11from typing import Optional , Union
2- from pydantic import BaseModel , ConfigDict , PrivateAttr , model_validator
2+ from pydantic import BaseModel , ConfigDict , PrivateAttr , model_validator , Field , field_validator
33import mimetypes
44import os
55import urllib .request
66import urllib .parse
77import tempfile
8- from pydantic import Field
98from typing import Any , Dict , List
109
1110import inspect
1211import ast
1312import textwrap
1413from collections import OrderedDict
14+ from enum import Enum
15+ import shutil
16+ from pathlib import Path
17+ import hashlib
1518
1619
1720# inspired by https://github.com/pydantic/pydantic/issues/7580
@@ -102,39 +105,35 @@ async def unload(self):
102105
103106class File (BaseModel ):
104107 """A class representing a file in the inference.sh ecosystem."""
105- uri : Optional [str ] = None # Original location (URL or file path)
108+ uri : Optional [str ] = Field ( default = None ) # Original location (URL or file path)
106109 path : Optional [str ] = None # Resolved local file path
107110 content_type : Optional [str ] = None # MIME type of the file
108111 size : Optional [int ] = None # File size in bytes
109112 filename : Optional [str ] = None # Original filename if available
110113 _tmp_path : Optional [str ] = PrivateAttr (default = None ) # Internal storage for temporary file path
111114
112- model_config = ConfigDict (
113- arbitrary_types_allowed = True ,
114- populate_by_name = True
115- )
116-
117- @classmethod
118- def __get_validators__ (cls ):
119- # First yield the default validators
120- yield cls .validate
121-
115+ def __init__ (self , initializer = None , ** data ):
116+ if initializer is not None :
117+ if isinstance (initializer , str ):
118+ data ['uri' ] = initializer
119+ elif isinstance (initializer , File ):
120+ data = initializer .model_dump ()
121+ else :
122+ raise ValueError (f'Invalid input for File: { initializer } ' )
123+ super ().__init__ (** data )
124+
125+ @model_validator (mode = 'before' )
122126 @classmethod
123- def validate (cls , value ):
124- """Convert string values to File objects."""
125- if isinstance (value , str ):
126- # If it's a string, treat it as a uri
127- return cls (uri = value )
128- elif isinstance (value , cls ):
129- # If it's already a File instance, return it as is
130- return value
131- elif isinstance (value , dict ):
132- # If it's a dict, use normal validation
133- return cls (** value )
134- raise ValueError (f'Invalid input for File: { value } ' )
135-
127+ def convert_str_to_file (cls , values ):
128+ print (f"check_uri_or_path input: { values } " )
129+ if isinstance (values , str ): # Only accept strings
130+ return {"uri" : values }
131+ elif isinstance (values , dict ):
132+ return values
133+ raise ValueError (f'Invalid input for File: { values } ' )
134+
136135 @model_validator (mode = 'after' )
137- def check_uri_or_path (self ) -> 'File' :
136+ def validate_required_fields (self ) -> 'File' :
138137 """Validate that either uri or path is provided."""
139138 if not self .uri and not self .path :
140139 raise ValueError ("Either 'uri' or 'path' must be provided" )
@@ -147,7 +146,10 @@ def model_post_init(self, _: Any) -> None:
147146 self .path = os .path .abspath (self .uri )
148147 elif self .uri :
149148 self .path = self .uri
150- self ._populate_metadata ()
149+ if self .path :
150+ self ._populate_metadata ()
151+ else :
152+ raise ValueError ("Either 'uri' or 'path' must be provided" )
151153
152154 def _is_url (self , path : str ) -> bool :
153155 """Check if the path is a URL."""
@@ -234,13 +236,24 @@ def exists(self) -> bool:
234236
235237 def refresh_metadata (self ) -> None :
236238 """Refresh all metadata from the file."""
237- self ._populate_metadata ()
239+ if os .path .exists (self .path ):
240+ self .content_type = self ._guess_content_type ()
241+ self .size = self ._get_file_size () # Always update size
242+ self .filename = self ._get_filename ()
238243
239244
245+ class ContextMessageRole (str , Enum ):
246+ USER = "user"
247+ ASSISTANT = "assistant"
248+ SYSTEM = "system"
249+
250+ class Message (BaseModel ):
251+ role : ContextMessageRole
252+ content : str
253+
240254class ContextMessage (BaseModel ):
241- role : str = Field (
255+ role : ContextMessageRole = Field (
242256 description = "The role of the message" ,
243- enum = ["user" , "assistant" , "system" ]
244257 )
245258 text : str = Field (
246259 description = "The text content of the message"
@@ -300,4 +313,51 @@ class LLMInputWithImage(LLMInput):
300313 image : Optional [File ] = Field (
301314 description = "The image to use for the model" ,
302315 default = None
303- )
316+ )
317+
318+ class DownloadDir (str , Enum ):
319+ """Standard download directories used by the SDK."""
320+ DATA = "./data" # Persistent storage/cache directory
321+ TEMP = "./tmp" # Temporary storage directory
322+ CACHE = "./cache" # Cache directory
323+
324+ def download (url : str , directory : Union [str , Path , DownloadDir ]) -> str :
325+ """Download a file to the specified directory and return its path.
326+
327+ Args:
328+ url: The URL to download from
329+ directory: The directory to save the file to. Can be a string path,
330+ Path object, or DownloadDir enum value.
331+
332+ Returns:
333+ str: The path to the downloaded file
334+ """
335+ # Convert directory to Path
336+ dir_path = Path (directory )
337+ dir_path .mkdir (exist_ok = True )
338+
339+ # Create hash directory from URL
340+ url_hash = hashlib .sha256 (url .encode ()).hexdigest ()[:12 ]
341+ hash_dir = dir_path / url_hash
342+ hash_dir .mkdir (exist_ok = True )
343+
344+ # Keep original filename
345+ filename = os .path .basename (urllib .parse .urlparse (url ).path )
346+ if not filename :
347+ filename = 'download'
348+
349+ output_path = hash_dir / filename
350+
351+ # If file exists in directory and it's not a temp directory, return it
352+ if output_path .exists () and directory != DownloadDir .TEMP :
353+ return str (output_path )
354+
355+ # Download the file
356+ file = File (url )
357+ if file .path :
358+ shutil .copy2 (file .path , output_path )
359+ # Prevent the File instance from deleting its temporary file
360+ file ._tmp_path = None
361+ return str (output_path )
362+
363+ raise RuntimeError (f"Failed to download { url } " )
0 commit comments