Skip to content

Commit 940ed6a

Browse files
committed
improve file helper to accept strings
1 parent 79a8f3c commit 940ed6a

File tree

1 file changed

+22
-3
lines changed

1 file changed

+22
-3
lines changed

src/inferencesh/sdk.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Optional, Union, ClassVar
1+
from typing import Optional, Union
22
from pydantic import BaseModel, ConfigDict, PrivateAttr, model_validator
33
import mimetypes
44
import os
@@ -114,6 +114,25 @@ class File(BaseModel):
114114
populate_by_name=True
115115
)
116116

117+
@classmethod
118+
def __get_validators__(cls):
119+
# First yield the default validators
120+
yield cls.validate
121+
122+
@classmethod
123+
def validate(cls, value):
124+
"""Convert string values to File objects."""
125+
if isinstance(value, str):
126+
# If it's a string, treat it as a uri
127+
return cls(uri=value)
128+
elif isinstance(value, cls):
129+
# If it's already a File instance, return it as is
130+
return value
131+
elif isinstance(value, dict):
132+
# If it's a dict, use normal validation
133+
return cls(**value)
134+
raise ValueError(f'Invalid input for File: {value}')
135+
117136
@model_validator(mode='after')
118137
def check_uri_or_path(self) -> 'File':
119138
"""Validate that either uri or path is provided."""
@@ -170,7 +189,7 @@ def _download_url(self) -> None:
170189
if tmp_file is not None and hasattr(self, '_tmp_path'):
171190
try:
172191
os.unlink(self._tmp_path)
173-
except:
192+
except (OSError, IOError):
174193
pass
175194
raise RuntimeError(f"Error downloading URL {original_url}: {str(e)}")
176195

@@ -179,7 +198,7 @@ def __del__(self):
179198
if hasattr(self, '_tmp_path') and self._tmp_path:
180199
try:
181200
os.unlink(self._tmp_path)
182-
except:
201+
except (OSError, IOError):
183202
pass
184203

185204
def _populate_metadata(self) -> None:

0 commit comments

Comments
 (0)