llama_cpp 0.14.7 → 0.15.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -951,7 +951,7 @@ ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type) {
951
951
  #define GGML_F16_VEC_ZERO GGML_F16x8_ZERO
952
952
  #define GGML_F16_VEC_SET1 GGML_F16x8_SET1
953
953
  #define GGML_F16_VEC_LOAD(p, i) GGML_F16x8_LOAD(p)
954
- #define GGML_F16_VEC_STORE(p, r, i) GGML_F16x8_STORE(p, r[i])
954
+ #define GGML_F16_VEC_STORE(p, r, i) GGML_F16x8_STORE((ggml_fp16_internal_t *)(p), r[i])
955
955
  #define GGML_F16_VEC_FMA GGML_F16x8_FMA
956
956
  #define GGML_F16_VEC_ADD GGML_F16x8_ADD
957
957
  #define GGML_F16_VEC_MUL GGML_F16x8_MUL
@@ -977,7 +977,7 @@ ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type) {
977
977
  #define GGML_F16_VEC_ZERO GGML_F32Cx4_ZERO
978
978
  #define GGML_F16_VEC_SET1 GGML_F32Cx4_SET1
979
979
  #define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx4_LOAD(p)
980
- #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE(p, r[i])
980
+ #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE((ggml_fp16_internal_t *)(p), r[i])
981
981
  #define GGML_F16_VEC_FMA GGML_F32Cx4_FMA
982
982
  #define GGML_F16_VEC_ADD GGML_F32Cx4_ADD
983
983
  #define GGML_F16_VEC_MUL GGML_F32Cx4_MUL
@@ -1046,7 +1046,7 @@ do { \
1046
1046
 
1047
1047
  // unlike _mm256_cvt intrinsics that require F16C, _mm512_cvt is defined in AVX512F
1048
1048
  // so F16C guard isn't required
1049
- #define GGML_F32Cx16_LOAD(x) _mm512_cvtph_ps(_mm256_loadu_si256((__m256i *)(x)))
1049
+ #define GGML_F32Cx16_LOAD(x) _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(x)))
1050
1050
  #define GGML_F32Cx16_STORE(x, y) _mm256_storeu_si256((__m256i *)(x), _mm512_cvtps_ph(y, 0))
1051
1051
 
1052
1052
  #define GGML_F32Cx16_FMA(a, b, c) _mm512_fmadd_ps(b, c, a)
@@ -1144,7 +1144,7 @@ do { \
1144
1144
 
1145
1145
  #if defined(__F16C__)
1146
1146
  // the _mm256_cvt intrinsics require F16C
1147
- #define GGML_F32Cx8_LOAD(x) _mm256_cvtph_ps(_mm_loadu_si128((__m128i *)(x)))
1147
+ #define GGML_F32Cx8_LOAD(x) _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(x)))
1148
1148
  #define GGML_F32Cx8_STORE(x, y) _mm_storeu_si128((__m128i *)(x), _mm256_cvtps_ph(y, 0))
1149
1149
  #else
1150
1150
  static inline __m256 __avx_f32cx8_load(ggml_fp16_t *x) {
@@ -1662,6 +1662,37 @@ inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float
1662
1662
  #endif
1663
1663
  }
1664
1664
 
1665
+ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, const ggml_fp16_t * restrict x, const float v) {
1666
+ #if defined(GGML_SIMD)
1667
+ const int np = (n & ~(GGML_F16_STEP - 1));
1668
+
1669
+ GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
1670
+
1671
+ GGML_F16_VEC ax[GGML_F16_ARR];
1672
+ GGML_F16_VEC ay[GGML_F16_ARR];
1673
+
1674
+ for (int i = 0; i < np; i += GGML_F16_STEP) {
1675
+ for (int j = 0; j < GGML_F16_ARR; j++) {
1676
+ ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j);
1677
+ ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
1678
+ ay[j] = GGML_F16_VEC_FMA(ay[j], ax[j], vx);
1679
+
1680
+ GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
1681
+ }
1682
+ }
1683
+
1684
+ // leftovers
1685
+ for (int i = np; i < n; ++i) {
1686
+ y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v);
1687
+ }
1688
+ #else
1689
+ // scalar
1690
+ for (int i = 0; i < n; ++i) {
1691
+ y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v);
1692
+ }
1693
+ #endif
1694
+ }
1695
+
1665
1696
  // xs and vs are byte strides of x and v
1666
1697
  inline static void ggml_vec_mad_f32_unroll(const int n, const int xs, const int vs, float * restrict y, const float * restrict xv, const float * restrict vv) {
1667
1698
 
@@ -1746,6 +1777,35 @@ inline static void ggml_vec_scale_f32(const int n, float * y, const float v) {
1746
1777
  #endif
1747
1778
  }
1748
1779
 
1780
+ inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float v) {
1781
+ #if defined(GGML_SIMD)
1782
+ const int np = (n & ~(GGML_F16_STEP - 1));
1783
+
1784
+ GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
1785
+
1786
+ GGML_F16_VEC ay[GGML_F16_ARR];
1787
+
1788
+ for (int i = 0; i < np; i += GGML_F16_STEP) {
1789
+ for (int j = 0; j < GGML_F16_ARR; j++) {
1790
+ ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
1791
+ ay[j] = GGML_F16_VEC_MUL(ay[j], vx);
1792
+
1793
+ GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
1794
+ }
1795
+ }
1796
+
1797
+ // leftovers
1798
+ for (int i = np; i < n; ++i) {
1799
+ y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i])*v);
1800
+ }
1801
+ #else
1802
+ // scalar
1803
+ for (int i = 0; i < n; ++i) {
1804
+ y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i])*v);
1805
+ }
1806
+ #endif
1807
+ }
1808
+
1749
1809
  inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, 0, x, 0, x, 0, 1); *s = sqrtf(*s); }
1750
1810
  inline static void ggml_vec_sqr_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i]; }
1751
1811
  inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrtf(x[i]); }
@@ -2000,6 +2060,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
2000
2060
  "LEAKY_RELU",
2001
2061
 
2002
2062
  "FLASH_ATTN",
2063
+ "FLASH_ATTN_EXT",
2003
2064
  "FLASH_FF",
2004
2065
  "FLASH_ATTN_BACK",
2005
2066
  "SSM_CONV",
@@ -2026,7 +2087,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
2026
2087
  "CROSS_ENTROPY_LOSS_BACK",
2027
2088
  };
2028
2089
 
2029
- static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76");
2090
+ static_assert(GGML_OP_COUNT == 77, "GGML_OP_COUNT != 77");
2030
2091
 
2031
2092
  static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
2032
2093
  "none",
@@ -2090,6 +2151,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
2090
2151
  "leaky_relu(x)",
2091
2152
 
2092
2153
  "flash_attn(x)",
2154
+ "flash_attn_ext(x)",
2093
2155
  "flash_ff(x)",
2094
2156
  "flash_attn_back(x)",
2095
2157
  "ssm_conv(x)",
@@ -2116,7 +2178,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
2116
2178
  "cross_entropy_loss_back(x,y)",
2117
2179
  };
2118
2180
 
2119
- static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76");
2181
+ static_assert(GGML_OP_COUNT == 77, "GGML_OP_COUNT != 77");
2120
2182
 
2121
2183
  static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
2122
2184
 
@@ -4559,6 +4621,8 @@ struct ggml_tensor * ggml_mul_mat(
4559
4621
  void ggml_mul_mat_set_prec(
4560
4622
  struct ggml_tensor * a,
4561
4623
  enum ggml_prec prec) {
4624
+ GGML_ASSERT(a->op == GGML_OP_MUL_MAT);
4625
+
4562
4626
  const int32_t prec_i32 = (int32_t) prec;
4563
4627
 
4564
4628
  ggml_set_op_params_i32(a, 0, prec_i32);
@@ -5397,17 +5461,23 @@ static struct ggml_tensor * ggml_soft_max_impl(
5397
5461
  GGML_ASSERT(ggml_is_contiguous(a));
5398
5462
 
5399
5463
  if (mask) {
5464
+ GGML_ASSERT(mask->type == GGML_TYPE_F16 || mask->type == GGML_TYPE_F32);
5400
5465
  GGML_ASSERT(ggml_is_contiguous(mask));
5401
5466
  GGML_ASSERT(ggml_is_matrix(mask));
5402
- GGML_ASSERT(ggml_can_repeat_rows(mask, a));
5467
+ GGML_ASSERT(mask->ne[0] == a->ne[0]);
5468
+ GGML_ASSERT(mask->ne[1] >= a->ne[1]);
5403
5469
  }
5404
5470
 
5405
5471
  if (pos) {
5406
5472
  GGML_ASSERT(ggml_is_vector(pos));
5407
- GGML_ASSERT(pos->type == GGML_TYPE_F32);
5473
+ GGML_ASSERT(pos->type == GGML_TYPE_F16 || pos->type == GGML_TYPE_F32);
5408
5474
  GGML_ASSERT(pos->ne[0] == a->ne[0]);
5409
5475
  }
5410
5476
 
5477
+ if (pos && mask) {
5478
+ GGML_ASSERT(pos->type == mask->type);
5479
+ }
5480
+
5411
5481
  if (max_bias > 0.0f) {
5412
5482
  GGML_ASSERT(pos);
5413
5483
  }
@@ -6216,6 +6286,59 @@ struct ggml_tensor * ggml_flash_attn(
6216
6286
  return result;
6217
6287
  }
6218
6288
 
6289
+ // ggml_flash_attn_ext
6290
+
6291
+ struct ggml_tensor * ggml_flash_attn_ext(
6292
+ struct ggml_context * ctx,
6293
+ struct ggml_tensor * q,
6294
+ struct ggml_tensor * k,
6295
+ struct ggml_tensor * v,
6296
+ struct ggml_tensor * mask,
6297
+ float scale) {
6298
+ GGML_ASSERT(ggml_can_mul_mat(k, q));
6299
+ // TODO: check if vT can be multiplied by (k*qT)
6300
+ if (mask) {
6301
+ GGML_ASSERT(ggml_is_contiguous(mask));
6302
+ GGML_ASSERT(mask->ne[2] == 1);
6303
+ GGML_ASSERT(mask->ne[3] == 1);
6304
+ GGML_ASSERT(mask->ne[1] >= GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD) &&
6305
+ "the Flash-Attention kernel requires the mask to be padded to GGML_KQ_MASK_PAD and at least n_queries big");
6306
+ //GGML_ASSERT(ggml_can_repeat_rows(mask, qk));
6307
+ }
6308
+
6309
+ bool is_node = false;
6310
+
6311
+ if (q->grad || k->grad || v->grad) {
6312
+ is_node = true;
6313
+ }
6314
+
6315
+ // permute(0, 2, 1, 3)
6316
+ int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] };
6317
+ struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
6318
+
6319
+ float params[] = { scale };
6320
+ ggml_set_op_params(result, params, sizeof(params));
6321
+
6322
+ result->op = GGML_OP_FLASH_ATTN_EXT;
6323
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6324
+ result->src[0] = q;
6325
+ result->src[1] = k;
6326
+ result->src[2] = v;
6327
+ result->src[3] = mask;
6328
+
6329
+ return result;
6330
+ }
6331
+
6332
+ void ggml_flash_attn_ext_set_prec(
6333
+ struct ggml_tensor * a,
6334
+ enum ggml_prec prec) {
6335
+ GGML_ASSERT(a->op == GGML_OP_FLASH_ATTN_EXT);
6336
+
6337
+ const int32_t prec_i32 = (int32_t) prec;
6338
+
6339
+ ggml_set_op_params_i32(a, 1, prec_i32); // scale is on first pos
6340
+ }
6341
+
6219
6342
  // ggml_flash_ff
6220
6343
 
6221
6344
  struct ggml_tensor * ggml_flash_ff(
@@ -12255,7 +12378,7 @@ static void ggml_compute_forward_soft_max_f32(
12255
12378
 
12256
12379
  GGML_TENSOR_UNARY_OP_LOCALS
12257
12380
 
12258
- const int64_t ne11 = src1 ? src1->ne[1] : 1;
12381
+ //const int64_t ne11 = src1 ? src1->ne[1] : 1;
12259
12382
 
12260
12383
  // TODO: is this supposed to be ceil instead of floor?
12261
12384
  // https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370
@@ -12278,19 +12401,31 @@ static void ggml_compute_forward_soft_max_f32(
12278
12401
  float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
12279
12402
 
12280
12403
  // when max_bias <= 0.0f, src2 is not used and we default it to src0 to avoid branching
12281
- float * pos = src2 ? (float *) src2->data : src0->data;
12404
+ ggml_fp16_t * pos_f16 = src2 ? (ggml_fp16_t *) src2->data : src0->data;
12405
+ float * pos_f32 = src2 ? (float *) src2->data : src0->data;
12406
+
12407
+ const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16) || (src2 && src2->type == GGML_TYPE_F16);
12282
12408
 
12283
12409
  for (int i1 = ir0; i1 < ir1; i1++) {
12284
12410
  float * sp = (float *)((char *) src0->data + i1*src0->nb[1]);
12285
12411
  float * dp = (float *)((char *) dst->data + i1*dst->nb[1]);
12286
12412
 
12287
12413
  // broadcast the mask across rows
12288
- float * mp = src1 ? (float *)((char *) src1->data + (i1%ne11)*src1->nb[1]) : NULL;
12414
+ ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
12415
+ float * mp_f32 = src1 ? (float *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
12289
12416
 
12290
12417
  ggml_vec_cpy_f32 (nc, wp, sp);
12291
12418
  ggml_vec_scale_f32(nc, wp, scale);
12292
- if (mp) {
12293
- ggml_vec_acc_f32(nc, wp, mp);
12419
+ if (mp_f32) {
12420
+ if (use_f16) {
12421
+ for (int i = 0; i < nc; ++i) {
12422
+ wp[i] += GGML_FP16_TO_FP32(mp_f16[i]);
12423
+ }
12424
+ } else {
12425
+ for (int i = 0; i < nc; ++i) {
12426
+ wp[i] += mp_f32[i];
12427
+ }
12428
+ }
12294
12429
  }
12295
12430
 
12296
12431
  // ALiBi bias
@@ -12298,8 +12433,14 @@ static void ggml_compute_forward_soft_max_f32(
12298
12433
  const uint32_t h = (i1/ne01)%ne02; // head
12299
12434
  const float slope = h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1);
12300
12435
 
12301
- for (int i = 0; i < nc; i++) {
12302
- wp[i] = wp[i] + slope*pos[i];
12436
+ if (use_f16) {
12437
+ for (int i = 0; i < nc; ++i) {
12438
+ wp[i] += slope*GGML_FP16_TO_FP32(pos_f16[i]);
12439
+ }
12440
+ } else {
12441
+ for (int i = 0; i < nc; ++i) {
12442
+ wp[i] += slope*pos_f32[i];
12443
+ }
12303
12444
  }
12304
12445
  }
12305
12446
 
@@ -14569,6 +14710,198 @@ static void ggml_compute_forward_flash_attn(
14569
14710
  }
14570
14711
  }
14571
14712
 
14713
+ // ggml_compute_forward_flash_attn_ext
14714
+
14715
+ static void ggml_compute_forward_flash_attn_ext_f16(
14716
+ const struct ggml_compute_params * params,
14717
+ const struct ggml_tensor * q,
14718
+ const struct ggml_tensor * k,
14719
+ const struct ggml_tensor * v,
14720
+ const struct ggml_tensor * mask,
14721
+ struct ggml_tensor * dst) {
14722
+ int64_t t0 = ggml_perf_time_us();
14723
+ UNUSED(t0);
14724
+
14725
+ GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
14726
+ GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
14727
+ GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
14728
+ GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
14729
+ GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
14730
+ GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
14731
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
14732
+ GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
14733
+
14734
+ const int ith = params->ith;
14735
+ const int nth = params->nth;
14736
+
14737
+ const int64_t D = neq0;
14738
+ const int64_t N = neq1;
14739
+
14740
+ GGML_ASSERT(ne0 == D);
14741
+ GGML_ASSERT(ne2 == N);
14742
+
14743
+ GGML_ASSERT(nbq0 == sizeof(float));
14744
+ GGML_ASSERT(nbk0 == sizeof(ggml_fp16_t));
14745
+ GGML_ASSERT(nbv0 == sizeof(ggml_fp16_t));
14746
+
14747
+ GGML_ASSERT(neq0 == D);
14748
+ GGML_ASSERT(nek0 == D);
14749
+ GGML_ASSERT(nev0 == D);
14750
+
14751
+ GGML_ASSERT(neq1 == N);
14752
+ GGML_ASSERT(nev0 == D);
14753
+
14754
+ // dst cannot be transposed or permuted
14755
+ GGML_ASSERT(nb0 == sizeof(float));
14756
+ GGML_ASSERT(nb0 <= nb1);
14757
+ GGML_ASSERT(nb1 <= nb2);
14758
+ GGML_ASSERT(nb2 <= nb3);
14759
+
14760
+ // broadcast factors
14761
+ const int64_t rk2 = neq2/nek2;
14762
+ const int64_t rk3 = neq3/nek3;
14763
+
14764
+ const int64_t rv2 = neq2/nev2;
14765
+ const int64_t rv3 = neq3/nev3;
14766
+
14767
+ if (params->type == GGML_TASK_TYPE_INIT) {
14768
+ return;
14769
+ }
14770
+
14771
+ if (params->type == GGML_TASK_TYPE_FINALIZE) {
14772
+ return;
14773
+ }
14774
+
14775
+ // parallelize by q rows using ggml_vec_dot_f32
14776
+
14777
+ // total rows in q
14778
+ const int nr = neq1*neq2*neq3;
14779
+
14780
+ // rows per thread
14781
+ const int dr = (nr + nth - 1)/nth;
14782
+
14783
+ // row range for this thread
14784
+ const int ir0 = dr*ith;
14785
+ const int ir1 = MIN(ir0 + dr, nr);
14786
+
14787
+ float scale = 1.0f;
14788
+ memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
14789
+
14790
+ // loop over n_batch and n_head
14791
+ for (int ir = ir0; ir < ir1; ++ir) {
14792
+ // q indices
14793
+ const int iq3 = ir/(neq2*neq1);
14794
+ const int iq2 = (ir - iq3*neq2*neq1)/neq1;
14795
+ const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
14796
+
14797
+ float S = 0.0f;
14798
+ float M = -INFINITY;
14799
+
14800
+ float * V32 = (float *) params->wdata + ith*(2*D + CACHE_LINE_SIZE_F32);
14801
+ ggml_fp16_t * Q16 = (ggml_fp16_t *) (V32); // reuse memory
14802
+ ggml_fp16_t * V16 = (ggml_fp16_t *) (V32 + D);
14803
+
14804
+ memset(V16, 0, D*sizeof(ggml_fp16_t));
14805
+
14806
+ const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL;
14807
+
14808
+ // k indices
14809
+ const int ik3 = iq3 / rk3;
14810
+ const int ik2 = iq2 / rk2;
14811
+
14812
+ // v indices
14813
+ const int iv3 = iq3 / rv3;
14814
+ const int iv2 = iq2 / rv2;
14815
+
14816
+ // online softmax / attention
14817
+ // loop over n_kv and n_head_kv
14818
+ // ref: https://arxiv.org/pdf/2112.05682.pdf
14819
+ for (int64_t ic = 0; ic < nek1; ++ic) {
14820
+ const float mv = mp ? GGML_FP16_TO_FP32(mp[ic]) : 0.0f;
14821
+ if (mv == -INFINITY) {
14822
+ continue;
14823
+ }
14824
+
14825
+ float s;
14826
+
14827
+ // convert Q to F16 in V32
14828
+ {
14829
+ const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
14830
+
14831
+ for (int64_t d = 0; d < D; ++d) {
14832
+ Q16[d] = GGML_FP32_TO_FP16(pq[d]);
14833
+ }
14834
+ }
14835
+
14836
+ ggml_vec_dot_f16(D,
14837
+ &s, 0,
14838
+ (ggml_fp16_t *) ((char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3)), 0,
14839
+ Q16, 0, 1);
14840
+
14841
+ s = s*scale + mv;
14842
+
14843
+ const float Mold = M;
14844
+
14845
+ float ms = 1.0f;
14846
+ float vs = 1.0f;
14847
+
14848
+ if (s > M) {
14849
+ M = s;
14850
+ ms = expf(Mold - M);
14851
+
14852
+ // V = V*expf(Mold - M)
14853
+ ggml_vec_scale_f16(D, V16, ms);
14854
+ } else {
14855
+ vs = expf(s - M);
14856
+ }
14857
+
14858
+ const ggml_fp16_t * v16 = (const ggml_fp16_t *) ((char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3));
14859
+
14860
+ // V += v*expf(s - M)
14861
+ ggml_vec_mad_f16(D, V16, v16, vs);
14862
+
14863
+ S = S*ms + vs;
14864
+ }
14865
+
14866
+ // V /= S
14867
+ for (int64_t d = 0; d < D; ++d) {
14868
+ V32[d] = GGML_FP16_TO_FP32(V16[d])/S;
14869
+ }
14870
+
14871
+ // dst indices
14872
+ const int i1 = iq1;
14873
+ const int i2 = iq2;
14874
+ const int i3 = iq3;
14875
+
14876
+ // original
14877
+ //memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float));
14878
+
14879
+ // permute(0, 2, 1, 3)
14880
+ memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, V32, nb1);
14881
+ }
14882
+ }
14883
+
14884
+ static void ggml_compute_forward_flash_attn_ext(
14885
+ const struct ggml_compute_params * params,
14886
+ const struct ggml_tensor * q,
14887
+ const struct ggml_tensor * k,
14888
+ const struct ggml_tensor * v,
14889
+ const struct ggml_tensor * mask,
14890
+ struct ggml_tensor * dst) {
14891
+ switch (dst->op_params[1]) {
14892
+ case GGML_PREC_DEFAULT:
14893
+ case GGML_PREC_F32:
14894
+ {
14895
+ // uses F32 accumulators
14896
+ ggml_compute_forward_flash_attn_ext_f16(params, q, k, v, mask, dst);
14897
+ } break;
14898
+ default:
14899
+ {
14900
+ GGML_ASSERT(false);
14901
+ } break;
14902
+ }
14903
+ }
14904
+
14572
14905
  // ggml_compute_forward_flash_ff
14573
14906
 
14574
14907
  static void ggml_compute_forward_flash_ff_f16(
@@ -16376,6 +16709,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
16376
16709
  const bool masked = t != 0;
16377
16710
  ggml_compute_forward_flash_attn(params, masked, tensor);
16378
16711
  } break;
16712
+ case GGML_OP_FLASH_ATTN_EXT:
16713
+ {
16714
+ ggml_compute_forward_flash_attn_ext(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor);
16715
+ } break;
16379
16716
  case GGML_OP_FLASH_FF:
16380
16717
  {
16381
16718
  ggml_compute_forward_flash_ff(params, tensor);
@@ -17388,6 +17725,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
17388
17725
  GGML_ASSERT(false); // TODO: not implemented
17389
17726
  } break;
17390
17727
  case GGML_OP_FLASH_ATTN:
17728
+ case GGML_OP_FLASH_ATTN_EXT:
17391
17729
  {
17392
17730
  struct ggml_tensor * flash_grad = NULL;
17393
17731
  if (src0->grad || src1->grad || tensor->src[2]->grad) {
@@ -18160,6 +18498,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads, int n_cur_
18160
18498
  n_tasks = n_threads;
18161
18499
  } break;
18162
18500
  case GGML_OP_FLASH_ATTN:
18501
+ case GGML_OP_FLASH_ATTN_EXT:
18163
18502
  {
18164
18503
  n_tasks = n_threads;
18165
18504
  } break;
@@ -18563,6 +18902,12 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
18563
18902
  cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
18564
18903
  }
18565
18904
  } break;
18905
+ case GGML_OP_FLASH_ATTN_EXT:
18906
+ {
18907
+ const int64_t ne00 = node->src[0]->ne[0]; // D
18908
+
18909
+ cur = 2*sizeof(float)*ne00*n_tasks; // 2x head size
18910
+ } break;
18566
18911
  case GGML_OP_FLASH_FF:
18567
18912
  {
18568
18913
  if (node->src[1]->type == GGML_TYPE_F32) {
@@ -20614,7 +20959,7 @@ static void gguf_free_kv(struct gguf_kv * kv) {
20614
20959
  }
20615
20960
 
20616
20961
  struct gguf_context * gguf_init_empty(void) {
20617
- struct gguf_context * ctx = GGML_ALIGNED_MALLOC(sizeof(struct gguf_context));
20962
+ struct gguf_context * ctx = GGML_CALLOC(1, sizeof(struct gguf_context));
20618
20963
 
20619
20964
  memcpy(ctx->header.magic, GGUF_MAGIC, sizeof(ctx->header.magic));
20620
20965
  ctx->header.version = GGUF_VERSION;
@@ -20659,7 +21004,7 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
20659
21004
 
20660
21005
  bool ok = true;
20661
21006
 
20662
- struct gguf_context * ctx = GGML_ALIGNED_MALLOC(sizeof(struct gguf_context));
21007
+ struct gguf_context * ctx = GGML_CALLOC(1, sizeof(struct gguf_context));
20663
21008
 
20664
21009
  // read the header
20665
21010
  {
@@ -20696,9 +21041,13 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
20696
21041
 
20697
21042
  // read the kv pairs
20698
21043
  {
20699
- ctx->kv = GGML_MALLOC(ctx->header.n_kv * sizeof(struct gguf_kv));
21044
+ const uint64_t n_kv = ctx->header.n_kv;
20700
21045
 
20701
- for (uint64_t i = 0; i < ctx->header.n_kv; ++i) {
21046
+ // header.n_kv will hold the actual value of pairs that were successfully read in the loop below
21047
+ ctx->header.n_kv = 0;
21048
+ ctx->kv = GGML_CALLOC(n_kv, sizeof(struct gguf_kv));
21049
+
21050
+ for (uint64_t i = 0; i < n_kv; ++i) {
20702
21051
  struct gguf_kv * kv = &ctx->kv[i];
20703
21052
 
20704
21053
  //fprintf(stderr, "%s: reading kv %d\n", __func__, i);
@@ -20747,7 +21096,7 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
20747
21096
  return NULL;
20748
21097
  }
20749
21098
 
20750
- kv->value.arr.data = GGML_MALLOC(kv->value.arr.n * gguf_type_size(kv->value.arr.type));
21099
+ kv->value.arr.data = GGML_CALLOC(kv->value.arr.n, gguf_type_size(kv->value.arr.type));
20751
21100
 
20752
21101
  ok = ok && gguf_fread_el(file, kv->value.arr.data, kv->value.arr.n * gguf_type_size(kv->value.arr.type), &offset);
20753
21102
  } break;
@@ -20761,7 +21110,7 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
20761
21110
  return NULL;
20762
21111
  }
20763
21112
 
20764
- kv->value.arr.data = GGML_MALLOC(kv->value.arr.n * sizeof(struct gguf_str));
21113
+ kv->value.arr.data = GGML_CALLOC(kv->value.arr.n, sizeof(struct gguf_str));
20765
21114
 
20766
21115
  for (uint64_t j = 0; j < kv->value.arr.n; ++j) {
20767
21116
  ok = ok && gguf_fread_str(file, &((struct gguf_str *) kv->value.arr.data)[j], &offset);
@@ -20777,6 +21126,8 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
20777
21126
  if (!ok) {
20778
21127
  break;
20779
21128
  }
21129
+
21130
+ ctx->header.n_kv++;
20780
21131
  }
20781
21132
 
20782
21133
  if (!ok) {
@@ -20789,7 +21140,7 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
20789
21140
 
20790
21141
  // read the tensor infos
20791
21142
  {
20792
- ctx->infos = GGML_MALLOC(ctx->header.n_tensors * sizeof(struct gguf_tensor_info));
21143
+ ctx->infos = GGML_CALLOC(ctx->header.n_tensors, sizeof(struct gguf_tensor_info));
20793
21144
 
20794
21145
  for (uint64_t i = 0; i < ctx->header.n_tensors; ++i) {
20795
21146
  struct gguf_tensor_info * info = &ctx->infos[i];
@@ -20810,8 +21161,17 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
20810
21161
  ok = ok && gguf_fread_el (file, &info->type, sizeof(info->type), &offset);
20811
21162
  ok = ok && gguf_fread_el (file, &info->offset, sizeof(info->offset), &offset);
20812
21163
 
21164
+ // TODO: return an error instead of crashing with GGML_ASSERT
20813
21165
  gguf_tensor_info_sanitize(info);
20814
21166
 
21167
+ // make sure there is no duplicated tensor names
21168
+ for (uint64_t j = 0; j < i; ++j) {
21169
+ if (strcmp(info->name.data, ctx->infos[j].name.data) == 0) {
21170
+ fprintf(stderr, "%s: duplicated tensor name %s\n", __func__, info->name.data);
21171
+ ok = false;
21172
+ }
21173
+ }
21174
+
20815
21175
  if (!ok) {
20816
21176
  fprintf(stderr, "%s: failed to read tensor info\n", __func__);
20817
21177
  fclose(file);
@@ -20980,7 +21340,7 @@ void gguf_free(struct gguf_context * ctx) {
20980
21340
  GGML_FREE(ctx->infos);
20981
21341
  }
20982
21342
 
20983
- GGML_ALIGNED_FREE(ctx);
21343
+ GGML_FREE(ctx);
20984
21344
  }
20985
21345
 
20986
21346
  const char * gguf_type_name(enum gguf_type type) {
@@ -21291,7 +21651,7 @@ void gguf_set_arr_data(struct gguf_context * ctx, const char * key, enum gguf_ty
21291
21651
  ctx->kv[idx].type = GGUF_TYPE_ARRAY;
21292
21652
  ctx->kv[idx].value.arr.type = type;
21293
21653
  ctx->kv[idx].value.arr.n = n;
21294
- ctx->kv[idx].value.arr.data = GGML_MALLOC(n*gguf_type_size(type));
21654
+ ctx->kv[idx].value.arr.data = GGML_CALLOC(n, gguf_type_size(type));
21295
21655
  memcpy(ctx->kv[idx].value.arr.data, data, n*gguf_type_size(type));
21296
21656
  }
21297
21657
 
@@ -21301,7 +21661,7 @@ void gguf_set_arr_str(struct gguf_context * ctx, const char * key, const char **
21301
21661
  ctx->kv[idx].type = GGUF_TYPE_ARRAY;
21302
21662
  ctx->kv[idx].value.arr.type = GGUF_TYPE_STRING;
21303
21663
  ctx->kv[idx].value.arr.n = n;
21304
- ctx->kv[idx].value.arr.data = GGML_MALLOC(n*sizeof(struct gguf_str));
21664
+ ctx->kv[idx].value.arr.data = GGML_CALLOC(n, sizeof(struct gguf_str));
21305
21665
  for (int i = 0; i < n; i++) {
21306
21666
  struct gguf_str * str = &((struct gguf_str *)ctx->kv[idx].value.arr.data)[i];
21307
21667
  str->n = strlen(data[i]);
@@ -21328,7 +21688,7 @@ void gguf_set_kv(struct gguf_context * ctx, struct gguf_context * src) {
21328
21688
  case GGUF_TYPE_ARRAY:
21329
21689
  {
21330
21690
  if (src->kv[i].value.arr.type == GGUF_TYPE_STRING) {
21331
- const char ** data = GGML_MALLOC(src->kv[i].value.arr.n*sizeof(char *));
21691
+ const char ** data = GGML_CALLOC(src->kv[i].value.arr.n, sizeof(char *));
21332
21692
  for (uint32_t j = 0; j < src->kv[i].value.arr.n; j++) {
21333
21693
  data[j] = ((struct gguf_str *)src->kv[i].value.arr.data)[j].data;
21334
21694
  }
@@ -21348,6 +21708,10 @@ void gguf_set_kv(struct gguf_context * ctx, struct gguf_context * src) {
21348
21708
  void gguf_add_tensor(
21349
21709
  struct gguf_context * ctx,
21350
21710
  const struct ggml_tensor * tensor) {
21711
+ if (gguf_find_tensor(ctx, tensor->name) != -1) {
21712
+ GGML_ASSERT(false && "duplicated tensor name");
21713
+ }
21714
+
21351
21715
  const int idx = ctx->header.n_tensors;
21352
21716
  ctx->infos = realloc(ctx->infos, (idx + 1)*sizeof(struct gguf_tensor_info));
21353
21717
 
@@ -21416,7 +21780,7 @@ struct gguf_buf {
21416
21780
 
21417
21781
  static struct gguf_buf gguf_buf_init(size_t size) {
21418
21782
  struct gguf_buf buf = {
21419
- /*buf.data =*/ size == 0 ? NULL : GGML_MALLOC(size),
21783
+ /*buf.data =*/ size == 0 ? NULL : GGML_CALLOC(1, size),
21420
21784
  /*buf.size =*/ size,
21421
21785
  /*buf.offset =*/ 0,
21422
21786
  };