Skip to content

Commit 19fd7eb

Browse files
committed
llm context imporvements
1 parent cc9db96 commit 19fd7eb

File tree

1 file changed

+74
-3
lines changed

1 file changed

+74
-3
lines changed

src/inferencesh/models/llm.py

Lines changed: 74 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,12 +97,41 @@ class MultipleImageCapabilityMixin(BaseModel):
9797
description="the images to use for the model",
9898
default=None,
9999
)
100+
101+
class FileCapabilityMixin(BaseModel):
102+
"""Mixin for models that support file inputs."""
103+
file: Optional[File] = Field(
104+
description="the file to use for the model",
105+
default=None,
106+
)
107+
108+
class MultipleFileCapabilityMixin(BaseModel):
109+
"""Mixin for models that support multiple file inputs."""
110+
files: Optional[List[File]] = Field(
111+
description="the files to use for the model",
112+
default=None,
113+
)
114+
115+
class ReasoningEffortEnum(str, Enum):
116+
"""Enum for reasoning effort."""
117+
LOW = "low"
118+
MEDIUM = "medium"
119+
HIGH = "high"
120+
NONE = "none"
100121

101122
class ReasoningCapabilityMixin(BaseModel):
102123
"""Mixin for models that support reasoning."""
103-
reasoning: bool = Field(
124+
reasoning: str | None = Field(
125+
description="the reasoning input of the message",
126+
default=None
127+
)
128+
reasoning_effort: ReasoningEffortEnum = Field(
104129
description="enable step-by-step reasoning",
105-
default=False
130+
default=ReasoningEffortEnum.NONE
131+
)
132+
reasoning_max_tokens: int | None = Field(
133+
description="the maximum number of tokens to use for reasoning",
134+
default=None
106135
)
107136

108137
class ToolsCapabilityMixin(BaseModel):
@@ -246,6 +275,15 @@ def image_to_base64_data_uri(file_path):
246275

247276
return f"data:image/{content_type};base64,{base64_data}"
248277

278+
def file_to_base64_data_uri(file_path):
279+
with open(file_path, "rb") as file:
280+
base64_data = base64.b64encode(file.read()).decode('utf-8')
281+
file_extension = file_path.split(".")[-1]
282+
content_type = "application/octet-stream"
283+
if file_extension == "pdf":
284+
content_type = "application/pdf"
285+
return f"data:{content_type};base64,{base64_data}"
286+
249287
def build_messages(
250288
input_data: LLMInput,
251289
transform_user_message: Optional[Callable[[str], str]] = None,
@@ -281,6 +319,24 @@ def render_message(msg: ContextMessage, allow_multipart: bool) -> str | List[dic
281319
elif image.uri:
282320
parts.append({"type": "image_url", "image_url": {"url": image.uri}})
283321

322+
if msg.file:
323+
if msg.file.path:
324+
file_data_uri = file_to_base64_data_uri(msg.file.path)
325+
parts.append({"type": "file_url", "file_url": {"url": file_data_uri}})
326+
elif msg.file.uri:
327+
parts.append({"type": "file_url", "file_url": {"url": msg.file.uri}})
328+
329+
if msg.files:
330+
for file in msg.files:
331+
if file.path:
332+
file_data_uri = file_to_base64_data_uri(file.path)
333+
parts.append({"type": "file_url", "file_url": {"url": file_data_uri}})
334+
elif file.uri:
335+
parts.append({"type": "file_url", "file_url": {"url": file.uri}})
336+
337+
if msg.reasoning:
338+
parts.append({"type": "reasoning", "reasoning": msg.reasoning})
339+
284340
if allow_multipart:
285341
return parts
286342

@@ -335,9 +391,24 @@ def merge_tool_calls(messages: List[ContextMessage]) -> List[Dict[str, Any]]:
335391
user_input_images = input_data.images
336392
multipart = multipart or input_data.images is not None
337393

394+
user_input_file = None
395+
if hasattr(input_data, "file"):
396+
user_input_file = input_data.file
397+
multipart = multipart or input_data.file is not None
398+
399+
user_input_files = None
400+
if hasattr(input_data, "files"):
401+
user_input_files = input_data.files
402+
multipart = multipart or input_data.files is not None
403+
404+
user_input_reasoning = None
405+
if hasattr(input_data, "reasoning"):
406+
user_input_reasoning = input_data.reasoning
407+
multipart = multipart or input_data.reasoning is not None
408+
338409
input_role = input_data.role if hasattr(input_data, "role") else ContextMessageRole.USER
339410
input_tool_call_id = input_data.tool_call_id if hasattr(input_data, "tool_call_id") else None
340-
user_msg = ContextMessage(role=input_role, text=user_input_text, image=user_input_image, images=user_input_images, tool_call_id=input_tool_call_id)
411+
user_msg = ContextMessage(role=input_role, text=user_input_text, image=user_input_image, images=user_input_images, file=user_input_file, files=user_input_files, reasoning=user_input_reasoning, tool_call_id=input_tool_call_id)
341412

342413
input_data.context.append(user_msg)
343414

0 commit comments

Comments
 (0)