From 7b8dbcb78b2f65c4676e41da215800d65846edd0 Mon Sep 17 00:00:00 2001 From: anzz1 Date: Tue, 28 Mar 2023 17:09:55 +0300 Subject: [PATCH] main.cpp fixes, refactoring (#571) - main: entering empty line passes back control without new input in interactive/instruct modes - instruct mode: keep prompt fix - instruct mode: duplicate instruct prompt fix - refactor: move common console code from main->common --- examples/common.cpp | 67 +++++++++++++++-- examples/common.h | 30 ++++++++ examples/main/main.cpp | 164 +++++++++++++---------------------------- 3 files changed, 143 insertions(+), 118 deletions(-) diff --git a/examples/common.cpp b/examples/common.cpp index 2ab000f..880ebe9 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -9,11 +9,20 @@ #include #include - #if defined(_MSC_VER) || defined(__MINGW32__) - #include // using malloc.h with MSC/MINGW - #elif !defined(__FreeBSD__) && !defined(__NetBSD__) && !defined(__OpenBSD__) - #include - #endif +#if defined(_MSC_VER) || defined(__MINGW32__) +#include // using malloc.h with MSC/MINGW +#elif !defined(__FreeBSD__) && !defined(__NetBSD__) && !defined(__OpenBSD__) +#include +#endif + +#if defined (_WIN32) +#pragma comment(lib,"kernel32.lib") +extern "C" __declspec(dllimport) void* __stdcall GetStdHandle(unsigned long nStdHandle); +extern "C" __declspec(dllimport) int __stdcall GetConsoleMode(void* hConsoleHandle, unsigned long* lpMode); +extern "C" __declspec(dllimport) int __stdcall SetConsoleMode(void* hConsoleHandle, unsigned long dwMode); +extern "C" __declspec(dllimport) int __stdcall SetConsoleCP(unsigned int wCodePageID); +extern "C" __declspec(dllimport) int __stdcall SetConsoleOutputCP(unsigned int wCodePageID); +#endif bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { // determine sensible default number of threads. @@ -204,7 +213,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { fprintf(stderr, " --in-prefix STRING string to prefix user inputs with (default: empty)\n"); fprintf(stderr, " -f FNAME, --file FNAME\n"); fprintf(stderr, " prompt file to start generation.\n"); - fprintf(stderr, " -n N, --n_predict N number of tokens to predict (default: %d, -1 - infinity)\n", params.n_predict); + fprintf(stderr, " -n N, --n_predict N number of tokens to predict (default: %d, -1 = infinity)\n", params.n_predict); fprintf(stderr, " --top_k N top-k sampling (default: %d)\n", params.top_k); fprintf(stderr, " --top_p N top-p sampling (default: %.1f)\n", params.top_p); fprintf(stderr, " --repeat_last_n N last n tokens to consider for penalize (default: %d)\n", params.repeat_last_n); @@ -216,7 +225,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { fprintf(stderr, " --n_parts N number of model parts (default: -1 = determine from dimensions)\n"); fprintf(stderr, " -b N, --batch_size N batch size for prompt processing (default: %d)\n", params.n_batch); fprintf(stderr, " --perplexity compute perplexity over the prompt\n"); - fprintf(stderr, " --keep number of tokens to keep from the initial prompt\n"); + fprintf(stderr, " --keep number of tokens to keep from the initial prompt (default: %d, -1 = all)\n", params.n_keep); if (ggml_mlock_supported()) { fprintf(stderr, " --mlock force system to keep model in RAM rather than swapping or compressing\n"); } @@ -256,3 +265,47 @@ std::vector llama_tokenize(struct llama_context * ctx, const std::s return res; } + +/* Keep track of current color of output, and emit ANSI code if it changes. */ +void set_console_color(console_state & con_st, console_color_t color) { + if (con_st.use_color && con_st.color != color) { + switch(color) { + case CONSOLE_COLOR_DEFAULT: + printf(ANSI_COLOR_RESET); + break; + case CONSOLE_COLOR_PROMPT: + printf(ANSI_COLOR_YELLOW); + break; + case CONSOLE_COLOR_USER_INPUT: + printf(ANSI_BOLD ANSI_COLOR_GREEN); + break; + } + con_st.color = color; + } +} + +#if defined (_WIN32) +void win32_console_init(bool enable_color) { + unsigned long dwMode = 0; + void* hConOut = GetStdHandle((unsigned long)-11); // STD_OUTPUT_HANDLE (-11) + if (!hConOut || hConOut == (void*)-1 || !GetConsoleMode(hConOut, &dwMode)) { + hConOut = GetStdHandle((unsigned long)-12); // STD_ERROR_HANDLE (-12) + if (hConOut && (hConOut == (void*)-1 || !GetConsoleMode(hConOut, &dwMode))) { + hConOut = 0; + } + } + if (hConOut) { + // Enable ANSI colors on Windows 10+ + if (enable_color && !(dwMode & 0x4)) { + SetConsoleMode(hConOut, dwMode | 0x4); // ENABLE_VIRTUAL_TERMINAL_PROCESSING (0x4) + } + // Set console output codepage to UTF8 + SetConsoleOutputCP(65001); // CP_UTF8 + } + void* hConIn = GetStdHandle((unsigned long)-10); // STD_INPUT_HANDLE (-10) + if (hConIn && hConIn != (void*)-1 && GetConsoleMode(hConIn, &dwMode)) { + // Set console input codepage to UTF8 + SetConsoleCP(65001); // CP_UTF8 + } +} +#endif diff --git a/examples/common.h b/examples/common.h index 8caefd8..1505aa9 100644 --- a/examples/common.h +++ b/examples/common.h @@ -63,3 +63,33 @@ std::string gpt_random_prompt(std::mt19937 & rng); // std::vector llama_tokenize(struct llama_context * ctx, const std::string & text, bool add_bos); + +// +// Console utils +// + +#define ANSI_COLOR_RED "\x1b[31m" +#define ANSI_COLOR_GREEN "\x1b[32m" +#define ANSI_COLOR_YELLOW "\x1b[33m" +#define ANSI_COLOR_BLUE "\x1b[34m" +#define ANSI_COLOR_MAGENTA "\x1b[35m" +#define ANSI_COLOR_CYAN "\x1b[36m" +#define ANSI_COLOR_RESET "\x1b[0m" +#define ANSI_BOLD "\x1b[1m" + +enum console_color_t { + CONSOLE_COLOR_DEFAULT=0, + CONSOLE_COLOR_PROMPT, + CONSOLE_COLOR_USER_INPUT +}; + +struct console_state { + bool use_color = false; + console_color_t color = CONSOLE_COLOR_DEFAULT; +}; + +void set_console_color(console_state & con_st, console_color_t color); + +#if defined (_WIN32) +void win32_console_init(bool enable_color); +#endif diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 66b7c2d..d5ab2cf 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -18,58 +18,13 @@ #include #endif -#if defined (_WIN32) -#pragma comment(lib,"kernel32.lib") -extern "C" __declspec(dllimport) void* __stdcall GetStdHandle(unsigned long nStdHandle); -extern "C" __declspec(dllimport) int __stdcall GetConsoleMode(void* hConsoleHandle, unsigned long* lpMode); -extern "C" __declspec(dllimport) int __stdcall SetConsoleMode(void* hConsoleHandle, unsigned long dwMode); -extern "C" __declspec(dllimport) int __stdcall SetConsoleCP(unsigned int wCodePageID); -extern "C" __declspec(dllimport) int __stdcall SetConsoleOutputCP(unsigned int wCodePageID); -#endif - -#define ANSI_COLOR_RED "\x1b[31m" -#define ANSI_COLOR_GREEN "\x1b[32m" -#define ANSI_COLOR_YELLOW "\x1b[33m" -#define ANSI_COLOR_BLUE "\x1b[34m" -#define ANSI_COLOR_MAGENTA "\x1b[35m" -#define ANSI_COLOR_CYAN "\x1b[36m" -#define ANSI_COLOR_RESET "\x1b[0m" -#define ANSI_BOLD "\x1b[1m" - -/* Keep track of current color of output, and emit ANSI code if it changes. */ -enum console_state { - CONSOLE_STATE_DEFAULT=0, - CONSOLE_STATE_PROMPT, - CONSOLE_STATE_USER_INPUT -}; - -static console_state con_st = CONSOLE_STATE_DEFAULT; -static bool con_use_color = false; - -void set_console_state(console_state new_st) { - if (!con_use_color) return; - // only emit color code if state changed - if (new_st != con_st) { - con_st = new_st; - switch(con_st) { - case CONSOLE_STATE_DEFAULT: - printf(ANSI_COLOR_RESET); - return; - case CONSOLE_STATE_PROMPT: - printf(ANSI_COLOR_YELLOW); - return; - case CONSOLE_STATE_USER_INPUT: - printf(ANSI_BOLD ANSI_COLOR_GREEN); - return; - } - } -} +static console_state con_st; static bool is_interacting = false; #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32) void sigint_handler(int signo) { - set_console_state(CONSOLE_STATE_DEFAULT); + set_console_color(con_st, CONSOLE_COLOR_DEFAULT); printf("\n"); // this also force flush stdout. if (signo == SIGINT) { if (!is_interacting) { @@ -81,32 +36,6 @@ void sigint_handler(int signo) { } #endif -#if defined (_WIN32) -void win32_console_init(void) { - unsigned long dwMode = 0; - void* hConOut = GetStdHandle((unsigned long)-11); // STD_OUTPUT_HANDLE (-11) - if (!hConOut || hConOut == (void*)-1 || !GetConsoleMode(hConOut, &dwMode)) { - hConOut = GetStdHandle((unsigned long)-12); // STD_ERROR_HANDLE (-12) - if (hConOut && (hConOut == (void*)-1 || !GetConsoleMode(hConOut, &dwMode))) { - hConOut = 0; - } - } - if (hConOut) { - // Enable ANSI colors on Windows 10+ - if (con_use_color && !(dwMode & 0x4)) { - SetConsoleMode(hConOut, dwMode | 0x4); // ENABLE_VIRTUAL_TERMINAL_PROCESSING (0x4) - } - // Set console output codepage to UTF8 - SetConsoleOutputCP(65001); // CP_UTF8 - } - void* hConIn = GetStdHandle((unsigned long)-10); // STD_INPUT_HANDLE (-10) - if (hConIn && hConIn != (void*)-1 && GetConsoleMode(hConIn, &dwMode)) { - // Set console input codepage to UTF8 - SetConsoleCP(65001); // CP_UTF8 - } -} -#endif - int main(int argc, char ** argv) { gpt_params params; params.model = "models/llama-7B/ggml-model.bin"; @@ -115,13 +44,12 @@ int main(int argc, char ** argv) { return 1; } - // save choice to use color for later // (note for later: this is a slightly awkward choice) - con_use_color = params.use_color; + con_st.use_color = params.use_color; #if defined (_WIN32) - win32_console_init(); + win32_console_init(params.use_color); #endif if (params.perplexity) { @@ -218,7 +146,10 @@ int main(int argc, char ** argv) { return 1; } - params.n_keep = std::min(params.n_keep, (int) embd_inp.size()); + // number of tokens to keep when resetting context + if (params.n_keep < 0 || params.n_keep > (int)embd_inp.size() || params.instruct) { + params.n_keep = (int)embd_inp.size(); + } // prefix & suffix for instruct mode const auto inp_pfx = ::llama_tokenize(ctx, "\n\n### Instruction:\n\n", true); @@ -226,16 +157,12 @@ int main(int argc, char ** argv) { // in instruct mode, we inject a prefix and a suffix to each input by the user if (params.instruct) { - params.interactive = true; + params.interactive_start = true; params.antiprompt.push_back("### Instruction:\n\n"); } - // enable interactive mode if reverse prompt is specified - if (params.antiprompt.size() != 0) { - params.interactive = true; - } - - if (params.interactive_start) { + // enable interactive mode if reverse prompt or interactive start is specified + if (params.antiprompt.size() != 0 || params.interactive_start) { params.interactive = true; } @@ -297,17 +224,18 @@ int main(int argc, char ** argv) { #endif " - Press Return to return control to LLaMa.\n" " - If you want to submit another line, end your input in '\\'.\n\n"); - is_interacting = params.interactive_start || params.instruct; + is_interacting = params.interactive_start; } - bool input_noecho = false; + bool is_antiprompt = false; + bool input_noecho = false; int n_past = 0; int n_remain = params.n_predict; int n_consumed = 0; // the first thing we will do is to output the prompt, so set color accordingly - set_console_state(CONSOLE_STATE_PROMPT); + set_console_color(con_st, CONSOLE_COLOR_PROMPT); std::vector embd; @@ -408,36 +336,38 @@ int main(int argc, char ** argv) { } // reset color to default if we there is no pending user input if (!input_noecho && (int)embd_inp.size() == n_consumed) { - set_console_state(CONSOLE_STATE_DEFAULT); + set_console_color(con_st, CONSOLE_COLOR_DEFAULT); } // in interactive mode, and not currently processing queued inputs; // check if we should prompt the user for more if (params.interactive && (int) embd_inp.size() <= n_consumed) { + // check for reverse prompt - std::string last_output; - for (auto id : last_n_tokens) { - last_output += llama_token_to_str(ctx, id); - } + if (params.antiprompt.size()) { + std::string last_output; + for (auto id : last_n_tokens) { + last_output += llama_token_to_str(ctx, id); + } - // Check if each of the reverse prompts appears at the end of the output. - for (std::string & antiprompt : params.antiprompt) { - if (last_output.find(antiprompt.c_str(), last_output.length() - antiprompt.length(), antiprompt.length()) != std::string::npos) { - is_interacting = true; - set_console_state(CONSOLE_STATE_USER_INPUT); - fflush(stdout); - break; + is_antiprompt = false; + // Check if each of the reverse prompts appears at the end of the output. + for (std::string & antiprompt : params.antiprompt) { + if (last_output.find(antiprompt.c_str(), last_output.length() - antiprompt.length(), antiprompt.length()) != std::string::npos) { + is_interacting = true; + is_antiprompt = true; + set_console_color(con_st, CONSOLE_COLOR_USER_INPUT); + fflush(stdout); + break; + } } } if (n_past > 0 && is_interacting) { // potentially set color to indicate we are taking user input - set_console_state(CONSOLE_STATE_USER_INPUT); + set_console_color(con_st, CONSOLE_COLOR_USER_INPUT); if (params.instruct) { - n_consumed = embd_inp.size(); - embd_inp.insert(embd_inp.end(), inp_pfx.begin(), inp_pfx.end()); - printf("\n> "); } @@ -463,16 +393,28 @@ int main(int argc, char ** argv) { } while (another_line); // done taking input, reset color - set_console_state(CONSOLE_STATE_DEFAULT); + set_console_color(con_st, CONSOLE_COLOR_DEFAULT); - auto line_inp = ::llama_tokenize(ctx, buffer, false); - embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end()); + // Add tokens to embd only if the input buffer is non-empty + // Entering a empty line lets the user pass control back + if (buffer.length() > 1) { - if (params.instruct) { - embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end()); - } + // instruct mode: insert instruction prefix + if (params.instruct && !is_antiprompt) { + n_consumed = embd_inp.size(); + embd_inp.insert(embd_inp.end(), inp_pfx.begin(), inp_pfx.end()); + } - n_remain -= line_inp.size(); + auto line_inp = ::llama_tokenize(ctx, buffer, false); + embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end()); + + // instruct mode: insert response suffix + if (params.instruct) { + embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end()); + } + + n_remain -= line_inp.size(); + } input_noecho = true; // do not echo this again } @@ -506,7 +448,7 @@ int main(int argc, char ** argv) { llama_print_timings(ctx); llama_free(ctx); - set_console_state(CONSOLE_STATE_DEFAULT); + set_console_color(con_st, CONSOLE_COLOR_DEFAULT); return 0; }