Skip to content

Commit 0457964

Browse files
committed
tests, file improvements, download and dir helpers
1 parent 61820c9 commit 0457964

File tree

3 files changed

+235
-38
lines changed

3 files changed

+235
-38
lines changed

pyproject.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,9 @@ addopts = "-v"
3535
[tool.flake8]
3636
max-line-length = 100
3737
exclude = [".git", "__pycache__", "build", "dist"]
38+
39+
[project.optional-dependencies]
40+
test = [
41+
"pytest>=7.0.0",
42+
"pytest-cov>=4.0.0",
43+
]

src/inferencesh/sdk.py

Lines changed: 92 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,20 @@
11
from typing import Optional, Union
2-
from pydantic import BaseModel, ConfigDict, PrivateAttr, model_validator
2+
from pydantic import BaseModel, ConfigDict, PrivateAttr, model_validator, Field, field_validator
33
import mimetypes
44
import os
55
import urllib.request
66
import urllib.parse
77
import tempfile
8-
from pydantic import Field
98
from typing import Any, Dict, List
109

1110
import inspect
1211
import ast
1312
import textwrap
1413
from collections import OrderedDict
14+
from enum import Enum
15+
import shutil
16+
from pathlib import Path
17+
import hashlib
1518

1619

1720
# inspired by https://github.com/pydantic/pydantic/issues/7580
@@ -102,39 +105,35 @@ async def unload(self):
102105

103106
class File(BaseModel):
104107
"""A class representing a file in the inference.sh ecosystem."""
105-
uri: Optional[str] = None # Original location (URL or file path)
108+
uri: Optional[str] = Field(default=None) # Original location (URL or file path)
106109
path: Optional[str] = None # Resolved local file path
107110
content_type: Optional[str] = None # MIME type of the file
108111
size: Optional[int] = None # File size in bytes
109112
filename: Optional[str] = None # Original filename if available
110113
_tmp_path: Optional[str] = PrivateAttr(default=None) # Internal storage for temporary file path
111114

112-
model_config = ConfigDict(
113-
arbitrary_types_allowed=True,
114-
populate_by_name=True
115-
)
116-
117-
@classmethod
118-
def __get_validators__(cls):
119-
# First yield the default validators
120-
yield cls.validate
121-
115+
def __init__(self, initializer=None, **data):
116+
if initializer is not None:
117+
if isinstance(initializer, str):
118+
data['uri'] = initializer
119+
elif isinstance(initializer, File):
120+
data = initializer.model_dump()
121+
else:
122+
raise ValueError(f'Invalid input for File: {initializer}')
123+
super().__init__(**data)
124+
125+
@model_validator(mode='before')
122126
@classmethod
123-
def validate(cls, value):
124-
"""Convert string values to File objects."""
125-
if isinstance(value, str):
126-
# If it's a string, treat it as a uri
127-
return cls(uri=value)
128-
elif isinstance(value, cls):
129-
# If it's already a File instance, return it as is
130-
return value
131-
elif isinstance(value, dict):
132-
# If it's a dict, use normal validation
133-
return cls(**value)
134-
raise ValueError(f'Invalid input for File: {value}')
135-
127+
def convert_str_to_file(cls, values):
128+
print(f"check_uri_or_path input: {values}")
129+
if isinstance(values, str): # Only accept strings
130+
return {"uri": values}
131+
elif isinstance(values, dict):
132+
return values
133+
raise ValueError(f'Invalid input for File: {values}')
134+
136135
@model_validator(mode='after')
137-
def check_uri_or_path(self) -> 'File':
136+
def validate_required_fields(self) -> 'File':
138137
"""Validate that either uri or path is provided."""
139138
if not self.uri and not self.path:
140139
raise ValueError("Either 'uri' or 'path' must be provided")
@@ -147,7 +146,10 @@ def model_post_init(self, _: Any) -> None:
147146
self.path = os.path.abspath(self.uri)
148147
elif self.uri:
149148
self.path = self.uri
150-
self._populate_metadata()
149+
if self.path:
150+
self._populate_metadata()
151+
else:
152+
raise ValueError("Either 'uri' or 'path' must be provided")
151153

152154
def _is_url(self, path: str) -> bool:
153155
"""Check if the path is a URL."""
@@ -234,13 +236,24 @@ def exists(self) -> bool:
234236

235237
def refresh_metadata(self) -> None:
236238
"""Refresh all metadata from the file."""
237-
self._populate_metadata()
239+
if os.path.exists(self.path):
240+
self.content_type = self._guess_content_type()
241+
self.size = self._get_file_size() # Always update size
242+
self.filename = self._get_filename()
238243

239244

245+
class ContextMessageRole(str, Enum):
246+
USER = "user"
247+
ASSISTANT = "assistant"
248+
SYSTEM = "system"
249+
250+
class Message(BaseModel):
251+
role: ContextMessageRole
252+
content: str
253+
240254
class ContextMessage(BaseModel):
241-
role: str = Field(
255+
role: ContextMessageRole = Field(
242256
description="The role of the message",
243-
enum=["user", "assistant", "system"]
244257
)
245258
text: str = Field(
246259
description="The text content of the message"
@@ -300,4 +313,51 @@ class LLMInputWithImage(LLMInput):
300313
image: Optional[File] = Field(
301314
description="The image to use for the model",
302315
default=None
303-
)
316+
)
317+
318+
class DownloadDir(str, Enum):
319+
"""Standard download directories used by the SDK."""
320+
DATA = "./data" # Persistent storage/cache directory
321+
TEMP = "./tmp" # Temporary storage directory
322+
CACHE = "./cache" # Cache directory
323+
324+
def download(url: str, directory: Union[str, Path, DownloadDir]) -> str:
325+
"""Download a file to the specified directory and return its path.
326+
327+
Args:
328+
url: The URL to download from
329+
directory: The directory to save the file to. Can be a string path,
330+
Path object, or DownloadDir enum value.
331+
332+
Returns:
333+
str: The path to the downloaded file
334+
"""
335+
# Convert directory to Path
336+
dir_path = Path(directory)
337+
dir_path.mkdir(exist_ok=True)
338+
339+
# Create hash directory from URL
340+
url_hash = hashlib.sha256(url.encode()).hexdigest()[:12]
341+
hash_dir = dir_path / url_hash
342+
hash_dir.mkdir(exist_ok=True)
343+
344+
# Keep original filename
345+
filename = os.path.basename(urllib.parse.urlparse(url).path)
346+
if not filename:
347+
filename = 'download'
348+
349+
output_path = hash_dir / filename
350+
351+
# If file exists in directory and it's not a temp directory, return it
352+
if output_path.exists() and directory != DownloadDir.TEMP:
353+
return str(output_path)
354+
355+
# Download the file
356+
file = File(url)
357+
if file.path:
358+
shutil.copy2(file.path, output_path)
359+
# Prevent the File instance from deleting its temporary file
360+
file._tmp_path = None
361+
return str(output_path)
362+
363+
raise RuntimeError(f"Failed to download {url}")

tests/test_sdk.py

Lines changed: 137 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import os
22
import pytest
3+
import tempfile
34
from inferencesh import BaseApp, BaseAppInput, BaseAppOutput, File
5+
import urllib.parse
46

57
def test_file_creation():
68
# Create a temporary file
@@ -22,10 +24,139 @@ class TestInput(BaseAppInput):
2224
class TestOutput(BaseAppOutput):
2325
result: str
2426

25-
class TestApp(BaseApp):
26-
async def run(self, app_input: TestInput) -> TestOutput:
27-
return TestOutput(result=f"Processed: {app_input.text}")
28-
29-
app = TestApp()
27+
# Use BaseApp directly, don't subclass with implementation
28+
app = BaseApp()
29+
import asyncio
3030
with pytest.raises(NotImplementedError):
31-
app.run(TestInput(text="test"))
31+
asyncio.run(app.run(TestInput(text="test")))
32+
33+
def test_file_from_local_path():
34+
# Create a temporary file
35+
with tempfile.NamedTemporaryFile(suffix='.txt', delete=False) as f:
36+
f.write(b"test content")
37+
path = f.name
38+
39+
try:
40+
# Test creating File from path
41+
file = File(uri=path)
42+
assert file.exists()
43+
assert file.size == len("test content")
44+
assert file.content_type == "text/plain"
45+
assert file.filename == os.path.basename(path)
46+
assert file.path == os.path.abspath(path)
47+
assert file._tmp_path is None # Should not create temp file for local paths
48+
finally:
49+
os.unlink(path)
50+
51+
def test_file_from_relative_path():
52+
# Create a file in current directory
53+
with open("test_relative.txt", "w") as f:
54+
f.write("relative test")
55+
56+
try:
57+
file = File(uri="test_relative.txt")
58+
assert file.exists()
59+
assert os.path.isabs(file.path)
60+
assert file.filename == "test_relative.txt"
61+
finally:
62+
os.unlink("test_relative.txt")
63+
64+
def test_file_validation():
65+
# Test empty initialization
66+
with pytest.raises(ValueError, match="Either 'uri' or 'path' must be provided"):
67+
File()
68+
69+
# Test invalid input type
70+
with pytest.raises(ValueError, match="Invalid input for File"):
71+
File(123)
72+
73+
# Test string input (should work)
74+
with tempfile.NamedTemporaryFile(suffix='.txt', delete=False) as f:
75+
f.write(b"test content")
76+
path = f.name
77+
78+
try:
79+
file = File(path)
80+
assert isinstance(file, File)
81+
assert file.uri == path
82+
assert file.exists()
83+
finally:
84+
os.unlink(path)
85+
86+
def test_file_from_url(monkeypatch):
87+
# Mock URL download
88+
def mock_urlopen(request):
89+
class MockResponse:
90+
def __enter__(self):
91+
return self
92+
93+
def __exit__(self, *args):
94+
pass
95+
96+
def read(self):
97+
return b"mocked content"
98+
99+
return MockResponse()
100+
101+
monkeypatch.setattr(urllib.request, 'urlopen', mock_urlopen)
102+
103+
url = "https://example.com/test.pdf"
104+
file = File(uri=url)
105+
106+
try:
107+
assert file._is_url(url)
108+
assert file.exists()
109+
assert file._tmp_path is not None
110+
assert file._tmp_path.endswith('.pdf') # Just check the extension
111+
assert file.content_type == "application/pdf"
112+
finally:
113+
# Cleanup should happen in __del__ but let's be explicit for testing
114+
if file._tmp_path and os.path.exists(file._tmp_path):
115+
os.unlink(file._tmp_path)
116+
117+
def test_file_metadata_refresh():
118+
with tempfile.NamedTemporaryFile(suffix='.json', delete=False) as f:
119+
initial_content = b'{"test": "data"}'
120+
f.write(initial_content)
121+
path = f.name
122+
123+
try:
124+
file = File(uri=path)
125+
initial_size = file.size
126+
127+
# Modify file with significantly more data
128+
with open(path, 'ab') as f: # Open in append binary mode
129+
additional_data = b'\n{"more": "data"}\n' * 10 # Add multiple lines of data
130+
f.write(additional_data)
131+
132+
# Refresh metadata
133+
file.refresh_metadata()
134+
assert file.size > initial_size, f"New size {file.size} should be larger than initial size {initial_size}"
135+
finally:
136+
os.unlink(path)
137+
138+
def test_file_cleanup(monkeypatch):
139+
# Mock URL download - same mock as test_file_from_url
140+
def mock_urlopen(request):
141+
class MockResponse:
142+
def __enter__(self):
143+
return self
144+
145+
def __exit__(self, *args):
146+
pass
147+
148+
def read(self):
149+
return b"mocked content"
150+
151+
return MockResponse()
152+
153+
monkeypatch.setattr(urllib.request, 'urlopen', mock_urlopen)
154+
155+
url = "https://example.com/test.txt"
156+
file = File(uri=url)
157+
158+
if file._tmp_path:
159+
tmp_path = file._tmp_path
160+
assert os.path.exists(tmp_path)
161+
del file
162+
assert not os.path.exists(tmp_path)

0 commit comments

Comments
 (0)