Fix race condition (todo: too wide lock range?)

pull/498/head
hadashiA 2023-09-08 23:38:34 +09:00
parent b195df9773
commit 6e99accf99
2 changed files with 54 additions and 66 deletions

View File

@ -13,33 +13,24 @@ namespace NetCoreTests.Linq
[Fact] [Fact]
public async Task TwoSource() public async Task TwoSource()
{ {
var semaphore = new SemaphoreSlim(1, 1);
var a = UniTaskAsyncEnumerable.Create<string>(async (writer, _) => var a = UniTaskAsyncEnumerable.Create<string>(async (writer, _) =>
{ {
await UniTask.SwitchToThreadPool(); await UniTask.SwitchToThreadPool();
await semaphore.WaitAsync();
await writer.YieldAsync("A1"); await writer.YieldAsync("A1");
semaphore.Release(); await Task.Delay(TimeSpan.FromMilliseconds(20));
await semaphore.WaitAsync();
await writer.YieldAsync("A2"); await writer.YieldAsync("A2");
semaphore.Release();
}); });
var b = UniTaskAsyncEnumerable.Create<string>(async (writer, _) => var b = UniTaskAsyncEnumerable.Create<string>(async (writer, _) =>
{ {
await UniTask.SwitchToThreadPool(); await UniTask.SwitchToThreadPool();
await semaphore.WaitAsync(); await Task.Delay(TimeSpan.FromMilliseconds(10));
await writer.YieldAsync("B1"); await writer.YieldAsync("B1");
await writer.YieldAsync("B2"); await writer.YieldAsync("B2");
semaphore.Release(); await Task.Delay(TimeSpan.FromMilliseconds(30));
await semaphore.WaitAsync();
await writer.YieldAsync("B3"); await writer.YieldAsync("B3");
semaphore.Release();
}); });
var result = await a.Merge(b).ToArrayAsync(); var result = await a.Merge(b).ToArrayAsync();
@ -49,33 +40,27 @@ namespace NetCoreTests.Linq
[Fact] [Fact]
public async Task ThreeSource() public async Task ThreeSource()
{ {
var semaphore = new SemaphoreSlim(0, 1);
var a = UniTaskAsyncEnumerable.Create<string>(async (writer, _) => var a = UniTaskAsyncEnumerable.Create<string>(async (writer, _) =>
{ {
await UniTask.SwitchToThreadPool(); await UniTask.SwitchToThreadPool();
await semaphore.WaitAsync(); await Task.Delay(TimeSpan.FromMilliseconds(10));
await writer.YieldAsync("A1"); await writer.YieldAsync("A1");
semaphore.Release();
await semaphore.WaitAsync(); await Task.Delay(TimeSpan.FromMilliseconds(30));
await writer.YieldAsync("A2"); await writer.YieldAsync("A2");
semaphore.Release();
}); });
var b = UniTaskAsyncEnumerable.Create<string>(async (writer, _) => var b = UniTaskAsyncEnumerable.Create<string>(async (writer, _) =>
{ {
await UniTask.SwitchToThreadPool(); await UniTask.SwitchToThreadPool();
await semaphore.WaitAsync(); await Task.Delay(TimeSpan.FromMilliseconds(20));
await writer.YieldAsync("B1"); await writer.YieldAsync("B1");
await writer.YieldAsync("B2"); await writer.YieldAsync("B2");
semaphore.Release();
await semaphore.WaitAsync(); await Task.Delay(TimeSpan.FromMilliseconds(40));
await writer.YieldAsync("B3"); await writer.YieldAsync("B3");
semaphore.Release();
}); });
var c = UniTaskAsyncEnumerable.Create<string>(async (writer, _) => var c = UniTaskAsyncEnumerable.Create<string>(async (writer, _) =>
@ -83,7 +68,6 @@ namespace NetCoreTests.Linq
await UniTask.SwitchToThreadPool(); await UniTask.SwitchToThreadPool();
await writer.YieldAsync("C1"); await writer.YieldAsync("C1");
semaphore.Release();
}); });
var result = await a.Merge(b, c).ToArrayAsync(); var result = await a.Merge(b, c).ToArrayAsync();

View File

@ -89,6 +89,8 @@ namespace Cysharp.Threading.Tasks.Linq
cancellationToken.ThrowIfCancellationRequested(); cancellationToken.ThrowIfCancellationRequested();
completionSource.Reset(); completionSource.Reset();
lock (states)
{
if (TryDequeue(out var queuedValue, out var queuedException)) if (TryDequeue(out var queuedValue, out var queuedException))
{ {
if (queuedException != null) if (queuedException != null)
@ -102,6 +104,7 @@ namespace Cysharp.Threading.Tasks.Linq
} }
return new UniTask<bool>(this, completionSource.Version); return new UniTask<bool>(this, completionSource.Version);
} }
}
for (var i = 0; i < length; i++) for (var i = 0; i < length; i++)
{ {
@ -159,7 +162,8 @@ namespace Cysharp.Threading.Tasks.Linq
{ {
if (!completionSource.TrySetException(ex)) if (!completionSource.TrySetException(ex))
{ {
lock (resultQueue) //
lock (states)
{ {
resultQueue.Enqueue((default, ex)); resultQueue.Enqueue((default, ex));
} }
@ -167,27 +171,27 @@ namespace Cysharp.Threading.Tasks.Linq
return; return;
} }
var completed = IsCompletedAll(); var completedAll = IsCompletedAll();
if (hasNext || completed) if (hasNext || completedAll)
{
lock (states)
{ {
if (completionSource.GetStatus(completionSource.Version).IsCompleted()) if (completionSource.GetStatus(completionSource.Version).IsCompleted())
{
lock (resultQueue)
{ {
resultQueue.Enqueue((enumerators[index].Current, null)); resultQueue.Enqueue((enumerators[index].Current, null));
} }
}
else else
{ {
Current = enumerators[index].Current; Current = enumerators[index].Current;
completionSource.TrySetResult(!completed); completionSource.TrySetResult(!completedAll);
}
} }
} }
} }
bool TryDequeue(out T value, out Exception ex) bool TryDequeue(out T value, out Exception ex)
{ {
lock (resultQueue) lock (states)
{ {
if (resultQueue.Count > 0) if (resultQueue.Count > 0)
{ {