From e2728560c338a133eaf595b0b01be5c3faf3843e Mon Sep 17 00:00:00 2001 From: "Ram.Type-0" Date: Thu, 3 Sep 2020 01:30:51 +0900 Subject: [PATCH] Update behavior of IEnumerator.ToUniTask to be more like StartCoroutine --- .../Runtime/EnumeratorAsyncExtensions.cs | 300 ++++++------------ 1 file changed, 95 insertions(+), 205 deletions(-) diff --git a/src/UniTask/Assets/Plugins/UniTask/Runtime/EnumeratorAsyncExtensions.cs b/src/UniTask/Assets/Plugins/UniTask/Runtime/EnumeratorAsyncExtensions.cs index 36dca5a..894e81f 100644 --- a/src/UniTask/Assets/Plugins/UniTask/Runtime/EnumeratorAsyncExtensions.cs +++ b/src/UniTask/Assets/Plugins/UniTask/Runtime/EnumeratorAsyncExtensions.cs @@ -17,235 +17,125 @@ namespace Cysharp.Threading.Tasks { var e = (IEnumerator)enumerator; Error.ThrowArgumentNullException(e, nameof(enumerator)); - return new UniTask(EnumeratorPromise.Create(e, PlayerLoopTiming.Update, CancellationToken.None, out var token), token).GetAwaiter(); + return StartCoroutineAsUniTask(enumerator).GetAwaiter(); } public static UniTask WithCancellation(this IEnumerator enumerator, CancellationToken cancellationToken) { Error.ThrowArgumentNullException(enumerator, nameof(enumerator)); - return new UniTask(EnumeratorPromise.Create(enumerator, PlayerLoopTiming.Update, cancellationToken, out var token), token); + return StartCoroutineAsUniTask(enumerator, cancellationToken: cancellationToken); } - public static UniTask ToUniTask(this IEnumerator enumerator, PlayerLoopTiming timing = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken)) + public static UniTask ToUniTask(this IEnumerator enumerator, PlayerLoopTiming? timing = null, CancellationToken cancellationToken = default(CancellationToken)) { Error.ThrowArgumentNullException(enumerator, nameof(enumerator)); - return new UniTask(EnumeratorPromise.Create(enumerator, timing, cancellationToken, out var token), token); + return StartCoroutineAsUniTask(enumerator, timing, cancellationToken); } - sealed class EnumeratorPromise : IUniTaskSource, IPlayerLoopItem, ITaskPoolNode + static async UniTask StartCoroutineAsUniTask(IEnumerator enumerator, PlayerLoopTiming? timing = default, CancellationToken cancellationToken = default) { - static TaskPool pool; - public EnumeratorPromise NextNode { get; set; } - - static EnumeratorPromise() - { - TaskPool.RegisterSizeGetter(typeof(EnumeratorPromise), () => pool.Size); - } - - IEnumerator innerEnumerator; - CancellationToken cancellationToken; - int initialFrame; - - UniTaskCompletionSourceCore core; - - EnumeratorPromise() - { - } - - public static IUniTaskSource Create(IEnumerator innerEnumerator, PlayerLoopTiming timing, CancellationToken cancellationToken, out short token) - { - if (cancellationToken.IsCancellationRequested) - { - return AutoResetUniTaskCompletionSource.CreateFromCanceled(cancellationToken, out token); - } - - if (!pool.TryPop(out var result)) - { - result = new EnumeratorPromise(); - } - TaskTracker.TrackActiveTask(result, 3); - - result.innerEnumerator = ConsumeEnumerator(innerEnumerator); - result.cancellationToken = cancellationToken; - result.initialFrame = -1; - - PlayerLoopHelper.AddAction(timing, result); - - token = result.core.Version; - - result.MoveNext(); // run immediately. - return result; - } - - public void GetResult(short token) - { - try - { - core.GetResult(token); - } - finally - { - TryReturn(); - } - } - - public UniTaskStatus GetStatus(short token) - { - return core.GetStatus(token); - } - - public UniTaskStatus UnsafeGetStatus() - { - return core.UnsafeGetStatus(); - } - - public void OnCompleted(Action continuation, object state, short token) - { - core.OnCompleted(continuation, state, token); - } - - public bool MoveNext() - { - if (cancellationToken.IsCancellationRequested) - { - core.TrySetCanceled(cancellationToken); - return false; - } - - if (initialFrame == -1) - { - // Time can not touch in threadpool. - if (PlayerLoopHelper.IsMainThread) - { - initialFrame = Time.frameCount; - } - } - else if (initialFrame == Time.frameCount) - { - return true; // already executed in first frame, skip. - } - - try - { - if (innerEnumerator.MoveNext()) - { - return true; - } - } - catch (Exception ex) - { - core.TrySetException(ex); - return false; - } - - core.TrySetResult(null); - return false; - } - - bool TryReturn() - { - TaskTracker.RemoveTracking(this); - core.Reset(); - innerEnumerator = default; - cancellationToken = default; - return pool.TryPush(this); - } - - // Unwrap YieldInstructions - - static IEnumerator ConsumeEnumerator(IEnumerator enumerator) + Exception exception = null; + try { + cancellationToken.ThrowIfCancellationRequested(); while (enumerator.MoveNext()) { + cancellationToken.ThrowIfCancellationRequested(); var current = enumerator.Current; - if (current == null) + switch (current) { - yield return null; - } - else if (current is CustomYieldInstruction) - { - // WWW, WaitForSecondsRealtime - var e2 = UnwrapWaitCustomYieldInstruction((CustomYieldInstruction)current); - while (e2.MoveNext()) - { - yield return null; - } - } - else if (current is YieldInstruction) - { - IEnumerator innerCoroutine = null; - switch (current) - { - case AsyncOperation ao: - innerCoroutine = UnwrapWaitAsyncOperation(ao); - break; - case WaitForSeconds wfs: - innerCoroutine = UnwrapWaitForSeconds(wfs); - break; - } - if (innerCoroutine != null) - { - while (innerCoroutine.MoveNext()) + case null: + default: { - yield return null; + await UniTask.Yield(PlayerLoopTiming.LastUpdate); + break; + } + case WaitForFixedUpdate waitForFixedUpdate: + { + await UniTask.Yield(PlayerLoopTiming.LastFixedUpdate); + break; + } + case WaitForEndOfFrame waitForEndOfFrame: + { + await UniTask.WaitForEndOfFrame(); + break; + } + case WaitForSeconds waitForSeconds: + { + var second = (float)waitForSeconds_Seconds.GetValue(waitForSeconds); + var elapsed = 0.0f; + do + { + await UniTask.Yield(PlayerLoopTiming.LastUpdate); + cancellationToken.ThrowIfCancellationRequested(); + elapsed += Time.deltaTime; + } while (elapsed < second); + break; + } + case CustomYieldInstruction cyi: // Include WWW, WaitForSecondsRealtime + { + while (cyi.keepWaiting) + { + await UniTask.Yield(PlayerLoopTiming.LastUpdate); + cancellationToken.ThrowIfCancellationRequested(); + } + break; + } + case IEnumerator innerEnumerator: + { + await StartCoroutineAsUniTask(innerEnumerator, null, cancellationToken); + break; + } + + case AsyncOperation ao: + { + await ao; + break; + } + + case YieldInstruction yieldInstruction: + { + bool isKnownIssue; + switch (yieldInstruction) + { + case Coroutine coroutine: + { + isKnownIssue = true; + break; + } + default: + { + isKnownIssue = false; + break; + } + } + throw new NotSupportedException("Coroutine yields YieldInstruction of type \"" + yieldInstruction.GetType().FullName + (isKnownIssue ? "\". Which is not supported by UniTask." : "\". Which is unknown by UniTask.")); } - } - else - { - yield return null; - } - } - else if (current is IEnumerator e3) - { - var e4 = ConsumeEnumerator(e3); - while (e4.MoveNext()) - { - yield return null; - } - } - else - { - // WaitForEndOfFrame, WaitForFixedUpdate, others. - yield return null; } + cancellationToken.ThrowIfCancellationRequested(); + } + } + catch(Exception e) + { + exception = e; + throw; + } + finally + { + if (timing is PlayerLoopTiming playerLoopTiming && PlayerLoopHelper.TryGetCurrentPlayerLoopTiming() != playerLoopTiming) // To avoid unintentional one frame of extra latency + { + await UniTask.Yield(playerLoopTiming); + + } + if(exception is null) + { + cancellationToken.ThrowIfCancellationRequested(); } } - // WWW and others as CustomYieldInstruction. - static IEnumerator UnwrapWaitCustomYieldInstruction(CustomYieldInstruction yieldInstruction) - { - while (yieldInstruction.keepWaiting) - { - yield return null; - } - } - - static readonly FieldInfo waitForSeconds_Seconds = typeof(WaitForSeconds).GetField("m_Seconds", BindingFlags.Instance | BindingFlags.GetField | BindingFlags.NonPublic); - - static IEnumerator UnwrapWaitForSeconds(WaitForSeconds waitForSeconds) - { - var second = (float)waitForSeconds_Seconds.GetValue(waitForSeconds); - var elapsed = 0.0f; - while (true) - { - yield return null; - - elapsed += Time.deltaTime; - if (elapsed >= second) - { - break; - } - }; - } - - static IEnumerator UnwrapWaitAsyncOperation(AsyncOperation asyncOperation) - { - while (!asyncOperation.isDone) - { - yield return null; - } - } } + + static readonly FieldInfo waitForSeconds_Seconds = typeof(WaitForSeconds).GetField("m_Seconds", BindingFlags.Instance | BindingFlags.GetField | BindingFlags.NonPublic); } }