diff --git a/src/Client/AzureManaged/DurableTaskSchedulerClientExtensions.cs b/src/Client/AzureManaged/DurableTaskSchedulerClientExtensions.cs index d98ac1256..efe79748c 100644 --- a/src/Client/AzureManaged/DurableTaskSchedulerClientExtensions.cs +++ b/src/Client/AzureManaged/DurableTaskSchedulerClientExtensions.cs @@ -4,9 +4,11 @@ using System.Collections.Concurrent; using System.Diagnostics; using System.Linq; +using System.Threading; using Azure.Core; using Grpc.Net.Client; using Microsoft.DurableTask.Client.Grpc; +using Microsoft.DurableTask.Client.Grpc.Internal; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.DependencyInjection.Extensions; using Microsoft.Extensions.Options; @@ -156,7 +158,122 @@ public void Configure(string? name, GrpcDurableTaskClientOptions options) string cacheKey = $"{optionsName}\u001F{source.EndpointAddress}\u001F{source.TaskHubName}\u001F{source.ResourceId}\u001F{credentialType}\u001F{source.AllowInsecureCredentials}\u001F{retryOptionsKey}"; options.Channel = this.channels.GetOrAdd( cacheKey, - _ => new Lazy(source.CreateChannel)).Value; + _ => new Lazy(source.CreateChannel, LazyThreadSafetyMode.PublicationOnly)).Value; + options.SetChannelRecreator((oldChannel, ct) => this.RecreateChannelAsync(cacheKey, source, oldChannel, ct)); + } + + /// + /// Atomically swaps the cached channel for the given key with a freshly created one and schedules + /// graceful disposal of the old channel after a grace period so any in-flight RPCs from peer + /// clients can drain. Returns the currently cached channel if a peer client has already recreated it. + /// + async Task RecreateChannelAsync( + string cacheKey, + DurableTaskSchedulerClientOptions source, + GrpcChannel oldChannel, + CancellationToken cancellation) + { + cancellation.ThrowIfCancellationRequested(); + + // Recreate callbacks can outlive Configure(...) because clients keep the delegate on their + // options. Best-effort check for disposal before publishing anything back into the shared cache. + if (this.disposed == 1) + { + throw new ObjectDisposedException(nameof(ConfigureGrpcChannel)); + } + + // Shared-cache recreation has four relevant states: + // 1. No entry exists anymore. Create one and use it. + // 2. The entry already materialized a different channel. A peer client already refreshed it. + // 3. The entry still represents what this client observed. Win TryUpdate and publish the new channel. + // 4. The entry changes between our read and TryUpdate. Lose the race, dispose ours, and reuse the winner. + if (!this.channels.TryGetValue(cacheKey, out Lazy? currentLazy)) + { + // PublicationOnly avoids permanently caching a transient CreateChannel exception. + Lazy created = new(source.CreateChannel, LazyThreadSafetyMode.PublicationOnly); + if (this.disposed == 1) + { + throw new ObjectDisposedException(nameof(ConfigureGrpcChannel)); + } + + if (this.channels.TryAdd(cacheKey, created)) + { + return created.Value; + } + + this.channels.TryGetValue(cacheKey, out currentLazy); + } + + if (currentLazy is null) + { + throw new InvalidOperationException("Failed to obtain a cached gRPC channel after recreation attempt."); + } + + // Only a materialized Lazy can be compared against oldChannel by reference. If the cache slot + // has not created its channel yet, let TryUpdate decide whether this recreate attempt still owns it. + if (currentLazy.IsValueCreated && !ReferenceEquals(currentLazy.Value, oldChannel)) + { + // A peer client already swapped in a new channel; reuse it. + return currentLazy.Value; + } + + // Materialize the new channel BEFORE swapping the dictionary so a CreateChannel failure + // leaves the existing entry intact. If we swapped a not-yet-materialized Lazy and then + // CreateChannel threw, the dictionary would point to a permanently-failing Lazy and the + // old channel would have already been queued for disposal — an unrecoverable state. + GrpcChannel newChannel = source.CreateChannel(); + if (this.disposed == 1) + { + await DisposeChannelAsync(newChannel).ConfigureAwait(false); + throw new ObjectDisposedException(nameof(ConfigureGrpcChannel)); + } + + // The cache always stores Lazy so the steady-state Configure path and the + // recreate path use the same dictionary value shape. Recreate materializes first only to + // avoid publishing a lazy that could fault before we know channel creation succeeded. + Lazy newLazy = new(newChannel); + if (!this.channels.TryUpdate(cacheKey, newLazy, currentLazy)) + { + // Lost the race. Always queue the freshly-created channel for deferred disposal so + // it does not leak. Then return the winning entry — but if the cache slot has been + // removed entirely (e.g. concurrent DisposeAsync cleared the dictionary), do NOT + // hand back the doomed `newChannel`: it has already been scheduled for shutdown. + _ = ScheduleDeferredDisposeAsync(newChannel); + if (this.channels.TryGetValue(cacheKey, out Lazy? winner) && winner is not null) + { + return winner.Value; + } + + throw new ObjectDisposedException(this.GetType().FullName); + } + + if (currentLazy.IsValueCreated) + { + _ = ScheduleDeferredDisposeAsync(currentLazy.Value); + } + + return newChannel; + } + + static async Task ScheduleDeferredDisposeAsync(GrpcChannel channel) + { + try + { + await Task.Delay(TimeSpan.FromSeconds(30)).ConfigureAwait(false); + await DisposeChannelAsync(channel).ConfigureAwait(false); + } + catch (Exception ex) when (ex is not OutOfMemoryException + and not StackOverflowException + and not AccessViolationException + and not ThreadAbortException) + { + if (ex is not OperationCanceledException and not ObjectDisposedException) + { + Trace.TraceError( + "Unexpected exception while deferred-disposing gRPC channel in DurableTaskSchedulerClientExtensions.ScheduleDeferredDisposeAsync: {0}", + ex); + } + } } /// @@ -175,6 +292,7 @@ public async ValueTask DisposeAsync() } catch (Exception ex) when (ex is not OutOfMemoryException and not StackOverflowException + and not AccessViolationException and not ThreadAbortException) { // Swallow disposal exceptions - disposal should be best-effort to ensure diff --git a/src/Client/Grpc/ChannelRecreatingCallInvoker.cs b/src/Client/Grpc/ChannelRecreatingCallInvoker.cs new file mode 100644 index 000000000..f68ccfe9b --- /dev/null +++ b/src/Client/Grpc/ChannelRecreatingCallInvoker.cs @@ -0,0 +1,450 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Diagnostics; +using Microsoft.Extensions.Logging; + +namespace Microsoft.DurableTask.Client.Grpc; + +/// +/// A wrapper that observes RPC outcomes and triggers a fire-and-forget channel +/// recreation after a configurable number of consecutive transport failures +/// (, or on RPCs that are +/// not long-poll waits). This guards against half-open HTTP/2 connections that can otherwise wedge +/// an entire client process for the lifetime of the gRPC channel. +/// +/// +/// The wrapper holds an immutable (channel + invoker pair) and swaps +/// the entire pair atomically on recreate to avoid torn state. Streaming RPCs are forwarded without +/// outcome observation; only unary RPC outcomes count toward the failure threshold. +/// The triggering RPC still surfaces its original failure to the caller; subsequent RPCs benefit +/// from the recreated channel. +/// +sealed class ChannelRecreatingCallInvoker : CallInvoker, IAsyncDisposable +{ + /// + /// Methods on which a response is expected behavior + /// (long-poll-style waits) and must NOT be counted toward the recreate threshold. + /// + static readonly HashSet DeadlineExceededAllowedMethods = new(StringComparer.Ordinal) + { + "/TaskHubSidecarService/WaitForInstanceCompletion", + "/TaskHubSidecarService/WaitForInstanceStart", + }; + + readonly Func> recreator; + readonly int failureThreshold; + readonly TimeSpan minRecreateInterval; + readonly bool ownsChannel; + readonly ILogger logger; + + // Cancelled in DisposeAsync so an in-flight RecreateAsync stops promptly and does not leak the + // freshly created channel back into our state after we've disposed. + readonly CancellationTokenSource disposalCts = new(); + + // Accessed from call-site threads (reads) and the background recreate task (writes). + // Read/written with Volatile.Read / Volatile.Write to prevent torn reads and to publish + // the new reference so that callers on other threads observe it without additional + // synchronization. The TransportState itself is immutable so readers see a consistent + // (Channel, Invoker) pair once Volatile.Read returns. + TransportState state; + int consecutiveFailures; + int recreateInFlight; + // Stopwatch timestamps are monotonic, so backend-recreate cooldowns cannot be shortened or + // extended by wall-clock jumps. + long lastRecreateTimestamp; + int disposed; + + public ChannelRecreatingCallInvoker( + GrpcChannel initialChannel, + Func> recreator, + int failureThreshold, + TimeSpan minRecreateInterval, + bool ownsChannel, + ILogger logger) + { + this.recreator = recreator; + this.failureThreshold = failureThreshold; + this.minRecreateInterval = minRecreateInterval; + this.ownsChannel = ownsChannel; + this.logger = logger; + this.state = new TransportState(initialChannel, initialChannel.CreateCallInvoker()); + + // Backdate the initial timestamp so the first recreate is never blocked by the cooldown. + // Leaving the field at 0 would make the first attempt depend on how long the current process + // has been running since machine startup. + this.lastRecreateTimestamp = CreateInitialRecreateTimestamp(minRecreateInterval); + } + + public override TResponse BlockingUnaryCall( + Method method, string? host, CallOptions options, TRequest request) + { + TransportState current = Volatile.Read(ref this.state); + try + { + TResponse response = current.Invoker.BlockingUnaryCall(method, host, options, request); + this.RecordSuccess(); + return response; + } + catch (RpcException ex) + { + this.RecordFailure(ex.StatusCode, method.FullName); + throw; + } + } + + public override AsyncUnaryCall AsyncUnaryCall( + Method method, string? host, CallOptions options, TRequest request) + { + TransportState current = Volatile.Read(ref this.state); + AsyncUnaryCall call = current.Invoker.AsyncUnaryCall(method, host, options, request); + this.ObserveOutcome(call.ResponseAsync, method.FullName); + return call; + } + + public override AsyncServerStreamingCall AsyncServerStreamingCall( + Method method, string? host, CallOptions options, TRequest request) + { + // Streaming calls are forwarded without outcome observation. The streaming methods used by the + // DurableTask client are bounded snapshots (e.g. StreamInstanceHistory) where errors surface as + // exceptions to user code, so global failure counting on these would create false positives. + return Volatile.Read(ref this.state).Invoker.AsyncServerStreamingCall(method, host, options, request); + } + + public override AsyncClientStreamingCall AsyncClientStreamingCall( + Method method, string? host, CallOptions options) + { + return Volatile.Read(ref this.state).Invoker.AsyncClientStreamingCall(method, host, options); + } + + public override AsyncDuplexStreamingCall AsyncDuplexStreamingCall( + Method method, string? host, CallOptions options) + { + return Volatile.Read(ref this.state).Invoker.AsyncDuplexStreamingCall(method, host, options); + } + + public async ValueTask DisposeAsync() + { + if (Interlocked.Exchange(ref this.disposed, 1) != 0) + { + return; + } + + // Signal any in-flight RecreateAsync to abort. We do this BEFORE shutting down the channel so + // the recreator's cancellation token is observed and the recreate task does not race to + // Volatile.Write a freshly created channel into our state after we've torn it down. + try + { + this.disposalCts.Cancel(); + } + catch (ObjectDisposedException) + { + // Already disposed by a racing caller; nothing more to do for cancellation. + } + + if (!this.ownsChannel) + { + // The wrapper still owns disposalCts and background recreate state, but the caller owns the channel. + this.disposalCts.Dispose(); + return; + } + + TransportState current = Volatile.Read(ref this.state); + try + { + await ShutdownAndDisposeOwnedChannelAsync(current.Channel).ConfigureAwait(false); + } + finally + { + this.disposalCts.Dispose(); + } + } + + static long CreateInitialRecreateTimestamp(TimeSpan minRecreateInterval) => + Stopwatch.GetTimestamp() - ToStopwatchTicks(minRecreateInterval); + + static long ToStopwatchTicks(TimeSpan ts) + { + if (ts <= TimeSpan.Zero) + { + return 0; + } + + long timeSpanTicks = ts.Ticks; + long wholeSeconds = timeSpanTicks / TimeSpan.TicksPerSecond; + long remainingTimeSpanTicks = timeSpanTicks % TimeSpan.TicksPerSecond; + if (wholeSeconds > long.MaxValue / Stopwatch.Frequency) + { + return long.MaxValue; + } + + long wholeSecondStopwatchTicks = wholeSeconds * Stopwatch.Frequency; + long partialSecondStopwatchTicks = + (long)(((decimal)remainingTimeSpanTicks * Stopwatch.Frequency) / TimeSpan.TicksPerSecond); + if (wholeSecondStopwatchTicks > long.MaxValue - partialSecondStopwatchTicks) + { + return long.MaxValue; + } + + return wholeSecondStopwatchTicks + partialSecondStopwatchTicks; + } + + static TimeSpan ElapsedSince(long previousTimestamp, long nowTimestamp) + { + long elapsedTicks = Math.Max(0, nowTimestamp - previousTimestamp); + return TimeSpan.FromSeconds((double)elapsedTicks / Stopwatch.Frequency); + } + + void ObserveOutcome(Task responseAsync, string methodFullName) + { + // Use ContinueWith with TaskScheduler.Default so we don't capture sync context. + responseAsync.ContinueWith( + (t, state) => + { + var (self, name) = ((ChannelRecreatingCallInvoker, string))state!; + if (t.Status == TaskStatus.RanToCompletion) + { + self.RecordSuccess(); + } + else if (t.Exception?.InnerException is RpcException rpcEx) + { + self.RecordFailure(rpcEx.StatusCode, name); + } + }, + (this, methodFullName), + CancellationToken.None, + TaskContinuationOptions.ExecuteSynchronously, + TaskScheduler.Default); + } + + void RecordSuccess() + { + Volatile.Write(ref this.consecutiveFailures, 0); + } + + void RecordFailure(StatusCode status, string methodFullName) + { + // Only count statuses that indicate an actual transport problem, not application-level errors: + // * Unavailable — half-open connection, peer reset, or dead routing target. + // * DeadlineExceeded — the call exceeded the *client-supplied* deadline. This is a + // transport hint EXCEPT for long-poll RPCs (e.g. WaitForInstance*) + // where a deadline timeout is expected behavior, so those are + // excluded explicitly. + // Other statuses (NotFound, InvalidArgument, FailedPrecondition, etc.) are application + // failures that a fresh channel won't fix and would otherwise produce false-positive + // recreates. + bool counts = status switch + { + StatusCode.Unavailable => true, + StatusCode.DeadlineExceeded => !DeadlineExceededAllowedMethods.Contains(methodFullName), + _ => false, + }; + + if (!counts) + { + // Any gRPC status reply (even an application-level error) is proof that the transport + // is healthy enough to deliver round-trips, so reset the failure counter. This prevents + // unrelated app-level failures from silently accumulating between transport blips and + // tripping a false-positive recreate. + Volatile.Write(ref this.consecutiveFailures, 0); + return; + } + + int count = Interlocked.Increment(ref this.consecutiveFailures); + if (this.failureThreshold <= 0 || count < this.failureThreshold) + { + return; + } + + this.MaybeTriggerRecreate(count); + } + + void MaybeTriggerRecreate(int observedCount) + { + if (!this.HasReachedRecreateCooldown(Stopwatch.GetTimestamp())) + { + return; + } + + // Single-flight guard: only one recreate task in flight at a time. + if (Interlocked.CompareExchange(ref this.recreateInFlight, 1, 0) != 0) + { + return; + } + + // A previous recreate can finish after the fast-path cooldown check but before we acquire the + // single-flight slot. Re-check after taking the slot so we don't immediately recreate again. + if (!this.HasReachedRecreateCooldown(Stopwatch.GetTimestamp())) + { + Interlocked.Exchange(ref this.recreateInFlight, 0); + return; + } + + this.logger.RecreatingChannel(observedCount); + + // Keep recreate work off the caller's RPC-failure path. + _ = Task.Run(() => this.RecreateAsync(observedCount)); + } + + async Task RecreateAsync(int observedCount) + { + try + { + if (Volatile.Read(ref this.disposed) != 0) + { + return; + } + + TransportState current = Volatile.Read(ref this.state); + + // Link to the disposal CTS so DisposeAsync can promptly abort an in-flight recreate. + // The 30s timeout keeps a wedged recreator from holding the single-flight slot indefinitely. + using CancellationTokenSource cts = this.CreateRecreateCancellationSource(); + GrpcChannel newChannel = await this.recreator(current.Channel, cts.Token).ConfigureAwait(false); + + if (!ReferenceEquals(newChannel, current.Channel)) + { + // Re-check disposal before publishing the new channel into state. Otherwise we could + // race with DisposeAsync and leak the new channel (its socket handlers + DNS resolver + // would never be torn down). + if (Volatile.Read(ref this.disposed) != 0) + { + if (this.ownsChannel) + { + try + { + await ShutdownAndDisposeOwnedChannelAsync(newChannel).ConfigureAwait(false); + } + catch (Exception shutdownEx) when (shutdownEx is not OutOfMemoryException + and not StackOverflowException + and not AccessViolationException + and not ThreadAbortException) + { + // Best-effort cleanup. + } + } + + return; + } + + Volatile.Write(ref this.state, new TransportState(newChannel, newChannel.CreateCallInvoker())); + this.logger.ChannelRecreated(GetEndpointDescription(newChannel)); + + // When we own the channel, no external party is responsible for tearing down the old + // one. Defer disposal briefly so any in-flight RPCs issued against the previous + // CallInvoker before the swap can still complete (they already captured the old + // TransportState before Volatile.Write). + if (this.ownsChannel) + { + _ = ScheduleDeferredDisposeAsync(current.Channel); + } + } + else + { + // Returning the same channel means no swap was needed (for example, because a peer + // already refreshed a shared cache). Keep using the published state and reset the + // failure counter below. + } + + // Successful recreate (even if a peer beat us to it) → reset the failure counter. + Volatile.Write(ref this.consecutiveFailures, 0); + Volatile.Write(ref this.lastRecreateTimestamp, Stopwatch.GetTimestamp()); + } + catch (OperationCanceledException) when (Volatile.Read(ref this.disposed) != 0) + { + // We were disposed mid-recreate; nothing to log. + } + catch (Exception ex) when (ex is not OutOfMemoryException + and not StackOverflowException + and not AccessViolationException + and not ThreadAbortException) + { + this.logger.ChannelRecreateFailed(ex); + + // Update the last-attempt timestamp even on failure so the cooldown applies to failed attempts too. + Volatile.Write(ref this.lastRecreateTimestamp, Stopwatch.GetTimestamp()); + } + finally + { + Interlocked.Exchange(ref this.recreateInFlight, 0); + } + } + + CancellationTokenSource CreateRecreateCancellationSource() + { + try + { + CancellationTokenSource cts = CancellationTokenSource.CreateLinkedTokenSource(this.disposalCts.Token); + cts.CancelAfter(TimeSpan.FromSeconds(30)); + return cts; + } + catch (ObjectDisposedException) when (Volatile.Read(ref this.disposed) != 0) + { + CancellationTokenSource cts = new(); + cts.Cancel(); + return cts; + } + } + + static async Task ScheduleDeferredDisposeAsync(GrpcChannel channel) + { + try + { + // Grace period to let in-flight RPCs captured against the old invoker drain before we + // tear down the channel's HTTP handler / sockets. + await Task.Delay(TimeSpan.FromSeconds(30)).ConfigureAwait(false); + await ShutdownAndDisposeOwnedChannelAsync(channel).ConfigureAwait(false); + } + catch (Exception ex) when (ex is not OutOfMemoryException + and not StackOverflowException + and not AccessViolationException + and not ThreadAbortException) + { + if (ex is not OperationCanceledException and not ObjectDisposedException) + { + Trace.TraceError( + "Unexpected exception while deferred-disposing gRPC channel in ChannelRecreatingCallInvoker.ScheduleDeferredDisposeAsync: {0}", + ex); + } + } + } + + static string GetEndpointDescription(GrpcChannel channel) + { + return channel.Target ?? "(unknown)"; + } + + bool HasReachedRecreateCooldown(long nowTimestamp) + { + TimeSpan elapsed = ElapsedSince(Volatile.Read(ref this.lastRecreateTimestamp), nowTimestamp); + return elapsed >= this.minRecreateInterval; + } + + static async Task ShutdownAndDisposeOwnedChannelAsync(GrpcChannel channel) + { + try + { + await channel.ShutdownAsync().ConfigureAwait(false); + } + catch (Exception ex) when (ex is OperationCanceledException or ObjectDisposedException) + { + // Expected during shutdown races; nothing more to do. + } +#if NET6_0_OR_GREATER + channel.Dispose(); +#endif + } + + sealed class TransportState + { + public TransportState(GrpcChannel channel, CallInvoker invoker) + { + this.Channel = channel; + this.Invoker = invoker; + } + + public GrpcChannel Channel { get; } + + public CallInvoker Invoker { get; } + } +} diff --git a/src/Client/Grpc/GrpcDurableTaskClient.cs b/src/Client/Grpc/GrpcDurableTaskClient.cs index 46e4dd2ed..23350d4cd 100644 --- a/src/Client/Grpc/GrpcDurableTaskClient.cs +++ b/src/Client/Grpc/GrpcDurableTaskClient.cs @@ -52,7 +52,7 @@ public GrpcDurableTaskClient(string name, GrpcDurableTaskClientOptions options, { this.logger = Check.NotNull(logger); this.options = Check.NotNull(options); - this.asyncDisposable = GetCallInvoker(options, out CallInvoker callInvoker); + this.asyncDisposable = GetCallInvoker(options, logger, out CallInvoker callInvoker); this.sidecarClient = new TaskHubSidecarServiceClient(callInvoker); if (this.options.EnableEntitySupport) @@ -624,23 +624,69 @@ public override async Task> GetOrchestrationHistoryAsync( } } - static AsyncDisposable GetCallInvoker(GrpcDurableTaskClientOptions options, out CallInvoker callInvoker) + static AsyncDisposable GetCallInvoker(GrpcDurableTaskClientOptions options, ILogger logger, out CallInvoker callInvoker) { + Func>? recreator = options.Internal.ChannelRecreator; + int threshold = options.Internal.ChannelRecreateFailureThreshold; + TimeSpan cooldown = options.Internal.MinRecreateInterval; + bool recreateEnabled = recreator != null && threshold > 0; + if (options.Channel is GrpcChannel c) { + if (recreateEnabled) + { + ChannelRecreatingCallInvoker wrapper = new(c, recreator!, threshold, cooldown, ownsChannel: false, logger); + callInvoker = wrapper; + + // We do not own the externally-supplied channel, but we DO own the wrapper. Without + // disposing the wrapper its CancellationTokenSource and any in-flight recreate task + // would outlive the client. The wrapper's DisposeAsync is a no-op for the channel + // itself when ownsChannel == false. + return new AsyncDisposable(() => wrapper.DisposeAsync()); + } + callInvoker = c.CreateCallInvoker(); return default; } if (options.CallInvoker is CallInvoker invoker) { + // Externally supplied invoker — we do not own the underlying channel and cannot recreate it. callInvoker = invoker; return default; } + // Self-owned address path: create the channel ourselves so we own its lifecycle. c = GetChannel(options.Address); + + if (recreateEnabled) + { + ChannelRecreatingCallInvoker wrapper = new(c, recreator!, threshold, cooldown, ownsChannel: true, logger); + callInvoker = wrapper; + return new AsyncDisposable(() => wrapper.DisposeAsync()); + } + callInvoker = c.CreateCallInvoker(); - return new AsyncDisposable(() => new(c.ShutdownAsync())); + return CreateOwnedChannelDisposable(c); + } + + static AsyncDisposable CreateOwnedChannelDisposable(GrpcChannel channel) + { + return new AsyncDisposable(() => ShutdownAndDisposeChannelAsync(channel)); + } + + static async ValueTask ShutdownAndDisposeChannelAsync(GrpcChannel channel) + { + try + { + await channel.ShutdownAsync().ConfigureAwait(false); + } + finally + { +#if NET6_0_OR_GREATER + channel.Dispose(); +#endif + } } #if NET6_0_OR_GREATER diff --git a/src/Client/Grpc/GrpcDurableTaskClientOptions.cs b/src/Client/Grpc/GrpcDurableTaskClientOptions.cs index b67b62f9a..126aad345 100644 --- a/src/Client/Grpc/GrpcDurableTaskClientOptions.cs +++ b/src/Client/Grpc/GrpcDurableTaskClientOptions.cs @@ -22,4 +22,40 @@ public sealed class GrpcDurableTaskClientOptions : DurableTaskClientOptions /// Gets or sets the gRPC call invoker to use. Will supersede when provided. /// public CallInvoker? CallInvoker { get; set; } + + /// + /// Gets the internal options. These are not exposed directly, but configurable via + /// . + /// + internal InternalOptions Internal { get; } = new(); + + /// + /// Internal options are not exposed directly, but configurable via . + /// + internal class InternalOptions + { + /// + /// Gets or sets the number of consecutive transport failures (Unavailable responses, or + /// DeadlineExceeded responses on RPCs other than long-poll waits) after which the underlying + /// gRPC channel will be recreated to clear stale DNS, sub-channel state, or routing-affinity + /// bindings. Setting to 0 or a negative value disables channel recreation. Defaults to 5. + /// + public int ChannelRecreateFailureThreshold { get; set; } = 5; + + /// + /// Gets or sets the minimum interval between consecutive channel recreate attempts. Acts as a + /// cooldown so a burst of failures during a real outage cannot thrash the channel cache. + /// Defaults to 30 seconds. + /// + public TimeSpan MinRecreateInterval { get; set; } = TimeSpan.FromSeconds(30); + + /// + /// Gets or sets an optional callback invoked when the client requests a fresh gRPC channel after + /// repeated transport failures. The callback receives the previously-used channel and should + /// return either a freshly created channel or the currently cached channel if a peer has already + /// recreated it. Implementations are responsible for atomic swap and deferred disposal of the + /// old channel so in-flight RPCs from peer clients are not interrupted. + /// + public Func>? ChannelRecreator { get; set; } + } } diff --git a/src/Client/Grpc/Internal/InternalOptionsExtensions.cs b/src/Client/Grpc/Internal/InternalOptionsExtensions.cs new file mode 100644 index 000000000..800848f34 --- /dev/null +++ b/src/Client/Grpc/Internal/InternalOptionsExtensions.cs @@ -0,0 +1,33 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +namespace Microsoft.DurableTask.Client.Grpc.Internal; + +/// +/// Provides access to configuring internal options for the gRPC client. +/// +public static class InternalOptionsExtensions +{ + /// + /// Sets a callback that the client invokes when the underlying gRPC channel needs to be recreated + /// after repeated transport failures (e.g., because the backend was replaced and the existing channel + /// is wedged on a half-open HTTP/2 connection). The callback receives the channel the client last + /// observed and must return either a freshly created channel or the currently cached channel if a + /// peer client has already swapped it. Implementations are responsible for atomic swap and deferred + /// disposal of the old channel so in-flight RPCs from peer clients are not interrupted. + /// + /// The gRPC client options. + /// The recreate callback. + /// + /// This is an internal API that supports the DurableTask infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new DurableTask release. + /// + public static void SetChannelRecreator( + this GrpcDurableTaskClientOptions options, + Func> recreator) + { + options.Internal.ChannelRecreator = recreator ?? throw new ArgumentNullException(nameof(recreator)); + } +} diff --git a/src/Client/Grpc/Logs.cs b/src/Client/Grpc/Logs.cs index 8887f17d9..7eec022a4 100644 --- a/src/Client/Grpc/Logs.cs +++ b/src/Client/Grpc/Logs.cs @@ -49,5 +49,14 @@ public static void PurgingInstances(this ILogger logger, PurgeInstancesFilter fi string? statuses = filter?.Statuses is null ? null : string.Join("|", filter.Statuses); PurgingInstances(logger, filter?.CreatedFrom, filter?.CreatedTo, statuses); } + + [LoggerMessage(EventId = 80, Level = LogLevel.Warning, Message = "Recreating gRPC channel to backend after {failureCount} consecutive transport failures.")] + public static partial void RecreatingChannel(this ILogger logger, int failureCount); + + [LoggerMessage(EventId = 81, Level = LogLevel.Information, Message = "gRPC channel to backend has been recreated. New target: {endpoint}.")] + public static partial void ChannelRecreated(this ILogger logger, string endpoint); + + [LoggerMessage(EventId = 82, Level = LogLevel.Warning, Message = "gRPC channel recreation failed.")] + public static partial void ChannelRecreateFailed(this ILogger logger, Exception exception); } } diff --git a/src/InProcessTestHost/Sidecar/Dispatcher/TaskOrchestrationDispatcher.cs b/src/InProcessTestHost/Sidecar/Dispatcher/TaskOrchestrationDispatcher.cs index cc2812a83..f6bafb450 100644 --- a/src/InProcessTestHost/Sidecar/Dispatcher/TaskOrchestrationDispatcher.cs +++ b/src/InProcessTestHost/Sidecar/Dispatcher/TaskOrchestrationDispatcher.cs @@ -149,6 +149,22 @@ protected override async Task ExecuteWorkItemAsync(TaskOrchestrationWorkItem wor out bool continueAsNew); if (continueAsNew) { + // Self-targeted external events emitted during ContinueAsNew (for example by + // preserveUnprocessedEvents) must be carried into the new runtime state before + // we clear the per-iteration message lists. Otherwise they are dropped here and + // the next generation starts without the buffered events it is expecting. + foreach (TaskMessage message in orchestratorMessages) + { + if (message.Event is EventRaisedEvent eventRaised + && string.Equals( + message.OrchestrationInstance.InstanceId, + workItem.InstanceId, + StringComparison.Ordinal)) + { + workItem.OrchestrationRuntimeState.AddEvent(eventRaised); + } + } + // The previous execution is being replaced by a new one. Clear any // accumulated messages from the old execution so they are not // re-enqueued when the work item completes. Without this, stale diff --git a/src/Worker/AzureManaged/DurableTaskSchedulerWorkerExtensions.cs b/src/Worker/AzureManaged/DurableTaskSchedulerWorkerExtensions.cs index 0164d1cc5..3b9d4e55c 100644 --- a/src/Worker/AzureManaged/DurableTaskSchedulerWorkerExtensions.cs +++ b/src/Worker/AzureManaged/DurableTaskSchedulerWorkerExtensions.cs @@ -1,9 +1,10 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. -using System.Collections.Concurrent; +using System.Collections.Generic; using System.Diagnostics; using System.Linq; +using System.Threading; using Azure.Core; using Grpc.Net.Client; using Microsoft.DurableTask.Worker.Grpc; @@ -110,7 +111,8 @@ static void ConfigureSchedulerOptions( sealed class ConfigureGrpcChannel : IConfigureNamedOptions, IAsyncDisposable { readonly IOptionsMonitor schedulerOptions; - readonly ConcurrentDictionary> channels = new(); + readonly Dictionary channels = new(); + readonly object syncRoot = new(); volatile int disposed; /// @@ -135,15 +137,6 @@ public ConfigureGrpcChannel(IOptionsMonitor s /// The options instance to configure. public void Configure(string? name, GrpcDurableTaskWorkerOptions options) { -#if NET7_0_OR_GREATER - ObjectDisposedException.ThrowIf(this.disposed == 1, this); -#else - if (this.disposed == 1) - { - throw new ObjectDisposedException(nameof(ConfigureGrpcChannel)); - } -#endif - string optionsName = name ?? Options.DefaultName; DurableTaskSchedulerWorkerOptions source = this.schedulerOptions.Get(optionsName); @@ -153,12 +146,124 @@ public void Configure(string? name, GrpcDurableTaskWorkerOptions options) // Use a delimiter character (\u001F) that will not appear in typical endpoint URIs. string credentialType = source.Credential?.GetType().FullName ?? "null"; string cacheKey = $"{optionsName}\u001F{source.EndpointAddress}\u001F{source.TaskHubName}\u001F{source.ResourceId}\u001F{credentialType}\u001F{source.AllowInsecureCredentials}\u001F{source.WorkerId}"; - options.Channel = this.channels.GetOrAdd( - cacheKey, - _ => new Lazy(source.CreateChannel)).Value; + GrpcChannel? newChannel = null; + bool disposeNewChannel = false; + lock (this.syncRoot) + { +#if NET7_0_OR_GREATER + ObjectDisposedException.ThrowIf(this.disposed == 1, this); +#else + if (this.disposed == 1) + { + throw new ObjectDisposedException(nameof(ConfigureGrpcChannel)); + } +#endif + + if (!this.channels.TryGetValue(cacheKey, out GrpcChannel? channel)) + { + newChannel = source.CreateChannel(); + if (this.disposed == 1) + { + disposeNewChannel = true; + } + else + { + channel = newChannel; + this.channels.Add(cacheKey, channel); + } + } + + options.Channel = channel; + } + + if (disposeNewChannel) + { + newChannel!.Dispose(); + throw new ObjectDisposedException(nameof(ConfigureGrpcChannel)); + } + + options.SetChannelRecreator((oldChannel, ct) => this.RecreateChannelAsync(cacheKey, source, oldChannel, ct)); options.ConfigureForAzureManaged(); } + /// + /// Replaces the cached worker channel for the given key with a freshly created one and schedules + /// graceful disposal of the old channel after a grace period so any in-flight RPCs that already + /// captured the previous channel can drain. + /// + async Task RecreateChannelAsync(string cacheKey, DurableTaskSchedulerWorkerOptions source, GrpcChannel oldChannel, CancellationToken cancellation) + { + cancellation.ThrowIfCancellationRequested(); + + GrpcChannel? cachedChannel = null; + GrpcChannel? newChannel = null; + bool disposeNewChannel = false; + lock (this.syncRoot) + { +#if NET7_0_OR_GREATER + ObjectDisposedException.ThrowIf(this.disposed == 1, this); +#else + if (this.disposed == 1) + { + throw new ObjectDisposedException(nameof(ConfigureGrpcChannel)); + } +#endif + + // Worker cache keys include WorkerId, so a single worker normally owns each cached entry. + // Still guard against a stale recreate callback that is racing a more recent successful swap. + if (this.channels.TryGetValue(cacheKey, out cachedChannel) + && !ReferenceEquals(cachedChannel, oldChannel)) + { + return cachedChannel; + } + + // Materialize the replacement channel only after we've established that this callback still + // corresponds to the currently cached worker channel. + newChannel = source.CreateChannel(); + if (this.disposed == 1) + { + disposeNewChannel = true; + } + else + { + this.channels[cacheKey] = newChannel; + } + } + + if (disposeNewChannel) + { + await DisposeChannelAsync(newChannel!).ConfigureAwait(false); + throw new ObjectDisposedException(nameof(ConfigureGrpcChannel)); + } + + // Successful swap. Schedule graceful disposal of the old channel after a grace period + // so any in-flight RPCs that already captured it can drain. + _ = ScheduleDeferredDisposeAsync(oldChannel); + return newChannel!; + } + + static async Task ScheduleDeferredDisposeAsync(GrpcChannel channel) + { + try + { + // Grace period to let in-flight RPCs using the previous channel complete before draining it. + await Task.Delay(TimeSpan.FromSeconds(30)).ConfigureAwait(false); + await DisposeChannelAsync(channel).ConfigureAwait(false); + } + catch (Exception ex) when (ex is not OutOfMemoryException + and not StackOverflowException + and not AccessViolationException + and not ThreadAbortException) + { + if (ex is not OperationCanceledException and not ObjectDisposedException) + { + Trace.TraceError( + "Unexpected exception while deferred-disposing gRPC channel in DurableTaskSchedulerWorkerExtensions.ScheduleDeferredDisposeAsync: {0}", + ex); + } + } + } + /// public async ValueTask DisposeAsync() { @@ -167,14 +272,22 @@ public async ValueTask DisposeAsync() return; } - foreach (Lazy channel in this.channels.Values.Where(lazy => lazy.IsValueCreated)) + List channelsToDispose; + lock (this.syncRoot) + { + channelsToDispose = this.channels.Values.ToList(); + this.channels.Clear(); + } + + foreach (GrpcChannel channel in channelsToDispose) { try { - await DisposeChannelAsync(channel.Value).ConfigureAwait(false); + await DisposeChannelAsync(channel).ConfigureAwait(false); } catch (Exception ex) when (ex is not OutOfMemoryException and not StackOverflowException + and not AccessViolationException and not ThreadAbortException) { // Swallow disposal exceptions - disposal should be best-effort to ensure @@ -187,8 +300,6 @@ and not StackOverflowException } } } - - this.channels.Clear(); GC.SuppressFinalize(this); } diff --git a/src/Worker/Grpc/GrpcDurableTaskWorker.Processor.cs b/src/Worker/Grpc/GrpcDurableTaskWorker.Processor.cs index 58a6db040..7f61b68a9 100644 --- a/src/Worker/Grpc/GrpcDurableTaskWorker.Processor.cs +++ b/src/Worker/Grpc/GrpcDurableTaskWorker.Processor.cs @@ -54,29 +54,66 @@ public Processor(GrpcDurableTaskWorker worker, TaskHubSidecarServiceClient clien ILogger Logger => this.worker.logger; - public async Task ExecuteAsync(CancellationToken cancellation) + public async Task ExecuteAsync(CancellationToken cancellation) { + // Tracks consecutive failures against the same channel. Reset only after the stream + // has actually delivered a message (HelloAsync alone is not proof the channel is healthy). + int consecutiveChannelFailures = 0; + + // Tracks consecutive retry attempts for backoff calculation. Reset on first stream message. + int reconnectAttempt = 0; + Random backoffRandom = ReconnectBackoff.CreateRandom(); + while (!cancellation.IsCancellationRequested) { + bool channelLikelyPoisoned = false; try { - AsyncServerStreamingCall stream = await this.ConnectAsync(cancellation); - await this.ProcessWorkItemsAsync(stream, cancellation); + using AsyncServerStreamingCall stream = await this.ConnectAsync(cancellation); + await this.ProcessWorkItemsAsync( + stream, + cancellation, + onFirstMessage: () => + { + consecutiveChannelFailures = 0; + reconnectAttempt = 0; + }, + onChannelLikelyPoisoned: () => channelLikelyPoisoned = true); } catch (RpcException) when (cancellation.IsCancellationRequested) { // Worker is shutting down - let the method exit gracefully - break; + return ProcessorExitReason.Shutdown; } catch (RpcException ex) when (ex.StatusCode == StatusCode.Cancelled) { - // Sidecar is shutting down - retry + // Sidecar is shutting down - retry. Don't count toward channel-poisoned threshold: + // Cancelled is ambiguous and shouldn't drive recreate storms. this.Logger.SidecarDisconnected(); } + catch (RpcException ex) when (ex.StatusCode == StatusCode.DeadlineExceeded) + { + // Only HelloAsync carries a deadline. Once the work-item stream is established, + // ProcessWorkItemsAsync relies on the silent-disconnect timer instead of per-read deadlines. + // A DeadlineExceeded here therefore means the handshake hung on a stale or half-open channel. + this.Logger.HelloTimeout(this.internalOptions.HelloDeadline); + channelLikelyPoisoned = true; + } catch (RpcException ex) when (ex.StatusCode == StatusCode.Unavailable) { - // Sidecar is down - keep retrying + // Sidecar is down - keep retrying. this.Logger.SidecarUnavailable(); + channelLikelyPoisoned = true; + } + catch (RpcException ex) when (ex.StatusCode == StatusCode.Unauthenticated) + { + // Auth rejection — log distinctly so it's diagnosable. Do not count toward channel + // recreate: a fresh channel won't fix bad credentials. Reset the consecutive-failure + // counters: a status reply is proof the transport itself is healthy, so prior + // transport failures should not combine with later ones to trip the recreate. + this.Logger.AuthenticationFailed(ex); + consecutiveChannelFailures = 0; + reconnectAttempt = 0; } catch (RpcException ex) when (ex.StatusCode == StatusCode.NotFound) { @@ -91,7 +128,7 @@ public async Task ExecuteAsync(CancellationToken cancellation) catch (OperationCanceledException) when (cancellation.IsCancellationRequested) { // Shutting down, lets exit gracefully. - break; + return ProcessorExitReason.Shutdown; } catch (Exception ex) { @@ -99,19 +136,39 @@ public async Task ExecuteAsync(CancellationToken cancellation) this.Logger.UnexpectedError(ex, string.Empty); } + if (channelLikelyPoisoned) + { + consecutiveChannelFailures++; + int threshold = this.internalOptions.ChannelRecreateFailureThreshold; + if (threshold > 0 && consecutiveChannelFailures >= threshold) + { + this.Logger.RecreatingChannel(consecutiveChannelFailures); + return ProcessorExitReason.ChannelRecreateRequested; + } + } + try { - // CONSIDER: Exponential backoff - await Task.Delay(TimeSpan.FromSeconds(5), cancellation); + TimeSpan delay = ReconnectBackoff.Compute( + reconnectAttempt, + this.internalOptions.ReconnectBackoffBase, + this.internalOptions.ReconnectBackoffCap, + backoffRandom); + this.Logger.ReconnectBackoff(reconnectAttempt, (int)Math.Min(int.MaxValue, delay.TotalMilliseconds)); + reconnectAttempt = Math.Min(reconnectAttempt + 1, 30); // cap to avoid overflow in 2^attempt + await Task.Delay(delay, cancellation); } catch (OperationCanceledException) when (cancellation.IsCancellationRequested) { // Worker is shutting down - let the method exit gracefully - break; + return ProcessorExitReason.Shutdown; } } + + return ProcessorExitReason.Shutdown; } + static string GetActionsListForLogging(IReadOnlyList actions) { if (actions.Count == 0) @@ -242,7 +299,21 @@ async ValueTask BuildRuntimeStateAsync( async Task> ConnectAsync(CancellationToken cancellation) { - await this.client!.HelloAsync(EmptyMessage, cancellationToken: cancellation); + TimeSpan helloDeadline = this.internalOptions.HelloDeadline; + DateTime? deadline = null; + + if (helloDeadline > TimeSpan.Zero) + { + // Clamp to a UTC DateTime.MaxValue so a misconfigured (very large) HelloDeadline cannot + // throw ArgumentOutOfRangeException out of DateTime.Add and so the gRPC deadline remains + // unambiguous during internal normalization. + DateTime now = DateTime.UtcNow; + DateTime maxDeadlineUtc = DateTime.SpecifyKind(DateTime.MaxValue, DateTimeKind.Utc); + TimeSpan maxOffset = maxDeadlineUtc - now; + deadline = helloDeadline >= maxOffset ? maxDeadlineUtc : now.Add(helloDeadline); + } + + await this.client!.HelloAsync(EmptyMessage, deadline: deadline, cancellationToken: cancellation); this.Logger.EstablishedWorkItemConnection(); DurableTaskWorkerOptions workerOptions = this.worker.workerOptions; @@ -263,85 +334,115 @@ async ValueTask BuildRuntimeStateAsync( cancellationToken: cancellation); } - async Task ProcessWorkItemsAsync(AsyncServerStreamingCall stream, CancellationToken cancellation) + async Task ProcessWorkItemsAsync( + AsyncServerStreamingCall stream, + CancellationToken cancellation, + Action? onFirstMessage = null, + Action? onChannelLikelyPoisoned = null) { - // Create a new token source for timing out and a final token source that keys off of them both. - // The timeout token is used to detect when we are no longer getting any messages, including health checks. - // If this is the case, it signifies the connection has been dropped silently and we need to reconnect. - using var timeoutSource = new CancellationTokenSource(); - timeoutSource.CancelAfter(TimeSpan.FromSeconds(60)); - using var tokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellation, timeoutSource.Token); - - while (!cancellation.IsCancellationRequested) + // The timeout token (managed by WorkItemStreamConsumer) detects when no messages — + // including health pings sent periodically by the server — arrive within the configured + // window. If that fires we treat the stream as silently disconnected and reconnect. + TimeSpan silentDisconnectTimeout = this.internalOptions.SilentDisconnectTimeout; + + // NOTE: the consumer deliberately does NOT wrap its await foreach in an outer loop. + // The underlying IAsyncStreamReader is single-use — once the server terminates the stream + // (e.g. via a graceful HTTP/2 GOAWAY with OK trailers during a rolling upgrade), MoveNext + // returns false forever and re-entering await foreach would tight-spin with no yield. + WorkItemStreamResult result = await WorkItemStreamConsumer.ConsumeAsync( + ct => stream.ResponseStream.ReadAllAsync(ct), + silentDisconnectTimeout, + workItem => this.DispatchWorkItem(workItem, cancellation), + onFirstMessage, + cancellation); + + switch (result.Outcome) { - await foreach (P.WorkItem workItem in stream.ResponseStream.ReadAllAsync(tokenSource.Token)) - { - timeoutSource.CancelAfter(TimeSpan.FromSeconds(60)); - if (workItem.RequestCase == P.WorkItem.RequestOneofCase.OrchestratorRequest) - { - this.RunBackgroundTask( - workItem, - () => this.OnRunOrchestratorAsync( - workItem.OrchestratorRequest, - workItem.CompletionToken, - cancellation), - cancellation); - } - else if (workItem.RequestCase == P.WorkItem.RequestOneofCase.ActivityRequest) - { - this.RunBackgroundTask( - workItem, - () => this.OnRunActivityAsync( - workItem.ActivityRequest, - workItem.CompletionToken, - cancellation), - cancellation); - } - else if (workItem.RequestCase == P.WorkItem.RequestOneofCase.EntityRequest) - { - this.RunBackgroundTask( - workItem, - () => this.OnRunEntityBatchAsync(workItem.EntityRequest.ToEntityBatchRequest(), cancellation), - cancellation); - } - else if (workItem.RequestCase == P.WorkItem.RequestOneofCase.EntityRequestV2) - { - workItem.EntityRequestV2.ToEntityBatchRequest( - out EntityBatchRequest batchRequest, - out List operationInfos); - - this.RunBackgroundTask( - workItem, - () => this.OnRunEntityBatchAsync( - batchRequest, - cancellation, - workItem.CompletionToken, - operationInfos), - cancellation); - } - else if (workItem.RequestCase == P.WorkItem.RequestOneofCase.HealthPing) - { - // No-op - } - else - { - this.Logger.UnexpectedWorkItemType(workItem.RequestCase.ToString()); - } - } + case WorkItemStreamOutcome.Shutdown: + return; - if (tokenSource.IsCancellationRequested || tokenSource.Token.IsCancellationRequested) - { - // The token has cancelled, this means either: - // 1. The broader 'cancellation' was triggered, return here to start a graceful shutdown. - // 2. The timeoutSource was triggered, return here to trigger a reconnect to the backend. - if (!cancellation.IsCancellationRequested) + case WorkItemStreamOutcome.SilentDisconnect: + // Stream stopped producing messages (including health pings) for longer than the + // configured window. Treat as a poisoned channel. + this.Logger.ConnectionTimeout(); + onChannelLikelyPoisoned?.Invoke(); + return; + + case WorkItemStreamOutcome.GracefulDrain: + // Canonical signal sent by the backend during a graceful drain (HTTP/2 GOAWAY + + // OK trailers when a DTS instance is being replaced). Log it explicitly so + // operators can see it. Only count it toward the channel-poisoned threshold when + // the stream produced no messages: a stream that successfully delivered work and + // was then closed by the server is healthy behavior (e.g. routine rolling + // upgrade), and counting those would let a long-lived process accumulate spurious + // "poison" credits across many healthy drains. An empty drain, on the other hand, + // is a strong signal the channel is latched onto a dead/evacuated backend and + // needs to be recreated to pick up fresh DNS/routing. + this.Logger.StreamEndedByPeer(); + if (!result.FirstMessageObserved) { - // Since the cancellation came from the timeout, log a warning. - this.Logger.ConnectionTimeout(); + onChannelLikelyPoisoned?.Invoke(); } return; - } + } + } + + void DispatchWorkItem(P.WorkItem workItem, CancellationToken cancellation) + { + if (workItem.RequestCase == P.WorkItem.RequestOneofCase.OrchestratorRequest) + { + this.RunBackgroundTask( + workItem, + () => this.OnRunOrchestratorAsync( + workItem.OrchestratorRequest, + workItem.CompletionToken, + cancellation), + cancellation); + } + else if (workItem.RequestCase == P.WorkItem.RequestOneofCase.ActivityRequest) + { + this.RunBackgroundTask( + workItem, + () => this.OnRunActivityAsync( + workItem.ActivityRequest, + workItem.CompletionToken, + cancellation), + cancellation); + } + else if (workItem.RequestCase == P.WorkItem.RequestOneofCase.EntityRequest) + { + this.RunBackgroundTask( + workItem, + () => this.OnRunEntityBatchAsync(workItem.EntityRequest.ToEntityBatchRequest(), cancellation), + cancellation); + } + else if (workItem.RequestCase == P.WorkItem.RequestOneofCase.EntityRequestV2) + { + workItem.EntityRequestV2.ToEntityBatchRequest( + out EntityBatchRequest batchRequest, + out List operationInfos); + + this.RunBackgroundTask( + workItem, + () => this.OnRunEntityBatchAsync( + batchRequest, + cancellation, + workItem.CompletionToken, + operationInfos), + cancellation); + } + else if (workItem.RequestCase == P.WorkItem.RequestOneofCase.HealthPing) + { + // Health pings are heartbeat-only signals from the backend; the silent-disconnect + // timer reset (handled inside WorkItemStreamConsumer) is the actionable behavior. + // Logging at Trace allows operators to confirm liveness without flooding info-level + // telemetry. + this.Logger.ReceivedHealthPing(); + } + else + { + this.Logger.UnexpectedWorkItemType(workItem.RequestCase.ToString()); } } diff --git a/src/Worker/Grpc/GrpcDurableTaskWorker.cs b/src/Worker/Grpc/GrpcDurableTaskWorker.cs index 1d03d96ae..a200cedbf 100644 --- a/src/Worker/Grpc/GrpcDurableTaskWorker.cs +++ b/src/Worker/Grpc/GrpcDurableTaskWorker.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. +using System.Diagnostics; using Microsoft.DurableTask.Worker.Hosting; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; @@ -12,6 +13,8 @@ namespace Microsoft.DurableTask.Worker.Grpc; /// sealed partial class GrpcDurableTaskWorker : DurableTaskWorker { + static readonly TimeSpan DeferredDisposeGracePeriod = TimeSpan.FromSeconds(30); + readonly GrpcDurableTaskWorkerOptions grpcOptions; readonly DurableTaskWorkerOptions workerOptions; readonly IServiceProvider services; @@ -57,9 +60,146 @@ public GrpcDurableTaskWorker( /// protected override async Task ExecuteAsync(CancellationToken stoppingToken) { - await using AsyncDisposable disposable = this.GetCallInvoker(out CallInvoker callInvoker, out string address); - this.logger.StartingTaskHubWorker(address); - await new Processor(this, new(callInvoker), this.orchestrationFilter, this.ExceptionPropertiesProvider).ExecuteAsync(stoppingToken); + AsyncDisposable workerOwnedChannelDisposable = this.GetCallInvoker(out CallInvoker callInvoker, out string address); + + // Seed the tracker from the configured channel once, then update latestObservedChannel after + // each successful recreate. Do not re-read this.grpcOptions.Channel inside the loop: the options + // object keeps its original Channel reference even when a shared backend-channel cache has already + // swapped to a newer instance. + GrpcChannel? latestObservedChannel = this.grpcOptions.Channel; + try + { + this.logger.StartingTaskHubWorker(address); + + while (!stoppingToken.IsCancellationRequested) + { + Processor processor = new(this, new(callInvoker), this.orchestrationFilter, this.ExceptionPropertiesProvider); + ProcessorExitReason reason = await processor.ExecuteAsync(stoppingToken); + + if (reason == ProcessorExitReason.Shutdown || stoppingToken.IsCancellationRequested) + { + return; + } + + // ChannelRecreateRequested: try to obtain a fresh channel before re-entering the loop. + ChannelRecreateResult result = await this.TryRecreateChannelAsync(stoppingToken, workerOwnedChannelDisposable, latestObservedChannel); + if (result.Recreated) + { + this.ApplySuccessfulRecreate( + result, + ref callInvoker, + ref address, + ref latestObservedChannel, + ref workerOwnedChannelDisposable, + DeferredDisposeGracePeriod); + } + + // If we couldn't recreate (e.g., caller-owned CallInvoker), fall through and retry on the + // existing transport. The Processor's outer backoff already waited before returning. + } + } + finally + { + await workerOwnedChannelDisposable.DisposeAsync(); + } + } + + async Task TryRecreateChannelAsync( + CancellationToken cancellation, + AsyncDisposable currentWorkerOwnedDisposable, + GrpcChannel? currentChannel) + { + // There are three ownership models here: + // 1. A caller-supplied recreator owns shared/cache-backed channels and decides how to swap them. + // 2. The worker owns an Address-created channel and can rebuild it directly. + // 3. The caller owns the Channel/CallInvoker, so the worker can only keep retrying on the same transport. + + // Path 1: caller (or extension method like ConfigureGrpcChannel) supplied a recreator. + Func>? recreator = this.grpcOptions.Internal.ChannelRecreator; + if (recreator is not null && currentChannel is not null) + { + try + { + GrpcChannel newChannel = await recreator(currentChannel, cancellation).ConfigureAwait(false); + if (!ReferenceEquals(newChannel, currentChannel)) + { + // The recreator owns the replacement channel lifetime. Return a default disposable + // so the caller disposes the previous worker-owned channel exactly once without + // carrying that ownership forward to the recreated state. + return new ChannelRecreateResult(true, newChannel.CreateCallInvoker(), newChannel.Target, default, newChannel); + } + + // Recreator returned the same instance — nothing to swap. + return ChannelRecreateResult.NotRecreated(currentWorkerOwnedDisposable); + } + catch (OperationCanceledException) when (cancellation.IsCancellationRequested) + { + throw; + } + catch (Exception ex) when (!IsFatal(ex)) + { + // Don't crash the worker if recreate fails; just keep using the existing transport. + this.logger.UnexpectedError(ex, string.Empty); + return ChannelRecreateResult.NotRecreated(currentWorkerOwnedDisposable); + } + } + + // Path 2: worker-owned channel created from Address. We can rebuild it ourselves. + if (this.grpcOptions.Channel is null + && this.grpcOptions.CallInvoker is null) + { + try + { + GrpcChannel newChannel = GetChannel(this.grpcOptions.Address); + // This new channel is worker-owned, so hand back a disposable that will shut it down + // (and dispose it on frameworks where GrpcChannel implements IDisposable). + AsyncDisposable newDisposable = CreateOwnedChannelDisposable(newChannel); + return new ChannelRecreateResult(true, newChannel.CreateCallInvoker(), newChannel.Target, newDisposable, newChannel); + } + catch (OperationCanceledException) when (cancellation.IsCancellationRequested) + { + throw; + } + catch (Exception ex) when (!IsFatal(ex)) + { + this.logger.UnexpectedError(ex, string.Empty); + return ChannelRecreateResult.NotRecreated(currentWorkerOwnedDisposable); + } + } + + // Path 3: caller-owned CallInvoker or externally-supplied Channel without a recreator. + // No safe way to recreate; let the inner loop continue trying on the existing transport. + return ChannelRecreateResult.NotRecreated(currentWorkerOwnedDisposable); + + static bool IsFatal(Exception ex) => ex is OutOfMemoryException + or StackOverflowException + or AccessViolationException + or ThreadAbortException; + } + + readonly struct ChannelRecreateResult + { + public ChannelRecreateResult(bool recreated, CallInvoker? newCallInvoker, string? newAddress, AsyncDisposable newWorkerOwnedDisposable, GrpcChannel? newChannel) + { + this.Recreated = recreated; + this.NewCallInvoker = newCallInvoker; + this.NewAddress = newAddress; + this.NewWorkerOwnedDisposable = newWorkerOwnedDisposable; + this.NewChannel = newChannel; + } + + public bool Recreated { get; } + + public CallInvoker? NewCallInvoker { get; } + + public string? NewAddress { get; } + + public AsyncDisposable NewWorkerOwnedDisposable { get; } + + public GrpcChannel? NewChannel { get; } + + public static ChannelRecreateResult NotRecreated(AsyncDisposable currentDisposable) + => new(false, null, null, currentDisposable, null); } #if NET6_0_OR_GREATER @@ -86,6 +226,72 @@ static GrpcChannel GetChannel(string? address) } #endif + static AsyncDisposable CreateOwnedChannelDisposable(GrpcChannel channel) + { + return new AsyncDisposable(() => ShutdownAndDisposeChannelAsync(channel)); + } + + static async ValueTask ShutdownAndDisposeChannelAsync(GrpcChannel channel) + { + try + { + await channel.ShutdownAsync().ConfigureAwait(false); + } + finally + { +#if NET6_0_OR_GREATER + channel.Dispose(); +#endif + } + } + + void ApplySuccessfulRecreate( + ChannelRecreateResult result, + ref CallInvoker callInvoker, + ref string address, + ref GrpcChannel? latestObservedChannel, + ref AsyncDisposable workerOwnedChannelDisposable, + TimeSpan deferredDisposeGracePeriod) + { + callInvoker = result.NewCallInvoker!; + address = result.NewAddress!; + latestObservedChannel = result.NewChannel; + AsyncDisposable previousDisposable = workerOwnedChannelDisposable; + workerOwnedChannelDisposable = result.NewWorkerOwnedDisposable; + this.logger.ChannelRecreated(address); + + // Defer disposal of the prior worker-owned channel so background completion/abandon RPCs + // from the previous processor instance can drain before the transport is torn down. + // Path 1 hands ownership of the replacement channel to the recreator, Path 2 installs a + // fresh worker-owned disposable, and Path 3 never recreates at all. + _ = ScheduleDeferredDisposeAsync(previousDisposable, deferredDisposeGracePeriod); + } + + static async Task ScheduleDeferredDisposeAsync(AsyncDisposable disposable, TimeSpan delay) + { + try + { + if (delay > TimeSpan.Zero) + { + await Task.Delay(delay).ConfigureAwait(false); + } + + await disposable.DisposeAsync().ConfigureAwait(false); + } + catch (Exception ex) when (ex is not OutOfMemoryException + and not StackOverflowException + and not AccessViolationException + and not ThreadAbortException) + { + if (ex is not OperationCanceledException and not ObjectDisposedException) + { + Trace.TraceError( + "Unexpected exception while deferred-disposing gRPC channel in GrpcDurableTaskWorker.ScheduleDeferredDisposeAsync: {0}", + ex); + } + } + } + AsyncDisposable GetCallInvoker(out CallInvoker callInvoker, out string address) { if (this.grpcOptions.Channel is GrpcChannel c) @@ -105,7 +311,7 @@ AsyncDisposable GetCallInvoker(out CallInvoker callInvoker, out string address) c = GetChannel(this.grpcOptions.Address); callInvoker = c.CreateCallInvoker(); address = c.Target; - return new AsyncDisposable(() => new(c.ShutdownAsync())); + return CreateOwnedChannelDisposable(c); } static ILogger CreateLogger(ILoggerFactory loggerFactory, DurableTaskWorkerOptions options) diff --git a/src/Worker/Grpc/GrpcDurableTaskWorkerOptions.cs b/src/Worker/Grpc/GrpcDurableTaskWorkerOptions.cs index 52372f65d..464c50a8b 100644 --- a/src/Worker/Grpc/GrpcDurableTaskWorkerOptions.cs +++ b/src/Worker/Grpc/GrpcDurableTaskWorkerOptions.cs @@ -97,5 +97,50 @@ internal class InternalOptions /// unlock events into the history when an orchestration terminates while holding an entity lock. /// public bool InsertEntityUnlocksOnCompletion { get; set; } + + /// + /// Gets or sets the maximum amount of time to wait for the initial Hello handshake against the + /// backend before treating the connect attempt as failed and retrying. A non-positive value disables + /// the deadline. Defaults to 30 seconds. This guards against half-open HTTP/2 connections that can + /// otherwise cause reconnect to hang indefinitely. + /// + public TimeSpan HelloDeadline { get; set; } = TimeSpan.FromSeconds(30); + + /// + /// Gets or sets the maximum amount of time the worker will wait between messages on an established + /// work-item stream before treating the channel as silently disconnected and forcing a reconnect. + /// The backend sends periodic health-ping work items expressly to keep this window alive when no + /// real work is flowing, so this value should be larger than the server's ping cadence to avoid + /// false positives. Defaults to 120 seconds. A non-positive value disables silent-disconnect detection. + /// + public TimeSpan SilentDisconnectTimeout { get; set; } = TimeSpan.FromSeconds(120); + + /// + /// Gets or sets the number of consecutive connect failures (Hello timeouts, Unavailable responses, or + /// silent stream disconnects) after which the underlying gRPC channel will be recreated to clear + /// stale DNS, sub-channel state, or routing-affinity bindings. Setting to 0 or a negative value + /// disables channel recreation. Defaults to 5. + /// + public int ChannelRecreateFailureThreshold { get; set; } = 5; + + /// + /// Gets or sets the base delay used when computing reconnect backoff with full jitter: + /// the actual delay is uniformly random in [0, min(cap, base * 2^attempt)]. Defaults to 1 second. + /// + public TimeSpan ReconnectBackoffBase { get; set; } = TimeSpan.FromSeconds(1); + + /// + /// Gets or sets the maximum delay used when computing reconnect backoff with full jitter. + /// Defaults to 30 seconds. + /// + public TimeSpan ReconnectBackoffCap { get; set; } = TimeSpan.FromSeconds(30); + + /// + /// Gets or sets an optional callback invoked when the worker requests a fresh gRPC channel after + /// repeated connect failures. The callback receives the previously-used channel and should return + /// its replacement. Implementations are responsible for publishing the replacement channel and + /// deferring disposal of the old channel so in-flight RPCs already using it are not interrupted. + /// + public Func>? ChannelRecreator { get; set; } } } diff --git a/src/Worker/Grpc/Internal/InternalOptionsExtensions.cs b/src/Worker/Grpc/Internal/InternalOptionsExtensions.cs index 0739c15c2..b26b36cc6 100644 --- a/src/Worker/Grpc/Internal/InternalOptionsExtensions.cs +++ b/src/Worker/Grpc/Internal/InternalOptionsExtensions.cs @@ -27,4 +27,60 @@ public static void ConfigureForAzureManaged(this GrpcDurableTaskWorkerOptions op options.Internal.ConvertOrchestrationEntityEvents = true; options.Internal.InsertEntityUnlocksOnCompletion = true; } + + /// + /// Sets a callback that the worker invokes when the underlying gRPC channel needs to be recreated + /// after repeated connect failures (e.g., because the backend was replaced and the existing channel + /// is wedged on a half-open HTTP/2 connection). The callback receives the channel the worker last + /// observed and must return its replacement. Implementations are responsible for publishing the + /// replacement channel and deferring disposal of the old channel so in-flight RPCs already using it + /// are not interrupted. + /// + /// The gRPC worker options. + /// The recreate callback. + /// + /// This is an internal API that supports the DurableTask infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new DurableTask release. + /// + public static void SetChannelRecreator( + this GrpcDurableTaskWorkerOptions options, + Func> recreator) + { + options.Internal.ChannelRecreator = recreator ?? throw new ArgumentNullException(nameof(recreator)); + } + + /// + /// Sets the deadline applied to the initial Hello RPC during worker connect. A wedged + /// handshake on a half-open HTTP/2 connection no longer hangs the reconnect loop indefinitely. + /// + /// The gRPC worker options. + /// The deadline; non-positive disables the deadline. + /// + /// This is an internal API that supports the DurableTask infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. + /// + public static void SetHelloDeadline(this GrpcDurableTaskWorkerOptions options, TimeSpan deadline) + { + options.Internal.HelloDeadline = deadline; + } + + /// + /// Sets the silent-disconnect timeout. If no message (including health pings) arrives on the + /// work-item stream within this window, the worker treats the stream as silently disconnected + /// and reconnects. + /// + /// The gRPC worker options. + /// The timeout; non-positive disables silent-disconnect detection. + /// + /// This is an internal API that supports the DurableTask infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. + /// + public static void SetSilentDisconnectTimeout(this GrpcDurableTaskWorkerOptions options, TimeSpan timeout) + { + options.Internal.SilentDisconnectTimeout = timeout; + } } diff --git a/src/Worker/Grpc/Logs.cs b/src/Worker/Grpc/Logs.cs index b7d1f957e..ea585dcfa 100644 --- a/src/Worker/Grpc/Logs.cs +++ b/src/Worker/Grpc/Logs.cs @@ -79,5 +79,26 @@ static partial class Logs [LoggerMessage(EventId = 65, Level = LogLevel.Information, Message = "{instanceId}: Abandoned entity work item. Completion token = '{completionToken}'")] public static partial void AbandonedEntityWorkItem(this ILogger logger, string instanceId, string completionToken); + + [LoggerMessage(EventId = 70, Level = LogLevel.Warning, Message = "Hello handshake to backend timed out after {timeout}. Will retry.")] + public static partial void HelloTimeout(this ILogger logger, TimeSpan timeout); + + [LoggerMessage(EventId = 71, Level = LogLevel.Warning, Message = "Authentication failed when connecting to backend. Will retry.")] + public static partial void AuthenticationFailed(this ILogger logger, Exception ex); + + [LoggerMessage(EventId = 72, Level = LogLevel.Warning, Message = "Recreating gRPC channel to backend after {failureCount} consecutive connect failures.")] + public static partial void RecreatingChannel(this ILogger logger, int failureCount); + + [LoggerMessage(EventId = 73, Level = LogLevel.Information, Message = "gRPC channel to backend has been recreated. New target: {endpoint}.")] + public static partial void ChannelRecreated(this ILogger logger, string endpoint); + + [LoggerMessage(EventId = 74, Level = LogLevel.Debug, Message = "Reconnect attempt {attempt} will retry after {delayMs} ms.")] + public static partial void ReconnectBackoff(this ILogger logger, int attempt, int delayMs); + + [LoggerMessage(EventId = 75, Level = LogLevel.Trace, Message = "Received health ping from the backend.")] + public static partial void ReceivedHealthPing(this ILogger logger); + + [LoggerMessage(EventId = 76, Level = LogLevel.Information, Message = "Work-item stream ended by the backend (graceful close). Will reconnect.")] + public static partial void StreamEndedByPeer(this ILogger logger); } } diff --git a/src/Worker/Grpc/ProcessorExitReason.cs b/src/Worker/Grpc/ProcessorExitReason.cs new file mode 100644 index 000000000..144ae4304 --- /dev/null +++ b/src/Worker/Grpc/ProcessorExitReason.cs @@ -0,0 +1,22 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +namespace Microsoft.DurableTask.Worker.Grpc; + +/// +/// Indicates why returned to its caller. +/// +enum ProcessorExitReason +{ + /// + /// The processor exited because cancellation was requested (graceful shutdown). + /// + Shutdown, + + /// + /// The processor exited because the underlying gRPC channel appears poisoned and should be recreated + /// before the next reconnect attempt. The caller is expected to obtain a fresh channel and rebuild the + /// processor. + /// + ChannelRecreateRequested, +} diff --git a/src/Worker/Grpc/ReconnectBackoff.cs b/src/Worker/Grpc/ReconnectBackoff.cs new file mode 100644 index 000000000..dd08a58ce --- /dev/null +++ b/src/Worker/Grpc/ReconnectBackoff.cs @@ -0,0 +1,60 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Security.Cryptography; + +namespace Microsoft.DurableTask.Worker.Grpc; + +/// +/// Helpers for computing reconnect backoff delays in the gRPC worker. +/// +static class ReconnectBackoff +{ + /// + /// Creates a random source for reconnect jitter using an explicit random seed so multiple workers on + /// older runtimes don't converge on the same time-based seed. + /// + /// A random source suitable for reconnect jitter. + public static Random CreateRandom() + { + byte[] seedBytes = new byte[sizeof(int)]; + using RandomNumberGenerator randomNumberGenerator = RandomNumberGenerator.Create(); + randomNumberGenerator.GetBytes(seedBytes); + return new Random(BitConverter.ToInt32(seedBytes, 0)); + } + + /// + /// Computes a full-jitter exponential backoff delay: a uniformly random TimeSpan in + /// [0, min(cap, base * 2^attempt)]. Returns when + /// or is non-positive. + /// + /// The retry attempt index, starting at 0. + /// The base delay used for the exponential growth. + /// The maximum delay before jitter is applied. + /// The random source used for jitter. + /// The computed jittered delay. + public static TimeSpan Compute(int attempt, TimeSpan baseDelay, TimeSpan cap, Random random) + { + if (baseDelay <= TimeSpan.Zero) + { + return TimeSpan.Zero; + } + + if (attempt < 0) + { + attempt = 0; + } + + // Cap the exponent to avoid overflow in 2^attempt for pathological attempt values. + int safeAttempt = Math.Min(attempt, 30); + + double capMs = Math.Max(0, cap.TotalMilliseconds); + double exponentialMs = baseDelay.TotalMilliseconds * Math.Pow(2, safeAttempt); + double upperBoundMs = Math.Min(capMs, exponentialMs); + + // Full jitter intentionally allows any value in the retry window. The wide spread keeps many + // workers that saw the same outage from reconnecting in lockstep against the backend. + double jitteredMs = random.NextDouble() * upperBoundMs; + return TimeSpan.FromMilliseconds(jitteredMs); + } +} diff --git a/src/Worker/Grpc/WorkItemStreamConsumer.cs b/src/Worker/Grpc/WorkItemStreamConsumer.cs new file mode 100644 index 000000000..cce0b23af --- /dev/null +++ b/src/Worker/Grpc/WorkItemStreamConsumer.cs @@ -0,0 +1,152 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using P = Microsoft.DurableTask.Protobuf; + +namespace Microsoft.DurableTask.Worker.Grpc; + +/// +/// Reason a invocation terminated. +/// +internal enum WorkItemStreamOutcome +{ + /// The outer cancellation token was signalled (worker shutdown). + Shutdown, + + /// The silent-disconnect timer fired (no item or health ping arrived in time). + SilentDisconnect, + + /// The stream completed without exception (e.g. server-initiated graceful close). + GracefulDrain, +} + +/// Result of consuming a work-item stream. +/// Why the loop terminated. +/// Whether at least one message was delivered before termination. +internal readonly record struct WorkItemStreamResult(WorkItemStreamOutcome Outcome, bool FirstMessageObserved); + +/// +/// Consumes a work-item stream and classifies its termination. Owns the silent-disconnect timeout +/// wiring and the catch chain that distinguishes a wedged stream (silent disconnect) from a normal +/// graceful drain or a worker shutdown. Per-item dispatch is delegated to the caller via the +/// onItem callback. +/// +/// +/// The onItem callback is synchronous because production dispatch is fire-and-forget. +/// +internal static class WorkItemStreamConsumer +{ + // Stay just below the historical CancelAfter(TimeSpan) ceiling so extremely large configuration + // values are still treated as "effectively infinite" without depending on framework-specific edge cases. + static readonly TimeSpan MaxSupportedCancelAfterTimeout = TimeSpan.FromMilliseconds(int.MaxValue - 1d); + + /// + /// Consume a work-item stream until shutdown, silent disconnect, or graceful drain. + /// + /// + /// Factory that opens the stream with the supplied linked-cancellation token. Production passes + /// ct => stream.ResponseStream.ReadAllAsync(ct); tests pass arbitrary fakes. + /// + /// + /// How long to wait between successive items (or health pings) before treating the stream as + /// silently disconnected. Non-positive values disable detection entirely. + /// + /// + /// Synchronous per-item dispatch. Invoked once per delivered work item, after the silent-disconnect + /// timer has been re-armed. + /// + /// + /// Optional callback invoked exactly once when the first message is observed. Used by callers to + /// reset retry counters that should only count consecutive transport failures. + /// + /// Outer worker cancellation token. + /// The classified outcome plus whether any message was observed. + public static async Task ConsumeAsync( + Func> openStream, + TimeSpan silentDisconnectTimeout, + Action onItem, + Action? onFirstMessage, + CancellationToken cancellation) + { + bool silentDisconnectEnabled = silentDisconnectTimeout > TimeSpan.Zero; + + // Clamp enormous values once up-front so the timer-reset path can simply re-arm the same window. + TimeSpan effectiveTimeout = ClampCancelAfterTimeout(silentDisconnectTimeout); + + using CancellationTokenSource timeoutSource = new(); + void ArmSilentDisconnectTimer() + { + if (silentDisconnectEnabled) + { + timeoutSource.CancelAfter(effectiveTimeout); + } + } + + // Arm once before reading so the initial gap before the first message is also bounded. + ArmSilentDisconnectTimer(); + + using CancellationTokenSource tokenSource = CancellationTokenSource.CreateLinkedTokenSource( + cancellation, timeoutSource.Token); + + bool firstMessageObserved = false; + + try + { + await foreach (P.WorkItem workItem in openStream(tokenSource.Token).ConfigureAwait(false)) + { + ArmSilentDisconnectTimer(); + + if (!firstMessageObserved) + { + firstMessageObserved = true; + onFirstMessage?.Invoke(); + } + + onItem(workItem); + } + } + catch (OperationCanceledException) when (cancellation.IsCancellationRequested) + { + // Worker is shutting down. + } + catch (OperationCanceledException) when (timeoutSource.IsCancellationRequested) + { + // Silent-disconnect timer fired and grpc-dotnet surfaced cancellation as OCE + // (when GrpcChannelOptions.ThrowOperationCanceledOnCancellation == true). + } + catch (RpcException ex) when (ex.StatusCode == StatusCode.Cancelled + && timeoutSource.IsCancellationRequested + && !cancellation.IsCancellationRequested) + { + // Silent-disconnect timer fired mid-MoveNext. By default + // (GrpcChannelOptions.ThrowOperationCanceledOnCancellation == false), grpc-dotnet + // surfaces the linked cancellation as RpcException(Cancelled) rather than OCE. + // Without this catch the exception would propagate past the silent-disconnect + // branch and the recreate path would never fire. + } + + if (cancellation.IsCancellationRequested) + { + return new WorkItemStreamResult(WorkItemStreamOutcome.Shutdown, firstMessageObserved); + } + + if (timeoutSource.IsCancellationRequested) + { + return new WorkItemStreamResult(WorkItemStreamOutcome.SilentDisconnect, firstMessageObserved); + } + + return new WorkItemStreamResult(WorkItemStreamOutcome.GracefulDrain, firstMessageObserved); + } + + static TimeSpan ClampCancelAfterTimeout(TimeSpan timeout) + { + if (timeout <= TimeSpan.Zero) + { + return timeout; + } + + return timeout <= MaxSupportedCancelAfterTimeout + ? timeout + : MaxSupportedCancelAfterTimeout; + } +} diff --git a/test/Client/Grpc.Tests/GrpcDurableTaskClientChannelRecreationTests.cs b/test/Client/Grpc.Tests/GrpcDurableTaskClientChannelRecreationTests.cs new file mode 100644 index 000000000..29d3e9105 --- /dev/null +++ b/test/Client/Grpc.Tests/GrpcDurableTaskClientChannelRecreationTests.cs @@ -0,0 +1,289 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Diagnostics; +using System.Reflection; +using System.Text; +using Grpc.Core; +using Microsoft.DurableTask; +using Microsoft.DurableTask.Client.Grpc.Internal; +using Microsoft.Extensions.Logging.Abstractions; + +namespace Microsoft.DurableTask.Client.Grpc.Tests; + +public class GrpcDurableTaskClientChannelRecreationTests +{ + static readonly Marshaller StringMarshaller = Marshallers.Create( + value => Encoding.UTF8.GetBytes(value), + bytes => Encoding.UTF8.GetString(bytes)); + static readonly Method TestMethod = new( + MethodType.Unary, + "TestService", + "TestMethod", + StringMarshaller, + StringMarshaller); + static readonly MethodInfo GetCallInvokerMethod = typeof(GrpcDurableTaskClient) + .GetMethod("GetCallInvoker", BindingFlags.Static | BindingFlags.NonPublic)!; + static readonly MethodInfo ToStopwatchTicksMethod = typeof(ChannelRecreatingCallInvoker) + .GetMethod("ToStopwatchTicks", BindingFlags.Static | BindingFlags.NonPublic)!; + + [Fact] + public async Task GetCallInvoker_WithProvidedChannel_RecreatesTransportAfterUnaryFailure() + { + // Arrange + CallbackHttpMessageHandler initialHandler = new((request, cancellationToken) => + Task.FromResult(CreateFailureResponse(StatusCode.Unavailable, "initial transport failure"))); + TaskCompletionSource recreatedTransportUsed = new(TaskCreationOptions.RunContinuationsAsynchronously); + CallbackHttpMessageHandler recreatedHandler = new((request, cancellationToken) => + { + recreatedTransportUsed.TrySetResult(); + return Task.FromResult(CreateFailureResponse(StatusCode.Unavailable, "recreated transport failure")); + }); + + GrpcChannel channel = CreateChannel("http://initial.client.test", initialHandler); + GrpcChannel recreatedChannel = CreateChannel("http://recreated.client.test", recreatedHandler); + GrpcDurableTaskClientOptions options = new() + { + Channel = channel, + }; + options.Internal.ChannelRecreateFailureThreshold = 2; + options.Internal.MinRecreateInterval = TimeSpan.Zero; + + TaskCompletionSource recreateRequested = new(TaskCreationOptions.RunContinuationsAsynchronously); + int recreatorCalls = 0; + options.SetChannelRecreator((existingChannel, ct) => + { + recreatorCalls++; + recreateRequested.TrySetResult(existingChannel); + return Task.FromResult(recreatedChannel); + }); + + try + { + // Act + (AsyncDisposable disposable, CallInvoker callInvoker) = InvokeGetCallInvoker(options); + + try + { + callInvoker.Should().BeOfType(); + GetOwnsChannel(callInvoker).Should().BeFalse(); + + // Act + await AssertRpcFailureAsync(callInvoker); + await AssertRpcFailureAsync(callInvoker); + await recreateRequested.Task.WaitAsync(TimeSpan.FromSeconds(5)); + await AssertRpcFailureAsync(callInvoker); + await recreatedTransportUsed.Task.WaitAsync(TimeSpan.FromSeconds(5)); + + // Assert + initialHandler.CallCount.Should().Be(2); + recreatedHandler.CallCount.Should().Be(1); + recreatorCalls.Should().Be(1); + } + finally + { + await disposable.DisposeAsync(); + } + } + finally + { + await DisposeChannelAsync(channel); + await DisposeChannelAsync(recreatedChannel); + } + } + + [Fact] + public async Task GetCallInvoker_WithAddressAndRecreator_UsesWrapperThatOwnsCreatedChannel() + { + // Arrange + GrpcDurableTaskClientOptions options = new() + { + Address = "http://owned.client.test", + }; + options.SetChannelRecreator((existingChannel, ct) => Task.FromResult(existingChannel)); + + // Act + (AsyncDisposable disposable, CallInvoker callInvoker) = InvokeGetCallInvoker(options); + + try + { + // Assert + callInvoker.Should().BeOfType(); + GetOwnsChannel(callInvoker).Should().BeTrue(); + } + finally + { + await disposable.DisposeAsync(); + } + } + + [Fact] + public async Task CreateRecreateCancellationSource_WhenDisposedDuringRecreateWindow_ReturnsCanceledTokenSource() + { + // Arrange + GrpcChannel channel = GrpcChannel.ForAddress("http://disposed-race.client.test"); + GrpcDurableTaskClientOptions options = new() + { + Channel = channel, + }; + options.SetChannelRecreator((existingChannel, ct) => Task.FromResult(existingChannel)); + + try + { + (AsyncDisposable disposable, CallInvoker callInvoker) = InvokeGetCallInvoker(options); + + try + { + ChannelRecreatingCallInvoker wrapper = callInvoker.Should().BeOfType().Subject; + MethodInfo? method = typeof(ChannelRecreatingCallInvoker).GetMethod( + "CreateRecreateCancellationSource", + BindingFlags.Instance | BindingFlags.NonPublic); + method.Should().NotBeNull(); + + SetDisposed(wrapper, 1); + GetDisposalCancellationSource(wrapper).Dispose(); + + // Act + using CancellationTokenSource recreateCts = + (CancellationTokenSource)method!.Invoke(wrapper, Array.Empty())!; + + // Assert + recreateCts.IsCancellationRequested.Should().BeTrue(); + } + finally + { + await disposable.DisposeAsync(); + } + } + finally + { + await DisposeChannelAsync(channel); + } + } + + [Theory] + [InlineData(0, 0)] + [InlineData(-1, 0)] + public void ToStopwatchTicks_NonPositiveInterval_ReturnsZero(long ticks, long expected) + { + // Arrange + TimeSpan interval = TimeSpan.FromTicks(ticks); + + // Act + long stopwatchTicks = InvokeToStopwatchTicks(interval); + + // Assert + stopwatchTicks.Should().Be(expected); + } + + [Fact] + public void ToStopwatchTicks_VeryLargeInterval_SaturatesAtLongMaxValue() + { + // Arrange + TimeSpan interval = TimeSpan.MaxValue; + + // Act + long stopwatchTicks = InvokeToStopwatchTicks(interval); + + // Assert + stopwatchTicks.Should().Be(long.MaxValue); + } + + static (AsyncDisposable Disposable, CallInvoker CallInvoker) InvokeGetCallInvoker(GrpcDurableTaskClientOptions options) + { + object?[] args = { options, NullLogger.Instance, null }; + AsyncDisposable disposable = (AsyncDisposable)GetCallInvokerMethod.Invoke(null, args)!; + CallInvoker callInvoker = (CallInvoker)args[2]!; + return (disposable, callInvoker); + } + + static bool GetOwnsChannel(CallInvoker callInvoker) + { + return (bool)typeof(ChannelRecreatingCallInvoker) + .GetField("ownsChannel", BindingFlags.Instance | BindingFlags.NonPublic)! + .GetValue(callInvoker)!; + } + + static CancellationTokenSource GetDisposalCancellationSource(CallInvoker callInvoker) + { + return (CancellationTokenSource)typeof(ChannelRecreatingCallInvoker) + .GetField("disposalCts", BindingFlags.Instance | BindingFlags.NonPublic)! + .GetValue(callInvoker)!; + } + + static void SetDisposed(CallInvoker callInvoker, int value) + { + typeof(ChannelRecreatingCallInvoker) + .GetField("disposed", BindingFlags.Instance | BindingFlags.NonPublic)! + .SetValue(callInvoker, value); + } + + static long InvokeToStopwatchTicks(TimeSpan interval) + { + return (long)ToStopwatchTicksMethod.Invoke(null, new object?[] { interval })!; + } + + static async Task AssertRpcFailureAsync(CallInvoker callInvoker) + { + Func act = async () => + { + using AsyncUnaryCall call = callInvoker.AsyncUnaryCall( + TestMethod, + host: null, + new CallOptions(deadline: DateTime.UtcNow.AddSeconds(1)), + request: "ping"); + + await call.ResponseAsync; + }; + + RpcException rpcException = (await act.Should().ThrowAsync()).Which; + rpcException.StatusCode.Should().Be(StatusCode.Unavailable); + } + + static GrpcChannel CreateChannel(string address, HttpMessageHandler handler) + { + return GrpcChannel.ForAddress(address, new GrpcChannelOptions + { + HttpHandler = handler, + }); + } + + static async ValueTask DisposeChannelAsync(GrpcChannel channel) + { + await channel.ShutdownAsync(); + channel.Dispose(); + } + + static HttpResponseMessage CreateFailureResponse(StatusCode statusCode, string detail) + { + HttpResponseMessage response = new(System.Net.HttpStatusCode.OK) + { + Version = new Version(2, 0), + Content = new ByteArrayContent([]), + }; + + response.Content.Headers.ContentType = new System.Net.Http.Headers.MediaTypeHeaderValue("application/grpc"); + response.TrailingHeaders.Add("grpc-status", ((int)statusCode).ToString()); + response.TrailingHeaders.Add("grpc-message", detail); + return response; + } + + sealed class CallbackHttpMessageHandler : HttpMessageHandler + { + readonly Func> callback; + int callCount; + + public CallbackHttpMessageHandler(Func> callback) + { + this.callback = callback; + } + + public int CallCount => Volatile.Read(ref this.callCount); + + protected override Task SendAsync(HttpRequestMessage request, CancellationToken cancellationToken) + { + Interlocked.Increment(ref this.callCount); + return this.callback(request, cancellationToken); + } + } +} diff --git a/test/Worker/AzureManaged.Tests/DurableTaskSchedulerWorkerExtensionsTests.cs b/test/Worker/AzureManaged.Tests/DurableTaskSchedulerWorkerExtensionsTests.cs index 6e08347ee..0c1d72805 100644 --- a/test/Worker/AzureManaged.Tests/DurableTaskSchedulerWorkerExtensionsTests.cs +++ b/test/Worker/AzureManaged.Tests/DurableTaskSchedulerWorkerExtensionsTests.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. +using System.Reflection; using Azure.Core; using Azure.Identity; using FluentAssertions; @@ -292,6 +293,37 @@ public async Task UseDurableTaskScheduler_ServiceProviderDispose_DisposesChannel newOptions.Channel.Should().NotBeSameAs(channel, "new provider should create a new channel"); } + [Fact] + public async Task UseDurableTaskScheduler_ChannelRecreator_ReplacesCachedChannel() + { + // Arrange + ServiceCollection services = new ServiceCollection(); + Mock mockBuilder = new Mock(); + mockBuilder.Setup(b => b.Services).Returns(services); + DefaultAzureCredential credential = new DefaultAzureCredential(); + mockBuilder.Object.UseDurableTaskScheduler(ValidEndpoint, ValidTaskHub, credential, options => + { + options.WorkerId = "worker-id-1"; + }); + + await using ServiceProvider provider = services.BuildServiceProvider(); + IOptionsFactory optionsFactory = provider.GetRequiredService>(); + GrpcDurableTaskWorkerOptions options = optionsFactory.Create(Options.DefaultName); + GrpcChannel originalChannel = options.Channel!; + + // Act + Func>? recreator = GetChannelRecreator(options); + recreator.Should().NotBeNull(); + GrpcChannel recreatedChannel = await recreator!(originalChannel, CancellationToken.None); + GrpcChannel staleRecreateResult = await recreator(originalChannel, CancellationToken.None); + GrpcDurableTaskWorkerOptions refreshedOptions = optionsFactory.Create(Options.DefaultName); + + // Assert + recreatedChannel.Should().NotBeSameAs(originalChannel); + staleRecreateResult.Should().BeSameAs(recreatedChannel, "a stale recreate callback should keep the already-cached replacement channel"); + refreshedOptions.Channel.Should().BeSameAs(recreatedChannel, "the worker recreator should replace the cached channel for the same worker entry"); + } + [Fact] public async Task UseDurableTaskScheduler_ConfigureAfterDispose_ThrowsObjectDisposedException() { @@ -448,5 +480,16 @@ public async Task UseDurableTaskScheduler_DifferentWorkerId_UsesSeparateChannels options2.Channel.Should().NotBeNull(); options1.Channel.Should().NotBeSameAs(options2.Channel, "different WorkerId should use different channels"); } + + static Func>? GetChannelRecreator(GrpcDurableTaskWorkerOptions options) + { + object internalOptions = typeof(GrpcDurableTaskWorkerOptions) + .GetProperty("Internal", BindingFlags.Instance | BindingFlags.NonPublic)! + .GetValue(options)!; + + return (Func>?)internalOptions.GetType() + .GetProperty("ChannelRecreator", BindingFlags.Instance | BindingFlags.Public)! + .GetValue(internalOptions); + } } diff --git a/test/Worker/Grpc.Tests/GrpcDurableTaskWorkerOptionsInternalTests.cs b/test/Worker/Grpc.Tests/GrpcDurableTaskWorkerOptionsInternalTests.cs new file mode 100644 index 000000000..5813bc683 --- /dev/null +++ b/test/Worker/Grpc.Tests/GrpcDurableTaskWorkerOptionsInternalTests.cs @@ -0,0 +1,63 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using Microsoft.DurableTask.Worker.Grpc.Internal; + +namespace Microsoft.DurableTask.Worker.Grpc.Tests; + +public class GrpcDurableTaskWorkerOptionsInternalTests +{ + [Fact] + public void InternalOptions_HasSafeDefaults() + { + // Arrange + GrpcDurableTaskWorkerOptions options = new(); + + // Act + GrpcDurableTaskWorkerOptions.InternalOptions internalOptions = options.Internal; + + // Assert + internalOptions.HelloDeadline.Should().Be(TimeSpan.FromSeconds(30)); + internalOptions.ChannelRecreateFailureThreshold.Should().Be(5); + internalOptions.ReconnectBackoffBase.Should().Be(TimeSpan.FromSeconds(1)); + internalOptions.ReconnectBackoffCap.Should().Be(TimeSpan.FromSeconds(30)); + internalOptions.SilentDisconnectTimeout.Should().Be(TimeSpan.FromSeconds(120)); + internalOptions.ChannelRecreator.Should().BeNull(); + } + + [Fact] + public void SetChannelRecreator_NullCallback_Throws() + { + // Arrange + GrpcDurableTaskWorkerOptions options = new(); + + // Act + Action act = () => options.SetChannelRecreator(null!); + + // Assert + act.Should().Throw(); + } + + [Fact] + public void SetChannelRecreator_StoresCallbackOnInternalOptions() + { + // Arrange + GrpcDurableTaskWorkerOptions options = new(); + bool invoked = false; + Func> recreator = (channel, ct) => + { + invoked = true; + return Task.FromResult(channel); + }; + + // Act + options.SetChannelRecreator(recreator); + + // Assert + options.Internal.ChannelRecreator.Should().BeSameAs(recreator); + + // Sanity-check that invoking the stored delegate calls the original. + options.Internal.ChannelRecreator!.Invoke(null!, CancellationToken.None); + invoked.Should().BeTrue(); + } +} diff --git a/test/Worker/Grpc.Tests/GrpcDurableTaskWorkerTests.cs b/test/Worker/Grpc.Tests/GrpcDurableTaskWorkerTests.cs new file mode 100644 index 000000000..db6c98da5 --- /dev/null +++ b/test/Worker/Grpc.Tests/GrpcDurableTaskWorkerTests.cs @@ -0,0 +1,814 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.IO; +using System.Reflection; +using Google.Protobuf.WellKnownTypes; +using Grpc.Core; +using Microsoft.DurableTask; +using Microsoft.DurableTask.Tests.Logging; +using Microsoft.DurableTask.Worker; +using Microsoft.DurableTask.Worker.Grpc.Internal; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; +using P = Microsoft.DurableTask.Protobuf; + +namespace Microsoft.DurableTask.Worker.Grpc.Tests; + +public class GrpcDurableTaskWorkerTests +{ + const string Category = "Microsoft.DurableTask.Worker.Grpc"; + static readonly MethodInfo ExecuteAsyncMethod = typeof(GrpcDurableTaskWorker) + .GetMethod("ExecuteAsync", BindingFlags.Instance | BindingFlags.NonPublic)!; + static readonly MethodInfo ApplySuccessfulRecreateMethod = typeof(GrpcDurableTaskWorker) + .GetMethod("ApplySuccessfulRecreate", BindingFlags.Instance | BindingFlags.NonPublic)!; + static readonly MethodInfo ProcessorExecuteAsyncMethod = typeof(GrpcDurableTaskWorker) + .GetNestedType("Processor", BindingFlags.NonPublic)! + .GetMethod("ExecuteAsync", BindingFlags.Instance | BindingFlags.Public)!; + static readonly MethodInfo ProcessorConnectAsyncMethod = typeof(GrpcDurableTaskWorker) + .GetNestedType("Processor", BindingFlags.NonPublic)! + .GetMethod("ConnectAsync", BindingFlags.Instance | BindingFlags.NonPublic)!; + static readonly MethodInfo TryRecreateChannelAsyncMethod = typeof(GrpcDurableTaskWorker) + .GetMethod("TryRecreateChannelAsync", BindingFlags.Instance | BindingFlags.NonPublic)!; + + [Fact] + public async Task ExecuteAsync_ConnectFailureThreshold_RecreatesConfiguredChannel() + { + // Arrange + CallbackHttpMessageHandler initialHandler = new((request, cancellationToken) => + Task.FromResult(CreateFailureResponse(StatusCode.Unavailable, "initial transport failure"))); + TaskCompletionSource recreatedTransportUsed = new(TaskCreationOptions.RunContinuationsAsynchronously); + CallbackHttpMessageHandler recreatedHandler = new(async (request, cancellationToken) => + { + recreatedTransportUsed.TrySetResult(); + await Task.Delay(Timeout.Infinite, cancellationToken); + return CreateFailureResponse(StatusCode.Cancelled, "recreated transport cancelled"); + }); + + GrpcChannel currentChannel = CreateChannel("http://initial.worker.test", initialHandler); + GrpcChannel recreatedChannel = CreateChannel("http://recreated.worker.test", recreatedHandler); + GrpcDurableTaskWorkerOptions grpcOptions = new() + { + Channel = currentChannel, + }; + grpcOptions.Internal.ChannelRecreateFailureThreshold = 2; + grpcOptions.Internal.ReconnectBackoffBase = TimeSpan.Zero; + grpcOptions.Internal.ReconnectBackoffCap = TimeSpan.Zero; + + DurableTaskWorkerOptions workerOptions = new() + { + Logging = { UseLegacyCategories = false }, + }; + TestLogProvider logProvider = new(new NullOutput()); + using CancellationTokenSource stoppingToken = new(); + int recreatorCalls = 0; + grpcOptions.SetChannelRecreator((channel, ct) => + { + recreatorCalls++; + return Task.FromResult(recreatedChannel); + }); + + GrpcDurableTaskWorker worker = CreateWorker(grpcOptions, workerOptions, new SimpleLoggerFactory(logProvider)); + + try + { + // Act + Task executeTask = InvokeExecuteAsync(worker, stoppingToken.Token); + await recreatedTransportUsed.Task.WaitAsync(TimeSpan.FromSeconds(5)); + stoppingToken.Cancel(); + await executeTask; + + // Assert + recreatorCalls.Should().Be(1); + initialHandler.CallCount.Should().Be(2); + recreatedHandler.CallCount.Should().Be(1); + logProvider.TryGetLogs(Category, out IReadOnlyCollection? logs).Should().BeTrue(); + logs!.Should().Contain(log => + log.Message.Contains("gRPC channel to backend has been recreated") + && log.Message.Contains(recreatedChannel.Target)); + } + finally + { + await DisposeChannelAsync(currentChannel); + await DisposeChannelAsync(recreatedChannel); + } + } + + [Fact] + public async Task ExecuteAsync_TransportResetDuringHello_RecreatesConfiguredChannel() + { + // Arrange + CallbackHttpMessageHandler initialHandler = new((request, cancellationToken) => + throw new HttpRequestException( + "connection reset by peer", + new IOException("An existing connection was forcibly closed by the remote host."))); + TaskCompletionSource recreatedTransportUsed = new(TaskCreationOptions.RunContinuationsAsynchronously); + CallbackHttpMessageHandler recreatedHandler = new(async (request, cancellationToken) => + { + recreatedTransportUsed.TrySetResult(); + await Task.Delay(Timeout.Infinite, cancellationToken); + return CreateFailureResponse(StatusCode.Cancelled, "recreated transport cancelled"); + }); + + GrpcChannel currentChannel = CreateChannel("http://transport-reset.worker.test", initialHandler); + GrpcChannel recreatedChannel = CreateChannel("http://recreated-after-reset.worker.test", recreatedHandler); + GrpcDurableTaskWorkerOptions grpcOptions = new() + { + Channel = currentChannel, + }; + grpcOptions.Internal.ChannelRecreateFailureThreshold = 1; + grpcOptions.Internal.ReconnectBackoffBase = TimeSpan.Zero; + grpcOptions.Internal.ReconnectBackoffCap = TimeSpan.Zero; + + DurableTaskWorkerOptions workerOptions = new() + { + Logging = { UseLegacyCategories = false }, + }; + TestLogProvider logProvider = new(new NullOutput()); + using CancellationTokenSource stoppingToken = new(); + int recreatorCalls = 0; + grpcOptions.SetChannelRecreator((channel, ct) => + { + recreatorCalls++; + return Task.FromResult(recreatedChannel); + }); + + GrpcDurableTaskWorker worker = CreateWorker(grpcOptions, workerOptions, new SimpleLoggerFactory(logProvider)); + Task? executeTask = null; + + try + { + // Act + executeTask = InvokeExecuteAsync(worker, stoppingToken.Token); + await recreatedTransportUsed.Task.WaitAsync(TimeSpan.FromSeconds(5)); + stoppingToken.Cancel(); + await executeTask; + + // Assert + recreatorCalls.Should().Be(1); + initialHandler.CallCount.Should().Be(1); + recreatedHandler.CallCount.Should().Be(1); + logProvider.TryGetLogs(Category, out IReadOnlyCollection? logs).Should().BeTrue(); + logs!.Should().Contain(log => log.Message.Contains("gRPC channel to backend has been recreated")); + } + finally + { + stoppingToken.Cancel(); + if (executeTask is not null) + { + await executeTask; + } + + await DisposeChannelAsync(currentChannel); + await DisposeChannelAsync(recreatedChannel); + } + } + + [Fact] + public async Task ProcessorExecuteAsync_SilentDisconnectBeforeFirstMessage_ReturnsChannelRecreateRequested() + { + // Arrange + GrpcDurableTaskWorkerOptions grpcOptions = new(); + grpcOptions.Internal.ChannelRecreateFailureThreshold = 1; + grpcOptions.Internal.ReconnectBackoffBase = TimeSpan.Zero; + grpcOptions.Internal.ReconnectBackoffCap = TimeSpan.Zero; + grpcOptions.Internal.SilentDisconnectTimeout = TimeSpan.FromMilliseconds(100); + + DurableTaskWorkerOptions workerOptions = new() + { + Logging = { UseLegacyCategories = false }, + }; + TestLogProvider logProvider = new(new NullOutput()); + ScriptedWorkerCallInvoker callInvoker = new( + helloFactory: static (callNumber, options) => CreateUnaryCall(Task.FromResult(new Empty())), + getWorkItemsFactory: static (callNumber, options) => CreateServerStreamingCall( + new HangingAsyncStreamReader(throwAsRpc: true))); + + GrpcDurableTaskWorker worker = CreateWorker(grpcOptions, workerOptions, new SimpleLoggerFactory(logProvider)); + object processor = CreateProcessor(worker, new P.TaskHubSidecarService.TaskHubSidecarServiceClient(callInvoker)); + + // Act + ProcessorExitReason reason = await InvokeProcessorExecuteAsync(processor, CancellationToken.None); + + // Assert + reason.Should().Be(ProcessorExitReason.ChannelRecreateRequested); + callInvoker.HelloCallCount.Should().Be(1); + callInvoker.GetWorkItemsCallCount.Should().Be(1); + logProvider.TryGetLogs(Category, out IReadOnlyCollection? logs).Should().BeTrue(); + logs!.Should().Contain(log => log.Message.Contains("Channel to backend has stopped receiving traffic")); + logs.Should().Contain(log => log.Message.Contains("Recreating gRPC channel to backend")); + } + + [Fact] + public async Task ProcessorExecuteAsync_GracefulDrainAfterFirstMessage_ReconnectsWithoutChannelRecreate() + { + // Arrange + GrpcDurableTaskWorkerOptions grpcOptions = new(); + grpcOptions.Internal.ChannelRecreateFailureThreshold = 1; + grpcOptions.Internal.ReconnectBackoffBase = TimeSpan.Zero; + grpcOptions.Internal.ReconnectBackoffCap = TimeSpan.Zero; + grpcOptions.Internal.SilentDisconnectTimeout = TimeSpan.FromSeconds(5); + + DurableTaskWorkerOptions workerOptions = new() + { + Logging = { UseLegacyCategories = false }, + }; + TestLogProvider logProvider = new(new NullOutput()); + TaskCompletionSource secondStreamOpened = new(TaskCreationOptions.RunContinuationsAsynchronously); + using CancellationTokenSource stoppingToken = new(); + + ScriptedWorkerCallInvoker callInvoker = new( + helloFactory: static (callNumber, options) => CreateUnaryCall(Task.FromResult(new Empty())), + getWorkItemsFactory: (callNumber, options) => + { + if (callNumber == 1) + { + return CreateServerStreamingCall( + new SequenceAsyncStreamReader(new P.WorkItem { HealthPing = new P.HealthPing() })); + } + + secondStreamOpened.TrySetResult(); + return CreateServerStreamingCall( + new HangingAsyncStreamReader(throwAsRpc: false)); + }); + + GrpcDurableTaskWorker worker = CreateWorker(grpcOptions, workerOptions, new SimpleLoggerFactory(logProvider)); + object processor = CreateProcessor(worker, new P.TaskHubSidecarService.TaskHubSidecarServiceClient(callInvoker)); + + // Act + Task executeTask = InvokeProcessorExecuteAsync(processor, stoppingToken.Token); + await secondStreamOpened.Task.WaitAsync(TimeSpan.FromSeconds(5)); + stoppingToken.Cancel(); + ProcessorExitReason reason = await executeTask; + + // Assert + reason.Should().Be(ProcessorExitReason.Shutdown); + callInvoker.HelloCallCount.Should().BeGreaterThanOrEqualTo(2); + callInvoker.GetWorkItemsCallCount.Should().BeGreaterThanOrEqualTo(2); + logProvider.TryGetLogs(Category, out IReadOnlyCollection? logs).Should().BeTrue(); + logs!.Should().Contain(log => log.Message.Contains("Work-item stream ended by the backend")); + logs.Should().NotContain(log => log.Message.Contains("Recreating gRPC channel to backend")); + } + + [Fact] + public async Task ProcessorExecuteAsync_HelloDeadlineExceeded_ReturnsChannelRecreateRequested() + { + // Arrange + GrpcDurableTaskWorkerOptions grpcOptions = new(); + grpcOptions.SetHelloDeadline(TimeSpan.FromMilliseconds(123)); + grpcOptions.Internal.ChannelRecreateFailureThreshold = 1; + grpcOptions.Internal.ReconnectBackoffBase = TimeSpan.Zero; + grpcOptions.Internal.ReconnectBackoffCap = TimeSpan.Zero; + + DurableTaskWorkerOptions workerOptions = new() + { + Logging = { UseLegacyCategories = false }, + }; + TestLogProvider logProvider = new(new NullOutput()); + ScriptedWorkerCallInvoker callInvoker = new( + helloFactory: static (callNumber, options) => CreateUnaryCall( + Task.FromException(new RpcException(new Status(StatusCode.DeadlineExceeded, "hello timed out")))), + getWorkItemsFactory: static (callNumber, options) => throw new InvalidOperationException("GetWorkItems should not be called.")); + + GrpcDurableTaskWorker worker = CreateWorker(grpcOptions, workerOptions, new SimpleLoggerFactory(logProvider)); + object processor = CreateProcessor(worker, new P.TaskHubSidecarService.TaskHubSidecarServiceClient(callInvoker)); + + // Act + ProcessorExitReason reason = await InvokeProcessorExecuteAsync(processor, CancellationToken.None); + + // Assert + reason.Should().Be(ProcessorExitReason.ChannelRecreateRequested); + callInvoker.HelloCallCount.Should().Be(1); + callInvoker.GetWorkItemsCallCount.Should().Be(0); + logProvider.TryGetLogs(Category, out IReadOnlyCollection? logs).Should().BeTrue(); + logs!.Should().Contain(log => log.Message.Contains("Hello handshake to backend timed out after 00:00:00.123")); + logs.Should().Contain(log => log.Message.Contains("Recreating gRPC channel to backend")); + } + + [Theory] + [InlineData(StatusCode.Cancelled, "Durable Task gRPC worker has disconnected from gRPC server.")] + [InlineData(StatusCode.Unauthenticated, "Authentication failed when connecting to backend. Will retry.")] + [InlineData(StatusCode.NotFound, "Task hub NotFound. Will continue retrying.")] + public async Task ProcessorExecuteAsync_NonPoisonHandshakeFailures_RetryWithoutChannelRecreate( + StatusCode statusCode, + string expectedLogMessage) + { + // Arrange + GrpcDurableTaskWorkerOptions grpcOptions = new(); + grpcOptions.Internal.ChannelRecreateFailureThreshold = 1; + grpcOptions.Internal.ReconnectBackoffBase = TimeSpan.Zero; + grpcOptions.Internal.ReconnectBackoffCap = TimeSpan.Zero; + + DurableTaskWorkerOptions workerOptions = new() + { + Logging = { UseLegacyCategories = false }, + }; + TestLogProvider logProvider = new(new NullOutput()); + using CancellationTokenSource stoppingToken = new(); + + ScriptedWorkerCallInvoker callInvoker = new( + helloFactory: (callNumber, options) => + { + if (callNumber == 2) + { + stoppingToken.Cancel(); + } + + return CreateUnaryCall( + Task.FromException(new RpcException(new Status(statusCode, "hello failed")))); + }, + getWorkItemsFactory: static (callNumber, options) => throw new InvalidOperationException("GetWorkItems should not be called.")); + + GrpcDurableTaskWorker worker = CreateWorker(grpcOptions, workerOptions, new SimpleLoggerFactory(logProvider)); + object processor = CreateProcessor(worker, new P.TaskHubSidecarService.TaskHubSidecarServiceClient(callInvoker)); + + // Act + ProcessorExitReason reason = await InvokeProcessorExecuteAsync(processor, stoppingToken.Token); + + // Assert + reason.Should().Be(ProcessorExitReason.Shutdown); + callInvoker.HelloCallCount.Should().BeGreaterThanOrEqualTo(2); + callInvoker.GetWorkItemsCallCount.Should().Be(0); + logProvider.TryGetLogs(Category, out IReadOnlyCollection? logs).Should().BeTrue(); + logs!.Should().Contain(log => log.Message.Contains(expectedLogMessage)); + logs.Should().NotContain(log => log.Message.Contains("Recreating gRPC channel to backend")); + } + + [Fact] + public async Task TryRecreateChannelAsync_ConfiguredRecreatorReturningSameChannel_DoesNotRecreate() + { + // Arrange + GrpcChannel currentChannel = GrpcChannel.ForAddress("http://localhost:5003"); + GrpcDurableTaskWorkerOptions grpcOptions = new() + { + Channel = currentChannel, + }; + + GrpcChannel? observedChannel = null; + grpcOptions.SetChannelRecreator((channel, ct) => + { + observedChannel = channel; + return Task.FromResult(channel); + }); + + GrpcDurableTaskWorker worker = CreateWorker(grpcOptions); + + try + { + // Act + object result = await InvokeTryRecreateChannelAsync(worker, currentChannel); + + // Assert + observedChannel.Should().BeSameAs(currentChannel); + GetResultProperty(result, "Recreated").Should().BeFalse(); + GetResultProperty(result, "NewChannel").Should().BeNull(); + GetResultProperty(result, "NewAddress").Should().BeNull(); + } + finally + { + await DisposeChannelAsync(currentChannel); + } + } + + [Fact] + public async Task TryRecreateChannelAsync_ConfiguredRecreatorReturningDifferentChannel_DoesNotCarryForwardOldDisposable() + { + // Arrange + GrpcChannel currentChannel = GrpcChannel.ForAddress("http://localhost:5004"); + GrpcChannel recreatedChannel = GrpcChannel.ForAddress("http://localhost:5005"); + GrpcDurableTaskWorkerOptions grpcOptions = new() + { + Channel = currentChannel, + }; + grpcOptions.SetChannelRecreator((channel, ct) => Task.FromResult(recreatedChannel)); + + GrpcDurableTaskWorker worker = CreateWorker(grpcOptions); + int disposeCalls = 0; + AsyncDisposable currentWorkerOwnedDisposable = new(() => + { + Interlocked.Increment(ref disposeCalls); + return ValueTask.CompletedTask; + }); + + try + { + // Act + object result = await InvokeTryRecreateChannelAsync(worker, currentWorkerOwnedDisposable, currentChannel); + AsyncDisposable newDisposable = GetResultProperty(result, "NewWorkerOwnedDisposable"); + + // Simulate the outer worker handoff. + await currentWorkerOwnedDisposable.DisposeAsync(); + await newDisposable.DisposeAsync(); + + // Assert + GetResultProperty(result, "Recreated").Should().BeTrue(); + GetResultProperty(result, "NewChannel").Should().BeSameAs(recreatedChannel); + Volatile.Read(ref disposeCalls).Should().Be(1); + } + finally + { + await DisposeChannelAsync(currentChannel); + await DisposeChannelAsync(recreatedChannel); + } + } + + [Fact] + public async Task ApplySuccessfulRecreate_DefersDisposalOfPreviousWorkerOwnedChannel() + { + // Arrange + GrpcChannel currentChannel = GrpcChannel.ForAddress("http://localhost:5004"); + GrpcChannel recreatedChannel = GrpcChannel.ForAddress("http://localhost:5005"); + GrpcDurableTaskWorkerOptions grpcOptions = new() + { + Channel = currentChannel, + }; + grpcOptions.SetChannelRecreator((channel, ct) => Task.FromResult(recreatedChannel)); + GrpcDurableTaskWorker worker = CreateWorker(grpcOptions); + + int disposeCalls = 0; + TaskCompletionSource disposalObserved = new(TaskCreationOptions.RunContinuationsAsynchronously); + AsyncDisposable disposable = new(() => + { + Interlocked.Increment(ref disposeCalls); + disposalObserved.TrySetResult(); + return ValueTask.CompletedTask; + }); + + CallInvoker callInvoker = currentChannel.CreateCallInvoker(); + string address = currentChannel.Target; + GrpcChannel? latestObservedChannel = currentChannel; + AsyncDisposable workerOwnedChannelDisposable = disposable; + + try + { + object result = await InvokeTryRecreateChannelAsync(worker, disposable, currentChannel); + + // Act + InvokeApplySuccessfulRecreate( + worker, + result, + ref callInvoker, + ref address, + ref latestObservedChannel, + ref workerOwnedChannelDisposable, + TimeSpan.FromMilliseconds(100)); + + // Assert + disposalObserved.Task.IsCompleted.Should().BeFalse(); + Volatile.Read(ref disposeCalls).Should().Be(0); + address.Should().Be(recreatedChannel.Target); + latestObservedChannel.Should().BeSameAs(recreatedChannel); + await disposalObserved.Task.WaitAsync(TimeSpan.FromSeconds(5)); + Volatile.Read(ref disposeCalls).Should().Be(1); + } + finally + { + await DisposeChannelAsync(currentChannel); + await DisposeChannelAsync(recreatedChannel); + } + } + + [Fact] + public async Task ConnectAsync_VeryLargeHelloDeadline_UsesUtcMaxValueDeadline() + { + // Arrange + GrpcDurableTaskWorkerOptions grpcOptions = new(); + grpcOptions.SetHelloDeadline(TimeSpan.MaxValue); + GrpcDurableTaskWorker worker = CreateWorker(grpcOptions); + RecordingCallInvoker callInvoker = new(); + P.TaskHubSidecarService.TaskHubSidecarServiceClient client = new(callInvoker); + object processor = CreateProcessor(worker, client); + + // Act + using AsyncServerStreamingCall stream = await InvokeProcessorConnectAsync(processor); + + // Assert + callInvoker.HelloDeadline.Should().HaveValue(); + DateTime deadline = callInvoker.HelloDeadline!.Value; + deadline.Kind.Should().Be(DateTimeKind.Utc); + deadline.Should().Be(DateTime.SpecifyKind(DateTime.MaxValue, DateTimeKind.Utc)); + } + + static GrpcDurableTaskWorker CreateWorker(GrpcDurableTaskWorkerOptions grpcOptions) + { + return CreateWorker(grpcOptions, new DurableTaskWorkerOptions(), NullLoggerFactory.Instance); + } + + static GrpcDurableTaskWorker CreateWorker( + GrpcDurableTaskWorkerOptions grpcOptions, + DurableTaskWorkerOptions workerOptions, + ILoggerFactory loggerFactory) + { + Mock factoryMock = new(MockBehavior.Strict); + + return new GrpcDurableTaskWorker( + name: "Test", + factory: factoryMock.Object, + grpcOptions: new OptionsMonitorStub(grpcOptions), + workerOptions: new OptionsMonitorStub(workerOptions), + services: Mock.Of(), + loggerFactory: loggerFactory, + orchestrationFilter: null, + exceptionPropertiesProvider: null); + } + + static Task InvokeExecuteAsync(GrpcDurableTaskWorker worker, CancellationToken cancellationToken) + { + return (Task)ExecuteAsyncMethod.Invoke(worker, new object?[] { cancellationToken })!; + } + + static object CreateProcessor(GrpcDurableTaskWorker worker, P.TaskHubSidecarService.TaskHubSidecarServiceClient client) + { + System.Type processorType = typeof(GrpcDurableTaskWorker).GetNestedType("Processor", BindingFlags.NonPublic)!; + return Activator.CreateInstance( + processorType, + BindingFlags.Public | BindingFlags.Instance, + binder: null, + args: new object?[] { worker, client, null, null }, + culture: null)!; + } + + static async Task> InvokeProcessorConnectAsync(object processor) + { + Task task = (Task)ProcessorConnectAsyncMethod.Invoke(processor, new object?[] { CancellationToken.None })!; + await task; + return (AsyncServerStreamingCall)task.GetType().GetProperty("Result")!.GetValue(task)!; + } + + static async Task InvokeProcessorExecuteAsync(object processor, CancellationToken cancellationToken) + { + Task task = (Task)ProcessorExecuteAsyncMethod.Invoke(processor, new object?[] { cancellationToken })!; + await task; + return (ProcessorExitReason)task.GetType().GetProperty("Result")!.GetValue(task)!; + } + + static void InvokeApplySuccessfulRecreate( + GrpcDurableTaskWorker worker, + object result, + ref CallInvoker callInvoker, + ref string address, + ref GrpcChannel? latestObservedChannel, + ref AsyncDisposable workerOwnedChannelDisposable, + TimeSpan deferredDisposeGracePeriod) + { + object?[] args = { result, callInvoker, address, latestObservedChannel, workerOwnedChannelDisposable, deferredDisposeGracePeriod }; + ApplySuccessfulRecreateMethod.Invoke(worker, args); + callInvoker = (CallInvoker)args[1]!; + address = (string)args[2]!; + latestObservedChannel = (GrpcChannel?)args[3]; + workerOwnedChannelDisposable = (AsyncDisposable)args[4]!; + } + + static async Task InvokeTryRecreateChannelAsync(GrpcDurableTaskWorker worker, GrpcChannel currentChannel) + { + return await InvokeTryRecreateChannelAsync(worker, default, currentChannel); + } + + static async Task InvokeTryRecreateChannelAsync( + GrpcDurableTaskWorker worker, + AsyncDisposable currentWorkerOwnedDisposable, + GrpcChannel currentChannel) + { + object?[] args = { CancellationToken.None, currentWorkerOwnedDisposable, currentChannel }; + Task task = (Task)TryRecreateChannelAsyncMethod.Invoke(worker, args)!; + await task; + return task.GetType().GetProperty("Result")!.GetValue(task)!; + } + + static T GetResultProperty(object result, string propertyName) + { + return (T)result.GetType().GetProperty(propertyName)!.GetValue(result)!; + } + + static async ValueTask DisposeChannelAsync(GrpcChannel channel) + { + await channel.ShutdownAsync(); + channel.Dispose(); + } + + static GrpcChannel CreateChannel(string address, HttpMessageHandler handler) + { + return GrpcChannel.ForAddress(address, new GrpcChannelOptions + { + HttpHandler = handler, + }); + } + + static HttpResponseMessage CreateFailureResponse(StatusCode statusCode, string detail) + { + HttpResponseMessage response = new(System.Net.HttpStatusCode.OK) + { + Version = new Version(2, 0), + Content = new ByteArrayContent([]), + }; + + response.Content.Headers.ContentType = new System.Net.Http.Headers.MediaTypeHeaderValue("application/grpc"); + response.TrailingHeaders.Add("grpc-status", ((int)statusCode).ToString()); + response.TrailingHeaders.Add("grpc-message", detail); + return response; + } + + static AsyncUnaryCall CreateUnaryCall(Task responseTask) + { + return new AsyncUnaryCall( + responseTask, + Task.FromResult(new Metadata()), + () => new Status(StatusCode.OK, string.Empty), + () => new Metadata(), + () => { }); + } + + static AsyncServerStreamingCall CreateServerStreamingCall(IAsyncStreamReader reader) + { + return new AsyncServerStreamingCall( + reader, + Task.FromResult(new Metadata()), + () => new Status(StatusCode.OK, string.Empty), + () => new Metadata(), + () => (reader as IDisposable)?.Dispose()); + } + + sealed class CallbackHttpMessageHandler : HttpMessageHandler + { + readonly Func> callback; + int callCount; + + public CallbackHttpMessageHandler(Func> callback) + { + this.callback = callback; + } + + public int CallCount => Volatile.Read(ref this.callCount); + + protected override Task SendAsync(HttpRequestMessage request, CancellationToken cancellationToken) + { + Interlocked.Increment(ref this.callCount); + return this.callback(request, cancellationToken); + } + } + + sealed class RecordingCallInvoker : CallInvoker + { + public DateTime? HelloDeadline { get; private set; } + + public override TResponse BlockingUnaryCall(Method method, string? host, CallOptions options, TRequest request) + { + throw new NotSupportedException(); + } + + public override AsyncUnaryCall AsyncUnaryCall(Method method, string? host, CallOptions options, TRequest request) + { + if (method.FullName == "/TaskHubSidecarService/Hello") + { + this.HelloDeadline = options.Deadline; + TResponse response = (TResponse)(object)new Empty(); + return new AsyncUnaryCall( + Task.FromResult(response), + Task.FromResult(new Metadata()), + () => new Status(StatusCode.OK, string.Empty), + () => new Metadata(), + () => { }); + } + + throw new NotSupportedException($"Unexpected unary method {method.FullName}."); + } + + public override AsyncServerStreamingCall AsyncServerStreamingCall(Method method, string? host, CallOptions options, TRequest request) + { + if (method.FullName == "/TaskHubSidecarService/GetWorkItems") + { + return new AsyncServerStreamingCall( + new EmptyAsyncStreamReader(), + Task.FromResult(new Metadata()), + () => new Status(StatusCode.OK, string.Empty), + () => new Metadata(), + () => { }); + } + + throw new NotSupportedException($"Unexpected server-streaming method {method.FullName}."); + } + + public override AsyncClientStreamingCall AsyncClientStreamingCall(Method method, string? host, CallOptions options) + { + throw new NotSupportedException(); + } + + public override AsyncDuplexStreamingCall AsyncDuplexStreamingCall(Method method, string? host, CallOptions options) + { + throw new NotSupportedException(); + } + } + + sealed class EmptyAsyncStreamReader : IAsyncStreamReader + { + public T Current => default!; + + public Task MoveNext(CancellationToken cancellationToken) => Task.FromResult(false); + } + + sealed class SequenceAsyncStreamReader : IAsyncStreamReader + { + readonly Queue items; + + public SequenceAsyncStreamReader(params T[] items) + { + this.items = new Queue(items); + } + + public T Current { get; private set; } = default!; + + public Task MoveNext(CancellationToken cancellationToken) + { + if (this.items.Count == 0) + { + return Task.FromResult(false); + } + + this.Current = this.items.Dequeue(); + return Task.FromResult(true); + } + } + + sealed class HangingAsyncStreamReader : IAsyncStreamReader + { + readonly bool throwAsRpc; + + public HangingAsyncStreamReader(bool throwAsRpc) + { + this.throwAsRpc = throwAsRpc; + } + + public T Current => default!; + + public async Task MoveNext(CancellationToken cancellationToken) + { + try + { + await Task.Delay(Timeout.Infinite, cancellationToken); + } + catch (OperationCanceledException) when (this.throwAsRpc) + { + throw new RpcException(new Status(StatusCode.Cancelled, "stream cancelled")); + } + + return false; + } + } + + sealed class ScriptedWorkerCallInvoker : CallInvoker + { + readonly Func> helloFactory; + readonly Func> getWorkItemsFactory; + int helloCallCount; + int getWorkItemsCallCount; + + public ScriptedWorkerCallInvoker( + Func> helloFactory, + Func> getWorkItemsFactory) + { + this.helloFactory = helloFactory; + this.getWorkItemsFactory = getWorkItemsFactory; + } + + public int HelloCallCount => Volatile.Read(ref this.helloCallCount); + + public int GetWorkItemsCallCount => Volatile.Read(ref this.getWorkItemsCallCount); + + public override TResponse BlockingUnaryCall(Method method, string? host, CallOptions options, TRequest request) + { + throw new NotSupportedException(); + } + + public override AsyncUnaryCall AsyncUnaryCall(Method method, string? host, CallOptions options, TRequest request) + { + if (method.FullName == "/TaskHubSidecarService/Hello") + { + AsyncUnaryCall call = this.helloFactory(Interlocked.Increment(ref this.helloCallCount), options); + return (AsyncUnaryCall)(object)call; + } + + throw new NotSupportedException($"Unexpected unary method {method.FullName}."); + } + + public override AsyncServerStreamingCall AsyncServerStreamingCall(Method method, string? host, CallOptions options, TRequest request) + { + if (method.FullName == "/TaskHubSidecarService/GetWorkItems") + { + AsyncServerStreamingCall call = this.getWorkItemsFactory(Interlocked.Increment(ref this.getWorkItemsCallCount), options); + return (AsyncServerStreamingCall)(object)call; + } + + throw new NotSupportedException($"Unexpected server-streaming method {method.FullName}."); + } + + public override AsyncClientStreamingCall AsyncClientStreamingCall(Method method, string? host, CallOptions options) + { + throw new NotSupportedException(); + } + + public override AsyncDuplexStreamingCall AsyncDuplexStreamingCall(Method method, string? host, CallOptions options) + { + throw new NotSupportedException(); + } + } +} diff --git a/test/Worker/Grpc.Tests/ReconnectBackoffTests.cs b/test/Worker/Grpc.Tests/ReconnectBackoffTests.cs new file mode 100644 index 000000000..024f179eb --- /dev/null +++ b/test/Worker/Grpc.Tests/ReconnectBackoffTests.cs @@ -0,0 +1,145 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +namespace Microsoft.DurableTask.Worker.Grpc.Tests; + +public class ReconnectBackoffTests +{ + [Fact] + public void Compute_ZeroBase_ReturnsZero() + { + // Arrange + Random random = new(42); + + // Act + TimeSpan delay = ReconnectBackoff.Compute(attempt: 5, baseDelay: TimeSpan.Zero, cap: TimeSpan.FromSeconds(30), random); + + // Assert + delay.Should().Be(TimeSpan.Zero); + } + + [Fact] + public void Compute_NegativeBase_ReturnsZero() + { + // Arrange + Random random = new(42); + + // Act + TimeSpan delay = ReconnectBackoff.Compute(attempt: 0, baseDelay: TimeSpan.FromMilliseconds(-100), cap: TimeSpan.FromSeconds(30), random); + + // Assert + delay.Should().Be(TimeSpan.Zero); + } + + [Fact] + public void Compute_NeverExceedsCap() + { + // Arrange + TimeSpan cap = TimeSpan.FromSeconds(30); + TimeSpan baseDelay = TimeSpan.FromSeconds(1); + Random random = new(1); + + // Act + Assert: try a wide range of attempts, including pathological values. + for (int attempt = 0; attempt < 50; attempt++) + { + TimeSpan delay = ReconnectBackoff.Compute(attempt, baseDelay, cap, random); + delay.Should().BeLessThanOrEqualTo(cap, $"attempt {attempt} produced {delay}"); + delay.Should().BeGreaterThanOrEqualTo(TimeSpan.Zero); + } + } + + [Fact] + public void Compute_GrowsExponentiallyUntilCap() + { + // Arrange: a Random that always returns 1.0 forces the upper bound of the jitter window. + DeterministicRandom random = new(value: 0.999999); + TimeSpan baseDelay = TimeSpan.FromSeconds(1); + TimeSpan cap = TimeSpan.FromSeconds(30); + + // Act + double d0 = ReconnectBackoff.Compute(0, baseDelay, cap, random).TotalMilliseconds; + double d1 = ReconnectBackoff.Compute(1, baseDelay, cap, random).TotalMilliseconds; + double d2 = ReconnectBackoff.Compute(2, baseDelay, cap, random).TotalMilliseconds; + double d3 = ReconnectBackoff.Compute(3, baseDelay, cap, random).TotalMilliseconds; + double d10 = ReconnectBackoff.Compute(10, baseDelay, cap, random).TotalMilliseconds; + + // Assert: roughly doubles each step until cap is reached. + d0.Should().BeApproximately(1000, 1); + d1.Should().BeApproximately(2000, 1); + d2.Should().BeApproximately(4000, 1); + d3.Should().BeApproximately(8000, 1); + d10.Should().BeApproximately(30000, 1, "should be clamped at the cap"); + } + + [Fact] + public void Compute_WithFullJitter_StaysWithinBounds() + { + // Arrange: with random=0 the result is 0; with random=1 the result is the bound. + TimeSpan baseDelay = TimeSpan.FromSeconds(1); + TimeSpan cap = TimeSpan.FromSeconds(30); + + // Act + Assert: random=0 → 0 + TimeSpan low = ReconnectBackoff.Compute(3, baseDelay, cap, new DeterministicRandom(0.0)); + low.TotalMilliseconds.Should().BeApproximately(0, 0.5); + + // random ~1 → bound (= 8s for attempt=3, base=1s) + TimeSpan high = ReconnectBackoff.Compute(3, baseDelay, cap, new DeterministicRandom(0.999999)); + high.TotalMilliseconds.Should().BeApproximately(8000, 1); + } + + [Fact] + public void Compute_NegativeAttempt_TreatedAsZero() + { + // Arrange + DeterministicRandom random = new(0.999999); + + // Act + TimeSpan delay = ReconnectBackoff.Compute(attempt: -5, baseDelay: TimeSpan.FromSeconds(1), cap: TimeSpan.FromSeconds(30), random); + + // Assert + delay.TotalMilliseconds.Should().BeApproximately(1000, 1); + } + + [Fact] + public void Compute_CapSmallerThanBase_ClampsToCap() + { + // Arrange: cap is intentionally smaller than baseDelay; the cap must still be honored. + DeterministicRandom random = new(0.999999); + TimeSpan baseDelay = TimeSpan.FromSeconds(5); + TimeSpan cap = TimeSpan.FromSeconds(1); + + // Act + TimeSpan delay = ReconnectBackoff.Compute(attempt: 3, baseDelay, cap, random); + + // Assert: with random ~ 1 the result is the bound, which must equal the cap. + delay.TotalMilliseconds.Should().BeApproximately(1000, 1); + delay.Should().BeLessThanOrEqualTo(cap); + } + + [Fact] + public void Compute_NonPositiveCap_ReturnsZero() + { + // Arrange + DeterministicRandom random = new(0.999999); + + // Act + TimeSpan zero = ReconnectBackoff.Compute(attempt: 3, baseDelay: TimeSpan.FromSeconds(1), cap: TimeSpan.Zero, random); + TimeSpan negative = ReconnectBackoff.Compute(attempt: 3, baseDelay: TimeSpan.FromSeconds(1), cap: TimeSpan.FromSeconds(-1), random); + + // Assert + zero.Should().Be(TimeSpan.Zero); + negative.Should().Be(TimeSpan.Zero); + } + + sealed class DeterministicRandom : Random + { + readonly double value; + + public DeterministicRandom(double value) + { + this.value = value; + } + + public override double NextDouble() => this.value; + } +} diff --git a/test/Worker/Grpc.Tests/WorkItemStreamConsumerTests.cs b/test/Worker/Grpc.Tests/WorkItemStreamConsumerTests.cs new file mode 100644 index 000000000..2464c0c7e --- /dev/null +++ b/test/Worker/Grpc.Tests/WorkItemStreamConsumerTests.cs @@ -0,0 +1,263 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Runtime.CompilerServices; +using System.Threading.Channels; +using Grpc.Core; +using Microsoft.DurableTask.Worker.Grpc; +using P = Microsoft.DurableTask.Protobuf; + +namespace Microsoft.DurableTask.Worker.Grpc.Tests; + +public class WorkItemStreamConsumerTests +{ + static readonly TimeSpan ShortTimeout = TimeSpan.FromMilliseconds(150); + + [Fact] + public async Task EmptyStream_ReturnsGracefulDrain() + { + WorkItemStreamResult result = await WorkItemStreamConsumer.ConsumeAsync( + openStream: _ => EmptyStream(), + silentDisconnectTimeout: TimeSpan.FromSeconds(5), + onItem: _ => throw new InvalidOperationException("onItem should not be invoked"), + onFirstMessage: () => throw new InvalidOperationException("onFirstMessage should not be invoked"), + cancellation: CancellationToken.None); + + result.Outcome.Should().Be(WorkItemStreamOutcome.GracefulDrain); + result.FirstMessageObserved.Should().BeFalse(); + } + + [Fact] + public async Task StreamWithItems_ReturnsGracefulDrain_AndFiresCallbacks() + { + P.WorkItem item1 = new() { HealthPing = new P.HealthPing() }; + P.WorkItem item2 = new() { HealthPing = new P.HealthPing() }; + List received = new(); + int firstMessageCount = 0; + + WorkItemStreamResult result = await WorkItemStreamConsumer.ConsumeAsync( + openStream: _ => StreamOf(item1, item2), + silentDisconnectTimeout: TimeSpan.FromSeconds(5), + onItem: received.Add, + onFirstMessage: () => firstMessageCount++, + cancellation: CancellationToken.None); + + result.Outcome.Should().Be(WorkItemStreamOutcome.GracefulDrain); + result.FirstMessageObserved.Should().BeTrue(); + received.Should().BeEquivalentTo(new[] { item1, item2 }, o => o.WithStrictOrdering()); + firstMessageCount.Should().Be(1); + } + + [Fact] + public async Task VeryLargeSilentDisconnectTimeout_IsClamped_AndStreamCanStillComplete() + { + WorkItemStreamResult result = await WorkItemStreamConsumer.ConsumeAsync( + openStream: _ => EmptyStream(), + silentDisconnectTimeout: TimeSpan.FromDays(365), + onItem: _ => throw new InvalidOperationException("onItem should not be invoked"), + onFirstMessage: null, + cancellation: CancellationToken.None); + + result.Outcome.Should().Be(WorkItemStreamOutcome.GracefulDrain); + result.FirstMessageObserved.Should().BeFalse(); + } + + [Fact] + public async Task HangingStream_SurfacingOce_ReturnsSilentDisconnect() + { + WorkItemStreamResult result = await WorkItemStreamConsumer.ConsumeAsync( + openStream: ct => HangingStream(ct, throwAsRpc: false), + silentDisconnectTimeout: ShortTimeout, + onItem: _ => { }, + onFirstMessage: null, + cancellation: CancellationToken.None); + + result.Outcome.Should().Be(WorkItemStreamOutcome.SilentDisconnect); + result.FirstMessageObserved.Should().BeFalse(); + } + + /// + /// Regression test for the C1 silent-disconnect bug. grpc-dotnet by default surfaces a linked-token + /// cancellation as (StatusCode.Cancelled), not . + /// Pre-fix this exception propagated past the silent-disconnect branch and the channel-recreate + /// callback was never invoked. + /// + [Fact] + public async Task HangingStream_SurfacingRpcCancelled_ReturnsSilentDisconnect() + { + WorkItemStreamResult result = await WorkItemStreamConsumer.ConsumeAsync( + openStream: ct => HangingStream(ct, throwAsRpc: true), + silentDisconnectTimeout: ShortTimeout, + onItem: _ => { }, + onFirstMessage: null, + cancellation: CancellationToken.None); + + result.Outcome.Should().Be(WorkItemStreamOutcome.SilentDisconnect); + result.FirstMessageObserved.Should().BeFalse(); + } + + [Fact] + public async Task OuterCancellation_WithOceFromStream_ReturnsShutdown() + { + // When the inner stream surfaces cancellation as OperationCanceledException, the helper + // classifies the termination and returns Shutdown. + using CancellationTokenSource outer = new(); + outer.CancelAfter(ShortTimeout); + + WorkItemStreamResult result = await WorkItemStreamConsumer.ConsumeAsync( + openStream: ct => HangingStream(ct, throwAsRpc: false), + silentDisconnectTimeout: TimeSpan.FromSeconds(30), + onItem: _ => { }, + onFirstMessage: null, + cancellation: outer.Token); + + result.Outcome.Should().Be(WorkItemStreamOutcome.Shutdown); + result.FirstMessageObserved.Should().BeFalse(); + } + + [Fact] + public async Task OuterCancellation_WithRpcCancelledFromStream_PropagatesException() + { + // When the inner stream surfaces outer cancellation as RpcException(Cancelled), the helper + // does NOT classify it as Shutdown — the caller's outer catch chain (ExecuteAsync) handles + // RpcException(Cancelled)-during-shutdown. Adding it to the helper would conflict with the + // post-fix silent-disconnect catch, which scopes RpcException(Cancelled) handling to the case + // where the timeout source — not the outer cancellation — fired. + using CancellationTokenSource outer = new(); + outer.CancelAfter(ShortTimeout); + + Func act = () => WorkItemStreamConsumer.ConsumeAsync( + openStream: ct => HangingStream(ct, throwAsRpc: true), + silentDisconnectTimeout: TimeSpan.FromSeconds(30), + onItem: _ => { }, + onFirstMessage: null, + cancellation: outer.Token); + + await act.Should().ThrowAsync().Where(e => e.StatusCode == StatusCode.Cancelled); + } + + [Fact] + public async Task PerItem_HeartbeatReset_KeepsTimerAlive() + { + // Feed one item, wait long enough that the original timer would have expired, then complete. + // Synchronize on the first item actually being processed so the second delay is measured from + // the consumer's timer reset instead of from the test thread's write timing. + Channel channel = Channel.CreateUnbounded(); + TimeSpan timeout = TimeSpan.FromMilliseconds(500); + TaskCompletionSource firstItemProcessed = new(TaskCreationOptions.RunContinuationsAsynchronously); + int itemCount = 0; + + Task consumeTask = WorkItemStreamConsumer.ConsumeAsync( + openStream: ct => channel.Reader.ReadAllAsync(ct), + silentDisconnectTimeout: timeout, + onItem: _ => + { + if (Interlocked.Increment(ref itemCount) == 1) + { + firstItemProcessed.TrySetResult(); + } + }, + onFirstMessage: null, + cancellation: CancellationToken.None); + + await Task.Delay(TimeSpan.FromMilliseconds(150)); + await channel.Writer.WriteAsync(new P.WorkItem { HealthPing = new P.HealthPing() }); + await firstItemProcessed.Task.WaitAsync(TimeSpan.FromSeconds(5)); + + // Without the per-item reset, the original timer would fire before this second item arrives. + await Task.Delay(TimeSpan.FromMilliseconds(400)); + await channel.Writer.WriteAsync(new P.WorkItem { HealthPing = new P.HealthPing() }); + channel.Writer.Complete(); + + WorkItemStreamResult result = await consumeTask; + + result.Outcome.Should().Be(WorkItemStreamOutcome.GracefulDrain); + result.FirstMessageObserved.Should().BeTrue(); + } + + [Fact] + public async Task UnrelatedRpcException_Propagates() + { + Func act = () => WorkItemStreamConsumer.ConsumeAsync( + openStream: _ => ThrowingStream(new RpcException(new Status(StatusCode.Unavailable, "boom"))), + silentDisconnectTimeout: TimeSpan.FromSeconds(5), + onItem: _ => { }, + onFirstMessage: null, + cancellation: CancellationToken.None); + + await act.Should().ThrowAsync().Where(e => e.StatusCode == StatusCode.Unavailable); + } + + [Theory] + [InlineData(0)] + [InlineData(-1)] + public async Task NonPositiveSilentDisconnectTimeout_OnlyShutdownEndsLoop(int timeoutMilliseconds) + { + // Arrange + using CancellationTokenSource outer = new(); + outer.CancelAfter(ShortTimeout); + + // Act + WorkItemStreamResult result = await WorkItemStreamConsumer.ConsumeAsync( + openStream: ct => HangingStream(ct, throwAsRpc: false), + silentDisconnectTimeout: TimeSpan.FromMilliseconds(timeoutMilliseconds), + onItem: _ => { }, + onFirstMessage: null, + cancellation: outer.Token); + + // Assert + result.Outcome.Should().Be(WorkItemStreamOutcome.Shutdown); + } + +#pragma warning disable CS1998 // Async method lacks 'await' operators + static async IAsyncEnumerable EmptyStream() + { + yield break; + } + + static async IAsyncEnumerable StreamOf(params P.WorkItem[] items) + { + foreach (P.WorkItem item in items) + { + yield return item; + } + } + + static IAsyncEnumerable ThrowingStream(Exception ex) => new ThrowingAsyncEnumerable(ex); +#pragma warning restore CS1998 + + static async IAsyncEnumerable HangingStream( + [EnumeratorCancellation] CancellationToken ct, + bool throwAsRpc) + { + try + { + await Task.Delay(Timeout.Infinite, ct); + } + catch (OperationCanceledException) when (throwAsRpc) + { + // Mimic grpc-dotnet's default surface shape for linked-token cancellation. + throw new RpcException(new Status(StatusCode.Cancelled, "stream cancelled")); + } + + yield break; + } + + sealed class ThrowingAsyncEnumerable : IAsyncEnumerable, IAsyncEnumerator + { + readonly Exception exception; + + public ThrowingAsyncEnumerable(Exception exception) + { + this.exception = exception; + } + + public P.WorkItem Current => throw new InvalidOperationException("No current item is available for a throwing stream."); + + public IAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) => this; + + public ValueTask DisposeAsync() => default; + + public ValueTask MoveNextAsync() => ValueTask.FromException(this.exception); + } +}