llama_cpp 0.15.4 → 0.16.1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +16 -0
- data/ext/llama_cpp/extconf.rb +3 -2
- data/ext/llama_cpp/llama_cpp.cpp +17 -3
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +15 -1
- data/vendor/tmp/llama.cpp/Makefile +166 -82
- data/vendor/tmp/llama.cpp/ggml-alloc.c +82 -26
- data/vendor/tmp/llama.cpp/ggml-backend-impl.h +20 -8
- data/vendor/tmp/llama.cpp/ggml-backend.c +183 -69
- data/vendor/tmp/llama.cpp/ggml-backend.h +4 -4
- data/vendor/tmp/llama.cpp/ggml-blas.cpp +363 -0
- data/vendor/tmp/llama.cpp/ggml-blas.h +23 -0
- data/vendor/tmp/llama.cpp/ggml-common.h +6 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/acc.cu +47 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/arange.cu +34 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/argsort.cu +104 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/binbcast.cu +280 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/clamp.cu +34 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/concat.cu +196 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/convert.cu +686 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/cpy.cu +490 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/diagmask.cu +40 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/dmmv.cu +674 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/fattn-tile-f16.cu +319 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/fattn-tile-f32.cu +312 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/fattn.cu +345 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/getrows.cu +178 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/im2col.cu +104 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/mmq.cu +88 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/mmvq.cu +419 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/norm.cu +221 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/pad.cu +49 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/pool2d.cu +94 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/quantize.cu +112 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/rope.cu +271 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/scale.cu +31 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/softmax.cu +206 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/sumrows.cu +40 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb16.cu +10 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb32.cu +9 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb16.cu +10 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb32.cu +10 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb8.cu +8 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q2_k.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q3_k.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q4_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q4_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q4_k.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q5_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q5_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q5_k.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q6_k.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q8_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/tsembd.cu +47 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/unary.cu +286 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/upscale.cu +51 -0
- data/vendor/tmp/llama.cpp/ggml-cuda.cu +103 -135
- data/vendor/tmp/llama.cpp/ggml-kompute.cpp +29 -13
- data/vendor/tmp/llama.cpp/ggml-metal.h +1 -1
- data/vendor/tmp/llama.cpp/ggml-metal.m +45 -33
- data/vendor/tmp/llama.cpp/ggml-metal.metal +83 -59
- data/vendor/tmp/llama.cpp/ggml-rpc.cpp +15 -14
- data/vendor/tmp/llama.cpp/ggml-sycl.cpp +26 -90
- data/vendor/tmp/llama.cpp/ggml-vulkan-shaders.hpp +74522 -14913
- data/vendor/tmp/llama.cpp/ggml-vulkan.cpp +631 -471
- data/vendor/tmp/llama.cpp/ggml.c +278 -603
- data/vendor/tmp/llama.cpp/ggml.h +9 -28
- data/vendor/tmp/llama.cpp/llama.cpp +345 -473
- data/vendor/tmp/llama.cpp/llama.h +21 -43
- metadata +134 -7
- data/vendor/tmp/llama.cpp/ggml-mpi.c +0 -216
- data/vendor/tmp/llama.cpp/ggml-mpi.h +0 -39
- data/vendor/tmp/llama.cpp/ggml-opencl.cpp +0 -2305
- data/vendor/tmp/llama.cpp/ggml-opencl.h +0 -36
@@ -172,8 +172,10 @@ enum ggml_metal_kernel_type {
|
|
172
172
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32,
|
173
173
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,
|
174
174
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,
|
175
|
-
|
176
|
-
|
175
|
+
GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32,
|
176
|
+
GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16,
|
177
|
+
GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32,
|
178
|
+
GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16,
|
177
179
|
GGML_METAL_KERNEL_TYPE_IM2COL_F16,
|
178
180
|
GGML_METAL_KERNEL_TYPE_IM2COL_F32,
|
179
181
|
GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
|
@@ -626,8 +628,10 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
626
628
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, ctx->support_simdgroup_mm);
|
627
629
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm);
|
628
630
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm);
|
629
|
-
GGML_METAL_ADD_KERNEL(
|
630
|
-
GGML_METAL_ADD_KERNEL(
|
631
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, rope_norm_f32, true);
|
632
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, rope_norm_f16, true);
|
633
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, rope_neox_f32, true);
|
634
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, rope_neox_f16, true);
|
631
635
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
|
632
636
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
|
633
637
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
|
@@ -740,7 +744,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
|
|
740
744
|
case GGML_UNARY_OP_GELU:
|
741
745
|
case GGML_UNARY_OP_GELU_QUICK:
|
742
746
|
case GGML_UNARY_OP_SILU:
|
743
|
-
return
|
747
|
+
return ggml_is_contiguous(op->src[0]);
|
744
748
|
default:
|
745
749
|
return false;
|
746
750
|
}
|
@@ -779,6 +783,12 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
|
|
779
783
|
case GGML_OP_LEAKY_RELU:
|
780
784
|
return true;
|
781
785
|
case GGML_OP_FLASH_ATTN_EXT:
|
786
|
+
if (op->src[1]->type != GGML_TYPE_F16) {
|
787
|
+
return false;
|
788
|
+
}
|
789
|
+
if (op->src[2]->type != GGML_TYPE_F16) {
|
790
|
+
return false;
|
791
|
+
}
|
782
792
|
if (op->src[0]->ne[0] == 256) {
|
783
793
|
return false;
|
784
794
|
}
|
@@ -1852,9 +1862,10 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1852
1862
|
// ne21 = n_rows
|
1853
1863
|
const int dst_rows = ne20*ne21;
|
1854
1864
|
const int dst_rows_min = n_as;
|
1865
|
+
const int dst_rows_max = (ctx->device.maxThreadgroupMemoryLength - 32 - 8192)/4;
|
1855
1866
|
|
1856
1867
|
// max size of the rowids array in the kernel shared buffer
|
1857
|
-
GGML_ASSERT(dst_rows <=
|
1868
|
+
GGML_ASSERT(dst_rows <= dst_rows_max);
|
1858
1869
|
|
1859
1870
|
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
1860
1871
|
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
|
@@ -2279,7 +2290,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2279
2290
|
const int n_dims = ((int32_t *) dst->op_params)[1];
|
2280
2291
|
const int mode = ((int32_t *) dst->op_params)[2];
|
2281
2292
|
// skip 3, n_ctx, used in GLM RoPE, unimplemented in metal
|
2282
|
-
const int
|
2293
|
+
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
|
2283
2294
|
|
2284
2295
|
float freq_base;
|
2285
2296
|
float freq_scale;
|
@@ -2296,22 +2307,23 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2296
2307
|
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
2297
2308
|
|
2298
2309
|
const bool is_neox = mode & 2;
|
2299
|
-
const bool is_glm = mode & 4;
|
2300
2310
|
|
2301
|
-
|
2311
|
+
id<MTLComputePipelineState> pipeline = nil;
|
2302
2312
|
|
2303
2313
|
if (!is_neox) {
|
2304
|
-
|
2314
|
+
switch (src0->type) {
|
2315
|
+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break;
|
2316
|
+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break;
|
2317
|
+
default: GGML_ASSERT(false);
|
2318
|
+
};
|
2319
|
+
} else {
|
2320
|
+
switch (src0->type) {
|
2321
|
+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break;
|
2322
|
+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break;
|
2323
|
+
default: GGML_ASSERT(false);
|
2324
|
+
};
|
2305
2325
|
}
|
2306
2326
|
|
2307
|
-
id<MTLComputePipelineState> pipeline = nil;
|
2308
|
-
|
2309
|
-
switch (src0->type) {
|
2310
|
-
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_F32].pipeline; break;
|
2311
|
-
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_F16].pipeline; break;
|
2312
|
-
default: GGML_ASSERT(false);
|
2313
|
-
};
|
2314
|
-
|
2315
2327
|
[encoder setComputePipelineState:pipeline];
|
2316
2328
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
2317
2329
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
@@ -2339,14 +2351,13 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2339
2351
|
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:19];
|
2340
2352
|
[encoder setBytes:&n_past length:sizeof( int) atIndex:20];
|
2341
2353
|
[encoder setBytes:&n_dims length:sizeof( int) atIndex:21];
|
2342
|
-
[encoder setBytes:&
|
2343
|
-
[encoder setBytes:&
|
2344
|
-
[encoder setBytes:&
|
2345
|
-
[encoder setBytes:&
|
2346
|
-
[encoder setBytes:&
|
2347
|
-
[encoder setBytes:&
|
2348
|
-
[encoder setBytes:&
|
2349
|
-
[encoder setBytes:&beta_slow length:sizeof( float) atIndex:29];
|
2354
|
+
[encoder setBytes:&n_ctx_orig length:sizeof( int) atIndex:22];
|
2355
|
+
[encoder setBytes:&freq_base length:sizeof( float) atIndex:23];
|
2356
|
+
[encoder setBytes:&freq_scale length:sizeof( float) atIndex:24];
|
2357
|
+
[encoder setBytes:&ext_factor length:sizeof( float) atIndex:25];
|
2358
|
+
[encoder setBytes:&attn_factor length:sizeof( float) atIndex:26];
|
2359
|
+
[encoder setBytes:&beta_fast length:sizeof( float) atIndex:27];
|
2360
|
+
[encoder setBytes:&beta_slow length:sizeof( float) atIndex:28];
|
2350
2361
|
|
2351
2362
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
2352
2363
|
} break;
|
@@ -3034,12 +3045,6 @@ GGML_CALL static size_t ggml_backend_metal_buffer_type_get_max_size(ggml_backend
|
|
3034
3045
|
UNUSED(buft);
|
3035
3046
|
}
|
3036
3047
|
|
3037
|
-
GGML_CALL static bool ggml_backend_metal_buffer_type_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend) {
|
3038
|
-
return ggml_backend_is_metal(backend) || ggml_backend_is_cpu(backend);
|
3039
|
-
|
3040
|
-
UNUSED(buft);
|
3041
|
-
}
|
3042
|
-
|
3043
3048
|
GGML_CALL static bool ggml_backend_metal_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
|
3044
3049
|
return true;
|
3045
3050
|
|
@@ -3054,7 +3059,6 @@ GGML_CALL ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void) {
|
|
3054
3059
|
/* .get_alignment = */ ggml_backend_metal_buffer_type_get_alignment,
|
3055
3060
|
/* .get_max_size = */ ggml_backend_metal_buffer_type_get_max_size,
|
3056
3061
|
/* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
|
3057
|
-
/* .supports_backend = */ ggml_backend_metal_buffer_type_supports_backend,
|
3058
3062
|
/* .is_host = */ ggml_backend_metal_buffer_type_is_host,
|
3059
3063
|
},
|
3060
3064
|
/* .context = */ NULL,
|
@@ -3169,6 +3173,12 @@ GGML_CALL static bool ggml_backend_metal_supports_op(ggml_backend_t backend, con
|
|
3169
3173
|
return ggml_metal_supports_op(metal_ctx, op);
|
3170
3174
|
}
|
3171
3175
|
|
3176
|
+
GGML_CALL static bool ggml_backend_metal_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
|
3177
|
+
return buft->iface.get_name == ggml_backend_metal_buffer_type_get_name;
|
3178
|
+
|
3179
|
+
UNUSED(backend);
|
3180
|
+
}
|
3181
|
+
|
3172
3182
|
static struct ggml_backend_i ggml_backend_metal_i = {
|
3173
3183
|
/* .get_name = */ ggml_backend_metal_name,
|
3174
3184
|
/* .free = */ ggml_backend_metal_free,
|
@@ -3179,9 +3189,11 @@ static struct ggml_backend_i ggml_backend_metal_i = {
|
|
3179
3189
|
/* .synchronize = */ NULL,
|
3180
3190
|
/* .graph_plan_create = */ NULL,
|
3181
3191
|
/* .graph_plan_free = */ NULL,
|
3192
|
+
/* .graph_plan_update = */ NULL,
|
3182
3193
|
/* .graph_plan_compute = */ NULL,
|
3183
3194
|
/* .graph_compute = */ ggml_backend_metal_graph_compute,
|
3184
3195
|
/* .supports_op = */ ggml_backend_metal_supports_op,
|
3196
|
+
/* .supports_buft = */ ggml_backend_metal_supports_buft,
|
3185
3197
|
/* .offload_op = */ NULL,
|
3186
3198
|
/* .event_new = */ NULL,
|
3187
3199
|
/* .event_free = */ NULL,
|
@@ -1654,8 +1654,7 @@ static float rope_yarn_ramp(const float low, const float high, const int i0) {
|
|
1654
1654
|
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
|
1655
1655
|
static void rope_yarn(
|
1656
1656
|
float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
|
1657
|
-
thread float * cos_theta, thread float * sin_theta
|
1658
|
-
) {
|
1657
|
+
thread float * cos_theta, thread float * sin_theta) {
|
1659
1658
|
// Get n-d rotational scaling corrected for extrapolation
|
1660
1659
|
float theta_interp = freq_scale * theta_extrap;
|
1661
1660
|
float theta = theta_interp;
|
@@ -1672,19 +1671,20 @@ static void rope_yarn(
|
|
1672
1671
|
|
1673
1672
|
// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
|
1674
1673
|
// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
|
1675
|
-
static float rope_yarn_corr_factor(int n_dims, int
|
1676
|
-
return n_dims * log(
|
1674
|
+
static float rope_yarn_corr_factor(int n_dims, int n_ctx_orig, float n_rot, float base) {
|
1675
|
+
return n_dims * log(n_ctx_orig / (n_rot * 2 * M_PI_F)) / (2 * log(base));
|
1677
1676
|
}
|
1678
1677
|
|
1679
1678
|
static void rope_yarn_corr_dims(
|
1680
|
-
int n_dims, int
|
1679
|
+
int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2]
|
1681
1680
|
) {
|
1682
1681
|
// start and end correction dims
|
1683
|
-
dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims,
|
1684
|
-
dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims,
|
1682
|
+
dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_fast, freq_base)));
|
1683
|
+
dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_slow, freq_base)));
|
1685
1684
|
}
|
1686
1685
|
|
1687
|
-
|
1686
|
+
template<typename T>
|
1687
|
+
kernel void kernel_rope_norm(
|
1688
1688
|
device const void * src0,
|
1689
1689
|
device const int32_t * src1,
|
1690
1690
|
device const float * src2,
|
@@ -1707,8 +1707,7 @@ typedef void (rope_t)(
|
|
1707
1707
|
constant uint64_t & nb3,
|
1708
1708
|
constant int & n_past,
|
1709
1709
|
constant int & n_dims,
|
1710
|
-
constant int &
|
1711
|
-
constant int & n_orig_ctx,
|
1710
|
+
constant int & n_ctx_orig,
|
1712
1711
|
constant float & freq_base,
|
1713
1712
|
constant float & freq_scale,
|
1714
1713
|
constant float & ext_factor,
|
@@ -1717,10 +1716,52 @@ typedef void (rope_t)(
|
|
1717
1716
|
constant float & beta_slow,
|
1718
1717
|
uint tiitg[[thread_index_in_threadgroup]],
|
1719
1718
|
uint3 tptg[[threads_per_threadgroup]],
|
1720
|
-
uint3 tgpig[[threadgroup_position_in_grid]])
|
1719
|
+
uint3 tgpig[[threadgroup_position_in_grid]]) {
|
1720
|
+
const int64_t i3 = tgpig[2];
|
1721
|
+
const int64_t i2 = tgpig[1];
|
1722
|
+
const int64_t i1 = tgpig[0];
|
1723
|
+
|
1724
|
+
float corr_dims[2];
|
1725
|
+
rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
|
1726
|
+
|
1727
|
+
device const int32_t * pos = src1;
|
1728
|
+
|
1729
|
+
const float theta_base = (float) pos[i2];
|
1730
|
+
const float inv_ndims = -1.f/n_dims;
|
1731
|
+
|
1732
|
+
float cos_theta;
|
1733
|
+
float sin_theta;
|
1734
|
+
|
1735
|
+
for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
|
1736
|
+
if (i0 < n_dims) {
|
1737
|
+
const int64_t ic = i0/2;
|
1738
|
+
|
1739
|
+
const float theta = theta_base * pow(freq_base, inv_ndims*i0);
|
1740
|
+
|
1741
|
+
const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
|
1742
|
+
|
1743
|
+
rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
1744
|
+
|
1745
|
+
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
1746
|
+
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
1747
|
+
|
1748
|
+
const float x0 = src[0];
|
1749
|
+
const float x1 = src[1];
|
1750
|
+
|
1751
|
+
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
1752
|
+
dst_data[1] = x0*sin_theta + x1*cos_theta;
|
1753
|
+
} else {
|
1754
|
+
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
1755
|
+
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
1756
|
+
|
1757
|
+
dst_data[0] = src[0];
|
1758
|
+
dst_data[1] = src[1];
|
1759
|
+
}
|
1760
|
+
}
|
1761
|
+
}
|
1721
1762
|
|
1722
1763
|
template<typename T>
|
1723
|
-
kernel void
|
1764
|
+
kernel void kernel_rope_neox(
|
1724
1765
|
device const void * src0,
|
1725
1766
|
device const int32_t * src1,
|
1726
1767
|
device const float * src2,
|
@@ -1743,8 +1784,7 @@ kernel void kernel_rope(
|
|
1743
1784
|
constant uint64_t & nb3,
|
1744
1785
|
constant int & n_past,
|
1745
1786
|
constant int & n_dims,
|
1746
|
-
constant int &
|
1747
|
-
constant int & n_orig_ctx,
|
1787
|
+
constant int & n_ctx_orig,
|
1748
1788
|
constant float & freq_base,
|
1749
1789
|
constant float & freq_scale,
|
1750
1790
|
constant float & ext_factor,
|
@@ -1758,69 +1798,53 @@ kernel void kernel_rope(
|
|
1758
1798
|
const int64_t i2 = tgpig[1];
|
1759
1799
|
const int64_t i1 = tgpig[0];
|
1760
1800
|
|
1761
|
-
const bool is_neox = mode & 2;
|
1762
|
-
|
1763
1801
|
float corr_dims[2];
|
1764
|
-
rope_yarn_corr_dims(n_dims,
|
1802
|
+
rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
|
1765
1803
|
|
1766
1804
|
device const int32_t * pos = src1;
|
1767
1805
|
|
1768
|
-
const
|
1769
|
-
|
1770
|
-
const float theta_base = (float)p;
|
1806
|
+
const float theta_base = (float) pos[i2];
|
1771
1807
|
const float inv_ndims = -1.f/n_dims;
|
1772
1808
|
|
1773
|
-
|
1774
|
-
|
1775
|
-
const float theta = theta_base * pow(freq_base, inv_ndims*i0);
|
1776
|
-
|
1777
|
-
float cos_theta, sin_theta;
|
1778
|
-
rope_yarn(theta, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
1779
|
-
|
1780
|
-
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
1781
|
-
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
1782
|
-
|
1783
|
-
const T x0 = src[0];
|
1784
|
-
const T x1 = src[1];
|
1809
|
+
float cos_theta;
|
1810
|
+
float sin_theta;
|
1785
1811
|
|
1786
|
-
|
1787
|
-
|
1788
|
-
|
1789
|
-
} else {
|
1790
|
-
for (int64_t ic = 2*tiitg; ic < ne0; ic += 2*tptg.x) {
|
1791
|
-
if (ic < n_dims) {
|
1792
|
-
const int64_t i0 = ic/2;
|
1812
|
+
for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
|
1813
|
+
if (i0 < n_dims) {
|
1814
|
+
const int64_t ic = i0/2;
|
1793
1815
|
|
1794
|
-
|
1795
|
-
|
1796
|
-
const float theta = theta_base * pow(freq_base, inv_ndims*ic);
|
1816
|
+
const float theta = theta_base * pow(freq_base, inv_ndims*i0);
|
1797
1817
|
|
1798
|
-
|
1799
|
-
rope_yarn(theta/freq_factor, freq_scale, corr_dims, ic, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
1818
|
+
const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
|
1800
1819
|
|
1801
|
-
|
1802
|
-
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
1820
|
+
rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
1803
1821
|
|
1804
|
-
|
1805
|
-
|
1822
|
+
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
|
1823
|
+
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
|
1806
1824
|
|
1807
|
-
|
1808
|
-
|
1809
|
-
} else {
|
1810
|
-
const int64_t i0 = ic;
|
1825
|
+
const float x0 = src[0];
|
1826
|
+
const float x1 = src[n_dims/2];
|
1811
1827
|
|
1812
|
-
|
1813
|
-
|
1828
|
+
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
1829
|
+
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
|
1830
|
+
} else {
|
1831
|
+
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
1832
|
+
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
1814
1833
|
|
1815
|
-
|
1816
|
-
|
1817
|
-
}
|
1834
|
+
dst_data[0] = src[0];
|
1835
|
+
dst_data[1] = src[1];
|
1818
1836
|
}
|
1819
1837
|
}
|
1820
1838
|
}
|
1821
1839
|
|
1822
|
-
|
1823
|
-
|
1840
|
+
typedef decltype(kernel_rope_norm<float>) kernel_rope_norm_t;
|
1841
|
+
typedef decltype(kernel_rope_neox<float>) kernel_rope_neox_t;
|
1842
|
+
|
1843
|
+
template [[host_name("kernel_rope_norm_f32")]] kernel kernel_rope_norm_t kernel_rope_norm<float>;
|
1844
|
+
template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_rope_norm<half>;
|
1845
|
+
|
1846
|
+
template [[host_name("kernel_rope_neox_f32")]] kernel kernel_rope_neox_t kernel_rope_neox<float>;
|
1847
|
+
template [[host_name("kernel_rope_neox_f16")]] kernel kernel_rope_neox_t kernel_rope_neox<half>;
|
1824
1848
|
|
1825
1849
|
typedef void (im2col_t)(
|
1826
1850
|
device const float * x,
|
@@ -491,7 +491,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer
|
|
491
491
|
if (remote_ptr != 0) {
|
492
492
|
ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft,
|
493
493
|
ggml_backend_rpc_buffer_interface,
|
494
|
-
new ggml_backend_rpc_buffer_context{sock, {}, remote_ptr, "RPC"},
|
494
|
+
new ggml_backend_rpc_buffer_context{sock, {}, remote_ptr, "RPC[" + std::string(buft_ctx->endpoint) + "]"},
|
495
495
|
remote_size);
|
496
496
|
return buffer;
|
497
497
|
} else {
|
@@ -540,22 +540,12 @@ GGML_CALL static size_t ggml_backend_rpc_buffer_type_get_alloc_size(ggml_backend
|
|
540
540
|
return ggml_nbytes(tensor);
|
541
541
|
}
|
542
542
|
|
543
|
-
GGML_CALL static bool ggml_backend_rpc_buffer_type_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend) {
|
544
|
-
if (!ggml_backend_is_rpc(backend)) {
|
545
|
-
return false;
|
546
|
-
}
|
547
|
-
ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
|
548
|
-
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
|
549
|
-
return buft_ctx->endpoint == rpc_ctx->endpoint;
|
550
|
-
}
|
551
|
-
|
552
543
|
static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = {
|
553
544
|
/* .get_name = */ ggml_backend_rpc_buffer_type_name,
|
554
545
|
/* .alloc_buffer = */ ggml_backend_rpc_buffer_type_alloc_buffer,
|
555
546
|
/* .get_alignment = */ ggml_backend_rpc_buffer_type_get_alignment,
|
556
547
|
/* .get_max_size = */ ggml_backend_rpc_get_max_size,
|
557
548
|
/* .get_alloc_size = */ ggml_backend_rpc_buffer_type_get_alloc_size,
|
558
|
-
/* .supports_backend = */ ggml_backend_rpc_buffer_type_supports_backend,
|
559
549
|
/* .is_host = */ NULL,
|
560
550
|
};
|
561
551
|
|
@@ -634,8 +624,17 @@ GGML_CALL static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t
|
|
634
624
|
GGML_CALL static bool ggml_backend_rpc_supports_op(ggml_backend_t backend, const ggml_tensor * op) {
|
635
625
|
UNUSED(backend);
|
636
626
|
UNUSED(op);
|
637
|
-
|
638
|
-
return
|
627
|
+
//TODO: call the remote backend and cache the results
|
628
|
+
return true;
|
629
|
+
}
|
630
|
+
|
631
|
+
GGML_CALL static bool ggml_backend_rpc_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
|
632
|
+
if (buft->iface.get_name != ggml_backend_rpc_buffer_type_name) {
|
633
|
+
return false;
|
634
|
+
}
|
635
|
+
ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
|
636
|
+
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
|
637
|
+
return buft_ctx->endpoint == rpc_ctx->endpoint;
|
639
638
|
}
|
640
639
|
|
641
640
|
static ggml_backend_i ggml_backend_rpc_interface = {
|
@@ -648,9 +647,11 @@ static ggml_backend_i ggml_backend_rpc_interface = {
|
|
648
647
|
/* .synchronize = */ ggml_backend_rpc_synchronize,
|
649
648
|
/* .graph_plan_create = */ NULL,
|
650
649
|
/* .graph_plan_free = */ NULL,
|
650
|
+
/* .graph_plan_update = */ NULL,
|
651
651
|
/* .graph_plan_compute = */ NULL,
|
652
652
|
/* .graph_compute = */ ggml_backend_rpc_graph_compute,
|
653
653
|
/* .supports_op = */ ggml_backend_rpc_supports_op,
|
654
|
+
/* .supports_buft = */ ggml_backend_rpc_supports_buft,
|
654
655
|
/* .offload_op = */ NULL,
|
655
656
|
/* .event_new = */ NULL,
|
656
657
|
/* .event_free = */ NULL,
|
@@ -692,7 +693,7 @@ GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const
|
|
692
693
|
GGML_CALL ggml_backend_t ggml_backend_rpc_init(const char * endpoint) {
|
693
694
|
ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context {
|
694
695
|
/* .endpoint = */ endpoint,
|
695
|
-
/* .name = */ "RPC",
|
696
|
+
/* .name = */ "RPC[" + std::string(endpoint) + "]",
|
696
697
|
};
|
697
698
|
|
698
699
|
ggml_backend_t backend = new ggml_backend {
|