Skip to content

Commit 3d167c1

Browse files
committed
url loading
1 parent e5a0795 commit 3d167c1

File tree

1 file changed

+33
-2
lines changed

1 file changed

+33
-2
lines changed

src/inferencesh/sdk.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
from pydantic import BaseModel, ConfigDict
33
import mimetypes
44
import os
5+
import urllib.request
6+
import urllib.parse
7+
import tempfile
58

69
class BaseAppInput(BaseModel):
710
pass
@@ -27,17 +30,45 @@ async def unload(self):
2730

2831
class File(BaseModel):
2932
"""A class representing a file in the inference.sh ecosystem."""
30-
path: str # Absolute path to the file
33+
path: str # Absolute path to the file or URL
3134
mime_type: Optional[str] = None # MIME type of the file
3235
size: Optional[int] = None # File size in bytes
3336
filename: Optional[str] = None # Original filename if available
37+
_tmp_path: Optional[str] = None # Internal storage for temporary file path
3438

3539
def __init__(self, **data):
3640
super().__init__(**data)
37-
if not os.path.isabs(self.path):
41+
if self._is_url(self.path):
42+
self._download_url()
43+
elif not os.path.isabs(self.path):
3844
self.path = os.path.abspath(self.path)
3945
self._populate_metadata()
4046

47+
def _is_url(self, path: str) -> bool:
48+
"""Check if the path is a URL."""
49+
parsed = urllib.parse.urlparse(path)
50+
return parsed.scheme in ('http', 'https')
51+
52+
def _download_url(self) -> None:
53+
"""Download the URL to a temporary file and update the path."""
54+
original_url = self.path
55+
# Create a temporary file with a suffix based on the URL path
56+
suffix = os.path.splitext(urllib.parse.urlparse(original_url).path)[1]
57+
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=suffix)
58+
self._tmp_path = tmp_file.name
59+
60+
# Download the file
61+
urllib.request.urlretrieve(original_url, self._tmp_path)
62+
self.path = self._tmp_path
63+
64+
def __del__(self):
65+
"""Cleanup temporary file if it exists."""
66+
if hasattr(self, '_tmp_path') and self._tmp_path:
67+
try:
68+
os.unlink(self._tmp_path)
69+
except:
70+
pass
71+
4172
def _populate_metadata(self) -> None:
4273
"""Populate file metadata from the path if it exists."""
4374
if os.path.exists(self.path):

0 commit comments

Comments
 (0)