diff --git a/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/Take.cs b/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/Take.cs index 13388d8..be24ca7 100644 --- a/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/Take.cs +++ b/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/Take.cs @@ -30,40 +30,93 @@ namespace Cysharp.Threading.Tasks.Linq return new _Take(source, count, cancellationToken); } - sealed class _Take : AsyncEnumeratorBase + sealed class _Take : MoveNextSource, IUniTaskAsyncEnumerator { - readonly int count; + static readonly Action MoveNextCoreDelegate = MoveNextCore; + readonly IUniTaskAsyncEnumerable source; + readonly int count; + CancellationToken cancellationToken; + + IUniTaskAsyncEnumerator enumerator; + UniTask.Awaiter awaiter; int index; public _Take(IUniTaskAsyncEnumerable source, int count, CancellationToken cancellationToken) - : base(source, cancellationToken) { + this.source = source; this.count = count; + this.cancellationToken = cancellationToken; } - protected override bool TryMoveNextCore(bool sourceHasCurrent, out bool result) + public TSource Current { get; private set; } + + public UniTask MoveNextAsync() { - if (sourceHasCurrent) + cancellationToken.ThrowIfCancellationRequested(); + + if (enumerator == null) { - if (checked(index++) < count) + enumerator = source.GetAsyncEnumerator(cancellationToken); + } + + if (checked(index) >= count) + { + return CompletedTasks.False; + } + + completionSource.Reset(); + SourceMoveNext(); + return new UniTask(this, completionSource.Version); + } + + void SourceMoveNext() + { + try + { + awaiter = enumerator.MoveNextAsync().GetAwaiter(); + if (awaiter.IsCompleted) { - Current = SourceCurrent; - result = true; - return true; + MoveNextCore(this); } else { - result = false; - return true; + awaiter.SourceOnCompleted(MoveNextCoreDelegate, this); } } - else + catch (Exception ex) { - result = false; - return true; + completionSource.TrySetException(ex); } } + + static void MoveNextCore(object state) + { + var self = (_Take)state; + + if (self.TryGetResult(self.awaiter, out var result)) + { + if (result) + { + self.index++; + self.Current = self.enumerator.Current; + self.completionSource.TrySetResult(true); + } + else + { + self.completionSource.TrySetResult(false); + } + } + } + + public UniTask DisposeAsync() + { + if (enumerator != null) + { + return enumerator.DisposeAsync(); + } + return default; + } } } } \ No newline at end of file