llama_cpp 0.14.5 → 0.14.7

Sign up to get free protection for your applications and to get access to all the features.
@@ -14,47 +14,6 @@
14
14
  #include <stdlib.h> // for qsort
15
15
  #include <stdio.h> // for GGML_ASSERT
16
16
 
17
- #ifdef __ARM_NEON
18
-
19
- // if YCM cannot find <arm_neon.h>, make a symbolic link to it, for example:
20
- //
21
- // $ ln -sfn /Library/Developer/CommandLineTools/usr/lib/clang/13.1.6/include/arm_neon.h ./src/
22
- //
23
- #include <arm_neon.h>
24
-
25
- #else
26
-
27
- #ifdef __wasm_simd128__
28
- #include <wasm_simd128.h>
29
- #else
30
- #if defined(__POWER9_VECTOR__) || defined(__powerpc64__)
31
- #include <altivec.h>
32
- #undef bool
33
- #define bool _Bool
34
- #else
35
- #if defined(_MSC_VER) || defined(__MINGW32__)
36
- #include <intrin.h>
37
- #else
38
- #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) || defined(__SSE3__)
39
- #if !defined(__riscv)
40
- #include <immintrin.h>
41
- #endif
42
- #endif
43
- #endif
44
- #endif
45
- #endif
46
- #endif
47
-
48
- #ifdef __riscv_v_intrinsic
49
- #include <riscv_vector.h>
50
- #endif
51
-
52
- #undef MIN
53
- #undef MAX
54
-
55
- #define MIN(a, b) ((a) < (b) ? (a) : (b))
56
- #define MAX(a, b) ((a) > (b) ? (a) : (b))
57
-
58
17
  #define UNUSED GGML_UNUSED
59
18
 
60
19
  // some compilers don't provide _mm256_set_m128i, e.g. gcc 7
@@ -132,7 +91,7 @@ static inline __m256 sum_i16_pairs_float(const __m256i x) {
132
91
  }
133
92
 
134
93
  static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) {
135
- #if defined(__AVXVNNI__) || defined(__AVX512VNNI__)
94
+ #if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__))
136
95
  const __m256i zero = _mm256_setzero_si256();
137
96
  const __m256i summed_pairs = _mm256_dpbusd_epi32(zero, ax, sy);
138
97
  return _mm256_cvtepi32_ps(summed_pairs);
@@ -276,258 +235,6 @@ static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128
276
235
  #endif // __AVX__ || __AVX2__ || __AVX512F__
277
236
  #endif // defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
278
237
 
279
- #if defined(__ARM_NEON)
280
-
281
- #ifdef _MSC_VER
282
-
283
- #define ggml_vld1q_u32(w,x,y,z) { ((w) + ((uint64_t)(x) << 32)), ((y) + ((uint64_t)(z) << 32)) }
284
-
285
- #else
286
-
287
- #define ggml_vld1q_u32(w,x,y,z) { (w), (x), (y), (z) }
288
-
289
- #endif
290
-
291
- #if !defined(__aarch64__)
292
-
293
- // 64-bit compatibility
294
-
295
- // vaddvq_s16
296
- // vpaddq_s16
297
- // vpaddq_s32
298
- // vaddvq_s32
299
- // vaddvq_f32
300
- // vmaxvq_f32
301
- // vcvtnq_s32_f32
302
- // vzip1_u8
303
- // vzip2_u8
304
-
305
- inline static int32_t vaddvq_s16(int16x8_t v) {
306
- return
307
- (int32_t)vgetq_lane_s16(v, 0) + (int32_t)vgetq_lane_s16(v, 1) +
308
- (int32_t)vgetq_lane_s16(v, 2) + (int32_t)vgetq_lane_s16(v, 3) +
309
- (int32_t)vgetq_lane_s16(v, 4) + (int32_t)vgetq_lane_s16(v, 5) +
310
- (int32_t)vgetq_lane_s16(v, 6) + (int32_t)vgetq_lane_s16(v, 7);
311
- }
312
-
313
- inline static int16x8_t vpaddq_s16(int16x8_t a, int16x8_t b) {
314
- int16x4_t a0 = vpadd_s16(vget_low_s16(a), vget_high_s16(a));
315
- int16x4_t b0 = vpadd_s16(vget_low_s16(b), vget_high_s16(b));
316
- return vcombine_s16(a0, b0);
317
- }
318
-
319
- inline static int32x4_t vpaddq_s32(int32x4_t a, int32x4_t b) {
320
- int32x2_t a0 = vpadd_s32(vget_low_s32(a), vget_high_s32(a));
321
- int32x2_t b0 = vpadd_s32(vget_low_s32(b), vget_high_s32(b));
322
- return vcombine_s32(a0, b0);
323
- }
324
-
325
- inline static int32_t vaddvq_s32(int32x4_t v) {
326
- return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3);
327
- }
328
-
329
- inline static float vaddvq_f32(float32x4_t v) {
330
- return vgetq_lane_f32(v, 0) + vgetq_lane_f32(v, 1) + vgetq_lane_f32(v, 2) + vgetq_lane_f32(v, 3);
331
- }
332
-
333
- inline static float vmaxvq_f32(float32x4_t v) {
334
- return
335
- MAX(MAX(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)),
336
- MAX(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3)));
337
- }
338
-
339
- inline static int32x4_t vcvtnq_s32_f32(float32x4_t v) {
340
- int32x4_t res;
341
-
342
- res[0] = roundf(vgetq_lane_f32(v, 0));
343
- res[1] = roundf(vgetq_lane_f32(v, 1));
344
- res[2] = roundf(vgetq_lane_f32(v, 2));
345
- res[3] = roundf(vgetq_lane_f32(v, 3));
346
-
347
- return res;
348
- }
349
-
350
- inline static uint8x8_t vzip1_u8(uint8x8_t a, uint8x8_t b) {
351
- uint8x8_t res;
352
-
353
- res[0] = a[0]; res[1] = b[0];
354
- res[2] = a[1]; res[3] = b[1];
355
- res[4] = a[2]; res[5] = b[2];
356
- res[6] = a[3]; res[7] = b[3];
357
-
358
- return res;
359
- }
360
-
361
- inline static uint8x8_t vzip2_u8(uint8x8_t a, uint8x8_t b) {
362
- uint8x8_t res;
363
-
364
- res[0] = a[4]; res[1] = b[4];
365
- res[2] = a[5]; res[3] = b[5];
366
- res[4] = a[6]; res[5] = b[6];
367
- res[6] = a[7]; res[7] = b[7];
368
-
369
- return res;
370
- }
371
-
372
- // vld1q_s16_x2
373
- // vld1q_u8_x2
374
- // vld1q_u8_x4
375
- // vld1q_s8_x2
376
- // vld1q_s8_x4
377
- // TODO: double-check these work correctly
378
-
379
- typedef struct ggml_int16x8x2_t {
380
- int16x8_t val[2];
381
- } ggml_int16x8x2_t;
382
-
383
- inline static ggml_int16x8x2_t ggml_vld1q_s16_x2(const int16_t * ptr) {
384
- ggml_int16x8x2_t res;
385
-
386
- res.val[0] = vld1q_s16(ptr + 0);
387
- res.val[1] = vld1q_s16(ptr + 8);
388
-
389
- return res;
390
- }
391
-
392
- typedef struct ggml_uint8x16x2_t {
393
- uint8x16_t val[2];
394
- } ggml_uint8x16x2_t;
395
-
396
- inline static ggml_uint8x16x2_t ggml_vld1q_u8_x2(const uint8_t * ptr) {
397
- ggml_uint8x16x2_t res;
398
-
399
- res.val[0] = vld1q_u8(ptr + 0);
400
- res.val[1] = vld1q_u8(ptr + 16);
401
-
402
- return res;
403
- }
404
-
405
- typedef struct ggml_uint8x16x4_t {
406
- uint8x16_t val[4];
407
- } ggml_uint8x16x4_t;
408
-
409
- inline static ggml_uint8x16x4_t ggml_vld1q_u8_x4(const uint8_t * ptr) {
410
- ggml_uint8x16x4_t res;
411
-
412
- res.val[0] = vld1q_u8(ptr + 0);
413
- res.val[1] = vld1q_u8(ptr + 16);
414
- res.val[2] = vld1q_u8(ptr + 32);
415
- res.val[3] = vld1q_u8(ptr + 48);
416
-
417
- return res;
418
- }
419
-
420
- typedef struct ggml_int8x16x2_t {
421
- int8x16_t val[2];
422
- } ggml_int8x16x2_t;
423
-
424
- inline static ggml_int8x16x2_t ggml_vld1q_s8_x2(const int8_t * ptr) {
425
- ggml_int8x16x2_t res;
426
-
427
- res.val[0] = vld1q_s8(ptr + 0);
428
- res.val[1] = vld1q_s8(ptr + 16);
429
-
430
- return res;
431
- }
432
-
433
- typedef struct ggml_int8x16x4_t {
434
- int8x16_t val[4];
435
- } ggml_int8x16x4_t;
436
-
437
- inline static ggml_int8x16x4_t ggml_vld1q_s8_x4(const int8_t * ptr) {
438
- ggml_int8x16x4_t res;
439
-
440
- res.val[0] = vld1q_s8(ptr + 0);
441
- res.val[1] = vld1q_s8(ptr + 16);
442
- res.val[2] = vld1q_s8(ptr + 32);
443
- res.val[3] = vld1q_s8(ptr + 48);
444
-
445
- return res;
446
- }
447
-
448
- // NOTE: not tested
449
- inline static int8x16_t ggml_vqtbl1q_s8(int8x16_t a, uint8x16_t b) {
450
- int8x16_t res;
451
-
452
- res[ 0] = a[b[ 0]];
453
- res[ 1] = a[b[ 1]];
454
- res[ 2] = a[b[ 2]];
455
- res[ 3] = a[b[ 3]];
456
- res[ 4] = a[b[ 4]];
457
- res[ 5] = a[b[ 5]];
458
- res[ 6] = a[b[ 6]];
459
- res[ 7] = a[b[ 7]];
460
- res[ 8] = a[b[ 8]];
461
- res[ 9] = a[b[ 9]];
462
- res[10] = a[b[10]];
463
- res[11] = a[b[11]];
464
- res[12] = a[b[12]];
465
- res[13] = a[b[13]];
466
- res[14] = a[b[14]];
467
- res[15] = a[b[15]];
468
-
469
- return res;
470
- }
471
-
472
- // NOTE: not tested
473
- inline static uint8x16_t ggml_vqtbl1q_u8(uint8x16_t a, uint8x16_t b) {
474
- uint8x16_t res;
475
-
476
- res[ 0] = a[b[ 0]];
477
- res[ 1] = a[b[ 1]];
478
- res[ 2] = a[b[ 2]];
479
- res[ 3] = a[b[ 3]];
480
- res[ 4] = a[b[ 4]];
481
- res[ 5] = a[b[ 5]];
482
- res[ 6] = a[b[ 6]];
483
- res[ 7] = a[b[ 7]];
484
- res[ 8] = a[b[ 8]];
485
- res[ 9] = a[b[ 9]];
486
- res[10] = a[b[10]];
487
- res[11] = a[b[11]];
488
- res[12] = a[b[12]];
489
- res[13] = a[b[13]];
490
- res[14] = a[b[14]];
491
- res[15] = a[b[15]];
492
-
493
- return res;
494
- }
495
-
496
- #else
497
-
498
- #define ggml_int16x8x2_t int16x8x2_t
499
- #define ggml_uint8x16x2_t uint8x16x2_t
500
- #define ggml_uint8x16x4_t uint8x16x4_t
501
- #define ggml_int8x16x2_t int8x16x2_t
502
- #define ggml_int8x16x4_t int8x16x4_t
503
-
504
- #define ggml_vld1q_s16_x2 vld1q_s16_x2
505
- #define ggml_vld1q_u8_x2 vld1q_u8_x2
506
- #define ggml_vld1q_u8_x4 vld1q_u8_x4
507
- #define ggml_vld1q_s8_x2 vld1q_s8_x2
508
- #define ggml_vld1q_s8_x4 vld1q_s8_x4
509
- #define ggml_vqtbl1q_s8 vqtbl1q_s8
510
- #define ggml_vqtbl1q_u8 vqtbl1q_u8
511
-
512
- #endif
513
-
514
- #if !defined(__ARM_FEATURE_DOTPROD)
515
-
516
- inline static int32x4_t ggml_vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b) {
517
- const int16x8_t p0 = vmull_s8(vget_low_s8 (a), vget_low_s8 (b));
518
- const int16x8_t p1 = vmull_s8(vget_high_s8(a), vget_high_s8(b));
519
-
520
- return vaddq_s32(acc, vaddq_s32(vpaddlq_s16(p0), vpaddlq_s16(p1)));
521
- }
522
-
523
- #else
524
-
525
- #define ggml_vdotq_s32(a, b, c) vdotq_s32(a, b, c)
526
-
527
- #endif
528
-
529
- #endif
530
-
531
238
  #if defined(__ARM_NEON) || defined(__wasm_simd128__)
532
239
  #define B1(c,s,n) 0x ## n ## c , 0x ## n ## s
533
240
  #define B2(c,s,n) B1(c,s,n ## c), B1(c,s,n ## s)
@@ -3154,7 +3154,6 @@ typedef float (*vec_dot_q_mul_mat_sycl_t)(
3154
3154
  #define SYCL_SCALE_BLOCK_SIZE 256
3155
3155
  #define SYCL_CLAMP_BLOCK_SIZE 256
3156
3156
  #define SYCL_ROPE_BLOCK_SIZE 256
3157
- #define SYCL_SOFT_MAX_BLOCK_SIZE 1024
3158
3157
  #define SYCL_ALIBI_BLOCK_SIZE 32
3159
3158
  #define SYCL_DIAG_MASK_INF_BLOCK_SIZE 32
3160
3159
  #define SYCL_QUANTIZE_BLOCK_SIZE 256
@@ -13080,11 +13079,13 @@ static void soft_max_f32_sycl(const float * x, const float * mask, const float *
13080
13079
  const int nrows_y, const float scale, const float max_bias,
13081
13080
  dpct::queue_ptr stream) {
13082
13081
  int nth = WARP_SIZE;
13083
- while (nth < ncols_x && nth < SYCL_SOFT_MAX_BLOCK_SIZE) nth *= 2;
13082
+ int max_block_size = g_work_group_size;
13083
+ while (nth < ncols_x && nth < max_block_size) nth *= 2;
13084
+ if (nth>max_block_size) nth = max_block_size;
13085
+
13084
13086
  const sycl::range<3> block_dims(1, 1, nth);
13085
13087
  const sycl::range<3> block_nums(1, 1, nrows_x);
13086
13088
  const size_t n_local_scratch = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE);
13087
- static_assert(SYCL_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
13088
13089
 
13089
13090
  const uint32_t n_head_kv = nrows_x/nrows_y;
13090
13091
  const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
@@ -13094,6 +13095,12 @@ static void soft_max_f32_sycl(const float * x, const float * mask, const float *
13094
13095
 
13095
13096
  const size_t local_mem_size = stream->get_device().get_info<sycl::info::device::local_mem_size>();
13096
13097
  if (n_local_scratch*sizeof(float) < local_mem_size) {
13098
+ if (ncols_x > max_block_size) {
13099
+ soft_max_f32_submitter<true, 0, 0>(x, mask, pos, dst, ncols_x, nrows_y, scale,
13100
+ max_bias, m0, m1, n_head_log2, block_nums,
13101
+ block_dims, n_local_scratch, stream);
13102
+ return;
13103
+ }
13097
13104
  switch (ncols_x) {
13098
13105
  case 32:
13099
13106
  soft_max_f32_submitter<true, 32, 32>(x, mask, pos, dst, ncols_x, nrows_y, scale,
@@ -15989,73 +15996,76 @@ static void ggml_sycl_mul_mat_id_sycl(ggml_tensor * dst) {
15989
15996
  static void ggml_sycl_mul_mat_id(const ggml_tensor *src0,
15990
15997
  const ggml_tensor *src1,
15991
15998
  ggml_tensor *dst) try {
15992
- #if 0
15993
- ggml_sycl_mul_mat_id_sycl(dst);
15994
- // TODO: mmq/mmv support
15995
- #endif
15999
+ GGML_ASSERT(src0->backend != GGML_BACKEND_TYPE_GPU_SPLIT &&
16000
+ "mul_mat_id does not support split buffers");
16001
+ const ggml_tensor *ids = dst->src[2];
16002
+ const dpct::queue_ptr stream = g_syclStreams[g_main_device][0];
15996
16003
 
15997
- const int64_t nb11 = src1->nb[1];
15998
- const int64_t nb1 = dst->nb[1];
16004
+ const size_t nb11 = src1->nb[1];
16005
+ const size_t nb1 = dst->nb[1];
15999
16006
 
16000
- const struct ggml_tensor * ids = src0;
16001
- const int32_t id = ((int32_t *) dst->op_params)[0];
16002
- const int32_t n_as = ((int32_t *) dst->op_params)[1];
16007
+ const int32_t id = ((int32_t *)dst->op_params)[0];
16008
+ const int32_t n_as = src0->ne[2];
16003
16009
 
16004
16010
  std::vector<char> ids_host(ggml_nbytes(ids));
16011
+ const char *ids_dev = (const char *)ids->data;
16005
16012
 
16006
- const dpct::queue_ptr stream = g_syclStreams[g_main_device][0];
16007
-
16008
- if (ids->backend == GGML_BACKEND_TYPE_GPU) {
16009
- const char * ids_dev = (const char *)((const ggml_tensor_extra_gpu *)ids->extra)->data_device[g_main_device];
16010
- SYCL_CHECK(CHECK_TRY_ERROR(
16011
- stream->memcpy(ids_host.data(), ids_dev, ggml_nbytes(ids)).wait()));
16012
- // SYCL_CHECK(CHECK_TRY_ERROR(stream->wait()));
16013
- } else {
16014
- memcpy(ids_host.data(), ids->data, ggml_nbytes(ids));
16015
- }
16013
+ SYCL_CHECK(CHECK_TRY_ERROR(
16014
+ stream->memcpy(ids_host.data(), ids_dev, ggml_nbytes(ids))));
16015
+ SYCL_CHECK(CHECK_TRY_ERROR(stream->wait()));
16016
16016
 
16017
- const ggml_tensor_extra_gpu * src1_extra = (const ggml_tensor_extra_gpu *) src1->extra;
16018
- const ggml_tensor_extra_gpu * dst_extra = (const ggml_tensor_extra_gpu *) dst->extra;
16017
+ const ggml_tensor_extra_gpu *src0_extra =
16018
+ (const ggml_tensor_extra_gpu *)src0->extra;
16019
+ const ggml_tensor_extra_gpu *src1_extra =
16020
+ (const ggml_tensor_extra_gpu *)src1->extra;
16021
+ const ggml_tensor_extra_gpu *dst_extra =
16022
+ (const ggml_tensor_extra_gpu *)dst->extra;
16019
16023
 
16024
+ ggml_tensor_extra_gpu src0_row_extra;
16020
16025
  ggml_tensor_extra_gpu src1_row_extra;
16021
16026
  ggml_tensor_extra_gpu dst_row_extra;
16022
16027
 
16028
+ ggml_tensor src0_row = *src0;
16023
16029
  ggml_tensor src1_row = *src1;
16024
16030
  ggml_tensor dst_row = *dst;
16025
16031
 
16026
16032
  src1_row.backend = GGML_BACKEND_TYPE_GPU;
16027
16033
  dst_row.backend = GGML_BACKEND_TYPE_GPU;
16028
16034
 
16035
+ src0_row.extra = &src0_row_extra;
16029
16036
  src1_row.extra = &src1_row_extra;
16030
16037
  dst_row.extra = &dst_row_extra;
16031
16038
 
16032
- char * src1_original = src1->backend == GGML_BACKEND_TYPE_CPU ?
16033
- (char *) src1->data : (char *) src1_extra->data_device[g_main_device];
16034
- char * dst_original = dst->backend == GGML_BACKEND_TYPE_CPU ?
16035
- (char *) dst->data : (char *) dst_extra->data_device[g_main_device];
16039
+ char *src0_original = src1->backend == GGML_BACKEND_TYPE_CPU
16040
+ ? (char *)src0->data
16041
+ : (char *)src0_extra->data_device[g_main_device];
16042
+ char *src1_original = src1->backend == GGML_BACKEND_TYPE_CPU
16043
+ ? (char *)src1->data
16044
+ : (char *)src1_extra->data_device[g_main_device];
16045
+ char *dst_original = dst->backend == GGML_BACKEND_TYPE_CPU
16046
+ ? (char *)dst->data
16047
+ : (char *)dst_extra->data_device[g_main_device];
16036
16048
 
16037
- if (src1->ne[1] == 1) {
16038
- GGML_ASSERT(src1->backend == GGML_BACKEND_TYPE_GPU);
16039
- GGML_ASSERT(dst->backend == GGML_BACKEND_TYPE_GPU);
16049
+ src0_row.ne[2] = 1;
16050
+ src0_row.ne[3] = 1;
16051
+ src0_row.nb[3] = src0->nb[2];
16040
16052
 
16053
+ if (src1->ne[1] == 1) {
16041
16054
  for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
16042
- //int32_t row_id;
16043
- //SYCL_CHECK(syclMemcpyAsync(&row_id, ids_dev + i01*ids->nb[1] + id*ids->nb[0], sizeof(int32_t), syclMemcpyDeviceToHost, g_syclStreams[g_main_device][0]));
16044
- //SYCL_CHECK(syclStreamSynchronize(g_syclStreams[g_main_device][0]));
16045
-
16046
- const int32_t row_id = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
16055
+ const int32_t row_id =
16056
+ *(const int32_t *)(ids_host.data() + i01 * ids->nb[1] +
16057
+ id * ids->nb[0]);
16047
16058
 
16048
16059
  GGML_ASSERT(row_id >= 0 && row_id < n_as);
16049
16060
 
16050
- const struct ggml_tensor * src0_row = dst->src[row_id + 2];
16051
-
16052
- src1_row_extra.data_device[g_main_device] = src1_original + i01*src1->nb[1];
16053
- src1_row.data = (char *) src1->data + i01*src1->nb[1]; // TODO why is this set?
16061
+ src0_row_extra.data_device[g_main_device] =
16062
+ src0_original + row_id * src0->nb[2];
16063
+ src1_row_extra.data_device[g_main_device] =
16064
+ src1_original + i01 * src1->nb[1];
16065
+ dst_row_extra.data_device[g_main_device] =
16066
+ dst_original + i01 * dst->nb[1];
16054
16067
 
16055
- dst_row_extra.data_device[g_main_device] = dst_original + i01*dst->nb[1];
16056
- dst_row.data = (char *) dst->data + i01*dst->nb[1]; // TODO why is this set?
16057
-
16058
- ggml_sycl_mul_mat(src0_row, &src1_row, &dst_row);
16068
+ ggml_sycl_mul_mat(&src0_row, &src1_row, &dst_row);
16059
16069
  }
16060
16070
  } else {
16061
16071
  sycl_pool_alloc<char> src1_contiguous(sizeof(float)*ggml_nelements(src1));
@@ -16065,8 +16075,6 @@ static void ggml_sycl_mul_mat_id(const ggml_tensor *src0,
16065
16075
  dst_row_extra.data_device[g_main_device] = dst_contiguous.get();
16066
16076
 
16067
16077
  for (int32_t row_id = 0; row_id < n_as; ++row_id) {
16068
- const struct ggml_tensor * src0_row = dst->src[row_id + 2];
16069
-
16070
16078
  int64_t num_src1_rows = 0;
16071
16079
  for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
16072
16080
  const int32_t row_id_i = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
@@ -16079,7 +16087,7 @@ static void ggml_sycl_mul_mat_id(const ggml_tensor *src0,
16079
16087
 
16080
16088
  SYCL_CHECK(CHECK_TRY_ERROR(
16081
16089
  stream->memcpy(src1_contiguous.get() + num_src1_rows * nb11,
16082
- src1_original + i01 * nb11, nb11).wait()));
16090
+ src1_original + i01 * nb11, nb11)));
16083
16091
  num_src1_rows++;
16084
16092
  }
16085
16093
 
@@ -16087,6 +16095,9 @@ static void ggml_sycl_mul_mat_id(const ggml_tensor *src0,
16087
16095
  continue;
16088
16096
  }
16089
16097
 
16098
+ src0_row_extra.data_device[g_main_device] =
16099
+ src0_original + row_id * src0->nb[2];
16100
+
16090
16101
  src1_row.ne[1] = num_src1_rows;
16091
16102
  dst_row.ne[1] = num_src1_rows;
16092
16103
 
@@ -16098,7 +16109,7 @@ static void ggml_sycl_mul_mat_id(const ggml_tensor *src0,
16098
16109
  dst_row.nb[2] = num_src1_rows*nb1;
16099
16110
  dst_row.nb[3] = num_src1_rows*nb1;
16100
16111
 
16101
- ggml_sycl_mul_mat(src0_row, &src1_row, &dst_row);
16112
+ ggml_sycl_mul_mat(&src0_row, &src1_row, &dst_row);
16102
16113
 
16103
16114
  num_src1_rows = 0;
16104
16115
  for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
@@ -16112,7 +16123,7 @@ static void ggml_sycl_mul_mat_id(const ggml_tensor *src0,
16112
16123
 
16113
16124
  SYCL_CHECK(CHECK_TRY_ERROR(stream->memcpy(
16114
16125
  dst_original + i01 * nb1,
16115
- dst_contiguous.get() + num_src1_rows * nb1, nb1).wait()));
16126
+ dst_contiguous.get() + num_src1_rows * nb1, nb1)));
16116
16127
  num_src1_rows++;
16117
16128
  }
16118
16129
  }
@@ -16814,11 +16825,13 @@ static void ggml_backend_sycl_buffer_set_tensor(ggml_backend_buffer_t buffer,
16814
16825
  const dpct::queue_ptr stream = g_syclStreams[ctx->device][0];
16815
16826
  SYCL_CHECK(
16816
16827
  CHECK_TRY_ERROR(dpct::dev_mgr::instance().get_device(ctx->device).queues_wait_and_throw()));
16817
-
16828
+ char* host_buf = (char*)malloc(size);
16829
+ memcpy(host_buf, data, size);
16818
16830
  SYCL_CHECK(
16819
16831
  CHECK_TRY_ERROR((*stream)
16820
- .memcpy((char *)tensor->data + offset, data, size)
16832
+ .memcpy((char *)tensor->data + offset, host_buf, size)
16821
16833
  .wait()));
16834
+ free(host_buf);
16822
16835
  }
16823
16836
  catch (sycl::exception const &exc) {
16824
16837
  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
@@ -17739,7 +17752,7 @@ GGML_CALL static bool ggml_backend_sycl_supports_op(ggml_backend_t backend, cons
17739
17752
 
17740
17753
  GGML_CALL static bool ggml_backend_sycl_offload_op(ggml_backend_t backend, const ggml_tensor * op) {
17741
17754
  const int min_batch_size = 32;
17742
- return op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS;
17755
+ return op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS && op->op != GGML_OP_MUL_MAT_ID;
17743
17756
  GGML_UNUSED(backend);
17744
17757
  }
17745
17758