whisper.rn 0.4.0-rc.4 → 0.4.0-rc.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. package/README.md +6 -6
  2. package/android/build.gradle +4 -0
  3. package/android/src/main/CMakeLists.txt +5 -0
  4. package/android/src/main/java/com/rnwhisper/AudioUtils.java +0 -80
  5. package/android/src/main/java/com/rnwhisper/WhisperContext.java +51 -133
  6. package/android/src/main/jni-utils.h +76 -0
  7. package/android/src/main/jni.cpp +187 -112
  8. package/cpp/README.md +1 -1
  9. package/cpp/coreml/whisper-encoder-impl.h +1 -1
  10. package/cpp/coreml/whisper-encoder.h +4 -0
  11. package/cpp/coreml/whisper-encoder.mm +4 -2
  12. package/cpp/ggml-alloc.c +55 -19
  13. package/cpp/ggml-alloc.h +7 -0
  14. package/cpp/ggml-backend-impl.h +46 -21
  15. package/cpp/ggml-backend.c +563 -156
  16. package/cpp/ggml-backend.h +62 -17
  17. package/cpp/ggml-impl.h +1 -1
  18. package/cpp/ggml-metal-whisper.metal +1010 -253
  19. package/cpp/ggml-metal.h +7 -1
  20. package/cpp/ggml-metal.m +618 -187
  21. package/cpp/ggml-quants.c +64 -59
  22. package/cpp/ggml-quants.h +40 -40
  23. package/cpp/ggml.c +751 -1466
  24. package/cpp/ggml.h +90 -25
  25. package/cpp/rn-audioutils.cpp +68 -0
  26. package/cpp/rn-audioutils.h +14 -0
  27. package/cpp/rn-whisper-log.h +11 -0
  28. package/cpp/rn-whisper.cpp +141 -59
  29. package/cpp/rn-whisper.h +47 -15
  30. package/cpp/whisper.cpp +1635 -928
  31. package/cpp/whisper.h +55 -10
  32. package/ios/RNWhisper.mm +7 -7
  33. package/ios/RNWhisperAudioUtils.h +0 -2
  34. package/ios/RNWhisperAudioUtils.m +0 -56
  35. package/ios/RNWhisperContext.h +3 -11
  36. package/ios/RNWhisperContext.mm +62 -134
  37. package/lib/commonjs/version.json +1 -1
  38. package/lib/module/version.json +1 -1
  39. package/package.json +6 -5
  40. package/src/version.json +1 -1
package/cpp/ggml.c CHANGED
@@ -233,24 +233,6 @@ inline static void * wsp_ggml_aligned_malloc(size_t size) {
233
233
  #define UNUSED WSP_GGML_UNUSED
234
234
  #define SWAP(x, y, T) do { T SWAP = x; x = y; y = SWAP; } while (0)
235
235
 
236
- //
237
- // tensor access macros
238
- //
239
-
240
- #define WSP_GGML_TENSOR_UNARY_OP_LOCALS \
241
- WSP_GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
242
- WSP_GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \
243
- WSP_GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \
244
- WSP_GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
245
-
246
- #define WSP_GGML_TENSOR_BINARY_OP_LOCALS \
247
- WSP_GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
248
- WSP_GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \
249
- WSP_GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) \
250
- WSP_GGML_TENSOR_LOCALS(size_t, nb1, src1, nb) \
251
- WSP_GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \
252
- WSP_GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
253
-
254
236
  #if defined(WSP_GGML_USE_ACCELERATE)
255
237
  #include <Accelerate/Accelerate.h>
256
238
  #if defined(WSP_GGML_USE_CLBLAST) // allow usage of CLBlast alongside Accelerate functions
@@ -455,9 +437,9 @@ static const wsp_ggml_type_traits_t type_traits[WSP_GGML_TYPE_COUNT] = {
455
437
  .blck_size = QK4_0,
456
438
  .type_size = sizeof(block_q4_0),
457
439
  .is_quantized = true,
458
- .to_float = (wsp_ggml_to_float_t) dequantize_row_q4_0,
459
- .from_float = quantize_row_q4_0,
460
- .from_float_reference = (wsp_ggml_from_float_t) quantize_row_q4_0_reference,
440
+ .to_float = (wsp_ggml_to_float_t) wsp_dewsp_quantize_row_q4_0,
441
+ .from_float = wsp_quantize_row_q4_0,
442
+ .from_float_reference = (wsp_ggml_from_float_t) wsp_quantize_row_q4_0_reference,
461
443
  .vec_dot = wsp_ggml_vec_dot_q4_0_q8_0,
462
444
  .vec_dot_type = WSP_GGML_TYPE_Q8_0,
463
445
  },
@@ -466,9 +448,9 @@ static const wsp_ggml_type_traits_t type_traits[WSP_GGML_TYPE_COUNT] = {
466
448
  .blck_size = QK4_1,
467
449
  .type_size = sizeof(block_q4_1),
468
450
  .is_quantized = true,
469
- .to_float = (wsp_ggml_to_float_t) dequantize_row_q4_1,
470
- .from_float = quantize_row_q4_1,
471
- .from_float_reference = (wsp_ggml_from_float_t) quantize_row_q4_1_reference,
451
+ .to_float = (wsp_ggml_to_float_t) wsp_dewsp_quantize_row_q4_1,
452
+ .from_float = wsp_quantize_row_q4_1,
453
+ .from_float_reference = (wsp_ggml_from_float_t) wsp_quantize_row_q4_1_reference,
472
454
  .vec_dot = wsp_ggml_vec_dot_q4_1_q8_1,
473
455
  .vec_dot_type = WSP_GGML_TYPE_Q8_1,
474
456
  },
@@ -499,9 +481,9 @@ static const wsp_ggml_type_traits_t type_traits[WSP_GGML_TYPE_COUNT] = {
499
481
  .blck_size = QK5_0,
500
482
  .type_size = sizeof(block_q5_0),
501
483
  .is_quantized = true,
502
- .to_float = (wsp_ggml_to_float_t) dequantize_row_q5_0,
503
- .from_float = quantize_row_q5_0,
504
- .from_float_reference = (wsp_ggml_from_float_t) quantize_row_q5_0_reference,
484
+ .to_float = (wsp_ggml_to_float_t) wsp_dewsp_quantize_row_q5_0,
485
+ .from_float = wsp_quantize_row_q5_0,
486
+ .from_float_reference = (wsp_ggml_from_float_t) wsp_quantize_row_q5_0_reference,
505
487
  .vec_dot = wsp_ggml_vec_dot_q5_0_q8_0,
506
488
  .vec_dot_type = WSP_GGML_TYPE_Q8_0,
507
489
  },
@@ -510,9 +492,9 @@ static const wsp_ggml_type_traits_t type_traits[WSP_GGML_TYPE_COUNT] = {
510
492
  .blck_size = QK5_1,
511
493
  .type_size = sizeof(block_q5_1),
512
494
  .is_quantized = true,
513
- .to_float = (wsp_ggml_to_float_t) dequantize_row_q5_1,
514
- .from_float = quantize_row_q5_1,
515
- .from_float_reference = (wsp_ggml_from_float_t) quantize_row_q5_1_reference,
495
+ .to_float = (wsp_ggml_to_float_t) wsp_dewsp_quantize_row_q5_1,
496
+ .from_float = wsp_quantize_row_q5_1,
497
+ .from_float_reference = (wsp_ggml_from_float_t) wsp_quantize_row_q5_1_reference,
516
498
  .vec_dot = wsp_ggml_vec_dot_q5_1_q8_1,
517
499
  .vec_dot_type = WSP_GGML_TYPE_Q8_1,
518
500
  },
@@ -521,9 +503,9 @@ static const wsp_ggml_type_traits_t type_traits[WSP_GGML_TYPE_COUNT] = {
521
503
  .blck_size = QK8_0,
522
504
  .type_size = sizeof(block_q8_0),
523
505
  .is_quantized = true,
524
- .to_float = (wsp_ggml_to_float_t) dequantize_row_q8_0,
525
- .from_float = quantize_row_q8_0,
526
- .from_float_reference = (wsp_ggml_from_float_t) quantize_row_q8_0_reference,
506
+ .to_float = (wsp_ggml_to_float_t) wsp_dewsp_quantize_row_q8_0,
507
+ .from_float = wsp_quantize_row_q8_0,
508
+ .from_float_reference = (wsp_ggml_from_float_t) wsp_quantize_row_q8_0_reference,
527
509
  .vec_dot = wsp_ggml_vec_dot_q8_0_q8_0,
528
510
  .vec_dot_type = WSP_GGML_TYPE_Q8_0,
529
511
  },
@@ -532,8 +514,8 @@ static const wsp_ggml_type_traits_t type_traits[WSP_GGML_TYPE_COUNT] = {
532
514
  .blck_size = QK8_1,
533
515
  .type_size = sizeof(block_q8_1),
534
516
  .is_quantized = true,
535
- .from_float = quantize_row_q8_1,
536
- .from_float_reference = (wsp_ggml_from_float_t) quantize_row_q8_1_reference,
517
+ .from_float = wsp_quantize_row_q8_1,
518
+ .from_float_reference = (wsp_ggml_from_float_t) wsp_quantize_row_q8_1_reference,
537
519
  .vec_dot_type = WSP_GGML_TYPE_Q8_1,
538
520
  },
539
521
  [WSP_GGML_TYPE_Q2_K] = {
@@ -541,9 +523,9 @@ static const wsp_ggml_type_traits_t type_traits[WSP_GGML_TYPE_COUNT] = {
541
523
  .blck_size = QK_K,
542
524
  .type_size = sizeof(block_q2_K),
543
525
  .is_quantized = true,
544
- .to_float = (wsp_ggml_to_float_t) dequantize_row_q2_K,
545
- .from_float = quantize_row_q2_K,
546
- .from_float_reference = (wsp_ggml_from_float_t) quantize_row_q2_K_reference,
526
+ .to_float = (wsp_ggml_to_float_t) wsp_dewsp_quantize_row_q2_K,
527
+ .from_float = wsp_quantize_row_q2_K,
528
+ .from_float_reference = (wsp_ggml_from_float_t) wsp_quantize_row_q2_K_reference,
547
529
  .vec_dot = wsp_ggml_vec_dot_q2_K_q8_K,
548
530
  .vec_dot_type = WSP_GGML_TYPE_Q8_K,
549
531
  },
@@ -552,9 +534,9 @@ static const wsp_ggml_type_traits_t type_traits[WSP_GGML_TYPE_COUNT] = {
552
534
  .blck_size = QK_K,
553
535
  .type_size = sizeof(block_q3_K),
554
536
  .is_quantized = true,
555
- .to_float = (wsp_ggml_to_float_t) dequantize_row_q3_K,
556
- .from_float = quantize_row_q3_K,
557
- .from_float_reference = (wsp_ggml_from_float_t) quantize_row_q3_K_reference,
537
+ .to_float = (wsp_ggml_to_float_t) wsp_dewsp_quantize_row_q3_K,
538
+ .from_float = wsp_quantize_row_q3_K,
539
+ .from_float_reference = (wsp_ggml_from_float_t) wsp_quantize_row_q3_K_reference,
558
540
  .vec_dot = wsp_ggml_vec_dot_q3_K_q8_K,
559
541
  .vec_dot_type = WSP_GGML_TYPE_Q8_K,
560
542
  },
@@ -563,9 +545,9 @@ static const wsp_ggml_type_traits_t type_traits[WSP_GGML_TYPE_COUNT] = {
563
545
  .blck_size = QK_K,
564
546
  .type_size = sizeof(block_q4_K),
565
547
  .is_quantized = true,
566
- .to_float = (wsp_ggml_to_float_t) dequantize_row_q4_K,
567
- .from_float = quantize_row_q4_K,
568
- .from_float_reference = (wsp_ggml_from_float_t) quantize_row_q4_K_reference,
548
+ .to_float = (wsp_ggml_to_float_t) wsp_dewsp_quantize_row_q4_K,
549
+ .from_float = wsp_quantize_row_q4_K,
550
+ .from_float_reference = (wsp_ggml_from_float_t) wsp_quantize_row_q4_K_reference,
569
551
  .vec_dot = wsp_ggml_vec_dot_q4_K_q8_K,
570
552
  .vec_dot_type = WSP_GGML_TYPE_Q8_K,
571
553
  },
@@ -574,9 +556,9 @@ static const wsp_ggml_type_traits_t type_traits[WSP_GGML_TYPE_COUNT] = {
574
556
  .blck_size = QK_K,
575
557
  .type_size = sizeof(block_q5_K),
576
558
  .is_quantized = true,
577
- .to_float = (wsp_ggml_to_float_t) dequantize_row_q5_K,
578
- .from_float = quantize_row_q5_K,
579
- .from_float_reference = (wsp_ggml_from_float_t) quantize_row_q5_K_reference,
559
+ .to_float = (wsp_ggml_to_float_t) wsp_dewsp_quantize_row_q5_K,
560
+ .from_float = wsp_quantize_row_q5_K,
561
+ .from_float_reference = (wsp_ggml_from_float_t) wsp_quantize_row_q5_K_reference,
580
562
  .vec_dot = wsp_ggml_vec_dot_q5_K_q8_K,
581
563
  .vec_dot_type = WSP_GGML_TYPE_Q8_K,
582
564
  },
@@ -585,9 +567,9 @@ static const wsp_ggml_type_traits_t type_traits[WSP_GGML_TYPE_COUNT] = {
585
567
  .blck_size = QK_K,
586
568
  .type_size = sizeof(block_q6_K),
587
569
  .is_quantized = true,
588
- .to_float = (wsp_ggml_to_float_t) dequantize_row_q6_K,
589
- .from_float = quantize_row_q6_K,
590
- .from_float_reference = (wsp_ggml_from_float_t) quantize_row_q6_K_reference,
570
+ .to_float = (wsp_ggml_to_float_t) wsp_dewsp_quantize_row_q6_K,
571
+ .from_float = wsp_quantize_row_q6_K,
572
+ .from_float_reference = (wsp_ggml_from_float_t) wsp_quantize_row_q6_K_reference,
591
573
  .vec_dot = wsp_ggml_vec_dot_q6_K_q8_K,
592
574
  .vec_dot_type = WSP_GGML_TYPE_Q8_K,
593
575
  },
@@ -596,7 +578,7 @@ static const wsp_ggml_type_traits_t type_traits[WSP_GGML_TYPE_COUNT] = {
596
578
  .blck_size = QK_K,
597
579
  .type_size = sizeof(block_q8_K),
598
580
  .is_quantized = true,
599
- .from_float = quantize_row_q8_K,
581
+ .from_float = wsp_quantize_row_q8_K,
600
582
  }
601
583
  };
602
584
 
@@ -1613,6 +1595,7 @@ static const char * WSP_GGML_OP_NAME[WSP_GGML_OP_COUNT] = {
1613
1595
  "GROUP_NORM",
1614
1596
 
1615
1597
  "MUL_MAT",
1598
+ "MUL_MAT_ID",
1616
1599
  "OUT_PROD",
1617
1600
 
1618
1601
  "SCALE",
@@ -1634,17 +1617,13 @@ static const char * WSP_GGML_OP_NAME[WSP_GGML_OP_COUNT] = {
1634
1617
  "ROPE_BACK",
1635
1618
  "ALIBI",
1636
1619
  "CLAMP",
1637
- "CONV_1D",
1638
- "CONV_1D_STAGE_0",
1639
- "CONV_1D_STAGE_1",
1640
1620
  "CONV_TRANSPOSE_1D",
1641
- "CONV_2D",
1642
- "CONV_2D_STAGE_0",
1643
- "CONV_2D_STAGE_1",
1621
+ "IM2COL",
1644
1622
  "CONV_TRANSPOSE_2D",
1645
1623
  "POOL_1D",
1646
1624
  "POOL_2D",
1647
1625
  "UPSCALE",
1626
+ "ARGSORT",
1648
1627
 
1649
1628
  "FLASH_ATTN",
1650
1629
  "FLASH_FF",
@@ -1671,7 +1650,7 @@ static const char * WSP_GGML_OP_NAME[WSP_GGML_OP_COUNT] = {
1671
1650
  "CROSS_ENTROPY_LOSS_BACK",
1672
1651
  };
1673
1652
 
1674
- static_assert(WSP_GGML_OP_COUNT == 73, "WSP_GGML_OP_COUNT != 73");
1653
+ static_assert(WSP_GGML_OP_COUNT == 70, "WSP_GGML_OP_COUNT != 70");
1675
1654
 
1676
1655
  static const char * WSP_GGML_OP_SYMBOL[WSP_GGML_OP_COUNT] = {
1677
1656
  "none",
@@ -1700,6 +1679,7 @@ static const char * WSP_GGML_OP_SYMBOL[WSP_GGML_OP_COUNT] = {
1700
1679
  "group_norm(x)",
1701
1680
 
1702
1681
  "X*Y",
1682
+ "X[i]*Y",
1703
1683
  "X*Y",
1704
1684
 
1705
1685
  "x*v",
@@ -1721,17 +1701,13 @@ static const char * WSP_GGML_OP_SYMBOL[WSP_GGML_OP_COUNT] = {
1721
1701
  "rope_back(x)",
1722
1702
  "alibi(x)",
1723
1703
  "clamp(x)",
1724
- "conv_1d(x)",
1725
- "conv_1d_stage_0(x)",
1726
- "conv_1d_stage_1(x)",
1727
1704
  "conv_transpose_1d(x)",
1728
- "conv_2d(x)",
1729
- "conv_2d_stage_0(x)",
1730
- "conv_2d_stage_1(x)",
1705
+ "im2col(x)",
1731
1706
  "conv_transpose_2d(x)",
1732
1707
  "pool_1d(x)",
1733
1708
  "pool_2d(x)",
1734
1709
  "upscale(x)",
1710
+ "argsort(x)",
1735
1711
 
1736
1712
  "flash_attn(x)",
1737
1713
  "flash_ff(x)",
@@ -1758,10 +1734,28 @@ static const char * WSP_GGML_OP_SYMBOL[WSP_GGML_OP_COUNT] = {
1758
1734
  "cross_entropy_loss_back(x,y)",
1759
1735
  };
1760
1736
 
1761
- static_assert(WSP_GGML_OP_COUNT == 73, "WSP_GGML_OP_COUNT != 73");
1737
+ static_assert(WSP_GGML_OP_COUNT == 70, "WSP_GGML_OP_COUNT != 70");
1762
1738
 
1763
1739
  static_assert(WSP_GGML_OP_POOL_COUNT == 2, "WSP_GGML_OP_POOL_COUNT != 2");
1764
1740
 
1741
+
1742
+ static const char * WSP_GGML_UNARY_OP_NAME[WSP_GGML_UNARY_OP_COUNT] = {
1743
+ "ABS",
1744
+ "SGN",
1745
+ "NEG",
1746
+ "STEP",
1747
+ "TANH",
1748
+ "ELU",
1749
+ "RELU",
1750
+ "GELU",
1751
+ "GELU_QUICK",
1752
+ "SILU",
1753
+ "LEAKY",
1754
+ };
1755
+
1756
+ static_assert(WSP_GGML_UNARY_OP_COUNT == 11, "WSP_GGML_UNARY_OP_COUNT != 11");
1757
+
1758
+
1765
1759
  static_assert(sizeof(struct wsp_ggml_object)%WSP_GGML_MEM_ALIGN == 0, "wsp_ggml_object size must be a multiple of WSP_GGML_MEM_ALIGN");
1766
1760
  static_assert(sizeof(struct wsp_ggml_tensor)%WSP_GGML_MEM_ALIGN == 0, "wsp_ggml_tensor size must be a multiple of WSP_GGML_MEM_ALIGN");
1767
1761
 
@@ -1781,18 +1775,13 @@ static void wsp_ggml_setup_op_has_task_pass(void) {
1781
1775
 
1782
1776
  p[WSP_GGML_OP_ACC ] = true;
1783
1777
  p[WSP_GGML_OP_MUL_MAT ] = true;
1778
+ p[WSP_GGML_OP_MUL_MAT_ID ] = true;
1784
1779
  p[WSP_GGML_OP_OUT_PROD ] = true;
1785
1780
  p[WSP_GGML_OP_SET ] = true;
1786
1781
  p[WSP_GGML_OP_GET_ROWS_BACK ] = true;
1787
1782
  p[WSP_GGML_OP_DIAG_MASK_INF ] = true;
1788
1783
  p[WSP_GGML_OP_DIAG_MASK_ZERO ] = true;
1789
- p[WSP_GGML_OP_CONV_1D ] = true;
1790
- p[WSP_GGML_OP_CONV_1D_STAGE_0 ] = true;
1791
- p[WSP_GGML_OP_CONV_1D_STAGE_1 ] = true;
1792
1784
  p[WSP_GGML_OP_CONV_TRANSPOSE_1D ] = true;
1793
- p[WSP_GGML_OP_CONV_2D ] = true;
1794
- p[WSP_GGML_OP_CONV_2D_STAGE_0 ] = true;
1795
- p[WSP_GGML_OP_CONV_2D_STAGE_1 ] = true;
1796
1785
  p[WSP_GGML_OP_CONV_TRANSPOSE_2D ] = true;
1797
1786
  p[WSP_GGML_OP_FLASH_ATTN_BACK ] = true;
1798
1787
  p[WSP_GGML_OP_CROSS_ENTROPY_LOSS ] = true;
@@ -2039,6 +2028,20 @@ const char * wsp_ggml_op_symbol(enum wsp_ggml_op op) {
2039
2028
  return WSP_GGML_OP_SYMBOL[op];
2040
2029
  }
2041
2030
 
2031
+ const char * wsp_ggml_unary_op_name(enum wsp_ggml_unary_op op) {
2032
+ return WSP_GGML_UNARY_OP_NAME[op];
2033
+ }
2034
+
2035
+ const char * wsp_ggml_op_desc(const struct wsp_ggml_tensor * t) {
2036
+ if (t->op == WSP_GGML_OP_UNARY) {
2037
+ enum wsp_ggml_unary_op uop = wsp_ggml_get_unary_op(t);
2038
+ return wsp_ggml_unary_op_name(uop);
2039
+ }
2040
+ else {
2041
+ return wsp_ggml_op_name(t->op);
2042
+ }
2043
+ }
2044
+
2042
2045
  size_t wsp_ggml_element_size(const struct wsp_ggml_tensor * tensor) {
2043
2046
  return wsp_ggml_type_size(tensor->type);
2044
2047
  }
@@ -3170,9 +3173,7 @@ static struct wsp_ggml_tensor * wsp_ggml_add_impl(
3170
3173
  struct wsp_ggml_tensor * a,
3171
3174
  struct wsp_ggml_tensor * b,
3172
3175
  bool inplace) {
3173
- // TODO: support less-strict constraint
3174
- // WSP_GGML_ASSERT(wsp_ggml_can_repeat(b, a));
3175
- WSP_GGML_ASSERT(wsp_ggml_can_repeat_rows(b, a));
3176
+ WSP_GGML_ASSERT(wsp_ggml_can_repeat(b, a));
3176
3177
 
3177
3178
  bool is_node = false;
3178
3179
 
@@ -3387,9 +3388,7 @@ static struct wsp_ggml_tensor * wsp_ggml_mul_impl(
3387
3388
  struct wsp_ggml_tensor * a,
3388
3389
  struct wsp_ggml_tensor * b,
3389
3390
  bool inplace) {
3390
- // TODO: support less-strict constraint
3391
- // WSP_GGML_ASSERT(wsp_ggml_can_repeat(b, a));
3392
- WSP_GGML_ASSERT(wsp_ggml_can_repeat_rows(b, a));
3391
+ WSP_GGML_ASSERT(wsp_ggml_can_repeat(b, a));
3393
3392
 
3394
3393
  bool is_node = false;
3395
3394
 
@@ -3434,7 +3433,7 @@ static struct wsp_ggml_tensor * wsp_ggml_div_impl(
3434
3433
  struct wsp_ggml_tensor * a,
3435
3434
  struct wsp_ggml_tensor * b,
3436
3435
  bool inplace) {
3437
- WSP_GGML_ASSERT(wsp_ggml_are_same_shape(a, b));
3436
+ WSP_GGML_ASSERT(wsp_ggml_can_repeat(b, a));
3438
3437
 
3439
3438
  bool is_node = false;
3440
3439
 
@@ -4072,6 +4071,49 @@ struct wsp_ggml_tensor * wsp_ggml_mul_mat(
4072
4071
  return result;
4073
4072
  }
4074
4073
 
4074
+ // wsp_ggml_mul_mat_id
4075
+
4076
+ struct wsp_ggml_tensor * wsp_ggml_mul_mat_id(
4077
+ struct wsp_ggml_context * ctx,
4078
+ struct wsp_ggml_tensor * as[],
4079
+ struct wsp_ggml_tensor * ids,
4080
+ int id,
4081
+ struct wsp_ggml_tensor * b) {
4082
+
4083
+ int64_t n_as = ids->ne[0];
4084
+
4085
+ WSP_GGML_ASSERT(ids->type == WSP_GGML_TYPE_I32);
4086
+ WSP_GGML_ASSERT(wsp_ggml_is_vector(ids));
4087
+ WSP_GGML_ASSERT(n_as > 0 && n_as <= WSP_GGML_MAX_SRC - 2);
4088
+ WSP_GGML_ASSERT(id >= 0 && id < n_as);
4089
+
4090
+ bool is_node = false;
4091
+
4092
+ if (as[0]->grad || b->grad) {
4093
+ is_node = true;
4094
+ }
4095
+
4096
+ const int64_t ne[4] = { as[0]->ne[1], b->ne[1], b->ne[2], b->ne[3] };
4097
+ struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, WSP_GGML_TYPE_F32, MAX(as[0]->n_dims, b->n_dims), ne);
4098
+
4099
+ wsp_ggml_set_op_params_i32(result, 0, id);
4100
+
4101
+ result->op = WSP_GGML_OP_MUL_MAT_ID;
4102
+ result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL;
4103
+ result->src[0] = ids;
4104
+ result->src[1] = b;
4105
+
4106
+ for (int64_t i = 0; i < n_as; i++) {
4107
+ struct wsp_ggml_tensor * a = as[i];
4108
+ WSP_GGML_ASSERT(wsp_ggml_are_same_shape(as[0], a));
4109
+ WSP_GGML_ASSERT(wsp_ggml_can_mul_mat(a, b));
4110
+ WSP_GGML_ASSERT(!wsp_ggml_is_transposed(a));
4111
+ result->src[i + 2] = a;
4112
+ }
4113
+
4114
+ return result;
4115
+ }
4116
+
4075
4117
  // wsp_ggml_out_prod
4076
4118
 
4077
4119
  struct wsp_ggml_tensor * wsp_ggml_out_prod(
@@ -4225,7 +4267,7 @@ struct wsp_ggml_tensor * wsp_ggml_set_2d_inplace(
4225
4267
  struct wsp_ggml_tensor * b,
4226
4268
  size_t nb1,
4227
4269
  size_t offset) {
4228
- return wsp_ggml_set_impl(ctx, a, b, nb1, a->nb[2], a->nb[3], offset, false);
4270
+ return wsp_ggml_set_impl(ctx, a, b, nb1, a->nb[2], a->nb[3], offset, true);
4229
4271
  }
4230
4272
 
4231
4273
  // wsp_ggml_cpy
@@ -4842,7 +4884,17 @@ struct wsp_ggml_tensor * wsp_ggml_diag_mask_zero_inplace(
4842
4884
  static struct wsp_ggml_tensor * wsp_ggml_soft_max_impl(
4843
4885
  struct wsp_ggml_context * ctx,
4844
4886
  struct wsp_ggml_tensor * a,
4887
+ struct wsp_ggml_tensor * mask,
4888
+ float scale,
4845
4889
  bool inplace) {
4890
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous(a));
4891
+ if (mask) {
4892
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous(mask));
4893
+ WSP_GGML_ASSERT(mask->ne[2] == 1);
4894
+ WSP_GGML_ASSERT(mask->ne[3] == 1);
4895
+ WSP_GGML_ASSERT(wsp_ggml_can_repeat_rows(mask, a));
4896
+ }
4897
+
4846
4898
  bool is_node = false;
4847
4899
 
4848
4900
  if (a->grad) {
@@ -4851,9 +4903,13 @@ static struct wsp_ggml_tensor * wsp_ggml_soft_max_impl(
4851
4903
 
4852
4904
  struct wsp_ggml_tensor * result = inplace ? wsp_ggml_view_tensor(ctx, a) : wsp_ggml_dup_tensor(ctx, a);
4853
4905
 
4906
+ float params[] = { scale };
4907
+ wsp_ggml_set_op_params(result, params, sizeof(params));
4908
+
4854
4909
  result->op = WSP_GGML_OP_SOFT_MAX;
4855
4910
  result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL;
4856
4911
  result->src[0] = a;
4912
+ result->src[1] = mask;
4857
4913
 
4858
4914
  return result;
4859
4915
  }
@@ -4861,13 +4917,21 @@ static struct wsp_ggml_tensor * wsp_ggml_soft_max_impl(
4861
4917
  struct wsp_ggml_tensor * wsp_ggml_soft_max(
4862
4918
  struct wsp_ggml_context * ctx,
4863
4919
  struct wsp_ggml_tensor * a) {
4864
- return wsp_ggml_soft_max_impl(ctx, a, false);
4920
+ return wsp_ggml_soft_max_impl(ctx, a, NULL, 1.0f, false);
4865
4921
  }
4866
4922
 
4867
4923
  struct wsp_ggml_tensor * wsp_ggml_soft_max_inplace(
4868
4924
  struct wsp_ggml_context * ctx,
4869
4925
  struct wsp_ggml_tensor * a) {
4870
- return wsp_ggml_soft_max_impl(ctx, a, true);
4926
+ return wsp_ggml_soft_max_impl(ctx, a, NULL, 1.0f, true);
4927
+ }
4928
+
4929
+ struct wsp_ggml_tensor * wsp_ggml_soft_max_ext(
4930
+ struct wsp_ggml_context * ctx,
4931
+ struct wsp_ggml_tensor * a,
4932
+ struct wsp_ggml_tensor * mask,
4933
+ float scale) {
4934
+ return wsp_ggml_soft_max_impl(ctx, a, mask, scale, false);
4871
4935
  }
4872
4936
 
4873
4937
  // wsp_ggml_soft_max_back
@@ -5040,8 +5104,13 @@ struct wsp_ggml_tensor * wsp_ggml_rope_back(
5040
5104
  int n_dims,
5041
5105
  int mode,
5042
5106
  int n_ctx,
5107
+ int n_orig_ctx,
5043
5108
  float freq_base,
5044
5109
  float freq_scale,
5110
+ float ext_factor,
5111
+ float attn_factor,
5112
+ float beta_fast,
5113
+ float beta_slow,
5045
5114
  float xpos_base,
5046
5115
  bool xpos_down) {
5047
5116
  WSP_GGML_ASSERT(wsp_ggml_is_vector(b));
@@ -5058,11 +5127,15 @@ struct wsp_ggml_tensor * wsp_ggml_rope_back(
5058
5127
 
5059
5128
  struct wsp_ggml_tensor * result = wsp_ggml_dup_tensor(ctx, a);
5060
5129
 
5061
- int32_t params[8] = { /*n_past*/ 0, n_dims, mode, n_ctx };
5062
- memcpy(params + 4, &freq_base, sizeof(float));
5063
- memcpy(params + 5, &freq_scale, sizeof(float));
5064
- memcpy(params + 6, &xpos_base, sizeof(float));
5065
- memcpy(params + 7, &xpos_down, sizeof(bool));
5130
+ int32_t params[13] = { /*n_past*/ 0, n_dims, mode, n_ctx, n_orig_ctx };
5131
+ memcpy(params + 5, &freq_base, sizeof(float));
5132
+ memcpy(params + 6, &freq_scale, sizeof(float));
5133
+ memcpy(params + 7, &ext_factor, sizeof(float));
5134
+ memcpy(params + 8, &attn_factor, sizeof(float));
5135
+ memcpy(params + 9, &beta_fast, sizeof(float));
5136
+ memcpy(params + 10, &beta_slow, sizeof(float));
5137
+ memcpy(params + 11, &xpos_base, sizeof(float));
5138
+ memcpy(params + 12, &xpos_down, sizeof(bool));
5066
5139
  wsp_ggml_set_op_params(result, params, sizeof(params));
5067
5140
 
5068
5141
  result->op = WSP_GGML_OP_ROPE_BACK;
@@ -5137,82 +5210,6 @@ static int64_t wsp_ggml_calc_conv_output_size(int64_t ins, int64_t ks, int s, in
5137
5210
  return (ins + 2 * p - d * (ks - 1) - 1) / s + 1;
5138
5211
  }
5139
5212
 
5140
- // im2col: [N, IC, IL] => [N, OL, IC*K]
5141
- // a: [OC,IC, K]
5142
- // b: [N, IC, IL]
5143
- // result: [N, OL, IC*K]
5144
- static struct wsp_ggml_tensor * wsp_ggml_conv_1d_stage_0(
5145
- struct wsp_ggml_context * ctx,
5146
- struct wsp_ggml_tensor * a,
5147
- struct wsp_ggml_tensor * b,
5148
- int s0,
5149
- int p0,
5150
- int d0) {
5151
- WSP_GGML_ASSERT(a->ne[1] == b->ne[1]);
5152
- bool is_node = false;
5153
-
5154
- if (a->grad || b->grad) {
5155
- WSP_GGML_ASSERT(false); // TODO: implement backward
5156
- is_node = true;
5157
- }
5158
-
5159
- const int64_t OL = wsp_ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0);
5160
-
5161
- const int64_t ne[4] = {
5162
- a->ne[1] * a->ne[0],
5163
- OL,
5164
- b->ne[2],
5165
- 1,
5166
- };
5167
- struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, WSP_GGML_TYPE_F16, 4, ne);
5168
-
5169
- int32_t params[] = { s0, p0, d0 };
5170
- wsp_ggml_set_op_params(result, params, sizeof(params));
5171
-
5172
- result->op = WSP_GGML_OP_CONV_1D_STAGE_0;
5173
- result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL;
5174
- result->src[0] = a;
5175
- result->src[1] = b;
5176
-
5177
- return result;
5178
- }
5179
-
5180
- // wsp_ggml_conv_1d_stage_1
5181
-
5182
- // gemm: [N, OC, OL] = [OC, IC * K] x [N*OL, IC * K]
5183
- // a: [OC, IC, K]
5184
- // b: [N, OL, IC * K]
5185
- // result: [N, OC, OL]
5186
- static struct wsp_ggml_tensor * wsp_ggml_conv_1d_stage_1(
5187
- struct wsp_ggml_context * ctx,
5188
- struct wsp_ggml_tensor * a,
5189
- struct wsp_ggml_tensor * b) {
5190
-
5191
- bool is_node = false;
5192
-
5193
- if (a->grad || b->grad) {
5194
- WSP_GGML_ASSERT(false); // TODO: implement backward
5195
- is_node = true;
5196
- }
5197
-
5198
- const int64_t ne[4] = {
5199
- b->ne[1],
5200
- a->ne[2],
5201
- b->ne[2],
5202
- 1,
5203
- };
5204
- struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, WSP_GGML_TYPE_F32, 4, ne);
5205
-
5206
- result->op = WSP_GGML_OP_CONV_1D_STAGE_1;
5207
- result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL;
5208
- result->src[0] = a;
5209
- result->src[1] = b;
5210
-
5211
- return result;
5212
- }
5213
-
5214
- // wsp_ggml_conv_1d
5215
-
5216
5213
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_conv_1d(
5217
5214
  struct wsp_ggml_context * ctx,
5218
5215
  struct wsp_ggml_tensor * a,
@@ -5220,43 +5217,17 @@ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_conv_1d(
5220
5217
  int s0,
5221
5218
  int p0,
5222
5219
  int d0) {
5223
- struct wsp_ggml_tensor * result = wsp_ggml_conv_1d_stage_0(ctx, a, b, s0, p0, d0);
5224
- result = wsp_ggml_conv_1d_stage_1(ctx, a, result);
5225
- return result;
5226
- }
5220
+ struct wsp_ggml_tensor * im2col = wsp_ggml_im2col(ctx, a, b, s0, 0, p0, 0, d0, 0, false); // [N, OL, IC * K]
5227
5221
 
5228
- // WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_conv_1d(
5229
- // struct wsp_ggml_context * ctx,
5230
- // struct wsp_ggml_tensor * a,
5231
- // struct wsp_ggml_tensor * b,
5232
- // int s0,
5233
- // int p0,
5234
- // int d0) {
5235
- // WSP_GGML_ASSERT(wsp_ggml_is_matrix(b));
5236
- // WSP_GGML_ASSERT(a->ne[1] == b->ne[1]);
5237
- // bool is_node = false;
5222
+ struct wsp_ggml_tensor * result =
5223
+ wsp_ggml_mul_mat(ctx,
5224
+ wsp_ggml_reshape_2d(ctx, im2col, im2col->ne[0], (im2col->ne[2] * im2col->ne[1])), // [N, OL, IC * K] => [N*OL, IC * K]
5225
+ wsp_ggml_reshape_2d(ctx, a, (a->ne[0] * a->ne[1]), a->ne[2])); // [OC,IC, K] => [OC, IC * K]
5238
5226
 
5239
- // if (a->grad || b->grad) {
5240
- // WSP_GGML_ASSERT(false); // TODO: implement backward
5241
- // is_node = true;
5242
- // }
5227
+ result = wsp_ggml_reshape_3d(ctx, result, im2col->ne[1], a->ne[2], im2col->ne[2]); // [N, OC, OL]
5243
5228
 
5244
- // const int64_t ne[4] = {
5245
- // wsp_ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0),
5246
- // a->ne[2], 1, 1,
5247
- // };
5248
- // struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, WSP_GGML_TYPE_F32, 2, ne);
5249
-
5250
- // int32_t params[] = { s0, p0, d0 };
5251
- // wsp_ggml_set_op_params(result, params, sizeof(params));
5252
-
5253
- // result->op = WSP_GGML_OP_CONV_1D;
5254
- // result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL;
5255
- // result->src[0] = a;
5256
- // result->src[1] = b;
5257
-
5258
- // return result;
5259
- // }
5229
+ return result;
5230
+ }
5260
5231
 
5261
5232
  // wsp_ggml_conv_1d_ph
5262
5233
 
@@ -5319,7 +5290,7 @@ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_conv_transpose_1d(
5319
5290
  // a: [OC,IC, KH, KW]
5320
5291
  // b: [N, IC, IH, IW]
5321
5292
  // result: [N, OH, OW, IC*KH*KW]
5322
- static struct wsp_ggml_tensor * wsp_ggml_conv_2d_stage_0(
5293
+ struct wsp_ggml_tensor * wsp_ggml_im2col(
5323
5294
  struct wsp_ggml_context * ctx,
5324
5295
  struct wsp_ggml_tensor * a,
5325
5296
  struct wsp_ggml_tensor * b,
@@ -5328,9 +5299,14 @@ static struct wsp_ggml_tensor * wsp_ggml_conv_2d_stage_0(
5328
5299
  int p0,
5329
5300
  int p1,
5330
5301
  int d0,
5331
- int d1) {
5302
+ int d1,
5303
+ bool is_2D) {
5332
5304
 
5333
- WSP_GGML_ASSERT(a->ne[2] == b->ne[2]);
5305
+ if(is_2D) {
5306
+ WSP_GGML_ASSERT(a->ne[2] == b->ne[2]);
5307
+ } else {
5308
+ WSP_GGML_ASSERT(a->ne[1] == b->ne[1]);
5309
+ }
5334
5310
  bool is_node = false;
5335
5311
 
5336
5312
  if (a->grad || b->grad) {
@@ -5338,81 +5314,51 @@ static struct wsp_ggml_tensor * wsp_ggml_conv_2d_stage_0(
5338
5314
  is_node = true;
5339
5315
  }
5340
5316
 
5341
- const int64_t OH = wsp_ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1);
5342
- const int64_t OW = wsp_ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0);
5317
+ const int64_t OH = is_2D ? wsp_ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1) : 0;
5318
+ const int64_t OW = wsp_ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0);
5343
5319
 
5344
5320
  const int64_t ne[4] = {
5345
- a->ne[2] * a->ne[1] * a->ne[0],
5321
+ is_2D ? (a->ne[2] * a->ne[1] * a->ne[0]) : a->ne[1] * a->ne[0],
5346
5322
  OW,
5347
- OH,
5348
- b->ne[3],
5323
+ is_2D ? OH : b->ne[2],
5324
+ is_2D ? b->ne[3] : 1,
5349
5325
  };
5350
- struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, WSP_GGML_TYPE_F16, 4, ne);
5351
5326
 
5352
- int32_t params[] = { s0, s1, p0, p1, d0, d1 };
5327
+ struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, WSP_GGML_TYPE_F16, 4, ne);
5328
+ int32_t params[] = { s0, s1, p0, p1, d0, d1, (is_2D ? 1 : 0) };
5353
5329
  wsp_ggml_set_op_params(result, params, sizeof(params));
5354
5330
 
5355
- result->op = WSP_GGML_OP_CONV_2D_STAGE_0;
5356
- result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL;
5357
- result->src[0] = a;
5358
- result->src[1] = b;
5359
-
5360
- return result;
5361
-
5362
- }
5363
-
5364
- // gemm: [N, OC, OH, OW] = [OC, IC * KH * KW] x [N*OH*OW, IC * KH * KW]
5365
- // a: [OC, IC, KH, KW]
5366
- // b: [N, OH, OW, IC * KH * KW]
5367
- // result: [N, OC, OH, OW]
5368
- static struct wsp_ggml_tensor * wsp_ggml_conv_2d_stage_1(
5369
- struct wsp_ggml_context * ctx,
5370
- struct wsp_ggml_tensor * a,
5371
- struct wsp_ggml_tensor * b) {
5372
-
5373
- bool is_node = false;
5374
-
5375
- if (a->grad || b->grad) {
5376
- WSP_GGML_ASSERT(false); // TODO: implement backward
5377
- is_node = true;
5378
- }
5379
-
5380
- const int64_t ne[4] = {
5381
- b->ne[1],
5382
- b->ne[2],
5383
- a->ne[3],
5384
- b->ne[3],
5385
- };
5386
- struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, WSP_GGML_TYPE_F32, 4, ne);
5387
-
5388
- result->op = WSP_GGML_OP_CONV_2D_STAGE_1;
5331
+ result->op = WSP_GGML_OP_IM2COL;
5389
5332
  result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL;
5390
5333
  result->src[0] = a;
5391
5334
  result->src[1] = b;
5392
5335
 
5393
5336
  return result;
5394
-
5395
5337
  }
5396
5338
 
5397
5339
  // a: [OC,IC, KH, KW]
5398
5340
  // b: [N, IC, IH, IW]
5399
5341
  // result: [N, OC, OH, OW]
5400
5342
  struct wsp_ggml_tensor * wsp_ggml_conv_2d(
5401
- struct wsp_ggml_context * ctx,
5402
- struct wsp_ggml_tensor * a,
5403
- struct wsp_ggml_tensor * b,
5404
- int s0,
5405
- int s1,
5406
- int p0,
5407
- int p1,
5408
- int d0,
5409
- int d1) {
5343
+ struct wsp_ggml_context * ctx,
5344
+ struct wsp_ggml_tensor * a,
5345
+ struct wsp_ggml_tensor * b,
5346
+ int s0,
5347
+ int s1,
5348
+ int p0,
5349
+ int p1,
5350
+ int d0,
5351
+ int d1) {
5352
+ struct wsp_ggml_tensor * im2col = wsp_ggml_im2col(ctx, a, b, s0, s1, p0, p1, d0, d1, true); // [N, OH, OW, IC * KH * KW]
5410
5353
 
5411
- struct wsp_ggml_tensor * result = wsp_ggml_conv_2d_stage_0(ctx, a, b, s0, s1, p0, p1, d0, d1); // [N, OH, OW, IC * KH * KW]
5412
- result = wsp_ggml_conv_2d_stage_1(ctx, a, result);
5354
+ struct wsp_ggml_tensor * result =
5355
+ wsp_ggml_mul_mat(ctx,
5356
+ wsp_ggml_reshape_2d(ctx, im2col, im2col->ne[0], im2col->ne[3] * im2col->ne[2] * im2col->ne[1]), // [N, OH, OW, IC * KH * KW] => [N*OH*OW, IC * KH * KW]
5357
+ wsp_ggml_reshape_2d(ctx, a, (a->ne[0] * a->ne[1] * a->ne[2]), a->ne[3])); // [OC,IC, KH, KW] => [OC, IC * KH * KW]
5413
5358
 
5414
- return result;
5359
+ result = wsp_ggml_reshape_4d(ctx, result, im2col->ne[1], im2col->ne[2], a->ne[3], im2col->ne[3]); // [N, OC, OH, OW]
5415
5360
 
5361
+ return result;
5416
5362
  }
5417
5363
 
5418
5364
  // wsp_ggml_conv_2d_sk_p0
@@ -5580,6 +5526,43 @@ struct wsp_ggml_tensor * wsp_ggml_upscale(
5580
5526
  return wsp_ggml_upscale_impl(ctx, a, scale_factor);
5581
5527
  }
5582
5528
 
5529
+ // wsp_ggml_argsort
5530
+
5531
+ struct wsp_ggml_tensor * wsp_ggml_argsort(
5532
+ struct wsp_ggml_context * ctx,
5533
+ struct wsp_ggml_tensor * a,
5534
+ enum wsp_ggml_sort_order order) {
5535
+ bool is_node = false;
5536
+
5537
+ struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, WSP_GGML_TYPE_I32, a->n_dims, a->ne);
5538
+
5539
+ wsp_ggml_set_op_params_i32(result, 0, (int32_t) order);
5540
+
5541
+ result->op = WSP_GGML_OP_ARGSORT;
5542
+ result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL;
5543
+ result->src[0] = a;
5544
+
5545
+ return result;
5546
+ }
5547
+
5548
+ // wsp_ggml_top_k
5549
+
5550
+ struct wsp_ggml_tensor * wsp_ggml_top_k(
5551
+ struct wsp_ggml_context * ctx,
5552
+ struct wsp_ggml_tensor * a,
5553
+ int k) {
5554
+ WSP_GGML_ASSERT(a->ne[0] >= k);
5555
+
5556
+ struct wsp_ggml_tensor * result = wsp_ggml_argsort(ctx, a, WSP_GGML_SORT_DESC);
5557
+
5558
+ result = wsp_ggml_view_4d(ctx, result,
5559
+ k, result->ne[1], result->ne[2], result->ne[3],
5560
+ result->nb[1], result->nb[2], result->nb[3],
5561
+ 0);
5562
+
5563
+ return result;
5564
+ }
5565
+
5583
5566
  // wsp_ggml_flash_attn
5584
5567
 
5585
5568
  struct wsp_ggml_tensor * wsp_ggml_flash_attn(
@@ -6472,7 +6455,7 @@ static void wsp_ggml_compute_forward_dup_f16(
6472
6455
  }
6473
6456
  }
6474
6457
  } else if (type_traits[dst->type].from_float) {
6475
- wsp_ggml_from_float_t const quantize_row_q = type_traits[dst->type].from_float;
6458
+ wsp_ggml_from_float_t const wsp_quantize_row_q = type_traits[dst->type].from_float;
6476
6459
  float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
6477
6460
 
6478
6461
  size_t id = 0;
@@ -6489,7 +6472,7 @@ static void wsp_ggml_compute_forward_dup_f16(
6489
6472
  src0_f32[i00] = WSP_GGML_FP16_TO_FP32(src0_ptr[i00]);
6490
6473
  }
6491
6474
 
6492
- quantize_row_q(src0_f32, dst_ptr + id, ne00);
6475
+ wsp_quantize_row_q(src0_f32, dst_ptr + id, ne00);
6493
6476
  id += rs;
6494
6477
  }
6495
6478
  id += rs * (ne01 - ir1);
@@ -6725,7 +6708,7 @@ static void wsp_ggml_compute_forward_dup_f32(
6725
6708
  }
6726
6709
  }
6727
6710
  } else if (type_traits[dst->type].from_float) {
6728
- wsp_ggml_from_float_t const quantize_row_q = type_traits[dst->type].from_float;
6711
+ wsp_ggml_from_float_t const wsp_quantize_row_q = type_traits[dst->type].from_float;
6729
6712
 
6730
6713
  size_t id = 0;
6731
6714
  size_t rs = nb0 * (ne00 / wsp_ggml_blck_size(dst->type));
@@ -6736,7 +6719,7 @@ static void wsp_ggml_compute_forward_dup_f32(
6736
6719
  id += rs * ir0;
6737
6720
  for (int i01 = ir0; i01 < ir1; i01++) {
6738
6721
  const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
6739
- quantize_row_q(src0_ptr, dst_ptr + id, ne00);
6722
+ wsp_quantize_row_q(src0_ptr, dst_ptr + id, ne00);
6740
6723
  id += rs;
6741
6724
  }
6742
6725
  id += rs * (ne01 - ir1);
@@ -6939,7 +6922,7 @@ static void wsp_ggml_compute_forward_add_f32(
6939
6922
  const struct wsp_ggml_tensor * src0,
6940
6923
  const struct wsp_ggml_tensor * src1,
6941
6924
  struct wsp_ggml_tensor * dst) {
6942
- WSP_GGML_ASSERT(wsp_ggml_can_repeat_rows(src1, src0) && wsp_ggml_are_same_shape(src0, dst));
6925
+ WSP_GGML_ASSERT(wsp_ggml_can_repeat(src1, src0) && wsp_ggml_are_same_shape(src0, dst));
6943
6926
 
6944
6927
  if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) {
6945
6928
  return;
@@ -6972,16 +6955,19 @@ static void wsp_ggml_compute_forward_add_f32(
6972
6955
  const int64_t i13 = i03 % ne13;
6973
6956
  const int64_t i12 = i02 % ne12;
6974
6957
  const int64_t i11 = i01 % ne11;
6958
+ const int64_t nr0 = ne00 / ne10;
6975
6959
 
6976
6960
  float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
6977
6961
  float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
6978
6962
  float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
6979
6963
 
6964
+ for (int64_t r = 0; r < nr0; ++r) {
6980
6965
  #ifdef WSP_GGML_USE_ACCELERATE
6981
- vDSP_vadd(src0_ptr, 1, src1_ptr, 1, dst_ptr, 1, ne00);
6966
+ vDSP_vadd(src0_ptr + r*ne10, 1, src1_ptr, 1, dst_ptr + r*ne10, 1, ne10);
6982
6967
  #else
6983
- wsp_ggml_vec_add_f32(ne00, dst_ptr, src0_ptr, src1_ptr);
6968
+ wsp_ggml_vec_add_f32(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr);
6984
6969
  #endif
6970
+ }
6985
6971
  }
6986
6972
  } else {
6987
6973
  // src1 is not contiguous
@@ -6998,8 +6984,9 @@ static void wsp_ggml_compute_forward_add_f32(
6998
6984
  float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
6999
6985
  float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
7000
6986
 
7001
- for (int i0 = 0; i0 < ne0; i0++) {
7002
- float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i0*nb10);
6987
+ for (int64_t i0 = 0; i0 < ne0; ++i0) {
6988
+ const int64_t i10 = i0 % ne10;
6989
+ float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10);
7003
6990
 
7004
6991
  dst_ptr[i0] = src0_ptr[i0] + *src1_ptr;
7005
6992
  }
@@ -7158,8 +7145,8 @@ static void wsp_ggml_compute_forward_add_q_f32(
7158
7145
 
7159
7146
  const enum wsp_ggml_type type = src0->type;
7160
7147
  const enum wsp_ggml_type dtype = dst->type;
7161
- wsp_ggml_to_float_t const dequantize_row_q = type_traits[type].to_float;
7162
- wsp_ggml_from_float_t const quantize_row_q = type_traits[dtype].from_float;
7148
+ wsp_ggml_to_float_t const wsp_dewsp_quantize_row_q = type_traits[type].to_float;
7149
+ wsp_ggml_from_float_t const wsp_quantize_row_q = type_traits[dtype].from_float;
7163
7150
 
7164
7151
  // we don't support permuted src0 or src1
7165
7152
  WSP_GGML_ASSERT(nb00 == wsp_ggml_type_size(type));
@@ -7204,12 +7191,12 @@ static void wsp_ggml_compute_forward_add_q_f32(
7204
7191
  assert(ne00 % 32 == 0);
7205
7192
 
7206
7193
  // unquantize row from src0 to temp buffer
7207
- dequantize_row_q(src0_row, wdata, ne00);
7194
+ wsp_dewsp_quantize_row_q(src0_row, wdata, ne00);
7208
7195
  // add src1
7209
7196
  wsp_ggml_vec_acc_f32(ne00, wdata, src1_row);
7210
7197
  // quantize row to dst
7211
- if (quantize_row_q != NULL) {
7212
- quantize_row_q(wdata, dst_row, ne00);
7198
+ if (wsp_quantize_row_q != NULL) {
7199
+ wsp_quantize_row_q(wdata, dst_row, ne00);
7213
7200
  } else {
7214
7201
  memcpy(dst_row, wdata, ne0*nb0);
7215
7202
  }
@@ -7435,8 +7422,8 @@ static void wsp_ggml_compute_forward_add1_q_f32(
7435
7422
  WSP_GGML_TENSOR_UNARY_OP_LOCALS
7436
7423
 
7437
7424
  const enum wsp_ggml_type type = src0->type;
7438
- wsp_ggml_to_float_t const dequantize_row_q = type_traits[type].to_float;
7439
- wsp_ggml_from_float_t const quantize_row_q = type_traits[type].from_float;
7425
+ wsp_ggml_to_float_t const wsp_dewsp_quantize_row_q = type_traits[type].to_float;
7426
+ wsp_ggml_from_float_t const wsp_quantize_row_q = type_traits[type].from_float;
7440
7427
 
7441
7428
  // we don't support permuted src0
7442
7429
  WSP_GGML_ASSERT(nb00 == wsp_ggml_type_size(type));
@@ -7471,11 +7458,11 @@ static void wsp_ggml_compute_forward_add1_q_f32(
7471
7458
  assert(ne0 % 32 == 0);
7472
7459
 
7473
7460
  // unquantize row from src0 to temp buffer
7474
- dequantize_row_q(src0_row, wdata, ne0);
7461
+ wsp_dewsp_quantize_row_q(src0_row, wdata, ne0);
7475
7462
  // add src1
7476
7463
  wsp_ggml_vec_acc1_f32(ne0, wdata, v);
7477
7464
  // quantize row to dst
7478
- quantize_row_q(wdata, dst_row, ne0);
7465
+ wsp_quantize_row_q(wdata, dst_row, ne0);
7479
7466
  }
7480
7467
  }
7481
7468
 
@@ -7719,7 +7706,7 @@ static void wsp_ggml_compute_forward_mul_f32(
7719
7706
  const struct wsp_ggml_tensor * src0,
7720
7707
  const struct wsp_ggml_tensor * src1,
7721
7708
  struct wsp_ggml_tensor * dst) {
7722
- WSP_GGML_ASSERT(wsp_ggml_can_repeat_rows(src1, src0) && wsp_ggml_are_same_shape(src0, dst));
7709
+ WSP_GGML_ASSERT(wsp_ggml_can_repeat(src1, src0) && wsp_ggml_are_same_shape(src0, dst));
7723
7710
 
7724
7711
  if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) {
7725
7712
  return;
@@ -7742,7 +7729,6 @@ static void wsp_ggml_compute_forward_mul_f32(
7742
7729
 
7743
7730
  WSP_GGML_ASSERT( nb0 == sizeof(float));
7744
7731
  WSP_GGML_ASSERT(nb00 == sizeof(float));
7745
- WSP_GGML_ASSERT(ne00 == ne10);
7746
7732
 
7747
7733
  if (nb10 == sizeof(float)) {
7748
7734
  for (int64_t ir = ith; ir < nr; ir += nth) {
@@ -7754,20 +7740,21 @@ static void wsp_ggml_compute_forward_mul_f32(
7754
7740
  const int64_t i13 = i03 % ne13;
7755
7741
  const int64_t i12 = i02 % ne12;
7756
7742
  const int64_t i11 = i01 % ne11;
7743
+ const int64_t nr0 = ne00 / ne10;
7757
7744
 
7758
7745
  float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
7759
7746
  float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
7760
7747
  float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
7761
7748
 
7749
+ for (int64_t r = 0 ; r < nr0; ++r) {
7762
7750
  #ifdef WSP_GGML_USE_ACCELERATE
7763
- UNUSED(wsp_ggml_vec_mul_f32);
7751
+ UNUSED(wsp_ggml_vec_mul_f32);
7764
7752
 
7765
- vDSP_vmul( src0_ptr, 1, src1_ptr, 1, dst_ptr, 1, ne00);
7753
+ vDSP_vmul(src0_ptr + r*ne10, 1, src1_ptr, 1, dst_ptr + r*ne10, 1, ne10);
7766
7754
  #else
7767
- wsp_ggml_vec_mul_f32(ne00, dst_ptr, src0_ptr, src1_ptr);
7755
+ wsp_ggml_vec_mul_f32(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr);
7768
7756
  #endif
7769
- // }
7770
- // }
7757
+ }
7771
7758
  }
7772
7759
  } else {
7773
7760
  // src1 is not contiguous
@@ -7785,8 +7772,9 @@ static void wsp_ggml_compute_forward_mul_f32(
7785
7772
  float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
7786
7773
  float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
7787
7774
 
7788
- for (int64_t i0 = 0; i0 < ne00; i0++) {
7789
- float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i0*nb10);
7775
+ for (int64_t i0 = 0; i0 < ne00; ++i0) {
7776
+ const int64_t i10 = i0 % ne10;
7777
+ float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10);
7790
7778
 
7791
7779
  dst_ptr[i0] = src0_ptr[i0] * (*src1_ptr);
7792
7780
  }
@@ -7820,14 +7808,16 @@ static void wsp_ggml_compute_forward_div_f32(
7820
7808
  const struct wsp_ggml_tensor * src0,
7821
7809
  const struct wsp_ggml_tensor * src1,
7822
7810
  struct wsp_ggml_tensor * dst) {
7823
- assert(params->ith == 0);
7824
- assert(wsp_ggml_are_same_shape(src0, src1) && wsp_ggml_are_same_shape(src0, dst));
7811
+ WSP_GGML_ASSERT(wsp_ggml_can_repeat(src1, src0) && wsp_ggml_are_same_shape(src0, dst));
7825
7812
 
7826
7813
  if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) {
7827
7814
  return;
7828
7815
  }
7829
7816
 
7830
- const int nr = wsp_ggml_nrows(src0);
7817
+ const int ith = params->ith;
7818
+ const int nth = params->nth;
7819
+
7820
+ const int64_t nr = wsp_ggml_nrows(src0);
7831
7821
 
7832
7822
  WSP_GGML_TENSOR_BINARY_OP_LOCALS
7833
7823
 
@@ -7835,41 +7825,50 @@ static void wsp_ggml_compute_forward_div_f32(
7835
7825
  WSP_GGML_ASSERT(nb00 == sizeof(float));
7836
7826
 
7837
7827
  if (nb10 == sizeof(float)) {
7838
- for (int ir = 0; ir < nr; ++ir) {
7839
- // src0, src1 and dst are same shape => same indices
7840
- const int i3 = ir/(ne2*ne1);
7841
- const int i2 = (ir - i3*ne2*ne1)/ne1;
7842
- const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
7828
+ for (int64_t ir = ith; ir < nr; ir += nth) {
7829
+ // src0 and dst are same shape => same indices
7830
+ const int64_t i03 = ir/(ne02*ne01);
7831
+ const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
7832
+ const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
7833
+
7834
+ const int64_t i13 = i03 % ne13;
7835
+ const int64_t i12 = i02 % ne12;
7836
+ const int64_t i11 = i01 % ne11;
7837
+ const int64_t nr0 = ne00 / ne10;
7838
+
7839
+ float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
7840
+ float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
7841
+ float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
7843
7842
 
7843
+ for (int64_t r = 0; r < nr0; ++r) {
7844
7844
  #ifdef WSP_GGML_USE_ACCELERATE
7845
- UNUSED(wsp_ggml_vec_div_f32);
7845
+ UNUSED(wsp_ggml_vec_div_f32);
7846
7846
 
7847
- vDSP_vdiv(
7848
- (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1,
7849
- (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1,
7850
- (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), 1,
7851
- ne0);
7847
+ vDSP_vdiv(src1_ptr, 1, src0_ptr + r*ne10, 1, dst_ptr + r*ne10, 1, ne10);
7852
7848
  #else
7853
- wsp_ggml_vec_div_f32(ne0,
7854
- (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ),
7855
- (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01),
7856
- (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
7849
+ wsp_ggml_vec_div_f32(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr);
7857
7850
  #endif
7858
- // }
7859
- // }
7851
+ }
7860
7852
  }
7861
7853
  } else {
7862
7854
  // src1 is not contiguous
7863
- for (int ir = 0; ir < nr; ++ir) {
7864
- // src0, src1 and dst are same shape => same indices
7865
- const int i3 = ir/(ne2*ne1);
7866
- const int i2 = (ir - i3*ne2*ne1)/ne1;
7867
- const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
7855
+ for (int64_t ir = ith; ir < nr; ir += nth) {
7856
+ // src0 and dst are same shape => same indices
7857
+ // src1 is broadcastable across src0 and dst in i1, i2, i3
7858
+ const int64_t i03 = ir/(ne02*ne01);
7859
+ const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
7860
+ const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
7868
7861
 
7869
- float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 );
7870
- float * src0_ptr = (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
7871
- for (int i0 = 0; i0 < ne0; i0++) {
7872
- float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11 + i0*nb10);
7862
+ const int64_t i13 = i03 % ne13;
7863
+ const int64_t i12 = i02 % ne12;
7864
+ const int64_t i11 = i01 % ne11;
7865
+
7866
+ float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
7867
+ float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
7868
+
7869
+ for (int64_t i0 = 0; i0 < ne00; ++i0) {
7870
+ const int64_t i10 = i0 % ne10;
7871
+ float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10);
7873
7872
 
7874
7873
  dst_ptr[i0] = src0_ptr[i0] / (*src1_ptr);
7875
7874
  }
@@ -8315,7 +8314,7 @@ static void wsp_ggml_compute_forward_repeat_f16(
8315
8314
  return;
8316
8315
  }
8317
8316
 
8318
- WSP_GGML_TENSOR_UNARY_OP_LOCALS;
8317
+ WSP_GGML_TENSOR_UNARY_OP_LOCALS
8319
8318
 
8320
8319
  // guaranteed to be an integer due to the check in wsp_ggml_can_repeat
8321
8320
  const int nr0 = (int)(ne0/ne00);
@@ -8460,6 +8459,7 @@ static void wsp_ggml_compute_forward_concat_f32(
8460
8459
  WSP_GGML_ASSERT(src0->nb[0] == sizeof(float));
8461
8460
 
8462
8461
  const int ith = params->ith;
8462
+ const int nth = params->nth;
8463
8463
 
8464
8464
  WSP_GGML_TENSOR_BINARY_OP_LOCALS
8465
8465
 
@@ -8469,7 +8469,7 @@ static void wsp_ggml_compute_forward_concat_f32(
8469
8469
  WSP_GGML_ASSERT(nb10 == sizeof(float));
8470
8470
 
8471
8471
  for (int i3 = 0; i3 < ne3; i3++) {
8472
- for (int i2 = ith; i2 < ne2; i2++) {
8472
+ for (int i2 = ith; i2 < ne2; i2 += nth) {
8473
8473
  if (i2 < ne02) { // src0
8474
8474
  for (int i1 = 0; i1 < ne1; i1++) {
8475
8475
  for (int i0 = 0; i0 < ne0; i0++) {
@@ -9507,6 +9507,8 @@ static bool wsp_ggml_compute_forward_mul_mat_use_blas(
9507
9507
  // TODO: find the optimal values for these
9508
9508
  if (wsp_ggml_is_contiguous(src0) &&
9509
9509
  wsp_ggml_is_contiguous(src1) &&
9510
+ //src0->type == WSP_GGML_TYPE_F32 &&
9511
+ src1->type == WSP_GGML_TYPE_F32 &&
9510
9512
  (ne0 >= 32 && ne1 >= 32 && ne10 >= 32)) {
9511
9513
 
9512
9514
  /*printf("BLAS: %d %d %d %d %d\n", ne0, ne1, ne10, ne00, ne01);*/
@@ -9545,7 +9547,7 @@ static void wsp_ggml_compute_forward_mul_mat(
9545
9547
 
9546
9548
  // we don't support permuted src0 or src1
9547
9549
  WSP_GGML_ASSERT(nb00 == wsp_ggml_type_size(type));
9548
- WSP_GGML_ASSERT(nb10 == sizeof(float));
9550
+ WSP_GGML_ASSERT(nb10 == wsp_ggml_type_size(src1->type));
9549
9551
 
9550
9552
  // dst cannot be transposed or permuted
9551
9553
  WSP_GGML_ASSERT(nb0 == sizeof(float));
@@ -9627,6 +9629,8 @@ static void wsp_ggml_compute_forward_mul_mat(
9627
9629
  char * wdata = params->wdata;
9628
9630
  const size_t row_size = ne10*wsp_ggml_type_size(vec_dot_type)/wsp_ggml_blck_size(vec_dot_type);
9629
9631
 
9632
+ assert(params->wsize >= ne11*ne12*ne13*row_size);
9633
+
9630
9634
  for (int64_t i13 = 0; i13 < ne13; ++i13) {
9631
9635
  for (int64_t i12 = 0; i12 < ne12; ++i12) {
9632
9636
  for (int64_t i11 = 0; i11 < ne11; ++i11) {
@@ -9728,6 +9732,26 @@ static void wsp_ggml_compute_forward_mul_mat(
9728
9732
  }
9729
9733
  }
9730
9734
 
9735
+ // wsp_ggml_compute_forward_mul_mat_id
9736
+
9737
+ static void wsp_ggml_compute_forward_mul_mat_id(
9738
+ const struct wsp_ggml_compute_params * params,
9739
+ struct wsp_ggml_tensor * dst) {
9740
+
9741
+ const struct wsp_ggml_tensor * ids = dst->src[0];
9742
+ const struct wsp_ggml_tensor * src1 = dst->src[1];
9743
+
9744
+ const int id = wsp_ggml_get_op_params_i32(dst, 0);
9745
+
9746
+ const int a_id = ((int32_t *)ids->data)[id];
9747
+
9748
+ WSP_GGML_ASSERT(a_id >= 0 && a_id < ids->ne[0]);
9749
+
9750
+ const struct wsp_ggml_tensor * src0 = dst->src[a_id + 2];
9751
+
9752
+ wsp_ggml_compute_forward_mul_mat(params, src0, src1, dst);
9753
+ }
9754
+
9731
9755
  // wsp_ggml_compute_forward_out_prod
9732
9756
 
9733
9757
  static void wsp_ggml_compute_forward_out_prod_f32(
@@ -9743,10 +9767,12 @@ static void wsp_ggml_compute_forward_out_prod_f32(
9743
9767
  const int ith = params->ith;
9744
9768
  const int nth = params->nth;
9745
9769
 
9770
+ WSP_GGML_ASSERT(ne0 == ne00);
9771
+ WSP_GGML_ASSERT(ne1 == ne10);
9772
+ WSP_GGML_ASSERT(ne2 == ne02);
9746
9773
  WSP_GGML_ASSERT(ne02 == ne12);
9747
- WSP_GGML_ASSERT(ne03 == ne13);
9748
- WSP_GGML_ASSERT(ne2 == ne12);
9749
9774
  WSP_GGML_ASSERT(ne3 == ne13);
9775
+ WSP_GGML_ASSERT(ne03 == ne13);
9750
9776
 
9751
9777
  // we don't support permuted src0 or src1
9752
9778
  WSP_GGML_ASSERT(nb00 == sizeof(float));
@@ -9757,18 +9783,25 @@ static void wsp_ggml_compute_forward_out_prod_f32(
9757
9783
  // WSP_GGML_ASSERT(nb1 <= nb2);
9758
9784
  // WSP_GGML_ASSERT(nb2 <= nb3);
9759
9785
 
9760
- WSP_GGML_ASSERT(ne0 == ne00);
9761
- WSP_GGML_ASSERT(ne1 == ne10);
9762
- WSP_GGML_ASSERT(ne2 == ne02);
9763
- WSP_GGML_ASSERT(ne3 == ne03);
9764
-
9765
9786
  // nb01 >= nb00 - src0 is not transposed
9766
9787
  // compute by src0 rows
9767
9788
 
9768
9789
  // TODO: #if defined(WSP_GGML_USE_CUBLAS) wsp_ggml_cuda_out_prod
9769
- // TODO: #if defined(WSP_GGML_USE_ACCELERATE) || defined(WSP_GGML_USE_OPENBLAS) || defined(WSP_GGML_USE_CLBLAST)
9790
+ // TODO: #if defined(WSP_GGML_USE_CLBLAST)
9791
+
9792
+ #if defined(WSP_GGML_USE_ACCELERATE) || defined(WSP_GGML_USE_OPENBLAS)
9793
+ bool use_blas = wsp_ggml_is_matrix(src0) &&
9794
+ wsp_ggml_is_matrix(src1) &&
9795
+ wsp_ggml_is_contiguous(src0) &&
9796
+ (wsp_ggml_is_contiguous(src1) || wsp_ggml_is_transposed(src1));
9797
+ #endif
9770
9798
 
9771
9799
  if (params->type == WSP_GGML_TASK_INIT) {
9800
+ #if defined(WSP_GGML_USE_ACCELERATE) || defined(WSP_GGML_USE_OPENBLAS) // gemm beta will zero dst
9801
+ if (use_blas) {
9802
+ return;
9803
+ }
9804
+ #endif
9772
9805
  wsp_ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0);
9773
9806
  return;
9774
9807
  }
@@ -9777,6 +9810,50 @@ static void wsp_ggml_compute_forward_out_prod_f32(
9777
9810
  return;
9778
9811
  }
9779
9812
 
9813
+ #if defined(WSP_GGML_USE_ACCELERATE) || defined(WSP_GGML_USE_OPENBLAS)
9814
+ if (use_blas) {
9815
+ if (params->ith != 0) { // All threads other than the first do no work.
9816
+ return;
9817
+ }
9818
+ // Arguments to wsp_ggml_compute_forward_out_prod (expressed as major,minor)
9819
+ // src0: (k,n)
9820
+ // src1: (k,m)
9821
+ // dst: (m,n)
9822
+ //
9823
+ // Arguments to sgemm (see https://github.com/Reference-LAPACK/lapack/blob/master/BLAS/SRC/sgemm.f)
9824
+ // Also expressed as (major,minor)
9825
+ // a: (m,k): so src1 transposed
9826
+ // b: (k,n): so src0
9827
+ // c: (m,n)
9828
+ //
9829
+ // However, if wsp_ggml_is_transposed(src1) is true, then
9830
+ // src1->data already contains a transposed version, so sgemm mustn't
9831
+ // transpose it further.
9832
+
9833
+ int n = src0->ne[0];
9834
+ int k = src0->ne[1];
9835
+ int m = src1->ne[0];
9836
+
9837
+ int transposeA, lda;
9838
+
9839
+ if (!wsp_ggml_is_transposed(src1)) {
9840
+ transposeA = CblasTrans;
9841
+ lda = m;
9842
+ } else {
9843
+ transposeA = CblasNoTrans;
9844
+ lda = k;
9845
+ }
9846
+
9847
+ float * a = (float *) ((char *) src1->data);
9848
+ float * b = (float *) ((char *) src0->data);
9849
+ float * c = (float *) ((char *) dst->data);
9850
+
9851
+ cblas_sgemm(CblasRowMajor, transposeA, CblasNoTrans, m, n, k, 1.0, a, lda, b, n, 0.0, c, n);
9852
+
9853
+ return;
9854
+ }
9855
+ #endif
9856
+
9780
9857
  // dst[:,:,:,:] = 0
9781
9858
  // for i2,i3:
9782
9859
  // for i1:
@@ -9880,7 +9957,7 @@ static void wsp_ggml_compute_forward_out_prod_q_f32(
9880
9957
  const int nth = params->nth;
9881
9958
 
9882
9959
  const enum wsp_ggml_type type = src0->type;
9883
- wsp_ggml_to_float_t const dequantize_row_q = type_traits[type].to_float;
9960
+ wsp_ggml_to_float_t const wsp_dewsp_quantize_row_q = type_traits[type].to_float;
9884
9961
 
9885
9962
  WSP_GGML_ASSERT(ne02 == ne12);
9886
9963
  WSP_GGML_ASSERT(ne03 == ne13);
@@ -9957,7 +10034,7 @@ static void wsp_ggml_compute_forward_out_prod_q_f32(
9957
10034
  float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
9958
10035
  float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
9959
10036
 
9960
- dequantize_row_q(s0, wdata, ne0);
10037
+ wsp_dewsp_quantize_row_q(s0, wdata, ne0);
9961
10038
  wsp_ggml_vec_mad_f32(ne0, d, wdata, *s1);
9962
10039
  }
9963
10040
  }
@@ -10251,7 +10328,7 @@ static void wsp_ggml_compute_forward_get_rows_q(
10251
10328
  const int nc = src0->ne[0];
10252
10329
  const int nr = wsp_ggml_nelements(src1);
10253
10330
  const enum wsp_ggml_type type = src0->type;
10254
- wsp_ggml_to_float_t const dequantize_row_q = type_traits[type].to_float;
10331
+ wsp_ggml_to_float_t const wsp_dewsp_quantize_row_q = type_traits[type].to_float;
10255
10332
 
10256
10333
  assert( dst->ne[0] == nc);
10257
10334
  assert( dst->ne[1] == nr);
@@ -10260,7 +10337,7 @@ static void wsp_ggml_compute_forward_get_rows_q(
10260
10337
  for (int i = 0; i < nr; ++i) {
10261
10338
  const int r = ((int32_t *) src1->data)[i];
10262
10339
 
10263
- dequantize_row_q(
10340
+ wsp_dewsp_quantize_row_q(
10264
10341
  (const void *) ((char *) src0->data + r*src0->nb[1]),
10265
10342
  (float *) ((char *) dst->data + i*dst->nb[1]), nc);
10266
10343
  }
@@ -10630,20 +10707,25 @@ static void wsp_ggml_compute_forward_diag_mask_zero(
10630
10707
  static void wsp_ggml_compute_forward_soft_max_f32(
10631
10708
  const struct wsp_ggml_compute_params * params,
10632
10709
  const struct wsp_ggml_tensor * src0,
10633
- struct wsp_ggml_tensor * dst) {
10634
- WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src0));
10635
- WSP_GGML_ASSERT(wsp_ggml_is_contiguous(dst));
10636
- WSP_GGML_ASSERT(wsp_ggml_are_same_shape(src0, dst));
10710
+ const struct wsp_ggml_tensor * src1,
10711
+ struct wsp_ggml_tensor * dst) {
10712
+ assert(wsp_ggml_is_contiguous(dst));
10713
+ assert(wsp_ggml_are_same_shape(src0, dst));
10637
10714
 
10638
10715
  if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) {
10639
10716
  return;
10640
10717
  }
10641
10718
 
10719
+ float scale = 1.0f;
10720
+ memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
10721
+
10642
10722
  // TODO: handle transposed/permuted matrices
10643
10723
 
10644
10724
  const int ith = params->ith;
10645
10725
  const int nth = params->nth;
10646
10726
 
10727
+ const int64_t ne11 = src1 ? src1->ne[1] : 1;
10728
+
10647
10729
  const int nc = src0->ne[0];
10648
10730
  const int nr = wsp_ggml_nrows(src0);
10649
10731
 
@@ -10654,29 +10736,40 @@ static void wsp_ggml_compute_forward_soft_max_f32(
10654
10736
  const int ir0 = dr*ith;
10655
10737
  const int ir1 = MIN(ir0 + dr, nr);
10656
10738
 
10739
+ float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
10740
+
10657
10741
  for (int i1 = ir0; i1 < ir1; i1++) {
10658
- float *sp = (float *)((char *) src0->data + i1*src0->nb[1]);
10659
- float *dp = (float *)((char *) dst->data + i1*dst->nb[1]);
10742
+ float * sp = (float *)((char *) src0->data + i1*src0->nb[1]);
10743
+ float * dp = (float *)((char *) dst->data + i1*dst->nb[1]);
10744
+
10745
+ // broadcast the mask across rows
10746
+ float * mp = src1 ? (float *)((char *) src1->data + (i1%ne11)*src1->nb[1]) : NULL;
10747
+
10748
+ wsp_ggml_vec_cpy_f32 (nc, wp, sp);
10749
+ wsp_ggml_vec_scale_f32(nc, wp, scale);
10750
+ if (mp) {
10751
+ wsp_ggml_vec_acc_f32(nc, wp, mp);
10752
+ }
10660
10753
 
10661
10754
  #ifndef NDEBUG
10662
10755
  for (int i = 0; i < nc; ++i) {
10663
10756
  //printf("p[%d] = %f\n", i, p[i]);
10664
- assert(!isnan(sp[i]));
10757
+ assert(!isnan(wp[i]));
10665
10758
  }
10666
10759
  #endif
10667
10760
 
10668
10761
  float max = -INFINITY;
10669
- wsp_ggml_vec_max_f32(nc, &max, sp);
10762
+ wsp_ggml_vec_max_f32(nc, &max, wp);
10670
10763
 
10671
10764
  wsp_ggml_float sum = 0.0;
10672
10765
 
10673
10766
  uint16_t scvt;
10674
10767
  for (int i = 0; i < nc; i++) {
10675
- if (sp[i] == -INFINITY) {
10768
+ if (wp[i] == -INFINITY) {
10676
10769
  dp[i] = 0.0f;
10677
10770
  } else {
10678
- // const float val = (sp[i] == -INFINITY) ? 0.0 : exp(sp[i] - max);
10679
- wsp_ggml_fp16_t s = WSP_GGML_FP32_TO_FP16(sp[i] - max);
10771
+ // const float val = (wp[i] == -INFINITY) ? 0.0 : exp(wp[i] - max);
10772
+ wsp_ggml_fp16_t s = WSP_GGML_FP32_TO_FP16(wp[i] - max);
10680
10773
  memcpy(&scvt, &s, sizeof(scvt));
10681
10774
  const float val = WSP_GGML_FP16_TO_FP32(wsp_ggml_table_exp_f16[scvt]);
10682
10775
  sum += (wsp_ggml_float)val;
@@ -10701,11 +10794,12 @@ static void wsp_ggml_compute_forward_soft_max_f32(
10701
10794
  static void wsp_ggml_compute_forward_soft_max(
10702
10795
  const struct wsp_ggml_compute_params * params,
10703
10796
  const struct wsp_ggml_tensor * src0,
10704
- struct wsp_ggml_tensor * dst) {
10797
+ const struct wsp_ggml_tensor * src1,
10798
+ struct wsp_ggml_tensor * dst) {
10705
10799
  switch (src0->type) {
10706
10800
  case WSP_GGML_TYPE_F32:
10707
10801
  {
10708
- wsp_ggml_compute_forward_soft_max_f32(params, src0, dst);
10802
+ wsp_ggml_compute_forward_soft_max_f32(params, src0, src1, dst);
10709
10803
  } break;
10710
10804
  default:
10711
10805
  {
@@ -11086,7 +11180,8 @@ static void wsp_ggml_compute_forward_rope_f32(
11086
11180
  const struct wsp_ggml_compute_params * params,
11087
11181
  const struct wsp_ggml_tensor * src0,
11088
11182
  const struct wsp_ggml_tensor * src1,
11089
- struct wsp_ggml_tensor * dst) {
11183
+ struct wsp_ggml_tensor * dst,
11184
+ const bool forward) {
11090
11185
  if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) {
11091
11186
  return;
11092
11187
  }
@@ -11145,6 +11240,11 @@ static void wsp_ggml_compute_forward_rope_f32(
11145
11240
  const bool is_neox = mode & 2;
11146
11241
  const bool is_glm = mode & 4;
11147
11242
 
11243
+ // backward process uses inverse rotation by cos and sin.
11244
+ // cos and sin build a rotation matrix, where the inverse is the transpose.
11245
+ // this essentially just switches the sign of sin.
11246
+ const float sin_sign = forward ? 1.0f : -1.0f;
11247
+
11148
11248
  const int32_t * pos = (const int32_t *) src1->data;
11149
11249
 
11150
11250
  for (int64_t i3 = 0; i3 < ne3; i3++) {
@@ -11161,9 +11261,9 @@ static void wsp_ggml_compute_forward_rope_f32(
11161
11261
  float block_theta = MAX(p - (n_ctx - 2), 0);
11162
11262
  for (int64_t i0 = 0; i0 < ne0 / 4; i0++) {
11163
11263
  const float cos_theta = cosf(theta_base);
11164
- const float sin_theta = sinf(theta_base);
11264
+ const float sin_theta = sinf(theta_base) * sin_sign;
11165
11265
  const float cos_block_theta = cosf(block_theta);
11166
- const float sin_block_theta = sinf(block_theta);
11266
+ const float sin_block_theta = sinf(block_theta) * sin_sign;
11167
11267
 
11168
11268
  theta_base *= theta_scale;
11169
11269
  block_theta *= theta_scale;
@@ -11187,6 +11287,7 @@ static void wsp_ggml_compute_forward_rope_f32(
11187
11287
  rope_yarn(
11188
11288
  theta_base, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta
11189
11289
  );
11290
+ sin_theta *= sin_sign;
11190
11291
 
11191
11292
  // zeta scaling for xPos only:
11192
11293
  float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f;
@@ -11217,6 +11318,7 @@ static void wsp_ggml_compute_forward_rope_f32(
11217
11318
  theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
11218
11319
  &cos_theta, &sin_theta
11219
11320
  );
11321
+ sin_theta *= sin_sign;
11220
11322
 
11221
11323
  theta_base *= theta_scale;
11222
11324
 
@@ -11242,7 +11344,8 @@ static void wsp_ggml_compute_forward_rope_f16(
11242
11344
  const struct wsp_ggml_compute_params * params,
11243
11345
  const struct wsp_ggml_tensor * src0,
11244
11346
  const struct wsp_ggml_tensor * src1,
11245
- struct wsp_ggml_tensor * dst) {
11347
+ struct wsp_ggml_tensor * dst,
11348
+ const bool forward) {
11246
11349
  if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) {
11247
11350
  return;
11248
11351
  }
@@ -11294,6 +11397,11 @@ static void wsp_ggml_compute_forward_rope_f16(
11294
11397
  const bool is_neox = mode & 2;
11295
11398
  const bool is_glm = mode & 4;
11296
11399
 
11400
+ // backward process uses inverse rotation by cos and sin.
11401
+ // cos and sin build a rotation matrix, where the inverse is the transpose.
11402
+ // this essentially just switches the sign of sin.
11403
+ const float sin_sign = forward ? 1.0f : -1.0f;
11404
+
11297
11405
  const int32_t * pos = (const int32_t *) src1->data;
11298
11406
 
11299
11407
  for (int64_t i3 = 0; i3 < ne3; i3++) {
@@ -11310,9 +11418,9 @@ static void wsp_ggml_compute_forward_rope_f16(
11310
11418
  float block_theta = MAX(p - (n_ctx - 2), 0);
11311
11419
  for (int64_t i0 = 0; i0 < ne0 / 4; i0++) {
11312
11420
  const float cos_theta = cosf(theta_base);
11313
- const float sin_theta = sinf(theta_base);
11421
+ const float sin_theta = sinf(theta_base) * sin_sign;
11314
11422
  const float cos_block_theta = cosf(block_theta);
11315
- const float sin_block_theta = sinf(block_theta);
11423
+ const float sin_block_theta = sinf(block_theta) * sin_sign;
11316
11424
 
11317
11425
  theta_base *= theta_scale;
11318
11426
  block_theta *= theta_scale;
@@ -11336,6 +11444,7 @@ static void wsp_ggml_compute_forward_rope_f16(
11336
11444
  rope_yarn(
11337
11445
  theta_base, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta
11338
11446
  );
11447
+ sin_theta *= sin_sign;
11339
11448
 
11340
11449
  theta_base *= theta_scale;
11341
11450
 
@@ -11362,6 +11471,7 @@ static void wsp_ggml_compute_forward_rope_f16(
11362
11471
  theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
11363
11472
  &cos_theta, &sin_theta
11364
11473
  );
11474
+ sin_theta *= sin_sign;
11365
11475
 
11366
11476
  theta_base *= theta_scale;
11367
11477
 
@@ -11391,11 +11501,11 @@ static void wsp_ggml_compute_forward_rope(
11391
11501
  switch (src0->type) {
11392
11502
  case WSP_GGML_TYPE_F16:
11393
11503
  {
11394
- wsp_ggml_compute_forward_rope_f16(params, src0, src1, dst);
11504
+ wsp_ggml_compute_forward_rope_f16(params, src0, src1, dst, true);
11395
11505
  } break;
11396
11506
  case WSP_GGML_TYPE_F32:
11397
11507
  {
11398
- wsp_ggml_compute_forward_rope_f32(params, src0, src1, dst);
11508
+ wsp_ggml_compute_forward_rope_f32(params, src0, src1, dst, true);
11399
11509
  } break;
11400
11510
  default:
11401
11511
  {
@@ -11406,216 +11516,6 @@ static void wsp_ggml_compute_forward_rope(
11406
11516
 
11407
11517
  // wsp_ggml_compute_forward_rope_back
11408
11518
 
11409
- static void wsp_ggml_compute_forward_rope_back_f32(
11410
- const struct wsp_ggml_compute_params * params,
11411
- const struct wsp_ggml_tensor * src0,
11412
- const struct wsp_ggml_tensor * src1,
11413
- struct wsp_ggml_tensor * dst) {
11414
-
11415
- if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) {
11416
- return;
11417
- }
11418
-
11419
- // y = rope(x, src1)
11420
- // dx = rope_back(dy, src1)
11421
- // src0 is dy, src1 contains options
11422
-
11423
- float freq_base;
11424
- float freq_scale;
11425
-
11426
- // these two only relevant for xPos RoPE:
11427
- float xpos_base;
11428
- bool xpos_down;
11429
-
11430
- //const int n_past = ((int32_t *) dst->op_params)[0];
11431
- const int n_dims = ((int32_t *) dst->op_params)[1];
11432
- const int mode = ((int32_t *) dst->op_params)[2];
11433
- const int n_ctx = ((int32_t *) dst->op_params)[3]; UNUSED(n_ctx);
11434
- memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float));
11435
- memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
11436
- memcpy(&xpos_base, (int32_t *) dst->op_params + 6, sizeof(float));
11437
- memcpy(&xpos_down, (int32_t *) dst->op_params + 7, sizeof(bool));
11438
-
11439
- WSP_GGML_TENSOR_UNARY_OP_LOCALS
11440
-
11441
- //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
11442
- //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
11443
-
11444
- assert(nb0 == sizeof(float));
11445
-
11446
- const int ith = params->ith;
11447
- const int nth = params->nth;
11448
-
11449
- const int nr = wsp_ggml_nrows(dst);
11450
-
11451
- // rows per thread
11452
- const int dr = (nr + nth - 1)/nth;
11453
-
11454
- // row range for this thread
11455
- const int ir0 = dr*ith;
11456
- const int ir1 = MIN(ir0 + dr, nr);
11457
-
11458
- // row index used to determine which thread to use
11459
- int ir = 0;
11460
-
11461
- const float theta_scale = powf(freq_base, -2.0f/n_dims);
11462
-
11463
- const bool is_neox = mode & 2;
11464
-
11465
- const int32_t * pos = (const int32_t *) src1->data;
11466
-
11467
- for (int64_t i3 = 0; i3 < ne3; i3++) {
11468
- for (int64_t i2 = 0; i2 < ne2; i2++) {
11469
- const int64_t p = pos[i2];
11470
- for (int64_t i1 = 0; i1 < ne1; i1++) {
11471
- if (ir++ < ir0) continue;
11472
- if (ir > ir1) break;
11473
-
11474
- float theta_base = freq_scale * (float)p;
11475
-
11476
- if (!is_neox) {
11477
- for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
11478
- const float cos_theta = cosf(theta_base);
11479
- const float sin_theta = sinf(theta_base);
11480
-
11481
- // zeta scaling for xPos only:
11482
- float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f;
11483
- if (xpos_down) zeta = 1.0f / zeta;
11484
-
11485
- theta_base *= theta_scale;
11486
-
11487
- const float * const dy = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
11488
- float * dx = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
11489
-
11490
- const float dy0 = dy[0];
11491
- const float dy1 = dy[1];
11492
-
11493
- dx[0] = dy0*cos_theta*zeta + dy1*sin_theta*zeta;
11494
- dx[1] = - dy0*sin_theta*zeta + dy1*cos_theta*zeta;
11495
- }
11496
- } else {
11497
- for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
11498
- for (int64_t ic = 0; ic < n_dims; ic += 2) {
11499
- const float cos_theta = cosf(theta_base);
11500
- const float sin_theta = sinf(theta_base);
11501
-
11502
- theta_base *= theta_scale;
11503
-
11504
- const int64_t i0 = ib*n_dims + ic/2;
11505
-
11506
- const float * const dy = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
11507
- float * dx = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
11508
-
11509
- const float dy0 = dy[0];
11510
- const float dy1 = dy[n_dims/2];
11511
-
11512
- dx[0] = dy0*cos_theta + dy1*sin_theta;
11513
- dx[n_dims/2] = - dy0*sin_theta + dy1*cos_theta;
11514
- }
11515
- }
11516
- }
11517
- }
11518
- }
11519
- }
11520
- }
11521
-
11522
- static void wsp_ggml_compute_forward_rope_back_f16(
11523
- const struct wsp_ggml_compute_params * params,
11524
- const struct wsp_ggml_tensor * src0,
11525
- const struct wsp_ggml_tensor * src1,
11526
- struct wsp_ggml_tensor * dst) {
11527
-
11528
- if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) {
11529
- return;
11530
- }
11531
-
11532
- // y = rope(x, src1)
11533
- // dx = rope_back(dy, src1)
11534
- // src0 is dy, src1 contains options
11535
-
11536
- //const int n_past = ((int32_t *) dst->op_params)[0];
11537
- const int n_dims = ((int32_t *) dst->op_params)[1];
11538
- const int mode = ((int32_t *) dst->op_params)[2];
11539
-
11540
- WSP_GGML_TENSOR_UNARY_OP_LOCALS
11541
-
11542
- //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
11543
- //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
11544
-
11545
- assert(nb0 == sizeof(wsp_ggml_fp16_t));
11546
-
11547
- const int ith = params->ith;
11548
- const int nth = params->nth;
11549
-
11550
- const int nr = wsp_ggml_nrows(dst);
11551
-
11552
- // rows per thread
11553
- const int dr = (nr + nth - 1)/nth;
11554
-
11555
- // row range for this thread
11556
- const int ir0 = dr*ith;
11557
- const int ir1 = MIN(ir0 + dr, nr);
11558
-
11559
- // row index used to determine which thread to use
11560
- int ir = 0;
11561
-
11562
- const float theta_scale = powf(10000.0, -2.0f/n_dims);
11563
-
11564
- const bool is_neox = mode & 2;
11565
-
11566
- const int32_t * pos = (const int32_t *) src1->data;
11567
-
11568
- for (int64_t i3 = 0; i3 < ne3; i3++) {
11569
- for (int64_t i2 = 0; i2 < ne2; i2++) {
11570
- const int64_t p = pos[i2];
11571
- for (int64_t i1 = 0; i1 < ne1; i1++) {
11572
- if (ir++ < ir0) continue;
11573
- if (ir > ir1) break;
11574
-
11575
- float theta_base = (float)p;
11576
-
11577
- if (!is_neox) {
11578
- for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
11579
- const float cos_theta = cosf(theta_base);
11580
- const float sin_theta = sinf(theta_base);
11581
-
11582
- theta_base *= theta_scale;
11583
-
11584
- const wsp_ggml_fp16_t * const dy = (wsp_ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
11585
- wsp_ggml_fp16_t * dx = (wsp_ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
11586
-
11587
- const float dy0 = WSP_GGML_FP16_TO_FP32(dy[0]);
11588
- const float dy1 = WSP_GGML_FP16_TO_FP32(dy[1]);
11589
-
11590
- dx[0] = WSP_GGML_FP32_TO_FP16( dy0*cos_theta + dy1*sin_theta);
11591
- dx[1] = WSP_GGML_FP32_TO_FP16(-dy0*sin_theta + dy1*cos_theta);
11592
- }
11593
- } else {
11594
- for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
11595
- for (int64_t ic = 0; ic < n_dims; ic += 2) {
11596
- const float cos_theta = cosf(theta_base);
11597
- const float sin_theta = sinf(theta_base);
11598
-
11599
- theta_base *= theta_scale;
11600
-
11601
- const int64_t i0 = ib*n_dims + ic/2;
11602
-
11603
- const wsp_ggml_fp16_t * const dy = (wsp_ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
11604
- wsp_ggml_fp16_t * dx = (wsp_ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
11605
-
11606
- const float dy0 = WSP_GGML_FP16_TO_FP32(dy[0]);
11607
- const float dy1 = WSP_GGML_FP16_TO_FP32(dy[n_dims/2]);
11608
-
11609
- dx[0] = WSP_GGML_FP32_TO_FP16( dy0*cos_theta + dy1*sin_theta);
11610
- dx[n_dims/2] = WSP_GGML_FP32_TO_FP16(-dy0*sin_theta + dy1*cos_theta);
11611
- }
11612
- }
11613
- }
11614
- }
11615
- }
11616
- }
11617
- }
11618
-
11619
11519
  static void wsp_ggml_compute_forward_rope_back(
11620
11520
  const struct wsp_ggml_compute_params * params,
11621
11521
  const struct wsp_ggml_tensor * src0,
@@ -11624,11 +11524,11 @@ static void wsp_ggml_compute_forward_rope_back(
11624
11524
  switch (src0->type) {
11625
11525
  case WSP_GGML_TYPE_F16:
11626
11526
  {
11627
- wsp_ggml_compute_forward_rope_back_f16(params, src0, src1, dst);
11527
+ wsp_ggml_compute_forward_rope_f16(params, src0, src1, dst, false);
11628
11528
  } break;
11629
11529
  case WSP_GGML_TYPE_F32:
11630
11530
  {
11631
- wsp_ggml_compute_forward_rope_back_f32(params, src0, src1, dst);
11531
+ wsp_ggml_compute_forward_rope_f32(params, src0, src1, dst, false);
11632
11532
  } break;
11633
11533
  default:
11634
11534
  {
@@ -11637,9 +11537,9 @@ static void wsp_ggml_compute_forward_rope_back(
11637
11537
  }
11638
11538
  }
11639
11539
 
11640
- // wsp_ggml_compute_forward_conv_1d
11540
+ // wsp_ggml_compute_forward_conv_transpose_1d
11641
11541
 
11642
- static void wsp_ggml_compute_forward_conv_1d_f16_f32(
11542
+ static void wsp_ggml_compute_forward_conv_transpose_1d_f16_f32(
11643
11543
  const struct wsp_ggml_compute_params * params,
11644
11544
  const struct wsp_ggml_tensor * src0,
11645
11545
  const struct wsp_ggml_tensor * src1,
@@ -11656,14 +11556,7 @@ static void wsp_ggml_compute_forward_conv_1d_f16_f32(
11656
11556
  const int ith = params->ith;
11657
11557
  const int nth = params->nth;
11658
11558
 
11659
- const int nk = ne00;
11660
-
11661
- // size of the convolution row - the kernel size unrolled across all input channels
11662
- const int ew0 = nk*ne01;
11663
-
11664
- const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
11665
- const int32_t p0 = ((const int32_t*)(dst->op_params))[1];
11666
- const int32_t d0 = ((const int32_t*)(dst->op_params))[2];
11559
+ const int nk = ne00*ne01*ne02;
11667
11560
 
11668
11561
  WSP_GGML_ASSERT(nb00 == sizeof(wsp_ggml_fp16_t));
11669
11562
  WSP_GGML_ASSERT(nb10 == sizeof(float));
@@ -11671,23 +11564,37 @@ static void wsp_ggml_compute_forward_conv_1d_f16_f32(
11671
11564
  if (params->type == WSP_GGML_TASK_INIT) {
11672
11565
  memset(params->wdata, 0, params->wsize);
11673
11566
 
11674
- wsp_ggml_fp16_t * const wdata = (wsp_ggml_fp16_t *) params->wdata + 0;
11567
+ // permute kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout)
11568
+ {
11569
+ wsp_ggml_fp16_t * const wdata = (wsp_ggml_fp16_t *) params->wdata + 0;
11675
11570
 
11676
- for (int64_t i11 = 0; i11 < ne11; i11++) {
11677
- const float * const src = (float *)((char *) src1->data + i11*nb11);
11678
- wsp_ggml_fp16_t * dst_data = wdata;
11571
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
11572
+ for (int64_t i01 = 0; i01 < ne01; i01++) {
11573
+ const wsp_ggml_fp16_t * const src = (wsp_ggml_fp16_t *)((char *) src0->data + i02*nb02 + i01*nb01);
11574
+ wsp_ggml_fp16_t * dst_data = wdata + i01*ne00*ne02;
11575
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
11576
+ dst_data[i00*ne02 + i02] = src[i00];
11577
+ }
11578
+ }
11579
+ }
11580
+ }
11679
11581
 
11680
- for (int64_t i0 = 0; i0 < ne0; i0++) {
11681
- for (int64_t ik = 0; ik < nk; ik++) {
11682
- const int idx0 = i0*s0 + ik*d0 - p0;
11582
+ // permute source data (src1) from (L x Cin) to (Cin x L)
11583
+ {
11584
+ wsp_ggml_fp16_t * const wdata = (wsp_ggml_fp16_t *) params->wdata + nk;
11585
+ wsp_ggml_fp16_t * dst_data = wdata;
11683
11586
 
11684
- if(!(idx0 < 0 || idx0 >= ne10)) {
11685
- dst_data[i0*ew0 + i11*nk + ik] = WSP_GGML_FP32_TO_FP16(src[idx0]);
11686
- }
11587
+ for (int64_t i11 = 0; i11 < ne11; i11++) {
11588
+ const float * const src = (float *)((char *) src1->data + i11*nb11);
11589
+ for (int64_t i10 = 0; i10 < ne10; i10++) {
11590
+ dst_data[i10*ne11 + i11] = WSP_GGML_FP32_TO_FP16(src[i10]);
11687
11591
  }
11688
11592
  }
11689
11593
  }
11690
11594
 
11595
+ // need to zero dst since we are accumulating into it
11596
+ memset(dst->data, 0, wsp_ggml_nbytes(dst));
11597
+
11691
11598
  return;
11692
11599
  }
11693
11600
 
@@ -11695,424 +11602,7 @@ static void wsp_ggml_compute_forward_conv_1d_f16_f32(
11695
11602
  return;
11696
11603
  }
11697
11604
 
11698
- // total rows in dst
11699
- const int nr = ne2;
11700
-
11701
- // rows per thread
11702
- const int dr = (nr + nth - 1)/nth;
11703
-
11704
- // row range for this thread
11705
- const int ir0 = dr*ith;
11706
- const int ir1 = MIN(ir0 + dr, nr);
11707
-
11708
- wsp_ggml_fp16_t * const wdata = (wsp_ggml_fp16_t *) params->wdata + 0;
11709
-
11710
- for (int i2 = 0; i2 < ne2; i2++) {
11711
- for (int i1 = ir0; i1 < ir1; i1++) {
11712
- float * dst_data = (float *)((char *) dst->data + i2*nb2 + i1*nb1);
11713
-
11714
- for (int i0 = 0; i0 < ne0; i0++) {
11715
- wsp_ggml_vec_dot_f16(ew0, dst_data + i0,
11716
- (wsp_ggml_fp16_t *) ((char *) src0->data + i1*nb02),
11717
- (wsp_ggml_fp16_t *) wdata + i2*nb2 + i0*ew0);
11718
- }
11719
- }
11720
- }
11721
- }
11722
-
11723
- static void wsp_ggml_compute_forward_conv_1d_f32(
11724
- const struct wsp_ggml_compute_params * params,
11725
- const struct wsp_ggml_tensor * src0,
11726
- const struct wsp_ggml_tensor * src1,
11727
- struct wsp_ggml_tensor * dst) {
11728
- WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F32);
11729
- WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_F32);
11730
- WSP_GGML_ASSERT( dst->type == WSP_GGML_TYPE_F32);
11731
-
11732
- int64_t t0 = wsp_ggml_perf_time_us();
11733
- UNUSED(t0);
11734
-
11735
- WSP_GGML_TENSOR_BINARY_OP_LOCALS
11736
-
11737
- const int ith = params->ith;
11738
- const int nth = params->nth;
11739
-
11740
- const int nk = ne00;
11741
-
11742
- const int ew0 = nk*ne01;
11743
-
11744
- const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
11745
- const int32_t p0 = ((const int32_t*)(dst->op_params))[1];
11746
- const int32_t d0 = ((const int32_t*)(dst->op_params))[2];
11747
-
11748
- WSP_GGML_ASSERT(nb00 == sizeof(float));
11749
- WSP_GGML_ASSERT(nb10 == sizeof(float));
11750
-
11751
- if (params->type == WSP_GGML_TASK_INIT) {
11752
- memset(params->wdata, 0, params->wsize);
11753
-
11754
- float * const wdata = (float *) params->wdata + 0;
11755
-
11756
- for (int64_t i11 = 0; i11 < ne11; i11++) {
11757
- const float * const src = (float *)((char *) src1->data + i11*nb11);
11758
- float * dst_data = wdata;
11759
-
11760
- for (int64_t i0 = 0; i0 < ne0; i0++) {
11761
- for (int64_t ik = 0; ik < nk; ik++) {
11762
- const int idx0 = i0*s0 + ik*d0 - p0;
11763
-
11764
- if(!(idx0 < 0 || idx0 >= ne10)) {
11765
- dst_data[i0*ew0 + i11*nk + ik] = src[idx0];
11766
- }
11767
- }
11768
- }
11769
- }
11770
-
11771
- return;
11772
- }
11773
-
11774
- if (params->type == WSP_GGML_TASK_FINALIZE) {
11775
- return;
11776
- }
11777
-
11778
- // total rows in dst
11779
- const int nr = ne02;
11780
-
11781
- // rows per thread
11782
- const int dr = (nr + nth - 1)/nth;
11783
-
11784
- // row range for this thread
11785
- const int ir0 = dr*ith;
11786
- const int ir1 = MIN(ir0 + dr, nr);
11787
-
11788
- float * const wdata = (float *) params->wdata + 0;
11789
-
11790
- for (int i2 = 0; i2 < ne2; i2++) {
11791
- for (int i1 = ir0; i1 < ir1; i1++) {
11792
- float * dst_data = (float *)((char *) dst->data + i2*nb2 + i1*nb1);
11793
-
11794
- for (int i0 = 0; i0 < ne0; i0++) {
11795
- wsp_ggml_vec_dot_f32(ew0, dst_data + i0,
11796
- (float *) ((char *) src0->data + i1*nb02),
11797
- (float *) wdata + i2*nb2 + i0*ew0);
11798
- }
11799
- }
11800
- }
11801
- }
11802
-
11803
- // TODO: reuse wsp_ggml_mul_mat or implement wsp_ggml_im2col and remove stage_0 and stage_1
11804
- static void gemm_f16_out_f32(int64_t m, int64_t n, int64_t k,
11805
- wsp_ggml_fp16_t * A,
11806
- wsp_ggml_fp16_t * B,
11807
- float * C,
11808
- const int ith, const int nth) {
11809
- // does not seem to make a difference
11810
- int64_t m0, m1, n0, n1;
11811
- // patches per thread
11812
- if (m > n) {
11813
- n0 = 0;
11814
- n1 = n;
11815
-
11816
- // total patches in dst
11817
- const int np = m;
11818
-
11819
- // patches per thread
11820
- const int dp = (np + nth - 1)/nth;
11821
-
11822
- // patch range for this thread
11823
- m0 = dp*ith;
11824
- m1 = MIN(m0 + dp, np);
11825
- } else {
11826
- m0 = 0;
11827
- m1 = m;
11828
-
11829
- // total patches in dst
11830
- const int np = n;
11831
-
11832
- // patches per thread
11833
- const int dp = (np + nth - 1)/nth;
11834
-
11835
- // patch range for this thread
11836
- n0 = dp*ith;
11837
- n1 = MIN(n0 + dp, np);
11838
- }
11839
-
11840
- // block-tiling attempt
11841
- int64_t blck_n = 16;
11842
- int64_t blck_m = 16;
11843
-
11844
- // int64_t CACHE_SIZE = 2 * 1024 * 1024; // 2MB
11845
- // int64_t blck_size = CACHE_SIZE / (sizeof(float) + 2 * sizeof(wsp_ggml_fp16_t) * K);
11846
- // if (blck_size > 0) {
11847
- // blck_0 = 4;
11848
- // blck_1 = blck_size / blck_0;
11849
- // if (blck_1 < 0) {
11850
- // blck_1 = 1;
11851
- // }
11852
- // // blck_0 = (int64_t)sqrt(blck_size);
11853
- // // blck_1 = blck_0;
11854
- // }
11855
- // // printf("%zd %zd %zd %zd\n", blck_size, K, blck_0, blck_1);
11856
-
11857
- for (int j = n0; j < n1; j+=blck_n) {
11858
- for (int i = m0; i < m1; i+=blck_m) {
11859
- // printf("i j k => %d %d %d\n", i, j, K);
11860
- for (int ii = i; ii < i + blck_m && ii < m1; ii++) {
11861
- for (int jj = j; jj < j + blck_n && jj < n1; jj++) {
11862
- wsp_ggml_vec_dot_f16(k,
11863
- C + ii*n + jj,
11864
- A + ii * k,
11865
- B + jj * k);
11866
- }
11867
- }
11868
- }
11869
- }
11870
- }
11871
-
11872
- // src0: kernel [OC, IC, K]
11873
- // src1: signal [N, IC, IL]
11874
- // dst: result [N, OL, IC*K]
11875
- static void wsp_ggml_compute_forward_conv_1d_stage_0_f32(
11876
- const struct wsp_ggml_compute_params * params,
11877
- const struct wsp_ggml_tensor * src0,
11878
- const struct wsp_ggml_tensor * src1,
11879
- struct wsp_ggml_tensor * dst) {
11880
- WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F16);
11881
- WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_F32);
11882
- WSP_GGML_ASSERT( dst->type == WSP_GGML_TYPE_F16);
11883
-
11884
- int64_t t0 = wsp_ggml_perf_time_us();
11885
- UNUSED(t0);
11886
-
11887
- WSP_GGML_TENSOR_BINARY_OP_LOCALS;
11888
-
11889
- const int64_t N = ne12;
11890
- const int64_t IC = ne11;
11891
- const int64_t IL = ne10;
11892
-
11893
- const int64_t K = ne00;
11894
-
11895
- const int64_t OL = ne1;
11896
-
11897
- const int ith = params->ith;
11898
- const int nth = params->nth;
11899
-
11900
- const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
11901
- const int32_t p0 = ((const int32_t*)(dst->op_params))[1];
11902
- const int32_t d0 = ((const int32_t*)(dst->op_params))[2];
11903
-
11904
- WSP_GGML_ASSERT(nb00 == sizeof(wsp_ggml_fp16_t));
11905
- WSP_GGML_ASSERT(nb10 == sizeof(float));
11906
-
11907
- if (params->type == WSP_GGML_TASK_INIT) {
11908
- memset(dst->data, 0, wsp_ggml_nbytes(dst));
11909
- return;
11910
- }
11911
-
11912
- if (params->type == WSP_GGML_TASK_FINALIZE) {
11913
- return;
11914
- }
11915
-
11916
- // im2col: [N, IC, IL] => [N, OL, IC*K]
11917
- {
11918
- wsp_ggml_fp16_t * const wdata = (wsp_ggml_fp16_t *) dst->data;
11919
-
11920
- for (int64_t in = 0; in < N; in++) {
11921
- for (int64_t iol = 0; iol < OL; iol++) {
11922
- for (int64_t iic = ith; iic < IC; iic+=nth) {
11923
-
11924
- // micro kernel
11925
- wsp_ggml_fp16_t * dst_data = wdata + (in*OL + iol)*(IC*K); // [IC, K]
11926
- const float * const src_data = (float *)((char *) src1->data + in*nb12 + iic*nb11); // [IL]
11927
-
11928
- for (int64_t ik = 0; ik < K; ik++) {
11929
- const int64_t iil = iol*s0 + ik*d0 - p0;
11930
-
11931
- if (!(iil < 0 || iil >= IL)) {
11932
- dst_data[iic*K + ik] = WSP_GGML_FP32_TO_FP16(src_data[iil]);
11933
- }
11934
- }
11935
- }
11936
- }
11937
- }
11938
- }
11939
- }
11940
-
11941
- // gemm: [N, OC, OL] = [OC, IC * K] x [N*OL, IC * K]
11942
- // src0: [OC, IC, K]
11943
- // src1: [N, OL, IC * K]
11944
- // result: [N, OC, OL]
11945
- static void wsp_ggml_compute_forward_conv_1d_stage_1_f16(
11946
- const struct wsp_ggml_compute_params * params,
11947
- const struct wsp_ggml_tensor * src0,
11948
- const struct wsp_ggml_tensor * src1,
11949
- struct wsp_ggml_tensor * dst) {
11950
- WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F16);
11951
- WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_F16);
11952
- WSP_GGML_ASSERT( dst->type == WSP_GGML_TYPE_F32);
11953
-
11954
- int64_t t0 = wsp_ggml_perf_time_us();
11955
- UNUSED(t0);
11956
-
11957
- if (params->type == WSP_GGML_TASK_INIT) {
11958
- return;
11959
- }
11960
-
11961
- if (params->type == WSP_GGML_TASK_FINALIZE) {
11962
- return;
11963
- }
11964
-
11965
- WSP_GGML_TENSOR_BINARY_OP_LOCALS;
11966
-
11967
- WSP_GGML_ASSERT(nb00 == sizeof(wsp_ggml_fp16_t));
11968
- WSP_GGML_ASSERT(nb10 == sizeof(wsp_ggml_fp16_t));
11969
- WSP_GGML_ASSERT(nb0 == sizeof(float));
11970
-
11971
- const int N = ne12;
11972
- const int OL = ne11;
11973
-
11974
- const int OC = ne02;
11975
- const int IC = ne01;
11976
- const int K = ne00;
11977
-
11978
- const int ith = params->ith;
11979
- const int nth = params->nth;
11980
-
11981
- int64_t m = OC;
11982
- int64_t n = OL;
11983
- int64_t k = IC * K;
11984
-
11985
- // [N, OC, OL] = [OC, IC * K] x [N*OL, IC * K]
11986
- for (int i = 0; i < N; i++) {
11987
- wsp_ggml_fp16_t * A = (wsp_ggml_fp16_t *)src0->data; // [m, k]
11988
- wsp_ggml_fp16_t * B = (wsp_ggml_fp16_t *)src1->data + i * m * k; // [n, k]
11989
- float * C = (float *)dst->data + i * m * n; // [m, n]
11990
-
11991
- gemm_f16_out_f32(m, n, k, A, B, C, ith, nth);
11992
- }
11993
- }
11994
-
11995
- static void wsp_ggml_compute_forward_conv_1d(
11996
- const struct wsp_ggml_compute_params * params,
11997
- const struct wsp_ggml_tensor * src0,
11998
- const struct wsp_ggml_tensor * src1,
11999
- struct wsp_ggml_tensor * dst) {
12000
- switch(src0->type) {
12001
- case WSP_GGML_TYPE_F16:
12002
- {
12003
- wsp_ggml_compute_forward_conv_1d_f16_f32(params, src0, src1, dst);
12004
- } break;
12005
- case WSP_GGML_TYPE_F32:
12006
- {
12007
- wsp_ggml_compute_forward_conv_1d_f32(params, src0, src1, dst);
12008
- } break;
12009
- default:
12010
- {
12011
- WSP_GGML_ASSERT(false);
12012
- } break;
12013
- }
12014
- }
12015
-
12016
- static void wsp_ggml_compute_forward_conv_1d_stage_0(
12017
- const struct wsp_ggml_compute_params * params,
12018
- const struct wsp_ggml_tensor * src0,
12019
- const struct wsp_ggml_tensor * src1,
12020
- struct wsp_ggml_tensor * dst) {
12021
- switch(src0->type) {
12022
- case WSP_GGML_TYPE_F16:
12023
- {
12024
- wsp_ggml_compute_forward_conv_1d_stage_0_f32(params, src0, src1, dst);
12025
- } break;
12026
- default:
12027
- {
12028
- WSP_GGML_ASSERT(false);
12029
- } break;
12030
- }
12031
- }
12032
-
12033
- static void wsp_ggml_compute_forward_conv_1d_stage_1(
12034
- const struct wsp_ggml_compute_params * params,
12035
- const struct wsp_ggml_tensor * src0,
12036
- const struct wsp_ggml_tensor * src1,
12037
- struct wsp_ggml_tensor * dst) {
12038
- switch(src0->type) {
12039
- case WSP_GGML_TYPE_F16:
12040
- {
12041
- wsp_ggml_compute_forward_conv_1d_stage_1_f16(params, src0, src1, dst);
12042
- } break;
12043
- default:
12044
- {
12045
- WSP_GGML_ASSERT(false);
12046
- } break;
12047
- }
12048
- }
12049
-
12050
- // wsp_ggml_compute_forward_conv_transpose_1d
12051
-
12052
- static void wsp_ggml_compute_forward_conv_transpose_1d_f16_f32(
12053
- const struct wsp_ggml_compute_params * params,
12054
- const struct wsp_ggml_tensor * src0,
12055
- const struct wsp_ggml_tensor * src1,
12056
- struct wsp_ggml_tensor * dst) {
12057
- WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F16);
12058
- WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_F32);
12059
- WSP_GGML_ASSERT( dst->type == WSP_GGML_TYPE_F32);
12060
-
12061
- int64_t t0 = wsp_ggml_perf_time_us();
12062
- UNUSED(t0);
12063
-
12064
- WSP_GGML_TENSOR_BINARY_OP_LOCALS
12065
-
12066
- const int ith = params->ith;
12067
- const int nth = params->nth;
12068
-
12069
- const int nk = ne00*ne01*ne02;
12070
-
12071
- WSP_GGML_ASSERT(nb00 == sizeof(wsp_ggml_fp16_t));
12072
- WSP_GGML_ASSERT(nb10 == sizeof(float));
12073
-
12074
- if (params->type == WSP_GGML_TASK_INIT) {
12075
- memset(params->wdata, 0, params->wsize);
12076
-
12077
- // permute kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout)
12078
- {
12079
- wsp_ggml_fp16_t * const wdata = (wsp_ggml_fp16_t *) params->wdata + 0;
12080
-
12081
- for (int64_t i02 = 0; i02 < ne02; i02++) {
12082
- for (int64_t i01 = 0; i01 < ne01; i01++) {
12083
- const wsp_ggml_fp16_t * const src = (wsp_ggml_fp16_t *)((char *) src0->data + i02*nb02 + i01*nb01);
12084
- wsp_ggml_fp16_t * dst_data = wdata + i01*ne00*ne02;
12085
- for (int64_t i00 = 0; i00 < ne00; i00++) {
12086
- dst_data[i00*ne02 + i02] = src[i00];
12087
- }
12088
- }
12089
- }
12090
- }
12091
-
12092
- // permute source data (src1) from (L x Cin) to (Cin x L)
12093
- {
12094
- wsp_ggml_fp16_t * const wdata = (wsp_ggml_fp16_t *) params->wdata + nk;
12095
- wsp_ggml_fp16_t * dst_data = wdata;
12096
-
12097
- for (int64_t i11 = 0; i11 < ne11; i11++) {
12098
- const float * const src = (float *)((char *) src1->data + i11*nb11);
12099
- for (int64_t i10 = 0; i10 < ne10; i10++) {
12100
- dst_data[i10*ne11 + i11] = WSP_GGML_FP32_TO_FP16(src[i10]);
12101
- }
12102
- }
12103
- }
12104
-
12105
- // need to zero dst since we are accumulating into it
12106
- memset(dst->data, 0, wsp_ggml_nbytes(dst));
12107
-
12108
- return;
12109
- }
12110
-
12111
- if (params->type == WSP_GGML_TASK_FINALIZE) {
12112
- return;
12113
- }
12114
-
12115
- const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
11605
+ const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
12116
11606
 
12117
11607
  // total rows in dst
12118
11608
  const int nr = ne1;
@@ -12258,229 +11748,81 @@ static void wsp_ggml_compute_forward_conv_transpose_1d(
12258
11748
  }
12259
11749
  }
12260
11750
 
12261
- // wsp_ggml_compute_forward_conv_2d
12262
-
12263
11751
  // src0: kernel [OC, IC, KH, KW]
12264
11752
  // src1: image [N, IC, IH, IW]
12265
11753
  // dst: result [N, OH, OW, IC*KH*KW]
12266
- static void wsp_ggml_compute_forward_conv_2d_stage_0_f32(
12267
- const struct wsp_ggml_compute_params * params,
12268
- const struct wsp_ggml_tensor * src0,
12269
- const struct wsp_ggml_tensor * src1,
12270
- struct wsp_ggml_tensor * dst) {
12271
- WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F16);
12272
- WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_F32);
12273
- WSP_GGML_ASSERT( dst->type == WSP_GGML_TYPE_F16);
12274
-
12275
- int64_t t0 = wsp_ggml_perf_time_us();
12276
- UNUSED(t0);
12277
-
12278
- WSP_GGML_TENSOR_BINARY_OP_LOCALS;
12279
-
12280
- const int64_t N = ne13;
12281
- const int64_t IC = ne12;
12282
- const int64_t IH = ne11;
12283
- const int64_t IW = ne10;
12284
-
12285
- // const int64_t OC = ne03;
12286
- // const int64_t IC = ne02;
12287
- const int64_t KH = ne01;
12288
- const int64_t KW = ne00;
12289
-
12290
- const int64_t OH = ne2;
12291
- const int64_t OW = ne1;
12292
-
12293
- const int ith = params->ith;
12294
- const int nth = params->nth;
12295
-
12296
- const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
12297
- const int32_t s1 = ((const int32_t*)(dst->op_params))[1];
12298
- const int32_t p0 = ((const int32_t*)(dst->op_params))[2];
12299
- const int32_t p1 = ((const int32_t*)(dst->op_params))[3];
12300
- const int32_t d0 = ((const int32_t*)(dst->op_params))[4];
12301
- const int32_t d1 = ((const int32_t*)(dst->op_params))[5];
12302
-
12303
- WSP_GGML_ASSERT(nb00 == sizeof(wsp_ggml_fp16_t));
12304
- WSP_GGML_ASSERT(nb10 == sizeof(float));
12305
-
12306
- if (params->type == WSP_GGML_TASK_INIT) {
12307
- memset(dst->data, 0, wsp_ggml_nbytes(dst));
12308
- return;
12309
- }
12310
-
12311
- if (params->type == WSP_GGML_TASK_FINALIZE) {
12312
- return;
12313
- }
12314
-
12315
- // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
12316
- {
12317
- wsp_ggml_fp16_t * const wdata = (wsp_ggml_fp16_t *) dst->data;
12318
-
12319
- for (int64_t in = 0; in < N; in++) {
12320
- for (int64_t ioh = 0; ioh < OH; ioh++) {
12321
- for (int64_t iow = 0; iow < OW; iow++) {
12322
- for (int64_t iic = ith; iic < IC; iic+=nth) {
12323
-
12324
- // micro kernel
12325
- wsp_ggml_fp16_t * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
12326
- const float * const src_data = (float *)((char *) src1->data + in*nb13 + iic*nb12); // [IH, IW]
12327
-
12328
- for (int64_t ikh = 0; ikh < KH; ikh++) {
12329
- for (int64_t ikw = 0; ikw < KW; ikw++) {
12330
- const int64_t iiw = iow*s0 + ikw*d0 - p0;
12331
- const int64_t iih = ioh*s1 + ikh*d1 - p1;
12332
-
12333
- if (!(iih < 0 || iih >= IH || iiw < 0 || iiw >= IW)) {
12334
- dst_data[iic*(KH*KW) + ikh*KW + ikw] = WSP_GGML_FP32_TO_FP16(src_data[iih*IW + iiw]);
12335
- }
12336
- }
12337
- }
12338
- }
12339
- }
12340
- }
12341
- }
12342
- }
12343
- }
12344
-
12345
- // gemm: [N, OC, OH, OW] = [OC, IC * KH * KW] x [N*OH*OW, IC * KH * KW]
12346
- // src0: [OC, IC, KH, KW]
12347
- // src1: [N, OH, OW, IC * KH * KW]
12348
- // result: [N, OC, OH, OW]
12349
- static void wsp_ggml_compute_forward_conv_2d_stage_1_f16(
12350
- const struct wsp_ggml_compute_params * params,
12351
- const struct wsp_ggml_tensor * src0,
12352
- const struct wsp_ggml_tensor * src1,
12353
- struct wsp_ggml_tensor * dst) {
12354
- WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F16);
12355
- WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_F16);
12356
- WSP_GGML_ASSERT( dst->type == WSP_GGML_TYPE_F32);
12357
-
12358
- int64_t t0 = wsp_ggml_perf_time_us();
12359
- UNUSED(t0);
12360
-
12361
- if (params->type == WSP_GGML_TASK_INIT) {
12362
- return;
12363
- }
12364
-
12365
- if (params->type == WSP_GGML_TASK_FINALIZE) {
12366
- return;
12367
- }
12368
-
12369
- WSP_GGML_TENSOR_BINARY_OP_LOCALS;
12370
-
12371
- WSP_GGML_ASSERT(nb00 == sizeof(wsp_ggml_fp16_t));
12372
- WSP_GGML_ASSERT(nb10 == sizeof(wsp_ggml_fp16_t));
12373
- WSP_GGML_ASSERT(nb0 == sizeof(float));
12374
-
12375
- const int N = ne13;
12376
- const int OH = ne12;
12377
- const int OW = ne11;
12378
-
12379
- const int OC = ne03;
12380
- const int IC = ne02;
12381
- const int KH = ne01;
12382
- const int KW = ne00;
12383
-
12384
- const int ith = params->ith;
12385
- const int nth = params->nth;
12386
-
12387
- int64_t m = OC;
12388
- int64_t n = OH * OW;
12389
- int64_t k = IC * KH * KW;
12390
-
12391
- // [N, OC, OH, OW] = [OC, IC * KH * KW] x [N*OH*OW, IC * KH * KW]
12392
- for (int i = 0; i < N; i++) {
12393
- wsp_ggml_fp16_t * A = (wsp_ggml_fp16_t *)src0->data; // [m, k]
12394
- wsp_ggml_fp16_t * B = (wsp_ggml_fp16_t *)src1->data + i * m * k; // [n, k]
12395
- float * C = (float *)dst->data + i * m * n; // [m, n]
12396
-
12397
- gemm_f16_out_f32(m, n, k, A, B, C, ith, nth);
12398
- }
12399
- }
12400
-
12401
- static void wsp_ggml_compute_forward_conv_2d_f16_f32(
11754
+ static void wsp_ggml_compute_forward_im2col_f16(
12402
11755
  const struct wsp_ggml_compute_params * params,
12403
11756
  const struct wsp_ggml_tensor * src0,
12404
11757
  const struct wsp_ggml_tensor * src1,
12405
11758
  struct wsp_ggml_tensor * dst) {
12406
11759
  WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F16);
12407
11760
  WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_F32);
12408
- WSP_GGML_ASSERT( dst->type == WSP_GGML_TYPE_F32);
11761
+ WSP_GGML_ASSERT( dst->type == WSP_GGML_TYPE_F16);
12409
11762
 
12410
11763
  int64_t t0 = wsp_ggml_perf_time_us();
12411
11764
  UNUSED(t0);
12412
11765
 
12413
- WSP_GGML_TENSOR_BINARY_OP_LOCALS
11766
+ WSP_GGML_TENSOR_BINARY_OP_LOCALS;
12414
11767
 
12415
- // src1: image [N, IC, IH, IW]
12416
- // src0: kernel [OC, IC, KH, KW]
12417
- // dst: result [N, OC, OH, OW]
12418
- // ne12: IC
12419
- // ne0: OW
12420
- // ne1: OH
12421
- // nk0: KW
12422
- // nk1: KH
12423
- // ne13: N
12424
-
12425
- const int N = ne13;
12426
- const int IC = ne12;
12427
- const int IH = ne11;
12428
- const int IW = ne10;
12429
-
12430
- const int OC = ne03;
12431
- // const int IC = ne02;
12432
- const int KH = ne01;
12433
- const int KW = ne00;
12434
-
12435
- const int OH = ne1;
12436
- const int OW = ne0;
11768
+ const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
11769
+ const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
11770
+ const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
11771
+ const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
11772
+ const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
11773
+ const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
11774
+ const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
12437
11775
 
12438
11776
  const int ith = params->ith;
12439
11777
  const int nth = params->nth;
12440
11778
 
12441
- // const int nk0 = ne00;
12442
- // const int nk1 = ne01;
11779
+ const int64_t N = is_2D ? ne13 : ne12;
11780
+ const int64_t IC = is_2D ? ne12 : ne11;
11781
+ const int64_t IH = is_2D ? ne11 : 1;
11782
+ const int64_t IW = ne10;
12443
11783
 
12444
- // size of the convolution row - the kernel size unrolled across all channels
12445
- // const int ew0 = nk0*nk1*ne02;
12446
- // ew0: IC*KH*KW
11784
+ const int64_t KH = is_2D ? ne01 : 1;
11785
+ const int64_t KW = ne00;
12447
11786
 
12448
- const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
12449
- const int32_t s1 = ((const int32_t*)(dst->op_params))[1];
12450
- const int32_t p0 = ((const int32_t*)(dst->op_params))[2];
12451
- const int32_t p1 = ((const int32_t*)(dst->op_params))[3];
12452
- const int32_t d0 = ((const int32_t*)(dst->op_params))[4];
12453
- const int32_t d1 = ((const int32_t*)(dst->op_params))[5];
11787
+ const int64_t OH = is_2D ? ne2 : 1;
11788
+ const int64_t OW = ne1;
11789
+
11790
+ int ofs0 = is_2D ? nb13 : nb12;
11791
+ int ofs1 = is_2D ? nb12 : nb11;
12454
11792
 
12455
11793
  WSP_GGML_ASSERT(nb00 == sizeof(wsp_ggml_fp16_t));
12456
11794
  WSP_GGML_ASSERT(nb10 == sizeof(float));
12457
11795
 
12458
11796
  if (params->type == WSP_GGML_TASK_INIT) {
12459
- memset(params->wdata, 0, params->wsize);
11797
+ return;
11798
+ }
12460
11799
 
12461
- // prepare source data (src1)
12462
- // im2col: [N, IC, IH, IW] => [N*OH*OW, IC*KH*KW]
11800
+ if (params->type == WSP_GGML_TASK_FINALIZE) {
11801
+ return;
11802
+ }
12463
11803
 
12464
- {
12465
- wsp_ggml_fp16_t * const wdata = (wsp_ggml_fp16_t *) params->wdata + 0;
11804
+ // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
11805
+ {
11806
+ wsp_ggml_fp16_t * const wdata = (wsp_ggml_fp16_t *) dst->data;
12466
11807
 
12467
- for (int in = 0; in < N; in++) {
12468
- for (int iic = 0; iic < IC; iic++) {
12469
- for (int ioh = 0; ioh < OH; ioh++) {
12470
- for (int iow = 0; iow < OW; iow++) {
11808
+ for (int64_t in = 0; in < N; in++) {
11809
+ for (int64_t ioh = 0; ioh < OH; ioh++) { // 1
11810
+ for (int64_t iow = 0; iow < OW; iow++) {
11811
+ for (int64_t iic = ith; iic < IC; iic += nth) {
12471
11812
 
12472
- // micro kernel
12473
- wsp_ggml_fp16_t * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
12474
- const float * const src_data = (float *)((char *) src1->data + in*nb13 + iic*nb12); // [IH, IW]
11813
+ // micro kernel
11814
+ wsp_ggml_fp16_t * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
11815
+ const float * const src_data = (float *)((char *) src1->data + in*ofs0 + iic*ofs1); // [IH, IW]
12475
11816
 
12476
- for (int ikh = 0; ikh < KH; ikh++) {
12477
- for (int ikw = 0; ikw < KW; ikw++) {
12478
- const int iiw = iow*s0 + ikw*d0 - p0;
12479
- const int iih = ioh*s1 + ikh*d1 - p1;
11817
+ for (int64_t ikh = 0; ikh < KH; ikh++) { // 1
11818
+ for (int64_t ikw = 0; ikw < KW; ikw++) {
11819
+ const int64_t iiw = iow*s0 + ikw*d0 - p0;
11820
+ const int64_t iih = ioh*s1 + ikh*d1 - p1;
12480
11821
 
12481
- if (!(iih < 0 || iih >= IH || iiw < 0 || iiw >= IW)) {
12482
- dst_data[iic*(KH*KW) + ikh*KW + ikw] = WSP_GGML_FP32_TO_FP16(src_data[iih*IW + iiw]);
12483
- }
11822
+ if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
11823
+ dst_data[iic*(KH*KW) + ikh*KW + ikw] = 0;
11824
+ } else {
11825
+ dst_data[iic*(KH*KW) + ikh*KW + ikw] = WSP_GGML_FP32_TO_FP16(src_data[iih*IW + iiw]);
12484
11826
  }
12485
11827
  }
12486
11828
  }
@@ -12488,77 +11830,10 @@ static void wsp_ggml_compute_forward_conv_2d_f16_f32(
12488
11830
  }
12489
11831
  }
12490
11832
  }
12491
-
12492
- return;
12493
- }
12494
-
12495
- if (params->type == WSP_GGML_TASK_FINALIZE) {
12496
- return;
12497
- }
12498
-
12499
- wsp_ggml_fp16_t * const wdata = (wsp_ggml_fp16_t *) params->wdata + 0;
12500
- // wdata: [N*OH*OW, IC*KH*KW]
12501
- // dst: result [N, OC, OH, OW]
12502
- // src0: kernel [OC, IC, KH, KW]
12503
-
12504
- int64_t m = OC;
12505
- int64_t n = OH * OW;
12506
- int64_t k = IC * KH * KW;
12507
-
12508
- // [N, OC, OH, OW] = [OC, IC * KH * KW] x [N*OH*OW, IC * KH * KW]
12509
- for (int i = 0; i < N; i++) {
12510
- wsp_ggml_fp16_t * A = (wsp_ggml_fp16_t *)src0->data; // [m, k]
12511
- wsp_ggml_fp16_t * B = (wsp_ggml_fp16_t *)wdata + i * m * k; // [n, k]
12512
- float * C = (float *)dst->data + i * m * n; // [m * k]
12513
-
12514
- gemm_f16_out_f32(m, n, k, A, B, C, ith, nth);
12515
- }
12516
- }
12517
-
12518
- static void wsp_ggml_compute_forward_conv_2d(
12519
- const struct wsp_ggml_compute_params * params,
12520
- const struct wsp_ggml_tensor * src0,
12521
- const struct wsp_ggml_tensor * src1,
12522
- struct wsp_ggml_tensor * dst) {
12523
- switch (src0->type) {
12524
- case WSP_GGML_TYPE_F16:
12525
- {
12526
- wsp_ggml_compute_forward_conv_2d_f16_f32(params, src0, src1, dst);
12527
- } break;
12528
- case WSP_GGML_TYPE_F32:
12529
- {
12530
- //wsp_ggml_compute_forward_conv_2d_f32(params, src0, src1, dst);
12531
- WSP_GGML_ASSERT(false);
12532
- } break;
12533
- default:
12534
- {
12535
- WSP_GGML_ASSERT(false);
12536
- } break;
12537
- }
12538
- }
12539
-
12540
- static void wsp_ggml_compute_forward_conv_2d_stage_0(
12541
- const struct wsp_ggml_compute_params * params,
12542
- const struct wsp_ggml_tensor * src0,
12543
- const struct wsp_ggml_tensor * src1,
12544
- struct wsp_ggml_tensor * dst) {
12545
- switch (src0->type) {
12546
- case WSP_GGML_TYPE_F16:
12547
- {
12548
- wsp_ggml_compute_forward_conv_2d_stage_0_f32(params, src0, src1, dst);
12549
- } break;
12550
- case WSP_GGML_TYPE_F32:
12551
- {
12552
- WSP_GGML_ASSERT(false);
12553
- } break;
12554
- default:
12555
- {
12556
- WSP_GGML_ASSERT(false);
12557
- } break;
12558
11833
  }
12559
11834
  }
12560
11835
 
12561
- static void wsp_ggml_compute_forward_conv_2d_stage_1(
11836
+ static void wsp_ggml_compute_forward_im2col(
12562
11837
  const struct wsp_ggml_compute_params * params,
12563
11838
  const struct wsp_ggml_tensor * src0,
12564
11839
  const struct wsp_ggml_tensor * src1,
@@ -12566,7 +11841,7 @@ static void wsp_ggml_compute_forward_conv_2d_stage_1(
12566
11841
  switch (src0->type) {
12567
11842
  case WSP_GGML_TYPE_F16:
12568
11843
  {
12569
- wsp_ggml_compute_forward_conv_2d_stage_1_f16(params, src0, src1, dst);
11844
+ wsp_ggml_compute_forward_im2col_f16(params, src0, src1, dst);
12570
11845
  } break;
12571
11846
  case WSP_GGML_TYPE_F32:
12572
11847
  {
@@ -12880,6 +12155,67 @@ static void wsp_ggml_compute_forward_upscale(
12880
12155
  }
12881
12156
  }
12882
12157
 
12158
+ // wsp_ggml_compute_forward_argsort
12159
+
12160
+ static void wsp_ggml_compute_forward_argsort_f32(
12161
+ const struct wsp_ggml_compute_params * params,
12162
+ const struct wsp_ggml_tensor * src0,
12163
+ struct wsp_ggml_tensor * dst) {
12164
+
12165
+ if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) {
12166
+ return;
12167
+ }
12168
+
12169
+ WSP_GGML_TENSOR_UNARY_OP_LOCALS
12170
+
12171
+ WSP_GGML_ASSERT(nb0 == sizeof(float));
12172
+
12173
+ const int ith = params->ith;
12174
+ const int nth = params->nth;
12175
+
12176
+ const int64_t nr = wsp_ggml_nrows(src0);
12177
+
12178
+ enum wsp_ggml_sort_order order = (enum wsp_ggml_sort_order) wsp_ggml_get_op_params_i32(dst, 0);
12179
+
12180
+ for (int64_t i = ith; i < nr; i += nth) {
12181
+ int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
12182
+ const float * src_data = (float *)((char *) src0->data + i*nb01);
12183
+
12184
+ for (int64_t j = 0; j < ne0; j++) {
12185
+ dst_data[j] = j;
12186
+ }
12187
+
12188
+ // C doesn't have a functional sort, so we do a bubble sort instead
12189
+ for (int64_t j = 0; j < ne0; j++) {
12190
+ for (int64_t k = j + 1; k < ne0; k++) {
12191
+ if ((order == WSP_GGML_SORT_ASC && src_data[dst_data[j]] > src_data[dst_data[k]]) ||
12192
+ (order == WSP_GGML_SORT_DESC && src_data[dst_data[j]] < src_data[dst_data[k]])) {
12193
+ int32_t tmp = dst_data[j];
12194
+ dst_data[j] = dst_data[k];
12195
+ dst_data[k] = tmp;
12196
+ }
12197
+ }
12198
+ }
12199
+ }
12200
+ }
12201
+
12202
+ static void wsp_ggml_compute_forward_argsort(
12203
+ const struct wsp_ggml_compute_params * params,
12204
+ const struct wsp_ggml_tensor * src0,
12205
+ struct wsp_ggml_tensor * dst) {
12206
+
12207
+ switch (src0->type) {
12208
+ case WSP_GGML_TYPE_F32:
12209
+ {
12210
+ wsp_ggml_compute_forward_argsort_f32(params, src0, dst);
12211
+ } break;
12212
+ default:
12213
+ {
12214
+ WSP_GGML_ASSERT(false);
12215
+ } break;
12216
+ }
12217
+ }
12218
+
12883
12219
  // wsp_ggml_compute_forward_flash_attn
12884
12220
 
12885
12221
  static void wsp_ggml_compute_forward_flash_attn_f32(
@@ -14703,6 +14039,10 @@ static void wsp_ggml_compute_forward(struct wsp_ggml_compute_params * params, st
14703
14039
  {
14704
14040
  wsp_ggml_compute_forward_mul_mat(params, tensor->src[0], tensor->src[1], tensor);
14705
14041
  } break;
14042
+ case WSP_GGML_OP_MUL_MAT_ID:
14043
+ {
14044
+ wsp_ggml_compute_forward_mul_mat_id(params, tensor);
14045
+ } break;
14706
14046
  case WSP_GGML_OP_OUT_PROD:
14707
14047
  {
14708
14048
  wsp_ggml_compute_forward_out_prod(params, tensor->src[0], tensor->src[1], tensor);
@@ -14761,7 +14101,7 @@ static void wsp_ggml_compute_forward(struct wsp_ggml_compute_params * params, st
14761
14101
  } break;
14762
14102
  case WSP_GGML_OP_SOFT_MAX:
14763
14103
  {
14764
- wsp_ggml_compute_forward_soft_max(params, tensor->src[0], tensor);
14104
+ wsp_ggml_compute_forward_soft_max(params, tensor->src[0], tensor->src[1], tensor);
14765
14105
  } break;
14766
14106
  case WSP_GGML_OP_SOFT_MAX_BACK:
14767
14107
  {
@@ -14783,33 +14123,13 @@ static void wsp_ggml_compute_forward(struct wsp_ggml_compute_params * params, st
14783
14123
  {
14784
14124
  wsp_ggml_compute_forward_clamp(params, tensor->src[0], tensor);
14785
14125
  } break;
14786
- case WSP_GGML_OP_CONV_1D:
14787
- {
14788
- wsp_ggml_compute_forward_conv_1d(params, tensor->src[0], tensor->src[1], tensor);
14789
- } break;
14790
- case WSP_GGML_OP_CONV_1D_STAGE_0:
14791
- {
14792
- wsp_ggml_compute_forward_conv_1d_stage_0(params, tensor->src[0], tensor->src[1], tensor);
14793
- } break;
14794
- case WSP_GGML_OP_CONV_1D_STAGE_1:
14795
- {
14796
- wsp_ggml_compute_forward_conv_1d_stage_1(params, tensor->src[0], tensor->src[1], tensor);
14797
- } break;
14798
14126
  case WSP_GGML_OP_CONV_TRANSPOSE_1D:
14799
14127
  {
14800
14128
  wsp_ggml_compute_forward_conv_transpose_1d(params, tensor->src[0], tensor->src[1], tensor);
14801
14129
  } break;
14802
- case WSP_GGML_OP_CONV_2D:
14803
- {
14804
- wsp_ggml_compute_forward_conv_2d(params, tensor->src[0], tensor->src[1], tensor);
14805
- } break;
14806
- case WSP_GGML_OP_CONV_2D_STAGE_0:
14807
- {
14808
- wsp_ggml_compute_forward_conv_2d_stage_0(params, tensor->src[0], tensor->src[1], tensor);
14809
- } break;
14810
- case WSP_GGML_OP_CONV_2D_STAGE_1:
14130
+ case WSP_GGML_OP_IM2COL:
14811
14131
  {
14812
- wsp_ggml_compute_forward_conv_2d_stage_1(params, tensor->src[0], tensor->src[1], tensor);
14132
+ wsp_ggml_compute_forward_im2col(params, tensor->src[0], tensor->src[1], tensor);
14813
14133
  } break;
14814
14134
  case WSP_GGML_OP_CONV_TRANSPOSE_2D:
14815
14135
  {
@@ -14827,6 +14147,10 @@ static void wsp_ggml_compute_forward(struct wsp_ggml_compute_params * params, st
14827
14147
  {
14828
14148
  wsp_ggml_compute_forward_upscale(params, tensor->src[0], tensor);
14829
14149
  } break;
14150
+ case WSP_GGML_OP_ARGSORT:
14151
+ {
14152
+ wsp_ggml_compute_forward_argsort(params, tensor->src[0], tensor);
14153
+ } break;
14830
14154
  case WSP_GGML_OP_FLASH_ATTN:
14831
14155
  {
14832
14156
  const int32_t t = wsp_ggml_get_op_params_i32(tensor, 0);
@@ -15477,6 +14801,10 @@ static void wsp_ggml_compute_backward(struct wsp_ggml_context * ctx, struct wsp_
15477
14801
  zero_table);
15478
14802
  }
15479
14803
  } break;
14804
+ case WSP_GGML_OP_MUL_MAT_ID:
14805
+ {
14806
+ WSP_GGML_ASSERT(false); // TODO: not implemented
14807
+ } break;
15480
14808
  case WSP_GGML_OP_OUT_PROD:
15481
14809
  {
15482
14810
  WSP_GGML_ASSERT(false); // TODO: not implemented
@@ -15708,17 +15036,20 @@ static void wsp_ggml_compute_backward(struct wsp_ggml_context * ctx, struct wsp_
15708
15036
  // necessary for llama
15709
15037
  if (src0->grad) {
15710
15038
  //const int n_past = ((int32_t *) tensor->op_params)[0];
15711
- const int n_dims = ((int32_t *) tensor->op_params)[1];
15712
- const int mode = ((int32_t *) tensor->op_params)[2];
15713
- const int n_ctx = ((int32_t *) tensor->op_params)[3];
15714
- float freq_base;
15715
- float freq_scale;
15716
- float xpos_base;
15717
- bool xpos_down;
15718
- memcpy(&freq_base, (int32_t *) tensor->op_params + 4, sizeof(float));
15719
- memcpy(&freq_scale, (int32_t *) tensor->op_params + 5, sizeof(float));
15720
- memcpy(&xpos_base, (int32_t *) tensor->op_params + 6, sizeof(float));
15721
- memcpy(&xpos_down, (int32_t *) tensor->op_params + 7, sizeof(bool));
15039
+ const int n_dims = ((int32_t *) tensor->op_params)[1];
15040
+ const int mode = ((int32_t *) tensor->op_params)[2];
15041
+ const int n_ctx = ((int32_t *) tensor->op_params)[3];
15042
+ const int n_orig_ctx = ((int32_t *) tensor->op_params)[4];
15043
+ float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow, xpos_base, xpos_down;
15044
+
15045
+ memcpy(&freq_base, (int32_t *) tensor->op_params + 5, sizeof(float));
15046
+ memcpy(&freq_scale, (int32_t *) tensor->op_params + 6, sizeof(float));
15047
+ memcpy(&ext_factor, (int32_t *) tensor->op_params + 7, sizeof(float));
15048
+ memcpy(&attn_factor, (int32_t *) tensor->op_params + 8, sizeof(float));
15049
+ memcpy(&beta_fast, (int32_t *) tensor->op_params + 9, sizeof(float));
15050
+ memcpy(&beta_slow, (int32_t *) tensor->op_params + 10, sizeof(float));
15051
+ memcpy(&xpos_base, (int32_t *) tensor->op_params + 11, sizeof(float));
15052
+ memcpy(&xpos_down, (int32_t *) tensor->op_params + 12, sizeof(bool));
15722
15053
 
15723
15054
  src0->grad = wsp_ggml_add_or_set(ctx,
15724
15055
  src0->grad,
@@ -15728,8 +15059,13 @@ static void wsp_ggml_compute_backward(struct wsp_ggml_context * ctx, struct wsp_
15728
15059
  n_dims,
15729
15060
  mode,
15730
15061
  n_ctx,
15062
+ n_orig_ctx,
15731
15063
  freq_base,
15732
15064
  freq_scale,
15065
+ ext_factor,
15066
+ attn_factor,
15067
+ beta_fast,
15068
+ beta_slow,
15733
15069
  xpos_base,
15734
15070
  xpos_down),
15735
15071
  zero_table);
@@ -15739,17 +15075,20 @@ static void wsp_ggml_compute_backward(struct wsp_ggml_context * ctx, struct wsp_
15739
15075
  {
15740
15076
  if (src0->grad) {
15741
15077
  //const int n_past = ((int32_t *) tensor->op_params)[0];
15742
- const int n_dims = ((int32_t *) tensor->op_params)[1];
15743
- const int mode = ((int32_t *) tensor->op_params)[2];
15744
- const int n_ctx = ((int32_t *) tensor->op_params)[3];
15745
- float freq_base;
15746
- float freq_scale;
15747
- float xpos_base;
15748
- bool xpos_down;
15749
- memcpy(&freq_base, (int32_t *) tensor->op_params + 4, sizeof(float));
15750
- memcpy(&freq_scale, (int32_t *) tensor->op_params + 5, sizeof(float));
15751
- memcpy(&xpos_base, (int32_t *) tensor->op_params + 6, sizeof(float));
15752
- memcpy(&xpos_down, (int32_t *) tensor->op_params + 7, sizeof(bool));
15078
+ const int n_dims = ((int32_t *) tensor->op_params)[1];
15079
+ const int mode = ((int32_t *) tensor->op_params)[2];
15080
+ const int n_ctx = ((int32_t *) tensor->op_params)[3];
15081
+ const int n_orig_ctx = ((int32_t *) tensor->op_params)[4];
15082
+ float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow, xpos_base, xpos_down;
15083
+
15084
+ memcpy(&freq_base, (int32_t *) tensor->op_params + 5, sizeof(float));
15085
+ memcpy(&freq_scale, (int32_t *) tensor->op_params + 6, sizeof(float));
15086
+ memcpy(&ext_factor, (int32_t *) tensor->op_params + 7, sizeof(float));
15087
+ memcpy(&attn_factor, (int32_t *) tensor->op_params + 8, sizeof(float));
15088
+ memcpy(&beta_fast, (int32_t *) tensor->op_params + 9, sizeof(float));
15089
+ memcpy(&beta_slow, (int32_t *) tensor->op_params + 10, sizeof(float));
15090
+ memcpy(&xpos_base, (int32_t *) tensor->op_params + 11, sizeof(float));
15091
+ memcpy(&xpos_down, (int32_t *) tensor->op_params + 12, sizeof(bool));
15753
15092
 
15754
15093
  src0->grad = wsp_ggml_add_or_set(ctx,
15755
15094
  src0->grad,
@@ -15758,14 +15097,14 @@ static void wsp_ggml_compute_backward(struct wsp_ggml_context * ctx, struct wsp_
15758
15097
  src1,
15759
15098
  n_dims,
15760
15099
  mode,
15761
- 0,
15762
15100
  n_ctx,
15101
+ n_orig_ctx,
15763
15102
  freq_base,
15764
15103
  freq_scale,
15765
- 0.0f,
15766
- 1.0f,
15767
- 0.0f,
15768
- 0.0f,
15104
+ ext_factor,
15105
+ attn_factor,
15106
+ beta_fast,
15107
+ beta_slow,
15769
15108
  xpos_base,
15770
15109
  xpos_down,
15771
15110
  false),
@@ -15780,31 +15119,11 @@ static void wsp_ggml_compute_backward(struct wsp_ggml_context * ctx, struct wsp_
15780
15119
  {
15781
15120
  WSP_GGML_ASSERT(false); // TODO: not implemented
15782
15121
  } break;
15783
- case WSP_GGML_OP_CONV_1D:
15784
- {
15785
- WSP_GGML_ASSERT(false); // TODO: not implemented
15786
- } break;
15787
- case WSP_GGML_OP_CONV_1D_STAGE_0:
15788
- {
15789
- WSP_GGML_ASSERT(false); // TODO: not implemented
15790
- } break;
15791
- case WSP_GGML_OP_CONV_1D_STAGE_1:
15792
- {
15793
- WSP_GGML_ASSERT(false); // TODO: not implemented
15794
- } break;
15795
15122
  case WSP_GGML_OP_CONV_TRANSPOSE_1D:
15796
15123
  {
15797
15124
  WSP_GGML_ASSERT(false); // TODO: not implemented
15798
15125
  } break;
15799
- case WSP_GGML_OP_CONV_2D:
15800
- {
15801
- WSP_GGML_ASSERT(false); // TODO: not implemented
15802
- } break;
15803
- case WSP_GGML_OP_CONV_2D_STAGE_0:
15804
- {
15805
- WSP_GGML_ASSERT(false); // TODO: not implemented
15806
- } break;
15807
- case WSP_GGML_OP_CONV_2D_STAGE_1:
15126
+ case WSP_GGML_OP_IM2COL:
15808
15127
  {
15809
15128
  WSP_GGML_ASSERT(false); // TODO: not implemented
15810
15129
  } break;
@@ -15824,6 +15143,10 @@ static void wsp_ggml_compute_backward(struct wsp_ggml_context * ctx, struct wsp_
15824
15143
  {
15825
15144
  WSP_GGML_ASSERT(false); // TODO: not implemented
15826
15145
  } break;
15146
+ case WSP_GGML_OP_ARGSORT:
15147
+ {
15148
+ WSP_GGML_ASSERT(false); // TODO: not implemented
15149
+ } break;
15827
15150
  case WSP_GGML_OP_FLASH_ATTN:
15828
15151
  {
15829
15152
  struct wsp_ggml_tensor * flash_grad = NULL;
@@ -16184,12 +15507,8 @@ struct wsp_ggml_cgraph * wsp_ggml_new_graph(struct wsp_ggml_context * ctx) {
16184
15507
  return wsp_ggml_new_graph_custom(ctx, WSP_GGML_DEFAULT_GRAPH_SIZE, false);
16185
15508
  }
16186
15509
 
16187
- struct wsp_ggml_cgraph * wsp_ggml_graph_view(struct wsp_ggml_context * ctx, struct wsp_ggml_cgraph * cgraph0, int i0, int i1) {
16188
- const size_t obj_size = sizeof(struct wsp_ggml_cgraph);
16189
- struct wsp_ggml_object * obj = wsp_ggml_new_object(ctx, WSP_GGML_OBJECT_GRAPH, obj_size);
16190
- struct wsp_ggml_cgraph * cgraph = (struct wsp_ggml_cgraph *) ((char *) ctx->mem_buffer + obj->offs);
16191
-
16192
- *cgraph = (struct wsp_ggml_cgraph) {
15510
+ struct wsp_ggml_cgraph wsp_ggml_graph_view(struct wsp_ggml_cgraph * cgraph0, int i0, int i1) {
15511
+ struct wsp_ggml_cgraph cgraph = {
16193
15512
  /*.size =*/ 0,
16194
15513
  /*.n_nodes =*/ i1 - i0,
16195
15514
  /*.n_leafs =*/ 0,
@@ -16424,7 +15743,6 @@ static int wsp_ggml_get_n_tasks(struct wsp_ggml_tensor * node, int n_threads) {
16424
15743
  n_tasks = n_threads;
16425
15744
  } break;
16426
15745
  case WSP_GGML_OP_SUB:
16427
- case WSP_GGML_OP_DIV:
16428
15746
  case WSP_GGML_OP_SQR:
16429
15747
  case WSP_GGML_OP_SQRT:
16430
15748
  case WSP_GGML_OP_LOG:
@@ -16457,10 +15775,13 @@ static int wsp_ggml_get_n_tasks(struct wsp_ggml_tensor * node, int n_threads) {
16457
15775
  {
16458
15776
  n_tasks = n_threads;
16459
15777
  } break;
15778
+ default:
15779
+ WSP_GGML_ASSERT(false);
16460
15780
  }
16461
15781
  break;
16462
15782
  case WSP_GGML_OP_SILU_BACK:
16463
15783
  case WSP_GGML_OP_MUL:
15784
+ case WSP_GGML_OP_DIV:
16464
15785
  case WSP_GGML_OP_NORM:
16465
15786
  case WSP_GGML_OP_RMS_NORM:
16466
15787
  case WSP_GGML_OP_RMS_NORM_BACK:
@@ -16498,6 +15819,11 @@ static int wsp_ggml_get_n_tasks(struct wsp_ggml_tensor * node, int n_threads) {
16498
15819
  }
16499
15820
  #endif
16500
15821
  } break;
15822
+ case WSP_GGML_OP_MUL_MAT_ID:
15823
+ {
15824
+ // FIXME: blas
15825
+ n_tasks = n_threads;
15826
+ } break;
16501
15827
  case WSP_GGML_OP_OUT_PROD:
16502
15828
  {
16503
15829
  n_tasks = n_threads;
@@ -16517,7 +15843,6 @@ static int wsp_ggml_get_n_tasks(struct wsp_ggml_tensor * node, int n_threads) {
16517
15843
  } break;
16518
15844
  case WSP_GGML_OP_DIAG_MASK_ZERO:
16519
15845
  case WSP_GGML_OP_DIAG_MASK_INF:
16520
- case WSP_GGML_OP_SOFT_MAX:
16521
15846
  case WSP_GGML_OP_SOFT_MAX_BACK:
16522
15847
  case WSP_GGML_OP_ROPE:
16523
15848
  case WSP_GGML_OP_ROPE_BACK:
@@ -16533,31 +15858,15 @@ static int wsp_ggml_get_n_tasks(struct wsp_ggml_tensor * node, int n_threads) {
16533
15858
  {
16534
15859
  n_tasks = 1; //TODO
16535
15860
  } break;
16536
- case WSP_GGML_OP_CONV_1D:
16537
- {
16538
- n_tasks = n_threads;
16539
- } break;
16540
- case WSP_GGML_OP_CONV_1D_STAGE_0:
16541
- {
16542
- n_tasks = n_threads;
16543
- } break;
16544
- case WSP_GGML_OP_CONV_1D_STAGE_1:
15861
+ case WSP_GGML_OP_SOFT_MAX:
16545
15862
  {
16546
- n_tasks = n_threads;
15863
+ n_tasks = MIN(MIN(4, n_threads), wsp_ggml_nrows(node->src[0]));
16547
15864
  } break;
16548
15865
  case WSP_GGML_OP_CONV_TRANSPOSE_1D:
16549
15866
  {
16550
15867
  n_tasks = n_threads;
16551
15868
  } break;
16552
- case WSP_GGML_OP_CONV_2D:
16553
- {
16554
- n_tasks = n_threads;
16555
- } break;
16556
- case WSP_GGML_OP_CONV_2D_STAGE_0:
16557
- {
16558
- n_tasks = n_threads;
16559
- } break;
16560
- case WSP_GGML_OP_CONV_2D_STAGE_1:
15869
+ case WSP_GGML_OP_IM2COL:
16561
15870
  {
16562
15871
  n_tasks = n_threads;
16563
15872
  } break;
@@ -16574,6 +15883,10 @@ static int wsp_ggml_get_n_tasks(struct wsp_ggml_tensor * node, int n_threads) {
16574
15883
  {
16575
15884
  n_tasks = n_threads;
16576
15885
  } break;
15886
+ case WSP_GGML_OP_ARGSORT:
15887
+ {
15888
+ n_tasks = n_threads;
15889
+ } break;
16577
15890
  case WSP_GGML_OP_FLASH_ATTN:
16578
15891
  {
16579
15892
  n_tasks = n_threads;
@@ -16642,6 +15955,12 @@ static int wsp_ggml_get_n_tasks(struct wsp_ggml_tensor * node, int n_threads) {
16642
15955
  } break;
16643
15956
  default:
16644
15957
  {
15958
+ fprintf(stderr, "%s: op not implemented: ", __func__);
15959
+ if (node->op < WSP_GGML_OP_COUNT) {
15960
+ fprintf(stderr, "%s\n", wsp_ggml_op_name(node->op));
15961
+ } else {
15962
+ fprintf(stderr, "%d\n", node->op);
15963
+ }
16645
15964
  WSP_GGML_ASSERT(false);
16646
15965
  } break;
16647
15966
  }
@@ -16782,18 +16101,16 @@ struct wsp_ggml_cplan wsp_ggml_graph_plan(struct wsp_ggml_cgraph * cgraph, int n
16782
16101
 
16783
16102
  // thread scheduling for the different operations + work buffer size estimation
16784
16103
  for (int i = 0; i < cgraph->n_nodes; i++) {
16785
- int n_tasks = 1;
16786
-
16787
16104
  struct wsp_ggml_tensor * node = cgraph->nodes[i];
16788
16105
 
16106
+ const int n_tasks = wsp_ggml_get_n_tasks(node, n_threads);
16107
+
16789
16108
  size_t cur = 0;
16790
16109
 
16791
16110
  switch (node->op) {
16792
16111
  case WSP_GGML_OP_CPY:
16793
16112
  case WSP_GGML_OP_DUP:
16794
16113
  {
16795
- n_tasks = n_threads;
16796
-
16797
16114
  if (wsp_ggml_is_quantized(node->type)) {
16798
16115
  cur = wsp_ggml_type_size(WSP_GGML_TYPE_F32) * node->ne[0] * n_tasks;
16799
16116
  }
@@ -16801,16 +16118,12 @@ struct wsp_ggml_cplan wsp_ggml_graph_plan(struct wsp_ggml_cgraph * cgraph, int n
16801
16118
  case WSP_GGML_OP_ADD:
16802
16119
  case WSP_GGML_OP_ADD1:
16803
16120
  {
16804
- n_tasks = n_threads;
16805
-
16806
16121
  if (wsp_ggml_is_quantized(node->src[0]->type)) {
16807
16122
  cur = wsp_ggml_type_size(WSP_GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks;
16808
16123
  }
16809
16124
  } break;
16810
16125
  case WSP_GGML_OP_ACC:
16811
16126
  {
16812
- n_tasks = n_threads;
16813
-
16814
16127
  if (wsp_ggml_is_quantized(node->src[0]->type)) {
16815
16128
  cur = wsp_ggml_type_size(WSP_GGML_TYPE_F32) * node->src[1]->ne[0] * n_tasks;
16816
16129
  }
@@ -16836,45 +16149,32 @@ struct wsp_ggml_cplan wsp_ggml_graph_plan(struct wsp_ggml_cgraph * cgraph, int n
16836
16149
  cur = wsp_ggml_type_size(vec_dot_type)*wsp_ggml_nelements(node->src[1])/wsp_ggml_blck_size(vec_dot_type);
16837
16150
  }
16838
16151
  } break;
16152
+ case WSP_GGML_OP_MUL_MAT_ID:
16153
+ {
16154
+ const struct wsp_ggml_tensor * a = node->src[2];
16155
+ const struct wsp_ggml_tensor * b = node->src[1];
16156
+ const enum wsp_ggml_type vec_dot_type = type_traits[a->type].vec_dot_type;
16157
+ #if defined(WSP_GGML_USE_ACCELERATE) || defined(WSP_GGML_USE_OPENBLAS)
16158
+ if (wsp_ggml_compute_forward_mul_mat_use_blas(a, b, node)) {
16159
+ if (a->type != WSP_GGML_TYPE_F32) {
16160
+ // here we need memory just for single 2D matrix from src0
16161
+ cur = wsp_ggml_type_size(WSP_GGML_TYPE_F32)*(a->ne[0]*a->ne[1]);
16162
+ }
16163
+ } else
16164
+ #endif
16165
+ if (b->type != vec_dot_type) {
16166
+ cur = wsp_ggml_type_size(vec_dot_type)*wsp_ggml_nelements(b)/wsp_ggml_blck_size(vec_dot_type);
16167
+ }
16168
+ } break;
16839
16169
  case WSP_GGML_OP_OUT_PROD:
16840
16170
  {
16841
- n_tasks = n_threads;
16842
-
16843
16171
  if (wsp_ggml_is_quantized(node->src[0]->type)) {
16844
16172
  cur = wsp_ggml_type_size(WSP_GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks;
16845
16173
  }
16846
16174
  } break;
16847
- case WSP_GGML_OP_CONV_1D:
16175
+ case WSP_GGML_OP_SOFT_MAX:
16848
16176
  {
16849
- WSP_GGML_ASSERT(node->src[0]->ne[3] == 1);
16850
- WSP_GGML_ASSERT(node->src[1]->ne[2] == 1);
16851
- WSP_GGML_ASSERT(node->src[1]->ne[3] == 1);
16852
-
16853
- const int64_t ne00 = node->src[0]->ne[0];
16854
- const int64_t ne01 = node->src[0]->ne[1];
16855
- const int64_t ne02 = node->src[0]->ne[2];
16856
-
16857
- const int64_t ne10 = node->src[1]->ne[0];
16858
- const int64_t ne11 = node->src[1]->ne[1];
16859
-
16860
- const int64_t ne0 = node->ne[0];
16861
- const int64_t ne1 = node->ne[1];
16862
- const int64_t nk = ne00;
16863
- const int64_t ew0 = nk * ne01;
16864
-
16865
- UNUSED(ne02);
16866
- UNUSED(ne10);
16867
- UNUSED(ne11);
16868
-
16869
- if (node->src[0]->type == WSP_GGML_TYPE_F16 &&
16870
- node->src[1]->type == WSP_GGML_TYPE_F32) {
16871
- cur = sizeof(wsp_ggml_fp16_t)*(ne0*ne1*ew0);
16872
- } else if (node->src[0]->type == WSP_GGML_TYPE_F32 &&
16873
- node->src[1]->type == WSP_GGML_TYPE_F32) {
16874
- cur = sizeof(float)*(ne0*ne1*ew0);
16875
- } else {
16876
- WSP_GGML_ASSERT(false);
16877
- }
16177
+ cur = wsp_ggml_type_size(WSP_GGML_TYPE_F32) * node->ne[0] * n_tasks;
16878
16178
  } break;
16879
16179
  case WSP_GGML_OP_CONV_TRANSPOSE_1D:
16880
16180
  {
@@ -16901,38 +16201,6 @@ struct wsp_ggml_cplan wsp_ggml_graph_plan(struct wsp_ggml_cgraph * cgraph, int n
16901
16201
  WSP_GGML_ASSERT(false);
16902
16202
  }
16903
16203
  } break;
16904
- case WSP_GGML_OP_CONV_2D:
16905
- {
16906
- const int64_t ne00 = node->src[0]->ne[0]; // W
16907
- const int64_t ne01 = node->src[0]->ne[1]; // H
16908
- const int64_t ne02 = node->src[0]->ne[2]; // C
16909
- const int64_t ne03 = node->src[0]->ne[3]; // N
16910
-
16911
- const int64_t ne10 = node->src[1]->ne[0]; // W
16912
- const int64_t ne11 = node->src[1]->ne[1]; // H
16913
- const int64_t ne12 = node->src[1]->ne[2]; // C
16914
-
16915
- const int64_t ne0 = node->ne[0];
16916
- const int64_t ne1 = node->ne[1];
16917
- const int64_t ne2 = node->ne[2];
16918
- const int64_t ne3 = node->ne[3];
16919
- const int64_t nk = ne00*ne01;
16920
- const int64_t ew0 = nk * ne02;
16921
-
16922
- UNUSED(ne03);
16923
- UNUSED(ne2);
16924
-
16925
- if (node->src[0]->type == WSP_GGML_TYPE_F16 &&
16926
- node->src[1]->type == WSP_GGML_TYPE_F32) {
16927
- // im2col: [N*OH*OW, IC*KH*KW]
16928
- cur = sizeof(wsp_ggml_fp16_t)*(ne3*ne0*ne1*ew0);
16929
- } else if (node->src[0]->type == WSP_GGML_TYPE_F32 &&
16930
- node->src[1]->type == WSP_GGML_TYPE_F32) {
16931
- cur = sizeof(float)* (ne10*ne11*ne12);
16932
- } else {
16933
- WSP_GGML_ASSERT(false);
16934
- }
16935
- } break;
16936
16204
  case WSP_GGML_OP_CONV_TRANSPOSE_2D:
16937
16205
  {
16938
16206
  const int64_t ne00 = node->src[0]->ne[0]; // W
@@ -16949,8 +16217,6 @@ struct wsp_ggml_cplan wsp_ggml_graph_plan(struct wsp_ggml_cgraph * cgraph, int n
16949
16217
  } break;
16950
16218
  case WSP_GGML_OP_FLASH_ATTN:
16951
16219
  {
16952
- n_tasks = n_threads;
16953
-
16954
16220
  const int64_t ne11 = wsp_ggml_up(node->src[1]->ne[1], WSP_GGML_SOFT_MAX_UNROLL);
16955
16221
 
16956
16222
  if (node->src[1]->type == WSP_GGML_TYPE_F32) {
@@ -16963,8 +16229,6 @@ struct wsp_ggml_cplan wsp_ggml_graph_plan(struct wsp_ggml_cgraph * cgraph, int n
16963
16229
  } break;
16964
16230
  case WSP_GGML_OP_FLASH_FF:
16965
16231
  {
16966
- n_tasks = n_threads;
16967
-
16968
16232
  if (node->src[1]->type == WSP_GGML_TYPE_F32) {
16969
16233
  cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
16970
16234
  cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2
@@ -16975,8 +16239,6 @@ struct wsp_ggml_cplan wsp_ggml_graph_plan(struct wsp_ggml_cgraph * cgraph, int n
16975
16239
  } break;
16976
16240
  case WSP_GGML_OP_FLASH_ATTN_BACK:
16977
16241
  {
16978
- n_tasks = n_threads;
16979
-
16980
16242
  const int64_t D = node->src[0]->ne[0];
16981
16243
  const int64_t ne11 = wsp_ggml_up(node->src[1]->ne[1], WSP_GGML_SOFT_MAX_UNROLL);
16982
16244
  const int64_t mxDn = MAX(D, ne11) * 2; // *2 because of S and SM in wsp_ggml_compute_forward_flash_attn_back
@@ -16991,8 +16253,6 @@ struct wsp_ggml_cplan wsp_ggml_graph_plan(struct wsp_ggml_cgraph * cgraph, int n
16991
16253
 
16992
16254
  case WSP_GGML_OP_CROSS_ENTROPY_LOSS:
16993
16255
  {
16994
- n_tasks = n_threads;
16995
-
16996
16256
  cur = wsp_ggml_type_size(node->type)*(n_tasks + node->src[0]->ne[0]*n_tasks);
16997
16257
  } break;
16998
16258
  case WSP_GGML_OP_COUNT:
@@ -18719,14 +17979,14 @@ enum wsp_ggml_opt_result wsp_ggml_opt_resume_g(
18719
17979
 
18720
17980
  ////////////////////////////////////////////////////////////////////////////////
18721
17981
 
18722
- size_t wsp_ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist) {
17982
+ size_t wsp_ggml_wsp_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist) {
18723
17983
  assert(k % QK4_0 == 0);
18724
17984
  const int nb = k / QK4_0;
18725
17985
 
18726
17986
  for (int b = 0; b < n; b += k) {
18727
17987
  block_q4_0 * restrict y = (block_q4_0 *) dst + b/QK4_0;
18728
17988
 
18729
- quantize_row_q4_0_reference(src + b, y, k);
17989
+ wsp_quantize_row_q4_0_reference(src + b, y, k);
18730
17990
 
18731
17991
  for (int i = 0; i < nb; i++) {
18732
17992
  for (int j = 0; j < QK4_0; j += 2) {
@@ -18742,14 +18002,14 @@ size_t wsp_ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64
18742
18002
  return (n/QK4_0*sizeof(block_q4_0));
18743
18003
  }
18744
18004
 
18745
- size_t wsp_ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist) {
18005
+ size_t wsp_ggml_wsp_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist) {
18746
18006
  assert(k % QK4_1 == 0);
18747
18007
  const int nb = k / QK4_1;
18748
18008
 
18749
18009
  for (int b = 0; b < n; b += k) {
18750
18010
  block_q4_1 * restrict y = (block_q4_1 *) dst + b/QK4_1;
18751
18011
 
18752
- quantize_row_q4_1_reference(src + b, y, k);
18012
+ wsp_quantize_row_q4_1_reference(src + b, y, k);
18753
18013
 
18754
18014
  for (int i = 0; i < nb; i++) {
18755
18015
  for (int j = 0; j < QK4_1; j += 2) {
@@ -18765,22 +18025,22 @@ size_t wsp_ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64
18765
18025
  return (n/QK4_1*sizeof(block_q4_1));
18766
18026
  }
18767
18027
 
18768
- size_t wsp_ggml_quantize_q5_0(const float * src, void * dst, int n, int k, int64_t * hist) {
18028
+ size_t wsp_ggml_wsp_quantize_q5_0(const float * src, void * dst, int n, int k, int64_t * hist) {
18769
18029
  assert(k % QK5_0 == 0);
18770
18030
  const int nb = k / QK5_0;
18771
18031
 
18772
18032
  for (int b = 0; b < n; b += k) {
18773
18033
  block_q5_0 * restrict y = (block_q5_0 *)dst + b/QK5_0;
18774
18034
 
18775
- quantize_row_q5_0_reference(src + b, y, k);
18035
+ wsp_quantize_row_q5_0_reference(src + b, y, k);
18776
18036
 
18777
18037
  for (int i = 0; i < nb; i++) {
18778
18038
  uint32_t qh;
18779
18039
  memcpy(&qh, &y[i].qh, sizeof(qh));
18780
18040
 
18781
18041
  for (int j = 0; j < QK5_0; j += 2) {
18782
- const uint8_t vh0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4;
18783
- const uint8_t vh1 = ((qh & (1u << (j + 16))) >> (j + 12));
18042
+ const uint8_t vh0 = ((qh & (1u << (j/2 + 0 ))) >> (j/2 + 0 )) << 4;
18043
+ const uint8_t vh1 = ((qh & (1u << (j/2 + 16))) >> (j/2 + 12));
18784
18044
 
18785
18045
  // cast to 16 bins
18786
18046
  const uint8_t vi0 = ((y[i].qs[j/2] & 0x0F) | vh0) / 2;
@@ -18795,22 +18055,22 @@ size_t wsp_ggml_quantize_q5_0(const float * src, void * dst, int n, int k, int64
18795
18055
  return (n/QK5_0*sizeof(block_q5_0));
18796
18056
  }
18797
18057
 
18798
- size_t wsp_ggml_quantize_q5_1(const float * src, void * dst, int n, int k, int64_t * hist) {
18058
+ size_t wsp_ggml_wsp_quantize_q5_1(const float * src, void * dst, int n, int k, int64_t * hist) {
18799
18059
  assert(k % QK5_1 == 0);
18800
18060
  const int nb = k / QK5_1;
18801
18061
 
18802
18062
  for (int b = 0; b < n; b += k) {
18803
18063
  block_q5_1 * restrict y = (block_q5_1 *)dst + b/QK5_1;
18804
18064
 
18805
- quantize_row_q5_1_reference(src + b, y, k);
18065
+ wsp_quantize_row_q5_1_reference(src + b, y, k);
18806
18066
 
18807
18067
  for (int i = 0; i < nb; i++) {
18808
18068
  uint32_t qh;
18809
18069
  memcpy(&qh, &y[i].qh, sizeof(qh));
18810
18070
 
18811
18071
  for (int j = 0; j < QK5_1; j += 2) {
18812
- const uint8_t vh0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4;
18813
- const uint8_t vh1 = ((qh & (1u << (j + 16))) >> (j + 12));
18072
+ const uint8_t vh0 = ((qh & (1u << (j/2 + 0 ))) >> (j/2 + 0 )) << 4;
18073
+ const uint8_t vh1 = ((qh & (1u << (j/2 + 16))) >> (j/2 + 12));
18814
18074
 
18815
18075
  // cast to 16 bins
18816
18076
  const uint8_t vi0 = ((y[i].qs[j/2] & 0x0F) | vh0) / 2;
@@ -18825,14 +18085,14 @@ size_t wsp_ggml_quantize_q5_1(const float * src, void * dst, int n, int k, int64
18825
18085
  return (n/QK5_1*sizeof(block_q5_1));
18826
18086
  }
18827
18087
 
18828
- size_t wsp_ggml_quantize_q8_0(const float * src, void * dst, int n, int k, int64_t * hist) {
18088
+ size_t wsp_ggml_wsp_quantize_q8_0(const float * src, void * dst, int n, int k, int64_t * hist) {
18829
18089
  assert(k % QK8_0 == 0);
18830
18090
  const int nb = k / QK8_0;
18831
18091
 
18832
18092
  for (int b = 0; b < n; b += k) {
18833
18093
  block_q8_0 * restrict y = (block_q8_0 *)dst + b/QK8_0;
18834
18094
 
18835
- quantize_row_q8_0_reference(src + b, y, k);
18095
+ wsp_quantize_row_q8_0_reference(src + b, y, k);
18836
18096
 
18837
18097
  for (int i = 0; i < nb; i++) {
18838
18098
  for (int j = 0; j < QK8_0; ++j) {
@@ -18846,68 +18106,68 @@ size_t wsp_ggml_quantize_q8_0(const float * src, void * dst, int n, int k, int64
18846
18106
  return (n/QK8_0*sizeof(block_q8_0));
18847
18107
  }
18848
18108
 
18849
- size_t wsp_ggml_quantize_chunk(enum wsp_ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist) {
18109
+ size_t wsp_ggml_wsp_quantize_chunk(enum wsp_ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist) {
18850
18110
  size_t result = 0;
18851
18111
  switch (type) {
18852
18112
  case WSP_GGML_TYPE_Q4_0:
18853
18113
  {
18854
18114
  WSP_GGML_ASSERT(start % QK4_0 == 0);
18855
18115
  block_q4_0 * block = (block_q4_0*)dst + start / QK4_0;
18856
- result = wsp_ggml_quantize_q4_0(src + start, block, n, n, hist);
18116
+ result = wsp_ggml_wsp_quantize_q4_0(src + start, block, n, n, hist);
18857
18117
  } break;
18858
18118
  case WSP_GGML_TYPE_Q4_1:
18859
18119
  {
18860
18120
  WSP_GGML_ASSERT(start % QK4_1 == 0);
18861
18121
  block_q4_1 * block = (block_q4_1*)dst + start / QK4_1;
18862
- result = wsp_ggml_quantize_q4_1(src + start, block, n, n, hist);
18122
+ result = wsp_ggml_wsp_quantize_q4_1(src + start, block, n, n, hist);
18863
18123
  } break;
18864
18124
  case WSP_GGML_TYPE_Q5_0:
18865
18125
  {
18866
18126
  WSP_GGML_ASSERT(start % QK5_0 == 0);
18867
18127
  block_q5_0 * block = (block_q5_0*)dst + start / QK5_0;
18868
- result = wsp_ggml_quantize_q5_0(src + start, block, n, n, hist);
18128
+ result = wsp_ggml_wsp_quantize_q5_0(src + start, block, n, n, hist);
18869
18129
  } break;
18870
18130
  case WSP_GGML_TYPE_Q5_1:
18871
18131
  {
18872
18132
  WSP_GGML_ASSERT(start % QK5_1 == 0);
18873
18133
  block_q5_1 * block = (block_q5_1*)dst + start / QK5_1;
18874
- result = wsp_ggml_quantize_q5_1(src + start, block, n, n, hist);
18134
+ result = wsp_ggml_wsp_quantize_q5_1(src + start, block, n, n, hist);
18875
18135
  } break;
18876
18136
  case WSP_GGML_TYPE_Q8_0:
18877
18137
  {
18878
18138
  WSP_GGML_ASSERT(start % QK8_0 == 0);
18879
18139
  block_q8_0 * block = (block_q8_0*)dst + start / QK8_0;
18880
- result = wsp_ggml_quantize_q8_0(src + start, block, n, n, hist);
18140
+ result = wsp_ggml_wsp_quantize_q8_0(src + start, block, n, n, hist);
18881
18141
  } break;
18882
18142
  case WSP_GGML_TYPE_Q2_K:
18883
18143
  {
18884
18144
  WSP_GGML_ASSERT(start % QK_K == 0);
18885
18145
  block_q2_K * block = (block_q2_K*)dst + start / QK_K;
18886
- result = wsp_ggml_quantize_q2_K(src + start, block, n, n, hist);
18146
+ result = wsp_ggml_wsp_quantize_q2_K(src + start, block, n, n, hist);
18887
18147
  } break;
18888
18148
  case WSP_GGML_TYPE_Q3_K:
18889
18149
  {
18890
18150
  WSP_GGML_ASSERT(start % QK_K == 0);
18891
18151
  block_q3_K * block = (block_q3_K*)dst + start / QK_K;
18892
- result = wsp_ggml_quantize_q3_K(src + start, block, n, n, hist);
18152
+ result = wsp_ggml_wsp_quantize_q3_K(src + start, block, n, n, hist);
18893
18153
  } break;
18894
18154
  case WSP_GGML_TYPE_Q4_K:
18895
18155
  {
18896
18156
  WSP_GGML_ASSERT(start % QK_K == 0);
18897
18157
  block_q4_K * block = (block_q4_K*)dst + start / QK_K;
18898
- result = wsp_ggml_quantize_q4_K(src + start, block, n, n, hist);
18158
+ result = wsp_ggml_wsp_quantize_q4_K(src + start, block, n, n, hist);
18899
18159
  } break;
18900
18160
  case WSP_GGML_TYPE_Q5_K:
18901
18161
  {
18902
18162
  WSP_GGML_ASSERT(start % QK_K == 0);
18903
18163
  block_q5_K * block = (block_q5_K*)dst + start / QK_K;
18904
- result = wsp_ggml_quantize_q5_K(src + start, block, n, n, hist);
18164
+ result = wsp_ggml_wsp_quantize_q5_K(src + start, block, n, n, hist);
18905
18165
  } break;
18906
18166
  case WSP_GGML_TYPE_Q6_K:
18907
18167
  {
18908
18168
  WSP_GGML_ASSERT(start % QK_K == 0);
18909
18169
  block_q6_K * block = (block_q6_K*)dst + start / QK_K;
18910
- result = wsp_ggml_quantize_q6_K(src + start, block, n, n, hist);
18170
+ result = wsp_ggml_wsp_quantize_q6_K(src + start, block, n, n, hist);
18911
18171
  } break;
18912
18172
  case WSP_GGML_TYPE_F16:
18913
18173
  {
@@ -19000,6 +18260,7 @@ struct wsp_gguf_kv {
19000
18260
 
19001
18261
  struct wsp_gguf_header {
19002
18262
  char magic[4];
18263
+
19003
18264
  uint32_t version;
19004
18265
  uint64_t n_tensors; // GGUFv2
19005
18266
  uint64_t n_kv; // GGUFv2
@@ -19089,7 +18350,7 @@ struct wsp_gguf_context * wsp_gguf_init_from_file(const char * fname, struct wsp
19089
18350
 
19090
18351
  for (uint32_t i = 0; i < sizeof(magic); i++) {
19091
18352
  if (magic[i] != WSP_GGUF_MAGIC[i]) {
19092
- fprintf(stderr, "%s: invalid magic characters %s.\n", __func__, magic);
18353
+ fprintf(stderr, "%s: invalid magic characters '%c%c%c%c'\n", __func__, magic[0], magic[1], magic[2], magic[3]);
19093
18354
  fclose(file);
19094
18355
  return NULL;
19095
18356
  }
@@ -19104,7 +18365,6 @@ struct wsp_gguf_context * wsp_gguf_init_from_file(const char * fname, struct wsp
19104
18365
  {
19105
18366
  strncpy(ctx->header.magic, magic, 4);
19106
18367
 
19107
-
19108
18368
  ctx->kv = NULL;
19109
18369
  ctx->infos = NULL;
19110
18370
  ctx->data = NULL;
@@ -19132,7 +18392,7 @@ struct wsp_gguf_context * wsp_gguf_init_from_file(const char * fname, struct wsp
19132
18392
  {
19133
18393
  ctx->kv = malloc(ctx->header.n_kv * sizeof(struct wsp_gguf_kv));
19134
18394
 
19135
- for (uint32_t i = 0; i < ctx->header.n_kv; ++i) {
18395
+ for (uint64_t i = 0; i < ctx->header.n_kv; ++i) {
19136
18396
  struct wsp_gguf_kv * kv = &ctx->kv[i];
19137
18397
 
19138
18398
  //fprintf(stderr, "%s: reading kv %d\n", __func__, i);
@@ -19179,7 +18439,7 @@ struct wsp_gguf_context * wsp_gguf_init_from_file(const char * fname, struct wsp
19179
18439
  case WSP_GGUF_TYPE_STRING:
19180
18440
  {
19181
18441
  kv->value.arr.data = malloc(kv->value.arr.n * sizeof(struct wsp_gguf_str));
19182
- for (uint32_t j = 0; j < kv->value.arr.n; ++j) {
18442
+ for (uint64_t j = 0; j < kv->value.arr.n; ++j) {
19183
18443
  ok = ok && wsp_gguf_fread_str(file, &((struct wsp_gguf_str *) kv->value.arr.data)[j], &offset);
19184
18444
  }
19185
18445
  } break;
@@ -19207,7 +18467,7 @@ struct wsp_gguf_context * wsp_gguf_init_from_file(const char * fname, struct wsp
19207
18467
  {
19208
18468
  ctx->infos = malloc(ctx->header.n_tensors * sizeof(struct wsp_gguf_tensor_info));
19209
18469
 
19210
- for (uint32_t i = 0; i < ctx->header.n_tensors; ++i) {
18470
+ for (uint64_t i = 0; i < ctx->header.n_tensors; ++i) {
19211
18471
  struct wsp_gguf_tensor_info * info = &ctx->infos[i];
19212
18472
 
19213
18473
  for (int j = 0; j < WSP_GGML_MAX_DIMS; ++j) {
@@ -19254,7 +18514,7 @@ struct wsp_gguf_context * wsp_gguf_init_from_file(const char * fname, struct wsp
19254
18514
  // compute the total size of the data section, taking into account the alignment
19255
18515
  {
19256
18516
  ctx->size = 0;
19257
- for (uint32_t i = 0; i < ctx->header.n_tensors; ++i) {
18517
+ for (uint64_t i = 0; i < ctx->header.n_tensors; ++i) {
19258
18518
  struct wsp_gguf_tensor_info * info = &ctx->infos[i];
19259
18519
 
19260
18520
  const int64_t ne =
@@ -19323,7 +18583,7 @@ struct wsp_gguf_context * wsp_gguf_init_from_file(const char * fname, struct wsp
19323
18583
  wsp_ggml_set_no_alloc(ctx_data, true);
19324
18584
 
19325
18585
  // create the tensors
19326
- for (uint32_t i = 0; i < ctx->header.n_tensors; ++i) {
18586
+ for (uint64_t i = 0; i < ctx->header.n_tensors; ++i) {
19327
18587
  const int64_t ne[WSP_GGML_MAX_DIMS] = {
19328
18588
  ctx->infos[i].ne[0],
19329
18589
  ctx->infos[i].ne[1],
@@ -19458,24 +18718,29 @@ int wsp_gguf_find_key(const struct wsp_gguf_context * ctx, const char * key) {
19458
18718
  }
19459
18719
 
19460
18720
  const char * wsp_gguf_get_key(const struct wsp_gguf_context * ctx, int key_id) {
18721
+ WSP_GGML_ASSERT(key_id >= 0 && key_id < wsp_gguf_get_n_kv(ctx));
19461
18722
  return ctx->kv[key_id].key.data;
19462
18723
  }
19463
18724
 
19464
18725
  enum wsp_gguf_type wsp_gguf_get_kv_type(const struct wsp_gguf_context * ctx, int key_id) {
18726
+ WSP_GGML_ASSERT(key_id >= 0 && key_id < wsp_gguf_get_n_kv(ctx));
19465
18727
  return ctx->kv[key_id].type;
19466
18728
  }
19467
18729
 
19468
18730
  enum wsp_gguf_type wsp_gguf_get_arr_type(const struct wsp_gguf_context * ctx, int key_id) {
18731
+ WSP_GGML_ASSERT(key_id >= 0 && key_id < wsp_gguf_get_n_kv(ctx));
19469
18732
  WSP_GGML_ASSERT(ctx->kv[key_id].type == WSP_GGUF_TYPE_ARRAY);
19470
18733
  return ctx->kv[key_id].value.arr.type;
19471
18734
  }
19472
18735
 
19473
18736
  const void * wsp_gguf_get_arr_data(const struct wsp_gguf_context * ctx, int key_id) {
18737
+ WSP_GGML_ASSERT(key_id >= 0 && key_id < wsp_gguf_get_n_kv(ctx));
19474
18738
  WSP_GGML_ASSERT(ctx->kv[key_id].type == WSP_GGUF_TYPE_ARRAY);
19475
18739
  return ctx->kv[key_id].value.arr.data;
19476
18740
  }
19477
18741
 
19478
18742
  const char * wsp_gguf_get_arr_str(const struct wsp_gguf_context * ctx, int key_id, int i) {
18743
+ WSP_GGML_ASSERT(key_id >= 0 && key_id < wsp_gguf_get_n_kv(ctx));
19479
18744
  WSP_GGML_ASSERT(ctx->kv[key_id].type == WSP_GGUF_TYPE_ARRAY);
19480
18745
  struct wsp_gguf_kv * kv = &ctx->kv[key_id];
19481
18746
  struct wsp_gguf_str * str = &((struct wsp_gguf_str *) kv->value.arr.data)[i];
@@ -19483,70 +18748,90 @@ const char * wsp_gguf_get_arr_str(const struct wsp_gguf_context * ctx, int key_i
19483
18748
  }
19484
18749
 
19485
18750
  int wsp_gguf_get_arr_n(const struct wsp_gguf_context * ctx, int key_id) {
18751
+ WSP_GGML_ASSERT(key_id >= 0 && key_id < wsp_gguf_get_n_kv(ctx));
19486
18752
  WSP_GGML_ASSERT(ctx->kv[key_id].type == WSP_GGUF_TYPE_ARRAY);
19487
18753
  return ctx->kv[key_id].value.arr.n;
19488
18754
  }
19489
18755
 
19490
18756
  uint8_t wsp_gguf_get_val_u8(const struct wsp_gguf_context * ctx, int key_id) {
18757
+ WSP_GGML_ASSERT(key_id >= 0 && key_id < wsp_gguf_get_n_kv(ctx));
19491
18758
  WSP_GGML_ASSERT(ctx->kv[key_id].type == WSP_GGUF_TYPE_UINT8);
19492
18759
  return ctx->kv[key_id].value.uint8;
19493
18760
  }
19494
18761
 
19495
18762
  int8_t wsp_gguf_get_val_i8(const struct wsp_gguf_context * ctx, int key_id) {
18763
+ WSP_GGML_ASSERT(key_id >= 0 && key_id < wsp_gguf_get_n_kv(ctx));
19496
18764
  WSP_GGML_ASSERT(ctx->kv[key_id].type == WSP_GGUF_TYPE_INT8);
19497
18765
  return ctx->kv[key_id].value.int8;
19498
18766
  }
19499
18767
 
19500
18768
  uint16_t wsp_gguf_get_val_u16(const struct wsp_gguf_context * ctx, int key_id) {
18769
+ WSP_GGML_ASSERT(key_id >= 0 && key_id < wsp_gguf_get_n_kv(ctx));
19501
18770
  WSP_GGML_ASSERT(ctx->kv[key_id].type == WSP_GGUF_TYPE_UINT16);
19502
18771
  return ctx->kv[key_id].value.uint16;
19503
18772
  }
19504
18773
 
19505
18774
  int16_t wsp_gguf_get_val_i16(const struct wsp_gguf_context * ctx, int key_id) {
18775
+ WSP_GGML_ASSERT(key_id >= 0 && key_id < wsp_gguf_get_n_kv(ctx));
19506
18776
  WSP_GGML_ASSERT(ctx->kv[key_id].type == WSP_GGUF_TYPE_INT16);
19507
18777
  return ctx->kv[key_id].value.int16;
19508
18778
  }
19509
18779
 
19510
18780
  uint32_t wsp_gguf_get_val_u32(const struct wsp_gguf_context * ctx, int key_id) {
18781
+ WSP_GGML_ASSERT(key_id >= 0 && key_id < wsp_gguf_get_n_kv(ctx));
19511
18782
  WSP_GGML_ASSERT(ctx->kv[key_id].type == WSP_GGUF_TYPE_UINT32);
19512
18783
  return ctx->kv[key_id].value.uint32;
19513
18784
  }
19514
18785
 
19515
18786
  int32_t wsp_gguf_get_val_i32(const struct wsp_gguf_context * ctx, int key_id) {
18787
+ WSP_GGML_ASSERT(key_id >= 0 && key_id < wsp_gguf_get_n_kv(ctx));
19516
18788
  WSP_GGML_ASSERT(ctx->kv[key_id].type == WSP_GGUF_TYPE_INT32);
19517
18789
  return ctx->kv[key_id].value.int32;
19518
18790
  }
19519
18791
 
19520
18792
  float wsp_gguf_get_val_f32(const struct wsp_gguf_context * ctx, int key_id) {
18793
+ WSP_GGML_ASSERT(key_id >= 0 && key_id < wsp_gguf_get_n_kv(ctx));
19521
18794
  WSP_GGML_ASSERT(ctx->kv[key_id].type == WSP_GGUF_TYPE_FLOAT32);
19522
18795
  return ctx->kv[key_id].value.float32;
19523
18796
  }
19524
18797
 
19525
18798
  uint64_t wsp_gguf_get_val_u64(const struct wsp_gguf_context * ctx, int key_id) {
18799
+ WSP_GGML_ASSERT(key_id >= 0 && key_id < wsp_gguf_get_n_kv(ctx));
19526
18800
  WSP_GGML_ASSERT(ctx->kv[key_id].type == WSP_GGUF_TYPE_UINT64);
19527
18801
  return ctx->kv[key_id].value.uint64;
19528
18802
  }
19529
18803
 
19530
18804
  int64_t wsp_gguf_get_val_i64(const struct wsp_gguf_context * ctx, int key_id) {
18805
+ WSP_GGML_ASSERT(key_id >= 0 && key_id < wsp_gguf_get_n_kv(ctx));
19531
18806
  WSP_GGML_ASSERT(ctx->kv[key_id].type == WSP_GGUF_TYPE_INT64);
19532
18807
  return ctx->kv[key_id].value.int64;
19533
18808
  }
19534
18809
 
19535
18810
  double wsp_gguf_get_val_f64(const struct wsp_gguf_context * ctx, int key_id) {
18811
+ WSP_GGML_ASSERT(key_id >= 0 && key_id < wsp_gguf_get_n_kv(ctx));
19536
18812
  WSP_GGML_ASSERT(ctx->kv[key_id].type == WSP_GGUF_TYPE_FLOAT64);
19537
18813
  return ctx->kv[key_id].value.float64;
19538
18814
  }
19539
18815
 
19540
18816
  bool wsp_gguf_get_val_bool(const struct wsp_gguf_context * ctx, int key_id) {
18817
+ WSP_GGML_ASSERT(key_id >= 0 && key_id < wsp_gguf_get_n_kv(ctx));
19541
18818
  WSP_GGML_ASSERT(ctx->kv[key_id].type == WSP_GGUF_TYPE_BOOL);
19542
18819
  return ctx->kv[key_id].value.bool_;
19543
18820
  }
19544
18821
 
19545
18822
  const char * wsp_gguf_get_val_str(const struct wsp_gguf_context * ctx, int key_id) {
18823
+ WSP_GGML_ASSERT(key_id >= 0 && key_id < wsp_gguf_get_n_kv(ctx));
19546
18824
  WSP_GGML_ASSERT(ctx->kv[key_id].type == WSP_GGUF_TYPE_STRING);
19547
18825
  return ctx->kv[key_id].value.str.data;
19548
18826
  }
19549
18827
 
18828
+ const void * wsp_gguf_get_val_data(const struct wsp_gguf_context * ctx, int key_id) {
18829
+ WSP_GGML_ASSERT(key_id >= 0 && key_id < wsp_gguf_get_n_kv(ctx));
18830
+ WSP_GGML_ASSERT(ctx->kv[key_id].type != WSP_GGUF_TYPE_ARRAY);
18831
+ WSP_GGML_ASSERT(ctx->kv[key_id].type != WSP_GGUF_TYPE_STRING);
18832
+ return &ctx->kv[key_id].value;
18833
+ }
18834
+
19550
18835
  int wsp_gguf_get_n_tensors(const struct wsp_gguf_context * ctx) {
19551
18836
  return ctx->header.n_tensors;
19552
18837
  }