@novastera-oss/llamarn 0.3.0 → 0.3.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (190) hide show
  1. package/android/build.gradle +2 -1
  2. package/android/proguard-rules.pro +12 -0
  3. package/android/src/main/cpp/include/llama.h +15 -47
  4. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  5. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  6. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  8. package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
  9. package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
  10. package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
  11. package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
  12. package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
  13. package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
  14. package/android/src/main/jniLibs/x86/libggml.so +0 -0
  15. package/android/src/main/jniLibs/x86/libllama.so +0 -0
  16. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  17. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  18. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  19. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  20. package/cpp/build-info.cpp +2 -2
  21. package/cpp/llama.cpp/CMakePresets.json +11 -0
  22. package/cpp/llama.cpp/CODEOWNERS +1 -0
  23. package/cpp/llama.cpp/README.md +4 -3
  24. package/cpp/llama.cpp/common/arg.cpp +45 -1
  25. package/cpp/llama.cpp/common/common.cpp +22 -6
  26. package/cpp/llama.cpp/common/common.h +18 -4
  27. package/cpp/llama.cpp/convert_hf_to_gguf.py +500 -32
  28. package/cpp/llama.cpp/convert_hf_to_gguf_update.py +12 -13
  29. package/cpp/llama.cpp/ggml/CMakeLists.txt +6 -1
  30. package/cpp/llama.cpp/ggml/cmake/ggml-config.cmake.in +85 -47
  31. package/cpp/llama.cpp/ggml/include/ggml-webgpu.h +19 -0
  32. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +1 -0
  33. package/cpp/llama.cpp/ggml/src/ggml-alloc.c +0 -15
  34. package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +7 -0
  35. package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +8 -20
  36. package/cpp/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +3 -1
  37. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +58 -3
  38. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +130 -22
  39. package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +122 -16
  40. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +5 -2
  41. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +1 -1
  42. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +109 -12
  43. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +3 -0
  44. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +88 -10
  45. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +343 -1094
  46. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +3 -0
  47. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +0 -1
  48. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +3 -0
  49. package/cpp/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +3 -3
  50. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +14 -4
  51. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cu +64 -17
  52. package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy-utils.cuh +225 -0
  53. package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy.cu +41 -301
  54. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +85 -67
  55. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +45 -62
  56. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +28 -43
  57. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +41 -56
  58. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +36 -47
  59. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +31 -43
  60. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +22 -37
  61. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn.cu +3 -13
  62. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +73 -23
  63. package/cpp/llama.cpp/ggml/src/ggml-cuda/im2col.cu +1 -1
  64. package/cpp/llama.cpp/ggml/src/ggml-cuda/mma.cuh +111 -3
  65. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cu +6 -4
  66. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +1152 -689
  67. package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cu +92 -5
  68. package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cuh +2 -0
  69. package/cpp/llama.cpp/ggml/src/ggml-cuda/set-rows.cu +275 -0
  70. package/cpp/llama.cpp/ggml/src/ggml-cuda/set-rows.cuh +7 -0
  71. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +7 -0
  72. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +2 -0
  73. package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +13 -1
  74. package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +2 -2
  75. package/cpp/llama.cpp/ggml/src/ggml-impl.h +16 -0
  76. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +13 -3
  77. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +407 -69
  78. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +380 -83
  79. package/cpp/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +18 -4
  80. package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +2 -0
  81. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +295 -2
  82. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
  83. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
  84. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
  85. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
  86. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/rms_norm.cl +79 -0
  87. package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +4 -4
  88. package/cpp/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +14 -26
  89. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +131 -46
  90. package/cpp/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +1 -1
  91. package/cpp/llama.cpp/ggml/src/ggml-sycl/quants.hpp +8 -9
  92. package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +43 -43
  93. package/cpp/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +2 -6
  94. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +287 -22
  95. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +265 -0
  96. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +1 -5
  97. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +1 -1
  98. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
  99. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +1 -1
  100. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +1 -1
  101. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
  102. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +2 -0
  103. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +2 -0
  104. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +3 -8
  105. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +8 -2
  106. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +1 -4
  107. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rte.comp +5 -0
  108. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +71 -16
  109. package/cpp/llama.cpp/ggml/src/ggml-webgpu/CMakeLists.txt +54 -0
  110. package/cpp/llama.cpp/ggml/src/ggml-webgpu/ggml-webgpu.cpp +907 -0
  111. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +60 -0
  112. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +35 -0
  113. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
  114. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +56 -0
  115. package/cpp/llama.cpp/ggml/src/ggml.c +4 -6
  116. package/cpp/llama.cpp/gguf-py/gguf/constants.py +98 -0
  117. package/cpp/llama.cpp/gguf-py/gguf/metadata.py +4 -0
  118. package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_dump.py +24 -1
  119. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +75 -52
  120. package/cpp/llama.cpp/include/llama.h +15 -7
  121. package/cpp/llama.cpp/models/templates/llama-cpp-rwkv-world.jinja +34 -0
  122. package/cpp/llama.cpp/models/templates/moonshotai-Kimi-K2.jinja +43 -0
  123. package/cpp/llama.cpp/requirements/requirements-all.txt +1 -0
  124. package/cpp/llama.cpp/requirements/requirements-server-bench.txt +5 -0
  125. package/cpp/llama.cpp/src/llama-arch.cpp +106 -0
  126. package/cpp/llama.cpp/src/llama-arch.h +5 -0
  127. package/cpp/llama.cpp/src/llama-batch.cpp +76 -70
  128. package/cpp/llama.cpp/src/llama-batch.h +24 -18
  129. package/cpp/llama.cpp/src/llama-chat.cpp +43 -1
  130. package/cpp/llama.cpp/src/llama-chat.h +2 -0
  131. package/cpp/llama.cpp/src/llama-context.cpp +180 -106
  132. package/cpp/llama.cpp/src/llama-context.h +26 -16
  133. package/cpp/llama.cpp/src/llama-cparams.h +3 -2
  134. package/cpp/llama.cpp/src/llama-graph.cpp +203 -39
  135. package/cpp/llama.cpp/src/llama-graph.h +147 -72
  136. package/cpp/llama.cpp/src/llama-hparams.cpp +40 -0
  137. package/cpp/llama.cpp/src/llama-hparams.h +10 -2
  138. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +11 -5
  139. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +3 -0
  140. package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +698 -302
  141. package/cpp/llama.cpp/src/llama-kv-cache-unified.h +89 -31
  142. package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +1 -0
  143. package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +16 -1
  144. package/cpp/llama.cpp/src/llama-model.cpp +1293 -312
  145. package/cpp/llama.cpp/src/llama-model.h +3 -4
  146. package/cpp/llama.cpp/src/llama-quant.cpp +1 -2
  147. package/cpp/llama.cpp/src/llama-vocab.cpp +363 -8
  148. package/cpp/llama.cpp/src/llama-vocab.h +2 -0
  149. package/cpp/llama.cpp/src/unicode.cpp +207 -0
  150. package/cpp/llama.cpp/src/unicode.h +2 -0
  151. package/ios/include/common.h +18 -4
  152. package/ios/include/llama.h +15 -7
  153. package/ios/libs/llama.xcframework/Info.plist +15 -15
  154. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  155. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5267 -5059
  156. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +15 -7
  157. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  158. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  159. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -5030
  160. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4014 -3889
  161. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -7
  162. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  163. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  164. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -5030
  165. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4016 -3891
  166. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +15 -7
  167. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +15 -7
  168. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  169. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +15 -7
  170. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  171. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  172. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  173. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5267 -5059
  174. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +15 -7
  175. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  176. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  177. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -5030
  178. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4014 -3889
  179. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -7
  180. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  181. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  182. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5303 -5095
  183. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +15 -7
  184. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  185. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  186. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5274 -5066
  187. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4044 -3919
  188. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -7
  189. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  190. package/package.json +4 -4
@@ -71,12 +71,15 @@ struct rhs_packing_info {
71
71
  std::function<size_t(size_t n, size_t k, size_t nr, size_t kr, size_t bl)>,
72
72
  std::function<size_t(size_t n, size_t k)>
73
73
  > packed_size;
74
+ size_t (*packed_stride)(size_t k, size_t nr, size_t kr, size_t bl);
74
75
  std::variant<
75
76
  std::function<void(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t bl, const uint8_t* rhs,
76
77
  const float* bias, void* rhs_packed, size_t extra_bytes, const struct kai_rhs_pack_qs4cxs1s0_param* params)>,
77
78
  std::function<void(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const void* rhs,
78
79
  const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, const void* params)>
79
80
  > pack_func;
81
+ void (*to_float)(const void *packed_data, int32_t row_idx, int64_t nc, float *out, size_t nr_pack, size_t packed_row_stride,
82
+ size_t kr, size_t bl, size_t num_bytes_multiplier);
80
83
  };
81
84
 
82
85
  struct ggml_kleidiai_kernels {
@@ -40,6 +40,17 @@ struct ggml_kleidiai_context {
40
40
  ggml_kleidiai_kernels * kernels;
41
41
  } static ctx = { CPU_FEATURE_NONE, NULL };
42
42
 
43
+ static const char* cpu_feature_to_string(cpu_feature f) {
44
+ switch (f) {
45
+ case CPU_FEATURE_NONE: return "NONE";
46
+ case CPU_FEATURE_DOTPROD: return "DOTPROD";
47
+ case CPU_FEATURE_I8MM: return "I8MM";
48
+ case CPU_FEATURE_SVE: return "SVE";
49
+ case CPU_FEATURE_SME: return "SME";
50
+ default: return "UNKNOWN";
51
+ }
52
+ }
53
+
43
54
  static void init_kleidiai_context(void) {
44
55
 
45
56
  ggml_critical_section_start();
@@ -62,6 +73,11 @@ static void init_kleidiai_context(void) {
62
73
  ctx.features |= ggml_cpu_has_sme() ? CPU_FEATURE_SME : CPU_FEATURE_NONE;
63
74
  }
64
75
  ctx.kernels = ggml_kleidiai_select_kernels_q4_0(ctx.features);
76
+ #ifndef NDEBUG
77
+ if (ctx.kernels) {
78
+ GGML_LOG_DEBUG("kleidiai: using kernel with CPU feature %s\n", cpu_feature_to_string(ctx.kernels->required_cpu));
79
+ }
80
+ #endif
65
81
  }
66
82
  ggml_critical_section_end();
67
83
  }
@@ -102,6 +118,9 @@ static void transpose_f32kxn_f16nxk(size_t n, size_t k, float * dst, const uint1
102
118
 
103
119
  class tensor_traits : public ggml::cpu::tensor_traits {
104
120
  bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override {
121
+ if (op->op != GGML_OP_MUL_MAT) {
122
+ return false;
123
+ }
105
124
  ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, op);
106
125
  GGML_ASSERT(kernels);
107
126
  kernel_info * kernel = op->src[1]->ne[1] == 1 ? &kernels->gemv : &kernels->gemm;
@@ -135,6 +154,10 @@ class tensor_traits : public ggml::cpu::tensor_traits {
135
154
  } else if (dst->src[0]->type == GGML_TYPE_F16) {
136
155
  return compute_forward_kv_cache(params, dst);
137
156
  }
157
+ } else if (dst->op == GGML_OP_GET_ROWS) {
158
+ if (dst->src[0]->type == GGML_TYPE_Q4_0) {
159
+ return compute_forward_get_rows(params, dst);
160
+ }
138
161
  }
139
162
  return false;
140
163
  }
@@ -270,6 +293,8 @@ class tensor_traits : public ggml::cpu::tensor_traits {
270
293
  }
271
294
 
272
295
  bool compute_forward_q4_0(struct ggml_compute_params * params, struct ggml_tensor * dst) {
296
+ GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q4_0);
297
+
273
298
  const ggml_tensor * src0 = dst->src[0];
274
299
  const ggml_tensor * src1 = dst->src[1];
275
300
 
@@ -342,8 +367,49 @@ class tensor_traits : public ggml::cpu::tensor_traits {
342
367
  return true;
343
368
  }
344
369
 
370
+ bool compute_forward_get_rows(struct ggml_compute_params * params, struct ggml_tensor * dst) {
371
+ GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q4_0);
372
+ GGML_ASSERT(ctx.kernels);
373
+
374
+ const ggml_tensor * src0 = dst->src[0];
375
+ const ggml_tensor * src1 = dst->src[1];
376
+
377
+ GGML_TENSOR_BINARY_OP_LOCALS
378
+
379
+ rhs_packing_info * rhs_info = &ctx.kernels->rhs_info;
380
+ kernel_info * kernel = &ctx.kernels->gemm;
381
+
382
+ const int64_t nc = ne00;
383
+ const int64_t nr = ggml_nelements(src1);
384
+
385
+ const size_t block_rows = kernel->get_nr();
386
+ const size_t kr = kernel->get_kr();
387
+
388
+ const size_t num_bytes_multiplier = sizeof(uint16_t);
389
+ const size_t packed_stride = rhs_info->packed_stride(nc, block_rows, kr, QK4_0);
390
+
391
+ const int ith = params->ith;
392
+ const int nth = params->nth;
393
+
394
+ const int dr = (nr + nth - 1) / nth;
395
+ const int ir0 = dr * ith;
396
+ const int ir1 = MIN(ir0 + dr, nr);
397
+
398
+ for (int64_t i = ir0; i < ir1; ++i) {
399
+ GGML_ASSERT(src1->type == GGML_TYPE_I32);
400
+ int64_t row_idx = ((const int32_t *)src1->data)[i];
401
+ GGML_ASSERT(row_idx >= 0 && row_idx < src0->ne[1]);
402
+
403
+ float *out = (float *)((char *)dst->data + i * nb1);
404
+ rhs_info->to_float(src0->data, row_idx, nc, out, block_rows, packed_stride, kr, QK4_0, num_bytes_multiplier);
405
+ }
406
+
407
+ return true;
408
+ }
409
+
345
410
  public:
346
411
  int repack(struct ggml_tensor * tensor, const void * data, size_t data_size) {
412
+ GGML_ASSERT(tensor->type == GGML_TYPE_Q4_0);
347
413
  GGML_ASSERT(ctx.kernels);
348
414
  const size_t n = tensor->ne[1];
349
415
  const size_t k = tensor->ne[0];
@@ -351,17 +417,12 @@ public:
351
417
  size_t kr = ctx.kernels->gemm.get_kr();
352
418
  size_t sr = ctx.kernels->gemm.get_sr();
353
419
 
354
- #ifndef NDEBUG
355
- const size_t repacked_size = variant_call<size_t>(ctx.kernels->rhs_info.packed_size, n, k, nr, kr, QK4_0);
356
- GGML_ASSERT(repacked_size <= data_size && "repacked size larger than the packed size!");
357
- #endif
358
420
  struct kai_rhs_pack_qs4cxs1s0_param params;
359
421
  params.lhs_zero_point = 1;
360
422
  params.rhs_zero_point = 8;
361
423
  variant_call<void>(ctx.kernels->rhs_info.pack_func, 1, n, k, nr, kr, sr, QK4_0, (const uint8_t*)data, nullptr, tensor->data, 0, &params);
362
424
 
363
425
  return 0;
364
-
365
426
  GGML_UNUSED(data_size);
366
427
  }
367
428
  };
@@ -375,8 +436,8 @@ static ggml::cpu::tensor_traits * get_tensor_traits(ggml_backend_buffer_t, struc
375
436
  static enum ggml_status ggml_backend_cpu_kleidiai_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
376
437
  tensor->extra = (void *) ggml::cpu::kleidiai::get_tensor_traits(buffer, tensor);
377
438
 
378
- GGML_UNUSED(buffer);
379
439
  return GGML_STATUS_SUCCESS;
440
+ GGML_UNUSED(buffer);
380
441
  }
381
442
 
382
443
  static void ggml_backend_cpu_kleidiai_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor,
@@ -418,18 +479,35 @@ static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alignment(ggml_backend_b
418
479
  GGML_UNUSED(buft);
419
480
  }
420
481
 
482
+ static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) {
483
+ GGML_ASSERT(tensor->type == GGML_TYPE_Q4_0);
484
+ GGML_ASSERT(ctx.kernels);
485
+
486
+ const size_t n = tensor->ne[1];
487
+ const size_t k = tensor->ne[0];
488
+ const size_t nr = ctx.kernels->gemm.get_nr();
489
+ const size_t kr = ctx.kernels->gemm.get_kr();
490
+
491
+ return variant_call<size_t>(ctx.kernels->rhs_info.packed_size, n, k, nr, kr, QK4_0);
492
+
493
+ GGML_UNUSED(buft);
494
+ }
495
+
421
496
  namespace ggml::cpu::kleidiai {
422
497
  class extra_buffer_type : ggml::cpu::extra_buffer_type {
423
498
  bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override {
424
- if (op->op == GGML_OP_MUL_MAT &&
499
+ if ((op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_GET_ROWS) &&
425
500
  op->src[0]->type == GGML_TYPE_Q4_0 &&
426
501
  op->src[0]->buffer &&
427
502
  (ggml_n_dims(op->src[0]) == 2) &&
428
503
  op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type() && ctx.kernels) {
504
+ if (op->op == GGML_OP_GET_ROWS && op->src[1]->ne[0] != 8) {
505
+ return false;
506
+ }
429
507
  if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
430
508
  return false;
431
509
  }
432
- if (op->src[1]->type == GGML_TYPE_F32 &&
510
+ if ((op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_I32) &&
433
511
  ggml_ne(op->src[1], 2) == 1 && ggml_ne(op->src[1], 3) == 1) {
434
512
  return true;
435
513
  }
@@ -438,7 +516,7 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type {
438
516
  }
439
517
 
440
518
  ggml::cpu::tensor_traits * get_tensor_traits(const struct ggml_tensor * op) override {
441
- if (op->op == GGML_OP_MUL_MAT) {
519
+ if (op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_GET_ROWS) {
442
520
  if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) {
443
521
  return (ggml::cpu::tensor_traits *) op->src[0]->extra;
444
522
  }
@@ -469,7 +547,7 @@ ggml_backend_buffer_type_t ggml_backend_cpu_kleidiai_buffer_type(void) {
469
547
  /* .alloc_buffer = */ ggml_backend_cpu_kleidiai_buffer_type_alloc_buffer,
470
548
  /* .get_alignment = */ ggml_backend_cpu_kleidiai_buffer_type_get_alignment,
471
549
  /* .get_max_size = */ nullptr, // defaults to SIZE_MAX
472
- /* .get_alloc_size = */ nullptr, // defaults to ggml_nbytes
550
+ /* .get_alloc_size = */ ggml_backend_cpu_kleidiai_buffer_type_get_alloc_size,
473
551
  /* .is_host = */ nullptr,
474
552
  },
475
553
  /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0),