cui-llama.rn 1.1.0 → 1.1.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/cpp/ggml.c CHANGED
@@ -7095,7 +7095,8 @@ struct lm_ggml_tensor * lm_ggml_flash_attn_ext(
7095
7095
  struct lm_ggml_tensor * v,
7096
7096
  struct lm_ggml_tensor * mask,
7097
7097
  float scale,
7098
- float max_bias) {
7098
+ float max_bias,
7099
+ float logit_softcap) {
7099
7100
  LM_GGML_ASSERT(lm_ggml_can_mul_mat(k, q));
7100
7101
  // TODO: check if vT can be multiplied by (k*qT)
7101
7102
 
@@ -7122,7 +7123,7 @@ struct lm_ggml_tensor * lm_ggml_flash_attn_ext(
7122
7123
  int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] };
7123
7124
  struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, LM_GGML_TYPE_F32, 4, ne);
7124
7125
 
7125
- float params[] = { scale, max_bias };
7126
+ float params[] = { scale, max_bias, logit_softcap };
7126
7127
  lm_ggml_set_op_params(result, params, sizeof(params));
7127
7128
 
7128
7129
  result->op = LM_GGML_OP_FLASH_ATTN_EXT;
@@ -7142,7 +7143,7 @@ void lm_ggml_flash_attn_ext_set_prec(
7142
7143
 
7143
7144
  const int32_t prec_i32 = (int32_t) prec;
7144
7145
 
7145
- lm_ggml_set_op_params_i32(a, 2, prec_i32); // scale is on first pos, max_bias on second
7146
+ lm_ggml_set_op_params_i32(a, 3, prec_i32); // scale is on first pos, max_bias on second
7146
7147
  }
7147
7148
 
7148
7149
  // lm_ggml_flash_attn_back
@@ -7229,43 +7230,34 @@ struct lm_ggml_tensor * lm_ggml_flash_attn_back(
7229
7230
 
7230
7231
  struct lm_ggml_tensor * lm_ggml_ssm_conv(
7231
7232
  struct lm_ggml_context * ctx,
7232
- struct lm_ggml_tensor * s,
7233
- struct lm_ggml_tensor * x,
7234
- struct lm_ggml_tensor * c,
7235
- struct lm_ggml_tensor * sq) {
7236
- LM_GGML_ASSERT(lm_ggml_is_3d(s));
7237
- LM_GGML_ASSERT(lm_ggml_is_matrix(x));
7233
+ struct lm_ggml_tensor * sx,
7234
+ struct lm_ggml_tensor * c) {
7235
+ LM_GGML_ASSERT(lm_ggml_is_3d(sx));
7238
7236
  LM_GGML_ASSERT(lm_ggml_is_matrix(c));
7239
- LM_GGML_ASSERT(lm_ggml_is_matrix(sq));
7240
- LM_GGML_ASSERT(sq->type == LM_GGML_TYPE_I32);
7241
7237
 
7242
- const int64_t d_conv = c->ne[0];
7243
- const int64_t d_inner = c->ne[1];
7244
- const int64_t n_tokens = x->ne[1];
7245
- const int64_t n_kv = s->ne[2];
7238
+ const int64_t d_conv = c->ne[0];
7239
+ const int64_t d_inner = c->ne[1];
7240
+ const int64_t n_t = sx->ne[0] - d_conv + 1; // tokens per sequence
7241
+ const int64_t n_s = sx->ne[2];
7246
7242
 
7247
- LM_GGML_ASSERT( s->ne[0] == d_conv - 1);
7248
- LM_GGML_ASSERT( s->ne[1] == d_inner);
7249
- LM_GGML_ASSERT( x->ne[0] == d_inner);
7250
- LM_GGML_ASSERT(sq->ne[0] == n_kv);
7251
- LM_GGML_ASSERT(sq->ne[1] == n_tokens);
7243
+ // TODO: maybe support other strides than 1?
7244
+ LM_GGML_ASSERT(sx->ne[0] == d_conv - 1 + n_t);
7245
+ LM_GGML_ASSERT(sx->ne[1] == d_inner);
7246
+ LM_GGML_ASSERT(n_t >= 0);
7252
7247
 
7253
7248
  bool is_node = false;
7254
7249
 
7255
- if (s->grad || x->grad || c->grad || sq->grad) {
7250
+ if (sx->grad || c->grad) {
7256
7251
  LM_GGML_ABORT("fatal error"); // TODO: implement
7257
7252
  is_node = true;
7258
7253
  }
7259
7254
 
7260
- // 2-in-1 concatenated x and conv_states, {d_inner, n_tokens} with {d_conv, d_inner, n_kv}
7261
- struct lm_ggml_tensor * result = lm_ggml_new_tensor_1d(ctx, LM_GGML_TYPE_F32, (d_inner*n_tokens) + (d_conv*d_inner*n_kv));
7255
+ struct lm_ggml_tensor * result = lm_ggml_new_tensor_3d(ctx, LM_GGML_TYPE_F32, d_inner, n_t, n_s);
7262
7256
 
7263
7257
  result->op = LM_GGML_OP_SSM_CONV;
7264
7258
  result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL;
7265
- result->src[0] = s;
7266
- result->src[1] = x;
7267
- result->src[2] = c;
7268
- result->src[3] = sq;
7259
+ result->src[0] = sx;
7260
+ result->src[1] = c;
7269
7261
 
7270
7262
  return result;
7271
7263
  }
@@ -7279,39 +7271,42 @@ struct lm_ggml_tensor * lm_ggml_ssm_scan(
7279
7271
  struct lm_ggml_tensor * dt,
7280
7272
  struct lm_ggml_tensor * A,
7281
7273
  struct lm_ggml_tensor * B,
7282
- struct lm_ggml_tensor * C,
7283
- struct lm_ggml_tensor * sq) {
7274
+ struct lm_ggml_tensor * C) {
7284
7275
  LM_GGML_ASSERT(lm_ggml_is_contiguous(s));
7285
7276
  LM_GGML_ASSERT(lm_ggml_is_contiguous(x));
7286
7277
  LM_GGML_ASSERT(lm_ggml_is_contiguous(dt));
7287
7278
  LM_GGML_ASSERT(lm_ggml_is_contiguous(A));
7288
- LM_GGML_ASSERT(sq->type == LM_GGML_TYPE_I32);
7279
+ LM_GGML_ASSERT(lm_ggml_is_matrix(A));
7280
+ LM_GGML_ASSERT(lm_ggml_is_3d(B));
7281
+ LM_GGML_ASSERT(lm_ggml_is_3d(s));
7289
7282
  LM_GGML_ASSERT(B->nb[0] == lm_ggml_type_size(B->type));
7290
7283
  LM_GGML_ASSERT(C->nb[0] == lm_ggml_type_size(C->type));
7291
7284
  LM_GGML_ASSERT(lm_ggml_are_same_shape(x, dt));
7285
+ LM_GGML_ASSERT(lm_ggml_are_same_shape(B, C));
7292
7286
 
7293
7287
  {
7294
- const int64_t d_state = s->ne[0];
7295
- const int64_t d_inner = s->ne[1];
7296
- const int64_t n_tokens = x->ne[1];
7288
+ const int64_t d_state = s->ne[0];
7289
+ const int64_t d_inner = s->ne[1];
7290
+ const int64_t n_seq_tokens = x->ne[1];
7291
+ const int64_t n_seqs = x->ne[2];
7297
7292
 
7293
+ LM_GGML_ASSERT(s->ne[2] == n_seqs);
7298
7294
  LM_GGML_ASSERT(x->ne[0] == d_inner);
7299
7295
  LM_GGML_ASSERT(A->ne[0] == d_state);
7300
7296
  LM_GGML_ASSERT(A->ne[1] == d_inner);
7301
7297
  LM_GGML_ASSERT(B->ne[0] == d_state);
7302
- LM_GGML_ASSERT(B->ne[1] == n_tokens);
7303
- LM_GGML_ASSERT(C->ne[0] == d_state);
7304
- LM_GGML_ASSERT(C->ne[1] == n_tokens);
7298
+ LM_GGML_ASSERT(B->ne[1] == n_seq_tokens);
7299
+ LM_GGML_ASSERT(B->ne[2] == n_seqs);
7305
7300
  }
7306
7301
 
7307
7302
  bool is_node = false;
7308
7303
 
7309
- if (s->grad || x->grad || dt->grad || A->grad || B->grad || C->grad || sq->grad) {
7304
+ if (s->grad || x->grad || dt->grad || A->grad || B->grad || C->grad) {
7310
7305
  LM_GGML_ABORT("fatal error"); // TODO: implement
7311
7306
  is_node = true;
7312
7307
  }
7313
7308
 
7314
- // 2-in-1 concatenated y and ssm_states, {d_inner, n_tokens} with {d_state, d_inner, n_kv}
7309
+ // concatenated y + ssm_states
7315
7310
  struct lm_ggml_tensor * result = lm_ggml_new_tensor_1d(ctx, LM_GGML_TYPE_F32, lm_ggml_nelements(x) + lm_ggml_nelements(s));
7316
7311
 
7317
7312
  result->op = LM_GGML_OP_SSM_SCAN;
@@ -7322,7 +7317,6 @@ struct lm_ggml_tensor * lm_ggml_ssm_scan(
7322
7317
  result->src[3] = A;
7323
7318
  result->src[4] = B;
7324
7319
  result->src[5] = C;
7325
- result->src[6] = sq;
7326
7320
 
7327
7321
  return result;
7328
7322
  }
@@ -10995,11 +10989,6 @@ static void lm_ggml_compute_forward_concat_f32(
10995
10989
 
10996
10990
  LM_GGML_TENSOR_BINARY_OP_LOCALS
10997
10991
 
10998
- // TODO: support for transposed / permuted tensors
10999
- LM_GGML_ASSERT(nb0 == sizeof(float));
11000
- LM_GGML_ASSERT(nb00 == sizeof(float));
11001
- LM_GGML_ASSERT(nb10 == sizeof(float));
11002
-
11003
10992
  const int32_t dim = lm_ggml_get_op_params_i32(dst, 0);
11004
10993
 
11005
10994
  LM_GGML_ASSERT(dim >= 0 && dim < 4);
@@ -15283,11 +15272,17 @@ static void lm_ggml_compute_forward_flash_attn_ext_f16(
15283
15272
  const int ir0 = dr*ith;
15284
15273
  const int ir1 = MIN(ir0 + dr, nr);
15285
15274
 
15286
- float scale = 1.0f;
15287
- float max_bias = 0.0f;
15275
+ float scale = 1.0f;
15276
+ float max_bias = 0.0f;
15277
+ float logit_softcap = 0.0f;
15288
15278
 
15289
- memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
15290
- memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
15279
+ memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
15280
+ memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
15281
+ memcpy(&logit_softcap, (float *) dst->op_params + 2, sizeof(float));
15282
+
15283
+ if (logit_softcap != 0) {
15284
+ scale /= logit_softcap;
15285
+ }
15291
15286
 
15292
15287
  const uint32_t n_head = neq2;
15293
15288
  const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
@@ -15351,7 +15346,13 @@ static void lm_ggml_compute_forward_flash_attn_ext_f16(
15351
15346
  const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3);
15352
15347
  kq_vec_dot(D, &s, 0, k_data, 0, Q_q, 0, 1);
15353
15348
 
15354
- s = s*scale + mv; // scale KQ value and apply mask
15349
+ s = s*scale; // scale KQ value
15350
+
15351
+ if (logit_softcap != 0.0f) {
15352
+ s = logit_softcap*tanhf(s);
15353
+ }
15354
+
15355
+ s += mv; // apply mask
15355
15356
 
15356
15357
  const float Mold = M;
15357
15358
 
@@ -15360,7 +15361,7 @@ static void lm_ggml_compute_forward_flash_attn_ext_f16(
15360
15361
 
15361
15362
  const char * v_data = ((const char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3));
15362
15363
 
15363
- if (v->type== LM_GGML_TYPE_F16) {
15364
+ if (v->type == LM_GGML_TYPE_F16) {
15364
15365
  if (s > M) {
15365
15366
  // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
15366
15367
  M = s;
@@ -15427,7 +15428,7 @@ static void lm_ggml_compute_forward_flash_attn_ext(
15427
15428
  const struct lm_ggml_tensor * v,
15428
15429
  const struct lm_ggml_tensor * mask,
15429
15430
  struct lm_ggml_tensor * dst) {
15430
- switch (dst->op_params[2]) {
15431
+ switch (dst->op_params[3]) {
15431
15432
  case LM_GGML_PREC_DEFAULT:
15432
15433
  case LM_GGML_PREC_F32:
15433
15434
  {
@@ -15782,27 +15783,22 @@ static void lm_ggml_compute_forward_flash_attn_back(
15782
15783
  static void lm_ggml_compute_forward_ssm_conv_f32(
15783
15784
  const struct lm_ggml_compute_params * params,
15784
15785
  struct lm_ggml_tensor * dst) {
15785
- const struct lm_ggml_tensor * src0 = dst->src[0]; // conv_state
15786
- const struct lm_ggml_tensor * src1 = dst->src[1]; // x
15787
- const struct lm_ggml_tensor * src2 = dst->src[2]; // conv1d.weight
15788
- const struct lm_ggml_tensor * src3 = dst->src[3]; // state_seq
15786
+ const struct lm_ggml_tensor * src0 = dst->src[0]; // conv_x
15787
+ const struct lm_ggml_tensor * src1 = dst->src[1]; // conv1d.weight
15789
15788
 
15790
15789
  const int ith = params->ith;
15791
15790
  const int nth = params->nth;
15792
15791
 
15793
- const int nc = src2->ne[0]; // d_conv
15794
- const int nr = src0->ne[1]; // d_inner
15795
- const int n_t = src1->ne[1]; // n_tokens
15796
- const int n_kv = src0->ne[2]; // max number of sequences in the batch
15792
+ const int nc = src1->ne[0]; // d_conv
15793
+ const int ncs = src0->ne[0]; // d_conv - 1 + n_t
15794
+ const int nr = src0->ne[1]; // d_inner
15795
+ const int n_t = dst->ne[1]; // tokens per sequence
15796
+ const int n_s = dst->ne[2]; // number of sequences in the batch
15797
15797
 
15798
- LM_GGML_ASSERT((nr*n_t) + (nc*nr*n_kv) == lm_ggml_nelements(dst));
15798
+ LM_GGML_ASSERT( dst->ne[0] == nr);
15799
15799
  LM_GGML_ASSERT(src0->nb[0] == sizeof(float));
15800
15800
  LM_GGML_ASSERT(src1->nb[0] == sizeof(float));
15801
- LM_GGML_ASSERT(src2->nb[0] == sizeof(float));
15802
- LM_GGML_ASSERT(src3->nb[0] == sizeof(int32_t));
15803
15801
  LM_GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
15804
- // for use with the destination state offset between sequences
15805
- LM_GGML_ASSERT(src2->nb[2] == src2->ne[1]*src2->ne[0]*sizeof(float));
15806
15802
 
15807
15803
  // rows per thread
15808
15804
  const int dr = (nr + nth - 1)/nth;
@@ -15812,74 +15808,27 @@ static void lm_ggml_compute_forward_ssm_conv_f32(
15812
15808
  const int ir1 = MIN(ir0 + dr, nr);
15813
15809
  const int ir = ir1 - ir0;
15814
15810
 
15815
- if (n_kv > 1) {
15816
- // multiple sequences means it's hard to know when it's the first time a state is read,
15817
- // so copy them all over to the destination, just to be sure.
15818
- for (int i3 = 0; i3 < n_kv; ++i3) {
15819
- float * s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]));
15820
- float * s = (float *) ((char *) dst->data + ir0*(src2->nb[1]) + i3*(src2->nb[2]) + nr*n_t*sizeof(float));
15821
- // can't use memcpy because of d_conv vs d_conv - 1
15822
- for (int i1 = 0; i1 < ir; ++i1) {
15823
- for (int i0 = 0; i0 < nc - 1; ++i0) {
15824
- // copy s0 to last (d_conv - 1) columns of s
15825
- s[1 + i0 + i1*nc] = s0[i0 + i1*(nc - 1)];
15826
- }
15827
- }
15828
- }
15829
- }
15830
-
15831
- for (int i2 = 0; i2 < n_t; ++i2) {
15832
- int32_t * sq = (int32_t *) ((char *) src3->data + i2*(src3->nb[1])); // {n_kv, n_tokens}
15833
- float * x = (float *) ((char *) dst->data + ir0*sizeof(float) + i2*(nr*sizeof(float))); // {d_inner, n_tokens}
15834
- float * s = (float *) ((char *) dst->data + ir0*(src2->nb[1]) + sq[0]*(src2->nb[2]) + nr*n_t*sizeof(float)); // {d_conv, d_inner, n_kv}
15835
- float * s0; // {d_conv - 1, d_inner, n_kv}
15836
- float * x0 = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens}
15837
- float * c = (float *) ((char *) src2->data + ir0*(src2->nb[1])); // {d_conv, d_inner}
15838
- int ne0s0;
15839
-
15840
- LM_GGML_ASSERT(0 <= sq[0] && sq[0] < n_kv);
15811
+ for (int i3 = 0; i3 < n_s; ++i3) {
15812
+ for (int i2 = 0; i2 < n_t; ++i2) {
15813
+ // {d_conv - 1 + n_t, d_inner, n_seqs}
15814
+ // sliding window
15815
+ const float * s = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i2*(src0->nb[0]) + i3*(src0->nb[2])); // {d_conv, d_inner, n_s}
15816
+ const float * c = (const float *) ((const char *) src1->data + ir0*(src1->nb[1])); // {d_conv, d_inner}
15817
+ float * x = (float *) ((char *) dst->data + ir0*(dst->nb[0]) + i2*(dst->nb[1]) + i3*(dst->nb[2])); // {d_inner, n_t, n_s}
15841
15818
 
15842
- // avoid needing to copy the state for the first token
15843
- if (i2 == 0) {
15844
- s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + sq[0]*(src0->nb[2])); // {d_conv - 1, d_inner, n_kv}
15845
- ne0s0 = src0->ne[0];
15846
- } else {
15847
- // the source is the last (d_conv - 1) columns of the destination
15848
- s0 = s + 1;
15849
- ne0s0 = nc;
15850
- }
15851
-
15852
- // d_inner
15853
- for (int i1 = 0; i1 < ir; ++i1) {
15854
- // shift state left
15855
- for (int i0 = 0; i0 < nc - 1; ++i0) {
15856
- s[i0 + i1*nc] = s0[i0 + i1*ne0s0];
15857
- }
15858
- // insert x on the last column
15859
- s[(nc - 1) + i1*nc] = x0[i1];
15860
- }
15861
-
15862
- // handle copies when there are multiple output states
15863
- for (int i3 = 1; i3 < n_kv; ++i3) {
15864
- int32_t seq = sq[i3];
15865
- if (0 <= seq && seq < n_kv) {
15866
- float * s1 = s + (seq - sq[0])*nc*nr;
15867
- memcpy(s1, s, nc*ir*sizeof(float));
15868
- } else {
15869
- // stop at negative or too big seq_ids
15870
- break;
15871
- }
15872
- }
15819
+ // TODO: transpose the output for smaller strides for big batches?
15820
+ // d_inner
15821
+ for (int i1 = 0; i1 < ir; ++i1) {
15822
+ // rowwise dot product
15823
+ // NOTE: not using lm_ggml_vec_dot_f32, because its sum is in double precision
15824
+ float sumf = 0.0f;
15873
15825
 
15874
- // it seems a little faster when this is separate from the state shift
15875
- for (int i1 = 0; i1 < ir; ++i1) {
15876
- // rowwise dot product
15877
- float sumf = 0.0f;
15878
- for (int i0 = 0; i0 < nc; ++i0) {
15879
- int i = i0 + i1*nc;
15880
- sumf += s[i] * c[i];
15826
+ // d_conv
15827
+ for (int i0 = 0; i0 < nc; ++i0) {
15828
+ sumf += s[i0 + i1*ncs] * c[i0 + i1*nc];
15829
+ }
15830
+ x[i1] = sumf;
15881
15831
  }
15882
- x[i1] = sumf;
15883
15832
  }
15884
15833
  }
15885
15834
  }
@@ -15910,15 +15859,14 @@ static void lm_ggml_compute_forward_ssm_scan_f32(
15910
15859
  const struct lm_ggml_tensor * src3 = dst->src[3]; // A
15911
15860
  const struct lm_ggml_tensor * src4 = dst->src[4]; // B
15912
15861
  const struct lm_ggml_tensor * src5 = dst->src[5]; // C
15913
- const struct lm_ggml_tensor * src6 = dst->src[6]; // sq
15914
15862
 
15915
15863
  const int ith = params->ith;
15916
15864
  const int nth = params->nth;
15917
15865
 
15918
- const int64_t nc = src0->ne[0]; // d_state
15919
- const int64_t nr = src0->ne[1]; // d_inner
15920
- const int64_t n_t = src1->ne[1]; // number of tokens in the batch
15921
- const int64_t n_kv = src0->ne[2]; // max number of sequences in the batch
15866
+ const int64_t nc = src0->ne[0]; // d_state
15867
+ const int64_t nr = src0->ne[1]; // d_inner
15868
+ const int64_t n_t = src1->ne[1]; // number of tokens per sequence
15869
+ const int64_t n_s = src0->ne[2]; // number of sequences in the batch
15922
15870
 
15923
15871
  LM_GGML_ASSERT(lm_ggml_nelements(src1) + lm_ggml_nelements(src0) == lm_ggml_nelements(dst));
15924
15872
  LM_GGML_ASSERT(src0->nb[0] == sizeof(float));
@@ -15927,12 +15875,12 @@ static void lm_ggml_compute_forward_ssm_scan_f32(
15927
15875
  LM_GGML_ASSERT(src3->nb[0] == sizeof(float));
15928
15876
  LM_GGML_ASSERT(src4->nb[0] == sizeof(float));
15929
15877
  LM_GGML_ASSERT(src5->nb[0] == sizeof(float));
15930
- // required for the dot product between s and C, and when copying the states
15878
+ // required for the dot product between s and C
15931
15879
  LM_GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
15932
15880
  // required for per-sequence offsets for states
15933
15881
  LM_GGML_ASSERT(src0->nb[2] == src0->ne[0]*src0->ne[1]*sizeof(float));
15934
- // required to get correct offset for state destination (i.e. src1->nb[2])
15935
- LM_GGML_ASSERT(src1->nb[2] == src1->ne[0]*src1->ne[1]*sizeof(float));
15882
+ // required to get correct offset for state destination (i.e. src1->nb[3])
15883
+ LM_GGML_ASSERT(src1->nb[3] == src1->ne[0]*src1->ne[1]*src1->ne[2]*sizeof(float));
15936
15884
 
15937
15885
  // rows per thread
15938
15886
  const int dr = (nr + nth - 1)/nth;
@@ -15942,64 +15890,36 @@ static void lm_ggml_compute_forward_ssm_scan_f32(
15942
15890
  const int ir1 = MIN(ir0 + dr, nr);
15943
15891
  const int ir = ir1 - ir0;
15944
15892
 
15945
- if (n_kv > 1) {
15946
- // it's hard to know if the source states have already been copied
15947
- // when there are multiple, so copy them already.
15948
- for (int i3 = 0; i3 < n_kv; ++i3) {
15949
- float * s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]));
15950
- float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[2]);
15951
- memcpy(s, s0, nc*ir*sizeof(float));
15952
- }
15953
- }
15954
-
15955
- for (int i2 = 0; i2 < n_t; ++i2) {
15956
- int32_t * sq = (int32_t *) ((char *) src6->data + i2*(src6->nb[1])); // {n_kv, n_tokens}
15957
- float * y = (float *) ((char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens}
15958
- float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + sq[0]*(src0->nb[2]) + src1->nb[2]); // {d_state, d_inner, n_kv}
15959
- float * s0;
15960
- float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens}
15961
- float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1])); // {d_inner, n_tokens}
15962
- float * A = (float *) ((char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
15963
- float * B = (float *) ((char *) src4->data + i2*(src4->nb[1])); // {d_state, n_tokens}
15964
- float * C = (float *) ((char *) src5->data + i2*(src5->nb[1])); // {d_state, n_tokens}
15965
-
15966
- LM_GGML_ASSERT(0 <= sq[0] && sq[0] < n_kv);
15967
-
15968
- // avoid needing to copy the state for the first token
15969
- if (i2 == 0) {
15970
- s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + sq[0]*(src0->nb[2])); // {d_state, d_inner, n_kv}
15971
- } else {
15972
- // otherwise the source is the same as the destination
15973
- s0 = s;
15974
- }
15975
-
15976
- // d_inner
15977
- for (int i1 = 0; i1 < ir; ++i1) {
15978
- // ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78
15979
- float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
15980
- float x_dt = x[i1] * dt_soft_plus;
15981
- float sumf = 0.0f;
15982
- // d_state
15983
- for (int i0 = 0; i0 < nc; ++i0) {
15984
- int i = i0 + i1*nc;
15985
- // state = prev_state * dA + dB * x
15986
- float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
15987
- // y = rowwise_dotprod(state, C)
15988
- sumf += state * C[i0];
15989
- s[i] = state;
15990
- }
15991
- y[i1] = sumf;
15992
- }
15993
-
15994
- // handle copies when there are multiple output states
15995
- for (int i3 = 1; i3 < n_kv; ++i3) {
15996
- int32_t seq = sq[i3];
15997
- if (0 <= seq && seq < n_kv) {
15998
- float * s1 = s + (seq - sq[0])*nc*nr;
15999
- memcpy(s1, s, nc*ir*sizeof(float));
16000
- } else {
16001
- // stop at negative or too big seq_ids
16002
- break;
15893
+ for (int i3 = 0; i3 < n_s; ++i3) {
15894
+ for (int i2 = 0; i2 < n_t; ++i2) {
15895
+ const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s}
15896
+ const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
15897
+ const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s}
15898
+ const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
15899
+ const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
15900
+ const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
15901
+ float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
15902
+ float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s}
15903
+
15904
+ // use the output as the source for the next token-wise iterations
15905
+ if (i2 > 0) { s0 = s; }
15906
+
15907
+ // d_inner
15908
+ for (int i1 = 0; i1 < ir; ++i1) {
15909
+ // ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78
15910
+ float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
15911
+ float x_dt = x[i1] * dt_soft_plus;
15912
+ float sumf = 0.0f;
15913
+ // d_state
15914
+ for (int i0 = 0; i0 < nc; ++i0) {
15915
+ int i = i0 + i1*nc;
15916
+ // state = prev_state * dA + dB * x
15917
+ float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
15918
+ // y = rowwise_dotprod(state, C)
15919
+ sumf += state * C[i0];
15920
+ s[i] = state;
15921
+ }
15922
+ y[i1] = sumf;
16003
15923
  }
16004
15924
  }
16005
15925
  }
package/cpp/ggml.h CHANGED
@@ -1760,7 +1760,8 @@ extern "C" {
1760
1760
  struct lm_ggml_tensor * v,
1761
1761
  struct lm_ggml_tensor * mask,
1762
1762
  float scale,
1763
- float max_bias);
1763
+ float max_bias,
1764
+ float logit_softcap);
1764
1765
 
1765
1766
  LM_GGML_API void lm_ggml_flash_attn_ext_set_prec(
1766
1767
  struct lm_ggml_tensor * a,
@@ -1777,10 +1778,8 @@ extern "C" {
1777
1778
 
1778
1779
  LM_GGML_API struct lm_ggml_tensor * lm_ggml_ssm_conv(
1779
1780
  struct lm_ggml_context * ctx,
1780
- struct lm_ggml_tensor * s,
1781
- struct lm_ggml_tensor * x,
1782
- struct lm_ggml_tensor * c,
1783
- struct lm_ggml_tensor * sq);
1781
+ struct lm_ggml_tensor * sx,
1782
+ struct lm_ggml_tensor * c);
1784
1783
 
1785
1784
  LM_GGML_API struct lm_ggml_tensor * lm_ggml_ssm_scan(
1786
1785
  struct lm_ggml_context * ctx,
@@ -1789,8 +1788,7 @@ extern "C" {
1789
1788
  struct lm_ggml_tensor * dt,
1790
1789
  struct lm_ggml_tensor * A,
1791
1790
  struct lm_ggml_tensor * B,
1792
- struct lm_ggml_tensor * C,
1793
- struct lm_ggml_tensor * sq);
1791
+ struct lm_ggml_tensor * C);
1794
1792
 
1795
1793
  // partition into non-overlapping windows with padding if needed
1796
1794
  // example:
package/cpp/llama-impl.h CHANGED
@@ -31,11 +31,17 @@ void llama_log_callback_default(lm_ggml_log_level level, const char * text, void
31
31
 
32
32
  static void replace_all(std::string & s, const std::string & search, const std::string & replace) {
33
33
  if (search.empty()) {
34
- return; // Avoid infinite loop if 'search' is an empty string
34
+ return;
35
35
  }
36
+ std::string builder;
37
+ builder.reserve(s.length());
36
38
  size_t pos = 0;
37
- while ((pos = s.find(search, pos)) != std::string::npos) {
38
- s.replace(pos, search.length(), replace);
39
- pos += replace.length();
39
+ size_t last_pos = 0;
40
+ while ((pos = s.find(search, last_pos)) != std::string::npos) {
41
+ builder.append(s, last_pos, pos - last_pos);
42
+ builder.append(replace);
43
+ last_pos = pos + search.length();
40
44
  }
45
+ builder.append(s, last_pos, std::string::npos);
46
+ s = std::move(builder);
41
47
  }
@@ -175,9 +175,10 @@ void llama_sample_xtc_impl(struct llama_sampling * smpl, llama_token_data_array
175
175
  if(xtc_threshold <= 0.0f || !candidates-> size) {
176
176
  return;
177
177
  }
178
- // TODO: xtc impl
178
+
179
179
  bool xtc_applied = false;
180
180
  const int64_t t_start_sample_us = lm_ggml_time_us();
181
+ llama_sample_softmax(nullptr, candidates);
181
182
 
182
183
  // unsorted iteration
183
184
  if (!candidates->sorted) {
@@ -199,7 +200,7 @@ void llama_sample_xtc_impl(struct llama_sampling * smpl, llama_token_data_array
199
200
 
200
201
  // sort top_tokens
201
202
  std::sort(top_tokens.begin(), top_tokens.end(), [](const llama_token_data & a, const llama_token_data & b) {
202
- return a.logit > b.logit;
203
+ return a.logit < b.logit;
203
204
  });
204
205
 
205
206
  // insert top_tokens with probability. Always insert lowest top_token
@@ -232,30 +233,31 @@ void llama_sample_xtc_impl(struct llama_sampling * smpl, llama_token_data_array
232
233
  size_t last_index = 0;
233
234
 
234
235
  for (; last_index < candidates -> size; ++last_index) {
235
- if(candidates -> data[last_index].logit < xtc_threshold) {
236
+ if(candidates -> data[last_index].p < xtc_threshold) {
236
237
  break;
237
238
  }
238
239
  }
239
- last_index--;
240
- // check if only 1 last index token or total less than min_keep
241
- if(last_index <= 1 || candidates-> size - last_index < min_keep) {
240
+
241
+ // check if only 1 token above threshold
242
+ if(last_index <= 1) {
242
243
  return;
243
244
  }
244
- // indexes to be skipped
245
- size_t safe_index = 0;
245
+ last_index--;
246
+ // items beyond safe index will be ignored
247
+ size_t safe_index = candidates -> size;
248
+
246
249
  // remove tokens until last threshold item
247
- candidates -> data;
248
250
  std::uniform_real_distribution<float> random_float(0.0 , 1.0);
249
251
  for (size_t i = 0; i < last_index; i++) {
250
252
  if(random_float(rng) < xtc_probability) {
251
- if(i != safe_index) {
252
- std::swap(candidates-> data[i], candidates->data[safe_index]);
253
+ std::swap(candidates-> data[i], candidates->data[safe_index - 1]);
254
+ safe_index--;
255
+ if (candidates-> sorted) {
256
+ candidates -> sorted = false;
253
257
  }
254
- safe_index++;
255
258
  }
256
259
  }
257
- candidates -> data = candidates -> data + safe_index;
258
- candidates -> size = candidates -> size - safe_index;
260
+ candidates -> size = safe_index;
259
261
  }
260
262
 
261
263
  if (smpl) {