cui-llama.rn 1.2.6 → 1.3.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (75) hide show
  1. package/README.md +3 -2
  2. package/android/src/main/CMakeLists.txt +26 -6
  3. package/android/src/main/java/com/rnllama/LlamaContext.java +115 -27
  4. package/android/src/main/java/com/rnllama/RNLlama.java +40 -7
  5. package/android/src/main/jni.cpp +228 -40
  6. package/android/src/newarch/java/com/rnllama/RNLlamaModule.java +9 -4
  7. package/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +9 -4
  8. package/cpp/amx/amx.cpp +196 -0
  9. package/cpp/amx/amx.h +20 -0
  10. package/cpp/amx/common.h +101 -0
  11. package/cpp/amx/mmq.cpp +2524 -0
  12. package/cpp/amx/mmq.h +16 -0
  13. package/cpp/common.cpp +118 -251
  14. package/cpp/common.h +53 -30
  15. package/cpp/ggml-aarch64.c +46 -3395
  16. package/cpp/ggml-aarch64.h +0 -20
  17. package/cpp/ggml-alloc.c +6 -8
  18. package/cpp/ggml-backend-impl.h +33 -11
  19. package/cpp/ggml-backend-reg.cpp +423 -0
  20. package/cpp/ggml-backend.cpp +14 -676
  21. package/cpp/ggml-backend.h +46 -9
  22. package/cpp/ggml-common.h +6 -0
  23. package/cpp/ggml-cpu-aarch64.c +3823 -0
  24. package/cpp/ggml-cpu-aarch64.h +32 -0
  25. package/cpp/ggml-cpu-impl.h +14 -242
  26. package/cpp/ggml-cpu-quants.c +10835 -0
  27. package/cpp/ggml-cpu-quants.h +63 -0
  28. package/cpp/ggml-cpu.c +13971 -13720
  29. package/cpp/ggml-cpu.cpp +715 -0
  30. package/cpp/ggml-cpu.h +65 -63
  31. package/cpp/ggml-impl.h +285 -25
  32. package/cpp/ggml-metal.h +8 -8
  33. package/cpp/ggml-metal.m +1221 -728
  34. package/cpp/ggml-quants.c +189 -10681
  35. package/cpp/ggml-quants.h +78 -125
  36. package/cpp/ggml-threading.cpp +12 -0
  37. package/cpp/ggml-threading.h +12 -0
  38. package/cpp/ggml.c +688 -1460
  39. package/cpp/ggml.h +58 -244
  40. package/cpp/json-schema-to-grammar.cpp +1045 -1045
  41. package/cpp/json.hpp +24766 -24766
  42. package/cpp/llama-sampling.cpp +5 -2
  43. package/cpp/llama.cpp +409 -123
  44. package/cpp/llama.h +8 -4
  45. package/cpp/rn-llama.hpp +89 -25
  46. package/cpp/sampling.cpp +42 -3
  47. package/cpp/sampling.h +22 -1
  48. package/cpp/sgemm.cpp +608 -0
  49. package/cpp/speculative.cpp +270 -0
  50. package/cpp/speculative.h +28 -0
  51. package/cpp/unicode.cpp +11 -0
  52. package/ios/RNLlama.mm +43 -20
  53. package/ios/RNLlamaContext.h +9 -3
  54. package/ios/RNLlamaContext.mm +146 -33
  55. package/jest/mock.js +0 -1
  56. package/lib/commonjs/NativeRNLlama.js.map +1 -1
  57. package/lib/commonjs/grammar.js +4 -2
  58. package/lib/commonjs/grammar.js.map +1 -1
  59. package/lib/commonjs/index.js +52 -15
  60. package/lib/commonjs/index.js.map +1 -1
  61. package/lib/module/NativeRNLlama.js.map +1 -1
  62. package/lib/module/grammar.js +2 -1
  63. package/lib/module/grammar.js.map +1 -1
  64. package/lib/module/index.js +51 -15
  65. package/lib/module/index.js.map +1 -1
  66. package/lib/typescript/NativeRNLlama.d.ts +122 -8
  67. package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
  68. package/lib/typescript/grammar.d.ts +5 -6
  69. package/lib/typescript/grammar.d.ts.map +1 -1
  70. package/lib/typescript/index.d.ts +15 -6
  71. package/lib/typescript/index.d.ts.map +1 -1
  72. package/package.json +2 -1
  73. package/src/NativeRNLlama.ts +135 -13
  74. package/src/grammar.ts +10 -8
  75. package/src/index.ts +104 -28
package/cpp/ggml.c CHANGED
@@ -3,9 +3,11 @@
3
3
 
4
4
  #include "ggml-backend.h"
5
5
  #include "ggml-impl.h"
6
- #include "ggml-cpu-impl.h"
7
- #include "ggml-quants.h"
6
+ #include "ggml-threading.h"
8
7
  #include "ggml.h"
8
+
9
+ // FIXME: required here for quantization functions
10
+ #include "ggml-quants.h"
9
11
  #include "ggml-aarch64.h"
10
12
 
11
13
  #if defined(_MSC_VER) || defined(__MINGW32__)
@@ -47,6 +49,17 @@
47
49
 
48
50
  #define UNUSED LM_GGML_UNUSED
49
51
 
52
+ #if defined(_MSC_VER)
53
+ #define m512bh(p) p
54
+ #define m512i(p) p
55
+ #else
56
+ #define m512bh(p) (__m512bh)(p)
57
+ #define m512i(p) (__m512i)(p)
58
+ #endif
59
+
60
+ // precomputed f32 table for f16 (256 KB) (ggml-impl.h)
61
+ float lm_ggml_table_f32_f16[1 << 16];
62
+
50
63
  #if (defined(__linux__) || defined(__APPLE__) || defined(__FreeBSD__) || defined(__NetBSD__) || defined(__OpenBSD__)) && \
51
64
  (!defined(TARGET_OS_TV) && !defined(TARGET_OS_WATCH))
52
65
  #include <unistd.h>
@@ -363,7 +376,7 @@ void lm_ggml_fp16_to_fp32_row(const lm_ggml_fp16_t * x, float * y, int64_t n) {
363
376
  void lm_ggml_fp32_to_fp16_row(const float * x, lm_ggml_fp16_t * y, int64_t n) {
364
377
  int64_t i = 0;
365
378
  #if defined(__F16C__)
366
- if (lm_ggml_cpu_has_f16c()) {
379
+ //if (lm_ggml_cpu_has_f16c()) {
367
380
  for (; i + 7 < n; i += 8) {
368
381
  __m256 x_vec = _mm256_loadu_ps(x + i);
369
382
  __m128i y_vec = _mm256_cvtps_ph(x_vec, _MM_FROUND_TO_NEAREST_INT);
@@ -374,7 +387,7 @@ void lm_ggml_fp32_to_fp16_row(const float * x, lm_ggml_fp16_t * y, int64_t n) {
374
387
  __m128i y_vec = _mm_cvtps_ph(x_vec, _MM_FROUND_TO_NEAREST_INT);
375
388
  _mm_storel_epi64((__m128i *)(y + i), y_vec);
376
389
  }
377
- }
390
+ //}
378
391
  #endif
379
392
  for (; i < n; i++) {
380
393
  y[i] = LM_GGML_FP32_TO_FP16(x[i]);
@@ -384,7 +397,7 @@ void lm_ggml_fp32_to_fp16_row(const float * x, lm_ggml_fp16_t * y, int64_t n) {
384
397
  void lm_ggml_bf16_to_fp32_row(const lm_ggml_bf16_t * x, float * y, int64_t n) {
385
398
  int64_t i = 0;
386
399
  #if defined(__AVX512F__)
387
- if (lm_ggml_cpu_has_avx512()) {
400
+ //if (lm_ggml_cpu_has_avx512()) {
388
401
  for (; i + 16 <= n; i += 16) {
389
402
  _mm512_storeu_ps(y + i,
390
403
  _mm512_castsi512_ps(
@@ -394,10 +407,10 @@ void lm_ggml_bf16_to_fp32_row(const lm_ggml_bf16_t * x, float * y, int64_t n) {
394
407
  (const __m256i *)(x + i))),
395
408
  16)));
396
409
  }
397
- }
410
+ //}
398
411
  #endif
399
412
  #if defined(__AVX2__)
400
- if (lm_ggml_cpu_has_avx2()) {
413
+ //if (lm_ggml_cpu_has_avx2()) {
401
414
  for (; i + 8 <= n; i += 8) {
402
415
  _mm256_storeu_ps(y + i,
403
416
  _mm256_castsi256_ps(
@@ -407,7 +420,7 @@ void lm_ggml_bf16_to_fp32_row(const lm_ggml_bf16_t * x, float * y, int64_t n) {
407
420
  (const __m128i *)(x + i))),
408
421
  16)));
409
422
  }
410
- }
423
+ //}
411
424
  #endif
412
425
  for (; i < n; i++) {
413
426
  y[i] = LM_GGML_BF16_TO_FP32(x[i]);
@@ -601,7 +614,6 @@ static const struct lm_ggml_type_traits type_traits[LM_GGML_TYPE_COUNT] = {
601
614
  .type_size = sizeof(lm_ggml_fp16_t),
602
615
  .is_quantized = false,
603
616
  .to_float = (lm_ggml_to_float_t) lm_ggml_fp16_to_fp32_row,
604
- .from_float = (lm_ggml_from_float_t) lm_ggml_fp32_to_fp16_row,
605
617
  .from_float_ref = (lm_ggml_from_float_t) lm_ggml_fp32_to_fp16_row,
606
618
  },
607
619
  [LM_GGML_TYPE_Q4_0] = {
@@ -610,7 +622,6 @@ static const struct lm_ggml_type_traits type_traits[LM_GGML_TYPE_COUNT] = {
610
622
  .type_size = sizeof(block_q4_0),
611
623
  .is_quantized = true,
612
624
  .to_float = (lm_ggml_to_float_t) dequantize_row_q4_0,
613
- .from_float = quantize_row_q4_0,
614
625
  .from_float_ref = (lm_ggml_from_float_t) quantize_row_q4_0_ref,
615
626
  },
616
627
  [LM_GGML_TYPE_Q4_1] = {
@@ -619,7 +630,6 @@ static const struct lm_ggml_type_traits type_traits[LM_GGML_TYPE_COUNT] = {
619
630
  .type_size = sizeof(block_q4_1),
620
631
  .is_quantized = true,
621
632
  .to_float = (lm_ggml_to_float_t) dequantize_row_q4_1,
622
- .from_float = quantize_row_q4_1,
623
633
  .from_float_ref = (lm_ggml_from_float_t) quantize_row_q4_1_ref,
624
634
  },
625
635
  [4] = { // LM_GGML_TYPE_Q4_2
@@ -627,18 +637,12 @@ static const struct lm_ggml_type_traits type_traits[LM_GGML_TYPE_COUNT] = {
627
637
  .blck_size = 0,
628
638
  .type_size = 0,
629
639
  .is_quantized = false,
630
- .to_float = NULL,
631
- .from_float = NULL,
632
- .from_float_ref = NULL,
633
640
  },
634
641
  [5] = { // LM_GGML_TYPE_Q4_3
635
642
  .type_name = "DEPRECATED",
636
643
  .blck_size = 0,
637
644
  .type_size = 0,
638
645
  .is_quantized = false,
639
- .to_float = NULL,
640
- .from_float = NULL,
641
- .from_float_ref = NULL,
642
646
  },
643
647
  [LM_GGML_TYPE_Q5_0] = {
644
648
  .type_name = "q5_0",
@@ -646,7 +650,6 @@ static const struct lm_ggml_type_traits type_traits[LM_GGML_TYPE_COUNT] = {
646
650
  .type_size = sizeof(block_q5_0),
647
651
  .is_quantized = true,
648
652
  .to_float = (lm_ggml_to_float_t) dequantize_row_q5_0,
649
- .from_float = quantize_row_q5_0,
650
653
  .from_float_ref = (lm_ggml_from_float_t) quantize_row_q5_0_ref,
651
654
  },
652
655
  [LM_GGML_TYPE_Q5_1] = {
@@ -655,7 +658,6 @@ static const struct lm_ggml_type_traits type_traits[LM_GGML_TYPE_COUNT] = {
655
658
  .type_size = sizeof(block_q5_1),
656
659
  .is_quantized = true,
657
660
  .to_float = (lm_ggml_to_float_t) dequantize_row_q5_1,
658
- .from_float = quantize_row_q5_1,
659
661
  .from_float_ref = (lm_ggml_from_float_t) quantize_row_q5_1_ref,
660
662
  },
661
663
  [LM_GGML_TYPE_Q8_0] = {
@@ -664,7 +666,6 @@ static const struct lm_ggml_type_traits type_traits[LM_GGML_TYPE_COUNT] = {
664
666
  .type_size = sizeof(block_q8_0),
665
667
  .is_quantized = true,
666
668
  .to_float = (lm_ggml_to_float_t) dequantize_row_q8_0,
667
- .from_float = quantize_row_q8_0,
668
669
  .from_float_ref = (lm_ggml_from_float_t) quantize_row_q8_0_ref,
669
670
  },
670
671
  [LM_GGML_TYPE_Q8_1] = {
@@ -672,7 +673,6 @@ static const struct lm_ggml_type_traits type_traits[LM_GGML_TYPE_COUNT] = {
672
673
  .blck_size = QK8_1,
673
674
  .type_size = sizeof(block_q8_1),
674
675
  .is_quantized = true,
675
- .from_float = quantize_row_q8_1,
676
676
  .from_float_ref = (lm_ggml_from_float_t) quantize_row_q8_1_ref,
677
677
  },
678
678
  [LM_GGML_TYPE_Q2_K] = {
@@ -681,7 +681,6 @@ static const struct lm_ggml_type_traits type_traits[LM_GGML_TYPE_COUNT] = {
681
681
  .type_size = sizeof(block_q2_K),
682
682
  .is_quantized = true,
683
683
  .to_float = (lm_ggml_to_float_t) dequantize_row_q2_K,
684
- .from_float = quantize_row_q2_K,
685
684
  .from_float_ref = (lm_ggml_from_float_t) quantize_row_q2_K_ref,
686
685
  },
687
686
  [LM_GGML_TYPE_Q3_K] = {
@@ -690,7 +689,6 @@ static const struct lm_ggml_type_traits type_traits[LM_GGML_TYPE_COUNT] = {
690
689
  .type_size = sizeof(block_q3_K),
691
690
  .is_quantized = true,
692
691
  .to_float = (lm_ggml_to_float_t) dequantize_row_q3_K,
693
- .from_float = quantize_row_q3_K,
694
692
  .from_float_ref = (lm_ggml_from_float_t) quantize_row_q3_K_ref,
695
693
  },
696
694
  [LM_GGML_TYPE_Q4_K] = {
@@ -699,7 +697,6 @@ static const struct lm_ggml_type_traits type_traits[LM_GGML_TYPE_COUNT] = {
699
697
  .type_size = sizeof(block_q4_K),
700
698
  .is_quantized = true,
701
699
  .to_float = (lm_ggml_to_float_t) dequantize_row_q4_K,
702
- .from_float = quantize_row_q4_K,
703
700
  .from_float_ref = (lm_ggml_from_float_t) quantize_row_q4_K_ref,
704
701
  },
705
702
  [LM_GGML_TYPE_Q5_K] = {
@@ -708,7 +705,6 @@ static const struct lm_ggml_type_traits type_traits[LM_GGML_TYPE_COUNT] = {
708
705
  .type_size = sizeof(block_q5_K),
709
706
  .is_quantized = true,
710
707
  .to_float = (lm_ggml_to_float_t) dequantize_row_q5_K,
711
- .from_float = quantize_row_q5_K,
712
708
  .from_float_ref = (lm_ggml_from_float_t) quantize_row_q5_K_ref,
713
709
  },
714
710
  [LM_GGML_TYPE_Q6_K] = {
@@ -717,7 +713,6 @@ static const struct lm_ggml_type_traits type_traits[LM_GGML_TYPE_COUNT] = {
717
713
  .type_size = sizeof(block_q6_K),
718
714
  .is_quantized = true,
719
715
  .to_float = (lm_ggml_to_float_t) dequantize_row_q6_K,
720
- .from_float = quantize_row_q6_K,
721
716
  .from_float_ref = (lm_ggml_from_float_t) quantize_row_q6_K_ref,
722
717
  },
723
718
  [LM_GGML_TYPE_IQ2_XXS] = {
@@ -726,7 +721,6 @@ static const struct lm_ggml_type_traits type_traits[LM_GGML_TYPE_COUNT] = {
726
721
  .type_size = sizeof(block_iq2_xxs),
727
722
  .is_quantized = true,
728
723
  .to_float = (lm_ggml_to_float_t) dequantize_row_iq2_xxs,
729
- .from_float = NULL,
730
724
  .from_float_ref = NULL,
731
725
  },
732
726
  [LM_GGML_TYPE_IQ2_XS] = {
@@ -735,7 +729,6 @@ static const struct lm_ggml_type_traits type_traits[LM_GGML_TYPE_COUNT] = {
735
729
  .type_size = sizeof(block_iq2_xs),
736
730
  .is_quantized = true,
737
731
  .to_float = (lm_ggml_to_float_t) dequantize_row_iq2_xs,
738
- .from_float = NULL,
739
732
  .from_float_ref = NULL,
740
733
  },
741
734
  [LM_GGML_TYPE_IQ3_XXS] = {
@@ -744,7 +737,6 @@ static const struct lm_ggml_type_traits type_traits[LM_GGML_TYPE_COUNT] = {
744
737
  .type_size = sizeof(block_iq3_xxs),
745
738
  .is_quantized = true,
746
739
  .to_float = (lm_ggml_to_float_t) dequantize_row_iq3_xxs,
747
- .from_float = quantize_row_iq3_xxs,
748
740
  .from_float_ref = (lm_ggml_from_float_t)quantize_row_iq3_xxs_ref,
749
741
  },
750
742
  [LM_GGML_TYPE_IQ3_S] = {
@@ -753,7 +745,6 @@ static const struct lm_ggml_type_traits type_traits[LM_GGML_TYPE_COUNT] = {
753
745
  .type_size = sizeof(block_iq3_s),
754
746
  .is_quantized = true,
755
747
  .to_float = (lm_ggml_to_float_t) dequantize_row_iq3_s,
756
- .from_float = quantize_row_iq3_s,
757
748
  .from_float_ref = (lm_ggml_from_float_t)quantize_row_iq3_s_ref,
758
749
  },
759
750
  [LM_GGML_TYPE_IQ2_S] = {
@@ -762,7 +753,6 @@ static const struct lm_ggml_type_traits type_traits[LM_GGML_TYPE_COUNT] = {
762
753
  .type_size = sizeof(block_iq2_s),
763
754
  .is_quantized = true,
764
755
  .to_float = (lm_ggml_to_float_t) dequantize_row_iq2_s,
765
- .from_float = quantize_row_iq2_s,
766
756
  .from_float_ref = (lm_ggml_from_float_t)quantize_row_iq2_s_ref,
767
757
  },
768
758
  [LM_GGML_TYPE_IQ1_S] = {
@@ -771,7 +761,6 @@ static const struct lm_ggml_type_traits type_traits[LM_GGML_TYPE_COUNT] = {
771
761
  .type_size = sizeof(block_iq1_s),
772
762
  .is_quantized = true,
773
763
  .to_float = (lm_ggml_to_float_t) dequantize_row_iq1_s,
774
- .from_float = NULL,
775
764
  .from_float_ref = NULL,
776
765
  },
777
766
  [LM_GGML_TYPE_IQ1_M] = {
@@ -780,7 +769,6 @@ static const struct lm_ggml_type_traits type_traits[LM_GGML_TYPE_COUNT] = {
780
769
  .type_size = sizeof(block_iq1_m),
781
770
  .is_quantized = true,
782
771
  .to_float = (lm_ggml_to_float_t) dequantize_row_iq1_m,
783
- .from_float = NULL,
784
772
  .from_float_ref = NULL,
785
773
  },
786
774
  [LM_GGML_TYPE_IQ4_NL] = {
@@ -789,7 +777,6 @@ static const struct lm_ggml_type_traits type_traits[LM_GGML_TYPE_COUNT] = {
789
777
  .type_size = sizeof(block_iq4_nl),
790
778
  .is_quantized = true,
791
779
  .to_float = (lm_ggml_to_float_t) dequantize_row_iq4_nl,
792
- .from_float = quantize_row_iq4_nl,
793
780
  .from_float_ref = (lm_ggml_from_float_t)quantize_row_iq4_nl_ref,
794
781
  },
795
782
  [LM_GGML_TYPE_IQ4_XS] = {
@@ -798,7 +785,6 @@ static const struct lm_ggml_type_traits type_traits[LM_GGML_TYPE_COUNT] = {
798
785
  .type_size = sizeof(block_iq4_xs),
799
786
  .is_quantized = true,
800
787
  .to_float = (lm_ggml_to_float_t) dequantize_row_iq4_xs,
801
- .from_float = quantize_row_iq4_xs,
802
788
  .from_float_ref = (lm_ggml_from_float_t)quantize_row_iq4_xs_ref,
803
789
  },
804
790
  [LM_GGML_TYPE_Q8_K] = {
@@ -806,7 +792,6 @@ static const struct lm_ggml_type_traits type_traits[LM_GGML_TYPE_COUNT] = {
806
792
  .blck_size = QK_K,
807
793
  .type_size = sizeof(block_q8_K),
808
794
  .is_quantized = true,
809
- .from_float = quantize_row_q8_K,
810
795
  },
811
796
  [LM_GGML_TYPE_BF16] = {
812
797
  .type_name = "bf16",
@@ -814,7 +799,6 @@ static const struct lm_ggml_type_traits type_traits[LM_GGML_TYPE_COUNT] = {
814
799
  .type_size = sizeof(lm_ggml_bf16_t),
815
800
  .is_quantized = false,
816
801
  .to_float = (lm_ggml_to_float_t) lm_ggml_bf16_to_fp32_row,
817
- .from_float = (lm_ggml_from_float_t) lm_ggml_fp32_to_bf16_row,
818
802
  .from_float_ref = (lm_ggml_from_float_t) lm_ggml_fp32_to_bf16_row_ref,
819
803
  },
820
804
  [LM_GGML_TYPE_Q4_0_4_4] = {
@@ -824,7 +808,6 @@ static const struct lm_ggml_type_traits type_traits[LM_GGML_TYPE_COUNT] = {
824
808
  .type_size = sizeof(block_q4_0),
825
809
  .is_quantized = true,
826
810
  .to_float = NULL,
827
- .from_float = NULL,
828
811
  .from_float_ref = NULL,
829
812
  },
830
813
  [LM_GGML_TYPE_Q4_0_4_8] = {
@@ -834,7 +817,6 @@ static const struct lm_ggml_type_traits type_traits[LM_GGML_TYPE_COUNT] = {
834
817
  .type_size = sizeof(block_q4_0),
835
818
  .is_quantized = true,
836
819
  .to_float = NULL,
837
- .from_float = NULL,
838
820
  .from_float_ref = NULL,
839
821
  },
840
822
  [LM_GGML_TYPE_Q4_0_8_8] = {
@@ -844,7 +826,6 @@ static const struct lm_ggml_type_traits type_traits[LM_GGML_TYPE_COUNT] = {
844
826
  .type_size = sizeof(block_q4_0),
845
827
  .is_quantized = true,
846
828
  .to_float = NULL,
847
- .from_float = NULL,
848
829
  .from_float_ref = NULL,
849
830
  },
850
831
  [LM_GGML_TYPE_TQ1_0] = {
@@ -853,7 +834,6 @@ static const struct lm_ggml_type_traits type_traits[LM_GGML_TYPE_COUNT] = {
853
834
  .type_size = sizeof(block_tq1_0),
854
835
  .is_quantized = true,
855
836
  .to_float = (lm_ggml_to_float_t) dequantize_row_tq1_0,
856
- .from_float = quantize_row_tq1_0,
857
837
  .from_float_ref = (lm_ggml_from_float_t) quantize_row_tq1_0_ref,
858
838
  },
859
839
  [LM_GGML_TYPE_TQ2_0] = {
@@ -862,9 +842,17 @@ static const struct lm_ggml_type_traits type_traits[LM_GGML_TYPE_COUNT] = {
862
842
  .type_size = sizeof(block_tq2_0),
863
843
  .is_quantized = true,
864
844
  .to_float = (lm_ggml_to_float_t) dequantize_row_tq2_0,
865
- .from_float = quantize_row_tq2_0,
866
845
  .from_float_ref = (lm_ggml_from_float_t) quantize_row_tq2_0_ref,
867
846
  },
847
+ [LM_GGML_TYPE_IQ4_NL_4_4] = {
848
+ .type_name = "iq4_nl_4x4",
849
+ .blck_size = QK4_NL,
850
+ .blck_size_interleave = 4,
851
+ .type_size = sizeof(block_iq4_nl),
852
+ .is_quantized = true,
853
+ .to_float = NULL,
854
+ .from_float_ref = NULL,
855
+ },
868
856
  };
869
857
 
870
858
  const struct lm_ggml_type_traits * lm_ggml_get_type_traits(enum lm_ggml_type type) {
@@ -988,7 +976,7 @@ static const char * LM_GGML_OP_NAME[LM_GGML_OP_COUNT] = {
988
976
  "WIN_UNPART",
989
977
  "GET_REL_POS",
990
978
  "ADD_REL_POS",
991
- "RWKV_WKV",
979
+ "RWKV_WKV6",
992
980
 
993
981
  "UNARY",
994
982
 
@@ -1083,7 +1071,7 @@ static const char * LM_GGML_OP_SYMBOL[LM_GGML_OP_COUNT] = {
1083
1071
  "win_unpart(x)",
1084
1072
  "get_rel_pos(x)",
1085
1073
  "add_rel_pos(x)",
1086
- "rwkv_wkv(k, v, r, tf, td, s)",
1074
+ "rwkv_wkv6(k, v, r, tf, td, s)",
1087
1075
 
1088
1076
  "unary(x)",
1089
1077
 
@@ -1420,11 +1408,11 @@ static inline bool lm_ggml_can_repeat_rows(const struct lm_ggml_tensor * t0, con
1420
1408
  ////////////////////////////////////////////////////////////////////////////////
1421
1409
 
1422
1410
  struct lm_ggml_context * lm_ggml_init(struct lm_ggml_init_params params) {
1423
- static bool is_first_call = false;
1411
+ static bool is_first_call = true;
1424
1412
 
1425
1413
  lm_ggml_critical_section_start();
1426
1414
 
1427
- if (!is_first_call) {
1415
+ if (is_first_call) {
1428
1416
  // initialize time system (required on Windows)
1429
1417
  lm_ggml_time_init();
1430
1418
 
@@ -1435,7 +1423,8 @@ struct lm_ggml_context * lm_ggml_init(struct lm_ggml_init_params params) {
1435
1423
  } u = {i};
1436
1424
  lm_ggml_table_f32_f16[i] = LM_GGML_COMPUTE_FP16_TO_FP32(u.fp16);
1437
1425
  }
1438
- is_first_call = true;
1426
+
1427
+ is_first_call = false;
1439
1428
  }
1440
1429
 
1441
1430
  lm_ggml_critical_section_end();
@@ -1625,14 +1614,13 @@ static struct lm_ggml_tensor * lm_ggml_new_tensor_impl(
1625
1614
  /*.op =*/ LM_GGML_OP_NONE,
1626
1615
  /*.op_params =*/ { 0 },
1627
1616
  /*.flags =*/ 0,
1628
- /*.grad =*/ NULL,
1629
1617
  /*.src =*/ { NULL },
1630
1618
  /*.view_src =*/ view_src,
1631
1619
  /*.view_offs =*/ view_offs,
1632
1620
  /*.data =*/ obj_alloc_size > 0 ? (void *)(result + 1) : data,
1633
1621
  /*.name =*/ { 0 },
1634
1622
  /*.extra =*/ NULL,
1635
- ///*.padding =*/ { 0 },
1623
+ /*.padding =*/ { 0 },
1636
1624
  };
1637
1625
 
1638
1626
  #ifdef __clang__
@@ -2289,6 +2277,7 @@ struct lm_ggml_tensor * lm_ggml_argmax(
2289
2277
  struct lm_ggml_context * ctx,
2290
2278
  struct lm_ggml_tensor * a) {
2291
2279
  LM_GGML_ASSERT(lm_ggml_is_matrix(a));
2280
+ LM_GGML_ASSERT(a->ne[0] <= INT32_MAX);
2292
2281
 
2293
2282
  struct lm_ggml_tensor * result = lm_ggml_new_tensor_1d(ctx, LM_GGML_TYPE_I32, a->ne[1]);
2294
2283
 
@@ -3658,6 +3647,22 @@ struct lm_ggml_tensor * lm_ggml_rope_custom_inplace(
3658
3647
  );
3659
3648
  }
3660
3649
 
3650
+ // Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
3651
+ // `corr_dim(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
3652
+ static float lm_ggml_rope_yarn_corr_dim(int n_dims, int n_ctx_orig, float n_rot, float base) {
3653
+ return n_dims * logf(n_ctx_orig / (n_rot * 2 * (float)M_PI)) / (2 * logf(base));
3654
+ }
3655
+
3656
+ void lm_ggml_rope_yarn_corr_dims(
3657
+ int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2]
3658
+ ) {
3659
+ // start and end correction dims
3660
+ float start = floorf(lm_ggml_rope_yarn_corr_dim(n_dims, n_ctx_orig, beta_fast, freq_base));
3661
+ float end = ceilf(lm_ggml_rope_yarn_corr_dim(n_dims, n_ctx_orig, beta_slow, freq_base));
3662
+ dims[0] = MAX(0, start);
3663
+ dims[1] = MIN(n_dims - 1, end);
3664
+ }
3665
+
3661
3666
  // lm_ggml_rope_back
3662
3667
 
3663
3668
  struct lm_ggml_tensor * lm_ggml_rope_back(
@@ -4156,6 +4161,7 @@ struct lm_ggml_tensor * lm_ggml_argsort(
4156
4161
  struct lm_ggml_context * ctx,
4157
4162
  struct lm_ggml_tensor * a,
4158
4163
  enum lm_ggml_sort_order order) {
4164
+ LM_GGML_ASSERT(a->ne[0] <= INT32_MAX);
4159
4165
  struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, LM_GGML_TYPE_I32, LM_GGML_MAX_DIMS, a->ne);
4160
4166
 
4161
4167
  lm_ggml_set_op_params_i32(result, 0, (int32_t) order);
@@ -4211,8 +4217,6 @@ struct lm_ggml_tensor * lm_ggml_flash_attn_ext(
4211
4217
  LM_GGML_ASSERT(mask);
4212
4218
  }
4213
4219
 
4214
- bool is_node = false;
4215
-
4216
4220
  // permute(0, 2, 1, 3)
4217
4221
  int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] };
4218
4222
  struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, LM_GGML_TYPE_F32, 4, ne);
@@ -4220,8 +4224,7 @@ struct lm_ggml_tensor * lm_ggml_flash_attn_ext(
4220
4224
  float params[] = { scale, max_bias, logit_softcap };
4221
4225
  lm_ggml_set_op_params(result, params, sizeof(params));
4222
4226
 
4223
- result->op = LM_GGML_OP_FLASH_ATTN_EXT;
4224
- result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL;
4227
+ result->op = LM_GGML_OP_FLASH_ATTN_EXT;
4225
4228
  result->src[0] = q;
4226
4229
  result->src[1] = k;
4227
4230
  result->src[2] = v;
@@ -4240,6 +4243,15 @@ void lm_ggml_flash_attn_ext_set_prec(
4240
4243
  lm_ggml_set_op_params_i32(a, 3, prec_i32); // scale is on first pos, max_bias on second
4241
4244
  }
4242
4245
 
4246
+ enum lm_ggml_prec lm_ggml_flash_attn_ext_get_prec(
4247
+ const struct lm_ggml_tensor * a) {
4248
+ LM_GGML_ASSERT(a->op == LM_GGML_OP_FLASH_ATTN_EXT);
4249
+
4250
+ const int32_t prec_i32 = lm_ggml_get_op_params_i32(a, 3);
4251
+
4252
+ return (enum lm_ggml_prec) prec_i32;
4253
+ }
4254
+
4243
4255
  // lm_ggml_flash_attn_back
4244
4256
 
4245
4257
  struct lm_ggml_tensor * lm_ggml_flash_attn_back(
@@ -4280,14 +4292,6 @@ struct lm_ggml_tensor * lm_ggml_flash_attn_back(
4280
4292
 
4281
4293
  LM_GGML_ASSERT(ne2 % kvne2 == 0);
4282
4294
 
4283
- bool is_node = false;
4284
-
4285
- if (q->grad || k->grad || v->grad) {
4286
- // when using this operation (in backwards pass) these grads are set.
4287
- // we don't want to create (big) grad of our result, so is_node is false.
4288
- is_node = false;
4289
- }
4290
-
4291
4295
  // store gradients of q, k and v as continuous tensors concatenated in result.
4292
4296
  // note: v and gradv are actually transposed, i.e. v->ne[0] != D.
4293
4297
  const int64_t elem_q = lm_ggml_nelements(q);
@@ -4310,8 +4314,7 @@ struct lm_ggml_tensor * lm_ggml_flash_attn_back(
4310
4314
  int32_t masked_i = masked ? 1 : 0;
4311
4315
  lm_ggml_set_op_params(result, &masked_i, sizeof(masked_i));
4312
4316
 
4313
- result->op = LM_GGML_OP_FLASH_ATTN_BACK;
4314
- result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL;
4317
+ result->op = LM_GGML_OP_FLASH_ATTN_BACK;
4315
4318
  result->src[0] = q;
4316
4319
  result->src[1] = k;
4317
4320
  result->src[2] = v;
@@ -4515,9 +4518,9 @@ struct lm_ggml_tensor * lm_ggml_add_rel_pos_inplace(
4515
4518
  return lm_ggml_add_rel_pos_impl(ctx, a, pw, ph, true);
4516
4519
  }
4517
4520
 
4518
- // lm_ggml_rwkv_wkv
4521
+ // lm_ggml_rwkv_wkv6
4519
4522
 
4520
- struct lm_ggml_tensor * lm_ggml_rwkv_wkv(
4523
+ struct lm_ggml_tensor * lm_ggml_rwkv_wkv6(
4521
4524
  struct lm_ggml_context * ctx,
4522
4525
  struct lm_ggml_tensor * k,
4523
4526
  struct lm_ggml_tensor * v,
@@ -4549,7 +4552,7 @@ struct lm_ggml_tensor * lm_ggml_rwkv_wkv(
4549
4552
  const int64_t ne[4] = { S * H, n_tokens + S * n_seqs, 1, 1 };
4550
4553
  struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, LM_GGML_TYPE_F32, 4, ne);
4551
4554
 
4552
- result->op = LM_GGML_OP_RWKV_WKV;
4555
+ result->op = LM_GGML_OP_RWKV_WKV6;
4553
4556
  result->src[0] = k;
4554
4557
  result->src[1] = v;
4555
4558
  result->src[2] = r;
@@ -4953,34 +4956,24 @@ struct lm_ggml_tensor * lm_ggml_opt_step_adamw(
4953
4956
  struct lm_ggml_context * ctx,
4954
4957
  struct lm_ggml_tensor * a,
4955
4958
  struct lm_ggml_tensor * grad,
4956
- float alpha,
4957
- float beta1,
4958
- float beta2,
4959
- float eps,
4960
- float wd) {
4959
+ struct lm_ggml_tensor * m,
4960
+ struct lm_ggml_tensor * v,
4961
+ struct lm_ggml_tensor * adamw_params) {
4961
4962
  LM_GGML_ASSERT(a->flags & LM_GGML_TENSOR_FLAG_PARAM);
4962
4963
  LM_GGML_ASSERT(lm_ggml_are_same_shape(a, grad));
4963
- LM_GGML_ASSERT(alpha > 0.0f);
4964
- LM_GGML_ASSERT(beta1 >= 0.0f && beta1 <= 1.0f);
4965
- LM_GGML_ASSERT(beta2 >= 0.0f && beta2 <= 1.0f);
4966
- LM_GGML_ASSERT(eps >= 0.0f);
4967
- LM_GGML_ASSERT(wd >= 0.0f && wd <= 1.0f);
4964
+ LM_GGML_ASSERT(lm_ggml_are_same_shape(a, m));
4965
+ LM_GGML_ASSERT(lm_ggml_are_same_shape(a, v));
4966
+ LM_GGML_ASSERT(adamw_params->type == LM_GGML_TYPE_F32);
4967
+ LM_GGML_ASSERT(lm_ggml_nelements(adamw_params) == 7);
4968
4968
 
4969
4969
  struct lm_ggml_tensor * result = lm_ggml_view_tensor(ctx, a);
4970
4970
 
4971
- const int64_t iter = 1;
4972
- memcpy(&result->op_params[0], &iter, sizeof(int64_t));
4973
- lm_ggml_set_op_params_f32(result, 2, alpha);
4974
- lm_ggml_set_op_params_f32(result, 3, beta1);
4975
- lm_ggml_set_op_params_f32(result, 4, beta2);
4976
- lm_ggml_set_op_params_f32(result, 5, eps);
4977
- lm_ggml_set_op_params_f32(result, 6, wd);
4978
-
4979
4971
  result->op = LM_GGML_OP_OPT_STEP_ADAMW;
4980
4972
  result->src[0] = a;
4981
4973
  result->src[1] = grad;
4982
- result->src[2] = lm_ggml_dup_tensor(ctx, grad);
4983
- result->src[3] = lm_ggml_dup_tensor(ctx, grad);
4974
+ result->src[2] = m;
4975
+ result->src[3] = v;
4976
+ result->src[4] = adamw_params;
4984
4977
 
4985
4978
  return result;
4986
4979
  }
@@ -5049,1112 +5042,526 @@ static void lm_ggml_hash_map_free(struct hash_map * map) {
5049
5042
  LM_GGML_FREE(map);
5050
5043
  }
5051
5044
 
5052
- // gradient checkpointing
5045
+ // utility functions to change gradients
5046
+ // isrc is the index of tensor in cgraph->visited_has_set.keys
5047
+ // the corresponding gradient (accumulators) are also at position isrc
5048
+ // if tensor has a gradient accumulator, modify that accumulator in-place
5049
+ // else if there is no gradient for tensor, set the corresponding value
5050
+ // else, just add/subtract/etc. the gradients
5053
5051
 
5054
- static struct lm_ggml_tensor * lm_ggml_recompute_graph_node(
5052
+ static void lm_ggml_add_or_set(
5055
5053
  struct lm_ggml_context * ctx,
5056
- struct lm_ggml_cgraph * graph,
5057
- struct hash_map * replacements,
5058
- struct lm_ggml_tensor * node) {
5059
-
5060
- if (node == NULL) {
5061
- return NULL;
5062
- }
5063
-
5064
- if (node->flags & LM_GGML_TENSOR_FLAG_PARAM) {
5065
- return node;
5066
- }
5067
-
5068
- if (!lm_ggml_hash_contains(&graph->visited_hash_set, node)) {
5069
- return node;
5070
- }
5071
-
5072
- int count_children = 0;
5073
- for (int k = 0; k < LM_GGML_MAX_SRC; ++k) {
5074
- if (node->src[k]) {
5075
- ++count_children;
5076
- }
5077
- }
5078
-
5079
- if (count_children == 0) {
5080
- return node;
5081
- }
5082
-
5083
- size_t i = lm_ggml_hash_find(&replacements->set, node);
5084
- LM_GGML_ASSERT(i != LM_GGML_HASHSET_FULL); // assert that not full
5085
- if (replacements->set.keys[i] == node) {
5086
- return replacements->vals[i];
5087
- }
5088
-
5089
- struct lm_ggml_tensor * clone = lm_ggml_new_tensor(ctx, node->type, LM_GGML_MAX_DIMS, node->ne);
5090
-
5091
- // insert clone into replacements
5092
- LM_GGML_ASSERT(replacements->set.keys[i] == NULL); // assert that we don't overwrite
5093
- replacements->set.keys[i] = node;
5094
- replacements->vals[i] = clone;
5095
-
5096
- clone->op = node->op;
5097
- clone->grad = node->grad;
5098
- clone->flags = node->flags;
5099
- clone->extra = node->extra;
5100
- for (int k = 0; k < LM_GGML_MAX_DIMS; ++k) {
5101
- clone->nb[k] = node->nb[k];
5102
- }
5103
- for (int k = 0; k < LM_GGML_MAX_SRC; ++k) {
5104
- clone->src[k] = lm_ggml_recompute_graph_node(ctx, graph, replacements, node->src[k]);
5105
- }
5106
- if (node->view_src != NULL) {
5107
- clone->data = (node->view_src->data == NULL)
5108
- ? NULL // view_src not yet allocated
5109
- : (char *) node->view_src->data // view_src already allocated
5110
- + node->view_offs;
5111
- clone->view_src = node->view_src;
5112
- clone->view_offs = node->view_offs;
5054
+ struct lm_ggml_cgraph * cgraph,
5055
+ size_t isrc,
5056
+ struct lm_ggml_tensor * tensor) {
5057
+ struct lm_ggml_tensor * src = cgraph->visited_hash_set.keys[isrc];
5058
+ LM_GGML_ASSERT(src);
5059
+ if (cgraph->grads[isrc]) {
5060
+ cgraph->grads[isrc] = lm_ggml_add_impl(ctx, cgraph->grads[isrc], tensor, /*inplace =*/ cgraph->grad_accs[isrc]);
5061
+ } else {
5062
+ cgraph->grads[isrc] = tensor;
5113
5063
  }
5114
-
5115
- LM_GGML_ASSERT(sizeof(node->op_params) == sizeof(int32_t) * (LM_GGML_MAX_OP_PARAMS / sizeof(int32_t)));
5116
- LM_GGML_ASSERT(sizeof(node->name) == LM_GGML_MAX_NAME);
5117
- memcpy(clone->op_params, node->op_params, sizeof(node->op_params));
5118
- lm_ggml_format_name(clone, "%s (clone)", lm_ggml_get_name(node));
5119
-
5120
- return clone;
5064
+ lm_ggml_format_name(cgraph->grads[isrc], "grad for %s", src->name);
5065
+ lm_ggml_build_forward_expand(cgraph, cgraph->grads[isrc]);
5121
5066
  }
5122
5067
 
5123
- void lm_ggml_build_backward_gradient_checkpointing(
5124
- struct lm_ggml_context * ctx,
5125
- struct lm_ggml_cgraph * gf,
5126
- struct lm_ggml_cgraph * gb,
5127
- struct lm_ggml_cgraph * gb_tmp,
5128
- struct lm_ggml_tensor * * checkpoints,
5129
- int n_checkpoints) {
5130
- lm_ggml_graph_cpy(gf, gb_tmp);
5131
- lm_ggml_build_backward_expand(ctx, gf, gb_tmp, false);
5132
-
5133
- if (n_checkpoints <= 0) {
5134
- lm_ggml_graph_cpy(gb_tmp, gb);
5135
- return;
5136
- }
5137
-
5138
- struct hash_map * replacements = lm_ggml_new_hash_map(gf->n_nodes + gf->n_leafs + n_checkpoints);
5139
-
5140
- // insert checkpoints in replacements
5141
- for (int i = 0; i < n_checkpoints; ++i) {
5142
- size_t k = lm_ggml_hash_find(&replacements->set, checkpoints[i]);
5143
- LM_GGML_ASSERT(k != LM_GGML_HASHSET_FULL); // assert that not full
5144
- LM_GGML_ASSERT(replacements->set.keys[k] == NULL); // assert that we don't overwrite
5145
- replacements->set.keys[k] = checkpoints[i];
5146
- replacements->vals[k] = checkpoints[i];
5147
- }
5148
-
5149
- lm_ggml_graph_cpy(gf, gb);
5150
- // rewrite gb_tmp->nodes[gf->n_nodes:gb_tmp->n_nodes],
5151
- // replacing references to gb_tmp->nodes[0:gf->n_nodes] ( == gf->nodes[0:gf->n_nodes]),
5152
- // by recomputing them from checkpoints
5153
- for (int i = gf->n_nodes; i<gb_tmp->n_nodes; ++i) {
5154
- struct lm_ggml_tensor * node = gb_tmp->nodes[i];
5155
- for (int k = 0; k < LM_GGML_MAX_SRC; ++k) {
5156
- // insert new tensors recomputing src, reusing already made replacements,
5157
- // remember replacements: remember new tensors with mapping from corresponding gf nodes
5158
- // recurse for input tensors,
5159
- // unless (i.e. terminating when) input tensors are replacements (like checkpoints)
5160
- node->src[k] = lm_ggml_recompute_graph_node(ctx, gf, replacements, node->src[k]);
5161
- }
5162
- // insert rewritten backward node with replacements made into resulting backward graph gb
5163
- lm_ggml_build_forward_expand(gb, node);
5068
+ static void lm_ggml_acc_or_set(
5069
+ struct lm_ggml_context * ctx,
5070
+ struct lm_ggml_cgraph * cgraph,
5071
+ size_t isrc,
5072
+ struct lm_ggml_tensor * tensor,
5073
+ const size_t nb1,
5074
+ const size_t nb2,
5075
+ const size_t nb3,
5076
+ const size_t offset) {
5077
+ struct lm_ggml_tensor * src = cgraph->visited_hash_set.keys[isrc];
5078
+ LM_GGML_ASSERT(src);
5079
+ if (cgraph->grads[isrc]) {
5080
+ cgraph->grads[isrc] = lm_ggml_acc_impl(ctx, cgraph->grads[isrc], tensor, nb1, nb2, nb3, offset, cgraph->grad_accs[isrc]);
5081
+ } else {
5082
+ struct lm_ggml_tensor * a_zero = lm_ggml_scale(ctx, src, 0.0f); // FIXME this is going to produce NaN if a contains inf/NaN
5083
+ cgraph->grads[isrc] = lm_ggml_acc_impl(ctx, a_zero, tensor, nb1, nb2, nb3, offset, false);
5164
5084
  }
5165
-
5166
- lm_ggml_hash_map_free(replacements);
5085
+ lm_ggml_format_name(cgraph->grads[isrc], "grad for %s", cgraph->visited_hash_set.keys[isrc]->name);
5086
+ lm_ggml_build_forward_expand(cgraph, cgraph->grads[isrc]);
5167
5087
  }
5168
5088
 
5169
- // utility functions to change gradients
5170
- // if a is in acc_table, modify gradients in-place and mark result as gradient accumulator
5171
- // else if a is in zero_table, replace a
5172
- // else, just add/subtract/etc. the gradients
5173
-
5174
- static struct lm_ggml_tensor * lm_ggml_add_or_set(
5175
- struct lm_ggml_context * ctx,
5176
- struct lm_ggml_tensor * a,
5177
- struct lm_ggml_tensor * b,
5178
- struct lm_ggml_hash_set * zero_table,
5179
- struct lm_ggml_hash_set * acc_table) {
5180
- if (lm_ggml_hash_contains(acc_table, a)) {
5181
- struct lm_ggml_tensor * ret = lm_ggml_add_impl(ctx, a, b, true);
5182
- const size_t insert_result = lm_ggml_hash_insert(acc_table, ret);
5183
- LM_GGML_ASSERT(insert_result != LM_GGML_HASHSET_FULL);
5184
- LM_GGML_ASSERT(insert_result != LM_GGML_HASHSET_ALREADY_EXISTS);
5185
- return ret;
5186
- }
5187
- if (lm_ggml_hash_contains(zero_table, a)) {
5188
- return b;
5089
+ static void lm_ggml_add1_or_set(
5090
+ struct lm_ggml_context * ctx,
5091
+ struct lm_ggml_cgraph * cgraph,
5092
+ size_t isrc,
5093
+ struct lm_ggml_tensor * tensor) {
5094
+ struct lm_ggml_tensor * src = cgraph->visited_hash_set.keys[isrc];
5095
+ LM_GGML_ASSERT(src);
5096
+ if (cgraph->grads[isrc]) {
5097
+ cgraph->grads[isrc] = lm_ggml_add1_impl(ctx, cgraph->grads[isrc], tensor, cgraph->grad_accs[isrc]);
5098
+ } else {
5099
+ cgraph->grads[isrc] = lm_ggml_repeat(ctx, tensor, src);
5189
5100
  }
5190
- return lm_ggml_add_impl(ctx, a, b, false);
5101
+ lm_ggml_format_name(cgraph->grads[isrc], "grad for %s", src->name);
5102
+ lm_ggml_build_forward_expand(cgraph, cgraph->grads[isrc]);
5191
5103
  }
5192
5104
 
5193
- static struct lm_ggml_tensor * lm_ggml_acc_or_set(
5194
- struct lm_ggml_context * ctx,
5195
- struct lm_ggml_tensor * a,
5196
- struct lm_ggml_tensor * b,
5197
- const size_t nb1,
5198
- const size_t nb2,
5199
- const size_t nb3,
5200
- const size_t offset,
5201
- struct lm_ggml_hash_set * zero_table,
5202
- struct lm_ggml_hash_set * acc_table) {
5203
- if (lm_ggml_hash_contains(acc_table, a)) {
5204
- struct lm_ggml_tensor * ret = lm_ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, true);
5205
- const size_t insert_result = lm_ggml_hash_insert(acc_table, ret);
5206
- LM_GGML_ASSERT(insert_result != LM_GGML_HASHSET_FULL);
5207
- LM_GGML_ASSERT(insert_result != LM_GGML_HASHSET_ALREADY_EXISTS);
5208
- return ret;
5209
- }
5210
- if (lm_ggml_hash_contains(zero_table, a)) {
5211
- struct lm_ggml_tensor * a_zero = lm_ggml_scale(ctx, a, 0.0f); // FIXME this is going to produce NaN if a contains inf/NaN
5212
- return lm_ggml_acc_impl(ctx, a_zero, b, nb1, nb2, nb3, offset, false);
5105
+ static void lm_ggml_sub_or_set(
5106
+ struct lm_ggml_context * ctx,
5107
+ struct lm_ggml_cgraph * cgraph,
5108
+ size_t isrc,
5109
+ struct lm_ggml_tensor * tensor) {
5110
+ struct lm_ggml_tensor * src = cgraph->visited_hash_set.keys[isrc];
5111
+ LM_GGML_ASSERT(src);
5112
+ if (cgraph->grads[isrc]) {
5113
+ cgraph->grads[isrc] = lm_ggml_sub_impl(ctx, cgraph->grads[isrc], tensor, cgraph->grad_accs[isrc]);
5114
+ } else {
5115
+ cgraph->grads[isrc] = lm_ggml_neg(ctx, tensor);
5213
5116
  }
5214
- return lm_ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, false);
5117
+ lm_ggml_format_name(cgraph->grads[isrc], "grad for %s", src->name);
5118
+ lm_ggml_build_forward_expand(cgraph, cgraph->grads[isrc]);
5215
5119
  }
5216
5120
 
5217
- static struct lm_ggml_tensor * lm_ggml_add1_or_set(
5218
- struct lm_ggml_context * ctx,
5219
- struct lm_ggml_tensor * a,
5220
- struct lm_ggml_tensor * b,
5221
- struct lm_ggml_hash_set * zero_table,
5222
- struct lm_ggml_hash_set * acc_table) {
5223
- if (lm_ggml_hash_contains(acc_table, a)) {
5224
- struct lm_ggml_tensor * ret = lm_ggml_add1_impl(ctx, a, b, true);
5225
- const size_t insert_result = lm_ggml_hash_insert(acc_table, ret);
5226
- LM_GGML_ASSERT(insert_result != LM_GGML_HASHSET_FULL);
5227
- LM_GGML_ASSERT(insert_result != LM_GGML_HASHSET_ALREADY_EXISTS);
5228
- return ret;
5229
- }
5230
- if (lm_ggml_hash_contains(zero_table, a)) {
5231
- return lm_ggml_repeat(ctx, b, a);
5232
- }
5233
- return lm_ggml_add1_impl(ctx, a, b, false);
5234
- }
5121
+ static void lm_ggml_compute_backward(
5122
+ struct lm_ggml_context * ctx, struct lm_ggml_cgraph * cgraph, int i, bool * grads_needed) {
5123
+ struct lm_ggml_tensor * tensor = cgraph->nodes[i];
5124
+ struct lm_ggml_tensor * grad = lm_ggml_graph_get_grad(cgraph, tensor);
5235
5125
 
5236
- static struct lm_ggml_tensor * lm_ggml_sub_or_set(
5237
- struct lm_ggml_context * ctx,
5238
- struct lm_ggml_tensor * a,
5239
- struct lm_ggml_tensor * b,
5240
- struct lm_ggml_hash_set * zero_table,
5241
- struct lm_ggml_hash_set * acc_table) {
5242
- if (lm_ggml_hash_contains(acc_table, a)) {
5243
- struct lm_ggml_tensor * ret = lm_ggml_sub_impl(ctx, a, b, true);
5244
- const size_t insert_result = lm_ggml_hash_insert(acc_table, ret);
5245
- LM_GGML_ASSERT(insert_result != LM_GGML_HASHSET_FULL);
5246
- LM_GGML_ASSERT(insert_result != LM_GGML_HASHSET_ALREADY_EXISTS);
5247
- return ret;
5248
- }
5249
- if (lm_ggml_hash_contains(zero_table, a)) {
5250
- return lm_ggml_neg(ctx, b);
5126
+ if (!grad) {
5127
+ return;
5251
5128
  }
5252
- return lm_ggml_sub_impl(ctx, a, b, false);
5253
- }
5254
5129
 
5255
- static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggml_tensor * tensor, struct lm_ggml_hash_set * zero_table, struct lm_ggml_hash_set * acc_table) {
5256
5130
  struct lm_ggml_tensor * src0 = tensor->src[0];
5257
5131
  struct lm_ggml_tensor * src1 = tensor->src[1];
5258
5132
  struct lm_ggml_tensor * src2 = tensor->src[2];
5133
+ struct lm_ggml_hash_set * hash_set = &cgraph->visited_hash_set;
5134
+ const size_t isrc0 = src0 ? lm_ggml_hash_find(hash_set, src0) : (size_t) -1;
5135
+ const size_t isrc1 = src1 ? lm_ggml_hash_find(hash_set, src1) : (size_t) -1;
5136
+ const size_t isrc2 = src2 ? lm_ggml_hash_find(hash_set, src2) : (size_t) -1;
5137
+ const bool src0_needs_grads = src0 && isrc0 != LM_GGML_HASHSET_FULL && lm_ggml_bitset_get(hash_set->used, isrc0) && grads_needed[isrc0];
5138
+ const bool src1_needs_grads = src1 && isrc1 != LM_GGML_HASHSET_FULL && lm_ggml_bitset_get(hash_set->used, isrc1) && grads_needed[isrc1];
5139
+ const bool src2_needs_grads = src2 && isrc2 != LM_GGML_HASHSET_FULL && lm_ggml_bitset_get(hash_set->used, isrc2) && grads_needed[isrc2];
5259
5140
 
5260
5141
  switch (tensor->op) {
5261
- case LM_GGML_OP_DUP:
5262
- {
5263
- if (src0->grad) {
5264
- src0->grad = lm_ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table);
5265
- }
5266
- } break;
5267
- case LM_GGML_OP_ADD:
5268
- {
5269
- if (src0->grad) {
5270
- src0->grad = lm_ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table);
5271
- }
5272
- if (src1->grad) {
5273
- if (lm_ggml_are_same_shape(src0, src1)) {
5274
- src1->grad = lm_ggml_add_or_set(ctx, src1->grad, tensor->grad, zero_table, acc_table);
5275
- } else {
5276
- src1->grad = lm_ggml_add_or_set(ctx, src1->grad, lm_ggml_repeat_back(ctx, tensor->grad, src1), zero_table, acc_table);
5277
- }
5278
- }
5279
- } break;
5280
- case LM_GGML_OP_ADD1:
5281
- {
5282
- if (src0->grad) {
5283
- src0->grad = lm_ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table);
5284
- }
5285
- if (src1->grad) {
5286
- src1->grad = lm_ggml_add_or_set(ctx,
5287
- src1->grad,
5288
- lm_ggml_mean(ctx, tensor->grad), // TODO: should probably be sum instead of mean
5289
- zero_table, acc_table);
5290
- }
5291
- } break;
5292
- case LM_GGML_OP_ACC:
5293
- {
5294
- if (src0->grad) {
5295
- src0->grad = lm_ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table);
5296
- }
5297
- if (src1->grad) {
5298
- const size_t nb1 = ((int32_t *) tensor->op_params)[0];
5299
- const size_t nb2 = ((int32_t *) tensor->op_params)[1];
5300
- const size_t nb3 = ((int32_t *) tensor->op_params)[2];
5301
- const size_t offset = ((int32_t *) tensor->op_params)[3];
5302
-
5303
- struct lm_ggml_tensor * tensor_grad_view = lm_ggml_view_4d(ctx,
5304
- tensor->grad,
5305
- src1->grad->ne[0],
5306
- src1->grad->ne[1],
5307
- src1->grad->ne[2],
5308
- src1->grad->ne[3],
5309
- nb1, nb2, nb3, offset);
5310
-
5311
- src1->grad =
5312
- lm_ggml_add_or_set(ctx,
5313
- src1->grad,
5314
- lm_ggml_reshape(ctx,
5315
- lm_ggml_cont(ctx, tensor_grad_view),
5316
- src1->grad),
5317
- zero_table, acc_table);
5318
- }
5319
- } break;
5320
- case LM_GGML_OP_SUB:
5321
- {
5322
- if (src0->grad) {
5323
- src0->grad = lm_ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table);
5324
- }
5325
- if (src1->grad) {
5326
- src1->grad = lm_ggml_sub_or_set(ctx, src1->grad, tensor->grad, zero_table, acc_table);
5327
- }
5328
- } break;
5329
- case LM_GGML_OP_MUL:
5330
- {
5331
- if (src0->grad) {
5332
- src0->grad =
5333
- lm_ggml_add_or_set(ctx,
5334
- src0->grad,
5335
- lm_ggml_mul(ctx, src1, tensor->grad),
5336
- zero_table, acc_table);
5337
- }
5338
- if (src1->grad) {
5339
- src1->grad =
5340
- lm_ggml_add_or_set(ctx,
5341
- src1->grad,
5342
- lm_ggml_mul(ctx, src0, tensor->grad),
5343
- zero_table, acc_table);
5344
- }
5345
- } break;
5346
- case LM_GGML_OP_DIV:
5347
- {
5348
- if (src0->grad) {
5349
- src0->grad =
5350
- lm_ggml_add_or_set(ctx,
5351
- src0->grad,
5352
- lm_ggml_div(ctx, tensor->grad, src1),
5353
- zero_table, acc_table);
5354
- }
5355
- if (src1->grad) {
5356
- src1->grad =
5357
- lm_ggml_sub_or_set(ctx,
5358
- src1->grad,
5359
- lm_ggml_mul(ctx,
5360
- tensor->grad,
5361
- lm_ggml_div(ctx, tensor, src1)),
5362
- zero_table, acc_table);
5363
- }
5364
- } break;
5365
- case LM_GGML_OP_SQR:
5366
- {
5367
- if (src0->grad) {
5368
- src0->grad =
5369
- lm_ggml_add_or_set(ctx,
5370
- src0->grad,
5371
- lm_ggml_scale(ctx,
5372
- lm_ggml_mul(ctx, src0, tensor->grad),
5373
- 2.0f),
5374
- zero_table, acc_table);
5375
- }
5376
- } break;
5377
- case LM_GGML_OP_SQRT:
5378
- {
5379
- if (src0->grad) {
5380
- src0->grad =
5381
- lm_ggml_add_or_set(ctx,
5382
- src0->grad,
5383
- lm_ggml_scale(ctx,
5384
- lm_ggml_div(ctx,
5385
- tensor->grad,
5386
- tensor),
5387
- 0.5f),
5388
- zero_table, acc_table);
5389
- }
5390
- } break;
5391
- case LM_GGML_OP_LOG:
5392
- {
5393
- if (src0->grad) {
5394
- src0->grad =
5395
- lm_ggml_add_or_set(ctx,
5396
- src0->grad,
5397
- lm_ggml_div(ctx,
5398
- tensor->grad,
5399
- src0),
5400
- zero_table, acc_table);
5401
- }
5402
- } break;
5403
- case LM_GGML_OP_SIN:
5404
- {
5405
- if (src0->grad) {
5406
- src0->grad =
5407
- lm_ggml_add_or_set(ctx,
5408
- src0->grad,
5409
- lm_ggml_mul(ctx,
5410
- tensor->grad,
5411
- lm_ggml_cos(ctx, src0)),
5412
- zero_table, acc_table);
5413
- }
5414
- } break;
5415
- case LM_GGML_OP_COS:
5416
- {
5417
- if (src0->grad) {
5418
- src0->grad =
5419
- lm_ggml_sub_or_set(ctx,
5420
- src0->grad,
5421
- lm_ggml_mul(ctx,
5422
- tensor->grad,
5423
- lm_ggml_sin(ctx, src0)),
5424
- zero_table, acc_table);
5425
- }
5426
- } break;
5427
- case LM_GGML_OP_SUM:
5428
- {
5429
- if (src0->grad) {
5430
- src0->grad =
5431
- lm_ggml_add1_or_set(ctx,
5432
- src0->grad,
5433
- tensor->grad,
5434
- zero_table, acc_table);
5435
- }
5436
- } break;
5437
- case LM_GGML_OP_SUM_ROWS:
5438
- {
5439
- if (src0->grad) {
5440
- src0->grad =
5441
- lm_ggml_add_or_set(ctx,
5442
- src0->grad,
5443
- lm_ggml_repeat(ctx,
5444
- tensor->grad,
5445
- src0->grad),
5446
- zero_table, acc_table);
5447
- }
5448
- } break;
5449
- case LM_GGML_OP_MEAN:
5450
- case LM_GGML_OP_ARGMAX:
5451
- case LM_GGML_OP_COUNT_EQUAL:
5452
- {
5453
- LM_GGML_ABORT("fatal error"); // TODO: implement
5454
- }
5455
- case LM_GGML_OP_REPEAT:
5456
- {
5457
- // necessary for llama
5458
- if (src0->grad) {
5459
- src0->grad = lm_ggml_add_or_set(ctx,
5460
- src0->grad,
5461
- lm_ggml_repeat_back(ctx, tensor->grad, src0->grad),
5462
- zero_table, acc_table);
5463
- }
5464
- } break;
5465
- case LM_GGML_OP_REPEAT_BACK:
5466
- {
5467
- if (src0->grad) {
5468
- // TODO: test this
5469
- src0->grad = lm_ggml_add_or_set(ctx,
5470
- src0->grad,
5471
- lm_ggml_repeat(ctx, tensor->grad, src0->grad),
5472
- zero_table, acc_table);
5473
- }
5474
- } break;
5475
- case LM_GGML_OP_CONCAT:
5476
- {
5477
- LM_GGML_ABORT("fatal error"); // TODO: implement
5478
- }
5479
- case LM_GGML_OP_SILU_BACK:
5480
- {
5481
- LM_GGML_ABORT("fatal error"); // TODO: not implemented
5142
+ case LM_GGML_OP_DUP: {
5143
+ if (src0_needs_grads) {
5144
+ lm_ggml_add_or_set(ctx, cgraph, isrc0, grad);
5482
5145
  }
5483
- case LM_GGML_OP_NORM:
5484
- {
5485
- LM_GGML_ABORT("fatal error"); // TODO: not implemented
5146
+ } break;
5147
+ case LM_GGML_OP_ADD: {
5148
+ if (src0_needs_grads) {
5149
+ lm_ggml_add_or_set(ctx, cgraph, isrc0, grad);
5486
5150
  }
5487
- case LM_GGML_OP_RMS_NORM:
5488
- {
5489
- // necessary for llama
5490
- if (src0->grad) {
5491
- float eps;
5492
- memcpy(&eps, tensor->op_params, sizeof(float));
5493
-
5494
- src0->grad = lm_ggml_add_or_set(ctx,
5495
- src0->grad,
5496
- lm_ggml_rms_norm_back(ctx, src0, tensor->grad, eps),
5497
- zero_table, acc_table);
5151
+ if (src1_needs_grads) {
5152
+ struct lm_ggml_tensor * tmp = grad;
5153
+ if (!lm_ggml_are_same_shape(src0, src1)) {
5154
+ tmp = lm_ggml_repeat_back(ctx, tmp, src1);
5498
5155
  }
5499
- } break;
5500
- case LM_GGML_OP_RMS_NORM_BACK:
5501
- {
5502
- LM_GGML_ABORT("fatal error"); // TODO: not implemented
5156
+ lm_ggml_add_or_set(ctx, cgraph, isrc1, tmp);
5503
5157
  }
5504
- case LM_GGML_OP_GROUP_NORM:
5505
- {
5506
- LM_GGML_ABORT("fatal error"); // TODO: not implemented
5158
+ } break;
5159
+ case LM_GGML_OP_ADD1: {
5160
+ if (src0_needs_grads) {
5161
+ lm_ggml_add_or_set(ctx, cgraph, isrc0, grad);
5507
5162
  }
5508
- case LM_GGML_OP_MUL_MAT:
5509
- {
5510
- // https://cs231n.github.io/optimization-2/#staged
5511
- // # forward pass
5512
- // s0 = np.random.randn(5, 10)
5513
- // s1 = np.random.randn(10, 3)
5514
- // t = s0.dot(s1)
5515
-
5516
- // # now suppose we had the gradient on t from above in the circuit
5517
- // dt = np.random.randn(*t.shape) # same shape as t
5518
- // ds0 = dt.dot(s1.T) #.T gives the transpose of the matrix
5519
- // ds1 = t.T.dot(dt)
5520
-
5521
- // tensor.shape [m,p,qq,rr]
5522
- // src0.shape [n,m,q1,r1]
5523
- // src1.shape [n,p,qq,rr]
5524
-
5525
- // necessary for llama
5526
- if (src0->grad) {
5527
- struct lm_ggml_tensor * s1_tg =
5528
- lm_ggml_out_prod(ctx, // [n,m,qq,rr]
5529
- src1, // [n,p,qq,rr]
5530
- tensor->grad); // [m,p,qq,rr]
5531
- const int64_t qq = s1_tg->ne[2];
5532
- const int64_t rr = s1_tg->ne[3];
5533
- const int64_t q1 = src0->ne[2];
5534
- const int64_t r1 = src0->ne[3];
5535
- const bool ne2_broadcasted = qq > q1;
5536
- const bool ne3_broadcasted = rr > r1;
5537
- if (ne2_broadcasted || ne3_broadcasted) {
5538
- // sum broadcast repetitions of s1_tg into shape of src0
5539
- s1_tg = lm_ggml_repeat_back(ctx, s1_tg, src0);
5540
- }
5541
- src0->grad =
5542
- lm_ggml_add_or_set(ctx,
5543
- src0->grad, // [n,m,q1,r1]
5544
- s1_tg, // [n,m,q1,r1]
5545
- zero_table, acc_table);
5546
- }
5547
- if (src1->grad) {
5548
- src1->grad =
5549
- lm_ggml_add_or_set(ctx,
5550
- src1->grad, // [n,p,qq,rr]
5551
- // lm_ggml_mul_mat(ctx, // [n,p,qq,rr]
5552
- // lm_ggml_cont(ctx, // [m,n,q1,r1]
5553
- // lm_ggml_transpose(ctx, src0)), // [m,n,q1,r1]
5554
- // tensor->grad), // [m,p,qq,rr]
5555
-
5556
- // // when src0 is bigger than tensor->grad (this is mostly the case in llama),
5557
- // // avoid transpose of src0, rather transpose smaller tensor->grad
5558
- // // and then use lm_ggml_out_prod
5559
- lm_ggml_out_prod(ctx, // [n,p,qq,rr]
5560
- src0, // [n,m,q1,r1]
5561
- lm_ggml_transpose(ctx, // [p,m,qq,rr]
5562
- tensor->grad)), // [m,p,qq,rr]
5563
- zero_table, acc_table);
5564
- }
5565
- } break;
5566
- case LM_GGML_OP_MUL_MAT_ID:
5567
- {
5568
- LM_GGML_ABORT("fatal error"); // TODO: not implemented
5163
+ if (src1_needs_grads) {
5164
+ lm_ggml_add_or_set(ctx, cgraph, isrc1, lm_ggml_mean(ctx, grad)); // TODO: should probably be sum instead of mean
5569
5165
  }
5570
- case LM_GGML_OP_OUT_PROD:
5571
- {
5572
- LM_GGML_ABORT("fatal error"); // TODO: not implemented
5166
+ } break;
5167
+ case LM_GGML_OP_ACC: {
5168
+ if (src0_needs_grads) {
5169
+ lm_ggml_add_or_set(ctx, cgraph, isrc0, grad);
5573
5170
  }
5574
- case LM_GGML_OP_SCALE:
5575
- {
5576
- // necessary for llama
5577
- if (src0->grad) {
5578
- float s;
5579
- memcpy(&s, tensor->op_params, sizeof(float));
5580
-
5581
- src0->grad =
5582
- lm_ggml_add_or_set(ctx,
5583
- src0->grad,
5584
- lm_ggml_scale_impl(ctx, tensor->grad, s, false),
5585
- zero_table, acc_table);
5586
- }
5587
- } break;
5588
- case LM_GGML_OP_SET:
5589
- {
5590
- const size_t nb1 = ((int32_t *) tensor->op_params)[0];
5591
- const size_t nb2 = ((int32_t *) tensor->op_params)[1];
5592
- const size_t nb3 = ((int32_t *) tensor->op_params)[2];
5593
- const size_t offset = ((int32_t *) tensor->op_params)[3];
5594
-
5595
- struct lm_ggml_tensor * tensor_grad_view = NULL;
5596
-
5597
- if (src0->grad || src1->grad) {
5598
- LM_GGML_ASSERT(src0->type == tensor->type);
5599
- LM_GGML_ASSERT(tensor->grad->type == tensor->type);
5600
- LM_GGML_ASSERT(!src1->grad || src1->grad->type == tensor->grad->type);
5601
-
5602
- tensor_grad_view = lm_ggml_view_4d(ctx,
5603
- tensor->grad, src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3],
5604
- nb1, nb2, nb3, offset);
5605
- }
5606
-
5607
- if (src0->grad) {
5608
- src0->grad = lm_ggml_add_or_set(ctx,
5609
- src0->grad,
5610
- lm_ggml_acc_impl(ctx,
5611
- tensor->grad,
5612
- lm_ggml_neg(ctx, tensor_grad_view),
5613
- nb1, nb2, nb3, offset, false),
5614
- zero_table, acc_table);
5615
- }
5171
+ if (src1_needs_grads) {
5172
+ const size_t nb1 = ((int32_t *) tensor->op_params)[0];
5173
+ const size_t nb2 = ((int32_t *) tensor->op_params)[1];
5174
+ const size_t nb3 = ((int32_t *) tensor->op_params)[2];
5175
+ const size_t offset = ((int32_t *) tensor->op_params)[3];
5616
5176
 
5617
- if (src1->grad) {
5618
- src1->grad =
5619
- lm_ggml_add_or_set(ctx,
5620
- src1->grad,
5621
- lm_ggml_reshape(ctx,
5622
- lm_ggml_cont(ctx, tensor_grad_view),
5623
- src1->grad),
5624
- zero_table, acc_table);
5625
- }
5626
- } break;
5627
- case LM_GGML_OP_CPY:
5628
- {
5629
- // necessary for llama
5630
- // cpy overwrites value of src1 by src0 and returns view(src1)
5631
- // the overwriting is mathematically equivalent to:
5632
- // tensor = src0 * 1 + src1 * 0
5633
- if (src0->grad) {
5634
- // dsrc0 = dtensor * 1
5635
- src0->grad = lm_ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table);
5636
- }
5637
- if (src1->grad) {
5638
- // dsrc1 = dtensor * 0 -> noop
5639
- }
5640
- } break;
5641
- case LM_GGML_OP_CONT:
5642
- {
5643
- // same as cpy
5644
- if (src0->grad) {
5645
- LM_GGML_ASSERT(lm_ggml_is_contiguous(src0->grad));
5646
- LM_GGML_ASSERT(lm_ggml_is_contiguous(tensor->grad));
5647
- src0->grad = lm_ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table);
5648
- }
5649
- } break;
5650
- case LM_GGML_OP_RESHAPE:
5651
- {
5652
- // necessary for llama
5653
- if (src0->grad) {
5654
- src0->grad =
5655
- lm_ggml_add_or_set(ctx, src0->grad,
5656
- lm_ggml_reshape(ctx,
5657
- lm_ggml_is_contiguous(tensor->grad)
5658
- ? tensor->grad
5659
- : lm_ggml_cont(ctx, tensor->grad),
5660
- src0->grad),
5661
- zero_table, acc_table);
5662
- }
5663
- } break;
5664
- case LM_GGML_OP_VIEW:
5665
- {
5666
- // necessary for llama
5667
- if (src0->grad) {
5668
- size_t offset;
5669
-
5670
- memcpy(&offset, tensor->op_params, sizeof(offset));
5671
-
5672
- size_t nb1 = tensor->nb[1];
5673
- size_t nb2 = tensor->nb[2];
5674
- size_t nb3 = tensor->nb[3];
5675
-
5676
- if (src0->type != src0->grad->type) {
5677
- // gradient is typically F32, but src0 could be other type
5678
- size_t ng = lm_ggml_element_size(src0->grad);
5679
- size_t n0 = lm_ggml_element_size(src0);
5680
- LM_GGML_ASSERT(offset % n0 == 0);
5681
- LM_GGML_ASSERT(nb1 % n0 == 0);
5682
- LM_GGML_ASSERT(nb2 % n0 == 0);
5683
- LM_GGML_ASSERT(nb3 % n0 == 0);
5684
- offset = (offset / n0) * ng;
5685
- nb1 = (nb1 / n0) * ng;
5686
- nb2 = (nb2 / n0) * ng;
5687
- nb3 = (nb3 / n0) * ng;
5688
- }
5177
+ struct lm_ggml_tensor * tensor_grad_view = lm_ggml_view_4d(ctx,
5178
+ grad, src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3],
5179
+ nb1, nb2, nb3, offset);
5689
5180
 
5690
- src0->grad = lm_ggml_acc_or_set(ctx, src0->grad, tensor->grad, nb1, nb2, nb3, offset, zero_table, acc_table);
5691
- }
5692
- } break;
5693
- case LM_GGML_OP_PERMUTE:
5694
- {
5695
- // necessary for llama
5696
- if (src0->grad) {
5697
- int32_t * axes = (int32_t *) tensor->op_params;
5698
- int axis0 = axes[0] & 0x3;
5699
- int axis1 = axes[1] & 0x3;
5700
- int axis2 = axes[2] & 0x3;
5701
- int axis3 = axes[3] & 0x3;
5702
- int axes_backward[4] = {0,0,0,0};
5703
- axes_backward[axis0] = 0;
5704
- axes_backward[axis1] = 1;
5705
- axes_backward[axis2] = 2;
5706
- axes_backward[axis3] = 3;
5707
- src0->grad =
5708
- lm_ggml_add_or_set(ctx, src0->grad,
5709
- lm_ggml_permute(ctx,
5710
- tensor->grad,
5711
- axes_backward[0],
5712
- axes_backward[1],
5713
- axes_backward[2],
5714
- axes_backward[3]),
5715
- zero_table, acc_table);
5716
- }
5717
- } break;
5718
- case LM_GGML_OP_TRANSPOSE:
5719
- {
5720
- // necessary for llama
5721
- if (src0->grad) {
5722
- src0->grad =
5723
- lm_ggml_add_or_set(ctx, src0->grad,
5724
- lm_ggml_transpose(ctx, tensor->grad),
5725
- zero_table, acc_table);
5726
- }
5727
- } break;
5728
- case LM_GGML_OP_GET_ROWS:
5729
- {
5730
- // necessary for llama (only for tokenizer)
5731
- if (src0->grad) {
5732
- src0->grad =
5733
- lm_ggml_add_or_set(ctx, src0->grad,
5734
- // last lm_ggml_get_rows_back argument src0->grad is only
5735
- // necessary to setup correct output shape
5736
- lm_ggml_get_rows_back(ctx, tensor->grad, src1, src0->grad),
5737
- zero_table, acc_table);
5738
- }
5739
- if (src1->grad) {
5740
- // noop
5741
- }
5742
- } break;
5743
- case LM_GGML_OP_GET_ROWS_BACK:
5744
- {
5745
- LM_GGML_ABORT("fatal error"); // TODO: not implemented
5181
+ lm_ggml_add_or_set(ctx, cgraph, isrc1, lm_ggml_reshape(ctx, lm_ggml_cont(ctx, tensor_grad_view), src1));
5746
5182
  }
5747
- case LM_GGML_OP_DIAG:
5748
- {
5749
- LM_GGML_ABORT("fatal error"); // TODO: not implemented
5183
+ } break;
5184
+ case LM_GGML_OP_SUB: {
5185
+ if (src0_needs_grads) {
5186
+ lm_ggml_add_or_set(ctx, cgraph, isrc0, grad);
5750
5187
  }
5751
- case LM_GGML_OP_DIAG_MASK_INF:
5752
- {
5753
- // necessary for llama
5754
- if (src0->grad) {
5755
- const int n_past = ((int32_t *) tensor->op_params)[0];
5756
- src0->grad =
5757
- lm_ggml_add_or_set(ctx, src0->grad,
5758
- /* lm_ggml_diag_mask_inf_impl() shouldn't be here */
5759
- /* ref: https://github.com/ggerganov/llama.cpp/pull/4203#discussion_r1412377992 */
5760
- lm_ggml_diag_mask_zero_impl(ctx, tensor->grad, n_past, false),
5761
- zero_table, acc_table);
5762
- }
5763
- } break;
5764
- case LM_GGML_OP_DIAG_MASK_ZERO:
5765
- {
5766
- // necessary for llama
5767
- if (src0->grad) {
5768
- const int n_past = ((int32_t *) tensor->op_params)[0];
5769
- src0->grad =
5770
- lm_ggml_add_or_set(ctx, src0->grad,
5771
- lm_ggml_diag_mask_zero_impl(ctx, tensor->grad, n_past, false),
5772
- zero_table, acc_table);
5773
- }
5774
- } break;
5775
- case LM_GGML_OP_SOFT_MAX:
5776
- {
5777
- // necessary for llama
5778
- if (src0->grad) {
5779
- src0->grad =
5780
- lm_ggml_add_or_set(ctx, src0->grad,
5781
- lm_ggml_soft_max_back(ctx, tensor->grad, tensor),
5782
- zero_table, acc_table);
5783
- }
5784
- LM_GGML_ASSERT((!src1 || !src1->grad) && "backward pass for softmax mask not implemented");
5785
- } break;
5786
- case LM_GGML_OP_SOFT_MAX_BACK:
5787
- {
5788
- LM_GGML_ABORT("fatal error"); // TODO: not implemented
5188
+ if (src1_needs_grads) {
5189
+ lm_ggml_sub_or_set(ctx, cgraph, isrc1, grad);
5789
5190
  }
5790
- case LM_GGML_OP_ROPE:
5791
- {
5792
- // necessary for llama
5793
- if (src0->grad) {
5794
- //const int n_past = ((int32_t *) tensor->op_params)[0];
5795
- const int n_dims = ((int32_t *) tensor->op_params)[1];
5796
- const int mode = ((int32_t *) tensor->op_params)[2];
5797
- //const int n_ctx = ((int32_t *) tensor->op_params)[3];
5798
- const int n_ctx_orig = ((int32_t *) tensor->op_params)[4];
5799
- float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
5800
-
5801
- memcpy(&freq_base, (int32_t *) tensor->op_params + 5, sizeof(float));
5802
- memcpy(&freq_scale, (int32_t *) tensor->op_params + 6, sizeof(float));
5803
- memcpy(&ext_factor, (int32_t *) tensor->op_params + 7, sizeof(float));
5804
- memcpy(&attn_factor, (int32_t *) tensor->op_params + 8, sizeof(float));
5805
- memcpy(&beta_fast, (int32_t *) tensor->op_params + 9, sizeof(float));
5806
- memcpy(&beta_slow, (int32_t *) tensor->op_params + 10, sizeof(float));
5807
-
5808
- src0->grad = lm_ggml_add_or_set(ctx,
5809
- src0->grad,
5810
- lm_ggml_rope_back(ctx,
5811
- tensor->grad,
5812
- src1,
5813
- src2,
5814
- n_dims,
5815
- mode,
5816
- n_ctx_orig,
5817
- freq_base,
5818
- freq_scale,
5819
- ext_factor,
5820
- attn_factor,
5821
- beta_fast,
5822
- beta_slow),
5823
- zero_table, acc_table);
5824
- }
5825
- LM_GGML_ASSERT((!src2 || !src2->grad) && "gradients for freq factors not implemented");
5826
- } break;
5827
- case LM_GGML_OP_ROPE_BACK:
5828
- {
5829
- if (src0->grad) {
5830
- //const int n_past = ((int32_t *) tensor->op_params)[0];
5831
- const int n_dims = ((int32_t *) tensor->op_params)[1];
5832
- const int mode = ((int32_t *) tensor->op_params)[2];
5833
- //const int n_ctx = ((int32_t *) tensor->op_params)[3];
5834
- const int n_ctx_orig = ((int32_t *) tensor->op_params)[4];
5835
- float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
5836
-
5837
- memcpy(&freq_base, (int32_t *) tensor->op_params + 5, sizeof(float));
5838
- memcpy(&freq_scale, (int32_t *) tensor->op_params + 6, sizeof(float));
5839
- memcpy(&ext_factor, (int32_t *) tensor->op_params + 7, sizeof(float));
5840
- memcpy(&attn_factor, (int32_t *) tensor->op_params + 8, sizeof(float));
5841
- memcpy(&beta_fast, (int32_t *) tensor->op_params + 9, sizeof(float));
5842
- memcpy(&beta_slow, (int32_t *) tensor->op_params + 10, sizeof(float));
5843
-
5844
- src0->grad = lm_ggml_add_or_set(ctx,
5845
- src0->grad,
5846
- lm_ggml_rope_impl(ctx,
5847
- tensor->grad,
5848
- src1,
5849
- src2,
5850
- n_dims,
5851
- mode,
5852
- n_ctx_orig,
5853
- freq_base,
5854
- freq_scale,
5855
- ext_factor,
5856
- attn_factor,
5857
- beta_fast,
5858
- beta_slow,
5859
- false),
5860
- zero_table, acc_table);
5191
+ } break;
5192
+ case LM_GGML_OP_MUL: {
5193
+ if (src0_needs_grads) {
5194
+ lm_ggml_add_or_set(ctx, cgraph, isrc0, lm_ggml_mul(ctx, src1, grad));
5195
+ }
5196
+ if (src1_needs_grads) {
5197
+ struct lm_ggml_tensor * tmp = lm_ggml_mul(ctx, src0, grad);
5198
+ if (!lm_ggml_are_same_shape(src0, src1)) {
5199
+ tmp = lm_ggml_repeat_back(ctx, tmp, src1);
5861
5200
  }
5862
- } break;
5863
- case LM_GGML_OP_CLAMP:
5864
- {
5865
- LM_GGML_ABORT("fatal error"); // TODO: not implemented
5201
+ lm_ggml_add_or_set(ctx, cgraph, isrc1, tmp);
5866
5202
  }
5867
- case LM_GGML_OP_CONV_TRANSPOSE_1D:
5868
- {
5869
- LM_GGML_ABORT("fatal error"); // TODO: not implemented
5203
+ } break;
5204
+ case LM_GGML_OP_DIV: {
5205
+ if (src0_needs_grads) {
5206
+ lm_ggml_add_or_set(ctx, cgraph, isrc0, lm_ggml_div(ctx, grad, src1));
5870
5207
  }
5871
- case LM_GGML_OP_IM2COL:
5872
- {
5873
- if (src1->grad) {
5874
- const int32_t s0 = lm_ggml_get_op_params_i32(tensor, 0);
5875
- const int32_t s1 = lm_ggml_get_op_params_i32(tensor, 1);
5876
- const int32_t p0 = lm_ggml_get_op_params_i32(tensor, 2);
5877
- const int32_t p1 = lm_ggml_get_op_params_i32(tensor, 3);
5878
- const int32_t d0 = lm_ggml_get_op_params_i32(tensor, 4);
5879
- const int32_t d1 = lm_ggml_get_op_params_i32(tensor, 5);
5880
- const bool is_2D = lm_ggml_get_op_params_i32(tensor, 6) == 1;
5881
-
5882
- src1->grad = lm_ggml_add_or_set(ctx,
5883
- src1->grad,
5884
- lm_ggml_im2col_back(ctx, src0, tensor->grad, src1->ne, s0, s1, p0, p1, d0, d1, is_2D),
5885
- zero_table, acc_table);
5886
- }
5887
- } break;
5888
- case LM_GGML_OP_IM2COL_BACK:
5889
- {
5890
- LM_GGML_ABORT("fatal error"); // TODO: not implemented
5208
+ if (src1_needs_grads) {
5209
+ lm_ggml_sub_or_set(ctx, cgraph, isrc1, lm_ggml_mul(ctx, grad, lm_ggml_div(ctx, tensor, src1)));
5891
5210
  }
5892
- case LM_GGML_OP_CONV_TRANSPOSE_2D:
5893
- {
5894
- LM_GGML_ABORT("fatal error"); // TODO: not implemented
5211
+ } break;
5212
+ case LM_GGML_OP_SQR: {
5213
+ if (src0_needs_grads) {
5214
+ lm_ggml_add_or_set(ctx, cgraph, isrc0, lm_ggml_scale(ctx, lm_ggml_mul(ctx, src0, grad), 2.0f));
5895
5215
  }
5896
- case LM_GGML_OP_POOL_1D:
5897
- {
5898
- LM_GGML_ABORT("fatal error"); // TODO: not implemented
5216
+ } break;
5217
+ case LM_GGML_OP_SQRT: {
5218
+ if (src0_needs_grads) {
5219
+ lm_ggml_add_or_set(ctx, cgraph, isrc0, lm_ggml_scale(ctx, lm_ggml_div(ctx, grad, tensor), 0.5f));
5899
5220
  }
5900
- case LM_GGML_OP_POOL_2D:
5901
- {
5902
- if (src0->grad) {
5903
- const enum lm_ggml_op_pool op = lm_ggml_get_op_params_i32(tensor, 0);
5904
- const int32_t k0 = lm_ggml_get_op_params_i32(tensor, 1);
5905
- const int32_t k1 = lm_ggml_get_op_params_i32(tensor, 2);
5906
- const int32_t s0 = lm_ggml_get_op_params_i32(tensor, 3);
5907
- const int32_t s1 = lm_ggml_get_op_params_i32(tensor, 4);
5908
- const int32_t p0 = lm_ggml_get_op_params_i32(tensor, 5);
5909
- const int32_t p1 = lm_ggml_get_op_params_i32(tensor, 6);
5910
-
5911
- src0->grad = lm_ggml_add_or_set(ctx,
5912
- src0->grad,
5913
- lm_ggml_pool_2d_back(ctx, tensor->grad, src0, op, k0, k1, s0, s1, p0, p1),
5914
- zero_table, acc_table);
5915
- }
5916
- } break;
5917
- case LM_GGML_OP_POOL_2D_BACK:
5918
- {
5919
- LM_GGML_ABORT("fatal error"); // TODO: not implemented
5221
+ } break;
5222
+ case LM_GGML_OP_LOG: {
5223
+ if (src0_needs_grads) {
5224
+ lm_ggml_add_or_set(ctx, cgraph, isrc0, lm_ggml_div(ctx, grad, src0));
5920
5225
  }
5921
- case LM_GGML_OP_UPSCALE:
5922
- {
5923
- LM_GGML_ABORT("fatal error"); // TODO: not implemented
5226
+ } break;
5227
+ case LM_GGML_OP_SIN: {
5228
+ if (src0_needs_grads) {
5229
+ lm_ggml_add_or_set(ctx, cgraph, isrc0, lm_ggml_mul(ctx, grad, lm_ggml_cos(ctx, src0)));
5924
5230
  }
5925
- case LM_GGML_OP_PAD:
5926
- {
5927
- LM_GGML_ABORT("fatal error"); // TODO: not implemented
5231
+ } break;
5232
+ case LM_GGML_OP_COS: {
5233
+ if (src0_needs_grads) {
5234
+ lm_ggml_sub_or_set(ctx, cgraph, isrc0, lm_ggml_mul(ctx, grad, lm_ggml_sin(ctx, src0)));
5928
5235
  }
5929
- case LM_GGML_OP_ARANGE:
5930
- {
5931
- LM_GGML_ABORT("fatal error"); // TODO: not implemented
5236
+ } break;
5237
+ case LM_GGML_OP_SUM: {
5238
+ if (src0_needs_grads) {
5239
+ lm_ggml_add1_or_set(ctx, cgraph, isrc0, grad);
5932
5240
  }
5933
- case LM_GGML_OP_TIMESTEP_EMBEDDING:
5934
- {
5935
- LM_GGML_ABORT("fatal error"); // TODO: not implemented
5241
+ } break;
5242
+ case LM_GGML_OP_SUM_ROWS: {
5243
+ if (src0_needs_grads) {
5244
+ lm_ggml_add_or_set(ctx, cgraph, isrc0, lm_ggml_repeat(ctx, grad, src0));
5936
5245
  }
5937
- case LM_GGML_OP_ARGSORT:
5938
- {
5939
- LM_GGML_ABORT("fatal error"); // TODO: not implemented
5246
+ } break;
5247
+ case LM_GGML_OP_MEAN: {
5248
+ if (src0_needs_grads) {
5249
+ lm_ggml_add1_or_set(ctx, cgraph, isrc0, lm_ggml_scale_impl(ctx, grad, 1.0f/src0->ne[0], false));
5940
5250
  }
5941
- case LM_GGML_OP_LEAKY_RELU:
5942
- {
5943
- LM_GGML_ABORT("fatal error"); // TODO: not implemented
5251
+ } break;
5252
+ case LM_GGML_OP_REPEAT: {
5253
+ if (src0_needs_grads) {
5254
+ lm_ggml_add_or_set(ctx, cgraph, isrc0, lm_ggml_repeat_back(ctx, grad, src0));
5944
5255
  }
5945
- case LM_GGML_OP_FLASH_ATTN_EXT:
5946
- {
5947
- LM_GGML_ABORT("FA backward pass not adapted after rework");
5948
- struct lm_ggml_tensor * flash_grad = NULL;
5949
- if (src0->grad || src1->grad || tensor->src[2]->grad) {
5950
- int32_t t = lm_ggml_get_op_params_i32(tensor, 0);
5951
- LM_GGML_ASSERT(t == 0 || t == 1);
5952
- bool masked = t != 0;
5953
- flash_grad =
5954
- lm_ggml_flash_attn_back(ctx,
5955
- src0,
5956
- src1,
5957
- tensor->src[2],
5958
- tensor->grad,
5959
- masked);
5256
+ } break;
5257
+ case LM_GGML_OP_REPEAT_BACK: {
5258
+ if (src0_needs_grads) {
5259
+ lm_ggml_add_or_set(ctx, cgraph, isrc0, lm_ggml_repeat(ctx, grad, src0));
5260
+ }
5261
+ } break;
5262
+ case LM_GGML_OP_RMS_NORM: {
5263
+ if (src0_needs_grads) {
5264
+ float eps;
5265
+ memcpy(&eps, tensor->op_params, sizeof(float));
5266
+ lm_ggml_add_or_set(ctx, cgraph, isrc0, lm_ggml_rms_norm_back(ctx, src0, grad, eps));
5267
+ }
5268
+ } break;
5269
+ case LM_GGML_OP_MUL_MAT: {
5270
+ // https://cs231n.github.io/optimization-2/#staged
5271
+ // # forward pass
5272
+ // s0 = np.random.randn(5, 10)
5273
+ // s1 = np.random.randn(10, 3)
5274
+ // t = s0.dot(s1)
5275
+
5276
+ // # now suppose we had the gradient on t from above in the circuit
5277
+ // dt = np.random.randn(*t.shape) # same shape as t
5278
+ // ds0 = dt.dot(s1.T) #.T gives the transpose of the matrix
5279
+ // ds1 = t.T.dot(dt)
5280
+
5281
+ // tensor.shape [m,p,qq,rr]
5282
+ // src0.shape [n,m,q1,r1]
5283
+ // src1.shape [n,p,qq,rr]
5284
+
5285
+ if (src0_needs_grads) {
5286
+ struct lm_ggml_tensor * s1_tg =
5287
+ lm_ggml_out_prod(ctx, // [n,m,qq,rr]
5288
+ src1, // [n,p,qq,rr]
5289
+ grad); // [m,p,qq,rr]
5290
+ const int64_t qq = s1_tg->ne[2];
5291
+ const int64_t rr = s1_tg->ne[3];
5292
+ const int64_t q1 = src0->ne[2];
5293
+ const int64_t r1 = src0->ne[3];
5294
+ const bool ne2_broadcasted = qq > q1;
5295
+ const bool ne3_broadcasted = rr > r1;
5296
+ if (ne2_broadcasted || ne3_broadcasted) {
5297
+ // sum broadcast repetitions of s1_tg into shape of src0
5298
+ s1_tg = lm_ggml_repeat_back(ctx, s1_tg, src0);
5960
5299
  }
5300
+ lm_ggml_add_or_set(ctx, cgraph, isrc0, s1_tg /*= [n,m,q1,r1]*/);
5301
+ }
5302
+ if (src1_needs_grads) {
5303
+ lm_ggml_add_or_set(ctx, cgraph, isrc1,
5304
+ // lm_ggml_mul_mat(ctx, // [n,p,qq,rr]
5305
+ // lm_ggml_cont(ctx, // [m,n,q1,r1]
5306
+ // lm_ggml_transpose(ctx, src0)), // [m,n,q1,r1]
5307
+ // grad), // [m,p,qq,rr]
5308
+
5309
+ // when src0 is bigger than tensor->grad (this is mostly the case in llama),
5310
+ // avoid transpose of src0, rather transpose smaller tensor->grad
5311
+ // and then use lm_ggml_out_prod
5312
+ lm_ggml_out_prod(ctx, // [n,p,qq,rr]
5313
+ src0, // [n,m,q1,r1]
5314
+ lm_ggml_transpose(ctx, // [p,m,qq,rr]
5315
+ grad))); // [m,p,qq,rr]
5316
+ }
5317
+ } break;
5318
+ case LM_GGML_OP_SCALE: {
5319
+ if (src0_needs_grads) {
5320
+ float s;
5321
+ memcpy(&s, tensor->op_params, sizeof(float));
5322
+ lm_ggml_add_or_set(ctx, cgraph, isrc0, lm_ggml_scale_impl(ctx, grad, s, false));
5323
+ }
5324
+ } break;
5325
+ case LM_GGML_OP_SET: {
5326
+ const size_t nb1 = ((const int32_t *) tensor->op_params)[0];
5327
+ const size_t nb2 = ((const int32_t *) tensor->op_params)[1];
5328
+ const size_t nb3 = ((const int32_t *) tensor->op_params)[2];
5329
+ const size_t offset = ((const int32_t *) tensor->op_params)[3];
5330
+
5331
+ struct lm_ggml_tensor * tensor_grad_view = NULL;
5332
+
5333
+ if (src0_needs_grads || src1_needs_grads) {
5334
+ LM_GGML_ASSERT(src0->type == tensor->type);
5335
+ LM_GGML_ASSERT(!cgraph->grads[isrc0] || cgraph->grads[isrc0]->type == grad->type);
5336
+ LM_GGML_ASSERT(!cgraph->grads[isrc1] || !src1_needs_grads || cgraph->grads[isrc1]->type == grad->type);
5337
+
5338
+ tensor_grad_view = lm_ggml_view_4d(ctx,
5339
+ grad, src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3],
5340
+ nb1, nb2, nb3, offset);
5341
+ }
5961
5342
 
5962
- const int64_t elem_q = lm_ggml_nelements(src0);
5963
- const int64_t elem_k = lm_ggml_nelements(src1);
5964
- const int64_t elem_v = lm_ggml_nelements(src2);
5965
-
5966
- enum lm_ggml_type result_type = flash_grad->type;
5967
- LM_GGML_ASSERT(lm_ggml_blck_size(result_type) == 1);
5968
- const size_t tsize = lm_ggml_type_size(result_type);
5969
-
5970
- const size_t offs_q = 0;
5971
- const size_t offs_k = offs_q + LM_GGML_PAD(elem_q * tsize, LM_GGML_MEM_ALIGN);
5972
- const size_t offs_v = offs_k + LM_GGML_PAD(elem_k * tsize, LM_GGML_MEM_ALIGN);
5973
-
5974
- if (src0->grad) {
5975
- struct lm_ggml_tensor * view_q = lm_ggml_view_1d(ctx, flash_grad, elem_q, offs_q);
5976
- struct lm_ggml_tensor * grad_q = lm_ggml_reshape(ctx, view_q, src0);
5977
- src0->grad = lm_ggml_add_or_set(ctx,
5978
- src0->grad,
5979
- grad_q,
5980
- zero_table, acc_table);
5981
- }
5982
- if (src1->grad) {
5983
- struct lm_ggml_tensor * view_k = lm_ggml_view_1d(ctx, flash_grad, elem_k, offs_k);
5984
- struct lm_ggml_tensor * grad_k = lm_ggml_reshape(ctx, view_k, src1);
5985
- src1->grad = lm_ggml_add_or_set(ctx,
5986
- src1->grad,
5987
- grad_k,
5988
- zero_table, acc_table);
5989
- }
5990
- if (src2->grad) {
5991
- struct lm_ggml_tensor * view_v = lm_ggml_view_1d(ctx, flash_grad, elem_v, offs_v);
5992
- struct lm_ggml_tensor * grad_v = lm_ggml_reshape(ctx, view_v, src2);
5993
- src2->grad = lm_ggml_add_or_set(ctx,
5994
- src2->grad,
5995
- grad_v,
5996
- zero_table, acc_table);
5343
+ if (src0_needs_grads) {
5344
+ struct lm_ggml_tensor * tmp = lm_ggml_neg(ctx, tensor_grad_view);
5345
+ lm_ggml_add_or_set(ctx, cgraph, isrc0, lm_ggml_acc_impl(ctx, grad, tmp, nb1, nb2, nb3, offset, false));
5346
+ }
5347
+
5348
+ if (src1_needs_grads) {
5349
+ lm_ggml_add_or_set(ctx, cgraph, isrc1, lm_ggml_reshape(ctx, lm_ggml_cont(ctx, tensor_grad_view), src1));
5350
+ }
5351
+ } break;
5352
+ case LM_GGML_OP_CPY: {
5353
+ // cpy overwrites value of src1 by src0 and returns view(src1)
5354
+ // the overwriting is mathematically equivalent to:
5355
+ // tensor = src0 * 1 + src1 * 0
5356
+ if (src0_needs_grads) {
5357
+ // dsrc0 = dtensor * 1
5358
+ lm_ggml_add_or_set(ctx, cgraph, isrc0, grad);
5359
+ }
5360
+ if (src1_needs_grads) {
5361
+ // dsrc1 = dtensor * 0 -> noop
5362
+ }
5363
+ } break;
5364
+ case LM_GGML_OP_CONT: {
5365
+ // same as cpy
5366
+ if (src0_needs_grads) {
5367
+ LM_GGML_ASSERT(!cgraph->grads[isrc0] || lm_ggml_is_contiguous(cgraph->grads[isrc0]));
5368
+ LM_GGML_ASSERT(lm_ggml_is_contiguous(grad));
5369
+ lm_ggml_add_or_set(ctx, cgraph, isrc0, grad);
5370
+ }
5371
+ } break;
5372
+ case LM_GGML_OP_RESHAPE: {
5373
+ if (src0_needs_grads) {
5374
+ struct lm_ggml_tensor * grad_cont = lm_ggml_is_contiguous(grad) ? grad : lm_ggml_cont(ctx, grad);
5375
+ lm_ggml_add_or_set(ctx, cgraph, isrc0, lm_ggml_reshape(ctx, grad_cont, src0));
5376
+ }
5377
+ } break;
5378
+ case LM_GGML_OP_VIEW: {
5379
+ if (src0_needs_grads) {
5380
+ size_t offset;
5381
+
5382
+ memcpy(&offset, tensor->op_params, sizeof(offset));
5383
+
5384
+ size_t nb1 = tensor->nb[1];
5385
+ size_t nb2 = tensor->nb[2];
5386
+ size_t nb3 = tensor->nb[3];
5387
+
5388
+ if (cgraph->grads[isrc0] && src0->type != cgraph->grads[isrc0]->type) {
5389
+ // gradient is typically F32, but src0 could be other type
5390
+ size_t ng = lm_ggml_element_size(cgraph->grads[isrc0]);
5391
+ size_t n0 = lm_ggml_element_size(src0);
5392
+ LM_GGML_ASSERT(offset % n0 == 0);
5393
+ LM_GGML_ASSERT(nb1 % n0 == 0);
5394
+ LM_GGML_ASSERT(nb2 % n0 == 0);
5395
+ LM_GGML_ASSERT(nb3 % n0 == 0);
5396
+ offset = (offset / n0) * ng;
5397
+ nb1 = (nb1 / n0) * ng;
5398
+ nb2 = (nb2 / n0) * ng;
5399
+ nb3 = (nb3 / n0) * ng;
5997
5400
  }
5998
- } break;
5999
- case LM_GGML_OP_FLASH_ATTN_BACK:
6000
- {
6001
- LM_GGML_ABORT("fatal error"); // not supported
5401
+
5402
+ lm_ggml_acc_or_set(ctx, cgraph, isrc0, grad, nb1, nb2, nb3, offset);
6002
5403
  }
6003
- case LM_GGML_OP_SSM_CONV:
6004
- case LM_GGML_OP_SSM_SCAN:
6005
- {
6006
- LM_GGML_ABORT("fatal error"); // TODO: not implemented
5404
+ } break;
5405
+ case LM_GGML_OP_PERMUTE: {
5406
+ if (src0_needs_grads) {
5407
+ const int32_t * axes = (const int32_t *) tensor->op_params;
5408
+ const int axis0 = axes[0] & 0x3;
5409
+ const int axis1 = axes[1] & 0x3;
5410
+ const int axis2 = axes[2] & 0x3;
5411
+ const int axis3 = axes[3] & 0x3;
5412
+ int axb[4] = {0,0,0,0}; // axes backward
5413
+ axb[axis0] = 0;
5414
+ axb[axis1] = 1;
5415
+ axb[axis2] = 2;
5416
+ axb[axis3] = 3;
5417
+ lm_ggml_add_or_set(ctx, cgraph, isrc0, lm_ggml_permute(ctx, grad, axb[0], axb[1], axb[2], axb[3]));
5418
+ }
5419
+ } break;
5420
+ case LM_GGML_OP_TRANSPOSE: {
5421
+ if (src0_needs_grads) {
5422
+ lm_ggml_add_or_set(ctx, cgraph, isrc0, lm_ggml_transpose(ctx, grad));
5423
+ }
5424
+ } break;
5425
+ case LM_GGML_OP_GET_ROWS: {
5426
+ if (src0_needs_grads) {
5427
+ lm_ggml_add_or_set(ctx, cgraph, isrc0, lm_ggml_get_rows_back(ctx, grad, src1, src0));
5428
+ }
5429
+ if (src1_needs_grads) {
5430
+ // noop
6007
5431
  }
5432
+ } break;
5433
+ case LM_GGML_OP_DIAG_MASK_INF: {
5434
+ if (src0_needs_grads) {
5435
+ /* lm_ggml_diag_mask_inf_impl() shouldn't be here */
5436
+ /* ref: https://github.com/ggerganov/llama.cpp/pull/4203#discussion_r1412377992 */
5437
+ const int n_past = ((const int32_t *) tensor->op_params)[0];
5438
+ lm_ggml_add_or_set(ctx, cgraph, isrc0, lm_ggml_diag_mask_zero_impl(ctx, grad, n_past, false));
5439
+ }
5440
+ } break;
5441
+ case LM_GGML_OP_DIAG_MASK_ZERO: {
5442
+ if (src0_needs_grads) {
5443
+ const int n_past = ((const int32_t *) tensor->op_params)[0];
5444
+ lm_ggml_add_or_set(ctx, cgraph, isrc0, lm_ggml_diag_mask_zero_impl(ctx, grad, n_past, false));
5445
+ }
5446
+ } break;
5447
+ case LM_GGML_OP_SOFT_MAX: {
5448
+ if (src0_needs_grads) {
5449
+ lm_ggml_add_or_set(ctx, cgraph, isrc0, lm_ggml_soft_max_back(ctx, grad, tensor));
5450
+ }
5451
+ LM_GGML_ASSERT((!src1 || !src1_needs_grads) && "backward pass for softmax mask not implemented");
5452
+ } break;
5453
+ case LM_GGML_OP_ROPE: {
5454
+ if (src0_needs_grads) {
5455
+ //const int n_past = ((int32_t *) tensor->op_params)[0];
5456
+ const int n_dims = ((const int32_t *) tensor->op_params)[1];
5457
+ const int mode = ((const int32_t *) tensor->op_params)[2];
5458
+ //const int n_ctx = ((int32_t *) tensor->op_params)[3];
5459
+ const int n_ctx_orig = ((const int32_t *) tensor->op_params)[4];
5460
+ float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
5461
+
5462
+ memcpy(&freq_base, (const float *) tensor->op_params + 5, sizeof(float));
5463
+ memcpy(&freq_scale, (const float *) tensor->op_params + 6, sizeof(float));
5464
+ memcpy(&ext_factor, (const float *) tensor->op_params + 7, sizeof(float));
5465
+ memcpy(&attn_factor, (const float *) tensor->op_params + 8, sizeof(float));
5466
+ memcpy(&beta_fast, (const float *) tensor->op_params + 9, sizeof(float));
5467
+ memcpy(&beta_slow, (const float *) tensor->op_params + 10, sizeof(float));
5468
+
5469
+ lm_ggml_add_or_set(ctx, cgraph, isrc0,
5470
+ lm_ggml_rope_back(ctx, grad, src1, src2, n_dims, mode, n_ctx_orig, freq_base,
5471
+ freq_scale, ext_factor, attn_factor, beta_fast, beta_slow));
5472
+ }
5473
+ LM_GGML_ASSERT((!src2 || !src2_needs_grads) && "gradients for freq factors not implemented");
5474
+ } break;
5475
+ case LM_GGML_OP_IM2COL: {
5476
+ if (src1_needs_grads) {
5477
+ const int32_t s0 = lm_ggml_get_op_params_i32(tensor, 0);
5478
+ const int32_t s1 = lm_ggml_get_op_params_i32(tensor, 1);
5479
+ const int32_t p0 = lm_ggml_get_op_params_i32(tensor, 2);
5480
+ const int32_t p1 = lm_ggml_get_op_params_i32(tensor, 3);
5481
+ const int32_t d0 = lm_ggml_get_op_params_i32(tensor, 4);
5482
+ const int32_t d1 = lm_ggml_get_op_params_i32(tensor, 5);
5483
+ const bool is_2D = lm_ggml_get_op_params_i32(tensor, 6) == 1;
5484
+
5485
+ lm_ggml_add_or_set(ctx, cgraph, isrc1, lm_ggml_im2col_back(ctx, src0, grad, src1->ne, s0, s1, p0, p1, d0, d1, is_2D));
5486
+ }
5487
+ } break;
5488
+ case LM_GGML_OP_POOL_2D: {
5489
+ if (src0_needs_grads) {
5490
+ const enum lm_ggml_op_pool op = lm_ggml_get_op_params_i32(tensor, 0);
5491
+ const int32_t k0 = lm_ggml_get_op_params_i32(tensor, 1);
5492
+ const int32_t k1 = lm_ggml_get_op_params_i32(tensor, 2);
5493
+ const int32_t s0 = lm_ggml_get_op_params_i32(tensor, 3);
5494
+ const int32_t s1 = lm_ggml_get_op_params_i32(tensor, 4);
5495
+ const int32_t p0 = lm_ggml_get_op_params_i32(tensor, 5);
5496
+ const int32_t p1 = lm_ggml_get_op_params_i32(tensor, 6);
5497
+
5498
+ lm_ggml_add_or_set(ctx, cgraph, isrc0, lm_ggml_pool_2d_back(ctx, grad, src0, op, k0, k1, s0, s1, p0, p1));
5499
+ }
5500
+ } break;
6008
5501
  case LM_GGML_OP_WIN_PART:
6009
5502
  case LM_GGML_OP_WIN_UNPART:
6010
- case LM_GGML_OP_UNARY:
6011
- {
6012
- switch (lm_ggml_get_unary_op(tensor)) {
6013
- case LM_GGML_UNARY_OP_ABS:
6014
- {
6015
- if (src0->grad) {
6016
- src0->grad =
6017
- lm_ggml_add_or_set(ctx,
6018
- src0->grad,
6019
- lm_ggml_mul(ctx,
6020
- lm_ggml_sgn(ctx, src0),
6021
- tensor->grad),
6022
- zero_table, acc_table);
6023
- }
6024
- } break;
6025
- case LM_GGML_UNARY_OP_SGN:
6026
- {
6027
- if (src0->grad) {
6028
- // noop
6029
- }
6030
- } break;
6031
- case LM_GGML_UNARY_OP_NEG:
6032
- {
6033
- if (src0->grad) {
6034
- src0->grad = lm_ggml_sub_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table);
6035
- }
6036
- } break;
6037
- case LM_GGML_UNARY_OP_STEP:
6038
- {
6039
- if (src0->grad) {
6040
- // noop
6041
- }
6042
- } break;
6043
- case LM_GGML_UNARY_OP_TANH:
6044
- {
6045
- LM_GGML_ABORT("fatal error"); // TODO: not implemented
6046
- }
6047
- case LM_GGML_UNARY_OP_ELU:
6048
- {
6049
- LM_GGML_ABORT("fatal error"); // TODO: not implemented
6050
- }
6051
- case LM_GGML_UNARY_OP_RELU:
6052
- {
6053
- if (src0->grad) {
6054
- src0->grad = lm_ggml_add_or_set(ctx,
6055
- src0->grad,
6056
- lm_ggml_mul(ctx,
6057
- lm_ggml_step(ctx, src0),
6058
- tensor->grad),
6059
- zero_table, acc_table);
6060
- }
6061
- } break;
6062
- case LM_GGML_UNARY_OP_SIGMOID:
6063
- {
6064
- LM_GGML_ABORT("fatal error"); // TODO: not implemented
6065
- }
6066
- case LM_GGML_UNARY_OP_GELU:
6067
- {
6068
- LM_GGML_ABORT("fatal error"); // TODO: not implemented
6069
- }
6070
- case LM_GGML_UNARY_OP_GELU_QUICK:
6071
- {
6072
- LM_GGML_ABORT("fatal error"); // TODO: not implemented
6073
- }
6074
- case LM_GGML_UNARY_OP_SILU:
6075
- {
6076
- // necessary for llama
6077
- if (src0->grad) {
6078
- src0->grad = lm_ggml_add_or_set(ctx,
6079
- src0->grad,
6080
- lm_ggml_silu_back(ctx, src0, tensor->grad),
6081
- zero_table, acc_table);
6082
- }
6083
- } break;
6084
- case LM_GGML_UNARY_OP_EXP:
6085
- {
6086
- if (src0->grad) {
6087
- src0->grad = lm_ggml_add_or_set(ctx,
6088
- src0->grad,
6089
- lm_ggml_mul(ctx, tensor, tensor->grad),
6090
- zero_table, acc_table);
6091
- }
6092
- } break;
6093
- default:
6094
- LM_GGML_ABORT("fatal error");
6095
- }
6096
- } break;
6097
- case LM_GGML_OP_GET_REL_POS:
6098
- case LM_GGML_OP_ADD_REL_POS:
6099
- case LM_GGML_OP_RWKV_WKV:
6100
- case LM_GGML_OP_MAP_UNARY:
6101
- case LM_GGML_OP_MAP_BINARY:
6102
- case LM_GGML_OP_MAP_CUSTOM1_F32:
6103
- case LM_GGML_OP_MAP_CUSTOM2_F32:
6104
- case LM_GGML_OP_MAP_CUSTOM3_F32:
6105
- case LM_GGML_OP_MAP_CUSTOM1:
6106
- case LM_GGML_OP_MAP_CUSTOM2:
6107
- case LM_GGML_OP_MAP_CUSTOM3:
6108
- {
6109
- LM_GGML_ABORT("fatal error"); // not supported
6110
- }
6111
- case LM_GGML_OP_CROSS_ENTROPY_LOSS:
6112
- {
6113
- if (src0->grad) {
6114
- src0->grad = lm_ggml_add_or_set(ctx,
6115
- src0->grad,
6116
- lm_ggml_cross_entropy_loss_back(ctx,
6117
- src0,
6118
- src1,
6119
- tensor->grad),
6120
- zero_table, acc_table);
6121
- }
6122
- LM_GGML_ASSERT(!src1->grad && "backward pass for labels not implemented");
6123
- } break;
6124
- case LM_GGML_OP_CROSS_ENTROPY_LOSS_BACK:
6125
- {
6126
- LM_GGML_ABORT("fatal error"); // not supported
5503
+ case LM_GGML_OP_UNARY: {
5504
+ switch (lm_ggml_get_unary_op(tensor)) {
5505
+ case LM_GGML_UNARY_OP_ABS: {
5506
+ if (src0_needs_grads) {
5507
+ lm_ggml_add_or_set(ctx, cgraph, isrc0, lm_ggml_mul(ctx, lm_ggml_sgn(ctx, src0), grad));
5508
+ }
5509
+ } break;
5510
+ case LM_GGML_UNARY_OP_SGN: {
5511
+ // noop
5512
+ } break;
5513
+ case LM_GGML_UNARY_OP_NEG: {
5514
+ if (src0_needs_grads) {
5515
+ lm_ggml_sub_or_set(ctx, cgraph, isrc0, grad);
5516
+ }
5517
+ } break;
5518
+ case LM_GGML_UNARY_OP_STEP: {
5519
+ // noop
5520
+ } break;
5521
+ case LM_GGML_UNARY_OP_RELU: {
5522
+ if (src0_needs_grads) {
5523
+ lm_ggml_add_or_set(ctx, cgraph, isrc0, lm_ggml_mul(ctx, lm_ggml_step(ctx, src0), grad));
5524
+ }
5525
+ } break;
5526
+ case LM_GGML_UNARY_OP_SILU: {
5527
+ if (src0_needs_grads) {
5528
+ lm_ggml_add_or_set(ctx, cgraph, isrc0, lm_ggml_silu_back(ctx, src0, grad));
5529
+ }
5530
+ } break;
5531
+ case LM_GGML_UNARY_OP_EXP: {
5532
+ if (src0_needs_grads) {
5533
+ lm_ggml_add_or_set(ctx, cgraph, isrc0, lm_ggml_mul(ctx, tensor, grad));
5534
+ }
5535
+ } break;
5536
+ default: {
5537
+ fprintf(stderr, "%s: unsupported unary op for backward pass: %s\n",
5538
+ __func__, lm_ggml_unary_op_name(lm_ggml_get_unary_op(tensor)));
5539
+ LM_GGML_ABORT("fatal error");
5540
+ } //break;
6127
5541
  }
6128
- case LM_GGML_OP_OPT_STEP_ADAMW:
6129
- {
6130
- LM_GGML_ABORT("fatal error"); // not supported
5542
+ } break;
5543
+ case LM_GGML_OP_CROSS_ENTROPY_LOSS: {
5544
+ if (src0_needs_grads) {
5545
+ lm_ggml_add_or_set(ctx, cgraph, isrc0, lm_ggml_cross_entropy_loss_back(ctx, src0, src1, grad));
6131
5546
  }
6132
- case LM_GGML_OP_NONE:
6133
- {
6134
- // nop
6135
- } break;
5547
+ LM_GGML_ASSERT(!src1_needs_grads && "backward pass for labels not implemented");
5548
+ } break;
5549
+ case LM_GGML_OP_NONE: {
5550
+ // noop
5551
+ } break;
6136
5552
  case LM_GGML_OP_COUNT:
6137
- {
6138
- LM_GGML_ABORT("fatal error");
6139
- }
5553
+ default: {
5554
+ fprintf(stderr, "%s: unsupported ggml op for backward pass: %s\n", __func__, lm_ggml_op_name(tensor->op));
5555
+ LM_GGML_ABORT("fatal error");
5556
+ } //break;
6140
5557
  }
6141
5558
 
6142
- for (int i = 0; i < LM_GGML_MAX_SRC; ++i) {
6143
- if (tensor->src[i] && tensor->src[i]->grad) {
6144
- LM_GGML_ASSERT(lm_ggml_are_same_shape(tensor->src[i], tensor->src[i]->grad));
6145
- }
6146
- }
5559
+ LM_GGML_ASSERT(!src0_needs_grads || lm_ggml_are_same_shape(src0, cgraph->grads[isrc0]));
5560
+ LM_GGML_ASSERT(!src1_needs_grads || lm_ggml_are_same_shape(src1, cgraph->grads[isrc1]));
5561
+ LM_GGML_ASSERT(!src2_needs_grads || lm_ggml_are_same_shape(src2, cgraph->grads[isrc2]));
6147
5562
  }
6148
5563
 
6149
5564
  static void lm_ggml_visit_parents(struct lm_ggml_cgraph * cgraph, struct lm_ggml_tensor * node) {
6150
- if (node->grad == NULL) {
6151
- // this usually happens when we generate intermediate nodes from constants in the backward pass
6152
- // it can also happen during forward pass, if the user performs computations with constants
6153
- if (node->op != LM_GGML_OP_NONE) {
6154
- //LM_GGML_PRINT_DEBUG("%s: warning: node %p has no grad, but op %d\n", __func__, (void *) node, node->op);
6155
- }
6156
- }
6157
-
6158
5565
  // check if already visited
6159
5566
  if (lm_ggml_hash_insert(&cgraph->visited_hash_set, node) == LM_GGML_HASHSET_ALREADY_EXISTS) {
6160
5567
  return;
@@ -6215,18 +5622,41 @@ void lm_ggml_build_forward_expand(struct lm_ggml_cgraph * cgraph, struct lm_ggml
6215
5622
  lm_ggml_build_forward_impl(cgraph, tensor, true);
6216
5623
  }
6217
5624
 
6218
- void lm_ggml_build_backward_expand(struct lm_ggml_context * ctx, struct lm_ggml_cgraph * gf, struct lm_ggml_cgraph * gb, bool accumulate) {
6219
- LM_GGML_ASSERT(gf->n_nodes > 0);
6220
- LM_GGML_ASSERT(gf->grads);
5625
+ void lm_ggml_build_backward_expand(
5626
+ struct lm_ggml_context * ctx_static,
5627
+ struct lm_ggml_context * ctx_compute,
5628
+ struct lm_ggml_cgraph * cgraph,
5629
+ bool accumulate) {
5630
+ LM_GGML_ASSERT(cgraph->n_nodes > 0);
5631
+ LM_GGML_ASSERT(cgraph->grads);
5632
+ LM_GGML_ASSERT(cgraph->grad_accs);
5633
+
5634
+ const int n_nodes_f = cgraph->n_nodes;
5635
+
5636
+ memset(cgraph->grads, 0, cgraph->visited_hash_set.size*sizeof(struct lm_ggml_tensor *));
5637
+ memset(cgraph->grad_accs, 0, cgraph->visited_hash_set.size*sizeof(struct lm_ggml_tensor *));
5638
+ bool * grads_needed = calloc(cgraph->visited_hash_set.size, sizeof(bool));
5639
+
5640
+ {
5641
+ bool any_params = false;
5642
+ bool any_loss = false;
5643
+ for (int i = 0; i < n_nodes_f; ++i) {
5644
+ struct lm_ggml_tensor * node = cgraph->nodes[i];
5645
+ any_params = any_params || (node->flags & LM_GGML_TENSOR_FLAG_PARAM);
5646
+ any_loss = any_loss || (node->flags & LM_GGML_TENSOR_FLAG_LOSS);
5647
+ }
5648
+ LM_GGML_ASSERT(any_params && "no trainable parameters found, did you forget to call lm_ggml_set_param?");
5649
+ LM_GGML_ASSERT(any_loss && "no training loss found, did you forget to call lm_ggml_set_loss?");
5650
+ }
6221
5651
 
6222
- for (int i = 0; i < gf->n_nodes; ++i) {
6223
- struct lm_ggml_tensor * node = gf->nodes[i];
5652
+ for (int i = 0; i < n_nodes_f; ++i) {
5653
+ struct lm_ggml_tensor * node = cgraph->nodes[i];
6224
5654
 
6225
5655
  if (node->type == LM_GGML_TYPE_I32) {
6226
5656
  continue;
6227
5657
  }
6228
5658
 
6229
- bool needs_grad = node->flags & LM_GGML_TENSOR_FLAG_PARAM;
5659
+ bool node_needs_grad = (node->flags & LM_GGML_TENSOR_FLAG_PARAM) || (node->flags & LM_GGML_TENSOR_FLAG_LOSS);
6230
5660
  bool ignore_src[LM_GGML_MAX_SRC] = {false};
6231
5661
  switch (node->op) {
6232
5662
  // gradients in node->src[0] for one reason or another have no effect on output gradients
@@ -6243,7 +5673,7 @@ void lm_ggml_build_backward_expand(struct lm_ggml_context * ctx, struct lm_ggml_
6243
5673
  } break;
6244
5674
 
6245
5675
  // gradients in node->src[1] for one reason or another have no effect on output gradients
6246
- case LM_GGML_OP_CPY: // gradients in CPY target are irrelevant
5676
+ case LM_GGML_OP_CPY: // gradients in CPY target are irrelevant
6247
5677
  case LM_GGML_OP_GET_ROWS: // row indices not differentiable
6248
5678
  case LM_GGML_OP_GET_ROWS_BACK: // same as for GET_ROWS
6249
5679
  case LM_GGML_OP_ROPE: // positions not differentiable
@@ -6254,14 +5684,14 @@ void lm_ggml_build_backward_expand(struct lm_ggml_context * ctx, struct lm_ggml_
6254
5684
  break;
6255
5685
  }
6256
5686
  for (int j = 0; j < LM_GGML_MAX_SRC; ++j) {
6257
- if (!node->src[j] || !node->src[j]->grad || ignore_src[j]) {
5687
+ if (!node->src[j] || ignore_src[j] || !grads_needed[lm_ggml_hash_find(&cgraph->visited_hash_set, node->src[j])]) {
6258
5688
  continue;
6259
5689
  }
6260
5690
  LM_GGML_ASSERT(node->src[j]->type == LM_GGML_TYPE_F32 || node->src[j]->type == LM_GGML_TYPE_F16);
6261
- needs_grad = true;
5691
+ node_needs_grad = true;
6262
5692
  break;
6263
5693
  }
6264
- if (!needs_grad) {
5694
+ if (!node_needs_grad) {
6265
5695
  continue;
6266
5696
  }
6267
5697
 
@@ -6269,73 +5699,24 @@ void lm_ggml_build_backward_expand(struct lm_ggml_context * ctx, struct lm_ggml_
6269
5699
  LM_GGML_ASSERT(!node->view_src || node->op == LM_GGML_OP_CPY || node->op == LM_GGML_OP_VIEW ||
6270
5700
  node->op == LM_GGML_OP_RESHAPE || node->op == LM_GGML_OP_PERMUTE || node->op == LM_GGML_OP_TRANSPOSE);
6271
5701
 
6272
- // create a new tensor with the same type and shape as the node and set it as grad
6273
- node->grad = lm_ggml_dup_tensor(ctx, node);
6274
- }
6275
-
6276
- // keep tables of original gradients for replacement/accumulation logic
6277
- struct lm_ggml_hash_set zero_table = lm_ggml_hash_set_new(gf->size);
6278
- struct lm_ggml_hash_set acc_table = lm_ggml_hash_set_new(gf->size);
6279
- for (int i = 0; i < gf->n_nodes; i++) {
6280
- struct lm_ggml_tensor * node = gf->nodes[i];
6281
-
6282
- if (node->grad) {
6283
- {
6284
- const size_t insert_result = lm_ggml_hash_insert(&zero_table, node->grad);
6285
- LM_GGML_ASSERT(insert_result != LM_GGML_HASHSET_FULL);
6286
- LM_GGML_ASSERT(insert_result != LM_GGML_HASHSET_ALREADY_EXISTS);
6287
- }
6288
-
6289
- // only gradients of trainable parameters should be accumulated
6290
- if (accumulate && (node->flags & LM_GGML_TENSOR_FLAG_PARAM)) {
6291
- const size_t insert_result = lm_ggml_hash_insert(&acc_table, node->grad);
6292
- LM_GGML_ASSERT(insert_result != LM_GGML_HASHSET_FULL);
6293
- LM_GGML_ASSERT(insert_result != LM_GGML_HASHSET_ALREADY_EXISTS);
6294
- }
5702
+ const size_t igrad = lm_ggml_hash_find(&cgraph->visited_hash_set, node);
5703
+ LM_GGML_ASSERT(igrad != LM_GGML_HASHSET_FULL);
5704
+ LM_GGML_ASSERT(lm_ggml_bitset_get(cgraph->visited_hash_set.used, igrad));
5705
+ if ((accumulate && (node->flags & LM_GGML_TENSOR_FLAG_PARAM)) || (node->flags & LM_GGML_TENSOR_FLAG_LOSS)) {
5706
+ cgraph->grad_accs[igrad] = lm_ggml_dup_tensor(ctx_static, node);
5707
+ cgraph->grads[igrad] = cgraph->grad_accs[igrad];
5708
+ lm_ggml_format_name(cgraph->grad_accs[igrad], "grad acc for %s", node->name);
6295
5709
  }
5710
+ grads_needed[igrad] = true;
6296
5711
  }
6297
5712
 
6298
- for (int i = gf->n_nodes - 1; i >= 0; i--) {
6299
- struct lm_ggml_tensor * node = gf->nodes[i];
6300
-
5713
+ for (int i = n_nodes_f - 1; i >= 0; --i) {
6301
5714
  // inplace operations to add gradients are not created by lm_ggml_compute_backward except for gradient accumulation
6302
5715
  // use allocator to automatically make inplace operations
6303
- if (node->grad) {
6304
- lm_ggml_compute_backward(ctx, node, &zero_table, &acc_table);
6305
- }
6306
- }
6307
-
6308
- for (int i = 0; i < gf->n_nodes; i++) {
6309
- struct lm_ggml_tensor * node = gf->nodes[i];
6310
-
6311
- if (node->flags & LM_GGML_TENSOR_FLAG_PARAM) {
6312
- LM_GGML_PRINT_DEBUG("%s: found root node %p\n", __func__, (void *) node);
6313
- lm_ggml_build_forward_expand(gb, node->grad);
6314
- }
5716
+ lm_ggml_compute_backward(ctx_compute, cgraph, i, grads_needed);
6315
5717
  }
6316
5718
 
6317
- lm_ggml_hash_set_free(&zero_table);
6318
- lm_ggml_hash_set_free(&acc_table);
6319
- }
6320
-
6321
- void lm_ggml_build_opt_adamw(
6322
- struct lm_ggml_context * ctx,
6323
- struct lm_ggml_cgraph * gf,
6324
- struct lm_ggml_cgraph * gb,
6325
- float alpha,
6326
- float beta1,
6327
- float beta2,
6328
- float eps,
6329
- float wd) {
6330
- for (int i = 0; i < gf->n_nodes; i++) {
6331
- struct lm_ggml_tensor * node = gf->nodes[i];
6332
-
6333
- if (node->flags & LM_GGML_TENSOR_FLAG_PARAM) {
6334
- LM_GGML_PRINT_DEBUG("%s: found root node %p\n", __func__, (void *) node);
6335
- struct lm_ggml_tensor * opt_step = lm_ggml_opt_step_adamw(ctx, node, node->grad, alpha, beta1, beta2, eps, wd);
6336
- lm_ggml_build_forward_expand(gb, opt_step);
6337
- }
6338
- }
5719
+ free(grads_needed);
6339
5720
  }
6340
5721
 
6341
5722
  static void * incr_ptr_aligned(void ** p, size_t size, size_t align) {
@@ -6353,7 +5734,8 @@ static size_t lm_ggml_graph_nbytes(size_t size, bool grads) {
6353
5734
  incr_ptr_aligned(&p, size * sizeof(struct lm_ggml_tensor *), sizeof(struct lm_ggml_tensor *)); // leafs
6354
5735
  incr_ptr_aligned(&p, hash_size * sizeof(struct lm_ggml_tensor *), sizeof(struct lm_ggml_tensor *)); // hash keys
6355
5736
  if (grads) {
6356
- incr_ptr_aligned(&p, size * sizeof(struct lm_ggml_tensor *), sizeof(struct lm_ggml_tensor *)); // grads
5737
+ incr_ptr_aligned(&p, hash_size * sizeof(struct lm_ggml_tensor *), sizeof(struct lm_ggml_tensor *)); // grads
5738
+ incr_ptr_aligned(&p, hash_size * sizeof(struct lm_ggml_tensor *), sizeof(struct lm_ggml_tensor *)); // grad_accs
6357
5739
  }
6358
5740
  incr_ptr_aligned(&p, lm_ggml_bitset_size(hash_size) * sizeof(lm_ggml_bitset_t), sizeof(lm_ggml_bitset_t));
6359
5741
 
@@ -6379,10 +5761,12 @@ struct lm_ggml_cgraph * lm_ggml_new_graph_custom(struct lm_ggml_context * ctx, s
6379
5761
 
6380
5762
  void * p = cgraph + 1;
6381
5763
 
6382
- struct lm_ggml_tensor ** nodes_ptr = incr_ptr_aligned(&p, size * sizeof(struct lm_ggml_tensor *), sizeof(struct lm_ggml_tensor *));
6383
- struct lm_ggml_tensor ** leafs_ptr = incr_ptr_aligned(&p, size * sizeof(struct lm_ggml_tensor *), sizeof(struct lm_ggml_tensor *));
6384
- struct lm_ggml_tensor ** hash_keys_ptr = incr_ptr_aligned(&p, hash_size * sizeof(struct lm_ggml_tensor *), sizeof(struct lm_ggml_tensor *));
6385
- struct lm_ggml_tensor ** grads_ptr = grads ? incr_ptr_aligned(&p, size * sizeof(struct lm_ggml_tensor *), sizeof(struct lm_ggml_tensor *)) : NULL;
5764
+ struct lm_ggml_tensor ** nodes_ptr = incr_ptr_aligned(&p, size * sizeof(struct lm_ggml_tensor *), sizeof(struct lm_ggml_tensor *));
5765
+ struct lm_ggml_tensor ** leafs_ptr = incr_ptr_aligned(&p, size * sizeof(struct lm_ggml_tensor *), sizeof(struct lm_ggml_tensor *));
5766
+ struct lm_ggml_tensor ** hash_keys_ptr = incr_ptr_aligned(&p, hash_size * sizeof(struct lm_ggml_tensor *), sizeof(struct lm_ggml_tensor *));
5767
+ struct lm_ggml_tensor ** grads_ptr = grads ? incr_ptr_aligned(&p, hash_size * sizeof(struct lm_ggml_tensor *), sizeof(struct lm_ggml_tensor *)) : NULL;
5768
+ struct lm_ggml_tensor ** grad_accs_ptr = grads ? incr_ptr_aligned(&p, hash_size * sizeof(struct lm_ggml_tensor *), sizeof(struct lm_ggml_tensor *)) : NULL;
5769
+
6386
5770
  lm_ggml_bitset_t * hash_used = incr_ptr_aligned(&p, lm_ggml_bitset_size(hash_size) * sizeof(lm_ggml_bitset_t), sizeof(lm_ggml_bitset_t));
6387
5771
 
6388
5772
  // check that we allocated the correct amount of memory
@@ -6394,12 +5778,17 @@ struct lm_ggml_cgraph * lm_ggml_new_graph_custom(struct lm_ggml_context * ctx, s
6394
5778
  /*.n_leafs =*/ 0,
6395
5779
  /*.nodes =*/ nodes_ptr,
6396
5780
  /*.grads =*/ grads_ptr,
5781
+ /*.grad_accs =*/ grad_accs_ptr,
6397
5782
  /*.leafs =*/ leafs_ptr,
6398
5783
  /*.hash_table =*/ { hash_size, hash_used, hash_keys_ptr },
6399
5784
  /*.order =*/ LM_GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT,
6400
5785
  };
6401
5786
 
6402
5787
  lm_ggml_hash_set_reset(&cgraph->visited_hash_set);
5788
+ if (grads) {
5789
+ memset(cgraph->grads, 0, hash_size*sizeof(struct lm_ggml_tensor *));
5790
+ memset(cgraph->grad_accs, 0, hash_size*sizeof(struct lm_ggml_tensor *));
5791
+ }
6403
5792
 
6404
5793
  return cgraph;
6405
5794
  }
@@ -6410,14 +5799,15 @@ struct lm_ggml_cgraph * lm_ggml_new_graph(struct lm_ggml_context * ctx) {
6410
5799
 
6411
5800
  struct lm_ggml_cgraph lm_ggml_graph_view(struct lm_ggml_cgraph * cgraph0, int i0, int i1) {
6412
5801
  struct lm_ggml_cgraph cgraph = {
6413
- /*.size =*/ 0,
6414
- /*.n_nodes =*/ i1 - i0,
6415
- /*.n_leafs =*/ 0,
6416
- /*.nodes =*/ cgraph0->nodes + i0,
6417
- /*.grads =*/ cgraph0->grads ? cgraph0->grads + i0 : NULL,
6418
- /*.leafs =*/ NULL,
6419
- /*.hash_table =*/ { 0, NULL, NULL },
6420
- /*.order =*/ cgraph0->order,
5802
+ /*.size =*/ 0,
5803
+ /*.n_nodes =*/ i1 - i0,
5804
+ /*.n_leafs =*/ 0,
5805
+ /*.nodes =*/ cgraph0->nodes + i0,
5806
+ /*.grads =*/ NULL, // gradients would need visited_hash_set
5807
+ /*.grad_accs =*/ NULL,
5808
+ /*.leafs =*/ NULL,
5809
+ /*.visited_hash_set =*/ { 0, NULL, NULL },
5810
+ /*.order =*/ cgraph0->order,
6421
5811
  };
6422
5812
 
6423
5813
  return cgraph;
@@ -6440,19 +5830,33 @@ void lm_ggml_graph_cpy(struct lm_ggml_cgraph * src, struct lm_ggml_cgraph * dst)
6440
5830
  dst->nodes[i] = src->nodes[i];
6441
5831
  }
6442
5832
 
6443
- if (src->grads) {
6444
- LM_GGML_ASSERT(dst->grads != NULL);
6445
- for (int i = 0; i < src->n_nodes; ++i) {
6446
- dst->grads[i] = src->grads[i];
6447
- }
6448
- }
6449
-
6450
5833
  for (size_t i = 0; i < src->visited_hash_set.size; ++i) {
6451
5834
  // copy all hashset keys (tensors) that are in use
6452
5835
  if (lm_ggml_bitset_get(src->visited_hash_set.used, i)) {
6453
5836
  lm_ggml_hash_insert(&dst->visited_hash_set, src->visited_hash_set.keys[i]);
6454
5837
  }
6455
5838
  }
5839
+
5840
+ if (dst->grads) {
5841
+ memset(dst->grads, 0, dst->visited_hash_set.size*sizeof(struct lm_ggml_tensor *));
5842
+ memset(dst->grad_accs, 0, dst->visited_hash_set.size*sizeof(struct lm_ggml_tensor *));
5843
+ }
5844
+ if (src->grads) {
5845
+ LM_GGML_ASSERT(dst->grads != NULL);
5846
+ LM_GGML_ASSERT(dst->grad_accs != NULL);
5847
+ for (int i = 0; i < src->n_nodes; ++i) {
5848
+ const size_t igrad_src = lm_ggml_hash_find(&src->visited_hash_set, src->nodes[i]);
5849
+ const size_t igrad_dst = lm_ggml_hash_find(&dst->visited_hash_set, dst->nodes[i]);
5850
+
5851
+ LM_GGML_ASSERT(igrad_src != LM_GGML_HASHSET_FULL);
5852
+ LM_GGML_ASSERT(lm_ggml_bitset_get(src->visited_hash_set.used, igrad_src));
5853
+ LM_GGML_ASSERT(igrad_dst != LM_GGML_HASHSET_FULL);
5854
+ LM_GGML_ASSERT(lm_ggml_bitset_get(dst->visited_hash_set.used, igrad_dst));
5855
+
5856
+ dst->grads[igrad_dst] = src->grads[igrad_src];
5857
+ dst->grad_accs[igrad_dst] = src->grad_accs[igrad_src];
5858
+ }
5859
+ }
6456
5860
  }
6457
5861
 
6458
5862
  struct lm_ggml_cgraph * lm_ggml_graph_dup(struct lm_ggml_context * ctx, struct lm_ggml_cgraph * cgraph) {
@@ -6478,29 +5882,32 @@ void lm_ggml_graph_reset(struct lm_ggml_cgraph * cgraph) {
6478
5882
  LM_GGML_ASSERT(cgraph->grads != NULL);
6479
5883
 
6480
5884
  for (int i = 0; i < cgraph->n_nodes; i++) {
6481
- struct lm_ggml_tensor * node = cgraph->nodes[i];
5885
+ struct lm_ggml_tensor * node = cgraph->nodes[i];
5886
+ struct lm_ggml_tensor * grad_acc = lm_ggml_graph_get_grad_acc(cgraph, node);
5887
+
5888
+ if (node->op == LM_GGML_OP_OPT_STEP_ADAMW) {
5889
+ // clear momenta
5890
+ lm_ggml_set_zero(node->src[2]);
5891
+ lm_ggml_set_zero(node->src[3]);
5892
+ }
6482
5893
 
6483
5894
  // initial gradients of loss should be 1, 0 otherwise
6484
- if (node->grad) {
5895
+ if (grad_acc) {
6485
5896
  if (node->flags & LM_GGML_TENSOR_FLAG_LOSS) {
6486
- LM_GGML_ASSERT(node->grad->buffer);
6487
- LM_GGML_ASSERT(node->type == LM_GGML_TYPE_F32);
6488
- LM_GGML_ASSERT(lm_ggml_is_scalar(node));
5897
+ LM_GGML_ASSERT(grad_acc->type == LM_GGML_TYPE_F32);
5898
+ LM_GGML_ASSERT(lm_ggml_is_scalar(grad_acc));
6489
5899
 
6490
5900
  const float onef = 1.0f;
6491
- lm_ggml_backend_tensor_set(node->grad, &onef, 0, lm_ggml_nbytes(node->grad));
5901
+ if (grad_acc->buffer) {
5902
+ lm_ggml_backend_tensor_set(grad_acc, &onef, 0, sizeof(float));
5903
+ } else {
5904
+ LM_GGML_ASSERT(grad_acc->data);
5905
+ *((float *) grad_acc->data) = onef;
5906
+ }
6492
5907
  } else {
6493
- lm_ggml_set_zero(node->grad);
5908
+ lm_ggml_set_zero(grad_acc);
6494
5909
  }
6495
5910
  }
6496
-
6497
- LM_GGML_ASSERT(node);
6498
- if (node->op == LM_GGML_OP_OPT_STEP_ADAMW) {
6499
- // set iteration to 1 and clear momenta
6500
- lm_ggml_set_op_params_i32(node, 0, 1);
6501
- lm_ggml_set_zero(node->src[2]);
6502
- lm_ggml_set_zero(node->src[3]);
6503
- }
6504
5911
  }
6505
5912
  }
6506
5913
 
@@ -6538,7 +5945,7 @@ void lm_ggml_graph_add_node(struct lm_ggml_cgraph * cgraph, struct lm_ggml_tenso
6538
5945
  cgraph->n_nodes++;
6539
5946
  }
6540
5947
 
6541
- struct lm_ggml_tensor * lm_ggml_graph_get_tensor(struct lm_ggml_cgraph * cgraph, const char * name) {
5948
+ struct lm_ggml_tensor * lm_ggml_graph_get_tensor(const struct lm_ggml_cgraph * cgraph, const char * name) {
6542
5949
  for (int i = 0; i < cgraph->n_leafs; i++) {
6543
5950
  struct lm_ggml_tensor * leaf = cgraph->leafs[i];
6544
5951
 
@@ -6558,6 +5965,16 @@ struct lm_ggml_tensor * lm_ggml_graph_get_tensor(struct lm_ggml_cgraph * cgraph,
6558
5965
  return NULL;
6559
5966
  }
6560
5967
 
5968
+ struct lm_ggml_tensor * lm_ggml_graph_get_grad(const struct lm_ggml_cgraph * cgraph, const struct lm_ggml_tensor * node) {
5969
+ const size_t igrad = lm_ggml_hash_find(&cgraph->visited_hash_set, node);
5970
+ return igrad != LM_GGML_HASHSET_FULL && lm_ggml_bitset_get(cgraph->visited_hash_set.used, igrad) ? cgraph->grads[igrad] : NULL;
5971
+ }
5972
+
5973
+ struct lm_ggml_tensor * lm_ggml_graph_get_grad_acc(const struct lm_ggml_cgraph * cgraph, const struct lm_ggml_tensor * node) {
5974
+ const size_t igrad = lm_ggml_hash_find(&cgraph->visited_hash_set, node);
5975
+ return igrad != LM_GGML_HASHSET_FULL && lm_ggml_bitset_get(cgraph->visited_hash_set.used, igrad) ? cgraph->grad_accs[igrad] : NULL;
5976
+ }
5977
+
6561
5978
  void lm_ggml_graph_print(const struct lm_ggml_cgraph * cgraph) {
6562
5979
  LM_GGML_LOG_INFO("=== GRAPH ===\n");
6563
5980
 
@@ -6568,7 +5985,8 @@ void lm_ggml_graph_print(const struct lm_ggml_cgraph * cgraph) {
6568
5985
  LM_GGML_LOG_INFO(" - %3d: [ %5" PRId64 ", %5" PRId64 ", %5" PRId64 "] %16s %s\n",
6569
5986
  i,
6570
5987
  node->ne[0], node->ne[1], node->ne[2],
6571
- lm_ggml_op_name(node->op), (node->flags & LM_GGML_TENSOR_FLAG_PARAM) ? "x" : node->grad ? "g" : " ");
5988
+ lm_ggml_op_name(node->op), (node->flags & LM_GGML_TENSOR_FLAG_PARAM) ? "x" :
5989
+ lm_ggml_graph_get_grad(cgraph, node) ? "g" : " ");
6572
5990
  }
6573
5991
 
6574
5992
  LM_GGML_LOG_INFO("n_leafs = %d\n", cgraph->n_leafs);
@@ -6603,8 +6021,9 @@ static bool lm_ggml_graph_find(const struct lm_ggml_cgraph * cgraph, const struc
6603
6021
  static struct lm_ggml_tensor * lm_ggml_graph_get_parent(const struct lm_ggml_cgraph * cgraph, const struct lm_ggml_tensor * node) {
6604
6022
  for (int i = 0; i < cgraph->n_nodes; i++) {
6605
6023
  struct lm_ggml_tensor * parent = cgraph->nodes[i];
6024
+ struct lm_ggml_tensor * grad = lm_ggml_graph_get_grad(cgraph, parent);
6606
6025
 
6607
- if (parent->grad == node) {
6026
+ if (grad == node) {
6608
6027
  return parent;
6609
6028
  }
6610
6029
  }
@@ -6644,6 +6063,7 @@ void lm_ggml_graph_dump_dot(const struct lm_ggml_cgraph * gb, const struct lm_gg
6644
6063
 
6645
6064
  for (int i = 0; i < gb->n_nodes; i++) {
6646
6065
  struct lm_ggml_tensor * node = gb->nodes[i];
6066
+ struct lm_ggml_tensor * grad = lm_ggml_graph_get_grad(gb, node);
6647
6067
 
6648
6068
  if (lm_ggml_graph_get_parent(gb, node) != NULL) {
6649
6069
  continue;
@@ -6651,7 +6071,7 @@ void lm_ggml_graph_dump_dot(const struct lm_ggml_cgraph * gb, const struct lm_gg
6651
6071
 
6652
6072
  if (node->flags & LM_GGML_TENSOR_FLAG_PARAM) {
6653
6073
  snprintf(color, sizeof(color), "yellow");
6654
- } else if (node->grad) {
6074
+ } else if (grad) {
6655
6075
  if (lm_ggml_graph_find(gf, node)) {
6656
6076
  snprintf(color, sizeof(color), "green");
6657
6077
  } else {
@@ -6678,8 +6098,8 @@ void lm_ggml_graph_dump_dot(const struct lm_ggml_cgraph * gb, const struct lm_gg
6678
6098
  fprintf(fp, "%d [%" PRId64 ", %" PRId64 ", %" PRId64 "] | <x>%s", i, node->ne[0], node->ne[1], node->ne[2], lm_ggml_op_symbol(node->op));
6679
6099
  }
6680
6100
 
6681
- if (node->grad) {
6682
- fprintf(fp, " | <g>%s\"; ]\n", lm_ggml_op_symbol(node->grad->op));
6101
+ if (grad) {
6102
+ fprintf(fp, " | <g>%s\"; ]\n", lm_ggml_op_symbol(grad->op));
6683
6103
  } else {
6684
6104
  fprintf(fp, "\"; ]\n");
6685
6105
  }
@@ -6789,9 +6209,9 @@ void lm_ggml_quantize_init(enum lm_ggml_type type) {
6789
6209
  case LM_GGML_TYPE_IQ2_XS:
6790
6210
  case LM_GGML_TYPE_IQ2_S:
6791
6211
  case LM_GGML_TYPE_IQ1_S:
6792
- case LM_GGML_TYPE_IQ1_M: iq2xs_init_impl(type); break;
6793
- case LM_GGML_TYPE_IQ3_XXS: iq3xs_init_impl(256); break;
6794
- case LM_GGML_TYPE_IQ3_S: iq3xs_init_impl(512); break;
6212
+ case LM_GGML_TYPE_IQ1_M: lm_iq2xs_init_impl(type); break;
6213
+ case LM_GGML_TYPE_IQ3_XXS: lm_iq3xs_init_impl(256); break;
6214
+ case LM_GGML_TYPE_IQ3_S: lm_iq3xs_init_impl(512); break;
6795
6215
  default: // nothing
6796
6216
  break;
6797
6217
  }
@@ -6802,10 +6222,10 @@ void lm_ggml_quantize_init(enum lm_ggml_type type) {
6802
6222
  void lm_ggml_quantize_free(void) {
6803
6223
  lm_ggml_critical_section_start();
6804
6224
 
6805
- iq2xs_free_impl(LM_GGML_TYPE_IQ2_XXS);
6806
- iq2xs_free_impl(LM_GGML_TYPE_IQ2_XS);
6807
- iq2xs_free_impl(LM_GGML_TYPE_IQ1_S);
6808
- iq3xs_free_impl(256);
6225
+ lm_iq2xs_free_impl(LM_GGML_TYPE_IQ2_XXS);
6226
+ lm_iq2xs_free_impl(LM_GGML_TYPE_IQ2_XS);
6227
+ lm_iq2xs_free_impl(LM_GGML_TYPE_IQ1_S);
6228
+ lm_iq3xs_free_impl(256);
6809
6229
 
6810
6230
  lm_ggml_critical_section_end();
6811
6231
  }
@@ -8169,222 +7589,30 @@ void lm_gguf_get_meta_data(const struct lm_gguf_context * ctx, void * data) {
8169
7589
  lm_gguf_buf_free(buf);
8170
7590
  }
8171
7591
 
8172
- ////////////////////////////////////////////////////////////////////////////////
8173
-
8174
- int lm_ggml_cpu_has_avx(void) {
8175
- #if defined(__AVX__)
8176
- return 1;
8177
- #else
8178
- return 0;
8179
- #endif
8180
- }
8181
-
8182
- int lm_ggml_cpu_has_avx_vnni(void) {
8183
- #if defined(__AVXVNNI__)
8184
- return 1;
8185
- #else
8186
- return 0;
8187
- #endif
8188
- }
8189
-
8190
- int lm_ggml_cpu_has_avx2(void) {
8191
- #if defined(__AVX2__)
8192
- return 1;
8193
- #else
8194
- return 0;
8195
- #endif
8196
- }
8197
-
8198
- int lm_ggml_cpu_has_avx512(void) {
8199
- #if defined(__AVX512F__)
8200
- return 1;
8201
- #else
8202
- return 0;
8203
- #endif
8204
- }
8205
-
8206
- int lm_ggml_cpu_has_avx512_vbmi(void) {
8207
- #if defined(__AVX512VBMI__)
8208
- return 1;
8209
- #else
8210
- return 0;
8211
- #endif
8212
- }
8213
-
8214
- int lm_ggml_cpu_has_avx512_vnni(void) {
8215
- #if defined(__AVX512VNNI__)
8216
- return 1;
8217
- #else
8218
- return 0;
8219
- #endif
8220
- }
8221
-
8222
- int lm_ggml_cpu_has_avx512_bf16(void) {
8223
- #if defined(__AVX512BF16__)
8224
- return 1;
8225
- #else
8226
- return 0;
8227
- #endif
8228
- }
8229
-
8230
- int lm_ggml_cpu_has_amx_int8(void) {
8231
- #if defined(__AMX_INT8__)
8232
- return 1;
8233
- #else
8234
- return 0;
8235
- #endif
8236
- }
8237
-
8238
- int lm_ggml_cpu_has_fma(void) {
8239
- #if defined(__FMA__)
8240
- return 1;
8241
- #else
8242
- return 0;
8243
- #endif
8244
- }
8245
-
8246
- int lm_ggml_cpu_has_arm_fma(void) {
8247
- #if defined(__ARM_FEATURE_FMA)
8248
- return 1;
8249
- #else
8250
- return 0;
8251
- #endif
8252
- }
8253
-
8254
- int lm_ggml_cpu_has_riscv_v(void) {
8255
- #if defined(__riscv_v_intrinsic)
8256
- return 1;
8257
- #else
8258
- return 0;
8259
- #endif
8260
- }
8261
-
8262
- int lm_ggml_cpu_has_metal(void) {
8263
- #if defined(LM_GGML_USE_METAL)
8264
- return 1;
8265
- #else
8266
- return 0;
8267
- #endif
8268
- }
8269
-
8270
- int lm_ggml_cpu_has_f16c(void) {
8271
- #if defined(__F16C__)
8272
- return 1;
8273
- #else
8274
- return 0;
8275
- #endif
8276
- }
8277
-
8278
- int lm_ggml_cpu_has_fp16_va(void) {
8279
- #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
8280
- return 1;
8281
- #else
8282
- return 0;
8283
- #endif
8284
- }
8285
-
8286
- int lm_ggml_cpu_has_wasm_simd(void) {
8287
- #if defined(__wasm_simd128__)
8288
- return 1;
8289
- #else
8290
- return 0;
8291
- #endif
8292
- }
8293
-
8294
- int lm_ggml_cpu_has_blas(void) {
8295
- #if defined(LM_GGML_USE_BLAS) || defined(LM_GGML_USE_CUDA) || defined(LM_GGML_USE_VULKAN) || defined(LM_GGML_USE_SYCL)
8296
- return 1;
8297
- #else
8298
- return 0;
8299
- #endif
8300
- }
8301
-
8302
- int lm_ggml_cpu_has_cuda(void) {
8303
- #if defined(LM_GGML_USE_CUDA)
8304
- return 1;
8305
- #else
8306
- return 0;
8307
- #endif
8308
- }
8309
-
8310
- int lm_ggml_cpu_has_vulkan(void) {
8311
- #if defined(LM_GGML_USE_VULKAN)
8312
- return 1;
8313
- #else
8314
- return 0;
8315
- #endif
8316
- }
8317
-
8318
- int lm_ggml_cpu_has_kompute(void) {
8319
- #if defined(LM_GGML_USE_KOMPUTE)
8320
- return 1;
8321
- #else
8322
- return 0;
8323
- #endif
8324
- }
8325
-
8326
- int lm_ggml_cpu_has_sycl(void) {
8327
- #if defined(LM_GGML_USE_SYCL)
8328
- return 1;
8329
- #else
8330
- return 0;
8331
- #endif
8332
- }
8333
-
8334
- int lm_ggml_cpu_has_rpc(void) {
8335
- #if defined(LM_GGML_USE_RPC)
8336
- return 1;
8337
- #else
8338
- return 0;
8339
- #endif
8340
- }
8341
-
8342
- int lm_ggml_cpu_has_cann(void) {
8343
- #if defined(LM_GGML_USE_CANN)
8344
- return 1;
8345
- #else
8346
- return 0;
8347
- #endif
8348
- }
8349
-
8350
- int lm_ggml_cpu_has_llamafile(void) {
8351
- #if defined(LM_GGML_USE_LLAMAFILE)
8352
- return 1;
8353
- #else
8354
- return 0;
8355
- #endif
8356
- }
8357
-
8358
- int lm_ggml_cpu_has_gpublas(void) {
8359
- return lm_ggml_cpu_has_cuda() || lm_ggml_cpu_has_vulkan() || lm_ggml_cpu_has_kompute() || lm_ggml_cpu_has_sycl();
8360
- }
8361
-
8362
- int lm_ggml_cpu_has_sse3(void) {
8363
- #if defined(__SSE3__)
8364
- return 1;
8365
- #else
8366
- return 0;
8367
- #endif
7592
+ void lm_ggml_log_set(lm_ggml_log_callback log_callback, void * user_data) {
7593
+ g_logger_state.log_callback = log_callback ? log_callback : lm_ggml_log_callback_default;
7594
+ g_logger_state.log_callback_user_data = user_data;
8368
7595
  }
8369
7596
 
8370
- int lm_ggml_cpu_has_ssse3(void) {
8371
- #if defined(__SSSE3__)
8372
- return 1;
8373
- #else
8374
- return 0;
8375
- #endif
7597
+ void lm_ggml_threadpool_params_init(struct lm_ggml_threadpool_params * p, int n_threads) {
7598
+ p->n_threads = n_threads;
7599
+ p->prio = 0; // default priority (usually means normal or inherited)
7600
+ p->poll = 50; // hybrid-polling enabled
7601
+ p->strict_cpu = false; // no strict placement (all threads share same cpumask)
7602
+ p->paused = false; // threads are ready to go
7603
+ memset(p->cpumask, 0, LM_GGML_MAX_N_THREADS); // all-zero means use the default affinity (usually inherited)
8376
7604
  }
8377
7605
 
8378
- int lm_ggml_cpu_has_vsx(void) {
8379
- #if defined(__POWER9_VECTOR__)
8380
- return 1;
8381
- #else
8382
- return 0;
8383
- #endif
7606
+ struct lm_ggml_threadpool_params lm_ggml_threadpool_params_default(int n_threads) {
7607
+ struct lm_ggml_threadpool_params p;
7608
+ lm_ggml_threadpool_params_init(&p, n_threads);
7609
+ return p;
8384
7610
  }
8385
7611
 
8386
- void lm_ggml_log_set(lm_ggml_log_callback log_callback, void * user_data) {
8387
- g_logger_state.log_callback = log_callback ? log_callback : lm_ggml_log_callback_default;
8388
- g_logger_state.log_callback_user_data = user_data;
7612
+ bool lm_ggml_threadpool_params_match(const struct lm_ggml_threadpool_params * p0, const struct lm_ggml_threadpool_params * p1) {
7613
+ if (p0->n_threads != p1->n_threads ) return false;
7614
+ if (p0->prio != p1->prio ) return false;
7615
+ if (p0->poll != p1->poll ) return false;
7616
+ if (p0->strict_cpu != p1->strict_cpu ) return false;
7617
+ return memcmp(p0->cpumask, p1->cpumask, LM_GGML_MAX_N_THREADS) == 0;
8389
7618
  }
8390
- ////////////////////////////////////////////////////////////////////////////////