llama_cpp 0.15.0 → 0.15.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -4,7 +4,6 @@
4
4
  #include "ggml-impl.h"
5
5
  #include "ggml-quants.h"
6
6
  #include "ggml.h"
7
- #include "sgemm.h"
8
7
 
9
8
  #if defined(_MSC_VER) || defined(__MINGW32__)
10
9
  #include <malloc.h> // using malloc.h with MSC/MINGW
@@ -37,6 +36,10 @@
37
36
  #undef GGML_USE_LLAMAFILE
38
37
  #endif
39
38
 
39
+ #ifdef GGML_USE_LLAMAFILE
40
+ #include "sgemm.h"
41
+ #endif
42
+
40
43
  #if defined(_MSC_VER)
41
44
  // disable "possible loss of data" to avoid hundreds of casts
42
45
  // we should just be careful :)
@@ -109,6 +112,8 @@ typedef void * thread_ret_t;
109
112
 
110
113
  #endif
111
114
 
115
+ typedef pthread_t ggml_thread_t;
116
+
112
117
  #ifdef GGML_USE_CPU_HBM
113
118
  #include <hbwmalloc.h>
114
119
  #endif
@@ -160,9 +165,6 @@ void ggml_print_backtrace(void) {
160
165
  #define GGML_DEBUG 0
161
166
  #define GGML_GELU_FP16
162
167
  #define GGML_GELU_QUICK_FP16
163
- #define GGML_SILU_FP16
164
- // #define GGML_CROSS_ENTROPY_EXP_FP16
165
- // #define GGML_FLASH_ATTN_EXP_FP16
166
168
 
167
169
  #define GGML_SOFT_MAX_UNROLL 4
168
170
  #define GGML_VEC_DOT_UNROLL 2
@@ -313,16 +315,10 @@ static ggml_fp16_t ggml_table_gelu_f16[1 << 16];
313
315
  // precomputed quick gelu table for f16 (128 KB)
314
316
  static ggml_fp16_t ggml_table_gelu_quick_f16[1 << 16];
315
317
 
316
- // precomputed silu table for f16 (128 KB)
317
- static ggml_fp16_t ggml_table_silu_f16[1 << 16];
318
-
319
- // precomputed exp table for f16 (128 KB)
320
- static ggml_fp16_t ggml_table_exp_f16[1 << 16];
321
-
322
318
  // precomputed f32 table for f16 (256 KB) (ggml-impl.h)
323
319
  float ggml_table_f32_f16[1 << 16];
324
320
 
325
- const char * ggml_status_to_string(enum ggml_status status) {
321
+ GGML_CALL const char * ggml_status_to_string(enum ggml_status status) {
326
322
  switch (status) {
327
323
  case GGML_STATUS_ALLOC_FAILED: return "GGML status: error (failed to allocate memory)";
328
324
  case GGML_STATUS_FAILED: return "GGML status: error (operation failed)";
@@ -333,16 +329,26 @@ const char * ggml_status_to_string(enum ggml_status status) {
333
329
  return "GGML status: unknown";
334
330
  }
335
331
 
336
- // note: do not use these inside ggml.c
337
- // these are meant to be used via the ggml.h API
338
332
  float ggml_fp16_to_fp32(ggml_fp16_t x) {
333
+ #define ggml_fp16_to_fp32 do_not_use__ggml_fp16_to_fp32__in_ggml
339
334
  return GGML_FP16_TO_FP32(x);
340
335
  }
341
336
 
342
337
  ggml_fp16_t ggml_fp32_to_fp16(float x) {
338
+ #define ggml_fp32_to_fp16 do_not_use__ggml_fp32_to_fp16__in_ggml
343
339
  return GGML_FP32_TO_FP16(x);
344
340
  }
345
341
 
342
+ float ggml_bf16_to_fp32(ggml_bf16_t x) {
343
+ #define ggml_bf16_to_fp32 do_not_use__ggml_bf16_to_fp32__in_ggml
344
+ return GGML_BF16_TO_FP32(x); // it just left shifts
345
+ }
346
+
347
+ ggml_bf16_t ggml_fp32_to_bf16(float x) {
348
+ #define ggml_fp32_to_bf16 do_not_use__ggml_fp32_to_bf16__in_ggml
349
+ return GGML_FP32_TO_BF16(x);
350
+ }
351
+
346
352
  void ggml_fp16_to_fp32_row(const ggml_fp16_t * x, float * y, int64_t n) {
347
353
  for (int64_t i = 0; i < n; i++) {
348
354
  y[i] = GGML_FP16_TO_FP32(x[i]);
@@ -368,6 +374,49 @@ void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, int64_t n) {
368
374
  }
369
375
  }
370
376
 
377
+ void ggml_bf16_to_fp32_row(const ggml_bf16_t * x, float * y, int64_t n) {
378
+ int64_t i = 0;
379
+ #if defined(__AVX512F__)
380
+ for (; i + 16 <= n; i += 16) {
381
+ _mm512_storeu_ps(y + i,
382
+ _mm512_castsi512_ps(
383
+ _mm512_slli_epi32(
384
+ _mm512_cvtepu16_epi32(
385
+ _mm256_loadu_si256(
386
+ (const __m256i *)(x + i))),
387
+ 16)));
388
+ }
389
+ #elif defined(__AVX2__)
390
+ for (; i + 8 <= n; i += 8) {
391
+ _mm256_storeu_ps(y + i,
392
+ _mm256_castsi256_ps(
393
+ _mm256_slli_epi32(
394
+ _mm256_cvtepu16_epi32(
395
+ _mm_loadu_si128(
396
+ (const __m128i *)(x + i))),
397
+ 16)));
398
+ }
399
+ #endif
400
+ for (; i < n; i++) {
401
+ y[i] = GGML_BF16_TO_FP32(x[i]);
402
+ }
403
+ }
404
+
405
+ void ggml_fp32_to_bf16_row(const float * x, ggml_bf16_t * y, int64_t n) {
406
+ int i = 0;
407
+ #if defined(__AVX512BF16__)
408
+ for (; i + 32 <= n; i += 32) {
409
+ _mm512_storeu_ps(
410
+ (__m512 *)(y + i),
411
+ (__m512)_mm512_cvtne2ps_pbh(_mm512_loadu_ps(x + i + 16),
412
+ _mm512_loadu_ps(x + i)));
413
+ }
414
+ #endif
415
+ for (; i < n; i++) {
416
+ y[i] = GGML_FP32_TO_BF16(x[i]);
417
+ }
418
+ }
419
+
371
420
  bool ggml_guid_matches(ggml_guid_t guid_a, ggml_guid_t guid_b) {
372
421
  return memcmp(guid_a, guid_b, sizeof(ggml_guid)) == 0;
373
422
  }
@@ -503,6 +552,7 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
503
552
 
504
553
  static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float * restrict x, size_t bx, const float * restrict y, size_t by, int nrc);
505
554
  static void ggml_vec_dot_f16(int n, float * restrict s, size_t bs, ggml_fp16_t * restrict x, size_t bx, ggml_fp16_t * restrict y, size_t by, int nrc);
555
+ static void ggml_vec_dot_bf16(int n, float * restrict s, size_t bs, ggml_bf16_t * restrict x, size_t bx, ggml_bf16_t * restrict y, size_t by, int nrc);
506
556
 
507
557
  static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
508
558
  [GGML_TYPE_I8] = {
@@ -845,6 +895,18 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
845
895
  .type_size = sizeof(block_q8_K),
846
896
  .is_quantized = true,
847
897
  .from_float = quantize_row_q8_K,
898
+ },
899
+ [GGML_TYPE_BF16] = {
900
+ .type_name = "bf16",
901
+ .blck_size = 1,
902
+ .type_size = sizeof(ggml_bf16_t),
903
+ .is_quantized = false,
904
+ .to_float = (ggml_to_float_t) ggml_bf16_to_fp32_row,
905
+ .from_float = (ggml_from_float_t) ggml_fp32_to_bf16_row,
906
+ .from_float_reference = (ggml_from_float_t) ggml_fp32_to_bf16_row,
907
+ .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_bf16,
908
+ .vec_dot_type = GGML_TYPE_BF16,
909
+ .nrows = 1,
848
910
  }
849
911
  };
850
912
 
@@ -1237,6 +1299,8 @@ static inline void __avx_f32cx8_store(ggml_fp16_t *x, __m256 y) {
1237
1299
  #define GGML_F16_VEC_ZERO GGML_F32x4_ZERO
1238
1300
  #define GGML_F16_VEC_SET1 GGML_F32x4_SET1
1239
1301
  #define GGML_F16_VEC_FMA GGML_F32x4_FMA
1302
+ #define GGML_F16_VEC_ADD GGML_F32x4_ADD
1303
+ #define GGML_F16_VEC_MUL GGML_F32x4_MUL
1240
1304
  #define GGML_F16_VEC_REDUCE GGML_F32x4_REDUCE
1241
1305
  // Use vec_xl, not vec_ld, in case the load address is not aligned.
1242
1306
  #define GGML_F16_VEC_LOAD(p, i) (i & 0x1) ? \
@@ -1468,6 +1532,59 @@ static inline void __sse_f16x4_store(ggml_fp16_t *x, __m128 y) {
1468
1532
  #define GGML_F16_ARR (GGML_F16_STEP/GGML_F16_EPR)
1469
1533
  #endif
1470
1534
 
1535
+ //
1536
+ // ggml context
1537
+ //
1538
+
1539
+ struct ggml_context {
1540
+ size_t mem_size;
1541
+ void* mem_buffer;
1542
+ bool mem_buffer_owned;
1543
+ bool no_alloc;
1544
+ bool no_alloc_save; // this is used to save the no_alloc state when using scratch buffers
1545
+
1546
+ int n_objects;
1547
+
1548
+ struct ggml_object* objects_begin;
1549
+ struct ggml_object* objects_end;
1550
+
1551
+ struct ggml_scratch scratch;
1552
+ struct ggml_scratch scratch_save;
1553
+ };
1554
+
1555
+ struct ggml_context_container {
1556
+ bool used;
1557
+
1558
+ struct ggml_context context;
1559
+ };
1560
+
1561
+ struct ggml_compute_state_shared {
1562
+ const struct ggml_cgraph* cgraph;
1563
+ const struct ggml_cplan* cplan;
1564
+
1565
+ int64_t perf_node_start_cycles;
1566
+ int64_t perf_node_start_time_us;
1567
+
1568
+ const int n_threads;
1569
+
1570
+ // synchronization primitives
1571
+ atomic_int n_active; // num active threads
1572
+ atomic_int node_n; // active graph node
1573
+ atomic_int node_task; // active graph node task phase
1574
+
1575
+ ggml_abort_callback abort_callback; // abort ggml_graph_compute when true
1576
+ void* abort_callback_data;
1577
+
1578
+ atomic_int current_chunk; // currently processing chunk during Mat_Mul, shared between all the threads.
1579
+ };
1580
+
1581
+ struct ggml_compute_state {
1582
+ ggml_thread_t thrd;
1583
+ int ith;
1584
+ struct ggml_compute_state_shared* shared;
1585
+ enum ggml_status ec;
1586
+ };
1587
+
1471
1588
  //
1472
1589
  // fundamental operations
1473
1590
  //
@@ -1480,6 +1597,8 @@ inline static void ggml_vec_set_i32(const int n, int32_t * x, const int32_t v) {
1480
1597
 
1481
1598
  inline static void ggml_vec_set_f16(const int n, ggml_fp16_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
1482
1599
 
1600
+ inline static void ggml_vec_set_bf16(const int n, ggml_bf16_t * x, const ggml_bf16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
1601
+
1483
1602
  inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] + y[i]; }
1484
1603
  inline static void ggml_vec_add1_f32(const int n, float * z, const float * x, const float v) { for (int i = 0; i < n; ++i) z[i] = x[i] + v; }
1485
1604
  inline static void ggml_vec_acc_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] += x[i]; }
@@ -1498,7 +1617,7 @@ static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float *
1498
1617
  UNUSED(by);
1499
1618
  UNUSED(bs);
1500
1619
 
1501
- #ifdef GGML_SIMD
1620
+ #if defined(GGML_SIMD)
1502
1621
  float sumf = 0.0f;
1503
1622
  const int np = (n & ~(GGML_F32_STEP - 1));
1504
1623
 
@@ -1534,6 +1653,70 @@ static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float *
1534
1653
  *s = sumf;
1535
1654
  }
1536
1655
 
1656
+ static void ggml_vec_dot_bf16(int n, float * restrict s, size_t bs, ggml_bf16_t * restrict x, size_t bx, ggml_bf16_t * restrict y, size_t by, int nrc) {
1657
+ assert(nrc == 1);
1658
+ UNUSED(nrc);
1659
+ UNUSED(bx);
1660
+ UNUSED(by);
1661
+ UNUSED(bs);
1662
+ int i = 0;
1663
+ ggml_float sumf = 0;
1664
+
1665
+ #if defined(__AVX512BF16__)
1666
+ __m512 c1 = _mm512_setzero_ps();
1667
+ __m512 c2 = _mm512_setzero_ps();
1668
+ for (; i + 64 <= n; i += 64) {
1669
+ c1 = _mm512_dpbf16_ps(c1, (__m512bh)_mm512_loadu_ps((const float *)(x + i)),
1670
+ (__m512bh)_mm512_loadu_ps((const float *)(y + i)));
1671
+ c2 = _mm512_dpbf16_ps(c2, (__m512bh)_mm512_loadu_ps((const float *)(x + i + 32)),
1672
+ (__m512bh)_mm512_loadu_ps((const float *)(y + i + 32)));
1673
+ }
1674
+ sumf += (ggml_float)_mm512_reduce_add_ps(c1);
1675
+ sumf += (ggml_float)_mm512_reduce_add_ps(c2);
1676
+
1677
+ #elif defined(__AVX512F__)
1678
+ #define LOAD(p) _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i *)(p))), 16))
1679
+ __m512 c1 = _mm512_setzero_ps();
1680
+ __m512 c2 = _mm512_setzero_ps();
1681
+ for (; i + 32 <= n; i += 32) {
1682
+ c1 = _mm512_add_ps(_mm512_mul_ps(LOAD(x + i), LOAD(y + i)), c1);
1683
+ c2 = _mm512_add_ps(_mm512_mul_ps(LOAD(x + i + 16), LOAD(y + i + 16)), c2);
1684
+ }
1685
+ sumf += (ggml_float)_mm512_reduce_add_ps(c1);
1686
+ sumf += (ggml_float)_mm512_reduce_add_ps(c2);
1687
+
1688
+ #undef LOAD
1689
+ #elif defined(__AVX2__)
1690
+ #define LOAD(p) _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)(p))), 16))
1691
+ __m256 c1 = _mm256_setzero_ps();
1692
+ __m256 c2 = _mm256_setzero_ps();
1693
+ __m256 c3 = _mm256_setzero_ps();
1694
+ __m256 c4 = _mm256_setzero_ps();
1695
+ for (; i + 32 <= n; i += 32) {
1696
+ c1 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i), LOAD(y + i)), c1);
1697
+ c2 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i + 8), LOAD(y + i + 8)), c2);
1698
+ c3 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i + 16), LOAD(y + i + 16)), c3);
1699
+ c4 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i + 24), LOAD(y + i + 24)), c4);
1700
+ }
1701
+ __m128 g;
1702
+ c1 = _mm256_add_ps(_mm256_add_ps(c1, c3),
1703
+ _mm256_add_ps(c2, c4));
1704
+ g = _mm_add_ps(_mm256_extractf128_ps(c1, 1),
1705
+ _mm256_castps256_ps128(c1));
1706
+ g = _mm_add_ps(g, _mm_movehl_ps(g, g));
1707
+ g = _mm_add_ss(g, _mm_movehdup_ps(g));
1708
+ sumf += (ggml_float)_mm_cvtss_f32(g);
1709
+
1710
+ #undef LOAD
1711
+ #endif
1712
+
1713
+ for (; i < n; ++i) {
1714
+ sumf += (ggml_float)(GGML_BF16_TO_FP32(x[i]) *
1715
+ GGML_BF16_TO_FP32(y[i]));
1716
+ }
1717
+ *s = sumf;
1718
+ }
1719
+
1537
1720
  static void ggml_vec_dot_f16(int n, float * restrict s, size_t bs, ggml_fp16_t * restrict x, size_t bx, ggml_fp16_t * restrict y, size_t by, int nrc) {
1538
1721
  assert(nrc == 1);
1539
1722
  UNUSED(nrc);
@@ -1817,6 +2000,7 @@ inline static void ggml_vec_tanh_f32 (const int n, float * y, const float * x) {
1817
2000
  inline static void ggml_vec_elu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : expf(x[i])-1; }
1818
2001
  inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; }
1819
2002
  inline static void ggml_vec_leaky_relu_f32 (const int n, float * y, const float * x, const float ns) { for (int i = 0; i < n; ++i) y[i] = ((x[i] > 0.f) ? x[i] : 0.f) + ns * ((x[i] < 0.0f) ? x[i] : 0.f); }
2003
+ inline static void ggml_vec_sigmoid_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = 1.f / (1.f + expf(-x[i])); }
1820
2004
  // TODO: optimize performance
1821
2005
  inline static void ggml_vec_hardswish_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i] * fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); }
1822
2006
  inline static void ggml_vec_hardsigmoid_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); }
@@ -1892,52 +2076,291 @@ inline static float ggml_silu_f32(float x) {
1892
2076
  return x/(1.0f + expf(-x));
1893
2077
  }
1894
2078
 
1895
- //inline static void ggml_vec_silu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
1896
- // const uint16_t * i16 = (const uint16_t *) x;
1897
- // for (int i = 0; i < n; ++i) {
1898
- // y[i] = ggml_table_silu_f16[i16[i]];
1899
- // }
1900
- //}
2079
+ #if defined(__ARM_NEON)
1901
2080
 
1902
- #ifdef GGML_SILU_FP16
1903
- inline static void ggml_vec_silu_f32(const int n, float * y, const float * x) {
1904
- uint16_t t;
1905
- for (int i = 0; i < n; ++i) {
1906
- ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]);
1907
- memcpy(&t, &fp16, sizeof(uint16_t));
1908
- y[i] = GGML_FP16_TO_FP32(ggml_table_silu_f16[t]);
1909
- }
1910
- }
2081
+ // adapted from arm limited optimized routine
2082
+ // the maximum error is 1.45358 plus 0.5 ulps
2083
+ // numbers above 88.38 will flush to infinity
2084
+ // numbers beneath -103.97 will flush to zero
2085
+ inline static float32x4_t ggml_v_expf(float32x4_t x) {
2086
+ const float32x4_t r = vdupq_n_f32(0x1.8p23f);
2087
+ const float32x4_t z = vfmaq_f32(r, x, vdupq_n_f32(0x1.715476p+0f));
2088
+ const float32x4_t n = vsubq_f32(z, r);
2089
+ const float32x4_t b = vfmsq_f32(vfmsq_f32(x, n, vdupq_n_f32(0x1.62e4p-1f)), n,
2090
+ vdupq_n_f32(0x1.7f7d1cp-20f));
2091
+ const uint32x4_t e = vshlq_n_u32(vreinterpretq_u32_f32(z), 23);
2092
+ const float32x4_t k = vreinterpretq_f32_u32(vaddq_u32(e, vreinterpretq_u32_f32(vdupq_n_f32(1))));
2093
+ const uint32x4_t c = vcagtq_f32(n, vdupq_n_f32(126));
2094
+ const float32x4_t u = vmulq_f32(b, b);
2095
+ const float32x4_t j = vfmaq_f32(
2096
+ vmulq_f32(vdupq_n_f32(0x1.ffffecp-1f), b),
2097
+ vfmaq_f32(vfmaq_f32(vdupq_n_f32(0x1.fffdb6p-2f), vdupq_n_f32(0x1.555e66p-3f), b),
2098
+ vfmaq_f32(vdupq_n_f32(0x1.573e2ep-5f), vdupq_n_f32(0x1.0e4020p-7f), b), u), u);
2099
+ if (!vpaddd_u64(vreinterpretq_u64_u32(c)))
2100
+ return vfmaq_f32(k, j, k);
2101
+ const uint32x4_t d = vandq_u32(vclezq_f32(n), vdupq_n_u32(0x82000000));
2102
+ const float32x4_t s1 = vreinterpretq_f32_u32(vaddq_u32(d, vdupq_n_u32(0x7f000000)));
2103
+ const float32x4_t s2 = vreinterpretq_f32_u32(vsubq_u32(e, d));
2104
+ return vbslq_f32(vcagtq_f32(n, vdupq_n_f32(192)), vmulq_f32(s1, s1),
2105
+ vbslq_f32(c, vmulq_f32(vfmaq_f32(s2, s2, j), s1), vfmaq_f32(k, k, j)));
2106
+ }
2107
+
2108
+ // computes silu x/(1+exp(-x)) in single precision vector
2109
+ inline static float32x4_t ggml_v_silu(float32x4_t x) {
2110
+ const float32x4_t one = vdupq_n_f32(1.0f);
2111
+ const float32x4_t zero = vdupq_n_f32(0.0f);
2112
+ const float32x4_t neg_x = vsubq_f32(zero, x);
2113
+ const float32x4_t exp_neg_x = ggml_v_expf(neg_x);
2114
+ const float32x4_t one_plus_exp_neg_x = vaddq_f32(one, exp_neg_x);
2115
+ return vdivq_f32(x, one_plus_exp_neg_x);
2116
+ }
2117
+
2118
+ #elif defined(__AVX512F__) && defined(__AVX512DQ__)
2119
+
2120
+ // adapted from arm limited optimized routine
2121
+ // the maximum error is 1.45358 plus 0.5 ulps
2122
+ // numbers above 88.38 will flush to infinity
2123
+ // numbers beneath -103.97 will flush to zero
2124
+ inline static __m512 ggml_v_expf(__m512 x) {
2125
+ const __m512 r = _mm512_set1_ps(0x1.8p23f);
2126
+ const __m512 z = _mm512_fmadd_ps(x, _mm512_set1_ps(0x1.715476p+0f), r);
2127
+ const __m512 n = _mm512_sub_ps(z, r);
2128
+ const __m512 b = _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.7f7d1cp-20f),
2129
+ _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.62e4p-1f), x));
2130
+ const __m512i e = _mm512_slli_epi32(_mm512_castps_si512(z), 23);
2131
+ const __m512 k = _mm512_castsi512_ps(_mm512_add_epi32(e, _mm512_castps_si512(_mm512_set1_ps(1))));
2132
+ const __mmask16 c = _mm512_cmp_ps_mask(_mm512_abs_ps(n), _mm512_set1_ps(126), _CMP_GT_OQ);
2133
+ const __m512 u = _mm512_mul_ps(b, b);
2134
+ const __m512 j = _mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_set1_ps(0x1.0e4020p-7f), b,
2135
+ _mm512_set1_ps(0x1.573e2ep-5f)), u,
2136
+ _mm512_fmadd_ps(_mm512_set1_ps(0x1.555e66p-3f), b,
2137
+ _mm512_set1_ps(0x1.fffdb6p-2f))),
2138
+ u, _mm512_mul_ps(_mm512_set1_ps(0x1.ffffecp-1f), b));
2139
+ if (_mm512_kortestz(c, c))
2140
+ return _mm512_fmadd_ps(j, k, k);
2141
+ const __m512i g = _mm512_and_si512(
2142
+ _mm512_movm_epi32(_mm512_cmp_ps_mask(n, _mm512_setzero_ps(), _CMP_LE_OQ)),
2143
+ _mm512_set1_epi32(0x82000000u));
2144
+ const __m512 s1 =
2145
+ _mm512_castsi512_ps(_mm512_add_epi32(g, _mm512_set1_epi32(0x7f000000u)));
2146
+ const __m512 s2 = _mm512_castsi512_ps(_mm512_sub_epi32(e, g));
2147
+ const __mmask16 d =
2148
+ _mm512_cmp_ps_mask(_mm512_abs_ps(n), _mm512_set1_ps(192), _CMP_GT_OQ);
2149
+ return _mm512_mask_blend_ps(
2150
+ d, _mm512_mask_blend_ps(
2151
+ c, _mm512_fmadd_ps(k, j, k),
2152
+ _mm512_mul_ps(_mm512_fmadd_ps(s2, j, s2), s1)),
2153
+ _mm512_mul_ps(s1, s1));
2154
+ }
2155
+
2156
+ // computes silu x/(1+exp(-x)) in single precision vector
2157
+ inline static __m512 ggml_v_silu(__m512 x) {
2158
+ const __m512 one = _mm512_set1_ps(1);
2159
+ const __m512 zero = _mm512_setzero_ps();
2160
+ const __m512 neg_x = _mm512_sub_ps(zero, x);
2161
+ const __m512 exp_neg_x = ggml_v_expf(neg_x);
2162
+ const __m512 one_plus_exp_neg_x = _mm512_add_ps(one, exp_neg_x);
2163
+ return _mm512_div_ps(x, one_plus_exp_neg_x);
2164
+ }
2165
+
2166
+ #elif defined(__AVX2__) && defined(__FMA__)
2167
+
2168
+ // adapted from arm limited optimized routine
2169
+ // the maximum error is 1.45358 plus 0.5 ulps
2170
+ // numbers above 88.38 will flush to infinity
2171
+ // numbers beneath -103.97 will flush to zero
2172
+ inline static __m256 ggml_v_expf(__m256 x) {
2173
+ const __m256 r = _mm256_set1_ps(0x1.8p23f);
2174
+ const __m256 z = _mm256_fmadd_ps(x, _mm256_set1_ps(0x1.715476p+0f), r);
2175
+ const __m256 n = _mm256_sub_ps(z, r);
2176
+ const __m256 b = _mm256_fnmadd_ps(n, _mm256_set1_ps(0x1.7f7d1cp-20f),
2177
+ _mm256_fnmadd_ps(n, _mm256_set1_ps(0x1.62e4p-1f), x));
2178
+ const __m256i e = _mm256_slli_epi32(_mm256_castps_si256(z), 23);
2179
+ const __m256 k = _mm256_castsi256_ps(
2180
+ _mm256_add_epi32(e, _mm256_castps_si256(_mm256_set1_ps(1))));
2181
+ const __m256i c = _mm256_castps_si256(
2182
+ _mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n),
2183
+ _mm256_set1_ps(126), _CMP_GT_OQ));
2184
+ const __m256 u = _mm256_mul_ps(b, b);
2185
+ const __m256 j = _mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_set1_ps(0x1.0e4020p-7f), b,
2186
+ _mm256_set1_ps(0x1.573e2ep-5f)), u,
2187
+ _mm256_fmadd_ps(_mm256_set1_ps(0x1.555e66p-3f), b,
2188
+ _mm256_set1_ps(0x1.fffdb6p-2f))),
2189
+ u, _mm256_mul_ps(_mm256_set1_ps(0x1.ffffecp-1f), b));
2190
+ if (!_mm256_movemask_ps(_mm256_castsi256_ps(c)))
2191
+ return _mm256_fmadd_ps(j, k, k);
2192
+ const __m256i g = _mm256_and_si256(
2193
+ _mm256_castps_si256(_mm256_cmp_ps(n, _mm256_setzero_ps(), _CMP_LE_OQ)),
2194
+ _mm256_set1_epi32(0x82000000u));
2195
+ const __m256 s1 =
2196
+ _mm256_castsi256_ps(_mm256_add_epi32(g, _mm256_set1_epi32(0x7f000000u)));
2197
+ const __m256 s2 = _mm256_castsi256_ps(_mm256_sub_epi32(e, g));
2198
+ const __m256i d = _mm256_castps_si256(
2199
+ _mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n),
2200
+ _mm256_set1_ps(192), _CMP_GT_OQ));
2201
+ return _mm256_or_ps(
2202
+ _mm256_and_ps(_mm256_castsi256_ps(d), _mm256_mul_ps(s1, s1)),
2203
+ _mm256_andnot_ps(
2204
+ _mm256_castsi256_ps(d),
2205
+ _mm256_or_ps(
2206
+ _mm256_and_ps(_mm256_castsi256_ps(c),
2207
+ _mm256_mul_ps(_mm256_fmadd_ps(s2, j, s2), s1)),
2208
+ _mm256_andnot_ps(_mm256_castsi256_ps(c), _mm256_fmadd_ps(k, j, k)))));
2209
+ }
2210
+
2211
+ // computes silu x/(1+exp(-x)) in single precision vector
2212
+ inline static __m256 ggml_v_silu(__m256 x) {
2213
+ const __m256 one = _mm256_set1_ps(1);
2214
+ const __m256 zero = _mm256_setzero_ps();
2215
+ const __m256 neg_x = _mm256_sub_ps(zero, x);
2216
+ const __m256 exp_neg_x = ggml_v_expf(neg_x);
2217
+ const __m256 one_plus_exp_neg_x = _mm256_add_ps(one, exp_neg_x);
2218
+ return _mm256_div_ps(x, one_plus_exp_neg_x);
2219
+ }
2220
+
2221
+ #elif defined(__SSE2__) // __AVX2__ / __ARM_NEON
2222
+
2223
+ #if defined(__FMA__)
2224
+ #define MADD128(x, y, z) _mm_fmadd_ps(x, y, z)
2225
+ #define NMADD128(x, y, z) _mm_fnmadd_ps(x, y, z)
1911
2226
  #else
1912
- inline static void ggml_vec_silu_f32(const int n, float * y, const float * x) {
1913
- for (int i = 0; i < n; ++i) {
2227
+ #define MADD128(x, y, z) _mm_add_ps(_mm_mul_ps(x, y), z)
2228
+ #define NMADD128(x, y, z) _mm_sub_ps(z, _mm_mul_ps(x, y))
2229
+ #endif
2230
+
2231
+ // adapted from arm limited optimized routine
2232
+ // the maximum error is 1.45358 plus 0.5 ulps
2233
+ // numbers above 88.38 will flush to infinity
2234
+ // numbers beneath -103.97 will flush to zero
2235
+ inline static __m128 ggml_v_expf(__m128 x) {
2236
+ const __m128 r = _mm_set1_ps(0x1.8p23f);
2237
+ const __m128 z = MADD128(x, _mm_set1_ps(0x1.715476p+0f), r);
2238
+ const __m128 n = _mm_sub_ps(z, r);
2239
+ const __m128 b =
2240
+ NMADD128(n, _mm_set1_ps(0x1.7f7d1cp-20f), NMADD128(n, _mm_set1_ps(0x1.62e4p-1f), x));
2241
+ const __m128i e = _mm_slli_epi32(_mm_castps_si128(z), 23);
2242
+ const __m128 k = _mm_castsi128_ps(_mm_add_epi32(e, _mm_castps_si128(_mm_set1_ps(1))));
2243
+ const __m128i c =
2244
+ _mm_castps_si128(_mm_cmpgt_ps(_mm_andnot_ps(_mm_set1_ps(-0.f), n), _mm_set1_ps(126)));
2245
+ const __m128 u = _mm_mul_ps(b, b);
2246
+ const __m128 j =
2247
+ MADD128(MADD128(MADD128(_mm_set1_ps(0x1.0e4020p-7f), b, _mm_set1_ps(0x1.573e2ep-5f)), u,
2248
+ MADD128(_mm_set1_ps(0x1.555e66p-3f), b, _mm_set1_ps(0x1.fffdb6p-2f))),
2249
+ u, _mm_mul_ps(_mm_set1_ps(0x1.ffffecp-1f), b));
2250
+ if (!_mm_movemask_epi8(c))
2251
+ return MADD128(j, k, k);
2252
+ const __m128i g = _mm_and_si128(_mm_castps_si128(_mm_cmple_ps(n, _mm_setzero_ps())),
2253
+ _mm_set1_epi32(0x82000000u));
2254
+ const __m128 s1 = _mm_castsi128_ps(_mm_add_epi32(g, _mm_set1_epi32(0x7f000000u)));
2255
+ const __m128 s2 = _mm_castsi128_ps(_mm_sub_epi32(e, g));
2256
+ const __m128i d =
2257
+ _mm_castps_si128(_mm_cmpgt_ps(_mm_andnot_ps(_mm_set1_ps(-0.f), n), _mm_set1_ps(192)));
2258
+ return _mm_or_ps(
2259
+ _mm_and_ps(_mm_castsi128_ps(d), _mm_mul_ps(s1, s1)),
2260
+ _mm_andnot_ps(_mm_castsi128_ps(d),
2261
+ _mm_or_ps(_mm_and_ps(_mm_castsi128_ps(c), _mm_mul_ps(MADD128(s2, j, s2), s1)),
2262
+ _mm_andnot_ps(_mm_castsi128_ps(c), MADD128(k, j, k)))));
2263
+ }
2264
+
2265
+ // computes silu x/(1+exp(-x)) in single precision vector
2266
+ inline static __m128 ggml_v_silu(__m128 x) {
2267
+ const __m128 one = _mm_set1_ps(1);
2268
+ const __m128 zero = _mm_setzero_ps();
2269
+ const __m128 neg_x = _mm_sub_ps(zero, x);
2270
+ const __m128 exp_neg_x = ggml_v_expf(neg_x);
2271
+ const __m128 one_plus_exp_neg_x = _mm_add_ps(one, exp_neg_x);
2272
+ return _mm_div_ps(x, one_plus_exp_neg_x);
2273
+ }
2274
+
2275
+ #endif // __ARM_NEON / __AVX2__ / __SSE2__
2276
+
2277
+ static void ggml_vec_silu_f32(const int n, float * y, const float * x) {
2278
+ int i = 0;
2279
+ #if defined(__AVX512F__) && defined(__AVX512DQ__)
2280
+ for (; i + 15 < n; i += 16) {
2281
+ _mm512_storeu_ps(y + i, ggml_v_silu(_mm512_loadu_ps(x + i)));
2282
+ }
2283
+ #elif defined(__AVX2__) && defined(__FMA__)
2284
+ for (; i + 7 < n; i += 8) {
2285
+ _mm256_storeu_ps(y + i, ggml_v_silu(_mm256_loadu_ps(x + i)));
2286
+ }
2287
+ #elif defined(__SSE2__)
2288
+ for (; i + 3 < n; i += 4) {
2289
+ _mm_storeu_ps(y + i, ggml_v_silu(_mm_loadu_ps(x + i)));
2290
+ }
2291
+ #elif defined(__ARM_NEON)
2292
+ for (; i + 3 < n; i += 4) {
2293
+ vst1q_f32(y + i, ggml_v_silu(vld1q_f32(x + i)));
2294
+ }
2295
+ #endif
2296
+ for (; i < n; ++i) {
1914
2297
  y[i] = ggml_silu_f32(x[i]);
1915
2298
  }
1916
2299
  }
2300
+
2301
+ static ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max) {
2302
+ int i = 0;
2303
+ ggml_float sum = 0;
2304
+ #if defined(__AVX512F__) && defined(__AVX512DQ__)
2305
+ for (; i + 15 < n; i += 16) {
2306
+ __m512 val = ggml_v_expf(_mm512_sub_ps(_mm512_loadu_ps(x + i),
2307
+ _mm512_set1_ps(max)));
2308
+ _mm512_storeu_ps(y + i, val);
2309
+ sum += (ggml_float)_mm512_reduce_add_ps(val);
2310
+ }
2311
+ #elif defined(__AVX2__) && defined(__FMA__)
2312
+ for (; i + 7 < n; i += 8) {
2313
+ __m256 val = ggml_v_expf(_mm256_sub_ps(_mm256_loadu_ps(x + i),
2314
+ _mm256_set1_ps(max)));
2315
+ _mm256_storeu_ps(y + i, val);
2316
+ __m128 val2 = _mm_add_ps(_mm256_extractf128_ps(val, 1),
2317
+ _mm256_castps256_ps128(val));
2318
+ val2 = _mm_add_ps(val2, _mm_movehl_ps(val2, val2));
2319
+ val2 = _mm_add_ss(val2, _mm_movehdup_ps(val2));
2320
+ sum += (ggml_float)_mm_cvtss_f32(val2);
2321
+ }
2322
+ #elif defined(__SSE2__)
2323
+ for (; i + 3 < n; i += 4) {
2324
+ __m128 val = ggml_v_expf(_mm_sub_ps(_mm_loadu_ps(x + i),
2325
+ _mm_set1_ps(max)));
2326
+ _mm_storeu_ps(y + i, val);
2327
+ #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
2328
+ val = _mm_add_ps(val, _mm_movehl_ps(val, val));
2329
+ val = _mm_add_ss(val, _mm_movehdup_ps(val));
2330
+ #else
2331
+ __m128 tmp = _mm_shuffle_ps(val, val, _MM_SHUFFLE(2, 3, 0, 1));
2332
+ val = _mm_add_ps(val, tmp);
2333
+ tmp = _mm_movehl_ps(tmp, val);
2334
+ val = _mm_add_ss(val, tmp);
1917
2335
  #endif
2336
+ sum += (ggml_float)_mm_cvtss_f32(val);
2337
+ }
2338
+ #elif defined(__ARM_NEON)
2339
+ for (; i + 3 < n; i += 4) {
2340
+ float32x4_t val = ggml_v_expf(vsubq_f32(vld1q_f32(x + i),
2341
+ vdupq_n_f32(max)));
2342
+ vst1q_f32(y + i, val);
2343
+ sum += (ggml_float)vaddvq_f32(val);
2344
+ }
2345
+ #endif
2346
+ for (; i < n; ++i) {
2347
+ float val = expf(x[i] - max);
2348
+ sum += (ggml_float)val;
2349
+ y[i] = val;
2350
+ }
2351
+ return sum;
2352
+ }
1918
2353
 
1919
2354
  inline static float ggml_silu_backward_f32(float x, float dy) {
1920
2355
  const float s = 1.0f/(1.0f + expf(-x));
1921
2356
  return dy*s*(1.0f + x*(1.0f - s));
1922
2357
  }
1923
2358
 
1924
- #ifdef GGML_SILU_FP16
1925
- inline static void ggml_vec_silu_backward_f32(const int n, float * dx, const float * x, const float * dy) {
1926
- for (int i = 0; i < n; ++i) {
1927
- // we did not use x[i] to compute forward silu but its f16 equivalent
1928
- // take derivative at f16 of x[i]:
1929
- ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]);
1930
- float usedx = GGML_FP16_TO_FP32(fp16);
1931
- dx[i] = ggml_silu_backward_f32(usedx, dy[i]);
1932
- }
1933
- }
1934
- #else
1935
2359
  inline static void ggml_vec_silu_backward_f32(const int n, float * dx, const float * x, const float * dy) {
1936
2360
  for (int i = 0; i < n; ++i) {
1937
2361
  dx[i] = ggml_silu_backward_f32(x[i], dy[i]);
1938
2362
  }
1939
2363
  }
1940
- #endif
1941
2364
 
1942
2365
  inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) {
1943
2366
  #ifndef GGML_USE_ACCELERATE
@@ -1967,6 +2390,14 @@ inline static void ggml_vec_sum_f16_ggf(const int n, float * s, const ggml_fp16_
1967
2390
  *s = sum;
1968
2391
  }
1969
2392
 
2393
+ inline static void ggml_vec_sum_bf16_ggf(const int n, float * s, const ggml_bf16_t * x) {
2394
+ float sum = 0.0f;
2395
+ for (int i = 0; i < n; ++i) {
2396
+ sum += GGML_BF16_TO_FP32(x[i]);
2397
+ }
2398
+ *s = sum;
2399
+ }
2400
+
1970
2401
  inline static void ggml_vec_max_f32(const int n, float * s, const float * x) {
1971
2402
  #ifndef GGML_USE_ACCELERATE
1972
2403
  float max = -INFINITY;
@@ -2045,7 +2476,6 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
2045
2476
  "SOFT_MAX_BACK",
2046
2477
  "ROPE",
2047
2478
  "ROPE_BACK",
2048
- "ALIBI",
2049
2479
  "CLAMP",
2050
2480
  "CONV_TRANSPOSE_1D",
2051
2481
  "IM2COL",
@@ -2087,7 +2517,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
2087
2517
  "CROSS_ENTROPY_LOSS_BACK",
2088
2518
  };
2089
2519
 
2090
- static_assert(GGML_OP_COUNT == 77, "GGML_OP_COUNT != 77");
2520
+ static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76");
2091
2521
 
2092
2522
  static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
2093
2523
  "none",
@@ -2136,7 +2566,6 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
2136
2566
  "soft_max_back(x)",
2137
2567
  "rope(x)",
2138
2568
  "rope_back(x)",
2139
- "alibi(x)",
2140
2569
  "clamp(x)",
2141
2570
  "conv_transpose_1d(x)",
2142
2571
  "im2col(x)",
@@ -2178,7 +2607,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
2178
2607
  "cross_entropy_loss_back(x,y)",
2179
2608
  };
2180
2609
 
2181
- static_assert(GGML_OP_COUNT == 77, "GGML_OP_COUNT != 77");
2610
+ static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76");
2182
2611
 
2183
2612
  static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
2184
2613
 
@@ -2191,6 +2620,7 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
2191
2620
  "TANH",
2192
2621
  "ELU",
2193
2622
  "RELU",
2623
+ "SIGMOID",
2194
2624
  "GELU",
2195
2625
  "GELU_QUICK",
2196
2626
  "SILU",
@@ -2198,7 +2628,7 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
2198
2628
  "HARDSIGMOID",
2199
2629
  };
2200
2630
 
2201
- static_assert(GGML_UNARY_OP_COUNT == 12, "GGML_UNARY_OP_COUNT != 12");
2631
+ static_assert(GGML_UNARY_OP_COUNT == 13, "GGML_UNARY_OP_COUNT != 13");
2202
2632
 
2203
2633
 
2204
2634
  static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
@@ -2240,32 +2670,6 @@ static void ggml_setup_op_has_task_pass(void) {
2240
2670
  }
2241
2671
  }
2242
2672
 
2243
- //
2244
- // ggml context
2245
- //
2246
-
2247
- struct ggml_context {
2248
- size_t mem_size;
2249
- void * mem_buffer;
2250
- bool mem_buffer_owned;
2251
- bool no_alloc;
2252
- bool no_alloc_save; // this is used to save the no_alloc state when using scratch buffers
2253
-
2254
- int n_objects;
2255
-
2256
- struct ggml_object * objects_begin;
2257
- struct ggml_object * objects_end;
2258
-
2259
- struct ggml_scratch scratch;
2260
- struct ggml_scratch scratch_save;
2261
- };
2262
-
2263
- struct ggml_context_container {
2264
- bool used;
2265
-
2266
- struct ggml_context context;
2267
- };
2268
-
2269
2673
  //
2270
2674
  // NUMA support
2271
2675
  //
@@ -2377,7 +2781,7 @@ void ggml_numa_init(enum ggml_numa_strategy numa_flag) {
2377
2781
  // figure out which node we're on
2378
2782
  uint current_cpu;
2379
2783
  int getcpu_ret = 0;
2380
- #if __GLIBC__ > 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ > 28)
2784
+ #if __GLIBC__ > 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ > 28) || defined(__COSMOPOLITAN__)
2381
2785
  getcpu_ret = getcpu(&current_cpu, &g_state.numa.current_node);
2382
2786
  #else
2383
2787
  // old glibc doesn't have a wrapper for this call. Fall back on direct syscall
@@ -2588,6 +2992,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
2588
2992
  switch (ftype) {
2589
2993
  case GGML_FTYPE_ALL_F32: wtype = GGML_TYPE_F32; break;
2590
2994
  case GGML_FTYPE_MOSTLY_F16: wtype = GGML_TYPE_F16; break;
2995
+ case GGML_FTYPE_MOSTLY_BF16: wtype = GGML_TYPE_BF16; break;
2591
2996
  case GGML_FTYPE_MOSTLY_Q4_0: wtype = GGML_TYPE_Q4_0; break;
2592
2997
  case GGML_FTYPE_MOSTLY_Q4_1: wtype = GGML_TYPE_Q4_1; break;
2593
2998
  case GGML_FTYPE_MOSTLY_Q5_0: wtype = GGML_TYPE_Q5_0; break;
@@ -2678,6 +3083,16 @@ bool ggml_are_same_shape(const struct ggml_tensor * t0, const struct ggml_tensor
2678
3083
  (t0->ne[3] == t1->ne[3] );
2679
3084
  }
2680
3085
 
3086
+ bool ggml_are_same_stride(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
3087
+ static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
3088
+
3089
+ return
3090
+ (t0->nb[0] == t1->nb[0] ) &&
3091
+ (t0->nb[1] == t1->nb[1] ) &&
3092
+ (t0->nb[2] == t1->nb[2] ) &&
3093
+ (t0->nb[3] == t1->nb[3] );
3094
+ }
3095
+
2681
3096
  // check if t1 can be represented as a repeatition of t0
2682
3097
  static inline bool ggml_can_repeat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
2683
3098
  static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
@@ -2729,15 +3144,14 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
2729
3144
  {
2730
3145
  const uint64_t t_start = ggml_time_us(); UNUSED(t_start);
2731
3146
 
2732
- ggml_fp16_t ii;
2733
3147
  for (int i = 0; i < (1 << 16); ++i) {
2734
- uint16_t ui = i;
2735
- memcpy(&ii, &ui, sizeof(ii));
2736
- const float f = ggml_table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(ii);
3148
+ union {
3149
+ uint16_t u16;
3150
+ ggml_fp16_t fp16;
3151
+ } u = {i};
3152
+ float f = ggml_table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(u.fp16);
2737
3153
  ggml_table_gelu_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_f32(f));
2738
3154
  ggml_table_gelu_quick_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_quick_f32(f));
2739
- ggml_table_silu_f16[i] = GGML_FP32_TO_FP16(ggml_silu_f32(f));
2740
- ggml_table_exp_f16[i] = GGML_FP32_TO_FP16(expf(f));
2741
3155
  }
2742
3156
 
2743
3157
  const uint64_t t_end = ggml_time_us(); UNUSED(t_end);
@@ -3021,6 +3435,12 @@ static struct ggml_tensor * ggml_new_tensor_impl(
3021
3435
 
3022
3436
  struct ggml_tensor * const result = (struct ggml_tensor *)((char *)ctx->mem_buffer + obj_new->offs);
3023
3437
 
3438
+ #ifdef __clang__
3439
+ // temporary until ggml_tensor::backend is removed
3440
+ #pragma clang diagnostic push
3441
+ #pragma clang diagnostic ignored "-Wdeprecated-declarations"
3442
+ #endif
3443
+
3024
3444
  *result = (struct ggml_tensor) {
3025
3445
  /*.type =*/ type,
3026
3446
  /*.backend =*/ GGML_BACKEND_TYPE_CPU,
@@ -3043,6 +3463,10 @@ static struct ggml_tensor * ggml_new_tensor_impl(
3043
3463
  /*.padding =*/ { 0 },
3044
3464
  };
3045
3465
 
3466
+ #ifdef __clang__
3467
+ #pragma clang diagnostic pop
3468
+ #endif
3469
+
3046
3470
  // TODO: this should not be needed as long as we don't rely on aligned SIMD loads
3047
3471
  //ggml_assert_aligned(result->data);
3048
3472
 
@@ -3201,6 +3625,13 @@ struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value) {
3201
3625
  ggml_vec_set_f16(nc, (ggml_fp16_t *)(data + i*n1), GGML_FP32_TO_FP16(value));
3202
3626
  }
3203
3627
  } break;
3628
+ case GGML_TYPE_BF16:
3629
+ {
3630
+ assert(tensor->nb[0] == sizeof(ggml_fp16_t));
3631
+ for (int i = 0; i < n; i++) {
3632
+ ggml_vec_set_bf16(nc, (ggml_bf16_t *)(data + i*n1), GGML_FP32_TO_BF16(value));
3633
+ }
3634
+ } break;
3204
3635
  case GGML_TYPE_F32:
3205
3636
  {
3206
3637
  assert(tensor->nb[0] == sizeof(float));
@@ -3253,6 +3684,13 @@ struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) {
3253
3684
  ggml_vec_set_f16(nc, (ggml_fp16_t *)(data + i*n1), GGML_FP32_TO_FP16(value));
3254
3685
  }
3255
3686
  } break;
3687
+ case GGML_TYPE_BF16:
3688
+ {
3689
+ assert(tensor->nb[0] == sizeof(ggml_bf16_t));
3690
+ for (int i = 0; i < n; i++) {
3691
+ ggml_vec_set_bf16(nc, (ggml_bf16_t *)(data + i*n1), GGML_FP32_TO_BF16(value));
3692
+ }
3693
+ } break;
3256
3694
  case GGML_TYPE_F32:
3257
3695
  {
3258
3696
  assert(tensor->nb[0] == sizeof(float));
@@ -3320,6 +3758,11 @@ int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) {
3320
3758
  GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t));
3321
3759
  return GGML_FP16_TO_FP32(((ggml_fp16_t *)(tensor->data))[i]);
3322
3760
  }
3761
+ case GGML_TYPE_BF16:
3762
+ {
3763
+ GGML_ASSERT(tensor->nb[0] == sizeof(ggml_bf16_t));
3764
+ return GGML_BF16_TO_FP32(((ggml_bf16_t *)(tensor->data))[i]);
3765
+ }
3323
3766
  case GGML_TYPE_F32:
3324
3767
  {
3325
3768
  GGML_ASSERT(tensor->nb[0] == sizeof(float));
@@ -3362,6 +3805,11 @@ void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) {
3362
3805
  GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t));
3363
3806
  ((ggml_fp16_t *)(tensor->data))[i] = GGML_FP32_TO_FP16(value);
3364
3807
  } break;
3808
+ case GGML_TYPE_BF16:
3809
+ {
3810
+ GGML_ASSERT(tensor->nb[0] == sizeof(ggml_bf16_t));
3811
+ ((ggml_bf16_t *)(tensor->data))[i] = GGML_FP32_TO_BF16(value);
3812
+ } break;
3365
3813
  case GGML_TYPE_F32:
3366
3814
  {
3367
3815
  GGML_ASSERT(tensor->nb[0] == sizeof(float));
@@ -3385,6 +3833,8 @@ int32_t ggml_get_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i
3385
3833
  return ((int32_t *) data)[0];
3386
3834
  case GGML_TYPE_F16:
3387
3835
  return GGML_FP16_TO_FP32(((ggml_fp16_t *) data)[0]);
3836
+ case GGML_TYPE_BF16:
3837
+ return GGML_BF16_TO_FP32(((ggml_bf16_t *) data)[0]);
3388
3838
  case GGML_TYPE_F32:
3389
3839
  return ((float *) data)[0];
3390
3840
  default:
@@ -3413,6 +3863,10 @@ void ggml_set_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2,
3413
3863
  {
3414
3864
  ((ggml_fp16_t *)(data))[0] = GGML_FP32_TO_FP16(value);
3415
3865
  } break;
3866
+ case GGML_TYPE_BF16:
3867
+ {
3868
+ ((ggml_bf16_t *)(data))[0] = GGML_FP32_TO_BF16(value);
3869
+ } break;
3416
3870
  case GGML_TYPE_F32:
3417
3871
  {
3418
3872
  ((float *)(data))[0] = value;
@@ -3451,6 +3905,11 @@ float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) {
3451
3905
  GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t));
3452
3906
  return GGML_FP16_TO_FP32(((ggml_fp16_t *)(tensor->data))[i]);
3453
3907
  }
3908
+ case GGML_TYPE_BF16:
3909
+ {
3910
+ GGML_ASSERT(tensor->nb[0] == sizeof(ggml_bf16_t));
3911
+ return GGML_BF16_TO_FP32(((ggml_bf16_t *)(tensor->data))[i]);
3912
+ }
3454
3913
  case GGML_TYPE_F32:
3455
3914
  {
3456
3915
  GGML_ASSERT(tensor->nb[0] == sizeof(float));
@@ -3493,6 +3952,11 @@ void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) {
3493
3952
  GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t));
3494
3953
  ((ggml_fp16_t *)(tensor->data))[i] = GGML_FP32_TO_FP16(value);
3495
3954
  } break;
3955
+ case GGML_TYPE_BF16:
3956
+ {
3957
+ GGML_ASSERT(tensor->nb[0] == sizeof(ggml_bf16_t));
3958
+ ((ggml_bf16_t *)(tensor->data))[i] = GGML_FP32_TO_BF16(value);
3959
+ } break;
3496
3960
  case GGML_TYPE_F32:
3497
3961
  {
3498
3962
  GGML_ASSERT(tensor->nb[0] == sizeof(float));
@@ -3516,6 +3980,8 @@ float ggml_get_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2,
3516
3980
  return ((int32_t *) data)[0];
3517
3981
  case GGML_TYPE_F16:
3518
3982
  return GGML_FP16_TO_FP32(((ggml_fp16_t *) data)[0]);
3983
+ case GGML_TYPE_BF16:
3984
+ return GGML_BF16_TO_FP32(((ggml_bf16_t *) data)[0]);
3519
3985
  case GGML_TYPE_F32:
3520
3986
  return ((float *) data)[0];
3521
3987
  default:
@@ -3544,6 +4010,10 @@ void ggml_set_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2,
3544
4010
  {
3545
4011
  ((ggml_fp16_t *)(data))[0] = GGML_FP32_TO_FP16(value);
3546
4012
  } break;
4013
+ case GGML_TYPE_BF16:
4014
+ {
4015
+ ((ggml_bf16_t *)(data))[0] = GGML_FP32_TO_BF16(value);
4016
+ } break;
3547
4017
  case GGML_TYPE_F32:
3548
4018
  {
3549
4019
  ((float *)(data))[0] = value;
@@ -3738,7 +4208,11 @@ static struct ggml_tensor * ggml_add_cast_impl(
3738
4208
  // TODO: support less-strict constraint
3739
4209
  // GGML_ASSERT(ggml_can_repeat(b, a));
3740
4210
  GGML_ASSERT(ggml_can_repeat_rows(b, a));
3741
- GGML_ASSERT(ggml_is_quantized(a->type) || a->type == GGML_TYPE_F16); // currently only supported for quantized input and f16
4211
+
4212
+ // currently only supported for quantized input and f16
4213
+ GGML_ASSERT(ggml_is_quantized(a->type) ||
4214
+ a->type == GGML_TYPE_F16 ||
4215
+ a->type == GGML_TYPE_BF16);
3742
4216
 
3743
4217
  bool is_node = false;
3744
4218
 
@@ -4371,6 +4845,20 @@ struct ggml_tensor * ggml_leaky_relu(
4371
4845
  return result;
4372
4846
  }
4373
4847
 
4848
+ // ggml_sigmoid
4849
+
4850
+ struct ggml_tensor * ggml_sigmoid(
4851
+ struct ggml_context * ctx,
4852
+ struct ggml_tensor * a) {
4853
+ return ggml_unary(ctx, a, GGML_UNARY_OP_SIGMOID);
4854
+ }
4855
+
4856
+ struct ggml_tensor * ggml_sigmoid_inplace(
4857
+ struct ggml_context * ctx,
4858
+ struct ggml_tensor * a) {
4859
+ return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_SIGMOID);
4860
+ }
4861
+
4374
4862
  // ggml_gelu
4375
4863
 
4376
4864
  struct ggml_tensor * ggml_gelu(
@@ -5454,7 +5942,6 @@ static struct ggml_tensor * ggml_soft_max_impl(
5454
5942
  struct ggml_context * ctx,
5455
5943
  struct ggml_tensor * a,
5456
5944
  struct ggml_tensor * mask,
5457
- struct ggml_tensor * pos,
5458
5945
  float scale,
5459
5946
  float max_bias,
5460
5947
  bool inplace) {
@@ -5468,18 +5955,8 @@ static struct ggml_tensor * ggml_soft_max_impl(
5468
5955
  GGML_ASSERT(mask->ne[1] >= a->ne[1]);
5469
5956
  }
5470
5957
 
5471
- if (pos) {
5472
- GGML_ASSERT(ggml_is_vector(pos));
5473
- GGML_ASSERT(pos->type == GGML_TYPE_F16 || pos->type == GGML_TYPE_F32);
5474
- GGML_ASSERT(pos->ne[0] == a->ne[0]);
5475
- }
5476
-
5477
- if (pos && mask) {
5478
- GGML_ASSERT(pos->type == mask->type);
5479
- }
5480
-
5481
5958
  if (max_bias > 0.0f) {
5482
- GGML_ASSERT(pos);
5959
+ GGML_ASSERT(mask);
5483
5960
  }
5484
5961
 
5485
5962
  bool is_node = false;
@@ -5497,7 +5974,6 @@ static struct ggml_tensor * ggml_soft_max_impl(
5497
5974
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5498
5975
  result->src[0] = a;
5499
5976
  result->src[1] = mask;
5500
- result->src[2] = pos;
5501
5977
 
5502
5978
  return result;
5503
5979
  }
@@ -5505,23 +5981,22 @@ static struct ggml_tensor * ggml_soft_max_impl(
5505
5981
  struct ggml_tensor * ggml_soft_max(
5506
5982
  struct ggml_context * ctx,
5507
5983
  struct ggml_tensor * a) {
5508
- return ggml_soft_max_impl(ctx, a, NULL, NULL, 1.0f, 0.0f, false);
5984
+ return ggml_soft_max_impl(ctx, a, NULL, 1.0f, 0.0f, false);
5509
5985
  }
5510
5986
 
5511
5987
  struct ggml_tensor * ggml_soft_max_inplace(
5512
5988
  struct ggml_context * ctx,
5513
5989
  struct ggml_tensor * a) {
5514
- return ggml_soft_max_impl(ctx, a, NULL, NULL, 1.0f, 0.0f, true);
5990
+ return ggml_soft_max_impl(ctx, a, NULL, 1.0f, 0.0f, true);
5515
5991
  }
5516
5992
 
5517
5993
  struct ggml_tensor * ggml_soft_max_ext(
5518
5994
  struct ggml_context * ctx,
5519
5995
  struct ggml_tensor * a,
5520
5996
  struct ggml_tensor * mask,
5521
- struct ggml_tensor * pos,
5522
5997
  float scale,
5523
5998
  float max_bias) {
5524
- return ggml_soft_max_impl(ctx, a, mask, pos, scale, max_bias, false);
5999
+ return ggml_soft_max_impl(ctx, a, mask, scale, max_bias, false);
5525
6000
  }
5526
6001
 
5527
6002
  // ggml_soft_max_back
@@ -5736,37 +6211,6 @@ struct ggml_tensor * ggml_rope_back(
5736
6211
  return result;
5737
6212
  }
5738
6213
 
5739
- // ggml_alibi
5740
-
5741
- struct ggml_tensor * ggml_alibi(
5742
- struct ggml_context * ctx,
5743
- struct ggml_tensor * a,
5744
- int n_past,
5745
- int n_head,
5746
- float bias_max) {
5747
- GGML_ASSERT(n_past >= 0);
5748
- bool is_node = false;
5749
-
5750
- if (a->grad) {
5751
- GGML_ASSERT(false); // TODO: implement backward
5752
- is_node = true;
5753
- }
5754
-
5755
- // TODO: when implement backward, fix this:
5756
- //struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
5757
- struct ggml_tensor * result = ggml_view_tensor(ctx, a);
5758
-
5759
- int32_t op_params[3] = { n_past, n_head };
5760
- memcpy(op_params + 2, &bias_max, sizeof(float));
5761
- ggml_set_op_params(result, op_params, sizeof(op_params));
5762
-
5763
- result->op = GGML_OP_ALIBI;
5764
- result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5765
- result->src[0] = a;
5766
-
5767
- return result;
5768
- }
5769
-
5770
6214
  // ggml_clamp
5771
6215
 
5772
6216
  struct ggml_tensor * ggml_clamp(
@@ -6116,7 +6560,10 @@ struct ggml_tensor * ggml_pool_2d(
6116
6560
  static struct ggml_tensor * ggml_upscale_impl(
6117
6561
  struct ggml_context * ctx,
6118
6562
  struct ggml_tensor * a,
6119
- int scale_factor) {
6563
+ int ne0,
6564
+ int ne1,
6565
+ int ne2,
6566
+ int ne3) {
6120
6567
  bool is_node = false;
6121
6568
 
6122
6569
  if (a->grad) {
@@ -6124,19 +6571,45 @@ static struct ggml_tensor * ggml_upscale_impl(
6124
6571
  is_node = true;
6125
6572
  }
6126
6573
 
6574
+ GGML_ASSERT(a->ne[0] <= ne0);
6575
+ GGML_ASSERT(a->ne[1] <= ne1);
6576
+ GGML_ASSERT(a->ne[2] <= ne2);
6577
+ GGML_ASSERT(a->ne[3] <= ne3);
6578
+
6127
6579
  struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type,
6128
- a->ne[0] * scale_factor,
6129
- a->ne[1] * scale_factor,
6130
- a->ne[2], a->ne[3]);
6580
+ ne0,
6581
+ ne1,
6582
+ ne2,
6583
+ ne3
6584
+ );
6131
6585
 
6132
6586
  result->op = GGML_OP_UPSCALE;
6133
- result->op_params[0] = scale_factor;
6587
+
6134
6588
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6135
6589
  result->src[0] = a;
6136
6590
 
6137
6591
  return result;
6138
6592
  }
6139
6593
 
6594
+ struct ggml_tensor * ggml_upscale(
6595
+ struct ggml_context * ctx,
6596
+ struct ggml_tensor * a,
6597
+ int scale_factor) {
6598
+ return ggml_upscale_impl(ctx, a, a->ne[0] * scale_factor, a->ne[1] * scale_factor, a->ne[2], a->ne[3]);
6599
+ }
6600
+
6601
+ struct ggml_tensor * ggml_upscale_ext(
6602
+ struct ggml_context * ctx,
6603
+ struct ggml_tensor * a,
6604
+ int ne0,
6605
+ int ne1,
6606
+ int ne2,
6607
+ int ne3) {
6608
+ return ggml_upscale_impl(ctx, a, ne0, ne1, ne2, ne3);
6609
+ }
6610
+
6611
+ // ggml_pad
6612
+
6140
6613
  struct ggml_tensor * ggml_pad(
6141
6614
  struct ggml_context * ctx,
6142
6615
  struct ggml_tensor * a,
@@ -6161,12 +6634,7 @@ struct ggml_tensor * ggml_pad(
6161
6634
  return result;
6162
6635
  }
6163
6636
 
6164
- struct ggml_tensor * ggml_upscale(
6165
- struct ggml_context * ctx,
6166
- struct ggml_tensor * a,
6167
- int scale_factor) {
6168
- return ggml_upscale_impl(ctx, a, scale_factor);
6169
- }
6637
+ // ggml_arange
6170
6638
 
6171
6639
  struct ggml_tensor * ggml_arange(
6172
6640
  struct ggml_context * ctx,
@@ -6188,6 +6656,8 @@ struct ggml_tensor * ggml_arange(
6188
6656
  return result;
6189
6657
  }
6190
6658
 
6659
+ // ggml_timestep_embedding
6660
+
6191
6661
  struct ggml_tensor * ggml_timestep_embedding(
6192
6662
  struct ggml_context * ctx,
6193
6663
  struct ggml_tensor * timesteps,
@@ -6294,9 +6764,11 @@ struct ggml_tensor * ggml_flash_attn_ext(
6294
6764
  struct ggml_tensor * k,
6295
6765
  struct ggml_tensor * v,
6296
6766
  struct ggml_tensor * mask,
6297
- float scale) {
6767
+ float scale,
6768
+ float max_bias) {
6298
6769
  GGML_ASSERT(ggml_can_mul_mat(k, q));
6299
6770
  // TODO: check if vT can be multiplied by (k*qT)
6771
+
6300
6772
  if (mask) {
6301
6773
  GGML_ASSERT(ggml_is_contiguous(mask));
6302
6774
  GGML_ASSERT(mask->ne[2] == 1);
@@ -6306,6 +6778,10 @@ struct ggml_tensor * ggml_flash_attn_ext(
6306
6778
  //GGML_ASSERT(ggml_can_repeat_rows(mask, qk));
6307
6779
  }
6308
6780
 
6781
+ if (max_bias > 0.0f) {
6782
+ GGML_ASSERT(mask);
6783
+ }
6784
+
6309
6785
  bool is_node = false;
6310
6786
 
6311
6787
  if (q->grad || k->grad || v->grad) {
@@ -6316,7 +6792,7 @@ struct ggml_tensor * ggml_flash_attn_ext(
6316
6792
  int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] };
6317
6793
  struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
6318
6794
 
6319
- float params[] = { scale };
6795
+ float params[] = { scale, max_bias };
6320
6796
  ggml_set_op_params(result, params, sizeof(params));
6321
6797
 
6322
6798
  result->op = GGML_OP_FLASH_ATTN_EXT;
@@ -6336,7 +6812,7 @@ void ggml_flash_attn_ext_set_prec(
6336
6812
 
6337
6813
  const int32_t prec_i32 = (int32_t) prec;
6338
6814
 
6339
- ggml_set_op_params_i32(a, 1, prec_i32); // scale is on first pos
6815
+ ggml_set_op_params_i32(a, 2, prec_i32); // scale is on first pos, max_bias on second
6340
6816
  }
6341
6817
 
6342
6818
  // ggml_flash_ff
@@ -7215,8 +7691,8 @@ static void ggml_compute_forward_dup_same_cont(
7215
7691
  ((char *) src0->data + ie0*nb00),
7216
7692
  (ie1 - ie0) * ggml_type_size(src0->type));
7217
7693
  }
7218
-
7219
7694
  }
7695
+
7220
7696
  static void ggml_compute_forward_dup_f16(
7221
7697
  const struct ggml_compute_params * params,
7222
7698
  struct ggml_tensor * dst) {
@@ -7490,6 +7966,366 @@ static void ggml_compute_forward_dup_f16(
7490
7966
  }
7491
7967
  }
7492
7968
 
7969
+ static void ggml_compute_forward_dup_bf16(
7970
+ const struct ggml_compute_params * params,
7971
+ struct ggml_tensor * dst) {
7972
+
7973
+ const struct ggml_tensor * src0 = dst->src[0];
7974
+
7975
+ GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
7976
+
7977
+ if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
7978
+ return;
7979
+ }
7980
+
7981
+ GGML_TENSOR_UNARY_OP_LOCALS
7982
+
7983
+ const int ith = params->ith; // thread index
7984
+ const int nth = params->nth; // number of threads
7985
+
7986
+ if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst) && src0->type == dst->type) {
7987
+ ggml_compute_forward_dup_same_cont(params, dst);
7988
+ return;
7989
+ }
7990
+
7991
+ // parallelize by rows
7992
+ const int nr = ne01;
7993
+ // number of rows per thread
7994
+ const int dr = (nr + nth - 1) / nth;
7995
+ // row range for this thread
7996
+ const int ir0 = dr * ith;
7997
+ const int ir1 = MIN(ir0 + dr, nr);
7998
+
7999
+ if (src0->type == dst->type &&
8000
+ ne00 == ne0 &&
8001
+ nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) {
8002
+ // copy by rows
8003
+ const size_t rs = ne00*nb00;
8004
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
8005
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
8006
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
8007
+ memcpy(
8008
+ ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
8009
+ ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
8010
+ rs);
8011
+ }
8012
+ }
8013
+ }
8014
+ return;
8015
+ }
8016
+
8017
+ // TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy
8018
+
8019
+ if (ggml_is_contiguous(dst)) {
8020
+ if (nb00 == sizeof(ggml_bf16_t)) {
8021
+ if (dst->type == GGML_TYPE_BF16) {
8022
+ size_t id = 0;
8023
+ const size_t rs = ne00 * nb00;
8024
+ char * dst_ptr = (char *) dst->data;
8025
+
8026
+ for (int i03 = 0; i03 < ne03; i03++) {
8027
+ for (int i02 = 0; i02 < ne02; i02++) {
8028
+ id += rs * ir0;
8029
+ for (int i01 = ir0; i01 < ir1; i01++) {
8030
+ const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
8031
+ memcpy(dst_ptr + id, src0_ptr, rs);
8032
+ id += rs;
8033
+ }
8034
+ id += rs * (ne01 - ir1);
8035
+ }
8036
+ }
8037
+ } else if (dst->type == GGML_TYPE_F16) {
8038
+ size_t id = 0;
8039
+ ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
8040
+
8041
+ for (int i03 = 0; i03 < ne03; i03++) {
8042
+ for (int i02 = 0; i02 < ne02; i02++) {
8043
+ id += ne00 * ir0;
8044
+ for (int i01 = ir0; i01 < ir1; i01++) {
8045
+ const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
8046
+ for (int i00 = 0; i00 < ne00; i00++) {
8047
+ dst_ptr[id] = GGML_FP32_TO_FP16(GGML_BF16_TO_FP32(src0_ptr[i00]));
8048
+ id++;
8049
+ }
8050
+ }
8051
+ id += ne00 * (ne01 - ir1);
8052
+ }
8053
+ }
8054
+ } else if (dst->type == GGML_TYPE_F32) {
8055
+ size_t id = 0;
8056
+ float * dst_ptr = (float *) dst->data;
8057
+
8058
+ for (int i03 = 0; i03 < ne03; i03++) {
8059
+ for (int i02 = 0; i02 < ne02; i02++) {
8060
+ id += ne00 * ir0;
8061
+ for (int i01 = ir0; i01 < ir1; i01++) {
8062
+ const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
8063
+ for (int i00 = 0; i00 < ne00; i00++) {
8064
+ dst_ptr[id] = GGML_BF16_TO_FP32(src0_ptr[i00]);
8065
+ id++;
8066
+ }
8067
+ }
8068
+ id += ne00 * (ne01 - ir1);
8069
+ }
8070
+ }
8071
+ } else if (type_traits[dst->type].from_float) {
8072
+ ggml_from_float_t const quantize_row_q = type_traits[dst->type].from_float;
8073
+ float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
8074
+
8075
+ size_t id = 0;
8076
+ size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
8077
+ char * dst_ptr = (char *) dst->data;
8078
+
8079
+ for (int i03 = 0; i03 < ne03; i03++) {
8080
+ for (int i02 = 0; i02 < ne02; i02++) {
8081
+ id += rs * ir0;
8082
+ for (int i01 = ir0; i01 < ir1; i01++) {
8083
+ const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
8084
+
8085
+ for (int i00 = 0; i00 < ne00; i00++) {
8086
+ src0_f32[i00] = GGML_BF16_TO_FP32(src0_ptr[i00]);
8087
+ }
8088
+
8089
+ quantize_row_q(src0_f32, dst_ptr + id, ne00);
8090
+ id += rs;
8091
+ }
8092
+ id += rs * (ne01 - ir1);
8093
+ }
8094
+ }
8095
+ } else {
8096
+ GGML_ASSERT(false); // TODO: implement
8097
+ }
8098
+ } else {
8099
+ //printf("%s: this is not optimal - fix me\n", __func__);
8100
+
8101
+ if (dst->type == GGML_TYPE_F32) {
8102
+ size_t id = 0;
8103
+ float * dst_ptr = (float *) dst->data;
8104
+
8105
+ for (int i03 = 0; i03 < ne03; i03++) {
8106
+ for (int i02 = 0; i02 < ne02; i02++) {
8107
+ id += ne00 * ir0;
8108
+ for (int i01 = ir0; i01 < ir1; i01++) {
8109
+ for (int i00 = 0; i00 < ne00; i00++) {
8110
+ const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
8111
+
8112
+ dst_ptr[id] = GGML_BF16_TO_FP32(*src0_ptr);
8113
+ id++;
8114
+ }
8115
+ }
8116
+ id += ne00 * (ne01 - ir1);
8117
+ }
8118
+ }
8119
+ } else if (dst->type == GGML_TYPE_BF16) {
8120
+ size_t id = 0;
8121
+ ggml_bf16_t * dst_ptr = (ggml_bf16_t *) dst->data;
8122
+
8123
+ for (int i03 = 0; i03 < ne03; i03++) {
8124
+ for (int i02 = 0; i02 < ne02; i02++) {
8125
+ id += ne00 * ir0;
8126
+ for (int i01 = ir0; i01 < ir1; i01++) {
8127
+ for (int i00 = 0; i00 < ne00; i00++) {
8128
+ const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
8129
+
8130
+ dst_ptr[id] = *src0_ptr;
8131
+ id++;
8132
+ }
8133
+ }
8134
+ id += ne00 * (ne01 - ir1);
8135
+ }
8136
+ }
8137
+ } else if (dst->type == GGML_TYPE_F16) {
8138
+ size_t id = 0;
8139
+ ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
8140
+
8141
+ for (int i03 = 0; i03 < ne03; i03++) {
8142
+ for (int i02 = 0; i02 < ne02; i02++) {
8143
+ id += ne00 * ir0;
8144
+ for (int i01 = ir0; i01 < ir1; i01++) {
8145
+ for (int i00 = 0; i00 < ne00; i00++) {
8146
+ const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
8147
+
8148
+ dst_ptr[id] = GGML_FP32_TO_FP16(GGML_BF16_TO_FP32(*src0_ptr));
8149
+ id++;
8150
+ }
8151
+ }
8152
+ id += ne00 * (ne01 - ir1);
8153
+ }
8154
+ }
8155
+ } else {
8156
+ GGML_ASSERT(false); // TODO: implement
8157
+ }
8158
+ }
8159
+ return;
8160
+ }
8161
+
8162
+ // dst counters
8163
+ int64_t i10 = 0;
8164
+ int64_t i11 = 0;
8165
+ int64_t i12 = 0;
8166
+ int64_t i13 = 0;
8167
+
8168
+ if (dst->type == GGML_TYPE_BF16) {
8169
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
8170
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
8171
+ i10 += ne00 * ir0;
8172
+ while (i10 >= ne0) {
8173
+ i10 -= ne0;
8174
+ if (++i11 == ne1) {
8175
+ i11 = 0;
8176
+ if (++i12 == ne2) {
8177
+ i12 = 0;
8178
+ if (++i13 == ne3) {
8179
+ i13 = 0;
8180
+ }
8181
+ }
8182
+ }
8183
+ }
8184
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
8185
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
8186
+ const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
8187
+ char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
8188
+
8189
+ memcpy(dst_ptr, src0_ptr, sizeof(ggml_bf16_t));
8190
+
8191
+ if (++i10 == ne00) {
8192
+ i10 = 0;
8193
+ if (++i11 == ne01) {
8194
+ i11 = 0;
8195
+ if (++i12 == ne02) {
8196
+ i12 = 0;
8197
+ if (++i13 == ne03) {
8198
+ i13 = 0;
8199
+ }
8200
+ }
8201
+ }
8202
+ }
8203
+ }
8204
+ }
8205
+ i10 += ne00 * (ne01 - ir1);
8206
+ while (i10 >= ne0) {
8207
+ i10 -= ne0;
8208
+ if (++i11 == ne1) {
8209
+ i11 = 0;
8210
+ if (++i12 == ne2) {
8211
+ i12 = 0;
8212
+ if (++i13 == ne3) {
8213
+ i13 = 0;
8214
+ }
8215
+ }
8216
+ }
8217
+ }
8218
+ }
8219
+ }
8220
+ } else if (dst->type == GGML_TYPE_F16) {
8221
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
8222
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
8223
+ i10 += ne00 * ir0;
8224
+ while (i10 >= ne0) {
8225
+ i10 -= ne0;
8226
+ if (++i11 == ne1) {
8227
+ i11 = 0;
8228
+ if (++i12 == ne2) {
8229
+ i12 = 0;
8230
+ if (++i13 == ne3) {
8231
+ i13 = 0;
8232
+ }
8233
+ }
8234
+ }
8235
+ }
8236
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
8237
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
8238
+ const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
8239
+ char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
8240
+
8241
+ *(ggml_fp16_t *) dst_ptr = GGML_FP32_TO_FP16(GGML_BF16_TO_FP32(*(const ggml_bf16_t *) src0_ptr));
8242
+
8243
+ if (++i10 == ne0) {
8244
+ i10 = 0;
8245
+ if (++i11 == ne1) {
8246
+ i11 = 0;
8247
+ if (++i12 == ne2) {
8248
+ i12 = 0;
8249
+ if (++i13 == ne3) {
8250
+ i13 = 0;
8251
+ }
8252
+ }
8253
+ }
8254
+ }
8255
+ }
8256
+ }
8257
+ i10 += ne00 * (ne01 - ir1);
8258
+ while (i10 >= ne0) {
8259
+ i10 -= ne0;
8260
+ if (++i11 == ne1) {
8261
+ i11 = 0;
8262
+ if (++i12 == ne2) {
8263
+ i12 = 0;
8264
+ if (++i13 == ne3) {
8265
+ i13 = 0;
8266
+ }
8267
+ }
8268
+ }
8269
+ }
8270
+ }
8271
+ }
8272
+ } else if (dst->type == GGML_TYPE_F32) {
8273
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
8274
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
8275
+ i10 += ne00 * ir0;
8276
+ while (i10 >= ne0) {
8277
+ i10 -= ne0;
8278
+ if (++i11 == ne1) {
8279
+ i11 = 0;
8280
+ if (++i12 == ne2) {
8281
+ i12 = 0;
8282
+ if (++i13 == ne3) {
8283
+ i13 = 0;
8284
+ }
8285
+ }
8286
+ }
8287
+ }
8288
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
8289
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
8290
+ const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
8291
+ char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
8292
+
8293
+ *(float *) dst_ptr = GGML_BF16_TO_FP32(*(const ggml_bf16_t *) src0_ptr);
8294
+
8295
+ if (++i10 == ne0) {
8296
+ i10 = 0;
8297
+ if (++i11 == ne1) {
8298
+ i11 = 0;
8299
+ if (++i12 == ne2) {
8300
+ i12 = 0;
8301
+ if (++i13 == ne3) {
8302
+ i13 = 0;
8303
+ }
8304
+ }
8305
+ }
8306
+ }
8307
+ }
8308
+ }
8309
+ i10 += ne00 * (ne01 - ir1);
8310
+ while (i10 >= ne0) {
8311
+ i10 -= ne0;
8312
+ if (++i11 == ne1) {
8313
+ i11 = 0;
8314
+ if (++i12 == ne2) {
8315
+ i12 = 0;
8316
+ if (++i13 == ne3) {
8317
+ i13 = 0;
8318
+ }
8319
+ }
8320
+ }
8321
+ }
8322
+ }
8323
+ }
8324
+ } else {
8325
+ GGML_ASSERT(false); // TODO: implement
8326
+ }
8327
+ }
8328
+
7493
8329
  static void ggml_compute_forward_dup_f32(
7494
8330
  const struct ggml_compute_params * params,
7495
8331
  struct ggml_tensor * dst) {
@@ -7596,43 +8432,113 @@ static void ggml_compute_forward_dup_f32(
7596
8432
  id++;
7597
8433
  }
7598
8434
  }
7599
- id += ne00 * (ne01 - ir1);
8435
+ id += ne00 * (ne01 - ir1);
8436
+ }
8437
+ }
8438
+ } else if (dst->type == GGML_TYPE_F16) {
8439
+ size_t id = 0;
8440
+ ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
8441
+
8442
+ for (int i03 = 0; i03 < ne03; i03++) {
8443
+ for (int i02 = 0; i02 < ne02; i02++) {
8444
+ id += ne00 * ir0;
8445
+ for (int i01 = ir0; i01 < ir1; i01++) {
8446
+ for (int i00 = 0; i00 < ne00; i00++) {
8447
+ const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
8448
+
8449
+ dst_ptr[id] = GGML_FP32_TO_FP16(*src0_ptr);
8450
+ id++;
8451
+ }
8452
+ }
8453
+ id += ne00 * (ne01 - ir1);
8454
+ }
8455
+ }
8456
+ } else if (dst->type == GGML_TYPE_BF16) {
8457
+ size_t id = 0;
8458
+ ggml_bf16_t * dst_ptr = (ggml_bf16_t *) dst->data;
8459
+
8460
+ for (int i03 = 0; i03 < ne03; i03++) {
8461
+ for (int i02 = 0; i02 < ne02; i02++) {
8462
+ id += ne00 * ir0;
8463
+ for (int i01 = ir0; i01 < ir1; i01++) {
8464
+ for (int i00 = 0; i00 < ne00; i00++) {
8465
+ const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
8466
+
8467
+ dst_ptr[id] = GGML_FP32_TO_BF16(*src0_ptr);
8468
+ id++;
8469
+ }
8470
+ }
8471
+ id += ne00 * (ne01 - ir1);
8472
+ }
8473
+ }
8474
+ } else {
8475
+ GGML_ASSERT(false); // TODO: implement
8476
+ }
8477
+ }
8478
+
8479
+ return;
8480
+ }
8481
+
8482
+ // dst counters
8483
+
8484
+ int64_t i10 = 0;
8485
+ int64_t i11 = 0;
8486
+ int64_t i12 = 0;
8487
+ int64_t i13 = 0;
8488
+
8489
+ if (dst->type == GGML_TYPE_F32) {
8490
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
8491
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
8492
+ i10 += ne00 * ir0;
8493
+ while (i10 >= ne0) {
8494
+ i10 -= ne0;
8495
+ if (++i11 == ne1) {
8496
+ i11 = 0;
8497
+ if (++i12 == ne2) {
8498
+ i12 = 0;
8499
+ if (++i13 == ne3) {
8500
+ i13 = 0;
8501
+ }
8502
+ }
8503
+ }
8504
+ }
8505
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
8506
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
8507
+ const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
8508
+ char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
8509
+
8510
+ memcpy(dst_ptr, src0_ptr, sizeof(float));
8511
+
8512
+ if (++i10 == ne0) {
8513
+ i10 = 0;
8514
+ if (++i11 == ne1) {
8515
+ i11 = 0;
8516
+ if (++i12 == ne2) {
8517
+ i12 = 0;
8518
+ if (++i13 == ne3) {
8519
+ i13 = 0;
8520
+ }
8521
+ }
8522
+ }
8523
+ }
7600
8524
  }
7601
8525
  }
7602
- } else if (dst->type == GGML_TYPE_F16) {
7603
- size_t id = 0;
7604
- ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
7605
-
7606
- for (int i03 = 0; i03 < ne03; i03++) {
7607
- for (int i02 = 0; i02 < ne02; i02++) {
7608
- id += ne00 * ir0;
7609
- for (int i01 = ir0; i01 < ir1; i01++) {
7610
- for (int i00 = 0; i00 < ne00; i00++) {
7611
- const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
7612
-
7613
- dst_ptr[id] = GGML_FP32_TO_FP16(*src0_ptr);
7614
- id++;
8526
+ i10 += ne00 * (ne01 - ir1);
8527
+ while (i10 >= ne0) {
8528
+ i10 -= ne0;
8529
+ if (++i11 == ne1) {
8530
+ i11 = 0;
8531
+ if (++i12 == ne2) {
8532
+ i12 = 0;
8533
+ if (++i13 == ne3) {
8534
+ i13 = 0;
7615
8535
  }
7616
8536
  }
7617
- id += ne00 * (ne01 - ir1);
7618
8537
  }
7619
8538
  }
7620
- } else {
7621
- GGML_ASSERT(false); // TODO: implement
7622
8539
  }
7623
8540
  }
7624
-
7625
- return;
7626
- }
7627
-
7628
- // dst counters
7629
-
7630
- int64_t i10 = 0;
7631
- int64_t i11 = 0;
7632
- int64_t i12 = 0;
7633
- int64_t i13 = 0;
7634
-
7635
- if (dst->type == GGML_TYPE_F32) {
8541
+ } else if (dst->type == GGML_TYPE_F16) {
7636
8542
  for (int64_t i03 = 0; i03 < ne03; i03++) {
7637
8543
  for (int64_t i02 = 0; i02 < ne02; i02++) {
7638
8544
  i10 += ne00 * ir0;
@@ -7653,7 +8559,7 @@ static void ggml_compute_forward_dup_f32(
7653
8559
  const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
7654
8560
  char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
7655
8561
 
7656
- memcpy(dst_ptr, src0_ptr, sizeof(float));
8562
+ *(ggml_fp16_t *) dst_ptr = GGML_FP32_TO_FP16(*(const float *) src0_ptr);
7657
8563
 
7658
8564
  if (++i10 == ne0) {
7659
8565
  i10 = 0;
@@ -7684,7 +8590,7 @@ static void ggml_compute_forward_dup_f32(
7684
8590
  }
7685
8591
  }
7686
8592
  }
7687
- } else if (dst->type == GGML_TYPE_F16) {
8593
+ } else if (dst->type == GGML_TYPE_BF16) {
7688
8594
  for (int64_t i03 = 0; i03 < ne03; i03++) {
7689
8595
  for (int64_t i02 = 0; i02 < ne02; i02++) {
7690
8596
  i10 += ne00 * ir0;
@@ -7705,7 +8611,7 @@ static void ggml_compute_forward_dup_f32(
7705
8611
  const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
7706
8612
  char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
7707
8613
 
7708
- *(ggml_fp16_t *) dst_ptr = GGML_FP32_TO_FP16(*(const float *) src0_ptr);
8614
+ *(ggml_bf16_t *) dst_ptr = GGML_FP32_TO_BF16(*(const float *) src0_ptr);
7709
8615
 
7710
8616
  if (++i10 == ne0) {
7711
8617
  i10 = 0;
@@ -7909,6 +8815,10 @@ static void ggml_compute_forward_dup(
7909
8815
  {
7910
8816
  ggml_compute_forward_dup_f16(params, dst);
7911
8817
  } break;
8818
+ case GGML_TYPE_BF16:
8819
+ {
8820
+ ggml_compute_forward_dup_bf16(params, dst);
8821
+ } break;
7912
8822
  case GGML_TYPE_F32:
7913
8823
  {
7914
8824
  ggml_compute_forward_dup_f32(params, dst);
@@ -8091,6 +9001,85 @@ static void ggml_compute_forward_add_f16_f32(
8091
9001
  }
8092
9002
  }
8093
9003
 
9004
+ static void ggml_compute_forward_add_bf16_f32(
9005
+ const struct ggml_compute_params * params,
9006
+ struct ggml_tensor * dst) {
9007
+
9008
+ const struct ggml_tensor * src0 = dst->src[0];
9009
+ const struct ggml_tensor * src1 = dst->src[1];
9010
+
9011
+ GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
9012
+
9013
+ if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
9014
+ return;
9015
+ }
9016
+
9017
+ const int ith = params->ith;
9018
+ const int nth = params->nth;
9019
+
9020
+ const int nr = ggml_nrows(src0);
9021
+
9022
+ GGML_TENSOR_BINARY_OP_LOCALS
9023
+
9024
+ GGML_ASSERT(src0->type == GGML_TYPE_BF16);
9025
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
9026
+
9027
+ if (dst->type == GGML_TYPE_F32) {
9028
+ GGML_ASSERT( nb0 == sizeof(float));
9029
+ }
9030
+ else {
9031
+ GGML_ASSERT(dst->type == GGML_TYPE_BF16);
9032
+ GGML_ASSERT( nb0 == sizeof(ggml_bf16_t));
9033
+ }
9034
+
9035
+ GGML_ASSERT(nb00 == sizeof(ggml_bf16_t));
9036
+
9037
+ // rows per thread
9038
+ const int dr = (nr + nth - 1)/nth;
9039
+
9040
+ // row range for this thread
9041
+ const int ir0 = dr*ith;
9042
+ const int ir1 = MIN(ir0 + dr, nr);
9043
+
9044
+ if (nb10 == sizeof(float)) {
9045
+ if (dst->type == GGML_TYPE_BF16) {
9046
+ for (int ir = ir0; ir < ir1; ++ir) {
9047
+ // src0, src1 and dst are same shape => same indices
9048
+ const int i3 = ir/(ne2*ne1);
9049
+ const int i2 = (ir - i3*ne2*ne1)/ne1;
9050
+ const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
9051
+
9052
+ ggml_bf16_t * dst_ptr = (ggml_bf16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
9053
+ ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
9054
+ float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
9055
+
9056
+ for (int i = 0; i < ne0; i++) {
9057
+ dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + src1_ptr[i]);
9058
+ }
9059
+ }
9060
+ } else {
9061
+ for (int ir = ir0; ir < ir1; ++ir) {
9062
+ // src0, src1 and dst are same shape => same indices
9063
+ const int i3 = ir/(ne2*ne1);
9064
+ const int i2 = (ir - i3*ne2*ne1)/ne1;
9065
+ const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
9066
+
9067
+ float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
9068
+ ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
9069
+ float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
9070
+
9071
+ for (int i = 0; i < ne0; i++) {
9072
+ dst_ptr[i] = GGML_BF16_TO_FP32(src0_ptr[i]) + src1_ptr[i];
9073
+ }
9074
+ }
9075
+ }
9076
+ }
9077
+ else {
9078
+ // src1 is not contiguous
9079
+ GGML_ASSERT(false);
9080
+ }
9081
+ }
9082
+
8094
9083
  static void ggml_compute_forward_add_f16_f16(
8095
9084
  const struct ggml_compute_params * params,
8096
9085
  struct ggml_tensor * dst) {
@@ -8147,6 +9136,62 @@ static void ggml_compute_forward_add_f16_f16(
8147
9136
  }
8148
9137
  }
8149
9138
 
9139
+ static void ggml_compute_forward_add_bf16_bf16(
9140
+ const struct ggml_compute_params * params,
9141
+ struct ggml_tensor * dst) {
9142
+
9143
+ const struct ggml_tensor * src0 = dst->src[0];
9144
+ const struct ggml_tensor * src1 = dst->src[1];
9145
+
9146
+ GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
9147
+
9148
+ if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
9149
+ return;
9150
+ }
9151
+
9152
+ const int ith = params->ith;
9153
+ const int nth = params->nth;
9154
+
9155
+ const int nr = ggml_nrows(src0);
9156
+
9157
+ GGML_TENSOR_BINARY_OP_LOCALS
9158
+
9159
+ GGML_ASSERT(src0->type == GGML_TYPE_BF16);
9160
+ GGML_ASSERT(src1->type == GGML_TYPE_BF16);
9161
+ GGML_ASSERT(dst->type == GGML_TYPE_BF16);
9162
+
9163
+ GGML_ASSERT( nb0 == sizeof(ggml_bf16_t));
9164
+ GGML_ASSERT(nb00 == sizeof(ggml_bf16_t));
9165
+
9166
+ // rows per thread
9167
+ const int dr = (nr + nth - 1)/nth;
9168
+
9169
+ // row range for this thread
9170
+ const int ir0 = dr*ith;
9171
+ const int ir1 = MIN(ir0 + dr, nr);
9172
+
9173
+ if (nb10 == sizeof(ggml_bf16_t)) {
9174
+ for (int ir = ir0; ir < ir1; ++ir) {
9175
+ // src0, src1 and dst are same shape => same indices
9176
+ const int i3 = ir/(ne2*ne1);
9177
+ const int i2 = (ir - i3*ne2*ne1)/ne1;
9178
+ const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
9179
+
9180
+ ggml_bf16_t * dst_ptr = (ggml_bf16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
9181
+ ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
9182
+ ggml_bf16_t * src1_ptr = (ggml_bf16_t *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
9183
+
9184
+ for (int i = 0; i < ne0; i++) {
9185
+ dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + GGML_BF16_TO_FP32(src1_ptr[i]));
9186
+ }
9187
+ }
9188
+ }
9189
+ else {
9190
+ // src1 is not contiguous
9191
+ GGML_ASSERT(false);
9192
+ }
9193
+ }
9194
+
8150
9195
  static void ggml_compute_forward_add_q_f32(
8151
9196
  const struct ggml_compute_params * params,
8152
9197
  struct ggml_tensor * dst) {
@@ -8256,6 +9301,18 @@ static void ggml_compute_forward_add(
8256
9301
  GGML_ASSERT(false);
8257
9302
  }
8258
9303
  } break;
9304
+ case GGML_TYPE_BF16:
9305
+ {
9306
+ if (src1->type == GGML_TYPE_BF16) {
9307
+ ggml_compute_forward_add_bf16_bf16(params, dst);
9308
+ }
9309
+ else if (src1->type == GGML_TYPE_F32) {
9310
+ ggml_compute_forward_add_bf16_f32(params, dst);
9311
+ }
9312
+ else {
9313
+ GGML_ASSERT(false);
9314
+ }
9315
+ } break;
8259
9316
  case GGML_TYPE_Q4_0:
8260
9317
  case GGML_TYPE_Q4_1:
8261
9318
  case GGML_TYPE_Q5_0:
@@ -8505,12 +9562,116 @@ static void ggml_compute_forward_add1_q_f32(
8505
9562
 
8506
9563
  assert(ne0 % 32 == 0);
8507
9564
 
8508
- // unquantize row from src0 to temp buffer
8509
- dequantize_row_q(src0_row, wdata, ne0);
8510
- // add src1
8511
- ggml_vec_acc1_f32(ne0, wdata, v);
8512
- // quantize row to dst
8513
- quantize_row_q(wdata, dst_row, ne0);
9565
+ // unquantize row from src0 to temp buffer
9566
+ dequantize_row_q(src0_row, wdata, ne0);
9567
+ // add src1
9568
+ ggml_vec_acc1_f32(ne0, wdata, v);
9569
+ // quantize row to dst
9570
+ quantize_row_q(wdata, dst_row, ne0);
9571
+ }
9572
+ }
9573
+
9574
+ static void ggml_compute_forward_add1_bf16_f32(
9575
+ const struct ggml_compute_params * params,
9576
+ struct ggml_tensor * dst) {
9577
+
9578
+ const struct ggml_tensor * src0 = dst->src[0];
9579
+ const struct ggml_tensor * src1 = dst->src[1];
9580
+
9581
+ GGML_ASSERT(ggml_are_same_shape(src0, dst));
9582
+ GGML_ASSERT(ggml_is_scalar(src1));
9583
+
9584
+ if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
9585
+ return;
9586
+ }
9587
+
9588
+ // scalar to add
9589
+ const float v = *(float *) src1->data;
9590
+
9591
+ const int ith = params->ith;
9592
+ const int nth = params->nth;
9593
+
9594
+ const int nr = ggml_nrows(src0);
9595
+
9596
+ GGML_TENSOR_UNARY_OP_LOCALS
9597
+
9598
+ GGML_ASSERT(src0->type == GGML_TYPE_BF16);
9599
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
9600
+ GGML_ASSERT(dst->type == GGML_TYPE_BF16);
9601
+
9602
+ GGML_ASSERT( nb0 == sizeof(ggml_bf16_t));
9603
+ GGML_ASSERT(nb00 == sizeof(ggml_bf16_t));
9604
+
9605
+ // rows per thread
9606
+ const int dr = (nr + nth - 1)/nth;
9607
+
9608
+ // row range for this thread
9609
+ const int ir0 = dr*ith;
9610
+ const int ir1 = MIN(ir0 + dr, nr);
9611
+
9612
+ for (int ir = ir0; ir < ir1; ++ir) {
9613
+ // src0 and dst are same shape => same indices
9614
+ const int i3 = ir/(ne2*ne1);
9615
+ const int i2 = (ir - i3*ne2*ne1)/ne1;
9616
+ const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
9617
+
9618
+ ggml_bf16_t * dst_ptr = (ggml_bf16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 );
9619
+ ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
9620
+ for (int i = 0; i < ne0; i++) {
9621
+ dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + v);
9622
+ }
9623
+ }
9624
+ }
9625
+
9626
+ static void ggml_compute_forward_add1_bf16_bf16(
9627
+ const struct ggml_compute_params * params,
9628
+ struct ggml_tensor * dst) {
9629
+
9630
+ const struct ggml_tensor * src0 = dst->src[0];
9631
+ const struct ggml_tensor * src1 = dst->src[1];
9632
+
9633
+ GGML_ASSERT(ggml_are_same_shape(src0, dst));
9634
+ GGML_ASSERT(ggml_is_scalar(src1));
9635
+
9636
+ if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
9637
+ return;
9638
+ }
9639
+
9640
+ // scalar to add
9641
+ const float v = GGML_BF16_TO_FP32(*(ggml_bf16_t *) src1->data);
9642
+
9643
+ const int ith = params->ith;
9644
+ const int nth = params->nth;
9645
+
9646
+ const int nr = ggml_nrows(src0);
9647
+
9648
+ GGML_TENSOR_UNARY_OP_LOCALS
9649
+
9650
+ GGML_ASSERT(src0->type == GGML_TYPE_BF16);
9651
+ GGML_ASSERT(src1->type == GGML_TYPE_BF16);
9652
+ GGML_ASSERT(dst->type == GGML_TYPE_BF16);
9653
+
9654
+ GGML_ASSERT( nb0 == sizeof(ggml_bf16_t));
9655
+ GGML_ASSERT(nb00 == sizeof(ggml_bf16_t));
9656
+
9657
+ // rows per thread
9658
+ const int dr = (nr + nth - 1)/nth;
9659
+
9660
+ // row range for this thread
9661
+ const int ir0 = dr*ith;
9662
+ const int ir1 = MIN(ir0 + dr, nr);
9663
+
9664
+ for (int ir = ir0; ir < ir1; ++ir) {
9665
+ // src0 and dst are same shape => same indices
9666
+ const int i3 = ir/(ne2*ne1);
9667
+ const int i2 = (ir - i3*ne2*ne1)/ne1;
9668
+ const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
9669
+
9670
+ ggml_bf16_t * dst_ptr = (ggml_bf16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 );
9671
+ ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
9672
+ for (int i = 0; i < ne0; i++) {
9673
+ dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + v);
9674
+ }
8514
9675
  }
8515
9676
  }
8516
9677
 
@@ -8538,6 +9699,18 @@ static void ggml_compute_forward_add1(
8538
9699
  GGML_ASSERT(false);
8539
9700
  }
8540
9701
  } break;
9702
+ case GGML_TYPE_BF16:
9703
+ {
9704
+ if (src1->type == GGML_TYPE_BF16) {
9705
+ ggml_compute_forward_add1_bf16_bf16(params, dst);
9706
+ }
9707
+ else if (src1->type == GGML_TYPE_F32) {
9708
+ ggml_compute_forward_add1_bf16_f32(params, dst);
9709
+ }
9710
+ else {
9711
+ GGML_ASSERT(false);
9712
+ }
9713
+ } break;
8541
9714
  case GGML_TYPE_Q4_0:
8542
9715
  case GGML_TYPE_Q4_1:
8543
9716
  case GGML_TYPE_Q5_0:
@@ -8666,6 +9839,7 @@ static void ggml_compute_forward_acc(
8666
9839
  ggml_compute_forward_acc_f32(params, dst);
8667
9840
  } break;
8668
9841
  case GGML_TYPE_F16:
9842
+ case GGML_TYPE_BF16:
8669
9843
  case GGML_TYPE_Q4_0:
8670
9844
  case GGML_TYPE_Q4_1:
8671
9845
  case GGML_TYPE_Q5_0:
@@ -9187,6 +10361,40 @@ static void ggml_compute_forward_sum_f16(
9187
10361
  ((ggml_fp16_t *) dst->data)[0] = GGML_FP32_TO_FP16(sum);
9188
10362
  }
9189
10363
 
10364
+ static void ggml_compute_forward_sum_bf16(
10365
+ const struct ggml_compute_params * params,
10366
+ struct ggml_tensor * dst) {
10367
+
10368
+ const struct ggml_tensor * src0 = dst->src[0];
10369
+
10370
+ assert(params->ith == 0);
10371
+ assert(ggml_is_scalar(dst));
10372
+
10373
+ if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
10374
+ return;
10375
+ }
10376
+
10377
+ assert(src0->nb[0] == sizeof(ggml_bf16_t));
10378
+
10379
+ GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
10380
+ GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
10381
+
10382
+ float sum = 0;
10383
+ float row_sum = 0;
10384
+
10385
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
10386
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
10387
+ for (int64_t i01 = 0; i01 < ne01; i01++) {
10388
+ ggml_vec_sum_bf16_ggf(ne00,
10389
+ &row_sum,
10390
+ (ggml_bf16_t *) ((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03));
10391
+ sum += row_sum;
10392
+ }
10393
+ }
10394
+ }
10395
+ ((ggml_bf16_t *) dst->data)[0] = GGML_FP32_TO_BF16(sum);
10396
+ }
10397
+
9190
10398
  static void ggml_compute_forward_sum(
9191
10399
  const struct ggml_compute_params * params,
9192
10400
  struct ggml_tensor * dst) {
@@ -9202,6 +10410,10 @@ static void ggml_compute_forward_sum(
9202
10410
  {
9203
10411
  ggml_compute_forward_sum_f16(params, dst);
9204
10412
  } break;
10413
+ case GGML_TYPE_BF16:
10414
+ {
10415
+ ggml_compute_forward_sum_bf16(params, dst);
10416
+ } break;
9205
10417
  default:
9206
10418
  {
9207
10419
  GGML_ASSERT(false);
@@ -9476,6 +10688,7 @@ static void ggml_compute_forward_repeat(
9476
10688
 
9477
10689
  switch (src0->type) {
9478
10690
  case GGML_TYPE_F16:
10691
+ case GGML_TYPE_BF16:
9479
10692
  case GGML_TYPE_I16:
9480
10693
  {
9481
10694
  ggml_compute_forward_repeat_f16(params, dst);
@@ -9963,6 +11176,52 @@ static void ggml_compute_forward_relu(
9963
11176
  }
9964
11177
  }
9965
11178
 
11179
+ // ggml_compute_forward_sigmoid
11180
+
11181
+ static void ggml_compute_forward_sigmoid_f32(
11182
+ const struct ggml_compute_params * params,
11183
+ struct ggml_tensor * dst) {
11184
+
11185
+ const struct ggml_tensor * src0 = dst->src[0];
11186
+
11187
+ assert(params->ith == 0);
11188
+ assert(ggml_are_same_shape(src0, dst));
11189
+
11190
+ if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
11191
+ return;
11192
+ }
11193
+
11194
+ const int n = ggml_nrows(src0);
11195
+ const int nc = src0->ne[0];
11196
+
11197
+ assert(dst->nb[0] == sizeof(float));
11198
+ assert(src0->nb[0] == sizeof(float));
11199
+
11200
+ for (int i = 0; i < n; i++) {
11201
+ ggml_vec_sigmoid_f32(nc,
11202
+ (float *) ((char *) dst->data + i*( dst->nb[1])),
11203
+ (float *) ((char *) src0->data + i*(src0->nb[1])));
11204
+ }
11205
+ }
11206
+
11207
+ static void ggml_compute_forward_sigmoid(
11208
+ const struct ggml_compute_params * params,
11209
+ struct ggml_tensor * dst) {
11210
+
11211
+ const struct ggml_tensor * src0 = dst->src[0];
11212
+
11213
+ switch (src0->type) {
11214
+ case GGML_TYPE_F32:
11215
+ {
11216
+ ggml_compute_forward_sigmoid_f32(params, dst);
11217
+ } break;
11218
+ default:
11219
+ {
11220
+ GGML_ASSERT(false);
11221
+ } break;
11222
+ }
11223
+ }
11224
+
9966
11225
  // ggml_compute_forward_gelu
9967
11226
 
9968
11227
  static void ggml_compute_forward_gelu_f32(
@@ -10813,9 +12072,101 @@ static bool ggml_compute_forward_mul_mat_use_blas(struct ggml_tensor * dst) {
10813
12072
  }
10814
12073
  #endif
10815
12074
 
12075
+ static void ggml_compute_forward_mul_mat_one_chunk(
12076
+ const struct ggml_compute_params * params,
12077
+ struct ggml_tensor * dst,
12078
+ const int64_t num_rows_per_vec_dot,
12079
+ const int64_t ir0_start,
12080
+ const int64_t ir0_end,
12081
+ const int64_t ir1_start,
12082
+ const int64_t ir1_end) {
12083
+
12084
+ const struct ggml_tensor * src0 = dst->src[0];
12085
+ const struct ggml_tensor * src1 = dst->src[1];
12086
+
12087
+ GGML_TENSOR_BINARY_OP_LOCALS
12088
+
12089
+ const enum ggml_type type = src0->type;
12090
+
12091
+ const bool src1_cont = ggml_is_contiguous(src1);
12092
+
12093
+ ggml_vec_dot_t const vec_dot = type_traits[type].vec_dot;
12094
+ enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type;
12095
+
12096
+ // broadcast factors
12097
+ const int64_t r2 = ne12 / ne02;
12098
+ const int64_t r3 = ne13 / ne03;
12099
+
12100
+ //printf("ir0_start = %6lld, ir0_end = %6lld, ir1_start = %6lld, ir1_end = %6lld\n", ir0_start, ir0_end, ir1_start, ir1_end);
12101
+
12102
+ // threads with no work simply yield (not sure if it helps)
12103
+ if (ir0_start >= ir0_end || ir1_start >= ir1_end) {
12104
+ return;
12105
+ }
12106
+
12107
+ const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
12108
+ const size_t row_size = ggml_row_size(vec_dot_type, ne10);
12109
+
12110
+ assert(ne12 % ne02 == 0);
12111
+ assert(ne13 % ne03 == 0);
12112
+
12113
+ // block-tiling attempt
12114
+ const int64_t blck_0 = 16;
12115
+ const int64_t blck_1 = 16;
12116
+
12117
+ const size_t src1_col_stride = src1_cont || src1->type != vec_dot_type ? row_size : nb11;
12118
+
12119
+ // attempt to reduce false-sharing (does not seem to make a difference)
12120
+ // 16 * 2, accounting for mmla kernels
12121
+ float tmp[32];
12122
+
12123
+ for (int64_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) {
12124
+ for (int64_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) {
12125
+ for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir1_end; ir1 += num_rows_per_vec_dot) {
12126
+ const int64_t i13 = (ir1 / (ne12 * ne1));
12127
+ const int64_t i12 = (ir1 - i13 * ne12 * ne1) / ne1;
12128
+ const int64_t i11 = (ir1 - i13 * ne12 * ne1 - i12 * ne1);
12129
+
12130
+ // broadcast src0 into src1
12131
+ const int64_t i03 = i13 / r3;
12132
+ const int64_t i02 = i12 / r2;
12133
+
12134
+ const int64_t i1 = i11;
12135
+ const int64_t i2 = i12;
12136
+ const int64_t i3 = i13;
12137
+
12138
+ const char * src0_row = (const char*)src0->data + (0 + i02 * nb02 + i03 * nb03);
12139
+
12140
+ // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
12141
+ // if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
12142
+ // the original src1 data pointer, so we should index using the indices directly
12143
+ // TODO: this is a bit of a hack, we should probably have a better way to handle this
12144
+ const char * src1_col = (const char*)wdata +
12145
+ (src1_cont || src1->type != vec_dot_type
12146
+ ? (i11 + i12 * ne11 + i13 * ne12 * ne11) * row_size
12147
+ : (i11 * nb11 + i12 * nb12 + i13 * nb13));
12148
+ float * dst_col = (float*)((char*)dst->data + (i1 * nb1 + i2 * nb2 + i3 * nb3));
12149
+
12150
+ //for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ++ir0) {
12151
+ // vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col);
12152
+ //}
12153
+
12154
+ for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ir0 += num_rows_per_vec_dot) {
12155
+ vec_dot(ne00, &tmp[ir0 - iir0], (num_rows_per_vec_dot > 1 ? 16 : 0), src0_row + ir0 * nb01, (num_rows_per_vec_dot > 1 ? nb01 : 0), src1_col, (num_rows_per_vec_dot > 1 ? src1_col_stride : 0), num_rows_per_vec_dot);
12156
+ }
12157
+
12158
+ for (int cn = 0; cn < num_rows_per_vec_dot; ++cn) {
12159
+ memcpy(&dst_col[iir0 + cn * nb1 / nb0], tmp + (cn * 16), (MIN(iir0 + blck_0, ir0_end) - iir0) * sizeof(float));
12160
+ }
12161
+ }
12162
+ }
12163
+ }
12164
+ }
12165
+
10816
12166
  static void ggml_compute_forward_mul_mat(
10817
12167
  const struct ggml_compute_params * params,
10818
- struct ggml_tensor * dst) {
12168
+ struct ggml_tensor * dst,
12169
+ struct ggml_compute_state * state) {
10819
12170
 
10820
12171
  const struct ggml_tensor * src0 = dst->src[0];
10821
12172
  const struct ggml_tensor * src1 = dst->src[1];
@@ -10830,9 +12181,6 @@ static void ggml_compute_forward_mul_mat(
10830
12181
 
10831
12182
  const enum ggml_type type = src0->type;
10832
12183
 
10833
- const bool src1_cont = ggml_is_contiguous(src1);
10834
-
10835
- ggml_vec_dot_t const vec_dot = type_traits[type].vec_dot;
10836
12184
  enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type;
10837
12185
  ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float;
10838
12186
  int64_t const vec_dot_num_rows = type_traits[type].nrows;
@@ -10853,8 +12201,10 @@ static void ggml_compute_forward_mul_mat(
10853
12201
  GGML_ASSERT(nb2 <= nb3);
10854
12202
 
10855
12203
  // broadcast factors
10856
- const int64_t r2 = ne12/ne02;
10857
- const int64_t r3 = ne13/ne03;
12204
+ const int64_t r2 = ne12 / ne02;
12205
+ const int64_t r3 = ne13 / ne03;
12206
+ UNUSED(r2);
12207
+ UNUSED(r3);
10858
12208
 
10859
12209
  // nb01 >= nb00 - src0 is not transposed
10860
12210
  // compute by src0 rows
@@ -10936,6 +12286,8 @@ static void ggml_compute_forward_mul_mat(
10936
12286
  #endif
10937
12287
 
10938
12288
  #if GGML_USE_LLAMAFILE
12289
+ const bool src1_cont = ggml_is_contiguous(src1);
12290
+
10939
12291
  if (src1_cont) {
10940
12292
  for (int64_t i13 = 0; i13 < ne13; i13++)
10941
12293
  for (int64_t i12 = 0; i12 < ne12; i12++)
@@ -10961,6 +12313,8 @@ UseGgmlGemm1:;
10961
12313
  if (ith != 0) {
10962
12314
  return;
10963
12315
  }
12316
+ // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
12317
+ atomic_store(&state->shared->current_chunk, nth);
10964
12318
  if (src1->type != vec_dot_type) {
10965
12319
  char * wdata = params->wdata;
10966
12320
  const size_t row_size = ggml_row_size(vec_dot_type, ne10);
@@ -10985,11 +12339,11 @@ UseGgmlGemm1:;
10985
12339
  return;
10986
12340
  }
10987
12341
 
10988
- const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
10989
- const size_t row_size = ggml_row_size(vec_dot_type, ne10);
10990
-
10991
12342
  #if GGML_USE_LLAMAFILE
10992
12343
  if (src1->type != vec_dot_type) {
12344
+ const void* wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
12345
+ const size_t row_size = ggml_row_size(vec_dot_type, ne10);
12346
+
10993
12347
  for (int64_t i13 = 0; i13 < ne13; i13++)
10994
12348
  for (int64_t i12 = 0; i12 < ne12; i12++)
10995
12349
  if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type),
@@ -11010,98 +12364,87 @@ UseGgmlGemm1:;
11010
12364
  UseGgmlGemm2:;
11011
12365
  #endif
11012
12366
 
11013
- const int64_t nr0 = ne01; // src0 rows
11014
- const int64_t nr1 = ne1*ne12*ne13; // src1 rows
11015
-
11016
- //printf("nr0 = %lld, nr1 = %lld\n", nr0, nr1);
11017
-
11018
- // distribute the thread work across the inner or outer loop based on which one is larger
11019
-
11020
- const int64_t nth0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows
11021
- const int64_t nth1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows
11022
-
11023
- const int64_t ith0 = ith % nth0;
11024
- const int64_t ith1 = ith / nth0;
11025
-
11026
- const int64_t dr0 = (nr0 + nth0 - 1)/nth0;
11027
- const int64_t dr1 = (nr1 + nth1 - 1)/nth1;
11028
-
11029
- const int64_t ir010 = dr0*ith0;
11030
- const int64_t ir011 = MIN(ir010 + dr0, nr0);
11031
-
11032
- const int64_t ir110 = dr1*ith1;
11033
- const int64_t ir111 = MIN(ir110 + dr1, nr1);
11034
-
11035
- //printf("ir010 = %6lld, ir011 = %6lld, ir110 = %6lld, ir111 = %6lld\n", ir010, ir011, ir110, ir111);
11036
-
11037
- // threads with no work simply yield (not sure if it helps)
11038
- if (ir010 >= ir011 || ir110 >= ir111) {
11039
- sched_yield();
11040
- return;
11041
- }
12367
+ #ifdef GGML_PERF
12368
+ int chunks_executed = 0;
12369
+ UNUSED(chunks_executed);
12370
+ #endif
11042
12371
 
11043
- assert(ne12 % ne02 == 0);
11044
- assert(ne13 % ne03 == 0);
12372
+ // This is the size of the first dimension of the result, so we can iterate that way. (see the ASSERT above, these are the same numbers)
12373
+ const int64_t nr0 = ne0;
11045
12374
 
11046
- // block-tiling attempt
11047
- const int64_t blck_0 = 16;
11048
- const int64_t blck_1 = 16;
12375
+ // This is the size of the rest of the dimensions of the result
12376
+ const int64_t nr1 = ne1 * ne2 * ne3;
11049
12377
 
11050
12378
  // dot kernels can handle 1 row and col at a time, but mmla kernels can process 2 rows and cols
11051
- int64_t nrc = vec_dot_num_rows;
12379
+ int64_t num_rows_per_vec_dot = vec_dot_num_rows;
11052
12380
  // TODO: currently the mmla kernels support only even numbered rows/cols.
11053
12381
  // this check can be removed once they are extended to support odd numbered rows/cols too
11054
12382
  if ((nr0 % 2 != 0) || (ne11 % 2 != 0)) {
11055
- nrc = 1;
12383
+ num_rows_per_vec_dot = 1;
11056
12384
  }
11057
12385
 
11058
- const size_t src1_col_stride = src1_cont || src1->type != vec_dot_type ? row_size : nb11;
12386
+ // Now select a reasonable chunk size.
12387
+ int chunk_size = 16;
11059
12388
 
11060
- // attempt to reduce false-sharing (does not seem to make a difference)
11061
- // 16 * 2, accounting for mmla kernels
11062
- float tmp[32];
12389
+ // We need to step up the size if it's small
12390
+ if (nr0 == 1 || nr1 == 1) {
12391
+ chunk_size = 64;
12392
+ }
11063
12393
 
11064
- for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) {
11065
- for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) {
11066
- for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ir1 += nrc) {
11067
- const int64_t i13 = (ir1/(ne12*ne1));
11068
- const int64_t i12 = (ir1 - i13*ne12*ne1)/ne1;
11069
- const int64_t i11 = (ir1 - i13*ne12*ne1 - i12*ne1);
12394
+ // distribute the work across the inner or outer loop based on which one is larger
12395
+ // The number of chunks in the 0/1 dim.
12396
+ // CEIL(nr0/chunk_size)
12397
+ int64_t nchunk0 = (nr0 + chunk_size - 1) / chunk_size;
12398
+ int64_t nchunk1 = (nr1 + chunk_size - 1) / chunk_size;
11070
12399
 
11071
- // broadcast src0 into src1
11072
- const int64_t i03 = i13/r3;
11073
- const int64_t i02 = i12/r2;
12400
+ // If the chunking is poor for the number of threads on this setup, scrap the whole plan. Re-chunk it by thread.
12401
+ // Also, chunking by thread was measured to have perform better on NUMA systems. See https://github.com/ggerganov/llama.cpp/pull/6915
12402
+ // In theory, chunking should be just as useful on NUMA and non NUMA systems, but testing disagreed with that.
12403
+ if (nchunk0 * nchunk1 < nth * 4 || ggml_is_numa()) {
12404
+ // distribute the thread work across the inner or outer loop based on which one is larger
12405
+ nchunk0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows
12406
+ nchunk1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows
12407
+ }
11074
12408
 
11075
- const int64_t i1 = i11;
11076
- const int64_t i2 = i12;
11077
- const int64_t i3 = i13;
12409
+ // The number of elements in each chunk
12410
+ const int64_t dr0 = (nr0 + nchunk0 - 1) / nchunk0;
12411
+ const int64_t dr1 = (nr1 + nchunk1 - 1) / nchunk1;
11078
12412
 
11079
- const char * src0_row = (const char *) src0->data + (0 + i02*nb02 + i03*nb03);
12413
+ //if (ith == 0)
12414
+ // printf("MUL_MAT = [%d, %d, %d, %d] x [%d, %d, %d, %d] = %d x %d = %d. Fp Ops/Ch %d\n", ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, nchunk0, nchunk1, nchunk0 * nchunk1, ne00 * nr0 * nr1 / nchunk0 / nchunk1);
11080
12415
 
11081
- // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
11082
- // if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
11083
- // the original src1 data pointer, so we should index using the indices directly
11084
- // TODO: this is a bit of a hack, we should probably have a better way to handle this
11085
- const char * src1_col = (const char *) wdata +
11086
- (src1_cont || src1->type != vec_dot_type
11087
- ? (i11 + i12*ne11 + i13*ne12*ne11)*row_size
11088
- : (i11*nb11 + i12*nb12 + i13*nb13));
11089
- float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3));
12416
+ // The first chunk comes from our thread_id, the rest will get auto-assigned.
12417
+ int current_chunk = ith;
11090
12418
 
11091
- //for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
11092
- // vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col);
11093
- //}
12419
+ while (current_chunk < nchunk0 * nchunk1) {
12420
+ const int64_t ith0 = current_chunk % nchunk0;
12421
+ const int64_t ith1 = current_chunk / nchunk0;
11094
12422
 
11095
- for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ir0 += nrc) {
11096
- vec_dot(ne00, &tmp[ir0 - iir0], (nrc>1 ? 16 : 0), src0_row + ir0*nb01, (nrc>1 ? nb01 : 0), src1_col, (nrc>1 ? src1_col_stride : 0), nrc);
11097
- }
12423
+ const int64_t ir0_start = dr0 * ith0;
12424
+ const int64_t ir0_end = MIN(ir0_start + dr0, nr0);
11098
12425
 
11099
- for (int cn = 0; cn < nrc; ++cn) {
11100
- memcpy(&dst_col[iir0 + cn*nb1/nb0], tmp + (cn*16), (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float));
11101
- }
11102
- }
12426
+ const int64_t ir1_start = dr1 * ith1;
12427
+ const int64_t ir1_end = MIN(ir1_start + dr1, nr1);
12428
+
12429
+ ggml_compute_forward_mul_mat_one_chunk(params, dst, num_rows_per_vec_dot, ir0_start, ir0_end, ir1_start, ir1_end);
12430
+
12431
+ #ifdef GGML_PERF
12432
+ chunks_executed++;
12433
+ #endif
12434
+
12435
+ if (nth >= nchunk0 * nchunk1) {
12436
+ break;
11103
12437
  }
12438
+
12439
+ current_chunk = atomic_fetch_add(&state->shared->current_chunk, 1);
11104
12440
  }
12441
+
12442
+ #ifdef GGML_PERF
12443
+ // These numbers are useful when trying to measure how well the threading scheduling works.
12444
+ //int64_t workSize = (ne01 * ne11 * ne12 * ne13 * ne00) / nchunk0 / nchunk1;
12445
+ //float time = (ggml_perf_time_us() - t0);
12446
+ //printf("MUL_MAT = %f ms, [%d, %d, %d, %d] x [%d, %d, %d, %d] = %I64u, %f ops/usec in %d chunks.\n", time / 1000.0, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, workSize, (float)workSize/time, chunks_executed);
12447
+ #endif
11105
12448
  }
11106
12449
 
11107
12450
  // ggml_compute_forward_mul_mat_id
@@ -11793,6 +13136,7 @@ static void ggml_compute_forward_set(
11793
13136
  ggml_compute_forward_set_f32(params, dst);
11794
13137
  } break;
11795
13138
  case GGML_TYPE_F16:
13139
+ case GGML_TYPE_BF16:
11796
13140
  case GGML_TYPE_Q4_0:
11797
13141
  case GGML_TYPE_Q4_1:
11798
13142
  case GGML_TYPE_Q5_0:
@@ -11918,13 +13262,56 @@ static void ggml_compute_forward_get_rows_q(
11918
13262
  const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
11919
13263
  const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
11920
13264
 
11921
- dequantize_row_q(
13265
+ dequantize_row_q(
13266
+ (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
13267
+ (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc);
13268
+ }
13269
+ }
13270
+
13271
+ static void ggml_compute_forward_get_rows_f16(
13272
+ const struct ggml_compute_params * params,
13273
+ struct ggml_tensor * dst) {
13274
+
13275
+ const struct ggml_tensor * src0 = dst->src[0];
13276
+ const struct ggml_tensor * src1 = dst->src[1];
13277
+
13278
+ if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
13279
+ return;
13280
+ }
13281
+
13282
+ GGML_TENSOR_BINARY_OP_LOCALS
13283
+
13284
+ const int64_t nc = ne00;
13285
+ const int64_t nr = ggml_nelements(src1);
13286
+
13287
+ assert(ne0 == nc);
13288
+ assert(ne02 == ne11);
13289
+ assert(nb00 == sizeof(ggml_fp16_t));
13290
+ assert(ggml_nrows(dst) == nr);
13291
+
13292
+ const int ith = params->ith;
13293
+ const int nth = params->nth;
13294
+
13295
+ // rows per thread
13296
+ const int dr = (nr + nth - 1)/nth;
13297
+
13298
+ // row range for this thread
13299
+ const int ir0 = dr*ith;
13300
+ const int ir1 = MIN(ir0 + dr, nr);
13301
+
13302
+ for (int64_t i = ir0; i < ir1; ++i) {
13303
+ const int64_t i12 = i/(ne11*ne10);
13304
+ const int64_t i11 = (i - i12*ne11*ne10)/ne10;
13305
+ const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
13306
+ const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
13307
+
13308
+ ggml_fp16_to_fp32_row(
11922
13309
  (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
11923
13310
  (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc);
11924
13311
  }
11925
13312
  }
11926
13313
 
11927
- static void ggml_compute_forward_get_rows_f16(
13314
+ static void ggml_compute_forward_get_rows_bf16(
11928
13315
  const struct ggml_compute_params * params,
11929
13316
  struct ggml_tensor * dst) {
11930
13317
 
@@ -11942,7 +13329,7 @@ static void ggml_compute_forward_get_rows_f16(
11942
13329
 
11943
13330
  assert(ne0 == nc);
11944
13331
  assert(ne02 == ne11);
11945
- assert(nb00 == sizeof(ggml_fp16_t));
13332
+ assert(nb00 == sizeof(ggml_bf16_t));
11946
13333
  assert(ggml_nrows(dst) == nr);
11947
13334
 
11948
13335
  const int ith = params->ith;
@@ -11961,7 +13348,7 @@ static void ggml_compute_forward_get_rows_f16(
11961
13348
  const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
11962
13349
  const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
11963
13350
 
11964
- ggml_fp16_to_fp32_row(
13351
+ ggml_bf16_to_fp32_row(
11965
13352
  (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
11966
13353
  (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc);
11967
13354
  }
@@ -12044,6 +13431,10 @@ static void ggml_compute_forward_get_rows(
12044
13431
  {
12045
13432
  ggml_compute_forward_get_rows_f16(params, dst);
12046
13433
  } break;
13434
+ case GGML_TYPE_BF16:
13435
+ {
13436
+ ggml_compute_forward_get_rows_bf16(params, dst);
13437
+ } break;
12047
13438
  case GGML_TYPE_F32:
12048
13439
  case GGML_TYPE_I32:
12049
13440
  {
@@ -12356,7 +13747,6 @@ static void ggml_compute_forward_soft_max_f32(
12356
13747
 
12357
13748
  const struct ggml_tensor * src0 = dst->src[0];
12358
13749
  const struct ggml_tensor * src1 = dst->src[1];
12359
- const struct ggml_tensor * src2 = dst->src[2];
12360
13750
 
12361
13751
  assert(ggml_is_contiguous(dst));
12362
13752
  assert(ggml_are_same_shape(src0, dst));
@@ -12382,8 +13772,8 @@ static void ggml_compute_forward_soft_max_f32(
12382
13772
 
12383
13773
  // TODO: is this supposed to be ceil instead of floor?
12384
13774
  // https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370
12385
- const uint32_t n_head_kv = ne02;
12386
- const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head_kv));
13775
+ const uint32_t n_head = ne02;
13776
+ const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
12387
13777
 
12388
13778
  const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
12389
13779
  const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
@@ -12400,13 +13790,13 @@ static void ggml_compute_forward_soft_max_f32(
12400
13790
 
12401
13791
  float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
12402
13792
 
12403
- // when max_bias <= 0.0f, src2 is not used and we default it to src0 to avoid branching
12404
- ggml_fp16_t * pos_f16 = src2 ? (ggml_fp16_t *) src2->data : src0->data;
12405
- float * pos_f32 = src2 ? (float *) src2->data : src0->data;
12406
-
12407
- const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16) || (src2 && src2->type == GGML_TYPE_F16);
13793
+ const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
12408
13794
 
12409
13795
  for (int i1 = ir0; i1 < ir1; i1++) {
13796
+ // ALiBi
13797
+ const uint32_t h = (i1/ne01)%ne02; // head
13798
+ const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
13799
+
12410
13800
  float * sp = (float *)((char *) src0->data + i1*src0->nb[1]);
12411
13801
  float * dp = (float *)((char *) dst->data + i1*dst->nb[1]);
12412
13802
 
@@ -12419,27 +13809,11 @@ static void ggml_compute_forward_soft_max_f32(
12419
13809
  if (mp_f32) {
12420
13810
  if (use_f16) {
12421
13811
  for (int i = 0; i < nc; ++i) {
12422
- wp[i] += GGML_FP16_TO_FP32(mp_f16[i]);
12423
- }
12424
- } else {
12425
- for (int i = 0; i < nc; ++i) {
12426
- wp[i] += mp_f32[i];
12427
- }
12428
- }
12429
- }
12430
-
12431
- // ALiBi bias
12432
- if (max_bias > 0.0f) {
12433
- const uint32_t h = (i1/ne01)%ne02; // head
12434
- const float slope = h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1);
12435
-
12436
- if (use_f16) {
12437
- for (int i = 0; i < nc; ++i) {
12438
- wp[i] += slope*GGML_FP16_TO_FP32(pos_f16[i]);
13812
+ wp[i] += slope*GGML_FP16_TO_FP32(mp_f16[i]);
12439
13813
  }
12440
13814
  } else {
12441
13815
  for (int i = 0; i < nc; ++i) {
12442
- wp[i] += slope*pos_f32[i];
13816
+ wp[i] += slope*mp_f32[i];
12443
13817
  }
12444
13818
  }
12445
13819
  }
@@ -12454,22 +13828,7 @@ static void ggml_compute_forward_soft_max_f32(
12454
13828
  float max = -INFINITY;
12455
13829
  ggml_vec_max_f32(nc, &max, wp);
12456
13830
 
12457
- ggml_float sum = 0.0;
12458
-
12459
- uint16_t scvt;
12460
- for (int i = 0; i < nc; i++) {
12461
- if (wp[i] == -INFINITY) {
12462
- dp[i] = 0.0f;
12463
- } else {
12464
- // const float val = (wp[i] == -INFINITY) ? 0.0 : exp(wp[i] - max);
12465
- ggml_fp16_t s = GGML_FP32_TO_FP16(wp[i] - max);
12466
- memcpy(&scvt, &s, sizeof(scvt));
12467
- const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt]);
12468
- sum += (ggml_float)val;
12469
- dp[i] = val;
12470
- }
12471
- }
12472
-
13831
+ ggml_float sum = ggml_vec_soft_max_f32(nc, dp, wp, max);
12473
13832
  assert(sum > 0.0);
12474
13833
 
12475
13834
  sum = 1.0/sum;
@@ -12601,177 +13960,6 @@ static void ggml_compute_forward_soft_max_back(
12601
13960
  }
12602
13961
  }
12603
13962
 
12604
- // ggml_compute_forward_alibi
12605
-
12606
- static void ggml_compute_forward_alibi_f32(
12607
- const struct ggml_compute_params * params,
12608
- struct ggml_tensor * dst) {
12609
-
12610
- const struct ggml_tensor * src0 = dst->src[0];
12611
-
12612
- assert(params->ith == 0);
12613
-
12614
- if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
12615
- return;
12616
- }
12617
-
12618
- //const int n_past = ((int32_t *) dst->op_params)[0];
12619
- const int n_head = ((int32_t *) dst->op_params)[1];
12620
- float max_bias;
12621
- memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
12622
-
12623
- const int64_t ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
12624
- const int64_t ne1 = src0->ne[1]; // seq_len_without_past
12625
- const int64_t ne2 = src0->ne[2]; // n_head -> this is k
12626
- //const int64_t ne3 = src0->ne[3]; // 1 -> bsz
12627
-
12628
- const int64_t n = ggml_nrows(src0);
12629
- const int64_t ne2_ne3 = n/ne1; // ne2*ne3
12630
-
12631
- const size_t nb0 = src0->nb[0];
12632
- const size_t nb1 = src0->nb[1];
12633
- const size_t nb2 = src0->nb[2];
12634
- //const int nb3 = src0->nb[3];
12635
-
12636
- GGML_ASSERT(nb0 == sizeof(float));
12637
- GGML_ASSERT(n_head == ne2);
12638
-
12639
- // add alibi to src0 (KQ_scaled)
12640
- const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
12641
-
12642
- const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
12643
- const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
12644
-
12645
- for (int64_t k = 0; k < ne2_ne3; k++) {
12646
- // TODO: k*nb2 or k*nb3
12647
- float m_k;
12648
-
12649
- if (k < n_heads_log2_floor) {
12650
- m_k = powf(m0, k + 1);
12651
- } else {
12652
- m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1);
12653
- }
12654
-
12655
- for (int64_t i = 0; i < ne0; i++) {
12656
- for (int64_t j = 0; j < ne1; j++) {
12657
- float * const src = (float *)((char *) src0->data + i*nb0 + j*nb1 + k*nb2);
12658
- float * pdst = (float *)((char *) dst->data + i*nb0 + j*nb1 + k*nb2);
12659
- pdst[0] = i * m_k + src[0];
12660
- }
12661
- }
12662
- }
12663
- }
12664
-
12665
- static void ggml_compute_forward_alibi_f16(
12666
- const struct ggml_compute_params * params,
12667
- struct ggml_tensor * dst) {
12668
-
12669
- const struct ggml_tensor * src0 = dst->src[0];
12670
-
12671
- assert(params->ith == 0);
12672
-
12673
- if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
12674
- return;
12675
- }
12676
-
12677
- //const int n_past = ((int32_t *) dst->op_params)[0];
12678
- const int n_head = ((int32_t *) dst->op_params)[1];
12679
- float max_bias;
12680
- memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
12681
-
12682
- const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
12683
- const int ne1 = src0->ne[1]; // seq_len_without_past
12684
- const int ne2 = src0->ne[2]; // n_head -> this is k
12685
- //const int ne3 = src0->ne[3]; // 1 -> bsz
12686
-
12687
- const int n = ggml_nrows(src0);
12688
- const int ne2_ne3 = n/ne1; // ne2*ne3
12689
-
12690
- const int nb0 = src0->nb[0];
12691
- const int nb1 = src0->nb[1];
12692
- const int nb2 = src0->nb[2];
12693
- //const int nb3 = src0->nb[3];
12694
-
12695
- GGML_ASSERT(nb0 == sizeof(ggml_fp16_t));
12696
- //GGML_ASSERT(ne1 + n_past == ne0); (void) n_past;
12697
- GGML_ASSERT(n_head == ne2);
12698
-
12699
- // add alibi to src0 (KQ_scaled)
12700
- const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
12701
-
12702
- const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
12703
- const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
12704
-
12705
- for (int k = 0; k < ne2_ne3; k++) {
12706
- // TODO: k*nb2 or k*nb3
12707
- float m_k;
12708
-
12709
- if (k < n_heads_log2_floor) {
12710
- m_k = powf(m0, k + 1);
12711
- } else {
12712
- m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1);
12713
- }
12714
-
12715
- for (int i = 0; i < ne0; i++) {
12716
- for (int j = 0; j < ne1; j++) {
12717
- ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i*nb0 + j*nb1 + k*nb2);
12718
- float * pdst = (float *)((char *) dst->data + i*nb0 + j*nb1 + k*nb2);
12719
-
12720
- // we return F32
12721
- pdst[0] = i * m_k + GGML_FP16_TO_FP32(src[0]);
12722
- }
12723
- }
12724
- }
12725
- }
12726
-
12727
- static void ggml_compute_forward_alibi(
12728
- const struct ggml_compute_params * params,
12729
- struct ggml_tensor * dst) {
12730
-
12731
- const struct ggml_tensor * src0 = dst->src[0];
12732
-
12733
- switch (src0->type) {
12734
- case GGML_TYPE_F16:
12735
- {
12736
- ggml_compute_forward_alibi_f16(params, dst);
12737
- } break;
12738
- case GGML_TYPE_F32:
12739
- {
12740
- ggml_compute_forward_alibi_f32(params, dst);
12741
- } break;
12742
- case GGML_TYPE_Q4_0:
12743
- case GGML_TYPE_Q4_1:
12744
- case GGML_TYPE_Q5_0:
12745
- case GGML_TYPE_Q5_1:
12746
- case GGML_TYPE_Q8_0:
12747
- case GGML_TYPE_Q8_1:
12748
- case GGML_TYPE_Q2_K:
12749
- case GGML_TYPE_Q3_K:
12750
- case GGML_TYPE_Q4_K:
12751
- case GGML_TYPE_Q5_K:
12752
- case GGML_TYPE_Q6_K:
12753
- case GGML_TYPE_IQ2_XXS:
12754
- case GGML_TYPE_IQ2_XS:
12755
- case GGML_TYPE_IQ3_XXS:
12756
- case GGML_TYPE_IQ1_S:
12757
- case GGML_TYPE_IQ1_M:
12758
- case GGML_TYPE_IQ4_NL:
12759
- case GGML_TYPE_IQ4_XS:
12760
- case GGML_TYPE_IQ3_S:
12761
- case GGML_TYPE_IQ2_S:
12762
- case GGML_TYPE_Q8_K:
12763
- case GGML_TYPE_I8:
12764
- case GGML_TYPE_I16:
12765
- case GGML_TYPE_I32:
12766
- case GGML_TYPE_I64:
12767
- case GGML_TYPE_F64:
12768
- case GGML_TYPE_COUNT:
12769
- {
12770
- GGML_ASSERT(false);
12771
- } break;
12772
- }
12773
- }
12774
-
12775
13963
  // ggml_compute_forward_clamp
12776
13964
 
12777
13965
  static void ggml_compute_forward_clamp_f32(
@@ -12828,6 +14016,7 @@ static void ggml_compute_forward_clamp(
12828
14016
  ggml_compute_forward_clamp_f32(params, dst);
12829
14017
  } break;
12830
14018
  case GGML_TYPE_F16:
14019
+ case GGML_TYPE_BF16:
12831
14020
  case GGML_TYPE_Q4_0:
12832
14021
  case GGML_TYPE_Q4_1:
12833
14022
  case GGML_TYPE_Q5_0:
@@ -13993,25 +15182,28 @@ static void ggml_compute_forward_upscale_f32(
13993
15182
  return;
13994
15183
  }
13995
15184
 
13996
- GGML_ASSERT(src0->nb[0] == sizeof(float));
15185
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
13997
15186
 
13998
15187
  const int ith = params->ith;
13999
15188
  const int nth = params->nth;
14000
15189
 
14001
15190
  GGML_TENSOR_UNARY_OP_LOCALS
14002
15191
 
14003
- const int scale_factor = dst->op_params[0];
15192
+ const float sf0 = (float)ne0/src0->ne[0];
15193
+ const float sf1 = (float)ne1/src0->ne[1];
15194
+ const float sf2 = (float)ne2/src0->ne[2];
15195
+ const float sf3 = (float)ne3/src0->ne[3];
14004
15196
 
14005
15197
  // TODO: optimize
14006
15198
 
14007
15199
  for (int64_t i3 = 0; i3 < ne3; i3++) {
14008
- const int64_t i03 = i3;
15200
+ const int64_t i03 = i3 / sf3;
14009
15201
  for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
14010
- const int64_t i02 = i2;
15202
+ const int64_t i02 = i2 / sf2;
14011
15203
  for (int64_t i1 = 0; i1 < ne1; i1++) {
14012
- const int64_t i01 = i1 / scale_factor;
15204
+ const int64_t i01 = i1 / sf1;
14013
15205
  for (int64_t i0 = 0; i0 < ne0; i0++) {
14014
- const int64_t i00 = i0 / scale_factor;
15206
+ const int64_t i00 = i0 / sf0;
14015
15207
 
14016
15208
  const float * x = (float *)((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
14017
15209
  float * y = (float *)((char *) dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
@@ -14041,6 +15233,7 @@ static void ggml_compute_forward_upscale(
14041
15233
  }
14042
15234
  }
14043
15235
 
15236
+
14044
15237
  // ggml_compute_forward_pad
14045
15238
 
14046
15239
  static void ggml_compute_forward_pad_f32(
@@ -14394,37 +15587,7 @@ static void ggml_compute_forward_flash_attn_f32(
14394
15587
  vvexpf(S, S, &Mup);
14395
15588
  ggml_vec_sum_f32(Mup, &sum, S);
14396
15589
  #else
14397
- uint16_t scvt[GGML_SOFT_MAX_UNROLL]; UNUSED(scvt);
14398
- ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 };
14399
-
14400
- for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) {
14401
- if (i >= masked_begin) {
14402
- break;
14403
- }
14404
- float * SS = S + i;
14405
-
14406
- for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) {
14407
- if (i + j >= masked_begin) {
14408
- break;
14409
- } else if (SS[j] == -INFINITY) {
14410
- SS[j] = 0.0f;
14411
- } else {
14412
- #ifndef GGML_FLASH_ATTN_EXP_FP16
14413
- const float val = expf(SS[j] - max);
14414
- #else
14415
- ggml_fp16_t s = GGML_FP32_TO_FP16(SS[j] - max);
14416
- memcpy(&scvt[j], &s, sizeof(uint16_t));
14417
- const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt[j]]);
14418
- #endif
14419
- sump[j] += (ggml_float)val;
14420
- SS[j] = val;
14421
- }
14422
- }
14423
- }
14424
-
14425
- for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) {
14426
- sum += sump[i];
14427
- }
15590
+ sum = ggml_vec_soft_max_f32(Mup, S, S, max);
14428
15591
  #endif
14429
15592
  }
14430
15593
 
@@ -14606,28 +15769,7 @@ static void ggml_compute_forward_flash_attn_f16(
14606
15769
  vvexpf(S, S, &Mup);
14607
15770
  ggml_vec_sum_f32(Mup, &sum, S);
14608
15771
  #else
14609
- uint16_t scvt[GGML_SOFT_MAX_UNROLL];
14610
- ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 };
14611
-
14612
- for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) {
14613
- float * SS = S + i;
14614
-
14615
- for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) {
14616
- if (SS[j] == -INFINITY) {
14617
- SS[j] = 0.0f;
14618
- } else {
14619
- ggml_fp16_t s = GGML_FP32_TO_FP16(SS[j] - max);
14620
- memcpy(&scvt[j], &s, sizeof(uint16_t));
14621
- const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt[j]]);
14622
- sump[j] += (ggml_float)val;
14623
- SS[j] = val;
14624
- }
14625
- }
14626
- }
14627
-
14628
- for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) {
14629
- sum += sump[i];
14630
- }
15772
+ sum = ggml_vec_soft_max_f32(Mup, S, S, max);
14631
15773
  #endif
14632
15774
  }
14633
15775
 
@@ -14784,8 +15926,17 @@ static void ggml_compute_forward_flash_attn_ext_f16(
14784
15926
  const int ir0 = dr*ith;
14785
15927
  const int ir1 = MIN(ir0 + dr, nr);
14786
15928
 
14787
- float scale = 1.0f;
14788
- memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
15929
+ float scale = 1.0f;
15930
+ float max_bias = 0.0f;
15931
+
15932
+ memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
15933
+ memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
15934
+
15935
+ const uint32_t n_head = neq2;
15936
+ const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
15937
+
15938
+ const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
15939
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
14789
15940
 
14790
15941
  // loop over n_batch and n_head
14791
15942
  for (int ir = ir0; ir < ir1; ++ir) {
@@ -14794,6 +15945,9 @@ static void ggml_compute_forward_flash_attn_ext_f16(
14794
15945
  const int iq2 = (ir - iq3*neq2*neq1)/neq1;
14795
15946
  const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
14796
15947
 
15948
+ const uint32_t h = iq2; // head
15949
+ const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
15950
+
14797
15951
  float S = 0.0f;
14798
15952
  float M = -INFINITY;
14799
15953
 
@@ -14817,7 +15971,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
14817
15971
  // loop over n_kv and n_head_kv
14818
15972
  // ref: https://arxiv.org/pdf/2112.05682.pdf
14819
15973
  for (int64_t ic = 0; ic < nek1; ++ic) {
14820
- const float mv = mp ? GGML_FP16_TO_FP32(mp[ic]) : 0.0f;
15974
+ const float mv = mp ? slope*GGML_FP16_TO_FP32(mp[ic]) : 0.0f;
14821
15975
  if (mv == -INFINITY) {
14822
15976
  continue;
14823
15977
  }
@@ -14888,7 +16042,7 @@ static void ggml_compute_forward_flash_attn_ext(
14888
16042
  const struct ggml_tensor * v,
14889
16043
  const struct ggml_tensor * mask,
14890
16044
  struct ggml_tensor * dst) {
14891
- switch (dst->op_params[1]) {
16045
+ switch (dst->op_params[2]) {
14892
16046
  case GGML_PREC_DEFAULT:
14893
16047
  case GGML_PREC_F32:
14894
16048
  {
@@ -15242,38 +16396,7 @@ static void ggml_compute_forward_flash_attn_back_f32(
15242
16396
  vvexpf(SM, SM, &Mup);
15243
16397
  ggml_vec_sum_f32(Mup, &sum, SM);
15244
16398
  #else
15245
- uint16_t scvt[GGML_SOFT_MAX_UNROLL]; UNUSED(scvt);
15246
- ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 };
15247
-
15248
- for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) {
15249
- if (i >= masked_begin) {
15250
- break;
15251
- }
15252
- float * SR = S + i;
15253
- float * SW = SM + i;
15254
-
15255
- for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) {
15256
- if (i + j >= masked_begin) {
15257
- break;
15258
- } else if (SR[j] == -INFINITY) {
15259
- SW[j] = 0.0f;
15260
- } else {
15261
- #ifndef GGML_FLASH_ATTN_EXP_FP16
15262
- const float val = expf(SR[j] - max);
15263
- #else
15264
- ggml_fp16_t s = GGML_FP32_TO_FP16(SR[j] - max);
15265
- memcpy(&scvt[j], &s, sizeof(uint16_t));
15266
- const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt[j]]);
15267
- #endif
15268
- sump[j] += (ggml_float)val;
15269
- SW[j] = val;
15270
- }
15271
- }
15272
- }
15273
-
15274
- for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) {
15275
- sum += sump[i];
15276
- }
16399
+ sum = ggml_vec_soft_max_f32(Mup, SM, S, max);
15277
16400
  #endif
15278
16401
  }
15279
16402
 
@@ -15855,6 +16978,10 @@ static void ggml_compute_forward_unary(
15855
16978
  {
15856
16979
  ggml_compute_forward_relu(params, dst);
15857
16980
  } break;
16981
+ case GGML_UNARY_OP_SIGMOID:
16982
+ {
16983
+ ggml_compute_forward_sigmoid(params, dst);
16984
+ } break;
15858
16985
  case GGML_UNARY_OP_GELU:
15859
16986
  {
15860
16987
  ggml_compute_forward_gelu(params, dst);
@@ -15921,6 +17048,7 @@ static void ggml_compute_forward_get_rel_pos(
15921
17048
 
15922
17049
  switch (src0->type) {
15923
17050
  case GGML_TYPE_F16:
17051
+ case GGML_TYPE_BF16:
15924
17052
  {
15925
17053
  ggml_compute_forward_get_rel_pos_f16(params, dst);
15926
17054
  } break;
@@ -16294,35 +17422,15 @@ static void ggml_compute_forward_cross_entropy_loss_f32(
16294
17422
  assert(!isnan(s1[i]));
16295
17423
  }
16296
17424
  #endif
16297
- // soft_max
16298
- ggml_float sum = 0.0;
16299
- {
16300
- float max = -INFINITY;
16301
- ggml_vec_max_f32(nc, &max, s0);
16302
17425
 
16303
- uint16_t scvt; UNUSED(scvt);
16304
- for (int i = 0; i < nc; i++) {
16305
- if (s0[i] == -INFINITY) {
16306
- st[i] = 0.0f;
16307
- } else {
16308
- #ifndef GGML_CROSS_ENTROPY_EXP_FP16
16309
- const float s = s0[i] - max;
16310
- const float val = expf(s);
16311
- #else
16312
- ggml_fp16_t s = GGML_FP32_TO_FP16(s0[i] - max);
16313
- memcpy(&scvt, &s, sizeof(scvt));
16314
- const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt]);
16315
- #endif
16316
- sum += (ggml_float)val;
16317
- st[i] = val;
16318
- }
16319
- }
17426
+ // soft_max
17427
+ float max = -INFINITY;
17428
+ ggml_vec_max_f32(nc, &max, s0);
17429
+ ggml_float sum = ggml_vec_soft_max_f32(nc, st, s0, max);
17430
+ assert(sum > 0.0);
17431
+ sum = (1.0 - eps) / sum;
16320
17432
 
16321
- assert(sum > 0.0);
16322
- // sum = 1.0/sum;
16323
- }
16324
17433
  // avoid log(0) by rescaling from [0..1] to [eps..1]
16325
- sum = (1.0 - eps) / sum;
16326
17434
  ggml_vec_scale_f32(nc, st, sum);
16327
17435
  ggml_vec_add1_f32(nc, st, st, eps);
16328
17436
  ggml_vec_log_f32(nc, st, st);
@@ -16412,32 +17520,11 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
16412
17520
  #endif
16413
17521
 
16414
17522
  // soft_max
16415
- ggml_float sum = 0.0;
16416
- {
16417
- float max = -INFINITY;
16418
- ggml_vec_max_f32(nc, &max, s0);
16419
-
16420
- uint16_t scvt; UNUSED(scvt);
16421
- for (int i = 0; i < nc; i++) {
16422
- if (s0[i] == -INFINITY) {
16423
- ds0[i] = 0.0f;
16424
- } else {
16425
- #ifndef GGML_CROSS_ENTROPY_EXP_FP16
16426
- const float s = s0[i] - max;
16427
- const float val = expf(s);
16428
- #else
16429
- ggml_fp16_t s = GGML_FP32_TO_FP16(s0[i] - max);
16430
- memcpy(&scvt, &s, sizeof(scvt));
16431
- const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt]);
16432
- #endif
16433
- sum += (ggml_float)val;
16434
- ds0[i] = val;
16435
- }
16436
- }
16437
-
16438
- assert(sum > 0.0);
16439
- sum = (1.0 - eps)/sum;
16440
- }
17523
+ float max = -INFINITY;
17524
+ ggml_vec_max_f32(nc, &max, s0);
17525
+ ggml_float sum = ggml_vec_soft_max_f32(nc, ds0, s0, max);
17526
+ assert(sum > 0.0);
17527
+ sum = (1.0 - eps) / sum;
16441
17528
 
16442
17529
  // grad(src0) = (softmax(src0) - src1) * grad(cross_entropy_loss(src0, src1)) / nr
16443
17530
  ggml_vec_scale_f32(nc, ds0, sum);
@@ -16474,7 +17561,7 @@ static void ggml_compute_forward_cross_entropy_loss_back(
16474
17561
 
16475
17562
  /////////////////////////////////
16476
17563
 
16477
- static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
17564
+ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor, struct ggml_compute_state * state) {
16478
17565
  GGML_ASSERT(params);
16479
17566
 
16480
17567
  if (tensor->op == GGML_OP_NONE || ggml_is_empty(tensor)) {
@@ -16572,7 +17659,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
16572
17659
  } break;
16573
17660
  case GGML_OP_MUL_MAT:
16574
17661
  {
16575
- ggml_compute_forward_mul_mat(params, tensor);
17662
+ ggml_compute_forward_mul_mat(params, tensor, state);
16576
17663
  } break;
16577
17664
  case GGML_OP_MUL_MAT_ID:
16578
17665
  {
@@ -16650,10 +17737,6 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
16650
17737
  {
16651
17738
  ggml_compute_forward_rope_back(params, tensor);
16652
17739
  } break;
16653
- case GGML_OP_ALIBI:
16654
- {
16655
- ggml_compute_forward_alibi(params, tensor);
16656
- } break;
16657
17740
  case GGML_OP_CLAMP:
16658
17741
  {
16659
17742
  ggml_compute_forward_clamp(params, tensor);
@@ -17672,10 +18755,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
17672
18755
  zero_table);
17673
18756
  }
17674
18757
  } break;
17675
- case GGML_OP_ALIBI:
17676
- {
17677
- GGML_ASSERT(false); // TODO: not implemented
17678
- } break;
17679
18758
  case GGML_OP_CLAMP:
17680
18759
  {
17681
18760
  GGML_ASSERT(false); // TODO: not implemented
@@ -17846,6 +18925,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
17846
18925
  zero_table);
17847
18926
  }
17848
18927
  } break;
18928
+ case GGML_UNARY_OP_SIGMOID:
18929
+ {
18930
+ GGML_ASSERT(false); // TODO: not implemented
18931
+ } break;
17849
18932
  case GGML_UNARY_OP_GELU:
17850
18933
  {
17851
18934
  GGML_ASSERT(false); // TODO: not implemented
@@ -18192,8 +19275,6 @@ typedef int ggml_lock_t;
18192
19275
 
18193
19276
  #define GGML_LOCK_INITIALIZER 0
18194
19277
 
18195
- typedef pthread_t ggml_thread_t;
18196
-
18197
19278
  #define ggml_thread_create pthread_create
18198
19279
  #define ggml_thread_join pthread_join
18199
19280
 
@@ -18219,8 +19300,6 @@ typedef int ggml_lock_t;
18219
19300
 
18220
19301
  #define GGML_LOCK_INITIALIZER 0
18221
19302
 
18222
- typedef pthread_t ggml_thread_t;
18223
-
18224
19303
  #define ggml_thread_create pthread_create
18225
19304
  #define ggml_thread_join pthread_join
18226
19305
 
@@ -18300,31 +19379,6 @@ static void set_numa_thread_affinity(int thread_n) { UNUSED(thread_n); }
18300
19379
  static void clear_numa_thread_affinity(void) {}
18301
19380
  #endif
18302
19381
 
18303
- struct ggml_compute_state_shared {
18304
- const struct ggml_cgraph * cgraph;
18305
- const struct ggml_cplan * cplan;
18306
-
18307
- int64_t perf_node_start_cycles;
18308
- int64_t perf_node_start_time_us;
18309
-
18310
- const int n_threads;
18311
-
18312
- // synchronization primitives
18313
- atomic_int n_active; // num active threads
18314
- atomic_int node_n; // active graph node
18315
- atomic_int node_task; // active graph node task phase
18316
-
18317
- ggml_abort_callback abort_callback; // abort ggml_graph_compute when true
18318
- void * abort_callback_data;
18319
- };
18320
-
18321
- struct ggml_compute_state {
18322
- ggml_thread_t thrd;
18323
- int ith;
18324
- struct ggml_compute_state_shared * shared;
18325
- enum ggml_status ec;
18326
- };
18327
-
18328
19382
  static void ggml_graph_compute_perf_stats_node(struct ggml_tensor * node, const struct ggml_compute_state_shared * st) {
18329
19383
  int64_t cycles_cur = ggml_perf_cycles() - st->perf_node_start_cycles;
18330
19384
  int64_t time_us_cur = ggml_perf_time_us() - st->perf_node_start_time_us;
@@ -18375,6 +19429,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads, int n_cur_
18375
19429
  case GGML_UNARY_OP_TANH:
18376
19430
  case GGML_UNARY_OP_ELU:
18377
19431
  case GGML_UNARY_OP_RELU:
19432
+ case GGML_UNARY_OP_SIGMOID:
18378
19433
  case GGML_UNARY_OP_HARDSWISH: // to opt for multiple threads
18379
19434
  case GGML_UNARY_OP_HARDSIGMOID: // to opt for multiple threads
18380
19435
  {
@@ -18448,10 +19503,6 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads, int n_cur_
18448
19503
  {
18449
19504
  n_tasks = n_threads;
18450
19505
  } break;
18451
- case GGML_OP_ALIBI:
18452
- {
18453
- n_tasks = 1; //TODO
18454
- } break;
18455
19506
  case GGML_OP_CLAMP:
18456
19507
  {
18457
19508
  n_tasks = 1; //TODO
@@ -18600,6 +19651,10 @@ static void ggml_graph_compute_thread_sync_node(int * node_n, struct ggml_comput
18600
19651
 
18601
19652
  * node_n = atomic_load(&state->shared->node_n);
18602
19653
  if (* node_n != last_node_n) break;
19654
+ #if defined(__SSE3__)
19655
+ // Tell the processor we're spinning. It's a processor hint for spinlocks.
19656
+ _mm_pause();
19657
+ #endif
18603
19658
  }
18604
19659
  }
18605
19660
 
@@ -18614,6 +19669,10 @@ static void ggml_graph_compute_thread_sync_task(int * task_phase, struct ggml_co
18614
19669
 
18615
19670
  * task_phase = atomic_load(&state->shared->node_task);
18616
19671
  if (* task_phase != last_task_phase) break;
19672
+ #if defined(__SSE3__)
19673
+ // Tell the processor we're spinning. It's a processor hint for spinlocks.
19674
+ _mm_pause();
19675
+ #endif
18617
19676
  }
18618
19677
  }
18619
19678
 
@@ -18653,7 +19712,7 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
18653
19712
  struct ggml_tensor * node = cgraph->nodes[node_n];
18654
19713
  if (GGML_OP_HAS_FINALIZE[node->op]) {
18655
19714
  params.nth = ggml_get_n_tasks(node, n_threads, state->shared->n_threads);
18656
- ggml_compute_forward(&params, node);
19715
+ ggml_compute_forward(&params, node, state);
18657
19716
  }
18658
19717
  ggml_graph_compute_perf_stats_node(node, state->shared);
18659
19718
  }
@@ -18673,17 +19732,17 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
18673
19732
  /* INIT */
18674
19733
  if (GGML_OP_HAS_INIT[node->op]) {
18675
19734
  params.type = GGML_TASK_TYPE_INIT;
18676
- ggml_compute_forward(&params, node);
19735
+ ggml_compute_forward(&params, node, state);
18677
19736
  }
18678
19737
 
18679
19738
  // TODO: maybe push node_n to the atomic but if other threads see n_tasks is 1,
18680
19739
  // they do something more efficient than spinning (?)
18681
19740
  params.type = GGML_TASK_TYPE_COMPUTE;
18682
- ggml_compute_forward(&params, node);
19741
+ ggml_compute_forward(&params, node, state);
18683
19742
 
18684
19743
  if (GGML_OP_HAS_FINALIZE[node->op]) {
18685
19744
  params.type = GGML_TASK_TYPE_FINALIZE;
18686
- ggml_compute_forward(&params, node);
19745
+ ggml_compute_forward(&params, node, state);
18687
19746
  }
18688
19747
 
18689
19748
  ggml_graph_compute_perf_stats_node(node, state->shared);
@@ -18722,7 +19781,7 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
18722
19781
 
18723
19782
  if (state->ith < n_tasks) {
18724
19783
  if (GGML_OP_HAS_INIT[node->op]) {
18725
- ggml_compute_forward(&params, node);
19784
+ ggml_compute_forward(&params, node, state);
18726
19785
  }
18727
19786
  }
18728
19787
 
@@ -18743,7 +19802,7 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
18743
19802
 
18744
19803
  if (state->ith < n_tasks) {
18745
19804
  params.type = GGML_TASK_TYPE_COMPUTE;
18746
- ggml_compute_forward(&params, node);
19805
+ ggml_compute_forward(&params, node, state);
18747
19806
  }
18748
19807
 
18749
19808
  if (atomic_fetch_sub(&state->shared->n_active, 1) == 1) {
@@ -18785,7 +19844,10 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
18785
19844
  case GGML_OP_CPY:
18786
19845
  case GGML_OP_DUP:
18787
19846
  {
18788
- if (ggml_is_quantized(node->type)) {
19847
+ if (ggml_is_quantized(node->type) ||
19848
+ // F16 -> BF16 and BF16 -> F16 copies go through intermediate F32
19849
+ (node->src[0]->type == GGML_TYPE_F16 && node->src[1] && node->src[1]->type == GGML_TYPE_BF16) ||
19850
+ (node->src[0]->type == GGML_TYPE_BF16 && node->src[1] && node->src[1]->type == GGML_TYPE_F16)) {
18789
19851
  cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
18790
19852
  }
18791
19853
  } break;
@@ -18864,7 +19926,8 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
18864
19926
  const int64_t ne10 = node->src[1]->ne[0]; // L
18865
19927
  const int64_t ne11 = node->src[1]->ne[1]; // Cin
18866
19928
 
18867
- if (node->src[0]->type == GGML_TYPE_F16 &&
19929
+ if ((node->src[0]->type == GGML_TYPE_F16 ||
19930
+ node->src[0]->type == GGML_TYPE_BF16) &&
18868
19931
  node->src[1]->type == GGML_TYPE_F32) {
18869
19932
  cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02;
18870
19933
  cur += sizeof(ggml_fp16_t)*ne10*ne11;
@@ -18900,6 +19963,9 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
18900
19963
  } else if (node->src[1]->type == GGML_TYPE_F16) {
18901
19964
  cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1)
18902
19965
  cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
19966
+ } else if (node->src[1]->type == GGML_TYPE_BF16) {
19967
+ cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1)
19968
+ cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
18903
19969
  }
18904
19970
  } break;
18905
19971
  case GGML_OP_FLASH_ATTN_EXT:
@@ -18916,6 +19982,9 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
18916
19982
  } else if (node->src[1]->type == GGML_TYPE_F16) {
18917
19983
  cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
18918
19984
  cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2
19985
+ } else if (node->src[1]->type == GGML_TYPE_BF16) {
19986
+ cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
19987
+ cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2
18919
19988
  }
18920
19989
  } break;
18921
19990
  case GGML_OP_FLASH_ATTN_BACK:
@@ -18929,6 +19998,9 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
18929
19998
  } else if (node->src[1]->type == GGML_TYPE_F16) {
18930
19999
  cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1)
18931
20000
  cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
20001
+ } else if (node->src[1]->type == GGML_TYPE_BF16) {
20002
+ cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1)
20003
+ cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
18932
20004
  }
18933
20005
  } break;
18934
20006
 
@@ -18981,6 +20053,7 @@ enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cpl
18981
20053
  /*.node_task =*/ GGML_TASK_TYPE_FINALIZE,
18982
20054
  /*.abort_callback =*/ NULL,
18983
20055
  /*.abort_callback_data =*/ NULL,
20056
+ /*.current_chunk; =*/ 0,
18984
20057
  };
18985
20058
  struct ggml_compute_state * workers = alloca(sizeof(struct ggml_compute_state)*n_threads);
18986
20059
 
@@ -19705,7 +20778,9 @@ void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph
19705
20778
  if (node->type == GGML_TYPE_I8 || node->type == GGML_TYPE_I16 || node->type == GGML_TYPE_I32) {
19706
20779
  fprintf(fp, "%d", ggml_get_i32_1d(node, j));
19707
20780
  }
19708
- else if (node->type == GGML_TYPE_F32 || node->type == GGML_TYPE_F16) {
20781
+ else if (node->type == GGML_TYPE_F32 ||
20782
+ node->type == GGML_TYPE_F16 ||
20783
+ node->type == GGML_TYPE_BF16) {
19709
20784
  fprintf(fp, "%.1e", (double)ggml_get_f32_1d(node, j));
19710
20785
  }
19711
20786
  else {
@@ -20763,6 +21838,12 @@ size_t ggml_quantize_chunk(
20763
21838
  ggml_fp32_to_fp16_row(src + start, (ggml_fp16_t *)dst + start, n);
20764
21839
  result = n * elemsize;
20765
21840
  } break;
21841
+ case GGML_TYPE_BF16:
21842
+ {
21843
+ size_t elemsize = sizeof(ggml_bf16_t);
21844
+ ggml_fp32_to_bf16_row(src + start, (ggml_bf16_t *)dst + start, n);
21845
+ result = n * elemsize;
21846
+ } break;
20766
21847
  case GGML_TYPE_F32:
20767
21848
  {
20768
21849
  size_t elemsize = sizeof(float);
@@ -21139,7 +22220,7 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
21139
22220
  }
21140
22221
 
21141
22222
  // read the tensor infos
21142
- {
22223
+ if (ctx->header.n_tensors > 0) {
21143
22224
  ctx->infos = GGML_CALLOC(ctx->header.n_tensors, sizeof(struct gguf_tensor_info));
21144
22225
 
21145
22226
  for (uint64_t i = 0; i < ctx->header.n_tensors; ++i) {