llama_cpp 0.15.3 → 0.16.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +16 -0
- data/ext/llama_cpp/extconf.rb +1 -2
- data/ext/llama_cpp/llama_cpp.cpp +27 -3
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +15 -1
- data/vendor/tmp/llama.cpp/Makefile +66 -36
- data/vendor/tmp/llama.cpp/ggml-alloc.c +4 -4
- data/vendor/tmp/llama.cpp/ggml-backend.c +5 -5
- data/vendor/tmp/llama.cpp/ggml-backend.h +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/acc.cu +47 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/arange.cu +34 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/argsort.cu +103 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/binbcast.cu +280 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/clamp.cu +34 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/concat.cu +196 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/convert.cu +686 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/cpy.cu +490 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/diagmask.cu +40 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/dmmv.cu +662 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/fattn-tile-f16.cu +319 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/fattn-tile-f32.cu +312 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/fattn.cu +345 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/getrows.cu +178 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/im2col.cu +104 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/mmq.cu +1564 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/mmvq.cu +404 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/norm.cu +221 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/pad.cu +49 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/pool2d.cu +94 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/quantize.cu +45 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/rope.cu +271 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/scale.cu +31 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/softmax.cu +205 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/sumrows.cu +40 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb16.cu +10 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb32.cu +9 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb16.cu +10 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb32.cu +10 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb8.cu +8 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/tsembd.cu +47 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/unary.cu +266 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/upscale.cu +51 -0
- data/vendor/tmp/llama.cpp/ggml-cuda.cu +35 -16
- data/vendor/tmp/llama.cpp/ggml-impl.h +4 -0
- data/vendor/tmp/llama.cpp/ggml-kompute.cpp +21 -7
- data/vendor/tmp/llama.cpp/ggml-metal.h +1 -1
- data/vendor/tmp/llama.cpp/ggml-metal.m +99 -35
- data/vendor/tmp/llama.cpp/ggml-metal.metal +146 -80
- data/vendor/tmp/llama.cpp/ggml-quants.c +101 -11
- data/vendor/tmp/llama.cpp/ggml-rpc.cpp +75 -58
- data/vendor/tmp/llama.cpp/ggml-sycl.cpp +345 -227
- data/vendor/tmp/llama.cpp/ggml-vulkan-shaders.hpp +99301 -39793
- data/vendor/tmp/llama.cpp/ggml-vulkan.cpp +458 -329
- data/vendor/tmp/llama.cpp/ggml.c +301 -409
- data/vendor/tmp/llama.cpp/ggml.h +19 -23
- data/vendor/tmp/llama.cpp/llama.cpp +855 -651
- data/vendor/tmp/llama.cpp/llama.h +28 -48
- metadata +121 -6
- data/vendor/tmp/llama.cpp/ggml-mpi.c +0 -216
- data/vendor/tmp/llama.cpp/ggml-mpi.h +0 -39
- data/vendor/tmp/llama.cpp/ggml-opencl.cpp +0 -2305
- data/vendor/tmp/llama.cpp/ggml-opencl.h +0 -36
|
@@ -35,6 +35,10 @@ enum ggml_metal_kernel_type {
|
|
|
35
35
|
GGML_METAL_KERNEL_TYPE_MUL_ROW,
|
|
36
36
|
GGML_METAL_KERNEL_TYPE_DIV,
|
|
37
37
|
GGML_METAL_KERNEL_TYPE_DIV_ROW,
|
|
38
|
+
GGML_METAL_KERNEL_TYPE_REPEAT_F32,
|
|
39
|
+
GGML_METAL_KERNEL_TYPE_REPEAT_F16,
|
|
40
|
+
GGML_METAL_KERNEL_TYPE_REPEAT_I32,
|
|
41
|
+
GGML_METAL_KERNEL_TYPE_REPEAT_I16,
|
|
38
42
|
GGML_METAL_KERNEL_TYPE_SCALE,
|
|
39
43
|
GGML_METAL_KERNEL_TYPE_SCALE_4,
|
|
40
44
|
GGML_METAL_KERNEL_TYPE_CLAMP,
|
|
@@ -168,8 +172,10 @@ enum ggml_metal_kernel_type {
|
|
|
168
172
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32,
|
|
169
173
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,
|
|
170
174
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,
|
|
171
|
-
|
|
172
|
-
|
|
175
|
+
GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32,
|
|
176
|
+
GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16,
|
|
177
|
+
GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32,
|
|
178
|
+
GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16,
|
|
173
179
|
GGML_METAL_KERNEL_TYPE_IM2COL_F16,
|
|
174
180
|
GGML_METAL_KERNEL_TYPE_IM2COL_F32,
|
|
175
181
|
GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
|
|
@@ -184,9 +190,9 @@ enum ggml_metal_kernel_type {
|
|
|
184
190
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96,
|
|
185
191
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,
|
|
186
192
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,
|
|
187
|
-
|
|
193
|
+
//GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
|
|
188
194
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
|
|
189
|
-
|
|
195
|
+
//GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
|
|
190
196
|
GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
|
|
191
197
|
GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
|
|
192
198
|
GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
|
|
@@ -485,6 +491,10 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
|
485
491
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true);
|
|
486
492
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true);
|
|
487
493
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true);
|
|
494
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F32, repeat_f32, true);
|
|
495
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F16, repeat_f16, true);
|
|
496
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I32, repeat_i32, true);
|
|
497
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I16, repeat_i16, true);
|
|
488
498
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true);
|
|
489
499
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true);
|
|
490
500
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true);
|
|
@@ -618,8 +628,10 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
|
618
628
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, ctx->support_simdgroup_mm);
|
|
619
629
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm);
|
|
620
630
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm);
|
|
621
|
-
GGML_METAL_ADD_KERNEL(
|
|
622
|
-
GGML_METAL_ADD_KERNEL(
|
|
631
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, rope_norm_f32, true);
|
|
632
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, rope_norm_f16, true);
|
|
633
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, rope_neox_f32, true);
|
|
634
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, rope_neox_f16, true);
|
|
623
635
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
|
|
624
636
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
|
|
625
637
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
|
|
@@ -634,9 +646,9 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
|
634
646
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, ctx->support_simdgroup_mm);
|
|
635
647
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, ctx->support_simdgroup_mm);
|
|
636
648
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, ctx->support_simdgroup_mm);
|
|
637
|
-
|
|
649
|
+
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, ctx->support_simdgroup_mm);
|
|
638
650
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, ctx->support_simdgroup_reduction);
|
|
639
|
-
|
|
651
|
+
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, ctx->support_simdgroup_reduction);
|
|
640
652
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
|
|
641
653
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
|
|
642
654
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
|
|
@@ -746,6 +758,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
|
|
|
746
758
|
case GGML_OP_ACC:
|
|
747
759
|
case GGML_OP_MUL:
|
|
748
760
|
case GGML_OP_DIV:
|
|
761
|
+
case GGML_OP_REPEAT:
|
|
749
762
|
case GGML_OP_SCALE:
|
|
750
763
|
case GGML_OP_CLAMP:
|
|
751
764
|
case GGML_OP_SQR:
|
|
@@ -770,6 +783,15 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
|
|
|
770
783
|
case GGML_OP_LEAKY_RELU:
|
|
771
784
|
return true;
|
|
772
785
|
case GGML_OP_FLASH_ATTN_EXT:
|
|
786
|
+
if (op->src[1]->type != GGML_TYPE_F16) {
|
|
787
|
+
return false;
|
|
788
|
+
}
|
|
789
|
+
if (op->src[2]->type != GGML_TYPE_F16) {
|
|
790
|
+
return false;
|
|
791
|
+
}
|
|
792
|
+
if (op->src[0]->ne[0] == 256) {
|
|
793
|
+
return false;
|
|
794
|
+
}
|
|
773
795
|
return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels
|
|
774
796
|
case GGML_OP_MUL_MAT:
|
|
775
797
|
case GGML_OP_MUL_MAT_ID:
|
|
@@ -976,10 +998,10 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
|
976
998
|
switch (dst->op) {
|
|
977
999
|
case GGML_OP_CONCAT:
|
|
978
1000
|
{
|
|
979
|
-
const int64_t nb = ne00;
|
|
980
|
-
|
|
981
1001
|
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONCAT].pipeline;
|
|
982
1002
|
|
|
1003
|
+
const int32_t dim = ((int32_t *) dst->op_params)[0];
|
|
1004
|
+
|
|
983
1005
|
[encoder setComputePipelineState:pipeline];
|
|
984
1006
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
985
1007
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
|
@@ -1008,7 +1030,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
|
1008
1030
|
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
|
|
1009
1031
|
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
|
|
1010
1032
|
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
|
|
1011
|
-
[encoder setBytes:&
|
|
1033
|
+
[encoder setBytes:&dim length:sizeof(dim) atIndex:27];
|
|
1012
1034
|
|
|
1013
1035
|
const int nth = MIN(1024, ne0);
|
|
1014
1036
|
|
|
@@ -1018,11 +1040,14 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
|
1018
1040
|
case GGML_OP_MUL:
|
|
1019
1041
|
case GGML_OP_DIV:
|
|
1020
1042
|
{
|
|
1043
|
+
GGML_ASSERT(src0t == GGML_TYPE_F32);
|
|
1044
|
+
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
|
1045
|
+
|
|
1021
1046
|
const size_t offs = 0;
|
|
1022
1047
|
|
|
1023
1048
|
bool bcast_row = false;
|
|
1024
1049
|
|
|
1025
|
-
int64_t nb = ne00;
|
|
1050
|
+
int64_t nb = ne00; // used by the "row" kernels
|
|
1026
1051
|
|
|
1027
1052
|
id<MTLComputePipelineState> pipeline = nil;
|
|
1028
1053
|
|
|
@@ -1091,6 +1116,42 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
|
1091
1116
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
1092
1117
|
}
|
|
1093
1118
|
} break;
|
|
1119
|
+
case GGML_OP_REPEAT:
|
|
1120
|
+
{
|
|
1121
|
+
id<MTLComputePipelineState> pipeline;
|
|
1122
|
+
|
|
1123
|
+
switch (src0t) {
|
|
1124
|
+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F32].pipeline; break;
|
|
1125
|
+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F16].pipeline; break;
|
|
1126
|
+
case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I32].pipeline; break;
|
|
1127
|
+
case GGML_TYPE_I16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I16].pipeline; break;
|
|
1128
|
+
default: GGML_ASSERT(false);
|
|
1129
|
+
}
|
|
1130
|
+
|
|
1131
|
+
[encoder setComputePipelineState:pipeline];
|
|
1132
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1133
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
1134
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
|
1135
|
+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
|
1136
|
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
|
1137
|
+
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
|
|
1138
|
+
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
|
|
1139
|
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
|
|
1140
|
+
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
|
|
1141
|
+
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
|
|
1142
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
|
|
1143
|
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
|
|
1144
|
+
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
|
|
1145
|
+
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
|
|
1146
|
+
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
|
|
1147
|
+
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
|
|
1148
|
+
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
|
|
1149
|
+
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
|
|
1150
|
+
|
|
1151
|
+
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
|
|
1152
|
+
|
|
1153
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
1154
|
+
} break;
|
|
1094
1155
|
case GGML_OP_ACC:
|
|
1095
1156
|
{
|
|
1096
1157
|
GGML_ASSERT(src0t == GGML_TYPE_F32);
|
|
@@ -1468,7 +1529,6 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
|
1468
1529
|
{
|
|
1469
1530
|
GGML_ASSERT(ne00 == ne10);
|
|
1470
1531
|
|
|
1471
|
-
// TODO: assert that dim2 and dim3 are contiguous
|
|
1472
1532
|
GGML_ASSERT(ne12 % ne02 == 0);
|
|
1473
1533
|
GGML_ASSERT(ne13 % ne03 == 0);
|
|
1474
1534
|
|
|
@@ -2136,6 +2196,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
|
2136
2196
|
case GGML_OP_RMS_NORM:
|
|
2137
2197
|
{
|
|
2138
2198
|
GGML_ASSERT(ne00 % 4 == 0);
|
|
2199
|
+
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
|
2139
2200
|
|
|
2140
2201
|
float eps;
|
|
2141
2202
|
memcpy(&eps, dst->op_params, sizeof(float));
|
|
@@ -2163,6 +2224,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
|
2163
2224
|
case GGML_OP_GROUP_NORM:
|
|
2164
2225
|
{
|
|
2165
2226
|
GGML_ASSERT(ne00 % 4 == 0);
|
|
2227
|
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
|
2166
2228
|
|
|
2167
2229
|
//float eps;
|
|
2168
2230
|
//memcpy(&eps, dst->op_params, sizeof(float));
|
|
@@ -2196,6 +2258,8 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
|
2196
2258
|
} break;
|
|
2197
2259
|
case GGML_OP_NORM:
|
|
2198
2260
|
{
|
|
2261
|
+
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
|
2262
|
+
|
|
2199
2263
|
float eps;
|
|
2200
2264
|
memcpy(&eps, dst->op_params, sizeof(float));
|
|
2201
2265
|
|
|
@@ -2225,7 +2289,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
|
2225
2289
|
const int n_dims = ((int32_t *) dst->op_params)[1];
|
|
2226
2290
|
const int mode = ((int32_t *) dst->op_params)[2];
|
|
2227
2291
|
// skip 3, n_ctx, used in GLM RoPE, unimplemented in metal
|
|
2228
|
-
const int
|
|
2292
|
+
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
|
|
2229
2293
|
|
|
2230
2294
|
float freq_base;
|
|
2231
2295
|
float freq_scale;
|
|
@@ -2242,22 +2306,23 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
|
2242
2306
|
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
|
2243
2307
|
|
|
2244
2308
|
const bool is_neox = mode & 2;
|
|
2245
|
-
const bool is_glm = mode & 4;
|
|
2246
2309
|
|
|
2247
|
-
|
|
2310
|
+
id<MTLComputePipelineState> pipeline = nil;
|
|
2248
2311
|
|
|
2249
2312
|
if (!is_neox) {
|
|
2250
|
-
|
|
2313
|
+
switch (src0->type) {
|
|
2314
|
+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break;
|
|
2315
|
+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break;
|
|
2316
|
+
default: GGML_ASSERT(false);
|
|
2317
|
+
};
|
|
2318
|
+
} else {
|
|
2319
|
+
switch (src0->type) {
|
|
2320
|
+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break;
|
|
2321
|
+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break;
|
|
2322
|
+
default: GGML_ASSERT(false);
|
|
2323
|
+
};
|
|
2251
2324
|
}
|
|
2252
2325
|
|
|
2253
|
-
id<MTLComputePipelineState> pipeline = nil;
|
|
2254
|
-
|
|
2255
|
-
switch (src0->type) {
|
|
2256
|
-
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_F32].pipeline; break;
|
|
2257
|
-
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_F16].pipeline; break;
|
|
2258
|
-
default: GGML_ASSERT(false);
|
|
2259
|
-
};
|
|
2260
|
-
|
|
2261
2326
|
[encoder setComputePipelineState:pipeline];
|
|
2262
2327
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
2263
2328
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
|
@@ -2285,14 +2350,13 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
|
2285
2350
|
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:19];
|
|
2286
2351
|
[encoder setBytes:&n_past length:sizeof( int) atIndex:20];
|
|
2287
2352
|
[encoder setBytes:&n_dims length:sizeof( int) atIndex:21];
|
|
2288
|
-
[encoder setBytes:&
|
|
2289
|
-
[encoder setBytes:&
|
|
2290
|
-
[encoder setBytes:&
|
|
2291
|
-
[encoder setBytes:&
|
|
2292
|
-
[encoder setBytes:&
|
|
2293
|
-
[encoder setBytes:&
|
|
2294
|
-
[encoder setBytes:&
|
|
2295
|
-
[encoder setBytes:&beta_slow length:sizeof( float) atIndex:29];
|
|
2353
|
+
[encoder setBytes:&n_ctx_orig length:sizeof( int) atIndex:22];
|
|
2354
|
+
[encoder setBytes:&freq_base length:sizeof( float) atIndex:23];
|
|
2355
|
+
[encoder setBytes:&freq_scale length:sizeof( float) atIndex:24];
|
|
2356
|
+
[encoder setBytes:&ext_factor length:sizeof( float) atIndex:25];
|
|
2357
|
+
[encoder setBytes:&attn_factor length:sizeof( float) atIndex:26];
|
|
2358
|
+
[encoder setBytes:&beta_fast length:sizeof( float) atIndex:27];
|
|
2359
|
+
[encoder setBytes:&beta_slow length:sizeof( float) atIndex:28];
|
|
2296
2360
|
|
|
2297
2361
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
2298
2362
|
} break;
|
|
@@ -2573,7 +2637,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
|
2573
2637
|
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break;
|
|
2574
2638
|
case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break;
|
|
2575
2639
|
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break;
|
|
2576
|
-
|
|
2640
|
+
//case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
|
|
2577
2641
|
default:
|
|
2578
2642
|
{
|
|
2579
2643
|
GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
|
|
@@ -2586,7 +2650,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
|
2586
2650
|
|
|
2587
2651
|
switch (ne00) {
|
|
2588
2652
|
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break;
|
|
2589
|
-
|
|
2653
|
+
//case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
|
|
2590
2654
|
default:
|
|
2591
2655
|
{
|
|
2592
2656
|
GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
|
|
@@ -168,6 +168,53 @@ kernel void kernel_div(
|
|
|
168
168
|
}
|
|
169
169
|
}
|
|
170
170
|
|
|
171
|
+
template<typename T>
|
|
172
|
+
kernel void kernel_repeat(
|
|
173
|
+
device const char * src0,
|
|
174
|
+
device char * dst,
|
|
175
|
+
constant int64_t & ne00,
|
|
176
|
+
constant int64_t & ne01,
|
|
177
|
+
constant int64_t & ne02,
|
|
178
|
+
constant int64_t & ne03,
|
|
179
|
+
constant uint64_t & nb00,
|
|
180
|
+
constant uint64_t & nb01,
|
|
181
|
+
constant uint64_t & nb02,
|
|
182
|
+
constant uint64_t & nb03,
|
|
183
|
+
constant int64_t & ne0,
|
|
184
|
+
constant int64_t & ne1,
|
|
185
|
+
constant int64_t & ne2,
|
|
186
|
+
constant int64_t & ne3,
|
|
187
|
+
constant uint64_t & nb0,
|
|
188
|
+
constant uint64_t & nb1,
|
|
189
|
+
constant uint64_t & nb2,
|
|
190
|
+
constant uint64_t & nb3,
|
|
191
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
192
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
193
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
|
194
|
+
const int64_t i3 = tgpig.z;
|
|
195
|
+
const int64_t i2 = tgpig.y;
|
|
196
|
+
const int64_t i1 = tgpig.x;
|
|
197
|
+
|
|
198
|
+
const int64_t i03 = i3 % ne03;
|
|
199
|
+
const int64_t i02 = i2 % ne02;
|
|
200
|
+
const int64_t i01 = i1 % ne01;
|
|
201
|
+
|
|
202
|
+
device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
|
|
203
|
+
device char * dst_ptr = dst + i3*nb3 + i2*nb2 + i1*nb1 ;
|
|
204
|
+
|
|
205
|
+
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
|
206
|
+
const int i00 = i0 % ne00;
|
|
207
|
+
*((device T *)(dst_ptr + i0*nb0)) = *((device T *)(src0_ptr + i00*nb00));
|
|
208
|
+
}
|
|
209
|
+
}
|
|
210
|
+
|
|
211
|
+
typedef decltype(kernel_repeat<float>) kernel_repeat_t;
|
|
212
|
+
|
|
213
|
+
template [[host_name("kernel_repeat_f32")]] kernel kernel_repeat_t kernel_repeat<float>;
|
|
214
|
+
template [[host_name("kernel_repeat_f16")]] kernel kernel_repeat_t kernel_repeat<half>;
|
|
215
|
+
template [[host_name("kernel_repeat_i32")]] kernel kernel_repeat_t kernel_repeat<int>;
|
|
216
|
+
template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat<short>;
|
|
217
|
+
|
|
171
218
|
// assumption: src1 is a row
|
|
172
219
|
// broadcast src1 into src0
|
|
173
220
|
kernel void kernel_add_row(
|
|
@@ -1607,8 +1654,7 @@ static float rope_yarn_ramp(const float low, const float high, const int i0) {
|
|
|
1607
1654
|
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
|
|
1608
1655
|
static void rope_yarn(
|
|
1609
1656
|
float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
|
|
1610
|
-
thread float * cos_theta, thread float * sin_theta
|
|
1611
|
-
) {
|
|
1657
|
+
thread float * cos_theta, thread float * sin_theta) {
|
|
1612
1658
|
// Get n-d rotational scaling corrected for extrapolation
|
|
1613
1659
|
float theta_interp = freq_scale * theta_extrap;
|
|
1614
1660
|
float theta = theta_interp;
|
|
@@ -1625,19 +1671,20 @@ static void rope_yarn(
|
|
|
1625
1671
|
|
|
1626
1672
|
// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
|
|
1627
1673
|
// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
|
|
1628
|
-
static float rope_yarn_corr_factor(int n_dims, int
|
|
1629
|
-
return n_dims * log(
|
|
1674
|
+
static float rope_yarn_corr_factor(int n_dims, int n_ctx_orig, float n_rot, float base) {
|
|
1675
|
+
return n_dims * log(n_ctx_orig / (n_rot * 2 * M_PI_F)) / (2 * log(base));
|
|
1630
1676
|
}
|
|
1631
1677
|
|
|
1632
1678
|
static void rope_yarn_corr_dims(
|
|
1633
|
-
int n_dims, int
|
|
1679
|
+
int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2]
|
|
1634
1680
|
) {
|
|
1635
1681
|
// start and end correction dims
|
|
1636
|
-
dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims,
|
|
1637
|
-
dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims,
|
|
1682
|
+
dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_fast, freq_base)));
|
|
1683
|
+
dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_slow, freq_base)));
|
|
1638
1684
|
}
|
|
1639
1685
|
|
|
1640
|
-
|
|
1686
|
+
template<typename T>
|
|
1687
|
+
kernel void kernel_rope_norm(
|
|
1641
1688
|
device const void * src0,
|
|
1642
1689
|
device const int32_t * src1,
|
|
1643
1690
|
device const float * src2,
|
|
@@ -1660,8 +1707,7 @@ typedef void (rope_t)(
|
|
|
1660
1707
|
constant uint64_t & nb3,
|
|
1661
1708
|
constant int & n_past,
|
|
1662
1709
|
constant int & n_dims,
|
|
1663
|
-
constant int &
|
|
1664
|
-
constant int & n_orig_ctx,
|
|
1710
|
+
constant int & n_ctx_orig,
|
|
1665
1711
|
constant float & freq_base,
|
|
1666
1712
|
constant float & freq_scale,
|
|
1667
1713
|
constant float & ext_factor,
|
|
@@ -1670,10 +1716,52 @@ typedef void (rope_t)(
|
|
|
1670
1716
|
constant float & beta_slow,
|
|
1671
1717
|
uint tiitg[[thread_index_in_threadgroup]],
|
|
1672
1718
|
uint3 tptg[[threads_per_threadgroup]],
|
|
1673
|
-
uint3 tgpig[[threadgroup_position_in_grid]])
|
|
1719
|
+
uint3 tgpig[[threadgroup_position_in_grid]]) {
|
|
1720
|
+
const int64_t i3 = tgpig[2];
|
|
1721
|
+
const int64_t i2 = tgpig[1];
|
|
1722
|
+
const int64_t i1 = tgpig[0];
|
|
1723
|
+
|
|
1724
|
+
float corr_dims[2];
|
|
1725
|
+
rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
|
|
1726
|
+
|
|
1727
|
+
device const int32_t * pos = src1;
|
|
1728
|
+
|
|
1729
|
+
const float theta_base = (float) pos[i2];
|
|
1730
|
+
const float inv_ndims = -1.f/n_dims;
|
|
1731
|
+
|
|
1732
|
+
float cos_theta;
|
|
1733
|
+
float sin_theta;
|
|
1734
|
+
|
|
1735
|
+
for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
|
|
1736
|
+
if (i0 < n_dims) {
|
|
1737
|
+
const int64_t ic = i0/2;
|
|
1738
|
+
|
|
1739
|
+
const float theta = theta_base * pow(freq_base, inv_ndims*i0);
|
|
1740
|
+
|
|
1741
|
+
const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
|
|
1742
|
+
|
|
1743
|
+
rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
|
1744
|
+
|
|
1745
|
+
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
|
1746
|
+
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
1747
|
+
|
|
1748
|
+
const float x0 = src[0];
|
|
1749
|
+
const float x1 = src[1];
|
|
1750
|
+
|
|
1751
|
+
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
|
1752
|
+
dst_data[1] = x0*sin_theta + x1*cos_theta;
|
|
1753
|
+
} else {
|
|
1754
|
+
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
|
1755
|
+
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
1756
|
+
|
|
1757
|
+
dst_data[0] = src[0];
|
|
1758
|
+
dst_data[1] = src[1];
|
|
1759
|
+
}
|
|
1760
|
+
}
|
|
1761
|
+
}
|
|
1674
1762
|
|
|
1675
1763
|
template<typename T>
|
|
1676
|
-
kernel void
|
|
1764
|
+
kernel void kernel_rope_neox(
|
|
1677
1765
|
device const void * src0,
|
|
1678
1766
|
device const int32_t * src1,
|
|
1679
1767
|
device const float * src2,
|
|
@@ -1696,8 +1784,7 @@ kernel void kernel_rope(
|
|
|
1696
1784
|
constant uint64_t & nb3,
|
|
1697
1785
|
constant int & n_past,
|
|
1698
1786
|
constant int & n_dims,
|
|
1699
|
-
constant int &
|
|
1700
|
-
constant int & n_orig_ctx,
|
|
1787
|
+
constant int & n_ctx_orig,
|
|
1701
1788
|
constant float & freq_base,
|
|
1702
1789
|
constant float & freq_scale,
|
|
1703
1790
|
constant float & ext_factor,
|
|
@@ -1711,73 +1798,53 @@ kernel void kernel_rope(
|
|
|
1711
1798
|
const int64_t i2 = tgpig[1];
|
|
1712
1799
|
const int64_t i1 = tgpig[0];
|
|
1713
1800
|
|
|
1714
|
-
const bool is_neox = mode & 2;
|
|
1715
|
-
|
|
1716
1801
|
float corr_dims[2];
|
|
1717
|
-
rope_yarn_corr_dims(n_dims,
|
|
1802
|
+
rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
|
|
1718
1803
|
|
|
1719
1804
|
device const int32_t * pos = src1;
|
|
1720
1805
|
|
|
1721
|
-
const
|
|
1722
|
-
|
|
1723
|
-
const float theta_0 = (float)p;
|
|
1806
|
+
const float theta_base = (float) pos[i2];
|
|
1724
1807
|
const float inv_ndims = -1.f/n_dims;
|
|
1725
1808
|
|
|
1726
|
-
|
|
1727
|
-
|
|
1728
|
-
|
|
1729
|
-
const float theta = theta_0 * pow(freq_base, inv_ndims*i0);
|
|
1730
|
-
float cos_theta, sin_theta;
|
|
1731
|
-
rope_yarn(theta, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
|
1732
|
-
|
|
1733
|
-
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
|
1734
|
-
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
1735
|
-
|
|
1736
|
-
const T x0 = src[0];
|
|
1737
|
-
const T x1 = src[1];
|
|
1809
|
+
float cos_theta;
|
|
1810
|
+
float sin_theta;
|
|
1738
1811
|
|
|
1739
|
-
|
|
1740
|
-
|
|
1741
|
-
|
|
1742
|
-
} else {
|
|
1743
|
-
for (int64_t ic = 2*tiitg; ic < ne0; ic += 2*tptg.x) {
|
|
1744
|
-
if (ic < n_dims) {
|
|
1745
|
-
const int64_t ib = 0;
|
|
1746
|
-
|
|
1747
|
-
// simplified from `(ib * n_dims + ic) * inv_ndims`
|
|
1748
|
-
const float cur_rot = inv_ndims*ic - ib;
|
|
1749
|
-
const float freq_factor = src2 != src0 ? src2[ic/2] : 1.0f;
|
|
1750
|
-
|
|
1751
|
-
const float theta = theta_0 * pow(freq_base, cur_rot) / freq_factor;
|
|
1812
|
+
for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
|
|
1813
|
+
if (i0 < n_dims) {
|
|
1814
|
+
const int64_t ic = i0/2;
|
|
1752
1815
|
|
|
1753
|
-
|
|
1754
|
-
rope_yarn(theta, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
|
1816
|
+
const float theta = theta_base * pow(freq_base, inv_ndims*i0);
|
|
1755
1817
|
|
|
1756
|
-
|
|
1818
|
+
const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
|
|
1757
1819
|
|
|
1758
|
-
|
|
1759
|
-
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
1820
|
+
rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
|
1760
1821
|
|
|
1761
|
-
|
|
1762
|
-
|
|
1822
|
+
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
|
|
1823
|
+
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
|
|
1763
1824
|
|
|
1764
|
-
|
|
1765
|
-
|
|
1766
|
-
} else {
|
|
1767
|
-
const int64_t i0 = ic;
|
|
1825
|
+
const float x0 = src[0];
|
|
1826
|
+
const float x1 = src[n_dims/2];
|
|
1768
1827
|
|
|
1769
|
-
|
|
1770
|
-
|
|
1828
|
+
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
|
1829
|
+
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
|
|
1830
|
+
} else {
|
|
1831
|
+
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
|
1832
|
+
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
1771
1833
|
|
|
1772
|
-
|
|
1773
|
-
|
|
1774
|
-
}
|
|
1834
|
+
dst_data[0] = src[0];
|
|
1835
|
+
dst_data[1] = src[1];
|
|
1775
1836
|
}
|
|
1776
1837
|
}
|
|
1777
1838
|
}
|
|
1778
1839
|
|
|
1779
|
-
|
|
1780
|
-
|
|
1840
|
+
typedef decltype(kernel_rope_norm<float>) kernel_rope_norm_t;
|
|
1841
|
+
typedef decltype(kernel_rope_neox<float>) kernel_rope_neox_t;
|
|
1842
|
+
|
|
1843
|
+
template [[host_name("kernel_rope_norm_f32")]] kernel kernel_rope_norm_t kernel_rope_norm<float>;
|
|
1844
|
+
template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_rope_norm<half>;
|
|
1845
|
+
|
|
1846
|
+
template [[host_name("kernel_rope_neox_f32")]] kernel kernel_rope_neox_t kernel_rope_neox<float>;
|
|
1847
|
+
template [[host_name("kernel_rope_neox_f16")]] kernel kernel_rope_neox_t kernel_rope_neox<half>;
|
|
1781
1848
|
|
|
1782
1849
|
typedef void (im2col_t)(
|
|
1783
1850
|
device const float * x,
|
|
@@ -2418,7 +2485,7 @@ template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f
|
|
|
2418
2485
|
template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<96>;
|
|
2419
2486
|
template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<112>;
|
|
2420
2487
|
template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128>;
|
|
2421
|
-
template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256>;
|
|
2488
|
+
//template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256>;
|
|
2422
2489
|
|
|
2423
2490
|
template<int64_t D, int64_t Q = 1, int64_t C = 32> // head size, queries per threadgroup, cache items per threadgroup
|
|
2424
2491
|
kernel void kernel_flash_attn_ext_vec_f16(
|
|
@@ -2696,7 +2763,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
|
|
2696
2763
|
}
|
|
2697
2764
|
|
|
2698
2765
|
template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128>;
|
|
2699
|
-
template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>;
|
|
2766
|
+
//template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>;
|
|
2700
2767
|
|
|
2701
2768
|
kernel void kernel_cpy_f16_f16(
|
|
2702
2769
|
device const half * src0,
|
|
@@ -3319,31 +3386,30 @@ kernel void kernel_concat(
|
|
|
3319
3386
|
constant uint64_t & nb1,
|
|
3320
3387
|
constant uint64_t & nb2,
|
|
3321
3388
|
constant uint64_t & nb3,
|
|
3389
|
+
constant int32_t & dim,
|
|
3322
3390
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
3323
3391
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
3324
3392
|
uint3 ntg[[threads_per_threadgroup]]) {
|
|
3325
3393
|
|
|
3326
|
-
const int64_t
|
|
3327
|
-
const int64_t
|
|
3328
|
-
const int64_t
|
|
3394
|
+
const int64_t i3 = tgpig.z;
|
|
3395
|
+
const int64_t i2 = tgpig.y;
|
|
3396
|
+
const int64_t i1 = tgpig.x;
|
|
3329
3397
|
|
|
3330
|
-
|
|
3331
|
-
|
|
3332
|
-
const int64_t i11 = i01 % ne11;
|
|
3398
|
+
int64_t o[4] = {0, 0, 0, 0};
|
|
3399
|
+
o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03));
|
|
3333
3400
|
|
|
3334
|
-
device const
|
|
3335
|
-
device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10;
|
|
3336
|
-
device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0;
|
|
3401
|
+
device const float * x;
|
|
3337
3402
|
|
|
3338
3403
|
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
|
3339
|
-
if (
|
|
3340
|
-
(
|
|
3341
|
-
src0_ptr += ntg.x*nb00;
|
|
3404
|
+
if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
|
|
3405
|
+
x = (device const float *)(src0 + (i3 )*nb03 + (i2 )*nb02 + (i1 )*nb01 + (i0 )*nb00);
|
|
3342
3406
|
} else {
|
|
3343
|
-
(
|
|
3344
|
-
src1_ptr += ntg.x*nb10;
|
|
3407
|
+
x = (device const float *)(src1 + (i3 - o[3])*nb13 + (i2 - o[2])*nb12 + (i1 - o[1])*nb11 + (i0 - o[0])*nb10);
|
|
3345
3408
|
}
|
|
3346
|
-
|
|
3409
|
+
|
|
3410
|
+
device float * y = (device float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
3411
|
+
|
|
3412
|
+
*y = *x;
|
|
3347
3413
|
}
|
|
3348
3414
|
}
|
|
3349
3415
|
|