llama_cpp 0.14.5 → 0.14.7
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +13 -0
- data/ext/llama_cpp/llama_cpp.cpp +37 -2
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +4 -0
- data/vendor/tmp/llama.cpp/Makefile +24 -7
- data/vendor/tmp/llama.cpp/ggml-alloc.c +8 -8
- data/vendor/tmp/llama.cpp/ggml-backend.c +14 -10
- data/vendor/tmp/llama.cpp/ggml-cuda.cu +135 -46
- data/vendor/tmp/llama.cpp/ggml-impl.h +263 -5
- data/vendor/tmp/llama.cpp/ggml-metal.m +130 -83
- data/vendor/tmp/llama.cpp/ggml-metal.metal +505 -1467
- data/vendor/tmp/llama.cpp/ggml-quants.c +1 -294
- data/vendor/tmp/llama.cpp/ggml-sycl.cpp +65 -52
- data/vendor/tmp/llama.cpp/ggml.c +151 -99
- data/vendor/tmp/llama.cpp/ggml.h +5 -4
- data/vendor/tmp/llama.cpp/llama.cpp +1308 -254
- data/vendor/tmp/llama.cpp/llama.h +19 -6
- data/vendor/tmp/llama.cpp/sgemm.cpp +999 -0
- data/vendor/tmp/llama.cpp/sgemm.h +12 -0
- metadata +4 -2
@@ -14,47 +14,6 @@
|
|
14
14
|
#include <stdlib.h> // for qsort
|
15
15
|
#include <stdio.h> // for GGML_ASSERT
|
16
16
|
|
17
|
-
#ifdef __ARM_NEON
|
18
|
-
|
19
|
-
// if YCM cannot find <arm_neon.h>, make a symbolic link to it, for example:
|
20
|
-
//
|
21
|
-
// $ ln -sfn /Library/Developer/CommandLineTools/usr/lib/clang/13.1.6/include/arm_neon.h ./src/
|
22
|
-
//
|
23
|
-
#include <arm_neon.h>
|
24
|
-
|
25
|
-
#else
|
26
|
-
|
27
|
-
#ifdef __wasm_simd128__
|
28
|
-
#include <wasm_simd128.h>
|
29
|
-
#else
|
30
|
-
#if defined(__POWER9_VECTOR__) || defined(__powerpc64__)
|
31
|
-
#include <altivec.h>
|
32
|
-
#undef bool
|
33
|
-
#define bool _Bool
|
34
|
-
#else
|
35
|
-
#if defined(_MSC_VER) || defined(__MINGW32__)
|
36
|
-
#include <intrin.h>
|
37
|
-
#else
|
38
|
-
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) || defined(__SSE3__)
|
39
|
-
#if !defined(__riscv)
|
40
|
-
#include <immintrin.h>
|
41
|
-
#endif
|
42
|
-
#endif
|
43
|
-
#endif
|
44
|
-
#endif
|
45
|
-
#endif
|
46
|
-
#endif
|
47
|
-
|
48
|
-
#ifdef __riscv_v_intrinsic
|
49
|
-
#include <riscv_vector.h>
|
50
|
-
#endif
|
51
|
-
|
52
|
-
#undef MIN
|
53
|
-
#undef MAX
|
54
|
-
|
55
|
-
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
56
|
-
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
57
|
-
|
58
17
|
#define UNUSED GGML_UNUSED
|
59
18
|
|
60
19
|
// some compilers don't provide _mm256_set_m128i, e.g. gcc 7
|
@@ -132,7 +91,7 @@ static inline __m256 sum_i16_pairs_float(const __m256i x) {
|
|
132
91
|
}
|
133
92
|
|
134
93
|
static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) {
|
135
|
-
#if defined(__AVXVNNI__) || defined(__AVX512VNNI__)
|
94
|
+
#if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__))
|
136
95
|
const __m256i zero = _mm256_setzero_si256();
|
137
96
|
const __m256i summed_pairs = _mm256_dpbusd_epi32(zero, ax, sy);
|
138
97
|
return _mm256_cvtepi32_ps(summed_pairs);
|
@@ -276,258 +235,6 @@ static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128
|
|
276
235
|
#endif // __AVX__ || __AVX2__ || __AVX512F__
|
277
236
|
#endif // defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
|
278
237
|
|
279
|
-
#if defined(__ARM_NEON)
|
280
|
-
|
281
|
-
#ifdef _MSC_VER
|
282
|
-
|
283
|
-
#define ggml_vld1q_u32(w,x,y,z) { ((w) + ((uint64_t)(x) << 32)), ((y) + ((uint64_t)(z) << 32)) }
|
284
|
-
|
285
|
-
#else
|
286
|
-
|
287
|
-
#define ggml_vld1q_u32(w,x,y,z) { (w), (x), (y), (z) }
|
288
|
-
|
289
|
-
#endif
|
290
|
-
|
291
|
-
#if !defined(__aarch64__)
|
292
|
-
|
293
|
-
// 64-bit compatibility
|
294
|
-
|
295
|
-
// vaddvq_s16
|
296
|
-
// vpaddq_s16
|
297
|
-
// vpaddq_s32
|
298
|
-
// vaddvq_s32
|
299
|
-
// vaddvq_f32
|
300
|
-
// vmaxvq_f32
|
301
|
-
// vcvtnq_s32_f32
|
302
|
-
// vzip1_u8
|
303
|
-
// vzip2_u8
|
304
|
-
|
305
|
-
inline static int32_t vaddvq_s16(int16x8_t v) {
|
306
|
-
return
|
307
|
-
(int32_t)vgetq_lane_s16(v, 0) + (int32_t)vgetq_lane_s16(v, 1) +
|
308
|
-
(int32_t)vgetq_lane_s16(v, 2) + (int32_t)vgetq_lane_s16(v, 3) +
|
309
|
-
(int32_t)vgetq_lane_s16(v, 4) + (int32_t)vgetq_lane_s16(v, 5) +
|
310
|
-
(int32_t)vgetq_lane_s16(v, 6) + (int32_t)vgetq_lane_s16(v, 7);
|
311
|
-
}
|
312
|
-
|
313
|
-
inline static int16x8_t vpaddq_s16(int16x8_t a, int16x8_t b) {
|
314
|
-
int16x4_t a0 = vpadd_s16(vget_low_s16(a), vget_high_s16(a));
|
315
|
-
int16x4_t b0 = vpadd_s16(vget_low_s16(b), vget_high_s16(b));
|
316
|
-
return vcombine_s16(a0, b0);
|
317
|
-
}
|
318
|
-
|
319
|
-
inline static int32x4_t vpaddq_s32(int32x4_t a, int32x4_t b) {
|
320
|
-
int32x2_t a0 = vpadd_s32(vget_low_s32(a), vget_high_s32(a));
|
321
|
-
int32x2_t b0 = vpadd_s32(vget_low_s32(b), vget_high_s32(b));
|
322
|
-
return vcombine_s32(a0, b0);
|
323
|
-
}
|
324
|
-
|
325
|
-
inline static int32_t vaddvq_s32(int32x4_t v) {
|
326
|
-
return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3);
|
327
|
-
}
|
328
|
-
|
329
|
-
inline static float vaddvq_f32(float32x4_t v) {
|
330
|
-
return vgetq_lane_f32(v, 0) + vgetq_lane_f32(v, 1) + vgetq_lane_f32(v, 2) + vgetq_lane_f32(v, 3);
|
331
|
-
}
|
332
|
-
|
333
|
-
inline static float vmaxvq_f32(float32x4_t v) {
|
334
|
-
return
|
335
|
-
MAX(MAX(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)),
|
336
|
-
MAX(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3)));
|
337
|
-
}
|
338
|
-
|
339
|
-
inline static int32x4_t vcvtnq_s32_f32(float32x4_t v) {
|
340
|
-
int32x4_t res;
|
341
|
-
|
342
|
-
res[0] = roundf(vgetq_lane_f32(v, 0));
|
343
|
-
res[1] = roundf(vgetq_lane_f32(v, 1));
|
344
|
-
res[2] = roundf(vgetq_lane_f32(v, 2));
|
345
|
-
res[3] = roundf(vgetq_lane_f32(v, 3));
|
346
|
-
|
347
|
-
return res;
|
348
|
-
}
|
349
|
-
|
350
|
-
inline static uint8x8_t vzip1_u8(uint8x8_t a, uint8x8_t b) {
|
351
|
-
uint8x8_t res;
|
352
|
-
|
353
|
-
res[0] = a[0]; res[1] = b[0];
|
354
|
-
res[2] = a[1]; res[3] = b[1];
|
355
|
-
res[4] = a[2]; res[5] = b[2];
|
356
|
-
res[6] = a[3]; res[7] = b[3];
|
357
|
-
|
358
|
-
return res;
|
359
|
-
}
|
360
|
-
|
361
|
-
inline static uint8x8_t vzip2_u8(uint8x8_t a, uint8x8_t b) {
|
362
|
-
uint8x8_t res;
|
363
|
-
|
364
|
-
res[0] = a[4]; res[1] = b[4];
|
365
|
-
res[2] = a[5]; res[3] = b[5];
|
366
|
-
res[4] = a[6]; res[5] = b[6];
|
367
|
-
res[6] = a[7]; res[7] = b[7];
|
368
|
-
|
369
|
-
return res;
|
370
|
-
}
|
371
|
-
|
372
|
-
// vld1q_s16_x2
|
373
|
-
// vld1q_u8_x2
|
374
|
-
// vld1q_u8_x4
|
375
|
-
// vld1q_s8_x2
|
376
|
-
// vld1q_s8_x4
|
377
|
-
// TODO: double-check these work correctly
|
378
|
-
|
379
|
-
typedef struct ggml_int16x8x2_t {
|
380
|
-
int16x8_t val[2];
|
381
|
-
} ggml_int16x8x2_t;
|
382
|
-
|
383
|
-
inline static ggml_int16x8x2_t ggml_vld1q_s16_x2(const int16_t * ptr) {
|
384
|
-
ggml_int16x8x2_t res;
|
385
|
-
|
386
|
-
res.val[0] = vld1q_s16(ptr + 0);
|
387
|
-
res.val[1] = vld1q_s16(ptr + 8);
|
388
|
-
|
389
|
-
return res;
|
390
|
-
}
|
391
|
-
|
392
|
-
typedef struct ggml_uint8x16x2_t {
|
393
|
-
uint8x16_t val[2];
|
394
|
-
} ggml_uint8x16x2_t;
|
395
|
-
|
396
|
-
inline static ggml_uint8x16x2_t ggml_vld1q_u8_x2(const uint8_t * ptr) {
|
397
|
-
ggml_uint8x16x2_t res;
|
398
|
-
|
399
|
-
res.val[0] = vld1q_u8(ptr + 0);
|
400
|
-
res.val[1] = vld1q_u8(ptr + 16);
|
401
|
-
|
402
|
-
return res;
|
403
|
-
}
|
404
|
-
|
405
|
-
typedef struct ggml_uint8x16x4_t {
|
406
|
-
uint8x16_t val[4];
|
407
|
-
} ggml_uint8x16x4_t;
|
408
|
-
|
409
|
-
inline static ggml_uint8x16x4_t ggml_vld1q_u8_x4(const uint8_t * ptr) {
|
410
|
-
ggml_uint8x16x4_t res;
|
411
|
-
|
412
|
-
res.val[0] = vld1q_u8(ptr + 0);
|
413
|
-
res.val[1] = vld1q_u8(ptr + 16);
|
414
|
-
res.val[2] = vld1q_u8(ptr + 32);
|
415
|
-
res.val[3] = vld1q_u8(ptr + 48);
|
416
|
-
|
417
|
-
return res;
|
418
|
-
}
|
419
|
-
|
420
|
-
typedef struct ggml_int8x16x2_t {
|
421
|
-
int8x16_t val[2];
|
422
|
-
} ggml_int8x16x2_t;
|
423
|
-
|
424
|
-
inline static ggml_int8x16x2_t ggml_vld1q_s8_x2(const int8_t * ptr) {
|
425
|
-
ggml_int8x16x2_t res;
|
426
|
-
|
427
|
-
res.val[0] = vld1q_s8(ptr + 0);
|
428
|
-
res.val[1] = vld1q_s8(ptr + 16);
|
429
|
-
|
430
|
-
return res;
|
431
|
-
}
|
432
|
-
|
433
|
-
typedef struct ggml_int8x16x4_t {
|
434
|
-
int8x16_t val[4];
|
435
|
-
} ggml_int8x16x4_t;
|
436
|
-
|
437
|
-
inline static ggml_int8x16x4_t ggml_vld1q_s8_x4(const int8_t * ptr) {
|
438
|
-
ggml_int8x16x4_t res;
|
439
|
-
|
440
|
-
res.val[0] = vld1q_s8(ptr + 0);
|
441
|
-
res.val[1] = vld1q_s8(ptr + 16);
|
442
|
-
res.val[2] = vld1q_s8(ptr + 32);
|
443
|
-
res.val[3] = vld1q_s8(ptr + 48);
|
444
|
-
|
445
|
-
return res;
|
446
|
-
}
|
447
|
-
|
448
|
-
// NOTE: not tested
|
449
|
-
inline static int8x16_t ggml_vqtbl1q_s8(int8x16_t a, uint8x16_t b) {
|
450
|
-
int8x16_t res;
|
451
|
-
|
452
|
-
res[ 0] = a[b[ 0]];
|
453
|
-
res[ 1] = a[b[ 1]];
|
454
|
-
res[ 2] = a[b[ 2]];
|
455
|
-
res[ 3] = a[b[ 3]];
|
456
|
-
res[ 4] = a[b[ 4]];
|
457
|
-
res[ 5] = a[b[ 5]];
|
458
|
-
res[ 6] = a[b[ 6]];
|
459
|
-
res[ 7] = a[b[ 7]];
|
460
|
-
res[ 8] = a[b[ 8]];
|
461
|
-
res[ 9] = a[b[ 9]];
|
462
|
-
res[10] = a[b[10]];
|
463
|
-
res[11] = a[b[11]];
|
464
|
-
res[12] = a[b[12]];
|
465
|
-
res[13] = a[b[13]];
|
466
|
-
res[14] = a[b[14]];
|
467
|
-
res[15] = a[b[15]];
|
468
|
-
|
469
|
-
return res;
|
470
|
-
}
|
471
|
-
|
472
|
-
// NOTE: not tested
|
473
|
-
inline static uint8x16_t ggml_vqtbl1q_u8(uint8x16_t a, uint8x16_t b) {
|
474
|
-
uint8x16_t res;
|
475
|
-
|
476
|
-
res[ 0] = a[b[ 0]];
|
477
|
-
res[ 1] = a[b[ 1]];
|
478
|
-
res[ 2] = a[b[ 2]];
|
479
|
-
res[ 3] = a[b[ 3]];
|
480
|
-
res[ 4] = a[b[ 4]];
|
481
|
-
res[ 5] = a[b[ 5]];
|
482
|
-
res[ 6] = a[b[ 6]];
|
483
|
-
res[ 7] = a[b[ 7]];
|
484
|
-
res[ 8] = a[b[ 8]];
|
485
|
-
res[ 9] = a[b[ 9]];
|
486
|
-
res[10] = a[b[10]];
|
487
|
-
res[11] = a[b[11]];
|
488
|
-
res[12] = a[b[12]];
|
489
|
-
res[13] = a[b[13]];
|
490
|
-
res[14] = a[b[14]];
|
491
|
-
res[15] = a[b[15]];
|
492
|
-
|
493
|
-
return res;
|
494
|
-
}
|
495
|
-
|
496
|
-
#else
|
497
|
-
|
498
|
-
#define ggml_int16x8x2_t int16x8x2_t
|
499
|
-
#define ggml_uint8x16x2_t uint8x16x2_t
|
500
|
-
#define ggml_uint8x16x4_t uint8x16x4_t
|
501
|
-
#define ggml_int8x16x2_t int8x16x2_t
|
502
|
-
#define ggml_int8x16x4_t int8x16x4_t
|
503
|
-
|
504
|
-
#define ggml_vld1q_s16_x2 vld1q_s16_x2
|
505
|
-
#define ggml_vld1q_u8_x2 vld1q_u8_x2
|
506
|
-
#define ggml_vld1q_u8_x4 vld1q_u8_x4
|
507
|
-
#define ggml_vld1q_s8_x2 vld1q_s8_x2
|
508
|
-
#define ggml_vld1q_s8_x4 vld1q_s8_x4
|
509
|
-
#define ggml_vqtbl1q_s8 vqtbl1q_s8
|
510
|
-
#define ggml_vqtbl1q_u8 vqtbl1q_u8
|
511
|
-
|
512
|
-
#endif
|
513
|
-
|
514
|
-
#if !defined(__ARM_FEATURE_DOTPROD)
|
515
|
-
|
516
|
-
inline static int32x4_t ggml_vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b) {
|
517
|
-
const int16x8_t p0 = vmull_s8(vget_low_s8 (a), vget_low_s8 (b));
|
518
|
-
const int16x8_t p1 = vmull_s8(vget_high_s8(a), vget_high_s8(b));
|
519
|
-
|
520
|
-
return vaddq_s32(acc, vaddq_s32(vpaddlq_s16(p0), vpaddlq_s16(p1)));
|
521
|
-
}
|
522
|
-
|
523
|
-
#else
|
524
|
-
|
525
|
-
#define ggml_vdotq_s32(a, b, c) vdotq_s32(a, b, c)
|
526
|
-
|
527
|
-
#endif
|
528
|
-
|
529
|
-
#endif
|
530
|
-
|
531
238
|
#if defined(__ARM_NEON) || defined(__wasm_simd128__)
|
532
239
|
#define B1(c,s,n) 0x ## n ## c , 0x ## n ## s
|
533
240
|
#define B2(c,s,n) B1(c,s,n ## c), B1(c,s,n ## s)
|
@@ -3154,7 +3154,6 @@ typedef float (*vec_dot_q_mul_mat_sycl_t)(
|
|
3154
3154
|
#define SYCL_SCALE_BLOCK_SIZE 256
|
3155
3155
|
#define SYCL_CLAMP_BLOCK_SIZE 256
|
3156
3156
|
#define SYCL_ROPE_BLOCK_SIZE 256
|
3157
|
-
#define SYCL_SOFT_MAX_BLOCK_SIZE 1024
|
3158
3157
|
#define SYCL_ALIBI_BLOCK_SIZE 32
|
3159
3158
|
#define SYCL_DIAG_MASK_INF_BLOCK_SIZE 32
|
3160
3159
|
#define SYCL_QUANTIZE_BLOCK_SIZE 256
|
@@ -13080,11 +13079,13 @@ static void soft_max_f32_sycl(const float * x, const float * mask, const float *
|
|
13080
13079
|
const int nrows_y, const float scale, const float max_bias,
|
13081
13080
|
dpct::queue_ptr stream) {
|
13082
13081
|
int nth = WARP_SIZE;
|
13083
|
-
|
13082
|
+
int max_block_size = g_work_group_size;
|
13083
|
+
while (nth < ncols_x && nth < max_block_size) nth *= 2;
|
13084
|
+
if (nth>max_block_size) nth = max_block_size;
|
13085
|
+
|
13084
13086
|
const sycl::range<3> block_dims(1, 1, nth);
|
13085
13087
|
const sycl::range<3> block_nums(1, 1, nrows_x);
|
13086
13088
|
const size_t n_local_scratch = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE);
|
13087
|
-
static_assert(SYCL_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
|
13088
13089
|
|
13089
13090
|
const uint32_t n_head_kv = nrows_x/nrows_y;
|
13090
13091
|
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
|
@@ -13094,6 +13095,12 @@ static void soft_max_f32_sycl(const float * x, const float * mask, const float *
|
|
13094
13095
|
|
13095
13096
|
const size_t local_mem_size = stream->get_device().get_info<sycl::info::device::local_mem_size>();
|
13096
13097
|
if (n_local_scratch*sizeof(float) < local_mem_size) {
|
13098
|
+
if (ncols_x > max_block_size) {
|
13099
|
+
soft_max_f32_submitter<true, 0, 0>(x, mask, pos, dst, ncols_x, nrows_y, scale,
|
13100
|
+
max_bias, m0, m1, n_head_log2, block_nums,
|
13101
|
+
block_dims, n_local_scratch, stream);
|
13102
|
+
return;
|
13103
|
+
}
|
13097
13104
|
switch (ncols_x) {
|
13098
13105
|
case 32:
|
13099
13106
|
soft_max_f32_submitter<true, 32, 32>(x, mask, pos, dst, ncols_x, nrows_y, scale,
|
@@ -15989,73 +15996,76 @@ static void ggml_sycl_mul_mat_id_sycl(ggml_tensor * dst) {
|
|
15989
15996
|
static void ggml_sycl_mul_mat_id(const ggml_tensor *src0,
|
15990
15997
|
const ggml_tensor *src1,
|
15991
15998
|
ggml_tensor *dst) try {
|
15992
|
-
|
15993
|
-
|
15994
|
-
|
15995
|
-
|
15999
|
+
GGML_ASSERT(src0->backend != GGML_BACKEND_TYPE_GPU_SPLIT &&
|
16000
|
+
"mul_mat_id does not support split buffers");
|
16001
|
+
const ggml_tensor *ids = dst->src[2];
|
16002
|
+
const dpct::queue_ptr stream = g_syclStreams[g_main_device][0];
|
15996
16003
|
|
15997
|
-
const
|
15998
|
-
const
|
16004
|
+
const size_t nb11 = src1->nb[1];
|
16005
|
+
const size_t nb1 = dst->nb[1];
|
15999
16006
|
|
16000
|
-
const
|
16001
|
-
const int32_t
|
16002
|
-
const int32_t n_as = ((int32_t *) dst->op_params)[1];
|
16007
|
+
const int32_t id = ((int32_t *)dst->op_params)[0];
|
16008
|
+
const int32_t n_as = src0->ne[2];
|
16003
16009
|
|
16004
16010
|
std::vector<char> ids_host(ggml_nbytes(ids));
|
16011
|
+
const char *ids_dev = (const char *)ids->data;
|
16005
16012
|
|
16006
|
-
|
16007
|
-
|
16008
|
-
|
16009
|
-
const char * ids_dev = (const char *)((const ggml_tensor_extra_gpu *)ids->extra)->data_device[g_main_device];
|
16010
|
-
SYCL_CHECK(CHECK_TRY_ERROR(
|
16011
|
-
stream->memcpy(ids_host.data(), ids_dev, ggml_nbytes(ids)).wait()));
|
16012
|
-
// SYCL_CHECK(CHECK_TRY_ERROR(stream->wait()));
|
16013
|
-
} else {
|
16014
|
-
memcpy(ids_host.data(), ids->data, ggml_nbytes(ids));
|
16015
|
-
}
|
16013
|
+
SYCL_CHECK(CHECK_TRY_ERROR(
|
16014
|
+
stream->memcpy(ids_host.data(), ids_dev, ggml_nbytes(ids))));
|
16015
|
+
SYCL_CHECK(CHECK_TRY_ERROR(stream->wait()));
|
16016
16016
|
|
16017
|
-
const ggml_tensor_extra_gpu *
|
16018
|
-
|
16017
|
+
const ggml_tensor_extra_gpu *src0_extra =
|
16018
|
+
(const ggml_tensor_extra_gpu *)src0->extra;
|
16019
|
+
const ggml_tensor_extra_gpu *src1_extra =
|
16020
|
+
(const ggml_tensor_extra_gpu *)src1->extra;
|
16021
|
+
const ggml_tensor_extra_gpu *dst_extra =
|
16022
|
+
(const ggml_tensor_extra_gpu *)dst->extra;
|
16019
16023
|
|
16024
|
+
ggml_tensor_extra_gpu src0_row_extra;
|
16020
16025
|
ggml_tensor_extra_gpu src1_row_extra;
|
16021
16026
|
ggml_tensor_extra_gpu dst_row_extra;
|
16022
16027
|
|
16028
|
+
ggml_tensor src0_row = *src0;
|
16023
16029
|
ggml_tensor src1_row = *src1;
|
16024
16030
|
ggml_tensor dst_row = *dst;
|
16025
16031
|
|
16026
16032
|
src1_row.backend = GGML_BACKEND_TYPE_GPU;
|
16027
16033
|
dst_row.backend = GGML_BACKEND_TYPE_GPU;
|
16028
16034
|
|
16035
|
+
src0_row.extra = &src0_row_extra;
|
16029
16036
|
src1_row.extra = &src1_row_extra;
|
16030
16037
|
dst_row.extra = &dst_row_extra;
|
16031
16038
|
|
16032
|
-
char *
|
16033
|
-
|
16034
|
-
|
16035
|
-
|
16039
|
+
char *src0_original = src1->backend == GGML_BACKEND_TYPE_CPU
|
16040
|
+
? (char *)src0->data
|
16041
|
+
: (char *)src0_extra->data_device[g_main_device];
|
16042
|
+
char *src1_original = src1->backend == GGML_BACKEND_TYPE_CPU
|
16043
|
+
? (char *)src1->data
|
16044
|
+
: (char *)src1_extra->data_device[g_main_device];
|
16045
|
+
char *dst_original = dst->backend == GGML_BACKEND_TYPE_CPU
|
16046
|
+
? (char *)dst->data
|
16047
|
+
: (char *)dst_extra->data_device[g_main_device];
|
16036
16048
|
|
16037
|
-
|
16038
|
-
|
16039
|
-
|
16049
|
+
src0_row.ne[2] = 1;
|
16050
|
+
src0_row.ne[3] = 1;
|
16051
|
+
src0_row.nb[3] = src0->nb[2];
|
16040
16052
|
|
16053
|
+
if (src1->ne[1] == 1) {
|
16041
16054
|
for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
|
16042
|
-
|
16043
|
-
|
16044
|
-
|
16045
|
-
|
16046
|
-
const int32_t row_id = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
|
16055
|
+
const int32_t row_id =
|
16056
|
+
*(const int32_t *)(ids_host.data() + i01 * ids->nb[1] +
|
16057
|
+
id * ids->nb[0]);
|
16047
16058
|
|
16048
16059
|
GGML_ASSERT(row_id >= 0 && row_id < n_as);
|
16049
16060
|
|
16050
|
-
|
16051
|
-
|
16052
|
-
src1_row_extra.data_device[g_main_device] =
|
16053
|
-
|
16061
|
+
src0_row_extra.data_device[g_main_device] =
|
16062
|
+
src0_original + row_id * src0->nb[2];
|
16063
|
+
src1_row_extra.data_device[g_main_device] =
|
16064
|
+
src1_original + i01 * src1->nb[1];
|
16065
|
+
dst_row_extra.data_device[g_main_device] =
|
16066
|
+
dst_original + i01 * dst->nb[1];
|
16054
16067
|
|
16055
|
-
|
16056
|
-
dst_row.data = (char *) dst->data + i01*dst->nb[1]; // TODO why is this set?
|
16057
|
-
|
16058
|
-
ggml_sycl_mul_mat(src0_row, &src1_row, &dst_row);
|
16068
|
+
ggml_sycl_mul_mat(&src0_row, &src1_row, &dst_row);
|
16059
16069
|
}
|
16060
16070
|
} else {
|
16061
16071
|
sycl_pool_alloc<char> src1_contiguous(sizeof(float)*ggml_nelements(src1));
|
@@ -16065,8 +16075,6 @@ static void ggml_sycl_mul_mat_id(const ggml_tensor *src0,
|
|
16065
16075
|
dst_row_extra.data_device[g_main_device] = dst_contiguous.get();
|
16066
16076
|
|
16067
16077
|
for (int32_t row_id = 0; row_id < n_as; ++row_id) {
|
16068
|
-
const struct ggml_tensor * src0_row = dst->src[row_id + 2];
|
16069
|
-
|
16070
16078
|
int64_t num_src1_rows = 0;
|
16071
16079
|
for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
|
16072
16080
|
const int32_t row_id_i = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
|
@@ -16079,7 +16087,7 @@ static void ggml_sycl_mul_mat_id(const ggml_tensor *src0,
|
|
16079
16087
|
|
16080
16088
|
SYCL_CHECK(CHECK_TRY_ERROR(
|
16081
16089
|
stream->memcpy(src1_contiguous.get() + num_src1_rows * nb11,
|
16082
|
-
src1_original + i01 * nb11, nb11)
|
16090
|
+
src1_original + i01 * nb11, nb11)));
|
16083
16091
|
num_src1_rows++;
|
16084
16092
|
}
|
16085
16093
|
|
@@ -16087,6 +16095,9 @@ static void ggml_sycl_mul_mat_id(const ggml_tensor *src0,
|
|
16087
16095
|
continue;
|
16088
16096
|
}
|
16089
16097
|
|
16098
|
+
src0_row_extra.data_device[g_main_device] =
|
16099
|
+
src0_original + row_id * src0->nb[2];
|
16100
|
+
|
16090
16101
|
src1_row.ne[1] = num_src1_rows;
|
16091
16102
|
dst_row.ne[1] = num_src1_rows;
|
16092
16103
|
|
@@ -16098,7 +16109,7 @@ static void ggml_sycl_mul_mat_id(const ggml_tensor *src0,
|
|
16098
16109
|
dst_row.nb[2] = num_src1_rows*nb1;
|
16099
16110
|
dst_row.nb[3] = num_src1_rows*nb1;
|
16100
16111
|
|
16101
|
-
ggml_sycl_mul_mat(src0_row, &src1_row, &dst_row);
|
16112
|
+
ggml_sycl_mul_mat(&src0_row, &src1_row, &dst_row);
|
16102
16113
|
|
16103
16114
|
num_src1_rows = 0;
|
16104
16115
|
for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
|
@@ -16112,7 +16123,7 @@ static void ggml_sycl_mul_mat_id(const ggml_tensor *src0,
|
|
16112
16123
|
|
16113
16124
|
SYCL_CHECK(CHECK_TRY_ERROR(stream->memcpy(
|
16114
16125
|
dst_original + i01 * nb1,
|
16115
|
-
dst_contiguous.get() + num_src1_rows * nb1, nb1)
|
16126
|
+
dst_contiguous.get() + num_src1_rows * nb1, nb1)));
|
16116
16127
|
num_src1_rows++;
|
16117
16128
|
}
|
16118
16129
|
}
|
@@ -16814,11 +16825,13 @@ static void ggml_backend_sycl_buffer_set_tensor(ggml_backend_buffer_t buffer,
|
|
16814
16825
|
const dpct::queue_ptr stream = g_syclStreams[ctx->device][0];
|
16815
16826
|
SYCL_CHECK(
|
16816
16827
|
CHECK_TRY_ERROR(dpct::dev_mgr::instance().get_device(ctx->device).queues_wait_and_throw()));
|
16817
|
-
|
16828
|
+
char* host_buf = (char*)malloc(size);
|
16829
|
+
memcpy(host_buf, data, size);
|
16818
16830
|
SYCL_CHECK(
|
16819
16831
|
CHECK_TRY_ERROR((*stream)
|
16820
|
-
.memcpy((char *)tensor->data + offset,
|
16832
|
+
.memcpy((char *)tensor->data + offset, host_buf, size)
|
16821
16833
|
.wait()));
|
16834
|
+
free(host_buf);
|
16822
16835
|
}
|
16823
16836
|
catch (sycl::exception const &exc) {
|
16824
16837
|
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
|
@@ -17739,7 +17752,7 @@ GGML_CALL static bool ggml_backend_sycl_supports_op(ggml_backend_t backend, cons
|
|
17739
17752
|
|
17740
17753
|
GGML_CALL static bool ggml_backend_sycl_offload_op(ggml_backend_t backend, const ggml_tensor * op) {
|
17741
17754
|
const int min_batch_size = 32;
|
17742
|
-
return op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS;
|
17755
|
+
return op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS && op->op != GGML_OP_MUL_MAT_ID;
|
17743
17756
|
GGML_UNUSED(backend);
|
17744
17757
|
}
|
17745
17758
|
|