Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions src/ModelContextProtocol.Core/Client/StdioClientTransport.cs
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,8 @@ public async Task<ITransport> ConnectAsync(CancellationToken cancellationToken =

if (logger.IsEnabled(LogLevel.Trace))
{
LogCreateProcessForTransportSensitive(logger, endpointName, _options.Command,
LogCreateProcessForTransportDetailed(logger, endpointName, _options.Command,
startInfo.Arguments,
string.Join(", ", startInfo.Environment.Select(kvp => $"{kvp.Key}={kvp.Value}")),
startInfo.WorkingDirectory);
}
else
Expand Down Expand Up @@ -295,8 +294,8 @@ private static string EscapeArgumentString(string argument) =>
[LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} starting server process. Command: '{Command}'.")]
private static partial void LogCreateProcessForTransport(ILogger logger, string endpointName, string command);

[LoggerMessage(Level = LogLevel.Trace, Message = "{EndpointName} starting server process. Command: '{Command}', Arguments: {Arguments}, Environment: {Environment}, Working directory: {WorkingDirectory}.")]
private static partial void LogCreateProcessForTransportSensitive(ILogger logger, string endpointName, string command, string? arguments, string environment, string workingDirectory);
[LoggerMessage(Level = LogLevel.Trace, Message = "{EndpointName} starting server process. Command: '{Command}', Arguments: {Arguments}, Working directory: {WorkingDirectory}.")]
private static partial void LogCreateProcessForTransportDetailed(ILogger logger, string endpointName, string command, string? arguments, string workingDirectory);

[LoggerMessage(Level = LogLevel.Warning, Message = "{EndpointName} failed to start server process.")]
private static partial void LogTransportProcessStartFailed(ILogger logger, string endpointName);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using ModelContextProtocol.Client;
using Microsoft.Extensions.Logging;
using ModelContextProtocol.Client;
using ModelContextProtocol.Protocol;
using ModelContextProtocol.Tests.Utils;
using System.IO.Pipelines;
Expand All @@ -12,6 +13,42 @@ public class StdioClientTransportTests(ITestOutputHelper testOutputHelper) : Log
{
public static bool IsStdErrCallbackSupported => !PlatformDetection.IsMonoRuntime;

[Fact]
public async Task ConnectAsync_DoesNotLogEnvironmentVariablesAtTrace()
{
string secretName = $"MCP_TEST_SECRET_{Guid.NewGuid():N}";
string secretValue = $"secret-{Guid.NewGuid():N}";

using var loggerFactory = Microsoft.Extensions.Logging.LoggerFactory.Create(builder =>
{
builder.AddProvider(MockLoggerProvider);
builder.SetMinimumLevel(LogLevel.Trace);
});

StdioClientTransport transport = RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ?
new(new()
{
Command = "cmd.exe",
Arguments = ["/c", "exit /b 0"],
EnvironmentVariables = new Dictionary<string, string?> { [secretName] = secretValue },
}, loggerFactory) :
new(new()
{
Command = "sh",
Arguments = ["-c", "exit 0"],
EnvironmentVariables = new Dictionary<string, string?> { [secretName] = secretValue },
}, loggerFactory);

await using var _ = await transport.ConnectAsync(TestContext.Current.CancellationToken);

Assert.Contains(MockLoggerProvider.LogMessages, log =>
log.LogLevel == LogLevel.Trace &&
log.Message.Contains("starting server process", StringComparison.Ordinal));
Assert.DoesNotContain(MockLoggerProvider.LogMessages, log =>
log.Message.Contains(secretName, StringComparison.Ordinal) ||
log.Message.Contains(secretValue, StringComparison.Ordinal));
}

[Fact]
public async Task CreateAsync_ValidProcessInvalidServer_Throws()
{
Expand Down