@novastera-oss/llamarn 0.2.7 → 0.2.9

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (186) hide show
  1. package/android/src/main/cpp/include/llama.h +8 -3
  2. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  3. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  4. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  5. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  6. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  7. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  8. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  9. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  10. package/cpp/LlamaCppModel.cpp +56 -22
  11. package/cpp/build-info.cpp +2 -2
  12. package/cpp/llama.cpp/CMakeLists.txt +1 -1
  13. package/cpp/llama.cpp/common/arg.cpp +7 -0
  14. package/cpp/llama.cpp/common/common.cpp +3 -0
  15. package/cpp/llama.cpp/common/common.h +1 -0
  16. package/cpp/llama.cpp/common/json-schema-to-grammar.cpp +3 -46
  17. package/cpp/llama.cpp/convert_hf_to_gguf.py +118 -20
  18. package/cpp/llama.cpp/ggml/CMakeLists.txt +1 -0
  19. package/cpp/llama.cpp/ggml/include/ggml-cpu.h +2 -0
  20. package/cpp/llama.cpp/ggml/include/ggml.h +33 -0
  21. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +17 -0
  22. package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +1 -1
  23. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +31 -2
  24. package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +10 -9
  25. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +109 -108
  26. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +1027 -1038
  27. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +53 -52
  28. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
  29. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +56 -55
  30. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +42 -41
  31. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +24 -23
  32. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +29 -28
  33. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +30 -29
  34. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +83 -82
  35. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/repack.cpp +20 -19
  36. package/cpp/llama.cpp/ggml/src/ggml-cpu/common.h +3 -2
  37. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +9 -3
  38. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +83 -102
  39. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +4 -0
  40. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +3 -2
  41. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +192 -67
  42. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +2 -0
  43. package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +25 -24
  44. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +56 -40
  45. package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +211 -33
  46. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +2 -2
  47. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +45 -45
  48. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +54 -29
  49. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
  50. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
  51. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
  52. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
  53. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +4 -0
  54. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +84 -31
  55. package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cu +19 -0
  56. package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cuh +3 -0
  57. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cu +257 -87
  58. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cuh +2 -3
  59. package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cu +5 -18
  60. package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cuh +0 -1
  61. package/cpp/llama.cpp/ggml/src/ggml-impl.h +61 -183
  62. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +16 -0
  63. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +227 -41
  64. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +362 -182
  65. package/cpp/llama.cpp/ggml/src/ggml-musa/mudnn.cuh +2 -2
  66. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +240 -535
  67. package/cpp/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +5 -6
  68. package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +1 -24
  69. package/cpp/llama.cpp/ggml/src/ggml-sycl/concat.cpp +28 -41
  70. package/cpp/llama.cpp/ggml/src/ggml-sycl/conv.cpp +4 -10
  71. package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +99 -166
  72. package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +94 -72
  73. package/cpp/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +49 -67
  74. package/cpp/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +31 -1
  75. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +99 -159
  76. package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +6 -9
  77. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +45 -54
  78. package/cpp/llama.cpp/ggml/src/ggml-sycl/gla.cpp +2 -2
  79. package/cpp/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +1 -1
  80. package/cpp/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +60 -80
  81. package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +132 -201
  82. package/cpp/llama.cpp/ggml/src/ggml-sycl/norm.cpp +55 -74
  83. package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +24 -20
  84. package/cpp/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +3 -3
  85. package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
  86. package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
  87. package/cpp/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +3 -8
  88. package/cpp/llama.cpp/ggml/src/ggml-sycl/wkv.cpp +12 -16
  89. package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +12 -1
  90. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +57 -1
  91. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -0
  92. package/cpp/llama.cpp/ggml/src/ggml.c +69 -13
  93. package/cpp/llama.cpp/ggml/src/gguf.cpp +5 -1
  94. package/cpp/llama.cpp/gguf-py/gguf/constants.py +76 -0
  95. package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +21 -0
  96. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +64 -0
  97. package/cpp/llama.cpp/gguf-py/gguf/vocab.py +97 -4
  98. package/cpp/llama.cpp/gguf-py/pyproject.toml +2 -2
  99. package/cpp/llama.cpp/include/llama.h +8 -3
  100. package/cpp/llama.cpp/models/templates/Mistral-Small-3.2-24B-Instruct-2506.jinja +124 -0
  101. package/cpp/llama.cpp/src/llama-arch.cpp +55 -0
  102. package/cpp/llama.cpp/src/llama-arch.h +18 -0
  103. package/cpp/llama.cpp/src/llama-batch.cpp +570 -359
  104. package/cpp/llama.cpp/src/llama-batch.h +98 -70
  105. package/cpp/llama.cpp/src/llama-chat.cpp +11 -6
  106. package/cpp/llama.cpp/src/llama-context.cpp +101 -107
  107. package/cpp/llama.cpp/src/llama-context.h +13 -13
  108. package/cpp/llama.cpp/src/llama-graph.cpp +199 -252
  109. package/cpp/llama.cpp/src/llama-graph.h +44 -32
  110. package/cpp/llama.cpp/src/llama-hparams.cpp +4 -0
  111. package/cpp/llama.cpp/src/llama-hparams.h +8 -0
  112. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +51 -53
  113. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +19 -24
  114. package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +110 -104
  115. package/cpp/llama.cpp/src/llama-kv-cache-unified.h +17 -22
  116. package/cpp/llama.cpp/src/llama-kv-cells.h +35 -11
  117. package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +66 -67
  118. package/cpp/llama.cpp/src/llama-memory-hybrid.h +16 -21
  119. package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +69 -68
  120. package/cpp/llama.cpp/src/llama-memory-recurrent.h +15 -20
  121. package/cpp/llama.cpp/src/llama-memory.h +18 -22
  122. package/cpp/llama.cpp/src/llama-model-saver.cpp +1 -0
  123. package/cpp/llama.cpp/src/llama-model.cpp +1006 -472
  124. package/cpp/llama.cpp/src/llama-model.h +22 -0
  125. package/cpp/llama.cpp/src/llama-quant.cpp +87 -5
  126. package/cpp/llama.cpp/src/llama-vocab.cpp +26 -3
  127. package/cpp/llama.cpp/src/llama-vocab.h +1 -0
  128. package/cpp/rn-utils.h +3 -0
  129. package/ios/include/common.h +1 -0
  130. package/ios/include/llama.h +8 -3
  131. package/ios/libs/llama.xcframework/Info.plist +19 -19
  132. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  133. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4890 -4863
  134. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
  135. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +33 -0
  136. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +8 -3
  137. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  138. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  139. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4861 -4834
  140. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3764 -3742
  141. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
  142. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +33 -0
  143. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -3
  144. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  145. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  146. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4861 -4834
  147. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3766 -3744
  148. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-cpu.h +2 -0
  149. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +33 -0
  150. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +8 -3
  151. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-cpu.h +2 -0
  152. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +33 -0
  153. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +8 -3
  154. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  155. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-cpu.h +2 -0
  156. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +33 -0
  157. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +8 -3
  158. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  159. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  160. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  161. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4890 -4863
  162. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
  163. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +33 -0
  164. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +8 -3
  165. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  166. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  167. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4861 -4834
  168. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3764 -3742
  169. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
  170. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +33 -0
  171. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -3
  172. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  173. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  174. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4926 -4900
  175. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
  176. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +33 -0
  177. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +8 -3
  178. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  179. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  180. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4897 -4871
  181. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3794 -3773
  182. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
  183. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +33 -0
  184. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -3
  185. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  186. package/package.json +1 -1
@@ -87,41 +87,33 @@ void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) {
87
87
 
88
88
  void llm_graph_input_pos_bucket_kv::set_input(const llama_ubatch * ubatch) {
89
89
  if (pos_bucket) {
90
- kv_state->set_input_pos_bucket(pos_bucket, ubatch);
90
+ mctx->set_input_pos_bucket(pos_bucket, ubatch);
91
91
  }
92
92
  }
93
93
 
94
94
  void llm_graph_input_out_ids::set_input(const llama_ubatch * ubatch) {
95
- if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
96
- //GGML_ASSERT(out_ids && "every model that can must skip unused outputs");
95
+ GGML_ASSERT(out_ids);
97
96
 
98
- if (!out_ids) {
99
- LLAMA_LOG_WARN("%s: 'out_ids' is not created\n", __func__);
100
- } else {
101
- const int64_t n_tokens = ubatch->n_tokens;
97
+ const int64_t n_tokens = ubatch->n_tokens;
102
98
 
103
- GGML_ASSERT(ggml_backend_buffer_is_host(out_ids->buffer));
104
- int32_t * data = (int32_t *) out_ids->data;
99
+ GGML_ASSERT(ggml_backend_buffer_is_host(out_ids->buffer));
100
+ int32_t * data = (int32_t *) out_ids->data;
105
101
 
106
- if (n_outputs == n_tokens) {
107
- for (int i = 0; i < n_tokens; ++i) {
108
- data[i] = i;
109
- }
110
- } else if (ubatch->output) {
111
- int32_t n_outputs = 0;
112
- for (int i = 0; i < n_tokens; ++i) {
113
- if (ubatch->output[i]) {
114
- data[n_outputs++] = i;
115
- }
116
- }
117
- // the graph needs to have been passed the correct number of outputs
118
- GGML_ASSERT(n_outputs == n_outputs);
119
- } else if (n_outputs == 1) {
120
- // only keep last output
121
- data[0] = n_tokens - 1;
122
- } else {
123
- GGML_ASSERT(n_outputs == 0);
124
- }
102
+ if (n_outputs == n_tokens) {
103
+ for (int i = 0; i < n_tokens; ++i) {
104
+ data[i] = i;
105
+ }
106
+
107
+ return;
108
+ }
109
+
110
+ GGML_ASSERT(ubatch->output);
111
+
112
+ int n_outputs = 0;
113
+
114
+ for (int i = 0; i < n_tokens; ++i) {
115
+ if (ubatch->output[i]) {
116
+ data[n_outputs++] = i;
125
117
  }
126
118
  }
127
119
  }
@@ -130,110 +122,97 @@ void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) {
130
122
  if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
131
123
  const int64_t n_tokens = ubatch->n_tokens;
132
124
  const int64_t n_seq_tokens = ubatch->n_seq_tokens;
133
- const int64_t n_seqs = ubatch->n_seqs;
125
+ const int64_t n_seqs_unq = ubatch->n_seqs_unq;
134
126
 
135
127
  GGML_ASSERT(mean);
136
128
  GGML_ASSERT(ggml_backend_buffer_is_host(mean->buffer));
137
129
 
138
130
  float * data = (float *) mean->data;
139
- memset(mean->data, 0, n_tokens * n_tokens * ggml_element_size(mean));
131
+ memset(mean->data, 0, n_tokens*n_seqs_unq*ggml_element_size(mean));
140
132
 
141
- std::vector<uint64_t> sum(n_tokens, 0);
133
+ std::vector<uint64_t> sums(n_seqs_unq, 0);
134
+ for (int i = 0; i < n_tokens; i += n_seq_tokens) {
135
+ for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
136
+ const llama_seq_id seq_id = ubatch->seq_id[i][s];
137
+ const int32_t seq_idx = ubatch->seq_idx[seq_id];
142
138
 
143
- // TODO: fix indexing [UBATCH_IDX]
144
- for (int s = 0; s < n_seqs; ++s) {
145
- const llama_seq_id seq_id = ubatch->seq_id[s][0];
146
-
147
- // TODO: adapt limits to n_seqs when ubatch->equal_seqs is true
148
- GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == MEAN");
149
-
150
- sum[seq_id] += ubatch->n_seq_tokens;
139
+ sums[seq_idx] += ubatch->n_seq_tokens;
140
+ }
151
141
  }
152
142
 
153
- std::vector<float> div(n_tokens, 0.0f);
154
- for (int i = 0; i < n_tokens; ++i) {
155
- const uint64_t s = sum[i];
156
- if (s > 0) {
157
- div[i] = 1.0f/float(s);
143
+ std::vector<float> div(n_seqs_unq, 0.0f);
144
+ for (int s = 0; s < n_seqs_unq; ++s) {
145
+ const uint64_t sum = sums[s];
146
+ if (sum > 0) {
147
+ div[s] = 1.0f/float(sum);
158
148
  }
159
149
  }
160
150
 
161
- // TODO: fix indexing [UBATCH_IDX]
162
- for (int s = 0; s < n_seqs; ++s) {
163
- const llama_seq_id seq_id = ubatch->seq_id[s][0];
151
+ for (int i = 0; i < n_tokens; i += n_seq_tokens) {
152
+ for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
153
+ const llama_seq_id seq_id = ubatch->seq_id[i][s];
154
+ const int32_t seq_idx = ubatch->seq_idx[seq_id];
164
155
 
165
- for (int i = 0; i < n_seq_tokens; ++i) {
166
- data[seq_id*n_tokens + s*n_seq_tokens + i] = div[seq_id];
156
+ for (int j = 0; j < n_seq_tokens; ++j) {
157
+ data[seq_idx*n_tokens + i + j] = div[seq_idx];
158
+ }
167
159
  }
168
160
  }
169
161
  }
170
162
  }
171
163
 
172
164
  void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
173
- if (cparams.embeddings && (
174
- cparams.pooling_type == LLAMA_POOLING_TYPE_CLS ||
175
- cparams.pooling_type == LLAMA_POOLING_TYPE_RANK)) {
176
- const int64_t n_tokens = ubatch->n_tokens;
177
- const int64_t n_seq_tokens = ubatch->n_seq_tokens;
178
- const int64_t n_seqs = ubatch->n_seqs;
165
+ const int64_t n_tokens = ubatch->n_tokens;
166
+ const int64_t n_seq_tokens = ubatch->n_seq_tokens;
167
+ const int64_t n_seqs_unq = ubatch->n_seqs_unq;
179
168
 
169
+ if (cparams.embeddings && (
170
+ cparams.pooling_type == LLAMA_POOLING_TYPE_CLS ||
171
+ cparams.pooling_type == LLAMA_POOLING_TYPE_RANK
172
+ )) {
180
173
  GGML_ASSERT(cls);
181
174
  GGML_ASSERT(ggml_backend_buffer_is_host(cls->buffer));
182
175
 
183
176
  uint32_t * data = (uint32_t *) cls->data;
184
- memset(cls->data, 0, n_tokens * ggml_element_size(cls));
177
+ memset(cls->data, 0, n_seqs_unq*ggml_element_size(cls));
185
178
 
186
- // TODO: fix indexing [UBATCH_IDX]
187
- for (int s = 0; s < n_seqs; ++s) {
188
- const llama_seq_id seq_id = ubatch->seq_id[s][0];
179
+ for (int i = 0; i < n_tokens; i += n_seq_tokens) {
180
+ for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
181
+ const llama_seq_id seq_id = ubatch->seq_id[i][s];
182
+ const int32_t seq_idx = ubatch->seq_idx[seq_id];
189
183
 
190
- // TODO: adapt limits to n_seqs when ubatch->equal_seqs is true
191
- GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS or RANK");
192
-
193
- for (int i = 0; i < n_seq_tokens; ++i) {
194
- const llama_pos pos = ubatch->pos[s*n_seq_tokens + i];
195
-
196
- if (pos == 0) {
197
- data[seq_id] = s*n_seq_tokens + i;
198
- }
184
+ data[seq_idx] = i;
199
185
  }
200
186
  }
201
187
  }
202
188
 
203
189
  if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) {
204
- const int64_t n_tokens = ubatch->n_tokens;
205
- const int64_t n_seq_tokens = ubatch->n_seq_tokens;
206
- const int64_t n_seqs = ubatch->n_seqs;
207
-
208
190
  GGML_ASSERT(cls);
209
191
  GGML_ASSERT(ggml_backend_buffer_is_host(cls->buffer));
210
192
 
211
193
  uint32_t * data = (uint32_t *) cls->data;
212
- memset(cls->data, 0, n_tokens * ggml_element_size(cls));
213
-
214
- std::vector<int> last_pos(n_tokens, -1);
215
- std::vector<int> last_row(n_tokens, -1);
194
+ memset(cls->data, 0, n_seqs_unq*ggml_element_size(cls));
216
195
 
217
- // TODO: fix indexing [UBATCH_IDX]
218
- for (int s = 0; s < n_seqs; ++s) {
219
- const llama_seq_id seq_id = ubatch->seq_id[s][0];
196
+ std::vector<int> last_pos(n_seqs_unq, -1);
197
+ std::vector<int> last_row(n_seqs_unq, -1);
220
198
 
221
- // TODO: adapt limits to n_seqs when ubatch->equal_seqs is true
222
- GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == LAST");
199
+ for (int i = 0; i < n_tokens; ++i) {
200
+ const llama_pos pos = ubatch->pos[i];
223
201
 
224
- for (int i = 0; i < n_seq_tokens; ++i) {
225
- const llama_pos pos = ubatch->pos[s*n_seq_tokens + i];
202
+ for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
203
+ const llama_seq_id seq_id = ubatch->seq_id[i][s];
204
+ const int32_t seq_idx = ubatch->seq_idx[seq_id];
226
205
 
227
- if (pos >= last_pos[seq_id]) {
228
- last_pos[seq_id] = pos;
229
- last_row[seq_id] = s*n_seq_tokens + i;
206
+ if (pos >= last_pos[seq_idx]) {
207
+ last_pos[seq_idx] = pos;
208
+ last_row[seq_idx] = i;
230
209
  }
231
210
  }
232
211
  }
233
212
 
234
- for (int i = 0; i < n_tokens; ++i) {
235
- if (last_row[i] >= 0) {
236
- data[i] = last_row[i];
213
+ for (int s = 0; s < n_seqs_unq; ++s) {
214
+ if (last_row[s] >= 0) {
215
+ data[s] = last_row[s];
237
216
  }
238
217
  }
239
218
  }
@@ -242,7 +221,7 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
242
221
  void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) {
243
222
  GGML_UNUSED(ubatch);
244
223
 
245
- const int64_t n_rs = mem_state->get_n_rs();
224
+ const int64_t n_rs = mctx->get_n_rs();
246
225
 
247
226
  if (s_copy) {
248
227
  GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
@@ -250,7 +229,7 @@ void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) {
250
229
 
251
230
  // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
252
231
  for (uint32_t i = 0; i < n_rs; ++i) {
253
- data[i] = mem_state->s_copy(i);
232
+ data[i] = mctx->s_copy(i);
254
233
  }
255
234
  }
256
235
  }
@@ -266,89 +245,36 @@ void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
266
245
  }
267
246
 
268
247
  void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
269
- if (kq_mask) {
270
- if (cparams.causal_attn) {
271
- const int64_t n_kv = ubatch->n_tokens;
272
- const int64_t n_tokens = ubatch->n_tokens;
273
- const int64_t n_seq_tokens = ubatch->n_seq_tokens;
274
- const int64_t n_seqs = ubatch->n_seqs;
275
-
276
- GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer));
277
- float * data = (float *) kq_mask->data;
278
-
279
- for (int h = 0; h < 1; ++h) {
280
- for (int s1 = 0; s1 < n_seqs; ++s1) {
281
- const llama_seq_id seq_id = ubatch->seq_id[s1][0];
282
-
283
- for (int j = 0; j < n_seq_tokens; ++j) {
284
- const int32_t tj = s1*n_seq_tokens + j;
285
-
286
- for (int s0 = 0; s0 < n_seqs; ++s0) {
287
- for (int i = 0; i < n_seq_tokens; ++i) {
288
- const int32_t ti = s0*n_seq_tokens + i;
289
- float f = -INFINITY;
290
-
291
- // TODO: fix indexing [UBATCH_IDX]
292
- for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) {
293
- if (ubatch->seq_id[s0][s] == seq_id && ubatch->pos[ti] <= ubatch->pos[tj]) {
294
- if (hparams.use_alibi) {
295
- f = -std::abs(ubatch->pos[ti] - ubatch->pos[tj]);
296
- } else {
297
- f = 0.0f;
298
- }
299
- break;
300
- }
301
- }
302
-
303
- data[h*(n_kv*n_tokens) + tj*n_kv + ti] = f;
304
- }
305
- }
306
- }
307
- }
308
- }
309
- } else {
310
- const int64_t n_tokens = ubatch->n_tokens;
311
- const int64_t n_seq_tokens = ubatch->n_seq_tokens;
312
- const int64_t n_seqs = ubatch->n_seqs;
313
- const int64_t n_stride = ubatch->n_tokens;
314
-
315
- GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer));
316
-
317
- float * data = (float *) kq_mask->data;
318
-
319
- for (int h = 0; h < 1; ++h) {
320
- for (int s1 = 0; s1 < n_seqs; ++s1) {
321
- const llama_seq_id seq_id = ubatch->seq_id[s1][0];
322
-
323
- for (int j = 0; j < n_seq_tokens; ++j) {
324
- const int32_t tj = s1*n_seq_tokens + j;
325
-
326
- for (int s0 = 0; s0 < n_seqs; ++s0) {
327
- for (int i = 0; i < n_seq_tokens; ++i) {
328
- const int32_t ti = s0*n_seq_tokens + i;
329
- float f = -INFINITY;
330
-
331
- // TODO: fix indexing [UBATCH_IDX]
332
- for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) {
333
- if (ubatch->seq_id[s0][s] == seq_id) {
334
- if (hparams.use_alibi) {
335
- f = -std::abs(ubatch->pos[ti] - ubatch->pos[tj]);
336
- } else {
337
- f = 0.0f;
338
- }
339
- break;
340
- }
341
- }
342
-
343
- data[h*(n_tokens*n_tokens) + tj*n_stride + ti] = f;
344
- }
345
- }
248
+ const int64_t n_kv = ubatch->n_tokens;
249
+ const int64_t n_tokens = ubatch->n_tokens;
250
+
251
+ GGML_ASSERT(kq_mask);
252
+ GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer));
253
+
254
+ float * data = (float *) kq_mask->data;
255
+
256
+ for (int h = 0; h < 1; ++h) {
257
+ for (int i1 = 0; i1 < n_tokens; ++i1) {
258
+ const llama_seq_id s1 = ubatch->seq_id[i1][0];
259
+
260
+ for (int i0 = 0; i0 < n_tokens; ++i0) {
261
+ float f = -INFINITY;
346
262
 
347
- for (int i = n_tokens; i < n_stride; ++i) {
348
- data[h*(n_tokens*n_tokens) + tj*n_stride + i] = -INFINITY;
263
+ for (int s = 0; s < ubatch->n_seq_id[i0]; ++s) {
264
+ const llama_seq_id s0 = ubatch->seq_id[i0][0];
265
+
266
+ // TODO: reimplement this like in llama_kv_cache_unified
267
+ if (s0 == s1 && (!cparams.causal_attn || ubatch->pos[i0] <= ubatch->pos[i1])) {
268
+ if (hparams.use_alibi) {
269
+ f = -std::abs(ubatch->pos[i0] - ubatch->pos[i1]);
270
+ } else {
271
+ f = 0.0f;
349
272
  }
273
+ break;
350
274
  }
351
275
  }
276
+
277
+ data[h*(n_kv*n_tokens) + i1*n_kv + i0] = f;
352
278
  }
353
279
  }
354
280
  }
@@ -356,49 +282,51 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
356
282
 
357
283
  void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
358
284
  if (self_kq_mask) {
359
- kv_state->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
285
+ mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
360
286
  }
361
287
  }
362
288
 
363
289
  void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
364
290
  if (self_kq_mask) {
365
- kv_state->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
291
+ mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
366
292
  }
367
293
 
368
294
  if (self_kq_mask_swa) {
369
- kv_state->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
295
+ mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
370
296
  }
371
297
  }
372
298
 
373
299
  void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
374
- if (cross_kq_mask) {
375
- const int64_t n_enc = cross_kq_mask->ne[0];
376
- const int64_t n_tokens = ubatch->n_tokens;
300
+ GGML_ASSERT(cross_kq_mask);
377
301
 
378
- GGML_ASSERT(ggml_backend_buffer_is_host(cross_kq_mask->buffer));
379
- GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
302
+ const int64_t n_enc = cross_kq_mask->ne[0];
303
+ const int64_t n_tokens = ubatch->n_tokens;
380
304
 
381
- float * data = (float *) cross_kq_mask->data;
305
+ GGML_ASSERT(ggml_backend_buffer_is_host(cross_kq_mask->buffer));
306
+ GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
382
307
 
383
- for (int h = 0; h < 1; ++h) {
384
- for (int j = 0; j < n_tokens; ++j) {
385
- for (int i = 0; i < n_enc; ++i) {
386
- float f = -INFINITY;
387
- // TODO: fix indexing [UBATCH_IDX]
388
- for (int s = 0; s < ubatch->n_seq_id[j]; ++s) {
389
- const llama_seq_id seq_id = ubatch->seq_id[j][s];
390
- if (cross->seq_ids_enc[i].find(seq_id) != cross->seq_ids_enc[i].end()) {
391
- f = 0.0f;
392
- }
308
+ float * data = (float *) cross_kq_mask->data;
309
+
310
+ for (int h = 0; h < 1; ++h) {
311
+ for (int i = 0; i < n_tokens; ++i) {
312
+ for (int j = 0; j < n_enc; ++j) {
313
+ float f = -INFINITY;
314
+
315
+ for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
316
+ const llama_seq_id seq_id = ubatch->seq_id[i][s];
317
+
318
+ if (cross->seq_ids_enc[j].find(seq_id) != cross->seq_ids_enc[j].end()) {
319
+ f = 0.0f;
393
320
  }
394
- data[h*(n_enc*n_tokens) + j*n_enc + i] = f;
395
321
  }
322
+
323
+ data[h*(n_enc*n_tokens) + i*n_enc + j] = f;
396
324
  }
325
+ }
397
326
 
398
- for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
399
- for (int j = 0; j < n_enc; ++j) {
400
- data[h*(n_enc*n_tokens) + i*n_enc + j] = -INFINITY;
401
- }
327
+ for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
328
+ for (int j = 0; j < n_enc; ++j) {
329
+ data[h*(n_enc*n_tokens) + i*n_enc + j] = -INFINITY;
402
330
  }
403
331
  }
404
332
  }
@@ -406,10 +334,10 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
406
334
 
407
335
  void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
408
336
  if (self_kq_mask) {
409
- mem_state->get_state_attn()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
337
+ mctx->get_attn()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
410
338
  }
411
339
 
412
- const int64_t n_rs = mem_state->get_state_recr()->get_n_rs();
340
+ const int64_t n_rs = mctx->get_recr()->get_n_rs();
413
341
 
414
342
  if (s_copy) {
415
343
  GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
@@ -417,11 +345,17 @@ void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
417
345
 
418
346
  // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
419
347
  for (uint32_t i = 0; i < n_rs; ++i) {
420
- data[i] = mem_state->get_state_recr()->s_copy(i);
348
+ data[i] = mctx->get_recr()->s_copy(i);
421
349
  }
422
350
  }
423
351
  }
424
352
 
353
+ void llm_graph_input_one::set_input(const llama_ubatch *) {
354
+ GGML_ASSERT(one && ggml_nelements(one) == 1);
355
+ float f_one = 1.0f;
356
+ ggml_backend_tensor_set(one, &f_one, 0, sizeof(float));
357
+ }
358
+
425
359
  //
426
360
  // llm_graph_context
427
361
  //
@@ -461,16 +395,12 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
461
395
  backend_cpu (params.backend_cpu),
462
396
  cvec (params.cvec),
463
397
  loras (params.loras),
464
- mstate (params.mstate),
398
+ mctx (params.mctx),
465
399
  cross (params.cross),
466
400
  cb_func (params.cb),
467
401
  res (std::make_unique<llm_graph_result>()) {
468
402
  }
469
403
 
470
- int64_t llm_graph_context::n_pos_per_embd() const {
471
- return hparams.rope_type == LLAMA_ROPE_TYPE_MROPE ? 4 : 1;
472
- }
473
-
474
404
  void llm_graph_context::cb(ggml_tensor * cur, const char * name, int il) const {
475
405
  if (cb_func) {
476
406
  cb_func(ubatch, cur, name, il);
@@ -915,11 +845,11 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
915
845
  }
916
846
 
917
847
  ggml_tensor * llm_graph_context::build_inp_pos() const {
918
- auto inp = std::make_unique<llm_graph_input_pos>(n_pos_per_embd());
848
+ auto inp = std::make_unique<llm_graph_input_pos>(hparams.n_pos_per_embd());
919
849
 
920
850
  auto & cur = inp->pos;
921
851
 
922
- cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens*n_pos_per_embd());
852
+ cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, (int64_t)n_tokens*hparams.n_pos_per_embd());
923
853
  ggml_set_input(cur);
924
854
 
925
855
  res->add_input(std::move(inp));
@@ -942,6 +872,14 @@ ggml_tensor * llm_graph_context::build_inp_attn_scale() const {
942
872
  }
943
873
 
944
874
  ggml_tensor * llm_graph_context::build_inp_out_ids() const {
875
+ // note: when all tokens are output, we could skip this optimization to spare the ggml_get_rows() calls,
876
+ // but this would make the graph topology depend on the number of output tokens, which can interere with
877
+ // features that require constant topology such as pipline parallelism
878
+ // ref: https://github.com/ggml-org/llama.cpp/pull/14275#issuecomment-2987424471
879
+ //if (n_outputs < n_tokens) {
880
+ // return nullptr;
881
+ //}
882
+
945
883
  auto inp = std::make_unique<llm_graph_input_out_ids>(hparams, cparams, n_outputs);
946
884
 
947
885
  auto & cur = inp->out_ids;
@@ -959,7 +897,7 @@ ggml_tensor * llm_graph_context::build_inp_mean() const {
959
897
 
960
898
  auto & cur = inp->mean;
961
899
 
962
- cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens);
900
+ cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, ubatch.n_seqs_unq);
963
901
  ggml_set_input(cur);
964
902
 
965
903
  res->add_input(std::move(inp));
@@ -972,7 +910,7 @@ ggml_tensor * llm_graph_context::build_inp_cls() const {
972
910
 
973
911
  auto & cur = inp->cls;
974
912
 
975
- cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
913
+ cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_seqs_unq);
976
914
  ggml_set_input(cur);
977
915
 
978
916
  res->add_input(std::move(inp));
@@ -1018,11 +956,11 @@ ggml_tensor * llm_graph_context::build_inp_pos_bucket_enc() const {
1018
956
  }
1019
957
 
1020
958
  ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const {
1021
- const auto * kv_state = static_cast<const llama_kv_cache_unified_state *>(mstate);
959
+ const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
1022
960
 
1023
- auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, kv_state);
961
+ auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, mctx_cur);
1024
962
 
1025
- const auto n_kv = kv_state->get_n_kv();
963
+ const auto n_kv = mctx_cur->get_n_kv();
1026
964
 
1027
965
  auto & cur = inp->pos_bucket;
1028
966
 
@@ -1050,14 +988,14 @@ ggml_tensor * llm_graph_context::build_pos_bias(ggml_tensor * pos_bucket, ggml_t
1050
988
  }
1051
989
 
1052
990
  llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
1053
- const auto * mem_state = static_cast<const llama_memory_hybrid_state *>(mstate);
991
+ const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx);
1054
992
 
1055
- auto inp = std::make_unique<llm_graph_input_mem_hybrid>(hparams, cparams, mem_state);
993
+ auto inp = std::make_unique<llm_graph_input_mem_hybrid>(hparams, cparams, mctx_cur);
1056
994
 
1057
995
  {
1058
996
  GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Hybrid recurrent is not supported with SWA attention layers");
1059
997
 
1060
- const auto n_kv = inp->mem_state->get_state_attn()->get_n_kv();
998
+ const auto n_kv = inp->mctx->get_attn()->get_n_kv();
1061
999
 
1062
1000
  inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1063
1001
  //cb(inp->self_kq_mask, "KQ_mask", -1);
@@ -1067,7 +1005,7 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
1067
1005
  }
1068
1006
 
1069
1007
  {
1070
- const auto n_rs = mem_state->get_state_recr()->get_n_rs();
1008
+ const auto n_rs = mctx_cur->get_recr()->get_n_rs();
1071
1009
 
1072
1010
  inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
1073
1011
  ggml_set_input(inp->s_copy);
@@ -1251,14 +1189,14 @@ ggml_tensor * llm_graph_context::build_attn(
1251
1189
  }
1252
1190
 
1253
1191
  llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const {
1254
- const auto * kv_state = static_cast<const llama_kv_cache_unified_state *>(mstate);
1192
+ const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
1255
1193
 
1256
- auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, kv_state);
1194
+ auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, mctx_cur);
1257
1195
 
1258
1196
  {
1259
1197
  GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
1260
1198
 
1261
- const auto n_kv = kv_state->get_n_kv();
1199
+ const auto n_kv = mctx_cur->get_n_kv();
1262
1200
 
1263
1201
  inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1264
1202
  //cb(inp->self_kq_mask, "KQ_mask", -1);
@@ -1288,19 +1226,19 @@ ggml_tensor * llm_graph_context::build_attn(
1288
1226
  ggml_build_forward_expand(gf, k_cur);
1289
1227
  ggml_build_forward_expand(gf, v_cur);
1290
1228
 
1291
- const auto * kv_state = static_cast<const llama_kv_cache_unified_state *>(mstate);
1229
+ const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
1292
1230
 
1293
1231
  // store to KV cache
1294
1232
  {
1295
- ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il));
1296
- ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il));
1233
+ ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
1234
+ ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
1297
1235
  }
1298
1236
 
1299
1237
  const auto & kq_mask = inp->get_kq_mask();
1300
1238
 
1301
1239
  ggml_tensor * q = q_cur;
1302
- ggml_tensor * k = kv_state->get_k(ctx0, il);
1303
- ggml_tensor * v = kv_state->get_v(ctx0, il);
1240
+ ggml_tensor * k = mctx_cur->get_k(ctx0, il);
1241
+ ggml_tensor * v = mctx_cur->get_v(ctx0, il);
1304
1242
 
1305
1243
  ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1306
1244
  cb(cur, "kqv_out", il);
@@ -1335,26 +1273,35 @@ ggml_tensor * llm_graph_context::build_attn(
1335
1273
  // these nodes are added to the graph together so that they are not reordered
1336
1274
  // by doing so, the number of splits in the graph is reduced
1337
1275
  ggml_build_forward_expand(gf, q_cur);
1338
- ggml_build_forward_expand(gf, k_cur);
1339
- ggml_build_forward_expand(gf, v_cur);
1340
1276
 
1341
- const auto * kv_state_iswa = static_cast<const llama_kv_cache_unified_iswa_state *>(mstate);
1277
+ if (k_cur) {
1278
+ ggml_build_forward_expand(gf, k_cur);
1279
+ }
1280
+
1281
+ if (v_cur) {
1282
+ ggml_build_forward_expand(gf, v_cur);
1283
+ }
1284
+
1285
+ const auto * mctx_iswa = static_cast<const llama_kv_cache_unified_iswa_context *>(mctx);
1342
1286
 
1343
1287
  const bool is_swa = hparams.is_swa(il);
1344
1288
 
1345
- const auto * kv_state = is_swa ? kv_state_iswa->get_swa() : kv_state_iswa->get_base();
1289
+ const auto * mctx_cur = is_swa ? mctx_iswa->get_swa() : mctx_iswa->get_base();
1346
1290
 
1347
- // store to KV cache
1348
- {
1349
- ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il));
1350
- ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il));
1291
+ // optionally store to KV cache
1292
+ if (k_cur) {
1293
+ ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
1294
+ }
1295
+
1296
+ if (v_cur) {
1297
+ ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
1351
1298
  }
1352
1299
 
1353
1300
  const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
1354
1301
 
1355
1302
  ggml_tensor * q = q_cur;
1356
- ggml_tensor * k = kv_state->get_k(ctx0, il);
1357
- ggml_tensor * v = kv_state->get_v(ctx0, il);
1303
+ ggml_tensor * k = mctx_cur->get_k(ctx0, il);
1304
+ ggml_tensor * v = mctx_cur->get_v(ctx0, il);
1358
1305
 
1359
1306
  ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1360
1307
  cb(cur, "kqv_out", il);
@@ -1447,19 +1394,19 @@ ggml_tensor * llm_graph_context::build_attn(
1447
1394
  ggml_build_forward_expand(gf, k_cur);
1448
1395
  ggml_build_forward_expand(gf, v_cur);
1449
1396
 
1450
- const auto * kv_state = static_cast<const llama_memory_hybrid_state *>(mstate)->get_state_attn();
1397
+ const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx)->get_attn();
1451
1398
 
1452
1399
  // store to KV cache
1453
1400
  {
1454
- ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il));
1455
- ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il));
1401
+ ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
1402
+ ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
1456
1403
  }
1457
1404
 
1458
1405
  const auto & kq_mask = inp->get_kq_mask();
1459
1406
 
1460
1407
  ggml_tensor * q = q_cur;
1461
- ggml_tensor * k = kv_state->get_k(ctx0, il);
1462
- ggml_tensor * v = kv_state->get_v(ctx0, il);
1408
+ ggml_tensor * k = mctx_cur->get_k(ctx0, il);
1409
+ ggml_tensor * v = mctx_cur->get_v(ctx0, il);
1463
1410
 
1464
1411
  ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1465
1412
  cb(cur, "kqv_out", il);
@@ -1480,12 +1427,12 @@ ggml_tensor * llm_graph_context::build_attn(
1480
1427
  }
1481
1428
 
1482
1429
  llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
1483
- const auto * kv_state = static_cast<const llama_kv_cache_unified_iswa_state *>(mstate);
1430
+ const auto * mctx_cur = static_cast<const llama_kv_cache_unified_iswa_context *>(mctx);
1484
1431
 
1485
- auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, kv_state);
1432
+ auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, mctx_cur);
1486
1433
 
1487
1434
  {
1488
- const auto n_kv = kv_state->get_base()->get_n_kv();
1435
+ const auto n_kv = mctx_cur->get_base()->get_n_kv();
1489
1436
 
1490
1437
  inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1491
1438
  //cb(inp->self_kq_mask, "KQ_mask", -1);
@@ -1497,7 +1444,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
1497
1444
  {
1498
1445
  GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA");
1499
1446
 
1500
- const auto n_kv = kv_state->get_swa()->get_n_kv();
1447
+ const auto n_kv = mctx_cur->get_swa()->get_n_kv();
1501
1448
 
1502
1449
  inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1503
1450
  //cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
@@ -1553,11 +1500,11 @@ ggml_tensor * llm_graph_context::build_rs(
1553
1500
  }
1554
1501
 
1555
1502
  llm_graph_input_rs * llm_graph_context::build_rs_inp() const {
1556
- const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
1503
+ const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
1557
1504
 
1558
- auto inp = std::make_unique<llm_graph_input_rs>(kv_state);
1505
+ auto inp = std::make_unique<llm_graph_input_rs>(mctx_cur);
1559
1506
 
1560
- const auto n_rs = kv_state->get_n_rs();
1507
+ const auto n_rs = mctx_cur->get_n_rs();
1561
1508
 
1562
1509
  inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
1563
1510
  ggml_set_input(inp->s_copy);
@@ -1572,9 +1519,9 @@ ggml_tensor * llm_graph_context::build_rs(
1572
1519
  int32_t state_size,
1573
1520
  int32_t n_seqs,
1574
1521
  bool avoid_copies) const {
1575
- const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
1522
+ const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
1576
1523
 
1577
- return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), avoid_copies);
1524
+ return build_rs(gf, s, inp->s_copy, state_size, n_seqs, mctx_cur->get_n_rs(), mctx_cur->get_head(), mctx_cur->get_size(), mctx_cur->get_rs_z(), avoid_copies);
1578
1525
  }
1579
1526
 
1580
1527
  ggml_tensor * llm_graph_context::build_rs(
@@ -1584,9 +1531,9 @@ ggml_tensor * llm_graph_context::build_rs(
1584
1531
  int32_t state_size,
1585
1532
  int32_t n_seqs,
1586
1533
  bool avoid_copies) const {
1587
- const auto * kv_state = static_cast<const llama_memory_hybrid_state *>(mstate)->get_state_recr();
1534
+ const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx)->get_recr();
1588
1535
 
1589
- return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), avoid_copies);
1536
+ return build_rs(gf, s, inp->s_copy, state_size, n_seqs, mctx_cur->get_n_rs(), mctx_cur->get_head(), mctx_cur->get_size(), mctx_cur->get_rs_z(), avoid_copies);
1590
1537
  }
1591
1538
 
1592
1539
  ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
@@ -1594,13 +1541,13 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
1594
1541
  ggml_cgraph * gf,
1595
1542
  const llama_ubatch & ubatch,
1596
1543
  int il) const {
1597
- const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
1544
+ const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
1598
1545
 
1599
1546
  const auto token_shift_count = hparams.token_shift_count;
1600
1547
 
1601
1548
  const int64_t n_seqs = ubatch.n_seqs;
1602
1549
 
1603
- ggml_tensor * token_shift_all = kv_state->get_r_l(il);
1550
+ ggml_tensor * token_shift_all = mctx_cur->get_r_l(il);
1604
1551
 
1605
1552
  ggml_tensor * token_shift = build_rs(
1606
1553
  inp, gf, token_shift_all,
@@ -1615,19 +1562,19 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
1615
1562
  ggml_tensor * token_shift,
1616
1563
  const llama_ubatch & ubatch,
1617
1564
  int il) const {
1618
- const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
1565
+ const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
1619
1566
 
1620
1567
  const auto token_shift_count = hparams.token_shift_count;
1621
1568
  const auto n_embd = hparams.n_embd;
1622
1569
 
1623
1570
  const int64_t n_seqs = ubatch.n_seqs;
1624
1571
 
1625
- const auto kv_head = kv_state->get_head();
1572
+ const auto kv_head = mctx_cur->get_head();
1626
1573
 
1627
1574
  return ggml_cpy(
1628
1575
  ctx0,
1629
1576
  ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * token_shift_count, 0),
1630
- ggml_view_1d(ctx0, kv_state->get_r_l(il), hparams.n_embd_r()*n_seqs, hparams.n_embd_r()*kv_head*ggml_element_size(kv_state->get_r_l(il)))
1577
+ ggml_view_1d(ctx0, mctx_cur->get_r_l(il), hparams.n_embd_r()*n_seqs, hparams.n_embd_r()*kv_head*ggml_element_size(mctx_cur->get_r_l(il)))
1631
1578
  );
1632
1579
  }
1633
1580