@fugood/llama.node 1.0.2 → 1.0.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. package/package.json +14 -14
  2. package/src/llama.cpp/CMakeLists.txt +0 -1
  3. package/src/llama.cpp/common/CMakeLists.txt +4 -5
  4. package/src/llama.cpp/common/arg.cpp +44 -0
  5. package/src/llama.cpp/common/common.cpp +22 -6
  6. package/src/llama.cpp/common/common.h +15 -1
  7. package/src/llama.cpp/ggml/CMakeLists.txt +10 -2
  8. package/src/llama.cpp/ggml/include/ggml-webgpu.h +19 -0
  9. package/src/llama.cpp/ggml/include/ggml.h +104 -10
  10. package/src/llama.cpp/ggml/src/CMakeLists.txt +1 -1
  11. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +6 -1
  12. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +12 -1
  13. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +343 -1094
  14. package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +749 -163
  15. package/src/llama.cpp/ggml/src/ggml-cpu/ops.h +5 -0
  16. package/src/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +1 -1
  17. package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +12 -9
  18. package/src/llama.cpp/ggml/src/ggml-cpu/vec.h +88 -9
  19. package/src/llama.cpp/include/llama.h +13 -47
  20. package/src/llama.cpp/src/llama-arch.cpp +298 -3
  21. package/src/llama.cpp/src/llama-arch.h +22 -1
  22. package/src/llama.cpp/src/llama-batch.cpp +103 -71
  23. package/src/llama.cpp/src/llama-batch.h +31 -18
  24. package/src/llama.cpp/src/llama-chat.cpp +59 -1
  25. package/src/llama.cpp/src/llama-chat.h +3 -0
  26. package/src/llama.cpp/src/llama-context.cpp +134 -95
  27. package/src/llama.cpp/src/llama-context.h +13 -16
  28. package/src/llama.cpp/src/llama-cparams.h +3 -2
  29. package/src/llama.cpp/src/llama-graph.cpp +279 -180
  30. package/src/llama.cpp/src/llama-graph.h +183 -122
  31. package/src/llama.cpp/src/llama-hparams.cpp +47 -1
  32. package/src/llama.cpp/src/llama-hparams.h +12 -1
  33. package/src/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +38 -22
  34. package/src/llama.cpp/src/llama-kv-cache-unified-iswa.h +7 -2
  35. package/src/llama.cpp/src/llama-kv-cache-unified.cpp +849 -304
  36. package/src/llama.cpp/src/llama-kv-cache-unified.h +143 -47
  37. package/src/llama.cpp/src/llama-kv-cells.h +62 -10
  38. package/src/llama.cpp/src/llama-memory-hybrid.cpp +10 -4
  39. package/src/llama.cpp/src/llama-memory-hybrid.h +3 -1
  40. package/src/llama.cpp/src/llama-memory-recurrent.cpp +21 -11
  41. package/src/llama.cpp/src/llama-memory.cpp +17 -0
  42. package/src/llama.cpp/src/llama-memory.h +3 -0
  43. package/src/llama.cpp/src/llama-model.cpp +3373 -743
  44. package/src/llama.cpp/src/llama-model.h +20 -4
  45. package/src/llama.cpp/src/llama-quant.cpp +2 -2
  46. package/src/llama.cpp/src/llama-vocab.cpp +376 -10
  47. package/src/llama.cpp/src/llama-vocab.h +43 -0
  48. package/src/llama.cpp/src/unicode.cpp +207 -0
  49. package/src/llama.cpp/src/unicode.h +2 -0
  50. package/src/llama.cpp/ggml/include/ggml-kompute.h +0 -50
@@ -3,6 +3,7 @@
3
3
  #include "ggml-cpu.h"
4
4
  #include "ggml-impl.h"
5
5
  #include "binary-ops.h"
6
+ #include "ggml.h"
6
7
  #include "unary-ops.h"
7
8
  #include "vec.h"
8
9
 
@@ -3613,6 +3614,292 @@ static void ggml_compute_forward_swiglu(
3613
3614
  }
3614
3615
  }
3615
3616
 
3617
+ // ggml_compute_forward_geglu_erf
3618
+
3619
+ static void ggml_compute_forward_geglu_erf_f32(
3620
+ const ggml_compute_params * params,
3621
+ ggml_tensor * dst) {
3622
+
3623
+ const ggml_tensor * src0 = dst->src[0];
3624
+ const ggml_tensor * src1 = dst->src[1];
3625
+ char * src0_d = (char *) src0->data;
3626
+ char * src1_d = (char *) (src1 ? src1->data : src0->data);
3627
+ const size_t src0_o = src0->nb[1];
3628
+ const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3629
+
3630
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
3631
+ GGML_ASSERT(ggml_is_contiguous_1(dst));
3632
+
3633
+ if (src1) {
3634
+ GGML_ASSERT(ggml_is_contiguous_1(src1));
3635
+ GGML_ASSERT(src0->type == src1->type);
3636
+ }
3637
+
3638
+ const int ith = params->ith;
3639
+ const int nth = params->nth;
3640
+
3641
+ const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3642
+ const int nr = ggml_nrows(src0);
3643
+
3644
+ GGML_ASSERT(dst->ne[0] == nc);
3645
+ GGML_ASSERT(ggml_nrows(dst) == nr);
3646
+
3647
+ const int32_t swapped = ggml_get_op_params_i32(dst, 1);
3648
+
3649
+ // rows per thread
3650
+ const int dr = (nr + nth - 1)/nth;
3651
+
3652
+ // row range for this thread
3653
+ const int ir0 = dr*ith;
3654
+ const int ir1 = MIN(ir0 + dr, nr);
3655
+
3656
+ for (int i1 = ir0; i1 < ir1; i1++) {
3657
+ float * src0_p = (float *) (src0_d + i1*src0_o);
3658
+ float * src1_p = (float *) (src1_d + i1*src1_o);
3659
+
3660
+ if (!src1) {
3661
+ src0_p += swapped ? nc : 0;
3662
+ src1_p += swapped ? 0 : nc;
3663
+ }
3664
+
3665
+ ggml_vec_geglu_erf_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3666
+
3667
+ #ifndef NDEBUG
3668
+ for (int k = 0; k < nc; k++) {
3669
+ const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3670
+ GGML_UNUSED(x);
3671
+ assert(!isnan(x));
3672
+ assert(!isinf(x));
3673
+ }
3674
+ #endif
3675
+ }
3676
+ }
3677
+
3678
+ static void ggml_compute_forward_geglu_erf_f16(
3679
+ const ggml_compute_params * params,
3680
+ ggml_tensor * dst) {
3681
+
3682
+ const ggml_tensor * src0 = dst->src[0];
3683
+ const ggml_tensor * src1 = dst->src[1];
3684
+ char * src0_d = (char *) src0->data;
3685
+ char * src1_d = (char *) (src1 ? src1->data : src0->data);
3686
+ const size_t src0_o = src0->nb[1];
3687
+ const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3688
+
3689
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
3690
+ GGML_ASSERT(ggml_is_contiguous_1(dst));
3691
+
3692
+ if (src1) {
3693
+ GGML_ASSERT(ggml_is_contiguous_1(src1));
3694
+ GGML_ASSERT(src0->type == src1->type);
3695
+ }
3696
+
3697
+ const int ith = params->ith;
3698
+ const int nth = params->nth;
3699
+
3700
+ const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3701
+ const int nr = ggml_nrows(src0);
3702
+
3703
+ GGML_ASSERT(dst->ne[0] == nc);
3704
+ GGML_ASSERT(ggml_nrows(dst) == nr);
3705
+
3706
+ const int32_t swapped = ggml_get_op_params_i32(dst, 1);
3707
+
3708
+ // rows per thread
3709
+ const int dr = (nr + nth - 1)/nth;
3710
+
3711
+ // row range for this thread
3712
+ const int ir0 = dr*ith;
3713
+ const int ir1 = MIN(ir0 + dr, nr);
3714
+
3715
+ for (int i1 = ir0; i1 < ir1; i1++) {
3716
+ ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
3717
+ ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
3718
+
3719
+ if (!src1) {
3720
+ src0_p += swapped ? nc : 0;
3721
+ src1_p += swapped ? 0 : nc;
3722
+ }
3723
+
3724
+ ggml_vec_geglu_erf_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3725
+
3726
+ #ifndef NDEBUG
3727
+ for (int k = 0; k < nc; k++) {
3728
+ const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3729
+ const float v = GGML_FP16_TO_FP32(x);
3730
+ GGML_UNUSED(v);
3731
+ assert(!isnan(v));
3732
+ assert(!isinf(v));
3733
+ }
3734
+ #endif
3735
+ }
3736
+ }
3737
+
3738
+ static void ggml_compute_forward_geglu_erf(
3739
+ const ggml_compute_params * params,
3740
+ ggml_tensor * dst) {
3741
+
3742
+ const ggml_tensor * src0 = dst->src[0];
3743
+
3744
+ switch (src0->type) {
3745
+ case GGML_TYPE_F32:
3746
+ {
3747
+ ggml_compute_forward_geglu_erf_f32(params, dst);
3748
+ } break;
3749
+ case GGML_TYPE_F16:
3750
+ {
3751
+ ggml_compute_forward_geglu_erf_f16(params, dst);
3752
+ } break;
3753
+ default:
3754
+ {
3755
+ GGML_ABORT("fatal error");
3756
+ }
3757
+ }
3758
+ }
3759
+
3760
+ // ggml_compute_forward_geglu_quick
3761
+
3762
+ static void ggml_compute_forward_geglu_quick_f32(
3763
+ const ggml_compute_params * params,
3764
+ ggml_tensor * dst) {
3765
+
3766
+ const ggml_tensor * src0 = dst->src[0];
3767
+ const ggml_tensor * src1 = dst->src[1];
3768
+ char * src0_d = (char *) src0->data;
3769
+ char * src1_d = (char *) (src1 ? src1->data : src0->data);
3770
+ const size_t src0_o = src0->nb[1];
3771
+ const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3772
+
3773
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
3774
+ GGML_ASSERT(ggml_is_contiguous_1(dst));
3775
+
3776
+ if (src1) {
3777
+ GGML_ASSERT(ggml_is_contiguous_1(src1));
3778
+ GGML_ASSERT(src0->type == src1->type);
3779
+ }
3780
+
3781
+ const int ith = params->ith;
3782
+ const int nth = params->nth;
3783
+
3784
+ const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3785
+ const int nr = ggml_nrows(src0);
3786
+
3787
+ GGML_ASSERT(dst->ne[0] == nc);
3788
+ GGML_ASSERT(ggml_nrows(dst) == nr);
3789
+
3790
+ const int32_t swapped = ggml_get_op_params_i32(dst, 1);
3791
+
3792
+ // rows per thread
3793
+ const int dr = (nr + nth - 1)/nth;
3794
+
3795
+ // row range for this thread
3796
+ const int ir0 = dr*ith;
3797
+ const int ir1 = MIN(ir0 + dr, nr);
3798
+
3799
+ for (int i1 = ir0; i1 < ir1; i1++) {
3800
+ float * src0_p = (float *) (src0_d + i1*src0_o);
3801
+ float * src1_p = (float *) (src1_d + i1*src1_o);
3802
+
3803
+ if (!src1) {
3804
+ src0_p += swapped ? nc : 0;
3805
+ src1_p += swapped ? 0 : nc;
3806
+ }
3807
+
3808
+ ggml_vec_geglu_quick_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3809
+
3810
+ #ifndef NDEBUG
3811
+ for (int k = 0; k < nc; k++) {
3812
+ const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3813
+ GGML_UNUSED(x);
3814
+ assert(!isnan(x));
3815
+ assert(!isinf(x));
3816
+ }
3817
+ #endif
3818
+ }
3819
+ }
3820
+
3821
+ static void ggml_compute_forward_geglu_quick_f16(
3822
+ const ggml_compute_params * params,
3823
+ ggml_tensor * dst) {
3824
+
3825
+ const ggml_tensor * src0 = dst->src[0];
3826
+ const ggml_tensor * src1 = dst->src[1];
3827
+ char * src0_d = (char *) src0->data;
3828
+ char * src1_d = (char *) (src1 ? src1->data : src0->data);
3829
+ const size_t src0_o = src0->nb[1];
3830
+ const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3831
+
3832
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
3833
+ GGML_ASSERT(ggml_is_contiguous_1(dst));
3834
+
3835
+ if (src1) {
3836
+ GGML_ASSERT(ggml_is_contiguous_1(src1));
3837
+ GGML_ASSERT(src0->type == src1->type);
3838
+ }
3839
+
3840
+ const int ith = params->ith;
3841
+ const int nth = params->nth;
3842
+
3843
+ const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3844
+ const int nr = ggml_nrows(src0);
3845
+
3846
+ GGML_ASSERT(dst->ne[0] == nc);
3847
+ GGML_ASSERT(ggml_nrows(dst) == nr);
3848
+
3849
+ const int32_t swapped = ggml_get_op_params_i32(dst, 1);
3850
+
3851
+ // rows per thread
3852
+ const int dr = (nr + nth - 1)/nth;
3853
+
3854
+ // row range for this thread
3855
+ const int ir0 = dr*ith;
3856
+ const int ir1 = MIN(ir0 + dr, nr);
3857
+
3858
+ for (int i1 = ir0; i1 < ir1; i1++) {
3859
+ ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
3860
+ ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
3861
+
3862
+ if (!src1) {
3863
+ src0_p += swapped ? nc : 0;
3864
+ src1_p += swapped ? 0 : nc;
3865
+ }
3866
+
3867
+ ggml_vec_geglu_quick_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3868
+
3869
+ #ifndef NDEBUG
3870
+ for (int k = 0; k < nc; k++) {
3871
+ const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3872
+ const float v = GGML_FP16_TO_FP32(x);
3873
+ GGML_UNUSED(v);
3874
+ assert(!isnan(v));
3875
+ assert(!isinf(v));
3876
+ }
3877
+ #endif
3878
+ }
3879
+ }
3880
+
3881
+ static void ggml_compute_forward_geglu_quick(
3882
+ const ggml_compute_params * params,
3883
+ ggml_tensor * dst) {
3884
+
3885
+ const ggml_tensor * src0 = dst->src[0];
3886
+
3887
+ switch (src0->type) {
3888
+ case GGML_TYPE_F32:
3889
+ {
3890
+ ggml_compute_forward_geglu_quick_f32(params, dst);
3891
+ } break;
3892
+ case GGML_TYPE_F16:
3893
+ {
3894
+ ggml_compute_forward_geglu_quick_f16(params, dst);
3895
+ } break;
3896
+ default:
3897
+ {
3898
+ GGML_ABORT("fatal error");
3899
+ }
3900
+ }
3901
+ }
3902
+
3616
3903
  // ggml_compute_forward_norm
3617
3904
 
3618
3905
  static void ggml_compute_forward_norm_f32(
@@ -3728,6 +4015,9 @@ static void ggml_compute_forward_rms_norm_f32(
3728
4015
 
3729
4016
  const float scale = 1.0f/sqrtf(mean + eps);
3730
4017
 
4018
+ // if you hit this, likely you got an inf somewhere earlier
4019
+ assert(scale > 0.0f);
4020
+
3731
4021
  ggml_vec_scale_f32(ne00, y, scale);
3732
4022
  }
3733
4023
  }
@@ -4356,9 +4646,11 @@ static void ggml_compute_forward_scale_f32(
4356
4646
  GGML_ASSERT(ggml_is_contiguous(dst));
4357
4647
  GGML_ASSERT(ggml_are_same_shape(src0, dst));
4358
4648
 
4359
- // scale factor
4360
- float v;
4361
- memcpy(&v, dst->op_params, sizeof(float));
4649
+ float s; // scale factor
4650
+ float b; // bias
4651
+
4652
+ memcpy(&s, (float *) dst->op_params + 0, sizeof(float));
4653
+ memcpy(&b, (float *) dst->op_params + 1, sizeof(float));
4362
4654
 
4363
4655
  const int ith = params->ith;
4364
4656
  const int nth = params->nth;
@@ -4377,12 +4669,22 @@ static void ggml_compute_forward_scale_f32(
4377
4669
 
4378
4670
  const size_t nb1 = dst->nb[1];
4379
4671
 
4380
- for (int i1 = ir0; i1 < ir1; i1++) {
4381
- if (dst->data != src0->data) {
4382
- // src0 is same shape as dst => same indices
4383
- memcpy((char *)dst->data + i1*nb1, (char *)src0->data + i1*nb01, nc * sizeof(float));
4672
+ if (b == 0.0f) {
4673
+ for (int i1 = ir0; i1 < ir1; i1++) {
4674
+ if (dst->data != src0->data) {
4675
+ // src0 is same shape as dst => same indices
4676
+ // TODO: add x parameter to ggml_vec_scale_f32 and remove this memcpy
4677
+ memcpy((char *)dst->data + i1*nb1, (char *)src0->data + i1*nb01, nc * sizeof(float));
4678
+ }
4679
+ ggml_vec_scale_f32(nc, (float *) ((char *) dst->data + i1*nb1), s);
4680
+ }
4681
+ } else {
4682
+ for (int i1 = ir0; i1 < ir1; i1++) {
4683
+ ggml_vec_mad1_f32(nc,
4684
+ (float *) ((char *) dst->data + i1*nb1),
4685
+ (float *) ((char *) src0->data + i1*nb1),
4686
+ s, b);
4384
4687
  }
4385
- ggml_vec_scale_f32(nc, (float *) ((char *) dst->data + i1*nb1), v);
4386
4688
  }
4387
4689
  }
4388
4690
 
@@ -5231,14 +5533,17 @@ static void ggml_compute_forward_soft_max_f32(
5231
5533
  memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
5232
5534
  memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
5233
5535
 
5234
- // TODO: handle transposed/permuted matrices
5235
-
5236
5536
  const int ith = params->ith;
5237
5537
  const int nth = params->nth;
5238
5538
 
5239
5539
  GGML_TENSOR_UNARY_OP_LOCALS
5240
5540
 
5241
- //const int64_t ne11 = src1 ? src1->ne[1] : 1;
5541
+ const int64_t nb11 = src1 ? src1->nb[1] : 1;
5542
+ const int64_t nb12 = src1 ? src1->nb[2] : 1;
5543
+ const int64_t nb13 = src1 ? src1->nb[3] : 1;
5544
+
5545
+ const int64_t ne12 = src1 ? src1->ne[2] : 1;
5546
+ const int64_t ne13 = src1 ? src1->ne[3] : 1;
5242
5547
 
5243
5548
  // TODO: is this supposed to be ceil instead of floor?
5244
5549
  // https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370
@@ -5248,68 +5553,66 @@ static void ggml_compute_forward_soft_max_f32(
5248
5553
  const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
5249
5554
  const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
5250
5555
 
5251
- const int nc = src0->ne[0];
5252
- const int nr = ggml_nrows(src0);
5253
-
5254
- // rows per thread
5255
- const int dr = (nr + nth - 1)/nth;
5256
-
5257
- // row range for this thread
5258
- const int ir0 = dr*ith;
5259
- const int ir1 = MIN(ir0 + dr, nr);
5260
-
5261
- float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
5556
+ float * wp = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
5262
5557
 
5263
5558
  const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
5264
5559
 
5265
- for (int i1 = ir0; i1 < ir1; i1++) {
5266
- // ALiBi
5267
- const uint32_t h = (i1/ne01)%ne02; // head
5268
- const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
5269
-
5270
- float * sp = (float *)((char *) src0->data + i1*src0->nb[1]);
5271
- float * dp = (float *)((char *) dst->data + i1*dst->nb[1]);
5272
-
5273
- // broadcast the mask across rows
5274
- ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
5275
- float * mp_f32 = src1 ? (float *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
5276
-
5277
- ggml_vec_cpy_f32 (nc, wp, sp);
5278
- ggml_vec_scale_f32(nc, wp, scale);
5279
- if (mp_f32) {
5280
- if (use_f16) {
5281
- for (int i = 0; i < nc; ++i) {
5282
- wp[i] += slope*GGML_CPU_FP16_TO_FP32(mp_f16[i]);
5283
- }
5284
- } else {
5285
- for (int i = 0; i < nc; ++i) {
5286
- wp[i] += slope*mp_f32[i];
5560
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
5561
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
5562
+ for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
5563
+ const int64_t i11 = i01;
5564
+ const int64_t i12 = i02%ne12;
5565
+ const int64_t i13 = i03%ne13;
5566
+
5567
+ // ALiBi
5568
+ const uint32_t h = i02; // head
5569
+ const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
5570
+
5571
+ float * sp = (float *)((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
5572
+ float * dp = (float *)((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
5573
+
5574
+ // broadcast the mask across rows
5575
+ ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13) : NULL;
5576
+ float * mp_f32 = src1 ? (float *)((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13) : NULL;
5577
+
5578
+ ggml_vec_cpy_f32 (ne00, wp, sp);
5579
+ ggml_vec_scale_f32(ne00, wp, scale);
5580
+ if (mp_f32) {
5581
+ if (use_f16) {
5582
+ for (int i = 0; i < ne00; ++i) {
5583
+ wp[i] += slope*GGML_CPU_FP16_TO_FP32(mp_f16[i]);
5584
+ }
5585
+ } else {
5586
+ for (int i = 0; i < ne00; ++i) {
5587
+ wp[i] += slope*mp_f32[i];
5588
+ }
5589
+ }
5287
5590
  }
5288
- }
5289
- }
5290
5591
 
5291
5592
  #ifndef NDEBUG
5292
- for (int i = 0; i < nc; ++i) {
5293
- //printf("p[%d] = %f\n", i, p[i]);
5294
- assert(!isnan(wp[i]));
5295
- }
5593
+ for (int i = 0; i < ne00; ++i) {
5594
+ //printf("p[%d] = %f\n", i, p[i]);
5595
+ assert(!isnan(wp[i]));
5596
+ }
5296
5597
  #endif
5297
5598
 
5298
- float max = -INFINITY;
5299
- ggml_vec_max_f32(nc, &max, wp);
5599
+ float max = -INFINITY;
5600
+ ggml_vec_max_f32(ne00, &max, wp);
5300
5601
 
5301
- ggml_float sum = ggml_vec_soft_max_f32(nc, dp, wp, max);
5302
- assert(sum > 0.0);
5602
+ ggml_float sum = ggml_vec_soft_max_f32(ne00, dp, wp, max);
5603
+ assert(sum > 0.0);
5303
5604
 
5304
- sum = 1.0/sum;
5305
- ggml_vec_scale_f32(nc, dp, sum);
5605
+ sum = 1.0/sum;
5606
+ ggml_vec_scale_f32(ne00, dp, sum);
5306
5607
 
5307
5608
  #ifndef NDEBUG
5308
- for (int i = 0; i < nc; ++i) {
5309
- assert(!isnan(dp[i]));
5310
- assert(!isinf(dp[i]));
5311
- }
5609
+ for (int i = 0; i < ne00; ++i) {
5610
+ assert(!isnan(dp[i]));
5611
+ assert(!isinf(dp[i]));
5612
+ }
5312
5613
  #endif
5614
+ }
5615
+ }
5313
5616
  }
5314
5617
  }
5315
5618
 
@@ -6545,6 +6848,186 @@ void ggml_compute_forward_im2col_back_f32(
6545
6848
  }
6546
6849
  }
6547
6850
 
6851
+ static void ggml_call_mul_mat(ggml_type type, const ggml_compute_params * params, int64_t m, int64_t n, int64_t k,
6852
+ void * a, void * b, float * c) {
6853
+ const ggml_type_traits * traits = ggml_get_type_traits(type);
6854
+ struct ggml_tensor src1 = {};
6855
+ src1.type = type;
6856
+ src1.ne[0] = k;
6857
+ src1.ne[1] = m;
6858
+ src1.ne[2] = 1;
6859
+ src1.ne[3] = 1;
6860
+ src1.nb[0] = traits->type_size;
6861
+ src1.nb[1] = k * traits->type_size;
6862
+ src1.nb[2] = src1.nb[1];
6863
+ src1.nb[3] = src1.nb[2];
6864
+ src1.data = a;
6865
+
6866
+ struct ggml_tensor src0 = {};
6867
+ src0.type = type;
6868
+ src0.ne[0] = k;
6869
+ src0.ne[1] = n;
6870
+ src0.ne[2] = 1;
6871
+ src0.ne[3] = 1;
6872
+ src0.nb[0] = traits->type_size;
6873
+ src0.nb[1] = k * traits->type_size;
6874
+ src0.nb[2] = src0.nb[1];
6875
+ src0.nb[3] = src0.nb[2];
6876
+ src0.data = b;
6877
+
6878
+ struct ggml_tensor dst = {};
6879
+ dst.ne[0] = n;
6880
+ dst.ne[1] = m;
6881
+ dst.ne[2] = 1;
6882
+ dst.ne[3] = 1;
6883
+ dst.nb[0] = sizeof(float);
6884
+ dst.nb[1] = n * sizeof(float);
6885
+ dst.nb[2] = dst.nb[1];
6886
+ dst.nb[3] = dst.nb[2];
6887
+ dst.data = c;
6888
+ dst.src[0] = &src0;
6889
+ dst.src[1] = &src1;
6890
+
6891
+ ggml_compute_forward_mul_mat(params, &dst);
6892
+ }
6893
+
6894
+ // ggml_compute_forward_conv_2d
6895
+
6896
+ static void ggml_compute_forward_conv_2d_impl(const ggml_compute_params * params,
6897
+ const ggml_tensor * kernel, // [KW, KH, IC, OC]
6898
+ const ggml_tensor * src, // [W, H, C, N]
6899
+ ggml_tensor * dst, // [OW, OH, OC, N]
6900
+ ggml_type kernel_type) {
6901
+
6902
+ GGML_ASSERT(ggml_is_contiguous(kernel));
6903
+ GGML_ASSERT(kernel_type == GGML_TYPE_F16 || kernel_type == GGML_TYPE_F32);
6904
+ GGML_ASSERT(kernel->type == kernel_type);
6905
+
6906
+ const ggml_type_traits * traits = ggml_get_type_traits(kernel_type);
6907
+
6908
+ const int32_t stride_x = dst->op_params[0];
6909
+ const int32_t stride_y = dst->op_params[1];
6910
+ const int32_t pad_x = dst->op_params[2];
6911
+ const int32_t pad_y = dst->op_params[3];
6912
+ const int32_t dilation_x = dst->op_params[4];
6913
+ const int32_t dilation_y = dst->op_params[5];
6914
+
6915
+ const int64_t c_in = src->ne[2];
6916
+ const int64_t c_out = kernel->ne[3];
6917
+ GGML_ASSERT(c_in == kernel->ne[2]);
6918
+
6919
+ const int64_t src_w = src->ne[0];
6920
+ const int64_t src_h = src->ne[1];
6921
+ const int64_t knl_w = kernel->ne[0];
6922
+ const int64_t knl_h = kernel->ne[1];
6923
+ const int64_t dst_w = dst->ne[0];
6924
+ const int64_t dst_h = dst->ne[1];
6925
+
6926
+ const float * src_data = (float *) src->data;
6927
+ void * knl_data = kernel->data;
6928
+ float * dst_data = (float *) dst->data;
6929
+
6930
+ const int64_t knl_n = knl_w * knl_h * c_in;
6931
+ const int64_t patch_total = dst->ne[3] * dst_w * dst_h;
6932
+
6933
+ const int64_t space_per_patch = knl_n * traits->type_size + c_out * sizeof(float);
6934
+ const int64_t batch_size = params->wsize / space_per_patch;
6935
+ const int64_t patches_per_batch = batch_size > 8 ? (batch_size / 8) * 8 : batch_size;
6936
+ const int64_t batch_n = (patch_total + patches_per_batch - 1) / patches_per_batch;
6937
+
6938
+ GGML_ASSERT(patches_per_batch > 0 && batch_size >= 1);
6939
+
6940
+ void * tmp = params->wdata;
6941
+
6942
+ for (int64_t batch_i = 0; batch_i < batch_n; ++batch_i) {
6943
+
6944
+ const int64_t patch_start_batch = batch_i * patches_per_batch;
6945
+ const int64_t patch_end_batch = std::min(patch_start_batch + patches_per_batch,
6946
+ patch_total);
6947
+ const int64_t patch_n = patch_end_batch - patch_start_batch;
6948
+
6949
+ const int64_t patch_per_thread = (patch_n + params->nth - 1) / params->nth;
6950
+ const int64_t patch_start = patch_start_batch + params->ith * patch_per_thread;
6951
+ const int64_t patch_end = std::min(patch_start + patch_per_thread, patch_end_batch);
6952
+
6953
+ //im2col for a patch
6954
+ for (int64_t p = patch_start; p < patch_end; ++p) {
6955
+ const int64_t batch_n = p / (dst_w * dst_h);
6956
+ const int64_t src_x = (p / dst_w) % dst_h;
6957
+ const int64_t src_y = p % dst_w;
6958
+
6959
+ const float * src_base = (const float *)((const char *)src_data + batch_n * src->nb[3]);
6960
+ char * dst_row = (char *) tmp + (p % patches_per_batch) * knl_n * traits->type_size;
6961
+
6962
+ for (int64_t ic = 0; ic < c_in; ++ic) {
6963
+ for (int64_t ky = 0; ky < knl_h; ++ky) {
6964
+ for (int64_t kx = 0; kx < knl_w; ++kx) {
6965
+ const int64_t sy = src_x * stride_y + ky * dilation_y - pad_y;
6966
+ const int64_t sx = src_y * stride_x + kx * dilation_x - pad_x;
6967
+
6968
+ int64_t dst_idx = ic * (knl_h * knl_w) + ky * knl_w + kx;
6969
+
6970
+ float src_val;
6971
+ if (sy < 0 || sy >= src_h || sx < 0 || sx >= src_w) {
6972
+ src_val = 0.0f;
6973
+ } else {
6974
+ const float * src_ptr = (const float *)((const char *)src_base + sx * src->nb[0] + sy * src->nb[1] + ic * src->nb[2]);
6975
+ src_val = *src_ptr;
6976
+ }
6977
+
6978
+ char * element_ptr = dst_row + dst_idx * traits->type_size;
6979
+ if (kernel_type == GGML_TYPE_F32) {
6980
+ *(float *) element_ptr = src_val;
6981
+ } else if (kernel_type == GGML_TYPE_F16) {
6982
+ *(ggml_fp16_t *) element_ptr = GGML_CPU_FP32_TO_FP16(src_val);
6983
+ }
6984
+ }
6985
+ }
6986
+ }
6987
+ } // patches handled by this thread
6988
+
6989
+ ggml_barrier(params->threadpool);
6990
+
6991
+ float * gemm_output = (float *) ((char *) tmp + patches_per_batch * knl_n * traits->type_size);
6992
+
6993
+ GGML_ASSERT(gemm_output + patch_n * c_out <= (float*)tmp + params->wsize);
6994
+
6995
+ // GEMM: patches[patch_n, knl_n] × kernel[knl_n, c_out] = output[patch_n, c_out]
6996
+ ggml_call_mul_mat(kernel_type, params, patch_n, c_out, knl_n, tmp, knl_data, gemm_output);
6997
+
6998
+ ggml_barrier(params->threadpool);
6999
+
7000
+
7001
+ //permute back [OC, N, OH, OW] to [N, OC, OH, OW]
7002
+ const int64_t permute_per_thread = (patch_n + params->nth - 1) / params->nth;
7003
+ const int64_t permute_start = params->ith * permute_per_thread;
7004
+ const int64_t permute_end = std::min(permute_start + permute_per_thread, patch_n);
7005
+
7006
+ for (int64_t i = permute_start; i < permute_end; ++i) {
7007
+ const int64_t p = patch_start_batch + i;
7008
+ const int64_t batch_n = p / (dst_w * dst_h);
7009
+ const int64_t dst_y = (p / dst_w) % dst_h;
7010
+ const int64_t dst_x = p % dst_w;
7011
+
7012
+ for (int64_t oc = 0; oc < c_out; ++oc) {
7013
+ const float value = gemm_output[i * c_out + oc];
7014
+ float * dst_ptr = (float *)((char *)dst_data + dst_x * dst->nb[0] + dst_y * dst->nb[1] + oc * dst->nb[2] + batch_n * dst->nb[3]);
7015
+ *dst_ptr = value;
7016
+ }
7017
+ }
7018
+ }
7019
+ }
7020
+
7021
+ void ggml_compute_forward_conv_2d(
7022
+ const ggml_compute_params * params,
7023
+ ggml_tensor * dst) {
7024
+
7025
+ const ggml_tensor * src0 = dst->src[0];
7026
+ const ggml_tensor * src1 = dst->src[1];
7027
+
7028
+ ggml_compute_forward_conv_2d_impl(params, src0, src1, dst, src0->type);
7029
+ }
7030
+
6548
7031
  // ggml_compute_forward_conv_transpose_2d
6549
7032
 
6550
7033
  void ggml_compute_forward_conv_transpose_2d(
@@ -7095,12 +7578,13 @@ static void ggml_compute_forward_upscale_f32(
7095
7578
 
7096
7579
  GGML_TENSOR_UNARY_OP_LOCALS
7097
7580
 
7098
- const float sf0 = (float)ne0/src0->ne[0];
7099
- const float sf1 = (float)ne1/src0->ne[1];
7100
- const float sf2 = (float)ne2/src0->ne[2];
7101
- const float sf3 = (float)ne3/src0->ne[3];
7581
+ float sf0 = (float)ne0/src0->ne[0];
7582
+ float sf1 = (float)ne1/src0->ne[1];
7583
+ float sf2 = (float)ne2/src0->ne[2];
7584
+ float sf3 = (float)ne3/src0->ne[3];
7102
7585
 
7103
- const ggml_scale_mode mode = (ggml_scale_mode) ggml_get_op_params_i32(dst, 0);
7586
+ const int32_t mode_flags = ggml_get_op_params_i32(dst, 0);
7587
+ const ggml_scale_mode mode = (ggml_scale_mode) (mode_flags & 0xFF);
7104
7588
 
7105
7589
  if (mode == GGML_SCALE_MODE_NEAREST) {
7106
7590
  for (int64_t i3 = 0; i3 < ne3; i3++) {
@@ -7121,8 +7605,12 @@ static void ggml_compute_forward_upscale_f32(
7121
7605
  }
7122
7606
  }
7123
7607
  } else if (mode == GGML_SCALE_MODE_BILINEAR) {
7124
- // setting a pixel offset of 0 would replicate the behavior of pytorch interpolate with align_corners=True
7125
- const float pixel_offset = 0.5f;
7608
+ float pixel_offset = 0.5f;
7609
+ if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) {
7610
+ pixel_offset = 0.0f;
7611
+ sf0 = (float)(ne0 - 1) / (src0->ne[0] - 1);
7612
+ sf1 = (float)(ne1 - 1) / (src0->ne[1] - 1);
7613
+ }
7126
7614
 
7127
7615
  for (int64_t i3 = 0; i3 < ne3; i3++) {
7128
7616
  const int64_t i03 = i3 / sf3;
@@ -7580,7 +8068,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
7580
8068
  const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
7581
8069
  const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
7582
8070
 
7583
- ggml_type const k_vec_dot_type = ggml_get_type_traits_cpu(k->type)->vec_dot_type;
8071
+ ggml_type const k_vec_dot_type = ggml_get_type_traits_cpu(k->type)->vec_dot_type;
7584
8072
  ggml_from_float_t const q_to_vec_dot = ggml_get_type_traits_cpu(k_vec_dot_type)->from_float;
7585
8073
  ggml_vec_dot_t const kq_vec_dot = ggml_get_type_traits_cpu(k->type)->vec_dot;
7586
8074
  ggml_to_float_t const v_to_float = ggml_get_type_traits(v->type)->to_float;
@@ -7612,7 +8100,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
7612
8100
  memset(VKQ32, 0, DV*sizeof(float));
7613
8101
  }
7614
8102
 
7615
- const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL;
8103
+ const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1] + (iq2%mask->ne[2])*mask->nb[2] + (iq3%mask->ne[3])*mask->nb[3]) : NULL;
7616
8104
 
7617
8105
  // k indices
7618
8106
  const int ik3 = iq3 / rk3;
@@ -8150,120 +8638,210 @@ void ggml_compute_forward_ssm_conv(
8150
8638
  static void ggml_compute_forward_ssm_scan_f32(
8151
8639
  const ggml_compute_params * params,
8152
8640
  ggml_tensor * dst) {
8153
- const ggml_tensor * src0 = dst->src[0]; // s
8154
- const ggml_tensor * src1 = dst->src[1]; // x
8155
- const ggml_tensor * src2 = dst->src[2]; // dt
8156
- const ggml_tensor * src3 = dst->src[3]; // A
8157
- const ggml_tensor * src4 = dst->src[4]; // B
8158
- const ggml_tensor * src5 = dst->src[5]; // C
8641
+ const ggml_tensor * src0 = dst->src[0]; // s {d_state, dim, n_head, n_seqs+}
8642
+ const ggml_tensor * src1 = dst->src[1]; // x {dim, n_head, n_seq_tokens, n_seqs}
8643
+ const ggml_tensor * src2 = dst->src[2]; // dt {n_head, n_seq_tokens, n_seqs}
8644
+ const ggml_tensor * src3 = dst->src[3]; // A {d_state, n_head} or {1, n_head}
8645
+ const ggml_tensor * src4 = dst->src[4]; // B {d_state, n_group, n_seq_tokens, n_seqs}
8646
+ const ggml_tensor * src5 = dst->src[5]; // C {d_state, n_group, n_seq_tokens, n_seqs}
8647
+ const ggml_tensor * src6 = dst->src[6]; // ids {n_seqs}
8159
8648
 
8160
8649
  const int ith = params->ith;
8161
8650
  const int nth = params->nth;
8162
8651
 
8163
- const int64_t nc = src0->ne[0]; // d_state
8164
- const int64_t nr = src0->ne[1]; // d_inner
8165
- const int64_t n_t = src1->ne[1]; // number of tokens per sequence
8166
- const int64_t n_s = src0->ne[2]; // number of sequences in the batch
8652
+ const int64_t nc = src0->ne[0]; // d_state
8653
+ const int64_t nr = src0->ne[1]; // dim
8654
+ const int64_t nh = src1->ne[1]; // n_head
8655
+ const int64_t ng = src4->ne[1];
8656
+ const int64_t nt = src1->ne[2]; // number of tokens per sequence
8657
+ const int64_t ns = src1->ne[3]; // number of sequences in the batch
8658
+
8659
+ // can't use ggml_nbytes because src1 is not necessarily contiguous
8660
+ const int64_t s_off = ggml_nelements(src1) * ggml_element_size(src1);
8167
8661
 
8168
- GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst));
8662
+ GGML_ASSERT(ggml_nelements(src1) + nc*nr*nh*ns == ggml_nelements(dst));
8169
8663
  GGML_ASSERT(src0->nb[0] == sizeof(float));
8170
8664
  GGML_ASSERT(src1->nb[0] == sizeof(float));
8171
8665
  GGML_ASSERT(src2->nb[0] == sizeof(float));
8172
8666
  GGML_ASSERT(src3->nb[0] == sizeof(float));
8173
8667
  GGML_ASSERT(src4->nb[0] == sizeof(float));
8174
8668
  GGML_ASSERT(src5->nb[0] == sizeof(float));
8175
- // required for the dot product between s and C
8176
- GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
8177
- // required for per-sequence offsets for states
8178
- GGML_ASSERT(src0->nb[2] == src0->ne[0]*src0->ne[1]*sizeof(float));
8179
- // required to get correct offset for state destination (i.e. src1->nb[3])
8180
- GGML_ASSERT(src1->nb[3] == src1->ne[0]*src1->ne[1]*src1->ne[2]*sizeof(float));
8669
+ GGML_ASSERT(src6->nb[0] == sizeof(int32_t));
8670
+ // allows optimizing the modulo since n_group should be a power of 2
8671
+ GGML_ASSERT((ng & -ng) == ng);
8181
8672
 
8182
- // rows per thread
8183
- const int dr = (nr + nth - 1)/nth;
8673
+ // heads per thread
8674
+ const int dh = (nh + nth - 1)/nth;
8184
8675
 
8185
- // row range for this thread
8186
- const int ir0 = dr*ith;
8187
- const int ir1 = MIN(ir0 + dr, nr);
8188
- const int ir = ir1 - ir0;
8676
+ // head range for this thread
8677
+ const int ih0 = dh*ith;
8678
+ const int ih1 = MIN(ih0 + dh, nh);
8679
+
8680
+ const int32_t * ids = (const int32_t *) src6->data;
8681
+
8682
+ for (int i3 = 0; i3 < ns; ++i3) {
8683
+ const float * s0 = (const float *) ((const char *) src0->data + ids[i3]*(src0->nb[3])); // {d_state, dim, nh, ns}
8684
+ float * s = ( float *) (( char *) dst->data + i3*(src0->nb[3]) + s_off); // {d_state, dim, nh, ns}
8685
+
8686
+ for (int i2 = 0; i2 < nt; ++i2) {
8687
+ const float * x = (const float *) ((const char *) src1->data + i2*(src1->nb[2]) + i3*(src1->nb[3])); // {dim, nh, nt, ns}
8688
+ const float * dt = (const float *) ((const char *) src2->data + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {nh, nt, ns}
8689
+ const float * A = (const float *) ((const char *) src3->data); // {d_state, nh} or {1, nh}
8690
+ const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[2]) + i3*(src4->nb[3])); // {d_state, ng, nt, ns}
8691
+ const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[2]) + i3*(src5->nb[3])); // {d_state, ng, nt, ns}
8692
+ float * y = ( float *) (( char *) dst->data + i2*(nh*nr*sizeof(float)) + i3*(nt*nh*nr*sizeof(float))); // {dim, nh, nt, ns}
8693
+
8694
+ if (src3->ne[0] == 1) {
8695
+ // Mamba-2 has a scalar decay factor per head; dA can be outside the state-wise loop
8696
+
8697
+ // n_head
8698
+ for (int h = ih0; h < ih1; ++h) {
8699
+ // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
8700
+ const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
8701
+ const float dA = expf(dt_soft_plus * A[h]);
8702
+
8703
+ // dim
8704
+ for (int i1 = 0; i1 < nr; ++i1) {
8705
+ const int ii = i1 + h*nr;
8706
+ const float x_dt = x[ii] * dt_soft_plus;
8707
+ float sumf = 0.0f;
8708
+ #if defined(GGML_SIMD)
8709
+ #if defined(__ARM_FEATURE_SVE)
8710
+ const int ggml_f32_epr = svcntw();
8711
+ const int ggml_f32_step = 1 * ggml_f32_epr;
8712
+
8713
+ const int np = (nc & ~(ggml_f32_step - 1));
8714
+
8715
+ GGML_F32_VEC sum = GGML_F32_VEC_ZERO;
8716
+
8717
+ GGML_F32_VEC adA = GGML_F32_VEC_SET1(dA);
8718
+ GGML_F32_VEC axdt = GGML_F32_VEC_SET1(x_dt);
8719
+
8720
+ for (int i = 0; i < np; i += ggml_f32_step) {
8721
+ // TODO: maybe unroll more?
8722
+ for (int j = 0; j < 1; j++) {
8723
+ GGML_F32_VEC t0 = GGML_F32_VEC_LOAD(s0 + i + j*ggml_f32_epr + ii*nc);
8724
+ GGML_F32_VEC t1 = GGML_F32_VEC_LOAD(B + i + j*ggml_f32_epr + (h & (ng - 1))*nc);
8725
+ GGML_F32_VEC t2 = GGML_F32_VEC_LOAD(C + i + j*ggml_f32_epr + (h & (ng - 1))*nc);
8726
+
8727
+ t0 = GGML_F32_VEC_MUL(t0, adA);
8728
+ t1 = GGML_F32_VEC_MUL(t1, axdt);
8729
+
8730
+ t0 = GGML_F32_VEC_ADD(t0, t1);
8731
+
8732
+ sum = GGML_F32_VEC_FMA(sum, t0, t2);
8733
+
8734
+ GGML_F32_VEC_STORE(s + i + j*ggml_f32_epr + ii*nc, t0);
8735
+ }
8736
+ }
8737
+
8738
+ sumf = GGML_F32xt_REDUCE_ONE(sum);
8739
+ #else
8740
+ const int np = (nc & ~(GGML_F32_STEP - 1));
8741
+
8742
+ GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
8743
+
8744
+ GGML_F32_VEC adA = GGML_F32_VEC_SET1(dA);
8745
+ GGML_F32_VEC axdt = GGML_F32_VEC_SET1(x_dt);
8746
+
8747
+ GGML_F32_VEC ax[GGML_F32_ARR];
8748
+ GGML_F32_VEC ay[GGML_F32_ARR];
8749
+ GGML_F32_VEC az[GGML_F32_ARR];
8189
8750
 
8190
- #ifdef __ARM_FEATURE_SVE
8191
- for (int i3 = 0; i3 < n_s; ++i3) {
8192
- for (int i2 = 0; i2 < n_t; ++i2) {
8193
- const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s}
8194
- const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
8195
- const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s}
8196
- const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
8197
- const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
8198
- const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
8199
- float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
8200
- float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s}
8201
-
8202
- // use the output as the source for the next token-wise iterations
8203
- if (i2 > 0) { s0 = s; }
8204
-
8205
- // d_inner
8206
- for (int i1 = 0; i1 < ir; ++i1) {
8207
- float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
8208
- float x_dt = x[i1] * dt_soft_plus;
8209
- svfloat32_t vx_dt = GGML_F32_VEC_SET1(x_dt);
8210
- svfloat32_t vdt_soft_plus = GGML_F32_VEC_SET1(dt_soft_plus);
8211
- svfloat32_t r1_vector = GGML_F32_VEC_ZERO;
8212
-
8213
- for (int64_t k = 0; k < nc; k += svcntw()) {
8214
- svfloat32_t vA = GGML_F32_VEC_LOAD(&A[i1*nc + k]);
8215
- svfloat32_t vB = GGML_F32_VEC_LOAD(&B[k]);
8216
- svfloat32_t vC = GGML_F32_VEC_LOAD(&C[k]);
8217
- svfloat32_t vs0 = GGML_F32_VEC_LOAD(&s0[i1*nc + k]);
8218
-
8219
- svfloat32_t t1 = GGML_F32_VEC_MUL(vdt_soft_plus, vA);
8220
- t1 = exp_ps_sve(svptrue_b32(), t1);
8221
- svfloat32_t t2 = GGML_F32_VEC_MUL(vx_dt, vB);
8222
-
8223
- vs0 = GGML_F32_VEC_FMA(vs0, t1, t2);
8224
- r1_vector = GGML_F32_VEC_ADD(GGML_F32_VEC_MUL(vs0, vC), r1_vector);
8225
-
8226
- GGML_F32_VEC_STORE(&s[i1*nc + k], vs0);
8751
+ for (int i = 0; i < np; i += GGML_F32_STEP) {
8752
+ for (int j = 0; j < GGML_F32_ARR; j++) {
8753
+ ax[j] = GGML_F32_VEC_LOAD(s0 + i + j*GGML_F32_EPR + ii*nc);
8754
+ ay[j] = GGML_F32_VEC_LOAD(B + i + j*GGML_F32_EPR + (h & (ng - 1))*nc);
8755
+ az[j] = GGML_F32_VEC_LOAD(C + i + j*GGML_F32_EPR + (h & (ng - 1))*nc);
8756
+
8757
+ ax[j] = GGML_F32_VEC_MUL(ax[j], adA);
8758
+ ay[j] = GGML_F32_VEC_MUL(ay[j], axdt);
8759
+
8760
+ ax[j] = GGML_F32_VEC_ADD(ax[j], ay[j]);
8761
+
8762
+ sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], az[j]);
8763
+
8764
+ GGML_F32_VEC_STORE(s + i + j*GGML_F32_EPR + ii*nc, ax[j]);
8765
+ }
8766
+ }
8767
+
8768
+ // reduce sum0..sum3 to sum0
8769
+ GGML_F32_VEC_REDUCE(sumf, sum);
8770
+ #endif
8771
+ #else
8772
+ const int np = 0;
8773
+ #endif
8774
+ // d_state
8775
+ for (int i0 = np; i0 < nc; ++i0) {
8776
+ const int i = i0 + ii*nc;
8777
+ const int ig = i0 + (h & (ng - 1))*nc;
8778
+ // state = prev_state * dA + dB * x
8779
+ const float state = (s0[i] * dA) + (B[ig] * x_dt);
8780
+ // y = rowwise_dotprod(state, C)
8781
+ sumf += state * C[ig];
8782
+ s[i] = state;
8783
+ }
8784
+ y[ii] = sumf;
8227
8785
  }
8228
- y[i1] = GGML_F32xt_REDUCE_ONE(r1_vector);
8229
8786
  }
8230
- }
8231
- }
8232
- #else
8233
- for (int i3 = 0; i3 < n_s; ++i3) {
8234
- for (int i2 = 0; i2 < n_t; ++i2) {
8235
- const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s}
8236
- const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
8237
- const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s}
8238
- const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
8239
- const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
8240
- const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
8241
- float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
8242
- float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s}
8243
-
8244
- // use the output as the source for the next token-wise iterations
8245
- if (i2 > 0) { s0 = s; }
8246
-
8247
- // d_inner
8248
- for (int i1 = 0; i1 < ir; ++i1) {
8249
- // ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78
8250
- float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
8251
- float x_dt = x[i1] * dt_soft_plus;
8252
- float sumf = 0.0f;
8253
- // d_state
8254
- for (int i0 = 0; i0 < nc; ++i0) {
8255
- int i = i0 + i1*nc;
8256
- // state = prev_state * dA + dB * x
8257
- float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
8258
- // y = rowwise_dotprod(state, C)
8259
- sumf += state * C[i0];
8260
- s[i] = state;
8787
+ } else {
8788
+ // Mamba-1 has an element-wise decay factor for the states
8789
+
8790
+ // n_head
8791
+ for (int h = ih0; h < ih1; ++h) {
8792
+ // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
8793
+ const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
8794
+
8795
+ // dim
8796
+ for (int i1 = 0; i1 < nr; ++i1) {
8797
+ const int ii = i1 + h*nr;
8798
+ const float x_dt = x[ii] * dt_soft_plus;
8799
+ #if defined(__ARM_FEATURE_SVE)
8800
+ svfloat32_t vx_dt = GGML_F32_VEC_SET1(x_dt);
8801
+ svfloat32_t vdt_soft_plus = GGML_F32_VEC_SET1(dt_soft_plus);
8802
+ svfloat32_t r1_vector = GGML_F32_VEC_ZERO;
8803
+
8804
+ // d_state
8805
+ // TODO: what happens when (d_state % svcntw()) != 0?
8806
+ for (int64_t k = 0; k < nc; k += svcntw()) {
8807
+ svfloat32_t vA = GGML_F32_VEC_LOAD(&A[h*nc + k]);
8808
+ svfloat32_t vB = GGML_F32_VEC_LOAD(&B[k + (h & (ng - 1))*nc]);
8809
+ svfloat32_t vC = GGML_F32_VEC_LOAD(&C[k + (h & (ng - 1))*nc]);
8810
+ svfloat32_t vs0 = GGML_F32_VEC_LOAD(&s0[ii*nc + k]);
8811
+
8812
+ svfloat32_t t1 = GGML_F32_VEC_MUL(vdt_soft_plus, vA);
8813
+ t1 = exp_ps_sve(svptrue_b32(), t1);
8814
+ svfloat32_t t2 = GGML_F32_VEC_MUL(vx_dt, vB);
8815
+
8816
+ vs0 = GGML_F32_VEC_FMA(t2, vs0, t1);
8817
+ r1_vector = GGML_F32_VEC_ADD(GGML_F32_VEC_MUL(vs0, vC), r1_vector);
8818
+
8819
+ GGML_F32_VEC_STORE(&s[ii*nc + k], vs0);
8820
+ }
8821
+ y[ii] = GGML_F32xt_REDUCE_ONE(r1_vector);
8822
+ #else
8823
+ float sumf = 0.0f;
8824
+ // NOTE: can't really use GGML_SIMD here because d_state is usually 16
8825
+ // and also because expf is used within the loop.
8826
+ // d_state
8827
+ for (int i0 = 0; i0 < nc; ++i0) {
8828
+ const int i = i0 + ii*nc;
8829
+ const int ig = i0 + (h & (ng - 1))*nc;
8830
+ // state = prev_state * dA + dB * x
8831
+ const float state = (s0[i] * expf(dt_soft_plus * A[i0 + h*nc])) + (B[ig] * x_dt);
8832
+ // y = rowwise_dotprod(state, C)
8833
+ sumf += state * C[ig];
8834
+ s[i] = state;
8835
+ }
8836
+ y[ii] = sumf;
8837
+ #endif
8261
8838
  }
8262
- y[i1] = sumf;
8263
8839
  }
8264
8840
  }
8841
+ // use the output as the source when it's not the first token-wise iteration
8842
+ s0 = s;
8265
8843
  }
8266
- #endif
8844
+ }
8267
8845
  }
8268
8846
 
8269
8847
  void ggml_compute_forward_ssm_scan(
@@ -8502,6 +9080,14 @@ void ggml_compute_forward_glu(
8502
9080
  {
8503
9081
  ggml_compute_forward_swiglu(params, dst);
8504
9082
  } break;
9083
+ case GGML_GLU_OP_GEGLU_ERF:
9084
+ {
9085
+ ggml_compute_forward_geglu_erf(params, dst);
9086
+ } break;
9087
+ case GGML_GLU_OP_GEGLU_QUICK:
9088
+ {
9089
+ ggml_compute_forward_geglu_quick(params, dst);
9090
+ } break;
8505
9091
  default:
8506
9092
  {
8507
9093
  GGML_ABORT("fatal error");