Skip to content

Commit 4b80829

Browse files
committed
add role to default inputs for llm
1 parent a56acfd commit 4b80829

File tree

1 file changed

+11
-2
lines changed

1 file changed

+11
-2
lines changed

src/inferencesh/models/llm.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ class Message(BaseAppInput):
2121
role: ContextMessageRole
2222
content: str
2323

24-
2524
class ContextMessage(BaseAppInput):
2625
role: ContextMessageRole = Field(
2726
description="the role of the message. user, assistant, or system",
@@ -33,6 +32,10 @@ class ContextMessage(BaseAppInput):
3332
description="the image file of the message",
3433
default=None
3534
)
35+
tool_calls: Optional[List[Dict[str, Any]]] = Field(
36+
description="the tool calls of the message",
37+
default=None
38+
)
3639

3740
class BaseLLMInput(BaseAppInput):
3841
"""Base class with common LLM fields."""
@@ -53,8 +56,12 @@ class BaseLLMInput(BaseAppInput):
5356
]
5457
]
5558
)
59+
role: ContextMessageRole = Field(
60+
description="the role of the input text",
61+
default=ContextMessageRole.USER
62+
)
5663
text: str = Field(
57-
description="the user prompt to use for the model",
64+
description="the input text to use for the model",
5865
examples=[
5966
"write a haiku about artificial general intelligence"
6067
]
@@ -217,6 +224,8 @@ def render_message(msg: ContextMessage, allow_multipart: bool) -> str | List[dic
217224
parts.append({"type": "image_url", "image_url": {"url": image_data_uri}})
218225
elif msg.image.uri:
219226
parts.append({"type": "image_url", "image_url": {"url": msg.image.uri}})
227+
if msg.tool_calls:
228+
parts.append({"type": "tool_call", "tool_calls": msg.tool_calls})
220229
if allow_multipart:
221230
return parts
222231
if len(parts) == 1 and parts[0]["type"] == "text":

0 commit comments

Comments
 (0)