llama_cpp 0.12.0 → 0.12.1

Sign up to get free protection for your applications and to get access to all the features.
@@ -116,6 +116,7 @@
116
116
  #include "ggml.h"
117
117
  #include "ggml-backend-impl.h"
118
118
 
119
+ #define CC_PASCAL 600
119
120
  #define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
120
121
  #define CC_VOLTA 700
121
122
  #define CC_OFFSET_AMD 1000000
@@ -183,7 +184,7 @@ static __device__ __forceinline__ int __vsubss4(const int a, const int b) {
183
184
  static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) {
184
185
  #if defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx1030__)
185
186
  c = __builtin_amdgcn_sdot4(a, b, c, false);
186
- #elif defined(__gfx1100__)
187
+ #elif defined(RDNA3)
187
188
  c = __builtin_amdgcn_sudot4( true, a, true, b, c, false);
188
189
  #elif defined(__gfx1010__) || defined(__gfx900__)
189
190
  int tmp1;
@@ -477,6 +478,23 @@ typedef struct {
477
478
  } block_q6_K;
478
479
  static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_K block size/padding");
479
480
 
481
+ #define QR2_XXS 8
482
+ #define QI2_XXS (QK_K / (4*QR2_XXS))
483
+ typedef struct {
484
+ half d;
485
+ uint16_t qs[QK_K/8];
486
+ } block_iq2_xxs;
487
+ static_assert(sizeof(block_iq2_xxs) == sizeof(ggml_fp16_t) + QK_K/8*sizeof(uint16_t), "wrong iq2_xxs block size/padding");
488
+
489
+ #define QR2_XS 8
490
+ #define QI2_XS (QK_K / (4*QR2_XS))
491
+ typedef struct {
492
+ half d;
493
+ uint16_t qs[QK_K/8];
494
+ uint8_t scales[QK_K/32];
495
+ } block_iq2_xs;
496
+ static_assert(sizeof(block_iq2_xs) == sizeof(ggml_fp16_t) + QK_K/8*sizeof(uint16_t) + QK_K/32, "wrong iq2_xs block size/padding");
497
+
480
498
  #define WARP_SIZE 32
481
499
  #define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
482
500
 
@@ -548,11 +566,12 @@ static float g_tensor_split[GGML_CUDA_MAX_DEVICES] = {0};
548
566
 
549
567
  struct cuda_device_capabilities {
550
568
  int cc; // compute capability
569
+ size_t smpb; // max. shared memory per block
551
570
  bool vmm; // virtual memory support
552
571
  size_t vmm_granularity; // granularity of virtual memory
553
572
  };
554
573
 
555
- static cuda_device_capabilities g_device_caps[GGML_CUDA_MAX_DEVICES] = { {0, false, 0} };
574
+ static cuda_device_capabilities g_device_caps[GGML_CUDA_MAX_DEVICES] = { {0, 0, false, 0} };
556
575
 
557
576
  static void * g_scratch_buffer = nullptr;
558
577
  static size_t g_scratch_size = 0; // disabled by default
@@ -585,6 +604,19 @@ static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
585
604
  return a;
586
605
  }
587
606
 
607
+ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
608
+ #if __CUDA_ARCH__ < CC_PASCAL || (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
609
+ (void) a;
610
+ bad_arch();
611
+ #else
612
+ #pragma unroll
613
+ for (int mask = 16; mask > 0; mask >>= 1) {
614
+ a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32));
615
+ }
616
+ return a;
617
+ #endif // __CUDA_ARCH__ < CC_PASCAL || (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
618
+ }
619
+
588
620
  static __device__ __forceinline__ float warp_reduce_max(float x) {
589
621
  #pragma unroll
590
622
  for (int mask = 16; mask > 0; mask >>= 1) {
@@ -593,6 +625,19 @@ static __device__ __forceinline__ float warp_reduce_max(float x) {
593
625
  return x;
594
626
  }
595
627
 
628
+ static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
629
+ #if __CUDA_ARCH__ < CC_PASCAL || (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
630
+ (void) x;
631
+ bad_arch();
632
+ #else
633
+ #pragma unroll
634
+ for (int mask = 16; mask > 0; mask >>= 1) {
635
+ x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
636
+ }
637
+ return x;
638
+ #endif // __CUDA_ARCH__ < CC_PASCAL || (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
639
+ }
640
+
596
641
  static __device__ __forceinline__ float op_repeat(const float a, const float b) {
597
642
  return b;
598
643
  GGML_UNUSED(a);
@@ -1292,6 +1337,281 @@ static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t
1292
1337
  #endif
1293
1338
  }
1294
1339
 
1340
+ static const __device__ uint64_t iq2xxs_grid[256] = {
1341
+ 0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
1342
+ 0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x08080808082b0808,
1343
+ 0x08080808082b082b, 0x08080808082b2b08, 0x08080808082b2b2b, 0x0808080819080819,
1344
+ 0x0808080819081908, 0x0808080819190808, 0x0808080819192b08, 0x08080808192b0819,
1345
+ 0x08080808192b1908, 0x080808082b080808, 0x080808082b08082b, 0x080808082b082b2b,
1346
+ 0x080808082b2b082b, 0x0808081908080819, 0x0808081908081908, 0x0808081908190808,
1347
+ 0x0808081908191919, 0x0808081919080808, 0x080808192b081908, 0x080808192b192b08,
1348
+ 0x0808082b08080808, 0x0808082b0808082b, 0x0808082b082b082b, 0x0808082b2b08082b,
1349
+ 0x0808190808080819, 0x0808190808081908, 0x0808190808190808, 0x08081908082b0819,
1350
+ 0x08081908082b1908, 0x0808190819080808, 0x080819081908082b, 0x0808190819082b08,
1351
+ 0x08081908192b0808, 0x080819082b080819, 0x080819082b081908, 0x080819082b190808,
1352
+ 0x080819082b2b1908, 0x0808191908080808, 0x080819190808082b, 0x0808191908082b08,
1353
+ 0x08081919082b0808, 0x080819191908192b, 0x08081919192b2b19, 0x080819192b080808,
1354
+ 0x080819192b190819, 0x0808192b08082b19, 0x0808192b08190808, 0x0808192b19080808,
1355
+ 0x0808192b2b081908, 0x0808192b2b2b1908, 0x08082b0808080808, 0x08082b0808081919,
1356
+ 0x08082b0808082b08, 0x08082b0808191908, 0x08082b08082b2b08, 0x08082b0819080819,
1357
+ 0x08082b0819081908, 0x08082b0819190808, 0x08082b081919082b, 0x08082b082b082b08,
1358
+ 0x08082b1908081908, 0x08082b1919080808, 0x08082b2b0808082b, 0x08082b2b08191908,
1359
+ 0x0819080808080819, 0x0819080808081908, 0x0819080808190808, 0x08190808082b0819,
1360
+ 0x0819080819080808, 0x08190808192b0808, 0x081908082b081908, 0x081908082b190808,
1361
+ 0x081908082b191919, 0x0819081908080808, 0x0819081908082b08, 0x08190819082b0808,
1362
+ 0x0819081919190808, 0x0819081919192b2b, 0x081908192b080808, 0x0819082b082b1908,
1363
+ 0x0819082b19081919, 0x0819190808080808, 0x0819190808082b08, 0x08191908082b0808,
1364
+ 0x08191908082b1919, 0x0819190819082b19, 0x081919082b080808, 0x0819191908192b08,
1365
+ 0x08191919192b082b, 0x0819192b08080808, 0x0819192b0819192b, 0x08192b0808080819,
1366
+ 0x08192b0808081908, 0x08192b0808190808, 0x08192b0819080808, 0x08192b082b080819,
1367
+ 0x08192b1908080808, 0x08192b1908081919, 0x08192b192b2b0808, 0x08192b2b19190819,
1368
+ 0x082b080808080808, 0x082b08080808082b, 0x082b080808082b2b, 0x082b080819081908,
1369
+ 0x082b0808192b0819, 0x082b08082b080808, 0x082b08082b08082b, 0x082b0819082b2b19,
1370
+ 0x082b081919082b08, 0x082b082b08080808, 0x082b082b0808082b, 0x082b190808080819,
1371
+ 0x082b190808081908, 0x082b190808190808, 0x082b190819080808, 0x082b19081919192b,
1372
+ 0x082b191908080808, 0x082b191919080819, 0x082b1919192b1908, 0x082b192b2b190808,
1373
+ 0x082b2b0808082b08, 0x082b2b08082b0808, 0x082b2b082b191908, 0x082b2b2b19081908,
1374
+ 0x1908080808080819, 0x1908080808081908, 0x1908080808190808, 0x1908080808192b08,
1375
+ 0x19080808082b0819, 0x19080808082b1908, 0x1908080819080808, 0x1908080819082b08,
1376
+ 0x190808081919192b, 0x19080808192b0808, 0x190808082b080819, 0x190808082b081908,
1377
+ 0x190808082b190808, 0x1908081908080808, 0x19080819082b0808, 0x19080819192b0819,
1378
+ 0x190808192b080808, 0x190808192b081919, 0x1908082b08080819, 0x1908082b08190808,
1379
+ 0x1908082b19082b08, 0x1908082b1919192b, 0x1908082b192b2b08, 0x1908190808080808,
1380
+ 0x1908190808082b08, 0x19081908082b0808, 0x190819082b080808, 0x190819082b192b19,
1381
+ 0x190819190819082b, 0x19081919082b1908, 0x1908192b08080808, 0x19082b0808080819,
1382
+ 0x19082b0808081908, 0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919,
1383
+ 0x19082b1908080808, 0x19082b1919192b08, 0x19082b19192b0819, 0x19082b192b08082b,
1384
+ 0x19082b2b19081919, 0x19082b2b2b190808, 0x1919080808080808, 0x1919080808082b08,
1385
+ 0x1919080808190819, 0x1919080808192b19, 0x19190808082b0808, 0x191908082b080808,
1386
+ 0x191908082b082b08, 0x1919081908081908, 0x191908191908082b, 0x191908192b2b1908,
1387
+ 0x1919082b2b190819, 0x191919082b190808, 0x191919082b19082b, 0x1919191908082b2b,
1388
+ 0x1919192b08080819, 0x1919192b19191908, 0x19192b0808080808, 0x19192b0808190819,
1389
+ 0x19192b0808192b19, 0x19192b08192b1908, 0x19192b1919080808, 0x19192b2b08082b08,
1390
+ 0x192b080808081908, 0x192b080808190808, 0x192b080819080808, 0x192b0808192b2b08,
1391
+ 0x192b081908080808, 0x192b081919191919, 0x192b082b08192b08, 0x192b082b192b0808,
1392
+ 0x192b190808080808, 0x192b190808081919, 0x192b191908190808, 0x192b19190819082b,
1393
+ 0x192b19192b081908, 0x192b2b081908082b, 0x2b08080808080808, 0x2b0808080808082b,
1394
+ 0x2b08080808082b2b, 0x2b08080819080819, 0x2b0808082b08082b, 0x2b08081908081908,
1395
+ 0x2b08081908192b08, 0x2b08081919080808, 0x2b08082b08190819, 0x2b08190808080819,
1396
+ 0x2b08190808081908, 0x2b08190808190808, 0x2b08190808191919, 0x2b08190819080808,
1397
+ 0x2b081908192b0808, 0x2b08191908080808, 0x2b0819191908192b, 0x2b0819192b191908,
1398
+ 0x2b08192b08082b19, 0x2b08192b19080808, 0x2b08192b192b0808, 0x2b082b080808082b,
1399
+ 0x2b082b1908081908, 0x2b082b2b08190819, 0x2b19080808081908, 0x2b19080808190808,
1400
+ 0x2b190808082b1908, 0x2b19080819080808, 0x2b1908082b2b0819, 0x2b1908190819192b,
1401
+ 0x2b1908192b080808, 0x2b19082b19081919, 0x2b19190808080808, 0x2b191908082b082b,
1402
+ 0x2b19190819081908, 0x2b19191919190819, 0x2b192b082b080819, 0x2b192b19082b0808,
1403
+ 0x2b2b08080808082b, 0x2b2b080819190808, 0x2b2b08082b081919, 0x2b2b081908082b19,
1404
+ 0x2b2b082b08080808, 0x2b2b190808192b08, 0x2b2b2b0819190808, 0x2b2b2b1908081908,
1405
+ };
1406
+
1407
+ static const __device__ uint64_t iq2xs_grid[512] = {
1408
+ 0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
1409
+ 0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x080808080819192b,
1410
+ 0x0808080808192b19, 0x08080808082b0808, 0x08080808082b082b, 0x08080808082b1919,
1411
+ 0x08080808082b2b08, 0x0808080819080819, 0x0808080819081908, 0x080808081908192b,
1412
+ 0x0808080819082b19, 0x0808080819190808, 0x080808081919082b, 0x0808080819191919,
1413
+ 0x0808080819192b08, 0x08080808192b0819, 0x08080808192b1908, 0x080808082b080808,
1414
+ 0x080808082b08082b, 0x080808082b081919, 0x080808082b082b08, 0x080808082b190819,
1415
+ 0x080808082b191908, 0x080808082b192b19, 0x080808082b2b0808, 0x0808081908080819,
1416
+ 0x0808081908081908, 0x080808190808192b, 0x0808081908082b19, 0x0808081908190808,
1417
+ 0x080808190819082b, 0x0808081908191919, 0x0808081908192b08, 0x0808081908192b2b,
1418
+ 0x08080819082b0819, 0x08080819082b1908, 0x0808081919080808, 0x080808191908082b,
1419
+ 0x0808081919081919, 0x0808081919082b08, 0x0808081919190819, 0x0808081919191908,
1420
+ 0x08080819192b0808, 0x08080819192b2b08, 0x080808192b080819, 0x080808192b081908,
1421
+ 0x080808192b190808, 0x0808082b08080808, 0x0808082b0808082b, 0x0808082b08081919,
1422
+ 0x0808082b08082b08, 0x0808082b08190819, 0x0808082b08191908, 0x0808082b082b0808,
1423
+ 0x0808082b19080819, 0x0808082b19081908, 0x0808082b19190808, 0x0808082b19191919,
1424
+ 0x0808082b2b080808, 0x0808082b2b082b2b, 0x0808190808080819, 0x0808190808081908,
1425
+ 0x080819080808192b, 0x0808190808082b19, 0x0808190808190808, 0x080819080819082b,
1426
+ 0x0808190808191919, 0x0808190808192b08, 0x08081908082b0819, 0x08081908082b1908,
1427
+ 0x0808190819080808, 0x080819081908082b, 0x0808190819081919, 0x0808190819082b08,
1428
+ 0x0808190819190819, 0x0808190819191908, 0x080819081919192b, 0x08081908192b0808,
1429
+ 0x080819082b080819, 0x080819082b081908, 0x080819082b190808, 0x0808191908080808,
1430
+ 0x080819190808082b, 0x0808191908081919, 0x0808191908082b08, 0x0808191908190819,
1431
+ 0x0808191908191908, 0x08081919082b0808, 0x0808191919080819, 0x0808191919081908,
1432
+ 0x0808191919190808, 0x08081919192b0819, 0x080819192b080808, 0x0808192b08080819,
1433
+ 0x0808192b08081908, 0x0808192b08190808, 0x0808192b082b192b, 0x0808192b19080808,
1434
+ 0x0808192b1908082b, 0x0808192b2b081908, 0x08082b0808080808, 0x08082b080808082b,
1435
+ 0x08082b0808081919, 0x08082b0808082b08, 0x08082b0808082b2b, 0x08082b0808190819,
1436
+ 0x08082b0808191908, 0x08082b08082b0808, 0x08082b08082b1919, 0x08082b0819080819,
1437
+ 0x08082b0819081908, 0x08082b0819190808, 0x08082b0819192b08, 0x08082b082b080808,
1438
+ 0x08082b082b2b0808, 0x08082b082b2b2b2b, 0x08082b1908080819, 0x08082b1908081908,
1439
+ 0x08082b1908190808, 0x08082b1919080808, 0x08082b192b080819, 0x08082b192b082b19,
1440
+ 0x08082b2b08080808, 0x08082b2b082b0808, 0x08082b2b082b2b08, 0x08082b2b2b19192b,
1441
+ 0x08082b2b2b2b0808, 0x0819080808080819, 0x0819080808081908, 0x081908080808192b,
1442
+ 0x0819080808082b19, 0x0819080808190808, 0x081908080819082b, 0x0819080808191919,
1443
+ 0x0819080808192b08, 0x08190808082b0819, 0x08190808082b1908, 0x0819080819080808,
1444
+ 0x081908081908082b, 0x0819080819081919, 0x0819080819082b08, 0x0819080819190819,
1445
+ 0x0819080819191908, 0x08190808192b0808, 0x08190808192b2b2b, 0x081908082b080819,
1446
+ 0x081908082b081908, 0x081908082b190808, 0x0819081908080808, 0x081908190808082b,
1447
+ 0x0819081908081919, 0x0819081908082b08, 0x0819081908190819, 0x0819081908191908,
1448
+ 0x08190819082b0808, 0x0819081919080819, 0x0819081919081908, 0x0819081919190808,
1449
+ 0x081908192b080808, 0x081908192b191908, 0x081908192b19192b, 0x0819082b08080819,
1450
+ 0x0819082b08081908, 0x0819082b0808192b, 0x0819082b08190808, 0x0819082b19080808,
1451
+ 0x0819082b192b0808, 0x0819190808080808, 0x081919080808082b, 0x0819190808081919,
1452
+ 0x0819190808082b08, 0x0819190808190819, 0x0819190808191908, 0x08191908082b0808,
1453
+ 0x0819190819080819, 0x0819190819081908, 0x0819190819082b19, 0x0819190819190808,
1454
+ 0x08191908192b1908, 0x081919082b080808, 0x0819191908080819, 0x0819191908081908,
1455
+ 0x0819191908190808, 0x0819191919080808, 0x0819192b08080808, 0x0819192b08191908,
1456
+ 0x0819192b19082b19, 0x08192b0808080819, 0x08192b0808081908, 0x08192b0808190808,
1457
+ 0x08192b080819082b, 0x08192b0819080808, 0x08192b0819191908, 0x08192b082b08192b,
1458
+ 0x08192b1908080808, 0x08192b1908081919, 0x08192b19192b192b, 0x08192b2b19190819,
1459
+ 0x08192b2b2b2b2b19, 0x082b080808080808, 0x082b08080808082b, 0x082b080808081919,
1460
+ 0x082b080808082b08, 0x082b080808082b2b, 0x082b080808190819, 0x082b080808191908,
1461
+ 0x082b0808082b0808, 0x082b080819080819, 0x082b080819081908, 0x082b080819190808,
1462
+ 0x082b08082b080808, 0x082b08082b2b0808, 0x082b081908080819, 0x082b081908081908,
1463
+ 0x082b081908190808, 0x082b081919080808, 0x082b081919082b08, 0x082b0819192b1919,
1464
+ 0x082b082b08080808, 0x082b082b082b082b, 0x082b082b2b080808, 0x082b082b2b2b2b08,
1465
+ 0x082b190808080819, 0x082b190808081908, 0x082b190808190808, 0x082b1908082b2b19,
1466
+ 0x082b190819080808, 0x082b191908080808, 0x082b191919080819, 0x082b19191919082b,
1467
+ 0x082b19192b192b19, 0x082b192b08080819, 0x082b192b08192b2b, 0x082b192b2b2b192b,
1468
+ 0x082b2b0808080808, 0x082b2b0808082b08, 0x082b2b0808082b2b, 0x082b2b08082b0808,
1469
+ 0x082b2b0819191919, 0x082b2b082b082b08, 0x082b2b082b2b082b, 0x082b2b19192b2b08,
1470
+ 0x082b2b192b190808, 0x082b2b2b08082b08, 0x082b2b2b082b0808, 0x082b2b2b2b08082b,
1471
+ 0x082b2b2b2b082b08, 0x082b2b2b2b082b2b, 0x1908080808080819, 0x1908080808081908,
1472
+ 0x190808080808192b, 0x1908080808082b19, 0x1908080808190808, 0x190808080819082b,
1473
+ 0x1908080808191919, 0x1908080808192b08, 0x19080808082b0819, 0x19080808082b1908,
1474
+ 0x1908080819080808, 0x190808081908082b, 0x1908080819081919, 0x1908080819082b08,
1475
+ 0x1908080819082b2b, 0x1908080819190819, 0x1908080819191908, 0x19080808192b0808,
1476
+ 0x19080808192b1919, 0x190808082b080819, 0x190808082b081908, 0x190808082b190808,
1477
+ 0x1908081908080808, 0x190808190808082b, 0x1908081908081919, 0x1908081908082b08,
1478
+ 0x1908081908190819, 0x1908081908191908, 0x19080819082b0808, 0x1908081919080819,
1479
+ 0x1908081919081908, 0x1908081919190808, 0x190808192b080808, 0x190808192b081919,
1480
+ 0x190808192b2b082b, 0x1908082b08080819, 0x1908082b08081908, 0x1908082b08190808,
1481
+ 0x1908082b0819082b, 0x1908082b082b2b19, 0x1908082b19080808, 0x1908190808080808,
1482
+ 0x190819080808082b, 0x1908190808081919, 0x1908190808082b08, 0x1908190808190819,
1483
+ 0x1908190808191908, 0x1908190808192b19, 0x19081908082b0808, 0x1908190819080819,
1484
+ 0x1908190819081908, 0x1908190819190808, 0x190819082b080808, 0x190819082b191908,
1485
+ 0x1908191908080819, 0x1908191908081908, 0x1908191908190808, 0x19081919082b1908,
1486
+ 0x1908191919080808, 0x190819192b192b2b, 0x1908192b08080808, 0x1908192b08082b2b,
1487
+ 0x1908192b19081908, 0x1908192b19190808, 0x19082b0808080819, 0x19082b0808081908,
1488
+ 0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919, 0x19082b0819191908,
1489
+ 0x19082b08192b082b, 0x19082b1908080808, 0x19082b1908190819, 0x19082b1919081908,
1490
+ 0x19082b1919190808, 0x19082b19192b2b19, 0x19082b2b08081908, 0x1919080808080808,
1491
+ 0x191908080808082b, 0x1919080808081919, 0x1919080808082b08, 0x1919080808190819,
1492
+ 0x1919080808191908, 0x19190808082b0808, 0x19190808082b2b08, 0x1919080819080819,
1493
+ 0x1919080819081908, 0x1919080819190808, 0x191908082b080808, 0x1919081908080819,
1494
+ 0x1919081908081908, 0x1919081908190808, 0x1919081908191919, 0x1919081919080808,
1495
+ 0x191908191908082b, 0x1919082b08080808, 0x1919082b19081908, 0x1919082b2b2b2b2b,
1496
+ 0x1919190808080819, 0x1919190808081908, 0x1919190808190808, 0x19191908082b0819,
1497
+ 0x1919190819080808, 0x19191908192b0808, 0x191919082b080819, 0x191919082b2b0819,
1498
+ 0x1919191908080808, 0x1919191908082b08, 0x191919192b080808, 0x191919192b082b08,
1499
+ 0x1919192b082b0819, 0x1919192b192b2b08, 0x1919192b2b2b0819, 0x19192b0808080808,
1500
+ 0x19192b0808191908, 0x19192b0819080819, 0x19192b0819190808, 0x19192b082b192b19,
1501
+ 0x19192b1908192b2b, 0x19192b1919080808, 0x19192b191908082b, 0x19192b2b2b081919,
1502
+ 0x192b080808080819, 0x192b080808081908, 0x192b080808190808, 0x192b080819080808,
1503
+ 0x192b080819191908, 0x192b0808192b082b, 0x192b08082b08192b, 0x192b08082b2b2b19,
1504
+ 0x192b081908080808, 0x192b082b082b1908, 0x192b082b19082b2b, 0x192b082b2b19082b,
1505
+ 0x192b190808080808, 0x192b19080819192b, 0x192b191908190808, 0x192b191919080808,
1506
+ 0x192b191919081919, 0x192b19192b2b1908, 0x192b2b0808080819, 0x192b2b08192b2b2b,
1507
+ 0x192b2b19082b1919, 0x192b2b2b0808192b, 0x192b2b2b19191908, 0x192b2b2b192b082b,
1508
+ 0x2b08080808080808, 0x2b0808080808082b, 0x2b08080808081919, 0x2b08080808082b08,
1509
+ 0x2b08080808190819, 0x2b08080808191908, 0x2b080808082b0808, 0x2b080808082b2b2b,
1510
+ 0x2b08080819080819, 0x2b08080819081908, 0x2b08080819190808, 0x2b0808082b080808,
1511
+ 0x2b0808082b08082b, 0x2b0808082b2b2b08, 0x2b0808082b2b2b2b, 0x2b08081908080819,
1512
+ 0x2b08081908081908, 0x2b0808190808192b, 0x2b08081908190808, 0x2b08081919080808,
1513
+ 0x2b08081919190819, 0x2b08081919192b19, 0x2b08082b08080808, 0x2b08082b082b0808,
1514
+ 0x2b08082b2b080808, 0x2b08082b2b08082b, 0x2b08082b2b2b0808, 0x2b08082b2b2b2b08,
1515
+ 0x2b08190808080819, 0x2b08190808081908, 0x2b08190808190808, 0x2b0819080819082b,
1516
+ 0x2b08190808191919, 0x2b08190819080808, 0x2b081908192b0808, 0x2b0819082b082b19,
1517
+ 0x2b08191908080808, 0x2b08191919081908, 0x2b0819192b2b1919, 0x2b08192b08192b08,
1518
+ 0x2b08192b192b2b2b, 0x2b082b0808080808, 0x2b082b0808082b08, 0x2b082b08082b1919,
1519
+ 0x2b082b0819192b2b, 0x2b082b082b080808, 0x2b082b082b08082b, 0x2b082b082b2b2b08,
1520
+ 0x2b082b190808192b, 0x2b082b2b082b082b, 0x2b082b2b2b080808, 0x2b082b2b2b082b08,
1521
+ 0x2b082b2b2b19192b, 0x2b082b2b2b2b2b08, 0x2b19080808080819, 0x2b19080808081908,
1522
+ 0x2b19080808190808, 0x2b19080819080808, 0x2b1908081919192b, 0x2b1908082b081908,
1523
+ 0x2b19081908080808, 0x2b190819082b082b, 0x2b190819192b1908, 0x2b19082b1919192b,
1524
+ 0x2b19082b2b082b19, 0x2b19190808080808, 0x2b19190808081919, 0x2b19190819081908,
1525
+ 0x2b19190819190808, 0x2b19190819192b08, 0x2b191919082b2b19, 0x2b1919192b190808,
1526
+ 0x2b1919192b19082b, 0x2b19192b19080819, 0x2b192b0819190819, 0x2b192b082b2b192b,
1527
+ 0x2b192b1919082b19, 0x2b192b2b08191919, 0x2b192b2b192b0808, 0x2b2b080808080808,
1528
+ 0x2b2b08080808082b, 0x2b2b080808082b08, 0x2b2b080808082b2b, 0x2b2b0808082b0808,
1529
+ 0x2b2b0808082b2b2b, 0x2b2b08082b2b0808, 0x2b2b081919190819, 0x2b2b081919192b19,
1530
+ 0x2b2b08192b2b192b, 0x2b2b082b08080808, 0x2b2b082b0808082b, 0x2b2b082b08082b08,
1531
+ 0x2b2b082b082b2b2b, 0x2b2b082b2b080808, 0x2b2b082b2b2b0808, 0x2b2b190819080808,
1532
+ 0x2b2b19082b191919, 0x2b2b192b192b1919, 0x2b2b192b2b192b08, 0x2b2b2b0808082b2b,
1533
+ 0x2b2b2b08082b0808, 0x2b2b2b08082b082b, 0x2b2b2b08082b2b08, 0x2b2b2b082b2b0808,
1534
+ 0x2b2b2b082b2b2b08, 0x2b2b2b1908081908, 0x2b2b2b192b081908, 0x2b2b2b192b08192b,
1535
+ 0x2b2b2b2b082b2b08, 0x2b2b2b2b082b2b2b, 0x2b2b2b2b2b190819, 0x2b2b2b2b2b2b2b2b,
1536
+ };
1537
+
1538
+ static const __device__ uint8_t ksigns_iq2xs[128] = {
1539
+ 0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15,
1540
+ 144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159,
1541
+ 160, 33, 34, 163, 36, 165, 166, 39, 40, 169, 170, 43, 172, 45, 46, 175,
1542
+ 48, 177, 178, 51, 180, 53, 54, 183, 184, 57, 58, 187, 60, 189, 190, 63,
1543
+ 192, 65, 66, 195, 68, 197, 198, 71, 72, 201, 202, 75, 204, 77, 78, 207,
1544
+ 80, 209, 210, 83, 212, 85, 86, 215, 216, 89, 90, 219, 92, 221, 222, 95,
1545
+ 96, 225, 226, 99, 228, 101, 102, 231, 232, 105, 106, 235, 108, 237, 238, 111,
1546
+ 240, 113, 114, 243, 116, 245, 246, 119, 120, 249, 250, 123, 252, 125, 126, 255,
1547
+ };
1548
+
1549
+ static const __device__ uint8_t kmask_iq2xs[8] = {1, 2, 4, 8, 16, 32, 64, 128};
1550
+
1551
+ inline bool ggml_cuda_supports_mmq(enum ggml_type type) {
1552
+ switch (type) {
1553
+ case GGML_TYPE_Q4_0:
1554
+ case GGML_TYPE_Q4_1:
1555
+ case GGML_TYPE_Q5_0:
1556
+ case GGML_TYPE_Q5_1:
1557
+ case GGML_TYPE_Q8_0:
1558
+ case GGML_TYPE_Q2_K:
1559
+ case GGML_TYPE_Q3_K:
1560
+ case GGML_TYPE_Q4_K:
1561
+ case GGML_TYPE_Q5_K:
1562
+ case GGML_TYPE_Q6_K:
1563
+ return true;
1564
+ default:
1565
+ return false;
1566
+ }
1567
+ }
1568
+
1569
+ template<typename dst_t>
1570
+ static __global__ void dequantize_block_iq2_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
1571
+
1572
+ const int i = blockIdx.x;
1573
+ const block_iq2_xxs * x = (const block_iq2_xxs *) vx;
1574
+
1575
+ const int tid = threadIdx.x;
1576
+ #if QK_K == 256
1577
+ const int il = tid/8; // 0...3
1578
+ const int ib = tid%8; // 0...7
1579
+ dst_t * y = yy + i*QK_K + 32*ib + 8*il;
1580
+ const uint16_t * q2 = x[i].qs + 4*ib;
1581
+ const uint8_t * aux8 = (const uint8_t *)q2;
1582
+ const uint8_t * grid = (const uint8_t *)(iq2xxs_grid + aux8[il]);
1583
+ const uint32_t aux32 = q2[2] | (q2[3] << 16);
1584
+ const float d = (float)x[i].d * (0.5f + (aux32 >> 28)) * 0.25f;
1585
+ const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*il) & 127];
1586
+ for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
1587
+ #else
1588
+ assert(false);
1589
+ #endif
1590
+
1591
+ }
1592
+
1593
+ template<typename dst_t>
1594
+ static __global__ void dequantize_block_iq2_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
1595
+
1596
+ const int i = blockIdx.x;
1597
+ const block_iq2_xs * x = (const block_iq2_xs *) vx;
1598
+
1599
+ const int tid = threadIdx.x;
1600
+ #if QK_K == 256
1601
+ const int il = tid/8; // 0...3
1602
+ const int ib = tid%8; // 0...7
1603
+ dst_t * y = yy + i*QK_K + 32*ib + 8*il;
1604
+ const uint16_t * q2 = x[i].qs + 4*ib;
1605
+ const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[il] & 511));
1606
+ const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f;
1607
+ const uint8_t signs = ksigns_iq2xs[q2[il] >> 9];
1608
+ for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
1609
+ #else
1610
+ assert(false);
1611
+ #endif
1612
+
1613
+ }
1614
+
1295
1615
  static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
1296
1616
 
1297
1617
  static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
@@ -1872,14 +2192,6 @@ static __device__ void convert_f16(const void * vx, const int ib, const int iqs,
1872
2192
  v.y = x[ib + iqs + 1];
1873
2193
  }
1874
2194
 
1875
- static __device__ void convert_f32(const void * vx, const int ib, const int iqs, dfloat2 & v){
1876
- const float * x = (const float *) vx;
1877
-
1878
- // automatic half -> float type cast if dfloat == float
1879
- v.x = x[ib + iqs + 0];
1880
- v.y = x[ib + iqs + 1];
1881
- }
1882
-
1883
2195
  static __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int kx, const int kx_padded) {
1884
2196
  const int ix = blockDim.x*blockIdx.x + threadIdx.x;
1885
2197
 
@@ -1983,7 +2295,7 @@ static __global__ void k_get_rows_float(
1983
2295
 
1984
2296
  template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
1985
2297
  static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int k) {
1986
- const int i = blockDim.x*blockIdx.x + 2*threadIdx.x;
2298
+ const int i = 2*(blockDim.x*blockIdx.x + threadIdx.x);
1987
2299
 
1988
2300
  if (i >= k) {
1989
2301
  return;
@@ -2002,6 +2314,19 @@ static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __
2002
2314
  y[iybs + iqs + y_offset] = v.y;
2003
2315
  }
2004
2316
 
2317
+ template <typename src_t, typename dst_t>
2318
+ static __global__ void convert_unary(const void * __restrict__ vx, dst_t * __restrict__ y, const int k) {
2319
+ const int i = blockDim.x*blockIdx.x + threadIdx.x;
2320
+
2321
+ if (i >= k) {
2322
+ return;
2323
+ }
2324
+
2325
+ const src_t * x = (src_t *) vx;
2326
+
2327
+ y[i] = x[i];
2328
+ }
2329
+
2005
2330
  // VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called
2006
2331
  // MMVQ = mul_mat_vec_q, MMQ = mul_mat_q
2007
2332
 
@@ -3820,6 +4145,91 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_mul_mat(
3820
4145
  return vec_dot_q6_K_q8_1_impl_mmq(&x_ql[index_x], &y_qs[index_y], sc, x_dmf[i * (WARP_SIZE/QI6_K) + i/QI6_K], &y_df[index_y/QI8_1]);
3821
4146
  }
3822
4147
 
4148
+ static __device__ __forceinline__ float vec_dot_iq2_xxs_q8_1(
4149
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
4150
+ #if QK_K == 256
4151
+ const block_iq2_xxs * bq2 = (const block_iq2_xxs *) vbq;
4152
+
4153
+ #if QR2_XXS == 8
4154
+ const int ib32 = iqs;
4155
+ const uint16_t * q2 = bq2->qs + 4*ib32;
4156
+ const uint8_t * aux8 = (const uint8_t *)q2;
4157
+ const int8_t * q8 = bq8_1[ib32].qs;
4158
+ uint32_t aux32 = q2[2] | (q2[3] << 16);
4159
+ int sumi = 0;
4160
+ for (int l = 0; l < 4; ++l) {
4161
+ const uint8_t * grid = (const uint8_t *)(iq2xxs_grid + aux8[l]);
4162
+ const uint8_t signs = ksigns_iq2xs[aux32 & 127];
4163
+ for (int j = 0; j < 8; ++j) {
4164
+ sumi += q8[j] * grid[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
4165
+ }
4166
+ q8 += 8;
4167
+ aux32 >>= 7;
4168
+ }
4169
+ const float d = (float)bq2->d * (0.5f + aux32) * (float)bq8_1[ib32].ds.x * 0.25f;
4170
+ return d * sumi;
4171
+ #else
4172
+ // iqs is 0...15
4173
+ const int ib32 = iqs/2;
4174
+ const int il = iqs%2;
4175
+ const uint16_t * q2 = bq2->qs + 4*ib32;
4176
+ const uint8_t * aux8 = (const uint8_t *)q2;
4177
+ const uint8_t * grid1 = (const uint8_t *)(iq2xxs_grid + aux8[2*il+0]);
4178
+ const uint8_t * grid2 = (const uint8_t *)(iq2xxs_grid + aux8[2*il+1]);
4179
+ const uint32_t aux32 = q2[2] | (q2[3] << 16);
4180
+ const float d = (float)bq2->d * (0.5f + (aux32 >> 28)) * (float)bq8_1[ib32].ds.x * 0.25f;
4181
+ const uint8_t signs1 = ksigns_iq2xs[(aux32 >> 14*il) & 127];
4182
+ const uint8_t signs2 = ksigns_iq2xs[(aux32 >> (14*il + 7)) & 127];
4183
+ const int8_t * q8 = bq8_1[ib32].qs + 16*il;
4184
+ int sumi1 = 0, sumi2 = 0;
4185
+ for (int j = 0; j < 8; ++j) {
4186
+ sumi1 += q8[j+0] * grid1[j] * (signs1 & kmask_iq2xs[j] ? -1 : 1);
4187
+ sumi2 += q8[j+8] * grid2[j] * (signs2 & kmask_iq2xs[j] ? -1 : 1);
4188
+ }
4189
+ return d * (sumi1 + sumi2);
4190
+ #endif
4191
+ #else
4192
+ assert(false);
4193
+ return 0.f;
4194
+ #endif
4195
+ }
4196
+
4197
+ static __device__ __forceinline__ float vec_dot_iq2_xs_q8_1(
4198
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
4199
+ #if QK_K == 256
4200
+ const block_iq2_xs * bq2 = (const block_iq2_xs *) vbq;
4201
+
4202
+ const int ib32 = iqs;
4203
+ const uint16_t * q2 = bq2->qs + 4*ib32;
4204
+ const int8_t * q8 = bq8_1[ib32].qs;
4205
+ const uint8_t ls1 = bq2->scales[ib32] & 0xf;
4206
+ const uint8_t ls2 = bq2->scales[ib32] >> 4;
4207
+ int sumi1 = 0;
4208
+ for (int l = 0; l < 2; ++l) {
4209
+ const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[l] & 511));
4210
+ const uint8_t signs = ksigns_iq2xs[q2[l] >> 9];
4211
+ for (int j = 0; j < 8; ++j) {
4212
+ sumi1 += q8[j] * grid[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
4213
+ }
4214
+ q8 += 8;
4215
+ }
4216
+ int sumi2 = 0;
4217
+ for (int l = 2; l < 4; ++l) {
4218
+ const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[l] & 511));
4219
+ const uint8_t signs = ksigns_iq2xs[q2[l] >> 9];
4220
+ for (int j = 0; j < 8; ++j) {
4221
+ sumi2 += q8[j] * grid[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
4222
+ }
4223
+ q8 += 8;
4224
+ }
4225
+ const float d = (float)bq2->d * (float)bq8_1[ib32].ds.x * 0.25f;
4226
+ return d * ((0.5f + ls1) * sumi1 + (0.5f + ls2) * sumi2);
4227
+ #else
4228
+ assert(false);
4229
+ return 0.f;
4230
+ #endif
4231
+ }
4232
+
3823
4233
  template <int qk, int qr, int qi, bool need_sum, typename block_q_t, int mmq_x, int mmq_y, int nwarps,
3824
4234
  allocate_tiles_cuda_t allocate_tiles, load_tiles_cuda_t load_tiles, int vdr, vec_dot_q_mul_mat_cuda_t vec_dot>
3825
4235
  static __device__ __forceinline__ void mul_mat_q(
@@ -5201,75 +5611,233 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int
5201
5611
  dst[i] = x[i] - (col > n_past + row % rows_per_channel) * FLT_MAX;
5202
5612
  }
5203
5613
 
5204
- static __global__ void soft_max_f32(const float * x, const float * y, float * dst, const int ncols, const int nrows_y, const float scale) {
5614
+ template <bool vals_smem, int ncols_template, int block_size_template, bool need_check>
5615
+ static __global__ void soft_max_f16(const float * x, const float * y, float * dst, const int ncols_par, const int nrows_y, const float scale) {
5616
+ #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
5617
+ const int ncols_data = ncols_template == 0 ? ncols_par : ncols_template;
5618
+ const int ncols_smem = GGML_PAD(ncols_data, 2*WARP_SIZE)/2;
5619
+
5620
+ const int tid = threadIdx.x;
5621
+ const int rowx = blockIdx.x;
5622
+ const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension
5623
+
5624
+ const int block_size = block_size_template == 0 ? blockDim.x : block_size_template;
5625
+
5626
+ const int warp_id = threadIdx.x / WARP_SIZE;
5627
+ const int lane_id = threadIdx.x % WARP_SIZE;
5628
+
5629
+ extern __shared__ half data_soft_max_f16[];
5630
+ half * buf_iw = data_soft_max_f16 + 0; // shared memory buffer for inter-warp communication
5631
+ // (shared memory) buffer to cache values between iterations:
5632
+ half2 * vals = vals_smem ? (half2 *) (buf_iw + WARP_SIZE) : (half2 *) (dst + rowx*ncols_data);
5633
+ // if the buffer is larger than max. shared memory per block, use dst as temp. buffer instead
5634
+ // in that case col_smem == col_data must be enforced to avoid race conditions
5635
+
5636
+ half2 max_val = make_half2(-INFINITY, -INFINITY);
5637
+
5638
+ #pragma unroll
5639
+ for (int col0 = 0; col0 < ncols_smem; col0 += block_size) {
5640
+ const int col_data = 2*col0 + 2*WARP_SIZE*warp_id + lane_id;
5641
+ const int col_smem = vals_smem ? col0 + tid : col_data;
5642
+
5643
+ const int ix = rowx*ncols_data + col_data;
5644
+ const int iy = rowy*ncols_data + col_data;
5645
+
5646
+ half2 val;
5647
+ if (need_check && col_data + 0 >= ncols_data) {
5648
+ val.x = -INFINITY;
5649
+ } else {
5650
+ val.x = x[ix + 0]*scale + (y ? y[iy + 0] : 0.0f);
5651
+ }
5652
+ if (need_check && col_data + WARP_SIZE >= ncols_data) {
5653
+ val.y = -INFINITY;
5654
+ } else {
5655
+ val.y = x[ix + WARP_SIZE]*scale + (y ? y[iy + WARP_SIZE] : 0.0f);
5656
+ }
5657
+ if (!need_check || col_smem < (vals_smem ? ncols_smem : ncols_data)) {
5658
+ vals[col_smem] = val;
5659
+ }
5660
+ max_val = __hmax2(max_val, val);
5661
+ }
5662
+
5663
+ // find the max value in the block
5664
+ max_val = warp_reduce_max(max_val);
5665
+ if (block_size > WARP_SIZE) {
5666
+ if (warp_id == 0) {
5667
+ buf_iw[lane_id] = -INFINITY;
5668
+ }
5669
+ __syncthreads();
5670
+
5671
+ if (lane_id == 0) {
5672
+ buf_iw[warp_id] = __hmax(max_val.x, max_val.y);
5673
+ }
5674
+ __syncthreads();
5675
+
5676
+ max_val = __half2half2(buf_iw[lane_id]);
5677
+ max_val = warp_reduce_max(max_val);
5678
+ } else {
5679
+ max_val = __half2half2(__hmax(max_val.x, max_val.y));
5680
+ }
5681
+
5682
+ half2 tmp = make_half2(0.0f, 0.0f); // partial sums
5683
+
5684
+ #pragma unroll
5685
+ for (int col0 = 0; col0 < ncols_smem; col0 += block_size) {
5686
+ const int col_smem = vals_smem ? col0 + tid : 2*col0 + 2*warp_id*WARP_SIZE + lane_id;
5687
+
5688
+ if (ncols_template == 0 && col_smem >= (vals_smem ? ncols_smem : ncols_data)) {
5689
+ break;
5690
+ }
5691
+
5692
+ const half2 val = h2exp(vals[col_smem] - max_val);
5693
+
5694
+ tmp += val;
5695
+ vals[col_smem] = val;
5696
+ }
5697
+
5698
+ // find the sum of exps in the block
5699
+ tmp = warp_reduce_sum(tmp);
5700
+ if (block_size > WARP_SIZE) {
5701
+ if (warp_id == 0) {
5702
+ buf_iw[lane_id] = 0.0f;
5703
+ }
5704
+ __syncthreads();
5705
+
5706
+ if (lane_id == 0) {
5707
+ buf_iw[warp_id] = tmp.x + tmp.y;
5708
+ }
5709
+ __syncthreads();
5710
+
5711
+ tmp = __half2half2(buf_iw[lane_id]);
5712
+ tmp = warp_reduce_sum(tmp);
5713
+ } else {
5714
+ tmp = __half2half2(tmp.x + tmp.y);
5715
+ }
5716
+
5717
+ const half2 inv_sum = make_half2(1.0f, 1.0f) / tmp;
5718
+
5719
+ #pragma unroll
5720
+ for (int col0 = 0; col0 < ncols_smem; col0 += block_size) {
5721
+ const int col_data = 2*col0 + 2*WARP_SIZE*warp_id + lane_id;
5722
+ const int col_smem = vals_smem ? col0 + tid : col_data;
5723
+
5724
+ const int idst = rowx*ncols_data + col_data;
5725
+ const half2 result = vals[col_smem] * inv_sum;
5726
+
5727
+ if (need_check && col_data + 0 >= ncols_data) {
5728
+ return;
5729
+ }
5730
+ dst[idst] = result.x;
5731
+
5732
+ if (need_check && col_data + WARP_SIZE >= ncols_data) {
5733
+ return;
5734
+ }
5735
+
5736
+ dst[idst + WARP_SIZE] = result.y;
5737
+ }
5738
+ #else
5739
+ (void) x; (void) y; (void) dst; (void) ncols_par; (void) nrows_y; (void) scale;
5740
+ bad_arch();
5741
+ #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
5742
+ }
5743
+
5744
+ template <bool vals_smem, int ncols_template, int block_size_template>
5745
+ static __global__ void soft_max_f32(const float * x, const float * y, float * dst, const int ncols_par, const int nrows_y, const float scale) {
5746
+ const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
5747
+
5205
5748
  const int tid = threadIdx.x;
5206
5749
  const int rowx = blockIdx.x;
5207
5750
  const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension
5208
5751
 
5209
- const int block_size = blockDim.x;
5752
+ const int block_size = block_size_template == 0 ? blockDim.x : block_size_template;
5210
5753
 
5211
5754
  const int warp_id = threadIdx.x / WARP_SIZE;
5212
5755
  const int lane_id = threadIdx.x % WARP_SIZE;
5213
5756
 
5214
- __shared__ float buf[CUDA_SOFT_MAX_BLOCK_SIZE/WARP_SIZE];
5757
+ extern __shared__ float data_soft_max_f32[];
5758
+ float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication
5759
+ // shared memory buffer to cache values between iterations:
5760
+ float * vals = vals_smem ? buf_iw + WARP_SIZE : dst + rowx*ncols;
5215
5761
 
5216
5762
  float max_val = -INFINITY;
5217
5763
 
5218
- for (int col = tid; col < ncols; col += block_size) {
5764
+ #pragma unroll
5765
+ for (int col0 = 0; col0 < ncols; col0 += block_size) {
5766
+ const int col = col0 + tid;
5767
+
5768
+ if (ncols_template == 0 && col >= ncols) {
5769
+ break;
5770
+ }
5771
+
5219
5772
  const int ix = rowx*ncols + col;
5220
5773
  const int iy = rowy*ncols + col;
5221
- max_val = max(max_val, x[ix]*scale + (y ? y[iy] : 0.0f));
5774
+
5775
+ const float val = x[ix]*scale + (y ? y[iy] : 0.0f);
5776
+ vals[col] = val;
5777
+ max_val = max(max_val, val);
5222
5778
  }
5223
5779
 
5224
5780
  // find the max value in the block
5225
5781
  max_val = warp_reduce_max(max_val);
5226
5782
  if (block_size > WARP_SIZE) {
5227
5783
  if (warp_id == 0) {
5228
- buf[lane_id] = -INFINITY;
5784
+ buf_iw[lane_id] = -INFINITY;
5229
5785
  }
5230
5786
  __syncthreads();
5231
5787
 
5232
5788
  if (lane_id == 0) {
5233
- buf[warp_id] = max_val;
5789
+ buf_iw[warp_id] = max_val;
5234
5790
  }
5235
5791
  __syncthreads();
5236
5792
 
5237
- max_val = buf[lane_id];
5793
+ max_val = buf_iw[lane_id];
5238
5794
  max_val = warp_reduce_max(max_val);
5239
5795
  }
5240
5796
 
5241
- float tmp = 0.f;
5797
+ float tmp = 0.0f; // partial sum
5242
5798
 
5243
- for (int col = tid; col < ncols; col += block_size) {
5244
- const int ix = rowx*ncols + col;
5245
- const int iy = rowy*ncols + col;
5246
- const float val = expf((x[ix]*scale + (y ? y[iy] : 0.0f)) - max_val);
5799
+ #pragma unroll
5800
+ for (int col0 = 0; col0 < ncols; col0 += block_size) {
5801
+ const int col = col0 + tid;
5802
+
5803
+ if (ncols_template == 0 && col >= ncols) {
5804
+ break;
5805
+ }
5806
+
5807
+ const float val = expf(vals[col] - max_val);
5247
5808
  tmp += val;
5248
- dst[ix] = val;
5809
+ vals[col] = val;
5249
5810
  }
5250
5811
 
5251
5812
  // find the sum of exps in the block
5252
5813
  tmp = warp_reduce_sum(tmp);
5253
5814
  if (block_size > WARP_SIZE) {
5254
5815
  if (warp_id == 0) {
5255
- buf[lane_id] = 0.f;
5816
+ buf_iw[lane_id] = 0.0f;
5256
5817
  }
5257
5818
  __syncthreads();
5258
5819
 
5259
5820
  if (lane_id == 0) {
5260
- buf[warp_id] = tmp;
5821
+ buf_iw[warp_id] = tmp;
5261
5822
  }
5262
5823
  __syncthreads();
5263
5824
 
5264
- tmp = buf[lane_id];
5825
+ tmp = buf_iw[lane_id];
5265
5826
  tmp = warp_reduce_sum(tmp);
5266
5827
  }
5267
5828
 
5268
- const float inv_tmp = 1.f / tmp;
5829
+ const float inv_sum = 1.0f / tmp;
5269
5830
 
5270
- for (int col = tid; col < ncols; col += block_size) {
5271
- const int i = rowx*ncols + col;
5272
- dst[i] *= inv_tmp;
5831
+ #pragma unroll
5832
+ for (int col0 = 0; col0 < ncols; col0 += block_size) {
5833
+ const int col = col0 + tid;
5834
+
5835
+ if (ncols_template == 0 && col >= ncols) {
5836
+ return;
5837
+ }
5838
+
5839
+ const int idst = rowx*ncols + col;
5840
+ dst[idst] = vals[col] * inv_sum;
5273
5841
  }
5274
5842
  }
5275
5843
 
@@ -5609,7 +6177,7 @@ static void quantize_row_q8_1_cuda(const float * x, void * vy, const int kx, con
5609
6177
 
5610
6178
  template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
5611
6179
  static void dequantize_block_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) {
5612
- const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
6180
+ const int num_blocks = (k + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE);
5613
6181
  dequantize_block<qk, qr, dequantize_kernel><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
5614
6182
  }
5615
6183
 
@@ -5659,6 +6227,24 @@ static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int k, cu
5659
6227
  #endif
5660
6228
  }
5661
6229
 
6230
+ template<typename dst_t>
6231
+ static void dequantize_row_iq2_xxs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
6232
+ const int nb = k / QK_K;
6233
+ dequantize_block_iq2_xxs<<<nb, 32, 0, stream>>>(vx, y);
6234
+ }
6235
+
6236
+ template<typename dst_t>
6237
+ static void dequantize_row_iq2_xs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
6238
+ const int nb = k / QK_K;
6239
+ dequantize_block_iq2_xs<<<nb, 32, 0, stream>>>(vx, y);
6240
+ }
6241
+
6242
+ template <typename src_t, typename dst_t>
6243
+ static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) {
6244
+ const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
6245
+ convert_unary<src_t><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
6246
+ }
6247
+
5662
6248
  static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
5663
6249
  switch (type) {
5664
6250
  case GGML_TYPE_Q4_0:
@@ -5681,8 +6267,12 @@ static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
5681
6267
  return dequantize_row_q5_K_cuda;
5682
6268
  case GGML_TYPE_Q6_K:
5683
6269
  return dequantize_row_q6_K_cuda;
6270
+ case GGML_TYPE_IQ2_XXS:
6271
+ return dequantize_row_iq2_xxs_cuda;
6272
+ case GGML_TYPE_IQ2_XS:
6273
+ return dequantize_row_iq2_xs_cuda;
5684
6274
  case GGML_TYPE_F32:
5685
- return dequantize_block_cuda<1, 1, convert_f32>;
6275
+ return convert_unary_cuda<float>;
5686
6276
  default:
5687
6277
  return nullptr;
5688
6278
  }
@@ -5710,8 +6300,12 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
5710
6300
  return dequantize_row_q5_K_cuda;
5711
6301
  case GGML_TYPE_Q6_K:
5712
6302
  return dequantize_row_q6_K_cuda;
6303
+ case GGML_TYPE_IQ2_XXS:
6304
+ return dequantize_row_iq2_xxs_cuda;
6305
+ case GGML_TYPE_IQ2_XS:
6306
+ return dequantize_row_iq2_xs_cuda;
5713
6307
  case GGML_TYPE_F16:
5714
- return dequantize_block_cuda<1, 1, convert_f16>;
6308
+ return convert_unary_cuda<half>;
5715
6309
  default:
5716
6310
  return nullptr;
5717
6311
  }
@@ -5904,6 +6498,24 @@ static void mul_mat_vec_q6_K_q8_1_cuda(const void * vx, const void * vy, float *
5904
6498
  <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
5905
6499
  }
5906
6500
 
6501
+ static void mul_mat_vec_iq2_xxs_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
6502
+ GGML_ASSERT(ncols % QK_K == 0);
6503
+ const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
6504
+ const dim3 block_nums(block_num_y, 1, 1);
6505
+ const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
6506
+ mul_mat_vec_q<QK_K, QI2_XXS, block_iq2_xxs, 1, vec_dot_iq2_xxs_q8_1>
6507
+ <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
6508
+ }
6509
+
6510
+ static void mul_mat_vec_iq2_xs_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
6511
+ GGML_ASSERT(ncols % QK_K == 0);
6512
+ const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
6513
+ const dim3 block_nums(block_num_y, 1, 1);
6514
+ const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
6515
+ mul_mat_vec_q<QK_K, QI2_XS, block_iq2_xs, 1, vec_dot_iq2_xs_q8_1>
6516
+ <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
6517
+ }
6518
+
5907
6519
  static void ggml_mul_mat_q4_0_q8_1_cuda(
5908
6520
  const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
5909
6521
  const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
@@ -6543,12 +7155,90 @@ static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols
6543
7155
  diag_mask_inf_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x, rows_per_channel, n_past);
6544
7156
  }
6545
7157
 
7158
+ static void soft_max_f16_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) {
7159
+ int nth = WARP_SIZE;
7160
+ while (nth < ncols_x/2 && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
7161
+ const dim3 block_dims(nth, 1, 1);
7162
+ const dim3 block_nums(nrows_x, 1, 1);
7163
+ const size_t shmem = (GGML_PAD(ncols_x, 2*WARP_SIZE) + WARP_SIZE)*sizeof(half);
7164
+ static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
7165
+ if (shmem <= g_device_caps[g_main_device].smpb) {
7166
+ switch (ncols_x) {
7167
+ case 32:
7168
+ soft_max_f16<true, 32, 32, true><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
7169
+ break;
7170
+ case 64:
7171
+ soft_max_f16<true, 64, 32, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
7172
+ break;
7173
+ case 128:
7174
+ soft_max_f16<true, 128, 64, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
7175
+ break;
7176
+ case 256:
7177
+ soft_max_f16<true, 256, 128, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
7178
+ break;
7179
+ case 512:
7180
+ soft_max_f16<true, 512, 256, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
7181
+ break;
7182
+ case 1024:
7183
+ soft_max_f16<true, 1024, 512, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
7184
+ break;
7185
+ case 2048:
7186
+ soft_max_f16<true, 2048, 1024, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
7187
+ break;
7188
+ case 4096:
7189
+ soft_max_f16<true, 4096, 1024, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
7190
+ break;
7191
+ default:
7192
+ soft_max_f16<true, 0, 0, true><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
7193
+ break;
7194
+ }
7195
+ } else {
7196
+ const size_t shmem_low = WARP_SIZE*sizeof(half);
7197
+ soft_max_f16<false, 0, 0, true><<<block_nums, block_dims, shmem_low, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
7198
+ }
7199
+ }
7200
+
6546
7201
  static void soft_max_f32_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) {
6547
7202
  int nth = WARP_SIZE;
6548
7203
  while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
6549
7204
  const dim3 block_dims(nth, 1, 1);
6550
7205
  const dim3 block_nums(nrows_x, 1, 1);
6551
- soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
7206
+ const size_t shmem = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float);
7207
+ static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
7208
+ if (shmem < g_device_caps[g_main_device].smpb) {
7209
+ switch (ncols_x) {
7210
+ case 32:
7211
+ soft_max_f32<true, 32, 32><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
7212
+ break;
7213
+ case 64:
7214
+ soft_max_f32<true, 64, 64><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
7215
+ break;
7216
+ case 128:
7217
+ soft_max_f32<true, 128, 128><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
7218
+ break;
7219
+ case 256:
7220
+ soft_max_f32<true, 256, 256><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
7221
+ break;
7222
+ case 512:
7223
+ soft_max_f32<true, 512, 512><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
7224
+ break;
7225
+ case 1024:
7226
+ soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
7227
+ break;
7228
+ case 2048:
7229
+ soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
7230
+ break;
7231
+ case 4096:
7232
+ soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
7233
+ break;
7234
+ default:
7235
+ soft_max_f32<true, 0, 0><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
7236
+ break;
7237
+ }
7238
+ } else {
7239
+ const size_t shmem_low = WARP_SIZE*sizeof(float);
7240
+ soft_max_f32<false, 0, 0><<<block_nums, block_dims, shmem_low, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
7241
+ }
6552
7242
  }
6553
7243
 
6554
7244
  static void im2col_f32_f16_cuda(const float* x, half* dst,
@@ -6863,6 +7553,7 @@ void ggml_init_cublas() {
6863
7553
  #else
6864
7554
  g_device_caps[id].cc = 100*prop.major + 10*prop.minor;
6865
7555
  #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
7556
+ g_device_caps[id].smpb = prop.sharedMemPerBlock;
6866
7557
  }
6867
7558
  for (int id = 0; id < g_device_count; ++id) {
6868
7559
  g_tensor_split[id] /= total_vram;
@@ -7396,6 +8087,8 @@ static int64_t get_row_rounding(ggml_type type) {
7396
8087
  case GGML_TYPE_Q4_K:
7397
8088
  case GGML_TYPE_Q5_K:
7398
8089
  case GGML_TYPE_Q6_K:
8090
+ case GGML_TYPE_IQ2_XXS:
8091
+ case GGML_TYPE_IQ2_XS:
7399
8092
  return max_compute_capability >= CC_RDNA2 ? 128 : 64;
7400
8093
  default:
7401
8094
  GGML_ASSERT(false);
@@ -7416,6 +8109,8 @@ static int64_t get_row_rounding(ggml_type type) {
7416
8109
  case GGML_TYPE_Q3_K:
7417
8110
  case GGML_TYPE_Q4_K:
7418
8111
  case GGML_TYPE_Q5_K:
8112
+ case GGML_TYPE_IQ2_XXS:
8113
+ case GGML_TYPE_IQ2_XS:
7419
8114
  return max_compute_capability >= CC_VOLTA ? 128 : 64;
7420
8115
  case GGML_TYPE_Q6_K:
7421
8116
  return 64;
@@ -7466,6 +8161,12 @@ static void ggml_cuda_op_mul_mat_vec_q(
7466
8161
  case GGML_TYPE_Q6_K:
7467
8162
  mul_mat_vec_q6_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
7468
8163
  break;
8164
+ case GGML_TYPE_IQ2_XXS:
8165
+ mul_mat_vec_iq2_xxs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
8166
+ break;
8167
+ case GGML_TYPE_IQ2_XS:
8168
+ mul_mat_vec_iq2_xs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
8169
+ break;
7469
8170
  default:
7470
8171
  GGML_ASSERT(false);
7471
8172
  break;
@@ -7873,7 +8574,21 @@ static void ggml_cuda_op_soft_max(
7873
8574
  float scale = 1.0f;
7874
8575
  memcpy(&scale, dst->op_params, sizeof(float));
7875
8576
 
7876
- soft_max_f32_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream);
8577
+ #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
8578
+ const bool use_f16_soft_max = false;
8579
+ #else
8580
+ #ifdef GGML_CUDA_F16
8581
+ const bool use_f16_soft_max = true;
8582
+ #else
8583
+ const bool use_f16_soft_max = false;
8584
+ #endif // GGML_CUDA_F16
8585
+ #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
8586
+
8587
+ if (use_f16_soft_max) {
8588
+ soft_max_f16_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream);
8589
+ } else {
8590
+ soft_max_f32_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream);
8591
+ }
7877
8592
 
7878
8593
  (void) dst;
7879
8594
  }
@@ -8682,6 +9397,8 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
8682
9397
 
8683
9398
  #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
8684
9399
 
9400
+ use_mul_mat_q = use_mul_mat_q && ggml_cuda_supports_mmq(src0->type);
9401
+
8685
9402
  // debug helpers
8686
9403
  //printf("src0: %8d %8d %8d %8d\n", src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]);
8687
9404
  //printf(" %8d %8d %8d %8d\n", src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]);
@@ -9689,8 +10406,8 @@ static void ggml_backend_cuda_buffer_set_tensor(ggml_backend_buffer_t buffer, gg
9689
10406
 
9690
10407
  ggml_cuda_set_device(ctx->device);
9691
10408
  CUDA_CHECK(cudaDeviceSynchronize());
9692
-
9693
10409
  CUDA_CHECK(cudaMemcpy((char *)tensor->data + offset, data, size, cudaMemcpyHostToDevice));
10410
+ CUDA_CHECK(cudaDeviceSynchronize());
9694
10411
  }
9695
10412
 
9696
10413
  static void ggml_backend_cuda_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
@@ -9910,7 +10627,7 @@ static void ggml_backend_cuda_graph_plan_compute(ggml_backend_t backend, ggml_ba
9910
10627
  UNUSED(plan);
9911
10628
  }
9912
10629
 
9913
- static void ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
10630
+ static bool ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
9914
10631
  ggml_backend_context_cuda * cuda_ctx = (ggml_backend_context_cuda *)backend->context;
9915
10632
 
9916
10633
  ggml_cuda_set_main_device(cuda_ctx->device);
@@ -9967,6 +10684,8 @@ static void ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph
9967
10684
  }
9968
10685
 
9969
10686
  UNUSED(backend);
10687
+
10688
+ return true;
9970
10689
  }
9971
10690
 
9972
10691
  static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_tensor * op) {