whisper.rn 0.4.0-rc.4 → 0.4.0-rc.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. package/README.md +6 -6
  2. package/android/build.gradle +4 -0
  3. package/android/src/main/CMakeLists.txt +5 -0
  4. package/android/src/main/java/com/rnwhisper/AudioUtils.java +0 -80
  5. package/android/src/main/java/com/rnwhisper/WhisperContext.java +57 -134
  6. package/android/src/main/jni-utils.h +76 -0
  7. package/android/src/main/jni.cpp +188 -112
  8. package/cpp/README.md +1 -1
  9. package/cpp/coreml/whisper-encoder-impl.h +1 -1
  10. package/cpp/coreml/whisper-encoder.h +4 -0
  11. package/cpp/coreml/whisper-encoder.mm +4 -2
  12. package/cpp/ggml-alloc.c +55 -19
  13. package/cpp/ggml-alloc.h +8 -1
  14. package/cpp/ggml-backend-impl.h +46 -21
  15. package/cpp/ggml-backend.c +563 -156
  16. package/cpp/ggml-backend.h +62 -17
  17. package/cpp/ggml-impl.h +1 -1
  18. package/cpp/ggml-metal-whisper.metal +2444 -359
  19. package/cpp/ggml-metal.h +7 -1
  20. package/cpp/ggml-metal.m +1105 -197
  21. package/cpp/ggml-quants.c +66 -61
  22. package/cpp/ggml-quants.h +40 -40
  23. package/cpp/ggml.c +1040 -1590
  24. package/cpp/ggml.h +109 -30
  25. package/cpp/rn-audioutils.cpp +68 -0
  26. package/cpp/rn-audioutils.h +14 -0
  27. package/cpp/rn-whisper-log.h +11 -0
  28. package/cpp/rn-whisper.cpp +143 -59
  29. package/cpp/rn-whisper.h +48 -15
  30. package/cpp/whisper.cpp +1635 -928
  31. package/cpp/whisper.h +55 -10
  32. package/ios/RNWhisper.mm +7 -7
  33. package/ios/RNWhisperAudioUtils.h +0 -2
  34. package/ios/RNWhisperAudioUtils.m +0 -56
  35. package/ios/RNWhisperContext.h +3 -11
  36. package/ios/RNWhisperContext.mm +68 -137
  37. package/lib/commonjs/index.js.map +1 -1
  38. package/lib/commonjs/version.json +1 -1
  39. package/lib/module/index.js.map +1 -1
  40. package/lib/module/version.json +1 -1
  41. package/lib/typescript/index.d.ts +5 -0
  42. package/lib/typescript/index.d.ts.map +1 -1
  43. package/package.json +6 -5
  44. package/src/index.ts +5 -0
  45. package/src/version.json +1 -1
  46. package/ios/RNWhisper.xcodeproj/project.xcworkspace/contents.xcworkspacedata +0 -4
  47. package/ios/RNWhisper.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist +0 -8
  48. package/ios/RNWhisper.xcodeproj/project.xcworkspace/xcuserdata/jhen.xcuserdatad/UserInterfaceState.xcuserstate +0 -0
  49. package/ios/RNWhisper.xcodeproj/xcuserdata/jhen.xcuserdatad/xcschemes/xcschememanagement.plist +0 -19
package/cpp/ggml.c CHANGED
@@ -1,4 +1,4 @@
1
- #define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnigns on Windows
1
+ #define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnings on Windows
2
2
  #define _USE_MATH_DEFINES // For M_PI on MSVC
3
3
 
4
4
  #include "ggml-impl.h"
@@ -33,7 +33,7 @@
33
33
  // we should just be careful :)
34
34
  #pragma warning(disable: 4244 4267)
35
35
 
36
- // disable POSIX deprecation warnigns
36
+ // disable POSIX deprecation warnings
37
37
  // these functions are never going away, anyway
38
38
  #pragma warning(disable: 4996)
39
39
  #endif
@@ -233,24 +233,6 @@ inline static void * wsp_ggml_aligned_malloc(size_t size) {
233
233
  #define UNUSED WSP_GGML_UNUSED
234
234
  #define SWAP(x, y, T) do { T SWAP = x; x = y; y = SWAP; } while (0)
235
235
 
236
- //
237
- // tensor access macros
238
- //
239
-
240
- #define WSP_GGML_TENSOR_UNARY_OP_LOCALS \
241
- WSP_GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
242
- WSP_GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \
243
- WSP_GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \
244
- WSP_GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
245
-
246
- #define WSP_GGML_TENSOR_BINARY_OP_LOCALS \
247
- WSP_GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
248
- WSP_GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \
249
- WSP_GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) \
250
- WSP_GGML_TENSOR_LOCALS(size_t, nb1, src1, nb) \
251
- WSP_GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \
252
- WSP_GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
253
-
254
236
  #if defined(WSP_GGML_USE_ACCELERATE)
255
237
  #include <Accelerate/Accelerate.h>
256
238
  #if defined(WSP_GGML_USE_CLBLAST) // allow usage of CLBlast alongside Accelerate functions
@@ -455,9 +437,9 @@ static const wsp_ggml_type_traits_t type_traits[WSP_GGML_TYPE_COUNT] = {
455
437
  .blck_size = QK4_0,
456
438
  .type_size = sizeof(block_q4_0),
457
439
  .is_quantized = true,
458
- .to_float = (wsp_ggml_to_float_t) dequantize_row_q4_0,
459
- .from_float = quantize_row_q4_0,
460
- .from_float_reference = (wsp_ggml_from_float_t) quantize_row_q4_0_reference,
440
+ .to_float = (wsp_ggml_to_float_t) wsp_dewsp_quantize_row_q4_0,
441
+ .from_float = wsp_quantize_row_q4_0,
442
+ .from_float_reference = (wsp_ggml_from_float_t) wsp_quantize_row_q4_0_reference,
461
443
  .vec_dot = wsp_ggml_vec_dot_q4_0_q8_0,
462
444
  .vec_dot_type = WSP_GGML_TYPE_Q8_0,
463
445
  },
@@ -466,9 +448,9 @@ static const wsp_ggml_type_traits_t type_traits[WSP_GGML_TYPE_COUNT] = {
466
448
  .blck_size = QK4_1,
467
449
  .type_size = sizeof(block_q4_1),
468
450
  .is_quantized = true,
469
- .to_float = (wsp_ggml_to_float_t) dequantize_row_q4_1,
470
- .from_float = quantize_row_q4_1,
471
- .from_float_reference = (wsp_ggml_from_float_t) quantize_row_q4_1_reference,
451
+ .to_float = (wsp_ggml_to_float_t) wsp_dewsp_quantize_row_q4_1,
452
+ .from_float = wsp_quantize_row_q4_1,
453
+ .from_float_reference = (wsp_ggml_from_float_t) wsp_quantize_row_q4_1_reference,
472
454
  .vec_dot = wsp_ggml_vec_dot_q4_1_q8_1,
473
455
  .vec_dot_type = WSP_GGML_TYPE_Q8_1,
474
456
  },
@@ -499,9 +481,9 @@ static const wsp_ggml_type_traits_t type_traits[WSP_GGML_TYPE_COUNT] = {
499
481
  .blck_size = QK5_0,
500
482
  .type_size = sizeof(block_q5_0),
501
483
  .is_quantized = true,
502
- .to_float = (wsp_ggml_to_float_t) dequantize_row_q5_0,
503
- .from_float = quantize_row_q5_0,
504
- .from_float_reference = (wsp_ggml_from_float_t) quantize_row_q5_0_reference,
484
+ .to_float = (wsp_ggml_to_float_t) wsp_dewsp_quantize_row_q5_0,
485
+ .from_float = wsp_quantize_row_q5_0,
486
+ .from_float_reference = (wsp_ggml_from_float_t) wsp_quantize_row_q5_0_reference,
505
487
  .vec_dot = wsp_ggml_vec_dot_q5_0_q8_0,
506
488
  .vec_dot_type = WSP_GGML_TYPE_Q8_0,
507
489
  },
@@ -510,9 +492,9 @@ static const wsp_ggml_type_traits_t type_traits[WSP_GGML_TYPE_COUNT] = {
510
492
  .blck_size = QK5_1,
511
493
  .type_size = sizeof(block_q5_1),
512
494
  .is_quantized = true,
513
- .to_float = (wsp_ggml_to_float_t) dequantize_row_q5_1,
514
- .from_float = quantize_row_q5_1,
515
- .from_float_reference = (wsp_ggml_from_float_t) quantize_row_q5_1_reference,
495
+ .to_float = (wsp_ggml_to_float_t) wsp_dewsp_quantize_row_q5_1,
496
+ .from_float = wsp_quantize_row_q5_1,
497
+ .from_float_reference = (wsp_ggml_from_float_t) wsp_quantize_row_q5_1_reference,
516
498
  .vec_dot = wsp_ggml_vec_dot_q5_1_q8_1,
517
499
  .vec_dot_type = WSP_GGML_TYPE_Q8_1,
518
500
  },
@@ -521,9 +503,9 @@ static const wsp_ggml_type_traits_t type_traits[WSP_GGML_TYPE_COUNT] = {
521
503
  .blck_size = QK8_0,
522
504
  .type_size = sizeof(block_q8_0),
523
505
  .is_quantized = true,
524
- .to_float = (wsp_ggml_to_float_t) dequantize_row_q8_0,
525
- .from_float = quantize_row_q8_0,
526
- .from_float_reference = (wsp_ggml_from_float_t) quantize_row_q8_0_reference,
506
+ .to_float = (wsp_ggml_to_float_t) wsp_dewsp_quantize_row_q8_0,
507
+ .from_float = wsp_quantize_row_q8_0,
508
+ .from_float_reference = (wsp_ggml_from_float_t) wsp_quantize_row_q8_0_reference,
527
509
  .vec_dot = wsp_ggml_vec_dot_q8_0_q8_0,
528
510
  .vec_dot_type = WSP_GGML_TYPE_Q8_0,
529
511
  },
@@ -532,8 +514,8 @@ static const wsp_ggml_type_traits_t type_traits[WSP_GGML_TYPE_COUNT] = {
532
514
  .blck_size = QK8_1,
533
515
  .type_size = sizeof(block_q8_1),
534
516
  .is_quantized = true,
535
- .from_float = quantize_row_q8_1,
536
- .from_float_reference = (wsp_ggml_from_float_t) quantize_row_q8_1_reference,
517
+ .from_float = wsp_quantize_row_q8_1,
518
+ .from_float_reference = (wsp_ggml_from_float_t) wsp_quantize_row_q8_1_reference,
537
519
  .vec_dot_type = WSP_GGML_TYPE_Q8_1,
538
520
  },
539
521
  [WSP_GGML_TYPE_Q2_K] = {
@@ -541,9 +523,9 @@ static const wsp_ggml_type_traits_t type_traits[WSP_GGML_TYPE_COUNT] = {
541
523
  .blck_size = QK_K,
542
524
  .type_size = sizeof(block_q2_K),
543
525
  .is_quantized = true,
544
- .to_float = (wsp_ggml_to_float_t) dequantize_row_q2_K,
545
- .from_float = quantize_row_q2_K,
546
- .from_float_reference = (wsp_ggml_from_float_t) quantize_row_q2_K_reference,
526
+ .to_float = (wsp_ggml_to_float_t) wsp_dewsp_quantize_row_q2_K,
527
+ .from_float = wsp_quantize_row_q2_K,
528
+ .from_float_reference = (wsp_ggml_from_float_t) wsp_quantize_row_q2_K_reference,
547
529
  .vec_dot = wsp_ggml_vec_dot_q2_K_q8_K,
548
530
  .vec_dot_type = WSP_GGML_TYPE_Q8_K,
549
531
  },
@@ -552,9 +534,9 @@ static const wsp_ggml_type_traits_t type_traits[WSP_GGML_TYPE_COUNT] = {
552
534
  .blck_size = QK_K,
553
535
  .type_size = sizeof(block_q3_K),
554
536
  .is_quantized = true,
555
- .to_float = (wsp_ggml_to_float_t) dequantize_row_q3_K,
556
- .from_float = quantize_row_q3_K,
557
- .from_float_reference = (wsp_ggml_from_float_t) quantize_row_q3_K_reference,
537
+ .to_float = (wsp_ggml_to_float_t) wsp_dewsp_quantize_row_q3_K,
538
+ .from_float = wsp_quantize_row_q3_K,
539
+ .from_float_reference = (wsp_ggml_from_float_t) wsp_quantize_row_q3_K_reference,
558
540
  .vec_dot = wsp_ggml_vec_dot_q3_K_q8_K,
559
541
  .vec_dot_type = WSP_GGML_TYPE_Q8_K,
560
542
  },
@@ -563,9 +545,9 @@ static const wsp_ggml_type_traits_t type_traits[WSP_GGML_TYPE_COUNT] = {
563
545
  .blck_size = QK_K,
564
546
  .type_size = sizeof(block_q4_K),
565
547
  .is_quantized = true,
566
- .to_float = (wsp_ggml_to_float_t) dequantize_row_q4_K,
567
- .from_float = quantize_row_q4_K,
568
- .from_float_reference = (wsp_ggml_from_float_t) quantize_row_q4_K_reference,
548
+ .to_float = (wsp_ggml_to_float_t) wsp_dewsp_quantize_row_q4_K,
549
+ .from_float = wsp_quantize_row_q4_K,
550
+ .from_float_reference = (wsp_ggml_from_float_t) wsp_quantize_row_q4_K_reference,
569
551
  .vec_dot = wsp_ggml_vec_dot_q4_K_q8_K,
570
552
  .vec_dot_type = WSP_GGML_TYPE_Q8_K,
571
553
  },
@@ -574,9 +556,9 @@ static const wsp_ggml_type_traits_t type_traits[WSP_GGML_TYPE_COUNT] = {
574
556
  .blck_size = QK_K,
575
557
  .type_size = sizeof(block_q5_K),
576
558
  .is_quantized = true,
577
- .to_float = (wsp_ggml_to_float_t) dequantize_row_q5_K,
578
- .from_float = quantize_row_q5_K,
579
- .from_float_reference = (wsp_ggml_from_float_t) quantize_row_q5_K_reference,
559
+ .to_float = (wsp_ggml_to_float_t) wsp_dewsp_quantize_row_q5_K,
560
+ .from_float = wsp_quantize_row_q5_K,
561
+ .from_float_reference = (wsp_ggml_from_float_t) wsp_quantize_row_q5_K_reference,
580
562
  .vec_dot = wsp_ggml_vec_dot_q5_K_q8_K,
581
563
  .vec_dot_type = WSP_GGML_TYPE_Q8_K,
582
564
  },
@@ -585,9 +567,9 @@ static const wsp_ggml_type_traits_t type_traits[WSP_GGML_TYPE_COUNT] = {
585
567
  .blck_size = QK_K,
586
568
  .type_size = sizeof(block_q6_K),
587
569
  .is_quantized = true,
588
- .to_float = (wsp_ggml_to_float_t) dequantize_row_q6_K,
589
- .from_float = quantize_row_q6_K,
590
- .from_float_reference = (wsp_ggml_from_float_t) quantize_row_q6_K_reference,
570
+ .to_float = (wsp_ggml_to_float_t) wsp_dewsp_quantize_row_q6_K,
571
+ .from_float = wsp_quantize_row_q6_K,
572
+ .from_float_reference = (wsp_ggml_from_float_t) wsp_quantize_row_q6_K_reference,
591
573
  .vec_dot = wsp_ggml_vec_dot_q6_K_q8_K,
592
574
  .vec_dot_type = WSP_GGML_TYPE_Q8_K,
593
575
  },
@@ -596,7 +578,7 @@ static const wsp_ggml_type_traits_t type_traits[WSP_GGML_TYPE_COUNT] = {
596
578
  .blck_size = QK_K,
597
579
  .type_size = sizeof(block_q8_K),
598
580
  .is_quantized = true,
599
- .from_float = quantize_row_q8_K,
581
+ .from_float = wsp_quantize_row_q8_K,
600
582
  }
601
583
  };
602
584
 
@@ -1413,7 +1395,7 @@ inline static void wsp_ggml_vec_step_f32 (const int n, float * y, const float *
1413
1395
  inline static void wsp_ggml_vec_tanh_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = tanhf(x[i]); }
1414
1396
  inline static void wsp_ggml_vec_elu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : expf(x[i])-1; }
1415
1397
  inline static void wsp_ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; }
1416
- inline static void wsp_ggml_vec_leaky_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.1f*x[i]; }
1398
+ inline static void wsp_ggml_vec_leaky_relu_f32 (const int n, float * y, const float * x, const float ns) { for (int i = 0; i < n; ++i) y[i] = ((x[i] > 0.f) ? x[i] : 0.f) + ns * ((x[i] < 0.0f) ? x[i] : 0.f); }
1417
1399
 
1418
1400
  static const float GELU_COEF_A = 0.044715f;
1419
1401
  static const float GELU_QUICK_COEF = -1.702f;
@@ -1613,6 +1595,7 @@ static const char * WSP_GGML_OP_NAME[WSP_GGML_OP_COUNT] = {
1613
1595
  "GROUP_NORM",
1614
1596
 
1615
1597
  "MUL_MAT",
1598
+ "MUL_MAT_ID",
1616
1599
  "OUT_PROD",
1617
1600
 
1618
1601
  "SCALE",
@@ -1634,17 +1617,15 @@ static const char * WSP_GGML_OP_NAME[WSP_GGML_OP_COUNT] = {
1634
1617
  "ROPE_BACK",
1635
1618
  "ALIBI",
1636
1619
  "CLAMP",
1637
- "CONV_1D",
1638
- "CONV_1D_STAGE_0",
1639
- "CONV_1D_STAGE_1",
1640
1620
  "CONV_TRANSPOSE_1D",
1641
- "CONV_2D",
1642
- "CONV_2D_STAGE_0",
1643
- "CONV_2D_STAGE_1",
1621
+ "IM2COL",
1644
1622
  "CONV_TRANSPOSE_2D",
1645
1623
  "POOL_1D",
1646
1624
  "POOL_2D",
1647
1625
  "UPSCALE",
1626
+ "PAD",
1627
+ "ARGSORT",
1628
+ "LEAKY_RELU",
1648
1629
 
1649
1630
  "FLASH_ATTN",
1650
1631
  "FLASH_FF",
@@ -1671,7 +1652,7 @@ static const char * WSP_GGML_OP_NAME[WSP_GGML_OP_COUNT] = {
1671
1652
  "CROSS_ENTROPY_LOSS_BACK",
1672
1653
  };
1673
1654
 
1674
- static_assert(WSP_GGML_OP_COUNT == 73, "WSP_GGML_OP_COUNT != 73");
1655
+ static_assert(WSP_GGML_OP_COUNT == 72, "WSP_GGML_OP_COUNT != 72");
1675
1656
 
1676
1657
  static const char * WSP_GGML_OP_SYMBOL[WSP_GGML_OP_COUNT] = {
1677
1658
  "none",
@@ -1700,6 +1681,7 @@ static const char * WSP_GGML_OP_SYMBOL[WSP_GGML_OP_COUNT] = {
1700
1681
  "group_norm(x)",
1701
1682
 
1702
1683
  "X*Y",
1684
+ "X[i]*Y",
1703
1685
  "X*Y",
1704
1686
 
1705
1687
  "x*v",
@@ -1721,17 +1703,15 @@ static const char * WSP_GGML_OP_SYMBOL[WSP_GGML_OP_COUNT] = {
1721
1703
  "rope_back(x)",
1722
1704
  "alibi(x)",
1723
1705
  "clamp(x)",
1724
- "conv_1d(x)",
1725
- "conv_1d_stage_0(x)",
1726
- "conv_1d_stage_1(x)",
1727
1706
  "conv_transpose_1d(x)",
1728
- "conv_2d(x)",
1729
- "conv_2d_stage_0(x)",
1730
- "conv_2d_stage_1(x)",
1707
+ "im2col(x)",
1731
1708
  "conv_transpose_2d(x)",
1732
1709
  "pool_1d(x)",
1733
1710
  "pool_2d(x)",
1734
1711
  "upscale(x)",
1712
+ "pad(x)",
1713
+ "argsort(x)",
1714
+ "leaky_relu(x)",
1735
1715
 
1736
1716
  "flash_attn(x)",
1737
1717
  "flash_ff(x)",
@@ -1758,15 +1738,32 @@ static const char * WSP_GGML_OP_SYMBOL[WSP_GGML_OP_COUNT] = {
1758
1738
  "cross_entropy_loss_back(x,y)",
1759
1739
  };
1760
1740
 
1761
- static_assert(WSP_GGML_OP_COUNT == 73, "WSP_GGML_OP_COUNT != 73");
1741
+ static_assert(WSP_GGML_OP_COUNT == 72, "WSP_GGML_OP_COUNT != 72");
1762
1742
 
1763
1743
  static_assert(WSP_GGML_OP_POOL_COUNT == 2, "WSP_GGML_OP_POOL_COUNT != 2");
1764
1744
 
1745
+
1746
+ static const char * WSP_GGML_UNARY_OP_NAME[WSP_GGML_UNARY_OP_COUNT] = {
1747
+ "ABS",
1748
+ "SGN",
1749
+ "NEG",
1750
+ "STEP",
1751
+ "TANH",
1752
+ "ELU",
1753
+ "RELU",
1754
+ "GELU",
1755
+ "GELU_QUICK",
1756
+ "SILU",
1757
+ };
1758
+
1759
+ static_assert(WSP_GGML_UNARY_OP_COUNT == 10, "WSP_GGML_UNARY_OP_COUNT != 10");
1760
+
1761
+
1765
1762
  static_assert(sizeof(struct wsp_ggml_object)%WSP_GGML_MEM_ALIGN == 0, "wsp_ggml_object size must be a multiple of WSP_GGML_MEM_ALIGN");
1766
1763
  static_assert(sizeof(struct wsp_ggml_tensor)%WSP_GGML_MEM_ALIGN == 0, "wsp_ggml_tensor size must be a multiple of WSP_GGML_MEM_ALIGN");
1767
1764
 
1768
1765
  // WARN:
1769
- // Mis-confguration can lead to problem that's hard to reason about:
1766
+ // Mis-configuration can lead to problem that's hard to reason about:
1770
1767
  // * At best it crash or talks nosense.
1771
1768
  // * At worst it talks slightly difference but hard to perceive.
1772
1769
  //
@@ -1781,18 +1778,13 @@ static void wsp_ggml_setup_op_has_task_pass(void) {
1781
1778
 
1782
1779
  p[WSP_GGML_OP_ACC ] = true;
1783
1780
  p[WSP_GGML_OP_MUL_MAT ] = true;
1781
+ p[WSP_GGML_OP_MUL_MAT_ID ] = true;
1784
1782
  p[WSP_GGML_OP_OUT_PROD ] = true;
1785
1783
  p[WSP_GGML_OP_SET ] = true;
1786
1784
  p[WSP_GGML_OP_GET_ROWS_BACK ] = true;
1787
1785
  p[WSP_GGML_OP_DIAG_MASK_INF ] = true;
1788
1786
  p[WSP_GGML_OP_DIAG_MASK_ZERO ] = true;
1789
- p[WSP_GGML_OP_CONV_1D ] = true;
1790
- p[WSP_GGML_OP_CONV_1D_STAGE_0 ] = true;
1791
- p[WSP_GGML_OP_CONV_1D_STAGE_1 ] = true;
1792
1787
  p[WSP_GGML_OP_CONV_TRANSPOSE_1D ] = true;
1793
- p[WSP_GGML_OP_CONV_2D ] = true;
1794
- p[WSP_GGML_OP_CONV_2D_STAGE_0 ] = true;
1795
- p[WSP_GGML_OP_CONV_2D_STAGE_1 ] = true;
1796
1788
  p[WSP_GGML_OP_CONV_TRANSPOSE_2D ] = true;
1797
1789
  p[WSP_GGML_OP_FLASH_ATTN_BACK ] = true;
1798
1790
  p[WSP_GGML_OP_CROSS_ENTROPY_LOSS ] = true;
@@ -2039,6 +2031,20 @@ const char * wsp_ggml_op_symbol(enum wsp_ggml_op op) {
2039
2031
  return WSP_GGML_OP_SYMBOL[op];
2040
2032
  }
2041
2033
 
2034
+ const char * wsp_ggml_unary_op_name(enum wsp_ggml_unary_op op) {
2035
+ return WSP_GGML_UNARY_OP_NAME[op];
2036
+ }
2037
+
2038
+ const char * wsp_ggml_op_desc(const struct wsp_ggml_tensor * t) {
2039
+ if (t->op == WSP_GGML_OP_UNARY) {
2040
+ enum wsp_ggml_unary_op uop = wsp_ggml_get_unary_op(t);
2041
+ return wsp_ggml_unary_op_name(uop);
2042
+ }
2043
+ else {
2044
+ return wsp_ggml_op_name(t->op);
2045
+ }
2046
+ }
2047
+
2042
2048
  size_t wsp_ggml_element_size(const struct wsp_ggml_tensor * tensor) {
2043
2049
  return wsp_ggml_type_size(tensor->type);
2044
2050
  }
@@ -3170,9 +3176,7 @@ static struct wsp_ggml_tensor * wsp_ggml_add_impl(
3170
3176
  struct wsp_ggml_tensor * a,
3171
3177
  struct wsp_ggml_tensor * b,
3172
3178
  bool inplace) {
3173
- // TODO: support less-strict constraint
3174
- // WSP_GGML_ASSERT(wsp_ggml_can_repeat(b, a));
3175
- WSP_GGML_ASSERT(wsp_ggml_can_repeat_rows(b, a));
3179
+ WSP_GGML_ASSERT(wsp_ggml_can_repeat(b, a));
3176
3180
 
3177
3181
  bool is_node = false;
3178
3182
 
@@ -3387,9 +3391,7 @@ static struct wsp_ggml_tensor * wsp_ggml_mul_impl(
3387
3391
  struct wsp_ggml_tensor * a,
3388
3392
  struct wsp_ggml_tensor * b,
3389
3393
  bool inplace) {
3390
- // TODO: support less-strict constraint
3391
- // WSP_GGML_ASSERT(wsp_ggml_can_repeat(b, a));
3392
- WSP_GGML_ASSERT(wsp_ggml_can_repeat_rows(b, a));
3394
+ WSP_GGML_ASSERT(wsp_ggml_can_repeat(b, a));
3393
3395
 
3394
3396
  bool is_node = false;
3395
3397
 
@@ -3434,7 +3436,7 @@ static struct wsp_ggml_tensor * wsp_ggml_div_impl(
3434
3436
  struct wsp_ggml_tensor * a,
3435
3437
  struct wsp_ggml_tensor * b,
3436
3438
  bool inplace) {
3437
- WSP_GGML_ASSERT(wsp_ggml_are_same_shape(a, b));
3439
+ WSP_GGML_ASSERT(wsp_ggml_can_repeat(b, a));
3438
3440
 
3439
3441
  bool is_node = false;
3440
3442
 
@@ -3831,12 +3833,25 @@ struct wsp_ggml_tensor * wsp_ggml_relu_inplace(
3831
3833
  return wsp_ggml_unary_inplace(ctx, a, WSP_GGML_UNARY_OP_RELU);
3832
3834
  }
3833
3835
 
3834
- // wsp_ggml_leaky
3836
+ // wsp_ggml_leaky_relu
3835
3837
 
3836
- struct wsp_ggml_tensor * wsp_ggml_leaky(
3838
+ struct wsp_ggml_tensor * wsp_ggml_leaky_relu(
3837
3839
  struct wsp_ggml_context * ctx,
3838
- struct wsp_ggml_tensor * a) {
3839
- return wsp_ggml_unary(ctx, a, WSP_GGML_UNARY_OP_LEAKY);
3840
+ struct wsp_ggml_tensor * a, float negative_slope, bool inplace) {
3841
+ bool is_node = false;
3842
+
3843
+ if (!inplace && (a->grad)) {
3844
+ is_node = true;
3845
+ }
3846
+
3847
+ struct wsp_ggml_tensor * result = inplace ? wsp_ggml_view_tensor(ctx, a) : wsp_ggml_dup_tensor(ctx, a);
3848
+ wsp_ggml_set_op_params(result, &negative_slope, sizeof(negative_slope));
3849
+
3850
+ result->op = WSP_GGML_OP_LEAKY_RELU;
3851
+ result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL;
3852
+ result->src[0] = a;
3853
+
3854
+ return result;
3840
3855
  }
3841
3856
 
3842
3857
  // wsp_ggml_gelu
@@ -4023,8 +4038,9 @@ static struct wsp_ggml_tensor * wsp_ggml_group_norm_impl(
4023
4038
 
4024
4039
  struct wsp_ggml_tensor * result = inplace ? wsp_ggml_view_tensor(ctx, a) : wsp_ggml_dup_tensor(ctx, a);
4025
4040
 
4026
- result->op = WSP_GGML_OP_GROUP_NORM;
4027
4041
  result->op_params[0] = n_groups;
4042
+
4043
+ result->op = WSP_GGML_OP_GROUP_NORM;
4028
4044
  result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL;
4029
4045
  result->src[0] = a;
4030
4046
  result->src[1] = NULL; // TODO: maybe store epsilon here?
@@ -4072,6 +4088,51 @@ struct wsp_ggml_tensor * wsp_ggml_mul_mat(
4072
4088
  return result;
4073
4089
  }
4074
4090
 
4091
+ // wsp_ggml_mul_mat_id
4092
+
4093
+ struct wsp_ggml_tensor * wsp_ggml_mul_mat_id(
4094
+ struct wsp_ggml_context * ctx,
4095
+ struct wsp_ggml_tensor * const as[],
4096
+ int n_as,
4097
+ struct wsp_ggml_tensor * ids,
4098
+ int id,
4099
+ struct wsp_ggml_tensor * b) {
4100
+
4101
+ WSP_GGML_ASSERT(ids->type == WSP_GGML_TYPE_I32);
4102
+ WSP_GGML_ASSERT(ids->ne[2] == 1 && ids->ne[3] == 1);
4103
+ WSP_GGML_ASSERT(ids->ne[1] == b->ne[1]);
4104
+ WSP_GGML_ASSERT(ids->ne[2] == b->ne[2] && ids->ne[3] == b->ne[3]);
4105
+ WSP_GGML_ASSERT(n_as > 0 && n_as <= WSP_GGML_MAX_SRC - 2);
4106
+ WSP_GGML_ASSERT(id >= 0 && id < ids->ne[0]);
4107
+
4108
+ bool is_node = false;
4109
+
4110
+ if (as[0]->grad || b->grad) {
4111
+ is_node = true;
4112
+ }
4113
+
4114
+ const int64_t ne[4] = { as[0]->ne[1], b->ne[1], b->ne[2], b->ne[3] };
4115
+ struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, WSP_GGML_TYPE_F32, MAX(as[0]->n_dims, b->n_dims), ne);
4116
+
4117
+ wsp_ggml_set_op_params_i32(result, 0, id);
4118
+ wsp_ggml_set_op_params_i32(result, 1, n_as);
4119
+
4120
+ result->op = WSP_GGML_OP_MUL_MAT_ID;
4121
+ result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL;
4122
+ result->src[0] = ids;
4123
+ result->src[1] = b;
4124
+
4125
+ for (int i = 0; i < n_as; i++) {
4126
+ struct wsp_ggml_tensor * a = as[i];
4127
+ WSP_GGML_ASSERT(wsp_ggml_are_same_shape(as[0], a));
4128
+ WSP_GGML_ASSERT(wsp_ggml_can_mul_mat(a, b));
4129
+ WSP_GGML_ASSERT(!wsp_ggml_is_transposed(a));
4130
+ result->src[i + 2] = a;
4131
+ }
4132
+
4133
+ return result;
4134
+ }
4135
+
4075
4136
  // wsp_ggml_out_prod
4076
4137
 
4077
4138
  struct wsp_ggml_tensor * wsp_ggml_out_prod(
@@ -4225,7 +4286,7 @@ struct wsp_ggml_tensor * wsp_ggml_set_2d_inplace(
4225
4286
  struct wsp_ggml_tensor * b,
4226
4287
  size_t nb1,
4227
4288
  size_t offset) {
4228
- return wsp_ggml_set_impl(ctx, a, b, nb1, a->nb[2], a->nb[3], offset, false);
4289
+ return wsp_ggml_set_impl(ctx, a, b, nb1, a->nb[2], a->nb[3], offset, true);
4229
4290
  }
4230
4291
 
4231
4292
  // wsp_ggml_cpy
@@ -4689,7 +4750,9 @@ struct wsp_ggml_tensor * wsp_ggml_get_rows(
4689
4750
  struct wsp_ggml_context * ctx,
4690
4751
  struct wsp_ggml_tensor * a,
4691
4752
  struct wsp_ggml_tensor * b) {
4692
- WSP_GGML_ASSERT(wsp_ggml_is_matrix(a) && wsp_ggml_is_vector(b) && b->type == WSP_GGML_TYPE_I32);
4753
+ WSP_GGML_ASSERT(a->ne[2] == b->ne[1]);
4754
+ WSP_GGML_ASSERT(b->ne[3] == 1);
4755
+ WSP_GGML_ASSERT(b->type == WSP_GGML_TYPE_I32);
4693
4756
 
4694
4757
  bool is_node = false;
4695
4758
 
@@ -4699,7 +4762,7 @@ struct wsp_ggml_tensor * wsp_ggml_get_rows(
4699
4762
 
4700
4763
  // TODO: implement non F32 return
4701
4764
  //struct wsp_ggml_tensor * result = wsp_ggml_new_tensor_2d(ctx, a->type, a->ne[0], b->ne[0]);
4702
- struct wsp_ggml_tensor * result = wsp_ggml_new_tensor_2d(ctx, WSP_GGML_TYPE_F32, a->ne[0], b->ne[0]);
4765
+ struct wsp_ggml_tensor * result = wsp_ggml_new_tensor_4d(ctx, WSP_GGML_TYPE_F32, a->ne[0], b->ne[0], b->ne[1], b->ne[2]);
4703
4766
 
4704
4767
  result->op = WSP_GGML_OP_GET_ROWS;
4705
4768
  result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL;
@@ -4842,7 +4905,17 @@ struct wsp_ggml_tensor * wsp_ggml_diag_mask_zero_inplace(
4842
4905
  static struct wsp_ggml_tensor * wsp_ggml_soft_max_impl(
4843
4906
  struct wsp_ggml_context * ctx,
4844
4907
  struct wsp_ggml_tensor * a,
4908
+ struct wsp_ggml_tensor * mask,
4909
+ float scale,
4845
4910
  bool inplace) {
4911
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous(a));
4912
+ if (mask) {
4913
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous(mask));
4914
+ WSP_GGML_ASSERT(mask->ne[2] == 1);
4915
+ WSP_GGML_ASSERT(mask->ne[3] == 1);
4916
+ WSP_GGML_ASSERT(wsp_ggml_can_repeat_rows(mask, a));
4917
+ }
4918
+
4846
4919
  bool is_node = false;
4847
4920
 
4848
4921
  if (a->grad) {
@@ -4851,9 +4924,13 @@ static struct wsp_ggml_tensor * wsp_ggml_soft_max_impl(
4851
4924
 
4852
4925
  struct wsp_ggml_tensor * result = inplace ? wsp_ggml_view_tensor(ctx, a) : wsp_ggml_dup_tensor(ctx, a);
4853
4926
 
4927
+ float params[] = { scale };
4928
+ wsp_ggml_set_op_params(result, params, sizeof(params));
4929
+
4854
4930
  result->op = WSP_GGML_OP_SOFT_MAX;
4855
4931
  result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL;
4856
4932
  result->src[0] = a;
4933
+ result->src[1] = mask;
4857
4934
 
4858
4935
  return result;
4859
4936
  }
@@ -4861,13 +4938,21 @@ static struct wsp_ggml_tensor * wsp_ggml_soft_max_impl(
4861
4938
  struct wsp_ggml_tensor * wsp_ggml_soft_max(
4862
4939
  struct wsp_ggml_context * ctx,
4863
4940
  struct wsp_ggml_tensor * a) {
4864
- return wsp_ggml_soft_max_impl(ctx, a, false);
4941
+ return wsp_ggml_soft_max_impl(ctx, a, NULL, 1.0f, false);
4865
4942
  }
4866
4943
 
4867
4944
  struct wsp_ggml_tensor * wsp_ggml_soft_max_inplace(
4868
4945
  struct wsp_ggml_context * ctx,
4869
4946
  struct wsp_ggml_tensor * a) {
4870
- return wsp_ggml_soft_max_impl(ctx, a, true);
4947
+ return wsp_ggml_soft_max_impl(ctx, a, NULL, 1.0f, true);
4948
+ }
4949
+
4950
+ struct wsp_ggml_tensor * wsp_ggml_soft_max_ext(
4951
+ struct wsp_ggml_context * ctx,
4952
+ struct wsp_ggml_tensor * a,
4953
+ struct wsp_ggml_tensor * mask,
4954
+ float scale) {
4955
+ return wsp_ggml_soft_max_impl(ctx, a, mask, scale, false);
4871
4956
  }
4872
4957
 
4873
4958
  // wsp_ggml_soft_max_back
@@ -5040,8 +5125,13 @@ struct wsp_ggml_tensor * wsp_ggml_rope_back(
5040
5125
  int n_dims,
5041
5126
  int mode,
5042
5127
  int n_ctx,
5128
+ int n_orig_ctx,
5043
5129
  float freq_base,
5044
5130
  float freq_scale,
5131
+ float ext_factor,
5132
+ float attn_factor,
5133
+ float beta_fast,
5134
+ float beta_slow,
5045
5135
  float xpos_base,
5046
5136
  bool xpos_down) {
5047
5137
  WSP_GGML_ASSERT(wsp_ggml_is_vector(b));
@@ -5058,11 +5148,15 @@ struct wsp_ggml_tensor * wsp_ggml_rope_back(
5058
5148
 
5059
5149
  struct wsp_ggml_tensor * result = wsp_ggml_dup_tensor(ctx, a);
5060
5150
 
5061
- int32_t params[8] = { /*n_past*/ 0, n_dims, mode, n_ctx };
5062
- memcpy(params + 4, &freq_base, sizeof(float));
5063
- memcpy(params + 5, &freq_scale, sizeof(float));
5064
- memcpy(params + 6, &xpos_base, sizeof(float));
5065
- memcpy(params + 7, &xpos_down, sizeof(bool));
5151
+ int32_t params[13] = { /*n_past*/ 0, n_dims, mode, n_ctx, n_orig_ctx };
5152
+ memcpy(params + 5, &freq_base, sizeof(float));
5153
+ memcpy(params + 6, &freq_scale, sizeof(float));
5154
+ memcpy(params + 7, &ext_factor, sizeof(float));
5155
+ memcpy(params + 8, &attn_factor, sizeof(float));
5156
+ memcpy(params + 9, &beta_fast, sizeof(float));
5157
+ memcpy(params + 10, &beta_slow, sizeof(float));
5158
+ memcpy(params + 11, &xpos_base, sizeof(float));
5159
+ memcpy(params + 12, &xpos_down, sizeof(bool));
5066
5160
  wsp_ggml_set_op_params(result, params, sizeof(params));
5067
5161
 
5068
5162
  result->op = WSP_GGML_OP_ROPE_BACK;
@@ -5137,82 +5231,6 @@ static int64_t wsp_ggml_calc_conv_output_size(int64_t ins, int64_t ks, int s, in
5137
5231
  return (ins + 2 * p - d * (ks - 1) - 1) / s + 1;
5138
5232
  }
5139
5233
 
5140
- // im2col: [N, IC, IL] => [N, OL, IC*K]
5141
- // a: [OC,IC, K]
5142
- // b: [N, IC, IL]
5143
- // result: [N, OL, IC*K]
5144
- static struct wsp_ggml_tensor * wsp_ggml_conv_1d_stage_0(
5145
- struct wsp_ggml_context * ctx,
5146
- struct wsp_ggml_tensor * a,
5147
- struct wsp_ggml_tensor * b,
5148
- int s0,
5149
- int p0,
5150
- int d0) {
5151
- WSP_GGML_ASSERT(a->ne[1] == b->ne[1]);
5152
- bool is_node = false;
5153
-
5154
- if (a->grad || b->grad) {
5155
- WSP_GGML_ASSERT(false); // TODO: implement backward
5156
- is_node = true;
5157
- }
5158
-
5159
- const int64_t OL = wsp_ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0);
5160
-
5161
- const int64_t ne[4] = {
5162
- a->ne[1] * a->ne[0],
5163
- OL,
5164
- b->ne[2],
5165
- 1,
5166
- };
5167
- struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, WSP_GGML_TYPE_F16, 4, ne);
5168
-
5169
- int32_t params[] = { s0, p0, d0 };
5170
- wsp_ggml_set_op_params(result, params, sizeof(params));
5171
-
5172
- result->op = WSP_GGML_OP_CONV_1D_STAGE_0;
5173
- result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL;
5174
- result->src[0] = a;
5175
- result->src[1] = b;
5176
-
5177
- return result;
5178
- }
5179
-
5180
- // wsp_ggml_conv_1d_stage_1
5181
-
5182
- // gemm: [N, OC, OL] = [OC, IC * K] x [N*OL, IC * K]
5183
- // a: [OC, IC, K]
5184
- // b: [N, OL, IC * K]
5185
- // result: [N, OC, OL]
5186
- static struct wsp_ggml_tensor * wsp_ggml_conv_1d_stage_1(
5187
- struct wsp_ggml_context * ctx,
5188
- struct wsp_ggml_tensor * a,
5189
- struct wsp_ggml_tensor * b) {
5190
-
5191
- bool is_node = false;
5192
-
5193
- if (a->grad || b->grad) {
5194
- WSP_GGML_ASSERT(false); // TODO: implement backward
5195
- is_node = true;
5196
- }
5197
-
5198
- const int64_t ne[4] = {
5199
- b->ne[1],
5200
- a->ne[2],
5201
- b->ne[2],
5202
- 1,
5203
- };
5204
- struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, WSP_GGML_TYPE_F32, 4, ne);
5205
-
5206
- result->op = WSP_GGML_OP_CONV_1D_STAGE_1;
5207
- result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL;
5208
- result->src[0] = a;
5209
- result->src[1] = b;
5210
-
5211
- return result;
5212
- }
5213
-
5214
- // wsp_ggml_conv_1d
5215
-
5216
5234
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_conv_1d(
5217
5235
  struct wsp_ggml_context * ctx,
5218
5236
  struct wsp_ggml_tensor * a,
@@ -5220,43 +5238,17 @@ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_conv_1d(
5220
5238
  int s0,
5221
5239
  int p0,
5222
5240
  int d0) {
5223
- struct wsp_ggml_tensor * result = wsp_ggml_conv_1d_stage_0(ctx, a, b, s0, p0, d0);
5224
- result = wsp_ggml_conv_1d_stage_1(ctx, a, result);
5225
- return result;
5226
- }
5227
-
5228
- // WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_conv_1d(
5229
- // struct wsp_ggml_context * ctx,
5230
- // struct wsp_ggml_tensor * a,
5231
- // struct wsp_ggml_tensor * b,
5232
- // int s0,
5233
- // int p0,
5234
- // int d0) {
5235
- // WSP_GGML_ASSERT(wsp_ggml_is_matrix(b));
5236
- // WSP_GGML_ASSERT(a->ne[1] == b->ne[1]);
5237
- // bool is_node = false;
5238
-
5239
- // if (a->grad || b->grad) {
5240
- // WSP_GGML_ASSERT(false); // TODO: implement backward
5241
- // is_node = true;
5242
- // }
5241
+ struct wsp_ggml_tensor * im2col = wsp_ggml_im2col(ctx, a, b, s0, 0, p0, 0, d0, 0, false); // [N, OL, IC * K]
5243
5242
 
5244
- // const int64_t ne[4] = {
5245
- // wsp_ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0),
5246
- // a->ne[2], 1, 1,
5247
- // };
5248
- // struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, WSP_GGML_TYPE_F32, 2, ne);
5243
+ struct wsp_ggml_tensor * result =
5244
+ wsp_ggml_mul_mat(ctx,
5245
+ wsp_ggml_reshape_2d(ctx, im2col, im2col->ne[0], (im2col->ne[2] * im2col->ne[1])), // [N, OL, IC * K] => [N*OL, IC * K]
5246
+ wsp_ggml_reshape_2d(ctx, a, (a->ne[0] * a->ne[1]), a->ne[2])); // [OC,IC, K] => [OC, IC * K]
5249
5247
 
5250
- // int32_t params[] = { s0, p0, d0 };
5251
- // wsp_ggml_set_op_params(result, params, sizeof(params));
5248
+ result = wsp_ggml_reshape_3d(ctx, result, im2col->ne[1], a->ne[2], im2col->ne[2]); // [N, OC, OL]
5252
5249
 
5253
- // result->op = WSP_GGML_OP_CONV_1D;
5254
- // result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL;
5255
- // result->src[0] = a;
5256
- // result->src[1] = b;
5257
-
5258
- // return result;
5259
- // }
5250
+ return result;
5251
+ }
5260
5252
 
5261
5253
  // wsp_ggml_conv_1d_ph
5262
5254
 
@@ -5319,7 +5311,7 @@ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_conv_transpose_1d(
5319
5311
  // a: [OC,IC, KH, KW]
5320
5312
  // b: [N, IC, IH, IW]
5321
5313
  // result: [N, OH, OW, IC*KH*KW]
5322
- static struct wsp_ggml_tensor * wsp_ggml_conv_2d_stage_0(
5314
+ struct wsp_ggml_tensor * wsp_ggml_im2col(
5323
5315
  struct wsp_ggml_context * ctx,
5324
5316
  struct wsp_ggml_tensor * a,
5325
5317
  struct wsp_ggml_tensor * b,
@@ -5328,9 +5320,14 @@ static struct wsp_ggml_tensor * wsp_ggml_conv_2d_stage_0(
5328
5320
  int p0,
5329
5321
  int p1,
5330
5322
  int d0,
5331
- int d1) {
5323
+ int d1,
5324
+ bool is_2D) {
5332
5325
 
5333
- WSP_GGML_ASSERT(a->ne[2] == b->ne[2]);
5326
+ if(is_2D) {
5327
+ WSP_GGML_ASSERT(a->ne[2] == b->ne[2]);
5328
+ } else {
5329
+ WSP_GGML_ASSERT(a->ne[1] == b->ne[1]);
5330
+ }
5334
5331
  bool is_node = false;
5335
5332
 
5336
5333
  if (a->grad || b->grad) {
@@ -5338,81 +5335,51 @@ static struct wsp_ggml_tensor * wsp_ggml_conv_2d_stage_0(
5338
5335
  is_node = true;
5339
5336
  }
5340
5337
 
5341
- const int64_t OH = wsp_ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1);
5342
- const int64_t OW = wsp_ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0);
5338
+ const int64_t OH = is_2D ? wsp_ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1) : 0;
5339
+ const int64_t OW = wsp_ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0);
5343
5340
 
5344
5341
  const int64_t ne[4] = {
5345
- a->ne[2] * a->ne[1] * a->ne[0],
5342
+ is_2D ? (a->ne[2] * a->ne[1] * a->ne[0]) : a->ne[1] * a->ne[0],
5346
5343
  OW,
5347
- OH,
5348
- b->ne[3],
5344
+ is_2D ? OH : b->ne[2],
5345
+ is_2D ? b->ne[3] : 1,
5349
5346
  };
5350
- struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, WSP_GGML_TYPE_F16, 4, ne);
5351
5347
 
5352
- int32_t params[] = { s0, s1, p0, p1, d0, d1 };
5348
+ struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, WSP_GGML_TYPE_F16, 4, ne);
5349
+ int32_t params[] = { s0, s1, p0, p1, d0, d1, (is_2D ? 1 : 0) };
5353
5350
  wsp_ggml_set_op_params(result, params, sizeof(params));
5354
5351
 
5355
- result->op = WSP_GGML_OP_CONV_2D_STAGE_0;
5356
- result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL;
5357
- result->src[0] = a;
5358
- result->src[1] = b;
5359
-
5360
- return result;
5361
-
5362
- }
5363
-
5364
- // gemm: [N, OC, OH, OW] = [OC, IC * KH * KW] x [N*OH*OW, IC * KH * KW]
5365
- // a: [OC, IC, KH, KW]
5366
- // b: [N, OH, OW, IC * KH * KW]
5367
- // result: [N, OC, OH, OW]
5368
- static struct wsp_ggml_tensor * wsp_ggml_conv_2d_stage_1(
5369
- struct wsp_ggml_context * ctx,
5370
- struct wsp_ggml_tensor * a,
5371
- struct wsp_ggml_tensor * b) {
5372
-
5373
- bool is_node = false;
5374
-
5375
- if (a->grad || b->grad) {
5376
- WSP_GGML_ASSERT(false); // TODO: implement backward
5377
- is_node = true;
5378
- }
5379
-
5380
- const int64_t ne[4] = {
5381
- b->ne[1],
5382
- b->ne[2],
5383
- a->ne[3],
5384
- b->ne[3],
5385
- };
5386
- struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, WSP_GGML_TYPE_F32, 4, ne);
5387
-
5388
- result->op = WSP_GGML_OP_CONV_2D_STAGE_1;
5352
+ result->op = WSP_GGML_OP_IM2COL;
5389
5353
  result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL;
5390
5354
  result->src[0] = a;
5391
5355
  result->src[1] = b;
5392
5356
 
5393
5357
  return result;
5394
-
5395
5358
  }
5396
5359
 
5397
5360
  // a: [OC,IC, KH, KW]
5398
5361
  // b: [N, IC, IH, IW]
5399
5362
  // result: [N, OC, OH, OW]
5400
5363
  struct wsp_ggml_tensor * wsp_ggml_conv_2d(
5401
- struct wsp_ggml_context * ctx,
5402
- struct wsp_ggml_tensor * a,
5403
- struct wsp_ggml_tensor * b,
5404
- int s0,
5405
- int s1,
5406
- int p0,
5407
- int p1,
5408
- int d0,
5409
- int d1) {
5364
+ struct wsp_ggml_context * ctx,
5365
+ struct wsp_ggml_tensor * a,
5366
+ struct wsp_ggml_tensor * b,
5367
+ int s0,
5368
+ int s1,
5369
+ int p0,
5370
+ int p1,
5371
+ int d0,
5372
+ int d1) {
5373
+ struct wsp_ggml_tensor * im2col = wsp_ggml_im2col(ctx, a, b, s0, s1, p0, p1, d0, d1, true); // [N, OH, OW, IC * KH * KW]
5410
5374
 
5411
- struct wsp_ggml_tensor * result = wsp_ggml_conv_2d_stage_0(ctx, a, b, s0, s1, p0, p1, d0, d1); // [N, OH, OW, IC * KH * KW]
5412
- result = wsp_ggml_conv_2d_stage_1(ctx, a, result);
5375
+ struct wsp_ggml_tensor * result =
5376
+ wsp_ggml_mul_mat(ctx,
5377
+ wsp_ggml_reshape_2d(ctx, im2col, im2col->ne[0], im2col->ne[3] * im2col->ne[2] * im2col->ne[1]), // [N, OH, OW, IC * KH * KW] => [N*OH*OW, IC * KH * KW]
5378
+ wsp_ggml_reshape_2d(ctx, a, (a->ne[0] * a->ne[1] * a->ne[2]), a->ne[3])); // [OC,IC, KH, KW] => [OC, IC * KH * KW]
5413
5379
 
5414
- return result;
5380
+ result = wsp_ggml_reshape_4d(ctx, result, im2col->ne[1], im2col->ne[2], a->ne[3], im2col->ne[3]); // [N, OC, OH, OW]
5415
5381
 
5382
+ return result;
5416
5383
  }
5417
5384
 
5418
5385
  // wsp_ggml_conv_2d_sk_p0
@@ -5573,6 +5540,30 @@ static struct wsp_ggml_tensor * wsp_ggml_upscale_impl(
5573
5540
  return result;
5574
5541
  }
5575
5542
 
5543
+ struct wsp_ggml_tensor * wsp_ggml_pad(
5544
+ struct wsp_ggml_context * ctx,
5545
+ struct wsp_ggml_tensor * a,
5546
+ int p0, int p1, int p2, int p3) {
5547
+ bool is_node = false;
5548
+
5549
+ if (a->grad) {
5550
+ WSP_GGML_ASSERT(false); // TODO: implement backward
5551
+ is_node = true;
5552
+ }
5553
+
5554
+ struct wsp_ggml_tensor * result = wsp_ggml_new_tensor_4d(ctx, a->type,
5555
+ a->ne[0] + p0,
5556
+ a->ne[1] + p1,
5557
+ a->ne[2] + p2,
5558
+ a->ne[3] + p3);
5559
+
5560
+ result->op = WSP_GGML_OP_PAD;
5561
+ result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL;
5562
+ result->src[0] = a;
5563
+
5564
+ return result;
5565
+ }
5566
+
5576
5567
  struct wsp_ggml_tensor * wsp_ggml_upscale(
5577
5568
  struct wsp_ggml_context * ctx,
5578
5569
  struct wsp_ggml_tensor * a,
@@ -5580,6 +5571,43 @@ struct wsp_ggml_tensor * wsp_ggml_upscale(
5580
5571
  return wsp_ggml_upscale_impl(ctx, a, scale_factor);
5581
5572
  }
5582
5573
 
5574
+ // wsp_ggml_argsort
5575
+
5576
+ struct wsp_ggml_tensor * wsp_ggml_argsort(
5577
+ struct wsp_ggml_context * ctx,
5578
+ struct wsp_ggml_tensor * a,
5579
+ enum wsp_ggml_sort_order order) {
5580
+ bool is_node = false;
5581
+
5582
+ struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, WSP_GGML_TYPE_I32, a->n_dims, a->ne);
5583
+
5584
+ wsp_ggml_set_op_params_i32(result, 0, (int32_t) order);
5585
+
5586
+ result->op = WSP_GGML_OP_ARGSORT;
5587
+ result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL;
5588
+ result->src[0] = a;
5589
+
5590
+ return result;
5591
+ }
5592
+
5593
+ // wsp_ggml_top_k
5594
+
5595
+ struct wsp_ggml_tensor * wsp_ggml_top_k(
5596
+ struct wsp_ggml_context * ctx,
5597
+ struct wsp_ggml_tensor * a,
5598
+ int k) {
5599
+ WSP_GGML_ASSERT(a->ne[0] >= k);
5600
+
5601
+ struct wsp_ggml_tensor * result = wsp_ggml_argsort(ctx, a, WSP_GGML_SORT_DESC);
5602
+
5603
+ result = wsp_ggml_view_4d(ctx, result,
5604
+ k, result->ne[1], result->ne[2], result->ne[3],
5605
+ result->nb[1], result->nb[2], result->nb[3],
5606
+ 0);
5607
+
5608
+ return result;
5609
+ }
5610
+
5583
5611
  // wsp_ggml_flash_attn
5584
5612
 
5585
5613
  struct wsp_ggml_tensor * wsp_ggml_flash_attn(
@@ -6472,7 +6500,7 @@ static void wsp_ggml_compute_forward_dup_f16(
6472
6500
  }
6473
6501
  }
6474
6502
  } else if (type_traits[dst->type].from_float) {
6475
- wsp_ggml_from_float_t const quantize_row_q = type_traits[dst->type].from_float;
6503
+ wsp_ggml_from_float_t const wsp_quantize_row_q = type_traits[dst->type].from_float;
6476
6504
  float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
6477
6505
 
6478
6506
  size_t id = 0;
@@ -6489,7 +6517,7 @@ static void wsp_ggml_compute_forward_dup_f16(
6489
6517
  src0_f32[i00] = WSP_GGML_FP16_TO_FP32(src0_ptr[i00]);
6490
6518
  }
6491
6519
 
6492
- quantize_row_q(src0_f32, dst_ptr + id, ne00);
6520
+ wsp_quantize_row_q(src0_f32, dst_ptr + id, ne00);
6493
6521
  id += rs;
6494
6522
  }
6495
6523
  id += rs * (ne01 - ir1);
@@ -6725,7 +6753,7 @@ static void wsp_ggml_compute_forward_dup_f32(
6725
6753
  }
6726
6754
  }
6727
6755
  } else if (type_traits[dst->type].from_float) {
6728
- wsp_ggml_from_float_t const quantize_row_q = type_traits[dst->type].from_float;
6756
+ wsp_ggml_from_float_t const wsp_quantize_row_q = type_traits[dst->type].from_float;
6729
6757
 
6730
6758
  size_t id = 0;
6731
6759
  size_t rs = nb0 * (ne00 / wsp_ggml_blck_size(dst->type));
@@ -6736,7 +6764,7 @@ static void wsp_ggml_compute_forward_dup_f32(
6736
6764
  id += rs * ir0;
6737
6765
  for (int i01 = ir0; i01 < ir1; i01++) {
6738
6766
  const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
6739
- quantize_row_q(src0_ptr, dst_ptr + id, ne00);
6767
+ wsp_quantize_row_q(src0_ptr, dst_ptr + id, ne00);
6740
6768
  id += rs;
6741
6769
  }
6742
6770
  id += rs * (ne01 - ir1);
@@ -6939,7 +6967,7 @@ static void wsp_ggml_compute_forward_add_f32(
6939
6967
  const struct wsp_ggml_tensor * src0,
6940
6968
  const struct wsp_ggml_tensor * src1,
6941
6969
  struct wsp_ggml_tensor * dst) {
6942
- WSP_GGML_ASSERT(wsp_ggml_can_repeat_rows(src1, src0) && wsp_ggml_are_same_shape(src0, dst));
6970
+ WSP_GGML_ASSERT(wsp_ggml_can_repeat(src1, src0) && wsp_ggml_are_same_shape(src0, dst));
6943
6971
 
6944
6972
  if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) {
6945
6973
  return;
@@ -6972,16 +7000,19 @@ static void wsp_ggml_compute_forward_add_f32(
6972
7000
  const int64_t i13 = i03 % ne13;
6973
7001
  const int64_t i12 = i02 % ne12;
6974
7002
  const int64_t i11 = i01 % ne11;
7003
+ const int64_t nr0 = ne00 / ne10;
6975
7004
 
6976
7005
  float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
6977
7006
  float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
6978
7007
  float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
6979
7008
 
7009
+ for (int64_t r = 0; r < nr0; ++r) {
6980
7010
  #ifdef WSP_GGML_USE_ACCELERATE
6981
- vDSP_vadd(src0_ptr, 1, src1_ptr, 1, dst_ptr, 1, ne00);
7011
+ vDSP_vadd(src0_ptr + r*ne10, 1, src1_ptr, 1, dst_ptr + r*ne10, 1, ne10);
6982
7012
  #else
6983
- wsp_ggml_vec_add_f32(ne00, dst_ptr, src0_ptr, src1_ptr);
7013
+ wsp_ggml_vec_add_f32(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr);
6984
7014
  #endif
7015
+ }
6985
7016
  }
6986
7017
  } else {
6987
7018
  // src1 is not contiguous
@@ -6998,8 +7029,9 @@ static void wsp_ggml_compute_forward_add_f32(
6998
7029
  float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
6999
7030
  float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
7000
7031
 
7001
- for (int i0 = 0; i0 < ne0; i0++) {
7002
- float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i0*nb10);
7032
+ for (int64_t i0 = 0; i0 < ne0; ++i0) {
7033
+ const int64_t i10 = i0 % ne10;
7034
+ float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10);
7003
7035
 
7004
7036
  dst_ptr[i0] = src0_ptr[i0] + *src1_ptr;
7005
7037
  }
@@ -7158,8 +7190,8 @@ static void wsp_ggml_compute_forward_add_q_f32(
7158
7190
 
7159
7191
  const enum wsp_ggml_type type = src0->type;
7160
7192
  const enum wsp_ggml_type dtype = dst->type;
7161
- wsp_ggml_to_float_t const dequantize_row_q = type_traits[type].to_float;
7162
- wsp_ggml_from_float_t const quantize_row_q = type_traits[dtype].from_float;
7193
+ wsp_ggml_to_float_t const wsp_dewsp_quantize_row_q = type_traits[type].to_float;
7194
+ wsp_ggml_from_float_t const wsp_quantize_row_q = type_traits[dtype].from_float;
7163
7195
 
7164
7196
  // we don't support permuted src0 or src1
7165
7197
  WSP_GGML_ASSERT(nb00 == wsp_ggml_type_size(type));
@@ -7204,12 +7236,12 @@ static void wsp_ggml_compute_forward_add_q_f32(
7204
7236
  assert(ne00 % 32 == 0);
7205
7237
 
7206
7238
  // unquantize row from src0 to temp buffer
7207
- dequantize_row_q(src0_row, wdata, ne00);
7239
+ wsp_dewsp_quantize_row_q(src0_row, wdata, ne00);
7208
7240
  // add src1
7209
7241
  wsp_ggml_vec_acc_f32(ne00, wdata, src1_row);
7210
7242
  // quantize row to dst
7211
- if (quantize_row_q != NULL) {
7212
- quantize_row_q(wdata, dst_row, ne00);
7243
+ if (wsp_quantize_row_q != NULL) {
7244
+ wsp_quantize_row_q(wdata, dst_row, ne00);
7213
7245
  } else {
7214
7246
  memcpy(dst_row, wdata, ne0*nb0);
7215
7247
  }
@@ -7435,8 +7467,8 @@ static void wsp_ggml_compute_forward_add1_q_f32(
7435
7467
  WSP_GGML_TENSOR_UNARY_OP_LOCALS
7436
7468
 
7437
7469
  const enum wsp_ggml_type type = src0->type;
7438
- wsp_ggml_to_float_t const dequantize_row_q = type_traits[type].to_float;
7439
- wsp_ggml_from_float_t const quantize_row_q = type_traits[type].from_float;
7470
+ wsp_ggml_to_float_t const wsp_dewsp_quantize_row_q = type_traits[type].to_float;
7471
+ wsp_ggml_from_float_t const wsp_quantize_row_q = type_traits[type].from_float;
7440
7472
 
7441
7473
  // we don't support permuted src0
7442
7474
  WSP_GGML_ASSERT(nb00 == wsp_ggml_type_size(type));
@@ -7471,11 +7503,11 @@ static void wsp_ggml_compute_forward_add1_q_f32(
7471
7503
  assert(ne0 % 32 == 0);
7472
7504
 
7473
7505
  // unquantize row from src0 to temp buffer
7474
- dequantize_row_q(src0_row, wdata, ne0);
7506
+ wsp_dewsp_quantize_row_q(src0_row, wdata, ne0);
7475
7507
  // add src1
7476
7508
  wsp_ggml_vec_acc1_f32(ne0, wdata, v);
7477
7509
  // quantize row to dst
7478
- quantize_row_q(wdata, dst_row, ne0);
7510
+ wsp_quantize_row_q(wdata, dst_row, ne0);
7479
7511
  }
7480
7512
  }
7481
7513
 
@@ -7533,7 +7565,7 @@ static void wsp_ggml_compute_forward_acc_f32(
7533
7565
  WSP_GGML_ASSERT(wsp_ggml_is_contiguous(dst) && wsp_ggml_is_contiguous(src0));
7534
7566
 
7535
7567
  // view src0 and dst with these strides and data offset inbytes during acc
7536
- // nb0 is implicitely element_size because src0 and dst are contiguous
7568
+ // nb0 is implicitly element_size because src0 and dst are contiguous
7537
7569
  size_t nb1 = ((int32_t *) dst->op_params)[0];
7538
7570
  size_t nb2 = ((int32_t *) dst->op_params)[1];
7539
7571
  size_t nb3 = ((int32_t *) dst->op_params)[2];
@@ -7719,7 +7751,7 @@ static void wsp_ggml_compute_forward_mul_f32(
7719
7751
  const struct wsp_ggml_tensor * src0,
7720
7752
  const struct wsp_ggml_tensor * src1,
7721
7753
  struct wsp_ggml_tensor * dst) {
7722
- WSP_GGML_ASSERT(wsp_ggml_can_repeat_rows(src1, src0) && wsp_ggml_are_same_shape(src0, dst));
7754
+ WSP_GGML_ASSERT(wsp_ggml_can_repeat(src1, src0) && wsp_ggml_are_same_shape(src0, dst));
7723
7755
 
7724
7756
  if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) {
7725
7757
  return;
@@ -7727,8 +7759,10 @@ static void wsp_ggml_compute_forward_mul_f32(
7727
7759
  const int ith = params->ith;
7728
7760
  const int nth = params->nth;
7729
7761
 
7762
+ // TODO: OpenCL kernel support broadcast
7730
7763
  #ifdef WSP_GGML_USE_CLBLAST
7731
7764
  if (src1->backend == WSP_GGML_BACKEND_GPU) {
7765
+ WSP_GGML_ASSERT(wsp_ggml_are_same_shape(src0, src1));
7732
7766
  if (ith == 0) {
7733
7767
  wsp_ggml_cl_mul(src0, src1, dst);
7734
7768
  }
@@ -7742,7 +7776,6 @@ static void wsp_ggml_compute_forward_mul_f32(
7742
7776
 
7743
7777
  WSP_GGML_ASSERT( nb0 == sizeof(float));
7744
7778
  WSP_GGML_ASSERT(nb00 == sizeof(float));
7745
- WSP_GGML_ASSERT(ne00 == ne10);
7746
7779
 
7747
7780
  if (nb10 == sizeof(float)) {
7748
7781
  for (int64_t ir = ith; ir < nr; ir += nth) {
@@ -7754,20 +7787,21 @@ static void wsp_ggml_compute_forward_mul_f32(
7754
7787
  const int64_t i13 = i03 % ne13;
7755
7788
  const int64_t i12 = i02 % ne12;
7756
7789
  const int64_t i11 = i01 % ne11;
7790
+ const int64_t nr0 = ne00 / ne10;
7757
7791
 
7758
7792
  float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
7759
7793
  float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
7760
7794
  float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
7761
7795
 
7796
+ for (int64_t r = 0 ; r < nr0; ++r) {
7762
7797
  #ifdef WSP_GGML_USE_ACCELERATE
7763
- UNUSED(wsp_ggml_vec_mul_f32);
7798
+ UNUSED(wsp_ggml_vec_mul_f32);
7764
7799
 
7765
- vDSP_vmul( src0_ptr, 1, src1_ptr, 1, dst_ptr, 1, ne00);
7800
+ vDSP_vmul(src0_ptr + r*ne10, 1, src1_ptr, 1, dst_ptr + r*ne10, 1, ne10);
7766
7801
  #else
7767
- wsp_ggml_vec_mul_f32(ne00, dst_ptr, src0_ptr, src1_ptr);
7802
+ wsp_ggml_vec_mul_f32(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr);
7768
7803
  #endif
7769
- // }
7770
- // }
7804
+ }
7771
7805
  }
7772
7806
  } else {
7773
7807
  // src1 is not contiguous
@@ -7785,8 +7819,9 @@ static void wsp_ggml_compute_forward_mul_f32(
7785
7819
  float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
7786
7820
  float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
7787
7821
 
7788
- for (int64_t i0 = 0; i0 < ne00; i0++) {
7789
- float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i0*nb10);
7822
+ for (int64_t i0 = 0; i0 < ne00; ++i0) {
7823
+ const int64_t i10 = i0 % ne10;
7824
+ float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10);
7790
7825
 
7791
7826
  dst_ptr[i0] = src0_ptr[i0] * (*src1_ptr);
7792
7827
  }
@@ -7820,14 +7855,16 @@ static void wsp_ggml_compute_forward_div_f32(
7820
7855
  const struct wsp_ggml_tensor * src0,
7821
7856
  const struct wsp_ggml_tensor * src1,
7822
7857
  struct wsp_ggml_tensor * dst) {
7823
- assert(params->ith == 0);
7824
- assert(wsp_ggml_are_same_shape(src0, src1) && wsp_ggml_are_same_shape(src0, dst));
7858
+ WSP_GGML_ASSERT(wsp_ggml_can_repeat(src1, src0) && wsp_ggml_are_same_shape(src0, dst));
7825
7859
 
7826
7860
  if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) {
7827
7861
  return;
7828
7862
  }
7829
7863
 
7830
- const int nr = wsp_ggml_nrows(src0);
7864
+ const int ith = params->ith;
7865
+ const int nth = params->nth;
7866
+
7867
+ const int64_t nr = wsp_ggml_nrows(src0);
7831
7868
 
7832
7869
  WSP_GGML_TENSOR_BINARY_OP_LOCALS
7833
7870
 
@@ -7835,41 +7872,50 @@ static void wsp_ggml_compute_forward_div_f32(
7835
7872
  WSP_GGML_ASSERT(nb00 == sizeof(float));
7836
7873
 
7837
7874
  if (nb10 == sizeof(float)) {
7838
- for (int ir = 0; ir < nr; ++ir) {
7839
- // src0, src1 and dst are same shape => same indices
7840
- const int i3 = ir/(ne2*ne1);
7841
- const int i2 = (ir - i3*ne2*ne1)/ne1;
7842
- const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
7875
+ for (int64_t ir = ith; ir < nr; ir += nth) {
7876
+ // src0 and dst are same shape => same indices
7877
+ const int64_t i03 = ir/(ne02*ne01);
7878
+ const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
7879
+ const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
7843
7880
 
7844
- #ifdef WSP_GGML_USE_ACCELERATE
7845
- UNUSED(wsp_ggml_vec_div_f32);
7881
+ const int64_t i13 = i03 % ne13;
7882
+ const int64_t i12 = i02 % ne12;
7883
+ const int64_t i11 = i01 % ne11;
7884
+ const int64_t nr0 = ne00 / ne10;
7846
7885
 
7847
- vDSP_vdiv(
7848
- (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1,
7849
- (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1,
7850
- (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), 1,
7851
- ne0);
7886
+ float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
7887
+ float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
7888
+ float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
7889
+
7890
+ for (int64_t r = 0; r < nr0; ++r) {
7891
+ #ifdef WSP_GGML_USE_ACCELERATE
7892
+ UNUSED(wsp_ggml_vec_div_f32);
7893
+
7894
+ vDSP_vdiv(src1_ptr, 1, src0_ptr + r*ne10, 1, dst_ptr + r*ne10, 1, ne10);
7852
7895
  #else
7853
- wsp_ggml_vec_div_f32(ne0,
7854
- (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ),
7855
- (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01),
7856
- (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
7896
+ wsp_ggml_vec_div_f32(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr);
7857
7897
  #endif
7858
- // }
7859
- // }
7898
+ }
7860
7899
  }
7861
7900
  } else {
7862
7901
  // src1 is not contiguous
7863
- for (int ir = 0; ir < nr; ++ir) {
7864
- // src0, src1 and dst are same shape => same indices
7865
- const int i3 = ir/(ne2*ne1);
7866
- const int i2 = (ir - i3*ne2*ne1)/ne1;
7867
- const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
7902
+ for (int64_t ir = ith; ir < nr; ir += nth) {
7903
+ // src0 and dst are same shape => same indices
7904
+ // src1 is broadcastable across src0 and dst in i1, i2, i3
7905
+ const int64_t i03 = ir/(ne02*ne01);
7906
+ const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
7907
+ const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
7868
7908
 
7869
- float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 );
7870
- float * src0_ptr = (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
7871
- for (int i0 = 0; i0 < ne0; i0++) {
7872
- float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11 + i0*nb10);
7909
+ const int64_t i13 = i03 % ne13;
7910
+ const int64_t i12 = i02 % ne12;
7911
+ const int64_t i11 = i01 % ne11;
7912
+
7913
+ float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
7914
+ float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
7915
+
7916
+ for (int64_t i0 = 0; i0 < ne00; ++i0) {
7917
+ const int64_t i10 = i0 % ne10;
7918
+ float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10);
7873
7919
 
7874
7920
  dst_ptr[i0] = src0_ptr[i0] / (*src1_ptr);
7875
7921
  }
@@ -8315,7 +8361,7 @@ static void wsp_ggml_compute_forward_repeat_f16(
8315
8361
  return;
8316
8362
  }
8317
8363
 
8318
- WSP_GGML_TENSOR_UNARY_OP_LOCALS;
8364
+ WSP_GGML_TENSOR_UNARY_OP_LOCALS
8319
8365
 
8320
8366
  // guaranteed to be an integer due to the check in wsp_ggml_can_repeat
8321
8367
  const int nr0 = (int)(ne0/ne00);
@@ -8460,6 +8506,7 @@ static void wsp_ggml_compute_forward_concat_f32(
8460
8506
  WSP_GGML_ASSERT(src0->nb[0] == sizeof(float));
8461
8507
 
8462
8508
  const int ith = params->ith;
8509
+ const int nth = params->nth;
8463
8510
 
8464
8511
  WSP_GGML_TENSOR_BINARY_OP_LOCALS
8465
8512
 
@@ -8469,7 +8516,7 @@ static void wsp_ggml_compute_forward_concat_f32(
8469
8516
  WSP_GGML_ASSERT(nb10 == sizeof(float));
8470
8517
 
8471
8518
  for (int i3 = 0; i3 < ne3; i3++) {
8472
- for (int i2 = ith; i2 < ne2; i2++) {
8519
+ for (int i2 = ith; i2 < ne2; i2 += nth) {
8473
8520
  if (i2 < ne02) { // src0
8474
8521
  for (int i1 = 0; i1 < ne1; i1++) {
8475
8522
  for (int i0 = 0; i0 < ne0; i0++) {
@@ -8981,10 +9028,9 @@ static void wsp_ggml_compute_forward_silu(
8981
9028
  } break;
8982
9029
  }
8983
9030
  }
9031
+ // wsp_ggml_compute_forward_leaky_relu
8984
9032
 
8985
- // wsp_ggml_compute_forward_leaky
8986
-
8987
- static void wsp_ggml_compute_forward_leaky_f32(
9033
+ static void wsp_ggml_compute_forward_leaky_relu_f32(
8988
9034
  const struct wsp_ggml_compute_params * params,
8989
9035
  const struct wsp_ggml_tensor * src0,
8990
9036
  struct wsp_ggml_tensor * dst) {
@@ -8998,24 +9044,27 @@ static void wsp_ggml_compute_forward_leaky_f32(
8998
9044
  const int n = wsp_ggml_nrows(src0);
8999
9045
  const int nc = src0->ne[0];
9000
9046
 
9047
+ float negative_slope;
9048
+ memcpy(&negative_slope, dst->op_params, sizeof(float));
9049
+
9001
9050
  assert(dst->nb[0] == sizeof(float));
9002
9051
  assert(src0->nb[0] == sizeof(float));
9003
9052
 
9004
9053
  for (int i = 0; i < n; i++) {
9005
- wsp_ggml_vec_leaky_f32(nc,
9054
+ wsp_ggml_vec_leaky_relu_f32(nc,
9006
9055
  (float *) ((char *) dst->data + i*( dst->nb[1])),
9007
- (float *) ((char *) src0->data + i*(src0->nb[1])));
9056
+ (float *) ((char *) src0->data + i*(src0->nb[1])), negative_slope);
9008
9057
  }
9009
9058
  }
9010
9059
 
9011
- static void wsp_ggml_compute_forward_leaky(
9060
+ static void wsp_ggml_compute_forward_leaky_relu(
9012
9061
  const struct wsp_ggml_compute_params * params,
9013
9062
  const struct wsp_ggml_tensor * src0,
9014
9063
  struct wsp_ggml_tensor * dst) {
9015
9064
  switch (src0->type) {
9016
9065
  case WSP_GGML_TYPE_F32:
9017
9066
  {
9018
- wsp_ggml_compute_forward_leaky_f32(params, src0, dst);
9067
+ wsp_ggml_compute_forward_leaky_relu_f32(params, src0, dst);
9019
9068
  } break;
9020
9069
  default:
9021
9070
  {
@@ -9504,9 +9553,14 @@ static bool wsp_ggml_compute_forward_mul_mat_use_blas(
9504
9553
  const int64_t ne0 = dst->ne[0];
9505
9554
  const int64_t ne1 = dst->ne[1];
9506
9555
 
9556
+ // NOTE: with WSP_GGML_OP_MUL_MAT_ID we don't want to go through the BLAS branch because it will dequantize (to_float)
9557
+ // all the experts for each batch element and the processing would become incredibly slow
9507
9558
  // TODO: find the optimal values for these
9508
- if (wsp_ggml_is_contiguous(src0) &&
9559
+ if (dst->op != WSP_GGML_OP_MUL_MAT_ID &&
9560
+ wsp_ggml_is_contiguous(src0) &&
9509
9561
  wsp_ggml_is_contiguous(src1) &&
9562
+ //src0->type == WSP_GGML_TYPE_F32 &&
9563
+ src1->type == WSP_GGML_TYPE_F32 &&
9510
9564
  (ne0 >= 32 && ne1 >= 32 && ne10 >= 32)) {
9511
9565
 
9512
9566
  /*printf("BLAS: %d %d %d %d %d\n", ne0, ne1, ne10, ne00, ne01);*/
@@ -9517,11 +9571,16 @@ static bool wsp_ggml_compute_forward_mul_mat_use_blas(
9517
9571
  }
9518
9572
  #endif
9519
9573
 
9574
+ // off1 = offset in i11 and i1
9575
+ // cne1 = ne11 and ne1
9576
+ // in a normal matrix multiplication, off1 = 0 and cne1 = ne1
9577
+ // during WSP_GGML_TASK_INIT, the full src1 is converted regardless of off1 and cne1
9520
9578
  static void wsp_ggml_compute_forward_mul_mat(
9521
9579
  const struct wsp_ggml_compute_params * params,
9522
9580
  const struct wsp_ggml_tensor * src0,
9523
9581
  const struct wsp_ggml_tensor * src1,
9524
- struct wsp_ggml_tensor * dst) {
9582
+ struct wsp_ggml_tensor * dst,
9583
+ int64_t off1, int64_t cne1) {
9525
9584
  int64_t t0 = wsp_ggml_perf_time_us();
9526
9585
  UNUSED(t0);
9527
9586
 
@@ -9545,7 +9604,7 @@ static void wsp_ggml_compute_forward_mul_mat(
9545
9604
 
9546
9605
  // we don't support permuted src0 or src1
9547
9606
  WSP_GGML_ASSERT(nb00 == wsp_ggml_type_size(type));
9548
- WSP_GGML_ASSERT(nb10 == sizeof(float));
9607
+ WSP_GGML_ASSERT(nb10 == wsp_ggml_type_size(src1->type));
9549
9608
 
9550
9609
  // dst cannot be transposed or permuted
9551
9610
  WSP_GGML_ASSERT(nb0 == sizeof(float));
@@ -9589,10 +9648,9 @@ static void wsp_ggml_compute_forward_mul_mat(
9589
9648
  const int64_t i03 = i13/r3;
9590
9649
  const int64_t i02 = i12/r2;
9591
9650
 
9592
- const void * x = (char *) src0->data + i02*nb02 + i03*nb03;
9593
- const float * y = (float *) ((char *) src1->data + i12*nb12 + i13*nb13);
9594
-
9595
- float * d = (float *) ((char *) dst->data + i12*nb2 + i13*nb3);
9651
+ const void * x = (char *) src0->data + i02*nb02 + i03*nb03;
9652
+ const float * y = (float *) ((char *) src1->data + off1*nb11 + i12*nb12 + i13*nb13);
9653
+ float * d = (float *) ((char *) dst->data + off1*nb1 + i12*nb2 + i13*nb3);
9596
9654
 
9597
9655
  if (type != WSP_GGML_TYPE_F32) {
9598
9656
  float * const wdata = params->wdata;
@@ -9609,10 +9667,10 @@ static void wsp_ggml_compute_forward_mul_mat(
9609
9667
  }
9610
9668
 
9611
9669
  cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
9612
- ne11, ne01, ne10,
9613
- 1.0f, y, ne10,
9614
- x, ne00,
9615
- 0.0f, d, ne01);
9670
+ cne1, ne01, ne10,
9671
+ 1.0f, y, ne10,
9672
+ x, ne00,
9673
+ 0.0f, d, ne01);
9616
9674
  }
9617
9675
  }
9618
9676
 
@@ -9627,6 +9685,9 @@ static void wsp_ggml_compute_forward_mul_mat(
9627
9685
  char * wdata = params->wdata;
9628
9686
  const size_t row_size = ne10*wsp_ggml_type_size(vec_dot_type)/wsp_ggml_blck_size(vec_dot_type);
9629
9687
 
9688
+ assert(params->wsize >= ne11*ne12*ne13*row_size);
9689
+ assert(src1->type == WSP_GGML_TYPE_F32);
9690
+
9630
9691
  for (int64_t i13 = 0; i13 < ne13; ++i13) {
9631
9692
  for (int64_t i12 = 0; i12 < ne12; ++i12) {
9632
9693
  for (int64_t i11 = 0; i11 < ne11; ++i11) {
@@ -9648,7 +9709,7 @@ static void wsp_ggml_compute_forward_mul_mat(
9648
9709
  const size_t row_size = ne10*wsp_ggml_type_size(vec_dot_type)/wsp_ggml_blck_size(vec_dot_type);
9649
9710
 
9650
9711
  const int64_t nr0 = ne01; // src0 rows
9651
- const int64_t nr1 = ne11*ne12*ne13; // src1 rows
9712
+ const int64_t nr1 = cne1*ne12*ne13; // src1 rows
9652
9713
 
9653
9714
  //printf("nr0 = %lld, nr1 = %lld\n", nr0, nr1);
9654
9715
 
@@ -9690,9 +9751,9 @@ static void wsp_ggml_compute_forward_mul_mat(
9690
9751
  for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) {
9691
9752
  for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) {
9692
9753
  for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ++ir1) {
9693
- const int64_t i13 = (ir1/(ne12*ne11));
9694
- const int64_t i12 = (ir1 - i13*ne12*ne11)/ne11;
9695
- const int64_t i11 = (ir1 - i13*ne12*ne11 - i12*ne11);
9754
+ const int64_t i13 = (ir1/(ne12*cne1));
9755
+ const int64_t i12 = (ir1 - i13*ne12*cne1)/cne1;
9756
+ const int64_t i11 = (ir1 - i13*ne12*cne1 - i12*cne1) + off1;
9696
9757
 
9697
9758
  // broadcast src0 into src1
9698
9759
  const int64_t i03 = i13/r3;
@@ -9728,6 +9789,34 @@ static void wsp_ggml_compute_forward_mul_mat(
9728
9789
  }
9729
9790
  }
9730
9791
 
9792
+ // wsp_ggml_compute_forward_mul_mat_id
9793
+
9794
+ static void wsp_ggml_compute_forward_mul_mat_id(
9795
+ const struct wsp_ggml_compute_params * params,
9796
+ const struct wsp_ggml_tensor * src0,
9797
+ const struct wsp_ggml_tensor * src1,
9798
+ struct wsp_ggml_tensor * dst) {
9799
+
9800
+ if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) {
9801
+ // during WSP_GGML_TASK_INIT the entire src1 is converted to vec_dot_type
9802
+ wsp_ggml_compute_forward_mul_mat(params, dst->src[2], src1, dst, 0, dst->ne[1]);
9803
+ return;
9804
+ }
9805
+
9806
+ const struct wsp_ggml_tensor * ids = src0;
9807
+ const int id = wsp_ggml_get_op_params_i32(dst, 0);
9808
+ const int n_as = wsp_ggml_get_op_params_i32(dst, 1);
9809
+
9810
+ for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
9811
+ const int32_t row_id = *(const int32_t *) ((const char *) ids->data + i01*ids->nb[1] + id*ids->nb[0]);
9812
+
9813
+ WSP_GGML_ASSERT(row_id >= 0 && row_id < n_as);
9814
+
9815
+ const struct wsp_ggml_tensor * src0_row = dst->src[row_id + 2];
9816
+ wsp_ggml_compute_forward_mul_mat(params, src0_row, src1, dst, i01, 1);
9817
+ }
9818
+ }
9819
+
9731
9820
  // wsp_ggml_compute_forward_out_prod
9732
9821
 
9733
9822
  static void wsp_ggml_compute_forward_out_prod_f32(
@@ -9743,10 +9832,12 @@ static void wsp_ggml_compute_forward_out_prod_f32(
9743
9832
  const int ith = params->ith;
9744
9833
  const int nth = params->nth;
9745
9834
 
9835
+ WSP_GGML_ASSERT(ne0 == ne00);
9836
+ WSP_GGML_ASSERT(ne1 == ne10);
9837
+ WSP_GGML_ASSERT(ne2 == ne02);
9746
9838
  WSP_GGML_ASSERT(ne02 == ne12);
9747
- WSP_GGML_ASSERT(ne03 == ne13);
9748
- WSP_GGML_ASSERT(ne2 == ne12);
9749
9839
  WSP_GGML_ASSERT(ne3 == ne13);
9840
+ WSP_GGML_ASSERT(ne03 == ne13);
9750
9841
 
9751
9842
  // we don't support permuted src0 or src1
9752
9843
  WSP_GGML_ASSERT(nb00 == sizeof(float));
@@ -9757,18 +9848,25 @@ static void wsp_ggml_compute_forward_out_prod_f32(
9757
9848
  // WSP_GGML_ASSERT(nb1 <= nb2);
9758
9849
  // WSP_GGML_ASSERT(nb2 <= nb3);
9759
9850
 
9760
- WSP_GGML_ASSERT(ne0 == ne00);
9761
- WSP_GGML_ASSERT(ne1 == ne10);
9762
- WSP_GGML_ASSERT(ne2 == ne02);
9763
- WSP_GGML_ASSERT(ne3 == ne03);
9764
-
9765
9851
  // nb01 >= nb00 - src0 is not transposed
9766
9852
  // compute by src0 rows
9767
9853
 
9768
9854
  // TODO: #if defined(WSP_GGML_USE_CUBLAS) wsp_ggml_cuda_out_prod
9769
- // TODO: #if defined(WSP_GGML_USE_ACCELERATE) || defined(WSP_GGML_USE_OPENBLAS) || defined(WSP_GGML_USE_CLBLAST)
9855
+ // TODO: #if defined(WSP_GGML_USE_CLBLAST)
9856
+
9857
+ #if defined(WSP_GGML_USE_ACCELERATE) || defined(WSP_GGML_USE_OPENBLAS)
9858
+ bool use_blas = wsp_ggml_is_matrix(src0) &&
9859
+ wsp_ggml_is_matrix(src1) &&
9860
+ wsp_ggml_is_contiguous(src0) &&
9861
+ (wsp_ggml_is_contiguous(src1) || wsp_ggml_is_transposed(src1));
9862
+ #endif
9770
9863
 
9771
9864
  if (params->type == WSP_GGML_TASK_INIT) {
9865
+ #if defined(WSP_GGML_USE_ACCELERATE) || defined(WSP_GGML_USE_OPENBLAS) // gemm beta will zero dst
9866
+ if (use_blas) {
9867
+ return;
9868
+ }
9869
+ #endif
9772
9870
  wsp_ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0);
9773
9871
  return;
9774
9872
  }
@@ -9777,6 +9875,50 @@ static void wsp_ggml_compute_forward_out_prod_f32(
9777
9875
  return;
9778
9876
  }
9779
9877
 
9878
+ #if defined(WSP_GGML_USE_ACCELERATE) || defined(WSP_GGML_USE_OPENBLAS)
9879
+ if (use_blas) {
9880
+ if (params->ith != 0) { // All threads other than the first do no work.
9881
+ return;
9882
+ }
9883
+ // Arguments to wsp_ggml_compute_forward_out_prod (expressed as major,minor)
9884
+ // src0: (k,n)
9885
+ // src1: (k,m)
9886
+ // dst: (m,n)
9887
+ //
9888
+ // Arguments to sgemm (see https://github.com/Reference-LAPACK/lapack/blob/master/BLAS/SRC/sgemm.f)
9889
+ // Also expressed as (major,minor)
9890
+ // a: (m,k): so src1 transposed
9891
+ // b: (k,n): so src0
9892
+ // c: (m,n)
9893
+ //
9894
+ // However, if wsp_ggml_is_transposed(src1) is true, then
9895
+ // src1->data already contains a transposed version, so sgemm mustn't
9896
+ // transpose it further.
9897
+
9898
+ int n = src0->ne[0];
9899
+ int k = src0->ne[1];
9900
+ int m = src1->ne[0];
9901
+
9902
+ int transposeA, lda;
9903
+
9904
+ if (!wsp_ggml_is_transposed(src1)) {
9905
+ transposeA = CblasTrans;
9906
+ lda = m;
9907
+ } else {
9908
+ transposeA = CblasNoTrans;
9909
+ lda = k;
9910
+ }
9911
+
9912
+ float * a = (float *) ((char *) src1->data);
9913
+ float * b = (float *) ((char *) src0->data);
9914
+ float * c = (float *) ((char *) dst->data);
9915
+
9916
+ cblas_sgemm(CblasRowMajor, transposeA, CblasNoTrans, m, n, k, 1.0, a, lda, b, n, 0.0, c, n);
9917
+
9918
+ return;
9919
+ }
9920
+ #endif
9921
+
9780
9922
  // dst[:,:,:,:] = 0
9781
9923
  // for i2,i3:
9782
9924
  // for i1:
@@ -9880,7 +10022,7 @@ static void wsp_ggml_compute_forward_out_prod_q_f32(
9880
10022
  const int nth = params->nth;
9881
10023
 
9882
10024
  const enum wsp_ggml_type type = src0->type;
9883
- wsp_ggml_to_float_t const dequantize_row_q = type_traits[type].to_float;
10025
+ wsp_ggml_to_float_t const wsp_dewsp_quantize_row_q = type_traits[type].to_float;
9884
10026
 
9885
10027
  WSP_GGML_ASSERT(ne02 == ne12);
9886
10028
  WSP_GGML_ASSERT(ne03 == ne13);
@@ -9957,7 +10099,7 @@ static void wsp_ggml_compute_forward_out_prod_q_f32(
9957
10099
  float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
9958
10100
  float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
9959
10101
 
9960
- dequantize_row_q(s0, wdata, ne0);
10102
+ wsp_dewsp_quantize_row_q(s0, wdata, ne0);
9961
10103
  wsp_ggml_vec_mad_f32(ne0, d, wdata, *s1);
9962
10104
  }
9963
10105
  }
@@ -10084,7 +10226,7 @@ static void wsp_ggml_compute_forward_set_f32(
10084
10226
  WSP_GGML_ASSERT(wsp_ggml_is_contiguous(dst) && wsp_ggml_is_contiguous(src0));
10085
10227
 
10086
10228
  // view src0 and dst with these strides and data offset inbytes during set
10087
- // nb0 is implicitely element_size because src0 and dst are contiguous
10229
+ // nb0 is implicitly element_size because src0 and dst are contiguous
10088
10230
  size_t nb1 = ((int32_t *) dst->op_params)[0];
10089
10231
  size_t nb2 = ((int32_t *) dst->op_params)[1];
10090
10232
  size_t nb3 = ((int32_t *) dst->op_params)[2];
@@ -10248,21 +10390,30 @@ static void wsp_ggml_compute_forward_get_rows_q(
10248
10390
  return;
10249
10391
  }
10250
10392
 
10251
- const int nc = src0->ne[0];
10252
- const int nr = wsp_ggml_nelements(src1);
10253
- const enum wsp_ggml_type type = src0->type;
10254
- wsp_ggml_to_float_t const dequantize_row_q = type_traits[type].to_float;
10255
-
10256
- assert( dst->ne[0] == nc);
10257
- assert( dst->ne[1] == nr);
10258
- assert(src0->nb[0] == wsp_ggml_type_size(type));
10393
+ WSP_GGML_TENSOR_BINARY_OP_LOCALS
10259
10394
 
10260
- for (int i = 0; i < nr; ++i) {
10261
- const int r = ((int32_t *) src1->data)[i];
10395
+ const int64_t nc = ne00;
10396
+ const int64_t nr = wsp_ggml_nelements(src1); WSP_GGML_UNUSED(nr);
10262
10397
 
10263
- dequantize_row_q(
10264
- (const void *) ((char *) src0->data + r*src0->nb[1]),
10265
- (float *) ((char *) dst->data + i*dst->nb[1]), nc);
10398
+ const enum wsp_ggml_type type = src0->type;
10399
+ wsp_ggml_to_float_t const wsp_dewsp_quantize_row_q = type_traits[type].to_float;
10400
+
10401
+ assert(ne0 == nc);
10402
+ assert(ne02 == ne11);
10403
+ assert(nb00 == wsp_ggml_type_size(type));
10404
+ assert(wsp_ggml_nrows(dst) == nr);
10405
+
10406
+ // TODO: multi-thread
10407
+ for (int64_t i12 = 0; i12 < ne12; ++i12) {
10408
+ for (int64_t i11 = 0; i11 < ne11; ++i11) {
10409
+ for (int64_t i10 = 0; i10 < ne10; ++i10) {
10410
+ const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
10411
+
10412
+ wsp_dewsp_quantize_row_q(
10413
+ (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
10414
+ (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc);
10415
+ }
10416
+ }
10266
10417
  }
10267
10418
  }
10268
10419
 
@@ -10277,19 +10428,26 @@ static void wsp_ggml_compute_forward_get_rows_f16(
10277
10428
  return;
10278
10429
  }
10279
10430
 
10280
- const int nc = src0->ne[0];
10281
- const int nr = wsp_ggml_nelements(src1);
10431
+ WSP_GGML_TENSOR_BINARY_OP_LOCALS
10282
10432
 
10283
- assert( dst->ne[0] == nc);
10284
- assert( dst->ne[1] == nr);
10285
- assert(src0->nb[0] == sizeof(wsp_ggml_fp16_t));
10433
+ const int64_t nc = ne00;
10434
+ const int64_t nr = wsp_ggml_nelements(src1); WSP_GGML_UNUSED(nr);
10286
10435
 
10287
- for (int i = 0; i < nr; ++i) {
10288
- const int r = ((int32_t *) src1->data)[i];
10436
+ assert(ne0 == nc);
10437
+ assert(ne02 == ne11);
10438
+ assert(nb00 == sizeof(wsp_ggml_fp16_t));
10439
+ assert(wsp_ggml_nrows(dst) == nr);
10289
10440
 
10290
- for (int j = 0; j < nc; ++j) {
10291
- wsp_ggml_fp16_t v = ((wsp_ggml_fp16_t *) ((char *) src0->data + r*src0->nb[1]))[j];
10292
- ((float *) ((char *) dst->data + i*dst->nb[1]))[j] = WSP_GGML_FP16_TO_FP32(v);
10441
+ // TODO: multi-thread
10442
+ for (int64_t i12 = 0; i12 < ne12; ++i12) {
10443
+ for (int64_t i11 = 0; i11 < ne11; ++i11) {
10444
+ for (int64_t i10 = 0; i10 < ne10; ++i10) {
10445
+ const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
10446
+
10447
+ wsp_ggml_fp16_to_fp32_row(
10448
+ (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
10449
+ (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc);
10450
+ }
10293
10451
  }
10294
10452
  }
10295
10453
  }
@@ -10305,19 +10463,27 @@ static void wsp_ggml_compute_forward_get_rows_f32(
10305
10463
  return;
10306
10464
  }
10307
10465
 
10308
- const int nc = src0->ne[0];
10309
- const int nr = wsp_ggml_nelements(src1);
10466
+ WSP_GGML_TENSOR_BINARY_OP_LOCALS
10310
10467
 
10311
- assert( dst->ne[0] == nc);
10312
- assert( dst->ne[1] == nr);
10313
- assert(src0->nb[0] == sizeof(float));
10468
+ const int64_t nc = ne00;
10469
+ const int64_t nr = wsp_ggml_nelements(src1); WSP_GGML_UNUSED(nr);
10314
10470
 
10315
- for (int i = 0; i < nr; ++i) {
10316
- const int r = ((int32_t *) src1->data)[i];
10471
+ assert(ne0 == nc);
10472
+ assert(ne02 == ne11);
10473
+ assert(nb00 == sizeof(float));
10474
+ assert(wsp_ggml_nrows(dst) == nr);
10317
10475
 
10318
- wsp_ggml_vec_cpy_f32(nc,
10319
- (float *) ((char *) dst->data + i*dst->nb[1]),
10320
- (float *) ((char *) src0->data + r*src0->nb[1]));
10476
+ // TODO: multi-thread
10477
+ for (int64_t i12 = 0; i12 < ne12; ++i12) {
10478
+ for (int64_t i11 = 0; i11 < ne11; ++i11) {
10479
+ for (int64_t i10 = 0; i10 < ne10; ++i10) {
10480
+ const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
10481
+
10482
+ wsp_ggml_vec_cpy_f32(nc,
10483
+ (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3),
10484
+ (float *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03));
10485
+ }
10486
+ }
10321
10487
  }
10322
10488
  }
10323
10489
 
@@ -10630,20 +10796,25 @@ static void wsp_ggml_compute_forward_diag_mask_zero(
10630
10796
  static void wsp_ggml_compute_forward_soft_max_f32(
10631
10797
  const struct wsp_ggml_compute_params * params,
10632
10798
  const struct wsp_ggml_tensor * src0,
10633
- struct wsp_ggml_tensor * dst) {
10634
- WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src0));
10635
- WSP_GGML_ASSERT(wsp_ggml_is_contiguous(dst));
10636
- WSP_GGML_ASSERT(wsp_ggml_are_same_shape(src0, dst));
10799
+ const struct wsp_ggml_tensor * src1,
10800
+ struct wsp_ggml_tensor * dst) {
10801
+ assert(wsp_ggml_is_contiguous(dst));
10802
+ assert(wsp_ggml_are_same_shape(src0, dst));
10637
10803
 
10638
10804
  if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) {
10639
10805
  return;
10640
10806
  }
10641
10807
 
10808
+ float scale = 1.0f;
10809
+ memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
10810
+
10642
10811
  // TODO: handle transposed/permuted matrices
10643
10812
 
10644
10813
  const int ith = params->ith;
10645
10814
  const int nth = params->nth;
10646
10815
 
10816
+ const int64_t ne11 = src1 ? src1->ne[1] : 1;
10817
+
10647
10818
  const int nc = src0->ne[0];
10648
10819
  const int nr = wsp_ggml_nrows(src0);
10649
10820
 
@@ -10654,29 +10825,40 @@ static void wsp_ggml_compute_forward_soft_max_f32(
10654
10825
  const int ir0 = dr*ith;
10655
10826
  const int ir1 = MIN(ir0 + dr, nr);
10656
10827
 
10828
+ float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
10829
+
10657
10830
  for (int i1 = ir0; i1 < ir1; i1++) {
10658
- float *sp = (float *)((char *) src0->data + i1*src0->nb[1]);
10659
- float *dp = (float *)((char *) dst->data + i1*dst->nb[1]);
10831
+ float * sp = (float *)((char *) src0->data + i1*src0->nb[1]);
10832
+ float * dp = (float *)((char *) dst->data + i1*dst->nb[1]);
10833
+
10834
+ // broadcast the mask across rows
10835
+ float * mp = src1 ? (float *)((char *) src1->data + (i1%ne11)*src1->nb[1]) : NULL;
10836
+
10837
+ wsp_ggml_vec_cpy_f32 (nc, wp, sp);
10838
+ wsp_ggml_vec_scale_f32(nc, wp, scale);
10839
+ if (mp) {
10840
+ wsp_ggml_vec_acc_f32(nc, wp, mp);
10841
+ }
10660
10842
 
10661
10843
  #ifndef NDEBUG
10662
10844
  for (int i = 0; i < nc; ++i) {
10663
10845
  //printf("p[%d] = %f\n", i, p[i]);
10664
- assert(!isnan(sp[i]));
10846
+ assert(!isnan(wp[i]));
10665
10847
  }
10666
10848
  #endif
10667
10849
 
10668
10850
  float max = -INFINITY;
10669
- wsp_ggml_vec_max_f32(nc, &max, sp);
10851
+ wsp_ggml_vec_max_f32(nc, &max, wp);
10670
10852
 
10671
10853
  wsp_ggml_float sum = 0.0;
10672
10854
 
10673
10855
  uint16_t scvt;
10674
10856
  for (int i = 0; i < nc; i++) {
10675
- if (sp[i] == -INFINITY) {
10857
+ if (wp[i] == -INFINITY) {
10676
10858
  dp[i] = 0.0f;
10677
10859
  } else {
10678
- // const float val = (sp[i] == -INFINITY) ? 0.0 : exp(sp[i] - max);
10679
- wsp_ggml_fp16_t s = WSP_GGML_FP32_TO_FP16(sp[i] - max);
10860
+ // const float val = (wp[i] == -INFINITY) ? 0.0 : exp(wp[i] - max);
10861
+ wsp_ggml_fp16_t s = WSP_GGML_FP32_TO_FP16(wp[i] - max);
10680
10862
  memcpy(&scvt, &s, sizeof(scvt));
10681
10863
  const float val = WSP_GGML_FP16_TO_FP32(wsp_ggml_table_exp_f16[scvt]);
10682
10864
  sum += (wsp_ggml_float)val;
@@ -10701,11 +10883,12 @@ static void wsp_ggml_compute_forward_soft_max_f32(
10701
10883
  static void wsp_ggml_compute_forward_soft_max(
10702
10884
  const struct wsp_ggml_compute_params * params,
10703
10885
  const struct wsp_ggml_tensor * src0,
10704
- struct wsp_ggml_tensor * dst) {
10886
+ const struct wsp_ggml_tensor * src1,
10887
+ struct wsp_ggml_tensor * dst) {
10705
10888
  switch (src0->type) {
10706
10889
  case WSP_GGML_TYPE_F32:
10707
10890
  {
10708
- wsp_ggml_compute_forward_soft_max_f32(params, src0, dst);
10891
+ wsp_ggml_compute_forward_soft_max_f32(params, src0, src1, dst);
10709
10892
  } break;
10710
10893
  default:
10711
10894
  {
@@ -11086,7 +11269,8 @@ static void wsp_ggml_compute_forward_rope_f32(
11086
11269
  const struct wsp_ggml_compute_params * params,
11087
11270
  const struct wsp_ggml_tensor * src0,
11088
11271
  const struct wsp_ggml_tensor * src1,
11089
- struct wsp_ggml_tensor * dst) {
11272
+ struct wsp_ggml_tensor * dst,
11273
+ const bool forward) {
11090
11274
  if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) {
11091
11275
  return;
11092
11276
  }
@@ -11145,6 +11329,11 @@ static void wsp_ggml_compute_forward_rope_f32(
11145
11329
  const bool is_neox = mode & 2;
11146
11330
  const bool is_glm = mode & 4;
11147
11331
 
11332
+ // backward process uses inverse rotation by cos and sin.
11333
+ // cos and sin build a rotation matrix, where the inverse is the transpose.
11334
+ // this essentially just switches the sign of sin.
11335
+ const float sin_sign = forward ? 1.0f : -1.0f;
11336
+
11148
11337
  const int32_t * pos = (const int32_t *) src1->data;
11149
11338
 
11150
11339
  for (int64_t i3 = 0; i3 < ne3; i3++) {
@@ -11161,9 +11350,9 @@ static void wsp_ggml_compute_forward_rope_f32(
11161
11350
  float block_theta = MAX(p - (n_ctx - 2), 0);
11162
11351
  for (int64_t i0 = 0; i0 < ne0 / 4; i0++) {
11163
11352
  const float cos_theta = cosf(theta_base);
11164
- const float sin_theta = sinf(theta_base);
11353
+ const float sin_theta = sinf(theta_base) * sin_sign;
11165
11354
  const float cos_block_theta = cosf(block_theta);
11166
- const float sin_block_theta = sinf(block_theta);
11355
+ const float sin_block_theta = sinf(block_theta) * sin_sign;
11167
11356
 
11168
11357
  theta_base *= theta_scale;
11169
11358
  block_theta *= theta_scale;
@@ -11187,6 +11376,7 @@ static void wsp_ggml_compute_forward_rope_f32(
11187
11376
  rope_yarn(
11188
11377
  theta_base, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta
11189
11378
  );
11379
+ sin_theta *= sin_sign;
11190
11380
 
11191
11381
  // zeta scaling for xPos only:
11192
11382
  float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f;
@@ -11217,6 +11407,7 @@ static void wsp_ggml_compute_forward_rope_f32(
11217
11407
  theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
11218
11408
  &cos_theta, &sin_theta
11219
11409
  );
11410
+ sin_theta *= sin_sign;
11220
11411
 
11221
11412
  theta_base *= theta_scale;
11222
11413
 
@@ -11242,7 +11433,8 @@ static void wsp_ggml_compute_forward_rope_f16(
11242
11433
  const struct wsp_ggml_compute_params * params,
11243
11434
  const struct wsp_ggml_tensor * src0,
11244
11435
  const struct wsp_ggml_tensor * src1,
11245
- struct wsp_ggml_tensor * dst) {
11436
+ struct wsp_ggml_tensor * dst,
11437
+ const bool forward) {
11246
11438
  if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) {
11247
11439
  return;
11248
11440
  }
@@ -11294,6 +11486,11 @@ static void wsp_ggml_compute_forward_rope_f16(
11294
11486
  const bool is_neox = mode & 2;
11295
11487
  const bool is_glm = mode & 4;
11296
11488
 
11489
+ // backward process uses inverse rotation by cos and sin.
11490
+ // cos and sin build a rotation matrix, where the inverse is the transpose.
11491
+ // this essentially just switches the sign of sin.
11492
+ const float sin_sign = forward ? 1.0f : -1.0f;
11493
+
11297
11494
  const int32_t * pos = (const int32_t *) src1->data;
11298
11495
 
11299
11496
  for (int64_t i3 = 0; i3 < ne3; i3++) {
@@ -11310,9 +11507,9 @@ static void wsp_ggml_compute_forward_rope_f16(
11310
11507
  float block_theta = MAX(p - (n_ctx - 2), 0);
11311
11508
  for (int64_t i0 = 0; i0 < ne0 / 4; i0++) {
11312
11509
  const float cos_theta = cosf(theta_base);
11313
- const float sin_theta = sinf(theta_base);
11510
+ const float sin_theta = sinf(theta_base) * sin_sign;
11314
11511
  const float cos_block_theta = cosf(block_theta);
11315
- const float sin_block_theta = sinf(block_theta);
11512
+ const float sin_block_theta = sinf(block_theta) * sin_sign;
11316
11513
 
11317
11514
  theta_base *= theta_scale;
11318
11515
  block_theta *= theta_scale;
@@ -11336,6 +11533,7 @@ static void wsp_ggml_compute_forward_rope_f16(
11336
11533
  rope_yarn(
11337
11534
  theta_base, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta
11338
11535
  );
11536
+ sin_theta *= sin_sign;
11339
11537
 
11340
11538
  theta_base *= theta_scale;
11341
11539
 
@@ -11362,6 +11560,7 @@ static void wsp_ggml_compute_forward_rope_f16(
11362
11560
  theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
11363
11561
  &cos_theta, &sin_theta
11364
11562
  );
11563
+ sin_theta *= sin_sign;
11365
11564
 
11366
11565
  theta_base *= theta_scale;
11367
11566
 
@@ -11391,11 +11590,11 @@ static void wsp_ggml_compute_forward_rope(
11391
11590
  switch (src0->type) {
11392
11591
  case WSP_GGML_TYPE_F16:
11393
11592
  {
11394
- wsp_ggml_compute_forward_rope_f16(params, src0, src1, dst);
11593
+ wsp_ggml_compute_forward_rope_f16(params, src0, src1, dst, true);
11395
11594
  } break;
11396
11595
  case WSP_GGML_TYPE_F32:
11397
11596
  {
11398
- wsp_ggml_compute_forward_rope_f32(params, src0, src1, dst);
11597
+ wsp_ggml_compute_forward_rope_f32(params, src0, src1, dst, true);
11399
11598
  } break;
11400
11599
  default:
11401
11600
  {
@@ -11406,726 +11605,106 @@ static void wsp_ggml_compute_forward_rope(
11406
11605
 
11407
11606
  // wsp_ggml_compute_forward_rope_back
11408
11607
 
11409
- static void wsp_ggml_compute_forward_rope_back_f32(
11608
+ static void wsp_ggml_compute_forward_rope_back(
11410
11609
  const struct wsp_ggml_compute_params * params,
11411
11610
  const struct wsp_ggml_tensor * src0,
11412
11611
  const struct wsp_ggml_tensor * src1,
11413
11612
  struct wsp_ggml_tensor * dst) {
11414
-
11415
- if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) {
11416
- return;
11613
+ switch (src0->type) {
11614
+ case WSP_GGML_TYPE_F16:
11615
+ {
11616
+ wsp_ggml_compute_forward_rope_f16(params, src0, src1, dst, false);
11617
+ } break;
11618
+ case WSP_GGML_TYPE_F32:
11619
+ {
11620
+ wsp_ggml_compute_forward_rope_f32(params, src0, src1, dst, false);
11621
+ } break;
11622
+ default:
11623
+ {
11624
+ WSP_GGML_ASSERT(false);
11625
+ } break;
11417
11626
  }
11627
+ }
11418
11628
 
11419
- // y = rope(x, src1)
11420
- // dx = rope_back(dy, src1)
11421
- // src0 is dy, src1 contains options
11422
-
11423
- float freq_base;
11424
- float freq_scale;
11425
-
11426
- // these two only relevant for xPos RoPE:
11427
- float xpos_base;
11428
- bool xpos_down;
11429
-
11430
- //const int n_past = ((int32_t *) dst->op_params)[0];
11431
- const int n_dims = ((int32_t *) dst->op_params)[1];
11432
- const int mode = ((int32_t *) dst->op_params)[2];
11433
- const int n_ctx = ((int32_t *) dst->op_params)[3]; UNUSED(n_ctx);
11434
- memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float));
11435
- memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
11436
- memcpy(&xpos_base, (int32_t *) dst->op_params + 6, sizeof(float));
11437
- memcpy(&xpos_down, (int32_t *) dst->op_params + 7, sizeof(bool));
11629
+ // wsp_ggml_compute_forward_conv_transpose_1d
11438
11630
 
11439
- WSP_GGML_TENSOR_UNARY_OP_LOCALS
11631
+ static void wsp_ggml_compute_forward_conv_transpose_1d_f16_f32(
11632
+ const struct wsp_ggml_compute_params * params,
11633
+ const struct wsp_ggml_tensor * src0,
11634
+ const struct wsp_ggml_tensor * src1,
11635
+ struct wsp_ggml_tensor * dst) {
11636
+ WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F16);
11637
+ WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_F32);
11638
+ WSP_GGML_ASSERT( dst->type == WSP_GGML_TYPE_F32);
11440
11639
 
11441
- //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
11442
- //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
11640
+ int64_t t0 = wsp_ggml_perf_time_us();
11641
+ UNUSED(t0);
11443
11642
 
11444
- assert(nb0 == sizeof(float));
11643
+ WSP_GGML_TENSOR_BINARY_OP_LOCALS
11445
11644
 
11446
11645
  const int ith = params->ith;
11447
11646
  const int nth = params->nth;
11448
11647
 
11449
- const int nr = wsp_ggml_nrows(dst);
11450
-
11451
- // rows per thread
11452
- const int dr = (nr + nth - 1)/nth;
11648
+ const int nk = ne00*ne01*ne02;
11453
11649
 
11454
- // row range for this thread
11455
- const int ir0 = dr*ith;
11456
- const int ir1 = MIN(ir0 + dr, nr);
11650
+ WSP_GGML_ASSERT(nb00 == sizeof(wsp_ggml_fp16_t));
11651
+ WSP_GGML_ASSERT(nb10 == sizeof(float));
11457
11652
 
11458
- // row index used to determine which thread to use
11459
- int ir = 0;
11653
+ if (params->type == WSP_GGML_TASK_INIT) {
11654
+ memset(params->wdata, 0, params->wsize);
11460
11655
 
11461
- const float theta_scale = powf(freq_base, -2.0f/n_dims);
11656
+ // permute kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout)
11657
+ {
11658
+ wsp_ggml_fp16_t * const wdata = (wsp_ggml_fp16_t *) params->wdata + 0;
11462
11659
 
11463
- const bool is_neox = mode & 2;
11660
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
11661
+ for (int64_t i01 = 0; i01 < ne01; i01++) {
11662
+ const wsp_ggml_fp16_t * const src = (wsp_ggml_fp16_t *)((char *) src0->data + i02*nb02 + i01*nb01);
11663
+ wsp_ggml_fp16_t * dst_data = wdata + i01*ne00*ne02;
11664
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
11665
+ dst_data[i00*ne02 + i02] = src[i00];
11666
+ }
11667
+ }
11668
+ }
11669
+ }
11464
11670
 
11465
- const int32_t * pos = (const int32_t *) src1->data;
11671
+ // permute source data (src1) from (L x Cin) to (Cin x L)
11672
+ {
11673
+ wsp_ggml_fp16_t * const wdata = (wsp_ggml_fp16_t *) params->wdata + nk;
11674
+ wsp_ggml_fp16_t * dst_data = wdata;
11466
11675
 
11467
- for (int64_t i3 = 0; i3 < ne3; i3++) {
11468
- for (int64_t i2 = 0; i2 < ne2; i2++) {
11469
- const int64_t p = pos[i2];
11470
- for (int64_t i1 = 0; i1 < ne1; i1++) {
11471
- if (ir++ < ir0) continue;
11472
- if (ir > ir1) break;
11676
+ for (int64_t i11 = 0; i11 < ne11; i11++) {
11677
+ const float * const src = (float *)((char *) src1->data + i11*nb11);
11678
+ for (int64_t i10 = 0; i10 < ne10; i10++) {
11679
+ dst_data[i10*ne11 + i11] = WSP_GGML_FP32_TO_FP16(src[i10]);
11680
+ }
11681
+ }
11682
+ }
11473
11683
 
11474
- float theta_base = freq_scale * (float)p;
11684
+ // need to zero dst since we are accumulating into it
11685
+ memset(dst->data, 0, wsp_ggml_nbytes(dst));
11475
11686
 
11476
- if (!is_neox) {
11477
- for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
11478
- const float cos_theta = cosf(theta_base);
11479
- const float sin_theta = sinf(theta_base);
11687
+ return;
11688
+ }
11480
11689
 
11481
- // zeta scaling for xPos only:
11482
- float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f;
11483
- if (xpos_down) zeta = 1.0f / zeta;
11690
+ if (params->type == WSP_GGML_TASK_FINALIZE) {
11691
+ return;
11692
+ }
11484
11693
 
11485
- theta_base *= theta_scale;
11694
+ const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
11486
11695
 
11487
- const float * const dy = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
11488
- float * dx = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
11696
+ // total rows in dst
11697
+ const int nr = ne1;
11489
11698
 
11490
- const float dy0 = dy[0];
11491
- const float dy1 = dy[1];
11699
+ // rows per thread
11700
+ const int dr = (nr + nth - 1)/nth;
11492
11701
 
11493
- dx[0] = dy0*cos_theta*zeta + dy1*sin_theta*zeta;
11494
- dx[1] = - dy0*sin_theta*zeta + dy1*cos_theta*zeta;
11495
- }
11496
- } else {
11497
- for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
11498
- for (int64_t ic = 0; ic < n_dims; ic += 2) {
11499
- const float cos_theta = cosf(theta_base);
11500
- const float sin_theta = sinf(theta_base);
11702
+ // row range for this thread
11703
+ const int ir0 = dr*ith;
11704
+ const int ir1 = MIN(ir0 + dr, nr);
11501
11705
 
11502
- theta_base *= theta_scale;
11503
-
11504
- const int64_t i0 = ib*n_dims + ic/2;
11505
-
11506
- const float * const dy = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
11507
- float * dx = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
11508
-
11509
- const float dy0 = dy[0];
11510
- const float dy1 = dy[n_dims/2];
11511
-
11512
- dx[0] = dy0*cos_theta + dy1*sin_theta;
11513
- dx[n_dims/2] = - dy0*sin_theta + dy1*cos_theta;
11514
- }
11515
- }
11516
- }
11517
- }
11518
- }
11519
- }
11520
- }
11521
-
11522
- static void wsp_ggml_compute_forward_rope_back_f16(
11523
- const struct wsp_ggml_compute_params * params,
11524
- const struct wsp_ggml_tensor * src0,
11525
- const struct wsp_ggml_tensor * src1,
11526
- struct wsp_ggml_tensor * dst) {
11527
-
11528
- if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) {
11529
- return;
11530
- }
11531
-
11532
- // y = rope(x, src1)
11533
- // dx = rope_back(dy, src1)
11534
- // src0 is dy, src1 contains options
11535
-
11536
- //const int n_past = ((int32_t *) dst->op_params)[0];
11537
- const int n_dims = ((int32_t *) dst->op_params)[1];
11538
- const int mode = ((int32_t *) dst->op_params)[2];
11539
-
11540
- WSP_GGML_TENSOR_UNARY_OP_LOCALS
11541
-
11542
- //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
11543
- //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
11544
-
11545
- assert(nb0 == sizeof(wsp_ggml_fp16_t));
11546
-
11547
- const int ith = params->ith;
11548
- const int nth = params->nth;
11549
-
11550
- const int nr = wsp_ggml_nrows(dst);
11551
-
11552
- // rows per thread
11553
- const int dr = (nr + nth - 1)/nth;
11554
-
11555
- // row range for this thread
11556
- const int ir0 = dr*ith;
11557
- const int ir1 = MIN(ir0 + dr, nr);
11558
-
11559
- // row index used to determine which thread to use
11560
- int ir = 0;
11561
-
11562
- const float theta_scale = powf(10000.0, -2.0f/n_dims);
11563
-
11564
- const bool is_neox = mode & 2;
11565
-
11566
- const int32_t * pos = (const int32_t *) src1->data;
11567
-
11568
- for (int64_t i3 = 0; i3 < ne3; i3++) {
11569
- for (int64_t i2 = 0; i2 < ne2; i2++) {
11570
- const int64_t p = pos[i2];
11571
- for (int64_t i1 = 0; i1 < ne1; i1++) {
11572
- if (ir++ < ir0) continue;
11573
- if (ir > ir1) break;
11574
-
11575
- float theta_base = (float)p;
11576
-
11577
- if (!is_neox) {
11578
- for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
11579
- const float cos_theta = cosf(theta_base);
11580
- const float sin_theta = sinf(theta_base);
11581
-
11582
- theta_base *= theta_scale;
11583
-
11584
- const wsp_ggml_fp16_t * const dy = (wsp_ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
11585
- wsp_ggml_fp16_t * dx = (wsp_ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
11586
-
11587
- const float dy0 = WSP_GGML_FP16_TO_FP32(dy[0]);
11588
- const float dy1 = WSP_GGML_FP16_TO_FP32(dy[1]);
11589
-
11590
- dx[0] = WSP_GGML_FP32_TO_FP16( dy0*cos_theta + dy1*sin_theta);
11591
- dx[1] = WSP_GGML_FP32_TO_FP16(-dy0*sin_theta + dy1*cos_theta);
11592
- }
11593
- } else {
11594
- for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
11595
- for (int64_t ic = 0; ic < n_dims; ic += 2) {
11596
- const float cos_theta = cosf(theta_base);
11597
- const float sin_theta = sinf(theta_base);
11598
-
11599
- theta_base *= theta_scale;
11600
-
11601
- const int64_t i0 = ib*n_dims + ic/2;
11602
-
11603
- const wsp_ggml_fp16_t * const dy = (wsp_ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
11604
- wsp_ggml_fp16_t * dx = (wsp_ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
11605
-
11606
- const float dy0 = WSP_GGML_FP16_TO_FP32(dy[0]);
11607
- const float dy1 = WSP_GGML_FP16_TO_FP32(dy[n_dims/2]);
11608
-
11609
- dx[0] = WSP_GGML_FP32_TO_FP16( dy0*cos_theta + dy1*sin_theta);
11610
- dx[n_dims/2] = WSP_GGML_FP32_TO_FP16(-dy0*sin_theta + dy1*cos_theta);
11611
- }
11612
- }
11613
- }
11614
- }
11615
- }
11616
- }
11617
- }
11618
-
11619
- static void wsp_ggml_compute_forward_rope_back(
11620
- const struct wsp_ggml_compute_params * params,
11621
- const struct wsp_ggml_tensor * src0,
11622
- const struct wsp_ggml_tensor * src1,
11623
- struct wsp_ggml_tensor * dst) {
11624
- switch (src0->type) {
11625
- case WSP_GGML_TYPE_F16:
11626
- {
11627
- wsp_ggml_compute_forward_rope_back_f16(params, src0, src1, dst);
11628
- } break;
11629
- case WSP_GGML_TYPE_F32:
11630
- {
11631
- wsp_ggml_compute_forward_rope_back_f32(params, src0, src1, dst);
11632
- } break;
11633
- default:
11634
- {
11635
- WSP_GGML_ASSERT(false);
11636
- } break;
11637
- }
11638
- }
11639
-
11640
- // wsp_ggml_compute_forward_conv_1d
11641
-
11642
- static void wsp_ggml_compute_forward_conv_1d_f16_f32(
11643
- const struct wsp_ggml_compute_params * params,
11644
- const struct wsp_ggml_tensor * src0,
11645
- const struct wsp_ggml_tensor * src1,
11646
- struct wsp_ggml_tensor * dst) {
11647
- WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F16);
11648
- WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_F32);
11649
- WSP_GGML_ASSERT( dst->type == WSP_GGML_TYPE_F32);
11650
-
11651
- int64_t t0 = wsp_ggml_perf_time_us();
11652
- UNUSED(t0);
11653
-
11654
- WSP_GGML_TENSOR_BINARY_OP_LOCALS
11655
-
11656
- const int ith = params->ith;
11657
- const int nth = params->nth;
11658
-
11659
- const int nk = ne00;
11660
-
11661
- // size of the convolution row - the kernel size unrolled across all input channels
11662
- const int ew0 = nk*ne01;
11663
-
11664
- const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
11665
- const int32_t p0 = ((const int32_t*)(dst->op_params))[1];
11666
- const int32_t d0 = ((const int32_t*)(dst->op_params))[2];
11667
-
11668
- WSP_GGML_ASSERT(nb00 == sizeof(wsp_ggml_fp16_t));
11669
- WSP_GGML_ASSERT(nb10 == sizeof(float));
11670
-
11671
- if (params->type == WSP_GGML_TASK_INIT) {
11672
- memset(params->wdata, 0, params->wsize);
11673
-
11674
- wsp_ggml_fp16_t * const wdata = (wsp_ggml_fp16_t *) params->wdata + 0;
11675
-
11676
- for (int64_t i11 = 0; i11 < ne11; i11++) {
11677
- const float * const src = (float *)((char *) src1->data + i11*nb11);
11678
- wsp_ggml_fp16_t * dst_data = wdata;
11679
-
11680
- for (int64_t i0 = 0; i0 < ne0; i0++) {
11681
- for (int64_t ik = 0; ik < nk; ik++) {
11682
- const int idx0 = i0*s0 + ik*d0 - p0;
11683
-
11684
- if(!(idx0 < 0 || idx0 >= ne10)) {
11685
- dst_data[i0*ew0 + i11*nk + ik] = WSP_GGML_FP32_TO_FP16(src[idx0]);
11686
- }
11687
- }
11688
- }
11689
- }
11690
-
11691
- return;
11692
- }
11693
-
11694
- if (params->type == WSP_GGML_TASK_FINALIZE) {
11695
- return;
11696
- }
11697
-
11698
- // total rows in dst
11699
- const int nr = ne2;
11700
-
11701
- // rows per thread
11702
- const int dr = (nr + nth - 1)/nth;
11703
-
11704
- // row range for this thread
11705
- const int ir0 = dr*ith;
11706
- const int ir1 = MIN(ir0 + dr, nr);
11707
-
11708
- wsp_ggml_fp16_t * const wdata = (wsp_ggml_fp16_t *) params->wdata + 0;
11709
-
11710
- for (int i2 = 0; i2 < ne2; i2++) {
11711
- for (int i1 = ir0; i1 < ir1; i1++) {
11712
- float * dst_data = (float *)((char *) dst->data + i2*nb2 + i1*nb1);
11713
-
11714
- for (int i0 = 0; i0 < ne0; i0++) {
11715
- wsp_ggml_vec_dot_f16(ew0, dst_data + i0,
11716
- (wsp_ggml_fp16_t *) ((char *) src0->data + i1*nb02),
11717
- (wsp_ggml_fp16_t *) wdata + i2*nb2 + i0*ew0);
11718
- }
11719
- }
11720
- }
11721
- }
11722
-
11723
- static void wsp_ggml_compute_forward_conv_1d_f32(
11724
- const struct wsp_ggml_compute_params * params,
11725
- const struct wsp_ggml_tensor * src0,
11726
- const struct wsp_ggml_tensor * src1,
11727
- struct wsp_ggml_tensor * dst) {
11728
- WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F32);
11729
- WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_F32);
11730
- WSP_GGML_ASSERT( dst->type == WSP_GGML_TYPE_F32);
11731
-
11732
- int64_t t0 = wsp_ggml_perf_time_us();
11733
- UNUSED(t0);
11734
-
11735
- WSP_GGML_TENSOR_BINARY_OP_LOCALS
11736
-
11737
- const int ith = params->ith;
11738
- const int nth = params->nth;
11739
-
11740
- const int nk = ne00;
11741
-
11742
- const int ew0 = nk*ne01;
11743
-
11744
- const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
11745
- const int32_t p0 = ((const int32_t*)(dst->op_params))[1];
11746
- const int32_t d0 = ((const int32_t*)(dst->op_params))[2];
11747
-
11748
- WSP_GGML_ASSERT(nb00 == sizeof(float));
11749
- WSP_GGML_ASSERT(nb10 == sizeof(float));
11750
-
11751
- if (params->type == WSP_GGML_TASK_INIT) {
11752
- memset(params->wdata, 0, params->wsize);
11753
-
11754
- float * const wdata = (float *) params->wdata + 0;
11755
-
11756
- for (int64_t i11 = 0; i11 < ne11; i11++) {
11757
- const float * const src = (float *)((char *) src1->data + i11*nb11);
11758
- float * dst_data = wdata;
11759
-
11760
- for (int64_t i0 = 0; i0 < ne0; i0++) {
11761
- for (int64_t ik = 0; ik < nk; ik++) {
11762
- const int idx0 = i0*s0 + ik*d0 - p0;
11763
-
11764
- if(!(idx0 < 0 || idx0 >= ne10)) {
11765
- dst_data[i0*ew0 + i11*nk + ik] = src[idx0];
11766
- }
11767
- }
11768
- }
11769
- }
11770
-
11771
- return;
11772
- }
11773
-
11774
- if (params->type == WSP_GGML_TASK_FINALIZE) {
11775
- return;
11776
- }
11777
-
11778
- // total rows in dst
11779
- const int nr = ne02;
11780
-
11781
- // rows per thread
11782
- const int dr = (nr + nth - 1)/nth;
11783
-
11784
- // row range for this thread
11785
- const int ir0 = dr*ith;
11786
- const int ir1 = MIN(ir0 + dr, nr);
11787
-
11788
- float * const wdata = (float *) params->wdata + 0;
11789
-
11790
- for (int i2 = 0; i2 < ne2; i2++) {
11791
- for (int i1 = ir0; i1 < ir1; i1++) {
11792
- float * dst_data = (float *)((char *) dst->data + i2*nb2 + i1*nb1);
11793
-
11794
- for (int i0 = 0; i0 < ne0; i0++) {
11795
- wsp_ggml_vec_dot_f32(ew0, dst_data + i0,
11796
- (float *) ((char *) src0->data + i1*nb02),
11797
- (float *) wdata + i2*nb2 + i0*ew0);
11798
- }
11799
- }
11800
- }
11801
- }
11802
-
11803
- // TODO: reuse wsp_ggml_mul_mat or implement wsp_ggml_im2col and remove stage_0 and stage_1
11804
- static void gemm_f16_out_f32(int64_t m, int64_t n, int64_t k,
11805
- wsp_ggml_fp16_t * A,
11806
- wsp_ggml_fp16_t * B,
11807
- float * C,
11808
- const int ith, const int nth) {
11809
- // does not seem to make a difference
11810
- int64_t m0, m1, n0, n1;
11811
- // patches per thread
11812
- if (m > n) {
11813
- n0 = 0;
11814
- n1 = n;
11815
-
11816
- // total patches in dst
11817
- const int np = m;
11818
-
11819
- // patches per thread
11820
- const int dp = (np + nth - 1)/nth;
11821
-
11822
- // patch range for this thread
11823
- m0 = dp*ith;
11824
- m1 = MIN(m0 + dp, np);
11825
- } else {
11826
- m0 = 0;
11827
- m1 = m;
11828
-
11829
- // total patches in dst
11830
- const int np = n;
11831
-
11832
- // patches per thread
11833
- const int dp = (np + nth - 1)/nth;
11834
-
11835
- // patch range for this thread
11836
- n0 = dp*ith;
11837
- n1 = MIN(n0 + dp, np);
11838
- }
11839
-
11840
- // block-tiling attempt
11841
- int64_t blck_n = 16;
11842
- int64_t blck_m = 16;
11843
-
11844
- // int64_t CACHE_SIZE = 2 * 1024 * 1024; // 2MB
11845
- // int64_t blck_size = CACHE_SIZE / (sizeof(float) + 2 * sizeof(wsp_ggml_fp16_t) * K);
11846
- // if (blck_size > 0) {
11847
- // blck_0 = 4;
11848
- // blck_1 = blck_size / blck_0;
11849
- // if (blck_1 < 0) {
11850
- // blck_1 = 1;
11851
- // }
11852
- // // blck_0 = (int64_t)sqrt(blck_size);
11853
- // // blck_1 = blck_0;
11854
- // }
11855
- // // printf("%zd %zd %zd %zd\n", blck_size, K, blck_0, blck_1);
11856
-
11857
- for (int j = n0; j < n1; j+=blck_n) {
11858
- for (int i = m0; i < m1; i+=blck_m) {
11859
- // printf("i j k => %d %d %d\n", i, j, K);
11860
- for (int ii = i; ii < i + blck_m && ii < m1; ii++) {
11861
- for (int jj = j; jj < j + blck_n && jj < n1; jj++) {
11862
- wsp_ggml_vec_dot_f16(k,
11863
- C + ii*n + jj,
11864
- A + ii * k,
11865
- B + jj * k);
11866
- }
11867
- }
11868
- }
11869
- }
11870
- }
11871
-
11872
- // src0: kernel [OC, IC, K]
11873
- // src1: signal [N, IC, IL]
11874
- // dst: result [N, OL, IC*K]
11875
- static void wsp_ggml_compute_forward_conv_1d_stage_0_f32(
11876
- const struct wsp_ggml_compute_params * params,
11877
- const struct wsp_ggml_tensor * src0,
11878
- const struct wsp_ggml_tensor * src1,
11879
- struct wsp_ggml_tensor * dst) {
11880
- WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F16);
11881
- WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_F32);
11882
- WSP_GGML_ASSERT( dst->type == WSP_GGML_TYPE_F16);
11883
-
11884
- int64_t t0 = wsp_ggml_perf_time_us();
11885
- UNUSED(t0);
11886
-
11887
- WSP_GGML_TENSOR_BINARY_OP_LOCALS;
11888
-
11889
- const int64_t N = ne12;
11890
- const int64_t IC = ne11;
11891
- const int64_t IL = ne10;
11892
-
11893
- const int64_t K = ne00;
11894
-
11895
- const int64_t OL = ne1;
11896
-
11897
- const int ith = params->ith;
11898
- const int nth = params->nth;
11899
-
11900
- const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
11901
- const int32_t p0 = ((const int32_t*)(dst->op_params))[1];
11902
- const int32_t d0 = ((const int32_t*)(dst->op_params))[2];
11903
-
11904
- WSP_GGML_ASSERT(nb00 == sizeof(wsp_ggml_fp16_t));
11905
- WSP_GGML_ASSERT(nb10 == sizeof(float));
11906
-
11907
- if (params->type == WSP_GGML_TASK_INIT) {
11908
- memset(dst->data, 0, wsp_ggml_nbytes(dst));
11909
- return;
11910
- }
11911
-
11912
- if (params->type == WSP_GGML_TASK_FINALIZE) {
11913
- return;
11914
- }
11915
-
11916
- // im2col: [N, IC, IL] => [N, OL, IC*K]
11917
- {
11918
- wsp_ggml_fp16_t * const wdata = (wsp_ggml_fp16_t *) dst->data;
11919
-
11920
- for (int64_t in = 0; in < N; in++) {
11921
- for (int64_t iol = 0; iol < OL; iol++) {
11922
- for (int64_t iic = ith; iic < IC; iic+=nth) {
11923
-
11924
- // micro kernel
11925
- wsp_ggml_fp16_t * dst_data = wdata + (in*OL + iol)*(IC*K); // [IC, K]
11926
- const float * const src_data = (float *)((char *) src1->data + in*nb12 + iic*nb11); // [IL]
11927
-
11928
- for (int64_t ik = 0; ik < K; ik++) {
11929
- const int64_t iil = iol*s0 + ik*d0 - p0;
11930
-
11931
- if (!(iil < 0 || iil >= IL)) {
11932
- dst_data[iic*K + ik] = WSP_GGML_FP32_TO_FP16(src_data[iil]);
11933
- }
11934
- }
11935
- }
11936
- }
11937
- }
11938
- }
11939
- }
11940
-
11941
- // gemm: [N, OC, OL] = [OC, IC * K] x [N*OL, IC * K]
11942
- // src0: [OC, IC, K]
11943
- // src1: [N, OL, IC * K]
11944
- // result: [N, OC, OL]
11945
- static void wsp_ggml_compute_forward_conv_1d_stage_1_f16(
11946
- const struct wsp_ggml_compute_params * params,
11947
- const struct wsp_ggml_tensor * src0,
11948
- const struct wsp_ggml_tensor * src1,
11949
- struct wsp_ggml_tensor * dst) {
11950
- WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F16);
11951
- WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_F16);
11952
- WSP_GGML_ASSERT( dst->type == WSP_GGML_TYPE_F32);
11953
-
11954
- int64_t t0 = wsp_ggml_perf_time_us();
11955
- UNUSED(t0);
11956
-
11957
- if (params->type == WSP_GGML_TASK_INIT) {
11958
- return;
11959
- }
11960
-
11961
- if (params->type == WSP_GGML_TASK_FINALIZE) {
11962
- return;
11963
- }
11964
-
11965
- WSP_GGML_TENSOR_BINARY_OP_LOCALS;
11966
-
11967
- WSP_GGML_ASSERT(nb00 == sizeof(wsp_ggml_fp16_t));
11968
- WSP_GGML_ASSERT(nb10 == sizeof(wsp_ggml_fp16_t));
11969
- WSP_GGML_ASSERT(nb0 == sizeof(float));
11970
-
11971
- const int N = ne12;
11972
- const int OL = ne11;
11973
-
11974
- const int OC = ne02;
11975
- const int IC = ne01;
11976
- const int K = ne00;
11977
-
11978
- const int ith = params->ith;
11979
- const int nth = params->nth;
11980
-
11981
- int64_t m = OC;
11982
- int64_t n = OL;
11983
- int64_t k = IC * K;
11984
-
11985
- // [N, OC, OL] = [OC, IC * K] x [N*OL, IC * K]
11986
- for (int i = 0; i < N; i++) {
11987
- wsp_ggml_fp16_t * A = (wsp_ggml_fp16_t *)src0->data; // [m, k]
11988
- wsp_ggml_fp16_t * B = (wsp_ggml_fp16_t *)src1->data + i * m * k; // [n, k]
11989
- float * C = (float *)dst->data + i * m * n; // [m, n]
11990
-
11991
- gemm_f16_out_f32(m, n, k, A, B, C, ith, nth);
11992
- }
11993
- }
11994
-
11995
- static void wsp_ggml_compute_forward_conv_1d(
11996
- const struct wsp_ggml_compute_params * params,
11997
- const struct wsp_ggml_tensor * src0,
11998
- const struct wsp_ggml_tensor * src1,
11999
- struct wsp_ggml_tensor * dst) {
12000
- switch(src0->type) {
12001
- case WSP_GGML_TYPE_F16:
12002
- {
12003
- wsp_ggml_compute_forward_conv_1d_f16_f32(params, src0, src1, dst);
12004
- } break;
12005
- case WSP_GGML_TYPE_F32:
12006
- {
12007
- wsp_ggml_compute_forward_conv_1d_f32(params, src0, src1, dst);
12008
- } break;
12009
- default:
12010
- {
12011
- WSP_GGML_ASSERT(false);
12012
- } break;
12013
- }
12014
- }
12015
-
12016
- static void wsp_ggml_compute_forward_conv_1d_stage_0(
12017
- const struct wsp_ggml_compute_params * params,
12018
- const struct wsp_ggml_tensor * src0,
12019
- const struct wsp_ggml_tensor * src1,
12020
- struct wsp_ggml_tensor * dst) {
12021
- switch(src0->type) {
12022
- case WSP_GGML_TYPE_F16:
12023
- {
12024
- wsp_ggml_compute_forward_conv_1d_stage_0_f32(params, src0, src1, dst);
12025
- } break;
12026
- default:
12027
- {
12028
- WSP_GGML_ASSERT(false);
12029
- } break;
12030
- }
12031
- }
12032
-
12033
- static void wsp_ggml_compute_forward_conv_1d_stage_1(
12034
- const struct wsp_ggml_compute_params * params,
12035
- const struct wsp_ggml_tensor * src0,
12036
- const struct wsp_ggml_tensor * src1,
12037
- struct wsp_ggml_tensor * dst) {
12038
- switch(src0->type) {
12039
- case WSP_GGML_TYPE_F16:
12040
- {
12041
- wsp_ggml_compute_forward_conv_1d_stage_1_f16(params, src0, src1, dst);
12042
- } break;
12043
- default:
12044
- {
12045
- WSP_GGML_ASSERT(false);
12046
- } break;
12047
- }
12048
- }
12049
-
12050
- // wsp_ggml_compute_forward_conv_transpose_1d
12051
-
12052
- static void wsp_ggml_compute_forward_conv_transpose_1d_f16_f32(
12053
- const struct wsp_ggml_compute_params * params,
12054
- const struct wsp_ggml_tensor * src0,
12055
- const struct wsp_ggml_tensor * src1,
12056
- struct wsp_ggml_tensor * dst) {
12057
- WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F16);
12058
- WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_F32);
12059
- WSP_GGML_ASSERT( dst->type == WSP_GGML_TYPE_F32);
12060
-
12061
- int64_t t0 = wsp_ggml_perf_time_us();
12062
- UNUSED(t0);
12063
-
12064
- WSP_GGML_TENSOR_BINARY_OP_LOCALS
12065
-
12066
- const int ith = params->ith;
12067
- const int nth = params->nth;
12068
-
12069
- const int nk = ne00*ne01*ne02;
12070
-
12071
- WSP_GGML_ASSERT(nb00 == sizeof(wsp_ggml_fp16_t));
12072
- WSP_GGML_ASSERT(nb10 == sizeof(float));
12073
-
12074
- if (params->type == WSP_GGML_TASK_INIT) {
12075
- memset(params->wdata, 0, params->wsize);
12076
-
12077
- // permute kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout)
12078
- {
12079
- wsp_ggml_fp16_t * const wdata = (wsp_ggml_fp16_t *) params->wdata + 0;
12080
-
12081
- for (int64_t i02 = 0; i02 < ne02; i02++) {
12082
- for (int64_t i01 = 0; i01 < ne01; i01++) {
12083
- const wsp_ggml_fp16_t * const src = (wsp_ggml_fp16_t *)((char *) src0->data + i02*nb02 + i01*nb01);
12084
- wsp_ggml_fp16_t * dst_data = wdata + i01*ne00*ne02;
12085
- for (int64_t i00 = 0; i00 < ne00; i00++) {
12086
- dst_data[i00*ne02 + i02] = src[i00];
12087
- }
12088
- }
12089
- }
12090
- }
12091
-
12092
- // permute source data (src1) from (L x Cin) to (Cin x L)
12093
- {
12094
- wsp_ggml_fp16_t * const wdata = (wsp_ggml_fp16_t *) params->wdata + nk;
12095
- wsp_ggml_fp16_t * dst_data = wdata;
12096
-
12097
- for (int64_t i11 = 0; i11 < ne11; i11++) {
12098
- const float * const src = (float *)((char *) src1->data + i11*nb11);
12099
- for (int64_t i10 = 0; i10 < ne10; i10++) {
12100
- dst_data[i10*ne11 + i11] = WSP_GGML_FP32_TO_FP16(src[i10]);
12101
- }
12102
- }
12103
- }
12104
-
12105
- // need to zero dst since we are accumulating into it
12106
- memset(dst->data, 0, wsp_ggml_nbytes(dst));
12107
-
12108
- return;
12109
- }
12110
-
12111
- if (params->type == WSP_GGML_TASK_FINALIZE) {
12112
- return;
12113
- }
12114
-
12115
- const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
12116
-
12117
- // total rows in dst
12118
- const int nr = ne1;
12119
-
12120
- // rows per thread
12121
- const int dr = (nr + nth - 1)/nth;
12122
-
12123
- // row range for this thread
12124
- const int ir0 = dr*ith;
12125
- const int ir1 = MIN(ir0 + dr, nr);
12126
-
12127
- wsp_ggml_fp16_t * const wdata = (wsp_ggml_fp16_t *) params->wdata + 0;
12128
- wsp_ggml_fp16_t * const wdata_src = wdata + nk;
11706
+ wsp_ggml_fp16_t * const wdata = (wsp_ggml_fp16_t *) params->wdata + 0;
11707
+ wsp_ggml_fp16_t * const wdata_src = wdata + nk;
12129
11708
 
12130
11709
  for (int i1 = ir0; i1 < ir1; i1++) {
12131
11710
  float * dst_data = (float *)((char *) dst->data + i1*nb1);
@@ -12258,12 +11837,10 @@ static void wsp_ggml_compute_forward_conv_transpose_1d(
12258
11837
  }
12259
11838
  }
12260
11839
 
12261
- // wsp_ggml_compute_forward_conv_2d
12262
-
12263
11840
  // src0: kernel [OC, IC, KH, KW]
12264
11841
  // src1: image [N, IC, IH, IW]
12265
11842
  // dst: result [N, OH, OW, IC*KH*KW]
12266
- static void wsp_ggml_compute_forward_conv_2d_stage_0_f32(
11843
+ static void wsp_ggml_compute_forward_im2col_f16(
12267
11844
  const struct wsp_ggml_compute_params * params,
12268
11845
  const struct wsp_ggml_tensor * src0,
12269
11846
  const struct wsp_ggml_tensor * src1,
@@ -12277,218 +11854,35 @@ static void wsp_ggml_compute_forward_conv_2d_stage_0_f32(
12277
11854
 
12278
11855
  WSP_GGML_TENSOR_BINARY_OP_LOCALS;
12279
11856
 
12280
- const int64_t N = ne13;
12281
- const int64_t IC = ne12;
12282
- const int64_t IH = ne11;
12283
- const int64_t IW = ne10;
12284
-
12285
- // const int64_t OC = ne03;
12286
- // const int64_t IC = ne02;
12287
- const int64_t KH = ne01;
12288
- const int64_t KW = ne00;
12289
-
12290
- const int64_t OH = ne2;
12291
- const int64_t OW = ne1;
12292
-
12293
- const int ith = params->ith;
12294
- const int nth = params->nth;
12295
-
12296
- const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
12297
- const int32_t s1 = ((const int32_t*)(dst->op_params))[1];
12298
- const int32_t p0 = ((const int32_t*)(dst->op_params))[2];
12299
- const int32_t p1 = ((const int32_t*)(dst->op_params))[3];
12300
- const int32_t d0 = ((const int32_t*)(dst->op_params))[4];
12301
- const int32_t d1 = ((const int32_t*)(dst->op_params))[5];
12302
-
12303
- WSP_GGML_ASSERT(nb00 == sizeof(wsp_ggml_fp16_t));
12304
- WSP_GGML_ASSERT(nb10 == sizeof(float));
12305
-
12306
- if (params->type == WSP_GGML_TASK_INIT) {
12307
- memset(dst->data, 0, wsp_ggml_nbytes(dst));
12308
- return;
12309
- }
12310
-
12311
- if (params->type == WSP_GGML_TASK_FINALIZE) {
12312
- return;
12313
- }
12314
-
12315
- // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
12316
- {
12317
- wsp_ggml_fp16_t * const wdata = (wsp_ggml_fp16_t *) dst->data;
12318
-
12319
- for (int64_t in = 0; in < N; in++) {
12320
- for (int64_t ioh = 0; ioh < OH; ioh++) {
12321
- for (int64_t iow = 0; iow < OW; iow++) {
12322
- for (int64_t iic = ith; iic < IC; iic+=nth) {
12323
-
12324
- // micro kernel
12325
- wsp_ggml_fp16_t * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
12326
- const float * const src_data = (float *)((char *) src1->data + in*nb13 + iic*nb12); // [IH, IW]
12327
-
12328
- for (int64_t ikh = 0; ikh < KH; ikh++) {
12329
- for (int64_t ikw = 0; ikw < KW; ikw++) {
12330
- const int64_t iiw = iow*s0 + ikw*d0 - p0;
12331
- const int64_t iih = ioh*s1 + ikh*d1 - p1;
12332
-
12333
- if (!(iih < 0 || iih >= IH || iiw < 0 || iiw >= IW)) {
12334
- dst_data[iic*(KH*KW) + ikh*KW + ikw] = WSP_GGML_FP32_TO_FP16(src_data[iih*IW + iiw]);
12335
- }
12336
- }
12337
- }
12338
- }
12339
- }
12340
- }
12341
- }
12342
- }
12343
- }
12344
-
12345
- // gemm: [N, OC, OH, OW] = [OC, IC * KH * KW] x [N*OH*OW, IC * KH * KW]
12346
- // src0: [OC, IC, KH, KW]
12347
- // src1: [N, OH, OW, IC * KH * KW]
12348
- // result: [N, OC, OH, OW]
12349
- static void wsp_ggml_compute_forward_conv_2d_stage_1_f16(
12350
- const struct wsp_ggml_compute_params * params,
12351
- const struct wsp_ggml_tensor * src0,
12352
- const struct wsp_ggml_tensor * src1,
12353
- struct wsp_ggml_tensor * dst) {
12354
- WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F16);
12355
- WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_F16);
12356
- WSP_GGML_ASSERT( dst->type == WSP_GGML_TYPE_F32);
12357
-
12358
- int64_t t0 = wsp_ggml_perf_time_us();
12359
- UNUSED(t0);
12360
-
12361
- if (params->type == WSP_GGML_TASK_INIT) {
12362
- return;
12363
- }
12364
-
12365
- if (params->type == WSP_GGML_TASK_FINALIZE) {
12366
- return;
12367
- }
12368
-
12369
- WSP_GGML_TENSOR_BINARY_OP_LOCALS;
12370
-
12371
- WSP_GGML_ASSERT(nb00 == sizeof(wsp_ggml_fp16_t));
12372
- WSP_GGML_ASSERT(nb10 == sizeof(wsp_ggml_fp16_t));
12373
- WSP_GGML_ASSERT(nb0 == sizeof(float));
12374
-
12375
- const int N = ne13;
12376
- const int OH = ne12;
12377
- const int OW = ne11;
12378
-
12379
- const int OC = ne03;
12380
- const int IC = ne02;
12381
- const int KH = ne01;
12382
- const int KW = ne00;
11857
+ const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
11858
+ const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
11859
+ const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
11860
+ const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
11861
+ const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
11862
+ const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
11863
+ const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
12383
11864
 
12384
11865
  const int ith = params->ith;
12385
11866
  const int nth = params->nth;
12386
11867
 
12387
- int64_t m = OC;
12388
- int64_t n = OH * OW;
12389
- int64_t k = IC * KH * KW;
12390
-
12391
- // [N, OC, OH, OW] = [OC, IC * KH * KW] x [N*OH*OW, IC * KH * KW]
12392
- for (int i = 0; i < N; i++) {
12393
- wsp_ggml_fp16_t * A = (wsp_ggml_fp16_t *)src0->data; // [m, k]
12394
- wsp_ggml_fp16_t * B = (wsp_ggml_fp16_t *)src1->data + i * m * k; // [n, k]
12395
- float * C = (float *)dst->data + i * m * n; // [m, n]
12396
-
12397
- gemm_f16_out_f32(m, n, k, A, B, C, ith, nth);
12398
- }
12399
- }
12400
-
12401
- static void wsp_ggml_compute_forward_conv_2d_f16_f32(
12402
- const struct wsp_ggml_compute_params * params,
12403
- const struct wsp_ggml_tensor * src0,
12404
- const struct wsp_ggml_tensor * src1,
12405
- struct wsp_ggml_tensor * dst) {
12406
- WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F16);
12407
- WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_F32);
12408
- WSP_GGML_ASSERT( dst->type == WSP_GGML_TYPE_F32);
12409
-
12410
- int64_t t0 = wsp_ggml_perf_time_us();
12411
- UNUSED(t0);
12412
-
12413
- WSP_GGML_TENSOR_BINARY_OP_LOCALS
12414
-
12415
- // src1: image [N, IC, IH, IW]
12416
- // src0: kernel [OC, IC, KH, KW]
12417
- // dst: result [N, OC, OH, OW]
12418
- // ne12: IC
12419
- // ne0: OW
12420
- // ne1: OH
12421
- // nk0: KW
12422
- // nk1: KH
12423
- // ne13: N
12424
-
12425
- const int N = ne13;
12426
- const int IC = ne12;
12427
- const int IH = ne11;
12428
- const int IW = ne10;
12429
-
12430
- const int OC = ne03;
12431
- // const int IC = ne02;
12432
- const int KH = ne01;
12433
- const int KW = ne00;
12434
-
12435
- const int OH = ne1;
12436
- const int OW = ne0;
12437
-
12438
- const int ith = params->ith;
12439
- const int nth = params->nth;
11868
+ const int64_t N = is_2D ? ne13 : ne12;
11869
+ const int64_t IC = is_2D ? ne12 : ne11;
11870
+ const int64_t IH = is_2D ? ne11 : 1;
11871
+ const int64_t IW = ne10;
12440
11872
 
12441
- // const int nk0 = ne00;
12442
- // const int nk1 = ne01;
11873
+ const int64_t KH = is_2D ? ne01 : 1;
11874
+ const int64_t KW = ne00;
12443
11875
 
12444
- // size of the convolution row - the kernel size unrolled across all channels
12445
- // const int ew0 = nk0*nk1*ne02;
12446
- // ew0: IC*KH*KW
11876
+ const int64_t OH = is_2D ? ne2 : 1;
11877
+ const int64_t OW = ne1;
12447
11878
 
12448
- const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
12449
- const int32_t s1 = ((const int32_t*)(dst->op_params))[1];
12450
- const int32_t p0 = ((const int32_t*)(dst->op_params))[2];
12451
- const int32_t p1 = ((const int32_t*)(dst->op_params))[3];
12452
- const int32_t d0 = ((const int32_t*)(dst->op_params))[4];
12453
- const int32_t d1 = ((const int32_t*)(dst->op_params))[5];
11879
+ int ofs0 = is_2D ? nb13 : nb12;
11880
+ int ofs1 = is_2D ? nb12 : nb11;
12454
11881
 
12455
11882
  WSP_GGML_ASSERT(nb00 == sizeof(wsp_ggml_fp16_t));
12456
11883
  WSP_GGML_ASSERT(nb10 == sizeof(float));
12457
11884
 
12458
11885
  if (params->type == WSP_GGML_TASK_INIT) {
12459
- memset(params->wdata, 0, params->wsize);
12460
-
12461
- // prepare source data (src1)
12462
- // im2col: [N, IC, IH, IW] => [N*OH*OW, IC*KH*KW]
12463
-
12464
- {
12465
- wsp_ggml_fp16_t * const wdata = (wsp_ggml_fp16_t *) params->wdata + 0;
12466
-
12467
- for (int in = 0; in < N; in++) {
12468
- for (int iic = 0; iic < IC; iic++) {
12469
- for (int ioh = 0; ioh < OH; ioh++) {
12470
- for (int iow = 0; iow < OW; iow++) {
12471
-
12472
- // micro kernel
12473
- wsp_ggml_fp16_t * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
12474
- const float * const src_data = (float *)((char *) src1->data + in*nb13 + iic*nb12); // [IH, IW]
12475
-
12476
- for (int ikh = 0; ikh < KH; ikh++) {
12477
- for (int ikw = 0; ikw < KW; ikw++) {
12478
- const int iiw = iow*s0 + ikw*d0 - p0;
12479
- const int iih = ioh*s1 + ikh*d1 - p1;
12480
-
12481
- if (!(iih < 0 || iih >= IH || iiw < 0 || iiw >= IW)) {
12482
- dst_data[iic*(KH*KW) + ikh*KW + ikw] = WSP_GGML_FP32_TO_FP16(src_data[iih*IW + iiw]);
12483
- }
12484
- }
12485
- }
12486
- }
12487
- }
12488
- }
12489
- }
12490
- }
12491
-
12492
11886
  return;
12493
11887
  }
12494
11888
 
@@ -12496,69 +11890,39 @@ static void wsp_ggml_compute_forward_conv_2d_f16_f32(
12496
11890
  return;
12497
11891
  }
12498
11892
 
12499
- wsp_ggml_fp16_t * const wdata = (wsp_ggml_fp16_t *) params->wdata + 0;
12500
- // wdata: [N*OH*OW, IC*KH*KW]
12501
- // dst: result [N, OC, OH, OW]
12502
- // src0: kernel [OC, IC, KH, KW]
12503
-
12504
- int64_t m = OC;
12505
- int64_t n = OH * OW;
12506
- int64_t k = IC * KH * KW;
12507
-
12508
- // [N, OC, OH, OW] = [OC, IC * KH * KW] x [N*OH*OW, IC * KH * KW]
12509
- for (int i = 0; i < N; i++) {
12510
- wsp_ggml_fp16_t * A = (wsp_ggml_fp16_t *)src0->data; // [m, k]
12511
- wsp_ggml_fp16_t * B = (wsp_ggml_fp16_t *)wdata + i * m * k; // [n, k]
12512
- float * C = (float *)dst->data + i * m * n; // [m * k]
12513
-
12514
- gemm_f16_out_f32(m, n, k, A, B, C, ith, nth);
12515
- }
12516
- }
12517
-
12518
- static void wsp_ggml_compute_forward_conv_2d(
12519
- const struct wsp_ggml_compute_params * params,
12520
- const struct wsp_ggml_tensor * src0,
12521
- const struct wsp_ggml_tensor * src1,
12522
- struct wsp_ggml_tensor * dst) {
12523
- switch (src0->type) {
12524
- case WSP_GGML_TYPE_F16:
12525
- {
12526
- wsp_ggml_compute_forward_conv_2d_f16_f32(params, src0, src1, dst);
12527
- } break;
12528
- case WSP_GGML_TYPE_F32:
12529
- {
12530
- //wsp_ggml_compute_forward_conv_2d_f32(params, src0, src1, dst);
12531
- WSP_GGML_ASSERT(false);
12532
- } break;
12533
- default:
12534
- {
12535
- WSP_GGML_ASSERT(false);
12536
- } break;
12537
- }
12538
- }
12539
-
12540
- static void wsp_ggml_compute_forward_conv_2d_stage_0(
12541
- const struct wsp_ggml_compute_params * params,
12542
- const struct wsp_ggml_tensor * src0,
12543
- const struct wsp_ggml_tensor * src1,
12544
- struct wsp_ggml_tensor * dst) {
12545
- switch (src0->type) {
12546
- case WSP_GGML_TYPE_F16:
12547
- {
12548
- wsp_ggml_compute_forward_conv_2d_stage_0_f32(params, src0, src1, dst);
12549
- } break;
12550
- case WSP_GGML_TYPE_F32:
12551
- {
12552
- WSP_GGML_ASSERT(false);
12553
- } break;
12554
- default:
12555
- {
12556
- WSP_GGML_ASSERT(false);
12557
- } break;
11893
+ // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
11894
+ {
11895
+ wsp_ggml_fp16_t * const wdata = (wsp_ggml_fp16_t *) dst->data;
11896
+
11897
+ for (int64_t in = 0; in < N; in++) {
11898
+ for (int64_t ioh = 0; ioh < OH; ioh++) { // 1
11899
+ for (int64_t iow = 0; iow < OW; iow++) {
11900
+ for (int64_t iic = ith; iic < IC; iic += nth) {
11901
+
11902
+ // micro kernel
11903
+ wsp_ggml_fp16_t * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
11904
+ const float * const src_data = (float *)((char *) src1->data + in*ofs0 + iic*ofs1); // [IH, IW]
11905
+
11906
+ for (int64_t ikh = 0; ikh < KH; ikh++) { // 1
11907
+ for (int64_t ikw = 0; ikw < KW; ikw++) {
11908
+ const int64_t iiw = iow*s0 + ikw*d0 - p0;
11909
+ const int64_t iih = ioh*s1 + ikh*d1 - p1;
11910
+
11911
+ if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
11912
+ dst_data[iic*(KH*KW) + ikh*KW + ikw] = 0;
11913
+ } else {
11914
+ dst_data[iic*(KH*KW) + ikh*KW + ikw] = WSP_GGML_FP32_TO_FP16(src_data[iih*IW + iiw]);
11915
+ }
11916
+ }
11917
+ }
11918
+ }
11919
+ }
11920
+ }
11921
+ }
12558
11922
  }
12559
11923
  }
12560
11924
 
12561
- static void wsp_ggml_compute_forward_conv_2d_stage_1(
11925
+ static void wsp_ggml_compute_forward_im2col(
12562
11926
  const struct wsp_ggml_compute_params * params,
12563
11927
  const struct wsp_ggml_tensor * src0,
12564
11928
  const struct wsp_ggml_tensor * src1,
@@ -12566,7 +11930,7 @@ static void wsp_ggml_compute_forward_conv_2d_stage_1(
12566
11930
  switch (src0->type) {
12567
11931
  case WSP_GGML_TYPE_F16:
12568
11932
  {
12569
- wsp_ggml_compute_forward_conv_2d_stage_1_f16(params, src0, src1, dst);
11933
+ wsp_ggml_compute_forward_im2col_f16(params, src0, src1, dst);
12570
11934
  } break;
12571
11935
  case WSP_GGML_TYPE_F32:
12572
11936
  {
@@ -12839,6 +12203,7 @@ static void wsp_ggml_compute_forward_upscale_f32(
12839
12203
  WSP_GGML_ASSERT(src0->nb[0] == sizeof(float));
12840
12204
 
12841
12205
  const int ith = params->ith;
12206
+ const int nth = params->nth;
12842
12207
 
12843
12208
  WSP_GGML_TENSOR_UNARY_OP_LOCALS
12844
12209
 
@@ -12846,16 +12211,17 @@ static void wsp_ggml_compute_forward_upscale_f32(
12846
12211
 
12847
12212
  // TODO: optimize
12848
12213
 
12849
- for (int i03 = 0; i03 < ne03; i03++) {
12850
- for (int i02 = ith; i02 < ne02; i02++) {
12851
- for (int m = 0; m < dst->ne[1]; m++) {
12852
- int i01 = m / scale_factor;
12853
- for (int n = 0; n < dst->ne[0]; n++) {
12854
- int i00 = n / scale_factor;
12855
-
12856
- const float * x = (float *)((char *) src0->data + i00 * nb00 +i01 * nb01 + i02 * nb02 + i03 * nb03);
12214
+ for (int64_t i3 = 0; i3 < ne3; i3++) {
12215
+ const int64_t i03 = i3;
12216
+ for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
12217
+ const int64_t i02 = i2;
12218
+ for (int64_t i1 = 0; i1 < ne1; i1++) {
12219
+ const int64_t i01 = i1 / scale_factor;
12220
+ for (int64_t i0 = 0; i0 < ne0; i0++) {
12221
+ const int64_t i00 = i0 / scale_factor;
12857
12222
 
12858
- float * y = (float *)((char *) dst->data + n * dst->nb[0] + m * dst->nb[1] + i02 * dst->nb[2] + i03 * dst->nb[3]);
12223
+ const float * x = (float *)((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
12224
+ float * y = (float *)((char *) dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
12859
12225
 
12860
12226
  *y = *x;
12861
12227
  }
@@ -12880,6 +12246,125 @@ static void wsp_ggml_compute_forward_upscale(
12880
12246
  }
12881
12247
  }
12882
12248
 
12249
+ // wsp_ggml_compute_forward_pad
12250
+
12251
+ static void wsp_ggml_compute_forward_pad_f32(
12252
+ const struct wsp_ggml_compute_params * params,
12253
+ const struct wsp_ggml_tensor * src0,
12254
+ struct wsp_ggml_tensor * dst) {
12255
+
12256
+ if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) {
12257
+ return;
12258
+ }
12259
+
12260
+ WSP_GGML_ASSERT(src0->nb[0] == sizeof(float));
12261
+ WSP_GGML_ASSERT( dst->nb[0] == sizeof(float));
12262
+
12263
+ const int ith = params->ith;
12264
+ const int nth = params->nth;
12265
+
12266
+ WSP_GGML_TENSOR_UNARY_OP_LOCALS
12267
+
12268
+ float * dst_ptr = (float *) dst->data;
12269
+
12270
+ // TODO: optimize
12271
+
12272
+ for (int64_t i2 = 0; i2 < ne2; ++i2) {
12273
+ for (int64_t i1 = ith; i1 < ne1; i1 += nth) {
12274
+ for (int64_t i0 = 0; i0 < ne0; ++i0) {
12275
+ for (int64_t i3 = 0; i3 < ne3; ++i3) {
12276
+ const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
12277
+
12278
+ const float * src_ptr = (const float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
12279
+
12280
+ if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
12281
+ dst_ptr[dst_idx] = *src_ptr;
12282
+ } else {
12283
+ dst_ptr[dst_idx] = 0;
12284
+ }
12285
+ }
12286
+ }
12287
+ }
12288
+ }
12289
+ }
12290
+
12291
+ static void wsp_ggml_compute_forward_pad(
12292
+ const struct wsp_ggml_compute_params * params,
12293
+ const struct wsp_ggml_tensor * src0,
12294
+ struct wsp_ggml_tensor * dst) {
12295
+ switch (src0->type) {
12296
+ case WSP_GGML_TYPE_F32:
12297
+ {
12298
+ wsp_ggml_compute_forward_pad_f32(params, src0, dst);
12299
+ } break;
12300
+ default:
12301
+ {
12302
+ WSP_GGML_ASSERT(false);
12303
+ } break;
12304
+ }
12305
+ }
12306
+
12307
+ // wsp_ggml_compute_forward_argsort
12308
+
12309
+ static void wsp_ggml_compute_forward_argsort_f32(
12310
+ const struct wsp_ggml_compute_params * params,
12311
+ const struct wsp_ggml_tensor * src0,
12312
+ struct wsp_ggml_tensor * dst) {
12313
+
12314
+ if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) {
12315
+ return;
12316
+ }
12317
+
12318
+ WSP_GGML_TENSOR_UNARY_OP_LOCALS
12319
+
12320
+ WSP_GGML_ASSERT(nb0 == sizeof(float));
12321
+
12322
+ const int ith = params->ith;
12323
+ const int nth = params->nth;
12324
+
12325
+ const int64_t nr = wsp_ggml_nrows(src0);
12326
+
12327
+ enum wsp_ggml_sort_order order = (enum wsp_ggml_sort_order) wsp_ggml_get_op_params_i32(dst, 0);
12328
+
12329
+ for (int64_t i = ith; i < nr; i += nth) {
12330
+ int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
12331
+ const float * src_data = (float *)((char *) src0->data + i*nb01);
12332
+
12333
+ for (int64_t j = 0; j < ne0; j++) {
12334
+ dst_data[j] = j;
12335
+ }
12336
+
12337
+ // C doesn't have a functional sort, so we do a bubble sort instead
12338
+ for (int64_t j = 0; j < ne0; j++) {
12339
+ for (int64_t k = j + 1; k < ne0; k++) {
12340
+ if ((order == WSP_GGML_SORT_ASC && src_data[dst_data[j]] > src_data[dst_data[k]]) ||
12341
+ (order == WSP_GGML_SORT_DESC && src_data[dst_data[j]] < src_data[dst_data[k]])) {
12342
+ int32_t tmp = dst_data[j];
12343
+ dst_data[j] = dst_data[k];
12344
+ dst_data[k] = tmp;
12345
+ }
12346
+ }
12347
+ }
12348
+ }
12349
+ }
12350
+
12351
+ static void wsp_ggml_compute_forward_argsort(
12352
+ const struct wsp_ggml_compute_params * params,
12353
+ const struct wsp_ggml_tensor * src0,
12354
+ struct wsp_ggml_tensor * dst) {
12355
+
12356
+ switch (src0->type) {
12357
+ case WSP_GGML_TYPE_F32:
12358
+ {
12359
+ wsp_ggml_compute_forward_argsort_f32(params, src0, dst);
12360
+ } break;
12361
+ default:
12362
+ {
12363
+ WSP_GGML_ASSERT(false);
12364
+ } break;
12365
+ }
12366
+ }
12367
+
12883
12368
  // wsp_ggml_compute_forward_flash_attn
12884
12369
 
12885
12370
  static void wsp_ggml_compute_forward_flash_attn_f32(
@@ -14026,10 +13511,6 @@ static void wsp_ggml_compute_forward_unary(
14026
13511
  {
14027
13512
  wsp_ggml_compute_forward_silu(params, src0, dst);
14028
13513
  } break;
14029
- case WSP_GGML_UNARY_OP_LEAKY:
14030
- {
14031
- wsp_ggml_compute_forward_leaky(params, src0, dst);
14032
- } break;
14033
13514
  default:
14034
13515
  {
14035
13516
  WSP_GGML_ASSERT(false);
@@ -14701,7 +14182,11 @@ static void wsp_ggml_compute_forward(struct wsp_ggml_compute_params * params, st
14701
14182
  } break;
14702
14183
  case WSP_GGML_OP_MUL_MAT:
14703
14184
  {
14704
- wsp_ggml_compute_forward_mul_mat(params, tensor->src[0], tensor->src[1], tensor);
14185
+ wsp_ggml_compute_forward_mul_mat(params, tensor->src[0], tensor->src[1], tensor, 0, tensor->ne[1]);
14186
+ } break;
14187
+ case WSP_GGML_OP_MUL_MAT_ID:
14188
+ {
14189
+ wsp_ggml_compute_forward_mul_mat_id(params, tensor->src[0], tensor->src[1], tensor);
14705
14190
  } break;
14706
14191
  case WSP_GGML_OP_OUT_PROD:
14707
14192
  {
@@ -14761,7 +14246,7 @@ static void wsp_ggml_compute_forward(struct wsp_ggml_compute_params * params, st
14761
14246
  } break;
14762
14247
  case WSP_GGML_OP_SOFT_MAX:
14763
14248
  {
14764
- wsp_ggml_compute_forward_soft_max(params, tensor->src[0], tensor);
14249
+ wsp_ggml_compute_forward_soft_max(params, tensor->src[0], tensor->src[1], tensor);
14765
14250
  } break;
14766
14251
  case WSP_GGML_OP_SOFT_MAX_BACK:
14767
14252
  {
@@ -14783,33 +14268,13 @@ static void wsp_ggml_compute_forward(struct wsp_ggml_compute_params * params, st
14783
14268
  {
14784
14269
  wsp_ggml_compute_forward_clamp(params, tensor->src[0], tensor);
14785
14270
  } break;
14786
- case WSP_GGML_OP_CONV_1D:
14787
- {
14788
- wsp_ggml_compute_forward_conv_1d(params, tensor->src[0], tensor->src[1], tensor);
14789
- } break;
14790
- case WSP_GGML_OP_CONV_1D_STAGE_0:
14791
- {
14792
- wsp_ggml_compute_forward_conv_1d_stage_0(params, tensor->src[0], tensor->src[1], tensor);
14793
- } break;
14794
- case WSP_GGML_OP_CONV_1D_STAGE_1:
14795
- {
14796
- wsp_ggml_compute_forward_conv_1d_stage_1(params, tensor->src[0], tensor->src[1], tensor);
14797
- } break;
14798
14271
  case WSP_GGML_OP_CONV_TRANSPOSE_1D:
14799
14272
  {
14800
14273
  wsp_ggml_compute_forward_conv_transpose_1d(params, tensor->src[0], tensor->src[1], tensor);
14801
14274
  } break;
14802
- case WSP_GGML_OP_CONV_2D:
14803
- {
14804
- wsp_ggml_compute_forward_conv_2d(params, tensor->src[0], tensor->src[1], tensor);
14805
- } break;
14806
- case WSP_GGML_OP_CONV_2D_STAGE_0:
14807
- {
14808
- wsp_ggml_compute_forward_conv_2d_stage_0(params, tensor->src[0], tensor->src[1], tensor);
14809
- } break;
14810
- case WSP_GGML_OP_CONV_2D_STAGE_1:
14275
+ case WSP_GGML_OP_IM2COL:
14811
14276
  {
14812
- wsp_ggml_compute_forward_conv_2d_stage_1(params, tensor->src[0], tensor->src[1], tensor);
14277
+ wsp_ggml_compute_forward_im2col(params, tensor->src[0], tensor->src[1], tensor);
14813
14278
  } break;
14814
14279
  case WSP_GGML_OP_CONV_TRANSPOSE_2D:
14815
14280
  {
@@ -14827,6 +14292,18 @@ static void wsp_ggml_compute_forward(struct wsp_ggml_compute_params * params, st
14827
14292
  {
14828
14293
  wsp_ggml_compute_forward_upscale(params, tensor->src[0], tensor);
14829
14294
  } break;
14295
+ case WSP_GGML_OP_PAD:
14296
+ {
14297
+ wsp_ggml_compute_forward_pad(params, tensor->src[0], tensor);
14298
+ } break;
14299
+ case WSP_GGML_OP_ARGSORT:
14300
+ {
14301
+ wsp_ggml_compute_forward_argsort(params, tensor->src[0], tensor);
14302
+ } break;
14303
+ case WSP_GGML_OP_LEAKY_RELU:
14304
+ {
14305
+ wsp_ggml_compute_forward_leaky_relu(params, tensor->src[0], tensor);
14306
+ } break;
14830
14307
  case WSP_GGML_OP_FLASH_ATTN:
14831
14308
  {
14832
14309
  const int32_t t = wsp_ggml_get_op_params_i32(tensor, 0);
@@ -15151,7 +14628,7 @@ void wsp_ggml_build_backward_gradient_checkpointing(
15151
14628
  // insert new tensors recomputing src, reusing already made replacements,
15152
14629
  // remember replacements: remember new tensors with mapping from corresponding gf nodes
15153
14630
  // recurse for input tensors,
15154
- // unless (i.e. terminating when) input tensors are replacments (like checkpoints)
14631
+ // unless (i.e. terminating when) input tensors are replacements (like checkpoints)
15155
14632
  node->src[k] = wsp_ggml_recompute_graph_node(ctx, gf, replacements, node->src[k]);
15156
14633
  }
15157
14634
  // insert rewritten backward node with replacements made into resulting backward graph gb
@@ -15477,6 +14954,10 @@ static void wsp_ggml_compute_backward(struct wsp_ggml_context * ctx, struct wsp_
15477
14954
  zero_table);
15478
14955
  }
15479
14956
  } break;
14957
+ case WSP_GGML_OP_MUL_MAT_ID:
14958
+ {
14959
+ WSP_GGML_ASSERT(false); // TODO: not implemented
14960
+ } break;
15480
14961
  case WSP_GGML_OP_OUT_PROD:
15481
14962
  {
15482
14963
  WSP_GGML_ASSERT(false); // TODO: not implemented
@@ -15708,17 +15189,20 @@ static void wsp_ggml_compute_backward(struct wsp_ggml_context * ctx, struct wsp_
15708
15189
  // necessary for llama
15709
15190
  if (src0->grad) {
15710
15191
  //const int n_past = ((int32_t *) tensor->op_params)[0];
15711
- const int n_dims = ((int32_t *) tensor->op_params)[1];
15712
- const int mode = ((int32_t *) tensor->op_params)[2];
15713
- const int n_ctx = ((int32_t *) tensor->op_params)[3];
15714
- float freq_base;
15715
- float freq_scale;
15716
- float xpos_base;
15717
- bool xpos_down;
15718
- memcpy(&freq_base, (int32_t *) tensor->op_params + 4, sizeof(float));
15719
- memcpy(&freq_scale, (int32_t *) tensor->op_params + 5, sizeof(float));
15720
- memcpy(&xpos_base, (int32_t *) tensor->op_params + 6, sizeof(float));
15721
- memcpy(&xpos_down, (int32_t *) tensor->op_params + 7, sizeof(bool));
15192
+ const int n_dims = ((int32_t *) tensor->op_params)[1];
15193
+ const int mode = ((int32_t *) tensor->op_params)[2];
15194
+ const int n_ctx = ((int32_t *) tensor->op_params)[3];
15195
+ const int n_orig_ctx = ((int32_t *) tensor->op_params)[4];
15196
+ float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow, xpos_base, xpos_down;
15197
+
15198
+ memcpy(&freq_base, (int32_t *) tensor->op_params + 5, sizeof(float));
15199
+ memcpy(&freq_scale, (int32_t *) tensor->op_params + 6, sizeof(float));
15200
+ memcpy(&ext_factor, (int32_t *) tensor->op_params + 7, sizeof(float));
15201
+ memcpy(&attn_factor, (int32_t *) tensor->op_params + 8, sizeof(float));
15202
+ memcpy(&beta_fast, (int32_t *) tensor->op_params + 9, sizeof(float));
15203
+ memcpy(&beta_slow, (int32_t *) tensor->op_params + 10, sizeof(float));
15204
+ memcpy(&xpos_base, (int32_t *) tensor->op_params + 11, sizeof(float));
15205
+ memcpy(&xpos_down, (int32_t *) tensor->op_params + 12, sizeof(bool));
15722
15206
 
15723
15207
  src0->grad = wsp_ggml_add_or_set(ctx,
15724
15208
  src0->grad,
@@ -15728,8 +15212,13 @@ static void wsp_ggml_compute_backward(struct wsp_ggml_context * ctx, struct wsp_
15728
15212
  n_dims,
15729
15213
  mode,
15730
15214
  n_ctx,
15215
+ n_orig_ctx,
15731
15216
  freq_base,
15732
15217
  freq_scale,
15218
+ ext_factor,
15219
+ attn_factor,
15220
+ beta_fast,
15221
+ beta_slow,
15733
15222
  xpos_base,
15734
15223
  xpos_down),
15735
15224
  zero_table);
@@ -15739,17 +15228,20 @@ static void wsp_ggml_compute_backward(struct wsp_ggml_context * ctx, struct wsp_
15739
15228
  {
15740
15229
  if (src0->grad) {
15741
15230
  //const int n_past = ((int32_t *) tensor->op_params)[0];
15742
- const int n_dims = ((int32_t *) tensor->op_params)[1];
15743
- const int mode = ((int32_t *) tensor->op_params)[2];
15744
- const int n_ctx = ((int32_t *) tensor->op_params)[3];
15745
- float freq_base;
15746
- float freq_scale;
15747
- float xpos_base;
15748
- bool xpos_down;
15749
- memcpy(&freq_base, (int32_t *) tensor->op_params + 4, sizeof(float));
15750
- memcpy(&freq_scale, (int32_t *) tensor->op_params + 5, sizeof(float));
15751
- memcpy(&xpos_base, (int32_t *) tensor->op_params + 6, sizeof(float));
15752
- memcpy(&xpos_down, (int32_t *) tensor->op_params + 7, sizeof(bool));
15231
+ const int n_dims = ((int32_t *) tensor->op_params)[1];
15232
+ const int mode = ((int32_t *) tensor->op_params)[2];
15233
+ const int n_ctx = ((int32_t *) tensor->op_params)[3];
15234
+ const int n_orig_ctx = ((int32_t *) tensor->op_params)[4];
15235
+ float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow, xpos_base, xpos_down;
15236
+
15237
+ memcpy(&freq_base, (int32_t *) tensor->op_params + 5, sizeof(float));
15238
+ memcpy(&freq_scale, (int32_t *) tensor->op_params + 6, sizeof(float));
15239
+ memcpy(&ext_factor, (int32_t *) tensor->op_params + 7, sizeof(float));
15240
+ memcpy(&attn_factor, (int32_t *) tensor->op_params + 8, sizeof(float));
15241
+ memcpy(&beta_fast, (int32_t *) tensor->op_params + 9, sizeof(float));
15242
+ memcpy(&beta_slow, (int32_t *) tensor->op_params + 10, sizeof(float));
15243
+ memcpy(&xpos_base, (int32_t *) tensor->op_params + 11, sizeof(float));
15244
+ memcpy(&xpos_down, (int32_t *) tensor->op_params + 12, sizeof(bool));
15753
15245
 
15754
15246
  src0->grad = wsp_ggml_add_or_set(ctx,
15755
15247
  src0->grad,
@@ -15758,14 +15250,14 @@ static void wsp_ggml_compute_backward(struct wsp_ggml_context * ctx, struct wsp_
15758
15250
  src1,
15759
15251
  n_dims,
15760
15252
  mode,
15761
- 0,
15762
15253
  n_ctx,
15254
+ n_orig_ctx,
15763
15255
  freq_base,
15764
15256
  freq_scale,
15765
- 0.0f,
15766
- 1.0f,
15767
- 0.0f,
15768
- 0.0f,
15257
+ ext_factor,
15258
+ attn_factor,
15259
+ beta_fast,
15260
+ beta_slow,
15769
15261
  xpos_base,
15770
15262
  xpos_down,
15771
15263
  false),
@@ -15780,47 +15272,39 @@ static void wsp_ggml_compute_backward(struct wsp_ggml_context * ctx, struct wsp_
15780
15272
  {
15781
15273
  WSP_GGML_ASSERT(false); // TODO: not implemented
15782
15274
  } break;
15783
- case WSP_GGML_OP_CONV_1D:
15784
- {
15785
- WSP_GGML_ASSERT(false); // TODO: not implemented
15786
- } break;
15787
- case WSP_GGML_OP_CONV_1D_STAGE_0:
15788
- {
15789
- WSP_GGML_ASSERT(false); // TODO: not implemented
15790
- } break;
15791
- case WSP_GGML_OP_CONV_1D_STAGE_1:
15275
+ case WSP_GGML_OP_CONV_TRANSPOSE_1D:
15792
15276
  {
15793
15277
  WSP_GGML_ASSERT(false); // TODO: not implemented
15794
15278
  } break;
15795
- case WSP_GGML_OP_CONV_TRANSPOSE_1D:
15279
+ case WSP_GGML_OP_IM2COL:
15796
15280
  {
15797
15281
  WSP_GGML_ASSERT(false); // TODO: not implemented
15798
15282
  } break;
15799
- case WSP_GGML_OP_CONV_2D:
15283
+ case WSP_GGML_OP_CONV_TRANSPOSE_2D:
15800
15284
  {
15801
15285
  WSP_GGML_ASSERT(false); // TODO: not implemented
15802
15286
  } break;
15803
- case WSP_GGML_OP_CONV_2D_STAGE_0:
15287
+ case WSP_GGML_OP_POOL_1D:
15804
15288
  {
15805
15289
  WSP_GGML_ASSERT(false); // TODO: not implemented
15806
15290
  } break;
15807
- case WSP_GGML_OP_CONV_2D_STAGE_1:
15291
+ case WSP_GGML_OP_POOL_2D:
15808
15292
  {
15809
15293
  WSP_GGML_ASSERT(false); // TODO: not implemented
15810
15294
  } break;
15811
- case WSP_GGML_OP_CONV_TRANSPOSE_2D:
15295
+ case WSP_GGML_OP_UPSCALE:
15812
15296
  {
15813
15297
  WSP_GGML_ASSERT(false); // TODO: not implemented
15814
15298
  } break;
15815
- case WSP_GGML_OP_POOL_1D:
15299
+ case WSP_GGML_OP_PAD:
15816
15300
  {
15817
15301
  WSP_GGML_ASSERT(false); // TODO: not implemented
15818
15302
  } break;
15819
- case WSP_GGML_OP_POOL_2D:
15303
+ case WSP_GGML_OP_ARGSORT:
15820
15304
  {
15821
15305
  WSP_GGML_ASSERT(false); // TODO: not implemented
15822
15306
  } break;
15823
- case WSP_GGML_OP_UPSCALE:
15307
+ case WSP_GGML_OP_LEAKY_RELU:
15824
15308
  {
15825
15309
  WSP_GGML_ASSERT(false); // TODO: not implemented
15826
15310
  } break;
@@ -16184,12 +15668,8 @@ struct wsp_ggml_cgraph * wsp_ggml_new_graph(struct wsp_ggml_context * ctx) {
16184
15668
  return wsp_ggml_new_graph_custom(ctx, WSP_GGML_DEFAULT_GRAPH_SIZE, false);
16185
15669
  }
16186
15670
 
16187
- struct wsp_ggml_cgraph * wsp_ggml_graph_view(struct wsp_ggml_context * ctx, struct wsp_ggml_cgraph * cgraph0, int i0, int i1) {
16188
- const size_t obj_size = sizeof(struct wsp_ggml_cgraph);
16189
- struct wsp_ggml_object * obj = wsp_ggml_new_object(ctx, WSP_GGML_OBJECT_GRAPH, obj_size);
16190
- struct wsp_ggml_cgraph * cgraph = (struct wsp_ggml_cgraph *) ((char *) ctx->mem_buffer + obj->offs);
16191
-
16192
- *cgraph = (struct wsp_ggml_cgraph) {
15671
+ struct wsp_ggml_cgraph wsp_ggml_graph_view(struct wsp_ggml_cgraph * cgraph0, int i0, int i1) {
15672
+ struct wsp_ggml_cgraph cgraph = {
16193
15673
  /*.size =*/ 0,
16194
15674
  /*.n_nodes =*/ i1 - i0,
16195
15675
  /*.n_leafs =*/ 0,
@@ -16424,7 +15904,6 @@ static int wsp_ggml_get_n_tasks(struct wsp_ggml_tensor * node, int n_threads) {
16424
15904
  n_tasks = n_threads;
16425
15905
  } break;
16426
15906
  case WSP_GGML_OP_SUB:
16427
- case WSP_GGML_OP_DIV:
16428
15907
  case WSP_GGML_OP_SQR:
16429
15908
  case WSP_GGML_OP_SQRT:
16430
15909
  case WSP_GGML_OP_LOG:
@@ -16434,6 +15913,7 @@ static int wsp_ggml_get_n_tasks(struct wsp_ggml_tensor * node, int n_threads) {
16434
15913
  case WSP_GGML_OP_ARGMAX:
16435
15914
  case WSP_GGML_OP_REPEAT:
16436
15915
  case WSP_GGML_OP_REPEAT_BACK:
15916
+ case WSP_GGML_OP_LEAKY_RELU:
16437
15917
  {
16438
15918
  n_tasks = 1;
16439
15919
  } break;
@@ -16446,7 +15926,6 @@ static int wsp_ggml_get_n_tasks(struct wsp_ggml_tensor * node, int n_threads) {
16446
15926
  case WSP_GGML_UNARY_OP_TANH:
16447
15927
  case WSP_GGML_UNARY_OP_ELU:
16448
15928
  case WSP_GGML_UNARY_OP_RELU:
16449
- case WSP_GGML_UNARY_OP_LEAKY:
16450
15929
  {
16451
15930
  n_tasks = 1;
16452
15931
  } break;
@@ -16457,10 +15936,13 @@ static int wsp_ggml_get_n_tasks(struct wsp_ggml_tensor * node, int n_threads) {
16457
15936
  {
16458
15937
  n_tasks = n_threads;
16459
15938
  } break;
15939
+ default:
15940
+ WSP_GGML_ASSERT(false);
16460
15941
  }
16461
15942
  break;
16462
15943
  case WSP_GGML_OP_SILU_BACK:
16463
15944
  case WSP_GGML_OP_MUL:
15945
+ case WSP_GGML_OP_DIV:
16464
15946
  case WSP_GGML_OP_NORM:
16465
15947
  case WSP_GGML_OP_RMS_NORM:
16466
15948
  case WSP_GGML_OP_RMS_NORM_BACK:
@@ -16498,6 +15980,11 @@ static int wsp_ggml_get_n_tasks(struct wsp_ggml_tensor * node, int n_threads) {
16498
15980
  }
16499
15981
  #endif
16500
15982
  } break;
15983
+ case WSP_GGML_OP_MUL_MAT_ID:
15984
+ {
15985
+ // FIXME: blas
15986
+ n_tasks = n_threads;
15987
+ } break;
16501
15988
  case WSP_GGML_OP_OUT_PROD:
16502
15989
  {
16503
15990
  n_tasks = n_threads;
@@ -16517,7 +16004,6 @@ static int wsp_ggml_get_n_tasks(struct wsp_ggml_tensor * node, int n_threads) {
16517
16004
  } break;
16518
16005
  case WSP_GGML_OP_DIAG_MASK_ZERO:
16519
16006
  case WSP_GGML_OP_DIAG_MASK_INF:
16520
- case WSP_GGML_OP_SOFT_MAX:
16521
16007
  case WSP_GGML_OP_SOFT_MAX_BACK:
16522
16008
  case WSP_GGML_OP_ROPE:
16523
16009
  case WSP_GGML_OP_ROPE_BACK:
@@ -16533,31 +16019,15 @@ static int wsp_ggml_get_n_tasks(struct wsp_ggml_tensor * node, int n_threads) {
16533
16019
  {
16534
16020
  n_tasks = 1; //TODO
16535
16021
  } break;
16536
- case WSP_GGML_OP_CONV_1D:
16537
- {
16538
- n_tasks = n_threads;
16539
- } break;
16540
- case WSP_GGML_OP_CONV_1D_STAGE_0:
16541
- {
16542
- n_tasks = n_threads;
16543
- } break;
16544
- case WSP_GGML_OP_CONV_1D_STAGE_1:
16022
+ case WSP_GGML_OP_SOFT_MAX:
16545
16023
  {
16546
- n_tasks = n_threads;
16024
+ n_tasks = MIN(MIN(4, n_threads), wsp_ggml_nrows(node->src[0]));
16547
16025
  } break;
16548
16026
  case WSP_GGML_OP_CONV_TRANSPOSE_1D:
16549
16027
  {
16550
16028
  n_tasks = n_threads;
16551
16029
  } break;
16552
- case WSP_GGML_OP_CONV_2D:
16553
- {
16554
- n_tasks = n_threads;
16555
- } break;
16556
- case WSP_GGML_OP_CONV_2D_STAGE_0:
16557
- {
16558
- n_tasks = n_threads;
16559
- } break;
16560
- case WSP_GGML_OP_CONV_2D_STAGE_1:
16030
+ case WSP_GGML_OP_IM2COL:
16561
16031
  {
16562
16032
  n_tasks = n_threads;
16563
16033
  } break;
@@ -16574,6 +16044,14 @@ static int wsp_ggml_get_n_tasks(struct wsp_ggml_tensor * node, int n_threads) {
16574
16044
  {
16575
16045
  n_tasks = n_threads;
16576
16046
  } break;
16047
+ case WSP_GGML_OP_PAD:
16048
+ {
16049
+ n_tasks = n_threads;
16050
+ } break;
16051
+ case WSP_GGML_OP_ARGSORT:
16052
+ {
16053
+ n_tasks = n_threads;
16054
+ } break;
16577
16055
  case WSP_GGML_OP_FLASH_ATTN:
16578
16056
  {
16579
16057
  n_tasks = n_threads;
@@ -16642,6 +16120,12 @@ static int wsp_ggml_get_n_tasks(struct wsp_ggml_tensor * node, int n_threads) {
16642
16120
  } break;
16643
16121
  default:
16644
16122
  {
16123
+ fprintf(stderr, "%s: op not implemented: ", __func__);
16124
+ if (node->op < WSP_GGML_OP_COUNT) {
16125
+ fprintf(stderr, "%s\n", wsp_ggml_op_name(node->op));
16126
+ } else {
16127
+ fprintf(stderr, "%d\n", node->op);
16128
+ }
16645
16129
  WSP_GGML_ASSERT(false);
16646
16130
  } break;
16647
16131
  }
@@ -16782,18 +16266,16 @@ struct wsp_ggml_cplan wsp_ggml_graph_plan(struct wsp_ggml_cgraph * cgraph, int n
16782
16266
 
16783
16267
  // thread scheduling for the different operations + work buffer size estimation
16784
16268
  for (int i = 0; i < cgraph->n_nodes; i++) {
16785
- int n_tasks = 1;
16786
-
16787
16269
  struct wsp_ggml_tensor * node = cgraph->nodes[i];
16788
16270
 
16271
+ const int n_tasks = wsp_ggml_get_n_tasks(node, n_threads);
16272
+
16789
16273
  size_t cur = 0;
16790
16274
 
16791
16275
  switch (node->op) {
16792
16276
  case WSP_GGML_OP_CPY:
16793
16277
  case WSP_GGML_OP_DUP:
16794
16278
  {
16795
- n_tasks = n_threads;
16796
-
16797
16279
  if (wsp_ggml_is_quantized(node->type)) {
16798
16280
  cur = wsp_ggml_type_size(WSP_GGML_TYPE_F32) * node->ne[0] * n_tasks;
16799
16281
  }
@@ -16801,16 +16283,12 @@ struct wsp_ggml_cplan wsp_ggml_graph_plan(struct wsp_ggml_cgraph * cgraph, int n
16801
16283
  case WSP_GGML_OP_ADD:
16802
16284
  case WSP_GGML_OP_ADD1:
16803
16285
  {
16804
- n_tasks = n_threads;
16805
-
16806
16286
  if (wsp_ggml_is_quantized(node->src[0]->type)) {
16807
16287
  cur = wsp_ggml_type_size(WSP_GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks;
16808
16288
  }
16809
16289
  } break;
16810
16290
  case WSP_GGML_OP_ACC:
16811
16291
  {
16812
- n_tasks = n_threads;
16813
-
16814
16292
  if (wsp_ggml_is_quantized(node->src[0]->type)) {
16815
16293
  cur = wsp_ggml_type_size(WSP_GGML_TYPE_F32) * node->src[1]->ne[0] * n_tasks;
16816
16294
  }
@@ -16836,45 +16314,32 @@ struct wsp_ggml_cplan wsp_ggml_graph_plan(struct wsp_ggml_cgraph * cgraph, int n
16836
16314
  cur = wsp_ggml_type_size(vec_dot_type)*wsp_ggml_nelements(node->src[1])/wsp_ggml_blck_size(vec_dot_type);
16837
16315
  }
16838
16316
  } break;
16317
+ case WSP_GGML_OP_MUL_MAT_ID:
16318
+ {
16319
+ const struct wsp_ggml_tensor * a = node->src[2];
16320
+ const struct wsp_ggml_tensor * b = node->src[1];
16321
+ const enum wsp_ggml_type vec_dot_type = type_traits[a->type].vec_dot_type;
16322
+ #if defined(WSP_GGML_USE_ACCELERATE) || defined(WSP_GGML_USE_OPENBLAS)
16323
+ if (wsp_ggml_compute_forward_mul_mat_use_blas(a, b, node)) {
16324
+ if (a->type != WSP_GGML_TYPE_F32) {
16325
+ // here we need memory just for single 2D matrix from src0
16326
+ cur = wsp_ggml_type_size(WSP_GGML_TYPE_F32)*(a->ne[0]*a->ne[1]);
16327
+ }
16328
+ } else
16329
+ #endif
16330
+ if (b->type != vec_dot_type) {
16331
+ cur = wsp_ggml_type_size(vec_dot_type)*wsp_ggml_nelements(b)/wsp_ggml_blck_size(vec_dot_type);
16332
+ }
16333
+ } break;
16839
16334
  case WSP_GGML_OP_OUT_PROD:
16840
16335
  {
16841
- n_tasks = n_threads;
16842
-
16843
16336
  if (wsp_ggml_is_quantized(node->src[0]->type)) {
16844
16337
  cur = wsp_ggml_type_size(WSP_GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks;
16845
16338
  }
16846
16339
  } break;
16847
- case WSP_GGML_OP_CONV_1D:
16340
+ case WSP_GGML_OP_SOFT_MAX:
16848
16341
  {
16849
- WSP_GGML_ASSERT(node->src[0]->ne[3] == 1);
16850
- WSP_GGML_ASSERT(node->src[1]->ne[2] == 1);
16851
- WSP_GGML_ASSERT(node->src[1]->ne[3] == 1);
16852
-
16853
- const int64_t ne00 = node->src[0]->ne[0];
16854
- const int64_t ne01 = node->src[0]->ne[1];
16855
- const int64_t ne02 = node->src[0]->ne[2];
16856
-
16857
- const int64_t ne10 = node->src[1]->ne[0];
16858
- const int64_t ne11 = node->src[1]->ne[1];
16859
-
16860
- const int64_t ne0 = node->ne[0];
16861
- const int64_t ne1 = node->ne[1];
16862
- const int64_t nk = ne00;
16863
- const int64_t ew0 = nk * ne01;
16864
-
16865
- UNUSED(ne02);
16866
- UNUSED(ne10);
16867
- UNUSED(ne11);
16868
-
16869
- if (node->src[0]->type == WSP_GGML_TYPE_F16 &&
16870
- node->src[1]->type == WSP_GGML_TYPE_F32) {
16871
- cur = sizeof(wsp_ggml_fp16_t)*(ne0*ne1*ew0);
16872
- } else if (node->src[0]->type == WSP_GGML_TYPE_F32 &&
16873
- node->src[1]->type == WSP_GGML_TYPE_F32) {
16874
- cur = sizeof(float)*(ne0*ne1*ew0);
16875
- } else {
16876
- WSP_GGML_ASSERT(false);
16877
- }
16342
+ cur = wsp_ggml_type_size(WSP_GGML_TYPE_F32) * node->ne[0] * n_tasks;
16878
16343
  } break;
16879
16344
  case WSP_GGML_OP_CONV_TRANSPOSE_1D:
16880
16345
  {
@@ -16901,38 +16366,6 @@ struct wsp_ggml_cplan wsp_ggml_graph_plan(struct wsp_ggml_cgraph * cgraph, int n
16901
16366
  WSP_GGML_ASSERT(false);
16902
16367
  }
16903
16368
  } break;
16904
- case WSP_GGML_OP_CONV_2D:
16905
- {
16906
- const int64_t ne00 = node->src[0]->ne[0]; // W
16907
- const int64_t ne01 = node->src[0]->ne[1]; // H
16908
- const int64_t ne02 = node->src[0]->ne[2]; // C
16909
- const int64_t ne03 = node->src[0]->ne[3]; // N
16910
-
16911
- const int64_t ne10 = node->src[1]->ne[0]; // W
16912
- const int64_t ne11 = node->src[1]->ne[1]; // H
16913
- const int64_t ne12 = node->src[1]->ne[2]; // C
16914
-
16915
- const int64_t ne0 = node->ne[0];
16916
- const int64_t ne1 = node->ne[1];
16917
- const int64_t ne2 = node->ne[2];
16918
- const int64_t ne3 = node->ne[3];
16919
- const int64_t nk = ne00*ne01;
16920
- const int64_t ew0 = nk * ne02;
16921
-
16922
- UNUSED(ne03);
16923
- UNUSED(ne2);
16924
-
16925
- if (node->src[0]->type == WSP_GGML_TYPE_F16 &&
16926
- node->src[1]->type == WSP_GGML_TYPE_F32) {
16927
- // im2col: [N*OH*OW, IC*KH*KW]
16928
- cur = sizeof(wsp_ggml_fp16_t)*(ne3*ne0*ne1*ew0);
16929
- } else if (node->src[0]->type == WSP_GGML_TYPE_F32 &&
16930
- node->src[1]->type == WSP_GGML_TYPE_F32) {
16931
- cur = sizeof(float)* (ne10*ne11*ne12);
16932
- } else {
16933
- WSP_GGML_ASSERT(false);
16934
- }
16935
- } break;
16936
16369
  case WSP_GGML_OP_CONV_TRANSPOSE_2D:
16937
16370
  {
16938
16371
  const int64_t ne00 = node->src[0]->ne[0]; // W
@@ -16949,8 +16382,6 @@ struct wsp_ggml_cplan wsp_ggml_graph_plan(struct wsp_ggml_cgraph * cgraph, int n
16949
16382
  } break;
16950
16383
  case WSP_GGML_OP_FLASH_ATTN:
16951
16384
  {
16952
- n_tasks = n_threads;
16953
-
16954
16385
  const int64_t ne11 = wsp_ggml_up(node->src[1]->ne[1], WSP_GGML_SOFT_MAX_UNROLL);
16955
16386
 
16956
16387
  if (node->src[1]->type == WSP_GGML_TYPE_F32) {
@@ -16963,8 +16394,6 @@ struct wsp_ggml_cplan wsp_ggml_graph_plan(struct wsp_ggml_cgraph * cgraph, int n
16963
16394
  } break;
16964
16395
  case WSP_GGML_OP_FLASH_FF:
16965
16396
  {
16966
- n_tasks = n_threads;
16967
-
16968
16397
  if (node->src[1]->type == WSP_GGML_TYPE_F32) {
16969
16398
  cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
16970
16399
  cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2
@@ -16975,8 +16404,6 @@ struct wsp_ggml_cplan wsp_ggml_graph_plan(struct wsp_ggml_cgraph * cgraph, int n
16975
16404
  } break;
16976
16405
  case WSP_GGML_OP_FLASH_ATTN_BACK:
16977
16406
  {
16978
- n_tasks = n_threads;
16979
-
16980
16407
  const int64_t D = node->src[0]->ne[0];
16981
16408
  const int64_t ne11 = wsp_ggml_up(node->src[1]->ne[1], WSP_GGML_SOFT_MAX_UNROLL);
16982
16409
  const int64_t mxDn = MAX(D, ne11) * 2; // *2 because of S and SM in wsp_ggml_compute_forward_flash_attn_back
@@ -16991,8 +16418,6 @@ struct wsp_ggml_cplan wsp_ggml_graph_plan(struct wsp_ggml_cgraph * cgraph, int n
16991
16418
 
16992
16419
  case WSP_GGML_OP_CROSS_ENTROPY_LOSS:
16993
16420
  {
16994
- n_tasks = n_threads;
16995
-
16996
16421
  cur = wsp_ggml_type_size(node->type)*(n_tasks + node->src[0]->ne[0]*n_tasks);
16997
16422
  } break;
16998
16423
  case WSP_GGML_OP_COUNT:
@@ -18719,14 +18144,14 @@ enum wsp_ggml_opt_result wsp_ggml_opt_resume_g(
18719
18144
 
18720
18145
  ////////////////////////////////////////////////////////////////////////////////
18721
18146
 
18722
- size_t wsp_ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist) {
18147
+ size_t wsp_ggml_wsp_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist) {
18723
18148
  assert(k % QK4_0 == 0);
18724
18149
  const int nb = k / QK4_0;
18725
18150
 
18726
18151
  for (int b = 0; b < n; b += k) {
18727
18152
  block_q4_0 * restrict y = (block_q4_0 *) dst + b/QK4_0;
18728
18153
 
18729
- quantize_row_q4_0_reference(src + b, y, k);
18154
+ wsp_quantize_row_q4_0_reference(src + b, y, k);
18730
18155
 
18731
18156
  for (int i = 0; i < nb; i++) {
18732
18157
  for (int j = 0; j < QK4_0; j += 2) {
@@ -18742,14 +18167,14 @@ size_t wsp_ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64
18742
18167
  return (n/QK4_0*sizeof(block_q4_0));
18743
18168
  }
18744
18169
 
18745
- size_t wsp_ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist) {
18170
+ size_t wsp_ggml_wsp_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist) {
18746
18171
  assert(k % QK4_1 == 0);
18747
18172
  const int nb = k / QK4_1;
18748
18173
 
18749
18174
  for (int b = 0; b < n; b += k) {
18750
18175
  block_q4_1 * restrict y = (block_q4_1 *) dst + b/QK4_1;
18751
18176
 
18752
- quantize_row_q4_1_reference(src + b, y, k);
18177
+ wsp_quantize_row_q4_1_reference(src + b, y, k);
18753
18178
 
18754
18179
  for (int i = 0; i < nb; i++) {
18755
18180
  for (int j = 0; j < QK4_1; j += 2) {
@@ -18765,22 +18190,22 @@ size_t wsp_ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64
18765
18190
  return (n/QK4_1*sizeof(block_q4_1));
18766
18191
  }
18767
18192
 
18768
- size_t wsp_ggml_quantize_q5_0(const float * src, void * dst, int n, int k, int64_t * hist) {
18193
+ size_t wsp_ggml_wsp_quantize_q5_0(const float * src, void * dst, int n, int k, int64_t * hist) {
18769
18194
  assert(k % QK5_0 == 0);
18770
18195
  const int nb = k / QK5_0;
18771
18196
 
18772
18197
  for (int b = 0; b < n; b += k) {
18773
18198
  block_q5_0 * restrict y = (block_q5_0 *)dst + b/QK5_0;
18774
18199
 
18775
- quantize_row_q5_0_reference(src + b, y, k);
18200
+ wsp_quantize_row_q5_0_reference(src + b, y, k);
18776
18201
 
18777
18202
  for (int i = 0; i < nb; i++) {
18778
18203
  uint32_t qh;
18779
18204
  memcpy(&qh, &y[i].qh, sizeof(qh));
18780
18205
 
18781
18206
  for (int j = 0; j < QK5_0; j += 2) {
18782
- const uint8_t vh0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4;
18783
- const uint8_t vh1 = ((qh & (1u << (j + 16))) >> (j + 12));
18207
+ const uint8_t vh0 = ((qh & (1u << (j/2 + 0 ))) >> (j/2 + 0 )) << 4;
18208
+ const uint8_t vh1 = ((qh & (1u << (j/2 + 16))) >> (j/2 + 12));
18784
18209
 
18785
18210
  // cast to 16 bins
18786
18211
  const uint8_t vi0 = ((y[i].qs[j/2] & 0x0F) | vh0) / 2;
@@ -18795,22 +18220,22 @@ size_t wsp_ggml_quantize_q5_0(const float * src, void * dst, int n, int k, int64
18795
18220
  return (n/QK5_0*sizeof(block_q5_0));
18796
18221
  }
18797
18222
 
18798
- size_t wsp_ggml_quantize_q5_1(const float * src, void * dst, int n, int k, int64_t * hist) {
18223
+ size_t wsp_ggml_wsp_quantize_q5_1(const float * src, void * dst, int n, int k, int64_t * hist) {
18799
18224
  assert(k % QK5_1 == 0);
18800
18225
  const int nb = k / QK5_1;
18801
18226
 
18802
18227
  for (int b = 0; b < n; b += k) {
18803
18228
  block_q5_1 * restrict y = (block_q5_1 *)dst + b/QK5_1;
18804
18229
 
18805
- quantize_row_q5_1_reference(src + b, y, k);
18230
+ wsp_quantize_row_q5_1_reference(src + b, y, k);
18806
18231
 
18807
18232
  for (int i = 0; i < nb; i++) {
18808
18233
  uint32_t qh;
18809
18234
  memcpy(&qh, &y[i].qh, sizeof(qh));
18810
18235
 
18811
18236
  for (int j = 0; j < QK5_1; j += 2) {
18812
- const uint8_t vh0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4;
18813
- const uint8_t vh1 = ((qh & (1u << (j + 16))) >> (j + 12));
18237
+ const uint8_t vh0 = ((qh & (1u << (j/2 + 0 ))) >> (j/2 + 0 )) << 4;
18238
+ const uint8_t vh1 = ((qh & (1u << (j/2 + 16))) >> (j/2 + 12));
18814
18239
 
18815
18240
  // cast to 16 bins
18816
18241
  const uint8_t vi0 = ((y[i].qs[j/2] & 0x0F) | vh0) / 2;
@@ -18825,14 +18250,14 @@ size_t wsp_ggml_quantize_q5_1(const float * src, void * dst, int n, int k, int64
18825
18250
  return (n/QK5_1*sizeof(block_q5_1));
18826
18251
  }
18827
18252
 
18828
- size_t wsp_ggml_quantize_q8_0(const float * src, void * dst, int n, int k, int64_t * hist) {
18253
+ size_t wsp_ggml_wsp_quantize_q8_0(const float * src, void * dst, int n, int k, int64_t * hist) {
18829
18254
  assert(k % QK8_0 == 0);
18830
18255
  const int nb = k / QK8_0;
18831
18256
 
18832
18257
  for (int b = 0; b < n; b += k) {
18833
18258
  block_q8_0 * restrict y = (block_q8_0 *)dst + b/QK8_0;
18834
18259
 
18835
- quantize_row_q8_0_reference(src + b, y, k);
18260
+ wsp_quantize_row_q8_0_reference(src + b, y, k);
18836
18261
 
18837
18262
  for (int i = 0; i < nb; i++) {
18838
18263
  for (int j = 0; j < QK8_0; ++j) {
@@ -18846,68 +18271,68 @@ size_t wsp_ggml_quantize_q8_0(const float * src, void * dst, int n, int k, int64
18846
18271
  return (n/QK8_0*sizeof(block_q8_0));
18847
18272
  }
18848
18273
 
18849
- size_t wsp_ggml_quantize_chunk(enum wsp_ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist) {
18274
+ size_t wsp_ggml_wsp_quantize_chunk(enum wsp_ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist) {
18850
18275
  size_t result = 0;
18851
18276
  switch (type) {
18852
18277
  case WSP_GGML_TYPE_Q4_0:
18853
18278
  {
18854
18279
  WSP_GGML_ASSERT(start % QK4_0 == 0);
18855
18280
  block_q4_0 * block = (block_q4_0*)dst + start / QK4_0;
18856
- result = wsp_ggml_quantize_q4_0(src + start, block, n, n, hist);
18281
+ result = wsp_ggml_wsp_quantize_q4_0(src + start, block, n, n, hist);
18857
18282
  } break;
18858
18283
  case WSP_GGML_TYPE_Q4_1:
18859
18284
  {
18860
18285
  WSP_GGML_ASSERT(start % QK4_1 == 0);
18861
18286
  block_q4_1 * block = (block_q4_1*)dst + start / QK4_1;
18862
- result = wsp_ggml_quantize_q4_1(src + start, block, n, n, hist);
18287
+ result = wsp_ggml_wsp_quantize_q4_1(src + start, block, n, n, hist);
18863
18288
  } break;
18864
18289
  case WSP_GGML_TYPE_Q5_0:
18865
18290
  {
18866
18291
  WSP_GGML_ASSERT(start % QK5_0 == 0);
18867
18292
  block_q5_0 * block = (block_q5_0*)dst + start / QK5_0;
18868
- result = wsp_ggml_quantize_q5_0(src + start, block, n, n, hist);
18293
+ result = wsp_ggml_wsp_quantize_q5_0(src + start, block, n, n, hist);
18869
18294
  } break;
18870
18295
  case WSP_GGML_TYPE_Q5_1:
18871
18296
  {
18872
18297
  WSP_GGML_ASSERT(start % QK5_1 == 0);
18873
18298
  block_q5_1 * block = (block_q5_1*)dst + start / QK5_1;
18874
- result = wsp_ggml_quantize_q5_1(src + start, block, n, n, hist);
18299
+ result = wsp_ggml_wsp_quantize_q5_1(src + start, block, n, n, hist);
18875
18300
  } break;
18876
18301
  case WSP_GGML_TYPE_Q8_0:
18877
18302
  {
18878
18303
  WSP_GGML_ASSERT(start % QK8_0 == 0);
18879
18304
  block_q8_0 * block = (block_q8_0*)dst + start / QK8_0;
18880
- result = wsp_ggml_quantize_q8_0(src + start, block, n, n, hist);
18305
+ result = wsp_ggml_wsp_quantize_q8_0(src + start, block, n, n, hist);
18881
18306
  } break;
18882
18307
  case WSP_GGML_TYPE_Q2_K:
18883
18308
  {
18884
18309
  WSP_GGML_ASSERT(start % QK_K == 0);
18885
18310
  block_q2_K * block = (block_q2_K*)dst + start / QK_K;
18886
- result = wsp_ggml_quantize_q2_K(src + start, block, n, n, hist);
18311
+ result = wsp_ggml_wsp_quantize_q2_K(src + start, block, n, n, hist);
18887
18312
  } break;
18888
18313
  case WSP_GGML_TYPE_Q3_K:
18889
18314
  {
18890
18315
  WSP_GGML_ASSERT(start % QK_K == 0);
18891
18316
  block_q3_K * block = (block_q3_K*)dst + start / QK_K;
18892
- result = wsp_ggml_quantize_q3_K(src + start, block, n, n, hist);
18317
+ result = wsp_ggml_wsp_quantize_q3_K(src + start, block, n, n, hist);
18893
18318
  } break;
18894
18319
  case WSP_GGML_TYPE_Q4_K:
18895
18320
  {
18896
18321
  WSP_GGML_ASSERT(start % QK_K == 0);
18897
18322
  block_q4_K * block = (block_q4_K*)dst + start / QK_K;
18898
- result = wsp_ggml_quantize_q4_K(src + start, block, n, n, hist);
18323
+ result = wsp_ggml_wsp_quantize_q4_K(src + start, block, n, n, hist);
18899
18324
  } break;
18900
18325
  case WSP_GGML_TYPE_Q5_K:
18901
18326
  {
18902
18327
  WSP_GGML_ASSERT(start % QK_K == 0);
18903
18328
  block_q5_K * block = (block_q5_K*)dst + start / QK_K;
18904
- result = wsp_ggml_quantize_q5_K(src + start, block, n, n, hist);
18329
+ result = wsp_ggml_wsp_quantize_q5_K(src + start, block, n, n, hist);
18905
18330
  } break;
18906
18331
  case WSP_GGML_TYPE_Q6_K:
18907
18332
  {
18908
18333
  WSP_GGML_ASSERT(start % QK_K == 0);
18909
18334
  block_q6_K * block = (block_q6_K*)dst + start / QK_K;
18910
- result = wsp_ggml_quantize_q6_K(src + start, block, n, n, hist);
18335
+ result = wsp_ggml_wsp_quantize_q6_K(src + start, block, n, n, hist);
18911
18336
  } break;
18912
18337
  case WSP_GGML_TYPE_F16:
18913
18338
  {
@@ -19000,6 +18425,7 @@ struct wsp_gguf_kv {
19000
18425
 
19001
18426
  struct wsp_gguf_header {
19002
18427
  char magic[4];
18428
+
19003
18429
  uint32_t version;
19004
18430
  uint64_t n_tensors; // GGUFv2
19005
18431
  uint64_t n_kv; // GGUFv2
@@ -19089,7 +18515,7 @@ struct wsp_gguf_context * wsp_gguf_init_from_file(const char * fname, struct wsp
19089
18515
 
19090
18516
  for (uint32_t i = 0; i < sizeof(magic); i++) {
19091
18517
  if (magic[i] != WSP_GGUF_MAGIC[i]) {
19092
- fprintf(stderr, "%s: invalid magic characters %s.\n", __func__, magic);
18518
+ fprintf(stderr, "%s: invalid magic characters '%c%c%c%c'\n", __func__, magic[0], magic[1], magic[2], magic[3]);
19093
18519
  fclose(file);
19094
18520
  return NULL;
19095
18521
  }
@@ -19104,7 +18530,6 @@ struct wsp_gguf_context * wsp_gguf_init_from_file(const char * fname, struct wsp
19104
18530
  {
19105
18531
  strncpy(ctx->header.magic, magic, 4);
19106
18532
 
19107
-
19108
18533
  ctx->kv = NULL;
19109
18534
  ctx->infos = NULL;
19110
18535
  ctx->data = NULL;
@@ -19132,7 +18557,7 @@ struct wsp_gguf_context * wsp_gguf_init_from_file(const char * fname, struct wsp
19132
18557
  {
19133
18558
  ctx->kv = malloc(ctx->header.n_kv * sizeof(struct wsp_gguf_kv));
19134
18559
 
19135
- for (uint32_t i = 0; i < ctx->header.n_kv; ++i) {
18560
+ for (uint64_t i = 0; i < ctx->header.n_kv; ++i) {
19136
18561
  struct wsp_gguf_kv * kv = &ctx->kv[i];
19137
18562
 
19138
18563
  //fprintf(stderr, "%s: reading kv %d\n", __func__, i);
@@ -19179,7 +18604,7 @@ struct wsp_gguf_context * wsp_gguf_init_from_file(const char * fname, struct wsp
19179
18604
  case WSP_GGUF_TYPE_STRING:
19180
18605
  {
19181
18606
  kv->value.arr.data = malloc(kv->value.arr.n * sizeof(struct wsp_gguf_str));
19182
- for (uint32_t j = 0; j < kv->value.arr.n; ++j) {
18607
+ for (uint64_t j = 0; j < kv->value.arr.n; ++j) {
19183
18608
  ok = ok && wsp_gguf_fread_str(file, &((struct wsp_gguf_str *) kv->value.arr.data)[j], &offset);
19184
18609
  }
19185
18610
  } break;
@@ -19207,7 +18632,7 @@ struct wsp_gguf_context * wsp_gguf_init_from_file(const char * fname, struct wsp
19207
18632
  {
19208
18633
  ctx->infos = malloc(ctx->header.n_tensors * sizeof(struct wsp_gguf_tensor_info));
19209
18634
 
19210
- for (uint32_t i = 0; i < ctx->header.n_tensors; ++i) {
18635
+ for (uint64_t i = 0; i < ctx->header.n_tensors; ++i) {
19211
18636
  struct wsp_gguf_tensor_info * info = &ctx->infos[i];
19212
18637
 
19213
18638
  for (int j = 0; j < WSP_GGML_MAX_DIMS; ++j) {
@@ -19254,7 +18679,7 @@ struct wsp_gguf_context * wsp_gguf_init_from_file(const char * fname, struct wsp
19254
18679
  // compute the total size of the data section, taking into account the alignment
19255
18680
  {
19256
18681
  ctx->size = 0;
19257
- for (uint32_t i = 0; i < ctx->header.n_tensors; ++i) {
18682
+ for (uint64_t i = 0; i < ctx->header.n_tensors; ++i) {
19258
18683
  struct wsp_gguf_tensor_info * info = &ctx->infos[i];
19259
18684
 
19260
18685
  const int64_t ne =
@@ -19323,7 +18748,7 @@ struct wsp_gguf_context * wsp_gguf_init_from_file(const char * fname, struct wsp
19323
18748
  wsp_ggml_set_no_alloc(ctx_data, true);
19324
18749
 
19325
18750
  // create the tensors
19326
- for (uint32_t i = 0; i < ctx->header.n_tensors; ++i) {
18751
+ for (uint64_t i = 0; i < ctx->header.n_tensors; ++i) {
19327
18752
  const int64_t ne[WSP_GGML_MAX_DIMS] = {
19328
18753
  ctx->infos[i].ne[0],
19329
18754
  ctx->infos[i].ne[1],
@@ -19458,24 +18883,29 @@ int wsp_gguf_find_key(const struct wsp_gguf_context * ctx, const char * key) {
19458
18883
  }
19459
18884
 
19460
18885
  const char * wsp_gguf_get_key(const struct wsp_gguf_context * ctx, int key_id) {
18886
+ WSP_GGML_ASSERT(key_id >= 0 && key_id < wsp_gguf_get_n_kv(ctx));
19461
18887
  return ctx->kv[key_id].key.data;
19462
18888
  }
19463
18889
 
19464
18890
  enum wsp_gguf_type wsp_gguf_get_kv_type(const struct wsp_gguf_context * ctx, int key_id) {
18891
+ WSP_GGML_ASSERT(key_id >= 0 && key_id < wsp_gguf_get_n_kv(ctx));
19465
18892
  return ctx->kv[key_id].type;
19466
18893
  }
19467
18894
 
19468
18895
  enum wsp_gguf_type wsp_gguf_get_arr_type(const struct wsp_gguf_context * ctx, int key_id) {
18896
+ WSP_GGML_ASSERT(key_id >= 0 && key_id < wsp_gguf_get_n_kv(ctx));
19469
18897
  WSP_GGML_ASSERT(ctx->kv[key_id].type == WSP_GGUF_TYPE_ARRAY);
19470
18898
  return ctx->kv[key_id].value.arr.type;
19471
18899
  }
19472
18900
 
19473
18901
  const void * wsp_gguf_get_arr_data(const struct wsp_gguf_context * ctx, int key_id) {
18902
+ WSP_GGML_ASSERT(key_id >= 0 && key_id < wsp_gguf_get_n_kv(ctx));
19474
18903
  WSP_GGML_ASSERT(ctx->kv[key_id].type == WSP_GGUF_TYPE_ARRAY);
19475
18904
  return ctx->kv[key_id].value.arr.data;
19476
18905
  }
19477
18906
 
19478
18907
  const char * wsp_gguf_get_arr_str(const struct wsp_gguf_context * ctx, int key_id, int i) {
18908
+ WSP_GGML_ASSERT(key_id >= 0 && key_id < wsp_gguf_get_n_kv(ctx));
19479
18909
  WSP_GGML_ASSERT(ctx->kv[key_id].type == WSP_GGUF_TYPE_ARRAY);
19480
18910
  struct wsp_gguf_kv * kv = &ctx->kv[key_id];
19481
18911
  struct wsp_gguf_str * str = &((struct wsp_gguf_str *) kv->value.arr.data)[i];
@@ -19483,70 +18913,90 @@ const char * wsp_gguf_get_arr_str(const struct wsp_gguf_context * ctx, int key_i
19483
18913
  }
19484
18914
 
19485
18915
  int wsp_gguf_get_arr_n(const struct wsp_gguf_context * ctx, int key_id) {
18916
+ WSP_GGML_ASSERT(key_id >= 0 && key_id < wsp_gguf_get_n_kv(ctx));
19486
18917
  WSP_GGML_ASSERT(ctx->kv[key_id].type == WSP_GGUF_TYPE_ARRAY);
19487
18918
  return ctx->kv[key_id].value.arr.n;
19488
18919
  }
19489
18920
 
19490
18921
  uint8_t wsp_gguf_get_val_u8(const struct wsp_gguf_context * ctx, int key_id) {
18922
+ WSP_GGML_ASSERT(key_id >= 0 && key_id < wsp_gguf_get_n_kv(ctx));
19491
18923
  WSP_GGML_ASSERT(ctx->kv[key_id].type == WSP_GGUF_TYPE_UINT8);
19492
18924
  return ctx->kv[key_id].value.uint8;
19493
18925
  }
19494
18926
 
19495
18927
  int8_t wsp_gguf_get_val_i8(const struct wsp_gguf_context * ctx, int key_id) {
18928
+ WSP_GGML_ASSERT(key_id >= 0 && key_id < wsp_gguf_get_n_kv(ctx));
19496
18929
  WSP_GGML_ASSERT(ctx->kv[key_id].type == WSP_GGUF_TYPE_INT8);
19497
18930
  return ctx->kv[key_id].value.int8;
19498
18931
  }
19499
18932
 
19500
18933
  uint16_t wsp_gguf_get_val_u16(const struct wsp_gguf_context * ctx, int key_id) {
18934
+ WSP_GGML_ASSERT(key_id >= 0 && key_id < wsp_gguf_get_n_kv(ctx));
19501
18935
  WSP_GGML_ASSERT(ctx->kv[key_id].type == WSP_GGUF_TYPE_UINT16);
19502
18936
  return ctx->kv[key_id].value.uint16;
19503
18937
  }
19504
18938
 
19505
18939
  int16_t wsp_gguf_get_val_i16(const struct wsp_gguf_context * ctx, int key_id) {
18940
+ WSP_GGML_ASSERT(key_id >= 0 && key_id < wsp_gguf_get_n_kv(ctx));
19506
18941
  WSP_GGML_ASSERT(ctx->kv[key_id].type == WSP_GGUF_TYPE_INT16);
19507
18942
  return ctx->kv[key_id].value.int16;
19508
18943
  }
19509
18944
 
19510
18945
  uint32_t wsp_gguf_get_val_u32(const struct wsp_gguf_context * ctx, int key_id) {
18946
+ WSP_GGML_ASSERT(key_id >= 0 && key_id < wsp_gguf_get_n_kv(ctx));
19511
18947
  WSP_GGML_ASSERT(ctx->kv[key_id].type == WSP_GGUF_TYPE_UINT32);
19512
18948
  return ctx->kv[key_id].value.uint32;
19513
18949
  }
19514
18950
 
19515
18951
  int32_t wsp_gguf_get_val_i32(const struct wsp_gguf_context * ctx, int key_id) {
18952
+ WSP_GGML_ASSERT(key_id >= 0 && key_id < wsp_gguf_get_n_kv(ctx));
19516
18953
  WSP_GGML_ASSERT(ctx->kv[key_id].type == WSP_GGUF_TYPE_INT32);
19517
18954
  return ctx->kv[key_id].value.int32;
19518
18955
  }
19519
18956
 
19520
18957
  float wsp_gguf_get_val_f32(const struct wsp_gguf_context * ctx, int key_id) {
18958
+ WSP_GGML_ASSERT(key_id >= 0 && key_id < wsp_gguf_get_n_kv(ctx));
19521
18959
  WSP_GGML_ASSERT(ctx->kv[key_id].type == WSP_GGUF_TYPE_FLOAT32);
19522
18960
  return ctx->kv[key_id].value.float32;
19523
18961
  }
19524
18962
 
19525
18963
  uint64_t wsp_gguf_get_val_u64(const struct wsp_gguf_context * ctx, int key_id) {
18964
+ WSP_GGML_ASSERT(key_id >= 0 && key_id < wsp_gguf_get_n_kv(ctx));
19526
18965
  WSP_GGML_ASSERT(ctx->kv[key_id].type == WSP_GGUF_TYPE_UINT64);
19527
18966
  return ctx->kv[key_id].value.uint64;
19528
18967
  }
19529
18968
 
19530
18969
  int64_t wsp_gguf_get_val_i64(const struct wsp_gguf_context * ctx, int key_id) {
18970
+ WSP_GGML_ASSERT(key_id >= 0 && key_id < wsp_gguf_get_n_kv(ctx));
19531
18971
  WSP_GGML_ASSERT(ctx->kv[key_id].type == WSP_GGUF_TYPE_INT64);
19532
18972
  return ctx->kv[key_id].value.int64;
19533
18973
  }
19534
18974
 
19535
18975
  double wsp_gguf_get_val_f64(const struct wsp_gguf_context * ctx, int key_id) {
18976
+ WSP_GGML_ASSERT(key_id >= 0 && key_id < wsp_gguf_get_n_kv(ctx));
19536
18977
  WSP_GGML_ASSERT(ctx->kv[key_id].type == WSP_GGUF_TYPE_FLOAT64);
19537
18978
  return ctx->kv[key_id].value.float64;
19538
18979
  }
19539
18980
 
19540
18981
  bool wsp_gguf_get_val_bool(const struct wsp_gguf_context * ctx, int key_id) {
18982
+ WSP_GGML_ASSERT(key_id >= 0 && key_id < wsp_gguf_get_n_kv(ctx));
19541
18983
  WSP_GGML_ASSERT(ctx->kv[key_id].type == WSP_GGUF_TYPE_BOOL);
19542
18984
  return ctx->kv[key_id].value.bool_;
19543
18985
  }
19544
18986
 
19545
18987
  const char * wsp_gguf_get_val_str(const struct wsp_gguf_context * ctx, int key_id) {
18988
+ WSP_GGML_ASSERT(key_id >= 0 && key_id < wsp_gguf_get_n_kv(ctx));
19546
18989
  WSP_GGML_ASSERT(ctx->kv[key_id].type == WSP_GGUF_TYPE_STRING);
19547
18990
  return ctx->kv[key_id].value.str.data;
19548
18991
  }
19549
18992
 
18993
+ const void * wsp_gguf_get_val_data(const struct wsp_gguf_context * ctx, int key_id) {
18994
+ WSP_GGML_ASSERT(key_id >= 0 && key_id < wsp_gguf_get_n_kv(ctx));
18995
+ WSP_GGML_ASSERT(ctx->kv[key_id].type != WSP_GGUF_TYPE_ARRAY);
18996
+ WSP_GGML_ASSERT(ctx->kv[key_id].type != WSP_GGUF_TYPE_STRING);
18997
+ return &ctx->kv[key_id].value;
18998
+ }
18999
+
19550
19000
  int wsp_gguf_get_n_tensors(const struct wsp_gguf_context * ctx) {
19551
19001
  return ctx->header.n_tensors;
19552
19002
  }