llama_cpp 0.14.5 → 0.14.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +13 -0
- data/ext/llama_cpp/llama_cpp.cpp +37 -2
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +4 -0
- data/vendor/tmp/llama.cpp/Makefile +24 -7
- data/vendor/tmp/llama.cpp/ggml-alloc.c +8 -8
- data/vendor/tmp/llama.cpp/ggml-backend.c +14 -10
- data/vendor/tmp/llama.cpp/ggml-cuda.cu +135 -46
- data/vendor/tmp/llama.cpp/ggml-impl.h +263 -5
- data/vendor/tmp/llama.cpp/ggml-metal.m +130 -83
- data/vendor/tmp/llama.cpp/ggml-metal.metal +505 -1467
- data/vendor/tmp/llama.cpp/ggml-quants.c +1 -294
- data/vendor/tmp/llama.cpp/ggml-sycl.cpp +65 -52
- data/vendor/tmp/llama.cpp/ggml.c +151 -99
- data/vendor/tmp/llama.cpp/ggml.h +5 -4
- data/vendor/tmp/llama.cpp/llama.cpp +1308 -254
- data/vendor/tmp/llama.cpp/llama.h +19 -6
- data/vendor/tmp/llama.cpp/sgemm.cpp +999 -0
- data/vendor/tmp/llama.cpp/sgemm.h +12 -0
- metadata +4 -2
@@ -14,47 +14,6 @@
|
|
14
14
|
#include <stdlib.h> // for qsort
|
15
15
|
#include <stdio.h> // for GGML_ASSERT
|
16
16
|
|
17
|
-
#ifdef __ARM_NEON
|
18
|
-
|
19
|
-
// if YCM cannot find <arm_neon.h>, make a symbolic link to it, for example:
|
20
|
-
//
|
21
|
-
// $ ln -sfn /Library/Developer/CommandLineTools/usr/lib/clang/13.1.6/include/arm_neon.h ./src/
|
22
|
-
//
|
23
|
-
#include <arm_neon.h>
|
24
|
-
|
25
|
-
#else
|
26
|
-
|
27
|
-
#ifdef __wasm_simd128__
|
28
|
-
#include <wasm_simd128.h>
|
29
|
-
#else
|
30
|
-
#if defined(__POWER9_VECTOR__) || defined(__powerpc64__)
|
31
|
-
#include <altivec.h>
|
32
|
-
#undef bool
|
33
|
-
#define bool _Bool
|
34
|
-
#else
|
35
|
-
#if defined(_MSC_VER) || defined(__MINGW32__)
|
36
|
-
#include <intrin.h>
|
37
|
-
#else
|
38
|
-
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) || defined(__SSE3__)
|
39
|
-
#if !defined(__riscv)
|
40
|
-
#include <immintrin.h>
|
41
|
-
#endif
|
42
|
-
#endif
|
43
|
-
#endif
|
44
|
-
#endif
|
45
|
-
#endif
|
46
|
-
#endif
|
47
|
-
|
48
|
-
#ifdef __riscv_v_intrinsic
|
49
|
-
#include <riscv_vector.h>
|
50
|
-
#endif
|
51
|
-
|
52
|
-
#undef MIN
|
53
|
-
#undef MAX
|
54
|
-
|
55
|
-
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
56
|
-
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
57
|
-
|
58
17
|
#define UNUSED GGML_UNUSED
|
59
18
|
|
60
19
|
// some compilers don't provide _mm256_set_m128i, e.g. gcc 7
|
@@ -132,7 +91,7 @@ static inline __m256 sum_i16_pairs_float(const __m256i x) {
|
|
132
91
|
}
|
133
92
|
|
134
93
|
static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) {
|
135
|
-
#if defined(__AVXVNNI__) || defined(__AVX512VNNI__)
|
94
|
+
#if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__))
|
136
95
|
const __m256i zero = _mm256_setzero_si256();
|
137
96
|
const __m256i summed_pairs = _mm256_dpbusd_epi32(zero, ax, sy);
|
138
97
|
return _mm256_cvtepi32_ps(summed_pairs);
|
@@ -276,258 +235,6 @@ static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128
|
|
276
235
|
#endif // __AVX__ || __AVX2__ || __AVX512F__
|
277
236
|
#endif // defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
|
278
237
|
|
279
|
-
#if defined(__ARM_NEON)
|
280
|
-
|
281
|
-
#ifdef _MSC_VER
|
282
|
-
|
283
|
-
#define ggml_vld1q_u32(w,x,y,z) { ((w) + ((uint64_t)(x) << 32)), ((y) + ((uint64_t)(z) << 32)) }
|
284
|
-
|
285
|
-
#else
|
286
|
-
|
287
|
-
#define ggml_vld1q_u32(w,x,y,z) { (w), (x), (y), (z) }
|
288
|
-
|
289
|
-
#endif
|
290
|
-
|
291
|
-
#if !defined(__aarch64__)
|
292
|
-
|
293
|
-
// 64-bit compatibility
|
294
|
-
|
295
|
-
// vaddvq_s16
|
296
|
-
// vpaddq_s16
|
297
|
-
// vpaddq_s32
|
298
|
-
// vaddvq_s32
|
299
|
-
// vaddvq_f32
|
300
|
-
// vmaxvq_f32
|
301
|
-
// vcvtnq_s32_f32
|
302
|
-
// vzip1_u8
|
303
|
-
// vzip2_u8
|
304
|
-
|
305
|
-
inline static int32_t vaddvq_s16(int16x8_t v) {
|
306
|
-
return
|
307
|
-
(int32_t)vgetq_lane_s16(v, 0) + (int32_t)vgetq_lane_s16(v, 1) +
|
308
|
-
(int32_t)vgetq_lane_s16(v, 2) + (int32_t)vgetq_lane_s16(v, 3) +
|
309
|
-
(int32_t)vgetq_lane_s16(v, 4) + (int32_t)vgetq_lane_s16(v, 5) +
|
310
|
-
(int32_t)vgetq_lane_s16(v, 6) + (int32_t)vgetq_lane_s16(v, 7);
|
311
|
-
}
|
312
|
-
|
313
|
-
inline static int16x8_t vpaddq_s16(int16x8_t a, int16x8_t b) {
|
314
|
-
int16x4_t a0 = vpadd_s16(vget_low_s16(a), vget_high_s16(a));
|
315
|
-
int16x4_t b0 = vpadd_s16(vget_low_s16(b), vget_high_s16(b));
|
316
|
-
return vcombine_s16(a0, b0);
|
317
|
-
}
|
318
|
-
|
319
|
-
inline static int32x4_t vpaddq_s32(int32x4_t a, int32x4_t b) {
|
320
|
-
int32x2_t a0 = vpadd_s32(vget_low_s32(a), vget_high_s32(a));
|
321
|
-
int32x2_t b0 = vpadd_s32(vget_low_s32(b), vget_high_s32(b));
|
322
|
-
return vcombine_s32(a0, b0);
|
323
|
-
}
|
324
|
-
|
325
|
-
inline static int32_t vaddvq_s32(int32x4_t v) {
|
326
|
-
return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3);
|
327
|
-
}
|
328
|
-
|
329
|
-
inline static float vaddvq_f32(float32x4_t v) {
|
330
|
-
return vgetq_lane_f32(v, 0) + vgetq_lane_f32(v, 1) + vgetq_lane_f32(v, 2) + vgetq_lane_f32(v, 3);
|
331
|
-
}
|
332
|
-
|
333
|
-
inline static float vmaxvq_f32(float32x4_t v) {
|
334
|
-
return
|
335
|
-
MAX(MAX(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)),
|
336
|
-
MAX(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3)));
|
337
|
-
}
|
338
|
-
|
339
|
-
inline static int32x4_t vcvtnq_s32_f32(float32x4_t v) {
|
340
|
-
int32x4_t res;
|
341
|
-
|
342
|
-
res[0] = roundf(vgetq_lane_f32(v, 0));
|
343
|
-
res[1] = roundf(vgetq_lane_f32(v, 1));
|
344
|
-
res[2] = roundf(vgetq_lane_f32(v, 2));
|
345
|
-
res[3] = roundf(vgetq_lane_f32(v, 3));
|
346
|
-
|
347
|
-
return res;
|
348
|
-
}
|
349
|
-
|
350
|
-
inline static uint8x8_t vzip1_u8(uint8x8_t a, uint8x8_t b) {
|
351
|
-
uint8x8_t res;
|
352
|
-
|
353
|
-
res[0] = a[0]; res[1] = b[0];
|
354
|
-
res[2] = a[1]; res[3] = b[1];
|
355
|
-
res[4] = a[2]; res[5] = b[2];
|
356
|
-
res[6] = a[3]; res[7] = b[3];
|
357
|
-
|
358
|
-
return res;
|
359
|
-
}
|
360
|
-
|
361
|
-
inline static uint8x8_t vzip2_u8(uint8x8_t a, uint8x8_t b) {
|
362
|
-
uint8x8_t res;
|
363
|
-
|
364
|
-
res[0] = a[4]; res[1] = b[4];
|
365
|
-
res[2] = a[5]; res[3] = b[5];
|
366
|
-
res[4] = a[6]; res[5] = b[6];
|
367
|
-
res[6] = a[7]; res[7] = b[7];
|
368
|
-
|
369
|
-
return res;
|
370
|
-
}
|
371
|
-
|
372
|
-
// vld1q_s16_x2
|
373
|
-
// vld1q_u8_x2
|
374
|
-
// vld1q_u8_x4
|
375
|
-
// vld1q_s8_x2
|
376
|
-
// vld1q_s8_x4
|
377
|
-
// TODO: double-check these work correctly
|
378
|
-
|
379
|
-
typedef struct ggml_int16x8x2_t {
|
380
|
-
int16x8_t val[2];
|
381
|
-
} ggml_int16x8x2_t;
|
382
|
-
|
383
|
-
inline static ggml_int16x8x2_t ggml_vld1q_s16_x2(const int16_t * ptr) {
|
384
|
-
ggml_int16x8x2_t res;
|
385
|
-
|
386
|
-
res.val[0] = vld1q_s16(ptr + 0);
|
387
|
-
res.val[1] = vld1q_s16(ptr + 8);
|
388
|
-
|
389
|
-
return res;
|
390
|
-
}
|
391
|
-
|
392
|
-
typedef struct ggml_uint8x16x2_t {
|
393
|
-
uint8x16_t val[2];
|
394
|
-
} ggml_uint8x16x2_t;
|
395
|
-
|
396
|
-
inline static ggml_uint8x16x2_t ggml_vld1q_u8_x2(const uint8_t * ptr) {
|
397
|
-
ggml_uint8x16x2_t res;
|
398
|
-
|
399
|
-
res.val[0] = vld1q_u8(ptr + 0);
|
400
|
-
res.val[1] = vld1q_u8(ptr + 16);
|
401
|
-
|
402
|
-
return res;
|
403
|
-
}
|
404
|
-
|
405
|
-
typedef struct ggml_uint8x16x4_t {
|
406
|
-
uint8x16_t val[4];
|
407
|
-
} ggml_uint8x16x4_t;
|
408
|
-
|
409
|
-
inline static ggml_uint8x16x4_t ggml_vld1q_u8_x4(const uint8_t * ptr) {
|
410
|
-
ggml_uint8x16x4_t res;
|
411
|
-
|
412
|
-
res.val[0] = vld1q_u8(ptr + 0);
|
413
|
-
res.val[1] = vld1q_u8(ptr + 16);
|
414
|
-
res.val[2] = vld1q_u8(ptr + 32);
|
415
|
-
res.val[3] = vld1q_u8(ptr + 48);
|
416
|
-
|
417
|
-
return res;
|
418
|
-
}
|
419
|
-
|
420
|
-
typedef struct ggml_int8x16x2_t {
|
421
|
-
int8x16_t val[2];
|
422
|
-
} ggml_int8x16x2_t;
|
423
|
-
|
424
|
-
inline static ggml_int8x16x2_t ggml_vld1q_s8_x2(const int8_t * ptr) {
|
425
|
-
ggml_int8x16x2_t res;
|
426
|
-
|
427
|
-
res.val[0] = vld1q_s8(ptr + 0);
|
428
|
-
res.val[1] = vld1q_s8(ptr + 16);
|
429
|
-
|
430
|
-
return res;
|
431
|
-
}
|
432
|
-
|
433
|
-
typedef struct ggml_int8x16x4_t {
|
434
|
-
int8x16_t val[4];
|
435
|
-
} ggml_int8x16x4_t;
|
436
|
-
|
437
|
-
inline static ggml_int8x16x4_t ggml_vld1q_s8_x4(const int8_t * ptr) {
|
438
|
-
ggml_int8x16x4_t res;
|
439
|
-
|
440
|
-
res.val[0] = vld1q_s8(ptr + 0);
|
441
|
-
res.val[1] = vld1q_s8(ptr + 16);
|
442
|
-
res.val[2] = vld1q_s8(ptr + 32);
|
443
|
-
res.val[3] = vld1q_s8(ptr + 48);
|
444
|
-
|
445
|
-
return res;
|
446
|
-
}
|
447
|
-
|
448
|
-
// NOTE: not tested
|
449
|
-
inline static int8x16_t ggml_vqtbl1q_s8(int8x16_t a, uint8x16_t b) {
|
450
|
-
int8x16_t res;
|
451
|
-
|
452
|
-
res[ 0] = a[b[ 0]];
|
453
|
-
res[ 1] = a[b[ 1]];
|
454
|
-
res[ 2] = a[b[ 2]];
|
455
|
-
res[ 3] = a[b[ 3]];
|
456
|
-
res[ 4] = a[b[ 4]];
|
457
|
-
res[ 5] = a[b[ 5]];
|
458
|
-
res[ 6] = a[b[ 6]];
|
459
|
-
res[ 7] = a[b[ 7]];
|
460
|
-
res[ 8] = a[b[ 8]];
|
461
|
-
res[ 9] = a[b[ 9]];
|
462
|
-
res[10] = a[b[10]];
|
463
|
-
res[11] = a[b[11]];
|
464
|
-
res[12] = a[b[12]];
|
465
|
-
res[13] = a[b[13]];
|
466
|
-
res[14] = a[b[14]];
|
467
|
-
res[15] = a[b[15]];
|
468
|
-
|
469
|
-
return res;
|
470
|
-
}
|
471
|
-
|
472
|
-
// NOTE: not tested
|
473
|
-
inline static uint8x16_t ggml_vqtbl1q_u8(uint8x16_t a, uint8x16_t b) {
|
474
|
-
uint8x16_t res;
|
475
|
-
|
476
|
-
res[ 0] = a[b[ 0]];
|
477
|
-
res[ 1] = a[b[ 1]];
|
478
|
-
res[ 2] = a[b[ 2]];
|
479
|
-
res[ 3] = a[b[ 3]];
|
480
|
-
res[ 4] = a[b[ 4]];
|
481
|
-
res[ 5] = a[b[ 5]];
|
482
|
-
res[ 6] = a[b[ 6]];
|
483
|
-
res[ 7] = a[b[ 7]];
|
484
|
-
res[ 8] = a[b[ 8]];
|
485
|
-
res[ 9] = a[b[ 9]];
|
486
|
-
res[10] = a[b[10]];
|
487
|
-
res[11] = a[b[11]];
|
488
|
-
res[12] = a[b[12]];
|
489
|
-
res[13] = a[b[13]];
|
490
|
-
res[14] = a[b[14]];
|
491
|
-
res[15] = a[b[15]];
|
492
|
-
|
493
|
-
return res;
|
494
|
-
}
|
495
|
-
|
496
|
-
#else
|
497
|
-
|
498
|
-
#define ggml_int16x8x2_t int16x8x2_t
|
499
|
-
#define ggml_uint8x16x2_t uint8x16x2_t
|
500
|
-
#define ggml_uint8x16x4_t uint8x16x4_t
|
501
|
-
#define ggml_int8x16x2_t int8x16x2_t
|
502
|
-
#define ggml_int8x16x4_t int8x16x4_t
|
503
|
-
|
504
|
-
#define ggml_vld1q_s16_x2 vld1q_s16_x2
|
505
|
-
#define ggml_vld1q_u8_x2 vld1q_u8_x2
|
506
|
-
#define ggml_vld1q_u8_x4 vld1q_u8_x4
|
507
|
-
#define ggml_vld1q_s8_x2 vld1q_s8_x2
|
508
|
-
#define ggml_vld1q_s8_x4 vld1q_s8_x4
|
509
|
-
#define ggml_vqtbl1q_s8 vqtbl1q_s8
|
510
|
-
#define ggml_vqtbl1q_u8 vqtbl1q_u8
|
511
|
-
|
512
|
-
#endif
|
513
|
-
|
514
|
-
#if !defined(__ARM_FEATURE_DOTPROD)
|
515
|
-
|
516
|
-
inline static int32x4_t ggml_vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b) {
|
517
|
-
const int16x8_t p0 = vmull_s8(vget_low_s8 (a), vget_low_s8 (b));
|
518
|
-
const int16x8_t p1 = vmull_s8(vget_high_s8(a), vget_high_s8(b));
|
519
|
-
|
520
|
-
return vaddq_s32(acc, vaddq_s32(vpaddlq_s16(p0), vpaddlq_s16(p1)));
|
521
|
-
}
|
522
|
-
|
523
|
-
#else
|
524
|
-
|
525
|
-
#define ggml_vdotq_s32(a, b, c) vdotq_s32(a, b, c)
|
526
|
-
|
527
|
-
#endif
|
528
|
-
|
529
|
-
#endif
|
530
|
-
|
531
238
|
#if defined(__ARM_NEON) || defined(__wasm_simd128__)
|
532
239
|
#define B1(c,s,n) 0x ## n ## c , 0x ## n ## s
|
533
240
|
#define B2(c,s,n) B1(c,s,n ## c), B1(c,s,n ## s)
|
@@ -3154,7 +3154,6 @@ typedef float (*vec_dot_q_mul_mat_sycl_t)(
|
|
3154
3154
|
#define SYCL_SCALE_BLOCK_SIZE 256
|
3155
3155
|
#define SYCL_CLAMP_BLOCK_SIZE 256
|
3156
3156
|
#define SYCL_ROPE_BLOCK_SIZE 256
|
3157
|
-
#define SYCL_SOFT_MAX_BLOCK_SIZE 1024
|
3158
3157
|
#define SYCL_ALIBI_BLOCK_SIZE 32
|
3159
3158
|
#define SYCL_DIAG_MASK_INF_BLOCK_SIZE 32
|
3160
3159
|
#define SYCL_QUANTIZE_BLOCK_SIZE 256
|
@@ -13080,11 +13079,13 @@ static void soft_max_f32_sycl(const float * x, const float * mask, const float *
|
|
13080
13079
|
const int nrows_y, const float scale, const float max_bias,
|
13081
13080
|
dpct::queue_ptr stream) {
|
13082
13081
|
int nth = WARP_SIZE;
|
13083
|
-
|
13082
|
+
int max_block_size = g_work_group_size;
|
13083
|
+
while (nth < ncols_x && nth < max_block_size) nth *= 2;
|
13084
|
+
if (nth>max_block_size) nth = max_block_size;
|
13085
|
+
|
13084
13086
|
const sycl::range<3> block_dims(1, 1, nth);
|
13085
13087
|
const sycl::range<3> block_nums(1, 1, nrows_x);
|
13086
13088
|
const size_t n_local_scratch = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE);
|
13087
|
-
static_assert(SYCL_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
|
13088
13089
|
|
13089
13090
|
const uint32_t n_head_kv = nrows_x/nrows_y;
|
13090
13091
|
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
|
@@ -13094,6 +13095,12 @@ static void soft_max_f32_sycl(const float * x, const float * mask, const float *
|
|
13094
13095
|
|
13095
13096
|
const size_t local_mem_size = stream->get_device().get_info<sycl::info::device::local_mem_size>();
|
13096
13097
|
if (n_local_scratch*sizeof(float) < local_mem_size) {
|
13098
|
+
if (ncols_x > max_block_size) {
|
13099
|
+
soft_max_f32_submitter<true, 0, 0>(x, mask, pos, dst, ncols_x, nrows_y, scale,
|
13100
|
+
max_bias, m0, m1, n_head_log2, block_nums,
|
13101
|
+
block_dims, n_local_scratch, stream);
|
13102
|
+
return;
|
13103
|
+
}
|
13097
13104
|
switch (ncols_x) {
|
13098
13105
|
case 32:
|
13099
13106
|
soft_max_f32_submitter<true, 32, 32>(x, mask, pos, dst, ncols_x, nrows_y, scale,
|
@@ -15989,73 +15996,76 @@ static void ggml_sycl_mul_mat_id_sycl(ggml_tensor * dst) {
|
|
15989
15996
|
static void ggml_sycl_mul_mat_id(const ggml_tensor *src0,
|
15990
15997
|
const ggml_tensor *src1,
|
15991
15998
|
ggml_tensor *dst) try {
|
15992
|
-
|
15993
|
-
|
15994
|
-
|
15995
|
-
|
15999
|
+
GGML_ASSERT(src0->backend != GGML_BACKEND_TYPE_GPU_SPLIT &&
|
16000
|
+
"mul_mat_id does not support split buffers");
|
16001
|
+
const ggml_tensor *ids = dst->src[2];
|
16002
|
+
const dpct::queue_ptr stream = g_syclStreams[g_main_device][0];
|
15996
16003
|
|
15997
|
-
const
|
15998
|
-
const
|
16004
|
+
const size_t nb11 = src1->nb[1];
|
16005
|
+
const size_t nb1 = dst->nb[1];
|
15999
16006
|
|
16000
|
-
const
|
16001
|
-
const int32_t
|
16002
|
-
const int32_t n_as = ((int32_t *) dst->op_params)[1];
|
16007
|
+
const int32_t id = ((int32_t *)dst->op_params)[0];
|
16008
|
+
const int32_t n_as = src0->ne[2];
|
16003
16009
|
|
16004
16010
|
std::vector<char> ids_host(ggml_nbytes(ids));
|
16011
|
+
const char *ids_dev = (const char *)ids->data;
|
16005
16012
|
|
16006
|
-
|
16007
|
-
|
16008
|
-
|
16009
|
-
const char * ids_dev = (const char *)((const ggml_tensor_extra_gpu *)ids->extra)->data_device[g_main_device];
|
16010
|
-
SYCL_CHECK(CHECK_TRY_ERROR(
|
16011
|
-
stream->memcpy(ids_host.data(), ids_dev, ggml_nbytes(ids)).wait()));
|
16012
|
-
// SYCL_CHECK(CHECK_TRY_ERROR(stream->wait()));
|
16013
|
-
} else {
|
16014
|
-
memcpy(ids_host.data(), ids->data, ggml_nbytes(ids));
|
16015
|
-
}
|
16013
|
+
SYCL_CHECK(CHECK_TRY_ERROR(
|
16014
|
+
stream->memcpy(ids_host.data(), ids_dev, ggml_nbytes(ids))));
|
16015
|
+
SYCL_CHECK(CHECK_TRY_ERROR(stream->wait()));
|
16016
16016
|
|
16017
|
-
const ggml_tensor_extra_gpu *
|
16018
|
-
|
16017
|
+
const ggml_tensor_extra_gpu *src0_extra =
|
16018
|
+
(const ggml_tensor_extra_gpu *)src0->extra;
|
16019
|
+
const ggml_tensor_extra_gpu *src1_extra =
|
16020
|
+
(const ggml_tensor_extra_gpu *)src1->extra;
|
16021
|
+
const ggml_tensor_extra_gpu *dst_extra =
|
16022
|
+
(const ggml_tensor_extra_gpu *)dst->extra;
|
16019
16023
|
|
16024
|
+
ggml_tensor_extra_gpu src0_row_extra;
|
16020
16025
|
ggml_tensor_extra_gpu src1_row_extra;
|
16021
16026
|
ggml_tensor_extra_gpu dst_row_extra;
|
16022
16027
|
|
16028
|
+
ggml_tensor src0_row = *src0;
|
16023
16029
|
ggml_tensor src1_row = *src1;
|
16024
16030
|
ggml_tensor dst_row = *dst;
|
16025
16031
|
|
16026
16032
|
src1_row.backend = GGML_BACKEND_TYPE_GPU;
|
16027
16033
|
dst_row.backend = GGML_BACKEND_TYPE_GPU;
|
16028
16034
|
|
16035
|
+
src0_row.extra = &src0_row_extra;
|
16029
16036
|
src1_row.extra = &src1_row_extra;
|
16030
16037
|
dst_row.extra = &dst_row_extra;
|
16031
16038
|
|
16032
|
-
char *
|
16033
|
-
|
16034
|
-
|
16035
|
-
|
16039
|
+
char *src0_original = src1->backend == GGML_BACKEND_TYPE_CPU
|
16040
|
+
? (char *)src0->data
|
16041
|
+
: (char *)src0_extra->data_device[g_main_device];
|
16042
|
+
char *src1_original = src1->backend == GGML_BACKEND_TYPE_CPU
|
16043
|
+
? (char *)src1->data
|
16044
|
+
: (char *)src1_extra->data_device[g_main_device];
|
16045
|
+
char *dst_original = dst->backend == GGML_BACKEND_TYPE_CPU
|
16046
|
+
? (char *)dst->data
|
16047
|
+
: (char *)dst_extra->data_device[g_main_device];
|
16036
16048
|
|
16037
|
-
|
16038
|
-
|
16039
|
-
|
16049
|
+
src0_row.ne[2] = 1;
|
16050
|
+
src0_row.ne[3] = 1;
|
16051
|
+
src0_row.nb[3] = src0->nb[2];
|
16040
16052
|
|
16053
|
+
if (src1->ne[1] == 1) {
|
16041
16054
|
for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
|
16042
|
-
|
16043
|
-
|
16044
|
-
|
16045
|
-
|
16046
|
-
const int32_t row_id = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
|
16055
|
+
const int32_t row_id =
|
16056
|
+
*(const int32_t *)(ids_host.data() + i01 * ids->nb[1] +
|
16057
|
+
id * ids->nb[0]);
|
16047
16058
|
|
16048
16059
|
GGML_ASSERT(row_id >= 0 && row_id < n_as);
|
16049
16060
|
|
16050
|
-
|
16051
|
-
|
16052
|
-
src1_row_extra.data_device[g_main_device] =
|
16053
|
-
|
16061
|
+
src0_row_extra.data_device[g_main_device] =
|
16062
|
+
src0_original + row_id * src0->nb[2];
|
16063
|
+
src1_row_extra.data_device[g_main_device] =
|
16064
|
+
src1_original + i01 * src1->nb[1];
|
16065
|
+
dst_row_extra.data_device[g_main_device] =
|
16066
|
+
dst_original + i01 * dst->nb[1];
|
16054
16067
|
|
16055
|
-
|
16056
|
-
dst_row.data = (char *) dst->data + i01*dst->nb[1]; // TODO why is this set?
|
16057
|
-
|
16058
|
-
ggml_sycl_mul_mat(src0_row, &src1_row, &dst_row);
|
16068
|
+
ggml_sycl_mul_mat(&src0_row, &src1_row, &dst_row);
|
16059
16069
|
}
|
16060
16070
|
} else {
|
16061
16071
|
sycl_pool_alloc<char> src1_contiguous(sizeof(float)*ggml_nelements(src1));
|
@@ -16065,8 +16075,6 @@ static void ggml_sycl_mul_mat_id(const ggml_tensor *src0,
|
|
16065
16075
|
dst_row_extra.data_device[g_main_device] = dst_contiguous.get();
|
16066
16076
|
|
16067
16077
|
for (int32_t row_id = 0; row_id < n_as; ++row_id) {
|
16068
|
-
const struct ggml_tensor * src0_row = dst->src[row_id + 2];
|
16069
|
-
|
16070
16078
|
int64_t num_src1_rows = 0;
|
16071
16079
|
for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
|
16072
16080
|
const int32_t row_id_i = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
|
@@ -16079,7 +16087,7 @@ static void ggml_sycl_mul_mat_id(const ggml_tensor *src0,
|
|
16079
16087
|
|
16080
16088
|
SYCL_CHECK(CHECK_TRY_ERROR(
|
16081
16089
|
stream->memcpy(src1_contiguous.get() + num_src1_rows * nb11,
|
16082
|
-
src1_original + i01 * nb11, nb11)
|
16090
|
+
src1_original + i01 * nb11, nb11)));
|
16083
16091
|
num_src1_rows++;
|
16084
16092
|
}
|
16085
16093
|
|
@@ -16087,6 +16095,9 @@ static void ggml_sycl_mul_mat_id(const ggml_tensor *src0,
|
|
16087
16095
|
continue;
|
16088
16096
|
}
|
16089
16097
|
|
16098
|
+
src0_row_extra.data_device[g_main_device] =
|
16099
|
+
src0_original + row_id * src0->nb[2];
|
16100
|
+
|
16090
16101
|
src1_row.ne[1] = num_src1_rows;
|
16091
16102
|
dst_row.ne[1] = num_src1_rows;
|
16092
16103
|
|
@@ -16098,7 +16109,7 @@ static void ggml_sycl_mul_mat_id(const ggml_tensor *src0,
|
|
16098
16109
|
dst_row.nb[2] = num_src1_rows*nb1;
|
16099
16110
|
dst_row.nb[3] = num_src1_rows*nb1;
|
16100
16111
|
|
16101
|
-
ggml_sycl_mul_mat(src0_row, &src1_row, &dst_row);
|
16112
|
+
ggml_sycl_mul_mat(&src0_row, &src1_row, &dst_row);
|
16102
16113
|
|
16103
16114
|
num_src1_rows = 0;
|
16104
16115
|
for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
|
@@ -16112,7 +16123,7 @@ static void ggml_sycl_mul_mat_id(const ggml_tensor *src0,
|
|
16112
16123
|
|
16113
16124
|
SYCL_CHECK(CHECK_TRY_ERROR(stream->memcpy(
|
16114
16125
|
dst_original + i01 * nb1,
|
16115
|
-
dst_contiguous.get() + num_src1_rows * nb1, nb1)
|
16126
|
+
dst_contiguous.get() + num_src1_rows * nb1, nb1)));
|
16116
16127
|
num_src1_rows++;
|
16117
16128
|
}
|
16118
16129
|
}
|
@@ -16814,11 +16825,13 @@ static void ggml_backend_sycl_buffer_set_tensor(ggml_backend_buffer_t buffer,
|
|
16814
16825
|
const dpct::queue_ptr stream = g_syclStreams[ctx->device][0];
|
16815
16826
|
SYCL_CHECK(
|
16816
16827
|
CHECK_TRY_ERROR(dpct::dev_mgr::instance().get_device(ctx->device).queues_wait_and_throw()));
|
16817
|
-
|
16828
|
+
char* host_buf = (char*)malloc(size);
|
16829
|
+
memcpy(host_buf, data, size);
|
16818
16830
|
SYCL_CHECK(
|
16819
16831
|
CHECK_TRY_ERROR((*stream)
|
16820
|
-
.memcpy((char *)tensor->data + offset,
|
16832
|
+
.memcpy((char *)tensor->data + offset, host_buf, size)
|
16821
16833
|
.wait()));
|
16834
|
+
free(host_buf);
|
16822
16835
|
}
|
16823
16836
|
catch (sycl::exception const &exc) {
|
16824
16837
|
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
|
@@ -17739,7 +17752,7 @@ GGML_CALL static bool ggml_backend_sycl_supports_op(ggml_backend_t backend, cons
|
|
17739
17752
|
|
17740
17753
|
GGML_CALL static bool ggml_backend_sycl_offload_op(ggml_backend_t backend, const ggml_tensor * op) {
|
17741
17754
|
const int min_batch_size = 32;
|
17742
|
-
return op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS;
|
17755
|
+
return op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS && op->op != GGML_OP_MUL_MAT_ID;
|
17743
17756
|
GGML_UNUSED(backend);
|
17744
17757
|
}
|
17745
17758
|
|