@fugood/llama.node 0.3.6 → 0.3.8

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (186) hide show
  1. package/README.md +17 -2
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-cuda/arm64/llama-node.node +0 -0
  7. package/bin/linux-cuda/x64/llama-node.node +0 -0
  8. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  9. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  10. package/bin/win32/arm64/llama-node.node +0 -0
  11. package/bin/win32/arm64/node.lib +0 -0
  12. package/bin/win32/x64/llama-node.node +0 -0
  13. package/bin/win32/x64/node.lib +0 -0
  14. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/arm64/node.lib +0 -0
  16. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  17. package/bin/win32-vulkan/x64/node.lib +0 -0
  18. package/lib/binding.ts +3 -1
  19. package/lib/index.js +16 -1
  20. package/lib/index.ts +16 -0
  21. package/package.json +1 -1
  22. package/src/EmbeddingWorker.cpp +4 -3
  23. package/src/LlamaCompletionWorker.cpp +4 -2
  24. package/src/LlamaContext.cpp +61 -6
  25. package/src/LlamaContext.h +1 -0
  26. package/src/common.hpp +6 -11
  27. package/src/llama.cpp/.github/workflows/build.yml +19 -17
  28. package/src/llama.cpp/.github/workflows/docker.yml +77 -30
  29. package/src/llama.cpp/.github/workflows/editorconfig.yml +3 -1
  30. package/src/llama.cpp/.github/workflows/server.yml +22 -3
  31. package/src/llama.cpp/CMakeLists.txt +49 -24
  32. package/src/llama.cpp/common/arg.cpp +82 -26
  33. package/src/llama.cpp/common/arg.h +3 -0
  34. package/src/llama.cpp/common/common.cpp +192 -72
  35. package/src/llama.cpp/common/common.h +51 -18
  36. package/src/llama.cpp/common/ngram-cache.cpp +12 -12
  37. package/src/llama.cpp/common/ngram-cache.h +2 -2
  38. package/src/llama.cpp/common/sampling.cpp +11 -6
  39. package/src/llama.cpp/common/speculative.cpp +18 -15
  40. package/src/llama.cpp/docs/build.md +2 -0
  41. package/src/llama.cpp/examples/batched/batched.cpp +9 -7
  42. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +3 -3
  43. package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +10 -8
  44. package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +11 -8
  45. package/src/llama.cpp/examples/cvector-generator/mean.hpp +1 -1
  46. package/src/llama.cpp/examples/cvector-generator/pca.hpp +1 -1
  47. package/src/llama.cpp/examples/embedding/embedding.cpp +8 -7
  48. package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +7 -6
  49. package/src/llama.cpp/examples/export-lora/export-lora.cpp +8 -7
  50. package/src/llama.cpp/examples/gguf/gguf.cpp +10 -6
  51. package/src/llama.cpp/examples/gguf-hash/gguf-hash.cpp +1 -0
  52. package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +8 -7
  53. package/src/llama.cpp/examples/gritlm/gritlm.cpp +13 -10
  54. package/src/llama.cpp/examples/imatrix/imatrix.cpp +13 -12
  55. package/src/llama.cpp/examples/infill/infill.cpp +23 -24
  56. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +44 -13
  57. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +11 -6
  58. package/src/llama.cpp/examples/llava/clip.cpp +4 -2
  59. package/src/llama.cpp/examples/llava/llava-cli.cpp +9 -6
  60. package/src/llama.cpp/examples/llava/llava.cpp +2 -2
  61. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +8 -4
  62. package/src/llama.cpp/examples/llava/qwen2vl-cli.cpp +11 -8
  63. package/src/llama.cpp/examples/lookahead/lookahead.cpp +6 -7
  64. package/src/llama.cpp/examples/lookup/lookup-create.cpp +4 -9
  65. package/src/llama.cpp/examples/lookup/lookup-stats.cpp +3 -7
  66. package/src/llama.cpp/examples/lookup/lookup.cpp +5 -6
  67. package/src/llama.cpp/examples/main/main.cpp +51 -29
  68. package/src/llama.cpp/examples/parallel/parallel.cpp +5 -6
  69. package/src/llama.cpp/examples/passkey/passkey.cpp +7 -5
  70. package/src/llama.cpp/examples/perplexity/perplexity.cpp +37 -23
  71. package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +12 -14
  72. package/src/llama.cpp/examples/retrieval/retrieval.cpp +8 -8
  73. package/src/llama.cpp/examples/rpc/rpc-server.cpp +12 -0
  74. package/src/llama.cpp/examples/run/CMakeLists.txt +1 -1
  75. package/src/llama.cpp/examples/run/linenoise.cpp/linenoise.cpp +1351 -0
  76. package/src/llama.cpp/examples/run/linenoise.cpp/linenoise.h +114 -0
  77. package/src/llama.cpp/examples/run/run.cpp +175 -61
  78. package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +4 -25
  79. package/src/llama.cpp/examples/server/CMakeLists.txt +1 -0
  80. package/src/llama.cpp/examples/server/httplib.h +1295 -409
  81. package/src/llama.cpp/examples/server/server.cpp +387 -181
  82. package/src/llama.cpp/examples/server/tests/requirements.txt +1 -0
  83. package/src/llama.cpp/examples/server/utils.hpp +170 -58
  84. package/src/llama.cpp/examples/simple/simple.cpp +9 -8
  85. package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +16 -12
  86. package/src/llama.cpp/examples/speculative/speculative.cpp +22 -23
  87. package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +8 -12
  88. package/src/llama.cpp/examples/tokenize/tokenize.cpp +17 -5
  89. package/src/llama.cpp/examples/tts/tts.cpp +64 -23
  90. package/src/llama.cpp/ggml/CMakeLists.txt +5 -21
  91. package/src/llama.cpp/ggml/include/ggml-backend.h +2 -0
  92. package/src/llama.cpp/ggml/include/ggml-cpp.h +1 -0
  93. package/src/llama.cpp/ggml/include/ggml.h +36 -145
  94. package/src/llama.cpp/ggml/include/gguf.h +202 -0
  95. package/src/llama.cpp/ggml/src/CMakeLists.txt +6 -3
  96. package/src/llama.cpp/ggml/src/ggml-alloc.c +5 -0
  97. package/src/llama.cpp/ggml/src/ggml-backend-impl.h +0 -1
  98. package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +79 -49
  99. package/src/llama.cpp/ggml/src/ggml-backend.cpp +5 -2
  100. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +33 -23
  101. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +57 -72
  102. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +87 -2
  103. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +335 -66
  104. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +10 -2
  105. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1090 -378
  106. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.h +2 -2
  107. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/cuda.h +1 -0
  108. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +3 -0
  109. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +3 -0
  110. package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +3 -1
  111. package/src/llama.cpp/ggml/src/ggml-impl.h +11 -16
  112. package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +16 -0
  113. package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +6 -6
  114. package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +154 -35
  115. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
  116. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +9 -3
  117. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +18 -0
  118. package/src/llama.cpp/ggml/src/ggml-sycl/concat.cpp +3 -2
  119. package/src/llama.cpp/ggml/src/ggml-sycl/concat.hpp +1 -2
  120. package/src/llama.cpp/ggml/src/ggml-sycl/conv.cpp +3 -2
  121. package/src/llama.cpp/ggml/src/ggml-sycl/conv.hpp +1 -2
  122. package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +40 -95
  123. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +48 -48
  124. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +24 -24
  125. package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +238 -164
  126. package/src/llama.cpp/ggml/src/ggml-sycl/gla.cpp +105 -0
  127. package/src/llama.cpp/ggml/src/ggml-sycl/gla.hpp +8 -0
  128. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +3 -3
  129. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.hpp +1 -2
  130. package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +3 -2
  131. package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.hpp +1 -2
  132. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +7 -5
  133. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.hpp +1 -2
  134. package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +74 -4
  135. package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +314 -116
  136. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -2
  137. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +9 -3
  138. package/src/llama.cpp/ggml/src/ggml.c +117 -1327
  139. package/src/llama.cpp/ggml/src/gguf.cpp +1329 -0
  140. package/src/llama.cpp/include/llama-cpp.h +6 -1
  141. package/src/llama.cpp/include/llama.h +138 -75
  142. package/src/llama.cpp/src/CMakeLists.txt +13 -1
  143. package/src/llama.cpp/src/llama-adapter.cpp +347 -0
  144. package/src/llama.cpp/src/llama-adapter.h +74 -0
  145. package/src/llama.cpp/src/llama-arch.cpp +1487 -0
  146. package/src/llama.cpp/src/llama-arch.h +400 -0
  147. package/src/llama.cpp/src/llama-batch.cpp +368 -0
  148. package/src/llama.cpp/src/llama-batch.h +88 -0
  149. package/src/llama.cpp/src/llama-chat.cpp +578 -0
  150. package/src/llama.cpp/src/llama-chat.h +52 -0
  151. package/src/llama.cpp/src/llama-context.cpp +1775 -0
  152. package/src/llama.cpp/src/llama-context.h +128 -0
  153. package/src/llama.cpp/src/llama-cparams.cpp +1 -0
  154. package/src/llama.cpp/src/llama-cparams.h +37 -0
  155. package/src/llama.cpp/src/llama-grammar.cpp +5 -4
  156. package/src/llama.cpp/src/llama-grammar.h +3 -1
  157. package/src/llama.cpp/src/llama-hparams.cpp +71 -0
  158. package/src/llama.cpp/src/llama-hparams.h +139 -0
  159. package/src/llama.cpp/src/llama-impl.cpp +167 -0
  160. package/src/llama.cpp/src/llama-impl.h +16 -136
  161. package/src/llama.cpp/src/llama-kv-cache.cpp +718 -0
  162. package/src/llama.cpp/src/llama-kv-cache.h +218 -0
  163. package/src/llama.cpp/src/llama-mmap.cpp +589 -0
  164. package/src/llama.cpp/src/llama-mmap.h +67 -0
  165. package/src/llama.cpp/src/llama-model-loader.cpp +1124 -0
  166. package/src/llama.cpp/src/llama-model-loader.h +167 -0
  167. package/src/llama.cpp/src/llama-model.cpp +3953 -0
  168. package/src/llama.cpp/src/llama-model.h +370 -0
  169. package/src/llama.cpp/src/llama-quant.cpp +934 -0
  170. package/src/llama.cpp/src/llama-quant.h +1 -0
  171. package/src/llama.cpp/src/llama-sampling.cpp +147 -32
  172. package/src/llama.cpp/src/llama-sampling.h +3 -19
  173. package/src/llama.cpp/src/llama-vocab.cpp +1832 -575
  174. package/src/llama.cpp/src/llama-vocab.h +97 -142
  175. package/src/llama.cpp/src/llama.cpp +7160 -20314
  176. package/src/llama.cpp/src/unicode.cpp +8 -3
  177. package/src/llama.cpp/tests/CMakeLists.txt +2 -0
  178. package/src/llama.cpp/tests/test-autorelease.cpp +3 -3
  179. package/src/llama.cpp/tests/test-backend-ops.cpp +370 -59
  180. package/src/llama.cpp/tests/test-chat-template.cpp +162 -125
  181. package/src/llama.cpp/tests/test-gguf.cpp +222 -187
  182. package/src/llama.cpp/tests/test-model-load-cancel.cpp +1 -1
  183. package/src/llama.cpp/tests/test-sampling.cpp +0 -1
  184. package/src/llama.cpp/tests/test-tokenizer-0.cpp +4 -4
  185. package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +9 -7
  186. package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +8 -6
@@ -986,7 +986,7 @@ inline static void __wasm_f16x4_store(ggml_fp16_t * p, v128_t x) {
986
986
  #define GGML_F16_STEP 32
987
987
  #define GGML_F16_EPR 4
988
988
 
989
- static inline __m128 __sse_f16x4_load(ggml_fp16_t *x) {
989
+ static inline __m128 __sse_f16x4_load(const ggml_fp16_t * x) {
990
990
  float tmp[4];
991
991
 
992
992
  tmp[0] = GGML_FP16_TO_FP32(x[0]);
@@ -997,7 +997,7 @@ static inline __m128 __sse_f16x4_load(ggml_fp16_t *x) {
997
997
  return _mm_loadu_ps(tmp);
998
998
  }
999
999
 
1000
- static inline void __sse_f16x4_store(ggml_fp16_t *x, __m128 y) {
1000
+ static inline void __sse_f16x4_store(ggml_fp16_t * x, __m128 y) {
1001
1001
  float arr[4];
1002
1002
 
1003
1003
  _mm_storeu_ps(arr, y);
@@ -3967,6 +3967,57 @@ static void ggml_compute_forward_dup_bytes(
3967
3967
  }
3968
3968
  }
3969
3969
 
3970
+ static void ggml_compute_forward_dup_q(
3971
+ const struct ggml_compute_params * params,
3972
+ struct ggml_tensor * dst) {
3973
+
3974
+ const struct ggml_tensor * src0 = dst->src[0];
3975
+ const struct ggml_tensor * src1 = dst->src[1];
3976
+
3977
+ GGML_TENSOR_BINARY_OP_LOCALS
3978
+
3979
+ const enum ggml_type type = src0->type;
3980
+ ggml_to_float_t const dequantize_row_q = ggml_get_type_traits(type)->to_float;
3981
+
3982
+ size_t qk = ggml_blck_size(type);
3983
+ const int64_t nr = ggml_nelements(src1) / qk;
3984
+
3985
+ // destination must be contiguous in the first dimension
3986
+ GGML_ASSERT(nb10 == ggml_type_size(dst->type));
3987
+ // must either have first dimension large enough to hold a row, or fully contiguous
3988
+ GGML_ASSERT((ne10 % qk) == 0 || ggml_is_contiguous(dst));
3989
+
3990
+ const int ith = params->ith;
3991
+ const int nth = params->nth;
3992
+
3993
+ const int dr = (nr + nth - 1)/nth;
3994
+
3995
+ // row range for this thread
3996
+ const int ir0 = dr*ith;
3997
+ const int ir1 = MIN(ir0 + dr, nr);
3998
+
3999
+ for (int64_t ir = ir0; ir < ir1; ++ir) {
4000
+
4001
+ uint32_t i = ir * qk;
4002
+
4003
+ const int64_t i03 = i/(ne00 * ne01 * ne02);
4004
+ const int64_t i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
4005
+ const int64_t i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
4006
+ const int64_t i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
4007
+ const int64_t x_offset = (i00/qk)*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
4008
+
4009
+ const int64_t i13 = i/(ne10 * ne11 * ne12);
4010
+ const int64_t i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
4011
+ const int64_t i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
4012
+ const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
4013
+ const int64_t dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13;
4014
+
4015
+ dequantize_row_q(
4016
+ (const void *) ((char *) src0->data + x_offset),
4017
+ (float *) ((char *) dst->data + dst_offset), qk);
4018
+ }
4019
+ }
4020
+
3970
4021
  static void ggml_compute_forward_dup(
3971
4022
  const struct ggml_compute_params * params,
3972
4023
  struct ggml_tensor * dst) {
@@ -3993,6 +4044,10 @@ static void ggml_compute_forward_dup(
3993
4044
  } break;
3994
4045
  default:
3995
4046
  {
4047
+ if (ggml_is_quantized(src0->type) && dst->type == GGML_TYPE_F32) {
4048
+ ggml_compute_forward_dup_q(params, dst);
4049
+ break;
4050
+ }
3996
4051
  GGML_ABORT("fatal error");
3997
4052
  }
3998
4053
  }
@@ -6691,20 +6746,20 @@ static void ggml_compute_forward_silu_back_f32(
6691
6746
  const struct ggml_compute_params * params,
6692
6747
  struct ggml_tensor * dst) {
6693
6748
 
6694
- const struct ggml_tensor * src0 = dst->src[0];
6695
- const struct ggml_tensor * grad = dst->src[1];
6749
+ const struct ggml_tensor * grad = dst->src[0];
6750
+ const struct ggml_tensor * src1 = dst->src[1];
6696
6751
 
6697
6752
  assert(ggml_is_contiguous_1(grad));
6698
- assert(ggml_is_contiguous_1(src0));
6753
+ assert(ggml_is_contiguous_1(src1));
6699
6754
  assert(ggml_is_contiguous_1(dst));
6700
- assert(ggml_are_same_shape(src0, dst));
6701
- assert(ggml_are_same_shape(src0, grad));
6755
+ assert(ggml_are_same_shape(src1, dst));
6756
+ assert(ggml_are_same_shape(src1, grad));
6702
6757
 
6703
6758
  const int ith = params->ith;
6704
6759
  const int nth = params->nth;
6705
6760
 
6706
- const int nc = src0->ne[0];
6707
- const int nr = ggml_nrows(src0);
6761
+ const int nc = src1->ne[0];
6762
+ const int nr = ggml_nrows(src1);
6708
6763
 
6709
6764
  // rows per thread
6710
6765
  const int dr = (nr + nth - 1)/nth;
@@ -6716,7 +6771,7 @@ static void ggml_compute_forward_silu_back_f32(
6716
6771
  for (int i1 = ir0; i1 < ir1; i1++) {
6717
6772
  ggml_vec_silu_backward_f32(nc,
6718
6773
  (float *) ((char *) dst->data + i1*( dst->nb[1])),
6719
- (float *) ((char *) src0->data + i1*(src0->nb[1])),
6774
+ (float *) ((char *) src1->data + i1*(src1->nb[1])),
6720
6775
  (float *) ((char *) grad->data + i1*(grad->nb[1])));
6721
6776
 
6722
6777
  #ifndef NDEBUG
@@ -6895,7 +6950,7 @@ static void ggml_compute_forward_norm_f32(
6895
6950
  float eps;
6896
6951
  memcpy(&eps, dst->op_params, sizeof(float));
6897
6952
 
6898
- GGML_ASSERT(eps > 0.0f);
6953
+ GGML_ASSERT(eps >= 0.0f);
6899
6954
 
6900
6955
  // TODO: optimize
6901
6956
  for (int64_t i03 = 0; i03 < ne03; i03++) {
@@ -6966,7 +7021,7 @@ static void ggml_compute_forward_rms_norm_f32(
6966
7021
  float eps;
6967
7022
  memcpy(&eps, dst->op_params, sizeof(float));
6968
7023
 
6969
- GGML_ASSERT(eps > 0.0f);
7024
+ GGML_ASSERT(eps >= 0.0f);
6970
7025
 
6971
7026
  // TODO: optimize
6972
7027
  for (int64_t i03 = 0; i03 < ne03; i03++) {
@@ -7018,12 +7073,13 @@ static void ggml_compute_forward_rms_norm_back_f32(
7018
7073
  const struct ggml_compute_params * params,
7019
7074
  struct ggml_tensor * dst) {
7020
7075
 
7021
- const struct ggml_tensor * src0 = dst->src[0];
7022
- const struct ggml_tensor * src1 = dst->src[1];
7076
+ const struct ggml_tensor * src0 = dst->src[0]; // gradients from forward pass output
7077
+ const struct ggml_tensor * src1 = dst->src[1]; // src1 from forward pass
7023
7078
 
7024
7079
  GGML_ASSERT(ggml_are_same_shape(src0, dst) && ggml_are_same_shape(src0, src1));
7025
7080
 
7026
7081
  GGML_ASSERT(src0->nb[0] == sizeof(float));
7082
+ GGML_ASSERT(src1->nb[0] == sizeof(float));
7027
7083
 
7028
7084
  const int ith = params->ith;
7029
7085
  const int nth = params->nth;
@@ -7042,8 +7098,8 @@ static void ggml_compute_forward_rms_norm_back_f32(
7042
7098
  const int64_t i12 = i02;
7043
7099
  const int64_t i13 = i03;
7044
7100
 
7045
- const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
7046
- const float * dz = (float *) ((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13);
7101
+ const float * dz = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
7102
+ const float * x = (float *) ((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13);
7047
7103
 
7048
7104
  ggml_float sum_xx = 0.0;
7049
7105
  ggml_float sum_xdz = 0.0;
@@ -7066,9 +7122,9 @@ static void ggml_compute_forward_rms_norm_back_f32(
7066
7122
  {
7067
7123
  // z = rms_norm(x)
7068
7124
  //
7069
- // rms_norm(src0) =
7125
+ // rms_norm(src1) =
7070
7126
  // scale(
7071
- // src0,
7127
+ // src1,
7072
7128
  // div(
7073
7129
  // 1,
7074
7130
  // sqrt(
@@ -7076,13 +7132,13 @@ static void ggml_compute_forward_rms_norm_back_f32(
7076
7132
  // scale(
7077
7133
  // sum(
7078
7134
  // sqr(
7079
- // src0)),
7135
+ // src1)),
7080
7136
  // (1.0/N)),
7081
7137
  // eps))));
7082
7138
 
7083
7139
  // postorder:
7084
7140
  // ## op args grad
7085
- // 00 param src0 grad[#00]
7141
+ // 00 param src1 grad[#00]
7086
7142
  // 01 const 1
7087
7143
  // 02 sqr (#00) grad[#02]
7088
7144
  // 03 sum (#02) grad[#03]
@@ -7159,6 +7215,7 @@ static void ggml_compute_forward_rms_norm_back_f32(
7159
7215
  // dx := scale(dx, rrms)
7160
7216
  float * dx = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
7161
7217
 
7218
+ // dx[i00] = (x*(-sum_xdz/sum_eps) + dz) / sqrtf(mean_eps)
7162
7219
  ggml_vec_cpy_f32 (ne00, dx, x);
7163
7220
  // ggml_vec_scale_f32(ne00, dx, -mean_xdz/mean_eps);
7164
7221
  ggml_vec_scale_f32(ne00, dx, (float)(-sum_xdz)/sum_eps);
@@ -7419,14 +7476,14 @@ static void ggml_compute_forward_mul_mat(
7419
7476
  if (src1_cont) {
7420
7477
  for (int64_t i13 = 0; i13 < ne13; i13++)
7421
7478
  for (int64_t i12 = 0; i12 < ne12; i12++)
7422
- if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type),
7479
+ if (!llamafile_sgemm(params,
7480
+ ne01, ne11, ne00/ggml_blck_size(src0->type),
7423
7481
  (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03,
7424
7482
  nb01/ggml_type_size(src0->type),
7425
7483
  (const char *)src1->data + i12*nb12 + i13*nb13,
7426
7484
  nb11/ggml_type_size(src1->type),
7427
7485
  (char *)dst->data + i12*nb2 + i13*nb3,
7428
7486
  nb1/ggml_type_size(dst->type),
7429
- ith, nth,
7430
7487
  src0->type,
7431
7488
  src1->type,
7432
7489
  dst->type))
@@ -7471,14 +7528,14 @@ UseGgmlGemm1:;
7471
7528
 
7472
7529
  for (int64_t i13 = 0; i13 < ne13; i13++)
7473
7530
  for (int64_t i12 = 0; i12 < ne12; i12++)
7474
- if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type),
7531
+ if (!llamafile_sgemm(params,
7532
+ ne01, ne11, ne00/ggml_blck_size(src0->type),
7475
7533
  (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03,
7476
7534
  nb01/ggml_type_size(src0->type),
7477
7535
  (const char *)wdata + (i12*ne11 + i13*ne12*ne11)*row_size,
7478
7536
  row_size/ggml_type_size(vec_dot_type),
7479
7537
  (char *)dst->data + i12*nb2 + i13*nb3,
7480
7538
  nb1/ggml_type_size(dst->type),
7481
- ith, nth,
7482
7539
  src0->type,
7483
7540
  vec_dot_type,
7484
7541
  dst->type))
@@ -7750,12 +7807,13 @@ static void ggml_compute_forward_out_prod_f32(
7750
7807
  const int ith = params->ith;
7751
7808
  const int nth = params->nth;
7752
7809
 
7753
- GGML_ASSERT(ne0 == ne00);
7754
- GGML_ASSERT(ne1 == ne10);
7755
- GGML_ASSERT(ne2 == ne02);
7756
- GGML_ASSERT(ne02 == ne12);
7757
- GGML_ASSERT(ne3 == ne13);
7758
- GGML_ASSERT(ne03 == ne13);
7810
+ GGML_ASSERT(ne0 == ne00);
7811
+ GGML_ASSERT(ne1 == ne10);
7812
+ GGML_ASSERT(ne2 == ne12);
7813
+ GGML_ASSERT(ne3 == ne13);
7814
+
7815
+ GGML_ASSERT(ne2 % ne02 == 0);
7816
+ GGML_ASSERT(ne3 % ne03 == 0);
7759
7817
 
7760
7818
  // we don't support permuted src0 or src1
7761
7819
  GGML_ASSERT(nb00 == sizeof(float));
@@ -7797,6 +7855,10 @@ static void ggml_compute_forward_out_prod_f32(
7797
7855
  const int64_t blck_0 = MAX(GGML_VEC_MAD_UNROLL, 32);
7798
7856
  const int64_t blck_1 = 16;
7799
7857
 
7858
+ // dps == dst per src0, used for group query attention
7859
+ const int64_t dps2 = ne2 / ne02;
7860
+ const int64_t dps3 = ne3 / ne03;
7861
+
7800
7862
  for (int64_t bir = ir0; bir < ir1; bir += blck_1) {
7801
7863
  const int64_t bir1 = MIN(bir + blck_1, ir1);
7802
7864
  for (int64_t bi01 = 0; bi01 < ne01; bi01 += blck_0) {
@@ -7807,8 +7869,8 @@ static void ggml_compute_forward_out_prod_f32(
7807
7869
  const int64_t i2 = (ir - i3*ne2*ne1)/ne1;
7808
7870
  const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1);
7809
7871
 
7810
- const int64_t i02 = i2;
7811
- const int64_t i03 = i3;
7872
+ const int64_t i02 = i2 / dps2;
7873
+ const int64_t i03 = i3 / dps3;
7812
7874
 
7813
7875
  //const int64_t i10 = i1;
7814
7876
  const int64_t i12 = i2;
@@ -8906,9 +8968,9 @@ static void ggml_compute_forward_soft_max(
8906
8968
  }
8907
8969
 
8908
8970
 
8909
- // ggml_compute_forward_soft_max_back
8971
+ // ggml_compute_forward_soft_max_ext_back
8910
8972
 
8911
- static void ggml_compute_forward_soft_max_back_f32(
8973
+ static void ggml_compute_forward_soft_max_ext_back_f32(
8912
8974
  const struct ggml_compute_params * params,
8913
8975
  struct ggml_tensor * dst) {
8914
8976
 
@@ -8921,6 +8983,14 @@ static void ggml_compute_forward_soft_max_back_f32(
8921
8983
  GGML_ASSERT(ggml_are_same_shape(src0, dst));
8922
8984
  GGML_ASSERT(ggml_are_same_shape(src1, dst));
8923
8985
 
8986
+ float scale = 1.0f;
8987
+ float max_bias = 0.0f;
8988
+
8989
+ memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float));
8990
+ memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float));
8991
+
8992
+ GGML_ASSERT(max_bias == 0.0f);
8993
+
8924
8994
  // TODO: handle transposed/permuted matrices
8925
8995
 
8926
8996
  const int ith = params->ith;
@@ -8969,10 +9039,11 @@ static void ggml_compute_forward_soft_max_back_f32(
8969
9039
 
8970
9040
  // linear runtime, no additional memory
8971
9041
  float dot_y_dy = 0;
8972
- ggml_vec_dot_f32 (nc, &dot_y_dy, 0, y, 0, dy, 0, 1);
8973
- ggml_vec_cpy_f32 (nc, dx, dy);
8974
- ggml_vec_acc1_f32(nc, dx, -dot_y_dy);
8975
- ggml_vec_mul_f32 (nc, dx, dx, y);
9042
+ ggml_vec_dot_f32 (nc, &dot_y_dy, 0, y, 0, dy, 0, 1);
9043
+ ggml_vec_cpy_f32 (nc, dx, dy);
9044
+ ggml_vec_acc1_f32 (nc, dx, -dot_y_dy);
9045
+ ggml_vec_mul_f32 (nc, dx, dx, y);
9046
+ ggml_vec_scale_f32(nc, dx, scale);
8976
9047
 
8977
9048
  #ifndef NDEBUG
8978
9049
  for (int i = 0; i < nc; ++i) {
@@ -8983,7 +9054,7 @@ static void ggml_compute_forward_soft_max_back_f32(
8983
9054
  }
8984
9055
  }
8985
9056
 
8986
- static void ggml_compute_forward_soft_max_back(
9057
+ static void ggml_compute_forward_soft_max_ext_back(
8987
9058
  const struct ggml_compute_params * params,
8988
9059
  struct ggml_tensor * dst) {
8989
9060
 
@@ -8992,7 +9063,7 @@ static void ggml_compute_forward_soft_max_back(
8992
9063
  switch (src0->type) {
8993
9064
  case GGML_TYPE_F32:
8994
9065
  {
8995
- ggml_compute_forward_soft_max_back_f32(params, dst);
9066
+ ggml_compute_forward_soft_max_ext_back_f32(params, dst);
8996
9067
  } break;
8997
9068
  default:
8998
9069
  {
@@ -9985,9 +10056,10 @@ static void ggml_compute_forward_im2col_back_f32(
9985
10056
  const struct ggml_compute_params * params,
9986
10057
  struct ggml_tensor * dst) {
9987
10058
 
9988
- const struct ggml_tensor * src0 = dst->src[0];
9989
- const struct ggml_tensor * src1 = dst->src[1];
10059
+ const struct ggml_tensor * src0 = dst->src[0]; // gradients of forward pass output
10060
+ const struct ggml_tensor * src1 = dst->src[1]; // convolution kernel
9990
10061
 
10062
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
9991
10063
  GGML_ASSERT(src1->type == GGML_TYPE_F32);
9992
10064
  GGML_ASSERT( dst->type == GGML_TYPE_F32);
9993
10065
 
@@ -10009,11 +10081,11 @@ static void ggml_compute_forward_im2col_back_f32(
10009
10081
  const int64_t IH = is_2D ? ne1 : 1;
10010
10082
  const int64_t IW = ne0;
10011
10083
 
10012
- const int64_t KH = is_2D ? ne01 : 1;
10013
- const int64_t KW = ne00;
10084
+ const int64_t KH = is_2D ? ne11 : 1;
10085
+ const int64_t KW = ne10;
10014
10086
 
10015
- const int64_t OH = is_2D ? ne12 : 1;
10016
- const int64_t OW = ne11;
10087
+ const int64_t OH = is_2D ? ne02 : 1;
10088
+ const int64_t OW = ne01;
10017
10089
 
10018
10090
  int ofs0 = is_2D ? nb3 : nb2;
10019
10091
  int ofs1 = is_2D ? nb2 : nb1;
@@ -10059,9 +10131,9 @@ static void ggml_compute_forward_im2col_back_f32(
10059
10131
  continue;
10060
10132
  }
10061
10133
 
10062
- const float * const src_data = (const float *) src1->data
10134
+ const float * const grad_in = (const float *) src0->data
10063
10135
  + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
10064
- grad += src_data[iic*(KH*KW) + ikh*KW + ikw];
10136
+ grad += grad_in[iic*(KH*KW) + ikh*KW + ikw];
10065
10137
  }
10066
10138
  }
10067
10139
  float * dst_data = (float *)((char *) wdata + (in*ofs0 + iic*ofs1)); // [IH, IW]
@@ -11803,9 +11875,9 @@ static void ggml_compute_forward_add_rel_pos(
11803
11875
  static void ggml_compute_forward_rwkv_wkv6_f32(
11804
11876
  const struct ggml_compute_params * params,
11805
11877
  struct ggml_tensor * dst) {
11806
- const int64_t T = dst->src[1]->ne[3];
11878
+ const int64_t T = dst->src[1]->ne[2];
11807
11879
  const int64_t C = dst->ne[0];
11808
- const int64_t HEADS = dst->src[1]->ne[2];
11880
+ const int64_t HEADS = dst->src[1]->ne[1];
11809
11881
  const int64_t n_seqs = dst->src[5]->ne[1];
11810
11882
  const int64_t head_size = C / HEADS;
11811
11883
 
@@ -12000,6 +12072,197 @@ static void ggml_compute_forward_rwkv_wkv6(
12000
12072
  }
12001
12073
  }
12002
12074
 
12075
+ // ggml_compute_forward_gla
12076
+
12077
+ static void ggml_compute_forward_gla_f32(
12078
+ const struct ggml_compute_params * params,
12079
+ struct ggml_tensor * dst) {
12080
+ const int64_t T = dst->src[1]->ne[2];
12081
+ const int64_t C = dst->ne[0];
12082
+ const int64_t HEADS = dst->src[1]->ne[1];
12083
+ const int64_t n_seqs = dst->src[4]->ne[1];
12084
+ const int64_t head_size = C / HEADS;
12085
+ const float scale = ggml_get_op_params_f32(dst, 0);
12086
+
12087
+ float * dst_data = (float *) dst->data;
12088
+ float * state = ((float *) dst->data) + C * T;
12089
+
12090
+ const int ith = params->ith;
12091
+ const int nth = params->nth;
12092
+
12093
+ if (ith >= HEADS) {
12094
+ return;
12095
+ }
12096
+
12097
+ const int h_start = (HEADS * ith) / nth;
12098
+ const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
12099
+ (HEADS * (ith + 1)) / nth : HEADS;
12100
+
12101
+ float * k = (float *) dst->src[0]->data;
12102
+ float * v = (float *) dst->src[1]->data;
12103
+ float * q = (float *) dst->src[2]->data;
12104
+ float * g = (float *) dst->src[3]->data;
12105
+
12106
+ size_t t_stride = HEADS * head_size; // Same to C
12107
+
12108
+ size_t h_stride = C / HEADS;
12109
+ GGML_ASSERT(C % HEADS == 0); // C must be divisible by HEADS
12110
+ size_t h_stride_2d = head_size * head_size;
12111
+
12112
+ if (ith == 0) {
12113
+ memset(dst_data, 0, T * C * sizeof(float));
12114
+ }
12115
+ ggml_barrier(params->threadpool);
12116
+
12117
+
12118
+ #if defined(__AVX__) && !defined(__AVX512F__)
12119
+ #define GGML_F32X GGML_F32x8
12120
+ #define GGML_F32X_SET1 GGML_F32x8_SET1
12121
+ #define GGML_F32X_LOAD GGML_F32x8_LOAD
12122
+ #define GGML_F32X_STORE GGML_F32x8_STORE
12123
+ #define GGML_F32X_MUL GGML_F32x8_MUL
12124
+ #define GGML_F32X_FMA GGML_F32x8_FMA
12125
+ #define GLA_VECTOR_SIZE 8
12126
+ #elif defined(__AVX512F__)
12127
+ #define GGML_F32X GGML_F32x16
12128
+ #define GGML_F32X_SET1 GGML_F32x16_SET1
12129
+ #define GGML_F32X_LOAD GGML_F32x16_LOAD
12130
+ #define GGML_F32X_STORE GGML_F32x16_STORE
12131
+ #define GGML_F32X_MUL GGML_F32x16_MUL
12132
+ #define GGML_F32X_FMA GGML_F32x16_FMA
12133
+ #define GLA_VECTOR_SIZE 16
12134
+ #elif defined(__ARM_NEON) && defined(__aarch64__)
12135
+ #define GGML_F32X GGML_F32x4
12136
+ #define GGML_F32X_SET1 GGML_F32x4_SET1
12137
+ #define GGML_F32X_LOAD GGML_F32x4_LOAD
12138
+ #define GGML_F32X_STORE GGML_F32x4_STORE
12139
+ #define GGML_F32X_MUL GGML_F32x4_MUL
12140
+ #define GGML_F32X_FMA GGML_F32x4_FMA
12141
+ #define GLA_VECTOR_SIZE 4
12142
+ #endif
12143
+
12144
+ #ifdef GLA_VECTOR_SIZE
12145
+ const int64_t vec_count = head_size / GLA_VECTOR_SIZE;
12146
+
12147
+ for (int64_t t = 0; t < T; t++) {
12148
+ size_t t_offset = t * t_stride;
12149
+ size_t state_offset = head_size * C * (t / (T / n_seqs));
12150
+ float * state_cur = state + state_offset;
12151
+ float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[4]->data + state_offset;
12152
+
12153
+ for (int64_t h = h_start; h < h_end; h++) {
12154
+ size_t h_offset = h * h_stride;
12155
+ size_t t_h_offset = t_offset + h_offset;
12156
+ size_t h_2d_offset = h * h_stride_2d;
12157
+
12158
+ for (int64_t i = 0; i < head_size; i++) {
12159
+ size_t t_h_i_offset = t_h_offset + i;
12160
+ size_t h_2d_i_offset = h_2d_offset + i * h_stride;
12161
+
12162
+ float k_val = k[t_h_i_offset];
12163
+ float q_val = q[t_h_i_offset] * scale;
12164
+ float g_val = g[t_h_i_offset];
12165
+
12166
+ // Broadcast scalar values to vectors
12167
+ GGML_F32X k_vec = GGML_F32X_SET1(k_val);
12168
+ GGML_F32X q_vec = GGML_F32X_SET1(q_val);
12169
+ GGML_F32X g_vec = GGML_F32X_SET1(g_val);
12170
+
12171
+ for (int64_t j = 0; j < vec_count; j++) {
12172
+ size_t base_j = j * GLA_VECTOR_SIZE;
12173
+ size_t t_h_j_offset = t_h_offset + base_j;
12174
+ size_t h_2d_i_j_offset = h_2d_i_offset + base_j;
12175
+
12176
+ // Load x elements at once
12177
+ GGML_F32X v_vec = GGML_F32X_LOAD(&v[t_h_j_offset]);
12178
+ GGML_F32X prev_state_vec = GGML_F32X_LOAD(&state_prev[h_2d_i_j_offset]);
12179
+ GGML_F32X dst_vec = GGML_F32X_LOAD(&dst_data[t_h_j_offset]);
12180
+
12181
+ // Compute kv = v * k
12182
+ GGML_F32X kv_vec = GGML_F32X_MUL(v_vec, k_vec);
12183
+
12184
+ // Compute temp = prev_state * g + kv
12185
+ GGML_F32X temp_vec = GGML_F32X_FMA(kv_vec, prev_state_vec, g_vec);
12186
+
12187
+ // Update dst: dst += temp * q
12188
+ dst_vec = GGML_F32X_FMA(dst_vec, temp_vec, q_vec);
12189
+ GGML_F32X_STORE(&dst_data[t_h_j_offset], dst_vec);
12190
+
12191
+ // Update state
12192
+ GGML_F32X_STORE(&state_cur[h_2d_i_j_offset], temp_vec);
12193
+ }
12194
+
12195
+ // Handle remaining elements, this will not be used.
12196
+ for (int64_t j = vec_count * GLA_VECTOR_SIZE; j < head_size; j++) {
12197
+ size_t t_h_j_offset = t_h_offset + j;
12198
+ size_t h_2d_i_j_offset = h_2d_i_offset + j;
12199
+ float v_val = v[t_h_j_offset];
12200
+ float kv_val = v_val * k_val;
12201
+ float prev_state_val = state_prev[h_2d_i_j_offset];
12202
+ float temp_val = kv_val + prev_state_val * g_val;
12203
+ dst_data[t_h_j_offset] += temp_val * q_val;
12204
+ state_cur[h_2d_i_j_offset] = temp_val;
12205
+ }
12206
+ }
12207
+ }
12208
+ }
12209
+
12210
+ #else
12211
+ for (int64_t t = 0; t < T; t++) {
12212
+ size_t t_offset = t * t_stride;
12213
+ size_t state_offset = head_size * C * (t / (T / n_seqs));
12214
+ float * state_cur = state + state_offset;
12215
+ float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[4]->data + state_offset;
12216
+
12217
+ for (int64_t h = h_start; h < h_end; h++) {
12218
+ size_t h_offset = h * h_stride;
12219
+ size_t t_h_offset = t_offset + h_offset;
12220
+ size_t h_2d_offset = h * h_stride_2d;
12221
+
12222
+ for (int64_t i = 0; i < head_size; i++) {
12223
+ size_t t_h_i_offset = t_h_offset + i;
12224
+ size_t h_2d_i_offset = h_2d_offset + i * h_stride;
12225
+
12226
+ float k_val = k[t_h_i_offset];
12227
+ float q_val = q[t_h_i_offset] * scale;
12228
+ float g_val = g[t_h_i_offset];
12229
+
12230
+ for (int64_t j = 0; j < head_size; j++) {
12231
+ size_t t_h_j_offset = t_h_offset + j;
12232
+ size_t h_2d_i_j_offset = h_2d_i_offset + j;
12233
+
12234
+ float v_val = v[t_h_j_offset];
12235
+ float kv_val = v_val * k_val;
12236
+ float prev_state_val = state_prev[h_2d_i_j_offset];
12237
+ float temp_val = prev_state_val * g_val + kv_val;
12238
+ dst_data[t_h_j_offset] += temp_val * q_val;
12239
+ state_cur[h_2d_i_j_offset] = temp_val;
12240
+ }
12241
+ }
12242
+ }
12243
+ }
12244
+ #endif
12245
+ }
12246
+
12247
+
12248
+ static void ggml_compute_forward_gla(
12249
+ const struct ggml_compute_params * params,
12250
+ struct ggml_tensor * dst) {
12251
+
12252
+ const struct ggml_tensor * src0 = dst->src[0];
12253
+
12254
+ switch (src0->type) {
12255
+ case GGML_TYPE_F32:
12256
+ {
12257
+ ggml_compute_forward_gla_f32(params, dst);
12258
+ } break;
12259
+ default:
12260
+ {
12261
+ GGML_ABORT("fatal error");
12262
+ }
12263
+ }
12264
+ }
12265
+
12003
12266
  // ggml_compute_forward_map_unary
12004
12267
 
12005
12268
  static void ggml_compute_forward_map_unary_f32(
@@ -12293,22 +12556,22 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
12293
12556
  const struct ggml_compute_params * params,
12294
12557
  struct ggml_tensor * dst) {
12295
12558
 
12296
- const struct ggml_tensor * src0 = dst->src[0];
12297
- const struct ggml_tensor * src1 = dst->src[1];
12298
- const struct ggml_tensor * opt0 = dst->src[2];
12559
+ const struct ggml_tensor * grad = dst->src[0]; // gradient of forward pass output
12560
+ const struct ggml_tensor * src0f = dst->src[1]; // src0 of forward pass
12561
+ const struct ggml_tensor * src1f = dst->src[2]; // src1 of forward pass
12299
12562
 
12300
12563
  GGML_ASSERT(ggml_is_contiguous(dst));
12301
- GGML_ASSERT(ggml_is_contiguous(src0));
12302
- GGML_ASSERT(ggml_is_contiguous(src1));
12303
- GGML_ASSERT(ggml_is_contiguous(opt0));
12304
- GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
12564
+ GGML_ASSERT(ggml_is_contiguous(src0f));
12565
+ GGML_ASSERT(ggml_is_contiguous(src1f));
12566
+ GGML_ASSERT(ggml_is_contiguous(grad));
12567
+ GGML_ASSERT(ggml_are_same_shape(src0f, src1f) && ggml_are_same_shape(src0f, dst));
12305
12568
 
12306
12569
  const int64_t ith = params->ith;
12307
12570
  const int64_t nth = params->nth;
12308
12571
 
12309
12572
  // TODO: handle transposed/permuted matrices
12310
- const int64_t nc = src0->ne[0];
12311
- const int64_t nr = ggml_nrows(src0);
12573
+ const int64_t nc = src0f->ne[0];
12574
+ const int64_t nr = ggml_nrows(src0f);
12312
12575
 
12313
12576
  // rows per thread
12314
12577
  const int64_t dr = (nr + nth - 1)/nth;
@@ -12317,12 +12580,12 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
12317
12580
  const int64_t ir0 = dr*ith;
12318
12581
  const int64_t ir1 = MIN(ir0 + dr, nr);
12319
12582
 
12320
- const float d_by_nr = ((const float *) opt0->data)[0] / (float) nr;
12583
+ const float d_by_nr = ((const float *) grad->data)[0] / (float) nr;
12321
12584
 
12322
12585
  for (int64_t i1 = ir0; i1 < ir1; i1++) {
12323
- float * ds0 = (float *)((char *) dst->data + i1*dst->nb[1]);
12324
- float * s0 = (float *)((char *) src0->data + i1*src0->nb[1]);
12325
- float * s1 = (float *)((char *) src1->data + i1*src1->nb[1]);
12586
+ float * ds0 = (float *)((char *) dst->data + i1*dst->nb[1]);
12587
+ const float * s0 = (const float *)((const char *) src0f->data + i1*src0f->nb[1]);
12588
+ const float * s1 = (const float *)((const char *) src1f->data + i1*src1f->nb[1]);
12326
12589
 
12327
12590
  #ifndef NDEBUG
12328
12591
  for (int64_t i = 0; i < nc; ++i) {
@@ -12335,11 +12598,11 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
12335
12598
  // soft_max
12336
12599
  float max = -INFINITY;
12337
12600
  ggml_vec_max_f32(nc, &max, s0);
12338
- ggml_float sum = ggml_vec_soft_max_f32(nc, ds0, s0, max);
12601
+ const ggml_float sum = ggml_vec_soft_max_f32(nc, ds0, s0, max);
12339
12602
  assert(sum > 0.0);
12340
12603
  ggml_vec_scale_f32(nc, ds0, 1.0/sum);
12341
12604
 
12342
- // grad(src0) = (softmax(src0) - src1) * grad(cross_entropy_loss(src0, src1)) / nr
12605
+ // grad(src0f) = (softmax(src0f) - src1f) * grad(cross_entropy_loss(src0f, src1f)) / nr
12343
12606
  ggml_vec_sub_f32(nc, ds0, ds0, s1);
12344
12607
  ggml_vec_scale_f32(nc, ds0, d_by_nr);
12345
12608
 
@@ -12636,7 +12899,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
12636
12899
  } break;
12637
12900
  case GGML_OP_SOFT_MAX_BACK:
12638
12901
  {
12639
- ggml_compute_forward_soft_max_back(params, tensor);
12902
+ ggml_compute_forward_soft_max_ext_back(params, tensor);
12640
12903
  } break;
12641
12904
  case GGML_OP_ROPE:
12642
12905
  {
@@ -12749,6 +13012,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
12749
13012
  {
12750
13013
  ggml_compute_forward_rwkv_wkv6(params, tensor);
12751
13014
  } break;
13015
+ case GGML_OP_GATED_LINEAR_ATTN:
13016
+ {
13017
+ ggml_compute_forward_gla(params, tensor);
13018
+ } break;
12752
13019
  case GGML_OP_MAP_UNARY:
12753
13020
  {
12754
13021
  ggml_unary_op_f32_t fun;
@@ -13047,6 +13314,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
13047
13314
  case GGML_OP_WIN_UNPART:
13048
13315
  case GGML_OP_GET_REL_POS:
13049
13316
  case GGML_OP_RWKV_WKV6:
13317
+ case GGML_OP_GATED_LINEAR_ATTN:
13050
13318
  case GGML_OP_MAP_UNARY:
13051
13319
  case GGML_OP_MAP_BINARY:
13052
13320
  case GGML_OP_MAP_CUSTOM1_F32:
@@ -13472,6 +13740,7 @@ struct ggml_cplan ggml_graph_plan(
13472
13740
  } break;
13473
13741
  case GGML_OP_SOFT_MAX:
13474
13742
  case GGML_OP_ROPE:
13743
+ case GGML_OP_ROPE_BACK:
13475
13744
  {
13476
13745
  cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
13477
13746
  } break;