llama_cpp 0.15.4 → 0.16.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (161) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +16 -0
  3. data/ext/llama_cpp/extconf.rb +3 -2
  4. data/ext/llama_cpp/llama_cpp.cpp +17 -3
  5. data/lib/llama_cpp/version.rb +2 -2
  6. data/sig/llama_cpp.rbs +15 -1
  7. data/vendor/tmp/llama.cpp/Makefile +166 -82
  8. data/vendor/tmp/llama.cpp/ggml-alloc.c +82 -26
  9. data/vendor/tmp/llama.cpp/ggml-backend-impl.h +20 -8
  10. data/vendor/tmp/llama.cpp/ggml-backend.c +183 -69
  11. data/vendor/tmp/llama.cpp/ggml-backend.h +4 -4
  12. data/vendor/tmp/llama.cpp/ggml-blas.cpp +363 -0
  13. data/vendor/tmp/llama.cpp/ggml-blas.h +23 -0
  14. data/vendor/tmp/llama.cpp/ggml-common.h +6 -0
  15. data/vendor/tmp/llama.cpp/ggml-cuda/acc.cu +47 -0
  16. data/vendor/tmp/llama.cpp/ggml-cuda/arange.cu +34 -0
  17. data/vendor/tmp/llama.cpp/ggml-cuda/argsort.cu +104 -0
  18. data/vendor/tmp/llama.cpp/ggml-cuda/binbcast.cu +280 -0
  19. data/vendor/tmp/llama.cpp/ggml-cuda/clamp.cu +34 -0
  20. data/vendor/tmp/llama.cpp/ggml-cuda/concat.cu +196 -0
  21. data/vendor/tmp/llama.cpp/ggml-cuda/convert.cu +686 -0
  22. data/vendor/tmp/llama.cpp/ggml-cuda/cpy.cu +490 -0
  23. data/vendor/tmp/llama.cpp/ggml-cuda/diagmask.cu +40 -0
  24. data/vendor/tmp/llama.cpp/ggml-cuda/dmmv.cu +674 -0
  25. data/vendor/tmp/llama.cpp/ggml-cuda/fattn-tile-f16.cu +319 -0
  26. data/vendor/tmp/llama.cpp/ggml-cuda/fattn-tile-f32.cu +312 -0
  27. data/vendor/tmp/llama.cpp/ggml-cuda/fattn.cu +345 -0
  28. data/vendor/tmp/llama.cpp/ggml-cuda/getrows.cu +178 -0
  29. data/vendor/tmp/llama.cpp/ggml-cuda/im2col.cu +104 -0
  30. data/vendor/tmp/llama.cpp/ggml-cuda/mmq.cu +88 -0
  31. data/vendor/tmp/llama.cpp/ggml-cuda/mmvq.cu +419 -0
  32. data/vendor/tmp/llama.cpp/ggml-cuda/norm.cu +221 -0
  33. data/vendor/tmp/llama.cpp/ggml-cuda/pad.cu +49 -0
  34. data/vendor/tmp/llama.cpp/ggml-cuda/pool2d.cu +94 -0
  35. data/vendor/tmp/llama.cpp/ggml-cuda/quantize.cu +112 -0
  36. data/vendor/tmp/llama.cpp/ggml-cuda/rope.cu +271 -0
  37. data/vendor/tmp/llama.cpp/ggml-cuda/scale.cu +31 -0
  38. data/vendor/tmp/llama.cpp/ggml-cuda/softmax.cu +206 -0
  39. data/vendor/tmp/llama.cpp/ggml-cuda/sumrows.cu +40 -0
  40. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +5 -0
  41. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +5 -0
  42. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +5 -0
  43. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +5 -0
  44. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +5 -0
  45. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +5 -0
  46. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +5 -0
  47. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +5 -0
  48. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +5 -0
  49. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +5 -0
  50. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +5 -0
  51. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +5 -0
  52. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +5 -0
  53. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +5 -0
  54. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +5 -0
  55. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +5 -0
  56. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +5 -0
  57. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +5 -0
  58. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +5 -0
  59. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +5 -0
  60. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +5 -0
  61. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +5 -0
  62. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +5 -0
  63. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +5 -0
  64. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +5 -0
  65. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +5 -0
  66. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +5 -0
  67. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +5 -0
  68. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +5 -0
  69. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +5 -0
  70. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +5 -0
  71. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +5 -0
  72. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +5 -0
  73. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +5 -0
  74. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +5 -0
  75. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +5 -0
  76. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +5 -0
  77. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +5 -0
  78. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +5 -0
  79. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +5 -0
  80. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +5 -0
  81. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +5 -0
  82. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +5 -0
  83. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +5 -0
  84. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +5 -0
  85. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +5 -0
  86. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +5 -0
  87. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +5 -0
  88. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +5 -0
  89. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +5 -0
  90. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +5 -0
  91. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +5 -0
  92. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +5 -0
  93. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +5 -0
  94. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +5 -0
  95. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +5 -0
  96. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +5 -0
  97. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +5 -0
  98. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +5 -0
  99. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +5 -0
  100. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +5 -0
  101. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +5 -0
  102. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +5 -0
  103. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +5 -0
  104. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +5 -0
  105. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +5 -0
  106. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +5 -0
  107. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +5 -0
  108. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +5 -0
  109. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +5 -0
  110. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +5 -0
  111. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +5 -0
  112. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +5 -0
  113. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +5 -0
  114. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +5 -0
  115. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +5 -0
  116. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +5 -0
  117. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +5 -0
  118. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +5 -0
  119. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +5 -0
  120. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +5 -0
  121. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +5 -0
  122. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +5 -0
  123. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +5 -0
  124. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +5 -0
  125. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +5 -0
  126. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb16.cu +10 -0
  127. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb32.cu +9 -0
  128. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb16.cu +10 -0
  129. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb32.cu +10 -0
  130. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb8.cu +8 -0
  131. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q2_k.cu +5 -0
  132. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q3_k.cu +5 -0
  133. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q4_0.cu +5 -0
  134. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q4_1.cu +5 -0
  135. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q4_k.cu +5 -0
  136. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q5_0.cu +5 -0
  137. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q5_1.cu +5 -0
  138. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q5_k.cu +5 -0
  139. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q6_k.cu +5 -0
  140. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q8_0.cu +5 -0
  141. data/vendor/tmp/llama.cpp/ggml-cuda/tsembd.cu +47 -0
  142. data/vendor/tmp/llama.cpp/ggml-cuda/unary.cu +286 -0
  143. data/vendor/tmp/llama.cpp/ggml-cuda/upscale.cu +51 -0
  144. data/vendor/tmp/llama.cpp/ggml-cuda.cu +103 -135
  145. data/vendor/tmp/llama.cpp/ggml-kompute.cpp +29 -13
  146. data/vendor/tmp/llama.cpp/ggml-metal.h +1 -1
  147. data/vendor/tmp/llama.cpp/ggml-metal.m +45 -33
  148. data/vendor/tmp/llama.cpp/ggml-metal.metal +83 -59
  149. data/vendor/tmp/llama.cpp/ggml-rpc.cpp +15 -14
  150. data/vendor/tmp/llama.cpp/ggml-sycl.cpp +26 -90
  151. data/vendor/tmp/llama.cpp/ggml-vulkan-shaders.hpp +74522 -14913
  152. data/vendor/tmp/llama.cpp/ggml-vulkan.cpp +631 -471
  153. data/vendor/tmp/llama.cpp/ggml.c +278 -603
  154. data/vendor/tmp/llama.cpp/ggml.h +9 -28
  155. data/vendor/tmp/llama.cpp/llama.cpp +345 -473
  156. data/vendor/tmp/llama.cpp/llama.h +21 -43
  157. metadata +134 -7
  158. data/vendor/tmp/llama.cpp/ggml-mpi.c +0 -216
  159. data/vendor/tmp/llama.cpp/ggml-mpi.h +0 -39
  160. data/vendor/tmp/llama.cpp/ggml-opencl.cpp +0 -2305
  161. data/vendor/tmp/llama.cpp/ggml-opencl.h +0 -36
@@ -172,8 +172,10 @@ enum ggml_metal_kernel_type {
172
172
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32,
173
173
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,
174
174
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,
175
- GGML_METAL_KERNEL_TYPE_ROPE_F32,
176
- GGML_METAL_KERNEL_TYPE_ROPE_F16,
175
+ GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32,
176
+ GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16,
177
+ GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32,
178
+ GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16,
177
179
  GGML_METAL_KERNEL_TYPE_IM2COL_F16,
178
180
  GGML_METAL_KERNEL_TYPE_IM2COL_F32,
179
181
  GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
@@ -626,8 +628,10 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
626
628
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, ctx->support_simdgroup_mm);
627
629
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm);
628
630
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm);
629
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true);
630
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true);
631
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, rope_norm_f32, true);
632
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, rope_norm_f16, true);
633
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, rope_neox_f32, true);
634
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, rope_neox_f16, true);
631
635
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
632
636
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
633
637
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
@@ -740,7 +744,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
740
744
  case GGML_UNARY_OP_GELU:
741
745
  case GGML_UNARY_OP_GELU_QUICK:
742
746
  case GGML_UNARY_OP_SILU:
743
- return true;
747
+ return ggml_is_contiguous(op->src[0]);
744
748
  default:
745
749
  return false;
746
750
  }
@@ -779,6 +783,12 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
779
783
  case GGML_OP_LEAKY_RELU:
780
784
  return true;
781
785
  case GGML_OP_FLASH_ATTN_EXT:
786
+ if (op->src[1]->type != GGML_TYPE_F16) {
787
+ return false;
788
+ }
789
+ if (op->src[2]->type != GGML_TYPE_F16) {
790
+ return false;
791
+ }
782
792
  if (op->src[0]->ne[0] == 256) {
783
793
  return false;
784
794
  }
@@ -1852,9 +1862,10 @@ static enum ggml_status ggml_metal_graph_compute(
1852
1862
  // ne21 = n_rows
1853
1863
  const int dst_rows = ne20*ne21;
1854
1864
  const int dst_rows_min = n_as;
1865
+ const int dst_rows_max = (ctx->device.maxThreadgroupMemoryLength - 32 - 8192)/4;
1855
1866
 
1856
1867
  // max size of the rowids array in the kernel shared buffer
1857
- GGML_ASSERT(dst_rows <= 2048);
1868
+ GGML_ASSERT(dst_rows <= dst_rows_max);
1858
1869
 
1859
1870
  // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
1860
1871
  // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
@@ -2279,7 +2290,7 @@ static enum ggml_status ggml_metal_graph_compute(
2279
2290
  const int n_dims = ((int32_t *) dst->op_params)[1];
2280
2291
  const int mode = ((int32_t *) dst->op_params)[2];
2281
2292
  // skip 3, n_ctx, used in GLM RoPE, unimplemented in metal
2282
- const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
2293
+ const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
2283
2294
 
2284
2295
  float freq_base;
2285
2296
  float freq_scale;
@@ -2296,22 +2307,23 @@ static enum ggml_status ggml_metal_graph_compute(
2296
2307
  memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
2297
2308
 
2298
2309
  const bool is_neox = mode & 2;
2299
- const bool is_glm = mode & 4;
2300
2310
 
2301
- GGML_ASSERT(!is_glm && "GLM RoPE not implemented in Metal");
2311
+ id<MTLComputePipelineState> pipeline = nil;
2302
2312
 
2303
2313
  if (!is_neox) {
2304
- GGML_ASSERT(id_src2 == nil && "TODO: freq_factors not implemented for !is_neox");
2314
+ switch (src0->type) {
2315
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break;
2316
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break;
2317
+ default: GGML_ASSERT(false);
2318
+ };
2319
+ } else {
2320
+ switch (src0->type) {
2321
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break;
2322
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break;
2323
+ default: GGML_ASSERT(false);
2324
+ };
2305
2325
  }
2306
2326
 
2307
- id<MTLComputePipelineState> pipeline = nil;
2308
-
2309
- switch (src0->type) {
2310
- case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_F32].pipeline; break;
2311
- case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_F16].pipeline; break;
2312
- default: GGML_ASSERT(false);
2313
- };
2314
-
2315
2327
  [encoder setComputePipelineState:pipeline];
2316
2328
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2317
2329
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
@@ -2339,14 +2351,13 @@ static enum ggml_status ggml_metal_graph_compute(
2339
2351
  [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:19];
2340
2352
  [encoder setBytes:&n_past length:sizeof( int) atIndex:20];
2341
2353
  [encoder setBytes:&n_dims length:sizeof( int) atIndex:21];
2342
- [encoder setBytes:&mode length:sizeof( int) atIndex:22];
2343
- [encoder setBytes:&n_orig_ctx length:sizeof( int) atIndex:23];
2344
- [encoder setBytes:&freq_base length:sizeof( float) atIndex:24];
2345
- [encoder setBytes:&freq_scale length:sizeof( float) atIndex:25];
2346
- [encoder setBytes:&ext_factor length:sizeof( float) atIndex:26];
2347
- [encoder setBytes:&attn_factor length:sizeof( float) atIndex:27];
2348
- [encoder setBytes:&beta_fast length:sizeof( float) atIndex:28];
2349
- [encoder setBytes:&beta_slow length:sizeof( float) atIndex:29];
2354
+ [encoder setBytes:&n_ctx_orig length:sizeof( int) atIndex:22];
2355
+ [encoder setBytes:&freq_base length:sizeof( float) atIndex:23];
2356
+ [encoder setBytes:&freq_scale length:sizeof( float) atIndex:24];
2357
+ [encoder setBytes:&ext_factor length:sizeof( float) atIndex:25];
2358
+ [encoder setBytes:&attn_factor length:sizeof( float) atIndex:26];
2359
+ [encoder setBytes:&beta_fast length:sizeof( float) atIndex:27];
2360
+ [encoder setBytes:&beta_slow length:sizeof( float) atIndex:28];
2350
2361
 
2351
2362
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2352
2363
  } break;
@@ -3034,12 +3045,6 @@ GGML_CALL static size_t ggml_backend_metal_buffer_type_get_max_size(ggml_backend
3034
3045
  UNUSED(buft);
3035
3046
  }
3036
3047
 
3037
- GGML_CALL static bool ggml_backend_metal_buffer_type_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend) {
3038
- return ggml_backend_is_metal(backend) || ggml_backend_is_cpu(backend);
3039
-
3040
- UNUSED(buft);
3041
- }
3042
-
3043
3048
  GGML_CALL static bool ggml_backend_metal_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
3044
3049
  return true;
3045
3050
 
@@ -3054,7 +3059,6 @@ GGML_CALL ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void) {
3054
3059
  /* .get_alignment = */ ggml_backend_metal_buffer_type_get_alignment,
3055
3060
  /* .get_max_size = */ ggml_backend_metal_buffer_type_get_max_size,
3056
3061
  /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
3057
- /* .supports_backend = */ ggml_backend_metal_buffer_type_supports_backend,
3058
3062
  /* .is_host = */ ggml_backend_metal_buffer_type_is_host,
3059
3063
  },
3060
3064
  /* .context = */ NULL,
@@ -3169,6 +3173,12 @@ GGML_CALL static bool ggml_backend_metal_supports_op(ggml_backend_t backend, con
3169
3173
  return ggml_metal_supports_op(metal_ctx, op);
3170
3174
  }
3171
3175
 
3176
+ GGML_CALL static bool ggml_backend_metal_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
3177
+ return buft->iface.get_name == ggml_backend_metal_buffer_type_get_name;
3178
+
3179
+ UNUSED(backend);
3180
+ }
3181
+
3172
3182
  static struct ggml_backend_i ggml_backend_metal_i = {
3173
3183
  /* .get_name = */ ggml_backend_metal_name,
3174
3184
  /* .free = */ ggml_backend_metal_free,
@@ -3179,9 +3189,11 @@ static struct ggml_backend_i ggml_backend_metal_i = {
3179
3189
  /* .synchronize = */ NULL,
3180
3190
  /* .graph_plan_create = */ NULL,
3181
3191
  /* .graph_plan_free = */ NULL,
3192
+ /* .graph_plan_update = */ NULL,
3182
3193
  /* .graph_plan_compute = */ NULL,
3183
3194
  /* .graph_compute = */ ggml_backend_metal_graph_compute,
3184
3195
  /* .supports_op = */ ggml_backend_metal_supports_op,
3196
+ /* .supports_buft = */ ggml_backend_metal_supports_buft,
3185
3197
  /* .offload_op = */ NULL,
3186
3198
  /* .event_new = */ NULL,
3187
3199
  /* .event_free = */ NULL,
@@ -1654,8 +1654,7 @@ static float rope_yarn_ramp(const float low, const float high, const int i0) {
1654
1654
  // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
1655
1655
  static void rope_yarn(
1656
1656
  float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
1657
- thread float * cos_theta, thread float * sin_theta
1658
- ) {
1657
+ thread float * cos_theta, thread float * sin_theta) {
1659
1658
  // Get n-d rotational scaling corrected for extrapolation
1660
1659
  float theta_interp = freq_scale * theta_extrap;
1661
1660
  float theta = theta_interp;
@@ -1672,19 +1671,20 @@ static void rope_yarn(
1672
1671
 
1673
1672
  // Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
1674
1673
  // `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
1675
- static float rope_yarn_corr_factor(int n_dims, int n_orig_ctx, float n_rot, float base) {
1676
- return n_dims * log(n_orig_ctx / (n_rot * 2 * M_PI_F)) / (2 * log(base));
1674
+ static float rope_yarn_corr_factor(int n_dims, int n_ctx_orig, float n_rot, float base) {
1675
+ return n_dims * log(n_ctx_orig / (n_rot * 2 * M_PI_F)) / (2 * log(base));
1677
1676
  }
1678
1677
 
1679
1678
  static void rope_yarn_corr_dims(
1680
- int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]
1679
+ int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2]
1681
1680
  ) {
1682
1681
  // start and end correction dims
1683
- dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_fast, freq_base)));
1684
- dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_slow, freq_base)));
1682
+ dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_fast, freq_base)));
1683
+ dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_slow, freq_base)));
1685
1684
  }
1686
1685
 
1687
- typedef void (rope_t)(
1686
+ template<typename T>
1687
+ kernel void kernel_rope_norm(
1688
1688
  device const void * src0,
1689
1689
  device const int32_t * src1,
1690
1690
  device const float * src2,
@@ -1707,8 +1707,7 @@ typedef void (rope_t)(
1707
1707
  constant uint64_t & nb3,
1708
1708
  constant int & n_past,
1709
1709
  constant int & n_dims,
1710
- constant int & mode,
1711
- constant int & n_orig_ctx,
1710
+ constant int & n_ctx_orig,
1712
1711
  constant float & freq_base,
1713
1712
  constant float & freq_scale,
1714
1713
  constant float & ext_factor,
@@ -1717,10 +1716,52 @@ typedef void (rope_t)(
1717
1716
  constant float & beta_slow,
1718
1717
  uint tiitg[[thread_index_in_threadgroup]],
1719
1718
  uint3 tptg[[threads_per_threadgroup]],
1720
- uint3 tgpig[[threadgroup_position_in_grid]]);
1719
+ uint3 tgpig[[threadgroup_position_in_grid]]) {
1720
+ const int64_t i3 = tgpig[2];
1721
+ const int64_t i2 = tgpig[1];
1722
+ const int64_t i1 = tgpig[0];
1723
+
1724
+ float corr_dims[2];
1725
+ rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
1726
+
1727
+ device const int32_t * pos = src1;
1728
+
1729
+ const float theta_base = (float) pos[i2];
1730
+ const float inv_ndims = -1.f/n_dims;
1731
+
1732
+ float cos_theta;
1733
+ float sin_theta;
1734
+
1735
+ for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
1736
+ if (i0 < n_dims) {
1737
+ const int64_t ic = i0/2;
1738
+
1739
+ const float theta = theta_base * pow(freq_base, inv_ndims*i0);
1740
+
1741
+ const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
1742
+
1743
+ rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
1744
+
1745
+ device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
1746
+ device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1747
+
1748
+ const float x0 = src[0];
1749
+ const float x1 = src[1];
1750
+
1751
+ dst_data[0] = x0*cos_theta - x1*sin_theta;
1752
+ dst_data[1] = x0*sin_theta + x1*cos_theta;
1753
+ } else {
1754
+ device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
1755
+ device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1756
+
1757
+ dst_data[0] = src[0];
1758
+ dst_data[1] = src[1];
1759
+ }
1760
+ }
1761
+ }
1721
1762
 
1722
1763
  template<typename T>
1723
- kernel void kernel_rope(
1764
+ kernel void kernel_rope_neox(
1724
1765
  device const void * src0,
1725
1766
  device const int32_t * src1,
1726
1767
  device const float * src2,
@@ -1743,8 +1784,7 @@ kernel void kernel_rope(
1743
1784
  constant uint64_t & nb3,
1744
1785
  constant int & n_past,
1745
1786
  constant int & n_dims,
1746
- constant int & mode,
1747
- constant int & n_orig_ctx,
1787
+ constant int & n_ctx_orig,
1748
1788
  constant float & freq_base,
1749
1789
  constant float & freq_scale,
1750
1790
  constant float & ext_factor,
@@ -1758,69 +1798,53 @@ kernel void kernel_rope(
1758
1798
  const int64_t i2 = tgpig[1];
1759
1799
  const int64_t i1 = tgpig[0];
1760
1800
 
1761
- const bool is_neox = mode & 2;
1762
-
1763
1801
  float corr_dims[2];
1764
- rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
1802
+ rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
1765
1803
 
1766
1804
  device const int32_t * pos = src1;
1767
1805
 
1768
- const int64_t p = pos[i2];
1769
-
1770
- const float theta_base = (float)p;
1806
+ const float theta_base = (float) pos[i2];
1771
1807
  const float inv_ndims = -1.f/n_dims;
1772
1808
 
1773
- if (!is_neox) {
1774
- for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
1775
- const float theta = theta_base * pow(freq_base, inv_ndims*i0);
1776
-
1777
- float cos_theta, sin_theta;
1778
- rope_yarn(theta, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
1779
-
1780
- device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
1781
- device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1782
-
1783
- const T x0 = src[0];
1784
- const T x1 = src[1];
1809
+ float cos_theta;
1810
+ float sin_theta;
1785
1811
 
1786
- dst_data[0] = x0*cos_theta - x1*sin_theta;
1787
- dst_data[1] = x0*sin_theta + x1*cos_theta;
1788
- }
1789
- } else {
1790
- for (int64_t ic = 2*tiitg; ic < ne0; ic += 2*tptg.x) {
1791
- if (ic < n_dims) {
1792
- const int64_t i0 = ic/2;
1812
+ for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
1813
+ if (i0 < n_dims) {
1814
+ const int64_t ic = i0/2;
1793
1815
 
1794
- const float freq_factor = src2 != src0 ? src2[i0] : 1.0f;
1795
-
1796
- const float theta = theta_base * pow(freq_base, inv_ndims*ic);
1816
+ const float theta = theta_base * pow(freq_base, inv_ndims*i0);
1797
1817
 
1798
- float cos_theta, sin_theta;
1799
- rope_yarn(theta/freq_factor, freq_scale, corr_dims, ic, ext_factor, attn_factor, &cos_theta, &sin_theta);
1818
+ const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
1800
1819
 
1801
- device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
1802
- device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1820
+ rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
1803
1821
 
1804
- const float x0 = src[0];
1805
- const float x1 = src[n_dims/2];
1822
+ device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
1823
+ device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
1806
1824
 
1807
- dst_data[0] = x0*cos_theta - x1*sin_theta;
1808
- dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
1809
- } else {
1810
- const int64_t i0 = ic;
1825
+ const float x0 = src[0];
1826
+ const float x1 = src[n_dims/2];
1811
1827
 
1812
- device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
1813
- device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1828
+ dst_data[0] = x0*cos_theta - x1*sin_theta;
1829
+ dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
1830
+ } else {
1831
+ device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
1832
+ device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1814
1833
 
1815
- dst_data[0] = src[0];
1816
- dst_data[1] = src[1];
1817
- }
1834
+ dst_data[0] = src[0];
1835
+ dst_data[1] = src[1];
1818
1836
  }
1819
1837
  }
1820
1838
  }
1821
1839
 
1822
- template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope<float>;
1823
- template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope<half>;
1840
+ typedef decltype(kernel_rope_norm<float>) kernel_rope_norm_t;
1841
+ typedef decltype(kernel_rope_neox<float>) kernel_rope_neox_t;
1842
+
1843
+ template [[host_name("kernel_rope_norm_f32")]] kernel kernel_rope_norm_t kernel_rope_norm<float>;
1844
+ template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_rope_norm<half>;
1845
+
1846
+ template [[host_name("kernel_rope_neox_f32")]] kernel kernel_rope_neox_t kernel_rope_neox<float>;
1847
+ template [[host_name("kernel_rope_neox_f16")]] kernel kernel_rope_neox_t kernel_rope_neox<half>;
1824
1848
 
1825
1849
  typedef void (im2col_t)(
1826
1850
  device const float * x,
@@ -491,7 +491,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer
491
491
  if (remote_ptr != 0) {
492
492
  ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft,
493
493
  ggml_backend_rpc_buffer_interface,
494
- new ggml_backend_rpc_buffer_context{sock, {}, remote_ptr, "RPC"},
494
+ new ggml_backend_rpc_buffer_context{sock, {}, remote_ptr, "RPC[" + std::string(buft_ctx->endpoint) + "]"},
495
495
  remote_size);
496
496
  return buffer;
497
497
  } else {
@@ -540,22 +540,12 @@ GGML_CALL static size_t ggml_backend_rpc_buffer_type_get_alloc_size(ggml_backend
540
540
  return ggml_nbytes(tensor);
541
541
  }
542
542
 
543
- GGML_CALL static bool ggml_backend_rpc_buffer_type_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend) {
544
- if (!ggml_backend_is_rpc(backend)) {
545
- return false;
546
- }
547
- ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
548
- ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
549
- return buft_ctx->endpoint == rpc_ctx->endpoint;
550
- }
551
-
552
543
  static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = {
553
544
  /* .get_name = */ ggml_backend_rpc_buffer_type_name,
554
545
  /* .alloc_buffer = */ ggml_backend_rpc_buffer_type_alloc_buffer,
555
546
  /* .get_alignment = */ ggml_backend_rpc_buffer_type_get_alignment,
556
547
  /* .get_max_size = */ ggml_backend_rpc_get_max_size,
557
548
  /* .get_alloc_size = */ ggml_backend_rpc_buffer_type_get_alloc_size,
558
- /* .supports_backend = */ ggml_backend_rpc_buffer_type_supports_backend,
559
549
  /* .is_host = */ NULL,
560
550
  };
561
551
 
@@ -634,8 +624,17 @@ GGML_CALL static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t
634
624
  GGML_CALL static bool ggml_backend_rpc_supports_op(ggml_backend_t backend, const ggml_tensor * op) {
635
625
  UNUSED(backend);
636
626
  UNUSED(op);
637
- GGML_ASSERT(false && "not implemented");
638
- return false;
627
+ //TODO: call the remote backend and cache the results
628
+ return true;
629
+ }
630
+
631
+ GGML_CALL static bool ggml_backend_rpc_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
632
+ if (buft->iface.get_name != ggml_backend_rpc_buffer_type_name) {
633
+ return false;
634
+ }
635
+ ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
636
+ ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
637
+ return buft_ctx->endpoint == rpc_ctx->endpoint;
639
638
  }
640
639
 
641
640
  static ggml_backend_i ggml_backend_rpc_interface = {
@@ -648,9 +647,11 @@ static ggml_backend_i ggml_backend_rpc_interface = {
648
647
  /* .synchronize = */ ggml_backend_rpc_synchronize,
649
648
  /* .graph_plan_create = */ NULL,
650
649
  /* .graph_plan_free = */ NULL,
650
+ /* .graph_plan_update = */ NULL,
651
651
  /* .graph_plan_compute = */ NULL,
652
652
  /* .graph_compute = */ ggml_backend_rpc_graph_compute,
653
653
  /* .supports_op = */ ggml_backend_rpc_supports_op,
654
+ /* .supports_buft = */ ggml_backend_rpc_supports_buft,
654
655
  /* .offload_op = */ NULL,
655
656
  /* .event_new = */ NULL,
656
657
  /* .event_free = */ NULL,
@@ -692,7 +693,7 @@ GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const
692
693
  GGML_CALL ggml_backend_t ggml_backend_rpc_init(const char * endpoint) {
693
694
  ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context {
694
695
  /* .endpoint = */ endpoint,
695
- /* .name = */ "RPC",
696
+ /* .name = */ "RPC[" + std::string(endpoint) + "]",
696
697
  };
697
698
 
698
699
  ggml_backend_t backend = new ggml_backend {