Skip to content

Commit 375231e

Browse files
committed
improvements for File and Dirs
1 parent be74b56 commit 375231e

File tree

1 file changed

+36
-17
lines changed

1 file changed

+36
-17
lines changed

src/inferencesh/sdk.py

Lines changed: 36 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -140,17 +140,29 @@ def validate_required_fields(self) -> 'File':
140140
return self
141141

142142
def model_post_init(self, _: Any) -> None:
143-
if self.uri and self._is_url(self.uri):
144-
self._download_url()
145-
elif self.uri and not os.path.isabs(self.uri):
146-
self.path = os.path.abspath(self.uri)
147-
elif self.uri:
148-
self.path = self.uri
143+
"""Initialize file path and metadata after model creation.
144+
145+
This method handles:
146+
1. Downloading URLs to local files if uri is a URL
147+
2. Converting relative paths to absolute paths
148+
3. Populating file metadata
149+
"""
150+
# Handle uri if provided
151+
if self.uri:
152+
if self._is_url(self.uri):
153+
self._download_url()
154+
else:
155+
# Convert relative paths to absolute, leave absolute paths unchanged
156+
self.path = os.path.abspath(self.uri)
157+
158+
# Handle path if provided
149159
if self.path:
160+
# Convert relative paths to absolute, leave absolute paths unchanged
161+
self.path = os.path.abspath(self.path)
150162
self._populate_metadata()
151-
else:
152-
raise ValueError("Either 'uri' or 'path' must be provided")
153-
163+
return
164+
165+
raise ValueError("Either 'uri' or 'path' must be provided and be valid")
154166
def _is_url(self, path: str) -> bool:
155167
"""Check if the path is a URL."""
156168
parsed = urllib.parse.urlparse(path)
@@ -326,19 +338,26 @@ class LLMInputWithImage(LLMInput):
326338
default=None
327339
)
328340

329-
class DownloadDir(str, Enum):
330-
"""Standard download directories used by the SDK."""
331-
DATA = "./data" # Persistent storage/cache directory
332-
TEMP = "./tmp" # Temporary storage directory
333-
CACHE = "./cache" # Cache directory
341+
class StorageDir(str, Enum):
342+
"""Standard storage directories used by the SDK."""
343+
DATA = "/app/data" # Persistent storage/cache directory
344+
TEMP = "/app/tmp" # Temporary storage directory
345+
CACHE = "/app/cache" # Cache directory
346+
347+
@property
348+
def path(self) -> Path:
349+
"""Get the Path object for this storage directory, ensuring it exists."""
350+
path = Path(self.value)
351+
path.mkdir(parents=True, exist_ok=True)
352+
return path
334353

335-
def download(url: str, directory: Union[str, Path, DownloadDir]) -> str:
354+
def download(url: str, directory: Union[str, Path, StorageDir]) -> str:
336355
"""Download a file to the specified directory and return its path.
337356
338357
Args:
339358
url: The URL to download from
340359
directory: The directory to save the file to. Can be a string path,
341-
Path object, or DownloadDir enum value.
360+
Path object, or StorageDir enum value.
342361
343362
Returns:
344363
str: The path to the downloaded file
@@ -360,7 +379,7 @@ def download(url: str, directory: Union[str, Path, DownloadDir]) -> str:
360379
output_path = hash_dir / filename
361380

362381
# If file exists in directory and it's not a temp directory, return it
363-
if output_path.exists() and directory != DownloadDir.TEMP:
382+
if output_path.exists() and directory != StorageDir.TEMP:
364383
return str(output_path)
365384

366385
# Download the file

0 commit comments

Comments
 (0)