cui-llama.rn 1.6.1 → 1.7.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (196) hide show
  1. package/android/src/main/CMakeLists.txt +6 -0
  2. package/android/src/main/java/com/rnllama/LlamaContext.java +51 -14
  3. package/android/src/main/java/com/rnllama/RNLlama.java +158 -6
  4. package/android/src/main/jni.cpp +153 -14
  5. package/android/src/main/jniLibs/arm64-v8a/librnllama.so +0 -0
  6. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8.so +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2.so +0 -0
  8. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod.so +0 -0
  9. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod_i8mm.so +0 -0
  10. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_i8mm.so +0 -0
  11. package/android/src/main/jniLibs/x86_64/librnllama.so +0 -0
  12. package/android/src/main/jniLibs/x86_64/librnllama_x86_64.so +0 -0
  13. package/android/src/newarch/java/com/rnllama/RNLlamaModule.java +24 -4
  14. package/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +22 -2
  15. package/cpp/chat.cpp +128 -106
  16. package/cpp/chat.h +2 -0
  17. package/cpp/common.cpp +38 -76
  18. package/cpp/common.h +23 -19
  19. package/cpp/ggml-backend.cpp +9 -5
  20. package/cpp/ggml-backend.h +4 -4
  21. package/cpp/ggml-cpu/ggml-cpu-aarch64.cpp +0 -2
  22. package/cpp/ggml-cpu/ggml-cpu-quants.c +306 -6
  23. package/cpp/ggml-cpu/ggml-cpu.c +5 -13
  24. package/cpp/ggml-cpu/ggml-cpu.cpp +29 -16
  25. package/cpp/ggml-cpu/ops.cpp +107 -13
  26. package/cpp/ggml-cpu/vec.cpp +0 -6
  27. package/cpp/ggml-cpu/vec.h +16 -0
  28. package/cpp/ggml-llama-sim.metallib +0 -0
  29. package/cpp/ggml-llama.metallib +0 -0
  30. package/cpp/ggml-metal-impl.h +36 -11
  31. package/cpp/ggml-metal.m +321 -132
  32. package/cpp/ggml-opt.cpp +373 -190
  33. package/cpp/ggml-opt.h +49 -28
  34. package/cpp/ggml-quants.c +0 -6
  35. package/cpp/ggml.c +93 -38
  36. package/cpp/ggml.h +21 -7
  37. package/cpp/gguf.cpp +33 -33
  38. package/cpp/llama-adapter.cpp +6 -0
  39. package/cpp/llama-arch.cpp +3 -0
  40. package/cpp/llama-batch.cpp +3 -1
  41. package/cpp/llama-chat.cpp +8 -6
  42. package/cpp/llama-chat.h +1 -0
  43. package/cpp/llama-context.cpp +349 -135
  44. package/cpp/llama-context.h +30 -3
  45. package/cpp/llama-cparams.h +1 -0
  46. package/cpp/llama-graph.cpp +150 -234
  47. package/cpp/llama-graph.h +52 -7
  48. package/cpp/llama-hparams.cpp +17 -1
  49. package/cpp/llama-hparams.h +34 -5
  50. package/cpp/llama-kv-cache.cpp +662 -321
  51. package/cpp/llama-kv-cache.h +203 -93
  52. package/cpp/llama-memory.h +3 -2
  53. package/cpp/llama-model-loader.cpp +24 -15
  54. package/cpp/llama-model-saver.cpp +281 -0
  55. package/cpp/llama-model-saver.h +37 -0
  56. package/cpp/llama-model.cpp +536 -132
  57. package/cpp/llama-model.h +7 -1
  58. package/cpp/llama-sampling.cpp +18 -6
  59. package/cpp/llama-vocab.cpp +46 -8
  60. package/cpp/llama-vocab.h +6 -0
  61. package/cpp/llama.cpp +14 -0
  62. package/cpp/llama.h +72 -131
  63. package/cpp/minja/chat-template.hpp +9 -5
  64. package/cpp/minja/minja.hpp +69 -36
  65. package/cpp/rn-llama.cpp +611 -47
  66. package/cpp/rn-llama.h +33 -3
  67. package/cpp/sampling.cpp +57 -50
  68. package/cpp/tools/mtmd/clip-impl.h +462 -0
  69. package/cpp/tools/mtmd/clip.cpp +4024 -0
  70. package/cpp/tools/mtmd/clip.h +101 -0
  71. package/cpp/tools/mtmd/miniaudio.h +93468 -0
  72. package/cpp/tools/mtmd/mtmd-audio.cpp +855 -0
  73. package/cpp/tools/mtmd/mtmd-audio.h +62 -0
  74. package/cpp/tools/mtmd/mtmd-helper.cpp +297 -0
  75. package/cpp/tools/mtmd/mtmd.cpp +942 -0
  76. package/cpp/tools/mtmd/mtmd.h +362 -0
  77. package/cpp/tools/mtmd/stb_image.h +7988 -0
  78. package/ios/CMakeLists.txt +7 -0
  79. package/ios/RNLlama.mm +77 -3
  80. package/ios/RNLlamaContext.h +5 -1
  81. package/ios/RNLlamaContext.mm +105 -10
  82. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/chat.h +2 -0
  83. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/common.h +23 -19
  84. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-backend.h +4 -4
  85. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-metal-impl.h +36 -11
  86. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-opt.h +49 -28
  87. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml.h +21 -7
  88. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-chat.h +1 -0
  89. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-context.h +30 -3
  90. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-cparams.h +1 -0
  91. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-graph.h +52 -7
  92. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-hparams.h +34 -5
  93. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-kv-cache.h +203 -93
  94. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-memory.h +3 -2
  95. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-model-saver.h +37 -0
  96. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-model.h +7 -1
  97. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-vocab.h +6 -0
  98. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama.h +72 -131
  99. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/minja/chat-template.hpp +9 -5
  100. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/minja/minja.hpp +69 -36
  101. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/rn-llama.h +33 -3
  102. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Info.plist +0 -0
  103. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/ggml-llama.metallib +0 -0
  104. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/rnllama +0 -0
  105. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/chat.h +2 -0
  106. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +23 -19
  107. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-backend.h +4 -4
  108. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-metal-impl.h +36 -11
  109. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-opt.h +49 -28
  110. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +21 -7
  111. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +1 -0
  112. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +30 -3
  113. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cparams.h +1 -0
  114. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +52 -7
  115. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +34 -5
  116. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +203 -93
  117. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +3 -2
  118. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model-saver.h +37 -0
  119. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +7 -1
  120. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-vocab.h +6 -0
  121. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +72 -131
  122. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/minja/chat-template.hpp +9 -5
  123. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/minja/minja.hpp +69 -36
  124. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +33 -3
  125. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Info.plist +0 -0
  126. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/_CodeSignature/CodeResources +1 -1
  127. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
  128. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
  129. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/chat.h +2 -0
  130. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/common.h +23 -19
  131. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-backend.h +4 -4
  132. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-metal-impl.h +36 -11
  133. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-opt.h +49 -28
  134. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml.h +21 -7
  135. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-chat.h +1 -0
  136. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-context.h +30 -3
  137. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-cparams.h +1 -0
  138. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-graph.h +52 -7
  139. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-hparams.h +34 -5
  140. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-kv-cache.h +203 -93
  141. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-memory.h +3 -2
  142. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-model-saver.h +37 -0
  143. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-model.h +7 -1
  144. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-vocab.h +6 -0
  145. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama.h +72 -131
  146. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/minja/chat-template.hpp +9 -5
  147. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/minja/minja.hpp +69 -36
  148. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/rn-llama.h +33 -3
  149. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Info.plist +0 -0
  150. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/ggml-llama.metallib +0 -0
  151. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/rnllama +0 -0
  152. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/chat.h +2 -0
  153. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +23 -19
  154. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-backend.h +4 -4
  155. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-metal-impl.h +36 -11
  156. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-opt.h +49 -28
  157. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +21 -7
  158. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +1 -0
  159. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +30 -3
  160. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cparams.h +1 -0
  161. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +52 -7
  162. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +34 -5
  163. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +203 -93
  164. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +3 -2
  165. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model-saver.h +37 -0
  166. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +7 -1
  167. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-vocab.h +6 -0
  168. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +72 -131
  169. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/minja/chat-template.hpp +9 -5
  170. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/minja/minja.hpp +69 -36
  171. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +33 -3
  172. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Info.plist +0 -0
  173. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/_CodeSignature/CodeResources +1 -1
  174. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
  175. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
  176. package/jest/mock.js +33 -7
  177. package/lib/commonjs/NativeRNLlama.js.map +1 -1
  178. package/lib/commonjs/index.js +153 -21
  179. package/lib/commonjs/index.js.map +1 -1
  180. package/lib/module/NativeRNLlama.js.map +1 -1
  181. package/lib/module/index.js +152 -20
  182. package/lib/module/index.js.map +1 -1
  183. package/lib/typescript/NativeRNLlama.d.ts +50 -4
  184. package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
  185. package/lib/typescript/index.d.ts +72 -6
  186. package/lib/typescript/index.d.ts.map +1 -1
  187. package/package.json +1 -1
  188. package/src/NativeRNLlama.ts +67 -4
  189. package/src/index.ts +212 -38
  190. package/lib/commonjs/chat.js +0 -37
  191. package/lib/commonjs/chat.js.map +0 -1
  192. package/lib/module/chat.js +0 -33
  193. package/lib/module/chat.js.map +0 -1
  194. package/lib/typescript/chat.d.ts +0 -10
  195. package/lib/typescript/chat.d.ts.map +0 -1
  196. package/src/chat.ts +0 -44
@@ -7,6 +7,7 @@
7
7
  #include "llama-adapter.h"
8
8
 
9
9
  #include "ggml-cpp.h"
10
+ #include "ggml-opt.h"
10
11
 
11
12
  #include <map>
12
13
  #include <vector>
@@ -133,6 +134,32 @@ struct llama_context {
133
134
  llama_perf_context_data perf_get_data() const;
134
135
  void perf_reset();
135
136
 
137
+ //
138
+ // training
139
+ //
140
+
141
+ void opt_init(struct llama_model * model, struct llama_opt_params lopt_params);
142
+
143
+ void opt_epoch(
144
+ lm_ggml_opt_dataset_t dataset,
145
+ lm_ggml_opt_result_t result_train,
146
+ lm_ggml_opt_result_t result_eval,
147
+ int64_t idata_split,
148
+ lm_ggml_opt_epoch_callback callback_train,
149
+ lm_ggml_opt_epoch_callback callback_eval);
150
+
151
+ void opt_epoch_iter(
152
+ lm_ggml_opt_dataset_t dataset,
153
+ lm_ggml_opt_result_t result,
154
+ const std::vector<llama_token> & tokens,
155
+ const std::vector<llama_token> & labels_sparse,
156
+ llama_batch & batch,
157
+ lm_ggml_opt_epoch_callback callback,
158
+ bool train,
159
+ int64_t idata_in_loop,
160
+ int64_t ndata_in_loop,
161
+ int64_t t_loop_start);
162
+
136
163
  private:
137
164
  //
138
165
  // output
@@ -187,9 +214,6 @@ private:
187
214
 
188
215
  std::unique_ptr<llama_memory_i> memory;
189
216
 
190
- // TODO: remove
191
- bool logits_all = false;
192
-
193
217
  // decode output (2-dimensional array: [n_outputs][n_vocab])
194
218
  size_t logits_size = 0; // capacity (of floats) for logits
195
219
  float * logits = nullptr;
@@ -215,6 +239,9 @@ private:
215
239
 
216
240
  lm_ggml_context_ptr ctx_compute;
217
241
 
242
+ // training
243
+ lm_ggml_opt_context_t opt_ctx = nullptr;
244
+
218
245
  lm_ggml_threadpool_t threadpool = nullptr;
219
246
  lm_ggml_threadpool_t threadpool_batch = nullptr;
220
247
 
@@ -30,6 +30,7 @@ struct llama_cparams {
30
30
  bool flash_attn;
31
31
  bool no_perf;
32
32
  bool warmup;
33
+ bool op_offload;
33
34
 
34
35
  enum llama_pooling_type pooling_type;
35
36
 
@@ -9,33 +9,6 @@
9
9
  #include <cmath>
10
10
  #include <cstring>
11
11
 
12
- static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
13
- // TODO move to hparams if a T5 variant appears that uses a different value
14
- const int64_t max_distance = 128;
15
-
16
- if (bidirectional) {
17
- n_buckets >>= 1;
18
- }
19
-
20
- const int64_t max_exact = n_buckets >> 1;
21
-
22
- int32_t relative_position = x - y;
23
- int32_t relative_bucket = 0;
24
-
25
- if (bidirectional) {
26
- relative_bucket += (relative_position > 0) * n_buckets;
27
- relative_position = abs(relative_position);
28
- } else {
29
- relative_position = -std::min<int32_t>(relative_position, 0);
30
- }
31
-
32
- int32_t relative_position_if_large = floorf(max_exact + logf(1.0 * relative_position / max_exact) * (n_buckets - max_exact) / log(1.0 * max_distance / max_exact));
33
- relative_position_if_large = std::min<int32_t>(relative_position_if_large, n_buckets - 1);
34
- relative_bucket += (relative_position < max_exact ? relative_position : relative_position_if_large);
35
-
36
- return relative_bucket;
37
- }
38
-
39
12
  void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
40
13
  if (ubatch->token) {
41
14
  const int64_t n_tokens = ubatch->n_tokens;
@@ -110,22 +83,7 @@ void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) {
110
83
 
111
84
  void llm_graph_input_pos_bucket_kv::set_input(const llama_ubatch * ubatch) {
112
85
  if (pos_bucket) {
113
- const int64_t n_tokens = ubatch->n_tokens;
114
-
115
- LM_GGML_ASSERT(lm_ggml_backend_buffer_is_host(pos_bucket->buffer));
116
- LM_GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
117
-
118
- int32_t * data = (int32_t *) pos_bucket->data;
119
-
120
- const int64_t n_kv = kv_self->n;
121
-
122
- for (int h = 0; h < 1; ++h) {
123
- for (int j = 0; j < n_tokens; ++j) {
124
- for (int i = 0; i < n_kv; ++i) {
125
- data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(kv_self->cells[i].pos, ubatch->pos[j], hparams.n_rel_attn_bkts, false);
126
- }
127
- }
128
- }
86
+ kv_self->set_input_pos_bucket(pos_bucket, ubatch);
129
87
  }
130
88
  }
131
89
 
@@ -403,99 +361,18 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
403
361
  }
404
362
 
405
363
  void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
406
- if (self_kq_mask || self_kq_mask_swa) {
407
- const int64_t n_kv = kv_self->n;
408
- const int64_t n_tokens = ubatch->n_tokens;
409
- const int64_t n_seq_tokens = ubatch->n_seq_tokens;
410
- const int64_t n_seqs = ubatch->n_seqs;
411
-
412
- float * data = nullptr;
413
- float * data_swa = nullptr;
414
-
415
- if (self_kq_mask) {
416
- LM_GGML_ASSERT(lm_ggml_backend_buffer_is_host(self_kq_mask->buffer));
417
- data = (float *) self_kq_mask->data;
418
- }
419
-
420
- if (self_kq_mask_swa) {
421
- LM_GGML_ASSERT(lm_ggml_backend_buffer_is_host(self_kq_mask_swa->buffer));
422
- data_swa = (float *) self_kq_mask_swa->data;
423
- }
424
-
425
- // Use only the previous KV cells of the correct sequence for each token of the ubatch.
426
- // It's assumed that if a token in the batch has multiple sequences, they are equivalent.
427
- // Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch:
428
- // Causal mask:
429
- // xxx-------
430
- // xxxx------
431
- // xxxxx-----
432
- // Non-causal mask:
433
- // xxxxx-----
434
- // xxxxx-----
435
- // xxxxx-----
436
- // To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
437
- for (int h = 0; h < 1; ++h) {
438
- for (int s = 0; s < n_seqs; ++s) {
439
- const llama_seq_id seq_id = ubatch->seq_id[s][0];
440
-
441
- for (int j = 0; j < n_seq_tokens; ++j) {
442
- const llama_pos pos = ubatch->pos[s*n_seq_tokens + j];
443
- for (int i = 0; i < n_kv; ++i) {
444
- float f;
445
- // mask the token if:
446
- if (!kv_self->cells[i].has_seq_id(seq_id) // not the correct sequence
447
- || (cparams.causal_attn && kv_self->cells[i].pos > pos) // for causal, mask future tokens
448
- ) {
449
- f = -INFINITY;
450
- } else {
451
- if (hparams.use_alibi) {
452
- f = -std::abs(kv_self->cells[i].pos - pos);
453
- } else {
454
- f = 0.0f;
455
- }
456
- }
457
-
458
- if (data) {
459
- data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
460
- }
461
-
462
- // may need to cut off old tokens for sliding window
463
- // TODO @ngxson : we are currently re-using the swa logic to store the chunked mask, we should rename SWA to something more generic like "aux mask"
464
- if (data_swa) {
465
- if (hparams.n_attn_chunk) {
466
- llama_pos pos_chunk_start = (pos / hparams.n_attn_chunk) * hparams.n_attn_chunk;
467
- if (kv_self->cells[i].pos < pos_chunk_start || pos < pos_chunk_start) {
468
- f = -INFINITY;
469
- }
470
- } else {
471
- if (pos - kv_self->cells[i].pos >= (int32_t)hparams.n_swa) {
472
- f = -INFINITY;
473
- }
474
- }
475
- data_swa[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
476
- }
477
- }
478
- }
479
- }
364
+ if (self_kq_mask) {
365
+ kv_self->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
366
+ }
367
+ }
480
368
 
481
- // mask padded tokens
482
- if (data) {
483
- for (int i = n_tokens; i < LM_GGML_PAD(n_tokens, LM_GGML_KQ_MASK_PAD); ++i) {
484
- for (int j = 0; j < n_kv; ++j) {
485
- data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
486
- }
487
- }
488
- }
369
+ void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
370
+ if (self_kq_mask) {
371
+ kv_self->get_kv_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
372
+ }
489
373
 
490
- // mask padded tokens
491
- if (data_swa) {
492
- for (int i = n_tokens; i < LM_GGML_PAD(n_tokens, LM_GGML_KQ_MASK_PAD); ++i) {
493
- for (int j = 0; j < n_kv; ++j) {
494
- data_swa[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
495
- }
496
- }
497
- }
498
- }
374
+ if (self_kq_mask_swa) {
375
+ kv_self->get_kv_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
499
376
  }
500
377
  }
501
378
 
@@ -545,7 +422,6 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
545
422
  n_layer (hparams.n_layer),
546
423
  n_rot (hparams.n_rot),
547
424
  n_ctx (cparams.n_ctx),
548
- n_ctx_per_seq (cparams.n_ctx / cparams.n_seq_max),
549
425
  n_head (hparams.n_head()),
550
426
  n_head_kv (hparams.n_head_kv()),
551
427
  n_embd_head_k (hparams.n_embd_head_k),
@@ -782,7 +658,7 @@ lm_ggml_tensor * llm_graph_context::build_ffn(
782
658
  } break;
783
659
  }
784
660
 
785
- if (type_gate == LLM_FFN_PAR) {
661
+ if (gate && type_gate == LLM_FFN_PAR) {
786
662
  cur = lm_ggml_mul(ctx0, cur, tmp);
787
663
  cb(cur, "ffn_gate_par", il);
788
664
  }
@@ -971,6 +847,7 @@ lm_ggml_tensor * llm_graph_context::build_inp_embd(lm_ggml_tensor * tok_embd) co
971
847
  inp->tokens = lm_ggml_new_tensor_1d(ctx0, LM_GGML_TYPE_I32, ubatch.n_tokens);
972
848
  //cb(inp->tokens, "inp_tokens", -1);
973
849
  lm_ggml_set_input(inp->tokens);
850
+ res->t_tokens = inp->tokens;
974
851
 
975
852
  cur = lm_ggml_get_rows(ctx0, tok_embd, inp->tokens);
976
853
 
@@ -1152,7 +1029,7 @@ lm_ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const {
1152
1029
 
1153
1030
  auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, kv_self);
1154
1031
 
1155
- const auto n_kv = kv_self->n;
1032
+ const auto n_kv = kv_self->get_n();
1156
1033
 
1157
1034
  auto & cur = inp->pos_bucket;
1158
1035
 
@@ -1187,16 +1064,12 @@ lm_ggml_tensor * llm_graph_context::build_attn_mha(
1187
1064
  lm_ggml_tensor * kq_b,
1188
1065
  lm_ggml_tensor * kq_mask,
1189
1066
  lm_ggml_tensor * v_mla,
1190
- bool v_trans,
1191
1067
  float kq_scale) const {
1192
- //const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
1193
- //const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
1068
+ const bool v_trans = v->nb[1] > v->nb[2];
1194
1069
 
1195
- //const int64_t n_head = hparams.n_head(il);
1196
- //const int64_t n_head_kv = hparams.n_head_kv(il);
1197
-
1198
- //const auto & n_embd_head_k = hparams.n_embd_head_k;
1199
- //const auto & n_embd_head_v = hparams.n_embd_head_v;
1070
+ q = lm_ggml_permute(ctx0, q, 0, 2, 1, 3);
1071
+ k = lm_ggml_permute(ctx0, k, 0, 2, 1, 3);
1072
+ v = lm_ggml_permute(ctx0, v, 0, 2, 1, 3);
1200
1073
 
1201
1074
  const auto n_tokens = q->ne[1];
1202
1075
  const auto n_head = q->ne[2];
@@ -1227,8 +1100,19 @@ lm_ggml_tensor * llm_graph_context::build_attn_mha(
1227
1100
  lm_ggml_flash_attn_ext_set_prec(cur, LM_GGML_PREC_F32);
1228
1101
 
1229
1102
  if (v_mla) {
1103
+ #if 0
1104
+ // v_mla can be applied as a matrix-vector multiplication with broadcasting across dimension 3 == n_tokens.
1105
+ // However, the code is optimized for dimensions 0 and 1 being large, so this is ineffient.
1230
1106
  cur = lm_ggml_reshape_4d(ctx0, cur, v_mla->ne[0], 1, n_head, n_tokens);
1231
1107
  cur = lm_ggml_mul_mat(ctx0, v_mla, cur);
1108
+ #else
1109
+ // It's preferable to do the calculation as a matrix-matrix multiplication with n_tokens in dimension 1.
1110
+ // The permutations are noops and only change how the tensor data is interpreted.
1111
+ cur = lm_ggml_permute(ctx0, cur, 0, 2, 1, 3);
1112
+ cur = lm_ggml_mul_mat(ctx0, v_mla, cur);
1113
+ cur = lm_ggml_permute(ctx0, cur, 0, 2, 1, 3);
1114
+ cur = lm_ggml_cont(ctx0, cur); // Needed because lm_ggml_reshape_2d expects contiguous inputs.
1115
+ #endif
1232
1116
  }
1233
1117
 
1234
1118
  cur = lm_ggml_reshape_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
@@ -1324,17 +1208,11 @@ lm_ggml_tensor * llm_graph_context::build_attn(
1324
1208
 
1325
1209
  const auto & kq_mask = inp->get_kq_mask();
1326
1210
 
1327
- lm_ggml_tensor * q = lm_ggml_permute(ctx0, q_cur, 0, 2, 1, 3);
1328
- //cb(q, "q", il);
1329
-
1330
- lm_ggml_tensor * k = lm_ggml_permute(ctx0, k_cur, 0, 2, 1, 3);
1331
- //cb(k, "k", il);
1332
-
1333
- lm_ggml_tensor * v = lm_ggml_permute(ctx0, v_cur, 0, 2, 1, 3);
1334
- //cb(k, "v", il);
1335
-
1336
- lm_ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, false, kq_scale);
1211
+ lm_ggml_tensor * q = q_cur;
1212
+ lm_ggml_tensor * k = k_cur;
1213
+ lm_ggml_tensor * v = v_cur;
1337
1214
 
1215
+ lm_ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1338
1216
  cb(cur, "kqv_out", il);
1339
1217
 
1340
1218
  if (wo) {
@@ -1357,22 +1235,16 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
1357
1235
 
1358
1236
  auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, kv_self);
1359
1237
 
1360
- const auto n_kv = kv_self->n;
1361
-
1362
- inp->self_kq_mask = lm_ggml_new_tensor_2d(ctx0, LM_GGML_TYPE_F32, n_kv, LM_GGML_PAD(n_tokens, LM_GGML_KQ_MASK_PAD));
1363
- //cb(inp->self_kq_mask, "KQ_mask", -1);
1364
- lm_ggml_set_input(inp->self_kq_mask);
1365
-
1366
- inp->self_kq_mask_cnv = cparams.flash_attn ? lm_ggml_cast(ctx0, inp->self_kq_mask, LM_GGML_TYPE_F16) : inp->self_kq_mask;
1238
+ {
1239
+ LM_GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
1367
1240
 
1368
- if (hparams.n_swa_pattern > 1) {
1369
- LM_GGML_ASSERT(hparams.n_swa > 0);
1241
+ const auto n_kv = kv_self->get_n();
1370
1242
 
1371
- inp->self_kq_mask_swa = lm_ggml_new_tensor_2d(ctx0, LM_GGML_TYPE_F32, n_kv, LM_GGML_PAD(n_tokens, LM_GGML_KQ_MASK_PAD));
1372
- //cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
1373
- lm_ggml_set_input(inp->self_kq_mask_swa);
1243
+ inp->self_kq_mask = lm_ggml_new_tensor_2d(ctx0, LM_GGML_TYPE_F32, n_kv, LM_GGML_PAD(n_tokens, LM_GGML_KQ_MASK_PAD));
1244
+ //cb(inp->self_kq_mask, "KQ_mask", -1);
1245
+ lm_ggml_set_input(inp->self_kq_mask);
1374
1246
 
1375
- inp->self_kq_mask_swa_cnv = cparams.flash_attn ? lm_ggml_cast(ctx0, inp->self_kq_mask_swa, LM_GGML_TYPE_F16) : inp->self_kq_mask_swa;
1247
+ inp->self_kq_mask_cnv = cparams.flash_attn ? lm_ggml_cast(ctx0, inp->self_kq_mask, LM_GGML_TYPE_F16) : inp->self_kq_mask;
1376
1248
  }
1377
1249
 
1378
1250
  return (llm_graph_input_attn_kv_unified *) res->add_input(std::move(inp));
@@ -1397,85 +1269,108 @@ lm_ggml_tensor * llm_graph_context::build_attn(
1397
1269
  lm_ggml_build_forward_expand(gf, v_cur);
1398
1270
 
1399
1271
  const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
1400
- const auto & n_ctx = cparams.n_ctx;
1401
1272
 
1402
- const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
1403
- const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
1273
+ // store to KV cache
1274
+ {
1275
+ lm_ggml_build_forward_expand(gf, kv_self->cpy_k(ctx0, k_cur, il));
1276
+ lm_ggml_build_forward_expand(gf, kv_self->cpy_v(ctx0, v_cur, il));
1277
+ }
1404
1278
 
1405
- const auto n_tokens = q_cur->ne[2];
1279
+ const auto & kq_mask = inp->get_kq_mask();
1406
1280
 
1407
- const bool v_trans = !cparams.flash_attn;
1281
+ lm_ggml_tensor * q = q_cur;
1282
+ lm_ggml_tensor * k = kv_self->get_k(ctx0, il);
1283
+ lm_ggml_tensor * v = kv_self->get_v(ctx0, il);
1408
1284
 
1409
- // store to KV cache
1410
- {
1411
- const auto kv_head = kv_self->head;
1285
+ lm_ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1286
+ cb(cur, "kqv_out", il);
1412
1287
 
1413
- LM_GGML_ASSERT(kv_self->size == n_ctx);
1288
+ if (wo) {
1289
+ cur = build_lora_mm(wo, cur);
1290
+ }
1414
1291
 
1415
- lm_ggml_tensor * k_cache_view = lm_ggml_view_1d(ctx0, kv_self->k_l[il], n_tokens*n_embd_k_gqa, lm_ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa)*kv_head);
1416
- //cb(k_cache_view, "k_cache_view", il);
1292
+ if (wo_b) {
1293
+ cur = lm_ggml_add(ctx0, cur, wo_b);
1294
+ }
1417
1295
 
1418
- // note: storing RoPE-ed version of K in the KV cache
1419
- lm_ggml_build_forward_expand(gf, lm_ggml_cpy(ctx0, k_cur, k_cache_view));
1296
+ return cur;
1297
+ }
1420
1298
 
1421
- v_cur = lm_ggml_reshape_2d(ctx0, v_cur, n_embd_v_gqa, n_tokens);
1299
+ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
1300
+ const llama_kv_cache_unified_iswa * kv_self = static_cast<const llama_kv_cache_unified_iswa *>(memory);
1422
1301
 
1423
- lm_ggml_tensor * v_cache_view = nullptr;
1302
+ auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, kv_self);
1424
1303
 
1425
- if (!v_trans) {
1426
- v_cache_view = lm_ggml_view_1d(ctx0, kv_self->v_l[il], n_tokens*n_embd_v_gqa, lm_ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa)*kv_head);
1427
- } else {
1428
- // note: the V cache is transposed when not using flash attention
1429
- v_cache_view = lm_ggml_view_2d(ctx0, kv_self->v_l[il], n_tokens, n_embd_v_gqa,
1430
- ( n_ctx)*lm_ggml_element_size(kv_self->v_l[il]),
1431
- (kv_head)*lm_ggml_element_size(kv_self->v_l[il]));
1304
+ {
1305
+ const auto n_kv = kv_self->get_kv_base()->get_n();
1432
1306
 
1433
- v_cur = lm_ggml_transpose(ctx0, v_cur);
1434
- }
1435
- //cb(v_cache_view, "v_cache_view", il);
1307
+ inp->self_kq_mask = lm_ggml_new_tensor_2d(ctx0, LM_GGML_TYPE_F32, n_kv, LM_GGML_PAD(n_tokens, LM_GGML_KQ_MASK_PAD));
1308
+ //cb(inp->self_kq_mask, "KQ_mask", -1);
1309
+ lm_ggml_set_input(inp->self_kq_mask);
1436
1310
 
1437
- lm_ggml_build_forward_expand(gf, lm_ggml_cpy(ctx0, v_cur, v_cache_view));
1311
+ inp->self_kq_mask_cnv = cparams.flash_attn ? lm_ggml_cast(ctx0, inp->self_kq_mask, LM_GGML_TYPE_F16) : inp->self_kq_mask;
1438
1312
  }
1439
1313
 
1314
+ {
1315
+ LM_GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA");
1316
+
1317
+ const auto n_kv = kv_self->get_kv_swa()->get_n();
1318
+
1319
+ inp->self_kq_mask_swa = lm_ggml_new_tensor_2d(ctx0, LM_GGML_TYPE_F32, n_kv, LM_GGML_PAD(n_tokens, LM_GGML_KQ_MASK_PAD));
1320
+ //cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
1321
+ lm_ggml_set_input(inp->self_kq_mask_swa);
1322
+
1323
+ inp->self_kq_mask_swa_cnv = cparams.flash_attn ? lm_ggml_cast(ctx0, inp->self_kq_mask_swa, LM_GGML_TYPE_F16) : inp->self_kq_mask_swa;
1324
+ }
1325
+
1326
+ return (llm_graph_input_attn_kv_unified_iswa *) res->add_input(std::move(inp));
1327
+ }
1328
+
1329
+ lm_ggml_tensor * llm_graph_context::build_attn(
1330
+ llm_graph_input_attn_kv_unified_iswa * inp,
1331
+ lm_ggml_cgraph * gf,
1332
+ lm_ggml_tensor * wo,
1333
+ lm_ggml_tensor * wo_b,
1334
+ lm_ggml_tensor * q_cur,
1335
+ lm_ggml_tensor * k_cur,
1336
+ lm_ggml_tensor * v_cur,
1337
+ lm_ggml_tensor * kq_b,
1338
+ lm_ggml_tensor * v_mla,
1339
+ float kq_scale,
1340
+ int il) const {
1341
+ // these nodes are added to the graph together so that they are not reordered
1342
+ // by doing so, the number of splits in the graph is reduced
1343
+ lm_ggml_build_forward_expand(gf, q_cur);
1344
+ lm_ggml_build_forward_expand(gf, k_cur);
1345
+ lm_ggml_build_forward_expand(gf, v_cur);
1346
+
1440
1347
  const bool is_swa = hparams.is_swa(il);
1441
1348
 
1349
+ const llama_kv_cache_unified_iswa * kv_self = static_cast<const llama_kv_cache_unified_iswa *>(memory);
1350
+
1351
+ const auto * kv = is_swa ? kv_self->get_kv_swa() : kv_self->get_kv_base();
1352
+
1353
+ // store to KV cache
1354
+ {
1355
+ lm_ggml_build_forward_expand(gf, kv->cpy_k(ctx0, k_cur, il));
1356
+ lm_ggml_build_forward_expand(gf, kv->cpy_v(ctx0, v_cur, il));
1357
+ }
1358
+
1442
1359
  const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
1443
1360
 
1444
- const auto n_kv = kv_self->n;
1361
+ lm_ggml_tensor * q = q_cur;
1362
+ lm_ggml_tensor * k = kv->get_k(ctx0, il);
1363
+ lm_ggml_tensor * v = kv->get_v(ctx0, il);
1445
1364
 
1446
- const int64_t n_head_kv = hparams.n_head_kv(il);
1447
-
1448
- const auto & n_embd_head_k = hparams.n_embd_head_k;
1449
- const auto & n_embd_head_v = hparams.n_embd_head_v;
1450
-
1451
- lm_ggml_tensor * q = lm_ggml_permute(ctx0, q_cur, 0, 2, 1, 3);
1452
- //cb(q, "q", il);
1453
-
1454
- lm_ggml_tensor * k =
1455
- lm_ggml_view_3d(ctx0, kv_self->k_l[il],
1456
- n_embd_head_k, n_kv, n_head_kv,
1457
- lm_ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
1458
- lm_ggml_row_size(kv_self->k_l[il]->type, n_embd_head_k),
1459
- 0);
1460
- //cb(k, "k", il);
1461
-
1462
- lm_ggml_tensor * v = !v_trans ?
1463
- lm_ggml_view_3d(ctx0, kv_self->v_l[il],
1464
- n_embd_head_v, n_kv, n_head_kv,
1465
- lm_ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa),
1466
- lm_ggml_row_size(kv_self->v_l[il]->type, n_embd_head_v),
1467
- 0) :
1468
- lm_ggml_view_3d(ctx0, kv_self->v_l[il],
1469
- n_kv, n_embd_head_v, n_head_kv,
1470
- lm_ggml_element_size(kv_self->v_l[il])*n_ctx,
1471
- lm_ggml_element_size(kv_self->v_l[il])*n_ctx*n_embd_head_v,
1472
- 0);
1473
-
1474
- lm_ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, v_trans, kq_scale);
1365
+ lm_ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1475
1366
  cb(cur, "kqv_out", il);
1476
1367
 
1477
1368
  if (wo) {
1478
1369
  cur = build_lora_mm(wo, cur);
1370
+ if (arch == LLM_ARCH_GLM4) {
1371
+ // GLM4 seems to have numerical issues with half-precision accumulators
1372
+ lm_ggml_mul_mat_set_prec(cur, LM_GGML_PREC_F32);
1373
+ }
1479
1374
  }
1480
1375
 
1481
1376
  if (wo_b) {
@@ -1522,17 +1417,11 @@ lm_ggml_tensor * llm_graph_context::build_attn(
1522
1417
 
1523
1418
  const auto & kq_mask = inp->get_kq_mask_cross();
1524
1419
 
1525
- lm_ggml_tensor * q = lm_ggml_permute(ctx0, q_cur, 0, 2, 1, 3);
1526
- //cb(q, "q", il);
1527
-
1528
- lm_ggml_tensor * k = lm_ggml_permute(ctx0, k_cur, 0, 2, 1, 3);
1529
- //cb(k, "k", il);
1530
-
1531
- lm_ggml_tensor * v = lm_ggml_permute(ctx0, v_cur, 0, 2, 1, 3);
1532
- //cb(k, "v", il);
1533
-
1534
- lm_ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, false, kq_scale);
1420
+ lm_ggml_tensor * q = q_cur;
1421
+ lm_ggml_tensor * k = k_cur;
1422
+ lm_ggml_tensor * v = v_cur;
1535
1423
 
1424
+ lm_ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1536
1425
  cb(cur, "kqv_out", il);
1537
1426
 
1538
1427
  if (wo) {
@@ -1700,3 +1589,30 @@ void llm_graph_context::build_pooling(
1700
1589
 
1701
1590
  lm_ggml_build_forward_expand(gf, cur);
1702
1591
  }
1592
+
1593
+ int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
1594
+ // TODO move to hparams if a T5 variant appears that uses a different value
1595
+ const int64_t max_distance = 128;
1596
+
1597
+ if (bidirectional) {
1598
+ n_buckets >>= 1;
1599
+ }
1600
+
1601
+ const int64_t max_exact = n_buckets >> 1;
1602
+
1603
+ int32_t relative_position = x - y;
1604
+ int32_t relative_bucket = 0;
1605
+
1606
+ if (bidirectional) {
1607
+ relative_bucket += (relative_position > 0) * n_buckets;
1608
+ relative_position = abs(relative_position);
1609
+ } else {
1610
+ relative_position = -std::min<int32_t>(relative_position, 0);
1611
+ }
1612
+
1613
+ int32_t relative_position_if_large = floorf(max_exact + logf(1.0 * relative_position / max_exact) * (n_buckets - max_exact) / log(1.0 * max_distance / max_exact));
1614
+ relative_position_if_large = std::min<int32_t>(relative_position_if_large, n_buckets - 1);
1615
+ relative_bucket += (relative_position < max_exact ? relative_position : relative_position_if_large);
1616
+
1617
+ return relative_bucket;
1618
+ }