llama_cpp 0.15.0 → 0.15.1

Sign up to get free protection for your applications and to get access to all the features.
@@ -322,7 +322,7 @@ static ggml_fp16_t ggml_table_exp_f16[1 << 16];
322
322
  // precomputed f32 table for f16 (256 KB) (ggml-impl.h)
323
323
  float ggml_table_f32_f16[1 << 16];
324
324
 
325
- const char * ggml_status_to_string(enum ggml_status status) {
325
+ GGML_CALL const char * ggml_status_to_string(enum ggml_status status) {
326
326
  switch (status) {
327
327
  case GGML_STATUS_ALLOC_FAILED: return "GGML status: error (failed to allocate memory)";
328
328
  case GGML_STATUS_FAILED: return "GGML status: error (operation failed)";
@@ -333,16 +333,26 @@ const char * ggml_status_to_string(enum ggml_status status) {
333
333
  return "GGML status: unknown";
334
334
  }
335
335
 
336
- // note: do not use these inside ggml.c
337
- // these are meant to be used via the ggml.h API
338
336
  float ggml_fp16_to_fp32(ggml_fp16_t x) {
337
+ #define ggml_fp16_to_fp32 do_not_use__ggml_fp16_to_fp32__in_ggml
339
338
  return GGML_FP16_TO_FP32(x);
340
339
  }
341
340
 
342
341
  ggml_fp16_t ggml_fp32_to_fp16(float x) {
342
+ #define ggml_fp32_to_fp16 do_not_use__ggml_fp32_to_fp16__in_ggml
343
343
  return GGML_FP32_TO_FP16(x);
344
344
  }
345
345
 
346
+ float ggml_bf16_to_fp32(ggml_bf16_t x) {
347
+ #define ggml_bf16_to_fp32 do_not_use__ggml_bf16_to_fp32__in_ggml
348
+ return GGML_BF16_TO_FP32(x); // it just left shifts
349
+ }
350
+
351
+ ggml_bf16_t ggml_fp32_to_bf16(float x) {
352
+ #define ggml_fp32_to_bf16 do_not_use__ggml_fp32_to_bf16__in_ggml
353
+ return GGML_FP32_TO_BF16(x);
354
+ }
355
+
346
356
  void ggml_fp16_to_fp32_row(const ggml_fp16_t * x, float * y, int64_t n) {
347
357
  for (int64_t i = 0; i < n; i++) {
348
358
  y[i] = GGML_FP16_TO_FP32(x[i]);
@@ -368,6 +378,49 @@ void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, int64_t n) {
368
378
  }
369
379
  }
370
380
 
381
+ void ggml_bf16_to_fp32_row(const ggml_bf16_t * x, float * y, int64_t n) {
382
+ int64_t i = 0;
383
+ #if defined(__AVX512F__)
384
+ for (; i + 16 <= n; i += 16) {
385
+ _mm512_storeu_ps(y + i,
386
+ _mm512_castsi512_ps(
387
+ _mm512_slli_epi32(
388
+ _mm512_cvtepu16_epi32(
389
+ _mm256_loadu_si256(
390
+ (const __m256i *)(x + i))),
391
+ 16)));
392
+ }
393
+ #elif defined(__AVX2__)
394
+ for (; i + 8 <= n; i += 8) {
395
+ _mm256_storeu_ps(y + i,
396
+ _mm256_castsi256_ps(
397
+ _mm256_slli_epi32(
398
+ _mm256_cvtepu16_epi32(
399
+ _mm_loadu_si128(
400
+ (const __m128i *)(x + i))),
401
+ 16)));
402
+ }
403
+ #endif
404
+ for (; i < n; i++) {
405
+ y[i] = GGML_BF16_TO_FP32(x[i]);
406
+ }
407
+ }
408
+
409
+ void ggml_fp32_to_bf16_row(const float * x, ggml_bf16_t * y, int64_t n) {
410
+ int i = 0;
411
+ #if defined(__AVX512BF16__)
412
+ for (; i + 32 <= n; i += 32) {
413
+ _mm512_storeu_ps(
414
+ (__m512 *)(y + i),
415
+ (__m512)_mm512_cvtne2ps_pbh(_mm512_loadu_ps(x + i + 16),
416
+ _mm512_loadu_ps(x + i)));
417
+ }
418
+ #endif
419
+ for (; i < n; i++) {
420
+ y[i] = GGML_FP32_TO_BF16(x[i]);
421
+ }
422
+ }
423
+
371
424
  bool ggml_guid_matches(ggml_guid_t guid_a, ggml_guid_t guid_b) {
372
425
  return memcmp(guid_a, guid_b, sizeof(ggml_guid)) == 0;
373
426
  }
@@ -503,6 +556,7 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
503
556
 
504
557
  static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float * restrict x, size_t bx, const float * restrict y, size_t by, int nrc);
505
558
  static void ggml_vec_dot_f16(int n, float * restrict s, size_t bs, ggml_fp16_t * restrict x, size_t bx, ggml_fp16_t * restrict y, size_t by, int nrc);
559
+ static void ggml_vec_dot_bf16(int n, float * restrict s, size_t bs, ggml_bf16_t * restrict x, size_t bx, ggml_bf16_t * restrict y, size_t by, int nrc);
506
560
 
507
561
  static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
508
562
  [GGML_TYPE_I8] = {
@@ -845,6 +899,18 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
845
899
  .type_size = sizeof(block_q8_K),
846
900
  .is_quantized = true,
847
901
  .from_float = quantize_row_q8_K,
902
+ },
903
+ [GGML_TYPE_BF16] = {
904
+ .type_name = "bf16",
905
+ .blck_size = 1,
906
+ .type_size = sizeof(ggml_bf16_t),
907
+ .is_quantized = false,
908
+ .to_float = (ggml_to_float_t) ggml_bf16_to_fp32_row,
909
+ .from_float = (ggml_from_float_t) ggml_fp32_to_bf16_row,
910
+ .from_float_reference = (ggml_from_float_t) ggml_fp32_to_bf16_row,
911
+ .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_bf16,
912
+ .vec_dot_type = GGML_TYPE_BF16,
913
+ .nrows = 1,
848
914
  }
849
915
  };
850
916
 
@@ -1480,6 +1546,8 @@ inline static void ggml_vec_set_i32(const int n, int32_t * x, const int32_t v) {
1480
1546
 
1481
1547
  inline static void ggml_vec_set_f16(const int n, ggml_fp16_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
1482
1548
 
1549
+ inline static void ggml_vec_set_bf16(const int n, ggml_bf16_t * x, const ggml_bf16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
1550
+
1483
1551
  inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] + y[i]; }
1484
1552
  inline static void ggml_vec_add1_f32(const int n, float * z, const float * x, const float v) { for (int i = 0; i < n; ++i) z[i] = x[i] + v; }
1485
1553
  inline static void ggml_vec_acc_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] += x[i]; }
@@ -1498,7 +1566,7 @@ static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float *
1498
1566
  UNUSED(by);
1499
1567
  UNUSED(bs);
1500
1568
 
1501
- #ifdef GGML_SIMD
1569
+ #if defined(GGML_SIMD)
1502
1570
  float sumf = 0.0f;
1503
1571
  const int np = (n & ~(GGML_F32_STEP - 1));
1504
1572
 
@@ -1534,6 +1602,70 @@ static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float *
1534
1602
  *s = sumf;
1535
1603
  }
1536
1604
 
1605
+ static void ggml_vec_dot_bf16(int n, float * restrict s, size_t bs, ggml_bf16_t * restrict x, size_t bx, ggml_bf16_t * restrict y, size_t by, int nrc) {
1606
+ assert(nrc == 1);
1607
+ UNUSED(nrc);
1608
+ UNUSED(bx);
1609
+ UNUSED(by);
1610
+ UNUSED(bs);
1611
+ int i = 0;
1612
+ ggml_float sumf = 0;
1613
+
1614
+ #if defined(__AVX512BF16__)
1615
+ __m512 c1 = _mm512_setzero_ps();
1616
+ __m512 c2 = _mm512_setzero_ps();
1617
+ for (; i + 64 <= n; i += 64) {
1618
+ c1 = _mm512_dpbf16_ps(c1, (__m512bh)_mm512_loadu_ps((const float *)(x + i)),
1619
+ (__m512bh)_mm512_loadu_ps((const float *)(y + i)));
1620
+ c2 = _mm512_dpbf16_ps(c2, (__m512bh)_mm512_loadu_ps((const float *)(x + i + 32)),
1621
+ (__m512bh)_mm512_loadu_ps((const float *)(y + i + 32)));
1622
+ }
1623
+ sumf += (ggml_float)_mm512_reduce_add_ps(c1);
1624
+ sumf += (ggml_float)_mm512_reduce_add_ps(c2);
1625
+
1626
+ #elif defined(__AVX512F__)
1627
+ #define LOAD(p) _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i *)(p))), 16))
1628
+ __m512 c1 = _mm512_setzero_ps();
1629
+ __m512 c2 = _mm512_setzero_ps();
1630
+ for (; i + 32 <= n; i += 32) {
1631
+ c1 = _mm512_add_ps(_mm512_mul_ps(LOAD(x + i), LOAD(y + i)), c1);
1632
+ c2 = _mm512_add_ps(_mm512_mul_ps(LOAD(x + i + 16), LOAD(y + i + 16)), c2);
1633
+ }
1634
+ sumf += (ggml_float)_mm512_reduce_add_ps(c1);
1635
+ sumf += (ggml_float)_mm512_reduce_add_ps(c2);
1636
+
1637
+ #undef LOAD
1638
+ #elif defined(__AVX2__)
1639
+ #define LOAD(p) _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)(p))), 16))
1640
+ __m256 c1 = _mm256_setzero_ps();
1641
+ __m256 c2 = _mm256_setzero_ps();
1642
+ __m256 c3 = _mm256_setzero_ps();
1643
+ __m256 c4 = _mm256_setzero_ps();
1644
+ for (; i + 32 <= n; i += 32) {
1645
+ c1 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i), LOAD(y + i)), c1);
1646
+ c2 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i + 8), LOAD(y + i + 8)), c2);
1647
+ c3 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i + 16), LOAD(y + i + 16)), c3);
1648
+ c4 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i + 24), LOAD(y + i + 24)), c4);
1649
+ }
1650
+ __m128 g;
1651
+ c1 = _mm256_add_ps(_mm256_add_ps(c1, c3),
1652
+ _mm256_add_ps(c2, c4));
1653
+ g = _mm_add_ps(_mm256_extractf128_ps(c1, 1),
1654
+ _mm256_castps256_ps128(c1));
1655
+ g = _mm_add_ps(g, _mm_movehl_ps(g, g));
1656
+ g = _mm_add_ss(g, _mm_movehdup_ps(g));
1657
+ sumf += (ggml_float)_mm_cvtss_f32(g);
1658
+
1659
+ #undef LOAD
1660
+ #endif
1661
+
1662
+ for (; i < n; ++i) {
1663
+ sumf += (ggml_float)(GGML_BF16_TO_FP32(x[i]) *
1664
+ GGML_BF16_TO_FP32(y[i]));
1665
+ }
1666
+ *s = sumf;
1667
+ }
1668
+
1537
1669
  static void ggml_vec_dot_f16(int n, float * restrict s, size_t bs, ggml_fp16_t * restrict x, size_t bx, ggml_fp16_t * restrict y, size_t by, int nrc) {
1538
1670
  assert(nrc == 1);
1539
1671
  UNUSED(nrc);
@@ -1967,6 +2099,14 @@ inline static void ggml_vec_sum_f16_ggf(const int n, float * s, const ggml_fp16_
1967
2099
  *s = sum;
1968
2100
  }
1969
2101
 
2102
+ inline static void ggml_vec_sum_bf16_ggf(const int n, float * s, const ggml_bf16_t * x) {
2103
+ float sum = 0.0f;
2104
+ for (int i = 0; i < n; ++i) {
2105
+ sum += GGML_BF16_TO_FP32(x[i]);
2106
+ }
2107
+ *s = sum;
2108
+ }
2109
+
1970
2110
  inline static void ggml_vec_max_f32(const int n, float * s, const float * x) {
1971
2111
  #ifndef GGML_USE_ACCELERATE
1972
2112
  float max = -INFINITY;
@@ -2377,7 +2517,7 @@ void ggml_numa_init(enum ggml_numa_strategy numa_flag) {
2377
2517
  // figure out which node we're on
2378
2518
  uint current_cpu;
2379
2519
  int getcpu_ret = 0;
2380
- #if __GLIBC__ > 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ > 28)
2520
+ #if __GLIBC__ > 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ > 28) || defined(__COSMOPOLITAN__)
2381
2521
  getcpu_ret = getcpu(&current_cpu, &g_state.numa.current_node);
2382
2522
  #else
2383
2523
  // old glibc doesn't have a wrapper for this call. Fall back on direct syscall
@@ -2588,6 +2728,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
2588
2728
  switch (ftype) {
2589
2729
  case GGML_FTYPE_ALL_F32: wtype = GGML_TYPE_F32; break;
2590
2730
  case GGML_FTYPE_MOSTLY_F16: wtype = GGML_TYPE_F16; break;
2731
+ case GGML_FTYPE_MOSTLY_BF16: wtype = GGML_TYPE_BF16; break;
2591
2732
  case GGML_FTYPE_MOSTLY_Q4_0: wtype = GGML_TYPE_Q4_0; break;
2592
2733
  case GGML_FTYPE_MOSTLY_Q4_1: wtype = GGML_TYPE_Q4_1; break;
2593
2734
  case GGML_FTYPE_MOSTLY_Q5_0: wtype = GGML_TYPE_Q5_0; break;
@@ -2729,15 +2870,16 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
2729
2870
  {
2730
2871
  const uint64_t t_start = ggml_time_us(); UNUSED(t_start);
2731
2872
 
2732
- ggml_fp16_t ii;
2733
2873
  for (int i = 0; i < (1 << 16); ++i) {
2734
- uint16_t ui = i;
2735
- memcpy(&ii, &ui, sizeof(ii));
2736
- const float f = ggml_table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(ii);
2874
+ union {
2875
+ uint16_t u16;
2876
+ ggml_fp16_t fp16;
2877
+ } u = {i};
2878
+ float f = ggml_table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(u.fp16);
2737
2879
  ggml_table_gelu_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_f32(f));
2738
2880
  ggml_table_gelu_quick_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_quick_f32(f));
2739
2881
  ggml_table_silu_f16[i] = GGML_FP32_TO_FP16(ggml_silu_f32(f));
2740
- ggml_table_exp_f16[i] = GGML_FP32_TO_FP16(expf(f));
2882
+ ggml_table_exp_f16[i] = GGML_FP32_TO_FP16(expf(f));
2741
2883
  }
2742
2884
 
2743
2885
  const uint64_t t_end = ggml_time_us(); UNUSED(t_end);
@@ -3201,6 +3343,13 @@ struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value) {
3201
3343
  ggml_vec_set_f16(nc, (ggml_fp16_t *)(data + i*n1), GGML_FP32_TO_FP16(value));
3202
3344
  }
3203
3345
  } break;
3346
+ case GGML_TYPE_BF16:
3347
+ {
3348
+ assert(tensor->nb[0] == sizeof(ggml_fp16_t));
3349
+ for (int i = 0; i < n; i++) {
3350
+ ggml_vec_set_bf16(nc, (ggml_bf16_t *)(data + i*n1), GGML_FP32_TO_BF16(value));
3351
+ }
3352
+ } break;
3204
3353
  case GGML_TYPE_F32:
3205
3354
  {
3206
3355
  assert(tensor->nb[0] == sizeof(float));
@@ -3253,6 +3402,13 @@ struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) {
3253
3402
  ggml_vec_set_f16(nc, (ggml_fp16_t *)(data + i*n1), GGML_FP32_TO_FP16(value));
3254
3403
  }
3255
3404
  } break;
3405
+ case GGML_TYPE_BF16:
3406
+ {
3407
+ assert(tensor->nb[0] == sizeof(ggml_bf16_t));
3408
+ for (int i = 0; i < n; i++) {
3409
+ ggml_vec_set_bf16(nc, (ggml_bf16_t *)(data + i*n1), GGML_FP32_TO_BF16(value));
3410
+ }
3411
+ } break;
3256
3412
  case GGML_TYPE_F32:
3257
3413
  {
3258
3414
  assert(tensor->nb[0] == sizeof(float));
@@ -3320,6 +3476,11 @@ int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) {
3320
3476
  GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t));
3321
3477
  return GGML_FP16_TO_FP32(((ggml_fp16_t *)(tensor->data))[i]);
3322
3478
  }
3479
+ case GGML_TYPE_BF16:
3480
+ {
3481
+ GGML_ASSERT(tensor->nb[0] == sizeof(ggml_bf16_t));
3482
+ return GGML_BF16_TO_FP32(((ggml_bf16_t *)(tensor->data))[i]);
3483
+ }
3323
3484
  case GGML_TYPE_F32:
3324
3485
  {
3325
3486
  GGML_ASSERT(tensor->nb[0] == sizeof(float));
@@ -3362,6 +3523,11 @@ void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) {
3362
3523
  GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t));
3363
3524
  ((ggml_fp16_t *)(tensor->data))[i] = GGML_FP32_TO_FP16(value);
3364
3525
  } break;
3526
+ case GGML_TYPE_BF16:
3527
+ {
3528
+ GGML_ASSERT(tensor->nb[0] == sizeof(ggml_bf16_t));
3529
+ ((ggml_bf16_t *)(tensor->data))[i] = GGML_FP32_TO_BF16(value);
3530
+ } break;
3365
3531
  case GGML_TYPE_F32:
3366
3532
  {
3367
3533
  GGML_ASSERT(tensor->nb[0] == sizeof(float));
@@ -3385,6 +3551,8 @@ int32_t ggml_get_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i
3385
3551
  return ((int32_t *) data)[0];
3386
3552
  case GGML_TYPE_F16:
3387
3553
  return GGML_FP16_TO_FP32(((ggml_fp16_t *) data)[0]);
3554
+ case GGML_TYPE_BF16:
3555
+ return GGML_BF16_TO_FP32(((ggml_bf16_t *) data)[0]);
3388
3556
  case GGML_TYPE_F32:
3389
3557
  return ((float *) data)[0];
3390
3558
  default:
@@ -3413,6 +3581,10 @@ void ggml_set_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2,
3413
3581
  {
3414
3582
  ((ggml_fp16_t *)(data))[0] = GGML_FP32_TO_FP16(value);
3415
3583
  } break;
3584
+ case GGML_TYPE_BF16:
3585
+ {
3586
+ ((ggml_bf16_t *)(data))[0] = GGML_FP32_TO_BF16(value);
3587
+ } break;
3416
3588
  case GGML_TYPE_F32:
3417
3589
  {
3418
3590
  ((float *)(data))[0] = value;
@@ -3451,6 +3623,11 @@ float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) {
3451
3623
  GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t));
3452
3624
  return GGML_FP16_TO_FP32(((ggml_fp16_t *)(tensor->data))[i]);
3453
3625
  }
3626
+ case GGML_TYPE_BF16:
3627
+ {
3628
+ GGML_ASSERT(tensor->nb[0] == sizeof(ggml_bf16_t));
3629
+ return GGML_BF16_TO_FP32(((ggml_bf16_t *)(tensor->data))[i]);
3630
+ }
3454
3631
  case GGML_TYPE_F32:
3455
3632
  {
3456
3633
  GGML_ASSERT(tensor->nb[0] == sizeof(float));
@@ -3493,6 +3670,11 @@ void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) {
3493
3670
  GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t));
3494
3671
  ((ggml_fp16_t *)(tensor->data))[i] = GGML_FP32_TO_FP16(value);
3495
3672
  } break;
3673
+ case GGML_TYPE_BF16:
3674
+ {
3675
+ GGML_ASSERT(tensor->nb[0] == sizeof(ggml_bf16_t));
3676
+ ((ggml_bf16_t *)(tensor->data))[i] = GGML_FP32_TO_BF16(value);
3677
+ } break;
3496
3678
  case GGML_TYPE_F32:
3497
3679
  {
3498
3680
  GGML_ASSERT(tensor->nb[0] == sizeof(float));
@@ -3516,6 +3698,8 @@ float ggml_get_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2,
3516
3698
  return ((int32_t *) data)[0];
3517
3699
  case GGML_TYPE_F16:
3518
3700
  return GGML_FP16_TO_FP32(((ggml_fp16_t *) data)[0]);
3701
+ case GGML_TYPE_BF16:
3702
+ return GGML_BF16_TO_FP32(((ggml_bf16_t *) data)[0]);
3519
3703
  case GGML_TYPE_F32:
3520
3704
  return ((float *) data)[0];
3521
3705
  default:
@@ -3544,6 +3728,10 @@ void ggml_set_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2,
3544
3728
  {
3545
3729
  ((ggml_fp16_t *)(data))[0] = GGML_FP32_TO_FP16(value);
3546
3730
  } break;
3731
+ case GGML_TYPE_BF16:
3732
+ {
3733
+ ((ggml_bf16_t *)(data))[0] = GGML_FP32_TO_BF16(value);
3734
+ } break;
3547
3735
  case GGML_TYPE_F32:
3548
3736
  {
3549
3737
  ((float *)(data))[0] = value;
@@ -3738,7 +3926,11 @@ static struct ggml_tensor * ggml_add_cast_impl(
3738
3926
  // TODO: support less-strict constraint
3739
3927
  // GGML_ASSERT(ggml_can_repeat(b, a));
3740
3928
  GGML_ASSERT(ggml_can_repeat_rows(b, a));
3741
- GGML_ASSERT(ggml_is_quantized(a->type) || a->type == GGML_TYPE_F16); // currently only supported for quantized input and f16
3929
+
3930
+ // currently only supported for quantized input and f16
3931
+ GGML_ASSERT(ggml_is_quantized(a->type) ||
3932
+ a->type == GGML_TYPE_F16 ||
3933
+ a->type == GGML_TYPE_BF16);
3742
3934
 
3743
3935
  bool is_node = false;
3744
3936
 
@@ -7215,8 +7407,8 @@ static void ggml_compute_forward_dup_same_cont(
7215
7407
  ((char *) src0->data + ie0*nb00),
7216
7408
  (ie1 - ie0) * ggml_type_size(src0->type));
7217
7409
  }
7218
-
7219
7410
  }
7411
+
7220
7412
  static void ggml_compute_forward_dup_f16(
7221
7413
  const struct ggml_compute_params * params,
7222
7414
  struct ggml_tensor * dst) {
@@ -7490,7 +7682,7 @@ static void ggml_compute_forward_dup_f16(
7490
7682
  }
7491
7683
  }
7492
7684
 
7493
- static void ggml_compute_forward_dup_f32(
7685
+ static void ggml_compute_forward_dup_bf16(
7494
7686
  const struct ggml_compute_params * params,
7495
7687
  struct ggml_tensor * dst) {
7496
7688
 
@@ -7538,10 +7730,11 @@ static void ggml_compute_forward_dup_f32(
7538
7730
  return;
7539
7731
  }
7540
7732
 
7733
+ // TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy
7734
+
7541
7735
  if (ggml_is_contiguous(dst)) {
7542
- // TODO: simplify
7543
- if (nb00 == sizeof(float)) {
7544
- if (dst->type == GGML_TYPE_F32) {
7736
+ if (nb00 == sizeof(ggml_bf16_t)) {
7737
+ if (dst->type == GGML_TYPE_BF16) {
7545
7738
  size_t id = 0;
7546
7739
  const size_t rs = ne00 * nb00;
7547
7740
  char * dst_ptr = (char *) dst->data;
@@ -7557,8 +7750,43 @@ static void ggml_compute_forward_dup_f32(
7557
7750
  id += rs * (ne01 - ir1);
7558
7751
  }
7559
7752
  }
7753
+ } else if (dst->type == GGML_TYPE_F16) {
7754
+ size_t id = 0;
7755
+ ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
7756
+
7757
+ for (int i03 = 0; i03 < ne03; i03++) {
7758
+ for (int i02 = 0; i02 < ne02; i02++) {
7759
+ id += ne00 * ir0;
7760
+ for (int i01 = ir0; i01 < ir1; i01++) {
7761
+ const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
7762
+ for (int i00 = 0; i00 < ne00; i00++) {
7763
+ dst_ptr[id] = GGML_FP32_TO_FP16(GGML_BF16_TO_FP32(src0_ptr[i00]));
7764
+ id++;
7765
+ }
7766
+ }
7767
+ id += ne00 * (ne01 - ir1);
7768
+ }
7769
+ }
7770
+ } else if (dst->type == GGML_TYPE_F32) {
7771
+ size_t id = 0;
7772
+ float * dst_ptr = (float *) dst->data;
7773
+
7774
+ for (int i03 = 0; i03 < ne03; i03++) {
7775
+ for (int i02 = 0; i02 < ne02; i02++) {
7776
+ id += ne00 * ir0;
7777
+ for (int i01 = ir0; i01 < ir1; i01++) {
7778
+ const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
7779
+ for (int i00 = 0; i00 < ne00; i00++) {
7780
+ dst_ptr[id] = GGML_BF16_TO_FP32(src0_ptr[i00]);
7781
+ id++;
7782
+ }
7783
+ }
7784
+ id += ne00 * (ne01 - ir1);
7785
+ }
7786
+ }
7560
7787
  } else if (type_traits[dst->type].from_float) {
7561
7788
  ggml_from_float_t const quantize_row_q = type_traits[dst->type].from_float;
7789
+ float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
7562
7790
 
7563
7791
  size_t id = 0;
7564
7792
  size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
@@ -7568,8 +7796,13 @@ static void ggml_compute_forward_dup_f32(
7568
7796
  for (int i02 = 0; i02 < ne02; i02++) {
7569
7797
  id += rs * ir0;
7570
7798
  for (int i01 = ir0; i01 < ir1; i01++) {
7571
- const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
7572
- quantize_row_q(src0_ptr, dst_ptr + id, ne00);
7799
+ const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
7800
+
7801
+ for (int i00 = 0; i00 < ne00; i00++) {
7802
+ src0_f32[i00] = GGML_BF16_TO_FP32(src0_ptr[i00]);
7803
+ }
7804
+
7805
+ quantize_row_q(src0_f32, dst_ptr + id, ne00);
7573
7806
  id += rs;
7574
7807
  }
7575
7808
  id += rs * (ne01 - ir1);
@@ -7590,7 +7823,25 @@ static void ggml_compute_forward_dup_f32(
7590
7823
  id += ne00 * ir0;
7591
7824
  for (int i01 = ir0; i01 < ir1; i01++) {
7592
7825
  for (int i00 = 0; i00 < ne00; i00++) {
7593
- const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
7826
+ const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
7827
+
7828
+ dst_ptr[id] = GGML_BF16_TO_FP32(*src0_ptr);
7829
+ id++;
7830
+ }
7831
+ }
7832
+ id += ne00 * (ne01 - ir1);
7833
+ }
7834
+ }
7835
+ } else if (dst->type == GGML_TYPE_BF16) {
7836
+ size_t id = 0;
7837
+ ggml_bf16_t * dst_ptr = (ggml_bf16_t *) dst->data;
7838
+
7839
+ for (int i03 = 0; i03 < ne03; i03++) {
7840
+ for (int i02 = 0; i02 < ne02; i02++) {
7841
+ id += ne00 * ir0;
7842
+ for (int i01 = ir0; i01 < ir1; i01++) {
7843
+ for (int i00 = 0; i00 < ne00; i00++) {
7844
+ const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
7594
7845
 
7595
7846
  dst_ptr[id] = *src0_ptr;
7596
7847
  id++;
@@ -7608,9 +7859,9 @@ static void ggml_compute_forward_dup_f32(
7608
7859
  id += ne00 * ir0;
7609
7860
  for (int i01 = ir0; i01 < ir1; i01++) {
7610
7861
  for (int i00 = 0; i00 < ne00; i00++) {
7611
- const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
7862
+ const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
7612
7863
 
7613
- dst_ptr[id] = GGML_FP32_TO_FP16(*src0_ptr);
7864
+ dst_ptr[id] = GGML_FP32_TO_FP16(GGML_BF16_TO_FP32(*src0_ptr));
7614
7865
  id++;
7615
7866
  }
7616
7867
  }
@@ -7621,18 +7872,16 @@ static void ggml_compute_forward_dup_f32(
7621
7872
  GGML_ASSERT(false); // TODO: implement
7622
7873
  }
7623
7874
  }
7624
-
7625
7875
  return;
7626
7876
  }
7627
7877
 
7628
7878
  // dst counters
7629
-
7630
7879
  int64_t i10 = 0;
7631
7880
  int64_t i11 = 0;
7632
7881
  int64_t i12 = 0;
7633
7882
  int64_t i13 = 0;
7634
7883
 
7635
- if (dst->type == GGML_TYPE_F32) {
7884
+ if (dst->type == GGML_TYPE_BF16) {
7636
7885
  for (int64_t i03 = 0; i03 < ne03; i03++) {
7637
7886
  for (int64_t i02 = 0; i02 < ne02; i02++) {
7638
7887
  i10 += ne00 * ir0;
@@ -7653,7 +7902,59 @@ static void ggml_compute_forward_dup_f32(
7653
7902
  const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
7654
7903
  char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
7655
7904
 
7656
- memcpy(dst_ptr, src0_ptr, sizeof(float));
7905
+ memcpy(dst_ptr, src0_ptr, sizeof(ggml_bf16_t));
7906
+
7907
+ if (++i10 == ne00) {
7908
+ i10 = 0;
7909
+ if (++i11 == ne01) {
7910
+ i11 = 0;
7911
+ if (++i12 == ne02) {
7912
+ i12 = 0;
7913
+ if (++i13 == ne03) {
7914
+ i13 = 0;
7915
+ }
7916
+ }
7917
+ }
7918
+ }
7919
+ }
7920
+ }
7921
+ i10 += ne00 * (ne01 - ir1);
7922
+ while (i10 >= ne0) {
7923
+ i10 -= ne0;
7924
+ if (++i11 == ne1) {
7925
+ i11 = 0;
7926
+ if (++i12 == ne2) {
7927
+ i12 = 0;
7928
+ if (++i13 == ne3) {
7929
+ i13 = 0;
7930
+ }
7931
+ }
7932
+ }
7933
+ }
7934
+ }
7935
+ }
7936
+ } else if (dst->type == GGML_TYPE_F16) {
7937
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
7938
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
7939
+ i10 += ne00 * ir0;
7940
+ while (i10 >= ne0) {
7941
+ i10 -= ne0;
7942
+ if (++i11 == ne1) {
7943
+ i11 = 0;
7944
+ if (++i12 == ne2) {
7945
+ i12 = 0;
7946
+ if (++i13 == ne3) {
7947
+ i13 = 0;
7948
+ }
7949
+ }
7950
+ }
7951
+ }
7952
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
7953
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
7954
+ const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
7955
+ char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
7956
+
7957
+ *(ggml_fp16_t *) dst_ptr = GGML_FP32_TO_FP16(GGML_BF16_TO_FP32(*(const ggml_bf16_t *) src0_ptr));
7657
7958
 
7658
7959
  if (++i10 == ne0) {
7659
7960
  i10 = 0;
@@ -7684,7 +7985,7 @@ static void ggml_compute_forward_dup_f32(
7684
7985
  }
7685
7986
  }
7686
7987
  }
7687
- } else if (dst->type == GGML_TYPE_F16) {
7988
+ } else if (dst->type == GGML_TYPE_F32) {
7688
7989
  for (int64_t i03 = 0; i03 < ne03; i03++) {
7689
7990
  for (int64_t i02 = 0; i02 < ne02; i02++) {
7690
7991
  i10 += ne00 * ir0;
@@ -7705,7 +8006,7 @@ static void ggml_compute_forward_dup_f32(
7705
8006
  const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
7706
8007
  char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
7707
8008
 
7708
- *(ggml_fp16_t *) dst_ptr = GGML_FP32_TO_FP16(*(const float *) src0_ptr);
8009
+ *(float *) dst_ptr = GGML_BF16_TO_FP32(*(const ggml_bf16_t *) src0_ptr);
7709
8010
 
7710
8011
  if (++i10 == ne0) {
7711
8012
  i10 = 0;
@@ -7741,31 +8042,27 @@ static void ggml_compute_forward_dup_f32(
7741
8042
  }
7742
8043
  }
7743
8044
 
7744
- // A simplified version of ggml_compute_forward_dup that doesn't do float upcasting, and just plain old memcpy.
7745
- static void ggml_compute_forward_dup_bytes(
8045
+ static void ggml_compute_forward_dup_f32(
7746
8046
  const struct ggml_compute_params * params,
7747
8047
  struct ggml_tensor * dst) {
7748
8048
 
7749
8049
  const struct ggml_tensor * src0 = dst->src[0];
7750
8050
 
7751
8051
  GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
7752
- GGML_ASSERT(src0->type == dst->type);
7753
8052
 
7754
8053
  if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
7755
8054
  return;
7756
8055
  }
7757
8056
 
7758
- if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst)) {
7759
- ggml_compute_forward_dup_same_cont(params, dst);
7760
- return;
7761
- }
7762
-
7763
- GGML_TENSOR_UNARY_OP_LOCALS;
8057
+ GGML_TENSOR_UNARY_OP_LOCALS
7764
8058
 
7765
- const size_t type_size = ggml_type_size(src0->type);
7766
8059
  const int ith = params->ith; // thread index
7767
8060
  const int nth = params->nth; // number of threads
7768
8061
 
8062
+ if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst) && src0->type == dst->type) {
8063
+ ggml_compute_forward_dup_same_cont(params, dst);
8064
+ return;
8065
+ }
7769
8066
 
7770
8067
  // parallelize by rows
7771
8068
  const int nr = ne01;
@@ -7777,9 +8074,9 @@ static void ggml_compute_forward_dup_bytes(
7777
8074
 
7778
8075
  if (src0->type == dst->type &&
7779
8076
  ne00 == ne0 &&
7780
- nb00 == type_size && nb0 == type_size) {
8077
+ nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) {
7781
8078
  // copy by rows
7782
- const size_t rs = ne00 * type_size;
8079
+ const size_t rs = ne00*nb00;
7783
8080
  for (int64_t i03 = 0; i03 < ne03; i03++) {
7784
8081
  for (int64_t i02 = 0; i02 < ne02; i02++) {
7785
8082
  for (int64_t i01 = ir0; i01 < ir1; i01++) {
@@ -7794,41 +8091,366 @@ static void ggml_compute_forward_dup_bytes(
7794
8091
  }
7795
8092
 
7796
8093
  if (ggml_is_contiguous(dst)) {
7797
- size_t id = 0;
7798
- char * dst_ptr = (char *) dst->data;
7799
- const size_t rs = ne00 * type_size;
7800
-
7801
- if (nb00 == type_size) {
7802
- // src0 is contigous on first dimension, copy by rows
7803
- for (int64_t i03 = 0; i03 < ne03; i03++) {
7804
- for (int64_t i02 = 0; i02 < ne02; i02++) {
7805
- id += rs * ir0;
7806
- for (int64_t i01 = ir0; i01 < ir1; i01++) {
7807
- const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
7808
- memcpy(dst_ptr + id, src0_ptr, rs);
7809
- id += rs;
7810
- }
7811
- id += rs * (ne01 - ir1);
7812
- }
7813
- }
7814
- } else {
7815
- //printf("%s: this is not optimal - fix me\n", __func__);
7816
-
7817
- for (int64_t i03 = 0; i03 < ne03; i03++) {
7818
- for (int64_t i02 = 0; i02 < ne02; i02++) {
7819
- id += rs * ir0;
7820
- for (int64_t i01 = ir0; i01 < ir1; i01++) {
7821
- for (int64_t i00 = 0; i00 < ne00; i00++) {
7822
- const char * src0_ptr = (char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03;
7823
- memcpy(dst_ptr + id, src0_ptr, type_size);
8094
+ // TODO: simplify
8095
+ if (nb00 == sizeof(float)) {
8096
+ if (dst->type == GGML_TYPE_F32) {
8097
+ size_t id = 0;
8098
+ const size_t rs = ne00 * nb00;
8099
+ char * dst_ptr = (char *) dst->data;
7824
8100
 
7825
- id += type_size;
8101
+ for (int i03 = 0; i03 < ne03; i03++) {
8102
+ for (int i02 = 0; i02 < ne02; i02++) {
8103
+ id += rs * ir0;
8104
+ for (int i01 = ir0; i01 < ir1; i01++) {
8105
+ const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
8106
+ memcpy(dst_ptr + id, src0_ptr, rs);
8107
+ id += rs;
7826
8108
  }
8109
+ id += rs * (ne01 - ir1);
7827
8110
  }
7828
- id += rs * (ne01 - ir1);
7829
8111
  }
7830
- }
7831
- }
8112
+ } else if (type_traits[dst->type].from_float) {
8113
+ ggml_from_float_t const quantize_row_q = type_traits[dst->type].from_float;
8114
+
8115
+ size_t id = 0;
8116
+ size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
8117
+ char * dst_ptr = (char *) dst->data;
8118
+
8119
+ for (int i03 = 0; i03 < ne03; i03++) {
8120
+ for (int i02 = 0; i02 < ne02; i02++) {
8121
+ id += rs * ir0;
8122
+ for (int i01 = ir0; i01 < ir1; i01++) {
8123
+ const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
8124
+ quantize_row_q(src0_ptr, dst_ptr + id, ne00);
8125
+ id += rs;
8126
+ }
8127
+ id += rs * (ne01 - ir1);
8128
+ }
8129
+ }
8130
+ } else {
8131
+ GGML_ASSERT(false); // TODO: implement
8132
+ }
8133
+ } else {
8134
+ //printf("%s: this is not optimal - fix me\n", __func__);
8135
+
8136
+ if (dst->type == GGML_TYPE_F32) {
8137
+ size_t id = 0;
8138
+ float * dst_ptr = (float *) dst->data;
8139
+
8140
+ for (int i03 = 0; i03 < ne03; i03++) {
8141
+ for (int i02 = 0; i02 < ne02; i02++) {
8142
+ id += ne00 * ir0;
8143
+ for (int i01 = ir0; i01 < ir1; i01++) {
8144
+ for (int i00 = 0; i00 < ne00; i00++) {
8145
+ const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
8146
+
8147
+ dst_ptr[id] = *src0_ptr;
8148
+ id++;
8149
+ }
8150
+ }
8151
+ id += ne00 * (ne01 - ir1);
8152
+ }
8153
+ }
8154
+ } else if (dst->type == GGML_TYPE_F16) {
8155
+ size_t id = 0;
8156
+ ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
8157
+
8158
+ for (int i03 = 0; i03 < ne03; i03++) {
8159
+ for (int i02 = 0; i02 < ne02; i02++) {
8160
+ id += ne00 * ir0;
8161
+ for (int i01 = ir0; i01 < ir1; i01++) {
8162
+ for (int i00 = 0; i00 < ne00; i00++) {
8163
+ const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
8164
+
8165
+ dst_ptr[id] = GGML_FP32_TO_FP16(*src0_ptr);
8166
+ id++;
8167
+ }
8168
+ }
8169
+ id += ne00 * (ne01 - ir1);
8170
+ }
8171
+ }
8172
+ } else if (dst->type == GGML_TYPE_BF16) {
8173
+ size_t id = 0;
8174
+ ggml_bf16_t * dst_ptr = (ggml_bf16_t *) dst->data;
8175
+
8176
+ for (int i03 = 0; i03 < ne03; i03++) {
8177
+ for (int i02 = 0; i02 < ne02; i02++) {
8178
+ id += ne00 * ir0;
8179
+ for (int i01 = ir0; i01 < ir1; i01++) {
8180
+ for (int i00 = 0; i00 < ne00; i00++) {
8181
+ const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
8182
+
8183
+ dst_ptr[id] = GGML_FP32_TO_BF16(*src0_ptr);
8184
+ id++;
8185
+ }
8186
+ }
8187
+ id += ne00 * (ne01 - ir1);
8188
+ }
8189
+ }
8190
+ } else {
8191
+ GGML_ASSERT(false); // TODO: implement
8192
+ }
8193
+ }
8194
+
8195
+ return;
8196
+ }
8197
+
8198
+ // dst counters
8199
+
8200
+ int64_t i10 = 0;
8201
+ int64_t i11 = 0;
8202
+ int64_t i12 = 0;
8203
+ int64_t i13 = 0;
8204
+
8205
+ if (dst->type == GGML_TYPE_F32) {
8206
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
8207
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
8208
+ i10 += ne00 * ir0;
8209
+ while (i10 >= ne0) {
8210
+ i10 -= ne0;
8211
+ if (++i11 == ne1) {
8212
+ i11 = 0;
8213
+ if (++i12 == ne2) {
8214
+ i12 = 0;
8215
+ if (++i13 == ne3) {
8216
+ i13 = 0;
8217
+ }
8218
+ }
8219
+ }
8220
+ }
8221
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
8222
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
8223
+ const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
8224
+ char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
8225
+
8226
+ memcpy(dst_ptr, src0_ptr, sizeof(float));
8227
+
8228
+ if (++i10 == ne0) {
8229
+ i10 = 0;
8230
+ if (++i11 == ne1) {
8231
+ i11 = 0;
8232
+ if (++i12 == ne2) {
8233
+ i12 = 0;
8234
+ if (++i13 == ne3) {
8235
+ i13 = 0;
8236
+ }
8237
+ }
8238
+ }
8239
+ }
8240
+ }
8241
+ }
8242
+ i10 += ne00 * (ne01 - ir1);
8243
+ while (i10 >= ne0) {
8244
+ i10 -= ne0;
8245
+ if (++i11 == ne1) {
8246
+ i11 = 0;
8247
+ if (++i12 == ne2) {
8248
+ i12 = 0;
8249
+ if (++i13 == ne3) {
8250
+ i13 = 0;
8251
+ }
8252
+ }
8253
+ }
8254
+ }
8255
+ }
8256
+ }
8257
+ } else if (dst->type == GGML_TYPE_F16) {
8258
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
8259
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
8260
+ i10 += ne00 * ir0;
8261
+ while (i10 >= ne0) {
8262
+ i10 -= ne0;
8263
+ if (++i11 == ne1) {
8264
+ i11 = 0;
8265
+ if (++i12 == ne2) {
8266
+ i12 = 0;
8267
+ if (++i13 == ne3) {
8268
+ i13 = 0;
8269
+ }
8270
+ }
8271
+ }
8272
+ }
8273
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
8274
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
8275
+ const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
8276
+ char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
8277
+
8278
+ *(ggml_fp16_t *) dst_ptr = GGML_FP32_TO_FP16(*(const float *) src0_ptr);
8279
+
8280
+ if (++i10 == ne0) {
8281
+ i10 = 0;
8282
+ if (++i11 == ne1) {
8283
+ i11 = 0;
8284
+ if (++i12 == ne2) {
8285
+ i12 = 0;
8286
+ if (++i13 == ne3) {
8287
+ i13 = 0;
8288
+ }
8289
+ }
8290
+ }
8291
+ }
8292
+ }
8293
+ }
8294
+ i10 += ne00 * (ne01 - ir1);
8295
+ while (i10 >= ne0) {
8296
+ i10 -= ne0;
8297
+ if (++i11 == ne1) {
8298
+ i11 = 0;
8299
+ if (++i12 == ne2) {
8300
+ i12 = 0;
8301
+ if (++i13 == ne3) {
8302
+ i13 = 0;
8303
+ }
8304
+ }
8305
+ }
8306
+ }
8307
+ }
8308
+ }
8309
+ } else if (dst->type == GGML_TYPE_BF16) {
8310
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
8311
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
8312
+ i10 += ne00 * ir0;
8313
+ while (i10 >= ne0) {
8314
+ i10 -= ne0;
8315
+ if (++i11 == ne1) {
8316
+ i11 = 0;
8317
+ if (++i12 == ne2) {
8318
+ i12 = 0;
8319
+ if (++i13 == ne3) {
8320
+ i13 = 0;
8321
+ }
8322
+ }
8323
+ }
8324
+ }
8325
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
8326
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
8327
+ const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
8328
+ char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
8329
+
8330
+ *(ggml_bf16_t *) dst_ptr = GGML_FP32_TO_BF16(*(const float *) src0_ptr);
8331
+
8332
+ if (++i10 == ne0) {
8333
+ i10 = 0;
8334
+ if (++i11 == ne1) {
8335
+ i11 = 0;
8336
+ if (++i12 == ne2) {
8337
+ i12 = 0;
8338
+ if (++i13 == ne3) {
8339
+ i13 = 0;
8340
+ }
8341
+ }
8342
+ }
8343
+ }
8344
+ }
8345
+ }
8346
+ i10 += ne00 * (ne01 - ir1);
8347
+ while (i10 >= ne0) {
8348
+ i10 -= ne0;
8349
+ if (++i11 == ne1) {
8350
+ i11 = 0;
8351
+ if (++i12 == ne2) {
8352
+ i12 = 0;
8353
+ if (++i13 == ne3) {
8354
+ i13 = 0;
8355
+ }
8356
+ }
8357
+ }
8358
+ }
8359
+ }
8360
+ }
8361
+ } else {
8362
+ GGML_ASSERT(false); // TODO: implement
8363
+ }
8364
+ }
8365
+
8366
+ // A simplified version of ggml_compute_forward_dup that doesn't do float upcasting, and just plain old memcpy.
8367
+ static void ggml_compute_forward_dup_bytes(
8368
+ const struct ggml_compute_params * params,
8369
+ struct ggml_tensor * dst) {
8370
+
8371
+ const struct ggml_tensor * src0 = dst->src[0];
8372
+
8373
+ GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
8374
+ GGML_ASSERT(src0->type == dst->type);
8375
+
8376
+ if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
8377
+ return;
8378
+ }
8379
+
8380
+ if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst)) {
8381
+ ggml_compute_forward_dup_same_cont(params, dst);
8382
+ return;
8383
+ }
8384
+
8385
+ GGML_TENSOR_UNARY_OP_LOCALS;
8386
+
8387
+ const size_t type_size = ggml_type_size(src0->type);
8388
+ const int ith = params->ith; // thread index
8389
+ const int nth = params->nth; // number of threads
8390
+
8391
+
8392
+ // parallelize by rows
8393
+ const int nr = ne01;
8394
+ // number of rows per thread
8395
+ const int dr = (nr + nth - 1) / nth;
8396
+ // row range for this thread
8397
+ const int ir0 = dr * ith;
8398
+ const int ir1 = MIN(ir0 + dr, nr);
8399
+
8400
+ if (src0->type == dst->type &&
8401
+ ne00 == ne0 &&
8402
+ nb00 == type_size && nb0 == type_size) {
8403
+ // copy by rows
8404
+ const size_t rs = ne00 * type_size;
8405
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
8406
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
8407
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
8408
+ memcpy(
8409
+ ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
8410
+ ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
8411
+ rs);
8412
+ }
8413
+ }
8414
+ }
8415
+ return;
8416
+ }
8417
+
8418
+ if (ggml_is_contiguous(dst)) {
8419
+ size_t id = 0;
8420
+ char * dst_ptr = (char *) dst->data;
8421
+ const size_t rs = ne00 * type_size;
8422
+
8423
+ if (nb00 == type_size) {
8424
+ // src0 is contigous on first dimension, copy by rows
8425
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
8426
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
8427
+ id += rs * ir0;
8428
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
8429
+ const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
8430
+ memcpy(dst_ptr + id, src0_ptr, rs);
8431
+ id += rs;
8432
+ }
8433
+ id += rs * (ne01 - ir1);
8434
+ }
8435
+ }
8436
+ } else {
8437
+ //printf("%s: this is not optimal - fix me\n", __func__);
8438
+
8439
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
8440
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
8441
+ id += rs * ir0;
8442
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
8443
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
8444
+ const char * src0_ptr = (char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03;
8445
+ memcpy(dst_ptr + id, src0_ptr, type_size);
8446
+
8447
+ id += type_size;
8448
+ }
8449
+ }
8450
+ id += rs * (ne01 - ir1);
8451
+ }
8452
+ }
8453
+ }
7832
8454
 
7833
8455
  return;
7834
8456
  }
@@ -7909,6 +8531,10 @@ static void ggml_compute_forward_dup(
7909
8531
  {
7910
8532
  ggml_compute_forward_dup_f16(params, dst);
7911
8533
  } break;
8534
+ case GGML_TYPE_BF16:
8535
+ {
8536
+ ggml_compute_forward_dup_bf16(params, dst);
8537
+ } break;
7912
8538
  case GGML_TYPE_F32:
7913
8539
  {
7914
8540
  ggml_compute_forward_dup_f32(params, dst);
@@ -8002,17 +8628,96 @@ static void ggml_compute_forward_add_f32(
8002
8628
  float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
8003
8629
  float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
8004
8630
 
8005
- for (int64_t i0 = 0; i0 < ne0; ++i0) {
8006
- const int64_t i10 = i0 % ne10;
8007
- float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10);
8631
+ for (int64_t i0 = 0; i0 < ne0; ++i0) {
8632
+ const int64_t i10 = i0 % ne10;
8633
+ float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10);
8634
+
8635
+ dst_ptr[i0] = src0_ptr[i0] + *src1_ptr;
8636
+ }
8637
+ }
8638
+ }
8639
+ }
8640
+
8641
+ static void ggml_compute_forward_add_f16_f32(
8642
+ const struct ggml_compute_params * params,
8643
+ struct ggml_tensor * dst) {
8644
+
8645
+ const struct ggml_tensor * src0 = dst->src[0];
8646
+ const struct ggml_tensor * src1 = dst->src[1];
8647
+
8648
+ GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
8649
+
8650
+ if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
8651
+ return;
8652
+ }
8653
+
8654
+ const int ith = params->ith;
8655
+ const int nth = params->nth;
8656
+
8657
+ const int nr = ggml_nrows(src0);
8658
+
8659
+ GGML_TENSOR_BINARY_OP_LOCALS
8660
+
8661
+ GGML_ASSERT(src0->type == GGML_TYPE_F16);
8662
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
8663
+
8664
+ if (dst->type == GGML_TYPE_F32) {
8665
+ GGML_ASSERT( nb0 == sizeof(float));
8666
+ }
8667
+ else {
8668
+ GGML_ASSERT(dst->type == GGML_TYPE_F16);
8669
+ GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
8670
+ }
8671
+
8672
+ GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
8673
+
8674
+ // rows per thread
8675
+ const int dr = (nr + nth - 1)/nth;
8676
+
8677
+ // row range for this thread
8678
+ const int ir0 = dr*ith;
8679
+ const int ir1 = MIN(ir0 + dr, nr);
8680
+
8681
+ if (nb10 == sizeof(float)) {
8682
+ if (dst->type == GGML_TYPE_F16) {
8683
+ for (int ir = ir0; ir < ir1; ++ir) {
8684
+ // src0, src1 and dst are same shape => same indices
8685
+ const int i3 = ir/(ne2*ne1);
8686
+ const int i2 = (ir - i3*ne2*ne1)/ne1;
8687
+ const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
8688
+
8689
+ ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
8690
+ ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
8691
+ float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
8692
+
8693
+ for (int i = 0; i < ne0; i++) {
8694
+ dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + src1_ptr[i]);
8695
+ }
8696
+ }
8697
+ } else {
8698
+ for (int ir = ir0; ir < ir1; ++ir) {
8699
+ // src0, src1 and dst are same shape => same indices
8700
+ const int i3 = ir/(ne2*ne1);
8701
+ const int i2 = (ir - i3*ne2*ne1)/ne1;
8702
+ const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
8703
+
8704
+ float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
8705
+ ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
8706
+ float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
8008
8707
 
8009
- dst_ptr[i0] = src0_ptr[i0] + *src1_ptr;
8708
+ for (int i = 0; i < ne0; i++) {
8709
+ dst_ptr[i] = GGML_FP16_TO_FP32(src0_ptr[i]) + src1_ptr[i];
8710
+ }
8010
8711
  }
8011
8712
  }
8012
8713
  }
8714
+ else {
8715
+ // src1 is not contiguous
8716
+ GGML_ASSERT(false);
8717
+ }
8013
8718
  }
8014
8719
 
8015
- static void ggml_compute_forward_add_f16_f32(
8720
+ static void ggml_compute_forward_add_bf16_f32(
8016
8721
  const struct ggml_compute_params * params,
8017
8722
  struct ggml_tensor * dst) {
8018
8723
 
@@ -8032,18 +8737,18 @@ static void ggml_compute_forward_add_f16_f32(
8032
8737
 
8033
8738
  GGML_TENSOR_BINARY_OP_LOCALS
8034
8739
 
8035
- GGML_ASSERT(src0->type == GGML_TYPE_F16);
8740
+ GGML_ASSERT(src0->type == GGML_TYPE_BF16);
8036
8741
  GGML_ASSERT(src1->type == GGML_TYPE_F32);
8037
8742
 
8038
8743
  if (dst->type == GGML_TYPE_F32) {
8039
8744
  GGML_ASSERT( nb0 == sizeof(float));
8040
8745
  }
8041
8746
  else {
8042
- GGML_ASSERT(dst->type == GGML_TYPE_F16);
8043
- GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
8747
+ GGML_ASSERT(dst->type == GGML_TYPE_BF16);
8748
+ GGML_ASSERT( nb0 == sizeof(ggml_bf16_t));
8044
8749
  }
8045
8750
 
8046
- GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
8751
+ GGML_ASSERT(nb00 == sizeof(ggml_bf16_t));
8047
8752
 
8048
8753
  // rows per thread
8049
8754
  const int dr = (nr + nth - 1)/nth;
@@ -8053,19 +8758,19 @@ static void ggml_compute_forward_add_f16_f32(
8053
8758
  const int ir1 = MIN(ir0 + dr, nr);
8054
8759
 
8055
8760
  if (nb10 == sizeof(float)) {
8056
- if (dst->type == GGML_TYPE_F16) {
8761
+ if (dst->type == GGML_TYPE_BF16) {
8057
8762
  for (int ir = ir0; ir < ir1; ++ir) {
8058
8763
  // src0, src1 and dst are same shape => same indices
8059
8764
  const int i3 = ir/(ne2*ne1);
8060
8765
  const int i2 = (ir - i3*ne2*ne1)/ne1;
8061
8766
  const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
8062
8767
 
8063
- ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
8064
- ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
8768
+ ggml_bf16_t * dst_ptr = (ggml_bf16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
8769
+ ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
8065
8770
  float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
8066
8771
 
8067
8772
  for (int i = 0; i < ne0; i++) {
8068
- dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + src1_ptr[i]);
8773
+ dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + src1_ptr[i]);
8069
8774
  }
8070
8775
  }
8071
8776
  } else {
@@ -8076,11 +8781,11 @@ static void ggml_compute_forward_add_f16_f32(
8076
8781
  const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
8077
8782
 
8078
8783
  float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
8079
- ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
8784
+ ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
8080
8785
  float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
8081
8786
 
8082
8787
  for (int i = 0; i < ne0; i++) {
8083
- dst_ptr[i] = GGML_FP16_TO_FP32(src0_ptr[i]) + src1_ptr[i];
8788
+ dst_ptr[i] = GGML_BF16_TO_FP32(src0_ptr[i]) + src1_ptr[i];
8084
8789
  }
8085
8790
  }
8086
8791
  }
@@ -8147,6 +8852,62 @@ static void ggml_compute_forward_add_f16_f16(
8147
8852
  }
8148
8853
  }
8149
8854
 
8855
+ static void ggml_compute_forward_add_bf16_bf16(
8856
+ const struct ggml_compute_params * params,
8857
+ struct ggml_tensor * dst) {
8858
+
8859
+ const struct ggml_tensor * src0 = dst->src[0];
8860
+ const struct ggml_tensor * src1 = dst->src[1];
8861
+
8862
+ GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
8863
+
8864
+ if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
8865
+ return;
8866
+ }
8867
+
8868
+ const int ith = params->ith;
8869
+ const int nth = params->nth;
8870
+
8871
+ const int nr = ggml_nrows(src0);
8872
+
8873
+ GGML_TENSOR_BINARY_OP_LOCALS
8874
+
8875
+ GGML_ASSERT(src0->type == GGML_TYPE_BF16);
8876
+ GGML_ASSERT(src1->type == GGML_TYPE_BF16);
8877
+ GGML_ASSERT(dst->type == GGML_TYPE_BF16);
8878
+
8879
+ GGML_ASSERT( nb0 == sizeof(ggml_bf16_t));
8880
+ GGML_ASSERT(nb00 == sizeof(ggml_bf16_t));
8881
+
8882
+ // rows per thread
8883
+ const int dr = (nr + nth - 1)/nth;
8884
+
8885
+ // row range for this thread
8886
+ const int ir0 = dr*ith;
8887
+ const int ir1 = MIN(ir0 + dr, nr);
8888
+
8889
+ if (nb10 == sizeof(ggml_bf16_t)) {
8890
+ for (int ir = ir0; ir < ir1; ++ir) {
8891
+ // src0, src1 and dst are same shape => same indices
8892
+ const int i3 = ir/(ne2*ne1);
8893
+ const int i2 = (ir - i3*ne2*ne1)/ne1;
8894
+ const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
8895
+
8896
+ ggml_bf16_t * dst_ptr = (ggml_bf16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
8897
+ ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
8898
+ ggml_bf16_t * src1_ptr = (ggml_bf16_t *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
8899
+
8900
+ for (int i = 0; i < ne0; i++) {
8901
+ dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + GGML_BF16_TO_FP32(src1_ptr[i]));
8902
+ }
8903
+ }
8904
+ }
8905
+ else {
8906
+ // src1 is not contiguous
8907
+ GGML_ASSERT(false);
8908
+ }
8909
+ }
8910
+
8150
8911
  static void ggml_compute_forward_add_q_f32(
8151
8912
  const struct ggml_compute_params * params,
8152
8913
  struct ggml_tensor * dst) {
@@ -8256,6 +9017,18 @@ static void ggml_compute_forward_add(
8256
9017
  GGML_ASSERT(false);
8257
9018
  }
8258
9019
  } break;
9020
+ case GGML_TYPE_BF16:
9021
+ {
9022
+ if (src1->type == GGML_TYPE_BF16) {
9023
+ ggml_compute_forward_add_bf16_bf16(params, dst);
9024
+ }
9025
+ else if (src1->type == GGML_TYPE_F32) {
9026
+ ggml_compute_forward_add_bf16_f32(params, dst);
9027
+ }
9028
+ else {
9029
+ GGML_ASSERT(false);
9030
+ }
9031
+ } break;
8259
9032
  case GGML_TYPE_Q4_0:
8260
9033
  case GGML_TYPE_Q4_1:
8261
9034
  case GGML_TYPE_Q5_0:
@@ -8514,6 +9287,110 @@ static void ggml_compute_forward_add1_q_f32(
8514
9287
  }
8515
9288
  }
8516
9289
 
9290
+ static void ggml_compute_forward_add1_bf16_f32(
9291
+ const struct ggml_compute_params * params,
9292
+ struct ggml_tensor * dst) {
9293
+
9294
+ const struct ggml_tensor * src0 = dst->src[0];
9295
+ const struct ggml_tensor * src1 = dst->src[1];
9296
+
9297
+ GGML_ASSERT(ggml_are_same_shape(src0, dst));
9298
+ GGML_ASSERT(ggml_is_scalar(src1));
9299
+
9300
+ if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
9301
+ return;
9302
+ }
9303
+
9304
+ // scalar to add
9305
+ const float v = *(float *) src1->data;
9306
+
9307
+ const int ith = params->ith;
9308
+ const int nth = params->nth;
9309
+
9310
+ const int nr = ggml_nrows(src0);
9311
+
9312
+ GGML_TENSOR_UNARY_OP_LOCALS
9313
+
9314
+ GGML_ASSERT(src0->type == GGML_TYPE_BF16);
9315
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
9316
+ GGML_ASSERT(dst->type == GGML_TYPE_BF16);
9317
+
9318
+ GGML_ASSERT( nb0 == sizeof(ggml_bf16_t));
9319
+ GGML_ASSERT(nb00 == sizeof(ggml_bf16_t));
9320
+
9321
+ // rows per thread
9322
+ const int dr = (nr + nth - 1)/nth;
9323
+
9324
+ // row range for this thread
9325
+ const int ir0 = dr*ith;
9326
+ const int ir1 = MIN(ir0 + dr, nr);
9327
+
9328
+ for (int ir = ir0; ir < ir1; ++ir) {
9329
+ // src0 and dst are same shape => same indices
9330
+ const int i3 = ir/(ne2*ne1);
9331
+ const int i2 = (ir - i3*ne2*ne1)/ne1;
9332
+ const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
9333
+
9334
+ ggml_bf16_t * dst_ptr = (ggml_bf16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 );
9335
+ ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
9336
+ for (int i = 0; i < ne0; i++) {
9337
+ dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + v);
9338
+ }
9339
+ }
9340
+ }
9341
+
9342
+ static void ggml_compute_forward_add1_bf16_bf16(
9343
+ const struct ggml_compute_params * params,
9344
+ struct ggml_tensor * dst) {
9345
+
9346
+ const struct ggml_tensor * src0 = dst->src[0];
9347
+ const struct ggml_tensor * src1 = dst->src[1];
9348
+
9349
+ GGML_ASSERT(ggml_are_same_shape(src0, dst));
9350
+ GGML_ASSERT(ggml_is_scalar(src1));
9351
+
9352
+ if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
9353
+ return;
9354
+ }
9355
+
9356
+ // scalar to add
9357
+ const float v = GGML_BF16_TO_FP32(*(ggml_bf16_t *) src1->data);
9358
+
9359
+ const int ith = params->ith;
9360
+ const int nth = params->nth;
9361
+
9362
+ const int nr = ggml_nrows(src0);
9363
+
9364
+ GGML_TENSOR_UNARY_OP_LOCALS
9365
+
9366
+ GGML_ASSERT(src0->type == GGML_TYPE_BF16);
9367
+ GGML_ASSERT(src1->type == GGML_TYPE_BF16);
9368
+ GGML_ASSERT(dst->type == GGML_TYPE_BF16);
9369
+
9370
+ GGML_ASSERT( nb0 == sizeof(ggml_bf16_t));
9371
+ GGML_ASSERT(nb00 == sizeof(ggml_bf16_t));
9372
+
9373
+ // rows per thread
9374
+ const int dr = (nr + nth - 1)/nth;
9375
+
9376
+ // row range for this thread
9377
+ const int ir0 = dr*ith;
9378
+ const int ir1 = MIN(ir0 + dr, nr);
9379
+
9380
+ for (int ir = ir0; ir < ir1; ++ir) {
9381
+ // src0 and dst are same shape => same indices
9382
+ const int i3 = ir/(ne2*ne1);
9383
+ const int i2 = (ir - i3*ne2*ne1)/ne1;
9384
+ const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
9385
+
9386
+ ggml_bf16_t * dst_ptr = (ggml_bf16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 );
9387
+ ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
9388
+ for (int i = 0; i < ne0; i++) {
9389
+ dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + v);
9390
+ }
9391
+ }
9392
+ }
9393
+
8517
9394
  static void ggml_compute_forward_add1(
8518
9395
  const struct ggml_compute_params * params,
8519
9396
  struct ggml_tensor * dst) {
@@ -8538,6 +9415,18 @@ static void ggml_compute_forward_add1(
8538
9415
  GGML_ASSERT(false);
8539
9416
  }
8540
9417
  } break;
9418
+ case GGML_TYPE_BF16:
9419
+ {
9420
+ if (src1->type == GGML_TYPE_BF16) {
9421
+ ggml_compute_forward_add1_bf16_bf16(params, dst);
9422
+ }
9423
+ else if (src1->type == GGML_TYPE_F32) {
9424
+ ggml_compute_forward_add1_bf16_f32(params, dst);
9425
+ }
9426
+ else {
9427
+ GGML_ASSERT(false);
9428
+ }
9429
+ } break;
8541
9430
  case GGML_TYPE_Q4_0:
8542
9431
  case GGML_TYPE_Q4_1:
8543
9432
  case GGML_TYPE_Q5_0:
@@ -8666,6 +9555,7 @@ static void ggml_compute_forward_acc(
8666
9555
  ggml_compute_forward_acc_f32(params, dst);
8667
9556
  } break;
8668
9557
  case GGML_TYPE_F16:
9558
+ case GGML_TYPE_BF16:
8669
9559
  case GGML_TYPE_Q4_0:
8670
9560
  case GGML_TYPE_Q4_1:
8671
9561
  case GGML_TYPE_Q5_0:
@@ -9187,6 +10077,40 @@ static void ggml_compute_forward_sum_f16(
9187
10077
  ((ggml_fp16_t *) dst->data)[0] = GGML_FP32_TO_FP16(sum);
9188
10078
  }
9189
10079
 
10080
+ static void ggml_compute_forward_sum_bf16(
10081
+ const struct ggml_compute_params * params,
10082
+ struct ggml_tensor * dst) {
10083
+
10084
+ const struct ggml_tensor * src0 = dst->src[0];
10085
+
10086
+ assert(params->ith == 0);
10087
+ assert(ggml_is_scalar(dst));
10088
+
10089
+ if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
10090
+ return;
10091
+ }
10092
+
10093
+ assert(src0->nb[0] == sizeof(ggml_bf16_t));
10094
+
10095
+ GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
10096
+ GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
10097
+
10098
+ float sum = 0;
10099
+ float row_sum = 0;
10100
+
10101
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
10102
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
10103
+ for (int64_t i01 = 0; i01 < ne01; i01++) {
10104
+ ggml_vec_sum_bf16_ggf(ne00,
10105
+ &row_sum,
10106
+ (ggml_bf16_t *) ((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03));
10107
+ sum += row_sum;
10108
+ }
10109
+ }
10110
+ }
10111
+ ((ggml_bf16_t *) dst->data)[0] = GGML_FP32_TO_BF16(sum);
10112
+ }
10113
+
9190
10114
  static void ggml_compute_forward_sum(
9191
10115
  const struct ggml_compute_params * params,
9192
10116
  struct ggml_tensor * dst) {
@@ -9202,6 +10126,10 @@ static void ggml_compute_forward_sum(
9202
10126
  {
9203
10127
  ggml_compute_forward_sum_f16(params, dst);
9204
10128
  } break;
10129
+ case GGML_TYPE_BF16:
10130
+ {
10131
+ ggml_compute_forward_sum_bf16(params, dst);
10132
+ } break;
9205
10133
  default:
9206
10134
  {
9207
10135
  GGML_ASSERT(false);
@@ -9476,6 +10404,7 @@ static void ggml_compute_forward_repeat(
9476
10404
 
9477
10405
  switch (src0->type) {
9478
10406
  case GGML_TYPE_F16:
10407
+ case GGML_TYPE_BF16:
9479
10408
  case GGML_TYPE_I16:
9480
10409
  {
9481
10410
  ggml_compute_forward_repeat_f16(params, dst);
@@ -11793,6 +12722,7 @@ static void ggml_compute_forward_set(
11793
12722
  ggml_compute_forward_set_f32(params, dst);
11794
12723
  } break;
11795
12724
  case GGML_TYPE_F16:
12725
+ case GGML_TYPE_BF16:
11796
12726
  case GGML_TYPE_Q4_0:
11797
12727
  case GGML_TYPE_Q4_1:
11798
12728
  case GGML_TYPE_Q5_0:
@@ -11967,6 +12897,49 @@ static void ggml_compute_forward_get_rows_f16(
11967
12897
  }
11968
12898
  }
11969
12899
 
12900
+ static void ggml_compute_forward_get_rows_bf16(
12901
+ const struct ggml_compute_params * params,
12902
+ struct ggml_tensor * dst) {
12903
+
12904
+ const struct ggml_tensor * src0 = dst->src[0];
12905
+ const struct ggml_tensor * src1 = dst->src[1];
12906
+
12907
+ if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
12908
+ return;
12909
+ }
12910
+
12911
+ GGML_TENSOR_BINARY_OP_LOCALS
12912
+
12913
+ const int64_t nc = ne00;
12914
+ const int64_t nr = ggml_nelements(src1);
12915
+
12916
+ assert(ne0 == nc);
12917
+ assert(ne02 == ne11);
12918
+ assert(nb00 == sizeof(ggml_bf16_t));
12919
+ assert(ggml_nrows(dst) == nr);
12920
+
12921
+ const int ith = params->ith;
12922
+ const int nth = params->nth;
12923
+
12924
+ // rows per thread
12925
+ const int dr = (nr + nth - 1)/nth;
12926
+
12927
+ // row range for this thread
12928
+ const int ir0 = dr*ith;
12929
+ const int ir1 = MIN(ir0 + dr, nr);
12930
+
12931
+ for (int64_t i = ir0; i < ir1; ++i) {
12932
+ const int64_t i12 = i/(ne11*ne10);
12933
+ const int64_t i11 = (i - i12*ne11*ne10)/ne10;
12934
+ const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
12935
+ const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
12936
+
12937
+ ggml_bf16_to_fp32_row(
12938
+ (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
12939
+ (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc);
12940
+ }
12941
+ }
12942
+
11970
12943
  static void ggml_compute_forward_get_rows_f32(
11971
12944
  const struct ggml_compute_params * params,
11972
12945
  struct ggml_tensor * dst) {
@@ -12044,6 +13017,10 @@ static void ggml_compute_forward_get_rows(
12044
13017
  {
12045
13018
  ggml_compute_forward_get_rows_f16(params, dst);
12046
13019
  } break;
13020
+ case GGML_TYPE_BF16:
13021
+ {
13022
+ ggml_compute_forward_get_rows_bf16(params, dst);
13023
+ } break;
12047
13024
  case GGML_TYPE_F32:
12048
13025
  case GGML_TYPE_I32:
12049
13026
  {
@@ -12739,6 +13716,7 @@ static void ggml_compute_forward_alibi(
12739
13716
  {
12740
13717
  ggml_compute_forward_alibi_f32(params, dst);
12741
13718
  } break;
13719
+ case GGML_TYPE_BF16:
12742
13720
  case GGML_TYPE_Q4_0:
12743
13721
  case GGML_TYPE_Q4_1:
12744
13722
  case GGML_TYPE_Q5_0:
@@ -12828,6 +13806,7 @@ static void ggml_compute_forward_clamp(
12828
13806
  ggml_compute_forward_clamp_f32(params, dst);
12829
13807
  } break;
12830
13808
  case GGML_TYPE_F16:
13809
+ case GGML_TYPE_BF16:
12831
13810
  case GGML_TYPE_Q4_0:
12832
13811
  case GGML_TYPE_Q4_1:
12833
13812
  case GGML_TYPE_Q5_0:
@@ -15921,6 +16900,7 @@ static void ggml_compute_forward_get_rel_pos(
15921
16900
 
15922
16901
  switch (src0->type) {
15923
16902
  case GGML_TYPE_F16:
16903
+ case GGML_TYPE_BF16:
15924
16904
  {
15925
16905
  ggml_compute_forward_get_rel_pos_f16(params, dst);
15926
16906
  } break;
@@ -18785,7 +19765,10 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
18785
19765
  case GGML_OP_CPY:
18786
19766
  case GGML_OP_DUP:
18787
19767
  {
18788
- if (ggml_is_quantized(node->type)) {
19768
+ if (ggml_is_quantized(node->type) ||
19769
+ // F16 -> BF16 and BF16 -> F16 copies go through intermediate F32
19770
+ (node->src[0]->type == GGML_TYPE_F16 && node->src[1] && node->src[1]->type == GGML_TYPE_BF16) ||
19771
+ (node->src[0]->type == GGML_TYPE_BF16 && node->src[1] && node->src[1]->type == GGML_TYPE_F16)) {
18789
19772
  cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
18790
19773
  }
18791
19774
  } break;
@@ -18864,7 +19847,8 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
18864
19847
  const int64_t ne10 = node->src[1]->ne[0]; // L
18865
19848
  const int64_t ne11 = node->src[1]->ne[1]; // Cin
18866
19849
 
18867
- if (node->src[0]->type == GGML_TYPE_F16 &&
19850
+ if ((node->src[0]->type == GGML_TYPE_F16 ||
19851
+ node->src[0]->type == GGML_TYPE_BF16) &&
18868
19852
  node->src[1]->type == GGML_TYPE_F32) {
18869
19853
  cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02;
18870
19854
  cur += sizeof(ggml_fp16_t)*ne10*ne11;
@@ -18900,6 +19884,9 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
18900
19884
  } else if (node->src[1]->type == GGML_TYPE_F16) {
18901
19885
  cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1)
18902
19886
  cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
19887
+ } else if (node->src[1]->type == GGML_TYPE_BF16) {
19888
+ cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1)
19889
+ cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
18903
19890
  }
18904
19891
  } break;
18905
19892
  case GGML_OP_FLASH_ATTN_EXT:
@@ -18916,6 +19903,9 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
18916
19903
  } else if (node->src[1]->type == GGML_TYPE_F16) {
18917
19904
  cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
18918
19905
  cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2
19906
+ } else if (node->src[1]->type == GGML_TYPE_BF16) {
19907
+ cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
19908
+ cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2
18919
19909
  }
18920
19910
  } break;
18921
19911
  case GGML_OP_FLASH_ATTN_BACK:
@@ -18929,6 +19919,9 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
18929
19919
  } else if (node->src[1]->type == GGML_TYPE_F16) {
18930
19920
  cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1)
18931
19921
  cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
19922
+ } else if (node->src[1]->type == GGML_TYPE_BF16) {
19923
+ cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1)
19924
+ cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
18932
19925
  }
18933
19926
  } break;
18934
19927
 
@@ -19705,7 +20698,9 @@ void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph
19705
20698
  if (node->type == GGML_TYPE_I8 || node->type == GGML_TYPE_I16 || node->type == GGML_TYPE_I32) {
19706
20699
  fprintf(fp, "%d", ggml_get_i32_1d(node, j));
19707
20700
  }
19708
- else if (node->type == GGML_TYPE_F32 || node->type == GGML_TYPE_F16) {
20701
+ else if (node->type == GGML_TYPE_F32 ||
20702
+ node->type == GGML_TYPE_F16 ||
20703
+ node->type == GGML_TYPE_BF16) {
19709
20704
  fprintf(fp, "%.1e", (double)ggml_get_f32_1d(node, j));
19710
20705
  }
19711
20706
  else {
@@ -20763,6 +21758,12 @@ size_t ggml_quantize_chunk(
20763
21758
  ggml_fp32_to_fp16_row(src + start, (ggml_fp16_t *)dst + start, n);
20764
21759
  result = n * elemsize;
20765
21760
  } break;
21761
+ case GGML_TYPE_BF16:
21762
+ {
21763
+ size_t elemsize = sizeof(ggml_bf16_t);
21764
+ ggml_fp32_to_bf16_row(src + start, (ggml_bf16_t *)dst + start, n);
21765
+ result = n * elemsize;
21766
+ } break;
20766
21767
  case GGML_TYPE_F32:
20767
21768
  {
20768
21769
  size_t elemsize = sizeof(float);
@@ -21139,7 +22140,7 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
21139
22140
  }
21140
22141
 
21141
22142
  // read the tensor infos
21142
- {
22143
+ if (ctx->header.n_tensors > 0) {
21143
22144
  ctx->infos = GGML_CALLOC(ctx->header.n_tensors, sizeof(struct gguf_tensor_info));
21144
22145
 
21145
22146
  for (uint64_t i = 0; i < ctx->header.n_tensors; ++i) {