cui-llama.rn 0.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. package/LICENSE +20 -0
  2. package/README.md +330 -0
  3. package/android/build.gradle +107 -0
  4. package/android/gradle.properties +5 -0
  5. package/android/src/main/AndroidManifest.xml +4 -0
  6. package/android/src/main/CMakeLists.txt +69 -0
  7. package/android/src/main/java/com/rnllama/LlamaContext.java +353 -0
  8. package/android/src/main/java/com/rnllama/RNLlama.java +446 -0
  9. package/android/src/main/java/com/rnllama/RNLlamaPackage.java +48 -0
  10. package/android/src/main/jni.cpp +635 -0
  11. package/android/src/newarch/java/com/rnllama/RNLlamaModule.java +94 -0
  12. package/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +95 -0
  13. package/cpp/README.md +4 -0
  14. package/cpp/common.cpp +3237 -0
  15. package/cpp/common.h +467 -0
  16. package/cpp/ggml-aarch64.c +2193 -0
  17. package/cpp/ggml-aarch64.h +39 -0
  18. package/cpp/ggml-alloc.c +1041 -0
  19. package/cpp/ggml-alloc.h +76 -0
  20. package/cpp/ggml-backend-impl.h +153 -0
  21. package/cpp/ggml-backend.c +2225 -0
  22. package/cpp/ggml-backend.h +236 -0
  23. package/cpp/ggml-common.h +1829 -0
  24. package/cpp/ggml-impl.h +655 -0
  25. package/cpp/ggml-metal.h +65 -0
  26. package/cpp/ggml-metal.m +3273 -0
  27. package/cpp/ggml-quants.c +15022 -0
  28. package/cpp/ggml-quants.h +132 -0
  29. package/cpp/ggml.c +22034 -0
  30. package/cpp/ggml.h +2444 -0
  31. package/cpp/grammar-parser.cpp +536 -0
  32. package/cpp/grammar-parser.h +29 -0
  33. package/cpp/json-schema-to-grammar.cpp +1045 -0
  34. package/cpp/json-schema-to-grammar.h +8 -0
  35. package/cpp/json.hpp +24766 -0
  36. package/cpp/llama.cpp +21789 -0
  37. package/cpp/llama.h +1201 -0
  38. package/cpp/log.h +737 -0
  39. package/cpp/rn-llama.hpp +630 -0
  40. package/cpp/sampling.cpp +460 -0
  41. package/cpp/sampling.h +160 -0
  42. package/cpp/sgemm.cpp +1027 -0
  43. package/cpp/sgemm.h +14 -0
  44. package/cpp/unicode-data.cpp +7032 -0
  45. package/cpp/unicode-data.h +20 -0
  46. package/cpp/unicode.cpp +812 -0
  47. package/cpp/unicode.h +64 -0
  48. package/ios/RNLlama.h +11 -0
  49. package/ios/RNLlama.mm +302 -0
  50. package/ios/RNLlama.xcodeproj/project.pbxproj +278 -0
  51. package/ios/RNLlamaContext.h +39 -0
  52. package/ios/RNLlamaContext.mm +426 -0
  53. package/jest/mock.js +169 -0
  54. package/lib/commonjs/NativeRNLlama.js +10 -0
  55. package/lib/commonjs/NativeRNLlama.js.map +1 -0
  56. package/lib/commonjs/grammar.js +574 -0
  57. package/lib/commonjs/grammar.js.map +1 -0
  58. package/lib/commonjs/index.js +151 -0
  59. package/lib/commonjs/index.js.map +1 -0
  60. package/lib/module/NativeRNLlama.js +3 -0
  61. package/lib/module/NativeRNLlama.js.map +1 -0
  62. package/lib/module/grammar.js +566 -0
  63. package/lib/module/grammar.js.map +1 -0
  64. package/lib/module/index.js +129 -0
  65. package/lib/module/index.js.map +1 -0
  66. package/lib/typescript/NativeRNLlama.d.ts +107 -0
  67. package/lib/typescript/NativeRNLlama.d.ts.map +1 -0
  68. package/lib/typescript/grammar.d.ts +38 -0
  69. package/lib/typescript/grammar.d.ts.map +1 -0
  70. package/lib/typescript/index.d.ts +46 -0
  71. package/lib/typescript/index.d.ts.map +1 -0
  72. package/llama-rn.podspec +56 -0
  73. package/package.json +230 -0
  74. package/src/NativeRNLlama.ts +132 -0
  75. package/src/grammar.ts +849 -0
  76. package/src/index.ts +182 -0
@@ -0,0 +1,3273 @@
1
+ #import "ggml-metal.h"
2
+
3
+ #import "ggml-backend-impl.h"
4
+ #import "ggml.h"
5
+
6
+ #import <Foundation/Foundation.h>
7
+
8
+ #import <Metal/Metal.h>
9
+
10
+ #undef MIN
11
+ #undef MAX
12
+ #define MIN(a, b) ((a) < (b) ? (a) : (b))
13
+ #define MAX(a, b) ((a) > (b) ? (a) : (b))
14
+
15
+ #ifdef LM_GGML_METAL_NDEBUG
16
+ #define LM_GGML_METAL_LOG_INFO(...)
17
+ #define LM_GGML_METAL_LOG_WARN(...)
18
+ #define LM_GGML_METAL_LOG_ERROR(...)
19
+ #else
20
+ #define LM_GGML_METAL_LOG_INFO(...) lm_ggml_metal_log(LM_GGML_LOG_LEVEL_INFO, __VA_ARGS__)
21
+ #define LM_GGML_METAL_LOG_WARN(...) lm_ggml_metal_log(LM_GGML_LOG_LEVEL_WARN, __VA_ARGS__)
22
+ #define LM_GGML_METAL_LOG_ERROR(...) lm_ggml_metal_log(LM_GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
23
+ #endif
24
+
25
+ #define UNUSED(x) (void)(x)
26
+
27
+ struct lm_ggml_metal_kernel {
28
+ id<MTLComputePipelineState> pipeline;
29
+ };
30
+
31
+ enum lm_ggml_metal_kernel_type {
32
+ LM_GGML_METAL_KERNEL_TYPE_ADD,
33
+ LM_GGML_METAL_KERNEL_TYPE_ADD_ROW,
34
+ LM_GGML_METAL_KERNEL_TYPE_MUL,
35
+ LM_GGML_METAL_KERNEL_TYPE_MUL_ROW,
36
+ LM_GGML_METAL_KERNEL_TYPE_DIV,
37
+ LM_GGML_METAL_KERNEL_TYPE_DIV_ROW,
38
+ LM_GGML_METAL_KERNEL_TYPE_REPEAT_F32,
39
+ LM_GGML_METAL_KERNEL_TYPE_REPEAT_F16,
40
+ LM_GGML_METAL_KERNEL_TYPE_REPEAT_I32,
41
+ LM_GGML_METAL_KERNEL_TYPE_REPEAT_I16,
42
+ LM_GGML_METAL_KERNEL_TYPE_SCALE,
43
+ LM_GGML_METAL_KERNEL_TYPE_SCALE_4,
44
+ LM_GGML_METAL_KERNEL_TYPE_CLAMP,
45
+ LM_GGML_METAL_KERNEL_TYPE_TANH,
46
+ LM_GGML_METAL_KERNEL_TYPE_RELU,
47
+ LM_GGML_METAL_KERNEL_TYPE_SIGMOID,
48
+ LM_GGML_METAL_KERNEL_TYPE_GELU,
49
+ LM_GGML_METAL_KERNEL_TYPE_GELU_4,
50
+ LM_GGML_METAL_KERNEL_TYPE_GELU_QUICK,
51
+ LM_GGML_METAL_KERNEL_TYPE_GELU_QUICK_4,
52
+ LM_GGML_METAL_KERNEL_TYPE_SILU,
53
+ LM_GGML_METAL_KERNEL_TYPE_SILU_4,
54
+ LM_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16,
55
+ LM_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4,
56
+ LM_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32,
57
+ LM_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4,
58
+ LM_GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF,
59
+ LM_GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8,
60
+ LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_F32,
61
+ LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_F16,
62
+ LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0,
63
+ LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1,
64
+ LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0,
65
+ LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1,
66
+ LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0,
67
+ LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K,
68
+ LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K,
69
+ LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K,
70
+ LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K,
71
+ LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K,
72
+ LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS,
73
+ LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS,
74
+ LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS,
75
+ LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S,
76
+ LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S,
77
+ LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S,
78
+ LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M,
79
+ LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,
80
+ LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,
81
+ LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
82
+ LM_GGML_METAL_KERNEL_TYPE_RMS_NORM,
83
+ LM_GGML_METAL_KERNEL_TYPE_GROUP_NORM,
84
+ LM_GGML_METAL_KERNEL_TYPE_NORM,
85
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
86
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,
87
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
88
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,
89
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4,
90
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32,
91
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32,
92
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32,
93
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32,
94
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32,
95
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32,
96
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32,
97
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32,
98
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32,
99
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32,
100
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32,
101
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32,
102
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32,
103
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32,
104
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32,
105
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32,
106
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32,
107
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32,
108
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32,
109
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,
110
+ //LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16,
111
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32,
112
+ //LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW,
113
+ //LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4,
114
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32,
115
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32,
116
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32,
117
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32,
118
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32,
119
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32,
120
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32,
121
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32,
122
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32,
123
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32,
124
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32,
125
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32,
126
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32,
127
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32,
128
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32,
129
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32,
130
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32,
131
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32,
132
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32,
133
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
134
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,
135
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,
136
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32,
137
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32,
138
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32,
139
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32,
140
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32,
141
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32,
142
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32,
143
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32,
144
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32,
145
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32,
146
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32,
147
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32,
148
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32,
149
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32,
150
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32,
151
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32,
152
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,
153
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,
154
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,
155
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,
156
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,
157
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32,
158
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32,
159
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32,
160
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32,
161
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32,
162
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32,
163
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32,
164
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32,
165
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32,
166
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32,
167
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32,
168
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32,
169
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32,
170
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32,
171
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32,
172
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32,
173
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,
174
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,
175
+ LM_GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32,
176
+ LM_GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16,
177
+ LM_GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32,
178
+ LM_GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16,
179
+ LM_GGML_METAL_KERNEL_TYPE_IM2COL_F16,
180
+ LM_GGML_METAL_KERNEL_TYPE_IM2COL_F32,
181
+ LM_GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
182
+ LM_GGML_METAL_KERNEL_TYPE_PAD_F32,
183
+ LM_GGML_METAL_KERNEL_TYPE_ARANGE_F32,
184
+ LM_GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32,
185
+ LM_GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
186
+ LM_GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC,
187
+ LM_GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32,
188
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64,
189
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80,
190
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96,
191
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,
192
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,
193
+ //LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
194
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
195
+ //LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
196
+ LM_GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
197
+ LM_GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
198
+ LM_GGML_METAL_KERNEL_TYPE_CPY_F16_F16,
199
+ LM_GGML_METAL_KERNEL_TYPE_CPY_F16_F32,
200
+ LM_GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
201
+ LM_GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0,
202
+ LM_GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1,
203
+ LM_GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0,
204
+ LM_GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,
205
+ LM_GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,
206
+ LM_GGML_METAL_KERNEL_TYPE_CONCAT,
207
+ LM_GGML_METAL_KERNEL_TYPE_SQR,
208
+ LM_GGML_METAL_KERNEL_TYPE_SUM_ROWS,
209
+
210
+ LM_GGML_METAL_KERNEL_TYPE_COUNT
211
+ };
212
+
213
+ struct lm_ggml_metal_context {
214
+ int n_cb;
215
+
216
+ id<MTLDevice> device;
217
+ id<MTLCommandQueue> queue;
218
+
219
+ dispatch_queue_t d_queue;
220
+
221
+ struct lm_ggml_metal_kernel kernels[LM_GGML_METAL_KERNEL_TYPE_COUNT];
222
+
223
+ bool support_simdgroup_reduction;
224
+ bool support_simdgroup_mm;
225
+
226
+ bool should_capture_next_compute;
227
+ };
228
+
229
+ // MSL code
230
+ // TODO: move the contents here when ready
231
+ // for now it is easier to work in a separate file
232
+ // static NSString * const msl_library_source = @"see metal.metal";
233
+
234
+ // Here to assist with NSBundle Path Hack
235
+ @interface LMGGMLMetalClass : NSObject
236
+ @end
237
+ @implementation LMGGMLMetalClass
238
+ @end
239
+
240
+ static void lm_ggml_metal_default_log_callback(enum lm_ggml_log_level level, const char * msg, void * user_data) {
241
+ fprintf(stderr, "%s", msg);
242
+
243
+ UNUSED(level);
244
+ UNUSED(user_data);
245
+ }
246
+
247
+ lm_ggml_log_callback lm_ggml_metal_log_callback = lm_ggml_metal_default_log_callback;
248
+ void * lm_ggml_metal_log_user_data = NULL;
249
+
250
+ LM_GGML_ATTRIBUTE_FORMAT(2, 3)
251
+ static void lm_ggml_metal_log(enum lm_ggml_log_level level, const char * format, ...){
252
+ if (lm_ggml_metal_log_callback != NULL) {
253
+ va_list args;
254
+ va_start(args, format);
255
+ char buffer[128];
256
+ int len = vsnprintf(buffer, 128, format, args);
257
+ if (len < 128) {
258
+ lm_ggml_metal_log_callback(level, buffer, lm_ggml_metal_log_user_data);
259
+ } else {
260
+ char* buffer2 = malloc(len+1);
261
+ va_end(args);
262
+ va_start(args, format);
263
+ vsnprintf(buffer2, len+1, format, args);
264
+ buffer2[len] = 0;
265
+ lm_ggml_metal_log_callback(level, buffer2, lm_ggml_metal_log_user_data);
266
+ free(buffer2);
267
+ }
268
+ va_end(args);
269
+ }
270
+ }
271
+
272
+ static void * lm_ggml_metal_host_malloc(size_t n) {
273
+ void * data = NULL;
274
+
275
+ #if TARGET_OS_OSX
276
+ kern_return_t err = vm_allocate((vm_map_t) mach_task_self(), (void *) &data, n, VM_FLAGS_ANYWHERE);
277
+ if (err != KERN_SUCCESS) {
278
+ LM_GGML_METAL_LOG_ERROR("%s: error: vm_allocate failed\n", __func__);
279
+ return NULL;
280
+ }
281
+ #else
282
+ const int result = posix_memalign((void **) &data, sysconf(_SC_PAGESIZE), n);
283
+ if (result != 0) {
284
+ LM_GGML_METAL_LOG_ERROR("%s: error: posix_memalign failed\n", __func__);
285
+ return NULL;
286
+ }
287
+ #endif
288
+
289
+ return data;
290
+ }
291
+
292
+ static struct lm_ggml_metal_context * lm_ggml_metal_init(int n_cb) {
293
+ LM_GGML_METAL_LOG_INFO("%s: allocating\n", __func__);
294
+
295
+ #if TARGET_OS_OSX && !LM_GGML_METAL_NDEBUG
296
+ // Show all the Metal device instances in the system
297
+ NSArray * devices = MTLCopyAllDevices();
298
+ for (id<MTLDevice> device in devices) {
299
+ LM_GGML_METAL_LOG_INFO("%s: found device: %s\n", __func__, [[device name] UTF8String]);
300
+ }
301
+ [devices release]; // since it was created by a *Copy* C method
302
+ #endif
303
+
304
+ // Pick and show default Metal device
305
+ id<MTLDevice> device = MTLCreateSystemDefaultDevice();
306
+ LM_GGML_METAL_LOG_INFO("%s: picking default device: %s\n", __func__, [[device name] UTF8String]);
307
+
308
+ // Configure context
309
+ struct lm_ggml_metal_context * ctx = malloc(sizeof(struct lm_ggml_metal_context));
310
+ ctx->device = device;
311
+ ctx->n_cb = MIN(n_cb, LM_GGML_METAL_MAX_BUFFERS);
312
+ ctx->queue = [ctx->device newCommandQueue];
313
+ ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
314
+
315
+ id<MTLLibrary> metal_library;
316
+
317
+ // load library
318
+ //
319
+ // - first check if the library is embedded
320
+ // - then check if the library is in the bundle
321
+ // - if not found, load the source and compile it
322
+ // - if that fails, return NULL
323
+ {
324
+ NSBundle * bundle = nil;
325
+ #ifdef SWIFT_PACKAGE
326
+ bundle = SWIFTPM_MODULE_BUNDLE;
327
+ #else
328
+ bundle = [NSBundle bundleForClass:[LMGGMLMetalClass class]];
329
+ #endif
330
+
331
+ NSError * error = nil;
332
+
333
+ #if LM_GGML_METAL_EMBED_LIBRARY
334
+ const bool try_metallib = false;
335
+ #else
336
+ const bool try_metallib = true;
337
+ #endif
338
+
339
+ NSString * path_lib = [bundle pathForResource:@"ggml-llama" ofType:@"metallib"];
340
+ if (try_metallib && path_lib != nil) {
341
+ // pre-compiled library found
342
+ NSURL * libURL = [NSURL fileURLWithPath:path_lib];
343
+ LM_GGML_METAL_LOG_INFO("%s: loading '%s'\n", __func__, [path_lib UTF8String]);
344
+
345
+ metal_library = [ctx->device newLibraryWithURL:libURL error:&error];
346
+ if (error) {
347
+ LM_GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
348
+ return NULL;
349
+ }
350
+ } else {
351
+ #if LM_GGML_METAL_EMBED_LIBRARY
352
+ LM_GGML_METAL_LOG_INFO("%s: using embedded metal library\n", __func__);
353
+
354
+ extern const char lm_ggml_metallib_start[];
355
+ extern const char lm_ggml_metallib_end[];
356
+
357
+ NSString * src = [[NSString alloc] initWithBytes:lm_ggml_metallib_start length:(lm_ggml_metallib_end-lm_ggml_metallib_start) encoding:NSUTF8StringEncoding];
358
+ #else
359
+ LM_GGML_METAL_LOG_INFO("%s: default.metallib not found, loading from source\n", __func__);
360
+
361
+ NSString * path_source;
362
+ NSString * path_resource = [[NSProcessInfo processInfo].environment objectForKey:@"LM_GGML_METAL_PATH_RESOURCES"];
363
+
364
+ LM_GGML_METAL_LOG_INFO("%s: LM_GGML_METAL_PATH_RESOURCES = %s\n", __func__, path_resource ? [path_resource UTF8String] : "nil");
365
+
366
+ if (path_resource) {
367
+ path_source = [path_resource stringByAppendingPathComponent:@"ggml-metal.metal"];
368
+ } else {
369
+ path_source = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
370
+ }
371
+
372
+ if (path_source == nil) {
373
+ LM_GGML_METAL_LOG_WARN("%s: error: could not use bundle path to find ggml-metal.metal, falling back to trying cwd\n", __func__);
374
+ path_source = @"ggml-metal.metal";
375
+ }
376
+
377
+ LM_GGML_METAL_LOG_INFO("%s: loading '%s'\n", __func__, [path_source UTF8String]);
378
+
379
+ NSString * src = [NSString stringWithContentsOfFile:path_source encoding:NSUTF8StringEncoding error:&error];
380
+ if (error) {
381
+ LM_GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
382
+ return NULL;
383
+ }
384
+ #endif // LM_GGML_METAL_EMBED_LIBRARY
385
+
386
+ @autoreleasepool {
387
+ // dictionary of preprocessor macros
388
+ NSMutableDictionary * prep = [NSMutableDictionary dictionary];
389
+
390
+ MTLCompileOptions* options = [MTLCompileOptions new];
391
+ options.preprocessorMacros = prep;
392
+
393
+ //[options setFastMathEnabled:false];
394
+
395
+ metal_library = [ctx->device newLibraryWithSource:src options:options error:&error];
396
+ if (error) {
397
+ LM_GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
398
+ return NULL;
399
+ }
400
+ }
401
+ }
402
+ }
403
+
404
+ // print MTL GPU family:
405
+ LM_GGML_METAL_LOG_INFO("%s: GPU name: %s\n", __func__, [[ctx->device name] UTF8String]);
406
+
407
+ const NSInteger MTLGPUFamilyMetal3 = 5001;
408
+
409
+ // determine max supported GPU family
410
+ // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
411
+ // https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
412
+ {
413
+ for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) {
414
+ if ([ctx->device supportsFamily:i]) {
415
+ LM_GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - (int) MTLGPUFamilyApple1 + 1, i);
416
+ break;
417
+ }
418
+ }
419
+
420
+ for (int i = MTLGPUFamilyCommon1 + 5; i >= MTLGPUFamilyCommon1; --i) {
421
+ if ([ctx->device supportsFamily:i]) {
422
+ LM_GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyCommon%d (%d)\n", __func__, i - (int) MTLGPUFamilyCommon1 + 1, i);
423
+ break;
424
+ }
425
+ }
426
+
427
+ for (int i = MTLGPUFamilyMetal3 + 5; i >= MTLGPUFamilyMetal3; --i) {
428
+ if ([ctx->device supportsFamily:i]) {
429
+ LM_GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyMetal%d (%d)\n", __func__, i - (int) MTLGPUFamilyMetal3 + 3, i);
430
+ break;
431
+ }
432
+ }
433
+ }
434
+
435
+ ctx->support_simdgroup_reduction = [ctx->device supportsFamily:MTLGPUFamilyApple7];
436
+ ctx->support_simdgroup_reduction |= [ctx->device supportsFamily:MTLGPUFamilyMetal3];
437
+
438
+ ctx->support_simdgroup_mm = [ctx->device supportsFamily:MTLGPUFamilyApple7];
439
+
440
+ LM_GGML_METAL_LOG_INFO("%s: simdgroup reduction support = %s\n", __func__, ctx->support_simdgroup_reduction ? "true" : "false");
441
+ LM_GGML_METAL_LOG_INFO("%s: simdgroup matrix mul. support = %s\n", __func__, ctx->support_simdgroup_mm ? "true" : "false");
442
+ LM_GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
443
+
444
+ ctx->should_capture_next_compute = false;
445
+
446
+ #if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
447
+ if (@available(macOS 10.12, iOS 16.0, *)) {
448
+ LM_GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1e6);
449
+ }
450
+ #elif TARGET_OS_OSX
451
+ if (ctx->device.maxTransferRate != 0) {
452
+ LM_GGML_METAL_LOG_INFO("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1e6);
453
+ } else {
454
+ LM_GGML_METAL_LOG_INFO("%s: maxTransferRate = built-in GPU\n", __func__);
455
+ }
456
+ #endif
457
+
458
+ // load kernels
459
+ {
460
+ NSError * error = nil;
461
+
462
+ for (int i = 0; i < LM_GGML_METAL_KERNEL_TYPE_COUNT; ++i) {
463
+ ctx->kernels[i].pipeline = nil;
464
+ }
465
+
466
+ /*
467
+ LM_GGML_METAL_LOG_INFO("%s: loaded %-40s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) kernel->pipeline, \
468
+ (int) kernel->pipeline.maxTotalThreadsPerThreadgroup, \
469
+ (int) kernel->pipeline.threadExecutionWidth); \
470
+ */
471
+ #define LM_GGML_METAL_ADD_KERNEL(e, name, supported) \
472
+ if (supported) { \
473
+ struct lm_ggml_metal_kernel * kernel = &ctx->kernels[e]; \
474
+ id<MTLFunction> metal_function = [metal_library newFunctionWithName:@"kernel_"#name]; \
475
+ kernel->pipeline = [ctx->device newComputePipelineStateWithFunction:metal_function error:&error]; \
476
+ [metal_function release]; \
477
+ if (error) { \
478
+ LM_GGML_METAL_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
479
+ [metal_library release]; \
480
+ return NULL; \
481
+ } \
482
+ } else { \
483
+ LM_GGML_METAL_LOG_WARN("%s: skipping %-40s (not supported)\n", __func__, "kernel_"#name); \
484
+ }
485
+
486
+ // simd_sum and simd_max requires MTLGPUFamilyApple7
487
+
488
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ADD, add, true);
489
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true);
490
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL, mul, true);
491
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true);
492
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_DIV, div, true);
493
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true);
494
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_REPEAT_F32, repeat_f32, true);
495
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_REPEAT_F16, repeat_f16, true);
496
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_REPEAT_I32, repeat_i32, true);
497
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_REPEAT_I16, repeat_i16, true);
498
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SCALE, scale, true);
499
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true);
500
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true);
501
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_TANH, tanh, true);
502
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_RELU, relu, true);
503
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SIGMOID, sigmoid, true);
504
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GELU, gelu, true);
505
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GELU_4, gelu_4, true);
506
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true);
507
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true);
508
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SILU, silu, true);
509
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true);
510
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16, ctx->support_simdgroup_reduction);
511
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, ctx->support_simdgroup_reduction);
512
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, ctx->support_simdgroup_reduction);
513
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4, soft_max_f32_4, ctx->support_simdgroup_reduction);
514
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true);
515
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true);
516
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true);
517
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, get_rows_f16, true);
518
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, get_rows_q4_0, true);
519
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, get_rows_q4_1, true);
520
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true);
521
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1, get_rows_q5_1, true);
522
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0, get_rows_q8_0, true);
523
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K, get_rows_q2_K, true);
524
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K, get_rows_q3_K, true);
525
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K, get_rows_q4_K, true);
526
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K, get_rows_q5_K, true);
527
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K, get_rows_q6_K, true);
528
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true);
529
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true);
530
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS, get_rows_iq3_xxs, true);
531
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S, get_rows_iq3_s, true);
532
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S, get_rows_iq2_s, true);
533
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, get_rows_iq1_s, true);
534
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M, get_rows_iq1_m, true);
535
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true);
536
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true);
537
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
538
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction);
539
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction);
540
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_NORM, norm, true);
541
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, ctx->support_simdgroup_reduction);
542
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, ctx->support_simdgroup_reduction);
543
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, ctx->support_simdgroup_reduction);
544
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, ctx->support_simdgroup_reduction);
545
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, ctx->support_simdgroup_reduction);
546
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, mul_mv_q4_0_f32, ctx->support_simdgroup_reduction);
547
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32, ctx->support_simdgroup_reduction);
548
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, ctx->support_simdgroup_reduction);
549
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, ctx->support_simdgroup_reduction);
550
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, ctx->support_simdgroup_reduction);
551
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, ctx->support_simdgroup_reduction);
552
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, ctx->support_simdgroup_reduction);
553
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, ctx->support_simdgroup_reduction);
554
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32, mul_mv_q5_K_f32, ctx->support_simdgroup_reduction);
555
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, mul_mv_q6_K_f32, ctx->support_simdgroup_reduction);
556
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, ctx->support_simdgroup_reduction);
557
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, ctx->support_simdgroup_reduction);
558
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, ctx->support_simdgroup_reduction);
559
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32, mul_mv_iq3_s_f32, ctx->support_simdgroup_reduction);
560
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32, mul_mv_iq2_s_f32, ctx->support_simdgroup_reduction);
561
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, ctx->support_simdgroup_reduction);
562
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32, mul_mv_iq1_m_f32, ctx->support_simdgroup_reduction);
563
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, ctx->support_simdgroup_reduction);
564
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32, ctx->support_simdgroup_reduction);
565
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction);
566
+ //LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, ctx->support_simdgroup_reduction);
567
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, ctx->support_simdgroup_reduction);
568
+ //LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, ctx->support_simdgroup_reduction);
569
+ //LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, ctx->support_simdgroup_reduction);
570
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, ctx->support_simdgroup_reduction);
571
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, ctx->support_simdgroup_reduction);
572
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, ctx->support_simdgroup_reduction);
573
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, ctx->support_simdgroup_reduction);
574
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, ctx->support_simdgroup_reduction);
575
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, ctx->support_simdgroup_reduction);
576
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, ctx->support_simdgroup_reduction);
577
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32, ctx->support_simdgroup_reduction);
578
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32, mul_mv_id_q5_K_f32, ctx->support_simdgroup_reduction);
579
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, mul_mv_id_q6_K_f32, ctx->support_simdgroup_reduction);
580
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, ctx->support_simdgroup_reduction);
581
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, ctx->support_simdgroup_reduction);
582
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, ctx->support_simdgroup_reduction);
583
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32, mul_mv_id_iq3_s_f32, ctx->support_simdgroup_reduction);
584
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32, mul_mv_id_iq2_s_f32, ctx->support_simdgroup_reduction);
585
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, ctx->support_simdgroup_reduction);
586
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32, mul_mv_id_iq1_m_f32, ctx->support_simdgroup_reduction);
587
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, ctx->support_simdgroup_reduction);
588
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, ctx->support_simdgroup_reduction);
589
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm);
590
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm);
591
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm);
592
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, ctx->support_simdgroup_mm);
593
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, ctx->support_simdgroup_mm);
594
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, ctx->support_simdgroup_mm);
595
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, ctx->support_simdgroup_mm);
596
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, ctx->support_simdgroup_mm);
597
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, ctx->support_simdgroup_mm);
598
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, ctx->support_simdgroup_mm);
599
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32, mul_mm_q5_K_f32, ctx->support_simdgroup_mm);
600
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, mul_mm_q6_K_f32, ctx->support_simdgroup_mm);
601
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, ctx->support_simdgroup_mm);
602
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, ctx->support_simdgroup_mm);
603
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, ctx->support_simdgroup_mm);
604
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32, mul_mm_iq3_s_f32, ctx->support_simdgroup_mm);
605
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32, mul_mm_iq2_s_f32, ctx->support_simdgroup_mm);
606
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, ctx->support_simdgroup_mm);
607
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, ctx->support_simdgroup_mm);
608
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, ctx->support_simdgroup_mm);
609
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, ctx->support_simdgroup_mm);
610
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm);
611
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm);
612
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, ctx->support_simdgroup_mm);
613
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, ctx->support_simdgroup_mm);
614
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, ctx->support_simdgroup_mm);
615
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, mul_mm_id_q5_1_f32, ctx->support_simdgroup_mm);
616
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, mul_mm_id_q8_0_f32, ctx->support_simdgroup_mm);
617
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, mul_mm_id_q2_K_f32, ctx->support_simdgroup_mm);
618
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, mul_mm_id_q3_K_f32, ctx->support_simdgroup_mm);
619
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32, mul_mm_id_q4_K_f32, ctx->support_simdgroup_mm);
620
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32, mul_mm_id_q5_K_f32, ctx->support_simdgroup_mm);
621
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32, ctx->support_simdgroup_mm);
622
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, ctx->support_simdgroup_mm);
623
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, ctx->support_simdgroup_mm);
624
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, ctx->support_simdgroup_mm);
625
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32, mul_mm_id_iq3_s_f32, ctx->support_simdgroup_mm);
626
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, mul_mm_id_iq2_s_f32, ctx->support_simdgroup_mm);
627
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, ctx->support_simdgroup_mm);
628
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, ctx->support_simdgroup_mm);
629
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm);
630
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm);
631
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, rope_norm_f32, true);
632
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, rope_norm_f16, true);
633
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, rope_neox_f32, true);
634
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, rope_neox_f16, true);
635
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
636
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
637
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
638
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true);
639
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true);
640
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ARANGE_F32, arange_f32, true);
641
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
642
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true);
643
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true);
644
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, ctx->support_simdgroup_mm);
645
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, ctx->support_simdgroup_mm);
646
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, ctx->support_simdgroup_mm);
647
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, ctx->support_simdgroup_mm);
648
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, ctx->support_simdgroup_mm);
649
+ //LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, ctx->support_simdgroup_mm);
650
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, ctx->support_simdgroup_reduction);
651
+ //LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, ctx->support_simdgroup_reduction);
652
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
653
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
654
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
655
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true);
656
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
657
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true);
658
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true);
659
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true);
660
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true);
661
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true);
662
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CONCAT, concat, true);
663
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SQR, sqr, true);
664
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
665
+ }
666
+
667
+ [metal_library release];
668
+ return ctx;
669
+ }
670
+
671
+ static void lm_ggml_metal_free(struct lm_ggml_metal_context * ctx) {
672
+ LM_GGML_METAL_LOG_INFO("%s: deallocating\n", __func__);
673
+
674
+ for (int i = 0; i < LM_GGML_METAL_KERNEL_TYPE_COUNT; ++i) {
675
+ [ctx->kernels[i].pipeline release];
676
+ }
677
+
678
+ [ctx->queue release];
679
+ [ctx->device release];
680
+
681
+ dispatch_release(ctx->d_queue);
682
+
683
+ free(ctx);
684
+ }
685
+
686
+ // temporarily defined here for compatibility between ggml-backend and the old API
687
+
688
+ struct lm_ggml_backend_metal_buffer {
689
+ void * data;
690
+ size_t size;
691
+
692
+ id<MTLBuffer> metal;
693
+ };
694
+
695
+ struct lm_ggml_backend_metal_buffer_context {
696
+ void * all_data;
697
+ size_t all_size;
698
+ bool owned;
699
+
700
+ // multiple buffers are used only to avoid the maximum buffer size limitation when using mmap
701
+ int n_buffers;
702
+ struct lm_ggml_backend_metal_buffer buffers[LM_GGML_METAL_MAX_BUFFERS];
703
+ };
704
+
705
+ // finds the Metal buffer that contains the tensor data on the GPU device
706
+ // the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the
707
+ // Metal buffer based on the host memory pointer
708
+ //
709
+ static id<MTLBuffer> lm_ggml_metal_get_buffer(struct lm_ggml_tensor * t, size_t * offs) {
710
+ //LM_GGML_METAL_LOG_INFO("%s: data tensor '%16s', offs_data = %8ld, offs_eval = %8ld, offs_cach = %8ld\n", __func__, t->name, offs_data, offs_eval, offs_cach);
711
+
712
+ const int64_t tsize = lm_ggml_nbytes(t);
713
+
714
+ lm_ggml_backend_buffer_t buffer = t->view_src ? t->view_src->buffer : t->buffer;
715
+
716
+ struct lm_ggml_backend_metal_buffer_context * buf_ctx = (struct lm_ggml_backend_metal_buffer_context *) buffer->context;
717
+
718
+ // find the view that contains the tensor fully
719
+ for (int i = 0; i < buf_ctx->n_buffers; ++i) {
720
+ const int64_t ioffs = (int64_t) t->data - (int64_t) buf_ctx->buffers[i].data;
721
+
722
+ //LM_GGML_METAL_LOG_INFO("ioffs = %10ld, tsize = %10ld, sum = %10ld, buf_ctx->buffers[%d].size = %10ld\n", ioffs, tsize, ioffs + tsize, i, buf_ctx->buffers[i].size);
723
+ if (ioffs >= 0 && ioffs + tsize <= (int64_t) buf_ctx->buffers[i].size) {
724
+ *offs = (size_t) ioffs;
725
+
726
+ //LM_GGML_METAL_LOG_INFO("%s: tensor '%16s', offs = %8ld\n", __func__, t->name, *offs);
727
+
728
+ return buf_ctx->buffers[i].metal;
729
+ }
730
+ }
731
+
732
+ LM_GGML_METAL_LOG_ERROR("%s: error: tensor '%s' buffer is nil\n", __func__, t->name);
733
+
734
+ return nil;
735
+ }
736
+
737
+ static bool lm_ggml_metal_supports_op(const struct lm_ggml_metal_context * ctx, const struct lm_ggml_tensor * op) {
738
+ for (size_t i = 0, n = 3; i < n; ++i) {
739
+ if (op->src[i] != NULL && op->src[i]->type == LM_GGML_TYPE_BF16) {
740
+ return false;
741
+ }
742
+ }
743
+
744
+ switch (op->op) {
745
+ case LM_GGML_OP_UNARY:
746
+ switch (lm_ggml_get_unary_op(op)) {
747
+ case LM_GGML_UNARY_OP_TANH:
748
+ case LM_GGML_UNARY_OP_RELU:
749
+ case LM_GGML_UNARY_OP_SIGMOID:
750
+ case LM_GGML_UNARY_OP_GELU:
751
+ case LM_GGML_UNARY_OP_GELU_QUICK:
752
+ case LM_GGML_UNARY_OP_SILU:
753
+ return lm_ggml_is_contiguous(op->src[0]);
754
+ default:
755
+ return false;
756
+ }
757
+ case LM_GGML_OP_NONE:
758
+ case LM_GGML_OP_RESHAPE:
759
+ case LM_GGML_OP_VIEW:
760
+ case LM_GGML_OP_TRANSPOSE:
761
+ case LM_GGML_OP_PERMUTE:
762
+ case LM_GGML_OP_CONCAT:
763
+ case LM_GGML_OP_ADD:
764
+ case LM_GGML_OP_ACC:
765
+ case LM_GGML_OP_MUL:
766
+ case LM_GGML_OP_DIV:
767
+ case LM_GGML_OP_REPEAT:
768
+ case LM_GGML_OP_SCALE:
769
+ case LM_GGML_OP_CLAMP:
770
+ case LM_GGML_OP_SQR:
771
+ case LM_GGML_OP_SUM_ROWS:
772
+ return true;
773
+ case LM_GGML_OP_SOFT_MAX:
774
+ case LM_GGML_OP_RMS_NORM:
775
+ case LM_GGML_OP_GROUP_NORM:
776
+ return ctx->support_simdgroup_reduction;
777
+ case LM_GGML_OP_NORM:
778
+ case LM_GGML_OP_ROPE:
779
+ case LM_GGML_OP_IM2COL:
780
+ return true;
781
+ case LM_GGML_OP_POOL_1D:
782
+ case LM_GGML_OP_POOL_2D:
783
+ return false;
784
+ case LM_GGML_OP_UPSCALE:
785
+ case LM_GGML_OP_PAD:
786
+ case LM_GGML_OP_ARANGE:
787
+ case LM_GGML_OP_TIMESTEP_EMBEDDING:
788
+ case LM_GGML_OP_ARGSORT:
789
+ case LM_GGML_OP_LEAKY_RELU:
790
+ return true;
791
+ case LM_GGML_OP_FLASH_ATTN_EXT:
792
+ if (op->src[1]->type != LM_GGML_TYPE_F16) {
793
+ return false;
794
+ }
795
+ if (op->src[2]->type != LM_GGML_TYPE_F16) {
796
+ return false;
797
+ }
798
+ if (op->src[0]->ne[0] == 256) {
799
+ return false;
800
+ }
801
+ return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels
802
+ case LM_GGML_OP_MUL_MAT:
803
+ case LM_GGML_OP_MUL_MAT_ID:
804
+ return ctx->support_simdgroup_reduction &&
805
+ (op->src[0]->type != LM_GGML_TYPE_F32 || op->src[1]->type == LM_GGML_TYPE_F32);
806
+ case LM_GGML_OP_CPY:
807
+ case LM_GGML_OP_DUP:
808
+ case LM_GGML_OP_CONT:
809
+ {
810
+ switch (op->src[0]->type) {
811
+ case LM_GGML_TYPE_F32:
812
+ switch (op->type) {
813
+ case LM_GGML_TYPE_F32:
814
+ case LM_GGML_TYPE_F16:
815
+ case LM_GGML_TYPE_Q8_0:
816
+ case LM_GGML_TYPE_Q4_0:
817
+ case LM_GGML_TYPE_Q4_1:
818
+ case LM_GGML_TYPE_Q5_0:
819
+ case LM_GGML_TYPE_Q5_1:
820
+ case LM_GGML_TYPE_IQ4_NL:
821
+ return true;
822
+ default:
823
+ return false;
824
+ }
825
+ case LM_GGML_TYPE_F16:
826
+ switch (op->type) {
827
+ case LM_GGML_TYPE_F32:
828
+ case LM_GGML_TYPE_F16:
829
+ return true;
830
+ default:
831
+ return false;
832
+ }
833
+ default:
834
+ return false;
835
+ };
836
+ }
837
+ case LM_GGML_OP_DIAG_MASK_INF:
838
+ case LM_GGML_OP_GET_ROWS:
839
+ {
840
+ return op->ne[3] == 1;
841
+ }
842
+ default:
843
+ return false;
844
+ }
845
+ }
846
+
847
+ static enum lm_ggml_status lm_ggml_metal_graph_compute(
848
+ struct lm_ggml_metal_context * ctx,
849
+ struct lm_ggml_cgraph * gf) {
850
+
851
+ @autoreleasepool {
852
+ MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor;
853
+ edesc.dispatchType = MTLDispatchTypeSerial;
854
+
855
+ // create multiple command buffers and enqueue them
856
+ // then, we encode the graph into the command buffers in parallel
857
+
858
+ const int n_nodes = gf->n_nodes;
859
+ const int n_cb = ctx->n_cb;
860
+ const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb;
861
+
862
+ const bool should_capture = ctx->should_capture_next_compute;
863
+ if (should_capture) {
864
+ ctx->should_capture_next_compute = false;
865
+
866
+ MTLCaptureDescriptor * descriptor = [MTLCaptureDescriptor new];
867
+ descriptor.captureObject = ctx->queue;
868
+
869
+ NSError * error = nil;
870
+ if (![[MTLCaptureManager sharedCaptureManager] startCaptureWithDescriptor:descriptor error:&error]) {
871
+ LM_GGML_METAL_LOG_ERROR("%s: error: unable to start capture '%s'\n", __func__, [[error localizedDescription] UTF8String]);
872
+ LM_GGML_ASSERT(!"capture failed");
873
+ }
874
+ }
875
+
876
+ id<MTLCommandBuffer> command_buffer_builder[n_cb];
877
+ for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
878
+ id<MTLCommandBuffer> command_buffer = [ctx->queue commandBufferWithUnretainedReferences];
879
+ command_buffer_builder[cb_idx] = command_buffer;
880
+
881
+ // enqueue the command buffers in order to specify their execution order
882
+ [command_buffer enqueue];
883
+ }
884
+
885
+ const id<MTLCommandBuffer> *command_buffers = command_buffer_builder;
886
+
887
+ dispatch_apply(n_cb, ctx->d_queue, ^(size_t iter) {
888
+ const int cb_idx = iter;
889
+
890
+ size_t offs_src0 = 0;
891
+ size_t offs_src1 = 0;
892
+ size_t offs_src2 = 0;
893
+ size_t offs_dst = 0;
894
+
895
+ id<MTLCommandBuffer> command_buffer = command_buffers[cb_idx];
896
+ id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
897
+
898
+ const int node_start = (cb_idx + 0) * n_nodes_per_cb;
899
+ const int node_end = MIN((cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb, n_nodes);
900
+
901
+ for (int i = node_start; i < node_end; ++i) {
902
+ if (i == -1) {
903
+ [encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
904
+ continue;
905
+ }
906
+
907
+ //LM_GGML_METAL_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, i, lm_ggml_op_name(gf->nodes[i]->op));
908
+
909
+ struct lm_ggml_tensor * src0 = gf->nodes[i]->src[0];
910
+ struct lm_ggml_tensor * src1 = gf->nodes[i]->src[1];
911
+ struct lm_ggml_tensor * src2 = gf->nodes[i]->src[2];
912
+ struct lm_ggml_tensor * dst = gf->nodes[i];
913
+
914
+ if (lm_ggml_is_empty(dst)) {
915
+ continue;
916
+ }
917
+
918
+ switch (dst->op) {
919
+ case LM_GGML_OP_NONE:
920
+ case LM_GGML_OP_RESHAPE:
921
+ case LM_GGML_OP_VIEW:
922
+ case LM_GGML_OP_TRANSPOSE:
923
+ case LM_GGML_OP_PERMUTE:
924
+ {
925
+ // noop -> next node
926
+ } continue;
927
+ default:
928
+ {
929
+ } break;
930
+ }
931
+
932
+ if (!lm_ggml_metal_supports_op(ctx, dst)) {
933
+ LM_GGML_METAL_LOG_ERROR("%s: error: unsupported op '%s'\n", __func__, lm_ggml_op_desc(dst));
934
+ LM_GGML_ASSERT(!"unsupported op");
935
+ }
936
+
937
+ if (should_capture) {
938
+ [encoder pushDebugGroup:[NSString stringWithCString:lm_ggml_op_desc(dst) encoding:NSUTF8StringEncoding]];
939
+ }
940
+
941
+ const int64_t ne00 = src0 ? src0->ne[0] : 0;
942
+ const int64_t ne01 = src0 ? src0->ne[1] : 0;
943
+ const int64_t ne02 = src0 ? src0->ne[2] : 0;
944
+ const int64_t ne03 = src0 ? src0->ne[3] : 0;
945
+
946
+ const uint64_t nb00 = src0 ? src0->nb[0] : 0;
947
+ const uint64_t nb01 = src0 ? src0->nb[1] : 0;
948
+ const uint64_t nb02 = src0 ? src0->nb[2] : 0;
949
+ const uint64_t nb03 = src0 ? src0->nb[3] : 0;
950
+
951
+ const int64_t ne10 = src1 ? src1->ne[0] : 0;
952
+ const int64_t ne11 = src1 ? src1->ne[1] : 0;
953
+ const int64_t ne12 = src1 ? src1->ne[2] : 0;
954
+ const int64_t ne13 = src1 ? src1->ne[3] : 0;
955
+
956
+ const uint64_t nb10 = src1 ? src1->nb[0] : 0;
957
+ const uint64_t nb11 = src1 ? src1->nb[1] : 0;
958
+ const uint64_t nb12 = src1 ? src1->nb[2] : 0;
959
+ const uint64_t nb13 = src1 ? src1->nb[3] : 0;
960
+
961
+ const int64_t ne20 = src2 ? src2->ne[0] : 0;
962
+ const int64_t ne21 = src2 ? src2->ne[1] : 0;
963
+ const int64_t ne22 = src2 ? src2->ne[2] : 0; LM_GGML_UNUSED(ne22);
964
+ const int64_t ne23 = src2 ? src2->ne[3] : 0; LM_GGML_UNUSED(ne23);
965
+
966
+ const uint64_t nb20 = src2 ? src2->nb[0] : 0; LM_GGML_UNUSED(nb20);
967
+ const uint64_t nb21 = src2 ? src2->nb[1] : 0;
968
+ const uint64_t nb22 = src2 ? src2->nb[2] : 0;
969
+ const uint64_t nb23 = src2 ? src2->nb[3] : 0;
970
+
971
+ const int64_t ne0 = dst ? dst->ne[0] : 0;
972
+ const int64_t ne1 = dst ? dst->ne[1] : 0;
973
+ const int64_t ne2 = dst ? dst->ne[2] : 0;
974
+ const int64_t ne3 = dst ? dst->ne[3] : 0;
975
+
976
+ const uint64_t nb0 = dst ? dst->nb[0] : 0;
977
+ const uint64_t nb1 = dst ? dst->nb[1] : 0;
978
+ const uint64_t nb2 = dst ? dst->nb[2] : 0;
979
+ const uint64_t nb3 = dst ? dst->nb[3] : 0;
980
+
981
+ const enum lm_ggml_type src0t = src0 ? src0->type : LM_GGML_TYPE_COUNT;
982
+ const enum lm_ggml_type src1t = src1 ? src1->type : LM_GGML_TYPE_COUNT;
983
+ const enum lm_ggml_type dstt = dst ? dst->type : LM_GGML_TYPE_COUNT;
984
+
985
+ id<MTLBuffer> id_src0 = src0 ? lm_ggml_metal_get_buffer(src0, &offs_src0) : nil;
986
+ id<MTLBuffer> id_src1 = src1 ? lm_ggml_metal_get_buffer(src1, &offs_src1) : nil;
987
+ id<MTLBuffer> id_src2 = src2 ? lm_ggml_metal_get_buffer(src2, &offs_src2) : nil;
988
+ id<MTLBuffer> id_dst = dst ? lm_ggml_metal_get_buffer(dst, &offs_dst) : nil;
989
+
990
+ //LM_GGML_METAL_LOG_INFO("%s: op - %s\n", __func__, lm_ggml_op_name(dst->op));
991
+ //if (src0) {
992
+ // LM_GGML_METAL_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, lm_ggml_type_name(src0t), ne00, ne01, ne02,
993
+ // lm_ggml_is_contiguous(src0), src0->name);
994
+ //}
995
+ //if (src1) {
996
+ // LM_GGML_METAL_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, lm_ggml_type_name(src1t), ne10, ne11, ne12,
997
+ // lm_ggml_is_contiguous(src1), src1->name);
998
+ //}
999
+ //if (dst) {
1000
+ // LM_GGML_METAL_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, lm_ggml_type_name(dstt), ne0, ne1, ne2,
1001
+ // dst->name);
1002
+ //}
1003
+
1004
+ switch (dst->op) {
1005
+ case LM_GGML_OP_CONCAT:
1006
+ {
1007
+ id<MTLComputePipelineState> pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_CONCAT].pipeline;
1008
+
1009
+ const int32_t dim = ((int32_t *) dst->op_params)[0];
1010
+
1011
+ [encoder setComputePipelineState:pipeline];
1012
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1013
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1014
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1015
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1016
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
1017
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
1018
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
1019
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
1020
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
1021
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
1022
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
1023
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
1024
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
1025
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
1026
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
1027
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
1028
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
1029
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
1030
+ [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
1031
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
1032
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
1033
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
1034
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
1035
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
1036
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
1037
+ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
1038
+ [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
1039
+ [encoder setBytes:&dim length:sizeof(dim) atIndex:27];
1040
+
1041
+ const int nth = MIN(1024, ne0);
1042
+
1043
+ [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1044
+ } break;
1045
+ case LM_GGML_OP_ADD:
1046
+ case LM_GGML_OP_MUL:
1047
+ case LM_GGML_OP_DIV:
1048
+ {
1049
+ LM_GGML_ASSERT(src0t == LM_GGML_TYPE_F32);
1050
+ LM_GGML_ASSERT(src1t == LM_GGML_TYPE_F32);
1051
+
1052
+ const size_t offs = 0;
1053
+
1054
+ bool bcast_row = false;
1055
+
1056
+ int64_t nb = ne00; // used by the "row" kernels
1057
+
1058
+ id<MTLComputePipelineState> pipeline = nil;
1059
+
1060
+ if (lm_ggml_nelements(src1) == ne10 && lm_ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
1061
+ LM_GGML_ASSERT(lm_ggml_is_contiguous(src0));
1062
+
1063
+ // src1 is a row
1064
+ LM_GGML_ASSERT(ne11 == 1);
1065
+
1066
+ nb = ne00 / 4;
1067
+ switch (dst->op) {
1068
+ case LM_GGML_OP_ADD: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline; break;
1069
+ case LM_GGML_OP_MUL: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_ROW].pipeline; break;
1070
+ case LM_GGML_OP_DIV: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_DIV_ROW].pipeline; break;
1071
+ default: LM_GGML_ASSERT(false);
1072
+ }
1073
+
1074
+ bcast_row = true;
1075
+ } else {
1076
+ switch (dst->op) {
1077
+ case LM_GGML_OP_ADD: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_ADD].pipeline; break;
1078
+ case LM_GGML_OP_MUL: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL].pipeline; break;
1079
+ case LM_GGML_OP_DIV: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_DIV].pipeline; break;
1080
+ default: LM_GGML_ASSERT(false);
1081
+ }
1082
+ }
1083
+
1084
+ [encoder setComputePipelineState:pipeline];
1085
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1086
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1087
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1088
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1089
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
1090
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
1091
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
1092
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
1093
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
1094
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
1095
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
1096
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
1097
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
1098
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
1099
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
1100
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
1101
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
1102
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
1103
+ [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
1104
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
1105
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
1106
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
1107
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
1108
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
1109
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
1110
+ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
1111
+ [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
1112
+ [encoder setBytes:&offs length:sizeof(offs) atIndex:27];
1113
+ [encoder setBytes:&nb length:sizeof(nb) atIndex:28];
1114
+
1115
+ if (bcast_row) {
1116
+ const int64_t n = lm_ggml_nelements(dst)/4;
1117
+
1118
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1119
+ } else {
1120
+ const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
1121
+
1122
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1123
+ }
1124
+ } break;
1125
+ case LM_GGML_OP_REPEAT:
1126
+ {
1127
+ id<MTLComputePipelineState> pipeline;
1128
+
1129
+ switch (src0t) {
1130
+ case LM_GGML_TYPE_F32: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_REPEAT_F32].pipeline; break;
1131
+ case LM_GGML_TYPE_F16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_REPEAT_F16].pipeline; break;
1132
+ case LM_GGML_TYPE_I32: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_REPEAT_I32].pipeline; break;
1133
+ case LM_GGML_TYPE_I16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_REPEAT_I16].pipeline; break;
1134
+ default: LM_GGML_ASSERT(false);
1135
+ }
1136
+
1137
+ [encoder setComputePipelineState:pipeline];
1138
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1139
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1140
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
1141
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
1142
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
1143
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
1144
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
1145
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
1146
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
1147
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
1148
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
1149
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
1150
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
1151
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
1152
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
1153
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
1154
+ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
1155
+ [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
1156
+
1157
+ const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
1158
+
1159
+ [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1160
+ } break;
1161
+ case LM_GGML_OP_ACC:
1162
+ {
1163
+ LM_GGML_ASSERT(src0t == LM_GGML_TYPE_F32);
1164
+ LM_GGML_ASSERT(src1t == LM_GGML_TYPE_F32);
1165
+ LM_GGML_ASSERT(dstt == LM_GGML_TYPE_F32);
1166
+
1167
+ LM_GGML_ASSERT(lm_ggml_is_contiguous(src0));
1168
+ LM_GGML_ASSERT(lm_ggml_is_contiguous(src1));
1169
+
1170
+ const size_t pnb1 = ((int32_t *) dst->op_params)[0];
1171
+ const size_t pnb2 = ((int32_t *) dst->op_params)[1];
1172
+ const size_t pnb3 = ((int32_t *) dst->op_params)[2];
1173
+ const size_t offs = ((int32_t *) dst->op_params)[3];
1174
+
1175
+ const bool inplace = (bool) ((int32_t *) dst->op_params)[4];
1176
+
1177
+ if (!inplace) {
1178
+ // run a separete kernel to cpy src->dst
1179
+ // not sure how to avoid this
1180
+ // TODO: make a simpler cpy_bytes kernel
1181
+
1182
+ const id<MTLComputePipelineState> pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline;
1183
+
1184
+ [encoder setComputePipelineState:pipeline];
1185
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1186
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1187
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1188
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
1189
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
1190
+ [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
1191
+ [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
1192
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
1193
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
1194
+ [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
1195
+ [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
1196
+ [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
1197
+ [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
1198
+ [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
1199
+ [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
1200
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
1201
+ [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
1202
+ [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
1203
+
1204
+ const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00);
1205
+
1206
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1207
+ }
1208
+
1209
+ const id<MTLComputePipelineState> pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_ADD].pipeline;
1210
+
1211
+ [encoder setComputePipelineState:pipeline];
1212
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1213
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1214
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1215
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1216
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
1217
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
1218
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
1219
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
1220
+ [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:8];
1221
+ [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:9];
1222
+ [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:10];
1223
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
1224
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
1225
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
1226
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
1227
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
1228
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
1229
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
1230
+ [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
1231
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
1232
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
1233
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
1234
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
1235
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
1236
+ [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:24];
1237
+ [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:25];
1238
+ [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:26];
1239
+ [encoder setBytes:&offs length:sizeof(offs) atIndex:27];
1240
+
1241
+ const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00);
1242
+
1243
+ [encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1244
+ } break;
1245
+ case LM_GGML_OP_SCALE:
1246
+ {
1247
+ LM_GGML_ASSERT(lm_ggml_is_contiguous(src0));
1248
+
1249
+ float scale;
1250
+ memcpy(&scale, dst->op_params, sizeof(scale));
1251
+
1252
+ int64_t n = lm_ggml_nelements(dst);
1253
+
1254
+ id<MTLComputePipelineState> pipeline = nil;
1255
+
1256
+ if (n % 4 == 0) {
1257
+ n /= 4;
1258
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_SCALE_4].pipeline;
1259
+ } else {
1260
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_SCALE].pipeline;
1261
+ }
1262
+
1263
+ [encoder setComputePipelineState:pipeline];
1264
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1265
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1266
+ [encoder setBytes:&scale length:sizeof(scale) atIndex:2];
1267
+
1268
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1269
+ } break;
1270
+ case LM_GGML_OP_CLAMP:
1271
+ {
1272
+ id<MTLComputePipelineState> pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_CLAMP].pipeline;
1273
+
1274
+ float min;
1275
+ float max;
1276
+ memcpy(&min, ((int32_t *) dst->op_params) + 0, sizeof(float));
1277
+ memcpy(&max, ((int32_t *) dst->op_params) + 1, sizeof(float));
1278
+
1279
+ [encoder setComputePipelineState:pipeline];
1280
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1281
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1282
+ [encoder setBytes:&min length:sizeof(min) atIndex:2];
1283
+ [encoder setBytes:&max length:sizeof(max) atIndex:3];
1284
+
1285
+ const int64_t n = lm_ggml_nelements(dst);
1286
+
1287
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1288
+ } break;
1289
+ case LM_GGML_OP_UNARY:
1290
+ switch (lm_ggml_get_unary_op(gf->nodes[i])) {
1291
+ // we are not taking into account the strides, so for now require contiguous tensors
1292
+ LM_GGML_ASSERT(lm_ggml_is_contiguous(src0));
1293
+
1294
+ case LM_GGML_UNARY_OP_TANH:
1295
+ {
1296
+ id<MTLComputePipelineState> pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_TANH].pipeline;
1297
+
1298
+ [encoder setComputePipelineState:pipeline];
1299
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1300
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1301
+
1302
+ const int64_t n = lm_ggml_nelements(dst);
1303
+
1304
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1305
+ } break;
1306
+ case LM_GGML_UNARY_OP_RELU:
1307
+ {
1308
+ id<MTLComputePipelineState> pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_RELU].pipeline;
1309
+
1310
+ [encoder setComputePipelineState:pipeline];
1311
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1312
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1313
+
1314
+ const int64_t n = lm_ggml_nelements(dst);
1315
+
1316
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1317
+ } break;
1318
+ case LM_GGML_UNARY_OP_SIGMOID:
1319
+ {
1320
+ id<MTLComputePipelineState> pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_SIGMOID].pipeline;
1321
+
1322
+ [encoder setComputePipelineState:pipeline];
1323
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1324
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1325
+
1326
+ const int64_t n = lm_ggml_nelements(dst);
1327
+
1328
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1329
+ } break;
1330
+ case LM_GGML_UNARY_OP_GELU:
1331
+ {
1332
+ int64_t n = lm_ggml_nelements(dst);
1333
+
1334
+ id<MTLComputePipelineState> pipeline = nil;
1335
+
1336
+ if (n % 4 == 0) {
1337
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GELU_4].pipeline;
1338
+ n /= 4;
1339
+ } else {
1340
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GELU].pipeline;
1341
+ }
1342
+
1343
+ [encoder setComputePipelineState:pipeline];
1344
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1345
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1346
+
1347
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1348
+ } break;
1349
+ case LM_GGML_UNARY_OP_GELU_QUICK:
1350
+ {
1351
+ int64_t n = lm_ggml_nelements(dst);
1352
+
1353
+ id<MTLComputePipelineState> pipeline = nil;
1354
+
1355
+ if (n % 4 == 0) {
1356
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GELU_QUICK_4].pipeline;
1357
+ n /= 4;
1358
+ } else {
1359
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GELU_QUICK].pipeline;
1360
+ }
1361
+
1362
+ [encoder setComputePipelineState:pipeline];
1363
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1364
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1365
+
1366
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1367
+ } break;
1368
+ case LM_GGML_UNARY_OP_SILU:
1369
+ {
1370
+ int64_t n = lm_ggml_nelements(dst);
1371
+
1372
+ id<MTLComputePipelineState> pipeline = nil;
1373
+
1374
+ if (n % 4 == 0) {
1375
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_SILU_4].pipeline;
1376
+ n /= 4;
1377
+ } else {
1378
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_SILU].pipeline;
1379
+ }
1380
+
1381
+ [encoder setComputePipelineState:pipeline];
1382
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1383
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1384
+
1385
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1386
+ } break;
1387
+ default:
1388
+ {
1389
+ LM_GGML_METAL_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, i, lm_ggml_op_name(dst->op));
1390
+ LM_GGML_ASSERT(false);
1391
+ }
1392
+ } break;
1393
+ case LM_GGML_OP_SQR:
1394
+ {
1395
+ LM_GGML_ASSERT(lm_ggml_is_contiguous(src0));
1396
+
1397
+ id<MTLComputePipelineState> pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_SQR].pipeline;
1398
+
1399
+ [encoder setComputePipelineState:pipeline];
1400
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1401
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1402
+
1403
+ const int64_t n = lm_ggml_nelements(dst);
1404
+
1405
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1406
+ } break;
1407
+ case LM_GGML_OP_SUM_ROWS:
1408
+ {
1409
+ LM_GGML_ASSERT(src0->nb[0] == lm_ggml_type_size(src0->type));
1410
+
1411
+ id<MTLComputePipelineState> pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
1412
+
1413
+ [encoder setComputePipelineState:pipeline];
1414
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1415
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1416
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
1417
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
1418
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
1419
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
1420
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
1421
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
1422
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
1423
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
1424
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
1425
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
1426
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
1427
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
1428
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
1429
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
1430
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
1431
+ [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17];
1432
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:18];
1433
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:19];
1434
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:20];
1435
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:21];
1436
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:22];
1437
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:23];
1438
+ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:24];
1439
+ [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:25];
1440
+
1441
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1442
+ } break;
1443
+ case LM_GGML_OP_SOFT_MAX:
1444
+ {
1445
+ LM_GGML_ASSERT(!src1 || src1->type == LM_GGML_TYPE_F16 || src1->type == LM_GGML_TYPE_F32);
1446
+
1447
+ int nth = 32; // SIMD width
1448
+
1449
+ id<MTLComputePipelineState> pipeline = nil;
1450
+
1451
+ const bool use_f16 = (src1 && src1->type == LM_GGML_TYPE_F16);
1452
+
1453
+ if (ne00%4 == 0) {
1454
+ while (nth < ne00/4 && nth*ne01*ne02*ne03 < 256) {
1455
+ nth *= 2;
1456
+ }
1457
+ if (use_f16) {
1458
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4].pipeline;
1459
+ } else {
1460
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4].pipeline;
1461
+ }
1462
+ } else {
1463
+ while (nth < ne00 && nth*ne01*ne02*ne03 < 256) {
1464
+ nth *= 2;
1465
+ }
1466
+ if (use_f16) {
1467
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16].pipeline;
1468
+ } else {
1469
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32].pipeline;
1470
+ }
1471
+ }
1472
+
1473
+ float scale;
1474
+ float max_bias;
1475
+
1476
+ memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale));
1477
+ memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias));
1478
+
1479
+ const int64_t nrows_x = lm_ggml_nrows(src0);
1480
+ const int64_t nrows_y = src0->ne[1];
1481
+
1482
+ const uint32_t n_head = nrows_x/nrows_y;
1483
+ const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
1484
+
1485
+ const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
1486
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
1487
+
1488
+ [encoder setComputePipelineState:pipeline];
1489
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1490
+ if (id_src1) {
1491
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1492
+ } else {
1493
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
1494
+ }
1495
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1496
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1497
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
1498
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
1499
+ [encoder setBytes:&scale length:sizeof(scale) atIndex:6];
1500
+ [encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:7];
1501
+ [encoder setBytes:&m0 length:sizeof(m0) atIndex:8];
1502
+ [encoder setBytes:&m1 length:sizeof(m1) atIndex:9];
1503
+ [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:10];
1504
+ [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
1505
+
1506
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1507
+ } break;
1508
+ case LM_GGML_OP_DIAG_MASK_INF:
1509
+ {
1510
+ const int n_past = ((int32_t *)(dst->op_params))[0];
1511
+
1512
+ id<MTLComputePipelineState> pipeline = nil;
1513
+
1514
+ if (ne00%8 == 0) {
1515
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8].pipeline;
1516
+ } else {
1517
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF].pipeline;
1518
+ }
1519
+
1520
+ [encoder setComputePipelineState:pipeline];
1521
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1522
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1523
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
1524
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
1525
+ [encoder setBytes:&n_past length:sizeof(int) atIndex:4];
1526
+
1527
+ if (ne00%8 == 0) {
1528
+ [encoder dispatchThreadgroups:MTLSizeMake(ne00*ne01*ne02/8, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1529
+ }
1530
+ else {
1531
+ [encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1532
+ }
1533
+ } break;
1534
+ case LM_GGML_OP_MUL_MAT:
1535
+ {
1536
+ LM_GGML_ASSERT(ne00 == ne10);
1537
+
1538
+ LM_GGML_ASSERT(ne12 % ne02 == 0);
1539
+ LM_GGML_ASSERT(ne13 % ne03 == 0);
1540
+
1541
+ const uint r2 = ne12/ne02;
1542
+ const uint r3 = ne13/ne03;
1543
+
1544
+ // find the break-even point where the matrix-matrix kernel becomes more efficient compared
1545
+ // to the matrix-vector kernel
1546
+ int ne11_mm_min = 1;
1547
+
1548
+ #if 0
1549
+ // the numbers below are measured on M2 Ultra for 7B and 13B models
1550
+ // these numbers do not translate to other devices or model sizes
1551
+ // TODO: need to find a better approach
1552
+ if ([ctx->device.name isEqualToString:@"Apple M2 Ultra"]) {
1553
+ switch (src0t) {
1554
+ case LM_GGML_TYPE_F16: ne11_mm_min = 2; break;
1555
+ case LM_GGML_TYPE_Q8_0: ne11_mm_min = 7; break;
1556
+ case LM_GGML_TYPE_Q2_K: ne11_mm_min = 15; break;
1557
+ case LM_GGML_TYPE_Q3_K: ne11_mm_min = 7; break;
1558
+ case LM_GGML_TYPE_Q4_0:
1559
+ case LM_GGML_TYPE_Q4_1: ne11_mm_min = 15; break;
1560
+ case LM_GGML_TYPE_Q4_K: ne11_mm_min = 11; break;
1561
+ case LM_GGML_TYPE_Q5_0: // not tested yet
1562
+ case LM_GGML_TYPE_Q5_1: ne11_mm_min = 13; break; // not tested yet
1563
+ case LM_GGML_TYPE_Q5_K: ne11_mm_min = 7; break;
1564
+ case LM_GGML_TYPE_Q6_K: ne11_mm_min = 7; break;
1565
+ default: ne11_mm_min = 1; break;
1566
+ }
1567
+ }
1568
+ #endif
1569
+
1570
+ // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
1571
+ // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
1572
+ if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
1573
+ !lm_ggml_is_transposed(src0) &&
1574
+ !lm_ggml_is_transposed(src1) &&
1575
+ src1t == LM_GGML_TYPE_F32 &&
1576
+ ne00 % 32 == 0 && ne00 >= 64 &&
1577
+ (ne11 > ne11_mm_min || (lm_ggml_is_quantized(src0t) && ne12 > 1))) {
1578
+ //printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
1579
+
1580
+ // some Metal matrix data types require aligned pointers
1581
+ // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
1582
+ switch (src0->type) {
1583
+ case LM_GGML_TYPE_F32: LM_GGML_ASSERT(nb01 % 16 == 0); break;
1584
+ case LM_GGML_TYPE_F16: LM_GGML_ASSERT(nb01 % 8 == 0); break;
1585
+ default: break;
1586
+ }
1587
+
1588
+ id<MTLComputePipelineState> pipeline = nil;
1589
+
1590
+ switch (src0->type) {
1591
+ case LM_GGML_TYPE_F32: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32 ].pipeline; break;
1592
+ case LM_GGML_TYPE_F16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32 ].pipeline; break;
1593
+ case LM_GGML_TYPE_Q4_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32 ].pipeline; break;
1594
+ case LM_GGML_TYPE_Q4_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32 ].pipeline; break;
1595
+ case LM_GGML_TYPE_Q5_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32 ].pipeline; break;
1596
+ case LM_GGML_TYPE_Q5_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32 ].pipeline; break;
1597
+ case LM_GGML_TYPE_Q8_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32 ].pipeline; break;
1598
+ case LM_GGML_TYPE_Q2_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32 ].pipeline; break;
1599
+ case LM_GGML_TYPE_Q3_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32 ].pipeline; break;
1600
+ case LM_GGML_TYPE_Q4_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32 ].pipeline; break;
1601
+ case LM_GGML_TYPE_Q5_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32 ].pipeline; break;
1602
+ case LM_GGML_TYPE_Q6_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32 ].pipeline; break;
1603
+ case LM_GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32].pipeline; break;
1604
+ case LM_GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline; break;
1605
+ case LM_GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32].pipeline; break;
1606
+ case LM_GGML_TYPE_IQ3_S: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32 ].pipeline; break;
1607
+ case LM_GGML_TYPE_IQ2_S: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32 ].pipeline; break;
1608
+ case LM_GGML_TYPE_IQ1_S: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32 ].pipeline; break;
1609
+ case LM_GGML_TYPE_IQ1_M: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32 ].pipeline; break;
1610
+ case LM_GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break;
1611
+ case LM_GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32 ].pipeline; break;
1612
+ default: LM_GGML_ASSERT(false && "MUL MAT-MAT not implemented");
1613
+ }
1614
+
1615
+ [encoder setComputePipelineState:pipeline];
1616
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1617
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1618
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1619
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1620
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
1621
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
1622
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6];
1623
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
1624
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:8];
1625
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:9];
1626
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10];
1627
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11];
1628
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12];
1629
+ [encoder setBytes:&r2 length:sizeof(r2) atIndex:13];
1630
+ [encoder setBytes:&r3 length:sizeof(r3) atIndex:14];
1631
+ [encoder setThreadgroupMemoryLength:8192 atIndex:0];
1632
+ [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1633
+ } else {
1634
+ int nth0 = 32;
1635
+ int nth1 = 1;
1636
+ int nrows = 1;
1637
+ //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
1638
+
1639
+ id<MTLComputePipelineState> pipeline = nil;
1640
+
1641
+ // use custom matrix x vector kernel
1642
+ switch (src0t) {
1643
+ case LM_GGML_TYPE_F32:
1644
+ {
1645
+ LM_GGML_ASSERT(src1t == LM_GGML_TYPE_F32);
1646
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline;
1647
+ nrows = 4;
1648
+ } break;
1649
+ case LM_GGML_TYPE_F16:
1650
+ {
1651
+ nth0 = 32;
1652
+ nth1 = 1;
1653
+ if (src1t == LM_GGML_TYPE_F32) {
1654
+ if (ne11 * ne12 < 4) {
1655
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW].pipeline;
1656
+ } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
1657
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline;
1658
+ nrows = ne11;
1659
+ } else {
1660
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32].pipeline;
1661
+ nrows = 4;
1662
+ }
1663
+ } else {
1664
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16].pipeline;
1665
+ nrows = 4;
1666
+ }
1667
+ } break;
1668
+ case LM_GGML_TYPE_Q4_0:
1669
+ {
1670
+ nth0 = 8;
1671
+ nth1 = 8;
1672
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32].pipeline;
1673
+ } break;
1674
+ case LM_GGML_TYPE_Q4_1:
1675
+ {
1676
+ nth0 = 8;
1677
+ nth1 = 8;
1678
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32].pipeline;
1679
+ } break;
1680
+ case LM_GGML_TYPE_Q5_0:
1681
+ {
1682
+ nth0 = 8;
1683
+ nth1 = 8;
1684
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32].pipeline;
1685
+ } break;
1686
+ case LM_GGML_TYPE_Q5_1:
1687
+ {
1688
+ nth0 = 8;
1689
+ nth1 = 8;
1690
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32].pipeline;
1691
+ } break;
1692
+ case LM_GGML_TYPE_Q8_0:
1693
+ {
1694
+ nth0 = 8;
1695
+ nth1 = 8;
1696
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32].pipeline;
1697
+ } break;
1698
+ case LM_GGML_TYPE_Q2_K:
1699
+ {
1700
+ nth0 = 2;
1701
+ nth1 = 32;
1702
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32].pipeline;
1703
+ } break;
1704
+ case LM_GGML_TYPE_Q3_K:
1705
+ {
1706
+ nth0 = 2;
1707
+ nth1 = 32;
1708
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32].pipeline;
1709
+ } break;
1710
+ case LM_GGML_TYPE_Q4_K:
1711
+ {
1712
+ nth0 = 4; //1;
1713
+ nth1 = 8; //32;
1714
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32].pipeline;
1715
+ } break;
1716
+ case LM_GGML_TYPE_Q5_K:
1717
+ {
1718
+ nth0 = 2;
1719
+ nth1 = 32;
1720
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32].pipeline;
1721
+ } break;
1722
+ case LM_GGML_TYPE_Q6_K:
1723
+ {
1724
+ nth0 = 2;
1725
+ nth1 = 32;
1726
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32].pipeline;
1727
+ } break;
1728
+ case LM_GGML_TYPE_IQ2_XXS:
1729
+ {
1730
+ nth0 = 4;
1731
+ nth1 = 16;
1732
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32].pipeline;
1733
+ } break;
1734
+ case LM_GGML_TYPE_IQ2_XS:
1735
+ {
1736
+ nth0 = 4;
1737
+ nth1 = 16;
1738
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32].pipeline;
1739
+ } break;
1740
+ case LM_GGML_TYPE_IQ3_XXS:
1741
+ {
1742
+ nth0 = 4;
1743
+ nth1 = 16;
1744
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32].pipeline;
1745
+ } break;
1746
+ case LM_GGML_TYPE_IQ3_S:
1747
+ {
1748
+ nth0 = 4;
1749
+ nth1 = 16;
1750
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32].pipeline;
1751
+ } break;
1752
+ case LM_GGML_TYPE_IQ2_S:
1753
+ {
1754
+ nth0 = 4;
1755
+ nth1 = 16;
1756
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32].pipeline;
1757
+ } break;
1758
+ case LM_GGML_TYPE_IQ1_S:
1759
+ {
1760
+ nth0 = 4;
1761
+ nth1 = 16;
1762
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32].pipeline;
1763
+ } break;
1764
+ case LM_GGML_TYPE_IQ1_M:
1765
+ {
1766
+ nth0 = 4;
1767
+ nth1 = 16;
1768
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32].pipeline;
1769
+ } break;
1770
+ case LM_GGML_TYPE_IQ4_NL:
1771
+ {
1772
+ nth0 = 4;
1773
+ nth1 = 16;
1774
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32].pipeline;
1775
+ } break;
1776
+ case LM_GGML_TYPE_IQ4_XS:
1777
+ {
1778
+ nth0 = 4;
1779
+ nth1 = 16;
1780
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32].pipeline;
1781
+ } break;
1782
+ default:
1783
+ {
1784
+ LM_GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
1785
+ LM_GGML_ASSERT(false && "not implemented");
1786
+ }
1787
+ };
1788
+
1789
+ if (lm_ggml_is_quantized(src0t)) {
1790
+ LM_GGML_ASSERT(ne00 >= nth0*nth1);
1791
+ }
1792
+
1793
+ [encoder setComputePipelineState:pipeline];
1794
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1795
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1796
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1797
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1798
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
1799
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
1800
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
1801
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
1802
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
1803
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:9];
1804
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:10];
1805
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:11];
1806
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:12];
1807
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:13];
1808
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
1809
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15];
1810
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
1811
+ [encoder setBytes:&r2 length:sizeof(r2) atIndex:17];
1812
+ [encoder setBytes:&r3 length:sizeof(r3) atIndex:18];
1813
+
1814
+ if (src0t == LM_GGML_TYPE_Q4_0 || src0t == LM_GGML_TYPE_Q4_1 || src0t == LM_GGML_TYPE_Q5_0 ||
1815
+ src0t == LM_GGML_TYPE_Q5_1 || src0t == LM_GGML_TYPE_Q8_0 || src0t == LM_GGML_TYPE_Q2_K ||
1816
+ src0t == LM_GGML_TYPE_IQ1_S || src0t == LM_GGML_TYPE_IQ1_M || src0t == LM_GGML_TYPE_IQ2_S) {
1817
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1818
+ }
1819
+ else if (src0t == LM_GGML_TYPE_IQ2_XXS || src0t == LM_GGML_TYPE_IQ2_XS) {
1820
+ const int mem_size = src0t == LM_GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
1821
+ [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1822
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1823
+ }
1824
+ else if (src0t == LM_GGML_TYPE_IQ3_XXS || src0t == LM_GGML_TYPE_IQ3_S) {
1825
+ const int mem_size = src0t == LM_GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
1826
+ [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1827
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1828
+ }
1829
+ else if (src0t == LM_GGML_TYPE_IQ4_NL || src0t == LM_GGML_TYPE_IQ4_XS) {
1830
+ const int mem_size = 32*sizeof(float);
1831
+ [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1832
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1833
+ }
1834
+ else if (src0t == LM_GGML_TYPE_Q4_K) {
1835
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1836
+ }
1837
+ else if (src0t == LM_GGML_TYPE_Q3_K) {
1838
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1839
+ }
1840
+ else if (src0t == LM_GGML_TYPE_Q5_K) {
1841
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1842
+ }
1843
+ else if (src0t == LM_GGML_TYPE_Q6_K) {
1844
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1845
+ } else {
1846
+ const int64_t ny = (ne11 + nrows - 1)/nrows;
1847
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1848
+ }
1849
+ }
1850
+ } break;
1851
+ case LM_GGML_OP_MUL_MAT_ID:
1852
+ {
1853
+ const int n_as = src0->ne[2];
1854
+
1855
+ // src2 = ids
1856
+ const enum lm_ggml_type src2t = src2->type; LM_GGML_UNUSED(src2t);
1857
+
1858
+ LM_GGML_ASSERT(src2t == LM_GGML_TYPE_I32);
1859
+
1860
+ LM_GGML_ASSERT(!lm_ggml_is_transposed(src0));
1861
+ LM_GGML_ASSERT(!lm_ggml_is_transposed(src1));
1862
+
1863
+ LM_GGML_ASSERT(src1t == LM_GGML_TYPE_F32);
1864
+
1865
+ // find the break-even point where the matrix-matrix kernel becomes more efficient compared
1866
+ // to the matrix-vector kernel
1867
+ // ne20 = n_used_experts
1868
+ // ne21 = n_rows
1869
+ const int dst_rows = ne20*ne21;
1870
+ const int dst_rows_min = n_as;
1871
+ const int dst_rows_max = (ctx->device.maxThreadgroupMemoryLength - 32 - 8192)/4;
1872
+
1873
+ // max size of the rowids array in the kernel shared buffer
1874
+ LM_GGML_ASSERT(dst_rows <= dst_rows_max);
1875
+
1876
+ // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
1877
+ // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
1878
+ // !!!
1879
+ // TODO: for now, always use mat-vec kernels until we figure out how to improve the
1880
+ // indirect matrix multiplication
1881
+ // !!!
1882
+ if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
1883
+ ne00 % 32 == 0 && ne00 >= 64 &&
1884
+ dst_rows > dst_rows_min) {
1885
+
1886
+ // some Metal matrix data types require aligned pointers
1887
+ // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
1888
+ switch (src0->type) {
1889
+ case LM_GGML_TYPE_F32: LM_GGML_ASSERT(nb01 % 16 == 0); break;
1890
+ case LM_GGML_TYPE_F16: LM_GGML_ASSERT(nb01 % 8 == 0); break;
1891
+ default: break;
1892
+ }
1893
+
1894
+ id<MTLComputePipelineState> pipeline = nil;
1895
+
1896
+ switch (src0->type) {
1897
+ case LM_GGML_TYPE_F32: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32 ].pipeline; break;
1898
+ case LM_GGML_TYPE_F16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32 ].pipeline; break;
1899
+ case LM_GGML_TYPE_Q4_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32 ].pipeline; break;
1900
+ case LM_GGML_TYPE_Q4_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32 ].pipeline; break;
1901
+ case LM_GGML_TYPE_Q5_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32 ].pipeline; break;
1902
+ case LM_GGML_TYPE_Q5_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32 ].pipeline; break;
1903
+ case LM_GGML_TYPE_Q8_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32 ].pipeline; break;
1904
+ case LM_GGML_TYPE_Q2_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32 ].pipeline; break;
1905
+ case LM_GGML_TYPE_Q3_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32 ].pipeline; break;
1906
+ case LM_GGML_TYPE_Q4_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32 ].pipeline; break;
1907
+ case LM_GGML_TYPE_Q5_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32 ].pipeline; break;
1908
+ case LM_GGML_TYPE_Q6_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32 ].pipeline; break;
1909
+ case LM_GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32].pipeline; break;
1910
+ case LM_GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline; break;
1911
+ case LM_GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32].pipeline; break;
1912
+ case LM_GGML_TYPE_IQ3_S: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32 ].pipeline; break;
1913
+ case LM_GGML_TYPE_IQ2_S: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32 ].pipeline; break;
1914
+ case LM_GGML_TYPE_IQ1_S: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32 ].pipeline; break;
1915
+ case LM_GGML_TYPE_IQ1_M: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32 ].pipeline; break;
1916
+ case LM_GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32 ].pipeline; break;
1917
+ case LM_GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32 ].pipeline; break;
1918
+ default: LM_GGML_ASSERT(false && "MUL_MAT_ID not implemented");
1919
+ }
1920
+
1921
+ [encoder setComputePipelineState:pipeline];
1922
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1923
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1924
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1925
+ [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
1926
+ [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
1927
+ [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
1928
+ [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
1929
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:7];
1930
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:8];
1931
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:9];
1932
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:10];
1933
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
1934
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
1935
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
1936
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
1937
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
1938
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
1939
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17];
1940
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:18];
1941
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19];
1942
+
1943
+ [encoder setThreadgroupMemoryLength:LM_GGML_PAD(8192 + dst_rows*4/*sizeof(ushort2)*/, 16) atIndex:0];
1944
+
1945
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 31)/32, (ne01 + 63)/64, n_as) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1946
+ } else {
1947
+ int nth0 = 32;
1948
+ int nth1 = 1;
1949
+ int nrows = 1;
1950
+ //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
1951
+
1952
+ id<MTLComputePipelineState> pipeline = nil;
1953
+
1954
+ // use custom matrix x vector kernel
1955
+ switch (src0t) {
1956
+ case LM_GGML_TYPE_F32:
1957
+ {
1958
+ LM_GGML_ASSERT(src1t == LM_GGML_TYPE_F32);
1959
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32].pipeline;
1960
+ } break;
1961
+ case LM_GGML_TYPE_F16:
1962
+ {
1963
+ LM_GGML_ASSERT(src1t == LM_GGML_TYPE_F32);
1964
+ nth0 = 32;
1965
+ nth1 = 1;
1966
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32].pipeline;
1967
+ } break;
1968
+ case LM_GGML_TYPE_Q4_0:
1969
+ {
1970
+ nth0 = 8;
1971
+ nth1 = 8;
1972
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32].pipeline;
1973
+ } break;
1974
+ case LM_GGML_TYPE_Q4_1:
1975
+ {
1976
+ nth0 = 8;
1977
+ nth1 = 8;
1978
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32].pipeline;
1979
+ } break;
1980
+ case LM_GGML_TYPE_Q5_0:
1981
+ {
1982
+ nth0 = 8;
1983
+ nth1 = 8;
1984
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32].pipeline;
1985
+ } break;
1986
+ case LM_GGML_TYPE_Q5_1:
1987
+ {
1988
+ nth0 = 8;
1989
+ nth1 = 8;
1990
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32].pipeline;
1991
+ } break;
1992
+ case LM_GGML_TYPE_Q8_0:
1993
+ {
1994
+ nth0 = 8;
1995
+ nth1 = 8;
1996
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32].pipeline;
1997
+ } break;
1998
+ case LM_GGML_TYPE_Q2_K:
1999
+ {
2000
+ nth0 = 2;
2001
+ nth1 = 32;
2002
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32].pipeline;
2003
+ } break;
2004
+ case LM_GGML_TYPE_Q3_K:
2005
+ {
2006
+ nth0 = 2;
2007
+ nth1 = 32;
2008
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32].pipeline;
2009
+ } break;
2010
+ case LM_GGML_TYPE_Q4_K:
2011
+ {
2012
+ nth0 = 4; //1;
2013
+ nth1 = 8; //32;
2014
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32].pipeline;
2015
+ } break;
2016
+ case LM_GGML_TYPE_Q5_K:
2017
+ {
2018
+ nth0 = 2;
2019
+ nth1 = 32;
2020
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32].pipeline;
2021
+ } break;
2022
+ case LM_GGML_TYPE_Q6_K:
2023
+ {
2024
+ nth0 = 2;
2025
+ nth1 = 32;
2026
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32].pipeline;
2027
+ } break;
2028
+ case LM_GGML_TYPE_IQ2_XXS:
2029
+ {
2030
+ nth0 = 4;
2031
+ nth1 = 16;
2032
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32].pipeline;
2033
+ } break;
2034
+ case LM_GGML_TYPE_IQ2_XS:
2035
+ {
2036
+ nth0 = 4;
2037
+ nth1 = 16;
2038
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32].pipeline;
2039
+ } break;
2040
+ case LM_GGML_TYPE_IQ3_XXS:
2041
+ {
2042
+ nth0 = 4;
2043
+ nth1 = 16;
2044
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32].pipeline;
2045
+ } break;
2046
+ case LM_GGML_TYPE_IQ3_S:
2047
+ {
2048
+ nth0 = 4;
2049
+ nth1 = 16;
2050
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32].pipeline;
2051
+ } break;
2052
+ case LM_GGML_TYPE_IQ2_S:
2053
+ {
2054
+ nth0 = 4;
2055
+ nth1 = 16;
2056
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32].pipeline;
2057
+ } break;
2058
+ case LM_GGML_TYPE_IQ1_S:
2059
+ {
2060
+ nth0 = 4;
2061
+ nth1 = 16;
2062
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32].pipeline;
2063
+ } break;
2064
+ case LM_GGML_TYPE_IQ1_M:
2065
+ {
2066
+ nth0 = 4;
2067
+ nth1 = 16;
2068
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32].pipeline;
2069
+ } break;
2070
+ case LM_GGML_TYPE_IQ4_NL:
2071
+ {
2072
+ nth0 = 4;
2073
+ nth1 = 16;
2074
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32].pipeline;
2075
+ } break;
2076
+ case LM_GGML_TYPE_IQ4_XS:
2077
+ {
2078
+ nth0 = 4;
2079
+ nth1 = 16;
2080
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32].pipeline;
2081
+ } break;
2082
+ default:
2083
+ {
2084
+ LM_GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t);
2085
+ LM_GGML_ASSERT(false && "not implemented");
2086
+ }
2087
+ };
2088
+
2089
+ if (lm_ggml_is_quantized(src0t)) {
2090
+ LM_GGML_ASSERT(ne00 >= nth0*nth1);
2091
+ }
2092
+
2093
+ [encoder setComputePipelineState:pipeline];
2094
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2095
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
2096
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
2097
+ [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
2098
+ [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
2099
+ [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
2100
+ [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
2101
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:7];
2102
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:8];
2103
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:9];
2104
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:10];
2105
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:11];
2106
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:12];
2107
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:13];
2108
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:14];
2109
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:15];
2110
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:16];
2111
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:17];
2112
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:18];
2113
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:19];
2114
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:20];
2115
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:21];
2116
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:22];
2117
+
2118
+ const int64_t _ne1 = 1;
2119
+ const int tgz = dst_rows;
2120
+
2121
+ if (src0t == LM_GGML_TYPE_Q4_0 || src0t == LM_GGML_TYPE_Q4_1 || src0t == LM_GGML_TYPE_Q5_0 ||
2122
+ src0t == LM_GGML_TYPE_Q5_1 || src0t == LM_GGML_TYPE_Q8_0 || src0t == LM_GGML_TYPE_Q2_K ||
2123
+ src0t == LM_GGML_TYPE_IQ1_S || src0t == LM_GGML_TYPE_IQ1_M || src0t == LM_GGML_TYPE_IQ2_S) {
2124
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2125
+ }
2126
+ else if (src0t == LM_GGML_TYPE_IQ2_XXS || src0t == LM_GGML_TYPE_IQ2_XS) {
2127
+ const int mem_size = src0t == LM_GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
2128
+ [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
2129
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2130
+ }
2131
+ else if (src0t == LM_GGML_TYPE_IQ3_XXS || src0t == LM_GGML_TYPE_IQ3_S) {
2132
+ const int mem_size = src0t == LM_GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
2133
+ [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
2134
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2135
+ }
2136
+ else if (src0t == LM_GGML_TYPE_IQ4_NL || src0t == LM_GGML_TYPE_IQ4_XS) {
2137
+ const int mem_size = 32*sizeof(float);
2138
+ [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
2139
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2140
+ }
2141
+ else if (src0t == LM_GGML_TYPE_Q4_K) {
2142
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2143
+ }
2144
+ else if (src0t == LM_GGML_TYPE_Q3_K) {
2145
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2146
+ }
2147
+ else if (src0t == LM_GGML_TYPE_Q5_K) {
2148
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2149
+ }
2150
+ else if (src0t == LM_GGML_TYPE_Q6_K) {
2151
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2152
+ } else {
2153
+ const int64_t ny = (_ne1 + nrows - 1)/nrows; // = _ne1
2154
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2155
+ }
2156
+ }
2157
+ } break;
2158
+ case LM_GGML_OP_GET_ROWS:
2159
+ {
2160
+ id<MTLComputePipelineState> pipeline = nil;
2161
+
2162
+ switch (src0->type) {
2163
+ case LM_GGML_TYPE_F32: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_F32 ].pipeline; break;
2164
+ case LM_GGML_TYPE_F16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_F16 ].pipeline; break;
2165
+ case LM_GGML_TYPE_Q4_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0 ].pipeline; break;
2166
+ case LM_GGML_TYPE_Q4_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1 ].pipeline; break;
2167
+ case LM_GGML_TYPE_Q5_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0 ].pipeline; break;
2168
+ case LM_GGML_TYPE_Q5_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1 ].pipeline; break;
2169
+ case LM_GGML_TYPE_Q8_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0 ].pipeline; break;
2170
+ case LM_GGML_TYPE_Q2_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K ].pipeline; break;
2171
+ case LM_GGML_TYPE_Q3_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K ].pipeline; break;
2172
+ case LM_GGML_TYPE_Q4_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K ].pipeline; break;
2173
+ case LM_GGML_TYPE_Q5_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K ].pipeline; break;
2174
+ case LM_GGML_TYPE_Q6_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K ].pipeline; break;
2175
+ case LM_GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS].pipeline; break;
2176
+ case LM_GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS ].pipeline; break;
2177
+ case LM_GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS].pipeline; break;
2178
+ case LM_GGML_TYPE_IQ3_S: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S ].pipeline; break;
2179
+ case LM_GGML_TYPE_IQ2_S: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S ].pipeline; break;
2180
+ case LM_GGML_TYPE_IQ1_S: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S ].pipeline; break;
2181
+ case LM_GGML_TYPE_IQ1_M: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M ].pipeline; break;
2182
+ case LM_GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL ].pipeline; break;
2183
+ case LM_GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS ].pipeline; break;
2184
+ case LM_GGML_TYPE_I32: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ].pipeline; break;
2185
+ default: LM_GGML_ASSERT(false && "not implemented");
2186
+ }
2187
+
2188
+ [encoder setComputePipelineState:pipeline];
2189
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2190
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
2191
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
2192
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
2193
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
2194
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:5];
2195
+ [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:6];
2196
+ [encoder setBytes:&nb10 length:sizeof( int64_t) atIndex:7];
2197
+ [encoder setBytes:&nb11 length:sizeof( int64_t) atIndex:8];
2198
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:9];
2199
+ [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:10];
2200
+
2201
+ [encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
2202
+ } break;
2203
+ case LM_GGML_OP_RMS_NORM:
2204
+ {
2205
+ LM_GGML_ASSERT(ne00 % 4 == 0);
2206
+ LM_GGML_ASSERT(lm_ggml_is_contiguous_1(src0));
2207
+
2208
+ float eps;
2209
+ memcpy(&eps, dst->op_params, sizeof(float));
2210
+
2211
+ int nth = 32; // SIMD width
2212
+
2213
+ while (nth < ne00/4 && nth < 1024) {
2214
+ nth *= 2;
2215
+ }
2216
+
2217
+ id<MTLComputePipelineState> pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_RMS_NORM].pipeline;
2218
+
2219
+ [encoder setComputePipelineState:pipeline];
2220
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2221
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2222
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
2223
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
2224
+ [encoder setBytes:&eps length:sizeof( float) atIndex:4];
2225
+ [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
2226
+
2227
+ const int64_t nrows = lm_ggml_nrows(src0);
2228
+
2229
+ [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2230
+ } break;
2231
+ case LM_GGML_OP_GROUP_NORM:
2232
+ {
2233
+ LM_GGML_ASSERT(ne00 % 4 == 0);
2234
+ LM_GGML_ASSERT(lm_ggml_is_contiguous(src0));
2235
+
2236
+ //float eps;
2237
+ //memcpy(&eps, dst->op_params, sizeof(float));
2238
+
2239
+ const float eps = 1e-6f; // TODO: temporarily hardcoded
2240
+
2241
+ const int32_t n_groups = ((int32_t *) dst->op_params)[0];
2242
+
2243
+ int nth = 32; // SIMD width
2244
+
2245
+ //while (nth < ne00/4 && nth < 1024) {
2246
+ // nth *= 2;
2247
+ //}
2248
+
2249
+ id<MTLComputePipelineState> pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GROUP_NORM].pipeline;
2250
+
2251
+ [encoder setComputePipelineState:pipeline];
2252
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2253
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2254
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
2255
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
2256
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
2257
+ [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:5];
2258
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:6];
2259
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:7];
2260
+ [encoder setBytes:&n_groups length:sizeof( int32_t) atIndex:8];
2261
+ [encoder setBytes:&eps length:sizeof( float) atIndex:9];
2262
+ [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
2263
+
2264
+ [encoder dispatchThreadgroups:MTLSizeMake(n_groups, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2265
+ } break;
2266
+ case LM_GGML_OP_NORM:
2267
+ {
2268
+ LM_GGML_ASSERT(lm_ggml_is_contiguous_1(src0));
2269
+
2270
+ float eps;
2271
+ memcpy(&eps, dst->op_params, sizeof(float));
2272
+
2273
+ const int nth = MIN(256, ne00);
2274
+
2275
+ id<MTLComputePipelineState> pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_NORM].pipeline;
2276
+
2277
+ [encoder setComputePipelineState:pipeline];
2278
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2279
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2280
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
2281
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
2282
+ [encoder setBytes:&eps length:sizeof( float) atIndex:4];
2283
+ [encoder setThreadgroupMemoryLength:LM_GGML_PAD(nth*sizeof(float), 16) atIndex:0];
2284
+
2285
+ const int64_t nrows = lm_ggml_nrows(src0);
2286
+
2287
+ [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2288
+ } break;
2289
+ case LM_GGML_OP_ROPE:
2290
+ {
2291
+ LM_GGML_ASSERT(ne10 == ne02);
2292
+
2293
+ const int nth = MIN(1024, ne00);
2294
+
2295
+ const int n_past = ((int32_t *) dst->op_params)[0];
2296
+ const int n_dims = ((int32_t *) dst->op_params)[1];
2297
+ const int mode = ((int32_t *) dst->op_params)[2];
2298
+ // skip 3, n_ctx, used in GLM RoPE, unimplemented in metal
2299
+ const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
2300
+
2301
+ float freq_base;
2302
+ float freq_scale;
2303
+ float ext_factor;
2304
+ float attn_factor;
2305
+ float beta_fast;
2306
+ float beta_slow;
2307
+
2308
+ memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
2309
+ memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
2310
+ memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
2311
+ memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
2312
+ memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
2313
+ memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
2314
+
2315
+ const bool is_neox = mode & 2;
2316
+
2317
+ id<MTLComputePipelineState> pipeline = nil;
2318
+
2319
+ if (!is_neox) {
2320
+ switch (src0->type) {
2321
+ case LM_GGML_TYPE_F32: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break;
2322
+ case LM_GGML_TYPE_F16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break;
2323
+ default: LM_GGML_ASSERT(false);
2324
+ };
2325
+ } else {
2326
+ switch (src0->type) {
2327
+ case LM_GGML_TYPE_F32: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break;
2328
+ case LM_GGML_TYPE_F16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break;
2329
+ default: LM_GGML_ASSERT(false);
2330
+ };
2331
+ }
2332
+
2333
+ [encoder setComputePipelineState:pipeline];
2334
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2335
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
2336
+ if (id_src2 != nil) {
2337
+ [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
2338
+ } else {
2339
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:2];
2340
+ }
2341
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
2342
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:4];
2343
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5];
2344
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6];
2345
+ [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7];
2346
+ [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:8];
2347
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:9];
2348
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:10];
2349
+ [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:11];
2350
+ [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:12];
2351
+ [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:13];
2352
+ [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:14];
2353
+ [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:15];
2354
+ [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:16];
2355
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:17];
2356
+ [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:18];
2357
+ [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:19];
2358
+ [encoder setBytes:&n_past length:sizeof( int) atIndex:20];
2359
+ [encoder setBytes:&n_dims length:sizeof( int) atIndex:21];
2360
+ [encoder setBytes:&n_ctx_orig length:sizeof( int) atIndex:22];
2361
+ [encoder setBytes:&freq_base length:sizeof( float) atIndex:23];
2362
+ [encoder setBytes:&freq_scale length:sizeof( float) atIndex:24];
2363
+ [encoder setBytes:&ext_factor length:sizeof( float) atIndex:25];
2364
+ [encoder setBytes:&attn_factor length:sizeof( float) atIndex:26];
2365
+ [encoder setBytes:&beta_fast length:sizeof( float) atIndex:27];
2366
+ [encoder setBytes:&beta_slow length:sizeof( float) atIndex:28];
2367
+
2368
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2369
+ } break;
2370
+ case LM_GGML_OP_IM2COL:
2371
+ {
2372
+ LM_GGML_ASSERT(src0->type == LM_GGML_TYPE_F16);
2373
+ LM_GGML_ASSERT(src1->type == LM_GGML_TYPE_F32);
2374
+ LM_GGML_ASSERT( dst->type == LM_GGML_TYPE_F16 || dst->type == LM_GGML_TYPE_F32);
2375
+
2376
+ const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
2377
+ const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
2378
+ const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
2379
+ const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
2380
+ const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
2381
+ const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
2382
+
2383
+ const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
2384
+
2385
+ const int32_t N = src1->ne[is_2D ? 3 : 2];
2386
+ const int32_t IC = src1->ne[is_2D ? 2 : 1];
2387
+ const int32_t IH = is_2D ? src1->ne[1] : 1;
2388
+ const int32_t IW = src1->ne[0];
2389
+
2390
+ const int32_t KH = is_2D ? src0->ne[1] : 1;
2391
+ const int32_t KW = src0->ne[0];
2392
+
2393
+ const int32_t OH = is_2D ? dst->ne[2] : 1;
2394
+ const int32_t OW = dst->ne[1];
2395
+
2396
+ const int32_t CHW = IC * KH * KW;
2397
+
2398
+ const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4;
2399
+ const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4;
2400
+
2401
+ id<MTLComputePipelineState> pipeline = nil;
2402
+
2403
+ switch (dst->type) {
2404
+ case LM_GGML_TYPE_F32: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline; break;
2405
+ case LM_GGML_TYPE_F16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline; break;
2406
+ default: LM_GGML_ASSERT(false);
2407
+ };
2408
+
2409
+ [encoder setComputePipelineState:pipeline];
2410
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:0];
2411
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2412
+ [encoder setBytes:&ofs0 length:sizeof( int32_t) atIndex:2];
2413
+ [encoder setBytes:&ofs1 length:sizeof( int32_t) atIndex:3];
2414
+ [encoder setBytes:&IW length:sizeof( int32_t) atIndex:4];
2415
+ [encoder setBytes:&IH length:sizeof( int32_t) atIndex:5];
2416
+ [encoder setBytes:&CHW length:sizeof( int32_t) atIndex:6];
2417
+ [encoder setBytes:&s0 length:sizeof( int32_t) atIndex:7];
2418
+ [encoder setBytes:&s1 length:sizeof( int32_t) atIndex:8];
2419
+ [encoder setBytes:&p0 length:sizeof( int32_t) atIndex:9];
2420
+ [encoder setBytes:&p1 length:sizeof( int32_t) atIndex:10];
2421
+ [encoder setBytes:&d0 length:sizeof( int32_t) atIndex:11];
2422
+ [encoder setBytes:&d1 length:sizeof( int32_t) atIndex:12];
2423
+
2424
+ [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
2425
+ } break;
2426
+ case LM_GGML_OP_UPSCALE:
2427
+ {
2428
+ LM_GGML_ASSERT(src0->type == LM_GGML_TYPE_F32);
2429
+
2430
+ const float sf0 = (float)ne0/src0->ne[0];
2431
+ const float sf1 = (float)ne1/src0->ne[1];
2432
+ const float sf2 = (float)ne2/src0->ne[2];
2433
+ const float sf3 = (float)ne3/src0->ne[3];
2434
+
2435
+ const id<MTLComputePipelineState> pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_UPSCALE_F32].pipeline;
2436
+
2437
+ [encoder setComputePipelineState:pipeline];
2438
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2439
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2440
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
2441
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
2442
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
2443
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
2444
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
2445
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
2446
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
2447
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
2448
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
2449
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
2450
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
2451
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
2452
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
2453
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
2454
+ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
2455
+ [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
2456
+ [encoder setBytes:&sf0 length:sizeof(sf0) atIndex:18];
2457
+ [encoder setBytes:&sf1 length:sizeof(sf1) atIndex:19];
2458
+ [encoder setBytes:&sf2 length:sizeof(sf2) atIndex:20];
2459
+ [encoder setBytes:&sf3 length:sizeof(sf3) atIndex:21];
2460
+
2461
+ const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
2462
+
2463
+ [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2464
+ } break;
2465
+ case LM_GGML_OP_PAD:
2466
+ {
2467
+ LM_GGML_ASSERT(src0->type == LM_GGML_TYPE_F32);
2468
+
2469
+ id<MTLComputePipelineState> pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_PAD_F32].pipeline;
2470
+
2471
+ [encoder setComputePipelineState:pipeline];
2472
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2473
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2474
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
2475
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
2476
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
2477
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
2478
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
2479
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
2480
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
2481
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
2482
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
2483
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
2484
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
2485
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
2486
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
2487
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
2488
+ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
2489
+ [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
2490
+
2491
+ const int nth = MIN(1024, ne0);
2492
+
2493
+ [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2494
+ } break;
2495
+ case LM_GGML_OP_ARANGE:
2496
+ {
2497
+ LM_GGML_ASSERT(dst->type == LM_GGML_TYPE_F32);
2498
+
2499
+ float start;
2500
+ float step;
2501
+
2502
+ memcpy(&start, ((int32_t *) dst->op_params) + 0, sizeof(float));
2503
+ memcpy(&step, ((int32_t *) dst->op_params) + 2, sizeof(float));
2504
+
2505
+ id<MTLComputePipelineState> pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_ARANGE_F32].pipeline;
2506
+
2507
+ [encoder setComputePipelineState:pipeline];
2508
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:0];
2509
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:1];
2510
+ [encoder setBytes:&start length:sizeof(start) atIndex:2];
2511
+ [encoder setBytes:&step length:sizeof(step) atIndex:3];
2512
+
2513
+ const int nth = MIN(1024, ne0);
2514
+
2515
+ [encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2516
+ } break;
2517
+ case LM_GGML_OP_TIMESTEP_EMBEDDING:
2518
+ {
2519
+ LM_GGML_ASSERT(src0->type == LM_GGML_TYPE_F32);
2520
+
2521
+ const int dim = dst->op_params[0];
2522
+ const int max_period = dst->op_params[1];
2523
+
2524
+ const int half = dim / 2;
2525
+
2526
+ id<MTLComputePipelineState> pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32].pipeline;
2527
+
2528
+ [encoder setComputePipelineState:pipeline];
2529
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2530
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2531
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:2];
2532
+ [encoder setBytes:&dim length:sizeof(dim) atIndex:3];
2533
+ [encoder setBytes:&max_period length:sizeof(max_period) atIndex:4];
2534
+
2535
+ const int nth = MIN(1024, half);
2536
+
2537
+ [encoder dispatchThreadgroups:MTLSizeMake(ne00, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2538
+ } break;
2539
+ case LM_GGML_OP_ARGSORT:
2540
+ {
2541
+ LM_GGML_ASSERT(src0->type == LM_GGML_TYPE_F32);
2542
+ LM_GGML_ASSERT( dst->type == LM_GGML_TYPE_I32);
2543
+
2544
+ const int nrows = lm_ggml_nrows(src0);
2545
+
2546
+ enum lm_ggml_sort_order order = (enum lm_ggml_sort_order) dst->op_params[0];
2547
+
2548
+ // bitonic sort requires the number of elements to be power of 2
2549
+ int64_t ne00_padded = 1;
2550
+ while (ne00_padded < ne00) {
2551
+ ne00_padded *= 2;
2552
+ }
2553
+
2554
+ // Metal kernels require the buffer size to be multiple of 16 bytes
2555
+ // https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength
2556
+ const int mem_size = LM_GGML_PAD(ne00_padded*sizeof(int32_t), 16);
2557
+
2558
+ id<MTLComputePipelineState> pipeline = nil;
2559
+
2560
+ switch (order) {
2561
+ case LM_GGML_SORT_ORDER_ASC: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC].pipeline; break;
2562
+ case LM_GGML_SORT_ORDER_DESC: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC].pipeline; break;
2563
+ default: LM_GGML_ASSERT(false);
2564
+ };
2565
+
2566
+ [encoder setComputePipelineState:pipeline];
2567
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2568
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2569
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
2570
+ [encoder setBytes:&ne00_padded length:sizeof( int64_t) atIndex:3];
2571
+ [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
2572
+
2573
+ [encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00_padded, 1, 1)];
2574
+ } break;
2575
+ case LM_GGML_OP_LEAKY_RELU:
2576
+ {
2577
+ LM_GGML_ASSERT(src0->type == LM_GGML_TYPE_F32);
2578
+
2579
+ float slope;
2580
+ memcpy(&slope, dst->op_params, sizeof(float));
2581
+
2582
+ id<MTLComputePipelineState> pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32].pipeline;
2583
+
2584
+ [encoder setComputePipelineState:pipeline];
2585
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2586
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2587
+ [encoder setBytes:&slope length:sizeof(slope) atIndex:2];
2588
+
2589
+ const int64_t n = lm_ggml_nelements(dst);
2590
+
2591
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2592
+ } break;
2593
+ case LM_GGML_OP_FLASH_ATTN_EXT:
2594
+ {
2595
+ LM_GGML_ASSERT(ne00 % 4 == 0);
2596
+ LM_GGML_ASSERT(ne11 % 32 == 0);
2597
+
2598
+ LM_GGML_ASSERT(src0->type == LM_GGML_TYPE_F32);
2599
+
2600
+ LM_GGML_ASSERT(lm_ggml_are_same_shape (src1, src2));
2601
+
2602
+ struct lm_ggml_tensor * src3 = gf->nodes[i]->src[3];
2603
+
2604
+ size_t offs_src3 = 0;
2605
+
2606
+ id<MTLBuffer> id_src3 = src3 ? lm_ggml_metal_get_buffer(src3, &offs_src3) : nil;
2607
+
2608
+ LM_GGML_ASSERT(!src3 || src3->type == LM_GGML_TYPE_F16);
2609
+ LM_GGML_ASSERT(!src3 || src3->ne[1] >= LM_GGML_PAD(src0->ne[1], 8) &&
2610
+ "the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big");
2611
+
2612
+ const int64_t ne30 = src3 ? src3->ne[0] : 0; LM_GGML_UNUSED(ne30);
2613
+ //const int64_t ne31 = src3 ? src3->ne[1] : 0;
2614
+ const int64_t ne32 = src3 ? src3->ne[2] : 0; LM_GGML_UNUSED(ne32);
2615
+ const int64_t ne33 = src3 ? src3->ne[3] : 0; LM_GGML_UNUSED(ne33);
2616
+
2617
+ const uint64_t nb30 = src3 ? src3->nb[0] : 0; LM_GGML_UNUSED(nb30);
2618
+ const uint64_t nb31 = src3 ? src3->nb[1] : 0;
2619
+ const uint64_t nb32 = src3 ? src3->nb[2] : 0; LM_GGML_UNUSED(nb32);
2620
+ const uint64_t nb33 = src3 ? src3->nb[3] : 0; LM_GGML_UNUSED(nb33);
2621
+
2622
+ const enum lm_ggml_type src2t = src2 ? src2->type : LM_GGML_TYPE_COUNT; LM_GGML_UNUSED(src2t);
2623
+
2624
+ float scale;
2625
+ float max_bias;
2626
+
2627
+ memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale));
2628
+ memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias));
2629
+
2630
+ const uint32_t n_head = src0->ne[2];
2631
+ const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
2632
+
2633
+ const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
2634
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
2635
+
2636
+ id<MTLComputePipelineState> pipeline = nil;
2637
+
2638
+ bool use_vec_kernel = false;
2639
+
2640
+ if (ne01 >= 4 || (ne00%128 != 0)) {
2641
+ switch (ne00) {
2642
+ case 64: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break;
2643
+ case 80: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break;
2644
+ case 96: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break;
2645
+ case 112: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break;
2646
+ case 128: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break;
2647
+ //case 256: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
2648
+ default:
2649
+ {
2650
+ LM_GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
2651
+ LM_GGML_METAL_LOG_ERROR("add template specialization for this size\n");
2652
+ LM_GGML_ASSERT(false && "add template specialization for this size");
2653
+ }
2654
+ }
2655
+ } else {
2656
+ use_vec_kernel = true;
2657
+
2658
+ switch (ne00) {
2659
+ case 128: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break;
2660
+ //case 256: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
2661
+ default:
2662
+ {
2663
+ LM_GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
2664
+ LM_GGML_METAL_LOG_ERROR("add template specialization for this size\n");
2665
+ LM_GGML_ASSERT(false && "add template specialization for this size");
2666
+ }
2667
+ }
2668
+ }
2669
+
2670
+ [encoder setComputePipelineState:pipeline];
2671
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2672
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
2673
+ [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
2674
+ if (id_src3) {
2675
+ [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
2676
+ } else {
2677
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:3];
2678
+ }
2679
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:4];
2680
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5];
2681
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6];
2682
+ [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7];
2683
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8];
2684
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9];
2685
+ [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10];
2686
+ [encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:11];
2687
+ [encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:12];
2688
+ [encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:13];
2689
+ [encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:14];
2690
+ [encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:15];
2691
+ [encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:16];
2692
+ [encoder setBytes:&nb21 length:sizeof(uint64_t) atIndex:17];
2693
+ [encoder setBytes:&nb22 length:sizeof(uint64_t) atIndex:18];
2694
+ [encoder setBytes:&nb23 length:sizeof(uint64_t) atIndex:19];
2695
+ [encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:20];
2696
+ [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:21];
2697
+ [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:22];
2698
+ [encoder setBytes:&scale length:sizeof( float) atIndex:23];
2699
+ [encoder setBytes:&max_bias length:sizeof( float) atIndex:24];
2700
+ [encoder setBytes:&m0 length:sizeof(m0) atIndex:25];
2701
+ [encoder setBytes:&m1 length:sizeof(m1) atIndex:26];
2702
+ [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:27];
2703
+
2704
+ if (!use_vec_kernel) {
2705
+ // half8x8 kernel
2706
+ const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !!
2707
+ const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
2708
+
2709
+ LM_GGML_ASSERT(nqptg <= 32);
2710
+ LM_GGML_ASSERT(nqptg % 8 == 0);
2711
+ LM_GGML_ASSERT(ncpsg % 32 == 0);
2712
+
2713
+ int64_t nsgmax = 2;
2714
+
2715
+ while (true) {
2716
+ const size_t smem = nqptg*(ne00 + 2*nsgmax*(ncpsg + nqptg))*(sizeof(float)/2);
2717
+ if (smem > ctx->device.maxThreadgroupMemoryLength) {
2718
+ break;
2719
+ }
2720
+ nsgmax *= 2;
2721
+ }
2722
+ nsgmax /= 2;
2723
+
2724
+ // simdgroups per threadgroup (a.k.a. warps)
2725
+ const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4;
2726
+
2727
+ const size_t smem = nqptg*(ne00 + 2*nsg*(ncpsg + nqptg))*(sizeof(float)/2);
2728
+
2729
+ //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength);
2730
+ LM_GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength);
2731
+
2732
+ [encoder setThreadgroupMemoryLength:LM_GGML_PAD(smem, 16) atIndex:0];
2733
+
2734
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
2735
+ } else {
2736
+ // half1x4 kernel
2737
+ const int64_t nqptg = 1; // queries per threadgroup !! sync with kernel template arguments !!
2738
+ const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
2739
+
2740
+ LM_GGML_ASSERT(nqptg <= 32);
2741
+ LM_GGML_ASSERT(nqptg % 1 == 0);
2742
+ LM_GGML_ASSERT(ncpsg % 32 == 0);
2743
+
2744
+ // simdgroups per threadgroup (a.k.a. warps)
2745
+ const int64_t nsgt = MAX(2, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32));
2746
+
2747
+ int64_t nsg = 1;
2748
+ while (nsg <= nsgt) {
2749
+ nsg *= 2;
2750
+ }
2751
+ nsg /= 2;
2752
+
2753
+ const size_t smem = (nqptg*(ne00 + 2*nsg*(ncpsg + nqptg)) + nsg*ne00)*(sizeof(float)/2);
2754
+
2755
+ //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength);
2756
+ LM_GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength);
2757
+ [encoder setThreadgroupMemoryLength:LM_GGML_PAD(smem, 16) atIndex:0];
2758
+
2759
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
2760
+ }
2761
+ } break;
2762
+ case LM_GGML_OP_DUP:
2763
+ case LM_GGML_OP_CPY:
2764
+ case LM_GGML_OP_CONT:
2765
+ {
2766
+ LM_GGML_ASSERT(ne00 % lm_ggml_blck_size(src0->type) == 0);
2767
+
2768
+ int nth = MIN(1024, ne00/lm_ggml_blck_size(src0->type));
2769
+
2770
+ id<MTLComputePipelineState> pipeline = nil;
2771
+
2772
+ switch (src0t) {
2773
+ case LM_GGML_TYPE_F32:
2774
+ {
2775
+ LM_GGML_ASSERT(ne0 % lm_ggml_blck_size(dst->type) == 0);
2776
+
2777
+ switch (dstt) {
2778
+ case LM_GGML_TYPE_F32: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break;
2779
+ case LM_GGML_TYPE_F16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break;
2780
+ case LM_GGML_TYPE_Q8_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break;
2781
+ case LM_GGML_TYPE_Q4_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break;
2782
+ case LM_GGML_TYPE_Q4_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break;
2783
+ case LM_GGML_TYPE_Q5_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0].pipeline; break;
2784
+ case LM_GGML_TYPE_Q5_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1].pipeline; break;
2785
+ case LM_GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL].pipeline; break;
2786
+ default: LM_GGML_ASSERT(false && "not implemented");
2787
+ };
2788
+ } break;
2789
+ case LM_GGML_TYPE_F16:
2790
+ {
2791
+ switch (dstt) {
2792
+ case LM_GGML_TYPE_F32: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_CPY_F16_F32].pipeline; break;
2793
+ case LM_GGML_TYPE_F16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline; break;
2794
+ default: LM_GGML_ASSERT(false && "not implemented");
2795
+ };
2796
+ } break;
2797
+ default: LM_GGML_ASSERT(false && "not implemented");
2798
+ }
2799
+
2800
+ [encoder setComputePipelineState:pipeline];
2801
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2802
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2803
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
2804
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
2805
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
2806
+ [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
2807
+ [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
2808
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
2809
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
2810
+ [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
2811
+ [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
2812
+ [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
2813
+ [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
2814
+ [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
2815
+ [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
2816
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
2817
+ [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
2818
+ [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
2819
+
2820
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2821
+ } break;
2822
+ default:
2823
+ {
2824
+ LM_GGML_METAL_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, i, lm_ggml_op_name(dst->op));
2825
+ LM_GGML_ASSERT(false);
2826
+ }
2827
+ }
2828
+
2829
+ if (should_capture) {
2830
+ [encoder popDebugGroup];
2831
+ }
2832
+ }
2833
+
2834
+ [encoder endEncoding];
2835
+
2836
+ [command_buffer commit];
2837
+ });
2838
+
2839
+ // Wait for completion and check status of each command buffer
2840
+ // needed to detect if the device ran out-of-memory for example (#1881)
2841
+
2842
+ for (int i = 0; i < n_cb; ++i) {
2843
+ id<MTLCommandBuffer> command_buffer = command_buffers[i];
2844
+ [command_buffer waitUntilCompleted];
2845
+
2846
+ MTLCommandBufferStatus status = [command_buffer status];
2847
+ if (status != MTLCommandBufferStatusCompleted) {
2848
+ LM_GGML_METAL_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status);
2849
+ if (status == MTLCommandBufferStatusError) {
2850
+ NSString * error_code = [command_buffer error].localizedDescription;
2851
+ LM_GGML_METAL_LOG_INFO("error: %s\n", [error_code UTF8String]);
2852
+ }
2853
+
2854
+ return LM_GGML_STATUS_FAILED;
2855
+ }
2856
+ }
2857
+
2858
+ if (should_capture) {
2859
+ [[MTLCaptureManager sharedCaptureManager] stopCapture];
2860
+ }
2861
+
2862
+ }
2863
+ return LM_GGML_STATUS_SUCCESS;
2864
+ }
2865
+
2866
+ ////////////////////////////////////////////////////////////////////////////////
2867
+
2868
+ // backend interface
2869
+
2870
+ // default buffer
2871
+ static id<MTLDevice> g_backend_device = nil;
2872
+ static int g_backend_device_ref_count = 0;
2873
+
2874
+ static id<MTLDevice> lm_ggml_backend_metal_get_device(void) {
2875
+ if (g_backend_device == nil) {
2876
+ g_backend_device = MTLCreateSystemDefaultDevice();
2877
+ }
2878
+
2879
+ g_backend_device_ref_count++;
2880
+
2881
+ return g_backend_device;
2882
+ }
2883
+
2884
+ static void lm_ggml_backend_metal_free_device(void) {
2885
+ assert(g_backend_device_ref_count > 0);
2886
+
2887
+ g_backend_device_ref_count--;
2888
+
2889
+ if (g_backend_device_ref_count == 0) {
2890
+ [g_backend_device release];
2891
+ g_backend_device = nil;
2892
+ }
2893
+ }
2894
+
2895
+ LM_GGML_CALL static const char * lm_ggml_backend_metal_buffer_get_name(lm_ggml_backend_buffer_t buffer) {
2896
+ return "Metal";
2897
+
2898
+ UNUSED(buffer);
2899
+ }
2900
+
2901
+ LM_GGML_CALL static void lm_ggml_backend_metal_buffer_free_buffer(lm_ggml_backend_buffer_t buffer) {
2902
+ struct lm_ggml_backend_metal_buffer_context * ctx = (struct lm_ggml_backend_metal_buffer_context *)buffer->context;
2903
+
2904
+ for (int i = 0; i < ctx->n_buffers; i++) {
2905
+ [ctx->buffers[i].metal release];
2906
+ }
2907
+ lm_ggml_backend_metal_free_device();
2908
+
2909
+ if (ctx->owned) {
2910
+ #if TARGET_OS_OSX
2911
+ vm_deallocate((vm_map_t)mach_task_self(), (vm_address_t)ctx->all_data, ctx->all_size);
2912
+ #else
2913
+ free(ctx->all_data);
2914
+ #endif
2915
+ }
2916
+
2917
+ free(ctx);
2918
+ }
2919
+
2920
+ LM_GGML_CALL static void * lm_ggml_backend_metal_buffer_get_base(lm_ggml_backend_buffer_t buffer) {
2921
+ struct lm_ggml_backend_metal_buffer_context * ctx = (struct lm_ggml_backend_metal_buffer_context *)buffer->context;
2922
+
2923
+ return ctx->all_data;
2924
+ }
2925
+
2926
+ LM_GGML_CALL static void lm_ggml_backend_metal_buffer_set_tensor(lm_ggml_backend_buffer_t buffer, struct lm_ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
2927
+ memcpy((char *)tensor->data + offset, data, size);
2928
+
2929
+ UNUSED(buffer);
2930
+ }
2931
+
2932
+ LM_GGML_CALL static void lm_ggml_backend_metal_buffer_get_tensor(lm_ggml_backend_buffer_t buffer, const struct lm_ggml_tensor * tensor, void * data, size_t offset, size_t size) {
2933
+ memcpy(data, (const char *)tensor->data + offset, size);
2934
+
2935
+ UNUSED(buffer);
2936
+ }
2937
+
2938
+ LM_GGML_CALL static bool lm_ggml_backend_metal_buffer_cpy_tensor(lm_ggml_backend_buffer_t buffer, const struct lm_ggml_tensor * src, struct lm_ggml_tensor * dst) {
2939
+ if (lm_ggml_backend_buffer_is_host(src->buffer)) {
2940
+ memcpy(dst->data, src->data, lm_ggml_nbytes(src));
2941
+ return true;
2942
+ }
2943
+ return false;
2944
+
2945
+ UNUSED(buffer);
2946
+ }
2947
+
2948
+ LM_GGML_CALL static void lm_ggml_backend_metal_buffer_clear(lm_ggml_backend_buffer_t buffer, uint8_t value) {
2949
+ struct lm_ggml_backend_metal_buffer_context * ctx = (struct lm_ggml_backend_metal_buffer_context *)buffer->context;
2950
+
2951
+ memset(ctx->all_data, value, ctx->all_size);
2952
+ }
2953
+
2954
+ static struct lm_ggml_backend_buffer_i lm_ggml_backend_metal_buffer_i = {
2955
+ /* .get_name = */ lm_ggml_backend_metal_buffer_get_name,
2956
+ /* .free_buffer = */ lm_ggml_backend_metal_buffer_free_buffer,
2957
+ /* .get_base = */ lm_ggml_backend_metal_buffer_get_base,
2958
+ /* .init_tensor = */ NULL,
2959
+ /* .set_tensor = */ lm_ggml_backend_metal_buffer_set_tensor,
2960
+ /* .get_tensor = */ lm_ggml_backend_metal_buffer_get_tensor,
2961
+ /* .cpy_tensor = */ lm_ggml_backend_metal_buffer_cpy_tensor,
2962
+ /* .clear = */ lm_ggml_backend_metal_buffer_clear,
2963
+ /* .reset = */ NULL,
2964
+ };
2965
+
2966
+ // default buffer type
2967
+
2968
+ LM_GGML_CALL static const char * lm_ggml_backend_metal_buffer_type_get_name(lm_ggml_backend_buffer_type_t buft) {
2969
+ return "Metal";
2970
+
2971
+ UNUSED(buft);
2972
+ }
2973
+
2974
+ static void lm_ggml_backend_metal_log_allocated_size(id<MTLDevice> device, size_t size_aligned) {
2975
+ #ifndef LM_GGML_METAL_NDEBUG
2976
+ #if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
2977
+ if (@available(macOS 10.12, iOS 16.0, *)) {
2978
+ LM_GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB, (%8.2f / %8.2f)",
2979
+ __func__,
2980
+ size_aligned / 1024.0 / 1024.0,
2981
+ device.currentAllocatedSize / 1024.0 / 1024.0,
2982
+ device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
2983
+
2984
+ if (device.currentAllocatedSize > device.recommendedMaxWorkingSetSize) {
2985
+ LM_GGML_METAL_LOG_WARN("%s: warning: current allocated size is greater than the recommended max working set size\n", __func__);
2986
+ } else {
2987
+ LM_GGML_METAL_LOG_INFO("\n");
2988
+ }
2989
+ } else {
2990
+ LM_GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB, (%8.2f)\n",
2991
+ __func__,
2992
+ size_aligned / 1024.0 / 1024.0,
2993
+ device.currentAllocatedSize / 1024.0 / 1024.0);
2994
+ }
2995
+ #endif
2996
+ #endif
2997
+ UNUSED(device);
2998
+ UNUSED(size_aligned);
2999
+ }
3000
+
3001
+ LM_GGML_CALL static lm_ggml_backend_buffer_t lm_ggml_backend_metal_buffer_type_alloc_buffer(lm_ggml_backend_buffer_type_t buft, size_t size) {
3002
+ struct lm_ggml_backend_metal_buffer_context * ctx = malloc(sizeof(struct lm_ggml_backend_metal_buffer_context));
3003
+
3004
+ const size_t size_page = sysconf(_SC_PAGESIZE);
3005
+
3006
+ size_t size_aligned = size;
3007
+ if ((size_aligned % size_page) != 0) {
3008
+ size_aligned += (size_page - (size_aligned % size_page));
3009
+ }
3010
+
3011
+ id<MTLDevice> device = lm_ggml_backend_metal_get_device();
3012
+
3013
+ ctx->all_data = lm_ggml_metal_host_malloc(size_aligned);
3014
+ ctx->all_size = size_aligned;
3015
+ ctx->owned = true;
3016
+ ctx->n_buffers = 1;
3017
+
3018
+ if (ctx->all_data != NULL) {
3019
+ ctx->buffers[0].data = ctx->all_data;
3020
+ ctx->buffers[0].size = size;
3021
+ ctx->buffers[0].metal = [device newBufferWithBytesNoCopy:ctx->all_data
3022
+ length:size_aligned
3023
+ options:MTLResourceStorageModeShared
3024
+ deallocator:nil];
3025
+ }
3026
+
3027
+ if (ctx->all_data == NULL || ctx->buffers[0].metal == nil) {
3028
+ LM_GGML_METAL_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0);
3029
+ free(ctx);
3030
+ lm_ggml_backend_metal_free_device();
3031
+ return NULL;
3032
+ }
3033
+
3034
+ //lm_ggml_backend_metal_log_allocated_size(device, size_aligned);
3035
+
3036
+ return lm_ggml_backend_buffer_init(buft, lm_ggml_backend_metal_buffer_i, ctx, size);
3037
+ }
3038
+
3039
+ LM_GGML_CALL static size_t lm_ggml_backend_metal_buffer_type_get_alignment(lm_ggml_backend_buffer_type_t buft) {
3040
+ return 32;
3041
+ UNUSED(buft);
3042
+ }
3043
+
3044
+ LM_GGML_CALL static size_t lm_ggml_backend_metal_buffer_type_get_max_size(lm_ggml_backend_buffer_type_t buft) {
3045
+ id<MTLDevice> device = lm_ggml_backend_metal_get_device();
3046
+ size_t max_size = device.maxBufferLength;
3047
+ lm_ggml_backend_metal_free_device();
3048
+
3049
+ return max_size;
3050
+
3051
+ UNUSED(buft);
3052
+ }
3053
+
3054
+ LM_GGML_CALL static bool lm_ggml_backend_metal_buffer_type_is_host(lm_ggml_backend_buffer_type_t buft) {
3055
+ return true;
3056
+
3057
+ UNUSED(buft);
3058
+ }
3059
+
3060
+ LM_GGML_CALL lm_ggml_backend_buffer_type_t lm_ggml_backend_metal_buffer_type(void) {
3061
+ static struct lm_ggml_backend_buffer_type lm_ggml_backend_buffer_type_metal = {
3062
+ /* .iface = */ {
3063
+ /* .get_name = */ lm_ggml_backend_metal_buffer_type_get_name,
3064
+ /* .alloc_buffer = */ lm_ggml_backend_metal_buffer_type_alloc_buffer,
3065
+ /* .get_alignment = */ lm_ggml_backend_metal_buffer_type_get_alignment,
3066
+ /* .get_max_size = */ lm_ggml_backend_metal_buffer_type_get_max_size,
3067
+ /* .get_alloc_size = */ NULL, // defaults to lm_ggml_nbytes
3068
+ /* .is_host = */ lm_ggml_backend_metal_buffer_type_is_host,
3069
+ },
3070
+ /* .context = */ NULL,
3071
+ };
3072
+
3073
+ return &lm_ggml_backend_buffer_type_metal;
3074
+ }
3075
+
3076
+ // buffer from ptr
3077
+
3078
+ LM_GGML_CALL lm_ggml_backend_buffer_t lm_ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size) {
3079
+ struct lm_ggml_backend_metal_buffer_context * ctx = malloc(sizeof(struct lm_ggml_backend_metal_buffer_context));
3080
+
3081
+ ctx->all_data = data;
3082
+ ctx->all_size = size;
3083
+ ctx->owned = false;
3084
+ ctx->n_buffers = 0;
3085
+
3086
+ const size_t size_page = sysconf(_SC_PAGESIZE);
3087
+
3088
+ // page-align the data ptr
3089
+ {
3090
+ const uintptr_t offs = (uintptr_t) data % size_page;
3091
+ data = (void *) ((char *) data - offs);
3092
+ size += offs;
3093
+ }
3094
+
3095
+ size_t size_aligned = size;
3096
+ if ((size_aligned % size_page) != 0) {
3097
+ size_aligned += (size_page - (size_aligned % size_page));
3098
+ }
3099
+
3100
+ id<MTLDevice> device = lm_ggml_backend_metal_get_device();
3101
+
3102
+ // the buffer fits into the max buffer size allowed by the device
3103
+ if (size_aligned <= device.maxBufferLength) {
3104
+ ctx->buffers[ctx->n_buffers].data = data;
3105
+ ctx->buffers[ctx->n_buffers].size = size;
3106
+
3107
+ ctx->buffers[ctx->n_buffers].metal = [device newBufferWithBytesNoCopy:data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil];
3108
+
3109
+ if (ctx->buffers[ctx->n_buffers].metal == nil) {
3110
+ LM_GGML_METAL_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0);
3111
+ return false;
3112
+ }
3113
+
3114
+ lm_ggml_backend_metal_log_allocated_size(device, size_aligned);
3115
+
3116
+ ++ctx->n_buffers;
3117
+ } else {
3118
+ // this overlap between the views will guarantee that the tensor with the maximum size will fully fit into
3119
+ // one of the views
3120
+ const size_t size_ovlp = ((max_size + size_page - 1) / size_page + 1) * size_page; // round-up 2 pages just in case
3121
+ const size_t size_step = device.maxBufferLength - size_ovlp;
3122
+ const size_t size_view = device.maxBufferLength;
3123
+
3124
+ for (size_t i = 0; i < size; i += size_step) {
3125
+ const size_t size_step_aligned = (i + size_view <= size) ? size_view : (size_aligned - i);
3126
+
3127
+ ctx->buffers[ctx->n_buffers].data = (void *) ((uint8_t *) data + i);
3128
+ ctx->buffers[ctx->n_buffers].size = size_step_aligned;
3129
+
3130
+ ctx->buffers[ctx->n_buffers].metal = [device newBufferWithBytesNoCopy:(void *) ((uint8_t *) data + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil];
3131
+
3132
+ if (ctx->buffers[ctx->n_buffers].metal == nil) {
3133
+ LM_GGML_METAL_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_step_aligned / 1024.0 / 1024.0);
3134
+ return false;
3135
+ }
3136
+
3137
+ lm_ggml_backend_metal_log_allocated_size(device, size_step_aligned);
3138
+
3139
+ if (i + size_step < size) {
3140
+ LM_GGML_METAL_LOG_INFO("\n");
3141
+ }
3142
+
3143
+ ++ctx->n_buffers;
3144
+ }
3145
+ }
3146
+
3147
+ return lm_ggml_backend_buffer_init(lm_ggml_backend_metal_buffer_type(), lm_ggml_backend_metal_buffer_i, ctx, size);
3148
+ }
3149
+
3150
+ // backend
3151
+
3152
+ LM_GGML_CALL static const char * lm_ggml_backend_metal_name(lm_ggml_backend_t backend) {
3153
+ return "Metal";
3154
+
3155
+ UNUSED(backend);
3156
+ }
3157
+
3158
+ LM_GGML_CALL static void lm_ggml_backend_metal_free(lm_ggml_backend_t backend) {
3159
+ struct lm_ggml_metal_context * ctx = (struct lm_ggml_metal_context *)backend->context;
3160
+ lm_ggml_metal_free(ctx);
3161
+ free(backend);
3162
+ }
3163
+
3164
+ LM_GGML_CALL static lm_ggml_backend_buffer_type_t lm_ggml_backend_metal_get_default_buffer_type(lm_ggml_backend_t backend) {
3165
+ return lm_ggml_backend_metal_buffer_type();
3166
+
3167
+ UNUSED(backend);
3168
+ }
3169
+
3170
+ LM_GGML_CALL static enum lm_ggml_status lm_ggml_backend_metal_graph_compute(lm_ggml_backend_t backend, struct lm_ggml_cgraph * cgraph) {
3171
+ struct lm_ggml_metal_context * metal_ctx = (struct lm_ggml_metal_context *)backend->context;
3172
+
3173
+ return lm_ggml_metal_graph_compute(metal_ctx, cgraph);
3174
+ }
3175
+
3176
+ LM_GGML_CALL static bool lm_ggml_backend_metal_supports_op(lm_ggml_backend_t backend, const struct lm_ggml_tensor * op) {
3177
+ struct lm_ggml_metal_context * metal_ctx = (struct lm_ggml_metal_context *)backend->context;
3178
+
3179
+ return lm_ggml_metal_supports_op(metal_ctx, op);
3180
+ }
3181
+
3182
+ LM_GGML_CALL static bool lm_ggml_backend_metal_supports_buft(lm_ggml_backend_t backend, lm_ggml_backend_buffer_type_t buft) {
3183
+ return buft->iface.get_name == lm_ggml_backend_metal_buffer_type_get_name;
3184
+
3185
+ UNUSED(backend);
3186
+ }
3187
+
3188
+ static struct lm_ggml_backend_i lm_ggml_backend_metal_i = {
3189
+ /* .get_name = */ lm_ggml_backend_metal_name,
3190
+ /* .free = */ lm_ggml_backend_metal_free,
3191
+ /* .get_default_buffer_type = */ lm_ggml_backend_metal_get_default_buffer_type,
3192
+ /* .set_tensor_async = */ NULL,
3193
+ /* .get_tensor_async = */ NULL,
3194
+ /* .cpy_tensor_async = */ NULL,
3195
+ /* .synchronize = */ NULL,
3196
+ /* .graph_plan_create = */ NULL,
3197
+ /* .graph_plan_free = */ NULL,
3198
+ /* .graph_plan_update = */ NULL,
3199
+ /* .graph_plan_compute = */ NULL,
3200
+ /* .graph_compute = */ lm_ggml_backend_metal_graph_compute,
3201
+ /* .supports_op = */ lm_ggml_backend_metal_supports_op,
3202
+ /* .supports_buft = */ lm_ggml_backend_metal_supports_buft,
3203
+ /* .offload_op = */ NULL,
3204
+ /* .event_new = */ NULL,
3205
+ /* .event_free = */ NULL,
3206
+ /* .event_record = */ NULL,
3207
+ /* .event_wait = */ NULL,
3208
+ /* .event_synchronize = */ NULL,
3209
+ };
3210
+
3211
+ void lm_ggml_backend_metal_log_set_callback(lm_ggml_log_callback log_callback, void * user_data) {
3212
+ lm_ggml_metal_log_callback = log_callback;
3213
+ lm_ggml_metal_log_user_data = user_data;
3214
+ }
3215
+
3216
+ static lm_ggml_guid_t lm_ggml_backend_metal_guid(void) {
3217
+ static lm_ggml_guid guid = { 0x81, 0xa1, 0x8b, 0x1e, 0x71, 0xec, 0x79, 0xed, 0x2b, 0x85, 0xdc, 0x8a, 0x61, 0x98, 0x30, 0xe6 };
3218
+ return &guid;
3219
+ }
3220
+
3221
+ lm_ggml_backend_t lm_ggml_backend_metal_init(void) {
3222
+ struct lm_ggml_metal_context * ctx = lm_ggml_metal_init(LM_GGML_DEFAULT_N_THREADS);
3223
+
3224
+ if (ctx == NULL) {
3225
+ return NULL;
3226
+ }
3227
+
3228
+ lm_ggml_backend_t metal_backend = malloc(sizeof(struct lm_ggml_backend));
3229
+
3230
+ *metal_backend = (struct lm_ggml_backend) {
3231
+ /* .guid = */ lm_ggml_backend_metal_guid(),
3232
+ /* .interface = */ lm_ggml_backend_metal_i,
3233
+ /* .context = */ ctx,
3234
+ };
3235
+
3236
+ return metal_backend;
3237
+ }
3238
+
3239
+ bool lm_ggml_backend_is_metal(lm_ggml_backend_t backend) {
3240
+ return backend != NULL && lm_ggml_guid_matches(backend->guid, lm_ggml_backend_metal_guid());
3241
+ }
3242
+
3243
+ void lm_ggml_backend_metal_set_n_cb(lm_ggml_backend_t backend, int n_cb) {
3244
+ LM_GGML_ASSERT(lm_ggml_backend_is_metal(backend));
3245
+
3246
+ struct lm_ggml_metal_context * ctx = (struct lm_ggml_metal_context *)backend->context;
3247
+
3248
+ ctx->n_cb = MIN(n_cb, LM_GGML_METAL_MAX_BUFFERS);
3249
+ }
3250
+
3251
+ bool lm_ggml_backend_metal_supports_family(lm_ggml_backend_t backend, int family) {
3252
+ LM_GGML_ASSERT(lm_ggml_backend_is_metal(backend));
3253
+
3254
+ struct lm_ggml_metal_context * ctx = (struct lm_ggml_metal_context *)backend->context;
3255
+
3256
+ return [ctx->device supportsFamily:(MTLGPUFamilyApple1 + family - 1)];
3257
+ }
3258
+
3259
+ void lm_ggml_backend_metal_capture_next_compute(lm_ggml_backend_t backend) {
3260
+ LM_GGML_ASSERT(lm_ggml_backend_is_metal(backend));
3261
+
3262
+ struct lm_ggml_metal_context * ctx = (struct lm_ggml_metal_context *)backend->context;
3263
+ ctx->should_capture_next_compute = true;
3264
+ }
3265
+
3266
+ LM_GGML_CALL lm_ggml_backend_t lm_ggml_backend_reg_metal_init(const char * params, void * user_data); // silence warning
3267
+
3268
+ LM_GGML_CALL lm_ggml_backend_t lm_ggml_backend_reg_metal_init(const char * params, void * user_data) {
3269
+ return lm_ggml_backend_metal_init();
3270
+
3271
+ LM_GGML_UNUSED(params);
3272
+ LM_GGML_UNUSED(user_data);
3273
+ }