From 0640f278cc5f54912196c5527a540beb1ff5a963 Mon Sep 17 00:00:00 2001 From: neuecc Date: Thu, 18 Jun 2020 03:02:01 +0900 Subject: [PATCH] Fix AsyncLazy can not await multiple times when task is not completed --- src/UniTask.NetCoreTests/AsyncLazyTest.cs | 167 +++++++++++++++ .../Plugins/UniTask/Runtime/AsyncLazy.cs | 194 +++++++++++++++--- .../Runtime/UniTaskCompletionSource.cs | 2 + .../UniTask/Runtime/UniTaskExtensions.cs | 4 +- 4 files changed, 331 insertions(+), 36 deletions(-) create mode 100644 src/UniTask.NetCoreTests/AsyncLazyTest.cs diff --git a/src/UniTask.NetCoreTests/AsyncLazyTest.cs b/src/UniTask.NetCoreTests/AsyncLazyTest.cs new file mode 100644 index 0000000..c517d72 --- /dev/null +++ b/src/UniTask.NetCoreTests/AsyncLazyTest.cs @@ -0,0 +1,167 @@ +using Cysharp.Threading.Tasks; +using FluentAssertions; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading; +using System.Threading.Channels; +using Cysharp.Threading.Tasks.Linq; +using System.Threading.Tasks; +using Xunit; + +namespace NetCoreTests +{ + public class AsyncLazyTest + { + [Fact] + public async Task LazyLazy() + { + { + var l = UniTask.Lazy(() => After()); + var a = AwaitAwait(l.Task); + var b = AwaitAwait(l.Task); + var c = AwaitAwait(l.Task); + + await a; + await b; + await c; + } + { + var l = UniTask.Lazy(() => AfterException()); + var a = AwaitAwait(l.Task); + var b = AwaitAwait(l.Task); + var c = AwaitAwait(l.Task); + + await Assert.ThrowsAsync(async () => await a); + await Assert.ThrowsAsync(async () => await b); + await Assert.ThrowsAsync(async () => await c); + } + } + + [Fact] + public async Task LazyImmediate() + { + { + var l = UniTask.Lazy(() => UniTask.FromResult(1).AsUniTask()); + var a = AwaitAwait(l.Task); + var b = AwaitAwait(l.Task); + var c = AwaitAwait(l.Task); + + await a; + await b; + await c; + } + { + var l = UniTask.Lazy(() => UniTask.FromException(new TaskTestException())); + var a = AwaitAwait(l.Task); + var b = AwaitAwait(l.Task); + var c = AwaitAwait(l.Task); + + await Assert.ThrowsAsync(async () => await a); + await Assert.ThrowsAsync(async () => await b); + await Assert.ThrowsAsync(async () => await c); + } + } + + static async UniTask AwaitAwait(UniTask t) + { + await t; + } + + + async UniTask After() + { + await UniTask.Yield(); + Thread.Sleep(TimeSpan.FromSeconds(1)); + await UniTask.Yield(); + await UniTask.Yield(); + } + + async UniTask AfterException() + { + await UniTask.Yield(); + Thread.Sleep(TimeSpan.FromSeconds(1)); + await UniTask.Yield(); + throw new TaskTestException(); + } + } + + public class AsyncLazyTest2 + { + [Fact] + public async Task LazyLazy() + { + { + var l = UniTask.Lazy(() => After()); + var a = AwaitAwait(l.Task); + var b = AwaitAwait(l.Task); + var c = AwaitAwait(l.Task); + + var a2 = await a; + var b2 = await b; + var c2 = await c; + (a2, b2, c2).Should().Be((10, 10, 10)); + } + { + var l = UniTask.Lazy(() => AfterException()); + var a = AwaitAwait(l.Task); + var b = AwaitAwait(l.Task); + var c = AwaitAwait(l.Task); + + await Assert.ThrowsAsync(async () => await a); + await Assert.ThrowsAsync(async () => await b); + await Assert.ThrowsAsync(async () => await c); + } + } + + [Fact] + public async Task LazyImmediate() + { + { + var l = UniTask.Lazy(() => UniTask.FromResult(1)); + var a = AwaitAwait(l.Task); + var b = AwaitAwait(l.Task); + var c = AwaitAwait(l.Task); + + var a2 = await a; + var b2 = await b; + var c2 = await c; + (a2, b2, c2).Should().Be((1, 1, 1)); + } + { + var l = UniTask.Lazy(() => UniTask.FromException(new TaskTestException())); + var a = AwaitAwait(l.Task); + var b = AwaitAwait(l.Task); + var c = AwaitAwait(l.Task); + + await Assert.ThrowsAsync(async () => await a); + await Assert.ThrowsAsync(async () => await b); + await Assert.ThrowsAsync(async () => await c); + } + } + + static async UniTask AwaitAwait(UniTask t) + { + return await t; + } + + + async UniTask After() + { + await UniTask.Yield(); + Thread.Sleep(TimeSpan.FromSeconds(1)); + await UniTask.Yield(); + await UniTask.Yield(); + return 10; + } + + async UniTask AfterException() + { + await UniTask.Yield(); + Thread.Sleep(TimeSpan.FromSeconds(1)); + await UniTask.Yield(); + throw new TaskTestException(); + } + } +} diff --git a/src/UniTask/Assets/Plugins/UniTask/Runtime/AsyncLazy.cs b/src/UniTask/Assets/Plugins/UniTask/Runtime/AsyncLazy.cs index 3e66a1f..51bfadc 100644 --- a/src/UniTask/Assets/Plugins/UniTask/Runtime/AsyncLazy.cs +++ b/src/UniTask/Assets/Plugins/UniTask/Runtime/AsyncLazy.cs @@ -7,113 +7,239 @@ namespace Cysharp.Threading.Tasks { public class AsyncLazy { - Func valueFactory; - UniTask target; + static Action continuation = SetCompletionSource; + + Func taskFactory; + UniTaskCompletionSource completionSource; + UniTask.Awaiter awaiter; + object syncLock; bool initialized; - public AsyncLazy(Func valueFactory) + public AsyncLazy(Func taskFactory) { - this.valueFactory = valueFactory; - this.target = default; + this.taskFactory = taskFactory; + this.completionSource = new UniTaskCompletionSource(); this.syncLock = new object(); this.initialized = false; } - internal AsyncLazy(UniTask value) + internal AsyncLazy(UniTask task) { - this.valueFactory = null; - this.target = value; + this.taskFactory = null; + this.completionSource = new UniTaskCompletionSource(); this.syncLock = null; this.initialized = true; + + var awaiter = task.GetAwaiter(); + if (awaiter.IsCompleted) + { + SetCompletionSource(awaiter); + } + else + { + this.awaiter = awaiter; + awaiter.SourceOnCompleted(continuation, this); + } } - public UniTask Task => EnsureInitialized(); + public UniTask Task + { + get + { + EnsureInitialized(); + return completionSource.Task; + } + } - public UniTask.Awaiter GetAwaiter() => EnsureInitialized().GetAwaiter(); - UniTask EnsureInitialized() + public UniTask.Awaiter GetAwaiter() => Task.GetAwaiter(); + + void EnsureInitialized() { if (Volatile.Read(ref initialized)) { - return target; + return; } - return EnsureInitializedCore(); + EnsureInitializedCore(); } - UniTask EnsureInitializedCore() + void EnsureInitializedCore() { lock (syncLock) { if (!Volatile.Read(ref initialized)) { - var f = Interlocked.Exchange(ref valueFactory, null); + var f = Interlocked.Exchange(ref taskFactory, null); if (f != null) { - target = f().Preserve(); // with preserve(allow multiple await). + var task = f(); + var awaiter = task.GetAwaiter(); + if (awaiter.IsCompleted) + { + SetCompletionSource(awaiter); + } + else + { + this.awaiter = awaiter; + awaiter.SourceOnCompleted(continuation, this); + } + Volatile.Write(ref initialized, true); } } } + } - return target; + void SetCompletionSource(in UniTask.Awaiter awaiter) + { + try + { + awaiter.GetResult(); + completionSource.TrySetResult(); + } + catch (Exception ex) + { + completionSource.TrySetException(ex); + } + } + + static void SetCompletionSource(object state) + { + var self = (AsyncLazy)state; + try + { + self.awaiter.GetResult(); + self.completionSource.TrySetResult(); + } + catch (Exception ex) + { + self.completionSource.TrySetException(ex); + } + finally + { + self.awaiter = default; + } } } public class AsyncLazy { - Func> valueFactory; - UniTask target; + static Action continuation = SetCompletionSource; + + Func> taskFactory; + UniTaskCompletionSource completionSource; + UniTask.Awaiter awaiter; + object syncLock; bool initialized; - public AsyncLazy(Func> valueFactory) + public AsyncLazy(Func> taskFactory) { - this.valueFactory = valueFactory; - this.target = default; + this.taskFactory = taskFactory; + this.completionSource = new UniTaskCompletionSource(); this.syncLock = new object(); this.initialized = false; } - internal AsyncLazy(UniTask value) + internal AsyncLazy(UniTask task) { - this.valueFactory = null; - this.target = value; + this.taskFactory = null; + this.completionSource = new UniTaskCompletionSource(); this.syncLock = null; this.initialized = true; + + var awaiter = task.GetAwaiter(); + if (awaiter.IsCompleted) + { + SetCompletionSource(awaiter); + } + else + { + this.awaiter = awaiter; + awaiter.SourceOnCompleted(continuation, this); + } } - public UniTask Task => EnsureInitialized(); + public UniTask Task + { + get + { + EnsureInitialized(); + return completionSource.Task; + } + } - public UniTask.Awaiter GetAwaiter() => EnsureInitialized().GetAwaiter(); - UniTask EnsureInitialized() + public UniTask.Awaiter GetAwaiter() => Task.GetAwaiter(); + + void EnsureInitialized() { if (Volatile.Read(ref initialized)) { - return target; + return; } - return EnsureInitializedCore(); + EnsureInitializedCore(); } - UniTask EnsureInitializedCore() + void EnsureInitializedCore() { lock (syncLock) { if (!Volatile.Read(ref initialized)) { - var f = Interlocked.Exchange(ref valueFactory, null); + var f = Interlocked.Exchange(ref taskFactory, null); if (f != null) { - target = f().Preserve(); // with preserve(allow multiple await). + var task = f(); + var awaiter = task.GetAwaiter(); + if (awaiter.IsCompleted) + { + SetCompletionSource(awaiter); + } + else + { + this.awaiter = awaiter; + awaiter.SourceOnCompleted(continuation, this); + } + Volatile.Write(ref initialized, true); } } } + } - return target; + void SetCompletionSource(in UniTask.Awaiter awaiter) + { + try + { + var result = awaiter.GetResult(); + completionSource.TrySetResult(result); + } + catch (Exception ex) + { + completionSource.TrySetException(ex); + } + } + + static void SetCompletionSource(object state) + { + var self = (AsyncLazy)state; + try + { + var result = self.awaiter.GetResult(); + self.completionSource.TrySetResult(result); + } + catch (Exception ex) + { + self.completionSource.TrySetException(ex); + } + finally + { + self.awaiter = default; + } } } } diff --git a/src/UniTask/Assets/Plugins/UniTask/Runtime/UniTaskCompletionSource.cs b/src/UniTask/Assets/Plugins/UniTask/Runtime/UniTaskCompletionSource.cs index acc3000..70f804d 100644 --- a/src/UniTask/Assets/Plugins/UniTask/Runtime/UniTaskCompletionSource.cs +++ b/src/UniTask/Assets/Plugins/UniTask/Runtime/UniTaskCompletionSource.cs @@ -696,6 +696,7 @@ namespace Cysharp.Threading.Tasks } } + [DebuggerHidden] bool TrySignalCompletion(UniTaskStatus status) { if (Interlocked.CompareExchange(ref intStatus, (int)status, (int)UniTaskStatus.Pending) == (int)UniTaskStatus.Pending) @@ -886,6 +887,7 @@ namespace Cysharp.Threading.Tasks } } + [DebuggerHidden] bool TrySignalCompletion(UniTaskStatus status) { if (Interlocked.CompareExchange(ref intStatus, (int)status, (int)UniTaskStatus.Pending) == (int)UniTaskStatus.Pending) diff --git a/src/UniTask/Assets/Plugins/UniTask/Runtime/UniTaskExtensions.cs b/src/UniTask/Assets/Plugins/UniTask/Runtime/UniTaskExtensions.cs index f0cf2ac..7833b3e 100644 --- a/src/UniTask/Assets/Plugins/UniTask/Runtime/UniTaskExtensions.cs +++ b/src/UniTask/Assets/Plugins/UniTask/Runtime/UniTaskExtensions.cs @@ -181,12 +181,12 @@ namespace Cysharp.Threading.Tasks public static AsyncLazy ToAsyncLazy(this UniTask task) { - return new AsyncLazy(task.Preserve()); // require Preserve + return new AsyncLazy(task); } public static AsyncLazy ToAsyncLazy(this UniTask task) { - return new AsyncLazy(task.Preserve()); // require Preserve + return new AsyncLazy(task); } #if UNITY_2018_3_OR_NEWER