llama_cpp 0.15.0 → 0.15.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -322,7 +322,7 @@ static ggml_fp16_t ggml_table_exp_f16[1 << 16];
322
322
  // precomputed f32 table for f16 (256 KB) (ggml-impl.h)
323
323
  float ggml_table_f32_f16[1 << 16];
324
324
 
325
- const char * ggml_status_to_string(enum ggml_status status) {
325
+ GGML_CALL const char * ggml_status_to_string(enum ggml_status status) {
326
326
  switch (status) {
327
327
  case GGML_STATUS_ALLOC_FAILED: return "GGML status: error (failed to allocate memory)";
328
328
  case GGML_STATUS_FAILED: return "GGML status: error (operation failed)";
@@ -333,16 +333,26 @@ const char * ggml_status_to_string(enum ggml_status status) {
333
333
  return "GGML status: unknown";
334
334
  }
335
335
 
336
- // note: do not use these inside ggml.c
337
- // these are meant to be used via the ggml.h API
338
336
  float ggml_fp16_to_fp32(ggml_fp16_t x) {
337
+ #define ggml_fp16_to_fp32 do_not_use__ggml_fp16_to_fp32__in_ggml
339
338
  return GGML_FP16_TO_FP32(x);
340
339
  }
341
340
 
342
341
  ggml_fp16_t ggml_fp32_to_fp16(float x) {
342
+ #define ggml_fp32_to_fp16 do_not_use__ggml_fp32_to_fp16__in_ggml
343
343
  return GGML_FP32_TO_FP16(x);
344
344
  }
345
345
 
346
+ float ggml_bf16_to_fp32(ggml_bf16_t x) {
347
+ #define ggml_bf16_to_fp32 do_not_use__ggml_bf16_to_fp32__in_ggml
348
+ return GGML_BF16_TO_FP32(x); // it just left shifts
349
+ }
350
+
351
+ ggml_bf16_t ggml_fp32_to_bf16(float x) {
352
+ #define ggml_fp32_to_bf16 do_not_use__ggml_fp32_to_bf16__in_ggml
353
+ return GGML_FP32_TO_BF16(x);
354
+ }
355
+
346
356
  void ggml_fp16_to_fp32_row(const ggml_fp16_t * x, float * y, int64_t n) {
347
357
  for (int64_t i = 0; i < n; i++) {
348
358
  y[i] = GGML_FP16_TO_FP32(x[i]);
@@ -368,6 +378,49 @@ void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, int64_t n) {
368
378
  }
369
379
  }
370
380
 
381
+ void ggml_bf16_to_fp32_row(const ggml_bf16_t * x, float * y, int64_t n) {
382
+ int64_t i = 0;
383
+ #if defined(__AVX512F__)
384
+ for (; i + 16 <= n; i += 16) {
385
+ _mm512_storeu_ps(y + i,
386
+ _mm512_castsi512_ps(
387
+ _mm512_slli_epi32(
388
+ _mm512_cvtepu16_epi32(
389
+ _mm256_loadu_si256(
390
+ (const __m256i *)(x + i))),
391
+ 16)));
392
+ }
393
+ #elif defined(__AVX2__)
394
+ for (; i + 8 <= n; i += 8) {
395
+ _mm256_storeu_ps(y + i,
396
+ _mm256_castsi256_ps(
397
+ _mm256_slli_epi32(
398
+ _mm256_cvtepu16_epi32(
399
+ _mm_loadu_si128(
400
+ (const __m128i *)(x + i))),
401
+ 16)));
402
+ }
403
+ #endif
404
+ for (; i < n; i++) {
405
+ y[i] = GGML_BF16_TO_FP32(x[i]);
406
+ }
407
+ }
408
+
409
+ void ggml_fp32_to_bf16_row(const float * x, ggml_bf16_t * y, int64_t n) {
410
+ int i = 0;
411
+ #if defined(__AVX512BF16__)
412
+ for (; i + 32 <= n; i += 32) {
413
+ _mm512_storeu_ps(
414
+ (__m512 *)(y + i),
415
+ (__m512)_mm512_cvtne2ps_pbh(_mm512_loadu_ps(x + i + 16),
416
+ _mm512_loadu_ps(x + i)));
417
+ }
418
+ #endif
419
+ for (; i < n; i++) {
420
+ y[i] = GGML_FP32_TO_BF16(x[i]);
421
+ }
422
+ }
423
+
371
424
  bool ggml_guid_matches(ggml_guid_t guid_a, ggml_guid_t guid_b) {
372
425
  return memcmp(guid_a, guid_b, sizeof(ggml_guid)) == 0;
373
426
  }
@@ -503,6 +556,7 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
503
556
 
504
557
  static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float * restrict x, size_t bx, const float * restrict y, size_t by, int nrc);
505
558
  static void ggml_vec_dot_f16(int n, float * restrict s, size_t bs, ggml_fp16_t * restrict x, size_t bx, ggml_fp16_t * restrict y, size_t by, int nrc);
559
+ static void ggml_vec_dot_bf16(int n, float * restrict s, size_t bs, ggml_bf16_t * restrict x, size_t bx, ggml_bf16_t * restrict y, size_t by, int nrc);
506
560
 
507
561
  static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
508
562
  [GGML_TYPE_I8] = {
@@ -845,6 +899,18 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
845
899
  .type_size = sizeof(block_q8_K),
846
900
  .is_quantized = true,
847
901
  .from_float = quantize_row_q8_K,
902
+ },
903
+ [GGML_TYPE_BF16] = {
904
+ .type_name = "bf16",
905
+ .blck_size = 1,
906
+ .type_size = sizeof(ggml_bf16_t),
907
+ .is_quantized = false,
908
+ .to_float = (ggml_to_float_t) ggml_bf16_to_fp32_row,
909
+ .from_float = (ggml_from_float_t) ggml_fp32_to_bf16_row,
910
+ .from_float_reference = (ggml_from_float_t) ggml_fp32_to_bf16_row,
911
+ .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_bf16,
912
+ .vec_dot_type = GGML_TYPE_BF16,
913
+ .nrows = 1,
848
914
  }
849
915
  };
850
916
 
@@ -1480,6 +1546,8 @@ inline static void ggml_vec_set_i32(const int n, int32_t * x, const int32_t v) {
1480
1546
 
1481
1547
  inline static void ggml_vec_set_f16(const int n, ggml_fp16_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
1482
1548
 
1549
+ inline static void ggml_vec_set_bf16(const int n, ggml_bf16_t * x, const ggml_bf16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
1550
+
1483
1551
  inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] + y[i]; }
1484
1552
  inline static void ggml_vec_add1_f32(const int n, float * z, const float * x, const float v) { for (int i = 0; i < n; ++i) z[i] = x[i] + v; }
1485
1553
  inline static void ggml_vec_acc_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] += x[i]; }
@@ -1498,7 +1566,7 @@ static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float *
1498
1566
  UNUSED(by);
1499
1567
  UNUSED(bs);
1500
1568
 
1501
- #ifdef GGML_SIMD
1569
+ #if defined(GGML_SIMD)
1502
1570
  float sumf = 0.0f;
1503
1571
  const int np = (n & ~(GGML_F32_STEP - 1));
1504
1572
 
@@ -1534,6 +1602,70 @@ static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float *
1534
1602
  *s = sumf;
1535
1603
  }
1536
1604
 
1605
+ static void ggml_vec_dot_bf16(int n, float * restrict s, size_t bs, ggml_bf16_t * restrict x, size_t bx, ggml_bf16_t * restrict y, size_t by, int nrc) {
1606
+ assert(nrc == 1);
1607
+ UNUSED(nrc);
1608
+ UNUSED(bx);
1609
+ UNUSED(by);
1610
+ UNUSED(bs);
1611
+ int i = 0;
1612
+ ggml_float sumf = 0;
1613
+
1614
+ #if defined(__AVX512BF16__)
1615
+ __m512 c1 = _mm512_setzero_ps();
1616
+ __m512 c2 = _mm512_setzero_ps();
1617
+ for (; i + 64 <= n; i += 64) {
1618
+ c1 = _mm512_dpbf16_ps(c1, (__m512bh)_mm512_loadu_ps((const float *)(x + i)),
1619
+ (__m512bh)_mm512_loadu_ps((const float *)(y + i)));
1620
+ c2 = _mm512_dpbf16_ps(c2, (__m512bh)_mm512_loadu_ps((const float *)(x + i + 32)),
1621
+ (__m512bh)_mm512_loadu_ps((const float *)(y + i + 32)));
1622
+ }
1623
+ sumf += (ggml_float)_mm512_reduce_add_ps(c1);
1624
+ sumf += (ggml_float)_mm512_reduce_add_ps(c2);
1625
+
1626
+ #elif defined(__AVX512F__)
1627
+ #define LOAD(p) _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i *)(p))), 16))
1628
+ __m512 c1 = _mm512_setzero_ps();
1629
+ __m512 c2 = _mm512_setzero_ps();
1630
+ for (; i + 32 <= n; i += 32) {
1631
+ c1 = _mm512_add_ps(_mm512_mul_ps(LOAD(x + i), LOAD(y + i)), c1);
1632
+ c2 = _mm512_add_ps(_mm512_mul_ps(LOAD(x + i + 16), LOAD(y + i + 16)), c2);
1633
+ }
1634
+ sumf += (ggml_float)_mm512_reduce_add_ps(c1);
1635
+ sumf += (ggml_float)_mm512_reduce_add_ps(c2);
1636
+
1637
+ #undef LOAD
1638
+ #elif defined(__AVX2__)
1639
+ #define LOAD(p) _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)(p))), 16))
1640
+ __m256 c1 = _mm256_setzero_ps();
1641
+ __m256 c2 = _mm256_setzero_ps();
1642
+ __m256 c3 = _mm256_setzero_ps();
1643
+ __m256 c4 = _mm256_setzero_ps();
1644
+ for (; i + 32 <= n; i += 32) {
1645
+ c1 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i), LOAD(y + i)), c1);
1646
+ c2 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i + 8), LOAD(y + i + 8)), c2);
1647
+ c3 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i + 16), LOAD(y + i + 16)), c3);
1648
+ c4 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i + 24), LOAD(y + i + 24)), c4);
1649
+ }
1650
+ __m128 g;
1651
+ c1 = _mm256_add_ps(_mm256_add_ps(c1, c3),
1652
+ _mm256_add_ps(c2, c4));
1653
+ g = _mm_add_ps(_mm256_extractf128_ps(c1, 1),
1654
+ _mm256_castps256_ps128(c1));
1655
+ g = _mm_add_ps(g, _mm_movehl_ps(g, g));
1656
+ g = _mm_add_ss(g, _mm_movehdup_ps(g));
1657
+ sumf += (ggml_float)_mm_cvtss_f32(g);
1658
+
1659
+ #undef LOAD
1660
+ #endif
1661
+
1662
+ for (; i < n; ++i) {
1663
+ sumf += (ggml_float)(GGML_BF16_TO_FP32(x[i]) *
1664
+ GGML_BF16_TO_FP32(y[i]));
1665
+ }
1666
+ *s = sumf;
1667
+ }
1668
+
1537
1669
  static void ggml_vec_dot_f16(int n, float * restrict s, size_t bs, ggml_fp16_t * restrict x, size_t bx, ggml_fp16_t * restrict y, size_t by, int nrc) {
1538
1670
  assert(nrc == 1);
1539
1671
  UNUSED(nrc);
@@ -1967,6 +2099,14 @@ inline static void ggml_vec_sum_f16_ggf(const int n, float * s, const ggml_fp16_
1967
2099
  *s = sum;
1968
2100
  }
1969
2101
 
2102
+ inline static void ggml_vec_sum_bf16_ggf(const int n, float * s, const ggml_bf16_t * x) {
2103
+ float sum = 0.0f;
2104
+ for (int i = 0; i < n; ++i) {
2105
+ sum += GGML_BF16_TO_FP32(x[i]);
2106
+ }
2107
+ *s = sum;
2108
+ }
2109
+
1970
2110
  inline static void ggml_vec_max_f32(const int n, float * s, const float * x) {
1971
2111
  #ifndef GGML_USE_ACCELERATE
1972
2112
  float max = -INFINITY;
@@ -2377,7 +2517,7 @@ void ggml_numa_init(enum ggml_numa_strategy numa_flag) {
2377
2517
  // figure out which node we're on
2378
2518
  uint current_cpu;
2379
2519
  int getcpu_ret = 0;
2380
- #if __GLIBC__ > 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ > 28)
2520
+ #if __GLIBC__ > 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ > 28) || defined(__COSMOPOLITAN__)
2381
2521
  getcpu_ret = getcpu(&current_cpu, &g_state.numa.current_node);
2382
2522
  #else
2383
2523
  // old glibc doesn't have a wrapper for this call. Fall back on direct syscall
@@ -2588,6 +2728,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
2588
2728
  switch (ftype) {
2589
2729
  case GGML_FTYPE_ALL_F32: wtype = GGML_TYPE_F32; break;
2590
2730
  case GGML_FTYPE_MOSTLY_F16: wtype = GGML_TYPE_F16; break;
2731
+ case GGML_FTYPE_MOSTLY_BF16: wtype = GGML_TYPE_BF16; break;
2591
2732
  case GGML_FTYPE_MOSTLY_Q4_0: wtype = GGML_TYPE_Q4_0; break;
2592
2733
  case GGML_FTYPE_MOSTLY_Q4_1: wtype = GGML_TYPE_Q4_1; break;
2593
2734
  case GGML_FTYPE_MOSTLY_Q5_0: wtype = GGML_TYPE_Q5_0; break;
@@ -2729,15 +2870,16 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
2729
2870
  {
2730
2871
  const uint64_t t_start = ggml_time_us(); UNUSED(t_start);
2731
2872
 
2732
- ggml_fp16_t ii;
2733
2873
  for (int i = 0; i < (1 << 16); ++i) {
2734
- uint16_t ui = i;
2735
- memcpy(&ii, &ui, sizeof(ii));
2736
- const float f = ggml_table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(ii);
2874
+ union {
2875
+ uint16_t u16;
2876
+ ggml_fp16_t fp16;
2877
+ } u = {i};
2878
+ float f = ggml_table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(u.fp16);
2737
2879
  ggml_table_gelu_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_f32(f));
2738
2880
  ggml_table_gelu_quick_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_quick_f32(f));
2739
2881
  ggml_table_silu_f16[i] = GGML_FP32_TO_FP16(ggml_silu_f32(f));
2740
- ggml_table_exp_f16[i] = GGML_FP32_TO_FP16(expf(f));
2882
+ ggml_table_exp_f16[i] = GGML_FP32_TO_FP16(expf(f));
2741
2883
  }
2742
2884
 
2743
2885
  const uint64_t t_end = ggml_time_us(); UNUSED(t_end);
@@ -3201,6 +3343,13 @@ struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value) {
3201
3343
  ggml_vec_set_f16(nc, (ggml_fp16_t *)(data + i*n1), GGML_FP32_TO_FP16(value));
3202
3344
  }
3203
3345
  } break;
3346
+ case GGML_TYPE_BF16:
3347
+ {
3348
+ assert(tensor->nb[0] == sizeof(ggml_fp16_t));
3349
+ for (int i = 0; i < n; i++) {
3350
+ ggml_vec_set_bf16(nc, (ggml_bf16_t *)(data + i*n1), GGML_FP32_TO_BF16(value));
3351
+ }
3352
+ } break;
3204
3353
  case GGML_TYPE_F32:
3205
3354
  {
3206
3355
  assert(tensor->nb[0] == sizeof(float));
@@ -3253,6 +3402,13 @@ struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) {
3253
3402
  ggml_vec_set_f16(nc, (ggml_fp16_t *)(data + i*n1), GGML_FP32_TO_FP16(value));
3254
3403
  }
3255
3404
  } break;
3405
+ case GGML_TYPE_BF16:
3406
+ {
3407
+ assert(tensor->nb[0] == sizeof(ggml_bf16_t));
3408
+ for (int i = 0; i < n; i++) {
3409
+ ggml_vec_set_bf16(nc, (ggml_bf16_t *)(data + i*n1), GGML_FP32_TO_BF16(value));
3410
+ }
3411
+ } break;
3256
3412
  case GGML_TYPE_F32:
3257
3413
  {
3258
3414
  assert(tensor->nb[0] == sizeof(float));
@@ -3320,6 +3476,11 @@ int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) {
3320
3476
  GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t));
3321
3477
  return GGML_FP16_TO_FP32(((ggml_fp16_t *)(tensor->data))[i]);
3322
3478
  }
3479
+ case GGML_TYPE_BF16:
3480
+ {
3481
+ GGML_ASSERT(tensor->nb[0] == sizeof(ggml_bf16_t));
3482
+ return GGML_BF16_TO_FP32(((ggml_bf16_t *)(tensor->data))[i]);
3483
+ }
3323
3484
  case GGML_TYPE_F32:
3324
3485
  {
3325
3486
  GGML_ASSERT(tensor->nb[0] == sizeof(float));
@@ -3362,6 +3523,11 @@ void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) {
3362
3523
  GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t));
3363
3524
  ((ggml_fp16_t *)(tensor->data))[i] = GGML_FP32_TO_FP16(value);
3364
3525
  } break;
3526
+ case GGML_TYPE_BF16:
3527
+ {
3528
+ GGML_ASSERT(tensor->nb[0] == sizeof(ggml_bf16_t));
3529
+ ((ggml_bf16_t *)(tensor->data))[i] = GGML_FP32_TO_BF16(value);
3530
+ } break;
3365
3531
  case GGML_TYPE_F32:
3366
3532
  {
3367
3533
  GGML_ASSERT(tensor->nb[0] == sizeof(float));
@@ -3385,6 +3551,8 @@ int32_t ggml_get_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i
3385
3551
  return ((int32_t *) data)[0];
3386
3552
  case GGML_TYPE_F16:
3387
3553
  return GGML_FP16_TO_FP32(((ggml_fp16_t *) data)[0]);
3554
+ case GGML_TYPE_BF16:
3555
+ return GGML_BF16_TO_FP32(((ggml_bf16_t *) data)[0]);
3388
3556
  case GGML_TYPE_F32:
3389
3557
  return ((float *) data)[0];
3390
3558
  default:
@@ -3413,6 +3581,10 @@ void ggml_set_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2,
3413
3581
  {
3414
3582
  ((ggml_fp16_t *)(data))[0] = GGML_FP32_TO_FP16(value);
3415
3583
  } break;
3584
+ case GGML_TYPE_BF16:
3585
+ {
3586
+ ((ggml_bf16_t *)(data))[0] = GGML_FP32_TO_BF16(value);
3587
+ } break;
3416
3588
  case GGML_TYPE_F32:
3417
3589
  {
3418
3590
  ((float *)(data))[0] = value;
@@ -3451,6 +3623,11 @@ float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) {
3451
3623
  GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t));
3452
3624
  return GGML_FP16_TO_FP32(((ggml_fp16_t *)(tensor->data))[i]);
3453
3625
  }
3626
+ case GGML_TYPE_BF16:
3627
+ {
3628
+ GGML_ASSERT(tensor->nb[0] == sizeof(ggml_bf16_t));
3629
+ return GGML_BF16_TO_FP32(((ggml_bf16_t *)(tensor->data))[i]);
3630
+ }
3454
3631
  case GGML_TYPE_F32:
3455
3632
  {
3456
3633
  GGML_ASSERT(tensor->nb[0] == sizeof(float));
@@ -3493,6 +3670,11 @@ void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) {
3493
3670
  GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t));
3494
3671
  ((ggml_fp16_t *)(tensor->data))[i] = GGML_FP32_TO_FP16(value);
3495
3672
  } break;
3673
+ case GGML_TYPE_BF16:
3674
+ {
3675
+ GGML_ASSERT(tensor->nb[0] == sizeof(ggml_bf16_t));
3676
+ ((ggml_bf16_t *)(tensor->data))[i] = GGML_FP32_TO_BF16(value);
3677
+ } break;
3496
3678
  case GGML_TYPE_F32:
3497
3679
  {
3498
3680
  GGML_ASSERT(tensor->nb[0] == sizeof(float));
@@ -3516,6 +3698,8 @@ float ggml_get_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2,
3516
3698
  return ((int32_t *) data)[0];
3517
3699
  case GGML_TYPE_F16:
3518
3700
  return GGML_FP16_TO_FP32(((ggml_fp16_t *) data)[0]);
3701
+ case GGML_TYPE_BF16:
3702
+ return GGML_BF16_TO_FP32(((ggml_bf16_t *) data)[0]);
3519
3703
  case GGML_TYPE_F32:
3520
3704
  return ((float *) data)[0];
3521
3705
  default:
@@ -3544,6 +3728,10 @@ void ggml_set_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2,
3544
3728
  {
3545
3729
  ((ggml_fp16_t *)(data))[0] = GGML_FP32_TO_FP16(value);
3546
3730
  } break;
3731
+ case GGML_TYPE_BF16:
3732
+ {
3733
+ ((ggml_bf16_t *)(data))[0] = GGML_FP32_TO_BF16(value);
3734
+ } break;
3547
3735
  case GGML_TYPE_F32:
3548
3736
  {
3549
3737
  ((float *)(data))[0] = value;
@@ -3738,7 +3926,11 @@ static struct ggml_tensor * ggml_add_cast_impl(
3738
3926
  // TODO: support less-strict constraint
3739
3927
  // GGML_ASSERT(ggml_can_repeat(b, a));
3740
3928
  GGML_ASSERT(ggml_can_repeat_rows(b, a));
3741
- GGML_ASSERT(ggml_is_quantized(a->type) || a->type == GGML_TYPE_F16); // currently only supported for quantized input and f16
3929
+
3930
+ // currently only supported for quantized input and f16
3931
+ GGML_ASSERT(ggml_is_quantized(a->type) ||
3932
+ a->type == GGML_TYPE_F16 ||
3933
+ a->type == GGML_TYPE_BF16);
3742
3934
 
3743
3935
  bool is_node = false;
3744
3936
 
@@ -7215,8 +7407,8 @@ static void ggml_compute_forward_dup_same_cont(
7215
7407
  ((char *) src0->data + ie0*nb00),
7216
7408
  (ie1 - ie0) * ggml_type_size(src0->type));
7217
7409
  }
7218
-
7219
7410
  }
7411
+
7220
7412
  static void ggml_compute_forward_dup_f16(
7221
7413
  const struct ggml_compute_params * params,
7222
7414
  struct ggml_tensor * dst) {
@@ -7490,7 +7682,7 @@ static void ggml_compute_forward_dup_f16(
7490
7682
  }
7491
7683
  }
7492
7684
 
7493
- static void ggml_compute_forward_dup_f32(
7685
+ static void ggml_compute_forward_dup_bf16(
7494
7686
  const struct ggml_compute_params * params,
7495
7687
  struct ggml_tensor * dst) {
7496
7688
 
@@ -7538,10 +7730,11 @@ static void ggml_compute_forward_dup_f32(
7538
7730
  return;
7539
7731
  }
7540
7732
 
7733
+ // TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy
7734
+
7541
7735
  if (ggml_is_contiguous(dst)) {
7542
- // TODO: simplify
7543
- if (nb00 == sizeof(float)) {
7544
- if (dst->type == GGML_TYPE_F32) {
7736
+ if (nb00 == sizeof(ggml_bf16_t)) {
7737
+ if (dst->type == GGML_TYPE_BF16) {
7545
7738
  size_t id = 0;
7546
7739
  const size_t rs = ne00 * nb00;
7547
7740
  char * dst_ptr = (char *) dst->data;
@@ -7557,8 +7750,43 @@ static void ggml_compute_forward_dup_f32(
7557
7750
  id += rs * (ne01 - ir1);
7558
7751
  }
7559
7752
  }
7753
+ } else if (dst->type == GGML_TYPE_F16) {
7754
+ size_t id = 0;
7755
+ ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
7756
+
7757
+ for (int i03 = 0; i03 < ne03; i03++) {
7758
+ for (int i02 = 0; i02 < ne02; i02++) {
7759
+ id += ne00 * ir0;
7760
+ for (int i01 = ir0; i01 < ir1; i01++) {
7761
+ const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
7762
+ for (int i00 = 0; i00 < ne00; i00++) {
7763
+ dst_ptr[id] = GGML_FP32_TO_FP16(GGML_BF16_TO_FP32(src0_ptr[i00]));
7764
+ id++;
7765
+ }
7766
+ }
7767
+ id += ne00 * (ne01 - ir1);
7768
+ }
7769
+ }
7770
+ } else if (dst->type == GGML_TYPE_F32) {
7771
+ size_t id = 0;
7772
+ float * dst_ptr = (float *) dst->data;
7773
+
7774
+ for (int i03 = 0; i03 < ne03; i03++) {
7775
+ for (int i02 = 0; i02 < ne02; i02++) {
7776
+ id += ne00 * ir0;
7777
+ for (int i01 = ir0; i01 < ir1; i01++) {
7778
+ const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
7779
+ for (int i00 = 0; i00 < ne00; i00++) {
7780
+ dst_ptr[id] = GGML_BF16_TO_FP32(src0_ptr[i00]);
7781
+ id++;
7782
+ }
7783
+ }
7784
+ id += ne00 * (ne01 - ir1);
7785
+ }
7786
+ }
7560
7787
  } else if (type_traits[dst->type].from_float) {
7561
7788
  ggml_from_float_t const quantize_row_q = type_traits[dst->type].from_float;
7789
+ float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
7562
7790
 
7563
7791
  size_t id = 0;
7564
7792
  size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
@@ -7568,8 +7796,13 @@ static void ggml_compute_forward_dup_f32(
7568
7796
  for (int i02 = 0; i02 < ne02; i02++) {
7569
7797
  id += rs * ir0;
7570
7798
  for (int i01 = ir0; i01 < ir1; i01++) {
7571
- const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
7572
- quantize_row_q(src0_ptr, dst_ptr + id, ne00);
7799
+ const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
7800
+
7801
+ for (int i00 = 0; i00 < ne00; i00++) {
7802
+ src0_f32[i00] = GGML_BF16_TO_FP32(src0_ptr[i00]);
7803
+ }
7804
+
7805
+ quantize_row_q(src0_f32, dst_ptr + id, ne00);
7573
7806
  id += rs;
7574
7807
  }
7575
7808
  id += rs * (ne01 - ir1);
@@ -7590,7 +7823,25 @@ static void ggml_compute_forward_dup_f32(
7590
7823
  id += ne00 * ir0;
7591
7824
  for (int i01 = ir0; i01 < ir1; i01++) {
7592
7825
  for (int i00 = 0; i00 < ne00; i00++) {
7593
- const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
7826
+ const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
7827
+
7828
+ dst_ptr[id] = GGML_BF16_TO_FP32(*src0_ptr);
7829
+ id++;
7830
+ }
7831
+ }
7832
+ id += ne00 * (ne01 - ir1);
7833
+ }
7834
+ }
7835
+ } else if (dst->type == GGML_TYPE_BF16) {
7836
+ size_t id = 0;
7837
+ ggml_bf16_t * dst_ptr = (ggml_bf16_t *) dst->data;
7838
+
7839
+ for (int i03 = 0; i03 < ne03; i03++) {
7840
+ for (int i02 = 0; i02 < ne02; i02++) {
7841
+ id += ne00 * ir0;
7842
+ for (int i01 = ir0; i01 < ir1; i01++) {
7843
+ for (int i00 = 0; i00 < ne00; i00++) {
7844
+ const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
7594
7845
 
7595
7846
  dst_ptr[id] = *src0_ptr;
7596
7847
  id++;
@@ -7608,9 +7859,9 @@ static void ggml_compute_forward_dup_f32(
7608
7859
  id += ne00 * ir0;
7609
7860
  for (int i01 = ir0; i01 < ir1; i01++) {
7610
7861
  for (int i00 = 0; i00 < ne00; i00++) {
7611
- const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
7862
+ const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
7612
7863
 
7613
- dst_ptr[id] = GGML_FP32_TO_FP16(*src0_ptr);
7864
+ dst_ptr[id] = GGML_FP32_TO_FP16(GGML_BF16_TO_FP32(*src0_ptr));
7614
7865
  id++;
7615
7866
  }
7616
7867
  }
@@ -7621,18 +7872,16 @@ static void ggml_compute_forward_dup_f32(
7621
7872
  GGML_ASSERT(false); // TODO: implement
7622
7873
  }
7623
7874
  }
7624
-
7625
7875
  return;
7626
7876
  }
7627
7877
 
7628
7878
  // dst counters
7629
-
7630
7879
  int64_t i10 = 0;
7631
7880
  int64_t i11 = 0;
7632
7881
  int64_t i12 = 0;
7633
7882
  int64_t i13 = 0;
7634
7883
 
7635
- if (dst->type == GGML_TYPE_F32) {
7884
+ if (dst->type == GGML_TYPE_BF16) {
7636
7885
  for (int64_t i03 = 0; i03 < ne03; i03++) {
7637
7886
  for (int64_t i02 = 0; i02 < ne02; i02++) {
7638
7887
  i10 += ne00 * ir0;
@@ -7653,7 +7902,59 @@ static void ggml_compute_forward_dup_f32(
7653
7902
  const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
7654
7903
  char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
7655
7904
 
7656
- memcpy(dst_ptr, src0_ptr, sizeof(float));
7905
+ memcpy(dst_ptr, src0_ptr, sizeof(ggml_bf16_t));
7906
+
7907
+ if (++i10 == ne00) {
7908
+ i10 = 0;
7909
+ if (++i11 == ne01) {
7910
+ i11 = 0;
7911
+ if (++i12 == ne02) {
7912
+ i12 = 0;
7913
+ if (++i13 == ne03) {
7914
+ i13 = 0;
7915
+ }
7916
+ }
7917
+ }
7918
+ }
7919
+ }
7920
+ }
7921
+ i10 += ne00 * (ne01 - ir1);
7922
+ while (i10 >= ne0) {
7923
+ i10 -= ne0;
7924
+ if (++i11 == ne1) {
7925
+ i11 = 0;
7926
+ if (++i12 == ne2) {
7927
+ i12 = 0;
7928
+ if (++i13 == ne3) {
7929
+ i13 = 0;
7930
+ }
7931
+ }
7932
+ }
7933
+ }
7934
+ }
7935
+ }
7936
+ } else if (dst->type == GGML_TYPE_F16) {
7937
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
7938
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
7939
+ i10 += ne00 * ir0;
7940
+ while (i10 >= ne0) {
7941
+ i10 -= ne0;
7942
+ if (++i11 == ne1) {
7943
+ i11 = 0;
7944
+ if (++i12 == ne2) {
7945
+ i12 = 0;
7946
+ if (++i13 == ne3) {
7947
+ i13 = 0;
7948
+ }
7949
+ }
7950
+ }
7951
+ }
7952
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
7953
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
7954
+ const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
7955
+ char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
7956
+
7957
+ *(ggml_fp16_t *) dst_ptr = GGML_FP32_TO_FP16(GGML_BF16_TO_FP32(*(const ggml_bf16_t *) src0_ptr));
7657
7958
 
7658
7959
  if (++i10 == ne0) {
7659
7960
  i10 = 0;
@@ -7684,7 +7985,7 @@ static void ggml_compute_forward_dup_f32(
7684
7985
  }
7685
7986
  }
7686
7987
  }
7687
- } else if (dst->type == GGML_TYPE_F16) {
7988
+ } else if (dst->type == GGML_TYPE_F32) {
7688
7989
  for (int64_t i03 = 0; i03 < ne03; i03++) {
7689
7990
  for (int64_t i02 = 0; i02 < ne02; i02++) {
7690
7991
  i10 += ne00 * ir0;
@@ -7705,7 +8006,7 @@ static void ggml_compute_forward_dup_f32(
7705
8006
  const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
7706
8007
  char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
7707
8008
 
7708
- *(ggml_fp16_t *) dst_ptr = GGML_FP32_TO_FP16(*(const float *) src0_ptr);
8009
+ *(float *) dst_ptr = GGML_BF16_TO_FP32(*(const ggml_bf16_t *) src0_ptr);
7709
8010
 
7710
8011
  if (++i10 == ne0) {
7711
8012
  i10 = 0;
@@ -7741,31 +8042,27 @@ static void ggml_compute_forward_dup_f32(
7741
8042
  }
7742
8043
  }
7743
8044
 
7744
- // A simplified version of ggml_compute_forward_dup that doesn't do float upcasting, and just plain old memcpy.
7745
- static void ggml_compute_forward_dup_bytes(
8045
+ static void ggml_compute_forward_dup_f32(
7746
8046
  const struct ggml_compute_params * params,
7747
8047
  struct ggml_tensor * dst) {
7748
8048
 
7749
8049
  const struct ggml_tensor * src0 = dst->src[0];
7750
8050
 
7751
8051
  GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
7752
- GGML_ASSERT(src0->type == dst->type);
7753
8052
 
7754
8053
  if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
7755
8054
  return;
7756
8055
  }
7757
8056
 
7758
- if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst)) {
7759
- ggml_compute_forward_dup_same_cont(params, dst);
7760
- return;
7761
- }
7762
-
7763
- GGML_TENSOR_UNARY_OP_LOCALS;
8057
+ GGML_TENSOR_UNARY_OP_LOCALS
7764
8058
 
7765
- const size_t type_size = ggml_type_size(src0->type);
7766
8059
  const int ith = params->ith; // thread index
7767
8060
  const int nth = params->nth; // number of threads
7768
8061
 
8062
+ if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst) && src0->type == dst->type) {
8063
+ ggml_compute_forward_dup_same_cont(params, dst);
8064
+ return;
8065
+ }
7769
8066
 
7770
8067
  // parallelize by rows
7771
8068
  const int nr = ne01;
@@ -7777,9 +8074,9 @@ static void ggml_compute_forward_dup_bytes(
7777
8074
 
7778
8075
  if (src0->type == dst->type &&
7779
8076
  ne00 == ne0 &&
7780
- nb00 == type_size && nb0 == type_size) {
8077
+ nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) {
7781
8078
  // copy by rows
7782
- const size_t rs = ne00 * type_size;
8079
+ const size_t rs = ne00*nb00;
7783
8080
  for (int64_t i03 = 0; i03 < ne03; i03++) {
7784
8081
  for (int64_t i02 = 0; i02 < ne02; i02++) {
7785
8082
  for (int64_t i01 = ir0; i01 < ir1; i01++) {
@@ -7794,41 +8091,366 @@ static void ggml_compute_forward_dup_bytes(
7794
8091
  }
7795
8092
 
7796
8093
  if (ggml_is_contiguous(dst)) {
7797
- size_t id = 0;
7798
- char * dst_ptr = (char *) dst->data;
7799
- const size_t rs = ne00 * type_size;
7800
-
7801
- if (nb00 == type_size) {
7802
- // src0 is contigous on first dimension, copy by rows
7803
- for (int64_t i03 = 0; i03 < ne03; i03++) {
7804
- for (int64_t i02 = 0; i02 < ne02; i02++) {
7805
- id += rs * ir0;
7806
- for (int64_t i01 = ir0; i01 < ir1; i01++) {
7807
- const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
7808
- memcpy(dst_ptr + id, src0_ptr, rs);
7809
- id += rs;
7810
- }
7811
- id += rs * (ne01 - ir1);
7812
- }
7813
- }
7814
- } else {
7815
- //printf("%s: this is not optimal - fix me\n", __func__);
7816
-
7817
- for (int64_t i03 = 0; i03 < ne03; i03++) {
7818
- for (int64_t i02 = 0; i02 < ne02; i02++) {
7819
- id += rs * ir0;
7820
- for (int64_t i01 = ir0; i01 < ir1; i01++) {
7821
- for (int64_t i00 = 0; i00 < ne00; i00++) {
7822
- const char * src0_ptr = (char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03;
7823
- memcpy(dst_ptr + id, src0_ptr, type_size);
8094
+ // TODO: simplify
8095
+ if (nb00 == sizeof(float)) {
8096
+ if (dst->type == GGML_TYPE_F32) {
8097
+ size_t id = 0;
8098
+ const size_t rs = ne00 * nb00;
8099
+ char * dst_ptr = (char *) dst->data;
7824
8100
 
7825
- id += type_size;
8101
+ for (int i03 = 0; i03 < ne03; i03++) {
8102
+ for (int i02 = 0; i02 < ne02; i02++) {
8103
+ id += rs * ir0;
8104
+ for (int i01 = ir0; i01 < ir1; i01++) {
8105
+ const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
8106
+ memcpy(dst_ptr + id, src0_ptr, rs);
8107
+ id += rs;
7826
8108
  }
8109
+ id += rs * (ne01 - ir1);
7827
8110
  }
7828
- id += rs * (ne01 - ir1);
7829
8111
  }
7830
- }
7831
- }
8112
+ } else if (type_traits[dst->type].from_float) {
8113
+ ggml_from_float_t const quantize_row_q = type_traits[dst->type].from_float;
8114
+
8115
+ size_t id = 0;
8116
+ size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
8117
+ char * dst_ptr = (char *) dst->data;
8118
+
8119
+ for (int i03 = 0; i03 < ne03; i03++) {
8120
+ for (int i02 = 0; i02 < ne02; i02++) {
8121
+ id += rs * ir0;
8122
+ for (int i01 = ir0; i01 < ir1; i01++) {
8123
+ const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
8124
+ quantize_row_q(src0_ptr, dst_ptr + id, ne00);
8125
+ id += rs;
8126
+ }
8127
+ id += rs * (ne01 - ir1);
8128
+ }
8129
+ }
8130
+ } else {
8131
+ GGML_ASSERT(false); // TODO: implement
8132
+ }
8133
+ } else {
8134
+ //printf("%s: this is not optimal - fix me\n", __func__);
8135
+
8136
+ if (dst->type == GGML_TYPE_F32) {
8137
+ size_t id = 0;
8138
+ float * dst_ptr = (float *) dst->data;
8139
+
8140
+ for (int i03 = 0; i03 < ne03; i03++) {
8141
+ for (int i02 = 0; i02 < ne02; i02++) {
8142
+ id += ne00 * ir0;
8143
+ for (int i01 = ir0; i01 < ir1; i01++) {
8144
+ for (int i00 = 0; i00 < ne00; i00++) {
8145
+ const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
8146
+
8147
+ dst_ptr[id] = *src0_ptr;
8148
+ id++;
8149
+ }
8150
+ }
8151
+ id += ne00 * (ne01 - ir1);
8152
+ }
8153
+ }
8154
+ } else if (dst->type == GGML_TYPE_F16) {
8155
+ size_t id = 0;
8156
+ ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
8157
+
8158
+ for (int i03 = 0; i03 < ne03; i03++) {
8159
+ for (int i02 = 0; i02 < ne02; i02++) {
8160
+ id += ne00 * ir0;
8161
+ for (int i01 = ir0; i01 < ir1; i01++) {
8162
+ for (int i00 = 0; i00 < ne00; i00++) {
8163
+ const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
8164
+
8165
+ dst_ptr[id] = GGML_FP32_TO_FP16(*src0_ptr);
8166
+ id++;
8167
+ }
8168
+ }
8169
+ id += ne00 * (ne01 - ir1);
8170
+ }
8171
+ }
8172
+ } else if (dst->type == GGML_TYPE_BF16) {
8173
+ size_t id = 0;
8174
+ ggml_bf16_t * dst_ptr = (ggml_bf16_t *) dst->data;
8175
+
8176
+ for (int i03 = 0; i03 < ne03; i03++) {
8177
+ for (int i02 = 0; i02 < ne02; i02++) {
8178
+ id += ne00 * ir0;
8179
+ for (int i01 = ir0; i01 < ir1; i01++) {
8180
+ for (int i00 = 0; i00 < ne00; i00++) {
8181
+ const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
8182
+
8183
+ dst_ptr[id] = GGML_FP32_TO_BF16(*src0_ptr);
8184
+ id++;
8185
+ }
8186
+ }
8187
+ id += ne00 * (ne01 - ir1);
8188
+ }
8189
+ }
8190
+ } else {
8191
+ GGML_ASSERT(false); // TODO: implement
8192
+ }
8193
+ }
8194
+
8195
+ return;
8196
+ }
8197
+
8198
+ // dst counters
8199
+
8200
+ int64_t i10 = 0;
8201
+ int64_t i11 = 0;
8202
+ int64_t i12 = 0;
8203
+ int64_t i13 = 0;
8204
+
8205
+ if (dst->type == GGML_TYPE_F32) {
8206
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
8207
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
8208
+ i10 += ne00 * ir0;
8209
+ while (i10 >= ne0) {
8210
+ i10 -= ne0;
8211
+ if (++i11 == ne1) {
8212
+ i11 = 0;
8213
+ if (++i12 == ne2) {
8214
+ i12 = 0;
8215
+ if (++i13 == ne3) {
8216
+ i13 = 0;
8217
+ }
8218
+ }
8219
+ }
8220
+ }
8221
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
8222
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
8223
+ const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
8224
+ char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
8225
+
8226
+ memcpy(dst_ptr, src0_ptr, sizeof(float));
8227
+
8228
+ if (++i10 == ne0) {
8229
+ i10 = 0;
8230
+ if (++i11 == ne1) {
8231
+ i11 = 0;
8232
+ if (++i12 == ne2) {
8233
+ i12 = 0;
8234
+ if (++i13 == ne3) {
8235
+ i13 = 0;
8236
+ }
8237
+ }
8238
+ }
8239
+ }
8240
+ }
8241
+ }
8242
+ i10 += ne00 * (ne01 - ir1);
8243
+ while (i10 >= ne0) {
8244
+ i10 -= ne0;
8245
+ if (++i11 == ne1) {
8246
+ i11 = 0;
8247
+ if (++i12 == ne2) {
8248
+ i12 = 0;
8249
+ if (++i13 == ne3) {
8250
+ i13 = 0;
8251
+ }
8252
+ }
8253
+ }
8254
+ }
8255
+ }
8256
+ }
8257
+ } else if (dst->type == GGML_TYPE_F16) {
8258
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
8259
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
8260
+ i10 += ne00 * ir0;
8261
+ while (i10 >= ne0) {
8262
+ i10 -= ne0;
8263
+ if (++i11 == ne1) {
8264
+ i11 = 0;
8265
+ if (++i12 == ne2) {
8266
+ i12 = 0;
8267
+ if (++i13 == ne3) {
8268
+ i13 = 0;
8269
+ }
8270
+ }
8271
+ }
8272
+ }
8273
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
8274
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
8275
+ const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
8276
+ char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
8277
+
8278
+ *(ggml_fp16_t *) dst_ptr = GGML_FP32_TO_FP16(*(const float *) src0_ptr);
8279
+
8280
+ if (++i10 == ne0) {
8281
+ i10 = 0;
8282
+ if (++i11 == ne1) {
8283
+ i11 = 0;
8284
+ if (++i12 == ne2) {
8285
+ i12 = 0;
8286
+ if (++i13 == ne3) {
8287
+ i13 = 0;
8288
+ }
8289
+ }
8290
+ }
8291
+ }
8292
+ }
8293
+ }
8294
+ i10 += ne00 * (ne01 - ir1);
8295
+ while (i10 >= ne0) {
8296
+ i10 -= ne0;
8297
+ if (++i11 == ne1) {
8298
+ i11 = 0;
8299
+ if (++i12 == ne2) {
8300
+ i12 = 0;
8301
+ if (++i13 == ne3) {
8302
+ i13 = 0;
8303
+ }
8304
+ }
8305
+ }
8306
+ }
8307
+ }
8308
+ }
8309
+ } else if (dst->type == GGML_TYPE_BF16) {
8310
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
8311
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
8312
+ i10 += ne00 * ir0;
8313
+ while (i10 >= ne0) {
8314
+ i10 -= ne0;
8315
+ if (++i11 == ne1) {
8316
+ i11 = 0;
8317
+ if (++i12 == ne2) {
8318
+ i12 = 0;
8319
+ if (++i13 == ne3) {
8320
+ i13 = 0;
8321
+ }
8322
+ }
8323
+ }
8324
+ }
8325
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
8326
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
8327
+ const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
8328
+ char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
8329
+
8330
+ *(ggml_bf16_t *) dst_ptr = GGML_FP32_TO_BF16(*(const float *) src0_ptr);
8331
+
8332
+ if (++i10 == ne0) {
8333
+ i10 = 0;
8334
+ if (++i11 == ne1) {
8335
+ i11 = 0;
8336
+ if (++i12 == ne2) {
8337
+ i12 = 0;
8338
+ if (++i13 == ne3) {
8339
+ i13 = 0;
8340
+ }
8341
+ }
8342
+ }
8343
+ }
8344
+ }
8345
+ }
8346
+ i10 += ne00 * (ne01 - ir1);
8347
+ while (i10 >= ne0) {
8348
+ i10 -= ne0;
8349
+ if (++i11 == ne1) {
8350
+ i11 = 0;
8351
+ if (++i12 == ne2) {
8352
+ i12 = 0;
8353
+ if (++i13 == ne3) {
8354
+ i13 = 0;
8355
+ }
8356
+ }
8357
+ }
8358
+ }
8359
+ }
8360
+ }
8361
+ } else {
8362
+ GGML_ASSERT(false); // TODO: implement
8363
+ }
8364
+ }
8365
+
8366
+ // A simplified version of ggml_compute_forward_dup that doesn't do float upcasting, and just plain old memcpy.
8367
+ static void ggml_compute_forward_dup_bytes(
8368
+ const struct ggml_compute_params * params,
8369
+ struct ggml_tensor * dst) {
8370
+
8371
+ const struct ggml_tensor * src0 = dst->src[0];
8372
+
8373
+ GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
8374
+ GGML_ASSERT(src0->type == dst->type);
8375
+
8376
+ if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
8377
+ return;
8378
+ }
8379
+
8380
+ if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst)) {
8381
+ ggml_compute_forward_dup_same_cont(params, dst);
8382
+ return;
8383
+ }
8384
+
8385
+ GGML_TENSOR_UNARY_OP_LOCALS;
8386
+
8387
+ const size_t type_size = ggml_type_size(src0->type);
8388
+ const int ith = params->ith; // thread index
8389
+ const int nth = params->nth; // number of threads
8390
+
8391
+
8392
+ // parallelize by rows
8393
+ const int nr = ne01;
8394
+ // number of rows per thread
8395
+ const int dr = (nr + nth - 1) / nth;
8396
+ // row range for this thread
8397
+ const int ir0 = dr * ith;
8398
+ const int ir1 = MIN(ir0 + dr, nr);
8399
+
8400
+ if (src0->type == dst->type &&
8401
+ ne00 == ne0 &&
8402
+ nb00 == type_size && nb0 == type_size) {
8403
+ // copy by rows
8404
+ const size_t rs = ne00 * type_size;
8405
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
8406
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
8407
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
8408
+ memcpy(
8409
+ ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
8410
+ ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
8411
+ rs);
8412
+ }
8413
+ }
8414
+ }
8415
+ return;
8416
+ }
8417
+
8418
+ if (ggml_is_contiguous(dst)) {
8419
+ size_t id = 0;
8420
+ char * dst_ptr = (char *) dst->data;
8421
+ const size_t rs = ne00 * type_size;
8422
+
8423
+ if (nb00 == type_size) {
8424
+ // src0 is contigous on first dimension, copy by rows
8425
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
8426
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
8427
+ id += rs * ir0;
8428
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
8429
+ const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
8430
+ memcpy(dst_ptr + id, src0_ptr, rs);
8431
+ id += rs;
8432
+ }
8433
+ id += rs * (ne01 - ir1);
8434
+ }
8435
+ }
8436
+ } else {
8437
+ //printf("%s: this is not optimal - fix me\n", __func__);
8438
+
8439
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
8440
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
8441
+ id += rs * ir0;
8442
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
8443
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
8444
+ const char * src0_ptr = (char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03;
8445
+ memcpy(dst_ptr + id, src0_ptr, type_size);
8446
+
8447
+ id += type_size;
8448
+ }
8449
+ }
8450
+ id += rs * (ne01 - ir1);
8451
+ }
8452
+ }
8453
+ }
7832
8454
 
7833
8455
  return;
7834
8456
  }
@@ -7909,6 +8531,10 @@ static void ggml_compute_forward_dup(
7909
8531
  {
7910
8532
  ggml_compute_forward_dup_f16(params, dst);
7911
8533
  } break;
8534
+ case GGML_TYPE_BF16:
8535
+ {
8536
+ ggml_compute_forward_dup_bf16(params, dst);
8537
+ } break;
7912
8538
  case GGML_TYPE_F32:
7913
8539
  {
7914
8540
  ggml_compute_forward_dup_f32(params, dst);
@@ -8002,17 +8628,96 @@ static void ggml_compute_forward_add_f32(
8002
8628
  float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
8003
8629
  float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
8004
8630
 
8005
- for (int64_t i0 = 0; i0 < ne0; ++i0) {
8006
- const int64_t i10 = i0 % ne10;
8007
- float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10);
8631
+ for (int64_t i0 = 0; i0 < ne0; ++i0) {
8632
+ const int64_t i10 = i0 % ne10;
8633
+ float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10);
8634
+
8635
+ dst_ptr[i0] = src0_ptr[i0] + *src1_ptr;
8636
+ }
8637
+ }
8638
+ }
8639
+ }
8640
+
8641
+ static void ggml_compute_forward_add_f16_f32(
8642
+ const struct ggml_compute_params * params,
8643
+ struct ggml_tensor * dst) {
8644
+
8645
+ const struct ggml_tensor * src0 = dst->src[0];
8646
+ const struct ggml_tensor * src1 = dst->src[1];
8647
+
8648
+ GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
8649
+
8650
+ if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
8651
+ return;
8652
+ }
8653
+
8654
+ const int ith = params->ith;
8655
+ const int nth = params->nth;
8656
+
8657
+ const int nr = ggml_nrows(src0);
8658
+
8659
+ GGML_TENSOR_BINARY_OP_LOCALS
8660
+
8661
+ GGML_ASSERT(src0->type == GGML_TYPE_F16);
8662
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
8663
+
8664
+ if (dst->type == GGML_TYPE_F32) {
8665
+ GGML_ASSERT( nb0 == sizeof(float));
8666
+ }
8667
+ else {
8668
+ GGML_ASSERT(dst->type == GGML_TYPE_F16);
8669
+ GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
8670
+ }
8671
+
8672
+ GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
8673
+
8674
+ // rows per thread
8675
+ const int dr = (nr + nth - 1)/nth;
8676
+
8677
+ // row range for this thread
8678
+ const int ir0 = dr*ith;
8679
+ const int ir1 = MIN(ir0 + dr, nr);
8680
+
8681
+ if (nb10 == sizeof(float)) {
8682
+ if (dst->type == GGML_TYPE_F16) {
8683
+ for (int ir = ir0; ir < ir1; ++ir) {
8684
+ // src0, src1 and dst are same shape => same indices
8685
+ const int i3 = ir/(ne2*ne1);
8686
+ const int i2 = (ir - i3*ne2*ne1)/ne1;
8687
+ const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
8688
+
8689
+ ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
8690
+ ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
8691
+ float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
8692
+
8693
+ for (int i = 0; i < ne0; i++) {
8694
+ dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + src1_ptr[i]);
8695
+ }
8696
+ }
8697
+ } else {
8698
+ for (int ir = ir0; ir < ir1; ++ir) {
8699
+ // src0, src1 and dst are same shape => same indices
8700
+ const int i3 = ir/(ne2*ne1);
8701
+ const int i2 = (ir - i3*ne2*ne1)/ne1;
8702
+ const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
8703
+
8704
+ float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
8705
+ ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
8706
+ float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
8008
8707
 
8009
- dst_ptr[i0] = src0_ptr[i0] + *src1_ptr;
8708
+ for (int i = 0; i < ne0; i++) {
8709
+ dst_ptr[i] = GGML_FP16_TO_FP32(src0_ptr[i]) + src1_ptr[i];
8710
+ }
8010
8711
  }
8011
8712
  }
8012
8713
  }
8714
+ else {
8715
+ // src1 is not contiguous
8716
+ GGML_ASSERT(false);
8717
+ }
8013
8718
  }
8014
8719
 
8015
- static void ggml_compute_forward_add_f16_f32(
8720
+ static void ggml_compute_forward_add_bf16_f32(
8016
8721
  const struct ggml_compute_params * params,
8017
8722
  struct ggml_tensor * dst) {
8018
8723
 
@@ -8032,18 +8737,18 @@ static void ggml_compute_forward_add_f16_f32(
8032
8737
 
8033
8738
  GGML_TENSOR_BINARY_OP_LOCALS
8034
8739
 
8035
- GGML_ASSERT(src0->type == GGML_TYPE_F16);
8740
+ GGML_ASSERT(src0->type == GGML_TYPE_BF16);
8036
8741
  GGML_ASSERT(src1->type == GGML_TYPE_F32);
8037
8742
 
8038
8743
  if (dst->type == GGML_TYPE_F32) {
8039
8744
  GGML_ASSERT( nb0 == sizeof(float));
8040
8745
  }
8041
8746
  else {
8042
- GGML_ASSERT(dst->type == GGML_TYPE_F16);
8043
- GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
8747
+ GGML_ASSERT(dst->type == GGML_TYPE_BF16);
8748
+ GGML_ASSERT( nb0 == sizeof(ggml_bf16_t));
8044
8749
  }
8045
8750
 
8046
- GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
8751
+ GGML_ASSERT(nb00 == sizeof(ggml_bf16_t));
8047
8752
 
8048
8753
  // rows per thread
8049
8754
  const int dr = (nr + nth - 1)/nth;
@@ -8053,19 +8758,19 @@ static void ggml_compute_forward_add_f16_f32(
8053
8758
  const int ir1 = MIN(ir0 + dr, nr);
8054
8759
 
8055
8760
  if (nb10 == sizeof(float)) {
8056
- if (dst->type == GGML_TYPE_F16) {
8761
+ if (dst->type == GGML_TYPE_BF16) {
8057
8762
  for (int ir = ir0; ir < ir1; ++ir) {
8058
8763
  // src0, src1 and dst are same shape => same indices
8059
8764
  const int i3 = ir/(ne2*ne1);
8060
8765
  const int i2 = (ir - i3*ne2*ne1)/ne1;
8061
8766
  const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
8062
8767
 
8063
- ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
8064
- ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
8768
+ ggml_bf16_t * dst_ptr = (ggml_bf16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
8769
+ ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
8065
8770
  float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
8066
8771
 
8067
8772
  for (int i = 0; i < ne0; i++) {
8068
- dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + src1_ptr[i]);
8773
+ dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + src1_ptr[i]);
8069
8774
  }
8070
8775
  }
8071
8776
  } else {
@@ -8076,11 +8781,11 @@ static void ggml_compute_forward_add_f16_f32(
8076
8781
  const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
8077
8782
 
8078
8783
  float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
8079
- ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
8784
+ ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
8080
8785
  float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
8081
8786
 
8082
8787
  for (int i = 0; i < ne0; i++) {
8083
- dst_ptr[i] = GGML_FP16_TO_FP32(src0_ptr[i]) + src1_ptr[i];
8788
+ dst_ptr[i] = GGML_BF16_TO_FP32(src0_ptr[i]) + src1_ptr[i];
8084
8789
  }
8085
8790
  }
8086
8791
  }
@@ -8147,6 +8852,62 @@ static void ggml_compute_forward_add_f16_f16(
8147
8852
  }
8148
8853
  }
8149
8854
 
8855
+ static void ggml_compute_forward_add_bf16_bf16(
8856
+ const struct ggml_compute_params * params,
8857
+ struct ggml_tensor * dst) {
8858
+
8859
+ const struct ggml_tensor * src0 = dst->src[0];
8860
+ const struct ggml_tensor * src1 = dst->src[1];
8861
+
8862
+ GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
8863
+
8864
+ if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
8865
+ return;
8866
+ }
8867
+
8868
+ const int ith = params->ith;
8869
+ const int nth = params->nth;
8870
+
8871
+ const int nr = ggml_nrows(src0);
8872
+
8873
+ GGML_TENSOR_BINARY_OP_LOCALS
8874
+
8875
+ GGML_ASSERT(src0->type == GGML_TYPE_BF16);
8876
+ GGML_ASSERT(src1->type == GGML_TYPE_BF16);
8877
+ GGML_ASSERT(dst->type == GGML_TYPE_BF16);
8878
+
8879
+ GGML_ASSERT( nb0 == sizeof(ggml_bf16_t));
8880
+ GGML_ASSERT(nb00 == sizeof(ggml_bf16_t));
8881
+
8882
+ // rows per thread
8883
+ const int dr = (nr + nth - 1)/nth;
8884
+
8885
+ // row range for this thread
8886
+ const int ir0 = dr*ith;
8887
+ const int ir1 = MIN(ir0 + dr, nr);
8888
+
8889
+ if (nb10 == sizeof(ggml_bf16_t)) {
8890
+ for (int ir = ir0; ir < ir1; ++ir) {
8891
+ // src0, src1 and dst are same shape => same indices
8892
+ const int i3 = ir/(ne2*ne1);
8893
+ const int i2 = (ir - i3*ne2*ne1)/ne1;
8894
+ const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
8895
+
8896
+ ggml_bf16_t * dst_ptr = (ggml_bf16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
8897
+ ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
8898
+ ggml_bf16_t * src1_ptr = (ggml_bf16_t *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
8899
+
8900
+ for (int i = 0; i < ne0; i++) {
8901
+ dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + GGML_BF16_TO_FP32(src1_ptr[i]));
8902
+ }
8903
+ }
8904
+ }
8905
+ else {
8906
+ // src1 is not contiguous
8907
+ GGML_ASSERT(false);
8908
+ }
8909
+ }
8910
+
8150
8911
  static void ggml_compute_forward_add_q_f32(
8151
8912
  const struct ggml_compute_params * params,
8152
8913
  struct ggml_tensor * dst) {
@@ -8256,6 +9017,18 @@ static void ggml_compute_forward_add(
8256
9017
  GGML_ASSERT(false);
8257
9018
  }
8258
9019
  } break;
9020
+ case GGML_TYPE_BF16:
9021
+ {
9022
+ if (src1->type == GGML_TYPE_BF16) {
9023
+ ggml_compute_forward_add_bf16_bf16(params, dst);
9024
+ }
9025
+ else if (src1->type == GGML_TYPE_F32) {
9026
+ ggml_compute_forward_add_bf16_f32(params, dst);
9027
+ }
9028
+ else {
9029
+ GGML_ASSERT(false);
9030
+ }
9031
+ } break;
8259
9032
  case GGML_TYPE_Q4_0:
8260
9033
  case GGML_TYPE_Q4_1:
8261
9034
  case GGML_TYPE_Q5_0:
@@ -8514,6 +9287,110 @@ static void ggml_compute_forward_add1_q_f32(
8514
9287
  }
8515
9288
  }
8516
9289
 
9290
+ static void ggml_compute_forward_add1_bf16_f32(
9291
+ const struct ggml_compute_params * params,
9292
+ struct ggml_tensor * dst) {
9293
+
9294
+ const struct ggml_tensor * src0 = dst->src[0];
9295
+ const struct ggml_tensor * src1 = dst->src[1];
9296
+
9297
+ GGML_ASSERT(ggml_are_same_shape(src0, dst));
9298
+ GGML_ASSERT(ggml_is_scalar(src1));
9299
+
9300
+ if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
9301
+ return;
9302
+ }
9303
+
9304
+ // scalar to add
9305
+ const float v = *(float *) src1->data;
9306
+
9307
+ const int ith = params->ith;
9308
+ const int nth = params->nth;
9309
+
9310
+ const int nr = ggml_nrows(src0);
9311
+
9312
+ GGML_TENSOR_UNARY_OP_LOCALS
9313
+
9314
+ GGML_ASSERT(src0->type == GGML_TYPE_BF16);
9315
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
9316
+ GGML_ASSERT(dst->type == GGML_TYPE_BF16);
9317
+
9318
+ GGML_ASSERT( nb0 == sizeof(ggml_bf16_t));
9319
+ GGML_ASSERT(nb00 == sizeof(ggml_bf16_t));
9320
+
9321
+ // rows per thread
9322
+ const int dr = (nr + nth - 1)/nth;
9323
+
9324
+ // row range for this thread
9325
+ const int ir0 = dr*ith;
9326
+ const int ir1 = MIN(ir0 + dr, nr);
9327
+
9328
+ for (int ir = ir0; ir < ir1; ++ir) {
9329
+ // src0 and dst are same shape => same indices
9330
+ const int i3 = ir/(ne2*ne1);
9331
+ const int i2 = (ir - i3*ne2*ne1)/ne1;
9332
+ const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
9333
+
9334
+ ggml_bf16_t * dst_ptr = (ggml_bf16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 );
9335
+ ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
9336
+ for (int i = 0; i < ne0; i++) {
9337
+ dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + v);
9338
+ }
9339
+ }
9340
+ }
9341
+
9342
+ static void ggml_compute_forward_add1_bf16_bf16(
9343
+ const struct ggml_compute_params * params,
9344
+ struct ggml_tensor * dst) {
9345
+
9346
+ const struct ggml_tensor * src0 = dst->src[0];
9347
+ const struct ggml_tensor * src1 = dst->src[1];
9348
+
9349
+ GGML_ASSERT(ggml_are_same_shape(src0, dst));
9350
+ GGML_ASSERT(ggml_is_scalar(src1));
9351
+
9352
+ if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
9353
+ return;
9354
+ }
9355
+
9356
+ // scalar to add
9357
+ const float v = GGML_BF16_TO_FP32(*(ggml_bf16_t *) src1->data);
9358
+
9359
+ const int ith = params->ith;
9360
+ const int nth = params->nth;
9361
+
9362
+ const int nr = ggml_nrows(src0);
9363
+
9364
+ GGML_TENSOR_UNARY_OP_LOCALS
9365
+
9366
+ GGML_ASSERT(src0->type == GGML_TYPE_BF16);
9367
+ GGML_ASSERT(src1->type == GGML_TYPE_BF16);
9368
+ GGML_ASSERT(dst->type == GGML_TYPE_BF16);
9369
+
9370
+ GGML_ASSERT( nb0 == sizeof(ggml_bf16_t));
9371
+ GGML_ASSERT(nb00 == sizeof(ggml_bf16_t));
9372
+
9373
+ // rows per thread
9374
+ const int dr = (nr + nth - 1)/nth;
9375
+
9376
+ // row range for this thread
9377
+ const int ir0 = dr*ith;
9378
+ const int ir1 = MIN(ir0 + dr, nr);
9379
+
9380
+ for (int ir = ir0; ir < ir1; ++ir) {
9381
+ // src0 and dst are same shape => same indices
9382
+ const int i3 = ir/(ne2*ne1);
9383
+ const int i2 = (ir - i3*ne2*ne1)/ne1;
9384
+ const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
9385
+
9386
+ ggml_bf16_t * dst_ptr = (ggml_bf16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 );
9387
+ ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
9388
+ for (int i = 0; i < ne0; i++) {
9389
+ dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + v);
9390
+ }
9391
+ }
9392
+ }
9393
+
8517
9394
  static void ggml_compute_forward_add1(
8518
9395
  const struct ggml_compute_params * params,
8519
9396
  struct ggml_tensor * dst) {
@@ -8538,6 +9415,18 @@ static void ggml_compute_forward_add1(
8538
9415
  GGML_ASSERT(false);
8539
9416
  }
8540
9417
  } break;
9418
+ case GGML_TYPE_BF16:
9419
+ {
9420
+ if (src1->type == GGML_TYPE_BF16) {
9421
+ ggml_compute_forward_add1_bf16_bf16(params, dst);
9422
+ }
9423
+ else if (src1->type == GGML_TYPE_F32) {
9424
+ ggml_compute_forward_add1_bf16_f32(params, dst);
9425
+ }
9426
+ else {
9427
+ GGML_ASSERT(false);
9428
+ }
9429
+ } break;
8541
9430
  case GGML_TYPE_Q4_0:
8542
9431
  case GGML_TYPE_Q4_1:
8543
9432
  case GGML_TYPE_Q5_0:
@@ -8666,6 +9555,7 @@ static void ggml_compute_forward_acc(
8666
9555
  ggml_compute_forward_acc_f32(params, dst);
8667
9556
  } break;
8668
9557
  case GGML_TYPE_F16:
9558
+ case GGML_TYPE_BF16:
8669
9559
  case GGML_TYPE_Q4_0:
8670
9560
  case GGML_TYPE_Q4_1:
8671
9561
  case GGML_TYPE_Q5_0:
@@ -9187,6 +10077,40 @@ static void ggml_compute_forward_sum_f16(
9187
10077
  ((ggml_fp16_t *) dst->data)[0] = GGML_FP32_TO_FP16(sum);
9188
10078
  }
9189
10079
 
10080
+ static void ggml_compute_forward_sum_bf16(
10081
+ const struct ggml_compute_params * params,
10082
+ struct ggml_tensor * dst) {
10083
+
10084
+ const struct ggml_tensor * src0 = dst->src[0];
10085
+
10086
+ assert(params->ith == 0);
10087
+ assert(ggml_is_scalar(dst));
10088
+
10089
+ if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
10090
+ return;
10091
+ }
10092
+
10093
+ assert(src0->nb[0] == sizeof(ggml_bf16_t));
10094
+
10095
+ GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
10096
+ GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
10097
+
10098
+ float sum = 0;
10099
+ float row_sum = 0;
10100
+
10101
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
10102
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
10103
+ for (int64_t i01 = 0; i01 < ne01; i01++) {
10104
+ ggml_vec_sum_bf16_ggf(ne00,
10105
+ &row_sum,
10106
+ (ggml_bf16_t *) ((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03));
10107
+ sum += row_sum;
10108
+ }
10109
+ }
10110
+ }
10111
+ ((ggml_bf16_t *) dst->data)[0] = GGML_FP32_TO_BF16(sum);
10112
+ }
10113
+
9190
10114
  static void ggml_compute_forward_sum(
9191
10115
  const struct ggml_compute_params * params,
9192
10116
  struct ggml_tensor * dst) {
@@ -9202,6 +10126,10 @@ static void ggml_compute_forward_sum(
9202
10126
  {
9203
10127
  ggml_compute_forward_sum_f16(params, dst);
9204
10128
  } break;
10129
+ case GGML_TYPE_BF16:
10130
+ {
10131
+ ggml_compute_forward_sum_bf16(params, dst);
10132
+ } break;
9205
10133
  default:
9206
10134
  {
9207
10135
  GGML_ASSERT(false);
@@ -9476,6 +10404,7 @@ static void ggml_compute_forward_repeat(
9476
10404
 
9477
10405
  switch (src0->type) {
9478
10406
  case GGML_TYPE_F16:
10407
+ case GGML_TYPE_BF16:
9479
10408
  case GGML_TYPE_I16:
9480
10409
  {
9481
10410
  ggml_compute_forward_repeat_f16(params, dst);
@@ -11793,6 +12722,7 @@ static void ggml_compute_forward_set(
11793
12722
  ggml_compute_forward_set_f32(params, dst);
11794
12723
  } break;
11795
12724
  case GGML_TYPE_F16:
12725
+ case GGML_TYPE_BF16:
11796
12726
  case GGML_TYPE_Q4_0:
11797
12727
  case GGML_TYPE_Q4_1:
11798
12728
  case GGML_TYPE_Q5_0:
@@ -11967,6 +12897,49 @@ static void ggml_compute_forward_get_rows_f16(
11967
12897
  }
11968
12898
  }
11969
12899
 
12900
+ static void ggml_compute_forward_get_rows_bf16(
12901
+ const struct ggml_compute_params * params,
12902
+ struct ggml_tensor * dst) {
12903
+
12904
+ const struct ggml_tensor * src0 = dst->src[0];
12905
+ const struct ggml_tensor * src1 = dst->src[1];
12906
+
12907
+ if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
12908
+ return;
12909
+ }
12910
+
12911
+ GGML_TENSOR_BINARY_OP_LOCALS
12912
+
12913
+ const int64_t nc = ne00;
12914
+ const int64_t nr = ggml_nelements(src1);
12915
+
12916
+ assert(ne0 == nc);
12917
+ assert(ne02 == ne11);
12918
+ assert(nb00 == sizeof(ggml_bf16_t));
12919
+ assert(ggml_nrows(dst) == nr);
12920
+
12921
+ const int ith = params->ith;
12922
+ const int nth = params->nth;
12923
+
12924
+ // rows per thread
12925
+ const int dr = (nr + nth - 1)/nth;
12926
+
12927
+ // row range for this thread
12928
+ const int ir0 = dr*ith;
12929
+ const int ir1 = MIN(ir0 + dr, nr);
12930
+
12931
+ for (int64_t i = ir0; i < ir1; ++i) {
12932
+ const int64_t i12 = i/(ne11*ne10);
12933
+ const int64_t i11 = (i - i12*ne11*ne10)/ne10;
12934
+ const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
12935
+ const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
12936
+
12937
+ ggml_bf16_to_fp32_row(
12938
+ (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
12939
+ (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc);
12940
+ }
12941
+ }
12942
+
11970
12943
  static void ggml_compute_forward_get_rows_f32(
11971
12944
  const struct ggml_compute_params * params,
11972
12945
  struct ggml_tensor * dst) {
@@ -12044,6 +13017,10 @@ static void ggml_compute_forward_get_rows(
12044
13017
  {
12045
13018
  ggml_compute_forward_get_rows_f16(params, dst);
12046
13019
  } break;
13020
+ case GGML_TYPE_BF16:
13021
+ {
13022
+ ggml_compute_forward_get_rows_bf16(params, dst);
13023
+ } break;
12047
13024
  case GGML_TYPE_F32:
12048
13025
  case GGML_TYPE_I32:
12049
13026
  {
@@ -12739,6 +13716,7 @@ static void ggml_compute_forward_alibi(
12739
13716
  {
12740
13717
  ggml_compute_forward_alibi_f32(params, dst);
12741
13718
  } break;
13719
+ case GGML_TYPE_BF16:
12742
13720
  case GGML_TYPE_Q4_0:
12743
13721
  case GGML_TYPE_Q4_1:
12744
13722
  case GGML_TYPE_Q5_0:
@@ -12828,6 +13806,7 @@ static void ggml_compute_forward_clamp(
12828
13806
  ggml_compute_forward_clamp_f32(params, dst);
12829
13807
  } break;
12830
13808
  case GGML_TYPE_F16:
13809
+ case GGML_TYPE_BF16:
12831
13810
  case GGML_TYPE_Q4_0:
12832
13811
  case GGML_TYPE_Q4_1:
12833
13812
  case GGML_TYPE_Q5_0:
@@ -15921,6 +16900,7 @@ static void ggml_compute_forward_get_rel_pos(
15921
16900
 
15922
16901
  switch (src0->type) {
15923
16902
  case GGML_TYPE_F16:
16903
+ case GGML_TYPE_BF16:
15924
16904
  {
15925
16905
  ggml_compute_forward_get_rel_pos_f16(params, dst);
15926
16906
  } break;
@@ -18785,7 +19765,10 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
18785
19765
  case GGML_OP_CPY:
18786
19766
  case GGML_OP_DUP:
18787
19767
  {
18788
- if (ggml_is_quantized(node->type)) {
19768
+ if (ggml_is_quantized(node->type) ||
19769
+ // F16 -> BF16 and BF16 -> F16 copies go through intermediate F32
19770
+ (node->src[0]->type == GGML_TYPE_F16 && node->src[1] && node->src[1]->type == GGML_TYPE_BF16) ||
19771
+ (node->src[0]->type == GGML_TYPE_BF16 && node->src[1] && node->src[1]->type == GGML_TYPE_F16)) {
18789
19772
  cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
18790
19773
  }
18791
19774
  } break;
@@ -18864,7 +19847,8 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
18864
19847
  const int64_t ne10 = node->src[1]->ne[0]; // L
18865
19848
  const int64_t ne11 = node->src[1]->ne[1]; // Cin
18866
19849
 
18867
- if (node->src[0]->type == GGML_TYPE_F16 &&
19850
+ if ((node->src[0]->type == GGML_TYPE_F16 ||
19851
+ node->src[0]->type == GGML_TYPE_BF16) &&
18868
19852
  node->src[1]->type == GGML_TYPE_F32) {
18869
19853
  cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02;
18870
19854
  cur += sizeof(ggml_fp16_t)*ne10*ne11;
@@ -18900,6 +19884,9 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
18900
19884
  } else if (node->src[1]->type == GGML_TYPE_F16) {
18901
19885
  cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1)
18902
19886
  cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
19887
+ } else if (node->src[1]->type == GGML_TYPE_BF16) {
19888
+ cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1)
19889
+ cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
18903
19890
  }
18904
19891
  } break;
18905
19892
  case GGML_OP_FLASH_ATTN_EXT:
@@ -18916,6 +19903,9 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
18916
19903
  } else if (node->src[1]->type == GGML_TYPE_F16) {
18917
19904
  cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
18918
19905
  cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2
19906
+ } else if (node->src[1]->type == GGML_TYPE_BF16) {
19907
+ cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
19908
+ cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2
18919
19909
  }
18920
19910
  } break;
18921
19911
  case GGML_OP_FLASH_ATTN_BACK:
@@ -18929,6 +19919,9 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
18929
19919
  } else if (node->src[1]->type == GGML_TYPE_F16) {
18930
19920
  cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1)
18931
19921
  cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
19922
+ } else if (node->src[1]->type == GGML_TYPE_BF16) {
19923
+ cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1)
19924
+ cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
18932
19925
  }
18933
19926
  } break;
18934
19927
 
@@ -19705,7 +20698,9 @@ void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph
19705
20698
  if (node->type == GGML_TYPE_I8 || node->type == GGML_TYPE_I16 || node->type == GGML_TYPE_I32) {
19706
20699
  fprintf(fp, "%d", ggml_get_i32_1d(node, j));
19707
20700
  }
19708
- else if (node->type == GGML_TYPE_F32 || node->type == GGML_TYPE_F16) {
20701
+ else if (node->type == GGML_TYPE_F32 ||
20702
+ node->type == GGML_TYPE_F16 ||
20703
+ node->type == GGML_TYPE_BF16) {
19709
20704
  fprintf(fp, "%.1e", (double)ggml_get_f32_1d(node, j));
19710
20705
  }
19711
20706
  else {
@@ -20763,6 +21758,12 @@ size_t ggml_quantize_chunk(
20763
21758
  ggml_fp32_to_fp16_row(src + start, (ggml_fp16_t *)dst + start, n);
20764
21759
  result = n * elemsize;
20765
21760
  } break;
21761
+ case GGML_TYPE_BF16:
21762
+ {
21763
+ size_t elemsize = sizeof(ggml_bf16_t);
21764
+ ggml_fp32_to_bf16_row(src + start, (ggml_bf16_t *)dst + start, n);
21765
+ result = n * elemsize;
21766
+ } break;
20766
21767
  case GGML_TYPE_F32:
20767
21768
  {
20768
21769
  size_t elemsize = sizeof(float);
@@ -21139,7 +22140,7 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
21139
22140
  }
21140
22141
 
21141
22142
  // read the tensor infos
21142
- {
22143
+ if (ctx->header.n_tensors > 0) {
21143
22144
  ctx->infos = GGML_CALLOC(ctx->header.n_tensors, sizeof(struct gguf_tensor_info));
21144
22145
 
21145
22146
  for (uint64_t i = 0; i < ctx->header.n_tensors; ++i) {