using System; using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; namespace ZeroLevel.Services.Async.Internal { internal struct CompletionSourceAwaiterFactory : IAwaiterFactory, IEquatable> { private readonly CancellationToken _cancellationToken; public CompletionSourceAwaiterFactory(CancellationToken cancellationToken) { _cancellationToken = cancellationToken; } public IAwaiter CreateAwaiter() => new CompletionSourceAwaiter(_cancellationToken); /// /// A simple wrapper that implements . /// private class CompletionSourceAwaiter : IAwaiter { private readonly TaskCompletionSource _completionSource; private readonly CancellationTokenRegistration _registration; public CompletionSourceAwaiter(CancellationToken cancellationToken) { _completionSource = new TaskCompletionSource(); Task = new ValueTask(_completionSource.Task.WithYield()); _registration = cancellationToken.Register( state => { TaskCompletionSource awaiter = state as TaskCompletionSource; awaiter.TrySetCanceled(); }, _completionSource, useSynchronizationContext: false); } public bool TrySetResult(T result) { _registration.Dispose(); return _completionSource.TrySetResult(result); } public ValueTask Task { get; } } #region IEquatable> public override int GetHashCode() => EqualityComparer.Default.GetHashCode(_cancellationToken); public bool Equals(CompletionSourceAwaiterFactory other) => _cancellationToken == other._cancellationToken; public override bool Equals(object obj) => obj is CompletionSourceAwaiterFactory && Equals((CompletionSourceAwaiterFactory)obj); public static bool operator ==(CompletionSourceAwaiterFactory x, CompletionSourceAwaiterFactory y) => x.Equals(y); public static bool operator !=(CompletionSourceAwaiterFactory x, CompletionSourceAwaiterFactory y) => !x.Equals(y); #endregion } }