whisper.rn 0.1.4 → 0.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/cpp/whisper.cpp CHANGED
@@ -547,13 +547,11 @@ struct whisper_decoder {
547
547
  std::vector<whisper_token> tokens_tmp; // used for whisper_decode calls
548
548
  };
549
549
 
550
- struct whisper_context {
551
- int64_t t_load_us = 0;
552
- int64_t t_mel_us = 0;
550
+ struct whisper_state {
553
551
  int64_t t_sample_us = 0;
554
552
  int64_t t_encode_us = 0;
555
553
  int64_t t_decode_us = 0;
556
- int64_t t_start_us = 0;
554
+ int64_t t_mel_us = 0;
557
555
 
558
556
  int32_t n_sample = 0; // number of tokens sampled
559
557
  int32_t n_encode = 0; // number of encoder calls
@@ -561,16 +559,10 @@ struct whisper_context {
561
559
  int32_t n_fail_p = 0; // number of logprob threshold failures
562
560
  int32_t n_fail_h = 0; // number of entropy threshold failures
563
561
 
564
- ggml_type wtype; // weight type (FP32 or FP16)
565
-
566
- whisper_mel mel;
567
-
568
- whisper_model model;
569
- whisper_vocab vocab;
570
-
571
562
  // cross-attention KV cache for the decoders
572
563
  // shared between all decoders
573
564
  whisper_kv_cache kv_cross;
565
+ whisper_mel mel;
574
566
 
575
567
  whisper_decoder decoders[WHISPER_MAX_DECODERS] = {};
576
568
 
@@ -635,6 +627,18 @@ struct whisper_context {
635
627
  }
636
628
  };
637
629
 
630
+ struct whisper_context {
631
+ int64_t t_load_us = 0;
632
+ int64_t t_start_us = 0;
633
+
634
+
635
+ ggml_type wtype = ggml_type::GGML_TYPE_F16; // weight type (FP32 or FP16)
636
+
637
+ whisper_model model;
638
+ whisper_vocab vocab;
639
+ whisper_state * state = nullptr;
640
+ };
641
+
638
642
  template<typename T>
639
643
  static void read_safe(whisper_model_loader * loader, T & dest) {
640
644
  loader->read(loader->context, &dest, sizeof(T));
@@ -821,32 +825,8 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
821
825
  wctx.model.buf = new std::vector<uint8_t>();
822
826
  wctx.model.buf->resize(scale*MEM_REQ_MODEL.at(model.type));
823
827
 
824
- if (!kv_cache_init(model.hparams, scale*MEM_REQ_KV_SELF.at(model.type), wctx.decoders[0].kv_self, wctx.wtype, model.hparams.n_text_ctx)) {
825
- fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__);
826
- return false;
827
- }
828
-
829
- {
830
- const size_t memory_size = ggml_nbytes(wctx.decoders[0].kv_self.k) + ggml_nbytes(wctx.decoders[0].kv_self.v);
831
- fprintf(stderr, "%s: kv self size = %7.2f MB\n", __func__, memory_size/1024.0/1024.0);
832
- }
833
-
834
- if (!kv_cache_init(model.hparams, scale*MEM_REQ_KV_CROSS.at(model.type), wctx.kv_cross, wctx.wtype, model.hparams.n_audio_ctx)) {
835
- fprintf(stderr, "%s: kv_cache_init() failed for cross-attention cache\n", __func__);
836
- return false;
837
- }
838
-
839
- {
840
- const size_t memory_size = ggml_nbytes(wctx.kv_cross.k) + ggml_nbytes(wctx.kv_cross.v);
841
- fprintf(stderr, "%s: kv cross size = %7.2f MB\n", __func__, memory_size/1024.0/1024.0);
842
- }
843
-
844
- wctx.buf_compute.resize(scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type)));
845
-
846
- wctx.buf_scratch[0].resize(MEM_REQ_SCRATCH0.at(model.type));
847
- wctx.buf_scratch[1].resize(MEM_REQ_SCRATCH1.at(model.type));
848
- wctx.buf_scratch[2].resize(MEM_REQ_SCRATCH2.at(model.type));
849
- wctx.buf_scratch[3].resize(MEM_REQ_SCRATCH3.at(model.type));
828
+ // we skip initialization of the state until it is needed
829
+ // because it might be that state will always be provided externally.
850
830
  }
851
831
 
852
832
  // load mel filters
@@ -929,17 +909,6 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
929
909
  vocab.id_to_token[i] = word;
930
910
  }
931
911
  }
932
-
933
- wctx.logits.reserve(vocab.n_vocab*model.hparams.n_text_ctx);
934
-
935
- wctx.logits_id.reserve(n_vocab);
936
-
937
- // TAGS: WHISPER_DECODER_INIT
938
- wctx.decoders[0].sequence.tokens.reserve(model.hparams.n_text_ctx);
939
-
940
- wctx.decoders[0].probs.reserve (vocab.n_vocab);
941
- wctx.decoders[0].logits.reserve (vocab.n_vocab);
942
- wctx.decoders[0].logprobs.reserve(vocab.n_vocab);
943
912
  }
944
913
 
945
914
  size_t ctx_size = 0;
@@ -1339,33 +1308,34 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
1339
1308
  }
1340
1309
  }
1341
1310
 
1342
- wctx.rng = std::mt19937(0);
1343
-
1344
1311
  wctx.t_load_us = ggml_time_us() - t_start_us;
1345
1312
 
1346
1313
  return true;
1347
1314
  }
1348
1315
 
1349
- // evaluate the encoder
1316
+ // evaluate the encoder with the given state
1350
1317
  //
1351
1318
  // given audio recording (more specifically, its log mel spectrogram), runs forward pass of the encoder
1352
1319
  // part of the transformer model and returns the encoded features
1353
1320
  //
1354
- // - model: the model
1321
+ // - wctx: the model
1322
+ // - wstate: the state of the encoder
1355
1323
  // - n_threads: number of threads to use
1356
1324
  // - mel_offset: offset in the mel spectrogram (i.e. audio offset)
1357
1325
  //
1358
- static bool whisper_encode(
1326
+ static bool whisper_encode_internal(
1359
1327
  whisper_context & wctx,
1328
+ whisper_state & wstate,
1360
1329
  const int mel_offset,
1361
- const int n_threads) {
1330
+ const int n_threads){
1331
+
1362
1332
  const int64_t t_start_us = ggml_time_us();
1363
1333
 
1364
1334
  const auto & model = wctx.model;
1365
- const auto & mel_inp = wctx.mel;
1335
+ const auto & mel_inp = wstate.mel;
1366
1336
  const auto & hparams = model.hparams;
1367
1337
 
1368
- const int n_ctx = wctx.exp_n_audio_ctx > 0 ? wctx.exp_n_audio_ctx : hparams.n_audio_ctx;
1338
+ const int n_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx;
1369
1339
  const int n_state = hparams.n_audio_state;
1370
1340
  const int n_head = hparams.n_audio_head;
1371
1341
  const int n_layer = hparams.n_audio_layer;
@@ -1374,12 +1344,12 @@ static bool whisper_encode(
1374
1344
  assert(mel_inp.n_mel == n_mels);
1375
1345
 
1376
1346
  struct ggml_init_params params;
1377
- params.mem_size = wctx.buf_compute.size();
1378
- params.mem_buffer = wctx.buf_compute.data();
1347
+ params.mem_size = wstate.buf_compute.size();
1348
+ params.mem_buffer = wstate.buf_compute.data();
1379
1349
 
1380
1350
  struct ggml_context * ctx0 = ggml_init(params);
1381
1351
 
1382
- wctx.use_buf(ctx0, 0);
1352
+ wstate.use_buf(ctx0, 0);
1383
1353
 
1384
1354
  struct ggml_tensor * mel = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2*n_ctx, n_mels);
1385
1355
  assert(mel->type == GGML_TYPE_F32);
@@ -1401,30 +1371,30 @@ static bool whisper_encode(
1401
1371
 
1402
1372
  // convolution + gelu
1403
1373
  {
1404
- wctx.use_buf(ctx0, 1);
1374
+ wstate.use_buf(ctx0, 1);
1405
1375
 
1406
1376
  cur = ggml_conv_1d_1s(ctx0, model.e_conv_1_w, mel);
1407
1377
  cur = ggml_add(ctx0,
1408
- ggml_repeat(ctx0,
1409
- model.e_conv_1_b,
1410
- cur),
1411
- cur);
1378
+ ggml_repeat(ctx0,
1379
+ model.e_conv_1_b,
1380
+ cur),
1381
+ cur);
1412
1382
 
1413
1383
  cur = ggml_gelu(ctx0, cur);
1414
1384
 
1415
- wctx.use_buf(ctx0, 0);
1385
+ wstate.use_buf(ctx0, 0);
1416
1386
 
1417
1387
  cur = ggml_conv_1d_2s(ctx0, model.e_conv_2_w, cur);
1418
1388
  cur = ggml_add(ctx0,
1419
- ggml_repeat(ctx0,
1420
- model.e_conv_2_b,
1421
- cur),
1422
- cur);
1389
+ ggml_repeat(ctx0,
1390
+ model.e_conv_2_b,
1391
+ cur),
1392
+ cur);
1423
1393
 
1424
1394
  cur = ggml_gelu(ctx0, cur);
1425
1395
  }
1426
1396
 
1427
- wctx.use_buf(ctx0, 3);
1397
+ wstate.use_buf(ctx0, 3);
1428
1398
 
1429
1399
  // ===================================================================
1430
1400
  // NOTE: experimenting with partial evaluation of the encoder (ignore)
@@ -1439,7 +1409,7 @@ static bool whisper_encode(
1439
1409
  //}
1440
1410
 
1441
1411
  static int iter = 0;
1442
-
1412
+
1443
1413
  const size_t e_pe_stride = model.e_pe->ne[0]*ggml_element_size(model.e_pe);
1444
1414
  const size_t e_pe_offset = model.e_pe->ne[0]*ggml_element_size(model.e_pe)*n_ctx*iter;
1445
1415
 
@@ -1459,54 +1429,54 @@ static bool whisper_encode(
1459
1429
 
1460
1430
  // norm
1461
1431
  {
1462
- wctx.use_buf(ctx0, 0);
1432
+ wstate.use_buf(ctx0, 0);
1463
1433
 
1464
1434
  cur = ggml_norm(ctx0, inpL);
1465
1435
 
1466
1436
  // cur = ln_0_w*cur + ln_0_b
1467
1437
  cur = ggml_add(ctx0,
1468
- ggml_mul(ctx0,
1469
- ggml_repeat(ctx0, layer.attn_ln_0_w, cur),
1470
- cur),
1471
- ggml_repeat(ctx0, layer.attn_ln_0_b, cur));
1438
+ ggml_mul(ctx0,
1439
+ ggml_repeat(ctx0, layer.attn_ln_0_w, cur),
1440
+ cur),
1441
+ ggml_repeat(ctx0, layer.attn_ln_0_b, cur));
1472
1442
  }
1473
1443
 
1474
1444
  // self-attention
1475
1445
  {
1476
- wctx.use_buf(ctx0, 1);
1446
+ wstate.use_buf(ctx0, 1);
1477
1447
 
1478
1448
  struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
1479
- layer.attn_q_w,
1480
- cur);
1449
+ layer.attn_q_w,
1450
+ cur);
1481
1451
 
1482
1452
  Qcur = ggml_add(ctx0,
1483
- ggml_repeat(ctx0,
1484
- layer.attn_q_b,
1485
- Qcur),
1486
- Qcur);
1453
+ ggml_repeat(ctx0,
1454
+ layer.attn_q_b,
1455
+ Qcur),
1456
+ Qcur);
1487
1457
 
1488
1458
  //Qcur = ggml_scale(ctx0, Qcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
1489
1459
 
1490
1460
  // note: no bias for Key
1491
1461
  struct ggml_tensor * Kcur = ggml_mul_mat(ctx0,
1492
- layer.attn_k_w,
1493
- cur);
1462
+ layer.attn_k_w,
1463
+ cur);
1494
1464
 
1495
1465
  //Kcur = ggml_scale(ctx0, Kcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
1496
1466
 
1497
1467
  struct ggml_tensor * Vcur = ggml_mul_mat(ctx0,
1498
- layer.attn_v_w,
1499
- cur);
1468
+ layer.attn_v_w,
1469
+ cur);
1500
1470
 
1501
1471
  Vcur = ggml_add(ctx0,
1502
- ggml_repeat(ctx0,
1503
- layer.attn_v_b,
1504
- Vcur),
1505
- Vcur);
1472
+ ggml_repeat(ctx0,
1473
+ layer.attn_v_b,
1474
+ Vcur),
1475
+ Vcur);
1506
1476
 
1507
1477
  // ------
1508
1478
 
1509
- wctx.use_buf(ctx0, 0);
1479
+ wstate.use_buf(ctx0, 0);
1510
1480
 
1511
1481
  #ifdef WHISPER_USE_FLASH_ATTN
1512
1482
  struct ggml_tensor * Q =
@@ -1583,29 +1553,29 @@ static bool whisper_encode(
1583
1553
  #endif
1584
1554
  struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
1585
1555
 
1586
- wctx.use_buf(ctx0, 1);
1556
+ wstate.use_buf(ctx0, 1);
1587
1557
 
1588
1558
  cur = ggml_cpy(ctx0,
1589
- KQV_merged,
1590
- ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx));
1559
+ KQV_merged,
1560
+ ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx));
1591
1561
  }
1592
1562
 
1593
1563
  // projection
1594
1564
  {
1595
- wctx.use_buf(ctx0, 0);
1565
+ wstate.use_buf(ctx0, 0);
1596
1566
 
1597
1567
  cur = ggml_mul_mat(ctx0,
1598
- layer.attn_ln_1_w,
1599
- cur);
1568
+ layer.attn_ln_1_w,
1569
+ cur);
1600
1570
 
1601
- wctx.use_buf(ctx0, 1);
1571
+ wstate.use_buf(ctx0, 1);
1602
1572
 
1603
1573
  cur = ggml_add(ctx0,
1604
- ggml_repeat(ctx0, layer.attn_ln_1_b, cur),
1605
- cur);
1574
+ ggml_repeat(ctx0, layer.attn_ln_1_b, cur),
1575
+ cur);
1606
1576
  }
1607
1577
 
1608
- wctx.use_buf(ctx0, 2);
1578
+ wstate.use_buf(ctx0, 2);
1609
1579
 
1610
1580
  // add the input
1611
1581
  cur = ggml_add(ctx0, cur, inpL);
@@ -1616,61 +1586,61 @@ static bool whisper_encode(
1616
1586
  {
1617
1587
  // norm
1618
1588
  {
1619
- wctx.use_buf(ctx0, 0);
1589
+ wstate.use_buf(ctx0, 0);
1620
1590
 
1621
1591
  cur = ggml_norm(ctx0, inpFF);
1622
1592
 
1623
- wctx.use_buf(ctx0, 1);
1593
+ wstate.use_buf(ctx0, 1);
1624
1594
 
1625
1595
  // cur = mlp_ln_w*cur + mlp_ln_b
1626
1596
  cur = ggml_add(ctx0,
1627
- ggml_mul(ctx0,
1628
- ggml_repeat(ctx0, layer.mlp_ln_w, cur),
1629
- cur),
1630
- ggml_repeat(ctx0, layer.mlp_ln_b, cur));
1631
- }
1597
+ ggml_mul(ctx0,
1598
+ ggml_repeat(ctx0, layer.mlp_ln_w, cur),
1599
+ cur),
1600
+ ggml_repeat(ctx0, layer.mlp_ln_b, cur));
1601
+ }
1632
1602
 
1633
1603
  #ifdef WHISPER_USE_FLASH_FF
1634
- wctx.use_buf(ctx0, 0);
1604
+ wstate.use_buf(ctx0, 0);
1635
1605
 
1636
1606
  cur = ggml_flash_ff(ctx0,
1637
- ggml_cpy(ctx0, cur, ggml_new_tensor_2d(ctx0, wctx.wtype, n_state, n_ctx)),
1638
- layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b);
1607
+ ggml_cpy(ctx0, cur, ggml_new_tensor_2d(ctx0, wstate.wtype, n_state, n_ctx)),
1608
+ layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b);
1639
1609
  #else
1640
- wctx.use_buf(ctx0, 0);
1610
+ wstate.use_buf(ctx0, 0);
1641
1611
 
1642
1612
  // fully connected
1643
1613
  cur = ggml_mul_mat(ctx0,
1644
- layer.mlp_0_w,
1645
- cur);
1614
+ layer.mlp_0_w,
1615
+ cur);
1646
1616
 
1647
- wctx.use_buf(ctx0, 1);
1617
+ wstate.use_buf(ctx0, 1);
1648
1618
 
1649
1619
  cur = ggml_add(ctx0,
1650
- ggml_repeat(ctx0, layer.mlp_0_b, cur),
1651
- cur);
1620
+ ggml_repeat(ctx0, layer.mlp_0_b, cur),
1621
+ cur);
1652
1622
 
1653
- wctx.use_buf(ctx0, 0);
1623
+ wstate.use_buf(ctx0, 0);
1654
1624
 
1655
1625
  // GELU activation
1656
1626
  cur = ggml_gelu(ctx0, cur);
1657
1627
 
1658
- wctx.use_buf(ctx0, 1);
1628
+ wstate.use_buf(ctx0, 1);
1659
1629
 
1660
1630
  // projection
1661
1631
  cur = ggml_mul_mat(ctx0,
1662
- layer.mlp_1_w,
1663
- cur);
1632
+ layer.mlp_1_w,
1633
+ cur);
1664
1634
 
1665
- wctx.use_buf(ctx0, 0);
1635
+ wstate.use_buf(ctx0, 0);
1666
1636
 
1667
1637
  cur = ggml_add(ctx0,
1668
- ggml_repeat(ctx0, layer.mlp_1_b, cur),
1669
- cur);
1638
+ ggml_repeat(ctx0, layer.mlp_1_b, cur),
1639
+ cur);
1670
1640
  #endif
1671
- }
1641
+ }
1672
1642
 
1673
- wctx.use_buf(ctx0, 3);
1643
+ wstate.use_buf(ctx0, 3);
1674
1644
 
1675
1645
  inpL = ggml_add(ctx0, cur, inpFF);
1676
1646
  }
@@ -1679,21 +1649,21 @@ static bool whisper_encode(
1679
1649
 
1680
1650
  // norm
1681
1651
  {
1682
- wctx.use_buf(ctx0, 0);
1652
+ wstate.use_buf(ctx0, 0);
1683
1653
 
1684
1654
  cur = ggml_norm(ctx0, cur);
1685
1655
 
1686
- wctx.use_buf(ctx0, 1);
1656
+ wstate.use_buf(ctx0, 1);
1687
1657
 
1688
1658
  // cur = ln_f_g*cur + ln_f_b
1689
1659
  cur = ggml_add(ctx0,
1690
- ggml_mul(ctx0,
1691
- ggml_repeat(ctx0, model.e_ln_w, cur),
1692
- cur),
1693
- ggml_repeat(ctx0, model.e_ln_b, cur));
1660
+ ggml_mul(ctx0,
1661
+ ggml_repeat(ctx0, model.e_ln_w, cur),
1662
+ cur),
1663
+ ggml_repeat(ctx0, model.e_ln_b, cur));
1694
1664
  }
1695
1665
 
1696
- wctx.use_buf(ctx0, -1);
1666
+ wstate.use_buf(ctx0, -1);
1697
1667
 
1698
1668
  // run the computation
1699
1669
  {
@@ -1701,7 +1671,7 @@ static bool whisper_encode(
1701
1671
  gf.n_threads = n_threads;
1702
1672
 
1703
1673
  ggml_build_forward_expand(&gf, cur);
1704
- ggml_graph_compute (ctx0, &gf);
1674
+ ggml_graph_compute(ctx0, &gf);
1705
1675
 
1706
1676
  //ggml_graph_print(&gf);
1707
1677
  }
@@ -1731,34 +1701,34 @@ static bool whisper_encode(
1731
1701
  cur->src1 = nullptr;
1732
1702
 
1733
1703
  for (int il = 0; il < model.hparams.n_text_layer; ++il) {
1734
- auto & layer = model.layers_decoder[il];
1704
+ auto& layer = model.layers_decoder[il];
1735
1705
 
1736
- wctx.use_buf(ctx0, 0);
1706
+ wstate.use_buf(ctx0, 0);
1737
1707
 
1738
- struct ggml_tensor * Kcross = ggml_mul_mat(ctx0,
1739
- layer.cross_attn_k_w,
1740
- cur);
1708
+ struct ggml_tensor* Kcross = ggml_mul_mat(ctx0,
1709
+ layer.cross_attn_k_w,
1710
+ cur);
1741
1711
 
1742
- Kcross = ggml_scale(ctx0, Kcross, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
1712
+ Kcross = ggml_scale(ctx0, Kcross, ggml_new_f32(ctx0, pow(float(n_state) / n_head, -0.25)));
1743
1713
 
1744
- wctx.use_buf(ctx0, 1);
1714
+ wstate.use_buf(ctx0, 1);
1745
1715
 
1746
- struct ggml_tensor * Vcross = ggml_mul_mat(ctx0,
1747
- layer.cross_attn_v_w,
1748
- cur);
1716
+ struct ggml_tensor* Vcross = ggml_mul_mat(ctx0,
1717
+ layer.cross_attn_v_w,
1718
+ cur);
1749
1719
 
1750
1720
  Vcross = ggml_add(ctx0,
1751
- ggml_repeat(ctx0,
1752
- layer.cross_attn_v_b,
1753
- Vcross),
1754
- Vcross);
1721
+ ggml_repeat(ctx0,
1722
+ layer.cross_attn_v_b,
1723
+ Vcross),
1724
+ Vcross);
1755
1725
 
1756
- wctx.use_buf(ctx0, -1);
1726
+ wstate.use_buf(ctx0, -1);
1757
1727
 
1758
- //struct ggml_tensor * k = ggml_view_1d(ctx0, wctx.kv_cross.k, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.k)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx));
1759
- //struct ggml_tensor * v = ggml_view_1d(ctx0, wctx.kv_cross.v, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.v)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx));
1760
- struct ggml_tensor * k = ggml_view_1d(ctx0, wctx.kv_cross.k, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.k)*n_state)*(il*n_ctx));
1761
- struct ggml_tensor * v = ggml_view_1d(ctx0, wctx.kv_cross.v, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.v)*n_state)*(il*n_ctx));
1728
+ //struct ggml_tensor * k = ggml_view_1d(ctx0, wstate.kv_cross.k, n_state*n_ctx, (ggml_element_size(wstate.kv_cross.k)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx));
1729
+ //struct ggml_tensor * v = ggml_view_1d(ctx0, wstate.kv_cross.v, n_state*n_ctx, (ggml_element_size(wstate.kv_cross.v)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx));
1730
+ struct ggml_tensor* k = ggml_view_1d(ctx0, wstate.kv_cross.k, n_state*n_ctx, (ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx));
1731
+ struct ggml_tensor* v = ggml_view_1d(ctx0, wstate.kv_cross.v, n_state*n_ctx, (ggml_element_size(wstate.kv_cross.v)*n_state)*(il*n_ctx));
1762
1732
 
1763
1733
  ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcross, k));
1764
1734
  ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcross, v));
@@ -1779,8 +1749,8 @@ static bool whisper_encode(
1779
1749
 
1780
1750
  ggml_free(ctx0);
1781
1751
 
1782
- wctx.t_encode_us += ggml_time_us() - t_start_us;
1783
- wctx.n_encode++;
1752
+ wstate.t_encode_us += ggml_time_us() - t_start_us;
1753
+ wstate.n_encode++;
1784
1754
 
1785
1755
  return true;
1786
1756
  }
@@ -1795,8 +1765,9 @@ static bool whisper_encode(
1795
1765
  // - n_tokens: number of tokens in the prompt
1796
1766
  // - n_past: number of past tokens to prefix the prompt with
1797
1767
  //
1798
- static bool whisper_decode(
1768
+ static bool whisper_decode_internal(
1799
1769
  whisper_context & wctx,
1770
+ whisper_state & wstate,
1800
1771
  whisper_decoder & decoder,
1801
1772
  const whisper_token * tokens,
1802
1773
  const int n_tokens,
@@ -1811,7 +1782,7 @@ static bool whisper_decode(
1811
1782
 
1812
1783
  WHISPER_ASSERT(!!kv_self.ctx);
1813
1784
 
1814
- auto & logits_out = wctx.logits;
1785
+ auto & logits_out = wstate.logits;
1815
1786
 
1816
1787
  const int n_vocab = hparams.n_vocab;
1817
1788
 
@@ -1821,13 +1792,13 @@ static bool whisper_decode(
1821
1792
  const int n_layer = hparams.n_text_layer;
1822
1793
 
1823
1794
  const int N = n_tokens;
1824
- const int M = wctx.exp_n_audio_ctx > 0 ? wctx.exp_n_audio_ctx : hparams.n_audio_ctx;
1795
+ const int M = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx;
1825
1796
 
1826
1797
  //WHISPER_PRINT_DEBUG("%s: n_past = %d, N = %d, M = %d, n_ctx = %d\n", __func__, n_past, N, M, n_ctx);
1827
1798
 
1828
1799
  struct ggml_init_params params;
1829
- params.mem_size = wctx.buf_compute.size();
1830
- params.mem_buffer = wctx.buf_compute.data();
1800
+ params.mem_size = wstate.buf_compute.size();
1801
+ params.mem_buffer = wstate.buf_compute.data();
1831
1802
 
1832
1803
  struct ggml_context * ctx0 = ggml_init(params);
1833
1804
 
@@ -1842,7 +1813,7 @@ static bool whisper_decode(
1842
1813
  ((int32_t *) position->data)[i] = n_past + i;
1843
1814
  }
1844
1815
 
1845
- wctx.use_buf(ctx0, 3);
1816
+ wstate.use_buf(ctx0, 3);
1846
1817
 
1847
1818
  // token encoding + position encoding
1848
1819
  struct ggml_tensor * cur =
@@ -1857,7 +1828,7 @@ static bool whisper_decode(
1857
1828
 
1858
1829
  // norm
1859
1830
  {
1860
- wctx.use_buf(ctx0, 0);
1831
+ wstate.use_buf(ctx0, 0);
1861
1832
 
1862
1833
  cur = ggml_norm(ctx0, inpL);
1863
1834
 
@@ -1871,7 +1842,7 @@ static bool whisper_decode(
1871
1842
 
1872
1843
  // self-attention
1873
1844
  {
1874
- wctx.use_buf(ctx0, 1);
1845
+ wstate.use_buf(ctx0, 1);
1875
1846
 
1876
1847
  struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
1877
1848
  layer.attn_q_w,
@@ -1913,7 +1884,7 @@ static bool whisper_decode(
1913
1884
 
1914
1885
  // ------
1915
1886
 
1916
- wctx.use_buf(ctx0, 0);
1887
+ wstate.use_buf(ctx0, 0);
1917
1888
 
1918
1889
  struct ggml_tensor * Q =
1919
1890
  ggml_permute(ctx0,
@@ -1929,12 +1900,12 @@ static bool whisper_decode(
1929
1900
  n_state/n_head, n_head, n_past + N),
1930
1901
  0, 2, 1, 3);
1931
1902
 
1932
- wctx.use_buf(ctx0, 1);
1903
+ wstate.use_buf(ctx0, 1);
1933
1904
 
1934
1905
  // K * Q
1935
1906
  struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
1936
1907
 
1937
- wctx.use_buf(ctx0, 0);
1908
+ wstate.use_buf(ctx0, 0);
1938
1909
 
1939
1910
  //struct ggml_tensor * KQ_scaled =
1940
1911
  // ggml_scale(ctx0,
@@ -1944,11 +1915,11 @@ static bool whisper_decode(
1944
1915
 
1945
1916
  struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ, n_past);
1946
1917
 
1947
- wctx.use_buf(ctx0, 1);
1918
+ wstate.use_buf(ctx0, 1);
1948
1919
 
1949
1920
  struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
1950
1921
 
1951
- wctx.use_buf(ctx0, 0);
1922
+ wstate.use_buf(ctx0, 0);
1952
1923
 
1953
1924
  struct ggml_tensor * V_trans =
1954
1925
  ggml_permute(ctx0,
@@ -1957,7 +1928,7 @@ static bool whisper_decode(
1957
1928
  n_state/n_head, n_head, n_past + N),
1958
1929
  1, 2, 0, 3);
1959
1930
 
1960
- wctx.use_buf(ctx0, 1);
1931
+ wstate.use_buf(ctx0, 1);
1961
1932
 
1962
1933
  struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
1963
1934
 
@@ -1970,31 +1941,31 @@ static bool whisper_decode(
1970
1941
 
1971
1942
  // projection
1972
1943
  {
1973
- wctx.use_buf(ctx0, 0);
1944
+ wstate.use_buf(ctx0, 0);
1974
1945
 
1975
1946
  cur = ggml_mul_mat(ctx0,
1976
1947
  layer.attn_ln_1_w,
1977
1948
  cur);
1978
1949
 
1979
- wctx.use_buf(ctx0, 1);
1950
+ wstate.use_buf(ctx0, 1);
1980
1951
 
1981
1952
  cur = ggml_add(ctx0,
1982
1953
  ggml_repeat(ctx0, layer.attn_ln_1_b, cur),
1983
1954
  cur);
1984
1955
  }
1985
1956
 
1986
- wctx.use_buf(ctx0, 2);
1957
+ wstate.use_buf(ctx0, 2);
1987
1958
 
1988
1959
  // add the input
1989
1960
  struct ggml_tensor * inpCA = ggml_add(ctx0, cur, inpL);
1990
1961
 
1991
1962
  // norm
1992
1963
  {
1993
- wctx.use_buf(ctx0, 0);
1964
+ wstate.use_buf(ctx0, 0);
1994
1965
 
1995
1966
  cur = ggml_norm(ctx0, inpCA); // note: we use inpCA here
1996
1967
 
1997
- wctx.use_buf(ctx0, 1);
1968
+ wstate.use_buf(ctx0, 1);
1998
1969
 
1999
1970
  // cur = ln_0_w*cur + ln_0_b
2000
1971
  cur = ggml_add(ctx0,
@@ -2006,7 +1977,7 @@ static bool whisper_decode(
2006
1977
 
2007
1978
  // cross-attention
2008
1979
  {
2009
- wctx.use_buf(ctx0, 0);
1980
+ wstate.use_buf(ctx0, 0);
2010
1981
 
2011
1982
  struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
2012
1983
  layer.cross_attn_q_w,
@@ -2023,19 +1994,19 @@ static bool whisper_decode(
2023
1994
  // Kcross is already scaled
2024
1995
  struct ggml_tensor * Kcross =
2025
1996
  ggml_reshape_3d(ctx0,
2026
- ggml_view_1d(ctx0, wctx.kv_cross.k, M*n_state, il*M*ggml_element_size(wctx.kv_cross.k)*n_state),
1997
+ ggml_view_1d(ctx0, wstate.kv_cross.k, M*n_state, il*M*ggml_element_size(wstate.kv_cross.k)*n_state),
2027
1998
  n_state/n_head, n_head, M);
2028
1999
 
2029
2000
  struct ggml_tensor * Vcross =
2030
2001
  ggml_reshape_3d(ctx0,
2031
- ggml_view_1d(ctx0, wctx.kv_cross.v, M*n_state, il*M*ggml_element_size(wctx.kv_cross.v)*n_state),
2002
+ ggml_view_1d(ctx0, wstate.kv_cross.v, M*n_state, il*M*ggml_element_size(wstate.kv_cross.v)*n_state),
2032
2003
  n_state/n_head, n_head, M);
2033
2004
 
2034
2005
  struct ggml_tensor * V_trans = ggml_permute(ctx0, Vcross, 1, 2, 0, 3);
2035
2006
 
2036
2007
  // ------
2037
2008
 
2038
- wctx.use_buf(ctx0, 1);
2009
+ wstate.use_buf(ctx0, 1);
2039
2010
 
2040
2011
  struct ggml_tensor * Q =
2041
2012
  ggml_permute(ctx0,
@@ -2046,7 +2017,7 @@ static bool whisper_decode(
2046
2017
 
2047
2018
  struct ggml_tensor * K = ggml_permute(ctx0, Kcross, 0, 2, 1, 3);
2048
2019
 
2049
- wctx.use_buf(ctx0, 0);
2020
+ wstate.use_buf(ctx0, 0);
2050
2021
 
2051
2022
  // K * Q
2052
2023
  struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
@@ -2060,15 +2031,15 @@ static bool whisper_decode(
2060
2031
  // no masking for cross-attention
2061
2032
  //struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
2062
2033
 
2063
- wctx.use_buf(ctx0, 1);
2034
+ wstate.use_buf(ctx0, 1);
2064
2035
 
2065
2036
  struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ);
2066
2037
 
2067
- wctx.use_buf(ctx0, 0);
2038
+ wstate.use_buf(ctx0, 0);
2068
2039
 
2069
2040
  struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
2070
2041
 
2071
- wctx.use_buf(ctx0, 1);
2042
+ wstate.use_buf(ctx0, 1);
2072
2043
 
2073
2044
  struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
2074
2045
 
@@ -2080,20 +2051,20 @@ static bool whisper_decode(
2080
2051
 
2081
2052
  // projection
2082
2053
  {
2083
- wctx.use_buf(ctx0, 0);
2054
+ wstate.use_buf(ctx0, 0);
2084
2055
 
2085
2056
  cur = ggml_mul_mat(ctx0,
2086
2057
  layer.cross_attn_ln_1_w,
2087
2058
  cur);
2088
2059
 
2089
- wctx.use_buf(ctx0, 1);
2060
+ wstate.use_buf(ctx0, 1);
2090
2061
 
2091
2062
  cur = ggml_add(ctx0,
2092
2063
  ggml_repeat(ctx0, layer.cross_attn_ln_1_b, cur),
2093
2064
  cur);
2094
2065
  }
2095
2066
 
2096
- wctx.use_buf(ctx0, 2);
2067
+ wstate.use_buf(ctx0, 2);
2097
2068
 
2098
2069
  // add the input
2099
2070
  cur = ggml_add(ctx0, cur, inpCA);
@@ -2104,11 +2075,11 @@ static bool whisper_decode(
2104
2075
  {
2105
2076
  // norm
2106
2077
  {
2107
- wctx.use_buf(ctx0, 0);
2078
+ wstate.use_buf(ctx0, 0);
2108
2079
 
2109
2080
  cur = ggml_norm(ctx0, inpFF);
2110
2081
 
2111
- wctx.use_buf(ctx0, 1);
2082
+ wstate.use_buf(ctx0, 1);
2112
2083
 
2113
2084
  // cur = mlp_ln_w*cur + mlp_ln_b
2114
2085
  cur = ggml_add(ctx0,
@@ -2118,39 +2089,39 @@ static bool whisper_decode(
2118
2089
  ggml_repeat(ctx0, layer.mlp_ln_b, cur));
2119
2090
  }
2120
2091
 
2121
- wctx.use_buf(ctx0, 0);
2092
+ wstate.use_buf(ctx0, 0);
2122
2093
 
2123
2094
  // fully connected
2124
2095
  cur = ggml_mul_mat(ctx0,
2125
2096
  layer.mlp_0_w,
2126
2097
  cur);
2127
2098
 
2128
- wctx.use_buf(ctx0, 1);
2099
+ wstate.use_buf(ctx0, 1);
2129
2100
 
2130
2101
  cur = ggml_add(ctx0,
2131
2102
  ggml_repeat(ctx0, layer.mlp_0_b, cur),
2132
2103
  cur);
2133
2104
 
2134
- wctx.use_buf(ctx0, 0);
2105
+ wstate.use_buf(ctx0, 0);
2135
2106
 
2136
2107
  // GELU activation
2137
2108
  cur = ggml_gelu(ctx0, cur);
2138
2109
 
2139
- wctx.use_buf(ctx0, 1);
2110
+ wstate.use_buf(ctx0, 1);
2140
2111
 
2141
2112
  // projection
2142
2113
  cur = ggml_mul_mat(ctx0,
2143
2114
  layer.mlp_1_w,
2144
2115
  cur);
2145
2116
 
2146
- wctx.use_buf(ctx0, 0);
2117
+ wstate.use_buf(ctx0, 0);
2147
2118
 
2148
2119
  cur = ggml_add(ctx0,
2149
2120
  ggml_repeat(ctx0, layer.mlp_1_b, cur),
2150
2121
  cur);
2151
2122
  }
2152
2123
 
2153
- wctx.use_buf(ctx0, 3);
2124
+ wstate.use_buf(ctx0, 3);
2154
2125
 
2155
2126
  inpL = ggml_add(ctx0, cur, inpFF);
2156
2127
  }
@@ -2159,11 +2130,11 @@ static bool whisper_decode(
2159
2130
 
2160
2131
  // norm
2161
2132
  {
2162
- wctx.use_buf(ctx0, 0);
2133
+ wstate.use_buf(ctx0, 0);
2163
2134
 
2164
2135
  cur = ggml_norm(ctx0, cur);
2165
2136
 
2166
- wctx.use_buf(ctx0, 1);
2137
+ wstate.use_buf(ctx0, 1);
2167
2138
 
2168
2139
  cur = ggml_add(ctx0,
2169
2140
  ggml_mul(ctx0,
@@ -2172,7 +2143,7 @@ static bool whisper_decode(
2172
2143
  ggml_repeat(ctx0, model.d_ln_b, cur));
2173
2144
  }
2174
2145
 
2175
- wctx.use_buf(ctx0, 0);
2146
+ wstate.use_buf(ctx0, 0);
2176
2147
 
2177
2148
  // compute logits only for the last token
2178
2149
  // comment this line to compute logits for all N tokens
@@ -2181,7 +2152,7 @@ static bool whisper_decode(
2181
2152
 
2182
2153
  struct ggml_tensor * logits = ggml_mul_mat(ctx0, model.d_te, cur);
2183
2154
 
2184
- wctx.use_buf(ctx0, -1);
2155
+ wstate.use_buf(ctx0, -1);
2185
2156
 
2186
2157
  // run the computation
2187
2158
  {
@@ -2208,8 +2179,8 @@ static bool whisper_decode(
2208
2179
 
2209
2180
  ggml_free(ctx0);
2210
2181
 
2211
- wctx.t_decode_us += ggml_time_us() - t_start_us;
2212
- wctx.n_decode++;
2182
+ wstate.t_decode_us += ggml_time_us() - t_start_us;
2183
+ wstate.n_decode++;
2213
2184
 
2214
2185
  return true;
2215
2186
  }
@@ -2313,7 +2284,7 @@ static void fft(const std::vector<float> & in, std::vector<float> & out) {
2313
2284
 
2314
2285
  // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L92-L124
2315
2286
  static bool log_mel_spectrogram(
2316
- whisper_context & wctx,
2287
+ whisper_state & wstate,
2317
2288
  const float * samples,
2318
2289
  const int n_samples,
2319
2290
  const int /*sample_rate*/,
@@ -2433,7 +2404,7 @@ static bool log_mel_spectrogram(
2433
2404
  mel.data[i] = (mel.data[i] + 4.0)/4.0;
2434
2405
  }
2435
2406
 
2436
- wctx.t_mel_us += ggml_time_us() - t_start_us;
2407
+ wstate.t_mel_us += ggml_time_us() - t_start_us;
2437
2408
 
2438
2409
  return true;
2439
2410
  }
@@ -2507,7 +2478,56 @@ static std::vector<whisper_vocab::id> tokenize(const whisper_vocab & vocab, cons
2507
2478
  // interface implementation
2508
2479
  //
2509
2480
 
2510
- struct whisper_context * whisper_init_from_file(const char * path_model) {
2481
+ struct whisper_state * whisper_init_state(whisper_context * ctx) {
2482
+ whisper_state * state = new whisper_state;
2483
+
2484
+ const size_t scale = ctx->model.hparams.f16 ? 1 : 2;
2485
+
2486
+
2487
+ if (!kv_cache_init(ctx->model.hparams, scale * MEM_REQ_KV_SELF.at(ctx->model.type), state->decoders[0].kv_self, ctx->wtype, ctx->model.hparams.n_text_ctx)) {
2488
+ fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__);
2489
+ return nullptr;
2490
+ }
2491
+
2492
+ {
2493
+ const size_t memory_size = ggml_nbytes(state->decoders[0].kv_self.k) + ggml_nbytes(state->decoders[0].kv_self.v);
2494
+ fprintf(stderr, "%s: kv self size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
2495
+ }
2496
+
2497
+ if (!kv_cache_init(ctx->model.hparams, scale * MEM_REQ_KV_CROSS.at(ctx->model.type), state->kv_cross, ctx->wtype, ctx->model.hparams.n_audio_ctx)) {
2498
+ fprintf(stderr, "%s: kv_cache_init() failed for cross-attention cache\n", __func__);
2499
+ return nullptr;
2500
+ }
2501
+
2502
+ {
2503
+ const size_t memory_size = ggml_nbytes(state->kv_cross.k) + ggml_nbytes(state->kv_cross.v);
2504
+ fprintf(stderr, "%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
2505
+ }
2506
+
2507
+
2508
+ state->logits.reserve(ctx->vocab.n_vocab * ctx->model.hparams.n_text_ctx);
2509
+
2510
+ state->logits_id.reserve(ctx->model.hparams.n_vocab);
2511
+
2512
+ // TAGS: WHISPER_DECODER_INIT
2513
+ state->decoders[0].sequence.tokens.reserve(ctx->model.hparams.n_text_ctx);
2514
+
2515
+ state->decoders[0].probs.reserve(ctx->vocab.n_vocab);
2516
+ state->decoders[0].logits.reserve(ctx->vocab.n_vocab);
2517
+ state->decoders[0].logprobs.reserve(ctx->vocab.n_vocab);
2518
+ state->buf_compute.resize(scale * std::max(MEM_REQ_ENCODE.at(ctx->model.type), MEM_REQ_DECODE.at(ctx->model.type)));
2519
+
2520
+ state->buf_scratch[0].resize(MEM_REQ_SCRATCH0.at(ctx->model.type));
2521
+ state->buf_scratch[1].resize(MEM_REQ_SCRATCH1.at(ctx->model.type));
2522
+ state->buf_scratch[2].resize(MEM_REQ_SCRATCH2.at(ctx->model.type));
2523
+ state->buf_scratch[3].resize(MEM_REQ_SCRATCH3.at(ctx->model.type));
2524
+
2525
+ state->rng = std::mt19937(0);
2526
+
2527
+ return state;
2528
+ }
2529
+
2530
+ struct whisper_context * whisper_init_from_file_no_state(const char * path_model) {
2511
2531
  whisper_model_loader loader = {};
2512
2532
 
2513
2533
  fprintf(stderr, "%s: loading model from '%s'\n", __func__, path_model);
@@ -2535,10 +2555,10 @@ struct whisper_context * whisper_init_from_file(const char * path_model) {
2535
2555
  fin->close();
2536
2556
  };
2537
2557
 
2538
- return whisper_init(&loader);
2558
+ return whisper_init_no_state(&loader);
2539
2559
  }
2540
2560
 
2541
- struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_size) {
2561
+ struct whisper_context * whisper_init_from_buffer_no_state(void * buffer, size_t buffer_size) {
2542
2562
  struct buf_context {
2543
2563
  uint8_t* buffer;
2544
2564
  size_t size;
@@ -2571,10 +2591,10 @@ struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_s
2571
2591
 
2572
2592
  loader.close = [](void * /*ctx*/) { };
2573
2593
 
2574
- return whisper_init(&loader);
2594
+ return whisper_init_no_state(&loader);
2575
2595
  }
2576
2596
 
2577
- struct whisper_context * whisper_init(struct whisper_model_loader * loader) {
2597
+ struct whisper_context * whisper_init_no_state(struct whisper_model_loader * loader) {
2578
2598
  ggml_time_init();
2579
2599
 
2580
2600
  whisper_context * ctx = new whisper_context;
@@ -2591,6 +2611,64 @@ struct whisper_context * whisper_init(struct whisper_model_loader * loader) {
2591
2611
  return ctx;
2592
2612
  }
2593
2613
 
2614
+ struct whisper_context * whisper_init_from_file(const char * path_model) {
2615
+ whisper_context * ctx = whisper_init_from_file_no_state(path_model);
2616
+ if (!ctx) {
2617
+ return nullptr;
2618
+ }
2619
+
2620
+ ctx->state = whisper_init_state(ctx);
2621
+ if (!ctx->state) {
2622
+ whisper_free(ctx);
2623
+ return nullptr;
2624
+ }
2625
+
2626
+ return ctx;
2627
+ }
2628
+
2629
+ struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_size) {
2630
+ whisper_context * ctx = whisper_init_from_buffer_no_state(buffer, buffer_size);
2631
+ if (!ctx) {
2632
+ return nullptr;
2633
+ }
2634
+
2635
+ ctx->state = whisper_init_state(ctx);
2636
+ if (!ctx->state) {
2637
+ whisper_free(ctx);
2638
+ return nullptr;
2639
+ }
2640
+
2641
+ return ctx;
2642
+ }
2643
+
2644
+ struct whisper_context * whisper_init(struct whisper_model_loader * loader) {
2645
+ whisper_context * ctx = whisper_init_no_state(loader);
2646
+ if (!ctx) {
2647
+ return nullptr;
2648
+ }
2649
+
2650
+ ctx->state = whisper_init_state(ctx);
2651
+ if (!ctx->state) {
2652
+ whisper_free(ctx);
2653
+ return nullptr;
2654
+ }
2655
+
2656
+ return ctx;
2657
+ }
2658
+
2659
+ void whisper_free_state(struct whisper_state * state)
2660
+ {
2661
+ if (state) {
2662
+ kv_cache_free(state->kv_cross);
2663
+
2664
+ for (int i = 0; i < WHISPER_MAX_DECODERS; ++i) {
2665
+ kv_cache_free(state->decoders[i].kv_self);
2666
+ }
2667
+
2668
+ delete state;
2669
+ }
2670
+ }
2671
+
2594
2672
  void whisper_free(struct whisper_context * ctx) {
2595
2673
  if (ctx) {
2596
2674
  if (ctx->model.ctx) {
@@ -2599,20 +2677,29 @@ void whisper_free(struct whisper_context * ctx) {
2599
2677
  if (ctx->model.buf) {
2600
2678
  delete ctx->model.buf;
2601
2679
  }
2602
- if (ctx->kv_cross.ctx) {
2603
- ggml_free(ctx->kv_cross.ctx);
2604
- }
2605
- for (int i = 0; i < WHISPER_MAX_DECODERS; ++i) {
2606
- if (ctx->decoders[i].kv_self.ctx) {
2607
- ggml_free(ctx->decoders[i].kv_self.ctx);
2608
- }
2609
- }
2680
+
2681
+ whisper_free_state(ctx->state);
2682
+
2610
2683
  delete ctx;
2611
2684
  }
2612
2685
  }
2613
2686
 
2687
+ int whisper_pcm_to_mel_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) {
2688
+ if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, false, state->mel)) {
2689
+ fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__);
2690
+ return -1;
2691
+ }
2692
+
2693
+ return 0;
2694
+ }
2695
+
2614
2696
  int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) {
2615
- if (!log_mel_spectrogram(*ctx, samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, false, ctx->mel)) {
2697
+ return whisper_pcm_to_mel_with_state(ctx, ctx->state, samples, n_samples, n_threads);
2698
+ }
2699
+
2700
+ // same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2
2701
+ int whisper_pcm_to_mel_phase_vocoder_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) {
2702
+ if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, 2 * WHISPER_N_FFT, 2 * WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, true, state->mel)) {
2616
2703
  fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__);
2617
2704
  return -1;
2618
2705
  }
@@ -2622,11 +2709,26 @@ int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int
2622
2709
 
2623
2710
  // same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2
2624
2711
  int whisper_pcm_to_mel_phase_vocoder(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) {
2625
- if (!log_mel_spectrogram(*ctx, samples, n_samples, WHISPER_SAMPLE_RATE, 2*WHISPER_N_FFT, 2*WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, true, ctx->mel)) {
2626
- fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__);
2712
+ return whisper_pcm_to_mel_phase_vocoder_with_state(ctx, ctx->state, samples, n_samples, n_threads);
2713
+ }
2714
+
2715
+ int whisper_set_mel_with_state(
2716
+ struct whisper_context * /*ctx*/,
2717
+ struct whisper_state * state,
2718
+ const float * data,
2719
+ int n_len,
2720
+ int n_mel) {
2721
+ if (n_mel != WHISPER_N_MEL) {
2722
+ fprintf(stderr, "%s: invalid number of mel bands: %d (expected %d)\n", __func__, n_mel, WHISPER_N_MEL);
2627
2723
  return -1;
2628
2724
  }
2629
2725
 
2726
+ state->mel.n_len = n_len;
2727
+ state->mel.n_mel = n_mel;
2728
+
2729
+ state->mel.data.resize(n_len*n_mel);
2730
+ memcpy(state->mel.data.data(), data, n_len*n_mel*sizeof(float));
2731
+
2630
2732
  return 0;
2631
2733
  }
2632
2734
 
@@ -2635,22 +2737,20 @@ int whisper_set_mel(
2635
2737
  const float * data,
2636
2738
  int n_len,
2637
2739
  int n_mel) {
2638
- if (n_mel != WHISPER_N_MEL) {
2639
- fprintf(stderr, "%s: invalid number of mel bands: %d (expected %d)\n", __func__, n_mel, WHISPER_N_MEL);
2740
+ return whisper_set_mel_with_state(ctx, ctx->state, data, n_len, n_mel);
2741
+ }
2742
+
2743
+ int whisper_encode_with_state(struct whisper_context * ctx, struct whisper_state * state, int offset, int n_threads) {
2744
+ if (!whisper_encode_internal(*ctx, *state, offset, n_threads)) {
2745
+ fprintf(stderr, "%s: failed to eval\n", __func__);
2640
2746
  return -1;
2641
2747
  }
2642
2748
 
2643
- ctx->mel.n_len = n_len;
2644
- ctx->mel.n_mel = n_mel;
2645
-
2646
- ctx->mel.data.resize(n_len*n_mel);
2647
- memcpy(ctx->mel.data.data(), data, n_len*n_mel*sizeof(float));
2648
-
2649
2749
  return 0;
2650
2750
  }
2651
2751
 
2652
2752
  int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) {
2653
- if (!whisper_encode(*ctx, offset, n_threads)) {
2753
+ if (!whisper_encode_internal(*ctx, *ctx->state, offset, n_threads)) {
2654
2754
  fprintf(stderr, "%s: failed to eval\n", __func__);
2655
2755
  return -1;
2656
2756
  }
@@ -2658,11 +2758,28 @@ int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) {
2658
2758
  return 0;
2659
2759
  }
2660
2760
 
2761
+ int whisper_decode_with_state(struct whisper_context * ctx, struct whisper_state * state, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) {
2762
+ const int selected_decoder_id = 0;
2763
+
2764
+ if (!whisper_decode_internal(*ctx, *state, state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads)) {
2765
+ fprintf(stderr, "%s: failed to eval\n", __func__);
2766
+ return 1;
2767
+ }
2768
+
2769
+ return 0;
2770
+ }
2771
+
2661
2772
  int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) {
2662
- // TODO: add selected_decoder_id to context
2773
+ // TODO: add selected_decoder_id to state
2663
2774
  const int selected_decoder_id = 0;
2664
2775
 
2665
- if (!whisper_decode(*ctx, ctx->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads)) {
2776
+ if (ctx->state == nullptr) {
2777
+ fprintf(stderr, "%s: ERROR state was not loaded.\n", __func__);
2778
+ return false;
2779
+ }
2780
+
2781
+
2782
+ if (!whisper_decode_internal(*ctx, *ctx->state, ctx->state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads)) {
2666
2783
  fprintf(stderr, "%s: failed to eval\n", __func__);
2667
2784
  return 1;
2668
2785
  }
@@ -2720,11 +2837,12 @@ const char * whisper_lang_str(int id) {
2720
2837
  return nullptr;
2721
2838
  }
2722
2839
 
2723
- int whisper_lang_auto_detect(
2840
+ int whisper_lang_auto_detect_with_state(
2724
2841
  struct whisper_context * ctx,
2725
- int offset_ms,
2726
- int n_threads,
2727
- float * lang_probs) {
2842
+ struct whisper_state * state,
2843
+ int offset_ms,
2844
+ int n_threads,
2845
+ float * lang_probs) {
2728
2846
  const int seek = offset_ms/10;
2729
2847
 
2730
2848
  if (seek < 0) {
@@ -2732,8 +2850,8 @@ int whisper_lang_auto_detect(
2732
2850
  return -1;
2733
2851
  }
2734
2852
 
2735
- if (seek >= ctx->mel.n_len) {
2736
- fprintf(stderr, "%s: offset %dms is past the end of the audio (%dms)\n", __func__, offset_ms, ctx->mel.n_len*10);
2853
+ if (seek >= state->mel.n_len) {
2854
+ fprintf(stderr, "%s: offset %dms is past the end of the audio (%dms)\n", __func__, offset_ms, state->mel.n_len*10);
2737
2855
  return -2;
2738
2856
  }
2739
2857
 
@@ -2745,17 +2863,17 @@ int whisper_lang_auto_detect(
2745
2863
 
2746
2864
  const std::vector<whisper_token> prompt = { whisper_token_sot(ctx) };
2747
2865
 
2748
- if (whisper_decode(ctx, prompt.data(), prompt.size(), 0, n_threads) != 0) {
2866
+ if (whisper_decode_with_state(ctx, state, prompt.data(), prompt.size(), 0, n_threads) != 0) {
2749
2867
  fprintf(stderr, "%s: failed to decode\n", __func__);
2750
2868
  return -7;
2751
2869
  }
2752
2870
 
2753
- auto & logits_id = ctx->logits_id;
2871
+ auto & logits_id = state->logits_id;
2754
2872
  logits_id.clear();
2755
2873
 
2756
2874
  for (const auto & kv : g_lang) {
2757
2875
  const auto token_lang = whisper_token_lang(ctx, kv.second.first);
2758
- logits_id.emplace_back(ctx->logits[token_lang], kv.second.first);
2876
+ logits_id.emplace_back(state->logits[token_lang], kv.second.first);
2759
2877
  }
2760
2878
 
2761
2879
  // sort descending
@@ -2794,8 +2912,20 @@ int whisper_lang_auto_detect(
2794
2912
  return logits_id[0].second;
2795
2913
  }
2796
2914
 
2915
+ int whisper_lang_auto_detect(
2916
+ struct whisper_context * ctx,
2917
+ int offset_ms,
2918
+ int n_threads,
2919
+ float * lang_probs) {
2920
+ return whisper_lang_auto_detect_with_state(ctx, ctx->state, offset_ms, n_threads, lang_probs);
2921
+ }
2922
+
2923
+ int whisper_n_len_from_state(struct whisper_state * state) {
2924
+ return state->mel.n_len;
2925
+ }
2926
+
2797
2927
  int whisper_n_len(struct whisper_context * ctx) {
2798
- return ctx->mel.n_len;
2928
+ return ctx->state->mel.n_len;
2799
2929
  }
2800
2930
 
2801
2931
  int whisper_n_vocab(struct whisper_context * ctx) {
@@ -2815,7 +2945,12 @@ int whisper_is_multilingual(struct whisper_context * ctx) {
2815
2945
  }
2816
2946
 
2817
2947
  float * whisper_get_logits(struct whisper_context * ctx) {
2818
- return ctx->logits.data();
2948
+ return ctx->state->logits.data();
2949
+ }
2950
+
2951
+
2952
+ float * whisper_get_logits_from_state(struct whisper_state * state) {
2953
+ return state->logits.data();
2819
2954
  }
2820
2955
 
2821
2956
  const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token) {
@@ -2861,24 +2996,29 @@ whisper_token whisper_token_transcribe(void) {
2861
2996
  void whisper_print_timings(struct whisper_context * ctx) {
2862
2997
  const int64_t t_end_us = ggml_time_us();
2863
2998
 
2864
- const int32_t n_sample = std::max(1, ctx->n_sample);
2865
- const int32_t n_encode = std::max(1, ctx->n_encode);
2866
- const int32_t n_decode = std::max(1, ctx->n_decode);
2867
-
2868
2999
  fprintf(stderr, "\n");
2869
- fprintf(stderr, "%s: fallbacks = %3d p / %3d h\n", __func__, ctx->n_fail_p, ctx->n_fail_h);
2870
- fprintf(stderr, "%s: load time = %8.2f ms\n", __func__, ctx->t_load_us/1000.0f);
2871
- fprintf(stderr, "%s: mel time = %8.2f ms\n", __func__, ctx->t_mel_us/1000.0f);
2872
- fprintf(stderr, "%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f*ctx->t_sample_us, n_sample, 1e-3f*ctx->t_sample_us/n_sample);
2873
- fprintf(stderr, "%s: encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f*ctx->t_encode_us, n_encode, 1e-3f*ctx->t_encode_us/n_encode);
2874
- fprintf(stderr, "%s: decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f*ctx->t_decode_us, n_decode, 1e-3f*ctx->t_decode_us/n_decode);
3000
+ fprintf(stderr, "%s: load time = %8.2f ms\n", __func__, ctx->t_load_us / 1000.0f);
3001
+ if (ctx->state != nullptr) {
3002
+
3003
+ const int32_t n_sample = std::max(1, ctx->state->n_sample);
3004
+ const int32_t n_encode = std::max(1, ctx->state->n_encode);
3005
+ const int32_t n_decode = std::max(1, ctx->state->n_decode);
3006
+
3007
+ fprintf(stderr, "%s: fallbacks = %3d p / %3d h\n", __func__, ctx->state->n_fail_p, ctx->state->n_fail_h);
3008
+ fprintf(stderr, "%s: mel time = %8.2f ms\n", __func__, ctx->state->t_mel_us / 1000.0f);
3009
+ fprintf(stderr, "%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_sample_us, n_sample, 1e-3f * ctx->state->t_sample_us / n_sample);
3010
+ fprintf(stderr, "%s: encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_encode_us, n_encode, 1e-3f * ctx->state->t_encode_us / n_encode);
3011
+ fprintf(stderr, "%s: decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_decode_us, n_decode, 1e-3f * ctx->state->t_decode_us / n_decode);
3012
+ }
2875
3013
  fprintf(stderr, "%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f);
2876
3014
  }
2877
3015
 
2878
3016
  void whisper_reset_timings(struct whisper_context * ctx) {
2879
- ctx->t_sample_us = 0;
2880
- ctx->t_encode_us = 0;
2881
- ctx->t_decode_us = 0;
3017
+ if (ctx->state != nullptr) {
3018
+ ctx->state->t_sample_us = 0;
3019
+ ctx->state->t_encode_us = 0;
3020
+ ctx->state->t_decode_us = 0;
3021
+ }
2882
3022
  }
2883
3023
 
2884
3024
  const char * whisper_print_system_info(void) {
@@ -2913,7 +3053,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
2913
3053
  /*.duration_ms =*/ 0,
2914
3054
 
2915
3055
  /*.translate =*/ false,
2916
- /*.no_context =*/ false,
3056
+ /*.no_context =*/ true,
2917
3057
  /*.single_segment =*/ false,
2918
3058
  /*.print_special =*/ false,
2919
3059
  /*.print_progress =*/ true,
@@ -2991,6 +3131,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
2991
3131
  static std::vector<float> get_signal_energy(const float * signal, int n_samples, int n_samples_per_half_window);
2992
3132
  static void whisper_exp_compute_token_level_timestamps(
2993
3133
  struct whisper_context & ctx,
3134
+ struct whisper_state & state,
2994
3135
  int i_segment,
2995
3136
  float thold_pt,
2996
3137
  float thold_ptsum);
@@ -3023,8 +3164,8 @@ static inline bool should_split_on_word(const char * txt, bool split_on_word) {
3023
3164
 
3024
3165
  // wrap the last segment to max_len characters
3025
3166
  // returns the number of new segments
3026
- static int whisper_wrap_segment(struct whisper_context & ctx, int max_len, bool split_on_word) {
3027
- auto segment = ctx.result_all.back();
3167
+ static int whisper_wrap_segment(struct whisper_context & ctx, struct whisper_state & state, int max_len, bool split_on_word) {
3168
+ auto segment = state.result_all.back();
3028
3169
 
3029
3170
  int res = 1;
3030
3171
  int acc = 0;
@@ -3046,24 +3187,24 @@ static int whisper_wrap_segment(struct whisper_context & ctx, int max_len, bool
3046
3187
  trim(text);
3047
3188
  }
3048
3189
 
3049
- ctx.result_all.back().text = std::move(text);
3050
- ctx.result_all.back().t1 = token.t0;
3051
- ctx.result_all.back().tokens.resize(i);
3190
+ state.result_all.back().text = std::move(text);
3191
+ state.result_all.back().t1 = token.t0;
3192
+ state.result_all.back().tokens.resize(i);
3052
3193
 
3053
- ctx.result_all.push_back({});
3054
- ctx.result_all.back().t0 = token.t0;
3055
- ctx.result_all.back().t1 = segment.t1;
3194
+ state.result_all.push_back({});
3195
+ state.result_all.back().t0 = token.t0;
3196
+ state.result_all.back().t1 = segment.t1;
3056
3197
 
3057
3198
  // add tokens [i, end] to the new segment
3058
- ctx.result_all.back().tokens.insert(
3059
- ctx.result_all.back().tokens.end(),
3199
+ state.result_all.back().tokens.insert(
3200
+ state.result_all.back().tokens.end(),
3060
3201
  segment.tokens.begin() + i,
3061
3202
  segment.tokens.end());
3062
3203
 
3063
3204
  acc = 0;
3064
3205
  text = "";
3065
3206
 
3066
- segment = ctx.result_all.back();
3207
+ segment = state.result_all.back();
3067
3208
  i = -1;
3068
3209
 
3069
3210
  res++;
@@ -3076,7 +3217,7 @@ static int whisper_wrap_segment(struct whisper_context & ctx, int max_len, bool
3076
3217
  if (split_on_word) {
3077
3218
  trim(text);
3078
3219
  }
3079
- ctx.result_all.back().text = std::move(text);
3220
+ state.result_all.back().text = std::move(text);
3080
3221
 
3081
3222
  return res;
3082
3223
  }
@@ -3093,6 +3234,7 @@ static const std::vector<std::string> non_speech_tokens = {
3093
3234
  // - computes logprobs and probs
3094
3235
  static void whisper_process_logits(
3095
3236
  struct whisper_context & ctx,
3237
+ struct whisper_state & state,
3096
3238
  const struct whisper_full_params params,
3097
3239
  struct whisper_decoder & decoder,
3098
3240
  float temperature) {
@@ -3111,7 +3253,7 @@ static void whisper_process_logits(
3111
3253
  auto & logprobs = decoder.logprobs;
3112
3254
  {
3113
3255
  logits.resize(n_logits);
3114
- memcpy(logits.data(), ctx.logits.data() + (ctx.logits.size() - n_logits), n_logits*sizeof(float));
3256
+ memcpy(logits.data(), state.logits.data() + (state.logits.size() - n_logits), n_logits*sizeof(float));
3115
3257
 
3116
3258
  if (temperature > 0.0f) {
3117
3259
  for (int i = 0; i < n_logits; i++) {
@@ -3149,7 +3291,7 @@ static void whisper_process_logits(
3149
3291
  logits[vocab.token_transcribe] = -INFINITY;
3150
3292
 
3151
3293
  if (params.logits_filter_callback) {
3152
- params.logits_filter_callback(&ctx, tokens_cur.data(), tokens_cur.size(), logits.data(), params.logits_filter_callback_user_data);
3294
+ params.logits_filter_callback(&ctx, &state, tokens_cur.data(), tokens_cur.size(), logits.data(), params.logits_filter_callback_user_data);
3153
3295
  }
3154
3296
 
3155
3297
  // suppress non-speech tokens
@@ -3310,6 +3452,7 @@ static void whisper_process_logits(
3310
3452
 
3311
3453
  static whisper_token_data whisper_sample_token(
3312
3454
  whisper_context & ctx,
3455
+ whisper_state & state,
3313
3456
  const whisper_decoder & decoder,
3314
3457
  bool best) {
3315
3458
  whisper_token_data result = {
@@ -3354,7 +3497,7 @@ static whisper_token_data whisper_sample_token(
3354
3497
  } else {
3355
3498
  std::discrete_distribution<> dist(probs.begin(), probs.end());
3356
3499
 
3357
- result.id = dist(ctx.rng);
3500
+ result.id = dist(state.rng);
3358
3501
  result.p = probs[result.id];
3359
3502
  result.plog = logprobs[result.id];
3360
3503
  }
@@ -3364,13 +3507,14 @@ static whisper_token_data whisper_sample_token(
3364
3507
  result.pt = result.p;
3365
3508
  }
3366
3509
 
3367
- ctx.n_sample++;
3510
+ state.n_sample++;
3368
3511
 
3369
3512
  return result;
3370
3513
  }
3371
3514
 
3372
3515
  static std::vector<whisper_token_data> whisper_sample_token_topk(
3373
3516
  whisper_context & ctx,
3517
+ whisper_state & state,
3374
3518
  const whisper_decoder & decoder,
3375
3519
  int k) {
3376
3520
  const auto & vocab = ctx.vocab;
@@ -3381,7 +3525,7 @@ static std::vector<whisper_token_data> whisper_sample_token_topk(
3381
3525
 
3382
3526
  const int n_logits = vocab.n_vocab;
3383
3527
 
3384
- auto & logits_id = ctx.logits_id;
3528
+ auto & logits_id = state.logits_id;
3385
3529
 
3386
3530
  logits_id.clear();
3387
3531
  for (int i = 0; i < n_logits; ++i) {
@@ -3434,7 +3578,7 @@ static std::vector<whisper_token_data> whisper_sample_token_topk(
3434
3578
  }
3435
3579
  }
3436
3580
 
3437
- ctx.n_sample++;
3581
+ state.n_sample++;
3438
3582
 
3439
3583
  return result;
3440
3584
  }
@@ -3488,24 +3632,25 @@ static void whisper_sequence_score(
3488
3632
  }
3489
3633
  }
3490
3634
 
3491
- int whisper_full(
3635
+ int whisper_full_with_state(
3492
3636
  struct whisper_context * ctx,
3493
- struct whisper_full_params params,
3494
- const float * samples,
3495
- int n_samples) {
3637
+ struct whisper_state * state,
3638
+ struct whisper_full_params params,
3639
+ const float * samples,
3640
+ int n_samples) {
3496
3641
  // clear old results
3497
- auto & result_all = ctx->result_all;
3642
+ auto & result_all = state->result_all;
3498
3643
 
3499
3644
  result_all.clear();
3500
3645
 
3501
3646
  // compute log mel spectrogram
3502
3647
  if (params.speed_up) {
3503
- if (whisper_pcm_to_mel_phase_vocoder(ctx, samples, n_samples, params.n_threads) != 0) {
3648
+ if (whisper_pcm_to_mel_phase_vocoder_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) {
3504
3649
  fprintf(stderr, "%s: failed to compute log mel spectrogram\n", __func__);
3505
3650
  return -1;
3506
3651
  }
3507
3652
  } else {
3508
- if (whisper_pcm_to_mel(ctx, samples, n_samples, params.n_threads) != 0) {
3653
+ if (whisper_pcm_to_mel_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) {
3509
3654
  fprintf(stderr, "%s: failed to compute log mel spectrogram\n", __func__);
3510
3655
  return -2;
3511
3656
  }
@@ -3515,26 +3660,26 @@ int whisper_full(
3515
3660
  if (params.language == nullptr || strlen(params.language) == 0 || strcmp(params.language, "auto") == 0) {
3516
3661
  std::vector<float> probs(whisper_lang_max_id() + 1, 0.0f);
3517
3662
 
3518
- const auto lang_id = whisper_lang_auto_detect(ctx, 0, params.n_threads, probs.data());
3663
+ const auto lang_id = whisper_lang_auto_detect_with_state(ctx, state, 0, params.n_threads, probs.data());
3519
3664
  if (lang_id < 0) {
3520
3665
  fprintf(stderr, "%s: failed to auto-detect language\n", __func__);
3521
3666
  return -3;
3522
3667
  }
3523
- ctx->lang_id = lang_id;
3668
+ state->lang_id = lang_id;
3524
3669
  params.language = whisper_lang_str(lang_id);
3525
3670
 
3526
3671
  fprintf(stderr, "%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]);
3527
3672
  }
3528
3673
 
3529
3674
  if (params.token_timestamps) {
3530
- ctx->t_beg = 0;
3531
- ctx->t_last = 0;
3532
- ctx->tid_last = 0;
3533
- ctx->energy = get_signal_energy(samples, n_samples, 32);
3675
+ state->t_beg = 0;
3676
+ state->t_last = 0;
3677
+ state->tid_last = 0;
3678
+ state->energy = get_signal_energy(samples, n_samples, 32);
3534
3679
  }
3535
3680
 
3536
3681
  const int seek_start = params.offset_ms/10;
3537
- const int seek_end = seek_start + (params.duration_ms == 0 ? whisper_n_len(ctx) : params.duration_ms/10);
3682
+ const int seek_end = seek_start + (params.duration_ms == 0 ? whisper_n_len_from_state(state) : params.duration_ms/10);
3538
3683
 
3539
3684
  // if length of spectrogram is less than 1s (100 samples), then return
3540
3685
  // basically don't process anything that is less than 1s
@@ -3572,10 +3717,10 @@ int whisper_full(
3572
3717
 
3573
3718
  // TAGS: WHISPER_DECODER_INIT
3574
3719
  for (int j = 1; j < n_decoders; j++) {
3575
- auto & decoder = ctx->decoders[j];
3720
+ auto & decoder = state->decoders[j];
3576
3721
 
3577
3722
  if (decoder.kv_self.ctx == nullptr) {
3578
- decoder.kv_self = ctx->decoders[0].kv_self;
3723
+ decoder.kv_self = state->decoders[0].kv_self;
3579
3724
  if (!kv_cache_reinit(decoder.kv_self)) {
3580
3725
  fprintf(stderr, "%s: kv_cache_reinit() failed for self-attention, decoder %d\n", __func__, j);
3581
3726
  return -4;
@@ -3583,7 +3728,7 @@ int whisper_full(
3583
3728
 
3584
3729
  WHISPER_PRINT_DEBUG("%s: initialized self-attention kv cache, decoder %d\n", __func__, j);
3585
3730
 
3586
- decoder.sequence.tokens.reserve(ctx->decoders[0].sequence.tokens.capacity());
3731
+ decoder.sequence.tokens.reserve(state->decoders[0].sequence.tokens.capacity());
3587
3732
 
3588
3733
  decoder.probs.resize (ctx->vocab.n_vocab);
3589
3734
  decoder.logits.resize (ctx->vocab.n_vocab);
@@ -3592,7 +3737,7 @@ int whisper_full(
3592
3737
  }
3593
3738
 
3594
3739
  // the accumulated text context so far
3595
- auto & prompt_past = ctx->prompt_past;
3740
+ auto & prompt_past = state->prompt_past;
3596
3741
  if (params.no_context) {
3597
3742
  prompt_past.clear();
3598
3743
  }
@@ -3611,13 +3756,13 @@ int whisper_full(
3611
3756
  fprintf(stderr, "%s: audio_ctx is larger than the maximum allowed (%d > %d)\n", __func__, params.audio_ctx, whisper_n_audio_ctx(ctx));
3612
3757
  return -5;
3613
3758
  }
3614
- ctx->exp_n_audio_ctx = params.audio_ctx;
3759
+ state->exp_n_audio_ctx = params.audio_ctx;
3615
3760
 
3616
3761
  // these tokens determine the task that will be performed
3617
3762
  std::vector<whisper_token> prompt_init = { whisper_token_sot(ctx) };
3618
3763
  if (whisper_is_multilingual(ctx)) {
3619
3764
  const int lang_id = whisper_lang_id(params.language);
3620
- ctx->lang_id = lang_id;
3765
+ state->lang_id = lang_id;
3621
3766
  prompt_init.push_back(whisper_token_lang(ctx, lang_id));
3622
3767
  if (params.translate) {
3623
3768
  prompt_init.push_back(whisper_token_translate());
@@ -3669,14 +3814,14 @@ int whisper_full(
3669
3814
  }
3670
3815
 
3671
3816
  if (params.encoder_begin_callback) {
3672
- if (params.encoder_begin_callback(ctx, params.encoder_begin_callback_user_data) == false) {
3817
+ if (params.encoder_begin_callback(ctx, state, params.encoder_begin_callback_user_data) == false) {
3673
3818
  fprintf(stderr, "%s: encoder_begin_callback returned false - aborting\n", __func__);
3674
3819
  break;
3675
3820
  }
3676
3821
  }
3677
3822
 
3678
3823
  // encode audio features starting at offset seek
3679
- if (!whisper_encode(*ctx, seek, params.n_threads)) {
3824
+ if (!whisper_encode_internal(*ctx, *state, seek, params.n_threads)) {
3680
3825
  fprintf(stderr, "%s: failed to encode\n", __func__);
3681
3826
  return -6;
3682
3827
  }
@@ -3717,7 +3862,7 @@ int whisper_full(
3717
3862
 
3718
3863
  // TAGS: WHISPER_DECODER_INIT
3719
3864
  for (int j = 0; j < n_decoders_cur; ++j) {
3720
- auto & decoder = ctx->decoders[j];
3865
+ auto & decoder = state->decoders[j];
3721
3866
 
3722
3867
  decoder.kv_self.n = 0;
3723
3868
 
@@ -3759,7 +3904,7 @@ int whisper_full(
3759
3904
  }
3760
3905
  WHISPER_PRINT_DEBUG("\n\n");
3761
3906
 
3762
- if (!whisper_decode(*ctx, ctx->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads)) {
3907
+ if (!whisper_decode_internal(*ctx, *state, state->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads)) {
3763
3908
  fprintf(stderr, "%s: failed to decode\n", __func__);
3764
3909
  return -7;
3765
3910
  }
@@ -3767,24 +3912,24 @@ int whisper_full(
3767
3912
  {
3768
3913
  const int64_t t_start_sample_us = ggml_time_us();
3769
3914
 
3770
- whisper_process_logits(*ctx, params, ctx->decoders[0], t_cur);
3915
+ whisper_process_logits(*ctx, *state, params, state->decoders[0], t_cur);
3771
3916
 
3772
- ctx->decoders[0].kv_self.n += prompt.size();
3917
+ state->decoders[0].kv_self.n += prompt.size();
3773
3918
 
3774
3919
  for (int j = 1; j < n_decoders_cur; ++j) {
3775
- auto & decoder = ctx->decoders[j];
3920
+ auto & decoder = state->decoders[j];
3776
3921
 
3777
- memcpy(decoder.kv_self.k->data, ctx->decoders[0].kv_self.k->data, ggml_nbytes(decoder.kv_self.k));
3778
- memcpy(decoder.kv_self.v->data, ctx->decoders[0].kv_self.v->data, ggml_nbytes(decoder.kv_self.v));
3922
+ memcpy(decoder.kv_self.k->data, state->decoders[0].kv_self.k->data, ggml_nbytes(decoder.kv_self.k));
3923
+ memcpy(decoder.kv_self.v->data, state->decoders[0].kv_self.v->data, ggml_nbytes(decoder.kv_self.v));
3779
3924
 
3780
3925
  decoder.kv_self.n += prompt.size();
3781
3926
 
3782
- memcpy(decoder.probs.data(), ctx->decoders[0].probs.data(), decoder.probs.size()*sizeof(decoder.probs[0]));
3783
- memcpy(decoder.logits.data(), ctx->decoders[0].logits.data(), decoder.logits.size()*sizeof(decoder.logits[0]));
3784
- memcpy(decoder.logprobs.data(), ctx->decoders[0].logprobs.data(), decoder.logprobs.size()*sizeof(decoder.logprobs[0]));
3927
+ memcpy(decoder.probs.data(), state->decoders[0].probs.data(), decoder.probs.size()*sizeof(decoder.probs[0]));
3928
+ memcpy(decoder.logits.data(), state->decoders[0].logits.data(), decoder.logits.size()*sizeof(decoder.logits[0]));
3929
+ memcpy(decoder.logprobs.data(), state->decoders[0].logprobs.data(), decoder.logprobs.size()*sizeof(decoder.logprobs[0]));
3785
3930
  }
3786
3931
 
3787
- ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
3932
+ state->t_sample_us += ggml_time_us() - t_start_sample_us;
3788
3933
  }
3789
3934
  }
3790
3935
 
@@ -3795,7 +3940,7 @@ int whisper_full(
3795
3940
  if (params.strategy == whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH) {
3796
3941
  kv_bufs.resize(n_decoders_cur);
3797
3942
  for (int j = 0; j < n_decoders_cur; ++j) {
3798
- auto & decoder = ctx->decoders[j];
3943
+ auto & decoder = state->decoders[j];
3799
3944
 
3800
3945
  if (decoder.completed || decoder.failed) {
3801
3946
  continue;
@@ -3813,7 +3958,7 @@ int whisper_full(
3813
3958
 
3814
3959
  // generate new sequence candidates for each decoder
3815
3960
  for (int j = 0; j < n_decoders_cur; ++j) {
3816
- auto & decoder = ctx->decoders[j];
3961
+ auto & decoder = state->decoders[j];
3817
3962
 
3818
3963
  if (decoder.completed || decoder.failed) {
3819
3964
  continue;
@@ -3823,16 +3968,16 @@ int whisper_full(
3823
3968
  case whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY:
3824
3969
  {
3825
3970
  if (t_cur < 1e-6f) {
3826
- decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, true));
3971
+ decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, *state, decoder, true));
3827
3972
  } else {
3828
- decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, false));
3973
+ decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, *state, decoder, false));
3829
3974
  }
3830
3975
 
3831
3976
  decoder.sequence.sum_logprobs_all += decoder.sequence.tokens.back().plog;
3832
3977
  } break;
3833
3978
  case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH:
3834
3979
  {
3835
- const auto tokens_new = whisper_sample_token_topk(*ctx, decoder, params.beam_search.beam_size);
3980
+ const auto tokens_new = whisper_sample_token_topk(*ctx, *state, decoder, params.beam_search.beam_size);
3836
3981
 
3837
3982
  for (const auto & token : tokens_new) {
3838
3983
  beam_candidates.push_back({ j, decoder.seek_delta, decoder.has_ts, decoder.sequence });
@@ -3857,7 +4002,7 @@ int whisper_full(
3857
4002
  uint32_t cur_c = 0;
3858
4003
 
3859
4004
  for (int j = 0; j < n_decoders_cur; ++j) {
3860
- auto & decoder = ctx->decoders[j];
4005
+ auto & decoder = state->decoders[j];
3861
4006
 
3862
4007
  if (decoder.completed || decoder.failed) {
3863
4008
  continue;
@@ -3886,7 +4031,7 @@ int whisper_full(
3886
4031
  // - check if the sequence is failed
3887
4032
  // - update sliding window based on timestamp tokens
3888
4033
  for (int j = 0; j < n_decoders_cur; ++j) {
3889
- auto & decoder = ctx->decoders[j];
4034
+ auto & decoder = state->decoders[j];
3890
4035
 
3891
4036
  if (decoder.completed || decoder.failed) {
3892
4037
  continue;
@@ -3968,7 +4113,7 @@ int whisper_full(
3968
4113
  bool completed_all = true;
3969
4114
 
3970
4115
  for (int j = 0; j < n_decoders_cur; ++j) {
3971
- auto & decoder = ctx->decoders[j];
4116
+ auto & decoder = state->decoders[j];
3972
4117
 
3973
4118
  if (decoder.completed || decoder.failed) {
3974
4119
  continue;
@@ -3982,11 +4127,11 @@ int whisper_full(
3982
4127
  }
3983
4128
  }
3984
4129
 
3985
- ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
4130
+ state->t_sample_us += ggml_time_us() - t_start_sample_us;
3986
4131
 
3987
4132
  // obtain logits for the next token
3988
4133
  for (int j = 0; j < n_decoders_cur; ++j) {
3989
- auto & decoder = ctx->decoders[j];
4134
+ auto & decoder = state->decoders[j];
3990
4135
 
3991
4136
  if (decoder.failed || decoder.completed) {
3992
4137
  continue;
@@ -3997,7 +4142,7 @@ int whisper_full(
3997
4142
 
3998
4143
  //WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, kv_self.n %d, seek_delta %d\n", __func__, j, decoder.tokens_tmp[0], decoder.kv_self.n, decoder.seek_delta);
3999
4144
 
4000
- if (!whisper_decode(*ctx, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.kv_self.n, params.n_threads)) {
4145
+ if (!whisper_decode_internal(*ctx, *state, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.kv_self.n, params.n_threads)) {
4001
4146
  fprintf(stderr, "%s: failed to decode\n", __func__);
4002
4147
  return -8;
4003
4148
  }
@@ -4005,11 +4150,11 @@ int whisper_full(
4005
4150
  {
4006
4151
  const int64_t t_start_sample_us = ggml_time_us();
4007
4152
 
4008
- whisper_process_logits(*ctx, params, decoder, t_cur);
4153
+ whisper_process_logits(*ctx, *state, params, decoder, t_cur);
4009
4154
 
4010
4155
  ++decoder.kv_self.n;
4011
4156
 
4012
- ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
4157
+ state->t_sample_us += ggml_time_us() - t_start_sample_us;
4013
4158
  }
4014
4159
  }
4015
4160
  }
@@ -4019,7 +4164,7 @@ int whisper_full(
4019
4164
  double best_score = -INFINITY;
4020
4165
 
4021
4166
  for (int j = 0; j < n_decoders_cur; ++j) {
4022
- auto & decoder = ctx->decoders[j];
4167
+ auto & decoder = state->decoders[j];
4023
4168
 
4024
4169
  if (decoder.failed) {
4025
4170
  continue;
@@ -4036,7 +4181,7 @@ int whisper_full(
4036
4181
  __func__, j, decoder.sequence.entropy, params.entropy_thold);
4037
4182
 
4038
4183
  decoder.failed = true;
4039
- ctx->n_fail_h++;
4184
+ state->n_fail_h++;
4040
4185
 
4041
4186
  continue;
4042
4187
  }
@@ -4054,11 +4199,11 @@ int whisper_full(
4054
4199
  {
4055
4200
  bool success = true;
4056
4201
 
4057
- const auto & decoder = ctx->decoders[best_decoder_id];
4202
+ const auto & decoder = state->decoders[best_decoder_id];
4058
4203
 
4059
4204
  if (decoder.failed || decoder.sequence.avg_logprobs < params.logprob_thold) {
4060
4205
  success = false;
4061
- ctx->n_fail_p++;
4206
+ state->n_fail_p++;
4062
4207
  }
4063
4208
 
4064
4209
  if (success) {
@@ -4075,7 +4220,7 @@ int whisper_full(
4075
4220
 
4076
4221
  // output results through a user-provided callback
4077
4222
  {
4078
- const auto & best_decoder = ctx->decoders[best_decoder_id];
4223
+ const auto & best_decoder = state->decoders[best_decoder_id];
4079
4224
 
4080
4225
  const auto seek_delta = best_decoder.seek_delta;
4081
4226
  const auto result_len = best_decoder.sequence.result_len;
@@ -4138,14 +4283,14 @@ int whisper_full(
4138
4283
 
4139
4284
  if (params.token_timestamps) {
4140
4285
  whisper_exp_compute_token_level_timestamps(
4141
- *ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
4286
+ *ctx, *state, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
4142
4287
 
4143
4288
  if (params.max_len > 0) {
4144
- n_new = whisper_wrap_segment(*ctx, params.max_len, params.split_on_word);
4289
+ n_new = whisper_wrap_segment(*ctx, *state, params.max_len, params.split_on_word);
4145
4290
  }
4146
4291
  }
4147
4292
  if (params.new_segment_callback) {
4148
- params.new_segment_callback(ctx, n_new, params.new_segment_callback_user_data);
4293
+ params.new_segment_callback(ctx, state, n_new, params.new_segment_callback_user_data);
4149
4294
  }
4150
4295
  }
4151
4296
  text = "";
@@ -4182,14 +4327,14 @@ int whisper_full(
4182
4327
 
4183
4328
  if (params.token_timestamps) {
4184
4329
  whisper_exp_compute_token_level_timestamps(
4185
- *ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
4330
+ *ctx, *state, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
4186
4331
 
4187
4332
  if (params.max_len > 0) {
4188
- n_new = whisper_wrap_segment(*ctx, params.max_len, params.split_on_word);
4333
+ n_new = whisper_wrap_segment(*ctx, *state, params.max_len, params.split_on_word);
4189
4334
  }
4190
4335
  }
4191
4336
  if (params.new_segment_callback) {
4192
- params.new_segment_callback(ctx, n_new, params.new_segment_callback_user_data);
4337
+ params.new_segment_callback(ctx, state, n_new, params.new_segment_callback_user_data);
4193
4338
  }
4194
4339
  }
4195
4340
  }
@@ -4204,6 +4349,15 @@ int whisper_full(
4204
4349
  return 0;
4205
4350
  }
4206
4351
 
4352
+
4353
+ int whisper_full(
4354
+ struct whisper_context * ctx,
4355
+ struct whisper_full_params params,
4356
+ const float * samples,
4357
+ int n_samples) {
4358
+ return whisper_full_with_state(ctx, ctx->state, params, samples, n_samples);
4359
+ }
4360
+
4207
4361
  int whisper_full_parallel(
4208
4362
  struct whisper_context * ctx,
4209
4363
  struct whisper_full_params params,
@@ -4213,40 +4367,10 @@ int whisper_full_parallel(
4213
4367
  if (n_processors == 1) {
4214
4368
  return whisper_full(ctx, params, samples, n_samples);
4215
4369
  }
4216
-
4217
4370
  int ret = 0;
4218
4371
 
4219
- // prepare separate contexts for each thread
4220
- std::vector<struct whisper_context> ctxs(n_processors - 1);
4221
-
4222
- for (int i = 0; i < n_processors - 1; ++i) {
4223
- auto & ctx_p = ctxs[i];
4224
-
4225
- ctx_p = *ctx;
4226
-
4227
- ctx_p.logits.reserve(ctx_p.vocab.n_vocab*ctx_p.model.hparams.n_text_ctx);
4228
-
4229
- ctx_p.logits_id.reserve(ctx_p.vocab.n_vocab);
4230
-
4231
- if (!kv_cache_reinit(ctx_p.kv_cross)) {
4232
- fprintf(stderr, "%s: kv_cache_reinit() failed for cross-attention, processor %d\n", __func__, i);
4233
- return false;
4234
- }
4235
-
4236
- // TAGS: WHISPER_DECODER_INIT
4237
- for (int j = 0; j < WHISPER_MAX_DECODERS; ++j) {
4238
- if (ctx_p.decoders[j].kv_self.ctx && !kv_cache_reinit(ctx_p.decoders[j].kv_self)) {
4239
- fprintf(stderr, "%s: kv_cache_reinit() failed for self-attention, decoder %d, processor %d\n", __func__, j, i);
4240
- return false;
4241
- }
4242
-
4243
- ctx_p.decoders[j].sequence.tokens.reserve(ctx_p.model.hparams.n_text_ctx);
4244
-
4245
- ctx_p.decoders[j].probs.reserve (ctx_p.vocab.n_vocab);
4246
- ctx_p.decoders[j].logits.reserve (ctx_p.vocab.n_vocab);
4247
- ctx_p.decoders[j].logprobs.reserve(ctx_p.vocab.n_vocab);
4248
- }
4249
- }
4372
+ // prepare separate states for each thread
4373
+ std::vector<whisper_state*> states;
4250
4374
 
4251
4375
  const int offset_samples = (WHISPER_SAMPLE_RATE*params.offset_ms)/1000;
4252
4376
  const int n_samples_per_processor = (n_samples - offset_samples)/n_processors;
@@ -4256,6 +4380,9 @@ int whisper_full_parallel(
4256
4380
 
4257
4381
  std::vector<std::thread> workers(n_processors - 1);
4258
4382
  for (int i = 0; i < n_processors - 1; ++i) {
4383
+ // create a new state for each thread
4384
+ states.push_back(whisper_init_state(ctx));
4385
+
4259
4386
  const int start_samples = offset_samples + (i + 1)*n_samples_per_processor;
4260
4387
  const int n_samples_cur = (i == n_processors - 2) ? n_samples - start_samples : n_samples_per_processor;
4261
4388
 
@@ -4268,13 +4395,17 @@ int whisper_full_parallel(
4268
4395
  params_cur.new_segment_callback = nullptr;
4269
4396
  params_cur.new_segment_callback_user_data = nullptr;
4270
4397
 
4271
- workers[i] = std::thread(whisper_full, &ctxs[i], std::move(params_cur), samples + start_samples, n_samples_cur);
4398
+ workers[i] = std::thread(whisper_full_with_state, ctx, states[i], std::move(params_cur), samples + start_samples, n_samples_cur);
4272
4399
  }
4273
4400
 
4274
4401
  {
4275
4402
  auto params_cur = params;
4276
4403
 
4277
- ret = whisper_full(ctx, std::move(params_cur), samples, offset_samples + n_samples_per_processor);
4404
+ // We need to disable the print real-time for this one as well, otherwise it will show only for the first chunk.
4405
+ params_cur.print_realtime = false;
4406
+
4407
+ // Run the first transformation using default state but only for the first chunk.
4408
+ ret = whisper_full_with_state(ctx, ctx->state, std::move(params_cur), samples, offset_samples + n_samples_per_processor);
4278
4409
  }
4279
4410
 
4280
4411
  for (int i = 0; i < n_processors - 1; ++i) {
@@ -4283,45 +4414,43 @@ int whisper_full_parallel(
4283
4414
 
4284
4415
  const int64_t offset_t = (int64_t) params.offset_ms/10.0;
4285
4416
 
4286
- // combine results into ctx->result_all
4417
+ // combine results into result_state->result_all from all other states
4287
4418
  for (int i = 0; i < n_processors - 1; ++i) {
4288
- auto & results_i = ctxs[i].result_all;
4419
+ auto& results_i = states[i]->result_all;
4289
4420
 
4290
- for (auto & result : results_i) {
4421
+ for (auto& result : results_i) {
4291
4422
  // correct the segment timestamp taking into account the offset
4292
- result.t0 += 100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t;
4293
- result.t1 += 100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t;
4423
+ result.t0 += 100 * ((i + 1) * n_samples_per_processor) / WHISPER_SAMPLE_RATE + offset_t;
4424
+ result.t1 += 100 * ((i + 1) * n_samples_per_processor) / WHISPER_SAMPLE_RATE + offset_t;
4425
+
4294
4426
 
4295
4427
  // make sure that segments are not overlapping
4296
- if (!ctx->result_all.empty()) {
4297
- result.t0 = std::max(result.t0, ctx->result_all.back().t1);
4428
+ if (!ctx->state->result_all.empty()) {
4429
+ result.t0 = std::max(result.t0, ctx->state->result_all.back().t1);
4298
4430
  }
4299
4431
 
4300
- ctx->result_all.push_back(std::move(result));
4432
+ ctx->state->result_all.push_back(std::move(result));
4301
4433
 
4302
4434
  // call the new_segment_callback for each segment
4303
4435
  if (params.new_segment_callback) {
4304
- params.new_segment_callback(ctx, 1, params.new_segment_callback_user_data);
4436
+ params.new_segment_callback(ctx, ctx->state, 1, params.new_segment_callback_user_data);
4305
4437
  }
4306
4438
  }
4307
4439
 
4308
- ctx->t_mel_us += ctxs[i].t_mel_us;
4309
- ctx->t_sample_us += ctxs[i].t_sample_us;
4310
- ctx->t_encode_us += ctxs[i].t_encode_us;
4311
- ctx->t_decode_us += ctxs[i].t_decode_us;
4440
+ ctx->state->t_mel_us += states[i]->t_mel_us;
4312
4441
 
4313
- kv_cache_free(ctx->kv_cross);
4442
+ ctx->state->t_sample_us += states[i]->t_sample_us;
4443
+ ctx->state->t_encode_us += states[i]->t_encode_us;
4444
+ ctx->state->t_decode_us += states[i]->t_decode_us;
4314
4445
 
4315
- for (int j = 0; j < WHISPER_MAX_DECODERS; ++j) {
4316
- kv_cache_free(ctx->decoders[j].kv_self);
4317
- }
4446
+ whisper_free_state(states[i]);
4318
4447
  }
4319
4448
 
4320
4449
  // average the timings
4321
- ctx->t_mel_us /= n_processors;
4322
- ctx->t_sample_us /= n_processors;
4323
- ctx->t_encode_us /= n_processors;
4324
- ctx->t_decode_us /= n_processors;
4450
+ ctx->state->t_mel_us /= n_processors;
4451
+ ctx->state->t_sample_us /= n_processors;
4452
+ ctx->state->t_encode_us /= n_processors;
4453
+ ctx->state->t_decode_us /= n_processors;
4325
4454
 
4326
4455
  // print information about the audio boundaries
4327
4456
  fprintf(stderr, "\n");
@@ -4334,44 +4463,84 @@ int whisper_full_parallel(
4334
4463
  return ret;
4335
4464
  }
4336
4465
 
4466
+ int whisper_full_n_segments_from_state(struct whisper_state * state) {
4467
+ return state->result_all.size();
4468
+ }
4469
+
4337
4470
  int whisper_full_n_segments(struct whisper_context * ctx) {
4338
- return ctx->result_all.size();
4471
+ return ctx->state->result_all.size();
4472
+ }
4473
+
4474
+ int whisper_full_lang_id_from_state(struct whisper_state * state) {
4475
+ return state->lang_id;
4339
4476
  }
4340
4477
 
4341
4478
  int whisper_full_lang_id(struct whisper_context * ctx) {
4342
- return ctx->lang_id;
4479
+ return ctx->state->lang_id;
4480
+ }
4481
+
4482
+ int64_t whisper_full_get_segment_t0_from_state(struct whisper_state * state, int i_segment) {
4483
+ return state->result_all[i_segment].t0;
4343
4484
  }
4344
4485
 
4345
4486
  int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment) {
4346
- return ctx->result_all[i_segment].t0;
4487
+ return ctx->state->result_all[i_segment].t0;
4488
+ }
4489
+
4490
+ int64_t whisper_full_get_segment_t1_from_state(struct whisper_state * state, int i_segment) {
4491
+ return state->result_all[i_segment].t1;
4347
4492
  }
4348
4493
 
4349
4494
  int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment) {
4350
- return ctx->result_all[i_segment].t1;
4495
+ return ctx->state->result_all[i_segment].t1;
4496
+ }
4497
+
4498
+ const char * whisper_full_get_segment_text_from_state(struct whisper_state * state, int i_segment) {
4499
+ return state->result_all[i_segment].text.c_str();
4351
4500
  }
4352
4501
 
4353
4502
  const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment) {
4354
- return ctx->result_all[i_segment].text.c_str();
4503
+ return ctx->state->result_all[i_segment].text.c_str();
4504
+ }
4505
+
4506
+ int whisper_full_n_tokens_from_state(struct whisper_state * state, int i_segment) {
4507
+ return state->result_all[i_segment].tokens.size();
4355
4508
  }
4356
4509
 
4357
4510
  int whisper_full_n_tokens(struct whisper_context * ctx, int i_segment) {
4358
- return ctx->result_all[i_segment].tokens.size();
4511
+ return ctx->state->result_all[i_segment].tokens.size();
4512
+ }
4513
+
4514
+ const char * whisper_full_get_token_text_from_state(struct whisper_context * ctx, struct whisper_state * state, int i_segment, int i_token) {
4515
+ return ctx->vocab.id_to_token[state->result_all[i_segment].tokens[i_token].id].c_str();
4516
+ }
4517
+
4518
+ const char* whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token) {
4519
+ return ctx->vocab.id_to_token[ctx->state->result_all[i_segment].tokens[i_token].id].c_str();
4359
4520
  }
4360
4521
 
4361
- const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token) {
4362
- return ctx->vocab.id_to_token[ctx->result_all[i_segment].tokens[i_token].id].c_str();
4522
+ whisper_token whisper_full_get_token_id_from_state(struct whisper_state * state, int i_segment, int i_token) {
4523
+ return state->result_all[i_segment].tokens[i_token].id;
4363
4524
  }
4364
4525
 
4365
4526
  whisper_token whisper_full_get_token_id(struct whisper_context * ctx, int i_segment, int i_token) {
4366
- return ctx->result_all[i_segment].tokens[i_token].id;
4527
+ return ctx->state->result_all[i_segment].tokens[i_token].id;
4528
+ }
4529
+
4530
+ struct whisper_token_data whisper_full_get_token_data_from_state(struct whisper_state * state, int i_segment, int i_token) {
4531
+ return state->result_all[i_segment].tokens[i_token];
4367
4532
  }
4368
4533
 
4369
4534
  struct whisper_token_data whisper_full_get_token_data(struct whisper_context * ctx, int i_segment, int i_token) {
4370
- return ctx->result_all[i_segment].tokens[i_token];
4535
+ return ctx->state->result_all[i_segment].tokens[i_token];
4536
+ }
4537
+
4538
+ float whisper_full_get_token_p_from_state(struct whisper_state * state, int i_segment, int i_token) {
4539
+ return state->result_all[i_segment].tokens[i_token].p;
4371
4540
  }
4372
4541
 
4373
4542
  float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token) {
4374
- return ctx->result_all[i_segment].tokens[i_token].p;
4543
+ return ctx->state->result_all[i_segment].tokens[i_token].p;
4375
4544
  }
4376
4545
 
4377
4546
  // =================================================================================================
@@ -4382,6 +4551,15 @@ float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int
4382
4551
  //
4383
4552
 
4384
4553
  WHISPER_API int whisper_bench_memcpy(int n_threads) {
4554
+ fputs(whisper_bench_memcpy_str(n_threads), stderr);
4555
+ return 0;
4556
+ }
4557
+
4558
+ WHISPER_API const char * whisper_bench_memcpy_str(int n_threads) {
4559
+ static std::string s;
4560
+ s = "";
4561
+ char strbuf[256];
4562
+
4385
4563
  ggml_time_init();
4386
4564
 
4387
4565
  size_t n = 50;
@@ -4411,7 +4589,8 @@ WHISPER_API int whisper_bench_memcpy(int n_threads) {
4411
4589
  src[0] = rand();
4412
4590
  }
4413
4591
 
4414
- fprintf(stderr, "memcpy: %.2f GB/s\n", (double) (n*size)/(tsum*1024llu*1024llu*1024llu));
4592
+ snprintf(strbuf, sizeof(strbuf), "memcpy: %.2f GB/s\n", (double) (n*size)/(tsum*1024llu*1024llu*1024llu));
4593
+ s += strbuf;
4415
4594
 
4416
4595
  // needed to prevent the compile from optimizing the memcpy away
4417
4596
  {
@@ -4419,16 +4598,26 @@ WHISPER_API int whisper_bench_memcpy(int n_threads) {
4419
4598
 
4420
4599
  for (size_t i = 0; i < size; i++) sum += dst[i];
4421
4600
 
4422
- fprintf(stderr, "sum: %s %f\n", sum == -536870910.00 ? "ok" : "error", sum);
4601
+ snprintf(strbuf, sizeof(strbuf), "sum: %s %f\n", sum == -536870910.00 ? "ok" : "error", sum);
4602
+ s += strbuf;
4423
4603
  }
4424
4604
 
4425
4605
  free(src);
4426
4606
  free(dst);
4427
4607
 
4428
- return 0;
4608
+ return s.c_str();
4429
4609
  }
4430
4610
 
4431
4611
  WHISPER_API int whisper_bench_ggml_mul_mat(int n_threads) {
4612
+ fputs(whisper_bench_ggml_mul_mat_str(n_threads), stderr);
4613
+ return 0;
4614
+ }
4615
+
4616
+ WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads) {
4617
+ static std::string s;
4618
+ s = "";
4619
+ char strbuf[256];
4620
+
4432
4621
  ggml_time_init();
4433
4622
 
4434
4623
  const int n_max = 128;
@@ -4504,11 +4693,12 @@ WHISPER_API int whisper_bench_ggml_mul_mat(int n_threads) {
4504
4693
  s = ((2.0*N*N*N*n)/tsum)*1e-9;
4505
4694
  }
4506
4695
 
4507
- fprintf(stderr, "ggml_mul_mat: %5zu x %5zu: F16 %8.1f GFLOPS (%3d runs) / F32 %8.1f GFLOPS (%3d runs)\n",
4696
+ snprintf(strbuf, sizeof(strbuf), "ggml_mul_mat: %5zu x %5zu: F16 %8.1f GFLOPS (%3d runs) / F32 %8.1f GFLOPS (%3d runs)\n",
4508
4697
  N, N, s_fp16, n_fp16, s_fp32, n_fp32);
4698
+ s += strbuf;
4509
4699
  }
4510
4700
 
4511
- return 0;
4701
+ return s.c_str();
4512
4702
  }
4513
4703
 
4514
4704
  // =================================================================================================
@@ -4583,13 +4773,14 @@ static std::vector<float> get_signal_energy(const float * signal, int n_samples,
4583
4773
 
4584
4774
  static void whisper_exp_compute_token_level_timestamps(
4585
4775
  struct whisper_context & ctx,
4776
+ struct whisper_state & state,
4586
4777
  int i_segment,
4587
4778
  float thold_pt,
4588
4779
  float thold_ptsum) {
4589
- auto & segment = ctx.result_all[i_segment];
4780
+ auto & segment = state.result_all[i_segment];
4590
4781
  auto & tokens = segment.tokens;
4591
4782
 
4592
- const int n_samples = ctx.energy.size();
4783
+ const int n_samples = state.energy.size();
4593
4784
 
4594
4785
  if (n_samples == 0) {
4595
4786
  fprintf(stderr, "%s: no signal data available\n", __func__);
@@ -4612,9 +4803,9 @@ static void whisper_exp_compute_token_level_timestamps(
4612
4803
  return;
4613
4804
  }
4614
4805
 
4615
- auto & t_beg = ctx.t_beg;
4616
- auto & t_last = ctx.t_last;
4617
- auto & tid_last = ctx.tid_last;
4806
+ auto & t_beg = state.t_beg;
4807
+ auto & t_last = state.t_last;
4808
+ auto & tid_last = state.tid_last;
4618
4809
 
4619
4810
  for (int j = 0; j < n; ++j) {
4620
4811
  auto & token = tokens[j];
@@ -4737,15 +4928,15 @@ static void whisper_exp_compute_token_level_timestamps(
4737
4928
  float sum = 0.0f;
4738
4929
 
4739
4930
  for (int k = ss0; k < ss1; k++) {
4740
- sum += ctx.energy[k];
4931
+ sum += state.energy[k];
4741
4932
  }
4742
4933
 
4743
4934
  const float thold = 0.5*sum/ns;
4744
4935
 
4745
4936
  {
4746
4937
  int k = s0;
4747
- if (ctx.energy[k] > thold && j > 0) {
4748
- while (k > 0 && ctx.energy[k] > thold) {
4938
+ if (state.energy[k] > thold && j > 0) {
4939
+ while (k > 0 && state.energy[k] > thold) {
4749
4940
  k--;
4750
4941
  }
4751
4942
  tokens[j].t0 = sample_to_timestamp(k);
@@ -4755,7 +4946,7 @@ static void whisper_exp_compute_token_level_timestamps(
4755
4946
  s0 = k;
4756
4947
  }
4757
4948
  } else {
4758
- while (ctx.energy[k] < thold && k < s1) {
4949
+ while (state.energy[k] < thold && k < s1) {
4759
4950
  k++;
4760
4951
  }
4761
4952
  s0 = k;
@@ -4765,8 +4956,8 @@ static void whisper_exp_compute_token_level_timestamps(
4765
4956
 
4766
4957
  {
4767
4958
  int k = s1;
4768
- if (ctx.energy[k] > thold) {
4769
- while (k < n_samples - 1 && ctx.energy[k] > thold) {
4959
+ if (state.energy[k] > thold) {
4960
+ while (k < n_samples - 1 && state.energy[k] > thold) {
4770
4961
  k++;
4771
4962
  }
4772
4963
  tokens[j].t1 = sample_to_timestamp(k);
@@ -4776,7 +4967,7 @@ static void whisper_exp_compute_token_level_timestamps(
4776
4967
  s1 = k;
4777
4968
  }
4778
4969
  } else {
4779
- while (ctx.energy[k] < thold && k > s0) {
4970
+ while (state.energy[k] < thold && k > s0) {
4780
4971
  k--;
4781
4972
  }
4782
4973
  s1 = k;