llama_cpp 0.14.5 → 0.14.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -14,47 +14,6 @@
14
14
  #include <stdlib.h> // for qsort
15
15
  #include <stdio.h> // for GGML_ASSERT
16
16
 
17
- #ifdef __ARM_NEON
18
-
19
- // if YCM cannot find <arm_neon.h>, make a symbolic link to it, for example:
20
- //
21
- // $ ln -sfn /Library/Developer/CommandLineTools/usr/lib/clang/13.1.6/include/arm_neon.h ./src/
22
- //
23
- #include <arm_neon.h>
24
-
25
- #else
26
-
27
- #ifdef __wasm_simd128__
28
- #include <wasm_simd128.h>
29
- #else
30
- #if defined(__POWER9_VECTOR__) || defined(__powerpc64__)
31
- #include <altivec.h>
32
- #undef bool
33
- #define bool _Bool
34
- #else
35
- #if defined(_MSC_VER) || defined(__MINGW32__)
36
- #include <intrin.h>
37
- #else
38
- #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) || defined(__SSE3__)
39
- #if !defined(__riscv)
40
- #include <immintrin.h>
41
- #endif
42
- #endif
43
- #endif
44
- #endif
45
- #endif
46
- #endif
47
-
48
- #ifdef __riscv_v_intrinsic
49
- #include <riscv_vector.h>
50
- #endif
51
-
52
- #undef MIN
53
- #undef MAX
54
-
55
- #define MIN(a, b) ((a) < (b) ? (a) : (b))
56
- #define MAX(a, b) ((a) > (b) ? (a) : (b))
57
-
58
17
  #define UNUSED GGML_UNUSED
59
18
 
60
19
  // some compilers don't provide _mm256_set_m128i, e.g. gcc 7
@@ -132,7 +91,7 @@ static inline __m256 sum_i16_pairs_float(const __m256i x) {
132
91
  }
133
92
 
134
93
  static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) {
135
- #if defined(__AVXVNNI__) || defined(__AVX512VNNI__)
94
+ #if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__))
136
95
  const __m256i zero = _mm256_setzero_si256();
137
96
  const __m256i summed_pairs = _mm256_dpbusd_epi32(zero, ax, sy);
138
97
  return _mm256_cvtepi32_ps(summed_pairs);
@@ -276,258 +235,6 @@ static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128
276
235
  #endif // __AVX__ || __AVX2__ || __AVX512F__
277
236
  #endif // defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
278
237
 
279
- #if defined(__ARM_NEON)
280
-
281
- #ifdef _MSC_VER
282
-
283
- #define ggml_vld1q_u32(w,x,y,z) { ((w) + ((uint64_t)(x) << 32)), ((y) + ((uint64_t)(z) << 32)) }
284
-
285
- #else
286
-
287
- #define ggml_vld1q_u32(w,x,y,z) { (w), (x), (y), (z) }
288
-
289
- #endif
290
-
291
- #if !defined(__aarch64__)
292
-
293
- // 64-bit compatibility
294
-
295
- // vaddvq_s16
296
- // vpaddq_s16
297
- // vpaddq_s32
298
- // vaddvq_s32
299
- // vaddvq_f32
300
- // vmaxvq_f32
301
- // vcvtnq_s32_f32
302
- // vzip1_u8
303
- // vzip2_u8
304
-
305
- inline static int32_t vaddvq_s16(int16x8_t v) {
306
- return
307
- (int32_t)vgetq_lane_s16(v, 0) + (int32_t)vgetq_lane_s16(v, 1) +
308
- (int32_t)vgetq_lane_s16(v, 2) + (int32_t)vgetq_lane_s16(v, 3) +
309
- (int32_t)vgetq_lane_s16(v, 4) + (int32_t)vgetq_lane_s16(v, 5) +
310
- (int32_t)vgetq_lane_s16(v, 6) + (int32_t)vgetq_lane_s16(v, 7);
311
- }
312
-
313
- inline static int16x8_t vpaddq_s16(int16x8_t a, int16x8_t b) {
314
- int16x4_t a0 = vpadd_s16(vget_low_s16(a), vget_high_s16(a));
315
- int16x4_t b0 = vpadd_s16(vget_low_s16(b), vget_high_s16(b));
316
- return vcombine_s16(a0, b0);
317
- }
318
-
319
- inline static int32x4_t vpaddq_s32(int32x4_t a, int32x4_t b) {
320
- int32x2_t a0 = vpadd_s32(vget_low_s32(a), vget_high_s32(a));
321
- int32x2_t b0 = vpadd_s32(vget_low_s32(b), vget_high_s32(b));
322
- return vcombine_s32(a0, b0);
323
- }
324
-
325
- inline static int32_t vaddvq_s32(int32x4_t v) {
326
- return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3);
327
- }
328
-
329
- inline static float vaddvq_f32(float32x4_t v) {
330
- return vgetq_lane_f32(v, 0) + vgetq_lane_f32(v, 1) + vgetq_lane_f32(v, 2) + vgetq_lane_f32(v, 3);
331
- }
332
-
333
- inline static float vmaxvq_f32(float32x4_t v) {
334
- return
335
- MAX(MAX(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)),
336
- MAX(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3)));
337
- }
338
-
339
- inline static int32x4_t vcvtnq_s32_f32(float32x4_t v) {
340
- int32x4_t res;
341
-
342
- res[0] = roundf(vgetq_lane_f32(v, 0));
343
- res[1] = roundf(vgetq_lane_f32(v, 1));
344
- res[2] = roundf(vgetq_lane_f32(v, 2));
345
- res[3] = roundf(vgetq_lane_f32(v, 3));
346
-
347
- return res;
348
- }
349
-
350
- inline static uint8x8_t vzip1_u8(uint8x8_t a, uint8x8_t b) {
351
- uint8x8_t res;
352
-
353
- res[0] = a[0]; res[1] = b[0];
354
- res[2] = a[1]; res[3] = b[1];
355
- res[4] = a[2]; res[5] = b[2];
356
- res[6] = a[3]; res[7] = b[3];
357
-
358
- return res;
359
- }
360
-
361
- inline static uint8x8_t vzip2_u8(uint8x8_t a, uint8x8_t b) {
362
- uint8x8_t res;
363
-
364
- res[0] = a[4]; res[1] = b[4];
365
- res[2] = a[5]; res[3] = b[5];
366
- res[4] = a[6]; res[5] = b[6];
367
- res[6] = a[7]; res[7] = b[7];
368
-
369
- return res;
370
- }
371
-
372
- // vld1q_s16_x2
373
- // vld1q_u8_x2
374
- // vld1q_u8_x4
375
- // vld1q_s8_x2
376
- // vld1q_s8_x4
377
- // TODO: double-check these work correctly
378
-
379
- typedef struct ggml_int16x8x2_t {
380
- int16x8_t val[2];
381
- } ggml_int16x8x2_t;
382
-
383
- inline static ggml_int16x8x2_t ggml_vld1q_s16_x2(const int16_t * ptr) {
384
- ggml_int16x8x2_t res;
385
-
386
- res.val[0] = vld1q_s16(ptr + 0);
387
- res.val[1] = vld1q_s16(ptr + 8);
388
-
389
- return res;
390
- }
391
-
392
- typedef struct ggml_uint8x16x2_t {
393
- uint8x16_t val[2];
394
- } ggml_uint8x16x2_t;
395
-
396
- inline static ggml_uint8x16x2_t ggml_vld1q_u8_x2(const uint8_t * ptr) {
397
- ggml_uint8x16x2_t res;
398
-
399
- res.val[0] = vld1q_u8(ptr + 0);
400
- res.val[1] = vld1q_u8(ptr + 16);
401
-
402
- return res;
403
- }
404
-
405
- typedef struct ggml_uint8x16x4_t {
406
- uint8x16_t val[4];
407
- } ggml_uint8x16x4_t;
408
-
409
- inline static ggml_uint8x16x4_t ggml_vld1q_u8_x4(const uint8_t * ptr) {
410
- ggml_uint8x16x4_t res;
411
-
412
- res.val[0] = vld1q_u8(ptr + 0);
413
- res.val[1] = vld1q_u8(ptr + 16);
414
- res.val[2] = vld1q_u8(ptr + 32);
415
- res.val[3] = vld1q_u8(ptr + 48);
416
-
417
- return res;
418
- }
419
-
420
- typedef struct ggml_int8x16x2_t {
421
- int8x16_t val[2];
422
- } ggml_int8x16x2_t;
423
-
424
- inline static ggml_int8x16x2_t ggml_vld1q_s8_x2(const int8_t * ptr) {
425
- ggml_int8x16x2_t res;
426
-
427
- res.val[0] = vld1q_s8(ptr + 0);
428
- res.val[1] = vld1q_s8(ptr + 16);
429
-
430
- return res;
431
- }
432
-
433
- typedef struct ggml_int8x16x4_t {
434
- int8x16_t val[4];
435
- } ggml_int8x16x4_t;
436
-
437
- inline static ggml_int8x16x4_t ggml_vld1q_s8_x4(const int8_t * ptr) {
438
- ggml_int8x16x4_t res;
439
-
440
- res.val[0] = vld1q_s8(ptr + 0);
441
- res.val[1] = vld1q_s8(ptr + 16);
442
- res.val[2] = vld1q_s8(ptr + 32);
443
- res.val[3] = vld1q_s8(ptr + 48);
444
-
445
- return res;
446
- }
447
-
448
- // NOTE: not tested
449
- inline static int8x16_t ggml_vqtbl1q_s8(int8x16_t a, uint8x16_t b) {
450
- int8x16_t res;
451
-
452
- res[ 0] = a[b[ 0]];
453
- res[ 1] = a[b[ 1]];
454
- res[ 2] = a[b[ 2]];
455
- res[ 3] = a[b[ 3]];
456
- res[ 4] = a[b[ 4]];
457
- res[ 5] = a[b[ 5]];
458
- res[ 6] = a[b[ 6]];
459
- res[ 7] = a[b[ 7]];
460
- res[ 8] = a[b[ 8]];
461
- res[ 9] = a[b[ 9]];
462
- res[10] = a[b[10]];
463
- res[11] = a[b[11]];
464
- res[12] = a[b[12]];
465
- res[13] = a[b[13]];
466
- res[14] = a[b[14]];
467
- res[15] = a[b[15]];
468
-
469
- return res;
470
- }
471
-
472
- // NOTE: not tested
473
- inline static uint8x16_t ggml_vqtbl1q_u8(uint8x16_t a, uint8x16_t b) {
474
- uint8x16_t res;
475
-
476
- res[ 0] = a[b[ 0]];
477
- res[ 1] = a[b[ 1]];
478
- res[ 2] = a[b[ 2]];
479
- res[ 3] = a[b[ 3]];
480
- res[ 4] = a[b[ 4]];
481
- res[ 5] = a[b[ 5]];
482
- res[ 6] = a[b[ 6]];
483
- res[ 7] = a[b[ 7]];
484
- res[ 8] = a[b[ 8]];
485
- res[ 9] = a[b[ 9]];
486
- res[10] = a[b[10]];
487
- res[11] = a[b[11]];
488
- res[12] = a[b[12]];
489
- res[13] = a[b[13]];
490
- res[14] = a[b[14]];
491
- res[15] = a[b[15]];
492
-
493
- return res;
494
- }
495
-
496
- #else
497
-
498
- #define ggml_int16x8x2_t int16x8x2_t
499
- #define ggml_uint8x16x2_t uint8x16x2_t
500
- #define ggml_uint8x16x4_t uint8x16x4_t
501
- #define ggml_int8x16x2_t int8x16x2_t
502
- #define ggml_int8x16x4_t int8x16x4_t
503
-
504
- #define ggml_vld1q_s16_x2 vld1q_s16_x2
505
- #define ggml_vld1q_u8_x2 vld1q_u8_x2
506
- #define ggml_vld1q_u8_x4 vld1q_u8_x4
507
- #define ggml_vld1q_s8_x2 vld1q_s8_x2
508
- #define ggml_vld1q_s8_x4 vld1q_s8_x4
509
- #define ggml_vqtbl1q_s8 vqtbl1q_s8
510
- #define ggml_vqtbl1q_u8 vqtbl1q_u8
511
-
512
- #endif
513
-
514
- #if !defined(__ARM_FEATURE_DOTPROD)
515
-
516
- inline static int32x4_t ggml_vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b) {
517
- const int16x8_t p0 = vmull_s8(vget_low_s8 (a), vget_low_s8 (b));
518
- const int16x8_t p1 = vmull_s8(vget_high_s8(a), vget_high_s8(b));
519
-
520
- return vaddq_s32(acc, vaddq_s32(vpaddlq_s16(p0), vpaddlq_s16(p1)));
521
- }
522
-
523
- #else
524
-
525
- #define ggml_vdotq_s32(a, b, c) vdotq_s32(a, b, c)
526
-
527
- #endif
528
-
529
- #endif
530
-
531
238
  #if defined(__ARM_NEON) || defined(__wasm_simd128__)
532
239
  #define B1(c,s,n) 0x ## n ## c , 0x ## n ## s
533
240
  #define B2(c,s,n) B1(c,s,n ## c), B1(c,s,n ## s)
@@ -3154,7 +3154,6 @@ typedef float (*vec_dot_q_mul_mat_sycl_t)(
3154
3154
  #define SYCL_SCALE_BLOCK_SIZE 256
3155
3155
  #define SYCL_CLAMP_BLOCK_SIZE 256
3156
3156
  #define SYCL_ROPE_BLOCK_SIZE 256
3157
- #define SYCL_SOFT_MAX_BLOCK_SIZE 1024
3158
3157
  #define SYCL_ALIBI_BLOCK_SIZE 32
3159
3158
  #define SYCL_DIAG_MASK_INF_BLOCK_SIZE 32
3160
3159
  #define SYCL_QUANTIZE_BLOCK_SIZE 256
@@ -13080,11 +13079,13 @@ static void soft_max_f32_sycl(const float * x, const float * mask, const float *
13080
13079
  const int nrows_y, const float scale, const float max_bias,
13081
13080
  dpct::queue_ptr stream) {
13082
13081
  int nth = WARP_SIZE;
13083
- while (nth < ncols_x && nth < SYCL_SOFT_MAX_BLOCK_SIZE) nth *= 2;
13082
+ int max_block_size = g_work_group_size;
13083
+ while (nth < ncols_x && nth < max_block_size) nth *= 2;
13084
+ if (nth>max_block_size) nth = max_block_size;
13085
+
13084
13086
  const sycl::range<3> block_dims(1, 1, nth);
13085
13087
  const sycl::range<3> block_nums(1, 1, nrows_x);
13086
13088
  const size_t n_local_scratch = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE);
13087
- static_assert(SYCL_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
13088
13089
 
13089
13090
  const uint32_t n_head_kv = nrows_x/nrows_y;
13090
13091
  const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
@@ -13094,6 +13095,12 @@ static void soft_max_f32_sycl(const float * x, const float * mask, const float *
13094
13095
 
13095
13096
  const size_t local_mem_size = stream->get_device().get_info<sycl::info::device::local_mem_size>();
13096
13097
  if (n_local_scratch*sizeof(float) < local_mem_size) {
13098
+ if (ncols_x > max_block_size) {
13099
+ soft_max_f32_submitter<true, 0, 0>(x, mask, pos, dst, ncols_x, nrows_y, scale,
13100
+ max_bias, m0, m1, n_head_log2, block_nums,
13101
+ block_dims, n_local_scratch, stream);
13102
+ return;
13103
+ }
13097
13104
  switch (ncols_x) {
13098
13105
  case 32:
13099
13106
  soft_max_f32_submitter<true, 32, 32>(x, mask, pos, dst, ncols_x, nrows_y, scale,
@@ -15989,73 +15996,76 @@ static void ggml_sycl_mul_mat_id_sycl(ggml_tensor * dst) {
15989
15996
  static void ggml_sycl_mul_mat_id(const ggml_tensor *src0,
15990
15997
  const ggml_tensor *src1,
15991
15998
  ggml_tensor *dst) try {
15992
- #if 0
15993
- ggml_sycl_mul_mat_id_sycl(dst);
15994
- // TODO: mmq/mmv support
15995
- #endif
15999
+ GGML_ASSERT(src0->backend != GGML_BACKEND_TYPE_GPU_SPLIT &&
16000
+ "mul_mat_id does not support split buffers");
16001
+ const ggml_tensor *ids = dst->src[2];
16002
+ const dpct::queue_ptr stream = g_syclStreams[g_main_device][0];
15996
16003
 
15997
- const int64_t nb11 = src1->nb[1];
15998
- const int64_t nb1 = dst->nb[1];
16004
+ const size_t nb11 = src1->nb[1];
16005
+ const size_t nb1 = dst->nb[1];
15999
16006
 
16000
- const struct ggml_tensor * ids = src0;
16001
- const int32_t id = ((int32_t *) dst->op_params)[0];
16002
- const int32_t n_as = ((int32_t *) dst->op_params)[1];
16007
+ const int32_t id = ((int32_t *)dst->op_params)[0];
16008
+ const int32_t n_as = src0->ne[2];
16003
16009
 
16004
16010
  std::vector<char> ids_host(ggml_nbytes(ids));
16011
+ const char *ids_dev = (const char *)ids->data;
16005
16012
 
16006
- const dpct::queue_ptr stream = g_syclStreams[g_main_device][0];
16007
-
16008
- if (ids->backend == GGML_BACKEND_TYPE_GPU) {
16009
- const char * ids_dev = (const char *)((const ggml_tensor_extra_gpu *)ids->extra)->data_device[g_main_device];
16010
- SYCL_CHECK(CHECK_TRY_ERROR(
16011
- stream->memcpy(ids_host.data(), ids_dev, ggml_nbytes(ids)).wait()));
16012
- // SYCL_CHECK(CHECK_TRY_ERROR(stream->wait()));
16013
- } else {
16014
- memcpy(ids_host.data(), ids->data, ggml_nbytes(ids));
16015
- }
16013
+ SYCL_CHECK(CHECK_TRY_ERROR(
16014
+ stream->memcpy(ids_host.data(), ids_dev, ggml_nbytes(ids))));
16015
+ SYCL_CHECK(CHECK_TRY_ERROR(stream->wait()));
16016
16016
 
16017
- const ggml_tensor_extra_gpu * src1_extra = (const ggml_tensor_extra_gpu *) src1->extra;
16018
- const ggml_tensor_extra_gpu * dst_extra = (const ggml_tensor_extra_gpu *) dst->extra;
16017
+ const ggml_tensor_extra_gpu *src0_extra =
16018
+ (const ggml_tensor_extra_gpu *)src0->extra;
16019
+ const ggml_tensor_extra_gpu *src1_extra =
16020
+ (const ggml_tensor_extra_gpu *)src1->extra;
16021
+ const ggml_tensor_extra_gpu *dst_extra =
16022
+ (const ggml_tensor_extra_gpu *)dst->extra;
16019
16023
 
16024
+ ggml_tensor_extra_gpu src0_row_extra;
16020
16025
  ggml_tensor_extra_gpu src1_row_extra;
16021
16026
  ggml_tensor_extra_gpu dst_row_extra;
16022
16027
 
16028
+ ggml_tensor src0_row = *src0;
16023
16029
  ggml_tensor src1_row = *src1;
16024
16030
  ggml_tensor dst_row = *dst;
16025
16031
 
16026
16032
  src1_row.backend = GGML_BACKEND_TYPE_GPU;
16027
16033
  dst_row.backend = GGML_BACKEND_TYPE_GPU;
16028
16034
 
16035
+ src0_row.extra = &src0_row_extra;
16029
16036
  src1_row.extra = &src1_row_extra;
16030
16037
  dst_row.extra = &dst_row_extra;
16031
16038
 
16032
- char * src1_original = src1->backend == GGML_BACKEND_TYPE_CPU ?
16033
- (char *) src1->data : (char *) src1_extra->data_device[g_main_device];
16034
- char * dst_original = dst->backend == GGML_BACKEND_TYPE_CPU ?
16035
- (char *) dst->data : (char *) dst_extra->data_device[g_main_device];
16039
+ char *src0_original = src1->backend == GGML_BACKEND_TYPE_CPU
16040
+ ? (char *)src0->data
16041
+ : (char *)src0_extra->data_device[g_main_device];
16042
+ char *src1_original = src1->backend == GGML_BACKEND_TYPE_CPU
16043
+ ? (char *)src1->data
16044
+ : (char *)src1_extra->data_device[g_main_device];
16045
+ char *dst_original = dst->backend == GGML_BACKEND_TYPE_CPU
16046
+ ? (char *)dst->data
16047
+ : (char *)dst_extra->data_device[g_main_device];
16036
16048
 
16037
- if (src1->ne[1] == 1) {
16038
- GGML_ASSERT(src1->backend == GGML_BACKEND_TYPE_GPU);
16039
- GGML_ASSERT(dst->backend == GGML_BACKEND_TYPE_GPU);
16049
+ src0_row.ne[2] = 1;
16050
+ src0_row.ne[3] = 1;
16051
+ src0_row.nb[3] = src0->nb[2];
16040
16052
 
16053
+ if (src1->ne[1] == 1) {
16041
16054
  for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
16042
- //int32_t row_id;
16043
- //SYCL_CHECK(syclMemcpyAsync(&row_id, ids_dev + i01*ids->nb[1] + id*ids->nb[0], sizeof(int32_t), syclMemcpyDeviceToHost, g_syclStreams[g_main_device][0]));
16044
- //SYCL_CHECK(syclStreamSynchronize(g_syclStreams[g_main_device][0]));
16045
-
16046
- const int32_t row_id = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
16055
+ const int32_t row_id =
16056
+ *(const int32_t *)(ids_host.data() + i01 * ids->nb[1] +
16057
+ id * ids->nb[0]);
16047
16058
 
16048
16059
  GGML_ASSERT(row_id >= 0 && row_id < n_as);
16049
16060
 
16050
- const struct ggml_tensor * src0_row = dst->src[row_id + 2];
16051
-
16052
- src1_row_extra.data_device[g_main_device] = src1_original + i01*src1->nb[1];
16053
- src1_row.data = (char *) src1->data + i01*src1->nb[1]; // TODO why is this set?
16061
+ src0_row_extra.data_device[g_main_device] =
16062
+ src0_original + row_id * src0->nb[2];
16063
+ src1_row_extra.data_device[g_main_device] =
16064
+ src1_original + i01 * src1->nb[1];
16065
+ dst_row_extra.data_device[g_main_device] =
16066
+ dst_original + i01 * dst->nb[1];
16054
16067
 
16055
- dst_row_extra.data_device[g_main_device] = dst_original + i01*dst->nb[1];
16056
- dst_row.data = (char *) dst->data + i01*dst->nb[1]; // TODO why is this set?
16057
-
16058
- ggml_sycl_mul_mat(src0_row, &src1_row, &dst_row);
16068
+ ggml_sycl_mul_mat(&src0_row, &src1_row, &dst_row);
16059
16069
  }
16060
16070
  } else {
16061
16071
  sycl_pool_alloc<char> src1_contiguous(sizeof(float)*ggml_nelements(src1));
@@ -16065,8 +16075,6 @@ static void ggml_sycl_mul_mat_id(const ggml_tensor *src0,
16065
16075
  dst_row_extra.data_device[g_main_device] = dst_contiguous.get();
16066
16076
 
16067
16077
  for (int32_t row_id = 0; row_id < n_as; ++row_id) {
16068
- const struct ggml_tensor * src0_row = dst->src[row_id + 2];
16069
-
16070
16078
  int64_t num_src1_rows = 0;
16071
16079
  for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
16072
16080
  const int32_t row_id_i = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
@@ -16079,7 +16087,7 @@ static void ggml_sycl_mul_mat_id(const ggml_tensor *src0,
16079
16087
 
16080
16088
  SYCL_CHECK(CHECK_TRY_ERROR(
16081
16089
  stream->memcpy(src1_contiguous.get() + num_src1_rows * nb11,
16082
- src1_original + i01 * nb11, nb11).wait()));
16090
+ src1_original + i01 * nb11, nb11)));
16083
16091
  num_src1_rows++;
16084
16092
  }
16085
16093
 
@@ -16087,6 +16095,9 @@ static void ggml_sycl_mul_mat_id(const ggml_tensor *src0,
16087
16095
  continue;
16088
16096
  }
16089
16097
 
16098
+ src0_row_extra.data_device[g_main_device] =
16099
+ src0_original + row_id * src0->nb[2];
16100
+
16090
16101
  src1_row.ne[1] = num_src1_rows;
16091
16102
  dst_row.ne[1] = num_src1_rows;
16092
16103
 
@@ -16098,7 +16109,7 @@ static void ggml_sycl_mul_mat_id(const ggml_tensor *src0,
16098
16109
  dst_row.nb[2] = num_src1_rows*nb1;
16099
16110
  dst_row.nb[3] = num_src1_rows*nb1;
16100
16111
 
16101
- ggml_sycl_mul_mat(src0_row, &src1_row, &dst_row);
16112
+ ggml_sycl_mul_mat(&src0_row, &src1_row, &dst_row);
16102
16113
 
16103
16114
  num_src1_rows = 0;
16104
16115
  for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
@@ -16112,7 +16123,7 @@ static void ggml_sycl_mul_mat_id(const ggml_tensor *src0,
16112
16123
 
16113
16124
  SYCL_CHECK(CHECK_TRY_ERROR(stream->memcpy(
16114
16125
  dst_original + i01 * nb1,
16115
- dst_contiguous.get() + num_src1_rows * nb1, nb1).wait()));
16126
+ dst_contiguous.get() + num_src1_rows * nb1, nb1)));
16116
16127
  num_src1_rows++;
16117
16128
  }
16118
16129
  }
@@ -16814,11 +16825,13 @@ static void ggml_backend_sycl_buffer_set_tensor(ggml_backend_buffer_t buffer,
16814
16825
  const dpct::queue_ptr stream = g_syclStreams[ctx->device][0];
16815
16826
  SYCL_CHECK(
16816
16827
  CHECK_TRY_ERROR(dpct::dev_mgr::instance().get_device(ctx->device).queues_wait_and_throw()));
16817
-
16828
+ char* host_buf = (char*)malloc(size);
16829
+ memcpy(host_buf, data, size);
16818
16830
  SYCL_CHECK(
16819
16831
  CHECK_TRY_ERROR((*stream)
16820
- .memcpy((char *)tensor->data + offset, data, size)
16832
+ .memcpy((char *)tensor->data + offset, host_buf, size)
16821
16833
  .wait()));
16834
+ free(host_buf);
16822
16835
  }
16823
16836
  catch (sycl::exception const &exc) {
16824
16837
  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
@@ -17739,7 +17752,7 @@ GGML_CALL static bool ggml_backend_sycl_supports_op(ggml_backend_t backend, cons
17739
17752
 
17740
17753
  GGML_CALL static bool ggml_backend_sycl_offload_op(ggml_backend_t backend, const ggml_tensor * op) {
17741
17754
  const int min_batch_size = 32;
17742
- return op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS;
17755
+ return op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS && op->op != GGML_OP_MUL_MAT_ID;
17743
17756
  GGML_UNUSED(backend);
17744
17757
  }
17745
17758