whisper.rn 0.1.4 → 0.2.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +1 -1
- package/README.md +43 -4
- package/android/build.gradle +2 -4
- package/android/src/main/java/com/rnwhisper/RNWhisperModule.java +47 -7
- package/android/src/main/java/com/rnwhisper/WhisperContext.java +196 -7
- package/android/src/main/jni/whisper/Whisper.mk +1 -1
- package/android/src/main/jni/whisper/jni.cpp +33 -9
- package/cpp/rn-whisper.cpp +26 -0
- package/cpp/rn-whisper.h +5 -0
- package/cpp/whisper.cpp +603 -412
- package/cpp/whisper.h +120 -40
- package/ios/RNWhisper.h +2 -2
- package/ios/RNWhisper.mm +78 -111
- package/ios/RNWhisperContext.h +53 -0
- package/ios/RNWhisperContext.mm +303 -0
- package/jest/mock.js +38 -2
- package/lib/commonjs/index.js +63 -2
- package/lib/commonjs/index.js.map +1 -1
- package/lib/module/index.js +64 -3
- package/lib/module/index.js.map +1 -1
- package/lib/typescript/index.d.ts +61 -2
- package/lib/typescript/index.d.ts.map +1 -1
- package/package.json +2 -2
- package/src/index.tsx +121 -4
- package/whisper-rn.podspec +15 -8
package/cpp/whisper.cpp
CHANGED
|
@@ -547,13 +547,11 @@ struct whisper_decoder {
|
|
|
547
547
|
std::vector<whisper_token> tokens_tmp; // used for whisper_decode calls
|
|
548
548
|
};
|
|
549
549
|
|
|
550
|
-
struct
|
|
551
|
-
int64_t t_load_us = 0;
|
|
552
|
-
int64_t t_mel_us = 0;
|
|
550
|
+
struct whisper_state {
|
|
553
551
|
int64_t t_sample_us = 0;
|
|
554
552
|
int64_t t_encode_us = 0;
|
|
555
553
|
int64_t t_decode_us = 0;
|
|
556
|
-
int64_t
|
|
554
|
+
int64_t t_mel_us = 0;
|
|
557
555
|
|
|
558
556
|
int32_t n_sample = 0; // number of tokens sampled
|
|
559
557
|
int32_t n_encode = 0; // number of encoder calls
|
|
@@ -561,16 +559,10 @@ struct whisper_context {
|
|
|
561
559
|
int32_t n_fail_p = 0; // number of logprob threshold failures
|
|
562
560
|
int32_t n_fail_h = 0; // number of entropy threshold failures
|
|
563
561
|
|
|
564
|
-
ggml_type wtype; // weight type (FP32 or FP16)
|
|
565
|
-
|
|
566
|
-
whisper_mel mel;
|
|
567
|
-
|
|
568
|
-
whisper_model model;
|
|
569
|
-
whisper_vocab vocab;
|
|
570
|
-
|
|
571
562
|
// cross-attention KV cache for the decoders
|
|
572
563
|
// shared between all decoders
|
|
573
564
|
whisper_kv_cache kv_cross;
|
|
565
|
+
whisper_mel mel;
|
|
574
566
|
|
|
575
567
|
whisper_decoder decoders[WHISPER_MAX_DECODERS] = {};
|
|
576
568
|
|
|
@@ -635,6 +627,18 @@ struct whisper_context {
|
|
|
635
627
|
}
|
|
636
628
|
};
|
|
637
629
|
|
|
630
|
+
struct whisper_context {
|
|
631
|
+
int64_t t_load_us = 0;
|
|
632
|
+
int64_t t_start_us = 0;
|
|
633
|
+
|
|
634
|
+
|
|
635
|
+
ggml_type wtype = ggml_type::GGML_TYPE_F16; // weight type (FP32 or FP16)
|
|
636
|
+
|
|
637
|
+
whisper_model model;
|
|
638
|
+
whisper_vocab vocab;
|
|
639
|
+
whisper_state * state = nullptr;
|
|
640
|
+
};
|
|
641
|
+
|
|
638
642
|
template<typename T>
|
|
639
643
|
static void read_safe(whisper_model_loader * loader, T & dest) {
|
|
640
644
|
loader->read(loader->context, &dest, sizeof(T));
|
|
@@ -821,32 +825,8 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
|
821
825
|
wctx.model.buf = new std::vector<uint8_t>();
|
|
822
826
|
wctx.model.buf->resize(scale*MEM_REQ_MODEL.at(model.type));
|
|
823
827
|
|
|
824
|
-
|
|
825
|
-
|
|
826
|
-
return false;
|
|
827
|
-
}
|
|
828
|
-
|
|
829
|
-
{
|
|
830
|
-
const size_t memory_size = ggml_nbytes(wctx.decoders[0].kv_self.k) + ggml_nbytes(wctx.decoders[0].kv_self.v);
|
|
831
|
-
fprintf(stderr, "%s: kv self size = %7.2f MB\n", __func__, memory_size/1024.0/1024.0);
|
|
832
|
-
}
|
|
833
|
-
|
|
834
|
-
if (!kv_cache_init(model.hparams, scale*MEM_REQ_KV_CROSS.at(model.type), wctx.kv_cross, wctx.wtype, model.hparams.n_audio_ctx)) {
|
|
835
|
-
fprintf(stderr, "%s: kv_cache_init() failed for cross-attention cache\n", __func__);
|
|
836
|
-
return false;
|
|
837
|
-
}
|
|
838
|
-
|
|
839
|
-
{
|
|
840
|
-
const size_t memory_size = ggml_nbytes(wctx.kv_cross.k) + ggml_nbytes(wctx.kv_cross.v);
|
|
841
|
-
fprintf(stderr, "%s: kv cross size = %7.2f MB\n", __func__, memory_size/1024.0/1024.0);
|
|
842
|
-
}
|
|
843
|
-
|
|
844
|
-
wctx.buf_compute.resize(scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type)));
|
|
845
|
-
|
|
846
|
-
wctx.buf_scratch[0].resize(MEM_REQ_SCRATCH0.at(model.type));
|
|
847
|
-
wctx.buf_scratch[1].resize(MEM_REQ_SCRATCH1.at(model.type));
|
|
848
|
-
wctx.buf_scratch[2].resize(MEM_REQ_SCRATCH2.at(model.type));
|
|
849
|
-
wctx.buf_scratch[3].resize(MEM_REQ_SCRATCH3.at(model.type));
|
|
828
|
+
// we skip initialization of the state until it is needed
|
|
829
|
+
// because it might be that state will always be provided externally.
|
|
850
830
|
}
|
|
851
831
|
|
|
852
832
|
// load mel filters
|
|
@@ -929,17 +909,6 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
|
929
909
|
vocab.id_to_token[i] = word;
|
|
930
910
|
}
|
|
931
911
|
}
|
|
932
|
-
|
|
933
|
-
wctx.logits.reserve(vocab.n_vocab*model.hparams.n_text_ctx);
|
|
934
|
-
|
|
935
|
-
wctx.logits_id.reserve(n_vocab);
|
|
936
|
-
|
|
937
|
-
// TAGS: WHISPER_DECODER_INIT
|
|
938
|
-
wctx.decoders[0].sequence.tokens.reserve(model.hparams.n_text_ctx);
|
|
939
|
-
|
|
940
|
-
wctx.decoders[0].probs.reserve (vocab.n_vocab);
|
|
941
|
-
wctx.decoders[0].logits.reserve (vocab.n_vocab);
|
|
942
|
-
wctx.decoders[0].logprobs.reserve(vocab.n_vocab);
|
|
943
912
|
}
|
|
944
913
|
|
|
945
914
|
size_t ctx_size = 0;
|
|
@@ -1339,33 +1308,34 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
|
1339
1308
|
}
|
|
1340
1309
|
}
|
|
1341
1310
|
|
|
1342
|
-
wctx.rng = std::mt19937(0);
|
|
1343
|
-
|
|
1344
1311
|
wctx.t_load_us = ggml_time_us() - t_start_us;
|
|
1345
1312
|
|
|
1346
1313
|
return true;
|
|
1347
1314
|
}
|
|
1348
1315
|
|
|
1349
|
-
// evaluate the encoder
|
|
1316
|
+
// evaluate the encoder with the given state
|
|
1350
1317
|
//
|
|
1351
1318
|
// given audio recording (more specifically, its log mel spectrogram), runs forward pass of the encoder
|
|
1352
1319
|
// part of the transformer model and returns the encoded features
|
|
1353
1320
|
//
|
|
1354
|
-
// -
|
|
1321
|
+
// - wctx: the model
|
|
1322
|
+
// - wstate: the state of the encoder
|
|
1355
1323
|
// - n_threads: number of threads to use
|
|
1356
1324
|
// - mel_offset: offset in the mel spectrogram (i.e. audio offset)
|
|
1357
1325
|
//
|
|
1358
|
-
static bool
|
|
1326
|
+
static bool whisper_encode_internal(
|
|
1359
1327
|
whisper_context & wctx,
|
|
1328
|
+
whisper_state & wstate,
|
|
1360
1329
|
const int mel_offset,
|
|
1361
|
-
const int n_threads)
|
|
1330
|
+
const int n_threads){
|
|
1331
|
+
|
|
1362
1332
|
const int64_t t_start_us = ggml_time_us();
|
|
1363
1333
|
|
|
1364
1334
|
const auto & model = wctx.model;
|
|
1365
|
-
const auto & mel_inp =
|
|
1335
|
+
const auto & mel_inp = wstate.mel;
|
|
1366
1336
|
const auto & hparams = model.hparams;
|
|
1367
1337
|
|
|
1368
|
-
const int n_ctx =
|
|
1338
|
+
const int n_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx;
|
|
1369
1339
|
const int n_state = hparams.n_audio_state;
|
|
1370
1340
|
const int n_head = hparams.n_audio_head;
|
|
1371
1341
|
const int n_layer = hparams.n_audio_layer;
|
|
@@ -1374,12 +1344,12 @@ static bool whisper_encode(
|
|
|
1374
1344
|
assert(mel_inp.n_mel == n_mels);
|
|
1375
1345
|
|
|
1376
1346
|
struct ggml_init_params params;
|
|
1377
|
-
params.mem_size =
|
|
1378
|
-
params.mem_buffer =
|
|
1347
|
+
params.mem_size = wstate.buf_compute.size();
|
|
1348
|
+
params.mem_buffer = wstate.buf_compute.data();
|
|
1379
1349
|
|
|
1380
1350
|
struct ggml_context * ctx0 = ggml_init(params);
|
|
1381
1351
|
|
|
1382
|
-
|
|
1352
|
+
wstate.use_buf(ctx0, 0);
|
|
1383
1353
|
|
|
1384
1354
|
struct ggml_tensor * mel = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2*n_ctx, n_mels);
|
|
1385
1355
|
assert(mel->type == GGML_TYPE_F32);
|
|
@@ -1401,30 +1371,30 @@ static bool whisper_encode(
|
|
|
1401
1371
|
|
|
1402
1372
|
// convolution + gelu
|
|
1403
1373
|
{
|
|
1404
|
-
|
|
1374
|
+
wstate.use_buf(ctx0, 1);
|
|
1405
1375
|
|
|
1406
1376
|
cur = ggml_conv_1d_1s(ctx0, model.e_conv_1_w, mel);
|
|
1407
1377
|
cur = ggml_add(ctx0,
|
|
1408
|
-
|
|
1409
|
-
|
|
1410
|
-
|
|
1411
|
-
|
|
1378
|
+
ggml_repeat(ctx0,
|
|
1379
|
+
model.e_conv_1_b,
|
|
1380
|
+
cur),
|
|
1381
|
+
cur);
|
|
1412
1382
|
|
|
1413
1383
|
cur = ggml_gelu(ctx0, cur);
|
|
1414
1384
|
|
|
1415
|
-
|
|
1385
|
+
wstate.use_buf(ctx0, 0);
|
|
1416
1386
|
|
|
1417
1387
|
cur = ggml_conv_1d_2s(ctx0, model.e_conv_2_w, cur);
|
|
1418
1388
|
cur = ggml_add(ctx0,
|
|
1419
|
-
|
|
1420
|
-
|
|
1421
|
-
|
|
1422
|
-
|
|
1389
|
+
ggml_repeat(ctx0,
|
|
1390
|
+
model.e_conv_2_b,
|
|
1391
|
+
cur),
|
|
1392
|
+
cur);
|
|
1423
1393
|
|
|
1424
1394
|
cur = ggml_gelu(ctx0, cur);
|
|
1425
1395
|
}
|
|
1426
1396
|
|
|
1427
|
-
|
|
1397
|
+
wstate.use_buf(ctx0, 3);
|
|
1428
1398
|
|
|
1429
1399
|
// ===================================================================
|
|
1430
1400
|
// NOTE: experimenting with partial evaluation of the encoder (ignore)
|
|
@@ -1439,7 +1409,7 @@ static bool whisper_encode(
|
|
|
1439
1409
|
//}
|
|
1440
1410
|
|
|
1441
1411
|
static int iter = 0;
|
|
1442
|
-
|
|
1412
|
+
|
|
1443
1413
|
const size_t e_pe_stride = model.e_pe->ne[0]*ggml_element_size(model.e_pe);
|
|
1444
1414
|
const size_t e_pe_offset = model.e_pe->ne[0]*ggml_element_size(model.e_pe)*n_ctx*iter;
|
|
1445
1415
|
|
|
@@ -1459,54 +1429,54 @@ static bool whisper_encode(
|
|
|
1459
1429
|
|
|
1460
1430
|
// norm
|
|
1461
1431
|
{
|
|
1462
|
-
|
|
1432
|
+
wstate.use_buf(ctx0, 0);
|
|
1463
1433
|
|
|
1464
1434
|
cur = ggml_norm(ctx0, inpL);
|
|
1465
1435
|
|
|
1466
1436
|
// cur = ln_0_w*cur + ln_0_b
|
|
1467
1437
|
cur = ggml_add(ctx0,
|
|
1468
|
-
|
|
1469
|
-
|
|
1470
|
-
|
|
1471
|
-
|
|
1438
|
+
ggml_mul(ctx0,
|
|
1439
|
+
ggml_repeat(ctx0, layer.attn_ln_0_w, cur),
|
|
1440
|
+
cur),
|
|
1441
|
+
ggml_repeat(ctx0, layer.attn_ln_0_b, cur));
|
|
1472
1442
|
}
|
|
1473
1443
|
|
|
1474
1444
|
// self-attention
|
|
1475
1445
|
{
|
|
1476
|
-
|
|
1446
|
+
wstate.use_buf(ctx0, 1);
|
|
1477
1447
|
|
|
1478
1448
|
struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
|
|
1479
|
-
|
|
1480
|
-
|
|
1449
|
+
layer.attn_q_w,
|
|
1450
|
+
cur);
|
|
1481
1451
|
|
|
1482
1452
|
Qcur = ggml_add(ctx0,
|
|
1483
|
-
|
|
1484
|
-
|
|
1485
|
-
|
|
1486
|
-
|
|
1453
|
+
ggml_repeat(ctx0,
|
|
1454
|
+
layer.attn_q_b,
|
|
1455
|
+
Qcur),
|
|
1456
|
+
Qcur);
|
|
1487
1457
|
|
|
1488
1458
|
//Qcur = ggml_scale(ctx0, Qcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
|
|
1489
1459
|
|
|
1490
1460
|
// note: no bias for Key
|
|
1491
1461
|
struct ggml_tensor * Kcur = ggml_mul_mat(ctx0,
|
|
1492
|
-
|
|
1493
|
-
|
|
1462
|
+
layer.attn_k_w,
|
|
1463
|
+
cur);
|
|
1494
1464
|
|
|
1495
1465
|
//Kcur = ggml_scale(ctx0, Kcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
|
|
1496
1466
|
|
|
1497
1467
|
struct ggml_tensor * Vcur = ggml_mul_mat(ctx0,
|
|
1498
|
-
|
|
1499
|
-
|
|
1468
|
+
layer.attn_v_w,
|
|
1469
|
+
cur);
|
|
1500
1470
|
|
|
1501
1471
|
Vcur = ggml_add(ctx0,
|
|
1502
|
-
|
|
1503
|
-
|
|
1504
|
-
|
|
1505
|
-
|
|
1472
|
+
ggml_repeat(ctx0,
|
|
1473
|
+
layer.attn_v_b,
|
|
1474
|
+
Vcur),
|
|
1475
|
+
Vcur);
|
|
1506
1476
|
|
|
1507
1477
|
// ------
|
|
1508
1478
|
|
|
1509
|
-
|
|
1479
|
+
wstate.use_buf(ctx0, 0);
|
|
1510
1480
|
|
|
1511
1481
|
#ifdef WHISPER_USE_FLASH_ATTN
|
|
1512
1482
|
struct ggml_tensor * Q =
|
|
@@ -1583,29 +1553,29 @@ static bool whisper_encode(
|
|
|
1583
1553
|
#endif
|
|
1584
1554
|
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
|
1585
1555
|
|
|
1586
|
-
|
|
1556
|
+
wstate.use_buf(ctx0, 1);
|
|
1587
1557
|
|
|
1588
1558
|
cur = ggml_cpy(ctx0,
|
|
1589
|
-
|
|
1590
|
-
|
|
1559
|
+
KQV_merged,
|
|
1560
|
+
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx));
|
|
1591
1561
|
}
|
|
1592
1562
|
|
|
1593
1563
|
// projection
|
|
1594
1564
|
{
|
|
1595
|
-
|
|
1565
|
+
wstate.use_buf(ctx0, 0);
|
|
1596
1566
|
|
|
1597
1567
|
cur = ggml_mul_mat(ctx0,
|
|
1598
|
-
|
|
1599
|
-
|
|
1568
|
+
layer.attn_ln_1_w,
|
|
1569
|
+
cur);
|
|
1600
1570
|
|
|
1601
|
-
|
|
1571
|
+
wstate.use_buf(ctx0, 1);
|
|
1602
1572
|
|
|
1603
1573
|
cur = ggml_add(ctx0,
|
|
1604
|
-
|
|
1605
|
-
|
|
1574
|
+
ggml_repeat(ctx0, layer.attn_ln_1_b, cur),
|
|
1575
|
+
cur);
|
|
1606
1576
|
}
|
|
1607
1577
|
|
|
1608
|
-
|
|
1578
|
+
wstate.use_buf(ctx0, 2);
|
|
1609
1579
|
|
|
1610
1580
|
// add the input
|
|
1611
1581
|
cur = ggml_add(ctx0, cur, inpL);
|
|
@@ -1616,61 +1586,61 @@ static bool whisper_encode(
|
|
|
1616
1586
|
{
|
|
1617
1587
|
// norm
|
|
1618
1588
|
{
|
|
1619
|
-
|
|
1589
|
+
wstate.use_buf(ctx0, 0);
|
|
1620
1590
|
|
|
1621
1591
|
cur = ggml_norm(ctx0, inpFF);
|
|
1622
1592
|
|
|
1623
|
-
|
|
1593
|
+
wstate.use_buf(ctx0, 1);
|
|
1624
1594
|
|
|
1625
1595
|
// cur = mlp_ln_w*cur + mlp_ln_b
|
|
1626
1596
|
cur = ggml_add(ctx0,
|
|
1627
|
-
|
|
1628
|
-
|
|
1629
|
-
|
|
1630
|
-
|
|
1631
|
-
|
|
1597
|
+
ggml_mul(ctx0,
|
|
1598
|
+
ggml_repeat(ctx0, layer.mlp_ln_w, cur),
|
|
1599
|
+
cur),
|
|
1600
|
+
ggml_repeat(ctx0, layer.mlp_ln_b, cur));
|
|
1601
|
+
}
|
|
1632
1602
|
|
|
1633
1603
|
#ifdef WHISPER_USE_FLASH_FF
|
|
1634
|
-
|
|
1604
|
+
wstate.use_buf(ctx0, 0);
|
|
1635
1605
|
|
|
1636
1606
|
cur = ggml_flash_ff(ctx0,
|
|
1637
|
-
|
|
1638
|
-
|
|
1607
|
+
ggml_cpy(ctx0, cur, ggml_new_tensor_2d(ctx0, wstate.wtype, n_state, n_ctx)),
|
|
1608
|
+
layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b);
|
|
1639
1609
|
#else
|
|
1640
|
-
|
|
1610
|
+
wstate.use_buf(ctx0, 0);
|
|
1641
1611
|
|
|
1642
1612
|
// fully connected
|
|
1643
1613
|
cur = ggml_mul_mat(ctx0,
|
|
1644
|
-
|
|
1645
|
-
|
|
1614
|
+
layer.mlp_0_w,
|
|
1615
|
+
cur);
|
|
1646
1616
|
|
|
1647
|
-
|
|
1617
|
+
wstate.use_buf(ctx0, 1);
|
|
1648
1618
|
|
|
1649
1619
|
cur = ggml_add(ctx0,
|
|
1650
|
-
|
|
1651
|
-
|
|
1620
|
+
ggml_repeat(ctx0, layer.mlp_0_b, cur),
|
|
1621
|
+
cur);
|
|
1652
1622
|
|
|
1653
|
-
|
|
1623
|
+
wstate.use_buf(ctx0, 0);
|
|
1654
1624
|
|
|
1655
1625
|
// GELU activation
|
|
1656
1626
|
cur = ggml_gelu(ctx0, cur);
|
|
1657
1627
|
|
|
1658
|
-
|
|
1628
|
+
wstate.use_buf(ctx0, 1);
|
|
1659
1629
|
|
|
1660
1630
|
// projection
|
|
1661
1631
|
cur = ggml_mul_mat(ctx0,
|
|
1662
|
-
|
|
1663
|
-
|
|
1632
|
+
layer.mlp_1_w,
|
|
1633
|
+
cur);
|
|
1664
1634
|
|
|
1665
|
-
|
|
1635
|
+
wstate.use_buf(ctx0, 0);
|
|
1666
1636
|
|
|
1667
1637
|
cur = ggml_add(ctx0,
|
|
1668
|
-
|
|
1669
|
-
|
|
1638
|
+
ggml_repeat(ctx0, layer.mlp_1_b, cur),
|
|
1639
|
+
cur);
|
|
1670
1640
|
#endif
|
|
1671
|
-
|
|
1641
|
+
}
|
|
1672
1642
|
|
|
1673
|
-
|
|
1643
|
+
wstate.use_buf(ctx0, 3);
|
|
1674
1644
|
|
|
1675
1645
|
inpL = ggml_add(ctx0, cur, inpFF);
|
|
1676
1646
|
}
|
|
@@ -1679,21 +1649,21 @@ static bool whisper_encode(
|
|
|
1679
1649
|
|
|
1680
1650
|
// norm
|
|
1681
1651
|
{
|
|
1682
|
-
|
|
1652
|
+
wstate.use_buf(ctx0, 0);
|
|
1683
1653
|
|
|
1684
1654
|
cur = ggml_norm(ctx0, cur);
|
|
1685
1655
|
|
|
1686
|
-
|
|
1656
|
+
wstate.use_buf(ctx0, 1);
|
|
1687
1657
|
|
|
1688
1658
|
// cur = ln_f_g*cur + ln_f_b
|
|
1689
1659
|
cur = ggml_add(ctx0,
|
|
1690
|
-
|
|
1691
|
-
|
|
1692
|
-
|
|
1693
|
-
|
|
1660
|
+
ggml_mul(ctx0,
|
|
1661
|
+
ggml_repeat(ctx0, model.e_ln_w, cur),
|
|
1662
|
+
cur),
|
|
1663
|
+
ggml_repeat(ctx0, model.e_ln_b, cur));
|
|
1694
1664
|
}
|
|
1695
1665
|
|
|
1696
|
-
|
|
1666
|
+
wstate.use_buf(ctx0, -1);
|
|
1697
1667
|
|
|
1698
1668
|
// run the computation
|
|
1699
1669
|
{
|
|
@@ -1701,7 +1671,7 @@ static bool whisper_encode(
|
|
|
1701
1671
|
gf.n_threads = n_threads;
|
|
1702
1672
|
|
|
1703
1673
|
ggml_build_forward_expand(&gf, cur);
|
|
1704
|
-
ggml_graph_compute
|
|
1674
|
+
ggml_graph_compute(ctx0, &gf);
|
|
1705
1675
|
|
|
1706
1676
|
//ggml_graph_print(&gf);
|
|
1707
1677
|
}
|
|
@@ -1731,34 +1701,34 @@ static bool whisper_encode(
|
|
|
1731
1701
|
cur->src1 = nullptr;
|
|
1732
1702
|
|
|
1733
1703
|
for (int il = 0; il < model.hparams.n_text_layer; ++il) {
|
|
1734
|
-
auto
|
|
1704
|
+
auto& layer = model.layers_decoder[il];
|
|
1735
1705
|
|
|
1736
|
-
|
|
1706
|
+
wstate.use_buf(ctx0, 0);
|
|
1737
1707
|
|
|
1738
|
-
struct ggml_tensor
|
|
1739
|
-
|
|
1740
|
-
|
|
1708
|
+
struct ggml_tensor* Kcross = ggml_mul_mat(ctx0,
|
|
1709
|
+
layer.cross_attn_k_w,
|
|
1710
|
+
cur);
|
|
1741
1711
|
|
|
1742
|
-
Kcross = ggml_scale(ctx0, Kcross, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
|
|
1712
|
+
Kcross = ggml_scale(ctx0, Kcross, ggml_new_f32(ctx0, pow(float(n_state) / n_head, -0.25)));
|
|
1743
1713
|
|
|
1744
|
-
|
|
1714
|
+
wstate.use_buf(ctx0, 1);
|
|
1745
1715
|
|
|
1746
|
-
struct ggml_tensor
|
|
1747
|
-
|
|
1748
|
-
|
|
1716
|
+
struct ggml_tensor* Vcross = ggml_mul_mat(ctx0,
|
|
1717
|
+
layer.cross_attn_v_w,
|
|
1718
|
+
cur);
|
|
1749
1719
|
|
|
1750
1720
|
Vcross = ggml_add(ctx0,
|
|
1751
|
-
|
|
1752
|
-
|
|
1753
|
-
|
|
1754
|
-
|
|
1721
|
+
ggml_repeat(ctx0,
|
|
1722
|
+
layer.cross_attn_v_b,
|
|
1723
|
+
Vcross),
|
|
1724
|
+
Vcross);
|
|
1755
1725
|
|
|
1756
|
-
|
|
1726
|
+
wstate.use_buf(ctx0, -1);
|
|
1757
1727
|
|
|
1758
|
-
//struct ggml_tensor * k = ggml_view_1d(ctx0,
|
|
1759
|
-
//struct ggml_tensor * v = ggml_view_1d(ctx0,
|
|
1760
|
-
struct ggml_tensor
|
|
1761
|
-
struct ggml_tensor
|
|
1728
|
+
//struct ggml_tensor * k = ggml_view_1d(ctx0, wstate.kv_cross.k, n_state*n_ctx, (ggml_element_size(wstate.kv_cross.k)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx));
|
|
1729
|
+
//struct ggml_tensor * v = ggml_view_1d(ctx0, wstate.kv_cross.v, n_state*n_ctx, (ggml_element_size(wstate.kv_cross.v)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx));
|
|
1730
|
+
struct ggml_tensor* k = ggml_view_1d(ctx0, wstate.kv_cross.k, n_state*n_ctx, (ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx));
|
|
1731
|
+
struct ggml_tensor* v = ggml_view_1d(ctx0, wstate.kv_cross.v, n_state*n_ctx, (ggml_element_size(wstate.kv_cross.v)*n_state)*(il*n_ctx));
|
|
1762
1732
|
|
|
1763
1733
|
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcross, k));
|
|
1764
1734
|
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcross, v));
|
|
@@ -1779,8 +1749,8 @@ static bool whisper_encode(
|
|
|
1779
1749
|
|
|
1780
1750
|
ggml_free(ctx0);
|
|
1781
1751
|
|
|
1782
|
-
|
|
1783
|
-
|
|
1752
|
+
wstate.t_encode_us += ggml_time_us() - t_start_us;
|
|
1753
|
+
wstate.n_encode++;
|
|
1784
1754
|
|
|
1785
1755
|
return true;
|
|
1786
1756
|
}
|
|
@@ -1795,8 +1765,9 @@ static bool whisper_encode(
|
|
|
1795
1765
|
// - n_tokens: number of tokens in the prompt
|
|
1796
1766
|
// - n_past: number of past tokens to prefix the prompt with
|
|
1797
1767
|
//
|
|
1798
|
-
static bool
|
|
1768
|
+
static bool whisper_decode_internal(
|
|
1799
1769
|
whisper_context & wctx,
|
|
1770
|
+
whisper_state & wstate,
|
|
1800
1771
|
whisper_decoder & decoder,
|
|
1801
1772
|
const whisper_token * tokens,
|
|
1802
1773
|
const int n_tokens,
|
|
@@ -1811,7 +1782,7 @@ static bool whisper_decode(
|
|
|
1811
1782
|
|
|
1812
1783
|
WHISPER_ASSERT(!!kv_self.ctx);
|
|
1813
1784
|
|
|
1814
|
-
auto & logits_out =
|
|
1785
|
+
auto & logits_out = wstate.logits;
|
|
1815
1786
|
|
|
1816
1787
|
const int n_vocab = hparams.n_vocab;
|
|
1817
1788
|
|
|
@@ -1821,13 +1792,13 @@ static bool whisper_decode(
|
|
|
1821
1792
|
const int n_layer = hparams.n_text_layer;
|
|
1822
1793
|
|
|
1823
1794
|
const int N = n_tokens;
|
|
1824
|
-
const int M =
|
|
1795
|
+
const int M = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx;
|
|
1825
1796
|
|
|
1826
1797
|
//WHISPER_PRINT_DEBUG("%s: n_past = %d, N = %d, M = %d, n_ctx = %d\n", __func__, n_past, N, M, n_ctx);
|
|
1827
1798
|
|
|
1828
1799
|
struct ggml_init_params params;
|
|
1829
|
-
params.mem_size =
|
|
1830
|
-
params.mem_buffer =
|
|
1800
|
+
params.mem_size = wstate.buf_compute.size();
|
|
1801
|
+
params.mem_buffer = wstate.buf_compute.data();
|
|
1831
1802
|
|
|
1832
1803
|
struct ggml_context * ctx0 = ggml_init(params);
|
|
1833
1804
|
|
|
@@ -1842,7 +1813,7 @@ static bool whisper_decode(
|
|
|
1842
1813
|
((int32_t *) position->data)[i] = n_past + i;
|
|
1843
1814
|
}
|
|
1844
1815
|
|
|
1845
|
-
|
|
1816
|
+
wstate.use_buf(ctx0, 3);
|
|
1846
1817
|
|
|
1847
1818
|
// token encoding + position encoding
|
|
1848
1819
|
struct ggml_tensor * cur =
|
|
@@ -1857,7 +1828,7 @@ static bool whisper_decode(
|
|
|
1857
1828
|
|
|
1858
1829
|
// norm
|
|
1859
1830
|
{
|
|
1860
|
-
|
|
1831
|
+
wstate.use_buf(ctx0, 0);
|
|
1861
1832
|
|
|
1862
1833
|
cur = ggml_norm(ctx0, inpL);
|
|
1863
1834
|
|
|
@@ -1871,7 +1842,7 @@ static bool whisper_decode(
|
|
|
1871
1842
|
|
|
1872
1843
|
// self-attention
|
|
1873
1844
|
{
|
|
1874
|
-
|
|
1845
|
+
wstate.use_buf(ctx0, 1);
|
|
1875
1846
|
|
|
1876
1847
|
struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
|
|
1877
1848
|
layer.attn_q_w,
|
|
@@ -1913,7 +1884,7 @@ static bool whisper_decode(
|
|
|
1913
1884
|
|
|
1914
1885
|
// ------
|
|
1915
1886
|
|
|
1916
|
-
|
|
1887
|
+
wstate.use_buf(ctx0, 0);
|
|
1917
1888
|
|
|
1918
1889
|
struct ggml_tensor * Q =
|
|
1919
1890
|
ggml_permute(ctx0,
|
|
@@ -1929,12 +1900,12 @@ static bool whisper_decode(
|
|
|
1929
1900
|
n_state/n_head, n_head, n_past + N),
|
|
1930
1901
|
0, 2, 1, 3);
|
|
1931
1902
|
|
|
1932
|
-
|
|
1903
|
+
wstate.use_buf(ctx0, 1);
|
|
1933
1904
|
|
|
1934
1905
|
// K * Q
|
|
1935
1906
|
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
|
|
1936
1907
|
|
|
1937
|
-
|
|
1908
|
+
wstate.use_buf(ctx0, 0);
|
|
1938
1909
|
|
|
1939
1910
|
//struct ggml_tensor * KQ_scaled =
|
|
1940
1911
|
// ggml_scale(ctx0,
|
|
@@ -1944,11 +1915,11 @@ static bool whisper_decode(
|
|
|
1944
1915
|
|
|
1945
1916
|
struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ, n_past);
|
|
1946
1917
|
|
|
1947
|
-
|
|
1918
|
+
wstate.use_buf(ctx0, 1);
|
|
1948
1919
|
|
|
1949
1920
|
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
|
|
1950
1921
|
|
|
1951
|
-
|
|
1922
|
+
wstate.use_buf(ctx0, 0);
|
|
1952
1923
|
|
|
1953
1924
|
struct ggml_tensor * V_trans =
|
|
1954
1925
|
ggml_permute(ctx0,
|
|
@@ -1957,7 +1928,7 @@ static bool whisper_decode(
|
|
|
1957
1928
|
n_state/n_head, n_head, n_past + N),
|
|
1958
1929
|
1, 2, 0, 3);
|
|
1959
1930
|
|
|
1960
|
-
|
|
1931
|
+
wstate.use_buf(ctx0, 1);
|
|
1961
1932
|
|
|
1962
1933
|
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
|
|
1963
1934
|
|
|
@@ -1970,31 +1941,31 @@ static bool whisper_decode(
|
|
|
1970
1941
|
|
|
1971
1942
|
// projection
|
|
1972
1943
|
{
|
|
1973
|
-
|
|
1944
|
+
wstate.use_buf(ctx0, 0);
|
|
1974
1945
|
|
|
1975
1946
|
cur = ggml_mul_mat(ctx0,
|
|
1976
1947
|
layer.attn_ln_1_w,
|
|
1977
1948
|
cur);
|
|
1978
1949
|
|
|
1979
|
-
|
|
1950
|
+
wstate.use_buf(ctx0, 1);
|
|
1980
1951
|
|
|
1981
1952
|
cur = ggml_add(ctx0,
|
|
1982
1953
|
ggml_repeat(ctx0, layer.attn_ln_1_b, cur),
|
|
1983
1954
|
cur);
|
|
1984
1955
|
}
|
|
1985
1956
|
|
|
1986
|
-
|
|
1957
|
+
wstate.use_buf(ctx0, 2);
|
|
1987
1958
|
|
|
1988
1959
|
// add the input
|
|
1989
1960
|
struct ggml_tensor * inpCA = ggml_add(ctx0, cur, inpL);
|
|
1990
1961
|
|
|
1991
1962
|
// norm
|
|
1992
1963
|
{
|
|
1993
|
-
|
|
1964
|
+
wstate.use_buf(ctx0, 0);
|
|
1994
1965
|
|
|
1995
1966
|
cur = ggml_norm(ctx0, inpCA); // note: we use inpCA here
|
|
1996
1967
|
|
|
1997
|
-
|
|
1968
|
+
wstate.use_buf(ctx0, 1);
|
|
1998
1969
|
|
|
1999
1970
|
// cur = ln_0_w*cur + ln_0_b
|
|
2000
1971
|
cur = ggml_add(ctx0,
|
|
@@ -2006,7 +1977,7 @@ static bool whisper_decode(
|
|
|
2006
1977
|
|
|
2007
1978
|
// cross-attention
|
|
2008
1979
|
{
|
|
2009
|
-
|
|
1980
|
+
wstate.use_buf(ctx0, 0);
|
|
2010
1981
|
|
|
2011
1982
|
struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
|
|
2012
1983
|
layer.cross_attn_q_w,
|
|
@@ -2023,19 +1994,19 @@ static bool whisper_decode(
|
|
|
2023
1994
|
// Kcross is already scaled
|
|
2024
1995
|
struct ggml_tensor * Kcross =
|
|
2025
1996
|
ggml_reshape_3d(ctx0,
|
|
2026
|
-
ggml_view_1d(ctx0,
|
|
1997
|
+
ggml_view_1d(ctx0, wstate.kv_cross.k, M*n_state, il*M*ggml_element_size(wstate.kv_cross.k)*n_state),
|
|
2027
1998
|
n_state/n_head, n_head, M);
|
|
2028
1999
|
|
|
2029
2000
|
struct ggml_tensor * Vcross =
|
|
2030
2001
|
ggml_reshape_3d(ctx0,
|
|
2031
|
-
ggml_view_1d(ctx0,
|
|
2002
|
+
ggml_view_1d(ctx0, wstate.kv_cross.v, M*n_state, il*M*ggml_element_size(wstate.kv_cross.v)*n_state),
|
|
2032
2003
|
n_state/n_head, n_head, M);
|
|
2033
2004
|
|
|
2034
2005
|
struct ggml_tensor * V_trans = ggml_permute(ctx0, Vcross, 1, 2, 0, 3);
|
|
2035
2006
|
|
|
2036
2007
|
// ------
|
|
2037
2008
|
|
|
2038
|
-
|
|
2009
|
+
wstate.use_buf(ctx0, 1);
|
|
2039
2010
|
|
|
2040
2011
|
struct ggml_tensor * Q =
|
|
2041
2012
|
ggml_permute(ctx0,
|
|
@@ -2046,7 +2017,7 @@ static bool whisper_decode(
|
|
|
2046
2017
|
|
|
2047
2018
|
struct ggml_tensor * K = ggml_permute(ctx0, Kcross, 0, 2, 1, 3);
|
|
2048
2019
|
|
|
2049
|
-
|
|
2020
|
+
wstate.use_buf(ctx0, 0);
|
|
2050
2021
|
|
|
2051
2022
|
// K * Q
|
|
2052
2023
|
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
|
|
@@ -2060,15 +2031,15 @@ static bool whisper_decode(
|
|
|
2060
2031
|
// no masking for cross-attention
|
|
2061
2032
|
//struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
|
|
2062
2033
|
|
|
2063
|
-
|
|
2034
|
+
wstate.use_buf(ctx0, 1);
|
|
2064
2035
|
|
|
2065
2036
|
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ);
|
|
2066
2037
|
|
|
2067
|
-
|
|
2038
|
+
wstate.use_buf(ctx0, 0);
|
|
2068
2039
|
|
|
2069
2040
|
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
|
|
2070
2041
|
|
|
2071
|
-
|
|
2042
|
+
wstate.use_buf(ctx0, 1);
|
|
2072
2043
|
|
|
2073
2044
|
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
|
2074
2045
|
|
|
@@ -2080,20 +2051,20 @@ static bool whisper_decode(
|
|
|
2080
2051
|
|
|
2081
2052
|
// projection
|
|
2082
2053
|
{
|
|
2083
|
-
|
|
2054
|
+
wstate.use_buf(ctx0, 0);
|
|
2084
2055
|
|
|
2085
2056
|
cur = ggml_mul_mat(ctx0,
|
|
2086
2057
|
layer.cross_attn_ln_1_w,
|
|
2087
2058
|
cur);
|
|
2088
2059
|
|
|
2089
|
-
|
|
2060
|
+
wstate.use_buf(ctx0, 1);
|
|
2090
2061
|
|
|
2091
2062
|
cur = ggml_add(ctx0,
|
|
2092
2063
|
ggml_repeat(ctx0, layer.cross_attn_ln_1_b, cur),
|
|
2093
2064
|
cur);
|
|
2094
2065
|
}
|
|
2095
2066
|
|
|
2096
|
-
|
|
2067
|
+
wstate.use_buf(ctx0, 2);
|
|
2097
2068
|
|
|
2098
2069
|
// add the input
|
|
2099
2070
|
cur = ggml_add(ctx0, cur, inpCA);
|
|
@@ -2104,11 +2075,11 @@ static bool whisper_decode(
|
|
|
2104
2075
|
{
|
|
2105
2076
|
// norm
|
|
2106
2077
|
{
|
|
2107
|
-
|
|
2078
|
+
wstate.use_buf(ctx0, 0);
|
|
2108
2079
|
|
|
2109
2080
|
cur = ggml_norm(ctx0, inpFF);
|
|
2110
2081
|
|
|
2111
|
-
|
|
2082
|
+
wstate.use_buf(ctx0, 1);
|
|
2112
2083
|
|
|
2113
2084
|
// cur = mlp_ln_w*cur + mlp_ln_b
|
|
2114
2085
|
cur = ggml_add(ctx0,
|
|
@@ -2118,39 +2089,39 @@ static bool whisper_decode(
|
|
|
2118
2089
|
ggml_repeat(ctx0, layer.mlp_ln_b, cur));
|
|
2119
2090
|
}
|
|
2120
2091
|
|
|
2121
|
-
|
|
2092
|
+
wstate.use_buf(ctx0, 0);
|
|
2122
2093
|
|
|
2123
2094
|
// fully connected
|
|
2124
2095
|
cur = ggml_mul_mat(ctx0,
|
|
2125
2096
|
layer.mlp_0_w,
|
|
2126
2097
|
cur);
|
|
2127
2098
|
|
|
2128
|
-
|
|
2099
|
+
wstate.use_buf(ctx0, 1);
|
|
2129
2100
|
|
|
2130
2101
|
cur = ggml_add(ctx0,
|
|
2131
2102
|
ggml_repeat(ctx0, layer.mlp_0_b, cur),
|
|
2132
2103
|
cur);
|
|
2133
2104
|
|
|
2134
|
-
|
|
2105
|
+
wstate.use_buf(ctx0, 0);
|
|
2135
2106
|
|
|
2136
2107
|
// GELU activation
|
|
2137
2108
|
cur = ggml_gelu(ctx0, cur);
|
|
2138
2109
|
|
|
2139
|
-
|
|
2110
|
+
wstate.use_buf(ctx0, 1);
|
|
2140
2111
|
|
|
2141
2112
|
// projection
|
|
2142
2113
|
cur = ggml_mul_mat(ctx0,
|
|
2143
2114
|
layer.mlp_1_w,
|
|
2144
2115
|
cur);
|
|
2145
2116
|
|
|
2146
|
-
|
|
2117
|
+
wstate.use_buf(ctx0, 0);
|
|
2147
2118
|
|
|
2148
2119
|
cur = ggml_add(ctx0,
|
|
2149
2120
|
ggml_repeat(ctx0, layer.mlp_1_b, cur),
|
|
2150
2121
|
cur);
|
|
2151
2122
|
}
|
|
2152
2123
|
|
|
2153
|
-
|
|
2124
|
+
wstate.use_buf(ctx0, 3);
|
|
2154
2125
|
|
|
2155
2126
|
inpL = ggml_add(ctx0, cur, inpFF);
|
|
2156
2127
|
}
|
|
@@ -2159,11 +2130,11 @@ static bool whisper_decode(
|
|
|
2159
2130
|
|
|
2160
2131
|
// norm
|
|
2161
2132
|
{
|
|
2162
|
-
|
|
2133
|
+
wstate.use_buf(ctx0, 0);
|
|
2163
2134
|
|
|
2164
2135
|
cur = ggml_norm(ctx0, cur);
|
|
2165
2136
|
|
|
2166
|
-
|
|
2137
|
+
wstate.use_buf(ctx0, 1);
|
|
2167
2138
|
|
|
2168
2139
|
cur = ggml_add(ctx0,
|
|
2169
2140
|
ggml_mul(ctx0,
|
|
@@ -2172,7 +2143,7 @@ static bool whisper_decode(
|
|
|
2172
2143
|
ggml_repeat(ctx0, model.d_ln_b, cur));
|
|
2173
2144
|
}
|
|
2174
2145
|
|
|
2175
|
-
|
|
2146
|
+
wstate.use_buf(ctx0, 0);
|
|
2176
2147
|
|
|
2177
2148
|
// compute logits only for the last token
|
|
2178
2149
|
// comment this line to compute logits for all N tokens
|
|
@@ -2181,7 +2152,7 @@ static bool whisper_decode(
|
|
|
2181
2152
|
|
|
2182
2153
|
struct ggml_tensor * logits = ggml_mul_mat(ctx0, model.d_te, cur);
|
|
2183
2154
|
|
|
2184
|
-
|
|
2155
|
+
wstate.use_buf(ctx0, -1);
|
|
2185
2156
|
|
|
2186
2157
|
// run the computation
|
|
2187
2158
|
{
|
|
@@ -2208,8 +2179,8 @@ static bool whisper_decode(
|
|
|
2208
2179
|
|
|
2209
2180
|
ggml_free(ctx0);
|
|
2210
2181
|
|
|
2211
|
-
|
|
2212
|
-
|
|
2182
|
+
wstate.t_decode_us += ggml_time_us() - t_start_us;
|
|
2183
|
+
wstate.n_decode++;
|
|
2213
2184
|
|
|
2214
2185
|
return true;
|
|
2215
2186
|
}
|
|
@@ -2313,7 +2284,7 @@ static void fft(const std::vector<float> & in, std::vector<float> & out) {
|
|
|
2313
2284
|
|
|
2314
2285
|
// ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L92-L124
|
|
2315
2286
|
static bool log_mel_spectrogram(
|
|
2316
|
-
|
|
2287
|
+
whisper_state & wstate,
|
|
2317
2288
|
const float * samples,
|
|
2318
2289
|
const int n_samples,
|
|
2319
2290
|
const int /*sample_rate*/,
|
|
@@ -2433,7 +2404,7 @@ static bool log_mel_spectrogram(
|
|
|
2433
2404
|
mel.data[i] = (mel.data[i] + 4.0)/4.0;
|
|
2434
2405
|
}
|
|
2435
2406
|
|
|
2436
|
-
|
|
2407
|
+
wstate.t_mel_us += ggml_time_us() - t_start_us;
|
|
2437
2408
|
|
|
2438
2409
|
return true;
|
|
2439
2410
|
}
|
|
@@ -2507,7 +2478,56 @@ static std::vector<whisper_vocab::id> tokenize(const whisper_vocab & vocab, cons
|
|
|
2507
2478
|
// interface implementation
|
|
2508
2479
|
//
|
|
2509
2480
|
|
|
2510
|
-
struct
|
|
2481
|
+
struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
|
2482
|
+
whisper_state * state = new whisper_state;
|
|
2483
|
+
|
|
2484
|
+
const size_t scale = ctx->model.hparams.f16 ? 1 : 2;
|
|
2485
|
+
|
|
2486
|
+
|
|
2487
|
+
if (!kv_cache_init(ctx->model.hparams, scale * MEM_REQ_KV_SELF.at(ctx->model.type), state->decoders[0].kv_self, ctx->wtype, ctx->model.hparams.n_text_ctx)) {
|
|
2488
|
+
fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__);
|
|
2489
|
+
return nullptr;
|
|
2490
|
+
}
|
|
2491
|
+
|
|
2492
|
+
{
|
|
2493
|
+
const size_t memory_size = ggml_nbytes(state->decoders[0].kv_self.k) + ggml_nbytes(state->decoders[0].kv_self.v);
|
|
2494
|
+
fprintf(stderr, "%s: kv self size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
|
|
2495
|
+
}
|
|
2496
|
+
|
|
2497
|
+
if (!kv_cache_init(ctx->model.hparams, scale * MEM_REQ_KV_CROSS.at(ctx->model.type), state->kv_cross, ctx->wtype, ctx->model.hparams.n_audio_ctx)) {
|
|
2498
|
+
fprintf(stderr, "%s: kv_cache_init() failed for cross-attention cache\n", __func__);
|
|
2499
|
+
return nullptr;
|
|
2500
|
+
}
|
|
2501
|
+
|
|
2502
|
+
{
|
|
2503
|
+
const size_t memory_size = ggml_nbytes(state->kv_cross.k) + ggml_nbytes(state->kv_cross.v);
|
|
2504
|
+
fprintf(stderr, "%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
|
|
2505
|
+
}
|
|
2506
|
+
|
|
2507
|
+
|
|
2508
|
+
state->logits.reserve(ctx->vocab.n_vocab * ctx->model.hparams.n_text_ctx);
|
|
2509
|
+
|
|
2510
|
+
state->logits_id.reserve(ctx->model.hparams.n_vocab);
|
|
2511
|
+
|
|
2512
|
+
// TAGS: WHISPER_DECODER_INIT
|
|
2513
|
+
state->decoders[0].sequence.tokens.reserve(ctx->model.hparams.n_text_ctx);
|
|
2514
|
+
|
|
2515
|
+
state->decoders[0].probs.reserve(ctx->vocab.n_vocab);
|
|
2516
|
+
state->decoders[0].logits.reserve(ctx->vocab.n_vocab);
|
|
2517
|
+
state->decoders[0].logprobs.reserve(ctx->vocab.n_vocab);
|
|
2518
|
+
state->buf_compute.resize(scale * std::max(MEM_REQ_ENCODE.at(ctx->model.type), MEM_REQ_DECODE.at(ctx->model.type)));
|
|
2519
|
+
|
|
2520
|
+
state->buf_scratch[0].resize(MEM_REQ_SCRATCH0.at(ctx->model.type));
|
|
2521
|
+
state->buf_scratch[1].resize(MEM_REQ_SCRATCH1.at(ctx->model.type));
|
|
2522
|
+
state->buf_scratch[2].resize(MEM_REQ_SCRATCH2.at(ctx->model.type));
|
|
2523
|
+
state->buf_scratch[3].resize(MEM_REQ_SCRATCH3.at(ctx->model.type));
|
|
2524
|
+
|
|
2525
|
+
state->rng = std::mt19937(0);
|
|
2526
|
+
|
|
2527
|
+
return state;
|
|
2528
|
+
}
|
|
2529
|
+
|
|
2530
|
+
struct whisper_context * whisper_init_from_file_no_state(const char * path_model) {
|
|
2511
2531
|
whisper_model_loader loader = {};
|
|
2512
2532
|
|
|
2513
2533
|
fprintf(stderr, "%s: loading model from '%s'\n", __func__, path_model);
|
|
@@ -2535,10 +2555,10 @@ struct whisper_context * whisper_init_from_file(const char * path_model) {
|
|
|
2535
2555
|
fin->close();
|
|
2536
2556
|
};
|
|
2537
2557
|
|
|
2538
|
-
return
|
|
2558
|
+
return whisper_init_no_state(&loader);
|
|
2539
2559
|
}
|
|
2540
2560
|
|
|
2541
|
-
struct whisper_context *
|
|
2561
|
+
struct whisper_context * whisper_init_from_buffer_no_state(void * buffer, size_t buffer_size) {
|
|
2542
2562
|
struct buf_context {
|
|
2543
2563
|
uint8_t* buffer;
|
|
2544
2564
|
size_t size;
|
|
@@ -2571,10 +2591,10 @@ struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_s
|
|
|
2571
2591
|
|
|
2572
2592
|
loader.close = [](void * /*ctx*/) { };
|
|
2573
2593
|
|
|
2574
|
-
return
|
|
2594
|
+
return whisper_init_no_state(&loader);
|
|
2575
2595
|
}
|
|
2576
2596
|
|
|
2577
|
-
struct whisper_context *
|
|
2597
|
+
struct whisper_context * whisper_init_no_state(struct whisper_model_loader * loader) {
|
|
2578
2598
|
ggml_time_init();
|
|
2579
2599
|
|
|
2580
2600
|
whisper_context * ctx = new whisper_context;
|
|
@@ -2591,6 +2611,64 @@ struct whisper_context * whisper_init(struct whisper_model_loader * loader) {
|
|
|
2591
2611
|
return ctx;
|
|
2592
2612
|
}
|
|
2593
2613
|
|
|
2614
|
+
struct whisper_context * whisper_init_from_file(const char * path_model) {
|
|
2615
|
+
whisper_context * ctx = whisper_init_from_file_no_state(path_model);
|
|
2616
|
+
if (!ctx) {
|
|
2617
|
+
return nullptr;
|
|
2618
|
+
}
|
|
2619
|
+
|
|
2620
|
+
ctx->state = whisper_init_state(ctx);
|
|
2621
|
+
if (!ctx->state) {
|
|
2622
|
+
whisper_free(ctx);
|
|
2623
|
+
return nullptr;
|
|
2624
|
+
}
|
|
2625
|
+
|
|
2626
|
+
return ctx;
|
|
2627
|
+
}
|
|
2628
|
+
|
|
2629
|
+
struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_size) {
|
|
2630
|
+
whisper_context * ctx = whisper_init_from_buffer_no_state(buffer, buffer_size);
|
|
2631
|
+
if (!ctx) {
|
|
2632
|
+
return nullptr;
|
|
2633
|
+
}
|
|
2634
|
+
|
|
2635
|
+
ctx->state = whisper_init_state(ctx);
|
|
2636
|
+
if (!ctx->state) {
|
|
2637
|
+
whisper_free(ctx);
|
|
2638
|
+
return nullptr;
|
|
2639
|
+
}
|
|
2640
|
+
|
|
2641
|
+
return ctx;
|
|
2642
|
+
}
|
|
2643
|
+
|
|
2644
|
+
struct whisper_context * whisper_init(struct whisper_model_loader * loader) {
|
|
2645
|
+
whisper_context * ctx = whisper_init_no_state(loader);
|
|
2646
|
+
if (!ctx) {
|
|
2647
|
+
return nullptr;
|
|
2648
|
+
}
|
|
2649
|
+
|
|
2650
|
+
ctx->state = whisper_init_state(ctx);
|
|
2651
|
+
if (!ctx->state) {
|
|
2652
|
+
whisper_free(ctx);
|
|
2653
|
+
return nullptr;
|
|
2654
|
+
}
|
|
2655
|
+
|
|
2656
|
+
return ctx;
|
|
2657
|
+
}
|
|
2658
|
+
|
|
2659
|
+
void whisper_free_state(struct whisper_state * state)
|
|
2660
|
+
{
|
|
2661
|
+
if (state) {
|
|
2662
|
+
kv_cache_free(state->kv_cross);
|
|
2663
|
+
|
|
2664
|
+
for (int i = 0; i < WHISPER_MAX_DECODERS; ++i) {
|
|
2665
|
+
kv_cache_free(state->decoders[i].kv_self);
|
|
2666
|
+
}
|
|
2667
|
+
|
|
2668
|
+
delete state;
|
|
2669
|
+
}
|
|
2670
|
+
}
|
|
2671
|
+
|
|
2594
2672
|
void whisper_free(struct whisper_context * ctx) {
|
|
2595
2673
|
if (ctx) {
|
|
2596
2674
|
if (ctx->model.ctx) {
|
|
@@ -2599,20 +2677,29 @@ void whisper_free(struct whisper_context * ctx) {
|
|
|
2599
2677
|
if (ctx->model.buf) {
|
|
2600
2678
|
delete ctx->model.buf;
|
|
2601
2679
|
}
|
|
2602
|
-
|
|
2603
|
-
|
|
2604
|
-
|
|
2605
|
-
for (int i = 0; i < WHISPER_MAX_DECODERS; ++i) {
|
|
2606
|
-
if (ctx->decoders[i].kv_self.ctx) {
|
|
2607
|
-
ggml_free(ctx->decoders[i].kv_self.ctx);
|
|
2608
|
-
}
|
|
2609
|
-
}
|
|
2680
|
+
|
|
2681
|
+
whisper_free_state(ctx->state);
|
|
2682
|
+
|
|
2610
2683
|
delete ctx;
|
|
2611
2684
|
}
|
|
2612
2685
|
}
|
|
2613
2686
|
|
|
2687
|
+
int whisper_pcm_to_mel_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) {
|
|
2688
|
+
if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, false, state->mel)) {
|
|
2689
|
+
fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__);
|
|
2690
|
+
return -1;
|
|
2691
|
+
}
|
|
2692
|
+
|
|
2693
|
+
return 0;
|
|
2694
|
+
}
|
|
2695
|
+
|
|
2614
2696
|
int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) {
|
|
2615
|
-
|
|
2697
|
+
return whisper_pcm_to_mel_with_state(ctx, ctx->state, samples, n_samples, n_threads);
|
|
2698
|
+
}
|
|
2699
|
+
|
|
2700
|
+
// same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2
|
|
2701
|
+
int whisper_pcm_to_mel_phase_vocoder_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) {
|
|
2702
|
+
if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, 2 * WHISPER_N_FFT, 2 * WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, true, state->mel)) {
|
|
2616
2703
|
fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__);
|
|
2617
2704
|
return -1;
|
|
2618
2705
|
}
|
|
@@ -2622,11 +2709,26 @@ int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int
|
|
|
2622
2709
|
|
|
2623
2710
|
// same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2
|
|
2624
2711
|
int whisper_pcm_to_mel_phase_vocoder(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) {
|
|
2625
|
-
|
|
2626
|
-
|
|
2712
|
+
return whisper_pcm_to_mel_phase_vocoder_with_state(ctx, ctx->state, samples, n_samples, n_threads);
|
|
2713
|
+
}
|
|
2714
|
+
|
|
2715
|
+
int whisper_set_mel_with_state(
|
|
2716
|
+
struct whisper_context * /*ctx*/,
|
|
2717
|
+
struct whisper_state * state,
|
|
2718
|
+
const float * data,
|
|
2719
|
+
int n_len,
|
|
2720
|
+
int n_mel) {
|
|
2721
|
+
if (n_mel != WHISPER_N_MEL) {
|
|
2722
|
+
fprintf(stderr, "%s: invalid number of mel bands: %d (expected %d)\n", __func__, n_mel, WHISPER_N_MEL);
|
|
2627
2723
|
return -1;
|
|
2628
2724
|
}
|
|
2629
2725
|
|
|
2726
|
+
state->mel.n_len = n_len;
|
|
2727
|
+
state->mel.n_mel = n_mel;
|
|
2728
|
+
|
|
2729
|
+
state->mel.data.resize(n_len*n_mel);
|
|
2730
|
+
memcpy(state->mel.data.data(), data, n_len*n_mel*sizeof(float));
|
|
2731
|
+
|
|
2630
2732
|
return 0;
|
|
2631
2733
|
}
|
|
2632
2734
|
|
|
@@ -2635,22 +2737,20 @@ int whisper_set_mel(
|
|
|
2635
2737
|
const float * data,
|
|
2636
2738
|
int n_len,
|
|
2637
2739
|
int n_mel) {
|
|
2638
|
-
|
|
2639
|
-
|
|
2740
|
+
return whisper_set_mel_with_state(ctx, ctx->state, data, n_len, n_mel);
|
|
2741
|
+
}
|
|
2742
|
+
|
|
2743
|
+
int whisper_encode_with_state(struct whisper_context * ctx, struct whisper_state * state, int offset, int n_threads) {
|
|
2744
|
+
if (!whisper_encode_internal(*ctx, *state, offset, n_threads)) {
|
|
2745
|
+
fprintf(stderr, "%s: failed to eval\n", __func__);
|
|
2640
2746
|
return -1;
|
|
2641
2747
|
}
|
|
2642
2748
|
|
|
2643
|
-
ctx->mel.n_len = n_len;
|
|
2644
|
-
ctx->mel.n_mel = n_mel;
|
|
2645
|
-
|
|
2646
|
-
ctx->mel.data.resize(n_len*n_mel);
|
|
2647
|
-
memcpy(ctx->mel.data.data(), data, n_len*n_mel*sizeof(float));
|
|
2648
|
-
|
|
2649
2749
|
return 0;
|
|
2650
2750
|
}
|
|
2651
2751
|
|
|
2652
2752
|
int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) {
|
|
2653
|
-
if (!
|
|
2753
|
+
if (!whisper_encode_internal(*ctx, *ctx->state, offset, n_threads)) {
|
|
2654
2754
|
fprintf(stderr, "%s: failed to eval\n", __func__);
|
|
2655
2755
|
return -1;
|
|
2656
2756
|
}
|
|
@@ -2658,11 +2758,28 @@ int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) {
|
|
|
2658
2758
|
return 0;
|
|
2659
2759
|
}
|
|
2660
2760
|
|
|
2761
|
+
int whisper_decode_with_state(struct whisper_context * ctx, struct whisper_state * state, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) {
|
|
2762
|
+
const int selected_decoder_id = 0;
|
|
2763
|
+
|
|
2764
|
+
if (!whisper_decode_internal(*ctx, *state, state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads)) {
|
|
2765
|
+
fprintf(stderr, "%s: failed to eval\n", __func__);
|
|
2766
|
+
return 1;
|
|
2767
|
+
}
|
|
2768
|
+
|
|
2769
|
+
return 0;
|
|
2770
|
+
}
|
|
2771
|
+
|
|
2661
2772
|
int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) {
|
|
2662
|
-
// TODO: add selected_decoder_id to
|
|
2773
|
+
// TODO: add selected_decoder_id to state
|
|
2663
2774
|
const int selected_decoder_id = 0;
|
|
2664
2775
|
|
|
2665
|
-
if (
|
|
2776
|
+
if (ctx->state == nullptr) {
|
|
2777
|
+
fprintf(stderr, "%s: ERROR state was not loaded.\n", __func__);
|
|
2778
|
+
return false;
|
|
2779
|
+
}
|
|
2780
|
+
|
|
2781
|
+
|
|
2782
|
+
if (!whisper_decode_internal(*ctx, *ctx->state, ctx->state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads)) {
|
|
2666
2783
|
fprintf(stderr, "%s: failed to eval\n", __func__);
|
|
2667
2784
|
return 1;
|
|
2668
2785
|
}
|
|
@@ -2720,11 +2837,12 @@ const char * whisper_lang_str(int id) {
|
|
|
2720
2837
|
return nullptr;
|
|
2721
2838
|
}
|
|
2722
2839
|
|
|
2723
|
-
int
|
|
2840
|
+
int whisper_lang_auto_detect_with_state(
|
|
2724
2841
|
struct whisper_context * ctx,
|
|
2725
|
-
|
|
2726
|
-
|
|
2727
|
-
|
|
2842
|
+
struct whisper_state * state,
|
|
2843
|
+
int offset_ms,
|
|
2844
|
+
int n_threads,
|
|
2845
|
+
float * lang_probs) {
|
|
2728
2846
|
const int seek = offset_ms/10;
|
|
2729
2847
|
|
|
2730
2848
|
if (seek < 0) {
|
|
@@ -2732,8 +2850,8 @@ int whisper_lang_auto_detect(
|
|
|
2732
2850
|
return -1;
|
|
2733
2851
|
}
|
|
2734
2852
|
|
|
2735
|
-
if (seek >=
|
|
2736
|
-
fprintf(stderr, "%s: offset %dms is past the end of the audio (%dms)\n", __func__, offset_ms,
|
|
2853
|
+
if (seek >= state->mel.n_len) {
|
|
2854
|
+
fprintf(stderr, "%s: offset %dms is past the end of the audio (%dms)\n", __func__, offset_ms, state->mel.n_len*10);
|
|
2737
2855
|
return -2;
|
|
2738
2856
|
}
|
|
2739
2857
|
|
|
@@ -2745,17 +2863,17 @@ int whisper_lang_auto_detect(
|
|
|
2745
2863
|
|
|
2746
2864
|
const std::vector<whisper_token> prompt = { whisper_token_sot(ctx) };
|
|
2747
2865
|
|
|
2748
|
-
if (
|
|
2866
|
+
if (whisper_decode_with_state(ctx, state, prompt.data(), prompt.size(), 0, n_threads) != 0) {
|
|
2749
2867
|
fprintf(stderr, "%s: failed to decode\n", __func__);
|
|
2750
2868
|
return -7;
|
|
2751
2869
|
}
|
|
2752
2870
|
|
|
2753
|
-
auto & logits_id =
|
|
2871
|
+
auto & logits_id = state->logits_id;
|
|
2754
2872
|
logits_id.clear();
|
|
2755
2873
|
|
|
2756
2874
|
for (const auto & kv : g_lang) {
|
|
2757
2875
|
const auto token_lang = whisper_token_lang(ctx, kv.second.first);
|
|
2758
|
-
logits_id.emplace_back(
|
|
2876
|
+
logits_id.emplace_back(state->logits[token_lang], kv.second.first);
|
|
2759
2877
|
}
|
|
2760
2878
|
|
|
2761
2879
|
// sort descending
|
|
@@ -2794,8 +2912,20 @@ int whisper_lang_auto_detect(
|
|
|
2794
2912
|
return logits_id[0].second;
|
|
2795
2913
|
}
|
|
2796
2914
|
|
|
2915
|
+
int whisper_lang_auto_detect(
|
|
2916
|
+
struct whisper_context * ctx,
|
|
2917
|
+
int offset_ms,
|
|
2918
|
+
int n_threads,
|
|
2919
|
+
float * lang_probs) {
|
|
2920
|
+
return whisper_lang_auto_detect_with_state(ctx, ctx->state, offset_ms, n_threads, lang_probs);
|
|
2921
|
+
}
|
|
2922
|
+
|
|
2923
|
+
int whisper_n_len_from_state(struct whisper_state * state) {
|
|
2924
|
+
return state->mel.n_len;
|
|
2925
|
+
}
|
|
2926
|
+
|
|
2797
2927
|
int whisper_n_len(struct whisper_context * ctx) {
|
|
2798
|
-
return ctx->mel.n_len;
|
|
2928
|
+
return ctx->state->mel.n_len;
|
|
2799
2929
|
}
|
|
2800
2930
|
|
|
2801
2931
|
int whisper_n_vocab(struct whisper_context * ctx) {
|
|
@@ -2815,7 +2945,12 @@ int whisper_is_multilingual(struct whisper_context * ctx) {
|
|
|
2815
2945
|
}
|
|
2816
2946
|
|
|
2817
2947
|
float * whisper_get_logits(struct whisper_context * ctx) {
|
|
2818
|
-
return ctx->logits.data();
|
|
2948
|
+
return ctx->state->logits.data();
|
|
2949
|
+
}
|
|
2950
|
+
|
|
2951
|
+
|
|
2952
|
+
float * whisper_get_logits_from_state(struct whisper_state * state) {
|
|
2953
|
+
return state->logits.data();
|
|
2819
2954
|
}
|
|
2820
2955
|
|
|
2821
2956
|
const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token) {
|
|
@@ -2861,24 +2996,29 @@ whisper_token whisper_token_transcribe(void) {
|
|
|
2861
2996
|
void whisper_print_timings(struct whisper_context * ctx) {
|
|
2862
2997
|
const int64_t t_end_us = ggml_time_us();
|
|
2863
2998
|
|
|
2864
|
-
const int32_t n_sample = std::max(1, ctx->n_sample);
|
|
2865
|
-
const int32_t n_encode = std::max(1, ctx->n_encode);
|
|
2866
|
-
const int32_t n_decode = std::max(1, ctx->n_decode);
|
|
2867
|
-
|
|
2868
2999
|
fprintf(stderr, "\n");
|
|
2869
|
-
fprintf(stderr, "%s:
|
|
2870
|
-
|
|
2871
|
-
|
|
2872
|
-
|
|
2873
|
-
|
|
2874
|
-
|
|
3000
|
+
fprintf(stderr, "%s: load time = %8.2f ms\n", __func__, ctx->t_load_us / 1000.0f);
|
|
3001
|
+
if (ctx->state != nullptr) {
|
|
3002
|
+
|
|
3003
|
+
const int32_t n_sample = std::max(1, ctx->state->n_sample);
|
|
3004
|
+
const int32_t n_encode = std::max(1, ctx->state->n_encode);
|
|
3005
|
+
const int32_t n_decode = std::max(1, ctx->state->n_decode);
|
|
3006
|
+
|
|
3007
|
+
fprintf(stderr, "%s: fallbacks = %3d p / %3d h\n", __func__, ctx->state->n_fail_p, ctx->state->n_fail_h);
|
|
3008
|
+
fprintf(stderr, "%s: mel time = %8.2f ms\n", __func__, ctx->state->t_mel_us / 1000.0f);
|
|
3009
|
+
fprintf(stderr, "%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_sample_us, n_sample, 1e-3f * ctx->state->t_sample_us / n_sample);
|
|
3010
|
+
fprintf(stderr, "%s: encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_encode_us, n_encode, 1e-3f * ctx->state->t_encode_us / n_encode);
|
|
3011
|
+
fprintf(stderr, "%s: decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_decode_us, n_decode, 1e-3f * ctx->state->t_decode_us / n_decode);
|
|
3012
|
+
}
|
|
2875
3013
|
fprintf(stderr, "%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f);
|
|
2876
3014
|
}
|
|
2877
3015
|
|
|
2878
3016
|
void whisper_reset_timings(struct whisper_context * ctx) {
|
|
2879
|
-
ctx->
|
|
2880
|
-
|
|
2881
|
-
|
|
3017
|
+
if (ctx->state != nullptr) {
|
|
3018
|
+
ctx->state->t_sample_us = 0;
|
|
3019
|
+
ctx->state->t_encode_us = 0;
|
|
3020
|
+
ctx->state->t_decode_us = 0;
|
|
3021
|
+
}
|
|
2882
3022
|
}
|
|
2883
3023
|
|
|
2884
3024
|
const char * whisper_print_system_info(void) {
|
|
@@ -2913,7 +3053,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
|
|
|
2913
3053
|
/*.duration_ms =*/ 0,
|
|
2914
3054
|
|
|
2915
3055
|
/*.translate =*/ false,
|
|
2916
|
-
/*.no_context =*/
|
|
3056
|
+
/*.no_context =*/ true,
|
|
2917
3057
|
/*.single_segment =*/ false,
|
|
2918
3058
|
/*.print_special =*/ false,
|
|
2919
3059
|
/*.print_progress =*/ true,
|
|
@@ -2991,6 +3131,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
|
|
|
2991
3131
|
static std::vector<float> get_signal_energy(const float * signal, int n_samples, int n_samples_per_half_window);
|
|
2992
3132
|
static void whisper_exp_compute_token_level_timestamps(
|
|
2993
3133
|
struct whisper_context & ctx,
|
|
3134
|
+
struct whisper_state & state,
|
|
2994
3135
|
int i_segment,
|
|
2995
3136
|
float thold_pt,
|
|
2996
3137
|
float thold_ptsum);
|
|
@@ -3023,8 +3164,8 @@ static inline bool should_split_on_word(const char * txt, bool split_on_word) {
|
|
|
3023
3164
|
|
|
3024
3165
|
// wrap the last segment to max_len characters
|
|
3025
3166
|
// returns the number of new segments
|
|
3026
|
-
static int whisper_wrap_segment(struct whisper_context & ctx, int max_len, bool split_on_word) {
|
|
3027
|
-
auto segment =
|
|
3167
|
+
static int whisper_wrap_segment(struct whisper_context & ctx, struct whisper_state & state, int max_len, bool split_on_word) {
|
|
3168
|
+
auto segment = state.result_all.back();
|
|
3028
3169
|
|
|
3029
3170
|
int res = 1;
|
|
3030
3171
|
int acc = 0;
|
|
@@ -3046,24 +3187,24 @@ static int whisper_wrap_segment(struct whisper_context & ctx, int max_len, bool
|
|
|
3046
3187
|
trim(text);
|
|
3047
3188
|
}
|
|
3048
3189
|
|
|
3049
|
-
|
|
3050
|
-
|
|
3051
|
-
|
|
3190
|
+
state.result_all.back().text = std::move(text);
|
|
3191
|
+
state.result_all.back().t1 = token.t0;
|
|
3192
|
+
state.result_all.back().tokens.resize(i);
|
|
3052
3193
|
|
|
3053
|
-
|
|
3054
|
-
|
|
3055
|
-
|
|
3194
|
+
state.result_all.push_back({});
|
|
3195
|
+
state.result_all.back().t0 = token.t0;
|
|
3196
|
+
state.result_all.back().t1 = segment.t1;
|
|
3056
3197
|
|
|
3057
3198
|
// add tokens [i, end] to the new segment
|
|
3058
|
-
|
|
3059
|
-
|
|
3199
|
+
state.result_all.back().tokens.insert(
|
|
3200
|
+
state.result_all.back().tokens.end(),
|
|
3060
3201
|
segment.tokens.begin() + i,
|
|
3061
3202
|
segment.tokens.end());
|
|
3062
3203
|
|
|
3063
3204
|
acc = 0;
|
|
3064
3205
|
text = "";
|
|
3065
3206
|
|
|
3066
|
-
segment =
|
|
3207
|
+
segment = state.result_all.back();
|
|
3067
3208
|
i = -1;
|
|
3068
3209
|
|
|
3069
3210
|
res++;
|
|
@@ -3076,7 +3217,7 @@ static int whisper_wrap_segment(struct whisper_context & ctx, int max_len, bool
|
|
|
3076
3217
|
if (split_on_word) {
|
|
3077
3218
|
trim(text);
|
|
3078
3219
|
}
|
|
3079
|
-
|
|
3220
|
+
state.result_all.back().text = std::move(text);
|
|
3080
3221
|
|
|
3081
3222
|
return res;
|
|
3082
3223
|
}
|
|
@@ -3093,6 +3234,7 @@ static const std::vector<std::string> non_speech_tokens = {
|
|
|
3093
3234
|
// - computes logprobs and probs
|
|
3094
3235
|
static void whisper_process_logits(
|
|
3095
3236
|
struct whisper_context & ctx,
|
|
3237
|
+
struct whisper_state & state,
|
|
3096
3238
|
const struct whisper_full_params params,
|
|
3097
3239
|
struct whisper_decoder & decoder,
|
|
3098
3240
|
float temperature) {
|
|
@@ -3111,7 +3253,7 @@ static void whisper_process_logits(
|
|
|
3111
3253
|
auto & logprobs = decoder.logprobs;
|
|
3112
3254
|
{
|
|
3113
3255
|
logits.resize(n_logits);
|
|
3114
|
-
memcpy(logits.data(),
|
|
3256
|
+
memcpy(logits.data(), state.logits.data() + (state.logits.size() - n_logits), n_logits*sizeof(float));
|
|
3115
3257
|
|
|
3116
3258
|
if (temperature > 0.0f) {
|
|
3117
3259
|
for (int i = 0; i < n_logits; i++) {
|
|
@@ -3149,7 +3291,7 @@ static void whisper_process_logits(
|
|
|
3149
3291
|
logits[vocab.token_transcribe] = -INFINITY;
|
|
3150
3292
|
|
|
3151
3293
|
if (params.logits_filter_callback) {
|
|
3152
|
-
params.logits_filter_callback(&ctx, tokens_cur.data(), tokens_cur.size(), logits.data(), params.logits_filter_callback_user_data);
|
|
3294
|
+
params.logits_filter_callback(&ctx, &state, tokens_cur.data(), tokens_cur.size(), logits.data(), params.logits_filter_callback_user_data);
|
|
3153
3295
|
}
|
|
3154
3296
|
|
|
3155
3297
|
// suppress non-speech tokens
|
|
@@ -3310,6 +3452,7 @@ static void whisper_process_logits(
|
|
|
3310
3452
|
|
|
3311
3453
|
static whisper_token_data whisper_sample_token(
|
|
3312
3454
|
whisper_context & ctx,
|
|
3455
|
+
whisper_state & state,
|
|
3313
3456
|
const whisper_decoder & decoder,
|
|
3314
3457
|
bool best) {
|
|
3315
3458
|
whisper_token_data result = {
|
|
@@ -3354,7 +3497,7 @@ static whisper_token_data whisper_sample_token(
|
|
|
3354
3497
|
} else {
|
|
3355
3498
|
std::discrete_distribution<> dist(probs.begin(), probs.end());
|
|
3356
3499
|
|
|
3357
|
-
result.id = dist(
|
|
3500
|
+
result.id = dist(state.rng);
|
|
3358
3501
|
result.p = probs[result.id];
|
|
3359
3502
|
result.plog = logprobs[result.id];
|
|
3360
3503
|
}
|
|
@@ -3364,13 +3507,14 @@ static whisper_token_data whisper_sample_token(
|
|
|
3364
3507
|
result.pt = result.p;
|
|
3365
3508
|
}
|
|
3366
3509
|
|
|
3367
|
-
|
|
3510
|
+
state.n_sample++;
|
|
3368
3511
|
|
|
3369
3512
|
return result;
|
|
3370
3513
|
}
|
|
3371
3514
|
|
|
3372
3515
|
static std::vector<whisper_token_data> whisper_sample_token_topk(
|
|
3373
3516
|
whisper_context & ctx,
|
|
3517
|
+
whisper_state & state,
|
|
3374
3518
|
const whisper_decoder & decoder,
|
|
3375
3519
|
int k) {
|
|
3376
3520
|
const auto & vocab = ctx.vocab;
|
|
@@ -3381,7 +3525,7 @@ static std::vector<whisper_token_data> whisper_sample_token_topk(
|
|
|
3381
3525
|
|
|
3382
3526
|
const int n_logits = vocab.n_vocab;
|
|
3383
3527
|
|
|
3384
|
-
auto & logits_id =
|
|
3528
|
+
auto & logits_id = state.logits_id;
|
|
3385
3529
|
|
|
3386
3530
|
logits_id.clear();
|
|
3387
3531
|
for (int i = 0; i < n_logits; ++i) {
|
|
@@ -3434,7 +3578,7 @@ static std::vector<whisper_token_data> whisper_sample_token_topk(
|
|
|
3434
3578
|
}
|
|
3435
3579
|
}
|
|
3436
3580
|
|
|
3437
|
-
|
|
3581
|
+
state.n_sample++;
|
|
3438
3582
|
|
|
3439
3583
|
return result;
|
|
3440
3584
|
}
|
|
@@ -3488,24 +3632,25 @@ static void whisper_sequence_score(
|
|
|
3488
3632
|
}
|
|
3489
3633
|
}
|
|
3490
3634
|
|
|
3491
|
-
int
|
|
3635
|
+
int whisper_full_with_state(
|
|
3492
3636
|
struct whisper_context * ctx,
|
|
3493
|
-
|
|
3494
|
-
|
|
3495
|
-
|
|
3637
|
+
struct whisper_state * state,
|
|
3638
|
+
struct whisper_full_params params,
|
|
3639
|
+
const float * samples,
|
|
3640
|
+
int n_samples) {
|
|
3496
3641
|
// clear old results
|
|
3497
|
-
auto & result_all =
|
|
3642
|
+
auto & result_all = state->result_all;
|
|
3498
3643
|
|
|
3499
3644
|
result_all.clear();
|
|
3500
3645
|
|
|
3501
3646
|
// compute log mel spectrogram
|
|
3502
3647
|
if (params.speed_up) {
|
|
3503
|
-
if (
|
|
3648
|
+
if (whisper_pcm_to_mel_phase_vocoder_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) {
|
|
3504
3649
|
fprintf(stderr, "%s: failed to compute log mel spectrogram\n", __func__);
|
|
3505
3650
|
return -1;
|
|
3506
3651
|
}
|
|
3507
3652
|
} else {
|
|
3508
|
-
if (
|
|
3653
|
+
if (whisper_pcm_to_mel_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) {
|
|
3509
3654
|
fprintf(stderr, "%s: failed to compute log mel spectrogram\n", __func__);
|
|
3510
3655
|
return -2;
|
|
3511
3656
|
}
|
|
@@ -3515,26 +3660,26 @@ int whisper_full(
|
|
|
3515
3660
|
if (params.language == nullptr || strlen(params.language) == 0 || strcmp(params.language, "auto") == 0) {
|
|
3516
3661
|
std::vector<float> probs(whisper_lang_max_id() + 1, 0.0f);
|
|
3517
3662
|
|
|
3518
|
-
const auto lang_id =
|
|
3663
|
+
const auto lang_id = whisper_lang_auto_detect_with_state(ctx, state, 0, params.n_threads, probs.data());
|
|
3519
3664
|
if (lang_id < 0) {
|
|
3520
3665
|
fprintf(stderr, "%s: failed to auto-detect language\n", __func__);
|
|
3521
3666
|
return -3;
|
|
3522
3667
|
}
|
|
3523
|
-
|
|
3668
|
+
state->lang_id = lang_id;
|
|
3524
3669
|
params.language = whisper_lang_str(lang_id);
|
|
3525
3670
|
|
|
3526
3671
|
fprintf(stderr, "%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]);
|
|
3527
3672
|
}
|
|
3528
3673
|
|
|
3529
3674
|
if (params.token_timestamps) {
|
|
3530
|
-
|
|
3531
|
-
|
|
3532
|
-
|
|
3533
|
-
|
|
3675
|
+
state->t_beg = 0;
|
|
3676
|
+
state->t_last = 0;
|
|
3677
|
+
state->tid_last = 0;
|
|
3678
|
+
state->energy = get_signal_energy(samples, n_samples, 32);
|
|
3534
3679
|
}
|
|
3535
3680
|
|
|
3536
3681
|
const int seek_start = params.offset_ms/10;
|
|
3537
|
-
const int seek_end = seek_start + (params.duration_ms == 0 ?
|
|
3682
|
+
const int seek_end = seek_start + (params.duration_ms == 0 ? whisper_n_len_from_state(state) : params.duration_ms/10);
|
|
3538
3683
|
|
|
3539
3684
|
// if length of spectrogram is less than 1s (100 samples), then return
|
|
3540
3685
|
// basically don't process anything that is less than 1s
|
|
@@ -3572,10 +3717,10 @@ int whisper_full(
|
|
|
3572
3717
|
|
|
3573
3718
|
// TAGS: WHISPER_DECODER_INIT
|
|
3574
3719
|
for (int j = 1; j < n_decoders; j++) {
|
|
3575
|
-
auto & decoder =
|
|
3720
|
+
auto & decoder = state->decoders[j];
|
|
3576
3721
|
|
|
3577
3722
|
if (decoder.kv_self.ctx == nullptr) {
|
|
3578
|
-
decoder.kv_self =
|
|
3723
|
+
decoder.kv_self = state->decoders[0].kv_self;
|
|
3579
3724
|
if (!kv_cache_reinit(decoder.kv_self)) {
|
|
3580
3725
|
fprintf(stderr, "%s: kv_cache_reinit() failed for self-attention, decoder %d\n", __func__, j);
|
|
3581
3726
|
return -4;
|
|
@@ -3583,7 +3728,7 @@ int whisper_full(
|
|
|
3583
3728
|
|
|
3584
3729
|
WHISPER_PRINT_DEBUG("%s: initialized self-attention kv cache, decoder %d\n", __func__, j);
|
|
3585
3730
|
|
|
3586
|
-
decoder.sequence.tokens.reserve(
|
|
3731
|
+
decoder.sequence.tokens.reserve(state->decoders[0].sequence.tokens.capacity());
|
|
3587
3732
|
|
|
3588
3733
|
decoder.probs.resize (ctx->vocab.n_vocab);
|
|
3589
3734
|
decoder.logits.resize (ctx->vocab.n_vocab);
|
|
@@ -3592,7 +3737,7 @@ int whisper_full(
|
|
|
3592
3737
|
}
|
|
3593
3738
|
|
|
3594
3739
|
// the accumulated text context so far
|
|
3595
|
-
auto & prompt_past =
|
|
3740
|
+
auto & prompt_past = state->prompt_past;
|
|
3596
3741
|
if (params.no_context) {
|
|
3597
3742
|
prompt_past.clear();
|
|
3598
3743
|
}
|
|
@@ -3611,13 +3756,13 @@ int whisper_full(
|
|
|
3611
3756
|
fprintf(stderr, "%s: audio_ctx is larger than the maximum allowed (%d > %d)\n", __func__, params.audio_ctx, whisper_n_audio_ctx(ctx));
|
|
3612
3757
|
return -5;
|
|
3613
3758
|
}
|
|
3614
|
-
|
|
3759
|
+
state->exp_n_audio_ctx = params.audio_ctx;
|
|
3615
3760
|
|
|
3616
3761
|
// these tokens determine the task that will be performed
|
|
3617
3762
|
std::vector<whisper_token> prompt_init = { whisper_token_sot(ctx) };
|
|
3618
3763
|
if (whisper_is_multilingual(ctx)) {
|
|
3619
3764
|
const int lang_id = whisper_lang_id(params.language);
|
|
3620
|
-
|
|
3765
|
+
state->lang_id = lang_id;
|
|
3621
3766
|
prompt_init.push_back(whisper_token_lang(ctx, lang_id));
|
|
3622
3767
|
if (params.translate) {
|
|
3623
3768
|
prompt_init.push_back(whisper_token_translate());
|
|
@@ -3669,14 +3814,14 @@ int whisper_full(
|
|
|
3669
3814
|
}
|
|
3670
3815
|
|
|
3671
3816
|
if (params.encoder_begin_callback) {
|
|
3672
|
-
if (params.encoder_begin_callback(ctx, params.encoder_begin_callback_user_data) == false) {
|
|
3817
|
+
if (params.encoder_begin_callback(ctx, state, params.encoder_begin_callback_user_data) == false) {
|
|
3673
3818
|
fprintf(stderr, "%s: encoder_begin_callback returned false - aborting\n", __func__);
|
|
3674
3819
|
break;
|
|
3675
3820
|
}
|
|
3676
3821
|
}
|
|
3677
3822
|
|
|
3678
3823
|
// encode audio features starting at offset seek
|
|
3679
|
-
if (!
|
|
3824
|
+
if (!whisper_encode_internal(*ctx, *state, seek, params.n_threads)) {
|
|
3680
3825
|
fprintf(stderr, "%s: failed to encode\n", __func__);
|
|
3681
3826
|
return -6;
|
|
3682
3827
|
}
|
|
@@ -3717,7 +3862,7 @@ int whisper_full(
|
|
|
3717
3862
|
|
|
3718
3863
|
// TAGS: WHISPER_DECODER_INIT
|
|
3719
3864
|
for (int j = 0; j < n_decoders_cur; ++j) {
|
|
3720
|
-
auto & decoder =
|
|
3865
|
+
auto & decoder = state->decoders[j];
|
|
3721
3866
|
|
|
3722
3867
|
decoder.kv_self.n = 0;
|
|
3723
3868
|
|
|
@@ -3759,7 +3904,7 @@ int whisper_full(
|
|
|
3759
3904
|
}
|
|
3760
3905
|
WHISPER_PRINT_DEBUG("\n\n");
|
|
3761
3906
|
|
|
3762
|
-
if (!
|
|
3907
|
+
if (!whisper_decode_internal(*ctx, *state, state->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads)) {
|
|
3763
3908
|
fprintf(stderr, "%s: failed to decode\n", __func__);
|
|
3764
3909
|
return -7;
|
|
3765
3910
|
}
|
|
@@ -3767,24 +3912,24 @@ int whisper_full(
|
|
|
3767
3912
|
{
|
|
3768
3913
|
const int64_t t_start_sample_us = ggml_time_us();
|
|
3769
3914
|
|
|
3770
|
-
whisper_process_logits(*ctx, params,
|
|
3915
|
+
whisper_process_logits(*ctx, *state, params, state->decoders[0], t_cur);
|
|
3771
3916
|
|
|
3772
|
-
|
|
3917
|
+
state->decoders[0].kv_self.n += prompt.size();
|
|
3773
3918
|
|
|
3774
3919
|
for (int j = 1; j < n_decoders_cur; ++j) {
|
|
3775
|
-
auto & decoder =
|
|
3920
|
+
auto & decoder = state->decoders[j];
|
|
3776
3921
|
|
|
3777
|
-
memcpy(decoder.kv_self.k->data,
|
|
3778
|
-
memcpy(decoder.kv_self.v->data,
|
|
3922
|
+
memcpy(decoder.kv_self.k->data, state->decoders[0].kv_self.k->data, ggml_nbytes(decoder.kv_self.k));
|
|
3923
|
+
memcpy(decoder.kv_self.v->data, state->decoders[0].kv_self.v->data, ggml_nbytes(decoder.kv_self.v));
|
|
3779
3924
|
|
|
3780
3925
|
decoder.kv_self.n += prompt.size();
|
|
3781
3926
|
|
|
3782
|
-
memcpy(decoder.probs.data(),
|
|
3783
|
-
memcpy(decoder.logits.data(),
|
|
3784
|
-
memcpy(decoder.logprobs.data(),
|
|
3927
|
+
memcpy(decoder.probs.data(), state->decoders[0].probs.data(), decoder.probs.size()*sizeof(decoder.probs[0]));
|
|
3928
|
+
memcpy(decoder.logits.data(), state->decoders[0].logits.data(), decoder.logits.size()*sizeof(decoder.logits[0]));
|
|
3929
|
+
memcpy(decoder.logprobs.data(), state->decoders[0].logprobs.data(), decoder.logprobs.size()*sizeof(decoder.logprobs[0]));
|
|
3785
3930
|
}
|
|
3786
3931
|
|
|
3787
|
-
|
|
3932
|
+
state->t_sample_us += ggml_time_us() - t_start_sample_us;
|
|
3788
3933
|
}
|
|
3789
3934
|
}
|
|
3790
3935
|
|
|
@@ -3795,7 +3940,7 @@ int whisper_full(
|
|
|
3795
3940
|
if (params.strategy == whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH) {
|
|
3796
3941
|
kv_bufs.resize(n_decoders_cur);
|
|
3797
3942
|
for (int j = 0; j < n_decoders_cur; ++j) {
|
|
3798
|
-
auto & decoder =
|
|
3943
|
+
auto & decoder = state->decoders[j];
|
|
3799
3944
|
|
|
3800
3945
|
if (decoder.completed || decoder.failed) {
|
|
3801
3946
|
continue;
|
|
@@ -3813,7 +3958,7 @@ int whisper_full(
|
|
|
3813
3958
|
|
|
3814
3959
|
// generate new sequence candidates for each decoder
|
|
3815
3960
|
for (int j = 0; j < n_decoders_cur; ++j) {
|
|
3816
|
-
auto & decoder =
|
|
3961
|
+
auto & decoder = state->decoders[j];
|
|
3817
3962
|
|
|
3818
3963
|
if (decoder.completed || decoder.failed) {
|
|
3819
3964
|
continue;
|
|
@@ -3823,16 +3968,16 @@ int whisper_full(
|
|
|
3823
3968
|
case whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY:
|
|
3824
3969
|
{
|
|
3825
3970
|
if (t_cur < 1e-6f) {
|
|
3826
|
-
decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, true));
|
|
3971
|
+
decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, *state, decoder, true));
|
|
3827
3972
|
} else {
|
|
3828
|
-
decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, false));
|
|
3973
|
+
decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, *state, decoder, false));
|
|
3829
3974
|
}
|
|
3830
3975
|
|
|
3831
3976
|
decoder.sequence.sum_logprobs_all += decoder.sequence.tokens.back().plog;
|
|
3832
3977
|
} break;
|
|
3833
3978
|
case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH:
|
|
3834
3979
|
{
|
|
3835
|
-
const auto tokens_new = whisper_sample_token_topk(*ctx, decoder, params.beam_search.beam_size);
|
|
3980
|
+
const auto tokens_new = whisper_sample_token_topk(*ctx, *state, decoder, params.beam_search.beam_size);
|
|
3836
3981
|
|
|
3837
3982
|
for (const auto & token : tokens_new) {
|
|
3838
3983
|
beam_candidates.push_back({ j, decoder.seek_delta, decoder.has_ts, decoder.sequence });
|
|
@@ -3857,7 +4002,7 @@ int whisper_full(
|
|
|
3857
4002
|
uint32_t cur_c = 0;
|
|
3858
4003
|
|
|
3859
4004
|
for (int j = 0; j < n_decoders_cur; ++j) {
|
|
3860
|
-
auto & decoder =
|
|
4005
|
+
auto & decoder = state->decoders[j];
|
|
3861
4006
|
|
|
3862
4007
|
if (decoder.completed || decoder.failed) {
|
|
3863
4008
|
continue;
|
|
@@ -3886,7 +4031,7 @@ int whisper_full(
|
|
|
3886
4031
|
// - check if the sequence is failed
|
|
3887
4032
|
// - update sliding window based on timestamp tokens
|
|
3888
4033
|
for (int j = 0; j < n_decoders_cur; ++j) {
|
|
3889
|
-
auto & decoder =
|
|
4034
|
+
auto & decoder = state->decoders[j];
|
|
3890
4035
|
|
|
3891
4036
|
if (decoder.completed || decoder.failed) {
|
|
3892
4037
|
continue;
|
|
@@ -3968,7 +4113,7 @@ int whisper_full(
|
|
|
3968
4113
|
bool completed_all = true;
|
|
3969
4114
|
|
|
3970
4115
|
for (int j = 0; j < n_decoders_cur; ++j) {
|
|
3971
|
-
auto & decoder =
|
|
4116
|
+
auto & decoder = state->decoders[j];
|
|
3972
4117
|
|
|
3973
4118
|
if (decoder.completed || decoder.failed) {
|
|
3974
4119
|
continue;
|
|
@@ -3982,11 +4127,11 @@ int whisper_full(
|
|
|
3982
4127
|
}
|
|
3983
4128
|
}
|
|
3984
4129
|
|
|
3985
|
-
|
|
4130
|
+
state->t_sample_us += ggml_time_us() - t_start_sample_us;
|
|
3986
4131
|
|
|
3987
4132
|
// obtain logits for the next token
|
|
3988
4133
|
for (int j = 0; j < n_decoders_cur; ++j) {
|
|
3989
|
-
auto & decoder =
|
|
4134
|
+
auto & decoder = state->decoders[j];
|
|
3990
4135
|
|
|
3991
4136
|
if (decoder.failed || decoder.completed) {
|
|
3992
4137
|
continue;
|
|
@@ -3997,7 +4142,7 @@ int whisper_full(
|
|
|
3997
4142
|
|
|
3998
4143
|
//WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, kv_self.n %d, seek_delta %d\n", __func__, j, decoder.tokens_tmp[0], decoder.kv_self.n, decoder.seek_delta);
|
|
3999
4144
|
|
|
4000
|
-
if (!
|
|
4145
|
+
if (!whisper_decode_internal(*ctx, *state, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.kv_self.n, params.n_threads)) {
|
|
4001
4146
|
fprintf(stderr, "%s: failed to decode\n", __func__);
|
|
4002
4147
|
return -8;
|
|
4003
4148
|
}
|
|
@@ -4005,11 +4150,11 @@ int whisper_full(
|
|
|
4005
4150
|
{
|
|
4006
4151
|
const int64_t t_start_sample_us = ggml_time_us();
|
|
4007
4152
|
|
|
4008
|
-
whisper_process_logits(*ctx, params, decoder, t_cur);
|
|
4153
|
+
whisper_process_logits(*ctx, *state, params, decoder, t_cur);
|
|
4009
4154
|
|
|
4010
4155
|
++decoder.kv_self.n;
|
|
4011
4156
|
|
|
4012
|
-
|
|
4157
|
+
state->t_sample_us += ggml_time_us() - t_start_sample_us;
|
|
4013
4158
|
}
|
|
4014
4159
|
}
|
|
4015
4160
|
}
|
|
@@ -4019,7 +4164,7 @@ int whisper_full(
|
|
|
4019
4164
|
double best_score = -INFINITY;
|
|
4020
4165
|
|
|
4021
4166
|
for (int j = 0; j < n_decoders_cur; ++j) {
|
|
4022
|
-
auto & decoder =
|
|
4167
|
+
auto & decoder = state->decoders[j];
|
|
4023
4168
|
|
|
4024
4169
|
if (decoder.failed) {
|
|
4025
4170
|
continue;
|
|
@@ -4036,7 +4181,7 @@ int whisper_full(
|
|
|
4036
4181
|
__func__, j, decoder.sequence.entropy, params.entropy_thold);
|
|
4037
4182
|
|
|
4038
4183
|
decoder.failed = true;
|
|
4039
|
-
|
|
4184
|
+
state->n_fail_h++;
|
|
4040
4185
|
|
|
4041
4186
|
continue;
|
|
4042
4187
|
}
|
|
@@ -4054,11 +4199,11 @@ int whisper_full(
|
|
|
4054
4199
|
{
|
|
4055
4200
|
bool success = true;
|
|
4056
4201
|
|
|
4057
|
-
const auto & decoder =
|
|
4202
|
+
const auto & decoder = state->decoders[best_decoder_id];
|
|
4058
4203
|
|
|
4059
4204
|
if (decoder.failed || decoder.sequence.avg_logprobs < params.logprob_thold) {
|
|
4060
4205
|
success = false;
|
|
4061
|
-
|
|
4206
|
+
state->n_fail_p++;
|
|
4062
4207
|
}
|
|
4063
4208
|
|
|
4064
4209
|
if (success) {
|
|
@@ -4075,7 +4220,7 @@ int whisper_full(
|
|
|
4075
4220
|
|
|
4076
4221
|
// output results through a user-provided callback
|
|
4077
4222
|
{
|
|
4078
|
-
const auto & best_decoder =
|
|
4223
|
+
const auto & best_decoder = state->decoders[best_decoder_id];
|
|
4079
4224
|
|
|
4080
4225
|
const auto seek_delta = best_decoder.seek_delta;
|
|
4081
4226
|
const auto result_len = best_decoder.sequence.result_len;
|
|
@@ -4138,14 +4283,14 @@ int whisper_full(
|
|
|
4138
4283
|
|
|
4139
4284
|
if (params.token_timestamps) {
|
|
4140
4285
|
whisper_exp_compute_token_level_timestamps(
|
|
4141
|
-
*ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
|
|
4286
|
+
*ctx, *state, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
|
|
4142
4287
|
|
|
4143
4288
|
if (params.max_len > 0) {
|
|
4144
|
-
n_new = whisper_wrap_segment(*ctx, params.max_len, params.split_on_word);
|
|
4289
|
+
n_new = whisper_wrap_segment(*ctx, *state, params.max_len, params.split_on_word);
|
|
4145
4290
|
}
|
|
4146
4291
|
}
|
|
4147
4292
|
if (params.new_segment_callback) {
|
|
4148
|
-
params.new_segment_callback(ctx, n_new, params.new_segment_callback_user_data);
|
|
4293
|
+
params.new_segment_callback(ctx, state, n_new, params.new_segment_callback_user_data);
|
|
4149
4294
|
}
|
|
4150
4295
|
}
|
|
4151
4296
|
text = "";
|
|
@@ -4182,14 +4327,14 @@ int whisper_full(
|
|
|
4182
4327
|
|
|
4183
4328
|
if (params.token_timestamps) {
|
|
4184
4329
|
whisper_exp_compute_token_level_timestamps(
|
|
4185
|
-
*ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
|
|
4330
|
+
*ctx, *state, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
|
|
4186
4331
|
|
|
4187
4332
|
if (params.max_len > 0) {
|
|
4188
|
-
n_new = whisper_wrap_segment(*ctx, params.max_len, params.split_on_word);
|
|
4333
|
+
n_new = whisper_wrap_segment(*ctx, *state, params.max_len, params.split_on_word);
|
|
4189
4334
|
}
|
|
4190
4335
|
}
|
|
4191
4336
|
if (params.new_segment_callback) {
|
|
4192
|
-
params.new_segment_callback(ctx, n_new, params.new_segment_callback_user_data);
|
|
4337
|
+
params.new_segment_callback(ctx, state, n_new, params.new_segment_callback_user_data);
|
|
4193
4338
|
}
|
|
4194
4339
|
}
|
|
4195
4340
|
}
|
|
@@ -4204,6 +4349,15 @@ int whisper_full(
|
|
|
4204
4349
|
return 0;
|
|
4205
4350
|
}
|
|
4206
4351
|
|
|
4352
|
+
|
|
4353
|
+
int whisper_full(
|
|
4354
|
+
struct whisper_context * ctx,
|
|
4355
|
+
struct whisper_full_params params,
|
|
4356
|
+
const float * samples,
|
|
4357
|
+
int n_samples) {
|
|
4358
|
+
return whisper_full_with_state(ctx, ctx->state, params, samples, n_samples);
|
|
4359
|
+
}
|
|
4360
|
+
|
|
4207
4361
|
int whisper_full_parallel(
|
|
4208
4362
|
struct whisper_context * ctx,
|
|
4209
4363
|
struct whisper_full_params params,
|
|
@@ -4213,40 +4367,10 @@ int whisper_full_parallel(
|
|
|
4213
4367
|
if (n_processors == 1) {
|
|
4214
4368
|
return whisper_full(ctx, params, samples, n_samples);
|
|
4215
4369
|
}
|
|
4216
|
-
|
|
4217
4370
|
int ret = 0;
|
|
4218
4371
|
|
|
4219
|
-
// prepare separate
|
|
4220
|
-
std::vector<
|
|
4221
|
-
|
|
4222
|
-
for (int i = 0; i < n_processors - 1; ++i) {
|
|
4223
|
-
auto & ctx_p = ctxs[i];
|
|
4224
|
-
|
|
4225
|
-
ctx_p = *ctx;
|
|
4226
|
-
|
|
4227
|
-
ctx_p.logits.reserve(ctx_p.vocab.n_vocab*ctx_p.model.hparams.n_text_ctx);
|
|
4228
|
-
|
|
4229
|
-
ctx_p.logits_id.reserve(ctx_p.vocab.n_vocab);
|
|
4230
|
-
|
|
4231
|
-
if (!kv_cache_reinit(ctx_p.kv_cross)) {
|
|
4232
|
-
fprintf(stderr, "%s: kv_cache_reinit() failed for cross-attention, processor %d\n", __func__, i);
|
|
4233
|
-
return false;
|
|
4234
|
-
}
|
|
4235
|
-
|
|
4236
|
-
// TAGS: WHISPER_DECODER_INIT
|
|
4237
|
-
for (int j = 0; j < WHISPER_MAX_DECODERS; ++j) {
|
|
4238
|
-
if (ctx_p.decoders[j].kv_self.ctx && !kv_cache_reinit(ctx_p.decoders[j].kv_self)) {
|
|
4239
|
-
fprintf(stderr, "%s: kv_cache_reinit() failed for self-attention, decoder %d, processor %d\n", __func__, j, i);
|
|
4240
|
-
return false;
|
|
4241
|
-
}
|
|
4242
|
-
|
|
4243
|
-
ctx_p.decoders[j].sequence.tokens.reserve(ctx_p.model.hparams.n_text_ctx);
|
|
4244
|
-
|
|
4245
|
-
ctx_p.decoders[j].probs.reserve (ctx_p.vocab.n_vocab);
|
|
4246
|
-
ctx_p.decoders[j].logits.reserve (ctx_p.vocab.n_vocab);
|
|
4247
|
-
ctx_p.decoders[j].logprobs.reserve(ctx_p.vocab.n_vocab);
|
|
4248
|
-
}
|
|
4249
|
-
}
|
|
4372
|
+
// prepare separate states for each thread
|
|
4373
|
+
std::vector<whisper_state*> states;
|
|
4250
4374
|
|
|
4251
4375
|
const int offset_samples = (WHISPER_SAMPLE_RATE*params.offset_ms)/1000;
|
|
4252
4376
|
const int n_samples_per_processor = (n_samples - offset_samples)/n_processors;
|
|
@@ -4256,6 +4380,9 @@ int whisper_full_parallel(
|
|
|
4256
4380
|
|
|
4257
4381
|
std::vector<std::thread> workers(n_processors - 1);
|
|
4258
4382
|
for (int i = 0; i < n_processors - 1; ++i) {
|
|
4383
|
+
// create a new state for each thread
|
|
4384
|
+
states.push_back(whisper_init_state(ctx));
|
|
4385
|
+
|
|
4259
4386
|
const int start_samples = offset_samples + (i + 1)*n_samples_per_processor;
|
|
4260
4387
|
const int n_samples_cur = (i == n_processors - 2) ? n_samples - start_samples : n_samples_per_processor;
|
|
4261
4388
|
|
|
@@ -4268,13 +4395,17 @@ int whisper_full_parallel(
|
|
|
4268
4395
|
params_cur.new_segment_callback = nullptr;
|
|
4269
4396
|
params_cur.new_segment_callback_user_data = nullptr;
|
|
4270
4397
|
|
|
4271
|
-
workers[i] = std::thread(
|
|
4398
|
+
workers[i] = std::thread(whisper_full_with_state, ctx, states[i], std::move(params_cur), samples + start_samples, n_samples_cur);
|
|
4272
4399
|
}
|
|
4273
4400
|
|
|
4274
4401
|
{
|
|
4275
4402
|
auto params_cur = params;
|
|
4276
4403
|
|
|
4277
|
-
|
|
4404
|
+
// We need to disable the print real-time for this one as well, otherwise it will show only for the first chunk.
|
|
4405
|
+
params_cur.print_realtime = false;
|
|
4406
|
+
|
|
4407
|
+
// Run the first transformation using default state but only for the first chunk.
|
|
4408
|
+
ret = whisper_full_with_state(ctx, ctx->state, std::move(params_cur), samples, offset_samples + n_samples_per_processor);
|
|
4278
4409
|
}
|
|
4279
4410
|
|
|
4280
4411
|
for (int i = 0; i < n_processors - 1; ++i) {
|
|
@@ -4283,45 +4414,43 @@ int whisper_full_parallel(
|
|
|
4283
4414
|
|
|
4284
4415
|
const int64_t offset_t = (int64_t) params.offset_ms/10.0;
|
|
4285
4416
|
|
|
4286
|
-
// combine results into
|
|
4417
|
+
// combine results into result_state->result_all from all other states
|
|
4287
4418
|
for (int i = 0; i < n_processors - 1; ++i) {
|
|
4288
|
-
auto
|
|
4419
|
+
auto& results_i = states[i]->result_all;
|
|
4289
4420
|
|
|
4290
|
-
for (auto
|
|
4421
|
+
for (auto& result : results_i) {
|
|
4291
4422
|
// correct the segment timestamp taking into account the offset
|
|
4292
|
-
result.t0 += 100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t;
|
|
4293
|
-
result.t1 += 100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t;
|
|
4423
|
+
result.t0 += 100 * ((i + 1) * n_samples_per_processor) / WHISPER_SAMPLE_RATE + offset_t;
|
|
4424
|
+
result.t1 += 100 * ((i + 1) * n_samples_per_processor) / WHISPER_SAMPLE_RATE + offset_t;
|
|
4425
|
+
|
|
4294
4426
|
|
|
4295
4427
|
// make sure that segments are not overlapping
|
|
4296
|
-
if (!ctx->result_all.empty()) {
|
|
4297
|
-
result.t0 = std::max(result.t0, ctx->result_all.back().t1);
|
|
4428
|
+
if (!ctx->state->result_all.empty()) {
|
|
4429
|
+
result.t0 = std::max(result.t0, ctx->state->result_all.back().t1);
|
|
4298
4430
|
}
|
|
4299
4431
|
|
|
4300
|
-
ctx->result_all.push_back(std::move(result));
|
|
4432
|
+
ctx->state->result_all.push_back(std::move(result));
|
|
4301
4433
|
|
|
4302
4434
|
// call the new_segment_callback for each segment
|
|
4303
4435
|
if (params.new_segment_callback) {
|
|
4304
|
-
params.new_segment_callback(ctx, 1, params.new_segment_callback_user_data);
|
|
4436
|
+
params.new_segment_callback(ctx, ctx->state, 1, params.new_segment_callback_user_data);
|
|
4305
4437
|
}
|
|
4306
4438
|
}
|
|
4307
4439
|
|
|
4308
|
-
ctx->t_mel_us
|
|
4309
|
-
ctx->t_sample_us += ctxs[i].t_sample_us;
|
|
4310
|
-
ctx->t_encode_us += ctxs[i].t_encode_us;
|
|
4311
|
-
ctx->t_decode_us += ctxs[i].t_decode_us;
|
|
4440
|
+
ctx->state->t_mel_us += states[i]->t_mel_us;
|
|
4312
4441
|
|
|
4313
|
-
|
|
4442
|
+
ctx->state->t_sample_us += states[i]->t_sample_us;
|
|
4443
|
+
ctx->state->t_encode_us += states[i]->t_encode_us;
|
|
4444
|
+
ctx->state->t_decode_us += states[i]->t_decode_us;
|
|
4314
4445
|
|
|
4315
|
-
|
|
4316
|
-
kv_cache_free(ctx->decoders[j].kv_self);
|
|
4317
|
-
}
|
|
4446
|
+
whisper_free_state(states[i]);
|
|
4318
4447
|
}
|
|
4319
4448
|
|
|
4320
4449
|
// average the timings
|
|
4321
|
-
ctx->t_mel_us /= n_processors;
|
|
4322
|
-
ctx->t_sample_us /= n_processors;
|
|
4323
|
-
ctx->t_encode_us /= n_processors;
|
|
4324
|
-
ctx->t_decode_us /= n_processors;
|
|
4450
|
+
ctx->state->t_mel_us /= n_processors;
|
|
4451
|
+
ctx->state->t_sample_us /= n_processors;
|
|
4452
|
+
ctx->state->t_encode_us /= n_processors;
|
|
4453
|
+
ctx->state->t_decode_us /= n_processors;
|
|
4325
4454
|
|
|
4326
4455
|
// print information about the audio boundaries
|
|
4327
4456
|
fprintf(stderr, "\n");
|
|
@@ -4334,44 +4463,84 @@ int whisper_full_parallel(
|
|
|
4334
4463
|
return ret;
|
|
4335
4464
|
}
|
|
4336
4465
|
|
|
4466
|
+
int whisper_full_n_segments_from_state(struct whisper_state * state) {
|
|
4467
|
+
return state->result_all.size();
|
|
4468
|
+
}
|
|
4469
|
+
|
|
4337
4470
|
int whisper_full_n_segments(struct whisper_context * ctx) {
|
|
4338
|
-
return ctx->result_all.size();
|
|
4471
|
+
return ctx->state->result_all.size();
|
|
4472
|
+
}
|
|
4473
|
+
|
|
4474
|
+
int whisper_full_lang_id_from_state(struct whisper_state * state) {
|
|
4475
|
+
return state->lang_id;
|
|
4339
4476
|
}
|
|
4340
4477
|
|
|
4341
4478
|
int whisper_full_lang_id(struct whisper_context * ctx) {
|
|
4342
|
-
return ctx->lang_id;
|
|
4479
|
+
return ctx->state->lang_id;
|
|
4480
|
+
}
|
|
4481
|
+
|
|
4482
|
+
int64_t whisper_full_get_segment_t0_from_state(struct whisper_state * state, int i_segment) {
|
|
4483
|
+
return state->result_all[i_segment].t0;
|
|
4343
4484
|
}
|
|
4344
4485
|
|
|
4345
4486
|
int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment) {
|
|
4346
|
-
return ctx->result_all[i_segment].t0;
|
|
4487
|
+
return ctx->state->result_all[i_segment].t0;
|
|
4488
|
+
}
|
|
4489
|
+
|
|
4490
|
+
int64_t whisper_full_get_segment_t1_from_state(struct whisper_state * state, int i_segment) {
|
|
4491
|
+
return state->result_all[i_segment].t1;
|
|
4347
4492
|
}
|
|
4348
4493
|
|
|
4349
4494
|
int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment) {
|
|
4350
|
-
return ctx->result_all[i_segment].t1;
|
|
4495
|
+
return ctx->state->result_all[i_segment].t1;
|
|
4496
|
+
}
|
|
4497
|
+
|
|
4498
|
+
const char * whisper_full_get_segment_text_from_state(struct whisper_state * state, int i_segment) {
|
|
4499
|
+
return state->result_all[i_segment].text.c_str();
|
|
4351
4500
|
}
|
|
4352
4501
|
|
|
4353
4502
|
const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment) {
|
|
4354
|
-
return ctx->result_all[i_segment].text.c_str();
|
|
4503
|
+
return ctx->state->result_all[i_segment].text.c_str();
|
|
4504
|
+
}
|
|
4505
|
+
|
|
4506
|
+
int whisper_full_n_tokens_from_state(struct whisper_state * state, int i_segment) {
|
|
4507
|
+
return state->result_all[i_segment].tokens.size();
|
|
4355
4508
|
}
|
|
4356
4509
|
|
|
4357
4510
|
int whisper_full_n_tokens(struct whisper_context * ctx, int i_segment) {
|
|
4358
|
-
return ctx->result_all[i_segment].tokens.size();
|
|
4511
|
+
return ctx->state->result_all[i_segment].tokens.size();
|
|
4512
|
+
}
|
|
4513
|
+
|
|
4514
|
+
const char * whisper_full_get_token_text_from_state(struct whisper_context * ctx, struct whisper_state * state, int i_segment, int i_token) {
|
|
4515
|
+
return ctx->vocab.id_to_token[state->result_all[i_segment].tokens[i_token].id].c_str();
|
|
4516
|
+
}
|
|
4517
|
+
|
|
4518
|
+
const char* whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token) {
|
|
4519
|
+
return ctx->vocab.id_to_token[ctx->state->result_all[i_segment].tokens[i_token].id].c_str();
|
|
4359
4520
|
}
|
|
4360
4521
|
|
|
4361
|
-
|
|
4362
|
-
return
|
|
4522
|
+
whisper_token whisper_full_get_token_id_from_state(struct whisper_state * state, int i_segment, int i_token) {
|
|
4523
|
+
return state->result_all[i_segment].tokens[i_token].id;
|
|
4363
4524
|
}
|
|
4364
4525
|
|
|
4365
4526
|
whisper_token whisper_full_get_token_id(struct whisper_context * ctx, int i_segment, int i_token) {
|
|
4366
|
-
return ctx->result_all[i_segment].tokens[i_token].id;
|
|
4527
|
+
return ctx->state->result_all[i_segment].tokens[i_token].id;
|
|
4528
|
+
}
|
|
4529
|
+
|
|
4530
|
+
struct whisper_token_data whisper_full_get_token_data_from_state(struct whisper_state * state, int i_segment, int i_token) {
|
|
4531
|
+
return state->result_all[i_segment].tokens[i_token];
|
|
4367
4532
|
}
|
|
4368
4533
|
|
|
4369
4534
|
struct whisper_token_data whisper_full_get_token_data(struct whisper_context * ctx, int i_segment, int i_token) {
|
|
4370
|
-
return ctx->result_all[i_segment].tokens[i_token];
|
|
4535
|
+
return ctx->state->result_all[i_segment].tokens[i_token];
|
|
4536
|
+
}
|
|
4537
|
+
|
|
4538
|
+
float whisper_full_get_token_p_from_state(struct whisper_state * state, int i_segment, int i_token) {
|
|
4539
|
+
return state->result_all[i_segment].tokens[i_token].p;
|
|
4371
4540
|
}
|
|
4372
4541
|
|
|
4373
4542
|
float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token) {
|
|
4374
|
-
return ctx->result_all[i_segment].tokens[i_token].p;
|
|
4543
|
+
return ctx->state->result_all[i_segment].tokens[i_token].p;
|
|
4375
4544
|
}
|
|
4376
4545
|
|
|
4377
4546
|
// =================================================================================================
|
|
@@ -4382,6 +4551,15 @@ float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int
|
|
|
4382
4551
|
//
|
|
4383
4552
|
|
|
4384
4553
|
WHISPER_API int whisper_bench_memcpy(int n_threads) {
|
|
4554
|
+
fputs(whisper_bench_memcpy_str(n_threads), stderr);
|
|
4555
|
+
return 0;
|
|
4556
|
+
}
|
|
4557
|
+
|
|
4558
|
+
WHISPER_API const char * whisper_bench_memcpy_str(int n_threads) {
|
|
4559
|
+
static std::string s;
|
|
4560
|
+
s = "";
|
|
4561
|
+
char strbuf[256];
|
|
4562
|
+
|
|
4385
4563
|
ggml_time_init();
|
|
4386
4564
|
|
|
4387
4565
|
size_t n = 50;
|
|
@@ -4411,7 +4589,8 @@ WHISPER_API int whisper_bench_memcpy(int n_threads) {
|
|
|
4411
4589
|
src[0] = rand();
|
|
4412
4590
|
}
|
|
4413
4591
|
|
|
4414
|
-
|
|
4592
|
+
snprintf(strbuf, sizeof(strbuf), "memcpy: %.2f GB/s\n", (double) (n*size)/(tsum*1024llu*1024llu*1024llu));
|
|
4593
|
+
s += strbuf;
|
|
4415
4594
|
|
|
4416
4595
|
// needed to prevent the compile from optimizing the memcpy away
|
|
4417
4596
|
{
|
|
@@ -4419,16 +4598,26 @@ WHISPER_API int whisper_bench_memcpy(int n_threads) {
|
|
|
4419
4598
|
|
|
4420
4599
|
for (size_t i = 0; i < size; i++) sum += dst[i];
|
|
4421
4600
|
|
|
4422
|
-
|
|
4601
|
+
snprintf(strbuf, sizeof(strbuf), "sum: %s %f\n", sum == -536870910.00 ? "ok" : "error", sum);
|
|
4602
|
+
s += strbuf;
|
|
4423
4603
|
}
|
|
4424
4604
|
|
|
4425
4605
|
free(src);
|
|
4426
4606
|
free(dst);
|
|
4427
4607
|
|
|
4428
|
-
return
|
|
4608
|
+
return s.c_str();
|
|
4429
4609
|
}
|
|
4430
4610
|
|
|
4431
4611
|
WHISPER_API int whisper_bench_ggml_mul_mat(int n_threads) {
|
|
4612
|
+
fputs(whisper_bench_ggml_mul_mat_str(n_threads), stderr);
|
|
4613
|
+
return 0;
|
|
4614
|
+
}
|
|
4615
|
+
|
|
4616
|
+
WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads) {
|
|
4617
|
+
static std::string s;
|
|
4618
|
+
s = "";
|
|
4619
|
+
char strbuf[256];
|
|
4620
|
+
|
|
4432
4621
|
ggml_time_init();
|
|
4433
4622
|
|
|
4434
4623
|
const int n_max = 128;
|
|
@@ -4504,11 +4693,12 @@ WHISPER_API int whisper_bench_ggml_mul_mat(int n_threads) {
|
|
|
4504
4693
|
s = ((2.0*N*N*N*n)/tsum)*1e-9;
|
|
4505
4694
|
}
|
|
4506
4695
|
|
|
4507
|
-
|
|
4696
|
+
snprintf(strbuf, sizeof(strbuf), "ggml_mul_mat: %5zu x %5zu: F16 %8.1f GFLOPS (%3d runs) / F32 %8.1f GFLOPS (%3d runs)\n",
|
|
4508
4697
|
N, N, s_fp16, n_fp16, s_fp32, n_fp32);
|
|
4698
|
+
s += strbuf;
|
|
4509
4699
|
}
|
|
4510
4700
|
|
|
4511
|
-
return
|
|
4701
|
+
return s.c_str();
|
|
4512
4702
|
}
|
|
4513
4703
|
|
|
4514
4704
|
// =================================================================================================
|
|
@@ -4583,13 +4773,14 @@ static std::vector<float> get_signal_energy(const float * signal, int n_samples,
|
|
|
4583
4773
|
|
|
4584
4774
|
static void whisper_exp_compute_token_level_timestamps(
|
|
4585
4775
|
struct whisper_context & ctx,
|
|
4776
|
+
struct whisper_state & state,
|
|
4586
4777
|
int i_segment,
|
|
4587
4778
|
float thold_pt,
|
|
4588
4779
|
float thold_ptsum) {
|
|
4589
|
-
auto & segment =
|
|
4780
|
+
auto & segment = state.result_all[i_segment];
|
|
4590
4781
|
auto & tokens = segment.tokens;
|
|
4591
4782
|
|
|
4592
|
-
const int n_samples =
|
|
4783
|
+
const int n_samples = state.energy.size();
|
|
4593
4784
|
|
|
4594
4785
|
if (n_samples == 0) {
|
|
4595
4786
|
fprintf(stderr, "%s: no signal data available\n", __func__);
|
|
@@ -4612,9 +4803,9 @@ static void whisper_exp_compute_token_level_timestamps(
|
|
|
4612
4803
|
return;
|
|
4613
4804
|
}
|
|
4614
4805
|
|
|
4615
|
-
auto & t_beg =
|
|
4616
|
-
auto & t_last =
|
|
4617
|
-
auto & tid_last =
|
|
4806
|
+
auto & t_beg = state.t_beg;
|
|
4807
|
+
auto & t_last = state.t_last;
|
|
4808
|
+
auto & tid_last = state.tid_last;
|
|
4618
4809
|
|
|
4619
4810
|
for (int j = 0; j < n; ++j) {
|
|
4620
4811
|
auto & token = tokens[j];
|
|
@@ -4737,15 +4928,15 @@ static void whisper_exp_compute_token_level_timestamps(
|
|
|
4737
4928
|
float sum = 0.0f;
|
|
4738
4929
|
|
|
4739
4930
|
for (int k = ss0; k < ss1; k++) {
|
|
4740
|
-
sum +=
|
|
4931
|
+
sum += state.energy[k];
|
|
4741
4932
|
}
|
|
4742
4933
|
|
|
4743
4934
|
const float thold = 0.5*sum/ns;
|
|
4744
4935
|
|
|
4745
4936
|
{
|
|
4746
4937
|
int k = s0;
|
|
4747
|
-
if (
|
|
4748
|
-
while (k > 0 &&
|
|
4938
|
+
if (state.energy[k] > thold && j > 0) {
|
|
4939
|
+
while (k > 0 && state.energy[k] > thold) {
|
|
4749
4940
|
k--;
|
|
4750
4941
|
}
|
|
4751
4942
|
tokens[j].t0 = sample_to_timestamp(k);
|
|
@@ -4755,7 +4946,7 @@ static void whisper_exp_compute_token_level_timestamps(
|
|
|
4755
4946
|
s0 = k;
|
|
4756
4947
|
}
|
|
4757
4948
|
} else {
|
|
4758
|
-
while (
|
|
4949
|
+
while (state.energy[k] < thold && k < s1) {
|
|
4759
4950
|
k++;
|
|
4760
4951
|
}
|
|
4761
4952
|
s0 = k;
|
|
@@ -4765,8 +4956,8 @@ static void whisper_exp_compute_token_level_timestamps(
|
|
|
4765
4956
|
|
|
4766
4957
|
{
|
|
4767
4958
|
int k = s1;
|
|
4768
|
-
if (
|
|
4769
|
-
while (k < n_samples - 1 &&
|
|
4959
|
+
if (state.energy[k] > thold) {
|
|
4960
|
+
while (k < n_samples - 1 && state.energy[k] > thold) {
|
|
4770
4961
|
k++;
|
|
4771
4962
|
}
|
|
4772
4963
|
tokens[j].t1 = sample_to_timestamp(k);
|
|
@@ -4776,7 +4967,7 @@ static void whisper_exp_compute_token_level_timestamps(
|
|
|
4776
4967
|
s1 = k;
|
|
4777
4968
|
}
|
|
4778
4969
|
} else {
|
|
4779
|
-
while (
|
|
4970
|
+
while (state.energy[k] < thold && k > s0) {
|
|
4780
4971
|
k--;
|
|
4781
4972
|
}
|
|
4782
4973
|
s1 = k;
|