diff --git a/src/UniTask.NetCoreTests/AsyncReactivePropertyTest.cs b/src/UniTask.NetCoreTests/AsyncReactivePropertyTest.cs index 21d1ef1..33eee0d 100644 --- a/src/UniTask.NetCoreTests/AsyncReactivePropertyTest.cs +++ b/src/UniTask.NetCoreTests/AsyncReactivePropertyTest.cs @@ -112,6 +112,85 @@ namespace NetCoreTests state.Value.Should().Be(20); } + [Fact] + public async Task WaitAsyncTest() + { + var rp = new AsyncReactiveProperty(128); + + var f = await rp.FirstAsync(); + f.Should().Be(128); + + { + var t = rp.WaitAsync(); + rp.Value = 99; + rp.Value = 100; + var v = await t; + + v.Should().Be(99); + } + + { + var t = rp.WaitAsync(); + rp.Value = 99; + rp.Value = 100; + var v = await t; + + v.Should().Be(99); + } + } + + + [Fact] + public async Task WaitAsyncCancellationTest() + { + var cts = new CancellationTokenSource(); + + var rp = new AsyncReactiveProperty(128); + + var t = rp.WaitAsync(cts.Token); + + cts.Cancel(); + + rp.Value = 99; + rp.Value = 100; + + await Assert.ThrowsAsync(async () => { await t; }); + } + + + [Fact] + public async Task ReadOnlyWaitAsyncTest() + { + var rp = new AsyncReactiveProperty(128); + var rrp = rp.ToReadOnlyAsyncReactiveProperty(CancellationToken.None); + + var t = rrp.WaitAsync(); + rp.Value = 99; + rp.Value = 100; + var v = await t; + + v.Should().Be(99); + } + + + [Fact] + public async Task ReadOnlyWaitAsyncCancellationTest() + { + var cts = new CancellationTokenSource(); + + var rp = new AsyncReactiveProperty(128); + var rrp = rp.ToReadOnlyAsyncReactiveProperty(CancellationToken.None); + + var t = rrp.WaitAsync(cts.Token); + + cts.Cancel(); + + rp.Value = 99; + rp.Value = 100; + + await Assert.ThrowsAsync(async () => { await t; }); + } + } diff --git a/src/UniTask/Assets/Plugins/UniTask/Runtime/AsyncReactiveProperty.cs b/src/UniTask/Assets/Plugins/UniTask/Runtime/AsyncReactiveProperty.cs index 8c84d6d..c6315f5 100644 --- a/src/UniTask/Assets/Plugins/UniTask/Runtime/AsyncReactiveProperty.cs +++ b/src/UniTask/Assets/Plugins/UniTask/Runtime/AsyncReactiveProperty.cs @@ -7,6 +7,7 @@ namespace Cysharp.Threading.Tasks { T Value { get; } IUniTaskAsyncEnumerable WithoutCurrent(); + UniTask WaitAsync(CancellationToken cancellationToken = default); } public interface IAsyncReactiveProperty : IReadOnlyAsyncReactiveProperty @@ -69,6 +70,11 @@ namespace Cysharp.Threading.Tasks return latestValue?.ToString(); } + public UniTask WaitAsync(CancellationToken cancellationToken = default) + { + return new UniTask(WaitAsyncSource.Create(this, cancellationToken, out var token), token); + } + static bool isValueType; static AsyncReactiveProperty() @@ -76,7 +82,143 @@ namespace Cysharp.Threading.Tasks isValueType = typeof(T).IsValueType; } - class WithoutCurrentEnumerable : IUniTaskAsyncEnumerable + sealed class WaitAsyncSource : IUniTaskSource, ITriggerHandler, ITaskPoolNode + { + static Action cancellationCallback = CancellationCallback; + + static TaskPool pool; + WaitAsyncSource ITaskPoolNode.NextNode { get; set; } + + static WaitAsyncSource() + { + TaskPool.RegisterSizeGetter(typeof(WaitAsyncSource), () => pool.Size); + } + + AsyncReactiveProperty parent; + CancellationToken cancellationToken; + CancellationTokenRegistration cancellationTokenRegistration; + UniTaskCompletionSourceCore core; + + WaitAsyncSource() + { + } + + public static IUniTaskSource Create(AsyncReactiveProperty parent, CancellationToken cancellationToken, out short token) + { + if (cancellationToken.IsCancellationRequested) + { + return AutoResetUniTaskCompletionSource.CreateFromCanceled(cancellationToken, out token); + } + + if (!pool.TryPop(out var result)) + { + result = new WaitAsyncSource(); + } + + result.parent = parent; + result.cancellationToken = cancellationToken; + + if (cancellationToken.CanBeCanceled) + { + result.cancellationTokenRegistration = cancellationToken.RegisterWithoutCaptureExecutionContext(cancellationCallback, result); + } + + result.parent.triggerEvent.Add(result); + + TaskTracker.TrackActiveTask(result, 3); + + token = result.core.Version; + return result; + } + + bool TryReturn() + { + TaskTracker.RemoveTracking(this); + core.Reset(); + cancellationTokenRegistration.Dispose(); + cancellationTokenRegistration = default; + parent.triggerEvent.Remove(this); + parent = null; + cancellationToken = default; + return pool.TryPush(this); + } + + ~WaitAsyncSource() + { + if (TryReturn()) + { + GC.ReRegisterForFinalize(this); + } + } + + static void CancellationCallback(object state) + { + var self = (WaitAsyncSource)state; + self.OnCanceled(self.cancellationToken); + } + + // IUniTaskSource + + public T GetResult(short token) + { + try + { + return core.GetResult(token); + } + finally + { + TryReturn(); + } + } + + void IUniTaskSource.GetResult(short token) + { + GetResult(token); + } + + public void OnCompleted(Action continuation, object state, short token) + { + core.OnCompleted(continuation, state, token); + } + + public UniTaskStatus GetStatus(short token) + { + return core.GetStatus(token); + } + + public UniTaskStatus UnsafeGetStatus() + { + return core.UnsafeGetStatus(); + } + + // ITriggerHandler + + ITriggerHandler ITriggerHandler.Prev { get; set; } + ITriggerHandler ITriggerHandler.Next { get; set; } + + public void OnCanceled(CancellationToken cancellationToken) + { + core.TrySetCanceled(cancellationToken); + } + + public void OnCompleted() + { + // Complete as Cancel. + core.TrySetCanceled(CancellationToken.None); + } + + public void OnError(Exception ex) + { + core.TrySetException(ex); + } + + public void OnNext(T value) + { + core.TrySetResult(value); + } + } + + sealed class WithoutCurrentEnumerable : IUniTaskAsyncEnumerable { readonly AsyncReactiveProperty parent; @@ -253,6 +395,11 @@ namespace Cysharp.Threading.Tasks return latestValue?.ToString(); } + public UniTask WaitAsync(CancellationToken cancellationToken = default) + { + return new UniTask(WaitAsyncSource.Create(this, cancellationToken, out var token), token); + } + static bool isValueType; static ReadOnlyAsyncReactiveProperty() @@ -260,7 +407,143 @@ namespace Cysharp.Threading.Tasks isValueType = typeof(T).IsValueType; } - class WithoutCurrentEnumerable : IUniTaskAsyncEnumerable + sealed class WaitAsyncSource : IUniTaskSource, ITriggerHandler, ITaskPoolNode + { + static Action cancellationCallback = CancellationCallback; + + static TaskPool pool; + WaitAsyncSource ITaskPoolNode.NextNode { get; set; } + + static WaitAsyncSource() + { + TaskPool.RegisterSizeGetter(typeof(WaitAsyncSource), () => pool.Size); + } + + ReadOnlyAsyncReactiveProperty parent; + CancellationToken cancellationToken; + CancellationTokenRegistration cancellationTokenRegistration; + UniTaskCompletionSourceCore core; + + WaitAsyncSource() + { + } + + public static IUniTaskSource Create(ReadOnlyAsyncReactiveProperty parent, CancellationToken cancellationToken, out short token) + { + if (cancellationToken.IsCancellationRequested) + { + return AutoResetUniTaskCompletionSource.CreateFromCanceled(cancellationToken, out token); + } + + if (!pool.TryPop(out var result)) + { + result = new WaitAsyncSource(); + } + + result.parent = parent; + result.cancellationToken = cancellationToken; + + if (cancellationToken.CanBeCanceled) + { + result.cancellationTokenRegistration = cancellationToken.RegisterWithoutCaptureExecutionContext(cancellationCallback, result); + } + + result.parent.triggerEvent.Add(result); + + TaskTracker.TrackActiveTask(result, 3); + + token = result.core.Version; + return result; + } + + bool TryReturn() + { + TaskTracker.RemoveTracking(this); + core.Reset(); + cancellationTokenRegistration.Dispose(); + cancellationTokenRegistration = default; + parent.triggerEvent.Remove(this); + parent = null; + cancellationToken = default; + return pool.TryPush(this); + } + + ~WaitAsyncSource() + { + if (TryReturn()) + { + GC.ReRegisterForFinalize(this); + } + } + + static void CancellationCallback(object state) + { + var self = (WaitAsyncSource)state; + self.OnCanceled(self.cancellationToken); + } + + // IUniTaskSource + + public T GetResult(short token) + { + try + { + return core.GetResult(token); + } + finally + { + TryReturn(); + } + } + + void IUniTaskSource.GetResult(short token) + { + GetResult(token); + } + + public void OnCompleted(Action continuation, object state, short token) + { + core.OnCompleted(continuation, state, token); + } + + public UniTaskStatus GetStatus(short token) + { + return core.GetStatus(token); + } + + public UniTaskStatus UnsafeGetStatus() + { + return core.UnsafeGetStatus(); + } + + // ITriggerHandler + + ITriggerHandler ITriggerHandler.Prev { get; set; } + ITriggerHandler ITriggerHandler.Next { get; set; } + + public void OnCanceled(CancellationToken cancellationToken) + { + core.TrySetCanceled(cancellationToken); + } + + public void OnCompleted() + { + // Complete as Cancel. + core.TrySetCanceled(CancellationToken.None); + } + + public void OnError(Exception ex) + { + core.TrySetException(ex); + } + + public void OnNext(T value) + { + core.TrySetResult(value); + } + } + + sealed class WithoutCurrentEnumerable : IUniTaskAsyncEnumerable { readonly ReadOnlyAsyncReactiveProperty parent;