@novastera-oss/llamarn 0.3.0 → 0.3.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (190) hide show
  1. package/android/build.gradle +2 -1
  2. package/android/proguard-rules.pro +12 -0
  3. package/android/src/main/cpp/include/llama.h +15 -47
  4. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  5. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  6. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  8. package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
  9. package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
  10. package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
  11. package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
  12. package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
  13. package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
  14. package/android/src/main/jniLibs/x86/libggml.so +0 -0
  15. package/android/src/main/jniLibs/x86/libllama.so +0 -0
  16. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  17. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  18. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  19. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  20. package/cpp/build-info.cpp +2 -2
  21. package/cpp/llama.cpp/CMakePresets.json +11 -0
  22. package/cpp/llama.cpp/CODEOWNERS +1 -0
  23. package/cpp/llama.cpp/README.md +4 -3
  24. package/cpp/llama.cpp/common/arg.cpp +45 -1
  25. package/cpp/llama.cpp/common/common.cpp +22 -6
  26. package/cpp/llama.cpp/common/common.h +18 -4
  27. package/cpp/llama.cpp/convert_hf_to_gguf.py +500 -32
  28. package/cpp/llama.cpp/convert_hf_to_gguf_update.py +12 -13
  29. package/cpp/llama.cpp/ggml/CMakeLists.txt +6 -1
  30. package/cpp/llama.cpp/ggml/cmake/ggml-config.cmake.in +85 -47
  31. package/cpp/llama.cpp/ggml/include/ggml-webgpu.h +19 -0
  32. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +1 -0
  33. package/cpp/llama.cpp/ggml/src/ggml-alloc.c +0 -15
  34. package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +7 -0
  35. package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +8 -20
  36. package/cpp/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +3 -1
  37. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +58 -3
  38. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +130 -22
  39. package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +122 -16
  40. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +5 -2
  41. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +1 -1
  42. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +109 -12
  43. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +3 -0
  44. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +88 -10
  45. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +343 -1094
  46. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +3 -0
  47. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +0 -1
  48. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +3 -0
  49. package/cpp/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +3 -3
  50. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +14 -4
  51. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cu +64 -17
  52. package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy-utils.cuh +225 -0
  53. package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy.cu +41 -301
  54. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +85 -67
  55. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +45 -62
  56. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +28 -43
  57. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +41 -56
  58. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +36 -47
  59. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +31 -43
  60. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +22 -37
  61. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn.cu +3 -13
  62. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +73 -23
  63. package/cpp/llama.cpp/ggml/src/ggml-cuda/im2col.cu +1 -1
  64. package/cpp/llama.cpp/ggml/src/ggml-cuda/mma.cuh +111 -3
  65. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cu +6 -4
  66. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +1152 -689
  67. package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cu +92 -5
  68. package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cuh +2 -0
  69. package/cpp/llama.cpp/ggml/src/ggml-cuda/set-rows.cu +275 -0
  70. package/cpp/llama.cpp/ggml/src/ggml-cuda/set-rows.cuh +7 -0
  71. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +7 -0
  72. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +2 -0
  73. package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +13 -1
  74. package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +2 -2
  75. package/cpp/llama.cpp/ggml/src/ggml-impl.h +16 -0
  76. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +13 -3
  77. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +407 -69
  78. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +380 -83
  79. package/cpp/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +18 -4
  80. package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +2 -0
  81. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +295 -2
  82. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
  83. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
  84. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
  85. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
  86. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/rms_norm.cl +79 -0
  87. package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +4 -4
  88. package/cpp/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +14 -26
  89. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +131 -46
  90. package/cpp/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +1 -1
  91. package/cpp/llama.cpp/ggml/src/ggml-sycl/quants.hpp +8 -9
  92. package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +43 -43
  93. package/cpp/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +2 -6
  94. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +287 -22
  95. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +265 -0
  96. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +1 -5
  97. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +1 -1
  98. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
  99. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +1 -1
  100. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +1 -1
  101. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
  102. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +2 -0
  103. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +2 -0
  104. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +3 -8
  105. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +8 -2
  106. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +1 -4
  107. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rte.comp +5 -0
  108. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +71 -16
  109. package/cpp/llama.cpp/ggml/src/ggml-webgpu/CMakeLists.txt +54 -0
  110. package/cpp/llama.cpp/ggml/src/ggml-webgpu/ggml-webgpu.cpp +907 -0
  111. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +60 -0
  112. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +35 -0
  113. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
  114. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +56 -0
  115. package/cpp/llama.cpp/ggml/src/ggml.c +4 -6
  116. package/cpp/llama.cpp/gguf-py/gguf/constants.py +98 -0
  117. package/cpp/llama.cpp/gguf-py/gguf/metadata.py +4 -0
  118. package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_dump.py +24 -1
  119. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +75 -52
  120. package/cpp/llama.cpp/include/llama.h +15 -7
  121. package/cpp/llama.cpp/models/templates/llama-cpp-rwkv-world.jinja +34 -0
  122. package/cpp/llama.cpp/models/templates/moonshotai-Kimi-K2.jinja +43 -0
  123. package/cpp/llama.cpp/requirements/requirements-all.txt +1 -0
  124. package/cpp/llama.cpp/requirements/requirements-server-bench.txt +5 -0
  125. package/cpp/llama.cpp/src/llama-arch.cpp +106 -0
  126. package/cpp/llama.cpp/src/llama-arch.h +5 -0
  127. package/cpp/llama.cpp/src/llama-batch.cpp +76 -70
  128. package/cpp/llama.cpp/src/llama-batch.h +24 -18
  129. package/cpp/llama.cpp/src/llama-chat.cpp +43 -1
  130. package/cpp/llama.cpp/src/llama-chat.h +2 -0
  131. package/cpp/llama.cpp/src/llama-context.cpp +180 -106
  132. package/cpp/llama.cpp/src/llama-context.h +26 -16
  133. package/cpp/llama.cpp/src/llama-cparams.h +3 -2
  134. package/cpp/llama.cpp/src/llama-graph.cpp +203 -39
  135. package/cpp/llama.cpp/src/llama-graph.h +147 -72
  136. package/cpp/llama.cpp/src/llama-hparams.cpp +40 -0
  137. package/cpp/llama.cpp/src/llama-hparams.h +10 -2
  138. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +11 -5
  139. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +3 -0
  140. package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +698 -302
  141. package/cpp/llama.cpp/src/llama-kv-cache-unified.h +89 -31
  142. package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +1 -0
  143. package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +16 -1
  144. package/cpp/llama.cpp/src/llama-model.cpp +1293 -312
  145. package/cpp/llama.cpp/src/llama-model.h +3 -4
  146. package/cpp/llama.cpp/src/llama-quant.cpp +1 -2
  147. package/cpp/llama.cpp/src/llama-vocab.cpp +363 -8
  148. package/cpp/llama.cpp/src/llama-vocab.h +2 -0
  149. package/cpp/llama.cpp/src/unicode.cpp +207 -0
  150. package/cpp/llama.cpp/src/unicode.h +2 -0
  151. package/ios/include/common.h +18 -4
  152. package/ios/include/llama.h +15 -7
  153. package/ios/libs/llama.xcframework/Info.plist +15 -15
  154. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  155. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5267 -5059
  156. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +15 -7
  157. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  158. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  159. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -5030
  160. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4014 -3889
  161. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -7
  162. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  163. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  164. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -5030
  165. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4016 -3891
  166. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +15 -7
  167. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +15 -7
  168. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  169. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +15 -7
  170. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  171. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  172. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  173. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5267 -5059
  174. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +15 -7
  175. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  176. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  177. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -5030
  178. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4014 -3889
  179. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -7
  180. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  181. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  182. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5303 -5095
  183. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +15 -7
  184. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  185. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  186. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5274 -5066
  187. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4044 -3919
  188. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -7
  189. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  190. package/package.json +4 -4
@@ -107,8 +107,10 @@ const char * llm_type_name(llm_type type) {
107
107
  case LLM_TYPE_17B_16E: return "17Bx16E (Scout)";
108
108
  case LLM_TYPE_17B_128E: return "17Bx128E (Maverick)";
109
109
  case LLM_TYPE_A13B: return "A13B";
110
+ case LLM_TYPE_21B_A3B: return "21B.A3B";
110
111
  case LLM_TYPE_30B_A3B: return "30B.A3B";
111
112
  case LLM_TYPE_235B_A22B: return "235B.A22B";
113
+ case LLM_TYPE_300B_A47B: return "300B.A47B";
112
114
  case LLM_TYPE_E2B: return "E2B";
113
115
  case LLM_TYPE_E4B: return "E4B";
114
116
  default: return "?B";
@@ -644,6 +646,9 @@ void llama_model::load_hparams(llama_model_loader & ml) {
644
646
  ml.get_key(LLM_KV_RESIDUAL_SCALE, hparams.f_residual_scale);
645
647
  ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale);
646
648
 
649
+ // MiniCPM uses rope by default, unlike Granite which uses it as a switch
650
+ hparams.rope_finetuned = true;
651
+
647
652
  switch (hparams.n_layer) {
648
653
  case 52: type = LLM_TYPE_1B; break;
649
654
  case 40: type = LLM_TYPE_2B; break;
@@ -849,6 +854,21 @@ void llama_model::load_hparams(llama_model_loader & ml) {
849
854
  default: type = LLM_TYPE_UNKNOWN;
850
855
  }
851
856
  } break;
857
+ case LLM_ARCH_DREAM:
858
+ {
859
+ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
860
+ // Dream models are primarily 7B with 28 layers
861
+ switch (hparams.n_layer) {
862
+ case 28:
863
+ type = LLM_TYPE_7B;
864
+ break;
865
+ default:
866
+ type = LLM_TYPE_UNKNOWN;
867
+ }
868
+ // Set non-causal attention for diffusion models
869
+ hparams.causal_attn = false;
870
+ }
871
+ break;
852
872
  case LLM_ARCH_QWEN2MOE:
853
873
  {
854
874
  ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false);
@@ -935,6 +955,33 @@ void llama_model::load_hparams(llama_model_loader & ml) {
935
955
  default: type = LLM_TYPE_UNKNOWN;
936
956
  }
937
957
  } break;
958
+ case LLM_ARCH_PLAMO2:
959
+ {
960
+ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
961
+
962
+ // Load Mamba SSM parameters
963
+ ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv);
964
+ ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner);
965
+ ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state);
966
+ ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank);
967
+ ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group);
968
+
969
+ for (uint32_t i = 0; i < hparams.n_layer; ++i) {
970
+ hparams.recurrent_layer_arr[i] = hparams.n_head_kv(i) == 0;
971
+ }
972
+
973
+ switch (hparams.n_layer) {
974
+ case 16: type = LLM_TYPE_1B; break;
975
+ case 32:
976
+ if (hparams.n_embd == 2048) {
977
+ type = LLM_TYPE_2B;
978
+ } else if (hparams.n_embd == 4096) {
979
+ type = LLM_TYPE_8B;
980
+ }
981
+ break;
982
+ default: type = LLM_TYPE_UNKNOWN;
983
+ }
984
+ } break;
938
985
  case LLM_ARCH_GPT2:
939
986
  {
940
987
  ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
@@ -1322,7 +1369,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
1322
1369
  // that have no expert_gating_func model parameter set
1323
1370
  hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX;
1324
1371
  }
1325
- ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul);
1372
+ ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, false);
1326
1373
 
1327
1374
  switch (hparams.n_layer) {
1328
1375
  case 27: type = LLM_TYPE_16B; break;
@@ -1446,6 +1493,23 @@ void llama_model::load_hparams(llama_model_loader & ml) {
1446
1493
  default: type = LLM_TYPE_UNKNOWN;
1447
1494
  }
1448
1495
  } break;
1496
+ case LLM_ARCH_EXAONE4:
1497
+ {
1498
+ if (hparams.n_layer == 64) { // 32B
1499
+ hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
1500
+ hparams.n_swa = 4096;
1501
+ hparams.set_swa_pattern(4);
1502
+ }
1503
+
1504
+ ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
1505
+ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
1506
+
1507
+ switch (hparams.n_layer) {
1508
+ case 30: type = LLM_TYPE_1_2B; break;
1509
+ case 64: type = LLM_TYPE_32B; break;
1510
+ default: type = LLM_TYPE_UNKNOWN;
1511
+ }
1512
+ } break;
1449
1513
  case LLM_ARCH_RWKV6:
1450
1514
  case LLM_ARCH_RWKV6QWEN2:
1451
1515
  {
@@ -1483,7 +1547,11 @@ void llama_model::load_hparams(llama_model_loader & ml) {
1483
1547
  ml.get_key(LLM_KV_TOKEN_SHIFT_COUNT, hparams.token_shift_count, false);
1484
1548
 
1485
1549
  switch (hparams.n_layer) {
1486
- case 12: type = LLM_TYPE_190M; break;
1550
+ case 12:
1551
+ switch (hparams.n_embd) {
1552
+ case 768: type = LLM_TYPE_190M; break;
1553
+ default: type = LLM_TYPE_UNKNOWN;
1554
+ } break;
1487
1555
  case 24:
1488
1556
  switch (hparams.n_embd) {
1489
1557
  case 1024: type = LLM_TYPE_450M; break;
@@ -1496,7 +1564,17 @@ void llama_model::load_hparams(llama_model_loader & ml) {
1496
1564
  case 3584: type = LLM_TYPE_7B; break;
1497
1565
  default: type = LLM_TYPE_UNKNOWN;
1498
1566
  } break;
1499
- case 32: type = LLM_TYPE_2_9B; break; // RWKV-7-World
1567
+ case 32:
1568
+ switch (hparams.n_embd) {
1569
+ case 2560: type = LLM_TYPE_2_9B; break;
1570
+ case 4096: type = LLM_TYPE_7B; break;
1571
+ default: type = LLM_TYPE_UNKNOWN;
1572
+ } break;
1573
+ case 61:
1574
+ switch (hparams.n_embd) {
1575
+ case 4096: type = LLM_TYPE_14B; break;
1576
+ default: type = LLM_TYPE_UNKNOWN;
1577
+ } break;
1500
1578
  default: type = LLM_TYPE_UNKNOWN;
1501
1579
  }
1502
1580
  } break;
@@ -1607,10 +1685,20 @@ void llama_model::load_hparams(llama_model_loader & ml) {
1607
1685
  }
1608
1686
  } break;
1609
1687
  case LLM_ARCH_ERNIE4_5:
1688
+ case LLM_ARCH_ERNIE4_5_MOE:
1610
1689
  {
1611
1690
  ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
1691
+ if (arch == LLM_ARCH_ERNIE4_5_MOE) {
1692
+ ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
1693
+ ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false);
1694
+ ml.get_key(LLM_KV_INTERLEAVE_MOE_LAYER_STEP, hparams.n_moe_layer_step);
1695
+ ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead);
1696
+ }
1697
+
1612
1698
  switch (hparams.n_layer) {
1613
1699
  case 18: type = LLM_TYPE_0_3B; break;
1700
+ case 28: type = LLM_TYPE_21B_A3B; break;
1701
+ case 54: type = LLM_TYPE_300B_A47B; break;
1614
1702
  default: type = LLM_TYPE_UNKNOWN;
1615
1703
  }
1616
1704
  } break;
@@ -2643,12 +2731,14 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
2643
2731
  } break;
2644
2732
  case LLM_ARCH_QWEN2:
2645
2733
  case LLM_ARCH_QWEN2VL:
2734
+ case LLM_ARCH_DREAM:
2646
2735
  {
2647
2736
  tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
2648
2737
 
2649
2738
  // output
2650
2739
  output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
2651
2740
  output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
2741
+ output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"), {n_vocab}, TENSOR_NOT_REQUIRED);
2652
2742
  // if output is NULL, init from the input tok embed
2653
2743
  if (output == NULL) {
2654
2744
  output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
@@ -2938,6 +3028,73 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
2938
3028
  layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
2939
3029
  }
2940
3030
  } break;
3031
+ case LLM_ARCH_PLAMO2:
3032
+ {
3033
+ const uint32_t d_conv = hparams.ssm_d_conv;
3034
+ const uint32_t d_state = hparams.ssm_d_state;
3035
+ const uint32_t num_heads = hparams.ssm_dt_rank;
3036
+ const uint32_t intermediate_size = hparams.ssm_d_inner;
3037
+ const uint32_t head_dim = intermediate_size / num_heads;
3038
+ const uint32_t qk_dim = head_dim;
3039
+ const uint32_t v_dim = head_dim;
3040
+ const int64_t num_attention_heads = hparams.n_head();
3041
+ const int64_t q_num_heads = num_attention_heads;
3042
+ const int64_t dt_dim = std::max(64, int(hparams.n_embd / 16));
3043
+
3044
+ tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
3045
+
3046
+ // output
3047
+ output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
3048
+ output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
3049
+ // if output is NULL, init from the input tok embed
3050
+ if (output == NULL) {
3051
+ output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
3052
+ }
3053
+
3054
+ for (int i = 0; i < n_layer; ++i) {
3055
+ auto & layer = layers[i];
3056
+ bool is_mamba_layer = hparams.is_recurrent(i);
3057
+
3058
+ layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
3059
+
3060
+ if (is_mamba_layer) {
3061
+ layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, 2 * intermediate_size}, 0);
3062
+ layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, intermediate_size}, 0);
3063
+
3064
+ layer.ssm_x = create_tensor(tn(LLM_TENSOR_SSM_X, "weight", i), {intermediate_size, dt_dim + 2*d_state}, 0);
3065
+ layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "weight", i), {dt_dim, num_heads}, 0);
3066
+ layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {num_heads}, 0);
3067
+
3068
+ layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {num_heads}, 0);
3069
+ layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {num_heads}, 0);
3070
+
3071
+ layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {intermediate_size, n_embd}, 0);
3072
+
3073
+ layer.ssm_dt_norm = create_tensor(tn(LLM_TENSOR_SSM_DT_NORM, i), {dt_dim}, 0);
3074
+ layer.ssm_b_norm = create_tensor(tn(LLM_TENSOR_SSM_B_NORM, i), {d_state}, 0);
3075
+ layer.ssm_c_norm = create_tensor(tn(LLM_TENSOR_SSM_C_NORM, i), {d_state}, 0);
3076
+ } else {
3077
+ const int64_t num_key_value_heads = hparams.n_head_kv(i);
3078
+ const int64_t k_num_heads = num_key_value_heads;
3079
+ const int64_t v_num_heads = num_key_value_heads;
3080
+ const int64_t q_proj_dim = q_num_heads * qk_dim;
3081
+ const int64_t k_proj_dim = k_num_heads * qk_dim;
3082
+ const int64_t v_proj_dim = v_num_heads * v_dim;
3083
+
3084
+ layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, q_proj_dim + k_proj_dim + v_proj_dim}, 0);
3085
+ layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {head_dim, num_attention_heads}, 0);
3086
+ layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {head_dim, k_num_heads}, 0);
3087
+ layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {q_num_heads * v_dim, n_embd}, 0);
3088
+ }
3089
+
3090
+ // All layers have post-attention norm, FFN norm, and FFN tensors
3091
+ layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, i), {n_embd}, 0);
3092
+ layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
3093
+ layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
3094
+ layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff * 2}, 0);
3095
+ layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, i), {n_embd}, 0);
3096
+ }
3097
+ } break;
2941
3098
  case LLM_ARCH_GPT2:
2942
3099
  {
2943
3100
  tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
@@ -4232,6 +4389,39 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
4232
4389
  layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
4233
4390
  }
4234
4391
  } break;
4392
+ case LLM_ARCH_EXAONE4:
4393
+ {
4394
+ tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
4395
+
4396
+ // output
4397
+ output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
4398
+ output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
4399
+
4400
+ // if output is NULL, init from the input tok embed
4401
+ if (output == NULL) {
4402
+ output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
4403
+ }
4404
+
4405
+ for (int i = 0; i < n_layer; ++i) {
4406
+ auto & layer = layers[i];
4407
+
4408
+ layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
4409
+ layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0);
4410
+ layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0);
4411
+ layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
4412
+
4413
+ layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
4414
+
4415
+ layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0);
4416
+ layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0);
4417
+ layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0);
4418
+
4419
+ layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
4420
+ layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
4421
+ layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
4422
+ layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
4423
+ }
4424
+ } break;
4235
4425
  case LLM_ARCH_RWKV6:
4236
4426
  {
4237
4427
  tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
@@ -4747,6 +4937,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
4747
4937
  }
4748
4938
  } break;
4749
4939
  case LLM_ARCH_ERNIE4_5:
4940
+ case LLM_ARCH_ERNIE4_5_MOE:
4750
4941
  {
4751
4942
  tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
4752
4943
 
@@ -4775,9 +4966,27 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
4775
4966
  layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED);
4776
4967
 
4777
4968
  layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
4778
- layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
4779
- layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
4780
- layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
4969
+
4970
+ if (arch == LLM_ARCH_ERNIE4_5_MOE && static_cast<uint32_t>(i) >= hparams.n_layer_dense_lead) { // MoE layers
4971
+ int n_ff_exp = hparams.n_ff_exp;
4972
+
4973
+ layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
4974
+ layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED);
4975
+ layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, TENSOR_NOT_REQUIRED);
4976
+ layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert}, 0);
4977
+ layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0);
4978
+
4979
+ // Shared expert (if present)
4980
+ if (hparams.n_ff_shexp > 0) {
4981
+ layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, hparams.n_ff_shexp}, 0);
4982
+ layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd }, 0);
4983
+ layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, hparams.n_ff_shexp}, 0);
4984
+ }
4985
+ } else { // Dense layers
4986
+ layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
4987
+ layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
4988
+ layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
4989
+ }
4781
4990
  }
4782
4991
  } break;
4783
4992
  case LLM_ARCH_FALCON_H1:
@@ -5209,6 +5418,7 @@ void llama_model::print_info() const {
5209
5418
  arch == LLM_ARCH_MAMBA2 ||
5210
5419
  arch == LLM_ARCH_JAMBA ||
5211
5420
  arch == LLM_ARCH_FALCON_H1 ||
5421
+ arch == LLM_ARCH_PLAMO2 ||
5212
5422
  arch == LLM_ARCH_GRANITE_HYBRID) {
5213
5423
  LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv);
5214
5424
  LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner);
@@ -5381,7 +5591,7 @@ ggml_tensor * llama_model::get_rope_factors(const llama_cparams & cparams, int i
5381
5591
  }
5382
5592
 
5383
5593
  struct llm_build_llama : public llm_graph_context {
5384
- llm_build_llama(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
5594
+ llm_build_llama(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
5385
5595
  const int64_t n_embd_head = hparams.n_embd_head_v;
5386
5596
 
5387
5597
  GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -5457,7 +5667,7 @@ struct llm_build_llama : public llm_graph_context {
5457
5667
  cb(Kcur, "Kcur", il);
5458
5668
  cb(Vcur, "Vcur", il);
5459
5669
 
5460
- cur = build_attn(inp_attn, gf,
5670
+ cur = build_attn(inp_attn,
5461
5671
  model.layers[il].wo, model.layers[il].bo,
5462
5672
  Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
5463
5673
  cb(cur, "attn_out", il);
@@ -5537,7 +5747,7 @@ struct llm_build_llama : public llm_graph_context {
5537
5747
  };
5538
5748
 
5539
5749
  struct llm_build_llama_iswa : public llm_graph_context {
5540
- llm_build_llama_iswa(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
5750
+ llm_build_llama_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
5541
5751
  const int64_t n_embd_head = hparams.n_embd_head_v;
5542
5752
 
5543
5753
  GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -5631,7 +5841,7 @@ struct llm_build_llama_iswa : public llm_graph_context {
5631
5841
  cb(Kcur, "Kcur_normed", il);
5632
5842
  }
5633
5843
 
5634
- cur = build_attn(inp_attn, gf,
5844
+ cur = build_attn(inp_attn,
5635
5845
  model.layers[il].wo, model.layers[il].bo,
5636
5846
  Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
5637
5847
  cb(cur, "attn_out", il);
@@ -5720,7 +5930,7 @@ struct llm_build_llama_iswa : public llm_graph_context {
5720
5930
  };
5721
5931
 
5722
5932
  struct llm_build_deci : public llm_graph_context {
5723
- llm_build_deci(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
5933
+ llm_build_deci(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
5724
5934
  const int64_t n_embd_head = hparams.n_embd_head_v;
5725
5935
 
5726
5936
  GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -5808,7 +6018,7 @@ struct llm_build_deci : public llm_graph_context {
5808
6018
  cb(Kcur, "Kcur", il);
5809
6019
  cb(Vcur, "Vcur", il);
5810
6020
 
5811
- cur = build_attn(inp_attn, gf,
6021
+ cur = build_attn(inp_attn,
5812
6022
  model.layers[il].wo, model.layers[il].bo,
5813
6023
  Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
5814
6024
  }
@@ -5876,7 +6086,7 @@ struct llm_build_deci : public llm_graph_context {
5876
6086
  };
5877
6087
 
5878
6088
  struct llm_build_baichuan : public llm_graph_context {
5879
- llm_build_baichuan(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
6089
+ llm_build_baichuan(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
5880
6090
  const int64_t n_embd_head = hparams.n_embd_head_v;
5881
6091
 
5882
6092
  GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -5940,7 +6150,7 @@ struct llm_build_baichuan : public llm_graph_context {
5940
6150
  cb(Kcur, "Kcur", il);
5941
6151
  cb(Vcur, "Vcur", il);
5942
6152
 
5943
- cur = build_attn(inp_attn, gf,
6153
+ cur = build_attn(inp_attn,
5944
6154
  model.layers[il].wo, NULL,
5945
6155
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
5946
6156
  }
@@ -5998,7 +6208,7 @@ struct llm_build_baichuan : public llm_graph_context {
5998
6208
  };
5999
6209
 
6000
6210
  struct llm_build_xverse : public llm_graph_context {
6001
- llm_build_xverse(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
6211
+ llm_build_xverse(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
6002
6212
  const int64_t n_embd_head = hparams.n_embd_head_v;
6003
6213
 
6004
6214
  GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -6055,7 +6265,7 @@ struct llm_build_xverse : public llm_graph_context {
6055
6265
  cb(Kcur, "Kcur", il);
6056
6266
  cb(Vcur, "Vcur", il);
6057
6267
 
6058
- cur = build_attn(inp_attn, gf,
6268
+ cur = build_attn(inp_attn,
6059
6269
  model.layers[il].wo, NULL,
6060
6270
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
6061
6271
  }
@@ -6111,7 +6321,7 @@ struct llm_build_xverse : public llm_graph_context {
6111
6321
  };
6112
6322
 
6113
6323
  struct llm_build_falcon : public llm_graph_context {
6114
- llm_build_falcon(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
6324
+ llm_build_falcon(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
6115
6325
  const int64_t n_embd_head = hparams.n_embd_head_v;
6116
6326
  const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
6117
6327
 
@@ -6178,7 +6388,7 @@ struct llm_build_falcon : public llm_graph_context {
6178
6388
  cb(Kcur, "Kcur", il);
6179
6389
  cb(Vcur, "Vcur", il);
6180
6390
 
6181
- cur = build_attn(inp_attn, gf,
6391
+ cur = build_attn(inp_attn,
6182
6392
  model.layers[il].wo, NULL,
6183
6393
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
6184
6394
  }
@@ -6233,7 +6443,7 @@ struct llm_build_falcon : public llm_graph_context {
6233
6443
  };
6234
6444
 
6235
6445
  struct llm_build_grok : public llm_graph_context {
6236
- llm_build_grok(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
6446
+ llm_build_grok(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
6237
6447
  const int64_t n_embd_head = hparams.n_embd_head_v;
6238
6448
 
6239
6449
  GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -6308,7 +6518,7 @@ struct llm_build_grok : public llm_graph_context {
6308
6518
  cb(Kcur, "Kcur", il);
6309
6519
  cb(Vcur, "Vcur", il);
6310
6520
 
6311
- cur = build_attn(inp_attn, gf,
6521
+ cur = build_attn(inp_attn,
6312
6522
  model.layers[il].wo, model.layers[il].bo,
6313
6523
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il);
6314
6524
  }
@@ -6395,7 +6605,7 @@ struct llm_build_grok : public llm_graph_context {
6395
6605
  };
6396
6606
 
6397
6607
  struct llm_build_dbrx : public llm_graph_context {
6398
- llm_build_dbrx(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
6608
+ llm_build_dbrx(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
6399
6609
  const int64_t n_embd_head = hparams.n_embd_head_v;
6400
6610
  const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
6401
6611
 
@@ -6457,7 +6667,7 @@ struct llm_build_dbrx : public llm_graph_context {
6457
6667
  cb(Kcur, "Kcur", il);
6458
6668
  cb(Vcur, "Vcur", il);
6459
6669
 
6460
- cur = build_attn(inp_attn, gf,
6670
+ cur = build_attn(inp_attn,
6461
6671
  model.layers[il].wo, NULL,
6462
6672
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
6463
6673
  }
@@ -6520,7 +6730,7 @@ struct llm_build_dbrx : public llm_graph_context {
6520
6730
  };
6521
6731
 
6522
6732
  struct llm_build_starcoder : public llm_graph_context {
6523
- llm_build_starcoder(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
6733
+ llm_build_starcoder(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
6524
6734
  const int64_t n_embd_head = hparams.n_embd_head_v;
6525
6735
  const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
6526
6736
 
@@ -6571,7 +6781,7 @@ struct llm_build_starcoder : public llm_graph_context {
6571
6781
  cb(Kcur, "Kcur", il);
6572
6782
  cb(Vcur, "Vcur", il);
6573
6783
 
6574
- cur = build_attn(inp_attn, gf,
6784
+ cur = build_attn(inp_attn,
6575
6785
  model.layers[il].wo, model.layers[il].bo,
6576
6786
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
6577
6787
  }
@@ -6629,7 +6839,7 @@ struct llm_build_starcoder : public llm_graph_context {
6629
6839
  };
6630
6840
 
6631
6841
  struct llm_build_refact : public llm_graph_context {
6632
- llm_build_refact(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
6842
+ llm_build_refact(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
6633
6843
  const int64_t n_embd_head = hparams.n_embd_head_v;
6634
6844
 
6635
6845
  GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -6670,7 +6880,7 @@ struct llm_build_refact : public llm_graph_context {
6670
6880
  cb(Kcur, "Kcur", il);
6671
6881
  cb(Vcur, "Vcur", il);
6672
6882
 
6673
- cur = build_attn(inp_attn, gf,
6883
+ cur = build_attn(inp_attn,
6674
6884
  model.layers[il].wo, NULL,
6675
6885
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
6676
6886
  }
@@ -6728,7 +6938,7 @@ struct llm_build_refact : public llm_graph_context {
6728
6938
  };
6729
6939
 
6730
6940
  struct llm_build_bert : public llm_graph_context {
6731
- llm_build_bert(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
6941
+ llm_build_bert(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
6732
6942
  const int64_t n_embd_head = hparams.n_embd_head_v;
6733
6943
  const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
6734
6944
 
@@ -6827,7 +7037,7 @@ struct llm_build_bert : public llm_graph_context {
6827
7037
  cb(Kcur, "Kcur", il);
6828
7038
  cb(Vcur, "Vcur", il);
6829
7039
 
6830
- cur = build_attn(inp_attn, gf,
7040
+ cur = build_attn(inp_attn,
6831
7041
  model.layers[il].wo, model.layers[il].bo,
6832
7042
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
6833
7043
  cb(cur, "kqv_out", il);
@@ -6914,7 +7124,7 @@ struct llm_build_bert : public llm_graph_context {
6914
7124
  };
6915
7125
 
6916
7126
  struct llm_build_neo_bert : public llm_graph_context {
6917
- llm_build_neo_bert(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
7127
+ llm_build_neo_bert(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
6918
7128
  const int64_t n_embd_head = hparams.n_embd_head_v;
6919
7129
  const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
6920
7130
 
@@ -6972,7 +7182,7 @@ struct llm_build_neo_bert : public llm_graph_context {
6972
7182
  cb(Kcur, "Kcur", il);
6973
7183
  cb(Vcur, "Vcur", il);
6974
7184
 
6975
- cur = build_attn(inp_attn, gf,
7185
+ cur = build_attn(inp_attn,
6976
7186
  model.layers[il].wo, nullptr,
6977
7187
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
6978
7188
  cb(cur, "kqv_out", il);
@@ -7024,7 +7234,7 @@ struct llm_build_neo_bert : public llm_graph_context {
7024
7234
  };
7025
7235
 
7026
7236
  struct llm_build_bloom : public llm_graph_context {
7027
- llm_build_bloom(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
7237
+ llm_build_bloom(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
7028
7238
  const int64_t n_embd_head = hparams.n_embd_head_v;
7029
7239
  const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
7030
7240
 
@@ -7072,7 +7282,7 @@ struct llm_build_bloom : public llm_graph_context {
7072
7282
  cb(Kcur, "Kcur", il);
7073
7283
  cb(Vcur, "Vcur", il);
7074
7284
 
7075
- cur = build_attn(inp_attn, gf,
7285
+ cur = build_attn(inp_attn,
7076
7286
  model.layers[il].wo, model.layers[il].bo,
7077
7287
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
7078
7288
  }
@@ -7130,7 +7340,7 @@ struct llm_build_bloom : public llm_graph_context {
7130
7340
  };
7131
7341
 
7132
7342
  struct llm_build_mpt : public llm_graph_context {
7133
- llm_build_mpt(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
7343
+ llm_build_mpt(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
7134
7344
  const int64_t n_embd_head = hparams.n_embd_head_v;
7135
7345
  const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
7136
7346
 
@@ -7219,7 +7429,7 @@ struct llm_build_mpt : public llm_graph_context {
7219
7429
  cb(Kcur, "Kcur", il);
7220
7430
  cb(Vcur, "Vcur", il);
7221
7431
 
7222
- cur = build_attn(inp_attn, gf,
7432
+ cur = build_attn(inp_attn,
7223
7433
  model.layers[il].wo, model.layers[il].bo,
7224
7434
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
7225
7435
  }
@@ -7278,7 +7488,7 @@ struct llm_build_mpt : public llm_graph_context {
7278
7488
  };
7279
7489
 
7280
7490
  struct llm_build_stablelm : public llm_graph_context {
7281
- llm_build_stablelm(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
7491
+ llm_build_stablelm(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
7282
7492
  const int64_t n_embd_head = hparams.n_embd_head_v;
7283
7493
 
7284
7494
  GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -7365,7 +7575,7 @@ struct llm_build_stablelm : public llm_graph_context {
7365
7575
  cb(Kcur, "Kcur", il);
7366
7576
  cb(Vcur, "Vcur", il);
7367
7577
 
7368
- cur = build_attn(inp_attn, gf,
7578
+ cur = build_attn(inp_attn,
7369
7579
  model.layers[il].wo, NULL,
7370
7580
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
7371
7581
  }
@@ -7430,7 +7640,7 @@ struct llm_build_stablelm : public llm_graph_context {
7430
7640
  };
7431
7641
 
7432
7642
  struct llm_build_qwen : public llm_graph_context {
7433
- llm_build_qwen(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
7643
+ llm_build_qwen(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
7434
7644
  const int64_t n_embd_head = hparams.n_embd_head_v;
7435
7645
 
7436
7646
  GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -7486,7 +7696,7 @@ struct llm_build_qwen : public llm_graph_context {
7486
7696
  cb(Kcur, "Kcur", il);
7487
7697
  cb(Vcur, "Vcur", il);
7488
7698
 
7489
- cur = build_attn(inp_attn, gf,
7699
+ cur = build_attn(inp_attn,
7490
7700
  model.layers[il].wo, NULL,
7491
7701
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
7492
7702
  }
@@ -7544,7 +7754,7 @@ struct llm_build_qwen : public llm_graph_context {
7544
7754
  };
7545
7755
 
7546
7756
  struct llm_build_qwen2 : public llm_graph_context {
7547
- llm_build_qwen2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
7757
+ llm_build_qwen2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
7548
7758
  const int64_t n_embd_head = hparams.n_embd_head_v;
7549
7759
 
7550
7760
  GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -7606,7 +7816,7 @@ struct llm_build_qwen2 : public llm_graph_context {
7606
7816
  cb(Kcur, "Kcur", il);
7607
7817
  cb(Vcur, "Vcur", il);
7608
7818
 
7609
- cur = build_attn(inp_attn, gf,
7819
+ cur = build_attn(inp_attn,
7610
7820
  model.layers[il].wo, model.layers[il].bo,
7611
7821
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
7612
7822
  }
@@ -7654,6 +7864,113 @@ struct llm_build_qwen2 : public llm_graph_context {
7654
7864
  // lm_head
7655
7865
  cur = build_lora_mm(model.output, cur);
7656
7866
 
7867
+ if (model.output_b != nullptr) {
7868
+ cur = ggml_add(ctx0, cur, model.output_b);
7869
+ }
7870
+
7871
+ cb(cur, "result_output", -1);
7872
+ res->t_logits = cur;
7873
+
7874
+ ggml_build_forward_expand(gf, cur);
7875
+ }
7876
+ };
7877
+
7878
+ struct llm_build_dream : public llm_graph_context {
7879
+ llm_build_dream(const llama_model & model, const llm_graph_params & params) :
7880
+ llm_graph_context(params) {
7881
+ //copied from qwen2
7882
+ const int64_t n_embd_head = hparams.n_embd_head_v;
7883
+
7884
+ GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
7885
+ GGML_ASSERT(n_embd_head == hparams.n_rot);
7886
+
7887
+ ggml_tensor * cur;
7888
+ ggml_tensor * inpL;
7889
+
7890
+ inpL = build_inp_embd(model.tok_embd);
7891
+
7892
+ // inp_pos - contains the positions
7893
+ ggml_tensor * inp_pos = build_inp_pos();
7894
+
7895
+ auto * inp_attn = build_attn_inp_no_cache();
7896
+
7897
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
7898
+
7899
+ for (int il = 0; il < n_layer; ++il) {
7900
+ ggml_tensor * inpSA = inpL;
7901
+
7902
+ // norm
7903
+ cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
7904
+ cb(cur, "attn_norm", il);
7905
+
7906
+ // self-attention
7907
+ {
7908
+ // compute Q and K and RoPE them
7909
+ ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
7910
+ Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
7911
+ cb(Qcur, "Qcur", il);
7912
+
7913
+ ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
7914
+ Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
7915
+ cb(Kcur, "Kcur", il);
7916
+
7917
+ ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
7918
+ Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
7919
+ cb(Vcur, "Vcur", il);
7920
+
7921
+ Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
7922
+ Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
7923
+ Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
7924
+
7925
+ Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
7926
+ ext_factor, attn_factor, beta_fast, beta_slow);
7927
+
7928
+ Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
7929
+ ext_factor, attn_factor, beta_fast, beta_slow);
7930
+
7931
+ cb(Qcur, "Qcur", il);
7932
+ cb(Kcur, "Kcur", il);
7933
+ cb(Vcur, "Vcur", il);
7934
+
7935
+ cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, nullptr,
7936
+ nullptr, 1.0f / sqrtf(float(n_embd_head)), il);
7937
+ }
7938
+
7939
+ if (il == n_layer - 1 && inp_out_ids) {
7940
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
7941
+ inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
7942
+ }
7943
+
7944
+ ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
7945
+ cb(ffn_inp, "ffn_inp", il);
7946
+
7947
+ // feed-forward network
7948
+ cur = build_norm(ffn_inp, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il);
7949
+ cb(cur, "ffn_norm", il);
7950
+
7951
+ cur = build_ffn(cur, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL,
7952
+ model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, il);
7953
+ cb(cur, "ffn_out", il);
7954
+
7955
+ cur = ggml_add(ctx0, cur, ffn_inp);
7956
+
7957
+ cur = build_cvec(cur, il);
7958
+ cb(cur, "l_out", il);
7959
+
7960
+ // input for next layer
7961
+ inpL = cur;
7962
+ }
7963
+
7964
+ cur = inpL;
7965
+
7966
+ cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1);
7967
+
7968
+ cb(cur, "result_norm", -1);
7969
+ res->t_embd = cur;
7970
+
7971
+ // lm_head
7972
+ cur = build_lora_mm(model.output, cur);
7973
+
7657
7974
  cb(cur, "result_output", -1);
7658
7975
  res->t_logits = cur;
7659
7976
 
@@ -7662,7 +7979,7 @@ struct llm_build_qwen2 : public llm_graph_context {
7662
7979
  };
7663
7980
 
7664
7981
  struct llm_build_qwen2vl : public llm_graph_context {
7665
- llm_build_qwen2vl(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
7982
+ llm_build_qwen2vl(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
7666
7983
  const int64_t n_embd_head = hparams.n_embd_head_v;
7667
7984
 
7668
7985
  GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -7727,7 +8044,7 @@ struct llm_build_qwen2vl : public llm_graph_context {
7727
8044
  cb(Kcur, "Kcur", il);
7728
8045
  cb(Vcur, "Vcur", il);
7729
8046
 
7730
- cur = build_attn(inp_attn, gf,
8047
+ cur = build_attn(inp_attn,
7731
8048
  model.layers[il].wo, model.layers[il].bo,
7732
8049
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
7733
8050
  }
@@ -7783,7 +8100,7 @@ struct llm_build_qwen2vl : public llm_graph_context {
7783
8100
  };
7784
8101
 
7785
8102
  struct llm_build_qwen2moe : public llm_graph_context {
7786
- llm_build_qwen2moe(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
8103
+ llm_build_qwen2moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
7787
8104
  const int64_t n_embd_head = hparams.n_embd_head_v;
7788
8105
 
7789
8106
  GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -7854,7 +8171,7 @@ struct llm_build_qwen2moe : public llm_graph_context {
7854
8171
  cb(Kcur, "Kcur", il);
7855
8172
  cb(Vcur, "Vcur", il);
7856
8173
 
7857
- cur = build_attn(inp_attn, gf,
8174
+ cur = build_attn(inp_attn,
7858
8175
  model.layers[il].wo, model.layers[il].bo,
7859
8176
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
7860
8177
  }
@@ -7942,7 +8259,7 @@ struct llm_build_qwen2moe : public llm_graph_context {
7942
8259
  };
7943
8260
 
7944
8261
  struct llm_build_qwen3 : public llm_graph_context {
7945
- llm_build_qwen3(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
8262
+ llm_build_qwen3(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
7946
8263
  const int64_t n_embd_head = hparams.n_embd_head_v;
7947
8264
 
7948
8265
  GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -8007,7 +8324,7 @@ struct llm_build_qwen3 : public llm_graph_context {
8007
8324
  cb(Kcur, "Kcur", il);
8008
8325
  cb(Vcur, "Vcur", il);
8009
8326
 
8010
- cur = build_attn(inp_attn, gf,
8327
+ cur = build_attn(inp_attn,
8011
8328
  model.layers[il].wo, model.layers[il].bo,
8012
8329
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
8013
8330
  }
@@ -8063,7 +8380,7 @@ struct llm_build_qwen3 : public llm_graph_context {
8063
8380
  };
8064
8381
 
8065
8382
  struct llm_build_qwen3moe : public llm_graph_context {
8066
- llm_build_qwen3moe(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
8383
+ llm_build_qwen3moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
8067
8384
  const int64_t n_embd_head = hparams.n_embd_head_v;
8068
8385
 
8069
8386
  GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -8128,7 +8445,7 @@ struct llm_build_qwen3moe : public llm_graph_context {
8128
8445
  cb(Kcur, "Kcur", il);
8129
8446
  cb(Vcur, "Vcur", il);
8130
8447
 
8131
- cur = build_attn(inp_attn, gf,
8448
+ cur = build_attn(inp_attn,
8132
8449
  model.layers[il].wo, model.layers[il].bo,
8133
8450
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
8134
8451
  }
@@ -8191,7 +8508,7 @@ struct llm_build_qwen3moe : public llm_graph_context {
8191
8508
  };
8192
8509
 
8193
8510
  struct llm_build_phi2 : public llm_graph_context {
8194
- llm_build_phi2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
8511
+ llm_build_phi2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
8195
8512
  const int64_t n_embd_head = hparams.n_embd_head_v;
8196
8513
  const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
8197
8514
 
@@ -8268,7 +8585,7 @@ struct llm_build_phi2 : public llm_graph_context {
8268
8585
  // ref: https://github.com/ml-explore/mlx-examples/blob/08e862336ade809bc37d1035f94b359e7d1a5152/phi2/phi2.py#L64-L66
8269
8586
  Qcur = ggml_scale(ctx0, Qcur, 1.0f/sqrtf(float(n_embd_head)));
8270
8587
 
8271
- cur = build_attn(inp_attn, gf,
8588
+ cur = build_attn(inp_attn,
8272
8589
  model.layers[il].wo, model.layers[il].bo,
8273
8590
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il);
8274
8591
  }
@@ -8322,7 +8639,7 @@ struct llm_build_phi2 : public llm_graph_context {
8322
8639
 
8323
8640
  template<bool iswa>
8324
8641
  struct llm_build_phi3 : public llm_graph_context {
8325
- llm_build_phi3(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
8642
+ llm_build_phi3(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
8326
8643
  const int64_t n_embd_head = hparams.n_embd_head_v;
8327
8644
  const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
8328
8645
 
@@ -8405,7 +8722,7 @@ struct llm_build_phi3 : public llm_graph_context {
8405
8722
  Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head)));
8406
8723
  cb(Qcur, "Qcur", il);
8407
8724
 
8408
- cur = build_attn(inp_attn, gf,
8725
+ cur = build_attn(inp_attn,
8409
8726
  model.layers[il].wo, model.layers[il].bo,
8410
8727
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il);
8411
8728
  }
@@ -8480,7 +8797,7 @@ struct llm_build_phi3 : public llm_graph_context {
8480
8797
  };
8481
8798
 
8482
8799
  struct llm_build_plamo : public llm_graph_context {
8483
- llm_build_plamo(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
8800
+ llm_build_plamo(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
8484
8801
  const int64_t n_embd_head = hparams.n_embd_head_v;
8485
8802
 
8486
8803
  GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -8539,7 +8856,7 @@ struct llm_build_plamo : public llm_graph_context {
8539
8856
  cb(Kcur, "Kcur", il);
8540
8857
  cb(Vcur, "Vcur", il);
8541
8858
 
8542
- cur = build_attn(inp_attn, gf,
8859
+ cur = build_attn(inp_attn,
8543
8860
  model.layers[il].wo, NULL,
8544
8861
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
8545
8862
  }
@@ -8595,7 +8912,7 @@ struct llm_build_plamo : public llm_graph_context {
8595
8912
  };
8596
8913
 
8597
8914
  struct llm_build_gpt2 : public llm_graph_context {
8598
- llm_build_gpt2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
8915
+ llm_build_gpt2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
8599
8916
  const int64_t n_embd_head = hparams.n_embd_head_v;
8600
8917
  const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
8601
8918
 
@@ -8647,7 +8964,7 @@ struct llm_build_gpt2 : public llm_graph_context {
8647
8964
  Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
8648
8965
  Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
8649
8966
 
8650
- cur = build_attn(inp_attn, gf,
8967
+ cur = build_attn(inp_attn,
8651
8968
  model.layers[il].wo, model.layers[il].bo,
8652
8969
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
8653
8970
  }
@@ -8705,7 +9022,7 @@ struct llm_build_gpt2 : public llm_graph_context {
8705
9022
  };
8706
9023
 
8707
9024
  struct llm_build_codeshell : public llm_graph_context {
8708
- llm_build_codeshell(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
9025
+ llm_build_codeshell(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
8709
9026
  const int64_t n_embd_head = hparams.n_embd_head_v;
8710
9027
  const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
8711
9028
 
@@ -8761,7 +9078,7 @@ struct llm_build_codeshell : public llm_graph_context {
8761
9078
  cb(Kcur, "Kcur", il);
8762
9079
  cb(Vcur, "Vcur", il);
8763
9080
 
8764
- cur = build_attn(inp_attn, gf,
9081
+ cur = build_attn(inp_attn,
8765
9082
  model.layers[il].wo, model.layers[il].bo,
8766
9083
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
8767
9084
  }
@@ -8819,7 +9136,7 @@ struct llm_build_codeshell : public llm_graph_context {
8819
9136
  };
8820
9137
 
8821
9138
  struct llm_build_orion : public llm_graph_context {
8822
- llm_build_orion(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
9139
+ llm_build_orion(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
8823
9140
  const int64_t n_embd_head = hparams.n_embd_head_v;
8824
9141
 
8825
9142
  GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -8890,7 +9207,7 @@ struct llm_build_orion : public llm_graph_context {
8890
9207
  cb(Kcur, "Kcur", il);
8891
9208
  cb(Vcur, "Vcur", il);
8892
9209
 
8893
- cur = build_attn(inp_attn, gf,
9210
+ cur = build_attn(inp_attn,
8894
9211
  model.layers[il].wo, NULL,
8895
9212
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
8896
9213
  }
@@ -8946,7 +9263,7 @@ struct llm_build_orion : public llm_graph_context {
8946
9263
  };
8947
9264
 
8948
9265
  struct llm_build_internlm2 : public llm_graph_context {
8949
- llm_build_internlm2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
9266
+ llm_build_internlm2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
8950
9267
  const int64_t n_embd_head = hparams.n_embd_head_v;
8951
9268
 
8952
9269
  GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -9017,7 +9334,7 @@ struct llm_build_internlm2 : public llm_graph_context {
9017
9334
  cb(Kcur, "Kcur", il);
9018
9335
  cb(Vcur, "Vcur", il);
9019
9336
 
9020
- cur = build_attn(inp_attn, gf,
9337
+ cur = build_attn(inp_attn,
9021
9338
  model.layers[il].wo, model.layers[il].bo,
9022
9339
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
9023
9340
  }
@@ -9073,7 +9390,7 @@ struct llm_build_internlm2 : public llm_graph_context {
9073
9390
  };
9074
9391
 
9075
9392
  struct llm_build_minicpm3 : public llm_graph_context {
9076
- llm_build_minicpm3(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
9393
+ llm_build_minicpm3(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
9077
9394
  //TODO: if the model varies, these parameters need to be read from the model
9078
9395
  const int64_t n_embd_base = 256;
9079
9396
  const float scale_embd = 12.0f;
@@ -9205,7 +9522,7 @@ struct llm_build_minicpm3 : public llm_graph_context {
9205
9522
  ggml_tensor * k_states = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0);
9206
9523
  cb(k_states, "k_states", il);
9207
9524
 
9208
- cur = build_attn(inp_attn, gf,
9525
+ cur = build_attn(inp_attn,
9209
9526
  model.layers[il].wo, NULL,
9210
9527
  q_states, k_states, v_states, nullptr, nullptr, kq_scale, il);
9211
9528
  }
@@ -9277,7 +9594,7 @@ struct llm_build_minicpm3 : public llm_graph_context {
9277
9594
  };
9278
9595
 
9279
9596
  struct llm_build_gemma : public llm_graph_context {
9280
- llm_build_gemma(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
9597
+ llm_build_gemma(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
9281
9598
  const int64_t n_embd_head = hparams.n_embd_head_v;
9282
9599
 
9283
9600
  ggml_tensor * cur;
@@ -9335,7 +9652,7 @@ struct llm_build_gemma : public llm_graph_context {
9335
9652
  Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head)));
9336
9653
  cb(Qcur, "Qcur_scaled", il);
9337
9654
 
9338
- cur = build_attn(inp_attn, gf,
9655
+ cur = build_attn(inp_attn,
9339
9656
  model.layers[il].wo, NULL,
9340
9657
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il);
9341
9658
  }
@@ -9393,7 +9710,7 @@ struct llm_build_gemma : public llm_graph_context {
9393
9710
  };
9394
9711
 
9395
9712
  struct llm_build_gemma2_iswa : public llm_graph_context {
9396
- llm_build_gemma2_iswa(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
9713
+ llm_build_gemma2_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
9397
9714
  const int64_t n_embd_head = hparams.n_embd_head_k;
9398
9715
 
9399
9716
  ggml_tensor * cur;
@@ -9450,7 +9767,7 @@ struct llm_build_gemma2_iswa : public llm_graph_context {
9450
9767
 
9451
9768
  Qcur = ggml_scale(ctx0, Qcur, hparams.f_attention_scale);
9452
9769
 
9453
- cur = build_attn(inp_attn, gf,
9770
+ cur = build_attn(inp_attn,
9454
9771
  model.layers[il].wo, NULL,
9455
9772
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il);
9456
9773
  }
@@ -9523,7 +9840,7 @@ struct llm_build_gemma2_iswa : public llm_graph_context {
9523
9840
  };
9524
9841
 
9525
9842
  struct llm_build_gemma3_iswa : public llm_graph_context {
9526
- llm_build_gemma3_iswa(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
9843
+ llm_build_gemma3_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
9527
9844
  const int64_t n_embd_head = hparams.n_embd_head_k;
9528
9845
 
9529
9846
  ggml_tensor * cur;
@@ -9592,7 +9909,7 @@ struct llm_build_gemma3_iswa : public llm_graph_context {
9592
9909
  // ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/model.py#L315
9593
9910
  Qcur = ggml_scale(ctx0, Qcur, hparams.f_attention_scale);
9594
9911
 
9595
- cur = build_attn(inp_attn, gf,
9912
+ cur = build_attn(inp_attn,
9596
9913
  model.layers[il].wo, NULL,
9597
9914
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il);
9598
9915
  }
@@ -9661,7 +9978,6 @@ struct llm_build_gemma3_iswa : public llm_graph_context {
9661
9978
 
9662
9979
  struct llm_build_gemma3n_iswa : public llm_graph_context {
9663
9980
  const llama_model & model;
9664
- ggml_cgraph * gf;
9665
9981
 
9666
9982
  const int64_t n_embd_head;
9667
9983
  const int64_t n_embd_altup;
@@ -9671,10 +9987,9 @@ struct llm_build_gemma3n_iswa : public llm_graph_context {
9671
9987
  const int n_layer_sparsity = 10; // number of layers using activation sparsity
9672
9988
  const float f_sparsity_std_mul = 1.6448533535003662f; // std_multiplier = normal_dist.icdf(0.95)
9673
9989
 
9674
- llm_build_gemma3n_iswa(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf)
9990
+ llm_build_gemma3n_iswa(const llama_model & model, const llm_graph_params & params)
9675
9991
  : llm_graph_context(params),
9676
9992
  model(model),
9677
- gf(gf),
9678
9993
  n_embd_head(model.hparams.n_embd_head_k),
9679
9994
  n_embd_altup(model.hparams.n_embd_altup),
9680
9995
  n_altup(model.hparams.n_altup),
@@ -9775,7 +10090,7 @@ struct llm_build_gemma3n_iswa : public llm_graph_context {
9775
10090
  cb(Qcur, "Qcur_pos", il);
9776
10091
  cb(Kcur, "Kcur_pos", il);
9777
10092
 
9778
- cur = build_attn(inp_attn, gf,
10093
+ cur = build_attn(inp_attn,
9779
10094
  model.layers[il].wo, NULL,
9780
10095
  Qcur, Kcur, Vcur, nullptr, nullptr, hparams.f_attention_scale, il);
9781
10096
  } else {
@@ -9793,7 +10108,7 @@ struct llm_build_gemma3n_iswa : public llm_graph_context {
9793
10108
  ext_factor, attn_factor, beta_fast, beta_slow);
9794
10109
  cb(Qcur, "Qcur_pos", il);
9795
10110
 
9796
- cur = build_attn(inp_attn, gf,
10111
+ cur = build_attn(inp_attn,
9797
10112
  model.layers[il].wo, NULL,
9798
10113
  Qcur, nullptr, nullptr, nullptr, nullptr, hparams.f_attention_scale, il);
9799
10114
  }
@@ -10087,7 +10402,7 @@ struct llm_build_gemma3n_iswa : public llm_graph_context {
10087
10402
 
10088
10403
  // TODO: move up next to build_starcoder
10089
10404
  struct llm_build_starcoder2 : public llm_graph_context {
10090
- llm_build_starcoder2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
10405
+ llm_build_starcoder2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
10091
10406
  const int64_t n_embd_head = hparams.n_embd_head_v;
10092
10407
 
10093
10408
  GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -10158,7 +10473,7 @@ struct llm_build_starcoder2 : public llm_graph_context {
10158
10473
  cb(Kcur, "Kcur", il);
10159
10474
  cb(Vcur, "Vcur", il);
10160
10475
 
10161
- cur = build_attn(inp_attn, gf,
10476
+ cur = build_attn(inp_attn,
10162
10477
  model.layers[il].wo, model.layers[il].bo,
10163
10478
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
10164
10479
  }
@@ -10219,7 +10534,6 @@ struct llm_graph_context_mamba : public llm_graph_context {
10219
10534
 
10220
10535
  ggml_tensor * build_mamba_layer(
10221
10536
  llm_graph_input_rs * inp,
10222
- ggml_cgraph * gf,
10223
10537
  ggml_tensor * cur,
10224
10538
  const llama_model & model,
10225
10539
  const llama_ubatch & ubatch,
@@ -10244,13 +10558,13 @@ struct llm_graph_context_mamba : public llm_graph_context {
10244
10558
  const int64_t n_seq_tokens = ubatch.n_seq_tokens;
10245
10559
 
10246
10560
  GGML_ASSERT(n_seqs != 0);
10247
- GGML_ASSERT(ubatch.equal_seqs);
10561
+ GGML_ASSERT(ubatch.equal_seqs());
10248
10562
  GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
10249
10563
 
10250
10564
  ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
10251
10565
  ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il);
10252
10566
 
10253
- ggml_tensor * conv = build_rs(inp, gf, conv_states_all, hparams.n_embd_r(), n_seqs);
10567
+ ggml_tensor * conv = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs);
10254
10568
  conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner, n_seqs);
10255
10569
 
10256
10570
  // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
@@ -10331,7 +10645,7 @@ struct llm_graph_context_mamba : public llm_graph_context {
10331
10645
  return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids);
10332
10646
  };
10333
10647
 
10334
- ggml_tensor * y_ssm = build_rs(inp, gf, ssm_states_all, hparams.n_embd_s(), ubatch.n_seqs, get_ssm_rows);
10648
+ ggml_tensor * y_ssm = build_rs(inp, ssm_states_all, hparams.n_embd_s(), ubatch.n_seqs, get_ssm_rows);
10335
10649
 
10336
10650
  // store last states
10337
10651
  ggml_build_forward_expand(gf,
@@ -10358,11 +10672,10 @@ struct llm_graph_context_mamba : public llm_graph_context {
10358
10672
 
10359
10673
  ggml_tensor * build_mamba2_layer(
10360
10674
  llm_graph_input_rs * inp,
10361
- ggml_cgraph * gf,
10362
- ggml_tensor * cur,
10363
- const llama_model & model,
10364
- const llama_ubatch & ubatch,
10365
- int il) const {
10675
+ ggml_tensor * cur,
10676
+ const llama_model & model,
10677
+ const llama_ubatch & ubatch,
10678
+ int il) const {
10366
10679
 
10367
10680
  const auto * mctx_cur = inp->mctx;
10368
10681
 
@@ -10379,13 +10692,13 @@ struct llm_graph_context_mamba : public llm_graph_context {
10379
10692
  const int64_t n_seq_tokens = ubatch.n_seq_tokens;
10380
10693
 
10381
10694
  GGML_ASSERT(n_seqs != 0);
10382
- GGML_ASSERT(ubatch.equal_seqs);
10695
+ GGML_ASSERT(ubatch.equal_seqs());
10383
10696
  GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
10384
10697
 
10385
10698
  ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
10386
10699
  ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il);
10387
10700
 
10388
- ggml_tensor * conv = build_rs(inp, gf, conv_states_all, hparams.n_embd_r(), n_seqs);
10701
+ ggml_tensor * conv = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs);
10389
10702
  conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs);
10390
10703
 
10391
10704
  // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
@@ -10455,7 +10768,7 @@ struct llm_graph_context_mamba : public llm_graph_context {
10455
10768
  return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids);
10456
10769
  };
10457
10770
 
10458
- ggml_tensor * y_ssm = build_rs(inp, gf, ssm_states_all, hparams.n_embd_s(), ubatch.n_seqs, get_ssm_rows);
10771
+ ggml_tensor * y_ssm = build_rs(inp, ssm_states_all, hparams.n_embd_s(), ubatch.n_seqs, get_ssm_rows);
10459
10772
 
10460
10773
  // store last states
10461
10774
  ggml_build_forward_expand(gf,
@@ -10491,7 +10804,7 @@ struct llm_graph_context_mamba : public llm_graph_context {
10491
10804
  };
10492
10805
 
10493
10806
  struct llm_build_mamba : public llm_graph_context_mamba {
10494
- llm_build_mamba(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context_mamba(params) {
10807
+ llm_build_mamba(const llama_model & model, const llm_graph_params & params) : llm_graph_context_mamba(params) {
10495
10808
  ggml_tensor * cur;
10496
10809
  ggml_tensor * inpL;
10497
10810
 
@@ -10510,9 +10823,9 @@ struct llm_build_mamba : public llm_graph_context_mamba {
10510
10823
  cb(cur, "attn_norm", il);
10511
10824
 
10512
10825
  if (model.arch == LLM_ARCH_MAMBA2) {
10513
- cur = build_mamba2_layer(rs_inp, gf, cur, model, ubatch, il);
10826
+ cur = build_mamba2_layer(rs_inp, cur, model, ubatch, il);
10514
10827
  } else {
10515
- cur = build_mamba_layer(rs_inp, gf, cur, model, ubatch, il);
10828
+ cur = build_mamba_layer(rs_inp, cur, model, ubatch, il);
10516
10829
  }
10517
10830
 
10518
10831
  if (il == n_layer - 1 && inp_out_ids) {
@@ -10548,7 +10861,7 @@ struct llm_build_mamba : public llm_graph_context_mamba {
10548
10861
  };
10549
10862
 
10550
10863
  struct llm_build_jamba : public llm_graph_context_mamba {
10551
- llm_build_jamba(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context_mamba(params) {
10864
+ llm_build_jamba(const llama_model & model, const llm_graph_params & params) : llm_graph_context_mamba(params) {
10552
10865
  const int64_t n_embd_head = hparams.n_embd_head_v;
10553
10866
 
10554
10867
  ggml_tensor * cur;
@@ -10568,7 +10881,7 @@ struct llm_build_jamba : public llm_graph_context_mamba {
10568
10881
  cb(cur, "attn_norm", il);
10569
10882
 
10570
10883
  if (n_head_kv == 0) {
10571
- cur = build_mamba_layer(inp_hybrid->get_recr(), gf, cur, model, ubatch, il);
10884
+ cur = build_mamba_layer(inp_hybrid->get_recr(), cur, model, ubatch, il);
10572
10885
  } else {
10573
10886
  // Attention
10574
10887
 
@@ -10589,7 +10902,7 @@ struct llm_build_jamba : public llm_graph_context_mamba {
10589
10902
  cb(Vcur, "Vcur", il);
10590
10903
 
10591
10904
  // No RoPE :)
10592
- cur = build_attn(inp_hybrid->get_attn(), gf, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, NULL, NULL, 1.0f/sqrtf(float(n_embd_head)), il);
10905
+ cur = build_attn(inp_hybrid->get_attn(), model.layers[il].wo, NULL, Qcur, Kcur, Vcur, NULL, NULL, 1.0f/sqrtf(float(n_embd_head)), il);
10593
10906
  }
10594
10907
 
10595
10908
  if (il == n_layer - 1 && inp_out_ids) {
@@ -10657,7 +10970,7 @@ struct llm_build_jamba : public llm_graph_context_mamba {
10657
10970
  };
10658
10971
 
10659
10972
  struct llm_build_command_r : public llm_graph_context {
10660
- llm_build_command_r(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
10973
+ llm_build_command_r(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
10661
10974
  const int64_t n_embd_head = hparams.n_embd_head_v;
10662
10975
 
10663
10976
  GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -10745,7 +11058,7 @@ struct llm_build_command_r : public llm_graph_context {
10745
11058
  cb(Kcur, "Kcur", il);
10746
11059
  cb(Vcur, "Vcur", il);
10747
11060
 
10748
- cur = build_attn(inp_attn, gf,
11061
+ cur = build_attn(inp_attn,
10749
11062
  model.layers[il].wo, model.layers[il].bo,
10750
11063
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
10751
11064
  }
@@ -10804,7 +11117,7 @@ struct llm_build_command_r : public llm_graph_context {
10804
11117
  };
10805
11118
 
10806
11119
  struct llm_build_cohere2_iswa : public llm_graph_context {
10807
- llm_build_cohere2_iswa(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
11120
+ llm_build_cohere2_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
10808
11121
  const int64_t n_embd_head = hparams.n_embd_head_v;
10809
11122
 
10810
11123
  GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -10880,7 +11193,7 @@ struct llm_build_cohere2_iswa : public llm_graph_context {
10880
11193
  cb(Kcur, "Kcur", il);
10881
11194
  cb(Vcur, "Vcur", il);
10882
11195
 
10883
- cur = build_attn(inp_attn, gf,
11196
+ cur = build_attn(inp_attn,
10884
11197
  model.layers[il].wo, model.layers[il].bo,
10885
11198
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
10886
11199
  }
@@ -10940,7 +11253,7 @@ struct llm_build_cohere2_iswa : public llm_graph_context {
10940
11253
  // * removed bias
10941
11254
  // * removed MoE
10942
11255
  struct llm_build_olmo : public llm_graph_context {
10943
- llm_build_olmo(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
11256
+ llm_build_olmo(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
10944
11257
  const int64_t n_embd_head = hparams.n_embd_head_v;
10945
11258
 
10946
11259
  GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -11011,7 +11324,7 @@ struct llm_build_olmo : public llm_graph_context {
11011
11324
  cb(Kcur, "Kcur", il);
11012
11325
  cb(Vcur, "Vcur", il);
11013
11326
 
11014
- cur = build_attn(inp_attn, gf,
11327
+ cur = build_attn(inp_attn,
11015
11328
  model.layers[il].wo, nullptr,
11016
11329
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
11017
11330
  }
@@ -11068,7 +11381,7 @@ struct llm_build_olmo : public llm_graph_context {
11068
11381
  };
11069
11382
 
11070
11383
  struct llm_build_olmo2 : public llm_graph_context {
11071
- llm_build_olmo2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
11384
+ llm_build_olmo2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
11072
11385
  const int64_t n_embd_head = hparams.n_embd_head_v;
11073
11386
 
11074
11387
  GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -11131,7 +11444,7 @@ struct llm_build_olmo2 : public llm_graph_context {
11131
11444
  cb(Kcur, "Kcur", il);
11132
11445
  cb(Vcur, "Vcur", il);
11133
11446
 
11134
- cur = build_attn(inp_attn, gf,
11447
+ cur = build_attn(inp_attn,
11135
11448
  model.layers[il].wo, NULL,
11136
11449
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
11137
11450
  }
@@ -11197,7 +11510,7 @@ struct llm_build_olmo2 : public llm_graph_context {
11197
11510
  // * removed bias
11198
11511
  // * added q, k norm
11199
11512
  struct llm_build_olmoe : public llm_graph_context {
11200
- llm_build_olmoe(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
11513
+ llm_build_olmoe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
11201
11514
  const int64_t n_embd_head = hparams.n_embd_head_v;
11202
11515
 
11203
11516
  GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -11264,7 +11577,7 @@ struct llm_build_olmoe : public llm_graph_context {
11264
11577
  cb(Kcur, "Kcur", il);
11265
11578
  cb(Vcur, "Vcur", il);
11266
11579
 
11267
- cur = build_attn(inp_attn, gf,
11580
+ cur = build_attn(inp_attn,
11268
11581
  model.layers[il].wo, NULL,
11269
11582
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
11270
11583
  }
@@ -11325,7 +11638,7 @@ struct llm_build_olmoe : public llm_graph_context {
11325
11638
  };
11326
11639
 
11327
11640
  struct llm_build_openelm : public llm_graph_context {
11328
- llm_build_openelm(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
11641
+ llm_build_openelm(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
11329
11642
  const int64_t n_embd_head = hparams.n_embd_head_v;
11330
11643
 
11331
11644
  GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -11397,7 +11710,7 @@ struct llm_build_openelm : public llm_graph_context {
11397
11710
  cb(Kcur, "Kcur", il);
11398
11711
  cb(Qcur, "Vcur", il);
11399
11712
 
11400
- cur = build_attn(inp_attn, gf,
11713
+ cur = build_attn(inp_attn,
11401
11714
  model.layers[il].wo, NULL,
11402
11715
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
11403
11716
  }
@@ -11454,7 +11767,7 @@ struct llm_build_openelm : public llm_graph_context {
11454
11767
  };
11455
11768
 
11456
11769
  struct llm_build_gptneox : public llm_graph_context {
11457
- llm_build_gptneox(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
11770
+ llm_build_gptneox(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
11458
11771
  const int64_t n_embd_head = hparams.n_embd_head_v;
11459
11772
  const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
11460
11773
 
@@ -11509,7 +11822,7 @@ struct llm_build_gptneox : public llm_graph_context {
11509
11822
  cb(Kcur, "Kcur", il);
11510
11823
  cb(Vcur, "Vcur", il);
11511
11824
 
11512
- cur = build_attn(inp_attn, gf,
11825
+ cur = build_attn(inp_attn,
11513
11826
  model.layers[il].wo, model.layers[il].bo,
11514
11827
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
11515
11828
  }
@@ -11600,7 +11913,7 @@ struct llm_build_gptneox : public llm_graph_context {
11600
11913
  };
11601
11914
 
11602
11915
  struct llm_build_arctic : public llm_graph_context {
11603
- llm_build_arctic(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
11916
+ llm_build_arctic(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
11604
11917
  const int64_t n_embd_head = hparams.n_embd_head_v;
11605
11918
 
11606
11919
  GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -11659,7 +11972,7 @@ struct llm_build_arctic : public llm_graph_context {
11659
11972
  cb(Kcur, "Kcur", il);
11660
11973
  cb(Vcur, "Vcur", il);
11661
11974
 
11662
- cur = build_attn(inp_attn, gf,
11975
+ cur = build_attn(inp_attn,
11663
11976
  model.layers[il].wo, NULL,
11664
11977
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
11665
11978
  }
@@ -11738,7 +12051,7 @@ struct llm_build_arctic : public llm_graph_context {
11738
12051
  };
11739
12052
 
11740
12053
  struct llm_build_deepseek : public llm_graph_context {
11741
- llm_build_deepseek(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
12054
+ llm_build_deepseek(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
11742
12055
  const int64_t n_embd_head = hparams.n_embd_head_v;
11743
12056
 
11744
12057
  GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -11814,7 +12127,7 @@ struct llm_build_deepseek : public llm_graph_context {
11814
12127
  cb(Kcur, "Kcur", il);
11815
12128
  cb(Vcur, "Vcur", il);
11816
12129
 
11817
- cur = build_attn(inp_attn, gf,
12130
+ cur = build_attn(inp_attn,
11818
12131
  model.layers[il].wo, model.layers[il].bo,
11819
12132
  Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
11820
12133
  }
@@ -11900,7 +12213,7 @@ struct llm_build_deepseek : public llm_graph_context {
11900
12213
  };
11901
12214
 
11902
12215
  struct llm_build_deepseek2 : public llm_graph_context {
11903
- llm_build_deepseek2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
12216
+ llm_build_deepseek2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
11904
12217
  bool is_lite = (hparams.n_layer == 27);
11905
12218
 
11906
12219
  const bool is_mla = (hparams.n_embd_head_k_mla != 0 && hparams.n_embd_head_v_mla != 0);
@@ -12042,7 +12355,7 @@ struct llm_build_deepseek2 : public llm_graph_context {
12042
12355
  cb(Vcur, "Vcur", il);
12043
12356
 
12044
12357
  // note: MLA with the absorption optimzation converts into MQA (ie: GQA with 1 group)
12045
- cur = build_attn(inp_attn, gf,
12358
+ cur = build_attn(inp_attn,
12046
12359
  model.layers[il].wo, NULL,
12047
12360
  Qcur, Kcur, Vcur, nullptr, model.layers[il].wv_b, kq_scale, il);
12048
12361
  } else {
@@ -12076,7 +12389,7 @@ struct llm_build_deepseek2 : public llm_graph_context {
12076
12389
  cb(Kcur, "Kcur", il);
12077
12390
 
12078
12391
  // note: MLA without the absorption optimization converts into MHA (ie: GQA with full n_head groups)
12079
- cur = build_attn(inp_attn, gf,
12392
+ cur = build_attn(inp_attn,
12080
12393
  model.layers[il].wo, NULL,
12081
12394
  Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
12082
12395
  }
@@ -12163,7 +12476,7 @@ struct llm_build_deepseek2 : public llm_graph_context {
12163
12476
  };
12164
12477
 
12165
12478
  struct llm_build_bitnet : public llm_graph_context {
12166
- llm_build_bitnet(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
12479
+ llm_build_bitnet(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
12167
12480
  const int64_t n_embd_head = hparams.n_embd_head_v;
12168
12481
 
12169
12482
  GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -12243,7 +12556,7 @@ struct llm_build_bitnet : public llm_graph_context {
12243
12556
  cb(Kcur, "Kcur", il);
12244
12557
  cb(Vcur, "Vcur", il);
12245
12558
 
12246
- cur = build_attn(inp_attn, gf,
12559
+ cur = build_attn(inp_attn,
12247
12560
  NULL, NULL,
12248
12561
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
12249
12562
 
@@ -12323,7 +12636,7 @@ struct llm_build_bitnet : public llm_graph_context {
12323
12636
  };
12324
12637
 
12325
12638
  struct llm_build_t5_enc : public llm_graph_context {
12326
- llm_build_t5_enc(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
12639
+ llm_build_t5_enc(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
12327
12640
  const int64_t n_embd_head = hparams.n_embd_head_v;
12328
12641
 
12329
12642
  GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -12366,7 +12679,7 @@ struct llm_build_t5_enc : public llm_graph_context {
12366
12679
  ggml_tensor * attn_rel_b = model.layers[il].attn_rel_b_enc ? model.layers[il].attn_rel_b_enc : model.layers[0].attn_rel_b_enc;
12367
12680
  ggml_tensor * kq_b = build_pos_bias(pos_bucket_enc, attn_rel_b);
12368
12681
 
12369
- cur = build_attn(inp_attn, gf,
12682
+ cur = build_attn(inp_attn,
12370
12683
  model.layers[il].wo_enc, nullptr,
12371
12684
  Qcur, Kcur, Vcur, kq_b, nullptr, 1.0f, il);
12372
12685
  cb(cur, "kqv_out", il);
@@ -12424,7 +12737,7 @@ struct llm_build_t5_enc : public llm_graph_context {
12424
12737
  };
12425
12738
 
12426
12739
  struct llm_build_t5_dec : public llm_graph_context {
12427
- llm_build_t5_dec(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
12740
+ llm_build_t5_dec(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
12428
12741
  const int64_t n_embd_head = hparams.n_embd_head_v;
12429
12742
  //const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
12430
12743
 
@@ -12472,7 +12785,7 @@ struct llm_build_t5_dec : public llm_graph_context {
12472
12785
  ggml_tensor * attn_rel_b = model.layers[il].attn_rel_b ? model.layers[il].attn_rel_b : model.layers[0].attn_rel_b;
12473
12786
  ggml_tensor * kq_b = build_pos_bias(pos_bucket_dec, attn_rel_b);
12474
12787
 
12475
- cur = build_attn(inp_attn_self, gf,
12788
+ cur = build_attn(inp_attn_self,
12476
12789
  model.layers[il].wo, model.layers[il].bo,
12477
12790
  Qcur, Kcur, Vcur, kq_b, nullptr, 1.0f, il);
12478
12791
  cb(cur, "kqv_out", il);
@@ -12504,7 +12817,7 @@ struct llm_build_t5_dec : public llm_graph_context {
12504
12817
  Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_outputs_enc);
12505
12818
  Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_outputs_enc);
12506
12819
 
12507
- cur = build_attn(inp_attn_cross, gf,
12820
+ cur = build_attn(inp_attn_cross,
12508
12821
  model.layers[il].wo_cross, nullptr,
12509
12822
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il);
12510
12823
  cb(cur, "kqv_out", il);
@@ -12594,7 +12907,7 @@ struct llm_build_t5_dec : public llm_graph_context {
12594
12907
  };
12595
12908
 
12596
12909
  struct llm_build_jais : public llm_graph_context {
12597
- llm_build_jais(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
12910
+ llm_build_jais(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
12598
12911
  const int64_t n_embd_head = hparams.n_embd_head_v;
12599
12912
  const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
12600
12913
 
@@ -12636,7 +12949,7 @@ struct llm_build_jais : public llm_graph_context {
12636
12949
  Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
12637
12950
  Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
12638
12951
 
12639
- cur = build_attn(inp_attn, gf,
12952
+ cur = build_attn(inp_attn,
12640
12953
  model.layers[il].wo, model.layers[il].bo,
12641
12954
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/float(n_embd_head), il);
12642
12955
  }
@@ -12689,7 +13002,7 @@ struct llm_build_jais : public llm_graph_context {
12689
13002
  };
12690
13003
 
12691
13004
  struct llm_build_chatglm : public llm_graph_context {
12692
- llm_build_chatglm(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
13005
+ llm_build_chatglm(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
12693
13006
  const int64_t n_embd_head = hparams.n_embd_head_v;
12694
13007
  const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
12695
13008
 
@@ -12768,7 +13081,7 @@ struct llm_build_chatglm : public llm_graph_context {
12768
13081
  cb(Kcur, "Kcur", il);
12769
13082
  cb(Vcur, "Vcur", il);
12770
13083
 
12771
- cur = build_attn(inp_attn, gf,
13084
+ cur = build_attn(inp_attn,
12772
13085
  model.layers[il].wo, NULL,
12773
13086
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
12774
13087
  }
@@ -12822,7 +13135,7 @@ struct llm_build_chatglm : public llm_graph_context {
12822
13135
  };
12823
13136
 
12824
13137
  struct llm_build_glm4 : public llm_graph_context {
12825
- llm_build_glm4(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
13138
+ llm_build_glm4(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
12826
13139
  const int64_t n_embd_head = hparams.n_embd_head_v;
12827
13140
  const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
12828
13141
 
@@ -12901,7 +13214,7 @@ struct llm_build_glm4 : public llm_graph_context {
12901
13214
  cb(Kcur, "Kcur", il);
12902
13215
  cb(Vcur, "Vcur", il);
12903
13216
 
12904
- cur = build_attn(inp_attn, gf,
13217
+ cur = build_attn(inp_attn,
12905
13218
  model.layers[il].wo, NULL,
12906
13219
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
12907
13220
  }
@@ -12973,7 +13286,7 @@ struct llm_build_glm4 : public llm_graph_context {
12973
13286
  };
12974
13287
 
12975
13288
  struct llm_build_nemotron : public llm_graph_context {
12976
- llm_build_nemotron(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
13289
+ llm_build_nemotron(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
12977
13290
  const int64_t n_embd_head = hparams.n_embd_head_v;
12978
13291
 
12979
13292
  GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -13045,7 +13358,7 @@ struct llm_build_nemotron : public llm_graph_context {
13045
13358
  cb(Kcur, "Kcur", il);
13046
13359
  cb(Vcur, "Vcur", il);
13047
13360
 
13048
- cur = build_attn(inp_attn, gf,
13361
+ cur = build_attn(inp_attn,
13049
13362
  model.layers[il].wo, model.layers[il].bo,
13050
13363
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
13051
13364
  }
@@ -13102,7 +13415,7 @@ struct llm_build_nemotron : public llm_graph_context {
13102
13415
  };
13103
13416
 
13104
13417
  struct llm_build_exaone : public llm_graph_context {
13105
- llm_build_exaone(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
13418
+ llm_build_exaone(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
13106
13419
  const int64_t n_embd_head = hparams.n_embd_head_v;
13107
13420
 
13108
13421
  GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -13176,7 +13489,7 @@ struct llm_build_exaone : public llm_graph_context {
13176
13489
  cb(Kcur, "Kcur", il);
13177
13490
  cb(Vcur, "Vcur", il);
13178
13491
 
13179
- cur = build_attn(inp_attn, gf,
13492
+ cur = build_attn(inp_attn,
13180
13493
  model.layers[il].wo, model.layers[il].bo,
13181
13494
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
13182
13495
  }
@@ -13232,32 +13545,168 @@ struct llm_build_exaone : public llm_graph_context {
13232
13545
  }
13233
13546
  };
13234
13547
 
13235
- struct llm_build_rwkv6_base : public llm_graph_context {
13236
- const llama_model & model;
13548
+ template <bool iswa>
13549
+ struct llm_build_exaone4 : public llm_graph_context {
13550
+ llm_build_exaone4(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
13551
+ const int64_t n_embd_head = hparams.n_embd_head_k;
13237
13552
 
13238
- llm_build_rwkv6_base(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params), model(model) {
13239
- }
13553
+ GGML_ASSERT(n_embd_head == hparams.n_embd_head_v);
13554
+ GGML_ASSERT(n_embd_head == hparams.n_rot);
13240
13555
 
13241
- ggml_tensor * build_rwkv6_channel_mix(
13242
- const llama_layer * layer,
13243
- ggml_tensor * cur,
13244
- ggml_tensor * x_prev,
13245
- llm_arch arch) const {
13246
- ggml_tensor * sx = ggml_sub(ctx0, x_prev, cur);
13247
- switch (arch) {
13248
- case LLM_ARCH_RWKV6:
13249
- {
13250
- ggml_tensor * xk = ggml_add(ctx0, ggml_mul(ctx0, sx, layer->channel_mix_lerp_k), cur);
13251
- ggml_tensor * xr = ggml_add(ctx0, ggml_mul(ctx0, sx, layer->channel_mix_lerp_r), cur);
13556
+ ggml_tensor * cur;
13557
+ ggml_tensor * inpL;
13252
13558
 
13253
- ggml_tensor * r = ggml_sigmoid(ctx0, build_lora_mm(layer->channel_mix_receptance, xr));
13254
- ggml_tensor * k = ggml_sqr(
13255
- ctx0,
13256
- ggml_relu(
13257
- ctx0,
13258
- build_lora_mm(layer->channel_mix_key, xk)
13259
- )
13260
- );
13559
+ inpL = build_inp_embd(model.tok_embd);
13560
+
13561
+ // inp_pos - contains the positions
13562
+ ggml_tensor * inp_pos = build_inp_pos();
13563
+
13564
+ using inp_attn_type = std::conditional_t<iswa, llm_graph_input_attn_kv_unified_iswa, llm_graph_input_attn_kv_unified>;
13565
+ inp_attn_type * inp_attn = nullptr;
13566
+
13567
+ if constexpr (iswa) {
13568
+ inp_attn = build_attn_inp_kv_unified_iswa();
13569
+ } else {
13570
+ inp_attn = build_attn_inp_kv_unified();
13571
+ }
13572
+
13573
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
13574
+
13575
+ for (int il = 0; il < n_layer; ++il) {
13576
+ ggml_tensor * inpSA = inpL;
13577
+
13578
+ // use RoPE for SWA layers or non-SWA models
13579
+ const bool use_rope = hparams.is_swa(il) || hparams.swa_type == LLAMA_SWA_TYPE_NONE;
13580
+
13581
+ cur = inpL;
13582
+
13583
+ // self-attention
13584
+ {
13585
+ ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
13586
+
13587
+ ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
13588
+ cb(Qcur, "Qcur", il);
13589
+
13590
+ ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
13591
+ cb(Kcur, "Kcur", il);
13592
+
13593
+ ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
13594
+ cb(Vcur, "Vcur", il);
13595
+
13596
+ Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
13597
+ Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
13598
+ Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
13599
+
13600
+ Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
13601
+ Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
13602
+ cb(Qcur, "Qcur_normed", il);
13603
+ cb(Kcur, "Kcur_normed", il);
13604
+
13605
+ if (use_rope) {
13606
+ Qcur = ggml_rope_ext(
13607
+ ctx0, Qcur, inp_pos, rope_factors,
13608
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
13609
+ ext_factor, attn_factor, beta_fast, beta_slow
13610
+ );
13611
+
13612
+ Kcur = ggml_rope_ext(
13613
+ ctx0, Kcur, inp_pos, rope_factors,
13614
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
13615
+ ext_factor, attn_factor, beta_fast, beta_slow
13616
+ );
13617
+ }
13618
+
13619
+ cb(Qcur, "Qcur", il);
13620
+ cb(Kcur, "Kcur", il);
13621
+ cb(Vcur, "Vcur", il);
13622
+
13623
+ cur = build_attn(inp_attn,
13624
+ model.layers[il].wo, NULL,
13625
+ Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
13626
+ cb(cur, "attn_out", il);
13627
+ }
13628
+
13629
+ if (il == n_layer - 1 && inp_out_ids) {
13630
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
13631
+ inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
13632
+ }
13633
+
13634
+ cur = build_norm(cur,
13635
+ model.layers[il].attn_post_norm, NULL,
13636
+ LLM_NORM_RMS, il);
13637
+ cb(cur, "attn_post_norm", il);
13638
+
13639
+ ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
13640
+ cb(ffn_inp, "ffn_inp", il);
13641
+
13642
+ // feed-forward network
13643
+ cur = build_ffn(ffn_inp,
13644
+ model.layers[il].ffn_up, NULL, NULL,
13645
+ model.layers[il].ffn_gate, NULL, NULL,
13646
+ model.layers[il].ffn_down, NULL, NULL,
13647
+ NULL,
13648
+ LLM_FFN_SILU, LLM_FFN_PAR, il);
13649
+ cb(cur, "ffn_out", il);
13650
+
13651
+ cur = build_norm(cur,
13652
+ model.layers[il].ffn_post_norm, NULL,
13653
+ LLM_NORM_RMS, -1);
13654
+ cb(cur, "ffn_post_norm", -1);
13655
+
13656
+ cur = ggml_add(ctx0, cur, ffn_inp);
13657
+
13658
+ cur = build_cvec(cur, il);
13659
+ cb(cur, "l_out", il);
13660
+
13661
+ // input for next layer
13662
+ inpL = cur;
13663
+ }
13664
+
13665
+ cur = inpL;
13666
+
13667
+ cur = build_norm(cur,
13668
+ model.output_norm, NULL,
13669
+ LLM_NORM_RMS, -1);
13670
+
13671
+ cb(cur, "result_norm", -1);
13672
+ res->t_embd = cur;
13673
+
13674
+ // lm_head
13675
+ cur = build_lora_mm(model.output, cur);
13676
+
13677
+ cb(cur, "result_output", -1);
13678
+ res->t_logits = cur;
13679
+
13680
+ ggml_build_forward_expand(gf, cur);
13681
+ }
13682
+ };
13683
+
13684
+ struct llm_build_rwkv6_base : public llm_graph_context {
13685
+ const llama_model & model;
13686
+
13687
+ llm_build_rwkv6_base(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params), model(model) {
13688
+ }
13689
+
13690
+ ggml_tensor * build_rwkv6_channel_mix(
13691
+ const llama_layer * layer,
13692
+ ggml_tensor * cur,
13693
+ ggml_tensor * x_prev,
13694
+ llm_arch arch) const {
13695
+ ggml_tensor * sx = ggml_sub(ctx0, x_prev, cur);
13696
+ switch (arch) {
13697
+ case LLM_ARCH_RWKV6:
13698
+ {
13699
+ ggml_tensor * xk = ggml_add(ctx0, ggml_mul(ctx0, sx, layer->channel_mix_lerp_k), cur);
13700
+ ggml_tensor * xr = ggml_add(ctx0, ggml_mul(ctx0, sx, layer->channel_mix_lerp_r), cur);
13701
+
13702
+ ggml_tensor * r = ggml_sigmoid(ctx0, build_lora_mm(layer->channel_mix_receptance, xr));
13703
+ ggml_tensor * k = ggml_sqr(
13704
+ ctx0,
13705
+ ggml_relu(
13706
+ ctx0,
13707
+ build_lora_mm(layer->channel_mix_key, xk)
13708
+ )
13709
+ );
13261
13710
  cur = ggml_mul(ctx0, r, build_lora_mm(layer->channel_mix_value, k));
13262
13711
  } break;
13263
13712
  default:
@@ -13269,7 +13718,6 @@ struct llm_build_rwkv6_base : public llm_graph_context {
13269
13718
 
13270
13719
  ggml_tensor * build_rwkv6_time_mix(
13271
13720
  llm_graph_input_rs * inp,
13272
- ggml_cgraph * gf,
13273
13721
  ggml_tensor * cur,
13274
13722
  ggml_tensor * x_prev,
13275
13723
  const llama_ubatch & ubatch,
@@ -13396,7 +13844,7 @@ struct llm_build_rwkv6_base : public llm_graph_context {
13396
13844
  }
13397
13845
 
13398
13846
  ggml_tensor * wkv_state = build_rs(
13399
- inp, gf, mctx_cur->get_s_l(il),
13847
+ inp, mctx_cur->get_s_l(il),
13400
13848
  hparams.n_embd_s(), n_seqs);
13401
13849
 
13402
13850
  ggml_tensor * wkv_output;
@@ -13442,7 +13890,7 @@ struct llm_build_rwkv6_base : public llm_graph_context {
13442
13890
  };
13443
13891
 
13444
13892
  struct llm_build_rwkv6 : public llm_build_rwkv6_base {
13445
- llm_build_rwkv6(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_build_rwkv6_base(model, params) {
13893
+ llm_build_rwkv6(const llama_model & model, const llm_graph_params & params) : llm_build_rwkv6_base(model, params) {
13446
13894
  GGML_ASSERT(hparams.token_shift_count == 2);
13447
13895
 
13448
13896
  ggml_tensor * cur;
@@ -13463,7 +13911,7 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base {
13463
13911
  const llama_layer * layer = &model.layers[il];
13464
13912
  inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
13465
13913
 
13466
- ggml_tensor * token_shift = build_rwkv_token_shift_load(rs_inp, gf, ubatch, il);
13914
+ ggml_tensor * token_shift = build_rwkv_token_shift_load(rs_inp, ubatch, il);
13467
13915
 
13468
13916
  ggml_tensor * att_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], 0);
13469
13917
  ggml_tensor * ffn_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], n_embd * ggml_element_size(token_shift));
@@ -13478,7 +13926,7 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base {
13478
13926
  1
13479
13927
  );
13480
13928
 
13481
- cur = build_rwkv6_time_mix(rs_inp, gf, att_norm, x_prev, ubatch, il);
13929
+ cur = build_rwkv6_time_mix(rs_inp, att_norm, x_prev, ubatch, il);
13482
13930
 
13483
13931
  ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
13484
13932
  cb(ffn_inp, "ffn_inp", il);
@@ -13543,7 +13991,7 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base {
13543
13991
 
13544
13992
  // ref: https://huggingface.co/recursal/QRWKV6-32B-Instruct-Preview-v0.1/blob/main/modeling_rwkv6qwen2.py
13545
13993
  struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base {
13546
- llm_build_rwkv6qwen2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_build_rwkv6_base(model, params) {
13994
+ llm_build_rwkv6qwen2(const llama_model & model, const llm_graph_params & params) : llm_build_rwkv6_base(model, params) {
13547
13995
  GGML_ASSERT(n_embd == hparams.n_embd_r());
13548
13996
 
13549
13997
  ggml_tensor * cur;
@@ -13563,7 +14011,7 @@ struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base {
13563
14011
  const llama_layer * layer = &model.layers[il];
13564
14012
  inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
13565
14013
 
13566
- ggml_tensor * token_shift = build_rwkv_token_shift_load(rs_inp, gf, ubatch, il);
14014
+ ggml_tensor * token_shift = build_rwkv_token_shift_load(rs_inp, ubatch, il);
13567
14015
 
13568
14016
  ggml_tensor * att_norm = build_norm(inpL, layer->attn_norm, layer->attn_norm_b, LLM_NORM_RMS, il);
13569
14017
  cb(att_norm, "attn_norm", il);
@@ -13575,7 +14023,7 @@ struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base {
13575
14023
  1
13576
14024
  );
13577
14025
 
13578
- cur = build_rwkv6_time_mix(rs_inp, gf, att_norm, x_prev, ubatch, il);
14026
+ cur = build_rwkv6_time_mix(rs_inp, att_norm, x_prev, ubatch, il);
13579
14027
 
13580
14028
  token_shift = ggml_view_3d(ctx0, att_norm, n_embd, 1, n_seqs, att_norm->nb[1], att_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(att_norm));
13581
14029
  ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il));
@@ -13665,7 +14113,6 @@ struct llm_build_rwkv7_base : public llm_graph_context {
13665
14113
 
13666
14114
  ggml_tensor * build_rwkv7_time_mix(
13667
14115
  llm_graph_input_rs * inp,
13668
- ggml_cgraph * gf,
13669
14116
  ggml_tensor * cur,
13670
14117
  ggml_tensor * x_prev,
13671
14118
  ggml_tensor *& first_layer_value,
@@ -13751,7 +14198,7 @@ struct llm_build_rwkv7_base : public llm_graph_context {
13751
14198
  a = ggml_reshape_3d(ctx0, a, head_size, head_count, n_tokens);
13752
14199
 
13753
14200
  ggml_tensor * wkv_state = build_rs(
13754
- inp, gf, mctx_cur->get_s_l(il),
14201
+ inp, mctx_cur->get_s_l(il),
13755
14202
  hparams.n_embd_s(), n_seqs);
13756
14203
 
13757
14204
  ggml_tensor * wkv_output = ggml_rwkv_wkv7(ctx0, r, w, k, v, ggml_neg(ctx0, kk), ggml_mul(ctx0, kk, a), wkv_state);
@@ -13798,7 +14245,7 @@ struct llm_build_rwkv7_base : public llm_graph_context {
13798
14245
  };
13799
14246
 
13800
14247
  struct llm_build_rwkv7 : public llm_build_rwkv7_base {
13801
- llm_build_rwkv7(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_build_rwkv7_base(model, params) {
14248
+ llm_build_rwkv7(const llama_model & model, const llm_graph_params & params) : llm_build_rwkv7_base(model, params) {
13802
14249
  GGML_ASSERT(hparams.token_shift_count == 2);
13803
14250
 
13804
14251
  ggml_tensor * cur;
@@ -13820,7 +14267,7 @@ struct llm_build_rwkv7 : public llm_build_rwkv7_base {
13820
14267
  const llama_layer * layer = &model.layers[il];
13821
14268
  inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
13822
14269
 
13823
- ggml_tensor * token_shift = build_rwkv_token_shift_load(rs_inp, gf, ubatch, il);
14270
+ ggml_tensor * token_shift = build_rwkv_token_shift_load(rs_inp, ubatch, il);
13824
14271
 
13825
14272
  ggml_tensor * att_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], 0);
13826
14273
  ggml_tensor * ffn_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], n_embd * ggml_element_size(token_shift));
@@ -13835,7 +14282,7 @@ struct llm_build_rwkv7 : public llm_build_rwkv7_base {
13835
14282
  1
13836
14283
  );
13837
14284
 
13838
- cur = build_rwkv7_time_mix(rs_inp, gf, att_norm, x_prev, v_first, ubatch, il);
14285
+ cur = build_rwkv7_time_mix(rs_inp, att_norm, x_prev, v_first, ubatch, il);
13839
14286
 
13840
14287
  ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
13841
14288
  cb(ffn_inp, "ffn_inp", il);
@@ -13894,7 +14341,7 @@ struct llm_build_rwkv7 : public llm_build_rwkv7_base {
13894
14341
 
13895
14342
 
13896
14343
  struct llm_build_arwkv7 : public llm_build_rwkv7_base {
13897
- llm_build_arwkv7(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_build_rwkv7_base(model, params) {
14344
+ llm_build_arwkv7(const llama_model & model, const llm_graph_params & params) : llm_build_rwkv7_base(model, params) {
13898
14345
  GGML_ASSERT(n_embd == hparams.n_embd_r());
13899
14346
 
13900
14347
  ggml_tensor * cur;
@@ -13915,7 +14362,7 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base {
13915
14362
  const llama_layer * layer = &model.layers[il];
13916
14363
  inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
13917
14364
 
13918
- ggml_tensor * token_shift = build_rwkv_token_shift_load(rs_inp, gf, ubatch, il);
14365
+ ggml_tensor * token_shift = build_rwkv_token_shift_load(rs_inp, ubatch, il);
13919
14366
 
13920
14367
  ggml_tensor * att_norm = build_norm(inpL, layer->attn_norm, layer->attn_norm_b, LLM_NORM_RMS, il);
13921
14368
  cb(att_norm, "attn_norm", il);
@@ -13927,7 +14374,7 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base {
13927
14374
  1
13928
14375
  );
13929
14376
 
13930
- cur = build_rwkv7_time_mix(rs_inp, gf, att_norm, x_prev, v_first, ubatch, il);
14377
+ cur = build_rwkv7_time_mix(rs_inp, att_norm, x_prev, v_first, ubatch, il);
13931
14378
 
13932
14379
  token_shift = ggml_view_3d(ctx0, att_norm, n_embd, 1, n_seqs, att_norm->nb[1], att_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(att_norm));
13933
14380
  ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il));
@@ -13984,8 +14431,7 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base {
13984
14431
  struct llm_build_granite : public llm_graph_context {
13985
14432
  llm_build_granite(
13986
14433
  const llama_model & model,
13987
- const llm_graph_params & params,
13988
- ggml_cgraph * gf)
14434
+ const llm_graph_params & params)
13989
14435
  : llm_graph_context(params) {
13990
14436
 
13991
14437
  const int64_t n_embd_head = hparams.n_embd_head_v;
@@ -14019,7 +14465,7 @@ struct llm_build_granite : public llm_graph_context {
14019
14465
 
14020
14466
  // self-attention
14021
14467
  cur = build_attention_layer(
14022
- gf, cur, inp_pos, inp_attn,
14468
+ cur, inp_pos, inp_attn,
14023
14469
  model, n_embd_head, il);
14024
14470
 
14025
14471
  if (il == n_layer - 1 && inp_out_ids) {
@@ -14055,7 +14501,6 @@ struct llm_build_granite : public llm_graph_context {
14055
14501
  }
14056
14502
 
14057
14503
  ggml_tensor * build_attention_layer(
14058
- ggml_cgraph * gf,
14059
14504
  ggml_tensor * cur,
14060
14505
  ggml_tensor * inp_pos,
14061
14506
  llm_graph_input_attn_kv_unified * inp_attn,
@@ -14110,7 +14555,7 @@ struct llm_build_granite : public llm_graph_context {
14110
14555
  cb(Vcur, "Vcur", il);
14111
14556
 
14112
14557
  const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
14113
- cur = build_attn(inp_attn, gf,
14558
+ cur = build_attn(inp_attn,
14114
14559
  model.layers[il].wo, model.layers[il].bo,
14115
14560
  Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
14116
14561
  cb(cur, "attn_out", il);
@@ -14198,11 +14643,9 @@ struct llm_build_granite : public llm_graph_context {
14198
14643
  };
14199
14644
 
14200
14645
  struct llm_build_granite_hybrid : public llm_graph_context_mamba {
14201
-
14202
14646
  llm_build_granite_hybrid(
14203
14647
  const llama_model & model,
14204
- const llm_graph_params & params,
14205
- ggml_cgraph * gf) :
14648
+ const llm_graph_params & params) :
14206
14649
  llm_graph_context_mamba(params) {
14207
14650
 
14208
14651
  const int64_t n_embd_head = hparams.n_embd_head_v;
@@ -14234,11 +14677,11 @@ struct llm_build_granite_hybrid : public llm_graph_context_mamba {
14234
14677
 
14235
14678
  if (hparams.is_recurrent(il)) {
14236
14679
  // ssm layer //
14237
- cur = build_mamba2_layer(inp->get_recr(), gf, cur, model, ubatch, il);
14680
+ cur = build_mamba2_layer(inp->get_recr(), cur, model, ubatch, il);
14238
14681
  } else {
14239
14682
  // attention layer //
14240
14683
  cur = build_attention_layer(
14241
- gf, cur, inp_pos, inp->get_attn(), model,
14684
+ cur, inp_pos, inp->get_attn(), model,
14242
14685
  n_embd_head, il);
14243
14686
  }
14244
14687
 
@@ -14277,7 +14720,6 @@ struct llm_build_granite_hybrid : public llm_graph_context_mamba {
14277
14720
  }
14278
14721
 
14279
14722
  ggml_tensor * build_attention_layer(
14280
- ggml_cgraph * gf,
14281
14723
  ggml_tensor * cur,
14282
14724
  ggml_tensor * inp_pos,
14283
14725
  llm_graph_input_attn_kv_unified * inp_attn,
@@ -14332,7 +14774,7 @@ struct llm_build_granite_hybrid : public llm_graph_context_mamba {
14332
14774
  cb(Vcur, "Vcur", il);
14333
14775
 
14334
14776
  const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
14335
- cur = build_attn(inp_attn, gf,
14777
+ cur = build_attn(inp_attn,
14336
14778
  model.layers[il].wo, model.layers[il].bo,
14337
14779
  Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
14338
14780
  cb(cur, "attn_out", il);
@@ -14426,7 +14868,7 @@ struct llm_build_granite_hybrid : public llm_graph_context_mamba {
14426
14868
  // * removed bias
14427
14869
  // * removed MoE
14428
14870
  struct llm_build_chameleon : public llm_graph_context {
14429
- llm_build_chameleon(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
14871
+ llm_build_chameleon(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
14430
14872
  const int64_t n_embd_head = hparams.n_embd_head_v;
14431
14873
 
14432
14874
  GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -14517,7 +14959,7 @@ struct llm_build_chameleon : public llm_graph_context {
14517
14959
  cb(Kcur, "Kcur", il);
14518
14960
  cb(Vcur, "Vcur", il);
14519
14961
 
14520
- cur = build_attn(inp_attn, gf,
14962
+ cur = build_attn(inp_attn,
14521
14963
  model.layers[il].wo, nullptr,
14522
14964
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
14523
14965
  }
@@ -14603,7 +15045,7 @@ struct llm_build_chameleon : public llm_graph_context {
14603
15045
  };
14604
15046
 
14605
15047
  struct llm_build_wavtokenizer_dec : public llm_graph_context {
14606
- llm_build_wavtokenizer_dec(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
15048
+ llm_build_wavtokenizer_dec(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
14607
15049
  ggml_tensor * cur;
14608
15050
  ggml_tensor * inpL;
14609
15051
 
@@ -14755,7 +15197,7 @@ struct llm_build_wavtokenizer_dec : public llm_graph_context {
14755
15197
  };
14756
15198
 
14757
15199
  struct llm_build_plm : public llm_graph_context {
14758
- llm_build_plm(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
15200
+ llm_build_plm(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
14759
15201
  const float kq_scale = 1.0f/sqrtf(float(hparams.n_embd_head_k));
14760
15202
 
14761
15203
  const uint32_t n_embd_head_qk_rope = hparams.n_rot;
@@ -14873,7 +15315,7 @@ struct llm_build_plm : public llm_graph_context {
14873
15315
  ggml_tensor * k_states = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0);
14874
15316
  cb(k_states, "k_states", il);
14875
15317
 
14876
- cur = build_attn(inp_attn, gf,
15318
+ cur = build_attn(inp_attn,
14877
15319
  model.layers[il].wo, NULL,
14878
15320
  q_states, k_states, v_states, nullptr, nullptr, kq_scale, il);
14879
15321
  }
@@ -14927,7 +15369,7 @@ struct llm_build_plm : public llm_graph_context {
14927
15369
  };
14928
15370
 
14929
15371
  struct llm_build_bailingmoe : public llm_graph_context {
14930
- llm_build_bailingmoe(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
15372
+ llm_build_bailingmoe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
14931
15373
  ggml_tensor * cur;
14932
15374
  ggml_tensor * inpL;
14933
15375
 
@@ -14996,7 +15438,7 @@ struct llm_build_bailingmoe : public llm_graph_context {
14996
15438
  cb(Kcur, "Kcur", il);
14997
15439
  cb(Vcur, "Vcur", il);
14998
15440
 
14999
- cur = build_attn(inp_attn, gf,
15441
+ cur = build_attn(inp_attn,
15000
15442
  model.layers[il].wo, model.layers[il].bo,
15001
15443
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_rot)), il);
15002
15444
  }
@@ -15071,7 +15513,7 @@ struct llm_build_bailingmoe : public llm_graph_context {
15071
15513
  };
15072
15514
 
15073
15515
  struct llm_build_dots1 : public llm_graph_context {
15074
- llm_build_dots1(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
15516
+ llm_build_dots1(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
15075
15517
  const int64_t n_embd_head = hparams.n_embd_head_v;
15076
15518
 
15077
15519
  GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -15136,7 +15578,7 @@ struct llm_build_dots1 : public llm_graph_context {
15136
15578
  cb(Kcur, "Kcur", il);
15137
15579
  cb(Vcur, "Vcur", il);
15138
15580
 
15139
- cur = build_attn(inp_attn, gf,
15581
+ cur = build_attn(inp_attn,
15140
15582
  model.layers[il].wo, model.layers[il].bo,
15141
15583
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
15142
15584
  }
@@ -15221,7 +15663,7 @@ struct llm_build_dots1 : public llm_graph_context {
15221
15663
  };
15222
15664
 
15223
15665
  struct llm_build_ernie4_5 : public llm_graph_context {
15224
- llm_build_ernie4_5(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
15666
+ llm_build_ernie4_5(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
15225
15667
  const int64_t n_embd_head = hparams.n_embd_head_v;
15226
15668
 
15227
15669
  GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -15291,7 +15733,7 @@ struct llm_build_ernie4_5 : public llm_graph_context {
15291
15733
  cb(Kcur, "Kcur", il);
15292
15734
  cb(Vcur, "Vcur", il);
15293
15735
 
15294
- cur = build_attn(inp_attn, gf,
15736
+ cur = build_attn(inp_attn,
15295
15737
  model.layers[il].wo, NULL,
15296
15738
  Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
15297
15739
  }
@@ -15350,8 +15792,178 @@ struct llm_build_ernie4_5 : public llm_graph_context {
15350
15792
  }
15351
15793
  };
15352
15794
 
15795
+ struct llm_build_ernie4_5_moe : public llm_graph_context {
15796
+ llm_build_ernie4_5_moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
15797
+ const int64_t n_embd_head = hparams.n_embd_head_v;
15798
+
15799
+ GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
15800
+ GGML_ASSERT(n_embd_head == hparams.n_rot);
15801
+
15802
+ ggml_tensor * cur;
15803
+ ggml_tensor * inpL;
15804
+
15805
+ inpL = build_inp_embd(model.tok_embd);
15806
+
15807
+ // inp_pos - contains the positions
15808
+ ggml_tensor * inp_pos = build_inp_pos();
15809
+
15810
+ auto * inp_attn = build_attn_inp_kv_unified();
15811
+
15812
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
15813
+
15814
+ GGML_ASSERT(hparams.n_moe_layer_step > 0 && "Ernie 4.5 MoE requires n_moe_layer_step > 0");
15815
+ for (int il = 0; il < n_layer; ++il) {
15816
+ ggml_tensor * inpSA = inpL;
15817
+ // norm
15818
+ {
15819
+ cur = build_norm(inpL,
15820
+ model.layers[il].attn_norm, NULL,
15821
+ LLM_NORM_RMS, il);
15822
+ cb(cur, "attn_norm", il);
15823
+ }
15824
+
15825
+ // self-attention
15826
+ {
15827
+ // compute Q and K and RoPE them
15828
+ ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
15829
+ cb(Qcur, "Qcur", il);
15830
+ if (model.layers[il].bq) {
15831
+ Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
15832
+ cb(Qcur, "Qcur", il);
15833
+ }
15834
+
15835
+ ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
15836
+ cb(Kcur, "Kcur", il);
15837
+ if (model.layers[il].bk) {
15838
+ Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
15839
+ cb(Kcur, "Kcur", il);
15840
+ }
15841
+
15842
+ ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
15843
+ cb(Vcur, "Vcur", il);
15844
+ if (model.layers[il].bv) {
15845
+ Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
15846
+ cb(Vcur, "Vcur", il);
15847
+ }
15848
+
15849
+ Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
15850
+ Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
15851
+ Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
15852
+
15853
+ Qcur = ggml_rope_ext(
15854
+ ctx0, Qcur, inp_pos, nullptr,
15855
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
15856
+ ext_factor, attn_factor, beta_fast, beta_slow
15857
+ );
15858
+
15859
+ Kcur = ggml_rope_ext(
15860
+ ctx0, Kcur, inp_pos, nullptr,
15861
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
15862
+ ext_factor, attn_factor, beta_fast, beta_slow
15863
+ );
15864
+
15865
+ cb(Qcur, "Qcur", il);
15866
+ cb(Kcur, "Kcur", il);
15867
+ cb(Vcur, "Vcur", il);
15868
+
15869
+ cur = build_attn(inp_attn,
15870
+ model.layers[il].wo, NULL,
15871
+ Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
15872
+ cb(cur, "attn_out", il);
15873
+ }
15874
+
15875
+ if (il == n_layer - 1 && inp_out_ids) {
15876
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
15877
+ inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
15878
+ }
15879
+
15880
+ ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
15881
+ cb(ffn_inp, "ffn_inp", il);
15882
+
15883
+ // feed-forward network
15884
+ bool is_moe_layer = static_cast<uint32_t>(il) >= hparams.n_layer_dense_lead && (il + 1) % hparams.n_moe_layer_step == 0;
15885
+
15886
+ if (!is_moe_layer) {
15887
+ cur = build_norm(ffn_inp,
15888
+ model.layers[il].ffn_norm, NULL,
15889
+ LLM_NORM_RMS, il);
15890
+ cb(cur, "ffn_norm", il);
15891
+
15892
+ cur = build_ffn(cur,
15893
+ model.layers[il].ffn_up, NULL, NULL,
15894
+ model.layers[il].ffn_gate, NULL, NULL,
15895
+ model.layers[il].ffn_down, NULL, NULL,
15896
+ NULL,
15897
+ LLM_FFN_SILU, LLM_FFN_PAR, il);
15898
+ cb(cur, "ffn_out", il);
15899
+ } else {
15900
+ // MoE branch
15901
+ cur = build_norm(ffn_inp,
15902
+ model.layers[il].ffn_norm, NULL,
15903
+ LLM_NORM_RMS, il);
15904
+ cb(cur, "ffn_norm", il);
15905
+
15906
+ ggml_tensor * moe_out = build_moe_ffn(cur,
15907
+ model.layers[il].ffn_gate_inp,
15908
+ model.layers[il].ffn_up_exps,
15909
+ model.layers[il].ffn_gate_exps,
15910
+ model.layers[il].ffn_down_exps,
15911
+ model.layers[il].ffn_exp_probs_b,
15912
+ n_expert, n_expert_used,
15913
+ LLM_FFN_SILU, true,
15914
+ false, 0.0,
15915
+ LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
15916
+ il);
15917
+ cb(moe_out, "ffn_moe_out", il);
15918
+
15919
+ // Shared expert (if present)
15920
+ if (hparams.n_ff_shexp > 0) {
15921
+ ggml_tensor * ffn_shexp = build_ffn(cur,
15922
+ model.layers[il].ffn_up_shexp, NULL, NULL,
15923
+ model.layers[il].ffn_gate_shexp, NULL, NULL,
15924
+ model.layers[il].ffn_down_shexp, NULL, NULL,
15925
+ NULL,
15926
+ LLM_FFN_SILU, LLM_FFN_PAR, il);
15927
+ cb(ffn_shexp, "ffn_shexp", il);
15928
+
15929
+ cur = ggml_add(ctx0, moe_out, ffn_shexp);
15930
+ } else {
15931
+ cur = moe_out;
15932
+ }
15933
+ cb(cur, "ffn_out", il);
15934
+ }
15935
+
15936
+ cur = ggml_add(ctx0, cur, ffn_inp);
15937
+ cb(cur, "ffn_out", il);
15938
+
15939
+ cur = build_cvec(cur, il);
15940
+ cb(cur, "l_out", il);
15941
+
15942
+ // input for next layer
15943
+ inpL = cur;
15944
+ }
15945
+
15946
+ cur = inpL;
15947
+
15948
+ cur = build_norm(cur,
15949
+ model.output_norm, NULL,
15950
+ LLM_NORM_RMS, -1);
15951
+
15952
+ cb(cur, "result_norm", -1);
15953
+ res->t_embd = cur;
15954
+
15955
+ // lm_head
15956
+ cur = build_lora_mm(model.output, cur);
15957
+
15958
+ cb(cur, "result_output", -1);
15959
+ res->t_logits = cur;
15960
+
15961
+ ggml_build_forward_expand(gf, cur);
15962
+ }
15963
+ };
15964
+
15353
15965
  struct llm_build_falcon_h1 : public llm_graph_context_mamba {
15354
- llm_build_falcon_h1(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context_mamba(params) {
15966
+ llm_build_falcon_h1(const llama_model & model, const llm_graph_params & params) : llm_graph_context_mamba(params) {
15355
15967
  const int64_t n_embd_head = hparams.n_embd_head_v;
15356
15968
 
15357
15969
  ggml_tensor * cur;
@@ -15407,7 +16019,7 @@ struct llm_build_falcon_h1 : public llm_graph_context_mamba {
15407
16019
  cb(Kcur, "Kcur-post-rope", il);
15408
16020
  cb(Vcur, "Vcur-post-rope", il);
15409
16021
 
15410
- ggml_tensor * attn_out = build_attn(inp->get_attn(), gf,
16022
+ ggml_tensor * attn_out = build_attn(inp->get_attn(),
15411
16023
  model.layers[il].wo, NULL,
15412
16024
  Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
15413
16025
  cb(attn_out, "attn_out", il);
@@ -15418,7 +16030,7 @@ struct llm_build_falcon_h1 : public llm_graph_context_mamba {
15418
16030
  // Mamba2 layer
15419
16031
  cb(cur, "ssm_in", il);
15420
16032
 
15421
- ggml_tensor * ssm_out = build_mamba2_layer(inp->get_recr(), gf, cur, model, ubatch, il);
16033
+ ggml_tensor * ssm_out = build_mamba2_layer(inp->get_recr(), cur, model, ubatch, il);
15422
16034
  cb(ssm_out, "ssm_out", il);
15423
16035
 
15424
16036
  // // Aggregation
@@ -15476,8 +16088,321 @@ struct llm_build_falcon_h1 : public llm_graph_context_mamba {
15476
16088
  }
15477
16089
  };
15478
16090
 
16091
+ struct llm_build_plamo2 : public llm_graph_context_mamba {
16092
+ llm_build_plamo2(const llama_model & model, const llm_graph_params & params) : llm_graph_context_mamba(params) {
16093
+ ggml_tensor * cur;
16094
+ ggml_tensor * inpL;
16095
+
16096
+ // {n_embd, n_tokens}
16097
+ inpL = build_inp_embd(model.tok_embd);
16098
+ cb(inpL, "embedding_output", -1);
16099
+
16100
+ ggml_tensor * inp_pos = build_inp_pos();
16101
+
16102
+ auto * inp_hybrid = build_inp_mem_hybrid();
16103
+
16104
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
16105
+
16106
+ for (int il = 0; il < n_layer; ++il) {
16107
+ ggml_tensor * residual = inpL;
16108
+
16109
+ // ggml_graph_add_node(gf, model.layers[il].attn_norm);
16110
+ // cb(model.layers[il].attn_norm, "attn_norm", il);
16111
+
16112
+ // pre_mixer_norm
16113
+ cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
16114
+
16115
+ // check if this layer is Mamba or Attention
16116
+ bool is_mamba_layer = hparams.is_recurrent(il);
16117
+
16118
+ if (is_mamba_layer) {
16119
+ // PLaMo-2 Mamba layer
16120
+ cur = build_plamo2_mamba_layer(inp_hybrid->get_recr(), cur, model, ubatch, il);
16121
+ } else {
16122
+ // PLaMo-2 Attention layer
16123
+ cur = build_plamo2_attn_layer(inp_hybrid->get_attn(), inp_pos, cur, model, il);
16124
+ }
16125
+
16126
+ // post_mixer_norm
16127
+ cur = build_norm(cur, model.layers[il].attn_post_norm, NULL, LLM_NORM_RMS, il);
16128
+ cb(cur, "attn_post_norm", il);
16129
+
16130
+ // residual connection
16131
+ cur = ggml_add(ctx0, cur, residual);
16132
+ cb(cur, "attn_residual", il);
16133
+ residual = cur;
16134
+
16135
+ // pre-ffn norm
16136
+ cur = build_norm(cur, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il);
16137
+ cb(cur, "ffn_pre_norm", il);
16138
+
16139
+ // feed-forward network
16140
+ cur = build_ffn(cur,
16141
+ model.layers[il].ffn_up, NULL, NULL,
16142
+ NULL, NULL, NULL,
16143
+ model.layers[il].ffn_down, NULL, NULL,
16144
+ NULL,
16145
+ LLM_FFN_SWIGLU, LLM_FFN_SEQ, il);
16146
+ cb(cur, "ffn_out", il);
16147
+
16148
+ // post ffn norm
16149
+ cur = build_norm(cur, model.layers[il].ffn_post_norm, NULL, LLM_NORM_RMS, il);
16150
+ cb(cur, "ffn_post_norm", il);
16151
+
16152
+ if (il == n_layer - 1 && inp_out_ids) {
16153
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
16154
+ residual = ggml_get_rows(ctx0, residual, inp_out_ids);
16155
+ }
16156
+
16157
+ // residual connection
16158
+ cur = ggml_add(ctx0, cur, residual);
16159
+ cb(cur, "ffn_residual", il);
16160
+
16161
+ inpL = cur;
16162
+ }
16163
+
16164
+ cur = inpL;
16165
+
16166
+ // final norm
16167
+ cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1);
16168
+ cb(cur, "result_norm", -1);
16169
+
16170
+ // lm_head
16171
+ cur = build_lora_mm(model.output, cur);
16172
+ cb(cur, "result_output", -1);
16173
+
16174
+ // Explicitly mark as output tensor to ensure proper backend assignment
16175
+ ggml_set_output(cur);
16176
+
16177
+ res->t_logits = cur;
16178
+
16179
+ ggml_build_forward_expand(gf, cur);
16180
+ }
16181
+
16182
+ private:
16183
+ ggml_tensor * build_plamo2_attn_layer(
16184
+ llm_graph_input_attn_kv_unified * inp,
16185
+ ggml_tensor * inp_pos,
16186
+ ggml_tensor * cur,
16187
+ const llama_model & model,
16188
+ int il) {
16189
+
16190
+ // self-attention
16191
+ {
16192
+ // PLaMo-2 uses combined QKV tensor
16193
+ ggml_tensor * qkv = build_lora_mm(model.layers[il].wqkv, cur);
16194
+ cb(qkv, "wqkv", il);
16195
+
16196
+ // split QKV tensor into Q, K, V
16197
+ const int64_t n_embd_head_q = hparams.n_embd_head_k;
16198
+ const int64_t n_embd_head_k = hparams.n_embd_head_k;
16199
+ const int64_t n_embd_head_v = hparams.n_embd_head_v;
16200
+ int32_t n_head_kv = hparams.n_head_kv(il);
16201
+
16202
+ const int64_t q_offset = 0;
16203
+ const int64_t k_offset = n_embd_head_q * n_head;
16204
+ const int64_t v_offset = k_offset + n_embd_head_k * n_head_kv;
16205
+
16206
+ ggml_tensor * Qcur = ggml_view_3d(ctx0, qkv, n_embd_head_q, n_head, n_tokens, n_embd_head_q * sizeof(float), qkv->nb[1], q_offset * ggml_element_size(qkv));
16207
+ ggml_tensor * Kcur = ggml_view_3d(ctx0, qkv, n_embd_head_k, n_head_kv, n_tokens, n_embd_head_k * sizeof(float), qkv->nb[1], k_offset * ggml_element_size(qkv));
16208
+ ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, n_embd_head_v * n_head_kv, n_tokens, qkv->nb[1], v_offset * ggml_element_size(qkv)));
16209
+
16210
+ cb(Qcur, "Qcur", il);
16211
+ cb(Kcur, "Kcur", il);
16212
+ cb(Vcur, "Vcur", il);
16213
+
16214
+ Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head_v, n_head_kv, n_tokens);
16215
+
16216
+ Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
16217
+ cb(Qcur, "Qcur_normed", il);
16218
+
16219
+ Qcur = ggml_rope_ext(
16220
+ ctx0, Qcur, inp_pos, nullptr,
16221
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
16222
+ ext_factor, attn_factor, beta_fast, beta_slow
16223
+ );
16224
+
16225
+ Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
16226
+ cb(Kcur, "Kcur_normed", il);
16227
+
16228
+ Kcur = ggml_rope_ext(
16229
+ ctx0, Kcur, inp_pos, nullptr,
16230
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
16231
+ ext_factor, attn_factor, beta_fast, beta_slow
16232
+ );
16233
+
16234
+ cur = build_attn(inp, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, NULL, NULL, 1.0f/sqrtf(float(n_embd_head_v)), il);
16235
+ }
16236
+
16237
+ cb(cur, "attn_out", il);
16238
+
16239
+ return cur;
16240
+ }
16241
+
16242
+ ggml_tensor * build_plamo2_mamba_layer(
16243
+ llm_graph_input_rs * inp,
16244
+ ggml_tensor * cur,
16245
+ const llama_model & model,
16246
+ const llama_ubatch & ubatch,
16247
+ int il) {
16248
+
16249
+ const auto * mctx_cur = inp->mctx;
16250
+
16251
+ const auto kv_head = mctx_cur->get_head();
16252
+
16253
+ const int64_t d_conv = hparams.ssm_d_conv;
16254
+ const int64_t d_inner = hparams.ssm_d_inner;
16255
+ const int64_t d_state = hparams.ssm_d_state;
16256
+ const int64_t n_heads = hparams.ssm_dt_rank;
16257
+ const int64_t head_dim = d_inner / n_heads;
16258
+ const int64_t n_group = hparams.ssm_n_group;
16259
+ const int64_t n_seqs = ubatch.n_seqs;
16260
+
16261
+ const int64_t n_seq_tokens = ubatch.n_seq_tokens;
16262
+
16263
+ GGML_ASSERT(n_seqs != 0);
16264
+ GGML_ASSERT(ubatch.equal_seqs());
16265
+ GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
16266
+
16267
+ ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
16268
+ ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il);
16269
+
16270
+ ggml_tensor * conv = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs);
16271
+ conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs);
16272
+
16273
+ // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
16274
+ cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs);
16275
+
16276
+ // in_proj: {n_embd, 2*d_inner} @ {n_embd, n_seq_tokens, n_seqs} => {2*d_inner, n_seq_tokens, n_seqs}
16277
+ ggml_tensor * zx = build_lora_mm(model.layers[il].ssm_in, cur);
16278
+ cb(zx, "mamba_in_proj", il);
16279
+ // {8192, 5, 1, 1} -> {8192, 1, 5, 1}
16280
+ zx = ggml_permute(ctx0, zx, 0, 2, 1, 3);
16281
+ zx = ggml_cont(ctx0, zx);
16282
+ zx = ggml_reshape_4d(ctx0, zx, head_dim * 2, n_heads, n_seq_tokens, n_seqs);
16283
+ cb(zx, "mamba_in_proj_out", il);
16284
+
16285
+ // split into z and x
16286
+ // => {head_dim * n_heads, n_seq_tokens, n_seqs}
16287
+ ggml_tensor * x = ggml_view_4d(ctx0, zx, head_dim, n_heads, n_seq_tokens, n_seqs, zx->nb[1], zx->nb[2], zx->nb[3], head_dim*ggml_element_size(zx));
16288
+ x = ggml_cont(ctx0, x);
16289
+ x = ggml_reshape_3d(ctx0, x, head_dim * n_heads, n_seq_tokens, n_seqs);
16290
+ // x = ggml_permute(ctx0, x, 0, 2, 1, 3);
16291
+ cb(x, "mamba_x_split", il);
16292
+
16293
+ ggml_tensor * z = ggml_view_4d(ctx0, zx, head_dim, n_heads, n_seq_tokens, n_seqs, zx->nb[1], zx->nb[2], zx->nb[3], 0);
16294
+ cb(z, "mamba_z_split", il);
16295
+
16296
+ // conv1d
16297
+ {
16298
+ // => {d_conv - 1 + n_seq_tokens, d_inner, n_seqs}
16299
+ ggml_tensor * conv_x = ggml_concat(ctx0, conv, ggml_transpose(ctx0, x), 0);
16300
+ cb(conv_x, "mamba_conv1d_input", il);
16301
+
16302
+ // copy last (d_conv - 1) columns back into the state cache
16303
+ ggml_tensor * last_conv = ggml_view_3d(ctx0, conv_x, d_conv - 1, d_inner, n_seqs,
16304
+ conv_x->nb[1], conv_x->nb[2], n_seq_tokens*(conv_x->nb[0]));
16305
+
16306
+ ggml_build_forward_expand(gf,
16307
+ ggml_cpy(ctx0, last_conv,
16308
+ ggml_view_1d(ctx0, conv_states_all,
16309
+ (d_conv - 1)*(d_inner + 2*n_group*d_state)*(n_seqs),
16310
+ kv_head*(d_conv - 1)*(d_inner + 2*n_group*d_state)*ggml_element_size(conv_states_all))));
16311
+ cb(conv_states_all, "mamba_conv1d_state", il);
16312
+
16313
+ // 1D convolution
16314
+ x = ggml_ssm_conv(ctx0, conv_x, model.layers[il].ssm_conv1d);
16315
+ cb(x, "mamba_conv1d", il);
16316
+
16317
+ x = ggml_silu(ctx0, x);
16318
+ cb(x, "mamba_conv1d_silu", il);
16319
+ }
16320
+
16321
+ // SSM
16322
+ {
16323
+ // bcdt_proj: {d_inner, dt_rank + 2*d_state} @ {d_inner, n_seq_tokens, n_seqs} => {dt_rank + 2*d_state, n_seq_tokens, n_seqs}
16324
+ ggml_tensor * x_bcdt = build_lora_mm(model.layers[il].ssm_x, x);
16325
+ cb(x_bcdt, "mamba_bcdt_proj", il);
16326
+
16327
+ // split into dt, B, C
16328
+ const int64_t dt_dim = std::max(64, int(hparams.n_embd / 16));
16329
+ ggml_tensor * B = ggml_view_3d(ctx0, x_bcdt, d_state, n_seq_tokens, n_seqs, x_bcdt->nb[1], x_bcdt->nb[2], 0);
16330
+ ggml_tensor * C = ggml_view_3d(ctx0, x_bcdt, d_state, n_seq_tokens, n_seqs, x_bcdt->nb[1], x_bcdt->nb[2], ggml_element_size(x_bcdt)*d_state);
16331
+ ggml_tensor * dt = ggml_view_3d(ctx0, x_bcdt, dt_dim, n_seq_tokens, n_seqs, x_bcdt->nb[1], x_bcdt->nb[2], ggml_element_size(x_bcdt)*(2*d_state));
16332
+ cb(B, "mamba_B_raw", il);
16333
+ cb(C, "mamba_C_raw", il);
16334
+ cb(dt, "mamba_dt_raw", il);
16335
+
16336
+ // Apply RMS norm to dt, B, C (PLaMo-2 specific)
16337
+ B = build_norm(B, model.layers[il].ssm_b_norm, NULL, LLM_NORM_RMS, il);
16338
+ C = build_norm(C, model.layers[il].ssm_c_norm, NULL, LLM_NORM_RMS, il);
16339
+ dt = build_norm(dt, model.layers[il].ssm_dt_norm, NULL, LLM_NORM_RMS, il);
16340
+ cb(B, "mamba_B_normed", il);
16341
+ cb(C, "mamba_C_normed", il);
16342
+ cb(dt, "mamba_dt_normed", il);
16343
+
16344
+ // dt_proj: {dt_rank, d_inner} @ {dt_rank, n_seq_tokens, n_seqs} => {d_inner, n_seq_tokens, n_seqs}
16345
+ dt = build_lora_mm(model.layers[il].ssm_dt, dt);
16346
+ dt = ggml_add(ctx0, dt, model.layers[il].ssm_dt_b);
16347
+ cb(dt, "mamba_dt_proj", il);
16348
+
16349
+ ggml_tensor * A = ggml_reshape_2d(ctx0, model.layers[il].ssm_a, 1, n_heads);
16350
+ cb(A, "mamba_A", il);
16351
+
16352
+ x = ggml_view_4d(ctx0, x, head_dim, n_heads, n_seq_tokens, n_seqs, head_dim * ggml_element_size(x), head_dim * n_heads * ggml_element_size(x), head_dim * n_heads * n_seq_tokens * ggml_element_size(x), 0);
16353
+ B = ggml_view_4d(ctx0, B, d_state, 1, n_seq_tokens, n_seqs, d_state * B->nb[0], B->nb[1], B->nb[2], 0);
16354
+ C = ggml_view_4d(ctx0, C, d_state, 1, n_seq_tokens, n_seqs, d_state * C->nb[0], C->nb[1], C->nb[2], 0);
16355
+
16356
+ // use the states and the indices provided by build_recurrent_state
16357
+ // (this is necessary in order to properly use the states before they are overwritten,
16358
+ // while avoiding to make unnecessary copies of the states)
16359
+ auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) {
16360
+ ggml_tensor * ssm = ggml_reshape_4d(ctx, states, d_state, head_dim, n_heads, mctx_cur->get_size());
16361
+
16362
+ // Custom operator to optimize the parallel associative scan
16363
+ // as described in the Annex D of the Mamba paper.
16364
+ // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs}
16365
+ return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids);
16366
+ };
16367
+
16368
+ ggml_tensor * y_ssm = build_rs(inp, ssm_states_all, hparams.n_embd_s(), ubatch.n_seqs, get_ssm_rows);
16369
+ cb(y_ssm, "mamba_ssm_scan", il);
16370
+
16371
+ // store last states
16372
+ ggml_build_forward_expand(gf,
16373
+ ggml_cpy(ctx0,
16374
+ ggml_view_1d(ctx0, y_ssm, n_heads*head_dim*d_state*n_seqs, n_heads*head_dim*n_seq_tokens*n_seqs*ggml_element_size(y_ssm)),
16375
+ ggml_view_1d(ctx0, ssm_states_all, n_heads*head_dim*d_state*n_seqs, kv_head*n_seqs*n_heads*head_dim*d_state*ggml_element_size(ssm_states_all))));
16376
+ cb(ssm_states_all, "mamba_ssm_states", il);
16377
+
16378
+ ggml_tensor * y = ggml_view_4d(ctx0, y_ssm, head_dim, n_heads, n_seq_tokens, n_seqs, head_dim * ggml_element_size(x), head_dim * n_heads * ggml_element_size(x), head_dim * n_heads * n_seq_tokens * ggml_element_size(x), 0);
16379
+ cb(y, "mamba_y_view", il);
16380
+
16381
+ // Add D parameter and apply gating with z
16382
+ // {d_inner, n_seq_tokens, n_seqs} * {d_inner} => {d_inner, n_seq_tokens, n_seqs}
16383
+ ggml_tensor * D = ggml_reshape_2d(ctx0, model.layers[il].ssm_d, 1, n_heads);
16384
+ y = ggml_add(ctx0, y, ggml_mul(ctx0, x, D));
16385
+ cb(y, "mamba_y_add_d", il);
16386
+
16387
+ y = ggml_swiglu_split(ctx0, ggml_cont(ctx0, z), y);
16388
+ cb(y, "mamba_y_swiglu_z", il);
16389
+
16390
+ // out_proj: {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs}
16391
+ y = ggml_view_3d(ctx0, y, head_dim * n_heads, n_seq_tokens, n_seqs, y->nb[2], y->nb[3], 0);
16392
+ cur = build_lora_mm(model.layers[il].ssm_out, y);
16393
+ cb(cur, "mamba_out_proj", il);
16394
+ }
16395
+
16396
+ // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens}
16397
+ cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], n_seq_tokens * n_seqs);
16398
+ cb(cur, "mamba_out", il);
16399
+
16400
+ return cur;
16401
+ }
16402
+ };
16403
+
15479
16404
  struct llm_build_arcee : public llm_graph_context {
15480
- llm_build_arcee(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
16405
+ llm_build_arcee(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
15481
16406
  const int64_t n_embd_head = hparams.n_embd_head_v;
15482
16407
 
15483
16408
  GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -15553,7 +16478,7 @@ struct llm_build_arcee : public llm_graph_context {
15553
16478
  cb(Kcur, "Kcur", il);
15554
16479
  cb(Vcur, "Vcur", il);
15555
16480
 
15556
- cur = build_attn(inp_attn, gf,
16481
+ cur = build_attn(inp_attn,
15557
16482
  model.layers[il].wo, model.layers[il].bo,
15558
16483
  Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
15559
16484
  cb(cur, "attn_out", il);
@@ -15612,7 +16537,7 @@ struct llm_build_arcee : public llm_graph_context {
15612
16537
  };
15613
16538
 
15614
16539
  struct llm_build_hunyuan_moe : public llm_graph_context {
15615
- llm_build_hunyuan_moe(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
16540
+ llm_build_hunyuan_moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
15616
16541
  const int64_t n_embd_head = hparams.n_embd_head_v;
15617
16542
 
15618
16543
  GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -15698,7 +16623,7 @@ struct llm_build_hunyuan_moe : public llm_graph_context {
15698
16623
  LLM_NORM_RMS, il);
15699
16624
  cb(Qcur, "Qcur_norm", il);
15700
16625
 
15701
- cur = build_attn(inp_attn, gf,
16626
+ cur = build_attn(inp_attn,
15702
16627
  model.layers[il].wo, model.layers[il].bo,
15703
16628
  Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
15704
16629
  cb(cur, "attn_out", il);
@@ -15773,7 +16698,7 @@ struct llm_build_hunyuan_moe : public llm_graph_context {
15773
16698
  };
15774
16699
 
15775
16700
  struct llm_build_smollm3 : public llm_graph_context {
15776
- llm_build_smollm3(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
16701
+ llm_build_smollm3(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
15777
16702
  const int64_t n_embd_head = hparams.n_embd_head_v;
15778
16703
 
15779
16704
  GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -15850,7 +16775,7 @@ struct llm_build_smollm3 : public llm_graph_context {
15850
16775
  cb(Kcur, "Kcur", il);
15851
16776
  cb(Vcur, "Vcur", il);
15852
16777
 
15853
- cur = build_attn(inp_attn, gf,
16778
+ cur = build_attn(inp_attn,
15854
16779
  model.layers[il].wo, model.layers[il].bo,
15855
16780
  Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
15856
16781
  cb(cur, "attn_out", il);
@@ -15912,7 +16837,7 @@ struct llm_build_smollm3 : public llm_graph_context {
15912
16837
  struct llm_build_lfm2 : public llm_graph_context {
15913
16838
  const llama_model & model;
15914
16839
 
15915
- llm_build_lfm2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params), model(model) {
16840
+ llm_build_lfm2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params), model(model) {
15916
16841
 
15917
16842
  ggml_tensor * cur = build_inp_embd(model.tok_embd);
15918
16843
  cb(cur, "model.embed_tokens", -1);
@@ -15927,8 +16852,8 @@ struct llm_build_lfm2 : public llm_graph_context {
15927
16852
  cb(cur, "model.layers.{}.operator_norm", il);
15928
16853
 
15929
16854
  cur = hparams.is_recurrent(il) ?
15930
- build_shortconv_block(gf, cur, inp_hybrid->get_recr(), il) :
15931
- build_attn_block(gf, cur, inp_pos, inp_hybrid->get_attn(), il) ;
16855
+ build_shortconv_block(cur, inp_hybrid->get_recr(), il) :
16856
+ build_attn_block(cur, inp_pos, inp_hybrid->get_attn(), il) ;
15932
16857
 
15933
16858
  if (il == n_layer - 1 && inp_out_ids) {
15934
16859
  cur = ggml_get_rows(ctx0, cur, inp_out_ids);
@@ -15971,8 +16896,7 @@ struct llm_build_lfm2 : public llm_graph_context {
15971
16896
  return cur;
15972
16897
  }
15973
16898
 
15974
- ggml_tensor * build_attn_block(ggml_cgraph * gf,
15975
- ggml_tensor * cur,
16899
+ ggml_tensor * build_attn_block(ggml_tensor * cur,
15976
16900
  ggml_tensor * inp_pos,
15977
16901
  llm_graph_input_attn_kv_unified * inp_attn,
15978
16902
  int il) const {
@@ -16009,7 +16933,7 @@ struct llm_build_lfm2 : public llm_graph_context {
16009
16933
  ext_factor, attn_factor, beta_fast, beta_slow
16010
16934
  );
16011
16935
 
16012
- cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL,
16936
+ cur = build_attn(inp_attn, model.layers[il].wo, NULL,
16013
16937
  q, k, v, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
16014
16938
 
16015
16939
  cb(cur, "model.layers.{}.self_attn.out_proj", il);
@@ -16017,11 +16941,22 @@ struct llm_build_lfm2 : public llm_graph_context {
16017
16941
  return cur;
16018
16942
  }
16019
16943
 
16020
- ggml_tensor * build_shortconv_block(ggml_cgraph * gf,
16021
- ggml_tensor * cur,
16944
+ ggml_tensor * build_shortconv_block(ggml_tensor * cur,
16022
16945
  llm_graph_input_rs * inp_recr,
16023
16946
  int il) {
16024
- const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx)->get_recr();
16947
+ const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx)->get_recr();
16948
+ const uint32_t kv_head = mctx_cur->get_head();
16949
+ const int64_t n_seq_tokens = ubatch.n_seq_tokens;
16950
+ const int64_t n_seqs = ubatch.n_seqs;
16951
+ GGML_ASSERT(n_seqs != 0);
16952
+ GGML_ASSERT(ubatch.equal_seqs());
16953
+ GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
16954
+
16955
+ GGML_ASSERT(hparams.n_shortconv_l_cache > 1);
16956
+ const uint32_t d_conv = hparams.n_shortconv_l_cache - 1;
16957
+
16958
+ // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
16959
+ cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs);
16025
16960
 
16026
16961
  auto * bcx = build_lora_mm(model.layers[il].shortconv.in_proj, cur);
16027
16962
  cb(bcx, "model.layers.{}.conv.in_proj", il);
@@ -16029,38 +16964,48 @@ struct llm_build_lfm2 : public llm_graph_context {
16029
16964
  constexpr auto n_chunks = 3;
16030
16965
  GGML_ASSERT(bcx->ne[0] % n_chunks == 0);
16031
16966
  auto const chunk_size = bcx->ne[0] / n_chunks;
16032
- auto * b = ggml_view_2d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->nb[1], 0 * chunk_size * ggml_element_size(bcx));
16033
- auto * c = ggml_view_2d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->nb[1], 1 * chunk_size * ggml_element_size(bcx));
16034
- auto * x = ggml_view_2d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->nb[1], 2 * chunk_size * ggml_element_size(bcx));
16967
+ auto * b = ggml_view_3d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->ne[2], bcx->nb[1], bcx->nb[2], 0*chunk_size*ggml_element_size(bcx));
16968
+ auto * c = ggml_view_3d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->ne[2], bcx->nb[1], bcx->nb[2], 1*chunk_size*ggml_element_size(bcx));
16969
+ auto * x = ggml_view_3d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->ne[2], bcx->nb[1], bcx->nb[2], 2*chunk_size*ggml_element_size(bcx));
16035
16970
 
16036
16971
  auto * bx = ggml_transpose(ctx0, ggml_mul(ctx0, b, x));
16037
16972
 
16038
- // read conv state directly, with build_rs generation is slower
16039
- ggml_tensor * conv_state = mctx_cur->get_r_l(il);
16040
- const int64_t n_seqs = ubatch.n_seqs;
16041
- ggml_tensor * conv = build_rs(inp_recr, gf, conv_state, hparams.n_embd_r(), n_seqs);
16042
- conv = ggml_reshape_3d(ctx0, conv_state, hparams.n_shortconv_l_cache - 1, hparams.n_embd, n_seqs);
16973
+ // read conv state
16974
+ auto * conv_state = mctx_cur->get_r_l(il);
16975
+ auto * conv_rs = build_rs(inp_recr, conv_state, hparams.n_embd_r(), n_seqs);
16976
+ auto * conv = ggml_reshape_3d(ctx0, conv_rs, d_conv, hparams.n_embd, n_seqs);
16043
16977
 
16044
16978
  bx = ggml_concat(ctx0, conv, bx, 0);
16045
16979
  GGML_ASSERT(bx->ne[0] > conv->ne[0]);
16046
16980
 
16047
- auto * new_conv = ggml_view_2d(ctx0, bx, conv->ne[0], bx->ne[1], bx->nb[1], (bx->ne[0] - conv->ne[0]) * ggml_element_size(bx));
16981
+ // last d_conv columns is a new conv state
16982
+ auto * new_conv = ggml_view_3d(ctx0, bx, conv->ne[0], bx->ne[1], bx->ne[2], bx->nb[1], bx->nb[2], (bx->ne[0] - conv->ne[0])*ggml_element_size(bx));
16048
16983
  GGML_ASSERT(ggml_are_same_shape(conv, new_conv));
16049
16984
 
16050
- // write conv state
16051
- ggml_build_forward_expand(gf, ggml_cpy(ctx0, new_conv, conv_state));
16985
+ // write new conv conv state
16986
+ ggml_build_forward_expand(
16987
+ gf,
16988
+ ggml_cpy(
16989
+ ctx0,
16990
+ new_conv,
16991
+ ggml_view_1d(
16992
+ ctx0,
16993
+ conv_state,
16994
+ ggml_nelements(new_conv),
16995
+ kv_head*d_conv*n_embd*ggml_element_size(new_conv)
16996
+ )
16997
+ )
16998
+ );
16052
16999
 
16053
17000
  auto * conv_kernel = model.layers[il].shortconv.conv;
16054
- GGML_ASSERT(hparams.n_shortconv_l_cache > 0);
16055
-
16056
- // construct ssm_conv op
16057
- ggml_tensor * conv_out = ggml_ssm_conv(ctx0, bx, conv_kernel);
17001
+ auto * conv_out = ggml_ssm_conv(ctx0, bx, conv_kernel);
16058
17002
  cb(conv_out, "model.layers.{}.conv.conv", il);
16059
17003
 
16060
17004
  auto * y = ggml_mul(ctx0, c, conv_out);
16061
-
16062
17005
  y = build_lora_mm(model.layers[il].shortconv.out_proj, y);
16063
17006
  cb(y, "model.layers.{}.conv.out_proj", il);
17007
+ // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens}
17008
+ y = ggml_reshape_2d(ctx0, y, y->ne[0], n_seq_tokens * n_seqs);
16064
17009
 
16065
17010
  return y;
16066
17011
  }
@@ -16078,6 +17023,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
16078
17023
  case LLM_ARCH_NOMIC_BERT_MOE:
16079
17024
  case LLM_ARCH_NEO_BERT:
16080
17025
  case LLM_ARCH_WAVTOKENIZER_DEC:
17026
+ case LLM_ARCH_DREAM:
16081
17027
  {
16082
17028
  res = nullptr;
16083
17029
  } break;
@@ -16118,7 +17064,18 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
16118
17064
  } else {
16119
17065
  const auto padding = llama_kv_cache_unified::get_padding(cparams);
16120
17066
 
16121
- cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
17067
+ uint32_t n_ctx_per_stream = cparams.n_ctx;
17068
+
17069
+ if (!cparams.kv_unified) {
17070
+ n_ctx_per_stream = (cparams.n_ctx + cparams.n_seq_max - 1)/cparams.n_seq_max;
17071
+ n_ctx_per_stream = GGML_PAD(n_ctx_per_stream, padding);
17072
+
17073
+ cparams.n_ctx = n_ctx_per_stream*cparams.n_seq_max;
17074
+ } else {
17075
+ n_ctx_per_stream = GGML_PAD(n_ctx_per_stream, padding);
17076
+
17077
+ cparams.n_ctx = n_ctx_per_stream;
17078
+ }
16122
17079
 
16123
17080
  LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
16124
17081
 
@@ -16132,7 +17089,8 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
16132
17089
  !cparams.flash_attn,
16133
17090
  cparams.offload_kqv,
16134
17091
  params.swa_full,
16135
- cparams.n_ctx,
17092
+ cparams.kv_unified,
17093
+ n_ctx_per_stream,
16136
17094
  cparams.n_seq_max,
16137
17095
  cparams.n_ubatch,
16138
17096
  padding);
@@ -16146,7 +17104,8 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
16146
17104
  params.type_v,
16147
17105
  !cparams.flash_attn,
16148
17106
  cparams.offload_kqv,
16149
- cparams.n_ctx,
17107
+ cparams.kv_unified,
17108
+ n_ctx_per_stream,
16150
17109
  cparams.n_seq_max,
16151
17110
  padding,
16152
17111
  hparams.n_swa,
@@ -16159,227 +17118,233 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
16159
17118
  return res;
16160
17119
  }
16161
17120
 
16162
- llm_graph_result_ptr llama_model::build_graph(
16163
- const llm_graph_params & params,
16164
- ggml_cgraph * gf,
16165
- llm_graph_type type) const {
17121
+ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
16166
17122
  std::unique_ptr<llm_graph_context> llm;
16167
17123
 
16168
17124
  switch (arch) {
16169
17125
  case LLM_ARCH_LLAMA:
16170
17126
  {
16171
- llm = std::make_unique<llm_build_llama>(*this, params, gf);
17127
+ llm = std::make_unique<llm_build_llama>(*this, params);
16172
17128
  } break;
16173
17129
  case LLM_ARCH_LLAMA4:
16174
17130
  {
16175
- llm = std::make_unique<llm_build_llama_iswa>(*this, params, gf);
17131
+ llm = std::make_unique<llm_build_llama_iswa>(*this, params);
16176
17132
  } break;
16177
17133
  case LLM_ARCH_DECI:
16178
17134
  {
16179
- llm = std::make_unique<llm_build_deci>(*this, params, gf);
17135
+ llm = std::make_unique<llm_build_deci>(*this, params);
16180
17136
  } break;
16181
17137
  case LLM_ARCH_BAICHUAN:
16182
17138
  {
16183
- llm = std::make_unique<llm_build_baichuan>(*this, params, gf);
17139
+ llm = std::make_unique<llm_build_baichuan>(*this, params);
16184
17140
  } break;
16185
17141
  case LLM_ARCH_FALCON:
16186
17142
  {
16187
- llm = std::make_unique<llm_build_falcon>(*this, params, gf);
17143
+ llm = std::make_unique<llm_build_falcon>(*this, params);
16188
17144
  } break;
16189
17145
  case LLM_ARCH_GROK:
16190
17146
  {
16191
- llm = std::make_unique<llm_build_grok>(*this, params, gf);
17147
+ llm = std::make_unique<llm_build_grok>(*this, params);
16192
17148
  } break;
16193
17149
  case LLM_ARCH_STARCODER:
16194
17150
  {
16195
- llm = std::make_unique<llm_build_starcoder>(*this, params, gf);
17151
+ llm = std::make_unique<llm_build_starcoder>(*this, params);
16196
17152
  } break;
16197
17153
  case LLM_ARCH_REFACT:
16198
17154
  {
16199
- llm = std::make_unique<llm_build_refact>(*this, params, gf);
17155
+ llm = std::make_unique<llm_build_refact>(*this, params);
16200
17156
  } break;
16201
17157
  case LLM_ARCH_BERT:
16202
17158
  case LLM_ARCH_JINA_BERT_V2:
16203
17159
  case LLM_ARCH_NOMIC_BERT:
16204
17160
  case LLM_ARCH_NOMIC_BERT_MOE:
16205
17161
  {
16206
- llm = std::make_unique<llm_build_bert>(*this, params, gf);
17162
+ llm = std::make_unique<llm_build_bert>(*this, params);
16207
17163
  } break;
16208
17164
  case LLM_ARCH_NEO_BERT:
16209
17165
  {
16210
- llm = std::make_unique<llm_build_neo_bert>(*this, params, gf);
17166
+ llm = std::make_unique<llm_build_neo_bert>(*this, params);
16211
17167
  } break;
16212
17168
  case LLM_ARCH_BLOOM:
16213
17169
  {
16214
- llm = std::make_unique<llm_build_bloom>(*this, params, gf);
17170
+ llm = std::make_unique<llm_build_bloom>(*this, params);
16215
17171
  } break;
16216
17172
  case LLM_ARCH_MPT:
16217
17173
  {
16218
- llm = std::make_unique<llm_build_mpt>(*this, params, gf);
17174
+ llm = std::make_unique<llm_build_mpt>(*this, params);
16219
17175
  } break;
16220
17176
  case LLM_ARCH_STABLELM:
16221
17177
  {
16222
- llm = std::make_unique<llm_build_stablelm>(*this, params, gf);
17178
+ llm = std::make_unique<llm_build_stablelm>(*this, params);
16223
17179
  } break;
16224
17180
  case LLM_ARCH_QWEN:
16225
17181
  {
16226
- llm = std::make_unique<llm_build_qwen>(*this, params, gf);
17182
+ llm = std::make_unique<llm_build_qwen>(*this, params);
16227
17183
  } break;
16228
17184
  case LLM_ARCH_QWEN2:
16229
17185
  {
16230
- llm = std::make_unique<llm_build_qwen2>(*this, params, gf);
17186
+ llm = std::make_unique<llm_build_qwen2>(*this, params);
16231
17187
  } break;
17188
+ case LLM_ARCH_DREAM:
17189
+ {
17190
+ llm = std::make_unique<llm_build_dream>(*this, params);
17191
+ }
17192
+ break;
16232
17193
  case LLM_ARCH_QWEN2VL:
16233
17194
  {
16234
- llm = std::make_unique<llm_build_qwen2vl>(*this, params, gf);
17195
+ llm = std::make_unique<llm_build_qwen2vl>(*this, params);
16235
17196
  } break;
16236
17197
  case LLM_ARCH_QWEN2MOE:
16237
17198
  {
16238
- llm = std::make_unique<llm_build_qwen2moe>(*this, params, gf);
17199
+ llm = std::make_unique<llm_build_qwen2moe>(*this, params);
16239
17200
  } break;
16240
17201
  case LLM_ARCH_QWEN3:
16241
17202
  {
16242
- llm = std::make_unique<llm_build_qwen3>(*this, params, gf);
17203
+ llm = std::make_unique<llm_build_qwen3>(*this, params);
16243
17204
  } break;
16244
17205
  case LLM_ARCH_QWEN3MOE:
16245
17206
  {
16246
- llm = std::make_unique<llm_build_qwen3moe>(*this, params, gf);
17207
+ llm = std::make_unique<llm_build_qwen3moe>(*this, params);
16247
17208
  } break;
16248
17209
  case LLM_ARCH_PHI2:
16249
17210
  {
16250
- llm = std::make_unique<llm_build_phi2>(*this, params, gf);
17211
+ llm = std::make_unique<llm_build_phi2>(*this, params);
16251
17212
  } break;
16252
17213
  case LLM_ARCH_PHI3:
16253
17214
  case LLM_ARCH_PHIMOE:
16254
17215
  {
16255
17216
  if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
16256
- llm = std::make_unique<llm_build_phi3<true>> (*this, params, gf);
17217
+ llm = std::make_unique<llm_build_phi3<true>> (*this, params);
16257
17218
  } else {
16258
- llm = std::make_unique<llm_build_phi3<false>>(*this, params, gf);
17219
+ llm = std::make_unique<llm_build_phi3<false>>(*this, params);
16259
17220
  }
16260
17221
  } break;
16261
17222
  case LLM_ARCH_PLAMO:
16262
17223
  {
16263
- llm = std::make_unique<llm_build_plamo>(*this, params, gf);
17224
+ llm = std::make_unique<llm_build_plamo>(*this, params);
17225
+ } break;
17226
+ case LLM_ARCH_PLAMO2:
17227
+ {
17228
+ llm = std::make_unique<llm_build_plamo2>(*this, params);
16264
17229
  } break;
16265
17230
  case LLM_ARCH_GPT2:
16266
17231
  {
16267
- llm = std::make_unique<llm_build_gpt2>(*this, params, gf);
17232
+ llm = std::make_unique<llm_build_gpt2>(*this, params);
16268
17233
  } break;
16269
17234
  case LLM_ARCH_CODESHELL:
16270
17235
  {
16271
- llm = std::make_unique<llm_build_codeshell>(*this, params, gf);
17236
+ llm = std::make_unique<llm_build_codeshell>(*this, params);
16272
17237
  } break;
16273
17238
  case LLM_ARCH_ORION:
16274
17239
  {
16275
- llm = std::make_unique<llm_build_orion>(*this, params, gf);
17240
+ llm = std::make_unique<llm_build_orion>(*this, params);
16276
17241
  } break;
16277
17242
  case LLM_ARCH_INTERNLM2:
16278
17243
  {
16279
- llm = std::make_unique<llm_build_internlm2>(*this, params, gf);
17244
+ llm = std::make_unique<llm_build_internlm2>(*this, params);
16280
17245
  } break;
16281
17246
  case LLM_ARCH_MINICPM3:
16282
17247
  {
16283
- llm = std::make_unique<llm_build_minicpm3>(*this, params, gf);
17248
+ llm = std::make_unique<llm_build_minicpm3>(*this, params);
16284
17249
  } break;
16285
17250
  case LLM_ARCH_GEMMA:
16286
17251
  {
16287
- llm = std::make_unique<llm_build_gemma>(*this, params, gf);
17252
+ llm = std::make_unique<llm_build_gemma>(*this, params);
16288
17253
  } break;
16289
17254
  case LLM_ARCH_GEMMA2:
16290
17255
  {
16291
- llm = std::make_unique<llm_build_gemma2_iswa>(*this, params, gf);
17256
+ llm = std::make_unique<llm_build_gemma2_iswa>(*this, params);
16292
17257
  } break;
16293
17258
  case LLM_ARCH_GEMMA3:
16294
17259
  {
16295
- llm = std::make_unique<llm_build_gemma3_iswa>(*this, params, gf);
17260
+ llm = std::make_unique<llm_build_gemma3_iswa>(*this, params);
16296
17261
  } break;
16297
17262
  case LLM_ARCH_GEMMA3N:
16298
17263
  {
16299
- llm = std::make_unique<llm_build_gemma3n_iswa>(*this, params, gf);
17264
+ llm = std::make_unique<llm_build_gemma3n_iswa>(*this, params);
16300
17265
  } break;
16301
17266
  case LLM_ARCH_STARCODER2:
16302
17267
  {
16303
- llm = std::make_unique<llm_build_starcoder2>(*this, params, gf);
17268
+ llm = std::make_unique<llm_build_starcoder2>(*this, params);
16304
17269
  } break;
16305
17270
  case LLM_ARCH_MAMBA:
16306
17271
  case LLM_ARCH_MAMBA2:
16307
17272
  {
16308
- llm = std::make_unique<llm_build_mamba>(*this, params, gf);
17273
+ llm = std::make_unique<llm_build_mamba>(*this, params);
16309
17274
  } break;
16310
17275
  case LLM_ARCH_JAMBA:
16311
17276
  {
16312
- llm = std::make_unique<llm_build_jamba>(*this, params, gf);
17277
+ llm = std::make_unique<llm_build_jamba>(*this, params);
16313
17278
  } break;
16314
17279
  case LLM_ARCH_XVERSE:
16315
17280
  {
16316
- llm = std::make_unique<llm_build_xverse>(*this, params, gf);
17281
+ llm = std::make_unique<llm_build_xverse>(*this, params);
16317
17282
  } break;
16318
17283
  case LLM_ARCH_COMMAND_R:
16319
17284
  {
16320
- llm = std::make_unique<llm_build_command_r>(*this, params, gf);
17285
+ llm = std::make_unique<llm_build_command_r>(*this, params);
16321
17286
  } break;
16322
17287
  case LLM_ARCH_COHERE2:
16323
17288
  {
16324
- llm = std::make_unique<llm_build_cohere2_iswa>(*this, params, gf);
17289
+ llm = std::make_unique<llm_build_cohere2_iswa>(*this, params);
16325
17290
  } break;
16326
17291
  case LLM_ARCH_DBRX:
16327
17292
  {
16328
- llm = std::make_unique<llm_build_dbrx>(*this, params, gf);
17293
+ llm = std::make_unique<llm_build_dbrx>(*this, params);
16329
17294
  } break;
16330
17295
  case LLM_ARCH_OLMO:
16331
17296
  {
16332
- llm = std::make_unique<llm_build_olmo>(*this, params, gf);
17297
+ llm = std::make_unique<llm_build_olmo>(*this, params);
16333
17298
  } break;
16334
17299
  case LLM_ARCH_OLMO2:
16335
17300
  {
16336
- llm = std::make_unique<llm_build_olmo2>(*this, params, gf);
17301
+ llm = std::make_unique<llm_build_olmo2>(*this, params);
16337
17302
  } break;
16338
17303
  case LLM_ARCH_OLMOE:
16339
17304
  {
16340
- llm = std::make_unique<llm_build_olmoe>(*this, params, gf);
17305
+ llm = std::make_unique<llm_build_olmoe>(*this, params);
16341
17306
  } break;
16342
17307
  case LLM_ARCH_OPENELM:
16343
17308
  {
16344
- llm = std::make_unique<llm_build_openelm>(*this, params, gf);
17309
+ llm = std::make_unique<llm_build_openelm>(*this, params);
16345
17310
  } break;
16346
17311
  case LLM_ARCH_GPTNEOX:
16347
17312
  {
16348
- llm = std::make_unique<llm_build_gptneox>(*this, params, gf);
17313
+ llm = std::make_unique<llm_build_gptneox>(*this, params);
16349
17314
  } break;
16350
17315
  case LLM_ARCH_ARCTIC:
16351
17316
  {
16352
- llm = std::make_unique<llm_build_arctic>(*this, params, gf);
17317
+ llm = std::make_unique<llm_build_arctic>(*this, params);
16353
17318
  } break;
16354
17319
  case LLM_ARCH_DEEPSEEK:
16355
17320
  {
16356
- llm = std::make_unique<llm_build_deepseek>(*this, params, gf);
17321
+ llm = std::make_unique<llm_build_deepseek>(*this, params);
16357
17322
  } break;
16358
17323
  case LLM_ARCH_DEEPSEEK2:
16359
17324
  {
16360
- llm = std::make_unique<llm_build_deepseek2>(*this, params, gf);
17325
+ llm = std::make_unique<llm_build_deepseek2>(*this, params);
16361
17326
  } break;
16362
17327
  case LLM_ARCH_CHATGLM:
16363
17328
  {
16364
- llm = std::make_unique<llm_build_chatglm>(*this, params, gf);
17329
+ llm = std::make_unique<llm_build_chatglm>(*this, params);
16365
17330
  } break;
16366
17331
  case LLM_ARCH_GLM4:
16367
17332
  {
16368
- llm = std::make_unique<llm_build_glm4>(*this, params, gf);
17333
+ llm = std::make_unique<llm_build_glm4>(*this, params);
16369
17334
  } break;
16370
17335
  case LLM_ARCH_BITNET:
16371
17336
  {
16372
- llm = std::make_unique<llm_build_bitnet>(*this, params, gf);
17337
+ llm = std::make_unique<llm_build_bitnet>(*this, params);
16373
17338
  } break;
16374
17339
  case LLM_ARCH_T5:
16375
17340
  {
16376
- switch (type) {
17341
+ switch (params.gtype) {
16377
17342
  case LLM_GRAPH_TYPE_ENCODER:
16378
- llm = std::make_unique<llm_build_t5_enc>(*this, params, gf);
17343
+ llm = std::make_unique<llm_build_t5_enc>(*this, params);
16379
17344
  break;
16380
17345
  case LLM_GRAPH_TYPE_DEFAULT:
16381
17346
  case LLM_GRAPH_TYPE_DECODER:
16382
- llm = std::make_unique<llm_build_t5_dec>(*this, params, gf);
17347
+ llm = std::make_unique<llm_build_t5_dec>(*this, params);
16383
17348
  break;
16384
17349
  default:
16385
17350
  GGML_ABORT("invalid graph type");
@@ -16387,99 +17352,111 @@ llm_graph_result_ptr llama_model::build_graph(
16387
17352
  } break;
16388
17353
  case LLM_ARCH_T5ENCODER:
16389
17354
  {
16390
- llm = std::make_unique<llm_build_t5_enc>(*this, params, gf);
17355
+ llm = std::make_unique<llm_build_t5_enc>(*this, params);
16391
17356
  }
16392
17357
  break;
16393
17358
  case LLM_ARCH_JAIS:
16394
17359
  {
16395
- llm = std::make_unique<llm_build_jais>(*this, params, gf);
17360
+ llm = std::make_unique<llm_build_jais>(*this, params);
16396
17361
  } break;
16397
17362
  case LLM_ARCH_NEMOTRON:
16398
17363
  {
16399
- llm = std::make_unique<llm_build_nemotron>(*this, params, gf);
17364
+ llm = std::make_unique<llm_build_nemotron>(*this, params);
16400
17365
  } break;
16401
17366
  case LLM_ARCH_EXAONE:
16402
17367
  {
16403
- llm = std::make_unique<llm_build_exaone>(*this, params, gf);
17368
+ llm = std::make_unique<llm_build_exaone>(*this, params);
17369
+ } break;
17370
+ case LLM_ARCH_EXAONE4:
17371
+ {
17372
+ if (hparams.swa_type == LLAMA_SWA_TYPE_STANDARD) {
17373
+ llm = std::make_unique<llm_build_exaone4<true>>(*this, params);
17374
+ } else {
17375
+ llm = std::make_unique<llm_build_exaone4<false>>(*this, params);
17376
+ }
16404
17377
  } break;
16405
17378
  case LLM_ARCH_RWKV6:
16406
17379
  {
16407
- llm = std::make_unique<llm_build_rwkv6>(*this, params, gf);
17380
+ llm = std::make_unique<llm_build_rwkv6>(*this, params);
16408
17381
  } break;
16409
17382
  case LLM_ARCH_RWKV6QWEN2:
16410
17383
  {
16411
- llm = std::make_unique<llm_build_rwkv6qwen2>(*this, params, gf);
17384
+ llm = std::make_unique<llm_build_rwkv6qwen2>(*this, params);
16412
17385
  } break;
16413
17386
  case LLM_ARCH_RWKV7:
16414
17387
  {
16415
- llm = std::make_unique<llm_build_rwkv7>(*this, params, gf);
17388
+ llm = std::make_unique<llm_build_rwkv7>(*this, params);
16416
17389
  } break;
16417
17390
  case LLM_ARCH_ARWKV7:
16418
17391
  {
16419
- llm = std::make_unique<llm_build_arwkv7>(*this, params, gf);
17392
+ llm = std::make_unique<llm_build_arwkv7>(*this, params);
16420
17393
  } break;
16421
17394
  case LLM_ARCH_GRANITE:
16422
17395
  case LLM_ARCH_GRANITE_MOE:
16423
17396
  case LLM_ARCH_MINICPM:
16424
17397
  {
16425
- llm = std::make_unique<llm_build_granite>(*this, params, gf);
17398
+ llm = std::make_unique<llm_build_granite>(*this, params);
16426
17399
  } break;
16427
17400
  case LLM_ARCH_GRANITE_HYBRID:
16428
17401
  {
16429
- llm = std::make_unique<llm_build_granite_hybrid>(*this, params, gf);
17402
+ llm = std::make_unique<llm_build_granite_hybrid>(*this, params);
16430
17403
  } break;
16431
17404
  case LLM_ARCH_CHAMELEON:
16432
17405
  {
16433
- llm = std::make_unique<llm_build_chameleon>(*this, params, gf);
17406
+ llm = std::make_unique<llm_build_chameleon>(*this, params);
16434
17407
  } break;
16435
17408
  case LLM_ARCH_WAVTOKENIZER_DEC:
16436
17409
  {
16437
- llm = std::make_unique<llm_build_wavtokenizer_dec>(*this, params, gf);
17410
+ llm = std::make_unique<llm_build_wavtokenizer_dec>(*this, params);
16438
17411
  } break;
16439
17412
  case LLM_ARCH_PLM:
16440
17413
  {
16441
- llm = std::make_unique<llm_build_plm>(*this, params, gf);
17414
+ llm = std::make_unique<llm_build_plm>(*this, params);
16442
17415
  } break;
16443
17416
  case LLM_ARCH_BAILINGMOE:
16444
17417
  {
16445
- llm = std::make_unique<llm_build_bailingmoe>(*this, params, gf);
17418
+ llm = std::make_unique<llm_build_bailingmoe>(*this, params);
16446
17419
  } break;
16447
17420
  case LLM_ARCH_DOTS1:
16448
17421
  {
16449
- llm = std::make_unique<llm_build_dots1>(*this, params, gf);
17422
+ llm = std::make_unique<llm_build_dots1>(*this, params);
16450
17423
  } break;
16451
17424
  case LLM_ARCH_ARCEE:
16452
17425
  {
16453
- llm = std::make_unique<llm_build_arcee>(*this, params, gf);
17426
+ llm = std::make_unique<llm_build_arcee>(*this, params);
16454
17427
  } break;
16455
17428
  case LLM_ARCH_ERNIE4_5:
16456
17429
  {
16457
- llm = std::make_unique<llm_build_ernie4_5>(*this, params, gf);
17430
+ llm = std::make_unique<llm_build_ernie4_5>(*this, params);
17431
+ } break;
17432
+ case LLM_ARCH_ERNIE4_5_MOE:
17433
+ {
17434
+ llm = std::make_unique<llm_build_ernie4_5_moe>(*this, params);
16458
17435
  } break;
16459
17436
  case LLM_ARCH_HUNYUAN_MOE:
16460
17437
  {
16461
- llm = std::make_unique<llm_build_hunyuan_moe>(*this, params, gf);
17438
+ llm = std::make_unique<llm_build_hunyuan_moe>(*this, params);
16462
17439
  } break;
16463
17440
  case LLM_ARCH_SMOLLM3:
16464
17441
  {
16465
- llm = std::make_unique<llm_build_smollm3>(*this, params, gf);
17442
+ llm = std::make_unique<llm_build_smollm3>(*this, params);
16466
17443
  } break;
16467
17444
  case LLM_ARCH_FALCON_H1:
16468
17445
  {
16469
- llm = std::make_unique<llm_build_falcon_h1>(*this, params, gf);
17446
+ llm = std::make_unique<llm_build_falcon_h1>(*this, params);
16470
17447
  } break;
16471
17448
  case LLM_ARCH_LFM2:
16472
17449
  {
16473
- llm = std::make_unique<llm_build_lfm2>(*this, params, gf);
17450
+ llm = std::make_unique<llm_build_lfm2>(*this, params);
16474
17451
  } break;
16475
17452
  default:
16476
17453
  GGML_ABORT("fatal error");
16477
17454
  }
16478
17455
 
16479
17456
  // add on pooling layer
16480
- llm->build_pooling(gf, cls, cls_b, cls_out, cls_out_b);
17457
+ llm->build_pooling(cls, cls_b, cls_out, cls_out_b);
16481
17458
 
16482
- return std::move(llm->res);
17459
+ return llm->res->get_gf();
16483
17460
  }
16484
17461
 
16485
17462
  //
@@ -16628,6 +17605,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
16628
17605
  case LLM_ARCH_SMOLLM3:
16629
17606
  case LLM_ARCH_ARCEE:
16630
17607
  case LLM_ARCH_ERNIE4_5:
17608
+ case LLM_ARCH_ERNIE4_5_MOE:
16631
17609
  return LLAMA_ROPE_TYPE_NORM;
16632
17610
 
16633
17611
  // the pairs of head values are offset by n_rot/2
@@ -16642,6 +17620,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
16642
17620
  case LLM_ARCH_BITNET:
16643
17621
  case LLM_ARCH_QWEN:
16644
17622
  case LLM_ARCH_QWEN2:
17623
+ case LLM_ARCH_DREAM:
16645
17624
  case LLM_ARCH_QWEN2MOE:
16646
17625
  case LLM_ARCH_QWEN3:
16647
17626
  case LLM_ARCH_QWEN3MOE:
@@ -16651,6 +17630,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
16651
17630
  case LLM_ARCH_PHI3:
16652
17631
  case LLM_ARCH_PHIMOE:
16653
17632
  case LLM_ARCH_PLAMO:
17633
+ case LLM_ARCH_PLAMO2:
16654
17634
  case LLM_ARCH_GEMMA:
16655
17635
  case LLM_ARCH_GEMMA2:
16656
17636
  case LLM_ARCH_GEMMA3:
@@ -16662,6 +17642,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
16662
17642
  case LLM_ARCH_ORION:
16663
17643
  case LLM_ARCH_NEMOTRON:
16664
17644
  case LLM_ARCH_EXAONE:
17645
+ case LLM_ARCH_EXAONE4:
16665
17646
  case LLM_ARCH_MINICPM3:
16666
17647
  case LLM_ARCH_DOTS1:
16667
17648
  case LLM_ARCH_HUNYUAN_MOE: