@fugood/llama.node 0.0.1-alpha.4 → 0.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (84) hide show
  1. package/CMakeLists.txt +42 -7
  2. package/README.md +10 -0
  3. package/bin/darwin/arm64/default.metallib +0 -0
  4. package/bin/darwin/arm64/llama-node.node +0 -0
  5. package/bin/darwin/x64/default.metallib +0 -0
  6. package/bin/darwin/x64/llama-node.node +0 -0
  7. package/bin/linux/arm64/llama-node.node +0 -0
  8. package/bin/linux/x64/llama-node.node +0 -0
  9. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  10. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  11. package/lib/binding.js +1 -1
  12. package/lib/binding.ts +16 -2
  13. package/lib/index.ts +2 -2
  14. package/package.json +15 -3
  15. package/src/DetokenizeWorker.cpp +22 -0
  16. package/src/DetokenizeWorker.h +19 -0
  17. package/src/EmbeddingWorker.cpp +46 -0
  18. package/src/EmbeddingWorker.h +23 -0
  19. package/src/LlamaCompletionWorker.cpp +5 -1
  20. package/src/LlamaCompletionWorker.h +4 -0
  21. package/src/LlamaContext.cpp +80 -1
  22. package/src/LlamaContext.h +3 -0
  23. package/src/TokenizeWorker.cpp +26 -0
  24. package/src/TokenizeWorker.h +23 -0
  25. package/src/common.hpp +12 -7
  26. package/src/llama.cpp/CMakeLists.txt +13 -7
  27. package/src/llama.cpp/common/common.cpp +221 -173
  28. package/src/llama.cpp/common/common.h +19 -8
  29. package/src/llama.cpp/common/json-schema-to-grammar.h +4 -0
  30. package/src/llama.cpp/common/log.h +2 -2
  31. package/src/llama.cpp/common/sampling.cpp +17 -1
  32. package/src/llama.cpp/common/sampling.h +28 -20
  33. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +17 -11
  34. package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +5 -5
  35. package/src/llama.cpp/examples/finetune/finetune.cpp +1 -1
  36. package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +15 -4
  37. package/src/llama.cpp/examples/imatrix/imatrix.cpp +72 -39
  38. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +42 -3
  39. package/src/llama.cpp/examples/llava/clip.cpp +74 -23
  40. package/src/llama.cpp/examples/llava/llava-cli.cpp +37 -28
  41. package/src/llama.cpp/examples/lookup/lookup-stats.cpp +0 -1
  42. package/src/llama.cpp/examples/lookup/lookup.cpp +0 -1
  43. package/src/llama.cpp/examples/main/main.cpp +10 -8
  44. package/src/llama.cpp/examples/perplexity/perplexity.cpp +175 -55
  45. package/src/llama.cpp/examples/quantize/CMakeLists.txt +1 -1
  46. package/src/llama.cpp/examples/quantize/quantize.cpp +74 -47
  47. package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +1 -1
  48. package/src/llama.cpp/examples/server/server.cpp +97 -86
  49. package/src/llama.cpp/examples/server/utils.hpp +17 -15
  50. package/src/llama.cpp/ggml-backend.c +7 -5
  51. package/src/llama.cpp/ggml-impl.h +339 -4
  52. package/src/llama.cpp/ggml-kompute.cpp +7 -0
  53. package/src/llama.cpp/ggml-opencl.cpp +1 -0
  54. package/src/llama.cpp/ggml-quants.c +302 -293
  55. package/src/llama.cpp/ggml-sycl.cpp +28 -16
  56. package/src/llama.cpp/ggml-vulkan-shaders.hpp +46843 -39205
  57. package/src/llama.cpp/ggml-vulkan.cpp +951 -263
  58. package/src/llama.cpp/ggml.c +1469 -116
  59. package/src/llama.cpp/ggml.h +37 -7
  60. package/src/llama.cpp/llama.cpp +969 -432
  61. package/src/llama.cpp/llama.h +46 -14
  62. package/src/llama.cpp/requirements/requirements-convert-hf-to-gguf-update.txt +2 -0
  63. package/src/llama.cpp/requirements/requirements-convert-hf-to-gguf.txt +0 -1
  64. package/src/llama.cpp/requirements/requirements-convert.txt +2 -2
  65. package/src/llama.cpp/requirements.txt +1 -0
  66. package/src/llama.cpp/sgemm.cpp +134 -103
  67. package/src/llama.cpp/sgemm.h +4 -2
  68. package/src/llama.cpp/tests/CMakeLists.txt +96 -36
  69. package/src/llama.cpp/tests/test-backend-ops.cpp +56 -6
  70. package/src/llama.cpp/tests/test-chat-template.cpp +4 -0
  71. package/src/llama.cpp/tests/test-grammar-integration.cpp +225 -136
  72. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +1 -0
  73. package/src/llama.cpp/tests/test-tokenizer-0.cpp +292 -0
  74. package/src/llama.cpp/tests/{test-tokenizer-1-llama.cpp → test-tokenizer-1-spm.cpp} +1 -1
  75. package/src/llama.cpp/unicode-data.cpp +1188 -656
  76. package/src/llama.cpp/unicode-data.h +4 -3
  77. package/src/llama.cpp/unicode.cpp +590 -49
  78. package/src/llama.cpp/unicode.h +6 -3
  79. package/bin/win32/arm64/llama-node.node +0 -0
  80. package/bin/win32/arm64/node.lib +0 -0
  81. package/bin/win32/x64/llama-node.node +0 -0
  82. package/bin/win32/x64/node.lib +0 -0
  83. package/src/llama.cpp/tests/test-tokenizer-0-falcon.cpp +0 -187
  84. package/src/llama.cpp/tests/test-tokenizer-0-llama.cpp +0 -190
@@ -322,7 +322,7 @@ static ggml_fp16_t ggml_table_exp_f16[1 << 16];
322
322
  // precomputed f32 table for f16 (256 KB) (ggml-impl.h)
323
323
  float ggml_table_f32_f16[1 << 16];
324
324
 
325
- const char * ggml_status_to_string(enum ggml_status status) {
325
+ GGML_CALL const char * ggml_status_to_string(enum ggml_status status) {
326
326
  switch (status) {
327
327
  case GGML_STATUS_ALLOC_FAILED: return "GGML status: error (failed to allocate memory)";
328
328
  case GGML_STATUS_FAILED: return "GGML status: error (operation failed)";
@@ -333,16 +333,26 @@ const char * ggml_status_to_string(enum ggml_status status) {
333
333
  return "GGML status: unknown";
334
334
  }
335
335
 
336
- // note: do not use these inside ggml.c
337
- // these are meant to be used via the ggml.h API
338
336
  float ggml_fp16_to_fp32(ggml_fp16_t x) {
337
+ #define ggml_fp16_to_fp32 do_not_use__ggml_fp16_to_fp32__in_ggml
339
338
  return GGML_FP16_TO_FP32(x);
340
339
  }
341
340
 
342
341
  ggml_fp16_t ggml_fp32_to_fp16(float x) {
342
+ #define ggml_fp32_to_fp16 do_not_use__ggml_fp32_to_fp16__in_ggml
343
343
  return GGML_FP32_TO_FP16(x);
344
344
  }
345
345
 
346
+ float ggml_bf16_to_fp32(ggml_bf16_t x) {
347
+ #define ggml_bf16_to_fp32 do_not_use__ggml_bf16_to_fp32__in_ggml
348
+ return GGML_BF16_TO_FP32(x); // it just left shifts
349
+ }
350
+
351
+ ggml_bf16_t ggml_fp32_to_bf16(float x) {
352
+ #define ggml_fp32_to_bf16 do_not_use__ggml_fp32_to_bf16__in_ggml
353
+ return GGML_FP32_TO_BF16(x);
354
+ }
355
+
346
356
  void ggml_fp16_to_fp32_row(const ggml_fp16_t * x, float * y, int64_t n) {
347
357
  for (int64_t i = 0; i < n; i++) {
348
358
  y[i] = GGML_FP16_TO_FP32(x[i]);
@@ -368,6 +378,49 @@ void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, int64_t n) {
368
378
  }
369
379
  }
370
380
 
381
+ void ggml_bf16_to_fp32_row(const ggml_bf16_t * x, float * y, int64_t n) {
382
+ int64_t i = 0;
383
+ #if defined(__AVX512F__)
384
+ for (; i + 16 <= n; i += 16) {
385
+ _mm512_storeu_ps(y + i,
386
+ _mm512_castsi512_ps(
387
+ _mm512_slli_epi32(
388
+ _mm512_cvtepu16_epi32(
389
+ _mm256_loadu_si256(
390
+ (const __m256i *)(x + i))),
391
+ 16)));
392
+ }
393
+ #elif defined(__AVX2__)
394
+ for (; i + 8 <= n; i += 8) {
395
+ _mm256_storeu_ps(y + i,
396
+ _mm256_castsi256_ps(
397
+ _mm256_slli_epi32(
398
+ _mm256_cvtepu16_epi32(
399
+ _mm_loadu_si128(
400
+ (const __m128i *)(x + i))),
401
+ 16)));
402
+ }
403
+ #endif
404
+ for (; i < n; i++) {
405
+ y[i] = GGML_BF16_TO_FP32(x[i]);
406
+ }
407
+ }
408
+
409
+ void ggml_fp32_to_bf16_row(const float * x, ggml_bf16_t * y, int64_t n) {
410
+ int i = 0;
411
+ #if defined(__AVX512BF16__)
412
+ for (; i + 32 <= n; i += 32) {
413
+ _mm512_storeu_ps(
414
+ (__m512 *)(y + i),
415
+ (__m512)_mm512_cvtne2ps_pbh(_mm512_loadu_ps(x + i + 16),
416
+ _mm512_loadu_ps(x + i)));
417
+ }
418
+ #endif
419
+ for (; i < n; i++) {
420
+ y[i] = GGML_FP32_TO_BF16(x[i]);
421
+ }
422
+ }
423
+
371
424
  bool ggml_guid_matches(ggml_guid_t guid_a, ggml_guid_t guid_b) {
372
425
  return memcmp(guid_a, guid_b, sizeof(ggml_guid)) == 0;
373
426
  }
@@ -503,6 +556,7 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
503
556
 
504
557
  static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float * restrict x, size_t bx, const float * restrict y, size_t by, int nrc);
505
558
  static void ggml_vec_dot_f16(int n, float * restrict s, size_t bs, ggml_fp16_t * restrict x, size_t bx, ggml_fp16_t * restrict y, size_t by, int nrc);
559
+ static void ggml_vec_dot_bf16(int n, float * restrict s, size_t bs, ggml_bf16_t * restrict x, size_t bx, ggml_bf16_t * restrict y, size_t by, int nrc);
506
560
 
507
561
  static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
508
562
  [GGML_TYPE_I8] = {
@@ -845,6 +899,18 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
845
899
  .type_size = sizeof(block_q8_K),
846
900
  .is_quantized = true,
847
901
  .from_float = quantize_row_q8_K,
902
+ },
903
+ [GGML_TYPE_BF16] = {
904
+ .type_name = "bf16",
905
+ .blck_size = 1,
906
+ .type_size = sizeof(ggml_bf16_t),
907
+ .is_quantized = false,
908
+ .to_float = (ggml_to_float_t) ggml_bf16_to_fp32_row,
909
+ .from_float = (ggml_from_float_t) ggml_fp32_to_bf16_row,
910
+ .from_float_reference = (ggml_from_float_t) ggml_fp32_to_bf16_row,
911
+ .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_bf16,
912
+ .vec_dot_type = GGML_TYPE_BF16,
913
+ .nrows = 1,
848
914
  }
849
915
  };
850
916
 
@@ -858,18 +924,6 @@ ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type) {
858
924
  // simd mappings
859
925
  //
860
926
 
861
- #if defined(__ARM_NEON)
862
- #if !defined(__aarch64__)
863
-
864
- // 64-bit compatibility
865
-
866
- inline static float vaddvq_f32(float32x4_t v) {
867
- return vgetq_lane_f32(v, 0) + vgetq_lane_f32(v, 1) + vgetq_lane_f32(v, 2) + vgetq_lane_f32(v, 3);
868
- }
869
-
870
- #endif
871
- #endif
872
-
873
927
  // we define a common set of C macros which map to specific intrinsics based on the current architecture
874
928
  // we then implement the fundamental computation operations below using only these macros
875
929
  // adding support for new architectures requires to define the corresponding SIMD macros
@@ -963,7 +1017,7 @@ inline static float vaddvq_f32(float32x4_t v) {
963
1017
  #define GGML_F16_VEC_ZERO GGML_F16x8_ZERO
964
1018
  #define GGML_F16_VEC_SET1 GGML_F16x8_SET1
965
1019
  #define GGML_F16_VEC_LOAD(p, i) GGML_F16x8_LOAD(p)
966
- #define GGML_F16_VEC_STORE(p, r, i) GGML_F16x8_STORE(p, r[i])
1020
+ #define GGML_F16_VEC_STORE(p, r, i) GGML_F16x8_STORE((ggml_fp16_internal_t *)(p), r[i])
967
1021
  #define GGML_F16_VEC_FMA GGML_F16x8_FMA
968
1022
  #define GGML_F16_VEC_ADD GGML_F16x8_ADD
969
1023
  #define GGML_F16_VEC_MUL GGML_F16x8_MUL
@@ -989,7 +1043,7 @@ inline static float vaddvq_f32(float32x4_t v) {
989
1043
  #define GGML_F16_VEC_ZERO GGML_F32Cx4_ZERO
990
1044
  #define GGML_F16_VEC_SET1 GGML_F32Cx4_SET1
991
1045
  #define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx4_LOAD(p)
992
- #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE(p, r[i])
1046
+ #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE((ggml_fp16_internal_t *)(p), r[i])
993
1047
  #define GGML_F16_VEC_FMA GGML_F32Cx4_FMA
994
1048
  #define GGML_F16_VEC_ADD GGML_F32Cx4_ADD
995
1049
  #define GGML_F16_VEC_MUL GGML_F32Cx4_MUL
@@ -1058,7 +1112,7 @@ do { \
1058
1112
 
1059
1113
  // unlike _mm256_cvt intrinsics that require F16C, _mm512_cvt is defined in AVX512F
1060
1114
  // so F16C guard isn't required
1061
- #define GGML_F32Cx16_LOAD(x) _mm512_cvtph_ps(_mm256_loadu_si256((__m256i *)(x)))
1115
+ #define GGML_F32Cx16_LOAD(x) _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(x)))
1062
1116
  #define GGML_F32Cx16_STORE(x, y) _mm256_storeu_si256((__m256i *)(x), _mm512_cvtps_ph(y, 0))
1063
1117
 
1064
1118
  #define GGML_F32Cx16_FMA(a, b, c) _mm512_fmadd_ps(b, c, a)
@@ -1156,7 +1210,7 @@ do { \
1156
1210
 
1157
1211
  #if defined(__F16C__)
1158
1212
  // the _mm256_cvt intrinsics require F16C
1159
- #define GGML_F32Cx8_LOAD(x) _mm256_cvtph_ps(_mm_loadu_si128((__m128i *)(x)))
1213
+ #define GGML_F32Cx8_LOAD(x) _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(x)))
1160
1214
  #define GGML_F32Cx8_STORE(x, y) _mm_storeu_si128((__m128i *)(x), _mm256_cvtps_ph(y, 0))
1161
1215
  #else
1162
1216
  static inline __m256 __avx_f32cx8_load(ggml_fp16_t *x) {
@@ -1492,6 +1546,8 @@ inline static void ggml_vec_set_i32(const int n, int32_t * x, const int32_t v) {
1492
1546
 
1493
1547
  inline static void ggml_vec_set_f16(const int n, ggml_fp16_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
1494
1548
 
1549
+ inline static void ggml_vec_set_bf16(const int n, ggml_bf16_t * x, const ggml_bf16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
1550
+
1495
1551
  inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] + y[i]; }
1496
1552
  inline static void ggml_vec_add1_f32(const int n, float * z, const float * x, const float v) { for (int i = 0; i < n; ++i) z[i] = x[i] + v; }
1497
1553
  inline static void ggml_vec_acc_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] += x[i]; }
@@ -1510,7 +1566,7 @@ static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float *
1510
1566
  UNUSED(by);
1511
1567
  UNUSED(bs);
1512
1568
 
1513
- #ifdef GGML_SIMD
1569
+ #if defined(GGML_SIMD)
1514
1570
  float sumf = 0.0f;
1515
1571
  const int np = (n & ~(GGML_F32_STEP - 1));
1516
1572
 
@@ -1546,6 +1602,70 @@ static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float *
1546
1602
  *s = sumf;
1547
1603
  }
1548
1604
 
1605
+ static void ggml_vec_dot_bf16(int n, float * restrict s, size_t bs, ggml_bf16_t * restrict x, size_t bx, ggml_bf16_t * restrict y, size_t by, int nrc) {
1606
+ assert(nrc == 1);
1607
+ UNUSED(nrc);
1608
+ UNUSED(bx);
1609
+ UNUSED(by);
1610
+ UNUSED(bs);
1611
+ int i = 0;
1612
+ ggml_float sumf = 0;
1613
+
1614
+ #if defined(__AVX512BF16__)
1615
+ __m512 c1 = _mm512_setzero_ps();
1616
+ __m512 c2 = _mm512_setzero_ps();
1617
+ for (; i + 64 <= n; i += 64) {
1618
+ c1 = _mm512_dpbf16_ps(c1, (__m512bh)_mm512_loadu_ps((const float *)(x + i)),
1619
+ (__m512bh)_mm512_loadu_ps((const float *)(y + i)));
1620
+ c2 = _mm512_dpbf16_ps(c2, (__m512bh)_mm512_loadu_ps((const float *)(x + i + 32)),
1621
+ (__m512bh)_mm512_loadu_ps((const float *)(y + i + 32)));
1622
+ }
1623
+ sumf += (ggml_float)_mm512_reduce_add_ps(c1);
1624
+ sumf += (ggml_float)_mm512_reduce_add_ps(c2);
1625
+
1626
+ #elif defined(__AVX512F__)
1627
+ #define LOAD(p) _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i *)(p))), 16))
1628
+ __m512 c1 = _mm512_setzero_ps();
1629
+ __m512 c2 = _mm512_setzero_ps();
1630
+ for (; i + 32 <= n; i += 32) {
1631
+ c1 = _mm512_add_ps(_mm512_mul_ps(LOAD(x + i), LOAD(y + i)), c1);
1632
+ c2 = _mm512_add_ps(_mm512_mul_ps(LOAD(x + i + 16), LOAD(y + i + 16)), c2);
1633
+ }
1634
+ sumf += (ggml_float)_mm512_reduce_add_ps(c1);
1635
+ sumf += (ggml_float)_mm512_reduce_add_ps(c2);
1636
+
1637
+ #undef LOAD
1638
+ #elif defined(__AVX2__)
1639
+ #define LOAD(p) _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)(p))), 16))
1640
+ __m256 c1 = _mm256_setzero_ps();
1641
+ __m256 c2 = _mm256_setzero_ps();
1642
+ __m256 c3 = _mm256_setzero_ps();
1643
+ __m256 c4 = _mm256_setzero_ps();
1644
+ for (; i + 32 <= n; i += 32) {
1645
+ c1 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i), LOAD(y + i)), c1);
1646
+ c2 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i + 8), LOAD(y + i + 8)), c2);
1647
+ c3 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i + 16), LOAD(y + i + 16)), c3);
1648
+ c4 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i + 24), LOAD(y + i + 24)), c4);
1649
+ }
1650
+ __m128 g;
1651
+ c1 = _mm256_add_ps(_mm256_add_ps(c1, c3),
1652
+ _mm256_add_ps(c2, c4));
1653
+ g = _mm_add_ps(_mm256_extractf128_ps(c1, 1),
1654
+ _mm256_castps256_ps128(c1));
1655
+ g = _mm_add_ps(g, _mm_movehl_ps(g, g));
1656
+ g = _mm_add_ss(g, _mm_movehdup_ps(g));
1657
+ sumf += (ggml_float)_mm_cvtss_f32(g);
1658
+
1659
+ #undef LOAD
1660
+ #endif
1661
+
1662
+ for (; i < n; ++i) {
1663
+ sumf += (ggml_float)(GGML_BF16_TO_FP32(x[i]) *
1664
+ GGML_BF16_TO_FP32(y[i]));
1665
+ }
1666
+ *s = sumf;
1667
+ }
1668
+
1549
1669
  static void ggml_vec_dot_f16(int n, float * restrict s, size_t bs, ggml_fp16_t * restrict x, size_t bx, ggml_fp16_t * restrict y, size_t by, int nrc) {
1550
1670
  assert(nrc == 1);
1551
1671
  UNUSED(nrc);
@@ -1674,6 +1794,37 @@ inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float
1674
1794
  #endif
1675
1795
  }
1676
1796
 
1797
+ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, const ggml_fp16_t * restrict x, const float v) {
1798
+ #if defined(GGML_SIMD)
1799
+ const int np = (n & ~(GGML_F16_STEP - 1));
1800
+
1801
+ GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
1802
+
1803
+ GGML_F16_VEC ax[GGML_F16_ARR];
1804
+ GGML_F16_VEC ay[GGML_F16_ARR];
1805
+
1806
+ for (int i = 0; i < np; i += GGML_F16_STEP) {
1807
+ for (int j = 0; j < GGML_F16_ARR; j++) {
1808
+ ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j);
1809
+ ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
1810
+ ay[j] = GGML_F16_VEC_FMA(ay[j], ax[j], vx);
1811
+
1812
+ GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
1813
+ }
1814
+ }
1815
+
1816
+ // leftovers
1817
+ for (int i = np; i < n; ++i) {
1818
+ y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v);
1819
+ }
1820
+ #else
1821
+ // scalar
1822
+ for (int i = 0; i < n; ++i) {
1823
+ y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v);
1824
+ }
1825
+ #endif
1826
+ }
1827
+
1677
1828
  // xs and vs are byte strides of x and v
1678
1829
  inline static void ggml_vec_mad_f32_unroll(const int n, const int xs, const int vs, float * restrict y, const float * restrict xv, const float * restrict vv) {
1679
1830
 
@@ -1758,6 +1909,35 @@ inline static void ggml_vec_scale_f32(const int n, float * y, const float v) {
1758
1909
  #endif
1759
1910
  }
1760
1911
 
1912
+ inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float v) {
1913
+ #if defined(GGML_SIMD)
1914
+ const int np = (n & ~(GGML_F16_STEP - 1));
1915
+
1916
+ GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
1917
+
1918
+ GGML_F16_VEC ay[GGML_F16_ARR];
1919
+
1920
+ for (int i = 0; i < np; i += GGML_F16_STEP) {
1921
+ for (int j = 0; j < GGML_F16_ARR; j++) {
1922
+ ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
1923
+ ay[j] = GGML_F16_VEC_MUL(ay[j], vx);
1924
+
1925
+ GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
1926
+ }
1927
+ }
1928
+
1929
+ // leftovers
1930
+ for (int i = np; i < n; ++i) {
1931
+ y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i])*v);
1932
+ }
1933
+ #else
1934
+ // scalar
1935
+ for (int i = 0; i < n; ++i) {
1936
+ y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i])*v);
1937
+ }
1938
+ #endif
1939
+ }
1940
+
1761
1941
  inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, 0, x, 0, x, 0, 1); *s = sqrtf(*s); }
1762
1942
  inline static void ggml_vec_sqr_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i]; }
1763
1943
  inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrtf(x[i]); }
@@ -1919,6 +2099,14 @@ inline static void ggml_vec_sum_f16_ggf(const int n, float * s, const ggml_fp16_
1919
2099
  *s = sum;
1920
2100
  }
1921
2101
 
2102
+ inline static void ggml_vec_sum_bf16_ggf(const int n, float * s, const ggml_bf16_t * x) {
2103
+ float sum = 0.0f;
2104
+ for (int i = 0; i < n; ++i) {
2105
+ sum += GGML_BF16_TO_FP32(x[i]);
2106
+ }
2107
+ *s = sum;
2108
+ }
2109
+
1922
2110
  inline static void ggml_vec_max_f32(const int n, float * s, const float * x) {
1923
2111
  #ifndef GGML_USE_ACCELERATE
1924
2112
  float max = -INFINITY;
@@ -2012,6 +2200,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
2012
2200
  "LEAKY_RELU",
2013
2201
 
2014
2202
  "FLASH_ATTN",
2203
+ "FLASH_ATTN_EXT",
2015
2204
  "FLASH_FF",
2016
2205
  "FLASH_ATTN_BACK",
2017
2206
  "SSM_CONV",
@@ -2038,7 +2227,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
2038
2227
  "CROSS_ENTROPY_LOSS_BACK",
2039
2228
  };
2040
2229
 
2041
- static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76");
2230
+ static_assert(GGML_OP_COUNT == 77, "GGML_OP_COUNT != 77");
2042
2231
 
2043
2232
  static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
2044
2233
  "none",
@@ -2102,6 +2291,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
2102
2291
  "leaky_relu(x)",
2103
2292
 
2104
2293
  "flash_attn(x)",
2294
+ "flash_attn_ext(x)",
2105
2295
  "flash_ff(x)",
2106
2296
  "flash_attn_back(x)",
2107
2297
  "ssm_conv(x)",
@@ -2128,7 +2318,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
2128
2318
  "cross_entropy_loss_back(x,y)",
2129
2319
  };
2130
2320
 
2131
- static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76");
2321
+ static_assert(GGML_OP_COUNT == 77, "GGML_OP_COUNT != 77");
2132
2322
 
2133
2323
  static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
2134
2324
 
@@ -2327,7 +2517,7 @@ void ggml_numa_init(enum ggml_numa_strategy numa_flag) {
2327
2517
  // figure out which node we're on
2328
2518
  uint current_cpu;
2329
2519
  int getcpu_ret = 0;
2330
- #if __GLIBC__ > 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ > 28)
2520
+ #if __GLIBC__ > 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ > 28) || defined(__COSMOPOLITAN__)
2331
2521
  getcpu_ret = getcpu(&current_cpu, &g_state.numa.current_node);
2332
2522
  #else
2333
2523
  // old glibc doesn't have a wrapper for this call. Fall back on direct syscall
@@ -2538,6 +2728,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
2538
2728
  switch (ftype) {
2539
2729
  case GGML_FTYPE_ALL_F32: wtype = GGML_TYPE_F32; break;
2540
2730
  case GGML_FTYPE_MOSTLY_F16: wtype = GGML_TYPE_F16; break;
2731
+ case GGML_FTYPE_MOSTLY_BF16: wtype = GGML_TYPE_BF16; break;
2541
2732
  case GGML_FTYPE_MOSTLY_Q4_0: wtype = GGML_TYPE_Q4_0; break;
2542
2733
  case GGML_FTYPE_MOSTLY_Q4_1: wtype = GGML_TYPE_Q4_1; break;
2543
2734
  case GGML_FTYPE_MOSTLY_Q5_0: wtype = GGML_TYPE_Q5_0; break;
@@ -2679,15 +2870,16 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
2679
2870
  {
2680
2871
  const uint64_t t_start = ggml_time_us(); UNUSED(t_start);
2681
2872
 
2682
- ggml_fp16_t ii;
2683
2873
  for (int i = 0; i < (1 << 16); ++i) {
2684
- uint16_t ui = i;
2685
- memcpy(&ii, &ui, sizeof(ii));
2686
- const float f = ggml_table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(ii);
2874
+ union {
2875
+ uint16_t u16;
2876
+ ggml_fp16_t fp16;
2877
+ } u = {i};
2878
+ float f = ggml_table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(u.fp16);
2687
2879
  ggml_table_gelu_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_f32(f));
2688
2880
  ggml_table_gelu_quick_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_quick_f32(f));
2689
2881
  ggml_table_silu_f16[i] = GGML_FP32_TO_FP16(ggml_silu_f32(f));
2690
- ggml_table_exp_f16[i] = GGML_FP32_TO_FP16(expf(f));
2882
+ ggml_table_exp_f16[i] = GGML_FP32_TO_FP16(expf(f));
2691
2883
  }
2692
2884
 
2693
2885
  const uint64_t t_end = ggml_time_us(); UNUSED(t_end);
@@ -3151,6 +3343,13 @@ struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value) {
3151
3343
  ggml_vec_set_f16(nc, (ggml_fp16_t *)(data + i*n1), GGML_FP32_TO_FP16(value));
3152
3344
  }
3153
3345
  } break;
3346
+ case GGML_TYPE_BF16:
3347
+ {
3348
+ assert(tensor->nb[0] == sizeof(ggml_fp16_t));
3349
+ for (int i = 0; i < n; i++) {
3350
+ ggml_vec_set_bf16(nc, (ggml_bf16_t *)(data + i*n1), GGML_FP32_TO_BF16(value));
3351
+ }
3352
+ } break;
3154
3353
  case GGML_TYPE_F32:
3155
3354
  {
3156
3355
  assert(tensor->nb[0] == sizeof(float));
@@ -3203,6 +3402,13 @@ struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) {
3203
3402
  ggml_vec_set_f16(nc, (ggml_fp16_t *)(data + i*n1), GGML_FP32_TO_FP16(value));
3204
3403
  }
3205
3404
  } break;
3405
+ case GGML_TYPE_BF16:
3406
+ {
3407
+ assert(tensor->nb[0] == sizeof(ggml_bf16_t));
3408
+ for (int i = 0; i < n; i++) {
3409
+ ggml_vec_set_bf16(nc, (ggml_bf16_t *)(data + i*n1), GGML_FP32_TO_BF16(value));
3410
+ }
3411
+ } break;
3206
3412
  case GGML_TYPE_F32:
3207
3413
  {
3208
3414
  assert(tensor->nb[0] == sizeof(float));
@@ -3270,6 +3476,11 @@ int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) {
3270
3476
  GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t));
3271
3477
  return GGML_FP16_TO_FP32(((ggml_fp16_t *)(tensor->data))[i]);
3272
3478
  }
3479
+ case GGML_TYPE_BF16:
3480
+ {
3481
+ GGML_ASSERT(tensor->nb[0] == sizeof(ggml_bf16_t));
3482
+ return GGML_BF16_TO_FP32(((ggml_bf16_t *)(tensor->data))[i]);
3483
+ }
3273
3484
  case GGML_TYPE_F32:
3274
3485
  {
3275
3486
  GGML_ASSERT(tensor->nb[0] == sizeof(float));
@@ -3312,6 +3523,11 @@ void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) {
3312
3523
  GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t));
3313
3524
  ((ggml_fp16_t *)(tensor->data))[i] = GGML_FP32_TO_FP16(value);
3314
3525
  } break;
3526
+ case GGML_TYPE_BF16:
3527
+ {
3528
+ GGML_ASSERT(tensor->nb[0] == sizeof(ggml_bf16_t));
3529
+ ((ggml_bf16_t *)(tensor->data))[i] = GGML_FP32_TO_BF16(value);
3530
+ } break;
3315
3531
  case GGML_TYPE_F32:
3316
3532
  {
3317
3533
  GGML_ASSERT(tensor->nb[0] == sizeof(float));
@@ -3335,6 +3551,8 @@ int32_t ggml_get_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i
3335
3551
  return ((int32_t *) data)[0];
3336
3552
  case GGML_TYPE_F16:
3337
3553
  return GGML_FP16_TO_FP32(((ggml_fp16_t *) data)[0]);
3554
+ case GGML_TYPE_BF16:
3555
+ return GGML_BF16_TO_FP32(((ggml_bf16_t *) data)[0]);
3338
3556
  case GGML_TYPE_F32:
3339
3557
  return ((float *) data)[0];
3340
3558
  default:
@@ -3363,6 +3581,10 @@ void ggml_set_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2,
3363
3581
  {
3364
3582
  ((ggml_fp16_t *)(data))[0] = GGML_FP32_TO_FP16(value);
3365
3583
  } break;
3584
+ case GGML_TYPE_BF16:
3585
+ {
3586
+ ((ggml_bf16_t *)(data))[0] = GGML_FP32_TO_BF16(value);
3587
+ } break;
3366
3588
  case GGML_TYPE_F32:
3367
3589
  {
3368
3590
  ((float *)(data))[0] = value;
@@ -3401,6 +3623,11 @@ float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) {
3401
3623
  GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t));
3402
3624
  return GGML_FP16_TO_FP32(((ggml_fp16_t *)(tensor->data))[i]);
3403
3625
  }
3626
+ case GGML_TYPE_BF16:
3627
+ {
3628
+ GGML_ASSERT(tensor->nb[0] == sizeof(ggml_bf16_t));
3629
+ return GGML_BF16_TO_FP32(((ggml_bf16_t *)(tensor->data))[i]);
3630
+ }
3404
3631
  case GGML_TYPE_F32:
3405
3632
  {
3406
3633
  GGML_ASSERT(tensor->nb[0] == sizeof(float));
@@ -3443,6 +3670,11 @@ void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) {
3443
3670
  GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t));
3444
3671
  ((ggml_fp16_t *)(tensor->data))[i] = GGML_FP32_TO_FP16(value);
3445
3672
  } break;
3673
+ case GGML_TYPE_BF16:
3674
+ {
3675
+ GGML_ASSERT(tensor->nb[0] == sizeof(ggml_bf16_t));
3676
+ ((ggml_bf16_t *)(tensor->data))[i] = GGML_FP32_TO_BF16(value);
3677
+ } break;
3446
3678
  case GGML_TYPE_F32:
3447
3679
  {
3448
3680
  GGML_ASSERT(tensor->nb[0] == sizeof(float));
@@ -3466,6 +3698,8 @@ float ggml_get_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2,
3466
3698
  return ((int32_t *) data)[0];
3467
3699
  case GGML_TYPE_F16:
3468
3700
  return GGML_FP16_TO_FP32(((ggml_fp16_t *) data)[0]);
3701
+ case GGML_TYPE_BF16:
3702
+ return GGML_BF16_TO_FP32(((ggml_bf16_t *) data)[0]);
3469
3703
  case GGML_TYPE_F32:
3470
3704
  return ((float *) data)[0];
3471
3705
  default:
@@ -3494,6 +3728,10 @@ void ggml_set_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2,
3494
3728
  {
3495
3729
  ((ggml_fp16_t *)(data))[0] = GGML_FP32_TO_FP16(value);
3496
3730
  } break;
3731
+ case GGML_TYPE_BF16:
3732
+ {
3733
+ ((ggml_bf16_t *)(data))[0] = GGML_FP32_TO_BF16(value);
3734
+ } break;
3497
3735
  case GGML_TYPE_F32:
3498
3736
  {
3499
3737
  ((float *)(data))[0] = value;
@@ -3688,7 +3926,11 @@ static struct ggml_tensor * ggml_add_cast_impl(
3688
3926
  // TODO: support less-strict constraint
3689
3927
  // GGML_ASSERT(ggml_can_repeat(b, a));
3690
3928
  GGML_ASSERT(ggml_can_repeat_rows(b, a));
3691
- GGML_ASSERT(ggml_is_quantized(a->type) || a->type == GGML_TYPE_F16); // currently only supported for quantized input and f16
3929
+
3930
+ // currently only supported for quantized input and f16
3931
+ GGML_ASSERT(ggml_is_quantized(a->type) ||
3932
+ a->type == GGML_TYPE_F16 ||
3933
+ a->type == GGML_TYPE_BF16);
3692
3934
 
3693
3935
  bool is_node = false;
3694
3936
 
@@ -4571,6 +4813,8 @@ struct ggml_tensor * ggml_mul_mat(
4571
4813
  void ggml_mul_mat_set_prec(
4572
4814
  struct ggml_tensor * a,
4573
4815
  enum ggml_prec prec) {
4816
+ GGML_ASSERT(a->op == GGML_OP_MUL_MAT);
4817
+
4574
4818
  const int32_t prec_i32 = (int32_t) prec;
4575
4819
 
4576
4820
  ggml_set_op_params_i32(a, 0, prec_i32);
@@ -5409,17 +5653,23 @@ static struct ggml_tensor * ggml_soft_max_impl(
5409
5653
  GGML_ASSERT(ggml_is_contiguous(a));
5410
5654
 
5411
5655
  if (mask) {
5656
+ GGML_ASSERT(mask->type == GGML_TYPE_F16 || mask->type == GGML_TYPE_F32);
5412
5657
  GGML_ASSERT(ggml_is_contiguous(mask));
5413
5658
  GGML_ASSERT(ggml_is_matrix(mask));
5414
- GGML_ASSERT(ggml_can_repeat_rows(mask, a));
5659
+ GGML_ASSERT(mask->ne[0] == a->ne[0]);
5660
+ GGML_ASSERT(mask->ne[1] >= a->ne[1]);
5415
5661
  }
5416
5662
 
5417
5663
  if (pos) {
5418
5664
  GGML_ASSERT(ggml_is_vector(pos));
5419
- GGML_ASSERT(pos->type == GGML_TYPE_F32);
5665
+ GGML_ASSERT(pos->type == GGML_TYPE_F16 || pos->type == GGML_TYPE_F32);
5420
5666
  GGML_ASSERT(pos->ne[0] == a->ne[0]);
5421
5667
  }
5422
5668
 
5669
+ if (pos && mask) {
5670
+ GGML_ASSERT(pos->type == mask->type);
5671
+ }
5672
+
5423
5673
  if (max_bias > 0.0f) {
5424
5674
  GGML_ASSERT(pos);
5425
5675
  }
@@ -6228,6 +6478,59 @@ struct ggml_tensor * ggml_flash_attn(
6228
6478
  return result;
6229
6479
  }
6230
6480
 
6481
+ // ggml_flash_attn_ext
6482
+
6483
+ struct ggml_tensor * ggml_flash_attn_ext(
6484
+ struct ggml_context * ctx,
6485
+ struct ggml_tensor * q,
6486
+ struct ggml_tensor * k,
6487
+ struct ggml_tensor * v,
6488
+ struct ggml_tensor * mask,
6489
+ float scale) {
6490
+ GGML_ASSERT(ggml_can_mul_mat(k, q));
6491
+ // TODO: check if vT can be multiplied by (k*qT)
6492
+ if (mask) {
6493
+ GGML_ASSERT(ggml_is_contiguous(mask));
6494
+ GGML_ASSERT(mask->ne[2] == 1);
6495
+ GGML_ASSERT(mask->ne[3] == 1);
6496
+ GGML_ASSERT(mask->ne[1] >= GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD) &&
6497
+ "the Flash-Attention kernel requires the mask to be padded to GGML_KQ_MASK_PAD and at least n_queries big");
6498
+ //GGML_ASSERT(ggml_can_repeat_rows(mask, qk));
6499
+ }
6500
+
6501
+ bool is_node = false;
6502
+
6503
+ if (q->grad || k->grad || v->grad) {
6504
+ is_node = true;
6505
+ }
6506
+
6507
+ // permute(0, 2, 1, 3)
6508
+ int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] };
6509
+ struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
6510
+
6511
+ float params[] = { scale };
6512
+ ggml_set_op_params(result, params, sizeof(params));
6513
+
6514
+ result->op = GGML_OP_FLASH_ATTN_EXT;
6515
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6516
+ result->src[0] = q;
6517
+ result->src[1] = k;
6518
+ result->src[2] = v;
6519
+ result->src[3] = mask;
6520
+
6521
+ return result;
6522
+ }
6523
+
6524
+ void ggml_flash_attn_ext_set_prec(
6525
+ struct ggml_tensor * a,
6526
+ enum ggml_prec prec) {
6527
+ GGML_ASSERT(a->op == GGML_OP_FLASH_ATTN_EXT);
6528
+
6529
+ const int32_t prec_i32 = (int32_t) prec;
6530
+
6531
+ ggml_set_op_params_i32(a, 1, prec_i32); // scale is on first pos
6532
+ }
6533
+
6231
6534
  // ggml_flash_ff
6232
6535
 
6233
6536
  struct ggml_tensor * ggml_flash_ff(
@@ -7104,8 +7407,8 @@ static void ggml_compute_forward_dup_same_cont(
7104
7407
  ((char *) src0->data + ie0*nb00),
7105
7408
  (ie1 - ie0) * ggml_type_size(src0->type));
7106
7409
  }
7107
-
7108
7410
  }
7411
+
7109
7412
  static void ggml_compute_forward_dup_f16(
7110
7413
  const struct ggml_compute_params * params,
7111
7414
  struct ggml_tensor * dst) {
@@ -7379,7 +7682,7 @@ static void ggml_compute_forward_dup_f16(
7379
7682
  }
7380
7683
  }
7381
7684
 
7382
- static void ggml_compute_forward_dup_f32(
7685
+ static void ggml_compute_forward_dup_bf16(
7383
7686
  const struct ggml_compute_params * params,
7384
7687
  struct ggml_tensor * dst) {
7385
7688
 
@@ -7427,10 +7730,11 @@ static void ggml_compute_forward_dup_f32(
7427
7730
  return;
7428
7731
  }
7429
7732
 
7733
+ // TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy
7734
+
7430
7735
  if (ggml_is_contiguous(dst)) {
7431
- // TODO: simplify
7432
- if (nb00 == sizeof(float)) {
7433
- if (dst->type == GGML_TYPE_F32) {
7736
+ if (nb00 == sizeof(ggml_bf16_t)) {
7737
+ if (dst->type == GGML_TYPE_BF16) {
7434
7738
  size_t id = 0;
7435
7739
  const size_t rs = ne00 * nb00;
7436
7740
  char * dst_ptr = (char *) dst->data;
@@ -7446,8 +7750,43 @@ static void ggml_compute_forward_dup_f32(
7446
7750
  id += rs * (ne01 - ir1);
7447
7751
  }
7448
7752
  }
7753
+ } else if (dst->type == GGML_TYPE_F16) {
7754
+ size_t id = 0;
7755
+ ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
7756
+
7757
+ for (int i03 = 0; i03 < ne03; i03++) {
7758
+ for (int i02 = 0; i02 < ne02; i02++) {
7759
+ id += ne00 * ir0;
7760
+ for (int i01 = ir0; i01 < ir1; i01++) {
7761
+ const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
7762
+ for (int i00 = 0; i00 < ne00; i00++) {
7763
+ dst_ptr[id] = GGML_FP32_TO_FP16(GGML_BF16_TO_FP32(src0_ptr[i00]));
7764
+ id++;
7765
+ }
7766
+ }
7767
+ id += ne00 * (ne01 - ir1);
7768
+ }
7769
+ }
7770
+ } else if (dst->type == GGML_TYPE_F32) {
7771
+ size_t id = 0;
7772
+ float * dst_ptr = (float *) dst->data;
7773
+
7774
+ for (int i03 = 0; i03 < ne03; i03++) {
7775
+ for (int i02 = 0; i02 < ne02; i02++) {
7776
+ id += ne00 * ir0;
7777
+ for (int i01 = ir0; i01 < ir1; i01++) {
7778
+ const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
7779
+ for (int i00 = 0; i00 < ne00; i00++) {
7780
+ dst_ptr[id] = GGML_BF16_TO_FP32(src0_ptr[i00]);
7781
+ id++;
7782
+ }
7783
+ }
7784
+ id += ne00 * (ne01 - ir1);
7785
+ }
7786
+ }
7449
7787
  } else if (type_traits[dst->type].from_float) {
7450
7788
  ggml_from_float_t const quantize_row_q = type_traits[dst->type].from_float;
7789
+ float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
7451
7790
 
7452
7791
  size_t id = 0;
7453
7792
  size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
@@ -7457,8 +7796,13 @@ static void ggml_compute_forward_dup_f32(
7457
7796
  for (int i02 = 0; i02 < ne02; i02++) {
7458
7797
  id += rs * ir0;
7459
7798
  for (int i01 = ir0; i01 < ir1; i01++) {
7460
- const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
7461
- quantize_row_q(src0_ptr, dst_ptr + id, ne00);
7799
+ const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
7800
+
7801
+ for (int i00 = 0; i00 < ne00; i00++) {
7802
+ src0_f32[i00] = GGML_BF16_TO_FP32(src0_ptr[i00]);
7803
+ }
7804
+
7805
+ quantize_row_q(src0_f32, dst_ptr + id, ne00);
7462
7806
  id += rs;
7463
7807
  }
7464
7808
  id += rs * (ne01 - ir1);
@@ -7479,7 +7823,25 @@ static void ggml_compute_forward_dup_f32(
7479
7823
  id += ne00 * ir0;
7480
7824
  for (int i01 = ir0; i01 < ir1; i01++) {
7481
7825
  for (int i00 = 0; i00 < ne00; i00++) {
7482
- const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
7826
+ const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
7827
+
7828
+ dst_ptr[id] = GGML_BF16_TO_FP32(*src0_ptr);
7829
+ id++;
7830
+ }
7831
+ }
7832
+ id += ne00 * (ne01 - ir1);
7833
+ }
7834
+ }
7835
+ } else if (dst->type == GGML_TYPE_BF16) {
7836
+ size_t id = 0;
7837
+ ggml_bf16_t * dst_ptr = (ggml_bf16_t *) dst->data;
7838
+
7839
+ for (int i03 = 0; i03 < ne03; i03++) {
7840
+ for (int i02 = 0; i02 < ne02; i02++) {
7841
+ id += ne00 * ir0;
7842
+ for (int i01 = ir0; i01 < ir1; i01++) {
7843
+ for (int i00 = 0; i00 < ne00; i00++) {
7844
+ const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
7483
7845
 
7484
7846
  dst_ptr[id] = *src0_ptr;
7485
7847
  id++;
@@ -7497,9 +7859,9 @@ static void ggml_compute_forward_dup_f32(
7497
7859
  id += ne00 * ir0;
7498
7860
  for (int i01 = ir0; i01 < ir1; i01++) {
7499
7861
  for (int i00 = 0; i00 < ne00; i00++) {
7500
- const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
7862
+ const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
7501
7863
 
7502
- dst_ptr[id] = GGML_FP32_TO_FP16(*src0_ptr);
7864
+ dst_ptr[id] = GGML_FP32_TO_FP16(GGML_BF16_TO_FP32(*src0_ptr));
7503
7865
  id++;
7504
7866
  }
7505
7867
  }
@@ -7510,18 +7872,16 @@ static void ggml_compute_forward_dup_f32(
7510
7872
  GGML_ASSERT(false); // TODO: implement
7511
7873
  }
7512
7874
  }
7513
-
7514
7875
  return;
7515
7876
  }
7516
7877
 
7517
7878
  // dst counters
7518
-
7519
7879
  int64_t i10 = 0;
7520
7880
  int64_t i11 = 0;
7521
7881
  int64_t i12 = 0;
7522
7882
  int64_t i13 = 0;
7523
7883
 
7524
- if (dst->type == GGML_TYPE_F32) {
7884
+ if (dst->type == GGML_TYPE_BF16) {
7525
7885
  for (int64_t i03 = 0; i03 < ne03; i03++) {
7526
7886
  for (int64_t i02 = 0; i02 < ne02; i02++) {
7527
7887
  i10 += ne00 * ir0;
@@ -7542,15 +7902,15 @@ static void ggml_compute_forward_dup_f32(
7542
7902
  const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
7543
7903
  char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
7544
7904
 
7545
- memcpy(dst_ptr, src0_ptr, sizeof(float));
7905
+ memcpy(dst_ptr, src0_ptr, sizeof(ggml_bf16_t));
7546
7906
 
7547
- if (++i10 == ne0) {
7907
+ if (++i10 == ne00) {
7548
7908
  i10 = 0;
7549
- if (++i11 == ne1) {
7909
+ if (++i11 == ne01) {
7550
7910
  i11 = 0;
7551
- if (++i12 == ne2) {
7911
+ if (++i12 == ne02) {
7552
7912
  i12 = 0;
7553
- if (++i13 == ne3) {
7913
+ if (++i13 == ne03) {
7554
7914
  i13 = 0;
7555
7915
  }
7556
7916
  }
@@ -7594,7 +7954,7 @@ static void ggml_compute_forward_dup_f32(
7594
7954
  const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
7595
7955
  char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
7596
7956
 
7597
- *(ggml_fp16_t *) dst_ptr = GGML_FP32_TO_FP16(*(const float *) src0_ptr);
7957
+ *(ggml_fp16_t *) dst_ptr = GGML_FP32_TO_FP16(GGML_BF16_TO_FP32(*(const ggml_bf16_t *) src0_ptr));
7598
7958
 
7599
7959
  if (++i10 == ne0) {
7600
7960
  i10 = 0;
@@ -7625,22 +7985,395 @@ static void ggml_compute_forward_dup_f32(
7625
7985
  }
7626
7986
  }
7627
7987
  }
7628
- } else {
7629
- GGML_ASSERT(false); // TODO: implement
7630
- }
7631
- }
7632
-
7633
- // A simplified version of ggml_compute_forward_dup that doesn't do float upcasting, and just plain old memcpy.
7634
- static void ggml_compute_forward_dup_bytes(
7635
- const struct ggml_compute_params * params,
7636
- struct ggml_tensor * dst) {
7637
-
7638
- const struct ggml_tensor * src0 = dst->src[0];
7639
-
7640
- GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
7641
- GGML_ASSERT(src0->type == dst->type);
7642
-
7643
- if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
7988
+ } else if (dst->type == GGML_TYPE_F32) {
7989
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
7990
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
7991
+ i10 += ne00 * ir0;
7992
+ while (i10 >= ne0) {
7993
+ i10 -= ne0;
7994
+ if (++i11 == ne1) {
7995
+ i11 = 0;
7996
+ if (++i12 == ne2) {
7997
+ i12 = 0;
7998
+ if (++i13 == ne3) {
7999
+ i13 = 0;
8000
+ }
8001
+ }
8002
+ }
8003
+ }
8004
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
8005
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
8006
+ const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
8007
+ char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
8008
+
8009
+ *(float *) dst_ptr = GGML_BF16_TO_FP32(*(const ggml_bf16_t *) src0_ptr);
8010
+
8011
+ if (++i10 == ne0) {
8012
+ i10 = 0;
8013
+ if (++i11 == ne1) {
8014
+ i11 = 0;
8015
+ if (++i12 == ne2) {
8016
+ i12 = 0;
8017
+ if (++i13 == ne3) {
8018
+ i13 = 0;
8019
+ }
8020
+ }
8021
+ }
8022
+ }
8023
+ }
8024
+ }
8025
+ i10 += ne00 * (ne01 - ir1);
8026
+ while (i10 >= ne0) {
8027
+ i10 -= ne0;
8028
+ if (++i11 == ne1) {
8029
+ i11 = 0;
8030
+ if (++i12 == ne2) {
8031
+ i12 = 0;
8032
+ if (++i13 == ne3) {
8033
+ i13 = 0;
8034
+ }
8035
+ }
8036
+ }
8037
+ }
8038
+ }
8039
+ }
8040
+ } else {
8041
+ GGML_ASSERT(false); // TODO: implement
8042
+ }
8043
+ }
8044
+
8045
+ static void ggml_compute_forward_dup_f32(
8046
+ const struct ggml_compute_params * params,
8047
+ struct ggml_tensor * dst) {
8048
+
8049
+ const struct ggml_tensor * src0 = dst->src[0];
8050
+
8051
+ GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
8052
+
8053
+ if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
8054
+ return;
8055
+ }
8056
+
8057
+ GGML_TENSOR_UNARY_OP_LOCALS
8058
+
8059
+ const int ith = params->ith; // thread index
8060
+ const int nth = params->nth; // number of threads
8061
+
8062
+ if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst) && src0->type == dst->type) {
8063
+ ggml_compute_forward_dup_same_cont(params, dst);
8064
+ return;
8065
+ }
8066
+
8067
+ // parallelize by rows
8068
+ const int nr = ne01;
8069
+ // number of rows per thread
8070
+ const int dr = (nr + nth - 1) / nth;
8071
+ // row range for this thread
8072
+ const int ir0 = dr * ith;
8073
+ const int ir1 = MIN(ir0 + dr, nr);
8074
+
8075
+ if (src0->type == dst->type &&
8076
+ ne00 == ne0 &&
8077
+ nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) {
8078
+ // copy by rows
8079
+ const size_t rs = ne00*nb00;
8080
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
8081
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
8082
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
8083
+ memcpy(
8084
+ ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
8085
+ ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
8086
+ rs);
8087
+ }
8088
+ }
8089
+ }
8090
+ return;
8091
+ }
8092
+
8093
+ if (ggml_is_contiguous(dst)) {
8094
+ // TODO: simplify
8095
+ if (nb00 == sizeof(float)) {
8096
+ if (dst->type == GGML_TYPE_F32) {
8097
+ size_t id = 0;
8098
+ const size_t rs = ne00 * nb00;
8099
+ char * dst_ptr = (char *) dst->data;
8100
+
8101
+ for (int i03 = 0; i03 < ne03; i03++) {
8102
+ for (int i02 = 0; i02 < ne02; i02++) {
8103
+ id += rs * ir0;
8104
+ for (int i01 = ir0; i01 < ir1; i01++) {
8105
+ const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
8106
+ memcpy(dst_ptr + id, src0_ptr, rs);
8107
+ id += rs;
8108
+ }
8109
+ id += rs * (ne01 - ir1);
8110
+ }
8111
+ }
8112
+ } else if (type_traits[dst->type].from_float) {
8113
+ ggml_from_float_t const quantize_row_q = type_traits[dst->type].from_float;
8114
+
8115
+ size_t id = 0;
8116
+ size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
8117
+ char * dst_ptr = (char *) dst->data;
8118
+
8119
+ for (int i03 = 0; i03 < ne03; i03++) {
8120
+ for (int i02 = 0; i02 < ne02; i02++) {
8121
+ id += rs * ir0;
8122
+ for (int i01 = ir0; i01 < ir1; i01++) {
8123
+ const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
8124
+ quantize_row_q(src0_ptr, dst_ptr + id, ne00);
8125
+ id += rs;
8126
+ }
8127
+ id += rs * (ne01 - ir1);
8128
+ }
8129
+ }
8130
+ } else {
8131
+ GGML_ASSERT(false); // TODO: implement
8132
+ }
8133
+ } else {
8134
+ //printf("%s: this is not optimal - fix me\n", __func__);
8135
+
8136
+ if (dst->type == GGML_TYPE_F32) {
8137
+ size_t id = 0;
8138
+ float * dst_ptr = (float *) dst->data;
8139
+
8140
+ for (int i03 = 0; i03 < ne03; i03++) {
8141
+ for (int i02 = 0; i02 < ne02; i02++) {
8142
+ id += ne00 * ir0;
8143
+ for (int i01 = ir0; i01 < ir1; i01++) {
8144
+ for (int i00 = 0; i00 < ne00; i00++) {
8145
+ const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
8146
+
8147
+ dst_ptr[id] = *src0_ptr;
8148
+ id++;
8149
+ }
8150
+ }
8151
+ id += ne00 * (ne01 - ir1);
8152
+ }
8153
+ }
8154
+ } else if (dst->type == GGML_TYPE_F16) {
8155
+ size_t id = 0;
8156
+ ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
8157
+
8158
+ for (int i03 = 0; i03 < ne03; i03++) {
8159
+ for (int i02 = 0; i02 < ne02; i02++) {
8160
+ id += ne00 * ir0;
8161
+ for (int i01 = ir0; i01 < ir1; i01++) {
8162
+ for (int i00 = 0; i00 < ne00; i00++) {
8163
+ const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
8164
+
8165
+ dst_ptr[id] = GGML_FP32_TO_FP16(*src0_ptr);
8166
+ id++;
8167
+ }
8168
+ }
8169
+ id += ne00 * (ne01 - ir1);
8170
+ }
8171
+ }
8172
+ } else if (dst->type == GGML_TYPE_BF16) {
8173
+ size_t id = 0;
8174
+ ggml_bf16_t * dst_ptr = (ggml_bf16_t *) dst->data;
8175
+
8176
+ for (int i03 = 0; i03 < ne03; i03++) {
8177
+ for (int i02 = 0; i02 < ne02; i02++) {
8178
+ id += ne00 * ir0;
8179
+ for (int i01 = ir0; i01 < ir1; i01++) {
8180
+ for (int i00 = 0; i00 < ne00; i00++) {
8181
+ const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
8182
+
8183
+ dst_ptr[id] = GGML_FP32_TO_BF16(*src0_ptr);
8184
+ id++;
8185
+ }
8186
+ }
8187
+ id += ne00 * (ne01 - ir1);
8188
+ }
8189
+ }
8190
+ } else {
8191
+ GGML_ASSERT(false); // TODO: implement
8192
+ }
8193
+ }
8194
+
8195
+ return;
8196
+ }
8197
+
8198
+ // dst counters
8199
+
8200
+ int64_t i10 = 0;
8201
+ int64_t i11 = 0;
8202
+ int64_t i12 = 0;
8203
+ int64_t i13 = 0;
8204
+
8205
+ if (dst->type == GGML_TYPE_F32) {
8206
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
8207
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
8208
+ i10 += ne00 * ir0;
8209
+ while (i10 >= ne0) {
8210
+ i10 -= ne0;
8211
+ if (++i11 == ne1) {
8212
+ i11 = 0;
8213
+ if (++i12 == ne2) {
8214
+ i12 = 0;
8215
+ if (++i13 == ne3) {
8216
+ i13 = 0;
8217
+ }
8218
+ }
8219
+ }
8220
+ }
8221
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
8222
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
8223
+ const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
8224
+ char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
8225
+
8226
+ memcpy(dst_ptr, src0_ptr, sizeof(float));
8227
+
8228
+ if (++i10 == ne0) {
8229
+ i10 = 0;
8230
+ if (++i11 == ne1) {
8231
+ i11 = 0;
8232
+ if (++i12 == ne2) {
8233
+ i12 = 0;
8234
+ if (++i13 == ne3) {
8235
+ i13 = 0;
8236
+ }
8237
+ }
8238
+ }
8239
+ }
8240
+ }
8241
+ }
8242
+ i10 += ne00 * (ne01 - ir1);
8243
+ while (i10 >= ne0) {
8244
+ i10 -= ne0;
8245
+ if (++i11 == ne1) {
8246
+ i11 = 0;
8247
+ if (++i12 == ne2) {
8248
+ i12 = 0;
8249
+ if (++i13 == ne3) {
8250
+ i13 = 0;
8251
+ }
8252
+ }
8253
+ }
8254
+ }
8255
+ }
8256
+ }
8257
+ } else if (dst->type == GGML_TYPE_F16) {
8258
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
8259
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
8260
+ i10 += ne00 * ir0;
8261
+ while (i10 >= ne0) {
8262
+ i10 -= ne0;
8263
+ if (++i11 == ne1) {
8264
+ i11 = 0;
8265
+ if (++i12 == ne2) {
8266
+ i12 = 0;
8267
+ if (++i13 == ne3) {
8268
+ i13 = 0;
8269
+ }
8270
+ }
8271
+ }
8272
+ }
8273
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
8274
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
8275
+ const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
8276
+ char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
8277
+
8278
+ *(ggml_fp16_t *) dst_ptr = GGML_FP32_TO_FP16(*(const float *) src0_ptr);
8279
+
8280
+ if (++i10 == ne0) {
8281
+ i10 = 0;
8282
+ if (++i11 == ne1) {
8283
+ i11 = 0;
8284
+ if (++i12 == ne2) {
8285
+ i12 = 0;
8286
+ if (++i13 == ne3) {
8287
+ i13 = 0;
8288
+ }
8289
+ }
8290
+ }
8291
+ }
8292
+ }
8293
+ }
8294
+ i10 += ne00 * (ne01 - ir1);
8295
+ while (i10 >= ne0) {
8296
+ i10 -= ne0;
8297
+ if (++i11 == ne1) {
8298
+ i11 = 0;
8299
+ if (++i12 == ne2) {
8300
+ i12 = 0;
8301
+ if (++i13 == ne3) {
8302
+ i13 = 0;
8303
+ }
8304
+ }
8305
+ }
8306
+ }
8307
+ }
8308
+ }
8309
+ } else if (dst->type == GGML_TYPE_BF16) {
8310
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
8311
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
8312
+ i10 += ne00 * ir0;
8313
+ while (i10 >= ne0) {
8314
+ i10 -= ne0;
8315
+ if (++i11 == ne1) {
8316
+ i11 = 0;
8317
+ if (++i12 == ne2) {
8318
+ i12 = 0;
8319
+ if (++i13 == ne3) {
8320
+ i13 = 0;
8321
+ }
8322
+ }
8323
+ }
8324
+ }
8325
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
8326
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
8327
+ const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
8328
+ char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
8329
+
8330
+ *(ggml_bf16_t *) dst_ptr = GGML_FP32_TO_BF16(*(const float *) src0_ptr);
8331
+
8332
+ if (++i10 == ne0) {
8333
+ i10 = 0;
8334
+ if (++i11 == ne1) {
8335
+ i11 = 0;
8336
+ if (++i12 == ne2) {
8337
+ i12 = 0;
8338
+ if (++i13 == ne3) {
8339
+ i13 = 0;
8340
+ }
8341
+ }
8342
+ }
8343
+ }
8344
+ }
8345
+ }
8346
+ i10 += ne00 * (ne01 - ir1);
8347
+ while (i10 >= ne0) {
8348
+ i10 -= ne0;
8349
+ if (++i11 == ne1) {
8350
+ i11 = 0;
8351
+ if (++i12 == ne2) {
8352
+ i12 = 0;
8353
+ if (++i13 == ne3) {
8354
+ i13 = 0;
8355
+ }
8356
+ }
8357
+ }
8358
+ }
8359
+ }
8360
+ }
8361
+ } else {
8362
+ GGML_ASSERT(false); // TODO: implement
8363
+ }
8364
+ }
8365
+
8366
+ // A simplified version of ggml_compute_forward_dup that doesn't do float upcasting, and just plain old memcpy.
8367
+ static void ggml_compute_forward_dup_bytes(
8368
+ const struct ggml_compute_params * params,
8369
+ struct ggml_tensor * dst) {
8370
+
8371
+ const struct ggml_tensor * src0 = dst->src[0];
8372
+
8373
+ GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
8374
+ GGML_ASSERT(src0->type == dst->type);
8375
+
8376
+ if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
7644
8377
  return;
7645
8378
  }
7646
8379
 
@@ -7798,6 +8531,10 @@ static void ggml_compute_forward_dup(
7798
8531
  {
7799
8532
  ggml_compute_forward_dup_f16(params, dst);
7800
8533
  } break;
8534
+ case GGML_TYPE_BF16:
8535
+ {
8536
+ ggml_compute_forward_dup_bf16(params, dst);
8537
+ } break;
7801
8538
  case GGML_TYPE_F32:
7802
8539
  {
7803
8540
  ggml_compute_forward_dup_f32(params, dst);
@@ -7980,6 +8717,85 @@ static void ggml_compute_forward_add_f16_f32(
7980
8717
  }
7981
8718
  }
7982
8719
 
8720
+ static void ggml_compute_forward_add_bf16_f32(
8721
+ const struct ggml_compute_params * params,
8722
+ struct ggml_tensor * dst) {
8723
+
8724
+ const struct ggml_tensor * src0 = dst->src[0];
8725
+ const struct ggml_tensor * src1 = dst->src[1];
8726
+
8727
+ GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
8728
+
8729
+ if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
8730
+ return;
8731
+ }
8732
+
8733
+ const int ith = params->ith;
8734
+ const int nth = params->nth;
8735
+
8736
+ const int nr = ggml_nrows(src0);
8737
+
8738
+ GGML_TENSOR_BINARY_OP_LOCALS
8739
+
8740
+ GGML_ASSERT(src0->type == GGML_TYPE_BF16);
8741
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
8742
+
8743
+ if (dst->type == GGML_TYPE_F32) {
8744
+ GGML_ASSERT( nb0 == sizeof(float));
8745
+ }
8746
+ else {
8747
+ GGML_ASSERT(dst->type == GGML_TYPE_BF16);
8748
+ GGML_ASSERT( nb0 == sizeof(ggml_bf16_t));
8749
+ }
8750
+
8751
+ GGML_ASSERT(nb00 == sizeof(ggml_bf16_t));
8752
+
8753
+ // rows per thread
8754
+ const int dr = (nr + nth - 1)/nth;
8755
+
8756
+ // row range for this thread
8757
+ const int ir0 = dr*ith;
8758
+ const int ir1 = MIN(ir0 + dr, nr);
8759
+
8760
+ if (nb10 == sizeof(float)) {
8761
+ if (dst->type == GGML_TYPE_BF16) {
8762
+ for (int ir = ir0; ir < ir1; ++ir) {
8763
+ // src0, src1 and dst are same shape => same indices
8764
+ const int i3 = ir/(ne2*ne1);
8765
+ const int i2 = (ir - i3*ne2*ne1)/ne1;
8766
+ const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
8767
+
8768
+ ggml_bf16_t * dst_ptr = (ggml_bf16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
8769
+ ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
8770
+ float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
8771
+
8772
+ for (int i = 0; i < ne0; i++) {
8773
+ dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + src1_ptr[i]);
8774
+ }
8775
+ }
8776
+ } else {
8777
+ for (int ir = ir0; ir < ir1; ++ir) {
8778
+ // src0, src1 and dst are same shape => same indices
8779
+ const int i3 = ir/(ne2*ne1);
8780
+ const int i2 = (ir - i3*ne2*ne1)/ne1;
8781
+ const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
8782
+
8783
+ float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
8784
+ ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
8785
+ float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
8786
+
8787
+ for (int i = 0; i < ne0; i++) {
8788
+ dst_ptr[i] = GGML_BF16_TO_FP32(src0_ptr[i]) + src1_ptr[i];
8789
+ }
8790
+ }
8791
+ }
8792
+ }
8793
+ else {
8794
+ // src1 is not contiguous
8795
+ GGML_ASSERT(false);
8796
+ }
8797
+ }
8798
+
7983
8799
  static void ggml_compute_forward_add_f16_f16(
7984
8800
  const struct ggml_compute_params * params,
7985
8801
  struct ggml_tensor * dst) {
@@ -8036,6 +8852,62 @@ static void ggml_compute_forward_add_f16_f16(
8036
8852
  }
8037
8853
  }
8038
8854
 
8855
+ static void ggml_compute_forward_add_bf16_bf16(
8856
+ const struct ggml_compute_params * params,
8857
+ struct ggml_tensor * dst) {
8858
+
8859
+ const struct ggml_tensor * src0 = dst->src[0];
8860
+ const struct ggml_tensor * src1 = dst->src[1];
8861
+
8862
+ GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
8863
+
8864
+ if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
8865
+ return;
8866
+ }
8867
+
8868
+ const int ith = params->ith;
8869
+ const int nth = params->nth;
8870
+
8871
+ const int nr = ggml_nrows(src0);
8872
+
8873
+ GGML_TENSOR_BINARY_OP_LOCALS
8874
+
8875
+ GGML_ASSERT(src0->type == GGML_TYPE_BF16);
8876
+ GGML_ASSERT(src1->type == GGML_TYPE_BF16);
8877
+ GGML_ASSERT(dst->type == GGML_TYPE_BF16);
8878
+
8879
+ GGML_ASSERT( nb0 == sizeof(ggml_bf16_t));
8880
+ GGML_ASSERT(nb00 == sizeof(ggml_bf16_t));
8881
+
8882
+ // rows per thread
8883
+ const int dr = (nr + nth - 1)/nth;
8884
+
8885
+ // row range for this thread
8886
+ const int ir0 = dr*ith;
8887
+ const int ir1 = MIN(ir0 + dr, nr);
8888
+
8889
+ if (nb10 == sizeof(ggml_bf16_t)) {
8890
+ for (int ir = ir0; ir < ir1; ++ir) {
8891
+ // src0, src1 and dst are same shape => same indices
8892
+ const int i3 = ir/(ne2*ne1);
8893
+ const int i2 = (ir - i3*ne2*ne1)/ne1;
8894
+ const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
8895
+
8896
+ ggml_bf16_t * dst_ptr = (ggml_bf16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
8897
+ ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
8898
+ ggml_bf16_t * src1_ptr = (ggml_bf16_t *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
8899
+
8900
+ for (int i = 0; i < ne0; i++) {
8901
+ dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + GGML_BF16_TO_FP32(src1_ptr[i]));
8902
+ }
8903
+ }
8904
+ }
8905
+ else {
8906
+ // src1 is not contiguous
8907
+ GGML_ASSERT(false);
8908
+ }
8909
+ }
8910
+
8039
8911
  static void ggml_compute_forward_add_q_f32(
8040
8912
  const struct ggml_compute_params * params,
8041
8913
  struct ggml_tensor * dst) {
@@ -8145,6 +9017,18 @@ static void ggml_compute_forward_add(
8145
9017
  GGML_ASSERT(false);
8146
9018
  }
8147
9019
  } break;
9020
+ case GGML_TYPE_BF16:
9021
+ {
9022
+ if (src1->type == GGML_TYPE_BF16) {
9023
+ ggml_compute_forward_add_bf16_bf16(params, dst);
9024
+ }
9025
+ else if (src1->type == GGML_TYPE_F32) {
9026
+ ggml_compute_forward_add_bf16_f32(params, dst);
9027
+ }
9028
+ else {
9029
+ GGML_ASSERT(false);
9030
+ }
9031
+ } break;
8148
9032
  case GGML_TYPE_Q4_0:
8149
9033
  case GGML_TYPE_Q4_1:
8150
9034
  case GGML_TYPE_Q5_0:
@@ -8358,21 +9242,133 @@ static void ggml_compute_forward_add1_q_f32(
8358
9242
 
8359
9243
  GGML_TENSOR_UNARY_OP_LOCALS
8360
9244
 
8361
- const enum ggml_type type = src0->type;
8362
- ggml_to_float_t const dequantize_row_q = type_traits[type].to_float;
8363
- ggml_from_float_t const quantize_row_q = type_traits[type].from_float;
8364
-
8365
- // we don't support permuted src0
8366
- GGML_ASSERT(nb00 == ggml_type_size(type));
8367
-
8368
- // dst cannot be transposed or permuted
8369
- GGML_ASSERT(nb0 <= nb1);
8370
- GGML_ASSERT(nb1 <= nb2);
8371
- GGML_ASSERT(nb2 <= nb3);
9245
+ const enum ggml_type type = src0->type;
9246
+ ggml_to_float_t const dequantize_row_q = type_traits[type].to_float;
9247
+ ggml_from_float_t const quantize_row_q = type_traits[type].from_float;
9248
+
9249
+ // we don't support permuted src0
9250
+ GGML_ASSERT(nb00 == ggml_type_size(type));
9251
+
9252
+ // dst cannot be transposed or permuted
9253
+ GGML_ASSERT(nb0 <= nb1);
9254
+ GGML_ASSERT(nb1 <= nb2);
9255
+ GGML_ASSERT(nb2 <= nb3);
9256
+
9257
+ GGML_ASSERT(ggml_is_quantized(src0->type));
9258
+ GGML_ASSERT(dst->type == src0->type);
9259
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
9260
+
9261
+ // rows per thread
9262
+ const int dr = (nr + nth - 1)/nth;
9263
+
9264
+ // row range for this thread
9265
+ const int ir0 = dr*ith;
9266
+ const int ir1 = MIN(ir0 + dr, nr);
9267
+
9268
+ float * wdata = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32) * ith;
9269
+
9270
+ for (int ir = ir0; ir < ir1; ++ir) {
9271
+ // src0 and dst are same shape => same indices
9272
+ const int i3 = ir/(ne2*ne1);
9273
+ const int i2 = (ir - i3*ne2*ne1)/ne1;
9274
+ const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
9275
+
9276
+ void * src0_row = (void *) ((char *) src0->data + (i1*nb01 + i2*nb02 + i3*nb03));
9277
+ void * dst_row = (void *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb0 ));
9278
+
9279
+ assert(ne0 % 32 == 0);
9280
+
9281
+ // unquantize row from src0 to temp buffer
9282
+ dequantize_row_q(src0_row, wdata, ne0);
9283
+ // add src1
9284
+ ggml_vec_acc1_f32(ne0, wdata, v);
9285
+ // quantize row to dst
9286
+ quantize_row_q(wdata, dst_row, ne0);
9287
+ }
9288
+ }
9289
+
9290
+ static void ggml_compute_forward_add1_bf16_f32(
9291
+ const struct ggml_compute_params * params,
9292
+ struct ggml_tensor * dst) {
9293
+
9294
+ const struct ggml_tensor * src0 = dst->src[0];
9295
+ const struct ggml_tensor * src1 = dst->src[1];
9296
+
9297
+ GGML_ASSERT(ggml_are_same_shape(src0, dst));
9298
+ GGML_ASSERT(ggml_is_scalar(src1));
9299
+
9300
+ if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
9301
+ return;
9302
+ }
9303
+
9304
+ // scalar to add
9305
+ const float v = *(float *) src1->data;
9306
+
9307
+ const int ith = params->ith;
9308
+ const int nth = params->nth;
9309
+
9310
+ const int nr = ggml_nrows(src0);
9311
+
9312
+ GGML_TENSOR_UNARY_OP_LOCALS
9313
+
9314
+ GGML_ASSERT(src0->type == GGML_TYPE_BF16);
9315
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
9316
+ GGML_ASSERT(dst->type == GGML_TYPE_BF16);
9317
+
9318
+ GGML_ASSERT( nb0 == sizeof(ggml_bf16_t));
9319
+ GGML_ASSERT(nb00 == sizeof(ggml_bf16_t));
9320
+
9321
+ // rows per thread
9322
+ const int dr = (nr + nth - 1)/nth;
9323
+
9324
+ // row range for this thread
9325
+ const int ir0 = dr*ith;
9326
+ const int ir1 = MIN(ir0 + dr, nr);
9327
+
9328
+ for (int ir = ir0; ir < ir1; ++ir) {
9329
+ // src0 and dst are same shape => same indices
9330
+ const int i3 = ir/(ne2*ne1);
9331
+ const int i2 = (ir - i3*ne2*ne1)/ne1;
9332
+ const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
9333
+
9334
+ ggml_bf16_t * dst_ptr = (ggml_bf16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 );
9335
+ ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
9336
+ for (int i = 0; i < ne0; i++) {
9337
+ dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + v);
9338
+ }
9339
+ }
9340
+ }
9341
+
9342
+ static void ggml_compute_forward_add1_bf16_bf16(
9343
+ const struct ggml_compute_params * params,
9344
+ struct ggml_tensor * dst) {
9345
+
9346
+ const struct ggml_tensor * src0 = dst->src[0];
9347
+ const struct ggml_tensor * src1 = dst->src[1];
9348
+
9349
+ GGML_ASSERT(ggml_are_same_shape(src0, dst));
9350
+ GGML_ASSERT(ggml_is_scalar(src1));
9351
+
9352
+ if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
9353
+ return;
9354
+ }
9355
+
9356
+ // scalar to add
9357
+ const float v = GGML_BF16_TO_FP32(*(ggml_bf16_t *) src1->data);
9358
+
9359
+ const int ith = params->ith;
9360
+ const int nth = params->nth;
9361
+
9362
+ const int nr = ggml_nrows(src0);
9363
+
9364
+ GGML_TENSOR_UNARY_OP_LOCALS
9365
+
9366
+ GGML_ASSERT(src0->type == GGML_TYPE_BF16);
9367
+ GGML_ASSERT(src1->type == GGML_TYPE_BF16);
9368
+ GGML_ASSERT(dst->type == GGML_TYPE_BF16);
8372
9369
 
8373
- GGML_ASSERT(ggml_is_quantized(src0->type));
8374
- GGML_ASSERT(dst->type == src0->type);
8375
- GGML_ASSERT(src1->type == GGML_TYPE_F32);
9370
+ GGML_ASSERT( nb0 == sizeof(ggml_bf16_t));
9371
+ GGML_ASSERT(nb00 == sizeof(ggml_bf16_t));
8376
9372
 
8377
9373
  // rows per thread
8378
9374
  const int dr = (nr + nth - 1)/nth;
@@ -8381,25 +9377,17 @@ static void ggml_compute_forward_add1_q_f32(
8381
9377
  const int ir0 = dr*ith;
8382
9378
  const int ir1 = MIN(ir0 + dr, nr);
8383
9379
 
8384
- float * wdata = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32) * ith;
8385
-
8386
9380
  for (int ir = ir0; ir < ir1; ++ir) {
8387
9381
  // src0 and dst are same shape => same indices
8388
9382
  const int i3 = ir/(ne2*ne1);
8389
9383
  const int i2 = (ir - i3*ne2*ne1)/ne1;
8390
9384
  const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
8391
9385
 
8392
- void * src0_row = (void *) ((char *) src0->data + (i1*nb01 + i2*nb02 + i3*nb03));
8393
- void * dst_row = (void *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb0 ));
8394
-
8395
- assert(ne0 % 32 == 0);
8396
-
8397
- // unquantize row from src0 to temp buffer
8398
- dequantize_row_q(src0_row, wdata, ne0);
8399
- // add src1
8400
- ggml_vec_acc1_f32(ne0, wdata, v);
8401
- // quantize row to dst
8402
- quantize_row_q(wdata, dst_row, ne0);
9386
+ ggml_bf16_t * dst_ptr = (ggml_bf16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 );
9387
+ ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
9388
+ for (int i = 0; i < ne0; i++) {
9389
+ dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + v);
9390
+ }
8403
9391
  }
8404
9392
  }
8405
9393
 
@@ -8427,6 +9415,18 @@ static void ggml_compute_forward_add1(
8427
9415
  GGML_ASSERT(false);
8428
9416
  }
8429
9417
  } break;
9418
+ case GGML_TYPE_BF16:
9419
+ {
9420
+ if (src1->type == GGML_TYPE_BF16) {
9421
+ ggml_compute_forward_add1_bf16_bf16(params, dst);
9422
+ }
9423
+ else if (src1->type == GGML_TYPE_F32) {
9424
+ ggml_compute_forward_add1_bf16_f32(params, dst);
9425
+ }
9426
+ else {
9427
+ GGML_ASSERT(false);
9428
+ }
9429
+ } break;
8430
9430
  case GGML_TYPE_Q4_0:
8431
9431
  case GGML_TYPE_Q4_1:
8432
9432
  case GGML_TYPE_Q5_0:
@@ -8555,6 +9555,7 @@ static void ggml_compute_forward_acc(
8555
9555
  ggml_compute_forward_acc_f32(params, dst);
8556
9556
  } break;
8557
9557
  case GGML_TYPE_F16:
9558
+ case GGML_TYPE_BF16:
8558
9559
  case GGML_TYPE_Q4_0:
8559
9560
  case GGML_TYPE_Q4_1:
8560
9561
  case GGML_TYPE_Q5_0:
@@ -9076,6 +10077,40 @@ static void ggml_compute_forward_sum_f16(
9076
10077
  ((ggml_fp16_t *) dst->data)[0] = GGML_FP32_TO_FP16(sum);
9077
10078
  }
9078
10079
 
10080
+ static void ggml_compute_forward_sum_bf16(
10081
+ const struct ggml_compute_params * params,
10082
+ struct ggml_tensor * dst) {
10083
+
10084
+ const struct ggml_tensor * src0 = dst->src[0];
10085
+
10086
+ assert(params->ith == 0);
10087
+ assert(ggml_is_scalar(dst));
10088
+
10089
+ if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
10090
+ return;
10091
+ }
10092
+
10093
+ assert(src0->nb[0] == sizeof(ggml_bf16_t));
10094
+
10095
+ GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
10096
+ GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
10097
+
10098
+ float sum = 0;
10099
+ float row_sum = 0;
10100
+
10101
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
10102
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
10103
+ for (int64_t i01 = 0; i01 < ne01; i01++) {
10104
+ ggml_vec_sum_bf16_ggf(ne00,
10105
+ &row_sum,
10106
+ (ggml_bf16_t *) ((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03));
10107
+ sum += row_sum;
10108
+ }
10109
+ }
10110
+ }
10111
+ ((ggml_bf16_t *) dst->data)[0] = GGML_FP32_TO_BF16(sum);
10112
+ }
10113
+
9079
10114
  static void ggml_compute_forward_sum(
9080
10115
  const struct ggml_compute_params * params,
9081
10116
  struct ggml_tensor * dst) {
@@ -9091,6 +10126,10 @@ static void ggml_compute_forward_sum(
9091
10126
  {
9092
10127
  ggml_compute_forward_sum_f16(params, dst);
9093
10128
  } break;
10129
+ case GGML_TYPE_BF16:
10130
+ {
10131
+ ggml_compute_forward_sum_bf16(params, dst);
10132
+ } break;
9094
10133
  default:
9095
10134
  {
9096
10135
  GGML_ASSERT(false);
@@ -9365,6 +10404,7 @@ static void ggml_compute_forward_repeat(
9365
10404
 
9366
10405
  switch (src0->type) {
9367
10406
  case GGML_TYPE_F16:
10407
+ case GGML_TYPE_BF16:
9368
10408
  case GGML_TYPE_I16:
9369
10409
  {
9370
10410
  ggml_compute_forward_repeat_f16(params, dst);
@@ -11682,6 +12722,7 @@ static void ggml_compute_forward_set(
11682
12722
  ggml_compute_forward_set_f32(params, dst);
11683
12723
  } break;
11684
12724
  case GGML_TYPE_F16:
12725
+ case GGML_TYPE_BF16:
11685
12726
  case GGML_TYPE_Q4_0:
11686
12727
  case GGML_TYPE_Q4_1:
11687
12728
  case GGML_TYPE_Q5_0:
@@ -11856,6 +12897,49 @@ static void ggml_compute_forward_get_rows_f16(
11856
12897
  }
11857
12898
  }
11858
12899
 
12900
+ static void ggml_compute_forward_get_rows_bf16(
12901
+ const struct ggml_compute_params * params,
12902
+ struct ggml_tensor * dst) {
12903
+
12904
+ const struct ggml_tensor * src0 = dst->src[0];
12905
+ const struct ggml_tensor * src1 = dst->src[1];
12906
+
12907
+ if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
12908
+ return;
12909
+ }
12910
+
12911
+ GGML_TENSOR_BINARY_OP_LOCALS
12912
+
12913
+ const int64_t nc = ne00;
12914
+ const int64_t nr = ggml_nelements(src1);
12915
+
12916
+ assert(ne0 == nc);
12917
+ assert(ne02 == ne11);
12918
+ assert(nb00 == sizeof(ggml_bf16_t));
12919
+ assert(ggml_nrows(dst) == nr);
12920
+
12921
+ const int ith = params->ith;
12922
+ const int nth = params->nth;
12923
+
12924
+ // rows per thread
12925
+ const int dr = (nr + nth - 1)/nth;
12926
+
12927
+ // row range for this thread
12928
+ const int ir0 = dr*ith;
12929
+ const int ir1 = MIN(ir0 + dr, nr);
12930
+
12931
+ for (int64_t i = ir0; i < ir1; ++i) {
12932
+ const int64_t i12 = i/(ne11*ne10);
12933
+ const int64_t i11 = (i - i12*ne11*ne10)/ne10;
12934
+ const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
12935
+ const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
12936
+
12937
+ ggml_bf16_to_fp32_row(
12938
+ (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
12939
+ (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc);
12940
+ }
12941
+ }
12942
+
11859
12943
  static void ggml_compute_forward_get_rows_f32(
11860
12944
  const struct ggml_compute_params * params,
11861
12945
  struct ggml_tensor * dst) {
@@ -11933,6 +13017,10 @@ static void ggml_compute_forward_get_rows(
11933
13017
  {
11934
13018
  ggml_compute_forward_get_rows_f16(params, dst);
11935
13019
  } break;
13020
+ case GGML_TYPE_BF16:
13021
+ {
13022
+ ggml_compute_forward_get_rows_bf16(params, dst);
13023
+ } break;
11936
13024
  case GGML_TYPE_F32:
11937
13025
  case GGML_TYPE_I32:
11938
13026
  {
@@ -12267,7 +13355,7 @@ static void ggml_compute_forward_soft_max_f32(
12267
13355
 
12268
13356
  GGML_TENSOR_UNARY_OP_LOCALS
12269
13357
 
12270
- const int64_t ne11 = src1 ? src1->ne[1] : 1;
13358
+ //const int64_t ne11 = src1 ? src1->ne[1] : 1;
12271
13359
 
12272
13360
  // TODO: is this supposed to be ceil instead of floor?
12273
13361
  // https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370
@@ -12290,19 +13378,31 @@ static void ggml_compute_forward_soft_max_f32(
12290
13378
  float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
12291
13379
 
12292
13380
  // when max_bias <= 0.0f, src2 is not used and we default it to src0 to avoid branching
12293
- float * pos = src2 ? (float *) src2->data : src0->data;
13381
+ ggml_fp16_t * pos_f16 = src2 ? (ggml_fp16_t *) src2->data : src0->data;
13382
+ float * pos_f32 = src2 ? (float *) src2->data : src0->data;
13383
+
13384
+ const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16) || (src2 && src2->type == GGML_TYPE_F16);
12294
13385
 
12295
13386
  for (int i1 = ir0; i1 < ir1; i1++) {
12296
13387
  float * sp = (float *)((char *) src0->data + i1*src0->nb[1]);
12297
13388
  float * dp = (float *)((char *) dst->data + i1*dst->nb[1]);
12298
13389
 
12299
13390
  // broadcast the mask across rows
12300
- float * mp = src1 ? (float *)((char *) src1->data + (i1%ne11)*src1->nb[1]) : NULL;
13391
+ ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
13392
+ float * mp_f32 = src1 ? (float *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
12301
13393
 
12302
13394
  ggml_vec_cpy_f32 (nc, wp, sp);
12303
13395
  ggml_vec_scale_f32(nc, wp, scale);
12304
- if (mp) {
12305
- ggml_vec_acc_f32(nc, wp, mp);
13396
+ if (mp_f32) {
13397
+ if (use_f16) {
13398
+ for (int i = 0; i < nc; ++i) {
13399
+ wp[i] += GGML_FP16_TO_FP32(mp_f16[i]);
13400
+ }
13401
+ } else {
13402
+ for (int i = 0; i < nc; ++i) {
13403
+ wp[i] += mp_f32[i];
13404
+ }
13405
+ }
12306
13406
  }
12307
13407
 
12308
13408
  // ALiBi bias
@@ -12310,8 +13410,14 @@ static void ggml_compute_forward_soft_max_f32(
12310
13410
  const uint32_t h = (i1/ne01)%ne02; // head
12311
13411
  const float slope = h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1);
12312
13412
 
12313
- for (int i = 0; i < nc; i++) {
12314
- wp[i] = wp[i] + slope*pos[i];
13413
+ if (use_f16) {
13414
+ for (int i = 0; i < nc; ++i) {
13415
+ wp[i] += slope*GGML_FP16_TO_FP32(pos_f16[i]);
13416
+ }
13417
+ } else {
13418
+ for (int i = 0; i < nc; ++i) {
13419
+ wp[i] += slope*pos_f32[i];
13420
+ }
12315
13421
  }
12316
13422
  }
12317
13423
 
@@ -12610,6 +13716,7 @@ static void ggml_compute_forward_alibi(
12610
13716
  {
12611
13717
  ggml_compute_forward_alibi_f32(params, dst);
12612
13718
  } break;
13719
+ case GGML_TYPE_BF16:
12613
13720
  case GGML_TYPE_Q4_0:
12614
13721
  case GGML_TYPE_Q4_1:
12615
13722
  case GGML_TYPE_Q5_0:
@@ -12699,6 +13806,7 @@ static void ggml_compute_forward_clamp(
12699
13806
  ggml_compute_forward_clamp_f32(params, dst);
12700
13807
  } break;
12701
13808
  case GGML_TYPE_F16:
13809
+ case GGML_TYPE_BF16:
12702
13810
  case GGML_TYPE_Q4_0:
12703
13811
  case GGML_TYPE_Q4_1:
12704
13812
  case GGML_TYPE_Q5_0:
@@ -14581,6 +15689,198 @@ static void ggml_compute_forward_flash_attn(
14581
15689
  }
14582
15690
  }
14583
15691
 
15692
+ // ggml_compute_forward_flash_attn_ext
15693
+
15694
+ static void ggml_compute_forward_flash_attn_ext_f16(
15695
+ const struct ggml_compute_params * params,
15696
+ const struct ggml_tensor * q,
15697
+ const struct ggml_tensor * k,
15698
+ const struct ggml_tensor * v,
15699
+ const struct ggml_tensor * mask,
15700
+ struct ggml_tensor * dst) {
15701
+ int64_t t0 = ggml_perf_time_us();
15702
+ UNUSED(t0);
15703
+
15704
+ GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
15705
+ GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
15706
+ GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
15707
+ GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
15708
+ GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
15709
+ GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
15710
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
15711
+ GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
15712
+
15713
+ const int ith = params->ith;
15714
+ const int nth = params->nth;
15715
+
15716
+ const int64_t D = neq0;
15717
+ const int64_t N = neq1;
15718
+
15719
+ GGML_ASSERT(ne0 == D);
15720
+ GGML_ASSERT(ne2 == N);
15721
+
15722
+ GGML_ASSERT(nbq0 == sizeof(float));
15723
+ GGML_ASSERT(nbk0 == sizeof(ggml_fp16_t));
15724
+ GGML_ASSERT(nbv0 == sizeof(ggml_fp16_t));
15725
+
15726
+ GGML_ASSERT(neq0 == D);
15727
+ GGML_ASSERT(nek0 == D);
15728
+ GGML_ASSERT(nev0 == D);
15729
+
15730
+ GGML_ASSERT(neq1 == N);
15731
+ GGML_ASSERT(nev0 == D);
15732
+
15733
+ // dst cannot be transposed or permuted
15734
+ GGML_ASSERT(nb0 == sizeof(float));
15735
+ GGML_ASSERT(nb0 <= nb1);
15736
+ GGML_ASSERT(nb1 <= nb2);
15737
+ GGML_ASSERT(nb2 <= nb3);
15738
+
15739
+ // broadcast factors
15740
+ const int64_t rk2 = neq2/nek2;
15741
+ const int64_t rk3 = neq3/nek3;
15742
+
15743
+ const int64_t rv2 = neq2/nev2;
15744
+ const int64_t rv3 = neq3/nev3;
15745
+
15746
+ if (params->type == GGML_TASK_TYPE_INIT) {
15747
+ return;
15748
+ }
15749
+
15750
+ if (params->type == GGML_TASK_TYPE_FINALIZE) {
15751
+ return;
15752
+ }
15753
+
15754
+ // parallelize by q rows using ggml_vec_dot_f32
15755
+
15756
+ // total rows in q
15757
+ const int nr = neq1*neq2*neq3;
15758
+
15759
+ // rows per thread
15760
+ const int dr = (nr + nth - 1)/nth;
15761
+
15762
+ // row range for this thread
15763
+ const int ir0 = dr*ith;
15764
+ const int ir1 = MIN(ir0 + dr, nr);
15765
+
15766
+ float scale = 1.0f;
15767
+ memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
15768
+
15769
+ // loop over n_batch and n_head
15770
+ for (int ir = ir0; ir < ir1; ++ir) {
15771
+ // q indices
15772
+ const int iq3 = ir/(neq2*neq1);
15773
+ const int iq2 = (ir - iq3*neq2*neq1)/neq1;
15774
+ const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
15775
+
15776
+ float S = 0.0f;
15777
+ float M = -INFINITY;
15778
+
15779
+ float * V32 = (float *) params->wdata + ith*(2*D + CACHE_LINE_SIZE_F32);
15780
+ ggml_fp16_t * Q16 = (ggml_fp16_t *) (V32); // reuse memory
15781
+ ggml_fp16_t * V16 = (ggml_fp16_t *) (V32 + D);
15782
+
15783
+ memset(V16, 0, D*sizeof(ggml_fp16_t));
15784
+
15785
+ const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL;
15786
+
15787
+ // k indices
15788
+ const int ik3 = iq3 / rk3;
15789
+ const int ik2 = iq2 / rk2;
15790
+
15791
+ // v indices
15792
+ const int iv3 = iq3 / rv3;
15793
+ const int iv2 = iq2 / rv2;
15794
+
15795
+ // online softmax / attention
15796
+ // loop over n_kv and n_head_kv
15797
+ // ref: https://arxiv.org/pdf/2112.05682.pdf
15798
+ for (int64_t ic = 0; ic < nek1; ++ic) {
15799
+ const float mv = mp ? GGML_FP16_TO_FP32(mp[ic]) : 0.0f;
15800
+ if (mv == -INFINITY) {
15801
+ continue;
15802
+ }
15803
+
15804
+ float s;
15805
+
15806
+ // convert Q to F16 in V32
15807
+ {
15808
+ const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
15809
+
15810
+ for (int64_t d = 0; d < D; ++d) {
15811
+ Q16[d] = GGML_FP32_TO_FP16(pq[d]);
15812
+ }
15813
+ }
15814
+
15815
+ ggml_vec_dot_f16(D,
15816
+ &s, 0,
15817
+ (ggml_fp16_t *) ((char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3)), 0,
15818
+ Q16, 0, 1);
15819
+
15820
+ s = s*scale + mv;
15821
+
15822
+ const float Mold = M;
15823
+
15824
+ float ms = 1.0f;
15825
+ float vs = 1.0f;
15826
+
15827
+ if (s > M) {
15828
+ M = s;
15829
+ ms = expf(Mold - M);
15830
+
15831
+ // V = V*expf(Mold - M)
15832
+ ggml_vec_scale_f16(D, V16, ms);
15833
+ } else {
15834
+ vs = expf(s - M);
15835
+ }
15836
+
15837
+ const ggml_fp16_t * v16 = (const ggml_fp16_t *) ((char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3));
15838
+
15839
+ // V += v*expf(s - M)
15840
+ ggml_vec_mad_f16(D, V16, v16, vs);
15841
+
15842
+ S = S*ms + vs;
15843
+ }
15844
+
15845
+ // V /= S
15846
+ for (int64_t d = 0; d < D; ++d) {
15847
+ V32[d] = GGML_FP16_TO_FP32(V16[d])/S;
15848
+ }
15849
+
15850
+ // dst indices
15851
+ const int i1 = iq1;
15852
+ const int i2 = iq2;
15853
+ const int i3 = iq3;
15854
+
15855
+ // original
15856
+ //memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float));
15857
+
15858
+ // permute(0, 2, 1, 3)
15859
+ memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, V32, nb1);
15860
+ }
15861
+ }
15862
+
15863
+ static void ggml_compute_forward_flash_attn_ext(
15864
+ const struct ggml_compute_params * params,
15865
+ const struct ggml_tensor * q,
15866
+ const struct ggml_tensor * k,
15867
+ const struct ggml_tensor * v,
15868
+ const struct ggml_tensor * mask,
15869
+ struct ggml_tensor * dst) {
15870
+ switch (dst->op_params[1]) {
15871
+ case GGML_PREC_DEFAULT:
15872
+ case GGML_PREC_F32:
15873
+ {
15874
+ // uses F32 accumulators
15875
+ ggml_compute_forward_flash_attn_ext_f16(params, q, k, v, mask, dst);
15876
+ } break;
15877
+ default:
15878
+ {
15879
+ GGML_ASSERT(false);
15880
+ } break;
15881
+ }
15882
+ }
15883
+
14584
15884
  // ggml_compute_forward_flash_ff
14585
15885
 
14586
15886
  static void ggml_compute_forward_flash_ff_f16(
@@ -15600,6 +16900,7 @@ static void ggml_compute_forward_get_rel_pos(
15600
16900
 
15601
16901
  switch (src0->type) {
15602
16902
  case GGML_TYPE_F16:
16903
+ case GGML_TYPE_BF16:
15603
16904
  {
15604
16905
  ggml_compute_forward_get_rel_pos_f16(params, dst);
15605
16906
  } break;
@@ -16388,6 +17689,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
16388
17689
  const bool masked = t != 0;
16389
17690
  ggml_compute_forward_flash_attn(params, masked, tensor);
16390
17691
  } break;
17692
+ case GGML_OP_FLASH_ATTN_EXT:
17693
+ {
17694
+ ggml_compute_forward_flash_attn_ext(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor);
17695
+ } break;
16391
17696
  case GGML_OP_FLASH_FF:
16392
17697
  {
16393
17698
  ggml_compute_forward_flash_ff(params, tensor);
@@ -17400,6 +18705,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
17400
18705
  GGML_ASSERT(false); // TODO: not implemented
17401
18706
  } break;
17402
18707
  case GGML_OP_FLASH_ATTN:
18708
+ case GGML_OP_FLASH_ATTN_EXT:
17403
18709
  {
17404
18710
  struct ggml_tensor * flash_grad = NULL;
17405
18711
  if (src0->grad || src1->grad || tensor->src[2]->grad) {
@@ -18172,6 +19478,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads, int n_cur_
18172
19478
  n_tasks = n_threads;
18173
19479
  } break;
18174
19480
  case GGML_OP_FLASH_ATTN:
19481
+ case GGML_OP_FLASH_ATTN_EXT:
18175
19482
  {
18176
19483
  n_tasks = n_threads;
18177
19484
  } break;
@@ -18458,7 +19765,10 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
18458
19765
  case GGML_OP_CPY:
18459
19766
  case GGML_OP_DUP:
18460
19767
  {
18461
- if (ggml_is_quantized(node->type)) {
19768
+ if (ggml_is_quantized(node->type) ||
19769
+ // F16 -> BF16 and BF16 -> F16 copies go through intermediate F32
19770
+ (node->src[0]->type == GGML_TYPE_F16 && node->src[1] && node->src[1]->type == GGML_TYPE_BF16) ||
19771
+ (node->src[0]->type == GGML_TYPE_BF16 && node->src[1] && node->src[1]->type == GGML_TYPE_F16)) {
18462
19772
  cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
18463
19773
  }
18464
19774
  } break;
@@ -18537,7 +19847,8 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
18537
19847
  const int64_t ne10 = node->src[1]->ne[0]; // L
18538
19848
  const int64_t ne11 = node->src[1]->ne[1]; // Cin
18539
19849
 
18540
- if (node->src[0]->type == GGML_TYPE_F16 &&
19850
+ if ((node->src[0]->type == GGML_TYPE_F16 ||
19851
+ node->src[0]->type == GGML_TYPE_BF16) &&
18541
19852
  node->src[1]->type == GGML_TYPE_F32) {
18542
19853
  cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02;
18543
19854
  cur += sizeof(ggml_fp16_t)*ne10*ne11;
@@ -18573,8 +19884,17 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
18573
19884
  } else if (node->src[1]->type == GGML_TYPE_F16) {
18574
19885
  cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1)
18575
19886
  cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
19887
+ } else if (node->src[1]->type == GGML_TYPE_BF16) {
19888
+ cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1)
19889
+ cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
18576
19890
  }
18577
19891
  } break;
19892
+ case GGML_OP_FLASH_ATTN_EXT:
19893
+ {
19894
+ const int64_t ne00 = node->src[0]->ne[0]; // D
19895
+
19896
+ cur = 2*sizeof(float)*ne00*n_tasks; // 2x head size
19897
+ } break;
18578
19898
  case GGML_OP_FLASH_FF:
18579
19899
  {
18580
19900
  if (node->src[1]->type == GGML_TYPE_F32) {
@@ -18583,6 +19903,9 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
18583
19903
  } else if (node->src[1]->type == GGML_TYPE_F16) {
18584
19904
  cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
18585
19905
  cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2
19906
+ } else if (node->src[1]->type == GGML_TYPE_BF16) {
19907
+ cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
19908
+ cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2
18586
19909
  }
18587
19910
  } break;
18588
19911
  case GGML_OP_FLASH_ATTN_BACK:
@@ -18596,6 +19919,9 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
18596
19919
  } else if (node->src[1]->type == GGML_TYPE_F16) {
18597
19920
  cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1)
18598
19921
  cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
19922
+ } else if (node->src[1]->type == GGML_TYPE_BF16) {
19923
+ cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1)
19924
+ cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
18599
19925
  }
18600
19926
  } break;
18601
19927
 
@@ -19372,7 +20698,9 @@ void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph
19372
20698
  if (node->type == GGML_TYPE_I8 || node->type == GGML_TYPE_I16 || node->type == GGML_TYPE_I32) {
19373
20699
  fprintf(fp, "%d", ggml_get_i32_1d(node, j));
19374
20700
  }
19375
- else if (node->type == GGML_TYPE_F32 || node->type == GGML_TYPE_F16) {
20701
+ else if (node->type == GGML_TYPE_F32 ||
20702
+ node->type == GGML_TYPE_F16 ||
20703
+ node->type == GGML_TYPE_BF16) {
19376
20704
  fprintf(fp, "%.1e", (double)ggml_get_f32_1d(node, j));
19377
20705
  }
19378
20706
  else {
@@ -20430,6 +21758,12 @@ size_t ggml_quantize_chunk(
20430
21758
  ggml_fp32_to_fp16_row(src + start, (ggml_fp16_t *)dst + start, n);
20431
21759
  result = n * elemsize;
20432
21760
  } break;
21761
+ case GGML_TYPE_BF16:
21762
+ {
21763
+ size_t elemsize = sizeof(ggml_bf16_t);
21764
+ ggml_fp32_to_bf16_row(src + start, (ggml_bf16_t *)dst + start, n);
21765
+ result = n * elemsize;
21766
+ } break;
20433
21767
  case GGML_TYPE_F32:
20434
21768
  {
20435
21769
  size_t elemsize = sizeof(float);
@@ -20626,7 +21960,7 @@ static void gguf_free_kv(struct gguf_kv * kv) {
20626
21960
  }
20627
21961
 
20628
21962
  struct gguf_context * gguf_init_empty(void) {
20629
- struct gguf_context * ctx = GGML_ALIGNED_MALLOC(sizeof(struct gguf_context));
21963
+ struct gguf_context * ctx = GGML_CALLOC(1, sizeof(struct gguf_context));
20630
21964
 
20631
21965
  memcpy(ctx->header.magic, GGUF_MAGIC, sizeof(ctx->header.magic));
20632
21966
  ctx->header.version = GGUF_VERSION;
@@ -20671,7 +22005,7 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
20671
22005
 
20672
22006
  bool ok = true;
20673
22007
 
20674
- struct gguf_context * ctx = GGML_ALIGNED_MALLOC(sizeof(struct gguf_context));
22008
+ struct gguf_context * ctx = GGML_CALLOC(1, sizeof(struct gguf_context));
20675
22009
 
20676
22010
  // read the header
20677
22011
  {
@@ -20708,9 +22042,13 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
20708
22042
 
20709
22043
  // read the kv pairs
20710
22044
  {
20711
- ctx->kv = GGML_MALLOC(ctx->header.n_kv * sizeof(struct gguf_kv));
22045
+ const uint64_t n_kv = ctx->header.n_kv;
20712
22046
 
20713
- for (uint64_t i = 0; i < ctx->header.n_kv; ++i) {
22047
+ // header.n_kv will hold the actual value of pairs that were successfully read in the loop below
22048
+ ctx->header.n_kv = 0;
22049
+ ctx->kv = GGML_CALLOC(n_kv, sizeof(struct gguf_kv));
22050
+
22051
+ for (uint64_t i = 0; i < n_kv; ++i) {
20714
22052
  struct gguf_kv * kv = &ctx->kv[i];
20715
22053
 
20716
22054
  //fprintf(stderr, "%s: reading kv %d\n", __func__, i);
@@ -20759,7 +22097,7 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
20759
22097
  return NULL;
20760
22098
  }
20761
22099
 
20762
- kv->value.arr.data = GGML_MALLOC(kv->value.arr.n * gguf_type_size(kv->value.arr.type));
22100
+ kv->value.arr.data = GGML_CALLOC(kv->value.arr.n, gguf_type_size(kv->value.arr.type));
20763
22101
 
20764
22102
  ok = ok && gguf_fread_el(file, kv->value.arr.data, kv->value.arr.n * gguf_type_size(kv->value.arr.type), &offset);
20765
22103
  } break;
@@ -20773,7 +22111,7 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
20773
22111
  return NULL;
20774
22112
  }
20775
22113
 
20776
- kv->value.arr.data = GGML_MALLOC(kv->value.arr.n * sizeof(struct gguf_str));
22114
+ kv->value.arr.data = GGML_CALLOC(kv->value.arr.n, sizeof(struct gguf_str));
20777
22115
 
20778
22116
  for (uint64_t j = 0; j < kv->value.arr.n; ++j) {
20779
22117
  ok = ok && gguf_fread_str(file, &((struct gguf_str *) kv->value.arr.data)[j], &offset);
@@ -20789,6 +22127,8 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
20789
22127
  if (!ok) {
20790
22128
  break;
20791
22129
  }
22130
+
22131
+ ctx->header.n_kv++;
20792
22132
  }
20793
22133
 
20794
22134
  if (!ok) {
@@ -20800,8 +22140,8 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
20800
22140
  }
20801
22141
 
20802
22142
  // read the tensor infos
20803
- {
20804
- ctx->infos = GGML_MALLOC(ctx->header.n_tensors * sizeof(struct gguf_tensor_info));
22143
+ if (ctx->header.n_tensors > 0) {
22144
+ ctx->infos = GGML_CALLOC(ctx->header.n_tensors, sizeof(struct gguf_tensor_info));
20805
22145
 
20806
22146
  for (uint64_t i = 0; i < ctx->header.n_tensors; ++i) {
20807
22147
  struct gguf_tensor_info * info = &ctx->infos[i];
@@ -20822,8 +22162,17 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
20822
22162
  ok = ok && gguf_fread_el (file, &info->type, sizeof(info->type), &offset);
20823
22163
  ok = ok && gguf_fread_el (file, &info->offset, sizeof(info->offset), &offset);
20824
22164
 
22165
+ // TODO: return an error instead of crashing with GGML_ASSERT
20825
22166
  gguf_tensor_info_sanitize(info);
20826
22167
 
22168
+ // make sure there is no duplicated tensor names
22169
+ for (uint64_t j = 0; j < i; ++j) {
22170
+ if (strcmp(info->name.data, ctx->infos[j].name.data) == 0) {
22171
+ fprintf(stderr, "%s: duplicated tensor name %s\n", __func__, info->name.data);
22172
+ ok = false;
22173
+ }
22174
+ }
22175
+
20827
22176
  if (!ok) {
20828
22177
  fprintf(stderr, "%s: failed to read tensor info\n", __func__);
20829
22178
  fclose(file);
@@ -20992,7 +22341,7 @@ void gguf_free(struct gguf_context * ctx) {
20992
22341
  GGML_FREE(ctx->infos);
20993
22342
  }
20994
22343
 
20995
- GGML_ALIGNED_FREE(ctx);
22344
+ GGML_FREE(ctx);
20996
22345
  }
20997
22346
 
20998
22347
  const char * gguf_type_name(enum gguf_type type) {
@@ -21303,7 +22652,7 @@ void gguf_set_arr_data(struct gguf_context * ctx, const char * key, enum gguf_ty
21303
22652
  ctx->kv[idx].type = GGUF_TYPE_ARRAY;
21304
22653
  ctx->kv[idx].value.arr.type = type;
21305
22654
  ctx->kv[idx].value.arr.n = n;
21306
- ctx->kv[idx].value.arr.data = GGML_MALLOC(n*gguf_type_size(type));
22655
+ ctx->kv[idx].value.arr.data = GGML_CALLOC(n, gguf_type_size(type));
21307
22656
  memcpy(ctx->kv[idx].value.arr.data, data, n*gguf_type_size(type));
21308
22657
  }
21309
22658
 
@@ -21313,7 +22662,7 @@ void gguf_set_arr_str(struct gguf_context * ctx, const char * key, const char **
21313
22662
  ctx->kv[idx].type = GGUF_TYPE_ARRAY;
21314
22663
  ctx->kv[idx].value.arr.type = GGUF_TYPE_STRING;
21315
22664
  ctx->kv[idx].value.arr.n = n;
21316
- ctx->kv[idx].value.arr.data = GGML_MALLOC(n*sizeof(struct gguf_str));
22665
+ ctx->kv[idx].value.arr.data = GGML_CALLOC(n, sizeof(struct gguf_str));
21317
22666
  for (int i = 0; i < n; i++) {
21318
22667
  struct gguf_str * str = &((struct gguf_str *)ctx->kv[idx].value.arr.data)[i];
21319
22668
  str->n = strlen(data[i]);
@@ -21340,7 +22689,7 @@ void gguf_set_kv(struct gguf_context * ctx, struct gguf_context * src) {
21340
22689
  case GGUF_TYPE_ARRAY:
21341
22690
  {
21342
22691
  if (src->kv[i].value.arr.type == GGUF_TYPE_STRING) {
21343
- const char ** data = GGML_MALLOC(src->kv[i].value.arr.n*sizeof(char *));
22692
+ const char ** data = GGML_CALLOC(src->kv[i].value.arr.n, sizeof(char *));
21344
22693
  for (uint32_t j = 0; j < src->kv[i].value.arr.n; j++) {
21345
22694
  data[j] = ((struct gguf_str *)src->kv[i].value.arr.data)[j].data;
21346
22695
  }
@@ -21360,6 +22709,10 @@ void gguf_set_kv(struct gguf_context * ctx, struct gguf_context * src) {
21360
22709
  void gguf_add_tensor(
21361
22710
  struct gguf_context * ctx,
21362
22711
  const struct ggml_tensor * tensor) {
22712
+ if (gguf_find_tensor(ctx, tensor->name) != -1) {
22713
+ GGML_ASSERT(false && "duplicated tensor name");
22714
+ }
22715
+
21363
22716
  const int idx = ctx->header.n_tensors;
21364
22717
  ctx->infos = realloc(ctx->infos, (idx + 1)*sizeof(struct gguf_tensor_info));
21365
22718
 
@@ -21428,7 +22781,7 @@ struct gguf_buf {
21428
22781
 
21429
22782
  static struct gguf_buf gguf_buf_init(size_t size) {
21430
22783
  struct gguf_buf buf = {
21431
- /*buf.data =*/ size == 0 ? NULL : GGML_MALLOC(size),
22784
+ /*buf.data =*/ size == 0 ? NULL : GGML_CALLOC(1, size),
21432
22785
  /*buf.size =*/ size,
21433
22786
  /*buf.offset =*/ 0,
21434
22787
  };