cactus-react-native 1.0.1 → 1.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +609 -56
- package/android/src/main/java/com/margelo/nitro/cactus/HybridCactusCrypto.kt +23 -15
- package/android/src/main/java/com/margelo/nitro/cactus/HybridCactusDeviceInfo.kt +12 -9
- package/android/src/main/java/com/margelo/nitro/cactus/HybridCactusFileSystem.kt +42 -41
- package/android/src/main/java/com/margelo/nitro/cactus/HybridCactusImage.kt +81 -0
- package/android/src/main/jniLibs/arm64-v8a/libcactus.a +0 -0
- package/cpp/HybridCactus.cpp +161 -44
- package/cpp/HybridCactus.hpp +34 -14
- package/cpp/HybridCactusUtil.cpp +13 -11
- package/cpp/HybridCactusUtil.hpp +9 -9
- package/cpp/cactus_ffi.h +28 -1
- package/ios/HybridCactusImage.swift +53 -0
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/cactus_ffi.h +28 -1
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/engine.h +237 -7
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/ffi_utils.h +158 -43
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/graph.h +23 -2
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/kernel.h +52 -0
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/cactus +0 -0
- package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/cactus_ffi.h +28 -1
- package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/engine.h +237 -7
- package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/ffi_utils.h +158 -43
- package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/graph.h +23 -2
- package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/kernel.h +52 -0
- package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/cactus +0 -0
- package/lib/module/api/Database.js +23 -0
- package/lib/module/api/Database.js.map +1 -1
- package/lib/module/api/RemoteLM.js +201 -0
- package/lib/module/api/RemoteLM.js.map +1 -0
- package/lib/module/classes/CactusLM.js +56 -28
- package/lib/module/classes/CactusLM.js.map +1 -1
- package/lib/module/classes/CactusSTT.js +137 -0
- package/lib/module/classes/CactusSTT.js.map +1 -0
- package/lib/module/config/CactusConfig.js +4 -0
- package/lib/module/config/CactusConfig.js.map +1 -1
- package/lib/module/constants/packageVersion.js +1 -1
- package/lib/module/hooks/useCactusLM.js +44 -16
- package/lib/module/hooks/useCactusLM.js.map +1 -1
- package/lib/module/hooks/useCactusSTT.js +234 -0
- package/lib/module/hooks/useCactusSTT.js.map +1 -0
- package/lib/module/index.js +2 -0
- package/lib/module/index.js.map +1 -1
- package/lib/module/native/Cactus.js +52 -3
- package/lib/module/native/Cactus.js.map +1 -1
- package/lib/module/native/CactusFileSystem.js +2 -3
- package/lib/module/native/CactusFileSystem.js.map +1 -1
- package/lib/module/native/CactusImage.js +13 -0
- package/lib/module/native/CactusImage.js.map +1 -0
- package/lib/module/native/index.js +1 -0
- package/lib/module/native/index.js.map +1 -1
- package/lib/module/specs/CactusImage.nitro.js +4 -0
- package/lib/module/specs/CactusImage.nitro.js.map +1 -0
- package/lib/module/telemetry/Telemetry.js +53 -1
- package/lib/module/telemetry/Telemetry.js.map +1 -1
- package/lib/module/types/CactusSTT.js +2 -0
- package/lib/module/types/CactusSTT.js.map +1 -0
- package/lib/typescript/src/api/Database.d.ts +1 -0
- package/lib/typescript/src/api/Database.d.ts.map +1 -1
- package/lib/typescript/src/api/RemoteLM.d.ts +14 -0
- package/lib/typescript/src/api/RemoteLM.d.ts.map +1 -0
- package/lib/typescript/src/classes/CactusLM.d.ts +8 -5
- package/lib/typescript/src/classes/CactusLM.d.ts.map +1 -1
- package/lib/typescript/src/classes/CactusSTT.d.ts +25 -0
- package/lib/typescript/src/classes/CactusSTT.d.ts.map +1 -0
- package/lib/typescript/src/config/CactusConfig.d.ts +1 -0
- package/lib/typescript/src/config/CactusConfig.d.ts.map +1 -1
- package/lib/typescript/src/constants/packageVersion.d.ts +1 -1
- package/lib/typescript/src/hooks/useCactusLM.d.ts +5 -4
- package/lib/typescript/src/hooks/useCactusLM.d.ts.map +1 -1
- package/lib/typescript/src/hooks/useCactusSTT.d.ts +20 -0
- package/lib/typescript/src/hooks/useCactusSTT.d.ts.map +1 -0
- package/lib/typescript/src/index.d.ts +4 -1
- package/lib/typescript/src/index.d.ts.map +1 -1
- package/lib/typescript/src/native/Cactus.d.ts +10 -3
- package/lib/typescript/src/native/Cactus.d.ts.map +1 -1
- package/lib/typescript/src/native/CactusFileSystem.d.ts +1 -1
- package/lib/typescript/src/native/CactusFileSystem.d.ts.map +1 -1
- package/lib/typescript/src/native/CactusImage.d.ts +6 -0
- package/lib/typescript/src/native/CactusImage.d.ts.map +1 -0
- package/lib/typescript/src/native/index.d.ts +1 -0
- package/lib/typescript/src/native/index.d.ts.map +1 -1
- package/lib/typescript/src/specs/Cactus.nitro.d.ts +4 -1
- package/lib/typescript/src/specs/Cactus.nitro.d.ts.map +1 -1
- package/lib/typescript/src/specs/CactusImage.nitro.d.ts +9 -0
- package/lib/typescript/src/specs/CactusImage.nitro.d.ts.map +1 -0
- package/lib/typescript/src/telemetry/Telemetry.d.ts +5 -1
- package/lib/typescript/src/telemetry/Telemetry.d.ts.map +1 -1
- package/lib/typescript/src/types/CactusLM.d.ts +11 -6
- package/lib/typescript/src/types/CactusLM.d.ts.map +1 -1
- package/lib/typescript/src/types/CactusSTT.d.ts +37 -0
- package/lib/typescript/src/types/CactusSTT.d.ts.map +1 -0
- package/nitro.json +4 -0
- package/nitrogen/generated/android/c++/JHybridCactusImageSpec.cpp +81 -0
- package/nitrogen/generated/android/c++/JHybridCactusImageSpec.hpp +66 -0
- package/nitrogen/generated/android/cactus+autolinking.cmake +2 -0
- package/nitrogen/generated/android/cactusOnLoad.cpp +10 -0
- package/nitrogen/generated/android/kotlin/com/margelo/nitro/cactus/HybridCactusImageSpec.kt +62 -0
- package/nitrogen/generated/ios/Cactus-Swift-Cxx-Bridge.cpp +17 -0
- package/nitrogen/generated/ios/Cactus-Swift-Cxx-Bridge.hpp +17 -0
- package/nitrogen/generated/ios/Cactus-Swift-Cxx-Umbrella.hpp +5 -0
- package/nitrogen/generated/ios/CactusAutolinking.mm +8 -0
- package/nitrogen/generated/ios/CactusAutolinking.swift +15 -0
- package/nitrogen/generated/ios/c++/HybridCactusImageSpecSwift.cpp +11 -0
- package/nitrogen/generated/ios/c++/HybridCactusImageSpecSwift.hpp +85 -0
- package/nitrogen/generated/ios/swift/HybridCactusImageSpec.swift +58 -0
- package/nitrogen/generated/ios/swift/HybridCactusImageSpec_cxx.swift +158 -0
- package/nitrogen/generated/shared/c++/HybridCactusImageSpec.cpp +22 -0
- package/nitrogen/generated/shared/c++/HybridCactusImageSpec.hpp +64 -0
- package/nitrogen/generated/shared/c++/HybridCactusSpec.cpp +3 -0
- package/nitrogen/generated/shared/c++/HybridCactusSpec.hpp +4 -1
- package/package.json +1 -1
- package/src/api/Database.ts +27 -0
- package/src/api/RemoteLM.ts +273 -0
- package/src/classes/CactusLM.ts +76 -40
- package/src/classes/CactusSTT.ts +182 -0
- package/src/config/CactusConfig.ts +4 -0
- package/src/constants/packageVersion.ts +1 -1
- package/src/hooks/useCactusLM.ts +53 -22
- package/src/hooks/useCactusSTT.ts +285 -0
- package/src/index.tsx +14 -2
- package/src/native/Cactus.ts +100 -6
- package/src/native/CactusFileSystem.ts +2 -2
- package/src/native/CactusImage.ts +20 -0
- package/src/native/index.ts +1 -0
- package/src/specs/Cactus.nitro.ts +14 -1
- package/src/specs/CactusImage.nitro.ts +12 -0
- package/src/telemetry/Telemetry.ts +78 -1
- package/src/types/CactusLM.ts +12 -6
- package/src/types/CactusSTT.ts +42 -0
|
@@ -8,6 +8,10 @@
|
|
|
8
8
|
#include <stdexcept>
|
|
9
9
|
#include <sstream>
|
|
10
10
|
#include <iomanip>
|
|
11
|
+
#include <fstream>
|
|
12
|
+
#include <iostream>
|
|
13
|
+
#include <filesystem>
|
|
14
|
+
#include <cctype>
|
|
11
15
|
|
|
12
16
|
namespace cactus {
|
|
13
17
|
namespace ffi {
|
|
@@ -30,8 +34,10 @@ inline void handle_error_response(const std::string& error_message, char* respon
|
|
|
30
34
|
}
|
|
31
35
|
}
|
|
32
36
|
|
|
33
|
-
inline std::vector<cactus::engine::ChatMessage> parse_messages_json(const std::string& json
|
|
37
|
+
inline std::vector<cactus::engine::ChatMessage> parse_messages_json(const std::string& json,
|
|
38
|
+
std::vector<std::string>& out_image_paths) {
|
|
34
39
|
std::vector<cactus::engine::ChatMessage> messages;
|
|
40
|
+
out_image_paths.clear();
|
|
35
41
|
|
|
36
42
|
size_t pos = json.find('[');
|
|
37
43
|
if (pos == std::string::npos) {
|
|
@@ -42,42 +48,79 @@ inline std::vector<cactus::engine::ChatMessage> parse_messages_json(const std::s
|
|
|
42
48
|
while (pos != std::string::npos) {
|
|
43
49
|
cactus::engine::ChatMessage msg;
|
|
44
50
|
|
|
51
|
+
size_t obj_start = pos;
|
|
52
|
+
int brace_count = 1;
|
|
53
|
+
size_t obj_end = obj_start + 1;
|
|
54
|
+
while (obj_end < json.length() && brace_count > 0) {
|
|
55
|
+
if (json[obj_end] == '{') brace_count++;
|
|
56
|
+
else if (json[obj_end] == '}') brace_count--;
|
|
57
|
+
obj_end++;
|
|
58
|
+
}
|
|
59
|
+
|
|
45
60
|
size_t role_pos = json.find("\"role\"", pos);
|
|
46
|
-
if (role_pos == std::string::npos) break;
|
|
61
|
+
if (role_pos == std::string::npos || role_pos >= obj_end) break;
|
|
47
62
|
|
|
48
63
|
size_t role_start = json.find('"', role_pos + 6) + 1;
|
|
49
64
|
size_t role_end = json.find('"', role_start);
|
|
50
65
|
msg.role = json.substr(role_start, role_end - role_start);
|
|
51
66
|
|
|
52
67
|
size_t content_pos = json.find("\"content\"", role_end);
|
|
53
|
-
if (content_pos
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
68
|
+
if (content_pos != std::string::npos && content_pos < obj_end) {
|
|
69
|
+
size_t content_start = json.find('"', content_pos + 9) + 1;
|
|
70
|
+
size_t content_end = content_start;
|
|
71
|
+
|
|
72
|
+
while (content_end < json.length()) {
|
|
73
|
+
content_end = json.find('"', content_end);
|
|
74
|
+
if (content_end == std::string::npos) break;
|
|
75
|
+
if (json[content_end - 1] != '\\') break;
|
|
76
|
+
content_end++;
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
msg.content = json.substr(content_start, content_end - content_start);
|
|
80
|
+
|
|
81
|
+
size_t escape_pos = 0;
|
|
82
|
+
while ((escape_pos = msg.content.find("\\n", escape_pos)) != std::string::npos) {
|
|
83
|
+
msg.content.replace(escape_pos, 2, "\n");
|
|
84
|
+
escape_pos += 1;
|
|
85
|
+
}
|
|
86
|
+
escape_pos = 0;
|
|
87
|
+
while ((escape_pos = msg.content.find("\\\"", escape_pos)) != std::string::npos) {
|
|
88
|
+
msg.content.replace(escape_pos, 2, "\"");
|
|
89
|
+
escape_pos += 1;
|
|
90
|
+
}
|
|
63
91
|
}
|
|
64
92
|
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
93
|
+
size_t images_pos = json.find("\"images\"", pos);
|
|
94
|
+
if (images_pos != std::string::npos && images_pos < obj_end) {
|
|
95
|
+
size_t array_start = json.find('[', images_pos);
|
|
96
|
+
if (array_start != std::string::npos && array_start < obj_end) {
|
|
97
|
+
size_t array_end = json.find(']', array_start);
|
|
98
|
+
if (array_end != std::string::npos && array_end < obj_end) {
|
|
99
|
+
size_t img_pos = array_start;
|
|
100
|
+
while (true) {
|
|
101
|
+
img_pos = json.find('"', img_pos + 1);
|
|
102
|
+
if (img_pos == std::string::npos || img_pos >= array_end) break;
|
|
103
|
+
|
|
104
|
+
size_t img_start = img_pos + 1;
|
|
105
|
+
size_t img_end = json.find('"', img_start);
|
|
106
|
+
if (img_end == std::string::npos || img_end > array_end) break;
|
|
107
|
+
|
|
108
|
+
std::string img_path = json.substr(img_start, img_end - img_start);
|
|
109
|
+
|
|
110
|
+
std::filesystem::path p(img_path);
|
|
111
|
+
img_path = std::filesystem::absolute(p).string();
|
|
112
|
+
|
|
113
|
+
msg.images.push_back(img_path);
|
|
114
|
+
out_image_paths.push_back(img_path);
|
|
115
|
+
img_pos = img_end;
|
|
116
|
+
}
|
|
117
|
+
}
|
|
118
|
+
}
|
|
76
119
|
}
|
|
77
120
|
|
|
78
121
|
messages.push_back(msg);
|
|
79
122
|
|
|
80
|
-
pos = json.find('{',
|
|
123
|
+
pos = json.find('{', obj_end);
|
|
81
124
|
}
|
|
82
125
|
|
|
83
126
|
return messages;
|
|
@@ -136,10 +179,10 @@ inline void parse_options_json(const std::string& json,
|
|
|
136
179
|
float& temperature, float& top_p,
|
|
137
180
|
size_t& top_k, size_t& max_tokens,
|
|
138
181
|
std::vector<std::string>& stop_sequences) {
|
|
139
|
-
temperature =
|
|
140
|
-
top_p =
|
|
141
|
-
top_k = 0;
|
|
142
|
-
max_tokens = 100;
|
|
182
|
+
temperature = 0.0f;
|
|
183
|
+
top_p = 0.0f;
|
|
184
|
+
top_k = 0;
|
|
185
|
+
max_tokens = 100;
|
|
143
186
|
stop_sequences.clear();
|
|
144
187
|
|
|
145
188
|
if (json.empty()) return;
|
|
@@ -192,47 +235,119 @@ inline std::string format_tools_for_prompt(const std::vector<ToolFunction>& tool
|
|
|
192
235
|
std::string formatted_tools_json;
|
|
193
236
|
for (size_t i = 0; i < tools.size(); i++) {
|
|
194
237
|
if (i > 0) formatted_tools_json += ",\n";
|
|
195
|
-
formatted_tools_json += "
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
formatted_tools_json += " \"description\": \"" + tools[i].description + "\"";
|
|
238
|
+
formatted_tools_json += "{\"type\":\"function\",\"function\":{\"name\":\""
|
|
239
|
+
+ tools[i].name
|
|
240
|
+
+ "\",\"description\":\""
|
|
241
|
+
+ tools[i].description + "\"";
|
|
200
242
|
if (tools[i].parameters.find("schema") != tools[i].parameters.end()) {
|
|
201
|
-
formatted_tools_json += ",\
|
|
243
|
+
formatted_tools_json += ",\"parameters\":" + tools[i].parameters.at("schema");
|
|
202
244
|
}
|
|
203
|
-
formatted_tools_json += "
|
|
245
|
+
formatted_tools_json += "}}";
|
|
204
246
|
}
|
|
205
247
|
return formatted_tools_json;
|
|
206
248
|
}
|
|
207
249
|
|
|
208
|
-
inline void parse_function_calls_from_response(const std::string& response_text,
|
|
209
|
-
std::string& regular_response,
|
|
250
|
+
inline void parse_function_calls_from_response(const std::string& response_text,
|
|
251
|
+
std::string& regular_response,
|
|
210
252
|
std::vector<std::string>& function_calls) {
|
|
211
253
|
regular_response = response_text;
|
|
212
254
|
function_calls.clear();
|
|
213
255
|
|
|
256
|
+
const std::string TOOL_CALL_START = "<|tool_call_start|>";
|
|
257
|
+
const std::string TOOL_CALL_END = "<|tool_call_end|>";
|
|
258
|
+
size_t tool_start_pos = 0;
|
|
259
|
+
|
|
260
|
+
while ((tool_start_pos = response_text.find(TOOL_CALL_START, tool_start_pos)) != std::string::npos) {
|
|
261
|
+
size_t content_start = tool_start_pos + TOOL_CALL_START.length();
|
|
262
|
+
size_t tool_end_pos = response_text.find(TOOL_CALL_END, content_start);
|
|
263
|
+
|
|
264
|
+
if (tool_end_pos != std::string::npos) {
|
|
265
|
+
std::string tool_content = response_text.substr(content_start, tool_end_pos - content_start);
|
|
266
|
+
|
|
267
|
+
if (tool_content.size() > 2 && tool_content[0] == '[' && tool_content[tool_content.size()-1] == ']') {
|
|
268
|
+
tool_content = tool_content.substr(1, tool_content.size() - 2);
|
|
269
|
+
|
|
270
|
+
size_t paren_pos = tool_content.find('(');
|
|
271
|
+
if (paren_pos != std::string::npos) {
|
|
272
|
+
std::string func_name = tool_content.substr(0, paren_pos);
|
|
273
|
+
std::string args_str = tool_content.substr(paren_pos + 1);
|
|
274
|
+
|
|
275
|
+
if (!args_str.empty() && args_str.back() == ')') {
|
|
276
|
+
args_str.pop_back();
|
|
277
|
+
}
|
|
278
|
+
|
|
279
|
+
std::string json_call = "{\"name\":\"" + func_name + "\",\"arguments\":{";
|
|
280
|
+
|
|
281
|
+
size_t arg_pos = 0;
|
|
282
|
+
bool first_arg = true;
|
|
283
|
+
while (arg_pos < args_str.length()) {
|
|
284
|
+
while (arg_pos < args_str.length() && std::isspace(args_str[arg_pos])) arg_pos++;
|
|
285
|
+
|
|
286
|
+
size_t eq_pos = args_str.find('=', arg_pos);
|
|
287
|
+
if (eq_pos == std::string::npos) break;
|
|
288
|
+
|
|
289
|
+
std::string arg_name = args_str.substr(arg_pos, eq_pos - arg_pos);
|
|
290
|
+
|
|
291
|
+
size_t val_start = eq_pos + 1;
|
|
292
|
+
size_t val_end = val_start;
|
|
293
|
+
|
|
294
|
+
if (val_start < args_str.length() && args_str[val_start] == '"') {
|
|
295
|
+
val_start++;
|
|
296
|
+
val_end = args_str.find('"', val_start);
|
|
297
|
+
if (val_end == std::string::npos) break;
|
|
298
|
+
} else {
|
|
299
|
+
val_end = args_str.find(',', val_start);
|
|
300
|
+
if (val_end == std::string::npos) val_end = args_str.length();
|
|
301
|
+
}
|
|
302
|
+
|
|
303
|
+
std::string arg_value = args_str.substr(val_start, val_end - val_start);
|
|
304
|
+
|
|
305
|
+
if (!first_arg) json_call += ",";
|
|
306
|
+
json_call += "\"" + arg_name + "\":\"" + arg_value + "\"";
|
|
307
|
+
first_arg = false;
|
|
308
|
+
|
|
309
|
+
arg_pos = args_str.find(',', val_end);
|
|
310
|
+
if (arg_pos != std::string::npos) {
|
|
311
|
+
arg_pos++;
|
|
312
|
+
} else {
|
|
313
|
+
break;
|
|
314
|
+
}
|
|
315
|
+
}
|
|
316
|
+
|
|
317
|
+
json_call += "}}";
|
|
318
|
+
function_calls.push_back(json_call);
|
|
319
|
+
}
|
|
320
|
+
}
|
|
321
|
+
|
|
322
|
+
regular_response.erase(tool_start_pos, tool_end_pos + TOOL_CALL_END.length() - tool_start_pos);
|
|
323
|
+
tool_start_pos = tool_end_pos + TOOL_CALL_END.length();
|
|
324
|
+
} else {
|
|
325
|
+
break;
|
|
326
|
+
}
|
|
327
|
+
}
|
|
328
|
+
|
|
214
329
|
const char* FUNCTION_CALL_MARKER = "\"function_call\"";
|
|
215
330
|
size_t search_pos = 0;
|
|
216
|
-
const size_t text_len =
|
|
331
|
+
const size_t text_len = regular_response.length();
|
|
217
332
|
|
|
218
333
|
while (search_pos < text_len) {
|
|
219
|
-
size_t marker_pos =
|
|
334
|
+
size_t marker_pos = regular_response.find(FUNCTION_CALL_MARKER, search_pos);
|
|
220
335
|
if (marker_pos == std::string::npos) break;
|
|
221
336
|
|
|
222
|
-
size_t json_start =
|
|
337
|
+
size_t json_start = regular_response.find('{', marker_pos);
|
|
223
338
|
if (json_start == std::string::npos) break;
|
|
224
339
|
|
|
225
340
|
int brace_count = 1;
|
|
226
341
|
size_t json_end = json_start + 1;
|
|
227
342
|
while (json_end < text_len && brace_count > 0) {
|
|
228
|
-
char c =
|
|
343
|
+
char c = regular_response[json_end];
|
|
229
344
|
brace_count += (c == '{') - (c == '}');
|
|
230
345
|
json_end++;
|
|
231
346
|
}
|
|
232
347
|
|
|
233
348
|
if (brace_count == 0) {
|
|
234
|
-
function_calls.push_back(
|
|
235
|
-
regular_response =
|
|
349
|
+
function_calls.push_back(regular_response.substr(json_start, json_end - json_start));
|
|
350
|
+
regular_response = regular_response.substr(0, marker_pos);
|
|
236
351
|
size_t last_bracket = regular_response.rfind('{');
|
|
237
352
|
if(last_bracket != std::string::npos) {
|
|
238
353
|
regular_response = regular_response.substr(0, last_bracket);
|
|
@@ -28,10 +28,11 @@ enum class OpType {
|
|
|
28
28
|
INPUT, PRECISION_CAST,
|
|
29
29
|
ADD, ADD_CLIPPED, SUBTRACT, MULTIPLY, DIVIDE,
|
|
30
30
|
MATMUL, TRANSPOSE, RESHAPE, SLICE, GATHER, EMBEDDING,
|
|
31
|
+
BILINEAR_INTERPOLATION,
|
|
31
32
|
SUM, MEAN, VARIANCE, MIN, MAX,
|
|
32
|
-
RMS_NORM, ROPE, SOFTMAX, ATTENTION, CONV1D_CAUSAL,
|
|
33
|
+
RMS_NORM, ROPE, SOFTMAX, ATTENTION, CONV1D_CAUSAL, CONV1D_K3,
|
|
33
34
|
SCALAR_ADD, SCALAR_SUBTRACT, SCALAR_MULTIPLY, SCALAR_DIVIDE, SCALAR_EXP, SCALAR_SQRT, SCALAR_COS, SCALAR_SIN,
|
|
34
|
-
SILU, GELU,
|
|
35
|
+
SILU, GELU, GELU_ERF,
|
|
35
36
|
SAMPLE, CONCAT,
|
|
36
37
|
SCATTER_TOPK,
|
|
37
38
|
TOPK, LAYERNORM,
|
|
@@ -139,6 +140,7 @@ struct OpParams {
|
|
|
139
140
|
ComputeBackend backend = ComputeBackend::CPU;
|
|
140
141
|
|
|
141
142
|
size_t dilation = 1;
|
|
143
|
+
size_t stride = 1;
|
|
142
144
|
float temperature = 1.0f;
|
|
143
145
|
float top_p = 1.0f;
|
|
144
146
|
size_t top_k = 0;
|
|
@@ -146,6 +148,8 @@ struct OpParams {
|
|
|
146
148
|
|
|
147
149
|
size_t index_value = 0; // For INDEX operation
|
|
148
150
|
size_t num_classes = 0; // For scatter operations
|
|
151
|
+
size_t dst_height = 0;
|
|
152
|
+
size_t dst_width = 0;
|
|
149
153
|
};
|
|
150
154
|
|
|
151
155
|
struct GraphNode {
|
|
@@ -187,6 +191,12 @@ namespace ValidationUtils {
|
|
|
187
191
|
class CactusGraph {
|
|
188
192
|
public:
|
|
189
193
|
CactusGraph();
|
|
194
|
+
|
|
195
|
+
struct DebugNodeEntry {
|
|
196
|
+
uint32_t layer_idx;
|
|
197
|
+
std::string name;
|
|
198
|
+
size_t node_id;
|
|
199
|
+
};
|
|
190
200
|
|
|
191
201
|
size_t input(const std::vector<size_t>& shape, Precision precision = Precision::INT8);
|
|
192
202
|
size_t precision_cast(size_t input, Precision target_precision);
|
|
@@ -209,9 +219,11 @@ public:
|
|
|
209
219
|
|
|
210
220
|
size_t silu(size_t input);
|
|
211
221
|
size_t gelu(size_t input);
|
|
222
|
+
size_t gelu_erf(size_t input);
|
|
212
223
|
|
|
213
224
|
size_t matmul(size_t input1, size_t input2, bool pretransposed_rhs = false, ComputeBackend backend = ComputeBackend::CPU);
|
|
214
225
|
size_t transpose(size_t input, ComputeBackend backend = ComputeBackend::CPU);
|
|
226
|
+
size_t transposeN(size_t input, const std::vector<size_t>& permutation, ComputeBackend backend = ComputeBackend::CPU);
|
|
215
227
|
size_t reshape(size_t input, const std::vector<size_t>& new_shape);
|
|
216
228
|
size_t slice(size_t input, int axis, size_t start, size_t length);
|
|
217
229
|
size_t index(size_t input, size_t index_value, int dim);
|
|
@@ -225,9 +237,11 @@ public:
|
|
|
225
237
|
size_t gather(size_t embeddings, size_t indices);
|
|
226
238
|
size_t mmap_embeddings(const std::string& filename);
|
|
227
239
|
size_t mmap_weights(const std::string& filename);
|
|
240
|
+
size_t load_weights(const std::string& filename);
|
|
228
241
|
void set_quantization_scale(size_t node_id, float scale);
|
|
229
242
|
size_t embedding(const std::string& filename, size_t indices);
|
|
230
243
|
size_t embedding(size_t embedding_tensor, size_t indices);
|
|
244
|
+
size_t bilinear_interpolation(size_t pos_embeds, size_t dst_height, size_t dst_width);
|
|
231
245
|
|
|
232
246
|
size_t layernorm(size_t input, size_t weight, size_t bias, float epsilon = 1e-5f);
|
|
233
247
|
size_t topk(size_t input, size_t k);
|
|
@@ -239,6 +253,7 @@ public:
|
|
|
239
253
|
size_t attention(size_t query, size_t key, size_t value, float scale, size_t position_offset, size_t window_size, ComputeBackend backend = ComputeBackend::CPU);
|
|
240
254
|
|
|
241
255
|
size_t conv1d_causal(size_t input, size_t weight, size_t kernel_size, size_t dilation = 1);
|
|
256
|
+
size_t conv1d_k3(size_t input, size_t weight, size_t stride);
|
|
242
257
|
|
|
243
258
|
size_t sample(size_t logits, float temperature = 0.6f, float top_p = 0.95f, size_t top_k = 20);
|
|
244
259
|
|
|
@@ -252,6 +267,11 @@ public:
|
|
|
252
267
|
void execute(const std::string& profile_file = "");
|
|
253
268
|
void hard_reset();
|
|
254
269
|
void soft_reset();
|
|
270
|
+
|
|
271
|
+
void register_debug_node(uint32_t layer_idx, const std::string& name, size_t node_id);
|
|
272
|
+
void capture_debug_node(uint32_t layer_idx, const std::string& name, size_t node_id);
|
|
273
|
+
const std::vector<DebugNodeEntry>& get_debug_nodes() const;
|
|
274
|
+
void clear_debug_nodes();
|
|
255
275
|
|
|
256
276
|
size_t add_node(OpType op_type, const std::vector<size_t>& inputs, const std::vector<size_t>& output_shape, const OpParams& params = {});
|
|
257
277
|
const BufferDesc& get_output_buffer(size_t node_id) const;
|
|
@@ -265,6 +285,7 @@ private:
|
|
|
265
285
|
size_t next_node_id_;
|
|
266
286
|
std::vector<std::unique_ptr<GraphFile::MappedFile>> mapped_files_;
|
|
267
287
|
std::unordered_map<std::string, size_t> weight_cache_;
|
|
288
|
+
std::vector<DebugNodeEntry> debug_nodes_;
|
|
268
289
|
};
|
|
269
290
|
|
|
270
291
|
|
|
@@ -174,6 +174,15 @@ void cactus_gelu_f16(const __fp16* input, __fp16* output, size_t num_elements);
|
|
|
174
174
|
void cactus_gelu_int8(const int8_t* input, int8_t* output, size_t num_elements,
|
|
175
175
|
float input_scale, float output_scale);
|
|
176
176
|
|
|
177
|
+
void cactus_gelu_f32_erf(const float* input, float* output, size_t num_elements);
|
|
178
|
+
void cactus_gelu_f16_erf(const __fp16* input, __fp16* output, size_t num_elements);
|
|
179
|
+
void cactus_gelu_int8_erf(
|
|
180
|
+
const int8_t* input,
|
|
181
|
+
int8_t* output,
|
|
182
|
+
size_t num_elements,
|
|
183
|
+
float scale_in,
|
|
184
|
+
float scale_out);
|
|
185
|
+
|
|
177
186
|
|
|
178
187
|
void cactus_attention_int8(const int8_t* queries, const int8_t* keys, const int8_t* values, int8_t* output,
|
|
179
188
|
size_t batch_size, size_t seq_len, size_t kv_seq_len, size_t num_q_heads, size_t num_kv_heads,
|
|
@@ -225,6 +234,49 @@ void cactus_conv1d_causal_depthwise_int8(
|
|
|
225
234
|
float weight_scale,
|
|
226
235
|
float output_scale);
|
|
227
236
|
|
|
237
|
+
void cactus_conv1d_f32_k3(
|
|
238
|
+
const float* input,
|
|
239
|
+
const float* weight,
|
|
240
|
+
float* output,
|
|
241
|
+
size_t N,
|
|
242
|
+
size_t L,
|
|
243
|
+
size_t C_in,
|
|
244
|
+
size_t C_out,
|
|
245
|
+
size_t stride
|
|
246
|
+
);
|
|
247
|
+
|
|
248
|
+
void cactus_conv1d_f16_k3(
|
|
249
|
+
const __fp16* input,
|
|
250
|
+
const __fp16* weight,
|
|
251
|
+
__fp16* output,
|
|
252
|
+
size_t N,
|
|
253
|
+
size_t L,
|
|
254
|
+
size_t C_in,
|
|
255
|
+
size_t C_out,
|
|
256
|
+
size_t stride
|
|
257
|
+
);
|
|
258
|
+
|
|
259
|
+
void cactus_conv1d_f32_k3(
|
|
260
|
+
const float* input,
|
|
261
|
+
const float* weight,
|
|
262
|
+
float* output,
|
|
263
|
+
size_t N, size_t L,
|
|
264
|
+
size_t C_in, size_t C_out,
|
|
265
|
+
size_t stride
|
|
266
|
+
);
|
|
267
|
+
|
|
268
|
+
void cactus_conv1d_f16_k3(
|
|
269
|
+
const __fp16* input,
|
|
270
|
+
const __fp16* weight,
|
|
271
|
+
__fp16* output,
|
|
272
|
+
size_t N, size_t L,
|
|
273
|
+
size_t C_in, size_t C_out,
|
|
274
|
+
size_t stride
|
|
275
|
+
);
|
|
276
|
+
|
|
277
|
+
void cactus_bilinear_interpolation_fp32(const float* input, float* output, size_t src_height, size_t src_width, size_t embed_dim,
|
|
278
|
+
size_t dst_height, size_t dst_width);
|
|
279
|
+
|
|
228
280
|
void cactus_sample_f32(const float* logits, uint32_t* output, size_t vocab_size,
|
|
229
281
|
float temperature, float top_p, size_t top_k, size_t random_seed);
|
|
230
282
|
void cactus_sample_f16(const __fp16* logits, uint32_t* output, size_t vocab_size,
|
|
Binary file
|
|
@@ -20,7 +20,7 @@ typedef void* cactus_model_t;
|
|
|
20
20
|
|
|
21
21
|
typedef void (*cactus_token_callback)(const char* token, uint32_t token_id, void* user_data);
|
|
22
22
|
|
|
23
|
-
CACTUS_FFI_EXPORT cactus_model_t cactus_init(const char* model_path, size_t context_size);
|
|
23
|
+
CACTUS_FFI_EXPORT cactus_model_t cactus_init(const char* model_path, size_t context_size, const char* corpus_dir);
|
|
24
24
|
|
|
25
25
|
CACTUS_FFI_EXPORT int cactus_complete(
|
|
26
26
|
cactus_model_t model,
|
|
@@ -33,6 +33,17 @@ CACTUS_FFI_EXPORT int cactus_complete(
|
|
|
33
33
|
void* user_data
|
|
34
34
|
);
|
|
35
35
|
|
|
36
|
+
CACTUS_FFI_EXPORT int cactus_transcribe(
|
|
37
|
+
cactus_model_t model,
|
|
38
|
+
const char* audio_file_path,
|
|
39
|
+
const char* prompt,
|
|
40
|
+
char* response_buffer,
|
|
41
|
+
size_t buffer_size,
|
|
42
|
+
const char* options_json,
|
|
43
|
+
cactus_token_callback callback,
|
|
44
|
+
void* user_data
|
|
45
|
+
);
|
|
46
|
+
|
|
36
47
|
|
|
37
48
|
CACTUS_FFI_EXPORT int cactus_embed(
|
|
38
49
|
cactus_model_t model,
|
|
@@ -42,6 +53,22 @@ CACTUS_FFI_EXPORT int cactus_embed(
|
|
|
42
53
|
size_t* embedding_dim
|
|
43
54
|
);
|
|
44
55
|
|
|
56
|
+
CACTUS_FFI_EXPORT int cactus_image_embed(
|
|
57
|
+
cactus_model_t model,
|
|
58
|
+
const char* image_path,
|
|
59
|
+
float* embeddings_buffer,
|
|
60
|
+
size_t buffer_size,
|
|
61
|
+
size_t* embedding_dim
|
|
62
|
+
);
|
|
63
|
+
|
|
64
|
+
CACTUS_FFI_EXPORT int cactus_audio_embed(
|
|
65
|
+
cactus_model_t model,
|
|
66
|
+
const char* audio_path,
|
|
67
|
+
float* embeddings_buffer,
|
|
68
|
+
size_t buffer_size,
|
|
69
|
+
size_t* embedding_dim
|
|
70
|
+
);
|
|
71
|
+
|
|
45
72
|
CACTUS_FFI_EXPORT void cactus_reset(cactus_model_t model);
|
|
46
73
|
|
|
47
74
|
CACTUS_FFI_EXPORT void cactus_stop(cactus_model_t model);
|