gpt_neox_client 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,809 @@
1
+ #define _USE_MATH_DEFINES // for M_PI
2
+
3
+ #include "common.h"
4
+
5
+ // third-party utilities
6
+ // use your favorite implementations
7
+ #define DR_WAV_IMPLEMENTATION
8
+ #include "dr_wav.h"
9
+
10
+ #include <cmath>
11
+ #include <cstring>
12
+ #include <fstream>
13
+ #include <regex>
14
+ #include <locale>
15
+ #include <codecvt>
16
+ #include <sstream>
17
+
18
+ #if defined(_MSC_VER)
19
+ #pragma warning(disable: 4244 4267) // possible loss of data
20
+ #endif
21
+
22
+ // Function to check if the next argument exists
23
+ std::string get_next_arg(int& i, int argc, char** argv, const std::string& flag, gpt_params& params) {
24
+ if (i + 1 < argc && argv[i + 1][0] != '-') {
25
+ return argv[++i];
26
+ } else {
27
+ fprintf(stderr, "error: %s requires one argument.\n", flag.c_str());
28
+ gpt_print_usage(argc, argv, params);
29
+ exit(0);
30
+ }
31
+ }
32
+
33
+ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
34
+ for (int i = 1; i < argc; i++) {
35
+ std::string arg = argv[i];
36
+
37
+ if (arg == "-s" || arg == "--seed") {
38
+ params.seed = std::stoi(get_next_arg(i, argc, argv, arg, params));
39
+ } else if (arg == "-t" || arg == "--threads") {
40
+ params.n_threads = std::stoi(get_next_arg(i, argc, argv, arg, params));
41
+ } else if (arg == "-ngl" || arg == "--gpu-layers" || arg == "--n-gpu-layers") {
42
+ params.n_gpu_layers = std::stoi(get_next_arg(i, argc, argv, arg, params));
43
+ } else if (arg == "-p" || arg == "--prompt") {
44
+ params.prompt = get_next_arg(i, argc, argv, arg, params);
45
+ } else if (arg == "-n" || arg == "--n_predict") {
46
+ params.n_predict = std::stoi(get_next_arg(i, argc, argv, arg, params));
47
+ } else if (arg == "--top_k") {
48
+ params.top_k = std::stoi(get_next_arg(i, argc, argv, arg, params));
49
+ } else if (arg == "--top_p") {
50
+ params.top_p = std::stof(get_next_arg(i, argc, argv, arg, params));
51
+ } else if (arg == "--temp") {
52
+ params.temp = std::stof(get_next_arg(i, argc, argv, arg, params));
53
+ } else if (arg == "--repeat-last-n") {
54
+ params.repeat_last_n = std::stoi(get_next_arg(i, argc, argv, arg, params));
55
+ } else if (arg == "--repeat-penalty") {
56
+ params.repeat_penalty = std::stof(get_next_arg(i, argc, argv, arg, params));
57
+ } else if (arg == "-b" || arg == "--batch_size") {
58
+ params.n_batch= std::stoi(get_next_arg(i, argc, argv, arg, params));
59
+ } else if (arg == "-m" || arg == "--model") {
60
+ params.model = get_next_arg(i, argc, argv, arg, params);
61
+ } else if (arg == "-i" || arg == "--interactive") {
62
+ params.interactive = true;
63
+ } else if (arg == "-ip" || arg == "--interactive-port") {
64
+ params.interactive = true;
65
+ params.interactive_port = std::stoi(get_next_arg(i, argc, argv, arg, params));
66
+ } else if (arg == "-h" || arg == "--help") {
67
+ gpt_print_usage(argc, argv, params);
68
+ exit(0);
69
+ } else if (arg == "-f" || arg == "--file") {
70
+ get_next_arg(i, argc, argv, arg, params);
71
+ std::ifstream file(argv[i]);
72
+ if (!file) {
73
+ fprintf(stderr, "error: failed to open file '%s'\n", argv[i]);
74
+ break;
75
+ }
76
+ std::copy(std::istreambuf_iterator<char>(file), std::istreambuf_iterator<char>(), back_inserter(params.prompt));
77
+ if (params.prompt.back() == '\n') {
78
+ params.prompt.pop_back();
79
+ }
80
+ } else if (arg == "-tt" || arg == "--token_test") {
81
+ params.token_test = get_next_arg(i, argc, argv, arg, params);
82
+ }
83
+ else {
84
+ fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
85
+ gpt_print_usage(argc, argv, params);
86
+ exit(0);
87
+ }
88
+ }
89
+
90
+ return true;
91
+ }
92
+
93
+ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
94
+ fprintf(stderr, "usage: %s [options]\n", argv[0]);
95
+ fprintf(stderr, "\n");
96
+ fprintf(stderr, "options:\n");
97
+ fprintf(stderr, " -h, --help show this help message and exit\n");
98
+ fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1)\n");
99
+ fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
100
+ fprintf(stderr, " -ngl N, --gpu-layers N number of layers to offload to GPU on supported models (default: %d)\n", params.n_gpu_layers);
101
+ fprintf(stderr, " -p PROMPT, --prompt PROMPT\n");
102
+ fprintf(stderr, " prompt to start generation with (default: random)\n");
103
+ fprintf(stderr, " -f FNAME, --file FNAME\n");
104
+ fprintf(stderr, " load prompt from a file\n");
105
+ fprintf(stderr, " -tt TOKEN_TEST, --token_test TOKEN_TEST\n");
106
+ fprintf(stderr, " test tokenization\n");
107
+ fprintf(stderr, " -n N, --n_predict N number of tokens to predict (default: %d)\n", params.n_predict);
108
+ fprintf(stderr, " --top_k N top-k sampling (default: %d)\n", params.top_k);
109
+ fprintf(stderr, " --top_p N top-p sampling (default: %.1f)\n", params.top_p);
110
+ fprintf(stderr, " --temp N temperature (default: %.1f)\n", params.temp);
111
+ fprintf(stderr, " --repeat-last-n N last n tokens to consider for penalize (default: %d, 0 = disabled)\n", params.repeat_last_n);
112
+ fprintf(stderr, " --repeat-penalty N penalize repeat sequence of tokens (default: %.2f, 1.0 = disabled)\n", (double)params.repeat_penalty);
113
+ fprintf(stderr, " -b N, --batch_size N batch size for prompt processing (default: %d)\n", params.n_batch);
114
+ fprintf(stderr, " -m FNAME, --model FNAME\n");
115
+ fprintf(stderr, " model path (default: %s)\n", params.model.c_str());
116
+ fprintf(stderr, "\n");
117
+ }
118
+
119
+ std::string gpt_random_prompt(std::mt19937 & rng) {
120
+ const int r = rng() % 10;
121
+ switch (r) {
122
+ case 0: return "So";
123
+ case 1: return "Once upon a time";
124
+ case 2: return "When";
125
+ case 3: return "The";
126
+ case 4: return "After";
127
+ case 5: return "If";
128
+ case 6: return "import";
129
+ case 7: return "He";
130
+ case 8: return "She";
131
+ case 9: return "They";
132
+ default: return "To";
133
+ }
134
+
135
+ return "The";
136
+ }
137
+
138
+ std::string trim(const std::string & s) {
139
+ std::regex e("^\\s+|\\s+$");
140
+ return std::regex_replace(s, e, "");
141
+ }
142
+
143
+ std::string replace(const std::string & s, const std::string & from, const std::string & to) {
144
+ std::string result = s;
145
+ size_t pos = 0;
146
+ while ((pos = result.find(from, pos)) != std::string::npos) {
147
+ result.replace(pos, from.length(), to);
148
+ pos += to.length();
149
+ }
150
+ return result;
151
+ }
152
+
153
+ void gpt_vocab::add_special_token(const std::string & token) {
154
+ special_tokens.push_back(token);
155
+ }
156
+
157
+ std::map<std::string, int32_t> json_parse(const std::string & fname) {
158
+ std::map<std::string, int32_t> result;
159
+
160
+ // read file into string
161
+ std::string json;
162
+ {
163
+ std::ifstream ifs(fname);
164
+ if (!ifs) {
165
+ fprintf(stderr, "Failed to open %s\n", fname.c_str());
166
+ exit(1);
167
+ }
168
+
169
+ json = std::string((std::istreambuf_iterator<char>(ifs)),
170
+ (std::istreambuf_iterator<char>()));
171
+ }
172
+
173
+ if (json[0] != '{') {
174
+ return result;
175
+ }
176
+
177
+ // parse json
178
+ {
179
+ bool has_key = false;
180
+ bool in_token = false;
181
+
182
+ std::string str_key = "";
183
+ std::string str_val = "";
184
+
185
+ int n = json.size();
186
+ for (int i = 1; i < n; ++i) {
187
+ if (!in_token) {
188
+ if (json[i] == ' ') continue;
189
+ if (json[i] == '"') {
190
+ in_token = true;
191
+ continue;
192
+ }
193
+ } else {
194
+ if (json[i] == '\\' && i+1 < n) {
195
+ if (has_key == false) {
196
+ str_key += json[i];
197
+ } else {
198
+ str_val += json[i];
199
+ }
200
+ ++i;
201
+ } else if (json[i] == '"') {
202
+ if (has_key == false) {
203
+ has_key = true;
204
+ ++i;
205
+ while (json[i] == ' ') ++i;
206
+ ++i; // :
207
+ while (json[i] == ' ') ++i;
208
+ if (json[i] != '\"') {
209
+ while (json[i] != ',' && json[i] != '}') {
210
+ str_val += json[i++];
211
+ }
212
+ has_key = false;
213
+ } else {
214
+ in_token = true;
215
+ continue;
216
+ }
217
+ } else {
218
+ has_key = false;
219
+ }
220
+
221
+ str_key = ::replace(str_key, "\\u0120", " " ); // \u0120 -> space
222
+ str_key = ::replace(str_key, "\\u010a", "\n"); // \u010a -> new line
223
+ str_key = ::replace(str_key, "\\\"", "\""); // \\\" -> "
224
+
225
+ try {
226
+ result[str_key] = std::stoi(str_val);
227
+ } catch (...) {
228
+ //fprintf(stderr, "%s: ignoring key '%s' with value '%s'\n", fname.c_str(), str_key.c_str(), str_val.c_str());
229
+
230
+ }
231
+ str_key = "";
232
+ str_val = "";
233
+ in_token = false;
234
+ continue;
235
+ }
236
+ if (has_key == false) {
237
+ str_key += json[i];
238
+ } else {
239
+ str_val += json[i];
240
+ }
241
+ }
242
+ }
243
+ }
244
+
245
+ return result;
246
+ }
247
+
248
+ std::string convert_to_utf8(const std::wstring & input) {
249
+ std::wstring_convert<std::codecvt_utf8<wchar_t>> converter;
250
+ return converter.to_bytes(input);
251
+ }
252
+
253
+
254
+ std::wstring convert_to_wstring(const std::string & input) {
255
+ std::wstring_convert<std::codecvt_utf8<wchar_t>> converter;
256
+ return converter.from_bytes(input);
257
+ }
258
+
259
+ void gpt_split_words(std::string str, std::vector<std::string>& words) {
260
+ const std::string pattern = R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)";
261
+ const std::regex re(pattern);
262
+ std::smatch m;
263
+
264
+ while (std::regex_search(str, m, re)) {
265
+ for (auto x : m) {
266
+ words.push_back(x);
267
+ }
268
+ str = m.suffix();
269
+ }
270
+ }
271
+
272
+ std::vector<gpt_vocab::id> gpt_tokenize(const gpt_vocab & vocab, const std::string & text) {
273
+ std::vector<std::string> words;
274
+
275
+ // first split the text into words
276
+ {
277
+ std::string str = text;
278
+
279
+ // Generate the subpattern from the special_tokens vector if it's not empty
280
+ if (!vocab.special_tokens.empty()) {
281
+ const std::regex escape(R"([\[\\\^\$\.\|\?\*\+\(\)\{\}])");
282
+ std::string special_tokens_subpattern;
283
+ for (const auto & token : vocab.special_tokens) {
284
+ if (!special_tokens_subpattern.empty()) {
285
+ special_tokens_subpattern += "|";
286
+ }
287
+ special_tokens_subpattern += std::regex_replace(token, escape, R"(\$&)");
288
+ }
289
+
290
+ std::regex re(special_tokens_subpattern);
291
+ std::smatch m;
292
+ // Split the text by special tokens.
293
+ while (std::regex_search(str, m, re)) {
294
+ // Split the substrings in-between special tokens into words.
295
+ gpt_split_words(m.prefix(), words);
296
+ // Add matched special tokens as words.
297
+ for (auto x : m) {
298
+ words.push_back(x);
299
+ }
300
+ str = m.suffix();
301
+ }
302
+ // Remaining text without special tokens will be handled below.
303
+ }
304
+
305
+ gpt_split_words(str, words);
306
+ }
307
+
308
+ // find the longest token that forms each word in words:
309
+ std::vector<gpt_vocab::id> tokens;
310
+ for (const auto & word : words) {
311
+ for (int i = 0; i < (int) word.size(); ){
312
+ for (int j = word.size() - 1; j >= i; j--){
313
+ auto cand = word.substr(i, j-i+1);
314
+ auto it = vocab.token_to_id.find(cand);
315
+ if (it != vocab.token_to_id.end()){ // word.substr(i, j-i+1) in vocab
316
+ tokens.push_back(it->second);
317
+ i = j + 1;
318
+ break;
319
+ }
320
+ else if (j == i){ // word.substr(i, 1) has no matching
321
+ fprintf(stderr, "%s: unknown token '%s'\n", __func__, word.substr(i, 1).data());
322
+ i++;
323
+ }
324
+ }
325
+ }
326
+ }
327
+
328
+ return tokens;
329
+ }
330
+
331
+ std::vector<gpt_vocab::id> parse_tokens_from_string(const std::string& input, char delimiter) {
332
+ std::vector<gpt_vocab::id> output;
333
+ std::stringstream ss(input);
334
+ std::string token;
335
+
336
+ while (std::getline(ss, token, delimiter)) {
337
+ output.push_back(std::stoi(token));
338
+ }
339
+
340
+ return output;
341
+ }
342
+
343
+ std::map<std::string, std::vector<gpt_vocab::id>> extract_tests_from_file(const std::string & fpath_test){
344
+ if (fpath_test.empty()){
345
+ fprintf(stderr, "%s : No test file found.\n", __func__);
346
+ return std::map<std::string, std::vector<gpt_vocab::id>>();
347
+ }
348
+
349
+ std::map<std::string, std::vector<gpt_vocab::id>> tests;
350
+
351
+ auto fin = std::ifstream(fpath_test, std::ios_base::in);
352
+ const char * delimeter = " => ";
353
+ const char del_tok = ',';
354
+ std::string line;
355
+ while (std::getline(fin, line)) {
356
+ size_t delimiterPos = line.find(delimeter);
357
+ if (delimiterPos != std::string::npos) {
358
+ std::string text = line.substr(0, delimiterPos);
359
+ std::string s_tokens = line.substr(delimiterPos + std::strlen(delimeter));
360
+ tests[text] = parse_tokens_from_string(s_tokens, del_tok);
361
+ }
362
+ }
363
+ return tests;
364
+ }
365
+
366
+ void test_gpt_tokenizer(gpt_vocab & vocab, const std::string & fpath_test){
367
+ std::map<std::string, std::vector<gpt_vocab::id>> tests = extract_tests_from_file(fpath_test);
368
+
369
+ size_t n_fails = 0;
370
+
371
+ for (const auto & test : tests) {
372
+ std::vector<gpt_vocab::id> tokens = gpt_tokenize(vocab, test.first);
373
+
374
+ if (tokens != test.second){
375
+ n_fails++;
376
+
377
+ // print out failure cases
378
+ fprintf(stderr, "%s : failed test: '%s'\n", __func__, test.first.c_str());
379
+ fprintf(stderr, "%s : tokens in hf: ", __func__);
380
+ for (const auto & t : test.second) {
381
+ fprintf(stderr, "%s(%d), ", vocab.id_to_token[t].c_str(), t);
382
+ }
383
+ fprintf(stderr, "\n");
384
+ fprintf(stderr, "%s : tokens in ggml: ", __func__);
385
+ for (const auto & t : tokens) {
386
+ fprintf(stderr, "%s(%d), ", vocab.id_to_token[t].c_str(), t);
387
+ }
388
+ fprintf(stderr, "\n");
389
+ }
390
+ }
391
+
392
+ fprintf(stderr, "%s : %zu tests failed out of %zu tests.\n", __func__, n_fails, tests.size());
393
+ }
394
+
395
+ bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab) {
396
+ printf("%s: loading vocab from '%s'\n", __func__, fname.c_str());
397
+
398
+ vocab.token_to_id = ::json_parse(fname);
399
+
400
+ for (const auto & kv : vocab.token_to_id) {
401
+ vocab.id_to_token[kv.second] = kv.first;
402
+ }
403
+
404
+ printf("%s: vocab size = %d\n", __func__, (int) vocab.token_to_id.size());
405
+
406
+ // print the vocabulary
407
+ //for (auto kv : vocab.token_to_id) {
408
+ // printf("'%s' -> %d\n", kv.first.data(), kv.second);
409
+ //}
410
+
411
+ return true;
412
+ }
413
+
414
+ gpt_vocab::id gpt_sample_top_k_top_p(
415
+ const gpt_vocab & vocab,
416
+ const float * logits,
417
+ int top_k,
418
+ double top_p,
419
+ double temp,
420
+ std::mt19937 & rng) {
421
+ int n_logits = vocab.id_to_token.size();
422
+
423
+ std::vector<std::pair<double, gpt_vocab::id>> logits_id;
424
+ logits_id.reserve(n_logits);
425
+
426
+ {
427
+ const double scale = 1.0/temp;
428
+ for (int i = 0; i < n_logits; ++i) {
429
+ logits_id.push_back(std::make_pair(logits[i]*scale, i));
430
+ }
431
+ }
432
+
433
+ // find the top K tokens
434
+ std::partial_sort(
435
+ logits_id.begin(),
436
+ logits_id.begin() + top_k, logits_id.end(),
437
+ [](const std::pair<double, gpt_vocab::id> & a, const std::pair<double, gpt_vocab::id> & b) {
438
+ return a.first > b.first;
439
+ });
440
+
441
+ logits_id.resize(top_k);
442
+
443
+ double maxl = -INFINITY;
444
+ for (const auto & kv : logits_id) {
445
+ maxl = std::max(maxl, kv.first);
446
+ }
447
+
448
+ // compute probs for the top K tokens
449
+ std::vector<double> probs;
450
+ probs.reserve(logits_id.size());
451
+
452
+ double sum = 0.0;
453
+ for (const auto & kv : logits_id) {
454
+ double p = exp(kv.first - maxl);
455
+ probs.push_back(p);
456
+ sum += p;
457
+ }
458
+
459
+ // normalize the probs
460
+ for (auto & p : probs) {
461
+ p /= sum;
462
+ }
463
+
464
+ if (top_p < 1.0f) {
465
+ double cumsum = 0.0f;
466
+ for (int i = 0; i < top_k; i++) {
467
+ cumsum += probs[i];
468
+ if (cumsum >= top_p) {
469
+ top_k = i + 1;
470
+ probs.resize(top_k);
471
+ logits_id.resize(top_k);
472
+ break;
473
+ }
474
+ }
475
+
476
+ cumsum = 1.0/cumsum;
477
+ for (int i = 0; i < (int) probs.size(); i++) {
478
+ probs[i] *= cumsum;
479
+ }
480
+ }
481
+
482
+ //printf("\n");
483
+ //for (int i = 0; i < (int) probs.size(); i++) {
484
+ // printf("%d: '%s' %f\n", i, vocab.id_to_token.at(logits_id[i].second).c_str(), probs[i]);
485
+ //}
486
+ //exit(0);
487
+
488
+ std::discrete_distribution<> dist(probs.begin(), probs.end());
489
+ int idx = dist(rng);
490
+
491
+ return logits_id[idx].second;
492
+ }
493
+
494
+ gpt_vocab::id gpt_sample_top_k_top_p_repeat(
495
+ const gpt_vocab & vocab,
496
+ const float * logits,
497
+ const int32_t * last_n_tokens_data,
498
+ size_t last_n_tokens_data_size,
499
+ int top_k,
500
+ double top_p,
501
+ double temp,
502
+ int repeat_last_n,
503
+ float repeat_penalty,
504
+ std::mt19937 & rng) {
505
+
506
+ int n_logits = vocab.id_to_token.size();
507
+
508
+ const auto * plogits = logits;
509
+
510
+ const auto last_n_tokens = std::vector<int32_t>(last_n_tokens_data, last_n_tokens_data + last_n_tokens_data_size);
511
+
512
+ if (temp <= 0) {
513
+ // select the token with the highest logit directly
514
+ float max_logit = plogits[0];
515
+ gpt_vocab::id max_id = 0;
516
+
517
+ for (int i = 1; i < n_logits; ++i) {
518
+ if (plogits[i] > max_logit) {
519
+ max_logit = plogits[i];
520
+ max_id = i;
521
+ }
522
+ }
523
+ return max_id;
524
+ }
525
+
526
+
527
+ std::vector<std::pair<double, gpt_vocab::id>> logits_id;
528
+ logits_id.reserve(n_logits);
529
+
530
+ {
531
+ const float scale = 1.0f/temp;
532
+ for (int i = 0; i < n_logits; ++i) {
533
+ // repetition penalty from ctrl paper (https://arxiv.org/abs/1909.05858)
534
+ // credit https://github.com/facebookresearch/llama/compare/main...shawwn:llama:main
535
+ if (repeat_last_n > 0 && std::find(last_n_tokens.end()-repeat_last_n, last_n_tokens.end(), i) != last_n_tokens.end()) {
536
+ // if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
537
+ if (plogits[i] < 0.0f) {
538
+ logits_id.push_back(std::make_pair(plogits[i]*scale*repeat_penalty, i));
539
+ } else {
540
+ logits_id.push_back(std::make_pair(plogits[i]*scale/repeat_penalty, i));
541
+ }
542
+ } else {
543
+ logits_id.push_back(std::make_pair(plogits[i]*scale, i));
544
+ }
545
+ }
546
+ }
547
+
548
+ // find the top K tokens
549
+ std::partial_sort(
550
+ logits_id.begin(),
551
+ logits_id.begin() + top_k, logits_id.end(),
552
+ [](const std::pair<double, gpt_vocab::id> & a, const std::pair<double, gpt_vocab::id> & b) {
553
+ return a.first > b.first;
554
+ });
555
+
556
+ logits_id.resize(top_k);
557
+
558
+ double maxl = -INFINITY;
559
+ for (const auto & kv : logits_id) {
560
+ maxl = std::max(maxl, kv.first);
561
+ }
562
+
563
+ // compute probs for the top K tokens
564
+ std::vector<double> probs;
565
+ probs.reserve(logits_id.size());
566
+
567
+ double sum = 0.0;
568
+ for (const auto & kv : logits_id) {
569
+ double p = exp(kv.first - maxl);
570
+ probs.push_back(p);
571
+ sum += p;
572
+ }
573
+
574
+ // normalize the probs
575
+ for (auto & p : probs) {
576
+ p /= sum;
577
+ }
578
+
579
+ if (top_p < 1.0f) {
580
+ double cumsum = 0.0f;
581
+ for (int i = 0; i < top_k; i++) {
582
+ cumsum += probs[i];
583
+ if (cumsum >= top_p) {
584
+ top_k = i + 1;
585
+ probs.resize(top_k);
586
+ logits_id.resize(top_k);
587
+ break;
588
+ }
589
+ }
590
+
591
+ cumsum = 1.0/cumsum;
592
+ for (int i = 0; i < (int) probs.size(); i++) {
593
+ probs[i] *= cumsum;
594
+ }
595
+ }
596
+
597
+ // printf("\n");
598
+ // for (int i = 0; i < (int) probs.size(); i++) {
599
+ // for (int i = 0; i < 10; i++) {
600
+ // printf("%d: '%s' %f\n", i, vocab.id_to_token.at(logits_id[i].second).c_str(), probs[i]);
601
+ // }
602
+
603
+ std::discrete_distribution<> dist(probs.begin(), probs.end());
604
+ int idx = dist(rng);
605
+
606
+ return logits_id[idx].second;
607
+
608
+ }
609
+
610
+ bool read_wav(const std::string & fname, std::vector<float>& pcmf32, std::vector<std::vector<float>>& pcmf32s, bool stereo) {
611
+ drwav wav;
612
+ std::vector<uint8_t> wav_data; // used for pipe input from stdin
613
+
614
+ if (fname == "-") {
615
+ {
616
+ uint8_t buf[1024];
617
+ while (true)
618
+ {
619
+ const size_t n = fread(buf, 1, sizeof(buf), stdin);
620
+ if (n == 0) {
621
+ break;
622
+ }
623
+ wav_data.insert(wav_data.end(), buf, buf + n);
624
+ }
625
+ }
626
+
627
+ if (drwav_init_memory(&wav, wav_data.data(), wav_data.size(), nullptr) == false) {
628
+ fprintf(stderr, "error: failed to open WAV file from stdin\n");
629
+ return false;
630
+ }
631
+
632
+ fprintf(stderr, "%s: read %zu bytes from stdin\n", __func__, wav_data.size());
633
+ }
634
+ else if (drwav_init_file(&wav, fname.c_str(), nullptr) == false) {
635
+ fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname.c_str());
636
+ return false;
637
+ }
638
+
639
+ if (wav.channels != 1 && wav.channels != 2) {
640
+ fprintf(stderr, "%s: WAV file '%s' must be mono or stereo\n", __func__, fname.c_str());
641
+ return false;
642
+ }
643
+
644
+ if (stereo && wav.channels != 2) {
645
+ fprintf(stderr, "%s: WAV file '%s' must be stereo for diarization\n", __func__, fname.c_str());
646
+ return false;
647
+ }
648
+
649
+ if (wav.sampleRate != COMMON_SAMPLE_RATE) {
650
+ fprintf(stderr, "%s: WAV file '%s' must be %i kHz\n", __func__, fname.c_str(), COMMON_SAMPLE_RATE/1000);
651
+ return false;
652
+ }
653
+
654
+ if (wav.bitsPerSample != 16) {
655
+ fprintf(stderr, "%s: WAV file '%s' must be 16-bit\n", __func__, fname.c_str());
656
+ return false;
657
+ }
658
+
659
+ const uint64_t n = wav_data.empty() ? wav.totalPCMFrameCount : wav_data.size()/(wav.channels*wav.bitsPerSample/8);
660
+
661
+ std::vector<int16_t> pcm16;
662
+ pcm16.resize(n*wav.channels);
663
+ drwav_read_pcm_frames_s16(&wav, n, pcm16.data());
664
+ drwav_uninit(&wav);
665
+
666
+ // convert to mono, float
667
+ pcmf32.resize(n);
668
+ if (wav.channels == 1) {
669
+ for (uint64_t i = 0; i < n; i++) {
670
+ pcmf32[i] = float(pcm16[i])/32768.0f;
671
+ }
672
+ } else {
673
+ for (uint64_t i = 0; i < n; i++) {
674
+ pcmf32[i] = float(pcm16[2*i] + pcm16[2*i + 1])/65536.0f;
675
+ }
676
+ }
677
+
678
+ if (stereo) {
679
+ // convert to stereo, float
680
+ pcmf32s.resize(2);
681
+
682
+ pcmf32s[0].resize(n);
683
+ pcmf32s[1].resize(n);
684
+ for (uint64_t i = 0; i < n; i++) {
685
+ pcmf32s[0][i] = float(pcm16[2*i])/32768.0f;
686
+ pcmf32s[1][i] = float(pcm16[2*i + 1])/32768.0f;
687
+ }
688
+ }
689
+
690
+ return true;
691
+ }
692
+
693
+ void high_pass_filter(std::vector<float> & data, float cutoff, float sample_rate) {
694
+ const float rc = 1.0f / (2.0f * M_PI * cutoff);
695
+ const float dt = 1.0f / sample_rate;
696
+ const float alpha = dt / (rc + dt);
697
+
698
+ float y = data[0];
699
+
700
+ for (size_t i = 1; i < data.size(); i++) {
701
+ y = alpha * (y + data[i] - data[i - 1]);
702
+ data[i] = y;
703
+ }
704
+ }
705
+
706
+ bool vad_simple(std::vector<float> & pcmf32, int sample_rate, int last_ms, float vad_thold, float freq_thold, bool verbose) {
707
+ const int n_samples = pcmf32.size();
708
+ const int n_samples_last = (sample_rate * last_ms) / 1000;
709
+
710
+ if (n_samples_last >= n_samples) {
711
+ // not enough samples - assume no speech
712
+ return false;
713
+ }
714
+
715
+ if (freq_thold > 0.0f) {
716
+ high_pass_filter(pcmf32, freq_thold, sample_rate);
717
+ }
718
+
719
+ float energy_all = 0.0f;
720
+ float energy_last = 0.0f;
721
+
722
+ for (int i = 0; i < n_samples; i++) {
723
+ energy_all += fabsf(pcmf32[i]);
724
+
725
+ if (i >= n_samples - n_samples_last) {
726
+ energy_last += fabsf(pcmf32[i]);
727
+ }
728
+ }
729
+
730
+ energy_all /= n_samples;
731
+ energy_last /= n_samples_last;
732
+
733
+ if (verbose) {
734
+ fprintf(stderr, "%s: energy_all: %f, energy_last: %f, vad_thold: %f, freq_thold: %f\n", __func__, energy_all, energy_last, vad_thold, freq_thold);
735
+ }
736
+
737
+ if (energy_last > vad_thold*energy_all) {
738
+ return false;
739
+ }
740
+
741
+ return true;
742
+ }
743
+
744
+ float similarity(const std::string & s0, const std::string & s1) {
745
+ const size_t len0 = s0.size() + 1;
746
+ const size_t len1 = s1.size() + 1;
747
+
748
+ std::vector<int> col(len1, 0);
749
+ std::vector<int> prevCol(len1, 0);
750
+
751
+ for (size_t i = 0; i < len1; i++) {
752
+ prevCol[i] = i;
753
+ }
754
+
755
+ for (size_t i = 0; i < len0; i++) {
756
+ col[0] = i;
757
+ for (size_t j = 1; j < len1; j++) {
758
+ col[j] = std::min(std::min(1 + col[j - 1], 1 + prevCol[j]), prevCol[j - 1] + (i > 0 && s0[i - 1] == s1[j - 1] ? 0 : 1));
759
+ }
760
+ col.swap(prevCol);
761
+ }
762
+
763
+ const float dist = prevCol[len1 - 1];
764
+
765
+ return 1.0f - (dist / std::max(s0.size(), s1.size()));
766
+ }
767
+
768
+ bool sam_params_parse(int argc, char ** argv, sam_params & params) {
769
+ for (int i = 1; i < argc; i++) {
770
+ std::string arg = argv[i];
771
+
772
+ if (arg == "-s" || arg == "--seed") {
773
+ params.seed = std::stoi(argv[++i]);
774
+ } else if (arg == "-t" || arg == "--threads") {
775
+ params.n_threads = std::stoi(argv[++i]);
776
+ } else if (arg == "-m" || arg == "--model") {
777
+ params.model = argv[++i];
778
+ } else if (arg == "-i" || arg == "--inp") {
779
+ params.fname_inp = argv[++i];
780
+ } else if (arg == "-o" || arg == "--out") {
781
+ params.fname_out = argv[++i];
782
+ } else if (arg == "-h" || arg == "--help") {
783
+ sam_print_usage(argc, argv, params);
784
+ exit(0);
785
+ } else {
786
+ fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
787
+ sam_print_usage(argc, argv, params);
788
+ exit(0);
789
+ }
790
+ }
791
+
792
+ return true;
793
+ }
794
+
795
+ void sam_print_usage(int argc, char ** argv, const sam_params & params) {
796
+ fprintf(stderr, "usage: %s [options]\n", argv[0]);
797
+ fprintf(stderr, "\n");
798
+ fprintf(stderr, "options:\n");
799
+ fprintf(stderr, " -h, --help show this help message and exit\n");
800
+ fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1)\n");
801
+ fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
802
+ fprintf(stderr, " -m FNAME, --model FNAME\n");
803
+ fprintf(stderr, " model path (default: %s)\n", params.model.c_str());
804
+ fprintf(stderr, " -i FNAME, --inp FNAME\n");
805
+ fprintf(stderr, " input file (default: %s)\n", params.fname_inp.c_str());
806
+ fprintf(stderr, " -o FNAME, --out FNAME\n");
807
+ fprintf(stderr, " output file (default: %s)\n", params.fname_out.c_str());
808
+ fprintf(stderr, "\n");
809
+ }