llama_cpp 0.15.1 → 0.15.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +6 -0
- data/lib/llama_cpp/version.rb +2 -2
- data/vendor/tmp/llama.cpp/Makefile +3 -3
- data/vendor/tmp/llama.cpp/ggml-backend.c +2 -3
- data/vendor/tmp/llama.cpp/ggml-cuda.cu +15 -7
- data/vendor/tmp/llama.cpp/ggml-impl.h +7 -0
- data/vendor/tmp/llama.cpp/ggml-kompute.cpp +9 -3
- data/vendor/tmp/llama.cpp/ggml-metal.m +114 -125
- data/vendor/tmp/llama.cpp/ggml-metal.metal +86 -109
- data/vendor/tmp/llama.cpp/ggml-quants.c +2202 -28
- data/vendor/tmp/llama.cpp/ggml-rpc.cpp +1032 -0
- data/vendor/tmp/llama.cpp/ggml-rpc.h +24 -0
- data/vendor/tmp/llama.cpp/ggml-sycl.cpp +24 -143
- data/vendor/tmp/llama.cpp/ggml-vulkan.cpp +4 -2
- data/vendor/tmp/llama.cpp/ggml.c +726 -646
- data/vendor/tmp/llama.cpp/ggml.h +28 -17
- data/vendor/tmp/llama.cpp/llama.cpp +478 -281
- data/vendor/tmp/llama.cpp/llama.h +3 -0
- data/vendor/tmp/llama.cpp/unicode-data.cpp +6969 -2169
- data/vendor/tmp/llama.cpp/unicode-data.h +15 -12
- data/vendor/tmp/llama.cpp/unicode.cpp +89 -111
- data/vendor/tmp/llama.cpp/unicode.h +44 -12
- metadata +4 -2
@@ -229,6 +229,13 @@ kernel void kernel_relu(
|
|
229
229
|
dst[tpig] = max(0.0f, src0[tpig]);
|
230
230
|
}
|
231
231
|
|
232
|
+
kernel void kernel_sigmoid(
|
233
|
+
device const float * src0,
|
234
|
+
device float * dst,
|
235
|
+
uint tpig[[thread_position_in_grid]]) {
|
236
|
+
dst[tpig] = 1.0f / (1.0f + exp(-src0[tpig]));
|
237
|
+
}
|
238
|
+
|
232
239
|
kernel void kernel_tanh(
|
233
240
|
device const float * src0,
|
234
241
|
device float * dst,
|
@@ -356,7 +363,6 @@ template<typename T>
|
|
356
363
|
kernel void kernel_soft_max(
|
357
364
|
device const char * src0,
|
358
365
|
device const char * src1,
|
359
|
-
device const char * src2,
|
360
366
|
device char * dst,
|
361
367
|
constant int64_t & ne00,
|
362
368
|
constant int64_t & ne01,
|
@@ -378,10 +384,9 @@ kernel void kernel_soft_max(
|
|
378
384
|
|
379
385
|
device const float * psrc0 = (device const float *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
380
386
|
device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00 : nullptr;
|
381
|
-
device const T * ppos = src2 != src0 ? (device const T *) src2 : nullptr;
|
382
387
|
device float * pdst = (device float *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
383
388
|
|
384
|
-
float slope =
|
389
|
+
float slope = 1.0f;
|
385
390
|
|
386
391
|
// ALiBi
|
387
392
|
if (max_bias > 0.0f) {
|
@@ -397,7 +402,7 @@ kernel void kernel_soft_max(
|
|
397
402
|
float lmax = -INFINITY;
|
398
403
|
|
399
404
|
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
400
|
-
lmax = MAX(lmax, psrc0[i00]*scale + (pmask ?
|
405
|
+
lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f));
|
401
406
|
}
|
402
407
|
|
403
408
|
// find the max value in the block
|
@@ -422,7 +427,7 @@ kernel void kernel_soft_max(
|
|
422
427
|
// parallel sum
|
423
428
|
float lsum = 0.0f;
|
424
429
|
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
425
|
-
const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ?
|
430
|
+
const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val);
|
426
431
|
lsum += exp_psrc0;
|
427
432
|
pdst[i00] = exp_psrc0;
|
428
433
|
}
|
@@ -461,7 +466,6 @@ template<typename T>
|
|
461
466
|
kernel void kernel_soft_max_4(
|
462
467
|
device const char * src0,
|
463
468
|
device const char * src1,
|
464
|
-
device const char * src2,
|
465
469
|
device char * dst,
|
466
470
|
constant int64_t & ne00,
|
467
471
|
constant int64_t & ne01,
|
@@ -483,10 +487,9 @@ kernel void kernel_soft_max_4(
|
|
483
487
|
|
484
488
|
device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
|
485
489
|
device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00/4 : nullptr;
|
486
|
-
device const T * ppos = src2 != src0 ? (device const T *) src2 : nullptr;
|
487
490
|
device float4 * pdst4 = (device float4 *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
|
488
491
|
|
489
|
-
float slope =
|
492
|
+
float slope = 1.0f;
|
490
493
|
|
491
494
|
if (max_bias > 0.0f) {
|
492
495
|
const int64_t h = i02;
|
@@ -501,7 +504,7 @@ kernel void kernel_soft_max_4(
|
|
501
504
|
float4 lmax4 = -INFINITY;
|
502
505
|
|
503
506
|
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
504
|
-
lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4)((pmask ?
|
507
|
+
lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f)));
|
505
508
|
}
|
506
509
|
|
507
510
|
const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
|
@@ -527,7 +530,7 @@ kernel void kernel_soft_max_4(
|
|
527
530
|
// parallel sum
|
528
531
|
float4 lsum4 = 0.0f;
|
529
532
|
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
530
|
-
const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4)((pmask ?
|
533
|
+
const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val);
|
531
534
|
lsum4 += exp_psrc4;
|
532
535
|
pdst4[i00] = exp_psrc4;
|
533
536
|
}
|
@@ -1595,60 +1598,6 @@ kernel void kernel_mul_mv_f16_f32_l4(
|
|
1595
1598
|
}
|
1596
1599
|
}
|
1597
1600
|
|
1598
|
-
kernel void kernel_alibi_f32(
|
1599
|
-
device const float * src0,
|
1600
|
-
device float * dst,
|
1601
|
-
constant int64_t & ne00,
|
1602
|
-
constant int64_t & ne01,
|
1603
|
-
constant int64_t & ne02,
|
1604
|
-
constant int64_t & ne03,
|
1605
|
-
constant uint64_t & nb00,
|
1606
|
-
constant uint64_t & nb01,
|
1607
|
-
constant uint64_t & nb02,
|
1608
|
-
constant uint64_t & nb03,
|
1609
|
-
constant int64_t & ne0,
|
1610
|
-
constant int64_t & ne1,
|
1611
|
-
constant int64_t & ne2,
|
1612
|
-
constant int64_t & ne3,
|
1613
|
-
constant uint64_t & nb0,
|
1614
|
-
constant uint64_t & nb1,
|
1615
|
-
constant uint64_t & nb2,
|
1616
|
-
constant uint64_t & nb3,
|
1617
|
-
constant float & m0,
|
1618
|
-
constant float & m1,
|
1619
|
-
constant int & n_heads_log2_floor,
|
1620
|
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
1621
|
-
uint3 tpitg[[thread_position_in_threadgroup]],
|
1622
|
-
uint3 ntg[[threads_per_threadgroup]]) {
|
1623
|
-
const int64_t i03 = tgpig[2];
|
1624
|
-
const int64_t i02 = tgpig[1];
|
1625
|
-
const int64_t i01 = tgpig[0];
|
1626
|
-
|
1627
|
-
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
1628
|
-
|
1629
|
-
const int64_t i3 = n / (ne2*ne1*ne0);
|
1630
|
-
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
1631
|
-
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
1632
|
-
//const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
|
1633
|
-
|
1634
|
-
const int64_t k = i3*ne3 + i2;
|
1635
|
-
|
1636
|
-
float m_k;
|
1637
|
-
if (k < n_heads_log2_floor) {
|
1638
|
-
m_k = pow(m0, k + 1);
|
1639
|
-
} else {
|
1640
|
-
m_k = pow(m1, 2 * (k - n_heads_log2_floor) + 1);
|
1641
|
-
}
|
1642
|
-
|
1643
|
-
device char * dst_row = (device char *) dst + i3*nb3 + i2*nb2 + i1*nb1;
|
1644
|
-
device const char * src_row = (device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01;
|
1645
|
-
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
|
1646
|
-
const float src_v = *(device float *)(src_row + i00*nb00);
|
1647
|
-
device float * dst_v = (device float *)(dst_row + i00*nb0);
|
1648
|
-
*dst_v = i00 * m_k + src_v;
|
1649
|
-
}
|
1650
|
-
}
|
1651
|
-
|
1652
1601
|
static float rope_yarn_ramp(const float low, const float high, const int i0) {
|
1653
1602
|
const float y = (i0 / 2 - low) / max(0.001f, high - low);
|
1654
1603
|
return 1.0f - min(1.0f, max(0.0f, y));
|
@@ -1903,7 +1852,10 @@ kernel void kernel_upscale_f32(
|
|
1903
1852
|
constant uint64_t & nb1,
|
1904
1853
|
constant uint64_t & nb2,
|
1905
1854
|
constant uint64_t & nb3,
|
1906
|
-
constant
|
1855
|
+
constant float & sf0,
|
1856
|
+
constant float & sf1,
|
1857
|
+
constant float & sf2,
|
1858
|
+
constant float & sf3,
|
1907
1859
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
1908
1860
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
1909
1861
|
uint3 ntg[[threads_per_threadgroup]]) {
|
@@ -1912,15 +1864,17 @@ kernel void kernel_upscale_f32(
|
|
1912
1864
|
const int64_t i2 = tgpig.y;
|
1913
1865
|
const int64_t i1 = tgpig.x;
|
1914
1866
|
|
1915
|
-
const int64_t i03 = i3;
|
1916
|
-
const int64_t i02 = i2;
|
1917
|
-
const int64_t i01 = i1/
|
1918
|
-
|
1919
|
-
device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
|
1920
|
-
device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1);
|
1867
|
+
const int64_t i03 = i3/sf3;
|
1868
|
+
const int64_t i02 = i2/sf2;
|
1869
|
+
const int64_t i01 = i1/sf1;
|
1921
1870
|
|
1922
1871
|
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
1923
|
-
|
1872
|
+
const int64_t i00 = i0/sf0;
|
1873
|
+
|
1874
|
+
device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
1875
|
+
device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
1876
|
+
|
1877
|
+
dst_ptr[0] = src0_ptr[0];
|
1924
1878
|
}
|
1925
1879
|
}
|
1926
1880
|
|
@@ -2100,29 +2054,29 @@ typedef void (flash_attn_ext_f16_t)(
|
|
2100
2054
|
device const char * v,
|
2101
2055
|
device const char * mask,
|
2102
2056
|
device float * dst,
|
2103
|
-
constant int64_t & ne00,
|
2104
2057
|
constant int64_t & ne01,
|
2105
2058
|
constant int64_t & ne02,
|
2106
2059
|
constant int64_t & ne03,
|
2107
|
-
constant uint64_t & nb00,
|
2108
2060
|
constant uint64_t & nb01,
|
2109
2061
|
constant uint64_t & nb02,
|
2110
2062
|
constant uint64_t & nb03,
|
2111
|
-
constant int64_t & ne10,
|
2112
2063
|
constant int64_t & ne11,
|
2113
2064
|
constant int64_t & ne12,
|
2114
2065
|
constant int64_t & ne13,
|
2115
|
-
constant uint64_t & nb10,
|
2116
2066
|
constant uint64_t & nb11,
|
2117
2067
|
constant uint64_t & nb12,
|
2118
2068
|
constant uint64_t & nb13,
|
2119
|
-
constant
|
2069
|
+
constant uint64_t & nb21,
|
2070
|
+
constant uint64_t & nb22,
|
2071
|
+
constant uint64_t & nb23,
|
2120
2072
|
constant uint64_t & nb31,
|
2121
|
-
constant int64_t & ne0,
|
2122
2073
|
constant int64_t & ne1,
|
2123
2074
|
constant int64_t & ne2,
|
2124
|
-
constant int64_t & ne3,
|
2125
2075
|
constant float & scale,
|
2076
|
+
constant float & max_bias,
|
2077
|
+
constant float & m0,
|
2078
|
+
constant float & m1,
|
2079
|
+
constant uint32_t & n_head_log2,
|
2126
2080
|
threadgroup half * shared,
|
2127
2081
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
2128
2082
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
@@ -2138,29 +2092,29 @@ kernel void kernel_flash_attn_ext_f16(
|
|
2138
2092
|
device const char * v,
|
2139
2093
|
device const char * mask,
|
2140
2094
|
device float * dst,
|
2141
|
-
constant int64_t & ne00,
|
2142
2095
|
constant int64_t & ne01,
|
2143
2096
|
constant int64_t & ne02,
|
2144
2097
|
constant int64_t & ne03,
|
2145
|
-
constant uint64_t & nb00,
|
2146
2098
|
constant uint64_t & nb01,
|
2147
2099
|
constant uint64_t & nb02,
|
2148
2100
|
constant uint64_t & nb03,
|
2149
|
-
constant int64_t & ne10,
|
2150
2101
|
constant int64_t & ne11,
|
2151
2102
|
constant int64_t & ne12,
|
2152
2103
|
constant int64_t & ne13,
|
2153
|
-
constant uint64_t & nb10,
|
2154
2104
|
constant uint64_t & nb11,
|
2155
2105
|
constant uint64_t & nb12,
|
2156
2106
|
constant uint64_t & nb13,
|
2157
|
-
constant
|
2107
|
+
constant uint64_t & nb21,
|
2108
|
+
constant uint64_t & nb22,
|
2109
|
+
constant uint64_t & nb23,
|
2158
2110
|
constant uint64_t & nb31,
|
2159
|
-
constant int64_t & ne0,
|
2160
2111
|
constant int64_t & ne1,
|
2161
2112
|
constant int64_t & ne2,
|
2162
|
-
constant int64_t & ne3,
|
2163
2113
|
constant float & scale,
|
2114
|
+
constant float & max_bias,
|
2115
|
+
constant float & m0,
|
2116
|
+
constant float & m1,
|
2117
|
+
constant uint32_t & n_head_log2,
|
2164
2118
|
threadgroup half * shared [[threadgroup(0)]],
|
2165
2119
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
2166
2120
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
@@ -2225,10 +2179,6 @@ kernel void kernel_flash_attn_ext_f16(
|
|
2225
2179
|
const short ne22 = ne12;
|
2226
2180
|
const short ne23 = ne13;
|
2227
2181
|
|
2228
|
-
const uint nb21 = nb11;
|
2229
|
-
const uint nb22 = nb12;
|
2230
|
-
const uint nb23 = nb13;
|
2231
|
-
|
2232
2182
|
// broadcast
|
2233
2183
|
const short rk2 = ne02/ne12;
|
2234
2184
|
const short rk3 = ne03/ne13;
|
@@ -2257,6 +2207,19 @@ kernel void kernel_flash_attn_ext_f16(
|
|
2257
2207
|
// prepare diagonal scale matrix
|
2258
2208
|
simdgroup_float8x8 mscale(scale);
|
2259
2209
|
|
2210
|
+
// prepare diagonal slope matrix
|
2211
|
+
simdgroup_float8x8 mslope(1.0f);
|
2212
|
+
|
2213
|
+
// ALiBi
|
2214
|
+
if (max_bias > 0.0f) {
|
2215
|
+
const uint32_t h = iq2;
|
2216
|
+
|
2217
|
+
const float base = h < n_head_log2 ? m0 : m1;
|
2218
|
+
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
2219
|
+
|
2220
|
+
mslope = simdgroup_float8x8(pow(base, exph));
|
2221
|
+
}
|
2222
|
+
|
2260
2223
|
// loop over the KV cache
|
2261
2224
|
// each simdgroup handles blocks of Q rows and C columns
|
2262
2225
|
for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) {
|
@@ -2279,10 +2242,16 @@ kernel void kernel_flash_attn_ext_f16(
|
|
2279
2242
|
simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
|
2280
2243
|
}
|
2281
2244
|
|
2282
|
-
|
2283
|
-
|
2284
|
-
|
2285
|
-
|
2245
|
+
if (mask != q) {
|
2246
|
+
// mqk = mqk*scale + mask*slope
|
2247
|
+
simdgroup_half8x8 mm;
|
2248
|
+
simdgroup_load(mm, mp + ic + 8*cc, nb31/sizeof(half), 0, false);
|
2249
|
+
simdgroup_multiply(mm, mslope, mm);
|
2250
|
+
simdgroup_multiply_accumulate(mqk, mqk, mscale, mm);
|
2251
|
+
} else {
|
2252
|
+
// mqk = mqk*scale
|
2253
|
+
simdgroup_multiply(mqk, mscale, mqk);
|
2254
|
+
}
|
2286
2255
|
|
2287
2256
|
simdgroup_store(mqk, ss + 8*cc, TF, 0, false);
|
2288
2257
|
}
|
@@ -2456,29 +2425,29 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
|
2456
2425
|
device const char * v,
|
2457
2426
|
device const char * mask,
|
2458
2427
|
device float * dst,
|
2459
|
-
constant int64_t & ne00,
|
2460
2428
|
constant int64_t & ne01,
|
2461
2429
|
constant int64_t & ne02,
|
2462
2430
|
constant int64_t & ne03,
|
2463
|
-
constant uint64_t & nb00,
|
2464
2431
|
constant uint64_t & nb01,
|
2465
2432
|
constant uint64_t & nb02,
|
2466
2433
|
constant uint64_t & nb03,
|
2467
|
-
constant int64_t & ne10,
|
2468
2434
|
constant int64_t & ne11,
|
2469
2435
|
constant int64_t & ne12,
|
2470
2436
|
constant int64_t & ne13,
|
2471
|
-
constant uint64_t & nb10,
|
2472
2437
|
constant uint64_t & nb11,
|
2473
2438
|
constant uint64_t & nb12,
|
2474
2439
|
constant uint64_t & nb13,
|
2475
|
-
constant
|
2440
|
+
constant uint64_t & nb21,
|
2441
|
+
constant uint64_t & nb22,
|
2442
|
+
constant uint64_t & nb23,
|
2476
2443
|
constant uint64_t & nb31,
|
2477
|
-
constant int64_t & ne0,
|
2478
2444
|
constant int64_t & ne1,
|
2479
2445
|
constant int64_t & ne2,
|
2480
|
-
constant int64_t & ne3,
|
2481
2446
|
constant float & scale,
|
2447
|
+
constant float & max_bias,
|
2448
|
+
constant float & m0,
|
2449
|
+
constant float & m1,
|
2450
|
+
constant uint32_t & n_head_log2,
|
2482
2451
|
threadgroup half * shared [[threadgroup(0)]],
|
2483
2452
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
2484
2453
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
@@ -2497,6 +2466,18 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
|
2497
2466
|
|
2498
2467
|
const short T = D + 2*nsg*SH; // shared memory size per query in (half)
|
2499
2468
|
|
2469
|
+
float slope = 1.0f;
|
2470
|
+
|
2471
|
+
// ALiBi
|
2472
|
+
if (max_bias > 0.0f) {
|
2473
|
+
const uint32_t h = iq2;
|
2474
|
+
|
2475
|
+
const float base = h < n_head_log2 ? m0 : m1;
|
2476
|
+
const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
2477
|
+
|
2478
|
+
slope = pow(base, exp);
|
2479
|
+
}
|
2480
|
+
|
2500
2481
|
//threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
|
2501
2482
|
threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
|
2502
2483
|
threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
|
@@ -2537,10 +2518,6 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
|
2537
2518
|
const short ne22 = ne12;
|
2538
2519
|
const short ne23 = ne13;
|
2539
2520
|
|
2540
|
-
const uint nb21 = nb11;
|
2541
|
-
const uint nb22 = nb12;
|
2542
|
-
const uint nb23 = nb13;
|
2543
|
-
|
2544
2521
|
// broadcast
|
2545
2522
|
const short rk2 = ne02/ne12;
|
2546
2523
|
const short rk3 = ne03/ne13;
|
@@ -2603,10 +2580,9 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
|
2603
2580
|
mqk += simd_shuffle_down(mqk, 2);
|
2604
2581
|
mqk += simd_shuffle_down(mqk, 1);
|
2605
2582
|
|
2606
|
-
// mqk = mqk*scale + mask
|
2583
|
+
// mqk = mqk*scale + mask*slope
|
2607
2584
|
if (tiisg == 0) {
|
2608
|
-
|
2609
|
-
mqk = mqk*scale + mm;
|
2585
|
+
mqk = mqk*scale + ((mask != q) ? ((float4) mp4[ic/4 + cc])*slope : (float4) 0.0f);
|
2610
2586
|
|
2611
2587
|
ss4[cc] = mqk;
|
2612
2588
|
}
|
@@ -2840,7 +2816,8 @@ kernel void kernel_cpy_f32_f16(
|
|
2840
2816
|
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
|
2841
2817
|
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
2842
2818
|
|
2843
|
-
|
2819
|
+
// TODO: is there a better way to handle -INFINITY?
|
2820
|
+
dst_data[i00] = src[0] == -INFINITY ? -MAXHALF : src[0];
|
2844
2821
|
}
|
2845
2822
|
}
|
2846
2823
|
|