Reduce lock

pull/498/head
hadashiA 2023-09-10 23:33:29 +09:00
parent ea57847c97
commit 3bac16229f
1 changed files with 20 additions and 29 deletions

View File

@ -65,7 +65,7 @@ namespace Cysharp.Threading.Tasks.Linq
readonly int length;
readonly IUniTaskAsyncEnumerator<T>[] enumerators;
readonly MergeSourceState[] states;
readonly int[] states;
readonly Queue<(T, Exception)> queuedResult = new Queue<(T, Exception)>();
readonly CancellationToken cancellationToken;
@ -75,12 +75,12 @@ namespace Cysharp.Threading.Tasks.Linq
{
this.cancellationToken = cancellationToken;
length = sources.Length;
states = ArrayPool<MergeSourceState>.Shared.Rent(length);
states = ArrayPool<int>.Shared.Rent(length);
enumerators = ArrayPool<IUniTaskAsyncEnumerator<T>>.Shared.Rent(length);
for (var i = 0; i < length; i++)
{
enumerators[i] = sources[i].GetAsyncEnumerator(cancellationToken);
states[i] = MergeSourceState.Pending;
states[i] = (int)MergeSourceState.Pending;;
}
}
@ -112,23 +112,17 @@ namespace Cysharp.Threading.Tasks.Linq
for (var i = 0; i < length; i++)
{
lock (states)
if (Interlocked.CompareExchange(ref states[i], (int)MergeSourceState.Running, (int)MergeSourceState.Pending) == (int)MergeSourceState.Pending)
{
if (states[i] != MergeSourceState.Pending)
var awaiter = enumerators[i].MoveNextAsync().GetAwaiter();
if (awaiter.IsCompleted)
{
continue;
GetResultAt(i, awaiter);
}
else
{
awaiter.SourceOnCompleted(GetResultAtAction, StateTuple.Create(this, i, awaiter));
}
states[i] = MergeSourceState.Running;
}
var awaiter = enumerators[i].MoveNextAsync().GetAwaiter();
if (awaiter.IsCompleted)
{
GetResultAt(i, awaiter);
}
else
{
awaiter.SourceOnCompleted(GetResultAtAction, StateTuple.Create(this, i, awaiter));
}
}
return new UniTask<bool>(this, completionSource.Version);
@ -141,7 +135,7 @@ namespace Cysharp.Threading.Tasks.Linq
await enumerators[i].DisposeAsync();
}
ArrayPool<MergeSourceState>.Shared.Return(states, true);
ArrayPool<int>.Shared.Return(states, true);
ArrayPool<IUniTaskAsyncEnumerator<T>>.Shared.Return(enumerators, true);
}
@ -159,10 +153,7 @@ namespace Cysharp.Threading.Tasks.Linq
try
{
hasNext = awaiter.GetResult();
lock (states)
{
states[index] = hasNext ? MergeSourceState.Pending : MergeSourceState.Completed;
}
Interlocked.Exchange(ref states[index], (int)(hasNext ? MergeSourceState.Pending : MergeSourceState.Completed));
}
catch (Exception ex)
{
@ -196,16 +187,16 @@ namespace Cysharp.Threading.Tasks.Linq
bool IsCompletedAll()
{
lock (states)
for (var i = 0; i < length; i++)
{
for (var i = 0; i < length; i++)
if (states[i] != (int)MergeSourceState.Completed)
{
if (states[i] != MergeSourceState.Completed)
{
return false;
}
return false;
}
return true;
}
lock (queuedResult)
{
return queuedResult.Count <= 0;
}
}
}