Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from langchain_core.chat_history import InMemoryChatMessageHistory
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import BaseMessage, HumanMessage
from langchain_core.messages import BaseMessage
from ldai import LDMessage, log
from ldai.providers.runner import Runner
from ldai.providers.types import LDAIMetrics, RunnerResult
Expand All @@ -28,13 +28,10 @@ class LangChainModelRunner(Runner):
def __init__(
self,
llm: BaseChatModel,
config_messages: Optional[List[LDMessage]] = None,
multi_turn: bool = True,
):
self._llm = llm
self._chat_history = InMemoryChatMessageHistory(
messages=cast(List[BaseMessage], convert_messages_to_langchain(config_messages or []))
)
self._chat_history = InMemoryChatMessageHistory()
self._multi_turn = multi_turn

def get_llm(self) -> BaseChatModel:
Expand All @@ -47,32 +44,53 @@ def get_llm(self) -> BaseChatModel:

async def run(
self,
input: str,
input: Any,
output_type: Optional[Dict[str, Any]] = None,
) -> RunnerResult:
"""
Run the LangChain model with the given input.

:param input: A string prompt
:param input: A string prompt or a list of :class:`LDMessage` objects.
When a list is provided it is used as the complete message set.
When a string is provided it is appended to the conversation history.
:param output_type: Optional JSON schema dict requesting structured output.
When provided, ``parsed`` on the returned :class:`RunnerResult` is
populated with the parsed JSON document.
:return: :class:`RunnerResult` containing ``content``, ``metrics``,
``raw`` and (when ``output_type`` is set) ``parsed``.
"""
langchain_messages = self._chat_history.messages + [HumanMessage(content=input)]
coerced = self._coerce_input(input)

if isinstance(input, list):
langchain_messages = cast(List[BaseMessage], convert_messages_to_langchain(coerced))
else:
langchain_messages = self._chat_history.messages + cast(
List[BaseMessage], convert_messages_to_langchain(coerced)
)

if output_type is not None:
result = await self._run_structured(langchain_messages, output_type)
else:
result = await self._run_completion(langchain_messages)

if result.metrics.success and result.content and self._multi_turn:
if result.metrics.success and result.content and self._multi_turn and isinstance(input, str):
self._chat_history.add_user_message(input)
self._chat_history.add_ai_message(result.content)

return result

# convert_messages_to_langchain only accepts List[LDMessage]; _coerce_input
# normalizes a bare string to [LDMessage(role='user', ...)] before that step.
@staticmethod
def _coerce_input(input: Any) -> List[LDMessage]:
if isinstance(input, str):
return [LDMessage(role='user', content=input)]
if isinstance(input, list):
return input
raise TypeError(
f"Unsupported input type for LangChainModelRunner.run: {type(input).__name__}"
)

async def _run_completion(self, messages: List[BaseMessage]) -> RunnerResult:
try:
response: BaseMessage = await self._llm.ainvoke(messages)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,5 +72,4 @@ def create_model(self, config: AIConfigKind, multi_turn: bool = True) -> LangCha
:return: LangChainModelRunner ready to invoke the model
"""
llm = create_langchain_model(config)
config_messages = list(getattr(config, 'messages', None) or [])
return LangChainModelRunner(llm, config_messages, multi_turn=multi_turn)
return LangChainModelRunner(llm, multi_turn=multi_turn)
Original file line number Diff line number Diff line change
Expand Up @@ -332,23 +332,22 @@ async def test_does_not_accumulate_history_on_failed_call(self, mock_llm):
assert second_call_messages[0].content == 'Try again'

@pytest.mark.asyncio
async def test_prepends_config_messages_before_history(self, mock_llm):
"""Should send config messages before history on every call."""
mock_llm.ainvoke = AsyncMock(side_effect=[
AIMessage(content='Answer 1'),
AIMessage(content='Answer 2'),
])
config_messages = [LDMessage(role='system', content='You are helpful.')]
provider = LangChainModelRunner(mock_llm, config_messages=config_messages)
async def test_passes_list_input_directly_without_prepending_history(self, mock_llm):
"""When a list[LDMessage] is passed, it is used as-is without prepending chat history."""
mock_llm.ainvoke = AsyncMock(return_value=AIMessage(content='Answer'))
provider = LangChainModelRunner(mock_llm)

await provider.run('Q1')
await provider.run('Q2')
messages = [
LDMessage(role='system', content='You are helpful.'),
LDMessage(role='user', content='Q1'),
]
await provider.run(messages)

second_call_messages = mock_llm.ainvoke.call_args_list[1][0][0]
assert second_call_messages[0].content == 'You are helpful.'
assert second_call_messages[1].content == 'Q1'
assert second_call_messages[2].content == 'Answer 1'
assert second_call_messages[3].content == 'Q2'
first_call_messages = mock_llm.ainvoke.call_args_list[0][0][0]
assert len(first_call_messages) == 2
assert first_call_messages[0].content == 'You are helpful.'
assert first_call_messages[1].content == 'Q1'
assert len(provider._chat_history.messages) == 0


class TestRunStructured:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,44 +28,63 @@ def __init__(
client: AsyncOpenAI,
model_name: str,
parameters: Dict[str, Any],
config_messages: Optional[List[LDMessage]] = None,
multi_turn: bool = True,
):
self._client = client
self._model_name = model_name
self._parameters = parameters
self._history: List[LDMessage] = list(config_messages or [])
self._history: List[LDMessage] = []
self._multi_turn = multi_turn

async def run(
self,
input: str,
input: Any,
output_type: Optional[Dict[str, Any]] = None,
) -> RunnerResult:
"""
Run the OpenAI model with the given input.

:param input: A string prompt
:param input: A string prompt or a list of :class:`LDMessage` objects.
When a list is provided it is used as the complete message set.
When a string is provided it is appended to the conversation history.
:param output_type: Optional JSON schema dict requesting structured output.
When provided, ``parsed`` on the returned :class:`RunnerResult` is
populated with the parsed JSON document.
:return: :class:`RunnerResult` containing ``content``, ``metrics``,
``raw`` and (when ``output_type`` is set) ``parsed``.
"""
user_message = LDMessage(role='user', content=input)
messages = self._history + [user_message]
try:
coerced = self._coerce_input(input)
except TypeError as error:
log.warning(f'OpenAI model runner received unsupported input type: {error}')
return RunnerResult(content='', metrics=LDAIMetrics(success=False, usage=None))

if isinstance(input, list):
messages = coerced
else:
messages = self._history + coerced

if output_type is not None:
result = await self._run_structured(messages, output_type)
else:
result = await self._run_completion(messages)

if result.metrics.success and result.content and self._multi_turn:
self._history.append(user_message)
if result.metrics.success and result.content and self._multi_turn and isinstance(input, str):
self._history.append(LDMessage(role='user', content=input))
self._history.append(LDMessage(role='assistant', content=result.content))

return result

@staticmethod
def _coerce_input(input: Any) -> List[LDMessage]:
if isinstance(input, str):
return [LDMessage(role='user', content=input)]
if isinstance(input, list):
return input
raise TypeError(
f"Unsupported input type for OpenAIModelRunner.run: {type(input).__name__}"
)

async def _run_completion(self, messages: List[LDMessage]) -> RunnerResult:
try:
response = await self._client.chat.completions.create(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,7 @@ def create_model(self, config: AIConfigKind, multi_turn: bool = True) -> OpenAIM
tool_defs = parameters.pop('tools', None) or []
if tool_defs:
parameters['tools'] = normalize_tool_types(tool_defs)
config_messages = list(getattr(config, 'messages', None) or [])
return OpenAIModelRunner(
self._client, model_name, parameters, config_messages, multi_turn=multi_turn
)
return OpenAIModelRunner(self._client, model_name, parameters, multi_turn=multi_turn)

def get_client(self) -> AsyncOpenAI:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,38 @@ def make_response(text: str):

assert len(provider._history) == baseline_len + 4

@pytest.mark.asyncio
async def test_passes_list_input_directly_without_prepending_history(self, mock_client):
"""When a list[LDMessage] is passed, it is used as-is without history prepend."""
from ldai import LDMessage as _LDMsg

def make_response(text: str):
r = MagicMock()
r.context_wrapper = None
r.choices = [MagicMock()]
r.choices[0].message = MagicMock()
r.choices[0].message.content = text
r.usage = None
return r

mock_client.chat = MagicMock()
mock_client.chat.completions = MagicMock()
mock_client.chat.completions.create = AsyncMock(return_value=make_response('Answer'))

provider = OpenAIModelRunner(mock_client, 'gpt-4o', {})
messages = [
_LDMsg(role='system', content='You are helpful.'),
_LDMsg(role='user', content='Q1'),
]
await provider.run(messages)

call_messages = mock_client.chat.completions.create.call_args.kwargs['messages']
assert call_messages == [
{'role': 'system', 'content': 'You are helpful.'},
{'role': 'user', 'content': 'Q1'},
]
assert provider._history == []

@pytest.mark.asyncio
async def test_does_not_accumulate_history_on_failed_call(self, mock_client):
"""Should not add to history when the call fails."""
Expand Down
16 changes: 2 additions & 14 deletions packages/sdk/server-ai/src/ldai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ldai import log
from ldai.agent_graph import AgentGraphDefinition
from ldai.evaluator import Evaluator
from ldai.judge import Judge, _strip_legacy_judge_messages
from ldai.judge import Judge
from ldai.managed_agent import ManagedAgent
from ldai.managed_agent_graph import ManagedAgentGraph
from ldai.managed_model import ManagedModel
Expand Down Expand Up @@ -203,17 +203,9 @@ def _judge_config(
"The variable 'response_to_evaluate' is reserved by the judge and will be ignored."
)

# Re-inject the reserved variables as their literal placeholders so they
# survive Mustache interpolation in ``__evaluate``. Without this, legacy
# templates like ``{{message_history}}`` get rendered to empty strings and
# ``_strip_legacy_judge_messages`` below cannot detect them.
extended_variables = dict(variables) if variables else {}
extended_variables['message_history'] = '{{message_history}}'
extended_variables['response_to_evaluate'] = '{{response_to_evaluate}}'

(model, provider, messages, instructions,
tracker_factory, enabled, judge_configuration, variation) = self.__evaluate(
key, context, default.to_dict(), extended_variables
key, context, default.to_dict(), variables
)

def _extract_evaluation_metric_key(variation: Dict[str, Any]) -> Optional[str]:
Expand All @@ -233,10 +225,6 @@ def _extract_evaluation_metric_key(variation: Dict[str, Any]) -> Optional[str]:

evaluation_metric_key = _extract_evaluation_metric_key(variation)

# strip legacy judge template messages before creating config
if messages:
messages = _strip_legacy_judge_messages(messages)

config = AIJudgeConfig(
key=key,
enabled=bool(enabled),
Expand Down
Loading
Loading