llama_cpp 0.14.5 → 0.14.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +13 -0
- data/ext/llama_cpp/llama_cpp.cpp +37 -2
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +4 -0
- data/vendor/tmp/llama.cpp/Makefile +24 -7
- data/vendor/tmp/llama.cpp/ggml-alloc.c +8 -8
- data/vendor/tmp/llama.cpp/ggml-backend.c +14 -10
- data/vendor/tmp/llama.cpp/ggml-cuda.cu +135 -46
- data/vendor/tmp/llama.cpp/ggml-impl.h +263 -5
- data/vendor/tmp/llama.cpp/ggml-metal.m +130 -83
- data/vendor/tmp/llama.cpp/ggml-metal.metal +505 -1467
- data/vendor/tmp/llama.cpp/ggml-quants.c +1 -294
- data/vendor/tmp/llama.cpp/ggml-sycl.cpp +65 -52
- data/vendor/tmp/llama.cpp/ggml.c +151 -99
- data/vendor/tmp/llama.cpp/ggml.h +5 -4
- data/vendor/tmp/llama.cpp/llama.cpp +1308 -254
- data/vendor/tmp/llama.cpp/llama.h +19 -6
- data/vendor/tmp/llama.cpp/sgemm.cpp +999 -0
- data/vendor/tmp/llama.cpp/sgemm.h +12 -0
- metadata +4 -2
@@ -11,6 +11,12 @@
|
|
11
11
|
#include <string.h> // memcpy
|
12
12
|
#include <math.h> // fabsf
|
13
13
|
|
14
|
+
#undef MIN
|
15
|
+
#undef MAX
|
16
|
+
|
17
|
+
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
18
|
+
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
19
|
+
|
14
20
|
#ifdef __cplusplus
|
15
21
|
extern "C" {
|
16
22
|
#endif
|
@@ -45,7 +51,7 @@ extern "C" {
|
|
45
51
|
// 16-bit float
|
46
52
|
// on Arm, we use __fp16
|
47
53
|
// on x86, we use uint16_t
|
48
|
-
#if defined(__ARM_NEON)
|
54
|
+
#if defined(__ARM_NEON)
|
49
55
|
|
50
56
|
// if YCM cannot find <arm_neon.h>, make a symbolic link to it, for example:
|
51
57
|
//
|
@@ -53,8 +59,262 @@ extern "C" {
|
|
53
59
|
//
|
54
60
|
#include <arm_neon.h>
|
55
61
|
|
62
|
+
#ifdef _MSC_VER
|
63
|
+
|
64
|
+
typedef uint16_t ggml_fp16_internal_t;
|
65
|
+
|
66
|
+
#define ggml_vld1q_u32(w,x,y,z) { ((w) + ((uint64_t)(x) << 32)), ((y) + ((uint64_t)(z) << 32)) }
|
67
|
+
|
68
|
+
#else
|
69
|
+
|
56
70
|
typedef __fp16 ggml_fp16_internal_t;
|
57
71
|
|
72
|
+
#define ggml_vld1q_u32(w,x,y,z) { (w), (x), (y), (z) }
|
73
|
+
|
74
|
+
#endif // _MSC_VER
|
75
|
+
|
76
|
+
#if !defined(__aarch64__)
|
77
|
+
|
78
|
+
// 32-bit ARM compatibility
|
79
|
+
|
80
|
+
// vaddvq_s16
|
81
|
+
// vpaddq_s16
|
82
|
+
// vpaddq_s32
|
83
|
+
// vaddvq_s32
|
84
|
+
// vaddvq_f32
|
85
|
+
// vmaxvq_f32
|
86
|
+
// vcvtnq_s32_f32
|
87
|
+
// vzip1_u8
|
88
|
+
// vzip2_u8
|
89
|
+
|
90
|
+
inline static int32_t vaddvq_s16(int16x8_t v) {
|
91
|
+
return
|
92
|
+
(int32_t)vgetq_lane_s16(v, 0) + (int32_t)vgetq_lane_s16(v, 1) +
|
93
|
+
(int32_t)vgetq_lane_s16(v, 2) + (int32_t)vgetq_lane_s16(v, 3) +
|
94
|
+
(int32_t)vgetq_lane_s16(v, 4) + (int32_t)vgetq_lane_s16(v, 5) +
|
95
|
+
(int32_t)vgetq_lane_s16(v, 6) + (int32_t)vgetq_lane_s16(v, 7);
|
96
|
+
}
|
97
|
+
|
98
|
+
inline static int16x8_t vpaddq_s16(int16x8_t a, int16x8_t b) {
|
99
|
+
int16x4_t a0 = vpadd_s16(vget_low_s16(a), vget_high_s16(a));
|
100
|
+
int16x4_t b0 = vpadd_s16(vget_low_s16(b), vget_high_s16(b));
|
101
|
+
return vcombine_s16(a0, b0);
|
102
|
+
}
|
103
|
+
|
104
|
+
inline static int32x4_t vpaddq_s32(int32x4_t a, int32x4_t b) {
|
105
|
+
int32x2_t a0 = vpadd_s32(vget_low_s32(a), vget_high_s32(a));
|
106
|
+
int32x2_t b0 = vpadd_s32(vget_low_s32(b), vget_high_s32(b));
|
107
|
+
return vcombine_s32(a0, b0);
|
108
|
+
}
|
109
|
+
|
110
|
+
inline static int32_t vaddvq_s32(int32x4_t v) {
|
111
|
+
return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3);
|
112
|
+
}
|
113
|
+
|
114
|
+
inline static float vaddvq_f32(float32x4_t v) {
|
115
|
+
return vgetq_lane_f32(v, 0) + vgetq_lane_f32(v, 1) + vgetq_lane_f32(v, 2) + vgetq_lane_f32(v, 3);
|
116
|
+
}
|
117
|
+
|
118
|
+
inline static float vmaxvq_f32(float32x4_t v) {
|
119
|
+
return
|
120
|
+
MAX(MAX(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)),
|
121
|
+
MAX(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3)));
|
122
|
+
}
|
123
|
+
|
124
|
+
inline static int32x4_t vcvtnq_s32_f32(float32x4_t v) {
|
125
|
+
int32x4_t res;
|
126
|
+
|
127
|
+
res[0] = roundf(vgetq_lane_f32(v, 0));
|
128
|
+
res[1] = roundf(vgetq_lane_f32(v, 1));
|
129
|
+
res[2] = roundf(vgetq_lane_f32(v, 2));
|
130
|
+
res[3] = roundf(vgetq_lane_f32(v, 3));
|
131
|
+
|
132
|
+
return res;
|
133
|
+
}
|
134
|
+
|
135
|
+
inline static uint8x8_t vzip1_u8(uint8x8_t a, uint8x8_t b) {
|
136
|
+
uint8x8_t res;
|
137
|
+
|
138
|
+
res[0] = a[0]; res[1] = b[0];
|
139
|
+
res[2] = a[1]; res[3] = b[1];
|
140
|
+
res[4] = a[2]; res[5] = b[2];
|
141
|
+
res[6] = a[3]; res[7] = b[3];
|
142
|
+
|
143
|
+
return res;
|
144
|
+
}
|
145
|
+
|
146
|
+
inline static uint8x8_t vzip2_u8(uint8x8_t a, uint8x8_t b) {
|
147
|
+
uint8x8_t res;
|
148
|
+
|
149
|
+
res[0] = a[4]; res[1] = b[4];
|
150
|
+
res[2] = a[5]; res[3] = b[5];
|
151
|
+
res[4] = a[6]; res[5] = b[6];
|
152
|
+
res[6] = a[7]; res[7] = b[7];
|
153
|
+
|
154
|
+
return res;
|
155
|
+
}
|
156
|
+
|
157
|
+
// vld1q_s16_x2
|
158
|
+
// vld1q_u8_x2
|
159
|
+
// vld1q_u8_x4
|
160
|
+
// vld1q_s8_x2
|
161
|
+
// vld1q_s8_x4
|
162
|
+
// TODO: double-check these work correctly
|
163
|
+
|
164
|
+
typedef struct ggml_int16x8x2_t {
|
165
|
+
int16x8_t val[2];
|
166
|
+
} ggml_int16x8x2_t;
|
167
|
+
|
168
|
+
inline static ggml_int16x8x2_t ggml_vld1q_s16_x2(const int16_t * ptr) {
|
169
|
+
ggml_int16x8x2_t res;
|
170
|
+
|
171
|
+
res.val[0] = vld1q_s16(ptr + 0);
|
172
|
+
res.val[1] = vld1q_s16(ptr + 8);
|
173
|
+
|
174
|
+
return res;
|
175
|
+
}
|
176
|
+
|
177
|
+
typedef struct ggml_uint8x16x2_t {
|
178
|
+
uint8x16_t val[2];
|
179
|
+
} ggml_uint8x16x2_t;
|
180
|
+
|
181
|
+
inline static ggml_uint8x16x2_t ggml_vld1q_u8_x2(const uint8_t * ptr) {
|
182
|
+
ggml_uint8x16x2_t res;
|
183
|
+
|
184
|
+
res.val[0] = vld1q_u8(ptr + 0);
|
185
|
+
res.val[1] = vld1q_u8(ptr + 16);
|
186
|
+
|
187
|
+
return res;
|
188
|
+
}
|
189
|
+
|
190
|
+
typedef struct ggml_uint8x16x4_t {
|
191
|
+
uint8x16_t val[4];
|
192
|
+
} ggml_uint8x16x4_t;
|
193
|
+
|
194
|
+
inline static ggml_uint8x16x4_t ggml_vld1q_u8_x4(const uint8_t * ptr) {
|
195
|
+
ggml_uint8x16x4_t res;
|
196
|
+
|
197
|
+
res.val[0] = vld1q_u8(ptr + 0);
|
198
|
+
res.val[1] = vld1q_u8(ptr + 16);
|
199
|
+
res.val[2] = vld1q_u8(ptr + 32);
|
200
|
+
res.val[3] = vld1q_u8(ptr + 48);
|
201
|
+
|
202
|
+
return res;
|
203
|
+
}
|
204
|
+
|
205
|
+
typedef struct ggml_int8x16x2_t {
|
206
|
+
int8x16_t val[2];
|
207
|
+
} ggml_int8x16x2_t;
|
208
|
+
|
209
|
+
inline static ggml_int8x16x2_t ggml_vld1q_s8_x2(const int8_t * ptr) {
|
210
|
+
ggml_int8x16x2_t res;
|
211
|
+
|
212
|
+
res.val[0] = vld1q_s8(ptr + 0);
|
213
|
+
res.val[1] = vld1q_s8(ptr + 16);
|
214
|
+
|
215
|
+
return res;
|
216
|
+
}
|
217
|
+
|
218
|
+
typedef struct ggml_int8x16x4_t {
|
219
|
+
int8x16_t val[4];
|
220
|
+
} ggml_int8x16x4_t;
|
221
|
+
|
222
|
+
inline static ggml_int8x16x4_t ggml_vld1q_s8_x4(const int8_t * ptr) {
|
223
|
+
ggml_int8x16x4_t res;
|
224
|
+
|
225
|
+
res.val[0] = vld1q_s8(ptr + 0);
|
226
|
+
res.val[1] = vld1q_s8(ptr + 16);
|
227
|
+
res.val[2] = vld1q_s8(ptr + 32);
|
228
|
+
res.val[3] = vld1q_s8(ptr + 48);
|
229
|
+
|
230
|
+
return res;
|
231
|
+
}
|
232
|
+
|
233
|
+
// NOTE: not tested
|
234
|
+
inline static int8x16_t ggml_vqtbl1q_s8(int8x16_t a, uint8x16_t b) {
|
235
|
+
int8x16_t res;
|
236
|
+
|
237
|
+
res[ 0] = a[b[ 0]];
|
238
|
+
res[ 1] = a[b[ 1]];
|
239
|
+
res[ 2] = a[b[ 2]];
|
240
|
+
res[ 3] = a[b[ 3]];
|
241
|
+
res[ 4] = a[b[ 4]];
|
242
|
+
res[ 5] = a[b[ 5]];
|
243
|
+
res[ 6] = a[b[ 6]];
|
244
|
+
res[ 7] = a[b[ 7]];
|
245
|
+
res[ 8] = a[b[ 8]];
|
246
|
+
res[ 9] = a[b[ 9]];
|
247
|
+
res[10] = a[b[10]];
|
248
|
+
res[11] = a[b[11]];
|
249
|
+
res[12] = a[b[12]];
|
250
|
+
res[13] = a[b[13]];
|
251
|
+
res[14] = a[b[14]];
|
252
|
+
res[15] = a[b[15]];
|
253
|
+
|
254
|
+
return res;
|
255
|
+
}
|
256
|
+
|
257
|
+
// NOTE: not tested
|
258
|
+
inline static uint8x16_t ggml_vqtbl1q_u8(uint8x16_t a, uint8x16_t b) {
|
259
|
+
uint8x16_t res;
|
260
|
+
|
261
|
+
res[ 0] = a[b[ 0]];
|
262
|
+
res[ 1] = a[b[ 1]];
|
263
|
+
res[ 2] = a[b[ 2]];
|
264
|
+
res[ 3] = a[b[ 3]];
|
265
|
+
res[ 4] = a[b[ 4]];
|
266
|
+
res[ 5] = a[b[ 5]];
|
267
|
+
res[ 6] = a[b[ 6]];
|
268
|
+
res[ 7] = a[b[ 7]];
|
269
|
+
res[ 8] = a[b[ 8]];
|
270
|
+
res[ 9] = a[b[ 9]];
|
271
|
+
res[10] = a[b[10]];
|
272
|
+
res[11] = a[b[11]];
|
273
|
+
res[12] = a[b[12]];
|
274
|
+
res[13] = a[b[13]];
|
275
|
+
res[14] = a[b[14]];
|
276
|
+
res[15] = a[b[15]];
|
277
|
+
|
278
|
+
return res;
|
279
|
+
}
|
280
|
+
|
281
|
+
#else
|
282
|
+
|
283
|
+
#define ggml_int16x8x2_t int16x8x2_t
|
284
|
+
#define ggml_uint8x16x2_t uint8x16x2_t
|
285
|
+
#define ggml_uint8x16x4_t uint8x16x4_t
|
286
|
+
#define ggml_int8x16x2_t int8x16x2_t
|
287
|
+
#define ggml_int8x16x4_t int8x16x4_t
|
288
|
+
|
289
|
+
#define ggml_vld1q_s16_x2 vld1q_s16_x2
|
290
|
+
#define ggml_vld1q_u8_x2 vld1q_u8_x2
|
291
|
+
#define ggml_vld1q_u8_x4 vld1q_u8_x4
|
292
|
+
#define ggml_vld1q_s8_x2 vld1q_s8_x2
|
293
|
+
#define ggml_vld1q_s8_x4 vld1q_s8_x4
|
294
|
+
#define ggml_vqtbl1q_s8 vqtbl1q_s8
|
295
|
+
#define ggml_vqtbl1q_u8 vqtbl1q_u8
|
296
|
+
|
297
|
+
#endif // !defined(__aarch64__)
|
298
|
+
|
299
|
+
#if !defined(__ARM_FEATURE_DOTPROD)
|
300
|
+
|
301
|
+
inline static int32x4_t ggml_vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b) {
|
302
|
+
const int16x8_t p0 = vmull_s8(vget_low_s8 (a), vget_low_s8 (b));
|
303
|
+
const int16x8_t p1 = vmull_s8(vget_high_s8(a), vget_high_s8(b));
|
304
|
+
|
305
|
+
return vaddq_s32(acc, vaddq_s32(vpaddlq_s16(p0), vpaddlq_s16(p1)));
|
306
|
+
}
|
307
|
+
|
308
|
+
#else
|
309
|
+
|
310
|
+
#define ggml_vdotq_s32(a, b, c) vdotq_s32(a, b, c)
|
311
|
+
|
312
|
+
#endif // !defined(__ARM_FEATURE_DOTPROD)
|
313
|
+
|
314
|
+
#endif // defined(__ARM_NEON)
|
315
|
+
|
316
|
+
#if defined(__ARM_NEON) && !defined(__MSC_VER)
|
317
|
+
|
58
318
|
#define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
|
59
319
|
#define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x)
|
60
320
|
|
@@ -75,8 +335,6 @@ static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
|
|
75
335
|
|
76
336
|
#else
|
77
337
|
|
78
|
-
typedef uint16_t ggml_fp16_internal_t;
|
79
|
-
|
80
338
|
#ifdef __wasm_simd128__
|
81
339
|
#include <wasm_simd128.h>
|
82
340
|
#else
|
@@ -88,7 +346,7 @@ typedef uint16_t ggml_fp16_internal_t;
|
|
88
346
|
#if defined(_MSC_VER) || defined(__MINGW32__)
|
89
347
|
#include <intrin.h>
|
90
348
|
#else
|
91
|
-
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) || defined(__SSE3__)
|
349
|
+
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) || defined(__SSE3__) || defined(__SSE__)
|
92
350
|
#if !defined(__riscv)
|
93
351
|
#include <immintrin.h>
|
94
352
|
#endif
|
@@ -221,7 +479,7 @@ static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
|
|
221
479
|
|
222
480
|
#endif // __F16C__
|
223
481
|
|
224
|
-
#endif // __ARM_NEON
|
482
|
+
#endif // defined(__ARM_NEON) && (!defined(__MSC_VER)
|
225
483
|
|
226
484
|
// precomputed f32 table for f16 (256 KB)
|
227
485
|
// defined in ggml.c, initialized in ggml_init()
|
@@ -37,11 +37,15 @@ enum ggml_metal_kernel_type {
|
|
37
37
|
GGML_METAL_KERNEL_TYPE_DIV_ROW,
|
38
38
|
GGML_METAL_KERNEL_TYPE_SCALE,
|
39
39
|
GGML_METAL_KERNEL_TYPE_SCALE_4,
|
40
|
+
GGML_METAL_KERNEL_TYPE_CLAMP,
|
40
41
|
GGML_METAL_KERNEL_TYPE_TANH,
|
41
42
|
GGML_METAL_KERNEL_TYPE_RELU,
|
42
43
|
GGML_METAL_KERNEL_TYPE_GELU,
|
44
|
+
GGML_METAL_KERNEL_TYPE_GELU_4,
|
43
45
|
GGML_METAL_KERNEL_TYPE_GELU_QUICK,
|
46
|
+
GGML_METAL_KERNEL_TYPE_GELU_QUICK_4,
|
44
47
|
GGML_METAL_KERNEL_TYPE_SILU,
|
48
|
+
GGML_METAL_KERNEL_TYPE_SILU_4,
|
45
49
|
GGML_METAL_KERNEL_TYPE_SOFT_MAX,
|
46
50
|
GGML_METAL_KERNEL_TYPE_SOFT_MAX_4,
|
47
51
|
GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF,
|
@@ -468,11 +472,15 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
468
472
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true);
|
469
473
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true);
|
470
474
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true);
|
475
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true);
|
471
476
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true);
|
472
477
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true);
|
473
478
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true);
|
479
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_4, gelu_4, true);
|
474
480
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true);
|
481
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true);
|
475
482
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true);
|
483
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true);
|
476
484
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX, soft_max, ctx->support_simdgroup_reduction);
|
477
485
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_4, soft_max_4, ctx->support_simdgroup_reduction);
|
478
486
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true);
|
@@ -713,6 +721,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
|
|
713
721
|
case GGML_OP_MUL:
|
714
722
|
case GGML_OP_DIV:
|
715
723
|
case GGML_OP_SCALE:
|
724
|
+
case GGML_OP_CLAMP:
|
716
725
|
case GGML_OP_SQR:
|
717
726
|
case GGML_OP_SUM_ROWS:
|
718
727
|
return true;
|
@@ -1154,8 +1163,30 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1154
1163
|
|
1155
1164
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1156
1165
|
} break;
|
1166
|
+
case GGML_OP_CLAMP:
|
1167
|
+
{
|
1168
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CLAMP].pipeline;
|
1169
|
+
|
1170
|
+
float min;
|
1171
|
+
float max;
|
1172
|
+
memcpy(&min, ((int32_t *) dst->op_params) + 0, sizeof(float));
|
1173
|
+
memcpy(&max, ((int32_t *) dst->op_params) + 1, sizeof(float));
|
1174
|
+
|
1175
|
+
[encoder setComputePipelineState:pipeline];
|
1176
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1177
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
1178
|
+
[encoder setBytes:&min length:sizeof(min) atIndex:2];
|
1179
|
+
[encoder setBytes:&max length:sizeof(max) atIndex:3];
|
1180
|
+
|
1181
|
+
const int64_t n = ggml_nelements(dst);
|
1182
|
+
|
1183
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1184
|
+
} break;
|
1157
1185
|
case GGML_OP_UNARY:
|
1158
1186
|
switch (ggml_get_unary_op(gf->nodes[i])) {
|
1187
|
+
// we are not taking into account the strides, so for now require contiguous tensors
|
1188
|
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
1189
|
+
|
1159
1190
|
case GGML_UNARY_OP_TANH:
|
1160
1191
|
{
|
1161
1192
|
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TANH].pipeline;
|
@@ -1182,42 +1213,60 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1182
1213
|
} break;
|
1183
1214
|
case GGML_UNARY_OP_GELU:
|
1184
1215
|
{
|
1185
|
-
|
1216
|
+
int64_t n = ggml_nelements(dst);
|
1217
|
+
|
1218
|
+
id<MTLComputePipelineState> pipeline = nil;
|
1219
|
+
|
1220
|
+
if (n % 4 == 0) {
|
1221
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_4].pipeline;
|
1222
|
+
n /= 4;
|
1223
|
+
} else {
|
1224
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU].pipeline;
|
1225
|
+
}
|
1186
1226
|
|
1187
1227
|
[encoder setComputePipelineState:pipeline];
|
1188
1228
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1189
1229
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
1190
1230
|
|
1191
|
-
|
1192
|
-
GGML_ASSERT(n % 4 == 0);
|
1193
|
-
|
1194
|
-
[encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1231
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1195
1232
|
} break;
|
1196
1233
|
case GGML_UNARY_OP_GELU_QUICK:
|
1197
1234
|
{
|
1198
|
-
|
1235
|
+
int64_t n = ggml_nelements(dst);
|
1236
|
+
|
1237
|
+
id<MTLComputePipelineState> pipeline = nil;
|
1238
|
+
|
1239
|
+
if (n % 4 == 0) {
|
1240
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK_4].pipeline;
|
1241
|
+
n /= 4;
|
1242
|
+
} else {
|
1243
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK].pipeline;
|
1244
|
+
}
|
1199
1245
|
|
1200
1246
|
[encoder setComputePipelineState:pipeline];
|
1201
1247
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1202
1248
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
1203
1249
|
|
1204
|
-
|
1205
|
-
GGML_ASSERT(n % 4 == 0);
|
1206
|
-
|
1207
|
-
[encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1250
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1208
1251
|
} break;
|
1209
1252
|
case GGML_UNARY_OP_SILU:
|
1210
1253
|
{
|
1211
|
-
|
1254
|
+
int64_t n = ggml_nelements(dst);
|
1255
|
+
|
1256
|
+
id<MTLComputePipelineState> pipeline = nil;
|
1257
|
+
|
1258
|
+
if (n % 4 == 0) {
|
1259
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU_4].pipeline;
|
1260
|
+
n /= 4;
|
1261
|
+
} else {
|
1262
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU].pipeline;
|
1263
|
+
}
|
1212
1264
|
|
1213
1265
|
[encoder setComputePipelineState:pipeline];
|
1214
1266
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1215
1267
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
1216
1268
|
|
1217
|
-
|
1218
|
-
GGML_ASSERT(n % 4 == 0);
|
1219
|
-
|
1220
|
-
[encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1269
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1221
1270
|
} break;
|
1222
1271
|
default:
|
1223
1272
|
{
|
@@ -1683,15 +1732,10 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1683
1732
|
} break;
|
1684
1733
|
case GGML_OP_MUL_MAT_ID:
|
1685
1734
|
{
|
1686
|
-
//GGML_ASSERT(ne00 == ne10);
|
1687
|
-
//GGML_ASSERT(ne03 == ne13);
|
1688
1735
|
const int n_as = src0->ne[2];
|
1689
1736
|
|
1690
|
-
// max size of the src1ids array in the kernel shared buffer
|
1691
|
-
GGML_ASSERT(ne11 <= 4096);
|
1692
|
-
|
1693
1737
|
// src2 = ids
|
1694
|
-
const int64_t ne20 = src2->ne[0];
|
1738
|
+
const int64_t ne20 = src2->ne[0];
|
1695
1739
|
const int64_t ne21 = src2->ne[1];
|
1696
1740
|
const int64_t ne22 = src2->ne[2]; GGML_UNUSED(ne22);
|
1697
1741
|
const int64_t ne23 = src2->ne[3]; GGML_UNUSED(ne23);
|
@@ -1712,15 +1756,13 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1712
1756
|
|
1713
1757
|
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
|
1714
1758
|
// to the matrix-vector kernel
|
1715
|
-
|
1716
|
-
|
1717
|
-
const int
|
1759
|
+
// ne20 = n_used_experts
|
1760
|
+
// ne21 = n_rows
|
1761
|
+
const int dst_rows = ne20*ne21;
|
1762
|
+
const int dst_rows_min = n_as;
|
1718
1763
|
|
1719
|
-
//
|
1720
|
-
GGML_ASSERT(
|
1721
|
-
GGML_ASSERT(ne12 == 1 && ne13 == 1); // no broadcasting
|
1722
|
-
const uint r2 = 1;
|
1723
|
-
const uint r3 = 1;
|
1764
|
+
// max size of the rowids array in the kernel shared buffer
|
1765
|
+
GGML_ASSERT(dst_rows <= 2048);
|
1724
1766
|
|
1725
1767
|
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
1726
1768
|
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
|
@@ -1730,7 +1772,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1730
1772
|
// !!!
|
1731
1773
|
if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
|
1732
1774
|
ne00 % 32 == 0 && ne00 >= 64 &&
|
1733
|
-
|
1775
|
+
dst_rows > dst_rows_min) {
|
1734
1776
|
|
1735
1777
|
// some Metal matrix data types require aligned pointers
|
1736
1778
|
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
|
@@ -1772,26 +1814,26 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1772
1814
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
1773
1815
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
1774
1816
|
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
|
1775
|
-
[encoder setBytes:&
|
1776
|
-
[encoder setBytes:&
|
1777
|
-
[encoder setBytes:&
|
1778
|
-
[encoder setBytes:&
|
1779
|
-
[encoder setBytes:&
|
1780
|
-
[encoder setBytes:&
|
1781
|
-
[encoder setBytes:&
|
1782
|
-
[encoder setBytes:&
|
1783
|
-
[encoder setBytes:&
|
1784
|
-
[encoder setBytes:&
|
1785
|
-
[encoder setBytes:&
|
1786
|
-
[encoder setBytes:&
|
1787
|
-
[encoder setBytes:&
|
1788
|
-
[encoder setBytes:&
|
1789
|
-
[encoder setBytes:&
|
1790
|
-
[encoder setBytes:&
|
1791
|
-
|
1792
|
-
[encoder setThreadgroupMemoryLength:GGML_PAD(8192 +
|
1793
|
-
|
1794
|
-
[encoder dispatchThreadgroups:MTLSizeMake((
|
1817
|
+
[encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
|
1818
|
+
[encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
|
1819
|
+
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
|
1820
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:7];
|
1821
|
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:8];
|
1822
|
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:9];
|
1823
|
+
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:10];
|
1824
|
+
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
|
1825
|
+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
|
1826
|
+
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
|
1827
|
+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
|
1828
|
+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
|
1829
|
+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
|
1830
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17];
|
1831
|
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:18];
|
1832
|
+
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19];
|
1833
|
+
|
1834
|
+
[encoder setThreadgroupMemoryLength:GGML_PAD(8192 + dst_rows*4/*sizeof(ushort2)*/, 16) atIndex:0];
|
1835
|
+
|
1836
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 31)/32, (ne01 + 63)/64, n_as) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
1795
1837
|
} else {
|
1796
1838
|
int nth0 = 32;
|
1797
1839
|
int nth1 = 1;
|
@@ -1926,7 +1968,12 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1926
1968
|
{
|
1927
1969
|
nth0 = 4;
|
1928
1970
|
nth1 = 16;
|
1971
|
+
#if QK_K == 64
|
1972
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32].pipeline;
|
1973
|
+
#else
|
1929
1974
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32].pipeline;
|
1975
|
+
#endif
|
1976
|
+
|
1930
1977
|
} break;
|
1931
1978
|
default:
|
1932
1979
|
{
|
@@ -1939,72 +1986,72 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1939
1986
|
GGML_ASSERT(ne00 >= nth0*nth1);
|
1940
1987
|
}
|
1941
1988
|
|
1942
|
-
const int64_t _ne1 = 1; // kernels needs a reference in constant memory
|
1943
|
-
|
1944
1989
|
[encoder setComputePipelineState:pipeline];
|
1945
1990
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1946
1991
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
1947
1992
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
1948
1993
|
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
|
1949
|
-
[encoder setBytes:&
|
1950
|
-
[encoder setBytes:&
|
1951
|
-
[encoder setBytes:&
|
1952
|
-
[encoder setBytes:&
|
1953
|
-
[encoder setBytes:&
|
1954
|
-
[encoder setBytes:&
|
1955
|
-
[encoder setBytes:&
|
1956
|
-
[encoder setBytes:&
|
1957
|
-
[encoder setBytes:&
|
1958
|
-
[encoder setBytes:&
|
1959
|
-
[encoder setBytes:&
|
1960
|
-
[encoder setBytes:&
|
1961
|
-
[encoder setBytes:&
|
1962
|
-
[encoder setBytes:&
|
1963
|
-
[encoder setBytes:&
|
1964
|
-
[encoder setBytes:&
|
1965
|
-
[encoder setBytes:&
|
1966
|
-
[encoder setBytes:&
|
1967
|
-
[encoder setBytes:&
|
1968
|
-
|
1994
|
+
[encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
|
1995
|
+
[encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
|
1996
|
+
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
|
1997
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:7];
|
1998
|
+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:8];
|
1999
|
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:9];
|
2000
|
+
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:10];
|
2001
|
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:11];
|
2002
|
+
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:12];
|
2003
|
+
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:13];
|
2004
|
+
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:14];
|
2005
|
+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:15];
|
2006
|
+
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:16];
|
2007
|
+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:17];
|
2008
|
+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:18];
|
2009
|
+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:19];
|
2010
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:20];
|
2011
|
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:21];
|
2012
|
+
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:22];
|
2013
|
+
|
2014
|
+
const int64_t _ne1 = 1;
|
2015
|
+
const int tgz = dst_rows;
|
1969
2016
|
|
1970
2017
|
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 ||
|
1971
2018
|
src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K ||
|
1972
2019
|
src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
|
1973
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1,
|
2020
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1974
2021
|
}
|
1975
2022
|
else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
|
1976
2023
|
const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
|
1977
2024
|
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
1978
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1,
|
2025
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1979
2026
|
}
|
1980
2027
|
else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) {
|
1981
2028
|
const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
|
1982
2029
|
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
1983
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1,
|
2030
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1984
2031
|
}
|
1985
2032
|
else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) {
|
1986
2033
|
const int mem_size = 32*sizeof(float);
|
1987
2034
|
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
1988
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1,
|
2035
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1989
2036
|
}
|
1990
2037
|
else if (src0t == GGML_TYPE_Q4_K) {
|
1991
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1,
|
2038
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1992
2039
|
}
|
1993
2040
|
else if (src0t == GGML_TYPE_Q3_K) {
|
1994
2041
|
#ifdef GGML_QKK_64
|
1995
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1,
|
2042
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1996
2043
|
#else
|
1997
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1,
|
2044
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1998
2045
|
#endif
|
1999
2046
|
}
|
2000
2047
|
else if (src0t == GGML_TYPE_Q5_K) {
|
2001
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1,
|
2048
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
2002
2049
|
}
|
2003
2050
|
else if (src0t == GGML_TYPE_Q6_K) {
|
2004
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1,
|
2051
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
2005
2052
|
} else {
|
2006
|
-
const int64_t ny = (_ne1 + nrows - 1)/nrows;
|
2007
|
-
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny,
|
2053
|
+
const int64_t ny = (_ne1 + nrows - 1)/nrows; // = _ne1
|
2054
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
2008
2055
|
}
|
2009
2056
|
}
|
2010
2057
|
} break;
|