Skip to content

Commit 9b9710a

Browse files
committed
Optimization: Improved batch token processing logic in Llava15ChatHandler.
1 parent a266a0b commit 9b9710a

File tree

3 files changed

+22
-16
lines changed

3 files changed

+22
-16
lines changed

llama_cpp/_internals.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -662,7 +662,7 @@ def n_tokens(self) -> int:
662662
def reset(self):
663663
self.batch.n_tokens = 0
664664

665-
def set_batch(self, batch: Sequence[int], n_past: int, logits_all: bool):
665+
def set_batch(self, batch: Sequence[int], n_past: llama_cpp.llama_pos, logits_all: bool):
666666
n_tokens = len(batch)
667667
self.batch.n_tokens = n_tokens
668668
for i in range(n_tokens):

llama_cpp/llama_chat_format.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2999,22 +2999,19 @@ def __call__(
29992999
llama._ctx.memory_clear(True)
30003000

30013001
# Process each chunk
3002-
n_past = llama_cpp.llama_pos(0)
3002+
n_past = 0
30033003
n_chunks = self._mtmd_cpp.mtmd_input_chunks_size(chunks)
30043004

30053005
for i in range(n_chunks):
30063006
chunk = self._mtmd_cpp.mtmd_input_chunks_get(chunks, i)
3007-
if chunk is None:
3008-
continue
3007+
if chunk is None: continue
30093008

30103009
chunk_type = self._mtmd_cpp.mtmd_input_chunk_get_type(chunk)
30113010

30123011
if chunk_type == self._mtmd_cpp.mtmd_input_chunk_type.MTMD_INPUT_CHUNK_TYPE_TEXT:
30133012
# Handle text chunk
30143013
n_tokens_out = ctypes.c_size_t()
3015-
tokens_ptr = self._mtmd_cpp.mtmd_input_chunk_get_tokens_text(
3016-
chunk, ctypes.byref(n_tokens_out)
3017-
)
3014+
tokens_ptr = self._mtmd_cpp.mtmd_input_chunk_get_tokens_text(chunk, ctypes.byref(n_tokens_out))
30183015

30193016
if tokens_ptr and n_tokens_out.value > 0:
30203017
# Convert ctypes array to Python list
@@ -3024,23 +3021,25 @@ def __call__(
30243021
raise ValueError(
30253022
f"Prompt exceeds n_ctx: {llama.n_tokens + len(tokens)} > {llama.n_ctx()}"
30263023
)
3024+
llama.n_tokens = n_past
30273025
llama.eval(tokens)
3026+
n_past = llama.n_tokens
30283027

30293028
elif chunk_type in [self._mtmd_cpp.mtmd_input_chunk_type.MTMD_INPUT_CHUNK_TYPE_IMAGE, self._mtmd_cpp.mtmd_input_chunk_type.MTMD_INPUT_CHUNK_TYPE_AUDIO]:
30303029
# Handle image/audio chunk using helper
30313030
chunk_n_tokens = self._mtmd_cpp.mtmd_input_chunk_get_n_tokens(chunk)
30323031

3033-
if llama.n_tokens + chunk_n_tokens > llama.n_ctx():
3032+
if n_past + chunk_n_tokens > llama.n_ctx():
30343033
raise ValueError(
3035-
f"Prompt exceeds n_ctx: {llama.n_tokens + chunk_n_tokens} > {llama.n_ctx()}"
3034+
f"Prompt exceeds n_ctx: {n_past + chunk_n_tokens} > {llama.n_ctx()}"
30363035
)
30373036

30383037
new_n_past = llama_cpp.llama_pos(0)
30393038
result = self._mtmd_cpp.mtmd_helper_eval_chunk_single(
30403039
self.mtmd_ctx,
30413040
llama._ctx.ctx,
30423041
chunk,
3043-
llama_cpp.llama_pos(llama.n_tokens),
3042+
llama_cpp.llama_pos(n_past),
30443043
llama_cpp.llama_seq_id(0),
30453044
llama.n_batch,
30463045
False, # logits_last
@@ -3051,8 +3050,15 @@ def __call__(
30513050
raise ValueError(f"Failed to evaluate chunk: error code {result}")
30523051

30533052
# Update llama's token count
3054-
llama.n_tokens = new_n_past.value
3055-
3053+
n_past = new_n_past.value
3054+
llama.n_tokens = n_past
3055+
3056+
n_past = llama.n_tokens
3057+
if n_past > 0:
3058+
llama._ctx.memory_seq_rm(0, n_past - 1, -1)
3059+
if llama._ctx.memory_seq_pos_min(0) == llama._ctx.memory_seq_pos_max(0):
3060+
n_past += 1
3061+
llama.n_tokens = n_past
30563062
# Get prompt tokens to avoid a cache miss
30573063
prompt = llama.input_ids[: llama.n_tokens].tolist()
30583064

@@ -3786,9 +3792,9 @@ def __call__(self, **kwargs):
37863792
messages = kwargs.get('messages', [])
37873793
try:
37883794
image_count = len(self.get_image_urls(messages))
3789-
print(f"GLM4VChatHandler - Processing {image_count} images", file=sys.stderr)
3795+
print(f"GLM4VChatHandler - Cleared state, processing {image_count} images", file=sys.stderr)
37903796
except Exception:
3791-
print(f"GLM4VChatHandler - State reset", file=sys.stderr)
3797+
print(f"GLM4VChatHandler - Cleared state", file=sys.stderr)
37923798

37933799
# Use parent implementation
37943800
return super().__call__(**kwargs)

llama_cpp/llama_cpp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -550,7 +550,7 @@ class llama_batch(ctypes.Structure):
550550
The provided arrays (i.e. token, embd, pos, etc.) must have size of n_tokens
551551
552552
Attributes:
553-
n_tokens (int): number of tokens
553+
n_tokens (ctypes.c_int32): number of tokens
554554
token (ctypes.Array[llama_token]): the token ids of the input (used when embd is NULL)
555555
embd (ctypes.Array[ctypes.ctypes.c_float]): token embeddings (i.e. float vector of size n_embd) (used when token is NULL)
556556
pos (ctypes.Array[ctypes.Array[llama_pos]]): the positions of the respective token in the sequence
@@ -559,7 +559,7 @@ class llama_batch(ctypes.Structure):
559559
"""
560560

561561
if TYPE_CHECKING:
562-
n_tokens: int
562+
n_tokens: ctypes.c_int32
563563
token: CtypesArray[llama_token]
564564
embd: CtypesArray[ctypes.c_float]
565565
pos: CtypesArray[CtypesArray[llama_pos]]

0 commit comments

Comments
 (0)