diff --git a/config/config.go b/config/config.go index 17ce01e4..4ab1c3c7 100644 --- a/config/config.go +++ b/config/config.go @@ -21,6 +21,10 @@ type Anthropic struct { // with a access token. When set, the access token is used for upstream // LLM requests instead of the API key. BYOKBearerToken string + // MaxRetries controls the number of automatic retries the SDK will perform + // on transient errors. If nil, the SDK default (2) is used. + // Set to 0 to disable retries entirely. + MaxRetries *int } type AWSBedrock struct { @@ -43,6 +47,10 @@ type OpenAI struct { CircuitBreaker *CircuitBreaker SendActorHeaders bool ExtraHeaders map[string]string + // MaxRetries controls the number of automatic retries the SDK will perform + // on transient errors. If nil, the SDK default (2) is used. + // Set to 0 to disable retries entirely. + MaxRetries *int } type Copilot struct { @@ -51,6 +59,10 @@ type Copilot struct { BaseURL string APIDumpDir string CircuitBreaker *CircuitBreaker + // MaxRetries controls the number of automatic retries the SDK will perform + // on transient errors. If nil, the SDK default (2) is used. + // Set to 0 to disable retries entirely. + MaxRetries *int } // CircuitBreaker holds configuration for circuit breakers. diff --git a/intercept/chatcompletions/base.go b/intercept/chatcompletions/base.go index 84e5f706..50b5e931 100644 --- a/intercept/chatcompletions/base.go +++ b/intercept/chatcompletions/base.go @@ -47,6 +47,9 @@ type interceptionBase struct { func (i *interceptionBase) newCompletionsService() openai.ChatCompletionService { opts := []option.RequestOption{option.WithAPIKey(i.cfg.Key), option.WithBaseURL(i.cfg.BaseURL)} + if i.cfg.MaxRetries != nil { + opts = append(opts, option.WithMaxRetries(*i.cfg.MaxRetries)) + } // Add extra headers if configured. // Some providers require additional headers that are not added by the SDK. diff --git a/intercept/messages/base.go b/intercept/messages/base.go index a5ac8b86..d7b9ddb8 100644 --- a/intercept/messages/base.go +++ b/intercept/messages/base.go @@ -219,6 +219,9 @@ func (i *interceptionBase) newMessagesService(ctx context.Context, opts ...optio opts = append(opts, option.WithAPIKey(i.cfg.Key)) } opts = append(opts, option.WithBaseURL(i.cfg.BaseURL)) + if i.cfg.MaxRetries != nil { + opts = append(opts, option.WithMaxRetries(*i.cfg.MaxRetries)) + } // Add extra headers if configured. // Some providers require additional headers that are not added by the SDK. diff --git a/intercept/responses/base.go b/intercept/responses/base.go index cbd63df2..daccc300 100644 --- a/intercept/responses/base.go +++ b/intercept/responses/base.go @@ -55,6 +55,9 @@ type responsesInterceptionBase struct { func (i *responsesInterceptionBase) newResponsesService() responses.ResponseService { opts := []option.RequestOption{option.WithBaseURL(i.cfg.BaseURL), option.WithAPIKey(i.cfg.Key)} + if i.cfg.MaxRetries != nil { + opts = append(opts, option.WithMaxRetries(*i.cfg.MaxRetries)) + } // Add extra headers if configured. // Some providers require additional headers that are not added by the SDK. diff --git a/internal/integrationtest/responses_test.go b/internal/integrationtest/responses_test.go index 3213e2ff..1a35f707 100644 --- a/internal/integrationtest/responses_test.go +++ b/internal/integrationtest/responses_test.go @@ -596,9 +596,6 @@ func TestResponsesParallelToolsOverwritten(t *testing.T) { } } -// TODO set MaxRetries to speed up this test -// option.WithMaxRetries(0), in base responses interceptor -// https://github.com/coder/aibridge/issues/115 func TestClientAndConnectionError(t *testing.T) { t.Parallel() @@ -642,7 +639,11 @@ func TestClientAndConnectionError(t *testing.T) { t.Cleanup(cancel) // tc.addr may be an intentionally invalid URL; use withCustomProvider. - bridgeServer := newBridgeTestServer(ctx, t, tc.addr, withCustomProvider(provider.NewOpenAI(openAICfg(tc.addr, apiKey)))) + // MaxRetries is set to 0 to disable SDK retries and speed up the test. + cfg := openAICfg(tc.addr, apiKey) + maxRetries := 0 + cfg.MaxRetries = &maxRetries + bridgeServer := newBridgeTestServer(ctx, t, tc.addr, withCustomProvider(provider.NewOpenAI(cfg))) reqBytes := responsesRequestBytes(t, tc.streaming) resp, err := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIResponses, reqBytes) @@ -660,9 +661,6 @@ func TestClientAndConnectionError(t *testing.T) { } } -// TODO set MaxRetries to speed up this test -// option.WithMaxRetries(0), in base responses interceptor -// https://github.com/coder/aibridge/issues/115 func TestUpstreamError(t *testing.T) { t.Parallel() @@ -721,7 +719,11 @@ func TestUpstreamError(t *testing.T) { })) t.Cleanup(upstream.Close) - bridgeServer := newBridgeTestServer(ctx, t, upstream.URL) + // MaxRetries is set to 0 to disable SDK retries and speed up the test. + cfg := openAICfg(upstream.URL, apiKey) + maxRetries := 0 + cfg.MaxRetries = &maxRetries + bridgeServer := newBridgeTestServer(ctx, t, upstream.URL, withCustomProvider(provider.NewOpenAI(cfg))) reqBytes := responsesRequestBytes(t, tc.streaming) resp, err := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIResponses, reqBytes) diff --git a/provider/anthropic.go b/provider/anthropic.go index 32de0d57..eb8a8371 100644 --- a/provider/anthropic.go +++ b/provider/anthropic.go @@ -5,6 +5,7 @@ import ( "io" "net/http" "os" + "strconv" "strings" "github.com/google/uuid" @@ -62,6 +63,13 @@ func NewAnthropic(cfg config.Anthropic, bedrockCfg *config.AWSBedrock) *Anthropi if cfg.APIDumpDir == "" { cfg.APIDumpDir = os.Getenv("BRIDGE_DUMP_DIR") } + if cfg.MaxRetries == nil { + if v := os.Getenv("ANTHROPIC_MAX_RETRIES"); v != "" { + if n, err := strconv.Atoi(v); err == nil { + cfg.MaxRetries = &n + } + } + } if cfg.CircuitBreaker != nil { cfg.CircuitBreaker.IsFailure = anthropicIsFailure cfg.CircuitBreaker.OpenErrorResponse = anthropicOpenErrorResponse diff --git a/provider/copilot.go b/provider/copilot.go index 56caa6e7..7ffaad24 100644 --- a/provider/copilot.go +++ b/provider/copilot.go @@ -6,6 +6,7 @@ import ( "io" "net/http" "os" + "strconv" "strings" "github.com/google/uuid" @@ -63,6 +64,13 @@ func NewCopilot(cfg config.Copilot) *Copilot { if cfg.APIDumpDir == "" { cfg.APIDumpDir = os.Getenv("BRIDGE_DUMP_DIR") } + if cfg.MaxRetries == nil { + if v := os.Getenv("COPILOT_MAX_RETRIES"); v != "" { + if n, err := strconv.Atoi(v); err == nil { + cfg.MaxRetries = &n + } + } + } if cfg.CircuitBreaker != nil { cfg.CircuitBreaker.OpenErrorResponse = copilotOpenErrorResponse } @@ -145,6 +153,7 @@ func (p *Copilot) CreateInterceptor(_ http.ResponseWriter, r *http.Request, trac APIDumpDir: p.cfg.APIDumpDir, CircuitBreaker: p.cfg.CircuitBreaker, ExtraHeaders: extractCopilotHeaders(r), + MaxRetries: p.cfg.MaxRetries, } cred := intercept.NewCredentialInfo(intercept.CredentialKindBYOK, key) diff --git a/provider/openai.go b/provider/openai.go index 6281aefe..827b86db 100644 --- a/provider/openai.go +++ b/provider/openai.go @@ -6,6 +6,7 @@ import ( "io" "net/http" "os" + "strconv" "strings" "github.com/google/uuid" @@ -51,6 +52,13 @@ func NewOpenAI(cfg config.OpenAI) *OpenAI { if cfg.APIDumpDir == "" { cfg.APIDumpDir = os.Getenv("BRIDGE_DUMP_DIR") } + if cfg.MaxRetries == nil { + if v := os.Getenv("OPENAI_MAX_RETRIES"); v != "" { + if n, err := strconv.Atoi(v); err == nil { + cfg.MaxRetries = &n + } + } + } if cfg.CircuitBreaker != nil { cfg.CircuitBreaker.OpenErrorResponse = openAIOpenErrorResponse }