diff --git a/src/UniTask/Assets/Plugins/UniTask/Runtime/UniTask.Factory.cs b/src/UniTask/Assets/Plugins/UniTask/Runtime/UniTask.Factory.cs index bd66aba..8d4f0c5 100644 --- a/src/UniTask/Assets/Plugins/UniTask/Runtime/UniTask.Factory.cs +++ b/src/UniTask/Assets/Plugins/UniTask/Runtime/UniTask.Factory.cs @@ -202,6 +202,22 @@ namespace Cysharp.Threading.Tasks return new UniTask(new DeferPromise(factory), 0); } + /// + /// Defer the task creation just before call await. + /// + public static UniTask Defer(TState state, Func factory) + { + return new UniTask(new DeferPromiseWithState(state, factory), 0); + } + + /// + /// Defer the task creation just before call await. + /// + public static UniTask Defer(TState state, Func> factory) + { + return new UniTask(new DeferPromiseWithState(state, factory), 0); + } + /// /// Never complete. /// @@ -465,6 +481,93 @@ namespace Cysharp.Threading.Tasks } } + sealed class DeferPromiseWithState : IUniTaskSource + { + Func factory; + TState argument; + UniTask task; + UniTask.Awaiter awaiter; + + public DeferPromiseWithState(TState argument, Func factory) + { + this.argument = argument; + this.factory = factory; + } + + public void GetResult(short token) + { + awaiter.GetResult(); + } + + public UniTaskStatus GetStatus(short token) + { + var f = Interlocked.Exchange(ref factory, null); + if (f != null) + { + task = f(argument); + awaiter = task.GetAwaiter(); + } + + return task.Status; + } + + public void OnCompleted(Action continuation, object state, short token) + { + awaiter.SourceOnCompleted(continuation, state); + } + + public UniTaskStatus UnsafeGetStatus() + { + return task.Status; + } + } + + sealed class DeferPromiseWithState : IUniTaskSource + { + Func> factory; + TState argument; + UniTask task; + UniTask.Awaiter awaiter; + + public DeferPromiseWithState(TState argument, Func> factory) + { + this.argument = argument; + this.factory = factory; + } + + public TResult GetResult(short token) + { + return awaiter.GetResult(); + } + + void IUniTaskSource.GetResult(short token) + { + awaiter.GetResult(); + } + + public UniTaskStatus GetStatus(short token) + { + var f = Interlocked.Exchange(ref factory, null); + if (f != null) + { + task = f(argument); + awaiter = task.GetAwaiter(); + } + + return task.Status; + } + + public void OnCompleted(Action continuation, object state, short token) + { + awaiter.SourceOnCompleted(continuation, state); + } + + public UniTaskStatus UnsafeGetStatus() + { + return task.Status; + } + } + sealed class NeverPromise : IUniTaskSource { static readonly Action cancellationCallback = CancellationCallback; diff --git a/src/UniTask/Assets/Plugins/UniTask/Runtime/UniTask.WaitUntil.cs b/src/UniTask/Assets/Plugins/UniTask/Runtime/UniTask.WaitUntil.cs index c5e5915..2f29b64 100644 --- a/src/UniTask/Assets/Plugins/UniTask/Runtime/UniTask.WaitUntil.cs +++ b/src/UniTask/Assets/Plugins/UniTask/Runtime/UniTask.WaitUntil.cs @@ -15,11 +15,21 @@ namespace Cysharp.Threading.Tasks return new UniTask(WaitUntilPromise.Create(predicate, timing, cancellationToken, cancelImmediately, out var token), token); } + public static UniTask WaitUntil(T state, Func predicate, PlayerLoopTiming timing = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken), bool cancelImmediately = false) + { + return new UniTask(WaitUntilPromise.Create(state, predicate, timing, cancellationToken, cancelImmediately, out var token), token); + } + public static UniTask WaitWhile(Func predicate, PlayerLoopTiming timing = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken), bool cancelImmediately = false) { return new UniTask(WaitWhilePromise.Create(predicate, timing, cancellationToken, cancelImmediately, out var token), token); } + public static UniTask WaitWhile(T state, Func predicate, PlayerLoopTiming timing = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken), bool cancelImmediately = false) + { + return new UniTask(WaitWhilePromise.Create(state, predicate, timing, cancellationToken, cancelImmediately, out var token), token); + } + public static UniTask WaitUntilCanceled(CancellationToken cancellationToken, PlayerLoopTiming timing = PlayerLoopTiming.Update, bool completeImmediately = false) { return new UniTask(WaitUntilCanceledPromise.Create(cancellationToken, timing, completeImmediately, out var token), token); @@ -162,6 +172,131 @@ namespace Cysharp.Threading.Tasks } } + sealed class WaitUntilPromise : IUniTaskSource, IPlayerLoopItem, ITaskPoolNode> + { + static TaskPool> pool; + WaitUntilPromise nextNode; + public ref WaitUntilPromise NextNode => ref nextNode; + + static WaitUntilPromise() + { + TaskPool.RegisterSizeGetter(typeof(WaitUntilPromise), () => pool.Size); + } + + Func predicate; + T argument; + CancellationToken cancellationToken; + CancellationTokenRegistration cancellationTokenRegistration; + bool cancelImmediately; + + UniTaskCompletionSourceCore core; + + WaitUntilPromise() + { + } + + public static IUniTaskSource Create(T argument, Func predicate, PlayerLoopTiming timing, CancellationToken cancellationToken, bool cancelImmediately, out short token) + { + if (cancellationToken.IsCancellationRequested) + { + return AutoResetUniTaskCompletionSource.CreateFromCanceled(cancellationToken, out token); + } + + if (!pool.TryPop(out var result)) + { + result = new WaitUntilPromise(); + } + + result.predicate = predicate; + result.argument = argument; + result.cancellationToken = cancellationToken; + result.cancelImmediately = cancelImmediately; + + if (cancelImmediately && cancellationToken.CanBeCanceled) + { + result.cancellationTokenRegistration = cancellationToken.RegisterWithoutCaptureExecutionContext(state => + { + var promise = (WaitUntilPromise)state; + promise.core.TrySetCanceled(promise.cancellationToken); + }, result); + } + + TaskTracker.TrackActiveTask(result, 3); + + PlayerLoopHelper.AddAction(timing, result); + + token = result.core.Version; + return result; + } + + public void GetResult(short token) + { + try + { + core.GetResult(token); + } + finally + { + if (!(cancelImmediately && cancellationToken.IsCancellationRequested)) + { + TryReturn(); + } + } + } + + public UniTaskStatus GetStatus(short token) + { + return core.GetStatus(token); + } + + public UniTaskStatus UnsafeGetStatus() + { + return core.UnsafeGetStatus(); + } + + public void OnCompleted(Action continuation, object state, short token) + { + core.OnCompleted(continuation, state, token); + } + + public bool MoveNext() + { + if (cancellationToken.IsCancellationRequested) + { + core.TrySetCanceled(cancellationToken); + return false; + } + + try + { + if (!predicate(argument)) + { + return true; + } + } + catch (Exception ex) + { + core.TrySetException(ex); + return false; + } + + core.TrySetResult(null); + return false; + } + + bool TryReturn() + { + TaskTracker.RemoveTracking(this); + core.Reset(); + predicate = default; + argument = default; + cancellationToken = default; + cancellationTokenRegistration.Dispose(); + cancelImmediately = default; + return pool.TryPush(this); + } + } + sealed class WaitWhilePromise : IUniTaskSource, IPlayerLoopItem, ITaskPoolNode { static TaskPool pool; @@ -199,7 +334,7 @@ namespace Cysharp.Threading.Tasks result.predicate = predicate; result.cancellationToken = cancellationToken; result.cancelImmediately = cancelImmediately; - + if (cancelImmediately && cancellationToken.CanBeCanceled) { result.cancellationTokenRegistration = cancellationToken.RegisterWithoutCaptureExecutionContext(state => @@ -288,6 +423,130 @@ namespace Cysharp.Threading.Tasks } } + sealed class WaitWhilePromise : IUniTaskSource, IPlayerLoopItem, ITaskPoolNode> + { + static TaskPool> pool; + WaitWhilePromise nextNode; + public ref WaitWhilePromise NextNode => ref nextNode; + + static WaitWhilePromise() + { + TaskPool.RegisterSizeGetter(typeof(WaitWhilePromise), () => pool.Size); + } + + Func predicate; + T argument; + CancellationToken cancellationToken; + CancellationTokenRegistration cancellationTokenRegistration; + bool cancelImmediately; + + UniTaskCompletionSourceCore core; + + WaitWhilePromise() + { + } + + public static IUniTaskSource Create(T argument, Func predicate, PlayerLoopTiming timing, CancellationToken cancellationToken, bool cancelImmediately, out short token) + { + if (cancellationToken.IsCancellationRequested) + { + return AutoResetUniTaskCompletionSource.CreateFromCanceled(cancellationToken, out token); + } + + if (!pool.TryPop(out var result)) + { + result = new WaitWhilePromise(); + } + + result.predicate = predicate; + result.argument = argument; + result.cancellationToken = cancellationToken; + + if (cancelImmediately && cancellationToken.CanBeCanceled) + { + result.cancellationTokenRegistration = cancellationToken.RegisterWithoutCaptureExecutionContext(state => + { + var promise = (WaitWhilePromise)state; + promise.core.TrySetCanceled(promise.cancellationToken); + }, result); + } + + TaskTracker.TrackActiveTask(result, 3); + + PlayerLoopHelper.AddAction(timing, result); + + token = result.core.Version; + return result; + } + + public void GetResult(short token) + { + try + { + core.GetResult(token); + } + finally + { + if (!(cancelImmediately && cancellationToken.IsCancellationRequested)) + { + TryReturn(); + } + } + } + + public UniTaskStatus GetStatus(short token) + { + return core.GetStatus(token); + } + + public UniTaskStatus UnsafeGetStatus() + { + return core.UnsafeGetStatus(); + } + + public void OnCompleted(Action continuation, object state, short token) + { + core.OnCompleted(continuation, state, token); + } + + public bool MoveNext() + { + if (cancellationToken.IsCancellationRequested) + { + core.TrySetCanceled(cancellationToken); + return false; + } + + try + { + if (predicate(argument)) + { + return true; + } + } + catch (Exception ex) + { + core.TrySetException(ex); + return false; + } + + core.TrySetResult(null); + return false; + } + + bool TryReturn() + { + TaskTracker.RemoveTracking(this); + core.Reset(); + predicate = default; + argument = default; + cancellationToken = default; + cancellationTokenRegistration.Dispose(); + cancelImmediately = default; + return pool.TryPush(this); + } + } + sealed class WaitUntilCanceledPromise : IUniTaskSource, IPlayerLoopItem, ITaskPoolNode { static TaskPool pool; @@ -443,7 +702,7 @@ namespace Cysharp.Threading.Tasks result.equalityComparer = equalityComparer ?? UnityEqualityComparer.GetDefault(); result.cancellationToken = cancellationToken; result.cancelImmediately = cancelImmediately; - + if (cancelImmediately && cancellationToken.CanBeCanceled) { result.cancellationTokenRegistration = cancellationToken.RegisterWithoutCaptureExecutionContext(state => @@ -586,7 +845,7 @@ namespace Cysharp.Threading.Tasks result.equalityComparer = equalityComparer ?? UnityEqualityComparer.GetDefault(); result.cancellationToken = cancellationToken; result.cancelImmediately = cancelImmediately; - + if (cancelImmediately && cancellationToken.CanBeCanceled) { result.cancellationTokenRegistration = cancellationToken.RegisterWithoutCaptureExecutionContext(state => diff --git a/src/UniTask/Assets/Tests/AsyncTest.cs b/src/UniTask/Assets/Tests/AsyncTest.cs index 3cd98a3..29645da 100644 --- a/src/UniTask/Assets/Tests/AsyncTest.cs +++ b/src/UniTask/Assets/Tests/AsyncTest.cs @@ -145,6 +145,11 @@ namespace Cysharp.Threading.TasksTests public int MyProperty { get; set; } } + class MyBoolenClass + { + public bool MyProperty { get; set; } + } + [UnityTest] public IEnumerator WaitUntil() => UniTask.ToCoroutine(async () => { @@ -159,6 +164,20 @@ namespace Cysharp.Threading.TasksTests diff.Should().Be(11); }); + [UnityTest] + public IEnumerator WaitUntilWithState() => UniTask.ToCoroutine(async () => + { + var v = new MyBoolenClass { MyProperty = false }; + + UniTask.DelayFrame(10, PlayerLoopTiming.PostLateUpdate).ContinueWith(() => v.MyProperty = true).Forget(); + + var startFrame = Time.frameCount; + await UniTask.WaitUntil(v, static v => v.MyProperty, PlayerLoopTiming.EarlyUpdate); + + var diff = Time.frameCount - startFrame; + diff.Should().Be(11); + }); + [UnityTest] public IEnumerator WaitWhile() => UniTask.ToCoroutine(async () => { @@ -173,6 +192,20 @@ namespace Cysharp.Threading.TasksTests diff.Should().Be(11); }); + [UnityTest] + public IEnumerator WaitWhileWithState() => UniTask.ToCoroutine(async () => + { + var v = new MyBoolenClass { MyProperty = true }; + + UniTask.DelayFrame(10, PlayerLoopTiming.PostLateUpdate).ContinueWith(() => v.MyProperty = false).Forget(); + + var startFrame = Time.frameCount; + await UniTask.WaitWhile(v, static v => v.MyProperty, PlayerLoopTiming.EarlyUpdate); + + var diff = Time.frameCount - startFrame; + diff.Should().Be(11); + }); + [UnityTest] public IEnumerator WaitUntilValueChanged() => UniTask.ToCoroutine(async () => {