diff --git a/common/console.cpp b/common/console.cpp index 078a8d678d9..f57d17f1ce3 100644 --- a/common/console.cpp +++ b/common/console.cpp @@ -1,6 +1,7 @@ #include "console.h" #include #include +#include #if defined(_WIN32) #define WIN32_LEAN_AND_MEAN @@ -24,14 +25,14 @@ #include #endif -#define ANSI_COLOR_RED "\x1b[31m" -#define ANSI_COLOR_GREEN "\x1b[32m" -#define ANSI_COLOR_YELLOW "\x1b[33m" -#define ANSI_COLOR_BLUE "\x1b[34m" -#define ANSI_COLOR_MAGENTA "\x1b[35m" -#define ANSI_COLOR_CYAN "\x1b[36m" -#define ANSI_COLOR_RESET "\x1b[0m" -#define ANSI_BOLD "\x1b[1m" +#define ANSI_COLOR_RED LOG_COL_RED +#define ANSI_COLOR_GREEN LOG_COL_GREEN +#define ANSI_COLOR_YELLOW LOG_COL_YELLOW +#define ANSI_COLOR_BLUE LOG_COL_BLUE +#define ANSI_COLOR_MAGENTA LOG_COL_MAGENTA +#define ANSI_COLOR_CYAN LOG_COL_CYAN +#define ANSI_COLOR_RESET LOG_COL_DEFAULT +#define ANSI_BOLD LOG_COL_BOLD namespace console { @@ -142,25 +143,65 @@ namespace console { // Keep track of current display and only emit ANSI code if it changes void set_display(display_t display) { if (advanced_display && current_display != display) { - fflush(stdout); - switch(display) { - case reset: - fprintf(out, ANSI_COLOR_RESET); - break; - case prompt: - fprintf(out, ANSI_COLOR_YELLOW); - break; - case user_input: - fprintf(out, ANSI_BOLD ANSI_COLOR_GREEN); - break; - case error: - fprintf(out, ANSI_BOLD ANSI_COLOR_RED); - } current_display = display; - fflush(out); + + if (display == user_input && common_log_is_active(common_log_main())) { + common_log_flush(common_log_main()); + } + + if (display == user_input || !common_log_is_active(common_log_main())) { + fflush(stdout); + switch(display) { + case reset: + fprintf(out, ANSI_COLOR_RESET); + break; + case prompt: + fprintf(out, ANSI_COLOR_YELLOW); + break; + case user_input: + fprintf(out, ANSI_BOLD ANSI_COLOR_GREEN); + break; + case error: + fprintf(out, ANSI_BOLD ANSI_COLOR_RED); + break; + case reasoning: + fprintf(out, ANSI_COLOR_BLUE); + break; + } + fflush(out); + } } } + display_t get_display() { + return current_display; + } + + const char * get_display_color() { + switch(current_display) { + case reset: + return ANSI_COLOR_RESET; + case prompt: + return ANSI_COLOR_YELLOW; + case user_input: + return ANSI_BOLD ANSI_COLOR_GREEN; + case error: + return ANSI_BOLD ANSI_COLOR_RED; + case reasoning: + return ANSI_COLOR_BLUE; + default: + return ""; + } + } + + void write_console(const char * format, ...) { + va_list args; + va_start(args, format); + vfprintf(out, format, args); + va_end(args); + fflush(out); + } + static char32_t getchar32() { #if defined(_WIN32) HANDLE hConsole = GetStdHandle(STD_INPUT_HANDLE); diff --git a/common/console.h b/common/console.h index ec175269b9d..ce35bbe2efa 100644 --- a/common/console.h +++ b/common/console.h @@ -3,17 +3,43 @@ #pragma once #include +#include "log.h" namespace console { enum display_t { reset = 0, prompt, user_input, - error + error, + reasoning }; void init(bool use_simple_io, bool use_advanced_display); void cleanup(); void set_display(display_t display); + display_t get_display(); + const char * get_display_color(); bool readline(std::string & line, bool multiline_input); + + void write_console(const char * format, ...); + + template + void write(const char * format, Args... args) { + if (get_display() == user_input || !common_log_is_active(common_log_main())) { + write_console(format, args...); + + } else { + const char * color = get_display_color(); + std::string colored_format = std::string(color) + format + LOG_COL_DEFAULT; + common_log_add(common_log_main(), GGML_LOG_LEVEL_CONT, colored_format.c_str(), args...); + } + } + + inline void write(const char * data) { + write("%s", data); + } + + inline void write(const std::string & data) { + write("%s", data.c_str()); + } } diff --git a/common/log.cpp b/common/log.cpp index a24782b7399..8958d0fce35 100644 --- a/common/log.cpp +++ b/common/log.cpp @@ -174,6 +174,7 @@ struct common_log { std::mutex mtx; std::thread thrd; std::condition_variable cv; + std::condition_variable cv_flushed; FILE * file; @@ -288,6 +289,10 @@ struct common_log { cur = entries[head]; head = (head + 1) % entries.size(); + + if (head == tail) { + cv_flushed.notify_all(); + } } if (cur.is_end) { @@ -376,6 +381,18 @@ struct common_log { this->timestamps = timestamps; } + + bool is_active() const { + return running; + } + + void flush() { + if (!running) { + return; + } + std::unique_lock lock(mtx); + cv_flushed.wait(lock, [this]() { return head == tail; }); + } }; // @@ -409,6 +426,14 @@ void common_log_free(struct common_log * log) { delete log; } +bool common_log_is_active(struct common_log * log) { + return log->is_active(); +} + +void common_log_flush(struct common_log * log) { + log->flush(); +} + void common_log_add(struct common_log * log, enum ggml_log_level level, const char * fmt, ...) { va_list args; va_start(args, fmt); diff --git a/common/log.h b/common/log.h index 7edb239a339..7f7a80a7ecc 100644 --- a/common/log.h +++ b/common/log.h @@ -44,9 +44,11 @@ struct common_log; struct common_log * common_log_init(); struct common_log * common_log_main(); // singleton, automatically destroys itself on exit -void common_log_pause (struct common_log * log); // pause the worker thread, not thread-safe -void common_log_resume(struct common_log * log); // resume the worker thread, not thread-safe -void common_log_free (struct common_log * log); +void common_log_pause (struct common_log * log); // pause the worker thread, not thread-safe +void common_log_resume (struct common_log * log); // resume the worker thread, not thread-safe +void common_log_free (struct common_log * log); +bool common_log_is_active(struct common_log * log); // check if logging is active +void common_log_flush (struct common_log * log); // wait for all pending messages to be processed LOG_ATTRIBUTE_FORMAT(3, 4) void common_log_add(struct common_log * log, enum ggml_log_level level, const char * fmt, ...); diff --git a/tools/main/main.cpp b/tools/main/main.cpp index 78b42267b59..e9ed31869a5 100644 --- a/tools/main/main.cpp +++ b/tools/main/main.cpp @@ -44,10 +44,10 @@ static bool need_insert_eot = false; static void print_usage(int argc, char ** argv) { (void) argc; - LOG("\nexample usage:\n"); - LOG("\n text generation: %s -m your_model.gguf -p \"I believe the meaning of life is\" -n 128 -no-cnv\n", argv[0]); - LOG("\n chat (conversation): %s -m your_model.gguf -sys \"You are a helpful assistant\"\n", argv[0]); - LOG("\n"); + console::write("\nexample usage:\n"); + console::write("\n text generation: %s -m your_model.gguf -p \"I believe the meaning of life is\" -n 128 -no-cnv\n", argv[0]); + console::write("\n chat (conversation): %s -m your_model.gguf -sys \"You are a helpful assistant\"\n", argv[0]); + console::write("\n"); } static bool file_exists(const std::string & path) { @@ -70,11 +70,11 @@ static void sigint_handler(int signo) { need_insert_eot = true; } else { console::cleanup(); - LOG("\n"); + console::write("\n"); common_perf_print(*g_ctx, *g_smpl); // make sure all logs are flushed - LOG("Interrupted by user\n"); + console::write("Interrupted by user\n"); common_log_pause(common_log_main()); _exit(130); @@ -83,6 +83,139 @@ static void sigint_handler(int signo) { } #endif +class partial_formatter { +public: + enum output_type { + CONTENT, + REASONING, + }; + + struct output { + std::string formatted; + output_type type; + }; + + partial_formatter(const common_chat_syntax & syntax) : syntax(syntax), had_reasoning(false) {} + + std::vector operator()(const std::string & accumulated) { + common_chat_msg next = common_chat_parse(accumulated, true, syntax); + + auto diffs = common_chat_msg_diff::compute_diffs(previous, next); + std::vector result; + for (const auto & diff : diffs) { + if (!diff.reasoning_content_delta.empty()) { + if (!had_reasoning) { + result.push_back({"\n⇒ ", REASONING}); + } + result.push_back({diff.reasoning_content_delta, REASONING}); + had_reasoning = true; + } + if (!diff.content_delta.empty()) { + if (had_reasoning) { + result.push_back({"\n\n", REASONING}); + had_reasoning = false; + } + result.push_back({diff.content_delta, CONTENT}); + } + } + previous = next; + return result; + } + + void clear() { + previous = common_chat_msg(); + had_reasoning = false; + } + +private: + common_chat_syntax syntax; + common_chat_msg previous; + bool had_reasoning; +}; + +class chat_formatter { +public: + chat_formatter( + std::vector & chat_msgs, + const common_chat_templates_ptr & chat_templates, + const common_params & params) + : chat_msgs(chat_msgs), + chat_templates(chat_templates), + params(params) {} + + std::string operator()(const std::string & role, const std::string & content) { + if (role == "user") { + formatted_cumulative.clear(); // Needed if template strips reasoning + + if (partial_formatter_ptr) { + partial_formatter_ptr->clear(); // Remove stale data from delta + } + } + + common_chat_msg new_msg; + if (role == "assistant" && syntax_ptr) { + new_msg = common_chat_parse(content, false, *syntax_ptr); + } else { + new_msg.content = content; + } + new_msg.role = role; + + chat_msgs.push_back(new_msg); + + common_chat_templates_inputs cinputs; + cinputs.messages.assign(chat_msgs.cbegin(), chat_msgs.cend()); + cinputs.use_jinja = params.use_jinja; + cinputs.add_generation_prompt = (role == "user"); + cinputs.reasoning_format = params.reasoning_format; + + cinputs.enable_thinking = + params.use_jinja && + params.reasoning_budget != 0 && + common_chat_templates_support_enable_thinking(chat_templates.get()); + + common_chat_params cparams = common_chat_templates_apply(chat_templates.get(), cinputs); + + if (!syntax_ptr) { + syntax_ptr.reset(new common_chat_syntax); + syntax_ptr->format = cparams.format; + syntax_ptr->reasoning_format = params.reasoning_format; + syntax_ptr->thinking_forced_open = cparams.thinking_forced_open; + syntax_ptr->parse_tool_calls = false; + } + + bool use_partial_formatter = params.reasoning_format != COMMON_REASONING_FORMAT_NONE; + if (!partial_formatter_ptr && use_partial_formatter) { + partial_formatter_ptr = std::make_unique(*syntax_ptr); + } + + std::string formatted; + if (formatted_cumulative.size() > cparams.prompt.size()) { + LOG_WRN("template cumulative size was reduced from \"%zu\" to \"%zu\" " + "likely due to template's removal of message reasoning.\n", + formatted_cumulative.size(), cparams.prompt.size()); + + } else { + formatted = cparams.prompt.substr(formatted_cumulative.size()); + } + + formatted_cumulative = cparams.prompt; + + LOG_DBG("formatted: '%s'\n", formatted.c_str()); + return formatted; + } + + partial_formatter * get_partial_formatter() { return partial_formatter_ptr.get(); } + const std::string & get_full_prompt() const { return formatted_cumulative; } + +private: + std::vector & chat_msgs; + const common_chat_templates_ptr & chat_templates; + const common_params & params; + std::unique_ptr syntax_ptr; + std::unique_ptr partial_formatter_ptr; + std::string formatted_cumulative; +}; + int main(int argc, char ** argv) { common_params params; g_params = ¶ms; @@ -269,15 +402,7 @@ int main(int argc, char ** argv) { std::vector embd_inp; bool waiting_for_first_input = false; - auto chat_add_and_format = [&chat_msgs, &chat_templates](const std::string & role, const std::string & content) { - common_chat_msg new_msg; - new_msg.role = role; - new_msg.content = content; - auto formatted = common_chat_format_single(chat_templates.get(), chat_msgs, new_msg, role == "user", g_params->use_jinja); - chat_msgs.push_back(new_msg); - LOG_DBG("formatted: '%s'\n", formatted.c_str()); - return formatted; - }; + chat_formatter chat_add_and_format(chat_msgs, chat_templates, params); std::string prompt; { @@ -295,13 +420,9 @@ int main(int argc, char ** argv) { } if (!params.system_prompt.empty() || !params.prompt.empty()) { - common_chat_templates_inputs inputs; - inputs.use_jinja = g_params->use_jinja; - inputs.messages = chat_msgs; - inputs.add_generation_prompt = !params.prompt.empty(); - - prompt = common_chat_templates_apply(chat_templates.get(), inputs).prompt; + prompt = chat_add_and_format.get_full_prompt(); } + } else { // otherwise use the prompt as is prompt = params.prompt; @@ -570,6 +691,12 @@ int main(int argc, char ** argv) { embd_inp.push_back(decoder_start_token_id); } + if (chat_add_and_format.get_partial_formatter()) { + for (const auto & msg : chat_msgs) { + console::write(msg.content + "\n"); + } + } + while ((n_remain != 0 && !is_antiprompt) || params.interactive) { // predict if (!embd.empty()) { @@ -717,6 +844,19 @@ int main(int argc, char ** argv) { if (params.conversation_mode && !waiting_for_first_input && !llama_vocab_is_eog(vocab, id)) { assistant_ss << common_token_to_piece(ctx, id, false); + + if (auto * formatter = chat_add_and_format.get_partial_formatter()) { + auto outputs = (*formatter)(assistant_ss.str()); + for (const auto & out : outputs) { + if (out.type == partial_formatter::REASONING) { + console::set_display(console::reasoning); + } else { + console::set_display(console::reset); + } + console::write(out.formatted); + } + console::set_display(console::reset); + } } // echo this to console @@ -748,8 +888,9 @@ int main(int argc, char ** argv) { for (auto id : embd) { const std::string token_str = common_token_to_piece(ctx, id, params.special); - // Console/Stream Output - LOG("%s", token_str.c_str()); + if (!chat_add_and_format.get_partial_formatter()) { + console::write(token_str); + } // Record Displayed Tokens To Log // Note: Generated tokens are created one by one hence this check @@ -832,7 +973,7 @@ int main(int argc, char ** argv) { chat_add_and_format("assistant", assistant_ss.str()); } is_interacting = true; - LOG("\n"); + console::write("\n"); } } @@ -846,8 +987,12 @@ int main(int argc, char ** argv) { if ((n_past > 0 || waiting_for_first_input) && is_interacting) { LOG_DBG("waiting for user input\n"); + // color user input only + console::set_display(console::user_input); + display = params.display_prompt; + if (params.conversation_mode) { - LOG("\n> "); + console::write("\n> "); } if (params.input_prefix_bos) { @@ -858,13 +1003,9 @@ int main(int argc, char ** argv) { std::string buffer; if (!params.input_prefix.empty() && !params.conversation_mode) { LOG_DBG("appending input prefix: '%s'\n", params.input_prefix.c_str()); - LOG("%s", params.input_prefix.c_str()); + console::write(params.input_prefix); } - // color user input only - console::set_display(console::user_input); - display = params.display_prompt; - std::string line; bool another_line = true; do { @@ -877,7 +1018,7 @@ int main(int argc, char ** argv) { display = true; if (buffer.empty()) { // Ctrl+D on empty line exits - LOG("EOF by user\n"); + console::write("EOF by user\n"); break; } @@ -895,7 +1036,7 @@ int main(int argc, char ** argv) { // append input suffix if any if (!params.input_suffix.empty() && !params.conversation_mode) { LOG_DBG("appending input suffix: '%s'\n", params.input_suffix.c_str()); - LOG("%s", params.input_suffix.c_str()); + console::write(params.input_suffix); } LOG_DBG("buffer: '%s'\n", buffer.c_str()); @@ -969,7 +1110,7 @@ int main(int argc, char ** argv) { // end of generation if (!embd.empty() && llama_vocab_is_eog(vocab, embd.back()) && !(params.interactive)) { - LOG(" [end of text]\n"); + console::write(" [end of text]\n"); break; } @@ -982,11 +1123,11 @@ int main(int argc, char ** argv) { } if (!path_session.empty() && params.prompt_cache_all && !params.prompt_cache_ro) { - LOG("\n%s: saving final output to session file '%s'\n", __func__, path_session.c_str()); + LOG_INF("\n%s: saving final output to session file '%s'\n", __func__, path_session.c_str()); llama_state_save_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size()); } - LOG("\n\n"); + console::write("\n\n"); common_perf_print(ctx, smpl); common_sampler_free(smpl);