llama_cpp 0.15.1 → 0.15.2
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +6 -0
- data/lib/llama_cpp/version.rb +2 -2
- data/vendor/tmp/llama.cpp/Makefile +3 -3
- data/vendor/tmp/llama.cpp/ggml-backend.c +2 -3
- data/vendor/tmp/llama.cpp/ggml-cuda.cu +15 -7
- data/vendor/tmp/llama.cpp/ggml-impl.h +7 -0
- data/vendor/tmp/llama.cpp/ggml-kompute.cpp +9 -3
- data/vendor/tmp/llama.cpp/ggml-metal.m +114 -125
- data/vendor/tmp/llama.cpp/ggml-metal.metal +86 -109
- data/vendor/tmp/llama.cpp/ggml-quants.c +2202 -28
- data/vendor/tmp/llama.cpp/ggml-rpc.cpp +1032 -0
- data/vendor/tmp/llama.cpp/ggml-rpc.h +24 -0
- data/vendor/tmp/llama.cpp/ggml-sycl.cpp +24 -143
- data/vendor/tmp/llama.cpp/ggml-vulkan.cpp +4 -2
- data/vendor/tmp/llama.cpp/ggml.c +726 -646
- data/vendor/tmp/llama.cpp/ggml.h +28 -17
- data/vendor/tmp/llama.cpp/llama.cpp +478 -281
- data/vendor/tmp/llama.cpp/llama.h +3 -0
- data/vendor/tmp/llama.cpp/unicode-data.cpp +6969 -2169
- data/vendor/tmp/llama.cpp/unicode-data.h +15 -12
- data/vendor/tmp/llama.cpp/unicode.cpp +89 -111
- data/vendor/tmp/llama.cpp/unicode.h +44 -12
- metadata +4 -2
@@ -229,6 +229,13 @@ kernel void kernel_relu(
|
|
229
229
|
dst[tpig] = max(0.0f, src0[tpig]);
|
230
230
|
}
|
231
231
|
|
232
|
+
kernel void kernel_sigmoid(
|
233
|
+
device const float * src0,
|
234
|
+
device float * dst,
|
235
|
+
uint tpig[[thread_position_in_grid]]) {
|
236
|
+
dst[tpig] = 1.0f / (1.0f + exp(-src0[tpig]));
|
237
|
+
}
|
238
|
+
|
232
239
|
kernel void kernel_tanh(
|
233
240
|
device const float * src0,
|
234
241
|
device float * dst,
|
@@ -356,7 +363,6 @@ template<typename T>
|
|
356
363
|
kernel void kernel_soft_max(
|
357
364
|
device const char * src0,
|
358
365
|
device const char * src1,
|
359
|
-
device const char * src2,
|
360
366
|
device char * dst,
|
361
367
|
constant int64_t & ne00,
|
362
368
|
constant int64_t & ne01,
|
@@ -378,10 +384,9 @@ kernel void kernel_soft_max(
|
|
378
384
|
|
379
385
|
device const float * psrc0 = (device const float *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
380
386
|
device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00 : nullptr;
|
381
|
-
device const T * ppos = src2 != src0 ? (device const T *) src2 : nullptr;
|
382
387
|
device float * pdst = (device float *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
383
388
|
|
384
|
-
float slope =
|
389
|
+
float slope = 1.0f;
|
385
390
|
|
386
391
|
// ALiBi
|
387
392
|
if (max_bias > 0.0f) {
|
@@ -397,7 +402,7 @@ kernel void kernel_soft_max(
|
|
397
402
|
float lmax = -INFINITY;
|
398
403
|
|
399
404
|
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
400
|
-
lmax = MAX(lmax, psrc0[i00]*scale + (pmask ?
|
405
|
+
lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f));
|
401
406
|
}
|
402
407
|
|
403
408
|
// find the max value in the block
|
@@ -422,7 +427,7 @@ kernel void kernel_soft_max(
|
|
422
427
|
// parallel sum
|
423
428
|
float lsum = 0.0f;
|
424
429
|
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
425
|
-
const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ?
|
430
|
+
const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val);
|
426
431
|
lsum += exp_psrc0;
|
427
432
|
pdst[i00] = exp_psrc0;
|
428
433
|
}
|
@@ -461,7 +466,6 @@ template<typename T>
|
|
461
466
|
kernel void kernel_soft_max_4(
|
462
467
|
device const char * src0,
|
463
468
|
device const char * src1,
|
464
|
-
device const char * src2,
|
465
469
|
device char * dst,
|
466
470
|
constant int64_t & ne00,
|
467
471
|
constant int64_t & ne01,
|
@@ -483,10 +487,9 @@ kernel void kernel_soft_max_4(
|
|
483
487
|
|
484
488
|
device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
|
485
489
|
device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00/4 : nullptr;
|
486
|
-
device const T * ppos = src2 != src0 ? (device const T *) src2 : nullptr;
|
487
490
|
device float4 * pdst4 = (device float4 *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
|
488
491
|
|
489
|
-
float slope =
|
492
|
+
float slope = 1.0f;
|
490
493
|
|
491
494
|
if (max_bias > 0.0f) {
|
492
495
|
const int64_t h = i02;
|
@@ -501,7 +504,7 @@ kernel void kernel_soft_max_4(
|
|
501
504
|
float4 lmax4 = -INFINITY;
|
502
505
|
|
503
506
|
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
504
|
-
lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4)((pmask ?
|
507
|
+
lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f)));
|
505
508
|
}
|
506
509
|
|
507
510
|
const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
|
@@ -527,7 +530,7 @@ kernel void kernel_soft_max_4(
|
|
527
530
|
// parallel sum
|
528
531
|
float4 lsum4 = 0.0f;
|
529
532
|
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
530
|
-
const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4)((pmask ?
|
533
|
+
const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val);
|
531
534
|
lsum4 += exp_psrc4;
|
532
535
|
pdst4[i00] = exp_psrc4;
|
533
536
|
}
|
@@ -1595,60 +1598,6 @@ kernel void kernel_mul_mv_f16_f32_l4(
|
|
1595
1598
|
}
|
1596
1599
|
}
|
1597
1600
|
|
1598
|
-
kernel void kernel_alibi_f32(
|
1599
|
-
device const float * src0,
|
1600
|
-
device float * dst,
|
1601
|
-
constant int64_t & ne00,
|
1602
|
-
constant int64_t & ne01,
|
1603
|
-
constant int64_t & ne02,
|
1604
|
-
constant int64_t & ne03,
|
1605
|
-
constant uint64_t & nb00,
|
1606
|
-
constant uint64_t & nb01,
|
1607
|
-
constant uint64_t & nb02,
|
1608
|
-
constant uint64_t & nb03,
|
1609
|
-
constant int64_t & ne0,
|
1610
|
-
constant int64_t & ne1,
|
1611
|
-
constant int64_t & ne2,
|
1612
|
-
constant int64_t & ne3,
|
1613
|
-
constant uint64_t & nb0,
|
1614
|
-
constant uint64_t & nb1,
|
1615
|
-
constant uint64_t & nb2,
|
1616
|
-
constant uint64_t & nb3,
|
1617
|
-
constant float & m0,
|
1618
|
-
constant float & m1,
|
1619
|
-
constant int & n_heads_log2_floor,
|
1620
|
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
1621
|
-
uint3 tpitg[[thread_position_in_threadgroup]],
|
1622
|
-
uint3 ntg[[threads_per_threadgroup]]) {
|
1623
|
-
const int64_t i03 = tgpig[2];
|
1624
|
-
const int64_t i02 = tgpig[1];
|
1625
|
-
const int64_t i01 = tgpig[0];
|
1626
|
-
|
1627
|
-
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
1628
|
-
|
1629
|
-
const int64_t i3 = n / (ne2*ne1*ne0);
|
1630
|
-
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
1631
|
-
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
1632
|
-
//const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
|
1633
|
-
|
1634
|
-
const int64_t k = i3*ne3 + i2;
|
1635
|
-
|
1636
|
-
float m_k;
|
1637
|
-
if (k < n_heads_log2_floor) {
|
1638
|
-
m_k = pow(m0, k + 1);
|
1639
|
-
} else {
|
1640
|
-
m_k = pow(m1, 2 * (k - n_heads_log2_floor) + 1);
|
1641
|
-
}
|
1642
|
-
|
1643
|
-
device char * dst_row = (device char *) dst + i3*nb3 + i2*nb2 + i1*nb1;
|
1644
|
-
device const char * src_row = (device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01;
|
1645
|
-
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
|
1646
|
-
const float src_v = *(device float *)(src_row + i00*nb00);
|
1647
|
-
device float * dst_v = (device float *)(dst_row + i00*nb0);
|
1648
|
-
*dst_v = i00 * m_k + src_v;
|
1649
|
-
}
|
1650
|
-
}
|
1651
|
-
|
1652
1601
|
static float rope_yarn_ramp(const float low, const float high, const int i0) {
|
1653
1602
|
const float y = (i0 / 2 - low) / max(0.001f, high - low);
|
1654
1603
|
return 1.0f - min(1.0f, max(0.0f, y));
|
@@ -1903,7 +1852,10 @@ kernel void kernel_upscale_f32(
|
|
1903
1852
|
constant uint64_t & nb1,
|
1904
1853
|
constant uint64_t & nb2,
|
1905
1854
|
constant uint64_t & nb3,
|
1906
|
-
constant
|
1855
|
+
constant float & sf0,
|
1856
|
+
constant float & sf1,
|
1857
|
+
constant float & sf2,
|
1858
|
+
constant float & sf3,
|
1907
1859
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
1908
1860
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
1909
1861
|
uint3 ntg[[threads_per_threadgroup]]) {
|
@@ -1912,15 +1864,17 @@ kernel void kernel_upscale_f32(
|
|
1912
1864
|
const int64_t i2 = tgpig.y;
|
1913
1865
|
const int64_t i1 = tgpig.x;
|
1914
1866
|
|
1915
|
-
const int64_t i03 = i3;
|
1916
|
-
const int64_t i02 = i2;
|
1917
|
-
const int64_t i01 = i1/
|
1918
|
-
|
1919
|
-
device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
|
1920
|
-
device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1);
|
1867
|
+
const int64_t i03 = i3/sf3;
|
1868
|
+
const int64_t i02 = i2/sf2;
|
1869
|
+
const int64_t i01 = i1/sf1;
|
1921
1870
|
|
1922
1871
|
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
1923
|
-
|
1872
|
+
const int64_t i00 = i0/sf0;
|
1873
|
+
|
1874
|
+
device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
1875
|
+
device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
1876
|
+
|
1877
|
+
dst_ptr[0] = src0_ptr[0];
|
1924
1878
|
}
|
1925
1879
|
}
|
1926
1880
|
|
@@ -2100,29 +2054,29 @@ typedef void (flash_attn_ext_f16_t)(
|
|
2100
2054
|
device const char * v,
|
2101
2055
|
device const char * mask,
|
2102
2056
|
device float * dst,
|
2103
|
-
constant int64_t & ne00,
|
2104
2057
|
constant int64_t & ne01,
|
2105
2058
|
constant int64_t & ne02,
|
2106
2059
|
constant int64_t & ne03,
|
2107
|
-
constant uint64_t & nb00,
|
2108
2060
|
constant uint64_t & nb01,
|
2109
2061
|
constant uint64_t & nb02,
|
2110
2062
|
constant uint64_t & nb03,
|
2111
|
-
constant int64_t & ne10,
|
2112
2063
|
constant int64_t & ne11,
|
2113
2064
|
constant int64_t & ne12,
|
2114
2065
|
constant int64_t & ne13,
|
2115
|
-
constant uint64_t & nb10,
|
2116
2066
|
constant uint64_t & nb11,
|
2117
2067
|
constant uint64_t & nb12,
|
2118
2068
|
constant uint64_t & nb13,
|
2119
|
-
constant
|
2069
|
+
constant uint64_t & nb21,
|
2070
|
+
constant uint64_t & nb22,
|
2071
|
+
constant uint64_t & nb23,
|
2120
2072
|
constant uint64_t & nb31,
|
2121
|
-
constant int64_t & ne0,
|
2122
2073
|
constant int64_t & ne1,
|
2123
2074
|
constant int64_t & ne2,
|
2124
|
-
constant int64_t & ne3,
|
2125
2075
|
constant float & scale,
|
2076
|
+
constant float & max_bias,
|
2077
|
+
constant float & m0,
|
2078
|
+
constant float & m1,
|
2079
|
+
constant uint32_t & n_head_log2,
|
2126
2080
|
threadgroup half * shared,
|
2127
2081
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
2128
2082
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
@@ -2138,29 +2092,29 @@ kernel void kernel_flash_attn_ext_f16(
|
|
2138
2092
|
device const char * v,
|
2139
2093
|
device const char * mask,
|
2140
2094
|
device float * dst,
|
2141
|
-
constant int64_t & ne00,
|
2142
2095
|
constant int64_t & ne01,
|
2143
2096
|
constant int64_t & ne02,
|
2144
2097
|
constant int64_t & ne03,
|
2145
|
-
constant uint64_t & nb00,
|
2146
2098
|
constant uint64_t & nb01,
|
2147
2099
|
constant uint64_t & nb02,
|
2148
2100
|
constant uint64_t & nb03,
|
2149
|
-
constant int64_t & ne10,
|
2150
2101
|
constant int64_t & ne11,
|
2151
2102
|
constant int64_t & ne12,
|
2152
2103
|
constant int64_t & ne13,
|
2153
|
-
constant uint64_t & nb10,
|
2154
2104
|
constant uint64_t & nb11,
|
2155
2105
|
constant uint64_t & nb12,
|
2156
2106
|
constant uint64_t & nb13,
|
2157
|
-
constant
|
2107
|
+
constant uint64_t & nb21,
|
2108
|
+
constant uint64_t & nb22,
|
2109
|
+
constant uint64_t & nb23,
|
2158
2110
|
constant uint64_t & nb31,
|
2159
|
-
constant int64_t & ne0,
|
2160
2111
|
constant int64_t & ne1,
|
2161
2112
|
constant int64_t & ne2,
|
2162
|
-
constant int64_t & ne3,
|
2163
2113
|
constant float & scale,
|
2114
|
+
constant float & max_bias,
|
2115
|
+
constant float & m0,
|
2116
|
+
constant float & m1,
|
2117
|
+
constant uint32_t & n_head_log2,
|
2164
2118
|
threadgroup half * shared [[threadgroup(0)]],
|
2165
2119
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
2166
2120
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
@@ -2225,10 +2179,6 @@ kernel void kernel_flash_attn_ext_f16(
|
|
2225
2179
|
const short ne22 = ne12;
|
2226
2180
|
const short ne23 = ne13;
|
2227
2181
|
|
2228
|
-
const uint nb21 = nb11;
|
2229
|
-
const uint nb22 = nb12;
|
2230
|
-
const uint nb23 = nb13;
|
2231
|
-
|
2232
2182
|
// broadcast
|
2233
2183
|
const short rk2 = ne02/ne12;
|
2234
2184
|
const short rk3 = ne03/ne13;
|
@@ -2257,6 +2207,19 @@ kernel void kernel_flash_attn_ext_f16(
|
|
2257
2207
|
// prepare diagonal scale matrix
|
2258
2208
|
simdgroup_float8x8 mscale(scale);
|
2259
2209
|
|
2210
|
+
// prepare diagonal slope matrix
|
2211
|
+
simdgroup_float8x8 mslope(1.0f);
|
2212
|
+
|
2213
|
+
// ALiBi
|
2214
|
+
if (max_bias > 0.0f) {
|
2215
|
+
const uint32_t h = iq2;
|
2216
|
+
|
2217
|
+
const float base = h < n_head_log2 ? m0 : m1;
|
2218
|
+
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
2219
|
+
|
2220
|
+
mslope = simdgroup_float8x8(pow(base, exph));
|
2221
|
+
}
|
2222
|
+
|
2260
2223
|
// loop over the KV cache
|
2261
2224
|
// each simdgroup handles blocks of Q rows and C columns
|
2262
2225
|
for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) {
|
@@ -2279,10 +2242,16 @@ kernel void kernel_flash_attn_ext_f16(
|
|
2279
2242
|
simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
|
2280
2243
|
}
|
2281
2244
|
|
2282
|
-
|
2283
|
-
|
2284
|
-
|
2285
|
-
|
2245
|
+
if (mask != q) {
|
2246
|
+
// mqk = mqk*scale + mask*slope
|
2247
|
+
simdgroup_half8x8 mm;
|
2248
|
+
simdgroup_load(mm, mp + ic + 8*cc, nb31/sizeof(half), 0, false);
|
2249
|
+
simdgroup_multiply(mm, mslope, mm);
|
2250
|
+
simdgroup_multiply_accumulate(mqk, mqk, mscale, mm);
|
2251
|
+
} else {
|
2252
|
+
// mqk = mqk*scale
|
2253
|
+
simdgroup_multiply(mqk, mscale, mqk);
|
2254
|
+
}
|
2286
2255
|
|
2287
2256
|
simdgroup_store(mqk, ss + 8*cc, TF, 0, false);
|
2288
2257
|
}
|
@@ -2456,29 +2425,29 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
|
2456
2425
|
device const char * v,
|
2457
2426
|
device const char * mask,
|
2458
2427
|
device float * dst,
|
2459
|
-
constant int64_t & ne00,
|
2460
2428
|
constant int64_t & ne01,
|
2461
2429
|
constant int64_t & ne02,
|
2462
2430
|
constant int64_t & ne03,
|
2463
|
-
constant uint64_t & nb00,
|
2464
2431
|
constant uint64_t & nb01,
|
2465
2432
|
constant uint64_t & nb02,
|
2466
2433
|
constant uint64_t & nb03,
|
2467
|
-
constant int64_t & ne10,
|
2468
2434
|
constant int64_t & ne11,
|
2469
2435
|
constant int64_t & ne12,
|
2470
2436
|
constant int64_t & ne13,
|
2471
|
-
constant uint64_t & nb10,
|
2472
2437
|
constant uint64_t & nb11,
|
2473
2438
|
constant uint64_t & nb12,
|
2474
2439
|
constant uint64_t & nb13,
|
2475
|
-
constant
|
2440
|
+
constant uint64_t & nb21,
|
2441
|
+
constant uint64_t & nb22,
|
2442
|
+
constant uint64_t & nb23,
|
2476
2443
|
constant uint64_t & nb31,
|
2477
|
-
constant int64_t & ne0,
|
2478
2444
|
constant int64_t & ne1,
|
2479
2445
|
constant int64_t & ne2,
|
2480
|
-
constant int64_t & ne3,
|
2481
2446
|
constant float & scale,
|
2447
|
+
constant float & max_bias,
|
2448
|
+
constant float & m0,
|
2449
|
+
constant float & m1,
|
2450
|
+
constant uint32_t & n_head_log2,
|
2482
2451
|
threadgroup half * shared [[threadgroup(0)]],
|
2483
2452
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
2484
2453
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
@@ -2497,6 +2466,18 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
|
2497
2466
|
|
2498
2467
|
const short T = D + 2*nsg*SH; // shared memory size per query in (half)
|
2499
2468
|
|
2469
|
+
float slope = 1.0f;
|
2470
|
+
|
2471
|
+
// ALiBi
|
2472
|
+
if (max_bias > 0.0f) {
|
2473
|
+
const uint32_t h = iq2;
|
2474
|
+
|
2475
|
+
const float base = h < n_head_log2 ? m0 : m1;
|
2476
|
+
const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
2477
|
+
|
2478
|
+
slope = pow(base, exp);
|
2479
|
+
}
|
2480
|
+
|
2500
2481
|
//threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
|
2501
2482
|
threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
|
2502
2483
|
threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
|
@@ -2537,10 +2518,6 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
|
2537
2518
|
const short ne22 = ne12;
|
2538
2519
|
const short ne23 = ne13;
|
2539
2520
|
|
2540
|
-
const uint nb21 = nb11;
|
2541
|
-
const uint nb22 = nb12;
|
2542
|
-
const uint nb23 = nb13;
|
2543
|
-
|
2544
2521
|
// broadcast
|
2545
2522
|
const short rk2 = ne02/ne12;
|
2546
2523
|
const short rk3 = ne03/ne13;
|
@@ -2603,10 +2580,9 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
|
2603
2580
|
mqk += simd_shuffle_down(mqk, 2);
|
2604
2581
|
mqk += simd_shuffle_down(mqk, 1);
|
2605
2582
|
|
2606
|
-
// mqk = mqk*scale + mask
|
2583
|
+
// mqk = mqk*scale + mask*slope
|
2607
2584
|
if (tiisg == 0) {
|
2608
|
-
|
2609
|
-
mqk = mqk*scale + mm;
|
2585
|
+
mqk = mqk*scale + ((mask != q) ? ((float4) mp4[ic/4 + cc])*slope : (float4) 0.0f);
|
2610
2586
|
|
2611
2587
|
ss4[cc] = mqk;
|
2612
2588
|
}
|
@@ -2840,7 +2816,8 @@ kernel void kernel_cpy_f32_f16(
|
|
2840
2816
|
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
|
2841
2817
|
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
2842
2818
|
|
2843
|
-
|
2819
|
+
// TODO: is there a better way to handle -INFINITY?
|
2820
|
+
dst_data[i00] = src[0] == -INFINITY ? -MAXHALF : src[0];
|
2844
2821
|
}
|
2845
2822
|
}
|
2846
2823
|
|