cui-llama.rn 1.4.6 → 1.5.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/src/main/CMakeLists.txt +9 -2
- package/android/src/main/jni.cpp +52 -34
- package/android/src/main/jniLibs/arm64-v8a/librnllama.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod_i8mm.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_i8mm.so +0 -0
- package/android/src/main/jniLibs/x86_64/librnllama.so +0 -0
- package/android/src/main/jniLibs/x86_64/librnllama_x86_64.so +0 -0
- package/cpp/binary-ops.cpp +158 -0
- package/cpp/binary-ops.h +16 -0
- package/cpp/chat.cpp +1769 -1779
- package/cpp/chat.h +9 -1
- package/cpp/common.cpp +20 -522
- package/cpp/common.h +13 -36
- package/cpp/cpu-common.h +72 -0
- package/cpp/ggml-common.h +12 -6
- package/cpp/ggml-cpu-aarch64.cpp +1557 -80
- package/cpp/ggml-cpu-impl.h +2 -21
- package/cpp/ggml-cpu-quants.c +904 -405
- package/cpp/ggml-cpu.c +909 -13237
- package/cpp/ggml-impl.h +50 -23
- package/cpp/ggml-metal-impl.h +77 -3
- package/cpp/ggml-metal.m +794 -580
- package/cpp/ggml.c +92 -3
- package/cpp/ggml.h +29 -5
- package/cpp/gguf.cpp +1 -0
- package/cpp/llama-adapter.cpp +55 -20
- package/cpp/llama-adapter.h +11 -9
- package/cpp/llama-arch.cpp +217 -16
- package/cpp/llama-arch.h +25 -0
- package/cpp/llama-batch.h +2 -2
- package/cpp/llama-chat.cpp +54 -2
- package/cpp/llama-chat.h +3 -0
- package/cpp/llama-context.cpp +2294 -1238
- package/cpp/llama-context.h +214 -77
- package/cpp/llama-cparams.h +1 -0
- package/cpp/llama-graph.cpp +1695 -0
- package/cpp/llama-graph.h +592 -0
- package/cpp/llama-hparams.cpp +8 -0
- package/cpp/llama-hparams.h +17 -0
- package/cpp/llama-io.cpp +15 -0
- package/cpp/llama-io.h +35 -0
- package/cpp/llama-kv-cache.cpp +965 -303
- package/cpp/llama-kv-cache.h +145 -151
- package/cpp/llama-memory.cpp +1 -0
- package/cpp/llama-memory.h +21 -0
- package/cpp/llama-mmap.cpp +1 -1
- package/cpp/llama-model-loader.cpp +10 -5
- package/cpp/llama-model-loader.h +5 -3
- package/cpp/llama-model.cpp +9194 -201
- package/cpp/llama-model.h +40 -1
- package/cpp/llama-sampling.cpp +5 -0
- package/cpp/llama-vocab.cpp +36 -5
- package/cpp/llama.cpp +51 -9984
- package/cpp/llama.h +102 -22
- package/cpp/log.cpp +34 -0
- package/cpp/minja/chat-template.hpp +15 -7
- package/cpp/minja/minja.hpp +120 -94
- package/cpp/ops.cpp +8723 -0
- package/cpp/ops.h +128 -0
- package/cpp/rn-llama.cpp +44 -53
- package/cpp/rn-llama.h +2 -12
- package/cpp/sampling.cpp +3 -0
- package/cpp/sgemm.cpp +533 -88
- package/cpp/simd-mappings.h +888 -0
- package/cpp/speculative.cpp +4 -4
- package/cpp/unary-ops.cpp +186 -0
- package/cpp/unary-ops.h +28 -0
- package/cpp/vec.cpp +258 -0
- package/cpp/vec.h +802 -0
- package/ios/CMakeLists.txt +5 -2
- package/ios/RNLlama.mm +2 -2
- package/ios/RNLlamaContext.mm +40 -24
- package/package.json +1 -1
- package/src/NativeRNLlama.ts +6 -4
- package/src/index.ts +3 -1
- package/cpp/chat-template.hpp +0 -529
- package/cpp/minja.hpp +0 -2915
@@ -19,6 +19,10 @@ set(
|
|
19
19
|
${RNLLAMA_LIB_DIR}/ggml-alloc.c
|
20
20
|
${RNLLAMA_LIB_DIR}/ggml-backend.cpp
|
21
21
|
${RNLLAMA_LIB_DIR}/ggml-backend-reg.cpp
|
22
|
+
${RNLLAMA_LIB_DIR}/ops.cpp
|
23
|
+
${RNLLAMA_LIB_DIR}/unary-ops.cpp
|
24
|
+
${RNLLAMA_LIB_DIR}/binary-ops.cpp
|
25
|
+
${RNLLAMA_LIB_DIR}/vec.cpp
|
22
26
|
${RNLLAMA_LIB_DIR}/ggml-cpu.c
|
23
27
|
${RNLLAMA_LIB_DIR}/ggml-cpu.cpp
|
24
28
|
${RNLLAMA_LIB_DIR}/ggml-cpu-aarch64.cpp
|
@@ -46,16 +50,19 @@ set(
|
|
46
50
|
${RNLLAMA_LIB_DIR}/llama-model-loader.cpp
|
47
51
|
${RNLLAMA_LIB_DIR}/llama-mmap.cpp
|
48
52
|
${RNLLAMA_LIB_DIR}/llama-vocab.cpp
|
53
|
+
${RNLLAMA_LIB_DIR}/llama-memory.cpp
|
54
|
+
${RNLLAMA_LIB_DIR}/llama-io.cpp
|
55
|
+
${RNLLAMA_LIB_DIR}/llama-graph.cpp
|
49
56
|
${RNLLAMA_LIB_DIR}/sampling.cpp
|
50
57
|
${RNLLAMA_LIB_DIR}/unicode-data.cpp
|
51
58
|
${RNLLAMA_LIB_DIR}/unicode.cpp
|
52
59
|
${RNLLAMA_LIB_DIR}/sgemm.cpp
|
53
60
|
${RNLLAMA_LIB_DIR}/common.cpp
|
54
61
|
${RNLLAMA_LIB_DIR}/chat.cpp
|
55
|
-
${RNLLAMA_LIB_DIR}/minja/chat-template.hpp
|
56
62
|
${RNLLAMA_LIB_DIR}/json-schema-to-grammar.cpp
|
57
|
-
${RNLLAMA_LIB_DIR}/minja/minja.hpp
|
58
63
|
${RNLLAMA_LIB_DIR}/json.hpp
|
64
|
+
${RNLLAMA_LIB_DIR}/minja/minja.hpp
|
65
|
+
${RNLLAMA_LIB_DIR}/minja/chat-template.hpp
|
59
66
|
${RNLLAMA_LIB_DIR}/rn-llama.cpp
|
60
67
|
${CMAKE_SOURCE_DIR}/jni-utils.h
|
61
68
|
${CMAKE_SOURCE_DIR}/jni.cpp
|
package/android/src/main/jni.cpp
CHANGED
@@ -264,7 +264,7 @@ Java_com_rnllama_LlamaContext_initContext(
|
|
264
264
|
}
|
265
265
|
|
266
266
|
const char *model_path_chars = env->GetStringUTFChars(model_path_str, nullptr);
|
267
|
-
defaultParams.model = model_path_chars;
|
267
|
+
defaultParams.model = { model_path_chars };
|
268
268
|
|
269
269
|
const char *chat_template_chars = env->GetStringUTFChars(chat_template, nullptr);
|
270
270
|
defaultParams.chat_template = chat_template_chars;
|
@@ -446,7 +446,7 @@ Java_com_rnllama_LlamaContext_loadModelDetails(
|
|
446
446
|
|
447
447
|
auto default_caps = createWriteableMap(env);
|
448
448
|
|
449
|
-
auto default_tmpl = llama
|
449
|
+
auto default_tmpl = llama->templates.get()->template_default.get();
|
450
450
|
auto default_tmpl_caps = default_tmpl->original_caps();
|
451
451
|
putBoolean(env, default_caps, "tools", default_tmpl_caps.supports_tools);
|
452
452
|
putBoolean(env, default_caps, "toolCalls", default_tmpl_caps.supports_tool_calls);
|
@@ -457,7 +457,7 @@ Java_com_rnllama_LlamaContext_loadModelDetails(
|
|
457
457
|
putMap(env, minja, "defaultCaps", default_caps);
|
458
458
|
|
459
459
|
putBoolean(env, minja, "toolUse", llama->validateModelChatTemplate(true, "tool_use"));
|
460
|
-
auto tool_use_tmpl = llama->
|
460
|
+
auto tool_use_tmpl = llama->templates.get()->template_tool_use.get();
|
461
461
|
if (tool_use_tmpl != nullptr) {
|
462
462
|
auto tool_use_caps = createWriteableMap(env);
|
463
463
|
auto tool_use_tmpl_caps = tool_use_tmpl->original_caps();
|
@@ -518,9 +518,9 @@ Java_com_rnllama_LlamaContext_getFormattedChatWithJinja(
|
|
518
518
|
auto grammar_triggers = createWritableArray(env);
|
519
519
|
for (const auto &trigger : formatted.grammar_triggers) {
|
520
520
|
auto trigger_map = createWriteableMap(env);
|
521
|
-
|
522
|
-
putString(env, trigger_map, "
|
523
|
-
|
521
|
+
putInt(env, trigger_map, "type", trigger.type);
|
522
|
+
putString(env, trigger_map, "value", trigger.value.c_str());
|
523
|
+
putInt(env, trigger_map, "token", trigger.token);
|
524
524
|
pushMap(env, grammar_triggers, trigger_map);
|
525
525
|
}
|
526
526
|
putArray(env, result, "grammar_triggers", grammar_triggers);
|
@@ -734,23 +734,53 @@ Java_com_rnllama_LlamaContext_doCompletion(
|
|
734
734
|
sparams.grammar = grammar_chars;
|
735
735
|
}
|
736
736
|
sparams.grammar_lazy = grammar_lazy;
|
737
|
+
|
738
|
+
if (preserved_tokens != nullptr) {
|
739
|
+
int preserved_tokens_size = readablearray::size(env, preserved_tokens);
|
740
|
+
for (int i = 0; i < preserved_tokens_size; i++) {
|
741
|
+
jstring preserved_token = readablearray::getString(env, preserved_tokens, i);
|
742
|
+
auto ids = common_tokenize(llama->ctx, env->GetStringUTFChars(preserved_token, nullptr), /* add_special= */ false, /* parse_special= */ true);
|
743
|
+
if (ids.size() == 1) {
|
744
|
+
sparams.preserved_tokens.insert(ids[0]);
|
745
|
+
} else {
|
746
|
+
LOGI("[RNLlama] Not preserved because more than 1 token (wrong chat template override?): %s", env->GetStringUTFChars(preserved_token, nullptr));
|
747
|
+
}
|
748
|
+
}
|
749
|
+
}
|
750
|
+
|
737
751
|
if (grammar_triggers != nullptr) {
|
738
752
|
int grammar_triggers_size = readablearray::size(env, grammar_triggers);
|
739
753
|
for (int i = 0; i < grammar_triggers_size; i++) {
|
740
|
-
common_grammar_trigger trigger;
|
741
754
|
auto trigger_map = readablearray::getMap(env, grammar_triggers, i);
|
742
|
-
|
743
|
-
|
744
|
-
|
745
|
-
|
746
|
-
|
747
|
-
|
748
|
-
|
749
|
-
|
750
|
-
|
751
|
-
|
755
|
+
const auto type = static_cast<common_grammar_trigger_type>(readablemap::getInt(env, trigger_map, "type", 0));
|
756
|
+
jstring trigger_word = readablemap::getString(env, trigger_map, "value", nullptr);
|
757
|
+
auto word = env->GetStringUTFChars(trigger_word, nullptr);
|
758
|
+
|
759
|
+
if (type == COMMON_GRAMMAR_TRIGGER_TYPE_WORD) {
|
760
|
+
auto ids = common_tokenize(llama->ctx, word, /* add_special= */ false, /* parse_special= */ true);
|
761
|
+
if (ids.size() == 1) {
|
762
|
+
auto token = ids[0];
|
763
|
+
if (std::find(sparams.preserved_tokens.begin(), sparams.preserved_tokens.end(), (llama_token) token) == sparams.preserved_tokens.end()) {
|
764
|
+
throw std::runtime_error("Grammar trigger word should be marked as preserved token");
|
765
|
+
}
|
766
|
+
common_grammar_trigger trigger;
|
767
|
+
trigger.type = COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN;
|
768
|
+
trigger.value = word;
|
769
|
+
trigger.token = token;
|
770
|
+
sparams.grammar_triggers.push_back(std::move(trigger));
|
771
|
+
} else {
|
772
|
+
sparams.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, word});
|
773
|
+
}
|
774
|
+
} else {
|
775
|
+
common_grammar_trigger trigger;
|
776
|
+
trigger.type = type;
|
777
|
+
trigger.value = word;
|
778
|
+
if (type == COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN) {
|
779
|
+
const auto token = (llama_token) readablemap::getInt(env, trigger_map, "token", 0);
|
780
|
+
trigger.token = token;
|
781
|
+
}
|
782
|
+
sparams.grammar_triggers.push_back(std::move(trigger));
|
752
783
|
}
|
753
|
-
sparams.grammar_triggers.push_back(trigger);
|
754
784
|
}
|
755
785
|
}
|
756
786
|
|
@@ -761,18 +791,6 @@ Java_com_rnllama_LlamaContext_doCompletion(
|
|
761
791
|
}
|
762
792
|
env->ReleaseStringUTFChars(json_schema, json_schema_chars);
|
763
793
|
|
764
|
-
if (preserved_tokens != nullptr) {
|
765
|
-
int preserved_tokens_size = readablearray::size(env, preserved_tokens);
|
766
|
-
for (int i = 0; i < preserved_tokens_size; i++) {
|
767
|
-
jstring preserved_token = readablearray::getString(env, preserved_tokens, i);
|
768
|
-
auto ids = common_tokenize(llama->ctx, env->GetStringUTFChars(preserved_token, nullptr), /* add_special= */ false, /* parse_special= */ true);
|
769
|
-
if (ids.size() == 1) {
|
770
|
-
sparams.preserved_tokens.insert(ids[0]);
|
771
|
-
} else {
|
772
|
-
LOGI("[RNLlama] Not preserved because more than 1 token (wrong chat template override?): %s", env->GetStringUTFChars(preserved_token, nullptr));
|
773
|
-
}
|
774
|
-
}
|
775
|
-
}
|
776
794
|
|
777
795
|
const llama_model * model = llama_get_model(llama->ctx);
|
778
796
|
const llama_vocab * vocab = llama_model_get_vocab(model);
|
@@ -904,7 +922,7 @@ Java_com_rnllama_LlamaContext_doCompletion(
|
|
904
922
|
|
905
923
|
auto toolCalls = createWritableArray(env);
|
906
924
|
std::string reasoningContent = "";
|
907
|
-
std::string
|
925
|
+
std::string content;
|
908
926
|
auto toolCallsSize = 0;
|
909
927
|
if (!llama->is_interrupted) {
|
910
928
|
try {
|
@@ -912,7 +930,7 @@ Java_com_rnllama_LlamaContext_doCompletion(
|
|
912
930
|
if (!message.reasoning_content.empty()) {
|
913
931
|
reasoningContent = message.reasoning_content;
|
914
932
|
}
|
915
|
-
content =
|
933
|
+
content = message.content;
|
916
934
|
for (const auto &tc : message.tool_calls) {
|
917
935
|
auto toolCall = createWriteableMap(env);
|
918
936
|
putString(env, toolCall, "type", "function");
|
@@ -933,8 +951,8 @@ Java_com_rnllama_LlamaContext_doCompletion(
|
|
933
951
|
|
934
952
|
auto result = createWriteableMap(env);
|
935
953
|
putString(env, result, "text", llama->generated_text.c_str());
|
936
|
-
if (content) {
|
937
|
-
putString(env, result, "content", content
|
954
|
+
if (!content.empty()) {
|
955
|
+
putString(env, result, "content", content.c_str());
|
938
956
|
}
|
939
957
|
if (!reasoningContent.empty()) {
|
940
958
|
putString(env, result, "reasoning_content", reasoningContent.c_str());
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
@@ -0,0 +1,158 @@
|
|
1
|
+
#include "binary-ops.h"
|
2
|
+
|
3
|
+
#if defined(LM_GGML_USE_ACCELERATE)
|
4
|
+
#include <Accelerate/Accelerate.h>
|
5
|
+
|
6
|
+
using vDSP_fn_t = void (*)(const float *, vDSP_Stride, const float *, vDSP_Stride, float *, vDSP_Stride, vDSP_Length);
|
7
|
+
#endif
|
8
|
+
|
9
|
+
static inline float op_add(float a, float b) {
|
10
|
+
return a + b;
|
11
|
+
}
|
12
|
+
|
13
|
+
static inline float op_sub(float a, float b) {
|
14
|
+
return a - b;
|
15
|
+
}
|
16
|
+
|
17
|
+
static inline float op_mul(float a, float b) {
|
18
|
+
return a * b;
|
19
|
+
}
|
20
|
+
|
21
|
+
static inline float op_div(float a, float b) {
|
22
|
+
return a / b;
|
23
|
+
}
|
24
|
+
|
25
|
+
template <float (*op)(float, float), typename src0_t, typename src1_t, typename dst_t>
|
26
|
+
static inline void vec_binary_op_contiguous(const int64_t n, dst_t * z, const src0_t * x, const src1_t * y) {
|
27
|
+
constexpr auto src0_to_f32 = type_conversion_table<src0_t>::to_f32;
|
28
|
+
constexpr auto src1_to_f32 = type_conversion_table<src1_t>::to_f32;
|
29
|
+
constexpr auto f32_to_dst = type_conversion_table<dst_t >::from_f32;
|
30
|
+
|
31
|
+
for (int i = 0; i < n; i++) {
|
32
|
+
z[i] = f32_to_dst(op(src0_to_f32(x[i]), src1_to_f32(y[i])));
|
33
|
+
}
|
34
|
+
}
|
35
|
+
|
36
|
+
template <float (*op)(float, float), typename src0_t, typename src1_t, typename dst_t>
|
37
|
+
static inline void vec_binary_op_non_contiguous(const int64_t n, const int64_t ne10, const int64_t nb10, dst_t * z, const src0_t * x, const src1_t * y) {
|
38
|
+
constexpr auto src0_to_f32 = type_conversion_table<src0_t>::to_f32;
|
39
|
+
constexpr auto src1_to_f32 = type_conversion_table<src1_t>::to_f32;
|
40
|
+
constexpr auto f32_to_dst = type_conversion_table<dst_t >::from_f32;
|
41
|
+
|
42
|
+
for (int i = 0; i < n; i++) {
|
43
|
+
int i10 = i % ne10;
|
44
|
+
const src1_t * y_ptr = (const src1_t *)((const char *)y + i10*nb10);
|
45
|
+
z[i] = f32_to_dst(op(src0_to_f32(x[i]), src1_to_f32(*y_ptr)));
|
46
|
+
}
|
47
|
+
}
|
48
|
+
|
49
|
+
template <float (*op)(float, float), typename src0_t, typename src1_t, typename dst_t>
|
50
|
+
static void apply_binary_op(const lm_ggml_compute_params * params, lm_ggml_tensor * dst) {
|
51
|
+
const lm_ggml_tensor * src0 = dst->src[0];
|
52
|
+
const lm_ggml_tensor * src1 = dst->src[1];
|
53
|
+
|
54
|
+
LM_GGML_ASSERT(lm_ggml_can_repeat(src1, src0) && lm_ggml_are_same_shape(src0, dst));
|
55
|
+
|
56
|
+
LM_GGML_TENSOR_BINARY_OP_LOCALS
|
57
|
+
|
58
|
+
LM_GGML_ASSERT( nb0 == sizeof(dst_t));
|
59
|
+
LM_GGML_ASSERT(nb00 == sizeof(src0_t));
|
60
|
+
|
61
|
+
const auto [ir0, ir1] = get_thread_range(params, src0);
|
62
|
+
const bool is_src1_contiguous = (nb10 == sizeof(src1_t));
|
63
|
+
|
64
|
+
if (!is_src1_contiguous) { // broadcast not implemented yet for non-contiguous
|
65
|
+
LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, src1));
|
66
|
+
}
|
67
|
+
|
68
|
+
#ifdef LM_GGML_USE_ACCELERATE
|
69
|
+
vDSP_fn_t vDSP_op = nullptr;
|
70
|
+
// TODO - avoid the f32-only check using type 'trait' lookup tables and row-based src-to-float conversion functions
|
71
|
+
if (src0->type == LM_GGML_TYPE_F32 && src1->type == LM_GGML_TYPE_F32 && dst->type == LM_GGML_TYPE_F32) {
|
72
|
+
if (op == op_add) {
|
73
|
+
vDSP_op = vDSP_vadd;
|
74
|
+
} else if (op == op_sub) {
|
75
|
+
vDSP_op = vDSP_vsub;
|
76
|
+
} else if (op == op_mul) {
|
77
|
+
vDSP_op = vDSP_vmul;
|
78
|
+
} else if (op == op_div) {
|
79
|
+
vDSP_op = vDSP_vdiv;
|
80
|
+
}
|
81
|
+
}
|
82
|
+
#endif
|
83
|
+
|
84
|
+
for (int64_t ir = ir0; ir < ir1; ++ir) {
|
85
|
+
const int64_t i03 = ir/(ne02*ne01);
|
86
|
+
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
|
87
|
+
const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
|
88
|
+
|
89
|
+
const int64_t i13 = i03 % ne13;
|
90
|
+
const int64_t i12 = i02 % ne12;
|
91
|
+
const int64_t i11 = i01 % ne11;
|
92
|
+
|
93
|
+
dst_t * dst_ptr = (dst_t *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
|
94
|
+
const src0_t * src0_ptr = (const src0_t *) ((const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
|
95
|
+
const src1_t * src1_ptr = (const src1_t *) ((const char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
|
96
|
+
|
97
|
+
if (is_src1_contiguous) {
|
98
|
+
// src1 is broadcastable across src0 and dst in i1, i2, i3
|
99
|
+
const int64_t nr0 = ne00 / ne10;
|
100
|
+
|
101
|
+
for (int64_t r = 0; r < nr0; ++r) {
|
102
|
+
#ifdef LM_GGML_USE_ACCELERATE
|
103
|
+
if constexpr (std::is_same_v<src0_t, float> && std::is_same_v<src1_t, float> && std::is_same_v<dst_t, float>) {
|
104
|
+
if (vDSP_op != nullptr) {
|
105
|
+
vDSP_op(src1_ptr, 1, src0_ptr + r*ne10, 1, dst_ptr + r*ne10, 1, ne10);
|
106
|
+
continue;
|
107
|
+
}
|
108
|
+
}
|
109
|
+
#endif
|
110
|
+
vec_binary_op_contiguous<op>(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr);
|
111
|
+
}
|
112
|
+
} else {
|
113
|
+
vec_binary_op_non_contiguous<op>(ne0, ne10, nb10, dst_ptr, src0_ptr, src1_ptr);
|
114
|
+
}
|
115
|
+
}
|
116
|
+
}
|
117
|
+
|
118
|
+
// TODO: Use the 'traits' lookup table (for type conversion fns), instead of a mass of 'if' conditions with long templates
|
119
|
+
template <float (*op)(float, float)>
|
120
|
+
static void binary_op(const lm_ggml_compute_params * params, lm_ggml_tensor * dst) {
|
121
|
+
const lm_ggml_tensor * src0 = dst->src[0];
|
122
|
+
const lm_ggml_tensor * src1 = dst->src[1];
|
123
|
+
|
124
|
+
/* */ if (src0->type == LM_GGML_TYPE_F32 && src1->type == LM_GGML_TYPE_F32 && dst->type == LM_GGML_TYPE_F32) { // all f32
|
125
|
+
apply_binary_op<op, float, float, float>(params, dst);
|
126
|
+
} else if (src0->type == LM_GGML_TYPE_F16 && src1->type == LM_GGML_TYPE_F16 && dst->type == LM_GGML_TYPE_F16) { // all f16
|
127
|
+
apply_binary_op<op, lm_ggml_fp16_t, lm_ggml_fp16_t, lm_ggml_fp16_t>(params, dst);
|
128
|
+
} else if (src0->type == LM_GGML_TYPE_BF16 && src1->type == LM_GGML_TYPE_BF16 && dst->type == LM_GGML_TYPE_BF16) { // all bf16
|
129
|
+
apply_binary_op<op, lm_ggml_bf16_t, lm_ggml_bf16_t, lm_ggml_bf16_t>(params, dst);
|
130
|
+
} else if (src0->type == LM_GGML_TYPE_BF16 && src1->type == LM_GGML_TYPE_F32 && dst->type == LM_GGML_TYPE_BF16) {
|
131
|
+
apply_binary_op<op, lm_ggml_bf16_t, float, lm_ggml_bf16_t>(params, dst);
|
132
|
+
} else if (src0->type == LM_GGML_TYPE_BF16 && src1->type == LM_GGML_TYPE_F32 && dst->type == LM_GGML_TYPE_F32) {
|
133
|
+
apply_binary_op<op, lm_ggml_bf16_t, float, float>(params, dst);
|
134
|
+
} else if (src0->type == LM_GGML_TYPE_F16 && src1->type == LM_GGML_TYPE_F32 && dst->type == LM_GGML_TYPE_F16) {
|
135
|
+
apply_binary_op<op, lm_ggml_fp16_t, float, lm_ggml_fp16_t>(params, dst);
|
136
|
+
} else if (src0->type == LM_GGML_TYPE_F16 && src1->type == LM_GGML_TYPE_F32 && dst->type == LM_GGML_TYPE_F32) {
|
137
|
+
apply_binary_op<op, lm_ggml_fp16_t, float, float>(params, dst);
|
138
|
+
} else {
|
139
|
+
LM_GGML_ABORT("%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__,
|
140
|
+
lm_ggml_type_name(dst->type), lm_ggml_type_name(src0->type), lm_ggml_type_name(src1->type));
|
141
|
+
}
|
142
|
+
}
|
143
|
+
|
144
|
+
void lm_ggml_compute_forward_add_non_quantized(const lm_ggml_compute_params * params, lm_ggml_tensor * dst) {
|
145
|
+
binary_op<op_add>(params, dst);
|
146
|
+
}
|
147
|
+
|
148
|
+
void lm_ggml_compute_forward_sub(const lm_ggml_compute_params * params, lm_ggml_tensor * dst) {
|
149
|
+
binary_op<op_sub>(params, dst);
|
150
|
+
}
|
151
|
+
|
152
|
+
void lm_ggml_compute_forward_mul(const lm_ggml_compute_params * params, lm_ggml_tensor * dst) {
|
153
|
+
binary_op<op_mul>(params, dst);
|
154
|
+
}
|
155
|
+
|
156
|
+
void lm_ggml_compute_forward_div(const lm_ggml_compute_params * params, lm_ggml_tensor * dst) {
|
157
|
+
binary_op<op_div>(params, dst);
|
158
|
+
}
|
package/cpp/binary-ops.h
ADDED
@@ -0,0 +1,16 @@
|
|
1
|
+
#pragma once
|
2
|
+
|
3
|
+
#include "cpu-common.h"
|
4
|
+
|
5
|
+
#ifdef __cplusplus
|
6
|
+
extern "C" {
|
7
|
+
#endif
|
8
|
+
|
9
|
+
void lm_ggml_compute_forward_add_non_quantized(const struct lm_ggml_compute_params * params, struct lm_ggml_tensor * dst);
|
10
|
+
void lm_ggml_compute_forward_sub(const struct lm_ggml_compute_params * params, struct lm_ggml_tensor * dst);
|
11
|
+
void lm_ggml_compute_forward_mul(const struct lm_ggml_compute_params * params, struct lm_ggml_tensor * dst);
|
12
|
+
void lm_ggml_compute_forward_div(const struct lm_ggml_compute_params * params, struct lm_ggml_tensor * dst);
|
13
|
+
|
14
|
+
#ifdef __cplusplus
|
15
|
+
}
|
16
|
+
#endif
|