llama_cpp 0.1.3 → 0.2.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,1916 @@
1
+ #include <cstddef>
2
+ #include <cstdint>
3
+ #include <stdint.h>
4
+ #include <stdio.h>
5
+ #include <atomic>
6
+ #include <assert.h>
7
+
8
+ #include <cuda_runtime.h>
9
+ #include <cublas_v2.h>
10
+ #include <cuda_fp16.h>
11
+
12
+ #include "ggml-cuda.h"
13
+ #include "ggml.h"
14
+
15
+ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
16
+
17
+ #define CUDA_CHECK(err) \
18
+ do { \
19
+ cudaError_t err_ = (err); \
20
+ if (err_ != cudaSuccess) { \
21
+ fprintf(stderr, "CUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \
22
+ cudaGetErrorString(err_)); \
23
+ exit(1); \
24
+ } \
25
+ } while (0)
26
+
27
+ #if CUDART_VERSION >= 12
28
+ #define CUBLAS_CHECK(err) \
29
+ do { \
30
+ cublasStatus_t err_ = (err); \
31
+ if (err_ != CUBLAS_STATUS_SUCCESS) { \
32
+ fprintf(stderr, "\ncuBLAS error %d at %s:%d: %s\n", \
33
+ err_, __FILE__, __LINE__, cublasGetStatusString(err_)); \
34
+ exit(1); \
35
+ } \
36
+ } while (0)
37
+ #else
38
+ #define CUBLAS_CHECK(err) \
39
+ do { \
40
+ cublasStatus_t err_ = (err); \
41
+ if (err_ != CUBLAS_STATUS_SUCCESS) { \
42
+ fprintf(stderr, "\ncuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__); \
43
+ exit(1); \
44
+ } \
45
+ } while (0)
46
+ #endif // CUDART_VERSION >= 11
47
+
48
+ typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, float & v0, float & v1);
49
+ typedef void (*to_fp32_cuda_t)(const void * x, float * y, int k, cudaStream_t stream);
50
+ typedef void (*dot_kernel_k_t)(const void * vx, const int ib, const int iqs, const float * y, float & v);
51
+ typedef void (*ggml_cuda_func_t)(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
52
+ typedef void (*ggml_cuda_op_t)(
53
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i, float * src0_ddf_i,
54
+ float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
55
+ cudaStream_t & cudaStream_main);
56
+
57
+ // QK = number of values after dequantization
58
+ // QR = QK / number of values before dequantization
59
+
60
+ #define QK4_0 32
61
+ #define QR4_0 2
62
+ typedef struct {
63
+ half d; // delta
64
+ uint8_t qs[QK4_0 / 2]; // nibbles / quants
65
+ } block_q4_0;
66
+ static_assert(sizeof(block_q4_0) == sizeof(ggml_fp16_t) + QK4_0 / 2, "wrong q4_0 block size/padding");
67
+
68
+ #define QK4_1 32
69
+ #define QR4_1 2
70
+ typedef struct {
71
+ half d; // delta
72
+ half m; // min
73
+ uint8_t qs[QK4_1 / 2]; // nibbles / quants
74
+ } block_q4_1;
75
+ static_assert(sizeof(block_q4_1) == sizeof(ggml_fp16_t) * 2 + QK4_1 / 2, "wrong q4_1 block size/padding");
76
+
77
+ #define QK5_0 32
78
+ #define QR5_0 2
79
+ typedef struct {
80
+ half d; // delta
81
+ uint8_t qh[4]; // 5-th bit of quants
82
+ uint8_t qs[QK5_0 / 2]; // nibbles / quants
83
+ } block_q5_0;
84
+ static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding");
85
+
86
+ #define QK5_1 32
87
+ #define QR5_1 2
88
+ typedef struct {
89
+ half d; // delta
90
+ half m; // min
91
+ uint8_t qh[4]; // 5-th bit of quants
92
+ uint8_t qs[QK5_1 / 2]; // nibbles / quants
93
+ } block_q5_1;
94
+ static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding");
95
+
96
+ #define QK8_0 32
97
+ #define QR8_0 1
98
+ typedef struct {
99
+ half d; // delta
100
+ int8_t qs[QK8_0]; // quants
101
+ } block_q8_0;
102
+ static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 block size/padding");
103
+
104
+ //================================= k-quants
105
+
106
+ #define QK_K 256
107
+
108
+ typedef struct {
109
+ uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits
110
+ uint8_t qs[QK_K/4]; // quants
111
+ half d; // super-block scale for quantized scales
112
+ half dmin; // super-block scale for quantized mins
113
+ } block_q2_K;
114
+ static_assert(sizeof(block_q2_K) == 2*sizeof(ggml_fp16_t) + QK_K/16 + QK_K/4, "wrong q2_K block size/padding");
115
+
116
+ typedef struct {
117
+ uint8_t hmask[QK_K/8];
118
+ uint8_t qs[QK_K/4]; // nibbles / quants
119
+ uint8_t scales[3*QK_K/64];
120
+ half d;
121
+ } block_q3_K;
122
+ static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + 11 * QK_K / 64, "wrong q3_K block size/padding");
123
+
124
+ typedef struct {
125
+ half d; // super-block scale for quantized scales
126
+ half dmin; // super-block scale for quantized mins
127
+ uint8_t scales[3*QK_K/64]; // scales, quantized with 6 bits
128
+ uint8_t qs[QK_K/2]; // 4--bit quants
129
+ } block_q4_K;
130
+ static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2, "wrong q4_K block size/padding");
131
+
132
+ typedef struct {
133
+ half d; // super-block scale for quantized scales
134
+ half dmin; // super-block scale for quantized mins
135
+ uint8_t scales[3*QK_K/64]; // scales, quantized with 6 bits
136
+ uint8_t qh[QK_K/8]; // quants, high bit
137
+ uint8_t qs[QK_K/2]; // quants, low 4 bits
138
+ } block_q5_K;
139
+ static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2 + QK_K/8, "wrong q5_K block size/padding");
140
+
141
+ typedef struct {
142
+ uint8_t ql[QK_K/2]; // quants, lower 4 bits
143
+ uint8_t qh[QK_K/4]; // quants, upper 2 bits
144
+ int8_t scales[QK_K/16]; // scales
145
+ half d; // delta
146
+ } block_q6_K;
147
+ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_K block size/padding");
148
+
149
+ #define WARP_SIZE 32
150
+
151
+ #define CUDA_ADD_BLOCK_SIZE 256
152
+ #define CUDA_MUL_BLOCK_SIZE 256
153
+ #define CUDA_SILU_BLOCK_SIZE 256
154
+ #define CUDA_ROPE_BLOCK_SIZE 256
155
+ #define CUDA_DEQUANTIZE_BLOCK_SIZE 256
156
+
157
+ // dmmv = dequantize_mul_mat_vec
158
+ #ifndef GGML_CUDA_DMMV_X
159
+ #define GGML_CUDA_DMMV_X 32
160
+ #endif
161
+ #ifndef GGML_CUDA_DMMV_Y
162
+ #define GGML_CUDA_DMMV_Y 1
163
+ #endif
164
+
165
+ static __global__ void add_f32(const float * x, const float * y, float * dst, const int k) {
166
+ const int i = blockDim.x*blockIdx.x + threadIdx.x;
167
+
168
+ if (i >= k) {
169
+ return;
170
+ }
171
+ dst[i] = x[i] + y[i];
172
+ }
173
+
174
+ static __global__ void mul_f32(const float * x, const float * y, float * dst, const int kx, const int ky) {
175
+ const int i = blockDim.x*blockIdx.x + threadIdx.x;
176
+
177
+ if (i >= kx) {
178
+ return;
179
+ }
180
+ dst[i] = x[i] * y[i%ky];
181
+ }
182
+
183
+ static __global__ void silu_f32(const float * x, float * dst, const int k) {
184
+ const int i = blockDim.x*blockIdx.x + threadIdx.x;
185
+
186
+ if (i >= k) {
187
+ return;
188
+ }
189
+ dst[i] = x[i] / (1.0f + expf(-x[i]));
190
+ }
191
+
192
+ static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols) {
193
+ const int row = blockIdx.x*blockDim.y + threadIdx.y;
194
+ const int tid = threadIdx.x;
195
+
196
+ const float eps = 1e-6;
197
+
198
+ float tmp = 0.0f; // partial sum for thread in warp
199
+
200
+ for (int i = 0; i < ncols; i += WARP_SIZE) {
201
+ const int col = i + tid;
202
+ const float xi = x[row*ncols + col];
203
+ tmp += xi * xi;
204
+ }
205
+
206
+ // sum up partial sums
207
+ __syncthreads();
208
+ #pragma unroll
209
+ for (int mask = 16; mask > 0; mask >>= 1) {
210
+ tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
211
+ }
212
+
213
+ const float mean = tmp / ncols;
214
+ const float scale = 1.0f / sqrtf(mean + eps);
215
+
216
+ for (int i = 0; i < ncols; i += WARP_SIZE) {
217
+ const int col = i + tid;
218
+ dst[row*ncols + col] = scale * x[row*ncols + col];
219
+ }
220
+ }
221
+
222
+ static __device__ void dequantize_q4_0(const void * vx, const int ib, const int iqs, float & v0, float & v1){
223
+ const block_q4_0 * x = (const block_q4_0 *) vx;
224
+
225
+ const float d = x[ib].d;
226
+
227
+ const uint8_t vui = x[ib].qs[iqs];
228
+
229
+ const int8_t vi0 = vui & 0xF;
230
+ const int8_t vi1 = vui >> 4;
231
+
232
+ v0 = (vi0 - 8)*d;
233
+ v1 = (vi1 - 8)*d;
234
+ }
235
+
236
+ static __device__ void dequantize_q4_1(const void * vx, const int ib, const int iqs, float & v0, float & v1){
237
+ const block_q4_1 * x = (const block_q4_1 *) vx;
238
+
239
+ const float d = x[ib].d;
240
+ const float m = x[ib].m;
241
+
242
+ const uint8_t vui = x[ib].qs[iqs];
243
+
244
+ const int8_t vi0 = vui & 0xF;
245
+ const int8_t vi1 = vui >> 4;
246
+
247
+ v0 = vi0*d + m;
248
+ v1 = vi1*d + m;
249
+ }
250
+
251
+ static __device__ void dequantize_q5_0(const void * vx, const int ib, const int iqs, float & v0, float & v1){
252
+ const block_q5_0 * x = (const block_q5_0 *) vx;
253
+
254
+ const float d = x[ib].d;
255
+
256
+ uint32_t qh;
257
+ memcpy(&qh, x[ib].qh, sizeof(qh));
258
+
259
+ const uint8_t xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10;
260
+ const uint8_t xh_1 = ((qh >> (iqs + 12)) ) & 0x10;
261
+
262
+ const int32_t x0 = ((x[ib].qs[iqs] & 0xf) | xh_0) - 16;
263
+ const int32_t x1 = ((x[ib].qs[iqs] >> 4) | xh_1) - 16;
264
+
265
+ v0 = x0*d;
266
+ v1 = x1*d;
267
+ }
268
+
269
+ static __device__ void dequantize_q5_1(const void * vx, const int ib, const int iqs, float & v0, float & v1){
270
+ const block_q5_1 * x = (const block_q5_1 *) vx;
271
+
272
+ const float d = x[ib].d;
273
+ const float m = x[ib].m;
274
+
275
+ uint32_t qh;
276
+ memcpy(&qh, x[ib].qh, sizeof(qh));
277
+
278
+ const uint8_t xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10;
279
+ const uint8_t xh_1 = ((qh >> (iqs + 12)) ) & 0x10;
280
+
281
+ const int32_t x0 = ((x[ib].qs[iqs] & 0xf) | xh_0);
282
+ const int32_t x1 = ((x[ib].qs[iqs] >> 4) | xh_1);
283
+
284
+ v0 = x0*d + m;
285
+ v1 = x1*d + m;
286
+ }
287
+
288
+ static __device__ void dequantize_q8_0(const void * vx, const int ib, const int iqs, float & v0, float & v1){
289
+ const block_q8_0 * x = (const block_q8_0 *) vx;
290
+
291
+ const float d = x[ib].d;
292
+
293
+ const int8_t vi0 = x[ib].qs[iqs + 0];
294
+ const int8_t vi1 = x[ib].qs[iqs + 1];
295
+
296
+ v0 = vi0*d;
297
+ v1 = vi1*d;
298
+ }
299
+
300
+ //================================== k-quants
301
+
302
+ static __global__ void dequantize_block_q2_K(const void * vx, float * yy) {
303
+
304
+ const int i = blockIdx.x;
305
+ const int tid = threadIdx.x;
306
+ const int n = tid/32;
307
+ const int l = tid - 32*n;
308
+ const int is = 8*n + l/16;
309
+
310
+ const block_q2_K * x = (const block_q2_K *) vx;
311
+
312
+ const uint8_t q = x[i].qs[32*n + l];
313
+ float * y = yy + i*QK_K + 128*n;
314
+
315
+ float dall = x[i].d;
316
+ float dmin = x[i].dmin;
317
+ y[l+ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4);
318
+ y[l+32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 2) & 3) - dmin * (x[i].scales[is+2] >> 4);
319
+ y[l+64] = dall * (x[i].scales[is+4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+4] >> 4);
320
+ y[l+96] = dall * (x[i].scales[is+6] & 0xF) * ((q >> 6) & 3) - dmin * (x[i].scales[is+6] >> 4);
321
+
322
+ }
323
+
324
+ static __device__ void vec_dot_q2_K(const void * vx, const int ib, const int iqs, const float * yy, float & result) {
325
+
326
+ const block_q2_K * x = (const block_q2_K *) vx;
327
+
328
+ // if n is 0, we want to do the lower 128, else the upper 128,
329
+ // covering y[l+0], y[l+32], y[l+64], y[l+96] and
330
+ // y[l+16], y[l+48], y[l+80], y[l+112]
331
+ int n = iqs/128; // 0 or 1
332
+ int r = iqs - 128*n; // 0...120 in steps of 8
333
+ int l = r/8; // 0...15 in steps of 1
334
+
335
+ const float * y = yy + 128*n + l;
336
+ const uint8_t * q = x[ib].qs + 32*n + l;
337
+ const uint8_t * s = x[ib].scales + 8*n;
338
+
339
+ const float dall = x[ib].d;
340
+ const float dmin = x[ib].dmin;
341
+
342
+ float sum = y[ 0] * (dall * ((s[0] & 0xF) * ((q[ 0] >> 0) & 3)) - dmin * (s[0] >> 4))
343
+ + y[ 32] * (dall * ((s[2] & 0xF) * ((q[ 0] >> 2) & 3)) - dmin * (s[2] >> 4))
344
+ + y[ 64] * (dall * ((s[4] & 0xF) * ((q[ 0] >> 4) & 3)) - dmin * (s[4] >> 4))
345
+ + y[ 96] * (dall * ((s[6] & 0xF) * ((q[ 0] >> 6) & 3)) - dmin * (s[6] >> 4))
346
+ + y[ 16] * (dall * ((s[1] & 0xF) * ((q[16] >> 0) & 3)) - dmin * (s[1] >> 4))
347
+ + y[ 48] * (dall * ((s[3] & 0xF) * ((q[16] >> 2) & 3)) - dmin * (s[3] >> 4))
348
+ + y[ 80] * (dall * ((s[5] & 0xF) * ((q[16] >> 4) & 3)) - dmin * (s[5] >> 4))
349
+ + y[112] * (dall * ((s[7] & 0xF) * ((q[16] >> 6) & 3)) - dmin * (s[7] >> 4));
350
+
351
+ result = sum;
352
+
353
+ }
354
+
355
+ static __global__ void dequantize_block_q3_K(const void * vx, float * yy) {
356
+
357
+ int r = threadIdx.x/4;
358
+ int i = blockIdx.x;
359
+ int tid = r/2;
360
+ int is0 = r%2;
361
+ int l0 = 16*is0 + 4*(threadIdx.x%4);
362
+ int n = tid / 4;
363
+ int j = tid - 4*n;
364
+
365
+ const block_q3_K * x = (const block_q3_K *) vx;
366
+
367
+ uint8_t m = 1 << (4*n + j);
368
+ int is = 8*n + 2*j + is0;
369
+ int shift = 2*j;
370
+
371
+ int8_t us = is < 4 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+8] >> 0) & 3) << 4) :
372
+ is < 8 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+4] >> 2) & 3) << 4) :
373
+ is < 12 ? (x[i].scales[is-8] >> 4) | (((x[i].scales[is+0] >> 4) & 3) << 4) :
374
+ (x[i].scales[is-8] >> 4) | (((x[i].scales[is-4] >> 6) & 3) << 4);
375
+ float d_all = x[i].d;
376
+ float dl = d_all * (us - 32);
377
+
378
+ float * y = yy + i*QK_K + 128*n + 32*j;
379
+ const uint8_t * q = x[i].qs + 32*n;
380
+ const uint8_t * hm = x[i].hmask;
381
+
382
+ for (int l = l0; l < l0+4; ++l) y[l] = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4));
383
+
384
+ }
385
+
386
+ static __device__ void vec_dot_q3_K(const void * vx, const int ib, const int iqs, const float * yy, float & result) {
387
+
388
+ const block_q3_K * x = (const block_q3_K *) vx;
389
+
390
+ const uint32_t kmask1 = 0x03030303;
391
+ const uint32_t kmask2 = 0x0f0f0f0f;
392
+
393
+ uint32_t aux[3];
394
+ uint32_t utmp[4];
395
+
396
+ // if n is 0, we want to do the lower 128, else the upper 128,
397
+ // covering y[l+0], y[l+32], y[l+64], y[l+96] and
398
+ // y[l+16], y[l+48], y[l+80], y[l+112]
399
+ int n = iqs/128; // 0 or 1
400
+ int r = iqs - 128*n; // 0...120 in steps of 8
401
+ int l = r/8; // 0...15 in steps of 1
402
+
403
+ const float * y = yy + 128*n + l;
404
+ const uint8_t * q = x[ib].qs + 32*n + l;
405
+ const uint8_t * hm = x[ib].hmask + l;
406
+ const int8_t * s = (const int8_t *)utmp + 8*n;
407
+
408
+ memcpy(aux, x[ib].scales, 12);
409
+ utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4);
410
+ utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4);
411
+ utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4);
412
+ utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4);
413
+
414
+ const float dall = x[ib].d;
415
+
416
+ const uint8_t m = 1 << (4*n);
417
+
418
+ float sum = y[ 0] * (s[0] - 32) * (((q[ 0] >> 0) & 3) - (hm[ 0] & (m << 0) ? 0 : 4))
419
+ + y[ 32] * (s[2] - 32) * (((q[ 0] >> 2) & 3) - (hm[ 0] & (m << 1) ? 0 : 4))
420
+ + y[ 64] * (s[4] - 32) * (((q[ 0] >> 4) & 3) - (hm[ 0] & (m << 2) ? 0 : 4))
421
+ + y[ 96] * (s[6] - 32) * (((q[ 0] >> 6) & 3) - (hm[ 0] & (m << 3) ? 0 : 4))
422
+ + y[ 16] * (s[1] - 32) * (((q[16] >> 0) & 3) - (hm[16] & (m << 0) ? 0 : 4))
423
+ + y[ 48] * (s[3] - 32) * (((q[16] >> 2) & 3) - (hm[16] & (m << 1) ? 0 : 4))
424
+ + y[ 80] * (s[5] - 32) * (((q[16] >> 4) & 3) - (hm[16] & (m << 2) ? 0 : 4))
425
+ + y[112] * (s[7] - 32) * (((q[16] >> 6) & 3) - (hm[16] & (m << 3) ? 0 : 4));
426
+
427
+ result = sum * dall;
428
+
429
+ }
430
+
431
+ static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) {
432
+ if (j < 4) {
433
+ d = q[j] & 63; m = q[j + 4] & 63;
434
+ } else {
435
+ d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
436
+ m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
437
+ }
438
+ }
439
+
440
+ static __global__ void dequantize_block_q4_K(const void * vx, float * yy) {
441
+ const block_q4_K * x = (const block_q4_K *) vx;
442
+
443
+ const int i = blockIdx.x;
444
+
445
+ //// assume 64 threads - this is very slightly better than the one below
446
+ //const int tid = threadIdx.x;
447
+ //const int il = tid/16;
448
+ //const int ir = tid%16;
449
+ //const int is = 2*il;
450
+ //const int n = 2;
451
+
452
+ // assume 32 threads
453
+ const int tid = threadIdx.x;
454
+ const int il = tid/8;
455
+ const int ir = tid%8;
456
+ const int is = 2*il;
457
+ const int n = 4;
458
+
459
+ float * y = yy + i*QK_K + 64*il + n*ir;
460
+
461
+ const float dall = x[i].d;
462
+ const float dmin = x[i].dmin;
463
+
464
+ const uint8_t * q = x[i].qs + 32*il + n*ir;
465
+
466
+ uint8_t sc, m;
467
+ get_scale_min_k4(is + 0, x[i].scales, sc, m);
468
+ const float d1 = dall * sc; const float m1 = dmin * m;
469
+ get_scale_min_k4(is + 1, x[i].scales, sc, m);
470
+ const float d2 = dall * sc; const float m2 = dmin * m;
471
+ for (int l = 0; l < n; ++l) {
472
+ y[l + 0] = d1 * (q[l] & 0xF) - m1;
473
+ y[l +32] = d2 * (q[l] >> 4) - m2;
474
+ }
475
+ }
476
+
477
+ static __device__ void vec_dot_q4_K(const void * vx, const int ib, const int iqs, const float * yy, float & result) {
478
+
479
+ const block_q4_K * x = (const block_q4_K *) vx;
480
+
481
+ // iqs is in 0...248 in steps of 8 =>
482
+ const int j = iqs / 64; // j is in 0...3
483
+ const int ir = (iqs - 64*j)/2; // ir is in 0...28 in steps of 4
484
+ const int is = 2*j; // is is in 0...6 in steps of 2
485
+
486
+ const float * y = yy + 64*j + ir;
487
+ const uint8_t * q = x[ib].qs + 32*j + ir;
488
+
489
+ const float dall = x[ib].d;
490
+ const float dmin = x[ib].dmin;
491
+
492
+ uint8_t sc, m;
493
+ get_scale_min_k4(is + 0, x[ib].scales, sc, m);
494
+ const float d1 = dall * sc;
495
+ const float m1 = dmin * m;
496
+ get_scale_min_k4(is + 1, x[ib].scales, sc, m);
497
+ const float d2 = dall * sc;
498
+ const float m2 = dmin * m;
499
+
500
+ float sum = 0;
501
+ for (int k = 0; k < 4; ++k) {
502
+ sum += y[k + 0] * (d1 * (q[k] & 0xF) - m1);
503
+ sum += y[k + 32] * (d2 * (q[k] >> 4) - m2);
504
+ }
505
+ result = sum;
506
+
507
+ }
508
+
509
+ static __global__ void dequantize_block_q5_K(const void * vx, float * yy) {
510
+ const block_q5_K * x = (const block_q5_K *) vx;
511
+
512
+ const int i = blockIdx.x;
513
+
514
+ // assume 64 threads - this is very slightly better than the one below
515
+ const int tid = threadIdx.x;
516
+ const int il = tid/16; // il is in 0...3
517
+ const int ir = tid%16; // ir is in 0...15
518
+ const int is = 2*il; // is is in 0...6
519
+
520
+ float * y = yy + i*QK_K + 64*il + 2*ir;
521
+
522
+ const float dall = x[i].d;
523
+ const float dmin = x[i].dmin;
524
+
525
+ const uint8_t * ql = x[i].qs + 32*il + 2*ir;
526
+ const uint8_t * qh = x[i].qh + 2*ir;
527
+
528
+ uint8_t sc, m;
529
+ get_scale_min_k4(is + 0, x[i].scales, sc, m);
530
+ const float d1 = dall * sc; const float m1 = dmin * m;
531
+ get_scale_min_k4(is + 1, x[i].scales, sc, m);
532
+ const float d2 = dall * sc; const float m2 = dmin * m;
533
+
534
+ uint8_t hm = 1 << (2*il);
535
+ y[ 0] = d1 * ((ql[ 0] & 0xF) + (qh[ 0] & hm ? 16 : 0)) - m1;
536
+ y[ 1] = d1 * ((ql[ 1] & 0xF) + (qh[ 1] & hm ? 16 : 0)) - m1;
537
+ hm <<= 1;
538
+ y[32] = d2 * ((ql[ 0] >> 4) + (qh[ 0] & hm ? 16 : 0)) - m2;
539
+ y[33] = d2 * ((ql[ 1] >> 4) + (qh[ 1] & hm ? 16 : 0)) - m2;
540
+ }
541
+
542
+ static __device__ void vec_dot_q5_K(const void * vx, const int ib, const int iqs, const float * yy, float & result) {
543
+
544
+ const block_q5_K * x = (const block_q5_K *) vx;
545
+
546
+ // iqs is in 0...248 in steps of 8 =>
547
+ const int j = iqs / 64; // j is in 0...3
548
+ const int ir = (iqs - 64*j)/2; // ir is in 0...28 in steps of 4
549
+ const int is = 2*j; // is is in 0...6 in steps of 2
550
+
551
+ const float * y = yy + 64*j + ir;
552
+ const uint8_t * ql = x[ib].qs + 32*j + ir;
553
+ const uint8_t * qh = x[ib].qh + ir;
554
+
555
+ const float dall = x[ib].d;
556
+ const float dmin = x[ib].dmin;
557
+
558
+ uint8_t sc, m;
559
+ get_scale_min_k4(is + 0, x[ib].scales, sc, m);
560
+ const float d1 = dall * sc;
561
+ const float m1 = dmin * m;
562
+ get_scale_min_k4(is + 1, x[ib].scales, sc, m);
563
+ const float d2 = dall * sc;
564
+ const float m2 = dmin * m;
565
+
566
+ uint8_t hm = 1 << is;
567
+ float sum = 0;
568
+ for (int k = 0; k < 4; ++k) {
569
+ sum += y[k + 0] * (d1 * ((ql[k] & 0xF) + (qh[k] & hm ? 16 : 0)) - m1);
570
+ }
571
+ hm <<= 1;
572
+ for (int k = 0; k < 4; ++k) {
573
+ sum += y[k + 32] * (d2 * ((ql[k] >> 4) + (qh[k] & hm ? 16 : 0)) - m2);
574
+ }
575
+ result = sum;
576
+
577
+ }
578
+
579
+ static __global__ void dequantize_block_q6_K(const void * vx, float * yy) {
580
+ const block_q6_K * x = (const block_q6_K *) vx;
581
+
582
+ const int i = blockIdx.x;
583
+
584
+ // assume 64 threads - this is very slightly better than the one below
585
+ const int tid = threadIdx.x;
586
+ const int ip = tid/32; // ip is 0 or 1
587
+ const int il = tid - 32*ip; // 0...32
588
+ const int is = 8*ip + il/16;
589
+
590
+ float * y = yy + i*QK_K + 128*ip + il;
591
+
592
+ const float d = x[i].d;
593
+
594
+ const uint8_t * ql = x[i].ql + 64*ip + il;
595
+ const uint8_t qh = x[i].qh[32*ip + il];
596
+ const int8_t * sc = x[i].scales + is;
597
+
598
+ y[ 0] = d * sc[0] * ((int8_t)((ql[ 0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32);
599
+ y[32] = d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32);
600
+ y[64] = d * sc[4] * ((int8_t)((ql[ 0] >> 4) | (((qh >> 4) & 3) << 4)) - 32);
601
+ y[96] = d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32);
602
+ }
603
+
604
+ static __device__ void vec_dot_q6_K(const void * vx, const int ib, const int iqs, const float * yy, float & result) {
605
+
606
+ const block_q6_K * x = (const block_q6_K *) vx;
607
+
608
+ const int ip = iqs / 128; // 0 or 1
609
+ const int il = (iqs - 128*ip)/8; // 0...15
610
+ const int is = 8*ip;
611
+
612
+ const float * y = yy + 128*ip + il;
613
+
614
+ const float d = x[ib].d;
615
+
616
+ const uint8_t * ql = x[ib].ql + 64*ip + il;
617
+ const uint8_t * qh = x[ib].qh + 32*ip + il;
618
+ const int8_t * sc = x[ib].scales + is;
619
+
620
+ result = y[ 0] * d * sc[0] * ((int8_t)((ql[ 0] & 0xF) | (((qh[ 0] >> 0) & 3) << 4)) - 32)
621
+ + y[ 32] * d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh[ 0] >> 2) & 3) << 4)) - 32)
622
+ + y[ 64] * d * sc[4] * ((int8_t)((ql[ 0] >> 4) | (((qh[ 0] >> 4) & 3) << 4)) - 32)
623
+ + y[ 96] * d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh[ 0] >> 6) & 3) << 4)) - 32)
624
+ + y[ 16] * d * sc[1] * ((int8_t)((ql[16] & 0xF) | (((qh[16] >> 0) & 3) << 4)) - 32)
625
+ + y[ 48] * d * sc[3] * ((int8_t)((ql[48] & 0xF) | (((qh[16] >> 2) & 3) << 4)) - 32)
626
+ + y[ 80] * d * sc[5] * ((int8_t)((ql[16] >> 4) | (((qh[16] >> 4) & 3) << 4)) - 32)
627
+ + y[112] * d * sc[7] * ((int8_t)((ql[48] >> 4) | (((qh[16] >> 6) & 3) << 4)) - 32);
628
+
629
+ }
630
+
631
+ static __device__ void convert_f16(const void * vx, const int ib, const int iqs, float & v0, float & v1){
632
+ const half * x = (const half *) vx;
633
+
634
+ v0 = __half2float(x[ib + iqs + 0]);
635
+ v1 = __half2float(x[ib + iqs + 1]);
636
+ }
637
+
638
+ template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
639
+ static __global__ void dequantize_block(const void * vx, float * y, const int k) {
640
+ const int i = blockDim.x*blockIdx.x + 2*threadIdx.x;
641
+
642
+ if (i >= k) {
643
+ return;
644
+ }
645
+
646
+ const int ib = i/qk; // block index
647
+ const int iqs = (i%qk)/qr; // quant index
648
+ const int iybs = i - i%qk; // y block start index
649
+ const int y_offset = qr == 1 ? 1 : qk/2;
650
+
651
+ // dequantize
652
+ float & v0 = y[iybs + iqs + 0];
653
+ float & v1 = y[iybs + iqs + y_offset];
654
+ dequantize_kernel(vx, ib, iqs, v0, v1);
655
+ }
656
+
657
+ template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
658
+ static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y, float * dst, const int ncols) {
659
+ // qk = quantized weights per x block
660
+ // qr = number of quantized weights per data value in x block
661
+ const int row = blockIdx.x*blockDim.y + threadIdx.y;
662
+ const int tid = threadIdx.x;
663
+
664
+ const int iter_stride = 2*GGML_CUDA_DMMV_X;
665
+ const int vals_per_iter = iter_stride / WARP_SIZE; // num quantized vals per thread and i iter
666
+ const int y_offset = qr == 1 ? 1 : qk/2;
667
+
668
+ float tmp = 0.0f; // partial sum for thread in warp
669
+
670
+ for (int i = 0; i < ncols; i += iter_stride) {
671
+ const int col = i + vals_per_iter*tid;
672
+ const int ib = (row*ncols + col)/qk; // x block index
673
+ const int iqs = (col%qk)/qr; // x quant index
674
+ const int iybs = col - col%qk; // y block start index
675
+
676
+ // processing >2 values per i iter is faster for fast GPUs
677
+ #pragma unroll
678
+ for (int j = 0; j < vals_per_iter; j += 2) {
679
+ // process 2 vals per j iter
680
+
681
+ // dequantize
682
+ float v0, v1;
683
+ dequantize_kernel(vx, ib, iqs + j/qr, v0, v1);
684
+ // for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val
685
+
686
+ // matrix multiplication
687
+ tmp += v0 * y[iybs + iqs + j/qr + 0];
688
+ tmp += v1 * y[iybs + iqs + j/qr + y_offset];
689
+ // for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2
690
+ }
691
+ }
692
+
693
+ // sum up partial sums and write back result
694
+ __syncthreads();
695
+ #pragma unroll
696
+ for (int mask = 16; mask > 0; mask >>= 1) {
697
+ tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
698
+ }
699
+
700
+ if (tid == 0) {
701
+ dst[row] = tmp;
702
+ }
703
+ }
704
+
705
+ template <int n_thread, dot_kernel_k_t dot_kernel>
706
+ static __global__ void dequantize_mul_mat_vec_k(const void * vx, const float * y, float * dst, const int ncols) {
707
+ const int row = blockIdx.x*blockDim.y + threadIdx.y;
708
+ const int tid = threadIdx.x;
709
+
710
+ const int iter_stride = QK_K;
711
+ const int vals_per_iter = iter_stride / n_thread;
712
+ const int num_blocks_per_row = ncols / QK_K;
713
+ const int ib0 = row*num_blocks_per_row;
714
+
715
+ float tmp = 0; // partial sum for thread in warp
716
+
717
+ for (int i = 0; i < ncols; i += iter_stride) {
718
+ const int col = i + vals_per_iter*tid;
719
+ const int ib = ib0 + col/QK_K; // x block index
720
+ const int iqs = col%QK_K; // x quant index
721
+ const int iybs = col - col%QK_K; // y block start index
722
+
723
+ float v;
724
+ dot_kernel(vx, ib, iqs, y + iybs, v);
725
+ tmp += v;
726
+ }
727
+
728
+ // sum up partial sums and write back result
729
+ __syncthreads();
730
+ #pragma unroll
731
+ for (int mask = 16; mask > 0; mask >>= 1) {
732
+ tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
733
+ }
734
+
735
+ if (tid == 0) {
736
+ dst[row] = tmp;
737
+ }
738
+ }
739
+
740
+ static __global__ void rope_f32(const float * x, float * dst, const int ncols, const float p, const float theta_scale) {
741
+ const int col = 2*(blockDim.x*blockIdx.x + threadIdx.x);
742
+
743
+ if (col >= ncols) {
744
+ return;
745
+ }
746
+
747
+ const int row = blockDim.y*blockIdx.y + threadIdx.y;
748
+ const int i = row*ncols + col;
749
+
750
+ const float theta = p*powf(theta_scale, col/2);
751
+ const float sin_theta = sinf(theta);
752
+ const float cos_theta = cosf(theta);
753
+
754
+ const float x0 = x[i + 0];
755
+ const float x1 = x[i + 1];
756
+
757
+ dst[i + 0] = x0*cos_theta - x1*sin_theta;
758
+ dst[i + 1] = x0*sin_theta + x1*cos_theta;
759
+ }
760
+
761
+ static void add_f32_cuda(const float * x, const float * y, float * dst, const int k, cudaStream_t stream) {
762
+ const int num_blocks = (k + CUDA_ADD_BLOCK_SIZE - 1) / CUDA_ADD_BLOCK_SIZE;
763
+ add_f32<<<num_blocks, CUDA_ADD_BLOCK_SIZE, 0, stream>>>(x, y, dst, k);
764
+ }
765
+
766
+ static void mul_f32_cuda(const float * x, const float * y, float * dst, const int kx, const int ky, cudaStream_t stream) {
767
+ const int num_blocks = (kx + CUDA_MUL_BLOCK_SIZE - 1) / CUDA_MUL_BLOCK_SIZE;
768
+ mul_f32<<<num_blocks, CUDA_MUL_BLOCK_SIZE, 0, stream>>>(x, y, dst, kx, ky);
769
+ }
770
+
771
+ static void silu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
772
+ const int num_blocks = (k + CUDA_SILU_BLOCK_SIZE - 1) / CUDA_SILU_BLOCK_SIZE;
773
+ silu_f32<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
774
+ }
775
+
776
+ static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
777
+ GGML_ASSERT(ncols % WARP_SIZE == 0);
778
+ const dim3 block_dims(WARP_SIZE, 1, 1);
779
+ rms_norm_f32<<<nrows, block_dims, 0, stream>>>(x, dst, ncols);
780
+ }
781
+
782
+ static void dequantize_row_q4_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
783
+ const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
784
+ dequantize_block<QK4_0, QR4_0, dequantize_q4_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
785
+ }
786
+
787
+ static void dequantize_row_q4_1_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
788
+ const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
789
+ dequantize_block<QK4_1, QR4_1, dequantize_q4_1><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
790
+ }
791
+
792
+ static void dequantize_row_q5_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
793
+ const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
794
+ dequantize_block<QK5_0, QR5_0, dequantize_q5_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
795
+ }
796
+
797
+ static void dequantize_row_q5_1_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
798
+ const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
799
+ dequantize_block<QK5_1, QR5_1, dequantize_q5_1><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
800
+ }
801
+
802
+ static void dequantize_row_q8_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
803
+ const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
804
+ dequantize_block<QK8_0, QR8_0, dequantize_q8_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
805
+ }
806
+
807
+ static void dequantize_row_q2_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
808
+ const int nb = k / QK_K;
809
+ dequantize_block_q2_K<<<nb, 64, 0, stream>>>(vx, y);
810
+ }
811
+
812
+ static void dequantize_row_q3_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
813
+ const int nb = k / QK_K;
814
+ dequantize_block_q3_K<<<nb, 64, 0, stream>>>(vx, y);
815
+ }
816
+
817
+ static void dequantize_row_q4_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
818
+ const int nb = k / QK_K;
819
+ dequantize_block_q4_K<<<nb, 32, 0, stream>>>(vx, y);
820
+ }
821
+
822
+ static void dequantize_row_q5_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
823
+ const int nb = k / QK_K;
824
+ dequantize_block_q5_K<<<nb, 64, 0, stream>>>(vx, y);
825
+ }
826
+
827
+ static void dequantize_row_q6_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
828
+ const int nb = k / QK_K;
829
+ dequantize_block_q6_K<<<nb, 64, 0, stream>>>(vx, y);
830
+ }
831
+
832
+ static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
833
+ GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
834
+ GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0);
835
+ const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
836
+ dequantize_mul_mat_vec<QK4_0, QR4_0, dequantize_q4_0>
837
+ <<<nrows/GGML_CUDA_DMMV_Y, block_dims, 0, stream>>>(vx, y, dst, ncols);
838
+ }
839
+
840
+ static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
841
+ GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
842
+ GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0);
843
+ const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
844
+ dequantize_mul_mat_vec<QK4_1, QR4_1, dequantize_q4_1>
845
+ <<<nrows/GGML_CUDA_DMMV_Y, block_dims, 0, stream>>>(vx, y, dst, ncols);
846
+ }
847
+
848
+ static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
849
+ GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
850
+ GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0);
851
+ const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
852
+ dequantize_mul_mat_vec<QK5_0, QR5_0, dequantize_q5_0>
853
+ <<<nrows/GGML_CUDA_DMMV_Y, block_dims, 0, stream>>>(vx, y, dst, ncols);
854
+ }
855
+
856
+ static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
857
+ GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
858
+ GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0);
859
+ const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
860
+ dequantize_mul_mat_vec<QK5_1, QR5_1, dequantize_q5_1>
861
+ <<<nrows/GGML_CUDA_DMMV_Y, block_dims, 0, stream>>>(vx, y, dst, ncols);
862
+ }
863
+
864
+ static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
865
+ GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
866
+ GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0);
867
+ const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
868
+ dequantize_mul_mat_vec<QK8_0, QR8_0, dequantize_q8_0>
869
+ <<<nrows/GGML_CUDA_DMMV_Y, block_dims, 0, stream>>>(vx, y, dst, ncols);
870
+ }
871
+
872
+ static void dequantize_mul_mat_vec_q2_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
873
+ GGML_ASSERT(ncols % QK_K == 0);
874
+ const int ny = 2;
875
+ const dim3 block_dims(32, ny, 1);
876
+ dequantize_mul_mat_vec_k<32, vec_dot_q2_K><<<(nrows + ny - 1)/ny, block_dims, 0, stream>>>(vx, y, dst, ncols);
877
+ }
878
+
879
+ static void dequantize_mul_mat_vec_q3_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
880
+ GGML_ASSERT(ncols % QK_K == 0);
881
+ const dim3 block_dims(32, 2, 1);
882
+ dequantize_mul_mat_vec_k<32, vec_dot_q3_K><<<nrows/2, block_dims, 0, stream>>>(vx, y, dst, ncols);
883
+ }
884
+
885
+ static void dequantize_mul_mat_vec_q4_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
886
+ GGML_ASSERT(ncols % QK_K == 0);
887
+ const dim3 block_dims(32, 2, 1);
888
+ dequantize_mul_mat_vec_k<32, vec_dot_q4_K><<<nrows/2, block_dims, 0, stream>>>(vx, y, dst, ncols);
889
+ }
890
+
891
+ static void dequantize_mul_mat_vec_q5_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
892
+ GGML_ASSERT(ncols % QK_K == 0);
893
+ const dim3 block_dims(32, 2, 1);
894
+ dequantize_mul_mat_vec_k<32, vec_dot_q5_K><<<nrows/2, block_dims, 0, stream>>>(vx, y, dst, ncols);
895
+ }
896
+
897
+ static void dequantize_mul_mat_vec_q6_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
898
+ GGML_ASSERT(ncols % QK_K == 0);
899
+ const dim3 block_dims(32, 2, 1);
900
+ dequantize_mul_mat_vec_k<32, vec_dot_q6_K><<<nrows/2, block_dims, 0, stream>>>(vx, y, dst, ncols);
901
+ }
902
+
903
+ static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
904
+ const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
905
+ dequantize_block<1, 1, convert_f16><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
906
+ }
907
+
908
+ static void convert_mul_mat_vec_f16_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
909
+ GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
910
+ GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0);
911
+ const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
912
+ dequantize_mul_mat_vec<1, 1, convert_f16>
913
+ <<<nrows/GGML_CUDA_DMMV_Y, block_dims, 0, stream>>>(vx, y, dst, ncols);
914
+ }
915
+
916
+ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
917
+ switch (type) {
918
+ case GGML_TYPE_Q4_0:
919
+ return dequantize_row_q4_0_cuda;
920
+ case GGML_TYPE_Q4_1:
921
+ return dequantize_row_q4_1_cuda;
922
+ case GGML_TYPE_Q5_0:
923
+ return dequantize_row_q5_0_cuda;
924
+ case GGML_TYPE_Q5_1:
925
+ return dequantize_row_q5_1_cuda;
926
+ case GGML_TYPE_Q8_0:
927
+ return dequantize_row_q8_0_cuda;
928
+ case GGML_TYPE_Q2_K:
929
+ return dequantize_row_q2_K_cuda;
930
+ case GGML_TYPE_Q3_K:
931
+ return dequantize_row_q3_K_cuda;
932
+ case GGML_TYPE_Q4_K:
933
+ return dequantize_row_q4_K_cuda;
934
+ case GGML_TYPE_Q5_K:
935
+ return dequantize_row_q5_K_cuda;
936
+ case GGML_TYPE_Q6_K:
937
+ return dequantize_row_q6_K_cuda;
938
+ case GGML_TYPE_F16:
939
+ return convert_fp16_to_fp32_cuda;
940
+ default:
941
+ return nullptr;
942
+ }
943
+ }
944
+
945
+ static void rope_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p, const float theta_scale, cudaStream_t stream) {
946
+ GGML_ASSERT(nrows % 2 == 0);
947
+ const dim3 block_dims(2*CUDA_ROPE_BLOCK_SIZE, 1, 1);
948
+ const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
949
+ const dim3 block_nums(num_blocks_x, nrows, 1);
950
+ rope_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p, theta_scale);
951
+ }
952
+
953
+ // buffer pool for cuda
954
+ #define MAX_CUDA_BUFFERS 256
955
+
956
+ struct scoped_spin_lock {
957
+ std::atomic_flag& lock;
958
+ scoped_spin_lock(std::atomic_flag& lock) : lock(lock) {
959
+ while (lock.test_and_set(std::memory_order_acquire)) {
960
+ ; // spin
961
+ }
962
+ }
963
+ ~scoped_spin_lock() {
964
+ lock.clear(std::memory_order_release);
965
+ }
966
+ scoped_spin_lock(const scoped_spin_lock&) = delete;
967
+ scoped_spin_lock& operator=(const scoped_spin_lock&) = delete;
968
+ };
969
+
970
+ struct cuda_buffer {
971
+ void * ptr = nullptr;
972
+ size_t size = 0;
973
+ };
974
+
975
+ static cuda_buffer g_cuda_buffer_pool[GGML_CUDA_MAX_DEVICES][MAX_CUDA_BUFFERS];
976
+ static std::atomic_flag g_cuda_pool_lock = ATOMIC_FLAG_INIT;
977
+
978
+ static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
979
+ scoped_spin_lock lock(g_cuda_pool_lock);
980
+ int id;
981
+ CUDA_CHECK(cudaGetDevice(&id));
982
+
983
+ for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) {
984
+ cuda_buffer& b = g_cuda_buffer_pool[id][i];
985
+ if (b.size >= size && b.ptr != nullptr) {
986
+ void * ptr = b.ptr;
987
+ *actual_size = b.size;
988
+ b.ptr = nullptr;
989
+ b.size = 0;
990
+ return ptr;
991
+ }
992
+ }
993
+ void * ptr;
994
+ CUDA_CHECK(cudaMalloc((void **) &ptr, size));
995
+ *actual_size = size;
996
+ return ptr;
997
+ }
998
+
999
+ static void ggml_cuda_pool_free(void * ptr, size_t size) {
1000
+ scoped_spin_lock lock(g_cuda_pool_lock);
1001
+ int id;
1002
+ CUDA_CHECK(cudaGetDevice(&id));
1003
+
1004
+ for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) {
1005
+ cuda_buffer& b = g_cuda_buffer_pool[id][i];
1006
+ if (b.ptr == nullptr) {
1007
+ b.ptr = ptr;
1008
+ b.size = size;
1009
+ return;
1010
+ }
1011
+ }
1012
+ fprintf(stderr, "WARNING: cuda buffer pool full, increase MAX_CUDA_BUFFERS\n");
1013
+ CUDA_CHECK(cudaFree(ptr));
1014
+ }
1015
+
1016
+
1017
+ static void * g_scratch_buffer = nullptr;
1018
+ static size_t g_scratch_size = 1024*1024*1024; // 1 GB by default
1019
+ static size_t g_scratch_offset = 0;
1020
+
1021
+ #define GGML_CUDA_MAX_STREAMS 8 // Set this to 1 for reproducible matrix multiplication.
1022
+ #define GGML_CUDA_MAX_EVENTS 64
1023
+
1024
+ static int g_device_count = -1;
1025
+ static int g_main_device = 0;
1026
+ static float g_tensor_split[GGML_CUDA_MAX_DEVICES] = {0};
1027
+
1028
+ static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
1029
+
1030
+ static cudaStream_t g_cudaStreams_main[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS] = { nullptr };
1031
+
1032
+ static cudaStream_t g_cudaStreams_memcpy_src1[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS] = { nullptr };
1033
+ static cudaEvent_t g_cudaEvents_memcpy_src1[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_EVENTS] = { nullptr };
1034
+
1035
+ void ggml_init_cublas() {
1036
+ static bool initialized = false;
1037
+
1038
+ if (!initialized) {
1039
+ CUDA_CHECK(cudaGetDeviceCount(&g_device_count));
1040
+ GGML_ASSERT(g_device_count <= GGML_CUDA_MAX_DEVICES);
1041
+ int64_t total_vram = 0;
1042
+ fprintf(stderr, "%s: found %d CUDA devices:\n", __func__, g_device_count);
1043
+ for (int id = 0; id < g_device_count; ++id) {
1044
+ cudaDeviceProp prop;
1045
+ CUDA_CHECK(cudaGetDeviceProperties(&prop, id));
1046
+ fprintf(stderr, " Device %d: %s\n", id, prop.name);
1047
+ g_tensor_split[id] = total_vram;
1048
+ total_vram += prop.totalGlobalMem;
1049
+ }
1050
+ for (int id = 0; id < g_device_count; ++id) {
1051
+ g_tensor_split[id] /= total_vram;
1052
+ }
1053
+
1054
+ for (int id = 0; id < g_device_count; ++id) {
1055
+ CUDA_CHECK(cudaSetDevice(id));
1056
+
1057
+ // create streams
1058
+ for (int i = 0; i < GGML_CUDA_MAX_STREAMS; ++i) {
1059
+ CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStreams_main[id][i], cudaStreamNonBlocking));
1060
+ CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStreams_memcpy_src1[id][i], cudaStreamNonBlocking));
1061
+ }
1062
+ // create events
1063
+ for (int i = 0; i < GGML_CUDA_MAX_EVENTS; ++i) {
1064
+ CUDA_CHECK(cudaEventCreateWithFlags(&g_cudaEvents_memcpy_src1[id][i], cudaEventDisableTiming));
1065
+ }
1066
+
1067
+ // create cublas handle
1068
+ CUBLAS_CHECK(cublasCreate(&g_cublas_handles[id]));
1069
+ CUBLAS_CHECK(cublasSetMathMode(g_cublas_handles[id], CUBLAS_TF32_TENSOR_OP_MATH));
1070
+ }
1071
+
1072
+ // configure logging to stdout
1073
+ // CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, nullptr));
1074
+
1075
+ initialized = true;
1076
+ }
1077
+ }
1078
+
1079
+ void ggml_cuda_set_tensor_split(const float * tensor_split) {
1080
+ bool all_zero = true;
1081
+ for (int i = 0; i < g_device_count; ++i) {
1082
+ if (tensor_split[i] != 0.0f) {
1083
+ all_zero = false;
1084
+ break;
1085
+ }
1086
+ }
1087
+ if (all_zero) {
1088
+ return;
1089
+ }
1090
+ float split_sum = 0.0f;
1091
+ for (int i = 0; i < g_device_count; ++i) {
1092
+ g_tensor_split[i] = split_sum;
1093
+ split_sum += tensor_split[i];
1094
+ }
1095
+ for (int i = 0; i < g_device_count; ++i) {
1096
+ g_tensor_split[i] /= split_sum;
1097
+ }
1098
+ }
1099
+
1100
+ void * ggml_cuda_host_malloc(size_t size) {
1101
+ if (getenv("GGML_CUDA_NO_PINNED") != nullptr) {
1102
+ return nullptr;
1103
+ }
1104
+
1105
+ void * ptr = nullptr;
1106
+ cudaError_t err = cudaMallocHost((void **) &ptr, size);
1107
+ if (err != cudaSuccess) {
1108
+ fprintf(stderr, "WARNING: failed to allocate %.2f MB of pinned memory: %s\n",
1109
+ size/1024.0/1024.0, cudaGetErrorString(err));
1110
+ return nullptr;
1111
+ }
1112
+
1113
+ return ptr;
1114
+ }
1115
+
1116
+ void ggml_cuda_host_free(void * ptr) {
1117
+ CUDA_CHECK(cudaFreeHost(ptr));
1118
+ }
1119
+
1120
+ static cudaError_t ggml_cuda_h2d_tensor_2d(
1121
+ void * dst, const struct ggml_tensor * src, int64_t i3, int64_t i2, int64_t i1_low, int64_t i1_high, cudaStream_t stream) {
1122
+
1123
+ char * dst_char = (char *) dst;
1124
+ const int64_t ne0 = src->ne[0];
1125
+ const int64_t nb0 = src->nb[0];
1126
+ const int64_t nb1 = src->nb[1];
1127
+ const int64_t nb2 = src->nb[2];
1128
+ const int64_t nb3 = src->nb[3];
1129
+ const enum ggml_type type = src->type;
1130
+ const int64_t ts = ggml_type_size(type);
1131
+ const int64_t bs = ggml_blck_size(type);
1132
+ int64_t i1_diff = i1_high - i1_low;
1133
+
1134
+ const void * x = (const void *) ((const char *) src->data + i1_low*nb1 + i2*nb2 + i3*nb3);
1135
+ if (nb0 == ts && nb1 == ts*ne0/bs) {
1136
+ return cudaMemcpyAsync(dst_char, x, i1_diff*nb1, cudaMemcpyHostToDevice, stream);
1137
+ } else if (nb0 == ts) {
1138
+ return cudaMemcpy2DAsync(dst_char, ts*ne0/bs, x, nb1, ts*ne0/bs, i1_diff, cudaMemcpyHostToDevice, stream);
1139
+ } else {
1140
+ for (int64_t i1 = 0; i1 < i1_diff; i1++) {
1141
+ const void * rx = (const void *) ((const char *) x + i1*nb1);
1142
+ void * rd = (void *) (dst_char + i1*ts*ne0/bs);
1143
+ // pretend the row is a matrix with cols=1
1144
+ cudaError_t r = cudaMemcpy2DAsync(rd, ts/bs, rx, nb0, ts/bs, ne0, cudaMemcpyHostToDevice, stream);
1145
+ if (r != cudaSuccess) return r;
1146
+ }
1147
+ return cudaSuccess;
1148
+ }
1149
+ }
1150
+
1151
+ inline void ggml_cuda_op_add(
1152
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
1153
+ float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
1154
+ cudaStream_t & cudaStream_main){
1155
+
1156
+ GGML_ASSERT(src0_ddf_i != nullptr);
1157
+ GGML_ASSERT(src1_ddf_i != nullptr);
1158
+ GGML_ASSERT(dst_ddf_i != nullptr);
1159
+
1160
+ const int64_t ne0 = src0->ne[0];
1161
+ const int64_t i01_diff = i01_high - i01_low;
1162
+
1163
+ // compute
1164
+ add_f32_cuda(src0_ddf_i, src1_ddf_i, dst_ddf_i, ne0*i01_diff, cudaStream_main);
1165
+ CUDA_CHECK(cudaGetLastError());
1166
+
1167
+ (void) src1;
1168
+ (void) dst;
1169
+ (void) src0_ddq_i;
1170
+ (void) i02;
1171
+ (void) i1;
1172
+ }
1173
+
1174
+ inline void ggml_cuda_op_mul(
1175
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
1176
+ float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
1177
+ cudaStream_t & cudaStream_main){
1178
+
1179
+ GGML_ASSERT(src0_ddf_i != nullptr);
1180
+ GGML_ASSERT(src1_ddf_i != nullptr);
1181
+ GGML_ASSERT(dst_ddf_i != nullptr);
1182
+
1183
+ const int64_t ne00 = src0->ne[0];
1184
+
1185
+ const int64_t ne10 = src1->ne[0];
1186
+ const int64_t ne11 = src1->ne[1];
1187
+
1188
+ for (int64_t i01 = i01_low; i01 < i01_high; i01++) {
1189
+ const int64_t i11 = i1*ne11 + i01%ne11; // broadcast src1 across src0
1190
+
1191
+ float * src0_ddf_i01 = src0_ddf_i + i01*ne00;
1192
+ float * src1_ddf_i01 = src1_ddf_i + i11*ne10;
1193
+ float * dst_ddf_i01 = dst_ddf_i + i01*ne00;
1194
+
1195
+ // compute
1196
+ mul_f32_cuda(src0_ddf_i01, src1_ddf_i01, dst_ddf_i01, ne00, ne10, cudaStream_main);
1197
+ CUDA_CHECK(cudaGetLastError());
1198
+ }
1199
+
1200
+ (void) dst;
1201
+ (void) src0_ddq_i;
1202
+ (void) i02;
1203
+ }
1204
+
1205
+ inline void ggml_cuda_op_silu(
1206
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
1207
+ float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
1208
+ cudaStream_t & cudaStream_main){
1209
+
1210
+ GGML_ASSERT(src0_ddf_i != nullptr);
1211
+ GGML_ASSERT(dst_ddf_i != nullptr);
1212
+
1213
+ const int64_t ne00 = src0->ne[0];
1214
+ const int64_t i01_diff = i01_high - i01_low;
1215
+
1216
+ // compute
1217
+ silu_f32_cuda(src0_ddf_i, dst_ddf_i, ne00*i01_diff, cudaStream_main);
1218
+ CUDA_CHECK(cudaGetLastError());
1219
+
1220
+ (void) src1;
1221
+ (void) dst;
1222
+ (void) src0_ddq_i;
1223
+ (void) src1_ddf_i;
1224
+ (void) i02;
1225
+ (void) i1;
1226
+ }
1227
+
1228
+ inline void ggml_cuda_op_rms_norm(
1229
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
1230
+ float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
1231
+ cudaStream_t & cudaStream_main){
1232
+
1233
+ GGML_ASSERT(src0_ddf_i != nullptr);
1234
+ GGML_ASSERT(dst_ddf_i != nullptr);
1235
+
1236
+ const int64_t ne00 = src0->ne[0];
1237
+ const int64_t i01_diff = i01_high - i01_low;
1238
+
1239
+ // compute
1240
+ rms_norm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, cudaStream_main);
1241
+ CUDA_CHECK(cudaGetLastError());
1242
+
1243
+ (void) src1;
1244
+ (void) dst;
1245
+ (void) src0_ddq_i;
1246
+ (void) src1_ddf_i;
1247
+ (void) i02;
1248
+ (void) i1;
1249
+ }
1250
+
1251
+ inline void ggml_cuda_op_dequantize_mul_mat_vec(
1252
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
1253
+ float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
1254
+ cudaStream_t & cudaStream_main){
1255
+
1256
+ GGML_ASSERT(src0_ddq_i != nullptr);
1257
+ GGML_ASSERT(src1_ddf_i != nullptr);
1258
+ GGML_ASSERT(dst_ddf_i != nullptr);
1259
+
1260
+ const int64_t ne00 = src0->ne[0];
1261
+ const int64_t nrows = i01_high - i01_low;
1262
+
1263
+ switch (src0->type) {
1264
+ case GGML_TYPE_Q4_0:
1265
+ dequantize_mul_mat_vec_q4_0_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
1266
+ break;
1267
+ case GGML_TYPE_Q4_1:
1268
+ dequantize_mul_mat_vec_q4_1_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
1269
+ break;
1270
+ case GGML_TYPE_Q5_0:
1271
+ dequantize_mul_mat_vec_q5_0_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
1272
+ break;
1273
+ case GGML_TYPE_Q5_1:
1274
+ dequantize_mul_mat_vec_q5_1_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
1275
+ break;
1276
+ case GGML_TYPE_Q8_0:
1277
+ dequantize_mul_mat_vec_q8_0_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
1278
+ break;
1279
+ case GGML_TYPE_Q2_K:
1280
+ dequantize_mul_mat_vec_q2_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
1281
+ break;
1282
+ case GGML_TYPE_Q3_K:
1283
+ dequantize_mul_mat_vec_q3_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
1284
+ break;
1285
+ case GGML_TYPE_Q4_K:
1286
+ dequantize_mul_mat_vec_q4_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
1287
+ break;
1288
+ case GGML_TYPE_Q5_K:
1289
+ dequantize_mul_mat_vec_q5_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
1290
+ break;
1291
+ case GGML_TYPE_Q6_K:
1292
+ dequantize_mul_mat_vec_q6_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
1293
+ break;
1294
+ case GGML_TYPE_F16:
1295
+ convert_mul_mat_vec_f16_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
1296
+ break;
1297
+ default:
1298
+ GGML_ASSERT(false);
1299
+ break;
1300
+ }
1301
+ CUDA_CHECK(cudaGetLastError());
1302
+
1303
+ (void) src1;
1304
+ (void) dst;
1305
+ (void) src0_ddf_i;
1306
+ (void) i02;
1307
+ (void) i1;
1308
+ }
1309
+
1310
+ inline void ggml_cuda_op_mul_mat_cublas(
1311
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
1312
+ float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
1313
+ cudaStream_t & cudaStream_main){
1314
+
1315
+ GGML_ASSERT(src0_ddf_i != nullptr);
1316
+ GGML_ASSERT(src1_ddf_i != nullptr);
1317
+ GGML_ASSERT(dst_ddf_i != nullptr);
1318
+
1319
+ const float alpha = 1.0f;
1320
+ const float beta = 0.0f;
1321
+
1322
+ const int64_t ne00 = src0->ne[0];
1323
+
1324
+ const int64_t ne10 = src1->ne[0];
1325
+ const int64_t ne11 = src1->ne[1];
1326
+
1327
+ const int64_t ne0 = dst->ne[0];
1328
+ const int64_t i01_diff = i01_high - i01_low;
1329
+
1330
+ int id;
1331
+ CUDA_CHECK(cudaGetDevice(&id));
1332
+
1333
+ // the main device has a larger memory buffer to hold the results from all GPUs
1334
+ // ldc == nrows of the matrix that cuBLAS writes into
1335
+ int ldc = dst->backend == GGML_BACKEND_GPU && id == g_main_device ? ne0 : i01_diff;
1336
+
1337
+ CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], cudaStream_main));
1338
+ CUBLAS_CHECK(
1339
+ cublasSgemm(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
1340
+ i01_diff, ne11, ne10,
1341
+ &alpha, src0_ddf_i, ne00,
1342
+ src1_ddf_i, ne10,
1343
+ &beta, dst_ddf_i, ldc));
1344
+
1345
+ (void) dst;
1346
+ (void) src0_ddq_i;
1347
+ (void) i02;
1348
+ (void) i1;
1349
+ }
1350
+
1351
+ inline void ggml_cuda_op_rope(
1352
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
1353
+ float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
1354
+ cudaStream_t & cudaStream_main){
1355
+
1356
+ GGML_ASSERT(src0_ddf_i != nullptr);
1357
+ GGML_ASSERT(dst_ddf_i != nullptr);
1358
+
1359
+ const int64_t ne00 = src0->ne[0];
1360
+ const int64_t i01_diff = i01_high - i01_low;
1361
+
1362
+ const int n_past = ((int32_t *) src1->data)[0];
1363
+ const int n_dims = ((int32_t *) src1->data)[1];
1364
+ const int mode = ((int32_t *) src1->data)[2];
1365
+ GGML_ASSERT(mode == 0);
1366
+
1367
+ const float theta_scale = powf(10000.0, -2.0f/n_dims);
1368
+ const float p = ((mode & 1) == 0 ? n_past + i02 : i02);
1369
+
1370
+ // compute
1371
+ rope_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p, theta_scale, cudaStream_main);
1372
+ CUDA_CHECK(cudaGetLastError());
1373
+
1374
+ (void) dst;
1375
+ (void) src0_ddq_i;
1376
+ (void) src1_ddf_i;
1377
+ (void) i1;
1378
+ }
1379
+
1380
+ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
1381
+ ggml_cuda_op_t op, bool src0_needs_f32) {
1382
+ const int64_t ne00 = src0->ne[0];
1383
+ const int64_t ne01 = src0->ne[1];
1384
+ const int64_t ne02 = src0->ne[2];
1385
+ const int64_t ne03 = src0->ne[3];
1386
+ const int64_t nrows0 = ggml_nrows(src0);
1387
+
1388
+ const bool use_src1 = src1 != nullptr;
1389
+ const int64_t ne10 = use_src1 ? src1->ne[0] : 1;
1390
+ const int64_t ne11 = use_src1 ? src1->ne[1] : 1;
1391
+ const int64_t ne12 = use_src1 ? src1->ne[2] : 1;
1392
+ const int64_t ne13 = use_src1 ? src1->ne[3] : 1;
1393
+
1394
+ const int64_t ne0 = dst->ne[0];
1395
+ const int64_t ne1 = dst->ne[1];
1396
+
1397
+ const int nb2 = dst->nb[2];
1398
+ const int nb3 = dst->nb[3];
1399
+
1400
+ GGML_ASSERT(dst->backend != GGML_BACKEND_GPU_SPLIT);
1401
+ GGML_ASSERT(!use_src1 || src1->backend != GGML_BACKEND_GPU_SPLIT);
1402
+
1403
+ // strides for iteration over dims 3 and 2
1404
+ const int64_t src0_stride = ne00 * ne01;
1405
+ const int64_t src1_stride = ne10 * ne11;
1406
+ const int64_t dst_stride = ne0 * ne1;
1407
+ const int64_t num_iters = ne02 * ne03;
1408
+
1409
+ const size_t src0_ts = ggml_type_size(src0->type);
1410
+ const size_t src0_bs = ggml_blck_size(src0->type);
1411
+
1412
+ struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
1413
+ struct ggml_tensor_extra_gpu * src1_extra = use_src1 ? (ggml_tensor_extra_gpu *) src1->extra : nullptr;
1414
+ struct ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
1415
+
1416
+ const bool src0_on_device = src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT;
1417
+ const bool src0_is_f32 = src0->type == GGML_TYPE_F32;
1418
+
1419
+ const bool split = src0->backend == GGML_BACKEND_GPU_SPLIT;
1420
+
1421
+ const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type);
1422
+
1423
+ // dd = data device
1424
+ char * src0_ddq[GGML_CUDA_MAX_DEVICES] = {nullptr}; // quantized
1425
+ float * src0_ddf[GGML_CUDA_MAX_DEVICES] = {nullptr}; // float
1426
+ float * src1_ddf[GGML_CUDA_MAX_DEVICES] = {nullptr};
1427
+ float * dst_ddf[GGML_CUDA_MAX_DEVICES] = {nullptr};
1428
+
1429
+ // asq = actual size quantized, asf = actual size float
1430
+ size_t src0_asq[GGML_CUDA_MAX_DEVICES] = {0};
1431
+ size_t src0_asf[GGML_CUDA_MAX_DEVICES] = {0};
1432
+ size_t src1_asf[GGML_CUDA_MAX_DEVICES] = {0};
1433
+ size_t dst_asf[GGML_CUDA_MAX_DEVICES] = {0};
1434
+
1435
+ for (int id = 0; id < g_device_count; ++id) {
1436
+ if (!split && id != g_main_device) {
1437
+ continue;
1438
+ }
1439
+
1440
+ const bool src1_on_device = use_src1 && src1->backend == GGML_BACKEND_GPU && id == g_main_device;
1441
+ const bool dst_on_device = dst->backend == GGML_BACKEND_GPU && id == g_main_device;
1442
+
1443
+ int64_t row_low, row_high;
1444
+ if (split) {
1445
+ row_low = id == 0 ? 0 : nrows0*g_tensor_split[id];
1446
+ row_low -= row_low % GGML_CUDA_DMMV_Y;
1447
+ row_high = id == g_device_count - 1 ? nrows0 : nrows0*g_tensor_split[id + 1];
1448
+ row_high -= row_high % GGML_CUDA_DMMV_Y;
1449
+ } else {
1450
+ row_low = 0;
1451
+ row_high = nrows0;
1452
+ }
1453
+ if (row_low == row_high) {
1454
+ continue;
1455
+ }
1456
+
1457
+ int64_t row_diff = row_high - row_low;
1458
+
1459
+ cudaSetDevice(id);
1460
+
1461
+ if (src0_on_device) {
1462
+ if (src0_is_f32) {
1463
+ src0_ddf[id] = (float *) src0_extra->data_device[id];
1464
+ } else {
1465
+ src0_ddq[id] = (char *) src0_extra->data_device[id];
1466
+ }
1467
+ } else {
1468
+ if (src0_is_f32) {
1469
+ src0_ddf[id] = (float *) ggml_cuda_pool_malloc(row_diff*ne00 * sizeof(float), &src0_asf[id]);
1470
+ } else {
1471
+ src0_ddq[id] = (char *) ggml_cuda_pool_malloc(row_diff*ne00 * src0_ts/src0_bs, &src0_asq[id]);
1472
+ }
1473
+ }
1474
+
1475
+ if (src0_needs_f32 && !src0_is_f32) {
1476
+ src0_ddf[id] = (float *) ggml_cuda_pool_malloc(row_diff*ne00 * sizeof(float), &src0_asf[id]);
1477
+ }
1478
+
1479
+ if (use_src1) {
1480
+ if (src1_on_device) {
1481
+ src1_ddf[id] = (float *) src1_extra->data_device[id];
1482
+ } else {
1483
+ src1_ddf[id] = (float *) ggml_cuda_pool_malloc(num_iters*src1_stride * sizeof(float), &src1_asf[id]);
1484
+ }
1485
+ }
1486
+ if (dst_on_device) {
1487
+ dst_ddf[id] = (float *) dst_extra->data_device[id];
1488
+ } else {
1489
+ size_t size_dst_ddf = split ? row_diff*ne1 * sizeof(float) : num_iters*dst_stride * sizeof(float);
1490
+ dst_ddf[id] = (float *) ggml_cuda_pool_malloc(size_dst_ddf, &dst_asf[id]);
1491
+ }
1492
+
1493
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
1494
+ const int64_t i13 = i03 % ne13;
1495
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
1496
+ const int64_t i12 = i02 % ne12;
1497
+
1498
+ const int64_t i0 = i03*ne02 + i02;
1499
+ const int64_t i0_offset_low = row_low/ne01;
1500
+ const int64_t i0_offset_high = row_high/ne01;
1501
+
1502
+ int64_t i01_low = 0;
1503
+ int64_t i01_high = ne01;
1504
+ if (split) {
1505
+ if (i0 < i0_offset_low || i0 > i0_offset_high) {
1506
+ continue;
1507
+ }
1508
+ if (i0 == i0_offset_low) {
1509
+ i01_low = row_low % ne01;
1510
+ }
1511
+ if (i0 == i0_offset_high) {
1512
+ i01_high = row_high % ne01;
1513
+ }
1514
+ }
1515
+
1516
+ // There is possibly a bug in the Windows nvcc compiler regarding instruction reordering or optimizing out local variables.
1517
+ // Removing the first assert or changing the order of the arguments causes the second assert to fail.
1518
+ // Removing both asserts results in i01_high becoming 0 which in turn results in garbage output.
1519
+ // The root cause seems to be a problem with i0_offset_high becoming 0 when it should always be >0 (for single GPU).
1520
+ GGML_ASSERT(i01_low == 0 || g_device_count > 1);
1521
+ GGML_ASSERT(i01_high == ne01 || g_device_count > 1);
1522
+
1523
+ const int64_t i01_diff = i01_high - i01_low;
1524
+ if (i01_diff == 0) {
1525
+ continue;
1526
+ }
1527
+ const int64_t i11 = i13*ne12 + i12;
1528
+
1529
+ cudaStream_t cudaStream_main = g_cudaStreams_main[id][i0 % GGML_CUDA_MAX_STREAMS];
1530
+ cudaStream_t cudaStream_memcpy_src1 = g_cudaStreams_memcpy_src1[id][i0 % GGML_CUDA_MAX_STREAMS];
1531
+ cudaEvent_t cudaEvent_memcpy_src1 = g_cudaEvents_memcpy_src1[id][i0 % GGML_CUDA_MAX_EVENTS];
1532
+
1533
+ // for split tensors the data begins at i0 == i0_offset_low
1534
+ char * src0_ddq_i = src0_ddq[id] + (i0 - i0_offset_low)*src0_stride*src0_ts/src0_bs;
1535
+ float * src0_ddf_i = src0_ddf[id] + (i0 - i0_offset_low)*src0_stride;
1536
+ float * src1_ddf_i = src1_ddf[id] + i11*src1_stride;
1537
+ float * dst_ddf_i = dst_ddf[id] + (i0 - i0_offset_low)*dst_stride;
1538
+
1539
+ // for split tensors the data pointer needs to be rounded down
1540
+ // to the bin edge for i03, i02 bins beyond the first
1541
+ if (i0 - i0_offset_low > 0) {
1542
+ src0_ddq_i -= (row_low % ne01)*ne00 * src0_ts/src0_bs;
1543
+ src0_ddf_i -= (row_low % ne01)*ne00;
1544
+ }
1545
+ if (i0 - i0_offset_low > 0) {
1546
+ dst_ddf_i -= (row_low % ne0)*ne1;
1547
+ }
1548
+
1549
+ // the main device memory buffer can be on VRAM scratch, with space for all partial results
1550
+ // in that case an offset on dst_ddf_i is needed
1551
+ if (dst->backend == GGML_BACKEND_GPU && id == g_main_device) {
1552
+ dst_ddf_i += i01_low; // offset is 0 if no tensor split
1553
+ }
1554
+
1555
+ // copy src0, src1 to device if necessary
1556
+ if (use_src1) {
1557
+ if (src1->backend == GGML_BACKEND_CPU) {
1558
+ CUDA_CHECK(ggml_cuda_h2d_tensor_2d(src1_ddf_i, src1, i03, i02, 0, ne11, cudaStream_memcpy_src1));
1559
+ } else if (src1->backend == GGML_BACKEND_GPU) {
1560
+ if (id != g_main_device) {
1561
+ float * src1_ddf_i_source = (float *) src1_extra->data_device[g_main_device];
1562
+ src1_ddf_i_source += i11*src1_stride;
1563
+ CUDA_CHECK(cudaMemcpyAsync(src1_ddf_i, src1_ddf_i_source, src1_stride*sizeof(float),
1564
+ cudaMemcpyDeviceToDevice, cudaStream_memcpy_src1));
1565
+ }
1566
+ } else {
1567
+ GGML_ASSERT(false);
1568
+ }
1569
+ }
1570
+ CUDA_CHECK(cudaEventRecord(cudaEvent_memcpy_src1, cudaStream_memcpy_src1));
1571
+ if (!src0_on_device) {
1572
+ if (src0_is_f32) {
1573
+ CUDA_CHECK(ggml_cuda_h2d_tensor_2d(src0_ddf_i, src0, i03, i02, i01_low, i01_high, cudaStream_main));
1574
+ } else {
1575
+ CUDA_CHECK(ggml_cuda_h2d_tensor_2d(src0_ddq_i, src0, i03, i02, i01_low, i01_high, cudaStream_main));
1576
+ }
1577
+ }
1578
+
1579
+ // convert src0 to f32 if it's necessary for the ggml_cuda_op
1580
+ if (src0_needs_f32 && !src0_is_f32) {
1581
+ to_fp32_cuda(src0_ddq_i, src0_ddf_i, i01_diff*ne00, cudaStream_main);
1582
+ CUDA_CHECK(cudaGetLastError());
1583
+ }
1584
+
1585
+ // wait with main stream until src1 memcpy is done
1586
+ CUDA_CHECK(cudaStreamWaitEvent(cudaStream_main, cudaEvent_memcpy_src1, 0));
1587
+
1588
+ // do the computation
1589
+ op(src0, src1, dst, src0_ddq_i, src0_ddf_i, src1_ddf_i, dst_ddf_i, i02, i01_low, i01_high, i11, cudaStream_main);
1590
+
1591
+ // copy dst to host or other device if necessary
1592
+ if (!dst_on_device) {
1593
+ void * dst_off_device;
1594
+ cudaMemcpyKind kind;
1595
+ if (dst->backend == GGML_BACKEND_CPU) {
1596
+ dst_off_device = dst->data;
1597
+ kind = cudaMemcpyDeviceToHost;
1598
+ } else if (dst->backend == GGML_BACKEND_GPU) {
1599
+ dst_off_device = dst_extra->data_device[g_main_device];
1600
+ kind = cudaMemcpyDeviceToDevice;
1601
+ } else {
1602
+ GGML_ASSERT(false);
1603
+ }
1604
+ if (split) {
1605
+ // src0 = weight matrix is saved as a transposed matrix for better memory layout.
1606
+ // dst is NOT transposed.
1607
+ // The outputs of cuBLAS matrix matrix multiplications can therefore NOT simply be concatenated for >1 GPU.
1608
+ // Instead they need to be copied to the correct slice in ne0 = dst row index.
1609
+ // If dst is a vector with ne0 == 1 then you don't have to do this but it still produces correct results.
1610
+ for (int64_t j = 0; j < ne1; ++j) {
1611
+ float * dhf_dst_i = (float *) ((char *) dst_off_device + (j*ne0 + i01_low)*sizeof(float) + i02*nb2 + i03*nb3);
1612
+ CUDA_CHECK(cudaMemcpyAsync(dhf_dst_i, dst_ddf_i + j*i01_diff, i01_diff*sizeof(float), kind, cudaStream_main));
1613
+ }
1614
+ } else {
1615
+ float * dhf_dst_i = (float *) ((char *) dst_off_device + i02*nb2 + i03*nb3);
1616
+ CUDA_CHECK(cudaMemcpyAsync(dhf_dst_i, dst_ddf_i, dst_stride*sizeof(float), kind, cudaStream_main));
1617
+ }
1618
+ }
1619
+ }
1620
+ }
1621
+ }
1622
+
1623
+ // wait until each device is finished, then free their buffers
1624
+ for (int id = 0; id < g_device_count; ++id) {
1625
+ CUDA_CHECK(cudaSetDevice(id));
1626
+ CUDA_CHECK(cudaDeviceSynchronize());
1627
+ if (src0_asq[id] > 0) {
1628
+ ggml_cuda_pool_free(src0_ddq[id], src0_asq[id]);
1629
+ }
1630
+ if (src0_asf[id] > 0) {
1631
+ ggml_cuda_pool_free(src0_ddf[id], src0_asf[id]);
1632
+ }
1633
+ if (src1_asf[id] > 0) {
1634
+ ggml_cuda_pool_free(src1_ddf[id], src1_asf[id]);
1635
+ }
1636
+ if (dst_asf[id] > 0) {
1637
+ ggml_cuda_pool_free(dst_ddf[id], dst_asf[id]);
1638
+ }
1639
+ }
1640
+ }
1641
+
1642
+ void ggml_cuda_add(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
1643
+ GGML_ASSERT(src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
1644
+ ggml_cuda_op(src0, src1, dst, ggml_cuda_op_add, true);
1645
+ }
1646
+
1647
+ void ggml_cuda_mul(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
1648
+ GGML_ASSERT(src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
1649
+ ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul, true);
1650
+ }
1651
+
1652
+ void ggml_cuda_silu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
1653
+ GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
1654
+ ggml_cuda_op(src0, src1, dst, ggml_cuda_op_silu, true);
1655
+ }
1656
+
1657
+ void ggml_cuda_rms_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
1658
+ GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
1659
+ ggml_cuda_op(src0, src1, dst, ggml_cuda_op_rms_norm, true);
1660
+ }
1661
+
1662
+ bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
1663
+ GGML_ASSERT(src0->backend != GGML_BACKEND_GPU);
1664
+ const int64_t ne10 = src1->ne[0];
1665
+
1666
+ const int64_t ne0 = dst->ne[0];
1667
+ const int64_t ne1 = dst->ne[1];
1668
+
1669
+ // if (strcmp(dst->name, "KQ") == 0 || strcmp(dst->name, "KQV") == 0) {
1670
+ // fprintf(stderr, "(%ld, %ld, %ld, %ld) + (%ld, %ld, %ld, %ld) -> (%ld, %ld, %ld, %ld)\n",
1671
+ // src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
1672
+ // src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3],
1673
+ // dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3]);
1674
+ // return false;
1675
+ // }
1676
+
1677
+ // TODO: find the optimal values for these
1678
+ if ((src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) &&
1679
+ src1->type == GGML_TYPE_F32 &&
1680
+ dst->type == GGML_TYPE_F32 &&
1681
+ (ne0 >= 32 && ne1 >= 32 && ne10 >= 32)) {
1682
+ return true;
1683
+ }
1684
+
1685
+ return false;
1686
+ }
1687
+
1688
+ void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
1689
+ if (src0->type == GGML_TYPE_F32) {
1690
+ ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, true);
1691
+ } else if (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) {
1692
+ if (src1->ne[1] == 1) {
1693
+ ggml_cuda_op(src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, false);
1694
+ } else {
1695
+ ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, true);
1696
+ }
1697
+ } else {
1698
+ GGML_ASSERT(false);
1699
+ }
1700
+ }
1701
+
1702
+ void ggml_cuda_rope(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
1703
+ GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
1704
+ ggml_cuda_op(src0, src1, dst, ggml_cuda_op_rope, true);
1705
+ }
1706
+
1707
+ void ggml_cuda_nop(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
1708
+ (void) src0;
1709
+ (void) src1;
1710
+ (void) dst;
1711
+ }
1712
+
1713
+ void ggml_cuda_load_data(const char * fname, struct ggml_tensor * tensor, const size_t offset) {
1714
+ FILE * fp = fopen(fname, "rb");
1715
+ int nrows = ggml_nrows(tensor);
1716
+ const size_t nb1 = tensor->nb[1];
1717
+ ggml_backend backend = tensor->backend;
1718
+ struct ggml_tensor_extra_gpu * extra = new struct ggml_tensor_extra_gpu;
1719
+
1720
+ for (int id = 0; id < g_device_count; ++id) {
1721
+ extra->data_device[id] = nullptr;
1722
+
1723
+ if (backend == GGML_BACKEND_GPU && id != g_main_device) {
1724
+ continue;
1725
+ }
1726
+
1727
+ cudaSetDevice(id);
1728
+
1729
+ int row_low, row_high;
1730
+ if (backend == GGML_BACKEND_GPU) {
1731
+ row_low = 0;
1732
+ row_high = nrows;
1733
+ } else if (backend == GGML_BACKEND_GPU_SPLIT) {
1734
+ row_low = id == 0 ? 0 : nrows*g_tensor_split[id];
1735
+ row_low -= row_low % GGML_CUDA_DMMV_Y;
1736
+ row_high = id == g_device_count - 1 ? nrows : nrows*g_tensor_split[id + 1];
1737
+ row_high -= row_high % GGML_CUDA_DMMV_Y;
1738
+ GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0);
1739
+ } else {
1740
+ GGML_ASSERT(false);
1741
+ }
1742
+ if (row_low == row_high) {
1743
+ continue;
1744
+ }
1745
+
1746
+ int64_t nrows_split = row_high - row_low;
1747
+
1748
+ const size_t offset_split = offset + row_low*nb1;
1749
+ const size_t size = ggml_nbytes_split(tensor, nrows_split);
1750
+
1751
+ void * buf;
1752
+ CUDA_CHECK(cudaMalloc(&buf, size));
1753
+ void * buf_host = malloc(size);
1754
+
1755
+ #ifdef _WIN32
1756
+ int ret = _fseeki64(fp, (__int64) offset_split, SEEK_SET);
1757
+ #else
1758
+ int ret = fseek(fp, (long) offset_split, SEEK_SET);
1759
+ #endif
1760
+ GGML_ASSERT(ret == 0); // same
1761
+
1762
+ size_t ret2 = fread(buf_host, size, 1, fp);
1763
+ if (ret2 != 1) {
1764
+ fprintf(stderr, "unexpectedly reached end of file");
1765
+ exit(1);
1766
+ }
1767
+
1768
+ cudaMemcpy(buf, buf_host, size, cudaMemcpyHostToDevice);
1769
+ cudaDeviceSynchronize();
1770
+
1771
+ free(buf_host);
1772
+ extra->data_device[id] = buf;
1773
+ }
1774
+
1775
+ tensor->extra = extra;
1776
+ fclose(fp);
1777
+ }
1778
+
1779
+ void ggml_cuda_free_data(struct ggml_tensor * tensor) {
1780
+ if (tensor->backend != GGML_BACKEND_GPU && tensor->backend != GGML_BACKEND_GPU_SPLIT) {
1781
+ return;
1782
+ }
1783
+
1784
+ ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra;
1785
+
1786
+ for (int id = 0; id < g_device_count; ++id) {
1787
+ if (extra->data_device[id] == nullptr) {
1788
+ continue;
1789
+ }
1790
+
1791
+ CUDA_CHECK(cudaSetDevice(id));
1792
+ CUDA_CHECK(cudaFree(extra->data_device[id]));
1793
+ }
1794
+
1795
+ delete extra;
1796
+ }
1797
+
1798
+ void ggml_cuda_assign_buffers(struct ggml_tensor * tensor) {
1799
+ if (tensor->src0 != nullptr && tensor->src0->op == GGML_OP_RESHAPE) {
1800
+ ggml_cuda_assign_buffers(tensor);
1801
+ }
1802
+
1803
+ const size_t size = ggml_nbytes(tensor);
1804
+ GGML_ASSERT(size <= g_scratch_size);
1805
+ if (g_scratch_offset + size > g_scratch_size) {
1806
+ g_scratch_offset = 0;
1807
+ }
1808
+
1809
+ tensor->backend = GGML_BACKEND_GPU;
1810
+ struct ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu;
1811
+
1812
+ bool inplace = tensor->src0 != nullptr && tensor->src0->data == tensor->data;
1813
+
1814
+ CUDA_CHECK(cudaSetDevice(g_main_device));
1815
+ if (inplace && tensor->src0->backend == GGML_BACKEND_GPU) {
1816
+ struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu * ) tensor->src0->extra;
1817
+ extra->data_device[g_main_device] = src0_extra->data_device;
1818
+ GGML_ASSERT(false);
1819
+ } else {
1820
+ char * data = (char *) g_scratch_buffer;
1821
+ if (data == nullptr) {
1822
+ CUDA_CHECK(cudaMalloc(&data, g_scratch_size));
1823
+ g_scratch_buffer = data;
1824
+ }
1825
+ extra->data_device[g_main_device] = data + g_scratch_offset;
1826
+ }
1827
+
1828
+ // fprintf(stderr, "data=%p offset=%ld data_device=%p\n", data, g_scratch_offset, extra->data_device[0]);
1829
+ g_scratch_offset += size;
1830
+ // fprintf(stderr, "%s: scratch %d, %p - %p\n",
1831
+ // tensor->name, g_scratch_index, data + g_scratch_offset, data + g_scratch_offset + size);
1832
+
1833
+ GGML_ASSERT(g_scratch_offset <= g_scratch_size);
1834
+ tensor->extra = extra;
1835
+ }
1836
+
1837
+ void ggml_cuda_set_main_device(int main_device) {
1838
+ if (main_device > g_device_count) {
1839
+ fprintf(stderr, "warning: cannot set main_device=%d because there are only %d devices. Using device %d instead.\n",
1840
+ main_device, g_device_count, g_main_device);
1841
+ return;
1842
+ }
1843
+ g_main_device = main_device;
1844
+ if (g_device_count > 1) {
1845
+ cudaDeviceProp prop;
1846
+ CUDA_CHECK(cudaGetDeviceProperties(&prop, g_main_device));
1847
+ fprintf(stderr, "%s: using device %d (%s) as main device\n", __func__, g_main_device, prop.name);
1848
+ }
1849
+ }
1850
+
1851
+ void ggml_cuda_set_scratch_size(size_t scratch_size) {
1852
+ g_scratch_size = scratch_size;
1853
+ }
1854
+
1855
+ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor){
1856
+ ggml_cuda_func_t func;
1857
+ const bool any_on_device = tensor->backend == GGML_BACKEND_GPU
1858
+ || tensor->src0->backend == GGML_BACKEND_GPU || tensor->src0->backend == GGML_BACKEND_GPU_SPLIT
1859
+ || (tensor->src1 != nullptr && tensor->src1->backend == GGML_BACKEND_GPU);
1860
+
1861
+ switch (tensor->op) {
1862
+ case GGML_OP_ADD:
1863
+ if (!any_on_device) {
1864
+ return false;
1865
+ }
1866
+ func = ggml_cuda_add;
1867
+ break;
1868
+ case GGML_OP_MUL:
1869
+ if (!any_on_device) {
1870
+ return false;
1871
+ }
1872
+ func = ggml_cuda_mul;
1873
+ break;
1874
+ case GGML_OP_SILU:
1875
+ if (!any_on_device) {
1876
+ return false;
1877
+ }
1878
+ func = ggml_cuda_silu;
1879
+ break;
1880
+ case GGML_OP_RMS_NORM:
1881
+ if (!any_on_device) {
1882
+ return false;
1883
+ }
1884
+ func = ggml_cuda_rms_norm;
1885
+ break;
1886
+ case GGML_OP_MUL_MAT:
1887
+ if (!any_on_device && !ggml_cuda_can_mul_mat(tensor->src0, tensor->src1, tensor)) {
1888
+ return false;
1889
+ }
1890
+ func = ggml_cuda_mul_mat;
1891
+ break;
1892
+ case GGML_OP_RESHAPE:
1893
+ if (!any_on_device) {
1894
+ return false;
1895
+ }
1896
+ func = ggml_cuda_nop;
1897
+ break;
1898
+ case GGML_OP_ROPE:
1899
+ if (!any_on_device) {
1900
+ return false;
1901
+ }
1902
+ func = ggml_cuda_rope;
1903
+ break;
1904
+ default:
1905
+ return false;
1906
+ }
1907
+
1908
+ if (params->ith != 0) {
1909
+ return true;
1910
+ }
1911
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
1912
+ return true;
1913
+ }
1914
+ func(tensor->src0, tensor->src1, tensor);
1915
+ return true;
1916
+ }