Skip to content

Commit 4277135

Browse files
committed
get closer to openai schema for llms
1 parent c775791 commit 4277135

File tree

1 file changed

+37
-8
lines changed

1 file changed

+37
-8
lines changed

src/inferencesh/models/llm.py

Lines changed: 37 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@ class ContextMessage(BaseAppInput):
3737
description="the tool calls of the message",
3838
default=None
3939
)
40+
tool_call_id: Optional[str] = Field(
41+
description="the tool call id for tool role messages",
42+
default=None
43+
)
4044

4145
class BaseLLMInput(BaseAppInput):
4246
"""Base class with common LLM fields."""
@@ -275,19 +279,44 @@ def merge_tool_calls(messages: List[ContextMessage]) -> List[Dict[str, Any]]:
275279
current_messages.append(msg)
276280
current_role = msg.role
277281
else:
278-
messages.append({
279-
"role": current_role,
282+
# Convert role enum to string for OpenAI API compatibility
283+
role_str = current_role.value if hasattr(current_role, "value") else current_role
284+
msg_dict = {
285+
"role": role_str,
280286
"content": render_message(merge_messages(current_messages), allow_multipart=multipart),
281-
"tool_calls": merge_tool_calls(current_messages)
282-
})
287+
}
288+
289+
# Only add tool_calls if not empty
290+
tool_calls = merge_tool_calls(current_messages)
291+
if tool_calls:
292+
msg_dict["tool_calls"] = tool_calls
293+
294+
# Add tool_call_id if present (for tool role messages)
295+
if current_messages and current_messages[0].tool_call_id:
296+
msg_dict["tool_call_id"] = current_messages[0].tool_call_id
297+
298+
messages.append(msg_dict)
283299
current_messages = [msg]
284300
current_role = msg.role
301+
285302
if len(current_messages) > 0:
286-
messages.append({
287-
"role": current_role,
303+
# Convert role enum to string for OpenAI API compatibility
304+
role_str = current_role.value if hasattr(current_role, "value") else current_role
305+
msg_dict = {
306+
"role": role_str,
288307
"content": render_message(merge_messages(current_messages), allow_multipart=multipart),
289-
"tool_calls": merge_tool_calls(current_messages)
290-
})
308+
}
309+
310+
# Only add tool_calls if not empty
311+
tool_calls = merge_tool_calls(current_messages)
312+
if tool_calls:
313+
msg_dict["tool_calls"] = tool_calls
314+
315+
# Add tool_call_id if present (for tool role messages)
316+
if current_messages and current_messages[0].tool_call_id:
317+
msg_dict["tool_call_id"] = current_messages[0].tool_call_id
318+
319+
messages.append(msg_dict)
291320

292321
return messages
293322

0 commit comments

Comments
 (0)