Skip to content

Commit 453e517

Browse files
committed
Add seperate lora_base path for applying LoRA to quantized models using original unquantized model weights.
1 parent 32ca803 commit 453e517

File tree

1 file changed

+12
-3
lines changed

1 file changed

+12
-3
lines changed

llama_cpp/llama.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def __init__(
3939
n_threads: Optional[int] = None,
4040
n_batch: int = 8,
4141
last_n_tokens_size: int = 64,
42+
lora_base: Optional[str] = None,
4243
lora_path: Optional[str] = None,
4344
verbose: bool = True,
4445
):
@@ -58,6 +59,7 @@ def __init__(
5859
n_threads: Number of threads to use. If None, the number of threads is automatically determined.
5960
n_batch: Maximum number of prompt tokens to batch together when calling llama_eval.
6061
last_n_tokens_size: Maximum number of tokens to keep in the last_n_tokens deque.
62+
lora_base: Optional path to base model, useful if using a quantized base model and you want to apply LoRA to an f16 model.
6163
lora_path: Path to a LoRA file to apply to the model.
6264
verbose: Print verbose output to stderr.
6365
@@ -110,16 +112,21 @@ def __init__(
110112
self.model_path.encode("utf-8"), self.params
111113
)
112114

115+
self.lora_base = None
113116
self.lora_path = None
114117
if lora_path:
118+
self.lora_base = lora_base
119+
# Use lora_base if set otherwise revert to using model_path.
120+
lora_base = lora_base if lora_base is not None else model_path
121+
115122
self.lora_path = lora_path
116123
if llama_cpp.llama_apply_lora_from_file(
117124
self.ctx,
118-
self.lora_path.encode("utf-8"),
119-
self.model_path.encode("utf-8"),
125+
lora_path.encode("utf-8"),
126+
lora_base.encode("utf-8"),
120127
llama_cpp.c_int(self.n_threads),
121128
):
122-
raise RuntimeError(f"Failed to apply LoRA from path: {self.lora_path}")
129+
raise RuntimeError(f"Failed to apply LoRA from lora path: {lora_path} to base path: {lora_base}")
123130

124131
if self.verbose:
125132
print(llama_cpp.llama_print_system_info().decode("utf-8"), file=sys.stderr)
@@ -815,6 +822,7 @@ def __getstate__(self):
815822
last_n_tokens_size=self.last_n_tokens_size,
816823
n_batch=self.n_batch,
817824
n_threads=self.n_threads,
825+
lora_base=self.lora_base,
818826
lora_path=self.lora_path,
819827
)
820828

@@ -833,6 +841,7 @@ def __setstate__(self, state):
833841
n_threads=state["n_threads"],
834842
n_batch=state["n_batch"],
835843
last_n_tokens_size=state["last_n_tokens_size"],
844+
lora_base=state["lora_base"],
836845
lora_path=state["lora_path"],
837846
verbose=state["verbose"],
838847
)

0 commit comments

Comments
 (0)