diff --git a/src/ModelContextProtocol.Core/Client/StdioClientTransport.cs b/src/ModelContextProtocol.Core/Client/StdioClientTransport.cs index 24a7dba8e..146f3d7ba 100644 --- a/src/ModelContextProtocol.Core/Client/StdioClientTransport.cs +++ b/src/ModelContextProtocol.Core/Client/StdioClientTransport.cs @@ -121,9 +121,8 @@ public async Task ConnectAsync(CancellationToken cancellationToken = if (logger.IsEnabled(LogLevel.Trace)) { - LogCreateProcessForTransportSensitive(logger, endpointName, _options.Command, + LogCreateProcessForTransportDetailed(logger, endpointName, _options.Command, startInfo.Arguments, - string.Join(", ", startInfo.Environment.Select(kvp => $"{kvp.Key}={kvp.Value}")), startInfo.WorkingDirectory); } else @@ -295,8 +294,8 @@ private static string EscapeArgumentString(string argument) => [LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} starting server process. Command: '{Command}'.")] private static partial void LogCreateProcessForTransport(ILogger logger, string endpointName, string command); - [LoggerMessage(Level = LogLevel.Trace, Message = "{EndpointName} starting server process. Command: '{Command}', Arguments: {Arguments}, Environment: {Environment}, Working directory: {WorkingDirectory}.")] - private static partial void LogCreateProcessForTransportSensitive(ILogger logger, string endpointName, string command, string? arguments, string environment, string workingDirectory); + [LoggerMessage(Level = LogLevel.Trace, Message = "{EndpointName} starting server process. Command: '{Command}', Arguments: {Arguments}, Working directory: {WorkingDirectory}.")] + private static partial void LogCreateProcessForTransportDetailed(ILogger logger, string endpointName, string command, string? arguments, string workingDirectory); [LoggerMessage(Level = LogLevel.Warning, Message = "{EndpointName} failed to start server process.")] private static partial void LogTransportProcessStartFailed(ILogger logger, string endpointName); diff --git a/tests/ModelContextProtocol.Tests/Transport/StdioClientTransportTests.cs b/tests/ModelContextProtocol.Tests/Transport/StdioClientTransportTests.cs index 1a999fd14..3e39ceabc 100644 --- a/tests/ModelContextProtocol.Tests/Transport/StdioClientTransportTests.cs +++ b/tests/ModelContextProtocol.Tests/Transport/StdioClientTransportTests.cs @@ -1,4 +1,5 @@ -using ModelContextProtocol.Client; +using Microsoft.Extensions.Logging; +using ModelContextProtocol.Client; using ModelContextProtocol.Protocol; using ModelContextProtocol.Tests.Utils; using System.IO.Pipelines; @@ -12,6 +13,42 @@ public class StdioClientTransportTests(ITestOutputHelper testOutputHelper) : Log { public static bool IsStdErrCallbackSupported => !PlatformDetection.IsMonoRuntime; + [Fact] + public async Task ConnectAsync_DoesNotLogEnvironmentVariablesAtTrace() + { + string secretName = $"MCP_TEST_SECRET_{Guid.NewGuid():N}"; + string secretValue = $"secret-{Guid.NewGuid():N}"; + + using var loggerFactory = Microsoft.Extensions.Logging.LoggerFactory.Create(builder => + { + builder.AddProvider(MockLoggerProvider); + builder.SetMinimumLevel(LogLevel.Trace); + }); + + StdioClientTransport transport = RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ? + new(new() + { + Command = "cmd.exe", + Arguments = ["/c", "exit /b 0"], + EnvironmentVariables = new Dictionary { [secretName] = secretValue }, + }, loggerFactory) : + new(new() + { + Command = "sh", + Arguments = ["-c", "exit 0"], + EnvironmentVariables = new Dictionary { [secretName] = secretValue }, + }, loggerFactory); + + await using var _ = await transport.ConnectAsync(TestContext.Current.CancellationToken); + + Assert.Contains(MockLoggerProvider.LogMessages, log => + log.LogLevel == LogLevel.Trace && + log.Message.Contains("starting server process", StringComparison.Ordinal)); + Assert.DoesNotContain(MockLoggerProvider.LogMessages, log => + log.Message.Contains(secretName, StringComparison.Ordinal) || + log.Message.Contains(secretValue, StringComparison.Ordinal)); + } + [Fact] public async Task CreateAsync_ValidProcessInvalidServer_Throws() {