llama_cpp 0.14.2 → 0.14.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -17,29 +17,17 @@ extern "C" {
17
17
 
18
18
  #define GGML_CUDA_MAX_DEVICES 16
19
19
 
20
- // Always success. To check if CUDA is actually loaded, use `ggml_cublas_loaded`.
21
- GGML_API GGML_CALL void ggml_init_cublas(void);
22
-
23
- // Returns `true` if there are available CUDA devices and cublas loads successfully; otherwise, it returns `false`.
24
- GGML_API GGML_CALL bool ggml_cublas_loaded(void);
25
-
26
- GGML_API GGML_CALL void * ggml_cuda_host_malloc(size_t size);
27
- GGML_API GGML_CALL void ggml_cuda_host_free(void * ptr);
28
-
29
- GGML_API GGML_CALL bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
30
- GGML_API GGML_CALL bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor);
31
-
32
- GGML_API GGML_CALL int ggml_cuda_get_device_count(void);
33
- GGML_API GGML_CALL void ggml_cuda_get_device_description(int device, char * description, size_t description_size);
34
-
35
20
  // backend API
36
21
  GGML_API GGML_CALL ggml_backend_t ggml_backend_cuda_init(int device);
37
22
 
38
23
  GGML_API GGML_CALL bool ggml_backend_is_cuda(ggml_backend_t backend);
39
24
 
25
+ // device buffer
40
26
  GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device);
27
+
41
28
  // split tensor buffer that splits matrices by rows across multiple devices
42
29
  GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_cuda_split_buffer_type(const float * tensor_split);
30
+
43
31
  // pinned host buffer for use with the CPU backend for faster copies between CPU and GPU
44
32
  GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_cuda_host_buffer_type(void);
45
33
 
@@ -47,6 +35,9 @@ GGML_API GGML_CALL int ggml_backend_cuda_get_device_count(void);
47
35
  GGML_API GGML_CALL void ggml_backend_cuda_get_device_description(int device, char * description, size_t description_size);
48
36
  GGML_API GGML_CALL void ggml_backend_cuda_get_device_memory(int device, size_t * free, size_t * total);
49
37
 
38
+ GGML_API GGML_CALL bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size);
39
+ GGML_API GGML_CALL void ggml_backend_cuda_unregister_host_buffer(void * buffer);
40
+
50
41
  #ifdef __cplusplus
51
42
  }
52
43
  #endif
@@ -1951,6 +1951,7 @@ static struct ggml_backend_i kompute_backend_i = {
1951
1951
  /* .graph_plan_compute = */ NULL,
1952
1952
  /* .graph_compute = */ ggml_backend_kompute_graph_compute,
1953
1953
  /* .supports_op = */ ggml_backend_kompute_supports_op,
1954
+ /* .offload_op = */ NULL,
1954
1955
  /* .event_new = */ NULL,
1955
1956
  /* .event_free = */ NULL,
1956
1957
  /* .event_record = */ NULL,
@@ -173,8 +173,9 @@ enum ggml_metal_kernel_type {
173
173
  GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
174
174
  GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0,
175
175
  GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1,
176
- //GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0,
177
- //GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,
176
+ GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0,
177
+ GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,
178
+ GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,
178
179
  GGML_METAL_KERNEL_TYPE_CPY_F16_F16,
179
180
  GGML_METAL_KERNEL_TYPE_CPY_F16_F32,
180
181
  GGML_METAL_KERNEL_TYPE_CONCAT,
@@ -598,8 +599,9 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
598
599
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
599
600
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true);
600
601
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true);
601
- //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true);
602
- //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true);
602
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true);
603
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true);
604
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true);
603
605
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
604
606
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true);
605
607
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true);
@@ -739,6 +741,9 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
739
741
  case GGML_TYPE_Q8_0:
740
742
  case GGML_TYPE_Q4_0:
741
743
  case GGML_TYPE_Q4_1:
744
+ case GGML_TYPE_Q5_0:
745
+ case GGML_TYPE_Q5_1:
746
+ case GGML_TYPE_IQ4_NL:
742
747
  return true;
743
748
  default:
744
749
  return false;
@@ -1387,6 +1392,14 @@ static enum ggml_status ggml_metal_graph_compute(
1387
1392
  (ne11 > ne11_mm_min || (ggml_is_quantized(src0t) && ne12 > 1))) {
1388
1393
  //printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
1389
1394
 
1395
+ // some Metal matrix data types require aligned pointers
1396
+ // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
1397
+ switch (src0->type) {
1398
+ case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
1399
+ case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
1400
+ default: break;
1401
+ }
1402
+
1390
1403
  id<MTLComputePipelineState> pipeline = nil;
1391
1404
 
1392
1405
  switch (src0->type) {
@@ -1701,6 +1714,14 @@ static enum ggml_status ggml_metal_graph_compute(
1701
1714
  ne20 % 32 == 0 && ne20 >= 64 &&
1702
1715
  ne11 > ne11_mm_min) {
1703
1716
 
1717
+ // some Metal matrix data types require aligned pointers
1718
+ // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
1719
+ switch (src0->type) {
1720
+ case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
1721
+ case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
1722
+ default: break;
1723
+ }
1724
+
1704
1725
  id<MTLComputePipelineState> pipeline = nil;
1705
1726
 
1706
1727
  switch (src2->type) {
@@ -2431,13 +2452,14 @@ static enum ggml_status ggml_metal_graph_compute(
2431
2452
  GGML_ASSERT(ne0 % ggml_blck_size(dst->type) == 0);
2432
2453
 
2433
2454
  switch (dstt) {
2434
- case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break;
2435
- case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break;
2436
- case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break;
2437
- case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break;
2438
- case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break;
2439
- //case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0].pipeline; break;
2440
- //case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1].pipeline; break;
2455
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break;
2456
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break;
2457
+ case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break;
2458
+ case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break;
2459
+ case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break;
2460
+ case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0].pipeline; break;
2461
+ case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1].pipeline; break;
2462
+ case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL].pipeline; break;
2441
2463
  default: GGML_ASSERT(false && "not implemented");
2442
2464
  };
2443
2465
  } break;
@@ -2837,6 +2859,7 @@ static struct ggml_backend_i ggml_backend_metal_i = {
2837
2859
  /* .graph_plan_compute = */ NULL,
2838
2860
  /* .graph_compute = */ ggml_backend_metal_graph_compute,
2839
2861
  /* .supports_op = */ ggml_backend_metal_supports_op,
2862
+ /* .offload_op = */ NULL,
2840
2863
  /* .event_new = */ NULL,
2841
2864
  /* .event_free = */ NULL,
2842
2865
  /* .event_record = */ NULL,
@@ -2388,6 +2388,242 @@ kernel void kernel_cpy_f32_q4_1(
2388
2388
  }
2389
2389
  }
2390
2390
 
2391
+ kernel void kernel_cpy_f32_q5_0(
2392
+ device const float * src0,
2393
+ device void * dst,
2394
+ constant int64_t & ne00,
2395
+ constant int64_t & ne01,
2396
+ constant int64_t & ne02,
2397
+ constant int64_t & ne03,
2398
+ constant uint64_t & nb00,
2399
+ constant uint64_t & nb01,
2400
+ constant uint64_t & nb02,
2401
+ constant uint64_t & nb03,
2402
+ constant int64_t & ne0,
2403
+ constant int64_t & ne1,
2404
+ constant int64_t & ne2,
2405
+ constant int64_t & ne3,
2406
+ constant uint64_t & nb0,
2407
+ constant uint64_t & nb1,
2408
+ constant uint64_t & nb2,
2409
+ constant uint64_t & nb3,
2410
+ uint3 tgpig[[threadgroup_position_in_grid]],
2411
+ uint3 tpitg[[thread_position_in_threadgroup]],
2412
+ uint3 ntg[[threads_per_threadgroup]]) {
2413
+ const int64_t i03 = tgpig[2];
2414
+ const int64_t i02 = tgpig[1];
2415
+ const int64_t i01 = tgpig[0];
2416
+
2417
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
2418
+
2419
+ const int64_t i3 = n / (ne2*ne1*ne0);
2420
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
2421
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
2422
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK5_0;
2423
+
2424
+ device block_q5_0 * dst_data = (device block_q5_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
2425
+
2426
+ for (int64_t i00 = tpitg.x*QK5_0; i00 < ne00; i00 += ntg.x*QK5_0) {
2427
+ device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
2428
+
2429
+ float amax = 0.0f; // absolute max
2430
+ float max = 0.0f;
2431
+
2432
+ for (int j = 0; j < QK5_0; j++) {
2433
+ const float v = src[j];
2434
+ if (amax < fabs(v)) {
2435
+ amax = fabs(v);
2436
+ max = v;
2437
+ }
2438
+ }
2439
+
2440
+ const float d = max / -16;
2441
+ const float id = d ? 1.0f/d : 0.0f;
2442
+
2443
+ dst_data[i00/QK5_0].d = d;
2444
+
2445
+ uint32_t qh = 0;
2446
+ for (int j = 0; j < QK5_0/2; ++j) {
2447
+ const float x0 = src[0 + j]*id;
2448
+ const float x1 = src[QK5_0/2 + j]*id;
2449
+
2450
+ const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f));
2451
+ const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f));
2452
+
2453
+ dst_data[i00/QK5_0].qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
2454
+ qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
2455
+ qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
2456
+ }
2457
+ thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
2458
+ for (int j = 0; j < 4; ++j) {
2459
+ dst_data[i00/QK5_0].qh[j] = qh8[j];
2460
+ }
2461
+ }
2462
+ }
2463
+
2464
+ kernel void kernel_cpy_f32_q5_1(
2465
+ device const float * src0,
2466
+ device void * dst,
2467
+ constant int64_t & ne00,
2468
+ constant int64_t & ne01,
2469
+ constant int64_t & ne02,
2470
+ constant int64_t & ne03,
2471
+ constant uint64_t & nb00,
2472
+ constant uint64_t & nb01,
2473
+ constant uint64_t & nb02,
2474
+ constant uint64_t & nb03,
2475
+ constant int64_t & ne0,
2476
+ constant int64_t & ne1,
2477
+ constant int64_t & ne2,
2478
+ constant int64_t & ne3,
2479
+ constant uint64_t & nb0,
2480
+ constant uint64_t & nb1,
2481
+ constant uint64_t & nb2,
2482
+ constant uint64_t & nb3,
2483
+ uint3 tgpig[[threadgroup_position_in_grid]],
2484
+ uint3 tpitg[[thread_position_in_threadgroup]],
2485
+ uint3 ntg[[threads_per_threadgroup]]) {
2486
+ const int64_t i03 = tgpig[2];
2487
+ const int64_t i02 = tgpig[1];
2488
+ const int64_t i01 = tgpig[0];
2489
+
2490
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
2491
+
2492
+ const int64_t i3 = n / (ne2*ne1*ne0);
2493
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
2494
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
2495
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK5_1;
2496
+
2497
+ device block_q5_1 * dst_data = (device block_q5_1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
2498
+
2499
+ for (int64_t i00 = tpitg.x*QK5_1; i00 < ne00; i00 += ntg.x*QK5_1) {
2500
+ device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
2501
+
2502
+ float max = src[0];
2503
+ float min = src[0];
2504
+
2505
+ for (int j = 1; j < QK5_1; j++) {
2506
+ const float v = src[j];
2507
+ min = v < min ? v : min;
2508
+ max = v > max ? v : max;
2509
+ }
2510
+
2511
+ const float d = (max - min) / 31;
2512
+ const float id = d ? 1.0f/d : 0.0f;
2513
+
2514
+ dst_data[i00/QK5_1].d = d;
2515
+ dst_data[i00/QK5_1].m = min;
2516
+
2517
+ uint32_t qh = 0;
2518
+ for (int j = 0; j < QK5_1/2; ++j) {
2519
+ const float x0 = (src[0 + j] - min)*id;
2520
+ const float x1 = (src[QK5_1/2 + j] - min)*id;
2521
+
2522
+ const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
2523
+ const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
2524
+
2525
+ dst_data[i00/QK5_1].qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
2526
+ qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
2527
+ qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2);
2528
+ }
2529
+ thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
2530
+ for (int j = 0; j < 4; ++j) {
2531
+ dst_data[i00/QK5_1].qh[j] = qh8[j];
2532
+ }
2533
+ }
2534
+ }
2535
+
2536
+ static inline int best_index_int8(int n, constant float * val, float x) {
2537
+ if (x <= val[0]) return 0;
2538
+ if (x >= val[n-1]) return n-1;
2539
+ int ml = 0, mu = n-1;
2540
+ while (mu-ml > 1) {
2541
+ int mav = (ml+mu)/2;
2542
+ if (x < val[mav]) mu = mav; else ml = mav;
2543
+ }
2544
+ return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
2545
+ }
2546
+
2547
+ constexpr constant static float kvalues_iq4nl_f[16] = {
2548
+ -127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
2549
+ };
2550
+
2551
+ kernel void kernel_cpy_f32_iq4_nl(
2552
+ device const float * src0,
2553
+ device void * dst,
2554
+ constant int64_t & ne00,
2555
+ constant int64_t & ne01,
2556
+ constant int64_t & ne02,
2557
+ constant int64_t & ne03,
2558
+ constant uint64_t & nb00,
2559
+ constant uint64_t & nb01,
2560
+ constant uint64_t & nb02,
2561
+ constant uint64_t & nb03,
2562
+ constant int64_t & ne0,
2563
+ constant int64_t & ne1,
2564
+ constant int64_t & ne2,
2565
+ constant int64_t & ne3,
2566
+ constant uint64_t & nb0,
2567
+ constant uint64_t & nb1,
2568
+ constant uint64_t & nb2,
2569
+ constant uint64_t & nb3,
2570
+ uint3 tgpig[[threadgroup_position_in_grid]],
2571
+ uint3 tpitg[[thread_position_in_threadgroup]],
2572
+ uint3 ntg[[threads_per_threadgroup]]) {
2573
+ const int64_t i03 = tgpig[2];
2574
+ const int64_t i02 = tgpig[1];
2575
+ const int64_t i01 = tgpig[0];
2576
+
2577
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
2578
+
2579
+ const int64_t i3 = n / (ne2*ne1*ne0);
2580
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
2581
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
2582
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_NL;
2583
+
2584
+ device block_iq4_nl * dst_data = (device block_iq4_nl *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
2585
+
2586
+ for (int64_t i00 = tpitg.x*QK4_NL; i00 < ne00; i00 += ntg.x*QK4_NL) {
2587
+ device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
2588
+
2589
+ float amax = 0.0f; // absolute max
2590
+ float max = 0.0f;
2591
+
2592
+ for (int j = 0; j < QK4_0; j++) {
2593
+ const float v = src[j];
2594
+ if (amax < fabs(v)) {
2595
+ amax = fabs(v);
2596
+ max = v;
2597
+ }
2598
+ }
2599
+
2600
+ const float d = max / kvalues_iq4nl_f[0];
2601
+ const float id = d ? 1.0f/d : 0.0f;
2602
+
2603
+ float sumqx = 0, sumq2 = 0;
2604
+ for (int j = 0; j < QK4_NL/2; ++j) {
2605
+ const float x0 = src[0 + j]*id;
2606
+ const float x1 = src[QK4_NL/2 + j]*id;
2607
+
2608
+ const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl_f, x0);
2609
+ const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl_f, x1);
2610
+
2611
+ dst_data[i00/QK4_NL].qs[j] = xi0 | (xi1 << 4);
2612
+
2613
+ const float v0 = kvalues_iq4nl_f[xi0];
2614
+ const float v1 = kvalues_iq4nl_f[xi1];
2615
+ const float w0 = src[0 + j]*src[0 + j];
2616
+ const float w1 = src[QK4_NL/2 + j]*src[QK4_NL/2 + j];
2617
+ sumqx += w0*v0*src[j] + w1*v1*src[QK4_NL/2 + j];
2618
+ sumq2 += w0*v0*v0 + w1*v1*v1;
2619
+
2620
+ }
2621
+
2622
+ dst_data[i00/QK4_NL].d = sumq2 > 0 ? sumqx/sumq2 : d;
2623
+
2624
+ }
2625
+ }
2626
+
2391
2627
  kernel void kernel_concat(
2392
2628
  device const char * src0,
2393
2629
  device const char * src1,
@@ -4220,10 +4456,6 @@ void kernel_mul_mv_iq1_s_f32_impl(
4220
4456
  }
4221
4457
  }
4222
4458
 
4223
- constexpr constant static float kvalues_iq4nl_f[16] = {
4224
- -127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
4225
- };
4226
-
4227
4459
  void kernel_mul_mv_iq4_nl_f32_impl(
4228
4460
  device const void * src0,
4229
4461
  device const float * src1,
@@ -5528,18 +5760,18 @@ typedef void (mat_mm_t)(
5528
5760
  threadgroup uchar *,
5529
5761
  uint3, uint, uint);
5530
5762
 
5531
- template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<float4x4, 1, dequantize_f32>;
5532
- template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
5533
- template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
5534
- template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
5535
- template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_0, 2, dequantize_q5_0>;
5536
- template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_1, 2, dequantize_q5_1>;
5537
- template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>;
5538
- template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>;
5539
- template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>;
5540
- template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;
5541
- template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>;
5542
- template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
5763
+ template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<float4x4, 1, dequantize_f32>;
5764
+ template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
5765
+ template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
5766
+ template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
5767
+ template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_0, 2, dequantize_q5_0>;
5768
+ template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_1, 2, dequantize_q5_1>;
5769
+ template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>;
5770
+ template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>;
5771
+ template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>;
5772
+ template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;
5773
+ template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>;
5774
+ template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
5543
5775
  template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
5544
5776
  template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
5545
5777
  template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
@@ -5588,18 +5820,18 @@ typedef void (mat_mm_id_t)(
5588
5820
  threadgroup uchar *,
5589
5821
  uint3, uint, uint);
5590
5822
 
5591
- template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<float4x4, 1, dequantize_f32>;
5592
- template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<half4x4, 1, dequantize_f16>;
5593
- template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_0, 2, dequantize_q4_0>;
5594
- template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_1, 2, dequantize_q4_1>;
5595
- template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_0, 2, dequantize_q5_0>;
5596
- template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_1, 2, dequantize_q5_1>;
5597
- template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q8_0, 2, dequantize_q8_0>;
5598
- template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q2_K, QK_NL, dequantize_q2_K>;
5599
- template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q3_K, QK_NL, dequantize_q3_K>;
5600
- template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_K, QK_NL, dequantize_q4_K>;
5601
- template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_K, QK_NL, dequantize_q5_K>;
5602
- template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q6_K, QK_NL, dequantize_q6_K>;
5823
+ template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<float4x4, 1, dequantize_f32>;
5824
+ template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<half4x4, 1, dequantize_f16>;
5825
+ template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_0, 2, dequantize_q4_0>;
5826
+ template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_1, 2, dequantize_q4_1>;
5827
+ template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_0, 2, dequantize_q5_0>;
5828
+ template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_1, 2, dequantize_q5_1>;
5829
+ template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q8_0, 2, dequantize_q8_0>;
5830
+ template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q2_K, QK_NL, dequantize_q2_K>;
5831
+ template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q3_K, QK_NL, dequantize_q3_K>;
5832
+ template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_K, QK_NL, dequantize_q4_K>;
5833
+ template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_K, QK_NL, dequantize_q5_K>;
5834
+ template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q6_K, QK_NL, dequantize_q6_K>;
5603
5835
  template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
5604
5836
  template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
5605
5837
  template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
@@ -11705,9 +11705,8 @@ static void quantize_row_iq4_nl_impl(const int super_block_size, const int block
11705
11705
  ggml_fp16_t * dh, uint8_t * q4, uint16_t * scales_h, uint8_t * scales_l,
11706
11706
  float * scales, float * weight, uint8_t * L,
11707
11707
  const int8_t * values,
11708
- const float * quant_weights) {
11709
-
11710
- const int ntry = 7;
11708
+ const float * quant_weights,
11709
+ const int ntry) {
11711
11710
 
11712
11711
  float sigma2 = 0;
11713
11712
  for (int j = 0; j < super_block_size; ++j) sigma2 += x[j]*x[j];
@@ -11719,6 +11718,7 @@ static void quantize_row_iq4_nl_impl(const int super_block_size, const int block
11719
11718
  float max_scale = 0, amax_scale = 0;
11720
11719
  for (int ib = 0; ib < super_block_size/block_size; ++ib) {
11721
11720
  const float * xb = x + ib*block_size;
11721
+ uint8_t * Lb = L + ib*block_size;
11722
11722
  if (quant_weights) {
11723
11723
  const float * qw = quant_weights + ib*block_size;
11724
11724
  for (int j = 0; j < block_size; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]);
@@ -11736,12 +11736,13 @@ static void quantize_row_iq4_nl_impl(const int super_block_size, const int block
11736
11736
  scales[ib] = 0;
11737
11737
  continue;
11738
11738
  }
11739
- float d = -max/values[0];
11739
+ float d = ntry > 0 ? -max/values[0] : max/values[0];
11740
11740
  float id = 1/d;
11741
11741
  float sumqx = 0, sumq2 = 0;
11742
11742
  for (int j = 0; j < block_size; ++j) {
11743
11743
  float al = id*xb[j];
11744
11744
  int l = best_index_int8(16, values, al);
11745
+ Lb[j] = l;
11745
11746
  float q = values[l];
11746
11747
  float w = weight[j];
11747
11748
  sumqx += w*q*xb[j];
@@ -11796,9 +11797,11 @@ static void quantize_row_iq4_nl_impl(const int super_block_size, const int block
11796
11797
  }
11797
11798
  } else {
11798
11799
  dh[0] = GGML_FP32_TO_FP16(scales[0]);
11799
- float id = scales[0] ? 1/scales[0] : 0;
11800
- for (int j = 0; j < super_block_size; ++j) {
11801
- L[j] = best_index_int8(16, values, id*x[j]);
11800
+ if (ntry > 0) {
11801
+ float id = scales[0] ? 1/scales[0] : 0;
11802
+ for (int j = 0; j < super_block_size; ++j) {
11803
+ L[j] = best_index_int8(16, values, id*x[j]);
11804
+ }
11802
11805
  }
11803
11806
  }
11804
11807
 
@@ -11823,7 +11826,7 @@ size_t quantize_iq4_nl(const float * restrict src, void * restrict dst, int nrow
11823
11826
  for (int ibl = 0; ibl < nblock; ++ibl) {
11824
11827
  const float * qw = quant_weights ? quant_weights + QK4_NL*ibl : NULL;
11825
11828
  quantize_row_iq4_nl_impl(QK4_NL, 32, src + QK4_NL*ibl, &iq4[ibl].d, iq4[ibl].qs, &unused_h, unused_l,
11826
- &scale, weight, L, kvalues_iq4nl, qw);
11829
+ &scale, weight, L, kvalues_iq4nl, qw, 7);
11827
11830
  }
11828
11831
  src += n_per_row;
11829
11832
  qrow += nblock*sizeof(block_iq4_nl);
@@ -11832,14 +11835,23 @@ size_t quantize_iq4_nl(const float * restrict src, void * restrict dst, int nrow
11832
11835
  }
11833
11836
 
11834
11837
  void quantize_row_iq4_nl(const float * restrict x, void * restrict vy, int k) {
11835
- assert(k % QK4_NL == 0);
11836
- block_iq4_nl * restrict y = vy;
11837
- quantize_row_iq4_nl_reference(x, y, k);
11838
+ GGML_ASSERT(k%QK4_NL == 0);
11839
+ int nblock = k/QK4_NL;
11840
+ uint8_t L[QK4_NL];
11841
+ float weight[QK4_NL];
11842
+ uint16_t unused_h;
11843
+ uint8_t * unused_l = NULL;
11844
+ float scale;
11845
+ block_iq4_nl * iq4 = (block_iq4_nl *)vy;
11846
+ for (int ibl = 0; ibl < nblock; ++ibl) {
11847
+ quantize_row_iq4_nl_impl(QK4_NL, 32, x + QK4_NL*ibl, &iq4[ibl].d, iq4[ibl].qs, &unused_h, unused_l,
11848
+ &scale, weight, L, kvalues_iq4nl, NULL, -1);
11849
+ }
11838
11850
  }
11839
11851
 
11840
11852
  void quantize_row_iq4_nl_reference(const float * restrict x, block_iq4_nl * restrict y, int k) {
11841
11853
  assert(k % QK4_NL == 0);
11842
- quantize_iq4_nl(x, y, 1, k, NULL);
11854
+ quantize_row_iq4_nl(x, y, k);
11843
11855
  }
11844
11856
 
11845
11857
  size_t quantize_iq4_xs(const float * restrict src, void * restrict dst, int nrow, int n_per_row, const float * quant_weights) {
@@ -11857,7 +11869,7 @@ size_t quantize_iq4_xs(const float * restrict src, void * restrict dst, int nrow
11857
11869
  for (int ibl = 0; ibl < nblock; ++ibl) {
11858
11870
  const float * qw = quant_weights ? quant_weights + QK_K*ibl : NULL;
11859
11871
  quantize_row_iq4_nl_impl(QK_K, 32, src + QK_K*ibl, &iq4[ibl].d, iq4[ibl].qs, &iq4[ibl].scales_h, iq4[ibl].scales_l,
11860
- scales, weight, L, kvalues_iq4nl, qw);
11872
+ scales, weight, L, kvalues_iq4nl, qw, 7);
11861
11873
  }
11862
11874
  src += n_per_row;
11863
11875
  qrow += nblock*sizeof(block_iq4_xs);