llama_cpp 0.14.7 → 0.15.1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +19 -0
- data/README.md +2 -2
- data/ext/llama_cpp/extconf.rb +2 -1
- data/ext/llama_cpp/llama_cpp.cpp +59 -9
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +24 -3
- data/vendor/tmp/llama.cpp/Makefile +42 -18
- data/vendor/tmp/llama.cpp/ggml-backend.c +7 -5
- data/vendor/tmp/llama.cpp/ggml-cuda.cu +295 -17
- data/vendor/tmp/llama.cpp/ggml-impl.h +78 -1
- data/vendor/tmp/llama.cpp/ggml-kompute.cpp +7 -0
- data/vendor/tmp/llama.cpp/ggml-metal.m +399 -184
- data/vendor/tmp/llama.cpp/ggml-metal.metal +654 -18
- data/vendor/tmp/llama.cpp/ggml-opencl.cpp +1 -0
- data/vendor/tmp/llama.cpp/ggml-quants.c +302 -0
- data/vendor/tmp/llama.cpp/ggml-sycl.cpp +28 -16
- data/vendor/tmp/llama.cpp/ggml-vulkan-shaders.hpp +46843 -39205
- data/vendor/tmp/llama.cpp/ggml-vulkan.cpp +951 -263
- data/vendor/tmp/llama.cpp/ggml.c +1457 -92
- data/vendor/tmp/llama.cpp/ggml.h +37 -7
- data/vendor/tmp/llama.cpp/llama.cpp +671 -403
- data/vendor/tmp/llama.cpp/llama.h +34 -10
- data/vendor/tmp/llama.cpp/sgemm.cpp +134 -103
- data/vendor/tmp/llama.cpp/sgemm.h +4 -2
- data/vendor/tmp/llama.cpp/unicode-data.cpp +1188 -656
- data/vendor/tmp/llama.cpp/unicode-data.h +4 -3
- data/vendor/tmp/llama.cpp/unicode.cpp +590 -49
- data/vendor/tmp/llama.cpp/unicode.h +6 -3
- metadata +3 -3
@@ -2119,6 +2119,7 @@ static size_t ggml_backend_opencl_buffer_type_get_alignment(ggml_backend_buffer_
|
|
2119
2119
|
if (alignment == (cl_uint)-1) {
|
2120
2120
|
ggml_cl_init();
|
2121
2121
|
clGetDeviceInfo(device, CL_DEVICE_MEM_BASE_ADDR_ALIGN, sizeof(cl_uint), &alignment, NULL);
|
2122
|
+
alignment /= 8; // bits to bytes
|
2122
2123
|
}
|
2123
2124
|
return alignment;
|
2124
2125
|
|
@@ -12383,3 +12383,305 @@ void quantize_row_iq2_s(const float * restrict x, void * restrict vy, int64_t k)
|
|
12383
12383
|
block_iq2_s * restrict y = vy;
|
12384
12384
|
quantize_row_iq2_s_reference(x, y, k);
|
12385
12385
|
}
|
12386
|
+
|
12387
|
+
static bool validate_float(float f, size_t i) {
|
12388
|
+
if (isinf(f)) {
|
12389
|
+
fprintf(stderr, "ggml_validate_row_data: found inf value at block %zu\n", i);
|
12390
|
+
return false;
|
12391
|
+
}
|
12392
|
+
|
12393
|
+
if (isnan(f)) {
|
12394
|
+
fprintf(stderr, "ggml_validate_row_data: found nan value at block %zu\n", i);
|
12395
|
+
return false;
|
12396
|
+
}
|
12397
|
+
|
12398
|
+
return true;
|
12399
|
+
}
|
12400
|
+
|
12401
|
+
static bool isinf_fp16(ggml_fp16_t f) {
|
12402
|
+
return (f & 0x7c00) == 0x7c00 && (f & 0x03ff) == 0;
|
12403
|
+
}
|
12404
|
+
|
12405
|
+
static bool isnan_fp16(ggml_fp16_t f) {
|
12406
|
+
return (f & 0x7c00) == 0x7c00 && (f & 0x03ff) != 0;
|
12407
|
+
}
|
12408
|
+
|
12409
|
+
static bool validate_fp16(ggml_fp16_t f, size_t i) {
|
12410
|
+
if (isinf_fp16(f)) {
|
12411
|
+
fprintf(stderr, "ggml_validate_row_data: found inf value at block %zu\n", i);
|
12412
|
+
return false;
|
12413
|
+
}
|
12414
|
+
|
12415
|
+
if (isnan_fp16(f)) {
|
12416
|
+
fprintf(stderr, "ggml_validate_row_data: found nan value at block %zu\n", i);
|
12417
|
+
return false;
|
12418
|
+
}
|
12419
|
+
|
12420
|
+
return true;
|
12421
|
+
}
|
12422
|
+
|
12423
|
+
#define VALIDATE_ROW_DATA_D_F16_IMPL(type, data, nb) \
|
12424
|
+
const type * q = (const type *) (data); \
|
12425
|
+
for (size_t i = 0; i < (nb); ++i) { \
|
12426
|
+
if (!validate_fp16(q[i].d, i)) { \
|
12427
|
+
return false; \
|
12428
|
+
} \
|
12429
|
+
}
|
12430
|
+
|
12431
|
+
#define VALIDATE_ROW_DATA_DM_F16_IMPL(type, data, nb, d, m) \
|
12432
|
+
const type * q = (const type *) (data); \
|
12433
|
+
for (size_t i = 0; i < (nb); ++i) { \
|
12434
|
+
if (!validate_fp16(q[i].d, i) || !validate_fp16(q[i].m, i)) { \
|
12435
|
+
return false; \
|
12436
|
+
} \
|
12437
|
+
}
|
12438
|
+
|
12439
|
+
bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbytes) {
|
12440
|
+
if (type < 0 || type >= GGML_TYPE_COUNT) {
|
12441
|
+
fprintf(stderr, "%s: invalid type %d\n", __func__, type);
|
12442
|
+
return false;
|
12443
|
+
}
|
12444
|
+
|
12445
|
+
if (nbytes % ggml_type_size(type) != 0) {
|
12446
|
+
fprintf(stderr, "%s: invalid size %zu for type %d\n", __func__, nbytes, type);
|
12447
|
+
return false;
|
12448
|
+
}
|
12449
|
+
|
12450
|
+
const size_t nb = nbytes/ggml_type_size(type);
|
12451
|
+
|
12452
|
+
switch (type) {
|
12453
|
+
case GGML_TYPE_BF16:
|
12454
|
+
{
|
12455
|
+
int nans = 0;
|
12456
|
+
int infs = 0;
|
12457
|
+
const unsigned short * f = (const unsigned short *) data;
|
12458
|
+
for (size_t i = 0; i < nb; ++i) {
|
12459
|
+
nans += (f[i] & 0x7fff) > 0x7f80;
|
12460
|
+
infs += (f[i] & 0x7fff) == 0x7f80;
|
12461
|
+
}
|
12462
|
+
if (nans) {
|
12463
|
+
fprintf(stderr, "%s: found %d NaNs in row of %zu BF16 values\n", __func__, nans, nb);
|
12464
|
+
return false;
|
12465
|
+
}
|
12466
|
+
if (infs) {
|
12467
|
+
fprintf(stderr, "%s: found %d infinities in row of %zu BF16 values\n", __func__, infs, nb);
|
12468
|
+
return false;
|
12469
|
+
}
|
12470
|
+
} break;
|
12471
|
+
case GGML_TYPE_F16:
|
12472
|
+
{
|
12473
|
+
const ggml_fp16_t * f = (const ggml_fp16_t *) data;
|
12474
|
+
size_t i = 0;
|
12475
|
+
#if defined(__AVX2__)
|
12476
|
+
for (; i + 15 < nb; i += 16) {
|
12477
|
+
__m256i v = _mm256_loadu_si256((const __m256i *)(f + i));
|
12478
|
+
__m256i vexp = _mm256_and_si256(v, _mm256_set1_epi16(0x7c00));
|
12479
|
+
__m256i cmp = _mm256_cmpeq_epi16(vexp, _mm256_set1_epi16(0x7c00));
|
12480
|
+
int mask = _mm256_movemask_epi8(cmp);
|
12481
|
+
if (mask) {
|
12482
|
+
for (size_t j = 0; j < 16; ++j) {
|
12483
|
+
if (!validate_fp16(f[i + j], i + j)) {
|
12484
|
+
return false;
|
12485
|
+
}
|
12486
|
+
}
|
12487
|
+
GGML_UNREACHABLE();
|
12488
|
+
}
|
12489
|
+
}
|
12490
|
+
#elif defined(__ARM_NEON)
|
12491
|
+
for (; i + 7 < nb; i += 8) {
|
12492
|
+
uint16x8_t v = vld1q_u16(f + i);
|
12493
|
+
uint16x8_t vexp = vandq_u16(v, vdupq_n_u16(0x7c00));
|
12494
|
+
uint16x8_t cmp = vceqq_u16(vexp, vdupq_n_u16(0x7c00));
|
12495
|
+
uint64_t mask = vget_lane_u64(vreinterpret_u64_u8(vshrn_n_u16(cmp, 4)), 0);
|
12496
|
+
if (mask) {
|
12497
|
+
for (size_t j = 0; j < 8; ++j) {
|
12498
|
+
if (!validate_fp16(f[i + j], i + j)) {
|
12499
|
+
return false;
|
12500
|
+
}
|
12501
|
+
}
|
12502
|
+
GGML_UNREACHABLE();
|
12503
|
+
}
|
12504
|
+
}
|
12505
|
+
#endif
|
12506
|
+
for (; i < nb; ++i) {
|
12507
|
+
if (!validate_fp16(f[i], i)) {
|
12508
|
+
return false;
|
12509
|
+
}
|
12510
|
+
}
|
12511
|
+
} break;
|
12512
|
+
case GGML_TYPE_F32:
|
12513
|
+
{
|
12514
|
+
const float * f = (const float *) data;
|
12515
|
+
size_t i = 0;
|
12516
|
+
#if defined(__AVX2__)
|
12517
|
+
for (; i + 7 < nb; i += 8) {
|
12518
|
+
__m256i v = _mm256_loadu_si256((const __m256i *)(f + i));
|
12519
|
+
__m256i vexp = _mm256_and_si256(v, _mm256_set1_epi32(0x7f800000));
|
12520
|
+
__m256i cmp = _mm256_cmpeq_epi32(vexp, _mm256_set1_epi32(0x7f800000));
|
12521
|
+
int mask = _mm256_movemask_epi8(cmp);
|
12522
|
+
if (mask) {
|
12523
|
+
for (size_t j = 0; j < 8; ++j) {
|
12524
|
+
if (!validate_float(f[i + j], i + j)) {
|
12525
|
+
return false;
|
12526
|
+
}
|
12527
|
+
}
|
12528
|
+
GGML_UNREACHABLE();
|
12529
|
+
}
|
12530
|
+
}
|
12531
|
+
#elif defined(__ARM_NEON)
|
12532
|
+
for (; i + 3 < nb; i += 4) {
|
12533
|
+
uint32x4_t v = vld1q_u32((const uint32_t *)f + i);
|
12534
|
+
uint32x4_t vexp = vandq_u32(v, vdupq_n_u32(0x7f800000));
|
12535
|
+
uint32x4_t cmp = vceqq_u32(vexp, vdupq_n_u32(0x7f800000));
|
12536
|
+
uint64_t mask = vget_lane_u64(vreinterpret_u64_u16(vshrn_n_u32(cmp, 8)), 0);
|
12537
|
+
if (mask) {
|
12538
|
+
for (size_t j = 0; j < 4; ++j) {
|
12539
|
+
if (!validate_float(f[i + j], i + j)) {
|
12540
|
+
return false;
|
12541
|
+
}
|
12542
|
+
}
|
12543
|
+
GGML_UNREACHABLE();
|
12544
|
+
}
|
12545
|
+
}
|
12546
|
+
#endif
|
12547
|
+
for (; i < nb; ++i) {
|
12548
|
+
if (!validate_float(f[i], i)) {
|
12549
|
+
return false;
|
12550
|
+
}
|
12551
|
+
}
|
12552
|
+
} break;
|
12553
|
+
case GGML_TYPE_F64:
|
12554
|
+
{
|
12555
|
+
const double * f = (const double *) data;
|
12556
|
+
for (size_t i = 0; i < nb; ++i) {
|
12557
|
+
if (!validate_float(f[i], i)) {
|
12558
|
+
return false;
|
12559
|
+
}
|
12560
|
+
}
|
12561
|
+
} break;
|
12562
|
+
case GGML_TYPE_Q4_0:
|
12563
|
+
{
|
12564
|
+
VALIDATE_ROW_DATA_D_F16_IMPL(block_q4_0, data, nb);
|
12565
|
+
} break;
|
12566
|
+
case GGML_TYPE_Q4_1:
|
12567
|
+
{
|
12568
|
+
VALIDATE_ROW_DATA_DM_F16_IMPL(block_q4_1, data, nb, d, m);
|
12569
|
+
} break;
|
12570
|
+
case GGML_TYPE_Q5_0:
|
12571
|
+
{
|
12572
|
+
VALIDATE_ROW_DATA_D_F16_IMPL(block_q5_0, data, nb);
|
12573
|
+
} break;
|
12574
|
+
case GGML_TYPE_Q5_1:
|
12575
|
+
{
|
12576
|
+
VALIDATE_ROW_DATA_DM_F16_IMPL(block_q5_1, data, nb, d, m);
|
12577
|
+
} break;
|
12578
|
+
case GGML_TYPE_Q8_0:
|
12579
|
+
{
|
12580
|
+
VALIDATE_ROW_DATA_D_F16_IMPL(block_q8_0, data, nb);
|
12581
|
+
} break;
|
12582
|
+
case GGML_TYPE_Q2_K:
|
12583
|
+
{
|
12584
|
+
VALIDATE_ROW_DATA_DM_F16_IMPL(block_q2_K, data, nb, d, dmin);
|
12585
|
+
} break;
|
12586
|
+
case GGML_TYPE_Q3_K:
|
12587
|
+
{
|
12588
|
+
VALIDATE_ROW_DATA_D_F16_IMPL(block_q3_K, data, nb);
|
12589
|
+
} break;
|
12590
|
+
case GGML_TYPE_Q4_K:
|
12591
|
+
{
|
12592
|
+
#ifdef GGML_QKK_64
|
12593
|
+
VALIDATE_ROW_DATA_DM_F16_IMPL(block_q4_K, data, nb, d[0], d[1]);
|
12594
|
+
#else
|
12595
|
+
VALIDATE_ROW_DATA_DM_F16_IMPL(block_q4_K, data, nb, d, dmin);
|
12596
|
+
#endif
|
12597
|
+
} break;
|
12598
|
+
case GGML_TYPE_Q5_K:
|
12599
|
+
{
|
12600
|
+
#ifdef GGML_QKK_64
|
12601
|
+
VALIDATE_ROW_DATA_D_F16_IMPL(block_q5_K, data, nb);
|
12602
|
+
#else
|
12603
|
+
VALIDATE_ROW_DATA_DM_F16_IMPL(block_q5_K, data, nb, d, dmin);
|
12604
|
+
#endif
|
12605
|
+
} break;
|
12606
|
+
case GGML_TYPE_Q6_K:
|
12607
|
+
{
|
12608
|
+
VALIDATE_ROW_DATA_D_F16_IMPL(block_q6_K, data, nb);
|
12609
|
+
} break;
|
12610
|
+
case GGML_TYPE_Q8_K:
|
12611
|
+
{
|
12612
|
+
const block_q8_K * q = (const block_q8_K *) data;
|
12613
|
+
for (size_t i = 0; i < nb; ++i) {
|
12614
|
+
if (!validate_float(q[i].d, i)) {
|
12615
|
+
return false;
|
12616
|
+
}
|
12617
|
+
}
|
12618
|
+
} break;
|
12619
|
+
case GGML_TYPE_IQ1_S:
|
12620
|
+
{
|
12621
|
+
VALIDATE_ROW_DATA_D_F16_IMPL(block_iq1_s, data, nb);
|
12622
|
+
} break;
|
12623
|
+
case GGML_TYPE_IQ1_M:
|
12624
|
+
{
|
12625
|
+
const block_iq1_m * q = (const block_iq1_m *) data;
|
12626
|
+
for (size_t i = 0; i < nb; ++i) {
|
12627
|
+
#if QK_K == 64
|
12628
|
+
if (!validate_fp16(q[i].d, i)) {
|
12629
|
+
return false;
|
12630
|
+
}
|
12631
|
+
#else
|
12632
|
+
iq1m_scale_t scale;
|
12633
|
+
const uint16_t * sc = (const uint16_t *)q[i].scales;
|
12634
|
+
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
|
12635
|
+
if (!validate_fp16(scale.f16, i)) {
|
12636
|
+
return false;
|
12637
|
+
}
|
12638
|
+
#endif
|
12639
|
+
}
|
12640
|
+
} break;
|
12641
|
+
case GGML_TYPE_IQ2_XXS:
|
12642
|
+
{
|
12643
|
+
VALIDATE_ROW_DATA_D_F16_IMPL(block_iq2_xxs, data, nb);
|
12644
|
+
} break;
|
12645
|
+
case GGML_TYPE_IQ2_XS:
|
12646
|
+
{
|
12647
|
+
VALIDATE_ROW_DATA_D_F16_IMPL(block_iq2_xs, data, nb);
|
12648
|
+
} break;
|
12649
|
+
case GGML_TYPE_IQ2_S:
|
12650
|
+
{
|
12651
|
+
VALIDATE_ROW_DATA_D_F16_IMPL(block_iq2_s, data, nb);
|
12652
|
+
} break;
|
12653
|
+
case GGML_TYPE_IQ3_XXS:
|
12654
|
+
{
|
12655
|
+
VALIDATE_ROW_DATA_D_F16_IMPL(block_iq3_xxs, data, nb);
|
12656
|
+
} break;
|
12657
|
+
|
12658
|
+
case GGML_TYPE_IQ3_S:
|
12659
|
+
{
|
12660
|
+
VALIDATE_ROW_DATA_D_F16_IMPL(block_iq3_s, data, nb);
|
12661
|
+
} break;
|
12662
|
+
case GGML_TYPE_IQ4_XS:
|
12663
|
+
#if QK_K != 64
|
12664
|
+
{
|
12665
|
+
VALIDATE_ROW_DATA_D_F16_IMPL(block_iq4_xs, data, nb);
|
12666
|
+
} break;
|
12667
|
+
#endif
|
12668
|
+
// with QK_K == 64, iq4_xs is iq4_nl
|
12669
|
+
case GGML_TYPE_IQ4_NL:
|
12670
|
+
{
|
12671
|
+
VALIDATE_ROW_DATA_D_F16_IMPL(block_iq4_nl, data, nb);
|
12672
|
+
} break;
|
12673
|
+
case GGML_TYPE_I8:
|
12674
|
+
case GGML_TYPE_I16:
|
12675
|
+
case GGML_TYPE_I32:
|
12676
|
+
case GGML_TYPE_I64:
|
12677
|
+
// nothing to validate
|
12678
|
+
break;
|
12679
|
+
default:
|
12680
|
+
{
|
12681
|
+
fprintf(stderr, "%s: invalid type %d\n", __func__, type);
|
12682
|
+
return false;
|
12683
|
+
}
|
12684
|
+
}
|
12685
|
+
|
12686
|
+
return true;
|
12687
|
+
}
|
@@ -8330,24 +8330,26 @@ static void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict_
|
|
8330
8330
|
const int blocks_per_row = ncols / qk;
|
8331
8331
|
const int blocks_per_warp = vdr * WARP_SIZE / qi;
|
8332
8332
|
|
8333
|
-
//
|
8333
|
+
const int qi_vdr = (qi / vdr); // N_threads processing 1 qk block
|
8334
|
+
|
8335
|
+
// partial sum for each thread
|
8334
8336
|
float tmp = 0.0f;
|
8335
8337
|
|
8336
8338
|
const block_q_t * x = (const block_q_t *) vx;
|
8337
8339
|
const block_q8_1 * y = (const block_q8_1 *) vy;
|
8338
8340
|
|
8339
|
-
for (int i = item_ct1.get_local_id(2) /
|
8341
|
+
for (int i = item_ct1.get_local_id(2) / qi_vdr; i < blocks_per_row;
|
8340
8342
|
i += blocks_per_warp) {
|
8341
|
-
|
8343
|
+
const int ibx = row * blocks_per_row + i; // x block index
|
8342
8344
|
|
8343
|
-
|
8345
|
+
const int iby = i * (qk / QK8_1); // y block index that aligns with ibx
|
8344
8346
|
|
8345
|
-
|
8346
|
-
|
8347
|
-
|
8348
|
-
|
8347
|
+
const int iqs =
|
8348
|
+
vdr *
|
8349
|
+
(item_ct1.get_local_id(2) -
|
8350
|
+
i * qi_vdr); // x block quant index when casting the quants to int
|
8349
8351
|
|
8350
|
-
|
8352
|
+
tmp += vec_dot_q_sycl(&x[ibx], &y[iby], iqs);
|
8351
8353
|
}
|
8352
8354
|
|
8353
8355
|
// sum up partial sums and write back result
|
@@ -13416,11 +13418,16 @@ void print_device_detail(int id, sycl::device &device, std::string device_type)
|
|
13416
13418
|
version += std::to_string(prop.get_minor_version());
|
13417
13419
|
|
13418
13420
|
device_type = std::regex_replace(device_type, std::regex("ext_oneapi_"), "");
|
13421
|
+
std::string name = std::string(prop.get_name());
|
13422
|
+
name = std::regex_replace(name, std::regex("\\(R\\)"), "");
|
13423
|
+
name = std::regex_replace(name, std::regex("\\(TM\\)"), "");
|
13419
13424
|
|
13420
|
-
|
13421
|
-
|
13425
|
+
auto global_mem_size = prop.get_global_mem_size()/1000000;
|
13426
|
+
|
13427
|
+
fprintf(stderr, "|%2d|%19s|%39s|%7s|%7d|%8d|%5d|%6luM|%21s|\n", id, device_type.c_str(),
|
13428
|
+
name.c_str(), version.c_str(), prop.get_max_compute_units(),
|
13422
13429
|
prop.get_max_work_group_size(), prop.get_max_sub_group_size(),
|
13423
|
-
|
13430
|
+
global_mem_size, device.get_info<sycl::info::device::driver_version>().c_str());
|
13424
13431
|
}
|
13425
13432
|
|
13426
13433
|
void ggml_backend_sycl_print_sycl_devices() {
|
@@ -13428,9 +13435,10 @@ void ggml_backend_sycl_print_sycl_devices() {
|
|
13428
13435
|
int device_count = dpct::dev_mgr::instance().device_count();
|
13429
13436
|
std::map<std::string, size_t> DeviceNums;
|
13430
13437
|
fprintf(stderr, "found %d SYCL devices:\n", device_count);
|
13431
|
-
fprintf(stderr, "| |
|
13432
|
-
fprintf(stderr, "|
|
13433
|
-
fprintf(stderr, "
|
13438
|
+
fprintf(stderr, "| | | | |Max | |Max |Global | |\n");
|
13439
|
+
fprintf(stderr, "| | | | |compute|Max work|sub |mem | |\n");
|
13440
|
+
fprintf(stderr, "|ID| Device Type| Name|Version|units |group |group|size | Driver version|\n");
|
13441
|
+
fprintf(stderr, "|--|-------------------|---------------------------------------|-------|-------|--------|-----|-------|---------------------|\n");
|
13434
13442
|
for (int id = 0; id < device_count; ++id) {
|
13435
13443
|
sycl::device device = dpct::dev_mgr::instance().get_device(id);
|
13436
13444
|
sycl::backend backend = device.get_backend();
|
@@ -14738,7 +14746,12 @@ inline void ggml_sycl_op_soft_max(const ggml_tensor *src0,
|
|
14738
14746
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
14739
14747
|
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
14740
14748
|
|
14749
|
+
const ggml_tensor * src2 = dst->src[2];
|
14750
|
+
|
14751
|
+
#pragma message("TODO: add ggml_sycl_op_soft_max() F16 src1 and src2 support")
|
14752
|
+
#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021")
|
14741
14753
|
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
|
14754
|
+
GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F32); // src2 contains positions and it is optional
|
14742
14755
|
|
14743
14756
|
const int64_t ne00 = src0->ne[0];
|
14744
14757
|
const int64_t nrows_x = ggml_nrows(src0);
|
@@ -14754,7 +14767,6 @@ inline void ggml_sycl_op_soft_max(const ggml_tensor *src0,
|
|
14754
14767
|
float * src2_dd = nullptr;
|
14755
14768
|
sycl_pool_alloc<float> src2_f;
|
14756
14769
|
|
14757
|
-
ggml_tensor * src2 = dst->src[2];
|
14758
14770
|
const bool use_src2 = src2 != nullptr;
|
14759
14771
|
|
14760
14772
|
if (use_src2) {
|