Skip to content

Commit aa72a3e

Browse files
committed
File cache
1 parent 1d9b607 commit aa72a3e

File tree

3 files changed

+87
-26
lines changed

3 files changed

+87
-26
lines changed

examples/run.py

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,40 +5,47 @@
55

66
from inferencesh import Inference, TaskStatus
77

8+
89
def main() -> None:
9-
api_key = "YOUR_INFERENCESH_API_KEY"
10-
client = Inference(api_key=api_key)
11-
10+
api_key = "1nfsh-0fd8bxd5faawt9m0cztdym4q6s"
11+
client = Inference(api_key=api_key, base_url="https://api-dev.inference.sh")
12+
1213
app = "infsh/text-templating"
1314

1415
try:
15-
result = client.run_sync(
16+
task = client.run(
1617
{
17-
"app": app,
18+
"app": "infsh/lightning-wan-2-2-i2v-a14b",
1819
"input": {
19-
"template": "{1} / {2}",
20-
"strings": [
21-
"god",
22-
"particle",
23-
]
20+
"negative_prompt": "oversaturated, overexposed, static, blurry details, subtitles, stylized, artwork, painting, still image, overall gray, worst quality, low quality, JPEG artifacts, ugly, deformed, extra fingers, poorly drawn hands, poorly drawn face, malformed, disfigured, deformed limbs, fused fingers, static motionless frame, cluttered background, three legs, crowded background, walking backwards",
21+
"prompt": "test",
22+
"num_frames": 81,
23+
"num_inference_steps": 4,
24+
"fps": 16,
25+
"boundary_ratio": 0.875,
26+
"image": "https://images.dev.letz.ai/5ed74083-f9d1-4897-b8e3-c8f1596af767/fa6b9cbc-9465-4fe8-b5ba-08c7a75d4975/drawing_extreme_closeup_portrait_of_junck37342762320240205225633.jpg",
2427
},
25-
"worker_selection_mode": "private",
26-
},
28+
"infra": "private",
29+
# "workers": [],
30+
"variant": "fp16_480p",
31+
}
2732
)
2833

34+
print(task["id"])
2935

30-
# Print final result
31-
if result.get("status") == TaskStatus.COMPLETED:
36+
# Print final task
37+
if task.get("status") == TaskStatus.COMPLETED:
3238
print(f"\n✓ task completed successfully!")
33-
print(f"result: {result.get('output', {}).get('result')}")
39+
print(f"task: {task.get('output', {}).get('task')}")
3440
else:
35-
status = result.get("status")
41+
status = task.get("status")
3642
status_name = TaskStatus(status).name if status is not None else "UNKNOWN"
3743
print(f"\n✗ task did not complete. final status: {status_name}")
3844

3945
except Exception as exc: # noqa: BLE001
4046
print(f"\nerror during run_sync: {type(exc).__name__}: {exc}")
4147
raise # Re-raise to see full traceback
4248

49+
4350
if __name__ == "__main__":
4451
main()

src/inferencesh/models/file.py

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,45 @@
55
import urllib.request
66
import urllib.parse
77
import tempfile
8+
import hashlib
9+
from pathlib import Path
810
from tqdm import tqdm
911

1012

1113
class File(BaseModel):
1214
"""A class representing a file in the inference.sh ecosystem."""
15+
16+
@classmethod
17+
def get_cache_dir(cls) -> Path:
18+
"""Get the cache directory path based on environment variables or default location."""
19+
if cache_dir := os.environ.get("FILE_CACHE_DIR"):
20+
path = Path(cache_dir)
21+
else:
22+
path = Path.home() / ".cache" / "inferencesh" / "files"
23+
path.mkdir(parents=True, exist_ok=True)
24+
return path
25+
26+
def _get_cache_path(self, url: str) -> Path:
27+
"""Get the cache path for a URL using a hash-based directory structure."""
28+
# Parse URL components
29+
parsed_url = urllib.parse.urlparse(url)
30+
31+
# Create hash from URL path and query parameters for uniqueness
32+
url_components = parsed_url.netloc + parsed_url.path
33+
if parsed_url.query:
34+
url_components += '?' + parsed_url.query
35+
url_hash = hashlib.sha256(url_components.encode()).hexdigest()[:12]
36+
37+
# Get filename from URL or use default
38+
filename = os.path.basename(parsed_url.path)
39+
if not filename:
40+
filename = 'download'
41+
42+
# Create hash directory in cache
43+
cache_dir = self.get_cache_dir() / url_hash
44+
cache_dir.mkdir(exist_ok=True)
45+
46+
return cache_dir / filename
1347
uri: Optional[str] = Field(default=None) # Original location (URL or file path)
1448
path: Optional[str] = None # Resolved local file path
1549
content_type: Optional[str] = None # MIME type of the file
@@ -74,11 +108,20 @@ def _is_url(self, path: str) -> bool:
74108
return parsed.scheme in ('http', 'https')
75109

76110
def _download_url(self) -> None:
77-
"""Download the URL to a temporary file and update the path."""
111+
"""Download the URL to the cache directory and update the path."""
78112
original_url = self.uri
113+
cache_path = self._get_cache_path(original_url)
114+
115+
# If file exists in cache, use it
116+
if cache_path.exists():
117+
print(f"Using cached file: {cache_path}")
118+
self.path = str(cache_path)
119+
return
120+
121+
print(f"Downloading URL: {original_url} to {cache_path}")
79122
tmp_file = None
80123
try:
81-
# Create a temporary file with a suffix based on the URL path
124+
# Download to temporary file first to avoid partial downloads in cache
82125
suffix = os.path.splitext(urllib.parse.urlparse(original_url).path)[1]
83126
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=suffix)
84127
self._tmp_path = tmp_file.name
@@ -133,7 +176,10 @@ def _download_url(self) -> None:
133176
# If we read the whole body at once, exit loop
134177
break
135178

136-
self.path = self._tmp_path
179+
# Move the temporary file to the cache location
180+
os.replace(self._tmp_path, cache_path)
181+
self._tmp_path = None # Prevent deletion in __del__
182+
self.path = str(cache_path)
137183
except (urllib.error.URLError, urllib.error.HTTPError) as e:
138184
raise RuntimeError(f"Failed to download URL {original_url}: {str(e)}")
139185
except IOError as e:

src/inferencesh/utils/download.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,16 +24,24 @@ def download(url: str, directory: Union[str, Path, StorageDir]) -> str:
2424
dir_path = Path(directory)
2525
dir_path.mkdir(exist_ok=True)
2626

27-
# Create hash directory from URL
28-
url_hash = hashlib.sha256(url.encode()).hexdigest()[:12]
29-
hash_dir = dir_path / url_hash
30-
hash_dir.mkdir(exist_ok=True)
27+
# Parse URL components
28+
parsed_url = urllib.parse.urlparse(url)
3129

32-
# Keep original filename
33-
filename = os.path.basename(urllib.parse.urlparse(url).path)
30+
# Create hash from URL path and query parameters for uniqueness
31+
url_components = parsed_url.netloc + parsed_url.path
32+
if parsed_url.query:
33+
url_components += '?' + parsed_url.query
34+
url_hash = hashlib.sha256(url_components.encode()).hexdigest()[:12]
35+
36+
# Keep original filename or use a default
37+
filename = os.path.basename(parsed_url.path)
3438
if not filename:
3539
filename = 'download'
36-
40+
41+
# Create hash directory and store file
42+
hash_dir = dir_path / url_hash
43+
hash_dir.mkdir(exist_ok=True)
44+
3745
output_path = hash_dir / filename
3846

3947
# If file exists in directory and it's not a temp directory, return it

0 commit comments

Comments
 (0)