@@ -39,6 +39,7 @@ def __init__(
3939 n_threads : Optional [int ] = None ,
4040 n_batch : int = 8 ,
4141 last_n_tokens_size : int = 64 ,
42+ lora_base : Optional [str ] = None ,
4243 lora_path : Optional [str ] = None ,
4344 verbose : bool = True ,
4445 ):
@@ -58,6 +59,7 @@ def __init__(
5859 n_threads: Number of threads to use. If None, the number of threads is automatically determined.
5960 n_batch: Maximum number of prompt tokens to batch together when calling llama_eval.
6061 last_n_tokens_size: Maximum number of tokens to keep in the last_n_tokens deque.
62+ lora_base: Optional path to base model, useful if using a quantized base model and you want to apply LoRA to an f16 model.
6163 lora_path: Path to a LoRA file to apply to the model.
6264 verbose: Print verbose output to stderr.
6365
@@ -110,16 +112,21 @@ def __init__(
110112 self .model_path .encode ("utf-8" ), self .params
111113 )
112114
115+ self .lora_base = None
113116 self .lora_path = None
114117 if lora_path :
118+ self .lora_base = lora_base
119+ # Use lora_base if set otherwise revert to using model_path.
120+ lora_base = lora_base if lora_base is not None else model_path
121+
115122 self .lora_path = lora_path
116123 if llama_cpp .llama_apply_lora_from_file (
117124 self .ctx ,
118- self . lora_path .encode ("utf-8" ),
119- self . model_path .encode ("utf-8" ),
125+ lora_path .encode ("utf-8" ),
126+ lora_base .encode ("utf-8" ),
120127 llama_cpp .c_int (self .n_threads ),
121128 ):
122- raise RuntimeError (f"Failed to apply LoRA from path: { self . lora_path } " )
129+ raise RuntimeError (f"Failed to apply LoRA from lora path: { lora_path } to base path: { lora_base } " )
123130
124131 if self .verbose :
125132 print (llama_cpp .llama_print_system_info ().decode ("utf-8" ), file = sys .stderr )
@@ -815,6 +822,7 @@ def __getstate__(self):
815822 last_n_tokens_size = self .last_n_tokens_size ,
816823 n_batch = self .n_batch ,
817824 n_threads = self .n_threads ,
825+ lora_base = self .lora_base ,
818826 lora_path = self .lora_path ,
819827 )
820828
@@ -833,6 +841,7 @@ def __setstate__(self, state):
833841 n_threads = state ["n_threads" ],
834842 n_batch = state ["n_batch" ],
835843 last_n_tokens_size = state ["last_n_tokens_size" ],
844+ lora_base = state ["lora_base" ],
836845 lora_path = state ["lora_path" ],
837846 verbose = state ["verbose" ],
838847 )
0 commit comments