llama_cpp 0.14.7 → 0.15.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +19 -0
- data/README.md +2 -2
- data/ext/llama_cpp/extconf.rb +2 -1
- data/ext/llama_cpp/llama_cpp.cpp +59 -9
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +24 -3
- data/vendor/tmp/llama.cpp/Makefile +42 -18
- data/vendor/tmp/llama.cpp/ggml-backend.c +7 -5
- data/vendor/tmp/llama.cpp/ggml-cuda.cu +295 -17
- data/vendor/tmp/llama.cpp/ggml-impl.h +78 -1
- data/vendor/tmp/llama.cpp/ggml-kompute.cpp +7 -0
- data/vendor/tmp/llama.cpp/ggml-metal.m +399 -184
- data/vendor/tmp/llama.cpp/ggml-metal.metal +654 -18
- data/vendor/tmp/llama.cpp/ggml-opencl.cpp +1 -0
- data/vendor/tmp/llama.cpp/ggml-quants.c +302 -0
- data/vendor/tmp/llama.cpp/ggml-sycl.cpp +28 -16
- data/vendor/tmp/llama.cpp/ggml-vulkan-shaders.hpp +46843 -39205
- data/vendor/tmp/llama.cpp/ggml-vulkan.cpp +951 -263
- data/vendor/tmp/llama.cpp/ggml.c +1457 -92
- data/vendor/tmp/llama.cpp/ggml.h +37 -7
- data/vendor/tmp/llama.cpp/llama.cpp +671 -403
- data/vendor/tmp/llama.cpp/llama.h +34 -10
- data/vendor/tmp/llama.cpp/sgemm.cpp +134 -103
- data/vendor/tmp/llama.cpp/sgemm.h +4 -2
- data/vendor/tmp/llama.cpp/unicode-data.cpp +1188 -656
- data/vendor/tmp/llama.cpp/unicode-data.h +4 -3
- data/vendor/tmp/llama.cpp/unicode.cpp +590 -49
- data/vendor/tmp/llama.cpp/unicode.h +6 -3
- metadata +3 -3
@@ -2119,6 +2119,7 @@ static size_t ggml_backend_opencl_buffer_type_get_alignment(ggml_backend_buffer_
|
|
2119
2119
|
if (alignment == (cl_uint)-1) {
|
2120
2120
|
ggml_cl_init();
|
2121
2121
|
clGetDeviceInfo(device, CL_DEVICE_MEM_BASE_ADDR_ALIGN, sizeof(cl_uint), &alignment, NULL);
|
2122
|
+
alignment /= 8; // bits to bytes
|
2122
2123
|
}
|
2123
2124
|
return alignment;
|
2124
2125
|
|
@@ -12383,3 +12383,305 @@ void quantize_row_iq2_s(const float * restrict x, void * restrict vy, int64_t k)
|
|
12383
12383
|
block_iq2_s * restrict y = vy;
|
12384
12384
|
quantize_row_iq2_s_reference(x, y, k);
|
12385
12385
|
}
|
12386
|
+
|
12387
|
+
static bool validate_float(float f, size_t i) {
|
12388
|
+
if (isinf(f)) {
|
12389
|
+
fprintf(stderr, "ggml_validate_row_data: found inf value at block %zu\n", i);
|
12390
|
+
return false;
|
12391
|
+
}
|
12392
|
+
|
12393
|
+
if (isnan(f)) {
|
12394
|
+
fprintf(stderr, "ggml_validate_row_data: found nan value at block %zu\n", i);
|
12395
|
+
return false;
|
12396
|
+
}
|
12397
|
+
|
12398
|
+
return true;
|
12399
|
+
}
|
12400
|
+
|
12401
|
+
static bool isinf_fp16(ggml_fp16_t f) {
|
12402
|
+
return (f & 0x7c00) == 0x7c00 && (f & 0x03ff) == 0;
|
12403
|
+
}
|
12404
|
+
|
12405
|
+
static bool isnan_fp16(ggml_fp16_t f) {
|
12406
|
+
return (f & 0x7c00) == 0x7c00 && (f & 0x03ff) != 0;
|
12407
|
+
}
|
12408
|
+
|
12409
|
+
static bool validate_fp16(ggml_fp16_t f, size_t i) {
|
12410
|
+
if (isinf_fp16(f)) {
|
12411
|
+
fprintf(stderr, "ggml_validate_row_data: found inf value at block %zu\n", i);
|
12412
|
+
return false;
|
12413
|
+
}
|
12414
|
+
|
12415
|
+
if (isnan_fp16(f)) {
|
12416
|
+
fprintf(stderr, "ggml_validate_row_data: found nan value at block %zu\n", i);
|
12417
|
+
return false;
|
12418
|
+
}
|
12419
|
+
|
12420
|
+
return true;
|
12421
|
+
}
|
12422
|
+
|
12423
|
+
#define VALIDATE_ROW_DATA_D_F16_IMPL(type, data, nb) \
|
12424
|
+
const type * q = (const type *) (data); \
|
12425
|
+
for (size_t i = 0; i < (nb); ++i) { \
|
12426
|
+
if (!validate_fp16(q[i].d, i)) { \
|
12427
|
+
return false; \
|
12428
|
+
} \
|
12429
|
+
}
|
12430
|
+
|
12431
|
+
#define VALIDATE_ROW_DATA_DM_F16_IMPL(type, data, nb, d, m) \
|
12432
|
+
const type * q = (const type *) (data); \
|
12433
|
+
for (size_t i = 0; i < (nb); ++i) { \
|
12434
|
+
if (!validate_fp16(q[i].d, i) || !validate_fp16(q[i].m, i)) { \
|
12435
|
+
return false; \
|
12436
|
+
} \
|
12437
|
+
}
|
12438
|
+
|
12439
|
+
bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbytes) {
|
12440
|
+
if (type < 0 || type >= GGML_TYPE_COUNT) {
|
12441
|
+
fprintf(stderr, "%s: invalid type %d\n", __func__, type);
|
12442
|
+
return false;
|
12443
|
+
}
|
12444
|
+
|
12445
|
+
if (nbytes % ggml_type_size(type) != 0) {
|
12446
|
+
fprintf(stderr, "%s: invalid size %zu for type %d\n", __func__, nbytes, type);
|
12447
|
+
return false;
|
12448
|
+
}
|
12449
|
+
|
12450
|
+
const size_t nb = nbytes/ggml_type_size(type);
|
12451
|
+
|
12452
|
+
switch (type) {
|
12453
|
+
case GGML_TYPE_BF16:
|
12454
|
+
{
|
12455
|
+
int nans = 0;
|
12456
|
+
int infs = 0;
|
12457
|
+
const unsigned short * f = (const unsigned short *) data;
|
12458
|
+
for (size_t i = 0; i < nb; ++i) {
|
12459
|
+
nans += (f[i] & 0x7fff) > 0x7f80;
|
12460
|
+
infs += (f[i] & 0x7fff) == 0x7f80;
|
12461
|
+
}
|
12462
|
+
if (nans) {
|
12463
|
+
fprintf(stderr, "%s: found %d NaNs in row of %zu BF16 values\n", __func__, nans, nb);
|
12464
|
+
return false;
|
12465
|
+
}
|
12466
|
+
if (infs) {
|
12467
|
+
fprintf(stderr, "%s: found %d infinities in row of %zu BF16 values\n", __func__, infs, nb);
|
12468
|
+
return false;
|
12469
|
+
}
|
12470
|
+
} break;
|
12471
|
+
case GGML_TYPE_F16:
|
12472
|
+
{
|
12473
|
+
const ggml_fp16_t * f = (const ggml_fp16_t *) data;
|
12474
|
+
size_t i = 0;
|
12475
|
+
#if defined(__AVX2__)
|
12476
|
+
for (; i + 15 < nb; i += 16) {
|
12477
|
+
__m256i v = _mm256_loadu_si256((const __m256i *)(f + i));
|
12478
|
+
__m256i vexp = _mm256_and_si256(v, _mm256_set1_epi16(0x7c00));
|
12479
|
+
__m256i cmp = _mm256_cmpeq_epi16(vexp, _mm256_set1_epi16(0x7c00));
|
12480
|
+
int mask = _mm256_movemask_epi8(cmp);
|
12481
|
+
if (mask) {
|
12482
|
+
for (size_t j = 0; j < 16; ++j) {
|
12483
|
+
if (!validate_fp16(f[i + j], i + j)) {
|
12484
|
+
return false;
|
12485
|
+
}
|
12486
|
+
}
|
12487
|
+
GGML_UNREACHABLE();
|
12488
|
+
}
|
12489
|
+
}
|
12490
|
+
#elif defined(__ARM_NEON)
|
12491
|
+
for (; i + 7 < nb; i += 8) {
|
12492
|
+
uint16x8_t v = vld1q_u16(f + i);
|
12493
|
+
uint16x8_t vexp = vandq_u16(v, vdupq_n_u16(0x7c00));
|
12494
|
+
uint16x8_t cmp = vceqq_u16(vexp, vdupq_n_u16(0x7c00));
|
12495
|
+
uint64_t mask = vget_lane_u64(vreinterpret_u64_u8(vshrn_n_u16(cmp, 4)), 0);
|
12496
|
+
if (mask) {
|
12497
|
+
for (size_t j = 0; j < 8; ++j) {
|
12498
|
+
if (!validate_fp16(f[i + j], i + j)) {
|
12499
|
+
return false;
|
12500
|
+
}
|
12501
|
+
}
|
12502
|
+
GGML_UNREACHABLE();
|
12503
|
+
}
|
12504
|
+
}
|
12505
|
+
#endif
|
12506
|
+
for (; i < nb; ++i) {
|
12507
|
+
if (!validate_fp16(f[i], i)) {
|
12508
|
+
return false;
|
12509
|
+
}
|
12510
|
+
}
|
12511
|
+
} break;
|
12512
|
+
case GGML_TYPE_F32:
|
12513
|
+
{
|
12514
|
+
const float * f = (const float *) data;
|
12515
|
+
size_t i = 0;
|
12516
|
+
#if defined(__AVX2__)
|
12517
|
+
for (; i + 7 < nb; i += 8) {
|
12518
|
+
__m256i v = _mm256_loadu_si256((const __m256i *)(f + i));
|
12519
|
+
__m256i vexp = _mm256_and_si256(v, _mm256_set1_epi32(0x7f800000));
|
12520
|
+
__m256i cmp = _mm256_cmpeq_epi32(vexp, _mm256_set1_epi32(0x7f800000));
|
12521
|
+
int mask = _mm256_movemask_epi8(cmp);
|
12522
|
+
if (mask) {
|
12523
|
+
for (size_t j = 0; j < 8; ++j) {
|
12524
|
+
if (!validate_float(f[i + j], i + j)) {
|
12525
|
+
return false;
|
12526
|
+
}
|
12527
|
+
}
|
12528
|
+
GGML_UNREACHABLE();
|
12529
|
+
}
|
12530
|
+
}
|
12531
|
+
#elif defined(__ARM_NEON)
|
12532
|
+
for (; i + 3 < nb; i += 4) {
|
12533
|
+
uint32x4_t v = vld1q_u32((const uint32_t *)f + i);
|
12534
|
+
uint32x4_t vexp = vandq_u32(v, vdupq_n_u32(0x7f800000));
|
12535
|
+
uint32x4_t cmp = vceqq_u32(vexp, vdupq_n_u32(0x7f800000));
|
12536
|
+
uint64_t mask = vget_lane_u64(vreinterpret_u64_u16(vshrn_n_u32(cmp, 8)), 0);
|
12537
|
+
if (mask) {
|
12538
|
+
for (size_t j = 0; j < 4; ++j) {
|
12539
|
+
if (!validate_float(f[i + j], i + j)) {
|
12540
|
+
return false;
|
12541
|
+
}
|
12542
|
+
}
|
12543
|
+
GGML_UNREACHABLE();
|
12544
|
+
}
|
12545
|
+
}
|
12546
|
+
#endif
|
12547
|
+
for (; i < nb; ++i) {
|
12548
|
+
if (!validate_float(f[i], i)) {
|
12549
|
+
return false;
|
12550
|
+
}
|
12551
|
+
}
|
12552
|
+
} break;
|
12553
|
+
case GGML_TYPE_F64:
|
12554
|
+
{
|
12555
|
+
const double * f = (const double *) data;
|
12556
|
+
for (size_t i = 0; i < nb; ++i) {
|
12557
|
+
if (!validate_float(f[i], i)) {
|
12558
|
+
return false;
|
12559
|
+
}
|
12560
|
+
}
|
12561
|
+
} break;
|
12562
|
+
case GGML_TYPE_Q4_0:
|
12563
|
+
{
|
12564
|
+
VALIDATE_ROW_DATA_D_F16_IMPL(block_q4_0, data, nb);
|
12565
|
+
} break;
|
12566
|
+
case GGML_TYPE_Q4_1:
|
12567
|
+
{
|
12568
|
+
VALIDATE_ROW_DATA_DM_F16_IMPL(block_q4_1, data, nb, d, m);
|
12569
|
+
} break;
|
12570
|
+
case GGML_TYPE_Q5_0:
|
12571
|
+
{
|
12572
|
+
VALIDATE_ROW_DATA_D_F16_IMPL(block_q5_0, data, nb);
|
12573
|
+
} break;
|
12574
|
+
case GGML_TYPE_Q5_1:
|
12575
|
+
{
|
12576
|
+
VALIDATE_ROW_DATA_DM_F16_IMPL(block_q5_1, data, nb, d, m);
|
12577
|
+
} break;
|
12578
|
+
case GGML_TYPE_Q8_0:
|
12579
|
+
{
|
12580
|
+
VALIDATE_ROW_DATA_D_F16_IMPL(block_q8_0, data, nb);
|
12581
|
+
} break;
|
12582
|
+
case GGML_TYPE_Q2_K:
|
12583
|
+
{
|
12584
|
+
VALIDATE_ROW_DATA_DM_F16_IMPL(block_q2_K, data, nb, d, dmin);
|
12585
|
+
} break;
|
12586
|
+
case GGML_TYPE_Q3_K:
|
12587
|
+
{
|
12588
|
+
VALIDATE_ROW_DATA_D_F16_IMPL(block_q3_K, data, nb);
|
12589
|
+
} break;
|
12590
|
+
case GGML_TYPE_Q4_K:
|
12591
|
+
{
|
12592
|
+
#ifdef GGML_QKK_64
|
12593
|
+
VALIDATE_ROW_DATA_DM_F16_IMPL(block_q4_K, data, nb, d[0], d[1]);
|
12594
|
+
#else
|
12595
|
+
VALIDATE_ROW_DATA_DM_F16_IMPL(block_q4_K, data, nb, d, dmin);
|
12596
|
+
#endif
|
12597
|
+
} break;
|
12598
|
+
case GGML_TYPE_Q5_K:
|
12599
|
+
{
|
12600
|
+
#ifdef GGML_QKK_64
|
12601
|
+
VALIDATE_ROW_DATA_D_F16_IMPL(block_q5_K, data, nb);
|
12602
|
+
#else
|
12603
|
+
VALIDATE_ROW_DATA_DM_F16_IMPL(block_q5_K, data, nb, d, dmin);
|
12604
|
+
#endif
|
12605
|
+
} break;
|
12606
|
+
case GGML_TYPE_Q6_K:
|
12607
|
+
{
|
12608
|
+
VALIDATE_ROW_DATA_D_F16_IMPL(block_q6_K, data, nb);
|
12609
|
+
} break;
|
12610
|
+
case GGML_TYPE_Q8_K:
|
12611
|
+
{
|
12612
|
+
const block_q8_K * q = (const block_q8_K *) data;
|
12613
|
+
for (size_t i = 0; i < nb; ++i) {
|
12614
|
+
if (!validate_float(q[i].d, i)) {
|
12615
|
+
return false;
|
12616
|
+
}
|
12617
|
+
}
|
12618
|
+
} break;
|
12619
|
+
case GGML_TYPE_IQ1_S:
|
12620
|
+
{
|
12621
|
+
VALIDATE_ROW_DATA_D_F16_IMPL(block_iq1_s, data, nb);
|
12622
|
+
} break;
|
12623
|
+
case GGML_TYPE_IQ1_M:
|
12624
|
+
{
|
12625
|
+
const block_iq1_m * q = (const block_iq1_m *) data;
|
12626
|
+
for (size_t i = 0; i < nb; ++i) {
|
12627
|
+
#if QK_K == 64
|
12628
|
+
if (!validate_fp16(q[i].d, i)) {
|
12629
|
+
return false;
|
12630
|
+
}
|
12631
|
+
#else
|
12632
|
+
iq1m_scale_t scale;
|
12633
|
+
const uint16_t * sc = (const uint16_t *)q[i].scales;
|
12634
|
+
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
|
12635
|
+
if (!validate_fp16(scale.f16, i)) {
|
12636
|
+
return false;
|
12637
|
+
}
|
12638
|
+
#endif
|
12639
|
+
}
|
12640
|
+
} break;
|
12641
|
+
case GGML_TYPE_IQ2_XXS:
|
12642
|
+
{
|
12643
|
+
VALIDATE_ROW_DATA_D_F16_IMPL(block_iq2_xxs, data, nb);
|
12644
|
+
} break;
|
12645
|
+
case GGML_TYPE_IQ2_XS:
|
12646
|
+
{
|
12647
|
+
VALIDATE_ROW_DATA_D_F16_IMPL(block_iq2_xs, data, nb);
|
12648
|
+
} break;
|
12649
|
+
case GGML_TYPE_IQ2_S:
|
12650
|
+
{
|
12651
|
+
VALIDATE_ROW_DATA_D_F16_IMPL(block_iq2_s, data, nb);
|
12652
|
+
} break;
|
12653
|
+
case GGML_TYPE_IQ3_XXS:
|
12654
|
+
{
|
12655
|
+
VALIDATE_ROW_DATA_D_F16_IMPL(block_iq3_xxs, data, nb);
|
12656
|
+
} break;
|
12657
|
+
|
12658
|
+
case GGML_TYPE_IQ3_S:
|
12659
|
+
{
|
12660
|
+
VALIDATE_ROW_DATA_D_F16_IMPL(block_iq3_s, data, nb);
|
12661
|
+
} break;
|
12662
|
+
case GGML_TYPE_IQ4_XS:
|
12663
|
+
#if QK_K != 64
|
12664
|
+
{
|
12665
|
+
VALIDATE_ROW_DATA_D_F16_IMPL(block_iq4_xs, data, nb);
|
12666
|
+
} break;
|
12667
|
+
#endif
|
12668
|
+
// with QK_K == 64, iq4_xs is iq4_nl
|
12669
|
+
case GGML_TYPE_IQ4_NL:
|
12670
|
+
{
|
12671
|
+
VALIDATE_ROW_DATA_D_F16_IMPL(block_iq4_nl, data, nb);
|
12672
|
+
} break;
|
12673
|
+
case GGML_TYPE_I8:
|
12674
|
+
case GGML_TYPE_I16:
|
12675
|
+
case GGML_TYPE_I32:
|
12676
|
+
case GGML_TYPE_I64:
|
12677
|
+
// nothing to validate
|
12678
|
+
break;
|
12679
|
+
default:
|
12680
|
+
{
|
12681
|
+
fprintf(stderr, "%s: invalid type %d\n", __func__, type);
|
12682
|
+
return false;
|
12683
|
+
}
|
12684
|
+
}
|
12685
|
+
|
12686
|
+
return true;
|
12687
|
+
}
|
@@ -8330,24 +8330,26 @@ static void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict_
|
|
8330
8330
|
const int blocks_per_row = ncols / qk;
|
8331
8331
|
const int blocks_per_warp = vdr * WARP_SIZE / qi;
|
8332
8332
|
|
8333
|
-
//
|
8333
|
+
const int qi_vdr = (qi / vdr); // N_threads processing 1 qk block
|
8334
|
+
|
8335
|
+
// partial sum for each thread
|
8334
8336
|
float tmp = 0.0f;
|
8335
8337
|
|
8336
8338
|
const block_q_t * x = (const block_q_t *) vx;
|
8337
8339
|
const block_q8_1 * y = (const block_q8_1 *) vy;
|
8338
8340
|
|
8339
|
-
for (int i = item_ct1.get_local_id(2) /
|
8341
|
+
for (int i = item_ct1.get_local_id(2) / qi_vdr; i < blocks_per_row;
|
8340
8342
|
i += blocks_per_warp) {
|
8341
|
-
|
8343
|
+
const int ibx = row * blocks_per_row + i; // x block index
|
8342
8344
|
|
8343
|
-
|
8345
|
+
const int iby = i * (qk / QK8_1); // y block index that aligns with ibx
|
8344
8346
|
|
8345
|
-
|
8346
|
-
|
8347
|
-
|
8348
|
-
|
8347
|
+
const int iqs =
|
8348
|
+
vdr *
|
8349
|
+
(item_ct1.get_local_id(2) -
|
8350
|
+
i * qi_vdr); // x block quant index when casting the quants to int
|
8349
8351
|
|
8350
|
-
|
8352
|
+
tmp += vec_dot_q_sycl(&x[ibx], &y[iby], iqs);
|
8351
8353
|
}
|
8352
8354
|
|
8353
8355
|
// sum up partial sums and write back result
|
@@ -13416,11 +13418,16 @@ void print_device_detail(int id, sycl::device &device, std::string device_type)
|
|
13416
13418
|
version += std::to_string(prop.get_minor_version());
|
13417
13419
|
|
13418
13420
|
device_type = std::regex_replace(device_type, std::regex("ext_oneapi_"), "");
|
13421
|
+
std::string name = std::string(prop.get_name());
|
13422
|
+
name = std::regex_replace(name, std::regex("\\(R\\)"), "");
|
13423
|
+
name = std::regex_replace(name, std::regex("\\(TM\\)"), "");
|
13419
13424
|
|
13420
|
-
|
13421
|
-
|
13425
|
+
auto global_mem_size = prop.get_global_mem_size()/1000000;
|
13426
|
+
|
13427
|
+
fprintf(stderr, "|%2d|%19s|%39s|%7s|%7d|%8d|%5d|%6luM|%21s|\n", id, device_type.c_str(),
|
13428
|
+
name.c_str(), version.c_str(), prop.get_max_compute_units(),
|
13422
13429
|
prop.get_max_work_group_size(), prop.get_max_sub_group_size(),
|
13423
|
-
|
13430
|
+
global_mem_size, device.get_info<sycl::info::device::driver_version>().c_str());
|
13424
13431
|
}
|
13425
13432
|
|
13426
13433
|
void ggml_backend_sycl_print_sycl_devices() {
|
@@ -13428,9 +13435,10 @@ void ggml_backend_sycl_print_sycl_devices() {
|
|
13428
13435
|
int device_count = dpct::dev_mgr::instance().device_count();
|
13429
13436
|
std::map<std::string, size_t> DeviceNums;
|
13430
13437
|
fprintf(stderr, "found %d SYCL devices:\n", device_count);
|
13431
|
-
fprintf(stderr, "| |
|
13432
|
-
fprintf(stderr, "|
|
13433
|
-
fprintf(stderr, "
|
13438
|
+
fprintf(stderr, "| | | | |Max | |Max |Global | |\n");
|
13439
|
+
fprintf(stderr, "| | | | |compute|Max work|sub |mem | |\n");
|
13440
|
+
fprintf(stderr, "|ID| Device Type| Name|Version|units |group |group|size | Driver version|\n");
|
13441
|
+
fprintf(stderr, "|--|-------------------|---------------------------------------|-------|-------|--------|-----|-------|---------------------|\n");
|
13434
13442
|
for (int id = 0; id < device_count; ++id) {
|
13435
13443
|
sycl::device device = dpct::dev_mgr::instance().get_device(id);
|
13436
13444
|
sycl::backend backend = device.get_backend();
|
@@ -14738,7 +14746,12 @@ inline void ggml_sycl_op_soft_max(const ggml_tensor *src0,
|
|
14738
14746
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
14739
14747
|
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
14740
14748
|
|
14749
|
+
const ggml_tensor * src2 = dst->src[2];
|
14750
|
+
|
14751
|
+
#pragma message("TODO: add ggml_sycl_op_soft_max() F16 src1 and src2 support")
|
14752
|
+
#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021")
|
14741
14753
|
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
|
14754
|
+
GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F32); // src2 contains positions and it is optional
|
14742
14755
|
|
14743
14756
|
const int64_t ne00 = src0->ne[0];
|
14744
14757
|
const int64_t nrows_x = ggml_nrows(src0);
|
@@ -14754,7 +14767,6 @@ inline void ggml_sycl_op_soft_max(const ggml_tensor *src0,
|
|
14754
14767
|
float * src2_dd = nullptr;
|
14755
14768
|
sycl_pool_alloc<float> src2_f;
|
14756
14769
|
|
14757
|
-
ggml_tensor * src2 = dst->src[2];
|
14758
14770
|
const bool use_src2 = src2 != nullptr;
|
14759
14771
|
|
14760
14772
|
if (use_src2) {
|