llama_cpp 0.15.4 → 0.16.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +16 -0
- data/ext/llama_cpp/extconf.rb +3 -2
- data/ext/llama_cpp/llama_cpp.cpp +17 -3
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +15 -1
- data/vendor/tmp/llama.cpp/Makefile +166 -82
- data/vendor/tmp/llama.cpp/ggml-alloc.c +82 -26
- data/vendor/tmp/llama.cpp/ggml-backend-impl.h +20 -8
- data/vendor/tmp/llama.cpp/ggml-backend.c +183 -69
- data/vendor/tmp/llama.cpp/ggml-backend.h +4 -4
- data/vendor/tmp/llama.cpp/ggml-blas.cpp +363 -0
- data/vendor/tmp/llama.cpp/ggml-blas.h +23 -0
- data/vendor/tmp/llama.cpp/ggml-common.h +6 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/acc.cu +47 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/arange.cu +34 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/argsort.cu +104 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/binbcast.cu +280 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/clamp.cu +34 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/concat.cu +196 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/convert.cu +686 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/cpy.cu +490 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/diagmask.cu +40 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/dmmv.cu +674 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/fattn-tile-f16.cu +319 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/fattn-tile-f32.cu +312 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/fattn.cu +345 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/getrows.cu +178 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/im2col.cu +104 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/mmq.cu +88 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/mmvq.cu +419 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/norm.cu +221 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/pad.cu +49 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/pool2d.cu +94 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/quantize.cu +112 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/rope.cu +271 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/scale.cu +31 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/softmax.cu +206 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/sumrows.cu +40 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb16.cu +10 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb32.cu +9 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb16.cu +10 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb32.cu +10 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb8.cu +8 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q2_k.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q3_k.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q4_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q4_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q4_k.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q5_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q5_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q5_k.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q6_k.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q8_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/tsembd.cu +47 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/unary.cu +286 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/upscale.cu +51 -0
- data/vendor/tmp/llama.cpp/ggml-cuda.cu +103 -135
- data/vendor/tmp/llama.cpp/ggml-kompute.cpp +29 -13
- data/vendor/tmp/llama.cpp/ggml-metal.h +1 -1
- data/vendor/tmp/llama.cpp/ggml-metal.m +45 -33
- data/vendor/tmp/llama.cpp/ggml-metal.metal +83 -59
- data/vendor/tmp/llama.cpp/ggml-rpc.cpp +15 -14
- data/vendor/tmp/llama.cpp/ggml-sycl.cpp +26 -90
- data/vendor/tmp/llama.cpp/ggml-vulkan-shaders.hpp +74522 -14913
- data/vendor/tmp/llama.cpp/ggml-vulkan.cpp +631 -471
- data/vendor/tmp/llama.cpp/ggml.c +278 -603
- data/vendor/tmp/llama.cpp/ggml.h +9 -28
- data/vendor/tmp/llama.cpp/llama.cpp +345 -473
- data/vendor/tmp/llama.cpp/llama.h +21 -43
- metadata +134 -7
- data/vendor/tmp/llama.cpp/ggml-mpi.c +0 -216
- data/vendor/tmp/llama.cpp/ggml-mpi.h +0 -39
- data/vendor/tmp/llama.cpp/ggml-opencl.cpp +0 -2305
- data/vendor/tmp/llama.cpp/ggml-opencl.h +0 -36
|
@@ -172,8 +172,10 @@ enum ggml_metal_kernel_type {
|
|
|
172
172
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32,
|
|
173
173
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,
|
|
174
174
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,
|
|
175
|
-
|
|
176
|
-
|
|
175
|
+
GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32,
|
|
176
|
+
GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16,
|
|
177
|
+
GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32,
|
|
178
|
+
GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16,
|
|
177
179
|
GGML_METAL_KERNEL_TYPE_IM2COL_F16,
|
|
178
180
|
GGML_METAL_KERNEL_TYPE_IM2COL_F32,
|
|
179
181
|
GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
|
|
@@ -626,8 +628,10 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
|
626
628
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, ctx->support_simdgroup_mm);
|
|
627
629
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm);
|
|
628
630
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm);
|
|
629
|
-
GGML_METAL_ADD_KERNEL(
|
|
630
|
-
GGML_METAL_ADD_KERNEL(
|
|
631
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, rope_norm_f32, true);
|
|
632
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, rope_norm_f16, true);
|
|
633
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, rope_neox_f32, true);
|
|
634
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, rope_neox_f16, true);
|
|
631
635
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
|
|
632
636
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
|
|
633
637
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
|
|
@@ -740,7 +744,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
|
|
|
740
744
|
case GGML_UNARY_OP_GELU:
|
|
741
745
|
case GGML_UNARY_OP_GELU_QUICK:
|
|
742
746
|
case GGML_UNARY_OP_SILU:
|
|
743
|
-
return
|
|
747
|
+
return ggml_is_contiguous(op->src[0]);
|
|
744
748
|
default:
|
|
745
749
|
return false;
|
|
746
750
|
}
|
|
@@ -779,6 +783,12 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
|
|
|
779
783
|
case GGML_OP_LEAKY_RELU:
|
|
780
784
|
return true;
|
|
781
785
|
case GGML_OP_FLASH_ATTN_EXT:
|
|
786
|
+
if (op->src[1]->type != GGML_TYPE_F16) {
|
|
787
|
+
return false;
|
|
788
|
+
}
|
|
789
|
+
if (op->src[2]->type != GGML_TYPE_F16) {
|
|
790
|
+
return false;
|
|
791
|
+
}
|
|
782
792
|
if (op->src[0]->ne[0] == 256) {
|
|
783
793
|
return false;
|
|
784
794
|
}
|
|
@@ -1852,9 +1862,10 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
|
1852
1862
|
// ne21 = n_rows
|
|
1853
1863
|
const int dst_rows = ne20*ne21;
|
|
1854
1864
|
const int dst_rows_min = n_as;
|
|
1865
|
+
const int dst_rows_max = (ctx->device.maxThreadgroupMemoryLength - 32 - 8192)/4;
|
|
1855
1866
|
|
|
1856
1867
|
// max size of the rowids array in the kernel shared buffer
|
|
1857
|
-
GGML_ASSERT(dst_rows <=
|
|
1868
|
+
GGML_ASSERT(dst_rows <= dst_rows_max);
|
|
1858
1869
|
|
|
1859
1870
|
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
|
1860
1871
|
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
|
|
@@ -2279,7 +2290,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
|
2279
2290
|
const int n_dims = ((int32_t *) dst->op_params)[1];
|
|
2280
2291
|
const int mode = ((int32_t *) dst->op_params)[2];
|
|
2281
2292
|
// skip 3, n_ctx, used in GLM RoPE, unimplemented in metal
|
|
2282
|
-
const int
|
|
2293
|
+
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
|
|
2283
2294
|
|
|
2284
2295
|
float freq_base;
|
|
2285
2296
|
float freq_scale;
|
|
@@ -2296,22 +2307,23 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
|
2296
2307
|
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
|
2297
2308
|
|
|
2298
2309
|
const bool is_neox = mode & 2;
|
|
2299
|
-
const bool is_glm = mode & 4;
|
|
2300
2310
|
|
|
2301
|
-
|
|
2311
|
+
id<MTLComputePipelineState> pipeline = nil;
|
|
2302
2312
|
|
|
2303
2313
|
if (!is_neox) {
|
|
2304
|
-
|
|
2314
|
+
switch (src0->type) {
|
|
2315
|
+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break;
|
|
2316
|
+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break;
|
|
2317
|
+
default: GGML_ASSERT(false);
|
|
2318
|
+
};
|
|
2319
|
+
} else {
|
|
2320
|
+
switch (src0->type) {
|
|
2321
|
+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break;
|
|
2322
|
+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break;
|
|
2323
|
+
default: GGML_ASSERT(false);
|
|
2324
|
+
};
|
|
2305
2325
|
}
|
|
2306
2326
|
|
|
2307
|
-
id<MTLComputePipelineState> pipeline = nil;
|
|
2308
|
-
|
|
2309
|
-
switch (src0->type) {
|
|
2310
|
-
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_F32].pipeline; break;
|
|
2311
|
-
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_F16].pipeline; break;
|
|
2312
|
-
default: GGML_ASSERT(false);
|
|
2313
|
-
};
|
|
2314
|
-
|
|
2315
2327
|
[encoder setComputePipelineState:pipeline];
|
|
2316
2328
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
2317
2329
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
|
@@ -2339,14 +2351,13 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
|
2339
2351
|
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:19];
|
|
2340
2352
|
[encoder setBytes:&n_past length:sizeof( int) atIndex:20];
|
|
2341
2353
|
[encoder setBytes:&n_dims length:sizeof( int) atIndex:21];
|
|
2342
|
-
[encoder setBytes:&
|
|
2343
|
-
[encoder setBytes:&
|
|
2344
|
-
[encoder setBytes:&
|
|
2345
|
-
[encoder setBytes:&
|
|
2346
|
-
[encoder setBytes:&
|
|
2347
|
-
[encoder setBytes:&
|
|
2348
|
-
[encoder setBytes:&
|
|
2349
|
-
[encoder setBytes:&beta_slow length:sizeof( float) atIndex:29];
|
|
2354
|
+
[encoder setBytes:&n_ctx_orig length:sizeof( int) atIndex:22];
|
|
2355
|
+
[encoder setBytes:&freq_base length:sizeof( float) atIndex:23];
|
|
2356
|
+
[encoder setBytes:&freq_scale length:sizeof( float) atIndex:24];
|
|
2357
|
+
[encoder setBytes:&ext_factor length:sizeof( float) atIndex:25];
|
|
2358
|
+
[encoder setBytes:&attn_factor length:sizeof( float) atIndex:26];
|
|
2359
|
+
[encoder setBytes:&beta_fast length:sizeof( float) atIndex:27];
|
|
2360
|
+
[encoder setBytes:&beta_slow length:sizeof( float) atIndex:28];
|
|
2350
2361
|
|
|
2351
2362
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
2352
2363
|
} break;
|
|
@@ -3034,12 +3045,6 @@ GGML_CALL static size_t ggml_backend_metal_buffer_type_get_max_size(ggml_backend
|
|
|
3034
3045
|
UNUSED(buft);
|
|
3035
3046
|
}
|
|
3036
3047
|
|
|
3037
|
-
GGML_CALL static bool ggml_backend_metal_buffer_type_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend) {
|
|
3038
|
-
return ggml_backend_is_metal(backend) || ggml_backend_is_cpu(backend);
|
|
3039
|
-
|
|
3040
|
-
UNUSED(buft);
|
|
3041
|
-
}
|
|
3042
|
-
|
|
3043
3048
|
GGML_CALL static bool ggml_backend_metal_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
|
|
3044
3049
|
return true;
|
|
3045
3050
|
|
|
@@ -3054,7 +3059,6 @@ GGML_CALL ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void) {
|
|
|
3054
3059
|
/* .get_alignment = */ ggml_backend_metal_buffer_type_get_alignment,
|
|
3055
3060
|
/* .get_max_size = */ ggml_backend_metal_buffer_type_get_max_size,
|
|
3056
3061
|
/* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
|
|
3057
|
-
/* .supports_backend = */ ggml_backend_metal_buffer_type_supports_backend,
|
|
3058
3062
|
/* .is_host = */ ggml_backend_metal_buffer_type_is_host,
|
|
3059
3063
|
},
|
|
3060
3064
|
/* .context = */ NULL,
|
|
@@ -3169,6 +3173,12 @@ GGML_CALL static bool ggml_backend_metal_supports_op(ggml_backend_t backend, con
|
|
|
3169
3173
|
return ggml_metal_supports_op(metal_ctx, op);
|
|
3170
3174
|
}
|
|
3171
3175
|
|
|
3176
|
+
GGML_CALL static bool ggml_backend_metal_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
|
|
3177
|
+
return buft->iface.get_name == ggml_backend_metal_buffer_type_get_name;
|
|
3178
|
+
|
|
3179
|
+
UNUSED(backend);
|
|
3180
|
+
}
|
|
3181
|
+
|
|
3172
3182
|
static struct ggml_backend_i ggml_backend_metal_i = {
|
|
3173
3183
|
/* .get_name = */ ggml_backend_metal_name,
|
|
3174
3184
|
/* .free = */ ggml_backend_metal_free,
|
|
@@ -3179,9 +3189,11 @@ static struct ggml_backend_i ggml_backend_metal_i = {
|
|
|
3179
3189
|
/* .synchronize = */ NULL,
|
|
3180
3190
|
/* .graph_plan_create = */ NULL,
|
|
3181
3191
|
/* .graph_plan_free = */ NULL,
|
|
3192
|
+
/* .graph_plan_update = */ NULL,
|
|
3182
3193
|
/* .graph_plan_compute = */ NULL,
|
|
3183
3194
|
/* .graph_compute = */ ggml_backend_metal_graph_compute,
|
|
3184
3195
|
/* .supports_op = */ ggml_backend_metal_supports_op,
|
|
3196
|
+
/* .supports_buft = */ ggml_backend_metal_supports_buft,
|
|
3185
3197
|
/* .offload_op = */ NULL,
|
|
3186
3198
|
/* .event_new = */ NULL,
|
|
3187
3199
|
/* .event_free = */ NULL,
|
|
@@ -1654,8 +1654,7 @@ static float rope_yarn_ramp(const float low, const float high, const int i0) {
|
|
|
1654
1654
|
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
|
|
1655
1655
|
static void rope_yarn(
|
|
1656
1656
|
float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
|
|
1657
|
-
thread float * cos_theta, thread float * sin_theta
|
|
1658
|
-
) {
|
|
1657
|
+
thread float * cos_theta, thread float * sin_theta) {
|
|
1659
1658
|
// Get n-d rotational scaling corrected for extrapolation
|
|
1660
1659
|
float theta_interp = freq_scale * theta_extrap;
|
|
1661
1660
|
float theta = theta_interp;
|
|
@@ -1672,19 +1671,20 @@ static void rope_yarn(
|
|
|
1672
1671
|
|
|
1673
1672
|
// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
|
|
1674
1673
|
// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
|
|
1675
|
-
static float rope_yarn_corr_factor(int n_dims, int
|
|
1676
|
-
return n_dims * log(
|
|
1674
|
+
static float rope_yarn_corr_factor(int n_dims, int n_ctx_orig, float n_rot, float base) {
|
|
1675
|
+
return n_dims * log(n_ctx_orig / (n_rot * 2 * M_PI_F)) / (2 * log(base));
|
|
1677
1676
|
}
|
|
1678
1677
|
|
|
1679
1678
|
static void rope_yarn_corr_dims(
|
|
1680
|
-
int n_dims, int
|
|
1679
|
+
int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2]
|
|
1681
1680
|
) {
|
|
1682
1681
|
// start and end correction dims
|
|
1683
|
-
dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims,
|
|
1684
|
-
dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims,
|
|
1682
|
+
dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_fast, freq_base)));
|
|
1683
|
+
dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_slow, freq_base)));
|
|
1685
1684
|
}
|
|
1686
1685
|
|
|
1687
|
-
|
|
1686
|
+
template<typename T>
|
|
1687
|
+
kernel void kernel_rope_norm(
|
|
1688
1688
|
device const void * src0,
|
|
1689
1689
|
device const int32_t * src1,
|
|
1690
1690
|
device const float * src2,
|
|
@@ -1707,8 +1707,7 @@ typedef void (rope_t)(
|
|
|
1707
1707
|
constant uint64_t & nb3,
|
|
1708
1708
|
constant int & n_past,
|
|
1709
1709
|
constant int & n_dims,
|
|
1710
|
-
constant int &
|
|
1711
|
-
constant int & n_orig_ctx,
|
|
1710
|
+
constant int & n_ctx_orig,
|
|
1712
1711
|
constant float & freq_base,
|
|
1713
1712
|
constant float & freq_scale,
|
|
1714
1713
|
constant float & ext_factor,
|
|
@@ -1717,10 +1716,52 @@ typedef void (rope_t)(
|
|
|
1717
1716
|
constant float & beta_slow,
|
|
1718
1717
|
uint tiitg[[thread_index_in_threadgroup]],
|
|
1719
1718
|
uint3 tptg[[threads_per_threadgroup]],
|
|
1720
|
-
uint3 tgpig[[threadgroup_position_in_grid]])
|
|
1719
|
+
uint3 tgpig[[threadgroup_position_in_grid]]) {
|
|
1720
|
+
const int64_t i3 = tgpig[2];
|
|
1721
|
+
const int64_t i2 = tgpig[1];
|
|
1722
|
+
const int64_t i1 = tgpig[0];
|
|
1723
|
+
|
|
1724
|
+
float corr_dims[2];
|
|
1725
|
+
rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
|
|
1726
|
+
|
|
1727
|
+
device const int32_t * pos = src1;
|
|
1728
|
+
|
|
1729
|
+
const float theta_base = (float) pos[i2];
|
|
1730
|
+
const float inv_ndims = -1.f/n_dims;
|
|
1731
|
+
|
|
1732
|
+
float cos_theta;
|
|
1733
|
+
float sin_theta;
|
|
1734
|
+
|
|
1735
|
+
for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
|
|
1736
|
+
if (i0 < n_dims) {
|
|
1737
|
+
const int64_t ic = i0/2;
|
|
1738
|
+
|
|
1739
|
+
const float theta = theta_base * pow(freq_base, inv_ndims*i0);
|
|
1740
|
+
|
|
1741
|
+
const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
|
|
1742
|
+
|
|
1743
|
+
rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
|
1744
|
+
|
|
1745
|
+
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
|
1746
|
+
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
1747
|
+
|
|
1748
|
+
const float x0 = src[0];
|
|
1749
|
+
const float x1 = src[1];
|
|
1750
|
+
|
|
1751
|
+
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
|
1752
|
+
dst_data[1] = x0*sin_theta + x1*cos_theta;
|
|
1753
|
+
} else {
|
|
1754
|
+
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
|
1755
|
+
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
1756
|
+
|
|
1757
|
+
dst_data[0] = src[0];
|
|
1758
|
+
dst_data[1] = src[1];
|
|
1759
|
+
}
|
|
1760
|
+
}
|
|
1761
|
+
}
|
|
1721
1762
|
|
|
1722
1763
|
template<typename T>
|
|
1723
|
-
kernel void
|
|
1764
|
+
kernel void kernel_rope_neox(
|
|
1724
1765
|
device const void * src0,
|
|
1725
1766
|
device const int32_t * src1,
|
|
1726
1767
|
device const float * src2,
|
|
@@ -1743,8 +1784,7 @@ kernel void kernel_rope(
|
|
|
1743
1784
|
constant uint64_t & nb3,
|
|
1744
1785
|
constant int & n_past,
|
|
1745
1786
|
constant int & n_dims,
|
|
1746
|
-
constant int &
|
|
1747
|
-
constant int & n_orig_ctx,
|
|
1787
|
+
constant int & n_ctx_orig,
|
|
1748
1788
|
constant float & freq_base,
|
|
1749
1789
|
constant float & freq_scale,
|
|
1750
1790
|
constant float & ext_factor,
|
|
@@ -1758,69 +1798,53 @@ kernel void kernel_rope(
|
|
|
1758
1798
|
const int64_t i2 = tgpig[1];
|
|
1759
1799
|
const int64_t i1 = tgpig[0];
|
|
1760
1800
|
|
|
1761
|
-
const bool is_neox = mode & 2;
|
|
1762
|
-
|
|
1763
1801
|
float corr_dims[2];
|
|
1764
|
-
rope_yarn_corr_dims(n_dims,
|
|
1802
|
+
rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
|
|
1765
1803
|
|
|
1766
1804
|
device const int32_t * pos = src1;
|
|
1767
1805
|
|
|
1768
|
-
const
|
|
1769
|
-
|
|
1770
|
-
const float theta_base = (float)p;
|
|
1806
|
+
const float theta_base = (float) pos[i2];
|
|
1771
1807
|
const float inv_ndims = -1.f/n_dims;
|
|
1772
1808
|
|
|
1773
|
-
|
|
1774
|
-
|
|
1775
|
-
const float theta = theta_base * pow(freq_base, inv_ndims*i0);
|
|
1776
|
-
|
|
1777
|
-
float cos_theta, sin_theta;
|
|
1778
|
-
rope_yarn(theta, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
|
1779
|
-
|
|
1780
|
-
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
|
1781
|
-
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
1782
|
-
|
|
1783
|
-
const T x0 = src[0];
|
|
1784
|
-
const T x1 = src[1];
|
|
1809
|
+
float cos_theta;
|
|
1810
|
+
float sin_theta;
|
|
1785
1811
|
|
|
1786
|
-
|
|
1787
|
-
|
|
1788
|
-
|
|
1789
|
-
} else {
|
|
1790
|
-
for (int64_t ic = 2*tiitg; ic < ne0; ic += 2*tptg.x) {
|
|
1791
|
-
if (ic < n_dims) {
|
|
1792
|
-
const int64_t i0 = ic/2;
|
|
1812
|
+
for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
|
|
1813
|
+
if (i0 < n_dims) {
|
|
1814
|
+
const int64_t ic = i0/2;
|
|
1793
1815
|
|
|
1794
|
-
|
|
1795
|
-
|
|
1796
|
-
const float theta = theta_base * pow(freq_base, inv_ndims*ic);
|
|
1816
|
+
const float theta = theta_base * pow(freq_base, inv_ndims*i0);
|
|
1797
1817
|
|
|
1798
|
-
|
|
1799
|
-
rope_yarn(theta/freq_factor, freq_scale, corr_dims, ic, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
|
1818
|
+
const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
|
|
1800
1819
|
|
|
1801
|
-
|
|
1802
|
-
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
1820
|
+
rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
|
1803
1821
|
|
|
1804
|
-
|
|
1805
|
-
|
|
1822
|
+
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
|
|
1823
|
+
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
|
|
1806
1824
|
|
|
1807
|
-
|
|
1808
|
-
|
|
1809
|
-
} else {
|
|
1810
|
-
const int64_t i0 = ic;
|
|
1825
|
+
const float x0 = src[0];
|
|
1826
|
+
const float x1 = src[n_dims/2];
|
|
1811
1827
|
|
|
1812
|
-
|
|
1813
|
-
|
|
1828
|
+
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
|
1829
|
+
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
|
|
1830
|
+
} else {
|
|
1831
|
+
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
|
1832
|
+
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
1814
1833
|
|
|
1815
|
-
|
|
1816
|
-
|
|
1817
|
-
}
|
|
1834
|
+
dst_data[0] = src[0];
|
|
1835
|
+
dst_data[1] = src[1];
|
|
1818
1836
|
}
|
|
1819
1837
|
}
|
|
1820
1838
|
}
|
|
1821
1839
|
|
|
1822
|
-
|
|
1823
|
-
|
|
1840
|
+
typedef decltype(kernel_rope_norm<float>) kernel_rope_norm_t;
|
|
1841
|
+
typedef decltype(kernel_rope_neox<float>) kernel_rope_neox_t;
|
|
1842
|
+
|
|
1843
|
+
template [[host_name("kernel_rope_norm_f32")]] kernel kernel_rope_norm_t kernel_rope_norm<float>;
|
|
1844
|
+
template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_rope_norm<half>;
|
|
1845
|
+
|
|
1846
|
+
template [[host_name("kernel_rope_neox_f32")]] kernel kernel_rope_neox_t kernel_rope_neox<float>;
|
|
1847
|
+
template [[host_name("kernel_rope_neox_f16")]] kernel kernel_rope_neox_t kernel_rope_neox<half>;
|
|
1824
1848
|
|
|
1825
1849
|
typedef void (im2col_t)(
|
|
1826
1850
|
device const float * x,
|
|
@@ -491,7 +491,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer
|
|
|
491
491
|
if (remote_ptr != 0) {
|
|
492
492
|
ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft,
|
|
493
493
|
ggml_backend_rpc_buffer_interface,
|
|
494
|
-
new ggml_backend_rpc_buffer_context{sock, {}, remote_ptr, "RPC"},
|
|
494
|
+
new ggml_backend_rpc_buffer_context{sock, {}, remote_ptr, "RPC[" + std::string(buft_ctx->endpoint) + "]"},
|
|
495
495
|
remote_size);
|
|
496
496
|
return buffer;
|
|
497
497
|
} else {
|
|
@@ -540,22 +540,12 @@ GGML_CALL static size_t ggml_backend_rpc_buffer_type_get_alloc_size(ggml_backend
|
|
|
540
540
|
return ggml_nbytes(tensor);
|
|
541
541
|
}
|
|
542
542
|
|
|
543
|
-
GGML_CALL static bool ggml_backend_rpc_buffer_type_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend) {
|
|
544
|
-
if (!ggml_backend_is_rpc(backend)) {
|
|
545
|
-
return false;
|
|
546
|
-
}
|
|
547
|
-
ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
|
|
548
|
-
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
|
|
549
|
-
return buft_ctx->endpoint == rpc_ctx->endpoint;
|
|
550
|
-
}
|
|
551
|
-
|
|
552
543
|
static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = {
|
|
553
544
|
/* .get_name = */ ggml_backend_rpc_buffer_type_name,
|
|
554
545
|
/* .alloc_buffer = */ ggml_backend_rpc_buffer_type_alloc_buffer,
|
|
555
546
|
/* .get_alignment = */ ggml_backend_rpc_buffer_type_get_alignment,
|
|
556
547
|
/* .get_max_size = */ ggml_backend_rpc_get_max_size,
|
|
557
548
|
/* .get_alloc_size = */ ggml_backend_rpc_buffer_type_get_alloc_size,
|
|
558
|
-
/* .supports_backend = */ ggml_backend_rpc_buffer_type_supports_backend,
|
|
559
549
|
/* .is_host = */ NULL,
|
|
560
550
|
};
|
|
561
551
|
|
|
@@ -634,8 +624,17 @@ GGML_CALL static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t
|
|
|
634
624
|
GGML_CALL static bool ggml_backend_rpc_supports_op(ggml_backend_t backend, const ggml_tensor * op) {
|
|
635
625
|
UNUSED(backend);
|
|
636
626
|
UNUSED(op);
|
|
637
|
-
|
|
638
|
-
return
|
|
627
|
+
//TODO: call the remote backend and cache the results
|
|
628
|
+
return true;
|
|
629
|
+
}
|
|
630
|
+
|
|
631
|
+
GGML_CALL static bool ggml_backend_rpc_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
|
|
632
|
+
if (buft->iface.get_name != ggml_backend_rpc_buffer_type_name) {
|
|
633
|
+
return false;
|
|
634
|
+
}
|
|
635
|
+
ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
|
|
636
|
+
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
|
|
637
|
+
return buft_ctx->endpoint == rpc_ctx->endpoint;
|
|
639
638
|
}
|
|
640
639
|
|
|
641
640
|
static ggml_backend_i ggml_backend_rpc_interface = {
|
|
@@ -648,9 +647,11 @@ static ggml_backend_i ggml_backend_rpc_interface = {
|
|
|
648
647
|
/* .synchronize = */ ggml_backend_rpc_synchronize,
|
|
649
648
|
/* .graph_plan_create = */ NULL,
|
|
650
649
|
/* .graph_plan_free = */ NULL,
|
|
650
|
+
/* .graph_plan_update = */ NULL,
|
|
651
651
|
/* .graph_plan_compute = */ NULL,
|
|
652
652
|
/* .graph_compute = */ ggml_backend_rpc_graph_compute,
|
|
653
653
|
/* .supports_op = */ ggml_backend_rpc_supports_op,
|
|
654
|
+
/* .supports_buft = */ ggml_backend_rpc_supports_buft,
|
|
654
655
|
/* .offload_op = */ NULL,
|
|
655
656
|
/* .event_new = */ NULL,
|
|
656
657
|
/* .event_free = */ NULL,
|
|
@@ -692,7 +693,7 @@ GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const
|
|
|
692
693
|
GGML_CALL ggml_backend_t ggml_backend_rpc_init(const char * endpoint) {
|
|
693
694
|
ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context {
|
|
694
695
|
/* .endpoint = */ endpoint,
|
|
695
|
-
/* .name = */ "RPC",
|
|
696
|
+
/* .name = */ "RPC[" + std::string(endpoint) + "]",
|
|
696
697
|
};
|
|
697
698
|
|
|
698
699
|
ggml_backend_t backend = new ggml_backend {
|