llama_cpp 0.11.1 → 0.12.1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +15 -0
- data/README.md +3 -3
- data/examples/chat.rb +6 -2
- data/examples/embedding.rb +5 -1
- data/examples/simple.rb +4 -1
- data/ext/llama_cpp/llama_cpp.cpp +63 -0
- data/lib/llama_cpp/version.rb +2 -2
- data/lib/llama_cpp.rb +2 -2
- data/sig/llama_cpp.rbs +5 -0
- data/vendor/tmp/llama.cpp/Makefile +8 -2
- data/vendor/tmp/llama.cpp/ggml-backend-impl.h +1 -1
- data/vendor/tmp/llama.cpp/ggml-backend.c +7 -3
- data/vendor/tmp/llama.cpp/ggml-backend.h +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda.cu +758 -39
- data/vendor/tmp/llama.cpp/ggml-metal.h +1 -1
- data/vendor/tmp/llama.cpp/ggml-metal.m +86 -7
- data/vendor/tmp/llama.cpp/ggml-metal.metal +692 -8
- data/vendor/tmp/llama.cpp/ggml-quants.c +635 -1
- data/vendor/tmp/llama.cpp/ggml-quants.h +25 -1
- data/vendor/tmp/llama.cpp/ggml.c +91 -52
- data/vendor/tmp/llama.cpp/ggml.h +14 -11
- data/vendor/tmp/llama.cpp/llama.cpp +79 -30
- data/vendor/tmp/llama.cpp/llama.h +14 -0
- metadata +2 -2
@@ -116,6 +116,7 @@
|
|
116
116
|
#include "ggml.h"
|
117
117
|
#include "ggml-backend-impl.h"
|
118
118
|
|
119
|
+
#define CC_PASCAL 600
|
119
120
|
#define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
|
120
121
|
#define CC_VOLTA 700
|
121
122
|
#define CC_OFFSET_AMD 1000000
|
@@ -183,7 +184,7 @@ static __device__ __forceinline__ int __vsubss4(const int a, const int b) {
|
|
183
184
|
static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) {
|
184
185
|
#if defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx1030__)
|
185
186
|
c = __builtin_amdgcn_sdot4(a, b, c, false);
|
186
|
-
#elif defined(
|
187
|
+
#elif defined(RDNA3)
|
187
188
|
c = __builtin_amdgcn_sudot4( true, a, true, b, c, false);
|
188
189
|
#elif defined(__gfx1010__) || defined(__gfx900__)
|
189
190
|
int tmp1;
|
@@ -477,6 +478,23 @@ typedef struct {
|
|
477
478
|
} block_q6_K;
|
478
479
|
static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_K block size/padding");
|
479
480
|
|
481
|
+
#define QR2_XXS 8
|
482
|
+
#define QI2_XXS (QK_K / (4*QR2_XXS))
|
483
|
+
typedef struct {
|
484
|
+
half d;
|
485
|
+
uint16_t qs[QK_K/8];
|
486
|
+
} block_iq2_xxs;
|
487
|
+
static_assert(sizeof(block_iq2_xxs) == sizeof(ggml_fp16_t) + QK_K/8*sizeof(uint16_t), "wrong iq2_xxs block size/padding");
|
488
|
+
|
489
|
+
#define QR2_XS 8
|
490
|
+
#define QI2_XS (QK_K / (4*QR2_XS))
|
491
|
+
typedef struct {
|
492
|
+
half d;
|
493
|
+
uint16_t qs[QK_K/8];
|
494
|
+
uint8_t scales[QK_K/32];
|
495
|
+
} block_iq2_xs;
|
496
|
+
static_assert(sizeof(block_iq2_xs) == sizeof(ggml_fp16_t) + QK_K/8*sizeof(uint16_t) + QK_K/32, "wrong iq2_xs block size/padding");
|
497
|
+
|
480
498
|
#define WARP_SIZE 32
|
481
499
|
#define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
|
482
500
|
|
@@ -548,11 +566,12 @@ static float g_tensor_split[GGML_CUDA_MAX_DEVICES] = {0};
|
|
548
566
|
|
549
567
|
struct cuda_device_capabilities {
|
550
568
|
int cc; // compute capability
|
569
|
+
size_t smpb; // max. shared memory per block
|
551
570
|
bool vmm; // virtual memory support
|
552
571
|
size_t vmm_granularity; // granularity of virtual memory
|
553
572
|
};
|
554
573
|
|
555
|
-
static cuda_device_capabilities g_device_caps[GGML_CUDA_MAX_DEVICES] = { {0, false, 0} };
|
574
|
+
static cuda_device_capabilities g_device_caps[GGML_CUDA_MAX_DEVICES] = { {0, 0, false, 0} };
|
556
575
|
|
557
576
|
static void * g_scratch_buffer = nullptr;
|
558
577
|
static size_t g_scratch_size = 0; // disabled by default
|
@@ -585,6 +604,19 @@ static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
|
|
585
604
|
return a;
|
586
605
|
}
|
587
606
|
|
607
|
+
static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
|
608
|
+
#if __CUDA_ARCH__ < CC_PASCAL || (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
609
|
+
(void) a;
|
610
|
+
bad_arch();
|
611
|
+
#else
|
612
|
+
#pragma unroll
|
613
|
+
for (int mask = 16; mask > 0; mask >>= 1) {
|
614
|
+
a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32));
|
615
|
+
}
|
616
|
+
return a;
|
617
|
+
#endif // __CUDA_ARCH__ < CC_PASCAL || (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
618
|
+
}
|
619
|
+
|
588
620
|
static __device__ __forceinline__ float warp_reduce_max(float x) {
|
589
621
|
#pragma unroll
|
590
622
|
for (int mask = 16; mask > 0; mask >>= 1) {
|
@@ -593,6 +625,19 @@ static __device__ __forceinline__ float warp_reduce_max(float x) {
|
|
593
625
|
return x;
|
594
626
|
}
|
595
627
|
|
628
|
+
static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
|
629
|
+
#if __CUDA_ARCH__ < CC_PASCAL || (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
630
|
+
(void) x;
|
631
|
+
bad_arch();
|
632
|
+
#else
|
633
|
+
#pragma unroll
|
634
|
+
for (int mask = 16; mask > 0; mask >>= 1) {
|
635
|
+
x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
|
636
|
+
}
|
637
|
+
return x;
|
638
|
+
#endif // __CUDA_ARCH__ < CC_PASCAL || (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
639
|
+
}
|
640
|
+
|
596
641
|
static __device__ __forceinline__ float op_repeat(const float a, const float b) {
|
597
642
|
return b;
|
598
643
|
GGML_UNUSED(a);
|
@@ -1292,6 +1337,281 @@ static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t
|
|
1292
1337
|
#endif
|
1293
1338
|
}
|
1294
1339
|
|
1340
|
+
static const __device__ uint64_t iq2xxs_grid[256] = {
|
1341
|
+
0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
|
1342
|
+
0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x08080808082b0808,
|
1343
|
+
0x08080808082b082b, 0x08080808082b2b08, 0x08080808082b2b2b, 0x0808080819080819,
|
1344
|
+
0x0808080819081908, 0x0808080819190808, 0x0808080819192b08, 0x08080808192b0819,
|
1345
|
+
0x08080808192b1908, 0x080808082b080808, 0x080808082b08082b, 0x080808082b082b2b,
|
1346
|
+
0x080808082b2b082b, 0x0808081908080819, 0x0808081908081908, 0x0808081908190808,
|
1347
|
+
0x0808081908191919, 0x0808081919080808, 0x080808192b081908, 0x080808192b192b08,
|
1348
|
+
0x0808082b08080808, 0x0808082b0808082b, 0x0808082b082b082b, 0x0808082b2b08082b,
|
1349
|
+
0x0808190808080819, 0x0808190808081908, 0x0808190808190808, 0x08081908082b0819,
|
1350
|
+
0x08081908082b1908, 0x0808190819080808, 0x080819081908082b, 0x0808190819082b08,
|
1351
|
+
0x08081908192b0808, 0x080819082b080819, 0x080819082b081908, 0x080819082b190808,
|
1352
|
+
0x080819082b2b1908, 0x0808191908080808, 0x080819190808082b, 0x0808191908082b08,
|
1353
|
+
0x08081919082b0808, 0x080819191908192b, 0x08081919192b2b19, 0x080819192b080808,
|
1354
|
+
0x080819192b190819, 0x0808192b08082b19, 0x0808192b08190808, 0x0808192b19080808,
|
1355
|
+
0x0808192b2b081908, 0x0808192b2b2b1908, 0x08082b0808080808, 0x08082b0808081919,
|
1356
|
+
0x08082b0808082b08, 0x08082b0808191908, 0x08082b08082b2b08, 0x08082b0819080819,
|
1357
|
+
0x08082b0819081908, 0x08082b0819190808, 0x08082b081919082b, 0x08082b082b082b08,
|
1358
|
+
0x08082b1908081908, 0x08082b1919080808, 0x08082b2b0808082b, 0x08082b2b08191908,
|
1359
|
+
0x0819080808080819, 0x0819080808081908, 0x0819080808190808, 0x08190808082b0819,
|
1360
|
+
0x0819080819080808, 0x08190808192b0808, 0x081908082b081908, 0x081908082b190808,
|
1361
|
+
0x081908082b191919, 0x0819081908080808, 0x0819081908082b08, 0x08190819082b0808,
|
1362
|
+
0x0819081919190808, 0x0819081919192b2b, 0x081908192b080808, 0x0819082b082b1908,
|
1363
|
+
0x0819082b19081919, 0x0819190808080808, 0x0819190808082b08, 0x08191908082b0808,
|
1364
|
+
0x08191908082b1919, 0x0819190819082b19, 0x081919082b080808, 0x0819191908192b08,
|
1365
|
+
0x08191919192b082b, 0x0819192b08080808, 0x0819192b0819192b, 0x08192b0808080819,
|
1366
|
+
0x08192b0808081908, 0x08192b0808190808, 0x08192b0819080808, 0x08192b082b080819,
|
1367
|
+
0x08192b1908080808, 0x08192b1908081919, 0x08192b192b2b0808, 0x08192b2b19190819,
|
1368
|
+
0x082b080808080808, 0x082b08080808082b, 0x082b080808082b2b, 0x082b080819081908,
|
1369
|
+
0x082b0808192b0819, 0x082b08082b080808, 0x082b08082b08082b, 0x082b0819082b2b19,
|
1370
|
+
0x082b081919082b08, 0x082b082b08080808, 0x082b082b0808082b, 0x082b190808080819,
|
1371
|
+
0x082b190808081908, 0x082b190808190808, 0x082b190819080808, 0x082b19081919192b,
|
1372
|
+
0x082b191908080808, 0x082b191919080819, 0x082b1919192b1908, 0x082b192b2b190808,
|
1373
|
+
0x082b2b0808082b08, 0x082b2b08082b0808, 0x082b2b082b191908, 0x082b2b2b19081908,
|
1374
|
+
0x1908080808080819, 0x1908080808081908, 0x1908080808190808, 0x1908080808192b08,
|
1375
|
+
0x19080808082b0819, 0x19080808082b1908, 0x1908080819080808, 0x1908080819082b08,
|
1376
|
+
0x190808081919192b, 0x19080808192b0808, 0x190808082b080819, 0x190808082b081908,
|
1377
|
+
0x190808082b190808, 0x1908081908080808, 0x19080819082b0808, 0x19080819192b0819,
|
1378
|
+
0x190808192b080808, 0x190808192b081919, 0x1908082b08080819, 0x1908082b08190808,
|
1379
|
+
0x1908082b19082b08, 0x1908082b1919192b, 0x1908082b192b2b08, 0x1908190808080808,
|
1380
|
+
0x1908190808082b08, 0x19081908082b0808, 0x190819082b080808, 0x190819082b192b19,
|
1381
|
+
0x190819190819082b, 0x19081919082b1908, 0x1908192b08080808, 0x19082b0808080819,
|
1382
|
+
0x19082b0808081908, 0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919,
|
1383
|
+
0x19082b1908080808, 0x19082b1919192b08, 0x19082b19192b0819, 0x19082b192b08082b,
|
1384
|
+
0x19082b2b19081919, 0x19082b2b2b190808, 0x1919080808080808, 0x1919080808082b08,
|
1385
|
+
0x1919080808190819, 0x1919080808192b19, 0x19190808082b0808, 0x191908082b080808,
|
1386
|
+
0x191908082b082b08, 0x1919081908081908, 0x191908191908082b, 0x191908192b2b1908,
|
1387
|
+
0x1919082b2b190819, 0x191919082b190808, 0x191919082b19082b, 0x1919191908082b2b,
|
1388
|
+
0x1919192b08080819, 0x1919192b19191908, 0x19192b0808080808, 0x19192b0808190819,
|
1389
|
+
0x19192b0808192b19, 0x19192b08192b1908, 0x19192b1919080808, 0x19192b2b08082b08,
|
1390
|
+
0x192b080808081908, 0x192b080808190808, 0x192b080819080808, 0x192b0808192b2b08,
|
1391
|
+
0x192b081908080808, 0x192b081919191919, 0x192b082b08192b08, 0x192b082b192b0808,
|
1392
|
+
0x192b190808080808, 0x192b190808081919, 0x192b191908190808, 0x192b19190819082b,
|
1393
|
+
0x192b19192b081908, 0x192b2b081908082b, 0x2b08080808080808, 0x2b0808080808082b,
|
1394
|
+
0x2b08080808082b2b, 0x2b08080819080819, 0x2b0808082b08082b, 0x2b08081908081908,
|
1395
|
+
0x2b08081908192b08, 0x2b08081919080808, 0x2b08082b08190819, 0x2b08190808080819,
|
1396
|
+
0x2b08190808081908, 0x2b08190808190808, 0x2b08190808191919, 0x2b08190819080808,
|
1397
|
+
0x2b081908192b0808, 0x2b08191908080808, 0x2b0819191908192b, 0x2b0819192b191908,
|
1398
|
+
0x2b08192b08082b19, 0x2b08192b19080808, 0x2b08192b192b0808, 0x2b082b080808082b,
|
1399
|
+
0x2b082b1908081908, 0x2b082b2b08190819, 0x2b19080808081908, 0x2b19080808190808,
|
1400
|
+
0x2b190808082b1908, 0x2b19080819080808, 0x2b1908082b2b0819, 0x2b1908190819192b,
|
1401
|
+
0x2b1908192b080808, 0x2b19082b19081919, 0x2b19190808080808, 0x2b191908082b082b,
|
1402
|
+
0x2b19190819081908, 0x2b19191919190819, 0x2b192b082b080819, 0x2b192b19082b0808,
|
1403
|
+
0x2b2b08080808082b, 0x2b2b080819190808, 0x2b2b08082b081919, 0x2b2b081908082b19,
|
1404
|
+
0x2b2b082b08080808, 0x2b2b190808192b08, 0x2b2b2b0819190808, 0x2b2b2b1908081908,
|
1405
|
+
};
|
1406
|
+
|
1407
|
+
static const __device__ uint64_t iq2xs_grid[512] = {
|
1408
|
+
0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
|
1409
|
+
0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x080808080819192b,
|
1410
|
+
0x0808080808192b19, 0x08080808082b0808, 0x08080808082b082b, 0x08080808082b1919,
|
1411
|
+
0x08080808082b2b08, 0x0808080819080819, 0x0808080819081908, 0x080808081908192b,
|
1412
|
+
0x0808080819082b19, 0x0808080819190808, 0x080808081919082b, 0x0808080819191919,
|
1413
|
+
0x0808080819192b08, 0x08080808192b0819, 0x08080808192b1908, 0x080808082b080808,
|
1414
|
+
0x080808082b08082b, 0x080808082b081919, 0x080808082b082b08, 0x080808082b190819,
|
1415
|
+
0x080808082b191908, 0x080808082b192b19, 0x080808082b2b0808, 0x0808081908080819,
|
1416
|
+
0x0808081908081908, 0x080808190808192b, 0x0808081908082b19, 0x0808081908190808,
|
1417
|
+
0x080808190819082b, 0x0808081908191919, 0x0808081908192b08, 0x0808081908192b2b,
|
1418
|
+
0x08080819082b0819, 0x08080819082b1908, 0x0808081919080808, 0x080808191908082b,
|
1419
|
+
0x0808081919081919, 0x0808081919082b08, 0x0808081919190819, 0x0808081919191908,
|
1420
|
+
0x08080819192b0808, 0x08080819192b2b08, 0x080808192b080819, 0x080808192b081908,
|
1421
|
+
0x080808192b190808, 0x0808082b08080808, 0x0808082b0808082b, 0x0808082b08081919,
|
1422
|
+
0x0808082b08082b08, 0x0808082b08190819, 0x0808082b08191908, 0x0808082b082b0808,
|
1423
|
+
0x0808082b19080819, 0x0808082b19081908, 0x0808082b19190808, 0x0808082b19191919,
|
1424
|
+
0x0808082b2b080808, 0x0808082b2b082b2b, 0x0808190808080819, 0x0808190808081908,
|
1425
|
+
0x080819080808192b, 0x0808190808082b19, 0x0808190808190808, 0x080819080819082b,
|
1426
|
+
0x0808190808191919, 0x0808190808192b08, 0x08081908082b0819, 0x08081908082b1908,
|
1427
|
+
0x0808190819080808, 0x080819081908082b, 0x0808190819081919, 0x0808190819082b08,
|
1428
|
+
0x0808190819190819, 0x0808190819191908, 0x080819081919192b, 0x08081908192b0808,
|
1429
|
+
0x080819082b080819, 0x080819082b081908, 0x080819082b190808, 0x0808191908080808,
|
1430
|
+
0x080819190808082b, 0x0808191908081919, 0x0808191908082b08, 0x0808191908190819,
|
1431
|
+
0x0808191908191908, 0x08081919082b0808, 0x0808191919080819, 0x0808191919081908,
|
1432
|
+
0x0808191919190808, 0x08081919192b0819, 0x080819192b080808, 0x0808192b08080819,
|
1433
|
+
0x0808192b08081908, 0x0808192b08190808, 0x0808192b082b192b, 0x0808192b19080808,
|
1434
|
+
0x0808192b1908082b, 0x0808192b2b081908, 0x08082b0808080808, 0x08082b080808082b,
|
1435
|
+
0x08082b0808081919, 0x08082b0808082b08, 0x08082b0808082b2b, 0x08082b0808190819,
|
1436
|
+
0x08082b0808191908, 0x08082b08082b0808, 0x08082b08082b1919, 0x08082b0819080819,
|
1437
|
+
0x08082b0819081908, 0x08082b0819190808, 0x08082b0819192b08, 0x08082b082b080808,
|
1438
|
+
0x08082b082b2b0808, 0x08082b082b2b2b2b, 0x08082b1908080819, 0x08082b1908081908,
|
1439
|
+
0x08082b1908190808, 0x08082b1919080808, 0x08082b192b080819, 0x08082b192b082b19,
|
1440
|
+
0x08082b2b08080808, 0x08082b2b082b0808, 0x08082b2b082b2b08, 0x08082b2b2b19192b,
|
1441
|
+
0x08082b2b2b2b0808, 0x0819080808080819, 0x0819080808081908, 0x081908080808192b,
|
1442
|
+
0x0819080808082b19, 0x0819080808190808, 0x081908080819082b, 0x0819080808191919,
|
1443
|
+
0x0819080808192b08, 0x08190808082b0819, 0x08190808082b1908, 0x0819080819080808,
|
1444
|
+
0x081908081908082b, 0x0819080819081919, 0x0819080819082b08, 0x0819080819190819,
|
1445
|
+
0x0819080819191908, 0x08190808192b0808, 0x08190808192b2b2b, 0x081908082b080819,
|
1446
|
+
0x081908082b081908, 0x081908082b190808, 0x0819081908080808, 0x081908190808082b,
|
1447
|
+
0x0819081908081919, 0x0819081908082b08, 0x0819081908190819, 0x0819081908191908,
|
1448
|
+
0x08190819082b0808, 0x0819081919080819, 0x0819081919081908, 0x0819081919190808,
|
1449
|
+
0x081908192b080808, 0x081908192b191908, 0x081908192b19192b, 0x0819082b08080819,
|
1450
|
+
0x0819082b08081908, 0x0819082b0808192b, 0x0819082b08190808, 0x0819082b19080808,
|
1451
|
+
0x0819082b192b0808, 0x0819190808080808, 0x081919080808082b, 0x0819190808081919,
|
1452
|
+
0x0819190808082b08, 0x0819190808190819, 0x0819190808191908, 0x08191908082b0808,
|
1453
|
+
0x0819190819080819, 0x0819190819081908, 0x0819190819082b19, 0x0819190819190808,
|
1454
|
+
0x08191908192b1908, 0x081919082b080808, 0x0819191908080819, 0x0819191908081908,
|
1455
|
+
0x0819191908190808, 0x0819191919080808, 0x0819192b08080808, 0x0819192b08191908,
|
1456
|
+
0x0819192b19082b19, 0x08192b0808080819, 0x08192b0808081908, 0x08192b0808190808,
|
1457
|
+
0x08192b080819082b, 0x08192b0819080808, 0x08192b0819191908, 0x08192b082b08192b,
|
1458
|
+
0x08192b1908080808, 0x08192b1908081919, 0x08192b19192b192b, 0x08192b2b19190819,
|
1459
|
+
0x08192b2b2b2b2b19, 0x082b080808080808, 0x082b08080808082b, 0x082b080808081919,
|
1460
|
+
0x082b080808082b08, 0x082b080808082b2b, 0x082b080808190819, 0x082b080808191908,
|
1461
|
+
0x082b0808082b0808, 0x082b080819080819, 0x082b080819081908, 0x082b080819190808,
|
1462
|
+
0x082b08082b080808, 0x082b08082b2b0808, 0x082b081908080819, 0x082b081908081908,
|
1463
|
+
0x082b081908190808, 0x082b081919080808, 0x082b081919082b08, 0x082b0819192b1919,
|
1464
|
+
0x082b082b08080808, 0x082b082b082b082b, 0x082b082b2b080808, 0x082b082b2b2b2b08,
|
1465
|
+
0x082b190808080819, 0x082b190808081908, 0x082b190808190808, 0x082b1908082b2b19,
|
1466
|
+
0x082b190819080808, 0x082b191908080808, 0x082b191919080819, 0x082b19191919082b,
|
1467
|
+
0x082b19192b192b19, 0x082b192b08080819, 0x082b192b08192b2b, 0x082b192b2b2b192b,
|
1468
|
+
0x082b2b0808080808, 0x082b2b0808082b08, 0x082b2b0808082b2b, 0x082b2b08082b0808,
|
1469
|
+
0x082b2b0819191919, 0x082b2b082b082b08, 0x082b2b082b2b082b, 0x082b2b19192b2b08,
|
1470
|
+
0x082b2b192b190808, 0x082b2b2b08082b08, 0x082b2b2b082b0808, 0x082b2b2b2b08082b,
|
1471
|
+
0x082b2b2b2b082b08, 0x082b2b2b2b082b2b, 0x1908080808080819, 0x1908080808081908,
|
1472
|
+
0x190808080808192b, 0x1908080808082b19, 0x1908080808190808, 0x190808080819082b,
|
1473
|
+
0x1908080808191919, 0x1908080808192b08, 0x19080808082b0819, 0x19080808082b1908,
|
1474
|
+
0x1908080819080808, 0x190808081908082b, 0x1908080819081919, 0x1908080819082b08,
|
1475
|
+
0x1908080819082b2b, 0x1908080819190819, 0x1908080819191908, 0x19080808192b0808,
|
1476
|
+
0x19080808192b1919, 0x190808082b080819, 0x190808082b081908, 0x190808082b190808,
|
1477
|
+
0x1908081908080808, 0x190808190808082b, 0x1908081908081919, 0x1908081908082b08,
|
1478
|
+
0x1908081908190819, 0x1908081908191908, 0x19080819082b0808, 0x1908081919080819,
|
1479
|
+
0x1908081919081908, 0x1908081919190808, 0x190808192b080808, 0x190808192b081919,
|
1480
|
+
0x190808192b2b082b, 0x1908082b08080819, 0x1908082b08081908, 0x1908082b08190808,
|
1481
|
+
0x1908082b0819082b, 0x1908082b082b2b19, 0x1908082b19080808, 0x1908190808080808,
|
1482
|
+
0x190819080808082b, 0x1908190808081919, 0x1908190808082b08, 0x1908190808190819,
|
1483
|
+
0x1908190808191908, 0x1908190808192b19, 0x19081908082b0808, 0x1908190819080819,
|
1484
|
+
0x1908190819081908, 0x1908190819190808, 0x190819082b080808, 0x190819082b191908,
|
1485
|
+
0x1908191908080819, 0x1908191908081908, 0x1908191908190808, 0x19081919082b1908,
|
1486
|
+
0x1908191919080808, 0x190819192b192b2b, 0x1908192b08080808, 0x1908192b08082b2b,
|
1487
|
+
0x1908192b19081908, 0x1908192b19190808, 0x19082b0808080819, 0x19082b0808081908,
|
1488
|
+
0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919, 0x19082b0819191908,
|
1489
|
+
0x19082b08192b082b, 0x19082b1908080808, 0x19082b1908190819, 0x19082b1919081908,
|
1490
|
+
0x19082b1919190808, 0x19082b19192b2b19, 0x19082b2b08081908, 0x1919080808080808,
|
1491
|
+
0x191908080808082b, 0x1919080808081919, 0x1919080808082b08, 0x1919080808190819,
|
1492
|
+
0x1919080808191908, 0x19190808082b0808, 0x19190808082b2b08, 0x1919080819080819,
|
1493
|
+
0x1919080819081908, 0x1919080819190808, 0x191908082b080808, 0x1919081908080819,
|
1494
|
+
0x1919081908081908, 0x1919081908190808, 0x1919081908191919, 0x1919081919080808,
|
1495
|
+
0x191908191908082b, 0x1919082b08080808, 0x1919082b19081908, 0x1919082b2b2b2b2b,
|
1496
|
+
0x1919190808080819, 0x1919190808081908, 0x1919190808190808, 0x19191908082b0819,
|
1497
|
+
0x1919190819080808, 0x19191908192b0808, 0x191919082b080819, 0x191919082b2b0819,
|
1498
|
+
0x1919191908080808, 0x1919191908082b08, 0x191919192b080808, 0x191919192b082b08,
|
1499
|
+
0x1919192b082b0819, 0x1919192b192b2b08, 0x1919192b2b2b0819, 0x19192b0808080808,
|
1500
|
+
0x19192b0808191908, 0x19192b0819080819, 0x19192b0819190808, 0x19192b082b192b19,
|
1501
|
+
0x19192b1908192b2b, 0x19192b1919080808, 0x19192b191908082b, 0x19192b2b2b081919,
|
1502
|
+
0x192b080808080819, 0x192b080808081908, 0x192b080808190808, 0x192b080819080808,
|
1503
|
+
0x192b080819191908, 0x192b0808192b082b, 0x192b08082b08192b, 0x192b08082b2b2b19,
|
1504
|
+
0x192b081908080808, 0x192b082b082b1908, 0x192b082b19082b2b, 0x192b082b2b19082b,
|
1505
|
+
0x192b190808080808, 0x192b19080819192b, 0x192b191908190808, 0x192b191919080808,
|
1506
|
+
0x192b191919081919, 0x192b19192b2b1908, 0x192b2b0808080819, 0x192b2b08192b2b2b,
|
1507
|
+
0x192b2b19082b1919, 0x192b2b2b0808192b, 0x192b2b2b19191908, 0x192b2b2b192b082b,
|
1508
|
+
0x2b08080808080808, 0x2b0808080808082b, 0x2b08080808081919, 0x2b08080808082b08,
|
1509
|
+
0x2b08080808190819, 0x2b08080808191908, 0x2b080808082b0808, 0x2b080808082b2b2b,
|
1510
|
+
0x2b08080819080819, 0x2b08080819081908, 0x2b08080819190808, 0x2b0808082b080808,
|
1511
|
+
0x2b0808082b08082b, 0x2b0808082b2b2b08, 0x2b0808082b2b2b2b, 0x2b08081908080819,
|
1512
|
+
0x2b08081908081908, 0x2b0808190808192b, 0x2b08081908190808, 0x2b08081919080808,
|
1513
|
+
0x2b08081919190819, 0x2b08081919192b19, 0x2b08082b08080808, 0x2b08082b082b0808,
|
1514
|
+
0x2b08082b2b080808, 0x2b08082b2b08082b, 0x2b08082b2b2b0808, 0x2b08082b2b2b2b08,
|
1515
|
+
0x2b08190808080819, 0x2b08190808081908, 0x2b08190808190808, 0x2b0819080819082b,
|
1516
|
+
0x2b08190808191919, 0x2b08190819080808, 0x2b081908192b0808, 0x2b0819082b082b19,
|
1517
|
+
0x2b08191908080808, 0x2b08191919081908, 0x2b0819192b2b1919, 0x2b08192b08192b08,
|
1518
|
+
0x2b08192b192b2b2b, 0x2b082b0808080808, 0x2b082b0808082b08, 0x2b082b08082b1919,
|
1519
|
+
0x2b082b0819192b2b, 0x2b082b082b080808, 0x2b082b082b08082b, 0x2b082b082b2b2b08,
|
1520
|
+
0x2b082b190808192b, 0x2b082b2b082b082b, 0x2b082b2b2b080808, 0x2b082b2b2b082b08,
|
1521
|
+
0x2b082b2b2b19192b, 0x2b082b2b2b2b2b08, 0x2b19080808080819, 0x2b19080808081908,
|
1522
|
+
0x2b19080808190808, 0x2b19080819080808, 0x2b1908081919192b, 0x2b1908082b081908,
|
1523
|
+
0x2b19081908080808, 0x2b190819082b082b, 0x2b190819192b1908, 0x2b19082b1919192b,
|
1524
|
+
0x2b19082b2b082b19, 0x2b19190808080808, 0x2b19190808081919, 0x2b19190819081908,
|
1525
|
+
0x2b19190819190808, 0x2b19190819192b08, 0x2b191919082b2b19, 0x2b1919192b190808,
|
1526
|
+
0x2b1919192b19082b, 0x2b19192b19080819, 0x2b192b0819190819, 0x2b192b082b2b192b,
|
1527
|
+
0x2b192b1919082b19, 0x2b192b2b08191919, 0x2b192b2b192b0808, 0x2b2b080808080808,
|
1528
|
+
0x2b2b08080808082b, 0x2b2b080808082b08, 0x2b2b080808082b2b, 0x2b2b0808082b0808,
|
1529
|
+
0x2b2b0808082b2b2b, 0x2b2b08082b2b0808, 0x2b2b081919190819, 0x2b2b081919192b19,
|
1530
|
+
0x2b2b08192b2b192b, 0x2b2b082b08080808, 0x2b2b082b0808082b, 0x2b2b082b08082b08,
|
1531
|
+
0x2b2b082b082b2b2b, 0x2b2b082b2b080808, 0x2b2b082b2b2b0808, 0x2b2b190819080808,
|
1532
|
+
0x2b2b19082b191919, 0x2b2b192b192b1919, 0x2b2b192b2b192b08, 0x2b2b2b0808082b2b,
|
1533
|
+
0x2b2b2b08082b0808, 0x2b2b2b08082b082b, 0x2b2b2b08082b2b08, 0x2b2b2b082b2b0808,
|
1534
|
+
0x2b2b2b082b2b2b08, 0x2b2b2b1908081908, 0x2b2b2b192b081908, 0x2b2b2b192b08192b,
|
1535
|
+
0x2b2b2b2b082b2b08, 0x2b2b2b2b082b2b2b, 0x2b2b2b2b2b190819, 0x2b2b2b2b2b2b2b2b,
|
1536
|
+
};
|
1537
|
+
|
1538
|
+
static const __device__ uint8_t ksigns_iq2xs[128] = {
|
1539
|
+
0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15,
|
1540
|
+
144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159,
|
1541
|
+
160, 33, 34, 163, 36, 165, 166, 39, 40, 169, 170, 43, 172, 45, 46, 175,
|
1542
|
+
48, 177, 178, 51, 180, 53, 54, 183, 184, 57, 58, 187, 60, 189, 190, 63,
|
1543
|
+
192, 65, 66, 195, 68, 197, 198, 71, 72, 201, 202, 75, 204, 77, 78, 207,
|
1544
|
+
80, 209, 210, 83, 212, 85, 86, 215, 216, 89, 90, 219, 92, 221, 222, 95,
|
1545
|
+
96, 225, 226, 99, 228, 101, 102, 231, 232, 105, 106, 235, 108, 237, 238, 111,
|
1546
|
+
240, 113, 114, 243, 116, 245, 246, 119, 120, 249, 250, 123, 252, 125, 126, 255,
|
1547
|
+
};
|
1548
|
+
|
1549
|
+
static const __device__ uint8_t kmask_iq2xs[8] = {1, 2, 4, 8, 16, 32, 64, 128};
|
1550
|
+
|
1551
|
+
inline bool ggml_cuda_supports_mmq(enum ggml_type type) {
|
1552
|
+
switch (type) {
|
1553
|
+
case GGML_TYPE_Q4_0:
|
1554
|
+
case GGML_TYPE_Q4_1:
|
1555
|
+
case GGML_TYPE_Q5_0:
|
1556
|
+
case GGML_TYPE_Q5_1:
|
1557
|
+
case GGML_TYPE_Q8_0:
|
1558
|
+
case GGML_TYPE_Q2_K:
|
1559
|
+
case GGML_TYPE_Q3_K:
|
1560
|
+
case GGML_TYPE_Q4_K:
|
1561
|
+
case GGML_TYPE_Q5_K:
|
1562
|
+
case GGML_TYPE_Q6_K:
|
1563
|
+
return true;
|
1564
|
+
default:
|
1565
|
+
return false;
|
1566
|
+
}
|
1567
|
+
}
|
1568
|
+
|
1569
|
+
template<typename dst_t>
|
1570
|
+
static __global__ void dequantize_block_iq2_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
1571
|
+
|
1572
|
+
const int i = blockIdx.x;
|
1573
|
+
const block_iq2_xxs * x = (const block_iq2_xxs *) vx;
|
1574
|
+
|
1575
|
+
const int tid = threadIdx.x;
|
1576
|
+
#if QK_K == 256
|
1577
|
+
const int il = tid/8; // 0...3
|
1578
|
+
const int ib = tid%8; // 0...7
|
1579
|
+
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
1580
|
+
const uint16_t * q2 = x[i].qs + 4*ib;
|
1581
|
+
const uint8_t * aux8 = (const uint8_t *)q2;
|
1582
|
+
const uint8_t * grid = (const uint8_t *)(iq2xxs_grid + aux8[il]);
|
1583
|
+
const uint32_t aux32 = q2[2] | (q2[3] << 16);
|
1584
|
+
const float d = (float)x[i].d * (0.5f + (aux32 >> 28)) * 0.25f;
|
1585
|
+
const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*il) & 127];
|
1586
|
+
for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
|
1587
|
+
#else
|
1588
|
+
assert(false);
|
1589
|
+
#endif
|
1590
|
+
|
1591
|
+
}
|
1592
|
+
|
1593
|
+
template<typename dst_t>
|
1594
|
+
static __global__ void dequantize_block_iq2_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
1595
|
+
|
1596
|
+
const int i = blockIdx.x;
|
1597
|
+
const block_iq2_xs * x = (const block_iq2_xs *) vx;
|
1598
|
+
|
1599
|
+
const int tid = threadIdx.x;
|
1600
|
+
#if QK_K == 256
|
1601
|
+
const int il = tid/8; // 0...3
|
1602
|
+
const int ib = tid%8; // 0...7
|
1603
|
+
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
1604
|
+
const uint16_t * q2 = x[i].qs + 4*ib;
|
1605
|
+
const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[il] & 511));
|
1606
|
+
const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f;
|
1607
|
+
const uint8_t signs = ksigns_iq2xs[q2[il] >> 9];
|
1608
|
+
for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
|
1609
|
+
#else
|
1610
|
+
assert(false);
|
1611
|
+
#endif
|
1612
|
+
|
1613
|
+
}
|
1614
|
+
|
1295
1615
|
static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
|
1296
1616
|
|
1297
1617
|
static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
|
@@ -1872,14 +2192,6 @@ static __device__ void convert_f16(const void * vx, const int ib, const int iqs,
|
|
1872
2192
|
v.y = x[ib + iqs + 1];
|
1873
2193
|
}
|
1874
2194
|
|
1875
|
-
static __device__ void convert_f32(const void * vx, const int ib, const int iqs, dfloat2 & v){
|
1876
|
-
const float * x = (const float *) vx;
|
1877
|
-
|
1878
|
-
// automatic half -> float type cast if dfloat == float
|
1879
|
-
v.x = x[ib + iqs + 0];
|
1880
|
-
v.y = x[ib + iqs + 1];
|
1881
|
-
}
|
1882
|
-
|
1883
2195
|
static __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int kx, const int kx_padded) {
|
1884
2196
|
const int ix = blockDim.x*blockIdx.x + threadIdx.x;
|
1885
2197
|
|
@@ -1983,7 +2295,7 @@ static __global__ void k_get_rows_float(
|
|
1983
2295
|
|
1984
2296
|
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
|
1985
2297
|
static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int k) {
|
1986
|
-
const int i = blockDim.x*blockIdx.x +
|
2298
|
+
const int i = 2*(blockDim.x*blockIdx.x + threadIdx.x);
|
1987
2299
|
|
1988
2300
|
if (i >= k) {
|
1989
2301
|
return;
|
@@ -2002,6 +2314,19 @@ static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __
|
|
2002
2314
|
y[iybs + iqs + y_offset] = v.y;
|
2003
2315
|
}
|
2004
2316
|
|
2317
|
+
template <typename src_t, typename dst_t>
|
2318
|
+
static __global__ void convert_unary(const void * __restrict__ vx, dst_t * __restrict__ y, const int k) {
|
2319
|
+
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
2320
|
+
|
2321
|
+
if (i >= k) {
|
2322
|
+
return;
|
2323
|
+
}
|
2324
|
+
|
2325
|
+
const src_t * x = (src_t *) vx;
|
2326
|
+
|
2327
|
+
y[i] = x[i];
|
2328
|
+
}
|
2329
|
+
|
2005
2330
|
// VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called
|
2006
2331
|
// MMVQ = mul_mat_vec_q, MMQ = mul_mat_q
|
2007
2332
|
|
@@ -3820,6 +4145,91 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_mul_mat(
|
|
3820
4145
|
return vec_dot_q6_K_q8_1_impl_mmq(&x_ql[index_x], &y_qs[index_y], sc, x_dmf[i * (WARP_SIZE/QI6_K) + i/QI6_K], &y_df[index_y/QI8_1]);
|
3821
4146
|
}
|
3822
4147
|
|
4148
|
+
static __device__ __forceinline__ float vec_dot_iq2_xxs_q8_1(
|
4149
|
+
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
|
4150
|
+
#if QK_K == 256
|
4151
|
+
const block_iq2_xxs * bq2 = (const block_iq2_xxs *) vbq;
|
4152
|
+
|
4153
|
+
#if QR2_XXS == 8
|
4154
|
+
const int ib32 = iqs;
|
4155
|
+
const uint16_t * q2 = bq2->qs + 4*ib32;
|
4156
|
+
const uint8_t * aux8 = (const uint8_t *)q2;
|
4157
|
+
const int8_t * q8 = bq8_1[ib32].qs;
|
4158
|
+
uint32_t aux32 = q2[2] | (q2[3] << 16);
|
4159
|
+
int sumi = 0;
|
4160
|
+
for (int l = 0; l < 4; ++l) {
|
4161
|
+
const uint8_t * grid = (const uint8_t *)(iq2xxs_grid + aux8[l]);
|
4162
|
+
const uint8_t signs = ksigns_iq2xs[aux32 & 127];
|
4163
|
+
for (int j = 0; j < 8; ++j) {
|
4164
|
+
sumi += q8[j] * grid[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
|
4165
|
+
}
|
4166
|
+
q8 += 8;
|
4167
|
+
aux32 >>= 7;
|
4168
|
+
}
|
4169
|
+
const float d = (float)bq2->d * (0.5f + aux32) * (float)bq8_1[ib32].ds.x * 0.25f;
|
4170
|
+
return d * sumi;
|
4171
|
+
#else
|
4172
|
+
// iqs is 0...15
|
4173
|
+
const int ib32 = iqs/2;
|
4174
|
+
const int il = iqs%2;
|
4175
|
+
const uint16_t * q2 = bq2->qs + 4*ib32;
|
4176
|
+
const uint8_t * aux8 = (const uint8_t *)q2;
|
4177
|
+
const uint8_t * grid1 = (const uint8_t *)(iq2xxs_grid + aux8[2*il+0]);
|
4178
|
+
const uint8_t * grid2 = (const uint8_t *)(iq2xxs_grid + aux8[2*il+1]);
|
4179
|
+
const uint32_t aux32 = q2[2] | (q2[3] << 16);
|
4180
|
+
const float d = (float)bq2->d * (0.5f + (aux32 >> 28)) * (float)bq8_1[ib32].ds.x * 0.25f;
|
4181
|
+
const uint8_t signs1 = ksigns_iq2xs[(aux32 >> 14*il) & 127];
|
4182
|
+
const uint8_t signs2 = ksigns_iq2xs[(aux32 >> (14*il + 7)) & 127];
|
4183
|
+
const int8_t * q8 = bq8_1[ib32].qs + 16*il;
|
4184
|
+
int sumi1 = 0, sumi2 = 0;
|
4185
|
+
for (int j = 0; j < 8; ++j) {
|
4186
|
+
sumi1 += q8[j+0] * grid1[j] * (signs1 & kmask_iq2xs[j] ? -1 : 1);
|
4187
|
+
sumi2 += q8[j+8] * grid2[j] * (signs2 & kmask_iq2xs[j] ? -1 : 1);
|
4188
|
+
}
|
4189
|
+
return d * (sumi1 + sumi2);
|
4190
|
+
#endif
|
4191
|
+
#else
|
4192
|
+
assert(false);
|
4193
|
+
return 0.f;
|
4194
|
+
#endif
|
4195
|
+
}
|
4196
|
+
|
4197
|
+
static __device__ __forceinline__ float vec_dot_iq2_xs_q8_1(
|
4198
|
+
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
|
4199
|
+
#if QK_K == 256
|
4200
|
+
const block_iq2_xs * bq2 = (const block_iq2_xs *) vbq;
|
4201
|
+
|
4202
|
+
const int ib32 = iqs;
|
4203
|
+
const uint16_t * q2 = bq2->qs + 4*ib32;
|
4204
|
+
const int8_t * q8 = bq8_1[ib32].qs;
|
4205
|
+
const uint8_t ls1 = bq2->scales[ib32] & 0xf;
|
4206
|
+
const uint8_t ls2 = bq2->scales[ib32] >> 4;
|
4207
|
+
int sumi1 = 0;
|
4208
|
+
for (int l = 0; l < 2; ++l) {
|
4209
|
+
const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[l] & 511));
|
4210
|
+
const uint8_t signs = ksigns_iq2xs[q2[l] >> 9];
|
4211
|
+
for (int j = 0; j < 8; ++j) {
|
4212
|
+
sumi1 += q8[j] * grid[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
|
4213
|
+
}
|
4214
|
+
q8 += 8;
|
4215
|
+
}
|
4216
|
+
int sumi2 = 0;
|
4217
|
+
for (int l = 2; l < 4; ++l) {
|
4218
|
+
const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[l] & 511));
|
4219
|
+
const uint8_t signs = ksigns_iq2xs[q2[l] >> 9];
|
4220
|
+
for (int j = 0; j < 8; ++j) {
|
4221
|
+
sumi2 += q8[j] * grid[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
|
4222
|
+
}
|
4223
|
+
q8 += 8;
|
4224
|
+
}
|
4225
|
+
const float d = (float)bq2->d * (float)bq8_1[ib32].ds.x * 0.25f;
|
4226
|
+
return d * ((0.5f + ls1) * sumi1 + (0.5f + ls2) * sumi2);
|
4227
|
+
#else
|
4228
|
+
assert(false);
|
4229
|
+
return 0.f;
|
4230
|
+
#endif
|
4231
|
+
}
|
4232
|
+
|
3823
4233
|
template <int qk, int qr, int qi, bool need_sum, typename block_q_t, int mmq_x, int mmq_y, int nwarps,
|
3824
4234
|
allocate_tiles_cuda_t allocate_tiles, load_tiles_cuda_t load_tiles, int vdr, vec_dot_q_mul_mat_cuda_t vec_dot>
|
3825
4235
|
static __device__ __forceinline__ void mul_mat_q(
|
@@ -5201,75 +5611,233 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int
|
|
5201
5611
|
dst[i] = x[i] - (col > n_past + row % rows_per_channel) * FLT_MAX;
|
5202
5612
|
}
|
5203
5613
|
|
5204
|
-
|
5614
|
+
template <bool vals_smem, int ncols_template, int block_size_template, bool need_check>
|
5615
|
+
static __global__ void soft_max_f16(const float * x, const float * y, float * dst, const int ncols_par, const int nrows_y, const float scale) {
|
5616
|
+
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
|
5617
|
+
const int ncols_data = ncols_template == 0 ? ncols_par : ncols_template;
|
5618
|
+
const int ncols_smem = GGML_PAD(ncols_data, 2*WARP_SIZE)/2;
|
5619
|
+
|
5620
|
+
const int tid = threadIdx.x;
|
5621
|
+
const int rowx = blockIdx.x;
|
5622
|
+
const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension
|
5623
|
+
|
5624
|
+
const int block_size = block_size_template == 0 ? blockDim.x : block_size_template;
|
5625
|
+
|
5626
|
+
const int warp_id = threadIdx.x / WARP_SIZE;
|
5627
|
+
const int lane_id = threadIdx.x % WARP_SIZE;
|
5628
|
+
|
5629
|
+
extern __shared__ half data_soft_max_f16[];
|
5630
|
+
half * buf_iw = data_soft_max_f16 + 0; // shared memory buffer for inter-warp communication
|
5631
|
+
// (shared memory) buffer to cache values between iterations:
|
5632
|
+
half2 * vals = vals_smem ? (half2 *) (buf_iw + WARP_SIZE) : (half2 *) (dst + rowx*ncols_data);
|
5633
|
+
// if the buffer is larger than max. shared memory per block, use dst as temp. buffer instead
|
5634
|
+
// in that case col_smem == col_data must be enforced to avoid race conditions
|
5635
|
+
|
5636
|
+
half2 max_val = make_half2(-INFINITY, -INFINITY);
|
5637
|
+
|
5638
|
+
#pragma unroll
|
5639
|
+
for (int col0 = 0; col0 < ncols_smem; col0 += block_size) {
|
5640
|
+
const int col_data = 2*col0 + 2*WARP_SIZE*warp_id + lane_id;
|
5641
|
+
const int col_smem = vals_smem ? col0 + tid : col_data;
|
5642
|
+
|
5643
|
+
const int ix = rowx*ncols_data + col_data;
|
5644
|
+
const int iy = rowy*ncols_data + col_data;
|
5645
|
+
|
5646
|
+
half2 val;
|
5647
|
+
if (need_check && col_data + 0 >= ncols_data) {
|
5648
|
+
val.x = -INFINITY;
|
5649
|
+
} else {
|
5650
|
+
val.x = x[ix + 0]*scale + (y ? y[iy + 0] : 0.0f);
|
5651
|
+
}
|
5652
|
+
if (need_check && col_data + WARP_SIZE >= ncols_data) {
|
5653
|
+
val.y = -INFINITY;
|
5654
|
+
} else {
|
5655
|
+
val.y = x[ix + WARP_SIZE]*scale + (y ? y[iy + WARP_SIZE] : 0.0f);
|
5656
|
+
}
|
5657
|
+
if (!need_check || col_smem < (vals_smem ? ncols_smem : ncols_data)) {
|
5658
|
+
vals[col_smem] = val;
|
5659
|
+
}
|
5660
|
+
max_val = __hmax2(max_val, val);
|
5661
|
+
}
|
5662
|
+
|
5663
|
+
// find the max value in the block
|
5664
|
+
max_val = warp_reduce_max(max_val);
|
5665
|
+
if (block_size > WARP_SIZE) {
|
5666
|
+
if (warp_id == 0) {
|
5667
|
+
buf_iw[lane_id] = -INFINITY;
|
5668
|
+
}
|
5669
|
+
__syncthreads();
|
5670
|
+
|
5671
|
+
if (lane_id == 0) {
|
5672
|
+
buf_iw[warp_id] = __hmax(max_val.x, max_val.y);
|
5673
|
+
}
|
5674
|
+
__syncthreads();
|
5675
|
+
|
5676
|
+
max_val = __half2half2(buf_iw[lane_id]);
|
5677
|
+
max_val = warp_reduce_max(max_val);
|
5678
|
+
} else {
|
5679
|
+
max_val = __half2half2(__hmax(max_val.x, max_val.y));
|
5680
|
+
}
|
5681
|
+
|
5682
|
+
half2 tmp = make_half2(0.0f, 0.0f); // partial sums
|
5683
|
+
|
5684
|
+
#pragma unroll
|
5685
|
+
for (int col0 = 0; col0 < ncols_smem; col0 += block_size) {
|
5686
|
+
const int col_smem = vals_smem ? col0 + tid : 2*col0 + 2*warp_id*WARP_SIZE + lane_id;
|
5687
|
+
|
5688
|
+
if (ncols_template == 0 && col_smem >= (vals_smem ? ncols_smem : ncols_data)) {
|
5689
|
+
break;
|
5690
|
+
}
|
5691
|
+
|
5692
|
+
const half2 val = h2exp(vals[col_smem] - max_val);
|
5693
|
+
|
5694
|
+
tmp += val;
|
5695
|
+
vals[col_smem] = val;
|
5696
|
+
}
|
5697
|
+
|
5698
|
+
// find the sum of exps in the block
|
5699
|
+
tmp = warp_reduce_sum(tmp);
|
5700
|
+
if (block_size > WARP_SIZE) {
|
5701
|
+
if (warp_id == 0) {
|
5702
|
+
buf_iw[lane_id] = 0.0f;
|
5703
|
+
}
|
5704
|
+
__syncthreads();
|
5705
|
+
|
5706
|
+
if (lane_id == 0) {
|
5707
|
+
buf_iw[warp_id] = tmp.x + tmp.y;
|
5708
|
+
}
|
5709
|
+
__syncthreads();
|
5710
|
+
|
5711
|
+
tmp = __half2half2(buf_iw[lane_id]);
|
5712
|
+
tmp = warp_reduce_sum(tmp);
|
5713
|
+
} else {
|
5714
|
+
tmp = __half2half2(tmp.x + tmp.y);
|
5715
|
+
}
|
5716
|
+
|
5717
|
+
const half2 inv_sum = make_half2(1.0f, 1.0f) / tmp;
|
5718
|
+
|
5719
|
+
#pragma unroll
|
5720
|
+
for (int col0 = 0; col0 < ncols_smem; col0 += block_size) {
|
5721
|
+
const int col_data = 2*col0 + 2*WARP_SIZE*warp_id + lane_id;
|
5722
|
+
const int col_smem = vals_smem ? col0 + tid : col_data;
|
5723
|
+
|
5724
|
+
const int idst = rowx*ncols_data + col_data;
|
5725
|
+
const half2 result = vals[col_smem] * inv_sum;
|
5726
|
+
|
5727
|
+
if (need_check && col_data + 0 >= ncols_data) {
|
5728
|
+
return;
|
5729
|
+
}
|
5730
|
+
dst[idst] = result.x;
|
5731
|
+
|
5732
|
+
if (need_check && col_data + WARP_SIZE >= ncols_data) {
|
5733
|
+
return;
|
5734
|
+
}
|
5735
|
+
|
5736
|
+
dst[idst + WARP_SIZE] = result.y;
|
5737
|
+
}
|
5738
|
+
#else
|
5739
|
+
(void) x; (void) y; (void) dst; (void) ncols_par; (void) nrows_y; (void) scale;
|
5740
|
+
bad_arch();
|
5741
|
+
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
|
5742
|
+
}
|
5743
|
+
|
5744
|
+
template <bool vals_smem, int ncols_template, int block_size_template>
|
5745
|
+
static __global__ void soft_max_f32(const float * x, const float * y, float * dst, const int ncols_par, const int nrows_y, const float scale) {
|
5746
|
+
const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
|
5747
|
+
|
5205
5748
|
const int tid = threadIdx.x;
|
5206
5749
|
const int rowx = blockIdx.x;
|
5207
5750
|
const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension
|
5208
5751
|
|
5209
|
-
const int block_size = blockDim.x;
|
5752
|
+
const int block_size = block_size_template == 0 ? blockDim.x : block_size_template;
|
5210
5753
|
|
5211
5754
|
const int warp_id = threadIdx.x / WARP_SIZE;
|
5212
5755
|
const int lane_id = threadIdx.x % WARP_SIZE;
|
5213
5756
|
|
5214
|
-
__shared__ float
|
5757
|
+
extern __shared__ float data_soft_max_f32[];
|
5758
|
+
float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication
|
5759
|
+
// shared memory buffer to cache values between iterations:
|
5760
|
+
float * vals = vals_smem ? buf_iw + WARP_SIZE : dst + rowx*ncols;
|
5215
5761
|
|
5216
5762
|
float max_val = -INFINITY;
|
5217
5763
|
|
5218
|
-
|
5764
|
+
#pragma unroll
|
5765
|
+
for (int col0 = 0; col0 < ncols; col0 += block_size) {
|
5766
|
+
const int col = col0 + tid;
|
5767
|
+
|
5768
|
+
if (ncols_template == 0 && col >= ncols) {
|
5769
|
+
break;
|
5770
|
+
}
|
5771
|
+
|
5219
5772
|
const int ix = rowx*ncols + col;
|
5220
5773
|
const int iy = rowy*ncols + col;
|
5221
|
-
|
5774
|
+
|
5775
|
+
const float val = x[ix]*scale + (y ? y[iy] : 0.0f);
|
5776
|
+
vals[col] = val;
|
5777
|
+
max_val = max(max_val, val);
|
5222
5778
|
}
|
5223
5779
|
|
5224
5780
|
// find the max value in the block
|
5225
5781
|
max_val = warp_reduce_max(max_val);
|
5226
5782
|
if (block_size > WARP_SIZE) {
|
5227
5783
|
if (warp_id == 0) {
|
5228
|
-
|
5784
|
+
buf_iw[lane_id] = -INFINITY;
|
5229
5785
|
}
|
5230
5786
|
__syncthreads();
|
5231
5787
|
|
5232
5788
|
if (lane_id == 0) {
|
5233
|
-
|
5789
|
+
buf_iw[warp_id] = max_val;
|
5234
5790
|
}
|
5235
5791
|
__syncthreads();
|
5236
5792
|
|
5237
|
-
max_val =
|
5793
|
+
max_val = buf_iw[lane_id];
|
5238
5794
|
max_val = warp_reduce_max(max_val);
|
5239
5795
|
}
|
5240
5796
|
|
5241
|
-
float tmp = 0.
|
5797
|
+
float tmp = 0.0f; // partial sum
|
5242
5798
|
|
5243
|
-
|
5244
|
-
|
5245
|
-
const int
|
5246
|
-
|
5799
|
+
#pragma unroll
|
5800
|
+
for (int col0 = 0; col0 < ncols; col0 += block_size) {
|
5801
|
+
const int col = col0 + tid;
|
5802
|
+
|
5803
|
+
if (ncols_template == 0 && col >= ncols) {
|
5804
|
+
break;
|
5805
|
+
}
|
5806
|
+
|
5807
|
+
const float val = expf(vals[col] - max_val);
|
5247
5808
|
tmp += val;
|
5248
|
-
|
5809
|
+
vals[col] = val;
|
5249
5810
|
}
|
5250
5811
|
|
5251
5812
|
// find the sum of exps in the block
|
5252
5813
|
tmp = warp_reduce_sum(tmp);
|
5253
5814
|
if (block_size > WARP_SIZE) {
|
5254
5815
|
if (warp_id == 0) {
|
5255
|
-
|
5816
|
+
buf_iw[lane_id] = 0.0f;
|
5256
5817
|
}
|
5257
5818
|
__syncthreads();
|
5258
5819
|
|
5259
5820
|
if (lane_id == 0) {
|
5260
|
-
|
5821
|
+
buf_iw[warp_id] = tmp;
|
5261
5822
|
}
|
5262
5823
|
__syncthreads();
|
5263
5824
|
|
5264
|
-
tmp =
|
5825
|
+
tmp = buf_iw[lane_id];
|
5265
5826
|
tmp = warp_reduce_sum(tmp);
|
5266
5827
|
}
|
5267
5828
|
|
5268
|
-
const float
|
5829
|
+
const float inv_sum = 1.0f / tmp;
|
5269
5830
|
|
5270
|
-
|
5271
|
-
|
5272
|
-
|
5831
|
+
#pragma unroll
|
5832
|
+
for (int col0 = 0; col0 < ncols; col0 += block_size) {
|
5833
|
+
const int col = col0 + tid;
|
5834
|
+
|
5835
|
+
if (ncols_template == 0 && col >= ncols) {
|
5836
|
+
return;
|
5837
|
+
}
|
5838
|
+
|
5839
|
+
const int idst = rowx*ncols + col;
|
5840
|
+
dst[idst] = vals[col] * inv_sum;
|
5273
5841
|
}
|
5274
5842
|
}
|
5275
5843
|
|
@@ -5609,7 +6177,7 @@ static void quantize_row_q8_1_cuda(const float * x, void * vy, const int kx, con
|
|
5609
6177
|
|
5610
6178
|
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
|
5611
6179
|
static void dequantize_block_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) {
|
5612
|
-
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
|
6180
|
+
const int num_blocks = (k + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE);
|
5613
6181
|
dequantize_block<qk, qr, dequantize_kernel><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
|
5614
6182
|
}
|
5615
6183
|
|
@@ -5659,6 +6227,24 @@ static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int k, cu
|
|
5659
6227
|
#endif
|
5660
6228
|
}
|
5661
6229
|
|
6230
|
+
template<typename dst_t>
|
6231
|
+
static void dequantize_row_iq2_xxs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
|
6232
|
+
const int nb = k / QK_K;
|
6233
|
+
dequantize_block_iq2_xxs<<<nb, 32, 0, stream>>>(vx, y);
|
6234
|
+
}
|
6235
|
+
|
6236
|
+
template<typename dst_t>
|
6237
|
+
static void dequantize_row_iq2_xs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
|
6238
|
+
const int nb = k / QK_K;
|
6239
|
+
dequantize_block_iq2_xs<<<nb, 32, 0, stream>>>(vx, y);
|
6240
|
+
}
|
6241
|
+
|
6242
|
+
template <typename src_t, typename dst_t>
|
6243
|
+
static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) {
|
6244
|
+
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
|
6245
|
+
convert_unary<src_t><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
|
6246
|
+
}
|
6247
|
+
|
5662
6248
|
static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
|
5663
6249
|
switch (type) {
|
5664
6250
|
case GGML_TYPE_Q4_0:
|
@@ -5681,8 +6267,12 @@ static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
|
|
5681
6267
|
return dequantize_row_q5_K_cuda;
|
5682
6268
|
case GGML_TYPE_Q6_K:
|
5683
6269
|
return dequantize_row_q6_K_cuda;
|
6270
|
+
case GGML_TYPE_IQ2_XXS:
|
6271
|
+
return dequantize_row_iq2_xxs_cuda;
|
6272
|
+
case GGML_TYPE_IQ2_XS:
|
6273
|
+
return dequantize_row_iq2_xs_cuda;
|
5684
6274
|
case GGML_TYPE_F32:
|
5685
|
-
return
|
6275
|
+
return convert_unary_cuda<float>;
|
5686
6276
|
default:
|
5687
6277
|
return nullptr;
|
5688
6278
|
}
|
@@ -5710,8 +6300,12 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
|
|
5710
6300
|
return dequantize_row_q5_K_cuda;
|
5711
6301
|
case GGML_TYPE_Q6_K:
|
5712
6302
|
return dequantize_row_q6_K_cuda;
|
6303
|
+
case GGML_TYPE_IQ2_XXS:
|
6304
|
+
return dequantize_row_iq2_xxs_cuda;
|
6305
|
+
case GGML_TYPE_IQ2_XS:
|
6306
|
+
return dequantize_row_iq2_xs_cuda;
|
5713
6307
|
case GGML_TYPE_F16:
|
5714
|
-
return
|
6308
|
+
return convert_unary_cuda<half>;
|
5715
6309
|
default:
|
5716
6310
|
return nullptr;
|
5717
6311
|
}
|
@@ -5904,6 +6498,24 @@ static void mul_mat_vec_q6_K_q8_1_cuda(const void * vx, const void * vy, float *
|
|
5904
6498
|
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
5905
6499
|
}
|
5906
6500
|
|
6501
|
+
static void mul_mat_vec_iq2_xxs_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
6502
|
+
GGML_ASSERT(ncols % QK_K == 0);
|
6503
|
+
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
6504
|
+
const dim3 block_nums(block_num_y, 1, 1);
|
6505
|
+
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
6506
|
+
mul_mat_vec_q<QK_K, QI2_XXS, block_iq2_xxs, 1, vec_dot_iq2_xxs_q8_1>
|
6507
|
+
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
6508
|
+
}
|
6509
|
+
|
6510
|
+
static void mul_mat_vec_iq2_xs_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
6511
|
+
GGML_ASSERT(ncols % QK_K == 0);
|
6512
|
+
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
6513
|
+
const dim3 block_nums(block_num_y, 1, 1);
|
6514
|
+
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
6515
|
+
mul_mat_vec_q<QK_K, QI2_XS, block_iq2_xs, 1, vec_dot_iq2_xs_q8_1>
|
6516
|
+
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
6517
|
+
}
|
6518
|
+
|
5907
6519
|
static void ggml_mul_mat_q4_0_q8_1_cuda(
|
5908
6520
|
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
|
5909
6521
|
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
|
@@ -6543,12 +7155,90 @@ static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols
|
|
6543
7155
|
diag_mask_inf_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x, rows_per_channel, n_past);
|
6544
7156
|
}
|
6545
7157
|
|
7158
|
+
static void soft_max_f16_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) {
|
7159
|
+
int nth = WARP_SIZE;
|
7160
|
+
while (nth < ncols_x/2 && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
|
7161
|
+
const dim3 block_dims(nth, 1, 1);
|
7162
|
+
const dim3 block_nums(nrows_x, 1, 1);
|
7163
|
+
const size_t shmem = (GGML_PAD(ncols_x, 2*WARP_SIZE) + WARP_SIZE)*sizeof(half);
|
7164
|
+
static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
|
7165
|
+
if (shmem <= g_device_caps[g_main_device].smpb) {
|
7166
|
+
switch (ncols_x) {
|
7167
|
+
case 32:
|
7168
|
+
soft_max_f16<true, 32, 32, true><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
7169
|
+
break;
|
7170
|
+
case 64:
|
7171
|
+
soft_max_f16<true, 64, 32, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
7172
|
+
break;
|
7173
|
+
case 128:
|
7174
|
+
soft_max_f16<true, 128, 64, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
7175
|
+
break;
|
7176
|
+
case 256:
|
7177
|
+
soft_max_f16<true, 256, 128, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
7178
|
+
break;
|
7179
|
+
case 512:
|
7180
|
+
soft_max_f16<true, 512, 256, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
7181
|
+
break;
|
7182
|
+
case 1024:
|
7183
|
+
soft_max_f16<true, 1024, 512, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
7184
|
+
break;
|
7185
|
+
case 2048:
|
7186
|
+
soft_max_f16<true, 2048, 1024, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
7187
|
+
break;
|
7188
|
+
case 4096:
|
7189
|
+
soft_max_f16<true, 4096, 1024, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
7190
|
+
break;
|
7191
|
+
default:
|
7192
|
+
soft_max_f16<true, 0, 0, true><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
7193
|
+
break;
|
7194
|
+
}
|
7195
|
+
} else {
|
7196
|
+
const size_t shmem_low = WARP_SIZE*sizeof(half);
|
7197
|
+
soft_max_f16<false, 0, 0, true><<<block_nums, block_dims, shmem_low, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
7198
|
+
}
|
7199
|
+
}
|
7200
|
+
|
6546
7201
|
static void soft_max_f32_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) {
|
6547
7202
|
int nth = WARP_SIZE;
|
6548
7203
|
while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
|
6549
7204
|
const dim3 block_dims(nth, 1, 1);
|
6550
7205
|
const dim3 block_nums(nrows_x, 1, 1);
|
6551
|
-
|
7206
|
+
const size_t shmem = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float);
|
7207
|
+
static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
|
7208
|
+
if (shmem < g_device_caps[g_main_device].smpb) {
|
7209
|
+
switch (ncols_x) {
|
7210
|
+
case 32:
|
7211
|
+
soft_max_f32<true, 32, 32><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
7212
|
+
break;
|
7213
|
+
case 64:
|
7214
|
+
soft_max_f32<true, 64, 64><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
7215
|
+
break;
|
7216
|
+
case 128:
|
7217
|
+
soft_max_f32<true, 128, 128><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
7218
|
+
break;
|
7219
|
+
case 256:
|
7220
|
+
soft_max_f32<true, 256, 256><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
7221
|
+
break;
|
7222
|
+
case 512:
|
7223
|
+
soft_max_f32<true, 512, 512><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
7224
|
+
break;
|
7225
|
+
case 1024:
|
7226
|
+
soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
7227
|
+
break;
|
7228
|
+
case 2048:
|
7229
|
+
soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
7230
|
+
break;
|
7231
|
+
case 4096:
|
7232
|
+
soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
7233
|
+
break;
|
7234
|
+
default:
|
7235
|
+
soft_max_f32<true, 0, 0><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
7236
|
+
break;
|
7237
|
+
}
|
7238
|
+
} else {
|
7239
|
+
const size_t shmem_low = WARP_SIZE*sizeof(float);
|
7240
|
+
soft_max_f32<false, 0, 0><<<block_nums, block_dims, shmem_low, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
7241
|
+
}
|
6552
7242
|
}
|
6553
7243
|
|
6554
7244
|
static void im2col_f32_f16_cuda(const float* x, half* dst,
|
@@ -6863,6 +7553,7 @@ void ggml_init_cublas() {
|
|
6863
7553
|
#else
|
6864
7554
|
g_device_caps[id].cc = 100*prop.major + 10*prop.minor;
|
6865
7555
|
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
7556
|
+
g_device_caps[id].smpb = prop.sharedMemPerBlock;
|
6866
7557
|
}
|
6867
7558
|
for (int id = 0; id < g_device_count; ++id) {
|
6868
7559
|
g_tensor_split[id] /= total_vram;
|
@@ -7396,6 +8087,8 @@ static int64_t get_row_rounding(ggml_type type) {
|
|
7396
8087
|
case GGML_TYPE_Q4_K:
|
7397
8088
|
case GGML_TYPE_Q5_K:
|
7398
8089
|
case GGML_TYPE_Q6_K:
|
8090
|
+
case GGML_TYPE_IQ2_XXS:
|
8091
|
+
case GGML_TYPE_IQ2_XS:
|
7399
8092
|
return max_compute_capability >= CC_RDNA2 ? 128 : 64;
|
7400
8093
|
default:
|
7401
8094
|
GGML_ASSERT(false);
|
@@ -7416,6 +8109,8 @@ static int64_t get_row_rounding(ggml_type type) {
|
|
7416
8109
|
case GGML_TYPE_Q3_K:
|
7417
8110
|
case GGML_TYPE_Q4_K:
|
7418
8111
|
case GGML_TYPE_Q5_K:
|
8112
|
+
case GGML_TYPE_IQ2_XXS:
|
8113
|
+
case GGML_TYPE_IQ2_XS:
|
7419
8114
|
return max_compute_capability >= CC_VOLTA ? 128 : 64;
|
7420
8115
|
case GGML_TYPE_Q6_K:
|
7421
8116
|
return 64;
|
@@ -7466,6 +8161,12 @@ static void ggml_cuda_op_mul_mat_vec_q(
|
|
7466
8161
|
case GGML_TYPE_Q6_K:
|
7467
8162
|
mul_mat_vec_q6_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
|
7468
8163
|
break;
|
8164
|
+
case GGML_TYPE_IQ2_XXS:
|
8165
|
+
mul_mat_vec_iq2_xxs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
|
8166
|
+
break;
|
8167
|
+
case GGML_TYPE_IQ2_XS:
|
8168
|
+
mul_mat_vec_iq2_xs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
|
8169
|
+
break;
|
7469
8170
|
default:
|
7470
8171
|
GGML_ASSERT(false);
|
7471
8172
|
break;
|
@@ -7873,7 +8574,21 @@ static void ggml_cuda_op_soft_max(
|
|
7873
8574
|
float scale = 1.0f;
|
7874
8575
|
memcpy(&scale, dst->op_params, sizeof(float));
|
7875
8576
|
|
7876
|
-
|
8577
|
+
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
8578
|
+
const bool use_f16_soft_max = false;
|
8579
|
+
#else
|
8580
|
+
#ifdef GGML_CUDA_F16
|
8581
|
+
const bool use_f16_soft_max = true;
|
8582
|
+
#else
|
8583
|
+
const bool use_f16_soft_max = false;
|
8584
|
+
#endif // GGML_CUDA_F16
|
8585
|
+
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
8586
|
+
|
8587
|
+
if (use_f16_soft_max) {
|
8588
|
+
soft_max_f16_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream);
|
8589
|
+
} else {
|
8590
|
+
soft_max_f32_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream);
|
8591
|
+
}
|
7877
8592
|
|
7878
8593
|
(void) dst;
|
7879
8594
|
}
|
@@ -8682,6 +9397,8 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
|
|
8682
9397
|
|
8683
9398
|
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
8684
9399
|
|
9400
|
+
use_mul_mat_q = use_mul_mat_q && ggml_cuda_supports_mmq(src0->type);
|
9401
|
+
|
8685
9402
|
// debug helpers
|
8686
9403
|
//printf("src0: %8d %8d %8d %8d\n", src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]);
|
8687
9404
|
//printf(" %8d %8d %8d %8d\n", src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]);
|
@@ -9689,8 +10406,8 @@ static void ggml_backend_cuda_buffer_set_tensor(ggml_backend_buffer_t buffer, gg
|
|
9689
10406
|
|
9690
10407
|
ggml_cuda_set_device(ctx->device);
|
9691
10408
|
CUDA_CHECK(cudaDeviceSynchronize());
|
9692
|
-
|
9693
10409
|
CUDA_CHECK(cudaMemcpy((char *)tensor->data + offset, data, size, cudaMemcpyHostToDevice));
|
10410
|
+
CUDA_CHECK(cudaDeviceSynchronize());
|
9694
10411
|
}
|
9695
10412
|
|
9696
10413
|
static void ggml_backend_cuda_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
|
@@ -9910,7 +10627,7 @@ static void ggml_backend_cuda_graph_plan_compute(ggml_backend_t backend, ggml_ba
|
|
9910
10627
|
UNUSED(plan);
|
9911
10628
|
}
|
9912
10629
|
|
9913
|
-
static
|
10630
|
+
static bool ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
|
9914
10631
|
ggml_backend_context_cuda * cuda_ctx = (ggml_backend_context_cuda *)backend->context;
|
9915
10632
|
|
9916
10633
|
ggml_cuda_set_main_device(cuda_ctx->device);
|
@@ -9967,6 +10684,8 @@ static void ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph
|
|
9967
10684
|
}
|
9968
10685
|
|
9969
10686
|
UNUSED(backend);
|
10687
|
+
|
10688
|
+
return true;
|
9970
10689
|
}
|
9971
10690
|
|
9972
10691
|
static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_tensor * op) {
|