llama_cpp 0.12.0 → 0.12.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -116,6 +116,7 @@
116
116
  #include "ggml.h"
117
117
  #include "ggml-backend-impl.h"
118
118
 
119
+ #define CC_PASCAL 600
119
120
  #define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
120
121
  #define CC_VOLTA 700
121
122
  #define CC_OFFSET_AMD 1000000
@@ -183,7 +184,7 @@ static __device__ __forceinline__ int __vsubss4(const int a, const int b) {
183
184
  static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) {
184
185
  #if defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx1030__)
185
186
  c = __builtin_amdgcn_sdot4(a, b, c, false);
186
- #elif defined(__gfx1100__)
187
+ #elif defined(RDNA3)
187
188
  c = __builtin_amdgcn_sudot4( true, a, true, b, c, false);
188
189
  #elif defined(__gfx1010__) || defined(__gfx900__)
189
190
  int tmp1;
@@ -477,6 +478,23 @@ typedef struct {
477
478
  } block_q6_K;
478
479
  static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_K block size/padding");
479
480
 
481
+ #define QR2_XXS 8
482
+ #define QI2_XXS (QK_K / (4*QR2_XXS))
483
+ typedef struct {
484
+ half d;
485
+ uint16_t qs[QK_K/8];
486
+ } block_iq2_xxs;
487
+ static_assert(sizeof(block_iq2_xxs) == sizeof(ggml_fp16_t) + QK_K/8*sizeof(uint16_t), "wrong iq2_xxs block size/padding");
488
+
489
+ #define QR2_XS 8
490
+ #define QI2_XS (QK_K / (4*QR2_XS))
491
+ typedef struct {
492
+ half d;
493
+ uint16_t qs[QK_K/8];
494
+ uint8_t scales[QK_K/32];
495
+ } block_iq2_xs;
496
+ static_assert(sizeof(block_iq2_xs) == sizeof(ggml_fp16_t) + QK_K/8*sizeof(uint16_t) + QK_K/32, "wrong iq2_xs block size/padding");
497
+
480
498
  #define WARP_SIZE 32
481
499
  #define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
482
500
 
@@ -548,11 +566,12 @@ static float g_tensor_split[GGML_CUDA_MAX_DEVICES] = {0};
548
566
 
549
567
  struct cuda_device_capabilities {
550
568
  int cc; // compute capability
569
+ size_t smpb; // max. shared memory per block
551
570
  bool vmm; // virtual memory support
552
571
  size_t vmm_granularity; // granularity of virtual memory
553
572
  };
554
573
 
555
- static cuda_device_capabilities g_device_caps[GGML_CUDA_MAX_DEVICES] = { {0, false, 0} };
574
+ static cuda_device_capabilities g_device_caps[GGML_CUDA_MAX_DEVICES] = { {0, 0, false, 0} };
556
575
 
557
576
  static void * g_scratch_buffer = nullptr;
558
577
  static size_t g_scratch_size = 0; // disabled by default
@@ -585,6 +604,19 @@ static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
585
604
  return a;
586
605
  }
587
606
 
607
+ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
608
+ #if __CUDA_ARCH__ < CC_PASCAL || (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
609
+ (void) a;
610
+ bad_arch();
611
+ #else
612
+ #pragma unroll
613
+ for (int mask = 16; mask > 0; mask >>= 1) {
614
+ a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32));
615
+ }
616
+ return a;
617
+ #endif // __CUDA_ARCH__ < CC_PASCAL || (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
618
+ }
619
+
588
620
  static __device__ __forceinline__ float warp_reduce_max(float x) {
589
621
  #pragma unroll
590
622
  for (int mask = 16; mask > 0; mask >>= 1) {
@@ -593,6 +625,19 @@ static __device__ __forceinline__ float warp_reduce_max(float x) {
593
625
  return x;
594
626
  }
595
627
 
628
+ static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
629
+ #if __CUDA_ARCH__ < CC_PASCAL || (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
630
+ (void) x;
631
+ bad_arch();
632
+ #else
633
+ #pragma unroll
634
+ for (int mask = 16; mask > 0; mask >>= 1) {
635
+ x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
636
+ }
637
+ return x;
638
+ #endif // __CUDA_ARCH__ < CC_PASCAL || (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
639
+ }
640
+
596
641
  static __device__ __forceinline__ float op_repeat(const float a, const float b) {
597
642
  return b;
598
643
  GGML_UNUSED(a);
@@ -1292,6 +1337,281 @@ static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t
1292
1337
  #endif
1293
1338
  }
1294
1339
 
1340
+ static const __device__ uint64_t iq2xxs_grid[256] = {
1341
+ 0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
1342
+ 0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x08080808082b0808,
1343
+ 0x08080808082b082b, 0x08080808082b2b08, 0x08080808082b2b2b, 0x0808080819080819,
1344
+ 0x0808080819081908, 0x0808080819190808, 0x0808080819192b08, 0x08080808192b0819,
1345
+ 0x08080808192b1908, 0x080808082b080808, 0x080808082b08082b, 0x080808082b082b2b,
1346
+ 0x080808082b2b082b, 0x0808081908080819, 0x0808081908081908, 0x0808081908190808,
1347
+ 0x0808081908191919, 0x0808081919080808, 0x080808192b081908, 0x080808192b192b08,
1348
+ 0x0808082b08080808, 0x0808082b0808082b, 0x0808082b082b082b, 0x0808082b2b08082b,
1349
+ 0x0808190808080819, 0x0808190808081908, 0x0808190808190808, 0x08081908082b0819,
1350
+ 0x08081908082b1908, 0x0808190819080808, 0x080819081908082b, 0x0808190819082b08,
1351
+ 0x08081908192b0808, 0x080819082b080819, 0x080819082b081908, 0x080819082b190808,
1352
+ 0x080819082b2b1908, 0x0808191908080808, 0x080819190808082b, 0x0808191908082b08,
1353
+ 0x08081919082b0808, 0x080819191908192b, 0x08081919192b2b19, 0x080819192b080808,
1354
+ 0x080819192b190819, 0x0808192b08082b19, 0x0808192b08190808, 0x0808192b19080808,
1355
+ 0x0808192b2b081908, 0x0808192b2b2b1908, 0x08082b0808080808, 0x08082b0808081919,
1356
+ 0x08082b0808082b08, 0x08082b0808191908, 0x08082b08082b2b08, 0x08082b0819080819,
1357
+ 0x08082b0819081908, 0x08082b0819190808, 0x08082b081919082b, 0x08082b082b082b08,
1358
+ 0x08082b1908081908, 0x08082b1919080808, 0x08082b2b0808082b, 0x08082b2b08191908,
1359
+ 0x0819080808080819, 0x0819080808081908, 0x0819080808190808, 0x08190808082b0819,
1360
+ 0x0819080819080808, 0x08190808192b0808, 0x081908082b081908, 0x081908082b190808,
1361
+ 0x081908082b191919, 0x0819081908080808, 0x0819081908082b08, 0x08190819082b0808,
1362
+ 0x0819081919190808, 0x0819081919192b2b, 0x081908192b080808, 0x0819082b082b1908,
1363
+ 0x0819082b19081919, 0x0819190808080808, 0x0819190808082b08, 0x08191908082b0808,
1364
+ 0x08191908082b1919, 0x0819190819082b19, 0x081919082b080808, 0x0819191908192b08,
1365
+ 0x08191919192b082b, 0x0819192b08080808, 0x0819192b0819192b, 0x08192b0808080819,
1366
+ 0x08192b0808081908, 0x08192b0808190808, 0x08192b0819080808, 0x08192b082b080819,
1367
+ 0x08192b1908080808, 0x08192b1908081919, 0x08192b192b2b0808, 0x08192b2b19190819,
1368
+ 0x082b080808080808, 0x082b08080808082b, 0x082b080808082b2b, 0x082b080819081908,
1369
+ 0x082b0808192b0819, 0x082b08082b080808, 0x082b08082b08082b, 0x082b0819082b2b19,
1370
+ 0x082b081919082b08, 0x082b082b08080808, 0x082b082b0808082b, 0x082b190808080819,
1371
+ 0x082b190808081908, 0x082b190808190808, 0x082b190819080808, 0x082b19081919192b,
1372
+ 0x082b191908080808, 0x082b191919080819, 0x082b1919192b1908, 0x082b192b2b190808,
1373
+ 0x082b2b0808082b08, 0x082b2b08082b0808, 0x082b2b082b191908, 0x082b2b2b19081908,
1374
+ 0x1908080808080819, 0x1908080808081908, 0x1908080808190808, 0x1908080808192b08,
1375
+ 0x19080808082b0819, 0x19080808082b1908, 0x1908080819080808, 0x1908080819082b08,
1376
+ 0x190808081919192b, 0x19080808192b0808, 0x190808082b080819, 0x190808082b081908,
1377
+ 0x190808082b190808, 0x1908081908080808, 0x19080819082b0808, 0x19080819192b0819,
1378
+ 0x190808192b080808, 0x190808192b081919, 0x1908082b08080819, 0x1908082b08190808,
1379
+ 0x1908082b19082b08, 0x1908082b1919192b, 0x1908082b192b2b08, 0x1908190808080808,
1380
+ 0x1908190808082b08, 0x19081908082b0808, 0x190819082b080808, 0x190819082b192b19,
1381
+ 0x190819190819082b, 0x19081919082b1908, 0x1908192b08080808, 0x19082b0808080819,
1382
+ 0x19082b0808081908, 0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919,
1383
+ 0x19082b1908080808, 0x19082b1919192b08, 0x19082b19192b0819, 0x19082b192b08082b,
1384
+ 0x19082b2b19081919, 0x19082b2b2b190808, 0x1919080808080808, 0x1919080808082b08,
1385
+ 0x1919080808190819, 0x1919080808192b19, 0x19190808082b0808, 0x191908082b080808,
1386
+ 0x191908082b082b08, 0x1919081908081908, 0x191908191908082b, 0x191908192b2b1908,
1387
+ 0x1919082b2b190819, 0x191919082b190808, 0x191919082b19082b, 0x1919191908082b2b,
1388
+ 0x1919192b08080819, 0x1919192b19191908, 0x19192b0808080808, 0x19192b0808190819,
1389
+ 0x19192b0808192b19, 0x19192b08192b1908, 0x19192b1919080808, 0x19192b2b08082b08,
1390
+ 0x192b080808081908, 0x192b080808190808, 0x192b080819080808, 0x192b0808192b2b08,
1391
+ 0x192b081908080808, 0x192b081919191919, 0x192b082b08192b08, 0x192b082b192b0808,
1392
+ 0x192b190808080808, 0x192b190808081919, 0x192b191908190808, 0x192b19190819082b,
1393
+ 0x192b19192b081908, 0x192b2b081908082b, 0x2b08080808080808, 0x2b0808080808082b,
1394
+ 0x2b08080808082b2b, 0x2b08080819080819, 0x2b0808082b08082b, 0x2b08081908081908,
1395
+ 0x2b08081908192b08, 0x2b08081919080808, 0x2b08082b08190819, 0x2b08190808080819,
1396
+ 0x2b08190808081908, 0x2b08190808190808, 0x2b08190808191919, 0x2b08190819080808,
1397
+ 0x2b081908192b0808, 0x2b08191908080808, 0x2b0819191908192b, 0x2b0819192b191908,
1398
+ 0x2b08192b08082b19, 0x2b08192b19080808, 0x2b08192b192b0808, 0x2b082b080808082b,
1399
+ 0x2b082b1908081908, 0x2b082b2b08190819, 0x2b19080808081908, 0x2b19080808190808,
1400
+ 0x2b190808082b1908, 0x2b19080819080808, 0x2b1908082b2b0819, 0x2b1908190819192b,
1401
+ 0x2b1908192b080808, 0x2b19082b19081919, 0x2b19190808080808, 0x2b191908082b082b,
1402
+ 0x2b19190819081908, 0x2b19191919190819, 0x2b192b082b080819, 0x2b192b19082b0808,
1403
+ 0x2b2b08080808082b, 0x2b2b080819190808, 0x2b2b08082b081919, 0x2b2b081908082b19,
1404
+ 0x2b2b082b08080808, 0x2b2b190808192b08, 0x2b2b2b0819190808, 0x2b2b2b1908081908,
1405
+ };
1406
+
1407
+ static const __device__ uint64_t iq2xs_grid[512] = {
1408
+ 0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
1409
+ 0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x080808080819192b,
1410
+ 0x0808080808192b19, 0x08080808082b0808, 0x08080808082b082b, 0x08080808082b1919,
1411
+ 0x08080808082b2b08, 0x0808080819080819, 0x0808080819081908, 0x080808081908192b,
1412
+ 0x0808080819082b19, 0x0808080819190808, 0x080808081919082b, 0x0808080819191919,
1413
+ 0x0808080819192b08, 0x08080808192b0819, 0x08080808192b1908, 0x080808082b080808,
1414
+ 0x080808082b08082b, 0x080808082b081919, 0x080808082b082b08, 0x080808082b190819,
1415
+ 0x080808082b191908, 0x080808082b192b19, 0x080808082b2b0808, 0x0808081908080819,
1416
+ 0x0808081908081908, 0x080808190808192b, 0x0808081908082b19, 0x0808081908190808,
1417
+ 0x080808190819082b, 0x0808081908191919, 0x0808081908192b08, 0x0808081908192b2b,
1418
+ 0x08080819082b0819, 0x08080819082b1908, 0x0808081919080808, 0x080808191908082b,
1419
+ 0x0808081919081919, 0x0808081919082b08, 0x0808081919190819, 0x0808081919191908,
1420
+ 0x08080819192b0808, 0x08080819192b2b08, 0x080808192b080819, 0x080808192b081908,
1421
+ 0x080808192b190808, 0x0808082b08080808, 0x0808082b0808082b, 0x0808082b08081919,
1422
+ 0x0808082b08082b08, 0x0808082b08190819, 0x0808082b08191908, 0x0808082b082b0808,
1423
+ 0x0808082b19080819, 0x0808082b19081908, 0x0808082b19190808, 0x0808082b19191919,
1424
+ 0x0808082b2b080808, 0x0808082b2b082b2b, 0x0808190808080819, 0x0808190808081908,
1425
+ 0x080819080808192b, 0x0808190808082b19, 0x0808190808190808, 0x080819080819082b,
1426
+ 0x0808190808191919, 0x0808190808192b08, 0x08081908082b0819, 0x08081908082b1908,
1427
+ 0x0808190819080808, 0x080819081908082b, 0x0808190819081919, 0x0808190819082b08,
1428
+ 0x0808190819190819, 0x0808190819191908, 0x080819081919192b, 0x08081908192b0808,
1429
+ 0x080819082b080819, 0x080819082b081908, 0x080819082b190808, 0x0808191908080808,
1430
+ 0x080819190808082b, 0x0808191908081919, 0x0808191908082b08, 0x0808191908190819,
1431
+ 0x0808191908191908, 0x08081919082b0808, 0x0808191919080819, 0x0808191919081908,
1432
+ 0x0808191919190808, 0x08081919192b0819, 0x080819192b080808, 0x0808192b08080819,
1433
+ 0x0808192b08081908, 0x0808192b08190808, 0x0808192b082b192b, 0x0808192b19080808,
1434
+ 0x0808192b1908082b, 0x0808192b2b081908, 0x08082b0808080808, 0x08082b080808082b,
1435
+ 0x08082b0808081919, 0x08082b0808082b08, 0x08082b0808082b2b, 0x08082b0808190819,
1436
+ 0x08082b0808191908, 0x08082b08082b0808, 0x08082b08082b1919, 0x08082b0819080819,
1437
+ 0x08082b0819081908, 0x08082b0819190808, 0x08082b0819192b08, 0x08082b082b080808,
1438
+ 0x08082b082b2b0808, 0x08082b082b2b2b2b, 0x08082b1908080819, 0x08082b1908081908,
1439
+ 0x08082b1908190808, 0x08082b1919080808, 0x08082b192b080819, 0x08082b192b082b19,
1440
+ 0x08082b2b08080808, 0x08082b2b082b0808, 0x08082b2b082b2b08, 0x08082b2b2b19192b,
1441
+ 0x08082b2b2b2b0808, 0x0819080808080819, 0x0819080808081908, 0x081908080808192b,
1442
+ 0x0819080808082b19, 0x0819080808190808, 0x081908080819082b, 0x0819080808191919,
1443
+ 0x0819080808192b08, 0x08190808082b0819, 0x08190808082b1908, 0x0819080819080808,
1444
+ 0x081908081908082b, 0x0819080819081919, 0x0819080819082b08, 0x0819080819190819,
1445
+ 0x0819080819191908, 0x08190808192b0808, 0x08190808192b2b2b, 0x081908082b080819,
1446
+ 0x081908082b081908, 0x081908082b190808, 0x0819081908080808, 0x081908190808082b,
1447
+ 0x0819081908081919, 0x0819081908082b08, 0x0819081908190819, 0x0819081908191908,
1448
+ 0x08190819082b0808, 0x0819081919080819, 0x0819081919081908, 0x0819081919190808,
1449
+ 0x081908192b080808, 0x081908192b191908, 0x081908192b19192b, 0x0819082b08080819,
1450
+ 0x0819082b08081908, 0x0819082b0808192b, 0x0819082b08190808, 0x0819082b19080808,
1451
+ 0x0819082b192b0808, 0x0819190808080808, 0x081919080808082b, 0x0819190808081919,
1452
+ 0x0819190808082b08, 0x0819190808190819, 0x0819190808191908, 0x08191908082b0808,
1453
+ 0x0819190819080819, 0x0819190819081908, 0x0819190819082b19, 0x0819190819190808,
1454
+ 0x08191908192b1908, 0x081919082b080808, 0x0819191908080819, 0x0819191908081908,
1455
+ 0x0819191908190808, 0x0819191919080808, 0x0819192b08080808, 0x0819192b08191908,
1456
+ 0x0819192b19082b19, 0x08192b0808080819, 0x08192b0808081908, 0x08192b0808190808,
1457
+ 0x08192b080819082b, 0x08192b0819080808, 0x08192b0819191908, 0x08192b082b08192b,
1458
+ 0x08192b1908080808, 0x08192b1908081919, 0x08192b19192b192b, 0x08192b2b19190819,
1459
+ 0x08192b2b2b2b2b19, 0x082b080808080808, 0x082b08080808082b, 0x082b080808081919,
1460
+ 0x082b080808082b08, 0x082b080808082b2b, 0x082b080808190819, 0x082b080808191908,
1461
+ 0x082b0808082b0808, 0x082b080819080819, 0x082b080819081908, 0x082b080819190808,
1462
+ 0x082b08082b080808, 0x082b08082b2b0808, 0x082b081908080819, 0x082b081908081908,
1463
+ 0x082b081908190808, 0x082b081919080808, 0x082b081919082b08, 0x082b0819192b1919,
1464
+ 0x082b082b08080808, 0x082b082b082b082b, 0x082b082b2b080808, 0x082b082b2b2b2b08,
1465
+ 0x082b190808080819, 0x082b190808081908, 0x082b190808190808, 0x082b1908082b2b19,
1466
+ 0x082b190819080808, 0x082b191908080808, 0x082b191919080819, 0x082b19191919082b,
1467
+ 0x082b19192b192b19, 0x082b192b08080819, 0x082b192b08192b2b, 0x082b192b2b2b192b,
1468
+ 0x082b2b0808080808, 0x082b2b0808082b08, 0x082b2b0808082b2b, 0x082b2b08082b0808,
1469
+ 0x082b2b0819191919, 0x082b2b082b082b08, 0x082b2b082b2b082b, 0x082b2b19192b2b08,
1470
+ 0x082b2b192b190808, 0x082b2b2b08082b08, 0x082b2b2b082b0808, 0x082b2b2b2b08082b,
1471
+ 0x082b2b2b2b082b08, 0x082b2b2b2b082b2b, 0x1908080808080819, 0x1908080808081908,
1472
+ 0x190808080808192b, 0x1908080808082b19, 0x1908080808190808, 0x190808080819082b,
1473
+ 0x1908080808191919, 0x1908080808192b08, 0x19080808082b0819, 0x19080808082b1908,
1474
+ 0x1908080819080808, 0x190808081908082b, 0x1908080819081919, 0x1908080819082b08,
1475
+ 0x1908080819082b2b, 0x1908080819190819, 0x1908080819191908, 0x19080808192b0808,
1476
+ 0x19080808192b1919, 0x190808082b080819, 0x190808082b081908, 0x190808082b190808,
1477
+ 0x1908081908080808, 0x190808190808082b, 0x1908081908081919, 0x1908081908082b08,
1478
+ 0x1908081908190819, 0x1908081908191908, 0x19080819082b0808, 0x1908081919080819,
1479
+ 0x1908081919081908, 0x1908081919190808, 0x190808192b080808, 0x190808192b081919,
1480
+ 0x190808192b2b082b, 0x1908082b08080819, 0x1908082b08081908, 0x1908082b08190808,
1481
+ 0x1908082b0819082b, 0x1908082b082b2b19, 0x1908082b19080808, 0x1908190808080808,
1482
+ 0x190819080808082b, 0x1908190808081919, 0x1908190808082b08, 0x1908190808190819,
1483
+ 0x1908190808191908, 0x1908190808192b19, 0x19081908082b0808, 0x1908190819080819,
1484
+ 0x1908190819081908, 0x1908190819190808, 0x190819082b080808, 0x190819082b191908,
1485
+ 0x1908191908080819, 0x1908191908081908, 0x1908191908190808, 0x19081919082b1908,
1486
+ 0x1908191919080808, 0x190819192b192b2b, 0x1908192b08080808, 0x1908192b08082b2b,
1487
+ 0x1908192b19081908, 0x1908192b19190808, 0x19082b0808080819, 0x19082b0808081908,
1488
+ 0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919, 0x19082b0819191908,
1489
+ 0x19082b08192b082b, 0x19082b1908080808, 0x19082b1908190819, 0x19082b1919081908,
1490
+ 0x19082b1919190808, 0x19082b19192b2b19, 0x19082b2b08081908, 0x1919080808080808,
1491
+ 0x191908080808082b, 0x1919080808081919, 0x1919080808082b08, 0x1919080808190819,
1492
+ 0x1919080808191908, 0x19190808082b0808, 0x19190808082b2b08, 0x1919080819080819,
1493
+ 0x1919080819081908, 0x1919080819190808, 0x191908082b080808, 0x1919081908080819,
1494
+ 0x1919081908081908, 0x1919081908190808, 0x1919081908191919, 0x1919081919080808,
1495
+ 0x191908191908082b, 0x1919082b08080808, 0x1919082b19081908, 0x1919082b2b2b2b2b,
1496
+ 0x1919190808080819, 0x1919190808081908, 0x1919190808190808, 0x19191908082b0819,
1497
+ 0x1919190819080808, 0x19191908192b0808, 0x191919082b080819, 0x191919082b2b0819,
1498
+ 0x1919191908080808, 0x1919191908082b08, 0x191919192b080808, 0x191919192b082b08,
1499
+ 0x1919192b082b0819, 0x1919192b192b2b08, 0x1919192b2b2b0819, 0x19192b0808080808,
1500
+ 0x19192b0808191908, 0x19192b0819080819, 0x19192b0819190808, 0x19192b082b192b19,
1501
+ 0x19192b1908192b2b, 0x19192b1919080808, 0x19192b191908082b, 0x19192b2b2b081919,
1502
+ 0x192b080808080819, 0x192b080808081908, 0x192b080808190808, 0x192b080819080808,
1503
+ 0x192b080819191908, 0x192b0808192b082b, 0x192b08082b08192b, 0x192b08082b2b2b19,
1504
+ 0x192b081908080808, 0x192b082b082b1908, 0x192b082b19082b2b, 0x192b082b2b19082b,
1505
+ 0x192b190808080808, 0x192b19080819192b, 0x192b191908190808, 0x192b191919080808,
1506
+ 0x192b191919081919, 0x192b19192b2b1908, 0x192b2b0808080819, 0x192b2b08192b2b2b,
1507
+ 0x192b2b19082b1919, 0x192b2b2b0808192b, 0x192b2b2b19191908, 0x192b2b2b192b082b,
1508
+ 0x2b08080808080808, 0x2b0808080808082b, 0x2b08080808081919, 0x2b08080808082b08,
1509
+ 0x2b08080808190819, 0x2b08080808191908, 0x2b080808082b0808, 0x2b080808082b2b2b,
1510
+ 0x2b08080819080819, 0x2b08080819081908, 0x2b08080819190808, 0x2b0808082b080808,
1511
+ 0x2b0808082b08082b, 0x2b0808082b2b2b08, 0x2b0808082b2b2b2b, 0x2b08081908080819,
1512
+ 0x2b08081908081908, 0x2b0808190808192b, 0x2b08081908190808, 0x2b08081919080808,
1513
+ 0x2b08081919190819, 0x2b08081919192b19, 0x2b08082b08080808, 0x2b08082b082b0808,
1514
+ 0x2b08082b2b080808, 0x2b08082b2b08082b, 0x2b08082b2b2b0808, 0x2b08082b2b2b2b08,
1515
+ 0x2b08190808080819, 0x2b08190808081908, 0x2b08190808190808, 0x2b0819080819082b,
1516
+ 0x2b08190808191919, 0x2b08190819080808, 0x2b081908192b0808, 0x2b0819082b082b19,
1517
+ 0x2b08191908080808, 0x2b08191919081908, 0x2b0819192b2b1919, 0x2b08192b08192b08,
1518
+ 0x2b08192b192b2b2b, 0x2b082b0808080808, 0x2b082b0808082b08, 0x2b082b08082b1919,
1519
+ 0x2b082b0819192b2b, 0x2b082b082b080808, 0x2b082b082b08082b, 0x2b082b082b2b2b08,
1520
+ 0x2b082b190808192b, 0x2b082b2b082b082b, 0x2b082b2b2b080808, 0x2b082b2b2b082b08,
1521
+ 0x2b082b2b2b19192b, 0x2b082b2b2b2b2b08, 0x2b19080808080819, 0x2b19080808081908,
1522
+ 0x2b19080808190808, 0x2b19080819080808, 0x2b1908081919192b, 0x2b1908082b081908,
1523
+ 0x2b19081908080808, 0x2b190819082b082b, 0x2b190819192b1908, 0x2b19082b1919192b,
1524
+ 0x2b19082b2b082b19, 0x2b19190808080808, 0x2b19190808081919, 0x2b19190819081908,
1525
+ 0x2b19190819190808, 0x2b19190819192b08, 0x2b191919082b2b19, 0x2b1919192b190808,
1526
+ 0x2b1919192b19082b, 0x2b19192b19080819, 0x2b192b0819190819, 0x2b192b082b2b192b,
1527
+ 0x2b192b1919082b19, 0x2b192b2b08191919, 0x2b192b2b192b0808, 0x2b2b080808080808,
1528
+ 0x2b2b08080808082b, 0x2b2b080808082b08, 0x2b2b080808082b2b, 0x2b2b0808082b0808,
1529
+ 0x2b2b0808082b2b2b, 0x2b2b08082b2b0808, 0x2b2b081919190819, 0x2b2b081919192b19,
1530
+ 0x2b2b08192b2b192b, 0x2b2b082b08080808, 0x2b2b082b0808082b, 0x2b2b082b08082b08,
1531
+ 0x2b2b082b082b2b2b, 0x2b2b082b2b080808, 0x2b2b082b2b2b0808, 0x2b2b190819080808,
1532
+ 0x2b2b19082b191919, 0x2b2b192b192b1919, 0x2b2b192b2b192b08, 0x2b2b2b0808082b2b,
1533
+ 0x2b2b2b08082b0808, 0x2b2b2b08082b082b, 0x2b2b2b08082b2b08, 0x2b2b2b082b2b0808,
1534
+ 0x2b2b2b082b2b2b08, 0x2b2b2b1908081908, 0x2b2b2b192b081908, 0x2b2b2b192b08192b,
1535
+ 0x2b2b2b2b082b2b08, 0x2b2b2b2b082b2b2b, 0x2b2b2b2b2b190819, 0x2b2b2b2b2b2b2b2b,
1536
+ };
1537
+
1538
+ static const __device__ uint8_t ksigns_iq2xs[128] = {
1539
+ 0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15,
1540
+ 144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159,
1541
+ 160, 33, 34, 163, 36, 165, 166, 39, 40, 169, 170, 43, 172, 45, 46, 175,
1542
+ 48, 177, 178, 51, 180, 53, 54, 183, 184, 57, 58, 187, 60, 189, 190, 63,
1543
+ 192, 65, 66, 195, 68, 197, 198, 71, 72, 201, 202, 75, 204, 77, 78, 207,
1544
+ 80, 209, 210, 83, 212, 85, 86, 215, 216, 89, 90, 219, 92, 221, 222, 95,
1545
+ 96, 225, 226, 99, 228, 101, 102, 231, 232, 105, 106, 235, 108, 237, 238, 111,
1546
+ 240, 113, 114, 243, 116, 245, 246, 119, 120, 249, 250, 123, 252, 125, 126, 255,
1547
+ };
1548
+
1549
+ static const __device__ uint8_t kmask_iq2xs[8] = {1, 2, 4, 8, 16, 32, 64, 128};
1550
+
1551
+ inline bool ggml_cuda_supports_mmq(enum ggml_type type) {
1552
+ switch (type) {
1553
+ case GGML_TYPE_Q4_0:
1554
+ case GGML_TYPE_Q4_1:
1555
+ case GGML_TYPE_Q5_0:
1556
+ case GGML_TYPE_Q5_1:
1557
+ case GGML_TYPE_Q8_0:
1558
+ case GGML_TYPE_Q2_K:
1559
+ case GGML_TYPE_Q3_K:
1560
+ case GGML_TYPE_Q4_K:
1561
+ case GGML_TYPE_Q5_K:
1562
+ case GGML_TYPE_Q6_K:
1563
+ return true;
1564
+ default:
1565
+ return false;
1566
+ }
1567
+ }
1568
+
1569
+ template<typename dst_t>
1570
+ static __global__ void dequantize_block_iq2_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
1571
+
1572
+ const int i = blockIdx.x;
1573
+ const block_iq2_xxs * x = (const block_iq2_xxs *) vx;
1574
+
1575
+ const int tid = threadIdx.x;
1576
+ #if QK_K == 256
1577
+ const int il = tid/8; // 0...3
1578
+ const int ib = tid%8; // 0...7
1579
+ dst_t * y = yy + i*QK_K + 32*ib + 8*il;
1580
+ const uint16_t * q2 = x[i].qs + 4*ib;
1581
+ const uint8_t * aux8 = (const uint8_t *)q2;
1582
+ const uint8_t * grid = (const uint8_t *)(iq2xxs_grid + aux8[il]);
1583
+ const uint32_t aux32 = q2[2] | (q2[3] << 16);
1584
+ const float d = (float)x[i].d * (0.5f + (aux32 >> 28)) * 0.25f;
1585
+ const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*il) & 127];
1586
+ for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
1587
+ #else
1588
+ assert(false);
1589
+ #endif
1590
+
1591
+ }
1592
+
1593
+ template<typename dst_t>
1594
+ static __global__ void dequantize_block_iq2_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
1595
+
1596
+ const int i = blockIdx.x;
1597
+ const block_iq2_xs * x = (const block_iq2_xs *) vx;
1598
+
1599
+ const int tid = threadIdx.x;
1600
+ #if QK_K == 256
1601
+ const int il = tid/8; // 0...3
1602
+ const int ib = tid%8; // 0...7
1603
+ dst_t * y = yy + i*QK_K + 32*ib + 8*il;
1604
+ const uint16_t * q2 = x[i].qs + 4*ib;
1605
+ const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[il] & 511));
1606
+ const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f;
1607
+ const uint8_t signs = ksigns_iq2xs[q2[il] >> 9];
1608
+ for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
1609
+ #else
1610
+ assert(false);
1611
+ #endif
1612
+
1613
+ }
1614
+
1295
1615
  static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
1296
1616
 
1297
1617
  static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
@@ -1872,14 +2192,6 @@ static __device__ void convert_f16(const void * vx, const int ib, const int iqs,
1872
2192
  v.y = x[ib + iqs + 1];
1873
2193
  }
1874
2194
 
1875
- static __device__ void convert_f32(const void * vx, const int ib, const int iqs, dfloat2 & v){
1876
- const float * x = (const float *) vx;
1877
-
1878
- // automatic half -> float type cast if dfloat == float
1879
- v.x = x[ib + iqs + 0];
1880
- v.y = x[ib + iqs + 1];
1881
- }
1882
-
1883
2195
  static __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int kx, const int kx_padded) {
1884
2196
  const int ix = blockDim.x*blockIdx.x + threadIdx.x;
1885
2197
 
@@ -1983,7 +2295,7 @@ static __global__ void k_get_rows_float(
1983
2295
 
1984
2296
  template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
1985
2297
  static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int k) {
1986
- const int i = blockDim.x*blockIdx.x + 2*threadIdx.x;
2298
+ const int i = 2*(blockDim.x*blockIdx.x + threadIdx.x);
1987
2299
 
1988
2300
  if (i >= k) {
1989
2301
  return;
@@ -2002,6 +2314,19 @@ static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __
2002
2314
  y[iybs + iqs + y_offset] = v.y;
2003
2315
  }
2004
2316
 
2317
+ template <typename src_t, typename dst_t>
2318
+ static __global__ void convert_unary(const void * __restrict__ vx, dst_t * __restrict__ y, const int k) {
2319
+ const int i = blockDim.x*blockIdx.x + threadIdx.x;
2320
+
2321
+ if (i >= k) {
2322
+ return;
2323
+ }
2324
+
2325
+ const src_t * x = (src_t *) vx;
2326
+
2327
+ y[i] = x[i];
2328
+ }
2329
+
2005
2330
  // VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called
2006
2331
  // MMVQ = mul_mat_vec_q, MMQ = mul_mat_q
2007
2332
 
@@ -3820,6 +4145,91 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_mul_mat(
3820
4145
  return vec_dot_q6_K_q8_1_impl_mmq(&x_ql[index_x], &y_qs[index_y], sc, x_dmf[i * (WARP_SIZE/QI6_K) + i/QI6_K], &y_df[index_y/QI8_1]);
3821
4146
  }
3822
4147
 
4148
+ static __device__ __forceinline__ float vec_dot_iq2_xxs_q8_1(
4149
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
4150
+ #if QK_K == 256
4151
+ const block_iq2_xxs * bq2 = (const block_iq2_xxs *) vbq;
4152
+
4153
+ #if QR2_XXS == 8
4154
+ const int ib32 = iqs;
4155
+ const uint16_t * q2 = bq2->qs + 4*ib32;
4156
+ const uint8_t * aux8 = (const uint8_t *)q2;
4157
+ const int8_t * q8 = bq8_1[ib32].qs;
4158
+ uint32_t aux32 = q2[2] | (q2[3] << 16);
4159
+ int sumi = 0;
4160
+ for (int l = 0; l < 4; ++l) {
4161
+ const uint8_t * grid = (const uint8_t *)(iq2xxs_grid + aux8[l]);
4162
+ const uint8_t signs = ksigns_iq2xs[aux32 & 127];
4163
+ for (int j = 0; j < 8; ++j) {
4164
+ sumi += q8[j] * grid[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
4165
+ }
4166
+ q8 += 8;
4167
+ aux32 >>= 7;
4168
+ }
4169
+ const float d = (float)bq2->d * (0.5f + aux32) * (float)bq8_1[ib32].ds.x * 0.25f;
4170
+ return d * sumi;
4171
+ #else
4172
+ // iqs is 0...15
4173
+ const int ib32 = iqs/2;
4174
+ const int il = iqs%2;
4175
+ const uint16_t * q2 = bq2->qs + 4*ib32;
4176
+ const uint8_t * aux8 = (const uint8_t *)q2;
4177
+ const uint8_t * grid1 = (const uint8_t *)(iq2xxs_grid + aux8[2*il+0]);
4178
+ const uint8_t * grid2 = (const uint8_t *)(iq2xxs_grid + aux8[2*il+1]);
4179
+ const uint32_t aux32 = q2[2] | (q2[3] << 16);
4180
+ const float d = (float)bq2->d * (0.5f + (aux32 >> 28)) * (float)bq8_1[ib32].ds.x * 0.25f;
4181
+ const uint8_t signs1 = ksigns_iq2xs[(aux32 >> 14*il) & 127];
4182
+ const uint8_t signs2 = ksigns_iq2xs[(aux32 >> (14*il + 7)) & 127];
4183
+ const int8_t * q8 = bq8_1[ib32].qs + 16*il;
4184
+ int sumi1 = 0, sumi2 = 0;
4185
+ for (int j = 0; j < 8; ++j) {
4186
+ sumi1 += q8[j+0] * grid1[j] * (signs1 & kmask_iq2xs[j] ? -1 : 1);
4187
+ sumi2 += q8[j+8] * grid2[j] * (signs2 & kmask_iq2xs[j] ? -1 : 1);
4188
+ }
4189
+ return d * (sumi1 + sumi2);
4190
+ #endif
4191
+ #else
4192
+ assert(false);
4193
+ return 0.f;
4194
+ #endif
4195
+ }
4196
+
4197
+ static __device__ __forceinline__ float vec_dot_iq2_xs_q8_1(
4198
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
4199
+ #if QK_K == 256
4200
+ const block_iq2_xs * bq2 = (const block_iq2_xs *) vbq;
4201
+
4202
+ const int ib32 = iqs;
4203
+ const uint16_t * q2 = bq2->qs + 4*ib32;
4204
+ const int8_t * q8 = bq8_1[ib32].qs;
4205
+ const uint8_t ls1 = bq2->scales[ib32] & 0xf;
4206
+ const uint8_t ls2 = bq2->scales[ib32] >> 4;
4207
+ int sumi1 = 0;
4208
+ for (int l = 0; l < 2; ++l) {
4209
+ const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[l] & 511));
4210
+ const uint8_t signs = ksigns_iq2xs[q2[l] >> 9];
4211
+ for (int j = 0; j < 8; ++j) {
4212
+ sumi1 += q8[j] * grid[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
4213
+ }
4214
+ q8 += 8;
4215
+ }
4216
+ int sumi2 = 0;
4217
+ for (int l = 2; l < 4; ++l) {
4218
+ const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[l] & 511));
4219
+ const uint8_t signs = ksigns_iq2xs[q2[l] >> 9];
4220
+ for (int j = 0; j < 8; ++j) {
4221
+ sumi2 += q8[j] * grid[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
4222
+ }
4223
+ q8 += 8;
4224
+ }
4225
+ const float d = (float)bq2->d * (float)bq8_1[ib32].ds.x * 0.25f;
4226
+ return d * ((0.5f + ls1) * sumi1 + (0.5f + ls2) * sumi2);
4227
+ #else
4228
+ assert(false);
4229
+ return 0.f;
4230
+ #endif
4231
+ }
4232
+
3823
4233
  template <int qk, int qr, int qi, bool need_sum, typename block_q_t, int mmq_x, int mmq_y, int nwarps,
3824
4234
  allocate_tiles_cuda_t allocate_tiles, load_tiles_cuda_t load_tiles, int vdr, vec_dot_q_mul_mat_cuda_t vec_dot>
3825
4235
  static __device__ __forceinline__ void mul_mat_q(
@@ -5201,75 +5611,233 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int
5201
5611
  dst[i] = x[i] - (col > n_past + row % rows_per_channel) * FLT_MAX;
5202
5612
  }
5203
5613
 
5204
- static __global__ void soft_max_f32(const float * x, const float * y, float * dst, const int ncols, const int nrows_y, const float scale) {
5614
+ template <bool vals_smem, int ncols_template, int block_size_template, bool need_check>
5615
+ static __global__ void soft_max_f16(const float * x, const float * y, float * dst, const int ncols_par, const int nrows_y, const float scale) {
5616
+ #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
5617
+ const int ncols_data = ncols_template == 0 ? ncols_par : ncols_template;
5618
+ const int ncols_smem = GGML_PAD(ncols_data, 2*WARP_SIZE)/2;
5619
+
5620
+ const int tid = threadIdx.x;
5621
+ const int rowx = blockIdx.x;
5622
+ const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension
5623
+
5624
+ const int block_size = block_size_template == 0 ? blockDim.x : block_size_template;
5625
+
5626
+ const int warp_id = threadIdx.x / WARP_SIZE;
5627
+ const int lane_id = threadIdx.x % WARP_SIZE;
5628
+
5629
+ extern __shared__ half data_soft_max_f16[];
5630
+ half * buf_iw = data_soft_max_f16 + 0; // shared memory buffer for inter-warp communication
5631
+ // (shared memory) buffer to cache values between iterations:
5632
+ half2 * vals = vals_smem ? (half2 *) (buf_iw + WARP_SIZE) : (half2 *) (dst + rowx*ncols_data);
5633
+ // if the buffer is larger than max. shared memory per block, use dst as temp. buffer instead
5634
+ // in that case col_smem == col_data must be enforced to avoid race conditions
5635
+
5636
+ half2 max_val = make_half2(-INFINITY, -INFINITY);
5637
+
5638
+ #pragma unroll
5639
+ for (int col0 = 0; col0 < ncols_smem; col0 += block_size) {
5640
+ const int col_data = 2*col0 + 2*WARP_SIZE*warp_id + lane_id;
5641
+ const int col_smem = vals_smem ? col0 + tid : col_data;
5642
+
5643
+ const int ix = rowx*ncols_data + col_data;
5644
+ const int iy = rowy*ncols_data + col_data;
5645
+
5646
+ half2 val;
5647
+ if (need_check && col_data + 0 >= ncols_data) {
5648
+ val.x = -INFINITY;
5649
+ } else {
5650
+ val.x = x[ix + 0]*scale + (y ? y[iy + 0] : 0.0f);
5651
+ }
5652
+ if (need_check && col_data + WARP_SIZE >= ncols_data) {
5653
+ val.y = -INFINITY;
5654
+ } else {
5655
+ val.y = x[ix + WARP_SIZE]*scale + (y ? y[iy + WARP_SIZE] : 0.0f);
5656
+ }
5657
+ if (!need_check || col_smem < (vals_smem ? ncols_smem : ncols_data)) {
5658
+ vals[col_smem] = val;
5659
+ }
5660
+ max_val = __hmax2(max_val, val);
5661
+ }
5662
+
5663
+ // find the max value in the block
5664
+ max_val = warp_reduce_max(max_val);
5665
+ if (block_size > WARP_SIZE) {
5666
+ if (warp_id == 0) {
5667
+ buf_iw[lane_id] = -INFINITY;
5668
+ }
5669
+ __syncthreads();
5670
+
5671
+ if (lane_id == 0) {
5672
+ buf_iw[warp_id] = __hmax(max_val.x, max_val.y);
5673
+ }
5674
+ __syncthreads();
5675
+
5676
+ max_val = __half2half2(buf_iw[lane_id]);
5677
+ max_val = warp_reduce_max(max_val);
5678
+ } else {
5679
+ max_val = __half2half2(__hmax(max_val.x, max_val.y));
5680
+ }
5681
+
5682
+ half2 tmp = make_half2(0.0f, 0.0f); // partial sums
5683
+
5684
+ #pragma unroll
5685
+ for (int col0 = 0; col0 < ncols_smem; col0 += block_size) {
5686
+ const int col_smem = vals_smem ? col0 + tid : 2*col0 + 2*warp_id*WARP_SIZE + lane_id;
5687
+
5688
+ if (ncols_template == 0 && col_smem >= (vals_smem ? ncols_smem : ncols_data)) {
5689
+ break;
5690
+ }
5691
+
5692
+ const half2 val = h2exp(vals[col_smem] - max_val);
5693
+
5694
+ tmp += val;
5695
+ vals[col_smem] = val;
5696
+ }
5697
+
5698
+ // find the sum of exps in the block
5699
+ tmp = warp_reduce_sum(tmp);
5700
+ if (block_size > WARP_SIZE) {
5701
+ if (warp_id == 0) {
5702
+ buf_iw[lane_id] = 0.0f;
5703
+ }
5704
+ __syncthreads();
5705
+
5706
+ if (lane_id == 0) {
5707
+ buf_iw[warp_id] = tmp.x + tmp.y;
5708
+ }
5709
+ __syncthreads();
5710
+
5711
+ tmp = __half2half2(buf_iw[lane_id]);
5712
+ tmp = warp_reduce_sum(tmp);
5713
+ } else {
5714
+ tmp = __half2half2(tmp.x + tmp.y);
5715
+ }
5716
+
5717
+ const half2 inv_sum = make_half2(1.0f, 1.0f) / tmp;
5718
+
5719
+ #pragma unroll
5720
+ for (int col0 = 0; col0 < ncols_smem; col0 += block_size) {
5721
+ const int col_data = 2*col0 + 2*WARP_SIZE*warp_id + lane_id;
5722
+ const int col_smem = vals_smem ? col0 + tid : col_data;
5723
+
5724
+ const int idst = rowx*ncols_data + col_data;
5725
+ const half2 result = vals[col_smem] * inv_sum;
5726
+
5727
+ if (need_check && col_data + 0 >= ncols_data) {
5728
+ return;
5729
+ }
5730
+ dst[idst] = result.x;
5731
+
5732
+ if (need_check && col_data + WARP_SIZE >= ncols_data) {
5733
+ return;
5734
+ }
5735
+
5736
+ dst[idst + WARP_SIZE] = result.y;
5737
+ }
5738
+ #else
5739
+ (void) x; (void) y; (void) dst; (void) ncols_par; (void) nrows_y; (void) scale;
5740
+ bad_arch();
5741
+ #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
5742
+ }
5743
+
5744
+ template <bool vals_smem, int ncols_template, int block_size_template>
5745
+ static __global__ void soft_max_f32(const float * x, const float * y, float * dst, const int ncols_par, const int nrows_y, const float scale) {
5746
+ const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
5747
+
5205
5748
  const int tid = threadIdx.x;
5206
5749
  const int rowx = blockIdx.x;
5207
5750
  const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension
5208
5751
 
5209
- const int block_size = blockDim.x;
5752
+ const int block_size = block_size_template == 0 ? blockDim.x : block_size_template;
5210
5753
 
5211
5754
  const int warp_id = threadIdx.x / WARP_SIZE;
5212
5755
  const int lane_id = threadIdx.x % WARP_SIZE;
5213
5756
 
5214
- __shared__ float buf[CUDA_SOFT_MAX_BLOCK_SIZE/WARP_SIZE];
5757
+ extern __shared__ float data_soft_max_f32[];
5758
+ float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication
5759
+ // shared memory buffer to cache values between iterations:
5760
+ float * vals = vals_smem ? buf_iw + WARP_SIZE : dst + rowx*ncols;
5215
5761
 
5216
5762
  float max_val = -INFINITY;
5217
5763
 
5218
- for (int col = tid; col < ncols; col += block_size) {
5764
+ #pragma unroll
5765
+ for (int col0 = 0; col0 < ncols; col0 += block_size) {
5766
+ const int col = col0 + tid;
5767
+
5768
+ if (ncols_template == 0 && col >= ncols) {
5769
+ break;
5770
+ }
5771
+
5219
5772
  const int ix = rowx*ncols + col;
5220
5773
  const int iy = rowy*ncols + col;
5221
- max_val = max(max_val, x[ix]*scale + (y ? y[iy] : 0.0f));
5774
+
5775
+ const float val = x[ix]*scale + (y ? y[iy] : 0.0f);
5776
+ vals[col] = val;
5777
+ max_val = max(max_val, val);
5222
5778
  }
5223
5779
 
5224
5780
  // find the max value in the block
5225
5781
  max_val = warp_reduce_max(max_val);
5226
5782
  if (block_size > WARP_SIZE) {
5227
5783
  if (warp_id == 0) {
5228
- buf[lane_id] = -INFINITY;
5784
+ buf_iw[lane_id] = -INFINITY;
5229
5785
  }
5230
5786
  __syncthreads();
5231
5787
 
5232
5788
  if (lane_id == 0) {
5233
- buf[warp_id] = max_val;
5789
+ buf_iw[warp_id] = max_val;
5234
5790
  }
5235
5791
  __syncthreads();
5236
5792
 
5237
- max_val = buf[lane_id];
5793
+ max_val = buf_iw[lane_id];
5238
5794
  max_val = warp_reduce_max(max_val);
5239
5795
  }
5240
5796
 
5241
- float tmp = 0.f;
5797
+ float tmp = 0.0f; // partial sum
5242
5798
 
5243
- for (int col = tid; col < ncols; col += block_size) {
5244
- const int ix = rowx*ncols + col;
5245
- const int iy = rowy*ncols + col;
5246
- const float val = expf((x[ix]*scale + (y ? y[iy] : 0.0f)) - max_val);
5799
+ #pragma unroll
5800
+ for (int col0 = 0; col0 < ncols; col0 += block_size) {
5801
+ const int col = col0 + tid;
5802
+
5803
+ if (ncols_template == 0 && col >= ncols) {
5804
+ break;
5805
+ }
5806
+
5807
+ const float val = expf(vals[col] - max_val);
5247
5808
  tmp += val;
5248
- dst[ix] = val;
5809
+ vals[col] = val;
5249
5810
  }
5250
5811
 
5251
5812
  // find the sum of exps in the block
5252
5813
  tmp = warp_reduce_sum(tmp);
5253
5814
  if (block_size > WARP_SIZE) {
5254
5815
  if (warp_id == 0) {
5255
- buf[lane_id] = 0.f;
5816
+ buf_iw[lane_id] = 0.0f;
5256
5817
  }
5257
5818
  __syncthreads();
5258
5819
 
5259
5820
  if (lane_id == 0) {
5260
- buf[warp_id] = tmp;
5821
+ buf_iw[warp_id] = tmp;
5261
5822
  }
5262
5823
  __syncthreads();
5263
5824
 
5264
- tmp = buf[lane_id];
5825
+ tmp = buf_iw[lane_id];
5265
5826
  tmp = warp_reduce_sum(tmp);
5266
5827
  }
5267
5828
 
5268
- const float inv_tmp = 1.f / tmp;
5829
+ const float inv_sum = 1.0f / tmp;
5269
5830
 
5270
- for (int col = tid; col < ncols; col += block_size) {
5271
- const int i = rowx*ncols + col;
5272
- dst[i] *= inv_tmp;
5831
+ #pragma unroll
5832
+ for (int col0 = 0; col0 < ncols; col0 += block_size) {
5833
+ const int col = col0 + tid;
5834
+
5835
+ if (ncols_template == 0 && col >= ncols) {
5836
+ return;
5837
+ }
5838
+
5839
+ const int idst = rowx*ncols + col;
5840
+ dst[idst] = vals[col] * inv_sum;
5273
5841
  }
5274
5842
  }
5275
5843
 
@@ -5609,7 +6177,7 @@ static void quantize_row_q8_1_cuda(const float * x, void * vy, const int kx, con
5609
6177
 
5610
6178
  template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
5611
6179
  static void dequantize_block_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) {
5612
- const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
6180
+ const int num_blocks = (k + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE);
5613
6181
  dequantize_block<qk, qr, dequantize_kernel><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
5614
6182
  }
5615
6183
 
@@ -5659,6 +6227,24 @@ static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int k, cu
5659
6227
  #endif
5660
6228
  }
5661
6229
 
6230
+ template<typename dst_t>
6231
+ static void dequantize_row_iq2_xxs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
6232
+ const int nb = k / QK_K;
6233
+ dequantize_block_iq2_xxs<<<nb, 32, 0, stream>>>(vx, y);
6234
+ }
6235
+
6236
+ template<typename dst_t>
6237
+ static void dequantize_row_iq2_xs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
6238
+ const int nb = k / QK_K;
6239
+ dequantize_block_iq2_xs<<<nb, 32, 0, stream>>>(vx, y);
6240
+ }
6241
+
6242
+ template <typename src_t, typename dst_t>
6243
+ static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) {
6244
+ const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
6245
+ convert_unary<src_t><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
6246
+ }
6247
+
5662
6248
  static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
5663
6249
  switch (type) {
5664
6250
  case GGML_TYPE_Q4_0:
@@ -5681,8 +6267,12 @@ static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
5681
6267
  return dequantize_row_q5_K_cuda;
5682
6268
  case GGML_TYPE_Q6_K:
5683
6269
  return dequantize_row_q6_K_cuda;
6270
+ case GGML_TYPE_IQ2_XXS:
6271
+ return dequantize_row_iq2_xxs_cuda;
6272
+ case GGML_TYPE_IQ2_XS:
6273
+ return dequantize_row_iq2_xs_cuda;
5684
6274
  case GGML_TYPE_F32:
5685
- return dequantize_block_cuda<1, 1, convert_f32>;
6275
+ return convert_unary_cuda<float>;
5686
6276
  default:
5687
6277
  return nullptr;
5688
6278
  }
@@ -5710,8 +6300,12 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
5710
6300
  return dequantize_row_q5_K_cuda;
5711
6301
  case GGML_TYPE_Q6_K:
5712
6302
  return dequantize_row_q6_K_cuda;
6303
+ case GGML_TYPE_IQ2_XXS:
6304
+ return dequantize_row_iq2_xxs_cuda;
6305
+ case GGML_TYPE_IQ2_XS:
6306
+ return dequantize_row_iq2_xs_cuda;
5713
6307
  case GGML_TYPE_F16:
5714
- return dequantize_block_cuda<1, 1, convert_f16>;
6308
+ return convert_unary_cuda<half>;
5715
6309
  default:
5716
6310
  return nullptr;
5717
6311
  }
@@ -5904,6 +6498,24 @@ static void mul_mat_vec_q6_K_q8_1_cuda(const void * vx, const void * vy, float *
5904
6498
  <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
5905
6499
  }
5906
6500
 
6501
+ static void mul_mat_vec_iq2_xxs_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
6502
+ GGML_ASSERT(ncols % QK_K == 0);
6503
+ const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
6504
+ const dim3 block_nums(block_num_y, 1, 1);
6505
+ const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
6506
+ mul_mat_vec_q<QK_K, QI2_XXS, block_iq2_xxs, 1, vec_dot_iq2_xxs_q8_1>
6507
+ <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
6508
+ }
6509
+
6510
+ static void mul_mat_vec_iq2_xs_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
6511
+ GGML_ASSERT(ncols % QK_K == 0);
6512
+ const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
6513
+ const dim3 block_nums(block_num_y, 1, 1);
6514
+ const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
6515
+ mul_mat_vec_q<QK_K, QI2_XS, block_iq2_xs, 1, vec_dot_iq2_xs_q8_1>
6516
+ <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
6517
+ }
6518
+
5907
6519
  static void ggml_mul_mat_q4_0_q8_1_cuda(
5908
6520
  const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
5909
6521
  const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
@@ -6543,12 +7155,90 @@ static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols
6543
7155
  diag_mask_inf_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x, rows_per_channel, n_past);
6544
7156
  }
6545
7157
 
7158
+ static void soft_max_f16_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) {
7159
+ int nth = WARP_SIZE;
7160
+ while (nth < ncols_x/2 && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
7161
+ const dim3 block_dims(nth, 1, 1);
7162
+ const dim3 block_nums(nrows_x, 1, 1);
7163
+ const size_t shmem = (GGML_PAD(ncols_x, 2*WARP_SIZE) + WARP_SIZE)*sizeof(half);
7164
+ static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
7165
+ if (shmem <= g_device_caps[g_main_device].smpb) {
7166
+ switch (ncols_x) {
7167
+ case 32:
7168
+ soft_max_f16<true, 32, 32, true><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
7169
+ break;
7170
+ case 64:
7171
+ soft_max_f16<true, 64, 32, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
7172
+ break;
7173
+ case 128:
7174
+ soft_max_f16<true, 128, 64, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
7175
+ break;
7176
+ case 256:
7177
+ soft_max_f16<true, 256, 128, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
7178
+ break;
7179
+ case 512:
7180
+ soft_max_f16<true, 512, 256, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
7181
+ break;
7182
+ case 1024:
7183
+ soft_max_f16<true, 1024, 512, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
7184
+ break;
7185
+ case 2048:
7186
+ soft_max_f16<true, 2048, 1024, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
7187
+ break;
7188
+ case 4096:
7189
+ soft_max_f16<true, 4096, 1024, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
7190
+ break;
7191
+ default:
7192
+ soft_max_f16<true, 0, 0, true><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
7193
+ break;
7194
+ }
7195
+ } else {
7196
+ const size_t shmem_low = WARP_SIZE*sizeof(half);
7197
+ soft_max_f16<false, 0, 0, true><<<block_nums, block_dims, shmem_low, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
7198
+ }
7199
+ }
7200
+
6546
7201
  static void soft_max_f32_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) {
6547
7202
  int nth = WARP_SIZE;
6548
7203
  while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
6549
7204
  const dim3 block_dims(nth, 1, 1);
6550
7205
  const dim3 block_nums(nrows_x, 1, 1);
6551
- soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
7206
+ const size_t shmem = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float);
7207
+ static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
7208
+ if (shmem < g_device_caps[g_main_device].smpb) {
7209
+ switch (ncols_x) {
7210
+ case 32:
7211
+ soft_max_f32<true, 32, 32><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
7212
+ break;
7213
+ case 64:
7214
+ soft_max_f32<true, 64, 64><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
7215
+ break;
7216
+ case 128:
7217
+ soft_max_f32<true, 128, 128><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
7218
+ break;
7219
+ case 256:
7220
+ soft_max_f32<true, 256, 256><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
7221
+ break;
7222
+ case 512:
7223
+ soft_max_f32<true, 512, 512><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
7224
+ break;
7225
+ case 1024:
7226
+ soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
7227
+ break;
7228
+ case 2048:
7229
+ soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
7230
+ break;
7231
+ case 4096:
7232
+ soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
7233
+ break;
7234
+ default:
7235
+ soft_max_f32<true, 0, 0><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
7236
+ break;
7237
+ }
7238
+ } else {
7239
+ const size_t shmem_low = WARP_SIZE*sizeof(float);
7240
+ soft_max_f32<false, 0, 0><<<block_nums, block_dims, shmem_low, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
7241
+ }
6552
7242
  }
6553
7243
 
6554
7244
  static void im2col_f32_f16_cuda(const float* x, half* dst,
@@ -6863,6 +7553,7 @@ void ggml_init_cublas() {
6863
7553
  #else
6864
7554
  g_device_caps[id].cc = 100*prop.major + 10*prop.minor;
6865
7555
  #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
7556
+ g_device_caps[id].smpb = prop.sharedMemPerBlock;
6866
7557
  }
6867
7558
  for (int id = 0; id < g_device_count; ++id) {
6868
7559
  g_tensor_split[id] /= total_vram;
@@ -7396,6 +8087,8 @@ static int64_t get_row_rounding(ggml_type type) {
7396
8087
  case GGML_TYPE_Q4_K:
7397
8088
  case GGML_TYPE_Q5_K:
7398
8089
  case GGML_TYPE_Q6_K:
8090
+ case GGML_TYPE_IQ2_XXS:
8091
+ case GGML_TYPE_IQ2_XS:
7399
8092
  return max_compute_capability >= CC_RDNA2 ? 128 : 64;
7400
8093
  default:
7401
8094
  GGML_ASSERT(false);
@@ -7416,6 +8109,8 @@ static int64_t get_row_rounding(ggml_type type) {
7416
8109
  case GGML_TYPE_Q3_K:
7417
8110
  case GGML_TYPE_Q4_K:
7418
8111
  case GGML_TYPE_Q5_K:
8112
+ case GGML_TYPE_IQ2_XXS:
8113
+ case GGML_TYPE_IQ2_XS:
7419
8114
  return max_compute_capability >= CC_VOLTA ? 128 : 64;
7420
8115
  case GGML_TYPE_Q6_K:
7421
8116
  return 64;
@@ -7466,6 +8161,12 @@ static void ggml_cuda_op_mul_mat_vec_q(
7466
8161
  case GGML_TYPE_Q6_K:
7467
8162
  mul_mat_vec_q6_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
7468
8163
  break;
8164
+ case GGML_TYPE_IQ2_XXS:
8165
+ mul_mat_vec_iq2_xxs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
8166
+ break;
8167
+ case GGML_TYPE_IQ2_XS:
8168
+ mul_mat_vec_iq2_xs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
8169
+ break;
7469
8170
  default:
7470
8171
  GGML_ASSERT(false);
7471
8172
  break;
@@ -7873,7 +8574,21 @@ static void ggml_cuda_op_soft_max(
7873
8574
  float scale = 1.0f;
7874
8575
  memcpy(&scale, dst->op_params, sizeof(float));
7875
8576
 
7876
- soft_max_f32_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream);
8577
+ #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
8578
+ const bool use_f16_soft_max = false;
8579
+ #else
8580
+ #ifdef GGML_CUDA_F16
8581
+ const bool use_f16_soft_max = true;
8582
+ #else
8583
+ const bool use_f16_soft_max = false;
8584
+ #endif // GGML_CUDA_F16
8585
+ #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
8586
+
8587
+ if (use_f16_soft_max) {
8588
+ soft_max_f16_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream);
8589
+ } else {
8590
+ soft_max_f32_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream);
8591
+ }
7877
8592
 
7878
8593
  (void) dst;
7879
8594
  }
@@ -8682,6 +9397,8 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
8682
9397
 
8683
9398
  #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
8684
9399
 
9400
+ use_mul_mat_q = use_mul_mat_q && ggml_cuda_supports_mmq(src0->type);
9401
+
8685
9402
  // debug helpers
8686
9403
  //printf("src0: %8d %8d %8d %8d\n", src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]);
8687
9404
  //printf(" %8d %8d %8d %8d\n", src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]);
@@ -9689,8 +10406,8 @@ static void ggml_backend_cuda_buffer_set_tensor(ggml_backend_buffer_t buffer, gg
9689
10406
 
9690
10407
  ggml_cuda_set_device(ctx->device);
9691
10408
  CUDA_CHECK(cudaDeviceSynchronize());
9692
-
9693
10409
  CUDA_CHECK(cudaMemcpy((char *)tensor->data + offset, data, size, cudaMemcpyHostToDevice));
10410
+ CUDA_CHECK(cudaDeviceSynchronize());
9694
10411
  }
9695
10412
 
9696
10413
  static void ggml_backend_cuda_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
@@ -9910,7 +10627,7 @@ static void ggml_backend_cuda_graph_plan_compute(ggml_backend_t backend, ggml_ba
9910
10627
  UNUSED(plan);
9911
10628
  }
9912
10629
 
9913
- static void ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
10630
+ static bool ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
9914
10631
  ggml_backend_context_cuda * cuda_ctx = (ggml_backend_context_cuda *)backend->context;
9915
10632
 
9916
10633
  ggml_cuda_set_main_device(cuda_ctx->device);
@@ -9967,6 +10684,8 @@ static void ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph
9967
10684
  }
9968
10685
 
9969
10686
  UNUSED(backend);
10687
+
10688
+ return true;
9970
10689
  }
9971
10690
 
9972
10691
  static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_tensor * op) {