From f303d9d7e8486d1b31a79a8a9e77aec24163384d Mon Sep 17 00:00:00 2001 From: hadashiA Date: Fri, 8 Sep 2023 13:04:10 +0900 Subject: [PATCH 01/10] Add UniTaskAsyncEnumerable.Merge --- src/UniTask.NetCoreTests/Linq/Merge.cs | 137 +++++++++++ .../Plugins/UniTask/Runtime/Internal/Error.cs | 2 +- .../Plugins/UniTask/Runtime/Linq/Merge.cs | 221 ++++++++++++++++++ .../UniTask/Runtime/Linq/Merge.cs.meta | 3 + 4 files changed, 362 insertions(+), 1 deletion(-) create mode 100644 src/UniTask.NetCoreTests/Linq/Merge.cs create mode 100644 src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/Merge.cs create mode 100644 src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/Merge.cs.meta diff --git a/src/UniTask.NetCoreTests/Linq/Merge.cs b/src/UniTask.NetCoreTests/Linq/Merge.cs new file mode 100644 index 0000000..049ae5a --- /dev/null +++ b/src/UniTask.NetCoreTests/Linq/Merge.cs @@ -0,0 +1,137 @@ +using System; +using System.Threading; +using System.Threading.Tasks; +using Cysharp.Threading.Tasks; +using Cysharp.Threading.Tasks.Linq; +using FluentAssertions; +using Xunit; + +namespace NetCoreTests.Linq +{ + public class MergeTest + { + [Fact] + public async Task TwoSource() + { + var semaphore = new SemaphoreSlim(1, 1); + + var a = UniTaskAsyncEnumerable.Create(async (writer, _) => + { + await UniTask.SwitchToThreadPool(); + + await semaphore.WaitAsync(); + await writer.YieldAsync("A1"); + semaphore.Release(); + + await semaphore.WaitAsync(); + await writer.YieldAsync("A2"); + semaphore.Release(); + }); + + var b = UniTaskAsyncEnumerable.Create(async (writer, _) => + { + await UniTask.SwitchToThreadPool(); + + await semaphore.WaitAsync(); + await writer.YieldAsync("B1"); + await writer.YieldAsync("B2"); + semaphore.Release(); + + await semaphore.WaitAsync(); + await writer.YieldAsync("B3"); + semaphore.Release(); + }); + + var result = await a.Merge(b).ToArrayAsync(); + result.Should().Equal("A1", "B1", "B2", "A2", "B3"); + } + + [Fact] + public async Task ThreeSource() + { + var semaphore = new SemaphoreSlim(0, 1); + + var a = UniTaskAsyncEnumerable.Create(async (writer, _) => + { + await UniTask.SwitchToThreadPool(); + + await semaphore.WaitAsync(); + await writer.YieldAsync("A1"); + semaphore.Release(); + + await semaphore.WaitAsync(); + await writer.YieldAsync("A2"); + semaphore.Release(); + }); + + var b = UniTaskAsyncEnumerable.Create(async (writer, _) => + { + await UniTask.SwitchToThreadPool(); + + await semaphore.WaitAsync(); + await writer.YieldAsync("B1"); + await writer.YieldAsync("B2"); + semaphore.Release(); + + await semaphore.WaitAsync(); + await writer.YieldAsync("B3"); + semaphore.Release(); + }); + + var c = UniTaskAsyncEnumerable.Create(async (writer, _) => + { + await UniTask.SwitchToThreadPool(); + + await writer.YieldAsync("C1"); + semaphore.Release(); + }); + + var result = await a.Merge(b, c).ToArrayAsync(); + result.Should().Equal("C1", "A1", "B1", "B2", "A2", "B3"); + } + + [Fact] + public async Task Throw() + { + var a = UniTaskAsyncEnumerable.Create(async (writer, _) => + { + await writer.YieldAsync("A1"); + + }); + + var b = UniTaskAsyncEnumerable.Create(async (writer, _) => + { + throw new UniTaskTestException(); + }); + + var enumerator = a.Merge(b).GetAsyncEnumerator(); + (await enumerator.MoveNextAsync()).Should().Be(true); + enumerator.Current.Should().Be("A1"); + + await Assert.ThrowsAsync(async () => await enumerator.MoveNextAsync()); + } + + [Fact] + public async Task Cancel() + { + var cts = new CancellationTokenSource(); + + var a = UniTaskAsyncEnumerable.Create(async (writer, _) => + { + await writer.YieldAsync("A1"); + }); + + var b = UniTaskAsyncEnumerable.Create(async (writer, _) => + { + await writer.YieldAsync("B1"); + }); + + var enumerator = a.Merge(b).GetAsyncEnumerator(cts.Token); + (await enumerator.MoveNextAsync()).Should().Be(true); + enumerator.Current.Should().Be("A1"); + + cts.Cancel(); + await Assert.ThrowsAsync(async () => await enumerator.MoveNextAsync()); + } + } +} \ No newline at end of file diff --git a/src/UniTask/Assets/Plugins/UniTask/Runtime/Internal/Error.cs b/src/UniTask/Assets/Plugins/UniTask/Runtime/Internal/Error.cs index 5c7bc93..9664491 100644 --- a/src/UniTask/Assets/Plugins/UniTask/Runtime/Internal/Error.cs +++ b/src/UniTask/Assets/Plugins/UniTask/Runtime/Internal/Error.cs @@ -39,7 +39,7 @@ namespace Cysharp.Threading.Tasks.Internal } [MethodImpl(MethodImplOptions.NoInlining)] - public static void ThrowArgumentException(string message) + public static void ThrowArgumentException(string message) { throw new ArgumentException(message); } diff --git a/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/Merge.cs b/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/Merge.cs new file mode 100644 index 0000000..5bc7649 --- /dev/null +++ b/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/Merge.cs @@ -0,0 +1,221 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using Cysharp.Threading.Tasks.Internal; + +namespace Cysharp.Threading.Tasks.Linq +{ + public static partial class UniTaskAsyncEnumerable + { + public static IUniTaskAsyncEnumerable Merge(this IUniTaskAsyncEnumerable first, IUniTaskAsyncEnumerable second) + { + Error.ThrowArgumentNullException(first, nameof(first)); + Error.ThrowArgumentNullException(second, nameof(second)); + + return new Merge(new [] { first, second }); + } + + public static IUniTaskAsyncEnumerable Merge(this IUniTaskAsyncEnumerable first, IUniTaskAsyncEnumerable second, IUniTaskAsyncEnumerable third) + { + Error.ThrowArgumentNullException(first, nameof(first)); + Error.ThrowArgumentNullException(second, nameof(second)); + Error.ThrowArgumentNullException(third, nameof(third)); + + return new Merge(new[] { first, second, third }); + } + + public static IUniTaskAsyncEnumerable Merge(this IEnumerable> sources) + { + return new Merge(sources.ToArray()); + } + + public static IUniTaskAsyncEnumerable Merge(params IUniTaskAsyncEnumerable[] sources) + { + return new Merge(sources); + } + } + + internal sealed class Merge : IUniTaskAsyncEnumerable + { + readonly IUniTaskAsyncEnumerable[] sources; + + public Merge(IUniTaskAsyncEnumerable[] sources) + { + if (sources.Length <= 0) + { + Error.ThrowArgumentException("No source async enumerable to merge"); + } + this.sources = sources; + } + + public IUniTaskAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) + => new _Merge(sources, cancellationToken); + + enum MergeSourceState + { + Pending, + Running, + Completed, + } + + sealed class _Merge : MoveNextSource, IUniTaskAsyncEnumerator + { + static readonly Action GetResultAtAction = GetResultAt; + + readonly int length; + readonly IUniTaskAsyncEnumerator[] enumerators; + readonly MergeSourceState[] states; + readonly Queue<(T, Exception)> resultQueue = new Queue<(T, Exception)>(); + readonly CancellationToken cancellationToken; + + public T Current { get; private set; } + + public _Merge(IUniTaskAsyncEnumerable[] sources, CancellationToken cancellationToken) + { + this.cancellationToken = cancellationToken; + length = sources.Length; + states = ArrayPool.Shared.Rent(length); + enumerators = ArrayPool>.Shared.Rent(length); + for (var i = 0; i < length; i++) + { + enumerators[i] = sources[i].GetAsyncEnumerator(cancellationToken); + states[i] = MergeSourceState.Pending; + } + } + + public UniTask MoveNextAsync() + { + cancellationToken.ThrowIfCancellationRequested(); + completionSource.Reset(); + + if (TryDequeue(out var queuedValue, out var queuedException)) + { + if (queuedException != null) + { + completionSource.TrySetException(queuedException); + } + else + { + Current = queuedValue; + completionSource.TrySetResult(!IsCompletedAll()); + } + return new UniTask(this, completionSource.Version); + } + + for (var i = 0; i < length; i++) + { + lock (states) + { + if (states[i] != MergeSourceState.Pending) + { + continue; + } + states[i] = MergeSourceState.Running; + } + + var awaiter = enumerators[i].MoveNextAsync().GetAwaiter(); + if (awaiter.IsCompleted) + { + GetResultAt(i, awaiter); + } + else + { + awaiter.SourceOnCompleted(GetResultAtAction, StateTuple.Create(this, i, awaiter)); + } + } + return new UniTask(this, completionSource.Version); + } + + public async UniTask DisposeAsync() + { + for (var i = 0; i < length; i++) + { + await enumerators[i].DisposeAsync(); + } + + ArrayPool.Shared.Return(states, true); + ArrayPool>.Shared.Return(enumerators, true); + } + + static void GetResultAt(object state) + { + var tuple = (StateTuple<_Merge, int, UniTask.Awaiter>)state; + tuple.Item1.GetResultAt(tuple.Item2, tuple.Item3); + } + + void GetResultAt(int index, UniTask.Awaiter awaiter) + { + bool hasNext; + try + { + hasNext = awaiter.GetResult(); + lock (states) + { + states[index] = hasNext ? MergeSourceState.Pending : MergeSourceState.Completed; + } + } + catch (Exception ex) + { + if (!completionSource.TrySetException(ex)) + { + lock (resultQueue) + { + resultQueue.Enqueue((default, ex)); + } + } + return; + } + + var completed = IsCompletedAll(); + if (hasNext || completed) + { + if (completionSource.GetStatus(completionSource.Version).IsCompleted()) + { + lock (resultQueue) + { + resultQueue.Enqueue((enumerators[index].Current, null)); + } + } + else + { + Current = enumerators[index].Current; + completionSource.TrySetResult(!completed); + } + } + } + + bool TryDequeue(out T value, out Exception ex) + { + lock (resultQueue) + { + if (resultQueue.Count > 0) + { + var result = resultQueue.Dequeue(); + value = result.Item1; + ex = result.Item2; + return true; + } + } + value = default; + ex = default; + return false; + } + + bool IsCompletedAll() + { + lock (states) + { + for (var i = 0; i < length; i++) + { + if (states[i] != MergeSourceState.Completed) + { + return false; + } + } + return true; + } + } + } + } +} \ No newline at end of file diff --git a/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/Merge.cs.meta b/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/Merge.cs.meta new file mode 100644 index 0000000..2f671f4 --- /dev/null +++ b/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/Merge.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: ca56812f160c45d0bacb4339819edf1a +timeCreated: 1694133666 \ No newline at end of file From b195df977395b97fc64a9f00105a70df70f2ed79 Mon Sep 17 00:00:00 2001 From: hadashiA Date: Fri, 8 Sep 2023 20:03:08 +0900 Subject: [PATCH 02/10] Update README --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 8571cb0..b066542 100644 --- a/README.md +++ b/README.md @@ -716,7 +716,7 @@ Async LINQ is enabled when `using Cysharp.Threading.Tasks.Linq;`, and `UniTaskAs It's closer to UniRx (Reactive Extensions), but UniTaskAsyncEnumerable is a pull-based asynchronous stream, whereas Rx was a push-based asynchronous stream. Note that although similar, the characteristics are different and the details behave differently along with them. -`UniTaskAsyncEnumerable` is the entry point like `Enumerable`. In addition to the standard query operators, there are other generators for Unity such as `EveryUpdate`, `Timer`, `TimerFrame`, `Interval`, `IntervalFrame`, and `EveryValueChanged`. And also added additional UniTask original query operators like `Append`, `Prepend`, `DistinctUntilChanged`, `ToHashSet`, `Buffer`, `CombineLatest`, `Do`, `Never`, `ForEachAsync`, `Pairwise`, `Publish`, `Queue`, `Return`, `SkipUntil`, `TakeUntil`, `SkipUntilCanceled`, `TakeUntilCanceled`, `TakeLast`, `Subscribe`. +`UniTaskAsyncEnumerable` is the entry point like `Enumerable`. In addition to the standard query operators, there are other generators for Unity such as `EveryUpdate`, `Timer`, `TimerFrame`, `Interval`, `IntervalFrame`, and `EveryValueChanged`. And also added additional UniTask original query operators like `Append`, `Prepend`, `DistinctUntilChanged`, `ToHashSet`, `Buffer`, `CombineLatest`,`Merge` `Do`, `Never`, `ForEachAsync`, `Pairwise`, `Publish`, `Queue`, `Return`, `SkipUntil`, `TakeUntil`, `SkipUntilCanceled`, `TakeUntilCanceled`, `TakeLast`, `Subscribe`. The method with Func as an argument has three additional overloads, `***Await`, `***AwaitWithCancellation`. From 6e99accf994e1a4ac67cc3f1059cc10ff5b6020f Mon Sep 17 00:00:00 2001 From: hadashiA Date: Fri, 8 Sep 2023 23:38:34 +0900 Subject: [PATCH 03/10] Fix race condition (todo: too wide lock range?) --- src/UniTask.NetCoreTests/Linq/Merge.cs | 60 +++++++------------ .../Plugins/UniTask/Runtime/Linq/Merge.cs | 60 ++++++++++--------- 2 files changed, 54 insertions(+), 66 deletions(-) diff --git a/src/UniTask.NetCoreTests/Linq/Merge.cs b/src/UniTask.NetCoreTests/Linq/Merge.cs index 049ae5a..e669580 100644 --- a/src/UniTask.NetCoreTests/Linq/Merge.cs +++ b/src/UniTask.NetCoreTests/Linq/Merge.cs @@ -13,77 +13,61 @@ namespace NetCoreTests.Linq [Fact] public async Task TwoSource() { - var semaphore = new SemaphoreSlim(1, 1); - var a = UniTaskAsyncEnumerable.Create(async (writer, _) => { await UniTask.SwitchToThreadPool(); - - await semaphore.WaitAsync(); + await writer.YieldAsync("A1"); - semaphore.Release(); - - await semaphore.WaitAsync(); + await Task.Delay(TimeSpan.FromMilliseconds(20)); await writer.YieldAsync("A2"); - semaphore.Release(); }); - + var b = UniTaskAsyncEnumerable.Create(async (writer, _) => { await UniTask.SwitchToThreadPool(); - - await semaphore.WaitAsync(); + + await Task.Delay(TimeSpan.FromMilliseconds(10)); await writer.YieldAsync("B1"); await writer.YieldAsync("B2"); - semaphore.Release(); - - await semaphore.WaitAsync(); + await Task.Delay(TimeSpan.FromMilliseconds(30)); await writer.YieldAsync("B3"); - semaphore.Release(); }); var result = await a.Merge(b).ToArrayAsync(); result.Should().Equal("A1", "B1", "B2", "A2", "B3"); } - + [Fact] public async Task ThreeSource() { - var semaphore = new SemaphoreSlim(0, 1); - var a = UniTaskAsyncEnumerable.Create(async (writer, _) => { await UniTask.SwitchToThreadPool(); - - await semaphore.WaitAsync(); + + await Task.Delay(TimeSpan.FromMilliseconds(10)); await writer.YieldAsync("A1"); - semaphore.Release(); - - await semaphore.WaitAsync(); + + await Task.Delay(TimeSpan.FromMilliseconds(30)); await writer.YieldAsync("A2"); - semaphore.Release(); }); - + var b = UniTaskAsyncEnumerable.Create(async (writer, _) => { await UniTask.SwitchToThreadPool(); - - await semaphore.WaitAsync(); + + await Task.Delay(TimeSpan.FromMilliseconds(20)); await writer.YieldAsync("B1"); await writer.YieldAsync("B2"); - semaphore.Release(); - - await semaphore.WaitAsync(); + + await Task.Delay(TimeSpan.FromMilliseconds(40)); await writer.YieldAsync("B3"); - semaphore.Release(); }); - + var c = UniTaskAsyncEnumerable.Create(async (writer, _) => { await UniTask.SwitchToThreadPool(); - + await writer.YieldAsync("C1"); - semaphore.Release(); }); var result = await a.Merge(b, c).ToArrayAsync(); @@ -107,15 +91,15 @@ namespace NetCoreTests.Linq var enumerator = a.Merge(b).GetAsyncEnumerator(); (await enumerator.MoveNextAsync()).Should().Be(true); enumerator.Current.Should().Be("A1"); - + await Assert.ThrowsAsync(async () => await enumerator.MoveNextAsync()); } - + [Fact] public async Task Cancel() { var cts = new CancellationTokenSource(); - + var a = UniTaskAsyncEnumerable.Create(async (writer, _) => { await writer.YieldAsync("A1"); @@ -129,7 +113,7 @@ namespace NetCoreTests.Linq var enumerator = a.Merge(b).GetAsyncEnumerator(cts.Token); (await enumerator.MoveNextAsync()).Should().Be(true); enumerator.Current.Should().Be("A1"); - + cts.Cancel(); await Assert.ThrowsAsync(async () => await enumerator.MoveNextAsync()); } diff --git a/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/Merge.cs b/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/Merge.cs index 5bc7649..f8a5fb0 100644 --- a/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/Merge.cs +++ b/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/Merge.cs @@ -15,7 +15,7 @@ namespace Cysharp.Threading.Tasks.Linq return new Merge(new [] { first, second }); } - + public static IUniTaskAsyncEnumerable Merge(this IUniTaskAsyncEnumerable first, IUniTaskAsyncEnumerable second, IUniTaskAsyncEnumerable third) { Error.ThrowArgumentNullException(first, nameof(first)); @@ -24,7 +24,7 @@ namespace Cysharp.Threading.Tasks.Linq return new Merge(new[] { first, second, third }); } - + public static IUniTaskAsyncEnumerable Merge(this IEnumerable> sources) { return new Merge(sources.ToArray()); @@ -35,11 +35,11 @@ namespace Cysharp.Threading.Tasks.Linq return new Merge(sources); } } - + internal sealed class Merge : IUniTaskAsyncEnumerable { readonly IUniTaskAsyncEnumerable[] sources; - + public Merge(IUniTaskAsyncEnumerable[] sources) { if (sources.Length <= 0) @@ -49,7 +49,7 @@ namespace Cysharp.Threading.Tasks.Linq this.sources = sources; } - public IUniTaskAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) + public IUniTaskAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) => new _Merge(sources, cancellationToken); enum MergeSourceState @@ -82,27 +82,30 @@ namespace Cysharp.Threading.Tasks.Linq enumerators[i] = sources[i].GetAsyncEnumerator(cancellationToken); states[i] = MergeSourceState.Pending; } - } + } public UniTask MoveNextAsync() { cancellationToken.ThrowIfCancellationRequested(); completionSource.Reset(); - if (TryDequeue(out var queuedValue, out var queuedException)) + lock (states) { - if (queuedException != null) + if (TryDequeue(out var queuedValue, out var queuedException)) { - completionSource.TrySetException(queuedException); + if (queuedException != null) + { + completionSource.TrySetException(queuedException); + } + else + { + Current = queuedValue; + completionSource.TrySetResult(!IsCompletedAll()); + } + return new UniTask(this, completionSource.Version); } - else - { - Current = queuedValue; - completionSource.TrySetResult(!IsCompletedAll()); - } - return new UniTask(this, completionSource.Version); } - + for (var i = 0; i < length; i++) { lock (states) @@ -113,7 +116,7 @@ namespace Cysharp.Threading.Tasks.Linq } states[i] = MergeSourceState.Running; } - + var awaiter = enumerators[i].MoveNextAsync().GetAwaiter(); if (awaiter.IsCompleted) { @@ -159,7 +162,8 @@ namespace Cysharp.Threading.Tasks.Linq { if (!completionSource.TrySetException(ex)) { - lock (resultQueue) + // + lock (states) { resultQueue.Enqueue((default, ex)); } @@ -167,27 +171,27 @@ namespace Cysharp.Threading.Tasks.Linq return; } - var completed = IsCompletedAll(); - if (hasNext || completed) + var completedAll = IsCompletedAll(); + if (hasNext || completedAll) { - if (completionSource.GetStatus(completionSource.Version).IsCompleted()) + lock (states) { - lock (resultQueue) + if (completionSource.GetStatus(completionSource.Version).IsCompleted()) { resultQueue.Enqueue((enumerators[index].Current, null)); } - } - else - { - Current = enumerators[index].Current; - completionSource.TrySetResult(!completed); + else + { + Current = enumerators[index].Current; + completionSource.TrySetResult(!completedAll); + } } } } bool TryDequeue(out T value, out Exception ex) { - lock (resultQueue) + lock (states) { if (resultQueue.Count > 0) { From ba7e676c6f1cd51bd31940dbbdcfacb96c3cbd40 Mon Sep 17 00:00:00 2001 From: hadashiA Date: Sat, 9 Sep 2023 09:05:06 +0900 Subject: [PATCH 04/10] Fix test --- src/UniTask.NetCoreTests/Linq/Merge.cs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/UniTask.NetCoreTests/Linq/Merge.cs b/src/UniTask.NetCoreTests/Linq/Merge.cs index e669580..4db2214 100644 --- a/src/UniTask.NetCoreTests/Linq/Merge.cs +++ b/src/UniTask.NetCoreTests/Linq/Merge.cs @@ -29,7 +29,7 @@ namespace NetCoreTests.Linq await Task.Delay(TimeSpan.FromMilliseconds(10)); await writer.YieldAsync("B1"); await writer.YieldAsync("B2"); - await Task.Delay(TimeSpan.FromMilliseconds(30)); + await Task.Delay(TimeSpan.FromMilliseconds(100)); await writer.YieldAsync("B3"); }); @@ -47,7 +47,7 @@ namespace NetCoreTests.Linq await Task.Delay(TimeSpan.FromMilliseconds(10)); await writer.YieldAsync("A1"); - await Task.Delay(TimeSpan.FromMilliseconds(30)); + await Task.Delay(TimeSpan.FromMilliseconds(40)); await writer.YieldAsync("A2"); }); @@ -59,7 +59,7 @@ namespace NetCoreTests.Linq await writer.YieldAsync("B1"); await writer.YieldAsync("B2"); - await Task.Delay(TimeSpan.FromMilliseconds(40)); + await Task.Delay(TimeSpan.FromMilliseconds(80)); await writer.YieldAsync("B3"); }); From 6db872236ef90eca99bed3963a250f6f8bbb5671 Mon Sep 17 00:00:00 2001 From: hadashiA Date: Sat, 9 Sep 2023 10:16:01 +0900 Subject: [PATCH 05/10] Fix test --- src/UniTask.NetCoreTests/Linq/Merge.cs | 30 ++++++++++++++++++++------ 1 file changed, 23 insertions(+), 7 deletions(-) diff --git a/src/UniTask.NetCoreTests/Linq/Merge.cs b/src/UniTask.NetCoreTests/Linq/Merge.cs index 4db2214..7021d1d 100644 --- a/src/UniTask.NetCoreTests/Linq/Merge.cs +++ b/src/UniTask.NetCoreTests/Linq/Merge.cs @@ -13,24 +13,33 @@ namespace NetCoreTests.Linq [Fact] public async Task TwoSource() { + var semaphore = new SemaphoreSlim(1, 1); + var a = UniTaskAsyncEnumerable.Create(async (writer, _) => { await UniTask.SwitchToThreadPool(); + await semaphore.WaitAsync(); await writer.YieldAsync("A1"); - await Task.Delay(TimeSpan.FromMilliseconds(20)); + semaphore.Release(); + + await semaphore.WaitAsync(); await writer.YieldAsync("A2"); + semaphore.Release(); }); var b = UniTaskAsyncEnumerable.Create(async (writer, _) => { await UniTask.SwitchToThreadPool(); - await Task.Delay(TimeSpan.FromMilliseconds(10)); + await semaphore.WaitAsync(); await writer.YieldAsync("B1"); await writer.YieldAsync("B2"); - await Task.Delay(TimeSpan.FromMilliseconds(100)); + semaphore.Release(); + + await semaphore.WaitAsync(); await writer.YieldAsync("B3"); + semaphore.Release(); }); var result = await a.Merge(b).ToArrayAsync(); @@ -40,27 +49,33 @@ namespace NetCoreTests.Linq [Fact] public async Task ThreeSource() { + var semaphore = new SemaphoreSlim(0, 1); + var a = UniTaskAsyncEnumerable.Create(async (writer, _) => { await UniTask.SwitchToThreadPool(); - await Task.Delay(TimeSpan.FromMilliseconds(10)); + await semaphore.WaitAsync(); await writer.YieldAsync("A1"); + semaphore.Release(); - await Task.Delay(TimeSpan.FromMilliseconds(40)); + await semaphore.WaitAsync(); await writer.YieldAsync("A2"); + semaphore.Release(); }); var b = UniTaskAsyncEnumerable.Create(async (writer, _) => { await UniTask.SwitchToThreadPool(); - await Task.Delay(TimeSpan.FromMilliseconds(20)); + await semaphore.WaitAsync(); await writer.YieldAsync("B1"); await writer.YieldAsync("B2"); + semaphore.Release(); - await Task.Delay(TimeSpan.FromMilliseconds(80)); + await semaphore.WaitAsync(); await writer.YieldAsync("B3"); + semaphore.Release(); }); var c = UniTaskAsyncEnumerable.Create(async (writer, _) => @@ -68,6 +83,7 @@ namespace NetCoreTests.Linq await UniTask.SwitchToThreadPool(); await writer.YieldAsync("C1"); + semaphore.Release(); }); var result = await a.Merge(b, c).ToArrayAsync(); From 730d68132d562032ec83d4ccaad2d4a1ec6202cc Mon Sep 17 00:00:00 2001 From: hadashiA Date: Sat, 9 Sep 2023 14:27:06 +0900 Subject: [PATCH 06/10] Tweaks --- .../Plugins/UniTask/Runtime/Linq/Merge.cs | 36 ++++++------------- 1 file changed, 11 insertions(+), 25 deletions(-) diff --git a/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/Merge.cs b/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/Merge.cs index f8a5fb0..001daae 100644 --- a/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/Merge.cs +++ b/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/Merge.cs @@ -66,7 +66,7 @@ namespace Cysharp.Threading.Tasks.Linq readonly int length; readonly IUniTaskAsyncEnumerator[] enumerators; readonly MergeSourceState[] states; - readonly Queue<(T, Exception)> resultQueue = new Queue<(T, Exception)>(); + readonly Queue<(T, Exception)> queuedResult = new Queue<(T, Exception)>(); readonly CancellationToken cancellationToken; public T Current { get; private set; } @@ -89,10 +89,14 @@ namespace Cysharp.Threading.Tasks.Linq cancellationToken.ThrowIfCancellationRequested(); completionSource.Reset(); - lock (states) + lock (queuedResult) { - if (TryDequeue(out var queuedValue, out var queuedException)) + if (queuedResult.Count > 0) { + var result = queuedResult.Dequeue(); + var queuedValue = result.Item1; + var queuedException = result.Item2; + if (queuedException != null) { completionSource.TrySetException(queuedException); @@ -162,10 +166,9 @@ namespace Cysharp.Threading.Tasks.Linq { if (!completionSource.TrySetException(ex)) { - // - lock (states) + lock (queuedResult) { - resultQueue.Enqueue((default, ex)); + queuedResult.Enqueue((default, ex)); } } return; @@ -174,11 +177,11 @@ namespace Cysharp.Threading.Tasks.Linq var completedAll = IsCompletedAll(); if (hasNext || completedAll) { - lock (states) + lock (queuedResult) { if (completionSource.GetStatus(completionSource.Version).IsCompleted()) { - resultQueue.Enqueue((enumerators[index].Current, null)); + queuedResult.Enqueue((enumerators[index].Current, null)); } else { @@ -189,23 +192,6 @@ namespace Cysharp.Threading.Tasks.Linq } } - bool TryDequeue(out T value, out Exception ex) - { - lock (states) - { - if (resultQueue.Count > 0) - { - var result = resultQueue.Dequeue(); - value = result.Item1; - ex = result.Item2; - return true; - } - } - value = default; - ex = default; - return false; - } - bool IsCompletedAll() { lock (states) From ea57847c9703f3782f0a8bdcd748c1f80cb73ad5 Mon Sep 17 00:00:00 2001 From: hadashiA Date: Sat, 9 Sep 2023 17:04:02 +0900 Subject: [PATCH 07/10] Add dispose --- src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/Merge.cs | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/Merge.cs b/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/Merge.cs index 001daae..226b9fb 100644 --- a/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/Merge.cs +++ b/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/Merge.cs @@ -147,8 +147,10 @@ namespace Cysharp.Threading.Tasks.Linq static void GetResultAt(object state) { - var tuple = (StateTuple<_Merge, int, UniTask.Awaiter>)state; - tuple.Item1.GetResultAt(tuple.Item2, tuple.Item3); + using (var tuple = (StateTuple<_Merge, int, UniTask.Awaiter>)state) + { + tuple.Item1.GetResultAt(tuple.Item2, tuple.Item3); + } } void GetResultAt(int index, UniTask.Awaiter awaiter) From 3bac16229fb1a83eb82104947e56a570937b3e28 Mon Sep 17 00:00:00 2001 From: hadashiA Date: Sun, 10 Sep 2023 23:33:29 +0900 Subject: [PATCH 08/10] Reduce lock --- .../Plugins/UniTask/Runtime/Linq/Merge.cs | 49 ++++++++----------- 1 file changed, 20 insertions(+), 29 deletions(-) diff --git a/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/Merge.cs b/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/Merge.cs index 226b9fb..86240a6 100644 --- a/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/Merge.cs +++ b/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/Merge.cs @@ -65,7 +65,7 @@ namespace Cysharp.Threading.Tasks.Linq readonly int length; readonly IUniTaskAsyncEnumerator[] enumerators; - readonly MergeSourceState[] states; + readonly int[] states; readonly Queue<(T, Exception)> queuedResult = new Queue<(T, Exception)>(); readonly CancellationToken cancellationToken; @@ -75,12 +75,12 @@ namespace Cysharp.Threading.Tasks.Linq { this.cancellationToken = cancellationToken; length = sources.Length; - states = ArrayPool.Shared.Rent(length); + states = ArrayPool.Shared.Rent(length); enumerators = ArrayPool>.Shared.Rent(length); for (var i = 0; i < length; i++) { enumerators[i] = sources[i].GetAsyncEnumerator(cancellationToken); - states[i] = MergeSourceState.Pending; + states[i] = (int)MergeSourceState.Pending;; } } @@ -112,23 +112,17 @@ namespace Cysharp.Threading.Tasks.Linq for (var i = 0; i < length; i++) { - lock (states) + if (Interlocked.CompareExchange(ref states[i], (int)MergeSourceState.Running, (int)MergeSourceState.Pending) == (int)MergeSourceState.Pending) { - if (states[i] != MergeSourceState.Pending) + var awaiter = enumerators[i].MoveNextAsync().GetAwaiter(); + if (awaiter.IsCompleted) { - continue; + GetResultAt(i, awaiter); + } + else + { + awaiter.SourceOnCompleted(GetResultAtAction, StateTuple.Create(this, i, awaiter)); } - states[i] = MergeSourceState.Running; - } - - var awaiter = enumerators[i].MoveNextAsync().GetAwaiter(); - if (awaiter.IsCompleted) - { - GetResultAt(i, awaiter); - } - else - { - awaiter.SourceOnCompleted(GetResultAtAction, StateTuple.Create(this, i, awaiter)); } } return new UniTask(this, completionSource.Version); @@ -141,7 +135,7 @@ namespace Cysharp.Threading.Tasks.Linq await enumerators[i].DisposeAsync(); } - ArrayPool.Shared.Return(states, true); + ArrayPool.Shared.Return(states, true); ArrayPool>.Shared.Return(enumerators, true); } @@ -159,10 +153,7 @@ namespace Cysharp.Threading.Tasks.Linq try { hasNext = awaiter.GetResult(); - lock (states) - { - states[index] = hasNext ? MergeSourceState.Pending : MergeSourceState.Completed; - } + Interlocked.Exchange(ref states[index], (int)(hasNext ? MergeSourceState.Pending : MergeSourceState.Completed)); } catch (Exception ex) { @@ -196,16 +187,16 @@ namespace Cysharp.Threading.Tasks.Linq bool IsCompletedAll() { - lock (states) + for (var i = 0; i < length; i++) { - for (var i = 0; i < length; i++) + if (states[i] != (int)MergeSourceState.Completed) { - if (states[i] != MergeSourceState.Completed) - { - return false; - } + return false; } - return true; + } + lock (queuedResult) + { + return queuedResult.Count <= 0; } } } From 937d3adf66dc1a80853acb3d6edfd951359b08e3 Mon Sep 17 00:00:00 2001 From: hadashiA Date: Tue, 12 Sep 2023 14:34:53 +0900 Subject: [PATCH 09/10] Fix race condition --- .../Plugins/UniTask/Runtime/Linq/Merge.cs | 39 +++++++++++-------- 1 file changed, 22 insertions(+), 17 deletions(-) diff --git a/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/Merge.cs b/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/Merge.cs index 86240a6..f129082 100644 --- a/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/Merge.cs +++ b/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/Merge.cs @@ -65,7 +65,7 @@ namespace Cysharp.Threading.Tasks.Linq readonly int length; readonly IUniTaskAsyncEnumerator[] enumerators; - readonly int[] states; + readonly MergeSourceState[] states; readonly Queue<(T, Exception)> queuedResult = new Queue<(T, Exception)>(); readonly CancellationToken cancellationToken; @@ -75,7 +75,7 @@ namespace Cysharp.Threading.Tasks.Linq { this.cancellationToken = cancellationToken; length = sources.Length; - states = ArrayPool.Shared.Rent(length); + states = ArrayPool.Shared.Rent(length); enumerators = ArrayPool>.Shared.Rent(length); for (var i = 0; i < length; i++) { @@ -112,18 +112,26 @@ namespace Cysharp.Threading.Tasks.Linq for (var i = 0; i < length; i++) { - if (Interlocked.CompareExchange(ref states[i], (int)MergeSourceState.Running, (int)MergeSourceState.Pending) == (int)MergeSourceState.Pending) + lock (queuedResult) { - var awaiter = enumerators[i].MoveNextAsync().GetAwaiter(); - if (awaiter.IsCompleted) + if (states[i] == (int)MergeSourceState.Pending) { - GetResultAt(i, awaiter); + states[i] = MergeSourceState.Running; } else { - awaiter.SourceOnCompleted(GetResultAtAction, StateTuple.Create(this, i, awaiter)); + continue; } } + var awaiter = enumerators[i].MoveNextAsync().GetAwaiter(); + if (awaiter.IsCompleted) + { + GetResultAt(i, awaiter); + } + else + { + awaiter.SourceOnCompleted(GetResultAtAction, StateTuple.Create(this, i, awaiter)); + } } return new UniTask(this, completionSource.Version); } @@ -135,7 +143,7 @@ namespace Cysharp.Threading.Tasks.Linq await enumerators[i].DisposeAsync(); } - ArrayPool.Shared.Return(states, true); + ArrayPool.Shared.Return(states, true); ArrayPool>.Shared.Return(enumerators, true); } @@ -153,7 +161,6 @@ namespace Cysharp.Threading.Tasks.Linq try { hasNext = awaiter.GetResult(); - Interlocked.Exchange(ref states[index], (int)(hasNext ? MergeSourceState.Pending : MergeSourceState.Completed)); } catch (Exception ex) { @@ -167,10 +174,11 @@ namespace Cysharp.Threading.Tasks.Linq return; } - var completedAll = IsCompletedAll(); - if (hasNext || completedAll) + lock (queuedResult) { - lock (queuedResult) + states[index] = hasNext ? MergeSourceState.Pending : MergeSourceState.Completed; + var completedAll = !hasNext && IsCompletedAll(); + if (hasNext || completedAll) { if (completionSource.GetStatus(completionSource.Version).IsCompleted()) { @@ -189,15 +197,12 @@ namespace Cysharp.Threading.Tasks.Linq { for (var i = 0; i < length; i++) { - if (states[i] != (int)MergeSourceState.Completed) + if (states[i] != MergeSourceState.Completed) { return false; } } - lock (queuedResult) - { - return queuedResult.Count <= 0; - } + return true; } } } From 3ba64412f8b5f502004892652c4bb7e9b284779d Mon Sep 17 00:00:00 2001 From: hadashiA Date: Wed, 13 Sep 2023 19:12:11 +0900 Subject: [PATCH 10/10] Reduce the lock --- .../Plugins/UniTask/Runtime/Linq/Merge.cs | 93 ++++++++++++------- 1 file changed, 58 insertions(+), 35 deletions(-) diff --git a/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/Merge.cs b/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/Merge.cs index f129082..d4ea969 100644 --- a/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/Merge.cs +++ b/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/Merge.cs @@ -66,9 +66,11 @@ namespace Cysharp.Threading.Tasks.Linq readonly int length; readonly IUniTaskAsyncEnumerator[] enumerators; readonly MergeSourceState[] states; - readonly Queue<(T, Exception)> queuedResult = new Queue<(T, Exception)>(); + readonly Queue<(T, Exception, bool)> queuedResult = new Queue<(T, Exception, bool)>(); readonly CancellationToken cancellationToken; + int moveNextCompleted; + public T Current { get; private set; } public _Merge(IUniTaskAsyncEnumerable[] sources, CancellationToken cancellationToken) @@ -88,33 +90,35 @@ namespace Cysharp.Threading.Tasks.Linq { cancellationToken.ThrowIfCancellationRequested(); completionSource.Reset(); + Interlocked.Exchange(ref moveNextCompleted, 0); - lock (queuedResult) + if (HasQueuedResult() && Interlocked.CompareExchange(ref moveNextCompleted, 1, 0) == 0) { - if (queuedResult.Count > 0) + (T, Exception, bool) value; + lock (states) { - var result = queuedResult.Dequeue(); - var queuedValue = result.Item1; - var queuedException = result.Item2; - - if (queuedException != null) - { - completionSource.TrySetException(queuedException); - } - else - { - Current = queuedValue; - completionSource.TrySetResult(!IsCompletedAll()); - } - return new UniTask(this, completionSource.Version); + value = queuedResult.Dequeue(); } + var resultValue = value.Item1; + var exception = value.Item2; + var hasNext = value.Item3; + if (exception != null) + { + completionSource.TrySetException(exception); + } + else + { + Current = resultValue; + completionSource.TrySetResult(hasNext); + } + return new UniTask(this, completionSource.Version); } for (var i = 0; i < length; i++) { - lock (queuedResult) + lock (states) { - if (states[i] == (int)MergeSourceState.Pending) + if (states[i] == MergeSourceState.Pending) { states[i] = MergeSourceState.Running; } @@ -158,48 +162,67 @@ namespace Cysharp.Threading.Tasks.Linq void GetResultAt(int index, UniTask.Awaiter awaiter) { bool hasNext; + bool completedAll; try { hasNext = awaiter.GetResult(); } catch (Exception ex) { - if (!completionSource.TrySetException(ex)) + if (Interlocked.CompareExchange(ref moveNextCompleted, 1, 0) == 0) { - lock (queuedResult) + completionSource.TrySetException(ex); + } + else + { + lock (states) { - queuedResult.Enqueue((default, ex)); + queuedResult.Enqueue((default, ex, default)); } } return; } - lock (queuedResult) + lock (states) { states[index] = hasNext ? MergeSourceState.Pending : MergeSourceState.Completed; - var completedAll = !hasNext && IsCompletedAll(); - if (hasNext || completedAll) + completedAll = !hasNext && IsCompletedAll(); + } + if (hasNext || completedAll) + { + if (Interlocked.CompareExchange(ref moveNextCompleted, 1, 0) == 0) { - if (completionSource.GetStatus(completionSource.Version).IsCompleted()) + Current = enumerators[index].Current; + completionSource.TrySetResult(!completedAll); + } + else + { + lock (states) { - queuedResult.Enqueue((enumerators[index].Current, null)); - } - else - { - Current = enumerators[index].Current; - completionSource.TrySetResult(!completedAll); + queuedResult.Enqueue((enumerators[index].Current, null, !completedAll)); } } } } + bool HasQueuedResult() + { + lock (states) + { + return queuedResult.Count > 0; + } + } + bool IsCompletedAll() { - for (var i = 0; i < length; i++) + lock (states) { - if (states[i] != MergeSourceState.Completed) + for (var i = 0; i < length; i++) { - return false; + if (states[i] != MergeSourceState.Completed) + { + return false; + } } } return true;