1+ from typing import Optional , Union , Any
2+ from pydantic import BaseModel , Field , PrivateAttr , model_validator
3+ import mimetypes
4+ import os
5+ import urllib .request
6+ import urllib .parse
7+ import tempfile
8+ from tqdm import tqdm
9+
10+
11+ class File (BaseModel ):
12+ """A class representing a file in the inference.sh ecosystem."""
13+ uri : Optional [str ] = Field (default = None ) # Original location (URL or file path)
14+ path : Optional [str ] = None # Resolved local file path
15+ content_type : Optional [str ] = None # MIME type of the file
16+ size : Optional [int ] = None # File size in bytes
17+ filename : Optional [str ] = None # Original filename if available
18+ _tmp_path : Optional [str ] = PrivateAttr (default = None ) # Internal storage for temporary file path
19+
20+ def __init__ (self , initializer = None , ** data ):
21+ if initializer is not None :
22+ if isinstance (initializer , str ):
23+ data ['uri' ] = initializer
24+ elif isinstance (initializer , File ):
25+ data = initializer .model_dump ()
26+ else :
27+ raise ValueError (f'Invalid input for File: { initializer } ' )
28+ super ().__init__ (** data )
29+
30+ @model_validator (mode = 'before' )
31+ @classmethod
32+ def convert_str_to_file (cls , values ):
33+ if isinstance (values , str ): # Only accept strings
34+ return {"uri" : values }
35+ elif isinstance (values , dict ):
36+ return values
37+ raise ValueError (f'Invalid input for File: { values } ' )
38+
39+ @model_validator (mode = 'after' )
40+ def validate_required_fields (self ) -> 'File' :
41+ """Validate that either uri or path is provided."""
42+ if not self .uri and not self .path :
43+ raise ValueError ("Either 'uri' or 'path' must be provided" )
44+ return self
45+
46+ def model_post_init (self , _ : Any ) -> None :
47+ """Initialize file path and metadata after model creation.
48+
49+ This method handles:
50+ 1. Downloading URLs to local files if uri is a URL
51+ 2. Converting relative paths to absolute paths
52+ 3. Populating file metadata
53+ """
54+ # Handle uri if provided
55+ if self .uri :
56+ if self ._is_url (self .uri ):
57+ self ._download_url ()
58+ else :
59+ # Convert relative paths to absolute, leave absolute paths unchanged
60+ self .path = os .path .abspath (self .uri )
61+
62+ # Handle path if provided
63+ if self .path :
64+ # Convert relative paths to absolute, leave absolute paths unchanged
65+ self .path = os .path .abspath (self .path )
66+ self ._populate_metadata ()
67+ return
68+
69+ raise ValueError ("Either 'uri' or 'path' must be provided and be valid" )
70+
71+ def _is_url (self , path : str ) -> bool :
72+ """Check if the path is a URL."""
73+ parsed = urllib .parse .urlparse (path )
74+ return parsed .scheme in ('http' , 'https' )
75+
76+ def _download_url (self ) -> None :
77+ """Download the URL to a temporary file and update the path."""
78+ original_url = self .uri
79+ tmp_file = None
80+ try :
81+ # Create a temporary file with a suffix based on the URL path
82+ suffix = os .path .splitext (urllib .parse .urlparse (original_url ).path )[1 ]
83+ tmp_file = tempfile .NamedTemporaryFile (delete = False , suffix = suffix )
84+ self ._tmp_path = tmp_file .name
85+
86+ # Set up request with user agent
87+ headers = {
88+ 'User-Agent' : (
89+ 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) '
90+ 'AppleWebKit/537.36 (KHTML, like Gecko) '
91+ 'Chrome/91.0.4472.124 Safari/537.36'
92+ )
93+ }
94+ req = urllib .request .Request (original_url , headers = headers )
95+
96+ # Download the file with progress bar
97+ print (f"Downloading URL: { original_url } to { self ._tmp_path } " )
98+ try :
99+ with urllib .request .urlopen (req ) as response :
100+ total_size = int (response .headers .get ('content-length' , 0 ))
101+ block_size = 1024 # 1 Kibibyte
102+
103+ with tqdm (total = total_size , unit = 'iB' , unit_scale = True ) as pbar :
104+ with open (self ._tmp_path , 'wb' ) as out_file :
105+ while True :
106+ buffer = response .read (block_size )
107+ if not buffer :
108+ break
109+ out_file .write (buffer )
110+ pbar .update (len (buffer ))
111+
112+ self .path = self ._tmp_path
113+ except (urllib .error .URLError , urllib .error .HTTPError ) as e :
114+ raise RuntimeError (f"Failed to download URL { original_url } : { str (e )} " )
115+ except IOError as e :
116+ raise RuntimeError (f"Failed to write downloaded file to { self ._tmp_path } : { str (e )} " )
117+ except Exception as e :
118+ # Clean up temp file if something went wrong
119+ if tmp_file is not None and hasattr (self , '_tmp_path' ):
120+ try :
121+ os .unlink (self ._tmp_path )
122+ except (OSError , IOError ):
123+ pass
124+ raise RuntimeError (f"Error downloading URL { original_url } : { str (e )} " )
125+
126+ def __del__ (self ):
127+ """Cleanup temporary file if it exists."""
128+ if hasattr (self , '_tmp_path' ) and self ._tmp_path :
129+ try :
130+ os .unlink (self ._tmp_path )
131+ except (OSError , IOError ):
132+ pass
133+
134+ def _populate_metadata (self ) -> None :
135+ """Populate file metadata from the path if it exists."""
136+ if os .path .exists (self .path ):
137+ if not self .content_type :
138+ self .content_type = self ._guess_content_type ()
139+ if not self .size :
140+ self .size = self ._get_file_size ()
141+ if not self .filename :
142+ self .filename = self ._get_filename ()
143+
144+ @classmethod
145+ def from_path (cls , path : Union [str , os .PathLike ]) -> 'File' :
146+ """Create a File instance from a file path."""
147+ return cls (uri = str (path ))
148+
149+ def _guess_content_type (self ) -> Optional [str ]:
150+ """Guess the MIME type of the file."""
151+ return mimetypes .guess_type (self .path )[0 ]
152+
153+ def _get_file_size (self ) -> int :
154+ """Get the size of the file in bytes."""
155+ return os .path .getsize (self .path )
156+
157+ def _get_filename (self ) -> str :
158+ """Get the base filename from the path."""
159+ return os .path .basename (self .path )
160+
161+ def exists (self ) -> bool :
162+ """Check if the file exists."""
163+ return os .path .exists (self .path )
164+
165+ def refresh_metadata (self ) -> None :
166+ """Refresh all metadata from the file."""
167+ if os .path .exists (self .path ):
168+ self .content_type = self ._guess_content_type ()
169+ self .size = self ._get_file_size () # Always update size
170+ self .filename = self ._get_filename ()
171+
172+ @classmethod
173+ def model_json_schema (cls , ** kwargs ):
174+ schema = super ().model_json_schema (** kwargs )
175+ schema ["$id" ] = "/schemas/File"
176+ # Create a schema that accepts either a string or the full object
177+ return {
178+ "oneOf" : [
179+ {"type" : "string" }, # Accept string input
180+ schema # Accept full object input
181+ ]
182+ }
0 commit comments