@novastera-oss/llamarn 0.2.7 → 0.2.9

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (186) hide show
  1. package/android/src/main/cpp/include/llama.h +8 -3
  2. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  3. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  4. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  5. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  6. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  7. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  8. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  9. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  10. package/cpp/LlamaCppModel.cpp +56 -22
  11. package/cpp/build-info.cpp +2 -2
  12. package/cpp/llama.cpp/CMakeLists.txt +1 -1
  13. package/cpp/llama.cpp/common/arg.cpp +7 -0
  14. package/cpp/llama.cpp/common/common.cpp +3 -0
  15. package/cpp/llama.cpp/common/common.h +1 -0
  16. package/cpp/llama.cpp/common/json-schema-to-grammar.cpp +3 -46
  17. package/cpp/llama.cpp/convert_hf_to_gguf.py +118 -20
  18. package/cpp/llama.cpp/ggml/CMakeLists.txt +1 -0
  19. package/cpp/llama.cpp/ggml/include/ggml-cpu.h +2 -0
  20. package/cpp/llama.cpp/ggml/include/ggml.h +33 -0
  21. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +17 -0
  22. package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +1 -1
  23. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +31 -2
  24. package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +10 -9
  25. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +109 -108
  26. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +1027 -1038
  27. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +53 -52
  28. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
  29. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +56 -55
  30. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +42 -41
  31. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +24 -23
  32. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +29 -28
  33. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +30 -29
  34. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +83 -82
  35. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/repack.cpp +20 -19
  36. package/cpp/llama.cpp/ggml/src/ggml-cpu/common.h +3 -2
  37. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +9 -3
  38. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +83 -102
  39. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +4 -0
  40. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +3 -2
  41. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +192 -67
  42. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +2 -0
  43. package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +25 -24
  44. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +56 -40
  45. package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +211 -33
  46. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +2 -2
  47. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +45 -45
  48. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +54 -29
  49. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
  50. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
  51. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
  52. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
  53. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +4 -0
  54. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +84 -31
  55. package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cu +19 -0
  56. package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cuh +3 -0
  57. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cu +257 -87
  58. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cuh +2 -3
  59. package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cu +5 -18
  60. package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cuh +0 -1
  61. package/cpp/llama.cpp/ggml/src/ggml-impl.h +61 -183
  62. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +16 -0
  63. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +227 -41
  64. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +362 -182
  65. package/cpp/llama.cpp/ggml/src/ggml-musa/mudnn.cuh +2 -2
  66. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +240 -535
  67. package/cpp/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +5 -6
  68. package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +1 -24
  69. package/cpp/llama.cpp/ggml/src/ggml-sycl/concat.cpp +28 -41
  70. package/cpp/llama.cpp/ggml/src/ggml-sycl/conv.cpp +4 -10
  71. package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +99 -166
  72. package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +94 -72
  73. package/cpp/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +49 -67
  74. package/cpp/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +31 -1
  75. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +99 -159
  76. package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +6 -9
  77. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +45 -54
  78. package/cpp/llama.cpp/ggml/src/ggml-sycl/gla.cpp +2 -2
  79. package/cpp/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +1 -1
  80. package/cpp/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +60 -80
  81. package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +132 -201
  82. package/cpp/llama.cpp/ggml/src/ggml-sycl/norm.cpp +55 -74
  83. package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +24 -20
  84. package/cpp/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +3 -3
  85. package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
  86. package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
  87. package/cpp/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +3 -8
  88. package/cpp/llama.cpp/ggml/src/ggml-sycl/wkv.cpp +12 -16
  89. package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +12 -1
  90. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +57 -1
  91. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -0
  92. package/cpp/llama.cpp/ggml/src/ggml.c +69 -13
  93. package/cpp/llama.cpp/ggml/src/gguf.cpp +5 -1
  94. package/cpp/llama.cpp/gguf-py/gguf/constants.py +76 -0
  95. package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +21 -0
  96. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +64 -0
  97. package/cpp/llama.cpp/gguf-py/gguf/vocab.py +97 -4
  98. package/cpp/llama.cpp/gguf-py/pyproject.toml +2 -2
  99. package/cpp/llama.cpp/include/llama.h +8 -3
  100. package/cpp/llama.cpp/models/templates/Mistral-Small-3.2-24B-Instruct-2506.jinja +124 -0
  101. package/cpp/llama.cpp/src/llama-arch.cpp +55 -0
  102. package/cpp/llama.cpp/src/llama-arch.h +18 -0
  103. package/cpp/llama.cpp/src/llama-batch.cpp +570 -359
  104. package/cpp/llama.cpp/src/llama-batch.h +98 -70
  105. package/cpp/llama.cpp/src/llama-chat.cpp +11 -6
  106. package/cpp/llama.cpp/src/llama-context.cpp +101 -107
  107. package/cpp/llama.cpp/src/llama-context.h +13 -13
  108. package/cpp/llama.cpp/src/llama-graph.cpp +199 -252
  109. package/cpp/llama.cpp/src/llama-graph.h +44 -32
  110. package/cpp/llama.cpp/src/llama-hparams.cpp +4 -0
  111. package/cpp/llama.cpp/src/llama-hparams.h +8 -0
  112. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +51 -53
  113. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +19 -24
  114. package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +110 -104
  115. package/cpp/llama.cpp/src/llama-kv-cache-unified.h +17 -22
  116. package/cpp/llama.cpp/src/llama-kv-cells.h +35 -11
  117. package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +66 -67
  118. package/cpp/llama.cpp/src/llama-memory-hybrid.h +16 -21
  119. package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +69 -68
  120. package/cpp/llama.cpp/src/llama-memory-recurrent.h +15 -20
  121. package/cpp/llama.cpp/src/llama-memory.h +18 -22
  122. package/cpp/llama.cpp/src/llama-model-saver.cpp +1 -0
  123. package/cpp/llama.cpp/src/llama-model.cpp +1006 -472
  124. package/cpp/llama.cpp/src/llama-model.h +22 -0
  125. package/cpp/llama.cpp/src/llama-quant.cpp +87 -5
  126. package/cpp/llama.cpp/src/llama-vocab.cpp +26 -3
  127. package/cpp/llama.cpp/src/llama-vocab.h +1 -0
  128. package/cpp/rn-utils.h +3 -0
  129. package/ios/include/common.h +1 -0
  130. package/ios/include/llama.h +8 -3
  131. package/ios/libs/llama.xcframework/Info.plist +19 -19
  132. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  133. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4890 -4863
  134. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
  135. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +33 -0
  136. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +8 -3
  137. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  138. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  139. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4861 -4834
  140. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3764 -3742
  141. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
  142. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +33 -0
  143. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -3
  144. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  145. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  146. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4861 -4834
  147. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3766 -3744
  148. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-cpu.h +2 -0
  149. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +33 -0
  150. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +8 -3
  151. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-cpu.h +2 -0
  152. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +33 -0
  153. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +8 -3
  154. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  155. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-cpu.h +2 -0
  156. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +33 -0
  157. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +8 -3
  158. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  159. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  160. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  161. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4890 -4863
  162. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
  163. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +33 -0
  164. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +8 -3
  165. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  166. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  167. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4861 -4834
  168. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3764 -3742
  169. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
  170. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +33 -0
  171. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -3
  172. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  173. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  174. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4926 -4900
  175. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
  176. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +33 -0
  177. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +8 -3
  178. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  179. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  180. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4897 -4871
  181. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3794 -3773
  182. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
  183. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +33 -0
  184. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -3
  185. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  186. package/package.json +1 -1
@@ -35,6 +35,17 @@ constexpr constant static float kvalues_iq4nl_f[16] = {
35
35
  -127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
36
36
  };
37
37
 
38
+ static inline int best_index_int8(int n, constant float * val, float x) {
39
+ if (x <= val[0]) return 0;
40
+ if (x >= val[n-1]) return n-1;
41
+ int ml = 0, mu = n-1;
42
+ while (mu-ml > 1) {
43
+ int mav = (ml+mu)/2;
44
+ if (x < val[mav]) mu = mav; else ml = mav;
45
+ }
46
+ return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
47
+ }
48
+
38
49
  // NOTE: this is not dequantizing - we are simply fitting the template
39
50
  template <typename type4x4>
40
51
  void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
@@ -97,6 +108,173 @@ void dequantize_q4_0_t4(device const block_q4_0 * xb, short il, thread type4 & r
97
108
  }
98
109
  }
99
110
 
111
+ void quantize_q4_0(device const float * src, device block_q4_0 & dst) {
112
+ float amax = 0.0f; // absolute max
113
+ float max = 0.0f;
114
+
115
+ for (int j = 0; j < QK4_0; j++) {
116
+ const float v = src[j];
117
+ if (amax < fabs(v)) {
118
+ amax = fabs(v);
119
+ max = v;
120
+ }
121
+ }
122
+
123
+ const float d = max / -8;
124
+ const float id = d ? 1.0f/d : 0.0f;
125
+
126
+ dst.d = d;
127
+
128
+ for (int j = 0; j < QK4_0/2; ++j) {
129
+ const float x0 = src[0 + j]*id;
130
+ const float x1 = src[QK4_0/2 + j]*id;
131
+
132
+ const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));
133
+ const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));
134
+
135
+ dst.qs[j] = xi0;
136
+ dst.qs[j] |= xi1 << 4;
137
+ }
138
+ }
139
+
140
+ void quantize_q4_1(device const float * src, device block_q4_1 & dst) {
141
+ float min = FLT_MAX;
142
+ float max = -FLT_MAX;
143
+
144
+ for (int j = 0; j < QK4_1; j++) {
145
+ const float v = src[j];
146
+ if (min > v) min = v;
147
+ if (max < v) max = v;
148
+ }
149
+
150
+ const float d = (max - min) / ((1 << 4) - 1);
151
+ const float id = d ? 1.0f/d : 0.0f;
152
+
153
+ dst.d = d;
154
+ dst.m = min;
155
+
156
+ for (int j = 0; j < QK4_1/2; ++j) {
157
+ const float x0 = (src[0 + j] - min)*id;
158
+ const float x1 = (src[QK4_1/2 + j] - min)*id;
159
+
160
+ const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f));
161
+ const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f));
162
+
163
+ dst.qs[j] = xi0;
164
+ dst.qs[j] |= xi1 << 4;
165
+ }
166
+ }
167
+
168
+ void quantize_q5_0(device const float * src, device block_q5_0 & dst) {
169
+ float amax = 0.0f; // absolute max
170
+ float max = 0.0f;
171
+
172
+ for (int j = 0; j < QK5_0; j++) {
173
+ const float v = src[j];
174
+ if (amax < fabs(v)) {
175
+ amax = fabs(v);
176
+ max = v;
177
+ }
178
+ }
179
+
180
+ const float d = max / -16;
181
+ const float id = d ? 1.0f/d : 0.0f;
182
+
183
+ dst.d = d;
184
+
185
+ uint32_t qh = 0;
186
+ for (int j = 0; j < QK5_0/2; ++j) {
187
+ const float x0 = src[0 + j]*id;
188
+ const float x1 = src[QK5_0/2 + j]*id;
189
+
190
+ const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f));
191
+ const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f));
192
+
193
+ dst.qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
194
+ qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
195
+ qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
196
+ }
197
+
198
+ thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
199
+
200
+ for (int j = 0; j < 4; ++j) {
201
+ dst.qh[j] = qh8[j];
202
+ }
203
+ }
204
+
205
+ void quantize_q5_1(device const float * src, device block_q5_1 & dst) {
206
+ float max = src[0];
207
+ float min = src[0];
208
+
209
+ for (int j = 1; j < QK5_1; j++) {
210
+ const float v = src[j];
211
+ min = v < min ? v : min;
212
+ max = v > max ? v : max;
213
+ }
214
+
215
+ const float d = (max - min) / 31;
216
+ const float id = d ? 1.0f/d : 0.0f;
217
+
218
+ dst.d = d;
219
+ dst.m = min;
220
+
221
+ uint32_t qh = 0;
222
+ for (int j = 0; j < QK5_1/2; ++j) {
223
+ const float x0 = (src[0 + j] - min)*id;
224
+ const float x1 = (src[QK5_1/2 + j] - min)*id;
225
+
226
+ const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
227
+ const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
228
+
229
+ dst.qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
230
+ qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
231
+ qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2);
232
+ }
233
+
234
+ thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
235
+
236
+ for (int j = 0; j < 4; ++j) {
237
+ dst.qh[j] = qh8[j];
238
+ }
239
+ }
240
+
241
+ void quantize_iq4_nl(device const float * src, device block_iq4_nl & dst) {
242
+ float amax = 0.0f; // absolute max
243
+ float max = 0.0f;
244
+
245
+ for (int j = 0; j < QK4_NL; j++) {
246
+ const float v = src[j];
247
+ if (amax < fabs(v)) {
248
+ amax = fabs(v);
249
+ max = v;
250
+ }
251
+ }
252
+
253
+ const float d = max / kvalues_iq4nl_f[0];
254
+ const float id = d ? 1.0f/d : 0.0f;
255
+
256
+ float sumqx = 0, sumq2 = 0;
257
+ for (int j = 0; j < QK4_NL/2; ++j) {
258
+ const float x0 = src[0 + j]*id;
259
+ const float x1 = src[QK4_NL/2 + j]*id;
260
+
261
+ const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl_f, x0);
262
+ const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl_f, x1);
263
+
264
+ dst.qs[j] = xi0 | (xi1 << 4);
265
+
266
+ const float v0 = kvalues_iq4nl_f[xi0];
267
+ const float v1 = kvalues_iq4nl_f[xi1];
268
+ const float w0 = src[0 + j]*src[0 + j];
269
+ const float w1 = src[QK4_NL/2 + j]*src[QK4_NL/2 + j];
270
+ sumqx += w0*v0*src[j] + w1*v1*src[QK4_NL/2 + j];
271
+ sumq2 += w0*v0*v0 + w1*v1*v1;
272
+
273
+ }
274
+
275
+ dst.d = sumq2 > 0 ? sumqx/sumq2 : d;
276
+ }
277
+
100
278
  template <typename type4x4>
101
279
  void dequantize_q4_1(device const block_q4_1 * xb, short il, thread type4x4 & reg) {
102
280
  device const uint16_t * qs = ((device const uint16_t *)xb + 2);
@@ -279,6 +457,26 @@ void dequantize_q8_0_t4(device const block_q8_0 *xb, short il, thread type4 & re
279
457
  }
280
458
  }
281
459
 
460
+ void quantize_q8_0(device const float * src, device block_q8_0 & dst) {
461
+ float amax = 0.0f; // absolute max
462
+
463
+ for (int j = 0; j < QK8_0; j++) {
464
+ const float v = src[j];
465
+ amax = MAX(amax, fabs(v));
466
+ }
467
+
468
+ const float d = amax / ((1 << 7) - 1);
469
+ const float id = d ? 1.0f/d : 0.0f;
470
+
471
+ dst.d = d;
472
+
473
+ for (int j = 0; j < QK8_0; ++j) {
474
+ const float x0 = src[j]*id;
475
+
476
+ dst.qs[j] = round(x0);
477
+ }
478
+ }
479
+
282
480
  template <typename type4x4>
283
481
  void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
284
482
  const float d = xb->d;
@@ -2532,6 +2730,70 @@ template [[host_name("kernel_mul_mv_bf16_f32")]] kernel mul_mv_t kernel_mul_mv<
2532
2730
  template [[host_name("kernel_mul_mv_bf16_bf16")]] kernel mul_mv_t kernel_mul_mv<bfloat, bfloat4, bfloat, bfloat4>;
2533
2731
  #endif
2534
2732
 
2733
+ template<typename T04, typename T14, typename args_t>
2734
+ void kernel_mul_mv_c4_impl(
2735
+ args_t args,
2736
+ device const char * src0,
2737
+ device const char * src1,
2738
+ device char * dst,
2739
+ uint3 tgpig,
2740
+ ushort tiisg) {
2741
+ const int r0 = tgpig.x*32 + tiisg;
2742
+ const int rb = tgpig.y*N_MV_T_T;
2743
+ const int im = tgpig.z;
2744
+
2745
+ if (r0 >= args.ne01) {
2746
+ return;
2747
+ }
2748
+
2749
+ const uint i12 = im%args.ne12;
2750
+ const uint i13 = im/args.ne12;
2751
+
2752
+ const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
2753
+
2754
+ device const T04 * x = (device const T04 *) (src0 + offset0);
2755
+
2756
+ device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1;
2757
+
2758
+ for (int row = 0; row < N_MV_T_T; ++row) {
2759
+ int r1 = rb + row;
2760
+ if (r1 >= args.ne11) {
2761
+ break;
2762
+ }
2763
+
2764
+ const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
2765
+
2766
+ device const T14 * y = (device const T14 *) (src1 + offset1);
2767
+
2768
+ dst_f32[(uint64_t)r1*args.ne0 + r0] = dot((float4) x[0], (float4) y[0]);
2769
+ }
2770
+ }
2771
+
2772
+ template<typename T04, typename T14>
2773
+ kernel void kernel_mul_mv_c4(
2774
+ constant ggml_metal_kargs_mul_mv & args,
2775
+ device const char * src0,
2776
+ device const char * src1,
2777
+ device char * dst,
2778
+ uint3 tgpig[[threadgroup_position_in_grid]],
2779
+ ushort tiisg[[thread_index_in_simdgroup]]) {
2780
+ kernel_mul_mv_c4_impl<T04, T14, constant ggml_metal_kargs_mul_mv &>(
2781
+ args,
2782
+ src0,
2783
+ src1,
2784
+ dst,
2785
+ tgpig,
2786
+ tiisg);
2787
+ }
2788
+
2789
+ typedef decltype(kernel_mul_mv_c4<half4, half4>) mul_mv_c4_t;
2790
+
2791
+ template [[host_name("kernel_mul_mv_f32_f32_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4<float4, float4>;
2792
+ template [[host_name("kernel_mul_mv_f16_f32_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4<half4, float4>;
2793
+ #if defined(GGML_METAL_USE_BF16)
2794
+ template [[host_name("kernel_mul_mv_bf16_f32_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4<bfloat4, float4>;
2795
+ #endif
2796
+
2535
2797
  template<typename T, typename T4>
2536
2798
  kernel void kernel_mul_mv_1row(
2537
2799
  constant ggml_metal_kargs_mul_mv & args,
@@ -4306,11 +4568,16 @@ kernel void kernel_cpy(
4306
4568
  device const char * src0,
4307
4569
  device char * dst,
4308
4570
  uint3 tgpig[[threadgroup_position_in_grid]],
4571
+ uint tiitg[[thread_index_in_threadgroup]],
4309
4572
  ushort3 tpitg[[thread_position_in_threadgroup]],
4310
- ushort3 ntg[[threads_per_threadgroup]]) {
4573
+ ushort3 tptg[[threads_per_threadgroup]]) {
4311
4574
  const int i03 = tgpig[2];
4312
4575
  const int i02 = tgpig[1];
4313
- const int i01 = tgpig[0];
4576
+ const int i01 = tgpig[0]*tptg.y + tiitg/tptg.x;
4577
+
4578
+ if (i01 >= args.ne01) {
4579
+ return;
4580
+ }
4314
4581
 
4315
4582
  const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
4316
4583
 
@@ -4321,7 +4588,7 @@ kernel void kernel_cpy(
4321
4588
 
4322
4589
  device T1 * dst_data = (device T1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
4323
4590
 
4324
- for (int64_t i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) {
4591
+ for (int64_t i00 = tiitg%tptg.x; i00 < args.ne00; i00 += tptg.x) {
4325
4592
  device const T0 * src = (device T0 *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
4326
4593
  dst_data[i00] = (T1) src[0];
4327
4594
  }
@@ -4341,6 +4608,7 @@ template [[host_name("kernel_cpy_bf16_f32")]] kernel kernel_cpy_t kernel_cpy<bf
4341
4608
  template [[host_name("kernel_cpy_bf16_bf16")]] kernel kernel_cpy_t kernel_cpy<bfloat, bfloat>;
4342
4609
  #endif
4343
4610
 
4611
+ // TODO: templetify these kernels
4344
4612
  kernel void kernel_cpy_f32_q8_0(
4345
4613
  constant ggml_metal_kargs_cpy & args,
4346
4614
  device const char * src0,
@@ -4364,23 +4632,7 @@ kernel void kernel_cpy_f32_q8_0(
4364
4632
  for (int64_t i00 = tpitg.x*QK8_0; i00 < args.ne00; i00 += ntg.x*QK8_0) {
4365
4633
  device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
4366
4634
 
4367
- float amax = 0.0f; // absolute max
4368
-
4369
- for (int j = 0; j < QK8_0; j++) {
4370
- const float v = src[j];
4371
- amax = MAX(amax, fabs(v));
4372
- }
4373
-
4374
- const float d = amax / ((1 << 7) - 1);
4375
- const float id = d ? 1.0f/d : 0.0f;
4376
-
4377
- dst_data[i00/QK8_0].d = d;
4378
-
4379
- for (int j = 0; j < QK8_0; ++j) {
4380
- const float x0 = src[j]*id;
4381
-
4382
- dst_data[i00/QK8_0].qs[j] = round(x0);
4383
- }
4635
+ quantize_q8_0(src, dst_data[i00/QK8_0]);
4384
4636
  }
4385
4637
  }
4386
4638
 
@@ -4407,32 +4659,7 @@ kernel void kernel_cpy_f32_q4_0(
4407
4659
  for (int64_t i00 = tpitg.x*QK4_0; i00 < args.ne00; i00 += ntg.x*QK4_0) {
4408
4660
  device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
4409
4661
 
4410
- float amax = 0.0f; // absolute max
4411
- float max = 0.0f;
4412
-
4413
- for (int j = 0; j < QK4_0; j++) {
4414
- const float v = src[j];
4415
- if (amax < fabs(v)) {
4416
- amax = fabs(v);
4417
- max = v;
4418
- }
4419
- }
4420
-
4421
- const float d = max / -8;
4422
- const float id = d ? 1.0f/d : 0.0f;
4423
-
4424
- dst_data[i00/QK4_0].d = d;
4425
-
4426
- for (int j = 0; j < QK4_0/2; ++j) {
4427
- const float x0 = src[0 + j]*id;
4428
- const float x1 = src[QK4_0/2 + j]*id;
4429
-
4430
- const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));
4431
- const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));
4432
-
4433
- dst_data[i00/QK4_0].qs[j] = xi0;
4434
- dst_data[i00/QK4_0].qs[j] |= xi1 << 4;
4435
- }
4662
+ quantize_q4_0(src, dst_data[i00/QK4_0]);
4436
4663
  }
4437
4664
  }
4438
4665
 
@@ -4459,31 +4686,7 @@ kernel void kernel_cpy_f32_q4_1(
4459
4686
  for (int64_t i00 = tpitg.x*QK4_1; i00 < args.ne00; i00 += ntg.x*QK4_1) {
4460
4687
  device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
4461
4688
 
4462
- float min = FLT_MAX;
4463
- float max = -FLT_MAX;
4464
-
4465
- for (int j = 0; j < QK4_1; j++) {
4466
- const float v = src[j];
4467
- if (min > v) min = v;
4468
- if (max < v) max = v;
4469
- }
4470
-
4471
- const float d = (max - min) / ((1 << 4) - 1);
4472
- const float id = d ? 1.0f/d : 0.0f;
4473
-
4474
- dst_data[i00/QK4_1].d = d;
4475
- dst_data[i00/QK4_1].m = min;
4476
-
4477
- for (int j = 0; j < QK4_1/2; ++j) {
4478
- const float x0 = (src[0 + j] - min)*id;
4479
- const float x1 = (src[QK4_1/2 + j] - min)*id;
4480
-
4481
- const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f));
4482
- const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f));
4483
-
4484
- dst_data[i00/QK4_1].qs[j] = xi0;
4485
- dst_data[i00/QK4_1].qs[j] |= xi1 << 4;
4486
- }
4689
+ quantize_q4_1(src, dst_data[i00/QK4_1]);
4487
4690
  }
4488
4691
  }
4489
4692
 
@@ -4510,38 +4713,7 @@ kernel void kernel_cpy_f32_q5_0(
4510
4713
  for (int64_t i00 = tpitg.x*QK5_0; i00 < args.ne00; i00 += ntg.x*QK5_0) {
4511
4714
  device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
4512
4715
 
4513
- float amax = 0.0f; // absolute max
4514
- float max = 0.0f;
4515
-
4516
- for (int j = 0; j < QK5_0; j++) {
4517
- const float v = src[j];
4518
- if (amax < fabs(v)) {
4519
- amax = fabs(v);
4520
- max = v;
4521
- }
4522
- }
4523
-
4524
- const float d = max / -16;
4525
- const float id = d ? 1.0f/d : 0.0f;
4526
-
4527
- dst_data[i00/QK5_0].d = d;
4528
-
4529
- uint32_t qh = 0;
4530
- for (int j = 0; j < QK5_0/2; ++j) {
4531
- const float x0 = src[0 + j]*id;
4532
- const float x1 = src[QK5_0/2 + j]*id;
4533
-
4534
- const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f));
4535
- const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f));
4536
-
4537
- dst_data[i00/QK5_0].qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
4538
- qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
4539
- qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
4540
- }
4541
- thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
4542
- for (int j = 0; j < 4; ++j) {
4543
- dst_data[i00/QK5_0].qh[j] = qh8[j];
4544
- }
4716
+ quantize_q5_0(src, dst_data[i00/QK5_0]);
4545
4717
  }
4546
4718
  }
4547
4719
 
@@ -4568,49 +4740,8 @@ kernel void kernel_cpy_f32_q5_1(
4568
4740
  for (int64_t i00 = tpitg.x*QK5_1; i00 < args.ne00; i00 += ntg.x*QK5_1) {
4569
4741
  device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
4570
4742
 
4571
- float max = src[0];
4572
- float min = src[0];
4573
-
4574
- for (int j = 1; j < QK5_1; j++) {
4575
- const float v = src[j];
4576
- min = v < min ? v : min;
4577
- max = v > max ? v : max;
4578
- }
4579
-
4580
- const float d = (max - min) / 31;
4581
- const float id = d ? 1.0f/d : 0.0f;
4582
-
4583
- dst_data[i00/QK5_1].d = d;
4584
- dst_data[i00/QK5_1].m = min;
4585
-
4586
- uint32_t qh = 0;
4587
- for (int j = 0; j < QK5_1/2; ++j) {
4588
- const float x0 = (src[0 + j] - min)*id;
4589
- const float x1 = (src[QK5_1/2 + j] - min)*id;
4590
-
4591
- const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
4592
- const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
4593
-
4594
- dst_data[i00/QK5_1].qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
4595
- qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
4596
- qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2);
4597
- }
4598
- thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
4599
- for (int j = 0; j < 4; ++j) {
4600
- dst_data[i00/QK5_1].qh[j] = qh8[j];
4601
- }
4602
- }
4603
- }
4604
-
4605
- static inline int best_index_int8(int n, constant float * val, float x) {
4606
- if (x <= val[0]) return 0;
4607
- if (x >= val[n-1]) return n-1;
4608
- int ml = 0, mu = n-1;
4609
- while (mu-ml > 1) {
4610
- int mav = (ml+mu)/2;
4611
- if (x < val[mav]) mu = mav; else ml = mav;
4743
+ quantize_q5_1(src, dst_data[i00/QK5_1]);
4612
4744
  }
4613
- return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
4614
4745
  }
4615
4746
 
4616
4747
  kernel void kernel_cpy_f32_iq4_nl(
@@ -4636,40 +4767,7 @@ kernel void kernel_cpy_f32_iq4_nl(
4636
4767
  for (int64_t i00 = tpitg.x*QK4_NL; i00 < args.ne00; i00 += ntg.x*QK4_NL) {
4637
4768
  device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
4638
4769
 
4639
- float amax = 0.0f; // absolute max
4640
- float max = 0.0f;
4641
-
4642
- for (int j = 0; j < QK4_NL; j++) {
4643
- const float v = src[j];
4644
- if (amax < fabs(v)) {
4645
- amax = fabs(v);
4646
- max = v;
4647
- }
4648
- }
4649
-
4650
- const float d = max / kvalues_iq4nl_f[0];
4651
- const float id = d ? 1.0f/d : 0.0f;
4652
-
4653
- float sumqx = 0, sumq2 = 0;
4654
- for (int j = 0; j < QK4_NL/2; ++j) {
4655
- const float x0 = src[0 + j]*id;
4656
- const float x1 = src[QK4_NL/2 + j]*id;
4657
-
4658
- const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl_f, x0);
4659
- const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl_f, x1);
4660
-
4661
- dst_data[i00/QK4_NL].qs[j] = xi0 | (xi1 << 4);
4662
-
4663
- const float v0 = kvalues_iq4nl_f[xi0];
4664
- const float v1 = kvalues_iq4nl_f[xi1];
4665
- const float w0 = src[0 + j]*src[0 + j];
4666
- const float w1 = src[QK4_NL/2 + j]*src[QK4_NL/2 + j];
4667
- sumqx += w0*v0*src[j] + w1*v1*src[QK4_NL/2 + j];
4668
- sumq2 += w0*v0*v0 + w1*v1*v1;
4669
-
4670
- }
4671
-
4672
- dst_data[i00/QK4_NL].d = sumq2 > 0 ? sumqx/sumq2 : d;
4770
+ quantize_iq4_nl(src, dst_data[i00/QK4_NL]);
4673
4771
  }
4674
4772
  }
4675
4773
 
@@ -6350,10 +6448,10 @@ kernel void kernel_mul_mv_iq4_xs_f32(
6350
6448
 
6351
6449
  template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
6352
6450
  kernel void kernel_get_rows_q(
6451
+ constant ggml_metal_kargs_get_rows & args,
6353
6452
  device const void * src0,
6354
6453
  device const void * src1,
6355
6454
  device float * dst,
6356
- constant ggml_metal_kargs_get_rows & args,
6357
6455
  uint3 tgpig[[threadgroup_position_in_grid]],
6358
6456
  uint tiitg[[thread_index_in_threadgroup]],
6359
6457
  uint3 tptg [[threads_per_threadgroup]]) {
@@ -6373,10 +6471,10 @@ kernel void kernel_get_rows_q(
6373
6471
 
6374
6472
  template<typename T>
6375
6473
  kernel void kernel_get_rows_f(
6474
+ constant ggml_metal_kargs_get_rows & args,
6376
6475
  device const void * src0,
6377
6476
  device const void * src1,
6378
6477
  device float * dst,
6379
- constant ggml_metal_kargs_get_rows & args,
6380
6478
  uint3 tgpig[[threadgroup_position_in_grid]],
6381
6479
  uint tiitg[[thread_index_in_threadgroup]],
6382
6480
  uint3 tptg [[threads_per_threadgroup]]) {
@@ -6394,10 +6492,10 @@ kernel void kernel_get_rows_f(
6394
6492
  }
6395
6493
 
6396
6494
  kernel void kernel_get_rows_i32(
6495
+ constant ggml_metal_kargs_get_rows & args,
6397
6496
  device const void * src0,
6398
6497
  device const void * src1,
6399
6498
  device int32_t * dst,
6400
- constant ggml_metal_kargs_get_rows & args,
6401
6499
  uint3 tgpig[[threadgroup_position_in_grid]],
6402
6500
  uint tiitg[[thread_index_in_threadgroup]],
6403
6501
  uint3 tptg [[threads_per_threadgroup]]) {
@@ -6414,6 +6512,67 @@ kernel void kernel_get_rows_i32(
6414
6512
  }
6415
6513
  }
6416
6514
 
6515
+ template<typename block_q, void (*quantize_func)(device const float *, device block_q &)>
6516
+ kernel void kernel_set_rows_q32(
6517
+ constant ggml_metal_kargs_set_rows & args,
6518
+ device const void * src0,
6519
+ device const void * src1,
6520
+ device float * dst,
6521
+ uint3 tgpig[[threadgroup_position_in_grid]],
6522
+ uint tiitg[[thread_index_in_threadgroup]],
6523
+ uint3 tptg [[threads_per_threadgroup]]) {
6524
+ const int32_t i03 = tgpig.z;
6525
+ const int32_t i02 = tgpig.y;
6526
+
6527
+ const int32_t i12 = i03%args.ne12;
6528
+ const int32_t i11 = i02%args.ne11;
6529
+
6530
+ const int32_t i01 = tgpig.x*tptg.y + tiitg/tptg.x;
6531
+ if (i01 >= args.ne01) {
6532
+ return;
6533
+ }
6534
+
6535
+ const int32_t i10 = i01;
6536
+ const int64_t i1 = ((const device int64_t *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0];
6537
+
6538
+ device block_q * dst_row = ( device block_q *) (( device char *) dst + i1*args.nb1 + i02*args.nb2 + i03*args.nb3);
6539
+ const device float * src_row = (const device float *) ((const device char *) src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
6540
+
6541
+ for (int ind = tiitg%tptg.x; ind < args.nk0; ind += tptg.x) {
6542
+ quantize_func(src_row + 32*ind, dst_row[ind]);
6543
+ }
6544
+ }
6545
+
6546
+ template<typename T>
6547
+ kernel void kernel_set_rows_f(
6548
+ constant ggml_metal_kargs_set_rows & args,
6549
+ device const void * src0,
6550
+ device const void * src1,
6551
+ device float * dst,
6552
+ uint3 tgpig[[threadgroup_position_in_grid]],
6553
+ uint tiitg[[thread_index_in_threadgroup]],
6554
+ uint3 tptg [[threads_per_threadgroup]]) {
6555
+ const int32_t i03 = tgpig.z;
6556
+ const int32_t i02 = tgpig.y;
6557
+
6558
+ const int32_t i12 = i03%args.ne12;
6559
+ const int32_t i11 = i02%args.ne11;
6560
+
6561
+ const int32_t i01 = tgpig.x*tptg.y + tiitg/tptg.x;
6562
+ if (i01 >= args.ne01) {
6563
+ return;
6564
+ }
6565
+
6566
+ const int32_t i10 = i01;
6567
+ const int64_t i1 = ((const device int64_t *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0];
6568
+
6569
+ device T * dst_row = ( device T *) (( device char *) dst + i1*args.nb1 + i02*args.nb2 + i03*args.nb3);
6570
+ const device float * src_row = (const device float *) ((const device char *) src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
6571
+
6572
+ for (int ind = tiitg%tptg.x; ind < args.nk0; ind += tptg.x) {
6573
+ dst_row[ind] = (T) src_row[ind];
6574
+ }
6575
+ }
6417
6576
 
6418
6577
  #define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
6419
6578
  #define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B
@@ -6837,6 +6996,27 @@ template [[host_name("kernel_get_rows_iq1_m")]] kernel get_rows_q_t kernel_get
6837
6996
  template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_q_t kernel_get_rows_q<block_iq4_nl, 2, dequantize_iq4_nl>;
6838
6997
  template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
6839
6998
 
6999
+ //
7000
+ // set rows
7001
+ //
7002
+
7003
+ typedef decltype(kernel_set_rows_f<float>) set_rows_f_t;
7004
+
7005
+ template [[host_name("kernel_set_rows_f32")]] kernel set_rows_f_t kernel_set_rows_f<float>;
7006
+ template [[host_name("kernel_set_rows_f16")]] kernel set_rows_f_t kernel_set_rows_f<half>;
7007
+ #if defined(GGML_METAL_USE_BF16)
7008
+ template [[host_name("kernel_set_rows_bf16")]] kernel set_rows_f_t kernel_set_rows_f<bfloat>;
7009
+ #endif
7010
+
7011
+ typedef decltype(kernel_set_rows_q32<block_q8_0, quantize_q8_0>) set_rows_q32_t;
7012
+
7013
+ template [[host_name("kernel_set_rows_q8_0")]] kernel set_rows_q32_t kernel_set_rows_q32<block_q8_0, quantize_q8_0>;
7014
+ template [[host_name("kernel_set_rows_q4_0")]] kernel set_rows_q32_t kernel_set_rows_q32<block_q4_0, quantize_q4_0>;
7015
+ template [[host_name("kernel_set_rows_q4_1")]] kernel set_rows_q32_t kernel_set_rows_q32<block_q4_1, quantize_q4_1>;
7016
+ template [[host_name("kernel_set_rows_q5_0")]] kernel set_rows_q32_t kernel_set_rows_q32<block_q5_0, quantize_q5_0>;
7017
+ template [[host_name("kernel_set_rows_q5_1")]] kernel set_rows_q32_t kernel_set_rows_q32<block_q5_1, quantize_q5_1>;
7018
+ template [[host_name("kernel_set_rows_iq4_nl")]] kernel set_rows_q32_t kernel_set_rows_q32<block_iq4_nl, quantize_iq4_nl>;
7019
+
6840
7020
  //
6841
7021
  // matrix-matrix multiplication
6842
7022
  //
@@ -1,7 +1,7 @@
1
1
  #pragma once
2
2
 
3
- #include "../include/ggml.h"
4
- #include "../ggml-cuda/common.cuh"
3
+ #include "ggml-cuda/common.cuh"
4
+ #include "ggml.h"
5
5
 
6
6
  // Asynchronously copies data from src tensor to dst tensor using the provided context.
7
7
  // Returns a musaError_t indicating success or failure.