diff --git a/dotnet/src/Client.cs b/dotnet/src/Client.cs index 9608d98c3..03f43c9a2 100644 --- a/dotnet/src/Client.cs +++ b/dotnet/src/Client.cs @@ -57,6 +57,7 @@ public sealed partial class CopilotClient : IDisposable, IAsyncDisposable /// Minimum protocol version this SDK can communicate with. /// private const int MinProtocolVersion = 3; + private static readonly TimeSpan s_stderrPumpShutdownTimeout = TimeSpan.FromSeconds(5); /// /// Provides a thread-safe collection of active Copilot sessions, indexed by session identifier. @@ -235,6 +236,7 @@ async Task StartCoreAsync(CancellationToken ct) var startTimestamp = Stopwatch.GetTimestamp(); Connection? connection = null; Process? cliProcess = null; + ProcessStderrPump? stderrPump = null; try { @@ -247,10 +249,11 @@ async Task StartCoreAsync(CancellationToken ct) else { // Child process (stdio or TCP) - var (startedProcess, portOrNull, stderrBuffer) = await StartCliServerAsync(ct); + var (startedProcess, portOrNull, startedStderrPump) = await StartCliServerAsync(ct); cliProcess = startedProcess; + stderrPump = startedStderrPump; _actualPort = portOrNull; - connection = await ConnectToServerAsync(cliProcess, portOrNull is null ? null : "localhost", portOrNull, stderrBuffer, ct); + connection = await ConnectToServerAsync(cliProcess, portOrNull is null ? null : "localhost", portOrNull, stderrPump, ct); } LoggingHelpers.LogTiming(_logger, LogLevel.Debug, null, @@ -292,7 +295,7 @@ async Task StartCoreAsync(CancellationToken ct) } else if (cliProcess is not null) { - await CleanupCliProcessAsync(cliProcess, errors: null, _logger); + await CleanupCliProcessAsync(cliProcess, stderrPump, errors: null, _logger); } throw; @@ -436,31 +439,56 @@ private async Task CleanupConnectionAsync(Connection ctx, List? error if (ctx.CliProcess is { } childProcess) { - await CleanupCliProcessAsync(childProcess, errors, _logger); + await CleanupCliProcessAsync(childProcess, ctx.StderrPump, errors, _logger); } } - private static async Task CleanupCliProcessAsync(Process childProcess, List? errors, ILogger? logger) + private static async Task CleanupCliProcessAsync(Process childProcess, ProcessStderrPump? stderrPump, List? errors, ILogger? logger) { + stderrPump?.Cancel(); + try { + if (!childProcess.HasExited) + { + childProcess.Kill(entireProcessTree: true); + // Kill is asynchronous; wait for the root CLI process to exit so cleanup callers + // do not observe StopAsync/DisposeAsync completion while it is still tearing down. + await childProcess.WaitForExitAsync(); + } + } + catch (Exception ex) + { + AddCleanupError(errors, ex, logger); + } + + if (stderrPump is not null) + { + var stderrPumpWaitTimestamp = Stopwatch.GetTimestamp(); try { - if (!childProcess.HasExited) + await stderrPump.Completion.WaitAsync(s_stderrPumpShutdownTimeout); + } + catch (TimeoutException ex) + { + if (logger is not null) { - childProcess.Kill(entireProcessTree: true); - await childProcess.WaitForExitAsync(); + LoggingHelpers.LogTiming(logger, LogLevel.Debug, ex, + "Timed out waiting for runtime stderr pump to stop. Elapsed={Elapsed}, Timeout={Timeout}", + stderrPumpWaitTimestamp, + s_stderrPumpShutdownTimeout); } + + AddCleanupError(errors, ex, logger); } - finally + catch (Exception ex) { - childProcess.Dispose(); + AddCleanupError(errors, ex, logger); } } - catch (Exception ex) - { - AddCleanupError(errors, ex, logger); - } + + try { childProcess.Dispose(); } + catch (Exception ex) { AddCleanupError(errors, ex, logger); } } private static void AddCleanupError(List? errors, Exception ex, ILogger? logger) @@ -1655,7 +1683,7 @@ private static bool IsUnsupportedConnectMethod(RemoteRpcException ex) || string.Equals(ex.Message, "Unhandled method connect", StringComparison.Ordinal); } - private async Task<(Process Process, int? DetectedLocalhostTcpPort, StringBuilder StderrBuffer)> StartCliServerAsync(CancellationToken cancellationToken) + private async Task<(Process Process, int? DetectedLocalhostTcpPort, ProcessStderrPump StderrPump)> StartCliServerAsync(CancellationToken cancellationToken) { var options = _options; var logger = _logger; @@ -1779,38 +1807,30 @@ private static bool IsUnsupportedConnectMethod(RemoteRpcException ex) if (telemetry.CaptureContent is { } capture) startInfo.Environment["OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT"] = capture ? "true" : "false"; } - Process? cliProcess = null; + var cliProcess = new Process { StartInfo = startInfo }; try { - cliProcess = new Process { StartInfo = startInfo }; var spawnTimestamp = Stopwatch.GetTimestamp(); cliProcess.Start(); LoggingHelpers.LogTiming(logger, LogLevel.Debug, null, "CopilotClient.StartCliServerAsync subprocess spawned. Elapsed={Elapsed}", spawnTimestamp); + } + catch + { + cliProcess.Dispose(); + throw; + } - // Capture stderr for error messages and forward to logger - var stderrBuffer = new StringBuilder(); - var stderrReader = Task.Run(async () => - { - while (true) - { - var line = await cliProcess.StandardError.ReadLineAsync(cancellationToken); - if (line is null) - { - break; - } - - lock (stderrBuffer) - { - stderrBuffer.AppendLine(line); - } - - logger.LogWarning("[CLI] {Line}", line); - } - }, cancellationToken); + ProcessStderrPump? stderrPump = null; + int? detectedLocalhostTcpPort = null; + try + { + // Capture stderr for error messages and forward to logger. + // The pump has its own lifetime token and is later cancelled/observed + // by the owning Connection before the process is disposed. + stderrPump = ProcessStderrPump.Start(cliProcess, logger); - var detectedLocalhostTcpPort = (int?)null; if (!useStdio) { // Wait for port announcement @@ -1818,40 +1838,48 @@ private static bool IsUnsupportedConnectMethod(RemoteRpcException ex) using var cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); cts.CancelAfter(TimeSpan.FromSeconds(30)); - while (!cts.Token.IsCancellationRequested) + try { - var line = await cliProcess.StandardOutput.ReadLineAsync(cts.Token); - if (line is null) + while (await cliProcess.StandardOutput.ReadLineAsync(cts.Token) is string line) { - await stderrReader; - throw CreateCliExitedException("Runtime process exited unexpectedly", stderrBuffer); + if (logger.IsEnabled(LogLevel.Debug)) + { + logger.LogDebug("[CLI] {Line}", line); + } + + if (ListeningOnPortRegex().Match(line) is { Success: true } match) + { + detectedLocalhostTcpPort = int.Parse(match.Groups[1].Value, CultureInfo.InvariantCulture); + LoggingHelpers.LogTiming(logger, LogLevel.Debug, null, + "CopilotClient.StartCliServerAsync TCP port wait complete. Elapsed={Elapsed}, Port={Port}", + portWaitTimestamp, + detectedLocalhostTcpPort.Value); + break; + } } - if (logger.IsEnabled(LogLevel.Debug)) + if (detectedLocalhostTcpPort is null) { - logger.LogDebug("[CLI] {Line}", line); - } - - if (ListeningOnPortRegex().Match(line) is { Success: true } match) - { - detectedLocalhostTcpPort = int.Parse(match.Groups[1].Value, CultureInfo.InvariantCulture); - LoggingHelpers.LogTiming(logger, LogLevel.Debug, null, - "CopilotClient.StartCliServerAsync TCP port wait complete. Elapsed={Elapsed}, Port={Port}", - portWaitTimestamp, - detectedLocalhostTcpPort.Value); - break; + // The CLI's stdout closed (process exited). Drain stderr + // before throwing so the surfaced exception includes the + // final diagnostic lines. + try { await stderrPump.Completion.WaitAsync(s_stderrPumpShutdownTimeout, CancellationToken.None); } + catch (TimeoutException) { /* best-effort: include whatever was captured */ } + catch (Exception ex) { logger.LogDebug(ex, "Runtime stderr pump faulted while draining"); } + throw CreateCliExitedException("Runtime process exited unexpectedly", stderrPump.Buffer); } } + catch (OperationCanceledException) when (!cancellationToken.IsCancellationRequested && cts.IsCancellationRequested) + { + throw CreateCliExitedException("Timed out waiting for Copilot CLI to report its TCP listening port.", stderrPump.Buffer); + } } - return (cliProcess, detectedLocalhostTcpPort, stderrBuffer); + return (cliProcess, detectedLocalhostTcpPort, stderrPump); } catch { - if (cliProcess is not null) - { - await CleanupCliProcessAsync(cliProcess, errors: null, logger); - } + await CleanupCliProcessAsync(cliProcess, stderrPump, errors: null, logger); throw; } @@ -1898,77 +1926,94 @@ private static (string FileName, IEnumerable Args) ResolveCliCommand(str return (cliPath, args); } - private async Task ConnectToServerAsync(Process? cliProcess, string? tcpHost, int? tcpPort, StringBuilder? stderrBuffer, CancellationToken cancellationToken) + private async Task ConnectToServerAsync(Process? cliProcess, string? tcpHost, int? tcpPort, ProcessStderrPump? stderrPump, CancellationToken cancellationToken) { var setupTimestamp = Stopwatch.GetTimestamp(); - Stream inputStream, outputStream; NetworkStream? networkStream = null; + JsonRpc? rpc = null; - if (_connection is StdioRuntimeConnection) + try { - if (cliProcess == null) - { - throw new InvalidOperationException("Runtime process not started"); - } + Stream inputStream, outputStream; - inputStream = cliProcess.StandardOutput.BaseStream; - outputStream = cliProcess.StandardInput.BaseStream; - } - else - { - if (tcpHost is null || tcpPort is null) + if (_connection is StdioRuntimeConnection) { - throw new InvalidOperationException("Cannot connect because TCP host or port are not available"); - } + if (cliProcess == null) + { + throw new InvalidOperationException("Runtime process not started"); + } - var socket = new Socket(SocketType.Stream, ProtocolType.Tcp); - try - { - var tcpConnectTimestamp = Stopwatch.GetTimestamp(); - LogConnectingToCliServer(_logger, tcpHost, tcpPort.Value); - await socket.ConnectAsync(tcpHost, tcpPort.Value, cancellationToken); - LoggingHelpers.LogTiming(_logger, LogLevel.Debug, null, - "CopilotClient.ConnectToServerAsync TCP connect complete. Elapsed={Elapsed}, Host={Host}, Port={Port}", - tcpConnectTimestamp, - tcpHost, - tcpPort.Value); + inputStream = cliProcess.StandardOutput.BaseStream; + outputStream = cliProcess.StandardInput.BaseStream; } - catch + else { - socket.Dispose(); - throw; + if (tcpHost is null || tcpPort is null) + { + throw new InvalidOperationException("Cannot connect because TCP host or port are not available"); + } + + var socket = new Socket(SocketType.Stream, ProtocolType.Tcp); + try + { + var tcpConnectTimestamp = Stopwatch.GetTimestamp(); + LogConnectingToCliServer(_logger, tcpHost, tcpPort.Value); + await socket.ConnectAsync(tcpHost, tcpPort.Value, cancellationToken); + LoggingHelpers.LogTiming(_logger, LogLevel.Debug, null, + "CopilotClient.ConnectToServerAsync TCP connect complete. Elapsed={Elapsed}, Host={Host}, Port={Port}", + tcpConnectTimestamp, + tcpHost, + tcpPort.Value); + } + catch + { + socket.Dispose(); + throw; + } + + inputStream = outputStream = networkStream = new NetworkStream(socket, ownsSocket: true); } - inputStream = outputStream = networkStream = new NetworkStream(socket, ownsSocket: true); - } + rpc = new JsonRpc( + outputStream, + inputStream, + SerializerOptionsForMessageFormatter, + _logger); + + var handler = new RpcHandler(this); + rpc.SetLocalRpcMethod("session.event", handler.OnSessionEvent); + rpc.SetLocalRpcMethod("session.lifecycle", handler.OnSessionLifecycle); + rpc.SetLocalRpcMethod("userInput.request", handler.OnUserInputRequest); + rpc.SetLocalRpcMethod("exitPlanMode.request", handler.OnExitPlanModeRequest); + rpc.SetLocalRpcMethod("autoModeSwitch.request", handler.OnAutoModeSwitchRequest); + rpc.SetLocalRpcMethod("hooks.invoke", handler.OnHooksInvoke); + rpc.SetLocalRpcMethod("systemMessage.transform", handler.OnSystemMessageTransform); + ClientSessionApiRegistration.RegisterClientSessionApiHandlers(rpc, sessionId => + { + var session = GetSession(sessionId) ?? throw new ArgumentException($"Unknown session {sessionId}"); + return session.ClientSessionApis; + }); + rpc.StartListening(); + LoggingHelpers.LogTiming(_logger, LogLevel.Debug, null, + "CopilotClient.ConnectToServerAsync transport setup complete. Elapsed={Elapsed}", + setupTimestamp); - var rpc = new JsonRpc( - outputStream, - inputStream, - SerializerOptionsForMessageFormatter, - _logger); + _serverRpc = new ServerRpc(rpc); - var handler = new RpcHandler(this); - rpc.SetLocalRpcMethod("session.event", handler.OnSessionEvent); - rpc.SetLocalRpcMethod("session.lifecycle", handler.OnSessionLifecycle); - rpc.SetLocalRpcMethod("userInput.request", handler.OnUserInputRequest); - rpc.SetLocalRpcMethod("exitPlanMode.request", handler.OnExitPlanModeRequest); - rpc.SetLocalRpcMethod("autoModeSwitch.request", handler.OnAutoModeSwitchRequest); - rpc.SetLocalRpcMethod("hooks.invoke", handler.OnHooksInvoke); - rpc.SetLocalRpcMethod("systemMessage.transform", handler.OnSystemMessageTransform); - ClientSessionApiRegistration.RegisterClientSessionApiHandlers(rpc, sessionId => + return new Connection(rpc, cliProcess, networkStream, stderrPump); + } + catch { - var session = GetSession(sessionId) ?? throw new ArgumentException($"Unknown session {sessionId}"); - return session.ClientSessionApis; - }); - rpc.StartListening(); - LoggingHelpers.LogTiming(_logger, LogLevel.Debug, null, - "CopilotClient.ConnectToServerAsync transport setup complete. Elapsed={Elapsed}", - setupTimestamp); - - _serverRpc = new ServerRpc(rpc); + try { rpc?.Dispose(); } + catch (Exception ex) { _logger.LogDebug(ex, "Failed to dispose JSON-RPC connection after startup failure"); } - return new Connection(rpc, cliProcess, networkStream, stderrBuffer); + if (networkStream is not null) + { + try { await networkStream.DisposeAsync(); } + catch (Exception ex) { _logger.LogDebug(ex, "Failed to dispose TCP stream after startup failure"); } + } + throw; + } } private static JsonSerializerOptions SerializerOptionsForMessageFormatter { get; } = CreateSerializerOptions(); @@ -2155,12 +2200,59 @@ private class Connection( JsonRpc rpc, Process? cliProcess, // Set if we created the child process NetworkStream? networkStream, // Set if using TCP - StringBuilder? stderrBuffer = null) // Captures stderr for error messages + ProcessStderrPump? stderrPump = null) // Captures stderr for error messages { public Process? CliProcess => cliProcess; public JsonRpc Rpc => rpc; public NetworkStream? NetworkStream => networkStream; - public StringBuilder? StderrBuffer => stderrBuffer; + public ProcessStderrPump? StderrPump => stderrPump; + public StringBuilder? StderrBuffer => stderrPump?.Buffer; + } + + private sealed class ProcessStderrPump + { + private readonly CancellationTokenSource _cancellationTokenSource = new(); + private readonly Task _completion; + + private ProcessStderrPump(Process process, ILogger logger) + { + _completion = Task.Run(() => PumpAsync(process, logger, _cancellationTokenSource.Token)); + } + + public StringBuilder Buffer { get; } = new(); + + public Task Completion => _completion; + + public static ProcessStderrPump Start(Process process, ILogger logger) + { + return new ProcessStderrPump(process, logger); + } + + public void Cancel() => _cancellationTokenSource.Cancel(); + + private async Task PumpAsync(Process process, ILogger logger, CancellationToken cancellationToken) + { + try + { + while (await process.StandardError.ReadLineAsync(cancellationToken) is string line) + { + lock (Buffer) + { + Buffer.AppendLine(line); + } + + logger.LogWarning("[CLI] {Line}", line); + } + } + catch (Exception e) when (cancellationToken.IsCancellationRequested + && e is OperationCanceledException or InvalidOperationException or ObjectDisposedException or IOException) + { + } + catch (Exception ex) + { + logger.LogDebug(ex, "Runtime stderr pump stopped unexpectedly"); + } + } } private static class ProcessArgumentEscaper diff --git a/dotnet/src/Polyfills/DownlevelExtensions.cs b/dotnet/src/Polyfills/DownlevelExtensions.cs index 0fdf70f3e..17c98643e 100644 --- a/dotnet/src/Polyfills/DownlevelExtensions.cs +++ b/dotnet/src/Polyfills/DownlevelExtensions.cs @@ -614,4 +614,35 @@ internal static class DownlevelValueTaskExtensions public static ValueTask FromResult(T result) => new(result); } } + + internal static class DownlevelTaskExtensions + { + extension(Task task) + { + public async Task WaitAsync(TimeSpan timeout, CancellationToken cancellationToken = default) + { + cancellationToken.ThrowIfCancellationRequested(); + + using var delayCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + var completed = await Task.WhenAny(task, Task.Delay(timeout, delayCts.Token)).ConfigureAwait(false); + if (!ReferenceEquals(completed, task)) + { + cancellationToken.ThrowIfCancellationRequested(); + throw new TimeoutException(); + } + + delayCts.Cancel(); + await task.ConfigureAwait(false); + } + } + + extension(Task task) + { + public async Task WaitAsync(TimeSpan timeout, CancellationToken cancellationToken = default) + { + await ((Task)task).WaitAsync(timeout, cancellationToken).ConfigureAwait(false); + return await task.ConfigureAwait(false); + } + } + } } diff --git a/dotnet/test/Polyfills/TaskExtensions.cs b/dotnet/test/Polyfills/TaskExtensions.cs deleted file mode 100644 index 04096e81d..000000000 --- a/dotnet/test/Polyfills/TaskExtensions.cs +++ /dev/null @@ -1,73 +0,0 @@ -/*--------------------------------------------------------------------------------------------- - * Copyright (c) Microsoft Corporation. All rights reserved. - *--------------------------------------------------------------------------------------------*/ - -// Polyfills for Task APIs not available on .NET Framework. -// These are test-only and not optimized for production use. - -#if !NET8_0_OR_GREATER - -using System.Threading; - -namespace System.Threading.Tasks; - -internal static class TestDownlevelTaskExtensions -{ - extension(Task task) - { - public Task WaitAsync(TimeSpan timeout) - { - if (task.IsCompleted) - { - return task; - } - - return WaitAsyncCore(task, timeout); - } - } - - extension(Task task) - { - public Task WaitAsync(TimeSpan timeout) - { - if (task.IsCompleted) - { - return task; - } - - return WaitAsyncCoreGeneric(task, timeout); - } - } - - private static async Task WaitAsyncCore(Task task, TimeSpan timeout) - { - using var cts = new CancellationTokenSource(); - var delayTask = Task.Delay(timeout, cts.Token); - var completedTask = await Task.WhenAny(task, delayTask).ConfigureAwait(false); - if (completedTask == task) - { - cts.Cancel(); - await task.ConfigureAwait(false); - } - else - { - throw new TimeoutException(); - } - } - - private static async Task WaitAsyncCoreGeneric(Task task, TimeSpan timeout) - { - using var cts = new CancellationTokenSource(); - var delayTask = Task.Delay(timeout, cts.Token); - var completedTask = await Task.WhenAny(task, delayTask).ConfigureAwait(false); - if (completedTask == task) - { - cts.Cancel(); - return await task.ConfigureAwait(false); - } - - throw new TimeoutException(); - } -} - -#endif