diff --git a/src/UniTask.NetCore/Linq/DefaultIfEmpty.cs b/src/UniTask.NetCore/Linq/DefaultIfEmpty.cs index 756fd5a..e8d3662 100644 --- a/src/UniTask.NetCore/Linq/DefaultIfEmpty.cs +++ b/src/UniTask.NetCore/Linq/DefaultIfEmpty.cs @@ -73,11 +73,16 @@ namespace Cysharp.Threading.Tasks.Linq cancellationToken.ThrowIfCancellationRequested(); completionSource.Reset(); - if (iteratingState == IteratingState.Empty) + if (iteratingState == IteratingState.Completed) { return CompletedTasks.False; } + if (enumerator == null) + { + enumerator = source.GetAsyncEnumerator(cancellationToken); + } + awaiter = enumerator.MoveNextAsync().GetAwaiter(); if (awaiter.IsCompleted) diff --git a/src/UniTask.NetCoreTests/Linq/Concat.cs b/src/UniTask.NetCoreTests/Linq/Concat.cs index 2d6cdfa..5b1a3d2 100644 --- a/src/UniTask.NetCoreTests/Linq/Concat.cs +++ b/src/UniTask.NetCoreTests/Linq/Concat.cs @@ -112,5 +112,33 @@ namespace NetCoreTests.Linq await Assert.ThrowsAsync(async () => await zs); } } + + [Fact] + public async Task DefaultIfEmpty() + { + { + var xs = await Enumerable.Range(1, 0).ToUniTaskAsyncEnumerable().DefaultIfEmpty(99).ToArrayAsync(); + var ys = Enumerable.Range(1, 0).DefaultIfEmpty(99).ToArray(); + xs.Should().BeEquivalentTo(ys); + } + { + var xs = await Enumerable.Range(1, 1).ToUniTaskAsyncEnumerable().DefaultIfEmpty(99).ToArrayAsync(); + var ys = Enumerable.Range(1, 1).DefaultIfEmpty(99).ToArray(); + xs.Should().BeEquivalentTo(ys); + } + { + var xs = await Enumerable.Range(1, 10).ToUniTaskAsyncEnumerable().DefaultIfEmpty(99).ToArrayAsync(); + var ys = Enumerable.Range(1, 10).DefaultIfEmpty(99).ToArray(); + xs.Should().BeEquivalentTo(ys); + } + // Throw + { + foreach (var item in UniTaskTestException.Throws()) + { + var xs = item.DefaultIfEmpty().ToArrayAsync(); + await Assert.ThrowsAsync(async () => await xs); + } + } + } } } diff --git a/src/UniTask.NetCoreTests/Linq/_Exception.cs b/src/UniTask.NetCoreTests/Linq/_Exception.cs index 1e977cc..e28c513 100644 --- a/src/UniTask.NetCoreTests/Linq/_Exception.cs +++ b/src/UniTask.NetCoreTests/Linq/_Exception.cs @@ -23,14 +23,14 @@ namespace NetCoreTests.Linq } - public static IEnumerable> Throws() + public static IEnumerable> Throws(int count = 3) { yield return ThrowImmediate(); yield return ThrowAfter(); yield return ThrowInMoveNext(); - yield return UniTaskAsyncEnumerable.Range(1, 3).Concat(ThrowImmediate()); - yield return UniTaskAsyncEnumerable.Range(1, 3).Concat(ThrowAfter()); - yield return UniTaskAsyncEnumerable.Range(1, 3).Concat(ThrowInMoveNext()); + yield return UniTaskAsyncEnumerable.Range(1, count).Concat(ThrowImmediate()); + yield return UniTaskAsyncEnumerable.Range(1, count).Concat(ThrowAfter()); + yield return UniTaskAsyncEnumerable.Range(1, count).Concat(ThrowInMoveNext()); } }