From c12bfce4378fb79143585e35cec5c8babe1d80e6 Mon Sep 17 00:00:00 2001 From: David Gageot Date: Mon, 1 Jun 2026 12:20:45 +0200 Subject: [PATCH 1/4] fix: prevent X-Conversation-Id from mutating cached session on retry --- pkg/chatserver/agent.go | 12 +- pkg/chatserver/conversations_test.go | 6 +- .../conversations_transaction_test.go | 159 ++++++++++++++++++ pkg/chatserver/server.go | 90 ++++++---- pkg/session/branch.go | 64 +++++++ pkg/session/clone_test.go | 129 ++++++++++++++ 6 files changed, 422 insertions(+), 38 deletions(-) create mode 100644 pkg/chatserver/conversations_transaction_test.go create mode 100644 pkg/session/clone_test.go diff --git a/pkg/chatserver/agent.go b/pkg/chatserver/agent.go index c0a083360..f28eda8aa 100644 --- a/pkg/chatserver/agent.go +++ b/pkg/chatserver/agent.go @@ -157,8 +157,11 @@ func convertParts(in []ContentPart) []chat.MessagePart { // appendLatestUser walks msgs from the end and appends only the last // user-role message into sess. Used by conversation continuation, where // the session already contains the full prior history and we just need -// to inject what the client just said. -func appendLatestUser(sess *session.Session, msgs []ChatCompletionMessage) { +// to inject what the client just said. It returns true when a user +// message was found and appended, and false when the request carried no +// usable user message (so the caller can reject it instead of replaying +// the prior turn). +func appendLatestUser(sess *session.Session, msgs []ChatCompletionMessage) bool { for _, m := range slices.Backward(msgs) { role := strings.ToLower(strings.TrimSpace(m.Role)) // Treat any non-system/assistant/tool role as user (matches @@ -173,15 +176,16 @@ func appendLatestUser(sess *session.Session, msgs []ChatCompletionMessage) { Content: m.Content, MultiContent: parts, }}) - return + return true } content := strings.TrimSpace(m.Content) if content == "" { continue } sess.AddMessage(session.UserMessage(m.Content)) - return + return true } + return false } // agentEmit collects the side-effect callbacks invoked by runAgentLoop as diff --git a/pkg/chatserver/conversations_test.go b/pkg/chatserver/conversations_test.go index 038cc0758..629ed6cbd 100644 --- a/pkg/chatserver/conversations_test.go +++ b/pkg/chatserver/conversations_test.go @@ -70,21 +70,23 @@ func TestConversationStore_Delete(t *testing.T) { func TestAppendLatestUser(t *testing.T) { sess := session.New() - appendLatestUser(sess, []ChatCompletionMessage{ + appended := appendLatestUser(sess, []ChatCompletionMessage{ {Role: "system", Content: "be helpful"}, {Role: "user", Content: "first"}, {Role: "assistant", Content: "ack"}, {Role: "user", Content: "second"}, {Role: "tool", Content: "tool result", ToolCallID: "x"}, }) + assert.True(t, appended) assert.Equal(t, "second", sess.GetLastUserMessageContent()) } func TestAppendLatestUser_NoUserMessage(t *testing.T) { sess := session.New() - appendLatestUser(sess, []ChatCompletionMessage{ + appended := appendLatestUser(sess, []ChatCompletionMessage{ {Role: "system", Content: "be helpful"}, {Role: "assistant", Content: "ack"}, }) + assert.False(t, appended) assert.Empty(t, sess.GetLastUserMessageContent()) } diff --git a/pkg/chatserver/conversations_transaction_test.go b/pkg/chatserver/conversations_transaction_test.go new file mode 100644 index 000000000..c226a0b17 --- /dev/null +++ b/pkg/chatserver/conversations_transaction_test.go @@ -0,0 +1,159 @@ +package chatserver + +import ( + "errors" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/docker/docker-agent/pkg/session" +) + +// newConvServer builds a server with a conversation store wired up but no +// team/runtime, for exercising the transactional session plumbing without +// running the agent loop. +func newConvServer(t *testing.T) *server { + t.Helper() + return &server{ + policy: agentPolicy{exposed: []string{"root"}, fallback: "root"}, + conversations: newConversationStore(8, time.Hour), + conversationLocks: newConversationLockSet(), + } +} + +// TestResolveSession_WorksOnClone verifies that continuing a cached +// conversation mutates a copy, leaving the cached session untouched until +// the caller commits. +func TestResolveSession_WorksOnClone(t *testing.T) { + s := newConvServer(t) + + seed := session.New() + seed.AddMessage(session.UserMessage("first")) + s.conversations.Put("conv-1", seed) + + working, err := s.resolveSession("conv-1", []ChatCompletionMessage{ + {Role: "user", Content: "second"}, + }) + require.NoError(t, err) + require.NotNil(t, working) + + // The working copy carries the new turn... + assert.Equal(t, "second", working.GetLastUserMessageContent()) + assert.NotSame(t, seed, working, "must not hand back the cached pointer") + + // ...but the cached session is still at the prior state. + cached := s.conversations.Get("conv-1") + require.NotNil(t, cached) + assert.Same(t, seed, cached) + assert.Equal(t, "first", cached.GetLastUserMessageContent()) + assert.Equal(t, 1, cached.MessageCount()) +} + +// TestResolveSession_RejectsContinuationWithoutUser verifies that a +// continuation request carrying no new user message is rejected rather than +// silently replaying the prior turn. +func TestResolveSession_RejectsContinuationWithoutUser(t *testing.T) { + s := newConvServer(t) + + seed := session.New() + seed.AddMessage(session.UserMessage("first")) + s.conversations.Put("conv-1", seed) + + _, err := s.resolveSession("conv-1", []ChatCompletionMessage{ + {Role: "system", Content: "be helpful"}, + {Role: "assistant", Content: "ack"}, + }) + require.Error(t, err) + assert.Contains(t, err.Error(), "no user message") +} + +// TestCommitConversation_FailedRunDoesNotAdvance verifies that a failed run +// leaves the cached conversation at its prior state, so a retry runs against +// the last successful turn instead of inheriting half-failed history. +func TestCommitConversation_FailedRunDoesNotAdvance(t *testing.T) { + s := newConvServer(t) + + seed := session.New() + seed.AddMessage(session.UserMessage("first")) + s.conversations.Put("conv-1", seed) + + working, err := s.resolveSession("conv-1", []ChatCompletionMessage{ + {Role: "user", Content: "second"}, + }) + require.NoError(t, err) + + // The run failed: the working copy must not be committed. + s.commitConversation("conv-1", working, errors.New("boom")) + + cached := s.conversations.Get("conv-1") + require.NotNil(t, cached) + assert.Same(t, seed, cached, "cache must still hold the pre-failure session") + assert.Equal(t, "first", cached.GetLastUserMessageContent()) + assert.Equal(t, 1, cached.MessageCount()) +} + +// TestCommitConversation_SuccessfulRunAdvances verifies that a successful run +// commits the working copy back into the cache. +func TestCommitConversation_SuccessfulRunAdvances(t *testing.T) { + s := newConvServer(t) + + seed := session.New() + seed.AddMessage(session.UserMessage("first")) + s.conversations.Put("conv-1", seed) + + working, err := s.resolveSession("conv-1", []ChatCompletionMessage{ + {Role: "user", Content: "second"}, + }) + require.NoError(t, err) + + s.commitConversation("conv-1", working, nil) + + cached := s.conversations.Get("conv-1") + require.NotNil(t, cached) + assert.Same(t, working, cached, "cache must hold the committed working copy") + assert.Equal(t, "second", cached.GetLastUserMessageContent()) + assert.Equal(t, 2, cached.MessageCount()) +} + +// TestCommitConversation_RestoresAfterEviction verifies that a successful run +// restores a conversation that was evicted while the request was in flight. +func TestCommitConversation_RestoresAfterEviction(t *testing.T) { + s := newConvServer(t) + + seed := session.New() + seed.AddMessage(session.UserMessage("first")) + s.conversations.Put("conv-1", seed) + + working, err := s.resolveSession("conv-1", []ChatCompletionMessage{ + {Role: "user", Content: "second"}, + }) + require.NoError(t, err) + + // Evict while the request is "in flight". + s.conversations.Delete("conv-1") + require.Nil(t, s.conversations.Get("conv-1")) + + s.commitConversation("conv-1", working, nil) + + cached := s.conversations.Get("conv-1") + require.NotNil(t, cached) + assert.Equal(t, "second", cached.GetLastUserMessageContent()) +} + +// TestResolveSession_NewConversation verifies that a request without a cached +// conversation builds a fresh session from the full history. +func TestResolveSession_NewConversation(t *testing.T) { + s := newConvServer(t) + + working, err := s.resolveSession("conv-new", []ChatCompletionMessage{ + {Role: "user", Content: "hello"}, + }) + require.NoError(t, err) + require.NotNil(t, working) + assert.Equal(t, "hello", working.GetLastUserMessageContent()) + + // Nothing is cached until the run commits. + assert.Nil(t, s.conversations.Get("conv-new")) +} diff --git a/pkg/chatserver/server.go b/pkg/chatserver/server.go index e5fdc953c..e2aba26f2 100644 --- a/pkg/chatserver/server.go +++ b/pkg/chatserver/server.go @@ -376,9 +376,9 @@ func (s *server) handleChatCompletions(c echo.Context) error { } defer s.conversationLocks.release(conversationID) - sess := s.resolveSession(conversationID, req.Messages) - if sess == nil { - return writeError(c, http.StatusBadRequest, "no user message provided") + sess, err := s.resolveSession(conversationID, req.Messages) + if err != nil { + return writeError(c, http.StatusBadRequest, err.Error()) } agentName := s.policy.pick(req.Model) @@ -396,45 +396,68 @@ func (s *server) handleChatCompletions(c echo.Context) error { } if req.Stream { - err := s.streamChatCompletion(c, rt, sess, model, req.StreamOptions.IncludeUsage) - s.maybeStoreConversation(conversationID, sess) - return err + runErr := s.streamChatCompletion(c, rt, sess, model, req.StreamOptions.IncludeUsage) + s.commitConversation(conversationID, sess, runErr) + // The agent run outcome is reported in-band (SSE error event for + // streams, JSON error envelope otherwise), so the HTTP handler + // itself always succeeds once we've started writing the response. + return nil } - err = s.chatCompletion(c, rt, sess, model) - s.maybeStoreConversation(conversationID, sess) - return err + runErr := s.chatCompletion(c, rt, sess, model) + s.commitConversation(conversationID, sess, runErr) + return nil } // resolveSession decides whether to start fresh or continue an existing // conversation. When X-Conversation-Id is set and we have an existing -// session for it, we append only the latest user message from the -// request (the prior history is already in the session). Otherwise we -// build a brand-new session from the full request history. -func (s *server) resolveSession(id string, msgs []ChatCompletionMessage) *session.Session { +// session for it, we work on a deep copy and append only the latest user +// message from the request (the prior history is already in the +// session). The cached session is left untouched until the run succeeds +// (see commitConversation), so a failed turn never advances the +// canonical conversation state. Otherwise we build a brand-new session +// from the full request history. +// +// Returns an error when the request carries no usable user message, in +// which case the caller rejects the request rather than replaying the +// prior turn. +func (s *server) resolveSession(id string, msgs []ChatCompletionMessage) (*session.Session, error) { if id != "" { if existing := s.conversations.Get(id); existing != nil { - appendLatestUser(existing, msgs) - return existing + working := existing.Clone() + if !appendLatestUser(working, msgs) { + return nil, errors.New("no user message provided") + } + return working, nil } } - return buildSession(msgs) + sess := buildSession(msgs) + if sess == nil { + return nil, errors.New("no user message provided") + } + return sess, nil } -// maybeStoreConversation inserts the session into the cache after a -// run. We always insert to handle the case where the conversation was -// evicted while the request was in flight. -func (s *server) maybeStoreConversation(id string, sess *session.Session) { - if id == "" || s.conversations == nil { +// commitConversation stores the session into the cache after a run, but +// only when the run succeeded. A failed turn must not advance the cached +// conversation state: the working session was a clone, so leaving the +// cache untouched means a retry runs against the last successful state. +// +// We always Put on success, even for existing conversations, to handle +// the case where the conversation was evicted while the request was in +// flight. Put refreshes the lastUsed timestamp and stores the updated +// session. +func (s *server) commitConversation(id string, sess *session.Session, runErr error) { + if id == "" || s.conversations == nil || runErr != nil { return } - // Always Put, even for existing conversations, to handle eviction - // during request processing. Put refreshes the lastUsed timestamp - // and ensures the updated session is stored. s.conversations.Put(id, sess) } // chatCompletion runs the agent to completion and replies with one -// non-streaming OpenAI ChatCompletion object. +// non-streaming OpenAI ChatCompletion object. It returns the agent run +// error (nil on success) so the caller can decide whether to commit the +// conversation; the HTTP response — success or error envelope — is always +// written here. func (s *server) chatCompletion(c echo.Context, rt runtime.Runtime, sess *session.Session, model string) error { var toolCalls []ToolCallReference emit := agentEmit{ @@ -443,10 +466,11 @@ func (s *server) chatCompletion(c echo.Context, rt runtime.Runtime, sess *sessio }, } if err := runAgentLoop(c.Request().Context(), rt, sess, emit); err != nil { - return writeError(c, http.StatusInternalServerError, fmt.Sprintf("agent execution failed: %v", err)) + _ = writeError(c, http.StatusInternalServerError, fmt.Sprintf("agent execution failed: %v", err)) + return err } - return c.JSON(http.StatusOK, ChatCompletionResponse{ + _ = c.JSON(http.StatusOK, ChatCompletionResponse{ ID: newChatID(), Object: "chat.completion", Created: time.Now().Unix(), @@ -462,15 +486,17 @@ func (s *server) chatCompletion(c echo.Context, rt runtime.Runtime, sess *sessio }}, Usage: sessionUsage(sess), }) + return nil } // streamChatCompletion runs the agent and streams its response back to the // client as Server-Sent Events in OpenAI's chat.completion.chunk format. // -// The error return is reserved for future use (e.g. surfacing a write -// failure to the request logger). Today every error is converted into an -// in-band SSE error event, so the function always returns nil. -func (s *server) streamChatCompletion(c echo.Context, rt runtime.Runtime, sess *session.Session, model string, includeUsage bool) error { //nolint:unparam // see comment +// It returns the agent run error (nil on success) so the caller can decide +// whether to commit the conversation. The error is *also* reported in-band +// as an SSE error event, so the HTTP handler itself still returns nil; the +// return value here exists purely to drive the commit decision. +func (s *server) streamChatCompletion(c echo.Context, rt runtime.Runtime, sess *session.Session, model string, includeUsage bool) error { stream := newSSEStream(c.Response(), newChatID(), model) // Initial "role: assistant" delta so clients can start rendering. @@ -510,7 +536,7 @@ func (s *server) streamChatCompletion(c echo.Context, rt runtime.Runtime, sess * } } stream.done() - return nil + return runErr } // sseStream writes OpenAI-style chat.completion.chunk events to a response. diff --git a/pkg/session/branch.go b/pkg/session/branch.go index a7442a1ec..2dbbafb80 100644 --- a/pkg/session/branch.go +++ b/pkg/session/branch.go @@ -37,6 +37,70 @@ func BranchSession(parent *Session, branchAtPosition int) (*Session, error) { return branched, nil } +// Clone returns a deep copy of the session that is safe to mutate without +// affecting the original. Conversation items (messages and sub-sessions) +// are deep-cloned so the two sessions share no mutable state; scalar and +// configuration fields are copied verbatim so the clone runs identically. +// Unlike BranchSession, Clone keeps the original ID, message IDs, and the +// full message history, making it suitable for transactional "work on a +// copy, commit on success" flows. +func (s *Session) Clone() *Session { + if s == nil { + return nil + } + + s.mu.RLock() + defer s.mu.RUnlock() + + clone := &Session{ + ID: s.ID, + InputID: s.InputID, + Title: s.Title, + Evals: s.Evals, + EvalResult: s.EvalResult, + CreatedAt: s.CreatedAt, + ToolsApproved: s.ToolsApproved, + NonInteractive: s.NonInteractive, + HideToolResults: s.HideToolResults, + WorkingDir: s.WorkingDir, + SendUserMessage: s.SendUserMessage, + MaxIterations: s.MaxIterations, + MaxConsecutiveToolCalls: s.MaxConsecutiveToolCalls, + MaxOldToolCallTokens: s.MaxOldToolCallTokens, + Starred: s.Starred, + InputTokens: s.InputTokens, + OutputTokens: s.OutputTokens, + Cost: s.Cost, + Permissions: clonePermissionsConfig(s.Permissions), + AgentModelOverrides: cloneStringMap(s.AgentModelOverrides), + CustomModelsUsed: cloneStringSlice(s.CustomModelsUsed), + AttachedFiles: cloneStringSlice(s.AttachedFiles), + ExcludedTools: cloneStringSlice(s.ExcludedTools), + AgentName: s.AgentName, + ParentID: s.ParentID, + MessageUsageHistory: slices.Clone(s.MessageUsageHistory), + } + + clone.Messages = make([]Item, 0, len(s.Messages)) + for _, item := range s.Messages { + switch { + case item.Message != nil: + clone.Messages = append(clone.Messages, Item{Message: cloneMessage(item.Message)}) + case item.SubSession != nil: + sub, err := cloneSubSession(item.SubSession) + if err != nil { + continue + } + clone.Messages = append(clone.Messages, Item{SubSession: sub}) + default: + // Summary/Cost/FirstKeptEntry are value fields; a shallow + // copy of the item is already a faithful clone. + clone.Messages = append(clone.Messages, item) + } + } + return clone +} + func cloneSessionItem(item Item) (Item, error) { switch { case item.Message != nil: diff --git a/pkg/session/clone_test.go b/pkg/session/clone_test.go new file mode 100644 index 000000000..1fe8c177e --- /dev/null +++ b/pkg/session/clone_test.go @@ -0,0 +1,129 @@ +package session + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/docker/docker-agent/pkg/chat" +) + +func TestClone_NilSession(t *testing.T) { + var s *Session + assert.Nil(t, s.Clone()) +} + +func TestClone_CopiesScalarFields(t *testing.T) { + orig := &Session{ + ID: "sess-1", + Title: "title", + ToolsApproved: true, + NonInteractive: true, + HideToolResults: true, + WorkingDir: "/work", + SendUserMessage: true, + MaxIterations: 7, + MaxConsecutiveToolCalls: 3, + MaxOldToolCallTokens: 99, + Starred: true, + InputTokens: 11, + OutputTokens: 22, + Cost: 1.5, + Permissions: &PermissionsConfig{Allow: []string{"a"}, Deny: []string{"d"}}, + AgentModelOverrides: map[string]string{"root": "openai/gpt-4o"}, + CustomModelsUsed: []string{"openai/gpt-4o"}, + AttachedFiles: []string{"/abs/path.txt"}, + ExcludedTools: []string{"run_skill"}, + AgentName: "root", + ParentID: "parent", + } + orig.AddMessage(UserMessage("hello")) + + clone := orig.Clone() + require.NotNil(t, clone) + + // Unlike BranchSession, Clone keeps the original identity and history. + assert.Equal(t, "sess-1", clone.ID) + assert.Equal(t, "title", clone.Title) + assert.True(t, clone.ToolsApproved) + assert.True(t, clone.NonInteractive) + assert.True(t, clone.HideToolResults) + assert.Equal(t, "/work", clone.WorkingDir) + assert.Equal(t, 7, clone.MaxIterations) + assert.Equal(t, 3, clone.MaxConsecutiveToolCalls) + assert.Equal(t, 99, clone.MaxOldToolCallTokens) + assert.True(t, clone.Starred) + assert.Equal(t, int64(11), clone.InputTokens) + assert.Equal(t, int64(22), clone.OutputTokens) + assert.InEpsilon(t, 1.5, clone.Cost, 1e-9) + assert.Equal(t, "root", clone.AgentName) + assert.Equal(t, "parent", clone.ParentID) + assert.Equal(t, "hello", clone.GetLastUserMessageContent()) +} + +func TestClone_DeepCopiesMessagesAndConfig(t *testing.T) { + orig := &Session{ + ID: "sess-1", + Permissions: &PermissionsConfig{Allow: []string{"a"}}, + AgentModelOverrides: map[string]string{"root": "m1"}, + CustomModelsUsed: []string{"m1"}, + AttachedFiles: []string{"/abs/a.txt"}, + } + orig.AddMessage(&Message{Message: chat.Message{ + Role: chat.MessageRoleUser, + MultiContent: []chat.MessagePart{{ + Type: chat.MessagePartTypeImageURL, + ImageURL: &chat.MessageImageURL{URL: "http://orig"}, + }}, + }}) + + clone := orig.Clone() + require.NotNil(t, clone) + + // Mutate the clone's deep-copied structures; the original must not change. + clone.Permissions.Allow[0] = "mutated" + clone.AgentModelOverrides["root"] = "mutated" + clone.CustomModelsUsed[0] = "mutated" + clone.AttachedFiles[0] = "/abs/mutated.txt" + clone.Messages[0].Message.Message.MultiContent[0].ImageURL.URL = "http://mutated" + + assert.Equal(t, "a", orig.Permissions.Allow[0]) + assert.Equal(t, "m1", orig.AgentModelOverrides["root"]) + assert.Equal(t, "m1", orig.CustomModelsUsed[0]) + assert.Equal(t, "/abs/a.txt", orig.AttachedFiles[0]) + assert.Equal(t, "http://orig", orig.Messages[0].Message.Message.MultiContent[0].ImageURL.URL) +} + +func TestClone_AppendingDoesNotAffectOriginal(t *testing.T) { + orig := New() + orig.AddMessage(UserMessage("first")) + + clone := orig.Clone() + clone.AddMessage(UserMessage("second")) + + assert.Equal(t, 1, orig.MessageCount()) + assert.Equal(t, 2, clone.MessageCount()) + assert.Equal(t, "first", orig.GetLastUserMessageContent()) + assert.Equal(t, "second", clone.GetLastUserMessageContent()) +} + +func TestClone_PreservesSubSessionAndSummary(t *testing.T) { + sub := New() + sub.AddMessage(UserMessage("sub message")) + + orig := New() + orig.AddMessage(UserMessage("first")) + orig.AddSubSession(sub) + orig.Messages = append(orig.Messages, Item{Summary: "a summary", Cost: 0.25}) + + clone := orig.Clone() + require.Len(t, clone.Messages, 3) + + assert.Equal(t, "first", clone.Messages[0].Message.Message.Content) + require.NotNil(t, clone.Messages[1].SubSession) + assert.NotSame(t, sub, clone.Messages[1].SubSession) + assert.Equal(t, "sub message", clone.Messages[1].SubSession.GetLastUserMessageContent()) + assert.Equal(t, "a summary", clone.Messages[2].Summary) + assert.InEpsilon(t, 0.25, clone.Messages[2].Cost, 1e-9) +} From da7a5f5d101ca66dc339d5100b8ed3b3074a8324 Mon Sep 17 00:00:00 2001 From: David Gageot Date: Mon, 1 Jun 2026 12:27:30 +0200 Subject: [PATCH 2/4] fix: preserve item value fields and Ask permission in Session.Clone --- pkg/session/branch.go | 29 ++++++++++++++--------------- pkg/session/clone_test.go | 28 ++++++++++++++++++++++++++-- 2 files changed, 40 insertions(+), 17 deletions(-) diff --git a/pkg/session/branch.go b/pkg/session/branch.go index 2dbbafb80..c84abdbcf 100644 --- a/pkg/session/branch.go +++ b/pkg/session/branch.go @@ -81,21 +81,19 @@ func (s *Session) Clone() *Session { MessageUsageHistory: slices.Clone(s.MessageUsageHistory), } - clone.Messages = make([]Item, 0, len(s.Messages)) - for _, item := range s.Messages { - switch { - case item.Message != nil: - clone.Messages = append(clone.Messages, Item{Message: cloneMessage(item.Message)}) - case item.SubSession != nil: - sub, err := cloneSubSession(item.SubSession) - if err != nil { - continue - } - clone.Messages = append(clone.Messages, Item{SubSession: sub}) - default: - // Summary/Cost/FirstKeptEntry are value fields; a shallow - // copy of the item is already a faithful clone. - clone.Messages = append(clone.Messages, item) + // Start from a shallow copy of each item so value fields (Summary, + // Cost, FirstKeptEntry) are preserved verbatim, then deep-copy the + // pointer fields so the clone shares no mutable state. Sub-sessions + // recurse through Clone to stay faithful (unlike BranchSession's + // helper, which mints fresh IDs and resets metadata). + clone.Messages = make([]Item, len(s.Messages)) + for i, item := range s.Messages { + clone.Messages[i] = item + if item.Message != nil { + clone.Messages[i].Message = cloneMessage(item.Message) + } + if item.SubSession != nil { + clone.Messages[i].SubSession = item.SubSession.Clone() } } return clone @@ -199,6 +197,7 @@ func clonePermissionsConfig(src *PermissionsConfig) *PermissionsConfig { } return &PermissionsConfig{ Allow: cloneStringSlice(src.Allow), + Ask: cloneStringSlice(src.Ask), Deny: cloneStringSlice(src.Deny), } } diff --git a/pkg/session/clone_test.go b/pkg/session/clone_test.go index 1fe8c177e..88c45fd52 100644 --- a/pkg/session/clone_test.go +++ b/pkg/session/clone_test.go @@ -30,7 +30,7 @@ func TestClone_CopiesScalarFields(t *testing.T) { InputTokens: 11, OutputTokens: 22, Cost: 1.5, - Permissions: &PermissionsConfig{Allow: []string{"a"}, Deny: []string{"d"}}, + Permissions: &PermissionsConfig{Allow: []string{"a"}, Ask: []string{"k"}, Deny: []string{"d"}}, AgentModelOverrides: map[string]string{"root": "openai/gpt-4o"}, CustomModelsUsed: []string{"openai/gpt-4o"}, AttachedFiles: []string{"/abs/path.txt"}, @@ -60,12 +60,16 @@ func TestClone_CopiesScalarFields(t *testing.T) { assert.Equal(t, "root", clone.AgentName) assert.Equal(t, "parent", clone.ParentID) assert.Equal(t, "hello", clone.GetLastUserMessageContent()) + require.NotNil(t, clone.Permissions) + assert.Equal(t, []string{"a"}, clone.Permissions.Allow) + assert.Equal(t, []string{"k"}, clone.Permissions.Ask) + assert.Equal(t, []string{"d"}, clone.Permissions.Deny) } func TestClone_DeepCopiesMessagesAndConfig(t *testing.T) { orig := &Session{ ID: "sess-1", - Permissions: &PermissionsConfig{Allow: []string{"a"}}, + Permissions: &PermissionsConfig{Allow: []string{"a"}, Ask: []string{"k"}}, AgentModelOverrides: map[string]string{"root": "m1"}, CustomModelsUsed: []string{"m1"}, AttachedFiles: []string{"/abs/a.txt"}, @@ -83,12 +87,14 @@ func TestClone_DeepCopiesMessagesAndConfig(t *testing.T) { // Mutate the clone's deep-copied structures; the original must not change. clone.Permissions.Allow[0] = "mutated" + clone.Permissions.Ask[0] = "mutated" clone.AgentModelOverrides["root"] = "mutated" clone.CustomModelsUsed[0] = "mutated" clone.AttachedFiles[0] = "/abs/mutated.txt" clone.Messages[0].Message.Message.MultiContent[0].ImageURL.URL = "http://mutated" assert.Equal(t, "a", orig.Permissions.Allow[0]) + assert.Equal(t, "k", orig.Permissions.Ask[0]) assert.Equal(t, "m1", orig.AgentModelOverrides["root"]) assert.Equal(t, "m1", orig.CustomModelsUsed[0]) assert.Equal(t, "/abs/a.txt", orig.AttachedFiles[0]) @@ -127,3 +133,21 @@ func TestClone_PreservesSubSessionAndSummary(t *testing.T) { assert.Equal(t, "a summary", clone.Messages[2].Summary) assert.InEpsilon(t, 0.25, clone.Messages[2].Cost, 1e-9) } + +// TestClone_PreservesItemValueFields guards against a clone that rebuilds +// items field-by-field and silently drops the per-item Cost / FirstKeptEntry +// that can ride alongside a message. +func TestClone_PreservesItemValueFields(t *testing.T) { + orig := New() + orig.Messages = []Item{{ + Message: UserMessage("hello"), + Cost: 0.5, + FirstKeptEntry: 3, + }} + + clone := orig.Clone() + require.Len(t, clone.Messages, 1) + assert.Equal(t, "hello", clone.Messages[0].Message.Message.Content) + assert.InEpsilon(t, 0.5, clone.Messages[0].Cost, 1e-9) + assert.Equal(t, 3, clone.Messages[0].FirstKeptEntry) +} From e9e57f0cbf9c23cb2bf640c50a5b558043b1e616 Mon Sep 17 00:00:00 2001 From: David Gageot Date: Mon, 1 Jun 2026 13:29:59 +0200 Subject: [PATCH 3/4] e2e: test conversation state handling across failed turns --- e2e/chatserver_test.go | 306 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 306 insertions(+) create mode 100644 e2e/chatserver_test.go diff --git a/e2e/chatserver_test.go b/e2e/chatserver_test.go new file mode 100644 index 000000000..8d9948c11 --- /dev/null +++ b/e2e/chatserver_test.go @@ -0,0 +1,306 @@ +package e2e_test + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "slices" + "strings" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/docker/docker-agent/cmd/root" +) + +func TestServeChatConversationFailedTurnDoesNotAdvanceCache(t *testing.T) { + modelServer := newRecordingChatCompletionsServer(t) + + agentFile := filepath.Join(t.TempDir(), "agent.yaml") + agentYAML := fmt.Appendf(nil, `version: "9" + +providers: + e2e: + api_type: openai_chatcompletions + base_url: %s/v1 + +models: + e2e-model: + provider: e2e + model: e2e-model + max_tokens: 64 + +agents: + root: + model: e2e-model + description: E2E chat server agent + instruction: Reply concisely. +`, modelServer.URL()) + require.NoError(t, os.WriteFile(agentFile, agentYAML, 0o644)) + + addr := freeTCPAddr(t) + ctx, cancel := context.WithCancel(t.Context()) + defer cancel() + + os.Unsetenv("DOCKER_CLI_PLUGIN_ORIGINAL_CLI_COMMAND") + t.Setenv("DOCKER_AGENT_MODELS_GATEWAY", "") + t.Setenv("CAGENT_MODELS_GATEWAY", "") + t.Setenv("OPENAI_API_KEY", "DUMMY") + + var stdout, stderr bytes.Buffer + errCh := make(chan error, 1) + go func() { + errCh <- root.Execute(ctx, nil, &stdout, &stderr, + "--cache-dir", filepath.Join(t.TempDir(), "cache"), + "--config-dir", filepath.Join(t.TempDir(), "config"), + "--data-dir", filepath.Join(t.TempDir(), "data"), + "serve", "chat", + "--listen", addr, + "--conversations-max", "10", + "--request-timeout", "2s", + agentFile, + ) + }() + baseURL := "http://" + addr + waitForChatServer(t, baseURL) + defer func() { + cancel() + select { + case err := <-errCh: + if err != nil && !strings.Contains(err.Error(), "use of closed network connection") { + require.NoError(t, err, "stdout: %s\nstderr: %s", stdout.String(), stderr.String()) + } + case <-time.After(5 * time.Second): + t.Fatal("chat server did not stop") + } + }() + + convID := "e2e-failed-turn" + postChatCompletion(t, baseURL, convID, http.StatusOK, "first") + postChatCompletion(t, baseURL, convID, http.StatusInternalServerError, "please fail") + postChatCompletion(t, baseURL, convID, http.StatusOK, "after failure") + + requests := modelServer.requests() + require.GreaterOrEqual(t, len(requests), 3) + assert.Equal(t, []string{"first"}, requests[0].userMessages()) + for _, req := range requests[1 : len(requests)-1] { + assert.Equal(t, []string{"first", "please fail"}, req.userMessages()) + } + + // This is the end-to-end assertion for #2890: the failed "please fail" + // turn must not have been committed to the X-Conversation-Id cache, so the + // following successful turn resumes from the last successful state. + assert.Equal(t, []string{"first", "after failure"}, requests[len(requests)-1].userMessages()) +} + +type recordingChatCompletionsServer struct { + server *httptest.Server + mu sync.Mutex + reqs []recordedChatCompletionRequest +} + +type recordedChatCompletionRequest struct { + Messages []recordedChatMessage `json:"messages"` +} + +type recordedChatMessage struct { + Role string `json:"role"` + Content any `json:"content"` +} + +func newRecordingChatCompletionsServer(t *testing.T) *recordingChatCompletionsServer { + t.Helper() + + rec := &recordingChatCompletionsServer{} + rec.server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/chat/completions" { + http.NotFound(w, r) + return + } + + var req recordedChatCompletionRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + rec.mu.Lock() + rec.reqs = append(rec.reqs, req) + rec.mu.Unlock() + + if lastUserMessage(req.Messages) == "please fail" { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusInternalServerError) + _, _ = io.WriteString(w, `{"error":{"message":"forced failure","type":"server_error"}}`) + return + } + + writeChatCompletionsStream(w, "ok: "+lastUserMessage(req.Messages)) + })) + t.Cleanup(rec.server.Close) + return rec +} + +func (s *recordingChatCompletionsServer) requests() []recordedChatCompletionRequest { + s.mu.Lock() + defer s.mu.Unlock() + return append([]recordedChatCompletionRequest(nil), s.reqs...) +} + +func (s *recordingChatCompletionsServer) URL() string { + return s.server.URL +} + +func (r recordedChatCompletionRequest) userMessages() []string { + var out []string + for _, msg := range r.Messages { + if msg.Role == "user" { + out = append(out, messageContentText(msg.Content)) + } + } + return out +} + +func lastUserMessage(messages []recordedChatMessage) string { + for _, message := range slices.Backward(messages) { + if message.Role == "user" { + return messageContentText(message.Content) + } + } + return "" +} + +func messageContentText(content any) string { + switch v := content.(type) { + case string: + return v + case []any: + var b strings.Builder + for _, part := range v { + m, ok := part.(map[string]any) + if !ok || m["type"] != "text" { + continue + } + if text, ok := m["text"].(string); ok { + b.WriteString(text) + } + } + return b.String() + default: + return fmt.Sprint(v) + } +} + +func writeChatCompletionsStream(w http.ResponseWriter, content string) { + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.WriteHeader(http.StatusOK) + writeSSEData(w, map[string]any{ + "id": "chatcmpl-e2e", + "object": "chat.completion.chunk", + "created": time.Now().Unix(), + "model": "e2e-model", + "choices": []map[string]any{{ + "index": 0, + "delta": map[string]any{"role": "assistant", "content": content}, + }}, + }) + writeSSEData(w, map[string]any{ + "id": "chatcmpl-e2e", + "object": "chat.completion.chunk", + "created": time.Now().Unix(), + "model": "e2e-model", + "choices": []map[string]any{{ + "index": 0, + "delta": map[string]any{}, + "finish_reason": "stop", + }}, + }) + writeSSEData(w, map[string]any{ + "id": "chatcmpl-e2e", + "object": "chat.completion.chunk", + "created": time.Now().Unix(), + "model": "e2e-model", + "choices": []map[string]any{}, + "usage": map[string]any{ + "prompt_tokens": 1, + "completion_tokens": 1, + "total_tokens": 2, + }, + }) + _, _ = io.WriteString(w, "data: [DONE]\n\n") + if f, ok := w.(http.Flusher); ok { + f.Flush() + } +} + +func writeSSEData(w http.ResponseWriter, payload any) { + data, err := json.Marshal(payload) + if err != nil { + return + } + _, _ = fmt.Fprintf(w, "data: %s\n\n", data) + if f, ok := w.(http.Flusher); ok { + f.Flush() + } +} + +func freeTCPAddr(t *testing.T) string { + t.Helper() + var lc net.ListenConfig + ln, err := lc.Listen(t.Context(), "tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + return ln.Addr().String() +} + +func waitForChatServer(t *testing.T, baseURL string) { + t.Helper() + client := &http.Client{Timeout: 500 * time.Millisecond} + require.Eventually(t, func() bool { + req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, baseURL+"/v1/models", http.NoBody) + if err != nil { + return false + } + resp, err := client.Do(req) + if err != nil { + return false + } + defer resp.Body.Close() + return resp.StatusCode == http.StatusOK + }, 5*time.Second, 50*time.Millisecond) +} + +func postChatCompletion(t *testing.T, baseURL, conversationID string, expectedStatus int, userMessage string) { + t.Helper() + body, err := json.Marshal(map[string]any{ + "model": "root", + "messages": []map[string]string{{ + "role": "user", + "content": userMessage, + }}, + }) + require.NoError(t, err) + + req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, baseURL+"/v1/chat/completions", bytes.NewReader(body)) + require.NoError(t, err) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("X-Conversation-Id", conversationID) + + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + responseBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + assert.Equal(t, expectedStatus, resp.StatusCode, "response body: %s", string(responseBody)) +} From 72b361c2a50e47925fa43a4124aefee83edd892c Mon Sep 17 00:00:00 2001 From: David Gageot Date: Mon, 1 Jun 2026 13:57:27 +0200 Subject: [PATCH 4/4] fix: deep-copy Evals, EvalResult, and ToolDefinitions in session clones --- pkg/chatserver/server.go | 3 +- pkg/session/branch.go | 42 +++++++++++++++++++++++-- pkg/session/clone_test.go | 64 +++++++++++++++++++++++++++++++++++++++ pkg/session/session.go | 48 ++++++++++++++++++++++++++++- 4 files changed, 152 insertions(+), 5 deletions(-) diff --git a/pkg/chatserver/server.go b/pkg/chatserver/server.go index e2aba26f2..b0fdf14da 100644 --- a/pkg/chatserver/server.go +++ b/pkg/chatserver/server.go @@ -470,7 +470,7 @@ func (s *server) chatCompletion(c echo.Context, rt runtime.Runtime, sess *sessio return err } - _ = c.JSON(http.StatusOK, ChatCompletionResponse{ + return c.JSON(http.StatusOK, ChatCompletionResponse{ ID: newChatID(), Object: "chat.completion", Created: time.Now().Unix(), @@ -486,7 +486,6 @@ func (s *server) chatCompletion(c echo.Context, rt runtime.Runtime, sess *sessio }}, Usage: sessionUsage(sess), }) - return nil } // streamChatCompletion runs the agent and streams its response back to the diff --git a/pkg/session/branch.go b/pkg/session/branch.go index c84abdbcf..6de7b5ed6 100644 --- a/pkg/session/branch.go +++ b/pkg/session/branch.go @@ -56,8 +56,8 @@ func (s *Session) Clone() *Session { ID: s.ID, InputID: s.InputID, Title: s.Title, - Evals: s.Evals, - EvalResult: s.EvalResult, + Evals: cloneEvalCriteria(s.Evals), + EvalResult: cloneEvalResult(s.EvalResult), CreatedAt: s.CreatedAt, ToolsApproved: s.ToolsApproved, NonInteractive: s.NonInteractive, @@ -191,6 +191,44 @@ func generateBranchTitle(parentTitle string) string { return parentTitle + " (branched)" } +func cloneEvalCriteria(src *EvalCriteria) *EvalCriteria { + if src == nil { + return nil + } + cp := *src + cp.Relevance = cloneStringSlice(src.Relevance) + return &cp +} + +func cloneEvalResult(src *EvalResult) *EvalResult { + if src == nil { + return nil + } + cp := *src + cp.Successes = cloneStringSlice(src.Successes) + cp.Failures = cloneStringSlice(src.Failures) + cp.Checks = cloneEvalResultChecks(src.Checks) + return &cp +} + +func cloneEvalResultChecks(src EvalResultChecks) EvalResultChecks { + var cp EvalResultChecks + if src.Size != nil { + size := *src.Size + cp.Size = &size + } + if src.ToolCalls != nil { + toolCalls := *src.ToolCalls + cp.ToolCalls = &toolCalls + } + if src.Relevance != nil { + relevance := *src.Relevance + relevance.Results = slices.Clone(src.Relevance.Results) + cp.Relevance = &relevance + } + return cp +} + func clonePermissionsConfig(src *PermissionsConfig) *PermissionsConfig { if src == nil { return nil diff --git a/pkg/session/clone_test.go b/pkg/session/clone_test.go index 88c45fd52..63eb3d510 100644 --- a/pkg/session/clone_test.go +++ b/pkg/session/clone_test.go @@ -7,6 +7,7 @@ import ( "github.com/stretchr/testify/require" "github.com/docker/docker-agent/pkg/chat" + "github.com/docker/docker-agent/pkg/tools" ) func TestClone_NilSession(t *testing.T) { @@ -66,6 +67,42 @@ func TestClone_CopiesScalarFields(t *testing.T) { assert.Equal(t, []string{"d"}, clone.Permissions.Deny) } +func TestClone_DeepCopiesEvalFields(t *testing.T) { + orig := &Session{ + Evals: &EvalCriteria{ + Relevance: []string{"is helpful"}, + WorkingDir: "work", + }, + EvalResult: &EvalResult{ + Passed: true, + Successes: []string{"ok"}, + Failures: []string{"missing"}, + Checks: EvalResultChecks{ + Size: &SizeCheck{Passed: true, Actual: "S", Expected: "M"}, + Relevance: &RelevanceCheck{Results: []RelevanceCriterionResult{{ + Criterion: "is helpful", + Passed: true, + }}}, + }, + }, + } + + clone := orig.Clone() + require.NotNil(t, clone) + + clone.Evals.Relevance[0] = "mutated" + clone.EvalResult.Successes[0] = "mutated" + clone.EvalResult.Failures[0] = "mutated" + clone.EvalResult.Checks.Size.Actual = "XL" + clone.EvalResult.Checks.Relevance.Results[0].Criterion = "mutated" + + assert.Equal(t, "is helpful", orig.Evals.Relevance[0]) + assert.Equal(t, "ok", orig.EvalResult.Successes[0]) + assert.Equal(t, "missing", orig.EvalResult.Failures[0]) + assert.Equal(t, "S", orig.EvalResult.Checks.Size.Actual) + assert.Equal(t, "is helpful", orig.EvalResult.Checks.Relevance.Results[0].Criterion) +} + func TestClone_DeepCopiesMessagesAndConfig(t *testing.T) { orig := &Session{ ID: "sess-1", @@ -74,12 +111,24 @@ func TestClone_DeepCopiesMessagesAndConfig(t *testing.T) { CustomModelsUsed: []string{"m1"}, AttachedFiles: []string{"/abs/a.txt"}, } + destructive := true orig.AddMessage(&Message{Message: chat.Message{ Role: chat.MessageRoleUser, MultiContent: []chat.MessagePart{{ Type: chat.MessagePartTypeImageURL, ImageURL: &chat.MessageImageURL{URL: "http://orig"}, }}, + ToolDefinitions: []tools.Tool{{ + Name: "read_file", + Parameters: map[string]any{ + "type": "object", + "properties": map[string]any{ + "path": map[string]any{"type": "string"}, + }, + }, + OutputSchema: map[string]any{"type": "string"}, + Annotations: tools.ToolAnnotations{DestructiveHint: &destructive}, + }}, }}) clone := orig.Clone() @@ -92,6 +141,17 @@ func TestClone_DeepCopiesMessagesAndConfig(t *testing.T) { clone.CustomModelsUsed[0] = "mutated" clone.AttachedFiles[0] = "/abs/mutated.txt" clone.Messages[0].Message.Message.MultiContent[0].ImageURL.URL = "http://mutated" + cloneParams := clone.Messages[0].Message.Message.ToolDefinitions[0].Parameters.(map[string]any) + cloneParams["type"] = "mutated" + cloneNestedParams := cloneParams["properties"].(map[string]any)["path"].(map[string]any) + cloneNestedParams["type"] = "mutated" + cloneOutputSchema := clone.Messages[0].Message.Message.ToolDefinitions[0].OutputSchema.(map[string]any) + cloneOutputSchema["type"] = "mutated" + *clone.Messages[0].Message.Message.ToolDefinitions[0].Annotations.DestructiveHint = false + + origParams := orig.Messages[0].Message.Message.ToolDefinitions[0].Parameters.(map[string]any) + origNestedParams := origParams["properties"].(map[string]any)["path"].(map[string]any) + origOutputSchema := orig.Messages[0].Message.Message.ToolDefinitions[0].OutputSchema.(map[string]any) assert.Equal(t, "a", orig.Permissions.Allow[0]) assert.Equal(t, "k", orig.Permissions.Ask[0]) @@ -99,6 +159,10 @@ func TestClone_DeepCopiesMessagesAndConfig(t *testing.T) { assert.Equal(t, "m1", orig.CustomModelsUsed[0]) assert.Equal(t, "/abs/a.txt", orig.AttachedFiles[0]) assert.Equal(t, "http://orig", orig.Messages[0].Message.Message.MultiContent[0].ImageURL.URL) + assert.Equal(t, "object", origParams["type"]) + assert.Equal(t, "string", origNestedParams["type"]) + assert.Equal(t, "string", origOutputSchema["type"]) + assert.True(t, *orig.Messages[0].Message.Message.ToolDefinitions[0].Annotations.DestructiveHint) } func TestClone_AppendingDoesNotAffectOriginal(t *testing.T) { diff --git a/pkg/session/session.go b/pkg/session/session.go index 2a2d7f2f7..f65f87a21 100644 --- a/pkg/session/session.go +++ b/pkg/session/session.go @@ -392,7 +392,7 @@ func cloneChatMessage(m chat.Message) chat.Message { m.ToolCalls = slices.Clone(m.ToolCalls) } if m.ToolDefinitions != nil { - m.ToolDefinitions = slices.Clone(m.ToolDefinitions) + m.ToolDefinitions = cloneToolDefinitions(m.ToolDefinitions) } if m.Usage != nil { usageCopy := *m.Usage @@ -404,6 +404,52 @@ func cloneChatMessage(m chat.Message) chat.Message { return m } +func cloneToolDefinitions(src []tools.Tool) []tools.Tool { + if src == nil { + return nil + } + out := make([]tools.Tool, len(src)) + for i, tool := range src { + out[i] = tool + out[i].Parameters = cloneSchemaValue(tool.Parameters) + out[i].OutputSchema = cloneSchemaValue(tool.OutputSchema) + out[i].Annotations = cloneToolAnnotations(tool.Annotations) + } + return out +} + +func cloneToolAnnotations(src tools.ToolAnnotations) tools.ToolAnnotations { + cp := src + if src.DestructiveHint != nil { + hint := *src.DestructiveHint + cp.DestructiveHint = &hint + } + if src.OpenWorldHint != nil { + hint := *src.OpenWorldHint + cp.OpenWorldHint = &hint + } + return cp +} + +func cloneSchemaValue(v any) any { + switch x := v.(type) { + case map[string]any: + cp := make(map[string]any, len(x)) + for k, v := range x { + cp[k] = cloneSchemaValue(v) + } + return cp + case []any: + cp := make([]any, len(x)) + for i, v := range x { + cp[i] = cloneSchemaValue(v) + } + return cp + default: + return v + } +} + // Session helper methods // AddMessage adds a message to the session