llama_cpp 0.14.7 → 0.15.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -322,7 +322,7 @@ static ggml_fp16_t ggml_table_exp_f16[1 << 16];
322
322
  // precomputed f32 table for f16 (256 KB) (ggml-impl.h)
323
323
  float ggml_table_f32_f16[1 << 16];
324
324
 
325
- const char * ggml_status_to_string(enum ggml_status status) {
325
+ GGML_CALL const char * ggml_status_to_string(enum ggml_status status) {
326
326
  switch (status) {
327
327
  case GGML_STATUS_ALLOC_FAILED: return "GGML status: error (failed to allocate memory)";
328
328
  case GGML_STATUS_FAILED: return "GGML status: error (operation failed)";
@@ -333,16 +333,26 @@ const char * ggml_status_to_string(enum ggml_status status) {
333
333
  return "GGML status: unknown";
334
334
  }
335
335
 
336
- // note: do not use these inside ggml.c
337
- // these are meant to be used via the ggml.h API
338
336
  float ggml_fp16_to_fp32(ggml_fp16_t x) {
337
+ #define ggml_fp16_to_fp32 do_not_use__ggml_fp16_to_fp32__in_ggml
339
338
  return GGML_FP16_TO_FP32(x);
340
339
  }
341
340
 
342
341
  ggml_fp16_t ggml_fp32_to_fp16(float x) {
342
+ #define ggml_fp32_to_fp16 do_not_use__ggml_fp32_to_fp16__in_ggml
343
343
  return GGML_FP32_TO_FP16(x);
344
344
  }
345
345
 
346
+ float ggml_bf16_to_fp32(ggml_bf16_t x) {
347
+ #define ggml_bf16_to_fp32 do_not_use__ggml_bf16_to_fp32__in_ggml
348
+ return GGML_BF16_TO_FP32(x); // it just left shifts
349
+ }
350
+
351
+ ggml_bf16_t ggml_fp32_to_bf16(float x) {
352
+ #define ggml_fp32_to_bf16 do_not_use__ggml_fp32_to_bf16__in_ggml
353
+ return GGML_FP32_TO_BF16(x);
354
+ }
355
+
346
356
  void ggml_fp16_to_fp32_row(const ggml_fp16_t * x, float * y, int64_t n) {
347
357
  for (int64_t i = 0; i < n; i++) {
348
358
  y[i] = GGML_FP16_TO_FP32(x[i]);
@@ -368,6 +378,49 @@ void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, int64_t n) {
368
378
  }
369
379
  }
370
380
 
381
+ void ggml_bf16_to_fp32_row(const ggml_bf16_t * x, float * y, int64_t n) {
382
+ int64_t i = 0;
383
+ #if defined(__AVX512F__)
384
+ for (; i + 16 <= n; i += 16) {
385
+ _mm512_storeu_ps(y + i,
386
+ _mm512_castsi512_ps(
387
+ _mm512_slli_epi32(
388
+ _mm512_cvtepu16_epi32(
389
+ _mm256_loadu_si256(
390
+ (const __m256i *)(x + i))),
391
+ 16)));
392
+ }
393
+ #elif defined(__AVX2__)
394
+ for (; i + 8 <= n; i += 8) {
395
+ _mm256_storeu_ps(y + i,
396
+ _mm256_castsi256_ps(
397
+ _mm256_slli_epi32(
398
+ _mm256_cvtepu16_epi32(
399
+ _mm_loadu_si128(
400
+ (const __m128i *)(x + i))),
401
+ 16)));
402
+ }
403
+ #endif
404
+ for (; i < n; i++) {
405
+ y[i] = GGML_BF16_TO_FP32(x[i]);
406
+ }
407
+ }
408
+
409
+ void ggml_fp32_to_bf16_row(const float * x, ggml_bf16_t * y, int64_t n) {
410
+ int i = 0;
411
+ #if defined(__AVX512BF16__)
412
+ for (; i + 32 <= n; i += 32) {
413
+ _mm512_storeu_ps(
414
+ (__m512 *)(y + i),
415
+ (__m512)_mm512_cvtne2ps_pbh(_mm512_loadu_ps(x + i + 16),
416
+ _mm512_loadu_ps(x + i)));
417
+ }
418
+ #endif
419
+ for (; i < n; i++) {
420
+ y[i] = GGML_FP32_TO_BF16(x[i]);
421
+ }
422
+ }
423
+
371
424
  bool ggml_guid_matches(ggml_guid_t guid_a, ggml_guid_t guid_b) {
372
425
  return memcmp(guid_a, guid_b, sizeof(ggml_guid)) == 0;
373
426
  }
@@ -503,6 +556,7 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
503
556
 
504
557
  static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float * restrict x, size_t bx, const float * restrict y, size_t by, int nrc);
505
558
  static void ggml_vec_dot_f16(int n, float * restrict s, size_t bs, ggml_fp16_t * restrict x, size_t bx, ggml_fp16_t * restrict y, size_t by, int nrc);
559
+ static void ggml_vec_dot_bf16(int n, float * restrict s, size_t bs, ggml_bf16_t * restrict x, size_t bx, ggml_bf16_t * restrict y, size_t by, int nrc);
506
560
 
507
561
  static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
508
562
  [GGML_TYPE_I8] = {
@@ -845,6 +899,18 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
845
899
  .type_size = sizeof(block_q8_K),
846
900
  .is_quantized = true,
847
901
  .from_float = quantize_row_q8_K,
902
+ },
903
+ [GGML_TYPE_BF16] = {
904
+ .type_name = "bf16",
905
+ .blck_size = 1,
906
+ .type_size = sizeof(ggml_bf16_t),
907
+ .is_quantized = false,
908
+ .to_float = (ggml_to_float_t) ggml_bf16_to_fp32_row,
909
+ .from_float = (ggml_from_float_t) ggml_fp32_to_bf16_row,
910
+ .from_float_reference = (ggml_from_float_t) ggml_fp32_to_bf16_row,
911
+ .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_bf16,
912
+ .vec_dot_type = GGML_TYPE_BF16,
913
+ .nrows = 1,
848
914
  }
849
915
  };
850
916
 
@@ -951,7 +1017,7 @@ ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type) {
951
1017
  #define GGML_F16_VEC_ZERO GGML_F16x8_ZERO
952
1018
  #define GGML_F16_VEC_SET1 GGML_F16x8_SET1
953
1019
  #define GGML_F16_VEC_LOAD(p, i) GGML_F16x8_LOAD(p)
954
- #define GGML_F16_VEC_STORE(p, r, i) GGML_F16x8_STORE(p, r[i])
1020
+ #define GGML_F16_VEC_STORE(p, r, i) GGML_F16x8_STORE((ggml_fp16_internal_t *)(p), r[i])
955
1021
  #define GGML_F16_VEC_FMA GGML_F16x8_FMA
956
1022
  #define GGML_F16_VEC_ADD GGML_F16x8_ADD
957
1023
  #define GGML_F16_VEC_MUL GGML_F16x8_MUL
@@ -977,7 +1043,7 @@ ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type) {
977
1043
  #define GGML_F16_VEC_ZERO GGML_F32Cx4_ZERO
978
1044
  #define GGML_F16_VEC_SET1 GGML_F32Cx4_SET1
979
1045
  #define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx4_LOAD(p)
980
- #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE(p, r[i])
1046
+ #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE((ggml_fp16_internal_t *)(p), r[i])
981
1047
  #define GGML_F16_VEC_FMA GGML_F32Cx4_FMA
982
1048
  #define GGML_F16_VEC_ADD GGML_F32Cx4_ADD
983
1049
  #define GGML_F16_VEC_MUL GGML_F32Cx4_MUL
@@ -1046,7 +1112,7 @@ do { \
1046
1112
 
1047
1113
  // unlike _mm256_cvt intrinsics that require F16C, _mm512_cvt is defined in AVX512F
1048
1114
  // so F16C guard isn't required
1049
- #define GGML_F32Cx16_LOAD(x) _mm512_cvtph_ps(_mm256_loadu_si256((__m256i *)(x)))
1115
+ #define GGML_F32Cx16_LOAD(x) _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(x)))
1050
1116
  #define GGML_F32Cx16_STORE(x, y) _mm256_storeu_si256((__m256i *)(x), _mm512_cvtps_ph(y, 0))
1051
1117
 
1052
1118
  #define GGML_F32Cx16_FMA(a, b, c) _mm512_fmadd_ps(b, c, a)
@@ -1144,7 +1210,7 @@ do { \
1144
1210
 
1145
1211
  #if defined(__F16C__)
1146
1212
  // the _mm256_cvt intrinsics require F16C
1147
- #define GGML_F32Cx8_LOAD(x) _mm256_cvtph_ps(_mm_loadu_si128((__m128i *)(x)))
1213
+ #define GGML_F32Cx8_LOAD(x) _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(x)))
1148
1214
  #define GGML_F32Cx8_STORE(x, y) _mm_storeu_si128((__m128i *)(x), _mm256_cvtps_ph(y, 0))
1149
1215
  #else
1150
1216
  static inline __m256 __avx_f32cx8_load(ggml_fp16_t *x) {
@@ -1480,6 +1546,8 @@ inline static void ggml_vec_set_i32(const int n, int32_t * x, const int32_t v) {
1480
1546
 
1481
1547
  inline static void ggml_vec_set_f16(const int n, ggml_fp16_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
1482
1548
 
1549
+ inline static void ggml_vec_set_bf16(const int n, ggml_bf16_t * x, const ggml_bf16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
1550
+
1483
1551
  inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] + y[i]; }
1484
1552
  inline static void ggml_vec_add1_f32(const int n, float * z, const float * x, const float v) { for (int i = 0; i < n; ++i) z[i] = x[i] + v; }
1485
1553
  inline static void ggml_vec_acc_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] += x[i]; }
@@ -1498,7 +1566,7 @@ static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float *
1498
1566
  UNUSED(by);
1499
1567
  UNUSED(bs);
1500
1568
 
1501
- #ifdef GGML_SIMD
1569
+ #if defined(GGML_SIMD)
1502
1570
  float sumf = 0.0f;
1503
1571
  const int np = (n & ~(GGML_F32_STEP - 1));
1504
1572
 
@@ -1534,6 +1602,70 @@ static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float *
1534
1602
  *s = sumf;
1535
1603
  }
1536
1604
 
1605
+ static void ggml_vec_dot_bf16(int n, float * restrict s, size_t bs, ggml_bf16_t * restrict x, size_t bx, ggml_bf16_t * restrict y, size_t by, int nrc) {
1606
+ assert(nrc == 1);
1607
+ UNUSED(nrc);
1608
+ UNUSED(bx);
1609
+ UNUSED(by);
1610
+ UNUSED(bs);
1611
+ int i = 0;
1612
+ ggml_float sumf = 0;
1613
+
1614
+ #if defined(__AVX512BF16__)
1615
+ __m512 c1 = _mm512_setzero_ps();
1616
+ __m512 c2 = _mm512_setzero_ps();
1617
+ for (; i + 64 <= n; i += 64) {
1618
+ c1 = _mm512_dpbf16_ps(c1, (__m512bh)_mm512_loadu_ps((const float *)(x + i)),
1619
+ (__m512bh)_mm512_loadu_ps((const float *)(y + i)));
1620
+ c2 = _mm512_dpbf16_ps(c2, (__m512bh)_mm512_loadu_ps((const float *)(x + i + 32)),
1621
+ (__m512bh)_mm512_loadu_ps((const float *)(y + i + 32)));
1622
+ }
1623
+ sumf += (ggml_float)_mm512_reduce_add_ps(c1);
1624
+ sumf += (ggml_float)_mm512_reduce_add_ps(c2);
1625
+
1626
+ #elif defined(__AVX512F__)
1627
+ #define LOAD(p) _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i *)(p))), 16))
1628
+ __m512 c1 = _mm512_setzero_ps();
1629
+ __m512 c2 = _mm512_setzero_ps();
1630
+ for (; i + 32 <= n; i += 32) {
1631
+ c1 = _mm512_add_ps(_mm512_mul_ps(LOAD(x + i), LOAD(y + i)), c1);
1632
+ c2 = _mm512_add_ps(_mm512_mul_ps(LOAD(x + i + 16), LOAD(y + i + 16)), c2);
1633
+ }
1634
+ sumf += (ggml_float)_mm512_reduce_add_ps(c1);
1635
+ sumf += (ggml_float)_mm512_reduce_add_ps(c2);
1636
+
1637
+ #undef LOAD
1638
+ #elif defined(__AVX2__)
1639
+ #define LOAD(p) _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)(p))), 16))
1640
+ __m256 c1 = _mm256_setzero_ps();
1641
+ __m256 c2 = _mm256_setzero_ps();
1642
+ __m256 c3 = _mm256_setzero_ps();
1643
+ __m256 c4 = _mm256_setzero_ps();
1644
+ for (; i + 32 <= n; i += 32) {
1645
+ c1 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i), LOAD(y + i)), c1);
1646
+ c2 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i + 8), LOAD(y + i + 8)), c2);
1647
+ c3 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i + 16), LOAD(y + i + 16)), c3);
1648
+ c4 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i + 24), LOAD(y + i + 24)), c4);
1649
+ }
1650
+ __m128 g;
1651
+ c1 = _mm256_add_ps(_mm256_add_ps(c1, c3),
1652
+ _mm256_add_ps(c2, c4));
1653
+ g = _mm_add_ps(_mm256_extractf128_ps(c1, 1),
1654
+ _mm256_castps256_ps128(c1));
1655
+ g = _mm_add_ps(g, _mm_movehl_ps(g, g));
1656
+ g = _mm_add_ss(g, _mm_movehdup_ps(g));
1657
+ sumf += (ggml_float)_mm_cvtss_f32(g);
1658
+
1659
+ #undef LOAD
1660
+ #endif
1661
+
1662
+ for (; i < n; ++i) {
1663
+ sumf += (ggml_float)(GGML_BF16_TO_FP32(x[i]) *
1664
+ GGML_BF16_TO_FP32(y[i]));
1665
+ }
1666
+ *s = sumf;
1667
+ }
1668
+
1537
1669
  static void ggml_vec_dot_f16(int n, float * restrict s, size_t bs, ggml_fp16_t * restrict x, size_t bx, ggml_fp16_t * restrict y, size_t by, int nrc) {
1538
1670
  assert(nrc == 1);
1539
1671
  UNUSED(nrc);
@@ -1662,6 +1794,37 @@ inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float
1662
1794
  #endif
1663
1795
  }
1664
1796
 
1797
+ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, const ggml_fp16_t * restrict x, const float v) {
1798
+ #if defined(GGML_SIMD)
1799
+ const int np = (n & ~(GGML_F16_STEP - 1));
1800
+
1801
+ GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
1802
+
1803
+ GGML_F16_VEC ax[GGML_F16_ARR];
1804
+ GGML_F16_VEC ay[GGML_F16_ARR];
1805
+
1806
+ for (int i = 0; i < np; i += GGML_F16_STEP) {
1807
+ for (int j = 0; j < GGML_F16_ARR; j++) {
1808
+ ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j);
1809
+ ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
1810
+ ay[j] = GGML_F16_VEC_FMA(ay[j], ax[j], vx);
1811
+
1812
+ GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
1813
+ }
1814
+ }
1815
+
1816
+ // leftovers
1817
+ for (int i = np; i < n; ++i) {
1818
+ y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v);
1819
+ }
1820
+ #else
1821
+ // scalar
1822
+ for (int i = 0; i < n; ++i) {
1823
+ y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v);
1824
+ }
1825
+ #endif
1826
+ }
1827
+
1665
1828
  // xs and vs are byte strides of x and v
1666
1829
  inline static void ggml_vec_mad_f32_unroll(const int n, const int xs, const int vs, float * restrict y, const float * restrict xv, const float * restrict vv) {
1667
1830
 
@@ -1746,6 +1909,35 @@ inline static void ggml_vec_scale_f32(const int n, float * y, const float v) {
1746
1909
  #endif
1747
1910
  }
1748
1911
 
1912
+ inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float v) {
1913
+ #if defined(GGML_SIMD)
1914
+ const int np = (n & ~(GGML_F16_STEP - 1));
1915
+
1916
+ GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
1917
+
1918
+ GGML_F16_VEC ay[GGML_F16_ARR];
1919
+
1920
+ for (int i = 0; i < np; i += GGML_F16_STEP) {
1921
+ for (int j = 0; j < GGML_F16_ARR; j++) {
1922
+ ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
1923
+ ay[j] = GGML_F16_VEC_MUL(ay[j], vx);
1924
+
1925
+ GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
1926
+ }
1927
+ }
1928
+
1929
+ // leftovers
1930
+ for (int i = np; i < n; ++i) {
1931
+ y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i])*v);
1932
+ }
1933
+ #else
1934
+ // scalar
1935
+ for (int i = 0; i < n; ++i) {
1936
+ y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i])*v);
1937
+ }
1938
+ #endif
1939
+ }
1940
+
1749
1941
  inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, 0, x, 0, x, 0, 1); *s = sqrtf(*s); }
1750
1942
  inline static void ggml_vec_sqr_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i]; }
1751
1943
  inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrtf(x[i]); }
@@ -1907,6 +2099,14 @@ inline static void ggml_vec_sum_f16_ggf(const int n, float * s, const ggml_fp16_
1907
2099
  *s = sum;
1908
2100
  }
1909
2101
 
2102
+ inline static void ggml_vec_sum_bf16_ggf(const int n, float * s, const ggml_bf16_t * x) {
2103
+ float sum = 0.0f;
2104
+ for (int i = 0; i < n; ++i) {
2105
+ sum += GGML_BF16_TO_FP32(x[i]);
2106
+ }
2107
+ *s = sum;
2108
+ }
2109
+
1910
2110
  inline static void ggml_vec_max_f32(const int n, float * s, const float * x) {
1911
2111
  #ifndef GGML_USE_ACCELERATE
1912
2112
  float max = -INFINITY;
@@ -2000,6 +2200,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
2000
2200
  "LEAKY_RELU",
2001
2201
 
2002
2202
  "FLASH_ATTN",
2203
+ "FLASH_ATTN_EXT",
2003
2204
  "FLASH_FF",
2004
2205
  "FLASH_ATTN_BACK",
2005
2206
  "SSM_CONV",
@@ -2026,7 +2227,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
2026
2227
  "CROSS_ENTROPY_LOSS_BACK",
2027
2228
  };
2028
2229
 
2029
- static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76");
2230
+ static_assert(GGML_OP_COUNT == 77, "GGML_OP_COUNT != 77");
2030
2231
 
2031
2232
  static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
2032
2233
  "none",
@@ -2090,6 +2291,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
2090
2291
  "leaky_relu(x)",
2091
2292
 
2092
2293
  "flash_attn(x)",
2294
+ "flash_attn_ext(x)",
2093
2295
  "flash_ff(x)",
2094
2296
  "flash_attn_back(x)",
2095
2297
  "ssm_conv(x)",
@@ -2116,7 +2318,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
2116
2318
  "cross_entropy_loss_back(x,y)",
2117
2319
  };
2118
2320
 
2119
- static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76");
2321
+ static_assert(GGML_OP_COUNT == 77, "GGML_OP_COUNT != 77");
2120
2322
 
2121
2323
  static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
2122
2324
 
@@ -2315,7 +2517,7 @@ void ggml_numa_init(enum ggml_numa_strategy numa_flag) {
2315
2517
  // figure out which node we're on
2316
2518
  uint current_cpu;
2317
2519
  int getcpu_ret = 0;
2318
- #if __GLIBC__ > 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ > 28)
2520
+ #if __GLIBC__ > 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ > 28) || defined(__COSMOPOLITAN__)
2319
2521
  getcpu_ret = getcpu(&current_cpu, &g_state.numa.current_node);
2320
2522
  #else
2321
2523
  // old glibc doesn't have a wrapper for this call. Fall back on direct syscall
@@ -2526,6 +2728,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
2526
2728
  switch (ftype) {
2527
2729
  case GGML_FTYPE_ALL_F32: wtype = GGML_TYPE_F32; break;
2528
2730
  case GGML_FTYPE_MOSTLY_F16: wtype = GGML_TYPE_F16; break;
2731
+ case GGML_FTYPE_MOSTLY_BF16: wtype = GGML_TYPE_BF16; break;
2529
2732
  case GGML_FTYPE_MOSTLY_Q4_0: wtype = GGML_TYPE_Q4_0; break;
2530
2733
  case GGML_FTYPE_MOSTLY_Q4_1: wtype = GGML_TYPE_Q4_1; break;
2531
2734
  case GGML_FTYPE_MOSTLY_Q5_0: wtype = GGML_TYPE_Q5_0; break;
@@ -2667,15 +2870,16 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
2667
2870
  {
2668
2871
  const uint64_t t_start = ggml_time_us(); UNUSED(t_start);
2669
2872
 
2670
- ggml_fp16_t ii;
2671
2873
  for (int i = 0; i < (1 << 16); ++i) {
2672
- uint16_t ui = i;
2673
- memcpy(&ii, &ui, sizeof(ii));
2674
- const float f = ggml_table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(ii);
2874
+ union {
2875
+ uint16_t u16;
2876
+ ggml_fp16_t fp16;
2877
+ } u = {i};
2878
+ float f = ggml_table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(u.fp16);
2675
2879
  ggml_table_gelu_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_f32(f));
2676
2880
  ggml_table_gelu_quick_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_quick_f32(f));
2677
2881
  ggml_table_silu_f16[i] = GGML_FP32_TO_FP16(ggml_silu_f32(f));
2678
- ggml_table_exp_f16[i] = GGML_FP32_TO_FP16(expf(f));
2882
+ ggml_table_exp_f16[i] = GGML_FP32_TO_FP16(expf(f));
2679
2883
  }
2680
2884
 
2681
2885
  const uint64_t t_end = ggml_time_us(); UNUSED(t_end);
@@ -3139,6 +3343,13 @@ struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value) {
3139
3343
  ggml_vec_set_f16(nc, (ggml_fp16_t *)(data + i*n1), GGML_FP32_TO_FP16(value));
3140
3344
  }
3141
3345
  } break;
3346
+ case GGML_TYPE_BF16:
3347
+ {
3348
+ assert(tensor->nb[0] == sizeof(ggml_fp16_t));
3349
+ for (int i = 0; i < n; i++) {
3350
+ ggml_vec_set_bf16(nc, (ggml_bf16_t *)(data + i*n1), GGML_FP32_TO_BF16(value));
3351
+ }
3352
+ } break;
3142
3353
  case GGML_TYPE_F32:
3143
3354
  {
3144
3355
  assert(tensor->nb[0] == sizeof(float));
@@ -3191,6 +3402,13 @@ struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) {
3191
3402
  ggml_vec_set_f16(nc, (ggml_fp16_t *)(data + i*n1), GGML_FP32_TO_FP16(value));
3192
3403
  }
3193
3404
  } break;
3405
+ case GGML_TYPE_BF16:
3406
+ {
3407
+ assert(tensor->nb[0] == sizeof(ggml_bf16_t));
3408
+ for (int i = 0; i < n; i++) {
3409
+ ggml_vec_set_bf16(nc, (ggml_bf16_t *)(data + i*n1), GGML_FP32_TO_BF16(value));
3410
+ }
3411
+ } break;
3194
3412
  case GGML_TYPE_F32:
3195
3413
  {
3196
3414
  assert(tensor->nb[0] == sizeof(float));
@@ -3258,6 +3476,11 @@ int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) {
3258
3476
  GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t));
3259
3477
  return GGML_FP16_TO_FP32(((ggml_fp16_t *)(tensor->data))[i]);
3260
3478
  }
3479
+ case GGML_TYPE_BF16:
3480
+ {
3481
+ GGML_ASSERT(tensor->nb[0] == sizeof(ggml_bf16_t));
3482
+ return GGML_BF16_TO_FP32(((ggml_bf16_t *)(tensor->data))[i]);
3483
+ }
3261
3484
  case GGML_TYPE_F32:
3262
3485
  {
3263
3486
  GGML_ASSERT(tensor->nb[0] == sizeof(float));
@@ -3300,6 +3523,11 @@ void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) {
3300
3523
  GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t));
3301
3524
  ((ggml_fp16_t *)(tensor->data))[i] = GGML_FP32_TO_FP16(value);
3302
3525
  } break;
3526
+ case GGML_TYPE_BF16:
3527
+ {
3528
+ GGML_ASSERT(tensor->nb[0] == sizeof(ggml_bf16_t));
3529
+ ((ggml_bf16_t *)(tensor->data))[i] = GGML_FP32_TO_BF16(value);
3530
+ } break;
3303
3531
  case GGML_TYPE_F32:
3304
3532
  {
3305
3533
  GGML_ASSERT(tensor->nb[0] == sizeof(float));
@@ -3323,6 +3551,8 @@ int32_t ggml_get_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i
3323
3551
  return ((int32_t *) data)[0];
3324
3552
  case GGML_TYPE_F16:
3325
3553
  return GGML_FP16_TO_FP32(((ggml_fp16_t *) data)[0]);
3554
+ case GGML_TYPE_BF16:
3555
+ return GGML_BF16_TO_FP32(((ggml_bf16_t *) data)[0]);
3326
3556
  case GGML_TYPE_F32:
3327
3557
  return ((float *) data)[0];
3328
3558
  default:
@@ -3351,6 +3581,10 @@ void ggml_set_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2,
3351
3581
  {
3352
3582
  ((ggml_fp16_t *)(data))[0] = GGML_FP32_TO_FP16(value);
3353
3583
  } break;
3584
+ case GGML_TYPE_BF16:
3585
+ {
3586
+ ((ggml_bf16_t *)(data))[0] = GGML_FP32_TO_BF16(value);
3587
+ } break;
3354
3588
  case GGML_TYPE_F32:
3355
3589
  {
3356
3590
  ((float *)(data))[0] = value;
@@ -3389,6 +3623,11 @@ float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) {
3389
3623
  GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t));
3390
3624
  return GGML_FP16_TO_FP32(((ggml_fp16_t *)(tensor->data))[i]);
3391
3625
  }
3626
+ case GGML_TYPE_BF16:
3627
+ {
3628
+ GGML_ASSERT(tensor->nb[0] == sizeof(ggml_bf16_t));
3629
+ return GGML_BF16_TO_FP32(((ggml_bf16_t *)(tensor->data))[i]);
3630
+ }
3392
3631
  case GGML_TYPE_F32:
3393
3632
  {
3394
3633
  GGML_ASSERT(tensor->nb[0] == sizeof(float));
@@ -3431,6 +3670,11 @@ void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) {
3431
3670
  GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t));
3432
3671
  ((ggml_fp16_t *)(tensor->data))[i] = GGML_FP32_TO_FP16(value);
3433
3672
  } break;
3673
+ case GGML_TYPE_BF16:
3674
+ {
3675
+ GGML_ASSERT(tensor->nb[0] == sizeof(ggml_bf16_t));
3676
+ ((ggml_bf16_t *)(tensor->data))[i] = GGML_FP32_TO_BF16(value);
3677
+ } break;
3434
3678
  case GGML_TYPE_F32:
3435
3679
  {
3436
3680
  GGML_ASSERT(tensor->nb[0] == sizeof(float));
@@ -3454,6 +3698,8 @@ float ggml_get_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2,
3454
3698
  return ((int32_t *) data)[0];
3455
3699
  case GGML_TYPE_F16:
3456
3700
  return GGML_FP16_TO_FP32(((ggml_fp16_t *) data)[0]);
3701
+ case GGML_TYPE_BF16:
3702
+ return GGML_BF16_TO_FP32(((ggml_bf16_t *) data)[0]);
3457
3703
  case GGML_TYPE_F32:
3458
3704
  return ((float *) data)[0];
3459
3705
  default:
@@ -3482,6 +3728,10 @@ void ggml_set_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2,
3482
3728
  {
3483
3729
  ((ggml_fp16_t *)(data))[0] = GGML_FP32_TO_FP16(value);
3484
3730
  } break;
3731
+ case GGML_TYPE_BF16:
3732
+ {
3733
+ ((ggml_bf16_t *)(data))[0] = GGML_FP32_TO_BF16(value);
3734
+ } break;
3485
3735
  case GGML_TYPE_F32:
3486
3736
  {
3487
3737
  ((float *)(data))[0] = value;
@@ -3676,7 +3926,11 @@ static struct ggml_tensor * ggml_add_cast_impl(
3676
3926
  // TODO: support less-strict constraint
3677
3927
  // GGML_ASSERT(ggml_can_repeat(b, a));
3678
3928
  GGML_ASSERT(ggml_can_repeat_rows(b, a));
3679
- GGML_ASSERT(ggml_is_quantized(a->type) || a->type == GGML_TYPE_F16); // currently only supported for quantized input and f16
3929
+
3930
+ // currently only supported for quantized input and f16
3931
+ GGML_ASSERT(ggml_is_quantized(a->type) ||
3932
+ a->type == GGML_TYPE_F16 ||
3933
+ a->type == GGML_TYPE_BF16);
3680
3934
 
3681
3935
  bool is_node = false;
3682
3936
 
@@ -4559,6 +4813,8 @@ struct ggml_tensor * ggml_mul_mat(
4559
4813
  void ggml_mul_mat_set_prec(
4560
4814
  struct ggml_tensor * a,
4561
4815
  enum ggml_prec prec) {
4816
+ GGML_ASSERT(a->op == GGML_OP_MUL_MAT);
4817
+
4562
4818
  const int32_t prec_i32 = (int32_t) prec;
4563
4819
 
4564
4820
  ggml_set_op_params_i32(a, 0, prec_i32);
@@ -5397,17 +5653,23 @@ static struct ggml_tensor * ggml_soft_max_impl(
5397
5653
  GGML_ASSERT(ggml_is_contiguous(a));
5398
5654
 
5399
5655
  if (mask) {
5656
+ GGML_ASSERT(mask->type == GGML_TYPE_F16 || mask->type == GGML_TYPE_F32);
5400
5657
  GGML_ASSERT(ggml_is_contiguous(mask));
5401
5658
  GGML_ASSERT(ggml_is_matrix(mask));
5402
- GGML_ASSERT(ggml_can_repeat_rows(mask, a));
5659
+ GGML_ASSERT(mask->ne[0] == a->ne[0]);
5660
+ GGML_ASSERT(mask->ne[1] >= a->ne[1]);
5403
5661
  }
5404
5662
 
5405
5663
  if (pos) {
5406
5664
  GGML_ASSERT(ggml_is_vector(pos));
5407
- GGML_ASSERT(pos->type == GGML_TYPE_F32);
5665
+ GGML_ASSERT(pos->type == GGML_TYPE_F16 || pos->type == GGML_TYPE_F32);
5408
5666
  GGML_ASSERT(pos->ne[0] == a->ne[0]);
5409
5667
  }
5410
5668
 
5669
+ if (pos && mask) {
5670
+ GGML_ASSERT(pos->type == mask->type);
5671
+ }
5672
+
5411
5673
  if (max_bias > 0.0f) {
5412
5674
  GGML_ASSERT(pos);
5413
5675
  }
@@ -6216,6 +6478,59 @@ struct ggml_tensor * ggml_flash_attn(
6216
6478
  return result;
6217
6479
  }
6218
6480
 
6481
+ // ggml_flash_attn_ext
6482
+
6483
+ struct ggml_tensor * ggml_flash_attn_ext(
6484
+ struct ggml_context * ctx,
6485
+ struct ggml_tensor * q,
6486
+ struct ggml_tensor * k,
6487
+ struct ggml_tensor * v,
6488
+ struct ggml_tensor * mask,
6489
+ float scale) {
6490
+ GGML_ASSERT(ggml_can_mul_mat(k, q));
6491
+ // TODO: check if vT can be multiplied by (k*qT)
6492
+ if (mask) {
6493
+ GGML_ASSERT(ggml_is_contiguous(mask));
6494
+ GGML_ASSERT(mask->ne[2] == 1);
6495
+ GGML_ASSERT(mask->ne[3] == 1);
6496
+ GGML_ASSERT(mask->ne[1] >= GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD) &&
6497
+ "the Flash-Attention kernel requires the mask to be padded to GGML_KQ_MASK_PAD and at least n_queries big");
6498
+ //GGML_ASSERT(ggml_can_repeat_rows(mask, qk));
6499
+ }
6500
+
6501
+ bool is_node = false;
6502
+
6503
+ if (q->grad || k->grad || v->grad) {
6504
+ is_node = true;
6505
+ }
6506
+
6507
+ // permute(0, 2, 1, 3)
6508
+ int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] };
6509
+ struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
6510
+
6511
+ float params[] = { scale };
6512
+ ggml_set_op_params(result, params, sizeof(params));
6513
+
6514
+ result->op = GGML_OP_FLASH_ATTN_EXT;
6515
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6516
+ result->src[0] = q;
6517
+ result->src[1] = k;
6518
+ result->src[2] = v;
6519
+ result->src[3] = mask;
6520
+
6521
+ return result;
6522
+ }
6523
+
6524
+ void ggml_flash_attn_ext_set_prec(
6525
+ struct ggml_tensor * a,
6526
+ enum ggml_prec prec) {
6527
+ GGML_ASSERT(a->op == GGML_OP_FLASH_ATTN_EXT);
6528
+
6529
+ const int32_t prec_i32 = (int32_t) prec;
6530
+
6531
+ ggml_set_op_params_i32(a, 1, prec_i32); // scale is on first pos
6532
+ }
6533
+
6219
6534
  // ggml_flash_ff
6220
6535
 
6221
6536
  struct ggml_tensor * ggml_flash_ff(
@@ -7092,8 +7407,8 @@ static void ggml_compute_forward_dup_same_cont(
7092
7407
  ((char *) src0->data + ie0*nb00),
7093
7408
  (ie1 - ie0) * ggml_type_size(src0->type));
7094
7409
  }
7095
-
7096
7410
  }
7411
+
7097
7412
  static void ggml_compute_forward_dup_f16(
7098
7413
  const struct ggml_compute_params * params,
7099
7414
  struct ggml_tensor * dst) {
@@ -7367,7 +7682,7 @@ static void ggml_compute_forward_dup_f16(
7367
7682
  }
7368
7683
  }
7369
7684
 
7370
- static void ggml_compute_forward_dup_f32(
7685
+ static void ggml_compute_forward_dup_bf16(
7371
7686
  const struct ggml_compute_params * params,
7372
7687
  struct ggml_tensor * dst) {
7373
7688
 
@@ -7415,10 +7730,11 @@ static void ggml_compute_forward_dup_f32(
7415
7730
  return;
7416
7731
  }
7417
7732
 
7733
+ // TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy
7734
+
7418
7735
  if (ggml_is_contiguous(dst)) {
7419
- // TODO: simplify
7420
- if (nb00 == sizeof(float)) {
7421
- if (dst->type == GGML_TYPE_F32) {
7736
+ if (nb00 == sizeof(ggml_bf16_t)) {
7737
+ if (dst->type == GGML_TYPE_BF16) {
7422
7738
  size_t id = 0;
7423
7739
  const size_t rs = ne00 * nb00;
7424
7740
  char * dst_ptr = (char *) dst->data;
@@ -7434,8 +7750,43 @@ static void ggml_compute_forward_dup_f32(
7434
7750
  id += rs * (ne01 - ir1);
7435
7751
  }
7436
7752
  }
7753
+ } else if (dst->type == GGML_TYPE_F16) {
7754
+ size_t id = 0;
7755
+ ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
7756
+
7757
+ for (int i03 = 0; i03 < ne03; i03++) {
7758
+ for (int i02 = 0; i02 < ne02; i02++) {
7759
+ id += ne00 * ir0;
7760
+ for (int i01 = ir0; i01 < ir1; i01++) {
7761
+ const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
7762
+ for (int i00 = 0; i00 < ne00; i00++) {
7763
+ dst_ptr[id] = GGML_FP32_TO_FP16(GGML_BF16_TO_FP32(src0_ptr[i00]));
7764
+ id++;
7765
+ }
7766
+ }
7767
+ id += ne00 * (ne01 - ir1);
7768
+ }
7769
+ }
7770
+ } else if (dst->type == GGML_TYPE_F32) {
7771
+ size_t id = 0;
7772
+ float * dst_ptr = (float *) dst->data;
7773
+
7774
+ for (int i03 = 0; i03 < ne03; i03++) {
7775
+ for (int i02 = 0; i02 < ne02; i02++) {
7776
+ id += ne00 * ir0;
7777
+ for (int i01 = ir0; i01 < ir1; i01++) {
7778
+ const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
7779
+ for (int i00 = 0; i00 < ne00; i00++) {
7780
+ dst_ptr[id] = GGML_BF16_TO_FP32(src0_ptr[i00]);
7781
+ id++;
7782
+ }
7783
+ }
7784
+ id += ne00 * (ne01 - ir1);
7785
+ }
7786
+ }
7437
7787
  } else if (type_traits[dst->type].from_float) {
7438
7788
  ggml_from_float_t const quantize_row_q = type_traits[dst->type].from_float;
7789
+ float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
7439
7790
 
7440
7791
  size_t id = 0;
7441
7792
  size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
@@ -7445,8 +7796,13 @@ static void ggml_compute_forward_dup_f32(
7445
7796
  for (int i02 = 0; i02 < ne02; i02++) {
7446
7797
  id += rs * ir0;
7447
7798
  for (int i01 = ir0; i01 < ir1; i01++) {
7448
- const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
7449
- quantize_row_q(src0_ptr, dst_ptr + id, ne00);
7799
+ const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
7800
+
7801
+ for (int i00 = 0; i00 < ne00; i00++) {
7802
+ src0_f32[i00] = GGML_BF16_TO_FP32(src0_ptr[i00]);
7803
+ }
7804
+
7805
+ quantize_row_q(src0_f32, dst_ptr + id, ne00);
7450
7806
  id += rs;
7451
7807
  }
7452
7808
  id += rs * (ne01 - ir1);
@@ -7467,7 +7823,25 @@ static void ggml_compute_forward_dup_f32(
7467
7823
  id += ne00 * ir0;
7468
7824
  for (int i01 = ir0; i01 < ir1; i01++) {
7469
7825
  for (int i00 = 0; i00 < ne00; i00++) {
7470
- const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
7826
+ const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
7827
+
7828
+ dst_ptr[id] = GGML_BF16_TO_FP32(*src0_ptr);
7829
+ id++;
7830
+ }
7831
+ }
7832
+ id += ne00 * (ne01 - ir1);
7833
+ }
7834
+ }
7835
+ } else if (dst->type == GGML_TYPE_BF16) {
7836
+ size_t id = 0;
7837
+ ggml_bf16_t * dst_ptr = (ggml_bf16_t *) dst->data;
7838
+
7839
+ for (int i03 = 0; i03 < ne03; i03++) {
7840
+ for (int i02 = 0; i02 < ne02; i02++) {
7841
+ id += ne00 * ir0;
7842
+ for (int i01 = ir0; i01 < ir1; i01++) {
7843
+ for (int i00 = 0; i00 < ne00; i00++) {
7844
+ const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
7471
7845
 
7472
7846
  dst_ptr[id] = *src0_ptr;
7473
7847
  id++;
@@ -7485,9 +7859,9 @@ static void ggml_compute_forward_dup_f32(
7485
7859
  id += ne00 * ir0;
7486
7860
  for (int i01 = ir0; i01 < ir1; i01++) {
7487
7861
  for (int i00 = 0; i00 < ne00; i00++) {
7488
- const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
7862
+ const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
7489
7863
 
7490
- dst_ptr[id] = GGML_FP32_TO_FP16(*src0_ptr);
7864
+ dst_ptr[id] = GGML_FP32_TO_FP16(GGML_BF16_TO_FP32(*src0_ptr));
7491
7865
  id++;
7492
7866
  }
7493
7867
  }
@@ -7498,18 +7872,16 @@ static void ggml_compute_forward_dup_f32(
7498
7872
  GGML_ASSERT(false); // TODO: implement
7499
7873
  }
7500
7874
  }
7501
-
7502
7875
  return;
7503
7876
  }
7504
7877
 
7505
7878
  // dst counters
7506
-
7507
7879
  int64_t i10 = 0;
7508
7880
  int64_t i11 = 0;
7509
7881
  int64_t i12 = 0;
7510
7882
  int64_t i13 = 0;
7511
7883
 
7512
- if (dst->type == GGML_TYPE_F32) {
7884
+ if (dst->type == GGML_TYPE_BF16) {
7513
7885
  for (int64_t i03 = 0; i03 < ne03; i03++) {
7514
7886
  for (int64_t i02 = 0; i02 < ne02; i02++) {
7515
7887
  i10 += ne00 * ir0;
@@ -7530,15 +7902,15 @@ static void ggml_compute_forward_dup_f32(
7530
7902
  const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
7531
7903
  char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
7532
7904
 
7533
- memcpy(dst_ptr, src0_ptr, sizeof(float));
7905
+ memcpy(dst_ptr, src0_ptr, sizeof(ggml_bf16_t));
7534
7906
 
7535
- if (++i10 == ne0) {
7907
+ if (++i10 == ne00) {
7536
7908
  i10 = 0;
7537
- if (++i11 == ne1) {
7909
+ if (++i11 == ne01) {
7538
7910
  i11 = 0;
7539
- if (++i12 == ne2) {
7911
+ if (++i12 == ne02) {
7540
7912
  i12 = 0;
7541
- if (++i13 == ne3) {
7913
+ if (++i13 == ne03) {
7542
7914
  i13 = 0;
7543
7915
  }
7544
7916
  }
@@ -7582,7 +7954,7 @@ static void ggml_compute_forward_dup_f32(
7582
7954
  const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
7583
7955
  char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
7584
7956
 
7585
- *(ggml_fp16_t *) dst_ptr = GGML_FP32_TO_FP16(*(const float *) src0_ptr);
7957
+ *(ggml_fp16_t *) dst_ptr = GGML_FP32_TO_FP16(GGML_BF16_TO_FP32(*(const ggml_bf16_t *) src0_ptr));
7586
7958
 
7587
7959
  if (++i10 == ne0) {
7588
7960
  i10 = 0;
@@ -7613,10 +7985,383 @@ static void ggml_compute_forward_dup_f32(
7613
7985
  }
7614
7986
  }
7615
7987
  }
7616
- } else {
7617
- GGML_ASSERT(false); // TODO: implement
7618
- }
7619
- }
7988
+ } else if (dst->type == GGML_TYPE_F32) {
7989
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
7990
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
7991
+ i10 += ne00 * ir0;
7992
+ while (i10 >= ne0) {
7993
+ i10 -= ne0;
7994
+ if (++i11 == ne1) {
7995
+ i11 = 0;
7996
+ if (++i12 == ne2) {
7997
+ i12 = 0;
7998
+ if (++i13 == ne3) {
7999
+ i13 = 0;
8000
+ }
8001
+ }
8002
+ }
8003
+ }
8004
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
8005
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
8006
+ const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
8007
+ char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
8008
+
8009
+ *(float *) dst_ptr = GGML_BF16_TO_FP32(*(const ggml_bf16_t *) src0_ptr);
8010
+
8011
+ if (++i10 == ne0) {
8012
+ i10 = 0;
8013
+ if (++i11 == ne1) {
8014
+ i11 = 0;
8015
+ if (++i12 == ne2) {
8016
+ i12 = 0;
8017
+ if (++i13 == ne3) {
8018
+ i13 = 0;
8019
+ }
8020
+ }
8021
+ }
8022
+ }
8023
+ }
8024
+ }
8025
+ i10 += ne00 * (ne01 - ir1);
8026
+ while (i10 >= ne0) {
8027
+ i10 -= ne0;
8028
+ if (++i11 == ne1) {
8029
+ i11 = 0;
8030
+ if (++i12 == ne2) {
8031
+ i12 = 0;
8032
+ if (++i13 == ne3) {
8033
+ i13 = 0;
8034
+ }
8035
+ }
8036
+ }
8037
+ }
8038
+ }
8039
+ }
8040
+ } else {
8041
+ GGML_ASSERT(false); // TODO: implement
8042
+ }
8043
+ }
8044
+
8045
+ static void ggml_compute_forward_dup_f32(
8046
+ const struct ggml_compute_params * params,
8047
+ struct ggml_tensor * dst) {
8048
+
8049
+ const struct ggml_tensor * src0 = dst->src[0];
8050
+
8051
+ GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
8052
+
8053
+ if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
8054
+ return;
8055
+ }
8056
+
8057
+ GGML_TENSOR_UNARY_OP_LOCALS
8058
+
8059
+ const int ith = params->ith; // thread index
8060
+ const int nth = params->nth; // number of threads
8061
+
8062
+ if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst) && src0->type == dst->type) {
8063
+ ggml_compute_forward_dup_same_cont(params, dst);
8064
+ return;
8065
+ }
8066
+
8067
+ // parallelize by rows
8068
+ const int nr = ne01;
8069
+ // number of rows per thread
8070
+ const int dr = (nr + nth - 1) / nth;
8071
+ // row range for this thread
8072
+ const int ir0 = dr * ith;
8073
+ const int ir1 = MIN(ir0 + dr, nr);
8074
+
8075
+ if (src0->type == dst->type &&
8076
+ ne00 == ne0 &&
8077
+ nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) {
8078
+ // copy by rows
8079
+ const size_t rs = ne00*nb00;
8080
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
8081
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
8082
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
8083
+ memcpy(
8084
+ ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
8085
+ ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
8086
+ rs);
8087
+ }
8088
+ }
8089
+ }
8090
+ return;
8091
+ }
8092
+
8093
+ if (ggml_is_contiguous(dst)) {
8094
+ // TODO: simplify
8095
+ if (nb00 == sizeof(float)) {
8096
+ if (dst->type == GGML_TYPE_F32) {
8097
+ size_t id = 0;
8098
+ const size_t rs = ne00 * nb00;
8099
+ char * dst_ptr = (char *) dst->data;
8100
+
8101
+ for (int i03 = 0; i03 < ne03; i03++) {
8102
+ for (int i02 = 0; i02 < ne02; i02++) {
8103
+ id += rs * ir0;
8104
+ for (int i01 = ir0; i01 < ir1; i01++) {
8105
+ const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
8106
+ memcpy(dst_ptr + id, src0_ptr, rs);
8107
+ id += rs;
8108
+ }
8109
+ id += rs * (ne01 - ir1);
8110
+ }
8111
+ }
8112
+ } else if (type_traits[dst->type].from_float) {
8113
+ ggml_from_float_t const quantize_row_q = type_traits[dst->type].from_float;
8114
+
8115
+ size_t id = 0;
8116
+ size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
8117
+ char * dst_ptr = (char *) dst->data;
8118
+
8119
+ for (int i03 = 0; i03 < ne03; i03++) {
8120
+ for (int i02 = 0; i02 < ne02; i02++) {
8121
+ id += rs * ir0;
8122
+ for (int i01 = ir0; i01 < ir1; i01++) {
8123
+ const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
8124
+ quantize_row_q(src0_ptr, dst_ptr + id, ne00);
8125
+ id += rs;
8126
+ }
8127
+ id += rs * (ne01 - ir1);
8128
+ }
8129
+ }
8130
+ } else {
8131
+ GGML_ASSERT(false); // TODO: implement
8132
+ }
8133
+ } else {
8134
+ //printf("%s: this is not optimal - fix me\n", __func__);
8135
+
8136
+ if (dst->type == GGML_TYPE_F32) {
8137
+ size_t id = 0;
8138
+ float * dst_ptr = (float *) dst->data;
8139
+
8140
+ for (int i03 = 0; i03 < ne03; i03++) {
8141
+ for (int i02 = 0; i02 < ne02; i02++) {
8142
+ id += ne00 * ir0;
8143
+ for (int i01 = ir0; i01 < ir1; i01++) {
8144
+ for (int i00 = 0; i00 < ne00; i00++) {
8145
+ const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
8146
+
8147
+ dst_ptr[id] = *src0_ptr;
8148
+ id++;
8149
+ }
8150
+ }
8151
+ id += ne00 * (ne01 - ir1);
8152
+ }
8153
+ }
8154
+ } else if (dst->type == GGML_TYPE_F16) {
8155
+ size_t id = 0;
8156
+ ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
8157
+
8158
+ for (int i03 = 0; i03 < ne03; i03++) {
8159
+ for (int i02 = 0; i02 < ne02; i02++) {
8160
+ id += ne00 * ir0;
8161
+ for (int i01 = ir0; i01 < ir1; i01++) {
8162
+ for (int i00 = 0; i00 < ne00; i00++) {
8163
+ const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
8164
+
8165
+ dst_ptr[id] = GGML_FP32_TO_FP16(*src0_ptr);
8166
+ id++;
8167
+ }
8168
+ }
8169
+ id += ne00 * (ne01 - ir1);
8170
+ }
8171
+ }
8172
+ } else if (dst->type == GGML_TYPE_BF16) {
8173
+ size_t id = 0;
8174
+ ggml_bf16_t * dst_ptr = (ggml_bf16_t *) dst->data;
8175
+
8176
+ for (int i03 = 0; i03 < ne03; i03++) {
8177
+ for (int i02 = 0; i02 < ne02; i02++) {
8178
+ id += ne00 * ir0;
8179
+ for (int i01 = ir0; i01 < ir1; i01++) {
8180
+ for (int i00 = 0; i00 < ne00; i00++) {
8181
+ const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
8182
+
8183
+ dst_ptr[id] = GGML_FP32_TO_BF16(*src0_ptr);
8184
+ id++;
8185
+ }
8186
+ }
8187
+ id += ne00 * (ne01 - ir1);
8188
+ }
8189
+ }
8190
+ } else {
8191
+ GGML_ASSERT(false); // TODO: implement
8192
+ }
8193
+ }
8194
+
8195
+ return;
8196
+ }
8197
+
8198
+ // dst counters
8199
+
8200
+ int64_t i10 = 0;
8201
+ int64_t i11 = 0;
8202
+ int64_t i12 = 0;
8203
+ int64_t i13 = 0;
8204
+
8205
+ if (dst->type == GGML_TYPE_F32) {
8206
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
8207
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
8208
+ i10 += ne00 * ir0;
8209
+ while (i10 >= ne0) {
8210
+ i10 -= ne0;
8211
+ if (++i11 == ne1) {
8212
+ i11 = 0;
8213
+ if (++i12 == ne2) {
8214
+ i12 = 0;
8215
+ if (++i13 == ne3) {
8216
+ i13 = 0;
8217
+ }
8218
+ }
8219
+ }
8220
+ }
8221
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
8222
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
8223
+ const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
8224
+ char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
8225
+
8226
+ memcpy(dst_ptr, src0_ptr, sizeof(float));
8227
+
8228
+ if (++i10 == ne0) {
8229
+ i10 = 0;
8230
+ if (++i11 == ne1) {
8231
+ i11 = 0;
8232
+ if (++i12 == ne2) {
8233
+ i12 = 0;
8234
+ if (++i13 == ne3) {
8235
+ i13 = 0;
8236
+ }
8237
+ }
8238
+ }
8239
+ }
8240
+ }
8241
+ }
8242
+ i10 += ne00 * (ne01 - ir1);
8243
+ while (i10 >= ne0) {
8244
+ i10 -= ne0;
8245
+ if (++i11 == ne1) {
8246
+ i11 = 0;
8247
+ if (++i12 == ne2) {
8248
+ i12 = 0;
8249
+ if (++i13 == ne3) {
8250
+ i13 = 0;
8251
+ }
8252
+ }
8253
+ }
8254
+ }
8255
+ }
8256
+ }
8257
+ } else if (dst->type == GGML_TYPE_F16) {
8258
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
8259
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
8260
+ i10 += ne00 * ir0;
8261
+ while (i10 >= ne0) {
8262
+ i10 -= ne0;
8263
+ if (++i11 == ne1) {
8264
+ i11 = 0;
8265
+ if (++i12 == ne2) {
8266
+ i12 = 0;
8267
+ if (++i13 == ne3) {
8268
+ i13 = 0;
8269
+ }
8270
+ }
8271
+ }
8272
+ }
8273
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
8274
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
8275
+ const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
8276
+ char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
8277
+
8278
+ *(ggml_fp16_t *) dst_ptr = GGML_FP32_TO_FP16(*(const float *) src0_ptr);
8279
+
8280
+ if (++i10 == ne0) {
8281
+ i10 = 0;
8282
+ if (++i11 == ne1) {
8283
+ i11 = 0;
8284
+ if (++i12 == ne2) {
8285
+ i12 = 0;
8286
+ if (++i13 == ne3) {
8287
+ i13 = 0;
8288
+ }
8289
+ }
8290
+ }
8291
+ }
8292
+ }
8293
+ }
8294
+ i10 += ne00 * (ne01 - ir1);
8295
+ while (i10 >= ne0) {
8296
+ i10 -= ne0;
8297
+ if (++i11 == ne1) {
8298
+ i11 = 0;
8299
+ if (++i12 == ne2) {
8300
+ i12 = 0;
8301
+ if (++i13 == ne3) {
8302
+ i13 = 0;
8303
+ }
8304
+ }
8305
+ }
8306
+ }
8307
+ }
8308
+ }
8309
+ } else if (dst->type == GGML_TYPE_BF16) {
8310
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
8311
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
8312
+ i10 += ne00 * ir0;
8313
+ while (i10 >= ne0) {
8314
+ i10 -= ne0;
8315
+ if (++i11 == ne1) {
8316
+ i11 = 0;
8317
+ if (++i12 == ne2) {
8318
+ i12 = 0;
8319
+ if (++i13 == ne3) {
8320
+ i13 = 0;
8321
+ }
8322
+ }
8323
+ }
8324
+ }
8325
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
8326
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
8327
+ const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
8328
+ char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
8329
+
8330
+ *(ggml_bf16_t *) dst_ptr = GGML_FP32_TO_BF16(*(const float *) src0_ptr);
8331
+
8332
+ if (++i10 == ne0) {
8333
+ i10 = 0;
8334
+ if (++i11 == ne1) {
8335
+ i11 = 0;
8336
+ if (++i12 == ne2) {
8337
+ i12 = 0;
8338
+ if (++i13 == ne3) {
8339
+ i13 = 0;
8340
+ }
8341
+ }
8342
+ }
8343
+ }
8344
+ }
8345
+ }
8346
+ i10 += ne00 * (ne01 - ir1);
8347
+ while (i10 >= ne0) {
8348
+ i10 -= ne0;
8349
+ if (++i11 == ne1) {
8350
+ i11 = 0;
8351
+ if (++i12 == ne2) {
8352
+ i12 = 0;
8353
+ if (++i13 == ne3) {
8354
+ i13 = 0;
8355
+ }
8356
+ }
8357
+ }
8358
+ }
8359
+ }
8360
+ }
8361
+ } else {
8362
+ GGML_ASSERT(false); // TODO: implement
8363
+ }
8364
+ }
7620
8365
 
7621
8366
  // A simplified version of ggml_compute_forward_dup that doesn't do float upcasting, and just plain old memcpy.
7622
8367
  static void ggml_compute_forward_dup_bytes(
@@ -7786,6 +8531,10 @@ static void ggml_compute_forward_dup(
7786
8531
  {
7787
8532
  ggml_compute_forward_dup_f16(params, dst);
7788
8533
  } break;
8534
+ case GGML_TYPE_BF16:
8535
+ {
8536
+ ggml_compute_forward_dup_bf16(params, dst);
8537
+ } break;
7789
8538
  case GGML_TYPE_F32:
7790
8539
  {
7791
8540
  ggml_compute_forward_dup_f32(params, dst);
@@ -7968,6 +8717,85 @@ static void ggml_compute_forward_add_f16_f32(
7968
8717
  }
7969
8718
  }
7970
8719
 
8720
+ static void ggml_compute_forward_add_bf16_f32(
8721
+ const struct ggml_compute_params * params,
8722
+ struct ggml_tensor * dst) {
8723
+
8724
+ const struct ggml_tensor * src0 = dst->src[0];
8725
+ const struct ggml_tensor * src1 = dst->src[1];
8726
+
8727
+ GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
8728
+
8729
+ if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
8730
+ return;
8731
+ }
8732
+
8733
+ const int ith = params->ith;
8734
+ const int nth = params->nth;
8735
+
8736
+ const int nr = ggml_nrows(src0);
8737
+
8738
+ GGML_TENSOR_BINARY_OP_LOCALS
8739
+
8740
+ GGML_ASSERT(src0->type == GGML_TYPE_BF16);
8741
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
8742
+
8743
+ if (dst->type == GGML_TYPE_F32) {
8744
+ GGML_ASSERT( nb0 == sizeof(float));
8745
+ }
8746
+ else {
8747
+ GGML_ASSERT(dst->type == GGML_TYPE_BF16);
8748
+ GGML_ASSERT( nb0 == sizeof(ggml_bf16_t));
8749
+ }
8750
+
8751
+ GGML_ASSERT(nb00 == sizeof(ggml_bf16_t));
8752
+
8753
+ // rows per thread
8754
+ const int dr = (nr + nth - 1)/nth;
8755
+
8756
+ // row range for this thread
8757
+ const int ir0 = dr*ith;
8758
+ const int ir1 = MIN(ir0 + dr, nr);
8759
+
8760
+ if (nb10 == sizeof(float)) {
8761
+ if (dst->type == GGML_TYPE_BF16) {
8762
+ for (int ir = ir0; ir < ir1; ++ir) {
8763
+ // src0, src1 and dst are same shape => same indices
8764
+ const int i3 = ir/(ne2*ne1);
8765
+ const int i2 = (ir - i3*ne2*ne1)/ne1;
8766
+ const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
8767
+
8768
+ ggml_bf16_t * dst_ptr = (ggml_bf16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
8769
+ ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
8770
+ float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
8771
+
8772
+ for (int i = 0; i < ne0; i++) {
8773
+ dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + src1_ptr[i]);
8774
+ }
8775
+ }
8776
+ } else {
8777
+ for (int ir = ir0; ir < ir1; ++ir) {
8778
+ // src0, src1 and dst are same shape => same indices
8779
+ const int i3 = ir/(ne2*ne1);
8780
+ const int i2 = (ir - i3*ne2*ne1)/ne1;
8781
+ const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
8782
+
8783
+ float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
8784
+ ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
8785
+ float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
8786
+
8787
+ for (int i = 0; i < ne0; i++) {
8788
+ dst_ptr[i] = GGML_BF16_TO_FP32(src0_ptr[i]) + src1_ptr[i];
8789
+ }
8790
+ }
8791
+ }
8792
+ }
8793
+ else {
8794
+ // src1 is not contiguous
8795
+ GGML_ASSERT(false);
8796
+ }
8797
+ }
8798
+
7971
8799
  static void ggml_compute_forward_add_f16_f16(
7972
8800
  const struct ggml_compute_params * params,
7973
8801
  struct ggml_tensor * dst) {
@@ -8024,6 +8852,62 @@ static void ggml_compute_forward_add_f16_f16(
8024
8852
  }
8025
8853
  }
8026
8854
 
8855
+ static void ggml_compute_forward_add_bf16_bf16(
8856
+ const struct ggml_compute_params * params,
8857
+ struct ggml_tensor * dst) {
8858
+
8859
+ const struct ggml_tensor * src0 = dst->src[0];
8860
+ const struct ggml_tensor * src1 = dst->src[1];
8861
+
8862
+ GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
8863
+
8864
+ if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
8865
+ return;
8866
+ }
8867
+
8868
+ const int ith = params->ith;
8869
+ const int nth = params->nth;
8870
+
8871
+ const int nr = ggml_nrows(src0);
8872
+
8873
+ GGML_TENSOR_BINARY_OP_LOCALS
8874
+
8875
+ GGML_ASSERT(src0->type == GGML_TYPE_BF16);
8876
+ GGML_ASSERT(src1->type == GGML_TYPE_BF16);
8877
+ GGML_ASSERT(dst->type == GGML_TYPE_BF16);
8878
+
8879
+ GGML_ASSERT( nb0 == sizeof(ggml_bf16_t));
8880
+ GGML_ASSERT(nb00 == sizeof(ggml_bf16_t));
8881
+
8882
+ // rows per thread
8883
+ const int dr = (nr + nth - 1)/nth;
8884
+
8885
+ // row range for this thread
8886
+ const int ir0 = dr*ith;
8887
+ const int ir1 = MIN(ir0 + dr, nr);
8888
+
8889
+ if (nb10 == sizeof(ggml_bf16_t)) {
8890
+ for (int ir = ir0; ir < ir1; ++ir) {
8891
+ // src0, src1 and dst are same shape => same indices
8892
+ const int i3 = ir/(ne2*ne1);
8893
+ const int i2 = (ir - i3*ne2*ne1)/ne1;
8894
+ const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
8895
+
8896
+ ggml_bf16_t * dst_ptr = (ggml_bf16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
8897
+ ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
8898
+ ggml_bf16_t * src1_ptr = (ggml_bf16_t *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
8899
+
8900
+ for (int i = 0; i < ne0; i++) {
8901
+ dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + GGML_BF16_TO_FP32(src1_ptr[i]));
8902
+ }
8903
+ }
8904
+ }
8905
+ else {
8906
+ // src1 is not contiguous
8907
+ GGML_ASSERT(false);
8908
+ }
8909
+ }
8910
+
8027
8911
  static void ggml_compute_forward_add_q_f32(
8028
8912
  const struct ggml_compute_params * params,
8029
8913
  struct ggml_tensor * dst) {
@@ -8133,6 +9017,18 @@ static void ggml_compute_forward_add(
8133
9017
  GGML_ASSERT(false);
8134
9018
  }
8135
9019
  } break;
9020
+ case GGML_TYPE_BF16:
9021
+ {
9022
+ if (src1->type == GGML_TYPE_BF16) {
9023
+ ggml_compute_forward_add_bf16_bf16(params, dst);
9024
+ }
9025
+ else if (src1->type == GGML_TYPE_F32) {
9026
+ ggml_compute_forward_add_bf16_f32(params, dst);
9027
+ }
9028
+ else {
9029
+ GGML_ASSERT(false);
9030
+ }
9031
+ } break;
8136
9032
  case GGML_TYPE_Q4_0:
8137
9033
  case GGML_TYPE_Q4_1:
8138
9034
  case GGML_TYPE_Q5_0:
@@ -8346,21 +9242,133 @@ static void ggml_compute_forward_add1_q_f32(
8346
9242
 
8347
9243
  GGML_TENSOR_UNARY_OP_LOCALS
8348
9244
 
8349
- const enum ggml_type type = src0->type;
8350
- ggml_to_float_t const dequantize_row_q = type_traits[type].to_float;
8351
- ggml_from_float_t const quantize_row_q = type_traits[type].from_float;
8352
-
8353
- // we don't support permuted src0
8354
- GGML_ASSERT(nb00 == ggml_type_size(type));
8355
-
8356
- // dst cannot be transposed or permuted
8357
- GGML_ASSERT(nb0 <= nb1);
8358
- GGML_ASSERT(nb1 <= nb2);
8359
- GGML_ASSERT(nb2 <= nb3);
9245
+ const enum ggml_type type = src0->type;
9246
+ ggml_to_float_t const dequantize_row_q = type_traits[type].to_float;
9247
+ ggml_from_float_t const quantize_row_q = type_traits[type].from_float;
9248
+
9249
+ // we don't support permuted src0
9250
+ GGML_ASSERT(nb00 == ggml_type_size(type));
9251
+
9252
+ // dst cannot be transposed or permuted
9253
+ GGML_ASSERT(nb0 <= nb1);
9254
+ GGML_ASSERT(nb1 <= nb2);
9255
+ GGML_ASSERT(nb2 <= nb3);
9256
+
9257
+ GGML_ASSERT(ggml_is_quantized(src0->type));
9258
+ GGML_ASSERT(dst->type == src0->type);
9259
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
9260
+
9261
+ // rows per thread
9262
+ const int dr = (nr + nth - 1)/nth;
9263
+
9264
+ // row range for this thread
9265
+ const int ir0 = dr*ith;
9266
+ const int ir1 = MIN(ir0 + dr, nr);
9267
+
9268
+ float * wdata = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32) * ith;
9269
+
9270
+ for (int ir = ir0; ir < ir1; ++ir) {
9271
+ // src0 and dst are same shape => same indices
9272
+ const int i3 = ir/(ne2*ne1);
9273
+ const int i2 = (ir - i3*ne2*ne1)/ne1;
9274
+ const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
9275
+
9276
+ void * src0_row = (void *) ((char *) src0->data + (i1*nb01 + i2*nb02 + i3*nb03));
9277
+ void * dst_row = (void *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb0 ));
9278
+
9279
+ assert(ne0 % 32 == 0);
9280
+
9281
+ // unquantize row from src0 to temp buffer
9282
+ dequantize_row_q(src0_row, wdata, ne0);
9283
+ // add src1
9284
+ ggml_vec_acc1_f32(ne0, wdata, v);
9285
+ // quantize row to dst
9286
+ quantize_row_q(wdata, dst_row, ne0);
9287
+ }
9288
+ }
9289
+
9290
+ static void ggml_compute_forward_add1_bf16_f32(
9291
+ const struct ggml_compute_params * params,
9292
+ struct ggml_tensor * dst) {
9293
+
9294
+ const struct ggml_tensor * src0 = dst->src[0];
9295
+ const struct ggml_tensor * src1 = dst->src[1];
9296
+
9297
+ GGML_ASSERT(ggml_are_same_shape(src0, dst));
9298
+ GGML_ASSERT(ggml_is_scalar(src1));
9299
+
9300
+ if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
9301
+ return;
9302
+ }
9303
+
9304
+ // scalar to add
9305
+ const float v = *(float *) src1->data;
9306
+
9307
+ const int ith = params->ith;
9308
+ const int nth = params->nth;
9309
+
9310
+ const int nr = ggml_nrows(src0);
9311
+
9312
+ GGML_TENSOR_UNARY_OP_LOCALS
9313
+
9314
+ GGML_ASSERT(src0->type == GGML_TYPE_BF16);
9315
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
9316
+ GGML_ASSERT(dst->type == GGML_TYPE_BF16);
9317
+
9318
+ GGML_ASSERT( nb0 == sizeof(ggml_bf16_t));
9319
+ GGML_ASSERT(nb00 == sizeof(ggml_bf16_t));
9320
+
9321
+ // rows per thread
9322
+ const int dr = (nr + nth - 1)/nth;
9323
+
9324
+ // row range for this thread
9325
+ const int ir0 = dr*ith;
9326
+ const int ir1 = MIN(ir0 + dr, nr);
9327
+
9328
+ for (int ir = ir0; ir < ir1; ++ir) {
9329
+ // src0 and dst are same shape => same indices
9330
+ const int i3 = ir/(ne2*ne1);
9331
+ const int i2 = (ir - i3*ne2*ne1)/ne1;
9332
+ const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
9333
+
9334
+ ggml_bf16_t * dst_ptr = (ggml_bf16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 );
9335
+ ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
9336
+ for (int i = 0; i < ne0; i++) {
9337
+ dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + v);
9338
+ }
9339
+ }
9340
+ }
9341
+
9342
+ static void ggml_compute_forward_add1_bf16_bf16(
9343
+ const struct ggml_compute_params * params,
9344
+ struct ggml_tensor * dst) {
9345
+
9346
+ const struct ggml_tensor * src0 = dst->src[0];
9347
+ const struct ggml_tensor * src1 = dst->src[1];
9348
+
9349
+ GGML_ASSERT(ggml_are_same_shape(src0, dst));
9350
+ GGML_ASSERT(ggml_is_scalar(src1));
9351
+
9352
+ if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
9353
+ return;
9354
+ }
9355
+
9356
+ // scalar to add
9357
+ const float v = GGML_BF16_TO_FP32(*(ggml_bf16_t *) src1->data);
9358
+
9359
+ const int ith = params->ith;
9360
+ const int nth = params->nth;
9361
+
9362
+ const int nr = ggml_nrows(src0);
9363
+
9364
+ GGML_TENSOR_UNARY_OP_LOCALS
9365
+
9366
+ GGML_ASSERT(src0->type == GGML_TYPE_BF16);
9367
+ GGML_ASSERT(src1->type == GGML_TYPE_BF16);
9368
+ GGML_ASSERT(dst->type == GGML_TYPE_BF16);
8360
9369
 
8361
- GGML_ASSERT(ggml_is_quantized(src0->type));
8362
- GGML_ASSERT(dst->type == src0->type);
8363
- GGML_ASSERT(src1->type == GGML_TYPE_F32);
9370
+ GGML_ASSERT( nb0 == sizeof(ggml_bf16_t));
9371
+ GGML_ASSERT(nb00 == sizeof(ggml_bf16_t));
8364
9372
 
8365
9373
  // rows per thread
8366
9374
  const int dr = (nr + nth - 1)/nth;
@@ -8369,25 +9377,17 @@ static void ggml_compute_forward_add1_q_f32(
8369
9377
  const int ir0 = dr*ith;
8370
9378
  const int ir1 = MIN(ir0 + dr, nr);
8371
9379
 
8372
- float * wdata = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32) * ith;
8373
-
8374
9380
  for (int ir = ir0; ir < ir1; ++ir) {
8375
9381
  // src0 and dst are same shape => same indices
8376
9382
  const int i3 = ir/(ne2*ne1);
8377
9383
  const int i2 = (ir - i3*ne2*ne1)/ne1;
8378
9384
  const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
8379
9385
 
8380
- void * src0_row = (void *) ((char *) src0->data + (i1*nb01 + i2*nb02 + i3*nb03));
8381
- void * dst_row = (void *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb0 ));
8382
-
8383
- assert(ne0 % 32 == 0);
8384
-
8385
- // unquantize row from src0 to temp buffer
8386
- dequantize_row_q(src0_row, wdata, ne0);
8387
- // add src1
8388
- ggml_vec_acc1_f32(ne0, wdata, v);
8389
- // quantize row to dst
8390
- quantize_row_q(wdata, dst_row, ne0);
9386
+ ggml_bf16_t * dst_ptr = (ggml_bf16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 );
9387
+ ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
9388
+ for (int i = 0; i < ne0; i++) {
9389
+ dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + v);
9390
+ }
8391
9391
  }
8392
9392
  }
8393
9393
 
@@ -8415,6 +9415,18 @@ static void ggml_compute_forward_add1(
8415
9415
  GGML_ASSERT(false);
8416
9416
  }
8417
9417
  } break;
9418
+ case GGML_TYPE_BF16:
9419
+ {
9420
+ if (src1->type == GGML_TYPE_BF16) {
9421
+ ggml_compute_forward_add1_bf16_bf16(params, dst);
9422
+ }
9423
+ else if (src1->type == GGML_TYPE_F32) {
9424
+ ggml_compute_forward_add1_bf16_f32(params, dst);
9425
+ }
9426
+ else {
9427
+ GGML_ASSERT(false);
9428
+ }
9429
+ } break;
8418
9430
  case GGML_TYPE_Q4_0:
8419
9431
  case GGML_TYPE_Q4_1:
8420
9432
  case GGML_TYPE_Q5_0:
@@ -8543,6 +9555,7 @@ static void ggml_compute_forward_acc(
8543
9555
  ggml_compute_forward_acc_f32(params, dst);
8544
9556
  } break;
8545
9557
  case GGML_TYPE_F16:
9558
+ case GGML_TYPE_BF16:
8546
9559
  case GGML_TYPE_Q4_0:
8547
9560
  case GGML_TYPE_Q4_1:
8548
9561
  case GGML_TYPE_Q5_0:
@@ -9064,6 +10077,40 @@ static void ggml_compute_forward_sum_f16(
9064
10077
  ((ggml_fp16_t *) dst->data)[0] = GGML_FP32_TO_FP16(sum);
9065
10078
  }
9066
10079
 
10080
+ static void ggml_compute_forward_sum_bf16(
10081
+ const struct ggml_compute_params * params,
10082
+ struct ggml_tensor * dst) {
10083
+
10084
+ const struct ggml_tensor * src0 = dst->src[0];
10085
+
10086
+ assert(params->ith == 0);
10087
+ assert(ggml_is_scalar(dst));
10088
+
10089
+ if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
10090
+ return;
10091
+ }
10092
+
10093
+ assert(src0->nb[0] == sizeof(ggml_bf16_t));
10094
+
10095
+ GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
10096
+ GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
10097
+
10098
+ float sum = 0;
10099
+ float row_sum = 0;
10100
+
10101
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
10102
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
10103
+ for (int64_t i01 = 0; i01 < ne01; i01++) {
10104
+ ggml_vec_sum_bf16_ggf(ne00,
10105
+ &row_sum,
10106
+ (ggml_bf16_t *) ((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03));
10107
+ sum += row_sum;
10108
+ }
10109
+ }
10110
+ }
10111
+ ((ggml_bf16_t *) dst->data)[0] = GGML_FP32_TO_BF16(sum);
10112
+ }
10113
+
9067
10114
  static void ggml_compute_forward_sum(
9068
10115
  const struct ggml_compute_params * params,
9069
10116
  struct ggml_tensor * dst) {
@@ -9079,6 +10126,10 @@ static void ggml_compute_forward_sum(
9079
10126
  {
9080
10127
  ggml_compute_forward_sum_f16(params, dst);
9081
10128
  } break;
10129
+ case GGML_TYPE_BF16:
10130
+ {
10131
+ ggml_compute_forward_sum_bf16(params, dst);
10132
+ } break;
9082
10133
  default:
9083
10134
  {
9084
10135
  GGML_ASSERT(false);
@@ -9353,6 +10404,7 @@ static void ggml_compute_forward_repeat(
9353
10404
 
9354
10405
  switch (src0->type) {
9355
10406
  case GGML_TYPE_F16:
10407
+ case GGML_TYPE_BF16:
9356
10408
  case GGML_TYPE_I16:
9357
10409
  {
9358
10410
  ggml_compute_forward_repeat_f16(params, dst);
@@ -11670,6 +12722,7 @@ static void ggml_compute_forward_set(
11670
12722
  ggml_compute_forward_set_f32(params, dst);
11671
12723
  } break;
11672
12724
  case GGML_TYPE_F16:
12725
+ case GGML_TYPE_BF16:
11673
12726
  case GGML_TYPE_Q4_0:
11674
12727
  case GGML_TYPE_Q4_1:
11675
12728
  case GGML_TYPE_Q5_0:
@@ -11844,6 +12897,49 @@ static void ggml_compute_forward_get_rows_f16(
11844
12897
  }
11845
12898
  }
11846
12899
 
12900
+ static void ggml_compute_forward_get_rows_bf16(
12901
+ const struct ggml_compute_params * params,
12902
+ struct ggml_tensor * dst) {
12903
+
12904
+ const struct ggml_tensor * src0 = dst->src[0];
12905
+ const struct ggml_tensor * src1 = dst->src[1];
12906
+
12907
+ if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
12908
+ return;
12909
+ }
12910
+
12911
+ GGML_TENSOR_BINARY_OP_LOCALS
12912
+
12913
+ const int64_t nc = ne00;
12914
+ const int64_t nr = ggml_nelements(src1);
12915
+
12916
+ assert(ne0 == nc);
12917
+ assert(ne02 == ne11);
12918
+ assert(nb00 == sizeof(ggml_bf16_t));
12919
+ assert(ggml_nrows(dst) == nr);
12920
+
12921
+ const int ith = params->ith;
12922
+ const int nth = params->nth;
12923
+
12924
+ // rows per thread
12925
+ const int dr = (nr + nth - 1)/nth;
12926
+
12927
+ // row range for this thread
12928
+ const int ir0 = dr*ith;
12929
+ const int ir1 = MIN(ir0 + dr, nr);
12930
+
12931
+ for (int64_t i = ir0; i < ir1; ++i) {
12932
+ const int64_t i12 = i/(ne11*ne10);
12933
+ const int64_t i11 = (i - i12*ne11*ne10)/ne10;
12934
+ const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
12935
+ const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
12936
+
12937
+ ggml_bf16_to_fp32_row(
12938
+ (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
12939
+ (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc);
12940
+ }
12941
+ }
12942
+
11847
12943
  static void ggml_compute_forward_get_rows_f32(
11848
12944
  const struct ggml_compute_params * params,
11849
12945
  struct ggml_tensor * dst) {
@@ -11921,6 +13017,10 @@ static void ggml_compute_forward_get_rows(
11921
13017
  {
11922
13018
  ggml_compute_forward_get_rows_f16(params, dst);
11923
13019
  } break;
13020
+ case GGML_TYPE_BF16:
13021
+ {
13022
+ ggml_compute_forward_get_rows_bf16(params, dst);
13023
+ } break;
11924
13024
  case GGML_TYPE_F32:
11925
13025
  case GGML_TYPE_I32:
11926
13026
  {
@@ -12255,7 +13355,7 @@ static void ggml_compute_forward_soft_max_f32(
12255
13355
 
12256
13356
  GGML_TENSOR_UNARY_OP_LOCALS
12257
13357
 
12258
- const int64_t ne11 = src1 ? src1->ne[1] : 1;
13358
+ //const int64_t ne11 = src1 ? src1->ne[1] : 1;
12259
13359
 
12260
13360
  // TODO: is this supposed to be ceil instead of floor?
12261
13361
  // https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370
@@ -12278,19 +13378,31 @@ static void ggml_compute_forward_soft_max_f32(
12278
13378
  float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
12279
13379
 
12280
13380
  // when max_bias <= 0.0f, src2 is not used and we default it to src0 to avoid branching
12281
- float * pos = src2 ? (float *) src2->data : src0->data;
13381
+ ggml_fp16_t * pos_f16 = src2 ? (ggml_fp16_t *) src2->data : src0->data;
13382
+ float * pos_f32 = src2 ? (float *) src2->data : src0->data;
13383
+
13384
+ const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16) || (src2 && src2->type == GGML_TYPE_F16);
12282
13385
 
12283
13386
  for (int i1 = ir0; i1 < ir1; i1++) {
12284
13387
  float * sp = (float *)((char *) src0->data + i1*src0->nb[1]);
12285
13388
  float * dp = (float *)((char *) dst->data + i1*dst->nb[1]);
12286
13389
 
12287
13390
  // broadcast the mask across rows
12288
- float * mp = src1 ? (float *)((char *) src1->data + (i1%ne11)*src1->nb[1]) : NULL;
13391
+ ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
13392
+ float * mp_f32 = src1 ? (float *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
12289
13393
 
12290
13394
  ggml_vec_cpy_f32 (nc, wp, sp);
12291
13395
  ggml_vec_scale_f32(nc, wp, scale);
12292
- if (mp) {
12293
- ggml_vec_acc_f32(nc, wp, mp);
13396
+ if (mp_f32) {
13397
+ if (use_f16) {
13398
+ for (int i = 0; i < nc; ++i) {
13399
+ wp[i] += GGML_FP16_TO_FP32(mp_f16[i]);
13400
+ }
13401
+ } else {
13402
+ for (int i = 0; i < nc; ++i) {
13403
+ wp[i] += mp_f32[i];
13404
+ }
13405
+ }
12294
13406
  }
12295
13407
 
12296
13408
  // ALiBi bias
@@ -12298,8 +13410,14 @@ static void ggml_compute_forward_soft_max_f32(
12298
13410
  const uint32_t h = (i1/ne01)%ne02; // head
12299
13411
  const float slope = h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1);
12300
13412
 
12301
- for (int i = 0; i < nc; i++) {
12302
- wp[i] = wp[i] + slope*pos[i];
13413
+ if (use_f16) {
13414
+ for (int i = 0; i < nc; ++i) {
13415
+ wp[i] += slope*GGML_FP16_TO_FP32(pos_f16[i]);
13416
+ }
13417
+ } else {
13418
+ for (int i = 0; i < nc; ++i) {
13419
+ wp[i] += slope*pos_f32[i];
13420
+ }
12303
13421
  }
12304
13422
  }
12305
13423
 
@@ -12598,6 +13716,7 @@ static void ggml_compute_forward_alibi(
12598
13716
  {
12599
13717
  ggml_compute_forward_alibi_f32(params, dst);
12600
13718
  } break;
13719
+ case GGML_TYPE_BF16:
12601
13720
  case GGML_TYPE_Q4_0:
12602
13721
  case GGML_TYPE_Q4_1:
12603
13722
  case GGML_TYPE_Q5_0:
@@ -12687,6 +13806,7 @@ static void ggml_compute_forward_clamp(
12687
13806
  ggml_compute_forward_clamp_f32(params, dst);
12688
13807
  } break;
12689
13808
  case GGML_TYPE_F16:
13809
+ case GGML_TYPE_BF16:
12690
13810
  case GGML_TYPE_Q4_0:
12691
13811
  case GGML_TYPE_Q4_1:
12692
13812
  case GGML_TYPE_Q5_0:
@@ -14569,6 +15689,198 @@ static void ggml_compute_forward_flash_attn(
14569
15689
  }
14570
15690
  }
14571
15691
 
15692
+ // ggml_compute_forward_flash_attn_ext
15693
+
15694
+ static void ggml_compute_forward_flash_attn_ext_f16(
15695
+ const struct ggml_compute_params * params,
15696
+ const struct ggml_tensor * q,
15697
+ const struct ggml_tensor * k,
15698
+ const struct ggml_tensor * v,
15699
+ const struct ggml_tensor * mask,
15700
+ struct ggml_tensor * dst) {
15701
+ int64_t t0 = ggml_perf_time_us();
15702
+ UNUSED(t0);
15703
+
15704
+ GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
15705
+ GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
15706
+ GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
15707
+ GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
15708
+ GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
15709
+ GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
15710
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
15711
+ GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
15712
+
15713
+ const int ith = params->ith;
15714
+ const int nth = params->nth;
15715
+
15716
+ const int64_t D = neq0;
15717
+ const int64_t N = neq1;
15718
+
15719
+ GGML_ASSERT(ne0 == D);
15720
+ GGML_ASSERT(ne2 == N);
15721
+
15722
+ GGML_ASSERT(nbq0 == sizeof(float));
15723
+ GGML_ASSERT(nbk0 == sizeof(ggml_fp16_t));
15724
+ GGML_ASSERT(nbv0 == sizeof(ggml_fp16_t));
15725
+
15726
+ GGML_ASSERT(neq0 == D);
15727
+ GGML_ASSERT(nek0 == D);
15728
+ GGML_ASSERT(nev0 == D);
15729
+
15730
+ GGML_ASSERT(neq1 == N);
15731
+ GGML_ASSERT(nev0 == D);
15732
+
15733
+ // dst cannot be transposed or permuted
15734
+ GGML_ASSERT(nb0 == sizeof(float));
15735
+ GGML_ASSERT(nb0 <= nb1);
15736
+ GGML_ASSERT(nb1 <= nb2);
15737
+ GGML_ASSERT(nb2 <= nb3);
15738
+
15739
+ // broadcast factors
15740
+ const int64_t rk2 = neq2/nek2;
15741
+ const int64_t rk3 = neq3/nek3;
15742
+
15743
+ const int64_t rv2 = neq2/nev2;
15744
+ const int64_t rv3 = neq3/nev3;
15745
+
15746
+ if (params->type == GGML_TASK_TYPE_INIT) {
15747
+ return;
15748
+ }
15749
+
15750
+ if (params->type == GGML_TASK_TYPE_FINALIZE) {
15751
+ return;
15752
+ }
15753
+
15754
+ // parallelize by q rows using ggml_vec_dot_f32
15755
+
15756
+ // total rows in q
15757
+ const int nr = neq1*neq2*neq3;
15758
+
15759
+ // rows per thread
15760
+ const int dr = (nr + nth - 1)/nth;
15761
+
15762
+ // row range for this thread
15763
+ const int ir0 = dr*ith;
15764
+ const int ir1 = MIN(ir0 + dr, nr);
15765
+
15766
+ float scale = 1.0f;
15767
+ memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
15768
+
15769
+ // loop over n_batch and n_head
15770
+ for (int ir = ir0; ir < ir1; ++ir) {
15771
+ // q indices
15772
+ const int iq3 = ir/(neq2*neq1);
15773
+ const int iq2 = (ir - iq3*neq2*neq1)/neq1;
15774
+ const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
15775
+
15776
+ float S = 0.0f;
15777
+ float M = -INFINITY;
15778
+
15779
+ float * V32 = (float *) params->wdata + ith*(2*D + CACHE_LINE_SIZE_F32);
15780
+ ggml_fp16_t * Q16 = (ggml_fp16_t *) (V32); // reuse memory
15781
+ ggml_fp16_t * V16 = (ggml_fp16_t *) (V32 + D);
15782
+
15783
+ memset(V16, 0, D*sizeof(ggml_fp16_t));
15784
+
15785
+ const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL;
15786
+
15787
+ // k indices
15788
+ const int ik3 = iq3 / rk3;
15789
+ const int ik2 = iq2 / rk2;
15790
+
15791
+ // v indices
15792
+ const int iv3 = iq3 / rv3;
15793
+ const int iv2 = iq2 / rv2;
15794
+
15795
+ // online softmax / attention
15796
+ // loop over n_kv and n_head_kv
15797
+ // ref: https://arxiv.org/pdf/2112.05682.pdf
15798
+ for (int64_t ic = 0; ic < nek1; ++ic) {
15799
+ const float mv = mp ? GGML_FP16_TO_FP32(mp[ic]) : 0.0f;
15800
+ if (mv == -INFINITY) {
15801
+ continue;
15802
+ }
15803
+
15804
+ float s;
15805
+
15806
+ // convert Q to F16 in V32
15807
+ {
15808
+ const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
15809
+
15810
+ for (int64_t d = 0; d < D; ++d) {
15811
+ Q16[d] = GGML_FP32_TO_FP16(pq[d]);
15812
+ }
15813
+ }
15814
+
15815
+ ggml_vec_dot_f16(D,
15816
+ &s, 0,
15817
+ (ggml_fp16_t *) ((char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3)), 0,
15818
+ Q16, 0, 1);
15819
+
15820
+ s = s*scale + mv;
15821
+
15822
+ const float Mold = M;
15823
+
15824
+ float ms = 1.0f;
15825
+ float vs = 1.0f;
15826
+
15827
+ if (s > M) {
15828
+ M = s;
15829
+ ms = expf(Mold - M);
15830
+
15831
+ // V = V*expf(Mold - M)
15832
+ ggml_vec_scale_f16(D, V16, ms);
15833
+ } else {
15834
+ vs = expf(s - M);
15835
+ }
15836
+
15837
+ const ggml_fp16_t * v16 = (const ggml_fp16_t *) ((char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3));
15838
+
15839
+ // V += v*expf(s - M)
15840
+ ggml_vec_mad_f16(D, V16, v16, vs);
15841
+
15842
+ S = S*ms + vs;
15843
+ }
15844
+
15845
+ // V /= S
15846
+ for (int64_t d = 0; d < D; ++d) {
15847
+ V32[d] = GGML_FP16_TO_FP32(V16[d])/S;
15848
+ }
15849
+
15850
+ // dst indices
15851
+ const int i1 = iq1;
15852
+ const int i2 = iq2;
15853
+ const int i3 = iq3;
15854
+
15855
+ // original
15856
+ //memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float));
15857
+
15858
+ // permute(0, 2, 1, 3)
15859
+ memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, V32, nb1);
15860
+ }
15861
+ }
15862
+
15863
+ static void ggml_compute_forward_flash_attn_ext(
15864
+ const struct ggml_compute_params * params,
15865
+ const struct ggml_tensor * q,
15866
+ const struct ggml_tensor * k,
15867
+ const struct ggml_tensor * v,
15868
+ const struct ggml_tensor * mask,
15869
+ struct ggml_tensor * dst) {
15870
+ switch (dst->op_params[1]) {
15871
+ case GGML_PREC_DEFAULT:
15872
+ case GGML_PREC_F32:
15873
+ {
15874
+ // uses F32 accumulators
15875
+ ggml_compute_forward_flash_attn_ext_f16(params, q, k, v, mask, dst);
15876
+ } break;
15877
+ default:
15878
+ {
15879
+ GGML_ASSERT(false);
15880
+ } break;
15881
+ }
15882
+ }
15883
+
14572
15884
  // ggml_compute_forward_flash_ff
14573
15885
 
14574
15886
  static void ggml_compute_forward_flash_ff_f16(
@@ -15588,6 +16900,7 @@ static void ggml_compute_forward_get_rel_pos(
15588
16900
 
15589
16901
  switch (src0->type) {
15590
16902
  case GGML_TYPE_F16:
16903
+ case GGML_TYPE_BF16:
15591
16904
  {
15592
16905
  ggml_compute_forward_get_rel_pos_f16(params, dst);
15593
16906
  } break;
@@ -16376,6 +17689,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
16376
17689
  const bool masked = t != 0;
16377
17690
  ggml_compute_forward_flash_attn(params, masked, tensor);
16378
17691
  } break;
17692
+ case GGML_OP_FLASH_ATTN_EXT:
17693
+ {
17694
+ ggml_compute_forward_flash_attn_ext(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor);
17695
+ } break;
16379
17696
  case GGML_OP_FLASH_FF:
16380
17697
  {
16381
17698
  ggml_compute_forward_flash_ff(params, tensor);
@@ -17388,6 +18705,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
17388
18705
  GGML_ASSERT(false); // TODO: not implemented
17389
18706
  } break;
17390
18707
  case GGML_OP_FLASH_ATTN:
18708
+ case GGML_OP_FLASH_ATTN_EXT:
17391
18709
  {
17392
18710
  struct ggml_tensor * flash_grad = NULL;
17393
18711
  if (src0->grad || src1->grad || tensor->src[2]->grad) {
@@ -18160,6 +19478,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads, int n_cur_
18160
19478
  n_tasks = n_threads;
18161
19479
  } break;
18162
19480
  case GGML_OP_FLASH_ATTN:
19481
+ case GGML_OP_FLASH_ATTN_EXT:
18163
19482
  {
18164
19483
  n_tasks = n_threads;
18165
19484
  } break;
@@ -18446,7 +19765,10 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
18446
19765
  case GGML_OP_CPY:
18447
19766
  case GGML_OP_DUP:
18448
19767
  {
18449
- if (ggml_is_quantized(node->type)) {
19768
+ if (ggml_is_quantized(node->type) ||
19769
+ // F16 -> BF16 and BF16 -> F16 copies go through intermediate F32
19770
+ (node->src[0]->type == GGML_TYPE_F16 && node->src[1] && node->src[1]->type == GGML_TYPE_BF16) ||
19771
+ (node->src[0]->type == GGML_TYPE_BF16 && node->src[1] && node->src[1]->type == GGML_TYPE_F16)) {
18450
19772
  cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
18451
19773
  }
18452
19774
  } break;
@@ -18525,7 +19847,8 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
18525
19847
  const int64_t ne10 = node->src[1]->ne[0]; // L
18526
19848
  const int64_t ne11 = node->src[1]->ne[1]; // Cin
18527
19849
 
18528
- if (node->src[0]->type == GGML_TYPE_F16 &&
19850
+ if ((node->src[0]->type == GGML_TYPE_F16 ||
19851
+ node->src[0]->type == GGML_TYPE_BF16) &&
18529
19852
  node->src[1]->type == GGML_TYPE_F32) {
18530
19853
  cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02;
18531
19854
  cur += sizeof(ggml_fp16_t)*ne10*ne11;
@@ -18561,8 +19884,17 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
18561
19884
  } else if (node->src[1]->type == GGML_TYPE_F16) {
18562
19885
  cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1)
18563
19886
  cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
19887
+ } else if (node->src[1]->type == GGML_TYPE_BF16) {
19888
+ cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1)
19889
+ cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
18564
19890
  }
18565
19891
  } break;
19892
+ case GGML_OP_FLASH_ATTN_EXT:
19893
+ {
19894
+ const int64_t ne00 = node->src[0]->ne[0]; // D
19895
+
19896
+ cur = 2*sizeof(float)*ne00*n_tasks; // 2x head size
19897
+ } break;
18566
19898
  case GGML_OP_FLASH_FF:
18567
19899
  {
18568
19900
  if (node->src[1]->type == GGML_TYPE_F32) {
@@ -18571,6 +19903,9 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
18571
19903
  } else if (node->src[1]->type == GGML_TYPE_F16) {
18572
19904
  cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
18573
19905
  cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2
19906
+ } else if (node->src[1]->type == GGML_TYPE_BF16) {
19907
+ cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
19908
+ cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2
18574
19909
  }
18575
19910
  } break;
18576
19911
  case GGML_OP_FLASH_ATTN_BACK:
@@ -18584,6 +19919,9 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
18584
19919
  } else if (node->src[1]->type == GGML_TYPE_F16) {
18585
19920
  cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1)
18586
19921
  cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
19922
+ } else if (node->src[1]->type == GGML_TYPE_BF16) {
19923
+ cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1)
19924
+ cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
18587
19925
  }
18588
19926
  } break;
18589
19927
 
@@ -19360,7 +20698,9 @@ void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph
19360
20698
  if (node->type == GGML_TYPE_I8 || node->type == GGML_TYPE_I16 || node->type == GGML_TYPE_I32) {
19361
20699
  fprintf(fp, "%d", ggml_get_i32_1d(node, j));
19362
20700
  }
19363
- else if (node->type == GGML_TYPE_F32 || node->type == GGML_TYPE_F16) {
20701
+ else if (node->type == GGML_TYPE_F32 ||
20702
+ node->type == GGML_TYPE_F16 ||
20703
+ node->type == GGML_TYPE_BF16) {
19364
20704
  fprintf(fp, "%.1e", (double)ggml_get_f32_1d(node, j));
19365
20705
  }
19366
20706
  else {
@@ -20418,6 +21758,12 @@ size_t ggml_quantize_chunk(
20418
21758
  ggml_fp32_to_fp16_row(src + start, (ggml_fp16_t *)dst + start, n);
20419
21759
  result = n * elemsize;
20420
21760
  } break;
21761
+ case GGML_TYPE_BF16:
21762
+ {
21763
+ size_t elemsize = sizeof(ggml_bf16_t);
21764
+ ggml_fp32_to_bf16_row(src + start, (ggml_bf16_t *)dst + start, n);
21765
+ result = n * elemsize;
21766
+ } break;
20421
21767
  case GGML_TYPE_F32:
20422
21768
  {
20423
21769
  size_t elemsize = sizeof(float);
@@ -20614,7 +21960,7 @@ static void gguf_free_kv(struct gguf_kv * kv) {
20614
21960
  }
20615
21961
 
20616
21962
  struct gguf_context * gguf_init_empty(void) {
20617
- struct gguf_context * ctx = GGML_ALIGNED_MALLOC(sizeof(struct gguf_context));
21963
+ struct gguf_context * ctx = GGML_CALLOC(1, sizeof(struct gguf_context));
20618
21964
 
20619
21965
  memcpy(ctx->header.magic, GGUF_MAGIC, sizeof(ctx->header.magic));
20620
21966
  ctx->header.version = GGUF_VERSION;
@@ -20659,7 +22005,7 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
20659
22005
 
20660
22006
  bool ok = true;
20661
22007
 
20662
- struct gguf_context * ctx = GGML_ALIGNED_MALLOC(sizeof(struct gguf_context));
22008
+ struct gguf_context * ctx = GGML_CALLOC(1, sizeof(struct gguf_context));
20663
22009
 
20664
22010
  // read the header
20665
22011
  {
@@ -20696,9 +22042,13 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
20696
22042
 
20697
22043
  // read the kv pairs
20698
22044
  {
20699
- ctx->kv = GGML_MALLOC(ctx->header.n_kv * sizeof(struct gguf_kv));
22045
+ const uint64_t n_kv = ctx->header.n_kv;
20700
22046
 
20701
- for (uint64_t i = 0; i < ctx->header.n_kv; ++i) {
22047
+ // header.n_kv will hold the actual value of pairs that were successfully read in the loop below
22048
+ ctx->header.n_kv = 0;
22049
+ ctx->kv = GGML_CALLOC(n_kv, sizeof(struct gguf_kv));
22050
+
22051
+ for (uint64_t i = 0; i < n_kv; ++i) {
20702
22052
  struct gguf_kv * kv = &ctx->kv[i];
20703
22053
 
20704
22054
  //fprintf(stderr, "%s: reading kv %d\n", __func__, i);
@@ -20747,7 +22097,7 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
20747
22097
  return NULL;
20748
22098
  }
20749
22099
 
20750
- kv->value.arr.data = GGML_MALLOC(kv->value.arr.n * gguf_type_size(kv->value.arr.type));
22100
+ kv->value.arr.data = GGML_CALLOC(kv->value.arr.n, gguf_type_size(kv->value.arr.type));
20751
22101
 
20752
22102
  ok = ok && gguf_fread_el(file, kv->value.arr.data, kv->value.arr.n * gguf_type_size(kv->value.arr.type), &offset);
20753
22103
  } break;
@@ -20761,7 +22111,7 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
20761
22111
  return NULL;
20762
22112
  }
20763
22113
 
20764
- kv->value.arr.data = GGML_MALLOC(kv->value.arr.n * sizeof(struct gguf_str));
22114
+ kv->value.arr.data = GGML_CALLOC(kv->value.arr.n, sizeof(struct gguf_str));
20765
22115
 
20766
22116
  for (uint64_t j = 0; j < kv->value.arr.n; ++j) {
20767
22117
  ok = ok && gguf_fread_str(file, &((struct gguf_str *) kv->value.arr.data)[j], &offset);
@@ -20777,6 +22127,8 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
20777
22127
  if (!ok) {
20778
22128
  break;
20779
22129
  }
22130
+
22131
+ ctx->header.n_kv++;
20780
22132
  }
20781
22133
 
20782
22134
  if (!ok) {
@@ -20788,8 +22140,8 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
20788
22140
  }
20789
22141
 
20790
22142
  // read the tensor infos
20791
- {
20792
- ctx->infos = GGML_MALLOC(ctx->header.n_tensors * sizeof(struct gguf_tensor_info));
22143
+ if (ctx->header.n_tensors > 0) {
22144
+ ctx->infos = GGML_CALLOC(ctx->header.n_tensors, sizeof(struct gguf_tensor_info));
20793
22145
 
20794
22146
  for (uint64_t i = 0; i < ctx->header.n_tensors; ++i) {
20795
22147
  struct gguf_tensor_info * info = &ctx->infos[i];
@@ -20810,8 +22162,17 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
20810
22162
  ok = ok && gguf_fread_el (file, &info->type, sizeof(info->type), &offset);
20811
22163
  ok = ok && gguf_fread_el (file, &info->offset, sizeof(info->offset), &offset);
20812
22164
 
22165
+ // TODO: return an error instead of crashing with GGML_ASSERT
20813
22166
  gguf_tensor_info_sanitize(info);
20814
22167
 
22168
+ // make sure there is no duplicated tensor names
22169
+ for (uint64_t j = 0; j < i; ++j) {
22170
+ if (strcmp(info->name.data, ctx->infos[j].name.data) == 0) {
22171
+ fprintf(stderr, "%s: duplicated tensor name %s\n", __func__, info->name.data);
22172
+ ok = false;
22173
+ }
22174
+ }
22175
+
20815
22176
  if (!ok) {
20816
22177
  fprintf(stderr, "%s: failed to read tensor info\n", __func__);
20817
22178
  fclose(file);
@@ -20980,7 +22341,7 @@ void gguf_free(struct gguf_context * ctx) {
20980
22341
  GGML_FREE(ctx->infos);
20981
22342
  }
20982
22343
 
20983
- GGML_ALIGNED_FREE(ctx);
22344
+ GGML_FREE(ctx);
20984
22345
  }
20985
22346
 
20986
22347
  const char * gguf_type_name(enum gguf_type type) {
@@ -21291,7 +22652,7 @@ void gguf_set_arr_data(struct gguf_context * ctx, const char * key, enum gguf_ty
21291
22652
  ctx->kv[idx].type = GGUF_TYPE_ARRAY;
21292
22653
  ctx->kv[idx].value.arr.type = type;
21293
22654
  ctx->kv[idx].value.arr.n = n;
21294
- ctx->kv[idx].value.arr.data = GGML_MALLOC(n*gguf_type_size(type));
22655
+ ctx->kv[idx].value.arr.data = GGML_CALLOC(n, gguf_type_size(type));
21295
22656
  memcpy(ctx->kv[idx].value.arr.data, data, n*gguf_type_size(type));
21296
22657
  }
21297
22658
 
@@ -21301,7 +22662,7 @@ void gguf_set_arr_str(struct gguf_context * ctx, const char * key, const char **
21301
22662
  ctx->kv[idx].type = GGUF_TYPE_ARRAY;
21302
22663
  ctx->kv[idx].value.arr.type = GGUF_TYPE_STRING;
21303
22664
  ctx->kv[idx].value.arr.n = n;
21304
- ctx->kv[idx].value.arr.data = GGML_MALLOC(n*sizeof(struct gguf_str));
22665
+ ctx->kv[idx].value.arr.data = GGML_CALLOC(n, sizeof(struct gguf_str));
21305
22666
  for (int i = 0; i < n; i++) {
21306
22667
  struct gguf_str * str = &((struct gguf_str *)ctx->kv[idx].value.arr.data)[i];
21307
22668
  str->n = strlen(data[i]);
@@ -21328,7 +22689,7 @@ void gguf_set_kv(struct gguf_context * ctx, struct gguf_context * src) {
21328
22689
  case GGUF_TYPE_ARRAY:
21329
22690
  {
21330
22691
  if (src->kv[i].value.arr.type == GGUF_TYPE_STRING) {
21331
- const char ** data = GGML_MALLOC(src->kv[i].value.arr.n*sizeof(char *));
22692
+ const char ** data = GGML_CALLOC(src->kv[i].value.arr.n, sizeof(char *));
21332
22693
  for (uint32_t j = 0; j < src->kv[i].value.arr.n; j++) {
21333
22694
  data[j] = ((struct gguf_str *)src->kv[i].value.arr.data)[j].data;
21334
22695
  }
@@ -21348,6 +22709,10 @@ void gguf_set_kv(struct gguf_context * ctx, struct gguf_context * src) {
21348
22709
  void gguf_add_tensor(
21349
22710
  struct gguf_context * ctx,
21350
22711
  const struct ggml_tensor * tensor) {
22712
+ if (gguf_find_tensor(ctx, tensor->name) != -1) {
22713
+ GGML_ASSERT(false && "duplicated tensor name");
22714
+ }
22715
+
21351
22716
  const int idx = ctx->header.n_tensors;
21352
22717
  ctx->infos = realloc(ctx->infos, (idx + 1)*sizeof(struct gguf_tensor_info));
21353
22718
 
@@ -21416,7 +22781,7 @@ struct gguf_buf {
21416
22781
 
21417
22782
  static struct gguf_buf gguf_buf_init(size_t size) {
21418
22783
  struct gguf_buf buf = {
21419
- /*buf.data =*/ size == 0 ? NULL : GGML_MALLOC(size),
22784
+ /*buf.data =*/ size == 0 ? NULL : GGML_CALLOC(1, size),
21420
22785
  /*buf.size =*/ size,
21421
22786
  /*buf.offset =*/ 0,
21422
22787
  };