llama_cpp 0.14.0 → 0.14.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -470,6 +470,19 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
470
470
  .type_size = sizeof(int32_t),
471
471
  .is_quantized = false,
472
472
  },
473
+ [GGML_TYPE_I64] = {
474
+ .type_name = "i64",
475
+ .blck_size = 1,
476
+ .type_size = sizeof(int64_t),
477
+ .is_quantized = false,
478
+ },
479
+ [GGML_TYPE_F64] = {
480
+ .type_name = "f64",
481
+ .blck_size = 1,
482
+ .type_size = sizeof(double),
483
+ .is_quantized = false,
484
+ .nrows = 1,
485
+ },
473
486
  [GGML_TYPE_F32] = {
474
487
  .type_name = "f32",
475
488
  .blck_size = 1,
@@ -857,7 +870,7 @@ inline static float vaddvq_f32(float32x4_t v) {
857
870
  #define GGML_F16x8 float16x8_t
858
871
  #define GGML_F16x8_ZERO vdupq_n_f16(0.0f)
859
872
  #define GGML_F16x8_SET1(x) vdupq_n_f16(x)
860
- #define GGML_F16x8_LOAD(x) vld1q_f16((const __fp16 *)(x))
873
+ #define GGML_F16x8_LOAD(x) vld1q_f16((const ggml_fp16_internal_t *)(x))
861
874
  #define GGML_F16x8_STORE vst1q_f16
862
875
  #define GGML_F16x8_FMA(a, b, c) vfmaq_f16(a, b, c)
863
876
  #define GGML_F16x8_ADD vaddq_f16
@@ -900,7 +913,7 @@ inline static float vaddvq_f32(float32x4_t v) {
900
913
  #define GGML_F32Cx4 float32x4_t
901
914
  #define GGML_F32Cx4_ZERO vdupq_n_f32(0.0f)
902
915
  #define GGML_F32Cx4_SET1(x) vdupq_n_f32(x)
903
- #define GGML_F32Cx4_LOAD(x) vcvt_f32_f16(vld1_f16((const __fp16 *)(x)))
916
+ #define GGML_F32Cx4_LOAD(x) vcvt_f32_f16(vld1_f16((const ggml_fp16_internal_t *)(x)))
904
917
  #define GGML_F32Cx4_STORE(x, y) vst1_f16(x, vcvt_f16_f32(y))
905
918
  #define GGML_F32Cx4_FMA(a, b, c) vfmaq_f32(a, b, c)
906
919
  #define GGML_F32Cx4_ADD vaddq_f32
@@ -1841,6 +1854,8 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
1841
1854
  "FLASH_ATTN",
1842
1855
  "FLASH_FF",
1843
1856
  "FLASH_ATTN_BACK",
1857
+ "SSM_CONV",
1858
+ "SSM_SCAN",
1844
1859
  "WIN_PART",
1845
1860
  "WIN_UNPART",
1846
1861
  "GET_REL_POS",
@@ -1863,7 +1878,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
1863
1878
  "CROSS_ENTROPY_LOSS_BACK",
1864
1879
  };
1865
1880
 
1866
- static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74");
1881
+ static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76");
1867
1882
 
1868
1883
  static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
1869
1884
  "none",
@@ -1929,6 +1944,8 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
1929
1944
  "flash_attn(x)",
1930
1945
  "flash_ff(x)",
1931
1946
  "flash_attn_back(x)",
1947
+ "ssm_conv(x)",
1948
+ "ssm_scan(x)",
1932
1949
  "win_part(x)",
1933
1950
  "win_unpart(x)",
1934
1951
  "get_rel_pos(x)",
@@ -1951,7 +1968,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
1951
1968
  "cross_entropy_loss_back(x,y)",
1952
1969
  };
1953
1970
 
1954
- static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74");
1971
+ static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76");
1955
1972
 
1956
1973
  static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
1957
1974
 
@@ -6154,6 +6171,108 @@ struct ggml_tensor * ggml_flash_attn_back(
6154
6171
  return result;
6155
6172
  }
6156
6173
 
6174
+ // ggml_ssm_conv
6175
+
6176
+ struct ggml_tensor * ggml_ssm_conv(
6177
+ struct ggml_context * ctx,
6178
+ struct ggml_tensor * s,
6179
+ struct ggml_tensor * x,
6180
+ struct ggml_tensor * c,
6181
+ struct ggml_tensor * sq) {
6182
+ GGML_ASSERT(ggml_is_3d(s));
6183
+ GGML_ASSERT(ggml_is_matrix(x));
6184
+ GGML_ASSERT(ggml_is_matrix(c));
6185
+ GGML_ASSERT(ggml_is_matrix(sq));
6186
+ GGML_ASSERT(sq->type == GGML_TYPE_I32);
6187
+
6188
+ const int64_t d_conv = c->ne[0];
6189
+ const int64_t d_inner = c->ne[1];
6190
+ const int64_t n_tokens = x->ne[1];
6191
+ const int64_t n_kv = s->ne[2];
6192
+
6193
+ GGML_ASSERT( s->ne[0] == d_conv - 1);
6194
+ GGML_ASSERT( s->ne[1] == d_inner);
6195
+ GGML_ASSERT( x->ne[0] == d_inner);
6196
+ GGML_ASSERT(sq->ne[0] == n_kv);
6197
+ GGML_ASSERT(sq->ne[1] == n_tokens);
6198
+
6199
+ bool is_node = false;
6200
+
6201
+ if (s->grad || x->grad || c->grad || sq->grad) {
6202
+ GGML_ASSERT(false); // TODO: implement
6203
+ is_node = true;
6204
+ }
6205
+
6206
+ // 2-in-1 concatenated x and conv_states, {d_inner, n_tokens} with {d_conv, d_inner, n_kv}
6207
+ struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, (d_inner*n_tokens) + (d_conv*d_inner*n_kv));
6208
+
6209
+ result->op = GGML_OP_SSM_CONV;
6210
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6211
+ result->src[0] = s;
6212
+ result->src[1] = x;
6213
+ result->src[2] = c;
6214
+ result->src[3] = sq;
6215
+
6216
+ return result;
6217
+ }
6218
+
6219
+ // ggml_ssm_scan
6220
+
6221
+ struct ggml_tensor * ggml_ssm_scan(
6222
+ struct ggml_context * ctx,
6223
+ struct ggml_tensor * s,
6224
+ struct ggml_tensor * x,
6225
+ struct ggml_tensor * dt,
6226
+ struct ggml_tensor * A,
6227
+ struct ggml_tensor * B,
6228
+ struct ggml_tensor * C,
6229
+ struct ggml_tensor * sq) {
6230
+ GGML_ASSERT(ggml_is_contiguous(s));
6231
+ GGML_ASSERT(ggml_is_contiguous(x));
6232
+ GGML_ASSERT(ggml_is_contiguous(dt));
6233
+ GGML_ASSERT(ggml_is_contiguous(A));
6234
+ GGML_ASSERT(sq->type == GGML_TYPE_I32);
6235
+ GGML_ASSERT(B->nb[0] == ggml_type_size(B->type));
6236
+ GGML_ASSERT(C->nb[0] == ggml_type_size(C->type));
6237
+ GGML_ASSERT(ggml_are_same_shape(x, dt));
6238
+
6239
+ {
6240
+ const int64_t d_state = s->ne[0];
6241
+ const int64_t d_inner = s->ne[1];
6242
+ const int64_t n_tokens = x->ne[1];
6243
+
6244
+ GGML_ASSERT(x->ne[0] == d_inner);
6245
+ GGML_ASSERT(A->ne[0] == d_state);
6246
+ GGML_ASSERT(A->ne[1] == d_inner);
6247
+ GGML_ASSERT(B->ne[0] == d_state);
6248
+ GGML_ASSERT(B->ne[1] == n_tokens);
6249
+ GGML_ASSERT(C->ne[0] == d_state);
6250
+ GGML_ASSERT(C->ne[1] == n_tokens);
6251
+ }
6252
+
6253
+ bool is_node = false;
6254
+
6255
+ if (s->grad || x->grad || dt->grad || A->grad || B->grad || C->grad || sq->grad) {
6256
+ GGML_ASSERT(false); // TODO: implement
6257
+ is_node = true;
6258
+ }
6259
+
6260
+ // 2-in-1 concatenated y and ssm_states, {d_inner, n_tokens} with {d_state, d_inner, n_kv}
6261
+ struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + ggml_nelements(s));
6262
+
6263
+ result->op = GGML_OP_SSM_SCAN;
6264
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6265
+ result->src[0] = s;
6266
+ result->src[1] = x;
6267
+ result->src[2] = dt;
6268
+ result->src[3] = A;
6269
+ result->src[4] = B;
6270
+ result->src[5] = C;
6271
+ result->src[6] = sq;
6272
+
6273
+ return result;
6274
+ }
6275
+
6157
6276
  // ggml_win_part
6158
6277
 
6159
6278
  struct ggml_tensor * ggml_win_part(
@@ -11454,8 +11573,6 @@ static void ggml_compute_forward_get_rows_q(
11454
11573
  const struct ggml_tensor * src0 = dst->src[0];
11455
11574
  const struct ggml_tensor * src1 = dst->src[1];
11456
11575
 
11457
- assert(params->ith == 0);
11458
-
11459
11576
  if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
11460
11577
  return;
11461
11578
  }
@@ -11463,7 +11580,7 @@ static void ggml_compute_forward_get_rows_q(
11463
11580
  GGML_TENSOR_BINARY_OP_LOCALS
11464
11581
 
11465
11582
  const int64_t nc = ne00;
11466
- const int64_t nr = ggml_nelements(src1); GGML_UNUSED(nr);
11583
+ const int64_t nr = ggml_nelements(src1);
11467
11584
 
11468
11585
  const enum ggml_type type = src0->type;
11469
11586
  ggml_to_float_t const dequantize_row_q = type_traits[type].to_float;
@@ -11473,17 +11590,25 @@ static void ggml_compute_forward_get_rows_q(
11473
11590
  assert(nb00 == ggml_type_size(type));
11474
11591
  assert(ggml_nrows(dst) == nr);
11475
11592
 
11476
- // TODO: multi-thread
11477
- for (int64_t i12 = 0; i12 < ne12; ++i12) {
11478
- for (int64_t i11 = 0; i11 < ne11; ++i11) {
11479
- for (int64_t i10 = 0; i10 < ne10; ++i10) {
11480
- const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
11593
+ const int ith = params->ith;
11594
+ const int nth = params->nth;
11481
11595
 
11482
- dequantize_row_q(
11483
- (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
11484
- (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc);
11485
- }
11486
- }
11596
+ // rows per thread
11597
+ const int dr = (nr + nth - 1)/nth;
11598
+
11599
+ // row range for this thread
11600
+ const int ir0 = dr*ith;
11601
+ const int ir1 = MIN(ir0 + dr, nr);
11602
+
11603
+ for (int64_t i = ir0; i < ir1; ++i) {
11604
+ const int64_t i12 = i/(ne11*ne10);
11605
+ const int64_t i11 = (i - i12*ne11*ne10)/ne10;
11606
+ const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
11607
+ const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
11608
+
11609
+ dequantize_row_q(
11610
+ (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
11611
+ (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc);
11487
11612
  }
11488
11613
  }
11489
11614
 
@@ -11494,8 +11619,6 @@ static void ggml_compute_forward_get_rows_f16(
11494
11619
  const struct ggml_tensor * src0 = dst->src[0];
11495
11620
  const struct ggml_tensor * src1 = dst->src[1];
11496
11621
 
11497
- assert(params->ith == 0);
11498
-
11499
11622
  if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
11500
11623
  return;
11501
11624
  }
@@ -11503,24 +11626,32 @@ static void ggml_compute_forward_get_rows_f16(
11503
11626
  GGML_TENSOR_BINARY_OP_LOCALS
11504
11627
 
11505
11628
  const int64_t nc = ne00;
11506
- const int64_t nr = ggml_nelements(src1); GGML_UNUSED(nr);
11629
+ const int64_t nr = ggml_nelements(src1);
11507
11630
 
11508
11631
  assert(ne0 == nc);
11509
11632
  assert(ne02 == ne11);
11510
11633
  assert(nb00 == sizeof(ggml_fp16_t));
11511
11634
  assert(ggml_nrows(dst) == nr);
11512
11635
 
11513
- // TODO: multi-thread
11514
- for (int64_t i12 = 0; i12 < ne12; ++i12) {
11515
- for (int64_t i11 = 0; i11 < ne11; ++i11) {
11516
- for (int64_t i10 = 0; i10 < ne10; ++i10) {
11517
- const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
11636
+ const int ith = params->ith;
11637
+ const int nth = params->nth;
11638
+
11639
+ // rows per thread
11640
+ const int dr = (nr + nth - 1)/nth;
11518
11641
 
11519
- ggml_fp16_to_fp32_row(
11520
- (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
11521
- (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc);
11522
- }
11523
- }
11642
+ // row range for this thread
11643
+ const int ir0 = dr*ith;
11644
+ const int ir1 = MIN(ir0 + dr, nr);
11645
+
11646
+ for (int64_t i = ir0; i < ir1; ++i) {
11647
+ const int64_t i12 = i/(ne11*ne10);
11648
+ const int64_t i11 = (i - i12*ne11*ne10)/ne10;
11649
+ const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
11650
+ const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
11651
+
11652
+ ggml_fp16_to_fp32_row(
11653
+ (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
11654
+ (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc);
11524
11655
  }
11525
11656
  }
11526
11657
 
@@ -11531,8 +11662,6 @@ static void ggml_compute_forward_get_rows_f32(
11531
11662
  const struct ggml_tensor * src0 = dst->src[0];
11532
11663
  const struct ggml_tensor * src1 = dst->src[1];
11533
11664
 
11534
- assert(params->ith == 0);
11535
-
11536
11665
  if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
11537
11666
  return;
11538
11667
  }
@@ -11540,24 +11669,32 @@ static void ggml_compute_forward_get_rows_f32(
11540
11669
  GGML_TENSOR_BINARY_OP_LOCALS
11541
11670
 
11542
11671
  const int64_t nc = ne00;
11543
- const int64_t nr = ggml_nelements(src1); GGML_UNUSED(nr);
11672
+ const int64_t nr = ggml_nelements(src1);
11544
11673
 
11545
11674
  assert(ne0 == nc);
11546
11675
  assert(ne02 == ne11);
11547
11676
  assert(nb00 == sizeof(float));
11548
11677
  assert(ggml_nrows(dst) == nr);
11549
11678
 
11550
- // TODO: multi-thread
11551
- for (int64_t i12 = 0; i12 < ne12; ++i12) {
11552
- for (int64_t i11 = 0; i11 < ne11; ++i11) {
11553
- for (int64_t i10 = 0; i10 < ne10; ++i10) {
11554
- const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
11679
+ const int ith = params->ith;
11680
+ const int nth = params->nth;
11555
11681
 
11556
- ggml_vec_cpy_f32(nc,
11557
- (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3),
11558
- (float *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03));
11559
- }
11560
- }
11682
+ // rows per thread
11683
+ const int dr = (nr + nth - 1)/nth;
11684
+
11685
+ // row range for this thread
11686
+ const int ir0 = dr*ith;
11687
+ const int ir1 = MIN(ir0 + dr, nr);
11688
+
11689
+ for (int64_t i = ir0; i < ir1; ++i) {
11690
+ const int64_t i12 = i/(ne11*ne10);
11691
+ const int64_t i11 = (i - i12*ne11*ne10)/ne10;
11692
+ const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
11693
+ const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
11694
+
11695
+ ggml_vec_cpy_f32(nc,
11696
+ (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3),
11697
+ (float *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03));
11561
11698
  }
11562
11699
  }
11563
11700
 
@@ -12294,6 +12431,8 @@ static void ggml_compute_forward_alibi(
12294
12431
  case GGML_TYPE_I8:
12295
12432
  case GGML_TYPE_I16:
12296
12433
  case GGML_TYPE_I32:
12434
+ case GGML_TYPE_I64:
12435
+ case GGML_TYPE_F64:
12297
12436
  case GGML_TYPE_COUNT:
12298
12437
  {
12299
12438
  GGML_ASSERT(false);
@@ -12380,6 +12519,8 @@ static void ggml_compute_forward_clamp(
12380
12519
  case GGML_TYPE_I8:
12381
12520
  case GGML_TYPE_I16:
12382
12521
  case GGML_TYPE_I32:
12522
+ case GGML_TYPE_I64:
12523
+ case GGML_TYPE_F64:
12383
12524
  case GGML_TYPE_COUNT:
12384
12525
  {
12385
12526
  GGML_ASSERT(false);
@@ -14771,6 +14912,257 @@ static void ggml_compute_forward_flash_attn_back(
14771
14912
  }
14772
14913
  }
14773
14914
 
14915
+ // ggml_compute_forward_ssm_conv
14916
+
14917
+ static void ggml_compute_forward_ssm_conv_f32(
14918
+ const struct ggml_compute_params * params,
14919
+ struct ggml_tensor * dst) {
14920
+ if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
14921
+ return;
14922
+ }
14923
+
14924
+ const struct ggml_tensor * src0 = dst->src[0]; // conv_state
14925
+ const struct ggml_tensor * src1 = dst->src[1]; // x
14926
+ const struct ggml_tensor * src2 = dst->src[2]; // conv1d.weight
14927
+ const struct ggml_tensor * src3 = dst->src[3]; // state_seq
14928
+
14929
+ const int ith = params->ith;
14930
+ const int nth = params->nth;
14931
+
14932
+ const int nc = src2->ne[0]; // d_conv
14933
+ const int nr = src0->ne[1]; // d_inner
14934
+ const int n_t = src1->ne[1]; // n_tokens
14935
+ const int n_kv = src0->ne[2]; // max number of sequences in the batch
14936
+
14937
+ GGML_ASSERT((nr*n_t) + (nc*nr*n_kv) == ggml_nelements(dst));
14938
+ GGML_ASSERT(src0->nb[0] == sizeof(float));
14939
+ GGML_ASSERT(src1->nb[0] == sizeof(float));
14940
+ GGML_ASSERT(src2->nb[0] == sizeof(float));
14941
+ GGML_ASSERT(src3->nb[0] == sizeof(int32_t));
14942
+ GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
14943
+ // for use with the destination state offset between sequences
14944
+ GGML_ASSERT(src2->nb[2] == src2->ne[1]*src2->ne[0]*sizeof(float));
14945
+
14946
+ // rows per thread
14947
+ const int dr = (nr + nth - 1)/nth;
14948
+
14949
+ // row range for this thread
14950
+ const int ir0 = dr*ith;
14951
+ const int ir1 = MIN(ir0 + dr, nr);
14952
+ const int ir = ir1 - ir0;
14953
+
14954
+ if (n_kv > 1) {
14955
+ // multiple sequences means it's hard to know when it's the first time a state is read,
14956
+ // so copy them all over to the destination, just to be sure.
14957
+ for (int i3 = 0; i3 < n_kv; ++i3) {
14958
+ float * s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]));
14959
+ float * s = (float *) ((char *) dst->data + ir0*(src2->nb[1]) + i3*(src2->nb[2]) + nr*n_t*sizeof(float));
14960
+ // can't use memcpy because of d_conv vs d_conv - 1
14961
+ for (int i1 = 0; i1 < ir; ++i1) {
14962
+ for (int i0 = 0; i0 < nc - 1; ++i0) {
14963
+ // copy s0 to last (d_conv - 1) columns of s
14964
+ s[1 + i0 + i1*nc] = s0[i0 + i1*(nc - 1)];
14965
+ }
14966
+ }
14967
+ }
14968
+ }
14969
+
14970
+ for (int i2 = 0; i2 < n_t; ++i2) {
14971
+ int32_t * sq = (int32_t *) ((char *) src3->data + i2*(src3->nb[1])); // {n_kv, n_tokens}
14972
+ float * x = (float *) ((char *) dst->data + ir0*sizeof(float) + i2*(nr*sizeof(float))); // {d_inner, n_tokens}
14973
+ float * s = (float *) ((char *) dst->data + ir0*(src2->nb[1]) + sq[0]*(src2->nb[2]) + nr*n_t*sizeof(float)); // {d_conv, d_inner, n_kv}
14974
+ float * s0; // {d_conv - 1, d_inner, n_kv}
14975
+ float * x0 = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens}
14976
+ float * c = (float *) ((char *) src2->data + ir0*(src2->nb[1])); // {d_conv, d_inner}
14977
+ int ne0s0;
14978
+
14979
+ GGML_ASSERT(0 <= sq[0] && sq[0] < n_kv);
14980
+
14981
+ // avoid needing to copy the state for the first token
14982
+ if (i2 == 0) {
14983
+ s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + sq[0]*(src0->nb[2])); // {d_conv - 1, d_inner, n_kv}
14984
+ ne0s0 = src0->ne[0];
14985
+ } else {
14986
+ // the source is the last (d_conv - 1) columns of the destination
14987
+ s0 = s + 1;
14988
+ ne0s0 = nc;
14989
+ }
14990
+
14991
+ // d_inner
14992
+ for (int i1 = 0; i1 < ir; ++i1) {
14993
+ // shift state left
14994
+ for (int i0 = 0; i0 < nc - 1; ++i0) {
14995
+ s[i0 + i1*nc] = s0[i0 + i1*ne0s0];
14996
+ }
14997
+ // insert x on the last column
14998
+ s[(nc - 1) + i1*nc] = x0[i1];
14999
+ }
15000
+
15001
+ // handle copies when there are multiple output states
15002
+ for (int i3 = 1; i3 < n_kv; ++i3) {
15003
+ int32_t seq = sq[i3];
15004
+ if (0 <= seq && seq < n_kv) {
15005
+ float * s1 = s + (seq - sq[0])*nc*nr;
15006
+ memcpy(s1, s, nc*ir*sizeof(float));
15007
+ } else {
15008
+ // stop at negative or too big seq_ids
15009
+ break;
15010
+ }
15011
+ }
15012
+
15013
+ // it seems a little faster when this is separate from the state shift
15014
+ for (int i1 = 0; i1 < ir; ++i1) {
15015
+ // rowwise dot product
15016
+ float sumf = 0.0f;
15017
+ for (int i0 = 0; i0 < nc; ++i0) {
15018
+ int i = i0 + i1*nc;
15019
+ sumf += s[i] * c[i];
15020
+ }
15021
+ x[i1] = sumf;
15022
+ }
15023
+ }
15024
+ }
15025
+
15026
+ static void ggml_compute_forward_ssm_conv(
15027
+ const struct ggml_compute_params * params,
15028
+ struct ggml_tensor * dst) {
15029
+ switch (dst->src[0]->type) {
15030
+ case GGML_TYPE_F32:
15031
+ {
15032
+ ggml_compute_forward_ssm_conv_f32(params, dst);
15033
+ } break;
15034
+ default:
15035
+ {
15036
+ GGML_ASSERT(false);
15037
+ } break;
15038
+ }
15039
+ }
15040
+
15041
+ // ggml_compute_forward_ssm_scan
15042
+
15043
+ static void ggml_compute_forward_ssm_scan_f32(
15044
+ const struct ggml_compute_params * params,
15045
+ struct ggml_tensor * dst) {
15046
+ if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
15047
+ return;
15048
+ }
15049
+
15050
+ const struct ggml_tensor * src0 = dst->src[0]; // s
15051
+ const struct ggml_tensor * src1 = dst->src[1]; // x
15052
+ const struct ggml_tensor * src2 = dst->src[2]; // dt
15053
+ const struct ggml_tensor * src3 = dst->src[3]; // A
15054
+ const struct ggml_tensor * src4 = dst->src[4]; // B
15055
+ const struct ggml_tensor * src5 = dst->src[5]; // C
15056
+ const struct ggml_tensor * src6 = dst->src[6]; // sq
15057
+
15058
+ const int ith = params->ith;
15059
+ const int nth = params->nth;
15060
+
15061
+ const int64_t nc = src0->ne[0]; // d_state
15062
+ const int64_t nr = src0->ne[1]; // d_inner
15063
+ const int64_t n_t = src1->ne[1]; // number of tokens in the batch
15064
+ const int64_t n_kv = src0->ne[2]; // max number of sequences in the batch
15065
+
15066
+ GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst));
15067
+ GGML_ASSERT(src0->nb[0] == sizeof(float));
15068
+ GGML_ASSERT(src1->nb[0] == sizeof(float));
15069
+ GGML_ASSERT(src2->nb[0] == sizeof(float));
15070
+ GGML_ASSERT(src3->nb[0] == sizeof(float));
15071
+ GGML_ASSERT(src4->nb[0] == sizeof(float));
15072
+ GGML_ASSERT(src5->nb[0] == sizeof(float));
15073
+ // required for the dot product between s and C, and when copying the states
15074
+ GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
15075
+ // required for per-sequence offsets for states
15076
+ GGML_ASSERT(src0->nb[2] == src0->ne[0]*src0->ne[1]*sizeof(float));
15077
+ // required to get correct offset for state destination (i.e. src1->nb[2])
15078
+ GGML_ASSERT(src1->nb[2] == src1->ne[0]*src1->ne[1]*sizeof(float));
15079
+
15080
+ // rows per thread
15081
+ const int dr = (nr + nth - 1)/nth;
15082
+
15083
+ // row range for this thread
15084
+ const int ir0 = dr*ith;
15085
+ const int ir1 = MIN(ir0 + dr, nr);
15086
+ const int ir = ir1 - ir0;
15087
+
15088
+ if (n_kv > 1) {
15089
+ // it's hard to know if the source states have already been copied
15090
+ // when there are multiple, so copy them already.
15091
+ for (int i3 = 0; i3 < n_kv; ++i3) {
15092
+ float * s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]));
15093
+ float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[2]);
15094
+ memcpy(s, s0, nc*ir*sizeof(float));
15095
+ }
15096
+ }
15097
+
15098
+ for (int i2 = 0; i2 < n_t; ++i2) {
15099
+ int32_t * sq = (int32_t *) ((char *) src6->data + i2*(src6->nb[1])); // {n_kv, n_tokens}
15100
+ float * y = (float *) ((char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens}
15101
+ float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + sq[0]*(src0->nb[2]) + src1->nb[2]); // {d_state, d_inner, n_kv}
15102
+ float * s0;
15103
+ float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens}
15104
+ float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1])); // {d_inner, n_tokens}
15105
+ float * A = (float *) ((char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
15106
+ float * B = (float *) ((char *) src4->data + i2*(src4->nb[1])); // {d_state, n_tokens}
15107
+ float * C = (float *) ((char *) src5->data + i2*(src5->nb[1])); // {d_state, n_tokens}
15108
+
15109
+ GGML_ASSERT(0 <= sq[0] && sq[0] < n_kv);
15110
+
15111
+ // avoid needing to copy the state for the first token
15112
+ if (i2 == 0) {
15113
+ s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + sq[0]*(src0->nb[2])); // {d_state, d_inner, n_kv}
15114
+ } else {
15115
+ // otherwise the source is the same as the destination
15116
+ s0 = s;
15117
+ }
15118
+
15119
+ // d_inner
15120
+ for (int i1 = 0; i1 < ir; ++i1) {
15121
+ // ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78
15122
+ float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
15123
+ float x_dt = x[i1] * dt_soft_plus;
15124
+ float sumf = 0.0f;
15125
+ // d_state
15126
+ for (int i0 = 0; i0 < nc; ++i0) {
15127
+ int i = i0 + i1*nc;
15128
+ // state = prev_state * dA + dB * x
15129
+ float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
15130
+ // y = rowwise_dotprod(state, C)
15131
+ sumf += state * C[i0];
15132
+ s[i] = state;
15133
+ }
15134
+ y[i1] = sumf;
15135
+ }
15136
+
15137
+ // handle copies when there are multiple output states
15138
+ for (int i3 = 1; i3 < n_kv; ++i3) {
15139
+ int32_t seq = sq[i3];
15140
+ if (0 <= seq && seq < n_kv) {
15141
+ float * s1 = s + (seq - sq[0])*nc*nr;
15142
+ memcpy(s1, s, nc*ir*sizeof(float));
15143
+ } else {
15144
+ // stop at negative or too big seq_ids
15145
+ break;
15146
+ }
15147
+ }
15148
+ }
15149
+ }
15150
+
15151
+ static void ggml_compute_forward_ssm_scan(
15152
+ const struct ggml_compute_params * params,
15153
+ struct ggml_tensor * dst) {
15154
+ switch (dst->src[0]->type) {
15155
+ case GGML_TYPE_F32:
15156
+ {
15157
+ ggml_compute_forward_ssm_scan_f32(params, dst);
15158
+ } break;
15159
+ default:
15160
+ {
15161
+ GGML_ASSERT(false);
15162
+ } break;
15163
+ }
15164
+ }
15165
+
14774
15166
  // ggml_compute_forward_win_part
14775
15167
 
14776
15168
  static void ggml_compute_forward_win_part_f32(
@@ -15830,6 +16222,14 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
15830
16222
  bool masked = t != 0;
15831
16223
  ggml_compute_forward_flash_attn_back(params, masked, tensor);
15832
16224
  } break;
16225
+ case GGML_OP_SSM_CONV:
16226
+ {
16227
+ ggml_compute_forward_ssm_conv(params, tensor);
16228
+ } break;
16229
+ case GGML_OP_SSM_SCAN:
16230
+ {
16231
+ ggml_compute_forward_ssm_scan(params, tensor);
16232
+ } break;
15833
16233
  case GGML_OP_WIN_PART:
15834
16234
  {
15835
16235
  ggml_compute_forward_win_part(params, tensor);
@@ -16884,6 +17284,11 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
16884
17284
  {
16885
17285
  GGML_ASSERT(false); // not supported
16886
17286
  } break;
17287
+ case GGML_OP_SSM_CONV:
17288
+ case GGML_OP_SSM_SCAN:
17289
+ {
17290
+ GGML_ASSERT(false); // TODO: not implemented
17291
+ } break;
16887
17292
  case GGML_OP_WIN_PART:
16888
17293
  case GGML_OP_WIN_UNPART:
16889
17294
  case GGML_OP_UNARY:
@@ -17426,7 +17831,7 @@ static void ggml_graph_compute_perf_stats_node(struct ggml_tensor * node, const
17426
17831
  node->perf_time_us += time_us_cur;
17427
17832
  }
17428
17833
 
17429
- static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
17834
+ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads, int n_cur_threads) {
17430
17835
  int n_tasks = 0;
17431
17836
 
17432
17837
  switch (node->op) {
@@ -17507,6 +17912,12 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
17507
17912
  {
17508
17913
  n_tasks = n_threads;
17509
17914
  } break;
17915
+ case GGML_OP_GET_ROWS:
17916
+ {
17917
+ // FIXME: the cost of launching additional threads decreases performance with GPU offloading
17918
+ //n_tasks = MIN(n_threads, ggml_nelements(node->src[1]));
17919
+ n_tasks = MIN(n_cur_threads, ggml_nelements(node->src[1]));
17920
+ } break;
17510
17921
  case GGML_OP_SCALE:
17511
17922
  case GGML_OP_SET:
17512
17923
  case GGML_OP_CONT:
@@ -17514,7 +17925,6 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
17514
17925
  case GGML_OP_VIEW:
17515
17926
  case GGML_OP_PERMUTE:
17516
17927
  case GGML_OP_TRANSPOSE:
17517
- case GGML_OP_GET_ROWS:
17518
17928
  case GGML_OP_GET_ROWS_BACK:
17519
17929
  case GGML_OP_DIAG:
17520
17930
  {
@@ -17590,6 +18000,11 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
17590
18000
  {
17591
18001
  n_tasks = n_threads;
17592
18002
  } break;
18003
+ case GGML_OP_SSM_CONV:
18004
+ case GGML_OP_SSM_SCAN:
18005
+ {
18006
+ n_tasks = n_threads;
18007
+ } break;
17593
18008
  case GGML_OP_WIN_PART:
17594
18009
  case GGML_OP_WIN_UNPART:
17595
18010
  case GGML_OP_GET_REL_POS:
@@ -17727,7 +18142,7 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
17727
18142
  /* FINALIZE */
17728
18143
  struct ggml_tensor * node = cgraph->nodes[node_n];
17729
18144
  if (GGML_OP_HAS_FINALIZE[node->op]) {
17730
- params.nth = ggml_get_n_tasks(node, n_threads);
18145
+ params.nth = ggml_get_n_tasks(node, n_threads, state->shared->n_threads);
17731
18146
  ggml_compute_forward(&params, node);
17732
18147
  }
17733
18148
  ggml_graph_compute_perf_stats_node(node, state->shared);
@@ -17737,7 +18152,7 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
17737
18152
  while (++node_n < cgraph->n_nodes) {
17738
18153
  GGML_PRINT_DEBUG_5("%s: %d/%d\n", __func__, node_n, cgraph->n_nodes);
17739
18154
  struct ggml_tensor * node = cgraph->nodes[node_n];
17740
- const int n_tasks = ggml_get_n_tasks(node, n_threads);
18155
+ const int n_tasks = ggml_get_n_tasks(node, n_threads, state->shared->n_threads);
17741
18156
 
17742
18157
  state->shared->perf_node_start_cycles = ggml_perf_cycles();
17743
18158
  state->shared->perf_node_start_time_us = ggml_perf_time_us();
@@ -17785,7 +18200,7 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
17785
18200
 
17786
18201
  /* INIT & COMPUTE */
17787
18202
  struct ggml_tensor * node = cgraph->nodes[node_n];
17788
- const int n_tasks = ggml_get_n_tasks(node, n_threads);
18203
+ const int n_tasks = ggml_get_n_tasks(node, n_threads, state->shared->n_threads);
17789
18204
 
17790
18205
  struct ggml_compute_params params = {
17791
18206
  /*.type =*/ GGML_TASK_TYPE_INIT,
@@ -17850,7 +18265,7 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
17850
18265
  for (int i = 0; i < cgraph->n_nodes; i++) {
17851
18266
  struct ggml_tensor * node = cgraph->nodes[i];
17852
18267
 
17853
- const int n_tasks = ggml_get_n_tasks(node, n_threads);
18268
+ const int n_tasks = ggml_get_n_tasks(node, n_threads, 1);
17854
18269
 
17855
18270
  max_tasks = MAX(max_tasks, n_tasks);
17856
18271
 
@@ -19784,133 +20199,6 @@ void ggml_quantize_free(void) {
19784
20199
  ggml_critical_section_end();
19785
20200
  }
19786
20201
 
19787
- size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist) {
19788
- assert(k % QK4_0 == 0);
19789
- const int nb = k / QK4_0;
19790
-
19791
- for (int b = 0; b < n; b += k) {
19792
- block_q4_0 * restrict y = (block_q4_0 *) dst + b/QK4_0;
19793
-
19794
- quantize_row_q4_0_reference(src + b, y, k);
19795
-
19796
- for (int i = 0; i < nb; i++) {
19797
- for (int j = 0; j < QK4_0; j += 2) {
19798
- const uint8_t vi0 = y[i].qs[j/2] & 0x0F;
19799
- const uint8_t vi1 = y[i].qs[j/2] >> 4;
19800
-
19801
- hist[vi0]++;
19802
- hist[vi1]++;
19803
- }
19804
- }
19805
- }
19806
-
19807
- return (n/QK4_0*sizeof(block_q4_0));
19808
- }
19809
-
19810
- size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist) {
19811
- assert(k % QK4_1 == 0);
19812
- const int nb = k / QK4_1;
19813
-
19814
- for (int b = 0; b < n; b += k) {
19815
- block_q4_1 * restrict y = (block_q4_1 *) dst + b/QK4_1;
19816
-
19817
- quantize_row_q4_1_reference(src + b, y, k);
19818
-
19819
- for (int i = 0; i < nb; i++) {
19820
- for (int j = 0; j < QK4_1; j += 2) {
19821
- const uint8_t vi0 = y[i].qs[j/2] & 0x0F;
19822
- const uint8_t vi1 = y[i].qs[j/2] >> 4;
19823
-
19824
- hist[vi0]++;
19825
- hist[vi1]++;
19826
- }
19827
- }
19828
- }
19829
-
19830
- return (n/QK4_1*sizeof(block_q4_1));
19831
- }
19832
-
19833
- size_t ggml_quantize_q5_0(const float * src, void * dst, int n, int k, int64_t * hist) {
19834
- assert(k % QK5_0 == 0);
19835
- const int nb = k / QK5_0;
19836
-
19837
- for (int b = 0; b < n; b += k) {
19838
- block_q5_0 * restrict y = (block_q5_0 *)dst + b/QK5_0;
19839
-
19840
- quantize_row_q5_0_reference(src + b, y, k);
19841
-
19842
- for (int i = 0; i < nb; i++) {
19843
- uint32_t qh;
19844
- memcpy(&qh, &y[i].qh, sizeof(qh));
19845
-
19846
- for (int j = 0; j < QK5_0; j += 2) {
19847
- const uint8_t vh0 = ((qh & (1u << (j/2 + 0 ))) >> (j/2 + 0 )) << 4;
19848
- const uint8_t vh1 = ((qh & (1u << (j/2 + 16))) >> (j/2 + 12));
19849
-
19850
- // cast to 16 bins
19851
- const uint8_t vi0 = ((y[i].qs[j/2] & 0x0F) | vh0) / 2;
19852
- const uint8_t vi1 = ((y[i].qs[j/2] >> 4) | vh1) / 2;
19853
-
19854
- hist[vi0]++;
19855
- hist[vi1]++;
19856
- }
19857
- }
19858
- }
19859
-
19860
- return (n/QK5_0*sizeof(block_q5_0));
19861
- }
19862
-
19863
- size_t ggml_quantize_q5_1(const float * src, void * dst, int n, int k, int64_t * hist) {
19864
- assert(k % QK5_1 == 0);
19865
- const int nb = k / QK5_1;
19866
-
19867
- for (int b = 0; b < n; b += k) {
19868
- block_q5_1 * restrict y = (block_q5_1 *)dst + b/QK5_1;
19869
-
19870
- quantize_row_q5_1_reference(src + b, y, k);
19871
-
19872
- for (int i = 0; i < nb; i++) {
19873
- uint32_t qh;
19874
- memcpy(&qh, &y[i].qh, sizeof(qh));
19875
-
19876
- for (int j = 0; j < QK5_1; j += 2) {
19877
- const uint8_t vh0 = ((qh & (1u << (j/2 + 0 ))) >> (j/2 + 0 )) << 4;
19878
- const uint8_t vh1 = ((qh & (1u << (j/2 + 16))) >> (j/2 + 12));
19879
-
19880
- // cast to 16 bins
19881
- const uint8_t vi0 = ((y[i].qs[j/2] & 0x0F) | vh0) / 2;
19882
- const uint8_t vi1 = ((y[i].qs[j/2] >> 4) | vh1) / 2;
19883
-
19884
- hist[vi0]++;
19885
- hist[vi1]++;
19886
- }
19887
- }
19888
- }
19889
-
19890
- return (n/QK5_1*sizeof(block_q5_1));
19891
- }
19892
-
19893
- size_t ggml_quantize_q8_0(const float * src, void * dst, int n, int k, int64_t * hist) {
19894
- assert(k % QK8_0 == 0);
19895
- const int nb = k / QK8_0;
19896
-
19897
- for (int b = 0; b < n; b += k) {
19898
- block_q8_0 * restrict y = (block_q8_0 *)dst + b/QK8_0;
19899
-
19900
- quantize_row_q8_0_reference(src + b, y, k);
19901
-
19902
- for (int i = 0; i < nb; i++) {
19903
- for (int j = 0; j < QK8_0; ++j) {
19904
- const int8_t vi = y[i].qs[j];
19905
-
19906
- hist[vi/16 + 8]++;
19907
- }
19908
- }
19909
- }
19910
-
19911
- return (n/QK8_0*sizeof(block_q8_0));
19912
- }
19913
-
19914
20202
  bool ggml_quantize_requires_imatrix(enum ggml_type type) {
19915
20203
  return
19916
20204
  type == GGML_TYPE_IQ2_XXS ||
@@ -19918,177 +20206,52 @@ bool ggml_quantize_requires_imatrix(enum ggml_type type) {
19918
20206
  type == GGML_TYPE_IQ1_S;
19919
20207
  }
19920
20208
 
19921
- size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start,
19922
- int nrows, int n_per_row, int64_t * hist, const float * imatrix) {
20209
+ size_t ggml_quantize_chunk(
20210
+ enum ggml_type type,
20211
+ const float * src,
20212
+ void * dst,
20213
+ int start,
20214
+ int nrows,
20215
+ int n_per_row,
20216
+ const float * imatrix) {
20217
+ const int n = nrows * n_per_row;
20218
+
20219
+ if (ggml_quantize_requires_imatrix(type)) {
20220
+ GGML_ASSERT(imatrix != NULL);
20221
+ }
20222
+
20223
+ GGML_ASSERT(start % type_traits[type].blck_size == 0);
20224
+ GGML_ASSERT(start % n_per_row == 0);
20225
+
19923
20226
  ggml_quantize_init(type); // this is noop if already initialized
20227
+
20228
+ const size_t start_row = start / n_per_row;
20229
+ const size_t row_size = ggml_row_size(type, n_per_row);
20230
+
19924
20231
  size_t result = 0;
19925
- int n = nrows * n_per_row;
20232
+
19926
20233
  switch (type) {
19927
- case GGML_TYPE_Q4_0:
19928
- {
19929
- GGML_ASSERT(start % QK4_0 == 0);
19930
- GGML_ASSERT(start % n_per_row == 0);
19931
- size_t start_row = start / n_per_row;
19932
- size_t row_size = ggml_row_size(type, n_per_row);
19933
- result = quantize_q4_0(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix);
19934
- GGML_ASSERT(result == row_size * nrows);
19935
- } break;
19936
- case GGML_TYPE_Q4_1:
19937
- {
19938
- GGML_ASSERT(start % QK4_1 == 0);
19939
- GGML_ASSERT(start % n_per_row == 0);
19940
- size_t start_row = start / n_per_row;
19941
- size_t row_size = ggml_row_size(type, n_per_row);
19942
- result = quantize_q4_1(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix);
19943
- GGML_ASSERT(result == row_size * nrows);
19944
- } break;
19945
- case GGML_TYPE_Q5_0:
19946
- {
19947
- GGML_ASSERT(start % QK5_0 == 0);
19948
- GGML_ASSERT(start % n_per_row == 0);
19949
- size_t start_row = start / n_per_row;
19950
- size_t row_size = ggml_row_size(type, n_per_row);
19951
- result = quantize_q5_0(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix);
19952
- GGML_ASSERT(result == row_size * nrows);
19953
- } break;
19954
- case GGML_TYPE_Q5_1:
19955
- {
19956
- GGML_ASSERT(start % QK5_1 == 0);
19957
- GGML_ASSERT(start % n_per_row == 0);
19958
- size_t start_row = start / n_per_row;
19959
- size_t row_size = ggml_row_size(type, n_per_row);
19960
- result = quantize_q5_1(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix);
19961
- GGML_ASSERT(result == row_size * nrows);
19962
- } break;
19963
- case GGML_TYPE_Q8_0:
19964
- {
19965
- GGML_ASSERT(start % QK8_0 == 0);
19966
- block_q8_0 * block = (block_q8_0*)dst + start / QK8_0;
19967
- result = ggml_quantize_q8_0(src + start, block, n, n, hist);
19968
- } break;
19969
- case GGML_TYPE_Q2_K:
19970
- {
19971
- GGML_ASSERT(start % QK_K == 0);
19972
- GGML_ASSERT(start % n_per_row == 0);
19973
- size_t start_row = start / n_per_row;
19974
- size_t row_size = ggml_row_size(type, n_per_row);
19975
- result = quantize_q2_K(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix);
19976
- GGML_ASSERT(result == row_size * nrows);
19977
- } break;
19978
- case GGML_TYPE_Q3_K:
19979
- {
19980
- GGML_ASSERT(start % QK_K == 0);
19981
- GGML_ASSERT(start % n_per_row == 0);
19982
- size_t start_row = start / n_per_row;
19983
- size_t row_size = ggml_row_size(type, n_per_row);
19984
- result = quantize_q3_K(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix);
19985
- GGML_ASSERT(result == row_size * nrows);
19986
- } break;
19987
- case GGML_TYPE_Q4_K:
19988
- {
19989
- GGML_ASSERT(start % QK_K == 0);
19990
- GGML_ASSERT(start % n_per_row == 0);
19991
- size_t start_row = start / n_per_row;
19992
- size_t row_size = ggml_row_size(type, n_per_row);
19993
- result = quantize_q4_K(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix);
19994
- GGML_ASSERT(result == row_size * nrows);
19995
- } break;
19996
- case GGML_TYPE_Q5_K:
19997
- {
19998
- GGML_ASSERT(start % QK_K == 0);
19999
- GGML_ASSERT(start % n_per_row == 0);
20000
- size_t start_row = start / n_per_row;
20001
- size_t row_size = ggml_row_size(type, n_per_row);
20002
- result = quantize_q5_K(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix);
20003
- GGML_ASSERT(result == row_size * nrows);
20004
- } break;
20005
- case GGML_TYPE_Q6_K:
20006
- {
20007
- GGML_ASSERT(start % QK_K == 0);
20008
- GGML_ASSERT(start % n_per_row == 0);
20009
- size_t start_row = start / n_per_row;
20010
- size_t row_size = ggml_row_size(type, n_per_row);
20011
- result = quantize_q6_K(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix);
20012
- GGML_ASSERT(result == row_size * nrows);
20013
- } break;
20014
- case GGML_TYPE_IQ2_XXS:
20015
- {
20016
- GGML_ASSERT(start % QK_K == 0);
20017
- GGML_ASSERT(start % n_per_row == 0);
20018
- GGML_ASSERT(imatrix);
20019
- size_t start_row = start / n_per_row;
20020
- size_t row_size = ggml_row_size(type, n_per_row);
20021
- result = quantize_iq2_xxs(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix);
20022
- GGML_ASSERT(result == row_size * nrows);
20023
- } break;
20024
- case GGML_TYPE_IQ2_XS:
20025
- {
20026
- GGML_ASSERT(start % QK_K == 0);
20027
- GGML_ASSERT(start % n_per_row == 0);
20028
- GGML_ASSERT(imatrix);
20029
- size_t start_row = start / n_per_row;
20030
- size_t row_size = ggml_row_size(type, n_per_row);
20031
- result = quantize_iq2_xs(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix);
20032
- GGML_ASSERT(result == row_size * nrows);
20033
- } break;
20034
- case GGML_TYPE_IQ3_XXS:
20035
- {
20036
- GGML_ASSERT(start % QK_K == 0);
20037
- GGML_ASSERT(start % n_per_row == 0);
20038
- size_t start_row = start / n_per_row;
20039
- size_t row_size = ggml_row_size(type, n_per_row);
20040
- result = quantize_iq3_xxs(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix);
20041
- GGML_ASSERT(result == row_size * nrows);
20042
- } break;
20043
- case GGML_TYPE_IQ3_S:
20044
- {
20045
- GGML_ASSERT(start % QK_K == 0);
20046
- GGML_ASSERT(start % n_per_row == 0);
20047
- size_t start_row = start / n_per_row;
20048
- size_t row_size = ggml_row_size(type, n_per_row);
20049
- result = quantize_iq3_s(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix);
20050
- GGML_ASSERT(result == row_size * nrows);
20051
- } break;
20052
- case GGML_TYPE_IQ2_S:
20053
- {
20054
- GGML_ASSERT(start % QK_K == 0);
20055
- GGML_ASSERT(start % n_per_row == 0);
20056
- size_t start_row = start / n_per_row;
20057
- size_t row_size = ggml_row_size(type, n_per_row);
20058
- result = quantize_iq2_s(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix);
20059
- GGML_ASSERT(result == row_size * nrows);
20060
- } break;
20061
- case GGML_TYPE_IQ1_S:
20062
- {
20063
- GGML_ASSERT(start % QK_K == 0);
20064
- GGML_ASSERT(start % n_per_row == 0);
20065
- size_t start_row = start / n_per_row;
20066
- size_t row_size = ggml_row_size(type, n_per_row);
20067
- result = quantize_iq1_s(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix);
20068
- GGML_ASSERT(result == row_size * nrows);
20069
- } break;
20070
- case GGML_TYPE_IQ4_NL:
20234
+ case GGML_TYPE_Q4_0: result = quantize_q4_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
20235
+ case GGML_TYPE_Q4_1: result = quantize_q4_1(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
20236
+ case GGML_TYPE_Q5_0: result = quantize_q5_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
20237
+ case GGML_TYPE_Q5_1: result = quantize_q5_1(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
20238
+ case GGML_TYPE_Q8_0: result = quantize_q8_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
20239
+ case GGML_TYPE_Q2_K: result = quantize_q2_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
20240
+ case GGML_TYPE_Q3_K: result = quantize_q3_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
20241
+ case GGML_TYPE_Q4_K: result = quantize_q4_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
20242
+ case GGML_TYPE_Q5_K: result = quantize_q5_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
20243
+ case GGML_TYPE_Q6_K: result = quantize_q6_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
20244
+ case GGML_TYPE_IQ2_XXS: result = quantize_iq2_xxs(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
20245
+ case GGML_TYPE_IQ2_XS: result = quantize_iq2_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
20246
+ case GGML_TYPE_IQ3_XXS: result = quantize_iq3_xxs(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
20247
+ case GGML_TYPE_IQ3_S: result = quantize_iq3_s (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
20248
+ case GGML_TYPE_IQ2_S: result = quantize_iq2_s (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
20249
+ case GGML_TYPE_IQ1_S: result = quantize_iq1_s (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
20250
+ case GGML_TYPE_IQ4_NL: result = quantize_iq4_nl (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
20071
20251
  #if QK_K == 64
20072
- case GGML_TYPE_IQ4_XS:
20073
- #endif
20074
- {
20075
- GGML_ASSERT(start % QK4_NL == 0);
20076
- GGML_ASSERT(start % n_per_row == 0);
20077
- size_t start_row = start / n_per_row;
20078
- size_t row_size = ggml_row_size(type, n_per_row);
20079
- result = quantize_iq4_nl(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix);
20080
- GGML_ASSERT(result == row_size * nrows);
20081
- } break;
20082
- #if QK_K != 64
20083
- case GGML_TYPE_IQ4_XS:
20084
- {
20085
- GGML_ASSERT(start % QK_K == 0);
20086
- GGML_ASSERT(start % n_per_row == 0);
20087
- size_t start_row = start / n_per_row;
20088
- size_t row_size = ggml_row_size(type, n_per_row);
20089
- result = quantize_iq4_xs(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix);
20090
- GGML_ASSERT(result == row_size * nrows);
20091
- } break;
20252
+ case GGML_TYPE_IQ4_XS: result = quantize_iq4_nl (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
20253
+ #else
20254
+ case GGML_TYPE_IQ4_XS: result = quantize_iq4_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
20092
20255
  #endif
20093
20256
  case GGML_TYPE_F16:
20094
20257
  {
@@ -20105,6 +20268,9 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i
20105
20268
  default:
20106
20269
  assert(false);
20107
20270
  }
20271
+
20272
+ GGML_ASSERT(result == nrows * row_size);
20273
+
20108
20274
  return result;
20109
20275
  }
20110
20276