Fix race condition (todo: too wide lock range?)

pull/498/head
hadashiA 2023-09-08 23:38:34 +09:00
parent b195df9773
commit 6e99accf99
2 changed files with 54 additions and 66 deletions

View File

@ -13,77 +13,61 @@ namespace NetCoreTests.Linq
[Fact] [Fact]
public async Task TwoSource() public async Task TwoSource()
{ {
var semaphore = new SemaphoreSlim(1, 1);
var a = UniTaskAsyncEnumerable.Create<string>(async (writer, _) => var a = UniTaskAsyncEnumerable.Create<string>(async (writer, _) =>
{ {
await UniTask.SwitchToThreadPool(); await UniTask.SwitchToThreadPool();
await semaphore.WaitAsync();
await writer.YieldAsync("A1"); await writer.YieldAsync("A1");
semaphore.Release(); await Task.Delay(TimeSpan.FromMilliseconds(20));
await semaphore.WaitAsync();
await writer.YieldAsync("A2"); await writer.YieldAsync("A2");
semaphore.Release();
}); });
var b = UniTaskAsyncEnumerable.Create<string>(async (writer, _) => var b = UniTaskAsyncEnumerable.Create<string>(async (writer, _) =>
{ {
await UniTask.SwitchToThreadPool(); await UniTask.SwitchToThreadPool();
await semaphore.WaitAsync(); await Task.Delay(TimeSpan.FromMilliseconds(10));
await writer.YieldAsync("B1"); await writer.YieldAsync("B1");
await writer.YieldAsync("B2"); await writer.YieldAsync("B2");
semaphore.Release(); await Task.Delay(TimeSpan.FromMilliseconds(30));
await semaphore.WaitAsync();
await writer.YieldAsync("B3"); await writer.YieldAsync("B3");
semaphore.Release();
}); });
var result = await a.Merge(b).ToArrayAsync(); var result = await a.Merge(b).ToArrayAsync();
result.Should().Equal("A1", "B1", "B2", "A2", "B3"); result.Should().Equal("A1", "B1", "B2", "A2", "B3");
} }
[Fact] [Fact]
public async Task ThreeSource() public async Task ThreeSource()
{ {
var semaphore = new SemaphoreSlim(0, 1);
var a = UniTaskAsyncEnumerable.Create<string>(async (writer, _) => var a = UniTaskAsyncEnumerable.Create<string>(async (writer, _) =>
{ {
await UniTask.SwitchToThreadPool(); await UniTask.SwitchToThreadPool();
await semaphore.WaitAsync(); await Task.Delay(TimeSpan.FromMilliseconds(10));
await writer.YieldAsync("A1"); await writer.YieldAsync("A1");
semaphore.Release();
await Task.Delay(TimeSpan.FromMilliseconds(30));
await semaphore.WaitAsync();
await writer.YieldAsync("A2"); await writer.YieldAsync("A2");
semaphore.Release();
}); });
var b = UniTaskAsyncEnumerable.Create<string>(async (writer, _) => var b = UniTaskAsyncEnumerable.Create<string>(async (writer, _) =>
{ {
await UniTask.SwitchToThreadPool(); await UniTask.SwitchToThreadPool();
await semaphore.WaitAsync(); await Task.Delay(TimeSpan.FromMilliseconds(20));
await writer.YieldAsync("B1"); await writer.YieldAsync("B1");
await writer.YieldAsync("B2"); await writer.YieldAsync("B2");
semaphore.Release();
await Task.Delay(TimeSpan.FromMilliseconds(40));
await semaphore.WaitAsync();
await writer.YieldAsync("B3"); await writer.YieldAsync("B3");
semaphore.Release();
}); });
var c = UniTaskAsyncEnumerable.Create<string>(async (writer, _) => var c = UniTaskAsyncEnumerable.Create<string>(async (writer, _) =>
{ {
await UniTask.SwitchToThreadPool(); await UniTask.SwitchToThreadPool();
await writer.YieldAsync("C1"); await writer.YieldAsync("C1");
semaphore.Release();
}); });
var result = await a.Merge(b, c).ToArrayAsync(); var result = await a.Merge(b, c).ToArrayAsync();
@ -107,15 +91,15 @@ namespace NetCoreTests.Linq
var enumerator = a.Merge(b).GetAsyncEnumerator(); var enumerator = a.Merge(b).GetAsyncEnumerator();
(await enumerator.MoveNextAsync()).Should().Be(true); (await enumerator.MoveNextAsync()).Should().Be(true);
enumerator.Current.Should().Be("A1"); enumerator.Current.Should().Be("A1");
await Assert.ThrowsAsync<UniTaskTestException>(async () => await enumerator.MoveNextAsync()); await Assert.ThrowsAsync<UniTaskTestException>(async () => await enumerator.MoveNextAsync());
} }
[Fact] [Fact]
public async Task Cancel() public async Task Cancel()
{ {
var cts = new CancellationTokenSource(); var cts = new CancellationTokenSource();
var a = UniTaskAsyncEnumerable.Create<string>(async (writer, _) => var a = UniTaskAsyncEnumerable.Create<string>(async (writer, _) =>
{ {
await writer.YieldAsync("A1"); await writer.YieldAsync("A1");
@ -129,7 +113,7 @@ namespace NetCoreTests.Linq
var enumerator = a.Merge(b).GetAsyncEnumerator(cts.Token); var enumerator = a.Merge(b).GetAsyncEnumerator(cts.Token);
(await enumerator.MoveNextAsync()).Should().Be(true); (await enumerator.MoveNextAsync()).Should().Be(true);
enumerator.Current.Should().Be("A1"); enumerator.Current.Should().Be("A1");
cts.Cancel(); cts.Cancel();
await Assert.ThrowsAsync<OperationCanceledException>(async () => await enumerator.MoveNextAsync()); await Assert.ThrowsAsync<OperationCanceledException>(async () => await enumerator.MoveNextAsync());
} }

View File

@ -15,7 +15,7 @@ namespace Cysharp.Threading.Tasks.Linq
return new Merge<T>(new [] { first, second }); return new Merge<T>(new [] { first, second });
} }
public static IUniTaskAsyncEnumerable<T> Merge<T>(this IUniTaskAsyncEnumerable<T> first, IUniTaskAsyncEnumerable<T> second, IUniTaskAsyncEnumerable<T> third) public static IUniTaskAsyncEnumerable<T> Merge<T>(this IUniTaskAsyncEnumerable<T> first, IUniTaskAsyncEnumerable<T> second, IUniTaskAsyncEnumerable<T> third)
{ {
Error.ThrowArgumentNullException(first, nameof(first)); Error.ThrowArgumentNullException(first, nameof(first));
@ -24,7 +24,7 @@ namespace Cysharp.Threading.Tasks.Linq
return new Merge<T>(new[] { first, second, third }); return new Merge<T>(new[] { first, second, third });
} }
public static IUniTaskAsyncEnumerable<T> Merge<T>(this IEnumerable<IUniTaskAsyncEnumerable<T>> sources) public static IUniTaskAsyncEnumerable<T> Merge<T>(this IEnumerable<IUniTaskAsyncEnumerable<T>> sources)
{ {
return new Merge<T>(sources.ToArray()); return new Merge<T>(sources.ToArray());
@ -35,11 +35,11 @@ namespace Cysharp.Threading.Tasks.Linq
return new Merge<T>(sources); return new Merge<T>(sources);
} }
} }
internal sealed class Merge<T> : IUniTaskAsyncEnumerable<T> internal sealed class Merge<T> : IUniTaskAsyncEnumerable<T>
{ {
readonly IUniTaskAsyncEnumerable<T>[] sources; readonly IUniTaskAsyncEnumerable<T>[] sources;
public Merge(IUniTaskAsyncEnumerable<T>[] sources) public Merge(IUniTaskAsyncEnumerable<T>[] sources)
{ {
if (sources.Length <= 0) if (sources.Length <= 0)
@ -49,7 +49,7 @@ namespace Cysharp.Threading.Tasks.Linq
this.sources = sources; this.sources = sources;
} }
public IUniTaskAsyncEnumerator<T> GetAsyncEnumerator(CancellationToken cancellationToken = default) public IUniTaskAsyncEnumerator<T> GetAsyncEnumerator(CancellationToken cancellationToken = default)
=> new _Merge(sources, cancellationToken); => new _Merge(sources, cancellationToken);
enum MergeSourceState enum MergeSourceState
@ -82,27 +82,30 @@ namespace Cysharp.Threading.Tasks.Linq
enumerators[i] = sources[i].GetAsyncEnumerator(cancellationToken); enumerators[i] = sources[i].GetAsyncEnumerator(cancellationToken);
states[i] = MergeSourceState.Pending; states[i] = MergeSourceState.Pending;
} }
} }
public UniTask<bool> MoveNextAsync() public UniTask<bool> MoveNextAsync()
{ {
cancellationToken.ThrowIfCancellationRequested(); cancellationToken.ThrowIfCancellationRequested();
completionSource.Reset(); completionSource.Reset();
if (TryDequeue(out var queuedValue, out var queuedException)) lock (states)
{ {
if (queuedException != null) if (TryDequeue(out var queuedValue, out var queuedException))
{ {
completionSource.TrySetException(queuedException); if (queuedException != null)
{
completionSource.TrySetException(queuedException);
}
else
{
Current = queuedValue;
completionSource.TrySetResult(!IsCompletedAll());
}
return new UniTask<bool>(this, completionSource.Version);
} }
else
{
Current = queuedValue;
completionSource.TrySetResult(!IsCompletedAll());
}
return new UniTask<bool>(this, completionSource.Version);
} }
for (var i = 0; i < length; i++) for (var i = 0; i < length; i++)
{ {
lock (states) lock (states)
@ -113,7 +116,7 @@ namespace Cysharp.Threading.Tasks.Linq
} }
states[i] = MergeSourceState.Running; states[i] = MergeSourceState.Running;
} }
var awaiter = enumerators[i].MoveNextAsync().GetAwaiter(); var awaiter = enumerators[i].MoveNextAsync().GetAwaiter();
if (awaiter.IsCompleted) if (awaiter.IsCompleted)
{ {
@ -159,7 +162,8 @@ namespace Cysharp.Threading.Tasks.Linq
{ {
if (!completionSource.TrySetException(ex)) if (!completionSource.TrySetException(ex))
{ {
lock (resultQueue) //
lock (states)
{ {
resultQueue.Enqueue((default, ex)); resultQueue.Enqueue((default, ex));
} }
@ -167,27 +171,27 @@ namespace Cysharp.Threading.Tasks.Linq
return; return;
} }
var completed = IsCompletedAll(); var completedAll = IsCompletedAll();
if (hasNext || completed) if (hasNext || completedAll)
{ {
if (completionSource.GetStatus(completionSource.Version).IsCompleted()) lock (states)
{ {
lock (resultQueue) if (completionSource.GetStatus(completionSource.Version).IsCompleted())
{ {
resultQueue.Enqueue((enumerators[index].Current, null)); resultQueue.Enqueue((enumerators[index].Current, null));
} }
} else
else {
{ Current = enumerators[index].Current;
Current = enumerators[index].Current; completionSource.TrySetResult(!completedAll);
completionSource.TrySetResult(!completed); }
} }
} }
} }
bool TryDequeue(out T value, out Exception ex) bool TryDequeue(out T value, out Exception ex)
{ {
lock (resultQueue) lock (states)
{ {
if (resultQueue.Count > 0) if (resultQueue.Count > 0)
{ {