cactus-react-native 1.5.0 → 1.7.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/Cactus.podspec +1 -1
- package/README.md +347 -241
- package/android/CMakeLists.txt +24 -5
- package/android/src/main/jniLibs/arm64-v8a/libcactus.a +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libcurl.a +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libmbedcrypto.a +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libmbedtls.a +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libmbedx509.a +0 -0
- package/cpp/HybridCactus.cpp +149 -117
- package/cpp/HybridCactus.hpp +14 -10
- package/cpp/cactus_ffi.h +54 -43
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/cactus.h +0 -1
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/cactus_ffi.h +54 -43
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/cactus_utils.h +318 -123
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/engine.h +118 -15
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/graph.h +77 -32
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/kernel.h +68 -6
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/kernel_utils.h +21 -155
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/cactus +0 -0
- package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/cactus.h +0 -1
- package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/cactus_ffi.h +54 -43
- package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/cactus_utils.h +318 -123
- package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/engine.h +118 -15
- package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/graph.h +77 -32
- package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/kernel.h +68 -6
- package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/kernel_utils.h +21 -155
- package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/cactus +0 -0
- package/lib/module/classes/CactusLM.js +16 -49
- package/lib/module/classes/CactusLM.js.map +1 -1
- package/lib/module/classes/CactusSTT.js +30 -79
- package/lib/module/classes/CactusSTT.js.map +1 -1
- package/lib/module/classes/CactusVAD.js +95 -0
- package/lib/module/classes/CactusVAD.js.map +1 -0
- package/lib/module/hooks/useCactusLM.js +10 -11
- package/lib/module/hooks/useCactusLM.js.map +1 -1
- package/lib/module/hooks/useCactusSTT.js +23 -62
- package/lib/module/hooks/useCactusSTT.js.map +1 -1
- package/lib/module/hooks/useCactusVAD.js +171 -0
- package/lib/module/hooks/useCactusVAD.js.map +1 -0
- package/lib/module/index.js +2 -3
- package/lib/module/index.js.map +1 -1
- package/lib/module/modelRegistry.js +52 -0
- package/lib/module/modelRegistry.js.map +1 -0
- package/lib/module/native/Cactus.js +85 -23
- package/lib/module/native/Cactus.js.map +1 -1
- package/lib/module/native/CactusIndex.js.map +1 -1
- package/lib/module/native/index.js +0 -3
- package/lib/module/native/index.js.map +1 -1
- package/lib/module/types/CactusVAD.js +4 -0
- package/lib/module/{specs/CactusUtil.nitro.js.map → types/CactusVAD.js.map} +1 -1
- package/lib/typescript/src/classes/CactusLM.d.ts +5 -7
- package/lib/typescript/src/classes/CactusLM.d.ts.map +1 -1
- package/lib/typescript/src/classes/CactusSTT.d.ts +8 -12
- package/lib/typescript/src/classes/CactusSTT.d.ts.map +1 -1
- package/lib/typescript/src/classes/CactusVAD.d.ts +20 -0
- package/lib/typescript/src/classes/CactusVAD.d.ts.map +1 -0
- package/lib/typescript/src/hooks/useCactusLM.d.ts +2 -2
- package/lib/typescript/src/hooks/useCactusLM.d.ts.map +1 -1
- package/lib/typescript/src/hooks/useCactusSTT.d.ts +6 -8
- package/lib/typescript/src/hooks/useCactusSTT.d.ts.map +1 -1
- package/lib/typescript/src/hooks/useCactusVAD.d.ts +15 -0
- package/lib/typescript/src/hooks/useCactusVAD.d.ts.map +1 -0
- package/lib/typescript/src/index.d.ts +7 -5
- package/lib/typescript/src/index.d.ts.map +1 -1
- package/lib/typescript/src/modelRegistry.d.ts +5 -0
- package/lib/typescript/src/modelRegistry.d.ts.map +1 -0
- package/lib/typescript/src/native/Cactus.d.ts +12 -11
- package/lib/typescript/src/native/Cactus.d.ts.map +1 -1
- package/lib/typescript/src/native/CactusIndex.d.ts +2 -2
- package/lib/typescript/src/native/CactusIndex.d.ts.map +1 -1
- package/lib/typescript/src/native/index.d.ts +0 -3
- package/lib/typescript/src/native/index.d.ts.map +1 -1
- package/lib/typescript/src/specs/Cactus.nitro.d.ts +6 -6
- package/lib/typescript/src/specs/Cactus.nitro.d.ts.map +1 -1
- package/lib/typescript/src/types/CactusIndex.d.ts +2 -2
- package/lib/typescript/src/types/CactusIndex.d.ts.map +1 -1
- package/lib/typescript/src/types/CactusLM.d.ts +19 -11
- package/lib/typescript/src/types/CactusLM.d.ts.map +1 -1
- package/lib/typescript/src/types/CactusSTT.d.ts +33 -12
- package/lib/typescript/src/types/CactusSTT.d.ts.map +1 -1
- package/lib/typescript/src/types/CactusVAD.d.ts +34 -0
- package/lib/typescript/src/types/CactusVAD.d.ts.map +1 -0
- package/lib/typescript/src/types/common.d.ts +1 -6
- package/lib/typescript/src/types/common.d.ts.map +1 -1
- package/nitro.json +0 -11
- package/nitrogen/generated/android/cactus+autolinking.cmake +0 -5
- package/nitrogen/generated/android/cactusOnLoad.cpp +0 -30
- package/nitrogen/generated/ios/Cactus-Swift-Cxx-Bridge.cpp +0 -50
- package/nitrogen/generated/ios/Cactus-Swift-Cxx-Bridge.hpp +9 -147
- package/nitrogen/generated/ios/Cactus-Swift-Cxx-Umbrella.hpp +0 -13
- package/nitrogen/generated/ios/CactusAutolinking.mm +0 -26
- package/nitrogen/generated/ios/CactusAutolinking.swift +0 -30
- package/nitrogen/generated/shared/c++/HybridCactusSpec.cpp +4 -4
- package/nitrogen/generated/shared/c++/HybridCactusSpec.hpp +6 -6
- package/package.json +3 -3
- package/src/classes/CactusLM.ts +18 -65
- package/src/classes/CactusSTT.ts +39 -97
- package/src/classes/CactusVAD.ts +129 -0
- package/src/hooks/useCactusLM.ts +14 -17
- package/src/hooks/useCactusSTT.ts +47 -98
- package/src/hooks/useCactusVAD.ts +215 -0
- package/src/index.tsx +18 -12
- package/src/modelRegistry.ts +65 -0
- package/src/native/Cactus.ts +102 -41
- package/src/native/CactusIndex.ts +2 -2
- package/src/native/index.ts +0 -3
- package/src/specs/Cactus.nitro.ts +11 -7
- package/src/types/CactusIndex.ts +2 -2
- package/src/types/CactusLM.ts +19 -11
- package/src/types/CactusSTT.ts +33 -13
- package/src/types/CactusVAD.ts +39 -0
- package/src/types/common.ts +1 -6
- package/android/src/main/java/com/margelo/nitro/cactus/HybridCactusCrypto.kt +0 -46
- package/android/src/main/java/com/margelo/nitro/cactus/HybridCactusDeviceInfo.kt +0 -27
- package/android/src/main/jniLibs/arm64-v8a/libcactus_util.a +0 -0
- package/cpp/HybridCactusUtil.cpp +0 -47
- package/cpp/HybridCactusUtil.hpp +0 -27
- package/cpp/cactus_util.h +0 -25
- package/ios/HybridCactusCrypto.swift +0 -37
- package/ios/HybridCactusDeviceInfo.swift +0 -32
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/cactus_telemetry.h +0 -656
- package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/cactus_telemetry.h +0 -656
- package/ios/cactus_util.xcframework/Info.plist +0 -39
- package/ios/cactus_util.xcframework/ios-arm64/cactus_util.framework/Headers/cactus_util.h +0 -25
- package/ios/cactus_util.xcframework/ios-arm64/cactus_util.framework/Headers/database.h +0 -27
- package/ios/cactus_util.xcframework/ios-arm64/cactus_util.framework/Headers/ios_utils.h +0 -10
- package/ios/cactus_util.xcframework/ios-arm64/cactus_util.framework/Headers/logging.h +0 -25
- package/ios/cactus_util.xcframework/ios-arm64/cactus_util.framework/Info.plist +0 -0
- package/ios/cactus_util.xcframework/ios-arm64/cactus_util.framework/cactus_util +0 -0
- package/ios/cactus_util.xcframework/ios-arm64-simulator/cactus_util.framework/Headers/cactus_util.h +0 -25
- package/ios/cactus_util.xcframework/ios-arm64-simulator/cactus_util.framework/Headers/database.h +0 -27
- package/ios/cactus_util.xcframework/ios-arm64-simulator/cactus_util.framework/Headers/ios_utils.h +0 -10
- package/ios/cactus_util.xcframework/ios-arm64-simulator/cactus_util.framework/Headers/logging.h +0 -25
- package/ios/cactus_util.xcframework/ios-arm64-simulator/cactus_util.framework/Info.plist +0 -0
- package/ios/cactus_util.xcframework/ios-arm64-simulator/cactus_util.framework/_CodeSignature/CodeResources +0 -135
- package/ios/cactus_util.xcframework/ios-arm64-simulator/cactus_util.framework/cactus_util +0 -0
- package/lib/module/api/Database.js +0 -45
- package/lib/module/api/Database.js.map +0 -1
- package/lib/module/api/RemoteLM.js +0 -201
- package/lib/module/api/RemoteLM.js.map +0 -1
- package/lib/module/config/CactusConfig.js +0 -12
- package/lib/module/config/CactusConfig.js.map +0 -1
- package/lib/module/models.js +0 -336
- package/lib/module/models.js.map +0 -1
- package/lib/module/native/CactusCrypto.js +0 -10
- package/lib/module/native/CactusCrypto.js.map +0 -1
- package/lib/module/native/CactusDeviceInfo.js +0 -13
- package/lib/module/native/CactusDeviceInfo.js.map +0 -1
- package/lib/module/native/CactusUtil.js +0 -36
- package/lib/module/native/CactusUtil.js.map +0 -1
- package/lib/module/specs/CactusCrypto.nitro.js +0 -4
- package/lib/module/specs/CactusCrypto.nitro.js.map +0 -1
- package/lib/module/specs/CactusDeviceInfo.nitro.js +0 -4
- package/lib/module/specs/CactusDeviceInfo.nitro.js.map +0 -1
- package/lib/module/specs/CactusUtil.nitro.js +0 -4
- package/lib/module/telemetry/Telemetry.js +0 -154
- package/lib/module/telemetry/Telemetry.js.map +0 -1
- package/lib/typescript/src/api/Database.d.ts +0 -12
- package/lib/typescript/src/api/Database.d.ts.map +0 -1
- package/lib/typescript/src/api/RemoteLM.d.ts +0 -14
- package/lib/typescript/src/api/RemoteLM.d.ts.map +0 -1
- package/lib/typescript/src/config/CactusConfig.d.ts +0 -7
- package/lib/typescript/src/config/CactusConfig.d.ts.map +0 -1
- package/lib/typescript/src/models.d.ts +0 -6
- package/lib/typescript/src/models.d.ts.map +0 -1
- package/lib/typescript/src/native/CactusCrypto.d.ts +0 -5
- package/lib/typescript/src/native/CactusCrypto.d.ts.map +0 -1
- package/lib/typescript/src/native/CactusDeviceInfo.d.ts +0 -7
- package/lib/typescript/src/native/CactusDeviceInfo.d.ts.map +0 -1
- package/lib/typescript/src/native/CactusUtil.d.ts +0 -6
- package/lib/typescript/src/native/CactusUtil.d.ts.map +0 -1
- package/lib/typescript/src/specs/CactusCrypto.nitro.d.ts +0 -8
- package/lib/typescript/src/specs/CactusCrypto.nitro.d.ts.map +0 -1
- package/lib/typescript/src/specs/CactusDeviceInfo.nitro.d.ts +0 -16
- package/lib/typescript/src/specs/CactusDeviceInfo.nitro.d.ts.map +0 -1
- package/lib/typescript/src/specs/CactusUtil.nitro.d.ts +0 -10
- package/lib/typescript/src/specs/CactusUtil.nitro.d.ts.map +0 -1
- package/lib/typescript/src/telemetry/Telemetry.d.ts +0 -34
- package/lib/typescript/src/telemetry/Telemetry.d.ts.map +0 -1
- package/nitrogen/generated/android/c++/JDeviceInfo.hpp +0 -74
- package/nitrogen/generated/android/c++/JHybridCactusCryptoSpec.cpp +0 -65
- package/nitrogen/generated/android/c++/JHybridCactusCryptoSpec.hpp +0 -65
- package/nitrogen/generated/android/c++/JHybridCactusDeviceInfoSpec.cpp +0 -85
- package/nitrogen/generated/android/c++/JHybridCactusDeviceInfoSpec.hpp +0 -66
- package/nitrogen/generated/android/kotlin/com/margelo/nitro/cactus/DeviceInfo.kt +0 -50
- package/nitrogen/generated/android/kotlin/com/margelo/nitro/cactus/HybridCactusCryptoSpec.kt +0 -58
- package/nitrogen/generated/android/kotlin/com/margelo/nitro/cactus/HybridCactusDeviceInfoSpec.kt +0 -62
- package/nitrogen/generated/ios/c++/HybridCactusCryptoSpecSwift.cpp +0 -11
- package/nitrogen/generated/ios/c++/HybridCactusCryptoSpecSwift.hpp +0 -77
- package/nitrogen/generated/ios/c++/HybridCactusDeviceInfoSpecSwift.cpp +0 -11
- package/nitrogen/generated/ios/c++/HybridCactusDeviceInfoSpecSwift.hpp +0 -88
- package/nitrogen/generated/ios/swift/DeviceInfo.swift +0 -98
- package/nitrogen/generated/ios/swift/Func_void_DeviceInfo.swift +0 -47
- package/nitrogen/generated/ios/swift/Func_void_std__optional_std__string_.swift +0 -54
- package/nitrogen/generated/ios/swift/HybridCactusCryptoSpec.swift +0 -57
- package/nitrogen/generated/ios/swift/HybridCactusCryptoSpec_cxx.swift +0 -139
- package/nitrogen/generated/ios/swift/HybridCactusDeviceInfoSpec.swift +0 -58
- package/nitrogen/generated/ios/swift/HybridCactusDeviceInfoSpec_cxx.swift +0 -164
- package/nitrogen/generated/shared/c++/DeviceInfo.hpp +0 -92
- package/nitrogen/generated/shared/c++/HybridCactusCryptoSpec.cpp +0 -21
- package/nitrogen/generated/shared/c++/HybridCactusCryptoSpec.hpp +0 -63
- package/nitrogen/generated/shared/c++/HybridCactusDeviceInfoSpec.cpp +0 -22
- package/nitrogen/generated/shared/c++/HybridCactusDeviceInfoSpec.hpp +0 -67
- package/nitrogen/generated/shared/c++/HybridCactusUtilSpec.cpp +0 -23
- package/nitrogen/generated/shared/c++/HybridCactusUtilSpec.hpp +0 -66
- package/src/api/Database.ts +0 -55
- package/src/api/RemoteLM.ts +0 -273
- package/src/config/CactusConfig.ts +0 -11
- package/src/models.ts +0 -344
- package/src/native/CactusCrypto.ts +0 -11
- package/src/native/CactusDeviceInfo.ts +0 -18
- package/src/native/CactusUtil.ts +0 -43
- package/src/specs/CactusCrypto.nitro.ts +0 -6
- package/src/specs/CactusDeviceInfo.nitro.ts +0 -15
- package/src/specs/CactusUtil.nitro.ts +0 -8
- package/src/telemetry/Telemetry.ts +0 -236
|
@@ -84,12 +84,17 @@ struct Config {
|
|
|
84
84
|
bool use_thumbnail = true;
|
|
85
85
|
uint32_t min_image_tokens = 64;
|
|
86
86
|
uint32_t max_image_tokens = 256;
|
|
87
|
-
|
|
87
|
+
uint32_t max_num_patches = 1024;
|
|
88
88
|
uint32_t tile_size = 512;
|
|
89
89
|
float max_pixels_tolerance = 2.0f;
|
|
90
90
|
bool do_image_splitting = true;
|
|
91
|
+
bool encoder_act_gelu = false;
|
|
92
|
+
bool decoder_act_gelu = false;
|
|
93
|
+
uint32_t num_encoder_layers = 0;
|
|
94
|
+
uint32_t num_decoder_layers = 0;
|
|
95
|
+
float partial_rotary_factor = 0.0f;
|
|
91
96
|
|
|
92
|
-
enum class ModelType {QWEN = 0, GEMMA = 1,
|
|
97
|
+
enum class ModelType {QWEN = 0, GEMMA = 1, NOMIC = 3, LFM2 = 5, SIGLIP2 = 6, WHISPER = 7, MOONSHINE = 8, SILERO_VAD = 9};
|
|
93
98
|
ModelType model_type = ModelType::QWEN;
|
|
94
99
|
|
|
95
100
|
enum class ModelVariant {DEFAULT = 0, VLM = 1, EXTRACT = 2, RAG = 3};
|
|
@@ -107,6 +112,8 @@ struct Config {
|
|
|
107
112
|
float default_temperature = 0.6f;
|
|
108
113
|
float default_top_p = 0.95f;
|
|
109
114
|
size_t default_top_k = 20;
|
|
115
|
+
float default_max_tps = -1.0f;
|
|
116
|
+
float default_cloud_handoff_threshold = 0.0f;
|
|
110
117
|
|
|
111
118
|
std::vector<std::string> layer_types;
|
|
112
119
|
size_t conv_L_cache = 0;
|
|
@@ -152,6 +159,7 @@ public:
|
|
|
152
159
|
virtual uint32_t get_bos_token() const = 0;
|
|
153
160
|
virtual uint32_t get_eos_token() const = 0;
|
|
154
161
|
virtual bool has_chat_template() const { return has_chat_template_; }
|
|
162
|
+
std::string get_default_stop_sequence() const;
|
|
155
163
|
|
|
156
164
|
virtual bool load_vocabulary_with_config(const std::string& vocab_file, const std::string& merges_file, const std::string& config_file) = 0;
|
|
157
165
|
|
|
@@ -159,11 +167,8 @@ public:
|
|
|
159
167
|
uint32_t get_fake_token_id() const { return fake_token_id_; }
|
|
160
168
|
uint32_t get_global_img_token_id() const { return global_img_token_id_; }
|
|
161
169
|
|
|
162
|
-
|
|
163
|
-
void set_corpus_dir(const std::string& dir) { corpus_dir_ = dir; }
|
|
164
|
-
|
|
165
170
|
protected:
|
|
166
|
-
enum class ModelType { UNKNOWN, QWEN, GEMMA, LFM2,
|
|
171
|
+
enum class ModelType { UNKNOWN, QWEN, GEMMA, LFM2, BERT, WHISPER};
|
|
167
172
|
ModelType model_type_ = ModelType::UNKNOWN;
|
|
168
173
|
enum class ModelVariant { DEFAULT, VLM, EXTRACT, RAG};
|
|
169
174
|
ModelVariant model_variant_ = ModelVariant::DEFAULT;
|
|
@@ -173,14 +178,12 @@ protected:
|
|
|
173
178
|
uint32_t image_token_id_ = 396;
|
|
174
179
|
uint32_t fake_token_id_ = 49189;
|
|
175
180
|
uint32_t global_img_token_id_ = 49152;
|
|
176
|
-
std::string corpus_dir_;
|
|
177
181
|
|
|
178
182
|
void detect_model_type(const std::string& config_path);
|
|
179
183
|
std::string format_qwen_style(const std::vector<ChatMessage>& messages, bool add_generation_prompt, const std::string& tools_json) const;
|
|
180
184
|
std::string format_gemma_style(const std::vector<ChatMessage>& messages, bool add_generation_prompt, const std::string& tools_json) const;
|
|
181
185
|
std::string format_lfm2_style(const std::vector<ChatMessage>& messages, bool add_generation_prompt, const std::string& tools_json) const;
|
|
182
186
|
std::string format_lfm2_vl_style(const std::vector<ChatMessage>& messages, bool add_generation_prompt, const std::string& tools_json) const;
|
|
183
|
-
std::string format_smol_style(const std::vector<ChatMessage>& messages, bool add_generation_prompt, const std::string& tools_json) const;
|
|
184
187
|
};
|
|
185
188
|
|
|
186
189
|
class BPETokenizer : public Tokenizer {
|
|
@@ -471,6 +474,8 @@ private:
|
|
|
471
474
|
void compute_bias();
|
|
472
475
|
void tokenize_grammar_elements();
|
|
473
476
|
void add_tokens_for_string(const std::string& str, std::unordered_set<uint32_t>& token_set);
|
|
477
|
+
void tokenize_function_names(bool quote_names);
|
|
478
|
+
void init_common_tokens();
|
|
474
479
|
};
|
|
475
480
|
|
|
476
481
|
class Model {
|
|
@@ -495,22 +500,22 @@ public:
|
|
|
495
500
|
const std::string& system_prompt = "", bool do_warmup = true);
|
|
496
501
|
|
|
497
502
|
virtual uint32_t decode(const std::vector<uint32_t>& tokens, float temperature = -1.0f, float top_p = -1.0f,
|
|
498
|
-
size_t top_k = 0, const std::string& profile_file = "");
|
|
503
|
+
size_t top_k = 0, const std::string& profile_file = "", float* out_entropy = nullptr);
|
|
499
504
|
|
|
500
505
|
virtual void prefill(const std::vector<uint32_t>& tokens, size_t chunk_size = 256, const std::string& profile_file = "");
|
|
501
506
|
|
|
502
507
|
virtual uint32_t decode_with_images(const std::vector<uint32_t>& tokens, const std::vector<std::string>& image_paths,
|
|
503
508
|
float temperature = -1.0f, float top_p = -1.0f,
|
|
504
|
-
size_t top_k = 0, const std::string& profile_file = "");
|
|
509
|
+
size_t top_k = 0, const std::string& profile_file = "", float* out_entropy = nullptr);
|
|
505
510
|
|
|
506
|
-
virtual uint32_t decode_with_audio(const std::vector<uint32_t>& tokens, const std::vector<float>&
|
|
507
|
-
size_t top_k = 0, const std::string& profile_file = "");
|
|
511
|
+
virtual uint32_t decode_with_audio(const std::vector<uint32_t>& tokens, const std::vector<float>& audio_features, float temperature = 0.0f, float top_p = 0.0f,
|
|
512
|
+
size_t top_k = 0, const std::string& profile_file = "", float* out_entropy = nullptr);
|
|
508
513
|
|
|
509
514
|
std::vector<float> get_embeddings(const std::vector<uint32_t>& tokens, bool pooled = true, bool normalize = false, const std::string& profile_file = "");
|
|
510
515
|
|
|
511
516
|
virtual std::vector<float> get_image_embeddings(const std::string& image_path);
|
|
512
517
|
|
|
513
|
-
virtual std::vector<float> get_audio_embeddings(const std::vector<float>&
|
|
518
|
+
virtual std::vector<float> get_audio_embeddings(const std::vector<float>& audio_features);
|
|
514
519
|
|
|
515
520
|
virtual void reset_cache() { kv_cache_.reset(); }
|
|
516
521
|
|
|
@@ -533,7 +538,7 @@ public:
|
|
|
533
538
|
protected:
|
|
534
539
|
virtual size_t forward(const std::vector<uint32_t>& tokens, bool use_cache = false) = 0;
|
|
535
540
|
|
|
536
|
-
virtual size_t forward(const std::vector<float>&
|
|
541
|
+
virtual size_t forward(const std::vector<float>& audio_features, const std::vector<uint32_t>& tokens, bool use_cache = false);
|
|
537
542
|
|
|
538
543
|
virtual void load_weights_to_graph(CactusGraph* gb) = 0;
|
|
539
544
|
|
|
@@ -645,6 +650,7 @@ public:
|
|
|
645
650
|
private:
|
|
646
651
|
Config config_;
|
|
647
652
|
|
|
653
|
+
std::pair<int64_t, int64_t> compute_pixel_limits() const;
|
|
648
654
|
std::vector<unsigned char> convert_to_rgb(const unsigned char* img_data, int width, int height, int channels);
|
|
649
655
|
std::pair<int, int> smart_resize(int height, int width);
|
|
650
656
|
bool is_image_too_large(int height, int width);
|
|
@@ -701,5 +707,102 @@ private:
|
|
|
701
707
|
size_t num_mel_filters_;
|
|
702
708
|
};
|
|
703
709
|
|
|
710
|
+
namespace index {
|
|
711
|
+
constexpr uint32_t MAGIC = 0x43414354;
|
|
712
|
+
constexpr uint32_t VERSION = 1;
|
|
713
|
+
|
|
714
|
+
struct Document {
|
|
715
|
+
int id;
|
|
716
|
+
std::vector<float> embedding;
|
|
717
|
+
std::string content;
|
|
718
|
+
std::string metadata;
|
|
719
|
+
};
|
|
720
|
+
|
|
721
|
+
struct QueryResult {
|
|
722
|
+
int doc_id;
|
|
723
|
+
float score;
|
|
724
|
+
};
|
|
725
|
+
|
|
726
|
+
struct QueryOptions {
|
|
727
|
+
size_t top_k = 10;
|
|
728
|
+
float score_threshold = -1.0f;
|
|
729
|
+
};
|
|
730
|
+
|
|
731
|
+
class Index {
|
|
732
|
+
public:
|
|
733
|
+
Index(const std::string& index_path, const std::string& data_path, size_t embedding_dim);
|
|
734
|
+
~Index();
|
|
735
|
+
|
|
736
|
+
Index(const Index&) = delete;
|
|
737
|
+
Index& operator=(const Index&) = delete;
|
|
738
|
+
Index(Index&&) = delete;
|
|
739
|
+
Index& operator=(Index&&) = delete;
|
|
740
|
+
|
|
741
|
+
void add_documents(const std::vector<Document>& documents);
|
|
742
|
+
void delete_documents(const std::vector<int>& doc_ids);
|
|
743
|
+
std::vector<Document> get_documents(const std::vector<int>& doc_ids);
|
|
744
|
+
std::vector<std::vector<QueryResult>> query(const std::vector<std::vector<float>>& embeddings, const QueryOptions& options);
|
|
745
|
+
void compact();
|
|
746
|
+
|
|
747
|
+
private:
|
|
748
|
+
struct IndexHeader {
|
|
749
|
+
uint32_t magic;
|
|
750
|
+
uint32_t version;
|
|
751
|
+
uint32_t embedding_dim;
|
|
752
|
+
uint32_t num_documents;
|
|
753
|
+
};
|
|
754
|
+
|
|
755
|
+
struct IndexEntry {
|
|
756
|
+
int32_t doc_id;
|
|
757
|
+
uint64_t data_offset;
|
|
758
|
+
uint8_t flags; // bit 0: tombstone
|
|
759
|
+
|
|
760
|
+
const __fp16* embedding() const {
|
|
761
|
+
return reinterpret_cast<const __fp16*>(this + 1);
|
|
762
|
+
}
|
|
763
|
+
|
|
764
|
+
static size_t size(size_t embedding_dim) {
|
|
765
|
+
return sizeof(IndexEntry) + embedding_dim * sizeof(__fp16);
|
|
766
|
+
}
|
|
767
|
+
};
|
|
768
|
+
|
|
769
|
+
struct DataHeader {
|
|
770
|
+
uint32_t magic;
|
|
771
|
+
uint32_t version;
|
|
772
|
+
};
|
|
773
|
+
|
|
774
|
+
struct DataEntry {
|
|
775
|
+
uint16_t content_len;
|
|
776
|
+
uint16_t metadata_len;
|
|
777
|
+
|
|
778
|
+
const char* content() const {
|
|
779
|
+
return reinterpret_cast<const char*>(this + 1);
|
|
780
|
+
}
|
|
781
|
+
|
|
782
|
+
const char* metadata() const {
|
|
783
|
+
return content() + content_len;
|
|
784
|
+
}
|
|
785
|
+
};
|
|
786
|
+
|
|
787
|
+
void parse_index_header();
|
|
788
|
+
void parse_data_header();
|
|
789
|
+
void build_doc_id_map();
|
|
790
|
+
void validate_documents(const std::vector<Document>& documents);
|
|
791
|
+
void validate_doc_ids(const std::vector<int>& doc_ids);
|
|
792
|
+
ssize_t write_full(int fd, const void* buf, size_t count);
|
|
793
|
+
|
|
794
|
+
std::unordered_map<int, uint32_t> doc_id_map_;
|
|
795
|
+
|
|
796
|
+
std::string index_path_, data_path_;
|
|
797
|
+
size_t embedding_dim_;
|
|
798
|
+
size_t index_entry_size_;
|
|
799
|
+
uint32_t num_documents_;
|
|
800
|
+
|
|
801
|
+
int index_fd_, data_fd_;
|
|
802
|
+
void *mapped_index_, *mapped_data_;
|
|
803
|
+
size_t index_file_size_, data_file_size_;
|
|
804
|
+
};
|
|
805
|
+
} // namespace index
|
|
806
|
+
|
|
807
|
+
}
|
|
704
808
|
}
|
|
705
|
-
}
|
|
@@ -4,6 +4,7 @@
|
|
|
4
4
|
#include <vector>
|
|
5
5
|
#include <memory>
|
|
6
6
|
#include <unordered_map>
|
|
7
|
+
#include <unordered_set>
|
|
7
8
|
#include <functional>
|
|
8
9
|
#include <cstring>
|
|
9
10
|
#include <stdexcept>
|
|
@@ -114,17 +115,20 @@ enum class OpType {
|
|
|
114
115
|
MATMUL, TRANSPOSE, RESHAPE, SLICE, GATHER, EMBEDDING,
|
|
115
116
|
BILINEAR_INTERPOLATION,
|
|
116
117
|
SUM, MEAN, VARIANCE, MIN, MAX,
|
|
117
|
-
RMS_NORM, ROPE, SOFTMAX, ATTENTION, ATTENTION_INT8_HYBRID, CONV1D_CAUSAL, CONV1D_K3,
|
|
118
|
+
RMS_NORM, ROPE, ROPE_GPTJ, SOFTMAX, ATTENTION, ATTENTION_INT8_HYBRID, CONV1D_CAUSAL, CONV1D_K3, CONV1D_K7S3, CONV1D,
|
|
118
119
|
SCALAR_ADD, SCALAR_SUBTRACT, SCALAR_MULTIPLY, SCALAR_DIVIDE, SCALAR_EXP, SCALAR_SQRT, SCALAR_COS, SCALAR_SIN,
|
|
119
|
-
SILU, GELU, GELU_ERF,
|
|
120
|
+
RELU, SILU, GELU, GELU_ERF, SIGMOID, TANH,
|
|
120
121
|
SAMPLE, CONCAT,
|
|
121
122
|
SCATTER_TOPK,
|
|
122
|
-
TOPK, LAYERNORM,
|
|
123
|
+
TOPK, LAYERNORM, GROUPNORM,
|
|
123
124
|
INDEX,
|
|
125
|
+
PERSISTENT,
|
|
126
|
+
QUANTIZE_ACTIVATIONS,
|
|
127
|
+
LSTM_CELL,
|
|
128
|
+
STFT_MAGNITUDE
|
|
124
129
|
};
|
|
125
130
|
|
|
126
131
|
struct PrecisionTraits {
|
|
127
|
-
// Returns in-memory element size (INT4 unpacks to INT8, so returns 1)
|
|
128
132
|
static constexpr size_t size_of(Precision prec) {
|
|
129
133
|
switch (prec) {
|
|
130
134
|
case Precision::INT8: return 1;
|
|
@@ -205,8 +209,12 @@ struct BufferDesc {
|
|
|
205
209
|
void* scales_data = nullptr;
|
|
206
210
|
std::unique_ptr<char[]> owned_scales;
|
|
207
211
|
|
|
208
|
-
|
|
209
|
-
size_t
|
|
212
|
+
bool is_interleaved = false;
|
|
213
|
+
size_t original_N = 0;
|
|
214
|
+
|
|
215
|
+
void* activation_scales_data = nullptr;
|
|
216
|
+
std::unique_ptr<char[]> owned_activation_scales;
|
|
217
|
+
size_t num_rows_for_activation_scales = 0;
|
|
210
218
|
|
|
211
219
|
BufferDesc();
|
|
212
220
|
BufferDesc(const std::vector<size_t>& s, Precision prec = Precision::INT8);
|
|
@@ -230,23 +238,39 @@ struct BufferDesc {
|
|
|
230
238
|
const __fp16* scales_as_fp16() const {
|
|
231
239
|
return reinterpret_cast<const __fp16*>(scales_data);
|
|
232
240
|
}
|
|
241
|
+
|
|
233
242
|
bool is_grouped_int8() const {
|
|
234
243
|
return precision == Precision::INT8 && group_size > 0;
|
|
235
244
|
}
|
|
236
|
-
|
|
237
|
-
return packed_int4_data != nullptr && packed_int4_size > 0;
|
|
238
|
-
}
|
|
239
|
-
const uint8_t* packed_int4_as_uint8() const {
|
|
240
|
-
return reinterpret_cast<const uint8_t*>(packed_int4_data);
|
|
241
|
-
}
|
|
245
|
+
|
|
242
246
|
void set_grouped_scales(size_t gs, size_t ng, void* scales_ptr) {
|
|
243
247
|
group_size = gs;
|
|
244
248
|
num_groups = ng;
|
|
245
249
|
scales_data = scales_ptr;
|
|
246
250
|
}
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
251
|
+
|
|
252
|
+
void set_interleaved(bool interleaved, size_t orig_n) {
|
|
253
|
+
is_interleaved = interleaved;
|
|
254
|
+
original_N = orig_n;
|
|
255
|
+
}
|
|
256
|
+
|
|
257
|
+
bool has_activation_scales() const {
|
|
258
|
+
return activation_scales_data != nullptr && num_rows_for_activation_scales > 0;
|
|
259
|
+
}
|
|
260
|
+
const float* activation_scales_as_float() const {
|
|
261
|
+
return reinterpret_cast<const float*>(activation_scales_data);
|
|
262
|
+
}
|
|
263
|
+
float* activation_scales_as_float() {
|
|
264
|
+
return reinterpret_cast<float*>(activation_scales_data);
|
|
265
|
+
}
|
|
266
|
+
void allocate_activation_scales(size_t num_rows) {
|
|
267
|
+
num_rows_for_activation_scales = num_rows;
|
|
268
|
+
owned_activation_scales = std::make_unique<char[]>(num_rows * sizeof(float));
|
|
269
|
+
activation_scales_data = owned_activation_scales.get();
|
|
270
|
+
}
|
|
271
|
+
void set_activation_scales(void* scales_ptr, size_t num_rows) {
|
|
272
|
+
activation_scales_data = scales_ptr;
|
|
273
|
+
num_rows_for_activation_scales = num_rows;
|
|
250
274
|
}
|
|
251
275
|
|
|
252
276
|
void allocate();
|
|
@@ -282,6 +306,7 @@ struct OpParams {
|
|
|
282
306
|
|
|
283
307
|
size_t index_value = 0;
|
|
284
308
|
size_t num_classes = 0;
|
|
309
|
+
size_t num_groups = 0;
|
|
285
310
|
size_t dst_height = 0;
|
|
286
311
|
size_t dst_width = 0;
|
|
287
312
|
|
|
@@ -295,6 +320,7 @@ struct OpParams {
|
|
|
295
320
|
size_t cache_seq_len = 0;
|
|
296
321
|
size_t num_kv_heads = 0;
|
|
297
322
|
size_t head_dim = 0;
|
|
323
|
+
size_t num_fft_bins = 0;
|
|
298
324
|
};
|
|
299
325
|
|
|
300
326
|
struct GraphNode {
|
|
@@ -324,7 +350,10 @@ void compute_sample_node(GraphNode& node, const std::vector<std::unique_ptr<Grap
|
|
|
324
350
|
void compute_scatter_topk_node(GraphNode& node, const std::vector<std::unique_ptr<GraphNode>>& nodes, const std::unordered_map<size_t, size_t>& node_index_map);
|
|
325
351
|
void compute_topk_node(GraphNode& node, const std::vector<std::unique_ptr<GraphNode>>& nodes, const std::unordered_map<size_t, size_t>& node_index_map);
|
|
326
352
|
void compute_layernorm_node(GraphNode& node, const std::vector<std::unique_ptr<GraphNode>>& nodes, const std::unordered_map<size_t, size_t>& node_index_map);
|
|
353
|
+
void compute_groupnorm_node(GraphNode& node, const std::vector<std::unique_ptr<GraphNode>>& nodes, const std::unordered_map<size_t, size_t>& node_index_map);
|
|
354
|
+
void compute_persistent_node(GraphNode& node, const std::vector<std::unique_ptr<GraphNode>>& nodes, const std::unordered_map<size_t, size_t>& node_index_map);
|
|
327
355
|
void compute_index_node(GraphNode& node, const std::vector<std::unique_ptr<GraphNode>>& nodes, const std::unordered_map<size_t, size_t>& node_index_map);
|
|
356
|
+
void compute_lstm_cell_node(GraphNode& node, const std::vector<std::unique_ptr<GraphNode>>& nodes, const std::unordered_map<size_t, size_t>& node_index_map);
|
|
328
357
|
|
|
329
358
|
void shrink_thread_local_buffers();
|
|
330
359
|
|
|
@@ -372,6 +401,7 @@ public:
|
|
|
372
401
|
|
|
373
402
|
size_t input(const std::vector<size_t>& shape, Precision precision = Precision::INT8);
|
|
374
403
|
size_t precision_cast(size_t input, Precision target_precision);
|
|
404
|
+
size_t quantize_activations(size_t input);
|
|
375
405
|
|
|
376
406
|
size_t add(size_t input1, size_t input2);
|
|
377
407
|
size_t add_clipped(size_t input1, size_t input2);
|
|
@@ -389,9 +419,12 @@ public:
|
|
|
389
419
|
size_t scalar_cos(size_t input);
|
|
390
420
|
size_t scalar_sin(size_t input);
|
|
391
421
|
|
|
422
|
+
size_t relu(size_t input);
|
|
392
423
|
size_t silu(size_t input);
|
|
393
424
|
size_t gelu(size_t input);
|
|
394
425
|
size_t gelu_erf(size_t input);
|
|
426
|
+
size_t sigmoid(size_t input);
|
|
427
|
+
size_t tanh(size_t input);
|
|
395
428
|
|
|
396
429
|
size_t matmul(size_t input1, size_t input2, bool pretransposed_rhs = false, ComputeBackend backend = ComputeBackend::CPU);
|
|
397
430
|
size_t transpose(size_t input, ComputeBackend backend = ComputeBackend::CPU);
|
|
@@ -409,8 +442,8 @@ public:
|
|
|
409
442
|
size_t gather(size_t embeddings, size_t indices);
|
|
410
443
|
size_t mmap_embeddings(const std::string& filename);
|
|
411
444
|
size_t mmap_weights(const std::string& filename);
|
|
412
|
-
size_t load_weights(const std::string& filename);
|
|
413
445
|
void set_grouped_scales(size_t node_id, size_t group_size, size_t num_groups, void* scales_ptr);
|
|
446
|
+
void set_interleaved(size_t node_id, bool interleaved, size_t original_N);
|
|
414
447
|
|
|
415
448
|
void release_weight_pages(size_t node_id);
|
|
416
449
|
void prefetch_weight_pages(size_t node_id);
|
|
@@ -420,9 +453,12 @@ public:
|
|
|
420
453
|
size_t bilinear_interpolation(size_t pos_embeds, size_t dst_height, size_t dst_width);
|
|
421
454
|
|
|
422
455
|
size_t layernorm(size_t input, size_t weight, size_t bias, float epsilon = 1e-5f);
|
|
456
|
+
size_t layernorm(size_t input, size_t weight, float epsilon = 1e-5f); // No bias version
|
|
457
|
+
size_t groupnorm(size_t input, size_t weight, size_t bias, size_t num_groups = 32, float epsilon = 1e-5f);
|
|
423
458
|
size_t topk(size_t input, size_t k);
|
|
424
459
|
size_t rms_norm(size_t input, size_t weight, float epsilon = 1e-5f);
|
|
425
460
|
size_t rope(size_t input, float theta, size_t position_offset = 0, ComputeBackend backend = ComputeBackend::CPU);
|
|
461
|
+
size_t rope_gptj(size_t input, float theta, size_t position_offset = 0, size_t rot_dim = 0, ComputeBackend backend = ComputeBackend::CPU);
|
|
426
462
|
size_t softmax(size_t input, int axis = -1);
|
|
427
463
|
size_t attention(size_t query, size_t key, size_t value, float scale, bool is_causal = true, ComputeBackend backend = ComputeBackend::CPU);
|
|
428
464
|
size_t attention(size_t query, size_t key, size_t value, float scale, size_t position_offset, ComputeBackend backend = ComputeBackend::CPU);
|
|
@@ -431,11 +467,17 @@ public:
|
|
|
431
467
|
size_t attention_int8_hybrid(size_t query, size_t key_new, size_t value_new, float scale, size_t position_offset,
|
|
432
468
|
const int8_t* cached_keys, const int8_t* cached_values,
|
|
433
469
|
const float* k_scales, const float* v_scales,
|
|
434
|
-
size_t cache_len, size_t num_kv_heads, size_t head_dim);
|
|
470
|
+
size_t cache_len, size_t num_kv_heads, size_t head_dim, size_t window_size = 0);
|
|
435
471
|
|
|
436
472
|
size_t conv1d_causal(size_t input, size_t weight, size_t kernel_size, size_t dilation = 1);
|
|
437
473
|
size_t conv1d_k3(size_t input, size_t weight, size_t stride);
|
|
438
|
-
|
|
474
|
+
size_t conv1d_k7s3(size_t input, size_t weight, size_t bias);
|
|
475
|
+
size_t conv1d(size_t input, size_t weight, size_t stride);
|
|
476
|
+
size_t conv1d(size_t input, size_t weight, size_t bias, size_t stride);
|
|
477
|
+
|
|
478
|
+
size_t lstm_cell(size_t input, size_t h_prev, size_t c_prev, size_t weight_ih, size_t weight_hh, size_t bias_ih, size_t bias_hh);
|
|
479
|
+
size_t stft_magnitude(size_t input, size_t weight, size_t stride, size_t num_fft_bins);
|
|
480
|
+
|
|
439
481
|
size_t sample(size_t logits, float temperature = 0.6f, float top_p = 0.95f, size_t top_k = 20,
|
|
440
482
|
const std::unordered_map<uint32_t, float>& logit_bias = {});
|
|
441
483
|
|
|
@@ -462,6 +504,10 @@ public:
|
|
|
462
504
|
void allocate_buffers();
|
|
463
505
|
size_t get_node_count() const;
|
|
464
506
|
|
|
507
|
+
size_t persistent(size_t source_node);
|
|
508
|
+
bool is_populated(size_t persistent_node_id) const;
|
|
509
|
+
void invalidate_persistent(size_t persistent_node_id);
|
|
510
|
+
|
|
465
511
|
std::vector<std::unique_ptr<GraphNode>> nodes_;
|
|
466
512
|
std::unordered_map<size_t, size_t> node_index_map_;
|
|
467
513
|
|
|
@@ -473,6 +519,9 @@ private:
|
|
|
473
519
|
std::vector<DebugNodeEntry> debug_nodes_;
|
|
474
520
|
BufferPool buffer_pool_;
|
|
475
521
|
bool prefill_mode_ = false;
|
|
522
|
+
|
|
523
|
+
std::unordered_set<size_t> persistent_node_ids_;
|
|
524
|
+
std::unordered_set<size_t> populated_node_ids_;
|
|
476
525
|
};
|
|
477
526
|
|
|
478
527
|
|
|
@@ -485,7 +534,6 @@ namespace GraphFile {
|
|
|
485
534
|
};
|
|
486
535
|
|
|
487
536
|
void save_node(CactusGraph& graph, size_t node_id, const std::string& filename);
|
|
488
|
-
LoadedNode load_into_graph(CactusGraph& graph, const std::string& filename);
|
|
489
537
|
|
|
490
538
|
class MappedFile {
|
|
491
539
|
public:
|
|
@@ -499,16 +547,14 @@ namespace GraphFile {
|
|
|
499
547
|
|
|
500
548
|
const std::vector<size_t>& shape() const;
|
|
501
549
|
Precision precision() const;
|
|
502
|
-
Precision effective_precision() const {
|
|
503
|
-
return is_int4_ ? Precision::INT8 : precision_;
|
|
504
|
-
}
|
|
505
550
|
size_t byte_size() const;
|
|
506
551
|
|
|
507
552
|
size_t group_size() const { return group_size_; }
|
|
508
553
|
size_t num_groups() const { return num_groups_; }
|
|
509
554
|
const void* scales_data() const;
|
|
510
|
-
|
|
511
|
-
bool
|
|
555
|
+
|
|
556
|
+
bool is_interleaved() const { return is_interleaved_; }
|
|
557
|
+
size_t original_N() const { return original_N_; }
|
|
512
558
|
|
|
513
559
|
void* data();
|
|
514
560
|
const void* data() const;
|
|
@@ -516,8 +562,6 @@ namespace GraphFile {
|
|
|
516
562
|
template<typename T>
|
|
517
563
|
const T* typed_data() const;
|
|
518
564
|
|
|
519
|
-
LoadedNode load_into_graph(CactusGraph& graph) const;
|
|
520
|
-
|
|
521
565
|
void release_pages();
|
|
522
566
|
void prefetch_pages();
|
|
523
567
|
|
|
@@ -532,16 +576,17 @@ namespace GraphFile {
|
|
|
532
576
|
size_t num_groups_ = 0;
|
|
533
577
|
size_t scales_offset_ = 0;
|
|
534
578
|
size_t scales_bytes_ = 0;
|
|
535
|
-
uint32_t version_ = 1;
|
|
536
579
|
uint32_t alignment_ = 32;
|
|
537
|
-
|
|
538
|
-
|
|
580
|
+
|
|
581
|
+
bool is_interleaved_ = false;
|
|
582
|
+
size_t original_N_ = 0;
|
|
583
|
+
|
|
584
|
+
std::unique_ptr<int8_t[]> unpacked_data_;
|
|
585
|
+
|
|
539
586
|
void parse_header();
|
|
540
587
|
void apply_madvise_hints();
|
|
541
|
-
void
|
|
588
|
+
void unpack_int4_data();
|
|
542
589
|
};
|
|
543
|
-
|
|
544
|
-
MappedFile mmap_load(const std::string& filename);
|
|
545
590
|
}
|
|
546
591
|
|
|
547
592
|
#endif
|
|
@@ -15,7 +15,7 @@ enum class ScalarOpType {
|
|
|
15
15
|
SIN
|
|
16
16
|
};
|
|
17
17
|
|
|
18
|
-
constexpr size_t KV_QUANT_GROUP_SIZE =
|
|
18
|
+
constexpr size_t KV_QUANT_GROUP_SIZE = 32;
|
|
19
19
|
|
|
20
20
|
void cactus_add_f16(const __fp16* a, const __fp16* b, __fp16* output, size_t num_elements);
|
|
21
21
|
void cactus_add_f16_clipped(const __fp16* a, const __fp16* b, __fp16* output, size_t num_elements);
|
|
@@ -38,14 +38,18 @@ void cactus_divide_broadcast_f16(const __fp16* a, const __fp16* b, __fp16* outpu
|
|
|
38
38
|
|
|
39
39
|
void cactus_scalar_op_f16(const __fp16* input, __fp16* output, size_t num_elements, float scalar_value, ScalarOpType op_type);
|
|
40
40
|
|
|
41
|
+
void cactus_gemv_int8(const int8_t* A, float A_scale,
|
|
42
|
+
const int8_t* B, const __fp16* B_scales,
|
|
43
|
+
__fp16* C, size_t K, size_t N, size_t group_size);
|
|
44
|
+
|
|
45
|
+
void cactus_gemm_int8(const int8_t* A, const float* A_scales,
|
|
46
|
+
const int8_t* B, const __fp16* B_scales,
|
|
47
|
+
__fp16* C, size_t M, size_t K, size_t N, size_t group_size);
|
|
48
|
+
|
|
41
49
|
void cactus_matmul_int8(const int8_t* A, const float* A_scales,
|
|
42
50
|
const int8_t* B, const __fp16* B_scales,
|
|
43
51
|
__fp16* C, size_t M, size_t K, size_t N, size_t group_size);
|
|
44
52
|
|
|
45
|
-
void cactus_matmul_int4(const int8_t* A, const float* A_scales,
|
|
46
|
-
const uint8_t* B_packed, const __fp16* B_scales,
|
|
47
|
-
__fp16* C, size_t M, size_t K, size_t N, size_t group_size);
|
|
48
|
-
|
|
49
53
|
void cactus_matmul_f16(const __fp16* a, const __fp16* b_transposed, __fp16* c,
|
|
50
54
|
size_t M, size_t K, size_t N);
|
|
51
55
|
|
|
@@ -75,15 +79,24 @@ void cactus_rms_norm_f16(const __fp16* input, const __fp16* weight, __fp16* outp
|
|
|
75
79
|
void cactus_rope_f16(const __fp16* input, __fp16* output, size_t batch_size, size_t seq_len,
|
|
76
80
|
size_t num_heads, size_t head_dim, size_t start_pos, float theta);
|
|
77
81
|
|
|
82
|
+
void cactus_gpt_j_rope_f16(const __fp16* input, __fp16* output, size_t batch_size, size_t seq_len,
|
|
83
|
+
size_t num_heads, size_t head_dim, size_t rot_dim, size_t start_pos, float theta);
|
|
84
|
+
|
|
78
85
|
void cactus_softmax_f16(const __fp16* input, __fp16* output, size_t batch_size,
|
|
79
86
|
size_t seq_len, size_t vocab_size);
|
|
80
87
|
|
|
88
|
+
void cactus_relu_f16(const __fp16* input, __fp16* output, size_t num_elements);
|
|
89
|
+
|
|
81
90
|
void cactus_silu_f16(const __fp16* input, __fp16* output, size_t num_elements);
|
|
82
91
|
|
|
83
92
|
void cactus_gelu_f16(const __fp16* input, __fp16* output, size_t num_elements);
|
|
84
93
|
|
|
85
94
|
void cactus_gelu_f16_erf(const __fp16* input, __fp16* output, size_t num_elements);
|
|
86
95
|
|
|
96
|
+
void cactus_sigmoid_f16(const __fp16* input, __fp16* output, size_t num_elements);
|
|
97
|
+
|
|
98
|
+
void cactus_tanh_f16(const __fp16* input, __fp16* output, size_t num_elements);
|
|
99
|
+
|
|
87
100
|
void cactus_attention_f16(const __fp16* queries, const __fp16* keys, const __fp16* values, __fp16* output,
|
|
88
101
|
size_t batch_size, size_t seq_len, size_t kv_seq_len, size_t num_q_heads, size_t num_kv_heads,
|
|
89
102
|
size_t head_dim, float scale, const __fp16* mask, size_t position_offset = 0, size_t window_size = 0,
|
|
@@ -100,7 +113,7 @@ void cactus_attention_hybrid_int8_fp16(
|
|
|
100
113
|
__fp16* output,
|
|
101
114
|
size_t batch_size, size_t seq_len, size_t cache_len, size_t new_len,
|
|
102
115
|
size_t num_q_heads, size_t num_kv_heads, size_t head_dim,
|
|
103
|
-
float scale, size_t position_offset = 0, bool is_causal = true,
|
|
116
|
+
float scale, size_t position_offset = 0, bool is_causal = true, size_t window_size = 0,
|
|
104
117
|
size_t group_size = KV_QUANT_GROUP_SIZE);
|
|
105
118
|
|
|
106
119
|
void cactus_conv1d_causal_depthwise_f16(
|
|
@@ -124,6 +137,40 @@ void cactus_conv1d_f16_k3(
|
|
|
124
137
|
size_t stride
|
|
125
138
|
);
|
|
126
139
|
|
|
140
|
+
void cactus_conv1d_f16(
|
|
141
|
+
const __fp16* input,
|
|
142
|
+
const __fp16* weight,
|
|
143
|
+
const __fp16* bias,
|
|
144
|
+
__fp16* output,
|
|
145
|
+
size_t N,
|
|
146
|
+
size_t L,
|
|
147
|
+
size_t C_in,
|
|
148
|
+
size_t C_out,
|
|
149
|
+
size_t K,
|
|
150
|
+
size_t stride
|
|
151
|
+
);
|
|
152
|
+
|
|
153
|
+
void cactus_stft_magnitude_f16(
|
|
154
|
+
const __fp16* input,
|
|
155
|
+
const __fp16* weight,
|
|
156
|
+
__fp16* output,
|
|
157
|
+
size_t N, size_t L,
|
|
158
|
+
size_t C_in, size_t C_out,
|
|
159
|
+
size_t K, size_t stride,
|
|
160
|
+
size_t num_fft_bins
|
|
161
|
+
);
|
|
162
|
+
|
|
163
|
+
void cactus_conv1d_f16_k7s3_oc8(
|
|
164
|
+
const __fp16* input,
|
|
165
|
+
const __fp16* Wpack,
|
|
166
|
+
const __fp16* bias,
|
|
167
|
+
__fp16* output,
|
|
168
|
+
size_t N,
|
|
169
|
+
size_t L,
|
|
170
|
+
size_t C_in,
|
|
171
|
+
size_t C_out
|
|
172
|
+
);
|
|
173
|
+
|
|
127
174
|
void cactus_bilinear_interpolation_f16(const __fp16* input, __fp16* output, size_t src_height, size_t src_width, size_t embed_dim,
|
|
128
175
|
size_t dst_height, size_t dst_width);
|
|
129
176
|
|
|
@@ -162,4 +209,19 @@ inline size_t kv_scales_count(size_t seq_len, size_t kv_heads, size_t head_dim,
|
|
|
162
209
|
|
|
163
210
|
void cactus_unpack_int4_to_int8(const uint8_t* packed, int8_t* unpacked, size_t unpacked_count);
|
|
164
211
|
|
|
212
|
+
void cactus_lstm_cell_f16(
|
|
213
|
+
const __fp16* x_input,
|
|
214
|
+
const __fp16* h_prev,
|
|
215
|
+
const __fp16* c_prev,
|
|
216
|
+
const __fp16* weight_ih,
|
|
217
|
+
const __fp16* weight_hh,
|
|
218
|
+
const __fp16* bias_ih,
|
|
219
|
+
const __fp16* bias_hh,
|
|
220
|
+
__fp16* h_new,
|
|
221
|
+
__fp16* c_new,
|
|
222
|
+
size_t batch_size,
|
|
223
|
+
size_t input_size,
|
|
224
|
+
size_t hidden_size
|
|
225
|
+
);
|
|
226
|
+
|
|
165
227
|
#endif
|