cui-llama.rn 1.2.3 → 1.2.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/cpp/ggml-metal.m CHANGED
@@ -241,6 +241,8 @@ enum lm_ggml_metal_kernel_type {
241
241
  LM_GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16,
242
242
  LM_GGML_METAL_KERNEL_TYPE_IM2COL_F16,
243
243
  LM_GGML_METAL_KERNEL_TYPE_IM2COL_F32,
244
+ LM_GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16,
245
+ LM_GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32,
244
246
  LM_GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
245
247
  LM_GGML_METAL_KERNEL_TYPE_PAD_F32,
246
248
  LM_GGML_METAL_KERNEL_TYPE_ARANGE_F32,
@@ -272,6 +274,8 @@ enum lm_ggml_metal_kernel_type {
272
274
  LM_GGML_METAL_KERNEL_TYPE_SIN,
273
275
  LM_GGML_METAL_KERNEL_TYPE_COS,
274
276
  LM_GGML_METAL_KERNEL_TYPE_SUM_ROWS,
277
+ LM_GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
278
+ LM_GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
275
279
 
276
280
  LM_GGML_METAL_KERNEL_TYPE_COUNT
277
281
  };
@@ -446,7 +450,14 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
446
450
  LM_GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
447
451
  return NULL;
448
452
  }
453
+
454
+ #if !__has_feature(objc_arc)
455
+ [options release];
456
+ #endif
449
457
  }
458
+ #if LM_GGML_METAL_EMBED_LIBRARY
459
+ [src release];
460
+ #endif // LM_GGML_METAL_EMBED_LIBRARY
450
461
  }
451
462
  }
452
463
 
@@ -685,6 +696,8 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
685
696
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, rope_neox_f16, true);
686
697
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
687
698
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
699
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16, im2col_ext_f16, true);
700
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32, im2col_ext_f32, true);
688
701
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
689
702
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true);
690
703
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true);
@@ -716,6 +729,8 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
716
729
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SIN, sin, true);
717
730
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_COS, cos, true);
718
731
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
732
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true);
733
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true);
719
734
  }
720
735
 
721
736
  [metal_library release];
@@ -844,8 +859,8 @@ static bool lm_ggml_metal_supports_op(const struct lm_ggml_backend_metal_device_
844
859
  case LM_GGML_OP_IM2COL:
845
860
  return op->src[0]->type == LM_GGML_TYPE_F16;
846
861
  case LM_GGML_OP_POOL_1D:
847
- case LM_GGML_OP_POOL_2D:
848
862
  return false;
863
+ case LM_GGML_OP_POOL_2D:
849
864
  case LM_GGML_OP_UPSCALE:
850
865
  case LM_GGML_OP_PAD:
851
866
  case LM_GGML_OP_ARANGE:
@@ -1007,19 +1022,21 @@ static void lm_ggml_metal_encode_node(
1007
1022
  id<MTLBuffer> id_src2 = src2 ? lm_ggml_metal_get_buffer(src2, &offs_src2) : nil;
1008
1023
  id<MTLBuffer> id_dst = dst ? lm_ggml_metal_get_buffer(dst, &offs_dst) : nil;
1009
1024
 
1010
- //LM_GGML_LOG_INFO("%s: op - %s\n", __func__, lm_ggml_op_name(dst->op));
1011
- //if (src0) {
1012
- // LM_GGML_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, lm_ggml_type_name(src0t), ne00, ne01, ne02,
1013
- // lm_ggml_is_contiguous(src0), src0->name);
1014
- //}
1015
- //if (src1) {
1016
- // LM_GGML_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, lm_ggml_type_name(src1t), ne10, ne11, ne12,
1017
- // lm_ggml_is_contiguous(src1), src1->name);
1018
- //}
1019
- //if (dst) {
1020
- // LM_GGML_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, lm_ggml_type_name(dstt), ne0, ne1, ne2,
1021
- // dst->name);
1022
- //}
1025
+ #if 0
1026
+ LM_GGML_LOG_INFO("%s: op - %s\n", __func__, lm_ggml_op_name(dst->op));
1027
+ if (src0) {
1028
+ LM_GGML_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, lm_ggml_type_name(src0t), ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03,
1029
+ lm_ggml_is_contiguous(src0), src0->name);
1030
+ }
1031
+ if (src1) {
1032
+ LM_GGML_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, lm_ggml_type_name(src1t), ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13,
1033
+ lm_ggml_is_contiguous(src1), src1->name);
1034
+ }
1035
+ if (dst) {
1036
+ LM_GGML_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], 1, %s\n", __func__, lm_ggml_type_name(dstt), ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3,
1037
+ dst->name);
1038
+ }
1039
+ #endif
1023
1040
 
1024
1041
  id<MTLDevice> device = ctx_dev->mtl_device;
1025
1042
 
@@ -1802,14 +1819,16 @@ static void lm_ggml_metal_encode_node(
1802
1819
  [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
1803
1820
  [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
1804
1821
  [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6];
1805
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
1806
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:8];
1807
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:9];
1808
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10];
1809
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11];
1810
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12];
1811
- [encoder setBytes:&r2 length:sizeof(r2) atIndex:13];
1812
- [encoder setBytes:&r3 length:sizeof(r3) atIndex:14];
1822
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:7];
1823
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:8];
1824
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:9];
1825
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:10];
1826
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:11];
1827
+ [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:12];
1828
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
1829
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
1830
+ [encoder setBytes:&r2 length:sizeof(r2) atIndex:15];
1831
+ [encoder setBytes:&r3 length:sizeof(r3) atIndex:16];
1813
1832
  [encoder setThreadgroupMemoryLength:8192 atIndex:0];
1814
1833
  [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1815
1834
  } else {
@@ -1978,20 +1997,22 @@ static void lm_ggml_metal_encode_node(
1978
1997
  [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
1979
1998
  [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
1980
1999
  [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
1981
- [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:9];
1982
- [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:10];
1983
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:11];
1984
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:12];
1985
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:13];
1986
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
1987
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15];
1988
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
1989
- [encoder setBytes:&r2 length:sizeof(r2) atIndex:17];
1990
- [encoder setBytes:&r3 length:sizeof(r3) atIndex:18];
2000
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
2001
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
2002
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
2003
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
2004
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:13];
2005
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:14];
2006
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:15];
2007
+ [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:16];
2008
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17];
2009
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:18];
2010
+ [encoder setBytes:&r2 length:sizeof(r2) atIndex:19];
2011
+ [encoder setBytes:&r3 length:sizeof(r3) atIndex:20];
1991
2012
 
1992
2013
  if (src0t == LM_GGML_TYPE_Q4_0 || src0t == LM_GGML_TYPE_Q4_1 || src0t == LM_GGML_TYPE_Q5_0 ||
1993
- src0t == LM_GGML_TYPE_Q5_1 || src0t == LM_GGML_TYPE_Q8_0 || src0t == LM_GGML_TYPE_Q2_K ||
1994
- src0t == LM_GGML_TYPE_IQ1_S || src0t == LM_GGML_TYPE_IQ1_M || src0t == LM_GGML_TYPE_IQ2_S) {
2014
+ src0t == LM_GGML_TYPE_Q5_1 || src0t == LM_GGML_TYPE_Q8_0 || src0t == LM_GGML_TYPE_Q2_K ||
2015
+ src0t == LM_GGML_TYPE_IQ1_S || src0t == LM_GGML_TYPE_IQ1_M || src0t == LM_GGML_TYPE_IQ2_S) {
1995
2016
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1996
2017
  }
1997
2018
  else if (src0t == LM_GGML_TYPE_IQ2_XXS || src0t == LM_GGML_TYPE_IQ2_XS) {
@@ -2040,6 +2061,9 @@ static void lm_ggml_metal_encode_node(
2040
2061
 
2041
2062
  LM_GGML_ASSERT(src1t == LM_GGML_TYPE_F32);
2042
2063
 
2064
+ LM_GGML_ASSERT(ne03 == 1);
2065
+ LM_GGML_ASSERT(ne13 == 1);
2066
+
2043
2067
  // find the break-even point where the matrix-matrix kernel becomes more efficient compared
2044
2068
  // to the matrix-vector kernel
2045
2069
  // ne20 = n_used_experts
@@ -2545,6 +2569,8 @@ static void lm_ggml_metal_encode_node(
2545
2569
  } break;
2546
2570
  case LM_GGML_OP_IM2COL:
2547
2571
  {
2572
+ LM_GGML_ASSERT(lm_ggml_is_contiguous(src0));
2573
+ LM_GGML_ASSERT(lm_ggml_is_contiguous(src1));
2548
2574
  LM_GGML_ASSERT(src0->type == LM_GGML_TYPE_F16);
2549
2575
  LM_GGML_ASSERT(src1->type == LM_GGML_TYPE_F32);
2550
2576
  LM_GGML_ASSERT( dst->type == LM_GGML_TYPE_F16 || dst->type == LM_GGML_TYPE_F32);
@@ -2574,30 +2600,54 @@ static void lm_ggml_metal_encode_node(
2574
2600
  const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4;
2575
2601
  const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4;
2576
2602
 
2577
- id<MTLComputePipelineState> pipeline = nil;
2603
+ id<MTLComputePipelineState> pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline;
2604
+
2605
+ const bool is_gt_mttpt = ((size_t)(N * KH * KW)) > pipeline.maxTotalThreadsPerThreadgroup;
2578
2606
 
2579
2607
  switch (dst->type) {
2580
- case LM_GGML_TYPE_F32: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline; break;
2581
- case LM_GGML_TYPE_F16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline; break;
2608
+ case LM_GGML_TYPE_F32: {
2609
+ pipeline = (is_gt_mttpt ?
2610
+ ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32].pipeline
2611
+ :
2612
+ ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline);
2613
+ } break;
2614
+ case LM_GGML_TYPE_F16: {
2615
+ pipeline = (is_gt_mttpt ?
2616
+ ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16].pipeline
2617
+ :
2618
+ ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline);
2619
+ } break;
2582
2620
  default: LM_GGML_ABORT("fatal error");
2583
2621
  };
2584
2622
 
2585
2623
  [encoder setComputePipelineState:pipeline];
2586
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:0];
2587
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2588
- [encoder setBytes:&ofs0 length:sizeof( int32_t) atIndex:2];
2589
- [encoder setBytes:&ofs1 length:sizeof( int32_t) atIndex:3];
2590
- [encoder setBytes:&IW length:sizeof( int32_t) atIndex:4];
2591
- [encoder setBytes:&IH length:sizeof( int32_t) atIndex:5];
2592
- [encoder setBytes:&CHW length:sizeof( int32_t) atIndex:6];
2593
- [encoder setBytes:&s0 length:sizeof( int32_t) atIndex:7];
2594
- [encoder setBytes:&s1 length:sizeof( int32_t) atIndex:8];
2595
- [encoder setBytes:&p0 length:sizeof( int32_t) atIndex:9];
2596
- [encoder setBytes:&p1 length:sizeof( int32_t) atIndex:10];
2597
- [encoder setBytes:&d0 length:sizeof( int32_t) atIndex:11];
2598
- [encoder setBytes:&d1 length:sizeof( int32_t) atIndex:12];
2599
-
2600
- [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
2624
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:0];
2625
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2626
+ [encoder setBytes:&ofs0 length:sizeof(int32_t) atIndex:2];
2627
+ [encoder setBytes:&ofs1 length:sizeof(int32_t) atIndex:3];
2628
+ [encoder setBytes:&IW length:sizeof(int32_t) atIndex:4];
2629
+ [encoder setBytes:&IH length:sizeof(int32_t) atIndex:5];
2630
+ [encoder setBytes:&CHW length:sizeof(int32_t) atIndex:6];
2631
+ [encoder setBytes:&s0 length:sizeof(int32_t) atIndex:7];
2632
+ [encoder setBytes:&s1 length:sizeof(int32_t) atIndex:8];
2633
+ [encoder setBytes:&p0 length:sizeof(int32_t) atIndex:9];
2634
+ [encoder setBytes:&p1 length:sizeof(int32_t) atIndex:10];
2635
+ [encoder setBytes:&d0 length:sizeof(int32_t) atIndex:11];
2636
+ [encoder setBytes:&d1 length:sizeof(int32_t) atIndex:12];
2637
+
2638
+ if (is_gt_mttpt) {
2639
+ [encoder setBytes:&N length:sizeof(int32_t) atIndex:13];
2640
+ [encoder setBytes:&KH length:sizeof(int32_t) atIndex:14];
2641
+ [encoder setBytes:&KW length:sizeof(int32_t) atIndex:15];
2642
+
2643
+ const uint64_t n_threads = MIN(pipeline.maxTotalThreadsPerThreadgroup, (uint64_t)N);
2644
+
2645
+ const int64_t quotient = N / n_threads + (N % n_threads > 0 ? 1 : 0);
2646
+
2647
+ [encoder dispatchThreadgroups:MTLSizeMake(quotient * CHW, OH, OW) threadsPerThreadgroup:MTLSizeMake(n_threads, 1, 1)];
2648
+ } else {
2649
+ [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
2650
+ }
2601
2651
  } break;
2602
2652
  case LM_GGML_OP_UPSCALE:
2603
2653
  {
@@ -3001,6 +3051,64 @@ static void lm_ggml_metal_encode_node(
3001
3051
 
3002
3052
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
3003
3053
  } break;
3054
+ case LM_GGML_OP_POOL_2D:
3055
+ {
3056
+ LM_GGML_ASSERT(lm_ggml_is_contiguous(src0));
3057
+ LM_GGML_ASSERT(src0t == LM_GGML_TYPE_F32 && src0t == dstt);
3058
+
3059
+ const int32_t * opts = dst->op_params;
3060
+ enum lm_ggml_op_pool op = opts[0];
3061
+
3062
+ id<MTLComputePipelineState> pipeline = nil;
3063
+ switch (src0t) {
3064
+ case LM_GGML_TYPE_F32: {
3065
+ switch(op) {
3066
+ case LM_GGML_OP_POOL_AVG:
3067
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32].pipeline; break;
3068
+ case LM_GGML_OP_POOL_MAX:
3069
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32].pipeline; break;
3070
+ default: LM_GGML_ASSERT(false && "not implemented");
3071
+ }
3072
+ } break;
3073
+ default: LM_GGML_ASSERT(false && "not implemented");
3074
+ }
3075
+
3076
+ const int32_t k0 = opts[1];
3077
+ const int32_t k1 = opts[2];
3078
+ const int32_t s0 = opts[3];
3079
+ const int32_t s1 = opts[4];
3080
+ const int32_t p0 = opts[5];
3081
+ const int32_t p1 = opts[6];
3082
+
3083
+ const int64_t IH = src0->ne[1];
3084
+ const int64_t IW = src0->ne[0];
3085
+
3086
+ const int64_t N = dst->ne[3];
3087
+ const int64_t OC = dst->ne[2];
3088
+ const int64_t OH = dst->ne[1];
3089
+ const int64_t OW = dst->ne[0];
3090
+
3091
+ const int64_t parallel_elements = N * OC * OH * OW;
3092
+ const int64_t n_threads = MIN((int64_t)[pipeline maxTotalThreadsPerThreadgroup], parallel_elements);
3093
+ const int64_t n_tg = (parallel_elements + n_threads - 1) / n_threads;
3094
+
3095
+ [encoder setComputePipelineState:pipeline];
3096
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
3097
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
3098
+ [encoder setBytes:&k0 length:sizeof(int32_t) atIndex:2];
3099
+ [encoder setBytes:&k1 length:sizeof(int32_t) atIndex:3];
3100
+ [encoder setBytes:&s0 length:sizeof(int32_t) atIndex:4];
3101
+ [encoder setBytes:&s1 length:sizeof(int32_t) atIndex:5];
3102
+ [encoder setBytes:&p0 length:sizeof(int32_t) atIndex:6];
3103
+ [encoder setBytes:&p1 length:sizeof(int32_t) atIndex:7];
3104
+ [encoder setBytes:&IH length:sizeof(int64_t) atIndex:8];
3105
+ [encoder setBytes:&IW length:sizeof(int64_t) atIndex:9];
3106
+ [encoder setBytes:&OH length:sizeof(int64_t) atIndex:10];
3107
+ [encoder setBytes:&OW length:sizeof(int64_t) atIndex:11];
3108
+ [encoder setBytes:&parallel_elements length:sizeof(int64_t) atIndex:12];
3109
+
3110
+ [encoder dispatchThreadgroups:MTLSizeMake(n_tg, 1, 1) threadsPerThreadgroup:MTLSizeMake(n_threads, 1, 1)];
3111
+ } break;
3004
3112
  default:
3005
3113
  {
3006
3114
  LM_GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, lm_ggml_op_name(dst->op));
@@ -3146,12 +3254,6 @@ static enum lm_ggml_status lm_ggml_metal_graph_compute(
3146
3254
 
3147
3255
  // backend interface
3148
3256
 
3149
- static const char * lm_ggml_backend_metal_buffer_get_name(lm_ggml_backend_buffer_t buffer) {
3150
- return "Metal";
3151
-
3152
- UNUSED(buffer);
3153
- }
3154
-
3155
3257
  static void lm_ggml_backend_metal_buffer_free_buffer(lm_ggml_backend_buffer_t buffer) {
3156
3258
  struct lm_ggml_backend_metal_buffer_context * ctx = (struct lm_ggml_backend_metal_buffer_context *)buffer->context;
3157
3259
 
@@ -3206,7 +3308,6 @@ static void lm_ggml_backend_metal_buffer_clear(lm_ggml_backend_buffer_t buffer,
3206
3308
  }
3207
3309
 
3208
3310
  static struct lm_ggml_backend_buffer_i lm_ggml_backend_metal_buffer_i = {
3209
- /* .get_name = */ lm_ggml_backend_metal_buffer_get_name,
3210
3311
  /* .free_buffer = */ lm_ggml_backend_metal_buffer_free_buffer,
3211
3312
  /* .get_base = */ lm_ggml_backend_metal_buffer_get_base,
3212
3313
  /* .init_tensor = */ NULL,
@@ -3331,6 +3432,29 @@ lm_ggml_backend_buffer_type_t lm_ggml_backend_metal_buffer_type(void) {
3331
3432
  return &lm_ggml_backend_buffer_type_metal;
3332
3433
  }
3333
3434
 
3435
+ static const char * lm_ggml_backend_metal_buffer_from_ptr_type_get_name(lm_ggml_backend_buffer_type_t buft) {
3436
+ return "Metal_Mapped";
3437
+
3438
+ UNUSED(buft);
3439
+ }
3440
+
3441
+ static lm_ggml_backend_buffer_type_t lm_ggml_backend_metal_buffer_from_ptr_type(void) {
3442
+ static struct lm_ggml_backend_buffer_type lm_ggml_backend_buffer_from_ptr_type_metal = {
3443
+ /* .iface = */ {
3444
+ /* .get_name = */ lm_ggml_backend_metal_buffer_from_ptr_type_get_name,
3445
+ /* .alloc_buffer = */ lm_ggml_backend_metal_buffer_type_alloc_buffer,
3446
+ /* .get_alignment = */ lm_ggml_backend_metal_buffer_type_get_alignment,
3447
+ /* .get_max_size = */ lm_ggml_backend_metal_buffer_type_get_max_size,
3448
+ /* .get_alloc_size = */ NULL, // defaults to lm_ggml_nbytes
3449
+ /* .is_host = */ lm_ggml_backend_metal_buffer_type_is_host,
3450
+ },
3451
+ /* .device = */ &g_lm_ggml_backend_metal_device,
3452
+ /* .context = */ NULL,
3453
+ };
3454
+
3455
+ return &lm_ggml_backend_buffer_from_ptr_type_metal;
3456
+ }
3457
+
3334
3458
  // TODO: obsoleted by lm_ggml_backend_metal_device_buffer_from_ptr
3335
3459
  lm_ggml_backend_buffer_t lm_ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size) {
3336
3460
  struct lm_ggml_backend_metal_buffer_context * ctx = calloc(1, sizeof(struct lm_ggml_backend_metal_buffer_context));
@@ -3407,7 +3531,7 @@ lm_ggml_backend_buffer_t lm_ggml_backend_metal_buffer_from_ptr(void * data, size
3407
3531
  }
3408
3532
  }
3409
3533
 
3410
- return lm_ggml_backend_buffer_init(lm_ggml_backend_metal_buffer_type(), lm_ggml_backend_metal_buffer_i, ctx, size);
3534
+ return lm_ggml_backend_buffer_init(lm_ggml_backend_metal_buffer_from_ptr_type(), lm_ggml_backend_metal_buffer_i, ctx, size);
3411
3535
  }
3412
3536
 
3413
3537
  // backend
@@ -3428,12 +3552,6 @@ static void lm_ggml_backend_metal_free(lm_ggml_backend_t backend) {
3428
3552
  free(backend);
3429
3553
  }
3430
3554
 
3431
- static lm_ggml_backend_buffer_type_t lm_ggml_backend_metal_get_default_buffer_type(lm_ggml_backend_t backend) {
3432
- return lm_ggml_backend_metal_buffer_type();
3433
-
3434
- UNUSED(backend);
3435
- }
3436
-
3437
3555
  static enum lm_ggml_status lm_ggml_backend_metal_graph_compute(lm_ggml_backend_t backend, struct lm_ggml_cgraph * cgraph) {
3438
3556
  return lm_ggml_metal_graph_compute(backend, cgraph);
3439
3557
  }
@@ -3500,7 +3618,6 @@ static void lm_ggml_backend_metal_set_n_cb(lm_ggml_backend_t backend, int n_cb)
3500
3618
  static struct lm_ggml_backend_i lm_ggml_backend_metal_i = {
3501
3619
  /* .get_name = */ lm_ggml_backend_metal_name,
3502
3620
  /* .free = */ lm_ggml_backend_metal_free,
3503
- /* .get_default_buffer_type = */ lm_ggml_backend_metal_get_default_buffer_type,
3504
3621
  /* .set_tensor_async = */ NULL,
3505
3622
  /* .get_tensor_async = */ NULL,
3506
3623
  /* .cpy_tensor_async = */ NULL,
@@ -3510,9 +3627,6 @@ static struct lm_ggml_backend_i lm_ggml_backend_metal_i = {
3510
3627
  /* .graph_plan_update = */ NULL,
3511
3628
  /* .graph_plan_compute = */ NULL,
3512
3629
  /* .graph_compute = */ lm_ggml_backend_metal_graph_compute,
3513
- /* .supports_op = */ NULL,
3514
- /* .supports_buft = */ NULL,
3515
- /* .offload_op = */ NULL,
3516
3630
  /* .event_record = */ NULL,
3517
3631
  /* .event_wait = */ NULL,
3518
3632
  };
@@ -3607,7 +3721,7 @@ static void lm_ggml_backend_metal_device_get_memory(lm_ggml_backend_dev_t dev, s
3607
3721
  }
3608
3722
 
3609
3723
  static enum lm_ggml_backend_dev_type lm_ggml_backend_metal_device_get_type(lm_ggml_backend_dev_t dev) {
3610
- return LM_GGML_BACKEND_DEVICE_TYPE_GPU_FULL;
3724
+ return LM_GGML_BACKEND_DEVICE_TYPE_GPU;
3611
3725
 
3612
3726
  LM_GGML_UNUSED(dev);
3613
3727
  }
package/cpp/ggml-quants.c CHANGED
@@ -4,7 +4,7 @@
4
4
  #include "ggml-quants.h"
5
5
  #include "ggml-impl.h"
6
6
  #include "ggml-cpu-impl.h"
7
-
7
+ #include "ggml-cpu.h"
8
8
 
9
9
  #include <math.h>
10
10
  #include <string.h>
@@ -9104,10 +9104,8 @@ void lm_ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void
9104
9104
 
9105
9105
  #elif defined __AVX__
9106
9106
 
9107
- const __m128i m4 = _mm_set1_epi8(0xF);
9108
9107
  const __m128i m3 = _mm_set1_epi8(3);
9109
- const __m128i m32s = _mm_set1_epi8(32);
9110
- const __m128i m2 = _mm_set1_epi8(2);
9108
+ const __m128i m15 = _mm_set1_epi8(15);
9111
9109
 
9112
9110
  __m256 acc = _mm256_setzero_ps();
9113
9111
 
@@ -9119,12 +9117,20 @@ void lm_ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void
9119
9117
  const uint8_t * restrict qh = x[i].qh;
9120
9118
  const int8_t * restrict q8 = y[i].qs;
9121
9119
 
9120
+ // handle the q6_k -32 offset separately using bsums
9121
+ const __m128i q8sums_0 = _mm_loadu_si128((const __m128i*)y[i].bsums);
9122
+ const __m128i q8sums_1 = _mm_loadu_si128((const __m128i*)y[i].bsums + 1);
9122
9123
  const __m128i scales = _mm_loadu_si128((const __m128i*)x[i].scales);
9124
+ const __m128i scales_16_0 = _mm_cvtepi8_epi16(scales);
9125
+ const __m128i scales_16_1 = _mm_cvtepi8_epi16(_mm_bsrli_si128(scales, 8));
9126
+ const __m128i q8sclsub_0 = _mm_slli_epi32(_mm_madd_epi16(q8sums_0, scales_16_0), 5);
9127
+ const __m128i q8sclsub_1 = _mm_slli_epi32(_mm_madd_epi16(q8sums_1, scales_16_1), 5);
9123
9128
 
9124
9129
  __m128i sumi_0 = _mm_setzero_si128();
9125
9130
  __m128i sumi_1 = _mm_setzero_si128();
9126
9131
 
9127
- __m128i shuffle = _mm_set_epi64x(0x0101010101010101, 0x0000000000000000);
9132
+ int is = 0;
9133
+
9128
9134
  for (int j = 0; j < QK_K/128; ++j) {
9129
9135
 
9130
9136
  const __m128i q4bitsH_0 = _mm_loadu_si128((const __m128i*)qh); qh += 16;
@@ -9132,26 +9138,26 @@ void lm_ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void
9132
9138
 
9133
9139
  const __m128i q4h_0 = _mm_slli_epi16(_mm_and_si128(q4bitsH_0, m3), 4);
9134
9140
  const __m128i q4h_1 = _mm_slli_epi16(_mm_and_si128(q4bitsH_1, m3), 4);
9135
- const __m128i q4h_2 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_0, 2), m3), 4);
9136
- const __m128i q4h_3 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_1, 2), m3), 4);
9137
- const __m128i q4h_4 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_0, 4), m3), 4);
9138
- const __m128i q4h_5 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_1, 4), m3), 4);
9139
- const __m128i q4h_6 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_0, 6), m3), 4);
9140
- const __m128i q4h_7 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_1, 6), m3), 4);
9141
+ const __m128i q4h_2 = _mm_slli_epi16(_mm_and_si128(q4bitsH_0, _mm_set1_epi8(12)), 2);
9142
+ const __m128i q4h_3 = _mm_slli_epi16(_mm_and_si128(q4bitsH_1, _mm_set1_epi8(12)), 2);
9143
+ const __m128i q4h_4 = _mm_and_si128(q4bitsH_0, _mm_set1_epi8(48));
9144
+ const __m128i q4h_5 = _mm_and_si128(q4bitsH_1, _mm_set1_epi8(48));
9145
+ const __m128i q4h_6 = _mm_srli_epi16(_mm_and_si128(q4bitsH_0, _mm_set1_epi8(-64)), 2);
9146
+ const __m128i q4h_7 = _mm_srli_epi16(_mm_and_si128(q4bitsH_1, _mm_set1_epi8(-64)), 2);
9141
9147
 
9142
9148
  const __m128i q4bits1_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
9143
9149
  const __m128i q4bits1_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
9144
9150
  const __m128i q4bits2_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
9145
9151
  const __m128i q4bits2_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
9146
9152
 
9147
- const __m128i q4_0 = _mm_or_si128(_mm_and_si128(q4bits1_0, m4), q4h_0);
9148
- const __m128i q4_1 = _mm_or_si128(_mm_and_si128(q4bits1_1, m4), q4h_1);
9149
- const __m128i q4_2 = _mm_or_si128(_mm_and_si128(q4bits2_0, m4), q4h_2);
9150
- const __m128i q4_3 = _mm_or_si128(_mm_and_si128(q4bits2_1, m4), q4h_3);
9151
- const __m128i q4_4 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_0, 4), m4), q4h_4);
9152
- const __m128i q4_5 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_1, 4), m4), q4h_5);
9153
- const __m128i q4_6 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_0, 4), m4), q4h_6);
9154
- const __m128i q4_7 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_1, 4), m4), q4h_7);
9153
+ const __m128i q4_0 = _mm_or_si128(_mm_and_si128(q4bits1_0, m15), q4h_0);
9154
+ const __m128i q4_1 = _mm_or_si128(_mm_and_si128(q4bits1_1, m15), q4h_1);
9155
+ const __m128i q4_2 = _mm_or_si128(_mm_and_si128(q4bits2_0, m15), q4h_2);
9156
+ const __m128i q4_3 = _mm_or_si128(_mm_and_si128(q4bits2_1, m15), q4h_3);
9157
+ const __m128i q4_4 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_0, 4), m15), q4h_4);
9158
+ const __m128i q4_5 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_1, 4), m15), q4h_5);
9159
+ const __m128i q4_6 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_0, 4), m15), q4h_6);
9160
+ const __m128i q4_7 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_1, 4), m15), q4h_7);
9155
9161
 
9156
9162
  const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
9157
9163
  const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
@@ -9162,15 +9168,6 @@ void lm_ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void
9162
9168
  const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
9163
9169
  const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
9164
9170
 
9165
- __m128i q8s_0 = _mm_maddubs_epi16(m32s, q8_0);
9166
- __m128i q8s_1 = _mm_maddubs_epi16(m32s, q8_1);
9167
- __m128i q8s_2 = _mm_maddubs_epi16(m32s, q8_2);
9168
- __m128i q8s_3 = _mm_maddubs_epi16(m32s, q8_3);
9169
- __m128i q8s_4 = _mm_maddubs_epi16(m32s, q8_4);
9170
- __m128i q8s_5 = _mm_maddubs_epi16(m32s, q8_5);
9171
- __m128i q8s_6 = _mm_maddubs_epi16(m32s, q8_6);
9172
- __m128i q8s_7 = _mm_maddubs_epi16(m32s, q8_7);
9173
-
9174
9171
  __m128i p16_0 = _mm_maddubs_epi16(q4_0, q8_0);
9175
9172
  __m128i p16_1 = _mm_maddubs_epi16(q4_1, q8_1);
9176
9173
  __m128i p16_2 = _mm_maddubs_epi16(q4_2, q8_2);
@@ -9180,32 +9177,20 @@ void lm_ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void
9180
9177
  __m128i p16_6 = _mm_maddubs_epi16(q4_6, q8_6);
9181
9178
  __m128i p16_7 = _mm_maddubs_epi16(q4_7, q8_7);
9182
9179
 
9183
- p16_0 = _mm_sub_epi16(p16_0, q8s_0);
9184
- p16_1 = _mm_sub_epi16(p16_1, q8s_1);
9185
- p16_2 = _mm_sub_epi16(p16_2, q8s_2);
9186
- p16_3 = _mm_sub_epi16(p16_3, q8s_3);
9187
- p16_4 = _mm_sub_epi16(p16_4, q8s_4);
9188
- p16_5 = _mm_sub_epi16(p16_5, q8s_5);
9189
- p16_6 = _mm_sub_epi16(p16_6, q8s_6);
9190
- p16_7 = _mm_sub_epi16(p16_7, q8s_7);
9191
-
9192
- const __m128i scale_0 = _mm_shuffle_epi8(scales, shuffle);
9193
- shuffle = _mm_add_epi8(shuffle, m2);
9194
- const __m128i scale_1 = _mm_shuffle_epi8(scales, shuffle);
9195
- shuffle = _mm_add_epi8(shuffle, m2);
9196
- const __m128i scale_2 = _mm_shuffle_epi8(scales, shuffle);
9197
- shuffle = _mm_add_epi8(shuffle, m2);
9198
- const __m128i scale_3 = _mm_shuffle_epi8(scales, shuffle);
9199
- shuffle = _mm_add_epi8(shuffle, m2);
9180
+ const __m128i scale_0 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 0));
9181
+ const __m128i scale_1 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 1));
9182
+ const __m128i scale_2 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 2));
9183
+ const __m128i scale_3 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 3));
9184
+ is += 4;
9200
9185
 
9201
9186
  p16_0 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_0), p16_0);
9202
- p16_1 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_0, scale_0)), p16_1);
9187
+ p16_1 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_0, 8)), p16_1);
9203
9188
  p16_2 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_1), p16_2);
9204
- p16_3 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_1, scale_1)), p16_3);
9189
+ p16_3 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_1, 8)), p16_3);
9205
9190
  p16_4 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_2), p16_4);
9206
- p16_5 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_2, scale_2)), p16_5);
9191
+ p16_5 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_2, 8)), p16_5);
9207
9192
  p16_6 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_3), p16_6);
9208
- p16_7 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_3, scale_3)), p16_7);
9193
+ p16_7 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_3, 8)), p16_7);
9209
9194
 
9210
9195
  sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2));
9211
9196
  sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3));
@@ -9214,8 +9199,10 @@ void lm_ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void
9214
9199
 
9215
9200
  }
9216
9201
 
9217
- __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
9218
- acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi)), acc);
9202
+ sumi_0 = _mm_sub_epi32(sumi_0, q8sclsub_0);
9203
+ sumi_1 = _mm_sub_epi32(sumi_1, q8sclsub_1);
9204
+ const __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
9205
+ acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(sumi)), acc);
9219
9206
  }
9220
9207
 
9221
9208
  *s = hsum_float_8(acc);