whisper.rn 0.4.0-rc.3 → 0.4.0-rc.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +6 -6
- package/android/build.gradle +4 -0
- package/android/src/main/CMakeLists.txt +7 -0
- package/android/src/main/java/com/rnwhisper/AudioUtils.java +0 -80
- package/android/src/main/java/com/rnwhisper/RNWhisper.java +6 -1
- package/android/src/main/java/com/rnwhisper/WhisperContext.java +53 -135
- package/android/src/main/jni-utils.h +76 -0
- package/android/src/main/jni.cpp +188 -109
- package/cpp/README.md +1 -1
- package/cpp/coreml/whisper-encoder-impl.h +1 -1
- package/cpp/coreml/whisper-encoder.h +4 -0
- package/cpp/coreml/whisper-encoder.mm +4 -2
- package/cpp/ggml-alloc.c +451 -282
- package/cpp/ggml-alloc.h +74 -8
- package/cpp/ggml-backend-impl.h +112 -0
- package/cpp/ggml-backend.c +1357 -0
- package/cpp/ggml-backend.h +181 -0
- package/cpp/ggml-impl.h +243 -0
- package/cpp/{ggml-metal.metal → ggml-metal-whisper.metal} +1556 -329
- package/cpp/ggml-metal.h +28 -1
- package/cpp/ggml-metal.m +1128 -308
- package/cpp/ggml-quants.c +7382 -0
- package/cpp/ggml-quants.h +224 -0
- package/cpp/ggml.c +3848 -5245
- package/cpp/ggml.h +353 -155
- package/cpp/rn-audioutils.cpp +68 -0
- package/cpp/rn-audioutils.h +14 -0
- package/cpp/rn-whisper-log.h +11 -0
- package/cpp/rn-whisper.cpp +141 -59
- package/cpp/rn-whisper.h +47 -15
- package/cpp/whisper.cpp +1750 -964
- package/cpp/whisper.h +97 -15
- package/ios/RNWhisper.mm +15 -9
- package/ios/RNWhisper.xcodeproj/project.xcworkspace/contents.xcworkspacedata +4 -0
- package/ios/RNWhisper.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist +8 -0
- package/ios/RNWhisper.xcodeproj/project.xcworkspace/xcuserdata/jhen.xcuserdatad/UserInterfaceState.xcuserstate +0 -0
- package/ios/RNWhisper.xcodeproj/xcuserdata/jhen.xcuserdatad/xcschemes/xcschememanagement.plist +19 -0
- package/ios/RNWhisperAudioUtils.h +0 -2
- package/ios/RNWhisperAudioUtils.m +0 -56
- package/ios/RNWhisperContext.h +8 -12
- package/ios/RNWhisperContext.mm +132 -138
- package/jest/mock.js +1 -1
- package/lib/commonjs/NativeRNWhisper.js.map +1 -1
- package/lib/commonjs/index.js +28 -9
- package/lib/commonjs/index.js.map +1 -1
- package/lib/commonjs/version.json +1 -1
- package/lib/module/NativeRNWhisper.js.map +1 -1
- package/lib/module/index.js +28 -9
- package/lib/module/index.js.map +1 -1
- package/lib/module/version.json +1 -1
- package/lib/typescript/NativeRNWhisper.d.ts +7 -1
- package/lib/typescript/NativeRNWhisper.d.ts.map +1 -1
- package/lib/typescript/index.d.ts +7 -2
- package/lib/typescript/index.d.ts.map +1 -1
- package/package.json +6 -5
- package/src/NativeRNWhisper.ts +8 -1
- package/src/index.ts +29 -17
- package/src/version.json +1 -1
- package/whisper-rn.podspec +1 -2
|
@@ -3,6 +3,8 @@
|
|
|
3
3
|
using namespace metal;
|
|
4
4
|
|
|
5
5
|
#define MAX(x, y) ((x) > (y) ? (x) : (y))
|
|
6
|
+
#define MIN(x, y) ((x) < (y) ? (x) : (y))
|
|
7
|
+
#define SWAP(x, y) { auto tmp = (x); (x) = (y); (y) = tmp; }
|
|
6
8
|
|
|
7
9
|
#define QK4_0 32
|
|
8
10
|
#define QR4_0 2
|
|
@@ -13,23 +15,187 @@ typedef struct {
|
|
|
13
15
|
|
|
14
16
|
#define QK4_1 32
|
|
15
17
|
typedef struct {
|
|
16
|
-
half d;
|
|
17
|
-
half m;
|
|
18
|
+
half d; // delta
|
|
19
|
+
half m; // min
|
|
18
20
|
uint8_t qs[QK4_1 / 2]; // nibbles / quants
|
|
19
21
|
} block_q4_1;
|
|
20
22
|
|
|
23
|
+
#define QK5_0 32
|
|
24
|
+
typedef struct {
|
|
25
|
+
half d; // delta
|
|
26
|
+
uint8_t qh[4]; // 5-th bit of quants
|
|
27
|
+
uint8_t qs[QK5_0 / 2]; // nibbles / quants
|
|
28
|
+
} block_q5_0;
|
|
29
|
+
|
|
30
|
+
#define QK5_1 32
|
|
31
|
+
typedef struct {
|
|
32
|
+
half d; // delta
|
|
33
|
+
half m; // min
|
|
34
|
+
uint8_t qh[4]; // 5-th bit of quants
|
|
35
|
+
uint8_t qs[QK5_1 / 2]; // nibbles / quants
|
|
36
|
+
} block_q5_1;
|
|
37
|
+
|
|
21
38
|
#define QK8_0 32
|
|
22
39
|
typedef struct {
|
|
23
40
|
half d; // delta
|
|
24
41
|
int8_t qs[QK8_0]; // quants
|
|
25
42
|
} block_q8_0;
|
|
26
43
|
|
|
44
|
+
#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
|
|
45
|
+
|
|
46
|
+
enum ggml_sort_order {
|
|
47
|
+
GGML_SORT_ASC,
|
|
48
|
+
GGML_SORT_DESC,
|
|
49
|
+
};
|
|
50
|
+
|
|
51
|
+
// general-purpose kernel for addition, multiplication and division of two tensors
|
|
52
|
+
// pros: works for non-contiguous tensors, supports broadcast across all dims
|
|
53
|
+
// cons: not very efficient
|
|
27
54
|
kernel void kernel_add(
|
|
28
|
-
device const
|
|
29
|
-
device const
|
|
30
|
-
device
|
|
31
|
-
|
|
32
|
-
|
|
55
|
+
device const char * src0,
|
|
56
|
+
device const char * src1,
|
|
57
|
+
device char * dst,
|
|
58
|
+
constant int64_t & ne00,
|
|
59
|
+
constant int64_t & ne01,
|
|
60
|
+
constant int64_t & ne02,
|
|
61
|
+
constant int64_t & ne03,
|
|
62
|
+
constant int64_t & nb00,
|
|
63
|
+
constant int64_t & nb01,
|
|
64
|
+
constant int64_t & nb02,
|
|
65
|
+
constant int64_t & nb03,
|
|
66
|
+
constant int64_t & ne10,
|
|
67
|
+
constant int64_t & ne11,
|
|
68
|
+
constant int64_t & ne12,
|
|
69
|
+
constant int64_t & ne13,
|
|
70
|
+
constant int64_t & nb10,
|
|
71
|
+
constant int64_t & nb11,
|
|
72
|
+
constant int64_t & nb12,
|
|
73
|
+
constant int64_t & nb13,
|
|
74
|
+
constant int64_t & ne0,
|
|
75
|
+
constant int64_t & ne1,
|
|
76
|
+
constant int64_t & ne2,
|
|
77
|
+
constant int64_t & ne3,
|
|
78
|
+
constant int64_t & nb0,
|
|
79
|
+
constant int64_t & nb1,
|
|
80
|
+
constant int64_t & nb2,
|
|
81
|
+
constant int64_t & nb3,
|
|
82
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
83
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
84
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
|
85
|
+
const int64_t i03 = tgpig.z;
|
|
86
|
+
const int64_t i02 = tgpig.y;
|
|
87
|
+
const int64_t i01 = tgpig.x;
|
|
88
|
+
|
|
89
|
+
const int64_t i13 = i03 % ne13;
|
|
90
|
+
const int64_t i12 = i02 % ne12;
|
|
91
|
+
const int64_t i11 = i01 % ne11;
|
|
92
|
+
|
|
93
|
+
device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
|
|
94
|
+
device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
|
|
95
|
+
device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1;
|
|
96
|
+
|
|
97
|
+
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
|
98
|
+
const int i10 = i0 % ne10;
|
|
99
|
+
*((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) + *((device float *)(src1_ptr + i10*nb10));
|
|
100
|
+
}
|
|
101
|
+
}
|
|
102
|
+
|
|
103
|
+
kernel void kernel_mul(
|
|
104
|
+
device const char * src0,
|
|
105
|
+
device const char * src1,
|
|
106
|
+
device char * dst,
|
|
107
|
+
constant int64_t & ne00,
|
|
108
|
+
constant int64_t & ne01,
|
|
109
|
+
constant int64_t & ne02,
|
|
110
|
+
constant int64_t & ne03,
|
|
111
|
+
constant int64_t & nb00,
|
|
112
|
+
constant int64_t & nb01,
|
|
113
|
+
constant int64_t & nb02,
|
|
114
|
+
constant int64_t & nb03,
|
|
115
|
+
constant int64_t & ne10,
|
|
116
|
+
constant int64_t & ne11,
|
|
117
|
+
constant int64_t & ne12,
|
|
118
|
+
constant int64_t & ne13,
|
|
119
|
+
constant int64_t & nb10,
|
|
120
|
+
constant int64_t & nb11,
|
|
121
|
+
constant int64_t & nb12,
|
|
122
|
+
constant int64_t & nb13,
|
|
123
|
+
constant int64_t & ne0,
|
|
124
|
+
constant int64_t & ne1,
|
|
125
|
+
constant int64_t & ne2,
|
|
126
|
+
constant int64_t & ne3,
|
|
127
|
+
constant int64_t & nb0,
|
|
128
|
+
constant int64_t & nb1,
|
|
129
|
+
constant int64_t & nb2,
|
|
130
|
+
constant int64_t & nb3,
|
|
131
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
132
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
133
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
|
134
|
+
const int64_t i03 = tgpig.z;
|
|
135
|
+
const int64_t i02 = tgpig.y;
|
|
136
|
+
const int64_t i01 = tgpig.x;
|
|
137
|
+
|
|
138
|
+
const int64_t i13 = i03 % ne13;
|
|
139
|
+
const int64_t i12 = i02 % ne12;
|
|
140
|
+
const int64_t i11 = i01 % ne11;
|
|
141
|
+
|
|
142
|
+
device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
|
|
143
|
+
device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
|
|
144
|
+
device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1;
|
|
145
|
+
|
|
146
|
+
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
|
147
|
+
const int i10 = i0 % ne10;
|
|
148
|
+
*((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) * *((device float *)(src1_ptr + i10*nb10));
|
|
149
|
+
}
|
|
150
|
+
}
|
|
151
|
+
|
|
152
|
+
kernel void kernel_div(
|
|
153
|
+
device const char * src0,
|
|
154
|
+
device const char * src1,
|
|
155
|
+
device char * dst,
|
|
156
|
+
constant int64_t & ne00,
|
|
157
|
+
constant int64_t & ne01,
|
|
158
|
+
constant int64_t & ne02,
|
|
159
|
+
constant int64_t & ne03,
|
|
160
|
+
constant int64_t & nb00,
|
|
161
|
+
constant int64_t & nb01,
|
|
162
|
+
constant int64_t & nb02,
|
|
163
|
+
constant int64_t & nb03,
|
|
164
|
+
constant int64_t & ne10,
|
|
165
|
+
constant int64_t & ne11,
|
|
166
|
+
constant int64_t & ne12,
|
|
167
|
+
constant int64_t & ne13,
|
|
168
|
+
constant int64_t & nb10,
|
|
169
|
+
constant int64_t & nb11,
|
|
170
|
+
constant int64_t & nb12,
|
|
171
|
+
constant int64_t & nb13,
|
|
172
|
+
constant int64_t & ne0,
|
|
173
|
+
constant int64_t & ne1,
|
|
174
|
+
constant int64_t & ne2,
|
|
175
|
+
constant int64_t & ne3,
|
|
176
|
+
constant int64_t & nb0,
|
|
177
|
+
constant int64_t & nb1,
|
|
178
|
+
constant int64_t & nb2,
|
|
179
|
+
constant int64_t & nb3,
|
|
180
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
181
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
182
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
|
183
|
+
const int64_t i03 = tgpig.z;
|
|
184
|
+
const int64_t i02 = tgpig.y;
|
|
185
|
+
const int64_t i01 = tgpig.x;
|
|
186
|
+
|
|
187
|
+
const int64_t i13 = i03 % ne13;
|
|
188
|
+
const int64_t i12 = i02 % ne12;
|
|
189
|
+
const int64_t i11 = i01 % ne11;
|
|
190
|
+
|
|
191
|
+
device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
|
|
192
|
+
device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
|
|
193
|
+
device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1;
|
|
194
|
+
|
|
195
|
+
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
|
196
|
+
const int i10 = i0 % ne10;
|
|
197
|
+
*((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) / *((device float *)(src1_ptr + i10*nb10));
|
|
198
|
+
}
|
|
33
199
|
}
|
|
34
200
|
|
|
35
201
|
// assumption: src1 is a row
|
|
@@ -38,34 +204,41 @@ kernel void kernel_add_row(
|
|
|
38
204
|
device const float4 * src0,
|
|
39
205
|
device const float4 * src1,
|
|
40
206
|
device float4 * dst,
|
|
41
|
-
constant int64_t & nb,
|
|
207
|
+
constant int64_t & nb [[buffer(27)]],
|
|
42
208
|
uint tpig[[thread_position_in_grid]]) {
|
|
43
209
|
dst[tpig] = src0[tpig] + src1[tpig % nb];
|
|
44
210
|
}
|
|
45
211
|
|
|
46
|
-
kernel void
|
|
212
|
+
kernel void kernel_mul_row(
|
|
47
213
|
device const float4 * src0,
|
|
48
214
|
device const float4 * src1,
|
|
49
215
|
device float4 * dst,
|
|
216
|
+
constant int64_t & nb [[buffer(27)]],
|
|
50
217
|
uint tpig[[thread_position_in_grid]]) {
|
|
51
|
-
dst[tpig] = src0[tpig] * src1[tpig];
|
|
218
|
+
dst[tpig] = src0[tpig] * src1[tpig % nb];
|
|
52
219
|
}
|
|
53
220
|
|
|
54
|
-
|
|
55
|
-
// broadcast src1 into src0
|
|
56
|
-
kernel void kernel_mul_row(
|
|
221
|
+
kernel void kernel_div_row(
|
|
57
222
|
device const float4 * src0,
|
|
58
223
|
device const float4 * src1,
|
|
59
224
|
device float4 * dst,
|
|
60
|
-
constant int64_t & nb,
|
|
225
|
+
constant int64_t & nb [[buffer(27)]],
|
|
61
226
|
uint tpig[[thread_position_in_grid]]) {
|
|
62
|
-
dst[tpig] = src0[tpig]
|
|
227
|
+
dst[tpig] = src0[tpig] / src1[tpig % nb];
|
|
63
228
|
}
|
|
64
229
|
|
|
65
230
|
kernel void kernel_scale(
|
|
231
|
+
device const float * src0,
|
|
232
|
+
device float * dst,
|
|
233
|
+
constant float & scale,
|
|
234
|
+
uint tpig[[thread_position_in_grid]]) {
|
|
235
|
+
dst[tpig] = src0[tpig] * scale;
|
|
236
|
+
}
|
|
237
|
+
|
|
238
|
+
kernel void kernel_scale_4(
|
|
66
239
|
device const float4 * src0,
|
|
67
240
|
device float4 * dst,
|
|
68
|
-
constant float
|
|
241
|
+
constant float & scale,
|
|
69
242
|
uint tpig[[thread_position_in_grid]]) {
|
|
70
243
|
dst[tpig] = src0[tpig] * scale;
|
|
71
244
|
}
|
|
@@ -85,6 +258,61 @@ kernel void kernel_relu(
|
|
|
85
258
|
dst[tpig] = max(0.0f, src0[tpig]);
|
|
86
259
|
}
|
|
87
260
|
|
|
261
|
+
kernel void kernel_sqr(
|
|
262
|
+
device const float * src0,
|
|
263
|
+
device float * dst,
|
|
264
|
+
uint tpig[[thread_position_in_grid]]) {
|
|
265
|
+
dst[tpig] = src0[tpig] * src0[tpig];
|
|
266
|
+
}
|
|
267
|
+
|
|
268
|
+
kernel void kernel_sum_rows(
|
|
269
|
+
device const float * src0,
|
|
270
|
+
device float * dst,
|
|
271
|
+
constant int64_t & ne00,
|
|
272
|
+
constant int64_t & ne01,
|
|
273
|
+
constant int64_t & ne02,
|
|
274
|
+
constant int64_t & ne03,
|
|
275
|
+
constant int64_t & nb00,
|
|
276
|
+
constant int64_t & nb01,
|
|
277
|
+
constant int64_t & nb02,
|
|
278
|
+
constant int64_t & nb03,
|
|
279
|
+
constant int64_t & ne10,
|
|
280
|
+
constant int64_t & ne11,
|
|
281
|
+
constant int64_t & ne12,
|
|
282
|
+
constant int64_t & ne13,
|
|
283
|
+
constant int64_t & nb10,
|
|
284
|
+
constant int64_t & nb11,
|
|
285
|
+
constant int64_t & nb12,
|
|
286
|
+
constant int64_t & nb13,
|
|
287
|
+
constant int64_t & ne0,
|
|
288
|
+
constant int64_t & ne1,
|
|
289
|
+
constant int64_t & ne2,
|
|
290
|
+
constant int64_t & ne3,
|
|
291
|
+
constant int64_t & nb0,
|
|
292
|
+
constant int64_t & nb1,
|
|
293
|
+
constant int64_t & nb2,
|
|
294
|
+
constant int64_t & nb3,
|
|
295
|
+
uint3 tpig[[thread_position_in_grid]]) {
|
|
296
|
+
int64_t i3 = tpig.z;
|
|
297
|
+
int64_t i2 = tpig.y;
|
|
298
|
+
int64_t i1 = tpig.x;
|
|
299
|
+
|
|
300
|
+
if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
|
|
301
|
+
return;
|
|
302
|
+
}
|
|
303
|
+
|
|
304
|
+
device const float * src_row = (device const float *) ((device const char *) src0 + i1*nb01 + i2*nb02 + i3*nb03);
|
|
305
|
+
device float * dst_row = (device float *) ((device char *) dst + i1*nb1 + i2*nb2 + i3*nb3);
|
|
306
|
+
|
|
307
|
+
float row_sum = 0;
|
|
308
|
+
|
|
309
|
+
for (int64_t i0 = 0; i0 < ne00; i0++) {
|
|
310
|
+
row_sum += src_row[i0];
|
|
311
|
+
}
|
|
312
|
+
|
|
313
|
+
dst_row[0] = row_sum;
|
|
314
|
+
}
|
|
315
|
+
|
|
88
316
|
constant float GELU_COEF_A = 0.044715f;
|
|
89
317
|
constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
|
90
318
|
|
|
@@ -103,82 +331,165 @@ kernel void kernel_gelu(
|
|
|
103
331
|
|
|
104
332
|
kernel void kernel_soft_max(
|
|
105
333
|
device const float * src0,
|
|
334
|
+
device const float * src1,
|
|
106
335
|
device float * dst,
|
|
107
336
|
constant int64_t & ne00,
|
|
108
337
|
constant int64_t & ne01,
|
|
109
338
|
constant int64_t & ne02,
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
339
|
+
constant float & scale,
|
|
340
|
+
threadgroup float * buf [[threadgroup(0)]],
|
|
341
|
+
uint tgpig[[threadgroup_position_in_grid]],
|
|
342
|
+
uint tpitg[[thread_position_in_threadgroup]],
|
|
343
|
+
uint sgitg[[simdgroup_index_in_threadgroup]],
|
|
344
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
345
|
+
uint ntg[[threads_per_threadgroup]]) {
|
|
346
|
+
const int64_t i03 = (tgpig) / (ne02*ne01);
|
|
347
|
+
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
|
|
348
|
+
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
|
|
349
|
+
|
|
350
|
+
device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
|
351
|
+
device const float * pmask = src1 ? src1 + i01*ne00 : nullptr;
|
|
352
|
+
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
|
119
353
|
|
|
120
354
|
// parallel max
|
|
121
|
-
float lmax =
|
|
122
|
-
|
|
123
|
-
|
|
355
|
+
float lmax = -INFINITY;
|
|
356
|
+
|
|
357
|
+
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
|
358
|
+
lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f));
|
|
359
|
+
}
|
|
360
|
+
|
|
361
|
+
// find the max value in the block
|
|
362
|
+
float max_val = simd_max(lmax);
|
|
363
|
+
if (ntg > N_SIMDWIDTH) {
|
|
364
|
+
if (sgitg == 0) {
|
|
365
|
+
buf[tiisg] = -INFINITY;
|
|
366
|
+
}
|
|
367
|
+
|
|
368
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
369
|
+
|
|
370
|
+
if (tiisg == 0) {
|
|
371
|
+
buf[sgitg] = max_val;
|
|
372
|
+
}
|
|
373
|
+
|
|
374
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
375
|
+
|
|
376
|
+
max_val = buf[tiisg];
|
|
377
|
+
max_val = simd_max(max_val);
|
|
124
378
|
}
|
|
125
|
-
const float max = simd_max(lmax);
|
|
126
379
|
|
|
127
380
|
// parallel sum
|
|
128
381
|
float lsum = 0.0f;
|
|
129
|
-
for (int i00 = tpitg
|
|
130
|
-
const float exp_psrc0 = exp(psrc0[i00] -
|
|
382
|
+
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
|
383
|
+
const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val);
|
|
131
384
|
lsum += exp_psrc0;
|
|
132
|
-
// Remember the result of exp here. exp is expensive, so we really do not
|
|
133
|
-
// whish to compute it twice.
|
|
134
385
|
pdst[i00] = exp_psrc0;
|
|
135
386
|
}
|
|
136
387
|
|
|
137
|
-
|
|
388
|
+
float sum = simd_sum(lsum);
|
|
389
|
+
if (ntg > N_SIMDWIDTH) {
|
|
390
|
+
if (sgitg == 0) {
|
|
391
|
+
buf[tiisg] = 0.0f;
|
|
392
|
+
}
|
|
393
|
+
|
|
394
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
138
395
|
|
|
139
|
-
|
|
140
|
-
|
|
396
|
+
if (tiisg == 0) {
|
|
397
|
+
buf[sgitg] = sum;
|
|
398
|
+
}
|
|
399
|
+
|
|
400
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
401
|
+
|
|
402
|
+
sum = buf[tiisg];
|
|
403
|
+
sum = simd_sum(sum);
|
|
404
|
+
}
|
|
405
|
+
|
|
406
|
+
const float inv_sum = 1.0f/sum;
|
|
407
|
+
|
|
408
|
+
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
|
409
|
+
pdst[i00] *= inv_sum;
|
|
141
410
|
}
|
|
142
411
|
}
|
|
143
412
|
|
|
144
413
|
kernel void kernel_soft_max_4(
|
|
145
414
|
device const float * src0,
|
|
415
|
+
device const float * src1,
|
|
146
416
|
device float * dst,
|
|
147
417
|
constant int64_t & ne00,
|
|
148
418
|
constant int64_t & ne01,
|
|
149
419
|
constant int64_t & ne02,
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
420
|
+
constant float & scale,
|
|
421
|
+
threadgroup float * buf [[threadgroup(0)]],
|
|
422
|
+
uint tgpig[[threadgroup_position_in_grid]],
|
|
423
|
+
uint tpitg[[thread_position_in_threadgroup]],
|
|
424
|
+
uint sgitg[[simdgroup_index_in_threadgroup]],
|
|
425
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
426
|
+
uint ntg[[threads_per_threadgroup]]) {
|
|
427
|
+
const int64_t i03 = (tgpig) / (ne02*ne01);
|
|
428
|
+
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
|
|
429
|
+
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
|
|
430
|
+
|
|
431
|
+
device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
|
432
|
+
device const float4 * pmask = src1 ? (device const float4 *)(src1 + i01*ne00) : nullptr;
|
|
433
|
+
device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
|
159
434
|
|
|
160
435
|
// parallel max
|
|
161
|
-
float4 lmax4 =
|
|
162
|
-
|
|
163
|
-
|
|
436
|
+
float4 lmax4 = -INFINITY;
|
|
437
|
+
|
|
438
|
+
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
|
439
|
+
lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f));
|
|
164
440
|
}
|
|
165
|
-
float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
|
|
166
441
|
|
|
167
|
-
const float
|
|
442
|
+
const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
|
|
443
|
+
|
|
444
|
+
float max_val = simd_max(lmax);
|
|
445
|
+
if (ntg > N_SIMDWIDTH) {
|
|
446
|
+
if (sgitg == 0) {
|
|
447
|
+
buf[tiisg] = -INFINITY;
|
|
448
|
+
}
|
|
449
|
+
|
|
450
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
451
|
+
|
|
452
|
+
if (tiisg == 0) {
|
|
453
|
+
buf[sgitg] = max_val;
|
|
454
|
+
}
|
|
455
|
+
|
|
456
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
457
|
+
|
|
458
|
+
max_val = buf[tiisg];
|
|
459
|
+
max_val = simd_max(max_val);
|
|
460
|
+
}
|
|
168
461
|
|
|
169
462
|
// parallel sum
|
|
170
463
|
float4 lsum4 = 0.0f;
|
|
171
|
-
for (int i00 = tpitg
|
|
172
|
-
const float4 exp_psrc4 = exp(psrc4[i00] -
|
|
464
|
+
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
|
465
|
+
const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val);
|
|
173
466
|
lsum4 += exp_psrc4;
|
|
174
467
|
pdst4[i00] = exp_psrc4;
|
|
175
468
|
}
|
|
176
|
-
float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
|
|
177
469
|
|
|
178
|
-
const float
|
|
470
|
+
const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
|
|
471
|
+
float sum = simd_sum(lsum);
|
|
472
|
+
if (ntg > N_SIMDWIDTH) {
|
|
473
|
+
if (sgitg == 0) {
|
|
474
|
+
buf[tiisg] = 0.0f;
|
|
475
|
+
}
|
|
179
476
|
|
|
180
|
-
|
|
181
|
-
|
|
477
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
478
|
+
|
|
479
|
+
if (tiisg == 0) {
|
|
480
|
+
buf[sgitg] = sum;
|
|
481
|
+
}
|
|
482
|
+
|
|
483
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
484
|
+
|
|
485
|
+
sum = buf[tiisg];
|
|
486
|
+
sum = simd_sum(sum);
|
|
487
|
+
}
|
|
488
|
+
|
|
489
|
+
const float inv_sum = 1.0f/sum;
|
|
490
|
+
|
|
491
|
+
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
|
492
|
+
pdst4[i00] *= inv_sum;
|
|
182
493
|
}
|
|
183
494
|
}
|
|
184
495
|
|
|
@@ -197,7 +508,7 @@ kernel void kernel_diag_mask_inf(
|
|
|
197
508
|
dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY;
|
|
198
509
|
} else {
|
|
199
510
|
dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00];
|
|
200
|
-
|
|
511
|
+
}
|
|
201
512
|
}
|
|
202
513
|
|
|
203
514
|
kernel void kernel_diag_mask_inf_8(
|
|
@@ -285,16 +596,16 @@ kernel void kernel_rms_norm(
|
|
|
285
596
|
constant int64_t & ne00,
|
|
286
597
|
constant uint64_t & nb01,
|
|
287
598
|
constant float & eps,
|
|
288
|
-
threadgroup float *
|
|
599
|
+
threadgroup float * buf [[threadgroup(0)]],
|
|
289
600
|
uint tgpig[[threadgroup_position_in_grid]],
|
|
290
601
|
uint tpitg[[thread_position_in_threadgroup]],
|
|
291
602
|
uint sgitg[[simdgroup_index_in_threadgroup]],
|
|
292
603
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
293
604
|
uint ntg[[threads_per_threadgroup]]) {
|
|
294
605
|
device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
|
|
295
|
-
|
|
296
|
-
float4 sumf=0;
|
|
297
|
-
float all_sum=0;
|
|
606
|
+
|
|
607
|
+
float4 sumf = 0;
|
|
608
|
+
float all_sum = 0;
|
|
298
609
|
|
|
299
610
|
// parallel sum
|
|
300
611
|
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
|
@@ -302,35 +613,30 @@ kernel void kernel_rms_norm(
|
|
|
302
613
|
}
|
|
303
614
|
all_sum = sumf[0] + sumf[1] + sumf[2] + sumf[3];
|
|
304
615
|
all_sum = simd_sum(all_sum);
|
|
305
|
-
if (
|
|
306
|
-
|
|
307
|
-
|
|
616
|
+
if (ntg > N_SIMDWIDTH) {
|
|
617
|
+
if (sgitg == 0) {
|
|
618
|
+
buf[tiisg] = 0.0f;
|
|
619
|
+
}
|
|
308
620
|
|
|
309
|
-
|
|
310
|
-
// broadcast, simd group number is ntg / 32
|
|
311
|
-
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
|
|
312
|
-
if (tpitg < i) {
|
|
313
|
-
sum[tpitg] += sum[tpitg + i];
|
|
314
|
-
}
|
|
315
|
-
}
|
|
316
|
-
if (tpitg == 0) {
|
|
317
|
-
for (int i = 4 * (ne00 / 4); i < ne00; i++) {sum[0] += x_scalar[i];}
|
|
318
|
-
sum[0] /= ne00;
|
|
319
|
-
}
|
|
621
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
320
622
|
|
|
321
|
-
|
|
623
|
+
if (tiisg == 0) {
|
|
624
|
+
buf[sgitg] = all_sum;
|
|
625
|
+
}
|
|
322
626
|
|
|
323
|
-
|
|
627
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
628
|
+
|
|
629
|
+
all_sum = buf[tiisg];
|
|
630
|
+
all_sum = simd_sum(all_sum);
|
|
631
|
+
}
|
|
632
|
+
|
|
633
|
+
const float mean = all_sum/ne00;
|
|
324
634
|
const float scale = 1.0f/sqrt(mean + eps);
|
|
325
635
|
|
|
326
636
|
device float4 * y = (device float4 *) (dst + tgpig*ne00);
|
|
327
|
-
device float * y_scalar = (device float *) y;
|
|
328
637
|
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
|
329
638
|
y[i00] = x[i00] * scale;
|
|
330
639
|
}
|
|
331
|
-
if (tpitg == 0) {
|
|
332
|
-
for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {y_scalar[i00] = x_scalar[i00] * scale;}
|
|
333
|
-
}
|
|
334
640
|
}
|
|
335
641
|
|
|
336
642
|
// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
|
|
@@ -339,8 +645,11 @@ kernel void kernel_rms_norm(
|
|
|
339
645
|
// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
|
|
340
646
|
inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) {
|
|
341
647
|
float d = qb_curr->d;
|
|
648
|
+
|
|
342
649
|
float2 acc = 0.f;
|
|
650
|
+
|
|
343
651
|
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2);
|
|
652
|
+
|
|
344
653
|
for (int i = 0; i < 8; i+=2) {
|
|
345
654
|
acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
|
|
346
655
|
+ yl[i + 1] * (qs[i / 2] & 0x0F00);
|
|
@@ -357,8 +666,11 @@ inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thre
|
|
|
357
666
|
inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) {
|
|
358
667
|
float d = qb_curr->d;
|
|
359
668
|
float m = qb_curr->m;
|
|
360
|
-
|
|
669
|
+
|
|
361
670
|
float2 acc = 0.f;
|
|
671
|
+
|
|
672
|
+
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2);
|
|
673
|
+
|
|
362
674
|
for (int i = 0; i < 8; i+=2) {
|
|
363
675
|
acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
|
|
364
676
|
+ yl[i + 1] * (qs[i / 2] & 0x0F00);
|
|
@@ -368,31 +680,92 @@ inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thre
|
|
|
368
680
|
return d * (acc[0] + acc[1]) + sumy * m;
|
|
369
681
|
}
|
|
370
682
|
|
|
683
|
+
// function for calculate inner product between half a q5_0 block and 16 floats (yl), sumy is SUM(yl[i])
|
|
684
|
+
// il indicates where the q5 quants begin (0 or QK5_0/4)
|
|
685
|
+
// we assume that the yl's have been multiplied with the appropriate scale factor
|
|
686
|
+
// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
|
|
687
|
+
inline float block_q_n_dot_y(device const block_q5_0 * qb_curr, float sumy, thread float * yl, int il) {
|
|
688
|
+
float d = qb_curr->d;
|
|
689
|
+
|
|
690
|
+
float2 acc = 0.f;
|
|
691
|
+
|
|
692
|
+
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 3 + il/2);
|
|
693
|
+
const uint32_t qh = *((device const uint32_t *)qb_curr->qh);
|
|
694
|
+
|
|
695
|
+
for (int i = 0; i < 8; i+=2) {
|
|
696
|
+
acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010))
|
|
697
|
+
+ yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000));
|
|
698
|
+
acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100))
|
|
699
|
+
+ yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));
|
|
700
|
+
}
|
|
701
|
+
return d * (sumy * -16.f + acc[0] + acc[1]);
|
|
702
|
+
}
|
|
703
|
+
|
|
704
|
+
// function for calculate inner product between half a q5_1 block and 16 floats (yl), sumy is SUM(yl[i])
|
|
705
|
+
// il indicates where the q5 quants begin (0 or QK5_1/4)
|
|
706
|
+
// we assume that the yl's have been multiplied with the appropriate scale factor
|
|
707
|
+
// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
|
|
708
|
+
inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thread float * yl, int il) {
|
|
709
|
+
float d = qb_curr->d;
|
|
710
|
+
float m = qb_curr->m;
|
|
711
|
+
|
|
712
|
+
float2 acc = 0.f;
|
|
713
|
+
|
|
714
|
+
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 4 + il/2);
|
|
715
|
+
const uint32_t qh = *((device const uint32_t *)qb_curr->qh);
|
|
716
|
+
|
|
717
|
+
for (int i = 0; i < 8; i+=2) {
|
|
718
|
+
acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010))
|
|
719
|
+
+ yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000));
|
|
720
|
+
acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100))
|
|
721
|
+
+ yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));
|
|
722
|
+
}
|
|
723
|
+
return d * (acc[0] + acc[1]) + sumy * m;
|
|
724
|
+
}
|
|
725
|
+
|
|
371
726
|
// putting them in the kernel cause a significant performance penalty
|
|
372
|
-
#define N_DST 4
|
|
373
|
-
#define N_SIMDGROUP 2
|
|
374
|
-
#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
|
|
727
|
+
#define N_DST 4 // each SIMD group works on 4 rows
|
|
728
|
+
#define N_SIMDGROUP 2 // number of SIMD groups in a thread group
|
|
375
729
|
//Note: This is a template, but strictly speaking it only applies to
|
|
376
730
|
// quantizations where the block size is 32. It also does not
|
|
377
731
|
// giard against the number of rows not being divisible by
|
|
378
732
|
// N_DST, so this is another explicit assumption of the implementation.
|
|
379
733
|
template<typename block_q_type, int nr, int nsg, int nw>
|
|
380
|
-
void mul_vec_q_n_f32(
|
|
381
|
-
|
|
382
|
-
|
|
734
|
+
void mul_vec_q_n_f32(
|
|
735
|
+
device const void * src0,
|
|
736
|
+
device const float * src1,
|
|
737
|
+
device float * dst,
|
|
738
|
+
int64_t ne00,
|
|
739
|
+
int64_t ne01,
|
|
740
|
+
int64_t ne02,
|
|
741
|
+
int64_t ne10,
|
|
742
|
+
int64_t ne12,
|
|
743
|
+
int64_t ne0,
|
|
744
|
+
int64_t ne1,
|
|
745
|
+
uint r2,
|
|
746
|
+
uint r3,
|
|
747
|
+
uint3 tgpig, uint tiisg, uint sgitg) {
|
|
383
748
|
const int nb = ne00/QK4_0;
|
|
749
|
+
|
|
384
750
|
const int r0 = tgpig.x;
|
|
385
751
|
const int r1 = tgpig.y;
|
|
386
752
|
const int im = tgpig.z;
|
|
753
|
+
|
|
387
754
|
const int first_row = (r0 * nsg + sgitg) * nr;
|
|
388
|
-
|
|
755
|
+
|
|
756
|
+
const uint i12 = im%ne12;
|
|
757
|
+
const uint i13 = im/ne12;
|
|
758
|
+
|
|
759
|
+
const uint offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
|
760
|
+
|
|
389
761
|
device const block_q_type * x = (device const block_q_type *) src0 + offset0;
|
|
390
762
|
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
|
391
|
-
float yl[16]; // src1 vector cache
|
|
392
|
-
float sumf[nr]={0.f};
|
|
393
763
|
|
|
394
|
-
|
|
395
|
-
|
|
764
|
+
float yl[16]; // src1 vector cache
|
|
765
|
+
float sumf[nr] = {0.f};
|
|
766
|
+
|
|
767
|
+
const int ix = (tiisg/2);
|
|
768
|
+
const int il = (tiisg%2)*8;
|
|
396
769
|
|
|
397
770
|
device const float * yb = y + ix * QK4_0 + il;
|
|
398
771
|
|
|
@@ -403,6 +776,7 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
|
|
|
403
776
|
sumy += yb[i] + yb[i+1];
|
|
404
777
|
yl[i+0] = yb[i+ 0];
|
|
405
778
|
yl[i+1] = yb[i+ 1]/256.f;
|
|
779
|
+
|
|
406
780
|
sumy += yb[i+16] + yb[i+17];
|
|
407
781
|
yl[i+8] = yb[i+16]/16.f;
|
|
408
782
|
yl[i+9] = yb[i+17]/4096.f;
|
|
@@ -418,12 +792,12 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
|
|
|
418
792
|
for (int row = 0; row < nr; ++row) {
|
|
419
793
|
const float tot = simd_sum(sumf[row]);
|
|
420
794
|
if (tiisg == 0 && first_row + row < ne01) {
|
|
421
|
-
dst[
|
|
795
|
+
dst[im*ne0*ne1 + r1*ne0 + first_row + row] = tot;
|
|
422
796
|
}
|
|
423
797
|
}
|
|
424
798
|
}
|
|
425
799
|
|
|
426
|
-
kernel void
|
|
800
|
+
kernel void kernel_mul_mv_q4_0_f32(
|
|
427
801
|
device const void * src0,
|
|
428
802
|
device const float * src1,
|
|
429
803
|
device float * dst,
|
|
@@ -432,16 +806,17 @@ kernel void kernel_mul_mat_q4_0_f32(
|
|
|
432
806
|
constant int64_t & ne02[[buffer(5)]],
|
|
433
807
|
constant int64_t & ne10[[buffer(9)]],
|
|
434
808
|
constant int64_t & ne12[[buffer(11)]],
|
|
435
|
-
constant int64_t & ne0[[buffer(15)]],
|
|
436
|
-
constant int64_t & ne1[[buffer(16)]],
|
|
437
|
-
constant uint &
|
|
809
|
+
constant int64_t & ne0 [[buffer(15)]],
|
|
810
|
+
constant int64_t & ne1 [[buffer(16)]],
|
|
811
|
+
constant uint & r2 [[buffer(17)]],
|
|
812
|
+
constant uint & r3 [[buffer(18)]],
|
|
438
813
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
439
|
-
uint
|
|
440
|
-
uint
|
|
441
|
-
mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,
|
|
814
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
815
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
816
|
+
mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
|
|
442
817
|
}
|
|
443
818
|
|
|
444
|
-
kernel void
|
|
819
|
+
kernel void kernel_mul_mv_q4_1_f32(
|
|
445
820
|
device const void * src0,
|
|
446
821
|
device const float * src1,
|
|
447
822
|
device float * dst,
|
|
@@ -450,18 +825,58 @@ kernel void kernel_mul_mat_q4_1_f32(
|
|
|
450
825
|
constant int64_t & ne02[[buffer(5)]],
|
|
451
826
|
constant int64_t & ne10[[buffer(9)]],
|
|
452
827
|
constant int64_t & ne12[[buffer(11)]],
|
|
453
|
-
constant int64_t & ne0[[buffer(15)]],
|
|
454
|
-
constant int64_t & ne1[[buffer(16)]],
|
|
455
|
-
constant uint &
|
|
828
|
+
constant int64_t & ne0 [[buffer(15)]],
|
|
829
|
+
constant int64_t & ne1 [[buffer(16)]],
|
|
830
|
+
constant uint & r2 [[buffer(17)]],
|
|
831
|
+
constant uint & r3 [[buffer(18)]],
|
|
456
832
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
457
833
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
458
834
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
459
|
-
mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,
|
|
835
|
+
mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
|
|
460
836
|
}
|
|
461
837
|
|
|
838
|
+
kernel void kernel_mul_mv_q5_0_f32(
|
|
839
|
+
device const void * src0,
|
|
840
|
+
device const float * src1,
|
|
841
|
+
device float * dst,
|
|
842
|
+
constant int64_t & ne00,
|
|
843
|
+
constant int64_t & ne01[[buffer(4)]],
|
|
844
|
+
constant int64_t & ne02[[buffer(5)]],
|
|
845
|
+
constant int64_t & ne10[[buffer(9)]],
|
|
846
|
+
constant int64_t & ne12[[buffer(11)]],
|
|
847
|
+
constant int64_t & ne0 [[buffer(15)]],
|
|
848
|
+
constant int64_t & ne1 [[buffer(16)]],
|
|
849
|
+
constant uint & r2 [[buffer(17)]],
|
|
850
|
+
constant uint & r3 [[buffer(18)]],
|
|
851
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
852
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
853
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
854
|
+
mul_vec_q_n_f32<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
|
|
855
|
+
}
|
|
856
|
+
|
|
857
|
+
kernel void kernel_mul_mv_q5_1_f32(
|
|
858
|
+
device const void * src0,
|
|
859
|
+
device const float * src1,
|
|
860
|
+
device float * dst,
|
|
861
|
+
constant int64_t & ne00,
|
|
862
|
+
constant int64_t & ne01[[buffer(4)]],
|
|
863
|
+
constant int64_t & ne02[[buffer(5)]],
|
|
864
|
+
constant int64_t & ne10[[buffer(9)]],
|
|
865
|
+
constant int64_t & ne12[[buffer(11)]],
|
|
866
|
+
constant int64_t & ne0 [[buffer(15)]],
|
|
867
|
+
constant int64_t & ne1 [[buffer(16)]],
|
|
868
|
+
constant uint & r2 [[buffer(17)]],
|
|
869
|
+
constant uint & r3 [[buffer(18)]],
|
|
870
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
871
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
872
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
873
|
+
mul_vec_q_n_f32<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
|
|
874
|
+
}
|
|
875
|
+
|
|
876
|
+
|
|
462
877
|
#define NB_Q8_0 8
|
|
463
878
|
|
|
464
|
-
kernel void
|
|
879
|
+
kernel void kernel_mul_mv_q8_0_f32(
|
|
465
880
|
device const void * src0,
|
|
466
881
|
device const float * src1,
|
|
467
882
|
device float * dst,
|
|
@@ -470,9 +885,10 @@ kernel void kernel_mul_mat_q8_0_f32(
|
|
|
470
885
|
constant int64_t & ne02[[buffer(5)]],
|
|
471
886
|
constant int64_t & ne10[[buffer(9)]],
|
|
472
887
|
constant int64_t & ne12[[buffer(11)]],
|
|
473
|
-
constant int64_t & ne0[[buffer(15)]],
|
|
474
|
-
constant int64_t & ne1[[buffer(16)]],
|
|
475
|
-
constant uint &
|
|
888
|
+
constant int64_t & ne0 [[buffer(15)]],
|
|
889
|
+
constant int64_t & ne1 [[buffer(16)]],
|
|
890
|
+
constant uint & r2 [[buffer(17)]],
|
|
891
|
+
constant uint & r3 [[buffer(18)]],
|
|
476
892
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
477
893
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
478
894
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
@@ -484,8 +900,14 @@ kernel void kernel_mul_mat_q8_0_f32(
|
|
|
484
900
|
const int r0 = tgpig.x;
|
|
485
901
|
const int r1 = tgpig.y;
|
|
486
902
|
const int im = tgpig.z;
|
|
903
|
+
|
|
487
904
|
const int first_row = (r0 * nsg + sgitg) * nr;
|
|
488
|
-
|
|
905
|
+
|
|
906
|
+
const uint i12 = im%ne12;
|
|
907
|
+
const uint i13 = im/ne12;
|
|
908
|
+
|
|
909
|
+
const uint offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
|
910
|
+
|
|
489
911
|
device const block_q8_0 * x = (device const block_q8_0 *) src0 + offset0;
|
|
490
912
|
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
|
491
913
|
|
|
@@ -525,7 +947,7 @@ kernel void kernel_mul_mat_q8_0_f32(
|
|
|
525
947
|
|
|
526
948
|
#define N_F32_F32 4
|
|
527
949
|
|
|
528
|
-
kernel void
|
|
950
|
+
kernel void kernel_mul_mv_f32_f32(
|
|
529
951
|
device const char * src0,
|
|
530
952
|
device const char * src1,
|
|
531
953
|
device float * dst,
|
|
@@ -543,14 +965,21 @@ kernel void kernel_mul_mat_f32_f32(
|
|
|
543
965
|
constant uint64_t & nb12,
|
|
544
966
|
constant int64_t & ne0,
|
|
545
967
|
constant int64_t & ne1,
|
|
968
|
+
constant uint & r2 [[buffer(17)]],
|
|
969
|
+
constant uint & r3 [[buffer(18)]],
|
|
546
970
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
547
|
-
uint
|
|
971
|
+
uint tiisg[[thread_index_in_simdgroup]]) {
|
|
548
972
|
|
|
549
973
|
const int64_t r0 = tgpig.x;
|
|
550
974
|
const int64_t rb = tgpig.y*N_F32_F32;
|
|
551
975
|
const int64_t im = tgpig.z;
|
|
552
976
|
|
|
553
|
-
|
|
977
|
+
const uint i12 = im%ne12;
|
|
978
|
+
const uint i13 = im/ne12;
|
|
979
|
+
|
|
980
|
+
const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
|
|
981
|
+
|
|
982
|
+
device const float * x = (device const float *) (src0 + offset0);
|
|
554
983
|
|
|
555
984
|
if (ne00 < 128) {
|
|
556
985
|
for (int row = 0; row < N_F32_F32; ++row) {
|
|
@@ -596,7 +1025,9 @@ kernel void kernel_mul_mat_f32_f32(
|
|
|
596
1025
|
}
|
|
597
1026
|
}
|
|
598
1027
|
|
|
599
|
-
|
|
1028
|
+
#define N_F16_F16 4
|
|
1029
|
+
|
|
1030
|
+
kernel void kernel_mul_mv_f16_f16(
|
|
600
1031
|
device const char * src0,
|
|
601
1032
|
device const char * src1,
|
|
602
1033
|
device float * dst,
|
|
@@ -614,14 +1045,99 @@ kernel void kernel_mul_mat_f16_f32_1row(
|
|
|
614
1045
|
constant uint64_t & nb12,
|
|
615
1046
|
constant int64_t & ne0,
|
|
616
1047
|
constant int64_t & ne1,
|
|
1048
|
+
constant uint & r2 [[buffer(17)]],
|
|
1049
|
+
constant uint & r3 [[buffer(18)]],
|
|
617
1050
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
618
|
-
uint
|
|
1051
|
+
uint tiisg[[thread_index_in_simdgroup]]) {
|
|
1052
|
+
|
|
1053
|
+
const int64_t r0 = tgpig.x;
|
|
1054
|
+
const int64_t rb = tgpig.y*N_F16_F16;
|
|
1055
|
+
const int64_t im = tgpig.z;
|
|
1056
|
+
|
|
1057
|
+
const uint i12 = im%ne12;
|
|
1058
|
+
const uint i13 = im/ne12;
|
|
1059
|
+
|
|
1060
|
+
const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
|
|
1061
|
+
|
|
1062
|
+
device const half * x = (device const half *) (src0 + offset0);
|
|
1063
|
+
|
|
1064
|
+
if (ne00 < 128) {
|
|
1065
|
+
for (int row = 0; row < N_F16_F16; ++row) {
|
|
1066
|
+
int r1 = rb + row;
|
|
1067
|
+
if (r1 >= ne11) {
|
|
1068
|
+
break;
|
|
1069
|
+
}
|
|
1070
|
+
|
|
1071
|
+
device const half * y = (device const half *) (src1 + r1*nb11 + im*nb12);
|
|
1072
|
+
|
|
1073
|
+
float sumf = 0;
|
|
1074
|
+
for (int i = tiisg; i < ne00; i += 32) {
|
|
1075
|
+
sumf += (half) x[i] * (half) y[i];
|
|
1076
|
+
}
|
|
1077
|
+
|
|
1078
|
+
float all_sum = simd_sum(sumf);
|
|
1079
|
+
if (tiisg == 0) {
|
|
1080
|
+
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
|
1081
|
+
}
|
|
1082
|
+
}
|
|
1083
|
+
} else {
|
|
1084
|
+
device const half4 * x4 = (device const half4 *)x;
|
|
1085
|
+
for (int row = 0; row < N_F16_F16; ++row) {
|
|
1086
|
+
int r1 = rb + row;
|
|
1087
|
+
if (r1 >= ne11) {
|
|
1088
|
+
break;
|
|
1089
|
+
}
|
|
1090
|
+
|
|
1091
|
+
device const half * y = (device const half *) (src1 + r1*nb11 + im*nb12);
|
|
1092
|
+
device const half4 * y4 = (device const half4 *) y;
|
|
1093
|
+
|
|
1094
|
+
float sumf = 0;
|
|
1095
|
+
for (int i = tiisg; i < ne00/4; i += 32) {
|
|
1096
|
+
for (int k = 0; k < 4; ++k) sumf += (half) x4[i][k] * y4[i][k];
|
|
1097
|
+
}
|
|
1098
|
+
|
|
1099
|
+
float all_sum = simd_sum(sumf);
|
|
1100
|
+
if (tiisg == 0) {
|
|
1101
|
+
for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (half) x[i] * y[i];
|
|
1102
|
+
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
|
1103
|
+
}
|
|
1104
|
+
}
|
|
1105
|
+
}
|
|
1106
|
+
}
|
|
1107
|
+
|
|
1108
|
+
kernel void kernel_mul_mv_f16_f32_1row(
|
|
1109
|
+
device const char * src0,
|
|
1110
|
+
device const char * src1,
|
|
1111
|
+
device float * dst,
|
|
1112
|
+
constant int64_t & ne00,
|
|
1113
|
+
constant int64_t & ne01,
|
|
1114
|
+
constant int64_t & ne02,
|
|
1115
|
+
constant uint64_t & nb00,
|
|
1116
|
+
constant uint64_t & nb01,
|
|
1117
|
+
constant uint64_t & nb02,
|
|
1118
|
+
constant int64_t & ne10,
|
|
1119
|
+
constant int64_t & ne11,
|
|
1120
|
+
constant int64_t & ne12,
|
|
1121
|
+
constant uint64_t & nb10,
|
|
1122
|
+
constant uint64_t & nb11,
|
|
1123
|
+
constant uint64_t & nb12,
|
|
1124
|
+
constant int64_t & ne0,
|
|
1125
|
+
constant int64_t & ne1,
|
|
1126
|
+
constant uint & r2 [[buffer(17)]],
|
|
1127
|
+
constant uint & r3 [[buffer(18)]],
|
|
1128
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1129
|
+
uint tiisg[[thread_index_in_simdgroup]]) {
|
|
619
1130
|
|
|
620
1131
|
const int64_t r0 = tgpig.x;
|
|
621
1132
|
const int64_t r1 = tgpig.y;
|
|
622
1133
|
const int64_t im = tgpig.z;
|
|
623
1134
|
|
|
624
|
-
|
|
1135
|
+
const uint i12 = im%ne12;
|
|
1136
|
+
const uint i13 = im/ne12;
|
|
1137
|
+
|
|
1138
|
+
const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
|
|
1139
|
+
|
|
1140
|
+
device const half * x = (device const half *) (src0 + offset0);
|
|
625
1141
|
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
|
|
626
1142
|
|
|
627
1143
|
float sumf = 0;
|
|
@@ -650,7 +1166,7 @@ kernel void kernel_mul_mat_f16_f32_1row(
|
|
|
650
1166
|
|
|
651
1167
|
#define N_F16_F32 4
|
|
652
1168
|
|
|
653
|
-
kernel void
|
|
1169
|
+
kernel void kernel_mul_mv_f16_f32(
|
|
654
1170
|
device const char * src0,
|
|
655
1171
|
device const char * src1,
|
|
656
1172
|
device float * dst,
|
|
@@ -668,6 +1184,8 @@ kernel void kernel_mul_mat_f16_f32(
|
|
|
668
1184
|
constant uint64_t & nb12,
|
|
669
1185
|
constant int64_t & ne0,
|
|
670
1186
|
constant int64_t & ne1,
|
|
1187
|
+
constant uint & r2 [[buffer(17)]],
|
|
1188
|
+
constant uint & r3 [[buffer(18)]],
|
|
671
1189
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
672
1190
|
uint tiisg[[thread_index_in_simdgroup]]) {
|
|
673
1191
|
|
|
@@ -675,7 +1193,12 @@ kernel void kernel_mul_mat_f16_f32(
|
|
|
675
1193
|
const int64_t rb = tgpig.y*N_F16_F32;
|
|
676
1194
|
const int64_t im = tgpig.z;
|
|
677
1195
|
|
|
678
|
-
|
|
1196
|
+
const uint i12 = im%ne12;
|
|
1197
|
+
const uint i13 = im/ne12;
|
|
1198
|
+
|
|
1199
|
+
const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
|
|
1200
|
+
|
|
1201
|
+
device const half * x = (device const half *) (src0 + offset0);
|
|
679
1202
|
|
|
680
1203
|
if (ne00 < 128) {
|
|
681
1204
|
for (int row = 0; row < N_F16_F32; ++row) {
|
|
@@ -722,7 +1245,7 @@ kernel void kernel_mul_mat_f16_f32(
|
|
|
722
1245
|
}
|
|
723
1246
|
|
|
724
1247
|
// Assumes row size (ne00) is a multiple of 4
|
|
725
|
-
kernel void
|
|
1248
|
+
kernel void kernel_mul_mv_f16_f32_l4(
|
|
726
1249
|
device const char * src0,
|
|
727
1250
|
device const char * src1,
|
|
728
1251
|
device float * dst,
|
|
@@ -740,33 +1263,387 @@ kernel void kernel_mul_mat_f16_f32_l4(
|
|
|
740
1263
|
constant uint64_t & nb12,
|
|
741
1264
|
constant int64_t & ne0,
|
|
742
1265
|
constant int64_t & ne1,
|
|
1266
|
+
constant uint & r2 [[buffer(17)]],
|
|
1267
|
+
constant uint & r3 [[buffer(18)]],
|
|
1268
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1269
|
+
uint tiisg[[thread_index_in_simdgroup]]) {
|
|
1270
|
+
|
|
1271
|
+
const int nrows = ne11;
|
|
1272
|
+
const int64_t r0 = tgpig.x;
|
|
1273
|
+
const int64_t im = tgpig.z;
|
|
1274
|
+
|
|
1275
|
+
const uint i12 = im%ne12;
|
|
1276
|
+
const uint i13 = im/ne12;
|
|
1277
|
+
|
|
1278
|
+
const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
|
|
1279
|
+
|
|
1280
|
+
device const half4 * x4 = (device const half4 *) (src0 + offset0);
|
|
1281
|
+
|
|
1282
|
+
for (int r1 = 0; r1 < nrows; ++r1) {
|
|
1283
|
+
device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12);
|
|
1284
|
+
|
|
1285
|
+
float sumf = 0;
|
|
1286
|
+
for (int i = tiisg; i < ne00/4; i += 32) {
|
|
1287
|
+
for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
|
|
1288
|
+
}
|
|
1289
|
+
|
|
1290
|
+
float all_sum = simd_sum(sumf);
|
|
1291
|
+
if (tiisg == 0) {
|
|
1292
|
+
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
|
1293
|
+
}
|
|
1294
|
+
}
|
|
1295
|
+
}
|
|
1296
|
+
|
|
1297
|
+
kernel void kernel_alibi_f32(
|
|
1298
|
+
device const float * src0,
|
|
1299
|
+
device float * dst,
|
|
1300
|
+
constant int64_t & ne00,
|
|
1301
|
+
constant int64_t & ne01,
|
|
1302
|
+
constant int64_t & ne02,
|
|
1303
|
+
constant int64_t & ne03,
|
|
1304
|
+
constant uint64_t & nb00,
|
|
1305
|
+
constant uint64_t & nb01,
|
|
1306
|
+
constant uint64_t & nb02,
|
|
1307
|
+
constant uint64_t & nb03,
|
|
1308
|
+
constant int64_t & ne0,
|
|
1309
|
+
constant int64_t & ne1,
|
|
1310
|
+
constant int64_t & ne2,
|
|
1311
|
+
constant int64_t & ne3,
|
|
1312
|
+
constant uint64_t & nb0,
|
|
1313
|
+
constant uint64_t & nb1,
|
|
1314
|
+
constant uint64_t & nb2,
|
|
1315
|
+
constant uint64_t & nb3,
|
|
1316
|
+
constant float & m0,
|
|
1317
|
+
constant float & m1,
|
|
1318
|
+
constant int & n_heads_log2_floor,
|
|
1319
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1320
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
1321
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
|
1322
|
+
const int64_t i03 = tgpig[2];
|
|
1323
|
+
const int64_t i02 = tgpig[1];
|
|
1324
|
+
const int64_t i01 = tgpig[0];
|
|
1325
|
+
|
|
1326
|
+
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
|
1327
|
+
|
|
1328
|
+
const int64_t i3 = n / (ne2*ne1*ne0);
|
|
1329
|
+
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
|
1330
|
+
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
|
1331
|
+
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
|
|
1332
|
+
const int64_t k = i3*ne3 + i2;
|
|
1333
|
+
|
|
1334
|
+
float m_k;
|
|
1335
|
+
if (k < n_heads_log2_floor) {
|
|
1336
|
+
m_k = pow(m0, k + 1);
|
|
1337
|
+
} else {
|
|
1338
|
+
m_k = pow(m1, 2 * (k - n_heads_log2_floor) + 1);
|
|
1339
|
+
}
|
|
1340
|
+
|
|
1341
|
+
device char * dst_row = (device char *) dst + i3*nb3 + i2*nb2 + i1*nb1;
|
|
1342
|
+
device const char * src_row = (device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01;
|
|
1343
|
+
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
|
|
1344
|
+
const float src_v = *(device float *)(src_row + i00*nb00);
|
|
1345
|
+
device float * dst_v = (device float *)(dst_row + i00*nb0);
|
|
1346
|
+
*dst_v = i00 * m_k + src_v;
|
|
1347
|
+
}
|
|
1348
|
+
}
|
|
1349
|
+
|
|
1350
|
+
static float rope_yarn_ramp(const float low, const float high, const int i0) {
|
|
1351
|
+
const float y = (i0 / 2 - low) / max(0.001f, high - low);
|
|
1352
|
+
return 1.0f - min(1.0f, max(0.0f, y));
|
|
1353
|
+
}
|
|
1354
|
+
|
|
1355
|
+
// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
|
|
1356
|
+
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
|
|
1357
|
+
static void rope_yarn(
|
|
1358
|
+
float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
|
|
1359
|
+
thread float * cos_theta, thread float * sin_theta
|
|
1360
|
+
) {
|
|
1361
|
+
// Get n-d rotational scaling corrected for extrapolation
|
|
1362
|
+
float theta_interp = freq_scale * theta_extrap;
|
|
1363
|
+
float theta = theta_interp;
|
|
1364
|
+
if (ext_factor != 0.0f) {
|
|
1365
|
+
float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
|
|
1366
|
+
theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
|
|
1367
|
+
|
|
1368
|
+
// Get n-d magnitude scaling corrected for interpolation
|
|
1369
|
+
mscale *= 1.0f + 0.1f * log(1.0f / freq_scale);
|
|
1370
|
+
}
|
|
1371
|
+
*cos_theta = cos(theta) * mscale;
|
|
1372
|
+
*sin_theta = sin(theta) * mscale;
|
|
1373
|
+
}
|
|
1374
|
+
|
|
1375
|
+
// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
|
|
1376
|
+
// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
|
|
1377
|
+
static float rope_yarn_corr_factor(int n_dims, int n_orig_ctx, float n_rot, float base) {
|
|
1378
|
+
return n_dims * log(n_orig_ctx / (n_rot * 2 * M_PI_F)) / (2 * log(base));
|
|
1379
|
+
}
|
|
1380
|
+
|
|
1381
|
+
static void rope_yarn_corr_dims(
|
|
1382
|
+
int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]
|
|
1383
|
+
) {
|
|
1384
|
+
// start and end correction dims
|
|
1385
|
+
dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_fast, freq_base)));
|
|
1386
|
+
dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_slow, freq_base)));
|
|
1387
|
+
}
|
|
1388
|
+
|
|
1389
|
+
typedef void (rope_t)(
|
|
1390
|
+
device const void * src0,
|
|
1391
|
+
device const int32_t * src1,
|
|
1392
|
+
device float * dst,
|
|
1393
|
+
constant int64_t & ne00,
|
|
1394
|
+
constant int64_t & ne01,
|
|
1395
|
+
constant int64_t & ne02,
|
|
1396
|
+
constant int64_t & ne03,
|
|
1397
|
+
constant uint64_t & nb00,
|
|
1398
|
+
constant uint64_t & nb01,
|
|
1399
|
+
constant uint64_t & nb02,
|
|
1400
|
+
constant uint64_t & nb03,
|
|
1401
|
+
constant int64_t & ne0,
|
|
1402
|
+
constant int64_t & ne1,
|
|
1403
|
+
constant int64_t & ne2,
|
|
1404
|
+
constant int64_t & ne3,
|
|
1405
|
+
constant uint64_t & nb0,
|
|
1406
|
+
constant uint64_t & nb1,
|
|
1407
|
+
constant uint64_t & nb2,
|
|
1408
|
+
constant uint64_t & nb3,
|
|
1409
|
+
constant int & n_past,
|
|
1410
|
+
constant int & n_dims,
|
|
1411
|
+
constant int & mode,
|
|
1412
|
+
constant int & n_orig_ctx,
|
|
1413
|
+
constant float & freq_base,
|
|
1414
|
+
constant float & freq_scale,
|
|
1415
|
+
constant float & ext_factor,
|
|
1416
|
+
constant float & attn_factor,
|
|
1417
|
+
constant float & beta_fast,
|
|
1418
|
+
constant float & beta_slow,
|
|
1419
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
|
1420
|
+
uint3 tptg[[threads_per_threadgroup]],
|
|
1421
|
+
uint3 tgpig[[threadgroup_position_in_grid]]);
|
|
1422
|
+
|
|
1423
|
+
template<typename T>
|
|
1424
|
+
kernel void kernel_rope(
|
|
1425
|
+
device const void * src0,
|
|
1426
|
+
device const int32_t * src1,
|
|
1427
|
+
device float * dst,
|
|
1428
|
+
constant int64_t & ne00,
|
|
1429
|
+
constant int64_t & ne01,
|
|
1430
|
+
constant int64_t & ne02,
|
|
1431
|
+
constant int64_t & ne03,
|
|
1432
|
+
constant uint64_t & nb00,
|
|
1433
|
+
constant uint64_t & nb01,
|
|
1434
|
+
constant uint64_t & nb02,
|
|
1435
|
+
constant uint64_t & nb03,
|
|
1436
|
+
constant int64_t & ne0,
|
|
1437
|
+
constant int64_t & ne1,
|
|
1438
|
+
constant int64_t & ne2,
|
|
1439
|
+
constant int64_t & ne3,
|
|
1440
|
+
constant uint64_t & nb0,
|
|
1441
|
+
constant uint64_t & nb1,
|
|
1442
|
+
constant uint64_t & nb2,
|
|
1443
|
+
constant uint64_t & nb3,
|
|
1444
|
+
constant int & n_past,
|
|
1445
|
+
constant int & n_dims,
|
|
1446
|
+
constant int & mode,
|
|
1447
|
+
constant int & n_orig_ctx,
|
|
1448
|
+
constant float & freq_base,
|
|
1449
|
+
constant float & freq_scale,
|
|
1450
|
+
constant float & ext_factor,
|
|
1451
|
+
constant float & attn_factor,
|
|
1452
|
+
constant float & beta_fast,
|
|
1453
|
+
constant float & beta_slow,
|
|
1454
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
|
1455
|
+
uint3 tptg[[threads_per_threadgroup]],
|
|
1456
|
+
uint3 tgpig[[threadgroup_position_in_grid]]) {
|
|
1457
|
+
const int64_t i3 = tgpig[2];
|
|
1458
|
+
const int64_t i2 = tgpig[1];
|
|
1459
|
+
const int64_t i1 = tgpig[0];
|
|
1460
|
+
|
|
1461
|
+
const bool is_neox = mode & 2;
|
|
1462
|
+
|
|
1463
|
+
float corr_dims[2];
|
|
1464
|
+
rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
|
|
1465
|
+
|
|
1466
|
+
device const int32_t * pos = src1;
|
|
1467
|
+
|
|
1468
|
+
const int64_t p = pos[i2];
|
|
1469
|
+
|
|
1470
|
+
const float theta_0 = (float)p;
|
|
1471
|
+
const float inv_ndims = -1.f/n_dims;
|
|
1472
|
+
|
|
1473
|
+
if (!is_neox) {
|
|
1474
|
+
for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
|
|
1475
|
+
|
|
1476
|
+
const float theta = theta_0 * pow(freq_base, inv_ndims*i0);
|
|
1477
|
+
float cos_theta, sin_theta;
|
|
1478
|
+
rope_yarn(theta, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
|
1479
|
+
|
|
1480
|
+
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
|
1481
|
+
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
1482
|
+
|
|
1483
|
+
const T x0 = src[0];
|
|
1484
|
+
const T x1 = src[1];
|
|
1485
|
+
|
|
1486
|
+
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
|
1487
|
+
dst_data[1] = x0*sin_theta + x1*cos_theta;
|
|
1488
|
+
}
|
|
1489
|
+
} else {
|
|
1490
|
+
for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
|
|
1491
|
+
for (int64_t ic = 2*tiitg; ic < n_dims; ic += 2*tptg.x) {
|
|
1492
|
+
|
|
1493
|
+
// simplified from `(ib * n_dims + ic) * inv_ndims`
|
|
1494
|
+
const float cur_rot = inv_ndims*ic - ib;
|
|
1495
|
+
|
|
1496
|
+
const float theta = theta_0 * pow(freq_base, cur_rot);
|
|
1497
|
+
float cos_theta, sin_theta;
|
|
1498
|
+
rope_yarn(theta, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
|
1499
|
+
|
|
1500
|
+
const int64_t i0 = ib*n_dims + ic/2;
|
|
1501
|
+
|
|
1502
|
+
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
|
1503
|
+
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
1504
|
+
|
|
1505
|
+
const float x0 = src[0];
|
|
1506
|
+
const float x1 = src[n_dims/2];
|
|
1507
|
+
|
|
1508
|
+
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
|
1509
|
+
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
|
|
1510
|
+
}
|
|
1511
|
+
}
|
|
1512
|
+
}
|
|
1513
|
+
}
|
|
1514
|
+
|
|
1515
|
+
template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope<float>;
|
|
1516
|
+
template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope<half>;
|
|
1517
|
+
|
|
1518
|
+
kernel void kernel_im2col_f16(
|
|
1519
|
+
device const float * x,
|
|
1520
|
+
device half * dst,
|
|
1521
|
+
constant int32_t & ofs0,
|
|
1522
|
+
constant int32_t & ofs1,
|
|
1523
|
+
constant int32_t & IW,
|
|
1524
|
+
constant int32_t & IH,
|
|
1525
|
+
constant int32_t & CHW,
|
|
1526
|
+
constant int32_t & s0,
|
|
1527
|
+
constant int32_t & s1,
|
|
1528
|
+
constant int32_t & p0,
|
|
1529
|
+
constant int32_t & p1,
|
|
1530
|
+
constant int32_t & d0,
|
|
1531
|
+
constant int32_t & d1,
|
|
1532
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1533
|
+
uint3 tgpg[[threadgroups_per_grid]],
|
|
1534
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
1535
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
|
1536
|
+
const int32_t iiw = tgpig[2] * s0 + tpitg[2] * d0 - p0;
|
|
1537
|
+
const int32_t iih = tgpig[1] * s1 + tpitg[1] * d1 - p1;
|
|
1538
|
+
|
|
1539
|
+
const int32_t offset_dst =
|
|
1540
|
+
(tpitg[0] * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW +
|
|
1541
|
+
(tgpig[0] * (ntg[1] * ntg[2]) + tpitg[1] * ntg[2] + tpitg[2]);
|
|
1542
|
+
|
|
1543
|
+
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
|
|
1544
|
+
dst[offset_dst] = 0.0f;
|
|
1545
|
+
} else {
|
|
1546
|
+
const int32_t offset_src = tpitg[0] * ofs0 + tgpig[0] * ofs1;
|
|
1547
|
+
dst[offset_dst] = x[offset_src + iih * IW + iiw];
|
|
1548
|
+
}
|
|
1549
|
+
}
|
|
1550
|
+
|
|
1551
|
+
// bitonic sort implementation following the CUDA kernels as reference
|
|
1552
|
+
typedef void (argsort_t)(
|
|
1553
|
+
device const float * x,
|
|
1554
|
+
device int32_t * dst,
|
|
1555
|
+
constant int64_t & ncols,
|
|
1556
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1557
|
+
uint3 tpitg[[thread_position_in_threadgroup]]);
|
|
1558
|
+
|
|
1559
|
+
template<ggml_sort_order order>
|
|
1560
|
+
kernel void kernel_argsort_f32_i32(
|
|
1561
|
+
device const float * x,
|
|
1562
|
+
device int32_t * dst,
|
|
1563
|
+
constant int64_t & ncols,
|
|
1564
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1565
|
+
uint3 tpitg[[thread_position_in_threadgroup]]) {
|
|
1566
|
+
// bitonic sort
|
|
1567
|
+
int col = tpitg[0];
|
|
1568
|
+
int row = tgpig[1];
|
|
1569
|
+
|
|
1570
|
+
if (col >= ncols) return;
|
|
1571
|
+
|
|
1572
|
+
device const float * x_row = x + row * ncols;
|
|
1573
|
+
device int32_t * dst_row = dst + row * ncols;
|
|
1574
|
+
|
|
1575
|
+
// initialize indices
|
|
1576
|
+
if (col < ncols) {
|
|
1577
|
+
dst_row[col] = col;
|
|
1578
|
+
}
|
|
1579
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
1580
|
+
|
|
1581
|
+
for (int k = 2; k <= ncols; k *= 2) {
|
|
1582
|
+
for (int j = k / 2; j > 0; j /= 2) {
|
|
1583
|
+
int ixj = col ^ j;
|
|
1584
|
+
if (ixj > col) {
|
|
1585
|
+
if ((col & k) == 0) {
|
|
1586
|
+
if (order == GGML_SORT_ASC ? x_row[dst_row[col]] > x_row[dst_row[ixj]] : x_row[dst_row[col]] < x_row[dst_row[ixj]]) {
|
|
1587
|
+
SWAP(dst_row[col], dst_row[ixj]);
|
|
1588
|
+
}
|
|
1589
|
+
} else {
|
|
1590
|
+
if (order == GGML_SORT_ASC ? x_row[dst_row[col]] < x_row[dst_row[ixj]] : x_row[dst_row[col]] > x_row[dst_row[ixj]]) {
|
|
1591
|
+
SWAP(dst_row[col], dst_row[ixj]);
|
|
1592
|
+
}
|
|
1593
|
+
}
|
|
1594
|
+
}
|
|
1595
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
1596
|
+
}
|
|
1597
|
+
}
|
|
1598
|
+
}
|
|
1599
|
+
|
|
1600
|
+
template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ASC>;
|
|
1601
|
+
template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_DESC>;
|
|
1602
|
+
|
|
1603
|
+
kernel void kernel_cpy_f16_f16(
|
|
1604
|
+
device const half * src0,
|
|
1605
|
+
device half * dst,
|
|
1606
|
+
constant int64_t & ne00,
|
|
1607
|
+
constant int64_t & ne01,
|
|
1608
|
+
constant int64_t & ne02,
|
|
1609
|
+
constant int64_t & ne03,
|
|
1610
|
+
constant uint64_t & nb00,
|
|
1611
|
+
constant uint64_t & nb01,
|
|
1612
|
+
constant uint64_t & nb02,
|
|
1613
|
+
constant uint64_t & nb03,
|
|
1614
|
+
constant int64_t & ne0,
|
|
1615
|
+
constant int64_t & ne1,
|
|
1616
|
+
constant int64_t & ne2,
|
|
1617
|
+
constant int64_t & ne3,
|
|
1618
|
+
constant uint64_t & nb0,
|
|
1619
|
+
constant uint64_t & nb1,
|
|
1620
|
+
constant uint64_t & nb2,
|
|
1621
|
+
constant uint64_t & nb3,
|
|
743
1622
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
744
|
-
|
|
745
|
-
|
|
746
|
-
const
|
|
747
|
-
const int64_t
|
|
748
|
-
const int64_t
|
|
1623
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
1624
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
|
1625
|
+
const int64_t i03 = tgpig[2];
|
|
1626
|
+
const int64_t i02 = tgpig[1];
|
|
1627
|
+
const int64_t i01 = tgpig[0];
|
|
749
1628
|
|
|
750
|
-
|
|
1629
|
+
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
|
751
1630
|
|
|
752
|
-
|
|
753
|
-
|
|
1631
|
+
const int64_t i3 = n / (ne2*ne1*ne0);
|
|
1632
|
+
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
|
1633
|
+
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
|
1634
|
+
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
|
|
754
1635
|
|
|
755
|
-
|
|
756
|
-
for (int i = tiisg; i < ne00/4; i += 32) {
|
|
757
|
-
for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
|
|
758
|
-
}
|
|
1636
|
+
device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
759
1637
|
|
|
760
|
-
|
|
761
|
-
|
|
762
|
-
|
|
763
|
-
}
|
|
1638
|
+
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
|
|
1639
|
+
device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
|
1640
|
+
dst_data[i00] = src[0];
|
|
764
1641
|
}
|
|
765
1642
|
}
|
|
766
1643
|
|
|
767
|
-
kernel void
|
|
1644
|
+
kernel void kernel_cpy_f32_f16(
|
|
768
1645
|
device const float * src0,
|
|
769
|
-
device
|
|
1646
|
+
device half * dst,
|
|
770
1647
|
constant int64_t & ne00,
|
|
771
1648
|
constant int64_t & ne01,
|
|
772
1649
|
constant int64_t & ne02,
|
|
@@ -783,7 +1660,6 @@ kernel void kernel_alibi_f32(
|
|
|
783
1660
|
constant uint64_t & nb1,
|
|
784
1661
|
constant uint64_t & nb2,
|
|
785
1662
|
constant uint64_t & nb3,
|
|
786
|
-
constant float & m0,
|
|
787
1663
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
788
1664
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
789
1665
|
uint3 ntg[[threads_per_threadgroup]]) {
|
|
@@ -798,16 +1674,17 @@ kernel void kernel_alibi_f32(
|
|
|
798
1674
|
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
|
799
1675
|
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
|
|
800
1676
|
|
|
801
|
-
device
|
|
802
|
-
|
|
1677
|
+
device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
1678
|
+
|
|
803
1679
|
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
|
|
804
1680
|
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
|
805
|
-
|
|
1681
|
+
|
|
1682
|
+
dst_data[i00] = src[0];
|
|
806
1683
|
}
|
|
807
1684
|
}
|
|
808
1685
|
|
|
809
|
-
kernel void
|
|
810
|
-
device const
|
|
1686
|
+
kernel void kernel_cpy_f32_f32(
|
|
1687
|
+
device const float * src0,
|
|
811
1688
|
device float * dst,
|
|
812
1689
|
constant int64_t & ne00,
|
|
813
1690
|
constant int64_t & ne01,
|
|
@@ -825,67 +1702,32 @@ kernel void kernel_rope(
|
|
|
825
1702
|
constant uint64_t & nb1,
|
|
826
1703
|
constant uint64_t & nb2,
|
|
827
1704
|
constant uint64_t & nb3,
|
|
828
|
-
|
|
829
|
-
|
|
830
|
-
|
|
831
|
-
|
|
832
|
-
|
|
833
|
-
|
|
834
|
-
uint3 tptg[[threads_per_threadgroup]],
|
|
835
|
-
uint3 tgpig[[threadgroup_position_in_grid]]) {
|
|
836
|
-
const int64_t i3 = tgpig[2];
|
|
837
|
-
const int64_t i2 = tgpig[1];
|
|
838
|
-
const int64_t i1 = tgpig[0];
|
|
839
|
-
|
|
840
|
-
const bool is_neox = mode & 2;
|
|
841
|
-
|
|
842
|
-
const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
|
|
843
|
-
|
|
844
|
-
const float theta_0 = freq_scale * (float)p;
|
|
845
|
-
const float inv_ndims = -1.f/n_dims;
|
|
846
|
-
|
|
847
|
-
if (!is_neox) {
|
|
848
|
-
for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
|
|
849
|
-
|
|
850
|
-
const float theta = theta_0 * pow(freq_base, inv_ndims*i0);
|
|
851
|
-
const float cos_theta = cos(theta);
|
|
852
|
-
const float sin_theta = sin(theta);
|
|
853
|
-
|
|
854
|
-
device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
|
855
|
-
device float * dst_data = (device float *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
856
|
-
|
|
857
|
-
const float x0 = src[0];
|
|
858
|
-
const float x1 = src[1];
|
|
859
|
-
|
|
860
|
-
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
|
861
|
-
dst_data[1] = x0*sin_theta + x1*cos_theta;
|
|
862
|
-
}
|
|
863
|
-
} else {
|
|
864
|
-
for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
|
|
865
|
-
for (int64_t ic = 2*tiitg; ic < n_dims; ic += 2*tptg.x) {
|
|
1705
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1706
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
1707
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
|
1708
|
+
const int64_t i03 = tgpig[2];
|
|
1709
|
+
const int64_t i02 = tgpig[1];
|
|
1710
|
+
const int64_t i01 = tgpig[0];
|
|
866
1711
|
|
|
867
|
-
|
|
868
|
-
const float cos_theta = cos(theta);
|
|
869
|
-
const float sin_theta = sin(theta);
|
|
1712
|
+
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
|
870
1713
|
|
|
871
|
-
|
|
1714
|
+
const int64_t i3 = n / (ne2*ne1*ne0);
|
|
1715
|
+
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
|
1716
|
+
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
|
1717
|
+
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
|
|
872
1718
|
|
|
873
|
-
|
|
874
|
-
device float * dst_data = (device float *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
1719
|
+
device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
875
1720
|
|
|
876
|
-
|
|
877
|
-
|
|
1721
|
+
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
|
|
1722
|
+
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
|
878
1723
|
|
|
879
|
-
|
|
880
|
-
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
|
|
881
|
-
}
|
|
882
|
-
}
|
|
1724
|
+
dst_data[i00] = src[0];
|
|
883
1725
|
}
|
|
884
1726
|
}
|
|
885
1727
|
|
|
886
|
-
kernel void
|
|
887
|
-
device const
|
|
888
|
-
device
|
|
1728
|
+
kernel void kernel_cpy_f32_q8_0(
|
|
1729
|
+
device const float * src0,
|
|
1730
|
+
device void * dst,
|
|
889
1731
|
constant int64_t & ne00,
|
|
890
1732
|
constant int64_t & ne01,
|
|
891
1733
|
constant int64_t & ne02,
|
|
@@ -914,19 +1756,36 @@ kernel void kernel_cpy_f16_f16(
|
|
|
914
1756
|
const int64_t i3 = n / (ne2*ne1*ne0);
|
|
915
1757
|
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
|
916
1758
|
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
|
917
|
-
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
|
|
1759
|
+
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK8_0;
|
|
918
1760
|
|
|
919
|
-
device
|
|
1761
|
+
device block_q8_0 * dst_data = (device block_q8_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
920
1762
|
|
|
921
|
-
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
|
|
922
|
-
device const
|
|
923
|
-
|
|
1763
|
+
for (int64_t i00 = tpitg.x*QK8_0; i00 < ne00; i00 += ntg.x*QK8_0) {
|
|
1764
|
+
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
|
1765
|
+
|
|
1766
|
+
float amax = 0.0f; // absolute max
|
|
1767
|
+
|
|
1768
|
+
for (int j = 0; j < QK8_0; j++) {
|
|
1769
|
+
const float v = src[j];
|
|
1770
|
+
amax = MAX(amax, fabs(v));
|
|
1771
|
+
}
|
|
1772
|
+
|
|
1773
|
+
const float d = amax / ((1 << 7) - 1);
|
|
1774
|
+
const float id = d ? 1.0f/d : 0.0f;
|
|
1775
|
+
|
|
1776
|
+
dst_data[i00/QK8_0].d = d;
|
|
1777
|
+
|
|
1778
|
+
for (int j = 0; j < QK8_0; ++j) {
|
|
1779
|
+
const float x0 = src[j]*id;
|
|
1780
|
+
|
|
1781
|
+
dst_data[i00/QK8_0].qs[j] = round(x0);
|
|
1782
|
+
}
|
|
924
1783
|
}
|
|
925
1784
|
}
|
|
926
1785
|
|
|
927
|
-
kernel void
|
|
1786
|
+
kernel void kernel_cpy_f32_q4_0(
|
|
928
1787
|
device const float * src0,
|
|
929
|
-
device
|
|
1788
|
+
device void * dst,
|
|
930
1789
|
constant int64_t & ne00,
|
|
931
1790
|
constant int64_t & ne01,
|
|
932
1791
|
constant int64_t & ne02,
|
|
@@ -955,20 +1814,45 @@ kernel void kernel_cpy_f32_f16(
|
|
|
955
1814
|
const int64_t i3 = n / (ne2*ne1*ne0);
|
|
956
1815
|
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
|
957
1816
|
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
|
958
|
-
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
|
|
1817
|
+
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_0;
|
|
959
1818
|
|
|
960
|
-
device
|
|
1819
|
+
device block_q4_0 * dst_data = (device block_q4_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
961
1820
|
|
|
962
|
-
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
|
|
1821
|
+
for (int64_t i00 = tpitg.x*QK4_0; i00 < ne00; i00 += ntg.x*QK4_0) {
|
|
963
1822
|
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
|
964
1823
|
|
|
965
|
-
|
|
1824
|
+
float amax = 0.0f; // absolute max
|
|
1825
|
+
float max = 0.0f;
|
|
1826
|
+
|
|
1827
|
+
for (int j = 0; j < QK4_0; j++) {
|
|
1828
|
+
const float v = src[j];
|
|
1829
|
+
if (amax < fabs(v)) {
|
|
1830
|
+
amax = fabs(v);
|
|
1831
|
+
max = v;
|
|
1832
|
+
}
|
|
1833
|
+
}
|
|
1834
|
+
|
|
1835
|
+
const float d = max / -8;
|
|
1836
|
+
const float id = d ? 1.0f/d : 0.0f;
|
|
1837
|
+
|
|
1838
|
+
dst_data[i00/QK4_0].d = d;
|
|
1839
|
+
|
|
1840
|
+
for (int j = 0; j < QK4_0/2; ++j) {
|
|
1841
|
+
const float x0 = src[0 + j]*id;
|
|
1842
|
+
const float x1 = src[QK4_0/2 + j]*id;
|
|
1843
|
+
|
|
1844
|
+
const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));
|
|
1845
|
+
const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));
|
|
1846
|
+
|
|
1847
|
+
dst_data[i00/QK4_0].qs[j] = xi0;
|
|
1848
|
+
dst_data[i00/QK4_0].qs[j] |= xi1 << 4;
|
|
1849
|
+
}
|
|
966
1850
|
}
|
|
967
1851
|
}
|
|
968
1852
|
|
|
969
|
-
kernel void
|
|
1853
|
+
kernel void kernel_cpy_f32_q4_1(
|
|
970
1854
|
device const float * src0,
|
|
971
|
-
device
|
|
1855
|
+
device void * dst,
|
|
972
1856
|
constant int64_t & ne00,
|
|
973
1857
|
constant int64_t & ne01,
|
|
974
1858
|
constant int64_t & ne02,
|
|
@@ -997,14 +1881,94 @@ kernel void kernel_cpy_f32_f32(
|
|
|
997
1881
|
const int64_t i3 = n / (ne2*ne1*ne0);
|
|
998
1882
|
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
|
999
1883
|
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
|
1000
|
-
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
|
|
1884
|
+
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_1;
|
|
1001
1885
|
|
|
1002
|
-
device
|
|
1886
|
+
device block_q4_1 * dst_data = (device block_q4_1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
1003
1887
|
|
|
1004
|
-
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
|
|
1888
|
+
for (int64_t i00 = tpitg.x*QK4_1; i00 < ne00; i00 += ntg.x*QK4_1) {
|
|
1005
1889
|
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
|
1006
1890
|
|
|
1007
|
-
|
|
1891
|
+
float min = FLT_MAX;
|
|
1892
|
+
float max = -FLT_MAX;
|
|
1893
|
+
|
|
1894
|
+
for (int j = 0; j < QK4_1; j++) {
|
|
1895
|
+
const float v = src[j];
|
|
1896
|
+
if (min > v) min = v;
|
|
1897
|
+
if (max < v) max = v;
|
|
1898
|
+
}
|
|
1899
|
+
|
|
1900
|
+
const float d = (max - min) / ((1 << 4) - 1);
|
|
1901
|
+
const float id = d ? 1.0f/d : 0.0f;
|
|
1902
|
+
|
|
1903
|
+
dst_data[i00/QK4_1].d = d;
|
|
1904
|
+
dst_data[i00/QK4_1].m = min;
|
|
1905
|
+
|
|
1906
|
+
for (int j = 0; j < QK4_1/2; ++j) {
|
|
1907
|
+
const float x0 = (src[0 + j] - min)*id;
|
|
1908
|
+
const float x1 = (src[QK4_1/2 + j] - min)*id;
|
|
1909
|
+
|
|
1910
|
+
const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f));
|
|
1911
|
+
const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f));
|
|
1912
|
+
|
|
1913
|
+
dst_data[i00/QK4_1].qs[j] = xi0;
|
|
1914
|
+
dst_data[i00/QK4_1].qs[j] |= xi1 << 4;
|
|
1915
|
+
}
|
|
1916
|
+
}
|
|
1917
|
+
}
|
|
1918
|
+
|
|
1919
|
+
kernel void kernel_concat(
|
|
1920
|
+
device const char * src0,
|
|
1921
|
+
device const char * src1,
|
|
1922
|
+
device char * dst,
|
|
1923
|
+
constant int64_t & ne00,
|
|
1924
|
+
constant int64_t & ne01,
|
|
1925
|
+
constant int64_t & ne02,
|
|
1926
|
+
constant int64_t & ne03,
|
|
1927
|
+
constant uint64_t & nb00,
|
|
1928
|
+
constant uint64_t & nb01,
|
|
1929
|
+
constant uint64_t & nb02,
|
|
1930
|
+
constant uint64_t & nb03,
|
|
1931
|
+
constant int64_t & ne10,
|
|
1932
|
+
constant int64_t & ne11,
|
|
1933
|
+
constant int64_t & ne12,
|
|
1934
|
+
constant int64_t & ne13,
|
|
1935
|
+
constant uint64_t & nb10,
|
|
1936
|
+
constant uint64_t & nb11,
|
|
1937
|
+
constant uint64_t & nb12,
|
|
1938
|
+
constant uint64_t & nb13,
|
|
1939
|
+
constant int64_t & ne0,
|
|
1940
|
+
constant int64_t & ne1,
|
|
1941
|
+
constant int64_t & ne2,
|
|
1942
|
+
constant int64_t & ne3,
|
|
1943
|
+
constant uint64_t & nb0,
|
|
1944
|
+
constant uint64_t & nb1,
|
|
1945
|
+
constant uint64_t & nb2,
|
|
1946
|
+
constant uint64_t & nb3,
|
|
1947
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1948
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
1949
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
|
1950
|
+
|
|
1951
|
+
const int64_t i03 = tgpig.z;
|
|
1952
|
+
const int64_t i02 = tgpig.y;
|
|
1953
|
+
const int64_t i01 = tgpig.x;
|
|
1954
|
+
|
|
1955
|
+
const int64_t i13 = i03 % ne13;
|
|
1956
|
+
const int64_t i12 = i02 % ne12;
|
|
1957
|
+
const int64_t i11 = i01 % ne11;
|
|
1958
|
+
|
|
1959
|
+
device const char * src0_ptr = src0 + i03 * nb03 + i02 * nb02 + i01 * nb01 + tpitg.x*nb00;
|
|
1960
|
+
device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10;
|
|
1961
|
+
device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0;
|
|
1962
|
+
|
|
1963
|
+
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
|
1964
|
+
if (i02 < ne02) {
|
|
1965
|
+
((device float *)dst_ptr)[0] = ((device float *)src0_ptr)[0];
|
|
1966
|
+
src0_ptr += ntg.x*nb00;
|
|
1967
|
+
} else {
|
|
1968
|
+
((device float *)dst_ptr)[0] = ((device float *)src1_ptr)[0];
|
|
1969
|
+
src1_ptr += ntg.x*nb10;
|
|
1970
|
+
}
|
|
1971
|
+
dst_ptr += ntg.x*nb0;
|
|
1008
1972
|
}
|
|
1009
1973
|
}
|
|
1010
1974
|
|
|
@@ -1100,7 +2064,7 @@ static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
|
|
|
1100
2064
|
|
|
1101
2065
|
//====================================== dot products =========================
|
|
1102
2066
|
|
|
1103
|
-
kernel void
|
|
2067
|
+
kernel void kernel_mul_mv_q2_K_f32(
|
|
1104
2068
|
device const void * src0,
|
|
1105
2069
|
device const float * src1,
|
|
1106
2070
|
device float * dst,
|
|
@@ -1109,23 +2073,30 @@ kernel void kernel_mul_mat_q2_K_f32(
|
|
|
1109
2073
|
constant int64_t & ne02[[buffer(5)]],
|
|
1110
2074
|
constant int64_t & ne10[[buffer(9)]],
|
|
1111
2075
|
constant int64_t & ne12[[buffer(11)]],
|
|
1112
|
-
constant int64_t & ne0[[buffer(15)]],
|
|
1113
|
-
constant int64_t & ne1[[buffer(16)]],
|
|
1114
|
-
constant uint &
|
|
2076
|
+
constant int64_t & ne0 [[buffer(15)]],
|
|
2077
|
+
constant int64_t & ne1 [[buffer(16)]],
|
|
2078
|
+
constant uint & r2 [[buffer(17)]],
|
|
2079
|
+
constant uint & r3 [[buffer(18)]],
|
|
1115
2080
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1116
|
-
uint
|
|
1117
|
-
uint
|
|
2081
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
2082
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
1118
2083
|
|
|
1119
2084
|
const int nb = ne00/QK_K;
|
|
1120
2085
|
const int r0 = tgpig.x;
|
|
1121
2086
|
const int r1 = tgpig.y;
|
|
1122
|
-
const int
|
|
2087
|
+
const int im = tgpig.z;
|
|
1123
2088
|
|
|
1124
2089
|
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
|
1125
2090
|
const int ib_row = first_row * nb;
|
|
1126
|
-
|
|
2091
|
+
|
|
2092
|
+
const uint i12 = im%ne12;
|
|
2093
|
+
const uint i13 = im/ne12;
|
|
2094
|
+
|
|
2095
|
+
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
|
2096
|
+
|
|
1127
2097
|
device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row + offset0;
|
|
1128
|
-
device const float * y = (device const float *) src1 + r1*ne10 +
|
|
2098
|
+
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
|
2099
|
+
|
|
1129
2100
|
float yl[32];
|
|
1130
2101
|
float sumf[N_DST]={0.f}, all_sum;
|
|
1131
2102
|
|
|
@@ -1134,11 +2105,11 @@ kernel void kernel_mul_mat_q2_K_f32(
|
|
|
1134
2105
|
#if QK_K == 256
|
|
1135
2106
|
const int ix = tiisg/8; // 0...3
|
|
1136
2107
|
const int it = tiisg%8; // 0...7
|
|
1137
|
-
const int
|
|
2108
|
+
const int iq = it/4; // 0 or 1
|
|
1138
2109
|
const int ir = it%4; // 0...3
|
|
1139
2110
|
const int is = (8*ir)/16;// 0 or 1
|
|
1140
2111
|
|
|
1141
|
-
device const float * y4 = y + ix * QK_K + 128 *
|
|
2112
|
+
device const float * y4 = y + ix * QK_K + 128 * iq + 8 * ir;
|
|
1142
2113
|
|
|
1143
2114
|
for (int ib = ix; ib < nb; ib += 4) {
|
|
1144
2115
|
|
|
@@ -1150,8 +2121,8 @@ kernel void kernel_mul_mat_q2_K_f32(
|
|
|
1150
2121
|
yl[i+24] = y4[i+96]; sumy[3] += yl[i+24];
|
|
1151
2122
|
}
|
|
1152
2123
|
|
|
1153
|
-
device const uint8_t * sc = (device const uint8_t *)x[ib].scales + 8*
|
|
1154
|
-
device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 *
|
|
2124
|
+
device const uint8_t * sc = (device const uint8_t *)x[ib].scales + 8*iq + is;
|
|
2125
|
+
device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir;
|
|
1155
2126
|
device const half * dh = &x[ib].d;
|
|
1156
2127
|
|
|
1157
2128
|
for (int row = 0; row < N_DST; row++) {
|
|
@@ -1238,13 +2209,13 @@ kernel void kernel_mul_mat_q2_K_f32(
|
|
|
1238
2209
|
for (int row = 0; row < N_DST; ++row) {
|
|
1239
2210
|
all_sum = simd_sum(sumf[row]);
|
|
1240
2211
|
if (tiisg == 0) {
|
|
1241
|
-
dst[r1*ne0 +
|
|
2212
|
+
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
|
|
1242
2213
|
}
|
|
1243
2214
|
}
|
|
1244
2215
|
}
|
|
1245
2216
|
|
|
1246
2217
|
#if QK_K == 256
|
|
1247
|
-
kernel void
|
|
2218
|
+
kernel void kernel_mul_mv_q3_K_f32(
|
|
1248
2219
|
device const void * src0,
|
|
1249
2220
|
device const float * src1,
|
|
1250
2221
|
device float * dst,
|
|
@@ -1253,9 +2224,10 @@ kernel void kernel_mul_mat_q3_K_f32(
|
|
|
1253
2224
|
constant int64_t & ne02[[buffer(5)]],
|
|
1254
2225
|
constant int64_t & ne10[[buffer(9)]],
|
|
1255
2226
|
constant int64_t & ne12[[buffer(11)]],
|
|
1256
|
-
constant int64_t & ne0[[buffer(15)]],
|
|
1257
|
-
constant int64_t & ne1[[buffer(16)]],
|
|
1258
|
-
constant uint &
|
|
2227
|
+
constant int64_t & ne0 [[buffer(15)]],
|
|
2228
|
+
constant int64_t & ne1 [[buffer(16)]],
|
|
2229
|
+
constant uint & r2 [[buffer(17)]],
|
|
2230
|
+
constant uint & r3 [[buffer(18)]],
|
|
1259
2231
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1260
2232
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
1261
2233
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
@@ -1264,17 +2236,22 @@ kernel void kernel_mul_mat_q3_K_f32(
|
|
|
1264
2236
|
|
|
1265
2237
|
const int64_t r0 = tgpig.x;
|
|
1266
2238
|
const int64_t r1 = tgpig.y;
|
|
1267
|
-
const int64_t
|
|
2239
|
+
const int64_t im = tgpig.z;
|
|
1268
2240
|
|
|
1269
2241
|
const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
|
|
1270
|
-
|
|
2242
|
+
|
|
2243
|
+
const uint i12 = im%ne12;
|
|
2244
|
+
const uint i13 = im/ne12;
|
|
2245
|
+
|
|
2246
|
+
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
|
2247
|
+
|
|
1271
2248
|
device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0;
|
|
1272
|
-
device const float * yy = (device const float *) src1 + r1*ne10 +
|
|
2249
|
+
device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
|
1273
2250
|
|
|
1274
2251
|
float yl[32];
|
|
1275
2252
|
|
|
1276
|
-
const uint16_t kmask1 = 0x3030;
|
|
1277
|
-
const uint16_t kmask2 = 0x0f0f;
|
|
2253
|
+
//const uint16_t kmask1 = 0x3030;
|
|
2254
|
+
//const uint16_t kmask2 = 0x0f0f;
|
|
1278
2255
|
|
|
1279
2256
|
const int tid = tiisg/4;
|
|
1280
2257
|
const int ix = tiisg%4;
|
|
@@ -1391,12 +2368,12 @@ kernel void kernel_mul_mat_q3_K_f32(
|
|
|
1391
2368
|
}
|
|
1392
2369
|
if (tiisg == 0) {
|
|
1393
2370
|
for (int row = 0; row < 2; ++row) {
|
|
1394
|
-
dst[r1*ne0 +
|
|
2371
|
+
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = sumf1[row];
|
|
1395
2372
|
}
|
|
1396
2373
|
}
|
|
1397
2374
|
}
|
|
1398
2375
|
#else
|
|
1399
|
-
kernel void
|
|
2376
|
+
kernel void kernel_mul_mv_q3_K_f32(
|
|
1400
2377
|
device const void * src0,
|
|
1401
2378
|
device const float * src1,
|
|
1402
2379
|
device float * dst,
|
|
@@ -1405,26 +2382,33 @@ kernel void kernel_mul_mat_q3_K_f32(
|
|
|
1405
2382
|
constant int64_t & ne02[[buffer(5)]],
|
|
1406
2383
|
constant int64_t & ne10[[buffer(9)]],
|
|
1407
2384
|
constant int64_t & ne12[[buffer(11)]],
|
|
1408
|
-
constant int64_t & ne0[[buffer(15)]],
|
|
1409
|
-
constant int64_t & ne1[[buffer(16)]],
|
|
1410
|
-
constant uint &
|
|
2385
|
+
constant int64_t & ne0 [[buffer(15)]],
|
|
2386
|
+
constant int64_t & ne1 [[buffer(16)]],
|
|
2387
|
+
constant uint & r2 [[buffer(17)]],
|
|
2388
|
+
constant uint & r3 [[buffer(18)]],
|
|
1411
2389
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1412
|
-
uint
|
|
1413
|
-
uint
|
|
2390
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
2391
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
1414
2392
|
|
|
1415
2393
|
const int nb = ne00/QK_K;
|
|
1416
2394
|
|
|
1417
2395
|
const int64_t r0 = tgpig.x;
|
|
1418
2396
|
const int64_t r1 = tgpig.y;
|
|
1419
|
-
const int64_t
|
|
2397
|
+
const int64_t im = tgpig.z;
|
|
1420
2398
|
|
|
1421
2399
|
const int row = 2 * r0 + sgitg;
|
|
1422
|
-
|
|
2400
|
+
|
|
2401
|
+
const uint i12 = im%ne12;
|
|
2402
|
+
const uint i13 = im/ne12;
|
|
2403
|
+
|
|
2404
|
+
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
|
2405
|
+
|
|
1423
2406
|
device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb + offset0;
|
|
1424
|
-
device const float * yy = (device const float *) src1 + r1*ne10 +
|
|
2407
|
+
device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
|
2408
|
+
|
|
1425
2409
|
const int ix = tiisg/4;
|
|
1426
2410
|
const int il = 4 * (tiisg%4);// 0, 4, 8, 12
|
|
1427
|
-
const int
|
|
2411
|
+
const int iq = il/8; // 0, 0, 1, 1
|
|
1428
2412
|
const int in = il%8; // 0, 4, 0, 4
|
|
1429
2413
|
|
|
1430
2414
|
float2 sum = {0.f, 0.f};
|
|
@@ -1444,7 +2428,7 @@ kernel void kernel_mul_mat_q3_K_f32(
|
|
|
1444
2428
|
const float d4 = d_all * ((int32_t)(s[0] & 0xF000) - 32768) * 1.f/262144.f;
|
|
1445
2429
|
|
|
1446
2430
|
for (int l = 0; l < 4; l += 2) {
|
|
1447
|
-
const uint16_t hm = h[l/2] >>
|
|
2431
|
+
const uint16_t hm = h[l/2] >> iq;
|
|
1448
2432
|
sum[0] += y[l+ 0] * d1 * ((int32_t)(q[l/2] & 0x0003) - ((hm & 0x0001) ? 0 : 4))
|
|
1449
2433
|
+ y[l+16] * d2 * ((int32_t)(q[l/2] & 0x000c) - ((hm & 0x0004) ? 0 : 16))
|
|
1450
2434
|
+ y[l+32] * d3 * ((int32_t)(q[l/2] & 0x0030) - ((hm & 0x0010) ? 0 : 64))
|
|
@@ -1460,14 +2444,14 @@ kernel void kernel_mul_mat_q3_K_f32(
|
|
|
1460
2444
|
|
|
1461
2445
|
const float tot = simd_sum(sumf);
|
|
1462
2446
|
if (tiisg == 0) {
|
|
1463
|
-
dst[r1*ne0 +
|
|
2447
|
+
dst[r1*ne0 + im*ne0*ne1 + row] = tot;
|
|
1464
2448
|
}
|
|
1465
2449
|
|
|
1466
2450
|
}
|
|
1467
2451
|
#endif
|
|
1468
2452
|
|
|
1469
2453
|
#if QK_K == 256
|
|
1470
|
-
kernel void
|
|
2454
|
+
kernel void kernel_mul_mv_q4_K_f32(
|
|
1471
2455
|
device const void * src0,
|
|
1472
2456
|
device const float * src1,
|
|
1473
2457
|
device float * dst,
|
|
@@ -1478,10 +2462,11 @@ kernel void kernel_mul_mat_q4_K_f32(
|
|
|
1478
2462
|
constant int64_t & ne12 [[buffer(11)]],
|
|
1479
2463
|
constant int64_t & ne0 [[buffer(15)]],
|
|
1480
2464
|
constant int64_t & ne1 [[buffer(16)]],
|
|
1481
|
-
constant uint &
|
|
2465
|
+
constant uint & r2 [[buffer(17)]],
|
|
2466
|
+
constant uint & r3 [[buffer(18)]],
|
|
1482
2467
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1483
|
-
uint
|
|
1484
|
-
uint
|
|
2468
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
2469
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
1485
2470
|
|
|
1486
2471
|
const uint16_t kmask1 = 0x3f3f;
|
|
1487
2472
|
const uint16_t kmask2 = 0x0f0f;
|
|
@@ -1489,26 +2474,32 @@ kernel void kernel_mul_mat_q4_K_f32(
|
|
|
1489
2474
|
|
|
1490
2475
|
const int ix = tiisg/8; // 0...3
|
|
1491
2476
|
const int it = tiisg%8; // 0...7
|
|
1492
|
-
const int
|
|
2477
|
+
const int iq = it/4; // 0 or 1
|
|
1493
2478
|
const int ir = it%4; // 0...3
|
|
1494
2479
|
|
|
1495
2480
|
const int nb = ne00/QK_K;
|
|
1496
2481
|
const int r0 = tgpig.x;
|
|
1497
2482
|
const int r1 = tgpig.y;
|
|
1498
|
-
const int
|
|
2483
|
+
const int im = tgpig.z;
|
|
1499
2484
|
//const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
|
1500
2485
|
const int first_row = r0 * N_DST;
|
|
1501
2486
|
const int ib_row = first_row * nb;
|
|
1502
|
-
|
|
2487
|
+
|
|
2488
|
+
const uint i12 = im%ne12;
|
|
2489
|
+
const uint i13 = im/ne12;
|
|
2490
|
+
|
|
2491
|
+
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
|
2492
|
+
|
|
1503
2493
|
device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
|
|
1504
|
-
device const float * y = (device const float *) src1 + r1*ne10 +
|
|
2494
|
+
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
|
2495
|
+
|
|
1505
2496
|
float yl[16];
|
|
1506
2497
|
float yh[16];
|
|
1507
2498
|
float sumf[N_DST]={0.f}, all_sum;
|
|
1508
2499
|
|
|
1509
2500
|
const int step = sizeof(block_q4_K) * nb / 2;
|
|
1510
2501
|
|
|
1511
|
-
device const float * y4 = y + ix * QK_K + 64 *
|
|
2502
|
+
device const float * y4 = y + ix * QK_K + 64 * iq + 8 * ir;
|
|
1512
2503
|
|
|
1513
2504
|
uint16_t sc16[4];
|
|
1514
2505
|
thread const uint8_t * sc8 = (thread const uint8_t *)sc16;
|
|
@@ -1523,8 +2514,8 @@ kernel void kernel_mul_mat_q4_K_f32(
|
|
|
1523
2514
|
yh[i+8] = y4[i+160]; sumy[3] += yh[i+8];
|
|
1524
2515
|
}
|
|
1525
2516
|
|
|
1526
|
-
device const uint16_t * sc = (device const uint16_t *)x[ib].scales +
|
|
1527
|
-
device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 *
|
|
2517
|
+
device const uint16_t * sc = (device const uint16_t *)x[ib].scales + iq;
|
|
2518
|
+
device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir;
|
|
1528
2519
|
device const half * dh = &x[ib].d;
|
|
1529
2520
|
|
|
1530
2521
|
for (int row = 0; row < N_DST; row++) {
|
|
@@ -1568,12 +2559,12 @@ kernel void kernel_mul_mat_q4_K_f32(
|
|
|
1568
2559
|
for (int row = 0; row < N_DST; ++row) {
|
|
1569
2560
|
all_sum = simd_sum(sumf[row]);
|
|
1570
2561
|
if (tiisg == 0) {
|
|
1571
|
-
dst[r1*ne0 +
|
|
2562
|
+
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
|
|
1572
2563
|
}
|
|
1573
2564
|
}
|
|
1574
2565
|
}
|
|
1575
2566
|
#else
|
|
1576
|
-
kernel void
|
|
2567
|
+
kernel void kernel_mul_mv_q4_K_f32(
|
|
1577
2568
|
device const void * src0,
|
|
1578
2569
|
device const float * src1,
|
|
1579
2570
|
device float * dst,
|
|
@@ -1582,9 +2573,10 @@ kernel void kernel_mul_mat_q4_K_f32(
|
|
|
1582
2573
|
constant int64_t & ne02[[buffer(5)]],
|
|
1583
2574
|
constant int64_t & ne10[[buffer(9)]],
|
|
1584
2575
|
constant int64_t & ne12[[buffer(11)]],
|
|
1585
|
-
constant int64_t & ne0[[buffer(15)]],
|
|
1586
|
-
constant int64_t & ne1[[buffer(16)]],
|
|
1587
|
-
constant uint &
|
|
2576
|
+
constant int64_t & ne0 [[buffer(15)]],
|
|
2577
|
+
constant int64_t & ne1 [[buffer(16)]],
|
|
2578
|
+
constant uint & r2 [[buffer(17)]],
|
|
2579
|
+
constant uint & r3 [[buffer(18)]],
|
|
1588
2580
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1589
2581
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
1590
2582
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
@@ -1595,12 +2587,18 @@ kernel void kernel_mul_mat_q4_K_f32(
|
|
|
1595
2587
|
const int nb = ne00/QK_K;
|
|
1596
2588
|
const int r0 = tgpig.x;
|
|
1597
2589
|
const int r1 = tgpig.y;
|
|
1598
|
-
const int
|
|
2590
|
+
const int im = tgpig.z;
|
|
1599
2591
|
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
|
1600
2592
|
const int ib_row = first_row * nb;
|
|
1601
|
-
|
|
2593
|
+
|
|
2594
|
+
const uint i12 = im%ne12;
|
|
2595
|
+
const uint i13 = im/ne12;
|
|
2596
|
+
|
|
2597
|
+
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
|
2598
|
+
|
|
1602
2599
|
device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
|
|
1603
|
-
device const float * y = (device const float *) src1 + r1*ne10 +
|
|
2600
|
+
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
|
2601
|
+
|
|
1604
2602
|
float yl[8];
|
|
1605
2603
|
float yh[8];
|
|
1606
2604
|
float sumf[N_DST]={0.f}, all_sum;
|
|
@@ -1656,13 +2654,13 @@ kernel void kernel_mul_mat_q4_K_f32(
|
|
|
1656
2654
|
for (int row = 0; row < N_DST; ++row) {
|
|
1657
2655
|
all_sum = simd_sum(sumf[row]);
|
|
1658
2656
|
if (tiisg == 0) {
|
|
1659
|
-
dst[r1*ne0+
|
|
2657
|
+
dst[r1*ne0+ im*ne0*ne1 + first_row + row] = all_sum;
|
|
1660
2658
|
}
|
|
1661
2659
|
}
|
|
1662
2660
|
}
|
|
1663
2661
|
#endif
|
|
1664
2662
|
|
|
1665
|
-
kernel void
|
|
2663
|
+
kernel void kernel_mul_mv_q5_K_f32(
|
|
1666
2664
|
device const void * src0,
|
|
1667
2665
|
device const float * src1,
|
|
1668
2666
|
device float * dst,
|
|
@@ -1671,9 +2669,10 @@ kernel void kernel_mul_mat_q5_K_f32(
|
|
|
1671
2669
|
constant int64_t & ne02[[buffer(5)]],
|
|
1672
2670
|
constant int64_t & ne10[[buffer(9)]],
|
|
1673
2671
|
constant int64_t & ne12[[buffer(11)]],
|
|
1674
|
-
constant int64_t & ne0[[buffer(15)]],
|
|
1675
|
-
constant int64_t & ne1[[buffer(16)]],
|
|
1676
|
-
constant uint &
|
|
2672
|
+
constant int64_t & ne0 [[buffer(15)]],
|
|
2673
|
+
constant int64_t & ne1 [[buffer(16)]],
|
|
2674
|
+
constant uint & r2 [[buffer(17)]],
|
|
2675
|
+
constant uint & r3 [[buffer(18)]],
|
|
1677
2676
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1678
2677
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
1679
2678
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
@@ -1682,12 +2681,17 @@ kernel void kernel_mul_mat_q5_K_f32(
|
|
|
1682
2681
|
|
|
1683
2682
|
const int64_t r0 = tgpig.x;
|
|
1684
2683
|
const int64_t r1 = tgpig.y;
|
|
1685
|
-
const int
|
|
2684
|
+
const int im = tgpig.z;
|
|
1686
2685
|
|
|
1687
2686
|
const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
|
|
1688
|
-
|
|
2687
|
+
|
|
2688
|
+
const uint i12 = im%ne12;
|
|
2689
|
+
const uint i13 = im/ne12;
|
|
2690
|
+
|
|
2691
|
+
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
|
2692
|
+
|
|
1689
2693
|
device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb + offset0;
|
|
1690
|
-
device const float * yy = (device const float *) src1 + r1*ne10 +
|
|
2694
|
+
device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
|
1691
2695
|
|
|
1692
2696
|
float sumf[2]={0.f};
|
|
1693
2697
|
|
|
@@ -1703,15 +2707,15 @@ kernel void kernel_mul_mat_q5_K_f32(
|
|
|
1703
2707
|
|
|
1704
2708
|
const int tid = tiisg/4;
|
|
1705
2709
|
const int ix = tiisg%4;
|
|
1706
|
-
const int
|
|
2710
|
+
const int iq = tid/4;
|
|
1707
2711
|
const int ir = tid%4;
|
|
1708
2712
|
const int n = 8;
|
|
1709
2713
|
|
|
1710
2714
|
const int l0 = n*ir;
|
|
1711
|
-
const int q_offset = 32*
|
|
1712
|
-
const int y_offset = 64*
|
|
2715
|
+
const int q_offset = 32*iq + l0;
|
|
2716
|
+
const int y_offset = 64*iq + l0;
|
|
1713
2717
|
|
|
1714
|
-
const uint8_t hm1 = 1u << (2*
|
|
2718
|
+
const uint8_t hm1 = 1u << (2*iq);
|
|
1715
2719
|
const uint8_t hm2 = hm1 << 1;
|
|
1716
2720
|
const uint8_t hm3 = hm1 << 4;
|
|
1717
2721
|
const uint8_t hm4 = hm2 << 4;
|
|
@@ -1726,7 +2730,7 @@ kernel void kernel_mul_mat_q5_K_f32(
|
|
|
1726
2730
|
device const uint8_t * q1 = x[i].qs + q_offset;
|
|
1727
2731
|
device const uint8_t * qh = x[i].qh + l0;
|
|
1728
2732
|
device const half * dh = &x[i].d;
|
|
1729
|
-
device const uint16_t * a = (device const uint16_t *)x[i].scales +
|
|
2733
|
+
device const uint16_t * a = (device const uint16_t *)x[i].scales + iq;
|
|
1730
2734
|
|
|
1731
2735
|
device const float * y2 = y1 + 128;
|
|
1732
2736
|
float4 sumy = {0.f, 0.f, 0.f, 0.f};
|
|
@@ -1782,7 +2786,7 @@ kernel void kernel_mul_mat_q5_K_f32(
|
|
|
1782
2786
|
|
|
1783
2787
|
const int il = 4 * (tiisg/8); // 0, 4, 8, 12
|
|
1784
2788
|
const int ix = tiisg%8;
|
|
1785
|
-
const int
|
|
2789
|
+
const int iq = il/8; // 0, 0, 1, 1
|
|
1786
2790
|
const int in = il%8; // 0, 4, 0, 4
|
|
1787
2791
|
|
|
1788
2792
|
device const float * y = yy + ix*QK_K + il;
|
|
@@ -1807,7 +2811,7 @@ kernel void kernel_mul_mat_q5_K_f32(
|
|
|
1807
2811
|
|
|
1808
2812
|
float2 acc = {0.f, 0.f};
|
|
1809
2813
|
for (int l = 0; l < 4; ++l) {
|
|
1810
|
-
const uint8_t hl = h[l] >>
|
|
2814
|
+
const uint8_t hl = h[l] >> iq;
|
|
1811
2815
|
acc[0] += yl[l+0] * s[0] * ((int16_t)(q[l+ 0] & 0x0F) - (hl & 0x01 ? 0 : 16))
|
|
1812
2816
|
+ yl[l+4] * s[1] * ((int16_t)(q[l+16] & 0x0F) - (hl & 0x04 ? 0 : 16));
|
|
1813
2817
|
acc[1] += yh[l+0] * s[2] * ((int16_t)(q[l+ 0] & 0xF0) - (hl & 0x10 ? 0 : 256))
|
|
@@ -1829,13 +2833,13 @@ kernel void kernel_mul_mat_q5_K_f32(
|
|
|
1829
2833
|
for (int row = 0; row < 2; ++row) {
|
|
1830
2834
|
const float tot = simd_sum(sumf[row]);
|
|
1831
2835
|
if (tiisg == 0) {
|
|
1832
|
-
dst[r1*ne0 +
|
|
2836
|
+
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
|
|
1833
2837
|
}
|
|
1834
2838
|
}
|
|
1835
2839
|
|
|
1836
2840
|
}
|
|
1837
2841
|
|
|
1838
|
-
kernel void
|
|
2842
|
+
kernel void kernel_mul_mv_q6_K_f32(
|
|
1839
2843
|
device const void * src0,
|
|
1840
2844
|
device const float * src1,
|
|
1841
2845
|
device float * dst,
|
|
@@ -1844,9 +2848,10 @@ kernel void kernel_mul_mat_q6_K_f32(
|
|
|
1844
2848
|
constant int64_t & ne02[[buffer(5)]],
|
|
1845
2849
|
constant int64_t & ne10[[buffer(9)]],
|
|
1846
2850
|
constant int64_t & ne12[[buffer(11)]],
|
|
1847
|
-
constant int64_t & ne0[[buffer(15)]],
|
|
1848
|
-
constant int64_t & ne1[[buffer(16)]],
|
|
1849
|
-
constant uint &
|
|
2851
|
+
constant int64_t & ne0 [[buffer(15)]],
|
|
2852
|
+
constant int64_t & ne1 [[buffer(16)]],
|
|
2853
|
+
constant uint & r2 [[buffer(17)]],
|
|
2854
|
+
constant uint & r3 [[buffer(18)]],
|
|
1850
2855
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1851
2856
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
1852
2857
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
@@ -1860,12 +2865,17 @@ kernel void kernel_mul_mat_q6_K_f32(
|
|
|
1860
2865
|
|
|
1861
2866
|
const int64_t r0 = tgpig.x;
|
|
1862
2867
|
const int64_t r1 = tgpig.y;
|
|
1863
|
-
const int
|
|
2868
|
+
const int im = tgpig.z;
|
|
1864
2869
|
|
|
1865
2870
|
const int row = 2 * r0 + sgitg;
|
|
1866
|
-
|
|
2871
|
+
|
|
2872
|
+
const uint i12 = im%ne12;
|
|
2873
|
+
const uint i13 = im/ne12;
|
|
2874
|
+
|
|
2875
|
+
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
|
2876
|
+
|
|
1867
2877
|
device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb + offset0;
|
|
1868
|
-
device const float * yy = (device const float *) src1 + r1*ne10 +
|
|
2878
|
+
device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
|
1869
2879
|
|
|
1870
2880
|
float sumf = 0;
|
|
1871
2881
|
|
|
@@ -1931,7 +2941,7 @@ kernel void kernel_mul_mat_q6_K_f32(
|
|
|
1931
2941
|
|
|
1932
2942
|
const float tot = simd_sum(sumf);
|
|
1933
2943
|
if (tiisg == 0) {
|
|
1934
|
-
dst[r1*ne0 +
|
|
2944
|
+
dst[r1*ne0 + im*ne0*ne1 + row] = tot;
|
|
1935
2945
|
}
|
|
1936
2946
|
}
|
|
1937
2947
|
|
|
@@ -1984,6 +2994,62 @@ void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg
|
|
|
1984
2994
|
}
|
|
1985
2995
|
}
|
|
1986
2996
|
|
|
2997
|
+
template <typename type4x4>
|
|
2998
|
+
void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg) {
|
|
2999
|
+
device const uint16_t * qs = ((device const uint16_t *)xb + 3);
|
|
3000
|
+
const float d = xb->d;
|
|
3001
|
+
const float md = -16.h * xb->d;
|
|
3002
|
+
const ushort mask = il ? 0x00F0 : 0x000F;
|
|
3003
|
+
|
|
3004
|
+
const uint32_t qh = *((device const uint32_t *)xb->qh);
|
|
3005
|
+
|
|
3006
|
+
const int x_mv = il ? 4 : 0;
|
|
3007
|
+
|
|
3008
|
+
const int gh_mv = il ? 12 : 0;
|
|
3009
|
+
const int gh_bk = il ? 0 : 4;
|
|
3010
|
+
|
|
3011
|
+
for (int i = 0; i < 8; i++) {
|
|
3012
|
+
// extract the 5-th bits for x0 and x1
|
|
3013
|
+
const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
|
|
3014
|
+
const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
|
|
3015
|
+
|
|
3016
|
+
// combine the 4-bits from qs with the 5th bit
|
|
3017
|
+
const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
|
|
3018
|
+
const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
|
|
3019
|
+
|
|
3020
|
+
reg[i/2][2*(i%2)+0] = d * x0 + md;
|
|
3021
|
+
reg[i/2][2*(i%2)+1] = d * x1 + md;
|
|
3022
|
+
}
|
|
3023
|
+
}
|
|
3024
|
+
|
|
3025
|
+
template <typename type4x4>
|
|
3026
|
+
void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg) {
|
|
3027
|
+
device const uint16_t * qs = ((device const uint16_t *)xb + 4);
|
|
3028
|
+
const float d = xb->d;
|
|
3029
|
+
const float m = xb->m;
|
|
3030
|
+
const ushort mask = il ? 0x00F0 : 0x000F;
|
|
3031
|
+
|
|
3032
|
+
const uint32_t qh = *((device const uint32_t *)xb->qh);
|
|
3033
|
+
|
|
3034
|
+
const int x_mv = il ? 4 : 0;
|
|
3035
|
+
|
|
3036
|
+
const int gh_mv = il ? 12 : 0;
|
|
3037
|
+
const int gh_bk = il ? 0 : 4;
|
|
3038
|
+
|
|
3039
|
+
for (int i = 0; i < 8; i++) {
|
|
3040
|
+
// extract the 5-th bits for x0 and x1
|
|
3041
|
+
const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
|
|
3042
|
+
const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
|
|
3043
|
+
|
|
3044
|
+
// combine the 4-bits from qs with the 5th bit
|
|
3045
|
+
const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
|
|
3046
|
+
const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
|
|
3047
|
+
|
|
3048
|
+
reg[i/2][2*(i%2)+0] = d * x0 + m;
|
|
3049
|
+
reg[i/2][2*(i%2)+1] = d * x1 + m;
|
|
3050
|
+
}
|
|
3051
|
+
}
|
|
3052
|
+
|
|
1987
3053
|
template <typename type4x4>
|
|
1988
3054
|
void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
|
|
1989
3055
|
device const int8_t * qs = ((device const int8_t *)xb->qs);
|
|
@@ -2173,7 +3239,7 @@ kernel void kernel_get_rows(
|
|
|
2173
3239
|
}
|
|
2174
3240
|
|
|
2175
3241
|
#define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
|
|
2176
|
-
#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix
|
|
3242
|
+
#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B
|
|
2177
3243
|
#define BLOCK_SIZE_K 32
|
|
2178
3244
|
#define THREAD_MAT_M 4 // each thread take 4 simdgroup matrices from matrix A
|
|
2179
3245
|
#define THREAD_MAT_N 2 // each thread take 2 simdgroup matrices from matrix B
|
|
@@ -2185,24 +3251,25 @@ kernel void kernel_get_rows(
|
|
|
2185
3251
|
|
|
2186
3252
|
// each block_q contains 16*nl weights
|
|
2187
3253
|
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
|
2188
|
-
|
|
2189
|
-
|
|
2190
|
-
|
|
2191
|
-
|
|
2192
|
-
|
|
2193
|
-
|
|
2194
|
-
|
|
2195
|
-
|
|
2196
|
-
|
|
2197
|
-
|
|
2198
|
-
|
|
2199
|
-
|
|
2200
|
-
|
|
2201
|
-
|
|
2202
|
-
|
|
2203
|
-
|
|
2204
|
-
|
|
2205
|
-
|
|
3254
|
+
void kernel_mul_mm_impl(device const uchar * src0,
|
|
3255
|
+
device const uchar * src1,
|
|
3256
|
+
device float * dst,
|
|
3257
|
+
constant int64_t & ne00,
|
|
3258
|
+
constant int64_t & ne02,
|
|
3259
|
+
constant int64_t & nb01,
|
|
3260
|
+
constant int64_t & nb02,
|
|
3261
|
+
constant int64_t & ne12,
|
|
3262
|
+
constant int64_t & nb10,
|
|
3263
|
+
constant int64_t & nb11,
|
|
3264
|
+
constant int64_t & nb12,
|
|
3265
|
+
constant int64_t & ne0,
|
|
3266
|
+
constant int64_t & ne1,
|
|
3267
|
+
constant uint & r2,
|
|
3268
|
+
constant uint & r3,
|
|
3269
|
+
threadgroup uchar * shared_memory [[threadgroup(0)]],
|
|
3270
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
3271
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
|
3272
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
2206
3273
|
|
|
2207
3274
|
threadgroup half * sa = (threadgroup half *)(shared_memory);
|
|
2208
3275
|
threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
|
|
@@ -2210,9 +3277,11 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|
|
2210
3277
|
const uint r0 = tgpig.y;
|
|
2211
3278
|
const uint r1 = tgpig.x;
|
|
2212
3279
|
const uint im = tgpig.z;
|
|
3280
|
+
|
|
2213
3281
|
// if this block is of 64x32 shape or smaller
|
|
2214
3282
|
short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M;
|
|
2215
3283
|
short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N;
|
|
3284
|
+
|
|
2216
3285
|
// a thread shouldn't load data outside of the matrix
|
|
2217
3286
|
short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
|
|
2218
3287
|
short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
|
|
@@ -2226,7 +3295,10 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|
|
2226
3295
|
|
|
2227
3296
|
short il = (tiitg % THREAD_PER_ROW);
|
|
2228
3297
|
|
|
2229
|
-
uint
|
|
3298
|
+
const uint i12 = im%ne12;
|
|
3299
|
+
const uint i13 = im/ne12;
|
|
3300
|
+
|
|
3301
|
+
uint offset0 = (i12/r2)*nb02 + (i13/r3)*(nb02*ne02);
|
|
2230
3302
|
ushort offset1 = il/nl;
|
|
2231
3303
|
|
|
2232
3304
|
device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
|
|
@@ -2236,26 +3308,30 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|
|
2236
3308
|
+ nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
|
|
2237
3309
|
|
|
2238
3310
|
for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
|
|
2239
|
-
//load data and store to threadgroup memory
|
|
3311
|
+
// load data and store to threadgroup memory
|
|
2240
3312
|
half4x4 temp_a;
|
|
2241
3313
|
dequantize_func(x, il, temp_a);
|
|
2242
3314
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
3315
|
+
|
|
2243
3316
|
#pragma unroll(16)
|
|
2244
3317
|
for (int i = 0; i < 16; i++) {
|
|
2245
3318
|
*(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
|
|
2246
|
-
+
|
|
2247
|
-
+
|
|
3319
|
+
+ (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \
|
|
3320
|
+
+ (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4];
|
|
2248
3321
|
}
|
|
2249
|
-
|
|
2250
|
-
|
|
3322
|
+
|
|
3323
|
+
*(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y);
|
|
3324
|
+
|
|
2251
3325
|
il = (il + 2 < nl) ? il + 2 : il % 2;
|
|
2252
3326
|
x = (il < 2) ? x + (2+nl-1)/nl : x;
|
|
2253
3327
|
y += BLOCK_SIZE_K;
|
|
2254
3328
|
|
|
2255
3329
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
2256
|
-
|
|
3330
|
+
|
|
3331
|
+
// load matrices from threadgroup memory and conduct outer products
|
|
2257
3332
|
threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
|
|
2258
3333
|
threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
|
|
3334
|
+
|
|
2259
3335
|
#pragma unroll(4)
|
|
2260
3336
|
for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
|
|
2261
3337
|
#pragma unroll(4)
|
|
@@ -2270,6 +3346,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|
|
2270
3346
|
|
|
2271
3347
|
lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
|
|
2272
3348
|
lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
|
|
3349
|
+
|
|
2273
3350
|
#pragma unroll(8)
|
|
2274
3351
|
for (int i = 0; i < 8; i++){
|
|
2275
3352
|
simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]);
|
|
@@ -2278,25 +3355,26 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|
|
2278
3355
|
}
|
|
2279
3356
|
|
|
2280
3357
|
if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) {
|
|
2281
|
-
device float *C = dst + BLOCK_SIZE_M * r0 + 32 * (sgitg&1) \
|
|
2282
|
-
|
|
3358
|
+
device float * C = dst + (BLOCK_SIZE_M * r0 + 32 * (sgitg & 1)) \
|
|
3359
|
+
+ (BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1)) * ne0 + im*ne1*ne0;
|
|
2283
3360
|
for (int i = 0; i < 8; i++) {
|
|
2284
3361
|
simdgroup_store(c_res[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0);
|
|
2285
3362
|
}
|
|
2286
3363
|
} else {
|
|
2287
3364
|
// block is smaller than 64x32, we should avoid writing data outside of the matrix
|
|
2288
3365
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
2289
|
-
threadgroup float *temp_str = ((threadgroup float *)shared_memory) \
|
|
3366
|
+
threadgroup float * temp_str = ((threadgroup float *)shared_memory) \
|
|
2290
3367
|
+ 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
|
|
2291
3368
|
for (int i = 0; i < 8; i++) {
|
|
2292
3369
|
simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
|
|
2293
3370
|
}
|
|
2294
3371
|
|
|
2295
3372
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
2296
|
-
|
|
2297
|
-
|
|
3373
|
+
|
|
3374
|
+
device float * C = dst + (BLOCK_SIZE_M * r0) + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0;
|
|
3375
|
+
if (sgitg == 0) {
|
|
2298
3376
|
for (int i = 0; i < n_rows; i++) {
|
|
2299
|
-
for (int j = tiitg; j< n_cols; j += BLOCK_SIZE_N) {
|
|
3377
|
+
for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
|
|
2300
3378
|
*(C + i + j * ne0) = *(temp_str + i + j * BLOCK_SIZE_M);
|
|
2301
3379
|
}
|
|
2302
3380
|
}
|
|
@@ -2304,19 +3382,123 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|
|
2304
3382
|
}
|
|
2305
3383
|
}
|
|
2306
3384
|
|
|
3385
|
+
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
|
3386
|
+
kernel void kernel_mul_mm(device const uchar * src0,
|
|
3387
|
+
device const uchar * src1,
|
|
3388
|
+
device float * dst,
|
|
3389
|
+
constant int64_t & ne00,
|
|
3390
|
+
constant int64_t & ne02,
|
|
3391
|
+
constant int64_t & nb01,
|
|
3392
|
+
constant int64_t & nb02,
|
|
3393
|
+
constant int64_t & ne12,
|
|
3394
|
+
constant int64_t & nb10,
|
|
3395
|
+
constant int64_t & nb11,
|
|
3396
|
+
constant int64_t & nb12,
|
|
3397
|
+
constant int64_t & ne0,
|
|
3398
|
+
constant int64_t & ne1,
|
|
3399
|
+
constant uint & r2,
|
|
3400
|
+
constant uint & r3,
|
|
3401
|
+
threadgroup uchar * shared_memory [[threadgroup(0)]],
|
|
3402
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
3403
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
|
3404
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
3405
|
+
kernel_mul_mm_impl<block_q, nl, dequantize_func>(
|
|
3406
|
+
src0,
|
|
3407
|
+
src1,
|
|
3408
|
+
dst,
|
|
3409
|
+
ne00,
|
|
3410
|
+
ne02,
|
|
3411
|
+
nb01,
|
|
3412
|
+
nb02,
|
|
3413
|
+
ne12,
|
|
3414
|
+
nb10,
|
|
3415
|
+
nb11,
|
|
3416
|
+
nb12,
|
|
3417
|
+
ne0,
|
|
3418
|
+
ne1,
|
|
3419
|
+
r2,
|
|
3420
|
+
r3,
|
|
3421
|
+
shared_memory,
|
|
3422
|
+
tgpig,
|
|
3423
|
+
tiitg,
|
|
3424
|
+
sgitg);
|
|
3425
|
+
}
|
|
3426
|
+
|
|
3427
|
+
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
|
3428
|
+
kernel void kernel_mul_mm_id(
|
|
3429
|
+
device const int32_t * ids,
|
|
3430
|
+
device const uchar * src1,
|
|
3431
|
+
device float * dst,
|
|
3432
|
+
constant int64_t & ne00,
|
|
3433
|
+
constant int64_t & ne02,
|
|
3434
|
+
constant int64_t & nb01,
|
|
3435
|
+
constant int64_t & nb02,
|
|
3436
|
+
constant int64_t & ne12,
|
|
3437
|
+
constant int64_t & nb10,
|
|
3438
|
+
constant int64_t & nb11,
|
|
3439
|
+
constant int64_t & nb12,
|
|
3440
|
+
constant int64_t & ne0,
|
|
3441
|
+
constant int64_t & ne1,
|
|
3442
|
+
constant uint & r2,
|
|
3443
|
+
constant uint & r3,
|
|
3444
|
+
constant int & idx,
|
|
3445
|
+
device const uchar * src00,
|
|
3446
|
+
device const uchar * src01,
|
|
3447
|
+
device const uchar * src02,
|
|
3448
|
+
device const uchar * src03,
|
|
3449
|
+
device const uchar * src04,
|
|
3450
|
+
device const uchar * src05,
|
|
3451
|
+
device const uchar * src06,
|
|
3452
|
+
device const uchar * src07,
|
|
3453
|
+
threadgroup uchar * shared_memory [[threadgroup(0)]],
|
|
3454
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
3455
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
|
3456
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
3457
|
+
device const uchar * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
|
3458
|
+
|
|
3459
|
+
kernel_mul_mm_impl<block_q, nl, dequantize_func>(
|
|
3460
|
+
src0[ids[idx]],
|
|
3461
|
+
src1,
|
|
3462
|
+
dst,
|
|
3463
|
+
ne00,
|
|
3464
|
+
ne02,
|
|
3465
|
+
nb01,
|
|
3466
|
+
nb02,
|
|
3467
|
+
ne12,
|
|
3468
|
+
nb10,
|
|
3469
|
+
nb11,
|
|
3470
|
+
nb12,
|
|
3471
|
+
ne0,
|
|
3472
|
+
ne1,
|
|
3473
|
+
r2,
|
|
3474
|
+
r3,
|
|
3475
|
+
shared_memory,
|
|
3476
|
+
tgpig,
|
|
3477
|
+
tiitg,
|
|
3478
|
+
sgitg);
|
|
3479
|
+
}
|
|
3480
|
+
|
|
2307
3481
|
#if QK_K == 256
|
|
2308
3482
|
#define QK_NL 16
|
|
2309
3483
|
#else
|
|
2310
3484
|
#define QK_NL 4
|
|
2311
3485
|
#endif
|
|
2312
3486
|
|
|
2313
|
-
typedef void (get_rows_t)(
|
|
2314
|
-
|
|
3487
|
+
typedef void (get_rows_t)(
|
|
3488
|
+
device const void * src0,
|
|
3489
|
+
device const int * src1,
|
|
3490
|
+
device float * dst,
|
|
3491
|
+
constant int64_t & ne00,
|
|
3492
|
+
constant uint64_t & nb01,
|
|
3493
|
+
constant uint64_t & nb1,
|
|
3494
|
+
uint, uint, uint);
|
|
2315
3495
|
|
|
2316
3496
|
template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows<float4x4, 1, dequantize_f32>;
|
|
2317
3497
|
template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
|
|
2318
3498
|
template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
|
|
2319
3499
|
template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
|
|
3500
|
+
template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_t kernel_get_rows<block_q5_0, 2, dequantize_q5_0>;
|
|
3501
|
+
template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_t kernel_get_rows<block_q5_1, 2, dequantize_q5_1>;
|
|
2320
3502
|
template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_t kernel_get_rows<block_q8_0, 2, dequantize_q8_0>;
|
|
2321
3503
|
template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_t kernel_get_rows<block_q2_K, QK_NL, dequantize_q2_K>;
|
|
2322
3504
|
template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows<block_q3_K, QK_NL, dequantize_q3_K>;
|
|
@@ -2338,16 +3520,61 @@ typedef void (mat_mm_t)(
|
|
|
2338
3520
|
constant int64_t & nb12,
|
|
2339
3521
|
constant int64_t & ne0,
|
|
2340
3522
|
constant int64_t & ne1,
|
|
2341
|
-
constant uint &
|
|
2342
|
-
|
|
3523
|
+
constant uint & r2,
|
|
3524
|
+
constant uint & r3,
|
|
3525
|
+
threadgroup uchar *,
|
|
3526
|
+
uint3, uint, uint);
|
|
2343
3527
|
|
|
2344
3528
|
template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<float4x4, 1, dequantize_f32>;
|
|
2345
3529
|
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
|
|
2346
3530
|
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
|
|
2347
3531
|
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
|
|
3532
|
+
template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_0, 2, dequantize_q5_0>;
|
|
3533
|
+
template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_1, 2, dequantize_q5_1>;
|
|
2348
3534
|
template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>;
|
|
2349
3535
|
template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>;
|
|
2350
3536
|
template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>;
|
|
2351
3537
|
template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;
|
|
2352
3538
|
template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>;
|
|
2353
3539
|
template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
|
|
3540
|
+
|
|
3541
|
+
typedef void (mat_mm_id_t)(
|
|
3542
|
+
device const int32_t * ids,
|
|
3543
|
+
device const uchar * src1,
|
|
3544
|
+
device float * dst,
|
|
3545
|
+
constant int64_t & ne00,
|
|
3546
|
+
constant int64_t & ne02,
|
|
3547
|
+
constant int64_t & nb01,
|
|
3548
|
+
constant int64_t & nb02,
|
|
3549
|
+
constant int64_t & ne12,
|
|
3550
|
+
constant int64_t & nb10,
|
|
3551
|
+
constant int64_t & nb11,
|
|
3552
|
+
constant int64_t & nb12,
|
|
3553
|
+
constant int64_t & ne0,
|
|
3554
|
+
constant int64_t & ne1,
|
|
3555
|
+
constant uint & r2,
|
|
3556
|
+
constant uint & r3,
|
|
3557
|
+
constant int & idx,
|
|
3558
|
+
device const uchar * src00,
|
|
3559
|
+
device const uchar * src01,
|
|
3560
|
+
device const uchar * src02,
|
|
3561
|
+
device const uchar * src03,
|
|
3562
|
+
device const uchar * src04,
|
|
3563
|
+
device const uchar * src05,
|
|
3564
|
+
device const uchar * src06,
|
|
3565
|
+
device const uchar * src07,
|
|
3566
|
+
threadgroup uchar *,
|
|
3567
|
+
uint3, uint, uint);
|
|
3568
|
+
|
|
3569
|
+
template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<float4x4, 1, dequantize_f32>;
|
|
3570
|
+
template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<half4x4, 1, dequantize_f16>;
|
|
3571
|
+
template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_0, 2, dequantize_q4_0>;
|
|
3572
|
+
template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_1, 2, dequantize_q4_1>;
|
|
3573
|
+
template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_0, 2, dequantize_q5_0>;
|
|
3574
|
+
template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_1, 2, dequantize_q5_1>;
|
|
3575
|
+
template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q8_0, 2, dequantize_q8_0>;
|
|
3576
|
+
template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q2_K, QK_NL, dequantize_q2_K>;
|
|
3577
|
+
template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q3_K, QK_NL, dequantize_q3_K>;
|
|
3578
|
+
template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_K, QK_NL, dequantize_q4_K>;
|
|
3579
|
+
template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_K, QK_NL, dequantize_q5_K>;
|
|
3580
|
+
template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q6_K, QK_NL, dequantize_q6_K>;
|