cui-llama.rn 1.4.0 → 1.4.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. package/README.md +4 -23
  2. package/android/build.gradle +12 -3
  3. package/android/src/main/CMakeLists.txt +13 -7
  4. package/android/src/main/java/com/rnllama/LlamaContext.java +27 -20
  5. package/android/src/main/java/com/rnllama/RNLlama.java +5 -1
  6. package/android/src/main/jni.cpp +15 -12
  7. package/android/src/main/jniLibs/arm64-v8a/librnllama.so +0 -0
  8. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8.so +0 -0
  9. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2.so +0 -0
  10. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod.so +0 -0
  11. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod_i8mm.so +0 -0
  12. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_i8mm.so +0 -0
  13. package/android/src/main/jniLibs/x86_64/librnllama.so +0 -0
  14. package/android/src/main/jniLibs/x86_64/librnllama_x86_64.so +0 -0
  15. package/cpp/README.md +1 -1
  16. package/cpp/common.cpp +158 -267
  17. package/cpp/common.h +46 -12
  18. package/cpp/ggml-alloc.c +1042 -1037
  19. package/cpp/ggml-backend-impl.h +255 -256
  20. package/cpp/ggml-backend-reg.cpp +582 -582
  21. package/cpp/ggml-backend.cpp +2002 -2002
  22. package/cpp/ggml-backend.h +354 -352
  23. package/cpp/ggml-common.h +1853 -1853
  24. package/cpp/ggml-cpp.h +39 -39
  25. package/cpp/ggml-cpu-aarch64.cpp +4247 -4247
  26. package/cpp/ggml-cpu-aarch64.h +8 -8
  27. package/cpp/ggml-cpu-impl.h +386 -386
  28. package/cpp/ggml-cpu-quants.c +10920 -10839
  29. package/cpp/ggml-cpu-traits.cpp +36 -36
  30. package/cpp/ggml-cpu-traits.h +38 -38
  31. package/cpp/ggml-cpu.c +329 -60
  32. package/cpp/ggml-cpu.cpp +10 -2
  33. package/cpp/ggml-cpu.h +135 -135
  34. package/cpp/ggml-impl.h +567 -567
  35. package/cpp/ggml-metal-impl.h +17 -17
  36. package/cpp/ggml-metal.m +4884 -4884
  37. package/cpp/ggml-quants.c +5238 -5238
  38. package/cpp/ggml-threading.h +14 -14
  39. package/cpp/ggml.c +6514 -6448
  40. package/cpp/ggml.h +2194 -2163
  41. package/cpp/gguf.cpp +1329 -1325
  42. package/cpp/gguf.h +202 -202
  43. package/cpp/json-schema-to-grammar.cpp +1045 -1045
  44. package/cpp/json-schema-to-grammar.h +8 -8
  45. package/cpp/json.hpp +24766 -24766
  46. package/cpp/llama-adapter.cpp +347 -346
  47. package/cpp/llama-adapter.h +74 -73
  48. package/cpp/llama-arch.cpp +1487 -1434
  49. package/cpp/llama-arch.h +400 -395
  50. package/cpp/llama-batch.cpp +368 -368
  51. package/cpp/llama-batch.h +88 -88
  52. package/cpp/llama-chat.cpp +578 -567
  53. package/cpp/llama-chat.h +52 -51
  54. package/cpp/llama-context.cpp +1775 -1771
  55. package/cpp/llama-context.h +128 -128
  56. package/cpp/llama-cparams.cpp +1 -1
  57. package/cpp/llama-cparams.h +37 -37
  58. package/cpp/llama-cpp.h +30 -30
  59. package/cpp/llama-grammar.cpp +1139 -1139
  60. package/cpp/llama-grammar.h +143 -143
  61. package/cpp/llama-hparams.cpp +71 -71
  62. package/cpp/llama-hparams.h +139 -140
  63. package/cpp/llama-impl.cpp +167 -167
  64. package/cpp/llama-impl.h +61 -61
  65. package/cpp/llama-kv-cache.cpp +718 -718
  66. package/cpp/llama-kv-cache.h +218 -218
  67. package/cpp/llama-mmap.cpp +2 -1
  68. package/cpp/llama-mmap.h +67 -67
  69. package/cpp/llama-model-loader.cpp +1124 -1011
  70. package/cpp/llama-model-loader.h +167 -158
  71. package/cpp/llama-model.cpp +3997 -2202
  72. package/cpp/llama-model.h +370 -391
  73. package/cpp/llama-sampling.cpp +2408 -2406
  74. package/cpp/llama-sampling.h +32 -48
  75. package/cpp/llama-vocab.cpp +3247 -1982
  76. package/cpp/llama-vocab.h +125 -182
  77. package/cpp/llama.cpp +416 -2886
  78. package/cpp/llama.h +1323 -1285
  79. package/cpp/log.cpp +401 -401
  80. package/cpp/log.h +121 -121
  81. package/cpp/rn-llama.cpp +822 -0
  82. package/cpp/rn-llama.h +123 -0
  83. package/cpp/rn-llama.hpp +18 -12
  84. package/cpp/sampling.cpp +505 -500
  85. package/cpp/sgemm.cpp +2597 -2597
  86. package/cpp/speculative.cpp +277 -274
  87. package/cpp/speculative.h +28 -28
  88. package/cpp/unicode.cpp +2 -3
  89. package/ios/CMakeLists.txt +99 -0
  90. package/ios/RNLlama.h +5 -1
  91. package/ios/RNLlama.mm +2 -2
  92. package/ios/RNLlamaContext.h +8 -1
  93. package/ios/RNLlamaContext.mm +15 -11
  94. package/ios/rnllama.xcframework/Info.plist +74 -0
  95. package/jest/mock.js +3 -2
  96. package/lib/commonjs/NativeRNLlama.js.map +1 -1
  97. package/lib/commonjs/index.js +4 -2
  98. package/lib/commonjs/index.js.map +1 -1
  99. package/lib/module/NativeRNLlama.js.map +1 -1
  100. package/lib/module/index.js +4 -2
  101. package/lib/module/index.js.map +1 -1
  102. package/lib/typescript/NativeRNLlama.d.ts +5 -1
  103. package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
  104. package/lib/typescript/index.d.ts.map +1 -1
  105. package/llama-rn.podspec +8 -2
  106. package/package.json +5 -2
  107. package/src/NativeRNLlama.ts +5 -1
  108. package/src/index.ts +9 -2
package/cpp/ggml-cpu.c CHANGED
@@ -3966,6 +3966,57 @@ static void lm_ggml_compute_forward_dup_bytes(
3966
3966
  }
3967
3967
  }
3968
3968
 
3969
+ static void lm_ggml_compute_forward_dup_q(
3970
+ const struct lm_ggml_compute_params * params,
3971
+ struct lm_ggml_tensor * dst) {
3972
+
3973
+ const struct lm_ggml_tensor * src0 = dst->src[0];
3974
+ const struct lm_ggml_tensor * src1 = dst->src[1];
3975
+
3976
+ LM_GGML_TENSOR_BINARY_OP_LOCALS
3977
+
3978
+ const enum lm_ggml_type type = src0->type;
3979
+ lm_ggml_to_float_t const dequantize_row_q = lm_ggml_get_type_traits(type)->to_float;
3980
+
3981
+ size_t qk = lm_ggml_blck_size(type);
3982
+ const int64_t nr = lm_ggml_nelements(src1) / qk;
3983
+
3984
+ // destination must be contiguous in the first dimension
3985
+ LM_GGML_ASSERT(nb10 == lm_ggml_type_size(dst->type));
3986
+ // must either have first dimension large enough to hold a row, or fully contiguous
3987
+ LM_GGML_ASSERT((ne10 % qk) == 0 || lm_ggml_is_contiguous(dst));
3988
+
3989
+ const int ith = params->ith;
3990
+ const int nth = params->nth;
3991
+
3992
+ const int dr = (nr + nth - 1)/nth;
3993
+
3994
+ // row range for this thread
3995
+ const int ir0 = dr*ith;
3996
+ const int ir1 = MIN(ir0 + dr, nr);
3997
+
3998
+ for (int64_t ir = ir0; ir < ir1; ++ir) {
3999
+
4000
+ uint32_t i = ir * qk;
4001
+
4002
+ const int64_t i03 = i/(ne00 * ne01 * ne02);
4003
+ const int64_t i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
4004
+ const int64_t i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
4005
+ const int64_t i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
4006
+ const int64_t x_offset = (i00/qk)*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
4007
+
4008
+ const int64_t i13 = i/(ne10 * ne11 * ne12);
4009
+ const int64_t i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
4010
+ const int64_t i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
4011
+ const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
4012
+ const int64_t dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13;
4013
+
4014
+ dequantize_row_q(
4015
+ (const void *) ((char *) src0->data + x_offset),
4016
+ (float *) ((char *) dst->data + dst_offset), qk);
4017
+ }
4018
+ }
4019
+
3969
4020
  static void lm_ggml_compute_forward_dup(
3970
4021
  const struct lm_ggml_compute_params * params,
3971
4022
  struct lm_ggml_tensor * dst) {
@@ -3992,6 +4043,10 @@ static void lm_ggml_compute_forward_dup(
3992
4043
  } break;
3993
4044
  default:
3994
4045
  {
4046
+ if (lm_ggml_is_quantized(src0->type) && dst->type == LM_GGML_TYPE_F32) {
4047
+ lm_ggml_compute_forward_dup_q(params, dst);
4048
+ break;
4049
+ }
3995
4050
  LM_GGML_ABORT("fatal error");
3996
4051
  }
3997
4052
  }
@@ -6690,20 +6745,20 @@ static void lm_ggml_compute_forward_silu_back_f32(
6690
6745
  const struct lm_ggml_compute_params * params,
6691
6746
  struct lm_ggml_tensor * dst) {
6692
6747
 
6693
- const struct lm_ggml_tensor * src0 = dst->src[0];
6694
- const struct lm_ggml_tensor * grad = dst->src[1];
6748
+ const struct lm_ggml_tensor * grad = dst->src[0];
6749
+ const struct lm_ggml_tensor * src1 = dst->src[1];
6695
6750
 
6696
6751
  assert(lm_ggml_is_contiguous_1(grad));
6697
- assert(lm_ggml_is_contiguous_1(src0));
6752
+ assert(lm_ggml_is_contiguous_1(src1));
6698
6753
  assert(lm_ggml_is_contiguous_1(dst));
6699
- assert(lm_ggml_are_same_shape(src0, dst));
6700
- assert(lm_ggml_are_same_shape(src0, grad));
6754
+ assert(lm_ggml_are_same_shape(src1, dst));
6755
+ assert(lm_ggml_are_same_shape(src1, grad));
6701
6756
 
6702
6757
  const int ith = params->ith;
6703
6758
  const int nth = params->nth;
6704
6759
 
6705
- const int nc = src0->ne[0];
6706
- const int nr = lm_ggml_nrows(src0);
6760
+ const int nc = src1->ne[0];
6761
+ const int nr = lm_ggml_nrows(src1);
6707
6762
 
6708
6763
  // rows per thread
6709
6764
  const int dr = (nr + nth - 1)/nth;
@@ -6715,7 +6770,7 @@ static void lm_ggml_compute_forward_silu_back_f32(
6715
6770
  for (int i1 = ir0; i1 < ir1; i1++) {
6716
6771
  lm_ggml_vec_silu_backward_f32(nc,
6717
6772
  (float *) ((char *) dst->data + i1*( dst->nb[1])),
6718
- (float *) ((char *) src0->data + i1*(src0->nb[1])),
6773
+ (float *) ((char *) src1->data + i1*(src1->nb[1])),
6719
6774
  (float *) ((char *) grad->data + i1*(grad->nb[1])));
6720
6775
 
6721
6776
  #ifndef NDEBUG
@@ -6894,7 +6949,7 @@ static void lm_ggml_compute_forward_norm_f32(
6894
6949
  float eps;
6895
6950
  memcpy(&eps, dst->op_params, sizeof(float));
6896
6951
 
6897
- LM_GGML_ASSERT(eps > 0.0f);
6952
+ LM_GGML_ASSERT(eps >= 0.0f);
6898
6953
 
6899
6954
  // TODO: optimize
6900
6955
  for (int64_t i03 = 0; i03 < ne03; i03++) {
@@ -6965,7 +7020,7 @@ static void lm_ggml_compute_forward_rms_norm_f32(
6965
7020
  float eps;
6966
7021
  memcpy(&eps, dst->op_params, sizeof(float));
6967
7022
 
6968
- LM_GGML_ASSERT(eps > 0.0f);
7023
+ LM_GGML_ASSERT(eps >= 0.0f);
6969
7024
 
6970
7025
  // TODO: optimize
6971
7026
  for (int64_t i03 = 0; i03 < ne03; i03++) {
@@ -7017,12 +7072,13 @@ static void lm_ggml_compute_forward_rms_norm_back_f32(
7017
7072
  const struct lm_ggml_compute_params * params,
7018
7073
  struct lm_ggml_tensor * dst) {
7019
7074
 
7020
- const struct lm_ggml_tensor * src0 = dst->src[0];
7021
- const struct lm_ggml_tensor * src1 = dst->src[1];
7075
+ const struct lm_ggml_tensor * src0 = dst->src[0]; // gradients from forward pass output
7076
+ const struct lm_ggml_tensor * src1 = dst->src[1]; // src1 from forward pass
7022
7077
 
7023
7078
  LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, dst) && lm_ggml_are_same_shape(src0, src1));
7024
7079
 
7025
7080
  LM_GGML_ASSERT(src0->nb[0] == sizeof(float));
7081
+ LM_GGML_ASSERT(src1->nb[0] == sizeof(float));
7026
7082
 
7027
7083
  const int ith = params->ith;
7028
7084
  const int nth = params->nth;
@@ -7041,8 +7097,8 @@ static void lm_ggml_compute_forward_rms_norm_back_f32(
7041
7097
  const int64_t i12 = i02;
7042
7098
  const int64_t i13 = i03;
7043
7099
 
7044
- const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
7045
- const float * dz = (float *) ((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13);
7100
+ const float * dz = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
7101
+ const float * x = (float *) ((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13);
7046
7102
 
7047
7103
  lm_ggml_float sum_xx = 0.0;
7048
7104
  lm_ggml_float sum_xdz = 0.0;
@@ -7065,9 +7121,9 @@ static void lm_ggml_compute_forward_rms_norm_back_f32(
7065
7121
  {
7066
7122
  // z = rms_norm(x)
7067
7123
  //
7068
- // rms_norm(src0) =
7124
+ // rms_norm(src1) =
7069
7125
  // scale(
7070
- // src0,
7126
+ // src1,
7071
7127
  // div(
7072
7128
  // 1,
7073
7129
  // sqrt(
@@ -7075,13 +7131,13 @@ static void lm_ggml_compute_forward_rms_norm_back_f32(
7075
7131
  // scale(
7076
7132
  // sum(
7077
7133
  // sqr(
7078
- // src0)),
7134
+ // src1)),
7079
7135
  // (1.0/N)),
7080
7136
  // eps))));
7081
7137
 
7082
7138
  // postorder:
7083
7139
  // ## op args grad
7084
- // 00 param src0 grad[#00]
7140
+ // 00 param src1 grad[#00]
7085
7141
  // 01 const 1
7086
7142
  // 02 sqr (#00) grad[#02]
7087
7143
  // 03 sum (#02) grad[#03]
@@ -7158,6 +7214,7 @@ static void lm_ggml_compute_forward_rms_norm_back_f32(
7158
7214
  // dx := scale(dx, rrms)
7159
7215
  float * dx = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
7160
7216
 
7217
+ // dx[i00] = (x*(-sum_xdz/sum_eps) + dz) / sqrtf(mean_eps)
7161
7218
  lm_ggml_vec_cpy_f32 (ne00, dx, x);
7162
7219
  // lm_ggml_vec_scale_f32(ne00, dx, -mean_xdz/mean_eps);
7163
7220
  lm_ggml_vec_scale_f32(ne00, dx, (float)(-sum_xdz)/sum_eps);
@@ -7749,12 +7806,13 @@ static void lm_ggml_compute_forward_out_prod_f32(
7749
7806
  const int ith = params->ith;
7750
7807
  const int nth = params->nth;
7751
7808
 
7752
- LM_GGML_ASSERT(ne0 == ne00);
7753
- LM_GGML_ASSERT(ne1 == ne10);
7754
- LM_GGML_ASSERT(ne2 == ne02);
7755
- LM_GGML_ASSERT(ne02 == ne12);
7756
- LM_GGML_ASSERT(ne3 == ne13);
7757
- LM_GGML_ASSERT(ne03 == ne13);
7809
+ LM_GGML_ASSERT(ne0 == ne00);
7810
+ LM_GGML_ASSERT(ne1 == ne10);
7811
+ LM_GGML_ASSERT(ne2 == ne12);
7812
+ LM_GGML_ASSERT(ne3 == ne13);
7813
+
7814
+ LM_GGML_ASSERT(ne2 % ne02 == 0);
7815
+ LM_GGML_ASSERT(ne3 % ne03 == 0);
7758
7816
 
7759
7817
  // we don't support permuted src0 or src1
7760
7818
  LM_GGML_ASSERT(nb00 == sizeof(float));
@@ -7796,6 +7854,10 @@ static void lm_ggml_compute_forward_out_prod_f32(
7796
7854
  const int64_t blck_0 = MAX(LM_GGML_VEC_MAD_UNROLL, 32);
7797
7855
  const int64_t blck_1 = 16;
7798
7856
 
7857
+ // dps == dst per src0, used for group query attention
7858
+ const int64_t dps2 = ne2 / ne02;
7859
+ const int64_t dps3 = ne3 / ne03;
7860
+
7799
7861
  for (int64_t bir = ir0; bir < ir1; bir += blck_1) {
7800
7862
  const int64_t bir1 = MIN(bir + blck_1, ir1);
7801
7863
  for (int64_t bi01 = 0; bi01 < ne01; bi01 += blck_0) {
@@ -7806,8 +7868,8 @@ static void lm_ggml_compute_forward_out_prod_f32(
7806
7868
  const int64_t i2 = (ir - i3*ne2*ne1)/ne1;
7807
7869
  const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1);
7808
7870
 
7809
- const int64_t i02 = i2;
7810
- const int64_t i03 = i3;
7871
+ const int64_t i02 = i2 / dps2;
7872
+ const int64_t i03 = i3 / dps3;
7811
7873
 
7812
7874
  //const int64_t i10 = i1;
7813
7875
  const int64_t i12 = i2;
@@ -8905,9 +8967,9 @@ static void lm_ggml_compute_forward_soft_max(
8905
8967
  }
8906
8968
 
8907
8969
 
8908
- // lm_ggml_compute_forward_soft_max_back
8970
+ // lm_ggml_compute_forward_soft_max_ext_back
8909
8971
 
8910
- static void lm_ggml_compute_forward_soft_max_back_f32(
8972
+ static void lm_ggml_compute_forward_soft_max_ext_back_f32(
8911
8973
  const struct lm_ggml_compute_params * params,
8912
8974
  struct lm_ggml_tensor * dst) {
8913
8975
 
@@ -8920,6 +8982,14 @@ static void lm_ggml_compute_forward_soft_max_back_f32(
8920
8982
  LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, dst));
8921
8983
  LM_GGML_ASSERT(lm_ggml_are_same_shape(src1, dst));
8922
8984
 
8985
+ float scale = 1.0f;
8986
+ float max_bias = 0.0f;
8987
+
8988
+ memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float));
8989
+ memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float));
8990
+
8991
+ LM_GGML_ASSERT(max_bias == 0.0f);
8992
+
8923
8993
  // TODO: handle transposed/permuted matrices
8924
8994
 
8925
8995
  const int ith = params->ith;
@@ -8968,10 +9038,11 @@ static void lm_ggml_compute_forward_soft_max_back_f32(
8968
9038
 
8969
9039
  // linear runtime, no additional memory
8970
9040
  float dot_y_dy = 0;
8971
- lm_ggml_vec_dot_f32 (nc, &dot_y_dy, 0, y, 0, dy, 0, 1);
8972
- lm_ggml_vec_cpy_f32 (nc, dx, dy);
8973
- lm_ggml_vec_acc1_f32(nc, dx, -dot_y_dy);
8974
- lm_ggml_vec_mul_f32 (nc, dx, dx, y);
9041
+ lm_ggml_vec_dot_f32 (nc, &dot_y_dy, 0, y, 0, dy, 0, 1);
9042
+ lm_ggml_vec_cpy_f32 (nc, dx, dy);
9043
+ lm_ggml_vec_acc1_f32 (nc, dx, -dot_y_dy);
9044
+ lm_ggml_vec_mul_f32 (nc, dx, dx, y);
9045
+ lm_ggml_vec_scale_f32(nc, dx, scale);
8975
9046
 
8976
9047
  #ifndef NDEBUG
8977
9048
  for (int i = 0; i < nc; ++i) {
@@ -8982,7 +9053,7 @@ static void lm_ggml_compute_forward_soft_max_back_f32(
8982
9053
  }
8983
9054
  }
8984
9055
 
8985
- static void lm_ggml_compute_forward_soft_max_back(
9056
+ static void lm_ggml_compute_forward_soft_max_ext_back(
8986
9057
  const struct lm_ggml_compute_params * params,
8987
9058
  struct lm_ggml_tensor * dst) {
8988
9059
 
@@ -8991,7 +9062,7 @@ static void lm_ggml_compute_forward_soft_max_back(
8991
9062
  switch (src0->type) {
8992
9063
  case LM_GGML_TYPE_F32:
8993
9064
  {
8994
- lm_ggml_compute_forward_soft_max_back_f32(params, dst);
9065
+ lm_ggml_compute_forward_soft_max_ext_back_f32(params, dst);
8995
9066
  } break;
8996
9067
  default:
8997
9068
  {
@@ -9984,9 +10055,10 @@ static void lm_ggml_compute_forward_im2col_back_f32(
9984
10055
  const struct lm_ggml_compute_params * params,
9985
10056
  struct lm_ggml_tensor * dst) {
9986
10057
 
9987
- const struct lm_ggml_tensor * src0 = dst->src[0];
9988
- const struct lm_ggml_tensor * src1 = dst->src[1];
10058
+ const struct lm_ggml_tensor * src0 = dst->src[0]; // gradients of forward pass output
10059
+ const struct lm_ggml_tensor * src1 = dst->src[1]; // convolution kernel
9989
10060
 
10061
+ LM_GGML_ASSERT(src0->type == LM_GGML_TYPE_F32);
9990
10062
  LM_GGML_ASSERT(src1->type == LM_GGML_TYPE_F32);
9991
10063
  LM_GGML_ASSERT( dst->type == LM_GGML_TYPE_F32);
9992
10064
 
@@ -10008,11 +10080,11 @@ static void lm_ggml_compute_forward_im2col_back_f32(
10008
10080
  const int64_t IH = is_2D ? ne1 : 1;
10009
10081
  const int64_t IW = ne0;
10010
10082
 
10011
- const int64_t KH = is_2D ? ne01 : 1;
10012
- const int64_t KW = ne00;
10083
+ const int64_t KH = is_2D ? ne11 : 1;
10084
+ const int64_t KW = ne10;
10013
10085
 
10014
- const int64_t OH = is_2D ? ne12 : 1;
10015
- const int64_t OW = ne11;
10086
+ const int64_t OH = is_2D ? ne02 : 1;
10087
+ const int64_t OW = ne01;
10016
10088
 
10017
10089
  int ofs0 = is_2D ? nb3 : nb2;
10018
10090
  int ofs1 = is_2D ? nb2 : nb1;
@@ -10058,9 +10130,9 @@ static void lm_ggml_compute_forward_im2col_back_f32(
10058
10130
  continue;
10059
10131
  }
10060
10132
 
10061
- const float * const src_data = (const float *) src1->data
10133
+ const float * const grad_in = (const float *) src0->data
10062
10134
  + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
10063
- grad += src_data[iic*(KH*KW) + ikh*KW + ikw];
10135
+ grad += grad_in[iic*(KH*KW) + ikh*KW + ikw];
10064
10136
  }
10065
10137
  }
10066
10138
  float * dst_data = (float *)((char *) wdata + (in*ofs0 + iic*ofs1)); // [IH, IW]
@@ -11802,9 +11874,9 @@ static void lm_ggml_compute_forward_add_rel_pos(
11802
11874
  static void lm_ggml_compute_forward_rwkv_wkv6_f32(
11803
11875
  const struct lm_ggml_compute_params * params,
11804
11876
  struct lm_ggml_tensor * dst) {
11805
- const int64_t T = dst->src[1]->ne[3];
11877
+ const int64_t T = dst->src[1]->ne[2];
11806
11878
  const int64_t C = dst->ne[0];
11807
- const int64_t HEADS = dst->src[1]->ne[2];
11879
+ const int64_t HEADS = dst->src[1]->ne[1];
11808
11880
  const int64_t n_seqs = dst->src[5]->ne[1];
11809
11881
  const int64_t head_size = C / HEADS;
11810
11882
 
@@ -11999,6 +12071,197 @@ static void lm_ggml_compute_forward_rwkv_wkv6(
11999
12071
  }
12000
12072
  }
12001
12073
 
12074
+ // lm_ggml_compute_forward_gla
12075
+
12076
+ static void lm_ggml_compute_forward_gla_f32(
12077
+ const struct lm_ggml_compute_params * params,
12078
+ struct lm_ggml_tensor * dst) {
12079
+ const int64_t T = dst->src[1]->ne[2];
12080
+ const int64_t C = dst->ne[0];
12081
+ const int64_t HEADS = dst->src[1]->ne[1];
12082
+ const int64_t n_seqs = dst->src[4]->ne[1];
12083
+ const int64_t head_size = C / HEADS;
12084
+ const float scale = lm_ggml_get_op_params_f32(dst, 0);
12085
+
12086
+ float * dst_data = (float *) dst->data;
12087
+ float * state = ((float *) dst->data) + C * T;
12088
+
12089
+ const int ith = params->ith;
12090
+ const int nth = params->nth;
12091
+
12092
+ if (ith >= HEADS) {
12093
+ return;
12094
+ }
12095
+
12096
+ const int h_start = (HEADS * ith) / nth;
12097
+ const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
12098
+ (HEADS * (ith + 1)) / nth : HEADS;
12099
+
12100
+ float * k = (float *) dst->src[0]->data;
12101
+ float * v = (float *) dst->src[1]->data;
12102
+ float * q = (float *) dst->src[2]->data;
12103
+ float * g = (float *) dst->src[3]->data;
12104
+
12105
+ size_t t_stride = HEADS * head_size; // Same to C
12106
+
12107
+ size_t h_stride = C / HEADS;
12108
+ LM_GGML_ASSERT(C % HEADS == 0); // C must be divisible by HEADS
12109
+ size_t h_stride_2d = head_size * head_size;
12110
+
12111
+ if (ith == 0) {
12112
+ memset(dst_data, 0, T * C * sizeof(float));
12113
+ }
12114
+ lm_ggml_barrier(params->threadpool);
12115
+
12116
+
12117
+ #if defined(__AVX__) && !defined(__AVX512F__)
12118
+ #define LM_GGML_F32X LM_GGML_F32x8
12119
+ #define LM_GGML_F32X_SET1 LM_GGML_F32x8_SET1
12120
+ #define LM_GGML_F32X_LOAD LM_GGML_F32x8_LOAD
12121
+ #define LM_GGML_F32X_STORE LM_GGML_F32x8_STORE
12122
+ #define LM_GGML_F32X_MUL LM_GGML_F32x8_MUL
12123
+ #define LM_GGML_F32X_FMA LM_GGML_F32x8_FMA
12124
+ #define GLA_VECTOR_SIZE 8
12125
+ #elif defined(__AVX512F__)
12126
+ #define LM_GGML_F32X LM_GGML_F32x16
12127
+ #define LM_GGML_F32X_SET1 LM_GGML_F32x16_SET1
12128
+ #define LM_GGML_F32X_LOAD LM_GGML_F32x16_LOAD
12129
+ #define LM_GGML_F32X_STORE LM_GGML_F32x16_STORE
12130
+ #define LM_GGML_F32X_MUL LM_GGML_F32x16_MUL
12131
+ #define LM_GGML_F32X_FMA LM_GGML_F32x16_FMA
12132
+ #define GLA_VECTOR_SIZE 16
12133
+ #elif defined(__ARM_NEON) && defined(__aarch64__)
12134
+ #define LM_GGML_F32X LM_GGML_F32x4
12135
+ #define LM_GGML_F32X_SET1 LM_GGML_F32x4_SET1
12136
+ #define LM_GGML_F32X_LOAD LM_GGML_F32x4_LOAD
12137
+ #define LM_GGML_F32X_STORE LM_GGML_F32x4_STORE
12138
+ #define LM_GGML_F32X_MUL LM_GGML_F32x4_MUL
12139
+ #define LM_GGML_F32X_FMA LM_GGML_F32x4_FMA
12140
+ #define GLA_VECTOR_SIZE 4
12141
+ #endif
12142
+
12143
+ #ifdef GLA_VECTOR_SIZE
12144
+ const int64_t vec_count = head_size / GLA_VECTOR_SIZE;
12145
+
12146
+ for (int64_t t = 0; t < T; t++) {
12147
+ size_t t_offset = t * t_stride;
12148
+ size_t state_offset = head_size * C * (t / (T / n_seqs));
12149
+ float * state_cur = state + state_offset;
12150
+ float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[4]->data + state_offset;
12151
+
12152
+ for (int64_t h = h_start; h < h_end; h++) {
12153
+ size_t h_offset = h * h_stride;
12154
+ size_t t_h_offset = t_offset + h_offset;
12155
+ size_t h_2d_offset = h * h_stride_2d;
12156
+
12157
+ for (int64_t i = 0; i < head_size; i++) {
12158
+ size_t t_h_i_offset = t_h_offset + i;
12159
+ size_t h_2d_i_offset = h_2d_offset + i * h_stride;
12160
+
12161
+ float k_val = k[t_h_i_offset];
12162
+ float q_val = q[t_h_i_offset] * scale;
12163
+ float g_val = g[t_h_i_offset];
12164
+
12165
+ // Broadcast scalar values to vectors
12166
+ LM_GGML_F32X k_vec = LM_GGML_F32X_SET1(k_val);
12167
+ LM_GGML_F32X q_vec = LM_GGML_F32X_SET1(q_val);
12168
+ LM_GGML_F32X g_vec = LM_GGML_F32X_SET1(g_val);
12169
+
12170
+ for (int64_t j = 0; j < vec_count; j++) {
12171
+ size_t base_j = j * GLA_VECTOR_SIZE;
12172
+ size_t t_h_j_offset = t_h_offset + base_j;
12173
+ size_t h_2d_i_j_offset = h_2d_i_offset + base_j;
12174
+
12175
+ // Load x elements at once
12176
+ LM_GGML_F32X v_vec = LM_GGML_F32X_LOAD(&v[t_h_j_offset]);
12177
+ LM_GGML_F32X prev_state_vec = LM_GGML_F32X_LOAD(&state_prev[h_2d_i_j_offset]);
12178
+ LM_GGML_F32X dst_vec = LM_GGML_F32X_LOAD(&dst_data[t_h_j_offset]);
12179
+
12180
+ // Compute kv = v * k
12181
+ LM_GGML_F32X kv_vec = LM_GGML_F32X_MUL(v_vec, k_vec);
12182
+
12183
+ // Compute temp = prev_state * g + kv
12184
+ LM_GGML_F32X temp_vec = LM_GGML_F32X_FMA(kv_vec, prev_state_vec, g_vec);
12185
+
12186
+ // Update dst: dst += temp * q
12187
+ dst_vec = LM_GGML_F32X_FMA(dst_vec, temp_vec, q_vec);
12188
+ LM_GGML_F32X_STORE(&dst_data[t_h_j_offset], dst_vec);
12189
+
12190
+ // Update state
12191
+ LM_GGML_F32X_STORE(&state_cur[h_2d_i_j_offset], temp_vec);
12192
+ }
12193
+
12194
+ // Handle remaining elements, this will not be used.
12195
+ for (int64_t j = vec_count * GLA_VECTOR_SIZE; j < head_size; j++) {
12196
+ size_t t_h_j_offset = t_h_offset + j;
12197
+ size_t h_2d_i_j_offset = h_2d_i_offset + j;
12198
+ float v_val = v[t_h_j_offset];
12199
+ float kv_val = v_val * k_val;
12200
+ float prev_state_val = state_prev[h_2d_i_j_offset];
12201
+ float temp_val = kv_val + prev_state_val * g_val;
12202
+ dst_data[t_h_j_offset] += temp_val * q_val;
12203
+ state_cur[h_2d_i_j_offset] = temp_val;
12204
+ }
12205
+ }
12206
+ }
12207
+ }
12208
+
12209
+ #else
12210
+ for (int64_t t = 0; t < T; t++) {
12211
+ size_t t_offset = t * t_stride;
12212
+ size_t state_offset = head_size * C * (t / (T / n_seqs));
12213
+ float * state_cur = state + state_offset;
12214
+ float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[4]->data + state_offset;
12215
+
12216
+ for (int64_t h = h_start; h < h_end; h++) {
12217
+ size_t h_offset = h * h_stride;
12218
+ size_t t_h_offset = t_offset + h_offset;
12219
+ size_t h_2d_offset = h * h_stride_2d;
12220
+
12221
+ for (int64_t i = 0; i < head_size; i++) {
12222
+ size_t t_h_i_offset = t_h_offset + i;
12223
+ size_t h_2d_i_offset = h_2d_offset + i * h_stride;
12224
+
12225
+ float k_val = k[t_h_i_offset];
12226
+ float q_val = q[t_h_i_offset] * scale;
12227
+ float g_val = g[t_h_i_offset];
12228
+
12229
+ for (int64_t j = 0; j < head_size; j++) {
12230
+ size_t t_h_j_offset = t_h_offset + j;
12231
+ size_t h_2d_i_j_offset = h_2d_i_offset + j;
12232
+
12233
+ float v_val = v[t_h_j_offset];
12234
+ float kv_val = v_val * k_val;
12235
+ float prev_state_val = state_prev[h_2d_i_j_offset];
12236
+ float temp_val = prev_state_val * g_val + kv_val;
12237
+ dst_data[t_h_j_offset] += temp_val * q_val;
12238
+ state_cur[h_2d_i_j_offset] = temp_val;
12239
+ }
12240
+ }
12241
+ }
12242
+ }
12243
+ #endif
12244
+ }
12245
+
12246
+
12247
+ static void lm_ggml_compute_forward_gla(
12248
+ const struct lm_ggml_compute_params * params,
12249
+ struct lm_ggml_tensor * dst) {
12250
+
12251
+ const struct lm_ggml_tensor * src0 = dst->src[0];
12252
+
12253
+ switch (src0->type) {
12254
+ case LM_GGML_TYPE_F32:
12255
+ {
12256
+ lm_ggml_compute_forward_gla_f32(params, dst);
12257
+ } break;
12258
+ default:
12259
+ {
12260
+ LM_GGML_ABORT("fatal error");
12261
+ }
12262
+ }
12263
+ }
12264
+
12002
12265
  // lm_ggml_compute_forward_map_unary
12003
12266
 
12004
12267
  static void lm_ggml_compute_forward_map_unary_f32(
@@ -12292,22 +12555,22 @@ static void lm_ggml_compute_forward_cross_entropy_loss_back_f32(
12292
12555
  const struct lm_ggml_compute_params * params,
12293
12556
  struct lm_ggml_tensor * dst) {
12294
12557
 
12295
- const struct lm_ggml_tensor * src0 = dst->src[0];
12296
- const struct lm_ggml_tensor * src1 = dst->src[1];
12297
- const struct lm_ggml_tensor * opt0 = dst->src[2];
12558
+ const struct lm_ggml_tensor * grad = dst->src[0]; // gradient of forward pass output
12559
+ const struct lm_ggml_tensor * src0f = dst->src[1]; // src0 of forward pass
12560
+ const struct lm_ggml_tensor * src1f = dst->src[2]; // src1 of forward pass
12298
12561
 
12299
12562
  LM_GGML_ASSERT(lm_ggml_is_contiguous(dst));
12300
- LM_GGML_ASSERT(lm_ggml_is_contiguous(src0));
12301
- LM_GGML_ASSERT(lm_ggml_is_contiguous(src1));
12302
- LM_GGML_ASSERT(lm_ggml_is_contiguous(opt0));
12303
- LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, src1) && lm_ggml_are_same_shape(src0, dst));
12563
+ LM_GGML_ASSERT(lm_ggml_is_contiguous(src0f));
12564
+ LM_GGML_ASSERT(lm_ggml_is_contiguous(src1f));
12565
+ LM_GGML_ASSERT(lm_ggml_is_contiguous(grad));
12566
+ LM_GGML_ASSERT(lm_ggml_are_same_shape(src0f, src1f) && lm_ggml_are_same_shape(src0f, dst));
12304
12567
 
12305
12568
  const int64_t ith = params->ith;
12306
12569
  const int64_t nth = params->nth;
12307
12570
 
12308
12571
  // TODO: handle transposed/permuted matrices
12309
- const int64_t nc = src0->ne[0];
12310
- const int64_t nr = lm_ggml_nrows(src0);
12572
+ const int64_t nc = src0f->ne[0];
12573
+ const int64_t nr = lm_ggml_nrows(src0f);
12311
12574
 
12312
12575
  // rows per thread
12313
12576
  const int64_t dr = (nr + nth - 1)/nth;
@@ -12316,12 +12579,12 @@ static void lm_ggml_compute_forward_cross_entropy_loss_back_f32(
12316
12579
  const int64_t ir0 = dr*ith;
12317
12580
  const int64_t ir1 = MIN(ir0 + dr, nr);
12318
12581
 
12319
- const float d_by_nr = ((const float *) opt0->data)[0] / (float) nr;
12582
+ const float d_by_nr = ((const float *) grad->data)[0] / (float) nr;
12320
12583
 
12321
12584
  for (int64_t i1 = ir0; i1 < ir1; i1++) {
12322
- float * ds0 = (float *)((char *) dst->data + i1*dst->nb[1]);
12323
- float * s0 = (float *)((char *) src0->data + i1*src0->nb[1]);
12324
- float * s1 = (float *)((char *) src1->data + i1*src1->nb[1]);
12585
+ float * ds0 = (float *)((char *) dst->data + i1*dst->nb[1]);
12586
+ const float * s0 = (const float *)((const char *) src0f->data + i1*src0f->nb[1]);
12587
+ const float * s1 = (const float *)((const char *) src1f->data + i1*src1f->nb[1]);
12325
12588
 
12326
12589
  #ifndef NDEBUG
12327
12590
  for (int64_t i = 0; i < nc; ++i) {
@@ -12334,11 +12597,11 @@ static void lm_ggml_compute_forward_cross_entropy_loss_back_f32(
12334
12597
  // soft_max
12335
12598
  float max = -INFINITY;
12336
12599
  lm_ggml_vec_max_f32(nc, &max, s0);
12337
- lm_ggml_float sum = lm_ggml_vec_soft_max_f32(nc, ds0, s0, max);
12600
+ const lm_ggml_float sum = lm_ggml_vec_soft_max_f32(nc, ds0, s0, max);
12338
12601
  assert(sum > 0.0);
12339
12602
  lm_ggml_vec_scale_f32(nc, ds0, 1.0/sum);
12340
12603
 
12341
- // grad(src0) = (softmax(src0) - src1) * grad(cross_entropy_loss(src0, src1)) / nr
12604
+ // grad(src0f) = (softmax(src0f) - src1f) * grad(cross_entropy_loss(src0f, src1f)) / nr
12342
12605
  lm_ggml_vec_sub_f32(nc, ds0, ds0, s1);
12343
12606
  lm_ggml_vec_scale_f32(nc, ds0, d_by_nr);
12344
12607
 
@@ -12635,7 +12898,7 @@ static void lm_ggml_compute_forward(struct lm_ggml_compute_params * params, stru
12635
12898
  } break;
12636
12899
  case LM_GGML_OP_SOFT_MAX_BACK:
12637
12900
  {
12638
- lm_ggml_compute_forward_soft_max_back(params, tensor);
12901
+ lm_ggml_compute_forward_soft_max_ext_back(params, tensor);
12639
12902
  } break;
12640
12903
  case LM_GGML_OP_ROPE:
12641
12904
  {
@@ -12748,6 +13011,10 @@ static void lm_ggml_compute_forward(struct lm_ggml_compute_params * params, stru
12748
13011
  {
12749
13012
  lm_ggml_compute_forward_rwkv_wkv6(params, tensor);
12750
13013
  } break;
13014
+ case LM_GGML_OP_GATED_LINEAR_ATTN:
13015
+ {
13016
+ lm_ggml_compute_forward_gla(params, tensor);
13017
+ } break;
12751
13018
  case LM_GGML_OP_MAP_UNARY:
12752
13019
  {
12753
13020
  lm_ggml_unary_op_f32_t fun;
@@ -13046,6 +13313,7 @@ static int lm_ggml_get_n_tasks(struct lm_ggml_tensor * node, int n_threads) {
13046
13313
  case LM_GGML_OP_WIN_UNPART:
13047
13314
  case LM_GGML_OP_GET_REL_POS:
13048
13315
  case LM_GGML_OP_RWKV_WKV6:
13316
+ case LM_GGML_OP_GATED_LINEAR_ATTN:
13049
13317
  case LM_GGML_OP_MAP_UNARY:
13050
13318
  case LM_GGML_OP_MAP_BINARY:
13051
13319
  case LM_GGML_OP_MAP_CUSTOM1_F32:
@@ -13471,6 +13739,7 @@ struct lm_ggml_cplan lm_ggml_graph_plan(
13471
13739
  } break;
13472
13740
  case LM_GGML_OP_SOFT_MAX:
13473
13741
  case LM_GGML_OP_ROPE:
13742
+ case LM_GGML_OP_ROPE_BACK:
13474
13743
  {
13475
13744
  cur = lm_ggml_type_size(LM_GGML_TYPE_F32) * node->ne[0] * n_tasks;
13476
13745
  } break;
package/cpp/ggml-cpu.cpp CHANGED
@@ -402,8 +402,16 @@ static bool lm_ggml_backend_cpu_device_supports_op(lm_ggml_backend_dev_t dev, co
402
402
  op->type != LM_GGML_TYPE_IQ1_M; // missing type_traits.from_float
403
403
  case LM_GGML_OP_MUL_MAT:
404
404
  return src1->type == LM_GGML_TYPE_F32 || src1->type == lm_ggml_get_type_traits_cpu(src0->type)->vec_dot_type;
405
- case LM_GGML_OP_ROPE_BACK:
406
- return op->src[2] == NULL && (op->op_params[2] & 4) == 0;
405
+ case LM_GGML_OP_SOFT_MAX_BACK: {
406
+ if (op->src[0]->type != LM_GGML_TYPE_F32 || op->src[1]->type != LM_GGML_TYPE_F32) {
407
+ return false;
408
+ }
409
+ float max_bias = 0.0f;
410
+
411
+ memcpy(&max_bias, (const float *) op->op_params + 1, sizeof(float));
412
+
413
+ return max_bias == 0.0f;
414
+ }
407
415
  case LM_GGML_OP_IM2COL_BACK:
408
416
  return src0->type == LM_GGML_TYPE_F32 && src1->type == LM_GGML_TYPE_F32;
409
417
  case LM_GGML_OP_OUT_PROD: