llama_cpp 0.14.7 → 0.15.1

Sign up to get free protection for your applications and to get access to all the features.
@@ -322,7 +322,7 @@ static ggml_fp16_t ggml_table_exp_f16[1 << 16];
322
322
  // precomputed f32 table for f16 (256 KB) (ggml-impl.h)
323
323
  float ggml_table_f32_f16[1 << 16];
324
324
 
325
- const char * ggml_status_to_string(enum ggml_status status) {
325
+ GGML_CALL const char * ggml_status_to_string(enum ggml_status status) {
326
326
  switch (status) {
327
327
  case GGML_STATUS_ALLOC_FAILED: return "GGML status: error (failed to allocate memory)";
328
328
  case GGML_STATUS_FAILED: return "GGML status: error (operation failed)";
@@ -333,16 +333,26 @@ const char * ggml_status_to_string(enum ggml_status status) {
333
333
  return "GGML status: unknown";
334
334
  }
335
335
 
336
- // note: do not use these inside ggml.c
337
- // these are meant to be used via the ggml.h API
338
336
  float ggml_fp16_to_fp32(ggml_fp16_t x) {
337
+ #define ggml_fp16_to_fp32 do_not_use__ggml_fp16_to_fp32__in_ggml
339
338
  return GGML_FP16_TO_FP32(x);
340
339
  }
341
340
 
342
341
  ggml_fp16_t ggml_fp32_to_fp16(float x) {
342
+ #define ggml_fp32_to_fp16 do_not_use__ggml_fp32_to_fp16__in_ggml
343
343
  return GGML_FP32_TO_FP16(x);
344
344
  }
345
345
 
346
+ float ggml_bf16_to_fp32(ggml_bf16_t x) {
347
+ #define ggml_bf16_to_fp32 do_not_use__ggml_bf16_to_fp32__in_ggml
348
+ return GGML_BF16_TO_FP32(x); // it just left shifts
349
+ }
350
+
351
+ ggml_bf16_t ggml_fp32_to_bf16(float x) {
352
+ #define ggml_fp32_to_bf16 do_not_use__ggml_fp32_to_bf16__in_ggml
353
+ return GGML_FP32_TO_BF16(x);
354
+ }
355
+
346
356
  void ggml_fp16_to_fp32_row(const ggml_fp16_t * x, float * y, int64_t n) {
347
357
  for (int64_t i = 0; i < n; i++) {
348
358
  y[i] = GGML_FP16_TO_FP32(x[i]);
@@ -368,6 +378,49 @@ void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, int64_t n) {
368
378
  }
369
379
  }
370
380
 
381
+ void ggml_bf16_to_fp32_row(const ggml_bf16_t * x, float * y, int64_t n) {
382
+ int64_t i = 0;
383
+ #if defined(__AVX512F__)
384
+ for (; i + 16 <= n; i += 16) {
385
+ _mm512_storeu_ps(y + i,
386
+ _mm512_castsi512_ps(
387
+ _mm512_slli_epi32(
388
+ _mm512_cvtepu16_epi32(
389
+ _mm256_loadu_si256(
390
+ (const __m256i *)(x + i))),
391
+ 16)));
392
+ }
393
+ #elif defined(__AVX2__)
394
+ for (; i + 8 <= n; i += 8) {
395
+ _mm256_storeu_ps(y + i,
396
+ _mm256_castsi256_ps(
397
+ _mm256_slli_epi32(
398
+ _mm256_cvtepu16_epi32(
399
+ _mm_loadu_si128(
400
+ (const __m128i *)(x + i))),
401
+ 16)));
402
+ }
403
+ #endif
404
+ for (; i < n; i++) {
405
+ y[i] = GGML_BF16_TO_FP32(x[i]);
406
+ }
407
+ }
408
+
409
+ void ggml_fp32_to_bf16_row(const float * x, ggml_bf16_t * y, int64_t n) {
410
+ int i = 0;
411
+ #if defined(__AVX512BF16__)
412
+ for (; i + 32 <= n; i += 32) {
413
+ _mm512_storeu_ps(
414
+ (__m512 *)(y + i),
415
+ (__m512)_mm512_cvtne2ps_pbh(_mm512_loadu_ps(x + i + 16),
416
+ _mm512_loadu_ps(x + i)));
417
+ }
418
+ #endif
419
+ for (; i < n; i++) {
420
+ y[i] = GGML_FP32_TO_BF16(x[i]);
421
+ }
422
+ }
423
+
371
424
  bool ggml_guid_matches(ggml_guid_t guid_a, ggml_guid_t guid_b) {
372
425
  return memcmp(guid_a, guid_b, sizeof(ggml_guid)) == 0;
373
426
  }
@@ -503,6 +556,7 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
503
556
 
504
557
  static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float * restrict x, size_t bx, const float * restrict y, size_t by, int nrc);
505
558
  static void ggml_vec_dot_f16(int n, float * restrict s, size_t bs, ggml_fp16_t * restrict x, size_t bx, ggml_fp16_t * restrict y, size_t by, int nrc);
559
+ static void ggml_vec_dot_bf16(int n, float * restrict s, size_t bs, ggml_bf16_t * restrict x, size_t bx, ggml_bf16_t * restrict y, size_t by, int nrc);
506
560
 
507
561
  static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
508
562
  [GGML_TYPE_I8] = {
@@ -845,6 +899,18 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
845
899
  .type_size = sizeof(block_q8_K),
846
900
  .is_quantized = true,
847
901
  .from_float = quantize_row_q8_K,
902
+ },
903
+ [GGML_TYPE_BF16] = {
904
+ .type_name = "bf16",
905
+ .blck_size = 1,
906
+ .type_size = sizeof(ggml_bf16_t),
907
+ .is_quantized = false,
908
+ .to_float = (ggml_to_float_t) ggml_bf16_to_fp32_row,
909
+ .from_float = (ggml_from_float_t) ggml_fp32_to_bf16_row,
910
+ .from_float_reference = (ggml_from_float_t) ggml_fp32_to_bf16_row,
911
+ .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_bf16,
912
+ .vec_dot_type = GGML_TYPE_BF16,
913
+ .nrows = 1,
848
914
  }
849
915
  };
850
916
 
@@ -951,7 +1017,7 @@ ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type) {
951
1017
  #define GGML_F16_VEC_ZERO GGML_F16x8_ZERO
952
1018
  #define GGML_F16_VEC_SET1 GGML_F16x8_SET1
953
1019
  #define GGML_F16_VEC_LOAD(p, i) GGML_F16x8_LOAD(p)
954
- #define GGML_F16_VEC_STORE(p, r, i) GGML_F16x8_STORE(p, r[i])
1020
+ #define GGML_F16_VEC_STORE(p, r, i) GGML_F16x8_STORE((ggml_fp16_internal_t *)(p), r[i])
955
1021
  #define GGML_F16_VEC_FMA GGML_F16x8_FMA
956
1022
  #define GGML_F16_VEC_ADD GGML_F16x8_ADD
957
1023
  #define GGML_F16_VEC_MUL GGML_F16x8_MUL
@@ -977,7 +1043,7 @@ ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type) {
977
1043
  #define GGML_F16_VEC_ZERO GGML_F32Cx4_ZERO
978
1044
  #define GGML_F16_VEC_SET1 GGML_F32Cx4_SET1
979
1045
  #define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx4_LOAD(p)
980
- #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE(p, r[i])
1046
+ #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE((ggml_fp16_internal_t *)(p), r[i])
981
1047
  #define GGML_F16_VEC_FMA GGML_F32Cx4_FMA
982
1048
  #define GGML_F16_VEC_ADD GGML_F32Cx4_ADD
983
1049
  #define GGML_F16_VEC_MUL GGML_F32Cx4_MUL
@@ -1046,7 +1112,7 @@ do { \
1046
1112
 
1047
1113
  // unlike _mm256_cvt intrinsics that require F16C, _mm512_cvt is defined in AVX512F
1048
1114
  // so F16C guard isn't required
1049
- #define GGML_F32Cx16_LOAD(x) _mm512_cvtph_ps(_mm256_loadu_si256((__m256i *)(x)))
1115
+ #define GGML_F32Cx16_LOAD(x) _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(x)))
1050
1116
  #define GGML_F32Cx16_STORE(x, y) _mm256_storeu_si256((__m256i *)(x), _mm512_cvtps_ph(y, 0))
1051
1117
 
1052
1118
  #define GGML_F32Cx16_FMA(a, b, c) _mm512_fmadd_ps(b, c, a)
@@ -1144,7 +1210,7 @@ do { \
1144
1210
 
1145
1211
  #if defined(__F16C__)
1146
1212
  // the _mm256_cvt intrinsics require F16C
1147
- #define GGML_F32Cx8_LOAD(x) _mm256_cvtph_ps(_mm_loadu_si128((__m128i *)(x)))
1213
+ #define GGML_F32Cx8_LOAD(x) _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(x)))
1148
1214
  #define GGML_F32Cx8_STORE(x, y) _mm_storeu_si128((__m128i *)(x), _mm256_cvtps_ph(y, 0))
1149
1215
  #else
1150
1216
  static inline __m256 __avx_f32cx8_load(ggml_fp16_t *x) {
@@ -1480,6 +1546,8 @@ inline static void ggml_vec_set_i32(const int n, int32_t * x, const int32_t v) {
1480
1546
 
1481
1547
  inline static void ggml_vec_set_f16(const int n, ggml_fp16_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
1482
1548
 
1549
+ inline static void ggml_vec_set_bf16(const int n, ggml_bf16_t * x, const ggml_bf16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
1550
+
1483
1551
  inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] + y[i]; }
1484
1552
  inline static void ggml_vec_add1_f32(const int n, float * z, const float * x, const float v) { for (int i = 0; i < n; ++i) z[i] = x[i] + v; }
1485
1553
  inline static void ggml_vec_acc_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] += x[i]; }
@@ -1498,7 +1566,7 @@ static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float *
1498
1566
  UNUSED(by);
1499
1567
  UNUSED(bs);
1500
1568
 
1501
- #ifdef GGML_SIMD
1569
+ #if defined(GGML_SIMD)
1502
1570
  float sumf = 0.0f;
1503
1571
  const int np = (n & ~(GGML_F32_STEP - 1));
1504
1572
 
@@ -1534,6 +1602,70 @@ static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float *
1534
1602
  *s = sumf;
1535
1603
  }
1536
1604
 
1605
+ static void ggml_vec_dot_bf16(int n, float * restrict s, size_t bs, ggml_bf16_t * restrict x, size_t bx, ggml_bf16_t * restrict y, size_t by, int nrc) {
1606
+ assert(nrc == 1);
1607
+ UNUSED(nrc);
1608
+ UNUSED(bx);
1609
+ UNUSED(by);
1610
+ UNUSED(bs);
1611
+ int i = 0;
1612
+ ggml_float sumf = 0;
1613
+
1614
+ #if defined(__AVX512BF16__)
1615
+ __m512 c1 = _mm512_setzero_ps();
1616
+ __m512 c2 = _mm512_setzero_ps();
1617
+ for (; i + 64 <= n; i += 64) {
1618
+ c1 = _mm512_dpbf16_ps(c1, (__m512bh)_mm512_loadu_ps((const float *)(x + i)),
1619
+ (__m512bh)_mm512_loadu_ps((const float *)(y + i)));
1620
+ c2 = _mm512_dpbf16_ps(c2, (__m512bh)_mm512_loadu_ps((const float *)(x + i + 32)),
1621
+ (__m512bh)_mm512_loadu_ps((const float *)(y + i + 32)));
1622
+ }
1623
+ sumf += (ggml_float)_mm512_reduce_add_ps(c1);
1624
+ sumf += (ggml_float)_mm512_reduce_add_ps(c2);
1625
+
1626
+ #elif defined(__AVX512F__)
1627
+ #define LOAD(p) _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i *)(p))), 16))
1628
+ __m512 c1 = _mm512_setzero_ps();
1629
+ __m512 c2 = _mm512_setzero_ps();
1630
+ for (; i + 32 <= n; i += 32) {
1631
+ c1 = _mm512_add_ps(_mm512_mul_ps(LOAD(x + i), LOAD(y + i)), c1);
1632
+ c2 = _mm512_add_ps(_mm512_mul_ps(LOAD(x + i + 16), LOAD(y + i + 16)), c2);
1633
+ }
1634
+ sumf += (ggml_float)_mm512_reduce_add_ps(c1);
1635
+ sumf += (ggml_float)_mm512_reduce_add_ps(c2);
1636
+
1637
+ #undef LOAD
1638
+ #elif defined(__AVX2__)
1639
+ #define LOAD(p) _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)(p))), 16))
1640
+ __m256 c1 = _mm256_setzero_ps();
1641
+ __m256 c2 = _mm256_setzero_ps();
1642
+ __m256 c3 = _mm256_setzero_ps();
1643
+ __m256 c4 = _mm256_setzero_ps();
1644
+ for (; i + 32 <= n; i += 32) {
1645
+ c1 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i), LOAD(y + i)), c1);
1646
+ c2 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i + 8), LOAD(y + i + 8)), c2);
1647
+ c3 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i + 16), LOAD(y + i + 16)), c3);
1648
+ c4 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i + 24), LOAD(y + i + 24)), c4);
1649
+ }
1650
+ __m128 g;
1651
+ c1 = _mm256_add_ps(_mm256_add_ps(c1, c3),
1652
+ _mm256_add_ps(c2, c4));
1653
+ g = _mm_add_ps(_mm256_extractf128_ps(c1, 1),
1654
+ _mm256_castps256_ps128(c1));
1655
+ g = _mm_add_ps(g, _mm_movehl_ps(g, g));
1656
+ g = _mm_add_ss(g, _mm_movehdup_ps(g));
1657
+ sumf += (ggml_float)_mm_cvtss_f32(g);
1658
+
1659
+ #undef LOAD
1660
+ #endif
1661
+
1662
+ for (; i < n; ++i) {
1663
+ sumf += (ggml_float)(GGML_BF16_TO_FP32(x[i]) *
1664
+ GGML_BF16_TO_FP32(y[i]));
1665
+ }
1666
+ *s = sumf;
1667
+ }
1668
+
1537
1669
  static void ggml_vec_dot_f16(int n, float * restrict s, size_t bs, ggml_fp16_t * restrict x, size_t bx, ggml_fp16_t * restrict y, size_t by, int nrc) {
1538
1670
  assert(nrc == 1);
1539
1671
  UNUSED(nrc);
@@ -1662,6 +1794,37 @@ inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float
1662
1794
  #endif
1663
1795
  }
1664
1796
 
1797
+ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, const ggml_fp16_t * restrict x, const float v) {
1798
+ #if defined(GGML_SIMD)
1799
+ const int np = (n & ~(GGML_F16_STEP - 1));
1800
+
1801
+ GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
1802
+
1803
+ GGML_F16_VEC ax[GGML_F16_ARR];
1804
+ GGML_F16_VEC ay[GGML_F16_ARR];
1805
+
1806
+ for (int i = 0; i < np; i += GGML_F16_STEP) {
1807
+ for (int j = 0; j < GGML_F16_ARR; j++) {
1808
+ ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j);
1809
+ ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
1810
+ ay[j] = GGML_F16_VEC_FMA(ay[j], ax[j], vx);
1811
+
1812
+ GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
1813
+ }
1814
+ }
1815
+
1816
+ // leftovers
1817
+ for (int i = np; i < n; ++i) {
1818
+ y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v);
1819
+ }
1820
+ #else
1821
+ // scalar
1822
+ for (int i = 0; i < n; ++i) {
1823
+ y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v);
1824
+ }
1825
+ #endif
1826
+ }
1827
+
1665
1828
  // xs and vs are byte strides of x and v
1666
1829
  inline static void ggml_vec_mad_f32_unroll(const int n, const int xs, const int vs, float * restrict y, const float * restrict xv, const float * restrict vv) {
1667
1830
 
@@ -1746,6 +1909,35 @@ inline static void ggml_vec_scale_f32(const int n, float * y, const float v) {
1746
1909
  #endif
1747
1910
  }
1748
1911
 
1912
+ inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float v) {
1913
+ #if defined(GGML_SIMD)
1914
+ const int np = (n & ~(GGML_F16_STEP - 1));
1915
+
1916
+ GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
1917
+
1918
+ GGML_F16_VEC ay[GGML_F16_ARR];
1919
+
1920
+ for (int i = 0; i < np; i += GGML_F16_STEP) {
1921
+ for (int j = 0; j < GGML_F16_ARR; j++) {
1922
+ ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
1923
+ ay[j] = GGML_F16_VEC_MUL(ay[j], vx);
1924
+
1925
+ GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
1926
+ }
1927
+ }
1928
+
1929
+ // leftovers
1930
+ for (int i = np; i < n; ++i) {
1931
+ y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i])*v);
1932
+ }
1933
+ #else
1934
+ // scalar
1935
+ for (int i = 0; i < n; ++i) {
1936
+ y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i])*v);
1937
+ }
1938
+ #endif
1939
+ }
1940
+
1749
1941
  inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, 0, x, 0, x, 0, 1); *s = sqrtf(*s); }
1750
1942
  inline static void ggml_vec_sqr_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i]; }
1751
1943
  inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrtf(x[i]); }
@@ -1907,6 +2099,14 @@ inline static void ggml_vec_sum_f16_ggf(const int n, float * s, const ggml_fp16_
1907
2099
  *s = sum;
1908
2100
  }
1909
2101
 
2102
+ inline static void ggml_vec_sum_bf16_ggf(const int n, float * s, const ggml_bf16_t * x) {
2103
+ float sum = 0.0f;
2104
+ for (int i = 0; i < n; ++i) {
2105
+ sum += GGML_BF16_TO_FP32(x[i]);
2106
+ }
2107
+ *s = sum;
2108
+ }
2109
+
1910
2110
  inline static void ggml_vec_max_f32(const int n, float * s, const float * x) {
1911
2111
  #ifndef GGML_USE_ACCELERATE
1912
2112
  float max = -INFINITY;
@@ -2000,6 +2200,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
2000
2200
  "LEAKY_RELU",
2001
2201
 
2002
2202
  "FLASH_ATTN",
2203
+ "FLASH_ATTN_EXT",
2003
2204
  "FLASH_FF",
2004
2205
  "FLASH_ATTN_BACK",
2005
2206
  "SSM_CONV",
@@ -2026,7 +2227,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
2026
2227
  "CROSS_ENTROPY_LOSS_BACK",
2027
2228
  };
2028
2229
 
2029
- static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76");
2230
+ static_assert(GGML_OP_COUNT == 77, "GGML_OP_COUNT != 77");
2030
2231
 
2031
2232
  static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
2032
2233
  "none",
@@ -2090,6 +2291,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
2090
2291
  "leaky_relu(x)",
2091
2292
 
2092
2293
  "flash_attn(x)",
2294
+ "flash_attn_ext(x)",
2093
2295
  "flash_ff(x)",
2094
2296
  "flash_attn_back(x)",
2095
2297
  "ssm_conv(x)",
@@ -2116,7 +2318,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
2116
2318
  "cross_entropy_loss_back(x,y)",
2117
2319
  };
2118
2320
 
2119
- static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76");
2321
+ static_assert(GGML_OP_COUNT == 77, "GGML_OP_COUNT != 77");
2120
2322
 
2121
2323
  static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
2122
2324
 
@@ -2315,7 +2517,7 @@ void ggml_numa_init(enum ggml_numa_strategy numa_flag) {
2315
2517
  // figure out which node we're on
2316
2518
  uint current_cpu;
2317
2519
  int getcpu_ret = 0;
2318
- #if __GLIBC__ > 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ > 28)
2520
+ #if __GLIBC__ > 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ > 28) || defined(__COSMOPOLITAN__)
2319
2521
  getcpu_ret = getcpu(&current_cpu, &g_state.numa.current_node);
2320
2522
  #else
2321
2523
  // old glibc doesn't have a wrapper for this call. Fall back on direct syscall
@@ -2526,6 +2728,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
2526
2728
  switch (ftype) {
2527
2729
  case GGML_FTYPE_ALL_F32: wtype = GGML_TYPE_F32; break;
2528
2730
  case GGML_FTYPE_MOSTLY_F16: wtype = GGML_TYPE_F16; break;
2731
+ case GGML_FTYPE_MOSTLY_BF16: wtype = GGML_TYPE_BF16; break;
2529
2732
  case GGML_FTYPE_MOSTLY_Q4_0: wtype = GGML_TYPE_Q4_0; break;
2530
2733
  case GGML_FTYPE_MOSTLY_Q4_1: wtype = GGML_TYPE_Q4_1; break;
2531
2734
  case GGML_FTYPE_MOSTLY_Q5_0: wtype = GGML_TYPE_Q5_0; break;
@@ -2667,15 +2870,16 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
2667
2870
  {
2668
2871
  const uint64_t t_start = ggml_time_us(); UNUSED(t_start);
2669
2872
 
2670
- ggml_fp16_t ii;
2671
2873
  for (int i = 0; i < (1 << 16); ++i) {
2672
- uint16_t ui = i;
2673
- memcpy(&ii, &ui, sizeof(ii));
2674
- const float f = ggml_table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(ii);
2874
+ union {
2875
+ uint16_t u16;
2876
+ ggml_fp16_t fp16;
2877
+ } u = {i};
2878
+ float f = ggml_table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(u.fp16);
2675
2879
  ggml_table_gelu_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_f32(f));
2676
2880
  ggml_table_gelu_quick_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_quick_f32(f));
2677
2881
  ggml_table_silu_f16[i] = GGML_FP32_TO_FP16(ggml_silu_f32(f));
2678
- ggml_table_exp_f16[i] = GGML_FP32_TO_FP16(expf(f));
2882
+ ggml_table_exp_f16[i] = GGML_FP32_TO_FP16(expf(f));
2679
2883
  }
2680
2884
 
2681
2885
  const uint64_t t_end = ggml_time_us(); UNUSED(t_end);
@@ -3139,6 +3343,13 @@ struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value) {
3139
3343
  ggml_vec_set_f16(nc, (ggml_fp16_t *)(data + i*n1), GGML_FP32_TO_FP16(value));
3140
3344
  }
3141
3345
  } break;
3346
+ case GGML_TYPE_BF16:
3347
+ {
3348
+ assert(tensor->nb[0] == sizeof(ggml_fp16_t));
3349
+ for (int i = 0; i < n; i++) {
3350
+ ggml_vec_set_bf16(nc, (ggml_bf16_t *)(data + i*n1), GGML_FP32_TO_BF16(value));
3351
+ }
3352
+ } break;
3142
3353
  case GGML_TYPE_F32:
3143
3354
  {
3144
3355
  assert(tensor->nb[0] == sizeof(float));
@@ -3191,6 +3402,13 @@ struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) {
3191
3402
  ggml_vec_set_f16(nc, (ggml_fp16_t *)(data + i*n1), GGML_FP32_TO_FP16(value));
3192
3403
  }
3193
3404
  } break;
3405
+ case GGML_TYPE_BF16:
3406
+ {
3407
+ assert(tensor->nb[0] == sizeof(ggml_bf16_t));
3408
+ for (int i = 0; i < n; i++) {
3409
+ ggml_vec_set_bf16(nc, (ggml_bf16_t *)(data + i*n1), GGML_FP32_TO_BF16(value));
3410
+ }
3411
+ } break;
3194
3412
  case GGML_TYPE_F32:
3195
3413
  {
3196
3414
  assert(tensor->nb[0] == sizeof(float));
@@ -3258,6 +3476,11 @@ int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) {
3258
3476
  GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t));
3259
3477
  return GGML_FP16_TO_FP32(((ggml_fp16_t *)(tensor->data))[i]);
3260
3478
  }
3479
+ case GGML_TYPE_BF16:
3480
+ {
3481
+ GGML_ASSERT(tensor->nb[0] == sizeof(ggml_bf16_t));
3482
+ return GGML_BF16_TO_FP32(((ggml_bf16_t *)(tensor->data))[i]);
3483
+ }
3261
3484
  case GGML_TYPE_F32:
3262
3485
  {
3263
3486
  GGML_ASSERT(tensor->nb[0] == sizeof(float));
@@ -3300,6 +3523,11 @@ void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) {
3300
3523
  GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t));
3301
3524
  ((ggml_fp16_t *)(tensor->data))[i] = GGML_FP32_TO_FP16(value);
3302
3525
  } break;
3526
+ case GGML_TYPE_BF16:
3527
+ {
3528
+ GGML_ASSERT(tensor->nb[0] == sizeof(ggml_bf16_t));
3529
+ ((ggml_bf16_t *)(tensor->data))[i] = GGML_FP32_TO_BF16(value);
3530
+ } break;
3303
3531
  case GGML_TYPE_F32:
3304
3532
  {
3305
3533
  GGML_ASSERT(tensor->nb[0] == sizeof(float));
@@ -3323,6 +3551,8 @@ int32_t ggml_get_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i
3323
3551
  return ((int32_t *) data)[0];
3324
3552
  case GGML_TYPE_F16:
3325
3553
  return GGML_FP16_TO_FP32(((ggml_fp16_t *) data)[0]);
3554
+ case GGML_TYPE_BF16:
3555
+ return GGML_BF16_TO_FP32(((ggml_bf16_t *) data)[0]);
3326
3556
  case GGML_TYPE_F32:
3327
3557
  return ((float *) data)[0];
3328
3558
  default:
@@ -3351,6 +3581,10 @@ void ggml_set_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2,
3351
3581
  {
3352
3582
  ((ggml_fp16_t *)(data))[0] = GGML_FP32_TO_FP16(value);
3353
3583
  } break;
3584
+ case GGML_TYPE_BF16:
3585
+ {
3586
+ ((ggml_bf16_t *)(data))[0] = GGML_FP32_TO_BF16(value);
3587
+ } break;
3354
3588
  case GGML_TYPE_F32:
3355
3589
  {
3356
3590
  ((float *)(data))[0] = value;
@@ -3389,6 +3623,11 @@ float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) {
3389
3623
  GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t));
3390
3624
  return GGML_FP16_TO_FP32(((ggml_fp16_t *)(tensor->data))[i]);
3391
3625
  }
3626
+ case GGML_TYPE_BF16:
3627
+ {
3628
+ GGML_ASSERT(tensor->nb[0] == sizeof(ggml_bf16_t));
3629
+ return GGML_BF16_TO_FP32(((ggml_bf16_t *)(tensor->data))[i]);
3630
+ }
3392
3631
  case GGML_TYPE_F32:
3393
3632
  {
3394
3633
  GGML_ASSERT(tensor->nb[0] == sizeof(float));
@@ -3431,6 +3670,11 @@ void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) {
3431
3670
  GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t));
3432
3671
  ((ggml_fp16_t *)(tensor->data))[i] = GGML_FP32_TO_FP16(value);
3433
3672
  } break;
3673
+ case GGML_TYPE_BF16:
3674
+ {
3675
+ GGML_ASSERT(tensor->nb[0] == sizeof(ggml_bf16_t));
3676
+ ((ggml_bf16_t *)(tensor->data))[i] = GGML_FP32_TO_BF16(value);
3677
+ } break;
3434
3678
  case GGML_TYPE_F32:
3435
3679
  {
3436
3680
  GGML_ASSERT(tensor->nb[0] == sizeof(float));
@@ -3454,6 +3698,8 @@ float ggml_get_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2,
3454
3698
  return ((int32_t *) data)[0];
3455
3699
  case GGML_TYPE_F16:
3456
3700
  return GGML_FP16_TO_FP32(((ggml_fp16_t *) data)[0]);
3701
+ case GGML_TYPE_BF16:
3702
+ return GGML_BF16_TO_FP32(((ggml_bf16_t *) data)[0]);
3457
3703
  case GGML_TYPE_F32:
3458
3704
  return ((float *) data)[0];
3459
3705
  default:
@@ -3482,6 +3728,10 @@ void ggml_set_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2,
3482
3728
  {
3483
3729
  ((ggml_fp16_t *)(data))[0] = GGML_FP32_TO_FP16(value);
3484
3730
  } break;
3731
+ case GGML_TYPE_BF16:
3732
+ {
3733
+ ((ggml_bf16_t *)(data))[0] = GGML_FP32_TO_BF16(value);
3734
+ } break;
3485
3735
  case GGML_TYPE_F32:
3486
3736
  {
3487
3737
  ((float *)(data))[0] = value;
@@ -3676,7 +3926,11 @@ static struct ggml_tensor * ggml_add_cast_impl(
3676
3926
  // TODO: support less-strict constraint
3677
3927
  // GGML_ASSERT(ggml_can_repeat(b, a));
3678
3928
  GGML_ASSERT(ggml_can_repeat_rows(b, a));
3679
- GGML_ASSERT(ggml_is_quantized(a->type) || a->type == GGML_TYPE_F16); // currently only supported for quantized input and f16
3929
+
3930
+ // currently only supported for quantized input and f16
3931
+ GGML_ASSERT(ggml_is_quantized(a->type) ||
3932
+ a->type == GGML_TYPE_F16 ||
3933
+ a->type == GGML_TYPE_BF16);
3680
3934
 
3681
3935
  bool is_node = false;
3682
3936
 
@@ -4559,6 +4813,8 @@ struct ggml_tensor * ggml_mul_mat(
4559
4813
  void ggml_mul_mat_set_prec(
4560
4814
  struct ggml_tensor * a,
4561
4815
  enum ggml_prec prec) {
4816
+ GGML_ASSERT(a->op == GGML_OP_MUL_MAT);
4817
+
4562
4818
  const int32_t prec_i32 = (int32_t) prec;
4563
4819
 
4564
4820
  ggml_set_op_params_i32(a, 0, prec_i32);
@@ -5397,17 +5653,23 @@ static struct ggml_tensor * ggml_soft_max_impl(
5397
5653
  GGML_ASSERT(ggml_is_contiguous(a));
5398
5654
 
5399
5655
  if (mask) {
5656
+ GGML_ASSERT(mask->type == GGML_TYPE_F16 || mask->type == GGML_TYPE_F32);
5400
5657
  GGML_ASSERT(ggml_is_contiguous(mask));
5401
5658
  GGML_ASSERT(ggml_is_matrix(mask));
5402
- GGML_ASSERT(ggml_can_repeat_rows(mask, a));
5659
+ GGML_ASSERT(mask->ne[0] == a->ne[0]);
5660
+ GGML_ASSERT(mask->ne[1] >= a->ne[1]);
5403
5661
  }
5404
5662
 
5405
5663
  if (pos) {
5406
5664
  GGML_ASSERT(ggml_is_vector(pos));
5407
- GGML_ASSERT(pos->type == GGML_TYPE_F32);
5665
+ GGML_ASSERT(pos->type == GGML_TYPE_F16 || pos->type == GGML_TYPE_F32);
5408
5666
  GGML_ASSERT(pos->ne[0] == a->ne[0]);
5409
5667
  }
5410
5668
 
5669
+ if (pos && mask) {
5670
+ GGML_ASSERT(pos->type == mask->type);
5671
+ }
5672
+
5411
5673
  if (max_bias > 0.0f) {
5412
5674
  GGML_ASSERT(pos);
5413
5675
  }
@@ -6216,6 +6478,59 @@ struct ggml_tensor * ggml_flash_attn(
6216
6478
  return result;
6217
6479
  }
6218
6480
 
6481
+ // ggml_flash_attn_ext
6482
+
6483
+ struct ggml_tensor * ggml_flash_attn_ext(
6484
+ struct ggml_context * ctx,
6485
+ struct ggml_tensor * q,
6486
+ struct ggml_tensor * k,
6487
+ struct ggml_tensor * v,
6488
+ struct ggml_tensor * mask,
6489
+ float scale) {
6490
+ GGML_ASSERT(ggml_can_mul_mat(k, q));
6491
+ // TODO: check if vT can be multiplied by (k*qT)
6492
+ if (mask) {
6493
+ GGML_ASSERT(ggml_is_contiguous(mask));
6494
+ GGML_ASSERT(mask->ne[2] == 1);
6495
+ GGML_ASSERT(mask->ne[3] == 1);
6496
+ GGML_ASSERT(mask->ne[1] >= GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD) &&
6497
+ "the Flash-Attention kernel requires the mask to be padded to GGML_KQ_MASK_PAD and at least n_queries big");
6498
+ //GGML_ASSERT(ggml_can_repeat_rows(mask, qk));
6499
+ }
6500
+
6501
+ bool is_node = false;
6502
+
6503
+ if (q->grad || k->grad || v->grad) {
6504
+ is_node = true;
6505
+ }
6506
+
6507
+ // permute(0, 2, 1, 3)
6508
+ int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] };
6509
+ struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
6510
+
6511
+ float params[] = { scale };
6512
+ ggml_set_op_params(result, params, sizeof(params));
6513
+
6514
+ result->op = GGML_OP_FLASH_ATTN_EXT;
6515
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6516
+ result->src[0] = q;
6517
+ result->src[1] = k;
6518
+ result->src[2] = v;
6519
+ result->src[3] = mask;
6520
+
6521
+ return result;
6522
+ }
6523
+
6524
+ void ggml_flash_attn_ext_set_prec(
6525
+ struct ggml_tensor * a,
6526
+ enum ggml_prec prec) {
6527
+ GGML_ASSERT(a->op == GGML_OP_FLASH_ATTN_EXT);
6528
+
6529
+ const int32_t prec_i32 = (int32_t) prec;
6530
+
6531
+ ggml_set_op_params_i32(a, 1, prec_i32); // scale is on first pos
6532
+ }
6533
+
6219
6534
  // ggml_flash_ff
6220
6535
 
6221
6536
  struct ggml_tensor * ggml_flash_ff(
@@ -7092,8 +7407,8 @@ static void ggml_compute_forward_dup_same_cont(
7092
7407
  ((char *) src0->data + ie0*nb00),
7093
7408
  (ie1 - ie0) * ggml_type_size(src0->type));
7094
7409
  }
7095
-
7096
7410
  }
7411
+
7097
7412
  static void ggml_compute_forward_dup_f16(
7098
7413
  const struct ggml_compute_params * params,
7099
7414
  struct ggml_tensor * dst) {
@@ -7367,7 +7682,7 @@ static void ggml_compute_forward_dup_f16(
7367
7682
  }
7368
7683
  }
7369
7684
 
7370
- static void ggml_compute_forward_dup_f32(
7685
+ static void ggml_compute_forward_dup_bf16(
7371
7686
  const struct ggml_compute_params * params,
7372
7687
  struct ggml_tensor * dst) {
7373
7688
 
@@ -7415,10 +7730,11 @@ static void ggml_compute_forward_dup_f32(
7415
7730
  return;
7416
7731
  }
7417
7732
 
7733
+ // TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy
7734
+
7418
7735
  if (ggml_is_contiguous(dst)) {
7419
- // TODO: simplify
7420
- if (nb00 == sizeof(float)) {
7421
- if (dst->type == GGML_TYPE_F32) {
7736
+ if (nb00 == sizeof(ggml_bf16_t)) {
7737
+ if (dst->type == GGML_TYPE_BF16) {
7422
7738
  size_t id = 0;
7423
7739
  const size_t rs = ne00 * nb00;
7424
7740
  char * dst_ptr = (char *) dst->data;
@@ -7434,8 +7750,43 @@ static void ggml_compute_forward_dup_f32(
7434
7750
  id += rs * (ne01 - ir1);
7435
7751
  }
7436
7752
  }
7753
+ } else if (dst->type == GGML_TYPE_F16) {
7754
+ size_t id = 0;
7755
+ ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
7756
+
7757
+ for (int i03 = 0; i03 < ne03; i03++) {
7758
+ for (int i02 = 0; i02 < ne02; i02++) {
7759
+ id += ne00 * ir0;
7760
+ for (int i01 = ir0; i01 < ir1; i01++) {
7761
+ const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
7762
+ for (int i00 = 0; i00 < ne00; i00++) {
7763
+ dst_ptr[id] = GGML_FP32_TO_FP16(GGML_BF16_TO_FP32(src0_ptr[i00]));
7764
+ id++;
7765
+ }
7766
+ }
7767
+ id += ne00 * (ne01 - ir1);
7768
+ }
7769
+ }
7770
+ } else if (dst->type == GGML_TYPE_F32) {
7771
+ size_t id = 0;
7772
+ float * dst_ptr = (float *) dst->data;
7773
+
7774
+ for (int i03 = 0; i03 < ne03; i03++) {
7775
+ for (int i02 = 0; i02 < ne02; i02++) {
7776
+ id += ne00 * ir0;
7777
+ for (int i01 = ir0; i01 < ir1; i01++) {
7778
+ const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
7779
+ for (int i00 = 0; i00 < ne00; i00++) {
7780
+ dst_ptr[id] = GGML_BF16_TO_FP32(src0_ptr[i00]);
7781
+ id++;
7782
+ }
7783
+ }
7784
+ id += ne00 * (ne01 - ir1);
7785
+ }
7786
+ }
7437
7787
  } else if (type_traits[dst->type].from_float) {
7438
7788
  ggml_from_float_t const quantize_row_q = type_traits[dst->type].from_float;
7789
+ float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
7439
7790
 
7440
7791
  size_t id = 0;
7441
7792
  size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
@@ -7445,8 +7796,13 @@ static void ggml_compute_forward_dup_f32(
7445
7796
  for (int i02 = 0; i02 < ne02; i02++) {
7446
7797
  id += rs * ir0;
7447
7798
  for (int i01 = ir0; i01 < ir1; i01++) {
7448
- const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
7449
- quantize_row_q(src0_ptr, dst_ptr + id, ne00);
7799
+ const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
7800
+
7801
+ for (int i00 = 0; i00 < ne00; i00++) {
7802
+ src0_f32[i00] = GGML_BF16_TO_FP32(src0_ptr[i00]);
7803
+ }
7804
+
7805
+ quantize_row_q(src0_f32, dst_ptr + id, ne00);
7450
7806
  id += rs;
7451
7807
  }
7452
7808
  id += rs * (ne01 - ir1);
@@ -7467,7 +7823,25 @@ static void ggml_compute_forward_dup_f32(
7467
7823
  id += ne00 * ir0;
7468
7824
  for (int i01 = ir0; i01 < ir1; i01++) {
7469
7825
  for (int i00 = 0; i00 < ne00; i00++) {
7470
- const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
7826
+ const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
7827
+
7828
+ dst_ptr[id] = GGML_BF16_TO_FP32(*src0_ptr);
7829
+ id++;
7830
+ }
7831
+ }
7832
+ id += ne00 * (ne01 - ir1);
7833
+ }
7834
+ }
7835
+ } else if (dst->type == GGML_TYPE_BF16) {
7836
+ size_t id = 0;
7837
+ ggml_bf16_t * dst_ptr = (ggml_bf16_t *) dst->data;
7838
+
7839
+ for (int i03 = 0; i03 < ne03; i03++) {
7840
+ for (int i02 = 0; i02 < ne02; i02++) {
7841
+ id += ne00 * ir0;
7842
+ for (int i01 = ir0; i01 < ir1; i01++) {
7843
+ for (int i00 = 0; i00 < ne00; i00++) {
7844
+ const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
7471
7845
 
7472
7846
  dst_ptr[id] = *src0_ptr;
7473
7847
  id++;
@@ -7485,9 +7859,9 @@ static void ggml_compute_forward_dup_f32(
7485
7859
  id += ne00 * ir0;
7486
7860
  for (int i01 = ir0; i01 < ir1; i01++) {
7487
7861
  for (int i00 = 0; i00 < ne00; i00++) {
7488
- const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
7862
+ const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
7489
7863
 
7490
- dst_ptr[id] = GGML_FP32_TO_FP16(*src0_ptr);
7864
+ dst_ptr[id] = GGML_FP32_TO_FP16(GGML_BF16_TO_FP32(*src0_ptr));
7491
7865
  id++;
7492
7866
  }
7493
7867
  }
@@ -7498,18 +7872,16 @@ static void ggml_compute_forward_dup_f32(
7498
7872
  GGML_ASSERT(false); // TODO: implement
7499
7873
  }
7500
7874
  }
7501
-
7502
7875
  return;
7503
7876
  }
7504
7877
 
7505
7878
  // dst counters
7506
-
7507
7879
  int64_t i10 = 0;
7508
7880
  int64_t i11 = 0;
7509
7881
  int64_t i12 = 0;
7510
7882
  int64_t i13 = 0;
7511
7883
 
7512
- if (dst->type == GGML_TYPE_F32) {
7884
+ if (dst->type == GGML_TYPE_BF16) {
7513
7885
  for (int64_t i03 = 0; i03 < ne03; i03++) {
7514
7886
  for (int64_t i02 = 0; i02 < ne02; i02++) {
7515
7887
  i10 += ne00 * ir0;
@@ -7530,15 +7902,15 @@ static void ggml_compute_forward_dup_f32(
7530
7902
  const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
7531
7903
  char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
7532
7904
 
7533
- memcpy(dst_ptr, src0_ptr, sizeof(float));
7905
+ memcpy(dst_ptr, src0_ptr, sizeof(ggml_bf16_t));
7534
7906
 
7535
- if (++i10 == ne0) {
7907
+ if (++i10 == ne00) {
7536
7908
  i10 = 0;
7537
- if (++i11 == ne1) {
7909
+ if (++i11 == ne01) {
7538
7910
  i11 = 0;
7539
- if (++i12 == ne2) {
7911
+ if (++i12 == ne02) {
7540
7912
  i12 = 0;
7541
- if (++i13 == ne3) {
7913
+ if (++i13 == ne03) {
7542
7914
  i13 = 0;
7543
7915
  }
7544
7916
  }
@@ -7582,7 +7954,7 @@ static void ggml_compute_forward_dup_f32(
7582
7954
  const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
7583
7955
  char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
7584
7956
 
7585
- *(ggml_fp16_t *) dst_ptr = GGML_FP32_TO_FP16(*(const float *) src0_ptr);
7957
+ *(ggml_fp16_t *) dst_ptr = GGML_FP32_TO_FP16(GGML_BF16_TO_FP32(*(const ggml_bf16_t *) src0_ptr));
7586
7958
 
7587
7959
  if (++i10 == ne0) {
7588
7960
  i10 = 0;
@@ -7613,10 +7985,383 @@ static void ggml_compute_forward_dup_f32(
7613
7985
  }
7614
7986
  }
7615
7987
  }
7616
- } else {
7617
- GGML_ASSERT(false); // TODO: implement
7618
- }
7619
- }
7988
+ } else if (dst->type == GGML_TYPE_F32) {
7989
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
7990
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
7991
+ i10 += ne00 * ir0;
7992
+ while (i10 >= ne0) {
7993
+ i10 -= ne0;
7994
+ if (++i11 == ne1) {
7995
+ i11 = 0;
7996
+ if (++i12 == ne2) {
7997
+ i12 = 0;
7998
+ if (++i13 == ne3) {
7999
+ i13 = 0;
8000
+ }
8001
+ }
8002
+ }
8003
+ }
8004
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
8005
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
8006
+ const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
8007
+ char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
8008
+
8009
+ *(float *) dst_ptr = GGML_BF16_TO_FP32(*(const ggml_bf16_t *) src0_ptr);
8010
+
8011
+ if (++i10 == ne0) {
8012
+ i10 = 0;
8013
+ if (++i11 == ne1) {
8014
+ i11 = 0;
8015
+ if (++i12 == ne2) {
8016
+ i12 = 0;
8017
+ if (++i13 == ne3) {
8018
+ i13 = 0;
8019
+ }
8020
+ }
8021
+ }
8022
+ }
8023
+ }
8024
+ }
8025
+ i10 += ne00 * (ne01 - ir1);
8026
+ while (i10 >= ne0) {
8027
+ i10 -= ne0;
8028
+ if (++i11 == ne1) {
8029
+ i11 = 0;
8030
+ if (++i12 == ne2) {
8031
+ i12 = 0;
8032
+ if (++i13 == ne3) {
8033
+ i13 = 0;
8034
+ }
8035
+ }
8036
+ }
8037
+ }
8038
+ }
8039
+ }
8040
+ } else {
8041
+ GGML_ASSERT(false); // TODO: implement
8042
+ }
8043
+ }
8044
+
8045
+ static void ggml_compute_forward_dup_f32(
8046
+ const struct ggml_compute_params * params,
8047
+ struct ggml_tensor * dst) {
8048
+
8049
+ const struct ggml_tensor * src0 = dst->src[0];
8050
+
8051
+ GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
8052
+
8053
+ if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
8054
+ return;
8055
+ }
8056
+
8057
+ GGML_TENSOR_UNARY_OP_LOCALS
8058
+
8059
+ const int ith = params->ith; // thread index
8060
+ const int nth = params->nth; // number of threads
8061
+
8062
+ if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst) && src0->type == dst->type) {
8063
+ ggml_compute_forward_dup_same_cont(params, dst);
8064
+ return;
8065
+ }
8066
+
8067
+ // parallelize by rows
8068
+ const int nr = ne01;
8069
+ // number of rows per thread
8070
+ const int dr = (nr + nth - 1) / nth;
8071
+ // row range for this thread
8072
+ const int ir0 = dr * ith;
8073
+ const int ir1 = MIN(ir0 + dr, nr);
8074
+
8075
+ if (src0->type == dst->type &&
8076
+ ne00 == ne0 &&
8077
+ nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) {
8078
+ // copy by rows
8079
+ const size_t rs = ne00*nb00;
8080
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
8081
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
8082
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
8083
+ memcpy(
8084
+ ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
8085
+ ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
8086
+ rs);
8087
+ }
8088
+ }
8089
+ }
8090
+ return;
8091
+ }
8092
+
8093
+ if (ggml_is_contiguous(dst)) {
8094
+ // TODO: simplify
8095
+ if (nb00 == sizeof(float)) {
8096
+ if (dst->type == GGML_TYPE_F32) {
8097
+ size_t id = 0;
8098
+ const size_t rs = ne00 * nb00;
8099
+ char * dst_ptr = (char *) dst->data;
8100
+
8101
+ for (int i03 = 0; i03 < ne03; i03++) {
8102
+ for (int i02 = 0; i02 < ne02; i02++) {
8103
+ id += rs * ir0;
8104
+ for (int i01 = ir0; i01 < ir1; i01++) {
8105
+ const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
8106
+ memcpy(dst_ptr + id, src0_ptr, rs);
8107
+ id += rs;
8108
+ }
8109
+ id += rs * (ne01 - ir1);
8110
+ }
8111
+ }
8112
+ } else if (type_traits[dst->type].from_float) {
8113
+ ggml_from_float_t const quantize_row_q = type_traits[dst->type].from_float;
8114
+
8115
+ size_t id = 0;
8116
+ size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
8117
+ char * dst_ptr = (char *) dst->data;
8118
+
8119
+ for (int i03 = 0; i03 < ne03; i03++) {
8120
+ for (int i02 = 0; i02 < ne02; i02++) {
8121
+ id += rs * ir0;
8122
+ for (int i01 = ir0; i01 < ir1; i01++) {
8123
+ const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
8124
+ quantize_row_q(src0_ptr, dst_ptr + id, ne00);
8125
+ id += rs;
8126
+ }
8127
+ id += rs * (ne01 - ir1);
8128
+ }
8129
+ }
8130
+ } else {
8131
+ GGML_ASSERT(false); // TODO: implement
8132
+ }
8133
+ } else {
8134
+ //printf("%s: this is not optimal - fix me\n", __func__);
8135
+
8136
+ if (dst->type == GGML_TYPE_F32) {
8137
+ size_t id = 0;
8138
+ float * dst_ptr = (float *) dst->data;
8139
+
8140
+ for (int i03 = 0; i03 < ne03; i03++) {
8141
+ for (int i02 = 0; i02 < ne02; i02++) {
8142
+ id += ne00 * ir0;
8143
+ for (int i01 = ir0; i01 < ir1; i01++) {
8144
+ for (int i00 = 0; i00 < ne00; i00++) {
8145
+ const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
8146
+
8147
+ dst_ptr[id] = *src0_ptr;
8148
+ id++;
8149
+ }
8150
+ }
8151
+ id += ne00 * (ne01 - ir1);
8152
+ }
8153
+ }
8154
+ } else if (dst->type == GGML_TYPE_F16) {
8155
+ size_t id = 0;
8156
+ ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
8157
+
8158
+ for (int i03 = 0; i03 < ne03; i03++) {
8159
+ for (int i02 = 0; i02 < ne02; i02++) {
8160
+ id += ne00 * ir0;
8161
+ for (int i01 = ir0; i01 < ir1; i01++) {
8162
+ for (int i00 = 0; i00 < ne00; i00++) {
8163
+ const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
8164
+
8165
+ dst_ptr[id] = GGML_FP32_TO_FP16(*src0_ptr);
8166
+ id++;
8167
+ }
8168
+ }
8169
+ id += ne00 * (ne01 - ir1);
8170
+ }
8171
+ }
8172
+ } else if (dst->type == GGML_TYPE_BF16) {
8173
+ size_t id = 0;
8174
+ ggml_bf16_t * dst_ptr = (ggml_bf16_t *) dst->data;
8175
+
8176
+ for (int i03 = 0; i03 < ne03; i03++) {
8177
+ for (int i02 = 0; i02 < ne02; i02++) {
8178
+ id += ne00 * ir0;
8179
+ for (int i01 = ir0; i01 < ir1; i01++) {
8180
+ for (int i00 = 0; i00 < ne00; i00++) {
8181
+ const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
8182
+
8183
+ dst_ptr[id] = GGML_FP32_TO_BF16(*src0_ptr);
8184
+ id++;
8185
+ }
8186
+ }
8187
+ id += ne00 * (ne01 - ir1);
8188
+ }
8189
+ }
8190
+ } else {
8191
+ GGML_ASSERT(false); // TODO: implement
8192
+ }
8193
+ }
8194
+
8195
+ return;
8196
+ }
8197
+
8198
+ // dst counters
8199
+
8200
+ int64_t i10 = 0;
8201
+ int64_t i11 = 0;
8202
+ int64_t i12 = 0;
8203
+ int64_t i13 = 0;
8204
+
8205
+ if (dst->type == GGML_TYPE_F32) {
8206
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
8207
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
8208
+ i10 += ne00 * ir0;
8209
+ while (i10 >= ne0) {
8210
+ i10 -= ne0;
8211
+ if (++i11 == ne1) {
8212
+ i11 = 0;
8213
+ if (++i12 == ne2) {
8214
+ i12 = 0;
8215
+ if (++i13 == ne3) {
8216
+ i13 = 0;
8217
+ }
8218
+ }
8219
+ }
8220
+ }
8221
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
8222
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
8223
+ const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
8224
+ char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
8225
+
8226
+ memcpy(dst_ptr, src0_ptr, sizeof(float));
8227
+
8228
+ if (++i10 == ne0) {
8229
+ i10 = 0;
8230
+ if (++i11 == ne1) {
8231
+ i11 = 0;
8232
+ if (++i12 == ne2) {
8233
+ i12 = 0;
8234
+ if (++i13 == ne3) {
8235
+ i13 = 0;
8236
+ }
8237
+ }
8238
+ }
8239
+ }
8240
+ }
8241
+ }
8242
+ i10 += ne00 * (ne01 - ir1);
8243
+ while (i10 >= ne0) {
8244
+ i10 -= ne0;
8245
+ if (++i11 == ne1) {
8246
+ i11 = 0;
8247
+ if (++i12 == ne2) {
8248
+ i12 = 0;
8249
+ if (++i13 == ne3) {
8250
+ i13 = 0;
8251
+ }
8252
+ }
8253
+ }
8254
+ }
8255
+ }
8256
+ }
8257
+ } else if (dst->type == GGML_TYPE_F16) {
8258
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
8259
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
8260
+ i10 += ne00 * ir0;
8261
+ while (i10 >= ne0) {
8262
+ i10 -= ne0;
8263
+ if (++i11 == ne1) {
8264
+ i11 = 0;
8265
+ if (++i12 == ne2) {
8266
+ i12 = 0;
8267
+ if (++i13 == ne3) {
8268
+ i13 = 0;
8269
+ }
8270
+ }
8271
+ }
8272
+ }
8273
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
8274
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
8275
+ const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
8276
+ char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
8277
+
8278
+ *(ggml_fp16_t *) dst_ptr = GGML_FP32_TO_FP16(*(const float *) src0_ptr);
8279
+
8280
+ if (++i10 == ne0) {
8281
+ i10 = 0;
8282
+ if (++i11 == ne1) {
8283
+ i11 = 0;
8284
+ if (++i12 == ne2) {
8285
+ i12 = 0;
8286
+ if (++i13 == ne3) {
8287
+ i13 = 0;
8288
+ }
8289
+ }
8290
+ }
8291
+ }
8292
+ }
8293
+ }
8294
+ i10 += ne00 * (ne01 - ir1);
8295
+ while (i10 >= ne0) {
8296
+ i10 -= ne0;
8297
+ if (++i11 == ne1) {
8298
+ i11 = 0;
8299
+ if (++i12 == ne2) {
8300
+ i12 = 0;
8301
+ if (++i13 == ne3) {
8302
+ i13 = 0;
8303
+ }
8304
+ }
8305
+ }
8306
+ }
8307
+ }
8308
+ }
8309
+ } else if (dst->type == GGML_TYPE_BF16) {
8310
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
8311
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
8312
+ i10 += ne00 * ir0;
8313
+ while (i10 >= ne0) {
8314
+ i10 -= ne0;
8315
+ if (++i11 == ne1) {
8316
+ i11 = 0;
8317
+ if (++i12 == ne2) {
8318
+ i12 = 0;
8319
+ if (++i13 == ne3) {
8320
+ i13 = 0;
8321
+ }
8322
+ }
8323
+ }
8324
+ }
8325
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
8326
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
8327
+ const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
8328
+ char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
8329
+
8330
+ *(ggml_bf16_t *) dst_ptr = GGML_FP32_TO_BF16(*(const float *) src0_ptr);
8331
+
8332
+ if (++i10 == ne0) {
8333
+ i10 = 0;
8334
+ if (++i11 == ne1) {
8335
+ i11 = 0;
8336
+ if (++i12 == ne2) {
8337
+ i12 = 0;
8338
+ if (++i13 == ne3) {
8339
+ i13 = 0;
8340
+ }
8341
+ }
8342
+ }
8343
+ }
8344
+ }
8345
+ }
8346
+ i10 += ne00 * (ne01 - ir1);
8347
+ while (i10 >= ne0) {
8348
+ i10 -= ne0;
8349
+ if (++i11 == ne1) {
8350
+ i11 = 0;
8351
+ if (++i12 == ne2) {
8352
+ i12 = 0;
8353
+ if (++i13 == ne3) {
8354
+ i13 = 0;
8355
+ }
8356
+ }
8357
+ }
8358
+ }
8359
+ }
8360
+ }
8361
+ } else {
8362
+ GGML_ASSERT(false); // TODO: implement
8363
+ }
8364
+ }
7620
8365
 
7621
8366
  // A simplified version of ggml_compute_forward_dup that doesn't do float upcasting, and just plain old memcpy.
7622
8367
  static void ggml_compute_forward_dup_bytes(
@@ -7786,6 +8531,10 @@ static void ggml_compute_forward_dup(
7786
8531
  {
7787
8532
  ggml_compute_forward_dup_f16(params, dst);
7788
8533
  } break;
8534
+ case GGML_TYPE_BF16:
8535
+ {
8536
+ ggml_compute_forward_dup_bf16(params, dst);
8537
+ } break;
7789
8538
  case GGML_TYPE_F32:
7790
8539
  {
7791
8540
  ggml_compute_forward_dup_f32(params, dst);
@@ -7968,6 +8717,85 @@ static void ggml_compute_forward_add_f16_f32(
7968
8717
  }
7969
8718
  }
7970
8719
 
8720
+ static void ggml_compute_forward_add_bf16_f32(
8721
+ const struct ggml_compute_params * params,
8722
+ struct ggml_tensor * dst) {
8723
+
8724
+ const struct ggml_tensor * src0 = dst->src[0];
8725
+ const struct ggml_tensor * src1 = dst->src[1];
8726
+
8727
+ GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
8728
+
8729
+ if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
8730
+ return;
8731
+ }
8732
+
8733
+ const int ith = params->ith;
8734
+ const int nth = params->nth;
8735
+
8736
+ const int nr = ggml_nrows(src0);
8737
+
8738
+ GGML_TENSOR_BINARY_OP_LOCALS
8739
+
8740
+ GGML_ASSERT(src0->type == GGML_TYPE_BF16);
8741
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
8742
+
8743
+ if (dst->type == GGML_TYPE_F32) {
8744
+ GGML_ASSERT( nb0 == sizeof(float));
8745
+ }
8746
+ else {
8747
+ GGML_ASSERT(dst->type == GGML_TYPE_BF16);
8748
+ GGML_ASSERT( nb0 == sizeof(ggml_bf16_t));
8749
+ }
8750
+
8751
+ GGML_ASSERT(nb00 == sizeof(ggml_bf16_t));
8752
+
8753
+ // rows per thread
8754
+ const int dr = (nr + nth - 1)/nth;
8755
+
8756
+ // row range for this thread
8757
+ const int ir0 = dr*ith;
8758
+ const int ir1 = MIN(ir0 + dr, nr);
8759
+
8760
+ if (nb10 == sizeof(float)) {
8761
+ if (dst->type == GGML_TYPE_BF16) {
8762
+ for (int ir = ir0; ir < ir1; ++ir) {
8763
+ // src0, src1 and dst are same shape => same indices
8764
+ const int i3 = ir/(ne2*ne1);
8765
+ const int i2 = (ir - i3*ne2*ne1)/ne1;
8766
+ const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
8767
+
8768
+ ggml_bf16_t * dst_ptr = (ggml_bf16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
8769
+ ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
8770
+ float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
8771
+
8772
+ for (int i = 0; i < ne0; i++) {
8773
+ dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + src1_ptr[i]);
8774
+ }
8775
+ }
8776
+ } else {
8777
+ for (int ir = ir0; ir < ir1; ++ir) {
8778
+ // src0, src1 and dst are same shape => same indices
8779
+ const int i3 = ir/(ne2*ne1);
8780
+ const int i2 = (ir - i3*ne2*ne1)/ne1;
8781
+ const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
8782
+
8783
+ float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
8784
+ ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
8785
+ float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
8786
+
8787
+ for (int i = 0; i < ne0; i++) {
8788
+ dst_ptr[i] = GGML_BF16_TO_FP32(src0_ptr[i]) + src1_ptr[i];
8789
+ }
8790
+ }
8791
+ }
8792
+ }
8793
+ else {
8794
+ // src1 is not contiguous
8795
+ GGML_ASSERT(false);
8796
+ }
8797
+ }
8798
+
7971
8799
  static void ggml_compute_forward_add_f16_f16(
7972
8800
  const struct ggml_compute_params * params,
7973
8801
  struct ggml_tensor * dst) {
@@ -8024,6 +8852,62 @@ static void ggml_compute_forward_add_f16_f16(
8024
8852
  }
8025
8853
  }
8026
8854
 
8855
+ static void ggml_compute_forward_add_bf16_bf16(
8856
+ const struct ggml_compute_params * params,
8857
+ struct ggml_tensor * dst) {
8858
+
8859
+ const struct ggml_tensor * src0 = dst->src[0];
8860
+ const struct ggml_tensor * src1 = dst->src[1];
8861
+
8862
+ GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
8863
+
8864
+ if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
8865
+ return;
8866
+ }
8867
+
8868
+ const int ith = params->ith;
8869
+ const int nth = params->nth;
8870
+
8871
+ const int nr = ggml_nrows(src0);
8872
+
8873
+ GGML_TENSOR_BINARY_OP_LOCALS
8874
+
8875
+ GGML_ASSERT(src0->type == GGML_TYPE_BF16);
8876
+ GGML_ASSERT(src1->type == GGML_TYPE_BF16);
8877
+ GGML_ASSERT(dst->type == GGML_TYPE_BF16);
8878
+
8879
+ GGML_ASSERT( nb0 == sizeof(ggml_bf16_t));
8880
+ GGML_ASSERT(nb00 == sizeof(ggml_bf16_t));
8881
+
8882
+ // rows per thread
8883
+ const int dr = (nr + nth - 1)/nth;
8884
+
8885
+ // row range for this thread
8886
+ const int ir0 = dr*ith;
8887
+ const int ir1 = MIN(ir0 + dr, nr);
8888
+
8889
+ if (nb10 == sizeof(ggml_bf16_t)) {
8890
+ for (int ir = ir0; ir < ir1; ++ir) {
8891
+ // src0, src1 and dst are same shape => same indices
8892
+ const int i3 = ir/(ne2*ne1);
8893
+ const int i2 = (ir - i3*ne2*ne1)/ne1;
8894
+ const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
8895
+
8896
+ ggml_bf16_t * dst_ptr = (ggml_bf16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
8897
+ ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
8898
+ ggml_bf16_t * src1_ptr = (ggml_bf16_t *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
8899
+
8900
+ for (int i = 0; i < ne0; i++) {
8901
+ dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + GGML_BF16_TO_FP32(src1_ptr[i]));
8902
+ }
8903
+ }
8904
+ }
8905
+ else {
8906
+ // src1 is not contiguous
8907
+ GGML_ASSERT(false);
8908
+ }
8909
+ }
8910
+
8027
8911
  static void ggml_compute_forward_add_q_f32(
8028
8912
  const struct ggml_compute_params * params,
8029
8913
  struct ggml_tensor * dst) {
@@ -8133,6 +9017,18 @@ static void ggml_compute_forward_add(
8133
9017
  GGML_ASSERT(false);
8134
9018
  }
8135
9019
  } break;
9020
+ case GGML_TYPE_BF16:
9021
+ {
9022
+ if (src1->type == GGML_TYPE_BF16) {
9023
+ ggml_compute_forward_add_bf16_bf16(params, dst);
9024
+ }
9025
+ else if (src1->type == GGML_TYPE_F32) {
9026
+ ggml_compute_forward_add_bf16_f32(params, dst);
9027
+ }
9028
+ else {
9029
+ GGML_ASSERT(false);
9030
+ }
9031
+ } break;
8136
9032
  case GGML_TYPE_Q4_0:
8137
9033
  case GGML_TYPE_Q4_1:
8138
9034
  case GGML_TYPE_Q5_0:
@@ -8346,21 +9242,133 @@ static void ggml_compute_forward_add1_q_f32(
8346
9242
 
8347
9243
  GGML_TENSOR_UNARY_OP_LOCALS
8348
9244
 
8349
- const enum ggml_type type = src0->type;
8350
- ggml_to_float_t const dequantize_row_q = type_traits[type].to_float;
8351
- ggml_from_float_t const quantize_row_q = type_traits[type].from_float;
8352
-
8353
- // we don't support permuted src0
8354
- GGML_ASSERT(nb00 == ggml_type_size(type));
8355
-
8356
- // dst cannot be transposed or permuted
8357
- GGML_ASSERT(nb0 <= nb1);
8358
- GGML_ASSERT(nb1 <= nb2);
8359
- GGML_ASSERT(nb2 <= nb3);
9245
+ const enum ggml_type type = src0->type;
9246
+ ggml_to_float_t const dequantize_row_q = type_traits[type].to_float;
9247
+ ggml_from_float_t const quantize_row_q = type_traits[type].from_float;
9248
+
9249
+ // we don't support permuted src0
9250
+ GGML_ASSERT(nb00 == ggml_type_size(type));
9251
+
9252
+ // dst cannot be transposed or permuted
9253
+ GGML_ASSERT(nb0 <= nb1);
9254
+ GGML_ASSERT(nb1 <= nb2);
9255
+ GGML_ASSERT(nb2 <= nb3);
9256
+
9257
+ GGML_ASSERT(ggml_is_quantized(src0->type));
9258
+ GGML_ASSERT(dst->type == src0->type);
9259
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
9260
+
9261
+ // rows per thread
9262
+ const int dr = (nr + nth - 1)/nth;
9263
+
9264
+ // row range for this thread
9265
+ const int ir0 = dr*ith;
9266
+ const int ir1 = MIN(ir0 + dr, nr);
9267
+
9268
+ float * wdata = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32) * ith;
9269
+
9270
+ for (int ir = ir0; ir < ir1; ++ir) {
9271
+ // src0 and dst are same shape => same indices
9272
+ const int i3 = ir/(ne2*ne1);
9273
+ const int i2 = (ir - i3*ne2*ne1)/ne1;
9274
+ const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
9275
+
9276
+ void * src0_row = (void *) ((char *) src0->data + (i1*nb01 + i2*nb02 + i3*nb03));
9277
+ void * dst_row = (void *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb0 ));
9278
+
9279
+ assert(ne0 % 32 == 0);
9280
+
9281
+ // unquantize row from src0 to temp buffer
9282
+ dequantize_row_q(src0_row, wdata, ne0);
9283
+ // add src1
9284
+ ggml_vec_acc1_f32(ne0, wdata, v);
9285
+ // quantize row to dst
9286
+ quantize_row_q(wdata, dst_row, ne0);
9287
+ }
9288
+ }
9289
+
9290
+ static void ggml_compute_forward_add1_bf16_f32(
9291
+ const struct ggml_compute_params * params,
9292
+ struct ggml_tensor * dst) {
9293
+
9294
+ const struct ggml_tensor * src0 = dst->src[0];
9295
+ const struct ggml_tensor * src1 = dst->src[1];
9296
+
9297
+ GGML_ASSERT(ggml_are_same_shape(src0, dst));
9298
+ GGML_ASSERT(ggml_is_scalar(src1));
9299
+
9300
+ if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
9301
+ return;
9302
+ }
9303
+
9304
+ // scalar to add
9305
+ const float v = *(float *) src1->data;
9306
+
9307
+ const int ith = params->ith;
9308
+ const int nth = params->nth;
9309
+
9310
+ const int nr = ggml_nrows(src0);
9311
+
9312
+ GGML_TENSOR_UNARY_OP_LOCALS
9313
+
9314
+ GGML_ASSERT(src0->type == GGML_TYPE_BF16);
9315
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
9316
+ GGML_ASSERT(dst->type == GGML_TYPE_BF16);
9317
+
9318
+ GGML_ASSERT( nb0 == sizeof(ggml_bf16_t));
9319
+ GGML_ASSERT(nb00 == sizeof(ggml_bf16_t));
9320
+
9321
+ // rows per thread
9322
+ const int dr = (nr + nth - 1)/nth;
9323
+
9324
+ // row range for this thread
9325
+ const int ir0 = dr*ith;
9326
+ const int ir1 = MIN(ir0 + dr, nr);
9327
+
9328
+ for (int ir = ir0; ir < ir1; ++ir) {
9329
+ // src0 and dst are same shape => same indices
9330
+ const int i3 = ir/(ne2*ne1);
9331
+ const int i2 = (ir - i3*ne2*ne1)/ne1;
9332
+ const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
9333
+
9334
+ ggml_bf16_t * dst_ptr = (ggml_bf16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 );
9335
+ ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
9336
+ for (int i = 0; i < ne0; i++) {
9337
+ dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + v);
9338
+ }
9339
+ }
9340
+ }
9341
+
9342
+ static void ggml_compute_forward_add1_bf16_bf16(
9343
+ const struct ggml_compute_params * params,
9344
+ struct ggml_tensor * dst) {
9345
+
9346
+ const struct ggml_tensor * src0 = dst->src[0];
9347
+ const struct ggml_tensor * src1 = dst->src[1];
9348
+
9349
+ GGML_ASSERT(ggml_are_same_shape(src0, dst));
9350
+ GGML_ASSERT(ggml_is_scalar(src1));
9351
+
9352
+ if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
9353
+ return;
9354
+ }
9355
+
9356
+ // scalar to add
9357
+ const float v = GGML_BF16_TO_FP32(*(ggml_bf16_t *) src1->data);
9358
+
9359
+ const int ith = params->ith;
9360
+ const int nth = params->nth;
9361
+
9362
+ const int nr = ggml_nrows(src0);
9363
+
9364
+ GGML_TENSOR_UNARY_OP_LOCALS
9365
+
9366
+ GGML_ASSERT(src0->type == GGML_TYPE_BF16);
9367
+ GGML_ASSERT(src1->type == GGML_TYPE_BF16);
9368
+ GGML_ASSERT(dst->type == GGML_TYPE_BF16);
8360
9369
 
8361
- GGML_ASSERT(ggml_is_quantized(src0->type));
8362
- GGML_ASSERT(dst->type == src0->type);
8363
- GGML_ASSERT(src1->type == GGML_TYPE_F32);
9370
+ GGML_ASSERT( nb0 == sizeof(ggml_bf16_t));
9371
+ GGML_ASSERT(nb00 == sizeof(ggml_bf16_t));
8364
9372
 
8365
9373
  // rows per thread
8366
9374
  const int dr = (nr + nth - 1)/nth;
@@ -8369,25 +9377,17 @@ static void ggml_compute_forward_add1_q_f32(
8369
9377
  const int ir0 = dr*ith;
8370
9378
  const int ir1 = MIN(ir0 + dr, nr);
8371
9379
 
8372
- float * wdata = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32) * ith;
8373
-
8374
9380
  for (int ir = ir0; ir < ir1; ++ir) {
8375
9381
  // src0 and dst are same shape => same indices
8376
9382
  const int i3 = ir/(ne2*ne1);
8377
9383
  const int i2 = (ir - i3*ne2*ne1)/ne1;
8378
9384
  const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
8379
9385
 
8380
- void * src0_row = (void *) ((char *) src0->data + (i1*nb01 + i2*nb02 + i3*nb03));
8381
- void * dst_row = (void *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb0 ));
8382
-
8383
- assert(ne0 % 32 == 0);
8384
-
8385
- // unquantize row from src0 to temp buffer
8386
- dequantize_row_q(src0_row, wdata, ne0);
8387
- // add src1
8388
- ggml_vec_acc1_f32(ne0, wdata, v);
8389
- // quantize row to dst
8390
- quantize_row_q(wdata, dst_row, ne0);
9386
+ ggml_bf16_t * dst_ptr = (ggml_bf16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 );
9387
+ ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
9388
+ for (int i = 0; i < ne0; i++) {
9389
+ dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + v);
9390
+ }
8391
9391
  }
8392
9392
  }
8393
9393
 
@@ -8415,6 +9415,18 @@ static void ggml_compute_forward_add1(
8415
9415
  GGML_ASSERT(false);
8416
9416
  }
8417
9417
  } break;
9418
+ case GGML_TYPE_BF16:
9419
+ {
9420
+ if (src1->type == GGML_TYPE_BF16) {
9421
+ ggml_compute_forward_add1_bf16_bf16(params, dst);
9422
+ }
9423
+ else if (src1->type == GGML_TYPE_F32) {
9424
+ ggml_compute_forward_add1_bf16_f32(params, dst);
9425
+ }
9426
+ else {
9427
+ GGML_ASSERT(false);
9428
+ }
9429
+ } break;
8418
9430
  case GGML_TYPE_Q4_0:
8419
9431
  case GGML_TYPE_Q4_1:
8420
9432
  case GGML_TYPE_Q5_0:
@@ -8543,6 +9555,7 @@ static void ggml_compute_forward_acc(
8543
9555
  ggml_compute_forward_acc_f32(params, dst);
8544
9556
  } break;
8545
9557
  case GGML_TYPE_F16:
9558
+ case GGML_TYPE_BF16:
8546
9559
  case GGML_TYPE_Q4_0:
8547
9560
  case GGML_TYPE_Q4_1:
8548
9561
  case GGML_TYPE_Q5_0:
@@ -9064,6 +10077,40 @@ static void ggml_compute_forward_sum_f16(
9064
10077
  ((ggml_fp16_t *) dst->data)[0] = GGML_FP32_TO_FP16(sum);
9065
10078
  }
9066
10079
 
10080
+ static void ggml_compute_forward_sum_bf16(
10081
+ const struct ggml_compute_params * params,
10082
+ struct ggml_tensor * dst) {
10083
+
10084
+ const struct ggml_tensor * src0 = dst->src[0];
10085
+
10086
+ assert(params->ith == 0);
10087
+ assert(ggml_is_scalar(dst));
10088
+
10089
+ if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
10090
+ return;
10091
+ }
10092
+
10093
+ assert(src0->nb[0] == sizeof(ggml_bf16_t));
10094
+
10095
+ GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
10096
+ GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
10097
+
10098
+ float sum = 0;
10099
+ float row_sum = 0;
10100
+
10101
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
10102
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
10103
+ for (int64_t i01 = 0; i01 < ne01; i01++) {
10104
+ ggml_vec_sum_bf16_ggf(ne00,
10105
+ &row_sum,
10106
+ (ggml_bf16_t *) ((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03));
10107
+ sum += row_sum;
10108
+ }
10109
+ }
10110
+ }
10111
+ ((ggml_bf16_t *) dst->data)[0] = GGML_FP32_TO_BF16(sum);
10112
+ }
10113
+
9067
10114
  static void ggml_compute_forward_sum(
9068
10115
  const struct ggml_compute_params * params,
9069
10116
  struct ggml_tensor * dst) {
@@ -9079,6 +10126,10 @@ static void ggml_compute_forward_sum(
9079
10126
  {
9080
10127
  ggml_compute_forward_sum_f16(params, dst);
9081
10128
  } break;
10129
+ case GGML_TYPE_BF16:
10130
+ {
10131
+ ggml_compute_forward_sum_bf16(params, dst);
10132
+ } break;
9082
10133
  default:
9083
10134
  {
9084
10135
  GGML_ASSERT(false);
@@ -9353,6 +10404,7 @@ static void ggml_compute_forward_repeat(
9353
10404
 
9354
10405
  switch (src0->type) {
9355
10406
  case GGML_TYPE_F16:
10407
+ case GGML_TYPE_BF16:
9356
10408
  case GGML_TYPE_I16:
9357
10409
  {
9358
10410
  ggml_compute_forward_repeat_f16(params, dst);
@@ -11670,6 +12722,7 @@ static void ggml_compute_forward_set(
11670
12722
  ggml_compute_forward_set_f32(params, dst);
11671
12723
  } break;
11672
12724
  case GGML_TYPE_F16:
12725
+ case GGML_TYPE_BF16:
11673
12726
  case GGML_TYPE_Q4_0:
11674
12727
  case GGML_TYPE_Q4_1:
11675
12728
  case GGML_TYPE_Q5_0:
@@ -11844,6 +12897,49 @@ static void ggml_compute_forward_get_rows_f16(
11844
12897
  }
11845
12898
  }
11846
12899
 
12900
+ static void ggml_compute_forward_get_rows_bf16(
12901
+ const struct ggml_compute_params * params,
12902
+ struct ggml_tensor * dst) {
12903
+
12904
+ const struct ggml_tensor * src0 = dst->src[0];
12905
+ const struct ggml_tensor * src1 = dst->src[1];
12906
+
12907
+ if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
12908
+ return;
12909
+ }
12910
+
12911
+ GGML_TENSOR_BINARY_OP_LOCALS
12912
+
12913
+ const int64_t nc = ne00;
12914
+ const int64_t nr = ggml_nelements(src1);
12915
+
12916
+ assert(ne0 == nc);
12917
+ assert(ne02 == ne11);
12918
+ assert(nb00 == sizeof(ggml_bf16_t));
12919
+ assert(ggml_nrows(dst) == nr);
12920
+
12921
+ const int ith = params->ith;
12922
+ const int nth = params->nth;
12923
+
12924
+ // rows per thread
12925
+ const int dr = (nr + nth - 1)/nth;
12926
+
12927
+ // row range for this thread
12928
+ const int ir0 = dr*ith;
12929
+ const int ir1 = MIN(ir0 + dr, nr);
12930
+
12931
+ for (int64_t i = ir0; i < ir1; ++i) {
12932
+ const int64_t i12 = i/(ne11*ne10);
12933
+ const int64_t i11 = (i - i12*ne11*ne10)/ne10;
12934
+ const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
12935
+ const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
12936
+
12937
+ ggml_bf16_to_fp32_row(
12938
+ (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
12939
+ (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc);
12940
+ }
12941
+ }
12942
+
11847
12943
  static void ggml_compute_forward_get_rows_f32(
11848
12944
  const struct ggml_compute_params * params,
11849
12945
  struct ggml_tensor * dst) {
@@ -11921,6 +13017,10 @@ static void ggml_compute_forward_get_rows(
11921
13017
  {
11922
13018
  ggml_compute_forward_get_rows_f16(params, dst);
11923
13019
  } break;
13020
+ case GGML_TYPE_BF16:
13021
+ {
13022
+ ggml_compute_forward_get_rows_bf16(params, dst);
13023
+ } break;
11924
13024
  case GGML_TYPE_F32:
11925
13025
  case GGML_TYPE_I32:
11926
13026
  {
@@ -12255,7 +13355,7 @@ static void ggml_compute_forward_soft_max_f32(
12255
13355
 
12256
13356
  GGML_TENSOR_UNARY_OP_LOCALS
12257
13357
 
12258
- const int64_t ne11 = src1 ? src1->ne[1] : 1;
13358
+ //const int64_t ne11 = src1 ? src1->ne[1] : 1;
12259
13359
 
12260
13360
  // TODO: is this supposed to be ceil instead of floor?
12261
13361
  // https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370
@@ -12278,19 +13378,31 @@ static void ggml_compute_forward_soft_max_f32(
12278
13378
  float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
12279
13379
 
12280
13380
  // when max_bias <= 0.0f, src2 is not used and we default it to src0 to avoid branching
12281
- float * pos = src2 ? (float *) src2->data : src0->data;
13381
+ ggml_fp16_t * pos_f16 = src2 ? (ggml_fp16_t *) src2->data : src0->data;
13382
+ float * pos_f32 = src2 ? (float *) src2->data : src0->data;
13383
+
13384
+ const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16) || (src2 && src2->type == GGML_TYPE_F16);
12282
13385
 
12283
13386
  for (int i1 = ir0; i1 < ir1; i1++) {
12284
13387
  float * sp = (float *)((char *) src0->data + i1*src0->nb[1]);
12285
13388
  float * dp = (float *)((char *) dst->data + i1*dst->nb[1]);
12286
13389
 
12287
13390
  // broadcast the mask across rows
12288
- float * mp = src1 ? (float *)((char *) src1->data + (i1%ne11)*src1->nb[1]) : NULL;
13391
+ ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
13392
+ float * mp_f32 = src1 ? (float *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
12289
13393
 
12290
13394
  ggml_vec_cpy_f32 (nc, wp, sp);
12291
13395
  ggml_vec_scale_f32(nc, wp, scale);
12292
- if (mp) {
12293
- ggml_vec_acc_f32(nc, wp, mp);
13396
+ if (mp_f32) {
13397
+ if (use_f16) {
13398
+ for (int i = 0; i < nc; ++i) {
13399
+ wp[i] += GGML_FP16_TO_FP32(mp_f16[i]);
13400
+ }
13401
+ } else {
13402
+ for (int i = 0; i < nc; ++i) {
13403
+ wp[i] += mp_f32[i];
13404
+ }
13405
+ }
12294
13406
  }
12295
13407
 
12296
13408
  // ALiBi bias
@@ -12298,8 +13410,14 @@ static void ggml_compute_forward_soft_max_f32(
12298
13410
  const uint32_t h = (i1/ne01)%ne02; // head
12299
13411
  const float slope = h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1);
12300
13412
 
12301
- for (int i = 0; i < nc; i++) {
12302
- wp[i] = wp[i] + slope*pos[i];
13413
+ if (use_f16) {
13414
+ for (int i = 0; i < nc; ++i) {
13415
+ wp[i] += slope*GGML_FP16_TO_FP32(pos_f16[i]);
13416
+ }
13417
+ } else {
13418
+ for (int i = 0; i < nc; ++i) {
13419
+ wp[i] += slope*pos_f32[i];
13420
+ }
12303
13421
  }
12304
13422
  }
12305
13423
 
@@ -12598,6 +13716,7 @@ static void ggml_compute_forward_alibi(
12598
13716
  {
12599
13717
  ggml_compute_forward_alibi_f32(params, dst);
12600
13718
  } break;
13719
+ case GGML_TYPE_BF16:
12601
13720
  case GGML_TYPE_Q4_0:
12602
13721
  case GGML_TYPE_Q4_1:
12603
13722
  case GGML_TYPE_Q5_0:
@@ -12687,6 +13806,7 @@ static void ggml_compute_forward_clamp(
12687
13806
  ggml_compute_forward_clamp_f32(params, dst);
12688
13807
  } break;
12689
13808
  case GGML_TYPE_F16:
13809
+ case GGML_TYPE_BF16:
12690
13810
  case GGML_TYPE_Q4_0:
12691
13811
  case GGML_TYPE_Q4_1:
12692
13812
  case GGML_TYPE_Q5_0:
@@ -14569,6 +15689,198 @@ static void ggml_compute_forward_flash_attn(
14569
15689
  }
14570
15690
  }
14571
15691
 
15692
+ // ggml_compute_forward_flash_attn_ext
15693
+
15694
+ static void ggml_compute_forward_flash_attn_ext_f16(
15695
+ const struct ggml_compute_params * params,
15696
+ const struct ggml_tensor * q,
15697
+ const struct ggml_tensor * k,
15698
+ const struct ggml_tensor * v,
15699
+ const struct ggml_tensor * mask,
15700
+ struct ggml_tensor * dst) {
15701
+ int64_t t0 = ggml_perf_time_us();
15702
+ UNUSED(t0);
15703
+
15704
+ GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
15705
+ GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
15706
+ GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
15707
+ GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
15708
+ GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
15709
+ GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
15710
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
15711
+ GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
15712
+
15713
+ const int ith = params->ith;
15714
+ const int nth = params->nth;
15715
+
15716
+ const int64_t D = neq0;
15717
+ const int64_t N = neq1;
15718
+
15719
+ GGML_ASSERT(ne0 == D);
15720
+ GGML_ASSERT(ne2 == N);
15721
+
15722
+ GGML_ASSERT(nbq0 == sizeof(float));
15723
+ GGML_ASSERT(nbk0 == sizeof(ggml_fp16_t));
15724
+ GGML_ASSERT(nbv0 == sizeof(ggml_fp16_t));
15725
+
15726
+ GGML_ASSERT(neq0 == D);
15727
+ GGML_ASSERT(nek0 == D);
15728
+ GGML_ASSERT(nev0 == D);
15729
+
15730
+ GGML_ASSERT(neq1 == N);
15731
+ GGML_ASSERT(nev0 == D);
15732
+
15733
+ // dst cannot be transposed or permuted
15734
+ GGML_ASSERT(nb0 == sizeof(float));
15735
+ GGML_ASSERT(nb0 <= nb1);
15736
+ GGML_ASSERT(nb1 <= nb2);
15737
+ GGML_ASSERT(nb2 <= nb3);
15738
+
15739
+ // broadcast factors
15740
+ const int64_t rk2 = neq2/nek2;
15741
+ const int64_t rk3 = neq3/nek3;
15742
+
15743
+ const int64_t rv2 = neq2/nev2;
15744
+ const int64_t rv3 = neq3/nev3;
15745
+
15746
+ if (params->type == GGML_TASK_TYPE_INIT) {
15747
+ return;
15748
+ }
15749
+
15750
+ if (params->type == GGML_TASK_TYPE_FINALIZE) {
15751
+ return;
15752
+ }
15753
+
15754
+ // parallelize by q rows using ggml_vec_dot_f32
15755
+
15756
+ // total rows in q
15757
+ const int nr = neq1*neq2*neq3;
15758
+
15759
+ // rows per thread
15760
+ const int dr = (nr + nth - 1)/nth;
15761
+
15762
+ // row range for this thread
15763
+ const int ir0 = dr*ith;
15764
+ const int ir1 = MIN(ir0 + dr, nr);
15765
+
15766
+ float scale = 1.0f;
15767
+ memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
15768
+
15769
+ // loop over n_batch and n_head
15770
+ for (int ir = ir0; ir < ir1; ++ir) {
15771
+ // q indices
15772
+ const int iq3 = ir/(neq2*neq1);
15773
+ const int iq2 = (ir - iq3*neq2*neq1)/neq1;
15774
+ const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
15775
+
15776
+ float S = 0.0f;
15777
+ float M = -INFINITY;
15778
+
15779
+ float * V32 = (float *) params->wdata + ith*(2*D + CACHE_LINE_SIZE_F32);
15780
+ ggml_fp16_t * Q16 = (ggml_fp16_t *) (V32); // reuse memory
15781
+ ggml_fp16_t * V16 = (ggml_fp16_t *) (V32 + D);
15782
+
15783
+ memset(V16, 0, D*sizeof(ggml_fp16_t));
15784
+
15785
+ const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL;
15786
+
15787
+ // k indices
15788
+ const int ik3 = iq3 / rk3;
15789
+ const int ik2 = iq2 / rk2;
15790
+
15791
+ // v indices
15792
+ const int iv3 = iq3 / rv3;
15793
+ const int iv2 = iq2 / rv2;
15794
+
15795
+ // online softmax / attention
15796
+ // loop over n_kv and n_head_kv
15797
+ // ref: https://arxiv.org/pdf/2112.05682.pdf
15798
+ for (int64_t ic = 0; ic < nek1; ++ic) {
15799
+ const float mv = mp ? GGML_FP16_TO_FP32(mp[ic]) : 0.0f;
15800
+ if (mv == -INFINITY) {
15801
+ continue;
15802
+ }
15803
+
15804
+ float s;
15805
+
15806
+ // convert Q to F16 in V32
15807
+ {
15808
+ const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
15809
+
15810
+ for (int64_t d = 0; d < D; ++d) {
15811
+ Q16[d] = GGML_FP32_TO_FP16(pq[d]);
15812
+ }
15813
+ }
15814
+
15815
+ ggml_vec_dot_f16(D,
15816
+ &s, 0,
15817
+ (ggml_fp16_t *) ((char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3)), 0,
15818
+ Q16, 0, 1);
15819
+
15820
+ s = s*scale + mv;
15821
+
15822
+ const float Mold = M;
15823
+
15824
+ float ms = 1.0f;
15825
+ float vs = 1.0f;
15826
+
15827
+ if (s > M) {
15828
+ M = s;
15829
+ ms = expf(Mold - M);
15830
+
15831
+ // V = V*expf(Mold - M)
15832
+ ggml_vec_scale_f16(D, V16, ms);
15833
+ } else {
15834
+ vs = expf(s - M);
15835
+ }
15836
+
15837
+ const ggml_fp16_t * v16 = (const ggml_fp16_t *) ((char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3));
15838
+
15839
+ // V += v*expf(s - M)
15840
+ ggml_vec_mad_f16(D, V16, v16, vs);
15841
+
15842
+ S = S*ms + vs;
15843
+ }
15844
+
15845
+ // V /= S
15846
+ for (int64_t d = 0; d < D; ++d) {
15847
+ V32[d] = GGML_FP16_TO_FP32(V16[d])/S;
15848
+ }
15849
+
15850
+ // dst indices
15851
+ const int i1 = iq1;
15852
+ const int i2 = iq2;
15853
+ const int i3 = iq3;
15854
+
15855
+ // original
15856
+ //memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float));
15857
+
15858
+ // permute(0, 2, 1, 3)
15859
+ memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, V32, nb1);
15860
+ }
15861
+ }
15862
+
15863
+ static void ggml_compute_forward_flash_attn_ext(
15864
+ const struct ggml_compute_params * params,
15865
+ const struct ggml_tensor * q,
15866
+ const struct ggml_tensor * k,
15867
+ const struct ggml_tensor * v,
15868
+ const struct ggml_tensor * mask,
15869
+ struct ggml_tensor * dst) {
15870
+ switch (dst->op_params[1]) {
15871
+ case GGML_PREC_DEFAULT:
15872
+ case GGML_PREC_F32:
15873
+ {
15874
+ // uses F32 accumulators
15875
+ ggml_compute_forward_flash_attn_ext_f16(params, q, k, v, mask, dst);
15876
+ } break;
15877
+ default:
15878
+ {
15879
+ GGML_ASSERT(false);
15880
+ } break;
15881
+ }
15882
+ }
15883
+
14572
15884
  // ggml_compute_forward_flash_ff
14573
15885
 
14574
15886
  static void ggml_compute_forward_flash_ff_f16(
@@ -15588,6 +16900,7 @@ static void ggml_compute_forward_get_rel_pos(
15588
16900
 
15589
16901
  switch (src0->type) {
15590
16902
  case GGML_TYPE_F16:
16903
+ case GGML_TYPE_BF16:
15591
16904
  {
15592
16905
  ggml_compute_forward_get_rel_pos_f16(params, dst);
15593
16906
  } break;
@@ -16376,6 +17689,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
16376
17689
  const bool masked = t != 0;
16377
17690
  ggml_compute_forward_flash_attn(params, masked, tensor);
16378
17691
  } break;
17692
+ case GGML_OP_FLASH_ATTN_EXT:
17693
+ {
17694
+ ggml_compute_forward_flash_attn_ext(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor);
17695
+ } break;
16379
17696
  case GGML_OP_FLASH_FF:
16380
17697
  {
16381
17698
  ggml_compute_forward_flash_ff(params, tensor);
@@ -17388,6 +18705,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
17388
18705
  GGML_ASSERT(false); // TODO: not implemented
17389
18706
  } break;
17390
18707
  case GGML_OP_FLASH_ATTN:
18708
+ case GGML_OP_FLASH_ATTN_EXT:
17391
18709
  {
17392
18710
  struct ggml_tensor * flash_grad = NULL;
17393
18711
  if (src0->grad || src1->grad || tensor->src[2]->grad) {
@@ -18160,6 +19478,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads, int n_cur_
18160
19478
  n_tasks = n_threads;
18161
19479
  } break;
18162
19480
  case GGML_OP_FLASH_ATTN:
19481
+ case GGML_OP_FLASH_ATTN_EXT:
18163
19482
  {
18164
19483
  n_tasks = n_threads;
18165
19484
  } break;
@@ -18446,7 +19765,10 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
18446
19765
  case GGML_OP_CPY:
18447
19766
  case GGML_OP_DUP:
18448
19767
  {
18449
- if (ggml_is_quantized(node->type)) {
19768
+ if (ggml_is_quantized(node->type) ||
19769
+ // F16 -> BF16 and BF16 -> F16 copies go through intermediate F32
19770
+ (node->src[0]->type == GGML_TYPE_F16 && node->src[1] && node->src[1]->type == GGML_TYPE_BF16) ||
19771
+ (node->src[0]->type == GGML_TYPE_BF16 && node->src[1] && node->src[1]->type == GGML_TYPE_F16)) {
18450
19772
  cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
18451
19773
  }
18452
19774
  } break;
@@ -18525,7 +19847,8 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
18525
19847
  const int64_t ne10 = node->src[1]->ne[0]; // L
18526
19848
  const int64_t ne11 = node->src[1]->ne[1]; // Cin
18527
19849
 
18528
- if (node->src[0]->type == GGML_TYPE_F16 &&
19850
+ if ((node->src[0]->type == GGML_TYPE_F16 ||
19851
+ node->src[0]->type == GGML_TYPE_BF16) &&
18529
19852
  node->src[1]->type == GGML_TYPE_F32) {
18530
19853
  cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02;
18531
19854
  cur += sizeof(ggml_fp16_t)*ne10*ne11;
@@ -18561,8 +19884,17 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
18561
19884
  } else if (node->src[1]->type == GGML_TYPE_F16) {
18562
19885
  cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1)
18563
19886
  cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
19887
+ } else if (node->src[1]->type == GGML_TYPE_BF16) {
19888
+ cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1)
19889
+ cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
18564
19890
  }
18565
19891
  } break;
19892
+ case GGML_OP_FLASH_ATTN_EXT:
19893
+ {
19894
+ const int64_t ne00 = node->src[0]->ne[0]; // D
19895
+
19896
+ cur = 2*sizeof(float)*ne00*n_tasks; // 2x head size
19897
+ } break;
18566
19898
  case GGML_OP_FLASH_FF:
18567
19899
  {
18568
19900
  if (node->src[1]->type == GGML_TYPE_F32) {
@@ -18571,6 +19903,9 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
18571
19903
  } else if (node->src[1]->type == GGML_TYPE_F16) {
18572
19904
  cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
18573
19905
  cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2
19906
+ } else if (node->src[1]->type == GGML_TYPE_BF16) {
19907
+ cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
19908
+ cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2
18574
19909
  }
18575
19910
  } break;
18576
19911
  case GGML_OP_FLASH_ATTN_BACK:
@@ -18584,6 +19919,9 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
18584
19919
  } else if (node->src[1]->type == GGML_TYPE_F16) {
18585
19920
  cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1)
18586
19921
  cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
19922
+ } else if (node->src[1]->type == GGML_TYPE_BF16) {
19923
+ cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1)
19924
+ cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
18587
19925
  }
18588
19926
  } break;
18589
19927
 
@@ -19360,7 +20698,9 @@ void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph
19360
20698
  if (node->type == GGML_TYPE_I8 || node->type == GGML_TYPE_I16 || node->type == GGML_TYPE_I32) {
19361
20699
  fprintf(fp, "%d", ggml_get_i32_1d(node, j));
19362
20700
  }
19363
- else if (node->type == GGML_TYPE_F32 || node->type == GGML_TYPE_F16) {
20701
+ else if (node->type == GGML_TYPE_F32 ||
20702
+ node->type == GGML_TYPE_F16 ||
20703
+ node->type == GGML_TYPE_BF16) {
19364
20704
  fprintf(fp, "%.1e", (double)ggml_get_f32_1d(node, j));
19365
20705
  }
19366
20706
  else {
@@ -20418,6 +21758,12 @@ size_t ggml_quantize_chunk(
20418
21758
  ggml_fp32_to_fp16_row(src + start, (ggml_fp16_t *)dst + start, n);
20419
21759
  result = n * elemsize;
20420
21760
  } break;
21761
+ case GGML_TYPE_BF16:
21762
+ {
21763
+ size_t elemsize = sizeof(ggml_bf16_t);
21764
+ ggml_fp32_to_bf16_row(src + start, (ggml_bf16_t *)dst + start, n);
21765
+ result = n * elemsize;
21766
+ } break;
20421
21767
  case GGML_TYPE_F32:
20422
21768
  {
20423
21769
  size_t elemsize = sizeof(float);
@@ -20614,7 +21960,7 @@ static void gguf_free_kv(struct gguf_kv * kv) {
20614
21960
  }
20615
21961
 
20616
21962
  struct gguf_context * gguf_init_empty(void) {
20617
- struct gguf_context * ctx = GGML_ALIGNED_MALLOC(sizeof(struct gguf_context));
21963
+ struct gguf_context * ctx = GGML_CALLOC(1, sizeof(struct gguf_context));
20618
21964
 
20619
21965
  memcpy(ctx->header.magic, GGUF_MAGIC, sizeof(ctx->header.magic));
20620
21966
  ctx->header.version = GGUF_VERSION;
@@ -20659,7 +22005,7 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
20659
22005
 
20660
22006
  bool ok = true;
20661
22007
 
20662
- struct gguf_context * ctx = GGML_ALIGNED_MALLOC(sizeof(struct gguf_context));
22008
+ struct gguf_context * ctx = GGML_CALLOC(1, sizeof(struct gguf_context));
20663
22009
 
20664
22010
  // read the header
20665
22011
  {
@@ -20696,9 +22042,13 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
20696
22042
 
20697
22043
  // read the kv pairs
20698
22044
  {
20699
- ctx->kv = GGML_MALLOC(ctx->header.n_kv * sizeof(struct gguf_kv));
22045
+ const uint64_t n_kv = ctx->header.n_kv;
20700
22046
 
20701
- for (uint64_t i = 0; i < ctx->header.n_kv; ++i) {
22047
+ // header.n_kv will hold the actual value of pairs that were successfully read in the loop below
22048
+ ctx->header.n_kv = 0;
22049
+ ctx->kv = GGML_CALLOC(n_kv, sizeof(struct gguf_kv));
22050
+
22051
+ for (uint64_t i = 0; i < n_kv; ++i) {
20702
22052
  struct gguf_kv * kv = &ctx->kv[i];
20703
22053
 
20704
22054
  //fprintf(stderr, "%s: reading kv %d\n", __func__, i);
@@ -20747,7 +22097,7 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
20747
22097
  return NULL;
20748
22098
  }
20749
22099
 
20750
- kv->value.arr.data = GGML_MALLOC(kv->value.arr.n * gguf_type_size(kv->value.arr.type));
22100
+ kv->value.arr.data = GGML_CALLOC(kv->value.arr.n, gguf_type_size(kv->value.arr.type));
20751
22101
 
20752
22102
  ok = ok && gguf_fread_el(file, kv->value.arr.data, kv->value.arr.n * gguf_type_size(kv->value.arr.type), &offset);
20753
22103
  } break;
@@ -20761,7 +22111,7 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
20761
22111
  return NULL;
20762
22112
  }
20763
22113
 
20764
- kv->value.arr.data = GGML_MALLOC(kv->value.arr.n * sizeof(struct gguf_str));
22114
+ kv->value.arr.data = GGML_CALLOC(kv->value.arr.n, sizeof(struct gguf_str));
20765
22115
 
20766
22116
  for (uint64_t j = 0; j < kv->value.arr.n; ++j) {
20767
22117
  ok = ok && gguf_fread_str(file, &((struct gguf_str *) kv->value.arr.data)[j], &offset);
@@ -20777,6 +22127,8 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
20777
22127
  if (!ok) {
20778
22128
  break;
20779
22129
  }
22130
+
22131
+ ctx->header.n_kv++;
20780
22132
  }
20781
22133
 
20782
22134
  if (!ok) {
@@ -20788,8 +22140,8 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
20788
22140
  }
20789
22141
 
20790
22142
  // read the tensor infos
20791
- {
20792
- ctx->infos = GGML_MALLOC(ctx->header.n_tensors * sizeof(struct gguf_tensor_info));
22143
+ if (ctx->header.n_tensors > 0) {
22144
+ ctx->infos = GGML_CALLOC(ctx->header.n_tensors, sizeof(struct gguf_tensor_info));
20793
22145
 
20794
22146
  for (uint64_t i = 0; i < ctx->header.n_tensors; ++i) {
20795
22147
  struct gguf_tensor_info * info = &ctx->infos[i];
@@ -20810,8 +22162,17 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
20810
22162
  ok = ok && gguf_fread_el (file, &info->type, sizeof(info->type), &offset);
20811
22163
  ok = ok && gguf_fread_el (file, &info->offset, sizeof(info->offset), &offset);
20812
22164
 
22165
+ // TODO: return an error instead of crashing with GGML_ASSERT
20813
22166
  gguf_tensor_info_sanitize(info);
20814
22167
 
22168
+ // make sure there is no duplicated tensor names
22169
+ for (uint64_t j = 0; j < i; ++j) {
22170
+ if (strcmp(info->name.data, ctx->infos[j].name.data) == 0) {
22171
+ fprintf(stderr, "%s: duplicated tensor name %s\n", __func__, info->name.data);
22172
+ ok = false;
22173
+ }
22174
+ }
22175
+
20815
22176
  if (!ok) {
20816
22177
  fprintf(stderr, "%s: failed to read tensor info\n", __func__);
20817
22178
  fclose(file);
@@ -20980,7 +22341,7 @@ void gguf_free(struct gguf_context * ctx) {
20980
22341
  GGML_FREE(ctx->infos);
20981
22342
  }
20982
22343
 
20983
- GGML_ALIGNED_FREE(ctx);
22344
+ GGML_FREE(ctx);
20984
22345
  }
20985
22346
 
20986
22347
  const char * gguf_type_name(enum gguf_type type) {
@@ -21291,7 +22652,7 @@ void gguf_set_arr_data(struct gguf_context * ctx, const char * key, enum gguf_ty
21291
22652
  ctx->kv[idx].type = GGUF_TYPE_ARRAY;
21292
22653
  ctx->kv[idx].value.arr.type = type;
21293
22654
  ctx->kv[idx].value.arr.n = n;
21294
- ctx->kv[idx].value.arr.data = GGML_MALLOC(n*gguf_type_size(type));
22655
+ ctx->kv[idx].value.arr.data = GGML_CALLOC(n, gguf_type_size(type));
21295
22656
  memcpy(ctx->kv[idx].value.arr.data, data, n*gguf_type_size(type));
21296
22657
  }
21297
22658
 
@@ -21301,7 +22662,7 @@ void gguf_set_arr_str(struct gguf_context * ctx, const char * key, const char **
21301
22662
  ctx->kv[idx].type = GGUF_TYPE_ARRAY;
21302
22663
  ctx->kv[idx].value.arr.type = GGUF_TYPE_STRING;
21303
22664
  ctx->kv[idx].value.arr.n = n;
21304
- ctx->kv[idx].value.arr.data = GGML_MALLOC(n*sizeof(struct gguf_str));
22665
+ ctx->kv[idx].value.arr.data = GGML_CALLOC(n, sizeof(struct gguf_str));
21305
22666
  for (int i = 0; i < n; i++) {
21306
22667
  struct gguf_str * str = &((struct gguf_str *)ctx->kv[idx].value.arr.data)[i];
21307
22668
  str->n = strlen(data[i]);
@@ -21328,7 +22689,7 @@ void gguf_set_kv(struct gguf_context * ctx, struct gguf_context * src) {
21328
22689
  case GGUF_TYPE_ARRAY:
21329
22690
  {
21330
22691
  if (src->kv[i].value.arr.type == GGUF_TYPE_STRING) {
21331
- const char ** data = GGML_MALLOC(src->kv[i].value.arr.n*sizeof(char *));
22692
+ const char ** data = GGML_CALLOC(src->kv[i].value.arr.n, sizeof(char *));
21332
22693
  for (uint32_t j = 0; j < src->kv[i].value.arr.n; j++) {
21333
22694
  data[j] = ((struct gguf_str *)src->kv[i].value.arr.data)[j].data;
21334
22695
  }
@@ -21348,6 +22709,10 @@ void gguf_set_kv(struct gguf_context * ctx, struct gguf_context * src) {
21348
22709
  void gguf_add_tensor(
21349
22710
  struct gguf_context * ctx,
21350
22711
  const struct ggml_tensor * tensor) {
22712
+ if (gguf_find_tensor(ctx, tensor->name) != -1) {
22713
+ GGML_ASSERT(false && "duplicated tensor name");
22714
+ }
22715
+
21351
22716
  const int idx = ctx->header.n_tensors;
21352
22717
  ctx->infos = realloc(ctx->infos, (idx + 1)*sizeof(struct gguf_tensor_info));
21353
22718
 
@@ -21416,7 +22781,7 @@ struct gguf_buf {
21416
22781
 
21417
22782
  static struct gguf_buf gguf_buf_init(size_t size) {
21418
22783
  struct gguf_buf buf = {
21419
- /*buf.data =*/ size == 0 ? NULL : GGML_MALLOC(size),
22784
+ /*buf.data =*/ size == 0 ? NULL : GGML_CALLOC(1, size),
21420
22785
  /*buf.size =*/ size,
21421
22786
  /*buf.offset =*/ 0,
21422
22787
  };