Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions strands-py/src/strands/models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,12 +234,13 @@ def get_config(self) -> BedrockConfig:
"""
return resolve_config_metadata(self.config, self.config.get("model_id", ""))

def _format_request(
def format_request(
self,
messages: Messages,
tool_specs: list[ToolSpec] | None = None,
system_prompt_content: list[SystemContentBlock] | None = None,
tool_choice: ToolChoice | None = None,
**kwargs: Any,
) -> dict[str, Any]:

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add kwargs here

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes in https://github.com/strands-agents/sdk-python/pull/2093/changes#r3306308887 is an example of an additive parameter FWIW

"""Format a Bedrock converse stream request.

Expand All @@ -248,6 +249,7 @@ def _format_request(
tool_specs: List of tool specifications to make available to the model.
tool_choice: Selection strategy for tool invocation.
system_prompt_content: System prompt content blocks to provide context to the model.
**kwargs: Additional keyword arguments for future extensibility.

Returns:
A Bedrock converse stream request.
Expand Down Expand Up @@ -830,7 +832,7 @@ async def count_tokens(
if system_prompt and system_prompt_content is None:
system_prompt_content = [{"text": system_prompt}]

request = self._format_request(messages, tool_specs, system_prompt_content)
request = self.format_request(messages, tool_specs, system_prompt_content)
converse_input: dict[str, Any] = {}
if "messages" in request:
converse_input["messages"] = request["messages"]
Expand Down Expand Up @@ -960,7 +962,7 @@ def _stream(
"""
try:
logger.debug("formatting request")
request = self._format_request(messages, tool_specs, system_prompt_content, tool_choice)
request = self.format_request(messages, tool_specs, system_prompt_content, tool_choice)
logger.debug("request=<%s>", request)

logger.debug("invoking model")
Expand All @@ -984,7 +986,7 @@ def _stream(

else:
response = self.client.converse(**request)
for event in self._convert_non_streaming_to_streaming(response):
for event in self.convert_non_streaming_to_streaming(response):
callback(event)

if (
Expand Down Expand Up @@ -1040,11 +1042,12 @@ def _stream(
callback()
logger.debug("finished streaming response from model")

def _convert_non_streaming_to_streaming(self, response: dict[str, Any]) -> Iterable[StreamEvent]:
def convert_non_streaming_to_streaming(self, response: dict[str, Any], **kwargs: Any) -> Iterable[StreamEvent]:
"""Convert a non-streaming response to the streaming format.

Args:
response: The non-streaming response from the Bedrock model.
**kwargs: Additional keyword arguments for future extensibility.

Returns:
An iterable of response events in the streaming format.
Expand Down
Loading