mirror of https://github.com/Cysharp/UniTask
Fix race condition (todo: too wide lock range?)
parent
b195df9773
commit
6e99accf99
|
@ -13,77 +13,61 @@ namespace NetCoreTests.Linq
|
||||||
[Fact]
|
[Fact]
|
||||||
public async Task TwoSource()
|
public async Task TwoSource()
|
||||||
{
|
{
|
||||||
var semaphore = new SemaphoreSlim(1, 1);
|
|
||||||
|
|
||||||
var a = UniTaskAsyncEnumerable.Create<string>(async (writer, _) =>
|
var a = UniTaskAsyncEnumerable.Create<string>(async (writer, _) =>
|
||||||
{
|
{
|
||||||
await UniTask.SwitchToThreadPool();
|
await UniTask.SwitchToThreadPool();
|
||||||
|
|
||||||
await semaphore.WaitAsync();
|
|
||||||
await writer.YieldAsync("A1");
|
await writer.YieldAsync("A1");
|
||||||
semaphore.Release();
|
await Task.Delay(TimeSpan.FromMilliseconds(20));
|
||||||
|
|
||||||
await semaphore.WaitAsync();
|
|
||||||
await writer.YieldAsync("A2");
|
await writer.YieldAsync("A2");
|
||||||
semaphore.Release();
|
|
||||||
});
|
});
|
||||||
|
|
||||||
var b = UniTaskAsyncEnumerable.Create<string>(async (writer, _) =>
|
var b = UniTaskAsyncEnumerable.Create<string>(async (writer, _) =>
|
||||||
{
|
{
|
||||||
await UniTask.SwitchToThreadPool();
|
await UniTask.SwitchToThreadPool();
|
||||||
|
|
||||||
await semaphore.WaitAsync();
|
await Task.Delay(TimeSpan.FromMilliseconds(10));
|
||||||
await writer.YieldAsync("B1");
|
await writer.YieldAsync("B1");
|
||||||
await writer.YieldAsync("B2");
|
await writer.YieldAsync("B2");
|
||||||
semaphore.Release();
|
await Task.Delay(TimeSpan.FromMilliseconds(30));
|
||||||
|
|
||||||
await semaphore.WaitAsync();
|
|
||||||
await writer.YieldAsync("B3");
|
await writer.YieldAsync("B3");
|
||||||
semaphore.Release();
|
|
||||||
});
|
});
|
||||||
|
|
||||||
var result = await a.Merge(b).ToArrayAsync();
|
var result = await a.Merge(b).ToArrayAsync();
|
||||||
result.Should().Equal("A1", "B1", "B2", "A2", "B3");
|
result.Should().Equal("A1", "B1", "B2", "A2", "B3");
|
||||||
}
|
}
|
||||||
|
|
||||||
[Fact]
|
[Fact]
|
||||||
public async Task ThreeSource()
|
public async Task ThreeSource()
|
||||||
{
|
{
|
||||||
var semaphore = new SemaphoreSlim(0, 1);
|
|
||||||
|
|
||||||
var a = UniTaskAsyncEnumerable.Create<string>(async (writer, _) =>
|
var a = UniTaskAsyncEnumerable.Create<string>(async (writer, _) =>
|
||||||
{
|
{
|
||||||
await UniTask.SwitchToThreadPool();
|
await UniTask.SwitchToThreadPool();
|
||||||
|
|
||||||
await semaphore.WaitAsync();
|
await Task.Delay(TimeSpan.FromMilliseconds(10));
|
||||||
await writer.YieldAsync("A1");
|
await writer.YieldAsync("A1");
|
||||||
semaphore.Release();
|
|
||||||
|
await Task.Delay(TimeSpan.FromMilliseconds(30));
|
||||||
await semaphore.WaitAsync();
|
|
||||||
await writer.YieldAsync("A2");
|
await writer.YieldAsync("A2");
|
||||||
semaphore.Release();
|
|
||||||
});
|
});
|
||||||
|
|
||||||
var b = UniTaskAsyncEnumerable.Create<string>(async (writer, _) =>
|
var b = UniTaskAsyncEnumerable.Create<string>(async (writer, _) =>
|
||||||
{
|
{
|
||||||
await UniTask.SwitchToThreadPool();
|
await UniTask.SwitchToThreadPool();
|
||||||
|
|
||||||
await semaphore.WaitAsync();
|
await Task.Delay(TimeSpan.FromMilliseconds(20));
|
||||||
await writer.YieldAsync("B1");
|
await writer.YieldAsync("B1");
|
||||||
await writer.YieldAsync("B2");
|
await writer.YieldAsync("B2");
|
||||||
semaphore.Release();
|
|
||||||
|
await Task.Delay(TimeSpan.FromMilliseconds(40));
|
||||||
await semaphore.WaitAsync();
|
|
||||||
await writer.YieldAsync("B3");
|
await writer.YieldAsync("B3");
|
||||||
semaphore.Release();
|
|
||||||
});
|
});
|
||||||
|
|
||||||
var c = UniTaskAsyncEnumerable.Create<string>(async (writer, _) =>
|
var c = UniTaskAsyncEnumerable.Create<string>(async (writer, _) =>
|
||||||
{
|
{
|
||||||
await UniTask.SwitchToThreadPool();
|
await UniTask.SwitchToThreadPool();
|
||||||
|
|
||||||
await writer.YieldAsync("C1");
|
await writer.YieldAsync("C1");
|
||||||
semaphore.Release();
|
|
||||||
});
|
});
|
||||||
|
|
||||||
var result = await a.Merge(b, c).ToArrayAsync();
|
var result = await a.Merge(b, c).ToArrayAsync();
|
||||||
|
@ -107,15 +91,15 @@ namespace NetCoreTests.Linq
|
||||||
var enumerator = a.Merge(b).GetAsyncEnumerator();
|
var enumerator = a.Merge(b).GetAsyncEnumerator();
|
||||||
(await enumerator.MoveNextAsync()).Should().Be(true);
|
(await enumerator.MoveNextAsync()).Should().Be(true);
|
||||||
enumerator.Current.Should().Be("A1");
|
enumerator.Current.Should().Be("A1");
|
||||||
|
|
||||||
await Assert.ThrowsAsync<UniTaskTestException>(async () => await enumerator.MoveNextAsync());
|
await Assert.ThrowsAsync<UniTaskTestException>(async () => await enumerator.MoveNextAsync());
|
||||||
}
|
}
|
||||||
|
|
||||||
[Fact]
|
[Fact]
|
||||||
public async Task Cancel()
|
public async Task Cancel()
|
||||||
{
|
{
|
||||||
var cts = new CancellationTokenSource();
|
var cts = new CancellationTokenSource();
|
||||||
|
|
||||||
var a = UniTaskAsyncEnumerable.Create<string>(async (writer, _) =>
|
var a = UniTaskAsyncEnumerable.Create<string>(async (writer, _) =>
|
||||||
{
|
{
|
||||||
await writer.YieldAsync("A1");
|
await writer.YieldAsync("A1");
|
||||||
|
@ -129,7 +113,7 @@ namespace NetCoreTests.Linq
|
||||||
var enumerator = a.Merge(b).GetAsyncEnumerator(cts.Token);
|
var enumerator = a.Merge(b).GetAsyncEnumerator(cts.Token);
|
||||||
(await enumerator.MoveNextAsync()).Should().Be(true);
|
(await enumerator.MoveNextAsync()).Should().Be(true);
|
||||||
enumerator.Current.Should().Be("A1");
|
enumerator.Current.Should().Be("A1");
|
||||||
|
|
||||||
cts.Cancel();
|
cts.Cancel();
|
||||||
await Assert.ThrowsAsync<OperationCanceledException>(async () => await enumerator.MoveNextAsync());
|
await Assert.ThrowsAsync<OperationCanceledException>(async () => await enumerator.MoveNextAsync());
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,7 +15,7 @@ namespace Cysharp.Threading.Tasks.Linq
|
||||||
|
|
||||||
return new Merge<T>(new [] { first, second });
|
return new Merge<T>(new [] { first, second });
|
||||||
}
|
}
|
||||||
|
|
||||||
public static IUniTaskAsyncEnumerable<T> Merge<T>(this IUniTaskAsyncEnumerable<T> first, IUniTaskAsyncEnumerable<T> second, IUniTaskAsyncEnumerable<T> third)
|
public static IUniTaskAsyncEnumerable<T> Merge<T>(this IUniTaskAsyncEnumerable<T> first, IUniTaskAsyncEnumerable<T> second, IUniTaskAsyncEnumerable<T> third)
|
||||||
{
|
{
|
||||||
Error.ThrowArgumentNullException(first, nameof(first));
|
Error.ThrowArgumentNullException(first, nameof(first));
|
||||||
|
@ -24,7 +24,7 @@ namespace Cysharp.Threading.Tasks.Linq
|
||||||
|
|
||||||
return new Merge<T>(new[] { first, second, third });
|
return new Merge<T>(new[] { first, second, third });
|
||||||
}
|
}
|
||||||
|
|
||||||
public static IUniTaskAsyncEnumerable<T> Merge<T>(this IEnumerable<IUniTaskAsyncEnumerable<T>> sources)
|
public static IUniTaskAsyncEnumerable<T> Merge<T>(this IEnumerable<IUniTaskAsyncEnumerable<T>> sources)
|
||||||
{
|
{
|
||||||
return new Merge<T>(sources.ToArray());
|
return new Merge<T>(sources.ToArray());
|
||||||
|
@ -35,11 +35,11 @@ namespace Cysharp.Threading.Tasks.Linq
|
||||||
return new Merge<T>(sources);
|
return new Merge<T>(sources);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
internal sealed class Merge<T> : IUniTaskAsyncEnumerable<T>
|
internal sealed class Merge<T> : IUniTaskAsyncEnumerable<T>
|
||||||
{
|
{
|
||||||
readonly IUniTaskAsyncEnumerable<T>[] sources;
|
readonly IUniTaskAsyncEnumerable<T>[] sources;
|
||||||
|
|
||||||
public Merge(IUniTaskAsyncEnumerable<T>[] sources)
|
public Merge(IUniTaskAsyncEnumerable<T>[] sources)
|
||||||
{
|
{
|
||||||
if (sources.Length <= 0)
|
if (sources.Length <= 0)
|
||||||
|
@ -49,7 +49,7 @@ namespace Cysharp.Threading.Tasks.Linq
|
||||||
this.sources = sources;
|
this.sources = sources;
|
||||||
}
|
}
|
||||||
|
|
||||||
public IUniTaskAsyncEnumerator<T> GetAsyncEnumerator(CancellationToken cancellationToken = default)
|
public IUniTaskAsyncEnumerator<T> GetAsyncEnumerator(CancellationToken cancellationToken = default)
|
||||||
=> new _Merge(sources, cancellationToken);
|
=> new _Merge(sources, cancellationToken);
|
||||||
|
|
||||||
enum MergeSourceState
|
enum MergeSourceState
|
||||||
|
@ -82,27 +82,30 @@ namespace Cysharp.Threading.Tasks.Linq
|
||||||
enumerators[i] = sources[i].GetAsyncEnumerator(cancellationToken);
|
enumerators[i] = sources[i].GetAsyncEnumerator(cancellationToken);
|
||||||
states[i] = MergeSourceState.Pending;
|
states[i] = MergeSourceState.Pending;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public UniTask<bool> MoveNextAsync()
|
public UniTask<bool> MoveNextAsync()
|
||||||
{
|
{
|
||||||
cancellationToken.ThrowIfCancellationRequested();
|
cancellationToken.ThrowIfCancellationRequested();
|
||||||
completionSource.Reset();
|
completionSource.Reset();
|
||||||
|
|
||||||
if (TryDequeue(out var queuedValue, out var queuedException))
|
lock (states)
|
||||||
{
|
{
|
||||||
if (queuedException != null)
|
if (TryDequeue(out var queuedValue, out var queuedException))
|
||||||
{
|
{
|
||||||
completionSource.TrySetException(queuedException);
|
if (queuedException != null)
|
||||||
|
{
|
||||||
|
completionSource.TrySetException(queuedException);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
Current = queuedValue;
|
||||||
|
completionSource.TrySetResult(!IsCompletedAll());
|
||||||
|
}
|
||||||
|
return new UniTask<bool>(this, completionSource.Version);
|
||||||
}
|
}
|
||||||
else
|
|
||||||
{
|
|
||||||
Current = queuedValue;
|
|
||||||
completionSource.TrySetResult(!IsCompletedAll());
|
|
||||||
}
|
|
||||||
return new UniTask<bool>(this, completionSource.Version);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for (var i = 0; i < length; i++)
|
for (var i = 0; i < length; i++)
|
||||||
{
|
{
|
||||||
lock (states)
|
lock (states)
|
||||||
|
@ -113,7 +116,7 @@ namespace Cysharp.Threading.Tasks.Linq
|
||||||
}
|
}
|
||||||
states[i] = MergeSourceState.Running;
|
states[i] = MergeSourceState.Running;
|
||||||
}
|
}
|
||||||
|
|
||||||
var awaiter = enumerators[i].MoveNextAsync().GetAwaiter();
|
var awaiter = enumerators[i].MoveNextAsync().GetAwaiter();
|
||||||
if (awaiter.IsCompleted)
|
if (awaiter.IsCompleted)
|
||||||
{
|
{
|
||||||
|
@ -159,7 +162,8 @@ namespace Cysharp.Threading.Tasks.Linq
|
||||||
{
|
{
|
||||||
if (!completionSource.TrySetException(ex))
|
if (!completionSource.TrySetException(ex))
|
||||||
{
|
{
|
||||||
lock (resultQueue)
|
//
|
||||||
|
lock (states)
|
||||||
{
|
{
|
||||||
resultQueue.Enqueue((default, ex));
|
resultQueue.Enqueue((default, ex));
|
||||||
}
|
}
|
||||||
|
@ -167,27 +171,27 @@ namespace Cysharp.Threading.Tasks.Linq
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
var completed = IsCompletedAll();
|
var completedAll = IsCompletedAll();
|
||||||
if (hasNext || completed)
|
if (hasNext || completedAll)
|
||||||
{
|
{
|
||||||
if (completionSource.GetStatus(completionSource.Version).IsCompleted())
|
lock (states)
|
||||||
{
|
{
|
||||||
lock (resultQueue)
|
if (completionSource.GetStatus(completionSource.Version).IsCompleted())
|
||||||
{
|
{
|
||||||
resultQueue.Enqueue((enumerators[index].Current, null));
|
resultQueue.Enqueue((enumerators[index].Current, null));
|
||||||
}
|
}
|
||||||
}
|
else
|
||||||
else
|
{
|
||||||
{
|
Current = enumerators[index].Current;
|
||||||
Current = enumerators[index].Current;
|
completionSource.TrySetResult(!completedAll);
|
||||||
completionSource.TrySetResult(!completed);
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool TryDequeue(out T value, out Exception ex)
|
bool TryDequeue(out T value, out Exception ex)
|
||||||
{
|
{
|
||||||
lock (resultQueue)
|
lock (states)
|
||||||
{
|
{
|
||||||
if (resultQueue.Count > 0)
|
if (resultQueue.Count > 0)
|
||||||
{
|
{
|
||||||
|
|
Loading…
Reference in New Issue