llama_cpp 0.16.1 → 0.17.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (177) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +25 -0
  3. data/README.md +7 -12
  4. data/ext/llama_cpp/extconf.rb +2 -42
  5. data/ext/llama_cpp/llama_cpp.cpp +20 -0
  6. data/lib/llama_cpp/version.rb +3 -3
  7. data/sig/llama_cpp.rbs +5 -0
  8. metadata +2 -171
  9. data/vendor/include/.gitkeep +0 -0
  10. data/vendor/lib/.gitkeep +0 -0
  11. data/vendor/tmp/llama.cpp/LICENSE +0 -21
  12. data/vendor/tmp/llama.cpp/Makefile +0 -1116
  13. data/vendor/tmp/llama.cpp/ggml-alloc.c +0 -1041
  14. data/vendor/tmp/llama.cpp/ggml-alloc.h +0 -76
  15. data/vendor/tmp/llama.cpp/ggml-backend-impl.h +0 -153
  16. data/vendor/tmp/llama.cpp/ggml-backend.c +0 -2214
  17. data/vendor/tmp/llama.cpp/ggml-backend.h +0 -233
  18. data/vendor/tmp/llama.cpp/ggml-blas.cpp +0 -363
  19. data/vendor/tmp/llama.cpp/ggml-blas.h +0 -23
  20. data/vendor/tmp/llama.cpp/ggml-common.h +0 -1805
  21. data/vendor/tmp/llama.cpp/ggml-cuda/acc.cu +0 -47
  22. data/vendor/tmp/llama.cpp/ggml-cuda/arange.cu +0 -34
  23. data/vendor/tmp/llama.cpp/ggml-cuda/argsort.cu +0 -104
  24. data/vendor/tmp/llama.cpp/ggml-cuda/binbcast.cu +0 -280
  25. data/vendor/tmp/llama.cpp/ggml-cuda/clamp.cu +0 -34
  26. data/vendor/tmp/llama.cpp/ggml-cuda/concat.cu +0 -196
  27. data/vendor/tmp/llama.cpp/ggml-cuda/convert.cu +0 -686
  28. data/vendor/tmp/llama.cpp/ggml-cuda/cpy.cu +0 -490
  29. data/vendor/tmp/llama.cpp/ggml-cuda/diagmask.cu +0 -40
  30. data/vendor/tmp/llama.cpp/ggml-cuda/dmmv.cu +0 -674
  31. data/vendor/tmp/llama.cpp/ggml-cuda/fattn-tile-f16.cu +0 -319
  32. data/vendor/tmp/llama.cpp/ggml-cuda/fattn-tile-f32.cu +0 -312
  33. data/vendor/tmp/llama.cpp/ggml-cuda/fattn.cu +0 -345
  34. data/vendor/tmp/llama.cpp/ggml-cuda/getrows.cu +0 -178
  35. data/vendor/tmp/llama.cpp/ggml-cuda/im2col.cu +0 -104
  36. data/vendor/tmp/llama.cpp/ggml-cuda/mmq.cu +0 -88
  37. data/vendor/tmp/llama.cpp/ggml-cuda/mmvq.cu +0 -419
  38. data/vendor/tmp/llama.cpp/ggml-cuda/norm.cu +0 -221
  39. data/vendor/tmp/llama.cpp/ggml-cuda/pad.cu +0 -49
  40. data/vendor/tmp/llama.cpp/ggml-cuda/pool2d.cu +0 -94
  41. data/vendor/tmp/llama.cpp/ggml-cuda/quantize.cu +0 -112
  42. data/vendor/tmp/llama.cpp/ggml-cuda/rope.cu +0 -271
  43. data/vendor/tmp/llama.cpp/ggml-cuda/scale.cu +0 -31
  44. data/vendor/tmp/llama.cpp/ggml-cuda/softmax.cu +0 -206
  45. data/vendor/tmp/llama.cpp/ggml-cuda/sumrows.cu +0 -40
  46. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +0 -5
  47. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +0 -5
  48. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +0 -5
  49. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +0 -5
  50. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +0 -5
  51. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +0 -5
  52. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +0 -5
  53. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +0 -5
  54. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +0 -5
  55. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +0 -5
  56. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +0 -5
  57. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +0 -5
  58. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +0 -5
  59. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +0 -5
  60. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +0 -5
  61. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +0 -5
  62. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +0 -5
  63. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +0 -5
  64. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +0 -5
  65. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +0 -5
  66. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +0 -5
  67. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +0 -5
  68. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +0 -5
  69. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +0 -5
  70. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +0 -5
  71. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +0 -5
  72. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +0 -5
  73. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +0 -5
  74. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +0 -5
  75. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +0 -5
  76. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +0 -5
  77. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +0 -5
  78. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +0 -5
  79. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +0 -5
  80. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +0 -5
  81. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +0 -5
  82. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +0 -5
  83. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +0 -5
  84. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +0 -5
  85. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +0 -5
  86. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +0 -5
  87. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +0 -5
  88. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +0 -5
  89. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +0 -5
  90. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +0 -5
  91. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +0 -5
  92. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +0 -5
  93. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +0 -5
  94. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +0 -5
  95. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +0 -5
  96. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +0 -5
  97. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +0 -5
  98. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +0 -5
  99. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +0 -5
  100. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +0 -5
  101. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +0 -5
  102. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +0 -5
  103. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +0 -5
  104. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +0 -5
  105. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +0 -5
  106. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +0 -5
  107. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +0 -5
  108. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +0 -5
  109. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +0 -5
  110. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +0 -5
  111. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +0 -5
  112. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +0 -5
  113. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +0 -5
  114. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +0 -5
  115. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +0 -5
  116. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +0 -5
  117. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +0 -5
  118. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +0 -5
  119. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +0 -5
  120. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +0 -5
  121. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +0 -5
  122. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +0 -5
  123. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +0 -5
  124. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +0 -5
  125. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +0 -5
  126. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +0 -5
  127. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +0 -5
  128. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +0 -5
  129. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +0 -5
  130. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +0 -5
  131. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +0 -5
  132. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb16.cu +0 -10
  133. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb32.cu +0 -9
  134. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb16.cu +0 -10
  135. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb32.cu +0 -10
  136. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb8.cu +0 -8
  137. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q2_k.cu +0 -5
  138. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q3_k.cu +0 -5
  139. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q4_0.cu +0 -5
  140. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q4_1.cu +0 -5
  141. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q4_k.cu +0 -5
  142. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q5_0.cu +0 -5
  143. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q5_1.cu +0 -5
  144. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q5_k.cu +0 -5
  145. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q6_k.cu +0 -5
  146. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q8_0.cu +0 -5
  147. data/vendor/tmp/llama.cpp/ggml-cuda/tsembd.cu +0 -47
  148. data/vendor/tmp/llama.cpp/ggml-cuda/unary.cu +0 -286
  149. data/vendor/tmp/llama.cpp/ggml-cuda/upscale.cu +0 -51
  150. data/vendor/tmp/llama.cpp/ggml-cuda.cu +0 -3069
  151. data/vendor/tmp/llama.cpp/ggml-cuda.h +0 -44
  152. data/vendor/tmp/llama.cpp/ggml-impl.h +0 -651
  153. data/vendor/tmp/llama.cpp/ggml-kompute.cpp +0 -2038
  154. data/vendor/tmp/llama.cpp/ggml-kompute.h +0 -46
  155. data/vendor/tmp/llama.cpp/ggml-metal.h +0 -66
  156. data/vendor/tmp/llama.cpp/ggml-metal.m +0 -3267
  157. data/vendor/tmp/llama.cpp/ggml-metal.metal +0 -6540
  158. data/vendor/tmp/llama.cpp/ggml-quants.c +0 -14380
  159. data/vendor/tmp/llama.cpp/ggml-quants.h +0 -133
  160. data/vendor/tmp/llama.cpp/ggml-rpc.cpp +0 -1173
  161. data/vendor/tmp/llama.cpp/ggml-rpc.h +0 -24
  162. data/vendor/tmp/llama.cpp/ggml-sycl.cpp +0 -17429
  163. data/vendor/tmp/llama.cpp/ggml-sycl.h +0 -49
  164. data/vendor/tmp/llama.cpp/ggml-vulkan-shaders.hpp +0 -140820
  165. data/vendor/tmp/llama.cpp/ggml-vulkan.cpp +0 -7271
  166. data/vendor/tmp/llama.cpp/ggml-vulkan.h +0 -29
  167. data/vendor/tmp/llama.cpp/ggml.c +0 -22589
  168. data/vendor/tmp/llama.cpp/ggml.h +0 -2452
  169. data/vendor/tmp/llama.cpp/llama.cpp +0 -18692
  170. data/vendor/tmp/llama.cpp/llama.h +0 -1143
  171. data/vendor/tmp/llama.cpp/scripts/get-flags.mk +0 -38
  172. data/vendor/tmp/llama.cpp/sgemm.cpp +0 -1030
  173. data/vendor/tmp/llama.cpp/sgemm.h +0 -14
  174. data/vendor/tmp/llama.cpp/unicode-data.cpp +0 -6983
  175. data/vendor/tmp/llama.cpp/unicode-data.h +0 -20
  176. data/vendor/tmp/llama.cpp/unicode.cpp +0 -796
  177. data/vendor/tmp/llama.cpp/unicode.h +0 -63
@@ -1,3267 +0,0 @@
1
- #import "ggml-metal.h"
2
-
3
- #import "ggml-backend-impl.h"
4
- #import "ggml.h"
5
-
6
- #import <Foundation/Foundation.h>
7
-
8
- #import <Metal/Metal.h>
9
-
10
- #undef MIN
11
- #undef MAX
12
- #define MIN(a, b) ((a) < (b) ? (a) : (b))
13
- #define MAX(a, b) ((a) > (b) ? (a) : (b))
14
-
15
- #ifdef GGML_METAL_NDEBUG
16
- #define GGML_METAL_LOG_INFO(...)
17
- #define GGML_METAL_LOG_WARN(...)
18
- #define GGML_METAL_LOG_ERROR(...)
19
- #else
20
- #define GGML_METAL_LOG_INFO(...) ggml_metal_log(GGML_LOG_LEVEL_INFO, __VA_ARGS__)
21
- #define GGML_METAL_LOG_WARN(...) ggml_metal_log(GGML_LOG_LEVEL_WARN, __VA_ARGS__)
22
- #define GGML_METAL_LOG_ERROR(...) ggml_metal_log(GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
23
- #endif
24
-
25
- #define UNUSED(x) (void)(x)
26
-
27
- struct ggml_metal_kernel {
28
- id<MTLComputePipelineState> pipeline;
29
- };
30
-
31
- enum ggml_metal_kernel_type {
32
- GGML_METAL_KERNEL_TYPE_ADD,
33
- GGML_METAL_KERNEL_TYPE_ADD_ROW,
34
- GGML_METAL_KERNEL_TYPE_MUL,
35
- GGML_METAL_KERNEL_TYPE_MUL_ROW,
36
- GGML_METAL_KERNEL_TYPE_DIV,
37
- GGML_METAL_KERNEL_TYPE_DIV_ROW,
38
- GGML_METAL_KERNEL_TYPE_REPEAT_F32,
39
- GGML_METAL_KERNEL_TYPE_REPEAT_F16,
40
- GGML_METAL_KERNEL_TYPE_REPEAT_I32,
41
- GGML_METAL_KERNEL_TYPE_REPEAT_I16,
42
- GGML_METAL_KERNEL_TYPE_SCALE,
43
- GGML_METAL_KERNEL_TYPE_SCALE_4,
44
- GGML_METAL_KERNEL_TYPE_CLAMP,
45
- GGML_METAL_KERNEL_TYPE_TANH,
46
- GGML_METAL_KERNEL_TYPE_RELU,
47
- GGML_METAL_KERNEL_TYPE_SIGMOID,
48
- GGML_METAL_KERNEL_TYPE_GELU,
49
- GGML_METAL_KERNEL_TYPE_GELU_4,
50
- GGML_METAL_KERNEL_TYPE_GELU_QUICK,
51
- GGML_METAL_KERNEL_TYPE_GELU_QUICK_4,
52
- GGML_METAL_KERNEL_TYPE_SILU,
53
- GGML_METAL_KERNEL_TYPE_SILU_4,
54
- GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16,
55
- GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4,
56
- GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32,
57
- GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4,
58
- GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF,
59
- GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8,
60
- GGML_METAL_KERNEL_TYPE_GET_ROWS_F32,
61
- GGML_METAL_KERNEL_TYPE_GET_ROWS_F16,
62
- GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0,
63
- GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1,
64
- GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0,
65
- GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1,
66
- GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0,
67
- GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K,
68
- GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K,
69
- GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K,
70
- GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K,
71
- GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K,
72
- GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS,
73
- GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS,
74
- GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS,
75
- GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S,
76
- GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S,
77
- GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S,
78
- GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M,
79
- GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,
80
- GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,
81
- GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
82
- GGML_METAL_KERNEL_TYPE_RMS_NORM,
83
- GGML_METAL_KERNEL_TYPE_GROUP_NORM,
84
- GGML_METAL_KERNEL_TYPE_NORM,
85
- GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
86
- GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,
87
- GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
88
- GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,
89
- GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4,
90
- GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32,
91
- GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32,
92
- GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32,
93
- GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32,
94
- GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32,
95
- GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32,
96
- GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32,
97
- GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32,
98
- GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32,
99
- GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32,
100
- GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32,
101
- GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32,
102
- GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32,
103
- GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32,
104
- GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32,
105
- GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32,
106
- GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32,
107
- GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32,
108
- GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32,
109
- GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,
110
- //GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16,
111
- GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32,
112
- //GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW,
113
- //GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4,
114
- GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32,
115
- GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32,
116
- GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32,
117
- GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32,
118
- GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32,
119
- GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32,
120
- GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32,
121
- GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32,
122
- GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32,
123
- GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32,
124
- GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32,
125
- GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32,
126
- GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32,
127
- GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32,
128
- GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32,
129
- GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32,
130
- GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32,
131
- GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32,
132
- GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32,
133
- GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
134
- GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,
135
- GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,
136
- GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32,
137
- GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32,
138
- GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32,
139
- GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32,
140
- GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32,
141
- GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32,
142
- GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32,
143
- GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32,
144
- GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32,
145
- GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32,
146
- GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32,
147
- GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32,
148
- GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32,
149
- GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32,
150
- GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32,
151
- GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32,
152
- GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,
153
- GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,
154
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,
155
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,
156
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,
157
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32,
158
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32,
159
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32,
160
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32,
161
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32,
162
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32,
163
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32,
164
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32,
165
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32,
166
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32,
167
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32,
168
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32,
169
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32,
170
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32,
171
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32,
172
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32,
173
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,
174
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,
175
- GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32,
176
- GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16,
177
- GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32,
178
- GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16,
179
- GGML_METAL_KERNEL_TYPE_IM2COL_F16,
180
- GGML_METAL_KERNEL_TYPE_IM2COL_F32,
181
- GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
182
- GGML_METAL_KERNEL_TYPE_PAD_F32,
183
- GGML_METAL_KERNEL_TYPE_ARANGE_F32,
184
- GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32,
185
- GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
186
- GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC,
187
- GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32,
188
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64,
189
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80,
190
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96,
191
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,
192
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,
193
- //GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
194
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
195
- //GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
196
- GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
197
- GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
198
- GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
199
- GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0,
200
- GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1,
201
- GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0,
202
- GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,
203
- GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,
204
- GGML_METAL_KERNEL_TYPE_CPY_F16_F16,
205
- GGML_METAL_KERNEL_TYPE_CPY_F16_F32,
206
- GGML_METAL_KERNEL_TYPE_CONCAT,
207
- GGML_METAL_KERNEL_TYPE_SQR,
208
- GGML_METAL_KERNEL_TYPE_SUM_ROWS,
209
-
210
- GGML_METAL_KERNEL_TYPE_COUNT
211
- };
212
-
213
- struct ggml_metal_context {
214
- int n_cb;
215
-
216
- id<MTLDevice> device;
217
- id<MTLCommandQueue> queue;
218
-
219
- dispatch_queue_t d_queue;
220
-
221
- struct ggml_metal_kernel kernels[GGML_METAL_KERNEL_TYPE_COUNT];
222
-
223
- bool support_simdgroup_reduction;
224
- bool support_simdgroup_mm;
225
-
226
- bool should_capture_next_compute;
227
- };
228
-
229
- // MSL code
230
- // TODO: move the contents here when ready
231
- // for now it is easier to work in a separate file
232
- // static NSString * const msl_library_source = @"see metal.metal";
233
-
234
- // Here to assist with NSBundle Path Hack
235
- @interface GGMLMetalClass : NSObject
236
- @end
237
- @implementation GGMLMetalClass
238
- @end
239
-
240
- static void ggml_metal_default_log_callback(enum ggml_log_level level, const char * msg, void * user_data) {
241
- fprintf(stderr, "%s", msg);
242
-
243
- UNUSED(level);
244
- UNUSED(user_data);
245
- }
246
-
247
- ggml_log_callback ggml_metal_log_callback = ggml_metal_default_log_callback;
248
- void * ggml_metal_log_user_data = NULL;
249
-
250
- GGML_ATTRIBUTE_FORMAT(2, 3)
251
- static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
252
- if (ggml_metal_log_callback != NULL) {
253
- va_list args;
254
- va_start(args, format);
255
- char buffer[128];
256
- int len = vsnprintf(buffer, 128, format, args);
257
- if (len < 128) {
258
- ggml_metal_log_callback(level, buffer, ggml_metal_log_user_data);
259
- } else {
260
- char* buffer2 = malloc(len+1);
261
- va_end(args);
262
- va_start(args, format);
263
- vsnprintf(buffer2, len+1, format, args);
264
- buffer2[len] = 0;
265
- ggml_metal_log_callback(level, buffer2, ggml_metal_log_user_data);
266
- free(buffer2);
267
- }
268
- va_end(args);
269
- }
270
- }
271
-
272
- static void * ggml_metal_host_malloc(size_t n) {
273
- void * data = NULL;
274
-
275
- #if TARGET_OS_OSX
276
- kern_return_t err = vm_allocate((vm_map_t) mach_task_self(), (void *) &data, n, VM_FLAGS_ANYWHERE);
277
- if (err != KERN_SUCCESS) {
278
- GGML_METAL_LOG_ERROR("%s: error: vm_allocate failed\n", __func__);
279
- return NULL;
280
- }
281
- #else
282
- const int result = posix_memalign((void **) &data, sysconf(_SC_PAGESIZE), n);
283
- if (result != 0) {
284
- GGML_METAL_LOG_ERROR("%s: error: posix_memalign failed\n", __func__);
285
- return NULL;
286
- }
287
- #endif
288
-
289
- return data;
290
- }
291
-
292
- static struct ggml_metal_context * ggml_metal_init(int n_cb) {
293
- GGML_METAL_LOG_INFO("%s: allocating\n", __func__);
294
-
295
- #if TARGET_OS_OSX && !GGML_METAL_NDEBUG
296
- // Show all the Metal device instances in the system
297
- NSArray * devices = MTLCopyAllDevices();
298
- for (id<MTLDevice> device in devices) {
299
- GGML_METAL_LOG_INFO("%s: found device: %s\n", __func__, [[device name] UTF8String]);
300
- }
301
- [devices release]; // since it was created by a *Copy* C method
302
- #endif
303
-
304
- // Pick and show default Metal device
305
- id<MTLDevice> device = MTLCreateSystemDefaultDevice();
306
- GGML_METAL_LOG_INFO("%s: picking default device: %s\n", __func__, [[device name] UTF8String]);
307
-
308
- // Configure context
309
- struct ggml_metal_context * ctx = malloc(sizeof(struct ggml_metal_context));
310
- ctx->device = device;
311
- ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS);
312
- ctx->queue = [ctx->device newCommandQueue];
313
- ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
314
-
315
- id<MTLLibrary> metal_library;
316
-
317
- // load library
318
- //
319
- // - first check if the library is embedded
320
- // - then check if the library is in the bundle
321
- // - if not found, load the source and compile it
322
- // - if that fails, return NULL
323
- {
324
- NSBundle * bundle = nil;
325
- #ifdef SWIFT_PACKAGE
326
- bundle = SWIFTPM_MODULE_BUNDLE;
327
- #else
328
- bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
329
- #endif
330
-
331
- NSError * error = nil;
332
-
333
- #if GGML_METAL_EMBED_LIBRARY
334
- const bool try_metallib = false;
335
- #else
336
- const bool try_metallib = true;
337
- #endif
338
-
339
- NSString * path_lib = [bundle pathForResource:@"default" ofType:@"metallib"];
340
- if (try_metallib && path_lib != nil) {
341
- // pre-compiled library found
342
- NSURL * libURL = [NSURL fileURLWithPath:path_lib];
343
- GGML_METAL_LOG_INFO("%s: loading '%s'\n", __func__, [path_lib UTF8String]);
344
-
345
- metal_library = [ctx->device newLibraryWithURL:libURL error:&error];
346
- if (error) {
347
- GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
348
- return NULL;
349
- }
350
- } else {
351
- #if GGML_METAL_EMBED_LIBRARY
352
- GGML_METAL_LOG_INFO("%s: using embedded metal library\n", __func__);
353
-
354
- extern const char ggml_metallib_start[];
355
- extern const char ggml_metallib_end[];
356
-
357
- NSString * src = [[NSString alloc] initWithBytes:ggml_metallib_start length:(ggml_metallib_end-ggml_metallib_start) encoding:NSUTF8StringEncoding];
358
- #else
359
- GGML_METAL_LOG_INFO("%s: default.metallib not found, loading from source\n", __func__);
360
-
361
- NSString * path_source;
362
- NSString * path_resource = [[NSProcessInfo processInfo].environment objectForKey:@"GGML_METAL_PATH_RESOURCES"];
363
-
364
- GGML_METAL_LOG_INFO("%s: GGML_METAL_PATH_RESOURCES = %s\n", __func__, path_resource ? [path_resource UTF8String] : "nil");
365
-
366
- if (path_resource) {
367
- path_source = [path_resource stringByAppendingPathComponent:@"ggml-metal.metal"];
368
- } else {
369
- path_source = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
370
- }
371
-
372
- if (path_source == nil) {
373
- GGML_METAL_LOG_WARN("%s: error: could not use bundle path to find ggml-metal.metal, falling back to trying cwd\n", __func__);
374
- path_source = @"ggml-metal.metal";
375
- }
376
-
377
- GGML_METAL_LOG_INFO("%s: loading '%s'\n", __func__, [path_source UTF8String]);
378
-
379
- NSString * src = [NSString stringWithContentsOfFile:path_source encoding:NSUTF8StringEncoding error:&error];
380
- if (error) {
381
- GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
382
- return NULL;
383
- }
384
- #endif // GGML_METAL_EMBED_LIBRARY
385
-
386
- @autoreleasepool {
387
- // dictionary of preprocessor macros
388
- NSMutableDictionary * prep = [NSMutableDictionary dictionary];
389
-
390
- MTLCompileOptions* options = [MTLCompileOptions new];
391
- options.preprocessorMacros = prep;
392
-
393
- //[options setFastMathEnabled:false];
394
-
395
- metal_library = [ctx->device newLibraryWithSource:src options:options error:&error];
396
- if (error) {
397
- GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
398
- return NULL;
399
- }
400
- }
401
- }
402
- }
403
-
404
- // print MTL GPU family:
405
- GGML_METAL_LOG_INFO("%s: GPU name: %s\n", __func__, [[ctx->device name] UTF8String]);
406
-
407
- const NSInteger MTLGPUFamilyMetal3 = 5001;
408
-
409
- // determine max supported GPU family
410
- // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
411
- // https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
412
- {
413
- for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) {
414
- if ([ctx->device supportsFamily:i]) {
415
- GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - (int) MTLGPUFamilyApple1 + 1, i);
416
- break;
417
- }
418
- }
419
-
420
- for (int i = MTLGPUFamilyCommon1 + 5; i >= MTLGPUFamilyCommon1; --i) {
421
- if ([ctx->device supportsFamily:i]) {
422
- GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyCommon%d (%d)\n", __func__, i - (int) MTLGPUFamilyCommon1 + 1, i);
423
- break;
424
- }
425
- }
426
-
427
- for (int i = MTLGPUFamilyMetal3 + 5; i >= MTLGPUFamilyMetal3; --i) {
428
- if ([ctx->device supportsFamily:i]) {
429
- GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyMetal%d (%d)\n", __func__, i - (int) MTLGPUFamilyMetal3 + 3, i);
430
- break;
431
- }
432
- }
433
- }
434
-
435
- ctx->support_simdgroup_reduction = [ctx->device supportsFamily:MTLGPUFamilyApple7];
436
- ctx->support_simdgroup_reduction |= [ctx->device supportsFamily:MTLGPUFamilyMetal3];
437
-
438
- ctx->support_simdgroup_mm = [ctx->device supportsFamily:MTLGPUFamilyApple7];
439
-
440
- GGML_METAL_LOG_INFO("%s: simdgroup reduction support = %s\n", __func__, ctx->support_simdgroup_reduction ? "true" : "false");
441
- GGML_METAL_LOG_INFO("%s: simdgroup matrix mul. support = %s\n", __func__, ctx->support_simdgroup_mm ? "true" : "false");
442
- GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
443
-
444
- ctx->should_capture_next_compute = false;
445
-
446
- #if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
447
- if (@available(macOS 10.12, iOS 16.0, *)) {
448
- GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1e6);
449
- }
450
- #elif TARGET_OS_OSX
451
- if (ctx->device.maxTransferRate != 0) {
452
- GGML_METAL_LOG_INFO("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1e6);
453
- } else {
454
- GGML_METAL_LOG_INFO("%s: maxTransferRate = built-in GPU\n", __func__);
455
- }
456
- #endif
457
-
458
- // load kernels
459
- {
460
- NSError * error = nil;
461
-
462
- for (int i = 0; i < GGML_METAL_KERNEL_TYPE_COUNT; ++i) {
463
- ctx->kernels[i].pipeline = nil;
464
- }
465
-
466
- /*
467
- GGML_METAL_LOG_INFO("%s: loaded %-40s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) kernel->pipeline, \
468
- (int) kernel->pipeline.maxTotalThreadsPerThreadgroup, \
469
- (int) kernel->pipeline.threadExecutionWidth); \
470
- */
471
- #define GGML_METAL_ADD_KERNEL(e, name, supported) \
472
- if (supported) { \
473
- struct ggml_metal_kernel * kernel = &ctx->kernels[e]; \
474
- id<MTLFunction> metal_function = [metal_library newFunctionWithName:@"kernel_"#name]; \
475
- kernel->pipeline = [ctx->device newComputePipelineStateWithFunction:metal_function error:&error]; \
476
- [metal_function release]; \
477
- if (error) { \
478
- GGML_METAL_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
479
- [metal_library release]; \
480
- return NULL; \
481
- } \
482
- } else { \
483
- GGML_METAL_LOG_WARN("%s: skipping %-40s (not supported)\n", __func__, "kernel_"#name); \
484
- }
485
-
486
- // simd_sum and simd_max requires MTLGPUFamilyApple7
487
-
488
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true);
489
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true);
490
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true);
491
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true);
492
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true);
493
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true);
494
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F32, repeat_f32, true);
495
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F16, repeat_f16, true);
496
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I32, repeat_i32, true);
497
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I16, repeat_i16, true);
498
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true);
499
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true);
500
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true);
501
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true);
502
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true);
503
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIGMOID, sigmoid, true);
504
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true);
505
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_4, gelu_4, true);
506
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true);
507
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true);
508
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true);
509
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true);
510
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16, ctx->support_simdgroup_reduction);
511
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, ctx->support_simdgroup_reduction);
512
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, ctx->support_simdgroup_reduction);
513
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4, soft_max_f32_4, ctx->support_simdgroup_reduction);
514
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true);
515
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true);
516
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true);
517
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, get_rows_f16, true);
518
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, get_rows_q4_0, true);
519
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, get_rows_q4_1, true);
520
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true);
521
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1, get_rows_q5_1, true);
522
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0, get_rows_q8_0, true);
523
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K, get_rows_q2_K, true);
524
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K, get_rows_q3_K, true);
525
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K, get_rows_q4_K, true);
526
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K, get_rows_q5_K, true);
527
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K, get_rows_q6_K, true);
528
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true);
529
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true);
530
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS, get_rows_iq3_xxs, true);
531
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S, get_rows_iq3_s, true);
532
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S, get_rows_iq2_s, true);
533
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, get_rows_iq1_s, true);
534
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M, get_rows_iq1_m, true);
535
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true);
536
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true);
537
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
538
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction);
539
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction);
540
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
541
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, ctx->support_simdgroup_reduction);
542
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, ctx->support_simdgroup_reduction);
543
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, ctx->support_simdgroup_reduction);
544
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, ctx->support_simdgroup_reduction);
545
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, ctx->support_simdgroup_reduction);
546
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, mul_mv_q4_0_f32, ctx->support_simdgroup_reduction);
547
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32, ctx->support_simdgroup_reduction);
548
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, ctx->support_simdgroup_reduction);
549
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, ctx->support_simdgroup_reduction);
550
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, ctx->support_simdgroup_reduction);
551
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, ctx->support_simdgroup_reduction);
552
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, ctx->support_simdgroup_reduction);
553
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, ctx->support_simdgroup_reduction);
554
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32, mul_mv_q5_K_f32, ctx->support_simdgroup_reduction);
555
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, mul_mv_q6_K_f32, ctx->support_simdgroup_reduction);
556
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, ctx->support_simdgroup_reduction);
557
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, ctx->support_simdgroup_reduction);
558
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, ctx->support_simdgroup_reduction);
559
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32, mul_mv_iq3_s_f32, ctx->support_simdgroup_reduction);
560
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32, mul_mv_iq2_s_f32, ctx->support_simdgroup_reduction);
561
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, ctx->support_simdgroup_reduction);
562
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32, mul_mv_iq1_m_f32, ctx->support_simdgroup_reduction);
563
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, ctx->support_simdgroup_reduction);
564
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32, ctx->support_simdgroup_reduction);
565
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction);
566
- //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, ctx->support_simdgroup_reduction);
567
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, ctx->support_simdgroup_reduction);
568
- //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, ctx->support_simdgroup_reduction);
569
- //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, ctx->support_simdgroup_reduction);
570
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, ctx->support_simdgroup_reduction);
571
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, ctx->support_simdgroup_reduction);
572
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, ctx->support_simdgroup_reduction);
573
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, ctx->support_simdgroup_reduction);
574
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, ctx->support_simdgroup_reduction);
575
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, ctx->support_simdgroup_reduction);
576
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, ctx->support_simdgroup_reduction);
577
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32, ctx->support_simdgroup_reduction);
578
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32, mul_mv_id_q5_K_f32, ctx->support_simdgroup_reduction);
579
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, mul_mv_id_q6_K_f32, ctx->support_simdgroup_reduction);
580
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, ctx->support_simdgroup_reduction);
581
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, ctx->support_simdgroup_reduction);
582
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, ctx->support_simdgroup_reduction);
583
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32, mul_mv_id_iq3_s_f32, ctx->support_simdgroup_reduction);
584
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32, mul_mv_id_iq2_s_f32, ctx->support_simdgroup_reduction);
585
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, ctx->support_simdgroup_reduction);
586
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32, mul_mv_id_iq1_m_f32, ctx->support_simdgroup_reduction);
587
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, ctx->support_simdgroup_reduction);
588
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, ctx->support_simdgroup_reduction);
589
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm);
590
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm);
591
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm);
592
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, ctx->support_simdgroup_mm);
593
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, ctx->support_simdgroup_mm);
594
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, ctx->support_simdgroup_mm);
595
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, ctx->support_simdgroup_mm);
596
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, ctx->support_simdgroup_mm);
597
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, ctx->support_simdgroup_mm);
598
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, ctx->support_simdgroup_mm);
599
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32, mul_mm_q5_K_f32, ctx->support_simdgroup_mm);
600
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, mul_mm_q6_K_f32, ctx->support_simdgroup_mm);
601
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, ctx->support_simdgroup_mm);
602
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, ctx->support_simdgroup_mm);
603
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, ctx->support_simdgroup_mm);
604
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32, mul_mm_iq3_s_f32, ctx->support_simdgroup_mm);
605
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32, mul_mm_iq2_s_f32, ctx->support_simdgroup_mm);
606
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, ctx->support_simdgroup_mm);
607
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, ctx->support_simdgroup_mm);
608
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, ctx->support_simdgroup_mm);
609
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, ctx->support_simdgroup_mm);
610
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm);
611
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm);
612
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, ctx->support_simdgroup_mm);
613
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, ctx->support_simdgroup_mm);
614
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, ctx->support_simdgroup_mm);
615
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, mul_mm_id_q5_1_f32, ctx->support_simdgroup_mm);
616
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, mul_mm_id_q8_0_f32, ctx->support_simdgroup_mm);
617
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, mul_mm_id_q2_K_f32, ctx->support_simdgroup_mm);
618
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, mul_mm_id_q3_K_f32, ctx->support_simdgroup_mm);
619
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32, mul_mm_id_q4_K_f32, ctx->support_simdgroup_mm);
620
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32, mul_mm_id_q5_K_f32, ctx->support_simdgroup_mm);
621
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32, ctx->support_simdgroup_mm);
622
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, ctx->support_simdgroup_mm);
623
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, ctx->support_simdgroup_mm);
624
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, ctx->support_simdgroup_mm);
625
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32, mul_mm_id_iq3_s_f32, ctx->support_simdgroup_mm);
626
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, mul_mm_id_iq2_s_f32, ctx->support_simdgroup_mm);
627
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, ctx->support_simdgroup_mm);
628
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, ctx->support_simdgroup_mm);
629
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm);
630
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm);
631
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, rope_norm_f32, true);
632
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, rope_norm_f16, true);
633
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, rope_neox_f32, true);
634
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, rope_neox_f16, true);
635
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
636
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
637
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
638
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true);
639
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true);
640
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARANGE_F32, arange_f32, true);
641
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
642
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true);
643
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true);
644
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, ctx->support_simdgroup_mm);
645
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, ctx->support_simdgroup_mm);
646
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, ctx->support_simdgroup_mm);
647
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, ctx->support_simdgroup_mm);
648
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, ctx->support_simdgroup_mm);
649
- //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, ctx->support_simdgroup_mm);
650
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, ctx->support_simdgroup_reduction);
651
- //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, ctx->support_simdgroup_reduction);
652
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
653
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
654
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
655
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true);
656
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true);
657
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true);
658
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true);
659
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true);
660
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
661
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true);
662
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true);
663
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true);
664
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
665
- }
666
-
667
- [metal_library release];
668
- return ctx;
669
- }
670
-
671
- static void ggml_metal_free(struct ggml_metal_context * ctx) {
672
- GGML_METAL_LOG_INFO("%s: deallocating\n", __func__);
673
-
674
- for (int i = 0; i < GGML_METAL_KERNEL_TYPE_COUNT; ++i) {
675
- [ctx->kernels[i].pipeline release];
676
- }
677
-
678
- [ctx->queue release];
679
- [ctx->device release];
680
-
681
- dispatch_release(ctx->d_queue);
682
-
683
- free(ctx);
684
- }
685
-
686
- // temporarily defined here for compatibility between ggml-backend and the old API
687
-
688
- struct ggml_backend_metal_buffer {
689
- void * data;
690
- size_t size;
691
-
692
- id<MTLBuffer> metal;
693
- };
694
-
695
- struct ggml_backend_metal_buffer_context {
696
- void * all_data;
697
- size_t all_size;
698
- bool owned;
699
-
700
- // multiple buffers are used only to avoid the maximum buffer size limitation when using mmap
701
- int n_buffers;
702
- struct ggml_backend_metal_buffer buffers[GGML_METAL_MAX_BUFFERS];
703
- };
704
-
705
- // finds the Metal buffer that contains the tensor data on the GPU device
706
- // the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the
707
- // Metal buffer based on the host memory pointer
708
- //
709
- static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_tensor * t, size_t * offs) {
710
- //GGML_METAL_LOG_INFO("%s: data tensor '%16s', offs_data = %8ld, offs_eval = %8ld, offs_cach = %8ld\n", __func__, t->name, offs_data, offs_eval, offs_cach);
711
-
712
- const int64_t tsize = ggml_nbytes(t);
713
-
714
- ggml_backend_buffer_t buffer = t->view_src ? t->view_src->buffer : t->buffer;
715
-
716
- struct ggml_backend_metal_buffer_context * buf_ctx = (struct ggml_backend_metal_buffer_context *) buffer->context;
717
-
718
- // find the view that contains the tensor fully
719
- for (int i = 0; i < buf_ctx->n_buffers; ++i) {
720
- const int64_t ioffs = (int64_t) t->data - (int64_t) buf_ctx->buffers[i].data;
721
-
722
- //GGML_METAL_LOG_INFO("ioffs = %10ld, tsize = %10ld, sum = %10ld, buf_ctx->buffers[%d].size = %10ld\n", ioffs, tsize, ioffs + tsize, i, buf_ctx->buffers[i].size);
723
- if (ioffs >= 0 && ioffs + tsize <= (int64_t) buf_ctx->buffers[i].size) {
724
- *offs = (size_t) ioffs;
725
-
726
- //GGML_METAL_LOG_INFO("%s: tensor '%16s', offs = %8ld\n", __func__, t->name, *offs);
727
-
728
- return buf_ctx->buffers[i].metal;
729
- }
730
- }
731
-
732
- GGML_METAL_LOG_ERROR("%s: error: tensor '%s' buffer is nil\n", __func__, t->name);
733
-
734
- return nil;
735
- }
736
-
737
- static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const struct ggml_tensor * op) {
738
- switch (op->op) {
739
- case GGML_OP_UNARY:
740
- switch (ggml_get_unary_op(op)) {
741
- case GGML_UNARY_OP_TANH:
742
- case GGML_UNARY_OP_RELU:
743
- case GGML_UNARY_OP_SIGMOID:
744
- case GGML_UNARY_OP_GELU:
745
- case GGML_UNARY_OP_GELU_QUICK:
746
- case GGML_UNARY_OP_SILU:
747
- return ggml_is_contiguous(op->src[0]);
748
- default:
749
- return false;
750
- }
751
- case GGML_OP_NONE:
752
- case GGML_OP_RESHAPE:
753
- case GGML_OP_VIEW:
754
- case GGML_OP_TRANSPOSE:
755
- case GGML_OP_PERMUTE:
756
- case GGML_OP_CONCAT:
757
- case GGML_OP_ADD:
758
- case GGML_OP_ACC:
759
- case GGML_OP_MUL:
760
- case GGML_OP_DIV:
761
- case GGML_OP_REPEAT:
762
- case GGML_OP_SCALE:
763
- case GGML_OP_CLAMP:
764
- case GGML_OP_SQR:
765
- case GGML_OP_SUM_ROWS:
766
- return true;
767
- case GGML_OP_SOFT_MAX:
768
- case GGML_OP_RMS_NORM:
769
- case GGML_OP_GROUP_NORM:
770
- return ctx->support_simdgroup_reduction;
771
- case GGML_OP_NORM:
772
- case GGML_OP_ROPE:
773
- case GGML_OP_IM2COL:
774
- return true;
775
- case GGML_OP_POOL_1D:
776
- case GGML_OP_POOL_2D:
777
- return false;
778
- case GGML_OP_UPSCALE:
779
- case GGML_OP_PAD:
780
- case GGML_OP_ARANGE:
781
- case GGML_OP_TIMESTEP_EMBEDDING:
782
- case GGML_OP_ARGSORT:
783
- case GGML_OP_LEAKY_RELU:
784
- return true;
785
- case GGML_OP_FLASH_ATTN_EXT:
786
- if (op->src[1]->type != GGML_TYPE_F16) {
787
- return false;
788
- }
789
- if (op->src[2]->type != GGML_TYPE_F16) {
790
- return false;
791
- }
792
- if (op->src[0]->ne[0] == 256) {
793
- return false;
794
- }
795
- return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels
796
- case GGML_OP_MUL_MAT:
797
- case GGML_OP_MUL_MAT_ID:
798
- return ctx->support_simdgroup_reduction &&
799
- (op->src[0]->type != GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F32);
800
- case GGML_OP_CPY:
801
- case GGML_OP_DUP:
802
- case GGML_OP_CONT:
803
- {
804
- switch (op->src[0]->type) {
805
- case GGML_TYPE_F32:
806
- switch (op->type) {
807
- case GGML_TYPE_F16:
808
- case GGML_TYPE_F32:
809
- case GGML_TYPE_Q8_0:
810
- case GGML_TYPE_Q4_0:
811
- case GGML_TYPE_Q4_1:
812
- case GGML_TYPE_Q5_0:
813
- case GGML_TYPE_Q5_1:
814
- case GGML_TYPE_IQ4_NL:
815
- return true;
816
- default:
817
- return false;
818
- }
819
- case GGML_TYPE_F16:
820
- switch (op->type) {
821
- case GGML_TYPE_F16:
822
- case GGML_TYPE_F32:
823
- return true;
824
- default:
825
- return false;
826
- }
827
- default:
828
- return false;
829
- };
830
- }
831
- case GGML_OP_DIAG_MASK_INF:
832
- case GGML_OP_GET_ROWS:
833
- {
834
- return op->src[0]->type != GGML_TYPE_BF16 && op->ne[3] == 1;
835
- }
836
- default:
837
- return false;
838
- }
839
- }
840
-
841
- static enum ggml_status ggml_metal_graph_compute(
842
- struct ggml_metal_context * ctx,
843
- struct ggml_cgraph * gf) {
844
-
845
- @autoreleasepool {
846
- MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor;
847
- edesc.dispatchType = MTLDispatchTypeSerial;
848
-
849
- // create multiple command buffers and enqueue them
850
- // then, we encode the graph into the command buffers in parallel
851
-
852
- const int n_nodes = gf->n_nodes;
853
- const int n_cb = ctx->n_cb;
854
- const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb;
855
-
856
- const bool should_capture = ctx->should_capture_next_compute;
857
- if (should_capture) {
858
- ctx->should_capture_next_compute = false;
859
-
860
- MTLCaptureDescriptor * descriptor = [MTLCaptureDescriptor new];
861
- descriptor.captureObject = ctx->queue;
862
-
863
- NSError * error = nil;
864
- if (![[MTLCaptureManager sharedCaptureManager] startCaptureWithDescriptor:descriptor error:&error]) {
865
- GGML_METAL_LOG_ERROR("%s: error: unable to start capture '%s'\n", __func__, [[error localizedDescription] UTF8String]);
866
- GGML_ASSERT(!"capture failed");
867
- }
868
- }
869
-
870
- id<MTLCommandBuffer> command_buffer_builder[n_cb];
871
- for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
872
- id<MTLCommandBuffer> command_buffer = [ctx->queue commandBufferWithUnretainedReferences];
873
- command_buffer_builder[cb_idx] = command_buffer;
874
-
875
- // enqueue the command buffers in order to specify their execution order
876
- [command_buffer enqueue];
877
- }
878
-
879
- const id<MTLCommandBuffer> *command_buffers = command_buffer_builder;
880
-
881
- dispatch_apply(n_cb, ctx->d_queue, ^(size_t iter) {
882
- const int cb_idx = iter;
883
-
884
- size_t offs_src0 = 0;
885
- size_t offs_src1 = 0;
886
- size_t offs_src2 = 0;
887
- size_t offs_dst = 0;
888
-
889
- id<MTLCommandBuffer> command_buffer = command_buffers[cb_idx];
890
- id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
891
-
892
- const int node_start = (cb_idx + 0) * n_nodes_per_cb;
893
- const int node_end = MIN((cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb, n_nodes);
894
-
895
- for (int i = node_start; i < node_end; ++i) {
896
- if (i == -1) {
897
- [encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
898
- continue;
899
- }
900
-
901
- //GGML_METAL_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
902
-
903
- struct ggml_tensor * src0 = gf->nodes[i]->src[0];
904
- struct ggml_tensor * src1 = gf->nodes[i]->src[1];
905
- struct ggml_tensor * src2 = gf->nodes[i]->src[2];
906
- struct ggml_tensor * dst = gf->nodes[i];
907
-
908
- if (ggml_is_empty(dst)) {
909
- continue;
910
- }
911
-
912
- switch (dst->op) {
913
- case GGML_OP_NONE:
914
- case GGML_OP_RESHAPE:
915
- case GGML_OP_VIEW:
916
- case GGML_OP_TRANSPOSE:
917
- case GGML_OP_PERMUTE:
918
- {
919
- // noop -> next node
920
- } continue;
921
- default:
922
- {
923
- } break;
924
- }
925
-
926
- if (!ggml_metal_supports_op(ctx, dst)) {
927
- GGML_METAL_LOG_ERROR("%s: error: unsupported op '%s'\n", __func__, ggml_op_desc(dst));
928
- GGML_ASSERT(!"unsupported op");
929
- }
930
-
931
- if (should_capture) {
932
- [encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(dst) encoding:NSUTF8StringEncoding]];
933
- }
934
-
935
- const int64_t ne00 = src0 ? src0->ne[0] : 0;
936
- const int64_t ne01 = src0 ? src0->ne[1] : 0;
937
- const int64_t ne02 = src0 ? src0->ne[2] : 0;
938
- const int64_t ne03 = src0 ? src0->ne[3] : 0;
939
-
940
- const uint64_t nb00 = src0 ? src0->nb[0] : 0;
941
- const uint64_t nb01 = src0 ? src0->nb[1] : 0;
942
- const uint64_t nb02 = src0 ? src0->nb[2] : 0;
943
- const uint64_t nb03 = src0 ? src0->nb[3] : 0;
944
-
945
- const int64_t ne10 = src1 ? src1->ne[0] : 0;
946
- const int64_t ne11 = src1 ? src1->ne[1] : 0;
947
- const int64_t ne12 = src1 ? src1->ne[2] : 0;
948
- const int64_t ne13 = src1 ? src1->ne[3] : 0;
949
-
950
- const uint64_t nb10 = src1 ? src1->nb[0] : 0;
951
- const uint64_t nb11 = src1 ? src1->nb[1] : 0;
952
- const uint64_t nb12 = src1 ? src1->nb[2] : 0;
953
- const uint64_t nb13 = src1 ? src1->nb[3] : 0;
954
-
955
- const int64_t ne20 = src2 ? src2->ne[0] : 0;
956
- const int64_t ne21 = src2 ? src2->ne[1] : 0;
957
- const int64_t ne22 = src2 ? src2->ne[2] : 0; GGML_UNUSED(ne22);
958
- const int64_t ne23 = src2 ? src2->ne[3] : 0; GGML_UNUSED(ne23);
959
-
960
- const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20);
961
- const uint64_t nb21 = src2 ? src2->nb[1] : 0;
962
- const uint64_t nb22 = src2 ? src2->nb[2] : 0;
963
- const uint64_t nb23 = src2 ? src2->nb[3] : 0;
964
-
965
- const int64_t ne0 = dst ? dst->ne[0] : 0;
966
- const int64_t ne1 = dst ? dst->ne[1] : 0;
967
- const int64_t ne2 = dst ? dst->ne[2] : 0;
968
- const int64_t ne3 = dst ? dst->ne[3] : 0;
969
-
970
- const uint64_t nb0 = dst ? dst->nb[0] : 0;
971
- const uint64_t nb1 = dst ? dst->nb[1] : 0;
972
- const uint64_t nb2 = dst ? dst->nb[2] : 0;
973
- const uint64_t nb3 = dst ? dst->nb[3] : 0;
974
-
975
- const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT;
976
- const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;
977
- const enum ggml_type dstt = dst ? dst->type : GGML_TYPE_COUNT;
978
-
979
- id<MTLBuffer> id_src0 = src0 ? ggml_metal_get_buffer(src0, &offs_src0) : nil;
980
- id<MTLBuffer> id_src1 = src1 ? ggml_metal_get_buffer(src1, &offs_src1) : nil;
981
- id<MTLBuffer> id_src2 = src2 ? ggml_metal_get_buffer(src2, &offs_src2) : nil;
982
- id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(dst, &offs_dst) : nil;
983
-
984
- //GGML_METAL_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
985
- //if (src0) {
986
- // GGML_METAL_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02,
987
- // ggml_is_contiguous(src0), src0->name);
988
- //}
989
- //if (src1) {
990
- // GGML_METAL_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12,
991
- // ggml_is_contiguous(src1), src1->name);
992
- //}
993
- //if (dst) {
994
- // GGML_METAL_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2,
995
- // dst->name);
996
- //}
997
-
998
- switch (dst->op) {
999
- case GGML_OP_CONCAT:
1000
- {
1001
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONCAT].pipeline;
1002
-
1003
- const int32_t dim = ((int32_t *) dst->op_params)[0];
1004
-
1005
- [encoder setComputePipelineState:pipeline];
1006
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1007
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1008
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1009
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1010
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
1011
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
1012
- [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
1013
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
1014
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
1015
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
1016
- [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
1017
- [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
1018
- [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
1019
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
1020
- [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
1021
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
1022
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
1023
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
1024
- [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
1025
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
1026
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
1027
- [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
1028
- [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
1029
- [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
1030
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
1031
- [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
1032
- [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
1033
- [encoder setBytes:&dim length:sizeof(dim) atIndex:27];
1034
-
1035
- const int nth = MIN(1024, ne0);
1036
-
1037
- [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1038
- } break;
1039
- case GGML_OP_ADD:
1040
- case GGML_OP_MUL:
1041
- case GGML_OP_DIV:
1042
- {
1043
- GGML_ASSERT(src0t == GGML_TYPE_F32);
1044
- GGML_ASSERT(src1t == GGML_TYPE_F32);
1045
-
1046
- const size_t offs = 0;
1047
-
1048
- bool bcast_row = false;
1049
-
1050
- int64_t nb = ne00; // used by the "row" kernels
1051
-
1052
- id<MTLComputePipelineState> pipeline = nil;
1053
-
1054
- if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
1055
- GGML_ASSERT(ggml_is_contiguous(src0));
1056
-
1057
- // src1 is a row
1058
- GGML_ASSERT(ne11 == 1);
1059
-
1060
- nb = ne00 / 4;
1061
- switch (dst->op) {
1062
- case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline; break;
1063
- case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_ROW].pipeline; break;
1064
- case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_ROW].pipeline; break;
1065
- default: GGML_ASSERT(false);
1066
- }
1067
-
1068
- bcast_row = true;
1069
- } else {
1070
- switch (dst->op) {
1071
- case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline; break;
1072
- case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL].pipeline; break;
1073
- case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV].pipeline; break;
1074
- default: GGML_ASSERT(false);
1075
- }
1076
- }
1077
-
1078
- [encoder setComputePipelineState:pipeline];
1079
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1080
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1081
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1082
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1083
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
1084
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
1085
- [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
1086
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
1087
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
1088
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
1089
- [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
1090
- [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
1091
- [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
1092
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
1093
- [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
1094
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
1095
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
1096
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
1097
- [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
1098
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
1099
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
1100
- [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
1101
- [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
1102
- [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
1103
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
1104
- [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
1105
- [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
1106
- [encoder setBytes:&offs length:sizeof(offs) atIndex:27];
1107
- [encoder setBytes:&nb length:sizeof(nb) atIndex:28];
1108
-
1109
- if (bcast_row) {
1110
- const int64_t n = ggml_nelements(dst)/4;
1111
-
1112
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1113
- } else {
1114
- const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
1115
-
1116
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1117
- }
1118
- } break;
1119
- case GGML_OP_REPEAT:
1120
- {
1121
- id<MTLComputePipelineState> pipeline;
1122
-
1123
- switch (src0t) {
1124
- case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F32].pipeline; break;
1125
- case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F16].pipeline; break;
1126
- case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I32].pipeline; break;
1127
- case GGML_TYPE_I16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I16].pipeline; break;
1128
- default: GGML_ASSERT(false);
1129
- }
1130
-
1131
- [encoder setComputePipelineState:pipeline];
1132
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1133
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1134
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
1135
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
1136
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
1137
- [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
1138
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
1139
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
1140
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
1141
- [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
1142
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
1143
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
1144
- [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
1145
- [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
1146
- [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
1147
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
1148
- [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
1149
- [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
1150
-
1151
- const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
1152
-
1153
- [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1154
- } break;
1155
- case GGML_OP_ACC:
1156
- {
1157
- GGML_ASSERT(src0t == GGML_TYPE_F32);
1158
- GGML_ASSERT(src1t == GGML_TYPE_F32);
1159
- GGML_ASSERT(dstt == GGML_TYPE_F32);
1160
-
1161
- GGML_ASSERT(ggml_is_contiguous(src0));
1162
- GGML_ASSERT(ggml_is_contiguous(src1));
1163
-
1164
- const size_t pnb1 = ((int32_t *) dst->op_params)[0];
1165
- const size_t pnb2 = ((int32_t *) dst->op_params)[1];
1166
- const size_t pnb3 = ((int32_t *) dst->op_params)[2];
1167
- const size_t offs = ((int32_t *) dst->op_params)[3];
1168
-
1169
- const bool inplace = (bool) ((int32_t *) dst->op_params)[4];
1170
-
1171
- if (!inplace) {
1172
- // run a separete kernel to cpy src->dst
1173
- // not sure how to avoid this
1174
- // TODO: make a simpler cpy_bytes kernel
1175
-
1176
- const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline;
1177
-
1178
- [encoder setComputePipelineState:pipeline];
1179
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1180
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1181
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1182
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
1183
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
1184
- [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
1185
- [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
1186
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
1187
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
1188
- [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
1189
- [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
1190
- [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
1191
- [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
1192
- [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
1193
- [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
1194
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
1195
- [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
1196
- [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
1197
-
1198
- const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00);
1199
-
1200
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1201
- }
1202
-
1203
- const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline;
1204
-
1205
- [encoder setComputePipelineState:pipeline];
1206
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1207
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1208
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1209
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1210
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
1211
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
1212
- [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
1213
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
1214
- [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:8];
1215
- [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:9];
1216
- [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:10];
1217
- [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
1218
- [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
1219
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
1220
- [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
1221
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
1222
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
1223
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
1224
- [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
1225
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
1226
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
1227
- [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
1228
- [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
1229
- [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
1230
- [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:24];
1231
- [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:25];
1232
- [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:26];
1233
- [encoder setBytes:&offs length:sizeof(offs) atIndex:27];
1234
-
1235
- const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00);
1236
-
1237
- [encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1238
- } break;
1239
- case GGML_OP_SCALE:
1240
- {
1241
- GGML_ASSERT(ggml_is_contiguous(src0));
1242
-
1243
- float scale;
1244
- memcpy(&scale, dst->op_params, sizeof(scale));
1245
-
1246
- int64_t n = ggml_nelements(dst);
1247
-
1248
- id<MTLComputePipelineState> pipeline = nil;
1249
-
1250
- if (n % 4 == 0) {
1251
- n /= 4;
1252
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE_4].pipeline;
1253
- } else {
1254
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE].pipeline;
1255
- }
1256
-
1257
- [encoder setComputePipelineState:pipeline];
1258
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1259
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1260
- [encoder setBytes:&scale length:sizeof(scale) atIndex:2];
1261
-
1262
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1263
- } break;
1264
- case GGML_OP_CLAMP:
1265
- {
1266
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CLAMP].pipeline;
1267
-
1268
- float min;
1269
- float max;
1270
- memcpy(&min, ((int32_t *) dst->op_params) + 0, sizeof(float));
1271
- memcpy(&max, ((int32_t *) dst->op_params) + 1, sizeof(float));
1272
-
1273
- [encoder setComputePipelineState:pipeline];
1274
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1275
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1276
- [encoder setBytes:&min length:sizeof(min) atIndex:2];
1277
- [encoder setBytes:&max length:sizeof(max) atIndex:3];
1278
-
1279
- const int64_t n = ggml_nelements(dst);
1280
-
1281
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1282
- } break;
1283
- case GGML_OP_UNARY:
1284
- switch (ggml_get_unary_op(gf->nodes[i])) {
1285
- // we are not taking into account the strides, so for now require contiguous tensors
1286
- GGML_ASSERT(ggml_is_contiguous(src0));
1287
-
1288
- case GGML_UNARY_OP_TANH:
1289
- {
1290
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TANH].pipeline;
1291
-
1292
- [encoder setComputePipelineState:pipeline];
1293
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1294
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1295
-
1296
- const int64_t n = ggml_nelements(dst);
1297
-
1298
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1299
- } break;
1300
- case GGML_UNARY_OP_RELU:
1301
- {
1302
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RELU].pipeline;
1303
-
1304
- [encoder setComputePipelineState:pipeline];
1305
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1306
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1307
-
1308
- const int64_t n = ggml_nelements(dst);
1309
-
1310
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1311
- } break;
1312
- case GGML_UNARY_OP_SIGMOID:
1313
- {
1314
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SIGMOID].pipeline;
1315
-
1316
- [encoder setComputePipelineState:pipeline];
1317
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1318
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1319
-
1320
- const int64_t n = ggml_nelements(dst);
1321
-
1322
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1323
- } break;
1324
- case GGML_UNARY_OP_GELU:
1325
- {
1326
- int64_t n = ggml_nelements(dst);
1327
-
1328
- id<MTLComputePipelineState> pipeline = nil;
1329
-
1330
- if (n % 4 == 0) {
1331
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_4].pipeline;
1332
- n /= 4;
1333
- } else {
1334
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU].pipeline;
1335
- }
1336
-
1337
- [encoder setComputePipelineState:pipeline];
1338
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1339
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1340
-
1341
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1342
- } break;
1343
- case GGML_UNARY_OP_GELU_QUICK:
1344
- {
1345
- int64_t n = ggml_nelements(dst);
1346
-
1347
- id<MTLComputePipelineState> pipeline = nil;
1348
-
1349
- if (n % 4 == 0) {
1350
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK_4].pipeline;
1351
- n /= 4;
1352
- } else {
1353
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK].pipeline;
1354
- }
1355
-
1356
- [encoder setComputePipelineState:pipeline];
1357
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1358
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1359
-
1360
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1361
- } break;
1362
- case GGML_UNARY_OP_SILU:
1363
- {
1364
- int64_t n = ggml_nelements(dst);
1365
-
1366
- id<MTLComputePipelineState> pipeline = nil;
1367
-
1368
- if (n % 4 == 0) {
1369
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU_4].pipeline;
1370
- n /= 4;
1371
- } else {
1372
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU].pipeline;
1373
- }
1374
-
1375
- [encoder setComputePipelineState:pipeline];
1376
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1377
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1378
-
1379
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1380
- } break;
1381
- default:
1382
- {
1383
- GGML_METAL_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
1384
- GGML_ASSERT(false);
1385
- }
1386
- } break;
1387
- case GGML_OP_SQR:
1388
- {
1389
- GGML_ASSERT(ggml_is_contiguous(src0));
1390
-
1391
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SQR].pipeline;
1392
-
1393
- [encoder setComputePipelineState:pipeline];
1394
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1395
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1396
-
1397
- const int64_t n = ggml_nelements(dst);
1398
-
1399
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1400
- } break;
1401
- case GGML_OP_SUM_ROWS:
1402
- {
1403
- GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
1404
-
1405
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
1406
-
1407
- [encoder setComputePipelineState:pipeline];
1408
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1409
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1410
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
1411
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
1412
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
1413
- [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
1414
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
1415
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
1416
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
1417
- [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
1418
- [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
1419
- [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
1420
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
1421
- [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
1422
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
1423
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
1424
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
1425
- [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17];
1426
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:18];
1427
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:19];
1428
- [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:20];
1429
- [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:21];
1430
- [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:22];
1431
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:23];
1432
- [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:24];
1433
- [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:25];
1434
-
1435
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1436
- } break;
1437
- case GGML_OP_SOFT_MAX:
1438
- {
1439
- GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32);
1440
-
1441
- int nth = 32; // SIMD width
1442
-
1443
- id<MTLComputePipelineState> pipeline = nil;
1444
-
1445
- const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
1446
-
1447
- if (ne00%4 == 0) {
1448
- while (nth < ne00/4 && nth*ne01*ne02*ne03 < 256) {
1449
- nth *= 2;
1450
- }
1451
- if (use_f16) {
1452
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4].pipeline;
1453
- } else {
1454
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4].pipeline;
1455
- }
1456
- } else {
1457
- while (nth < ne00 && nth*ne01*ne02*ne03 < 256) {
1458
- nth *= 2;
1459
- }
1460
- if (use_f16) {
1461
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16].pipeline;
1462
- } else {
1463
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32].pipeline;
1464
- }
1465
- }
1466
-
1467
- float scale;
1468
- float max_bias;
1469
-
1470
- memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale));
1471
- memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias));
1472
-
1473
- const int64_t nrows_x = ggml_nrows(src0);
1474
- const int64_t nrows_y = src0->ne[1];
1475
-
1476
- const uint32_t n_head = nrows_x/nrows_y;
1477
- const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
1478
-
1479
- const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
1480
- const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
1481
-
1482
- [encoder setComputePipelineState:pipeline];
1483
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1484
- if (id_src1) {
1485
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1486
- } else {
1487
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
1488
- }
1489
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1490
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1491
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
1492
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
1493
- [encoder setBytes:&scale length:sizeof(scale) atIndex:6];
1494
- [encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:7];
1495
- [encoder setBytes:&m0 length:sizeof(m0) atIndex:8];
1496
- [encoder setBytes:&m1 length:sizeof(m1) atIndex:9];
1497
- [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:10];
1498
- [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
1499
-
1500
- [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1501
- } break;
1502
- case GGML_OP_DIAG_MASK_INF:
1503
- {
1504
- const int n_past = ((int32_t *)(dst->op_params))[0];
1505
-
1506
- id<MTLComputePipelineState> pipeline = nil;
1507
-
1508
- if (ne00%8 == 0) {
1509
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8].pipeline;
1510
- } else {
1511
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF].pipeline;
1512
- }
1513
-
1514
- [encoder setComputePipelineState:pipeline];
1515
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1516
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1517
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
1518
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
1519
- [encoder setBytes:&n_past length:sizeof(int) atIndex:4];
1520
-
1521
- if (ne00%8 == 0) {
1522
- [encoder dispatchThreadgroups:MTLSizeMake(ne00*ne01*ne02/8, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1523
- }
1524
- else {
1525
- [encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1526
- }
1527
- } break;
1528
- case GGML_OP_MUL_MAT:
1529
- {
1530
- GGML_ASSERT(ne00 == ne10);
1531
-
1532
- GGML_ASSERT(ne12 % ne02 == 0);
1533
- GGML_ASSERT(ne13 % ne03 == 0);
1534
-
1535
- const uint r2 = ne12/ne02;
1536
- const uint r3 = ne13/ne03;
1537
-
1538
- // find the break-even point where the matrix-matrix kernel becomes more efficient compared
1539
- // to the matrix-vector kernel
1540
- int ne11_mm_min = 1;
1541
-
1542
- #if 0
1543
- // the numbers below are measured on M2 Ultra for 7B and 13B models
1544
- // these numbers do not translate to other devices or model sizes
1545
- // TODO: need to find a better approach
1546
- if ([ctx->device.name isEqualToString:@"Apple M2 Ultra"]) {
1547
- switch (src0t) {
1548
- case GGML_TYPE_F16: ne11_mm_min = 2; break;
1549
- case GGML_TYPE_Q8_0: ne11_mm_min = 7; break;
1550
- case GGML_TYPE_Q2_K: ne11_mm_min = 15; break;
1551
- case GGML_TYPE_Q3_K: ne11_mm_min = 7; break;
1552
- case GGML_TYPE_Q4_0:
1553
- case GGML_TYPE_Q4_1: ne11_mm_min = 15; break;
1554
- case GGML_TYPE_Q4_K: ne11_mm_min = 11; break;
1555
- case GGML_TYPE_Q5_0: // not tested yet
1556
- case GGML_TYPE_Q5_1: ne11_mm_min = 13; break; // not tested yet
1557
- case GGML_TYPE_Q5_K: ne11_mm_min = 7; break;
1558
- case GGML_TYPE_Q6_K: ne11_mm_min = 7; break;
1559
- default: ne11_mm_min = 1; break;
1560
- }
1561
- }
1562
- #endif
1563
-
1564
- // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
1565
- // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
1566
- if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
1567
- !ggml_is_transposed(src0) &&
1568
- !ggml_is_transposed(src1) &&
1569
- src1t == GGML_TYPE_F32 &&
1570
- ne00 % 32 == 0 && ne00 >= 64 &&
1571
- (ne11 > ne11_mm_min || (ggml_is_quantized(src0t) && ne12 > 1))) {
1572
- //printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
1573
-
1574
- // some Metal matrix data types require aligned pointers
1575
- // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
1576
- switch (src0->type) {
1577
- case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
1578
- case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
1579
- default: break;
1580
- }
1581
-
1582
- id<MTLComputePipelineState> pipeline = nil;
1583
-
1584
- switch (src0->type) {
1585
- case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32 ].pipeline; break;
1586
- case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32 ].pipeline; break;
1587
- case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32 ].pipeline; break;
1588
- case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32 ].pipeline; break;
1589
- case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32 ].pipeline; break;
1590
- case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32 ].pipeline; break;
1591
- case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32 ].pipeline; break;
1592
- case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32 ].pipeline; break;
1593
- case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32 ].pipeline; break;
1594
- case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32 ].pipeline; break;
1595
- case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32 ].pipeline; break;
1596
- case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32 ].pipeline; break;
1597
- case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32].pipeline; break;
1598
- case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline; break;
1599
- case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32].pipeline; break;
1600
- case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32 ].pipeline; break;
1601
- case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32 ].pipeline; break;
1602
- case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32 ].pipeline; break;
1603
- case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32 ].pipeline; break;
1604
- case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break;
1605
- case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32 ].pipeline; break;
1606
- default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
1607
- }
1608
-
1609
- [encoder setComputePipelineState:pipeline];
1610
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1611
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1612
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1613
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1614
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
1615
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
1616
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6];
1617
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
1618
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:8];
1619
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:9];
1620
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10];
1621
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11];
1622
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12];
1623
- [encoder setBytes:&r2 length:sizeof(r2) atIndex:13];
1624
- [encoder setBytes:&r3 length:sizeof(r3) atIndex:14];
1625
- [encoder setThreadgroupMemoryLength:8192 atIndex:0];
1626
- [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1627
- } else {
1628
- int nth0 = 32;
1629
- int nth1 = 1;
1630
- int nrows = 1;
1631
- //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
1632
-
1633
- id<MTLComputePipelineState> pipeline = nil;
1634
-
1635
- // use custom matrix x vector kernel
1636
- switch (src0t) {
1637
- case GGML_TYPE_F32:
1638
- {
1639
- GGML_ASSERT(src1t == GGML_TYPE_F32);
1640
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline;
1641
- nrows = 4;
1642
- } break;
1643
- case GGML_TYPE_F16:
1644
- {
1645
- nth0 = 32;
1646
- nth1 = 1;
1647
- if (src1t == GGML_TYPE_F32) {
1648
- if (ne11 * ne12 < 4) {
1649
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW].pipeline;
1650
- } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
1651
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline;
1652
- nrows = ne11;
1653
- } else {
1654
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32].pipeline;
1655
- nrows = 4;
1656
- }
1657
- } else {
1658
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16].pipeline;
1659
- nrows = 4;
1660
- }
1661
- } break;
1662
- case GGML_TYPE_Q4_0:
1663
- {
1664
- nth0 = 8;
1665
- nth1 = 8;
1666
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32].pipeline;
1667
- } break;
1668
- case GGML_TYPE_Q4_1:
1669
- {
1670
- nth0 = 8;
1671
- nth1 = 8;
1672
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32].pipeline;
1673
- } break;
1674
- case GGML_TYPE_Q5_0:
1675
- {
1676
- nth0 = 8;
1677
- nth1 = 8;
1678
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32].pipeline;
1679
- } break;
1680
- case GGML_TYPE_Q5_1:
1681
- {
1682
- nth0 = 8;
1683
- nth1 = 8;
1684
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32].pipeline;
1685
- } break;
1686
- case GGML_TYPE_Q8_0:
1687
- {
1688
- nth0 = 8;
1689
- nth1 = 8;
1690
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32].pipeline;
1691
- } break;
1692
- case GGML_TYPE_Q2_K:
1693
- {
1694
- nth0 = 2;
1695
- nth1 = 32;
1696
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32].pipeline;
1697
- } break;
1698
- case GGML_TYPE_Q3_K:
1699
- {
1700
- nth0 = 2;
1701
- nth1 = 32;
1702
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32].pipeline;
1703
- } break;
1704
- case GGML_TYPE_Q4_K:
1705
- {
1706
- nth0 = 4; //1;
1707
- nth1 = 8; //32;
1708
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32].pipeline;
1709
- } break;
1710
- case GGML_TYPE_Q5_K:
1711
- {
1712
- nth0 = 2;
1713
- nth1 = 32;
1714
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32].pipeline;
1715
- } break;
1716
- case GGML_TYPE_Q6_K:
1717
- {
1718
- nth0 = 2;
1719
- nth1 = 32;
1720
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32].pipeline;
1721
- } break;
1722
- case GGML_TYPE_IQ2_XXS:
1723
- {
1724
- nth0 = 4;
1725
- nth1 = 16;
1726
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32].pipeline;
1727
- } break;
1728
- case GGML_TYPE_IQ2_XS:
1729
- {
1730
- nth0 = 4;
1731
- nth1 = 16;
1732
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32].pipeline;
1733
- } break;
1734
- case GGML_TYPE_IQ3_XXS:
1735
- {
1736
- nth0 = 4;
1737
- nth1 = 16;
1738
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32].pipeline;
1739
- } break;
1740
- case GGML_TYPE_IQ3_S:
1741
- {
1742
- nth0 = 4;
1743
- nth1 = 16;
1744
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32].pipeline;
1745
- } break;
1746
- case GGML_TYPE_IQ2_S:
1747
- {
1748
- nth0 = 4;
1749
- nth1 = 16;
1750
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32].pipeline;
1751
- } break;
1752
- case GGML_TYPE_IQ1_S:
1753
- {
1754
- nth0 = 4;
1755
- nth1 = 16;
1756
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32].pipeline;
1757
- } break;
1758
- case GGML_TYPE_IQ1_M:
1759
- {
1760
- nth0 = 4;
1761
- nth1 = 16;
1762
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32].pipeline;
1763
- } break;
1764
- case GGML_TYPE_IQ4_NL:
1765
- {
1766
- nth0 = 4;
1767
- nth1 = 16;
1768
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32].pipeline;
1769
- } break;
1770
- case GGML_TYPE_IQ4_XS:
1771
- {
1772
- nth0 = 4;
1773
- nth1 = 16;
1774
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32].pipeline;
1775
- } break;
1776
- default:
1777
- {
1778
- GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
1779
- GGML_ASSERT(false && "not implemented");
1780
- }
1781
- };
1782
-
1783
- if (ggml_is_quantized(src0t)) {
1784
- GGML_ASSERT(ne00 >= nth0*nth1);
1785
- }
1786
-
1787
- [encoder setComputePipelineState:pipeline];
1788
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1789
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1790
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1791
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1792
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
1793
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
1794
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
1795
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
1796
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
1797
- [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:9];
1798
- [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:10];
1799
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:11];
1800
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:12];
1801
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:13];
1802
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
1803
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15];
1804
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
1805
- [encoder setBytes:&r2 length:sizeof(r2) atIndex:17];
1806
- [encoder setBytes:&r3 length:sizeof(r3) atIndex:18];
1807
-
1808
- if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 ||
1809
- src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K ||
1810
- src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
1811
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1812
- }
1813
- else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
1814
- const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
1815
- [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1816
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1817
- }
1818
- else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) {
1819
- const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
1820
- [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1821
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1822
- }
1823
- else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) {
1824
- const int mem_size = 32*sizeof(float);
1825
- [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1826
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1827
- }
1828
- else if (src0t == GGML_TYPE_Q4_K) {
1829
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1830
- }
1831
- else if (src0t == GGML_TYPE_Q3_K) {
1832
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1833
- }
1834
- else if (src0t == GGML_TYPE_Q5_K) {
1835
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1836
- }
1837
- else if (src0t == GGML_TYPE_Q6_K) {
1838
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1839
- } else {
1840
- const int64_t ny = (ne11 + nrows - 1)/nrows;
1841
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1842
- }
1843
- }
1844
- } break;
1845
- case GGML_OP_MUL_MAT_ID:
1846
- {
1847
- const int n_as = src0->ne[2];
1848
-
1849
- // src2 = ids
1850
- const enum ggml_type src2t = src2->type; GGML_UNUSED(src2t);
1851
-
1852
- GGML_ASSERT(src2t == GGML_TYPE_I32);
1853
-
1854
- GGML_ASSERT(!ggml_is_transposed(src0));
1855
- GGML_ASSERT(!ggml_is_transposed(src1));
1856
-
1857
- GGML_ASSERT(src1t == GGML_TYPE_F32);
1858
-
1859
- // find the break-even point where the matrix-matrix kernel becomes more efficient compared
1860
- // to the matrix-vector kernel
1861
- // ne20 = n_used_experts
1862
- // ne21 = n_rows
1863
- const int dst_rows = ne20*ne21;
1864
- const int dst_rows_min = n_as;
1865
- const int dst_rows_max = (ctx->device.maxThreadgroupMemoryLength - 32 - 8192)/4;
1866
-
1867
- // max size of the rowids array in the kernel shared buffer
1868
- GGML_ASSERT(dst_rows <= dst_rows_max);
1869
-
1870
- // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
1871
- // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
1872
- // !!!
1873
- // TODO: for now, always use mat-vec kernels until we figure out how to improve the
1874
- // indirect matrix multiplication
1875
- // !!!
1876
- if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
1877
- ne00 % 32 == 0 && ne00 >= 64 &&
1878
- dst_rows > dst_rows_min) {
1879
-
1880
- // some Metal matrix data types require aligned pointers
1881
- // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
1882
- switch (src0->type) {
1883
- case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
1884
- case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
1885
- default: break;
1886
- }
1887
-
1888
- id<MTLComputePipelineState> pipeline = nil;
1889
-
1890
- switch (src0->type) {
1891
- case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32 ].pipeline; break;
1892
- case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32 ].pipeline; break;
1893
- case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32 ].pipeline; break;
1894
- case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32 ].pipeline; break;
1895
- case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32 ].pipeline; break;
1896
- case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32 ].pipeline; break;
1897
- case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32 ].pipeline; break;
1898
- case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32 ].pipeline; break;
1899
- case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32 ].pipeline; break;
1900
- case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32 ].pipeline; break;
1901
- case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32 ].pipeline; break;
1902
- case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32 ].pipeline; break;
1903
- case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32].pipeline; break;
1904
- case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline; break;
1905
- case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32].pipeline; break;
1906
- case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32 ].pipeline; break;
1907
- case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32 ].pipeline; break;
1908
- case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32 ].pipeline; break;
1909
- case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32 ].pipeline; break;
1910
- case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32 ].pipeline; break;
1911
- case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32 ].pipeline; break;
1912
- default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
1913
- }
1914
-
1915
- [encoder setComputePipelineState:pipeline];
1916
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1917
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1918
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1919
- [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
1920
- [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
1921
- [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
1922
- [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
1923
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:7];
1924
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:8];
1925
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:9];
1926
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:10];
1927
- [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
1928
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
1929
- [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
1930
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
1931
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
1932
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
1933
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17];
1934
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:18];
1935
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19];
1936
-
1937
- [encoder setThreadgroupMemoryLength:GGML_PAD(8192 + dst_rows*4/*sizeof(ushort2)*/, 16) atIndex:0];
1938
-
1939
- [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 31)/32, (ne01 + 63)/64, n_as) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1940
- } else {
1941
- int nth0 = 32;
1942
- int nth1 = 1;
1943
- int nrows = 1;
1944
- //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
1945
-
1946
- id<MTLComputePipelineState> pipeline = nil;
1947
-
1948
- // use custom matrix x vector kernel
1949
- switch (src0t) {
1950
- case GGML_TYPE_F32:
1951
- {
1952
- GGML_ASSERT(src1t == GGML_TYPE_F32);
1953
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32].pipeline;
1954
- } break;
1955
- case GGML_TYPE_F16:
1956
- {
1957
- GGML_ASSERT(src1t == GGML_TYPE_F32);
1958
- nth0 = 32;
1959
- nth1 = 1;
1960
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32].pipeline;
1961
- } break;
1962
- case GGML_TYPE_Q4_0:
1963
- {
1964
- nth0 = 8;
1965
- nth1 = 8;
1966
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32].pipeline;
1967
- } break;
1968
- case GGML_TYPE_Q4_1:
1969
- {
1970
- nth0 = 8;
1971
- nth1 = 8;
1972
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32].pipeline;
1973
- } break;
1974
- case GGML_TYPE_Q5_0:
1975
- {
1976
- nth0 = 8;
1977
- nth1 = 8;
1978
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32].pipeline;
1979
- } break;
1980
- case GGML_TYPE_Q5_1:
1981
- {
1982
- nth0 = 8;
1983
- nth1 = 8;
1984
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32].pipeline;
1985
- } break;
1986
- case GGML_TYPE_Q8_0:
1987
- {
1988
- nth0 = 8;
1989
- nth1 = 8;
1990
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32].pipeline;
1991
- } break;
1992
- case GGML_TYPE_Q2_K:
1993
- {
1994
- nth0 = 2;
1995
- nth1 = 32;
1996
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32].pipeline;
1997
- } break;
1998
- case GGML_TYPE_Q3_K:
1999
- {
2000
- nth0 = 2;
2001
- nth1 = 32;
2002
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32].pipeline;
2003
- } break;
2004
- case GGML_TYPE_Q4_K:
2005
- {
2006
- nth0 = 4; //1;
2007
- nth1 = 8; //32;
2008
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32].pipeline;
2009
- } break;
2010
- case GGML_TYPE_Q5_K:
2011
- {
2012
- nth0 = 2;
2013
- nth1 = 32;
2014
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32].pipeline;
2015
- } break;
2016
- case GGML_TYPE_Q6_K:
2017
- {
2018
- nth0 = 2;
2019
- nth1 = 32;
2020
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32].pipeline;
2021
- } break;
2022
- case GGML_TYPE_IQ2_XXS:
2023
- {
2024
- nth0 = 4;
2025
- nth1 = 16;
2026
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32].pipeline;
2027
- } break;
2028
- case GGML_TYPE_IQ2_XS:
2029
- {
2030
- nth0 = 4;
2031
- nth1 = 16;
2032
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32].pipeline;
2033
- } break;
2034
- case GGML_TYPE_IQ3_XXS:
2035
- {
2036
- nth0 = 4;
2037
- nth1 = 16;
2038
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32].pipeline;
2039
- } break;
2040
- case GGML_TYPE_IQ3_S:
2041
- {
2042
- nth0 = 4;
2043
- nth1 = 16;
2044
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32].pipeline;
2045
- } break;
2046
- case GGML_TYPE_IQ2_S:
2047
- {
2048
- nth0 = 4;
2049
- nth1 = 16;
2050
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32].pipeline;
2051
- } break;
2052
- case GGML_TYPE_IQ1_S:
2053
- {
2054
- nth0 = 4;
2055
- nth1 = 16;
2056
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32].pipeline;
2057
- } break;
2058
- case GGML_TYPE_IQ1_M:
2059
- {
2060
- nth0 = 4;
2061
- nth1 = 16;
2062
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32].pipeline;
2063
- } break;
2064
- case GGML_TYPE_IQ4_NL:
2065
- {
2066
- nth0 = 4;
2067
- nth1 = 16;
2068
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32].pipeline;
2069
- } break;
2070
- case GGML_TYPE_IQ4_XS:
2071
- {
2072
- nth0 = 4;
2073
- nth1 = 16;
2074
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32].pipeline;
2075
- } break;
2076
- default:
2077
- {
2078
- GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t);
2079
- GGML_ASSERT(false && "not implemented");
2080
- }
2081
- };
2082
-
2083
- if (ggml_is_quantized(src0t)) {
2084
- GGML_ASSERT(ne00 >= nth0*nth1);
2085
- }
2086
-
2087
- [encoder setComputePipelineState:pipeline];
2088
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2089
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
2090
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
2091
- [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
2092
- [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
2093
- [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
2094
- [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
2095
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:7];
2096
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:8];
2097
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:9];
2098
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:10];
2099
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:11];
2100
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:12];
2101
- [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:13];
2102
- [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:14];
2103
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:15];
2104
- [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:16];
2105
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:17];
2106
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:18];
2107
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:19];
2108
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:20];
2109
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:21];
2110
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:22];
2111
-
2112
- const int64_t _ne1 = 1;
2113
- const int tgz = dst_rows;
2114
-
2115
- if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 ||
2116
- src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K ||
2117
- src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
2118
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2119
- }
2120
- else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
2121
- const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
2122
- [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
2123
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2124
- }
2125
- else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) {
2126
- const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
2127
- [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
2128
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2129
- }
2130
- else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) {
2131
- const int mem_size = 32*sizeof(float);
2132
- [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
2133
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2134
- }
2135
- else if (src0t == GGML_TYPE_Q4_K) {
2136
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2137
- }
2138
- else if (src0t == GGML_TYPE_Q3_K) {
2139
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2140
- }
2141
- else if (src0t == GGML_TYPE_Q5_K) {
2142
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2143
- }
2144
- else if (src0t == GGML_TYPE_Q6_K) {
2145
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2146
- } else {
2147
- const int64_t ny = (_ne1 + nrows - 1)/nrows; // = _ne1
2148
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2149
- }
2150
- }
2151
- } break;
2152
- case GGML_OP_GET_ROWS:
2153
- {
2154
- id<MTLComputePipelineState> pipeline = nil;
2155
-
2156
- switch (src0->type) {
2157
- case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F32 ].pipeline; break;
2158
- case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F16 ].pipeline; break;
2159
- case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0 ].pipeline; break;
2160
- case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1 ].pipeline; break;
2161
- case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0 ].pipeline; break;
2162
- case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1 ].pipeline; break;
2163
- case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0 ].pipeline; break;
2164
- case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K ].pipeline; break;
2165
- case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K ].pipeline; break;
2166
- case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K ].pipeline; break;
2167
- case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K ].pipeline; break;
2168
- case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K ].pipeline; break;
2169
- case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS].pipeline; break;
2170
- case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS ].pipeline; break;
2171
- case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS].pipeline; break;
2172
- case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S ].pipeline; break;
2173
- case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S ].pipeline; break;
2174
- case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S ].pipeline; break;
2175
- case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M ].pipeline; break;
2176
- case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL ].pipeline; break;
2177
- case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS ].pipeline; break;
2178
- case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ].pipeline; break;
2179
- default: GGML_ASSERT(false && "not implemented");
2180
- }
2181
-
2182
- [encoder setComputePipelineState:pipeline];
2183
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2184
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
2185
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
2186
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
2187
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
2188
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:5];
2189
- [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:6];
2190
- [encoder setBytes:&nb10 length:sizeof( int64_t) atIndex:7];
2191
- [encoder setBytes:&nb11 length:sizeof( int64_t) atIndex:8];
2192
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:9];
2193
- [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:10];
2194
-
2195
- [encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
2196
- } break;
2197
- case GGML_OP_RMS_NORM:
2198
- {
2199
- GGML_ASSERT(ne00 % 4 == 0);
2200
- GGML_ASSERT(ggml_is_contiguous_1(src0));
2201
-
2202
- float eps;
2203
- memcpy(&eps, dst->op_params, sizeof(float));
2204
-
2205
- int nth = 32; // SIMD width
2206
-
2207
- while (nth < ne00/4 && nth < 1024) {
2208
- nth *= 2;
2209
- }
2210
-
2211
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM].pipeline;
2212
-
2213
- [encoder setComputePipelineState:pipeline];
2214
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2215
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2216
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
2217
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
2218
- [encoder setBytes:&eps length:sizeof( float) atIndex:4];
2219
- [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
2220
-
2221
- const int64_t nrows = ggml_nrows(src0);
2222
-
2223
- [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2224
- } break;
2225
- case GGML_OP_GROUP_NORM:
2226
- {
2227
- GGML_ASSERT(ne00 % 4 == 0);
2228
- GGML_ASSERT(ggml_is_contiguous(src0));
2229
-
2230
- //float eps;
2231
- //memcpy(&eps, dst->op_params, sizeof(float));
2232
-
2233
- const float eps = 1e-6f; // TODO: temporarily hardcoded
2234
-
2235
- const int32_t n_groups = ((int32_t *) dst->op_params)[0];
2236
-
2237
- int nth = 32; // SIMD width
2238
-
2239
- //while (nth < ne00/4 && nth < 1024) {
2240
- // nth *= 2;
2241
- //}
2242
-
2243
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GROUP_NORM].pipeline;
2244
-
2245
- [encoder setComputePipelineState:pipeline];
2246
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2247
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2248
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
2249
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
2250
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
2251
- [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:5];
2252
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:6];
2253
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:7];
2254
- [encoder setBytes:&n_groups length:sizeof( int32_t) atIndex:8];
2255
- [encoder setBytes:&eps length:sizeof( float) atIndex:9];
2256
- [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
2257
-
2258
- [encoder dispatchThreadgroups:MTLSizeMake(n_groups, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2259
- } break;
2260
- case GGML_OP_NORM:
2261
- {
2262
- GGML_ASSERT(ggml_is_contiguous_1(src0));
2263
-
2264
- float eps;
2265
- memcpy(&eps, dst->op_params, sizeof(float));
2266
-
2267
- const int nth = MIN(256, ne00);
2268
-
2269
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_NORM].pipeline;
2270
-
2271
- [encoder setComputePipelineState:pipeline];
2272
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2273
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2274
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
2275
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
2276
- [encoder setBytes:&eps length:sizeof( float) atIndex:4];
2277
- [encoder setThreadgroupMemoryLength:GGML_PAD(nth*sizeof(float), 16) atIndex:0];
2278
-
2279
- const int64_t nrows = ggml_nrows(src0);
2280
-
2281
- [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2282
- } break;
2283
- case GGML_OP_ROPE:
2284
- {
2285
- GGML_ASSERT(ne10 == ne02);
2286
-
2287
- const int nth = MIN(1024, ne00);
2288
-
2289
- const int n_past = ((int32_t *) dst->op_params)[0];
2290
- const int n_dims = ((int32_t *) dst->op_params)[1];
2291
- const int mode = ((int32_t *) dst->op_params)[2];
2292
- // skip 3, n_ctx, used in GLM RoPE, unimplemented in metal
2293
- const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
2294
-
2295
- float freq_base;
2296
- float freq_scale;
2297
- float ext_factor;
2298
- float attn_factor;
2299
- float beta_fast;
2300
- float beta_slow;
2301
-
2302
- memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
2303
- memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
2304
- memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
2305
- memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
2306
- memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
2307
- memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
2308
-
2309
- const bool is_neox = mode & 2;
2310
-
2311
- id<MTLComputePipelineState> pipeline = nil;
2312
-
2313
- if (!is_neox) {
2314
- switch (src0->type) {
2315
- case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break;
2316
- case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break;
2317
- default: GGML_ASSERT(false);
2318
- };
2319
- } else {
2320
- switch (src0->type) {
2321
- case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break;
2322
- case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break;
2323
- default: GGML_ASSERT(false);
2324
- };
2325
- }
2326
-
2327
- [encoder setComputePipelineState:pipeline];
2328
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2329
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
2330
- if (id_src2 != nil) {
2331
- [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
2332
- } else {
2333
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:2];
2334
- }
2335
- [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
2336
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:4];
2337
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5];
2338
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6];
2339
- [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7];
2340
- [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:8];
2341
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:9];
2342
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:10];
2343
- [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:11];
2344
- [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:12];
2345
- [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:13];
2346
- [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:14];
2347
- [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:15];
2348
- [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:16];
2349
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:17];
2350
- [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:18];
2351
- [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:19];
2352
- [encoder setBytes:&n_past length:sizeof( int) atIndex:20];
2353
- [encoder setBytes:&n_dims length:sizeof( int) atIndex:21];
2354
- [encoder setBytes:&n_ctx_orig length:sizeof( int) atIndex:22];
2355
- [encoder setBytes:&freq_base length:sizeof( float) atIndex:23];
2356
- [encoder setBytes:&freq_scale length:sizeof( float) atIndex:24];
2357
- [encoder setBytes:&ext_factor length:sizeof( float) atIndex:25];
2358
- [encoder setBytes:&attn_factor length:sizeof( float) atIndex:26];
2359
- [encoder setBytes:&beta_fast length:sizeof( float) atIndex:27];
2360
- [encoder setBytes:&beta_slow length:sizeof( float) atIndex:28];
2361
-
2362
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2363
- } break;
2364
- case GGML_OP_IM2COL:
2365
- {
2366
- GGML_ASSERT(src0->type == GGML_TYPE_F16);
2367
- GGML_ASSERT(src1->type == GGML_TYPE_F32);
2368
- GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
2369
-
2370
- const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
2371
- const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
2372
- const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
2373
- const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
2374
- const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
2375
- const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
2376
-
2377
- const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
2378
-
2379
- const int32_t N = src1->ne[is_2D ? 3 : 2];
2380
- const int32_t IC = src1->ne[is_2D ? 2 : 1];
2381
- const int32_t IH = is_2D ? src1->ne[1] : 1;
2382
- const int32_t IW = src1->ne[0];
2383
-
2384
- const int32_t KH = is_2D ? src0->ne[1] : 1;
2385
- const int32_t KW = src0->ne[0];
2386
-
2387
- const int32_t OH = is_2D ? dst->ne[2] : 1;
2388
- const int32_t OW = dst->ne[1];
2389
-
2390
- const int32_t CHW = IC * KH * KW;
2391
-
2392
- const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4;
2393
- const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4;
2394
-
2395
- id<MTLComputePipelineState> pipeline = nil;
2396
-
2397
- switch (dst->type) {
2398
- case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline; break;
2399
- case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline; break;
2400
- default: GGML_ASSERT(false);
2401
- };
2402
-
2403
- [encoder setComputePipelineState:pipeline];
2404
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:0];
2405
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2406
- [encoder setBytes:&ofs0 length:sizeof( int32_t) atIndex:2];
2407
- [encoder setBytes:&ofs1 length:sizeof( int32_t) atIndex:3];
2408
- [encoder setBytes:&IW length:sizeof( int32_t) atIndex:4];
2409
- [encoder setBytes:&IH length:sizeof( int32_t) atIndex:5];
2410
- [encoder setBytes:&CHW length:sizeof( int32_t) atIndex:6];
2411
- [encoder setBytes:&s0 length:sizeof( int32_t) atIndex:7];
2412
- [encoder setBytes:&s1 length:sizeof( int32_t) atIndex:8];
2413
- [encoder setBytes:&p0 length:sizeof( int32_t) atIndex:9];
2414
- [encoder setBytes:&p1 length:sizeof( int32_t) atIndex:10];
2415
- [encoder setBytes:&d0 length:sizeof( int32_t) atIndex:11];
2416
- [encoder setBytes:&d1 length:sizeof( int32_t) atIndex:12];
2417
-
2418
- [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
2419
- } break;
2420
- case GGML_OP_UPSCALE:
2421
- {
2422
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
2423
-
2424
- const float sf0 = (float)ne0/src0->ne[0];
2425
- const float sf1 = (float)ne1/src0->ne[1];
2426
- const float sf2 = (float)ne2/src0->ne[2];
2427
- const float sf3 = (float)ne3/src0->ne[3];
2428
-
2429
- const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_UPSCALE_F32].pipeline;
2430
-
2431
- [encoder setComputePipelineState:pipeline];
2432
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2433
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2434
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
2435
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
2436
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
2437
- [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
2438
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
2439
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
2440
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
2441
- [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
2442
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
2443
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
2444
- [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
2445
- [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
2446
- [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
2447
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
2448
- [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
2449
- [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
2450
- [encoder setBytes:&sf0 length:sizeof(sf0) atIndex:18];
2451
- [encoder setBytes:&sf1 length:sizeof(sf1) atIndex:19];
2452
- [encoder setBytes:&sf2 length:sizeof(sf2) atIndex:20];
2453
- [encoder setBytes:&sf3 length:sizeof(sf3) atIndex:21];
2454
-
2455
- const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
2456
-
2457
- [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2458
- } break;
2459
- case GGML_OP_PAD:
2460
- {
2461
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
2462
-
2463
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_PAD_F32].pipeline;
2464
-
2465
- [encoder setComputePipelineState:pipeline];
2466
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2467
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2468
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
2469
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
2470
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
2471
- [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
2472
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
2473
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
2474
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
2475
- [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
2476
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
2477
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
2478
- [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
2479
- [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
2480
- [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
2481
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
2482
- [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
2483
- [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
2484
-
2485
- const int nth = MIN(1024, ne0);
2486
-
2487
- [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2488
- } break;
2489
- case GGML_OP_ARANGE:
2490
- {
2491
- GGML_ASSERT(dst->type == GGML_TYPE_F32);
2492
-
2493
- float start;
2494
- float step;
2495
-
2496
- memcpy(&start, ((int32_t *) dst->op_params) + 0, sizeof(float));
2497
- memcpy(&step, ((int32_t *) dst->op_params) + 2, sizeof(float));
2498
-
2499
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARANGE_F32].pipeline;
2500
-
2501
- [encoder setComputePipelineState:pipeline];
2502
- [encoder setBuffer:id_dst offset:offs_dst atIndex:0];
2503
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:1];
2504
- [encoder setBytes:&start length:sizeof(start) atIndex:2];
2505
- [encoder setBytes:&step length:sizeof(step) atIndex:3];
2506
-
2507
- const int nth = MIN(1024, ne0);
2508
-
2509
- [encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2510
- } break;
2511
- case GGML_OP_TIMESTEP_EMBEDDING:
2512
- {
2513
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
2514
-
2515
- const int dim = dst->op_params[0];
2516
- const int max_period = dst->op_params[1];
2517
-
2518
- const int half = dim / 2;
2519
-
2520
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32].pipeline;
2521
-
2522
- [encoder setComputePipelineState:pipeline];
2523
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2524
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2525
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:2];
2526
- [encoder setBytes:&dim length:sizeof(dim) atIndex:3];
2527
- [encoder setBytes:&max_period length:sizeof(max_period) atIndex:4];
2528
-
2529
- const int nth = MIN(1024, half);
2530
-
2531
- [encoder dispatchThreadgroups:MTLSizeMake(ne00, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2532
- } break;
2533
- case GGML_OP_ARGSORT:
2534
- {
2535
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
2536
- GGML_ASSERT( dst->type == GGML_TYPE_I32);
2537
-
2538
- const int nrows = ggml_nrows(src0);
2539
-
2540
- enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
2541
-
2542
- // bitonic sort requires the number of elements to be power of 2
2543
- int64_t ne00_padded = 1;
2544
- while (ne00_padded < ne00) {
2545
- ne00_padded *= 2;
2546
- }
2547
-
2548
- // Metal kernels require the buffer size to be multiple of 16 bytes
2549
- // https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength
2550
- const int mem_size = GGML_PAD(ne00_padded*sizeof(int32_t), 16);
2551
-
2552
- id<MTLComputePipelineState> pipeline = nil;
2553
-
2554
- switch (order) {
2555
- case GGML_SORT_ORDER_ASC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC].pipeline; break;
2556
- case GGML_SORT_ORDER_DESC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC].pipeline; break;
2557
- default: GGML_ASSERT(false);
2558
- };
2559
-
2560
- [encoder setComputePipelineState:pipeline];
2561
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2562
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2563
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
2564
- [encoder setBytes:&ne00_padded length:sizeof( int64_t) atIndex:3];
2565
- [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
2566
-
2567
- [encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00_padded, 1, 1)];
2568
- } break;
2569
- case GGML_OP_LEAKY_RELU:
2570
- {
2571
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
2572
-
2573
- float slope;
2574
- memcpy(&slope, dst->op_params, sizeof(float));
2575
-
2576
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32].pipeline;
2577
-
2578
- [encoder setComputePipelineState:pipeline];
2579
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2580
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2581
- [encoder setBytes:&slope length:sizeof(slope) atIndex:2];
2582
-
2583
- const int64_t n = ggml_nelements(dst);
2584
-
2585
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2586
- } break;
2587
- case GGML_OP_FLASH_ATTN_EXT:
2588
- {
2589
- GGML_ASSERT(ne00 % 4 == 0);
2590
- GGML_ASSERT(ne11 % 32 == 0);
2591
-
2592
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
2593
-
2594
- GGML_ASSERT(ggml_are_same_shape (src1, src2));
2595
-
2596
- struct ggml_tensor * src3 = gf->nodes[i]->src[3];
2597
-
2598
- size_t offs_src3 = 0;
2599
-
2600
- id<MTLBuffer> id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil;
2601
-
2602
- GGML_ASSERT(!src3 || src3->type == GGML_TYPE_F16);
2603
- GGML_ASSERT(!src3 || src3->ne[1] >= GGML_PAD(src0->ne[1], 8) &&
2604
- "the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big");
2605
-
2606
- const int64_t ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30);
2607
- //const int64_t ne31 = src3 ? src3->ne[1] : 0;
2608
- const int64_t ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32);
2609
- const int64_t ne33 = src3 ? src3->ne[3] : 0; GGML_UNUSED(ne33);
2610
-
2611
- const uint64_t nb30 = src3 ? src3->nb[0] : 0; GGML_UNUSED(nb30);
2612
- const uint64_t nb31 = src3 ? src3->nb[1] : 0;
2613
- const uint64_t nb32 = src3 ? src3->nb[2] : 0; GGML_UNUSED(nb32);
2614
- const uint64_t nb33 = src3 ? src3->nb[3] : 0; GGML_UNUSED(nb33);
2615
-
2616
- const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t);
2617
-
2618
- float scale;
2619
- float max_bias;
2620
-
2621
- memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale));
2622
- memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias));
2623
-
2624
- const uint32_t n_head = src0->ne[2];
2625
- const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
2626
-
2627
- const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
2628
- const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
2629
-
2630
- id<MTLComputePipelineState> pipeline = nil;
2631
-
2632
- bool use_vec_kernel = false;
2633
-
2634
- if (ne01 >= 4 || (ne00%128 != 0)) {
2635
- switch (ne00) {
2636
- case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break;
2637
- case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break;
2638
- case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break;
2639
- case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break;
2640
- case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break;
2641
- //case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
2642
- default:
2643
- {
2644
- GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
2645
- GGML_METAL_LOG_ERROR("add template specialization for this size\n");
2646
- GGML_ASSERT(false && "add template specialization for this size");
2647
- }
2648
- }
2649
- } else {
2650
- use_vec_kernel = true;
2651
-
2652
- switch (ne00) {
2653
- case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break;
2654
- //case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
2655
- default:
2656
- {
2657
- GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
2658
- GGML_METAL_LOG_ERROR("add template specialization for this size\n");
2659
- GGML_ASSERT(false && "add template specialization for this size");
2660
- }
2661
- }
2662
- }
2663
-
2664
- [encoder setComputePipelineState:pipeline];
2665
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2666
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
2667
- [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
2668
- if (id_src3) {
2669
- [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
2670
- } else {
2671
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:3];
2672
- }
2673
- [encoder setBuffer:id_dst offset:offs_dst atIndex:4];
2674
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5];
2675
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6];
2676
- [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7];
2677
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8];
2678
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9];
2679
- [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10];
2680
- [encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:11];
2681
- [encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:12];
2682
- [encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:13];
2683
- [encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:14];
2684
- [encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:15];
2685
- [encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:16];
2686
- [encoder setBytes:&nb21 length:sizeof(uint64_t) atIndex:17];
2687
- [encoder setBytes:&nb22 length:sizeof(uint64_t) atIndex:18];
2688
- [encoder setBytes:&nb23 length:sizeof(uint64_t) atIndex:19];
2689
- [encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:20];
2690
- [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:21];
2691
- [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:22];
2692
- [encoder setBytes:&scale length:sizeof( float) atIndex:23];
2693
- [encoder setBytes:&max_bias length:sizeof( float) atIndex:24];
2694
- [encoder setBytes:&m0 length:sizeof(m0) atIndex:25];
2695
- [encoder setBytes:&m1 length:sizeof(m1) atIndex:26];
2696
- [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:27];
2697
-
2698
- if (!use_vec_kernel) {
2699
- // half8x8 kernel
2700
- const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !!
2701
- const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
2702
-
2703
- GGML_ASSERT(nqptg <= 32);
2704
- GGML_ASSERT(nqptg % 8 == 0);
2705
- GGML_ASSERT(ncpsg % 32 == 0);
2706
-
2707
- int64_t nsgmax = 2;
2708
-
2709
- while (true) {
2710
- const size_t smem = nqptg*(ne00 + 2*nsgmax*(ncpsg + nqptg))*(sizeof(float)/2);
2711
- if (smem > ctx->device.maxThreadgroupMemoryLength) {
2712
- break;
2713
- }
2714
- nsgmax *= 2;
2715
- }
2716
- nsgmax /= 2;
2717
-
2718
- // simdgroups per threadgroup (a.k.a. warps)
2719
- const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4;
2720
-
2721
- const size_t smem = nqptg*(ne00 + 2*nsg*(ncpsg + nqptg))*(sizeof(float)/2);
2722
-
2723
- //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength);
2724
- GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength);
2725
-
2726
- [encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0];
2727
-
2728
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
2729
- } else {
2730
- // half1x4 kernel
2731
- const int64_t nqptg = 1; // queries per threadgroup !! sync with kernel template arguments !!
2732
- const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
2733
-
2734
- GGML_ASSERT(nqptg <= 32);
2735
- GGML_ASSERT(nqptg % 1 == 0);
2736
- GGML_ASSERT(ncpsg % 32 == 0);
2737
-
2738
- // simdgroups per threadgroup (a.k.a. warps)
2739
- const int64_t nsgt = MAX(2, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32));
2740
-
2741
- int64_t nsg = 1;
2742
- while (nsg <= nsgt) {
2743
- nsg *= 2;
2744
- }
2745
- nsg /= 2;
2746
-
2747
- const size_t smem = (nqptg*(ne00 + 2*nsg*(ncpsg + nqptg)) + nsg*ne00)*(sizeof(float)/2);
2748
-
2749
- //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength);
2750
- GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength);
2751
- [encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0];
2752
-
2753
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
2754
- }
2755
- } break;
2756
- case GGML_OP_DUP:
2757
- case GGML_OP_CPY:
2758
- case GGML_OP_CONT:
2759
- {
2760
- GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
2761
-
2762
- int nth = MIN(1024, ne00/ggml_blck_size(src0->type));
2763
-
2764
- id<MTLComputePipelineState> pipeline = nil;
2765
-
2766
- switch (src0t) {
2767
- case GGML_TYPE_F32:
2768
- {
2769
- GGML_ASSERT(ne0 % ggml_blck_size(dst->type) == 0);
2770
-
2771
- switch (dstt) {
2772
- case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break;
2773
- case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break;
2774
- case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break;
2775
- case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break;
2776
- case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break;
2777
- case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0].pipeline; break;
2778
- case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1].pipeline; break;
2779
- case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL].pipeline; break;
2780
- default: GGML_ASSERT(false && "not implemented");
2781
- };
2782
- } break;
2783
- case GGML_TYPE_F16:
2784
- {
2785
- switch (dstt) {
2786
- case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline; break;
2787
- case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F32].pipeline; break;
2788
- default: GGML_ASSERT(false && "not implemented");
2789
- };
2790
- } break;
2791
- default: GGML_ASSERT(false && "not implemented");
2792
- }
2793
-
2794
- [encoder setComputePipelineState:pipeline];
2795
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2796
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2797
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
2798
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
2799
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
2800
- [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
2801
- [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
2802
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
2803
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
2804
- [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
2805
- [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
2806
- [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
2807
- [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
2808
- [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
2809
- [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
2810
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
2811
- [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
2812
- [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
2813
-
2814
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2815
- } break;
2816
- default:
2817
- {
2818
- GGML_METAL_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
2819
- GGML_ASSERT(false);
2820
- }
2821
- }
2822
-
2823
- if (should_capture) {
2824
- [encoder popDebugGroup];
2825
- }
2826
- }
2827
-
2828
- [encoder endEncoding];
2829
-
2830
- [command_buffer commit];
2831
- });
2832
-
2833
- // Wait for completion and check status of each command buffer
2834
- // needed to detect if the device ran out-of-memory for example (#1881)
2835
-
2836
- for (int i = 0; i < n_cb; ++i) {
2837
- id<MTLCommandBuffer> command_buffer = command_buffers[i];
2838
- [command_buffer waitUntilCompleted];
2839
-
2840
- MTLCommandBufferStatus status = [command_buffer status];
2841
- if (status != MTLCommandBufferStatusCompleted) {
2842
- GGML_METAL_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status);
2843
- if (status == MTLCommandBufferStatusError) {
2844
- NSString * error_code = [command_buffer error].localizedDescription;
2845
- GGML_METAL_LOG_INFO("error: %s\n", [error_code UTF8String]);
2846
- }
2847
-
2848
- return GGML_STATUS_FAILED;
2849
- }
2850
- }
2851
-
2852
- if (should_capture) {
2853
- [[MTLCaptureManager sharedCaptureManager] stopCapture];
2854
- }
2855
-
2856
- }
2857
- return GGML_STATUS_SUCCESS;
2858
- }
2859
-
2860
- ////////////////////////////////////////////////////////////////////////////////
2861
-
2862
- // backend interface
2863
-
2864
- // default buffer
2865
- static id<MTLDevice> g_backend_device = nil;
2866
- static int g_backend_device_ref_count = 0;
2867
-
2868
- static id<MTLDevice> ggml_backend_metal_get_device(void) {
2869
- if (g_backend_device == nil) {
2870
- g_backend_device = MTLCreateSystemDefaultDevice();
2871
- }
2872
-
2873
- g_backend_device_ref_count++;
2874
-
2875
- return g_backend_device;
2876
- }
2877
-
2878
- static void ggml_backend_metal_free_device(void) {
2879
- assert(g_backend_device_ref_count > 0);
2880
-
2881
- g_backend_device_ref_count--;
2882
-
2883
- if (g_backend_device_ref_count == 0) {
2884
- [g_backend_device release];
2885
- g_backend_device = nil;
2886
- }
2887
- }
2888
-
2889
- GGML_CALL static const char * ggml_backend_metal_buffer_get_name(ggml_backend_buffer_t buffer) {
2890
- return "Metal";
2891
-
2892
- UNUSED(buffer);
2893
- }
2894
-
2895
- GGML_CALL static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer) {
2896
- struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context;
2897
-
2898
- for (int i = 0; i < ctx->n_buffers; i++) {
2899
- [ctx->buffers[i].metal release];
2900
- }
2901
- ggml_backend_metal_free_device();
2902
-
2903
- if (ctx->owned) {
2904
- #if TARGET_OS_OSX
2905
- vm_deallocate((vm_map_t)mach_task_self(), (vm_address_t)ctx->all_data, ctx->all_size);
2906
- #else
2907
- free(ctx->all_data);
2908
- #endif
2909
- }
2910
-
2911
- free(ctx);
2912
- }
2913
-
2914
- GGML_CALL static void * ggml_backend_metal_buffer_get_base(ggml_backend_buffer_t buffer) {
2915
- struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context;
2916
-
2917
- return ctx->all_data;
2918
- }
2919
-
2920
- GGML_CALL static void ggml_backend_metal_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
2921
- memcpy((char *)tensor->data + offset, data, size);
2922
-
2923
- UNUSED(buffer);
2924
- }
2925
-
2926
- GGML_CALL static void ggml_backend_metal_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
2927
- memcpy(data, (const char *)tensor->data + offset, size);
2928
-
2929
- UNUSED(buffer);
2930
- }
2931
-
2932
- GGML_CALL static bool ggml_backend_metal_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst) {
2933
- if (ggml_backend_buffer_is_host(src->buffer)) {
2934
- memcpy(dst->data, src->data, ggml_nbytes(src));
2935
- return true;
2936
- }
2937
- return false;
2938
-
2939
- UNUSED(buffer);
2940
- }
2941
-
2942
- GGML_CALL static void ggml_backend_metal_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
2943
- struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context;
2944
-
2945
- memset(ctx->all_data, value, ctx->all_size);
2946
- }
2947
-
2948
- static struct ggml_backend_buffer_i ggml_backend_metal_buffer_i = {
2949
- /* .get_name = */ ggml_backend_metal_buffer_get_name,
2950
- /* .free_buffer = */ ggml_backend_metal_buffer_free_buffer,
2951
- /* .get_base = */ ggml_backend_metal_buffer_get_base,
2952
- /* .init_tensor = */ NULL,
2953
- /* .set_tensor = */ ggml_backend_metal_buffer_set_tensor,
2954
- /* .get_tensor = */ ggml_backend_metal_buffer_get_tensor,
2955
- /* .cpy_tensor = */ ggml_backend_metal_buffer_cpy_tensor,
2956
- /* .clear = */ ggml_backend_metal_buffer_clear,
2957
- /* .reset = */ NULL,
2958
- };
2959
-
2960
- // default buffer type
2961
-
2962
- GGML_CALL static const char * ggml_backend_metal_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
2963
- return "Metal";
2964
-
2965
- UNUSED(buft);
2966
- }
2967
-
2968
- static void ggml_backend_metal_log_allocated_size(id<MTLDevice> device, size_t size_aligned) {
2969
- #ifndef GGML_METAL_NDEBUG
2970
- #if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
2971
- if (@available(macOS 10.12, iOS 16.0, *)) {
2972
- GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB, (%8.2f / %8.2f)",
2973
- __func__,
2974
- size_aligned / 1024.0 / 1024.0,
2975
- device.currentAllocatedSize / 1024.0 / 1024.0,
2976
- device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
2977
-
2978
- if (device.currentAllocatedSize > device.recommendedMaxWorkingSetSize) {
2979
- GGML_METAL_LOG_WARN("%s: warning: current allocated size is greater than the recommended max working set size\n", __func__);
2980
- } else {
2981
- GGML_METAL_LOG_INFO("\n");
2982
- }
2983
- } else {
2984
- GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB, (%8.2f)\n",
2985
- __func__,
2986
- size_aligned / 1024.0 / 1024.0,
2987
- device.currentAllocatedSize / 1024.0 / 1024.0);
2988
- }
2989
- #endif
2990
- #endif
2991
- UNUSED(device);
2992
- UNUSED(size_aligned);
2993
- }
2994
-
2995
- GGML_CALL static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
2996
- struct ggml_backend_metal_buffer_context * ctx = malloc(sizeof(struct ggml_backend_metal_buffer_context));
2997
-
2998
- const size_t size_page = sysconf(_SC_PAGESIZE);
2999
-
3000
- size_t size_aligned = size;
3001
- if ((size_aligned % size_page) != 0) {
3002
- size_aligned += (size_page - (size_aligned % size_page));
3003
- }
3004
-
3005
- id<MTLDevice> device = ggml_backend_metal_get_device();
3006
-
3007
- ctx->all_data = ggml_metal_host_malloc(size_aligned);
3008
- ctx->all_size = size_aligned;
3009
- ctx->owned = true;
3010
- ctx->n_buffers = 1;
3011
-
3012
- if (ctx->all_data != NULL) {
3013
- ctx->buffers[0].data = ctx->all_data;
3014
- ctx->buffers[0].size = size;
3015
- ctx->buffers[0].metal = [device newBufferWithBytesNoCopy:ctx->all_data
3016
- length:size_aligned
3017
- options:MTLResourceStorageModeShared
3018
- deallocator:nil];
3019
- }
3020
-
3021
- if (ctx->all_data == NULL || ctx->buffers[0].metal == nil) {
3022
- GGML_METAL_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0);
3023
- free(ctx);
3024
- ggml_backend_metal_free_device();
3025
- return NULL;
3026
- }
3027
-
3028
- //ggml_backend_metal_log_allocated_size(device, size_aligned);
3029
-
3030
- return ggml_backend_buffer_init(buft, ggml_backend_metal_buffer_i, ctx, size);
3031
- }
3032
-
3033
- GGML_CALL static size_t ggml_backend_metal_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
3034
- return 32;
3035
- UNUSED(buft);
3036
- }
3037
-
3038
- GGML_CALL static size_t ggml_backend_metal_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
3039
- id<MTLDevice> device = ggml_backend_metal_get_device();
3040
- size_t max_size = device.maxBufferLength;
3041
- ggml_backend_metal_free_device();
3042
-
3043
- return max_size;
3044
-
3045
- UNUSED(buft);
3046
- }
3047
-
3048
- GGML_CALL static bool ggml_backend_metal_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
3049
- return true;
3050
-
3051
- UNUSED(buft);
3052
- }
3053
-
3054
- GGML_CALL ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void) {
3055
- static struct ggml_backend_buffer_type ggml_backend_buffer_type_metal = {
3056
- /* .iface = */ {
3057
- /* .get_name = */ ggml_backend_metal_buffer_type_get_name,
3058
- /* .alloc_buffer = */ ggml_backend_metal_buffer_type_alloc_buffer,
3059
- /* .get_alignment = */ ggml_backend_metal_buffer_type_get_alignment,
3060
- /* .get_max_size = */ ggml_backend_metal_buffer_type_get_max_size,
3061
- /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
3062
- /* .is_host = */ ggml_backend_metal_buffer_type_is_host,
3063
- },
3064
- /* .context = */ NULL,
3065
- };
3066
-
3067
- return &ggml_backend_buffer_type_metal;
3068
- }
3069
-
3070
- // buffer from ptr
3071
-
3072
- GGML_CALL ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size) {
3073
- struct ggml_backend_metal_buffer_context * ctx = malloc(sizeof(struct ggml_backend_metal_buffer_context));
3074
-
3075
- ctx->all_data = data;
3076
- ctx->all_size = size;
3077
- ctx->owned = false;
3078
- ctx->n_buffers = 0;
3079
-
3080
- const size_t size_page = sysconf(_SC_PAGESIZE);
3081
-
3082
- // page-align the data ptr
3083
- {
3084
- const uintptr_t offs = (uintptr_t) data % size_page;
3085
- data = (void *) ((char *) data - offs);
3086
- size += offs;
3087
- }
3088
-
3089
- size_t size_aligned = size;
3090
- if ((size_aligned % size_page) != 0) {
3091
- size_aligned += (size_page - (size_aligned % size_page));
3092
- }
3093
-
3094
- id<MTLDevice> device = ggml_backend_metal_get_device();
3095
-
3096
- // the buffer fits into the max buffer size allowed by the device
3097
- if (size_aligned <= device.maxBufferLength) {
3098
- ctx->buffers[ctx->n_buffers].data = data;
3099
- ctx->buffers[ctx->n_buffers].size = size;
3100
-
3101
- ctx->buffers[ctx->n_buffers].metal = [device newBufferWithBytesNoCopy:data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil];
3102
-
3103
- if (ctx->buffers[ctx->n_buffers].metal == nil) {
3104
- GGML_METAL_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0);
3105
- return false;
3106
- }
3107
-
3108
- ggml_backend_metal_log_allocated_size(device, size_aligned);
3109
-
3110
- ++ctx->n_buffers;
3111
- } else {
3112
- // this overlap between the views will guarantee that the tensor with the maximum size will fully fit into
3113
- // one of the views
3114
- const size_t size_ovlp = ((max_size + size_page - 1) / size_page + 1) * size_page; // round-up 2 pages just in case
3115
- const size_t size_step = device.maxBufferLength - size_ovlp;
3116
- const size_t size_view = device.maxBufferLength;
3117
-
3118
- for (size_t i = 0; i < size; i += size_step) {
3119
- const size_t size_step_aligned = (i + size_view <= size) ? size_view : (size_aligned - i);
3120
-
3121
- ctx->buffers[ctx->n_buffers].data = (void *) ((uint8_t *) data + i);
3122
- ctx->buffers[ctx->n_buffers].size = size_step_aligned;
3123
-
3124
- ctx->buffers[ctx->n_buffers].metal = [device newBufferWithBytesNoCopy:(void *) ((uint8_t *) data + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil];
3125
-
3126
- if (ctx->buffers[ctx->n_buffers].metal == nil) {
3127
- GGML_METAL_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_step_aligned / 1024.0 / 1024.0);
3128
- return false;
3129
- }
3130
-
3131
- ggml_backend_metal_log_allocated_size(device, size_step_aligned);
3132
-
3133
- if (i + size_step < size) {
3134
- GGML_METAL_LOG_INFO("\n");
3135
- }
3136
-
3137
- ++ctx->n_buffers;
3138
- }
3139
- }
3140
-
3141
- return ggml_backend_buffer_init(ggml_backend_metal_buffer_type(), ggml_backend_metal_buffer_i, ctx, size);
3142
- }
3143
-
3144
- // backend
3145
-
3146
- GGML_CALL static const char * ggml_backend_metal_name(ggml_backend_t backend) {
3147
- return "Metal";
3148
-
3149
- UNUSED(backend);
3150
- }
3151
-
3152
- GGML_CALL static void ggml_backend_metal_free(ggml_backend_t backend) {
3153
- struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
3154
- ggml_metal_free(ctx);
3155
- free(backend);
3156
- }
3157
-
3158
- GGML_CALL static ggml_backend_buffer_type_t ggml_backend_metal_get_default_buffer_type(ggml_backend_t backend) {
3159
- return ggml_backend_metal_buffer_type();
3160
-
3161
- UNUSED(backend);
3162
- }
3163
-
3164
- GGML_CALL static enum ggml_status ggml_backend_metal_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
3165
- struct ggml_metal_context * metal_ctx = (struct ggml_metal_context *)backend->context;
3166
-
3167
- return ggml_metal_graph_compute(metal_ctx, cgraph);
3168
- }
3169
-
3170
- GGML_CALL static bool ggml_backend_metal_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
3171
- struct ggml_metal_context * metal_ctx = (struct ggml_metal_context *)backend->context;
3172
-
3173
- return ggml_metal_supports_op(metal_ctx, op);
3174
- }
3175
-
3176
- GGML_CALL static bool ggml_backend_metal_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
3177
- return buft->iface.get_name == ggml_backend_metal_buffer_type_get_name;
3178
-
3179
- UNUSED(backend);
3180
- }
3181
-
3182
- static struct ggml_backend_i ggml_backend_metal_i = {
3183
- /* .get_name = */ ggml_backend_metal_name,
3184
- /* .free = */ ggml_backend_metal_free,
3185
- /* .get_default_buffer_type = */ ggml_backend_metal_get_default_buffer_type,
3186
- /* .set_tensor_async = */ NULL,
3187
- /* .get_tensor_async = */ NULL,
3188
- /* .cpy_tensor_async = */ NULL,
3189
- /* .synchronize = */ NULL,
3190
- /* .graph_plan_create = */ NULL,
3191
- /* .graph_plan_free = */ NULL,
3192
- /* .graph_plan_update = */ NULL,
3193
- /* .graph_plan_compute = */ NULL,
3194
- /* .graph_compute = */ ggml_backend_metal_graph_compute,
3195
- /* .supports_op = */ ggml_backend_metal_supports_op,
3196
- /* .supports_buft = */ ggml_backend_metal_supports_buft,
3197
- /* .offload_op = */ NULL,
3198
- /* .event_new = */ NULL,
3199
- /* .event_free = */ NULL,
3200
- /* .event_record = */ NULL,
3201
- /* .event_wait = */ NULL,
3202
- /* .event_synchronize = */ NULL,
3203
- };
3204
-
3205
- void ggml_backend_metal_log_set_callback(ggml_log_callback log_callback, void * user_data) {
3206
- ggml_metal_log_callback = log_callback;
3207
- ggml_metal_log_user_data = user_data;
3208
- }
3209
-
3210
- static ggml_guid_t ggml_backend_metal_guid(void) {
3211
- static ggml_guid guid = { 0x81, 0xa1, 0x8b, 0x1e, 0x71, 0xec, 0x79, 0xed, 0x2b, 0x85, 0xdc, 0x8a, 0x61, 0x98, 0x30, 0xe6 };
3212
- return &guid;
3213
- }
3214
-
3215
- ggml_backend_t ggml_backend_metal_init(void) {
3216
- struct ggml_metal_context * ctx = ggml_metal_init(GGML_DEFAULT_N_THREADS);
3217
-
3218
- if (ctx == NULL) {
3219
- return NULL;
3220
- }
3221
-
3222
- ggml_backend_t metal_backend = malloc(sizeof(struct ggml_backend));
3223
-
3224
- *metal_backend = (struct ggml_backend) {
3225
- /* .guid = */ ggml_backend_metal_guid(),
3226
- /* .interface = */ ggml_backend_metal_i,
3227
- /* .context = */ ctx,
3228
- };
3229
-
3230
- return metal_backend;
3231
- }
3232
-
3233
- bool ggml_backend_is_metal(ggml_backend_t backend) {
3234
- return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_metal_guid());
3235
- }
3236
-
3237
- void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
3238
- GGML_ASSERT(ggml_backend_is_metal(backend));
3239
-
3240
- struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
3241
-
3242
- ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS);
3243
- }
3244
-
3245
- bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family) {
3246
- GGML_ASSERT(ggml_backend_is_metal(backend));
3247
-
3248
- struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
3249
-
3250
- return [ctx->device supportsFamily:(MTLGPUFamilyApple1 + family - 1)];
3251
- }
3252
-
3253
- void ggml_backend_metal_capture_next_compute(ggml_backend_t backend) {
3254
- GGML_ASSERT(ggml_backend_is_metal(backend));
3255
-
3256
- struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
3257
- ctx->should_capture_next_compute = true;
3258
- }
3259
-
3260
- GGML_CALL ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data); // silence warning
3261
-
3262
- GGML_CALL ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data) {
3263
- return ggml_backend_metal_init();
3264
-
3265
- GGML_UNUSED(params);
3266
- GGML_UNUSED(user_data);
3267
- }