@@ -207,7 +207,6 @@ def __init__(
207207 n_ctx : int = 512 ,
208208 n_parts : int = - 1 ,
209209 n_gpu_layers : int = 0 ,
210- tensor_split : list [float ] = None ,
211210 seed : int = 1337 ,
212211 f16_kv : bool = True ,
213212 logits_all : bool = False ,
@@ -221,6 +220,7 @@ def __init__(
221220 lora_base : Optional [str ] = None ,
222221 lora_path : Optional [str ] = None ,
223222 low_vram : bool = False ,
223+ tensor_split : Optional [List [float ]] = None ,
224224 verbose : bool = True ,
225225 ):
226226 """Load a llama.cpp model from `model_path`.
@@ -241,6 +241,7 @@ def __init__(
241241 last_n_tokens_size: Maximum number of tokens to keep in the last_n_tokens deque.
242242 lora_base: Optional path to base model, useful if using a quantized base model and you want to apply LoRA to an f16 model.
243243 lora_path: Path to a LoRA file to apply to the model.
244+ tensor_split: List of floats to split the model across multiple GPUs. If None, the model is not split.
244245 verbose: Print verbose output to stderr.
245246
246247 Raises:
@@ -249,20 +250,13 @@ def __init__(
249250 Returns:
250251 A Llama instance.
251252 """
252- if tensor_split is None :
253- tensor_split = [0.0 ] * llama_cpp .LLAMA_MAX_DEVICES .value
254-
255- #Type conversion and expand the list to the length of LLAMA_MAX_DEVICES
256- FloatArray = ctypes .c_float * llama_cpp .LLAMA_MAX_DEVICES .value
257- c_tensor_split = FloatArray (* tensor_split )
258253
259254 self .verbose = verbose
260255 self .model_path = model_path
261256
262257 self .params = llama_cpp .llama_context_default_params ()
263258 self .params .n_ctx = n_ctx
264259 self .params .n_gpu_layers = n_gpu_layers
265- self .params .tensor_split = c_tensor_split
266260 self .params .seed = seed
267261 self .params .f16_kv = f16_kv
268262 self .params .logits_all = logits_all
@@ -272,6 +266,15 @@ def __init__(
272266 self .params .embedding = embedding
273267 self .params .low_vram = low_vram
274268
269+ self .tensor_split = tensor_split
270+ self ._c_tensor_split = None
271+
272+ if self .tensor_split is not None :
273+ #Type conversion and expand the list to the length of LLAMA_MAX_DEVICES
274+ FloatArray = ctypes .c_float * llama_cpp .LLAMA_MAX_DEVICES .value
275+ self ._c_tensor_split = FloatArray (* tensor_split ) # keep a reference to the array so it is not gc'd
276+ self .params .tensor_split = self ._c_tensor_split
277+
275278 self .last_n_tokens_size = last_n_tokens_size
276279 self .n_batch = min (n_ctx , n_batch )
277280
@@ -1499,7 +1502,6 @@ def __getstate__(self):
14991502 model_path = self .model_path ,
15001503 n_ctx = self .params .n_ctx ,
15011504 n_gpu_layers = self .params .n_gpu_layers ,
1502- tensor_split = self .params .tensor_split ,
15031505 seed = self .params .seed ,
15041506 f16_kv = self .params .f16_kv ,
15051507 logits_all = self .params .logits_all ,
@@ -1513,6 +1515,7 @@ def __getstate__(self):
15131515 n_threads = self .n_threads ,
15141516 lora_base = self .lora_base ,
15151517 lora_path = self .lora_path ,
1518+ tensor_split = self .tensor_split ,
15161519 ### DEPRECATED ###
15171520 n_parts = self .n_parts ,
15181521 ### DEPRECATED ###
@@ -1524,7 +1527,6 @@ def __setstate__(self, state):
15241527 n_ctx = state ["n_ctx" ],
15251528 n_parts = state ["n_parts" ],
15261529 n_gpu_layers = state ["n_gpu_layers" ],
1527- tensor_split = state ["tensor_split" ],
15281530 seed = state ["seed" ],
15291531 f16_kv = state ["f16_kv" ],
15301532 logits_all = state ["logits_all" ],
@@ -1538,6 +1540,7 @@ def __setstate__(self, state):
15381540 last_n_tokens_size = state ["last_n_tokens_size" ],
15391541 lora_base = state ["lora_base" ],
15401542 lora_path = state ["lora_path" ],
1543+ tensor_split = state ["tensor_split" ],
15411544 verbose = state ["verbose" ],
15421545 )
15431546
0 commit comments