Skip to content

Commit d1a0e1c

Browse files
committed
org and llms
1 parent 13157da commit d1a0e1c

File tree

9 files changed

+717
-416
lines changed

9 files changed

+717
-416
lines changed

src/inferencesh/__init__.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,31 @@
22

33
__version__ = "0.1.2"
44

5-
from .sdk import BaseApp, BaseAppInput, BaseAppOutput, File, LLMInput, ContextMessage, ContextMessageWithImage, LLMInputWithImage
5+
from .models import (
6+
BaseApp,
7+
BaseAppInput,
8+
BaseAppOutput,
9+
File,
10+
ContextMessageRole,
11+
Message,
12+
ContextMessage,
13+
ContextMessageWithImage,
14+
LLMInput,
15+
LLMInputWithImage,
16+
)
17+
from .utils import StorageDir, download
18+
19+
__all__ = [
20+
"BaseApp",
21+
"BaseAppInput",
22+
"BaseAppOutput",
23+
"File",
24+
"ContextMessageRole",
25+
"Message",
26+
"ContextMessage",
27+
"ContextMessageWithImage",
28+
"LLMInput",
29+
"LLMInputWithImage",
30+
"StorageDir",
31+
"download",
32+
]

src/inferencesh/models/__init__.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
"""Models package for inference.sh SDK."""
2+
3+
from .base import BaseApp, BaseAppInput, BaseAppOutput
4+
from .file import File
5+
from .llm import (
6+
ContextMessageRole,
7+
Message,
8+
ContextMessage,
9+
ContextMessageWithImage,
10+
LLMInput,
11+
LLMInputWithImage,
12+
)
13+
14+
__all__ = [
15+
"BaseApp",
16+
"BaseAppInput",
17+
"BaseAppOutput",
18+
"File",
19+
"ContextMessageRole",
20+
"Message",
21+
"ContextMessage",
22+
"ContextMessageWithImage",
23+
"LLMInput",
24+
"LLMInputWithImage",
25+
]

src/inferencesh/models/base.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
from typing import Any, Dict, List
2+
from pydantic import BaseModel, ConfigDict
3+
import inspect
4+
import ast
5+
import textwrap
6+
from collections import OrderedDict
7+
8+
9+
class OrderedSchemaModel(BaseModel):
10+
"""A base model that ensures the JSON schema properties and required fields are in the order of field definition."""
11+
12+
@classmethod
13+
def model_json_schema(cls, by_alias: bool = True, **kwargs: Any) -> Dict[str, Any]:
14+
schema = super().model_json_schema(by_alias=by_alias, **kwargs)
15+
16+
field_order = cls._get_field_order()
17+
18+
if field_order:
19+
# Order properties
20+
ordered_properties = OrderedDict()
21+
for field_name in field_order:
22+
if field_name in schema['properties']:
23+
ordered_properties[field_name] = schema['properties'][field_name]
24+
25+
# Add any remaining properties that weren't in field_order
26+
for field_name, field_schema in schema['properties'].items():
27+
if field_name not in ordered_properties:
28+
ordered_properties[field_name] = field_schema
29+
30+
schema['properties'] = ordered_properties
31+
32+
# Order required fields
33+
if 'required' in schema:
34+
ordered_required = [field for field in field_order if field in schema['required']]
35+
# Add any remaining required fields that weren't in field_order
36+
ordered_required.extend([field for field in schema['required'] if field not in ordered_required])
37+
schema['required'] = ordered_required
38+
39+
return schema
40+
41+
@classmethod
42+
def _get_field_order(cls) -> List[str]:
43+
"""Get the order of fields as they were defined in the class."""
44+
source = inspect.getsource(cls)
45+
46+
# Unindent the entire source code
47+
source = textwrap.dedent(source)
48+
49+
try:
50+
module = ast.parse(source)
51+
except IndentationError:
52+
# If we still get an IndentationError, wrap the class in a dummy module
53+
source = f"class DummyModule:\n{textwrap.indent(source, ' ')}"
54+
module = ast.parse(source)
55+
# Adjust to look at the first class def inside DummyModule
56+
# noinspection PyUnresolvedReferences
57+
class_def = module.body[0].body[0]
58+
else:
59+
# Find the class definition
60+
class_def = next(
61+
node for node in module.body if isinstance(node, ast.ClassDef) and node.name == cls.__name__
62+
)
63+
64+
# Extract field names in the order they were defined
65+
field_order = []
66+
for node in class_def.body:
67+
if isinstance(node, ast.AnnAssign) and isinstance(node.target, ast.Name):
68+
field_order.append(node.target.id)
69+
70+
return field_order
71+
72+
73+
class BaseAppInput(OrderedSchemaModel):
74+
pass
75+
76+
77+
class BaseAppOutput(OrderedSchemaModel):
78+
pass
79+
80+
81+
class BaseApp(BaseModel):
82+
model_config = ConfigDict(
83+
arbitrary_types_allowed=True,
84+
extra='allow'
85+
)
86+
87+
async def setup(self):
88+
pass
89+
90+
async def run(self, app_input: BaseAppInput) -> BaseAppOutput:
91+
raise NotImplementedError("run method must be implemented")
92+
93+
async def unload(self):
94+
pass

src/inferencesh/models/file.py

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
from typing import Optional, Union, Any
2+
from pydantic import BaseModel, Field, PrivateAttr, model_validator
3+
import mimetypes
4+
import os
5+
import urllib.request
6+
import urllib.parse
7+
import tempfile
8+
from tqdm import tqdm
9+
10+
11+
class File(BaseModel):
12+
"""A class representing a file in the inference.sh ecosystem."""
13+
uri: Optional[str] = Field(default=None) # Original location (URL or file path)
14+
path: Optional[str] = None # Resolved local file path
15+
content_type: Optional[str] = None # MIME type of the file
16+
size: Optional[int] = None # File size in bytes
17+
filename: Optional[str] = None # Original filename if available
18+
_tmp_path: Optional[str] = PrivateAttr(default=None) # Internal storage for temporary file path
19+
20+
def __init__(self, initializer=None, **data):
21+
if initializer is not None:
22+
if isinstance(initializer, str):
23+
data['uri'] = initializer
24+
elif isinstance(initializer, File):
25+
data = initializer.model_dump()
26+
else:
27+
raise ValueError(f'Invalid input for File: {initializer}')
28+
super().__init__(**data)
29+
30+
@model_validator(mode='before')
31+
@classmethod
32+
def convert_str_to_file(cls, values):
33+
if isinstance(values, str): # Only accept strings
34+
return {"uri": values}
35+
elif isinstance(values, dict):
36+
return values
37+
raise ValueError(f'Invalid input for File: {values}')
38+
39+
@model_validator(mode='after')
40+
def validate_required_fields(self) -> 'File':
41+
"""Validate that either uri or path is provided."""
42+
if not self.uri and not self.path:
43+
raise ValueError("Either 'uri' or 'path' must be provided")
44+
return self
45+
46+
def model_post_init(self, _: Any) -> None:
47+
"""Initialize file path and metadata after model creation.
48+
49+
This method handles:
50+
1. Downloading URLs to local files if uri is a URL
51+
2. Converting relative paths to absolute paths
52+
3. Populating file metadata
53+
"""
54+
# Handle uri if provided
55+
if self.uri:
56+
if self._is_url(self.uri):
57+
self._download_url()
58+
else:
59+
# Convert relative paths to absolute, leave absolute paths unchanged
60+
self.path = os.path.abspath(self.uri)
61+
62+
# Handle path if provided
63+
if self.path:
64+
# Convert relative paths to absolute, leave absolute paths unchanged
65+
self.path = os.path.abspath(self.path)
66+
self._populate_metadata()
67+
return
68+
69+
raise ValueError("Either 'uri' or 'path' must be provided and be valid")
70+
71+
def _is_url(self, path: str) -> bool:
72+
"""Check if the path is a URL."""
73+
parsed = urllib.parse.urlparse(path)
74+
return parsed.scheme in ('http', 'https')
75+
76+
def _download_url(self) -> None:
77+
"""Download the URL to a temporary file and update the path."""
78+
original_url = self.uri
79+
tmp_file = None
80+
try:
81+
# Create a temporary file with a suffix based on the URL path
82+
suffix = os.path.splitext(urllib.parse.urlparse(original_url).path)[1]
83+
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=suffix)
84+
self._tmp_path = tmp_file.name
85+
86+
# Set up request with user agent
87+
headers = {
88+
'User-Agent': (
89+
'Mozilla/5.0 (Windows NT 10.0; Win64; x64) '
90+
'AppleWebKit/537.36 (KHTML, like Gecko) '
91+
'Chrome/91.0.4472.124 Safari/537.36'
92+
)
93+
}
94+
req = urllib.request.Request(original_url, headers=headers)
95+
96+
# Download the file with progress bar
97+
print(f"Downloading URL: {original_url} to {self._tmp_path}")
98+
try:
99+
with urllib.request.urlopen(req) as response:
100+
total_size = int(response.headers.get('content-length', 0))
101+
block_size = 1024 # 1 Kibibyte
102+
103+
with tqdm(total=total_size, unit='iB', unit_scale=True) as pbar:
104+
with open(self._tmp_path, 'wb') as out_file:
105+
while True:
106+
buffer = response.read(block_size)
107+
if not buffer:
108+
break
109+
out_file.write(buffer)
110+
pbar.update(len(buffer))
111+
112+
self.path = self._tmp_path
113+
except (urllib.error.URLError, urllib.error.HTTPError) as e:
114+
raise RuntimeError(f"Failed to download URL {original_url}: {str(e)}")
115+
except IOError as e:
116+
raise RuntimeError(f"Failed to write downloaded file to {self._tmp_path}: {str(e)}")
117+
except Exception as e:
118+
# Clean up temp file if something went wrong
119+
if tmp_file is not None and hasattr(self, '_tmp_path'):
120+
try:
121+
os.unlink(self._tmp_path)
122+
except (OSError, IOError):
123+
pass
124+
raise RuntimeError(f"Error downloading URL {original_url}: {str(e)}")
125+
126+
def __del__(self):
127+
"""Cleanup temporary file if it exists."""
128+
if hasattr(self, '_tmp_path') and self._tmp_path:
129+
try:
130+
os.unlink(self._tmp_path)
131+
except (OSError, IOError):
132+
pass
133+
134+
def _populate_metadata(self) -> None:
135+
"""Populate file metadata from the path if it exists."""
136+
if os.path.exists(self.path):
137+
if not self.content_type:
138+
self.content_type = self._guess_content_type()
139+
if not self.size:
140+
self.size = self._get_file_size()
141+
if not self.filename:
142+
self.filename = self._get_filename()
143+
144+
@classmethod
145+
def from_path(cls, path: Union[str, os.PathLike]) -> 'File':
146+
"""Create a File instance from a file path."""
147+
return cls(uri=str(path))
148+
149+
def _guess_content_type(self) -> Optional[str]:
150+
"""Guess the MIME type of the file."""
151+
return mimetypes.guess_type(self.path)[0]
152+
153+
def _get_file_size(self) -> int:
154+
"""Get the size of the file in bytes."""
155+
return os.path.getsize(self.path)
156+
157+
def _get_filename(self) -> str:
158+
"""Get the base filename from the path."""
159+
return os.path.basename(self.path)
160+
161+
def exists(self) -> bool:
162+
"""Check if the file exists."""
163+
return os.path.exists(self.path)
164+
165+
def refresh_metadata(self) -> None:
166+
"""Refresh all metadata from the file."""
167+
if os.path.exists(self.path):
168+
self.content_type = self._guess_content_type()
169+
self.size = self._get_file_size() # Always update size
170+
self.filename = self._get_filename()
171+
172+
@classmethod
173+
def model_json_schema(cls, **kwargs):
174+
schema = super().model_json_schema(**kwargs)
175+
schema["$id"] = "/schemas/File"
176+
# Create a schema that accepts either a string or the full object
177+
return {
178+
"oneOf": [
179+
{"type": "string"}, # Accept string input
180+
schema # Accept full object input
181+
]
182+
}

0 commit comments

Comments
 (0)