@fugood/llama.node 1.0.2 → 1.0.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. package/package.json +14 -14
  2. package/src/llama.cpp/CMakeLists.txt +0 -1
  3. package/src/llama.cpp/common/arg.cpp +7 -0
  4. package/src/llama.cpp/common/common.h +1 -0
  5. package/src/llama.cpp/ggml/CMakeLists.txt +7 -2
  6. package/src/llama.cpp/ggml/include/ggml.h +91 -10
  7. package/src/llama.cpp/ggml/src/CMakeLists.txt +0 -1
  8. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +6 -1
  9. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +12 -1
  10. package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +726 -155
  11. package/src/llama.cpp/ggml/src/ggml-cpu/ops.h +5 -0
  12. package/src/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +1 -1
  13. package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +9 -9
  14. package/src/llama.cpp/ggml/src/ggml-cpu/vec.h +49 -9
  15. package/src/llama.cpp/include/llama.h +1 -0
  16. package/src/llama.cpp/src/llama-arch.cpp +90 -2
  17. package/src/llama.cpp/src/llama-arch.h +6 -0
  18. package/src/llama.cpp/src/llama-batch.cpp +27 -1
  19. package/src/llama.cpp/src/llama-batch.h +8 -1
  20. package/src/llama.cpp/src/llama-chat.cpp +15 -0
  21. package/src/llama.cpp/src/llama-chat.h +1 -0
  22. package/src/llama.cpp/src/llama-graph.cpp +64 -50
  23. package/src/llama.cpp/src/llama-graph.h +41 -16
  24. package/src/llama.cpp/src/llama-hparams.cpp +2 -1
  25. package/src/llama.cpp/src/llama-hparams.h +1 -0
  26. package/src/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +28 -18
  27. package/src/llama.cpp/src/llama-kv-cache-unified-iswa.h +4 -2
  28. package/src/llama.cpp/src/llama-kv-cache-unified.cpp +214 -65
  29. package/src/llama.cpp/src/llama-kv-cache-unified.h +62 -24
  30. package/src/llama.cpp/src/llama-kv-cells.h +62 -10
  31. package/src/llama.cpp/src/llama-memory-hybrid.cpp +9 -4
  32. package/src/llama.cpp/src/llama-memory-hybrid.h +3 -1
  33. package/src/llama.cpp/src/llama-memory-recurrent.cpp +15 -2
  34. package/src/llama.cpp/src/llama-memory.cpp +17 -0
  35. package/src/llama.cpp/src/llama-memory.h +3 -0
  36. package/src/llama.cpp/src/llama-model.cpp +1234 -248
  37. package/src/llama.cpp/src/llama-model.h +2 -0
  38. package/src/llama.cpp/src/llama-vocab.cpp +8 -1
  39. package/src/llama.cpp/ggml/include/ggml-kompute.h +0 -50
@@ -3,6 +3,7 @@
3
3
  #include "ggml-cpu.h"
4
4
  #include "ggml-impl.h"
5
5
  #include "binary-ops.h"
6
+ #include "ggml.h"
6
7
  #include "unary-ops.h"
7
8
  #include "vec.h"
8
9
 
@@ -3613,6 +3614,292 @@ static void ggml_compute_forward_swiglu(
3613
3614
  }
3614
3615
  }
3615
3616
 
3617
+ // ggml_compute_forward_geglu_erf
3618
+
3619
+ static void ggml_compute_forward_geglu_erf_f32(
3620
+ const ggml_compute_params * params,
3621
+ ggml_tensor * dst) {
3622
+
3623
+ const ggml_tensor * src0 = dst->src[0];
3624
+ const ggml_tensor * src1 = dst->src[1];
3625
+ char * src0_d = (char *) src0->data;
3626
+ char * src1_d = (char *) (src1 ? src1->data : src0->data);
3627
+ const size_t src0_o = src0->nb[1];
3628
+ const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3629
+
3630
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
3631
+ GGML_ASSERT(ggml_is_contiguous_1(dst));
3632
+
3633
+ if (src1) {
3634
+ GGML_ASSERT(ggml_is_contiguous_1(src1));
3635
+ GGML_ASSERT(src0->type == src1->type);
3636
+ }
3637
+
3638
+ const int ith = params->ith;
3639
+ const int nth = params->nth;
3640
+
3641
+ const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3642
+ const int nr = ggml_nrows(src0);
3643
+
3644
+ GGML_ASSERT(dst->ne[0] == nc);
3645
+ GGML_ASSERT(ggml_nrows(dst) == nr);
3646
+
3647
+ const int32_t swapped = ggml_get_op_params_i32(dst, 1);
3648
+
3649
+ // rows per thread
3650
+ const int dr = (nr + nth - 1)/nth;
3651
+
3652
+ // row range for this thread
3653
+ const int ir0 = dr*ith;
3654
+ const int ir1 = MIN(ir0 + dr, nr);
3655
+
3656
+ for (int i1 = ir0; i1 < ir1; i1++) {
3657
+ float * src0_p = (float *) (src0_d + i1*src0_o);
3658
+ float * src1_p = (float *) (src1_d + i1*src1_o);
3659
+
3660
+ if (!src1) {
3661
+ src0_p += swapped ? nc : 0;
3662
+ src1_p += swapped ? 0 : nc;
3663
+ }
3664
+
3665
+ ggml_vec_geglu_erf_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3666
+
3667
+ #ifndef NDEBUG
3668
+ for (int k = 0; k < nc; k++) {
3669
+ const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3670
+ GGML_UNUSED(x);
3671
+ assert(!isnan(x));
3672
+ assert(!isinf(x));
3673
+ }
3674
+ #endif
3675
+ }
3676
+ }
3677
+
3678
+ static void ggml_compute_forward_geglu_erf_f16(
3679
+ const ggml_compute_params * params,
3680
+ ggml_tensor * dst) {
3681
+
3682
+ const ggml_tensor * src0 = dst->src[0];
3683
+ const ggml_tensor * src1 = dst->src[1];
3684
+ char * src0_d = (char *) src0->data;
3685
+ char * src1_d = (char *) (src1 ? src1->data : src0->data);
3686
+ const size_t src0_o = src0->nb[1];
3687
+ const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3688
+
3689
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
3690
+ GGML_ASSERT(ggml_is_contiguous_1(dst));
3691
+
3692
+ if (src1) {
3693
+ GGML_ASSERT(ggml_is_contiguous_1(src1));
3694
+ GGML_ASSERT(src0->type == src1->type);
3695
+ }
3696
+
3697
+ const int ith = params->ith;
3698
+ const int nth = params->nth;
3699
+
3700
+ const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3701
+ const int nr = ggml_nrows(src0);
3702
+
3703
+ GGML_ASSERT(dst->ne[0] == nc);
3704
+ GGML_ASSERT(ggml_nrows(dst) == nr);
3705
+
3706
+ const int32_t swapped = ggml_get_op_params_i32(dst, 1);
3707
+
3708
+ // rows per thread
3709
+ const int dr = (nr + nth - 1)/nth;
3710
+
3711
+ // row range for this thread
3712
+ const int ir0 = dr*ith;
3713
+ const int ir1 = MIN(ir0 + dr, nr);
3714
+
3715
+ for (int i1 = ir0; i1 < ir1; i1++) {
3716
+ ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
3717
+ ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
3718
+
3719
+ if (!src1) {
3720
+ src0_p += swapped ? nc : 0;
3721
+ src1_p += swapped ? 0 : nc;
3722
+ }
3723
+
3724
+ ggml_vec_geglu_erf_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3725
+
3726
+ #ifndef NDEBUG
3727
+ for (int k = 0; k < nc; k++) {
3728
+ const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3729
+ const float v = GGML_FP16_TO_FP32(x);
3730
+ GGML_UNUSED(v);
3731
+ assert(!isnan(v));
3732
+ assert(!isinf(v));
3733
+ }
3734
+ #endif
3735
+ }
3736
+ }
3737
+
3738
+ static void ggml_compute_forward_geglu_erf(
3739
+ const ggml_compute_params * params,
3740
+ ggml_tensor * dst) {
3741
+
3742
+ const ggml_tensor * src0 = dst->src[0];
3743
+
3744
+ switch (src0->type) {
3745
+ case GGML_TYPE_F32:
3746
+ {
3747
+ ggml_compute_forward_geglu_erf_f32(params, dst);
3748
+ } break;
3749
+ case GGML_TYPE_F16:
3750
+ {
3751
+ ggml_compute_forward_geglu_erf_f16(params, dst);
3752
+ } break;
3753
+ default:
3754
+ {
3755
+ GGML_ABORT("fatal error");
3756
+ }
3757
+ }
3758
+ }
3759
+
3760
+ // ggml_compute_forward_geglu_quick
3761
+
3762
+ static void ggml_compute_forward_geglu_quick_f32(
3763
+ const ggml_compute_params * params,
3764
+ ggml_tensor * dst) {
3765
+
3766
+ const ggml_tensor * src0 = dst->src[0];
3767
+ const ggml_tensor * src1 = dst->src[1];
3768
+ char * src0_d = (char *) src0->data;
3769
+ char * src1_d = (char *) (src1 ? src1->data : src0->data);
3770
+ const size_t src0_o = src0->nb[1];
3771
+ const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3772
+
3773
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
3774
+ GGML_ASSERT(ggml_is_contiguous_1(dst));
3775
+
3776
+ if (src1) {
3777
+ GGML_ASSERT(ggml_is_contiguous_1(src1));
3778
+ GGML_ASSERT(src0->type == src1->type);
3779
+ }
3780
+
3781
+ const int ith = params->ith;
3782
+ const int nth = params->nth;
3783
+
3784
+ const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3785
+ const int nr = ggml_nrows(src0);
3786
+
3787
+ GGML_ASSERT(dst->ne[0] == nc);
3788
+ GGML_ASSERT(ggml_nrows(dst) == nr);
3789
+
3790
+ const int32_t swapped = ggml_get_op_params_i32(dst, 1);
3791
+
3792
+ // rows per thread
3793
+ const int dr = (nr + nth - 1)/nth;
3794
+
3795
+ // row range for this thread
3796
+ const int ir0 = dr*ith;
3797
+ const int ir1 = MIN(ir0 + dr, nr);
3798
+
3799
+ for (int i1 = ir0; i1 < ir1; i1++) {
3800
+ float * src0_p = (float *) (src0_d + i1*src0_o);
3801
+ float * src1_p = (float *) (src1_d + i1*src1_o);
3802
+
3803
+ if (!src1) {
3804
+ src0_p += swapped ? nc : 0;
3805
+ src1_p += swapped ? 0 : nc;
3806
+ }
3807
+
3808
+ ggml_vec_geglu_quick_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3809
+
3810
+ #ifndef NDEBUG
3811
+ for (int k = 0; k < nc; k++) {
3812
+ const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3813
+ GGML_UNUSED(x);
3814
+ assert(!isnan(x));
3815
+ assert(!isinf(x));
3816
+ }
3817
+ #endif
3818
+ }
3819
+ }
3820
+
3821
+ static void ggml_compute_forward_geglu_quick_f16(
3822
+ const ggml_compute_params * params,
3823
+ ggml_tensor * dst) {
3824
+
3825
+ const ggml_tensor * src0 = dst->src[0];
3826
+ const ggml_tensor * src1 = dst->src[1];
3827
+ char * src0_d = (char *) src0->data;
3828
+ char * src1_d = (char *) (src1 ? src1->data : src0->data);
3829
+ const size_t src0_o = src0->nb[1];
3830
+ const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3831
+
3832
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
3833
+ GGML_ASSERT(ggml_is_contiguous_1(dst));
3834
+
3835
+ if (src1) {
3836
+ GGML_ASSERT(ggml_is_contiguous_1(src1));
3837
+ GGML_ASSERT(src0->type == src1->type);
3838
+ }
3839
+
3840
+ const int ith = params->ith;
3841
+ const int nth = params->nth;
3842
+
3843
+ const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3844
+ const int nr = ggml_nrows(src0);
3845
+
3846
+ GGML_ASSERT(dst->ne[0] == nc);
3847
+ GGML_ASSERT(ggml_nrows(dst) == nr);
3848
+
3849
+ const int32_t swapped = ggml_get_op_params_i32(dst, 1);
3850
+
3851
+ // rows per thread
3852
+ const int dr = (nr + nth - 1)/nth;
3853
+
3854
+ // row range for this thread
3855
+ const int ir0 = dr*ith;
3856
+ const int ir1 = MIN(ir0 + dr, nr);
3857
+
3858
+ for (int i1 = ir0; i1 < ir1; i1++) {
3859
+ ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
3860
+ ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
3861
+
3862
+ if (!src1) {
3863
+ src0_p += swapped ? nc : 0;
3864
+ src1_p += swapped ? 0 : nc;
3865
+ }
3866
+
3867
+ ggml_vec_geglu_quick_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3868
+
3869
+ #ifndef NDEBUG
3870
+ for (int k = 0; k < nc; k++) {
3871
+ const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3872
+ const float v = GGML_FP16_TO_FP32(x);
3873
+ GGML_UNUSED(v);
3874
+ assert(!isnan(v));
3875
+ assert(!isinf(v));
3876
+ }
3877
+ #endif
3878
+ }
3879
+ }
3880
+
3881
+ static void ggml_compute_forward_geglu_quick(
3882
+ const ggml_compute_params * params,
3883
+ ggml_tensor * dst) {
3884
+
3885
+ const ggml_tensor * src0 = dst->src[0];
3886
+
3887
+ switch (src0->type) {
3888
+ case GGML_TYPE_F32:
3889
+ {
3890
+ ggml_compute_forward_geglu_quick_f32(params, dst);
3891
+ } break;
3892
+ case GGML_TYPE_F16:
3893
+ {
3894
+ ggml_compute_forward_geglu_quick_f16(params, dst);
3895
+ } break;
3896
+ default:
3897
+ {
3898
+ GGML_ABORT("fatal error");
3899
+ }
3900
+ }
3901
+ }
3902
+
3616
3903
  // ggml_compute_forward_norm
3617
3904
 
3618
3905
  static void ggml_compute_forward_norm_f32(
@@ -5231,14 +5518,17 @@ static void ggml_compute_forward_soft_max_f32(
5231
5518
  memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
5232
5519
  memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
5233
5520
 
5234
- // TODO: handle transposed/permuted matrices
5235
-
5236
5521
  const int ith = params->ith;
5237
5522
  const int nth = params->nth;
5238
5523
 
5239
5524
  GGML_TENSOR_UNARY_OP_LOCALS
5240
5525
 
5241
- //const int64_t ne11 = src1 ? src1->ne[1] : 1;
5526
+ const int64_t nb11 = src1 ? src1->nb[1] : 1;
5527
+ const int64_t nb12 = src1 ? src1->nb[2] : 1;
5528
+ const int64_t nb13 = src1 ? src1->nb[3] : 1;
5529
+
5530
+ const int64_t ne12 = src1 ? src1->ne[2] : 1;
5531
+ const int64_t ne13 = src1 ? src1->ne[3] : 1;
5242
5532
 
5243
5533
  // TODO: is this supposed to be ceil instead of floor?
5244
5534
  // https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370
@@ -5248,68 +5538,66 @@ static void ggml_compute_forward_soft_max_f32(
5248
5538
  const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
5249
5539
  const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
5250
5540
 
5251
- const int nc = src0->ne[0];
5252
- const int nr = ggml_nrows(src0);
5253
-
5254
- // rows per thread
5255
- const int dr = (nr + nth - 1)/nth;
5256
-
5257
- // row range for this thread
5258
- const int ir0 = dr*ith;
5259
- const int ir1 = MIN(ir0 + dr, nr);
5260
-
5261
- float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
5541
+ float * wp = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
5262
5542
 
5263
5543
  const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
5264
5544
 
5265
- for (int i1 = ir0; i1 < ir1; i1++) {
5266
- // ALiBi
5267
- const uint32_t h = (i1/ne01)%ne02; // head
5268
- const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
5269
-
5270
- float * sp = (float *)((char *) src0->data + i1*src0->nb[1]);
5271
- float * dp = (float *)((char *) dst->data + i1*dst->nb[1]);
5272
-
5273
- // broadcast the mask across rows
5274
- ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
5275
- float * mp_f32 = src1 ? (float *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
5276
-
5277
- ggml_vec_cpy_f32 (nc, wp, sp);
5278
- ggml_vec_scale_f32(nc, wp, scale);
5279
- if (mp_f32) {
5280
- if (use_f16) {
5281
- for (int i = 0; i < nc; ++i) {
5282
- wp[i] += slope*GGML_CPU_FP16_TO_FP32(mp_f16[i]);
5283
- }
5284
- } else {
5285
- for (int i = 0; i < nc; ++i) {
5286
- wp[i] += slope*mp_f32[i];
5545
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
5546
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
5547
+ for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
5548
+ const int64_t i11 = i01;
5549
+ const int64_t i12 = i02%ne12;
5550
+ const int64_t i13 = i03%ne13;
5551
+
5552
+ // ALiBi
5553
+ const uint32_t h = i02; // head
5554
+ const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
5555
+
5556
+ float * sp = (float *)((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
5557
+ float * dp = (float *)((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
5558
+
5559
+ // broadcast the mask across rows
5560
+ ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13) : NULL;
5561
+ float * mp_f32 = src1 ? (float *)((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13) : NULL;
5562
+
5563
+ ggml_vec_cpy_f32 (ne00, wp, sp);
5564
+ ggml_vec_scale_f32(ne00, wp, scale);
5565
+ if (mp_f32) {
5566
+ if (use_f16) {
5567
+ for (int i = 0; i < ne00; ++i) {
5568
+ wp[i] += slope*GGML_CPU_FP16_TO_FP32(mp_f16[i]);
5569
+ }
5570
+ } else {
5571
+ for (int i = 0; i < ne00; ++i) {
5572
+ wp[i] += slope*mp_f32[i];
5573
+ }
5574
+ }
5287
5575
  }
5288
- }
5289
- }
5290
5576
 
5291
5577
  #ifndef NDEBUG
5292
- for (int i = 0; i < nc; ++i) {
5293
- //printf("p[%d] = %f\n", i, p[i]);
5294
- assert(!isnan(wp[i]));
5295
- }
5578
+ for (int i = 0; i < ne00; ++i) {
5579
+ //printf("p[%d] = %f\n", i, p[i]);
5580
+ assert(!isnan(wp[i]));
5581
+ }
5296
5582
  #endif
5297
5583
 
5298
- float max = -INFINITY;
5299
- ggml_vec_max_f32(nc, &max, wp);
5584
+ float max = -INFINITY;
5585
+ ggml_vec_max_f32(ne00, &max, wp);
5300
5586
 
5301
- ggml_float sum = ggml_vec_soft_max_f32(nc, dp, wp, max);
5302
- assert(sum > 0.0);
5587
+ ggml_float sum = ggml_vec_soft_max_f32(ne00, dp, wp, max);
5588
+ assert(sum > 0.0);
5303
5589
 
5304
- sum = 1.0/sum;
5305
- ggml_vec_scale_f32(nc, dp, sum);
5590
+ sum = 1.0/sum;
5591
+ ggml_vec_scale_f32(ne00, dp, sum);
5306
5592
 
5307
5593
  #ifndef NDEBUG
5308
- for (int i = 0; i < nc; ++i) {
5309
- assert(!isnan(dp[i]));
5310
- assert(!isinf(dp[i]));
5311
- }
5594
+ for (int i = 0; i < ne00; ++i) {
5595
+ assert(!isnan(dp[i]));
5596
+ assert(!isinf(dp[i]));
5597
+ }
5312
5598
  #endif
5599
+ }
5600
+ }
5313
5601
  }
5314
5602
  }
5315
5603
 
@@ -6545,6 +6833,186 @@ void ggml_compute_forward_im2col_back_f32(
6545
6833
  }
6546
6834
  }
6547
6835
 
6836
+ static void ggml_call_mul_mat(ggml_type type, const ggml_compute_params * params, int64_t m, int64_t n, int64_t k,
6837
+ void * a, void * b, float * c) {
6838
+ const ggml_type_traits * traits = ggml_get_type_traits(type);
6839
+ struct ggml_tensor src1 = {};
6840
+ src1.type = type;
6841
+ src1.ne[0] = k;
6842
+ src1.ne[1] = m;
6843
+ src1.ne[2] = 1;
6844
+ src1.ne[3] = 1;
6845
+ src1.nb[0] = traits->type_size;
6846
+ src1.nb[1] = k * traits->type_size;
6847
+ src1.nb[2] = src1.nb[1];
6848
+ src1.nb[3] = src1.nb[2];
6849
+ src1.data = a;
6850
+
6851
+ struct ggml_tensor src0 = {};
6852
+ src0.type = type;
6853
+ src0.ne[0] = k;
6854
+ src0.ne[1] = n;
6855
+ src0.ne[2] = 1;
6856
+ src0.ne[3] = 1;
6857
+ src0.nb[0] = traits->type_size;
6858
+ src0.nb[1] = k * traits->type_size;
6859
+ src0.nb[2] = src0.nb[1];
6860
+ src0.nb[3] = src0.nb[2];
6861
+ src0.data = b;
6862
+
6863
+ struct ggml_tensor dst = {};
6864
+ dst.ne[0] = n;
6865
+ dst.ne[1] = m;
6866
+ dst.ne[2] = 1;
6867
+ dst.ne[3] = 1;
6868
+ dst.nb[0] = sizeof(float);
6869
+ dst.nb[1] = n * sizeof(float);
6870
+ dst.nb[2] = dst.nb[1];
6871
+ dst.nb[3] = dst.nb[2];
6872
+ dst.data = c;
6873
+ dst.src[0] = &src0;
6874
+ dst.src[1] = &src1;
6875
+
6876
+ ggml_compute_forward_mul_mat(params, &dst);
6877
+ }
6878
+
6879
+ // ggml_compute_forward_conv_2d
6880
+
6881
+ static void ggml_compute_forward_conv_2d_impl(const ggml_compute_params * params,
6882
+ const ggml_tensor * kernel, // [KW, KH, IC, OC]
6883
+ const ggml_tensor * src, // [W, H, C, N]
6884
+ ggml_tensor * dst, // [OW, OH, OC, N]
6885
+ ggml_type kernel_type) {
6886
+
6887
+ GGML_ASSERT(ggml_is_contiguous(kernel));
6888
+ GGML_ASSERT(kernel_type == GGML_TYPE_F16 || kernel_type == GGML_TYPE_F32);
6889
+ GGML_ASSERT(kernel->type == kernel_type);
6890
+
6891
+ const ggml_type_traits * traits = ggml_get_type_traits(kernel_type);
6892
+
6893
+ const int32_t stride_x = dst->op_params[0];
6894
+ const int32_t stride_y = dst->op_params[1];
6895
+ const int32_t pad_x = dst->op_params[2];
6896
+ const int32_t pad_y = dst->op_params[3];
6897
+ const int32_t dilation_x = dst->op_params[4];
6898
+ const int32_t dilation_y = dst->op_params[5];
6899
+
6900
+ const int64_t c_in = src->ne[2];
6901
+ const int64_t c_out = kernel->ne[3];
6902
+ GGML_ASSERT(c_in == kernel->ne[2]);
6903
+
6904
+ const int64_t src_w = src->ne[0];
6905
+ const int64_t src_h = src->ne[1];
6906
+ const int64_t knl_w = kernel->ne[0];
6907
+ const int64_t knl_h = kernel->ne[1];
6908
+ const int64_t dst_w = dst->ne[0];
6909
+ const int64_t dst_h = dst->ne[1];
6910
+
6911
+ const float * src_data = (float *) src->data;
6912
+ void * knl_data = kernel->data;
6913
+ float * dst_data = (float *) dst->data;
6914
+
6915
+ const int64_t knl_n = knl_w * knl_h * c_in;
6916
+ const int64_t patch_total = dst->ne[3] * dst_w * dst_h;
6917
+
6918
+ const int64_t space_per_patch = knl_n * traits->type_size + c_out * sizeof(float);
6919
+ const int64_t batch_size = params->wsize / space_per_patch;
6920
+ const int64_t patches_per_batch = batch_size > 8 ? (batch_size / 8) * 8 : batch_size;
6921
+ const int64_t batch_n = (patch_total + patches_per_batch - 1) / patches_per_batch;
6922
+
6923
+ GGML_ASSERT(patches_per_batch > 0 && batch_size >= 1);
6924
+
6925
+ void * tmp = params->wdata;
6926
+
6927
+ for (int64_t batch_i = 0; batch_i < batch_n; ++batch_i) {
6928
+
6929
+ const int64_t patch_start_batch = batch_i * patches_per_batch;
6930
+ const int64_t patch_end_batch = std::min(patch_start_batch + patches_per_batch,
6931
+ patch_total);
6932
+ const int64_t patch_n = patch_end_batch - patch_start_batch;
6933
+
6934
+ const int64_t patch_per_thread = (patch_n + params->nth - 1) / params->nth;
6935
+ const int64_t patch_start = patch_start_batch + params->ith * patch_per_thread;
6936
+ const int64_t patch_end = std::min(patch_start + patch_per_thread, patch_end_batch);
6937
+
6938
+ //im2col for a patch
6939
+ for (int64_t p = patch_start; p < patch_end; ++p) {
6940
+ const int64_t batch_n = p / (dst_w * dst_h);
6941
+ const int64_t src_x = (p / dst_w) % dst_h;
6942
+ const int64_t src_y = p % dst_w;
6943
+
6944
+ const float * src_base = (const float *)((const char *)src_data + batch_n * src->nb[3]);
6945
+ char * dst_row = (char *) tmp + (p % patches_per_batch) * knl_n * traits->type_size;
6946
+
6947
+ for (int64_t ic = 0; ic < c_in; ++ic) {
6948
+ for (int64_t ky = 0; ky < knl_h; ++ky) {
6949
+ for (int64_t kx = 0; kx < knl_w; ++kx) {
6950
+ const int64_t sy = src_x * stride_y + ky * dilation_y - pad_y;
6951
+ const int64_t sx = src_y * stride_x + kx * dilation_x - pad_x;
6952
+
6953
+ int64_t dst_idx = ic * (knl_h * knl_w) + ky * knl_w + kx;
6954
+
6955
+ float src_val;
6956
+ if (sy < 0 || sy >= src_h || sx < 0 || sx >= src_w) {
6957
+ src_val = 0.0f;
6958
+ } else {
6959
+ const float * src_ptr = (const float *)((const char *)src_base + sx * src->nb[0] + sy * src->nb[1] + ic * src->nb[2]);
6960
+ src_val = *src_ptr;
6961
+ }
6962
+
6963
+ char * element_ptr = dst_row + dst_idx * traits->type_size;
6964
+ if (kernel_type == GGML_TYPE_F32) {
6965
+ *(float *) element_ptr = src_val;
6966
+ } else if (kernel_type == GGML_TYPE_F16) {
6967
+ *(ggml_fp16_t *) element_ptr = GGML_CPU_FP32_TO_FP16(src_val);
6968
+ }
6969
+ }
6970
+ }
6971
+ }
6972
+ } // patches handled by this thread
6973
+
6974
+ ggml_barrier(params->threadpool);
6975
+
6976
+ float * gemm_output = (float *) ((char *) tmp + patches_per_batch * knl_n * traits->type_size);
6977
+
6978
+ GGML_ASSERT(gemm_output + patch_n * c_out <= (float*)tmp + params->wsize);
6979
+
6980
+ // GEMM: patches[patch_n, knl_n] × kernel[knl_n, c_out] = output[patch_n, c_out]
6981
+ ggml_call_mul_mat(kernel_type, params, patch_n, c_out, knl_n, tmp, knl_data, gemm_output);
6982
+
6983
+ ggml_barrier(params->threadpool);
6984
+
6985
+
6986
+ //permute back [OC, N, OH, OW] to [N, OC, OH, OW]
6987
+ const int64_t permute_per_thread = (patch_n + params->nth - 1) / params->nth;
6988
+ const int64_t permute_start = params->ith * permute_per_thread;
6989
+ const int64_t permute_end = std::min(permute_start + permute_per_thread, patch_n);
6990
+
6991
+ for (int64_t i = permute_start; i < permute_end; ++i) {
6992
+ const int64_t p = patch_start_batch + i;
6993
+ const int64_t batch_n = p / (dst_w * dst_h);
6994
+ const int64_t dst_y = (p / dst_w) % dst_h;
6995
+ const int64_t dst_x = p % dst_w;
6996
+
6997
+ for (int64_t oc = 0; oc < c_out; ++oc) {
6998
+ const float value = gemm_output[i * c_out + oc];
6999
+ float * dst_ptr = (float *)((char *)dst_data + dst_x * dst->nb[0] + dst_y * dst->nb[1] + oc * dst->nb[2] + batch_n * dst->nb[3]);
7000
+ *dst_ptr = value;
7001
+ }
7002
+ }
7003
+ }
7004
+ }
7005
+
7006
+ void ggml_compute_forward_conv_2d(
7007
+ const ggml_compute_params * params,
7008
+ ggml_tensor * dst) {
7009
+
7010
+ const ggml_tensor * src0 = dst->src[0];
7011
+ const ggml_tensor * src1 = dst->src[1];
7012
+
7013
+ ggml_compute_forward_conv_2d_impl(params, src0, src1, dst, src0->type);
7014
+ }
7015
+
6548
7016
  // ggml_compute_forward_conv_transpose_2d
6549
7017
 
6550
7018
  void ggml_compute_forward_conv_transpose_2d(
@@ -7095,12 +7563,13 @@ static void ggml_compute_forward_upscale_f32(
7095
7563
 
7096
7564
  GGML_TENSOR_UNARY_OP_LOCALS
7097
7565
 
7098
- const float sf0 = (float)ne0/src0->ne[0];
7099
- const float sf1 = (float)ne1/src0->ne[1];
7100
- const float sf2 = (float)ne2/src0->ne[2];
7101
- const float sf3 = (float)ne3/src0->ne[3];
7566
+ float sf0 = (float)ne0/src0->ne[0];
7567
+ float sf1 = (float)ne1/src0->ne[1];
7568
+ float sf2 = (float)ne2/src0->ne[2];
7569
+ float sf3 = (float)ne3/src0->ne[3];
7102
7570
 
7103
- const ggml_scale_mode mode = (ggml_scale_mode) ggml_get_op_params_i32(dst, 0);
7571
+ const int32_t mode_flags = ggml_get_op_params_i32(dst, 0);
7572
+ const ggml_scale_mode mode = (ggml_scale_mode) (mode_flags & 0xFF);
7104
7573
 
7105
7574
  if (mode == GGML_SCALE_MODE_NEAREST) {
7106
7575
  for (int64_t i3 = 0; i3 < ne3; i3++) {
@@ -7121,8 +7590,12 @@ static void ggml_compute_forward_upscale_f32(
7121
7590
  }
7122
7591
  }
7123
7592
  } else if (mode == GGML_SCALE_MODE_BILINEAR) {
7124
- // setting a pixel offset of 0 would replicate the behavior of pytorch interpolate with align_corners=True
7125
- const float pixel_offset = 0.5f;
7593
+ float pixel_offset = 0.5f;
7594
+ if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) {
7595
+ pixel_offset = 0.0f;
7596
+ sf0 = (float)(ne0 - 1) / (src0->ne[0] - 1);
7597
+ sf1 = (float)(ne1 - 1) / (src0->ne[1] - 1);
7598
+ }
7126
7599
 
7127
7600
  for (int64_t i3 = 0; i3 < ne3; i3++) {
7128
7601
  const int64_t i03 = i3 / sf3;
@@ -7580,7 +8053,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
7580
8053
  const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
7581
8054
  const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
7582
8055
 
7583
- ggml_type const k_vec_dot_type = ggml_get_type_traits_cpu(k->type)->vec_dot_type;
8056
+ ggml_type const k_vec_dot_type = ggml_get_type_traits_cpu(k->type)->vec_dot_type;
7584
8057
  ggml_from_float_t const q_to_vec_dot = ggml_get_type_traits_cpu(k_vec_dot_type)->from_float;
7585
8058
  ggml_vec_dot_t const kq_vec_dot = ggml_get_type_traits_cpu(k->type)->vec_dot;
7586
8059
  ggml_to_float_t const v_to_float = ggml_get_type_traits(v->type)->to_float;
@@ -7612,7 +8085,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
7612
8085
  memset(VKQ32, 0, DV*sizeof(float));
7613
8086
  }
7614
8087
 
7615
- const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL;
8088
+ const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1] + (iq2%mask->ne[2])*mask->nb[2] + (iq3%mask->ne[3])*mask->nb[3]) : NULL;
7616
8089
 
7617
8090
  // k indices
7618
8091
  const int ik3 = iq3 / rk3;
@@ -8150,120 +8623,210 @@ void ggml_compute_forward_ssm_conv(
8150
8623
  static void ggml_compute_forward_ssm_scan_f32(
8151
8624
  const ggml_compute_params * params,
8152
8625
  ggml_tensor * dst) {
8153
- const ggml_tensor * src0 = dst->src[0]; // s
8154
- const ggml_tensor * src1 = dst->src[1]; // x
8155
- const ggml_tensor * src2 = dst->src[2]; // dt
8156
- const ggml_tensor * src3 = dst->src[3]; // A
8157
- const ggml_tensor * src4 = dst->src[4]; // B
8158
- const ggml_tensor * src5 = dst->src[5]; // C
8626
+ const ggml_tensor * src0 = dst->src[0]; // s {d_state, dim, n_head, n_seqs+}
8627
+ const ggml_tensor * src1 = dst->src[1]; // x {dim, n_head, n_seq_tokens, n_seqs}
8628
+ const ggml_tensor * src2 = dst->src[2]; // dt {n_head, n_seq_tokens, n_seqs}
8629
+ const ggml_tensor * src3 = dst->src[3]; // A {d_state, n_head} or {1, n_head}
8630
+ const ggml_tensor * src4 = dst->src[4]; // B {d_state, n_group, n_seq_tokens, n_seqs}
8631
+ const ggml_tensor * src5 = dst->src[5]; // C {d_state, n_group, n_seq_tokens, n_seqs}
8632
+ const ggml_tensor * src6 = dst->src[6]; // ids {n_seqs}
8159
8633
 
8160
8634
  const int ith = params->ith;
8161
8635
  const int nth = params->nth;
8162
8636
 
8163
- const int64_t nc = src0->ne[0]; // d_state
8164
- const int64_t nr = src0->ne[1]; // d_inner
8165
- const int64_t n_t = src1->ne[1]; // number of tokens per sequence
8166
- const int64_t n_s = src0->ne[2]; // number of sequences in the batch
8637
+ const int64_t nc = src0->ne[0]; // d_state
8638
+ const int64_t nr = src0->ne[1]; // dim
8639
+ const int64_t nh = src1->ne[1]; // n_head
8640
+ const int64_t ng = src4->ne[1];
8641
+ const int64_t nt = src1->ne[2]; // number of tokens per sequence
8642
+ const int64_t ns = src1->ne[3]; // number of sequences in the batch
8643
+
8644
+ // can't use ggml_nbytes because src1 is not necessarily contiguous
8645
+ const int64_t s_off = ggml_nelements(src1) * ggml_element_size(src1);
8167
8646
 
8168
- GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst));
8647
+ GGML_ASSERT(ggml_nelements(src1) + nc*nr*nh*ns == ggml_nelements(dst));
8169
8648
  GGML_ASSERT(src0->nb[0] == sizeof(float));
8170
8649
  GGML_ASSERT(src1->nb[0] == sizeof(float));
8171
8650
  GGML_ASSERT(src2->nb[0] == sizeof(float));
8172
8651
  GGML_ASSERT(src3->nb[0] == sizeof(float));
8173
8652
  GGML_ASSERT(src4->nb[0] == sizeof(float));
8174
8653
  GGML_ASSERT(src5->nb[0] == sizeof(float));
8175
- // required for the dot product between s and C
8176
- GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
8177
- // required for per-sequence offsets for states
8178
- GGML_ASSERT(src0->nb[2] == src0->ne[0]*src0->ne[1]*sizeof(float));
8179
- // required to get correct offset for state destination (i.e. src1->nb[3])
8180
- GGML_ASSERT(src1->nb[3] == src1->ne[0]*src1->ne[1]*src1->ne[2]*sizeof(float));
8654
+ GGML_ASSERT(src6->nb[0] == sizeof(int32_t));
8655
+ // allows optimizing the modulo since n_group should be a power of 2
8656
+ GGML_ASSERT((ng & -ng) == ng);
8181
8657
 
8182
- // rows per thread
8183
- const int dr = (nr + nth - 1)/nth;
8658
+ // heads per thread
8659
+ const int dh = (nh + nth - 1)/nth;
8184
8660
 
8185
- // row range for this thread
8186
- const int ir0 = dr*ith;
8187
- const int ir1 = MIN(ir0 + dr, nr);
8188
- const int ir = ir1 - ir0;
8661
+ // head range for this thread
8662
+ const int ih0 = dh*ith;
8663
+ const int ih1 = MIN(ih0 + dh, nh);
8664
+
8665
+ const int32_t * ids = (const int32_t *) src6->data;
8666
+
8667
+ for (int i3 = 0; i3 < ns; ++i3) {
8668
+ const float * s0 = (const float *) ((const char *) src0->data + ids[i3]*(src0->nb[3])); // {d_state, dim, nh, ns}
8669
+ float * s = ( float *) (( char *) dst->data + i3*(src0->nb[3]) + s_off); // {d_state, dim, nh, ns}
8670
+
8671
+ for (int i2 = 0; i2 < nt; ++i2) {
8672
+ const float * x = (const float *) ((const char *) src1->data + i2*(src1->nb[2]) + i3*(src1->nb[3])); // {dim, nh, nt, ns}
8673
+ const float * dt = (const float *) ((const char *) src2->data + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {nh, nt, ns}
8674
+ const float * A = (const float *) ((const char *) src3->data); // {d_state, nh} or {1, nh}
8675
+ const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[2]) + i3*(src4->nb[3])); // {d_state, ng, nt, ns}
8676
+ const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[2]) + i3*(src5->nb[3])); // {d_state, ng, nt, ns}
8677
+ float * y = ( float *) (( char *) dst->data + i2*(nh*nr*sizeof(float)) + i3*(nt*nh*nr*sizeof(float))); // {dim, nh, nt, ns}
8678
+
8679
+ if (src3->ne[0] == 1) {
8680
+ // Mamba-2 has a scalar decay factor per head; dA can be outside the state-wise loop
8681
+
8682
+ // n_head
8683
+ for (int h = ih0; h < ih1; ++h) {
8684
+ // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
8685
+ const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
8686
+ const float dA = expf(dt_soft_plus * A[h]);
8687
+
8688
+ // dim
8689
+ for (int i1 = 0; i1 < nr; ++i1) {
8690
+ const int ii = i1 + h*nr;
8691
+ const float x_dt = x[ii] * dt_soft_plus;
8692
+ float sumf = 0.0f;
8693
+ #if defined(GGML_SIMD)
8694
+ #if defined(__ARM_FEATURE_SVE)
8695
+ const int ggml_f32_epr = svcntw();
8696
+ const int ggml_f32_step = 1 * ggml_f32_epr;
8697
+
8698
+ const int np = (nc & ~(ggml_f32_step - 1));
8699
+
8700
+ GGML_F32_VEC sum = GGML_F32_VEC_ZERO;
8701
+
8702
+ GGML_F32_VEC adA = GGML_F32_VEC_SET1(dA);
8703
+ GGML_F32_VEC axdt = GGML_F32_VEC_SET1(x_dt);
8704
+
8705
+ for (int i = 0; i < np; i += ggml_f32_step) {
8706
+ // TODO: maybe unroll more?
8707
+ for (int j = 0; j < 1; j++) {
8708
+ GGML_F32_VEC t0 = GGML_F32_VEC_LOAD(s0 + i + j*ggml_f32_epr + ii*nc);
8709
+ GGML_F32_VEC t1 = GGML_F32_VEC_LOAD(B + i + j*ggml_f32_epr + (h & (ng - 1))*nc);
8710
+ GGML_F32_VEC t2 = GGML_F32_VEC_LOAD(C + i + j*ggml_f32_epr + (h & (ng - 1))*nc);
8711
+
8712
+ t0 = GGML_F32_VEC_MUL(t0, adA);
8713
+ t1 = GGML_F32_VEC_MUL(t1, axdt);
8714
+
8715
+ t0 = GGML_F32_VEC_ADD(t0, t1);
8716
+
8717
+ sum = GGML_F32_VEC_FMA(sum, t0, t2);
8718
+
8719
+ GGML_F32_VEC_STORE(s + i + j*ggml_f32_epr + ii*nc, t0);
8720
+ }
8721
+ }
8722
+
8723
+ sumf = GGML_F32xt_REDUCE_ONE(sum);
8724
+ #else
8725
+ const int np = (nc & ~(GGML_F32_STEP - 1));
8726
+
8727
+ GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
8728
+
8729
+ GGML_F32_VEC adA = GGML_F32_VEC_SET1(dA);
8730
+ GGML_F32_VEC axdt = GGML_F32_VEC_SET1(x_dt);
8731
+
8732
+ GGML_F32_VEC ax[GGML_F32_ARR];
8733
+ GGML_F32_VEC ay[GGML_F32_ARR];
8734
+ GGML_F32_VEC az[GGML_F32_ARR];
8735
+
8736
+ for (int i = 0; i < np; i += GGML_F32_STEP) {
8737
+ for (int j = 0; j < GGML_F32_ARR; j++) {
8738
+ ax[j] = GGML_F32_VEC_LOAD(s0 + i + j*GGML_F32_EPR + ii*nc);
8739
+ ay[j] = GGML_F32_VEC_LOAD(B + i + j*GGML_F32_EPR + (h & (ng - 1))*nc);
8740
+ az[j] = GGML_F32_VEC_LOAD(C + i + j*GGML_F32_EPR + (h & (ng - 1))*nc);
8189
8741
 
8190
- #ifdef __ARM_FEATURE_SVE
8191
- for (int i3 = 0; i3 < n_s; ++i3) {
8192
- for (int i2 = 0; i2 < n_t; ++i2) {
8193
- const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s}
8194
- const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
8195
- const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s}
8196
- const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
8197
- const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
8198
- const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
8199
- float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
8200
- float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s}
8201
-
8202
- // use the output as the source for the next token-wise iterations
8203
- if (i2 > 0) { s0 = s; }
8204
-
8205
- // d_inner
8206
- for (int i1 = 0; i1 < ir; ++i1) {
8207
- float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
8208
- float x_dt = x[i1] * dt_soft_plus;
8209
- svfloat32_t vx_dt = GGML_F32_VEC_SET1(x_dt);
8210
- svfloat32_t vdt_soft_plus = GGML_F32_VEC_SET1(dt_soft_plus);
8211
- svfloat32_t r1_vector = GGML_F32_VEC_ZERO;
8212
-
8213
- for (int64_t k = 0; k < nc; k += svcntw()) {
8214
- svfloat32_t vA = GGML_F32_VEC_LOAD(&A[i1*nc + k]);
8215
- svfloat32_t vB = GGML_F32_VEC_LOAD(&B[k]);
8216
- svfloat32_t vC = GGML_F32_VEC_LOAD(&C[k]);
8217
- svfloat32_t vs0 = GGML_F32_VEC_LOAD(&s0[i1*nc + k]);
8218
-
8219
- svfloat32_t t1 = GGML_F32_VEC_MUL(vdt_soft_plus, vA);
8220
- t1 = exp_ps_sve(svptrue_b32(), t1);
8221
- svfloat32_t t2 = GGML_F32_VEC_MUL(vx_dt, vB);
8222
-
8223
- vs0 = GGML_F32_VEC_FMA(vs0, t1, t2);
8224
- r1_vector = GGML_F32_VEC_ADD(GGML_F32_VEC_MUL(vs0, vC), r1_vector);
8225
-
8226
- GGML_F32_VEC_STORE(&s[i1*nc + k], vs0);
8742
+ ax[j] = GGML_F32_VEC_MUL(ax[j], adA);
8743
+ ay[j] = GGML_F32_VEC_MUL(ay[j], axdt);
8744
+
8745
+ ax[j] = GGML_F32_VEC_ADD(ax[j], ay[j]);
8746
+
8747
+ sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], az[j]);
8748
+
8749
+ GGML_F32_VEC_STORE(s + i + j*GGML_F32_EPR + ii*nc, ax[j]);
8750
+ }
8751
+ }
8752
+
8753
+ // reduce sum0..sum3 to sum0
8754
+ GGML_F32_VEC_REDUCE(sumf, sum);
8755
+ #endif
8756
+ #else
8757
+ const int np = 0;
8758
+ #endif
8759
+ // d_state
8760
+ for (int i0 = np; i0 < nc; ++i0) {
8761
+ const int i = i0 + ii*nc;
8762
+ const int ig = i0 + (h & (ng - 1))*nc;
8763
+ // state = prev_state * dA + dB * x
8764
+ const float state = (s0[i] * dA) + (B[ig] * x_dt);
8765
+ // y = rowwise_dotprod(state, C)
8766
+ sumf += state * C[ig];
8767
+ s[i] = state;
8768
+ }
8769
+ y[ii] = sumf;
8227
8770
  }
8228
- y[i1] = GGML_F32xt_REDUCE_ONE(r1_vector);
8229
8771
  }
8230
- }
8231
- }
8232
- #else
8233
- for (int i3 = 0; i3 < n_s; ++i3) {
8234
- for (int i2 = 0; i2 < n_t; ++i2) {
8235
- const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s}
8236
- const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
8237
- const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s}
8238
- const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
8239
- const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
8240
- const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
8241
- float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
8242
- float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s}
8243
-
8244
- // use the output as the source for the next token-wise iterations
8245
- if (i2 > 0) { s0 = s; }
8246
-
8247
- // d_inner
8248
- for (int i1 = 0; i1 < ir; ++i1) {
8249
- // ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78
8250
- float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
8251
- float x_dt = x[i1] * dt_soft_plus;
8252
- float sumf = 0.0f;
8253
- // d_state
8254
- for (int i0 = 0; i0 < nc; ++i0) {
8255
- int i = i0 + i1*nc;
8256
- // state = prev_state * dA + dB * x
8257
- float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
8258
- // y = rowwise_dotprod(state, C)
8259
- sumf += state * C[i0];
8260
- s[i] = state;
8772
+ } else {
8773
+ // Mamba-1 has an element-wise decay factor for the states
8774
+
8775
+ // n_head
8776
+ for (int h = ih0; h < ih1; ++h) {
8777
+ // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
8778
+ const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
8779
+
8780
+ // dim
8781
+ for (int i1 = 0; i1 < nr; ++i1) {
8782
+ const int ii = i1 + h*nr;
8783
+ const float x_dt = x[ii] * dt_soft_plus;
8784
+ #if defined(__ARM_FEATURE_SVE)
8785
+ svfloat32_t vx_dt = GGML_F32_VEC_SET1(x_dt);
8786
+ svfloat32_t vdt_soft_plus = GGML_F32_VEC_SET1(dt_soft_plus);
8787
+ svfloat32_t r1_vector = GGML_F32_VEC_ZERO;
8788
+
8789
+ // d_state
8790
+ // TODO: what happens when (d_state % svcntw()) != 0?
8791
+ for (int64_t k = 0; k < nc; k += svcntw()) {
8792
+ svfloat32_t vA = GGML_F32_VEC_LOAD(&A[h*nc + k]);
8793
+ svfloat32_t vB = GGML_F32_VEC_LOAD(&B[k + (h & (ng - 1))*nc]);
8794
+ svfloat32_t vC = GGML_F32_VEC_LOAD(&C[k + (h & (ng - 1))*nc]);
8795
+ svfloat32_t vs0 = GGML_F32_VEC_LOAD(&s0[ii*nc + k]);
8796
+
8797
+ svfloat32_t t1 = GGML_F32_VEC_MUL(vdt_soft_plus, vA);
8798
+ t1 = exp_ps_sve(svptrue_b32(), t1);
8799
+ svfloat32_t t2 = GGML_F32_VEC_MUL(vx_dt, vB);
8800
+
8801
+ vs0 = GGML_F32_VEC_FMA(t2, vs0, t1);
8802
+ r1_vector = GGML_F32_VEC_ADD(GGML_F32_VEC_MUL(vs0, vC), r1_vector);
8803
+
8804
+ GGML_F32_VEC_STORE(&s[ii*nc + k], vs0);
8805
+ }
8806
+ y[ii] = GGML_F32xt_REDUCE_ONE(r1_vector);
8807
+ #else
8808
+ float sumf = 0.0f;
8809
+ // NOTE: can't really use GGML_SIMD here because d_state is usually 16
8810
+ // and also because expf is used within the loop.
8811
+ // d_state
8812
+ for (int i0 = 0; i0 < nc; ++i0) {
8813
+ const int i = i0 + ii*nc;
8814
+ const int ig = i0 + (h & (ng - 1))*nc;
8815
+ // state = prev_state * dA + dB * x
8816
+ const float state = (s0[i] * expf(dt_soft_plus * A[i0 + h*nc])) + (B[ig] * x_dt);
8817
+ // y = rowwise_dotprod(state, C)
8818
+ sumf += state * C[ig];
8819
+ s[i] = state;
8820
+ }
8821
+ y[ii] = sumf;
8822
+ #endif
8261
8823
  }
8262
- y[i1] = sumf;
8263
8824
  }
8264
8825
  }
8826
+ // use the output as the source when it's not the first token-wise iteration
8827
+ s0 = s;
8265
8828
  }
8266
- #endif
8829
+ }
8267
8830
  }
8268
8831
 
8269
8832
  void ggml_compute_forward_ssm_scan(
@@ -8502,6 +9065,14 @@ void ggml_compute_forward_glu(
8502
9065
  {
8503
9066
  ggml_compute_forward_swiglu(params, dst);
8504
9067
  } break;
9068
+ case GGML_GLU_OP_GEGLU_ERF:
9069
+ {
9070
+ ggml_compute_forward_geglu_erf(params, dst);
9071
+ } break;
9072
+ case GGML_GLU_OP_GEGLU_QUICK:
9073
+ {
9074
+ ggml_compute_forward_geglu_quick(params, dst);
9075
+ } break;
8505
9076
  default:
8506
9077
  {
8507
9078
  GGML_ABORT("fatal error");