From ad5fd5b60cfdfbfb22b0f2bc9e9f6c9692768f8d Mon Sep 17 00:00:00 2001 From: tjohnman Date: Sun, 19 Mar 2023 19:36:19 +0100 Subject: [PATCH] Make prompt randomization optional. (#300) Co-authored-by: Johnman <> --- main.cpp | 2 +- utils.cpp | 5 ++++- utils.h | 2 ++ 3 files changed, 7 insertions(+), 2 deletions(-) diff --git a/main.cpp b/main.cpp index 57e9249..6c78cb0 100644 --- a/main.cpp +++ b/main.cpp @@ -803,7 +803,7 @@ int main(int argc, char ** argv) { fprintf(stderr, "%s: seed = %d\n", __func__, params.seed); std::mt19937 rng(params.seed); - if (params.prompt.empty()) { + if (params.random_prompt) { params.prompt = gpt_random_prompt(rng); } diff --git a/utils.cpp b/utils.cpp index a4135b9..04840e4 100644 --- a/utils.cpp +++ b/utils.cpp @@ -76,6 +76,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { } else if (arg == "-h" || arg == "--help") { gpt_print_usage(argc, argv, params); exit(0); + } else if (arg == "--random-prompt") { + params.random_prompt = true; } else { fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); gpt_print_usage(argc, argv, params); @@ -99,7 +101,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1)\n"); fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads); fprintf(stderr, " -p PROMPT, --prompt PROMPT\n"); - fprintf(stderr, " prompt to start generation with (default: random)\n"); + fprintf(stderr, " prompt to start generation with (default: empty)\n"); + fprintf(stderr, " --random-prompt start with a randomized prompt.\n"); fprintf(stderr, " -f FNAME, --file FNAME\n"); fprintf(stderr, " prompt file to start generation.\n"); fprintf(stderr, " -n N, --n_predict N number of tokens to predict (default: %d)\n", params.n_predict); diff --git a/utils.h b/utils.h index 2132519..60ef12b 100644 --- a/utils.h +++ b/utils.h @@ -32,6 +32,8 @@ struct gpt_params { std::string prompt = ""; std::string antiprompt = ""; // string upon seeing which more user input is prompted + bool random_prompt = false; + bool use_color = false; // use color to distinguish generations and inputs bool interactive = false; // interactive mode