From eff2b884b1625bd193fafc4104d8268a4dcc72a1 Mon Sep 17 00:00:00 2001 From: Alde Rojas Date: Fri, 5 Dec 2025 01:49:33 -0600 Subject: [PATCH 1/6] llama : add token support to llama-grammar --- src/llama-grammar.cpp | 226 ++++++++++++++++++++++++++--- src/llama-grammar.h | 17 ++- tests/test-grammar-integration.cpp | 96 +++++++++++- tests/test-grammar-parser.cpp | 14 ++ 4 files changed, 327 insertions(+), 26 deletions(-) diff --git a/src/llama-grammar.cpp b/src/llama-grammar.cpp index b3c5eb57174..381ef9e08c8 100644 --- a/src/llama-grammar.cpp +++ b/src/llama-grammar.cpp @@ -181,6 +181,52 @@ static std::pair parse_char(const char * src) { throw std::runtime_error("unexpected end of input"); } +static std::pair parse_token(const llama_vocab * vocab, const char * src) { + const char * pos = src; + if (*pos != '<') { + throw std::runtime_error(std::string("expecting '<' at ") + pos); + } + pos++; + + // Parse <[id]> + if (*pos == '[') { + pos++; + const char * int_end = parse_int(pos); + uint32_t token_id = std::stoul(std::string(pos, int_end - pos)); + pos = int_end; + if (*pos != ']') { + throw std::runtime_error(std::string("expecting ']' at ") + pos); + } + pos++; + if (*pos != '>') { + throw std::runtime_error(std::string("expecting '>' at ") + pos); + } + pos++; + return std::make_pair(token_id, pos); + } + + if (vocab == nullptr) { + throw std::runtime_error(std::string("no vocab to parse token at ") + src); + } + + // Parse and tokenize to obtain the token id + while (*pos != 0 && *pos != '>') { + pos++; + } + if (*pos != '>') { + throw std::runtime_error(std::string("expecting '>' at ") + pos); + } + pos++; + + llama_token tokens[2]; + int32_t n_tokens = vocab->tokenize(src, static_cast(pos - src), tokens, 2, false, true); + if (n_tokens != 1) { + // must tokenize to exactly 1 token + throw std::runtime_error("invalid token '" + std::string(src, pos - src) + "'"); + } + return std::make_pair(tokens[0], pos); +} + static void print_grammar_char(FILE * file, uint32_t c) { if (0x20 <= c && c <= 0x7f) { fprintf(file, "%c", static_cast(c)); @@ -212,6 +258,8 @@ static void print_rule_binary(FILE * file, const llama_grammar_rule & rule) { case LLAMA_GRETYPE_CHAR_RNG_UPPER: fprintf(file, "CHAR_RNG_UPPER"); break; case LLAMA_GRETYPE_CHAR_ALT: fprintf(file, "CHAR_ALT"); break; case LLAMA_GRETYPE_CHAR_ANY: fprintf(file, "CHAR_ANY"); break; + case LLAMA_GRETYPE_TOKEN: fprintf(file, "TOKEN"); break; + case LLAMA_GRETYPE_TOKEN_NOT: fprintf(file, "TOKEN_NOT"); break; } switch (elem.type) { case LLAMA_GRETYPE_END: @@ -228,6 +276,17 @@ static void print_rule_binary(FILE * file, const llama_grammar_rule & rule) { print_grammar_char(file, elem.value); fprintf(file, "\") "); break; + case LLAMA_GRETYPE_TOKEN: + fprintf(file, "<["); + fprintf(file, "%u", elem.value); + fprintf(file, "]> "); + break; + case LLAMA_GRETYPE_TOKEN_NOT: + fprintf(file, "!"); + fprintf(file, "<["); + fprintf(file, "%u", elem.value); + fprintf(file, "]> "); + break; } } fprintf(file, "\n"); @@ -284,6 +343,17 @@ static void print_rule( case LLAMA_GRETYPE_CHAR_ANY: fprintf(file, "."); break; + case LLAMA_GRETYPE_TOKEN: + fprintf(file, "<["); + fprintf(file, "%u", elem.value); + fprintf(file, "]> "); + break; + case LLAMA_GRETYPE_TOKEN_NOT: + fprintf(file, "!"); + fprintf(file, "<["); + fprintf(file, "%u", elem.value); + fprintf(file, "]> "); + break; } if (is_char_element(elem)) { switch (rule[i + 1].type) { @@ -444,6 +514,17 @@ const char * llama_grammar_parser::parse_sequence( } } pos = parse_space(pos + 1, is_nested); + } else if (*pos == '<' || *pos == '!') { // token + auto type = LLAMA_GRETYPE_TOKEN; + if (*pos == '!') { // token inverse + type = LLAMA_GRETYPE_TOKEN_NOT; + pos++; + } + auto token_pair = parse_token(vocab, pos); + const char * token_end = token_pair.second; + last_sym_start = rule.size(); + rule.push_back({type, token_pair.first}); + pos = parse_space(token_end, is_nested); } else if (is_word_char(*pos)) { // rule reference const char * name_end = parse_name(pos); uint32_t ref_rule_id = get_symbol_id(pos, name_end - pos); @@ -691,6 +772,21 @@ static bool llama_grammar_match_partial_char( return !is_positive_char; } +// returns true iff token matches the rule at pos (regular or inverse) +// asserts that pos is pointing to a token element +static bool llama_grammar_match_token( + const llama_grammar_element * pos, + const llama_token token) { + GGML_ASSERT(pos->type == LLAMA_GRETYPE_TOKEN || pos->type == LLAMA_GRETYPE_TOKEN_NOT); + if (pos->type == LLAMA_GRETYPE_TOKEN) { + return pos->value == static_cast(token); + } + if (pos->type == LLAMA_GRETYPE_TOKEN_NOT) { + return pos->value != static_cast(token); + } + return false; +} + // transforms a grammar pushdown stack into N possible stacks, all ending // at a character range (terminal element) static void llama_grammar_advance_stack( @@ -738,6 +834,8 @@ static void llama_grammar_advance_stack( case LLAMA_GRETYPE_CHAR: case LLAMA_GRETYPE_CHAR_NOT: case LLAMA_GRETYPE_CHAR_ANY: + case LLAMA_GRETYPE_TOKEN: + case LLAMA_GRETYPE_TOKEN_NOT: if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) { // only add the stack if it's not a duplicate of one we already have new_stacks.emplace_back(stack); @@ -831,26 +929,38 @@ llama_grammar_stacks & llama_grammar_get_stacks(struct llama_grammar * grammar) return grammar->stacks; } +static void llama_grammar_accept_chr( + struct llama_grammar & grammar, + const llama_grammar_stack & stack, + uint32_t chr, + llama_grammar_stacks & new_stacks) { + if (stack.empty()) { + return; + } + + const llama_grammar_element * pos = stack.back(); + + // ignore if this turns into a token + if (pos->type == LLAMA_GRETYPE_TOKEN || pos->type == LLAMA_GRETYPE_TOKEN_NOT) { + return; + } + + auto match = llama_grammar_match_char(pos, chr); + if (match.first) { + llama_grammar_stack new_stack(stack.begin(), stack.end() - 1); + if (!llama_grammar_is_end_of_sequence(match.second)) { + new_stack.push_back(match.second); + } + llama_grammar_advance_stack(grammar.rules, new_stack, new_stacks); + } +} + void llama_grammar_accept(struct llama_grammar * grammar, uint32_t chr) { llama_grammar_stacks stacks_new; stacks_new.reserve(grammar->stacks.size()); for (const auto & stack : grammar->stacks) { - if (stack.empty()) { - continue; - } - - auto match = llama_grammar_match_char(stack.back(), chr); - if (match.first) { - const llama_grammar_element * pos = match.second; - - // update top of stack to next element, if any - llama_grammar_stack new_stack(stack.begin(), stack.end() - 1); - if (!llama_grammar_is_end_of_sequence(pos)) { - new_stack.push_back(pos); - } - llama_grammar_advance_stack(grammar->rules, new_stack, stacks_new); - } + llama_grammar_accept_chr(*grammar, stack, chr, stacks_new); } grammar->stacks = std::move(stacks_new); @@ -875,6 +985,22 @@ llama_grammar_candidates llama_grammar_reject_candidates_for_stack( const llama_grammar_element * stack_pos = stack.back(); + // if the top of the stack is a token rule, then we only need to check the token id + if (stack_pos->type == LLAMA_GRETYPE_TOKEN || stack_pos->type == LLAMA_GRETYPE_TOKEN_NOT) { + for (const auto & tok : candidates) { + if (*tok.code_points == 0) { + // reached the end of a token consumed by char rules, reject iff it ended + // in a partial response + if (tok.partial_utf8.n_remain != 0) { + rejects.push_back(tok); + } + } else if (!llama_grammar_match_token(stack_pos, tok.id)) { + rejects.push_back(tok); + } + } + return rejects; + } + llama_grammar_candidates next_candidates; next_candidates.reserve(candidates.size()); @@ -887,7 +1013,7 @@ llama_grammar_candidates llama_grammar_reject_candidates_for_stack( rejects.push_back(tok); } } else if (llama_grammar_match_char(stack_pos, *tok.code_points).first) { - next_candidates.push_back({ tok.index, tok.code_points + 1, tok.partial_utf8 }); + next_candidates.push_back({ tok.index, tok.code_points + 1, tok.partial_utf8, tok.id }); } else { rejects.push_back(tok); } @@ -905,7 +1031,7 @@ llama_grammar_candidates llama_grammar_reject_candidates_for_stack( auto next_rejects = llama_grammar_reject_candidates(rules, next_stacks, next_candidates); for (const auto & tok : next_rejects) { - rejects.push_back({ tok.index, tok.code_points - 1, tok.partial_utf8 }); + rejects.push_back({ tok.index, tok.code_points - 1, tok.partial_utf8, tok.id }); } return rejects; @@ -990,7 +1116,7 @@ struct llama_grammar * llama_grammar_init_impl( size_t num_trigger_patterns, const llama_token * trigger_tokens, size_t num_trigger_tokens) { - llama_grammar_parser parser; + llama_grammar_parser parser(vocab); // if there is a grammar, parse it // rules will be empty (default) if there are parse errors @@ -1156,7 +1282,7 @@ void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_ cur_p->data[i].logit = -INFINITY; } else { candidates_decoded.push_back(decode_utf8(piece, grammar.partial_utf8)); - candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second }); + candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second, id }); } } @@ -1175,7 +1301,7 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token if (std::find(grammar.trigger_tokens.begin(), grammar.trigger_tokens.end(), token) != grammar.trigger_tokens.end()) { grammar.awaiting_trigger = false; grammar.trigger_buffer.clear(); - llama_grammar_accept_str(grammar, piece); + llama_grammar_accept_token(grammar, token, piece); LLAMA_LOG_DEBUG("Grammar triggered on token %u (`%s`)", token, piece.c_str()); return; } else { @@ -1199,7 +1325,7 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token auto constrained_str = grammar.trigger_buffer.substr(start); // std::string constrained_str(match[1].first, grammar.trigger_buffer.end()); grammar.trigger_buffer.clear(); - llama_grammar_accept_str(grammar, constrained_str); + llama_grammar_accept_token(grammar, -1, constrained_str); LLAMA_LOG_DEBUG("Grammar triggered on regex: '%s'\n", constrained_str.c_str()); return; } @@ -1218,7 +1344,7 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token GGML_ABORT("fatal error"); } - llama_grammar_accept_str(grammar, piece); + llama_grammar_accept_token(grammar, token, piece); } void llama_grammar_accept_str(struct llama_grammar & grammar, const std::string & piece) { @@ -1235,3 +1361,59 @@ void llama_grammar_accept_str(struct llama_grammar & grammar, const std::string throw std::runtime_error("Unexpected empty grammar stack after accepting piece: " + piece); } } + +void llama_grammar_accept_token(struct llama_grammar & grammar, llama_token token, const std::string & piece) { + // Note terminating 0 in decoded string + const auto decoded = decode_utf8(piece, grammar.partial_utf8); + const auto & code_points = decoded.first; + + llama_grammar_stacks stacks_new; + stacks_new.reserve(grammar.stacks.size()); + + for (const auto & stack : grammar.stacks) { + if (stack.empty()) { + continue; + } + + const llama_grammar_element * pos = stack.back(); + + if (pos->type == LLAMA_GRETYPE_TOKEN || pos->type == LLAMA_GRETYPE_TOKEN_NOT) { + if (llama_grammar_match_token(pos, token)) { + llama_grammar_stack new_stack(stack.begin(), stack.end() - 1); + if (!llama_grammar_is_end_of_sequence(pos + 1)) { + new_stack.push_back(pos + 1); + } + llama_grammar_advance_stack(grammar.rules, new_stack, stacks_new); + } + } else { + llama_grammar_stacks current_stacks = {stack}; + + for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) { + llama_grammar_stacks next_stacks; + + for (const auto & cur_stack : current_stacks) { + llama_grammar_accept_chr(grammar, cur_stack, *it, next_stacks); + } + + current_stacks = std::move(next_stacks); + if (current_stacks.empty()) { + break; + } + } + + for (auto & surviving_stack : current_stacks) { + if (std::find(stacks_new.begin(), stacks_new.end(), surviving_stack) == stacks_new.end()) { + stacks_new.emplace_back(surviving_stack); + } + } + } + } + + grammar.stacks = std::move(stacks_new); + grammar.partial_utf8 = decoded.second; + + if (grammar.stacks.empty()) { + throw std::runtime_error("Unexpected empty grammar stack after accepting piece: " + piece + " (" + std::to_string(token) + ")"); + } +} + diff --git a/src/llama-grammar.h b/src/llama-grammar.h index f8c291de999..09ca934a0d2 100644 --- a/src/llama-grammar.h +++ b/src/llama-grammar.h @@ -36,11 +36,17 @@ enum llama_gretype { // any character (.) LLAMA_GRETYPE_CHAR_ANY = 7, + + // terminal element: token (<[token-id]>) + LLAMA_GRETYPE_TOKEN = 8, + + // inverse token (<[^token-id]>) + LLAMA_GRETYPE_TOKEN_NOT = 9, }; typedef struct llama_grammar_element { enum llama_gretype type; - uint32_t value; // Unicode code point or rule ID + uint32_t value; // Unicode code point, rule ID, or token ID } llama_grammar_element; struct llama_partial_utf8 { @@ -52,6 +58,7 @@ struct llama_grammar_candidate { size_t index; const uint32_t * code_points; llama_partial_utf8 partial_utf8; + llama_token id; }; using llama_grammar_rule = std::vector< llama_grammar_element>; @@ -77,10 +84,13 @@ std::vector llama_grammar_reject_candidates_for_stack( const llama_grammar_candidates & candidates); struct llama_grammar_parser { + const llama_vocab * vocab; std::map symbol_ids; llama_grammar_rules rules; + llama_grammar_parser(const struct llama_vocab * vocab = nullptr) : vocab(vocab) {} + llama_grammar_stack c_rules() const; uint32_t get_symbol_id(const char * src, size_t len); @@ -171,3 +181,8 @@ void llama_grammar_accept_impl( void llama_grammar_accept_str( struct llama_grammar & grammar, const std::string & piece); + +void llama_grammar_accept_token( + struct llama_grammar & grammar, + llama_token token, + const std::string & piece); diff --git a/tests/test-grammar-integration.cpp b/tests/test-grammar-integration.cpp index 82fae671ed0..133cc07f54b 100644 --- a/tests/test-grammar-integration.cpp +++ b/tests/test-grammar-integration.cpp @@ -32,13 +32,66 @@ static bool test_build_grammar_fails(const std::string & grammar_str) { return grammar_fails; } +struct token_and_piece { + llama_token token; + std::string piece; +}; + +// token() encodes a 32-bit ID as 5 bytes: a 0xff marker followed by the ID in big-endian order. +static std::string token(llama_token id) { + return std::string{ + static_cast(0xff), + static_cast((id >> 24) & 0xff), + static_cast((id >> 16) & 0xff), + static_cast((id >> 8) & 0xff), + static_cast(id & 0xff) + }; +} + +// parse_tokens() parses the token encodes above and UTF-8 text. +static std::vector parse_tokens(const std::string & input) { + std::vector result; + result.reserve(input.size()); + size_t offset = 0; + while (offset < input.size()) { + try { + if (static_cast(input[offset]) == 0xff) { + if (offset + 5 > input.size()) { + throw std::runtime_error("not enough bytes for token id"); + } + uint32_t val = + (static_cast(input[offset + 1]) << 24) | + (static_cast(input[offset + 2]) << 16) | + (static_cast(input[offset + 3]) << 8) | + (static_cast(input[offset + 4])); + auto piece = "<[" + std::to_string(val) + "]>"; + result.push_back({static_cast(val), piece}); + offset += 5; + } else { + uint32_t cpt = unicode_cpt_from_utf8(input, offset); + result.push_back({0, unicode_cpt_to_utf8(cpt)}); + } + } catch (const std::invalid_argument & /*ex*/) { + // Silently ignore invalid UTF-8 input to avoid leaking the exception beyond llama_tokenize + ++offset; + result.push_back({0, unicode_cpt_to_utf8(0xFFFD)}); // replacement character + } + } + return result; +} + static bool match_string(const std::string & input, llama_grammar * grammar) { - const auto cpts = unicode_cpts_from_utf8(input); + const auto parsed = parse_tokens(input); auto & stacks_cur = llama_grammar_get_stacks(grammar); - for (const auto & cpt : cpts) { - llama_grammar_accept(grammar, cpt); + for (const auto & in : parsed) { + try { + llama_grammar_accept_token(*grammar, in.token, in.piece); + } catch (const std::runtime_error & /*e*/) { + // normally this shouldn't get hit because of llama_grammar_apply + return false; + } if (stacks_cur.empty()) { // no stacks means that the grammar failed to match at this point @@ -426,6 +479,25 @@ static void test_simple_grammar() { "12a45", } ); + + // Test case for a simple grammar with tokens + test_grammar( + "simple grammar with tokens", + R"""( + root ::= <[10]> content <[11]> + content ::= !<[10]>*)""", + // Passing strings + { + token(10) + "content goes here" + token(11), + token(10) + "content goes here" + token(12) + ", optionally other tokens too" + token(11), + token(10) + token(11), + }, + // Failing strings + { + token(10) + "content goes here", + token(10), + } + ); } static void test_complex_grammar() { @@ -487,6 +559,24 @@ static void test_complex_grammar() { "123+456*789-123/456+789*123-456/789+123*456-789/123+456*789-123/456+789*123-456/", } ); + + // Test case for a more complex grammar with tokens + test_grammar( + "complex grammar with tokens", + R"""( + root ::= reasoning content tool-call + reasoning ::= <[10]> !<[10]>* <[11]> + content ::= !<[12]>* + tool-call ::= <[12]> .*)""", + // Passing strings + { + token(10) + "I am thinking" + token(11) + "hello!" + token(12) + "... tool call ...", + }, + // Failing strings + { + "hello", + } + ); } static void test_special_chars() { diff --git a/tests/test-grammar-parser.cpp b/tests/test-grammar-parser.cpp index 67821a2d5c6..03ae78ff739 100644 --- a/tests/test-grammar-parser.cpp +++ b/tests/test-grammar-parser.cpp @@ -515,5 +515,19 @@ int main() {LLAMA_GRETYPE_END, 0}, }); + // <[1000]> = "" + // <[1001]> = "" + verify_parsing(R"""( + root ::= <[1000]> !<[1001]> <[1001]> + )""", { + {"root", 0} + }, { + // root (index 0) + {LLAMA_GRETYPE_TOKEN, 1000}, + {LLAMA_GRETYPE_TOKEN_NOT, 1001}, + {LLAMA_GRETYPE_TOKEN, 1001}, + {LLAMA_GRETYPE_END, 0}, + }); + return 0; } From ab8d1376f434bb13b6fe0db6ac8142c06995fdf4 Mon Sep 17 00:00:00 2001 From: Alde Rojas Date: Sat, 6 Dec 2025 00:51:26 -0600 Subject: [PATCH 2/6] fix inverse token comment --- src/llama-grammar.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llama-grammar.h b/src/llama-grammar.h index 09ca934a0d2..82c273b9c45 100644 --- a/src/llama-grammar.h +++ b/src/llama-grammar.h @@ -40,7 +40,7 @@ enum llama_gretype { // terminal element: token (<[token-id]>) LLAMA_GRETYPE_TOKEN = 8, - // inverse token (<[^token-id]>) + // inverse token (!<[token-id]>) LLAMA_GRETYPE_TOKEN_NOT = 9, }; From 27dbaef31809cfa062cf6bb52de37cc6943a508e Mon Sep 17 00:00:00 2001 From: Alde Rojas Date: Sat, 6 Dec 2025 14:03:00 -0600 Subject: [PATCH 3/6] refactor trigger_patterns to replay tokens instead of the entire string --- src/llama-grammar.cpp | 42 ++++++++++++++++++++++++++++++------------ src/llama-grammar.h | 4 ++++ 2 files changed, 34 insertions(+), 12 deletions(-) diff --git a/src/llama-grammar.cpp b/src/llama-grammar.cpp index 381ef9e08c8..75d5d750c39 100644 --- a/src/llama-grammar.cpp +++ b/src/llama-grammar.cpp @@ -1098,12 +1098,13 @@ struct llama_grammar * llama_grammar_init_impl( vocab, std::move(vec_rules), std::move(stacks), - /* .partial_utf8 = */ {}, - /* .lazy =*/ false, - /* .awaiting_trigger = */ false, - /* .trigger_buffer = */ "", - /* .trigger_tokens = */ {}, - /* .trigger_patterns = */ {}, + /* .partial_utf8 = */ {}, + /* .lazy = */ false, + /* .awaiting_trigger = */ false, + /* .trigger_buffer = */ "", + /* .trigger_buffer_positions = */ {}, + /* .trigger_tokens = */ {}, + /* .trigger_patterns = */ {}, }; } @@ -1203,10 +1204,11 @@ struct llama_grammar * llama_grammar_init_impl( vocab, std::move(vec_rules), std::move(stacks), - /* .partial_utf8 = */ {}, - /* .lazy = */ lazy, - /* .awaiting_trigger = */ lazy, - /* .trigger_buffer = */ "", + /* .partial_utf8 = */ {}, + /* .lazy = */ lazy, + /* .awaiting_trigger = */ lazy, + /* .trigger_buffer = */ "", + /* .trigger_buffer_positions = */ {}, std::move(vec_trigger_tokens), std::move(vec_trigger_patterns), }; @@ -1229,6 +1231,7 @@ struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & gra grammar.lazy, grammar.awaiting_trigger, grammar.trigger_buffer, + grammar.trigger_buffer_positions, grammar.trigger_tokens, grammar.trigger_patterns, }; @@ -1305,6 +1308,8 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token LLAMA_LOG_DEBUG("Grammar triggered on token %u (`%s`)", token, piece.c_str()); return; } else { + auto position = std::make_pair(grammar.trigger_buffer.size(), grammar.trigger_buffer.size() + piece.size()); + grammar.trigger_buffer_positions.push_back(std::make_pair(token, position)); grammar.trigger_buffer += piece; std::smatch match; @@ -1322,10 +1327,23 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token if (start == std::string::npos) { start = match.position(0); } + + // replay tokens that overlap with [start, end) + for (const auto & [tok, tok_pos] : grammar.trigger_buffer_positions) { + auto [tok_start, tok_end] = tok_pos; + if (tok_end <= start) { + continue; + } + + size_t piece_start = (tok_start < start) ? start : tok_start; // allow for partial token pieces + size_t piece_len = tok_end - piece_start; + auto tok_piece = grammar.trigger_buffer.substr(piece_start, piece_len); + llama_grammar_accept_token(grammar, tok, tok_piece); + } + auto constrained_str = grammar.trigger_buffer.substr(start); - // std::string constrained_str(match[1].first, grammar.trigger_buffer.end()); grammar.trigger_buffer.clear(); - llama_grammar_accept_token(grammar, -1, constrained_str); + grammar.trigger_buffer_positions.clear(); LLAMA_LOG_DEBUG("Grammar triggered on regex: '%s'\n", constrained_str.c_str()); return; } diff --git a/src/llama-grammar.h b/src/llama-grammar.h index 82c273b9c45..a4c978ac115 100644 --- a/src/llama-grammar.h +++ b/src/llama-grammar.h @@ -122,6 +122,9 @@ struct llama_grammar_trigger_pattern { }; struct llama_grammar { + // maintain a list of llama_tokens and their positions in the trigger_buffer + using token_pos = std::pair>; + // note: allow null vocab for testing (not great) const llama_vocab * vocab; @@ -137,6 +140,7 @@ struct llama_grammar { bool lazy = false; bool awaiting_trigger = false; // Initialized to true for lazy grammars only std::string trigger_buffer; // Output buffered by lazy grammar. Will be cleared once trigger is found. + std::vector trigger_buffer_positions; // Tokens buffered by lazy grammar. Used to replay when a trigger is found. std::vector trigger_tokens; // Tokens that trigger a lazy grammar, or tokens to force printing of (even if special). std::vector trigger_patterns; // Regular expressions that trigger a lazy grammar. Must be a full match of the entire generated From 02621fe703905528ebfa47c65c5ec8da78fbf450 Mon Sep 17 00:00:00 2001 From: Alde Rojas Date: Sat, 6 Dec 2025 14:17:04 -0600 Subject: [PATCH 4/6] add token documentation --- grammars/README.md | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/grammars/README.md b/grammars/README.md index a63198b5aeb..11e3b6dd900 100644 --- a/grammars/README.md +++ b/grammars/README.md @@ -67,6 +67,30 @@ Parentheses `()` can be used to group sequences, which allows for embedding alte - `{m,n}` repeats the precedent symbol or sequence at between `m` and `n` times (included) - `{0,n}` repeats the precedent symbol or sequence at most `n` times (included) +## Tokens + +Tokens allow grammars to match specific tokenizer tokens rather than character sequences. This is useful for constraining outputs based on special tokens (like `` or ``). + +Tokens can be specified in two ways: + +1. **Token ID**: Use angle brackets with the token ID in square brackets: `<[token-id]>`. For example, `<[1000]>` matches the token with ID 1000. + +2. **Token string**: Use angle brackets with the token text directly: ``. For example, `` will match the token whose text is exactly ``. This only works if the string tokenizes to exactly one token in the vocabulary, otherwise the grammar will fail to parse. + +You can negate token matches using the `!` prefix: `!<[1000]>` or `!` matches any token *except* the specified one. + +``` +# Match a thinking block: ... +# Using token strings (requires these to be single tokens in the vocab) +root ::= thinking .* +thinking ::= !* + +# Equivalent grammar using explicit token IDs +# Assumes token 1000 = , token 1001 = +root ::= <[1000]> thinking <[1001]> .* +thinking ::= !<[1001]>* +``` + ## Comments and newlines Comments can be specified with `#`: From 83cb005f2a3229285ad0488da83dcfd604a2412e Mon Sep 17 00:00:00 2001 From: Alde Rojas Date: Sat, 6 Dec 2025 14:19:37 -0600 Subject: [PATCH 5/6] fix test-llama-grammar --- tests/test-llama-grammar.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test-llama-grammar.cpp b/tests/test-llama-grammar.cpp index cc198f3e3c9..fd45d5ada83 100644 --- a/tests/test-llama-grammar.cpp +++ b/tests/test-llama-grammar.cpp @@ -202,7 +202,7 @@ int main() uint32_t *cp = new uint32_t[2]; // dynamically allocate memory for code_point cp[0] = 37 + i; cp[1] = 0; - next_candidates[i] = {i, cp, {}}; + next_candidates[i] = {i, cp, {}, 0}; } std::vector>> expected_reject = { From 49f16b4b9515546bffc060271e27ffa0863e1669 Mon Sep 17 00:00:00 2001 From: Alde Rojas Date: Sat, 6 Dec 2025 15:24:10 -0600 Subject: [PATCH 6/6] improve test cases for tokens --- tests/test-grammar-integration.cpp | 35 +++++++++++++++++++++--------- 1 file changed, 25 insertions(+), 10 deletions(-) diff --git a/tests/test-grammar-integration.cpp b/tests/test-grammar-integration.cpp index 133cc07f54b..7aa7e58a5c6 100644 --- a/tests/test-grammar-integration.cpp +++ b/tests/test-grammar-integration.cpp @@ -485,17 +485,22 @@ static void test_simple_grammar() { "simple grammar with tokens", R"""( root ::= <[10]> content <[11]> - content ::= !<[10]>*)""", + content ::= (!<[11]>)*)""", // Passing strings { - token(10) + "content goes here" + token(11), - token(10) + "content goes here" + token(12) + ", optionally other tokens too" + token(11), + token(10) + "hello world" + token(11), + token(10) + "text with " + token(12) + " other tokens " + token(13) + " mixed in" + token(11), token(10) + token(11), + token(10) + token(12) + token(13) + token(14) + token(15) + token(11), + token(10) + "a" + token(11), }, // Failing strings { - token(10) + "content goes here", + token(10) + "missing end token", token(10), + "missing start token" + token(11), + token(10) + token(11) + token(11), // double end token + token(11) + "wrong order" + token(10), } ); } @@ -564,17 +569,27 @@ static void test_complex_grammar() { test_grammar( "complex grammar with tokens", R"""( - root ::= reasoning content tool-call - reasoning ::= <[10]> !<[10]>* <[11]> - content ::= !<[12]>* - tool-call ::= <[12]> .*)""", + root ::= reasoning+ content tool-call* + reasoning ::= <[10]> (!<[11]>)* <[11]> + content ::= <[20]> (!<[21]>)* <[21]> + tool-call ::= <[12]> name <[13]> args <[14]> + name ::= (!<[13]>)+ + args ::= (!<[14]>)*)""", // Passing strings { - token(10) + "I am thinking" + token(11) + "hello!" + token(12) + "... tool call ...", + token(10) + "I am thinking" + token(11) + token(20) + "hello world!" + token(21) + token(12) + "search" + token(13) + "query=test" + token(14), + token(10) + "reasoning 1" + token(11) + token(10) + "reasoning 2" + token(11) + token(20) + token(21) + token(12) + "tool" + token(13) + token(14), + token(10) + token(11) + token(20) + "content" + token(21), + token(10) + "think" + token(12) + " nested" + token(11) + token(20) + token(10) + "more content" + token(21) + token(12) + "fn" + token(13) + "x=1,y=2" + token(14) + token(12) + "fn2" + token(13) + token(14), + token(10) + "reasoning" + token(11) + token(10) + "more" + token(11) + token(10) + "even more" + token(11) + token(20) + "text" + token(21) + token(12) + "a" + token(13) + "b" + token(14) + token(12) + "c" + token(13) + "d" + token(14), }, // Failing strings { - "hello", + token(20) + "content only" + token(21), + token(10) + "no closing reasoning", + token(10) + token(11) + token(20) + "no closing content", + token(10) + token(11) + token(20) + token(21) + token(12) + "incomplete tool", + token(10) + token(11) + token(11) + token(20) + token(21), } ); }