Update behavior of IEnumerator.ToUniTask to be more like StartCoroutine

pull/159/head
Ram.Type-0 2020-09-03 01:30:51 +09:00
parent 833b40e59f
commit e2728560c3
1 changed files with 95 additions and 205 deletions

View File

@ -17,235 +17,125 @@ namespace Cysharp.Threading.Tasks
{
var e = (IEnumerator)enumerator;
Error.ThrowArgumentNullException(e, nameof(enumerator));
return new UniTask(EnumeratorPromise.Create(e, PlayerLoopTiming.Update, CancellationToken.None, out var token), token).GetAwaiter();
return StartCoroutineAsUniTask(enumerator).GetAwaiter();
}
public static UniTask WithCancellation(this IEnumerator enumerator, CancellationToken cancellationToken)
{
Error.ThrowArgumentNullException(enumerator, nameof(enumerator));
return new UniTask(EnumeratorPromise.Create(enumerator, PlayerLoopTiming.Update, cancellationToken, out var token), token);
return StartCoroutineAsUniTask(enumerator, cancellationToken: cancellationToken);
}
public static UniTask ToUniTask(this IEnumerator enumerator, PlayerLoopTiming timing = PlayerLoopTiming.Update, CancellationToken cancellationToken = default(CancellationToken))
public static UniTask ToUniTask(this IEnumerator enumerator, PlayerLoopTiming? timing = null, CancellationToken cancellationToken = default(CancellationToken))
{
Error.ThrowArgumentNullException(enumerator, nameof(enumerator));
return new UniTask(EnumeratorPromise.Create(enumerator, timing, cancellationToken, out var token), token);
return StartCoroutineAsUniTask(enumerator, timing, cancellationToken);
}
sealed class EnumeratorPromise : IUniTaskSource, IPlayerLoopItem, ITaskPoolNode<EnumeratorPromise>
static async UniTask StartCoroutineAsUniTask(IEnumerator enumerator, PlayerLoopTiming? timing = default, CancellationToken cancellationToken = default)
{
static TaskPool<EnumeratorPromise> pool;
public EnumeratorPromise NextNode { get; set; }
static EnumeratorPromise()
{
TaskPool.RegisterSizeGetter(typeof(EnumeratorPromise), () => pool.Size);
}
IEnumerator innerEnumerator;
CancellationToken cancellationToken;
int initialFrame;
UniTaskCompletionSourceCore<object> core;
EnumeratorPromise()
{
}
public static IUniTaskSource Create(IEnumerator innerEnumerator, PlayerLoopTiming timing, CancellationToken cancellationToken, out short token)
{
if (cancellationToken.IsCancellationRequested)
{
return AutoResetUniTaskCompletionSource.CreateFromCanceled(cancellationToken, out token);
}
if (!pool.TryPop(out var result))
{
result = new EnumeratorPromise();
}
TaskTracker.TrackActiveTask(result, 3);
result.innerEnumerator = ConsumeEnumerator(innerEnumerator);
result.cancellationToken = cancellationToken;
result.initialFrame = -1;
PlayerLoopHelper.AddAction(timing, result);
token = result.core.Version;
result.MoveNext(); // run immediately.
return result;
}
public void GetResult(short token)
{
try
{
core.GetResult(token);
}
finally
{
TryReturn();
}
}
public UniTaskStatus GetStatus(short token)
{
return core.GetStatus(token);
}
public UniTaskStatus UnsafeGetStatus()
{
return core.UnsafeGetStatus();
}
public void OnCompleted(Action<object> continuation, object state, short token)
{
core.OnCompleted(continuation, state, token);
}
public bool MoveNext()
{
if (cancellationToken.IsCancellationRequested)
{
core.TrySetCanceled(cancellationToken);
return false;
}
if (initialFrame == -1)
{
// Time can not touch in threadpool.
if (PlayerLoopHelper.IsMainThread)
{
initialFrame = Time.frameCount;
}
}
else if (initialFrame == Time.frameCount)
{
return true; // already executed in first frame, skip.
}
try
{
if (innerEnumerator.MoveNext())
{
return true;
}
}
catch (Exception ex)
{
core.TrySetException(ex);
return false;
}
core.TrySetResult(null);
return false;
}
bool TryReturn()
{
TaskTracker.RemoveTracking(this);
core.Reset();
innerEnumerator = default;
cancellationToken = default;
return pool.TryPush(this);
}
// Unwrap YieldInstructions
static IEnumerator ConsumeEnumerator(IEnumerator enumerator)
Exception exception = null;
try
{
cancellationToken.ThrowIfCancellationRequested();
while (enumerator.MoveNext())
{
cancellationToken.ThrowIfCancellationRequested();
var current = enumerator.Current;
if (current == null)
switch (current)
{
yield return null;
}
else if (current is CustomYieldInstruction)
{
// WWW, WaitForSecondsRealtime
var e2 = UnwrapWaitCustomYieldInstruction((CustomYieldInstruction)current);
while (e2.MoveNext())
{
yield return null;
}
}
else if (current is YieldInstruction)
{
IEnumerator innerCoroutine = null;
switch (current)
{
case AsyncOperation ao:
innerCoroutine = UnwrapWaitAsyncOperation(ao);
break;
case WaitForSeconds wfs:
innerCoroutine = UnwrapWaitForSeconds(wfs);
break;
}
if (innerCoroutine != null)
{
while (innerCoroutine.MoveNext())
case null:
default:
{
yield return null;
await UniTask.Yield(PlayerLoopTiming.LastUpdate);
break;
}
case WaitForFixedUpdate waitForFixedUpdate:
{
await UniTask.Yield(PlayerLoopTiming.LastFixedUpdate);
break;
}
case WaitForEndOfFrame waitForEndOfFrame:
{
await UniTask.WaitForEndOfFrame();
break;
}
case WaitForSeconds waitForSeconds:
{
var second = (float)waitForSeconds_Seconds.GetValue(waitForSeconds);
var elapsed = 0.0f;
do
{
await UniTask.Yield(PlayerLoopTiming.LastUpdate);
cancellationToken.ThrowIfCancellationRequested();
elapsed += Time.deltaTime;
} while (elapsed < second);
break;
}
case CustomYieldInstruction cyi: // Include WWW, WaitForSecondsRealtime
{
while (cyi.keepWaiting)
{
await UniTask.Yield(PlayerLoopTiming.LastUpdate);
cancellationToken.ThrowIfCancellationRequested();
}
break;
}
case IEnumerator innerEnumerator:
{
await StartCoroutineAsUniTask(innerEnumerator, null, cancellationToken);
break;
}
case AsyncOperation ao:
{
await ao;
break;
}
case YieldInstruction yieldInstruction:
{
bool isKnownIssue;
switch (yieldInstruction)
{
case Coroutine coroutine:
{
isKnownIssue = true;
break;
}
default:
{
isKnownIssue = false;
break;
}
}
throw new NotSupportedException("Coroutine yields YieldInstruction of type \"" + yieldInstruction.GetType().FullName + (isKnownIssue ? "\". Which is not supported by UniTask." : "\". Which is unknown by UniTask."));
}
}
else
{
yield return null;
}
}
else if (current is IEnumerator e3)
{
var e4 = ConsumeEnumerator(e3);
while (e4.MoveNext())
{
yield return null;
}
}
else
{
// WaitForEndOfFrame, WaitForFixedUpdate, others.
yield return null;
}
cancellationToken.ThrowIfCancellationRequested();
}
}
catch(Exception e)
{
exception = e;
throw;
}
finally
{
if (timing is PlayerLoopTiming playerLoopTiming && PlayerLoopHelper.TryGetCurrentPlayerLoopTiming() != playerLoopTiming) // To avoid unintentional one frame of extra latency
{
await UniTask.Yield(playerLoopTiming);
}
if(exception is null)
{
cancellationToken.ThrowIfCancellationRequested();
}
}
// WWW and others as CustomYieldInstruction.
static IEnumerator UnwrapWaitCustomYieldInstruction(CustomYieldInstruction yieldInstruction)
{
while (yieldInstruction.keepWaiting)
{
yield return null;
}
}
static readonly FieldInfo waitForSeconds_Seconds = typeof(WaitForSeconds).GetField("m_Seconds", BindingFlags.Instance | BindingFlags.GetField | BindingFlags.NonPublic);
static IEnumerator UnwrapWaitForSeconds(WaitForSeconds waitForSeconds)
{
var second = (float)waitForSeconds_Seconds.GetValue(waitForSeconds);
var elapsed = 0.0f;
while (true)
{
yield return null;
elapsed += Time.deltaTime;
if (elapsed >= second)
{
break;
}
};
}
static IEnumerator UnwrapWaitAsyncOperation(AsyncOperation asyncOperation)
{
while (!asyncOperation.isDone)
{
yield return null;
}
}
}
static readonly FieldInfo waitForSeconds_Seconds = typeof(WaitForSeconds).GetField("m_Seconds", BindingFlags.Instance | BindingFlags.GetField | BindingFlags.NonPublic);
}
}