llama_cpp 0.15.0 → 0.15.2

Sign up to get free protection for your applications and to get access to all the features.
@@ -4,7 +4,6 @@
4
4
  #include "ggml-impl.h"
5
5
  #include "ggml-quants.h"
6
6
  #include "ggml.h"
7
- #include "sgemm.h"
8
7
 
9
8
  #if defined(_MSC_VER) || defined(__MINGW32__)
10
9
  #include <malloc.h> // using malloc.h with MSC/MINGW
@@ -37,6 +36,10 @@
37
36
  #undef GGML_USE_LLAMAFILE
38
37
  #endif
39
38
 
39
+ #ifdef GGML_USE_LLAMAFILE
40
+ #include "sgemm.h"
41
+ #endif
42
+
40
43
  #if defined(_MSC_VER)
41
44
  // disable "possible loss of data" to avoid hundreds of casts
42
45
  // we should just be careful :)
@@ -109,6 +112,8 @@ typedef void * thread_ret_t;
109
112
 
110
113
  #endif
111
114
 
115
+ typedef pthread_t ggml_thread_t;
116
+
112
117
  #ifdef GGML_USE_CPU_HBM
113
118
  #include <hbwmalloc.h>
114
119
  #endif
@@ -160,9 +165,6 @@ void ggml_print_backtrace(void) {
160
165
  #define GGML_DEBUG 0
161
166
  #define GGML_GELU_FP16
162
167
  #define GGML_GELU_QUICK_FP16
163
- #define GGML_SILU_FP16
164
- // #define GGML_CROSS_ENTROPY_EXP_FP16
165
- // #define GGML_FLASH_ATTN_EXP_FP16
166
168
 
167
169
  #define GGML_SOFT_MAX_UNROLL 4
168
170
  #define GGML_VEC_DOT_UNROLL 2
@@ -313,16 +315,10 @@ static ggml_fp16_t ggml_table_gelu_f16[1 << 16];
313
315
  // precomputed quick gelu table for f16 (128 KB)
314
316
  static ggml_fp16_t ggml_table_gelu_quick_f16[1 << 16];
315
317
 
316
- // precomputed silu table for f16 (128 KB)
317
- static ggml_fp16_t ggml_table_silu_f16[1 << 16];
318
-
319
- // precomputed exp table for f16 (128 KB)
320
- static ggml_fp16_t ggml_table_exp_f16[1 << 16];
321
-
322
318
  // precomputed f32 table for f16 (256 KB) (ggml-impl.h)
323
319
  float ggml_table_f32_f16[1 << 16];
324
320
 
325
- const char * ggml_status_to_string(enum ggml_status status) {
321
+ GGML_CALL const char * ggml_status_to_string(enum ggml_status status) {
326
322
  switch (status) {
327
323
  case GGML_STATUS_ALLOC_FAILED: return "GGML status: error (failed to allocate memory)";
328
324
  case GGML_STATUS_FAILED: return "GGML status: error (operation failed)";
@@ -333,16 +329,26 @@ const char * ggml_status_to_string(enum ggml_status status) {
333
329
  return "GGML status: unknown";
334
330
  }
335
331
 
336
- // note: do not use these inside ggml.c
337
- // these are meant to be used via the ggml.h API
338
332
  float ggml_fp16_to_fp32(ggml_fp16_t x) {
333
+ #define ggml_fp16_to_fp32 do_not_use__ggml_fp16_to_fp32__in_ggml
339
334
  return GGML_FP16_TO_FP32(x);
340
335
  }
341
336
 
342
337
  ggml_fp16_t ggml_fp32_to_fp16(float x) {
338
+ #define ggml_fp32_to_fp16 do_not_use__ggml_fp32_to_fp16__in_ggml
343
339
  return GGML_FP32_TO_FP16(x);
344
340
  }
345
341
 
342
+ float ggml_bf16_to_fp32(ggml_bf16_t x) {
343
+ #define ggml_bf16_to_fp32 do_not_use__ggml_bf16_to_fp32__in_ggml
344
+ return GGML_BF16_TO_FP32(x); // it just left shifts
345
+ }
346
+
347
+ ggml_bf16_t ggml_fp32_to_bf16(float x) {
348
+ #define ggml_fp32_to_bf16 do_not_use__ggml_fp32_to_bf16__in_ggml
349
+ return GGML_FP32_TO_BF16(x);
350
+ }
351
+
346
352
  void ggml_fp16_to_fp32_row(const ggml_fp16_t * x, float * y, int64_t n) {
347
353
  for (int64_t i = 0; i < n; i++) {
348
354
  y[i] = GGML_FP16_TO_FP32(x[i]);
@@ -368,6 +374,49 @@ void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, int64_t n) {
368
374
  }
369
375
  }
370
376
 
377
+ void ggml_bf16_to_fp32_row(const ggml_bf16_t * x, float * y, int64_t n) {
378
+ int64_t i = 0;
379
+ #if defined(__AVX512F__)
380
+ for (; i + 16 <= n; i += 16) {
381
+ _mm512_storeu_ps(y + i,
382
+ _mm512_castsi512_ps(
383
+ _mm512_slli_epi32(
384
+ _mm512_cvtepu16_epi32(
385
+ _mm256_loadu_si256(
386
+ (const __m256i *)(x + i))),
387
+ 16)));
388
+ }
389
+ #elif defined(__AVX2__)
390
+ for (; i + 8 <= n; i += 8) {
391
+ _mm256_storeu_ps(y + i,
392
+ _mm256_castsi256_ps(
393
+ _mm256_slli_epi32(
394
+ _mm256_cvtepu16_epi32(
395
+ _mm_loadu_si128(
396
+ (const __m128i *)(x + i))),
397
+ 16)));
398
+ }
399
+ #endif
400
+ for (; i < n; i++) {
401
+ y[i] = GGML_BF16_TO_FP32(x[i]);
402
+ }
403
+ }
404
+
405
+ void ggml_fp32_to_bf16_row(const float * x, ggml_bf16_t * y, int64_t n) {
406
+ int i = 0;
407
+ #if defined(__AVX512BF16__)
408
+ for (; i + 32 <= n; i += 32) {
409
+ _mm512_storeu_ps(
410
+ (__m512 *)(y + i),
411
+ (__m512)_mm512_cvtne2ps_pbh(_mm512_loadu_ps(x + i + 16),
412
+ _mm512_loadu_ps(x + i)));
413
+ }
414
+ #endif
415
+ for (; i < n; i++) {
416
+ y[i] = GGML_FP32_TO_BF16(x[i]);
417
+ }
418
+ }
419
+
371
420
  bool ggml_guid_matches(ggml_guid_t guid_a, ggml_guid_t guid_b) {
372
421
  return memcmp(guid_a, guid_b, sizeof(ggml_guid)) == 0;
373
422
  }
@@ -503,6 +552,7 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
503
552
 
504
553
  static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float * restrict x, size_t bx, const float * restrict y, size_t by, int nrc);
505
554
  static void ggml_vec_dot_f16(int n, float * restrict s, size_t bs, ggml_fp16_t * restrict x, size_t bx, ggml_fp16_t * restrict y, size_t by, int nrc);
555
+ static void ggml_vec_dot_bf16(int n, float * restrict s, size_t bs, ggml_bf16_t * restrict x, size_t bx, ggml_bf16_t * restrict y, size_t by, int nrc);
506
556
 
507
557
  static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
508
558
  [GGML_TYPE_I8] = {
@@ -845,6 +895,18 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
845
895
  .type_size = sizeof(block_q8_K),
846
896
  .is_quantized = true,
847
897
  .from_float = quantize_row_q8_K,
898
+ },
899
+ [GGML_TYPE_BF16] = {
900
+ .type_name = "bf16",
901
+ .blck_size = 1,
902
+ .type_size = sizeof(ggml_bf16_t),
903
+ .is_quantized = false,
904
+ .to_float = (ggml_to_float_t) ggml_bf16_to_fp32_row,
905
+ .from_float = (ggml_from_float_t) ggml_fp32_to_bf16_row,
906
+ .from_float_reference = (ggml_from_float_t) ggml_fp32_to_bf16_row,
907
+ .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_bf16,
908
+ .vec_dot_type = GGML_TYPE_BF16,
909
+ .nrows = 1,
848
910
  }
849
911
  };
850
912
 
@@ -1237,6 +1299,8 @@ static inline void __avx_f32cx8_store(ggml_fp16_t *x, __m256 y) {
1237
1299
  #define GGML_F16_VEC_ZERO GGML_F32x4_ZERO
1238
1300
  #define GGML_F16_VEC_SET1 GGML_F32x4_SET1
1239
1301
  #define GGML_F16_VEC_FMA GGML_F32x4_FMA
1302
+ #define GGML_F16_VEC_ADD GGML_F32x4_ADD
1303
+ #define GGML_F16_VEC_MUL GGML_F32x4_MUL
1240
1304
  #define GGML_F16_VEC_REDUCE GGML_F32x4_REDUCE
1241
1305
  // Use vec_xl, not vec_ld, in case the load address is not aligned.
1242
1306
  #define GGML_F16_VEC_LOAD(p, i) (i & 0x1) ? \
@@ -1468,6 +1532,59 @@ static inline void __sse_f16x4_store(ggml_fp16_t *x, __m128 y) {
1468
1532
  #define GGML_F16_ARR (GGML_F16_STEP/GGML_F16_EPR)
1469
1533
  #endif
1470
1534
 
1535
+ //
1536
+ // ggml context
1537
+ //
1538
+
1539
+ struct ggml_context {
1540
+ size_t mem_size;
1541
+ void* mem_buffer;
1542
+ bool mem_buffer_owned;
1543
+ bool no_alloc;
1544
+ bool no_alloc_save; // this is used to save the no_alloc state when using scratch buffers
1545
+
1546
+ int n_objects;
1547
+
1548
+ struct ggml_object* objects_begin;
1549
+ struct ggml_object* objects_end;
1550
+
1551
+ struct ggml_scratch scratch;
1552
+ struct ggml_scratch scratch_save;
1553
+ };
1554
+
1555
+ struct ggml_context_container {
1556
+ bool used;
1557
+
1558
+ struct ggml_context context;
1559
+ };
1560
+
1561
+ struct ggml_compute_state_shared {
1562
+ const struct ggml_cgraph* cgraph;
1563
+ const struct ggml_cplan* cplan;
1564
+
1565
+ int64_t perf_node_start_cycles;
1566
+ int64_t perf_node_start_time_us;
1567
+
1568
+ const int n_threads;
1569
+
1570
+ // synchronization primitives
1571
+ atomic_int n_active; // num active threads
1572
+ atomic_int node_n; // active graph node
1573
+ atomic_int node_task; // active graph node task phase
1574
+
1575
+ ggml_abort_callback abort_callback; // abort ggml_graph_compute when true
1576
+ void* abort_callback_data;
1577
+
1578
+ atomic_int current_chunk; // currently processing chunk during Mat_Mul, shared between all the threads.
1579
+ };
1580
+
1581
+ struct ggml_compute_state {
1582
+ ggml_thread_t thrd;
1583
+ int ith;
1584
+ struct ggml_compute_state_shared* shared;
1585
+ enum ggml_status ec;
1586
+ };
1587
+
1471
1588
  //
1472
1589
  // fundamental operations
1473
1590
  //
@@ -1480,6 +1597,8 @@ inline static void ggml_vec_set_i32(const int n, int32_t * x, const int32_t v) {
1480
1597
 
1481
1598
  inline static void ggml_vec_set_f16(const int n, ggml_fp16_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
1482
1599
 
1600
+ inline static void ggml_vec_set_bf16(const int n, ggml_bf16_t * x, const ggml_bf16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
1601
+
1483
1602
  inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] + y[i]; }
1484
1603
  inline static void ggml_vec_add1_f32(const int n, float * z, const float * x, const float v) { for (int i = 0; i < n; ++i) z[i] = x[i] + v; }
1485
1604
  inline static void ggml_vec_acc_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] += x[i]; }
@@ -1498,7 +1617,7 @@ static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float *
1498
1617
  UNUSED(by);
1499
1618
  UNUSED(bs);
1500
1619
 
1501
- #ifdef GGML_SIMD
1620
+ #if defined(GGML_SIMD)
1502
1621
  float sumf = 0.0f;
1503
1622
  const int np = (n & ~(GGML_F32_STEP - 1));
1504
1623
 
@@ -1534,6 +1653,70 @@ static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float *
1534
1653
  *s = sumf;
1535
1654
  }
1536
1655
 
1656
+ static void ggml_vec_dot_bf16(int n, float * restrict s, size_t bs, ggml_bf16_t * restrict x, size_t bx, ggml_bf16_t * restrict y, size_t by, int nrc) {
1657
+ assert(nrc == 1);
1658
+ UNUSED(nrc);
1659
+ UNUSED(bx);
1660
+ UNUSED(by);
1661
+ UNUSED(bs);
1662
+ int i = 0;
1663
+ ggml_float sumf = 0;
1664
+
1665
+ #if defined(__AVX512BF16__)
1666
+ __m512 c1 = _mm512_setzero_ps();
1667
+ __m512 c2 = _mm512_setzero_ps();
1668
+ for (; i + 64 <= n; i += 64) {
1669
+ c1 = _mm512_dpbf16_ps(c1, (__m512bh)_mm512_loadu_ps((const float *)(x + i)),
1670
+ (__m512bh)_mm512_loadu_ps((const float *)(y + i)));
1671
+ c2 = _mm512_dpbf16_ps(c2, (__m512bh)_mm512_loadu_ps((const float *)(x + i + 32)),
1672
+ (__m512bh)_mm512_loadu_ps((const float *)(y + i + 32)));
1673
+ }
1674
+ sumf += (ggml_float)_mm512_reduce_add_ps(c1);
1675
+ sumf += (ggml_float)_mm512_reduce_add_ps(c2);
1676
+
1677
+ #elif defined(__AVX512F__)
1678
+ #define LOAD(p) _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i *)(p))), 16))
1679
+ __m512 c1 = _mm512_setzero_ps();
1680
+ __m512 c2 = _mm512_setzero_ps();
1681
+ for (; i + 32 <= n; i += 32) {
1682
+ c1 = _mm512_add_ps(_mm512_mul_ps(LOAD(x + i), LOAD(y + i)), c1);
1683
+ c2 = _mm512_add_ps(_mm512_mul_ps(LOAD(x + i + 16), LOAD(y + i + 16)), c2);
1684
+ }
1685
+ sumf += (ggml_float)_mm512_reduce_add_ps(c1);
1686
+ sumf += (ggml_float)_mm512_reduce_add_ps(c2);
1687
+
1688
+ #undef LOAD
1689
+ #elif defined(__AVX2__)
1690
+ #define LOAD(p) _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)(p))), 16))
1691
+ __m256 c1 = _mm256_setzero_ps();
1692
+ __m256 c2 = _mm256_setzero_ps();
1693
+ __m256 c3 = _mm256_setzero_ps();
1694
+ __m256 c4 = _mm256_setzero_ps();
1695
+ for (; i + 32 <= n; i += 32) {
1696
+ c1 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i), LOAD(y + i)), c1);
1697
+ c2 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i + 8), LOAD(y + i + 8)), c2);
1698
+ c3 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i + 16), LOAD(y + i + 16)), c3);
1699
+ c4 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i + 24), LOAD(y + i + 24)), c4);
1700
+ }
1701
+ __m128 g;
1702
+ c1 = _mm256_add_ps(_mm256_add_ps(c1, c3),
1703
+ _mm256_add_ps(c2, c4));
1704
+ g = _mm_add_ps(_mm256_extractf128_ps(c1, 1),
1705
+ _mm256_castps256_ps128(c1));
1706
+ g = _mm_add_ps(g, _mm_movehl_ps(g, g));
1707
+ g = _mm_add_ss(g, _mm_movehdup_ps(g));
1708
+ sumf += (ggml_float)_mm_cvtss_f32(g);
1709
+
1710
+ #undef LOAD
1711
+ #endif
1712
+
1713
+ for (; i < n; ++i) {
1714
+ sumf += (ggml_float)(GGML_BF16_TO_FP32(x[i]) *
1715
+ GGML_BF16_TO_FP32(y[i]));
1716
+ }
1717
+ *s = sumf;
1718
+ }
1719
+
1537
1720
  static void ggml_vec_dot_f16(int n, float * restrict s, size_t bs, ggml_fp16_t * restrict x, size_t bx, ggml_fp16_t * restrict y, size_t by, int nrc) {
1538
1721
  assert(nrc == 1);
1539
1722
  UNUSED(nrc);
@@ -1817,6 +2000,7 @@ inline static void ggml_vec_tanh_f32 (const int n, float * y, const float * x) {
1817
2000
  inline static void ggml_vec_elu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : expf(x[i])-1; }
1818
2001
  inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; }
1819
2002
  inline static void ggml_vec_leaky_relu_f32 (const int n, float * y, const float * x, const float ns) { for (int i = 0; i < n; ++i) y[i] = ((x[i] > 0.f) ? x[i] : 0.f) + ns * ((x[i] < 0.0f) ? x[i] : 0.f); }
2003
+ inline static void ggml_vec_sigmoid_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = 1.f / (1.f + expf(-x[i])); }
1820
2004
  // TODO: optimize performance
1821
2005
  inline static void ggml_vec_hardswish_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i] * fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); }
1822
2006
  inline static void ggml_vec_hardsigmoid_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); }
@@ -1892,52 +2076,291 @@ inline static float ggml_silu_f32(float x) {
1892
2076
  return x/(1.0f + expf(-x));
1893
2077
  }
1894
2078
 
1895
- //inline static void ggml_vec_silu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
1896
- // const uint16_t * i16 = (const uint16_t *) x;
1897
- // for (int i = 0; i < n; ++i) {
1898
- // y[i] = ggml_table_silu_f16[i16[i]];
1899
- // }
1900
- //}
2079
+ #if defined(__ARM_NEON)
1901
2080
 
1902
- #ifdef GGML_SILU_FP16
1903
- inline static void ggml_vec_silu_f32(const int n, float * y, const float * x) {
1904
- uint16_t t;
1905
- for (int i = 0; i < n; ++i) {
1906
- ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]);
1907
- memcpy(&t, &fp16, sizeof(uint16_t));
1908
- y[i] = GGML_FP16_TO_FP32(ggml_table_silu_f16[t]);
1909
- }
1910
- }
2081
+ // adapted from arm limited optimized routine
2082
+ // the maximum error is 1.45358 plus 0.5 ulps
2083
+ // numbers above 88.38 will flush to infinity
2084
+ // numbers beneath -103.97 will flush to zero
2085
+ inline static float32x4_t ggml_v_expf(float32x4_t x) {
2086
+ const float32x4_t r = vdupq_n_f32(0x1.8p23f);
2087
+ const float32x4_t z = vfmaq_f32(r, x, vdupq_n_f32(0x1.715476p+0f));
2088
+ const float32x4_t n = vsubq_f32(z, r);
2089
+ const float32x4_t b = vfmsq_f32(vfmsq_f32(x, n, vdupq_n_f32(0x1.62e4p-1f)), n,
2090
+ vdupq_n_f32(0x1.7f7d1cp-20f));
2091
+ const uint32x4_t e = vshlq_n_u32(vreinterpretq_u32_f32(z), 23);
2092
+ const float32x4_t k = vreinterpretq_f32_u32(vaddq_u32(e, vreinterpretq_u32_f32(vdupq_n_f32(1))));
2093
+ const uint32x4_t c = vcagtq_f32(n, vdupq_n_f32(126));
2094
+ const float32x4_t u = vmulq_f32(b, b);
2095
+ const float32x4_t j = vfmaq_f32(
2096
+ vmulq_f32(vdupq_n_f32(0x1.ffffecp-1f), b),
2097
+ vfmaq_f32(vfmaq_f32(vdupq_n_f32(0x1.fffdb6p-2f), vdupq_n_f32(0x1.555e66p-3f), b),
2098
+ vfmaq_f32(vdupq_n_f32(0x1.573e2ep-5f), vdupq_n_f32(0x1.0e4020p-7f), b), u), u);
2099
+ if (!vpaddd_u64(vreinterpretq_u64_u32(c)))
2100
+ return vfmaq_f32(k, j, k);
2101
+ const uint32x4_t d = vandq_u32(vclezq_f32(n), vdupq_n_u32(0x82000000));
2102
+ const float32x4_t s1 = vreinterpretq_f32_u32(vaddq_u32(d, vdupq_n_u32(0x7f000000)));
2103
+ const float32x4_t s2 = vreinterpretq_f32_u32(vsubq_u32(e, d));
2104
+ return vbslq_f32(vcagtq_f32(n, vdupq_n_f32(192)), vmulq_f32(s1, s1),
2105
+ vbslq_f32(c, vmulq_f32(vfmaq_f32(s2, s2, j), s1), vfmaq_f32(k, k, j)));
2106
+ }
2107
+
2108
+ // computes silu x/(1+exp(-x)) in single precision vector
2109
+ inline static float32x4_t ggml_v_silu(float32x4_t x) {
2110
+ const float32x4_t one = vdupq_n_f32(1.0f);
2111
+ const float32x4_t zero = vdupq_n_f32(0.0f);
2112
+ const float32x4_t neg_x = vsubq_f32(zero, x);
2113
+ const float32x4_t exp_neg_x = ggml_v_expf(neg_x);
2114
+ const float32x4_t one_plus_exp_neg_x = vaddq_f32(one, exp_neg_x);
2115
+ return vdivq_f32(x, one_plus_exp_neg_x);
2116
+ }
2117
+
2118
+ #elif defined(__AVX512F__) && defined(__AVX512DQ__)
2119
+
2120
+ // adapted from arm limited optimized routine
2121
+ // the maximum error is 1.45358 plus 0.5 ulps
2122
+ // numbers above 88.38 will flush to infinity
2123
+ // numbers beneath -103.97 will flush to zero
2124
+ inline static __m512 ggml_v_expf(__m512 x) {
2125
+ const __m512 r = _mm512_set1_ps(0x1.8p23f);
2126
+ const __m512 z = _mm512_fmadd_ps(x, _mm512_set1_ps(0x1.715476p+0f), r);
2127
+ const __m512 n = _mm512_sub_ps(z, r);
2128
+ const __m512 b = _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.7f7d1cp-20f),
2129
+ _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.62e4p-1f), x));
2130
+ const __m512i e = _mm512_slli_epi32(_mm512_castps_si512(z), 23);
2131
+ const __m512 k = _mm512_castsi512_ps(_mm512_add_epi32(e, _mm512_castps_si512(_mm512_set1_ps(1))));
2132
+ const __mmask16 c = _mm512_cmp_ps_mask(_mm512_abs_ps(n), _mm512_set1_ps(126), _CMP_GT_OQ);
2133
+ const __m512 u = _mm512_mul_ps(b, b);
2134
+ const __m512 j = _mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_set1_ps(0x1.0e4020p-7f), b,
2135
+ _mm512_set1_ps(0x1.573e2ep-5f)), u,
2136
+ _mm512_fmadd_ps(_mm512_set1_ps(0x1.555e66p-3f), b,
2137
+ _mm512_set1_ps(0x1.fffdb6p-2f))),
2138
+ u, _mm512_mul_ps(_mm512_set1_ps(0x1.ffffecp-1f), b));
2139
+ if (_mm512_kortestz(c, c))
2140
+ return _mm512_fmadd_ps(j, k, k);
2141
+ const __m512i g = _mm512_and_si512(
2142
+ _mm512_movm_epi32(_mm512_cmp_ps_mask(n, _mm512_setzero_ps(), _CMP_LE_OQ)),
2143
+ _mm512_set1_epi32(0x82000000u));
2144
+ const __m512 s1 =
2145
+ _mm512_castsi512_ps(_mm512_add_epi32(g, _mm512_set1_epi32(0x7f000000u)));
2146
+ const __m512 s2 = _mm512_castsi512_ps(_mm512_sub_epi32(e, g));
2147
+ const __mmask16 d =
2148
+ _mm512_cmp_ps_mask(_mm512_abs_ps(n), _mm512_set1_ps(192), _CMP_GT_OQ);
2149
+ return _mm512_mask_blend_ps(
2150
+ d, _mm512_mask_blend_ps(
2151
+ c, _mm512_fmadd_ps(k, j, k),
2152
+ _mm512_mul_ps(_mm512_fmadd_ps(s2, j, s2), s1)),
2153
+ _mm512_mul_ps(s1, s1));
2154
+ }
2155
+
2156
+ // computes silu x/(1+exp(-x)) in single precision vector
2157
+ inline static __m512 ggml_v_silu(__m512 x) {
2158
+ const __m512 one = _mm512_set1_ps(1);
2159
+ const __m512 zero = _mm512_setzero_ps();
2160
+ const __m512 neg_x = _mm512_sub_ps(zero, x);
2161
+ const __m512 exp_neg_x = ggml_v_expf(neg_x);
2162
+ const __m512 one_plus_exp_neg_x = _mm512_add_ps(one, exp_neg_x);
2163
+ return _mm512_div_ps(x, one_plus_exp_neg_x);
2164
+ }
2165
+
2166
+ #elif defined(__AVX2__) && defined(__FMA__)
2167
+
2168
+ // adapted from arm limited optimized routine
2169
+ // the maximum error is 1.45358 plus 0.5 ulps
2170
+ // numbers above 88.38 will flush to infinity
2171
+ // numbers beneath -103.97 will flush to zero
2172
+ inline static __m256 ggml_v_expf(__m256 x) {
2173
+ const __m256 r = _mm256_set1_ps(0x1.8p23f);
2174
+ const __m256 z = _mm256_fmadd_ps(x, _mm256_set1_ps(0x1.715476p+0f), r);
2175
+ const __m256 n = _mm256_sub_ps(z, r);
2176
+ const __m256 b = _mm256_fnmadd_ps(n, _mm256_set1_ps(0x1.7f7d1cp-20f),
2177
+ _mm256_fnmadd_ps(n, _mm256_set1_ps(0x1.62e4p-1f), x));
2178
+ const __m256i e = _mm256_slli_epi32(_mm256_castps_si256(z), 23);
2179
+ const __m256 k = _mm256_castsi256_ps(
2180
+ _mm256_add_epi32(e, _mm256_castps_si256(_mm256_set1_ps(1))));
2181
+ const __m256i c = _mm256_castps_si256(
2182
+ _mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n),
2183
+ _mm256_set1_ps(126), _CMP_GT_OQ));
2184
+ const __m256 u = _mm256_mul_ps(b, b);
2185
+ const __m256 j = _mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_set1_ps(0x1.0e4020p-7f), b,
2186
+ _mm256_set1_ps(0x1.573e2ep-5f)), u,
2187
+ _mm256_fmadd_ps(_mm256_set1_ps(0x1.555e66p-3f), b,
2188
+ _mm256_set1_ps(0x1.fffdb6p-2f))),
2189
+ u, _mm256_mul_ps(_mm256_set1_ps(0x1.ffffecp-1f), b));
2190
+ if (!_mm256_movemask_ps(_mm256_castsi256_ps(c)))
2191
+ return _mm256_fmadd_ps(j, k, k);
2192
+ const __m256i g = _mm256_and_si256(
2193
+ _mm256_castps_si256(_mm256_cmp_ps(n, _mm256_setzero_ps(), _CMP_LE_OQ)),
2194
+ _mm256_set1_epi32(0x82000000u));
2195
+ const __m256 s1 =
2196
+ _mm256_castsi256_ps(_mm256_add_epi32(g, _mm256_set1_epi32(0x7f000000u)));
2197
+ const __m256 s2 = _mm256_castsi256_ps(_mm256_sub_epi32(e, g));
2198
+ const __m256i d = _mm256_castps_si256(
2199
+ _mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n),
2200
+ _mm256_set1_ps(192), _CMP_GT_OQ));
2201
+ return _mm256_or_ps(
2202
+ _mm256_and_ps(_mm256_castsi256_ps(d), _mm256_mul_ps(s1, s1)),
2203
+ _mm256_andnot_ps(
2204
+ _mm256_castsi256_ps(d),
2205
+ _mm256_or_ps(
2206
+ _mm256_and_ps(_mm256_castsi256_ps(c),
2207
+ _mm256_mul_ps(_mm256_fmadd_ps(s2, j, s2), s1)),
2208
+ _mm256_andnot_ps(_mm256_castsi256_ps(c), _mm256_fmadd_ps(k, j, k)))));
2209
+ }
2210
+
2211
+ // computes silu x/(1+exp(-x)) in single precision vector
2212
+ inline static __m256 ggml_v_silu(__m256 x) {
2213
+ const __m256 one = _mm256_set1_ps(1);
2214
+ const __m256 zero = _mm256_setzero_ps();
2215
+ const __m256 neg_x = _mm256_sub_ps(zero, x);
2216
+ const __m256 exp_neg_x = ggml_v_expf(neg_x);
2217
+ const __m256 one_plus_exp_neg_x = _mm256_add_ps(one, exp_neg_x);
2218
+ return _mm256_div_ps(x, one_plus_exp_neg_x);
2219
+ }
2220
+
2221
+ #elif defined(__SSE2__) // __AVX2__ / __ARM_NEON
2222
+
2223
+ #if defined(__FMA__)
2224
+ #define MADD128(x, y, z) _mm_fmadd_ps(x, y, z)
2225
+ #define NMADD128(x, y, z) _mm_fnmadd_ps(x, y, z)
1911
2226
  #else
1912
- inline static void ggml_vec_silu_f32(const int n, float * y, const float * x) {
1913
- for (int i = 0; i < n; ++i) {
2227
+ #define MADD128(x, y, z) _mm_add_ps(_mm_mul_ps(x, y), z)
2228
+ #define NMADD128(x, y, z) _mm_sub_ps(z, _mm_mul_ps(x, y))
2229
+ #endif
2230
+
2231
+ // adapted from arm limited optimized routine
2232
+ // the maximum error is 1.45358 plus 0.5 ulps
2233
+ // numbers above 88.38 will flush to infinity
2234
+ // numbers beneath -103.97 will flush to zero
2235
+ inline static __m128 ggml_v_expf(__m128 x) {
2236
+ const __m128 r = _mm_set1_ps(0x1.8p23f);
2237
+ const __m128 z = MADD128(x, _mm_set1_ps(0x1.715476p+0f), r);
2238
+ const __m128 n = _mm_sub_ps(z, r);
2239
+ const __m128 b =
2240
+ NMADD128(n, _mm_set1_ps(0x1.7f7d1cp-20f), NMADD128(n, _mm_set1_ps(0x1.62e4p-1f), x));
2241
+ const __m128i e = _mm_slli_epi32(_mm_castps_si128(z), 23);
2242
+ const __m128 k = _mm_castsi128_ps(_mm_add_epi32(e, _mm_castps_si128(_mm_set1_ps(1))));
2243
+ const __m128i c =
2244
+ _mm_castps_si128(_mm_cmpgt_ps(_mm_andnot_ps(_mm_set1_ps(-0.f), n), _mm_set1_ps(126)));
2245
+ const __m128 u = _mm_mul_ps(b, b);
2246
+ const __m128 j =
2247
+ MADD128(MADD128(MADD128(_mm_set1_ps(0x1.0e4020p-7f), b, _mm_set1_ps(0x1.573e2ep-5f)), u,
2248
+ MADD128(_mm_set1_ps(0x1.555e66p-3f), b, _mm_set1_ps(0x1.fffdb6p-2f))),
2249
+ u, _mm_mul_ps(_mm_set1_ps(0x1.ffffecp-1f), b));
2250
+ if (!_mm_movemask_epi8(c))
2251
+ return MADD128(j, k, k);
2252
+ const __m128i g = _mm_and_si128(_mm_castps_si128(_mm_cmple_ps(n, _mm_setzero_ps())),
2253
+ _mm_set1_epi32(0x82000000u));
2254
+ const __m128 s1 = _mm_castsi128_ps(_mm_add_epi32(g, _mm_set1_epi32(0x7f000000u)));
2255
+ const __m128 s2 = _mm_castsi128_ps(_mm_sub_epi32(e, g));
2256
+ const __m128i d =
2257
+ _mm_castps_si128(_mm_cmpgt_ps(_mm_andnot_ps(_mm_set1_ps(-0.f), n), _mm_set1_ps(192)));
2258
+ return _mm_or_ps(
2259
+ _mm_and_ps(_mm_castsi128_ps(d), _mm_mul_ps(s1, s1)),
2260
+ _mm_andnot_ps(_mm_castsi128_ps(d),
2261
+ _mm_or_ps(_mm_and_ps(_mm_castsi128_ps(c), _mm_mul_ps(MADD128(s2, j, s2), s1)),
2262
+ _mm_andnot_ps(_mm_castsi128_ps(c), MADD128(k, j, k)))));
2263
+ }
2264
+
2265
+ // computes silu x/(1+exp(-x)) in single precision vector
2266
+ inline static __m128 ggml_v_silu(__m128 x) {
2267
+ const __m128 one = _mm_set1_ps(1);
2268
+ const __m128 zero = _mm_setzero_ps();
2269
+ const __m128 neg_x = _mm_sub_ps(zero, x);
2270
+ const __m128 exp_neg_x = ggml_v_expf(neg_x);
2271
+ const __m128 one_plus_exp_neg_x = _mm_add_ps(one, exp_neg_x);
2272
+ return _mm_div_ps(x, one_plus_exp_neg_x);
2273
+ }
2274
+
2275
+ #endif // __ARM_NEON / __AVX2__ / __SSE2__
2276
+
2277
+ static void ggml_vec_silu_f32(const int n, float * y, const float * x) {
2278
+ int i = 0;
2279
+ #if defined(__AVX512F__) && defined(__AVX512DQ__)
2280
+ for (; i + 15 < n; i += 16) {
2281
+ _mm512_storeu_ps(y + i, ggml_v_silu(_mm512_loadu_ps(x + i)));
2282
+ }
2283
+ #elif defined(__AVX2__) && defined(__FMA__)
2284
+ for (; i + 7 < n; i += 8) {
2285
+ _mm256_storeu_ps(y + i, ggml_v_silu(_mm256_loadu_ps(x + i)));
2286
+ }
2287
+ #elif defined(__SSE2__)
2288
+ for (; i + 3 < n; i += 4) {
2289
+ _mm_storeu_ps(y + i, ggml_v_silu(_mm_loadu_ps(x + i)));
2290
+ }
2291
+ #elif defined(__ARM_NEON)
2292
+ for (; i + 3 < n; i += 4) {
2293
+ vst1q_f32(y + i, ggml_v_silu(vld1q_f32(x + i)));
2294
+ }
2295
+ #endif
2296
+ for (; i < n; ++i) {
1914
2297
  y[i] = ggml_silu_f32(x[i]);
1915
2298
  }
1916
2299
  }
2300
+
2301
+ static ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max) {
2302
+ int i = 0;
2303
+ ggml_float sum = 0;
2304
+ #if defined(__AVX512F__) && defined(__AVX512DQ__)
2305
+ for (; i + 15 < n; i += 16) {
2306
+ __m512 val = ggml_v_expf(_mm512_sub_ps(_mm512_loadu_ps(x + i),
2307
+ _mm512_set1_ps(max)));
2308
+ _mm512_storeu_ps(y + i, val);
2309
+ sum += (ggml_float)_mm512_reduce_add_ps(val);
2310
+ }
2311
+ #elif defined(__AVX2__) && defined(__FMA__)
2312
+ for (; i + 7 < n; i += 8) {
2313
+ __m256 val = ggml_v_expf(_mm256_sub_ps(_mm256_loadu_ps(x + i),
2314
+ _mm256_set1_ps(max)));
2315
+ _mm256_storeu_ps(y + i, val);
2316
+ __m128 val2 = _mm_add_ps(_mm256_extractf128_ps(val, 1),
2317
+ _mm256_castps256_ps128(val));
2318
+ val2 = _mm_add_ps(val2, _mm_movehl_ps(val2, val2));
2319
+ val2 = _mm_add_ss(val2, _mm_movehdup_ps(val2));
2320
+ sum += (ggml_float)_mm_cvtss_f32(val2);
2321
+ }
2322
+ #elif defined(__SSE2__)
2323
+ for (; i + 3 < n; i += 4) {
2324
+ __m128 val = ggml_v_expf(_mm_sub_ps(_mm_loadu_ps(x + i),
2325
+ _mm_set1_ps(max)));
2326
+ _mm_storeu_ps(y + i, val);
2327
+ #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
2328
+ val = _mm_add_ps(val, _mm_movehl_ps(val, val));
2329
+ val = _mm_add_ss(val, _mm_movehdup_ps(val));
2330
+ #else
2331
+ __m128 tmp = _mm_shuffle_ps(val, val, _MM_SHUFFLE(2, 3, 0, 1));
2332
+ val = _mm_add_ps(val, tmp);
2333
+ tmp = _mm_movehl_ps(tmp, val);
2334
+ val = _mm_add_ss(val, tmp);
1917
2335
  #endif
2336
+ sum += (ggml_float)_mm_cvtss_f32(val);
2337
+ }
2338
+ #elif defined(__ARM_NEON)
2339
+ for (; i + 3 < n; i += 4) {
2340
+ float32x4_t val = ggml_v_expf(vsubq_f32(vld1q_f32(x + i),
2341
+ vdupq_n_f32(max)));
2342
+ vst1q_f32(y + i, val);
2343
+ sum += (ggml_float)vaddvq_f32(val);
2344
+ }
2345
+ #endif
2346
+ for (; i < n; ++i) {
2347
+ float val = expf(x[i] - max);
2348
+ sum += (ggml_float)val;
2349
+ y[i] = val;
2350
+ }
2351
+ return sum;
2352
+ }
1918
2353
 
1919
2354
  inline static float ggml_silu_backward_f32(float x, float dy) {
1920
2355
  const float s = 1.0f/(1.0f + expf(-x));
1921
2356
  return dy*s*(1.0f + x*(1.0f - s));
1922
2357
  }
1923
2358
 
1924
- #ifdef GGML_SILU_FP16
1925
- inline static void ggml_vec_silu_backward_f32(const int n, float * dx, const float * x, const float * dy) {
1926
- for (int i = 0; i < n; ++i) {
1927
- // we did not use x[i] to compute forward silu but its f16 equivalent
1928
- // take derivative at f16 of x[i]:
1929
- ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]);
1930
- float usedx = GGML_FP16_TO_FP32(fp16);
1931
- dx[i] = ggml_silu_backward_f32(usedx, dy[i]);
1932
- }
1933
- }
1934
- #else
1935
2359
  inline static void ggml_vec_silu_backward_f32(const int n, float * dx, const float * x, const float * dy) {
1936
2360
  for (int i = 0; i < n; ++i) {
1937
2361
  dx[i] = ggml_silu_backward_f32(x[i], dy[i]);
1938
2362
  }
1939
2363
  }
1940
- #endif
1941
2364
 
1942
2365
  inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) {
1943
2366
  #ifndef GGML_USE_ACCELERATE
@@ -1967,6 +2390,14 @@ inline static void ggml_vec_sum_f16_ggf(const int n, float * s, const ggml_fp16_
1967
2390
  *s = sum;
1968
2391
  }
1969
2392
 
2393
+ inline static void ggml_vec_sum_bf16_ggf(const int n, float * s, const ggml_bf16_t * x) {
2394
+ float sum = 0.0f;
2395
+ for (int i = 0; i < n; ++i) {
2396
+ sum += GGML_BF16_TO_FP32(x[i]);
2397
+ }
2398
+ *s = sum;
2399
+ }
2400
+
1970
2401
  inline static void ggml_vec_max_f32(const int n, float * s, const float * x) {
1971
2402
  #ifndef GGML_USE_ACCELERATE
1972
2403
  float max = -INFINITY;
@@ -2045,7 +2476,6 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
2045
2476
  "SOFT_MAX_BACK",
2046
2477
  "ROPE",
2047
2478
  "ROPE_BACK",
2048
- "ALIBI",
2049
2479
  "CLAMP",
2050
2480
  "CONV_TRANSPOSE_1D",
2051
2481
  "IM2COL",
@@ -2087,7 +2517,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
2087
2517
  "CROSS_ENTROPY_LOSS_BACK",
2088
2518
  };
2089
2519
 
2090
- static_assert(GGML_OP_COUNT == 77, "GGML_OP_COUNT != 77");
2520
+ static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76");
2091
2521
 
2092
2522
  static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
2093
2523
  "none",
@@ -2136,7 +2566,6 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
2136
2566
  "soft_max_back(x)",
2137
2567
  "rope(x)",
2138
2568
  "rope_back(x)",
2139
- "alibi(x)",
2140
2569
  "clamp(x)",
2141
2570
  "conv_transpose_1d(x)",
2142
2571
  "im2col(x)",
@@ -2178,7 +2607,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
2178
2607
  "cross_entropy_loss_back(x,y)",
2179
2608
  };
2180
2609
 
2181
- static_assert(GGML_OP_COUNT == 77, "GGML_OP_COUNT != 77");
2610
+ static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76");
2182
2611
 
2183
2612
  static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
2184
2613
 
@@ -2191,6 +2620,7 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
2191
2620
  "TANH",
2192
2621
  "ELU",
2193
2622
  "RELU",
2623
+ "SIGMOID",
2194
2624
  "GELU",
2195
2625
  "GELU_QUICK",
2196
2626
  "SILU",
@@ -2198,7 +2628,7 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
2198
2628
  "HARDSIGMOID",
2199
2629
  };
2200
2630
 
2201
- static_assert(GGML_UNARY_OP_COUNT == 12, "GGML_UNARY_OP_COUNT != 12");
2631
+ static_assert(GGML_UNARY_OP_COUNT == 13, "GGML_UNARY_OP_COUNT != 13");
2202
2632
 
2203
2633
 
2204
2634
  static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
@@ -2240,32 +2670,6 @@ static void ggml_setup_op_has_task_pass(void) {
2240
2670
  }
2241
2671
  }
2242
2672
 
2243
- //
2244
- // ggml context
2245
- //
2246
-
2247
- struct ggml_context {
2248
- size_t mem_size;
2249
- void * mem_buffer;
2250
- bool mem_buffer_owned;
2251
- bool no_alloc;
2252
- bool no_alloc_save; // this is used to save the no_alloc state when using scratch buffers
2253
-
2254
- int n_objects;
2255
-
2256
- struct ggml_object * objects_begin;
2257
- struct ggml_object * objects_end;
2258
-
2259
- struct ggml_scratch scratch;
2260
- struct ggml_scratch scratch_save;
2261
- };
2262
-
2263
- struct ggml_context_container {
2264
- bool used;
2265
-
2266
- struct ggml_context context;
2267
- };
2268
-
2269
2673
  //
2270
2674
  // NUMA support
2271
2675
  //
@@ -2377,7 +2781,7 @@ void ggml_numa_init(enum ggml_numa_strategy numa_flag) {
2377
2781
  // figure out which node we're on
2378
2782
  uint current_cpu;
2379
2783
  int getcpu_ret = 0;
2380
- #if __GLIBC__ > 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ > 28)
2784
+ #if __GLIBC__ > 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ > 28) || defined(__COSMOPOLITAN__)
2381
2785
  getcpu_ret = getcpu(&current_cpu, &g_state.numa.current_node);
2382
2786
  #else
2383
2787
  // old glibc doesn't have a wrapper for this call. Fall back on direct syscall
@@ -2588,6 +2992,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
2588
2992
  switch (ftype) {
2589
2993
  case GGML_FTYPE_ALL_F32: wtype = GGML_TYPE_F32; break;
2590
2994
  case GGML_FTYPE_MOSTLY_F16: wtype = GGML_TYPE_F16; break;
2995
+ case GGML_FTYPE_MOSTLY_BF16: wtype = GGML_TYPE_BF16; break;
2591
2996
  case GGML_FTYPE_MOSTLY_Q4_0: wtype = GGML_TYPE_Q4_0; break;
2592
2997
  case GGML_FTYPE_MOSTLY_Q4_1: wtype = GGML_TYPE_Q4_1; break;
2593
2998
  case GGML_FTYPE_MOSTLY_Q5_0: wtype = GGML_TYPE_Q5_0; break;
@@ -2678,6 +3083,16 @@ bool ggml_are_same_shape(const struct ggml_tensor * t0, const struct ggml_tensor
2678
3083
  (t0->ne[3] == t1->ne[3] );
2679
3084
  }
2680
3085
 
3086
+ bool ggml_are_same_stride(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
3087
+ static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
3088
+
3089
+ return
3090
+ (t0->nb[0] == t1->nb[0] ) &&
3091
+ (t0->nb[1] == t1->nb[1] ) &&
3092
+ (t0->nb[2] == t1->nb[2] ) &&
3093
+ (t0->nb[3] == t1->nb[3] );
3094
+ }
3095
+
2681
3096
  // check if t1 can be represented as a repeatition of t0
2682
3097
  static inline bool ggml_can_repeat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
2683
3098
  static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
@@ -2729,15 +3144,14 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
2729
3144
  {
2730
3145
  const uint64_t t_start = ggml_time_us(); UNUSED(t_start);
2731
3146
 
2732
- ggml_fp16_t ii;
2733
3147
  for (int i = 0; i < (1 << 16); ++i) {
2734
- uint16_t ui = i;
2735
- memcpy(&ii, &ui, sizeof(ii));
2736
- const float f = ggml_table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(ii);
3148
+ union {
3149
+ uint16_t u16;
3150
+ ggml_fp16_t fp16;
3151
+ } u = {i};
3152
+ float f = ggml_table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(u.fp16);
2737
3153
  ggml_table_gelu_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_f32(f));
2738
3154
  ggml_table_gelu_quick_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_quick_f32(f));
2739
- ggml_table_silu_f16[i] = GGML_FP32_TO_FP16(ggml_silu_f32(f));
2740
- ggml_table_exp_f16[i] = GGML_FP32_TO_FP16(expf(f));
2741
3155
  }
2742
3156
 
2743
3157
  const uint64_t t_end = ggml_time_us(); UNUSED(t_end);
@@ -3021,6 +3435,12 @@ static struct ggml_tensor * ggml_new_tensor_impl(
3021
3435
 
3022
3436
  struct ggml_tensor * const result = (struct ggml_tensor *)((char *)ctx->mem_buffer + obj_new->offs);
3023
3437
 
3438
+ #ifdef __clang__
3439
+ // temporary until ggml_tensor::backend is removed
3440
+ #pragma clang diagnostic push
3441
+ #pragma clang diagnostic ignored "-Wdeprecated-declarations"
3442
+ #endif
3443
+
3024
3444
  *result = (struct ggml_tensor) {
3025
3445
  /*.type =*/ type,
3026
3446
  /*.backend =*/ GGML_BACKEND_TYPE_CPU,
@@ -3043,6 +3463,10 @@ static struct ggml_tensor * ggml_new_tensor_impl(
3043
3463
  /*.padding =*/ { 0 },
3044
3464
  };
3045
3465
 
3466
+ #ifdef __clang__
3467
+ #pragma clang diagnostic pop
3468
+ #endif
3469
+
3046
3470
  // TODO: this should not be needed as long as we don't rely on aligned SIMD loads
3047
3471
  //ggml_assert_aligned(result->data);
3048
3472
 
@@ -3201,6 +3625,13 @@ struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value) {
3201
3625
  ggml_vec_set_f16(nc, (ggml_fp16_t *)(data + i*n1), GGML_FP32_TO_FP16(value));
3202
3626
  }
3203
3627
  } break;
3628
+ case GGML_TYPE_BF16:
3629
+ {
3630
+ assert(tensor->nb[0] == sizeof(ggml_fp16_t));
3631
+ for (int i = 0; i < n; i++) {
3632
+ ggml_vec_set_bf16(nc, (ggml_bf16_t *)(data + i*n1), GGML_FP32_TO_BF16(value));
3633
+ }
3634
+ } break;
3204
3635
  case GGML_TYPE_F32:
3205
3636
  {
3206
3637
  assert(tensor->nb[0] == sizeof(float));
@@ -3253,6 +3684,13 @@ struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) {
3253
3684
  ggml_vec_set_f16(nc, (ggml_fp16_t *)(data + i*n1), GGML_FP32_TO_FP16(value));
3254
3685
  }
3255
3686
  } break;
3687
+ case GGML_TYPE_BF16:
3688
+ {
3689
+ assert(tensor->nb[0] == sizeof(ggml_bf16_t));
3690
+ for (int i = 0; i < n; i++) {
3691
+ ggml_vec_set_bf16(nc, (ggml_bf16_t *)(data + i*n1), GGML_FP32_TO_BF16(value));
3692
+ }
3693
+ } break;
3256
3694
  case GGML_TYPE_F32:
3257
3695
  {
3258
3696
  assert(tensor->nb[0] == sizeof(float));
@@ -3320,6 +3758,11 @@ int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) {
3320
3758
  GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t));
3321
3759
  return GGML_FP16_TO_FP32(((ggml_fp16_t *)(tensor->data))[i]);
3322
3760
  }
3761
+ case GGML_TYPE_BF16:
3762
+ {
3763
+ GGML_ASSERT(tensor->nb[0] == sizeof(ggml_bf16_t));
3764
+ return GGML_BF16_TO_FP32(((ggml_bf16_t *)(tensor->data))[i]);
3765
+ }
3323
3766
  case GGML_TYPE_F32:
3324
3767
  {
3325
3768
  GGML_ASSERT(tensor->nb[0] == sizeof(float));
@@ -3362,6 +3805,11 @@ void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) {
3362
3805
  GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t));
3363
3806
  ((ggml_fp16_t *)(tensor->data))[i] = GGML_FP32_TO_FP16(value);
3364
3807
  } break;
3808
+ case GGML_TYPE_BF16:
3809
+ {
3810
+ GGML_ASSERT(tensor->nb[0] == sizeof(ggml_bf16_t));
3811
+ ((ggml_bf16_t *)(tensor->data))[i] = GGML_FP32_TO_BF16(value);
3812
+ } break;
3365
3813
  case GGML_TYPE_F32:
3366
3814
  {
3367
3815
  GGML_ASSERT(tensor->nb[0] == sizeof(float));
@@ -3385,6 +3833,8 @@ int32_t ggml_get_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i
3385
3833
  return ((int32_t *) data)[0];
3386
3834
  case GGML_TYPE_F16:
3387
3835
  return GGML_FP16_TO_FP32(((ggml_fp16_t *) data)[0]);
3836
+ case GGML_TYPE_BF16:
3837
+ return GGML_BF16_TO_FP32(((ggml_bf16_t *) data)[0]);
3388
3838
  case GGML_TYPE_F32:
3389
3839
  return ((float *) data)[0];
3390
3840
  default:
@@ -3413,6 +3863,10 @@ void ggml_set_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2,
3413
3863
  {
3414
3864
  ((ggml_fp16_t *)(data))[0] = GGML_FP32_TO_FP16(value);
3415
3865
  } break;
3866
+ case GGML_TYPE_BF16:
3867
+ {
3868
+ ((ggml_bf16_t *)(data))[0] = GGML_FP32_TO_BF16(value);
3869
+ } break;
3416
3870
  case GGML_TYPE_F32:
3417
3871
  {
3418
3872
  ((float *)(data))[0] = value;
@@ -3451,6 +3905,11 @@ float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) {
3451
3905
  GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t));
3452
3906
  return GGML_FP16_TO_FP32(((ggml_fp16_t *)(tensor->data))[i]);
3453
3907
  }
3908
+ case GGML_TYPE_BF16:
3909
+ {
3910
+ GGML_ASSERT(tensor->nb[0] == sizeof(ggml_bf16_t));
3911
+ return GGML_BF16_TO_FP32(((ggml_bf16_t *)(tensor->data))[i]);
3912
+ }
3454
3913
  case GGML_TYPE_F32:
3455
3914
  {
3456
3915
  GGML_ASSERT(tensor->nb[0] == sizeof(float));
@@ -3493,6 +3952,11 @@ void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) {
3493
3952
  GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t));
3494
3953
  ((ggml_fp16_t *)(tensor->data))[i] = GGML_FP32_TO_FP16(value);
3495
3954
  } break;
3955
+ case GGML_TYPE_BF16:
3956
+ {
3957
+ GGML_ASSERT(tensor->nb[0] == sizeof(ggml_bf16_t));
3958
+ ((ggml_bf16_t *)(tensor->data))[i] = GGML_FP32_TO_BF16(value);
3959
+ } break;
3496
3960
  case GGML_TYPE_F32:
3497
3961
  {
3498
3962
  GGML_ASSERT(tensor->nb[0] == sizeof(float));
@@ -3516,6 +3980,8 @@ float ggml_get_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2,
3516
3980
  return ((int32_t *) data)[0];
3517
3981
  case GGML_TYPE_F16:
3518
3982
  return GGML_FP16_TO_FP32(((ggml_fp16_t *) data)[0]);
3983
+ case GGML_TYPE_BF16:
3984
+ return GGML_BF16_TO_FP32(((ggml_bf16_t *) data)[0]);
3519
3985
  case GGML_TYPE_F32:
3520
3986
  return ((float *) data)[0];
3521
3987
  default:
@@ -3544,6 +4010,10 @@ void ggml_set_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2,
3544
4010
  {
3545
4011
  ((ggml_fp16_t *)(data))[0] = GGML_FP32_TO_FP16(value);
3546
4012
  } break;
4013
+ case GGML_TYPE_BF16:
4014
+ {
4015
+ ((ggml_bf16_t *)(data))[0] = GGML_FP32_TO_BF16(value);
4016
+ } break;
3547
4017
  case GGML_TYPE_F32:
3548
4018
  {
3549
4019
  ((float *)(data))[0] = value;
@@ -3738,7 +4208,11 @@ static struct ggml_tensor * ggml_add_cast_impl(
3738
4208
  // TODO: support less-strict constraint
3739
4209
  // GGML_ASSERT(ggml_can_repeat(b, a));
3740
4210
  GGML_ASSERT(ggml_can_repeat_rows(b, a));
3741
- GGML_ASSERT(ggml_is_quantized(a->type) || a->type == GGML_TYPE_F16); // currently only supported for quantized input and f16
4211
+
4212
+ // currently only supported for quantized input and f16
4213
+ GGML_ASSERT(ggml_is_quantized(a->type) ||
4214
+ a->type == GGML_TYPE_F16 ||
4215
+ a->type == GGML_TYPE_BF16);
3742
4216
 
3743
4217
  bool is_node = false;
3744
4218
 
@@ -4371,6 +4845,20 @@ struct ggml_tensor * ggml_leaky_relu(
4371
4845
  return result;
4372
4846
  }
4373
4847
 
4848
+ // ggml_sigmoid
4849
+
4850
+ struct ggml_tensor * ggml_sigmoid(
4851
+ struct ggml_context * ctx,
4852
+ struct ggml_tensor * a) {
4853
+ return ggml_unary(ctx, a, GGML_UNARY_OP_SIGMOID);
4854
+ }
4855
+
4856
+ struct ggml_tensor * ggml_sigmoid_inplace(
4857
+ struct ggml_context * ctx,
4858
+ struct ggml_tensor * a) {
4859
+ return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_SIGMOID);
4860
+ }
4861
+
4374
4862
  // ggml_gelu
4375
4863
 
4376
4864
  struct ggml_tensor * ggml_gelu(
@@ -5454,7 +5942,6 @@ static struct ggml_tensor * ggml_soft_max_impl(
5454
5942
  struct ggml_context * ctx,
5455
5943
  struct ggml_tensor * a,
5456
5944
  struct ggml_tensor * mask,
5457
- struct ggml_tensor * pos,
5458
5945
  float scale,
5459
5946
  float max_bias,
5460
5947
  bool inplace) {
@@ -5468,18 +5955,8 @@ static struct ggml_tensor * ggml_soft_max_impl(
5468
5955
  GGML_ASSERT(mask->ne[1] >= a->ne[1]);
5469
5956
  }
5470
5957
 
5471
- if (pos) {
5472
- GGML_ASSERT(ggml_is_vector(pos));
5473
- GGML_ASSERT(pos->type == GGML_TYPE_F16 || pos->type == GGML_TYPE_F32);
5474
- GGML_ASSERT(pos->ne[0] == a->ne[0]);
5475
- }
5476
-
5477
- if (pos && mask) {
5478
- GGML_ASSERT(pos->type == mask->type);
5479
- }
5480
-
5481
5958
  if (max_bias > 0.0f) {
5482
- GGML_ASSERT(pos);
5959
+ GGML_ASSERT(mask);
5483
5960
  }
5484
5961
 
5485
5962
  bool is_node = false;
@@ -5497,7 +5974,6 @@ static struct ggml_tensor * ggml_soft_max_impl(
5497
5974
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5498
5975
  result->src[0] = a;
5499
5976
  result->src[1] = mask;
5500
- result->src[2] = pos;
5501
5977
 
5502
5978
  return result;
5503
5979
  }
@@ -5505,23 +5981,22 @@ static struct ggml_tensor * ggml_soft_max_impl(
5505
5981
  struct ggml_tensor * ggml_soft_max(
5506
5982
  struct ggml_context * ctx,
5507
5983
  struct ggml_tensor * a) {
5508
- return ggml_soft_max_impl(ctx, a, NULL, NULL, 1.0f, 0.0f, false);
5984
+ return ggml_soft_max_impl(ctx, a, NULL, 1.0f, 0.0f, false);
5509
5985
  }
5510
5986
 
5511
5987
  struct ggml_tensor * ggml_soft_max_inplace(
5512
5988
  struct ggml_context * ctx,
5513
5989
  struct ggml_tensor * a) {
5514
- return ggml_soft_max_impl(ctx, a, NULL, NULL, 1.0f, 0.0f, true);
5990
+ return ggml_soft_max_impl(ctx, a, NULL, 1.0f, 0.0f, true);
5515
5991
  }
5516
5992
 
5517
5993
  struct ggml_tensor * ggml_soft_max_ext(
5518
5994
  struct ggml_context * ctx,
5519
5995
  struct ggml_tensor * a,
5520
5996
  struct ggml_tensor * mask,
5521
- struct ggml_tensor * pos,
5522
5997
  float scale,
5523
5998
  float max_bias) {
5524
- return ggml_soft_max_impl(ctx, a, mask, pos, scale, max_bias, false);
5999
+ return ggml_soft_max_impl(ctx, a, mask, scale, max_bias, false);
5525
6000
  }
5526
6001
 
5527
6002
  // ggml_soft_max_back
@@ -5736,37 +6211,6 @@ struct ggml_tensor * ggml_rope_back(
5736
6211
  return result;
5737
6212
  }
5738
6213
 
5739
- // ggml_alibi
5740
-
5741
- struct ggml_tensor * ggml_alibi(
5742
- struct ggml_context * ctx,
5743
- struct ggml_tensor * a,
5744
- int n_past,
5745
- int n_head,
5746
- float bias_max) {
5747
- GGML_ASSERT(n_past >= 0);
5748
- bool is_node = false;
5749
-
5750
- if (a->grad) {
5751
- GGML_ASSERT(false); // TODO: implement backward
5752
- is_node = true;
5753
- }
5754
-
5755
- // TODO: when implement backward, fix this:
5756
- //struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
5757
- struct ggml_tensor * result = ggml_view_tensor(ctx, a);
5758
-
5759
- int32_t op_params[3] = { n_past, n_head };
5760
- memcpy(op_params + 2, &bias_max, sizeof(float));
5761
- ggml_set_op_params(result, op_params, sizeof(op_params));
5762
-
5763
- result->op = GGML_OP_ALIBI;
5764
- result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5765
- result->src[0] = a;
5766
-
5767
- return result;
5768
- }
5769
-
5770
6214
  // ggml_clamp
5771
6215
 
5772
6216
  struct ggml_tensor * ggml_clamp(
@@ -6116,7 +6560,10 @@ struct ggml_tensor * ggml_pool_2d(
6116
6560
  static struct ggml_tensor * ggml_upscale_impl(
6117
6561
  struct ggml_context * ctx,
6118
6562
  struct ggml_tensor * a,
6119
- int scale_factor) {
6563
+ int ne0,
6564
+ int ne1,
6565
+ int ne2,
6566
+ int ne3) {
6120
6567
  bool is_node = false;
6121
6568
 
6122
6569
  if (a->grad) {
@@ -6124,19 +6571,45 @@ static struct ggml_tensor * ggml_upscale_impl(
6124
6571
  is_node = true;
6125
6572
  }
6126
6573
 
6574
+ GGML_ASSERT(a->ne[0] <= ne0);
6575
+ GGML_ASSERT(a->ne[1] <= ne1);
6576
+ GGML_ASSERT(a->ne[2] <= ne2);
6577
+ GGML_ASSERT(a->ne[3] <= ne3);
6578
+
6127
6579
  struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type,
6128
- a->ne[0] * scale_factor,
6129
- a->ne[1] * scale_factor,
6130
- a->ne[2], a->ne[3]);
6580
+ ne0,
6581
+ ne1,
6582
+ ne2,
6583
+ ne3
6584
+ );
6131
6585
 
6132
6586
  result->op = GGML_OP_UPSCALE;
6133
- result->op_params[0] = scale_factor;
6587
+
6134
6588
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6135
6589
  result->src[0] = a;
6136
6590
 
6137
6591
  return result;
6138
6592
  }
6139
6593
 
6594
+ struct ggml_tensor * ggml_upscale(
6595
+ struct ggml_context * ctx,
6596
+ struct ggml_tensor * a,
6597
+ int scale_factor) {
6598
+ return ggml_upscale_impl(ctx, a, a->ne[0] * scale_factor, a->ne[1] * scale_factor, a->ne[2], a->ne[3]);
6599
+ }
6600
+
6601
+ struct ggml_tensor * ggml_upscale_ext(
6602
+ struct ggml_context * ctx,
6603
+ struct ggml_tensor * a,
6604
+ int ne0,
6605
+ int ne1,
6606
+ int ne2,
6607
+ int ne3) {
6608
+ return ggml_upscale_impl(ctx, a, ne0, ne1, ne2, ne3);
6609
+ }
6610
+
6611
+ // ggml_pad
6612
+
6140
6613
  struct ggml_tensor * ggml_pad(
6141
6614
  struct ggml_context * ctx,
6142
6615
  struct ggml_tensor * a,
@@ -6161,12 +6634,7 @@ struct ggml_tensor * ggml_pad(
6161
6634
  return result;
6162
6635
  }
6163
6636
 
6164
- struct ggml_tensor * ggml_upscale(
6165
- struct ggml_context * ctx,
6166
- struct ggml_tensor * a,
6167
- int scale_factor) {
6168
- return ggml_upscale_impl(ctx, a, scale_factor);
6169
- }
6637
+ // ggml_arange
6170
6638
 
6171
6639
  struct ggml_tensor * ggml_arange(
6172
6640
  struct ggml_context * ctx,
@@ -6188,6 +6656,8 @@ struct ggml_tensor * ggml_arange(
6188
6656
  return result;
6189
6657
  }
6190
6658
 
6659
+ // ggml_timestep_embedding
6660
+
6191
6661
  struct ggml_tensor * ggml_timestep_embedding(
6192
6662
  struct ggml_context * ctx,
6193
6663
  struct ggml_tensor * timesteps,
@@ -6294,9 +6764,11 @@ struct ggml_tensor * ggml_flash_attn_ext(
6294
6764
  struct ggml_tensor * k,
6295
6765
  struct ggml_tensor * v,
6296
6766
  struct ggml_tensor * mask,
6297
- float scale) {
6767
+ float scale,
6768
+ float max_bias) {
6298
6769
  GGML_ASSERT(ggml_can_mul_mat(k, q));
6299
6770
  // TODO: check if vT can be multiplied by (k*qT)
6771
+
6300
6772
  if (mask) {
6301
6773
  GGML_ASSERT(ggml_is_contiguous(mask));
6302
6774
  GGML_ASSERT(mask->ne[2] == 1);
@@ -6306,6 +6778,10 @@ struct ggml_tensor * ggml_flash_attn_ext(
6306
6778
  //GGML_ASSERT(ggml_can_repeat_rows(mask, qk));
6307
6779
  }
6308
6780
 
6781
+ if (max_bias > 0.0f) {
6782
+ GGML_ASSERT(mask);
6783
+ }
6784
+
6309
6785
  bool is_node = false;
6310
6786
 
6311
6787
  if (q->grad || k->grad || v->grad) {
@@ -6316,7 +6792,7 @@ struct ggml_tensor * ggml_flash_attn_ext(
6316
6792
  int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] };
6317
6793
  struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
6318
6794
 
6319
- float params[] = { scale };
6795
+ float params[] = { scale, max_bias };
6320
6796
  ggml_set_op_params(result, params, sizeof(params));
6321
6797
 
6322
6798
  result->op = GGML_OP_FLASH_ATTN_EXT;
@@ -6336,7 +6812,7 @@ void ggml_flash_attn_ext_set_prec(
6336
6812
 
6337
6813
  const int32_t prec_i32 = (int32_t) prec;
6338
6814
 
6339
- ggml_set_op_params_i32(a, 1, prec_i32); // scale is on first pos
6815
+ ggml_set_op_params_i32(a, 2, prec_i32); // scale is on first pos, max_bias on second
6340
6816
  }
6341
6817
 
6342
6818
  // ggml_flash_ff
@@ -7215,8 +7691,8 @@ static void ggml_compute_forward_dup_same_cont(
7215
7691
  ((char *) src0->data + ie0*nb00),
7216
7692
  (ie1 - ie0) * ggml_type_size(src0->type));
7217
7693
  }
7218
-
7219
7694
  }
7695
+
7220
7696
  static void ggml_compute_forward_dup_f16(
7221
7697
  const struct ggml_compute_params * params,
7222
7698
  struct ggml_tensor * dst) {
@@ -7490,6 +7966,366 @@ static void ggml_compute_forward_dup_f16(
7490
7966
  }
7491
7967
  }
7492
7968
 
7969
+ static void ggml_compute_forward_dup_bf16(
7970
+ const struct ggml_compute_params * params,
7971
+ struct ggml_tensor * dst) {
7972
+
7973
+ const struct ggml_tensor * src0 = dst->src[0];
7974
+
7975
+ GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
7976
+
7977
+ if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
7978
+ return;
7979
+ }
7980
+
7981
+ GGML_TENSOR_UNARY_OP_LOCALS
7982
+
7983
+ const int ith = params->ith; // thread index
7984
+ const int nth = params->nth; // number of threads
7985
+
7986
+ if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst) && src0->type == dst->type) {
7987
+ ggml_compute_forward_dup_same_cont(params, dst);
7988
+ return;
7989
+ }
7990
+
7991
+ // parallelize by rows
7992
+ const int nr = ne01;
7993
+ // number of rows per thread
7994
+ const int dr = (nr + nth - 1) / nth;
7995
+ // row range for this thread
7996
+ const int ir0 = dr * ith;
7997
+ const int ir1 = MIN(ir0 + dr, nr);
7998
+
7999
+ if (src0->type == dst->type &&
8000
+ ne00 == ne0 &&
8001
+ nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) {
8002
+ // copy by rows
8003
+ const size_t rs = ne00*nb00;
8004
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
8005
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
8006
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
8007
+ memcpy(
8008
+ ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
8009
+ ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
8010
+ rs);
8011
+ }
8012
+ }
8013
+ }
8014
+ return;
8015
+ }
8016
+
8017
+ // TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy
8018
+
8019
+ if (ggml_is_contiguous(dst)) {
8020
+ if (nb00 == sizeof(ggml_bf16_t)) {
8021
+ if (dst->type == GGML_TYPE_BF16) {
8022
+ size_t id = 0;
8023
+ const size_t rs = ne00 * nb00;
8024
+ char * dst_ptr = (char *) dst->data;
8025
+
8026
+ for (int i03 = 0; i03 < ne03; i03++) {
8027
+ for (int i02 = 0; i02 < ne02; i02++) {
8028
+ id += rs * ir0;
8029
+ for (int i01 = ir0; i01 < ir1; i01++) {
8030
+ const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
8031
+ memcpy(dst_ptr + id, src0_ptr, rs);
8032
+ id += rs;
8033
+ }
8034
+ id += rs * (ne01 - ir1);
8035
+ }
8036
+ }
8037
+ } else if (dst->type == GGML_TYPE_F16) {
8038
+ size_t id = 0;
8039
+ ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
8040
+
8041
+ for (int i03 = 0; i03 < ne03; i03++) {
8042
+ for (int i02 = 0; i02 < ne02; i02++) {
8043
+ id += ne00 * ir0;
8044
+ for (int i01 = ir0; i01 < ir1; i01++) {
8045
+ const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
8046
+ for (int i00 = 0; i00 < ne00; i00++) {
8047
+ dst_ptr[id] = GGML_FP32_TO_FP16(GGML_BF16_TO_FP32(src0_ptr[i00]));
8048
+ id++;
8049
+ }
8050
+ }
8051
+ id += ne00 * (ne01 - ir1);
8052
+ }
8053
+ }
8054
+ } else if (dst->type == GGML_TYPE_F32) {
8055
+ size_t id = 0;
8056
+ float * dst_ptr = (float *) dst->data;
8057
+
8058
+ for (int i03 = 0; i03 < ne03; i03++) {
8059
+ for (int i02 = 0; i02 < ne02; i02++) {
8060
+ id += ne00 * ir0;
8061
+ for (int i01 = ir0; i01 < ir1; i01++) {
8062
+ const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
8063
+ for (int i00 = 0; i00 < ne00; i00++) {
8064
+ dst_ptr[id] = GGML_BF16_TO_FP32(src0_ptr[i00]);
8065
+ id++;
8066
+ }
8067
+ }
8068
+ id += ne00 * (ne01 - ir1);
8069
+ }
8070
+ }
8071
+ } else if (type_traits[dst->type].from_float) {
8072
+ ggml_from_float_t const quantize_row_q = type_traits[dst->type].from_float;
8073
+ float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
8074
+
8075
+ size_t id = 0;
8076
+ size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
8077
+ char * dst_ptr = (char *) dst->data;
8078
+
8079
+ for (int i03 = 0; i03 < ne03; i03++) {
8080
+ for (int i02 = 0; i02 < ne02; i02++) {
8081
+ id += rs * ir0;
8082
+ for (int i01 = ir0; i01 < ir1; i01++) {
8083
+ const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
8084
+
8085
+ for (int i00 = 0; i00 < ne00; i00++) {
8086
+ src0_f32[i00] = GGML_BF16_TO_FP32(src0_ptr[i00]);
8087
+ }
8088
+
8089
+ quantize_row_q(src0_f32, dst_ptr + id, ne00);
8090
+ id += rs;
8091
+ }
8092
+ id += rs * (ne01 - ir1);
8093
+ }
8094
+ }
8095
+ } else {
8096
+ GGML_ASSERT(false); // TODO: implement
8097
+ }
8098
+ } else {
8099
+ //printf("%s: this is not optimal - fix me\n", __func__);
8100
+
8101
+ if (dst->type == GGML_TYPE_F32) {
8102
+ size_t id = 0;
8103
+ float * dst_ptr = (float *) dst->data;
8104
+
8105
+ for (int i03 = 0; i03 < ne03; i03++) {
8106
+ for (int i02 = 0; i02 < ne02; i02++) {
8107
+ id += ne00 * ir0;
8108
+ for (int i01 = ir0; i01 < ir1; i01++) {
8109
+ for (int i00 = 0; i00 < ne00; i00++) {
8110
+ const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
8111
+
8112
+ dst_ptr[id] = GGML_BF16_TO_FP32(*src0_ptr);
8113
+ id++;
8114
+ }
8115
+ }
8116
+ id += ne00 * (ne01 - ir1);
8117
+ }
8118
+ }
8119
+ } else if (dst->type == GGML_TYPE_BF16) {
8120
+ size_t id = 0;
8121
+ ggml_bf16_t * dst_ptr = (ggml_bf16_t *) dst->data;
8122
+
8123
+ for (int i03 = 0; i03 < ne03; i03++) {
8124
+ for (int i02 = 0; i02 < ne02; i02++) {
8125
+ id += ne00 * ir0;
8126
+ for (int i01 = ir0; i01 < ir1; i01++) {
8127
+ for (int i00 = 0; i00 < ne00; i00++) {
8128
+ const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
8129
+
8130
+ dst_ptr[id] = *src0_ptr;
8131
+ id++;
8132
+ }
8133
+ }
8134
+ id += ne00 * (ne01 - ir1);
8135
+ }
8136
+ }
8137
+ } else if (dst->type == GGML_TYPE_F16) {
8138
+ size_t id = 0;
8139
+ ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
8140
+
8141
+ for (int i03 = 0; i03 < ne03; i03++) {
8142
+ for (int i02 = 0; i02 < ne02; i02++) {
8143
+ id += ne00 * ir0;
8144
+ for (int i01 = ir0; i01 < ir1; i01++) {
8145
+ for (int i00 = 0; i00 < ne00; i00++) {
8146
+ const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
8147
+
8148
+ dst_ptr[id] = GGML_FP32_TO_FP16(GGML_BF16_TO_FP32(*src0_ptr));
8149
+ id++;
8150
+ }
8151
+ }
8152
+ id += ne00 * (ne01 - ir1);
8153
+ }
8154
+ }
8155
+ } else {
8156
+ GGML_ASSERT(false); // TODO: implement
8157
+ }
8158
+ }
8159
+ return;
8160
+ }
8161
+
8162
+ // dst counters
8163
+ int64_t i10 = 0;
8164
+ int64_t i11 = 0;
8165
+ int64_t i12 = 0;
8166
+ int64_t i13 = 0;
8167
+
8168
+ if (dst->type == GGML_TYPE_BF16) {
8169
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
8170
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
8171
+ i10 += ne00 * ir0;
8172
+ while (i10 >= ne0) {
8173
+ i10 -= ne0;
8174
+ if (++i11 == ne1) {
8175
+ i11 = 0;
8176
+ if (++i12 == ne2) {
8177
+ i12 = 0;
8178
+ if (++i13 == ne3) {
8179
+ i13 = 0;
8180
+ }
8181
+ }
8182
+ }
8183
+ }
8184
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
8185
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
8186
+ const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
8187
+ char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
8188
+
8189
+ memcpy(dst_ptr, src0_ptr, sizeof(ggml_bf16_t));
8190
+
8191
+ if (++i10 == ne00) {
8192
+ i10 = 0;
8193
+ if (++i11 == ne01) {
8194
+ i11 = 0;
8195
+ if (++i12 == ne02) {
8196
+ i12 = 0;
8197
+ if (++i13 == ne03) {
8198
+ i13 = 0;
8199
+ }
8200
+ }
8201
+ }
8202
+ }
8203
+ }
8204
+ }
8205
+ i10 += ne00 * (ne01 - ir1);
8206
+ while (i10 >= ne0) {
8207
+ i10 -= ne0;
8208
+ if (++i11 == ne1) {
8209
+ i11 = 0;
8210
+ if (++i12 == ne2) {
8211
+ i12 = 0;
8212
+ if (++i13 == ne3) {
8213
+ i13 = 0;
8214
+ }
8215
+ }
8216
+ }
8217
+ }
8218
+ }
8219
+ }
8220
+ } else if (dst->type == GGML_TYPE_F16) {
8221
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
8222
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
8223
+ i10 += ne00 * ir0;
8224
+ while (i10 >= ne0) {
8225
+ i10 -= ne0;
8226
+ if (++i11 == ne1) {
8227
+ i11 = 0;
8228
+ if (++i12 == ne2) {
8229
+ i12 = 0;
8230
+ if (++i13 == ne3) {
8231
+ i13 = 0;
8232
+ }
8233
+ }
8234
+ }
8235
+ }
8236
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
8237
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
8238
+ const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
8239
+ char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
8240
+
8241
+ *(ggml_fp16_t *) dst_ptr = GGML_FP32_TO_FP16(GGML_BF16_TO_FP32(*(const ggml_bf16_t *) src0_ptr));
8242
+
8243
+ if (++i10 == ne0) {
8244
+ i10 = 0;
8245
+ if (++i11 == ne1) {
8246
+ i11 = 0;
8247
+ if (++i12 == ne2) {
8248
+ i12 = 0;
8249
+ if (++i13 == ne3) {
8250
+ i13 = 0;
8251
+ }
8252
+ }
8253
+ }
8254
+ }
8255
+ }
8256
+ }
8257
+ i10 += ne00 * (ne01 - ir1);
8258
+ while (i10 >= ne0) {
8259
+ i10 -= ne0;
8260
+ if (++i11 == ne1) {
8261
+ i11 = 0;
8262
+ if (++i12 == ne2) {
8263
+ i12 = 0;
8264
+ if (++i13 == ne3) {
8265
+ i13 = 0;
8266
+ }
8267
+ }
8268
+ }
8269
+ }
8270
+ }
8271
+ }
8272
+ } else if (dst->type == GGML_TYPE_F32) {
8273
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
8274
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
8275
+ i10 += ne00 * ir0;
8276
+ while (i10 >= ne0) {
8277
+ i10 -= ne0;
8278
+ if (++i11 == ne1) {
8279
+ i11 = 0;
8280
+ if (++i12 == ne2) {
8281
+ i12 = 0;
8282
+ if (++i13 == ne3) {
8283
+ i13 = 0;
8284
+ }
8285
+ }
8286
+ }
8287
+ }
8288
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
8289
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
8290
+ const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
8291
+ char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
8292
+
8293
+ *(float *) dst_ptr = GGML_BF16_TO_FP32(*(const ggml_bf16_t *) src0_ptr);
8294
+
8295
+ if (++i10 == ne0) {
8296
+ i10 = 0;
8297
+ if (++i11 == ne1) {
8298
+ i11 = 0;
8299
+ if (++i12 == ne2) {
8300
+ i12 = 0;
8301
+ if (++i13 == ne3) {
8302
+ i13 = 0;
8303
+ }
8304
+ }
8305
+ }
8306
+ }
8307
+ }
8308
+ }
8309
+ i10 += ne00 * (ne01 - ir1);
8310
+ while (i10 >= ne0) {
8311
+ i10 -= ne0;
8312
+ if (++i11 == ne1) {
8313
+ i11 = 0;
8314
+ if (++i12 == ne2) {
8315
+ i12 = 0;
8316
+ if (++i13 == ne3) {
8317
+ i13 = 0;
8318
+ }
8319
+ }
8320
+ }
8321
+ }
8322
+ }
8323
+ }
8324
+ } else {
8325
+ GGML_ASSERT(false); // TODO: implement
8326
+ }
8327
+ }
8328
+
7493
8329
  static void ggml_compute_forward_dup_f32(
7494
8330
  const struct ggml_compute_params * params,
7495
8331
  struct ggml_tensor * dst) {
@@ -7596,43 +8432,113 @@ static void ggml_compute_forward_dup_f32(
7596
8432
  id++;
7597
8433
  }
7598
8434
  }
7599
- id += ne00 * (ne01 - ir1);
8435
+ id += ne00 * (ne01 - ir1);
8436
+ }
8437
+ }
8438
+ } else if (dst->type == GGML_TYPE_F16) {
8439
+ size_t id = 0;
8440
+ ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
8441
+
8442
+ for (int i03 = 0; i03 < ne03; i03++) {
8443
+ for (int i02 = 0; i02 < ne02; i02++) {
8444
+ id += ne00 * ir0;
8445
+ for (int i01 = ir0; i01 < ir1; i01++) {
8446
+ for (int i00 = 0; i00 < ne00; i00++) {
8447
+ const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
8448
+
8449
+ dst_ptr[id] = GGML_FP32_TO_FP16(*src0_ptr);
8450
+ id++;
8451
+ }
8452
+ }
8453
+ id += ne00 * (ne01 - ir1);
8454
+ }
8455
+ }
8456
+ } else if (dst->type == GGML_TYPE_BF16) {
8457
+ size_t id = 0;
8458
+ ggml_bf16_t * dst_ptr = (ggml_bf16_t *) dst->data;
8459
+
8460
+ for (int i03 = 0; i03 < ne03; i03++) {
8461
+ for (int i02 = 0; i02 < ne02; i02++) {
8462
+ id += ne00 * ir0;
8463
+ for (int i01 = ir0; i01 < ir1; i01++) {
8464
+ for (int i00 = 0; i00 < ne00; i00++) {
8465
+ const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
8466
+
8467
+ dst_ptr[id] = GGML_FP32_TO_BF16(*src0_ptr);
8468
+ id++;
8469
+ }
8470
+ }
8471
+ id += ne00 * (ne01 - ir1);
8472
+ }
8473
+ }
8474
+ } else {
8475
+ GGML_ASSERT(false); // TODO: implement
8476
+ }
8477
+ }
8478
+
8479
+ return;
8480
+ }
8481
+
8482
+ // dst counters
8483
+
8484
+ int64_t i10 = 0;
8485
+ int64_t i11 = 0;
8486
+ int64_t i12 = 0;
8487
+ int64_t i13 = 0;
8488
+
8489
+ if (dst->type == GGML_TYPE_F32) {
8490
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
8491
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
8492
+ i10 += ne00 * ir0;
8493
+ while (i10 >= ne0) {
8494
+ i10 -= ne0;
8495
+ if (++i11 == ne1) {
8496
+ i11 = 0;
8497
+ if (++i12 == ne2) {
8498
+ i12 = 0;
8499
+ if (++i13 == ne3) {
8500
+ i13 = 0;
8501
+ }
8502
+ }
8503
+ }
8504
+ }
8505
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
8506
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
8507
+ const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
8508
+ char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
8509
+
8510
+ memcpy(dst_ptr, src0_ptr, sizeof(float));
8511
+
8512
+ if (++i10 == ne0) {
8513
+ i10 = 0;
8514
+ if (++i11 == ne1) {
8515
+ i11 = 0;
8516
+ if (++i12 == ne2) {
8517
+ i12 = 0;
8518
+ if (++i13 == ne3) {
8519
+ i13 = 0;
8520
+ }
8521
+ }
8522
+ }
8523
+ }
7600
8524
  }
7601
8525
  }
7602
- } else if (dst->type == GGML_TYPE_F16) {
7603
- size_t id = 0;
7604
- ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
7605
-
7606
- for (int i03 = 0; i03 < ne03; i03++) {
7607
- for (int i02 = 0; i02 < ne02; i02++) {
7608
- id += ne00 * ir0;
7609
- for (int i01 = ir0; i01 < ir1; i01++) {
7610
- for (int i00 = 0; i00 < ne00; i00++) {
7611
- const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
7612
-
7613
- dst_ptr[id] = GGML_FP32_TO_FP16(*src0_ptr);
7614
- id++;
8526
+ i10 += ne00 * (ne01 - ir1);
8527
+ while (i10 >= ne0) {
8528
+ i10 -= ne0;
8529
+ if (++i11 == ne1) {
8530
+ i11 = 0;
8531
+ if (++i12 == ne2) {
8532
+ i12 = 0;
8533
+ if (++i13 == ne3) {
8534
+ i13 = 0;
7615
8535
  }
7616
8536
  }
7617
- id += ne00 * (ne01 - ir1);
7618
8537
  }
7619
8538
  }
7620
- } else {
7621
- GGML_ASSERT(false); // TODO: implement
7622
8539
  }
7623
8540
  }
7624
-
7625
- return;
7626
- }
7627
-
7628
- // dst counters
7629
-
7630
- int64_t i10 = 0;
7631
- int64_t i11 = 0;
7632
- int64_t i12 = 0;
7633
- int64_t i13 = 0;
7634
-
7635
- if (dst->type == GGML_TYPE_F32) {
8541
+ } else if (dst->type == GGML_TYPE_F16) {
7636
8542
  for (int64_t i03 = 0; i03 < ne03; i03++) {
7637
8543
  for (int64_t i02 = 0; i02 < ne02; i02++) {
7638
8544
  i10 += ne00 * ir0;
@@ -7653,7 +8559,7 @@ static void ggml_compute_forward_dup_f32(
7653
8559
  const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
7654
8560
  char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
7655
8561
 
7656
- memcpy(dst_ptr, src0_ptr, sizeof(float));
8562
+ *(ggml_fp16_t *) dst_ptr = GGML_FP32_TO_FP16(*(const float *) src0_ptr);
7657
8563
 
7658
8564
  if (++i10 == ne0) {
7659
8565
  i10 = 0;
@@ -7684,7 +8590,7 @@ static void ggml_compute_forward_dup_f32(
7684
8590
  }
7685
8591
  }
7686
8592
  }
7687
- } else if (dst->type == GGML_TYPE_F16) {
8593
+ } else if (dst->type == GGML_TYPE_BF16) {
7688
8594
  for (int64_t i03 = 0; i03 < ne03; i03++) {
7689
8595
  for (int64_t i02 = 0; i02 < ne02; i02++) {
7690
8596
  i10 += ne00 * ir0;
@@ -7705,7 +8611,7 @@ static void ggml_compute_forward_dup_f32(
7705
8611
  const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
7706
8612
  char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
7707
8613
 
7708
- *(ggml_fp16_t *) dst_ptr = GGML_FP32_TO_FP16(*(const float *) src0_ptr);
8614
+ *(ggml_bf16_t *) dst_ptr = GGML_FP32_TO_BF16(*(const float *) src0_ptr);
7709
8615
 
7710
8616
  if (++i10 == ne0) {
7711
8617
  i10 = 0;
@@ -7909,6 +8815,10 @@ static void ggml_compute_forward_dup(
7909
8815
  {
7910
8816
  ggml_compute_forward_dup_f16(params, dst);
7911
8817
  } break;
8818
+ case GGML_TYPE_BF16:
8819
+ {
8820
+ ggml_compute_forward_dup_bf16(params, dst);
8821
+ } break;
7912
8822
  case GGML_TYPE_F32:
7913
8823
  {
7914
8824
  ggml_compute_forward_dup_f32(params, dst);
@@ -8091,6 +9001,85 @@ static void ggml_compute_forward_add_f16_f32(
8091
9001
  }
8092
9002
  }
8093
9003
 
9004
+ static void ggml_compute_forward_add_bf16_f32(
9005
+ const struct ggml_compute_params * params,
9006
+ struct ggml_tensor * dst) {
9007
+
9008
+ const struct ggml_tensor * src0 = dst->src[0];
9009
+ const struct ggml_tensor * src1 = dst->src[1];
9010
+
9011
+ GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
9012
+
9013
+ if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
9014
+ return;
9015
+ }
9016
+
9017
+ const int ith = params->ith;
9018
+ const int nth = params->nth;
9019
+
9020
+ const int nr = ggml_nrows(src0);
9021
+
9022
+ GGML_TENSOR_BINARY_OP_LOCALS
9023
+
9024
+ GGML_ASSERT(src0->type == GGML_TYPE_BF16);
9025
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
9026
+
9027
+ if (dst->type == GGML_TYPE_F32) {
9028
+ GGML_ASSERT( nb0 == sizeof(float));
9029
+ }
9030
+ else {
9031
+ GGML_ASSERT(dst->type == GGML_TYPE_BF16);
9032
+ GGML_ASSERT( nb0 == sizeof(ggml_bf16_t));
9033
+ }
9034
+
9035
+ GGML_ASSERT(nb00 == sizeof(ggml_bf16_t));
9036
+
9037
+ // rows per thread
9038
+ const int dr = (nr + nth - 1)/nth;
9039
+
9040
+ // row range for this thread
9041
+ const int ir0 = dr*ith;
9042
+ const int ir1 = MIN(ir0 + dr, nr);
9043
+
9044
+ if (nb10 == sizeof(float)) {
9045
+ if (dst->type == GGML_TYPE_BF16) {
9046
+ for (int ir = ir0; ir < ir1; ++ir) {
9047
+ // src0, src1 and dst are same shape => same indices
9048
+ const int i3 = ir/(ne2*ne1);
9049
+ const int i2 = (ir - i3*ne2*ne1)/ne1;
9050
+ const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
9051
+
9052
+ ggml_bf16_t * dst_ptr = (ggml_bf16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
9053
+ ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
9054
+ float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
9055
+
9056
+ for (int i = 0; i < ne0; i++) {
9057
+ dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + src1_ptr[i]);
9058
+ }
9059
+ }
9060
+ } else {
9061
+ for (int ir = ir0; ir < ir1; ++ir) {
9062
+ // src0, src1 and dst are same shape => same indices
9063
+ const int i3 = ir/(ne2*ne1);
9064
+ const int i2 = (ir - i3*ne2*ne1)/ne1;
9065
+ const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
9066
+
9067
+ float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
9068
+ ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
9069
+ float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
9070
+
9071
+ for (int i = 0; i < ne0; i++) {
9072
+ dst_ptr[i] = GGML_BF16_TO_FP32(src0_ptr[i]) + src1_ptr[i];
9073
+ }
9074
+ }
9075
+ }
9076
+ }
9077
+ else {
9078
+ // src1 is not contiguous
9079
+ GGML_ASSERT(false);
9080
+ }
9081
+ }
9082
+
8094
9083
  static void ggml_compute_forward_add_f16_f16(
8095
9084
  const struct ggml_compute_params * params,
8096
9085
  struct ggml_tensor * dst) {
@@ -8147,6 +9136,62 @@ static void ggml_compute_forward_add_f16_f16(
8147
9136
  }
8148
9137
  }
8149
9138
 
9139
+ static void ggml_compute_forward_add_bf16_bf16(
9140
+ const struct ggml_compute_params * params,
9141
+ struct ggml_tensor * dst) {
9142
+
9143
+ const struct ggml_tensor * src0 = dst->src[0];
9144
+ const struct ggml_tensor * src1 = dst->src[1];
9145
+
9146
+ GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
9147
+
9148
+ if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
9149
+ return;
9150
+ }
9151
+
9152
+ const int ith = params->ith;
9153
+ const int nth = params->nth;
9154
+
9155
+ const int nr = ggml_nrows(src0);
9156
+
9157
+ GGML_TENSOR_BINARY_OP_LOCALS
9158
+
9159
+ GGML_ASSERT(src0->type == GGML_TYPE_BF16);
9160
+ GGML_ASSERT(src1->type == GGML_TYPE_BF16);
9161
+ GGML_ASSERT(dst->type == GGML_TYPE_BF16);
9162
+
9163
+ GGML_ASSERT( nb0 == sizeof(ggml_bf16_t));
9164
+ GGML_ASSERT(nb00 == sizeof(ggml_bf16_t));
9165
+
9166
+ // rows per thread
9167
+ const int dr = (nr + nth - 1)/nth;
9168
+
9169
+ // row range for this thread
9170
+ const int ir0 = dr*ith;
9171
+ const int ir1 = MIN(ir0 + dr, nr);
9172
+
9173
+ if (nb10 == sizeof(ggml_bf16_t)) {
9174
+ for (int ir = ir0; ir < ir1; ++ir) {
9175
+ // src0, src1 and dst are same shape => same indices
9176
+ const int i3 = ir/(ne2*ne1);
9177
+ const int i2 = (ir - i3*ne2*ne1)/ne1;
9178
+ const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
9179
+
9180
+ ggml_bf16_t * dst_ptr = (ggml_bf16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
9181
+ ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
9182
+ ggml_bf16_t * src1_ptr = (ggml_bf16_t *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
9183
+
9184
+ for (int i = 0; i < ne0; i++) {
9185
+ dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + GGML_BF16_TO_FP32(src1_ptr[i]));
9186
+ }
9187
+ }
9188
+ }
9189
+ else {
9190
+ // src1 is not contiguous
9191
+ GGML_ASSERT(false);
9192
+ }
9193
+ }
9194
+
8150
9195
  static void ggml_compute_forward_add_q_f32(
8151
9196
  const struct ggml_compute_params * params,
8152
9197
  struct ggml_tensor * dst) {
@@ -8256,6 +9301,18 @@ static void ggml_compute_forward_add(
8256
9301
  GGML_ASSERT(false);
8257
9302
  }
8258
9303
  } break;
9304
+ case GGML_TYPE_BF16:
9305
+ {
9306
+ if (src1->type == GGML_TYPE_BF16) {
9307
+ ggml_compute_forward_add_bf16_bf16(params, dst);
9308
+ }
9309
+ else if (src1->type == GGML_TYPE_F32) {
9310
+ ggml_compute_forward_add_bf16_f32(params, dst);
9311
+ }
9312
+ else {
9313
+ GGML_ASSERT(false);
9314
+ }
9315
+ } break;
8259
9316
  case GGML_TYPE_Q4_0:
8260
9317
  case GGML_TYPE_Q4_1:
8261
9318
  case GGML_TYPE_Q5_0:
@@ -8505,12 +9562,116 @@ static void ggml_compute_forward_add1_q_f32(
8505
9562
 
8506
9563
  assert(ne0 % 32 == 0);
8507
9564
 
8508
- // unquantize row from src0 to temp buffer
8509
- dequantize_row_q(src0_row, wdata, ne0);
8510
- // add src1
8511
- ggml_vec_acc1_f32(ne0, wdata, v);
8512
- // quantize row to dst
8513
- quantize_row_q(wdata, dst_row, ne0);
9565
+ // unquantize row from src0 to temp buffer
9566
+ dequantize_row_q(src0_row, wdata, ne0);
9567
+ // add src1
9568
+ ggml_vec_acc1_f32(ne0, wdata, v);
9569
+ // quantize row to dst
9570
+ quantize_row_q(wdata, dst_row, ne0);
9571
+ }
9572
+ }
9573
+
9574
+ static void ggml_compute_forward_add1_bf16_f32(
9575
+ const struct ggml_compute_params * params,
9576
+ struct ggml_tensor * dst) {
9577
+
9578
+ const struct ggml_tensor * src0 = dst->src[0];
9579
+ const struct ggml_tensor * src1 = dst->src[1];
9580
+
9581
+ GGML_ASSERT(ggml_are_same_shape(src0, dst));
9582
+ GGML_ASSERT(ggml_is_scalar(src1));
9583
+
9584
+ if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
9585
+ return;
9586
+ }
9587
+
9588
+ // scalar to add
9589
+ const float v = *(float *) src1->data;
9590
+
9591
+ const int ith = params->ith;
9592
+ const int nth = params->nth;
9593
+
9594
+ const int nr = ggml_nrows(src0);
9595
+
9596
+ GGML_TENSOR_UNARY_OP_LOCALS
9597
+
9598
+ GGML_ASSERT(src0->type == GGML_TYPE_BF16);
9599
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
9600
+ GGML_ASSERT(dst->type == GGML_TYPE_BF16);
9601
+
9602
+ GGML_ASSERT( nb0 == sizeof(ggml_bf16_t));
9603
+ GGML_ASSERT(nb00 == sizeof(ggml_bf16_t));
9604
+
9605
+ // rows per thread
9606
+ const int dr = (nr + nth - 1)/nth;
9607
+
9608
+ // row range for this thread
9609
+ const int ir0 = dr*ith;
9610
+ const int ir1 = MIN(ir0 + dr, nr);
9611
+
9612
+ for (int ir = ir0; ir < ir1; ++ir) {
9613
+ // src0 and dst are same shape => same indices
9614
+ const int i3 = ir/(ne2*ne1);
9615
+ const int i2 = (ir - i3*ne2*ne1)/ne1;
9616
+ const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
9617
+
9618
+ ggml_bf16_t * dst_ptr = (ggml_bf16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 );
9619
+ ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
9620
+ for (int i = 0; i < ne0; i++) {
9621
+ dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + v);
9622
+ }
9623
+ }
9624
+ }
9625
+
9626
+ static void ggml_compute_forward_add1_bf16_bf16(
9627
+ const struct ggml_compute_params * params,
9628
+ struct ggml_tensor * dst) {
9629
+
9630
+ const struct ggml_tensor * src0 = dst->src[0];
9631
+ const struct ggml_tensor * src1 = dst->src[1];
9632
+
9633
+ GGML_ASSERT(ggml_are_same_shape(src0, dst));
9634
+ GGML_ASSERT(ggml_is_scalar(src1));
9635
+
9636
+ if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
9637
+ return;
9638
+ }
9639
+
9640
+ // scalar to add
9641
+ const float v = GGML_BF16_TO_FP32(*(ggml_bf16_t *) src1->data);
9642
+
9643
+ const int ith = params->ith;
9644
+ const int nth = params->nth;
9645
+
9646
+ const int nr = ggml_nrows(src0);
9647
+
9648
+ GGML_TENSOR_UNARY_OP_LOCALS
9649
+
9650
+ GGML_ASSERT(src0->type == GGML_TYPE_BF16);
9651
+ GGML_ASSERT(src1->type == GGML_TYPE_BF16);
9652
+ GGML_ASSERT(dst->type == GGML_TYPE_BF16);
9653
+
9654
+ GGML_ASSERT( nb0 == sizeof(ggml_bf16_t));
9655
+ GGML_ASSERT(nb00 == sizeof(ggml_bf16_t));
9656
+
9657
+ // rows per thread
9658
+ const int dr = (nr + nth - 1)/nth;
9659
+
9660
+ // row range for this thread
9661
+ const int ir0 = dr*ith;
9662
+ const int ir1 = MIN(ir0 + dr, nr);
9663
+
9664
+ for (int ir = ir0; ir < ir1; ++ir) {
9665
+ // src0 and dst are same shape => same indices
9666
+ const int i3 = ir/(ne2*ne1);
9667
+ const int i2 = (ir - i3*ne2*ne1)/ne1;
9668
+ const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
9669
+
9670
+ ggml_bf16_t * dst_ptr = (ggml_bf16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 );
9671
+ ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
9672
+ for (int i = 0; i < ne0; i++) {
9673
+ dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + v);
9674
+ }
8514
9675
  }
8515
9676
  }
8516
9677
 
@@ -8538,6 +9699,18 @@ static void ggml_compute_forward_add1(
8538
9699
  GGML_ASSERT(false);
8539
9700
  }
8540
9701
  } break;
9702
+ case GGML_TYPE_BF16:
9703
+ {
9704
+ if (src1->type == GGML_TYPE_BF16) {
9705
+ ggml_compute_forward_add1_bf16_bf16(params, dst);
9706
+ }
9707
+ else if (src1->type == GGML_TYPE_F32) {
9708
+ ggml_compute_forward_add1_bf16_f32(params, dst);
9709
+ }
9710
+ else {
9711
+ GGML_ASSERT(false);
9712
+ }
9713
+ } break;
8541
9714
  case GGML_TYPE_Q4_0:
8542
9715
  case GGML_TYPE_Q4_1:
8543
9716
  case GGML_TYPE_Q5_0:
@@ -8666,6 +9839,7 @@ static void ggml_compute_forward_acc(
8666
9839
  ggml_compute_forward_acc_f32(params, dst);
8667
9840
  } break;
8668
9841
  case GGML_TYPE_F16:
9842
+ case GGML_TYPE_BF16:
8669
9843
  case GGML_TYPE_Q4_0:
8670
9844
  case GGML_TYPE_Q4_1:
8671
9845
  case GGML_TYPE_Q5_0:
@@ -9187,6 +10361,40 @@ static void ggml_compute_forward_sum_f16(
9187
10361
  ((ggml_fp16_t *) dst->data)[0] = GGML_FP32_TO_FP16(sum);
9188
10362
  }
9189
10363
 
10364
+ static void ggml_compute_forward_sum_bf16(
10365
+ const struct ggml_compute_params * params,
10366
+ struct ggml_tensor * dst) {
10367
+
10368
+ const struct ggml_tensor * src0 = dst->src[0];
10369
+
10370
+ assert(params->ith == 0);
10371
+ assert(ggml_is_scalar(dst));
10372
+
10373
+ if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
10374
+ return;
10375
+ }
10376
+
10377
+ assert(src0->nb[0] == sizeof(ggml_bf16_t));
10378
+
10379
+ GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
10380
+ GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
10381
+
10382
+ float sum = 0;
10383
+ float row_sum = 0;
10384
+
10385
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
10386
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
10387
+ for (int64_t i01 = 0; i01 < ne01; i01++) {
10388
+ ggml_vec_sum_bf16_ggf(ne00,
10389
+ &row_sum,
10390
+ (ggml_bf16_t *) ((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03));
10391
+ sum += row_sum;
10392
+ }
10393
+ }
10394
+ }
10395
+ ((ggml_bf16_t *) dst->data)[0] = GGML_FP32_TO_BF16(sum);
10396
+ }
10397
+
9190
10398
  static void ggml_compute_forward_sum(
9191
10399
  const struct ggml_compute_params * params,
9192
10400
  struct ggml_tensor * dst) {
@@ -9202,6 +10410,10 @@ static void ggml_compute_forward_sum(
9202
10410
  {
9203
10411
  ggml_compute_forward_sum_f16(params, dst);
9204
10412
  } break;
10413
+ case GGML_TYPE_BF16:
10414
+ {
10415
+ ggml_compute_forward_sum_bf16(params, dst);
10416
+ } break;
9205
10417
  default:
9206
10418
  {
9207
10419
  GGML_ASSERT(false);
@@ -9476,6 +10688,7 @@ static void ggml_compute_forward_repeat(
9476
10688
 
9477
10689
  switch (src0->type) {
9478
10690
  case GGML_TYPE_F16:
10691
+ case GGML_TYPE_BF16:
9479
10692
  case GGML_TYPE_I16:
9480
10693
  {
9481
10694
  ggml_compute_forward_repeat_f16(params, dst);
@@ -9963,6 +11176,52 @@ static void ggml_compute_forward_relu(
9963
11176
  }
9964
11177
  }
9965
11178
 
11179
+ // ggml_compute_forward_sigmoid
11180
+
11181
+ static void ggml_compute_forward_sigmoid_f32(
11182
+ const struct ggml_compute_params * params,
11183
+ struct ggml_tensor * dst) {
11184
+
11185
+ const struct ggml_tensor * src0 = dst->src[0];
11186
+
11187
+ assert(params->ith == 0);
11188
+ assert(ggml_are_same_shape(src0, dst));
11189
+
11190
+ if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
11191
+ return;
11192
+ }
11193
+
11194
+ const int n = ggml_nrows(src0);
11195
+ const int nc = src0->ne[0];
11196
+
11197
+ assert(dst->nb[0] == sizeof(float));
11198
+ assert(src0->nb[0] == sizeof(float));
11199
+
11200
+ for (int i = 0; i < n; i++) {
11201
+ ggml_vec_sigmoid_f32(nc,
11202
+ (float *) ((char *) dst->data + i*( dst->nb[1])),
11203
+ (float *) ((char *) src0->data + i*(src0->nb[1])));
11204
+ }
11205
+ }
11206
+
11207
+ static void ggml_compute_forward_sigmoid(
11208
+ const struct ggml_compute_params * params,
11209
+ struct ggml_tensor * dst) {
11210
+
11211
+ const struct ggml_tensor * src0 = dst->src[0];
11212
+
11213
+ switch (src0->type) {
11214
+ case GGML_TYPE_F32:
11215
+ {
11216
+ ggml_compute_forward_sigmoid_f32(params, dst);
11217
+ } break;
11218
+ default:
11219
+ {
11220
+ GGML_ASSERT(false);
11221
+ } break;
11222
+ }
11223
+ }
11224
+
9966
11225
  // ggml_compute_forward_gelu
9967
11226
 
9968
11227
  static void ggml_compute_forward_gelu_f32(
@@ -10813,9 +12072,101 @@ static bool ggml_compute_forward_mul_mat_use_blas(struct ggml_tensor * dst) {
10813
12072
  }
10814
12073
  #endif
10815
12074
 
12075
+ static void ggml_compute_forward_mul_mat_one_chunk(
12076
+ const struct ggml_compute_params * params,
12077
+ struct ggml_tensor * dst,
12078
+ const int64_t num_rows_per_vec_dot,
12079
+ const int64_t ir0_start,
12080
+ const int64_t ir0_end,
12081
+ const int64_t ir1_start,
12082
+ const int64_t ir1_end) {
12083
+
12084
+ const struct ggml_tensor * src0 = dst->src[0];
12085
+ const struct ggml_tensor * src1 = dst->src[1];
12086
+
12087
+ GGML_TENSOR_BINARY_OP_LOCALS
12088
+
12089
+ const enum ggml_type type = src0->type;
12090
+
12091
+ const bool src1_cont = ggml_is_contiguous(src1);
12092
+
12093
+ ggml_vec_dot_t const vec_dot = type_traits[type].vec_dot;
12094
+ enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type;
12095
+
12096
+ // broadcast factors
12097
+ const int64_t r2 = ne12 / ne02;
12098
+ const int64_t r3 = ne13 / ne03;
12099
+
12100
+ //printf("ir0_start = %6lld, ir0_end = %6lld, ir1_start = %6lld, ir1_end = %6lld\n", ir0_start, ir0_end, ir1_start, ir1_end);
12101
+
12102
+ // threads with no work simply yield (not sure if it helps)
12103
+ if (ir0_start >= ir0_end || ir1_start >= ir1_end) {
12104
+ return;
12105
+ }
12106
+
12107
+ const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
12108
+ const size_t row_size = ggml_row_size(vec_dot_type, ne10);
12109
+
12110
+ assert(ne12 % ne02 == 0);
12111
+ assert(ne13 % ne03 == 0);
12112
+
12113
+ // block-tiling attempt
12114
+ const int64_t blck_0 = 16;
12115
+ const int64_t blck_1 = 16;
12116
+
12117
+ const size_t src1_col_stride = src1_cont || src1->type != vec_dot_type ? row_size : nb11;
12118
+
12119
+ // attempt to reduce false-sharing (does not seem to make a difference)
12120
+ // 16 * 2, accounting for mmla kernels
12121
+ float tmp[32];
12122
+
12123
+ for (int64_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) {
12124
+ for (int64_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) {
12125
+ for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir1_end; ir1 += num_rows_per_vec_dot) {
12126
+ const int64_t i13 = (ir1 / (ne12 * ne1));
12127
+ const int64_t i12 = (ir1 - i13 * ne12 * ne1) / ne1;
12128
+ const int64_t i11 = (ir1 - i13 * ne12 * ne1 - i12 * ne1);
12129
+
12130
+ // broadcast src0 into src1
12131
+ const int64_t i03 = i13 / r3;
12132
+ const int64_t i02 = i12 / r2;
12133
+
12134
+ const int64_t i1 = i11;
12135
+ const int64_t i2 = i12;
12136
+ const int64_t i3 = i13;
12137
+
12138
+ const char * src0_row = (const char*)src0->data + (0 + i02 * nb02 + i03 * nb03);
12139
+
12140
+ // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
12141
+ // if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
12142
+ // the original src1 data pointer, so we should index using the indices directly
12143
+ // TODO: this is a bit of a hack, we should probably have a better way to handle this
12144
+ const char * src1_col = (const char*)wdata +
12145
+ (src1_cont || src1->type != vec_dot_type
12146
+ ? (i11 + i12 * ne11 + i13 * ne12 * ne11) * row_size
12147
+ : (i11 * nb11 + i12 * nb12 + i13 * nb13));
12148
+ float * dst_col = (float*)((char*)dst->data + (i1 * nb1 + i2 * nb2 + i3 * nb3));
12149
+
12150
+ //for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ++ir0) {
12151
+ // vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col);
12152
+ //}
12153
+
12154
+ for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ir0 += num_rows_per_vec_dot) {
12155
+ vec_dot(ne00, &tmp[ir0 - iir0], (num_rows_per_vec_dot > 1 ? 16 : 0), src0_row + ir0 * nb01, (num_rows_per_vec_dot > 1 ? nb01 : 0), src1_col, (num_rows_per_vec_dot > 1 ? src1_col_stride : 0), num_rows_per_vec_dot);
12156
+ }
12157
+
12158
+ for (int cn = 0; cn < num_rows_per_vec_dot; ++cn) {
12159
+ memcpy(&dst_col[iir0 + cn * nb1 / nb0], tmp + (cn * 16), (MIN(iir0 + blck_0, ir0_end) - iir0) * sizeof(float));
12160
+ }
12161
+ }
12162
+ }
12163
+ }
12164
+ }
12165
+
10816
12166
  static void ggml_compute_forward_mul_mat(
10817
12167
  const struct ggml_compute_params * params,
10818
- struct ggml_tensor * dst) {
12168
+ struct ggml_tensor * dst,
12169
+ struct ggml_compute_state * state) {
10819
12170
 
10820
12171
  const struct ggml_tensor * src0 = dst->src[0];
10821
12172
  const struct ggml_tensor * src1 = dst->src[1];
@@ -10830,9 +12181,6 @@ static void ggml_compute_forward_mul_mat(
10830
12181
 
10831
12182
  const enum ggml_type type = src0->type;
10832
12183
 
10833
- const bool src1_cont = ggml_is_contiguous(src1);
10834
-
10835
- ggml_vec_dot_t const vec_dot = type_traits[type].vec_dot;
10836
12184
  enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type;
10837
12185
  ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float;
10838
12186
  int64_t const vec_dot_num_rows = type_traits[type].nrows;
@@ -10853,8 +12201,10 @@ static void ggml_compute_forward_mul_mat(
10853
12201
  GGML_ASSERT(nb2 <= nb3);
10854
12202
 
10855
12203
  // broadcast factors
10856
- const int64_t r2 = ne12/ne02;
10857
- const int64_t r3 = ne13/ne03;
12204
+ const int64_t r2 = ne12 / ne02;
12205
+ const int64_t r3 = ne13 / ne03;
12206
+ UNUSED(r2);
12207
+ UNUSED(r3);
10858
12208
 
10859
12209
  // nb01 >= nb00 - src0 is not transposed
10860
12210
  // compute by src0 rows
@@ -10936,6 +12286,8 @@ static void ggml_compute_forward_mul_mat(
10936
12286
  #endif
10937
12287
 
10938
12288
  #if GGML_USE_LLAMAFILE
12289
+ const bool src1_cont = ggml_is_contiguous(src1);
12290
+
10939
12291
  if (src1_cont) {
10940
12292
  for (int64_t i13 = 0; i13 < ne13; i13++)
10941
12293
  for (int64_t i12 = 0; i12 < ne12; i12++)
@@ -10961,6 +12313,8 @@ UseGgmlGemm1:;
10961
12313
  if (ith != 0) {
10962
12314
  return;
10963
12315
  }
12316
+ // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
12317
+ atomic_store(&state->shared->current_chunk, nth);
10964
12318
  if (src1->type != vec_dot_type) {
10965
12319
  char * wdata = params->wdata;
10966
12320
  const size_t row_size = ggml_row_size(vec_dot_type, ne10);
@@ -10985,11 +12339,11 @@ UseGgmlGemm1:;
10985
12339
  return;
10986
12340
  }
10987
12341
 
10988
- const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
10989
- const size_t row_size = ggml_row_size(vec_dot_type, ne10);
10990
-
10991
12342
  #if GGML_USE_LLAMAFILE
10992
12343
  if (src1->type != vec_dot_type) {
12344
+ const void* wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
12345
+ const size_t row_size = ggml_row_size(vec_dot_type, ne10);
12346
+
10993
12347
  for (int64_t i13 = 0; i13 < ne13; i13++)
10994
12348
  for (int64_t i12 = 0; i12 < ne12; i12++)
10995
12349
  if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type),
@@ -11010,98 +12364,87 @@ UseGgmlGemm1:;
11010
12364
  UseGgmlGemm2:;
11011
12365
  #endif
11012
12366
 
11013
- const int64_t nr0 = ne01; // src0 rows
11014
- const int64_t nr1 = ne1*ne12*ne13; // src1 rows
11015
-
11016
- //printf("nr0 = %lld, nr1 = %lld\n", nr0, nr1);
11017
-
11018
- // distribute the thread work across the inner or outer loop based on which one is larger
11019
-
11020
- const int64_t nth0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows
11021
- const int64_t nth1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows
11022
-
11023
- const int64_t ith0 = ith % nth0;
11024
- const int64_t ith1 = ith / nth0;
11025
-
11026
- const int64_t dr0 = (nr0 + nth0 - 1)/nth0;
11027
- const int64_t dr1 = (nr1 + nth1 - 1)/nth1;
11028
-
11029
- const int64_t ir010 = dr0*ith0;
11030
- const int64_t ir011 = MIN(ir010 + dr0, nr0);
11031
-
11032
- const int64_t ir110 = dr1*ith1;
11033
- const int64_t ir111 = MIN(ir110 + dr1, nr1);
11034
-
11035
- //printf("ir010 = %6lld, ir011 = %6lld, ir110 = %6lld, ir111 = %6lld\n", ir010, ir011, ir110, ir111);
11036
-
11037
- // threads with no work simply yield (not sure if it helps)
11038
- if (ir010 >= ir011 || ir110 >= ir111) {
11039
- sched_yield();
11040
- return;
11041
- }
12367
+ #ifdef GGML_PERF
12368
+ int chunks_executed = 0;
12369
+ UNUSED(chunks_executed);
12370
+ #endif
11042
12371
 
11043
- assert(ne12 % ne02 == 0);
11044
- assert(ne13 % ne03 == 0);
12372
+ // This is the size of the first dimension of the result, so we can iterate that way. (see the ASSERT above, these are the same numbers)
12373
+ const int64_t nr0 = ne0;
11045
12374
 
11046
- // block-tiling attempt
11047
- const int64_t blck_0 = 16;
11048
- const int64_t blck_1 = 16;
12375
+ // This is the size of the rest of the dimensions of the result
12376
+ const int64_t nr1 = ne1 * ne2 * ne3;
11049
12377
 
11050
12378
  // dot kernels can handle 1 row and col at a time, but mmla kernels can process 2 rows and cols
11051
- int64_t nrc = vec_dot_num_rows;
12379
+ int64_t num_rows_per_vec_dot = vec_dot_num_rows;
11052
12380
  // TODO: currently the mmla kernels support only even numbered rows/cols.
11053
12381
  // this check can be removed once they are extended to support odd numbered rows/cols too
11054
12382
  if ((nr0 % 2 != 0) || (ne11 % 2 != 0)) {
11055
- nrc = 1;
12383
+ num_rows_per_vec_dot = 1;
11056
12384
  }
11057
12385
 
11058
- const size_t src1_col_stride = src1_cont || src1->type != vec_dot_type ? row_size : nb11;
12386
+ // Now select a reasonable chunk size.
12387
+ int chunk_size = 16;
11059
12388
 
11060
- // attempt to reduce false-sharing (does not seem to make a difference)
11061
- // 16 * 2, accounting for mmla kernels
11062
- float tmp[32];
12389
+ // We need to step up the size if it's small
12390
+ if (nr0 == 1 || nr1 == 1) {
12391
+ chunk_size = 64;
12392
+ }
11063
12393
 
11064
- for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) {
11065
- for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) {
11066
- for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ir1 += nrc) {
11067
- const int64_t i13 = (ir1/(ne12*ne1));
11068
- const int64_t i12 = (ir1 - i13*ne12*ne1)/ne1;
11069
- const int64_t i11 = (ir1 - i13*ne12*ne1 - i12*ne1);
12394
+ // distribute the work across the inner or outer loop based on which one is larger
12395
+ // The number of chunks in the 0/1 dim.
12396
+ // CEIL(nr0/chunk_size)
12397
+ int64_t nchunk0 = (nr0 + chunk_size - 1) / chunk_size;
12398
+ int64_t nchunk1 = (nr1 + chunk_size - 1) / chunk_size;
11070
12399
 
11071
- // broadcast src0 into src1
11072
- const int64_t i03 = i13/r3;
11073
- const int64_t i02 = i12/r2;
12400
+ // If the chunking is poor for the number of threads on this setup, scrap the whole plan. Re-chunk it by thread.
12401
+ // Also, chunking by thread was measured to have perform better on NUMA systems. See https://github.com/ggerganov/llama.cpp/pull/6915
12402
+ // In theory, chunking should be just as useful on NUMA and non NUMA systems, but testing disagreed with that.
12403
+ if (nchunk0 * nchunk1 < nth * 4 || ggml_is_numa()) {
12404
+ // distribute the thread work across the inner or outer loop based on which one is larger
12405
+ nchunk0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows
12406
+ nchunk1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows
12407
+ }
11074
12408
 
11075
- const int64_t i1 = i11;
11076
- const int64_t i2 = i12;
11077
- const int64_t i3 = i13;
12409
+ // The number of elements in each chunk
12410
+ const int64_t dr0 = (nr0 + nchunk0 - 1) / nchunk0;
12411
+ const int64_t dr1 = (nr1 + nchunk1 - 1) / nchunk1;
11078
12412
 
11079
- const char * src0_row = (const char *) src0->data + (0 + i02*nb02 + i03*nb03);
12413
+ //if (ith == 0)
12414
+ // printf("MUL_MAT = [%d, %d, %d, %d] x [%d, %d, %d, %d] = %d x %d = %d. Fp Ops/Ch %d\n", ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, nchunk0, nchunk1, nchunk0 * nchunk1, ne00 * nr0 * nr1 / nchunk0 / nchunk1);
11080
12415
 
11081
- // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
11082
- // if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
11083
- // the original src1 data pointer, so we should index using the indices directly
11084
- // TODO: this is a bit of a hack, we should probably have a better way to handle this
11085
- const char * src1_col = (const char *) wdata +
11086
- (src1_cont || src1->type != vec_dot_type
11087
- ? (i11 + i12*ne11 + i13*ne12*ne11)*row_size
11088
- : (i11*nb11 + i12*nb12 + i13*nb13));
11089
- float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3));
12416
+ // The first chunk comes from our thread_id, the rest will get auto-assigned.
12417
+ int current_chunk = ith;
11090
12418
 
11091
- //for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
11092
- // vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col);
11093
- //}
12419
+ while (current_chunk < nchunk0 * nchunk1) {
12420
+ const int64_t ith0 = current_chunk % nchunk0;
12421
+ const int64_t ith1 = current_chunk / nchunk0;
11094
12422
 
11095
- for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ir0 += nrc) {
11096
- vec_dot(ne00, &tmp[ir0 - iir0], (nrc>1 ? 16 : 0), src0_row + ir0*nb01, (nrc>1 ? nb01 : 0), src1_col, (nrc>1 ? src1_col_stride : 0), nrc);
11097
- }
12423
+ const int64_t ir0_start = dr0 * ith0;
12424
+ const int64_t ir0_end = MIN(ir0_start + dr0, nr0);
11098
12425
 
11099
- for (int cn = 0; cn < nrc; ++cn) {
11100
- memcpy(&dst_col[iir0 + cn*nb1/nb0], tmp + (cn*16), (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float));
11101
- }
11102
- }
12426
+ const int64_t ir1_start = dr1 * ith1;
12427
+ const int64_t ir1_end = MIN(ir1_start + dr1, nr1);
12428
+
12429
+ ggml_compute_forward_mul_mat_one_chunk(params, dst, num_rows_per_vec_dot, ir0_start, ir0_end, ir1_start, ir1_end);
12430
+
12431
+ #ifdef GGML_PERF
12432
+ chunks_executed++;
12433
+ #endif
12434
+
12435
+ if (nth >= nchunk0 * nchunk1) {
12436
+ break;
11103
12437
  }
12438
+
12439
+ current_chunk = atomic_fetch_add(&state->shared->current_chunk, 1);
11104
12440
  }
12441
+
12442
+ #ifdef GGML_PERF
12443
+ // These numbers are useful when trying to measure how well the threading scheduling works.
12444
+ //int64_t workSize = (ne01 * ne11 * ne12 * ne13 * ne00) / nchunk0 / nchunk1;
12445
+ //float time = (ggml_perf_time_us() - t0);
12446
+ //printf("MUL_MAT = %f ms, [%d, %d, %d, %d] x [%d, %d, %d, %d] = %I64u, %f ops/usec in %d chunks.\n", time / 1000.0, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, workSize, (float)workSize/time, chunks_executed);
12447
+ #endif
11105
12448
  }
11106
12449
 
11107
12450
  // ggml_compute_forward_mul_mat_id
@@ -11793,6 +13136,7 @@ static void ggml_compute_forward_set(
11793
13136
  ggml_compute_forward_set_f32(params, dst);
11794
13137
  } break;
11795
13138
  case GGML_TYPE_F16:
13139
+ case GGML_TYPE_BF16:
11796
13140
  case GGML_TYPE_Q4_0:
11797
13141
  case GGML_TYPE_Q4_1:
11798
13142
  case GGML_TYPE_Q5_0:
@@ -11918,13 +13262,56 @@ static void ggml_compute_forward_get_rows_q(
11918
13262
  const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
11919
13263
  const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
11920
13264
 
11921
- dequantize_row_q(
13265
+ dequantize_row_q(
13266
+ (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
13267
+ (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc);
13268
+ }
13269
+ }
13270
+
13271
+ static void ggml_compute_forward_get_rows_f16(
13272
+ const struct ggml_compute_params * params,
13273
+ struct ggml_tensor * dst) {
13274
+
13275
+ const struct ggml_tensor * src0 = dst->src[0];
13276
+ const struct ggml_tensor * src1 = dst->src[1];
13277
+
13278
+ if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
13279
+ return;
13280
+ }
13281
+
13282
+ GGML_TENSOR_BINARY_OP_LOCALS
13283
+
13284
+ const int64_t nc = ne00;
13285
+ const int64_t nr = ggml_nelements(src1);
13286
+
13287
+ assert(ne0 == nc);
13288
+ assert(ne02 == ne11);
13289
+ assert(nb00 == sizeof(ggml_fp16_t));
13290
+ assert(ggml_nrows(dst) == nr);
13291
+
13292
+ const int ith = params->ith;
13293
+ const int nth = params->nth;
13294
+
13295
+ // rows per thread
13296
+ const int dr = (nr + nth - 1)/nth;
13297
+
13298
+ // row range for this thread
13299
+ const int ir0 = dr*ith;
13300
+ const int ir1 = MIN(ir0 + dr, nr);
13301
+
13302
+ for (int64_t i = ir0; i < ir1; ++i) {
13303
+ const int64_t i12 = i/(ne11*ne10);
13304
+ const int64_t i11 = (i - i12*ne11*ne10)/ne10;
13305
+ const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
13306
+ const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
13307
+
13308
+ ggml_fp16_to_fp32_row(
11922
13309
  (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
11923
13310
  (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc);
11924
13311
  }
11925
13312
  }
11926
13313
 
11927
- static void ggml_compute_forward_get_rows_f16(
13314
+ static void ggml_compute_forward_get_rows_bf16(
11928
13315
  const struct ggml_compute_params * params,
11929
13316
  struct ggml_tensor * dst) {
11930
13317
 
@@ -11942,7 +13329,7 @@ static void ggml_compute_forward_get_rows_f16(
11942
13329
 
11943
13330
  assert(ne0 == nc);
11944
13331
  assert(ne02 == ne11);
11945
- assert(nb00 == sizeof(ggml_fp16_t));
13332
+ assert(nb00 == sizeof(ggml_bf16_t));
11946
13333
  assert(ggml_nrows(dst) == nr);
11947
13334
 
11948
13335
  const int ith = params->ith;
@@ -11961,7 +13348,7 @@ static void ggml_compute_forward_get_rows_f16(
11961
13348
  const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
11962
13349
  const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
11963
13350
 
11964
- ggml_fp16_to_fp32_row(
13351
+ ggml_bf16_to_fp32_row(
11965
13352
  (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
11966
13353
  (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc);
11967
13354
  }
@@ -12044,6 +13431,10 @@ static void ggml_compute_forward_get_rows(
12044
13431
  {
12045
13432
  ggml_compute_forward_get_rows_f16(params, dst);
12046
13433
  } break;
13434
+ case GGML_TYPE_BF16:
13435
+ {
13436
+ ggml_compute_forward_get_rows_bf16(params, dst);
13437
+ } break;
12047
13438
  case GGML_TYPE_F32:
12048
13439
  case GGML_TYPE_I32:
12049
13440
  {
@@ -12356,7 +13747,6 @@ static void ggml_compute_forward_soft_max_f32(
12356
13747
 
12357
13748
  const struct ggml_tensor * src0 = dst->src[0];
12358
13749
  const struct ggml_tensor * src1 = dst->src[1];
12359
- const struct ggml_tensor * src2 = dst->src[2];
12360
13750
 
12361
13751
  assert(ggml_is_contiguous(dst));
12362
13752
  assert(ggml_are_same_shape(src0, dst));
@@ -12382,8 +13772,8 @@ static void ggml_compute_forward_soft_max_f32(
12382
13772
 
12383
13773
  // TODO: is this supposed to be ceil instead of floor?
12384
13774
  // https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370
12385
- const uint32_t n_head_kv = ne02;
12386
- const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head_kv));
13775
+ const uint32_t n_head = ne02;
13776
+ const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
12387
13777
 
12388
13778
  const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
12389
13779
  const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
@@ -12400,13 +13790,13 @@ static void ggml_compute_forward_soft_max_f32(
12400
13790
 
12401
13791
  float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
12402
13792
 
12403
- // when max_bias <= 0.0f, src2 is not used and we default it to src0 to avoid branching
12404
- ggml_fp16_t * pos_f16 = src2 ? (ggml_fp16_t *) src2->data : src0->data;
12405
- float * pos_f32 = src2 ? (float *) src2->data : src0->data;
12406
-
12407
- const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16) || (src2 && src2->type == GGML_TYPE_F16);
13793
+ const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
12408
13794
 
12409
13795
  for (int i1 = ir0; i1 < ir1; i1++) {
13796
+ // ALiBi
13797
+ const uint32_t h = (i1/ne01)%ne02; // head
13798
+ const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
13799
+
12410
13800
  float * sp = (float *)((char *) src0->data + i1*src0->nb[1]);
12411
13801
  float * dp = (float *)((char *) dst->data + i1*dst->nb[1]);
12412
13802
 
@@ -12419,27 +13809,11 @@ static void ggml_compute_forward_soft_max_f32(
12419
13809
  if (mp_f32) {
12420
13810
  if (use_f16) {
12421
13811
  for (int i = 0; i < nc; ++i) {
12422
- wp[i] += GGML_FP16_TO_FP32(mp_f16[i]);
12423
- }
12424
- } else {
12425
- for (int i = 0; i < nc; ++i) {
12426
- wp[i] += mp_f32[i];
12427
- }
12428
- }
12429
- }
12430
-
12431
- // ALiBi bias
12432
- if (max_bias > 0.0f) {
12433
- const uint32_t h = (i1/ne01)%ne02; // head
12434
- const float slope = h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1);
12435
-
12436
- if (use_f16) {
12437
- for (int i = 0; i < nc; ++i) {
12438
- wp[i] += slope*GGML_FP16_TO_FP32(pos_f16[i]);
13812
+ wp[i] += slope*GGML_FP16_TO_FP32(mp_f16[i]);
12439
13813
  }
12440
13814
  } else {
12441
13815
  for (int i = 0; i < nc; ++i) {
12442
- wp[i] += slope*pos_f32[i];
13816
+ wp[i] += slope*mp_f32[i];
12443
13817
  }
12444
13818
  }
12445
13819
  }
@@ -12454,22 +13828,7 @@ static void ggml_compute_forward_soft_max_f32(
12454
13828
  float max = -INFINITY;
12455
13829
  ggml_vec_max_f32(nc, &max, wp);
12456
13830
 
12457
- ggml_float sum = 0.0;
12458
-
12459
- uint16_t scvt;
12460
- for (int i = 0; i < nc; i++) {
12461
- if (wp[i] == -INFINITY) {
12462
- dp[i] = 0.0f;
12463
- } else {
12464
- // const float val = (wp[i] == -INFINITY) ? 0.0 : exp(wp[i] - max);
12465
- ggml_fp16_t s = GGML_FP32_TO_FP16(wp[i] - max);
12466
- memcpy(&scvt, &s, sizeof(scvt));
12467
- const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt]);
12468
- sum += (ggml_float)val;
12469
- dp[i] = val;
12470
- }
12471
- }
12472
-
13831
+ ggml_float sum = ggml_vec_soft_max_f32(nc, dp, wp, max);
12473
13832
  assert(sum > 0.0);
12474
13833
 
12475
13834
  sum = 1.0/sum;
@@ -12601,177 +13960,6 @@ static void ggml_compute_forward_soft_max_back(
12601
13960
  }
12602
13961
  }
12603
13962
 
12604
- // ggml_compute_forward_alibi
12605
-
12606
- static void ggml_compute_forward_alibi_f32(
12607
- const struct ggml_compute_params * params,
12608
- struct ggml_tensor * dst) {
12609
-
12610
- const struct ggml_tensor * src0 = dst->src[0];
12611
-
12612
- assert(params->ith == 0);
12613
-
12614
- if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
12615
- return;
12616
- }
12617
-
12618
- //const int n_past = ((int32_t *) dst->op_params)[0];
12619
- const int n_head = ((int32_t *) dst->op_params)[1];
12620
- float max_bias;
12621
- memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
12622
-
12623
- const int64_t ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
12624
- const int64_t ne1 = src0->ne[1]; // seq_len_without_past
12625
- const int64_t ne2 = src0->ne[2]; // n_head -> this is k
12626
- //const int64_t ne3 = src0->ne[3]; // 1 -> bsz
12627
-
12628
- const int64_t n = ggml_nrows(src0);
12629
- const int64_t ne2_ne3 = n/ne1; // ne2*ne3
12630
-
12631
- const size_t nb0 = src0->nb[0];
12632
- const size_t nb1 = src0->nb[1];
12633
- const size_t nb2 = src0->nb[2];
12634
- //const int nb3 = src0->nb[3];
12635
-
12636
- GGML_ASSERT(nb0 == sizeof(float));
12637
- GGML_ASSERT(n_head == ne2);
12638
-
12639
- // add alibi to src0 (KQ_scaled)
12640
- const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
12641
-
12642
- const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
12643
- const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
12644
-
12645
- for (int64_t k = 0; k < ne2_ne3; k++) {
12646
- // TODO: k*nb2 or k*nb3
12647
- float m_k;
12648
-
12649
- if (k < n_heads_log2_floor) {
12650
- m_k = powf(m0, k + 1);
12651
- } else {
12652
- m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1);
12653
- }
12654
-
12655
- for (int64_t i = 0; i < ne0; i++) {
12656
- for (int64_t j = 0; j < ne1; j++) {
12657
- float * const src = (float *)((char *) src0->data + i*nb0 + j*nb1 + k*nb2);
12658
- float * pdst = (float *)((char *) dst->data + i*nb0 + j*nb1 + k*nb2);
12659
- pdst[0] = i * m_k + src[0];
12660
- }
12661
- }
12662
- }
12663
- }
12664
-
12665
- static void ggml_compute_forward_alibi_f16(
12666
- const struct ggml_compute_params * params,
12667
- struct ggml_tensor * dst) {
12668
-
12669
- const struct ggml_tensor * src0 = dst->src[0];
12670
-
12671
- assert(params->ith == 0);
12672
-
12673
- if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
12674
- return;
12675
- }
12676
-
12677
- //const int n_past = ((int32_t *) dst->op_params)[0];
12678
- const int n_head = ((int32_t *) dst->op_params)[1];
12679
- float max_bias;
12680
- memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
12681
-
12682
- const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
12683
- const int ne1 = src0->ne[1]; // seq_len_without_past
12684
- const int ne2 = src0->ne[2]; // n_head -> this is k
12685
- //const int ne3 = src0->ne[3]; // 1 -> bsz
12686
-
12687
- const int n = ggml_nrows(src0);
12688
- const int ne2_ne3 = n/ne1; // ne2*ne3
12689
-
12690
- const int nb0 = src0->nb[0];
12691
- const int nb1 = src0->nb[1];
12692
- const int nb2 = src0->nb[2];
12693
- //const int nb3 = src0->nb[3];
12694
-
12695
- GGML_ASSERT(nb0 == sizeof(ggml_fp16_t));
12696
- //GGML_ASSERT(ne1 + n_past == ne0); (void) n_past;
12697
- GGML_ASSERT(n_head == ne2);
12698
-
12699
- // add alibi to src0 (KQ_scaled)
12700
- const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
12701
-
12702
- const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
12703
- const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
12704
-
12705
- for (int k = 0; k < ne2_ne3; k++) {
12706
- // TODO: k*nb2 or k*nb3
12707
- float m_k;
12708
-
12709
- if (k < n_heads_log2_floor) {
12710
- m_k = powf(m0, k + 1);
12711
- } else {
12712
- m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1);
12713
- }
12714
-
12715
- for (int i = 0; i < ne0; i++) {
12716
- for (int j = 0; j < ne1; j++) {
12717
- ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i*nb0 + j*nb1 + k*nb2);
12718
- float * pdst = (float *)((char *) dst->data + i*nb0 + j*nb1 + k*nb2);
12719
-
12720
- // we return F32
12721
- pdst[0] = i * m_k + GGML_FP16_TO_FP32(src[0]);
12722
- }
12723
- }
12724
- }
12725
- }
12726
-
12727
- static void ggml_compute_forward_alibi(
12728
- const struct ggml_compute_params * params,
12729
- struct ggml_tensor * dst) {
12730
-
12731
- const struct ggml_tensor * src0 = dst->src[0];
12732
-
12733
- switch (src0->type) {
12734
- case GGML_TYPE_F16:
12735
- {
12736
- ggml_compute_forward_alibi_f16(params, dst);
12737
- } break;
12738
- case GGML_TYPE_F32:
12739
- {
12740
- ggml_compute_forward_alibi_f32(params, dst);
12741
- } break;
12742
- case GGML_TYPE_Q4_0:
12743
- case GGML_TYPE_Q4_1:
12744
- case GGML_TYPE_Q5_0:
12745
- case GGML_TYPE_Q5_1:
12746
- case GGML_TYPE_Q8_0:
12747
- case GGML_TYPE_Q8_1:
12748
- case GGML_TYPE_Q2_K:
12749
- case GGML_TYPE_Q3_K:
12750
- case GGML_TYPE_Q4_K:
12751
- case GGML_TYPE_Q5_K:
12752
- case GGML_TYPE_Q6_K:
12753
- case GGML_TYPE_IQ2_XXS:
12754
- case GGML_TYPE_IQ2_XS:
12755
- case GGML_TYPE_IQ3_XXS:
12756
- case GGML_TYPE_IQ1_S:
12757
- case GGML_TYPE_IQ1_M:
12758
- case GGML_TYPE_IQ4_NL:
12759
- case GGML_TYPE_IQ4_XS:
12760
- case GGML_TYPE_IQ3_S:
12761
- case GGML_TYPE_IQ2_S:
12762
- case GGML_TYPE_Q8_K:
12763
- case GGML_TYPE_I8:
12764
- case GGML_TYPE_I16:
12765
- case GGML_TYPE_I32:
12766
- case GGML_TYPE_I64:
12767
- case GGML_TYPE_F64:
12768
- case GGML_TYPE_COUNT:
12769
- {
12770
- GGML_ASSERT(false);
12771
- } break;
12772
- }
12773
- }
12774
-
12775
13963
  // ggml_compute_forward_clamp
12776
13964
 
12777
13965
  static void ggml_compute_forward_clamp_f32(
@@ -12828,6 +14016,7 @@ static void ggml_compute_forward_clamp(
12828
14016
  ggml_compute_forward_clamp_f32(params, dst);
12829
14017
  } break;
12830
14018
  case GGML_TYPE_F16:
14019
+ case GGML_TYPE_BF16:
12831
14020
  case GGML_TYPE_Q4_0:
12832
14021
  case GGML_TYPE_Q4_1:
12833
14022
  case GGML_TYPE_Q5_0:
@@ -13993,25 +15182,28 @@ static void ggml_compute_forward_upscale_f32(
13993
15182
  return;
13994
15183
  }
13995
15184
 
13996
- GGML_ASSERT(src0->nb[0] == sizeof(float));
15185
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
13997
15186
 
13998
15187
  const int ith = params->ith;
13999
15188
  const int nth = params->nth;
14000
15189
 
14001
15190
  GGML_TENSOR_UNARY_OP_LOCALS
14002
15191
 
14003
- const int scale_factor = dst->op_params[0];
15192
+ const float sf0 = (float)ne0/src0->ne[0];
15193
+ const float sf1 = (float)ne1/src0->ne[1];
15194
+ const float sf2 = (float)ne2/src0->ne[2];
15195
+ const float sf3 = (float)ne3/src0->ne[3];
14004
15196
 
14005
15197
  // TODO: optimize
14006
15198
 
14007
15199
  for (int64_t i3 = 0; i3 < ne3; i3++) {
14008
- const int64_t i03 = i3;
15200
+ const int64_t i03 = i3 / sf3;
14009
15201
  for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
14010
- const int64_t i02 = i2;
15202
+ const int64_t i02 = i2 / sf2;
14011
15203
  for (int64_t i1 = 0; i1 < ne1; i1++) {
14012
- const int64_t i01 = i1 / scale_factor;
15204
+ const int64_t i01 = i1 / sf1;
14013
15205
  for (int64_t i0 = 0; i0 < ne0; i0++) {
14014
- const int64_t i00 = i0 / scale_factor;
15206
+ const int64_t i00 = i0 / sf0;
14015
15207
 
14016
15208
  const float * x = (float *)((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
14017
15209
  float * y = (float *)((char *) dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
@@ -14041,6 +15233,7 @@ static void ggml_compute_forward_upscale(
14041
15233
  }
14042
15234
  }
14043
15235
 
15236
+
14044
15237
  // ggml_compute_forward_pad
14045
15238
 
14046
15239
  static void ggml_compute_forward_pad_f32(
@@ -14394,37 +15587,7 @@ static void ggml_compute_forward_flash_attn_f32(
14394
15587
  vvexpf(S, S, &Mup);
14395
15588
  ggml_vec_sum_f32(Mup, &sum, S);
14396
15589
  #else
14397
- uint16_t scvt[GGML_SOFT_MAX_UNROLL]; UNUSED(scvt);
14398
- ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 };
14399
-
14400
- for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) {
14401
- if (i >= masked_begin) {
14402
- break;
14403
- }
14404
- float * SS = S + i;
14405
-
14406
- for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) {
14407
- if (i + j >= masked_begin) {
14408
- break;
14409
- } else if (SS[j] == -INFINITY) {
14410
- SS[j] = 0.0f;
14411
- } else {
14412
- #ifndef GGML_FLASH_ATTN_EXP_FP16
14413
- const float val = expf(SS[j] - max);
14414
- #else
14415
- ggml_fp16_t s = GGML_FP32_TO_FP16(SS[j] - max);
14416
- memcpy(&scvt[j], &s, sizeof(uint16_t));
14417
- const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt[j]]);
14418
- #endif
14419
- sump[j] += (ggml_float)val;
14420
- SS[j] = val;
14421
- }
14422
- }
14423
- }
14424
-
14425
- for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) {
14426
- sum += sump[i];
14427
- }
15590
+ sum = ggml_vec_soft_max_f32(Mup, S, S, max);
14428
15591
  #endif
14429
15592
  }
14430
15593
 
@@ -14606,28 +15769,7 @@ static void ggml_compute_forward_flash_attn_f16(
14606
15769
  vvexpf(S, S, &Mup);
14607
15770
  ggml_vec_sum_f32(Mup, &sum, S);
14608
15771
  #else
14609
- uint16_t scvt[GGML_SOFT_MAX_UNROLL];
14610
- ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 };
14611
-
14612
- for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) {
14613
- float * SS = S + i;
14614
-
14615
- for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) {
14616
- if (SS[j] == -INFINITY) {
14617
- SS[j] = 0.0f;
14618
- } else {
14619
- ggml_fp16_t s = GGML_FP32_TO_FP16(SS[j] - max);
14620
- memcpy(&scvt[j], &s, sizeof(uint16_t));
14621
- const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt[j]]);
14622
- sump[j] += (ggml_float)val;
14623
- SS[j] = val;
14624
- }
14625
- }
14626
- }
14627
-
14628
- for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) {
14629
- sum += sump[i];
14630
- }
15772
+ sum = ggml_vec_soft_max_f32(Mup, S, S, max);
14631
15773
  #endif
14632
15774
  }
14633
15775
 
@@ -14784,8 +15926,17 @@ static void ggml_compute_forward_flash_attn_ext_f16(
14784
15926
  const int ir0 = dr*ith;
14785
15927
  const int ir1 = MIN(ir0 + dr, nr);
14786
15928
 
14787
- float scale = 1.0f;
14788
- memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
15929
+ float scale = 1.0f;
15930
+ float max_bias = 0.0f;
15931
+
15932
+ memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
15933
+ memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
15934
+
15935
+ const uint32_t n_head = neq2;
15936
+ const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
15937
+
15938
+ const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
15939
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
14789
15940
 
14790
15941
  // loop over n_batch and n_head
14791
15942
  for (int ir = ir0; ir < ir1; ++ir) {
@@ -14794,6 +15945,9 @@ static void ggml_compute_forward_flash_attn_ext_f16(
14794
15945
  const int iq2 = (ir - iq3*neq2*neq1)/neq1;
14795
15946
  const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
14796
15947
 
15948
+ const uint32_t h = iq2; // head
15949
+ const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
15950
+
14797
15951
  float S = 0.0f;
14798
15952
  float M = -INFINITY;
14799
15953
 
@@ -14817,7 +15971,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
14817
15971
  // loop over n_kv and n_head_kv
14818
15972
  // ref: https://arxiv.org/pdf/2112.05682.pdf
14819
15973
  for (int64_t ic = 0; ic < nek1; ++ic) {
14820
- const float mv = mp ? GGML_FP16_TO_FP32(mp[ic]) : 0.0f;
15974
+ const float mv = mp ? slope*GGML_FP16_TO_FP32(mp[ic]) : 0.0f;
14821
15975
  if (mv == -INFINITY) {
14822
15976
  continue;
14823
15977
  }
@@ -14888,7 +16042,7 @@ static void ggml_compute_forward_flash_attn_ext(
14888
16042
  const struct ggml_tensor * v,
14889
16043
  const struct ggml_tensor * mask,
14890
16044
  struct ggml_tensor * dst) {
14891
- switch (dst->op_params[1]) {
16045
+ switch (dst->op_params[2]) {
14892
16046
  case GGML_PREC_DEFAULT:
14893
16047
  case GGML_PREC_F32:
14894
16048
  {
@@ -15242,38 +16396,7 @@ static void ggml_compute_forward_flash_attn_back_f32(
15242
16396
  vvexpf(SM, SM, &Mup);
15243
16397
  ggml_vec_sum_f32(Mup, &sum, SM);
15244
16398
  #else
15245
- uint16_t scvt[GGML_SOFT_MAX_UNROLL]; UNUSED(scvt);
15246
- ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 };
15247
-
15248
- for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) {
15249
- if (i >= masked_begin) {
15250
- break;
15251
- }
15252
- float * SR = S + i;
15253
- float * SW = SM + i;
15254
-
15255
- for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) {
15256
- if (i + j >= masked_begin) {
15257
- break;
15258
- } else if (SR[j] == -INFINITY) {
15259
- SW[j] = 0.0f;
15260
- } else {
15261
- #ifndef GGML_FLASH_ATTN_EXP_FP16
15262
- const float val = expf(SR[j] - max);
15263
- #else
15264
- ggml_fp16_t s = GGML_FP32_TO_FP16(SR[j] - max);
15265
- memcpy(&scvt[j], &s, sizeof(uint16_t));
15266
- const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt[j]]);
15267
- #endif
15268
- sump[j] += (ggml_float)val;
15269
- SW[j] = val;
15270
- }
15271
- }
15272
- }
15273
-
15274
- for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) {
15275
- sum += sump[i];
15276
- }
16399
+ sum = ggml_vec_soft_max_f32(Mup, SM, S, max);
15277
16400
  #endif
15278
16401
  }
15279
16402
 
@@ -15855,6 +16978,10 @@ static void ggml_compute_forward_unary(
15855
16978
  {
15856
16979
  ggml_compute_forward_relu(params, dst);
15857
16980
  } break;
16981
+ case GGML_UNARY_OP_SIGMOID:
16982
+ {
16983
+ ggml_compute_forward_sigmoid(params, dst);
16984
+ } break;
15858
16985
  case GGML_UNARY_OP_GELU:
15859
16986
  {
15860
16987
  ggml_compute_forward_gelu(params, dst);
@@ -15921,6 +17048,7 @@ static void ggml_compute_forward_get_rel_pos(
15921
17048
 
15922
17049
  switch (src0->type) {
15923
17050
  case GGML_TYPE_F16:
17051
+ case GGML_TYPE_BF16:
15924
17052
  {
15925
17053
  ggml_compute_forward_get_rel_pos_f16(params, dst);
15926
17054
  } break;
@@ -16294,35 +17422,15 @@ static void ggml_compute_forward_cross_entropy_loss_f32(
16294
17422
  assert(!isnan(s1[i]));
16295
17423
  }
16296
17424
  #endif
16297
- // soft_max
16298
- ggml_float sum = 0.0;
16299
- {
16300
- float max = -INFINITY;
16301
- ggml_vec_max_f32(nc, &max, s0);
16302
17425
 
16303
- uint16_t scvt; UNUSED(scvt);
16304
- for (int i = 0; i < nc; i++) {
16305
- if (s0[i] == -INFINITY) {
16306
- st[i] = 0.0f;
16307
- } else {
16308
- #ifndef GGML_CROSS_ENTROPY_EXP_FP16
16309
- const float s = s0[i] - max;
16310
- const float val = expf(s);
16311
- #else
16312
- ggml_fp16_t s = GGML_FP32_TO_FP16(s0[i] - max);
16313
- memcpy(&scvt, &s, sizeof(scvt));
16314
- const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt]);
16315
- #endif
16316
- sum += (ggml_float)val;
16317
- st[i] = val;
16318
- }
16319
- }
17426
+ // soft_max
17427
+ float max = -INFINITY;
17428
+ ggml_vec_max_f32(nc, &max, s0);
17429
+ ggml_float sum = ggml_vec_soft_max_f32(nc, st, s0, max);
17430
+ assert(sum > 0.0);
17431
+ sum = (1.0 - eps) / sum;
16320
17432
 
16321
- assert(sum > 0.0);
16322
- // sum = 1.0/sum;
16323
- }
16324
17433
  // avoid log(0) by rescaling from [0..1] to [eps..1]
16325
- sum = (1.0 - eps) / sum;
16326
17434
  ggml_vec_scale_f32(nc, st, sum);
16327
17435
  ggml_vec_add1_f32(nc, st, st, eps);
16328
17436
  ggml_vec_log_f32(nc, st, st);
@@ -16412,32 +17520,11 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
16412
17520
  #endif
16413
17521
 
16414
17522
  // soft_max
16415
- ggml_float sum = 0.0;
16416
- {
16417
- float max = -INFINITY;
16418
- ggml_vec_max_f32(nc, &max, s0);
16419
-
16420
- uint16_t scvt; UNUSED(scvt);
16421
- for (int i = 0; i < nc; i++) {
16422
- if (s0[i] == -INFINITY) {
16423
- ds0[i] = 0.0f;
16424
- } else {
16425
- #ifndef GGML_CROSS_ENTROPY_EXP_FP16
16426
- const float s = s0[i] - max;
16427
- const float val = expf(s);
16428
- #else
16429
- ggml_fp16_t s = GGML_FP32_TO_FP16(s0[i] - max);
16430
- memcpy(&scvt, &s, sizeof(scvt));
16431
- const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt]);
16432
- #endif
16433
- sum += (ggml_float)val;
16434
- ds0[i] = val;
16435
- }
16436
- }
16437
-
16438
- assert(sum > 0.0);
16439
- sum = (1.0 - eps)/sum;
16440
- }
17523
+ float max = -INFINITY;
17524
+ ggml_vec_max_f32(nc, &max, s0);
17525
+ ggml_float sum = ggml_vec_soft_max_f32(nc, ds0, s0, max);
17526
+ assert(sum > 0.0);
17527
+ sum = (1.0 - eps) / sum;
16441
17528
 
16442
17529
  // grad(src0) = (softmax(src0) - src1) * grad(cross_entropy_loss(src0, src1)) / nr
16443
17530
  ggml_vec_scale_f32(nc, ds0, sum);
@@ -16474,7 +17561,7 @@ static void ggml_compute_forward_cross_entropy_loss_back(
16474
17561
 
16475
17562
  /////////////////////////////////
16476
17563
 
16477
- static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
17564
+ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor, struct ggml_compute_state * state) {
16478
17565
  GGML_ASSERT(params);
16479
17566
 
16480
17567
  if (tensor->op == GGML_OP_NONE || ggml_is_empty(tensor)) {
@@ -16572,7 +17659,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
16572
17659
  } break;
16573
17660
  case GGML_OP_MUL_MAT:
16574
17661
  {
16575
- ggml_compute_forward_mul_mat(params, tensor);
17662
+ ggml_compute_forward_mul_mat(params, tensor, state);
16576
17663
  } break;
16577
17664
  case GGML_OP_MUL_MAT_ID:
16578
17665
  {
@@ -16650,10 +17737,6 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
16650
17737
  {
16651
17738
  ggml_compute_forward_rope_back(params, tensor);
16652
17739
  } break;
16653
- case GGML_OP_ALIBI:
16654
- {
16655
- ggml_compute_forward_alibi(params, tensor);
16656
- } break;
16657
17740
  case GGML_OP_CLAMP:
16658
17741
  {
16659
17742
  ggml_compute_forward_clamp(params, tensor);
@@ -17672,10 +18755,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
17672
18755
  zero_table);
17673
18756
  }
17674
18757
  } break;
17675
- case GGML_OP_ALIBI:
17676
- {
17677
- GGML_ASSERT(false); // TODO: not implemented
17678
- } break;
17679
18758
  case GGML_OP_CLAMP:
17680
18759
  {
17681
18760
  GGML_ASSERT(false); // TODO: not implemented
@@ -17846,6 +18925,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
17846
18925
  zero_table);
17847
18926
  }
17848
18927
  } break;
18928
+ case GGML_UNARY_OP_SIGMOID:
18929
+ {
18930
+ GGML_ASSERT(false); // TODO: not implemented
18931
+ } break;
17849
18932
  case GGML_UNARY_OP_GELU:
17850
18933
  {
17851
18934
  GGML_ASSERT(false); // TODO: not implemented
@@ -18192,8 +19275,6 @@ typedef int ggml_lock_t;
18192
19275
 
18193
19276
  #define GGML_LOCK_INITIALIZER 0
18194
19277
 
18195
- typedef pthread_t ggml_thread_t;
18196
-
18197
19278
  #define ggml_thread_create pthread_create
18198
19279
  #define ggml_thread_join pthread_join
18199
19280
 
@@ -18219,8 +19300,6 @@ typedef int ggml_lock_t;
18219
19300
 
18220
19301
  #define GGML_LOCK_INITIALIZER 0
18221
19302
 
18222
- typedef pthread_t ggml_thread_t;
18223
-
18224
19303
  #define ggml_thread_create pthread_create
18225
19304
  #define ggml_thread_join pthread_join
18226
19305
 
@@ -18300,31 +19379,6 @@ static void set_numa_thread_affinity(int thread_n) { UNUSED(thread_n); }
18300
19379
  static void clear_numa_thread_affinity(void) {}
18301
19380
  #endif
18302
19381
 
18303
- struct ggml_compute_state_shared {
18304
- const struct ggml_cgraph * cgraph;
18305
- const struct ggml_cplan * cplan;
18306
-
18307
- int64_t perf_node_start_cycles;
18308
- int64_t perf_node_start_time_us;
18309
-
18310
- const int n_threads;
18311
-
18312
- // synchronization primitives
18313
- atomic_int n_active; // num active threads
18314
- atomic_int node_n; // active graph node
18315
- atomic_int node_task; // active graph node task phase
18316
-
18317
- ggml_abort_callback abort_callback; // abort ggml_graph_compute when true
18318
- void * abort_callback_data;
18319
- };
18320
-
18321
- struct ggml_compute_state {
18322
- ggml_thread_t thrd;
18323
- int ith;
18324
- struct ggml_compute_state_shared * shared;
18325
- enum ggml_status ec;
18326
- };
18327
-
18328
19382
  static void ggml_graph_compute_perf_stats_node(struct ggml_tensor * node, const struct ggml_compute_state_shared * st) {
18329
19383
  int64_t cycles_cur = ggml_perf_cycles() - st->perf_node_start_cycles;
18330
19384
  int64_t time_us_cur = ggml_perf_time_us() - st->perf_node_start_time_us;
@@ -18375,6 +19429,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads, int n_cur_
18375
19429
  case GGML_UNARY_OP_TANH:
18376
19430
  case GGML_UNARY_OP_ELU:
18377
19431
  case GGML_UNARY_OP_RELU:
19432
+ case GGML_UNARY_OP_SIGMOID:
18378
19433
  case GGML_UNARY_OP_HARDSWISH: // to opt for multiple threads
18379
19434
  case GGML_UNARY_OP_HARDSIGMOID: // to opt for multiple threads
18380
19435
  {
@@ -18448,10 +19503,6 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads, int n_cur_
18448
19503
  {
18449
19504
  n_tasks = n_threads;
18450
19505
  } break;
18451
- case GGML_OP_ALIBI:
18452
- {
18453
- n_tasks = 1; //TODO
18454
- } break;
18455
19506
  case GGML_OP_CLAMP:
18456
19507
  {
18457
19508
  n_tasks = 1; //TODO
@@ -18600,6 +19651,10 @@ static void ggml_graph_compute_thread_sync_node(int * node_n, struct ggml_comput
18600
19651
 
18601
19652
  * node_n = atomic_load(&state->shared->node_n);
18602
19653
  if (* node_n != last_node_n) break;
19654
+ #if defined(__SSE3__)
19655
+ // Tell the processor we're spinning. It's a processor hint for spinlocks.
19656
+ _mm_pause();
19657
+ #endif
18603
19658
  }
18604
19659
  }
18605
19660
 
@@ -18614,6 +19669,10 @@ static void ggml_graph_compute_thread_sync_task(int * task_phase, struct ggml_co
18614
19669
 
18615
19670
  * task_phase = atomic_load(&state->shared->node_task);
18616
19671
  if (* task_phase != last_task_phase) break;
19672
+ #if defined(__SSE3__)
19673
+ // Tell the processor we're spinning. It's a processor hint for spinlocks.
19674
+ _mm_pause();
19675
+ #endif
18617
19676
  }
18618
19677
  }
18619
19678
 
@@ -18653,7 +19712,7 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
18653
19712
  struct ggml_tensor * node = cgraph->nodes[node_n];
18654
19713
  if (GGML_OP_HAS_FINALIZE[node->op]) {
18655
19714
  params.nth = ggml_get_n_tasks(node, n_threads, state->shared->n_threads);
18656
- ggml_compute_forward(&params, node);
19715
+ ggml_compute_forward(&params, node, state);
18657
19716
  }
18658
19717
  ggml_graph_compute_perf_stats_node(node, state->shared);
18659
19718
  }
@@ -18673,17 +19732,17 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
18673
19732
  /* INIT */
18674
19733
  if (GGML_OP_HAS_INIT[node->op]) {
18675
19734
  params.type = GGML_TASK_TYPE_INIT;
18676
- ggml_compute_forward(&params, node);
19735
+ ggml_compute_forward(&params, node, state);
18677
19736
  }
18678
19737
 
18679
19738
  // TODO: maybe push node_n to the atomic but if other threads see n_tasks is 1,
18680
19739
  // they do something more efficient than spinning (?)
18681
19740
  params.type = GGML_TASK_TYPE_COMPUTE;
18682
- ggml_compute_forward(&params, node);
19741
+ ggml_compute_forward(&params, node, state);
18683
19742
 
18684
19743
  if (GGML_OP_HAS_FINALIZE[node->op]) {
18685
19744
  params.type = GGML_TASK_TYPE_FINALIZE;
18686
- ggml_compute_forward(&params, node);
19745
+ ggml_compute_forward(&params, node, state);
18687
19746
  }
18688
19747
 
18689
19748
  ggml_graph_compute_perf_stats_node(node, state->shared);
@@ -18722,7 +19781,7 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
18722
19781
 
18723
19782
  if (state->ith < n_tasks) {
18724
19783
  if (GGML_OP_HAS_INIT[node->op]) {
18725
- ggml_compute_forward(&params, node);
19784
+ ggml_compute_forward(&params, node, state);
18726
19785
  }
18727
19786
  }
18728
19787
 
@@ -18743,7 +19802,7 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
18743
19802
 
18744
19803
  if (state->ith < n_tasks) {
18745
19804
  params.type = GGML_TASK_TYPE_COMPUTE;
18746
- ggml_compute_forward(&params, node);
19805
+ ggml_compute_forward(&params, node, state);
18747
19806
  }
18748
19807
 
18749
19808
  if (atomic_fetch_sub(&state->shared->n_active, 1) == 1) {
@@ -18785,7 +19844,10 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
18785
19844
  case GGML_OP_CPY:
18786
19845
  case GGML_OP_DUP:
18787
19846
  {
18788
- if (ggml_is_quantized(node->type)) {
19847
+ if (ggml_is_quantized(node->type) ||
19848
+ // F16 -> BF16 and BF16 -> F16 copies go through intermediate F32
19849
+ (node->src[0]->type == GGML_TYPE_F16 && node->src[1] && node->src[1]->type == GGML_TYPE_BF16) ||
19850
+ (node->src[0]->type == GGML_TYPE_BF16 && node->src[1] && node->src[1]->type == GGML_TYPE_F16)) {
18789
19851
  cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
18790
19852
  }
18791
19853
  } break;
@@ -18864,7 +19926,8 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
18864
19926
  const int64_t ne10 = node->src[1]->ne[0]; // L
18865
19927
  const int64_t ne11 = node->src[1]->ne[1]; // Cin
18866
19928
 
18867
- if (node->src[0]->type == GGML_TYPE_F16 &&
19929
+ if ((node->src[0]->type == GGML_TYPE_F16 ||
19930
+ node->src[0]->type == GGML_TYPE_BF16) &&
18868
19931
  node->src[1]->type == GGML_TYPE_F32) {
18869
19932
  cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02;
18870
19933
  cur += sizeof(ggml_fp16_t)*ne10*ne11;
@@ -18900,6 +19963,9 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
18900
19963
  } else if (node->src[1]->type == GGML_TYPE_F16) {
18901
19964
  cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1)
18902
19965
  cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
19966
+ } else if (node->src[1]->type == GGML_TYPE_BF16) {
19967
+ cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1)
19968
+ cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
18903
19969
  }
18904
19970
  } break;
18905
19971
  case GGML_OP_FLASH_ATTN_EXT:
@@ -18916,6 +19982,9 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
18916
19982
  } else if (node->src[1]->type == GGML_TYPE_F16) {
18917
19983
  cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
18918
19984
  cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2
19985
+ } else if (node->src[1]->type == GGML_TYPE_BF16) {
19986
+ cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
19987
+ cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2
18919
19988
  }
18920
19989
  } break;
18921
19990
  case GGML_OP_FLASH_ATTN_BACK:
@@ -18929,6 +19998,9 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
18929
19998
  } else if (node->src[1]->type == GGML_TYPE_F16) {
18930
19999
  cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1)
18931
20000
  cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
20001
+ } else if (node->src[1]->type == GGML_TYPE_BF16) {
20002
+ cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1)
20003
+ cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
18932
20004
  }
18933
20005
  } break;
18934
20006
 
@@ -18981,6 +20053,7 @@ enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cpl
18981
20053
  /*.node_task =*/ GGML_TASK_TYPE_FINALIZE,
18982
20054
  /*.abort_callback =*/ NULL,
18983
20055
  /*.abort_callback_data =*/ NULL,
20056
+ /*.current_chunk; =*/ 0,
18984
20057
  };
18985
20058
  struct ggml_compute_state * workers = alloca(sizeof(struct ggml_compute_state)*n_threads);
18986
20059
 
@@ -19705,7 +20778,9 @@ void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph
19705
20778
  if (node->type == GGML_TYPE_I8 || node->type == GGML_TYPE_I16 || node->type == GGML_TYPE_I32) {
19706
20779
  fprintf(fp, "%d", ggml_get_i32_1d(node, j));
19707
20780
  }
19708
- else if (node->type == GGML_TYPE_F32 || node->type == GGML_TYPE_F16) {
20781
+ else if (node->type == GGML_TYPE_F32 ||
20782
+ node->type == GGML_TYPE_F16 ||
20783
+ node->type == GGML_TYPE_BF16) {
19709
20784
  fprintf(fp, "%.1e", (double)ggml_get_f32_1d(node, j));
19710
20785
  }
19711
20786
  else {
@@ -20763,6 +21838,12 @@ size_t ggml_quantize_chunk(
20763
21838
  ggml_fp32_to_fp16_row(src + start, (ggml_fp16_t *)dst + start, n);
20764
21839
  result = n * elemsize;
20765
21840
  } break;
21841
+ case GGML_TYPE_BF16:
21842
+ {
21843
+ size_t elemsize = sizeof(ggml_bf16_t);
21844
+ ggml_fp32_to_bf16_row(src + start, (ggml_bf16_t *)dst + start, n);
21845
+ result = n * elemsize;
21846
+ } break;
20766
21847
  case GGML_TYPE_F32:
20767
21848
  {
20768
21849
  size_t elemsize = sizeof(float);
@@ -21139,7 +22220,7 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
21139
22220
  }
21140
22221
 
21141
22222
  // read the tensor infos
21142
- {
22223
+ if (ctx->header.n_tensors > 0) {
21143
22224
  ctx->infos = GGML_CALLOC(ctx->header.n_tensors, sizeof(struct gguf_tensor_info));
21144
22225
 
21145
22226
  for (uint64_t i = 0; i < ctx->header.n_tensors; ++i) {