Skip to content

Commit 7be830e

Browse files
committed
llm context imporvements
1 parent 19fd7eb commit 7be830e

File tree

1 file changed

+14
-59
lines changed

1 file changed

+14
-59
lines changed

src/inferencesh/models/llm.py

Lines changed: 14 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,14 @@ class ContextMessage(BaseAppInput):
3333
description="the reasoning content of the message",
3434
default=None
3535
)
36-
image: Optional[File] = Field(
37-
description="the image file of the message",
38-
default=None
39-
)
4036
images: Optional[List[File]] = Field(
4137
description="the images of the message",
4238
default=None
4339
)
40+
files: Optional[List[File]] = Field(
41+
description="the files of the message",
42+
default=None
43+
)
4444
tool_calls: Optional[List[Dict[str, Any]]] = Field(
4545
description="the tool calls of the message",
4646
default=None
@@ -84,14 +84,6 @@ class BaseLLMInput(BaseAppInput):
8484
context_size: int = Field(default=4096)
8585

8686
class ImageCapabilityMixin(BaseModel):
87-
"""Mixin for models that support image inputs."""
88-
image: Optional[File] = Field(
89-
description="the image to use for the model",
90-
default=None,
91-
contentMediaType="image/*",
92-
)
93-
94-
class MultipleImageCapabilityMixin(BaseModel):
9587
"""Mixin for models that support image inputs."""
9688
images: Optional[List[File]] = Field(
9789
description="the images to use for the model",
@@ -100,13 +92,6 @@ class MultipleImageCapabilityMixin(BaseModel):
10092

10193
class FileCapabilityMixin(BaseModel):
10294
"""Mixin for models that support file inputs."""
103-
file: Optional[File] = Field(
104-
description="the file to use for the model",
105-
default=None,
106-
)
107-
108-
class MultipleFileCapabilityMixin(BaseModel):
109-
"""Mixin for models that support multiple file inputs."""
11095
files: Optional[List[File]] = Field(
11196
description="the files to use for the model",
11297
default=None,
@@ -163,7 +148,6 @@ class LLMUsage(BaseAppOutput):
163148
reasoning_tokens: int = 0
164149
reasoning_time: float = 0.0
165150

166-
167151
class BaseLLMOutput(BaseAppOutput):
168152
"""Base class for LLM outputs with common fields."""
169153
response: str = Field(description="the generated text response")
@@ -304,13 +288,6 @@ def render_message(msg: ContextMessage, allow_multipart: bool) -> str | List[dic
304288
else:
305289
parts.append({"type": "text", "text": ""})
306290

307-
if msg.image:
308-
if msg.image.path:
309-
image_data_uri = image_to_base64_data_uri(msg.image.path)
310-
parts.append({"type": "image_url", "image_url": {"url": image_data_uri}})
311-
elif msg.image.uri:
312-
parts.append({"type": "image_url", "image_url": {"url": msg.image.uri}})
313-
314291
if msg.images:
315292
for image in msg.images:
316293
if image.path:
@@ -319,13 +296,6 @@ def render_message(msg: ContextMessage, allow_multipart: bool) -> str | List[dic
319296
elif image.uri:
320297
parts.append({"type": "image_url", "image_url": {"url": image.uri}})
321298

322-
if msg.file:
323-
if msg.file.path:
324-
file_data_uri = file_to_base64_data_uri(msg.file.path)
325-
parts.append({"type": "file_url", "file_url": {"url": file_data_uri}})
326-
elif msg.file.uri:
327-
parts.append({"type": "file_url", "file_url": {"url": msg.file.uri}})
328-
329299
if msg.files:
330300
for file in msg.files:
331301
if file.path:
@@ -355,19 +325,16 @@ def render_message(msg: ContextMessage, allow_multipart: bool) -> str | List[dic
355325
def merge_messages(messages: List[ContextMessage]) -> ContextMessage:
356326
text = "\n\n".join(msg.text for msg in messages if msg.text)
357327
images = []
358-
# Collect single images
359-
for msg in messages:
360-
if msg.image:
361-
images.append(msg.image)
362-
# Collect multiple images (flatten the list)
328+
files = []
363329
for msg in messages:
364330
if msg.images:
365-
images.extend(msg.images)
366-
# Set image to single File if there's exactly one, otherwise None
367-
image = images[0] if len(images) == 1 else None
368-
# Set images to the list if there are multiple, otherwise None
331+
images.extend(msg.images)
332+
if msg.files:
333+
files.extend(msg.files)
334+
369335
images_list = images if len(images) > 1 else None
370-
return ContextMessage(role=messages[0].role, text=text, image=image, images=images_list)
336+
files_list = files if len(files) > 1 else None
337+
return ContextMessage(role=messages[0].role, text=text, images=images_list, files=files_list)
371338

372339
def merge_tool_calls(messages: List[ContextMessage]) -> List[Dict[str, Any]]:
373340
tool_calls = []
@@ -380,35 +347,23 @@ def merge_tool_calls(messages: List[ContextMessage]) -> List[Dict[str, Any]]:
380347
if hasattr(input_data, "text"):
381348
user_input_text = transform_user_message(input_data.text) if transform_user_message else input_data.text
382349

383-
user_input_image = None
384-
multipart = any(m.image for m in input_data.context)
385-
if hasattr(input_data, "image"):
386-
user_input_image = input_data.image
387-
multipart = multipart or input_data.image is not None
388-
389350
user_input_images = None
390351
if hasattr(input_data, "images"):
391352
user_input_images = input_data.images
392-
multipart = multipart or input_data.images is not None
393-
394-
user_input_file = None
395-
if hasattr(input_data, "file"):
396-
user_input_file = input_data.file
397-
multipart = multipart or input_data.file is not None
398353

399354
user_input_files = None
400355
if hasattr(input_data, "files"):
401356
user_input_files = input_data.files
402-
multipart = multipart or input_data.files is not None
403357

404358
user_input_reasoning = None
405359
if hasattr(input_data, "reasoning"):
406360
user_input_reasoning = input_data.reasoning
407-
multipart = multipart or input_data.reasoning is not None
361+
362+
multipart = any(m.images or m.files or m.reasoning for m in input_data.context)
408363

409364
input_role = input_data.role if hasattr(input_data, "role") else ContextMessageRole.USER
410365
input_tool_call_id = input_data.tool_call_id if hasattr(input_data, "tool_call_id") else None
411-
user_msg = ContextMessage(role=input_role, text=user_input_text, image=user_input_image, images=user_input_images, file=user_input_file, files=user_input_files, reasoning=user_input_reasoning, tool_call_id=input_tool_call_id)
366+
user_msg = ContextMessage(role=input_role, text=user_input_text, images=user_input_images, files=user_input_files, reasoning=user_input_reasoning, tool_call_id=input_tool_call_id)
412367

413368
input_data.context.append(user_msg)
414369

0 commit comments

Comments
 (0)