cui-llama.rn 1.4.6 → 1.5.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (80) hide show
  1. package/android/src/main/CMakeLists.txt +9 -2
  2. package/android/src/main/jni.cpp +52 -34
  3. package/android/src/main/jniLibs/arm64-v8a/librnllama.so +0 -0
  4. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8.so +0 -0
  5. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2.so +0 -0
  6. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod.so +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod_i8mm.so +0 -0
  8. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_i8mm.so +0 -0
  9. package/android/src/main/jniLibs/x86_64/librnllama.so +0 -0
  10. package/android/src/main/jniLibs/x86_64/librnllama_x86_64.so +0 -0
  11. package/cpp/binary-ops.cpp +158 -0
  12. package/cpp/binary-ops.h +16 -0
  13. package/cpp/chat.cpp +1769 -1779
  14. package/cpp/chat.h +9 -1
  15. package/cpp/common.cpp +20 -522
  16. package/cpp/common.h +13 -36
  17. package/cpp/cpu-common.h +72 -0
  18. package/cpp/ggml-common.h +12 -6
  19. package/cpp/ggml-cpu-aarch64.cpp +1557 -80
  20. package/cpp/ggml-cpu-impl.h +2 -21
  21. package/cpp/ggml-cpu-quants.c +904 -405
  22. package/cpp/ggml-cpu.c +909 -13237
  23. package/cpp/ggml-impl.h +50 -23
  24. package/cpp/ggml-metal-impl.h +77 -3
  25. package/cpp/ggml-metal.m +794 -580
  26. package/cpp/ggml.c +92 -3
  27. package/cpp/ggml.h +29 -5
  28. package/cpp/gguf.cpp +1 -0
  29. package/cpp/llama-adapter.cpp +55 -20
  30. package/cpp/llama-adapter.h +11 -9
  31. package/cpp/llama-arch.cpp +217 -16
  32. package/cpp/llama-arch.h +25 -0
  33. package/cpp/llama-batch.h +2 -2
  34. package/cpp/llama-chat.cpp +54 -2
  35. package/cpp/llama-chat.h +3 -0
  36. package/cpp/llama-context.cpp +2294 -1238
  37. package/cpp/llama-context.h +214 -77
  38. package/cpp/llama-cparams.h +1 -0
  39. package/cpp/llama-graph.cpp +1695 -0
  40. package/cpp/llama-graph.h +592 -0
  41. package/cpp/llama-hparams.cpp +8 -0
  42. package/cpp/llama-hparams.h +17 -0
  43. package/cpp/llama-io.cpp +15 -0
  44. package/cpp/llama-io.h +35 -0
  45. package/cpp/llama-kv-cache.cpp +965 -303
  46. package/cpp/llama-kv-cache.h +145 -151
  47. package/cpp/llama-memory.cpp +1 -0
  48. package/cpp/llama-memory.h +21 -0
  49. package/cpp/llama-mmap.cpp +1 -1
  50. package/cpp/llama-model-loader.cpp +10 -5
  51. package/cpp/llama-model-loader.h +5 -3
  52. package/cpp/llama-model.cpp +9194 -201
  53. package/cpp/llama-model.h +40 -1
  54. package/cpp/llama-sampling.cpp +5 -0
  55. package/cpp/llama-vocab.cpp +36 -5
  56. package/cpp/llama.cpp +51 -9984
  57. package/cpp/llama.h +102 -22
  58. package/cpp/log.cpp +34 -0
  59. package/cpp/minja/chat-template.hpp +15 -7
  60. package/cpp/minja/minja.hpp +120 -94
  61. package/cpp/ops.cpp +8723 -0
  62. package/cpp/ops.h +128 -0
  63. package/cpp/rn-llama.cpp +44 -53
  64. package/cpp/rn-llama.h +2 -12
  65. package/cpp/sampling.cpp +3 -0
  66. package/cpp/sgemm.cpp +533 -88
  67. package/cpp/simd-mappings.h +888 -0
  68. package/cpp/speculative.cpp +4 -4
  69. package/cpp/unary-ops.cpp +186 -0
  70. package/cpp/unary-ops.h +28 -0
  71. package/cpp/vec.cpp +258 -0
  72. package/cpp/vec.h +802 -0
  73. package/ios/CMakeLists.txt +5 -2
  74. package/ios/RNLlama.mm +2 -2
  75. package/ios/RNLlamaContext.mm +40 -24
  76. package/package.json +1 -1
  77. package/src/NativeRNLlama.ts +6 -4
  78. package/src/index.ts +3 -1
  79. package/cpp/chat-template.hpp +0 -529
  80. package/cpp/minja.hpp +0 -2915
@@ -19,6 +19,10 @@ set(
19
19
  ${RNLLAMA_LIB_DIR}/ggml-alloc.c
20
20
  ${RNLLAMA_LIB_DIR}/ggml-backend.cpp
21
21
  ${RNLLAMA_LIB_DIR}/ggml-backend-reg.cpp
22
+ ${RNLLAMA_LIB_DIR}/ops.cpp
23
+ ${RNLLAMA_LIB_DIR}/unary-ops.cpp
24
+ ${RNLLAMA_LIB_DIR}/binary-ops.cpp
25
+ ${RNLLAMA_LIB_DIR}/vec.cpp
22
26
  ${RNLLAMA_LIB_DIR}/ggml-cpu.c
23
27
  ${RNLLAMA_LIB_DIR}/ggml-cpu.cpp
24
28
  ${RNLLAMA_LIB_DIR}/ggml-cpu-aarch64.cpp
@@ -46,16 +50,19 @@ set(
46
50
  ${RNLLAMA_LIB_DIR}/llama-model-loader.cpp
47
51
  ${RNLLAMA_LIB_DIR}/llama-mmap.cpp
48
52
  ${RNLLAMA_LIB_DIR}/llama-vocab.cpp
53
+ ${RNLLAMA_LIB_DIR}/llama-memory.cpp
54
+ ${RNLLAMA_LIB_DIR}/llama-io.cpp
55
+ ${RNLLAMA_LIB_DIR}/llama-graph.cpp
49
56
  ${RNLLAMA_LIB_DIR}/sampling.cpp
50
57
  ${RNLLAMA_LIB_DIR}/unicode-data.cpp
51
58
  ${RNLLAMA_LIB_DIR}/unicode.cpp
52
59
  ${RNLLAMA_LIB_DIR}/sgemm.cpp
53
60
  ${RNLLAMA_LIB_DIR}/common.cpp
54
61
  ${RNLLAMA_LIB_DIR}/chat.cpp
55
- ${RNLLAMA_LIB_DIR}/minja/chat-template.hpp
56
62
  ${RNLLAMA_LIB_DIR}/json-schema-to-grammar.cpp
57
- ${RNLLAMA_LIB_DIR}/minja/minja.hpp
58
63
  ${RNLLAMA_LIB_DIR}/json.hpp
64
+ ${RNLLAMA_LIB_DIR}/minja/minja.hpp
65
+ ${RNLLAMA_LIB_DIR}/minja/chat-template.hpp
59
66
  ${RNLLAMA_LIB_DIR}/rn-llama.cpp
60
67
  ${CMAKE_SOURCE_DIR}/jni-utils.h
61
68
  ${CMAKE_SOURCE_DIR}/jni.cpp
@@ -264,7 +264,7 @@ Java_com_rnllama_LlamaContext_initContext(
264
264
  }
265
265
 
266
266
  const char *model_path_chars = env->GetStringUTFChars(model_path_str, nullptr);
267
- defaultParams.model = model_path_chars;
267
+ defaultParams.model = { model_path_chars };
268
268
 
269
269
  const char *chat_template_chars = env->GetStringUTFChars(chat_template, nullptr);
270
270
  defaultParams.chat_template = chat_template_chars;
@@ -446,7 +446,7 @@ Java_com_rnllama_LlamaContext_loadModelDetails(
446
446
 
447
447
  auto default_caps = createWriteableMap(env);
448
448
 
449
- auto default_tmpl = llama -> templates -> template_default.get();
449
+ auto default_tmpl = llama->templates.get()->template_default.get();
450
450
  auto default_tmpl_caps = default_tmpl->original_caps();
451
451
  putBoolean(env, default_caps, "tools", default_tmpl_caps.supports_tools);
452
452
  putBoolean(env, default_caps, "toolCalls", default_tmpl_caps.supports_tool_calls);
@@ -457,7 +457,7 @@ Java_com_rnllama_LlamaContext_loadModelDetails(
457
457
  putMap(env, minja, "defaultCaps", default_caps);
458
458
 
459
459
  putBoolean(env, minja, "toolUse", llama->validateModelChatTemplate(true, "tool_use"));
460
- auto tool_use_tmpl = llama-> templates -> template_tool_use.get();
460
+ auto tool_use_tmpl = llama->templates.get()->template_tool_use.get();
461
461
  if (tool_use_tmpl != nullptr) {
462
462
  auto tool_use_caps = createWriteableMap(env);
463
463
  auto tool_use_tmpl_caps = tool_use_tmpl->original_caps();
@@ -518,9 +518,9 @@ Java_com_rnllama_LlamaContext_getFormattedChatWithJinja(
518
518
  auto grammar_triggers = createWritableArray(env);
519
519
  for (const auto &trigger : formatted.grammar_triggers) {
520
520
  auto trigger_map = createWriteableMap(env);
521
-
522
- putString(env, trigger_map, "word", trigger.value.c_str());
523
- putBoolean(env, trigger_map, "at_start", trigger.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START);
521
+ putInt(env, trigger_map, "type", trigger.type);
522
+ putString(env, trigger_map, "value", trigger.value.c_str());
523
+ putInt(env, trigger_map, "token", trigger.token);
524
524
  pushMap(env, grammar_triggers, trigger_map);
525
525
  }
526
526
  putArray(env, result, "grammar_triggers", grammar_triggers);
@@ -734,23 +734,53 @@ Java_com_rnllama_LlamaContext_doCompletion(
734
734
  sparams.grammar = grammar_chars;
735
735
  }
736
736
  sparams.grammar_lazy = grammar_lazy;
737
+
738
+ if (preserved_tokens != nullptr) {
739
+ int preserved_tokens_size = readablearray::size(env, preserved_tokens);
740
+ for (int i = 0; i < preserved_tokens_size; i++) {
741
+ jstring preserved_token = readablearray::getString(env, preserved_tokens, i);
742
+ auto ids = common_tokenize(llama->ctx, env->GetStringUTFChars(preserved_token, nullptr), /* add_special= */ false, /* parse_special= */ true);
743
+ if (ids.size() == 1) {
744
+ sparams.preserved_tokens.insert(ids[0]);
745
+ } else {
746
+ LOGI("[RNLlama] Not preserved because more than 1 token (wrong chat template override?): %s", env->GetStringUTFChars(preserved_token, nullptr));
747
+ }
748
+ }
749
+ }
750
+
737
751
  if (grammar_triggers != nullptr) {
738
752
  int grammar_triggers_size = readablearray::size(env, grammar_triggers);
739
753
  for (int i = 0; i < grammar_triggers_size; i++) {
740
- common_grammar_trigger trigger;
741
754
  auto trigger_map = readablearray::getMap(env, grammar_triggers, i);
742
- jstring trigger_word = readablemap::getString(env, trigger_map, "word", nullptr);
743
- jboolean trigger_at_start = readablemap::getBool(env, trigger_map, "at_start", false);
744
- trigger.value = env->GetStringUTFChars(trigger_word, nullptr);
745
- trigger.type = trigger_at_start ? COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START : COMMON_GRAMMAR_TRIGGER_TYPE_WORD;
746
-
747
- auto ids = common_tokenize(llama->ctx, trigger.value, /* add_special= */ false, /* parse_special= */ true);
748
- if (ids.size() == 1) {
749
- sparams.grammar_triggers.push_back(trigger);
750
- sparams.preserved_tokens.insert(ids[0]);
751
- continue;
755
+ const auto type = static_cast<common_grammar_trigger_type>(readablemap::getInt(env, trigger_map, "type", 0));
756
+ jstring trigger_word = readablemap::getString(env, trigger_map, "value", nullptr);
757
+ auto word = env->GetStringUTFChars(trigger_word, nullptr);
758
+
759
+ if (type == COMMON_GRAMMAR_TRIGGER_TYPE_WORD) {
760
+ auto ids = common_tokenize(llama->ctx, word, /* add_special= */ false, /* parse_special= */ true);
761
+ if (ids.size() == 1) {
762
+ auto token = ids[0];
763
+ if (std::find(sparams.preserved_tokens.begin(), sparams.preserved_tokens.end(), (llama_token) token) == sparams.preserved_tokens.end()) {
764
+ throw std::runtime_error("Grammar trigger word should be marked as preserved token");
765
+ }
766
+ common_grammar_trigger trigger;
767
+ trigger.type = COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN;
768
+ trigger.value = word;
769
+ trigger.token = token;
770
+ sparams.grammar_triggers.push_back(std::move(trigger));
771
+ } else {
772
+ sparams.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, word});
773
+ }
774
+ } else {
775
+ common_grammar_trigger trigger;
776
+ trigger.type = type;
777
+ trigger.value = word;
778
+ if (type == COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN) {
779
+ const auto token = (llama_token) readablemap::getInt(env, trigger_map, "token", 0);
780
+ trigger.token = token;
781
+ }
782
+ sparams.grammar_triggers.push_back(std::move(trigger));
752
783
  }
753
- sparams.grammar_triggers.push_back(trigger);
754
784
  }
755
785
  }
756
786
 
@@ -761,18 +791,6 @@ Java_com_rnllama_LlamaContext_doCompletion(
761
791
  }
762
792
  env->ReleaseStringUTFChars(json_schema, json_schema_chars);
763
793
 
764
- if (preserved_tokens != nullptr) {
765
- int preserved_tokens_size = readablearray::size(env, preserved_tokens);
766
- for (int i = 0; i < preserved_tokens_size; i++) {
767
- jstring preserved_token = readablearray::getString(env, preserved_tokens, i);
768
- auto ids = common_tokenize(llama->ctx, env->GetStringUTFChars(preserved_token, nullptr), /* add_special= */ false, /* parse_special= */ true);
769
- if (ids.size() == 1) {
770
- sparams.preserved_tokens.insert(ids[0]);
771
- } else {
772
- LOGI("[RNLlama] Not preserved because more than 1 token (wrong chat template override?): %s", env->GetStringUTFChars(preserved_token, nullptr));
773
- }
774
- }
775
- }
776
794
 
777
795
  const llama_model * model = llama_get_model(llama->ctx);
778
796
  const llama_vocab * vocab = llama_model_get_vocab(model);
@@ -904,7 +922,7 @@ Java_com_rnllama_LlamaContext_doCompletion(
904
922
 
905
923
  auto toolCalls = createWritableArray(env);
906
924
  std::string reasoningContent = "";
907
- std::string *content = nullptr;
925
+ std::string content;
908
926
  auto toolCallsSize = 0;
909
927
  if (!llama->is_interrupted) {
910
928
  try {
@@ -912,7 +930,7 @@ Java_com_rnllama_LlamaContext_doCompletion(
912
930
  if (!message.reasoning_content.empty()) {
913
931
  reasoningContent = message.reasoning_content;
914
932
  }
915
- content = &message.content;
933
+ content = message.content;
916
934
  for (const auto &tc : message.tool_calls) {
917
935
  auto toolCall = createWriteableMap(env);
918
936
  putString(env, toolCall, "type", "function");
@@ -933,8 +951,8 @@ Java_com_rnllama_LlamaContext_doCompletion(
933
951
 
934
952
  auto result = createWriteableMap(env);
935
953
  putString(env, result, "text", llama->generated_text.c_str());
936
- if (content) {
937
- putString(env, result, "content", content->c_str());
954
+ if (!content.empty()) {
955
+ putString(env, result, "content", content.c_str());
938
956
  }
939
957
  if (!reasoningContent.empty()) {
940
958
  putString(env, result, "reasoning_content", reasoningContent.c_str());
@@ -0,0 +1,158 @@
1
+ #include "binary-ops.h"
2
+
3
+ #if defined(LM_GGML_USE_ACCELERATE)
4
+ #include <Accelerate/Accelerate.h>
5
+
6
+ using vDSP_fn_t = void (*)(const float *, vDSP_Stride, const float *, vDSP_Stride, float *, vDSP_Stride, vDSP_Length);
7
+ #endif
8
+
9
+ static inline float op_add(float a, float b) {
10
+ return a + b;
11
+ }
12
+
13
+ static inline float op_sub(float a, float b) {
14
+ return a - b;
15
+ }
16
+
17
+ static inline float op_mul(float a, float b) {
18
+ return a * b;
19
+ }
20
+
21
+ static inline float op_div(float a, float b) {
22
+ return a / b;
23
+ }
24
+
25
+ template <float (*op)(float, float), typename src0_t, typename src1_t, typename dst_t>
26
+ static inline void vec_binary_op_contiguous(const int64_t n, dst_t * z, const src0_t * x, const src1_t * y) {
27
+ constexpr auto src0_to_f32 = type_conversion_table<src0_t>::to_f32;
28
+ constexpr auto src1_to_f32 = type_conversion_table<src1_t>::to_f32;
29
+ constexpr auto f32_to_dst = type_conversion_table<dst_t >::from_f32;
30
+
31
+ for (int i = 0; i < n; i++) {
32
+ z[i] = f32_to_dst(op(src0_to_f32(x[i]), src1_to_f32(y[i])));
33
+ }
34
+ }
35
+
36
+ template <float (*op)(float, float), typename src0_t, typename src1_t, typename dst_t>
37
+ static inline void vec_binary_op_non_contiguous(const int64_t n, const int64_t ne10, const int64_t nb10, dst_t * z, const src0_t * x, const src1_t * y) {
38
+ constexpr auto src0_to_f32 = type_conversion_table<src0_t>::to_f32;
39
+ constexpr auto src1_to_f32 = type_conversion_table<src1_t>::to_f32;
40
+ constexpr auto f32_to_dst = type_conversion_table<dst_t >::from_f32;
41
+
42
+ for (int i = 0; i < n; i++) {
43
+ int i10 = i % ne10;
44
+ const src1_t * y_ptr = (const src1_t *)((const char *)y + i10*nb10);
45
+ z[i] = f32_to_dst(op(src0_to_f32(x[i]), src1_to_f32(*y_ptr)));
46
+ }
47
+ }
48
+
49
+ template <float (*op)(float, float), typename src0_t, typename src1_t, typename dst_t>
50
+ static void apply_binary_op(const lm_ggml_compute_params * params, lm_ggml_tensor * dst) {
51
+ const lm_ggml_tensor * src0 = dst->src[0];
52
+ const lm_ggml_tensor * src1 = dst->src[1];
53
+
54
+ LM_GGML_ASSERT(lm_ggml_can_repeat(src1, src0) && lm_ggml_are_same_shape(src0, dst));
55
+
56
+ LM_GGML_TENSOR_BINARY_OP_LOCALS
57
+
58
+ LM_GGML_ASSERT( nb0 == sizeof(dst_t));
59
+ LM_GGML_ASSERT(nb00 == sizeof(src0_t));
60
+
61
+ const auto [ir0, ir1] = get_thread_range(params, src0);
62
+ const bool is_src1_contiguous = (nb10 == sizeof(src1_t));
63
+
64
+ if (!is_src1_contiguous) { // broadcast not implemented yet for non-contiguous
65
+ LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, src1));
66
+ }
67
+
68
+ #ifdef LM_GGML_USE_ACCELERATE
69
+ vDSP_fn_t vDSP_op = nullptr;
70
+ // TODO - avoid the f32-only check using type 'trait' lookup tables and row-based src-to-float conversion functions
71
+ if (src0->type == LM_GGML_TYPE_F32 && src1->type == LM_GGML_TYPE_F32 && dst->type == LM_GGML_TYPE_F32) {
72
+ if (op == op_add) {
73
+ vDSP_op = vDSP_vadd;
74
+ } else if (op == op_sub) {
75
+ vDSP_op = vDSP_vsub;
76
+ } else if (op == op_mul) {
77
+ vDSP_op = vDSP_vmul;
78
+ } else if (op == op_div) {
79
+ vDSP_op = vDSP_vdiv;
80
+ }
81
+ }
82
+ #endif
83
+
84
+ for (int64_t ir = ir0; ir < ir1; ++ir) {
85
+ const int64_t i03 = ir/(ne02*ne01);
86
+ const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
87
+ const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
88
+
89
+ const int64_t i13 = i03 % ne13;
90
+ const int64_t i12 = i02 % ne12;
91
+ const int64_t i11 = i01 % ne11;
92
+
93
+ dst_t * dst_ptr = (dst_t *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
94
+ const src0_t * src0_ptr = (const src0_t *) ((const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
95
+ const src1_t * src1_ptr = (const src1_t *) ((const char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
96
+
97
+ if (is_src1_contiguous) {
98
+ // src1 is broadcastable across src0 and dst in i1, i2, i3
99
+ const int64_t nr0 = ne00 / ne10;
100
+
101
+ for (int64_t r = 0; r < nr0; ++r) {
102
+ #ifdef LM_GGML_USE_ACCELERATE
103
+ if constexpr (std::is_same_v<src0_t, float> && std::is_same_v<src1_t, float> && std::is_same_v<dst_t, float>) {
104
+ if (vDSP_op != nullptr) {
105
+ vDSP_op(src1_ptr, 1, src0_ptr + r*ne10, 1, dst_ptr + r*ne10, 1, ne10);
106
+ continue;
107
+ }
108
+ }
109
+ #endif
110
+ vec_binary_op_contiguous<op>(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr);
111
+ }
112
+ } else {
113
+ vec_binary_op_non_contiguous<op>(ne0, ne10, nb10, dst_ptr, src0_ptr, src1_ptr);
114
+ }
115
+ }
116
+ }
117
+
118
+ // TODO: Use the 'traits' lookup table (for type conversion fns), instead of a mass of 'if' conditions with long templates
119
+ template <float (*op)(float, float)>
120
+ static void binary_op(const lm_ggml_compute_params * params, lm_ggml_tensor * dst) {
121
+ const lm_ggml_tensor * src0 = dst->src[0];
122
+ const lm_ggml_tensor * src1 = dst->src[1];
123
+
124
+ /* */ if (src0->type == LM_GGML_TYPE_F32 && src1->type == LM_GGML_TYPE_F32 && dst->type == LM_GGML_TYPE_F32) { // all f32
125
+ apply_binary_op<op, float, float, float>(params, dst);
126
+ } else if (src0->type == LM_GGML_TYPE_F16 && src1->type == LM_GGML_TYPE_F16 && dst->type == LM_GGML_TYPE_F16) { // all f16
127
+ apply_binary_op<op, lm_ggml_fp16_t, lm_ggml_fp16_t, lm_ggml_fp16_t>(params, dst);
128
+ } else if (src0->type == LM_GGML_TYPE_BF16 && src1->type == LM_GGML_TYPE_BF16 && dst->type == LM_GGML_TYPE_BF16) { // all bf16
129
+ apply_binary_op<op, lm_ggml_bf16_t, lm_ggml_bf16_t, lm_ggml_bf16_t>(params, dst);
130
+ } else if (src0->type == LM_GGML_TYPE_BF16 && src1->type == LM_GGML_TYPE_F32 && dst->type == LM_GGML_TYPE_BF16) {
131
+ apply_binary_op<op, lm_ggml_bf16_t, float, lm_ggml_bf16_t>(params, dst);
132
+ } else if (src0->type == LM_GGML_TYPE_BF16 && src1->type == LM_GGML_TYPE_F32 && dst->type == LM_GGML_TYPE_F32) {
133
+ apply_binary_op<op, lm_ggml_bf16_t, float, float>(params, dst);
134
+ } else if (src0->type == LM_GGML_TYPE_F16 && src1->type == LM_GGML_TYPE_F32 && dst->type == LM_GGML_TYPE_F16) {
135
+ apply_binary_op<op, lm_ggml_fp16_t, float, lm_ggml_fp16_t>(params, dst);
136
+ } else if (src0->type == LM_GGML_TYPE_F16 && src1->type == LM_GGML_TYPE_F32 && dst->type == LM_GGML_TYPE_F32) {
137
+ apply_binary_op<op, lm_ggml_fp16_t, float, float>(params, dst);
138
+ } else {
139
+ LM_GGML_ABORT("%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__,
140
+ lm_ggml_type_name(dst->type), lm_ggml_type_name(src0->type), lm_ggml_type_name(src1->type));
141
+ }
142
+ }
143
+
144
+ void lm_ggml_compute_forward_add_non_quantized(const lm_ggml_compute_params * params, lm_ggml_tensor * dst) {
145
+ binary_op<op_add>(params, dst);
146
+ }
147
+
148
+ void lm_ggml_compute_forward_sub(const lm_ggml_compute_params * params, lm_ggml_tensor * dst) {
149
+ binary_op<op_sub>(params, dst);
150
+ }
151
+
152
+ void lm_ggml_compute_forward_mul(const lm_ggml_compute_params * params, lm_ggml_tensor * dst) {
153
+ binary_op<op_mul>(params, dst);
154
+ }
155
+
156
+ void lm_ggml_compute_forward_div(const lm_ggml_compute_params * params, lm_ggml_tensor * dst) {
157
+ binary_op<op_div>(params, dst);
158
+ }
@@ -0,0 +1,16 @@
1
+ #pragma once
2
+
3
+ #include "cpu-common.h"
4
+
5
+ #ifdef __cplusplus
6
+ extern "C" {
7
+ #endif
8
+
9
+ void lm_ggml_compute_forward_add_non_quantized(const struct lm_ggml_compute_params * params, struct lm_ggml_tensor * dst);
10
+ void lm_ggml_compute_forward_sub(const struct lm_ggml_compute_params * params, struct lm_ggml_tensor * dst);
11
+ void lm_ggml_compute_forward_mul(const struct lm_ggml_compute_params * params, struct lm_ggml_tensor * dst);
12
+ void lm_ggml_compute_forward_div(const struct lm_ggml_compute_params * params, struct lm_ggml_tensor * dst);
13
+
14
+ #ifdef __cplusplus
15
+ }
16
+ #endif