llama_cpp 0.15.3 → 0.16.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +16 -0
- data/ext/llama_cpp/extconf.rb +1 -2
- data/ext/llama_cpp/llama_cpp.cpp +27 -3
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +15 -1
- data/vendor/tmp/llama.cpp/Makefile +66 -36
- data/vendor/tmp/llama.cpp/ggml-alloc.c +4 -4
- data/vendor/tmp/llama.cpp/ggml-backend.c +5 -5
- data/vendor/tmp/llama.cpp/ggml-backend.h +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/acc.cu +47 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/arange.cu +34 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/argsort.cu +103 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/binbcast.cu +280 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/clamp.cu +34 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/concat.cu +196 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/convert.cu +686 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/cpy.cu +490 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/diagmask.cu +40 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/dmmv.cu +662 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/fattn-tile-f16.cu +319 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/fattn-tile-f32.cu +312 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/fattn.cu +345 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/getrows.cu +178 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/im2col.cu +104 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/mmq.cu +1564 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/mmvq.cu +404 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/norm.cu +221 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/pad.cu +49 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/pool2d.cu +94 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/quantize.cu +45 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/rope.cu +271 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/scale.cu +31 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/softmax.cu +205 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/sumrows.cu +40 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb16.cu +10 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb32.cu +9 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb16.cu +10 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb32.cu +10 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb8.cu +8 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/tsembd.cu +47 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/unary.cu +266 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/upscale.cu +51 -0
- data/vendor/tmp/llama.cpp/ggml-cuda.cu +35 -16
- data/vendor/tmp/llama.cpp/ggml-impl.h +4 -0
- data/vendor/tmp/llama.cpp/ggml-kompute.cpp +21 -7
- data/vendor/tmp/llama.cpp/ggml-metal.h +1 -1
- data/vendor/tmp/llama.cpp/ggml-metal.m +99 -35
- data/vendor/tmp/llama.cpp/ggml-metal.metal +146 -80
- data/vendor/tmp/llama.cpp/ggml-quants.c +101 -11
- data/vendor/tmp/llama.cpp/ggml-rpc.cpp +75 -58
- data/vendor/tmp/llama.cpp/ggml-sycl.cpp +345 -227
- data/vendor/tmp/llama.cpp/ggml-vulkan-shaders.hpp +99301 -39793
- data/vendor/tmp/llama.cpp/ggml-vulkan.cpp +458 -329
- data/vendor/tmp/llama.cpp/ggml.c +301 -409
- data/vendor/tmp/llama.cpp/ggml.h +19 -23
- data/vendor/tmp/llama.cpp/llama.cpp +855 -651
- data/vendor/tmp/llama.cpp/llama.h +28 -48
- metadata +121 -6
- data/vendor/tmp/llama.cpp/ggml-mpi.c +0 -216
- data/vendor/tmp/llama.cpp/ggml-mpi.h +0 -39
- data/vendor/tmp/llama.cpp/ggml-opencl.cpp +0 -2305
- data/vendor/tmp/llama.cpp/ggml-opencl.h +0 -36
@@ -35,6 +35,10 @@ enum ggml_metal_kernel_type {
|
|
35
35
|
GGML_METAL_KERNEL_TYPE_MUL_ROW,
|
36
36
|
GGML_METAL_KERNEL_TYPE_DIV,
|
37
37
|
GGML_METAL_KERNEL_TYPE_DIV_ROW,
|
38
|
+
GGML_METAL_KERNEL_TYPE_REPEAT_F32,
|
39
|
+
GGML_METAL_KERNEL_TYPE_REPEAT_F16,
|
40
|
+
GGML_METAL_KERNEL_TYPE_REPEAT_I32,
|
41
|
+
GGML_METAL_KERNEL_TYPE_REPEAT_I16,
|
38
42
|
GGML_METAL_KERNEL_TYPE_SCALE,
|
39
43
|
GGML_METAL_KERNEL_TYPE_SCALE_4,
|
40
44
|
GGML_METAL_KERNEL_TYPE_CLAMP,
|
@@ -168,8 +172,10 @@ enum ggml_metal_kernel_type {
|
|
168
172
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32,
|
169
173
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,
|
170
174
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,
|
171
|
-
|
172
|
-
|
175
|
+
GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32,
|
176
|
+
GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16,
|
177
|
+
GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32,
|
178
|
+
GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16,
|
173
179
|
GGML_METAL_KERNEL_TYPE_IM2COL_F16,
|
174
180
|
GGML_METAL_KERNEL_TYPE_IM2COL_F32,
|
175
181
|
GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
|
@@ -184,9 +190,9 @@ enum ggml_metal_kernel_type {
|
|
184
190
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96,
|
185
191
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,
|
186
192
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,
|
187
|
-
|
193
|
+
//GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
|
188
194
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
|
189
|
-
|
195
|
+
//GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
|
190
196
|
GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
|
191
197
|
GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
|
192
198
|
GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
|
@@ -485,6 +491,10 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
485
491
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true);
|
486
492
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true);
|
487
493
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true);
|
494
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F32, repeat_f32, true);
|
495
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F16, repeat_f16, true);
|
496
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I32, repeat_i32, true);
|
497
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I16, repeat_i16, true);
|
488
498
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true);
|
489
499
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true);
|
490
500
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true);
|
@@ -618,8 +628,10 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
618
628
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, ctx->support_simdgroup_mm);
|
619
629
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm);
|
620
630
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm);
|
621
|
-
GGML_METAL_ADD_KERNEL(
|
622
|
-
GGML_METAL_ADD_KERNEL(
|
631
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, rope_norm_f32, true);
|
632
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, rope_norm_f16, true);
|
633
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, rope_neox_f32, true);
|
634
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, rope_neox_f16, true);
|
623
635
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
|
624
636
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
|
625
637
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
|
@@ -634,9 +646,9 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
634
646
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, ctx->support_simdgroup_mm);
|
635
647
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, ctx->support_simdgroup_mm);
|
636
648
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, ctx->support_simdgroup_mm);
|
637
|
-
|
649
|
+
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, ctx->support_simdgroup_mm);
|
638
650
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, ctx->support_simdgroup_reduction);
|
639
|
-
|
651
|
+
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, ctx->support_simdgroup_reduction);
|
640
652
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
|
641
653
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
|
642
654
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
|
@@ -746,6 +758,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
|
|
746
758
|
case GGML_OP_ACC:
|
747
759
|
case GGML_OP_MUL:
|
748
760
|
case GGML_OP_DIV:
|
761
|
+
case GGML_OP_REPEAT:
|
749
762
|
case GGML_OP_SCALE:
|
750
763
|
case GGML_OP_CLAMP:
|
751
764
|
case GGML_OP_SQR:
|
@@ -770,6 +783,15 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
|
|
770
783
|
case GGML_OP_LEAKY_RELU:
|
771
784
|
return true;
|
772
785
|
case GGML_OP_FLASH_ATTN_EXT:
|
786
|
+
if (op->src[1]->type != GGML_TYPE_F16) {
|
787
|
+
return false;
|
788
|
+
}
|
789
|
+
if (op->src[2]->type != GGML_TYPE_F16) {
|
790
|
+
return false;
|
791
|
+
}
|
792
|
+
if (op->src[0]->ne[0] == 256) {
|
793
|
+
return false;
|
794
|
+
}
|
773
795
|
return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels
|
774
796
|
case GGML_OP_MUL_MAT:
|
775
797
|
case GGML_OP_MUL_MAT_ID:
|
@@ -976,10 +998,10 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
976
998
|
switch (dst->op) {
|
977
999
|
case GGML_OP_CONCAT:
|
978
1000
|
{
|
979
|
-
const int64_t nb = ne00;
|
980
|
-
|
981
1001
|
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONCAT].pipeline;
|
982
1002
|
|
1003
|
+
const int32_t dim = ((int32_t *) dst->op_params)[0];
|
1004
|
+
|
983
1005
|
[encoder setComputePipelineState:pipeline];
|
984
1006
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
985
1007
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
@@ -1008,7 +1030,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1008
1030
|
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
|
1009
1031
|
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
|
1010
1032
|
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
|
1011
|
-
[encoder setBytes:&
|
1033
|
+
[encoder setBytes:&dim length:sizeof(dim) atIndex:27];
|
1012
1034
|
|
1013
1035
|
const int nth = MIN(1024, ne0);
|
1014
1036
|
|
@@ -1018,11 +1040,14 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1018
1040
|
case GGML_OP_MUL:
|
1019
1041
|
case GGML_OP_DIV:
|
1020
1042
|
{
|
1043
|
+
GGML_ASSERT(src0t == GGML_TYPE_F32);
|
1044
|
+
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
1045
|
+
|
1021
1046
|
const size_t offs = 0;
|
1022
1047
|
|
1023
1048
|
bool bcast_row = false;
|
1024
1049
|
|
1025
|
-
int64_t nb = ne00;
|
1050
|
+
int64_t nb = ne00; // used by the "row" kernels
|
1026
1051
|
|
1027
1052
|
id<MTLComputePipelineState> pipeline = nil;
|
1028
1053
|
|
@@ -1091,6 +1116,42 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1091
1116
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
1092
1117
|
}
|
1093
1118
|
} break;
|
1119
|
+
case GGML_OP_REPEAT:
|
1120
|
+
{
|
1121
|
+
id<MTLComputePipelineState> pipeline;
|
1122
|
+
|
1123
|
+
switch (src0t) {
|
1124
|
+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F32].pipeline; break;
|
1125
|
+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F16].pipeline; break;
|
1126
|
+
case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I32].pipeline; break;
|
1127
|
+
case GGML_TYPE_I16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I16].pipeline; break;
|
1128
|
+
default: GGML_ASSERT(false);
|
1129
|
+
}
|
1130
|
+
|
1131
|
+
[encoder setComputePipelineState:pipeline];
|
1132
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1133
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
1134
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
1135
|
+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
1136
|
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
1137
|
+
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
|
1138
|
+
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
|
1139
|
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
|
1140
|
+
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
|
1141
|
+
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
|
1142
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
|
1143
|
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
|
1144
|
+
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
|
1145
|
+
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
|
1146
|
+
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
|
1147
|
+
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
|
1148
|
+
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
|
1149
|
+
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
|
1150
|
+
|
1151
|
+
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
|
1152
|
+
|
1153
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
1154
|
+
} break;
|
1094
1155
|
case GGML_OP_ACC:
|
1095
1156
|
{
|
1096
1157
|
GGML_ASSERT(src0t == GGML_TYPE_F32);
|
@@ -1468,7 +1529,6 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1468
1529
|
{
|
1469
1530
|
GGML_ASSERT(ne00 == ne10);
|
1470
1531
|
|
1471
|
-
// TODO: assert that dim2 and dim3 are contiguous
|
1472
1532
|
GGML_ASSERT(ne12 % ne02 == 0);
|
1473
1533
|
GGML_ASSERT(ne13 % ne03 == 0);
|
1474
1534
|
|
@@ -2136,6 +2196,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2136
2196
|
case GGML_OP_RMS_NORM:
|
2137
2197
|
{
|
2138
2198
|
GGML_ASSERT(ne00 % 4 == 0);
|
2199
|
+
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
2139
2200
|
|
2140
2201
|
float eps;
|
2141
2202
|
memcpy(&eps, dst->op_params, sizeof(float));
|
@@ -2163,6 +2224,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2163
2224
|
case GGML_OP_GROUP_NORM:
|
2164
2225
|
{
|
2165
2226
|
GGML_ASSERT(ne00 % 4 == 0);
|
2227
|
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
2166
2228
|
|
2167
2229
|
//float eps;
|
2168
2230
|
//memcpy(&eps, dst->op_params, sizeof(float));
|
@@ -2196,6 +2258,8 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2196
2258
|
} break;
|
2197
2259
|
case GGML_OP_NORM:
|
2198
2260
|
{
|
2261
|
+
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
2262
|
+
|
2199
2263
|
float eps;
|
2200
2264
|
memcpy(&eps, dst->op_params, sizeof(float));
|
2201
2265
|
|
@@ -2225,7 +2289,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2225
2289
|
const int n_dims = ((int32_t *) dst->op_params)[1];
|
2226
2290
|
const int mode = ((int32_t *) dst->op_params)[2];
|
2227
2291
|
// skip 3, n_ctx, used in GLM RoPE, unimplemented in metal
|
2228
|
-
const int
|
2292
|
+
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
|
2229
2293
|
|
2230
2294
|
float freq_base;
|
2231
2295
|
float freq_scale;
|
@@ -2242,22 +2306,23 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2242
2306
|
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
2243
2307
|
|
2244
2308
|
const bool is_neox = mode & 2;
|
2245
|
-
const bool is_glm = mode & 4;
|
2246
2309
|
|
2247
|
-
|
2310
|
+
id<MTLComputePipelineState> pipeline = nil;
|
2248
2311
|
|
2249
2312
|
if (!is_neox) {
|
2250
|
-
|
2313
|
+
switch (src0->type) {
|
2314
|
+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break;
|
2315
|
+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break;
|
2316
|
+
default: GGML_ASSERT(false);
|
2317
|
+
};
|
2318
|
+
} else {
|
2319
|
+
switch (src0->type) {
|
2320
|
+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break;
|
2321
|
+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break;
|
2322
|
+
default: GGML_ASSERT(false);
|
2323
|
+
};
|
2251
2324
|
}
|
2252
2325
|
|
2253
|
-
id<MTLComputePipelineState> pipeline = nil;
|
2254
|
-
|
2255
|
-
switch (src0->type) {
|
2256
|
-
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_F32].pipeline; break;
|
2257
|
-
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_F16].pipeline; break;
|
2258
|
-
default: GGML_ASSERT(false);
|
2259
|
-
};
|
2260
|
-
|
2261
2326
|
[encoder setComputePipelineState:pipeline];
|
2262
2327
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
2263
2328
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
@@ -2285,14 +2350,13 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2285
2350
|
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:19];
|
2286
2351
|
[encoder setBytes:&n_past length:sizeof( int) atIndex:20];
|
2287
2352
|
[encoder setBytes:&n_dims length:sizeof( int) atIndex:21];
|
2288
|
-
[encoder setBytes:&
|
2289
|
-
[encoder setBytes:&
|
2290
|
-
[encoder setBytes:&
|
2291
|
-
[encoder setBytes:&
|
2292
|
-
[encoder setBytes:&
|
2293
|
-
[encoder setBytes:&
|
2294
|
-
[encoder setBytes:&
|
2295
|
-
[encoder setBytes:&beta_slow length:sizeof( float) atIndex:29];
|
2353
|
+
[encoder setBytes:&n_ctx_orig length:sizeof( int) atIndex:22];
|
2354
|
+
[encoder setBytes:&freq_base length:sizeof( float) atIndex:23];
|
2355
|
+
[encoder setBytes:&freq_scale length:sizeof( float) atIndex:24];
|
2356
|
+
[encoder setBytes:&ext_factor length:sizeof( float) atIndex:25];
|
2357
|
+
[encoder setBytes:&attn_factor length:sizeof( float) atIndex:26];
|
2358
|
+
[encoder setBytes:&beta_fast length:sizeof( float) atIndex:27];
|
2359
|
+
[encoder setBytes:&beta_slow length:sizeof( float) atIndex:28];
|
2296
2360
|
|
2297
2361
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
2298
2362
|
} break;
|
@@ -2573,7 +2637,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2573
2637
|
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break;
|
2574
2638
|
case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break;
|
2575
2639
|
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break;
|
2576
|
-
|
2640
|
+
//case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
|
2577
2641
|
default:
|
2578
2642
|
{
|
2579
2643
|
GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
|
@@ -2586,7 +2650,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2586
2650
|
|
2587
2651
|
switch (ne00) {
|
2588
2652
|
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break;
|
2589
|
-
|
2653
|
+
//case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
|
2590
2654
|
default:
|
2591
2655
|
{
|
2592
2656
|
GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
|
@@ -168,6 +168,53 @@ kernel void kernel_div(
|
|
168
168
|
}
|
169
169
|
}
|
170
170
|
|
171
|
+
template<typename T>
|
172
|
+
kernel void kernel_repeat(
|
173
|
+
device const char * src0,
|
174
|
+
device char * dst,
|
175
|
+
constant int64_t & ne00,
|
176
|
+
constant int64_t & ne01,
|
177
|
+
constant int64_t & ne02,
|
178
|
+
constant int64_t & ne03,
|
179
|
+
constant uint64_t & nb00,
|
180
|
+
constant uint64_t & nb01,
|
181
|
+
constant uint64_t & nb02,
|
182
|
+
constant uint64_t & nb03,
|
183
|
+
constant int64_t & ne0,
|
184
|
+
constant int64_t & ne1,
|
185
|
+
constant int64_t & ne2,
|
186
|
+
constant int64_t & ne3,
|
187
|
+
constant uint64_t & nb0,
|
188
|
+
constant uint64_t & nb1,
|
189
|
+
constant uint64_t & nb2,
|
190
|
+
constant uint64_t & nb3,
|
191
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
192
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
193
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
194
|
+
const int64_t i3 = tgpig.z;
|
195
|
+
const int64_t i2 = tgpig.y;
|
196
|
+
const int64_t i1 = tgpig.x;
|
197
|
+
|
198
|
+
const int64_t i03 = i3 % ne03;
|
199
|
+
const int64_t i02 = i2 % ne02;
|
200
|
+
const int64_t i01 = i1 % ne01;
|
201
|
+
|
202
|
+
device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
|
203
|
+
device char * dst_ptr = dst + i3*nb3 + i2*nb2 + i1*nb1 ;
|
204
|
+
|
205
|
+
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
206
|
+
const int i00 = i0 % ne00;
|
207
|
+
*((device T *)(dst_ptr + i0*nb0)) = *((device T *)(src0_ptr + i00*nb00));
|
208
|
+
}
|
209
|
+
}
|
210
|
+
|
211
|
+
typedef decltype(kernel_repeat<float>) kernel_repeat_t;
|
212
|
+
|
213
|
+
template [[host_name("kernel_repeat_f32")]] kernel kernel_repeat_t kernel_repeat<float>;
|
214
|
+
template [[host_name("kernel_repeat_f16")]] kernel kernel_repeat_t kernel_repeat<half>;
|
215
|
+
template [[host_name("kernel_repeat_i32")]] kernel kernel_repeat_t kernel_repeat<int>;
|
216
|
+
template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat<short>;
|
217
|
+
|
171
218
|
// assumption: src1 is a row
|
172
219
|
// broadcast src1 into src0
|
173
220
|
kernel void kernel_add_row(
|
@@ -1607,8 +1654,7 @@ static float rope_yarn_ramp(const float low, const float high, const int i0) {
|
|
1607
1654
|
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
|
1608
1655
|
static void rope_yarn(
|
1609
1656
|
float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
|
1610
|
-
thread float * cos_theta, thread float * sin_theta
|
1611
|
-
) {
|
1657
|
+
thread float * cos_theta, thread float * sin_theta) {
|
1612
1658
|
// Get n-d rotational scaling corrected for extrapolation
|
1613
1659
|
float theta_interp = freq_scale * theta_extrap;
|
1614
1660
|
float theta = theta_interp;
|
@@ -1625,19 +1671,20 @@ static void rope_yarn(
|
|
1625
1671
|
|
1626
1672
|
// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
|
1627
1673
|
// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
|
1628
|
-
static float rope_yarn_corr_factor(int n_dims, int
|
1629
|
-
return n_dims * log(
|
1674
|
+
static float rope_yarn_corr_factor(int n_dims, int n_ctx_orig, float n_rot, float base) {
|
1675
|
+
return n_dims * log(n_ctx_orig / (n_rot * 2 * M_PI_F)) / (2 * log(base));
|
1630
1676
|
}
|
1631
1677
|
|
1632
1678
|
static void rope_yarn_corr_dims(
|
1633
|
-
int n_dims, int
|
1679
|
+
int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2]
|
1634
1680
|
) {
|
1635
1681
|
// start and end correction dims
|
1636
|
-
dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims,
|
1637
|
-
dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims,
|
1682
|
+
dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_fast, freq_base)));
|
1683
|
+
dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_slow, freq_base)));
|
1638
1684
|
}
|
1639
1685
|
|
1640
|
-
|
1686
|
+
template<typename T>
|
1687
|
+
kernel void kernel_rope_norm(
|
1641
1688
|
device const void * src0,
|
1642
1689
|
device const int32_t * src1,
|
1643
1690
|
device const float * src2,
|
@@ -1660,8 +1707,7 @@ typedef void (rope_t)(
|
|
1660
1707
|
constant uint64_t & nb3,
|
1661
1708
|
constant int & n_past,
|
1662
1709
|
constant int & n_dims,
|
1663
|
-
constant int &
|
1664
|
-
constant int & n_orig_ctx,
|
1710
|
+
constant int & n_ctx_orig,
|
1665
1711
|
constant float & freq_base,
|
1666
1712
|
constant float & freq_scale,
|
1667
1713
|
constant float & ext_factor,
|
@@ -1670,10 +1716,52 @@ typedef void (rope_t)(
|
|
1670
1716
|
constant float & beta_slow,
|
1671
1717
|
uint tiitg[[thread_index_in_threadgroup]],
|
1672
1718
|
uint3 tptg[[threads_per_threadgroup]],
|
1673
|
-
uint3 tgpig[[threadgroup_position_in_grid]])
|
1719
|
+
uint3 tgpig[[threadgroup_position_in_grid]]) {
|
1720
|
+
const int64_t i3 = tgpig[2];
|
1721
|
+
const int64_t i2 = tgpig[1];
|
1722
|
+
const int64_t i1 = tgpig[0];
|
1723
|
+
|
1724
|
+
float corr_dims[2];
|
1725
|
+
rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
|
1726
|
+
|
1727
|
+
device const int32_t * pos = src1;
|
1728
|
+
|
1729
|
+
const float theta_base = (float) pos[i2];
|
1730
|
+
const float inv_ndims = -1.f/n_dims;
|
1731
|
+
|
1732
|
+
float cos_theta;
|
1733
|
+
float sin_theta;
|
1734
|
+
|
1735
|
+
for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
|
1736
|
+
if (i0 < n_dims) {
|
1737
|
+
const int64_t ic = i0/2;
|
1738
|
+
|
1739
|
+
const float theta = theta_base * pow(freq_base, inv_ndims*i0);
|
1740
|
+
|
1741
|
+
const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
|
1742
|
+
|
1743
|
+
rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
1744
|
+
|
1745
|
+
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
1746
|
+
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
1747
|
+
|
1748
|
+
const float x0 = src[0];
|
1749
|
+
const float x1 = src[1];
|
1750
|
+
|
1751
|
+
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
1752
|
+
dst_data[1] = x0*sin_theta + x1*cos_theta;
|
1753
|
+
} else {
|
1754
|
+
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
1755
|
+
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
1756
|
+
|
1757
|
+
dst_data[0] = src[0];
|
1758
|
+
dst_data[1] = src[1];
|
1759
|
+
}
|
1760
|
+
}
|
1761
|
+
}
|
1674
1762
|
|
1675
1763
|
template<typename T>
|
1676
|
-
kernel void
|
1764
|
+
kernel void kernel_rope_neox(
|
1677
1765
|
device const void * src0,
|
1678
1766
|
device const int32_t * src1,
|
1679
1767
|
device const float * src2,
|
@@ -1696,8 +1784,7 @@ kernel void kernel_rope(
|
|
1696
1784
|
constant uint64_t & nb3,
|
1697
1785
|
constant int & n_past,
|
1698
1786
|
constant int & n_dims,
|
1699
|
-
constant int &
|
1700
|
-
constant int & n_orig_ctx,
|
1787
|
+
constant int & n_ctx_orig,
|
1701
1788
|
constant float & freq_base,
|
1702
1789
|
constant float & freq_scale,
|
1703
1790
|
constant float & ext_factor,
|
@@ -1711,73 +1798,53 @@ kernel void kernel_rope(
|
|
1711
1798
|
const int64_t i2 = tgpig[1];
|
1712
1799
|
const int64_t i1 = tgpig[0];
|
1713
1800
|
|
1714
|
-
const bool is_neox = mode & 2;
|
1715
|
-
|
1716
1801
|
float corr_dims[2];
|
1717
|
-
rope_yarn_corr_dims(n_dims,
|
1802
|
+
rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
|
1718
1803
|
|
1719
1804
|
device const int32_t * pos = src1;
|
1720
1805
|
|
1721
|
-
const
|
1722
|
-
|
1723
|
-
const float theta_0 = (float)p;
|
1806
|
+
const float theta_base = (float) pos[i2];
|
1724
1807
|
const float inv_ndims = -1.f/n_dims;
|
1725
1808
|
|
1726
|
-
|
1727
|
-
|
1728
|
-
|
1729
|
-
const float theta = theta_0 * pow(freq_base, inv_ndims*i0);
|
1730
|
-
float cos_theta, sin_theta;
|
1731
|
-
rope_yarn(theta, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
1732
|
-
|
1733
|
-
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
1734
|
-
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
1735
|
-
|
1736
|
-
const T x0 = src[0];
|
1737
|
-
const T x1 = src[1];
|
1809
|
+
float cos_theta;
|
1810
|
+
float sin_theta;
|
1738
1811
|
|
1739
|
-
|
1740
|
-
|
1741
|
-
|
1742
|
-
} else {
|
1743
|
-
for (int64_t ic = 2*tiitg; ic < ne0; ic += 2*tptg.x) {
|
1744
|
-
if (ic < n_dims) {
|
1745
|
-
const int64_t ib = 0;
|
1746
|
-
|
1747
|
-
// simplified from `(ib * n_dims + ic) * inv_ndims`
|
1748
|
-
const float cur_rot = inv_ndims*ic - ib;
|
1749
|
-
const float freq_factor = src2 != src0 ? src2[ic/2] : 1.0f;
|
1750
|
-
|
1751
|
-
const float theta = theta_0 * pow(freq_base, cur_rot) / freq_factor;
|
1812
|
+
for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
|
1813
|
+
if (i0 < n_dims) {
|
1814
|
+
const int64_t ic = i0/2;
|
1752
1815
|
|
1753
|
-
|
1754
|
-
rope_yarn(theta, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
1816
|
+
const float theta = theta_base * pow(freq_base, inv_ndims*i0);
|
1755
1817
|
|
1756
|
-
|
1818
|
+
const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
|
1757
1819
|
|
1758
|
-
|
1759
|
-
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
1820
|
+
rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
1760
1821
|
|
1761
|
-
|
1762
|
-
|
1822
|
+
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
|
1823
|
+
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
|
1763
1824
|
|
1764
|
-
|
1765
|
-
|
1766
|
-
} else {
|
1767
|
-
const int64_t i0 = ic;
|
1825
|
+
const float x0 = src[0];
|
1826
|
+
const float x1 = src[n_dims/2];
|
1768
1827
|
|
1769
|
-
|
1770
|
-
|
1828
|
+
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
1829
|
+
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
|
1830
|
+
} else {
|
1831
|
+
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
1832
|
+
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
1771
1833
|
|
1772
|
-
|
1773
|
-
|
1774
|
-
}
|
1834
|
+
dst_data[0] = src[0];
|
1835
|
+
dst_data[1] = src[1];
|
1775
1836
|
}
|
1776
1837
|
}
|
1777
1838
|
}
|
1778
1839
|
|
1779
|
-
|
1780
|
-
|
1840
|
+
typedef decltype(kernel_rope_norm<float>) kernel_rope_norm_t;
|
1841
|
+
typedef decltype(kernel_rope_neox<float>) kernel_rope_neox_t;
|
1842
|
+
|
1843
|
+
template [[host_name("kernel_rope_norm_f32")]] kernel kernel_rope_norm_t kernel_rope_norm<float>;
|
1844
|
+
template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_rope_norm<half>;
|
1845
|
+
|
1846
|
+
template [[host_name("kernel_rope_neox_f32")]] kernel kernel_rope_neox_t kernel_rope_neox<float>;
|
1847
|
+
template [[host_name("kernel_rope_neox_f16")]] kernel kernel_rope_neox_t kernel_rope_neox<half>;
|
1781
1848
|
|
1782
1849
|
typedef void (im2col_t)(
|
1783
1850
|
device const float * x,
|
@@ -2418,7 +2485,7 @@ template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f
|
|
2418
2485
|
template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<96>;
|
2419
2486
|
template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<112>;
|
2420
2487
|
template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128>;
|
2421
|
-
template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256>;
|
2488
|
+
//template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256>;
|
2422
2489
|
|
2423
2490
|
template<int64_t D, int64_t Q = 1, int64_t C = 32> // head size, queries per threadgroup, cache items per threadgroup
|
2424
2491
|
kernel void kernel_flash_attn_ext_vec_f16(
|
@@ -2696,7 +2763,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
|
2696
2763
|
}
|
2697
2764
|
|
2698
2765
|
template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128>;
|
2699
|
-
template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>;
|
2766
|
+
//template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>;
|
2700
2767
|
|
2701
2768
|
kernel void kernel_cpy_f16_f16(
|
2702
2769
|
device const half * src0,
|
@@ -3319,31 +3386,30 @@ kernel void kernel_concat(
|
|
3319
3386
|
constant uint64_t & nb1,
|
3320
3387
|
constant uint64_t & nb2,
|
3321
3388
|
constant uint64_t & nb3,
|
3389
|
+
constant int32_t & dim,
|
3322
3390
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
3323
3391
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
3324
3392
|
uint3 ntg[[threads_per_threadgroup]]) {
|
3325
3393
|
|
3326
|
-
const int64_t
|
3327
|
-
const int64_t
|
3328
|
-
const int64_t
|
3394
|
+
const int64_t i3 = tgpig.z;
|
3395
|
+
const int64_t i2 = tgpig.y;
|
3396
|
+
const int64_t i1 = tgpig.x;
|
3329
3397
|
|
3330
|
-
|
3331
|
-
|
3332
|
-
const int64_t i11 = i01 % ne11;
|
3398
|
+
int64_t o[4] = {0, 0, 0, 0};
|
3399
|
+
o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03));
|
3333
3400
|
|
3334
|
-
device const
|
3335
|
-
device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10;
|
3336
|
-
device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0;
|
3401
|
+
device const float * x;
|
3337
3402
|
|
3338
3403
|
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
3339
|
-
if (
|
3340
|
-
(
|
3341
|
-
src0_ptr += ntg.x*nb00;
|
3404
|
+
if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
|
3405
|
+
x = (device const float *)(src0 + (i3 )*nb03 + (i2 )*nb02 + (i1 )*nb01 + (i0 )*nb00);
|
3342
3406
|
} else {
|
3343
|
-
(
|
3344
|
-
src1_ptr += ntg.x*nb10;
|
3407
|
+
x = (device const float *)(src1 + (i3 - o[3])*nb13 + (i2 - o[2])*nb12 + (i1 - o[1])*nb11 + (i0 - o[0])*nb10);
|
3345
3408
|
}
|
3346
|
-
|
3409
|
+
|
3410
|
+
device float * y = (device float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
3411
|
+
|
3412
|
+
*y = *x;
|
3347
3413
|
}
|
3348
3414
|
}
|
3349
3415
|
|