llama_cpp 0.16.2 → 0.17.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (177) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +18 -0
  3. data/README.md +7 -12
  4. data/ext/llama_cpp/extconf.rb +2 -43
  5. data/ext/llama_cpp/llama_cpp.cpp +8 -0
  6. data/lib/llama_cpp/version.rb +3 -3
  7. data/sig/llama_cpp.rbs +3 -0
  8. metadata +2 -171
  9. data/vendor/include/.gitkeep +0 -0
  10. data/vendor/lib/.gitkeep +0 -0
  11. data/vendor/tmp/llama.cpp/LICENSE +0 -21
  12. data/vendor/tmp/llama.cpp/Makefile +0 -1124
  13. data/vendor/tmp/llama.cpp/ggml-alloc.c +0 -1041
  14. data/vendor/tmp/llama.cpp/ggml-alloc.h +0 -76
  15. data/vendor/tmp/llama.cpp/ggml-backend-impl.h +0 -153
  16. data/vendor/tmp/llama.cpp/ggml-backend.c +0 -2225
  17. data/vendor/tmp/llama.cpp/ggml-backend.h +0 -236
  18. data/vendor/tmp/llama.cpp/ggml-blas.cpp +0 -363
  19. data/vendor/tmp/llama.cpp/ggml-blas.h +0 -23
  20. data/vendor/tmp/llama.cpp/ggml-common.h +0 -1805
  21. data/vendor/tmp/llama.cpp/ggml-cuda/acc.cu +0 -47
  22. data/vendor/tmp/llama.cpp/ggml-cuda/arange.cu +0 -34
  23. data/vendor/tmp/llama.cpp/ggml-cuda/argsort.cu +0 -104
  24. data/vendor/tmp/llama.cpp/ggml-cuda/binbcast.cu +0 -280
  25. data/vendor/tmp/llama.cpp/ggml-cuda/clamp.cu +0 -34
  26. data/vendor/tmp/llama.cpp/ggml-cuda/concat.cu +0 -196
  27. data/vendor/tmp/llama.cpp/ggml-cuda/convert.cu +0 -686
  28. data/vendor/tmp/llama.cpp/ggml-cuda/cpy.cu +0 -490
  29. data/vendor/tmp/llama.cpp/ggml-cuda/diagmask.cu +0 -40
  30. data/vendor/tmp/llama.cpp/ggml-cuda/dmmv.cu +0 -674
  31. data/vendor/tmp/llama.cpp/ggml-cuda/fattn-tile-f16.cu +0 -319
  32. data/vendor/tmp/llama.cpp/ggml-cuda/fattn-tile-f32.cu +0 -312
  33. data/vendor/tmp/llama.cpp/ggml-cuda/fattn.cu +0 -345
  34. data/vendor/tmp/llama.cpp/ggml-cuda/getrows.cu +0 -178
  35. data/vendor/tmp/llama.cpp/ggml-cuda/im2col.cu +0 -104
  36. data/vendor/tmp/llama.cpp/ggml-cuda/mmq.cu +0 -88
  37. data/vendor/tmp/llama.cpp/ggml-cuda/mmvq.cu +0 -419
  38. data/vendor/tmp/llama.cpp/ggml-cuda/norm.cu +0 -221
  39. data/vendor/tmp/llama.cpp/ggml-cuda/pad.cu +0 -49
  40. data/vendor/tmp/llama.cpp/ggml-cuda/pool2d.cu +0 -94
  41. data/vendor/tmp/llama.cpp/ggml-cuda/quantize.cu +0 -112
  42. data/vendor/tmp/llama.cpp/ggml-cuda/rope.cu +0 -271
  43. data/vendor/tmp/llama.cpp/ggml-cuda/scale.cu +0 -31
  44. data/vendor/tmp/llama.cpp/ggml-cuda/softmax.cu +0 -206
  45. data/vendor/tmp/llama.cpp/ggml-cuda/sumrows.cu +0 -40
  46. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +0 -5
  47. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +0 -5
  48. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +0 -5
  49. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +0 -5
  50. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +0 -5
  51. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +0 -5
  52. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +0 -5
  53. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +0 -5
  54. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +0 -5
  55. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +0 -5
  56. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +0 -5
  57. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +0 -5
  58. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +0 -5
  59. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +0 -5
  60. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +0 -5
  61. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +0 -5
  62. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +0 -5
  63. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +0 -5
  64. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +0 -5
  65. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +0 -5
  66. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +0 -5
  67. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +0 -5
  68. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +0 -5
  69. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +0 -5
  70. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +0 -5
  71. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +0 -5
  72. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +0 -5
  73. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +0 -5
  74. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +0 -5
  75. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +0 -5
  76. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +0 -5
  77. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +0 -5
  78. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +0 -5
  79. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +0 -5
  80. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +0 -5
  81. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +0 -5
  82. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +0 -5
  83. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +0 -5
  84. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +0 -5
  85. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +0 -5
  86. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +0 -5
  87. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +0 -5
  88. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +0 -5
  89. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +0 -5
  90. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +0 -5
  91. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +0 -5
  92. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +0 -5
  93. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +0 -5
  94. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +0 -5
  95. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +0 -5
  96. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +0 -5
  97. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +0 -5
  98. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +0 -5
  99. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +0 -5
  100. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +0 -5
  101. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +0 -5
  102. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +0 -5
  103. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +0 -5
  104. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +0 -5
  105. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +0 -5
  106. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +0 -5
  107. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +0 -5
  108. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +0 -5
  109. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +0 -5
  110. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +0 -5
  111. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +0 -5
  112. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +0 -5
  113. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +0 -5
  114. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +0 -5
  115. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +0 -5
  116. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +0 -5
  117. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +0 -5
  118. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +0 -5
  119. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +0 -5
  120. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +0 -5
  121. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +0 -5
  122. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +0 -5
  123. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +0 -5
  124. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +0 -5
  125. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +0 -5
  126. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +0 -5
  127. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +0 -5
  128. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +0 -5
  129. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +0 -5
  130. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +0 -5
  131. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +0 -5
  132. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb16.cu +0 -10
  133. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb32.cu +0 -9
  134. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb16.cu +0 -10
  135. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb32.cu +0 -10
  136. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb8.cu +0 -8
  137. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q2_k.cu +0 -5
  138. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q3_k.cu +0 -5
  139. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q4_0.cu +0 -5
  140. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q4_1.cu +0 -5
  141. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q4_k.cu +0 -5
  142. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q5_0.cu +0 -5
  143. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q5_1.cu +0 -5
  144. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q5_k.cu +0 -5
  145. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q6_k.cu +0 -5
  146. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q8_0.cu +0 -5
  147. data/vendor/tmp/llama.cpp/ggml-cuda/tsembd.cu +0 -47
  148. data/vendor/tmp/llama.cpp/ggml-cuda/unary.cu +0 -314
  149. data/vendor/tmp/llama.cpp/ggml-cuda/upscale.cu +0 -51
  150. data/vendor/tmp/llama.cpp/ggml-cuda.cu +0 -3069
  151. data/vendor/tmp/llama.cpp/ggml-cuda.h +0 -44
  152. data/vendor/tmp/llama.cpp/ggml-impl.h +0 -651
  153. data/vendor/tmp/llama.cpp/ggml-kompute.cpp +0 -2038
  154. data/vendor/tmp/llama.cpp/ggml-kompute.h +0 -46
  155. data/vendor/tmp/llama.cpp/ggml-metal.h +0 -66
  156. data/vendor/tmp/llama.cpp/ggml-metal.m +0 -3273
  157. data/vendor/tmp/llama.cpp/ggml-metal.metal +0 -6540
  158. data/vendor/tmp/llama.cpp/ggml-quants.c +0 -14994
  159. data/vendor/tmp/llama.cpp/ggml-quants.h +0 -133
  160. data/vendor/tmp/llama.cpp/ggml-rpc.cpp +0 -1178
  161. data/vendor/tmp/llama.cpp/ggml-rpc.h +0 -24
  162. data/vendor/tmp/llama.cpp/ggml-sycl.cpp +0 -6351
  163. data/vendor/tmp/llama.cpp/ggml-sycl.h +0 -40
  164. data/vendor/tmp/llama.cpp/ggml-vulkan-shaders.hpp +0 -144508
  165. data/vendor/tmp/llama.cpp/ggml-vulkan.cpp +0 -7183
  166. data/vendor/tmp/llama.cpp/ggml-vulkan.h +0 -29
  167. data/vendor/tmp/llama.cpp/ggml.c +0 -22506
  168. data/vendor/tmp/llama.cpp/ggml.h +0 -2458
  169. data/vendor/tmp/llama.cpp/llama.cpp +0 -18985
  170. data/vendor/tmp/llama.cpp/llama.h +0 -1147
  171. data/vendor/tmp/llama.cpp/scripts/get-flags.mk +0 -38
  172. data/vendor/tmp/llama.cpp/sgemm.cpp +0 -1032
  173. data/vendor/tmp/llama.cpp/sgemm.h +0 -14
  174. data/vendor/tmp/llama.cpp/unicode-data.cpp +0 -7033
  175. data/vendor/tmp/llama.cpp/unicode-data.h +0 -20
  176. data/vendor/tmp/llama.cpp/unicode.cpp +0 -810
  177. data/vendor/tmp/llama.cpp/unicode.h +0 -63
@@ -1,3273 +0,0 @@
1
- #import "ggml-metal.h"
2
-
3
- #import "ggml-backend-impl.h"
4
- #import "ggml.h"
5
-
6
- #import <Foundation/Foundation.h>
7
-
8
- #import <Metal/Metal.h>
9
-
10
- #undef MIN
11
- #undef MAX
12
- #define MIN(a, b) ((a) < (b) ? (a) : (b))
13
- #define MAX(a, b) ((a) > (b) ? (a) : (b))
14
-
15
- #ifdef GGML_METAL_NDEBUG
16
- #define GGML_METAL_LOG_INFO(...)
17
- #define GGML_METAL_LOG_WARN(...)
18
- #define GGML_METAL_LOG_ERROR(...)
19
- #else
20
- #define GGML_METAL_LOG_INFO(...) ggml_metal_log(GGML_LOG_LEVEL_INFO, __VA_ARGS__)
21
- #define GGML_METAL_LOG_WARN(...) ggml_metal_log(GGML_LOG_LEVEL_WARN, __VA_ARGS__)
22
- #define GGML_METAL_LOG_ERROR(...) ggml_metal_log(GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
23
- #endif
24
-
25
- #define UNUSED(x) (void)(x)
26
-
27
- struct ggml_metal_kernel {
28
- id<MTLComputePipelineState> pipeline;
29
- };
30
-
31
- enum ggml_metal_kernel_type {
32
- GGML_METAL_KERNEL_TYPE_ADD,
33
- GGML_METAL_KERNEL_TYPE_ADD_ROW,
34
- GGML_METAL_KERNEL_TYPE_MUL,
35
- GGML_METAL_KERNEL_TYPE_MUL_ROW,
36
- GGML_METAL_KERNEL_TYPE_DIV,
37
- GGML_METAL_KERNEL_TYPE_DIV_ROW,
38
- GGML_METAL_KERNEL_TYPE_REPEAT_F32,
39
- GGML_METAL_KERNEL_TYPE_REPEAT_F16,
40
- GGML_METAL_KERNEL_TYPE_REPEAT_I32,
41
- GGML_METAL_KERNEL_TYPE_REPEAT_I16,
42
- GGML_METAL_KERNEL_TYPE_SCALE,
43
- GGML_METAL_KERNEL_TYPE_SCALE_4,
44
- GGML_METAL_KERNEL_TYPE_CLAMP,
45
- GGML_METAL_KERNEL_TYPE_TANH,
46
- GGML_METAL_KERNEL_TYPE_RELU,
47
- GGML_METAL_KERNEL_TYPE_SIGMOID,
48
- GGML_METAL_KERNEL_TYPE_GELU,
49
- GGML_METAL_KERNEL_TYPE_GELU_4,
50
- GGML_METAL_KERNEL_TYPE_GELU_QUICK,
51
- GGML_METAL_KERNEL_TYPE_GELU_QUICK_4,
52
- GGML_METAL_KERNEL_TYPE_SILU,
53
- GGML_METAL_KERNEL_TYPE_SILU_4,
54
- GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16,
55
- GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4,
56
- GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32,
57
- GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4,
58
- GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF,
59
- GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8,
60
- GGML_METAL_KERNEL_TYPE_GET_ROWS_F32,
61
- GGML_METAL_KERNEL_TYPE_GET_ROWS_F16,
62
- GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0,
63
- GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1,
64
- GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0,
65
- GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1,
66
- GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0,
67
- GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K,
68
- GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K,
69
- GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K,
70
- GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K,
71
- GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K,
72
- GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS,
73
- GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS,
74
- GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS,
75
- GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S,
76
- GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S,
77
- GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S,
78
- GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M,
79
- GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,
80
- GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,
81
- GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
82
- GGML_METAL_KERNEL_TYPE_RMS_NORM,
83
- GGML_METAL_KERNEL_TYPE_GROUP_NORM,
84
- GGML_METAL_KERNEL_TYPE_NORM,
85
- GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
86
- GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,
87
- GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
88
- GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,
89
- GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4,
90
- GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32,
91
- GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32,
92
- GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32,
93
- GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32,
94
- GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32,
95
- GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32,
96
- GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32,
97
- GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32,
98
- GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32,
99
- GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32,
100
- GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32,
101
- GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32,
102
- GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32,
103
- GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32,
104
- GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32,
105
- GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32,
106
- GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32,
107
- GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32,
108
- GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32,
109
- GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,
110
- //GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16,
111
- GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32,
112
- //GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW,
113
- //GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4,
114
- GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32,
115
- GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32,
116
- GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32,
117
- GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32,
118
- GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32,
119
- GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32,
120
- GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32,
121
- GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32,
122
- GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32,
123
- GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32,
124
- GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32,
125
- GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32,
126
- GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32,
127
- GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32,
128
- GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32,
129
- GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32,
130
- GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32,
131
- GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32,
132
- GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32,
133
- GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
134
- GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,
135
- GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,
136
- GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32,
137
- GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32,
138
- GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32,
139
- GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32,
140
- GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32,
141
- GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32,
142
- GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32,
143
- GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32,
144
- GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32,
145
- GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32,
146
- GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32,
147
- GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32,
148
- GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32,
149
- GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32,
150
- GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32,
151
- GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32,
152
- GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,
153
- GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,
154
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,
155
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,
156
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,
157
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32,
158
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32,
159
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32,
160
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32,
161
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32,
162
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32,
163
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32,
164
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32,
165
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32,
166
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32,
167
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32,
168
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32,
169
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32,
170
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32,
171
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32,
172
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32,
173
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,
174
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,
175
- GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32,
176
- GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16,
177
- GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32,
178
- GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16,
179
- GGML_METAL_KERNEL_TYPE_IM2COL_F16,
180
- GGML_METAL_KERNEL_TYPE_IM2COL_F32,
181
- GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
182
- GGML_METAL_KERNEL_TYPE_PAD_F32,
183
- GGML_METAL_KERNEL_TYPE_ARANGE_F32,
184
- GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32,
185
- GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
186
- GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC,
187
- GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32,
188
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64,
189
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80,
190
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96,
191
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,
192
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,
193
- //GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
194
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
195
- //GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
196
- GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
197
- GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
198
- GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
199
- GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0,
200
- GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1,
201
- GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0,
202
- GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,
203
- GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,
204
- GGML_METAL_KERNEL_TYPE_CPY_F16_F16,
205
- GGML_METAL_KERNEL_TYPE_CPY_F16_F32,
206
- GGML_METAL_KERNEL_TYPE_CONCAT,
207
- GGML_METAL_KERNEL_TYPE_SQR,
208
- GGML_METAL_KERNEL_TYPE_SUM_ROWS,
209
-
210
- GGML_METAL_KERNEL_TYPE_COUNT
211
- };
212
-
213
- struct ggml_metal_context {
214
- int n_cb;
215
-
216
- id<MTLDevice> device;
217
- id<MTLCommandQueue> queue;
218
-
219
- dispatch_queue_t d_queue;
220
-
221
- struct ggml_metal_kernel kernels[GGML_METAL_KERNEL_TYPE_COUNT];
222
-
223
- bool support_simdgroup_reduction;
224
- bool support_simdgroup_mm;
225
-
226
- bool should_capture_next_compute;
227
- };
228
-
229
- // MSL code
230
- // TODO: move the contents here when ready
231
- // for now it is easier to work in a separate file
232
- // static NSString * const msl_library_source = @"see metal.metal";
233
-
234
- // Here to assist with NSBundle Path Hack
235
- @interface GGMLMetalClass : NSObject
236
- @end
237
- @implementation GGMLMetalClass
238
- @end
239
-
240
- static void ggml_metal_default_log_callback(enum ggml_log_level level, const char * msg, void * user_data) {
241
- fprintf(stderr, "%s", msg);
242
-
243
- UNUSED(level);
244
- UNUSED(user_data);
245
- }
246
-
247
- ggml_log_callback ggml_metal_log_callback = ggml_metal_default_log_callback;
248
- void * ggml_metal_log_user_data = NULL;
249
-
250
- GGML_ATTRIBUTE_FORMAT(2, 3)
251
- static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
252
- if (ggml_metal_log_callback != NULL) {
253
- va_list args;
254
- va_start(args, format);
255
- char buffer[128];
256
- int len = vsnprintf(buffer, 128, format, args);
257
- if (len < 128) {
258
- ggml_metal_log_callback(level, buffer, ggml_metal_log_user_data);
259
- } else {
260
- char* buffer2 = malloc(len+1);
261
- va_end(args);
262
- va_start(args, format);
263
- vsnprintf(buffer2, len+1, format, args);
264
- buffer2[len] = 0;
265
- ggml_metal_log_callback(level, buffer2, ggml_metal_log_user_data);
266
- free(buffer2);
267
- }
268
- va_end(args);
269
- }
270
- }
271
-
272
- static void * ggml_metal_host_malloc(size_t n) {
273
- void * data = NULL;
274
-
275
- #if TARGET_OS_OSX
276
- kern_return_t err = vm_allocate((vm_map_t) mach_task_self(), (void *) &data, n, VM_FLAGS_ANYWHERE);
277
- if (err != KERN_SUCCESS) {
278
- GGML_METAL_LOG_ERROR("%s: error: vm_allocate failed\n", __func__);
279
- return NULL;
280
- }
281
- #else
282
- const int result = posix_memalign((void **) &data, sysconf(_SC_PAGESIZE), n);
283
- if (result != 0) {
284
- GGML_METAL_LOG_ERROR("%s: error: posix_memalign failed\n", __func__);
285
- return NULL;
286
- }
287
- #endif
288
-
289
- return data;
290
- }
291
-
292
- static struct ggml_metal_context * ggml_metal_init(int n_cb) {
293
- GGML_METAL_LOG_INFO("%s: allocating\n", __func__);
294
-
295
- #if TARGET_OS_OSX && !GGML_METAL_NDEBUG
296
- // Show all the Metal device instances in the system
297
- NSArray * devices = MTLCopyAllDevices();
298
- for (id<MTLDevice> device in devices) {
299
- GGML_METAL_LOG_INFO("%s: found device: %s\n", __func__, [[device name] UTF8String]);
300
- }
301
- [devices release]; // since it was created by a *Copy* C method
302
- #endif
303
-
304
- // Pick and show default Metal device
305
- id<MTLDevice> device = MTLCreateSystemDefaultDevice();
306
- GGML_METAL_LOG_INFO("%s: picking default device: %s\n", __func__, [[device name] UTF8String]);
307
-
308
- // Configure context
309
- struct ggml_metal_context * ctx = malloc(sizeof(struct ggml_metal_context));
310
- ctx->device = device;
311
- ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS);
312
- ctx->queue = [ctx->device newCommandQueue];
313
- ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
314
-
315
- id<MTLLibrary> metal_library;
316
-
317
- // load library
318
- //
319
- // - first check if the library is embedded
320
- // - then check if the library is in the bundle
321
- // - if not found, load the source and compile it
322
- // - if that fails, return NULL
323
- {
324
- NSBundle * bundle = nil;
325
- #ifdef SWIFT_PACKAGE
326
- bundle = SWIFTPM_MODULE_BUNDLE;
327
- #else
328
- bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
329
- #endif
330
-
331
- NSError * error = nil;
332
-
333
- #if GGML_METAL_EMBED_LIBRARY
334
- const bool try_metallib = false;
335
- #else
336
- const bool try_metallib = true;
337
- #endif
338
-
339
- NSString * path_lib = [bundle pathForResource:@"default" ofType:@"metallib"];
340
- if (try_metallib && path_lib != nil) {
341
- // pre-compiled library found
342
- NSURL * libURL = [NSURL fileURLWithPath:path_lib];
343
- GGML_METAL_LOG_INFO("%s: loading '%s'\n", __func__, [path_lib UTF8String]);
344
-
345
- metal_library = [ctx->device newLibraryWithURL:libURL error:&error];
346
- if (error) {
347
- GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
348
- return NULL;
349
- }
350
- } else {
351
- #if GGML_METAL_EMBED_LIBRARY
352
- GGML_METAL_LOG_INFO("%s: using embedded metal library\n", __func__);
353
-
354
- extern const char ggml_metallib_start[];
355
- extern const char ggml_metallib_end[];
356
-
357
- NSString * src = [[NSString alloc] initWithBytes:ggml_metallib_start length:(ggml_metallib_end-ggml_metallib_start) encoding:NSUTF8StringEncoding];
358
- #else
359
- GGML_METAL_LOG_INFO("%s: default.metallib not found, loading from source\n", __func__);
360
-
361
- NSString * path_source;
362
- NSString * path_resource = [[NSProcessInfo processInfo].environment objectForKey:@"GGML_METAL_PATH_RESOURCES"];
363
-
364
- GGML_METAL_LOG_INFO("%s: GGML_METAL_PATH_RESOURCES = %s\n", __func__, path_resource ? [path_resource UTF8String] : "nil");
365
-
366
- if (path_resource) {
367
- path_source = [path_resource stringByAppendingPathComponent:@"ggml-metal.metal"];
368
- } else {
369
- path_source = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
370
- }
371
-
372
- if (path_source == nil) {
373
- GGML_METAL_LOG_WARN("%s: error: could not use bundle path to find ggml-metal.metal, falling back to trying cwd\n", __func__);
374
- path_source = @"ggml-metal.metal";
375
- }
376
-
377
- GGML_METAL_LOG_INFO("%s: loading '%s'\n", __func__, [path_source UTF8String]);
378
-
379
- NSString * src = [NSString stringWithContentsOfFile:path_source encoding:NSUTF8StringEncoding error:&error];
380
- if (error) {
381
- GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
382
- return NULL;
383
- }
384
- #endif // GGML_METAL_EMBED_LIBRARY
385
-
386
- @autoreleasepool {
387
- // dictionary of preprocessor macros
388
- NSMutableDictionary * prep = [NSMutableDictionary dictionary];
389
-
390
- MTLCompileOptions* options = [MTLCompileOptions new];
391
- options.preprocessorMacros = prep;
392
-
393
- //[options setFastMathEnabled:false];
394
-
395
- metal_library = [ctx->device newLibraryWithSource:src options:options error:&error];
396
- if (error) {
397
- GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
398
- return NULL;
399
- }
400
- }
401
- }
402
- }
403
-
404
- // print MTL GPU family:
405
- GGML_METAL_LOG_INFO("%s: GPU name: %s\n", __func__, [[ctx->device name] UTF8String]);
406
-
407
- const NSInteger MTLGPUFamilyMetal3 = 5001;
408
-
409
- // determine max supported GPU family
410
- // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
411
- // https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
412
- {
413
- for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) {
414
- if ([ctx->device supportsFamily:i]) {
415
- GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - (int) MTLGPUFamilyApple1 + 1, i);
416
- break;
417
- }
418
- }
419
-
420
- for (int i = MTLGPUFamilyCommon1 + 5; i >= MTLGPUFamilyCommon1; --i) {
421
- if ([ctx->device supportsFamily:i]) {
422
- GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyCommon%d (%d)\n", __func__, i - (int) MTLGPUFamilyCommon1 + 1, i);
423
- break;
424
- }
425
- }
426
-
427
- for (int i = MTLGPUFamilyMetal3 + 5; i >= MTLGPUFamilyMetal3; --i) {
428
- if ([ctx->device supportsFamily:i]) {
429
- GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyMetal%d (%d)\n", __func__, i - (int) MTLGPUFamilyMetal3 + 3, i);
430
- break;
431
- }
432
- }
433
- }
434
-
435
- ctx->support_simdgroup_reduction = [ctx->device supportsFamily:MTLGPUFamilyApple7];
436
- ctx->support_simdgroup_reduction |= [ctx->device supportsFamily:MTLGPUFamilyMetal3];
437
-
438
- ctx->support_simdgroup_mm = [ctx->device supportsFamily:MTLGPUFamilyApple7];
439
-
440
- GGML_METAL_LOG_INFO("%s: simdgroup reduction support = %s\n", __func__, ctx->support_simdgroup_reduction ? "true" : "false");
441
- GGML_METAL_LOG_INFO("%s: simdgroup matrix mul. support = %s\n", __func__, ctx->support_simdgroup_mm ? "true" : "false");
442
- GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
443
-
444
- ctx->should_capture_next_compute = false;
445
-
446
- #if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
447
- if (@available(macOS 10.12, iOS 16.0, *)) {
448
- GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1e6);
449
- }
450
- #elif TARGET_OS_OSX
451
- if (ctx->device.maxTransferRate != 0) {
452
- GGML_METAL_LOG_INFO("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1e6);
453
- } else {
454
- GGML_METAL_LOG_INFO("%s: maxTransferRate = built-in GPU\n", __func__);
455
- }
456
- #endif
457
-
458
- // load kernels
459
- {
460
- NSError * error = nil;
461
-
462
- for (int i = 0; i < GGML_METAL_KERNEL_TYPE_COUNT; ++i) {
463
- ctx->kernels[i].pipeline = nil;
464
- }
465
-
466
- /*
467
- GGML_METAL_LOG_INFO("%s: loaded %-40s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) kernel->pipeline, \
468
- (int) kernel->pipeline.maxTotalThreadsPerThreadgroup, \
469
- (int) kernel->pipeline.threadExecutionWidth); \
470
- */
471
- #define GGML_METAL_ADD_KERNEL(e, name, supported) \
472
- if (supported) { \
473
- struct ggml_metal_kernel * kernel = &ctx->kernels[e]; \
474
- id<MTLFunction> metal_function = [metal_library newFunctionWithName:@"kernel_"#name]; \
475
- kernel->pipeline = [ctx->device newComputePipelineStateWithFunction:metal_function error:&error]; \
476
- [metal_function release]; \
477
- if (error) { \
478
- GGML_METAL_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
479
- [metal_library release]; \
480
- return NULL; \
481
- } \
482
- } else { \
483
- GGML_METAL_LOG_WARN("%s: skipping %-40s (not supported)\n", __func__, "kernel_"#name); \
484
- }
485
-
486
- // simd_sum and simd_max requires MTLGPUFamilyApple7
487
-
488
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true);
489
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true);
490
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true);
491
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true);
492
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true);
493
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true);
494
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F32, repeat_f32, true);
495
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F16, repeat_f16, true);
496
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I32, repeat_i32, true);
497
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I16, repeat_i16, true);
498
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true);
499
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true);
500
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true);
501
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true);
502
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true);
503
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIGMOID, sigmoid, true);
504
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true);
505
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_4, gelu_4, true);
506
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true);
507
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true);
508
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true);
509
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true);
510
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16, ctx->support_simdgroup_reduction);
511
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, ctx->support_simdgroup_reduction);
512
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, ctx->support_simdgroup_reduction);
513
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4, soft_max_f32_4, ctx->support_simdgroup_reduction);
514
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true);
515
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true);
516
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true);
517
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, get_rows_f16, true);
518
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, get_rows_q4_0, true);
519
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, get_rows_q4_1, true);
520
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true);
521
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1, get_rows_q5_1, true);
522
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0, get_rows_q8_0, true);
523
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K, get_rows_q2_K, true);
524
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K, get_rows_q3_K, true);
525
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K, get_rows_q4_K, true);
526
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K, get_rows_q5_K, true);
527
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K, get_rows_q6_K, true);
528
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true);
529
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true);
530
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS, get_rows_iq3_xxs, true);
531
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S, get_rows_iq3_s, true);
532
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S, get_rows_iq2_s, true);
533
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, get_rows_iq1_s, true);
534
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M, get_rows_iq1_m, true);
535
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true);
536
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true);
537
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
538
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction);
539
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction);
540
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
541
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, ctx->support_simdgroup_reduction);
542
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, ctx->support_simdgroup_reduction);
543
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, ctx->support_simdgroup_reduction);
544
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, ctx->support_simdgroup_reduction);
545
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, ctx->support_simdgroup_reduction);
546
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, mul_mv_q4_0_f32, ctx->support_simdgroup_reduction);
547
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32, ctx->support_simdgroup_reduction);
548
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, ctx->support_simdgroup_reduction);
549
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, ctx->support_simdgroup_reduction);
550
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, ctx->support_simdgroup_reduction);
551
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, ctx->support_simdgroup_reduction);
552
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, ctx->support_simdgroup_reduction);
553
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, ctx->support_simdgroup_reduction);
554
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32, mul_mv_q5_K_f32, ctx->support_simdgroup_reduction);
555
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, mul_mv_q6_K_f32, ctx->support_simdgroup_reduction);
556
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, ctx->support_simdgroup_reduction);
557
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, ctx->support_simdgroup_reduction);
558
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, ctx->support_simdgroup_reduction);
559
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32, mul_mv_iq3_s_f32, ctx->support_simdgroup_reduction);
560
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32, mul_mv_iq2_s_f32, ctx->support_simdgroup_reduction);
561
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, ctx->support_simdgroup_reduction);
562
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32, mul_mv_iq1_m_f32, ctx->support_simdgroup_reduction);
563
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, ctx->support_simdgroup_reduction);
564
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32, ctx->support_simdgroup_reduction);
565
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction);
566
- //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, ctx->support_simdgroup_reduction);
567
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, ctx->support_simdgroup_reduction);
568
- //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, ctx->support_simdgroup_reduction);
569
- //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, ctx->support_simdgroup_reduction);
570
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, ctx->support_simdgroup_reduction);
571
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, ctx->support_simdgroup_reduction);
572
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, ctx->support_simdgroup_reduction);
573
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, ctx->support_simdgroup_reduction);
574
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, ctx->support_simdgroup_reduction);
575
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, ctx->support_simdgroup_reduction);
576
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, ctx->support_simdgroup_reduction);
577
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32, ctx->support_simdgroup_reduction);
578
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32, mul_mv_id_q5_K_f32, ctx->support_simdgroup_reduction);
579
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, mul_mv_id_q6_K_f32, ctx->support_simdgroup_reduction);
580
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, ctx->support_simdgroup_reduction);
581
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, ctx->support_simdgroup_reduction);
582
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, ctx->support_simdgroup_reduction);
583
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32, mul_mv_id_iq3_s_f32, ctx->support_simdgroup_reduction);
584
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32, mul_mv_id_iq2_s_f32, ctx->support_simdgroup_reduction);
585
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, ctx->support_simdgroup_reduction);
586
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32, mul_mv_id_iq1_m_f32, ctx->support_simdgroup_reduction);
587
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, ctx->support_simdgroup_reduction);
588
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, ctx->support_simdgroup_reduction);
589
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm);
590
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm);
591
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm);
592
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, ctx->support_simdgroup_mm);
593
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, ctx->support_simdgroup_mm);
594
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, ctx->support_simdgroup_mm);
595
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, ctx->support_simdgroup_mm);
596
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, ctx->support_simdgroup_mm);
597
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, ctx->support_simdgroup_mm);
598
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, ctx->support_simdgroup_mm);
599
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32, mul_mm_q5_K_f32, ctx->support_simdgroup_mm);
600
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, mul_mm_q6_K_f32, ctx->support_simdgroup_mm);
601
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, ctx->support_simdgroup_mm);
602
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, ctx->support_simdgroup_mm);
603
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, ctx->support_simdgroup_mm);
604
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32, mul_mm_iq3_s_f32, ctx->support_simdgroup_mm);
605
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32, mul_mm_iq2_s_f32, ctx->support_simdgroup_mm);
606
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, ctx->support_simdgroup_mm);
607
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, ctx->support_simdgroup_mm);
608
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, ctx->support_simdgroup_mm);
609
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, ctx->support_simdgroup_mm);
610
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm);
611
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm);
612
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, ctx->support_simdgroup_mm);
613
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, ctx->support_simdgroup_mm);
614
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, ctx->support_simdgroup_mm);
615
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, mul_mm_id_q5_1_f32, ctx->support_simdgroup_mm);
616
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, mul_mm_id_q8_0_f32, ctx->support_simdgroup_mm);
617
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, mul_mm_id_q2_K_f32, ctx->support_simdgroup_mm);
618
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, mul_mm_id_q3_K_f32, ctx->support_simdgroup_mm);
619
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32, mul_mm_id_q4_K_f32, ctx->support_simdgroup_mm);
620
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32, mul_mm_id_q5_K_f32, ctx->support_simdgroup_mm);
621
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32, ctx->support_simdgroup_mm);
622
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, ctx->support_simdgroup_mm);
623
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, ctx->support_simdgroup_mm);
624
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, ctx->support_simdgroup_mm);
625
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32, mul_mm_id_iq3_s_f32, ctx->support_simdgroup_mm);
626
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, mul_mm_id_iq2_s_f32, ctx->support_simdgroup_mm);
627
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, ctx->support_simdgroup_mm);
628
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, ctx->support_simdgroup_mm);
629
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm);
630
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm);
631
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, rope_norm_f32, true);
632
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, rope_norm_f16, true);
633
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, rope_neox_f32, true);
634
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, rope_neox_f16, true);
635
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
636
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
637
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
638
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true);
639
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true);
640
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARANGE_F32, arange_f32, true);
641
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
642
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true);
643
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true);
644
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, ctx->support_simdgroup_mm);
645
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, ctx->support_simdgroup_mm);
646
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, ctx->support_simdgroup_mm);
647
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, ctx->support_simdgroup_mm);
648
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, ctx->support_simdgroup_mm);
649
- //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, ctx->support_simdgroup_mm);
650
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, ctx->support_simdgroup_reduction);
651
- //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, ctx->support_simdgroup_reduction);
652
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
653
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
654
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
655
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true);
656
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true);
657
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true);
658
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true);
659
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true);
660
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
661
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true);
662
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true);
663
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true);
664
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
665
- }
666
-
667
- [metal_library release];
668
- return ctx;
669
- }
670
-
671
- static void ggml_metal_free(struct ggml_metal_context * ctx) {
672
- GGML_METAL_LOG_INFO("%s: deallocating\n", __func__);
673
-
674
- for (int i = 0; i < GGML_METAL_KERNEL_TYPE_COUNT; ++i) {
675
- [ctx->kernels[i].pipeline release];
676
- }
677
-
678
- [ctx->queue release];
679
- [ctx->device release];
680
-
681
- dispatch_release(ctx->d_queue);
682
-
683
- free(ctx);
684
- }
685
-
686
- // temporarily defined here for compatibility between ggml-backend and the old API
687
-
688
- struct ggml_backend_metal_buffer {
689
- void * data;
690
- size_t size;
691
-
692
- id<MTLBuffer> metal;
693
- };
694
-
695
- struct ggml_backend_metal_buffer_context {
696
- void * all_data;
697
- size_t all_size;
698
- bool owned;
699
-
700
- // multiple buffers are used only to avoid the maximum buffer size limitation when using mmap
701
- int n_buffers;
702
- struct ggml_backend_metal_buffer buffers[GGML_METAL_MAX_BUFFERS];
703
- };
704
-
705
- // finds the Metal buffer that contains the tensor data on the GPU device
706
- // the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the
707
- // Metal buffer based on the host memory pointer
708
- //
709
- static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_tensor * t, size_t * offs) {
710
- //GGML_METAL_LOG_INFO("%s: data tensor '%16s', offs_data = %8ld, offs_eval = %8ld, offs_cach = %8ld\n", __func__, t->name, offs_data, offs_eval, offs_cach);
711
-
712
- const int64_t tsize = ggml_nbytes(t);
713
-
714
- ggml_backend_buffer_t buffer = t->view_src ? t->view_src->buffer : t->buffer;
715
-
716
- struct ggml_backend_metal_buffer_context * buf_ctx = (struct ggml_backend_metal_buffer_context *) buffer->context;
717
-
718
- // find the view that contains the tensor fully
719
- for (int i = 0; i < buf_ctx->n_buffers; ++i) {
720
- const int64_t ioffs = (int64_t) t->data - (int64_t) buf_ctx->buffers[i].data;
721
-
722
- //GGML_METAL_LOG_INFO("ioffs = %10ld, tsize = %10ld, sum = %10ld, buf_ctx->buffers[%d].size = %10ld\n", ioffs, tsize, ioffs + tsize, i, buf_ctx->buffers[i].size);
723
- if (ioffs >= 0 && ioffs + tsize <= (int64_t) buf_ctx->buffers[i].size) {
724
- *offs = (size_t) ioffs;
725
-
726
- //GGML_METAL_LOG_INFO("%s: tensor '%16s', offs = %8ld\n", __func__, t->name, *offs);
727
-
728
- return buf_ctx->buffers[i].metal;
729
- }
730
- }
731
-
732
- GGML_METAL_LOG_ERROR("%s: error: tensor '%s' buffer is nil\n", __func__, t->name);
733
-
734
- return nil;
735
- }
736
-
737
- static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const struct ggml_tensor * op) {
738
- for (size_t i = 0, n = 3; i < n; ++i) {
739
- if (op->src[i] != NULL && op->src[i]->type == GGML_TYPE_BF16) {
740
- return false;
741
- }
742
- }
743
-
744
- switch (op->op) {
745
- case GGML_OP_UNARY:
746
- switch (ggml_get_unary_op(op)) {
747
- case GGML_UNARY_OP_TANH:
748
- case GGML_UNARY_OP_RELU:
749
- case GGML_UNARY_OP_SIGMOID:
750
- case GGML_UNARY_OP_GELU:
751
- case GGML_UNARY_OP_GELU_QUICK:
752
- case GGML_UNARY_OP_SILU:
753
- return ggml_is_contiguous(op->src[0]);
754
- default:
755
- return false;
756
- }
757
- case GGML_OP_NONE:
758
- case GGML_OP_RESHAPE:
759
- case GGML_OP_VIEW:
760
- case GGML_OP_TRANSPOSE:
761
- case GGML_OP_PERMUTE:
762
- case GGML_OP_CONCAT:
763
- case GGML_OP_ADD:
764
- case GGML_OP_ACC:
765
- case GGML_OP_MUL:
766
- case GGML_OP_DIV:
767
- case GGML_OP_REPEAT:
768
- case GGML_OP_SCALE:
769
- case GGML_OP_CLAMP:
770
- case GGML_OP_SQR:
771
- case GGML_OP_SUM_ROWS:
772
- return true;
773
- case GGML_OP_SOFT_MAX:
774
- case GGML_OP_RMS_NORM:
775
- case GGML_OP_GROUP_NORM:
776
- return ctx->support_simdgroup_reduction;
777
- case GGML_OP_NORM:
778
- case GGML_OP_ROPE:
779
- case GGML_OP_IM2COL:
780
- return true;
781
- case GGML_OP_POOL_1D:
782
- case GGML_OP_POOL_2D:
783
- return false;
784
- case GGML_OP_UPSCALE:
785
- case GGML_OP_PAD:
786
- case GGML_OP_ARANGE:
787
- case GGML_OP_TIMESTEP_EMBEDDING:
788
- case GGML_OP_ARGSORT:
789
- case GGML_OP_LEAKY_RELU:
790
- return true;
791
- case GGML_OP_FLASH_ATTN_EXT:
792
- if (op->src[1]->type != GGML_TYPE_F16) {
793
- return false;
794
- }
795
- if (op->src[2]->type != GGML_TYPE_F16) {
796
- return false;
797
- }
798
- if (op->src[0]->ne[0] == 256) {
799
- return false;
800
- }
801
- return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels
802
- case GGML_OP_MUL_MAT:
803
- case GGML_OP_MUL_MAT_ID:
804
- return ctx->support_simdgroup_reduction &&
805
- (op->src[0]->type != GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F32);
806
- case GGML_OP_CPY:
807
- case GGML_OP_DUP:
808
- case GGML_OP_CONT:
809
- {
810
- switch (op->src[0]->type) {
811
- case GGML_TYPE_F32:
812
- switch (op->type) {
813
- case GGML_TYPE_F16:
814
- case GGML_TYPE_F32:
815
- case GGML_TYPE_Q8_0:
816
- case GGML_TYPE_Q4_0:
817
- case GGML_TYPE_Q4_1:
818
- case GGML_TYPE_Q5_0:
819
- case GGML_TYPE_Q5_1:
820
- case GGML_TYPE_IQ4_NL:
821
- return true;
822
- default:
823
- return false;
824
- }
825
- case GGML_TYPE_F16:
826
- switch (op->type) {
827
- case GGML_TYPE_F16:
828
- case GGML_TYPE_F32:
829
- return true;
830
- default:
831
- return false;
832
- }
833
- default:
834
- return false;
835
- };
836
- }
837
- case GGML_OP_DIAG_MASK_INF:
838
- case GGML_OP_GET_ROWS:
839
- {
840
- return op->src[0]->type != GGML_TYPE_BF16 && op->ne[3] == 1;
841
- }
842
- default:
843
- return false;
844
- }
845
- }
846
-
847
- static enum ggml_status ggml_metal_graph_compute(
848
- struct ggml_metal_context * ctx,
849
- struct ggml_cgraph * gf) {
850
-
851
- @autoreleasepool {
852
- MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor;
853
- edesc.dispatchType = MTLDispatchTypeSerial;
854
-
855
- // create multiple command buffers and enqueue them
856
- // then, we encode the graph into the command buffers in parallel
857
-
858
- const int n_nodes = gf->n_nodes;
859
- const int n_cb = ctx->n_cb;
860
- const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb;
861
-
862
- const bool should_capture = ctx->should_capture_next_compute;
863
- if (should_capture) {
864
- ctx->should_capture_next_compute = false;
865
-
866
- MTLCaptureDescriptor * descriptor = [MTLCaptureDescriptor new];
867
- descriptor.captureObject = ctx->queue;
868
-
869
- NSError * error = nil;
870
- if (![[MTLCaptureManager sharedCaptureManager] startCaptureWithDescriptor:descriptor error:&error]) {
871
- GGML_METAL_LOG_ERROR("%s: error: unable to start capture '%s'\n", __func__, [[error localizedDescription] UTF8String]);
872
- GGML_ASSERT(!"capture failed");
873
- }
874
- }
875
-
876
- id<MTLCommandBuffer> command_buffer_builder[n_cb];
877
- for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
878
- id<MTLCommandBuffer> command_buffer = [ctx->queue commandBufferWithUnretainedReferences];
879
- command_buffer_builder[cb_idx] = command_buffer;
880
-
881
- // enqueue the command buffers in order to specify their execution order
882
- [command_buffer enqueue];
883
- }
884
-
885
- const id<MTLCommandBuffer> *command_buffers = command_buffer_builder;
886
-
887
- dispatch_apply(n_cb, ctx->d_queue, ^(size_t iter) {
888
- const int cb_idx = iter;
889
-
890
- size_t offs_src0 = 0;
891
- size_t offs_src1 = 0;
892
- size_t offs_src2 = 0;
893
- size_t offs_dst = 0;
894
-
895
- id<MTLCommandBuffer> command_buffer = command_buffers[cb_idx];
896
- id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
897
-
898
- const int node_start = (cb_idx + 0) * n_nodes_per_cb;
899
- const int node_end = MIN((cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb, n_nodes);
900
-
901
- for (int i = node_start; i < node_end; ++i) {
902
- if (i == -1) {
903
- [encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
904
- continue;
905
- }
906
-
907
- //GGML_METAL_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
908
-
909
- struct ggml_tensor * src0 = gf->nodes[i]->src[0];
910
- struct ggml_tensor * src1 = gf->nodes[i]->src[1];
911
- struct ggml_tensor * src2 = gf->nodes[i]->src[2];
912
- struct ggml_tensor * dst = gf->nodes[i];
913
-
914
- if (ggml_is_empty(dst)) {
915
- continue;
916
- }
917
-
918
- switch (dst->op) {
919
- case GGML_OP_NONE:
920
- case GGML_OP_RESHAPE:
921
- case GGML_OP_VIEW:
922
- case GGML_OP_TRANSPOSE:
923
- case GGML_OP_PERMUTE:
924
- {
925
- // noop -> next node
926
- } continue;
927
- default:
928
- {
929
- } break;
930
- }
931
-
932
- if (!ggml_metal_supports_op(ctx, dst)) {
933
- GGML_METAL_LOG_ERROR("%s: error: unsupported op '%s'\n", __func__, ggml_op_desc(dst));
934
- GGML_ASSERT(!"unsupported op");
935
- }
936
-
937
- if (should_capture) {
938
- [encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(dst) encoding:NSUTF8StringEncoding]];
939
- }
940
-
941
- const int64_t ne00 = src0 ? src0->ne[0] : 0;
942
- const int64_t ne01 = src0 ? src0->ne[1] : 0;
943
- const int64_t ne02 = src0 ? src0->ne[2] : 0;
944
- const int64_t ne03 = src0 ? src0->ne[3] : 0;
945
-
946
- const uint64_t nb00 = src0 ? src0->nb[0] : 0;
947
- const uint64_t nb01 = src0 ? src0->nb[1] : 0;
948
- const uint64_t nb02 = src0 ? src0->nb[2] : 0;
949
- const uint64_t nb03 = src0 ? src0->nb[3] : 0;
950
-
951
- const int64_t ne10 = src1 ? src1->ne[0] : 0;
952
- const int64_t ne11 = src1 ? src1->ne[1] : 0;
953
- const int64_t ne12 = src1 ? src1->ne[2] : 0;
954
- const int64_t ne13 = src1 ? src1->ne[3] : 0;
955
-
956
- const uint64_t nb10 = src1 ? src1->nb[0] : 0;
957
- const uint64_t nb11 = src1 ? src1->nb[1] : 0;
958
- const uint64_t nb12 = src1 ? src1->nb[2] : 0;
959
- const uint64_t nb13 = src1 ? src1->nb[3] : 0;
960
-
961
- const int64_t ne20 = src2 ? src2->ne[0] : 0;
962
- const int64_t ne21 = src2 ? src2->ne[1] : 0;
963
- const int64_t ne22 = src2 ? src2->ne[2] : 0; GGML_UNUSED(ne22);
964
- const int64_t ne23 = src2 ? src2->ne[3] : 0; GGML_UNUSED(ne23);
965
-
966
- const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20);
967
- const uint64_t nb21 = src2 ? src2->nb[1] : 0;
968
- const uint64_t nb22 = src2 ? src2->nb[2] : 0;
969
- const uint64_t nb23 = src2 ? src2->nb[3] : 0;
970
-
971
- const int64_t ne0 = dst ? dst->ne[0] : 0;
972
- const int64_t ne1 = dst ? dst->ne[1] : 0;
973
- const int64_t ne2 = dst ? dst->ne[2] : 0;
974
- const int64_t ne3 = dst ? dst->ne[3] : 0;
975
-
976
- const uint64_t nb0 = dst ? dst->nb[0] : 0;
977
- const uint64_t nb1 = dst ? dst->nb[1] : 0;
978
- const uint64_t nb2 = dst ? dst->nb[2] : 0;
979
- const uint64_t nb3 = dst ? dst->nb[3] : 0;
980
-
981
- const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT;
982
- const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;
983
- const enum ggml_type dstt = dst ? dst->type : GGML_TYPE_COUNT;
984
-
985
- id<MTLBuffer> id_src0 = src0 ? ggml_metal_get_buffer(src0, &offs_src0) : nil;
986
- id<MTLBuffer> id_src1 = src1 ? ggml_metal_get_buffer(src1, &offs_src1) : nil;
987
- id<MTLBuffer> id_src2 = src2 ? ggml_metal_get_buffer(src2, &offs_src2) : nil;
988
- id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(dst, &offs_dst) : nil;
989
-
990
- //GGML_METAL_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
991
- //if (src0) {
992
- // GGML_METAL_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02,
993
- // ggml_is_contiguous(src0), src0->name);
994
- //}
995
- //if (src1) {
996
- // GGML_METAL_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12,
997
- // ggml_is_contiguous(src1), src1->name);
998
- //}
999
- //if (dst) {
1000
- // GGML_METAL_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2,
1001
- // dst->name);
1002
- //}
1003
-
1004
- switch (dst->op) {
1005
- case GGML_OP_CONCAT:
1006
- {
1007
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONCAT].pipeline;
1008
-
1009
- const int32_t dim = ((int32_t *) dst->op_params)[0];
1010
-
1011
- [encoder setComputePipelineState:pipeline];
1012
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1013
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1014
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1015
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1016
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
1017
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
1018
- [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
1019
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
1020
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
1021
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
1022
- [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
1023
- [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
1024
- [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
1025
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
1026
- [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
1027
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
1028
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
1029
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
1030
- [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
1031
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
1032
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
1033
- [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
1034
- [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
1035
- [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
1036
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
1037
- [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
1038
- [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
1039
- [encoder setBytes:&dim length:sizeof(dim) atIndex:27];
1040
-
1041
- const int nth = MIN(1024, ne0);
1042
-
1043
- [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1044
- } break;
1045
- case GGML_OP_ADD:
1046
- case GGML_OP_MUL:
1047
- case GGML_OP_DIV:
1048
- {
1049
- GGML_ASSERT(src0t == GGML_TYPE_F32);
1050
- GGML_ASSERT(src1t == GGML_TYPE_F32);
1051
-
1052
- const size_t offs = 0;
1053
-
1054
- bool bcast_row = false;
1055
-
1056
- int64_t nb = ne00; // used by the "row" kernels
1057
-
1058
- id<MTLComputePipelineState> pipeline = nil;
1059
-
1060
- if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
1061
- GGML_ASSERT(ggml_is_contiguous(src0));
1062
-
1063
- // src1 is a row
1064
- GGML_ASSERT(ne11 == 1);
1065
-
1066
- nb = ne00 / 4;
1067
- switch (dst->op) {
1068
- case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline; break;
1069
- case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_ROW].pipeline; break;
1070
- case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_ROW].pipeline; break;
1071
- default: GGML_ASSERT(false);
1072
- }
1073
-
1074
- bcast_row = true;
1075
- } else {
1076
- switch (dst->op) {
1077
- case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline; break;
1078
- case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL].pipeline; break;
1079
- case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV].pipeline; break;
1080
- default: GGML_ASSERT(false);
1081
- }
1082
- }
1083
-
1084
- [encoder setComputePipelineState:pipeline];
1085
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1086
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1087
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1088
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1089
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
1090
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
1091
- [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
1092
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
1093
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
1094
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
1095
- [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
1096
- [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
1097
- [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
1098
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
1099
- [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
1100
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
1101
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
1102
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
1103
- [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
1104
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
1105
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
1106
- [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
1107
- [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
1108
- [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
1109
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
1110
- [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
1111
- [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
1112
- [encoder setBytes:&offs length:sizeof(offs) atIndex:27];
1113
- [encoder setBytes:&nb length:sizeof(nb) atIndex:28];
1114
-
1115
- if (bcast_row) {
1116
- const int64_t n = ggml_nelements(dst)/4;
1117
-
1118
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1119
- } else {
1120
- const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
1121
-
1122
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1123
- }
1124
- } break;
1125
- case GGML_OP_REPEAT:
1126
- {
1127
- id<MTLComputePipelineState> pipeline;
1128
-
1129
- switch (src0t) {
1130
- case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F32].pipeline; break;
1131
- case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F16].pipeline; break;
1132
- case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I32].pipeline; break;
1133
- case GGML_TYPE_I16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I16].pipeline; break;
1134
- default: GGML_ASSERT(false);
1135
- }
1136
-
1137
- [encoder setComputePipelineState:pipeline];
1138
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1139
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1140
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
1141
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
1142
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
1143
- [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
1144
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
1145
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
1146
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
1147
- [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
1148
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
1149
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
1150
- [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
1151
- [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
1152
- [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
1153
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
1154
- [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
1155
- [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
1156
-
1157
- const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
1158
-
1159
- [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1160
- } break;
1161
- case GGML_OP_ACC:
1162
- {
1163
- GGML_ASSERT(src0t == GGML_TYPE_F32);
1164
- GGML_ASSERT(src1t == GGML_TYPE_F32);
1165
- GGML_ASSERT(dstt == GGML_TYPE_F32);
1166
-
1167
- GGML_ASSERT(ggml_is_contiguous(src0));
1168
- GGML_ASSERT(ggml_is_contiguous(src1));
1169
-
1170
- const size_t pnb1 = ((int32_t *) dst->op_params)[0];
1171
- const size_t pnb2 = ((int32_t *) dst->op_params)[1];
1172
- const size_t pnb3 = ((int32_t *) dst->op_params)[2];
1173
- const size_t offs = ((int32_t *) dst->op_params)[3];
1174
-
1175
- const bool inplace = (bool) ((int32_t *) dst->op_params)[4];
1176
-
1177
- if (!inplace) {
1178
- // run a separete kernel to cpy src->dst
1179
- // not sure how to avoid this
1180
- // TODO: make a simpler cpy_bytes kernel
1181
-
1182
- const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline;
1183
-
1184
- [encoder setComputePipelineState:pipeline];
1185
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1186
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1187
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1188
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
1189
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
1190
- [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
1191
- [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
1192
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
1193
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
1194
- [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
1195
- [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
1196
- [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
1197
- [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
1198
- [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
1199
- [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
1200
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
1201
- [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
1202
- [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
1203
-
1204
- const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00);
1205
-
1206
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1207
- }
1208
-
1209
- const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline;
1210
-
1211
- [encoder setComputePipelineState:pipeline];
1212
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1213
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1214
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1215
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1216
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
1217
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
1218
- [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
1219
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
1220
- [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:8];
1221
- [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:9];
1222
- [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:10];
1223
- [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
1224
- [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
1225
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
1226
- [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
1227
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
1228
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
1229
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
1230
- [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
1231
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
1232
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
1233
- [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
1234
- [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
1235
- [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
1236
- [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:24];
1237
- [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:25];
1238
- [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:26];
1239
- [encoder setBytes:&offs length:sizeof(offs) atIndex:27];
1240
-
1241
- const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00);
1242
-
1243
- [encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1244
- } break;
1245
- case GGML_OP_SCALE:
1246
- {
1247
- GGML_ASSERT(ggml_is_contiguous(src0));
1248
-
1249
- float scale;
1250
- memcpy(&scale, dst->op_params, sizeof(scale));
1251
-
1252
- int64_t n = ggml_nelements(dst);
1253
-
1254
- id<MTLComputePipelineState> pipeline = nil;
1255
-
1256
- if (n % 4 == 0) {
1257
- n /= 4;
1258
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE_4].pipeline;
1259
- } else {
1260
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE].pipeline;
1261
- }
1262
-
1263
- [encoder setComputePipelineState:pipeline];
1264
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1265
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1266
- [encoder setBytes:&scale length:sizeof(scale) atIndex:2];
1267
-
1268
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1269
- } break;
1270
- case GGML_OP_CLAMP:
1271
- {
1272
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CLAMP].pipeline;
1273
-
1274
- float min;
1275
- float max;
1276
- memcpy(&min, ((int32_t *) dst->op_params) + 0, sizeof(float));
1277
- memcpy(&max, ((int32_t *) dst->op_params) + 1, sizeof(float));
1278
-
1279
- [encoder setComputePipelineState:pipeline];
1280
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1281
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1282
- [encoder setBytes:&min length:sizeof(min) atIndex:2];
1283
- [encoder setBytes:&max length:sizeof(max) atIndex:3];
1284
-
1285
- const int64_t n = ggml_nelements(dst);
1286
-
1287
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1288
- } break;
1289
- case GGML_OP_UNARY:
1290
- switch (ggml_get_unary_op(gf->nodes[i])) {
1291
- // we are not taking into account the strides, so for now require contiguous tensors
1292
- GGML_ASSERT(ggml_is_contiguous(src0));
1293
-
1294
- case GGML_UNARY_OP_TANH:
1295
- {
1296
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TANH].pipeline;
1297
-
1298
- [encoder setComputePipelineState:pipeline];
1299
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1300
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1301
-
1302
- const int64_t n = ggml_nelements(dst);
1303
-
1304
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1305
- } break;
1306
- case GGML_UNARY_OP_RELU:
1307
- {
1308
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RELU].pipeline;
1309
-
1310
- [encoder setComputePipelineState:pipeline];
1311
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1312
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1313
-
1314
- const int64_t n = ggml_nelements(dst);
1315
-
1316
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1317
- } break;
1318
- case GGML_UNARY_OP_SIGMOID:
1319
- {
1320
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SIGMOID].pipeline;
1321
-
1322
- [encoder setComputePipelineState:pipeline];
1323
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1324
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1325
-
1326
- const int64_t n = ggml_nelements(dst);
1327
-
1328
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1329
- } break;
1330
- case GGML_UNARY_OP_GELU:
1331
- {
1332
- int64_t n = ggml_nelements(dst);
1333
-
1334
- id<MTLComputePipelineState> pipeline = nil;
1335
-
1336
- if (n % 4 == 0) {
1337
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_4].pipeline;
1338
- n /= 4;
1339
- } else {
1340
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU].pipeline;
1341
- }
1342
-
1343
- [encoder setComputePipelineState:pipeline];
1344
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1345
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1346
-
1347
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1348
- } break;
1349
- case GGML_UNARY_OP_GELU_QUICK:
1350
- {
1351
- int64_t n = ggml_nelements(dst);
1352
-
1353
- id<MTLComputePipelineState> pipeline = nil;
1354
-
1355
- if (n % 4 == 0) {
1356
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK_4].pipeline;
1357
- n /= 4;
1358
- } else {
1359
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK].pipeline;
1360
- }
1361
-
1362
- [encoder setComputePipelineState:pipeline];
1363
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1364
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1365
-
1366
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1367
- } break;
1368
- case GGML_UNARY_OP_SILU:
1369
- {
1370
- int64_t n = ggml_nelements(dst);
1371
-
1372
- id<MTLComputePipelineState> pipeline = nil;
1373
-
1374
- if (n % 4 == 0) {
1375
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU_4].pipeline;
1376
- n /= 4;
1377
- } else {
1378
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU].pipeline;
1379
- }
1380
-
1381
- [encoder setComputePipelineState:pipeline];
1382
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1383
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1384
-
1385
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1386
- } break;
1387
- default:
1388
- {
1389
- GGML_METAL_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
1390
- GGML_ASSERT(false);
1391
- }
1392
- } break;
1393
- case GGML_OP_SQR:
1394
- {
1395
- GGML_ASSERT(ggml_is_contiguous(src0));
1396
-
1397
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SQR].pipeline;
1398
-
1399
- [encoder setComputePipelineState:pipeline];
1400
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1401
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1402
-
1403
- const int64_t n = ggml_nelements(dst);
1404
-
1405
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1406
- } break;
1407
- case GGML_OP_SUM_ROWS:
1408
- {
1409
- GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
1410
-
1411
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
1412
-
1413
- [encoder setComputePipelineState:pipeline];
1414
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1415
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1416
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
1417
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
1418
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
1419
- [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
1420
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
1421
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
1422
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
1423
- [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
1424
- [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
1425
- [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
1426
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
1427
- [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
1428
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
1429
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
1430
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
1431
- [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17];
1432
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:18];
1433
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:19];
1434
- [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:20];
1435
- [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:21];
1436
- [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:22];
1437
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:23];
1438
- [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:24];
1439
- [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:25];
1440
-
1441
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1442
- } break;
1443
- case GGML_OP_SOFT_MAX:
1444
- {
1445
- GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32);
1446
-
1447
- int nth = 32; // SIMD width
1448
-
1449
- id<MTLComputePipelineState> pipeline = nil;
1450
-
1451
- const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
1452
-
1453
- if (ne00%4 == 0) {
1454
- while (nth < ne00/4 && nth*ne01*ne02*ne03 < 256) {
1455
- nth *= 2;
1456
- }
1457
- if (use_f16) {
1458
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4].pipeline;
1459
- } else {
1460
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4].pipeline;
1461
- }
1462
- } else {
1463
- while (nth < ne00 && nth*ne01*ne02*ne03 < 256) {
1464
- nth *= 2;
1465
- }
1466
- if (use_f16) {
1467
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16].pipeline;
1468
- } else {
1469
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32].pipeline;
1470
- }
1471
- }
1472
-
1473
- float scale;
1474
- float max_bias;
1475
-
1476
- memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale));
1477
- memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias));
1478
-
1479
- const int64_t nrows_x = ggml_nrows(src0);
1480
- const int64_t nrows_y = src0->ne[1];
1481
-
1482
- const uint32_t n_head = nrows_x/nrows_y;
1483
- const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
1484
-
1485
- const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
1486
- const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
1487
-
1488
- [encoder setComputePipelineState:pipeline];
1489
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1490
- if (id_src1) {
1491
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1492
- } else {
1493
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
1494
- }
1495
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1496
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1497
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
1498
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
1499
- [encoder setBytes:&scale length:sizeof(scale) atIndex:6];
1500
- [encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:7];
1501
- [encoder setBytes:&m0 length:sizeof(m0) atIndex:8];
1502
- [encoder setBytes:&m1 length:sizeof(m1) atIndex:9];
1503
- [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:10];
1504
- [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
1505
-
1506
- [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1507
- } break;
1508
- case GGML_OP_DIAG_MASK_INF:
1509
- {
1510
- const int n_past = ((int32_t *)(dst->op_params))[0];
1511
-
1512
- id<MTLComputePipelineState> pipeline = nil;
1513
-
1514
- if (ne00%8 == 0) {
1515
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8].pipeline;
1516
- } else {
1517
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF].pipeline;
1518
- }
1519
-
1520
- [encoder setComputePipelineState:pipeline];
1521
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1522
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1523
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
1524
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
1525
- [encoder setBytes:&n_past length:sizeof(int) atIndex:4];
1526
-
1527
- if (ne00%8 == 0) {
1528
- [encoder dispatchThreadgroups:MTLSizeMake(ne00*ne01*ne02/8, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1529
- }
1530
- else {
1531
- [encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1532
- }
1533
- } break;
1534
- case GGML_OP_MUL_MAT:
1535
- {
1536
- GGML_ASSERT(ne00 == ne10);
1537
-
1538
- GGML_ASSERT(ne12 % ne02 == 0);
1539
- GGML_ASSERT(ne13 % ne03 == 0);
1540
-
1541
- const uint r2 = ne12/ne02;
1542
- const uint r3 = ne13/ne03;
1543
-
1544
- // find the break-even point where the matrix-matrix kernel becomes more efficient compared
1545
- // to the matrix-vector kernel
1546
- int ne11_mm_min = 1;
1547
-
1548
- #if 0
1549
- // the numbers below are measured on M2 Ultra for 7B and 13B models
1550
- // these numbers do not translate to other devices or model sizes
1551
- // TODO: need to find a better approach
1552
- if ([ctx->device.name isEqualToString:@"Apple M2 Ultra"]) {
1553
- switch (src0t) {
1554
- case GGML_TYPE_F16: ne11_mm_min = 2; break;
1555
- case GGML_TYPE_Q8_0: ne11_mm_min = 7; break;
1556
- case GGML_TYPE_Q2_K: ne11_mm_min = 15; break;
1557
- case GGML_TYPE_Q3_K: ne11_mm_min = 7; break;
1558
- case GGML_TYPE_Q4_0:
1559
- case GGML_TYPE_Q4_1: ne11_mm_min = 15; break;
1560
- case GGML_TYPE_Q4_K: ne11_mm_min = 11; break;
1561
- case GGML_TYPE_Q5_0: // not tested yet
1562
- case GGML_TYPE_Q5_1: ne11_mm_min = 13; break; // not tested yet
1563
- case GGML_TYPE_Q5_K: ne11_mm_min = 7; break;
1564
- case GGML_TYPE_Q6_K: ne11_mm_min = 7; break;
1565
- default: ne11_mm_min = 1; break;
1566
- }
1567
- }
1568
- #endif
1569
-
1570
- // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
1571
- // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
1572
- if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
1573
- !ggml_is_transposed(src0) &&
1574
- !ggml_is_transposed(src1) &&
1575
- src1t == GGML_TYPE_F32 &&
1576
- ne00 % 32 == 0 && ne00 >= 64 &&
1577
- (ne11 > ne11_mm_min || (ggml_is_quantized(src0t) && ne12 > 1))) {
1578
- //printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
1579
-
1580
- // some Metal matrix data types require aligned pointers
1581
- // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
1582
- switch (src0->type) {
1583
- case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
1584
- case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
1585
- default: break;
1586
- }
1587
-
1588
- id<MTLComputePipelineState> pipeline = nil;
1589
-
1590
- switch (src0->type) {
1591
- case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32 ].pipeline; break;
1592
- case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32 ].pipeline; break;
1593
- case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32 ].pipeline; break;
1594
- case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32 ].pipeline; break;
1595
- case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32 ].pipeline; break;
1596
- case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32 ].pipeline; break;
1597
- case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32 ].pipeline; break;
1598
- case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32 ].pipeline; break;
1599
- case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32 ].pipeline; break;
1600
- case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32 ].pipeline; break;
1601
- case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32 ].pipeline; break;
1602
- case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32 ].pipeline; break;
1603
- case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32].pipeline; break;
1604
- case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline; break;
1605
- case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32].pipeline; break;
1606
- case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32 ].pipeline; break;
1607
- case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32 ].pipeline; break;
1608
- case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32 ].pipeline; break;
1609
- case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32 ].pipeline; break;
1610
- case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break;
1611
- case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32 ].pipeline; break;
1612
- default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
1613
- }
1614
-
1615
- [encoder setComputePipelineState:pipeline];
1616
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1617
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1618
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1619
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1620
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
1621
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
1622
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6];
1623
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
1624
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:8];
1625
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:9];
1626
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10];
1627
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11];
1628
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12];
1629
- [encoder setBytes:&r2 length:sizeof(r2) atIndex:13];
1630
- [encoder setBytes:&r3 length:sizeof(r3) atIndex:14];
1631
- [encoder setThreadgroupMemoryLength:8192 atIndex:0];
1632
- [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1633
- } else {
1634
- int nth0 = 32;
1635
- int nth1 = 1;
1636
- int nrows = 1;
1637
- //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
1638
-
1639
- id<MTLComputePipelineState> pipeline = nil;
1640
-
1641
- // use custom matrix x vector kernel
1642
- switch (src0t) {
1643
- case GGML_TYPE_F32:
1644
- {
1645
- GGML_ASSERT(src1t == GGML_TYPE_F32);
1646
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline;
1647
- nrows = 4;
1648
- } break;
1649
- case GGML_TYPE_F16:
1650
- {
1651
- nth0 = 32;
1652
- nth1 = 1;
1653
- if (src1t == GGML_TYPE_F32) {
1654
- if (ne11 * ne12 < 4) {
1655
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW].pipeline;
1656
- } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
1657
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline;
1658
- nrows = ne11;
1659
- } else {
1660
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32].pipeline;
1661
- nrows = 4;
1662
- }
1663
- } else {
1664
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16].pipeline;
1665
- nrows = 4;
1666
- }
1667
- } break;
1668
- case GGML_TYPE_Q4_0:
1669
- {
1670
- nth0 = 8;
1671
- nth1 = 8;
1672
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32].pipeline;
1673
- } break;
1674
- case GGML_TYPE_Q4_1:
1675
- {
1676
- nth0 = 8;
1677
- nth1 = 8;
1678
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32].pipeline;
1679
- } break;
1680
- case GGML_TYPE_Q5_0:
1681
- {
1682
- nth0 = 8;
1683
- nth1 = 8;
1684
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32].pipeline;
1685
- } break;
1686
- case GGML_TYPE_Q5_1:
1687
- {
1688
- nth0 = 8;
1689
- nth1 = 8;
1690
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32].pipeline;
1691
- } break;
1692
- case GGML_TYPE_Q8_0:
1693
- {
1694
- nth0 = 8;
1695
- nth1 = 8;
1696
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32].pipeline;
1697
- } break;
1698
- case GGML_TYPE_Q2_K:
1699
- {
1700
- nth0 = 2;
1701
- nth1 = 32;
1702
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32].pipeline;
1703
- } break;
1704
- case GGML_TYPE_Q3_K:
1705
- {
1706
- nth0 = 2;
1707
- nth1 = 32;
1708
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32].pipeline;
1709
- } break;
1710
- case GGML_TYPE_Q4_K:
1711
- {
1712
- nth0 = 4; //1;
1713
- nth1 = 8; //32;
1714
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32].pipeline;
1715
- } break;
1716
- case GGML_TYPE_Q5_K:
1717
- {
1718
- nth0 = 2;
1719
- nth1 = 32;
1720
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32].pipeline;
1721
- } break;
1722
- case GGML_TYPE_Q6_K:
1723
- {
1724
- nth0 = 2;
1725
- nth1 = 32;
1726
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32].pipeline;
1727
- } break;
1728
- case GGML_TYPE_IQ2_XXS:
1729
- {
1730
- nth0 = 4;
1731
- nth1 = 16;
1732
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32].pipeline;
1733
- } break;
1734
- case GGML_TYPE_IQ2_XS:
1735
- {
1736
- nth0 = 4;
1737
- nth1 = 16;
1738
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32].pipeline;
1739
- } break;
1740
- case GGML_TYPE_IQ3_XXS:
1741
- {
1742
- nth0 = 4;
1743
- nth1 = 16;
1744
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32].pipeline;
1745
- } break;
1746
- case GGML_TYPE_IQ3_S:
1747
- {
1748
- nth0 = 4;
1749
- nth1 = 16;
1750
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32].pipeline;
1751
- } break;
1752
- case GGML_TYPE_IQ2_S:
1753
- {
1754
- nth0 = 4;
1755
- nth1 = 16;
1756
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32].pipeline;
1757
- } break;
1758
- case GGML_TYPE_IQ1_S:
1759
- {
1760
- nth0 = 4;
1761
- nth1 = 16;
1762
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32].pipeline;
1763
- } break;
1764
- case GGML_TYPE_IQ1_M:
1765
- {
1766
- nth0 = 4;
1767
- nth1 = 16;
1768
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32].pipeline;
1769
- } break;
1770
- case GGML_TYPE_IQ4_NL:
1771
- {
1772
- nth0 = 4;
1773
- nth1 = 16;
1774
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32].pipeline;
1775
- } break;
1776
- case GGML_TYPE_IQ4_XS:
1777
- {
1778
- nth0 = 4;
1779
- nth1 = 16;
1780
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32].pipeline;
1781
- } break;
1782
- default:
1783
- {
1784
- GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
1785
- GGML_ASSERT(false && "not implemented");
1786
- }
1787
- };
1788
-
1789
- if (ggml_is_quantized(src0t)) {
1790
- GGML_ASSERT(ne00 >= nth0*nth1);
1791
- }
1792
-
1793
- [encoder setComputePipelineState:pipeline];
1794
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1795
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1796
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1797
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1798
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
1799
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
1800
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
1801
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
1802
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
1803
- [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:9];
1804
- [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:10];
1805
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:11];
1806
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:12];
1807
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:13];
1808
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
1809
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15];
1810
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
1811
- [encoder setBytes:&r2 length:sizeof(r2) atIndex:17];
1812
- [encoder setBytes:&r3 length:sizeof(r3) atIndex:18];
1813
-
1814
- if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 ||
1815
- src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K ||
1816
- src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
1817
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1818
- }
1819
- else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
1820
- const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
1821
- [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1822
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1823
- }
1824
- else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) {
1825
- const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
1826
- [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1827
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1828
- }
1829
- else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) {
1830
- const int mem_size = 32*sizeof(float);
1831
- [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1832
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1833
- }
1834
- else if (src0t == GGML_TYPE_Q4_K) {
1835
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1836
- }
1837
- else if (src0t == GGML_TYPE_Q3_K) {
1838
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1839
- }
1840
- else if (src0t == GGML_TYPE_Q5_K) {
1841
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1842
- }
1843
- else if (src0t == GGML_TYPE_Q6_K) {
1844
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1845
- } else {
1846
- const int64_t ny = (ne11 + nrows - 1)/nrows;
1847
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1848
- }
1849
- }
1850
- } break;
1851
- case GGML_OP_MUL_MAT_ID:
1852
- {
1853
- const int n_as = src0->ne[2];
1854
-
1855
- // src2 = ids
1856
- const enum ggml_type src2t = src2->type; GGML_UNUSED(src2t);
1857
-
1858
- GGML_ASSERT(src2t == GGML_TYPE_I32);
1859
-
1860
- GGML_ASSERT(!ggml_is_transposed(src0));
1861
- GGML_ASSERT(!ggml_is_transposed(src1));
1862
-
1863
- GGML_ASSERT(src1t == GGML_TYPE_F32);
1864
-
1865
- // find the break-even point where the matrix-matrix kernel becomes more efficient compared
1866
- // to the matrix-vector kernel
1867
- // ne20 = n_used_experts
1868
- // ne21 = n_rows
1869
- const int dst_rows = ne20*ne21;
1870
- const int dst_rows_min = n_as;
1871
- const int dst_rows_max = (ctx->device.maxThreadgroupMemoryLength - 32 - 8192)/4;
1872
-
1873
- // max size of the rowids array in the kernel shared buffer
1874
- GGML_ASSERT(dst_rows <= dst_rows_max);
1875
-
1876
- // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
1877
- // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
1878
- // !!!
1879
- // TODO: for now, always use mat-vec kernels until we figure out how to improve the
1880
- // indirect matrix multiplication
1881
- // !!!
1882
- if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
1883
- ne00 % 32 == 0 && ne00 >= 64 &&
1884
- dst_rows > dst_rows_min) {
1885
-
1886
- // some Metal matrix data types require aligned pointers
1887
- // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
1888
- switch (src0->type) {
1889
- case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
1890
- case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
1891
- default: break;
1892
- }
1893
-
1894
- id<MTLComputePipelineState> pipeline = nil;
1895
-
1896
- switch (src0->type) {
1897
- case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32 ].pipeline; break;
1898
- case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32 ].pipeline; break;
1899
- case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32 ].pipeline; break;
1900
- case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32 ].pipeline; break;
1901
- case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32 ].pipeline; break;
1902
- case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32 ].pipeline; break;
1903
- case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32 ].pipeline; break;
1904
- case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32 ].pipeline; break;
1905
- case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32 ].pipeline; break;
1906
- case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32 ].pipeline; break;
1907
- case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32 ].pipeline; break;
1908
- case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32 ].pipeline; break;
1909
- case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32].pipeline; break;
1910
- case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline; break;
1911
- case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32].pipeline; break;
1912
- case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32 ].pipeline; break;
1913
- case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32 ].pipeline; break;
1914
- case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32 ].pipeline; break;
1915
- case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32 ].pipeline; break;
1916
- case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32 ].pipeline; break;
1917
- case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32 ].pipeline; break;
1918
- default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
1919
- }
1920
-
1921
- [encoder setComputePipelineState:pipeline];
1922
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1923
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1924
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1925
- [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
1926
- [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
1927
- [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
1928
- [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
1929
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:7];
1930
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:8];
1931
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:9];
1932
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:10];
1933
- [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
1934
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
1935
- [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
1936
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
1937
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
1938
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
1939
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17];
1940
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:18];
1941
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19];
1942
-
1943
- [encoder setThreadgroupMemoryLength:GGML_PAD(8192 + dst_rows*4/*sizeof(ushort2)*/, 16) atIndex:0];
1944
-
1945
- [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 31)/32, (ne01 + 63)/64, n_as) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1946
- } else {
1947
- int nth0 = 32;
1948
- int nth1 = 1;
1949
- int nrows = 1;
1950
- //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
1951
-
1952
- id<MTLComputePipelineState> pipeline = nil;
1953
-
1954
- // use custom matrix x vector kernel
1955
- switch (src0t) {
1956
- case GGML_TYPE_F32:
1957
- {
1958
- GGML_ASSERT(src1t == GGML_TYPE_F32);
1959
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32].pipeline;
1960
- } break;
1961
- case GGML_TYPE_F16:
1962
- {
1963
- GGML_ASSERT(src1t == GGML_TYPE_F32);
1964
- nth0 = 32;
1965
- nth1 = 1;
1966
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32].pipeline;
1967
- } break;
1968
- case GGML_TYPE_Q4_0:
1969
- {
1970
- nth0 = 8;
1971
- nth1 = 8;
1972
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32].pipeline;
1973
- } break;
1974
- case GGML_TYPE_Q4_1:
1975
- {
1976
- nth0 = 8;
1977
- nth1 = 8;
1978
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32].pipeline;
1979
- } break;
1980
- case GGML_TYPE_Q5_0:
1981
- {
1982
- nth0 = 8;
1983
- nth1 = 8;
1984
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32].pipeline;
1985
- } break;
1986
- case GGML_TYPE_Q5_1:
1987
- {
1988
- nth0 = 8;
1989
- nth1 = 8;
1990
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32].pipeline;
1991
- } break;
1992
- case GGML_TYPE_Q8_0:
1993
- {
1994
- nth0 = 8;
1995
- nth1 = 8;
1996
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32].pipeline;
1997
- } break;
1998
- case GGML_TYPE_Q2_K:
1999
- {
2000
- nth0 = 2;
2001
- nth1 = 32;
2002
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32].pipeline;
2003
- } break;
2004
- case GGML_TYPE_Q3_K:
2005
- {
2006
- nth0 = 2;
2007
- nth1 = 32;
2008
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32].pipeline;
2009
- } break;
2010
- case GGML_TYPE_Q4_K:
2011
- {
2012
- nth0 = 4; //1;
2013
- nth1 = 8; //32;
2014
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32].pipeline;
2015
- } break;
2016
- case GGML_TYPE_Q5_K:
2017
- {
2018
- nth0 = 2;
2019
- nth1 = 32;
2020
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32].pipeline;
2021
- } break;
2022
- case GGML_TYPE_Q6_K:
2023
- {
2024
- nth0 = 2;
2025
- nth1 = 32;
2026
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32].pipeline;
2027
- } break;
2028
- case GGML_TYPE_IQ2_XXS:
2029
- {
2030
- nth0 = 4;
2031
- nth1 = 16;
2032
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32].pipeline;
2033
- } break;
2034
- case GGML_TYPE_IQ2_XS:
2035
- {
2036
- nth0 = 4;
2037
- nth1 = 16;
2038
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32].pipeline;
2039
- } break;
2040
- case GGML_TYPE_IQ3_XXS:
2041
- {
2042
- nth0 = 4;
2043
- nth1 = 16;
2044
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32].pipeline;
2045
- } break;
2046
- case GGML_TYPE_IQ3_S:
2047
- {
2048
- nth0 = 4;
2049
- nth1 = 16;
2050
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32].pipeline;
2051
- } break;
2052
- case GGML_TYPE_IQ2_S:
2053
- {
2054
- nth0 = 4;
2055
- nth1 = 16;
2056
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32].pipeline;
2057
- } break;
2058
- case GGML_TYPE_IQ1_S:
2059
- {
2060
- nth0 = 4;
2061
- nth1 = 16;
2062
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32].pipeline;
2063
- } break;
2064
- case GGML_TYPE_IQ1_M:
2065
- {
2066
- nth0 = 4;
2067
- nth1 = 16;
2068
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32].pipeline;
2069
- } break;
2070
- case GGML_TYPE_IQ4_NL:
2071
- {
2072
- nth0 = 4;
2073
- nth1 = 16;
2074
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32].pipeline;
2075
- } break;
2076
- case GGML_TYPE_IQ4_XS:
2077
- {
2078
- nth0 = 4;
2079
- nth1 = 16;
2080
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32].pipeline;
2081
- } break;
2082
- default:
2083
- {
2084
- GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t);
2085
- GGML_ASSERT(false && "not implemented");
2086
- }
2087
- };
2088
-
2089
- if (ggml_is_quantized(src0t)) {
2090
- GGML_ASSERT(ne00 >= nth0*nth1);
2091
- }
2092
-
2093
- [encoder setComputePipelineState:pipeline];
2094
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2095
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
2096
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
2097
- [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
2098
- [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
2099
- [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
2100
- [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
2101
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:7];
2102
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:8];
2103
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:9];
2104
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:10];
2105
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:11];
2106
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:12];
2107
- [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:13];
2108
- [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:14];
2109
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:15];
2110
- [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:16];
2111
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:17];
2112
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:18];
2113
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:19];
2114
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:20];
2115
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:21];
2116
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:22];
2117
-
2118
- const int64_t _ne1 = 1;
2119
- const int tgz = dst_rows;
2120
-
2121
- if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 ||
2122
- src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K ||
2123
- src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
2124
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2125
- }
2126
- else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
2127
- const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
2128
- [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
2129
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2130
- }
2131
- else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) {
2132
- const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
2133
- [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
2134
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2135
- }
2136
- else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) {
2137
- const int mem_size = 32*sizeof(float);
2138
- [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
2139
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2140
- }
2141
- else if (src0t == GGML_TYPE_Q4_K) {
2142
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2143
- }
2144
- else if (src0t == GGML_TYPE_Q3_K) {
2145
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2146
- }
2147
- else if (src0t == GGML_TYPE_Q5_K) {
2148
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2149
- }
2150
- else if (src0t == GGML_TYPE_Q6_K) {
2151
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2152
- } else {
2153
- const int64_t ny = (_ne1 + nrows - 1)/nrows; // = _ne1
2154
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2155
- }
2156
- }
2157
- } break;
2158
- case GGML_OP_GET_ROWS:
2159
- {
2160
- id<MTLComputePipelineState> pipeline = nil;
2161
-
2162
- switch (src0->type) {
2163
- case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F32 ].pipeline; break;
2164
- case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F16 ].pipeline; break;
2165
- case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0 ].pipeline; break;
2166
- case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1 ].pipeline; break;
2167
- case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0 ].pipeline; break;
2168
- case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1 ].pipeline; break;
2169
- case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0 ].pipeline; break;
2170
- case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K ].pipeline; break;
2171
- case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K ].pipeline; break;
2172
- case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K ].pipeline; break;
2173
- case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K ].pipeline; break;
2174
- case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K ].pipeline; break;
2175
- case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS].pipeline; break;
2176
- case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS ].pipeline; break;
2177
- case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS].pipeline; break;
2178
- case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S ].pipeline; break;
2179
- case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S ].pipeline; break;
2180
- case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S ].pipeline; break;
2181
- case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M ].pipeline; break;
2182
- case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL ].pipeline; break;
2183
- case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS ].pipeline; break;
2184
- case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ].pipeline; break;
2185
- default: GGML_ASSERT(false && "not implemented");
2186
- }
2187
-
2188
- [encoder setComputePipelineState:pipeline];
2189
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2190
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
2191
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
2192
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
2193
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
2194
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:5];
2195
- [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:6];
2196
- [encoder setBytes:&nb10 length:sizeof( int64_t) atIndex:7];
2197
- [encoder setBytes:&nb11 length:sizeof( int64_t) atIndex:8];
2198
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:9];
2199
- [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:10];
2200
-
2201
- [encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
2202
- } break;
2203
- case GGML_OP_RMS_NORM:
2204
- {
2205
- GGML_ASSERT(ne00 % 4 == 0);
2206
- GGML_ASSERT(ggml_is_contiguous_1(src0));
2207
-
2208
- float eps;
2209
- memcpy(&eps, dst->op_params, sizeof(float));
2210
-
2211
- int nth = 32; // SIMD width
2212
-
2213
- while (nth < ne00/4 && nth < 1024) {
2214
- nth *= 2;
2215
- }
2216
-
2217
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM].pipeline;
2218
-
2219
- [encoder setComputePipelineState:pipeline];
2220
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2221
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2222
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
2223
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
2224
- [encoder setBytes:&eps length:sizeof( float) atIndex:4];
2225
- [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
2226
-
2227
- const int64_t nrows = ggml_nrows(src0);
2228
-
2229
- [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2230
- } break;
2231
- case GGML_OP_GROUP_NORM:
2232
- {
2233
- GGML_ASSERT(ne00 % 4 == 0);
2234
- GGML_ASSERT(ggml_is_contiguous(src0));
2235
-
2236
- //float eps;
2237
- //memcpy(&eps, dst->op_params, sizeof(float));
2238
-
2239
- const float eps = 1e-6f; // TODO: temporarily hardcoded
2240
-
2241
- const int32_t n_groups = ((int32_t *) dst->op_params)[0];
2242
-
2243
- int nth = 32; // SIMD width
2244
-
2245
- //while (nth < ne00/4 && nth < 1024) {
2246
- // nth *= 2;
2247
- //}
2248
-
2249
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GROUP_NORM].pipeline;
2250
-
2251
- [encoder setComputePipelineState:pipeline];
2252
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2253
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2254
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
2255
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
2256
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
2257
- [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:5];
2258
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:6];
2259
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:7];
2260
- [encoder setBytes:&n_groups length:sizeof( int32_t) atIndex:8];
2261
- [encoder setBytes:&eps length:sizeof( float) atIndex:9];
2262
- [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
2263
-
2264
- [encoder dispatchThreadgroups:MTLSizeMake(n_groups, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2265
- } break;
2266
- case GGML_OP_NORM:
2267
- {
2268
- GGML_ASSERT(ggml_is_contiguous_1(src0));
2269
-
2270
- float eps;
2271
- memcpy(&eps, dst->op_params, sizeof(float));
2272
-
2273
- const int nth = MIN(256, ne00);
2274
-
2275
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_NORM].pipeline;
2276
-
2277
- [encoder setComputePipelineState:pipeline];
2278
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2279
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2280
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
2281
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
2282
- [encoder setBytes:&eps length:sizeof( float) atIndex:4];
2283
- [encoder setThreadgroupMemoryLength:GGML_PAD(nth*sizeof(float), 16) atIndex:0];
2284
-
2285
- const int64_t nrows = ggml_nrows(src0);
2286
-
2287
- [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2288
- } break;
2289
- case GGML_OP_ROPE:
2290
- {
2291
- GGML_ASSERT(ne10 == ne02);
2292
-
2293
- const int nth = MIN(1024, ne00);
2294
-
2295
- const int n_past = ((int32_t *) dst->op_params)[0];
2296
- const int n_dims = ((int32_t *) dst->op_params)[1];
2297
- const int mode = ((int32_t *) dst->op_params)[2];
2298
- // skip 3, n_ctx, used in GLM RoPE, unimplemented in metal
2299
- const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
2300
-
2301
- float freq_base;
2302
- float freq_scale;
2303
- float ext_factor;
2304
- float attn_factor;
2305
- float beta_fast;
2306
- float beta_slow;
2307
-
2308
- memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
2309
- memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
2310
- memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
2311
- memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
2312
- memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
2313
- memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
2314
-
2315
- const bool is_neox = mode & 2;
2316
-
2317
- id<MTLComputePipelineState> pipeline = nil;
2318
-
2319
- if (!is_neox) {
2320
- switch (src0->type) {
2321
- case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break;
2322
- case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break;
2323
- default: GGML_ASSERT(false);
2324
- };
2325
- } else {
2326
- switch (src0->type) {
2327
- case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break;
2328
- case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break;
2329
- default: GGML_ASSERT(false);
2330
- };
2331
- }
2332
-
2333
- [encoder setComputePipelineState:pipeline];
2334
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2335
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
2336
- if (id_src2 != nil) {
2337
- [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
2338
- } else {
2339
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:2];
2340
- }
2341
- [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
2342
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:4];
2343
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5];
2344
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6];
2345
- [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7];
2346
- [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:8];
2347
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:9];
2348
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:10];
2349
- [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:11];
2350
- [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:12];
2351
- [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:13];
2352
- [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:14];
2353
- [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:15];
2354
- [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:16];
2355
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:17];
2356
- [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:18];
2357
- [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:19];
2358
- [encoder setBytes:&n_past length:sizeof( int) atIndex:20];
2359
- [encoder setBytes:&n_dims length:sizeof( int) atIndex:21];
2360
- [encoder setBytes:&n_ctx_orig length:sizeof( int) atIndex:22];
2361
- [encoder setBytes:&freq_base length:sizeof( float) atIndex:23];
2362
- [encoder setBytes:&freq_scale length:sizeof( float) atIndex:24];
2363
- [encoder setBytes:&ext_factor length:sizeof( float) atIndex:25];
2364
- [encoder setBytes:&attn_factor length:sizeof( float) atIndex:26];
2365
- [encoder setBytes:&beta_fast length:sizeof( float) atIndex:27];
2366
- [encoder setBytes:&beta_slow length:sizeof( float) atIndex:28];
2367
-
2368
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2369
- } break;
2370
- case GGML_OP_IM2COL:
2371
- {
2372
- GGML_ASSERT(src0->type == GGML_TYPE_F16);
2373
- GGML_ASSERT(src1->type == GGML_TYPE_F32);
2374
- GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
2375
-
2376
- const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
2377
- const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
2378
- const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
2379
- const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
2380
- const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
2381
- const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
2382
-
2383
- const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
2384
-
2385
- const int32_t N = src1->ne[is_2D ? 3 : 2];
2386
- const int32_t IC = src1->ne[is_2D ? 2 : 1];
2387
- const int32_t IH = is_2D ? src1->ne[1] : 1;
2388
- const int32_t IW = src1->ne[0];
2389
-
2390
- const int32_t KH = is_2D ? src0->ne[1] : 1;
2391
- const int32_t KW = src0->ne[0];
2392
-
2393
- const int32_t OH = is_2D ? dst->ne[2] : 1;
2394
- const int32_t OW = dst->ne[1];
2395
-
2396
- const int32_t CHW = IC * KH * KW;
2397
-
2398
- const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4;
2399
- const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4;
2400
-
2401
- id<MTLComputePipelineState> pipeline = nil;
2402
-
2403
- switch (dst->type) {
2404
- case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline; break;
2405
- case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline; break;
2406
- default: GGML_ASSERT(false);
2407
- };
2408
-
2409
- [encoder setComputePipelineState:pipeline];
2410
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:0];
2411
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2412
- [encoder setBytes:&ofs0 length:sizeof( int32_t) atIndex:2];
2413
- [encoder setBytes:&ofs1 length:sizeof( int32_t) atIndex:3];
2414
- [encoder setBytes:&IW length:sizeof( int32_t) atIndex:4];
2415
- [encoder setBytes:&IH length:sizeof( int32_t) atIndex:5];
2416
- [encoder setBytes:&CHW length:sizeof( int32_t) atIndex:6];
2417
- [encoder setBytes:&s0 length:sizeof( int32_t) atIndex:7];
2418
- [encoder setBytes:&s1 length:sizeof( int32_t) atIndex:8];
2419
- [encoder setBytes:&p0 length:sizeof( int32_t) atIndex:9];
2420
- [encoder setBytes:&p1 length:sizeof( int32_t) atIndex:10];
2421
- [encoder setBytes:&d0 length:sizeof( int32_t) atIndex:11];
2422
- [encoder setBytes:&d1 length:sizeof( int32_t) atIndex:12];
2423
-
2424
- [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
2425
- } break;
2426
- case GGML_OP_UPSCALE:
2427
- {
2428
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
2429
-
2430
- const float sf0 = (float)ne0/src0->ne[0];
2431
- const float sf1 = (float)ne1/src0->ne[1];
2432
- const float sf2 = (float)ne2/src0->ne[2];
2433
- const float sf3 = (float)ne3/src0->ne[3];
2434
-
2435
- const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_UPSCALE_F32].pipeline;
2436
-
2437
- [encoder setComputePipelineState:pipeline];
2438
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2439
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2440
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
2441
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
2442
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
2443
- [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
2444
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
2445
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
2446
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
2447
- [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
2448
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
2449
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
2450
- [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
2451
- [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
2452
- [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
2453
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
2454
- [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
2455
- [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
2456
- [encoder setBytes:&sf0 length:sizeof(sf0) atIndex:18];
2457
- [encoder setBytes:&sf1 length:sizeof(sf1) atIndex:19];
2458
- [encoder setBytes:&sf2 length:sizeof(sf2) atIndex:20];
2459
- [encoder setBytes:&sf3 length:sizeof(sf3) atIndex:21];
2460
-
2461
- const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
2462
-
2463
- [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2464
- } break;
2465
- case GGML_OP_PAD:
2466
- {
2467
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
2468
-
2469
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_PAD_F32].pipeline;
2470
-
2471
- [encoder setComputePipelineState:pipeline];
2472
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2473
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2474
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
2475
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
2476
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
2477
- [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
2478
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
2479
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
2480
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
2481
- [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
2482
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
2483
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
2484
- [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
2485
- [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
2486
- [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
2487
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
2488
- [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
2489
- [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
2490
-
2491
- const int nth = MIN(1024, ne0);
2492
-
2493
- [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2494
- } break;
2495
- case GGML_OP_ARANGE:
2496
- {
2497
- GGML_ASSERT(dst->type == GGML_TYPE_F32);
2498
-
2499
- float start;
2500
- float step;
2501
-
2502
- memcpy(&start, ((int32_t *) dst->op_params) + 0, sizeof(float));
2503
- memcpy(&step, ((int32_t *) dst->op_params) + 2, sizeof(float));
2504
-
2505
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARANGE_F32].pipeline;
2506
-
2507
- [encoder setComputePipelineState:pipeline];
2508
- [encoder setBuffer:id_dst offset:offs_dst atIndex:0];
2509
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:1];
2510
- [encoder setBytes:&start length:sizeof(start) atIndex:2];
2511
- [encoder setBytes:&step length:sizeof(step) atIndex:3];
2512
-
2513
- const int nth = MIN(1024, ne0);
2514
-
2515
- [encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2516
- } break;
2517
- case GGML_OP_TIMESTEP_EMBEDDING:
2518
- {
2519
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
2520
-
2521
- const int dim = dst->op_params[0];
2522
- const int max_period = dst->op_params[1];
2523
-
2524
- const int half = dim / 2;
2525
-
2526
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32].pipeline;
2527
-
2528
- [encoder setComputePipelineState:pipeline];
2529
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2530
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2531
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:2];
2532
- [encoder setBytes:&dim length:sizeof(dim) atIndex:3];
2533
- [encoder setBytes:&max_period length:sizeof(max_period) atIndex:4];
2534
-
2535
- const int nth = MIN(1024, half);
2536
-
2537
- [encoder dispatchThreadgroups:MTLSizeMake(ne00, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2538
- } break;
2539
- case GGML_OP_ARGSORT:
2540
- {
2541
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
2542
- GGML_ASSERT( dst->type == GGML_TYPE_I32);
2543
-
2544
- const int nrows = ggml_nrows(src0);
2545
-
2546
- enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
2547
-
2548
- // bitonic sort requires the number of elements to be power of 2
2549
- int64_t ne00_padded = 1;
2550
- while (ne00_padded < ne00) {
2551
- ne00_padded *= 2;
2552
- }
2553
-
2554
- // Metal kernels require the buffer size to be multiple of 16 bytes
2555
- // https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength
2556
- const int mem_size = GGML_PAD(ne00_padded*sizeof(int32_t), 16);
2557
-
2558
- id<MTLComputePipelineState> pipeline = nil;
2559
-
2560
- switch (order) {
2561
- case GGML_SORT_ORDER_ASC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC].pipeline; break;
2562
- case GGML_SORT_ORDER_DESC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC].pipeline; break;
2563
- default: GGML_ASSERT(false);
2564
- };
2565
-
2566
- [encoder setComputePipelineState:pipeline];
2567
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2568
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2569
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
2570
- [encoder setBytes:&ne00_padded length:sizeof( int64_t) atIndex:3];
2571
- [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
2572
-
2573
- [encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00_padded, 1, 1)];
2574
- } break;
2575
- case GGML_OP_LEAKY_RELU:
2576
- {
2577
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
2578
-
2579
- float slope;
2580
- memcpy(&slope, dst->op_params, sizeof(float));
2581
-
2582
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32].pipeline;
2583
-
2584
- [encoder setComputePipelineState:pipeline];
2585
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2586
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2587
- [encoder setBytes:&slope length:sizeof(slope) atIndex:2];
2588
-
2589
- const int64_t n = ggml_nelements(dst);
2590
-
2591
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2592
- } break;
2593
- case GGML_OP_FLASH_ATTN_EXT:
2594
- {
2595
- GGML_ASSERT(ne00 % 4 == 0);
2596
- GGML_ASSERT(ne11 % 32 == 0);
2597
-
2598
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
2599
-
2600
- GGML_ASSERT(ggml_are_same_shape (src1, src2));
2601
-
2602
- struct ggml_tensor * src3 = gf->nodes[i]->src[3];
2603
-
2604
- size_t offs_src3 = 0;
2605
-
2606
- id<MTLBuffer> id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil;
2607
-
2608
- GGML_ASSERT(!src3 || src3->type == GGML_TYPE_F16);
2609
- GGML_ASSERT(!src3 || src3->ne[1] >= GGML_PAD(src0->ne[1], 8) &&
2610
- "the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big");
2611
-
2612
- const int64_t ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30);
2613
- //const int64_t ne31 = src3 ? src3->ne[1] : 0;
2614
- const int64_t ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32);
2615
- const int64_t ne33 = src3 ? src3->ne[3] : 0; GGML_UNUSED(ne33);
2616
-
2617
- const uint64_t nb30 = src3 ? src3->nb[0] : 0; GGML_UNUSED(nb30);
2618
- const uint64_t nb31 = src3 ? src3->nb[1] : 0;
2619
- const uint64_t nb32 = src3 ? src3->nb[2] : 0; GGML_UNUSED(nb32);
2620
- const uint64_t nb33 = src3 ? src3->nb[3] : 0; GGML_UNUSED(nb33);
2621
-
2622
- const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t);
2623
-
2624
- float scale;
2625
- float max_bias;
2626
-
2627
- memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale));
2628
- memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias));
2629
-
2630
- const uint32_t n_head = src0->ne[2];
2631
- const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
2632
-
2633
- const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
2634
- const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
2635
-
2636
- id<MTLComputePipelineState> pipeline = nil;
2637
-
2638
- bool use_vec_kernel = false;
2639
-
2640
- if (ne01 >= 4 || (ne00%128 != 0)) {
2641
- switch (ne00) {
2642
- case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break;
2643
- case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break;
2644
- case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break;
2645
- case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break;
2646
- case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break;
2647
- //case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
2648
- default:
2649
- {
2650
- GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
2651
- GGML_METAL_LOG_ERROR("add template specialization for this size\n");
2652
- GGML_ASSERT(false && "add template specialization for this size");
2653
- }
2654
- }
2655
- } else {
2656
- use_vec_kernel = true;
2657
-
2658
- switch (ne00) {
2659
- case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break;
2660
- //case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
2661
- default:
2662
- {
2663
- GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
2664
- GGML_METAL_LOG_ERROR("add template specialization for this size\n");
2665
- GGML_ASSERT(false && "add template specialization for this size");
2666
- }
2667
- }
2668
- }
2669
-
2670
- [encoder setComputePipelineState:pipeline];
2671
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2672
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
2673
- [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
2674
- if (id_src3) {
2675
- [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
2676
- } else {
2677
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:3];
2678
- }
2679
- [encoder setBuffer:id_dst offset:offs_dst atIndex:4];
2680
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5];
2681
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6];
2682
- [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7];
2683
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8];
2684
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9];
2685
- [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10];
2686
- [encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:11];
2687
- [encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:12];
2688
- [encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:13];
2689
- [encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:14];
2690
- [encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:15];
2691
- [encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:16];
2692
- [encoder setBytes:&nb21 length:sizeof(uint64_t) atIndex:17];
2693
- [encoder setBytes:&nb22 length:sizeof(uint64_t) atIndex:18];
2694
- [encoder setBytes:&nb23 length:sizeof(uint64_t) atIndex:19];
2695
- [encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:20];
2696
- [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:21];
2697
- [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:22];
2698
- [encoder setBytes:&scale length:sizeof( float) atIndex:23];
2699
- [encoder setBytes:&max_bias length:sizeof( float) atIndex:24];
2700
- [encoder setBytes:&m0 length:sizeof(m0) atIndex:25];
2701
- [encoder setBytes:&m1 length:sizeof(m1) atIndex:26];
2702
- [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:27];
2703
-
2704
- if (!use_vec_kernel) {
2705
- // half8x8 kernel
2706
- const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !!
2707
- const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
2708
-
2709
- GGML_ASSERT(nqptg <= 32);
2710
- GGML_ASSERT(nqptg % 8 == 0);
2711
- GGML_ASSERT(ncpsg % 32 == 0);
2712
-
2713
- int64_t nsgmax = 2;
2714
-
2715
- while (true) {
2716
- const size_t smem = nqptg*(ne00 + 2*nsgmax*(ncpsg + nqptg))*(sizeof(float)/2);
2717
- if (smem > ctx->device.maxThreadgroupMemoryLength) {
2718
- break;
2719
- }
2720
- nsgmax *= 2;
2721
- }
2722
- nsgmax /= 2;
2723
-
2724
- // simdgroups per threadgroup (a.k.a. warps)
2725
- const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4;
2726
-
2727
- const size_t smem = nqptg*(ne00 + 2*nsg*(ncpsg + nqptg))*(sizeof(float)/2);
2728
-
2729
- //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength);
2730
- GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength);
2731
-
2732
- [encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0];
2733
-
2734
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
2735
- } else {
2736
- // half1x4 kernel
2737
- const int64_t nqptg = 1; // queries per threadgroup !! sync with kernel template arguments !!
2738
- const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
2739
-
2740
- GGML_ASSERT(nqptg <= 32);
2741
- GGML_ASSERT(nqptg % 1 == 0);
2742
- GGML_ASSERT(ncpsg % 32 == 0);
2743
-
2744
- // simdgroups per threadgroup (a.k.a. warps)
2745
- const int64_t nsgt = MAX(2, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32));
2746
-
2747
- int64_t nsg = 1;
2748
- while (nsg <= nsgt) {
2749
- nsg *= 2;
2750
- }
2751
- nsg /= 2;
2752
-
2753
- const size_t smem = (nqptg*(ne00 + 2*nsg*(ncpsg + nqptg)) + nsg*ne00)*(sizeof(float)/2);
2754
-
2755
- //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength);
2756
- GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength);
2757
- [encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0];
2758
-
2759
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
2760
- }
2761
- } break;
2762
- case GGML_OP_DUP:
2763
- case GGML_OP_CPY:
2764
- case GGML_OP_CONT:
2765
- {
2766
- GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
2767
-
2768
- int nth = MIN(1024, ne00/ggml_blck_size(src0->type));
2769
-
2770
- id<MTLComputePipelineState> pipeline = nil;
2771
-
2772
- switch (src0t) {
2773
- case GGML_TYPE_F32:
2774
- {
2775
- GGML_ASSERT(ne0 % ggml_blck_size(dst->type) == 0);
2776
-
2777
- switch (dstt) {
2778
- case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break;
2779
- case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break;
2780
- case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break;
2781
- case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break;
2782
- case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break;
2783
- case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0].pipeline; break;
2784
- case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1].pipeline; break;
2785
- case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL].pipeline; break;
2786
- default: GGML_ASSERT(false && "not implemented");
2787
- };
2788
- } break;
2789
- case GGML_TYPE_F16:
2790
- {
2791
- switch (dstt) {
2792
- case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline; break;
2793
- case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F32].pipeline; break;
2794
- default: GGML_ASSERT(false && "not implemented");
2795
- };
2796
- } break;
2797
- default: GGML_ASSERT(false && "not implemented");
2798
- }
2799
-
2800
- [encoder setComputePipelineState:pipeline];
2801
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2802
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2803
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
2804
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
2805
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
2806
- [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
2807
- [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
2808
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
2809
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
2810
- [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
2811
- [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
2812
- [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
2813
- [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
2814
- [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
2815
- [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
2816
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
2817
- [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
2818
- [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
2819
-
2820
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2821
- } break;
2822
- default:
2823
- {
2824
- GGML_METAL_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
2825
- GGML_ASSERT(false);
2826
- }
2827
- }
2828
-
2829
- if (should_capture) {
2830
- [encoder popDebugGroup];
2831
- }
2832
- }
2833
-
2834
- [encoder endEncoding];
2835
-
2836
- [command_buffer commit];
2837
- });
2838
-
2839
- // Wait for completion and check status of each command buffer
2840
- // needed to detect if the device ran out-of-memory for example (#1881)
2841
-
2842
- for (int i = 0; i < n_cb; ++i) {
2843
- id<MTLCommandBuffer> command_buffer = command_buffers[i];
2844
- [command_buffer waitUntilCompleted];
2845
-
2846
- MTLCommandBufferStatus status = [command_buffer status];
2847
- if (status != MTLCommandBufferStatusCompleted) {
2848
- GGML_METAL_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status);
2849
- if (status == MTLCommandBufferStatusError) {
2850
- NSString * error_code = [command_buffer error].localizedDescription;
2851
- GGML_METAL_LOG_INFO("error: %s\n", [error_code UTF8String]);
2852
- }
2853
-
2854
- return GGML_STATUS_FAILED;
2855
- }
2856
- }
2857
-
2858
- if (should_capture) {
2859
- [[MTLCaptureManager sharedCaptureManager] stopCapture];
2860
- }
2861
-
2862
- }
2863
- return GGML_STATUS_SUCCESS;
2864
- }
2865
-
2866
- ////////////////////////////////////////////////////////////////////////////////
2867
-
2868
- // backend interface
2869
-
2870
- // default buffer
2871
- static id<MTLDevice> g_backend_device = nil;
2872
- static int g_backend_device_ref_count = 0;
2873
-
2874
- static id<MTLDevice> ggml_backend_metal_get_device(void) {
2875
- if (g_backend_device == nil) {
2876
- g_backend_device = MTLCreateSystemDefaultDevice();
2877
- }
2878
-
2879
- g_backend_device_ref_count++;
2880
-
2881
- return g_backend_device;
2882
- }
2883
-
2884
- static void ggml_backend_metal_free_device(void) {
2885
- assert(g_backend_device_ref_count > 0);
2886
-
2887
- g_backend_device_ref_count--;
2888
-
2889
- if (g_backend_device_ref_count == 0) {
2890
- [g_backend_device release];
2891
- g_backend_device = nil;
2892
- }
2893
- }
2894
-
2895
- GGML_CALL static const char * ggml_backend_metal_buffer_get_name(ggml_backend_buffer_t buffer) {
2896
- return "Metal";
2897
-
2898
- UNUSED(buffer);
2899
- }
2900
-
2901
- GGML_CALL static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer) {
2902
- struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context;
2903
-
2904
- for (int i = 0; i < ctx->n_buffers; i++) {
2905
- [ctx->buffers[i].metal release];
2906
- }
2907
- ggml_backend_metal_free_device();
2908
-
2909
- if (ctx->owned) {
2910
- #if TARGET_OS_OSX
2911
- vm_deallocate((vm_map_t)mach_task_self(), (vm_address_t)ctx->all_data, ctx->all_size);
2912
- #else
2913
- free(ctx->all_data);
2914
- #endif
2915
- }
2916
-
2917
- free(ctx);
2918
- }
2919
-
2920
- GGML_CALL static void * ggml_backend_metal_buffer_get_base(ggml_backend_buffer_t buffer) {
2921
- struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context;
2922
-
2923
- return ctx->all_data;
2924
- }
2925
-
2926
- GGML_CALL static void ggml_backend_metal_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
2927
- memcpy((char *)tensor->data + offset, data, size);
2928
-
2929
- UNUSED(buffer);
2930
- }
2931
-
2932
- GGML_CALL static void ggml_backend_metal_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
2933
- memcpy(data, (const char *)tensor->data + offset, size);
2934
-
2935
- UNUSED(buffer);
2936
- }
2937
-
2938
- GGML_CALL static bool ggml_backend_metal_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst) {
2939
- if (ggml_backend_buffer_is_host(src->buffer)) {
2940
- memcpy(dst->data, src->data, ggml_nbytes(src));
2941
- return true;
2942
- }
2943
- return false;
2944
-
2945
- UNUSED(buffer);
2946
- }
2947
-
2948
- GGML_CALL static void ggml_backend_metal_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
2949
- struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context;
2950
-
2951
- memset(ctx->all_data, value, ctx->all_size);
2952
- }
2953
-
2954
- static struct ggml_backend_buffer_i ggml_backend_metal_buffer_i = {
2955
- /* .get_name = */ ggml_backend_metal_buffer_get_name,
2956
- /* .free_buffer = */ ggml_backend_metal_buffer_free_buffer,
2957
- /* .get_base = */ ggml_backend_metal_buffer_get_base,
2958
- /* .init_tensor = */ NULL,
2959
- /* .set_tensor = */ ggml_backend_metal_buffer_set_tensor,
2960
- /* .get_tensor = */ ggml_backend_metal_buffer_get_tensor,
2961
- /* .cpy_tensor = */ ggml_backend_metal_buffer_cpy_tensor,
2962
- /* .clear = */ ggml_backend_metal_buffer_clear,
2963
- /* .reset = */ NULL,
2964
- };
2965
-
2966
- // default buffer type
2967
-
2968
- GGML_CALL static const char * ggml_backend_metal_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
2969
- return "Metal";
2970
-
2971
- UNUSED(buft);
2972
- }
2973
-
2974
- static void ggml_backend_metal_log_allocated_size(id<MTLDevice> device, size_t size_aligned) {
2975
- #ifndef GGML_METAL_NDEBUG
2976
- #if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
2977
- if (@available(macOS 10.12, iOS 16.0, *)) {
2978
- GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB, (%8.2f / %8.2f)",
2979
- __func__,
2980
- size_aligned / 1024.0 / 1024.0,
2981
- device.currentAllocatedSize / 1024.0 / 1024.0,
2982
- device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
2983
-
2984
- if (device.currentAllocatedSize > device.recommendedMaxWorkingSetSize) {
2985
- GGML_METAL_LOG_WARN("%s: warning: current allocated size is greater than the recommended max working set size\n", __func__);
2986
- } else {
2987
- GGML_METAL_LOG_INFO("\n");
2988
- }
2989
- } else {
2990
- GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB, (%8.2f)\n",
2991
- __func__,
2992
- size_aligned / 1024.0 / 1024.0,
2993
- device.currentAllocatedSize / 1024.0 / 1024.0);
2994
- }
2995
- #endif
2996
- #endif
2997
- UNUSED(device);
2998
- UNUSED(size_aligned);
2999
- }
3000
-
3001
- GGML_CALL static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
3002
- struct ggml_backend_metal_buffer_context * ctx = malloc(sizeof(struct ggml_backend_metal_buffer_context));
3003
-
3004
- const size_t size_page = sysconf(_SC_PAGESIZE);
3005
-
3006
- size_t size_aligned = size;
3007
- if ((size_aligned % size_page) != 0) {
3008
- size_aligned += (size_page - (size_aligned % size_page));
3009
- }
3010
-
3011
- id<MTLDevice> device = ggml_backend_metal_get_device();
3012
-
3013
- ctx->all_data = ggml_metal_host_malloc(size_aligned);
3014
- ctx->all_size = size_aligned;
3015
- ctx->owned = true;
3016
- ctx->n_buffers = 1;
3017
-
3018
- if (ctx->all_data != NULL) {
3019
- ctx->buffers[0].data = ctx->all_data;
3020
- ctx->buffers[0].size = size;
3021
- ctx->buffers[0].metal = [device newBufferWithBytesNoCopy:ctx->all_data
3022
- length:size_aligned
3023
- options:MTLResourceStorageModeShared
3024
- deallocator:nil];
3025
- }
3026
-
3027
- if (ctx->all_data == NULL || ctx->buffers[0].metal == nil) {
3028
- GGML_METAL_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0);
3029
- free(ctx);
3030
- ggml_backend_metal_free_device();
3031
- return NULL;
3032
- }
3033
-
3034
- //ggml_backend_metal_log_allocated_size(device, size_aligned);
3035
-
3036
- return ggml_backend_buffer_init(buft, ggml_backend_metal_buffer_i, ctx, size);
3037
- }
3038
-
3039
- GGML_CALL static size_t ggml_backend_metal_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
3040
- return 32;
3041
- UNUSED(buft);
3042
- }
3043
-
3044
- GGML_CALL static size_t ggml_backend_metal_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
3045
- id<MTLDevice> device = ggml_backend_metal_get_device();
3046
- size_t max_size = device.maxBufferLength;
3047
- ggml_backend_metal_free_device();
3048
-
3049
- return max_size;
3050
-
3051
- UNUSED(buft);
3052
- }
3053
-
3054
- GGML_CALL static bool ggml_backend_metal_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
3055
- return true;
3056
-
3057
- UNUSED(buft);
3058
- }
3059
-
3060
- GGML_CALL ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void) {
3061
- static struct ggml_backend_buffer_type ggml_backend_buffer_type_metal = {
3062
- /* .iface = */ {
3063
- /* .get_name = */ ggml_backend_metal_buffer_type_get_name,
3064
- /* .alloc_buffer = */ ggml_backend_metal_buffer_type_alloc_buffer,
3065
- /* .get_alignment = */ ggml_backend_metal_buffer_type_get_alignment,
3066
- /* .get_max_size = */ ggml_backend_metal_buffer_type_get_max_size,
3067
- /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
3068
- /* .is_host = */ ggml_backend_metal_buffer_type_is_host,
3069
- },
3070
- /* .context = */ NULL,
3071
- };
3072
-
3073
- return &ggml_backend_buffer_type_metal;
3074
- }
3075
-
3076
- // buffer from ptr
3077
-
3078
- GGML_CALL ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size) {
3079
- struct ggml_backend_metal_buffer_context * ctx = malloc(sizeof(struct ggml_backend_metal_buffer_context));
3080
-
3081
- ctx->all_data = data;
3082
- ctx->all_size = size;
3083
- ctx->owned = false;
3084
- ctx->n_buffers = 0;
3085
-
3086
- const size_t size_page = sysconf(_SC_PAGESIZE);
3087
-
3088
- // page-align the data ptr
3089
- {
3090
- const uintptr_t offs = (uintptr_t) data % size_page;
3091
- data = (void *) ((char *) data - offs);
3092
- size += offs;
3093
- }
3094
-
3095
- size_t size_aligned = size;
3096
- if ((size_aligned % size_page) != 0) {
3097
- size_aligned += (size_page - (size_aligned % size_page));
3098
- }
3099
-
3100
- id<MTLDevice> device = ggml_backend_metal_get_device();
3101
-
3102
- // the buffer fits into the max buffer size allowed by the device
3103
- if (size_aligned <= device.maxBufferLength) {
3104
- ctx->buffers[ctx->n_buffers].data = data;
3105
- ctx->buffers[ctx->n_buffers].size = size;
3106
-
3107
- ctx->buffers[ctx->n_buffers].metal = [device newBufferWithBytesNoCopy:data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil];
3108
-
3109
- if (ctx->buffers[ctx->n_buffers].metal == nil) {
3110
- GGML_METAL_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0);
3111
- return false;
3112
- }
3113
-
3114
- ggml_backend_metal_log_allocated_size(device, size_aligned);
3115
-
3116
- ++ctx->n_buffers;
3117
- } else {
3118
- // this overlap between the views will guarantee that the tensor with the maximum size will fully fit into
3119
- // one of the views
3120
- const size_t size_ovlp = ((max_size + size_page - 1) / size_page + 1) * size_page; // round-up 2 pages just in case
3121
- const size_t size_step = device.maxBufferLength - size_ovlp;
3122
- const size_t size_view = device.maxBufferLength;
3123
-
3124
- for (size_t i = 0; i < size; i += size_step) {
3125
- const size_t size_step_aligned = (i + size_view <= size) ? size_view : (size_aligned - i);
3126
-
3127
- ctx->buffers[ctx->n_buffers].data = (void *) ((uint8_t *) data + i);
3128
- ctx->buffers[ctx->n_buffers].size = size_step_aligned;
3129
-
3130
- ctx->buffers[ctx->n_buffers].metal = [device newBufferWithBytesNoCopy:(void *) ((uint8_t *) data + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil];
3131
-
3132
- if (ctx->buffers[ctx->n_buffers].metal == nil) {
3133
- GGML_METAL_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_step_aligned / 1024.0 / 1024.0);
3134
- return false;
3135
- }
3136
-
3137
- ggml_backend_metal_log_allocated_size(device, size_step_aligned);
3138
-
3139
- if (i + size_step < size) {
3140
- GGML_METAL_LOG_INFO("\n");
3141
- }
3142
-
3143
- ++ctx->n_buffers;
3144
- }
3145
- }
3146
-
3147
- return ggml_backend_buffer_init(ggml_backend_metal_buffer_type(), ggml_backend_metal_buffer_i, ctx, size);
3148
- }
3149
-
3150
- // backend
3151
-
3152
- GGML_CALL static const char * ggml_backend_metal_name(ggml_backend_t backend) {
3153
- return "Metal";
3154
-
3155
- UNUSED(backend);
3156
- }
3157
-
3158
- GGML_CALL static void ggml_backend_metal_free(ggml_backend_t backend) {
3159
- struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
3160
- ggml_metal_free(ctx);
3161
- free(backend);
3162
- }
3163
-
3164
- GGML_CALL static ggml_backend_buffer_type_t ggml_backend_metal_get_default_buffer_type(ggml_backend_t backend) {
3165
- return ggml_backend_metal_buffer_type();
3166
-
3167
- UNUSED(backend);
3168
- }
3169
-
3170
- GGML_CALL static enum ggml_status ggml_backend_metal_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
3171
- struct ggml_metal_context * metal_ctx = (struct ggml_metal_context *)backend->context;
3172
-
3173
- return ggml_metal_graph_compute(metal_ctx, cgraph);
3174
- }
3175
-
3176
- GGML_CALL static bool ggml_backend_metal_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
3177
- struct ggml_metal_context * metal_ctx = (struct ggml_metal_context *)backend->context;
3178
-
3179
- return ggml_metal_supports_op(metal_ctx, op);
3180
- }
3181
-
3182
- GGML_CALL static bool ggml_backend_metal_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
3183
- return buft->iface.get_name == ggml_backend_metal_buffer_type_get_name;
3184
-
3185
- UNUSED(backend);
3186
- }
3187
-
3188
- static struct ggml_backend_i ggml_backend_metal_i = {
3189
- /* .get_name = */ ggml_backend_metal_name,
3190
- /* .free = */ ggml_backend_metal_free,
3191
- /* .get_default_buffer_type = */ ggml_backend_metal_get_default_buffer_type,
3192
- /* .set_tensor_async = */ NULL,
3193
- /* .get_tensor_async = */ NULL,
3194
- /* .cpy_tensor_async = */ NULL,
3195
- /* .synchronize = */ NULL,
3196
- /* .graph_plan_create = */ NULL,
3197
- /* .graph_plan_free = */ NULL,
3198
- /* .graph_plan_update = */ NULL,
3199
- /* .graph_plan_compute = */ NULL,
3200
- /* .graph_compute = */ ggml_backend_metal_graph_compute,
3201
- /* .supports_op = */ ggml_backend_metal_supports_op,
3202
- /* .supports_buft = */ ggml_backend_metal_supports_buft,
3203
- /* .offload_op = */ NULL,
3204
- /* .event_new = */ NULL,
3205
- /* .event_free = */ NULL,
3206
- /* .event_record = */ NULL,
3207
- /* .event_wait = */ NULL,
3208
- /* .event_synchronize = */ NULL,
3209
- };
3210
-
3211
- void ggml_backend_metal_log_set_callback(ggml_log_callback log_callback, void * user_data) {
3212
- ggml_metal_log_callback = log_callback;
3213
- ggml_metal_log_user_data = user_data;
3214
- }
3215
-
3216
- static ggml_guid_t ggml_backend_metal_guid(void) {
3217
- static ggml_guid guid = { 0x81, 0xa1, 0x8b, 0x1e, 0x71, 0xec, 0x79, 0xed, 0x2b, 0x85, 0xdc, 0x8a, 0x61, 0x98, 0x30, 0xe6 };
3218
- return &guid;
3219
- }
3220
-
3221
- ggml_backend_t ggml_backend_metal_init(void) {
3222
- struct ggml_metal_context * ctx = ggml_metal_init(GGML_DEFAULT_N_THREADS);
3223
-
3224
- if (ctx == NULL) {
3225
- return NULL;
3226
- }
3227
-
3228
- ggml_backend_t metal_backend = malloc(sizeof(struct ggml_backend));
3229
-
3230
- *metal_backend = (struct ggml_backend) {
3231
- /* .guid = */ ggml_backend_metal_guid(),
3232
- /* .interface = */ ggml_backend_metal_i,
3233
- /* .context = */ ctx,
3234
- };
3235
-
3236
- return metal_backend;
3237
- }
3238
-
3239
- bool ggml_backend_is_metal(ggml_backend_t backend) {
3240
- return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_metal_guid());
3241
- }
3242
-
3243
- void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
3244
- GGML_ASSERT(ggml_backend_is_metal(backend));
3245
-
3246
- struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
3247
-
3248
- ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS);
3249
- }
3250
-
3251
- bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family) {
3252
- GGML_ASSERT(ggml_backend_is_metal(backend));
3253
-
3254
- struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
3255
-
3256
- return [ctx->device supportsFamily:(MTLGPUFamilyApple1 + family - 1)];
3257
- }
3258
-
3259
- void ggml_backend_metal_capture_next_compute(ggml_backend_t backend) {
3260
- GGML_ASSERT(ggml_backend_is_metal(backend));
3261
-
3262
- struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
3263
- ctx->should_capture_next_compute = true;
3264
- }
3265
-
3266
- GGML_CALL ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data); // silence warning
3267
-
3268
- GGML_CALL ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data) {
3269
- return ggml_backend_metal_init();
3270
-
3271
- GGML_UNUSED(params);
3272
- GGML_UNUSED(user_data);
3273
- }