|
5 | 5 | import urllib.request |
6 | 6 | import urllib.parse |
7 | 7 | import tempfile |
| 8 | +import hashlib |
| 9 | +from pathlib import Path |
8 | 10 | from tqdm import tqdm |
9 | 11 |
|
10 | 12 |
|
11 | 13 | class File(BaseModel): |
12 | 14 | """A class representing a file in the inference.sh ecosystem.""" |
| 15 | + |
| 16 | + @classmethod |
| 17 | + def get_cache_dir(cls) -> Path: |
| 18 | + """Get the cache directory path based on environment variables or default location.""" |
| 19 | + if cache_dir := os.environ.get("FILE_CACHE_DIR"): |
| 20 | + path = Path(cache_dir) |
| 21 | + else: |
| 22 | + path = Path.home() / ".cache" / "inferencesh" / "files" |
| 23 | + path.mkdir(parents=True, exist_ok=True) |
| 24 | + return path |
| 25 | + |
| 26 | + def _get_cache_path(self, url: str) -> Path: |
| 27 | + """Get the cache path for a URL using a hash-based directory structure.""" |
| 28 | + # Parse URL components |
| 29 | + parsed_url = urllib.parse.urlparse(url) |
| 30 | + |
| 31 | + # Create hash from URL path and query parameters for uniqueness |
| 32 | + url_components = parsed_url.netloc + parsed_url.path |
| 33 | + if parsed_url.query: |
| 34 | + url_components += '?' + parsed_url.query |
| 35 | + url_hash = hashlib.sha256(url_components.encode()).hexdigest()[:12] |
| 36 | + |
| 37 | + # Get filename from URL or use default |
| 38 | + filename = os.path.basename(parsed_url.path) |
| 39 | + if not filename: |
| 40 | + filename = 'download' |
| 41 | + |
| 42 | + # Create hash directory in cache |
| 43 | + cache_dir = self.get_cache_dir() / url_hash |
| 44 | + cache_dir.mkdir(exist_ok=True) |
| 45 | + |
| 46 | + return cache_dir / filename |
13 | 47 | uri: Optional[str] = Field(default=None) # Original location (URL or file path) |
14 | 48 | path: Optional[str] = None # Resolved local file path |
15 | 49 | content_type: Optional[str] = None # MIME type of the file |
@@ -74,11 +108,20 @@ def _is_url(self, path: str) -> bool: |
74 | 108 | return parsed.scheme in ('http', 'https') |
75 | 109 |
|
76 | 110 | def _download_url(self) -> None: |
77 | | - """Download the URL to a temporary file and update the path.""" |
| 111 | + """Download the URL to the cache directory and update the path.""" |
78 | 112 | original_url = self.uri |
| 113 | + cache_path = self._get_cache_path(original_url) |
| 114 | + |
| 115 | + # If file exists in cache, use it |
| 116 | + if cache_path.exists(): |
| 117 | + print(f"Using cached file: {cache_path}") |
| 118 | + self.path = str(cache_path) |
| 119 | + return |
| 120 | + |
| 121 | + print(f"Downloading URL: {original_url} to {cache_path}") |
79 | 122 | tmp_file = None |
80 | 123 | try: |
81 | | - # Create a temporary file with a suffix based on the URL path |
| 124 | + # Download to temporary file first to avoid partial downloads in cache |
82 | 125 | suffix = os.path.splitext(urllib.parse.urlparse(original_url).path)[1] |
83 | 126 | tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=suffix) |
84 | 127 | self._tmp_path = tmp_file.name |
@@ -133,7 +176,10 @@ def _download_url(self) -> None: |
133 | 176 | # If we read the whole body at once, exit loop |
134 | 177 | break |
135 | 178 |
|
136 | | - self.path = self._tmp_path |
| 179 | + # Move the temporary file to the cache location |
| 180 | + os.replace(self._tmp_path, cache_path) |
| 181 | + self._tmp_path = None # Prevent deletion in __del__ |
| 182 | + self.path = str(cache_path) |
137 | 183 | except (urllib.error.URLError, urllib.error.HTTPError) as e: |
138 | 184 | raise RuntimeError(f"Failed to download URL {original_url}: {str(e)}") |
139 | 185 | except IOError as e: |
|
0 commit comments