llama_cpp 0.12.0 → 0.12.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +6 -0
- data/ext/llama_cpp/llama_cpp.cpp +14 -0
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +4 -0
- data/vendor/tmp/llama.cpp/Makefile +8 -2
- data/vendor/tmp/llama.cpp/ggml-backend-impl.h +1 -1
- data/vendor/tmp/llama.cpp/ggml-backend.c +7 -3
- data/vendor/tmp/llama.cpp/ggml-backend.h +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda.cu +758 -39
- data/vendor/tmp/llama.cpp/ggml-metal.h +1 -1
- data/vendor/tmp/llama.cpp/ggml-metal.m +86 -7
- data/vendor/tmp/llama.cpp/ggml-metal.metal +692 -8
- data/vendor/tmp/llama.cpp/ggml-quants.c +635 -1
- data/vendor/tmp/llama.cpp/ggml-quants.h +25 -1
- data/vendor/tmp/llama.cpp/ggml.c +91 -52
- data/vendor/tmp/llama.cpp/ggml.h +14 -11
- data/vendor/tmp/llama.cpp/llama.cpp +79 -30
- data/vendor/tmp/llama.cpp/llama.h +14 -0
- metadata +2 -2
@@ -116,6 +116,7 @@
|
|
116
116
|
#include "ggml.h"
|
117
117
|
#include "ggml-backend-impl.h"
|
118
118
|
|
119
|
+
#define CC_PASCAL 600
|
119
120
|
#define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
|
120
121
|
#define CC_VOLTA 700
|
121
122
|
#define CC_OFFSET_AMD 1000000
|
@@ -183,7 +184,7 @@ static __device__ __forceinline__ int __vsubss4(const int a, const int b) {
|
|
183
184
|
static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) {
|
184
185
|
#if defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx1030__)
|
185
186
|
c = __builtin_amdgcn_sdot4(a, b, c, false);
|
186
|
-
#elif defined(
|
187
|
+
#elif defined(RDNA3)
|
187
188
|
c = __builtin_amdgcn_sudot4( true, a, true, b, c, false);
|
188
189
|
#elif defined(__gfx1010__) || defined(__gfx900__)
|
189
190
|
int tmp1;
|
@@ -477,6 +478,23 @@ typedef struct {
|
|
477
478
|
} block_q6_K;
|
478
479
|
static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_K block size/padding");
|
479
480
|
|
481
|
+
#define QR2_XXS 8
|
482
|
+
#define QI2_XXS (QK_K / (4*QR2_XXS))
|
483
|
+
typedef struct {
|
484
|
+
half d;
|
485
|
+
uint16_t qs[QK_K/8];
|
486
|
+
} block_iq2_xxs;
|
487
|
+
static_assert(sizeof(block_iq2_xxs) == sizeof(ggml_fp16_t) + QK_K/8*sizeof(uint16_t), "wrong iq2_xxs block size/padding");
|
488
|
+
|
489
|
+
#define QR2_XS 8
|
490
|
+
#define QI2_XS (QK_K / (4*QR2_XS))
|
491
|
+
typedef struct {
|
492
|
+
half d;
|
493
|
+
uint16_t qs[QK_K/8];
|
494
|
+
uint8_t scales[QK_K/32];
|
495
|
+
} block_iq2_xs;
|
496
|
+
static_assert(sizeof(block_iq2_xs) == sizeof(ggml_fp16_t) + QK_K/8*sizeof(uint16_t) + QK_K/32, "wrong iq2_xs block size/padding");
|
497
|
+
|
480
498
|
#define WARP_SIZE 32
|
481
499
|
#define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
|
482
500
|
|
@@ -548,11 +566,12 @@ static float g_tensor_split[GGML_CUDA_MAX_DEVICES] = {0};
|
|
548
566
|
|
549
567
|
struct cuda_device_capabilities {
|
550
568
|
int cc; // compute capability
|
569
|
+
size_t smpb; // max. shared memory per block
|
551
570
|
bool vmm; // virtual memory support
|
552
571
|
size_t vmm_granularity; // granularity of virtual memory
|
553
572
|
};
|
554
573
|
|
555
|
-
static cuda_device_capabilities g_device_caps[GGML_CUDA_MAX_DEVICES] = { {0, false, 0} };
|
574
|
+
static cuda_device_capabilities g_device_caps[GGML_CUDA_MAX_DEVICES] = { {0, 0, false, 0} };
|
556
575
|
|
557
576
|
static void * g_scratch_buffer = nullptr;
|
558
577
|
static size_t g_scratch_size = 0; // disabled by default
|
@@ -585,6 +604,19 @@ static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
|
|
585
604
|
return a;
|
586
605
|
}
|
587
606
|
|
607
|
+
static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
|
608
|
+
#if __CUDA_ARCH__ < CC_PASCAL || (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
609
|
+
(void) a;
|
610
|
+
bad_arch();
|
611
|
+
#else
|
612
|
+
#pragma unroll
|
613
|
+
for (int mask = 16; mask > 0; mask >>= 1) {
|
614
|
+
a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32));
|
615
|
+
}
|
616
|
+
return a;
|
617
|
+
#endif // __CUDA_ARCH__ < CC_PASCAL || (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
618
|
+
}
|
619
|
+
|
588
620
|
static __device__ __forceinline__ float warp_reduce_max(float x) {
|
589
621
|
#pragma unroll
|
590
622
|
for (int mask = 16; mask > 0; mask >>= 1) {
|
@@ -593,6 +625,19 @@ static __device__ __forceinline__ float warp_reduce_max(float x) {
|
|
593
625
|
return x;
|
594
626
|
}
|
595
627
|
|
628
|
+
static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
|
629
|
+
#if __CUDA_ARCH__ < CC_PASCAL || (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
630
|
+
(void) x;
|
631
|
+
bad_arch();
|
632
|
+
#else
|
633
|
+
#pragma unroll
|
634
|
+
for (int mask = 16; mask > 0; mask >>= 1) {
|
635
|
+
x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
|
636
|
+
}
|
637
|
+
return x;
|
638
|
+
#endif // __CUDA_ARCH__ < CC_PASCAL || (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
639
|
+
}
|
640
|
+
|
596
641
|
static __device__ __forceinline__ float op_repeat(const float a, const float b) {
|
597
642
|
return b;
|
598
643
|
GGML_UNUSED(a);
|
@@ -1292,6 +1337,281 @@ static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t
|
|
1292
1337
|
#endif
|
1293
1338
|
}
|
1294
1339
|
|
1340
|
+
static const __device__ uint64_t iq2xxs_grid[256] = {
|
1341
|
+
0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
|
1342
|
+
0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x08080808082b0808,
|
1343
|
+
0x08080808082b082b, 0x08080808082b2b08, 0x08080808082b2b2b, 0x0808080819080819,
|
1344
|
+
0x0808080819081908, 0x0808080819190808, 0x0808080819192b08, 0x08080808192b0819,
|
1345
|
+
0x08080808192b1908, 0x080808082b080808, 0x080808082b08082b, 0x080808082b082b2b,
|
1346
|
+
0x080808082b2b082b, 0x0808081908080819, 0x0808081908081908, 0x0808081908190808,
|
1347
|
+
0x0808081908191919, 0x0808081919080808, 0x080808192b081908, 0x080808192b192b08,
|
1348
|
+
0x0808082b08080808, 0x0808082b0808082b, 0x0808082b082b082b, 0x0808082b2b08082b,
|
1349
|
+
0x0808190808080819, 0x0808190808081908, 0x0808190808190808, 0x08081908082b0819,
|
1350
|
+
0x08081908082b1908, 0x0808190819080808, 0x080819081908082b, 0x0808190819082b08,
|
1351
|
+
0x08081908192b0808, 0x080819082b080819, 0x080819082b081908, 0x080819082b190808,
|
1352
|
+
0x080819082b2b1908, 0x0808191908080808, 0x080819190808082b, 0x0808191908082b08,
|
1353
|
+
0x08081919082b0808, 0x080819191908192b, 0x08081919192b2b19, 0x080819192b080808,
|
1354
|
+
0x080819192b190819, 0x0808192b08082b19, 0x0808192b08190808, 0x0808192b19080808,
|
1355
|
+
0x0808192b2b081908, 0x0808192b2b2b1908, 0x08082b0808080808, 0x08082b0808081919,
|
1356
|
+
0x08082b0808082b08, 0x08082b0808191908, 0x08082b08082b2b08, 0x08082b0819080819,
|
1357
|
+
0x08082b0819081908, 0x08082b0819190808, 0x08082b081919082b, 0x08082b082b082b08,
|
1358
|
+
0x08082b1908081908, 0x08082b1919080808, 0x08082b2b0808082b, 0x08082b2b08191908,
|
1359
|
+
0x0819080808080819, 0x0819080808081908, 0x0819080808190808, 0x08190808082b0819,
|
1360
|
+
0x0819080819080808, 0x08190808192b0808, 0x081908082b081908, 0x081908082b190808,
|
1361
|
+
0x081908082b191919, 0x0819081908080808, 0x0819081908082b08, 0x08190819082b0808,
|
1362
|
+
0x0819081919190808, 0x0819081919192b2b, 0x081908192b080808, 0x0819082b082b1908,
|
1363
|
+
0x0819082b19081919, 0x0819190808080808, 0x0819190808082b08, 0x08191908082b0808,
|
1364
|
+
0x08191908082b1919, 0x0819190819082b19, 0x081919082b080808, 0x0819191908192b08,
|
1365
|
+
0x08191919192b082b, 0x0819192b08080808, 0x0819192b0819192b, 0x08192b0808080819,
|
1366
|
+
0x08192b0808081908, 0x08192b0808190808, 0x08192b0819080808, 0x08192b082b080819,
|
1367
|
+
0x08192b1908080808, 0x08192b1908081919, 0x08192b192b2b0808, 0x08192b2b19190819,
|
1368
|
+
0x082b080808080808, 0x082b08080808082b, 0x082b080808082b2b, 0x082b080819081908,
|
1369
|
+
0x082b0808192b0819, 0x082b08082b080808, 0x082b08082b08082b, 0x082b0819082b2b19,
|
1370
|
+
0x082b081919082b08, 0x082b082b08080808, 0x082b082b0808082b, 0x082b190808080819,
|
1371
|
+
0x082b190808081908, 0x082b190808190808, 0x082b190819080808, 0x082b19081919192b,
|
1372
|
+
0x082b191908080808, 0x082b191919080819, 0x082b1919192b1908, 0x082b192b2b190808,
|
1373
|
+
0x082b2b0808082b08, 0x082b2b08082b0808, 0x082b2b082b191908, 0x082b2b2b19081908,
|
1374
|
+
0x1908080808080819, 0x1908080808081908, 0x1908080808190808, 0x1908080808192b08,
|
1375
|
+
0x19080808082b0819, 0x19080808082b1908, 0x1908080819080808, 0x1908080819082b08,
|
1376
|
+
0x190808081919192b, 0x19080808192b0808, 0x190808082b080819, 0x190808082b081908,
|
1377
|
+
0x190808082b190808, 0x1908081908080808, 0x19080819082b0808, 0x19080819192b0819,
|
1378
|
+
0x190808192b080808, 0x190808192b081919, 0x1908082b08080819, 0x1908082b08190808,
|
1379
|
+
0x1908082b19082b08, 0x1908082b1919192b, 0x1908082b192b2b08, 0x1908190808080808,
|
1380
|
+
0x1908190808082b08, 0x19081908082b0808, 0x190819082b080808, 0x190819082b192b19,
|
1381
|
+
0x190819190819082b, 0x19081919082b1908, 0x1908192b08080808, 0x19082b0808080819,
|
1382
|
+
0x19082b0808081908, 0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919,
|
1383
|
+
0x19082b1908080808, 0x19082b1919192b08, 0x19082b19192b0819, 0x19082b192b08082b,
|
1384
|
+
0x19082b2b19081919, 0x19082b2b2b190808, 0x1919080808080808, 0x1919080808082b08,
|
1385
|
+
0x1919080808190819, 0x1919080808192b19, 0x19190808082b0808, 0x191908082b080808,
|
1386
|
+
0x191908082b082b08, 0x1919081908081908, 0x191908191908082b, 0x191908192b2b1908,
|
1387
|
+
0x1919082b2b190819, 0x191919082b190808, 0x191919082b19082b, 0x1919191908082b2b,
|
1388
|
+
0x1919192b08080819, 0x1919192b19191908, 0x19192b0808080808, 0x19192b0808190819,
|
1389
|
+
0x19192b0808192b19, 0x19192b08192b1908, 0x19192b1919080808, 0x19192b2b08082b08,
|
1390
|
+
0x192b080808081908, 0x192b080808190808, 0x192b080819080808, 0x192b0808192b2b08,
|
1391
|
+
0x192b081908080808, 0x192b081919191919, 0x192b082b08192b08, 0x192b082b192b0808,
|
1392
|
+
0x192b190808080808, 0x192b190808081919, 0x192b191908190808, 0x192b19190819082b,
|
1393
|
+
0x192b19192b081908, 0x192b2b081908082b, 0x2b08080808080808, 0x2b0808080808082b,
|
1394
|
+
0x2b08080808082b2b, 0x2b08080819080819, 0x2b0808082b08082b, 0x2b08081908081908,
|
1395
|
+
0x2b08081908192b08, 0x2b08081919080808, 0x2b08082b08190819, 0x2b08190808080819,
|
1396
|
+
0x2b08190808081908, 0x2b08190808190808, 0x2b08190808191919, 0x2b08190819080808,
|
1397
|
+
0x2b081908192b0808, 0x2b08191908080808, 0x2b0819191908192b, 0x2b0819192b191908,
|
1398
|
+
0x2b08192b08082b19, 0x2b08192b19080808, 0x2b08192b192b0808, 0x2b082b080808082b,
|
1399
|
+
0x2b082b1908081908, 0x2b082b2b08190819, 0x2b19080808081908, 0x2b19080808190808,
|
1400
|
+
0x2b190808082b1908, 0x2b19080819080808, 0x2b1908082b2b0819, 0x2b1908190819192b,
|
1401
|
+
0x2b1908192b080808, 0x2b19082b19081919, 0x2b19190808080808, 0x2b191908082b082b,
|
1402
|
+
0x2b19190819081908, 0x2b19191919190819, 0x2b192b082b080819, 0x2b192b19082b0808,
|
1403
|
+
0x2b2b08080808082b, 0x2b2b080819190808, 0x2b2b08082b081919, 0x2b2b081908082b19,
|
1404
|
+
0x2b2b082b08080808, 0x2b2b190808192b08, 0x2b2b2b0819190808, 0x2b2b2b1908081908,
|
1405
|
+
};
|
1406
|
+
|
1407
|
+
static const __device__ uint64_t iq2xs_grid[512] = {
|
1408
|
+
0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
|
1409
|
+
0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x080808080819192b,
|
1410
|
+
0x0808080808192b19, 0x08080808082b0808, 0x08080808082b082b, 0x08080808082b1919,
|
1411
|
+
0x08080808082b2b08, 0x0808080819080819, 0x0808080819081908, 0x080808081908192b,
|
1412
|
+
0x0808080819082b19, 0x0808080819190808, 0x080808081919082b, 0x0808080819191919,
|
1413
|
+
0x0808080819192b08, 0x08080808192b0819, 0x08080808192b1908, 0x080808082b080808,
|
1414
|
+
0x080808082b08082b, 0x080808082b081919, 0x080808082b082b08, 0x080808082b190819,
|
1415
|
+
0x080808082b191908, 0x080808082b192b19, 0x080808082b2b0808, 0x0808081908080819,
|
1416
|
+
0x0808081908081908, 0x080808190808192b, 0x0808081908082b19, 0x0808081908190808,
|
1417
|
+
0x080808190819082b, 0x0808081908191919, 0x0808081908192b08, 0x0808081908192b2b,
|
1418
|
+
0x08080819082b0819, 0x08080819082b1908, 0x0808081919080808, 0x080808191908082b,
|
1419
|
+
0x0808081919081919, 0x0808081919082b08, 0x0808081919190819, 0x0808081919191908,
|
1420
|
+
0x08080819192b0808, 0x08080819192b2b08, 0x080808192b080819, 0x080808192b081908,
|
1421
|
+
0x080808192b190808, 0x0808082b08080808, 0x0808082b0808082b, 0x0808082b08081919,
|
1422
|
+
0x0808082b08082b08, 0x0808082b08190819, 0x0808082b08191908, 0x0808082b082b0808,
|
1423
|
+
0x0808082b19080819, 0x0808082b19081908, 0x0808082b19190808, 0x0808082b19191919,
|
1424
|
+
0x0808082b2b080808, 0x0808082b2b082b2b, 0x0808190808080819, 0x0808190808081908,
|
1425
|
+
0x080819080808192b, 0x0808190808082b19, 0x0808190808190808, 0x080819080819082b,
|
1426
|
+
0x0808190808191919, 0x0808190808192b08, 0x08081908082b0819, 0x08081908082b1908,
|
1427
|
+
0x0808190819080808, 0x080819081908082b, 0x0808190819081919, 0x0808190819082b08,
|
1428
|
+
0x0808190819190819, 0x0808190819191908, 0x080819081919192b, 0x08081908192b0808,
|
1429
|
+
0x080819082b080819, 0x080819082b081908, 0x080819082b190808, 0x0808191908080808,
|
1430
|
+
0x080819190808082b, 0x0808191908081919, 0x0808191908082b08, 0x0808191908190819,
|
1431
|
+
0x0808191908191908, 0x08081919082b0808, 0x0808191919080819, 0x0808191919081908,
|
1432
|
+
0x0808191919190808, 0x08081919192b0819, 0x080819192b080808, 0x0808192b08080819,
|
1433
|
+
0x0808192b08081908, 0x0808192b08190808, 0x0808192b082b192b, 0x0808192b19080808,
|
1434
|
+
0x0808192b1908082b, 0x0808192b2b081908, 0x08082b0808080808, 0x08082b080808082b,
|
1435
|
+
0x08082b0808081919, 0x08082b0808082b08, 0x08082b0808082b2b, 0x08082b0808190819,
|
1436
|
+
0x08082b0808191908, 0x08082b08082b0808, 0x08082b08082b1919, 0x08082b0819080819,
|
1437
|
+
0x08082b0819081908, 0x08082b0819190808, 0x08082b0819192b08, 0x08082b082b080808,
|
1438
|
+
0x08082b082b2b0808, 0x08082b082b2b2b2b, 0x08082b1908080819, 0x08082b1908081908,
|
1439
|
+
0x08082b1908190808, 0x08082b1919080808, 0x08082b192b080819, 0x08082b192b082b19,
|
1440
|
+
0x08082b2b08080808, 0x08082b2b082b0808, 0x08082b2b082b2b08, 0x08082b2b2b19192b,
|
1441
|
+
0x08082b2b2b2b0808, 0x0819080808080819, 0x0819080808081908, 0x081908080808192b,
|
1442
|
+
0x0819080808082b19, 0x0819080808190808, 0x081908080819082b, 0x0819080808191919,
|
1443
|
+
0x0819080808192b08, 0x08190808082b0819, 0x08190808082b1908, 0x0819080819080808,
|
1444
|
+
0x081908081908082b, 0x0819080819081919, 0x0819080819082b08, 0x0819080819190819,
|
1445
|
+
0x0819080819191908, 0x08190808192b0808, 0x08190808192b2b2b, 0x081908082b080819,
|
1446
|
+
0x081908082b081908, 0x081908082b190808, 0x0819081908080808, 0x081908190808082b,
|
1447
|
+
0x0819081908081919, 0x0819081908082b08, 0x0819081908190819, 0x0819081908191908,
|
1448
|
+
0x08190819082b0808, 0x0819081919080819, 0x0819081919081908, 0x0819081919190808,
|
1449
|
+
0x081908192b080808, 0x081908192b191908, 0x081908192b19192b, 0x0819082b08080819,
|
1450
|
+
0x0819082b08081908, 0x0819082b0808192b, 0x0819082b08190808, 0x0819082b19080808,
|
1451
|
+
0x0819082b192b0808, 0x0819190808080808, 0x081919080808082b, 0x0819190808081919,
|
1452
|
+
0x0819190808082b08, 0x0819190808190819, 0x0819190808191908, 0x08191908082b0808,
|
1453
|
+
0x0819190819080819, 0x0819190819081908, 0x0819190819082b19, 0x0819190819190808,
|
1454
|
+
0x08191908192b1908, 0x081919082b080808, 0x0819191908080819, 0x0819191908081908,
|
1455
|
+
0x0819191908190808, 0x0819191919080808, 0x0819192b08080808, 0x0819192b08191908,
|
1456
|
+
0x0819192b19082b19, 0x08192b0808080819, 0x08192b0808081908, 0x08192b0808190808,
|
1457
|
+
0x08192b080819082b, 0x08192b0819080808, 0x08192b0819191908, 0x08192b082b08192b,
|
1458
|
+
0x08192b1908080808, 0x08192b1908081919, 0x08192b19192b192b, 0x08192b2b19190819,
|
1459
|
+
0x08192b2b2b2b2b19, 0x082b080808080808, 0x082b08080808082b, 0x082b080808081919,
|
1460
|
+
0x082b080808082b08, 0x082b080808082b2b, 0x082b080808190819, 0x082b080808191908,
|
1461
|
+
0x082b0808082b0808, 0x082b080819080819, 0x082b080819081908, 0x082b080819190808,
|
1462
|
+
0x082b08082b080808, 0x082b08082b2b0808, 0x082b081908080819, 0x082b081908081908,
|
1463
|
+
0x082b081908190808, 0x082b081919080808, 0x082b081919082b08, 0x082b0819192b1919,
|
1464
|
+
0x082b082b08080808, 0x082b082b082b082b, 0x082b082b2b080808, 0x082b082b2b2b2b08,
|
1465
|
+
0x082b190808080819, 0x082b190808081908, 0x082b190808190808, 0x082b1908082b2b19,
|
1466
|
+
0x082b190819080808, 0x082b191908080808, 0x082b191919080819, 0x082b19191919082b,
|
1467
|
+
0x082b19192b192b19, 0x082b192b08080819, 0x082b192b08192b2b, 0x082b192b2b2b192b,
|
1468
|
+
0x082b2b0808080808, 0x082b2b0808082b08, 0x082b2b0808082b2b, 0x082b2b08082b0808,
|
1469
|
+
0x082b2b0819191919, 0x082b2b082b082b08, 0x082b2b082b2b082b, 0x082b2b19192b2b08,
|
1470
|
+
0x082b2b192b190808, 0x082b2b2b08082b08, 0x082b2b2b082b0808, 0x082b2b2b2b08082b,
|
1471
|
+
0x082b2b2b2b082b08, 0x082b2b2b2b082b2b, 0x1908080808080819, 0x1908080808081908,
|
1472
|
+
0x190808080808192b, 0x1908080808082b19, 0x1908080808190808, 0x190808080819082b,
|
1473
|
+
0x1908080808191919, 0x1908080808192b08, 0x19080808082b0819, 0x19080808082b1908,
|
1474
|
+
0x1908080819080808, 0x190808081908082b, 0x1908080819081919, 0x1908080819082b08,
|
1475
|
+
0x1908080819082b2b, 0x1908080819190819, 0x1908080819191908, 0x19080808192b0808,
|
1476
|
+
0x19080808192b1919, 0x190808082b080819, 0x190808082b081908, 0x190808082b190808,
|
1477
|
+
0x1908081908080808, 0x190808190808082b, 0x1908081908081919, 0x1908081908082b08,
|
1478
|
+
0x1908081908190819, 0x1908081908191908, 0x19080819082b0808, 0x1908081919080819,
|
1479
|
+
0x1908081919081908, 0x1908081919190808, 0x190808192b080808, 0x190808192b081919,
|
1480
|
+
0x190808192b2b082b, 0x1908082b08080819, 0x1908082b08081908, 0x1908082b08190808,
|
1481
|
+
0x1908082b0819082b, 0x1908082b082b2b19, 0x1908082b19080808, 0x1908190808080808,
|
1482
|
+
0x190819080808082b, 0x1908190808081919, 0x1908190808082b08, 0x1908190808190819,
|
1483
|
+
0x1908190808191908, 0x1908190808192b19, 0x19081908082b0808, 0x1908190819080819,
|
1484
|
+
0x1908190819081908, 0x1908190819190808, 0x190819082b080808, 0x190819082b191908,
|
1485
|
+
0x1908191908080819, 0x1908191908081908, 0x1908191908190808, 0x19081919082b1908,
|
1486
|
+
0x1908191919080808, 0x190819192b192b2b, 0x1908192b08080808, 0x1908192b08082b2b,
|
1487
|
+
0x1908192b19081908, 0x1908192b19190808, 0x19082b0808080819, 0x19082b0808081908,
|
1488
|
+
0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919, 0x19082b0819191908,
|
1489
|
+
0x19082b08192b082b, 0x19082b1908080808, 0x19082b1908190819, 0x19082b1919081908,
|
1490
|
+
0x19082b1919190808, 0x19082b19192b2b19, 0x19082b2b08081908, 0x1919080808080808,
|
1491
|
+
0x191908080808082b, 0x1919080808081919, 0x1919080808082b08, 0x1919080808190819,
|
1492
|
+
0x1919080808191908, 0x19190808082b0808, 0x19190808082b2b08, 0x1919080819080819,
|
1493
|
+
0x1919080819081908, 0x1919080819190808, 0x191908082b080808, 0x1919081908080819,
|
1494
|
+
0x1919081908081908, 0x1919081908190808, 0x1919081908191919, 0x1919081919080808,
|
1495
|
+
0x191908191908082b, 0x1919082b08080808, 0x1919082b19081908, 0x1919082b2b2b2b2b,
|
1496
|
+
0x1919190808080819, 0x1919190808081908, 0x1919190808190808, 0x19191908082b0819,
|
1497
|
+
0x1919190819080808, 0x19191908192b0808, 0x191919082b080819, 0x191919082b2b0819,
|
1498
|
+
0x1919191908080808, 0x1919191908082b08, 0x191919192b080808, 0x191919192b082b08,
|
1499
|
+
0x1919192b082b0819, 0x1919192b192b2b08, 0x1919192b2b2b0819, 0x19192b0808080808,
|
1500
|
+
0x19192b0808191908, 0x19192b0819080819, 0x19192b0819190808, 0x19192b082b192b19,
|
1501
|
+
0x19192b1908192b2b, 0x19192b1919080808, 0x19192b191908082b, 0x19192b2b2b081919,
|
1502
|
+
0x192b080808080819, 0x192b080808081908, 0x192b080808190808, 0x192b080819080808,
|
1503
|
+
0x192b080819191908, 0x192b0808192b082b, 0x192b08082b08192b, 0x192b08082b2b2b19,
|
1504
|
+
0x192b081908080808, 0x192b082b082b1908, 0x192b082b19082b2b, 0x192b082b2b19082b,
|
1505
|
+
0x192b190808080808, 0x192b19080819192b, 0x192b191908190808, 0x192b191919080808,
|
1506
|
+
0x192b191919081919, 0x192b19192b2b1908, 0x192b2b0808080819, 0x192b2b08192b2b2b,
|
1507
|
+
0x192b2b19082b1919, 0x192b2b2b0808192b, 0x192b2b2b19191908, 0x192b2b2b192b082b,
|
1508
|
+
0x2b08080808080808, 0x2b0808080808082b, 0x2b08080808081919, 0x2b08080808082b08,
|
1509
|
+
0x2b08080808190819, 0x2b08080808191908, 0x2b080808082b0808, 0x2b080808082b2b2b,
|
1510
|
+
0x2b08080819080819, 0x2b08080819081908, 0x2b08080819190808, 0x2b0808082b080808,
|
1511
|
+
0x2b0808082b08082b, 0x2b0808082b2b2b08, 0x2b0808082b2b2b2b, 0x2b08081908080819,
|
1512
|
+
0x2b08081908081908, 0x2b0808190808192b, 0x2b08081908190808, 0x2b08081919080808,
|
1513
|
+
0x2b08081919190819, 0x2b08081919192b19, 0x2b08082b08080808, 0x2b08082b082b0808,
|
1514
|
+
0x2b08082b2b080808, 0x2b08082b2b08082b, 0x2b08082b2b2b0808, 0x2b08082b2b2b2b08,
|
1515
|
+
0x2b08190808080819, 0x2b08190808081908, 0x2b08190808190808, 0x2b0819080819082b,
|
1516
|
+
0x2b08190808191919, 0x2b08190819080808, 0x2b081908192b0808, 0x2b0819082b082b19,
|
1517
|
+
0x2b08191908080808, 0x2b08191919081908, 0x2b0819192b2b1919, 0x2b08192b08192b08,
|
1518
|
+
0x2b08192b192b2b2b, 0x2b082b0808080808, 0x2b082b0808082b08, 0x2b082b08082b1919,
|
1519
|
+
0x2b082b0819192b2b, 0x2b082b082b080808, 0x2b082b082b08082b, 0x2b082b082b2b2b08,
|
1520
|
+
0x2b082b190808192b, 0x2b082b2b082b082b, 0x2b082b2b2b080808, 0x2b082b2b2b082b08,
|
1521
|
+
0x2b082b2b2b19192b, 0x2b082b2b2b2b2b08, 0x2b19080808080819, 0x2b19080808081908,
|
1522
|
+
0x2b19080808190808, 0x2b19080819080808, 0x2b1908081919192b, 0x2b1908082b081908,
|
1523
|
+
0x2b19081908080808, 0x2b190819082b082b, 0x2b190819192b1908, 0x2b19082b1919192b,
|
1524
|
+
0x2b19082b2b082b19, 0x2b19190808080808, 0x2b19190808081919, 0x2b19190819081908,
|
1525
|
+
0x2b19190819190808, 0x2b19190819192b08, 0x2b191919082b2b19, 0x2b1919192b190808,
|
1526
|
+
0x2b1919192b19082b, 0x2b19192b19080819, 0x2b192b0819190819, 0x2b192b082b2b192b,
|
1527
|
+
0x2b192b1919082b19, 0x2b192b2b08191919, 0x2b192b2b192b0808, 0x2b2b080808080808,
|
1528
|
+
0x2b2b08080808082b, 0x2b2b080808082b08, 0x2b2b080808082b2b, 0x2b2b0808082b0808,
|
1529
|
+
0x2b2b0808082b2b2b, 0x2b2b08082b2b0808, 0x2b2b081919190819, 0x2b2b081919192b19,
|
1530
|
+
0x2b2b08192b2b192b, 0x2b2b082b08080808, 0x2b2b082b0808082b, 0x2b2b082b08082b08,
|
1531
|
+
0x2b2b082b082b2b2b, 0x2b2b082b2b080808, 0x2b2b082b2b2b0808, 0x2b2b190819080808,
|
1532
|
+
0x2b2b19082b191919, 0x2b2b192b192b1919, 0x2b2b192b2b192b08, 0x2b2b2b0808082b2b,
|
1533
|
+
0x2b2b2b08082b0808, 0x2b2b2b08082b082b, 0x2b2b2b08082b2b08, 0x2b2b2b082b2b0808,
|
1534
|
+
0x2b2b2b082b2b2b08, 0x2b2b2b1908081908, 0x2b2b2b192b081908, 0x2b2b2b192b08192b,
|
1535
|
+
0x2b2b2b2b082b2b08, 0x2b2b2b2b082b2b2b, 0x2b2b2b2b2b190819, 0x2b2b2b2b2b2b2b2b,
|
1536
|
+
};
|
1537
|
+
|
1538
|
+
static const __device__ uint8_t ksigns_iq2xs[128] = {
|
1539
|
+
0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15,
|
1540
|
+
144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159,
|
1541
|
+
160, 33, 34, 163, 36, 165, 166, 39, 40, 169, 170, 43, 172, 45, 46, 175,
|
1542
|
+
48, 177, 178, 51, 180, 53, 54, 183, 184, 57, 58, 187, 60, 189, 190, 63,
|
1543
|
+
192, 65, 66, 195, 68, 197, 198, 71, 72, 201, 202, 75, 204, 77, 78, 207,
|
1544
|
+
80, 209, 210, 83, 212, 85, 86, 215, 216, 89, 90, 219, 92, 221, 222, 95,
|
1545
|
+
96, 225, 226, 99, 228, 101, 102, 231, 232, 105, 106, 235, 108, 237, 238, 111,
|
1546
|
+
240, 113, 114, 243, 116, 245, 246, 119, 120, 249, 250, 123, 252, 125, 126, 255,
|
1547
|
+
};
|
1548
|
+
|
1549
|
+
static const __device__ uint8_t kmask_iq2xs[8] = {1, 2, 4, 8, 16, 32, 64, 128};
|
1550
|
+
|
1551
|
+
inline bool ggml_cuda_supports_mmq(enum ggml_type type) {
|
1552
|
+
switch (type) {
|
1553
|
+
case GGML_TYPE_Q4_0:
|
1554
|
+
case GGML_TYPE_Q4_1:
|
1555
|
+
case GGML_TYPE_Q5_0:
|
1556
|
+
case GGML_TYPE_Q5_1:
|
1557
|
+
case GGML_TYPE_Q8_0:
|
1558
|
+
case GGML_TYPE_Q2_K:
|
1559
|
+
case GGML_TYPE_Q3_K:
|
1560
|
+
case GGML_TYPE_Q4_K:
|
1561
|
+
case GGML_TYPE_Q5_K:
|
1562
|
+
case GGML_TYPE_Q6_K:
|
1563
|
+
return true;
|
1564
|
+
default:
|
1565
|
+
return false;
|
1566
|
+
}
|
1567
|
+
}
|
1568
|
+
|
1569
|
+
template<typename dst_t>
|
1570
|
+
static __global__ void dequantize_block_iq2_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
1571
|
+
|
1572
|
+
const int i = blockIdx.x;
|
1573
|
+
const block_iq2_xxs * x = (const block_iq2_xxs *) vx;
|
1574
|
+
|
1575
|
+
const int tid = threadIdx.x;
|
1576
|
+
#if QK_K == 256
|
1577
|
+
const int il = tid/8; // 0...3
|
1578
|
+
const int ib = tid%8; // 0...7
|
1579
|
+
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
1580
|
+
const uint16_t * q2 = x[i].qs + 4*ib;
|
1581
|
+
const uint8_t * aux8 = (const uint8_t *)q2;
|
1582
|
+
const uint8_t * grid = (const uint8_t *)(iq2xxs_grid + aux8[il]);
|
1583
|
+
const uint32_t aux32 = q2[2] | (q2[3] << 16);
|
1584
|
+
const float d = (float)x[i].d * (0.5f + (aux32 >> 28)) * 0.25f;
|
1585
|
+
const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*il) & 127];
|
1586
|
+
for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
|
1587
|
+
#else
|
1588
|
+
assert(false);
|
1589
|
+
#endif
|
1590
|
+
|
1591
|
+
}
|
1592
|
+
|
1593
|
+
template<typename dst_t>
|
1594
|
+
static __global__ void dequantize_block_iq2_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
1595
|
+
|
1596
|
+
const int i = blockIdx.x;
|
1597
|
+
const block_iq2_xs * x = (const block_iq2_xs *) vx;
|
1598
|
+
|
1599
|
+
const int tid = threadIdx.x;
|
1600
|
+
#if QK_K == 256
|
1601
|
+
const int il = tid/8; // 0...3
|
1602
|
+
const int ib = tid%8; // 0...7
|
1603
|
+
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
1604
|
+
const uint16_t * q2 = x[i].qs + 4*ib;
|
1605
|
+
const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[il] & 511));
|
1606
|
+
const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f;
|
1607
|
+
const uint8_t signs = ksigns_iq2xs[q2[il] >> 9];
|
1608
|
+
for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
|
1609
|
+
#else
|
1610
|
+
assert(false);
|
1611
|
+
#endif
|
1612
|
+
|
1613
|
+
}
|
1614
|
+
|
1295
1615
|
static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
|
1296
1616
|
|
1297
1617
|
static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
|
@@ -1872,14 +2192,6 @@ static __device__ void convert_f16(const void * vx, const int ib, const int iqs,
|
|
1872
2192
|
v.y = x[ib + iqs + 1];
|
1873
2193
|
}
|
1874
2194
|
|
1875
|
-
static __device__ void convert_f32(const void * vx, const int ib, const int iqs, dfloat2 & v){
|
1876
|
-
const float * x = (const float *) vx;
|
1877
|
-
|
1878
|
-
// automatic half -> float type cast if dfloat == float
|
1879
|
-
v.x = x[ib + iqs + 0];
|
1880
|
-
v.y = x[ib + iqs + 1];
|
1881
|
-
}
|
1882
|
-
|
1883
2195
|
static __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int kx, const int kx_padded) {
|
1884
2196
|
const int ix = blockDim.x*blockIdx.x + threadIdx.x;
|
1885
2197
|
|
@@ -1983,7 +2295,7 @@ static __global__ void k_get_rows_float(
|
|
1983
2295
|
|
1984
2296
|
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
|
1985
2297
|
static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int k) {
|
1986
|
-
const int i = blockDim.x*blockIdx.x +
|
2298
|
+
const int i = 2*(blockDim.x*blockIdx.x + threadIdx.x);
|
1987
2299
|
|
1988
2300
|
if (i >= k) {
|
1989
2301
|
return;
|
@@ -2002,6 +2314,19 @@ static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __
|
|
2002
2314
|
y[iybs + iqs + y_offset] = v.y;
|
2003
2315
|
}
|
2004
2316
|
|
2317
|
+
template <typename src_t, typename dst_t>
|
2318
|
+
static __global__ void convert_unary(const void * __restrict__ vx, dst_t * __restrict__ y, const int k) {
|
2319
|
+
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
2320
|
+
|
2321
|
+
if (i >= k) {
|
2322
|
+
return;
|
2323
|
+
}
|
2324
|
+
|
2325
|
+
const src_t * x = (src_t *) vx;
|
2326
|
+
|
2327
|
+
y[i] = x[i];
|
2328
|
+
}
|
2329
|
+
|
2005
2330
|
// VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called
|
2006
2331
|
// MMVQ = mul_mat_vec_q, MMQ = mul_mat_q
|
2007
2332
|
|
@@ -3820,6 +4145,91 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_mul_mat(
|
|
3820
4145
|
return vec_dot_q6_K_q8_1_impl_mmq(&x_ql[index_x], &y_qs[index_y], sc, x_dmf[i * (WARP_SIZE/QI6_K) + i/QI6_K], &y_df[index_y/QI8_1]);
|
3821
4146
|
}
|
3822
4147
|
|
4148
|
+
static __device__ __forceinline__ float vec_dot_iq2_xxs_q8_1(
|
4149
|
+
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
|
4150
|
+
#if QK_K == 256
|
4151
|
+
const block_iq2_xxs * bq2 = (const block_iq2_xxs *) vbq;
|
4152
|
+
|
4153
|
+
#if QR2_XXS == 8
|
4154
|
+
const int ib32 = iqs;
|
4155
|
+
const uint16_t * q2 = bq2->qs + 4*ib32;
|
4156
|
+
const uint8_t * aux8 = (const uint8_t *)q2;
|
4157
|
+
const int8_t * q8 = bq8_1[ib32].qs;
|
4158
|
+
uint32_t aux32 = q2[2] | (q2[3] << 16);
|
4159
|
+
int sumi = 0;
|
4160
|
+
for (int l = 0; l < 4; ++l) {
|
4161
|
+
const uint8_t * grid = (const uint8_t *)(iq2xxs_grid + aux8[l]);
|
4162
|
+
const uint8_t signs = ksigns_iq2xs[aux32 & 127];
|
4163
|
+
for (int j = 0; j < 8; ++j) {
|
4164
|
+
sumi += q8[j] * grid[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
|
4165
|
+
}
|
4166
|
+
q8 += 8;
|
4167
|
+
aux32 >>= 7;
|
4168
|
+
}
|
4169
|
+
const float d = (float)bq2->d * (0.5f + aux32) * (float)bq8_1[ib32].ds.x * 0.25f;
|
4170
|
+
return d * sumi;
|
4171
|
+
#else
|
4172
|
+
// iqs is 0...15
|
4173
|
+
const int ib32 = iqs/2;
|
4174
|
+
const int il = iqs%2;
|
4175
|
+
const uint16_t * q2 = bq2->qs + 4*ib32;
|
4176
|
+
const uint8_t * aux8 = (const uint8_t *)q2;
|
4177
|
+
const uint8_t * grid1 = (const uint8_t *)(iq2xxs_grid + aux8[2*il+0]);
|
4178
|
+
const uint8_t * grid2 = (const uint8_t *)(iq2xxs_grid + aux8[2*il+1]);
|
4179
|
+
const uint32_t aux32 = q2[2] | (q2[3] << 16);
|
4180
|
+
const float d = (float)bq2->d * (0.5f + (aux32 >> 28)) * (float)bq8_1[ib32].ds.x * 0.25f;
|
4181
|
+
const uint8_t signs1 = ksigns_iq2xs[(aux32 >> 14*il) & 127];
|
4182
|
+
const uint8_t signs2 = ksigns_iq2xs[(aux32 >> (14*il + 7)) & 127];
|
4183
|
+
const int8_t * q8 = bq8_1[ib32].qs + 16*il;
|
4184
|
+
int sumi1 = 0, sumi2 = 0;
|
4185
|
+
for (int j = 0; j < 8; ++j) {
|
4186
|
+
sumi1 += q8[j+0] * grid1[j] * (signs1 & kmask_iq2xs[j] ? -1 : 1);
|
4187
|
+
sumi2 += q8[j+8] * grid2[j] * (signs2 & kmask_iq2xs[j] ? -1 : 1);
|
4188
|
+
}
|
4189
|
+
return d * (sumi1 + sumi2);
|
4190
|
+
#endif
|
4191
|
+
#else
|
4192
|
+
assert(false);
|
4193
|
+
return 0.f;
|
4194
|
+
#endif
|
4195
|
+
}
|
4196
|
+
|
4197
|
+
static __device__ __forceinline__ float vec_dot_iq2_xs_q8_1(
|
4198
|
+
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
|
4199
|
+
#if QK_K == 256
|
4200
|
+
const block_iq2_xs * bq2 = (const block_iq2_xs *) vbq;
|
4201
|
+
|
4202
|
+
const int ib32 = iqs;
|
4203
|
+
const uint16_t * q2 = bq2->qs + 4*ib32;
|
4204
|
+
const int8_t * q8 = bq8_1[ib32].qs;
|
4205
|
+
const uint8_t ls1 = bq2->scales[ib32] & 0xf;
|
4206
|
+
const uint8_t ls2 = bq2->scales[ib32] >> 4;
|
4207
|
+
int sumi1 = 0;
|
4208
|
+
for (int l = 0; l < 2; ++l) {
|
4209
|
+
const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[l] & 511));
|
4210
|
+
const uint8_t signs = ksigns_iq2xs[q2[l] >> 9];
|
4211
|
+
for (int j = 0; j < 8; ++j) {
|
4212
|
+
sumi1 += q8[j] * grid[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
|
4213
|
+
}
|
4214
|
+
q8 += 8;
|
4215
|
+
}
|
4216
|
+
int sumi2 = 0;
|
4217
|
+
for (int l = 2; l < 4; ++l) {
|
4218
|
+
const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[l] & 511));
|
4219
|
+
const uint8_t signs = ksigns_iq2xs[q2[l] >> 9];
|
4220
|
+
for (int j = 0; j < 8; ++j) {
|
4221
|
+
sumi2 += q8[j] * grid[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
|
4222
|
+
}
|
4223
|
+
q8 += 8;
|
4224
|
+
}
|
4225
|
+
const float d = (float)bq2->d * (float)bq8_1[ib32].ds.x * 0.25f;
|
4226
|
+
return d * ((0.5f + ls1) * sumi1 + (0.5f + ls2) * sumi2);
|
4227
|
+
#else
|
4228
|
+
assert(false);
|
4229
|
+
return 0.f;
|
4230
|
+
#endif
|
4231
|
+
}
|
4232
|
+
|
3823
4233
|
template <int qk, int qr, int qi, bool need_sum, typename block_q_t, int mmq_x, int mmq_y, int nwarps,
|
3824
4234
|
allocate_tiles_cuda_t allocate_tiles, load_tiles_cuda_t load_tiles, int vdr, vec_dot_q_mul_mat_cuda_t vec_dot>
|
3825
4235
|
static __device__ __forceinline__ void mul_mat_q(
|
@@ -5201,75 +5611,233 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int
|
|
5201
5611
|
dst[i] = x[i] - (col > n_past + row % rows_per_channel) * FLT_MAX;
|
5202
5612
|
}
|
5203
5613
|
|
5204
|
-
|
5614
|
+
template <bool vals_smem, int ncols_template, int block_size_template, bool need_check>
|
5615
|
+
static __global__ void soft_max_f16(const float * x, const float * y, float * dst, const int ncols_par, const int nrows_y, const float scale) {
|
5616
|
+
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
|
5617
|
+
const int ncols_data = ncols_template == 0 ? ncols_par : ncols_template;
|
5618
|
+
const int ncols_smem = GGML_PAD(ncols_data, 2*WARP_SIZE)/2;
|
5619
|
+
|
5620
|
+
const int tid = threadIdx.x;
|
5621
|
+
const int rowx = blockIdx.x;
|
5622
|
+
const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension
|
5623
|
+
|
5624
|
+
const int block_size = block_size_template == 0 ? blockDim.x : block_size_template;
|
5625
|
+
|
5626
|
+
const int warp_id = threadIdx.x / WARP_SIZE;
|
5627
|
+
const int lane_id = threadIdx.x % WARP_SIZE;
|
5628
|
+
|
5629
|
+
extern __shared__ half data_soft_max_f16[];
|
5630
|
+
half * buf_iw = data_soft_max_f16 + 0; // shared memory buffer for inter-warp communication
|
5631
|
+
// (shared memory) buffer to cache values between iterations:
|
5632
|
+
half2 * vals = vals_smem ? (half2 *) (buf_iw + WARP_SIZE) : (half2 *) (dst + rowx*ncols_data);
|
5633
|
+
// if the buffer is larger than max. shared memory per block, use dst as temp. buffer instead
|
5634
|
+
// in that case col_smem == col_data must be enforced to avoid race conditions
|
5635
|
+
|
5636
|
+
half2 max_val = make_half2(-INFINITY, -INFINITY);
|
5637
|
+
|
5638
|
+
#pragma unroll
|
5639
|
+
for (int col0 = 0; col0 < ncols_smem; col0 += block_size) {
|
5640
|
+
const int col_data = 2*col0 + 2*WARP_SIZE*warp_id + lane_id;
|
5641
|
+
const int col_smem = vals_smem ? col0 + tid : col_data;
|
5642
|
+
|
5643
|
+
const int ix = rowx*ncols_data + col_data;
|
5644
|
+
const int iy = rowy*ncols_data + col_data;
|
5645
|
+
|
5646
|
+
half2 val;
|
5647
|
+
if (need_check && col_data + 0 >= ncols_data) {
|
5648
|
+
val.x = -INFINITY;
|
5649
|
+
} else {
|
5650
|
+
val.x = x[ix + 0]*scale + (y ? y[iy + 0] : 0.0f);
|
5651
|
+
}
|
5652
|
+
if (need_check && col_data + WARP_SIZE >= ncols_data) {
|
5653
|
+
val.y = -INFINITY;
|
5654
|
+
} else {
|
5655
|
+
val.y = x[ix + WARP_SIZE]*scale + (y ? y[iy + WARP_SIZE] : 0.0f);
|
5656
|
+
}
|
5657
|
+
if (!need_check || col_smem < (vals_smem ? ncols_smem : ncols_data)) {
|
5658
|
+
vals[col_smem] = val;
|
5659
|
+
}
|
5660
|
+
max_val = __hmax2(max_val, val);
|
5661
|
+
}
|
5662
|
+
|
5663
|
+
// find the max value in the block
|
5664
|
+
max_val = warp_reduce_max(max_val);
|
5665
|
+
if (block_size > WARP_SIZE) {
|
5666
|
+
if (warp_id == 0) {
|
5667
|
+
buf_iw[lane_id] = -INFINITY;
|
5668
|
+
}
|
5669
|
+
__syncthreads();
|
5670
|
+
|
5671
|
+
if (lane_id == 0) {
|
5672
|
+
buf_iw[warp_id] = __hmax(max_val.x, max_val.y);
|
5673
|
+
}
|
5674
|
+
__syncthreads();
|
5675
|
+
|
5676
|
+
max_val = __half2half2(buf_iw[lane_id]);
|
5677
|
+
max_val = warp_reduce_max(max_val);
|
5678
|
+
} else {
|
5679
|
+
max_val = __half2half2(__hmax(max_val.x, max_val.y));
|
5680
|
+
}
|
5681
|
+
|
5682
|
+
half2 tmp = make_half2(0.0f, 0.0f); // partial sums
|
5683
|
+
|
5684
|
+
#pragma unroll
|
5685
|
+
for (int col0 = 0; col0 < ncols_smem; col0 += block_size) {
|
5686
|
+
const int col_smem = vals_smem ? col0 + tid : 2*col0 + 2*warp_id*WARP_SIZE + lane_id;
|
5687
|
+
|
5688
|
+
if (ncols_template == 0 && col_smem >= (vals_smem ? ncols_smem : ncols_data)) {
|
5689
|
+
break;
|
5690
|
+
}
|
5691
|
+
|
5692
|
+
const half2 val = h2exp(vals[col_smem] - max_val);
|
5693
|
+
|
5694
|
+
tmp += val;
|
5695
|
+
vals[col_smem] = val;
|
5696
|
+
}
|
5697
|
+
|
5698
|
+
// find the sum of exps in the block
|
5699
|
+
tmp = warp_reduce_sum(tmp);
|
5700
|
+
if (block_size > WARP_SIZE) {
|
5701
|
+
if (warp_id == 0) {
|
5702
|
+
buf_iw[lane_id] = 0.0f;
|
5703
|
+
}
|
5704
|
+
__syncthreads();
|
5705
|
+
|
5706
|
+
if (lane_id == 0) {
|
5707
|
+
buf_iw[warp_id] = tmp.x + tmp.y;
|
5708
|
+
}
|
5709
|
+
__syncthreads();
|
5710
|
+
|
5711
|
+
tmp = __half2half2(buf_iw[lane_id]);
|
5712
|
+
tmp = warp_reduce_sum(tmp);
|
5713
|
+
} else {
|
5714
|
+
tmp = __half2half2(tmp.x + tmp.y);
|
5715
|
+
}
|
5716
|
+
|
5717
|
+
const half2 inv_sum = make_half2(1.0f, 1.0f) / tmp;
|
5718
|
+
|
5719
|
+
#pragma unroll
|
5720
|
+
for (int col0 = 0; col0 < ncols_smem; col0 += block_size) {
|
5721
|
+
const int col_data = 2*col0 + 2*WARP_SIZE*warp_id + lane_id;
|
5722
|
+
const int col_smem = vals_smem ? col0 + tid : col_data;
|
5723
|
+
|
5724
|
+
const int idst = rowx*ncols_data + col_data;
|
5725
|
+
const half2 result = vals[col_smem] * inv_sum;
|
5726
|
+
|
5727
|
+
if (need_check && col_data + 0 >= ncols_data) {
|
5728
|
+
return;
|
5729
|
+
}
|
5730
|
+
dst[idst] = result.x;
|
5731
|
+
|
5732
|
+
if (need_check && col_data + WARP_SIZE >= ncols_data) {
|
5733
|
+
return;
|
5734
|
+
}
|
5735
|
+
|
5736
|
+
dst[idst + WARP_SIZE] = result.y;
|
5737
|
+
}
|
5738
|
+
#else
|
5739
|
+
(void) x; (void) y; (void) dst; (void) ncols_par; (void) nrows_y; (void) scale;
|
5740
|
+
bad_arch();
|
5741
|
+
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
|
5742
|
+
}
|
5743
|
+
|
5744
|
+
template <bool vals_smem, int ncols_template, int block_size_template>
|
5745
|
+
static __global__ void soft_max_f32(const float * x, const float * y, float * dst, const int ncols_par, const int nrows_y, const float scale) {
|
5746
|
+
const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
|
5747
|
+
|
5205
5748
|
const int tid = threadIdx.x;
|
5206
5749
|
const int rowx = blockIdx.x;
|
5207
5750
|
const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension
|
5208
5751
|
|
5209
|
-
const int block_size = blockDim.x;
|
5752
|
+
const int block_size = block_size_template == 0 ? blockDim.x : block_size_template;
|
5210
5753
|
|
5211
5754
|
const int warp_id = threadIdx.x / WARP_SIZE;
|
5212
5755
|
const int lane_id = threadIdx.x % WARP_SIZE;
|
5213
5756
|
|
5214
|
-
__shared__ float
|
5757
|
+
extern __shared__ float data_soft_max_f32[];
|
5758
|
+
float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication
|
5759
|
+
// shared memory buffer to cache values between iterations:
|
5760
|
+
float * vals = vals_smem ? buf_iw + WARP_SIZE : dst + rowx*ncols;
|
5215
5761
|
|
5216
5762
|
float max_val = -INFINITY;
|
5217
5763
|
|
5218
|
-
|
5764
|
+
#pragma unroll
|
5765
|
+
for (int col0 = 0; col0 < ncols; col0 += block_size) {
|
5766
|
+
const int col = col0 + tid;
|
5767
|
+
|
5768
|
+
if (ncols_template == 0 && col >= ncols) {
|
5769
|
+
break;
|
5770
|
+
}
|
5771
|
+
|
5219
5772
|
const int ix = rowx*ncols + col;
|
5220
5773
|
const int iy = rowy*ncols + col;
|
5221
|
-
|
5774
|
+
|
5775
|
+
const float val = x[ix]*scale + (y ? y[iy] : 0.0f);
|
5776
|
+
vals[col] = val;
|
5777
|
+
max_val = max(max_val, val);
|
5222
5778
|
}
|
5223
5779
|
|
5224
5780
|
// find the max value in the block
|
5225
5781
|
max_val = warp_reduce_max(max_val);
|
5226
5782
|
if (block_size > WARP_SIZE) {
|
5227
5783
|
if (warp_id == 0) {
|
5228
|
-
|
5784
|
+
buf_iw[lane_id] = -INFINITY;
|
5229
5785
|
}
|
5230
5786
|
__syncthreads();
|
5231
5787
|
|
5232
5788
|
if (lane_id == 0) {
|
5233
|
-
|
5789
|
+
buf_iw[warp_id] = max_val;
|
5234
5790
|
}
|
5235
5791
|
__syncthreads();
|
5236
5792
|
|
5237
|
-
max_val =
|
5793
|
+
max_val = buf_iw[lane_id];
|
5238
5794
|
max_val = warp_reduce_max(max_val);
|
5239
5795
|
}
|
5240
5796
|
|
5241
|
-
float tmp = 0.
|
5797
|
+
float tmp = 0.0f; // partial sum
|
5242
5798
|
|
5243
|
-
|
5244
|
-
|
5245
|
-
const int
|
5246
|
-
|
5799
|
+
#pragma unroll
|
5800
|
+
for (int col0 = 0; col0 < ncols; col0 += block_size) {
|
5801
|
+
const int col = col0 + tid;
|
5802
|
+
|
5803
|
+
if (ncols_template == 0 && col >= ncols) {
|
5804
|
+
break;
|
5805
|
+
}
|
5806
|
+
|
5807
|
+
const float val = expf(vals[col] - max_val);
|
5247
5808
|
tmp += val;
|
5248
|
-
|
5809
|
+
vals[col] = val;
|
5249
5810
|
}
|
5250
5811
|
|
5251
5812
|
// find the sum of exps in the block
|
5252
5813
|
tmp = warp_reduce_sum(tmp);
|
5253
5814
|
if (block_size > WARP_SIZE) {
|
5254
5815
|
if (warp_id == 0) {
|
5255
|
-
|
5816
|
+
buf_iw[lane_id] = 0.0f;
|
5256
5817
|
}
|
5257
5818
|
__syncthreads();
|
5258
5819
|
|
5259
5820
|
if (lane_id == 0) {
|
5260
|
-
|
5821
|
+
buf_iw[warp_id] = tmp;
|
5261
5822
|
}
|
5262
5823
|
__syncthreads();
|
5263
5824
|
|
5264
|
-
tmp =
|
5825
|
+
tmp = buf_iw[lane_id];
|
5265
5826
|
tmp = warp_reduce_sum(tmp);
|
5266
5827
|
}
|
5267
5828
|
|
5268
|
-
const float
|
5829
|
+
const float inv_sum = 1.0f / tmp;
|
5269
5830
|
|
5270
|
-
|
5271
|
-
|
5272
|
-
|
5831
|
+
#pragma unroll
|
5832
|
+
for (int col0 = 0; col0 < ncols; col0 += block_size) {
|
5833
|
+
const int col = col0 + tid;
|
5834
|
+
|
5835
|
+
if (ncols_template == 0 && col >= ncols) {
|
5836
|
+
return;
|
5837
|
+
}
|
5838
|
+
|
5839
|
+
const int idst = rowx*ncols + col;
|
5840
|
+
dst[idst] = vals[col] * inv_sum;
|
5273
5841
|
}
|
5274
5842
|
}
|
5275
5843
|
|
@@ -5609,7 +6177,7 @@ static void quantize_row_q8_1_cuda(const float * x, void * vy, const int kx, con
|
|
5609
6177
|
|
5610
6178
|
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
|
5611
6179
|
static void dequantize_block_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) {
|
5612
|
-
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
|
6180
|
+
const int num_blocks = (k + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE);
|
5613
6181
|
dequantize_block<qk, qr, dequantize_kernel><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
|
5614
6182
|
}
|
5615
6183
|
|
@@ -5659,6 +6227,24 @@ static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int k, cu
|
|
5659
6227
|
#endif
|
5660
6228
|
}
|
5661
6229
|
|
6230
|
+
template<typename dst_t>
|
6231
|
+
static void dequantize_row_iq2_xxs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
|
6232
|
+
const int nb = k / QK_K;
|
6233
|
+
dequantize_block_iq2_xxs<<<nb, 32, 0, stream>>>(vx, y);
|
6234
|
+
}
|
6235
|
+
|
6236
|
+
template<typename dst_t>
|
6237
|
+
static void dequantize_row_iq2_xs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
|
6238
|
+
const int nb = k / QK_K;
|
6239
|
+
dequantize_block_iq2_xs<<<nb, 32, 0, stream>>>(vx, y);
|
6240
|
+
}
|
6241
|
+
|
6242
|
+
template <typename src_t, typename dst_t>
|
6243
|
+
static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) {
|
6244
|
+
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
|
6245
|
+
convert_unary<src_t><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
|
6246
|
+
}
|
6247
|
+
|
5662
6248
|
static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
|
5663
6249
|
switch (type) {
|
5664
6250
|
case GGML_TYPE_Q4_0:
|
@@ -5681,8 +6267,12 @@ static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
|
|
5681
6267
|
return dequantize_row_q5_K_cuda;
|
5682
6268
|
case GGML_TYPE_Q6_K:
|
5683
6269
|
return dequantize_row_q6_K_cuda;
|
6270
|
+
case GGML_TYPE_IQ2_XXS:
|
6271
|
+
return dequantize_row_iq2_xxs_cuda;
|
6272
|
+
case GGML_TYPE_IQ2_XS:
|
6273
|
+
return dequantize_row_iq2_xs_cuda;
|
5684
6274
|
case GGML_TYPE_F32:
|
5685
|
-
return
|
6275
|
+
return convert_unary_cuda<float>;
|
5686
6276
|
default:
|
5687
6277
|
return nullptr;
|
5688
6278
|
}
|
@@ -5710,8 +6300,12 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
|
|
5710
6300
|
return dequantize_row_q5_K_cuda;
|
5711
6301
|
case GGML_TYPE_Q6_K:
|
5712
6302
|
return dequantize_row_q6_K_cuda;
|
6303
|
+
case GGML_TYPE_IQ2_XXS:
|
6304
|
+
return dequantize_row_iq2_xxs_cuda;
|
6305
|
+
case GGML_TYPE_IQ2_XS:
|
6306
|
+
return dequantize_row_iq2_xs_cuda;
|
5713
6307
|
case GGML_TYPE_F16:
|
5714
|
-
return
|
6308
|
+
return convert_unary_cuda<half>;
|
5715
6309
|
default:
|
5716
6310
|
return nullptr;
|
5717
6311
|
}
|
@@ -5904,6 +6498,24 @@ static void mul_mat_vec_q6_K_q8_1_cuda(const void * vx, const void * vy, float *
|
|
5904
6498
|
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
5905
6499
|
}
|
5906
6500
|
|
6501
|
+
static void mul_mat_vec_iq2_xxs_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
6502
|
+
GGML_ASSERT(ncols % QK_K == 0);
|
6503
|
+
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
6504
|
+
const dim3 block_nums(block_num_y, 1, 1);
|
6505
|
+
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
6506
|
+
mul_mat_vec_q<QK_K, QI2_XXS, block_iq2_xxs, 1, vec_dot_iq2_xxs_q8_1>
|
6507
|
+
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
6508
|
+
}
|
6509
|
+
|
6510
|
+
static void mul_mat_vec_iq2_xs_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
6511
|
+
GGML_ASSERT(ncols % QK_K == 0);
|
6512
|
+
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
6513
|
+
const dim3 block_nums(block_num_y, 1, 1);
|
6514
|
+
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
6515
|
+
mul_mat_vec_q<QK_K, QI2_XS, block_iq2_xs, 1, vec_dot_iq2_xs_q8_1>
|
6516
|
+
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
6517
|
+
}
|
6518
|
+
|
5907
6519
|
static void ggml_mul_mat_q4_0_q8_1_cuda(
|
5908
6520
|
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
|
5909
6521
|
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
|
@@ -6543,12 +7155,90 @@ static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols
|
|
6543
7155
|
diag_mask_inf_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x, rows_per_channel, n_past);
|
6544
7156
|
}
|
6545
7157
|
|
7158
|
+
static void soft_max_f16_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) {
|
7159
|
+
int nth = WARP_SIZE;
|
7160
|
+
while (nth < ncols_x/2 && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
|
7161
|
+
const dim3 block_dims(nth, 1, 1);
|
7162
|
+
const dim3 block_nums(nrows_x, 1, 1);
|
7163
|
+
const size_t shmem = (GGML_PAD(ncols_x, 2*WARP_SIZE) + WARP_SIZE)*sizeof(half);
|
7164
|
+
static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
|
7165
|
+
if (shmem <= g_device_caps[g_main_device].smpb) {
|
7166
|
+
switch (ncols_x) {
|
7167
|
+
case 32:
|
7168
|
+
soft_max_f16<true, 32, 32, true><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
7169
|
+
break;
|
7170
|
+
case 64:
|
7171
|
+
soft_max_f16<true, 64, 32, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
7172
|
+
break;
|
7173
|
+
case 128:
|
7174
|
+
soft_max_f16<true, 128, 64, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
7175
|
+
break;
|
7176
|
+
case 256:
|
7177
|
+
soft_max_f16<true, 256, 128, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
7178
|
+
break;
|
7179
|
+
case 512:
|
7180
|
+
soft_max_f16<true, 512, 256, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
7181
|
+
break;
|
7182
|
+
case 1024:
|
7183
|
+
soft_max_f16<true, 1024, 512, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
7184
|
+
break;
|
7185
|
+
case 2048:
|
7186
|
+
soft_max_f16<true, 2048, 1024, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
7187
|
+
break;
|
7188
|
+
case 4096:
|
7189
|
+
soft_max_f16<true, 4096, 1024, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
7190
|
+
break;
|
7191
|
+
default:
|
7192
|
+
soft_max_f16<true, 0, 0, true><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
7193
|
+
break;
|
7194
|
+
}
|
7195
|
+
} else {
|
7196
|
+
const size_t shmem_low = WARP_SIZE*sizeof(half);
|
7197
|
+
soft_max_f16<false, 0, 0, true><<<block_nums, block_dims, shmem_low, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
7198
|
+
}
|
7199
|
+
}
|
7200
|
+
|
6546
7201
|
static void soft_max_f32_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) {
|
6547
7202
|
int nth = WARP_SIZE;
|
6548
7203
|
while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
|
6549
7204
|
const dim3 block_dims(nth, 1, 1);
|
6550
7205
|
const dim3 block_nums(nrows_x, 1, 1);
|
6551
|
-
|
7206
|
+
const size_t shmem = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float);
|
7207
|
+
static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
|
7208
|
+
if (shmem < g_device_caps[g_main_device].smpb) {
|
7209
|
+
switch (ncols_x) {
|
7210
|
+
case 32:
|
7211
|
+
soft_max_f32<true, 32, 32><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
7212
|
+
break;
|
7213
|
+
case 64:
|
7214
|
+
soft_max_f32<true, 64, 64><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
7215
|
+
break;
|
7216
|
+
case 128:
|
7217
|
+
soft_max_f32<true, 128, 128><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
7218
|
+
break;
|
7219
|
+
case 256:
|
7220
|
+
soft_max_f32<true, 256, 256><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
7221
|
+
break;
|
7222
|
+
case 512:
|
7223
|
+
soft_max_f32<true, 512, 512><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
7224
|
+
break;
|
7225
|
+
case 1024:
|
7226
|
+
soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
7227
|
+
break;
|
7228
|
+
case 2048:
|
7229
|
+
soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
7230
|
+
break;
|
7231
|
+
case 4096:
|
7232
|
+
soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
7233
|
+
break;
|
7234
|
+
default:
|
7235
|
+
soft_max_f32<true, 0, 0><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
7236
|
+
break;
|
7237
|
+
}
|
7238
|
+
} else {
|
7239
|
+
const size_t shmem_low = WARP_SIZE*sizeof(float);
|
7240
|
+
soft_max_f32<false, 0, 0><<<block_nums, block_dims, shmem_low, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
7241
|
+
}
|
6552
7242
|
}
|
6553
7243
|
|
6554
7244
|
static void im2col_f32_f16_cuda(const float* x, half* dst,
|
@@ -6863,6 +7553,7 @@ void ggml_init_cublas() {
|
|
6863
7553
|
#else
|
6864
7554
|
g_device_caps[id].cc = 100*prop.major + 10*prop.minor;
|
6865
7555
|
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
7556
|
+
g_device_caps[id].smpb = prop.sharedMemPerBlock;
|
6866
7557
|
}
|
6867
7558
|
for (int id = 0; id < g_device_count; ++id) {
|
6868
7559
|
g_tensor_split[id] /= total_vram;
|
@@ -7396,6 +8087,8 @@ static int64_t get_row_rounding(ggml_type type) {
|
|
7396
8087
|
case GGML_TYPE_Q4_K:
|
7397
8088
|
case GGML_TYPE_Q5_K:
|
7398
8089
|
case GGML_TYPE_Q6_K:
|
8090
|
+
case GGML_TYPE_IQ2_XXS:
|
8091
|
+
case GGML_TYPE_IQ2_XS:
|
7399
8092
|
return max_compute_capability >= CC_RDNA2 ? 128 : 64;
|
7400
8093
|
default:
|
7401
8094
|
GGML_ASSERT(false);
|
@@ -7416,6 +8109,8 @@ static int64_t get_row_rounding(ggml_type type) {
|
|
7416
8109
|
case GGML_TYPE_Q3_K:
|
7417
8110
|
case GGML_TYPE_Q4_K:
|
7418
8111
|
case GGML_TYPE_Q5_K:
|
8112
|
+
case GGML_TYPE_IQ2_XXS:
|
8113
|
+
case GGML_TYPE_IQ2_XS:
|
7419
8114
|
return max_compute_capability >= CC_VOLTA ? 128 : 64;
|
7420
8115
|
case GGML_TYPE_Q6_K:
|
7421
8116
|
return 64;
|
@@ -7466,6 +8161,12 @@ static void ggml_cuda_op_mul_mat_vec_q(
|
|
7466
8161
|
case GGML_TYPE_Q6_K:
|
7467
8162
|
mul_mat_vec_q6_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
|
7468
8163
|
break;
|
8164
|
+
case GGML_TYPE_IQ2_XXS:
|
8165
|
+
mul_mat_vec_iq2_xxs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
|
8166
|
+
break;
|
8167
|
+
case GGML_TYPE_IQ2_XS:
|
8168
|
+
mul_mat_vec_iq2_xs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
|
8169
|
+
break;
|
7469
8170
|
default:
|
7470
8171
|
GGML_ASSERT(false);
|
7471
8172
|
break;
|
@@ -7873,7 +8574,21 @@ static void ggml_cuda_op_soft_max(
|
|
7873
8574
|
float scale = 1.0f;
|
7874
8575
|
memcpy(&scale, dst->op_params, sizeof(float));
|
7875
8576
|
|
7876
|
-
|
8577
|
+
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
8578
|
+
const bool use_f16_soft_max = false;
|
8579
|
+
#else
|
8580
|
+
#ifdef GGML_CUDA_F16
|
8581
|
+
const bool use_f16_soft_max = true;
|
8582
|
+
#else
|
8583
|
+
const bool use_f16_soft_max = false;
|
8584
|
+
#endif // GGML_CUDA_F16
|
8585
|
+
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
8586
|
+
|
8587
|
+
if (use_f16_soft_max) {
|
8588
|
+
soft_max_f16_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream);
|
8589
|
+
} else {
|
8590
|
+
soft_max_f32_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream);
|
8591
|
+
}
|
7877
8592
|
|
7878
8593
|
(void) dst;
|
7879
8594
|
}
|
@@ -8682,6 +9397,8 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
|
|
8682
9397
|
|
8683
9398
|
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
8684
9399
|
|
9400
|
+
use_mul_mat_q = use_mul_mat_q && ggml_cuda_supports_mmq(src0->type);
|
9401
|
+
|
8685
9402
|
// debug helpers
|
8686
9403
|
//printf("src0: %8d %8d %8d %8d\n", src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]);
|
8687
9404
|
//printf(" %8d %8d %8d %8d\n", src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]);
|
@@ -9689,8 +10406,8 @@ static void ggml_backend_cuda_buffer_set_tensor(ggml_backend_buffer_t buffer, gg
|
|
9689
10406
|
|
9690
10407
|
ggml_cuda_set_device(ctx->device);
|
9691
10408
|
CUDA_CHECK(cudaDeviceSynchronize());
|
9692
|
-
|
9693
10409
|
CUDA_CHECK(cudaMemcpy((char *)tensor->data + offset, data, size, cudaMemcpyHostToDevice));
|
10410
|
+
CUDA_CHECK(cudaDeviceSynchronize());
|
9694
10411
|
}
|
9695
10412
|
|
9696
10413
|
static void ggml_backend_cuda_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
|
@@ -9910,7 +10627,7 @@ static void ggml_backend_cuda_graph_plan_compute(ggml_backend_t backend, ggml_ba
|
|
9910
10627
|
UNUSED(plan);
|
9911
10628
|
}
|
9912
10629
|
|
9913
|
-
static
|
10630
|
+
static bool ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
|
9914
10631
|
ggml_backend_context_cuda * cuda_ctx = (ggml_backend_context_cuda *)backend->context;
|
9915
10632
|
|
9916
10633
|
ggml_cuda_set_main_device(cuda_ctx->device);
|
@@ -9967,6 +10684,8 @@ static void ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph
|
|
9967
10684
|
}
|
9968
10685
|
|
9969
10686
|
UNUSED(backend);
|
10687
|
+
|
10688
|
+
return true;
|
9970
10689
|
}
|
9971
10690
|
|
9972
10691
|
static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_tensor * op) {
|