@@ -2227,6 +2227,7 @@ def from_pretrained(
22272227 cls ,
22282228 repo_id : str ,
22292229 filename : Optional [str ],
2230+ additional_files : Optional [List ] = None ,
22302231 local_dir : Optional [Union [str , os .PathLike [str ]]] = None ,
22312232 local_dir_use_symlinks : Union [bool , Literal ["auto" ]] = "auto" ,
22322233 cache_dir : Optional [Union [str , os .PathLike [str ]]] = None ,
@@ -2239,6 +2240,7 @@ def from_pretrained(
22392240 Args:
22402241 repo_id: The model repo id.
22412242 filename: A filename or glob pattern to match the model file in the repo.
2243+ additional_files: A list of filenames or glob patterns to match additional model files in the repo.
22422244 local_dir: The local directory to save the model to.
22432245 local_dir_use_symlinks: Whether to use symlinks when downloading the model.
22442246 **kwargs: Additional keyword arguments to pass to the Llama constructor.
@@ -2269,6 +2271,7 @@ def from_pretrained(
22692271 rel_path = Path (file ).relative_to (repo_id )
22702272 file_list .append (str (rel_path ))
22712273
2274+ # find the only/first shard file:
22722275 matching_files = [file for file in file_list if fnmatch .fnmatch (file , filename )] # type: ignore
22732276
22742277 if len (matching_files ) == 0 :
@@ -2298,6 +2301,35 @@ def from_pretrained(
22982301 cache_dir = cache_dir ,
22992302 )
23002303
2304+ if additional_files :
2305+ for additonal_file_name in additional_files :
2306+ # find the additional shard file:
2307+ matching_additional_files = [file for file in file_list if fnmatch .fnmatch (file , additonal_file_name )]
2308+
2309+ if len (matching_additional_files ) == 0 :
2310+ raise ValueError (
2311+ f"No file found in { repo_id } that match { additonal_file_name } \n \n "
2312+ f"Available Files:\n { json .dumps (file_list )} "
2313+ )
2314+
2315+ if len (matching_additional_files ) > 1 :
2316+ raise ValueError (
2317+ f"Multiple files found in { repo_id } matching { additonal_file_name } \n \n "
2318+ f"Available Files:\n { json .dumps (files )} "
2319+ )
2320+
2321+ (matching_additional_file ,) = matching_additional_files
2322+
2323+ # download the additional file
2324+ hf_hub_download (
2325+ repo_id = repo_id ,
2326+ filename = matching_additional_file ,
2327+ subfolder = subfolder ,
2328+ local_dir = local_dir ,
2329+ local_dir_use_symlinks = local_dir_use_symlinks ,
2330+ cache_dir = cache_dir ,
2331+ )
2332+
23012333 if local_dir is None :
23022334 model_path = hf_hub_download (
23032335 repo_id = repo_id ,
@@ -2311,6 +2343,7 @@ def from_pretrained(
23112343 else :
23122344 model_path = os .path .join (local_dir , filename )
23132345
2346+ # loading the first file of a sharded GGUF loads all remaining shard files in the subfolder
23142347 return cls (
23152348 model_path = model_path ,
23162349 ** kwargs ,
0 commit comments