diff --git a/src/UniTask.NetCore/Linq/SkipLast.cs b/src/UniTask.NetCore/Linq/SkipLast.cs index 665e15d..345ec3e 100644 --- a/src/UniTask.NetCore/Linq/SkipLast.cs +++ b/src/UniTask.NetCore/Linq/SkipLast.cs @@ -138,6 +138,10 @@ namespace Cysharp.Threading.Tasks.Linq self.completionSource.TrySetResult(false); } } + else + { + self.continueNext = false; + } } public UniTask DisposeAsync() diff --git a/src/UniTask.NetCore/Linq/TakeLast.cs b/src/UniTask.NetCore/Linq/TakeLast.cs index 4b4d955..d770c3b 100644 --- a/src/UniTask.NetCore/Linq/TakeLast.cs +++ b/src/UniTask.NetCore/Linq/TakeLast.cs @@ -1,775 +1,173 @@ -namespace Cysharp.Threading.Tasks.Linq +using Cysharp.Threading.Tasks.Internal; +using System; +using System.Collections.Generic; +using System.Threading; + +namespace Cysharp.Threading.Tasks.Linq { - internal sealed class TakeLast + public static partial class UniTaskAsyncEnumerable { + public static IUniTaskAsyncEnumerable TakeLast(this IUniTaskAsyncEnumerable source, Int32 count) + { + Error.ThrowArgumentNullException(source, nameof(source)); + + // non take. + if (count <= 0) + { + return Empty(); + } + + return new TakeLast(source, count); + } } - -} - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + internal sealed class TakeLast : IUniTaskAsyncEnumerable + { + readonly IUniTaskAsyncEnumerable source; + readonly int count; + + public TakeLast(IUniTaskAsyncEnumerable source, int count) + { + this.source = source; + this.count = count; + } + + public IUniTaskAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) + { + return new Enumerator(source, count, cancellationToken); + } + + sealed class Enumerator : MoveNextSource, IUniTaskAsyncEnumerator + { + static readonly Action MoveNextCoreDelegate = MoveNextCore; + + readonly IUniTaskAsyncEnumerable source; + readonly int count; + CancellationToken cancellationToken; + + IUniTaskAsyncEnumerator enumerator; + UniTask.Awaiter awaiter; + Queue queue; + + bool iterateCompleted; + bool continueNext; + + public Enumerator(IUniTaskAsyncEnumerable source, int count, CancellationToken cancellationToken) + { + this.source = source; + this.count = count; + this.cancellationToken = cancellationToken; + } + + public TSource Current { get; private set; } + + public UniTask MoveNextAsync() + { + cancellationToken.ThrowIfCancellationRequested(); + + if (enumerator == null) + { + enumerator = source.GetAsyncEnumerator(cancellationToken); + queue = new Queue(); + } + + completionSource.Reset(); + SourceMoveNext(); + return new UniTask(this, completionSource.Version); + } + + void SourceMoveNext() + { + if (iterateCompleted) + { + if (queue.Count > 0) + { + Current = queue.Dequeue(); + completionSource.TrySetResult(true); + } + else + { + completionSource.TrySetResult(false); + } + + return; + } + + try + { + LOOP: + awaiter = enumerator.MoveNextAsync().GetAwaiter(); + if (awaiter.IsCompleted) + { + continueNext = true; + MoveNextCore(this); + if (continueNext) + { + continueNext = false; + goto LOOP; // avoid recursive + } + } + else + { + awaiter.SourceOnCompleted(MoveNextCoreDelegate, this); + } + } + catch (Exception ex) + { + completionSource.TrySetException(ex); + } + } + + + static void MoveNextCore(object state) + { + var self = (Enumerator)state; + + if (self.TryGetResult(self.awaiter, out var result)) + { + if (result) + { + if (self.queue.Count < self.count) + { + self.queue.Enqueue(self.enumerator.Current); + + if (!self.continueNext) + { + self.SourceMoveNext(); + } + } + else + { + self.queue.Dequeue(); + self.queue.Enqueue(self.enumerator.Current); + + if (!self.continueNext) + { + self.SourceMoveNext(); + } + } + } + else + { + self.continueNext = false; + self.iterateCompleted = true; + self.SourceMoveNext(); + } + } + else + { + self.continueNext = false; + } + } + + public UniTask DisposeAsync() + { + if (enumerator != null) + { + return enumerator.DisposeAsync(); + } + return default; + } + } + } +} \ No newline at end of file diff --git a/src/UniTask.NetCoreSandbox/Program.cs b/src/UniTask.NetCoreSandbox/Program.cs index dc26574..52a62d1 100644 --- a/src/UniTask.NetCoreSandbox/Program.cs +++ b/src/UniTask.NetCoreSandbox/Program.cs @@ -41,7 +41,7 @@ namespace NetCoreSandbox { await foreach (var item in UniTaskAsyncEnumerable.Range(1, 10) .SelectAwait(x => UniTask.Run(() => x)) - .SkipLast(6) + .TakeLast(6) diff --git a/src/UniTask.NetCoreTests/Linq/Paging.cs b/src/UniTask.NetCoreTests/Linq/Paging.cs index aef0600..cfea04b 100644 --- a/src/UniTask.NetCoreTests/Linq/Paging.cs +++ b/src/UniTask.NetCoreTests/Linq/Paging.cs @@ -64,6 +64,31 @@ namespace NetCoreTests.Linq await Assert.ThrowsAsync(async () => await xs); } } + [Theory] + [InlineData(0, 0)] + [InlineData(0, 1)] + [InlineData(9, 0)] + [InlineData(9, 1)] + [InlineData(9, 5)] + [InlineData(9, 9)] + [InlineData(9, 15)] + public async Task TakeLast(int collection, int takeCount) + { + var xs = await UniTaskAsyncEnumerable.Range(1, collection).TakeLast(takeCount).ToArrayAsync(); + var ys = Enumerable.Range(1, collection).TakeLast(takeCount).ToArray(); + + xs.Should().BeEquivalentTo(ys); + } + + [Fact] + public async Task TakeLastException() + { + foreach (var item in UniTaskTestException.Throws()) + { + var xs = item.TakeLast(5).ToArrayAsync(); + await Assert.ThrowsAsync(async () => await xs); + } + } [Theory] [InlineData(0, 0)]