cui-llama.rn 1.1.0 → 1.1.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/cpp/common.cpp CHANGED
@@ -333,6 +333,10 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
333
333
  void gpt_params_parse_from_env(gpt_params & params) {
334
334
  // we only care about server-related params for now
335
335
  get_env("LLAMA_ARG_MODEL", params.model);
336
+ get_env("LLAMA_ARG_MODEL_URL", params.model_url);
337
+ get_env("LLAMA_ARG_MODEL_ALIAS", params.model_alias);
338
+ get_env("LLAMA_ARG_HF_REPO", params.hf_repo);
339
+ get_env("LLAMA_ARG_HF_FILE", params.hf_file);
336
340
  get_env("LLAMA_ARG_THREADS", params.n_threads);
337
341
  get_env("LLAMA_ARG_CTX_SIZE", params.n_ctx);
338
342
  get_env("LLAMA_ARG_N_PARALLEL", params.n_parallel);
@@ -347,6 +351,9 @@ void gpt_params_parse_from_env(gpt_params & params) {
347
351
  get_env("LLAMA_ARG_EMBEDDINGS", params.embedding);
348
352
  get_env("LLAMA_ARG_FLASH_ATTN", params.flash_attn);
349
353
  get_env("LLAMA_ARG_DEFRAG_THOLD", params.defrag_thold);
354
+ get_env("LLAMA_ARG_CONT_BATCHING", params.cont_batching);
355
+ get_env("LLAMA_ARG_HOST", params.hostname);
356
+ get_env("LLAMA_ARG_PORT", params.port);
350
357
  }
351
358
 
352
359
  bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
@@ -907,7 +914,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
907
914
  }
908
915
  return true;
909
916
  }
910
- if (arg == "-ngld" || arg == "--gpu-layers-draft" || arg == "--gpu-layers-draft") {
917
+ if (arg == "-ngld" || arg == "--gpu-layers-draft" || arg == "--n-gpu-layers-draft") {
911
918
  CHECK_ARG
912
919
  params.n_gpu_layers_draft = std::stoi(argv[i]);
913
920
  if (!llama_supports_gpu_offload()) {
@@ -1867,13 +1874,19 @@ std::string string_get_sortable_timestamp() {
1867
1874
 
1868
1875
  void string_replace_all(std::string & s, const std::string & search, const std::string & replace) {
1869
1876
  if (search.empty()) {
1870
- return; // Avoid infinite loop if 'search' is an empty string
1877
+ return;
1871
1878
  }
1879
+ std::string builder;
1880
+ builder.reserve(s.length());
1872
1881
  size_t pos = 0;
1873
- while ((pos = s.find(search, pos)) != std::string::npos) {
1874
- s.replace(pos, search.length(), replace);
1875
- pos += replace.length();
1876
- }
1882
+ size_t last_pos = 0;
1883
+ while ((pos = s.find(search, last_pos)) != std::string::npos) {
1884
+ builder.append(s, last_pos, pos - last_pos);
1885
+ builder.append(replace);
1886
+ last_pos = pos + search.length();
1887
+ }
1888
+ builder.append(s, last_pos, std::string::npos);
1889
+ s = std::move(builder);
1877
1890
  }
1878
1891
 
1879
1892
  void string_process_escapes(std::string & input) {
@@ -337,33 +337,18 @@ static size_t quantize_q4_0_nr_bl(const float * restrict src, void * restrict ds
337
337
  }
338
338
 
339
339
  size_t quantize_q4_0_4x4(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
340
- if (!quant_weights) {
341
- return quantize_q4_0_nr_bl(src, dst, nrow, n_per_row, 4, 4);
342
- }
343
- else {
344
- assert(false);
345
- return 0;
346
- }
340
+ UNUSED(quant_weights);
341
+ return quantize_q4_0_nr_bl(src, dst, nrow, n_per_row, 4, 4);
347
342
  }
348
343
 
349
344
  size_t quantize_q4_0_4x8(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
350
- if (!quant_weights) {
351
- return quantize_q4_0_nr_bl(src, dst, nrow, n_per_row, 4, 8);
352
- }
353
- else {
354
- assert(false);
355
- return 0;
356
- }
345
+ UNUSED(quant_weights);
346
+ return quantize_q4_0_nr_bl(src, dst, nrow, n_per_row, 4, 8);
357
347
  }
358
348
 
359
349
  size_t quantize_q4_0_8x8(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
360
- if (!quant_weights) {
361
- return quantize_q4_0_nr_bl(src, dst, nrow, n_per_row, 8, 8);
362
- }
363
- else {
364
- assert(false);
365
- return 0;
366
- }
350
+ UNUSED(quant_weights);
351
+ return quantize_q4_0_nr_bl(src, dst, nrow, n_per_row, 8, 8);
367
352
  }
368
353
 
369
354
  void lm_ggml_gemv_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) {
package/cpp/ggml-metal.m CHANGED
@@ -82,6 +82,8 @@ enum lm_ggml_metal_kernel_type {
82
82
  LM_GGML_METAL_KERNEL_TYPE_RMS_NORM,
83
83
  LM_GGML_METAL_KERNEL_TYPE_GROUP_NORM,
84
84
  LM_GGML_METAL_KERNEL_TYPE_NORM,
85
+ LM_GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,
86
+ LM_GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32,
85
87
  LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
86
88
  LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,
87
89
  LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
@@ -542,6 +544,8 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(int n_cb) {
542
544
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction);
543
545
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction);
544
546
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_NORM, norm, true);
547
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true);
548
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true);
545
549
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, ctx->support_simdgroup_reduction);
546
550
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, ctx->support_simdgroup_reduction);
547
551
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, ctx->support_simdgroup_reduction);
@@ -803,6 +807,9 @@ static bool lm_ggml_metal_supports_op(const struct lm_ggml_backend_metal_context
803
807
  return false;
804
808
  }
805
809
  return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels
810
+ case LM_GGML_OP_SSM_CONV:
811
+ case LM_GGML_OP_SSM_SCAN:
812
+ return true;
806
813
  case LM_GGML_OP_MUL_MAT:
807
814
  case LM_GGML_OP_MUL_MAT_ID:
808
815
  return ctx->support_simdgroup_reduction &&
@@ -1538,6 +1545,121 @@ static enum lm_ggml_status lm_ggml_metal_graph_compute(
1538
1545
  [encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1539
1546
  }
1540
1547
  } break;
1548
+ case LM_GGML_OP_SSM_CONV:
1549
+ {
1550
+ LM_GGML_ASSERT(src0t == LM_GGML_TYPE_F32);
1551
+ LM_GGML_ASSERT(src1t == LM_GGML_TYPE_F32);
1552
+
1553
+ LM_GGML_ASSERT(lm_ggml_is_contiguous(src0));
1554
+ LM_GGML_ASSERT(lm_ggml_is_contiguous(src1));
1555
+
1556
+ id<MTLComputePipelineState> pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_SSM_CONV_F32].pipeline;
1557
+
1558
+ [encoder setComputePipelineState:pipeline];
1559
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1560
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1561
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1562
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1563
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
1564
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
1565
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
1566
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
1567
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
1568
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:9];
1569
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:10];
1570
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:11];
1571
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:12];
1572
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
1573
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
1574
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:15];
1575
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:16];
1576
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:17];
1577
+ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:18];
1578
+
1579
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne1, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1580
+ } break;
1581
+ case LM_GGML_OP_SSM_SCAN:
1582
+ {
1583
+ struct lm_ggml_tensor * src3 = gf->nodes[i]->src[3];
1584
+ struct lm_ggml_tensor * src4 = gf->nodes[i]->src[4];
1585
+ struct lm_ggml_tensor * src5 = gf->nodes[i]->src[5];
1586
+
1587
+ LM_GGML_ASSERT(src3);
1588
+ LM_GGML_ASSERT(src4);
1589
+ LM_GGML_ASSERT(src5);
1590
+
1591
+ size_t offs_src3 = 0;
1592
+ size_t offs_src4 = 0;
1593
+ size_t offs_src5 = 0;
1594
+
1595
+ id<MTLBuffer> id_src3 = src3 ? lm_ggml_metal_get_buffer(src3, &offs_src3) : nil;
1596
+ id<MTLBuffer> id_src4 = src4 ? lm_ggml_metal_get_buffer(src4, &offs_src4) : nil;
1597
+ id<MTLBuffer> id_src5 = src5 ? lm_ggml_metal_get_buffer(src5, &offs_src5) : nil;
1598
+
1599
+ const int64_t ne30 = src3->ne[0]; LM_GGML_UNUSED(ne30);
1600
+ const int64_t ne31 = src3->ne[1]; LM_GGML_UNUSED(ne31);
1601
+
1602
+ const uint64_t nb30 = src3->nb[0];
1603
+ const uint64_t nb31 = src3->nb[1];
1604
+
1605
+ const int64_t ne40 = src4->ne[0]; LM_GGML_UNUSED(ne40);
1606
+ const int64_t ne41 = src4->ne[1]; LM_GGML_UNUSED(ne41);
1607
+ const int64_t ne42 = src4->ne[2]; LM_GGML_UNUSED(ne42);
1608
+
1609
+ const uint64_t nb40 = src4->nb[0];
1610
+ const uint64_t nb41 = src4->nb[1];
1611
+ const uint64_t nb42 = src4->nb[2];
1612
+
1613
+ const int64_t ne50 = src5->ne[0]; LM_GGML_UNUSED(ne50);
1614
+ const int64_t ne51 = src5->ne[1]; LM_GGML_UNUSED(ne51);
1615
+ const int64_t ne52 = src5->ne[2]; LM_GGML_UNUSED(ne52);
1616
+
1617
+ const uint64_t nb50 = src5->nb[0];
1618
+ const uint64_t nb51 = src5->nb[1];
1619
+ const uint64_t nb52 = src5->nb[2];
1620
+
1621
+ const int64_t d_state = ne00;
1622
+ const int64_t d_inner = ne01;
1623
+ const int64_t n_seq_tokens = ne11;
1624
+ const int64_t n_seqs = ne02;
1625
+
1626
+ id<MTLComputePipelineState> pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline;
1627
+
1628
+ [encoder setComputePipelineState:pipeline];
1629
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1630
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1631
+ [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
1632
+ [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
1633
+ [encoder setBuffer:id_src4 offset:offs_src4 atIndex:4];
1634
+ [encoder setBuffer:id_src5 offset:offs_src5 atIndex:5];
1635
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:6];
1636
+
1637
+ [encoder setBytes:&d_state length:sizeof(d_state) atIndex:7];
1638
+ [encoder setBytes:&d_inner length:sizeof(d_inner) atIndex:8];
1639
+ [encoder setBytes:&n_seq_tokens length:sizeof(n_seq_tokens) atIndex:9];
1640
+ [encoder setBytes:&n_seqs length:sizeof(n_seqs) atIndex:10];
1641
+
1642
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:11];
1643
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:12];
1644
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:13];
1645
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
1646
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
1647
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
1648
+ [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17];
1649
+ [encoder setBytes:&nb20 length:sizeof(nb20) atIndex:18];
1650
+ [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:19];
1651
+ [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:20];
1652
+ [encoder setBytes:&nb30 length:sizeof(nb30) atIndex:21];
1653
+ [encoder setBytes:&nb31 length:sizeof(nb31) atIndex:22];
1654
+ [encoder setBytes:&nb40 length:sizeof(nb40) atIndex:23];
1655
+ [encoder setBytes:&nb41 length:sizeof(nb41) atIndex:24];
1656
+ [encoder setBytes:&nb42 length:sizeof(nb42) atIndex:25];
1657
+ [encoder setBytes:&nb50 length:sizeof(nb50) atIndex:26];
1658
+ [encoder setBytes:&nb51 length:sizeof(nb51) atIndex:27];
1659
+ [encoder setBytes:&nb52 length:sizeof(nb52) atIndex:28];
1660
+
1661
+ [encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1662
+ } break;
1541
1663
  case LM_GGML_OP_MUL_MAT:
1542
1664
  {
1543
1665
  LM_GGML_ASSERT(ne00 == ne10);
@@ -2624,9 +2746,14 @@ static enum lm_ggml_status lm_ggml_metal_graph_compute(
2624
2746
 
2625
2747
  float scale;
2626
2748
  float max_bias;
2749
+ float logit_softcap;
2750
+ memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale));
2751
+ memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias));
2752
+ memcpy(&logit_softcap, ((int32_t *) dst->op_params) + 2, sizeof(logit_softcap));
2627
2753
 
2628
- memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale));
2629
- memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias));
2754
+ if (logit_softcap != 0.0f) {
2755
+ scale /= logit_softcap;
2756
+ }
2630
2757
 
2631
2758
  const uint32_t n_head = src0->ne[2];
2632
2759
  const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
@@ -2677,30 +2804,31 @@ static enum lm_ggml_status lm_ggml_metal_graph_compute(
2677
2804
  } else {
2678
2805
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:3];
2679
2806
  }
2680
- [encoder setBuffer:id_dst offset:offs_dst atIndex:4];
2681
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5];
2682
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6];
2683
- [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7];
2684
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8];
2685
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9];
2686
- [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10];
2687
- [encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:11];
2688
- [encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:12];
2689
- [encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:13];
2690
- [encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:14];
2691
- [encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:15];
2692
- [encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:16];
2693
- [encoder setBytes:&nb21 length:sizeof(uint64_t) atIndex:17];
2694
- [encoder setBytes:&nb22 length:sizeof(uint64_t) atIndex:18];
2695
- [encoder setBytes:&nb23 length:sizeof(uint64_t) atIndex:19];
2696
- [encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:20];
2697
- [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:21];
2698
- [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:22];
2699
- [encoder setBytes:&scale length:sizeof( float) atIndex:23];
2700
- [encoder setBytes:&max_bias length:sizeof( float) atIndex:24];
2701
- [encoder setBytes:&m0 length:sizeof(m0) atIndex:25];
2702
- [encoder setBytes:&m1 length:sizeof(m1) atIndex:26];
2703
- [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:27];
2807
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:4];
2808
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5];
2809
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6];
2810
+ [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7];
2811
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8];
2812
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9];
2813
+ [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10];
2814
+ [encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:11];
2815
+ [encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:12];
2816
+ [encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:13];
2817
+ [encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:14];
2818
+ [encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:15];
2819
+ [encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:16];
2820
+ [encoder setBytes:&nb21 length:sizeof(uint64_t) atIndex:17];
2821
+ [encoder setBytes:&nb22 length:sizeof(uint64_t) atIndex:18];
2822
+ [encoder setBytes:&nb23 length:sizeof(uint64_t) atIndex:19];
2823
+ [encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:20];
2824
+ [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:21];
2825
+ [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:22];
2826
+ [encoder setBytes:&scale length:sizeof( float) atIndex:23];
2827
+ [encoder setBytes:&max_bias length:sizeof( float) atIndex:24];
2828
+ [encoder setBytes:&m0 length:sizeof(m0) atIndex:25];
2829
+ [encoder setBytes:&m1 length:sizeof(m1) atIndex:26];
2830
+ [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:27];
2831
+ [encoder setBytes:&logit_softcap length:sizeof(logit_softcap) atIndex:28];
2704
2832
 
2705
2833
  if (!use_vec_kernel) {
2706
2834
  // half8x8 kernel