llama_cpp 0.14.2 → 0.14.3

Sign up to get free protection for your applications and to get access to all the features.
@@ -17,29 +17,17 @@ extern "C" {
17
17
 
18
18
  #define GGML_CUDA_MAX_DEVICES 16
19
19
 
20
- // Always success. To check if CUDA is actually loaded, use `ggml_cublas_loaded`.
21
- GGML_API GGML_CALL void ggml_init_cublas(void);
22
-
23
- // Returns `true` if there are available CUDA devices and cublas loads successfully; otherwise, it returns `false`.
24
- GGML_API GGML_CALL bool ggml_cublas_loaded(void);
25
-
26
- GGML_API GGML_CALL void * ggml_cuda_host_malloc(size_t size);
27
- GGML_API GGML_CALL void ggml_cuda_host_free(void * ptr);
28
-
29
- GGML_API GGML_CALL bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
30
- GGML_API GGML_CALL bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor);
31
-
32
- GGML_API GGML_CALL int ggml_cuda_get_device_count(void);
33
- GGML_API GGML_CALL void ggml_cuda_get_device_description(int device, char * description, size_t description_size);
34
-
35
20
  // backend API
36
21
  GGML_API GGML_CALL ggml_backend_t ggml_backend_cuda_init(int device);
37
22
 
38
23
  GGML_API GGML_CALL bool ggml_backend_is_cuda(ggml_backend_t backend);
39
24
 
25
+ // device buffer
40
26
  GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device);
27
+
41
28
  // split tensor buffer that splits matrices by rows across multiple devices
42
29
  GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_cuda_split_buffer_type(const float * tensor_split);
30
+
43
31
  // pinned host buffer for use with the CPU backend for faster copies between CPU and GPU
44
32
  GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_cuda_host_buffer_type(void);
45
33
 
@@ -47,6 +35,9 @@ GGML_API GGML_CALL int ggml_backend_cuda_get_device_count(void);
47
35
  GGML_API GGML_CALL void ggml_backend_cuda_get_device_description(int device, char * description, size_t description_size);
48
36
  GGML_API GGML_CALL void ggml_backend_cuda_get_device_memory(int device, size_t * free, size_t * total);
49
37
 
38
+ GGML_API GGML_CALL bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size);
39
+ GGML_API GGML_CALL void ggml_backend_cuda_unregister_host_buffer(void * buffer);
40
+
50
41
  #ifdef __cplusplus
51
42
  }
52
43
  #endif
@@ -1951,6 +1951,7 @@ static struct ggml_backend_i kompute_backend_i = {
1951
1951
  /* .graph_plan_compute = */ NULL,
1952
1952
  /* .graph_compute = */ ggml_backend_kompute_graph_compute,
1953
1953
  /* .supports_op = */ ggml_backend_kompute_supports_op,
1954
+ /* .offload_op = */ NULL,
1954
1955
  /* .event_new = */ NULL,
1955
1956
  /* .event_free = */ NULL,
1956
1957
  /* .event_record = */ NULL,
@@ -173,8 +173,9 @@ enum ggml_metal_kernel_type {
173
173
  GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
174
174
  GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0,
175
175
  GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1,
176
- //GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0,
177
- //GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,
176
+ GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0,
177
+ GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,
178
+ GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,
178
179
  GGML_METAL_KERNEL_TYPE_CPY_F16_F16,
179
180
  GGML_METAL_KERNEL_TYPE_CPY_F16_F32,
180
181
  GGML_METAL_KERNEL_TYPE_CONCAT,
@@ -598,8 +599,9 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
598
599
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
599
600
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true);
600
601
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true);
601
- //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true);
602
- //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true);
602
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true);
603
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true);
604
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true);
603
605
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
604
606
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true);
605
607
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true);
@@ -739,6 +741,9 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
739
741
  case GGML_TYPE_Q8_0:
740
742
  case GGML_TYPE_Q4_0:
741
743
  case GGML_TYPE_Q4_1:
744
+ case GGML_TYPE_Q5_0:
745
+ case GGML_TYPE_Q5_1:
746
+ case GGML_TYPE_IQ4_NL:
742
747
  return true;
743
748
  default:
744
749
  return false;
@@ -1387,6 +1392,14 @@ static enum ggml_status ggml_metal_graph_compute(
1387
1392
  (ne11 > ne11_mm_min || (ggml_is_quantized(src0t) && ne12 > 1))) {
1388
1393
  //printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
1389
1394
 
1395
+ // some Metal matrix data types require aligned pointers
1396
+ // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
1397
+ switch (src0->type) {
1398
+ case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
1399
+ case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
1400
+ default: break;
1401
+ }
1402
+
1390
1403
  id<MTLComputePipelineState> pipeline = nil;
1391
1404
 
1392
1405
  switch (src0->type) {
@@ -1701,6 +1714,14 @@ static enum ggml_status ggml_metal_graph_compute(
1701
1714
  ne20 % 32 == 0 && ne20 >= 64 &&
1702
1715
  ne11 > ne11_mm_min) {
1703
1716
 
1717
+ // some Metal matrix data types require aligned pointers
1718
+ // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
1719
+ switch (src0->type) {
1720
+ case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
1721
+ case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
1722
+ default: break;
1723
+ }
1724
+
1704
1725
  id<MTLComputePipelineState> pipeline = nil;
1705
1726
 
1706
1727
  switch (src2->type) {
@@ -2431,13 +2452,14 @@ static enum ggml_status ggml_metal_graph_compute(
2431
2452
  GGML_ASSERT(ne0 % ggml_blck_size(dst->type) == 0);
2432
2453
 
2433
2454
  switch (dstt) {
2434
- case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break;
2435
- case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break;
2436
- case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break;
2437
- case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break;
2438
- case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break;
2439
- //case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0].pipeline; break;
2440
- //case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1].pipeline; break;
2455
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break;
2456
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break;
2457
+ case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break;
2458
+ case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break;
2459
+ case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break;
2460
+ case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0].pipeline; break;
2461
+ case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1].pipeline; break;
2462
+ case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL].pipeline; break;
2441
2463
  default: GGML_ASSERT(false && "not implemented");
2442
2464
  };
2443
2465
  } break;
@@ -2837,6 +2859,7 @@ static struct ggml_backend_i ggml_backend_metal_i = {
2837
2859
  /* .graph_plan_compute = */ NULL,
2838
2860
  /* .graph_compute = */ ggml_backend_metal_graph_compute,
2839
2861
  /* .supports_op = */ ggml_backend_metal_supports_op,
2862
+ /* .offload_op = */ NULL,
2840
2863
  /* .event_new = */ NULL,
2841
2864
  /* .event_free = */ NULL,
2842
2865
  /* .event_record = */ NULL,
@@ -2388,6 +2388,242 @@ kernel void kernel_cpy_f32_q4_1(
2388
2388
  }
2389
2389
  }
2390
2390
 
2391
+ kernel void kernel_cpy_f32_q5_0(
2392
+ device const float * src0,
2393
+ device void * dst,
2394
+ constant int64_t & ne00,
2395
+ constant int64_t & ne01,
2396
+ constant int64_t & ne02,
2397
+ constant int64_t & ne03,
2398
+ constant uint64_t & nb00,
2399
+ constant uint64_t & nb01,
2400
+ constant uint64_t & nb02,
2401
+ constant uint64_t & nb03,
2402
+ constant int64_t & ne0,
2403
+ constant int64_t & ne1,
2404
+ constant int64_t & ne2,
2405
+ constant int64_t & ne3,
2406
+ constant uint64_t & nb0,
2407
+ constant uint64_t & nb1,
2408
+ constant uint64_t & nb2,
2409
+ constant uint64_t & nb3,
2410
+ uint3 tgpig[[threadgroup_position_in_grid]],
2411
+ uint3 tpitg[[thread_position_in_threadgroup]],
2412
+ uint3 ntg[[threads_per_threadgroup]]) {
2413
+ const int64_t i03 = tgpig[2];
2414
+ const int64_t i02 = tgpig[1];
2415
+ const int64_t i01 = tgpig[0];
2416
+
2417
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
2418
+
2419
+ const int64_t i3 = n / (ne2*ne1*ne0);
2420
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
2421
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
2422
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK5_0;
2423
+
2424
+ device block_q5_0 * dst_data = (device block_q5_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
2425
+
2426
+ for (int64_t i00 = tpitg.x*QK5_0; i00 < ne00; i00 += ntg.x*QK5_0) {
2427
+ device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
2428
+
2429
+ float amax = 0.0f; // absolute max
2430
+ float max = 0.0f;
2431
+
2432
+ for (int j = 0; j < QK5_0; j++) {
2433
+ const float v = src[j];
2434
+ if (amax < fabs(v)) {
2435
+ amax = fabs(v);
2436
+ max = v;
2437
+ }
2438
+ }
2439
+
2440
+ const float d = max / -16;
2441
+ const float id = d ? 1.0f/d : 0.0f;
2442
+
2443
+ dst_data[i00/QK5_0].d = d;
2444
+
2445
+ uint32_t qh = 0;
2446
+ for (int j = 0; j < QK5_0/2; ++j) {
2447
+ const float x0 = src[0 + j]*id;
2448
+ const float x1 = src[QK5_0/2 + j]*id;
2449
+
2450
+ const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f));
2451
+ const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f));
2452
+
2453
+ dst_data[i00/QK5_0].qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
2454
+ qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
2455
+ qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
2456
+ }
2457
+ thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
2458
+ for (int j = 0; j < 4; ++j) {
2459
+ dst_data[i00/QK5_0].qh[j] = qh8[j];
2460
+ }
2461
+ }
2462
+ }
2463
+
2464
+ kernel void kernel_cpy_f32_q5_1(
2465
+ device const float * src0,
2466
+ device void * dst,
2467
+ constant int64_t & ne00,
2468
+ constant int64_t & ne01,
2469
+ constant int64_t & ne02,
2470
+ constant int64_t & ne03,
2471
+ constant uint64_t & nb00,
2472
+ constant uint64_t & nb01,
2473
+ constant uint64_t & nb02,
2474
+ constant uint64_t & nb03,
2475
+ constant int64_t & ne0,
2476
+ constant int64_t & ne1,
2477
+ constant int64_t & ne2,
2478
+ constant int64_t & ne3,
2479
+ constant uint64_t & nb0,
2480
+ constant uint64_t & nb1,
2481
+ constant uint64_t & nb2,
2482
+ constant uint64_t & nb3,
2483
+ uint3 tgpig[[threadgroup_position_in_grid]],
2484
+ uint3 tpitg[[thread_position_in_threadgroup]],
2485
+ uint3 ntg[[threads_per_threadgroup]]) {
2486
+ const int64_t i03 = tgpig[2];
2487
+ const int64_t i02 = tgpig[1];
2488
+ const int64_t i01 = tgpig[0];
2489
+
2490
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
2491
+
2492
+ const int64_t i3 = n / (ne2*ne1*ne0);
2493
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
2494
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
2495
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK5_1;
2496
+
2497
+ device block_q5_1 * dst_data = (device block_q5_1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
2498
+
2499
+ for (int64_t i00 = tpitg.x*QK5_1; i00 < ne00; i00 += ntg.x*QK5_1) {
2500
+ device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
2501
+
2502
+ float max = src[0];
2503
+ float min = src[0];
2504
+
2505
+ for (int j = 1; j < QK5_1; j++) {
2506
+ const float v = src[j];
2507
+ min = v < min ? v : min;
2508
+ max = v > max ? v : max;
2509
+ }
2510
+
2511
+ const float d = (max - min) / 31;
2512
+ const float id = d ? 1.0f/d : 0.0f;
2513
+
2514
+ dst_data[i00/QK5_1].d = d;
2515
+ dst_data[i00/QK5_1].m = min;
2516
+
2517
+ uint32_t qh = 0;
2518
+ for (int j = 0; j < QK5_1/2; ++j) {
2519
+ const float x0 = (src[0 + j] - min)*id;
2520
+ const float x1 = (src[QK5_1/2 + j] - min)*id;
2521
+
2522
+ const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
2523
+ const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
2524
+
2525
+ dst_data[i00/QK5_1].qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
2526
+ qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
2527
+ qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2);
2528
+ }
2529
+ thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
2530
+ for (int j = 0; j < 4; ++j) {
2531
+ dst_data[i00/QK5_1].qh[j] = qh8[j];
2532
+ }
2533
+ }
2534
+ }
2535
+
2536
+ static inline int best_index_int8(int n, constant float * val, float x) {
2537
+ if (x <= val[0]) return 0;
2538
+ if (x >= val[n-1]) return n-1;
2539
+ int ml = 0, mu = n-1;
2540
+ while (mu-ml > 1) {
2541
+ int mav = (ml+mu)/2;
2542
+ if (x < val[mav]) mu = mav; else ml = mav;
2543
+ }
2544
+ return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
2545
+ }
2546
+
2547
+ constexpr constant static float kvalues_iq4nl_f[16] = {
2548
+ -127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
2549
+ };
2550
+
2551
+ kernel void kernel_cpy_f32_iq4_nl(
2552
+ device const float * src0,
2553
+ device void * dst,
2554
+ constant int64_t & ne00,
2555
+ constant int64_t & ne01,
2556
+ constant int64_t & ne02,
2557
+ constant int64_t & ne03,
2558
+ constant uint64_t & nb00,
2559
+ constant uint64_t & nb01,
2560
+ constant uint64_t & nb02,
2561
+ constant uint64_t & nb03,
2562
+ constant int64_t & ne0,
2563
+ constant int64_t & ne1,
2564
+ constant int64_t & ne2,
2565
+ constant int64_t & ne3,
2566
+ constant uint64_t & nb0,
2567
+ constant uint64_t & nb1,
2568
+ constant uint64_t & nb2,
2569
+ constant uint64_t & nb3,
2570
+ uint3 tgpig[[threadgroup_position_in_grid]],
2571
+ uint3 tpitg[[thread_position_in_threadgroup]],
2572
+ uint3 ntg[[threads_per_threadgroup]]) {
2573
+ const int64_t i03 = tgpig[2];
2574
+ const int64_t i02 = tgpig[1];
2575
+ const int64_t i01 = tgpig[0];
2576
+
2577
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
2578
+
2579
+ const int64_t i3 = n / (ne2*ne1*ne0);
2580
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
2581
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
2582
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_NL;
2583
+
2584
+ device block_iq4_nl * dst_data = (device block_iq4_nl *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
2585
+
2586
+ for (int64_t i00 = tpitg.x*QK4_NL; i00 < ne00; i00 += ntg.x*QK4_NL) {
2587
+ device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
2588
+
2589
+ float amax = 0.0f; // absolute max
2590
+ float max = 0.0f;
2591
+
2592
+ for (int j = 0; j < QK4_0; j++) {
2593
+ const float v = src[j];
2594
+ if (amax < fabs(v)) {
2595
+ amax = fabs(v);
2596
+ max = v;
2597
+ }
2598
+ }
2599
+
2600
+ const float d = max / kvalues_iq4nl_f[0];
2601
+ const float id = d ? 1.0f/d : 0.0f;
2602
+
2603
+ float sumqx = 0, sumq2 = 0;
2604
+ for (int j = 0; j < QK4_NL/2; ++j) {
2605
+ const float x0 = src[0 + j]*id;
2606
+ const float x1 = src[QK4_NL/2 + j]*id;
2607
+
2608
+ const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl_f, x0);
2609
+ const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl_f, x1);
2610
+
2611
+ dst_data[i00/QK4_NL].qs[j] = xi0 | (xi1 << 4);
2612
+
2613
+ const float v0 = kvalues_iq4nl_f[xi0];
2614
+ const float v1 = kvalues_iq4nl_f[xi1];
2615
+ const float w0 = src[0 + j]*src[0 + j];
2616
+ const float w1 = src[QK4_NL/2 + j]*src[QK4_NL/2 + j];
2617
+ sumqx += w0*v0*src[j] + w1*v1*src[QK4_NL/2 + j];
2618
+ sumq2 += w0*v0*v0 + w1*v1*v1;
2619
+
2620
+ }
2621
+
2622
+ dst_data[i00/QK4_NL].d = sumq2 > 0 ? sumqx/sumq2 : d;
2623
+
2624
+ }
2625
+ }
2626
+
2391
2627
  kernel void kernel_concat(
2392
2628
  device const char * src0,
2393
2629
  device const char * src1,
@@ -4220,10 +4456,6 @@ void kernel_mul_mv_iq1_s_f32_impl(
4220
4456
  }
4221
4457
  }
4222
4458
 
4223
- constexpr constant static float kvalues_iq4nl_f[16] = {
4224
- -127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
4225
- };
4226
-
4227
4459
  void kernel_mul_mv_iq4_nl_f32_impl(
4228
4460
  device const void * src0,
4229
4461
  device const float * src1,
@@ -5528,18 +5760,18 @@ typedef void (mat_mm_t)(
5528
5760
  threadgroup uchar *,
5529
5761
  uint3, uint, uint);
5530
5762
 
5531
- template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<float4x4, 1, dequantize_f32>;
5532
- template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
5533
- template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
5534
- template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
5535
- template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_0, 2, dequantize_q5_0>;
5536
- template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_1, 2, dequantize_q5_1>;
5537
- template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>;
5538
- template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>;
5539
- template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>;
5540
- template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;
5541
- template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>;
5542
- template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
5763
+ template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<float4x4, 1, dequantize_f32>;
5764
+ template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
5765
+ template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
5766
+ template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
5767
+ template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_0, 2, dequantize_q5_0>;
5768
+ template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_1, 2, dequantize_q5_1>;
5769
+ template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>;
5770
+ template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>;
5771
+ template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>;
5772
+ template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;
5773
+ template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>;
5774
+ template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
5543
5775
  template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
5544
5776
  template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
5545
5777
  template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
@@ -5588,18 +5820,18 @@ typedef void (mat_mm_id_t)(
5588
5820
  threadgroup uchar *,
5589
5821
  uint3, uint, uint);
5590
5822
 
5591
- template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<float4x4, 1, dequantize_f32>;
5592
- template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<half4x4, 1, dequantize_f16>;
5593
- template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_0, 2, dequantize_q4_0>;
5594
- template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_1, 2, dequantize_q4_1>;
5595
- template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_0, 2, dequantize_q5_0>;
5596
- template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_1, 2, dequantize_q5_1>;
5597
- template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q8_0, 2, dequantize_q8_0>;
5598
- template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q2_K, QK_NL, dequantize_q2_K>;
5599
- template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q3_K, QK_NL, dequantize_q3_K>;
5600
- template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_K, QK_NL, dequantize_q4_K>;
5601
- template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_K, QK_NL, dequantize_q5_K>;
5602
- template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q6_K, QK_NL, dequantize_q6_K>;
5823
+ template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<float4x4, 1, dequantize_f32>;
5824
+ template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<half4x4, 1, dequantize_f16>;
5825
+ template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_0, 2, dequantize_q4_0>;
5826
+ template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_1, 2, dequantize_q4_1>;
5827
+ template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_0, 2, dequantize_q5_0>;
5828
+ template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_1, 2, dequantize_q5_1>;
5829
+ template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q8_0, 2, dequantize_q8_0>;
5830
+ template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q2_K, QK_NL, dequantize_q2_K>;
5831
+ template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q3_K, QK_NL, dequantize_q3_K>;
5832
+ template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_K, QK_NL, dequantize_q4_K>;
5833
+ template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_K, QK_NL, dequantize_q5_K>;
5834
+ template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q6_K, QK_NL, dequantize_q6_K>;
5603
5835
  template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
5604
5836
  template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
5605
5837
  template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
@@ -11705,9 +11705,8 @@ static void quantize_row_iq4_nl_impl(const int super_block_size, const int block
11705
11705
  ggml_fp16_t * dh, uint8_t * q4, uint16_t * scales_h, uint8_t * scales_l,
11706
11706
  float * scales, float * weight, uint8_t * L,
11707
11707
  const int8_t * values,
11708
- const float * quant_weights) {
11709
-
11710
- const int ntry = 7;
11708
+ const float * quant_weights,
11709
+ const int ntry) {
11711
11710
 
11712
11711
  float sigma2 = 0;
11713
11712
  for (int j = 0; j < super_block_size; ++j) sigma2 += x[j]*x[j];
@@ -11719,6 +11718,7 @@ static void quantize_row_iq4_nl_impl(const int super_block_size, const int block
11719
11718
  float max_scale = 0, amax_scale = 0;
11720
11719
  for (int ib = 0; ib < super_block_size/block_size; ++ib) {
11721
11720
  const float * xb = x + ib*block_size;
11721
+ uint8_t * Lb = L + ib*block_size;
11722
11722
  if (quant_weights) {
11723
11723
  const float * qw = quant_weights + ib*block_size;
11724
11724
  for (int j = 0; j < block_size; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]);
@@ -11736,12 +11736,13 @@ static void quantize_row_iq4_nl_impl(const int super_block_size, const int block
11736
11736
  scales[ib] = 0;
11737
11737
  continue;
11738
11738
  }
11739
- float d = -max/values[0];
11739
+ float d = ntry > 0 ? -max/values[0] : max/values[0];
11740
11740
  float id = 1/d;
11741
11741
  float sumqx = 0, sumq2 = 0;
11742
11742
  for (int j = 0; j < block_size; ++j) {
11743
11743
  float al = id*xb[j];
11744
11744
  int l = best_index_int8(16, values, al);
11745
+ Lb[j] = l;
11745
11746
  float q = values[l];
11746
11747
  float w = weight[j];
11747
11748
  sumqx += w*q*xb[j];
@@ -11796,9 +11797,11 @@ static void quantize_row_iq4_nl_impl(const int super_block_size, const int block
11796
11797
  }
11797
11798
  } else {
11798
11799
  dh[0] = GGML_FP32_TO_FP16(scales[0]);
11799
- float id = scales[0] ? 1/scales[0] : 0;
11800
- for (int j = 0; j < super_block_size; ++j) {
11801
- L[j] = best_index_int8(16, values, id*x[j]);
11800
+ if (ntry > 0) {
11801
+ float id = scales[0] ? 1/scales[0] : 0;
11802
+ for (int j = 0; j < super_block_size; ++j) {
11803
+ L[j] = best_index_int8(16, values, id*x[j]);
11804
+ }
11802
11805
  }
11803
11806
  }
11804
11807
 
@@ -11823,7 +11826,7 @@ size_t quantize_iq4_nl(const float * restrict src, void * restrict dst, int nrow
11823
11826
  for (int ibl = 0; ibl < nblock; ++ibl) {
11824
11827
  const float * qw = quant_weights ? quant_weights + QK4_NL*ibl : NULL;
11825
11828
  quantize_row_iq4_nl_impl(QK4_NL, 32, src + QK4_NL*ibl, &iq4[ibl].d, iq4[ibl].qs, &unused_h, unused_l,
11826
- &scale, weight, L, kvalues_iq4nl, qw);
11829
+ &scale, weight, L, kvalues_iq4nl, qw, 7);
11827
11830
  }
11828
11831
  src += n_per_row;
11829
11832
  qrow += nblock*sizeof(block_iq4_nl);
@@ -11832,14 +11835,23 @@ size_t quantize_iq4_nl(const float * restrict src, void * restrict dst, int nrow
11832
11835
  }
11833
11836
 
11834
11837
  void quantize_row_iq4_nl(const float * restrict x, void * restrict vy, int k) {
11835
- assert(k % QK4_NL == 0);
11836
- block_iq4_nl * restrict y = vy;
11837
- quantize_row_iq4_nl_reference(x, y, k);
11838
+ GGML_ASSERT(k%QK4_NL == 0);
11839
+ int nblock = k/QK4_NL;
11840
+ uint8_t L[QK4_NL];
11841
+ float weight[QK4_NL];
11842
+ uint16_t unused_h;
11843
+ uint8_t * unused_l = NULL;
11844
+ float scale;
11845
+ block_iq4_nl * iq4 = (block_iq4_nl *)vy;
11846
+ for (int ibl = 0; ibl < nblock; ++ibl) {
11847
+ quantize_row_iq4_nl_impl(QK4_NL, 32, x + QK4_NL*ibl, &iq4[ibl].d, iq4[ibl].qs, &unused_h, unused_l,
11848
+ &scale, weight, L, kvalues_iq4nl, NULL, -1);
11849
+ }
11838
11850
  }
11839
11851
 
11840
11852
  void quantize_row_iq4_nl_reference(const float * restrict x, block_iq4_nl * restrict y, int k) {
11841
11853
  assert(k % QK4_NL == 0);
11842
- quantize_iq4_nl(x, y, 1, k, NULL);
11854
+ quantize_row_iq4_nl(x, y, k);
11843
11855
  }
11844
11856
 
11845
11857
  size_t quantize_iq4_xs(const float * restrict src, void * restrict dst, int nrow, int n_per_row, const float * quant_weights) {
@@ -11857,7 +11869,7 @@ size_t quantize_iq4_xs(const float * restrict src, void * restrict dst, int nrow
11857
11869
  for (int ibl = 0; ibl < nblock; ++ibl) {
11858
11870
  const float * qw = quant_weights ? quant_weights + QK_K*ibl : NULL;
11859
11871
  quantize_row_iq4_nl_impl(QK_K, 32, src + QK_K*ibl, &iq4[ibl].d, iq4[ibl].qs, &iq4[ibl].scales_h, iq4[ibl].scales_l,
11860
- scales, weight, L, kvalues_iq4nl, qw);
11872
+ scales, weight, L, kvalues_iq4nl, qw, 7);
11861
11873
  }
11862
11874
  src += n_per_row;
11863
11875
  qrow += nblock*sizeof(block_iq4_xs);