llama_cpp 0.15.4 → 0.16.1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (161) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +16 -0
  3. data/ext/llama_cpp/extconf.rb +3 -2
  4. data/ext/llama_cpp/llama_cpp.cpp +17 -3
  5. data/lib/llama_cpp/version.rb +2 -2
  6. data/sig/llama_cpp.rbs +15 -1
  7. data/vendor/tmp/llama.cpp/Makefile +166 -82
  8. data/vendor/tmp/llama.cpp/ggml-alloc.c +82 -26
  9. data/vendor/tmp/llama.cpp/ggml-backend-impl.h +20 -8
  10. data/vendor/tmp/llama.cpp/ggml-backend.c +183 -69
  11. data/vendor/tmp/llama.cpp/ggml-backend.h +4 -4
  12. data/vendor/tmp/llama.cpp/ggml-blas.cpp +363 -0
  13. data/vendor/tmp/llama.cpp/ggml-blas.h +23 -0
  14. data/vendor/tmp/llama.cpp/ggml-common.h +6 -0
  15. data/vendor/tmp/llama.cpp/ggml-cuda/acc.cu +47 -0
  16. data/vendor/tmp/llama.cpp/ggml-cuda/arange.cu +34 -0
  17. data/vendor/tmp/llama.cpp/ggml-cuda/argsort.cu +104 -0
  18. data/vendor/tmp/llama.cpp/ggml-cuda/binbcast.cu +280 -0
  19. data/vendor/tmp/llama.cpp/ggml-cuda/clamp.cu +34 -0
  20. data/vendor/tmp/llama.cpp/ggml-cuda/concat.cu +196 -0
  21. data/vendor/tmp/llama.cpp/ggml-cuda/convert.cu +686 -0
  22. data/vendor/tmp/llama.cpp/ggml-cuda/cpy.cu +490 -0
  23. data/vendor/tmp/llama.cpp/ggml-cuda/diagmask.cu +40 -0
  24. data/vendor/tmp/llama.cpp/ggml-cuda/dmmv.cu +674 -0
  25. data/vendor/tmp/llama.cpp/ggml-cuda/fattn-tile-f16.cu +319 -0
  26. data/vendor/tmp/llama.cpp/ggml-cuda/fattn-tile-f32.cu +312 -0
  27. data/vendor/tmp/llama.cpp/ggml-cuda/fattn.cu +345 -0
  28. data/vendor/tmp/llama.cpp/ggml-cuda/getrows.cu +178 -0
  29. data/vendor/tmp/llama.cpp/ggml-cuda/im2col.cu +104 -0
  30. data/vendor/tmp/llama.cpp/ggml-cuda/mmq.cu +88 -0
  31. data/vendor/tmp/llama.cpp/ggml-cuda/mmvq.cu +419 -0
  32. data/vendor/tmp/llama.cpp/ggml-cuda/norm.cu +221 -0
  33. data/vendor/tmp/llama.cpp/ggml-cuda/pad.cu +49 -0
  34. data/vendor/tmp/llama.cpp/ggml-cuda/pool2d.cu +94 -0
  35. data/vendor/tmp/llama.cpp/ggml-cuda/quantize.cu +112 -0
  36. data/vendor/tmp/llama.cpp/ggml-cuda/rope.cu +271 -0
  37. data/vendor/tmp/llama.cpp/ggml-cuda/scale.cu +31 -0
  38. data/vendor/tmp/llama.cpp/ggml-cuda/softmax.cu +206 -0
  39. data/vendor/tmp/llama.cpp/ggml-cuda/sumrows.cu +40 -0
  40. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +5 -0
  41. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +5 -0
  42. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +5 -0
  43. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +5 -0
  44. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +5 -0
  45. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +5 -0
  46. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +5 -0
  47. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +5 -0
  48. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +5 -0
  49. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +5 -0
  50. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +5 -0
  51. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +5 -0
  52. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +5 -0
  53. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +5 -0
  54. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +5 -0
  55. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +5 -0
  56. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +5 -0
  57. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +5 -0
  58. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +5 -0
  59. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +5 -0
  60. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +5 -0
  61. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +5 -0
  62. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +5 -0
  63. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +5 -0
  64. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +5 -0
  65. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +5 -0
  66. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +5 -0
  67. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +5 -0
  68. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +5 -0
  69. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +5 -0
  70. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +5 -0
  71. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +5 -0
  72. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +5 -0
  73. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +5 -0
  74. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +5 -0
  75. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +5 -0
  76. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +5 -0
  77. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +5 -0
  78. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +5 -0
  79. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +5 -0
  80. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +5 -0
  81. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +5 -0
  82. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +5 -0
  83. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +5 -0
  84. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +5 -0
  85. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +5 -0
  86. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +5 -0
  87. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +5 -0
  88. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +5 -0
  89. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +5 -0
  90. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +5 -0
  91. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +5 -0
  92. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +5 -0
  93. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +5 -0
  94. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +5 -0
  95. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +5 -0
  96. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +5 -0
  97. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +5 -0
  98. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +5 -0
  99. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +5 -0
  100. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +5 -0
  101. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +5 -0
  102. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +5 -0
  103. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +5 -0
  104. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +5 -0
  105. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +5 -0
  106. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +5 -0
  107. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +5 -0
  108. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +5 -0
  109. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +5 -0
  110. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +5 -0
  111. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +5 -0
  112. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +5 -0
  113. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +5 -0
  114. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +5 -0
  115. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +5 -0
  116. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +5 -0
  117. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +5 -0
  118. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +5 -0
  119. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +5 -0
  120. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +5 -0
  121. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +5 -0
  122. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +5 -0
  123. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +5 -0
  124. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +5 -0
  125. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +5 -0
  126. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb16.cu +10 -0
  127. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb32.cu +9 -0
  128. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb16.cu +10 -0
  129. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb32.cu +10 -0
  130. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb8.cu +8 -0
  131. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q2_k.cu +5 -0
  132. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q3_k.cu +5 -0
  133. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q4_0.cu +5 -0
  134. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q4_1.cu +5 -0
  135. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q4_k.cu +5 -0
  136. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q5_0.cu +5 -0
  137. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q5_1.cu +5 -0
  138. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q5_k.cu +5 -0
  139. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q6_k.cu +5 -0
  140. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q8_0.cu +5 -0
  141. data/vendor/tmp/llama.cpp/ggml-cuda/tsembd.cu +47 -0
  142. data/vendor/tmp/llama.cpp/ggml-cuda/unary.cu +286 -0
  143. data/vendor/tmp/llama.cpp/ggml-cuda/upscale.cu +51 -0
  144. data/vendor/tmp/llama.cpp/ggml-cuda.cu +103 -135
  145. data/vendor/tmp/llama.cpp/ggml-kompute.cpp +29 -13
  146. data/vendor/tmp/llama.cpp/ggml-metal.h +1 -1
  147. data/vendor/tmp/llama.cpp/ggml-metal.m +45 -33
  148. data/vendor/tmp/llama.cpp/ggml-metal.metal +83 -59
  149. data/vendor/tmp/llama.cpp/ggml-rpc.cpp +15 -14
  150. data/vendor/tmp/llama.cpp/ggml-sycl.cpp +26 -90
  151. data/vendor/tmp/llama.cpp/ggml-vulkan-shaders.hpp +74522 -14913
  152. data/vendor/tmp/llama.cpp/ggml-vulkan.cpp +631 -471
  153. data/vendor/tmp/llama.cpp/ggml.c +278 -603
  154. data/vendor/tmp/llama.cpp/ggml.h +9 -28
  155. data/vendor/tmp/llama.cpp/llama.cpp +345 -473
  156. data/vendor/tmp/llama.cpp/llama.h +21 -43
  157. metadata +134 -7
  158. data/vendor/tmp/llama.cpp/ggml-mpi.c +0 -216
  159. data/vendor/tmp/llama.cpp/ggml-mpi.h +0 -39
  160. data/vendor/tmp/llama.cpp/ggml-opencl.cpp +0 -2305
  161. data/vendor/tmp/llama.cpp/ggml-opencl.h +0 -36
@@ -172,8 +172,10 @@ enum ggml_metal_kernel_type {
172
172
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32,
173
173
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,
174
174
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,
175
- GGML_METAL_KERNEL_TYPE_ROPE_F32,
176
- GGML_METAL_KERNEL_TYPE_ROPE_F16,
175
+ GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32,
176
+ GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16,
177
+ GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32,
178
+ GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16,
177
179
  GGML_METAL_KERNEL_TYPE_IM2COL_F16,
178
180
  GGML_METAL_KERNEL_TYPE_IM2COL_F32,
179
181
  GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
@@ -626,8 +628,10 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
626
628
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, ctx->support_simdgroup_mm);
627
629
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm);
628
630
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm);
629
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true);
630
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true);
631
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, rope_norm_f32, true);
632
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, rope_norm_f16, true);
633
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, rope_neox_f32, true);
634
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, rope_neox_f16, true);
631
635
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
632
636
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
633
637
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
@@ -740,7 +744,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
740
744
  case GGML_UNARY_OP_GELU:
741
745
  case GGML_UNARY_OP_GELU_QUICK:
742
746
  case GGML_UNARY_OP_SILU:
743
- return true;
747
+ return ggml_is_contiguous(op->src[0]);
744
748
  default:
745
749
  return false;
746
750
  }
@@ -779,6 +783,12 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
779
783
  case GGML_OP_LEAKY_RELU:
780
784
  return true;
781
785
  case GGML_OP_FLASH_ATTN_EXT:
786
+ if (op->src[1]->type != GGML_TYPE_F16) {
787
+ return false;
788
+ }
789
+ if (op->src[2]->type != GGML_TYPE_F16) {
790
+ return false;
791
+ }
782
792
  if (op->src[0]->ne[0] == 256) {
783
793
  return false;
784
794
  }
@@ -1852,9 +1862,10 @@ static enum ggml_status ggml_metal_graph_compute(
1852
1862
  // ne21 = n_rows
1853
1863
  const int dst_rows = ne20*ne21;
1854
1864
  const int dst_rows_min = n_as;
1865
+ const int dst_rows_max = (ctx->device.maxThreadgroupMemoryLength - 32 - 8192)/4;
1855
1866
 
1856
1867
  // max size of the rowids array in the kernel shared buffer
1857
- GGML_ASSERT(dst_rows <= 2048);
1868
+ GGML_ASSERT(dst_rows <= dst_rows_max);
1858
1869
 
1859
1870
  // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
1860
1871
  // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
@@ -2279,7 +2290,7 @@ static enum ggml_status ggml_metal_graph_compute(
2279
2290
  const int n_dims = ((int32_t *) dst->op_params)[1];
2280
2291
  const int mode = ((int32_t *) dst->op_params)[2];
2281
2292
  // skip 3, n_ctx, used in GLM RoPE, unimplemented in metal
2282
- const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
2293
+ const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
2283
2294
 
2284
2295
  float freq_base;
2285
2296
  float freq_scale;
@@ -2296,22 +2307,23 @@ static enum ggml_status ggml_metal_graph_compute(
2296
2307
  memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
2297
2308
 
2298
2309
  const bool is_neox = mode & 2;
2299
- const bool is_glm = mode & 4;
2300
2310
 
2301
- GGML_ASSERT(!is_glm && "GLM RoPE not implemented in Metal");
2311
+ id<MTLComputePipelineState> pipeline = nil;
2302
2312
 
2303
2313
  if (!is_neox) {
2304
- GGML_ASSERT(id_src2 == nil && "TODO: freq_factors not implemented for !is_neox");
2314
+ switch (src0->type) {
2315
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break;
2316
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break;
2317
+ default: GGML_ASSERT(false);
2318
+ };
2319
+ } else {
2320
+ switch (src0->type) {
2321
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break;
2322
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break;
2323
+ default: GGML_ASSERT(false);
2324
+ };
2305
2325
  }
2306
2326
 
2307
- id<MTLComputePipelineState> pipeline = nil;
2308
-
2309
- switch (src0->type) {
2310
- case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_F32].pipeline; break;
2311
- case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_F16].pipeline; break;
2312
- default: GGML_ASSERT(false);
2313
- };
2314
-
2315
2327
  [encoder setComputePipelineState:pipeline];
2316
2328
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2317
2329
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
@@ -2339,14 +2351,13 @@ static enum ggml_status ggml_metal_graph_compute(
2339
2351
  [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:19];
2340
2352
  [encoder setBytes:&n_past length:sizeof( int) atIndex:20];
2341
2353
  [encoder setBytes:&n_dims length:sizeof( int) atIndex:21];
2342
- [encoder setBytes:&mode length:sizeof( int) atIndex:22];
2343
- [encoder setBytes:&n_orig_ctx length:sizeof( int) atIndex:23];
2344
- [encoder setBytes:&freq_base length:sizeof( float) atIndex:24];
2345
- [encoder setBytes:&freq_scale length:sizeof( float) atIndex:25];
2346
- [encoder setBytes:&ext_factor length:sizeof( float) atIndex:26];
2347
- [encoder setBytes:&attn_factor length:sizeof( float) atIndex:27];
2348
- [encoder setBytes:&beta_fast length:sizeof( float) atIndex:28];
2349
- [encoder setBytes:&beta_slow length:sizeof( float) atIndex:29];
2354
+ [encoder setBytes:&n_ctx_orig length:sizeof( int) atIndex:22];
2355
+ [encoder setBytes:&freq_base length:sizeof( float) atIndex:23];
2356
+ [encoder setBytes:&freq_scale length:sizeof( float) atIndex:24];
2357
+ [encoder setBytes:&ext_factor length:sizeof( float) atIndex:25];
2358
+ [encoder setBytes:&attn_factor length:sizeof( float) atIndex:26];
2359
+ [encoder setBytes:&beta_fast length:sizeof( float) atIndex:27];
2360
+ [encoder setBytes:&beta_slow length:sizeof( float) atIndex:28];
2350
2361
 
2351
2362
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2352
2363
  } break;
@@ -3034,12 +3045,6 @@ GGML_CALL static size_t ggml_backend_metal_buffer_type_get_max_size(ggml_backend
3034
3045
  UNUSED(buft);
3035
3046
  }
3036
3047
 
3037
- GGML_CALL static bool ggml_backend_metal_buffer_type_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend) {
3038
- return ggml_backend_is_metal(backend) || ggml_backend_is_cpu(backend);
3039
-
3040
- UNUSED(buft);
3041
- }
3042
-
3043
3048
  GGML_CALL static bool ggml_backend_metal_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
3044
3049
  return true;
3045
3050
 
@@ -3054,7 +3059,6 @@ GGML_CALL ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void) {
3054
3059
  /* .get_alignment = */ ggml_backend_metal_buffer_type_get_alignment,
3055
3060
  /* .get_max_size = */ ggml_backend_metal_buffer_type_get_max_size,
3056
3061
  /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
3057
- /* .supports_backend = */ ggml_backend_metal_buffer_type_supports_backend,
3058
3062
  /* .is_host = */ ggml_backend_metal_buffer_type_is_host,
3059
3063
  },
3060
3064
  /* .context = */ NULL,
@@ -3169,6 +3173,12 @@ GGML_CALL static bool ggml_backend_metal_supports_op(ggml_backend_t backend, con
3169
3173
  return ggml_metal_supports_op(metal_ctx, op);
3170
3174
  }
3171
3175
 
3176
+ GGML_CALL static bool ggml_backend_metal_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
3177
+ return buft->iface.get_name == ggml_backend_metal_buffer_type_get_name;
3178
+
3179
+ UNUSED(backend);
3180
+ }
3181
+
3172
3182
  static struct ggml_backend_i ggml_backend_metal_i = {
3173
3183
  /* .get_name = */ ggml_backend_metal_name,
3174
3184
  /* .free = */ ggml_backend_metal_free,
@@ -3179,9 +3189,11 @@ static struct ggml_backend_i ggml_backend_metal_i = {
3179
3189
  /* .synchronize = */ NULL,
3180
3190
  /* .graph_plan_create = */ NULL,
3181
3191
  /* .graph_plan_free = */ NULL,
3192
+ /* .graph_plan_update = */ NULL,
3182
3193
  /* .graph_plan_compute = */ NULL,
3183
3194
  /* .graph_compute = */ ggml_backend_metal_graph_compute,
3184
3195
  /* .supports_op = */ ggml_backend_metal_supports_op,
3196
+ /* .supports_buft = */ ggml_backend_metal_supports_buft,
3185
3197
  /* .offload_op = */ NULL,
3186
3198
  /* .event_new = */ NULL,
3187
3199
  /* .event_free = */ NULL,
@@ -1654,8 +1654,7 @@ static float rope_yarn_ramp(const float low, const float high, const int i0) {
1654
1654
  // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
1655
1655
  static void rope_yarn(
1656
1656
  float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
1657
- thread float * cos_theta, thread float * sin_theta
1658
- ) {
1657
+ thread float * cos_theta, thread float * sin_theta) {
1659
1658
  // Get n-d rotational scaling corrected for extrapolation
1660
1659
  float theta_interp = freq_scale * theta_extrap;
1661
1660
  float theta = theta_interp;
@@ -1672,19 +1671,20 @@ static void rope_yarn(
1672
1671
 
1673
1672
  // Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
1674
1673
  // `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
1675
- static float rope_yarn_corr_factor(int n_dims, int n_orig_ctx, float n_rot, float base) {
1676
- return n_dims * log(n_orig_ctx / (n_rot * 2 * M_PI_F)) / (2 * log(base));
1674
+ static float rope_yarn_corr_factor(int n_dims, int n_ctx_orig, float n_rot, float base) {
1675
+ return n_dims * log(n_ctx_orig / (n_rot * 2 * M_PI_F)) / (2 * log(base));
1677
1676
  }
1678
1677
 
1679
1678
  static void rope_yarn_corr_dims(
1680
- int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]
1679
+ int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2]
1681
1680
  ) {
1682
1681
  // start and end correction dims
1683
- dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_fast, freq_base)));
1684
- dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_slow, freq_base)));
1682
+ dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_fast, freq_base)));
1683
+ dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_slow, freq_base)));
1685
1684
  }
1686
1685
 
1687
- typedef void (rope_t)(
1686
+ template<typename T>
1687
+ kernel void kernel_rope_norm(
1688
1688
  device const void * src0,
1689
1689
  device const int32_t * src1,
1690
1690
  device const float * src2,
@@ -1707,8 +1707,7 @@ typedef void (rope_t)(
1707
1707
  constant uint64_t & nb3,
1708
1708
  constant int & n_past,
1709
1709
  constant int & n_dims,
1710
- constant int & mode,
1711
- constant int & n_orig_ctx,
1710
+ constant int & n_ctx_orig,
1712
1711
  constant float & freq_base,
1713
1712
  constant float & freq_scale,
1714
1713
  constant float & ext_factor,
@@ -1717,10 +1716,52 @@ typedef void (rope_t)(
1717
1716
  constant float & beta_slow,
1718
1717
  uint tiitg[[thread_index_in_threadgroup]],
1719
1718
  uint3 tptg[[threads_per_threadgroup]],
1720
- uint3 tgpig[[threadgroup_position_in_grid]]);
1719
+ uint3 tgpig[[threadgroup_position_in_grid]]) {
1720
+ const int64_t i3 = tgpig[2];
1721
+ const int64_t i2 = tgpig[1];
1722
+ const int64_t i1 = tgpig[0];
1723
+
1724
+ float corr_dims[2];
1725
+ rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
1726
+
1727
+ device const int32_t * pos = src1;
1728
+
1729
+ const float theta_base = (float) pos[i2];
1730
+ const float inv_ndims = -1.f/n_dims;
1731
+
1732
+ float cos_theta;
1733
+ float sin_theta;
1734
+
1735
+ for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
1736
+ if (i0 < n_dims) {
1737
+ const int64_t ic = i0/2;
1738
+
1739
+ const float theta = theta_base * pow(freq_base, inv_ndims*i0);
1740
+
1741
+ const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
1742
+
1743
+ rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
1744
+
1745
+ device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
1746
+ device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1747
+
1748
+ const float x0 = src[0];
1749
+ const float x1 = src[1];
1750
+
1751
+ dst_data[0] = x0*cos_theta - x1*sin_theta;
1752
+ dst_data[1] = x0*sin_theta + x1*cos_theta;
1753
+ } else {
1754
+ device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
1755
+ device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1756
+
1757
+ dst_data[0] = src[0];
1758
+ dst_data[1] = src[1];
1759
+ }
1760
+ }
1761
+ }
1721
1762
 
1722
1763
  template<typename T>
1723
- kernel void kernel_rope(
1764
+ kernel void kernel_rope_neox(
1724
1765
  device const void * src0,
1725
1766
  device const int32_t * src1,
1726
1767
  device const float * src2,
@@ -1743,8 +1784,7 @@ kernel void kernel_rope(
1743
1784
  constant uint64_t & nb3,
1744
1785
  constant int & n_past,
1745
1786
  constant int & n_dims,
1746
- constant int & mode,
1747
- constant int & n_orig_ctx,
1787
+ constant int & n_ctx_orig,
1748
1788
  constant float & freq_base,
1749
1789
  constant float & freq_scale,
1750
1790
  constant float & ext_factor,
@@ -1758,69 +1798,53 @@ kernel void kernel_rope(
1758
1798
  const int64_t i2 = tgpig[1];
1759
1799
  const int64_t i1 = tgpig[0];
1760
1800
 
1761
- const bool is_neox = mode & 2;
1762
-
1763
1801
  float corr_dims[2];
1764
- rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
1802
+ rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
1765
1803
 
1766
1804
  device const int32_t * pos = src1;
1767
1805
 
1768
- const int64_t p = pos[i2];
1769
-
1770
- const float theta_base = (float)p;
1806
+ const float theta_base = (float) pos[i2];
1771
1807
  const float inv_ndims = -1.f/n_dims;
1772
1808
 
1773
- if (!is_neox) {
1774
- for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
1775
- const float theta = theta_base * pow(freq_base, inv_ndims*i0);
1776
-
1777
- float cos_theta, sin_theta;
1778
- rope_yarn(theta, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
1779
-
1780
- device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
1781
- device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1782
-
1783
- const T x0 = src[0];
1784
- const T x1 = src[1];
1809
+ float cos_theta;
1810
+ float sin_theta;
1785
1811
 
1786
- dst_data[0] = x0*cos_theta - x1*sin_theta;
1787
- dst_data[1] = x0*sin_theta + x1*cos_theta;
1788
- }
1789
- } else {
1790
- for (int64_t ic = 2*tiitg; ic < ne0; ic += 2*tptg.x) {
1791
- if (ic < n_dims) {
1792
- const int64_t i0 = ic/2;
1812
+ for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
1813
+ if (i0 < n_dims) {
1814
+ const int64_t ic = i0/2;
1793
1815
 
1794
- const float freq_factor = src2 != src0 ? src2[i0] : 1.0f;
1795
-
1796
- const float theta = theta_base * pow(freq_base, inv_ndims*ic);
1816
+ const float theta = theta_base * pow(freq_base, inv_ndims*i0);
1797
1817
 
1798
- float cos_theta, sin_theta;
1799
- rope_yarn(theta/freq_factor, freq_scale, corr_dims, ic, ext_factor, attn_factor, &cos_theta, &sin_theta);
1818
+ const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
1800
1819
 
1801
- device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
1802
- device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1820
+ rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
1803
1821
 
1804
- const float x0 = src[0];
1805
- const float x1 = src[n_dims/2];
1822
+ device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
1823
+ device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
1806
1824
 
1807
- dst_data[0] = x0*cos_theta - x1*sin_theta;
1808
- dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
1809
- } else {
1810
- const int64_t i0 = ic;
1825
+ const float x0 = src[0];
1826
+ const float x1 = src[n_dims/2];
1811
1827
 
1812
- device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
1813
- device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1828
+ dst_data[0] = x0*cos_theta - x1*sin_theta;
1829
+ dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
1830
+ } else {
1831
+ device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
1832
+ device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1814
1833
 
1815
- dst_data[0] = src[0];
1816
- dst_data[1] = src[1];
1817
- }
1834
+ dst_data[0] = src[0];
1835
+ dst_data[1] = src[1];
1818
1836
  }
1819
1837
  }
1820
1838
  }
1821
1839
 
1822
- template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope<float>;
1823
- template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope<half>;
1840
+ typedef decltype(kernel_rope_norm<float>) kernel_rope_norm_t;
1841
+ typedef decltype(kernel_rope_neox<float>) kernel_rope_neox_t;
1842
+
1843
+ template [[host_name("kernel_rope_norm_f32")]] kernel kernel_rope_norm_t kernel_rope_norm<float>;
1844
+ template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_rope_norm<half>;
1845
+
1846
+ template [[host_name("kernel_rope_neox_f32")]] kernel kernel_rope_neox_t kernel_rope_neox<float>;
1847
+ template [[host_name("kernel_rope_neox_f16")]] kernel kernel_rope_neox_t kernel_rope_neox<half>;
1824
1848
 
1825
1849
  typedef void (im2col_t)(
1826
1850
  device const float * x,
@@ -491,7 +491,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer
491
491
  if (remote_ptr != 0) {
492
492
  ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft,
493
493
  ggml_backend_rpc_buffer_interface,
494
- new ggml_backend_rpc_buffer_context{sock, {}, remote_ptr, "RPC"},
494
+ new ggml_backend_rpc_buffer_context{sock, {}, remote_ptr, "RPC[" + std::string(buft_ctx->endpoint) + "]"},
495
495
  remote_size);
496
496
  return buffer;
497
497
  } else {
@@ -540,22 +540,12 @@ GGML_CALL static size_t ggml_backend_rpc_buffer_type_get_alloc_size(ggml_backend
540
540
  return ggml_nbytes(tensor);
541
541
  }
542
542
 
543
- GGML_CALL static bool ggml_backend_rpc_buffer_type_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend) {
544
- if (!ggml_backend_is_rpc(backend)) {
545
- return false;
546
- }
547
- ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
548
- ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
549
- return buft_ctx->endpoint == rpc_ctx->endpoint;
550
- }
551
-
552
543
  static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = {
553
544
  /* .get_name = */ ggml_backend_rpc_buffer_type_name,
554
545
  /* .alloc_buffer = */ ggml_backend_rpc_buffer_type_alloc_buffer,
555
546
  /* .get_alignment = */ ggml_backend_rpc_buffer_type_get_alignment,
556
547
  /* .get_max_size = */ ggml_backend_rpc_get_max_size,
557
548
  /* .get_alloc_size = */ ggml_backend_rpc_buffer_type_get_alloc_size,
558
- /* .supports_backend = */ ggml_backend_rpc_buffer_type_supports_backend,
559
549
  /* .is_host = */ NULL,
560
550
  };
561
551
 
@@ -634,8 +624,17 @@ GGML_CALL static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t
634
624
  GGML_CALL static bool ggml_backend_rpc_supports_op(ggml_backend_t backend, const ggml_tensor * op) {
635
625
  UNUSED(backend);
636
626
  UNUSED(op);
637
- GGML_ASSERT(false && "not implemented");
638
- return false;
627
+ //TODO: call the remote backend and cache the results
628
+ return true;
629
+ }
630
+
631
+ GGML_CALL static bool ggml_backend_rpc_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
632
+ if (buft->iface.get_name != ggml_backend_rpc_buffer_type_name) {
633
+ return false;
634
+ }
635
+ ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
636
+ ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
637
+ return buft_ctx->endpoint == rpc_ctx->endpoint;
639
638
  }
640
639
 
641
640
  static ggml_backend_i ggml_backend_rpc_interface = {
@@ -648,9 +647,11 @@ static ggml_backend_i ggml_backend_rpc_interface = {
648
647
  /* .synchronize = */ ggml_backend_rpc_synchronize,
649
648
  /* .graph_plan_create = */ NULL,
650
649
  /* .graph_plan_free = */ NULL,
650
+ /* .graph_plan_update = */ NULL,
651
651
  /* .graph_plan_compute = */ NULL,
652
652
  /* .graph_compute = */ ggml_backend_rpc_graph_compute,
653
653
  /* .supports_op = */ ggml_backend_rpc_supports_op,
654
+ /* .supports_buft = */ ggml_backend_rpc_supports_buft,
654
655
  /* .offload_op = */ NULL,
655
656
  /* .event_new = */ NULL,
656
657
  /* .event_free = */ NULL,
@@ -692,7 +693,7 @@ GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const
692
693
  GGML_CALL ggml_backend_t ggml_backend_rpc_init(const char * endpoint) {
693
694
  ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context {
694
695
  /* .endpoint = */ endpoint,
695
- /* .name = */ "RPC",
696
+ /* .name = */ "RPC[" + std::string(endpoint) + "]",
696
697
  };
697
698
 
698
699
  ggml_backend_t backend = new ggml_backend {