From 7b273c4bd1b7af5e41bf727e2ad246deb851d733 Mon Sep 17 00:00:00 2001
From: neuecc <ils@neue.cc>
Date: Tue, 19 May 2020 03:10:37 +0900
Subject: [PATCH] Add UniTask.Defer

---
 src/UniTask.NetCoreTests/DeferTest.cs         | 47 +++++++++
 .../UniTask/Runtime/UniTask.Factory.cs        | 95 +++++++++++++++++++
 2 files changed, 142 insertions(+)
 create mode 100644 src/UniTask.NetCoreTests/DeferTest.cs

diff --git a/src/UniTask.NetCoreTests/DeferTest.cs b/src/UniTask.NetCoreTests/DeferTest.cs
new file mode 100644
index 0000000..d2578e3
--- /dev/null
+++ b/src/UniTask.NetCoreTests/DeferTest.cs
@@ -0,0 +1,47 @@
+using Cysharp.Threading.Tasks;
+using FluentAssertions;
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using System.Threading;
+using System.Threading.Channels;
+using Cysharp.Threading.Tasks.Linq;
+using System.Threading.Tasks;
+using Xunit;
+
+namespace NetCoreTests
+{
+    public class DeferTest
+    {
+        [Fact]
+        public async Task D()
+        {
+            var created = false;
+            var v = UniTask.Defer(() => { created = true; return UniTask.Run(() => 10); });
+
+            created.Should().BeFalse();
+
+            var t = await v;
+
+            created.Should().BeTrue();
+
+            t.Should().Be(10);
+        }
+
+        [Fact]
+        public async Task D2()
+        {
+            var created = false;
+            var v = UniTask.Defer(() => { created = true; return UniTask.Run(() => 10).AsUniTask(); });
+
+            created.Should().BeFalse();
+
+            await v;
+
+            created.Should().BeTrue();
+        }
+    }
+
+
+}
diff --git a/src/UniTask/Assets/Plugins/UniTask/Runtime/UniTask.Factory.cs b/src/UniTask/Assets/Plugins/UniTask/Runtime/UniTask.Factory.cs
index 214db62..75e3cea 100644
--- a/src/UniTask/Assets/Plugins/UniTask/Runtime/UniTask.Factory.cs
+++ b/src/UniTask/Assets/Plugins/UniTask/Runtime/UniTask.Factory.cs
@@ -132,6 +132,101 @@ namespace Cysharp.Threading.Tasks
         {
             asyncAction(state).Forget();
         }
+
+        /// <summary>
+        /// Defer the task creation just before call await.
+        /// </summary>
+        public static UniTask Defer(Func<UniTask> factory)
+        {
+            return new UniTask(new DeferPromise(factory), 0);
+        }
+
+        /// <summary>
+        /// Defer the task creation just before call await.
+        /// </summary>
+        public static UniTask<T> Defer<T>(Func<UniTask<T>> factory)
+        {
+            return new UniTask<T>(new DeferPromise<T>(factory), 0);
+        }
+
+        sealed class DeferPromise : IUniTaskSource
+        {
+            Func<UniTask> factory;
+            UniTask task;
+            UniTask.Awaiter awaiter;
+
+            public DeferPromise(Func<UniTask> factory)
+            {
+                this.factory = factory;
+            }
+
+            public void GetResult(short token)
+            {
+                awaiter.GetResult();
+            }
+
+            public UniTaskStatus GetStatus(short token)
+            {
+                var f = Interlocked.Exchange(ref factory, null);
+                if (f == null) throw new InvalidOperationException("Can't call twice.");
+
+                task = f();
+                awaiter = f().GetAwaiter();
+                return task.Status;
+            }
+
+            public void OnCompleted(Action<object> continuation, object state, short token)
+            {
+                awaiter.SourceOnCompleted(continuation, state);
+            }
+
+            public UniTaskStatus UnsafeGetStatus()
+            {
+                return task.Status;
+            }
+        }
+
+        sealed class DeferPromise<T> : IUniTaskSource<T>
+        {
+            Func<UniTask<T>> factory;
+            UniTask<T> task;
+            UniTask<T>.Awaiter awaiter;
+
+            public DeferPromise(Func<UniTask<T>> factory)
+            {
+                this.factory = factory;
+            }
+
+            public T GetResult(short token)
+            {
+                return awaiter.GetResult();
+            }
+
+            void IUniTaskSource.GetResult(short token)
+            {
+                awaiter.GetResult();
+            }
+
+            public UniTaskStatus GetStatus(short token)
+            {
+                var f = Interlocked.Exchange(ref factory, null);
+                if (f == null) throw new InvalidOperationException("Can't call twice.");
+
+                task = f();
+                awaiter = f().GetAwaiter();
+                return task.Status;
+            }
+
+            public void OnCompleted(Action<object> continuation, object state, short token)
+            {
+                awaiter.SourceOnCompleted(continuation, state);
+            }
+
+            public UniTaskStatus UnsafeGetStatus()
+            {
+                return task.Status;
+            }
+        }
     }
 
     internal static class CompletedTasks