cactus-react-native 1.5.0 → 1.10.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/Cactus.podspec +1 -1
- package/README.md +347 -241
- package/android/CMakeLists.txt +24 -5
- package/android/src/main/jniLibs/arm64-v8a/libcactus.a +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libcurl.a +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libmbedcrypto.a +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libmbedtls.a +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libmbedx509.a +0 -0
- package/cpp/HybridCactus.cpp +197 -117
- package/cpp/HybridCactus.hpp +18 -9
- package/cpp/cactus_ffi.h +66 -42
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/cactus.h +0 -1
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/cactus_cloud.h +48 -0
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/cactus_ffi.h +66 -42
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/cactus_utils.h +568 -135
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/engine.h +148 -17
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/graph.h +145 -36
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/kernel.h +187 -6
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/kernel_utils.h +49 -149
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Info.plist +0 -0
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/cactus +0 -0
- package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/cactus.h +0 -1
- package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/cactus_cloud.h +48 -0
- package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/cactus_ffi.h +66 -42
- package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/cactus_utils.h +568 -135
- package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/engine.h +148 -17
- package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/graph.h +145 -36
- package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/kernel.h +187 -6
- package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/kernel_utils.h +49 -149
- package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Info.plist +0 -0
- package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/_CodeSignature/CodeResources +1 -1
- package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/cactus +0 -0
- package/lib/module/classes/CactusLM.js +16 -49
- package/lib/module/classes/CactusLM.js.map +1 -1
- package/lib/module/classes/CactusSTT.js +41 -75
- package/lib/module/classes/CactusSTT.js.map +1 -1
- package/lib/module/classes/CactusVAD.js +95 -0
- package/lib/module/classes/CactusVAD.js.map +1 -0
- package/lib/module/hooks/useCactusLM.js +10 -11
- package/lib/module/hooks/useCactusLM.js.map +1 -1
- package/lib/module/hooks/useCactusSTT.js +23 -62
- package/lib/module/hooks/useCactusSTT.js.map +1 -1
- package/lib/module/hooks/useCactusVAD.js +171 -0
- package/lib/module/hooks/useCactusVAD.js.map +1 -0
- package/lib/module/index.js +2 -3
- package/lib/module/index.js.map +1 -1
- package/lib/module/modelRegistry.js +52 -0
- package/lib/module/modelRegistry.js.map +1 -0
- package/lib/module/native/Cactus.js +103 -23
- package/lib/module/native/Cactus.js.map +1 -1
- package/lib/module/native/CactusIndex.js.map +1 -1
- package/lib/module/native/index.js +0 -3
- package/lib/module/native/index.js.map +1 -1
- package/lib/module/types/CactusVAD.js +4 -0
- package/lib/module/{specs/CactusUtil.nitro.js.map → types/CactusVAD.js.map} +1 -1
- package/lib/typescript/src/classes/CactusLM.d.ts +5 -7
- package/lib/typescript/src/classes/CactusLM.d.ts.map +1 -1
- package/lib/typescript/src/classes/CactusSTT.d.ts +9 -12
- package/lib/typescript/src/classes/CactusSTT.d.ts.map +1 -1
- package/lib/typescript/src/classes/CactusVAD.d.ts +20 -0
- package/lib/typescript/src/classes/CactusVAD.d.ts.map +1 -0
- package/lib/typescript/src/hooks/useCactusLM.d.ts +2 -2
- package/lib/typescript/src/hooks/useCactusLM.d.ts.map +1 -1
- package/lib/typescript/src/hooks/useCactusSTT.d.ts +6 -8
- package/lib/typescript/src/hooks/useCactusSTT.d.ts.map +1 -1
- package/lib/typescript/src/hooks/useCactusVAD.d.ts +15 -0
- package/lib/typescript/src/hooks/useCactusVAD.d.ts.map +1 -0
- package/lib/typescript/src/index.d.ts +7 -5
- package/lib/typescript/src/index.d.ts.map +1 -1
- package/lib/typescript/src/modelRegistry.d.ts +5 -0
- package/lib/typescript/src/modelRegistry.d.ts.map +1 -0
- package/lib/typescript/src/native/Cactus.d.ts +13 -11
- package/lib/typescript/src/native/Cactus.d.ts.map +1 -1
- package/lib/typescript/src/native/CactusIndex.d.ts +2 -2
- package/lib/typescript/src/native/CactusIndex.d.ts.map +1 -1
- package/lib/typescript/src/native/index.d.ts +0 -3
- package/lib/typescript/src/native/index.d.ts.map +1 -1
- package/lib/typescript/src/specs/Cactus.nitro.d.ts +7 -6
- package/lib/typescript/src/specs/Cactus.nitro.d.ts.map +1 -1
- package/lib/typescript/src/types/CactusIndex.d.ts +2 -2
- package/lib/typescript/src/types/CactusIndex.d.ts.map +1 -1
- package/lib/typescript/src/types/CactusLM.d.ts +19 -11
- package/lib/typescript/src/types/CactusLM.d.ts.map +1 -1
- package/lib/typescript/src/types/CactusSTT.d.ts +44 -12
- package/lib/typescript/src/types/CactusSTT.d.ts.map +1 -1
- package/lib/typescript/src/types/CactusVAD.d.ts +34 -0
- package/lib/typescript/src/types/CactusVAD.d.ts.map +1 -0
- package/lib/typescript/src/types/common.d.ts +1 -6
- package/lib/typescript/src/types/common.d.ts.map +1 -1
- package/nitro.json +0 -11
- package/nitrogen/generated/android/cactus+autolinking.cmake +0 -5
- package/nitrogen/generated/android/cactusOnLoad.cpp +0 -30
- package/nitrogen/generated/ios/Cactus-Swift-Cxx-Bridge.cpp +0 -50
- package/nitrogen/generated/ios/Cactus-Swift-Cxx-Bridge.hpp +9 -147
- package/nitrogen/generated/ios/Cactus-Swift-Cxx-Umbrella.hpp +0 -13
- package/nitrogen/generated/ios/CactusAutolinking.mm +0 -26
- package/nitrogen/generated/ios/CactusAutolinking.swift +0 -30
- package/nitrogen/generated/shared/c++/HybridCactusSpec.cpp +5 -4
- package/nitrogen/generated/shared/c++/HybridCactusSpec.hpp +7 -6
- package/package.json +3 -3
- package/src/classes/CactusLM.ts +18 -65
- package/src/classes/CactusSTT.ts +52 -90
- package/src/classes/CactusVAD.ts +129 -0
- package/src/hooks/useCactusLM.ts +14 -17
- package/src/hooks/useCactusSTT.ts +47 -98
- package/src/hooks/useCactusVAD.ts +215 -0
- package/src/index.tsx +21 -12
- package/src/modelRegistry.ts +65 -0
- package/src/native/Cactus.ts +131 -38
- package/src/native/CactusIndex.ts +2 -2
- package/src/native/index.ts +0 -3
- package/src/specs/Cactus.nitro.ts +16 -7
- package/src/types/CactusIndex.ts +2 -2
- package/src/types/CactusLM.ts +19 -11
- package/src/types/CactusSTT.ts +47 -13
- package/src/types/CactusVAD.ts +39 -0
- package/src/types/common.ts +1 -6
- package/android/src/main/java/com/margelo/nitro/cactus/HybridCactusCrypto.kt +0 -46
- package/android/src/main/java/com/margelo/nitro/cactus/HybridCactusDeviceInfo.kt +0 -27
- package/android/src/main/jniLibs/arm64-v8a/libcactus_util.a +0 -0
- package/cpp/HybridCactusUtil.cpp +0 -47
- package/cpp/HybridCactusUtil.hpp +0 -27
- package/cpp/cactus_util.h +0 -25
- package/ios/HybridCactusCrypto.swift +0 -37
- package/ios/HybridCactusDeviceInfo.swift +0 -32
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/cactus_telemetry.h +0 -656
- package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/cactus_telemetry.h +0 -656
- package/ios/cactus_util.xcframework/Info.plist +0 -39
- package/ios/cactus_util.xcframework/ios-arm64/cactus_util.framework/Headers/cactus_util.h +0 -25
- package/ios/cactus_util.xcframework/ios-arm64/cactus_util.framework/Headers/database.h +0 -27
- package/ios/cactus_util.xcframework/ios-arm64/cactus_util.framework/Headers/ios_utils.h +0 -10
- package/ios/cactus_util.xcframework/ios-arm64/cactus_util.framework/Headers/logging.h +0 -25
- package/ios/cactus_util.xcframework/ios-arm64/cactus_util.framework/Info.plist +0 -0
- package/ios/cactus_util.xcframework/ios-arm64/cactus_util.framework/cactus_util +0 -0
- package/ios/cactus_util.xcframework/ios-arm64-simulator/cactus_util.framework/Headers/cactus_util.h +0 -25
- package/ios/cactus_util.xcframework/ios-arm64-simulator/cactus_util.framework/Headers/database.h +0 -27
- package/ios/cactus_util.xcframework/ios-arm64-simulator/cactus_util.framework/Headers/ios_utils.h +0 -10
- package/ios/cactus_util.xcframework/ios-arm64-simulator/cactus_util.framework/Headers/logging.h +0 -25
- package/ios/cactus_util.xcframework/ios-arm64-simulator/cactus_util.framework/Info.plist +0 -0
- package/ios/cactus_util.xcframework/ios-arm64-simulator/cactus_util.framework/_CodeSignature/CodeResources +0 -135
- package/ios/cactus_util.xcframework/ios-arm64-simulator/cactus_util.framework/cactus_util +0 -0
- package/lib/module/api/Database.js +0 -45
- package/lib/module/api/Database.js.map +0 -1
- package/lib/module/api/RemoteLM.js +0 -201
- package/lib/module/api/RemoteLM.js.map +0 -1
- package/lib/module/config/CactusConfig.js +0 -12
- package/lib/module/config/CactusConfig.js.map +0 -1
- package/lib/module/models.js +0 -336
- package/lib/module/models.js.map +0 -1
- package/lib/module/native/CactusCrypto.js +0 -10
- package/lib/module/native/CactusCrypto.js.map +0 -1
- package/lib/module/native/CactusDeviceInfo.js +0 -13
- package/lib/module/native/CactusDeviceInfo.js.map +0 -1
- package/lib/module/native/CactusUtil.js +0 -36
- package/lib/module/native/CactusUtil.js.map +0 -1
- package/lib/module/specs/CactusCrypto.nitro.js +0 -4
- package/lib/module/specs/CactusCrypto.nitro.js.map +0 -1
- package/lib/module/specs/CactusDeviceInfo.nitro.js +0 -4
- package/lib/module/specs/CactusDeviceInfo.nitro.js.map +0 -1
- package/lib/module/specs/CactusUtil.nitro.js +0 -4
- package/lib/module/telemetry/Telemetry.js +0 -154
- package/lib/module/telemetry/Telemetry.js.map +0 -1
- package/lib/typescript/src/api/Database.d.ts +0 -12
- package/lib/typescript/src/api/Database.d.ts.map +0 -1
- package/lib/typescript/src/api/RemoteLM.d.ts +0 -14
- package/lib/typescript/src/api/RemoteLM.d.ts.map +0 -1
- package/lib/typescript/src/config/CactusConfig.d.ts +0 -7
- package/lib/typescript/src/config/CactusConfig.d.ts.map +0 -1
- package/lib/typescript/src/models.d.ts +0 -6
- package/lib/typescript/src/models.d.ts.map +0 -1
- package/lib/typescript/src/native/CactusCrypto.d.ts +0 -5
- package/lib/typescript/src/native/CactusCrypto.d.ts.map +0 -1
- package/lib/typescript/src/native/CactusDeviceInfo.d.ts +0 -7
- package/lib/typescript/src/native/CactusDeviceInfo.d.ts.map +0 -1
- package/lib/typescript/src/native/CactusUtil.d.ts +0 -6
- package/lib/typescript/src/native/CactusUtil.d.ts.map +0 -1
- package/lib/typescript/src/specs/CactusCrypto.nitro.d.ts +0 -8
- package/lib/typescript/src/specs/CactusCrypto.nitro.d.ts.map +0 -1
- package/lib/typescript/src/specs/CactusDeviceInfo.nitro.d.ts +0 -16
- package/lib/typescript/src/specs/CactusDeviceInfo.nitro.d.ts.map +0 -1
- package/lib/typescript/src/specs/CactusUtil.nitro.d.ts +0 -10
- package/lib/typescript/src/specs/CactusUtil.nitro.d.ts.map +0 -1
- package/lib/typescript/src/telemetry/Telemetry.d.ts +0 -34
- package/lib/typescript/src/telemetry/Telemetry.d.ts.map +0 -1
- package/nitrogen/generated/android/c++/JDeviceInfo.hpp +0 -74
- package/nitrogen/generated/android/c++/JHybridCactusCryptoSpec.cpp +0 -65
- package/nitrogen/generated/android/c++/JHybridCactusCryptoSpec.hpp +0 -65
- package/nitrogen/generated/android/c++/JHybridCactusDeviceInfoSpec.cpp +0 -85
- package/nitrogen/generated/android/c++/JHybridCactusDeviceInfoSpec.hpp +0 -66
- package/nitrogen/generated/android/kotlin/com/margelo/nitro/cactus/DeviceInfo.kt +0 -50
- package/nitrogen/generated/android/kotlin/com/margelo/nitro/cactus/HybridCactusCryptoSpec.kt +0 -58
- package/nitrogen/generated/android/kotlin/com/margelo/nitro/cactus/HybridCactusDeviceInfoSpec.kt +0 -62
- package/nitrogen/generated/ios/c++/HybridCactusCryptoSpecSwift.cpp +0 -11
- package/nitrogen/generated/ios/c++/HybridCactusCryptoSpecSwift.hpp +0 -77
- package/nitrogen/generated/ios/c++/HybridCactusDeviceInfoSpecSwift.cpp +0 -11
- package/nitrogen/generated/ios/c++/HybridCactusDeviceInfoSpecSwift.hpp +0 -88
- package/nitrogen/generated/ios/swift/DeviceInfo.swift +0 -98
- package/nitrogen/generated/ios/swift/Func_void_DeviceInfo.swift +0 -47
- package/nitrogen/generated/ios/swift/Func_void_std__optional_std__string_.swift +0 -54
- package/nitrogen/generated/ios/swift/HybridCactusCryptoSpec.swift +0 -57
- package/nitrogen/generated/ios/swift/HybridCactusCryptoSpec_cxx.swift +0 -139
- package/nitrogen/generated/ios/swift/HybridCactusDeviceInfoSpec.swift +0 -58
- package/nitrogen/generated/ios/swift/HybridCactusDeviceInfoSpec_cxx.swift +0 -164
- package/nitrogen/generated/shared/c++/DeviceInfo.hpp +0 -92
- package/nitrogen/generated/shared/c++/HybridCactusCryptoSpec.cpp +0 -21
- package/nitrogen/generated/shared/c++/HybridCactusCryptoSpec.hpp +0 -63
- package/nitrogen/generated/shared/c++/HybridCactusDeviceInfoSpec.cpp +0 -22
- package/nitrogen/generated/shared/c++/HybridCactusDeviceInfoSpec.hpp +0 -67
- package/nitrogen/generated/shared/c++/HybridCactusUtilSpec.cpp +0 -23
- package/nitrogen/generated/shared/c++/HybridCactusUtilSpec.hpp +0 -66
- package/src/api/Database.ts +0 -55
- package/src/api/RemoteLM.ts +0 -273
- package/src/config/CactusConfig.ts +0 -11
- package/src/models.ts +0 -344
- package/src/native/CactusCrypto.ts +0 -11
- package/src/native/CactusDeviceInfo.ts +0 -18
- package/src/native/CactusUtil.ts +0 -43
- package/src/specs/CactusCrypto.nitro.ts +0 -6
- package/src/specs/CactusDeviceInfo.nitro.ts +0 -15
- package/src/specs/CactusUtil.nitro.ts +0 -8
- package/src/telemetry/Telemetry.ts +0 -236
|
@@ -56,6 +56,12 @@ struct Config {
|
|
|
56
56
|
uint32_t num_shared_experts = 0;
|
|
57
57
|
uint32_t num_top_experts = 0;
|
|
58
58
|
uint32_t moe_every_n_layers = 0;
|
|
59
|
+
uint32_t moe_intermediate_dim = 0;
|
|
60
|
+
uint32_t num_dense_layers = 0;
|
|
61
|
+
uint32_t num_experts_per_tok = 0;
|
|
62
|
+
bool norm_topk_prob = false;
|
|
63
|
+
bool use_expert_bias = false;
|
|
64
|
+
float routed_scaling_factor = 1.0f;
|
|
59
65
|
bool tie_word_embeddings = true;
|
|
60
66
|
|
|
61
67
|
uint32_t vision_hidden_dim = 0;
|
|
@@ -84,12 +90,31 @@ struct Config {
|
|
|
84
90
|
bool use_thumbnail = true;
|
|
85
91
|
uint32_t min_image_tokens = 64;
|
|
86
92
|
uint32_t max_image_tokens = 256;
|
|
87
|
-
|
|
93
|
+
uint32_t max_num_patches = 1024;
|
|
88
94
|
uint32_t tile_size = 512;
|
|
89
95
|
float max_pixels_tolerance = 2.0f;
|
|
90
96
|
bool do_image_splitting = true;
|
|
91
|
-
|
|
92
|
-
|
|
97
|
+
bool encoder_act_gelu = false;
|
|
98
|
+
bool decoder_act_gelu = false;
|
|
99
|
+
uint32_t num_encoder_layers = 0;
|
|
100
|
+
uint32_t num_decoder_layers = 0;
|
|
101
|
+
float partial_rotary_factor = 0.0f;
|
|
102
|
+
uint32_t pad_token_id = 0;
|
|
103
|
+
uint32_t conv_kernel_size = 0;
|
|
104
|
+
uint32_t subsampling_conv_kernel_size = 0;
|
|
105
|
+
uint32_t subsampling_conv_stride = 0;
|
|
106
|
+
uint32_t subsampling_conv_channels = 0;
|
|
107
|
+
uint32_t subsampling_factor = 0;
|
|
108
|
+
uint32_t num_mel_bins = 80;
|
|
109
|
+
std::string encoder_hidden_act = "silu";
|
|
110
|
+
uint32_t predictor_hidden_dim = 0;
|
|
111
|
+
uint32_t predictor_num_layers = 0;
|
|
112
|
+
uint32_t tdt_joint_dim = 0;
|
|
113
|
+
uint32_t tdt_num_durations = 0;
|
|
114
|
+
uint32_t tdt_blank_id = 0;
|
|
115
|
+
std::vector<uint32_t> tdt_durations;
|
|
116
|
+
|
|
117
|
+
enum class ModelType {QWEN = 0, GEMMA = 1, NOMIC = 3, LFM2 = 5, SIGLIP2 = 6, WHISPER = 7, MOONSHINE = 8, SILERO_VAD = 9, PARAKEET = 10, PARAKEET_TDT = 11};
|
|
93
118
|
ModelType model_type = ModelType::QWEN;
|
|
94
119
|
|
|
95
120
|
enum class ModelVariant {DEFAULT = 0, VLM = 1, EXTRACT = 2, RAG = 3};
|
|
@@ -107,6 +132,8 @@ struct Config {
|
|
|
107
132
|
float default_temperature = 0.6f;
|
|
108
133
|
float default_top_p = 0.95f;
|
|
109
134
|
size_t default_top_k = 20;
|
|
135
|
+
float default_max_tps = -1.0f;
|
|
136
|
+
float default_cloud_handoff_threshold = 0.0f;
|
|
110
137
|
|
|
111
138
|
std::vector<std::string> layer_types;
|
|
112
139
|
size_t conv_L_cache = 0;
|
|
@@ -152,6 +179,7 @@ public:
|
|
|
152
179
|
virtual uint32_t get_bos_token() const = 0;
|
|
153
180
|
virtual uint32_t get_eos_token() const = 0;
|
|
154
181
|
virtual bool has_chat_template() const { return has_chat_template_; }
|
|
182
|
+
std::string get_default_stop_sequence() const;
|
|
155
183
|
|
|
156
184
|
virtual bool load_vocabulary_with_config(const std::string& vocab_file, const std::string& merges_file, const std::string& config_file) = 0;
|
|
157
185
|
|
|
@@ -159,11 +187,8 @@ public:
|
|
|
159
187
|
uint32_t get_fake_token_id() const { return fake_token_id_; }
|
|
160
188
|
uint32_t get_global_img_token_id() const { return global_img_token_id_; }
|
|
161
189
|
|
|
162
|
-
|
|
163
|
-
void set_corpus_dir(const std::string& dir) { corpus_dir_ = dir; }
|
|
164
|
-
|
|
165
190
|
protected:
|
|
166
|
-
enum class ModelType { UNKNOWN, QWEN, GEMMA, LFM2,
|
|
191
|
+
enum class ModelType { UNKNOWN, QWEN, GEMMA, LFM2, BERT, WHISPER, PARAKEET};
|
|
167
192
|
ModelType model_type_ = ModelType::UNKNOWN;
|
|
168
193
|
enum class ModelVariant { DEFAULT, VLM, EXTRACT, RAG};
|
|
169
194
|
ModelVariant model_variant_ = ModelVariant::DEFAULT;
|
|
@@ -173,14 +198,12 @@ protected:
|
|
|
173
198
|
uint32_t image_token_id_ = 396;
|
|
174
199
|
uint32_t fake_token_id_ = 49189;
|
|
175
200
|
uint32_t global_img_token_id_ = 49152;
|
|
176
|
-
std::string corpus_dir_;
|
|
177
201
|
|
|
178
202
|
void detect_model_type(const std::string& config_path);
|
|
179
203
|
std::string format_qwen_style(const std::vector<ChatMessage>& messages, bool add_generation_prompt, const std::string& tools_json) const;
|
|
180
204
|
std::string format_gemma_style(const std::vector<ChatMessage>& messages, bool add_generation_prompt, const std::string& tools_json) const;
|
|
181
205
|
std::string format_lfm2_style(const std::vector<ChatMessage>& messages, bool add_generation_prompt, const std::string& tools_json) const;
|
|
182
206
|
std::string format_lfm2_vl_style(const std::vector<ChatMessage>& messages, bool add_generation_prompt, const std::string& tools_json) const;
|
|
183
|
-
std::string format_smol_style(const std::vector<ChatMessage>& messages, bool add_generation_prompt, const std::string& tools_json) const;
|
|
184
207
|
};
|
|
185
208
|
|
|
186
209
|
class BPETokenizer : public Tokenizer {
|
|
@@ -363,7 +386,6 @@ struct KVCache {
|
|
|
363
386
|
size_t num_tokens, size_t kv_heads, size_t head_dim);
|
|
364
387
|
|
|
365
388
|
bool is_empty() const { return current_seq_len == 0; }
|
|
366
|
-
bool is_int8() const { return precision == Precision::INT8; }
|
|
367
389
|
void* get_key_ptr(size_t layer);
|
|
368
390
|
void* get_value_ptr(size_t layer);
|
|
369
391
|
|
|
@@ -471,6 +493,8 @@ private:
|
|
|
471
493
|
void compute_bias();
|
|
472
494
|
void tokenize_grammar_elements();
|
|
473
495
|
void add_tokens_for_string(const std::string& str, std::unordered_set<uint32_t>& token_set);
|
|
496
|
+
void tokenize_function_names(bool quote_names);
|
|
497
|
+
void init_common_tokens();
|
|
474
498
|
};
|
|
475
499
|
|
|
476
500
|
class Model {
|
|
@@ -495,22 +519,22 @@ public:
|
|
|
495
519
|
const std::string& system_prompt = "", bool do_warmup = true);
|
|
496
520
|
|
|
497
521
|
virtual uint32_t decode(const std::vector<uint32_t>& tokens, float temperature = -1.0f, float top_p = -1.0f,
|
|
498
|
-
size_t top_k = 0, const std::string& profile_file = "");
|
|
522
|
+
size_t top_k = 0, const std::string& profile_file = "", float* out_entropy = nullptr);
|
|
499
523
|
|
|
500
524
|
virtual void prefill(const std::vector<uint32_t>& tokens, size_t chunk_size = 256, const std::string& profile_file = "");
|
|
501
525
|
|
|
502
526
|
virtual uint32_t decode_with_images(const std::vector<uint32_t>& tokens, const std::vector<std::string>& image_paths,
|
|
503
527
|
float temperature = -1.0f, float top_p = -1.0f,
|
|
504
|
-
size_t top_k = 0, const std::string& profile_file = "");
|
|
528
|
+
size_t top_k = 0, const std::string& profile_file = "", float* out_entropy = nullptr);
|
|
505
529
|
|
|
506
|
-
virtual uint32_t decode_with_audio(const std::vector<uint32_t>& tokens, const std::vector<float>&
|
|
507
|
-
size_t top_k = 0, const std::string& profile_file = "");
|
|
530
|
+
virtual uint32_t decode_with_audio(const std::vector<uint32_t>& tokens, const std::vector<float>& audio_features, float temperature = 0.0f, float top_p = 0.0f,
|
|
531
|
+
size_t top_k = 0, const std::string& profile_file = "", float* out_entropy = nullptr);
|
|
508
532
|
|
|
509
533
|
std::vector<float> get_embeddings(const std::vector<uint32_t>& tokens, bool pooled = true, bool normalize = false, const std::string& profile_file = "");
|
|
510
534
|
|
|
511
535
|
virtual std::vector<float> get_image_embeddings(const std::string& image_path);
|
|
512
536
|
|
|
513
|
-
virtual std::vector<float> get_audio_embeddings(const std::vector<float>&
|
|
537
|
+
virtual std::vector<float> get_audio_embeddings(const std::vector<float>& audio_features);
|
|
514
538
|
|
|
515
539
|
virtual void reset_cache() { kv_cache_.reset(); }
|
|
516
540
|
|
|
@@ -533,7 +557,7 @@ public:
|
|
|
533
557
|
protected:
|
|
534
558
|
virtual size_t forward(const std::vector<uint32_t>& tokens, bool use_cache = false) = 0;
|
|
535
559
|
|
|
536
|
-
virtual size_t forward(const std::vector<float>&
|
|
560
|
+
virtual size_t forward(const std::vector<float>& audio_features, const std::vector<uint32_t>& tokens, bool use_cache = false);
|
|
537
561
|
|
|
538
562
|
virtual void load_weights_to_graph(CactusGraph* gb) = 0;
|
|
539
563
|
|
|
@@ -645,6 +669,7 @@ public:
|
|
|
645
669
|
private:
|
|
646
670
|
Config config_;
|
|
647
671
|
|
|
672
|
+
std::pair<int64_t, int64_t> compute_pixel_limits() const;
|
|
648
673
|
std::vector<unsigned char> convert_to_rgb(const unsigned char* img_data, int width, int height, int channels);
|
|
649
674
|
std::pair<int, int> smart_resize(int height, int width);
|
|
650
675
|
bool is_image_too_large(int height, int width);
|
|
@@ -678,6 +703,8 @@ public:
|
|
|
678
703
|
float reference = 1.0f;
|
|
679
704
|
float min_value = 1e-10f;
|
|
680
705
|
bool remove_dc_offset = false;
|
|
706
|
+
float preemphasis = 0.0f;
|
|
707
|
+
bool hann_periodic = true;
|
|
681
708
|
};
|
|
682
709
|
|
|
683
710
|
AudioProcessor();
|
|
@@ -690,6 +717,11 @@ public:
|
|
|
690
717
|
const std::vector<float>& waveform,
|
|
691
718
|
const SpectrogramConfig& config);
|
|
692
719
|
|
|
720
|
+
static std::vector<float> compute_irfft(
|
|
721
|
+
const std::vector<float>& complex_input,
|
|
722
|
+
size_t n,
|
|
723
|
+
const char* norm = "backward");
|
|
724
|
+
|
|
693
725
|
const std::vector<float>& get_mel_filters() const { return mel_filters_; }
|
|
694
726
|
|
|
695
727
|
size_t get_num_mel_filters() const { return num_mel_filters_; }
|
|
@@ -701,5 +733,104 @@ private:
|
|
|
701
733
|
size_t num_mel_filters_;
|
|
702
734
|
};
|
|
703
735
|
|
|
736
|
+
namespace index {
|
|
737
|
+
constexpr uint32_t MAGIC = 0x43414354;
|
|
738
|
+
constexpr uint32_t VERSION = 1;
|
|
739
|
+
|
|
740
|
+
struct Document {
|
|
741
|
+
int id;
|
|
742
|
+
std::vector<float> embedding;
|
|
743
|
+
std::string content;
|
|
744
|
+
std::string metadata;
|
|
745
|
+
};
|
|
746
|
+
|
|
747
|
+
struct QueryResult {
|
|
748
|
+
int doc_id;
|
|
749
|
+
float score;
|
|
750
|
+
|
|
751
|
+
QueryResult(int doc_id, float score) : doc_id(doc_id), score(score) {}
|
|
752
|
+
};
|
|
753
|
+
|
|
754
|
+
struct QueryOptions {
|
|
755
|
+
size_t top_k = 10;
|
|
756
|
+
float score_threshold = -1.0f;
|
|
757
|
+
};
|
|
758
|
+
|
|
759
|
+
class Index {
|
|
760
|
+
public:
|
|
761
|
+
Index(const std::string& index_path, const std::string& data_path, size_t embedding_dim);
|
|
762
|
+
~Index();
|
|
763
|
+
|
|
764
|
+
Index(const Index&) = delete;
|
|
765
|
+
Index& operator=(const Index&) = delete;
|
|
766
|
+
Index(Index&&) = delete;
|
|
767
|
+
Index& operator=(Index&&) = delete;
|
|
768
|
+
|
|
769
|
+
void add_documents(const std::vector<Document>& documents);
|
|
770
|
+
void delete_documents(const std::vector<int>& doc_ids);
|
|
771
|
+
std::vector<Document> get_documents(const std::vector<int>& doc_ids);
|
|
772
|
+
std::vector<std::vector<QueryResult>> query(const std::vector<std::vector<float>>& embeddings, const QueryOptions& options);
|
|
773
|
+
void compact();
|
|
774
|
+
|
|
775
|
+
private:
|
|
776
|
+
struct IndexHeader {
|
|
777
|
+
uint32_t magic;
|
|
778
|
+
uint32_t version;
|
|
779
|
+
uint32_t embedding_dim;
|
|
780
|
+
uint32_t num_documents;
|
|
781
|
+
};
|
|
782
|
+
|
|
783
|
+
struct IndexEntry {
|
|
784
|
+
int32_t doc_id;
|
|
785
|
+
uint64_t data_offset;
|
|
786
|
+
uint8_t flags; // bit 0: tombstone
|
|
787
|
+
|
|
788
|
+
const __fp16* embedding() const {
|
|
789
|
+
return reinterpret_cast<const __fp16*>(this + 1);
|
|
790
|
+
}
|
|
791
|
+
|
|
792
|
+
static size_t size(size_t embedding_dim) {
|
|
793
|
+
return sizeof(IndexEntry) + embedding_dim * sizeof(__fp16);
|
|
794
|
+
}
|
|
795
|
+
};
|
|
796
|
+
|
|
797
|
+
struct DataHeader {
|
|
798
|
+
uint32_t magic;
|
|
799
|
+
uint32_t version;
|
|
800
|
+
};
|
|
801
|
+
|
|
802
|
+
struct DataEntry {
|
|
803
|
+
uint16_t content_len;
|
|
804
|
+
uint16_t metadata_len;
|
|
805
|
+
|
|
806
|
+
const char* content() const {
|
|
807
|
+
return reinterpret_cast<const char*>(this + 1);
|
|
808
|
+
}
|
|
809
|
+
|
|
810
|
+
const char* metadata() const {
|
|
811
|
+
return content() + content_len;
|
|
812
|
+
}
|
|
813
|
+
};
|
|
814
|
+
|
|
815
|
+
void parse_index_header();
|
|
816
|
+
void parse_data_header();
|
|
817
|
+
void build_doc_id_map();
|
|
818
|
+
void validate_documents(const std::vector<Document>& documents);
|
|
819
|
+
void validate_doc_ids(const std::vector<int>& doc_ids);
|
|
820
|
+
ssize_t write_full(int fd, const void* buf, size_t count);
|
|
821
|
+
|
|
822
|
+
std::unordered_map<int, uint32_t> doc_id_map_;
|
|
823
|
+
|
|
824
|
+
std::string index_path_, data_path_;
|
|
825
|
+
size_t embedding_dim_;
|
|
826
|
+
size_t index_entry_size_;
|
|
827
|
+
uint32_t num_documents_;
|
|
828
|
+
|
|
829
|
+
int index_fd_, data_fd_;
|
|
830
|
+
void *mapped_index_, *mapped_data_;
|
|
831
|
+
size_t index_file_size_, data_file_size_;
|
|
832
|
+
};
|
|
833
|
+
} // namespace index
|
|
834
|
+
|
|
835
|
+
}
|
|
704
836
|
}
|
|
705
|
-
}
|
|
@@ -4,7 +4,9 @@
|
|
|
4
4
|
#include <vector>
|
|
5
5
|
#include <memory>
|
|
6
6
|
#include <unordered_map>
|
|
7
|
+
#include <unordered_set>
|
|
7
8
|
#include <functional>
|
|
9
|
+
#include <cassert>
|
|
8
10
|
#include <cstring>
|
|
9
11
|
#include <stdexcept>
|
|
10
12
|
#include <string>
|
|
@@ -108,23 +110,36 @@ enum class ComputeBackend {
|
|
|
108
110
|
NPU
|
|
109
111
|
};
|
|
110
112
|
|
|
113
|
+
enum class Activation {
|
|
114
|
+
SILU,
|
|
115
|
+
GELU,
|
|
116
|
+
GELU_ERF,
|
|
117
|
+
RELU,
|
|
118
|
+
SIGMOID,
|
|
119
|
+
TANH
|
|
120
|
+
};
|
|
121
|
+
|
|
111
122
|
enum class OpType {
|
|
112
123
|
INPUT, PRECISION_CAST,
|
|
113
124
|
ADD, ADD_CLIPPED, SUBTRACT, MULTIPLY, DIVIDE,
|
|
114
125
|
MATMUL, TRANSPOSE, RESHAPE, SLICE, GATHER, EMBEDDING,
|
|
115
126
|
BILINEAR_INTERPOLATION,
|
|
116
127
|
SUM, MEAN, VARIANCE, MIN, MAX,
|
|
117
|
-
RMS_NORM, ROPE, SOFTMAX, ATTENTION, ATTENTION_INT8_HYBRID, CONV1D_CAUSAL, CONV1D_K3,
|
|
118
|
-
SCALAR_ADD, SCALAR_SUBTRACT, SCALAR_MULTIPLY, SCALAR_DIVIDE, SCALAR_EXP, SCALAR_SQRT, SCALAR_COS, SCALAR_SIN,
|
|
119
|
-
SILU, GELU, GELU_ERF,
|
|
128
|
+
RMS_NORM, ROPE, ROPE_GPTJ, SOFTMAX, ATTENTION, ATTENTION_INT8_HYBRID, REL_POS_BIAS, CONV1D_CAUSAL, CONV1D_K3, CONV1D_K7S3, CONV1D, CONV1D_SAME_DEPTHWISE_K9, CONV1D_POINTWISE, CONV2D_K3S2P1, CONV2D_DEPTHWISE_K3S2P1, CONV2D_POINTWISE_1X1, GLU, BATCHNORM,
|
|
129
|
+
SCALAR_ADD, SCALAR_SUBTRACT, SCALAR_MULTIPLY, SCALAR_DIVIDE, SCALAR_EXP, SCALAR_SQRT, SCALAR_COS, SCALAR_SIN, SCALAR_LOG,
|
|
130
|
+
RELU, SILU, GELU, GELU_ERF, SIGMOID, TANH,
|
|
120
131
|
SAMPLE, CONCAT,
|
|
121
132
|
SCATTER_TOPK,
|
|
122
|
-
TOPK, LAYERNORM,
|
|
133
|
+
TOPK, LAYERNORM, GROUPNORM,
|
|
134
|
+
MOE_LAYER,
|
|
123
135
|
INDEX,
|
|
136
|
+
PERSISTENT,
|
|
137
|
+
QUANTIZE_ACTIVATIONS,
|
|
138
|
+
LSTM_CELL,
|
|
139
|
+
STFT
|
|
124
140
|
};
|
|
125
141
|
|
|
126
142
|
struct PrecisionTraits {
|
|
127
|
-
// Returns in-memory element size (INT4 unpacks to INT8, so returns 1)
|
|
128
143
|
static constexpr size_t size_of(Precision prec) {
|
|
129
144
|
switch (prec) {
|
|
130
145
|
case Precision::INT8: return 1;
|
|
@@ -137,11 +152,20 @@ struct PrecisionTraits {
|
|
|
137
152
|
|
|
138
153
|
static constexpr size_t packed_size_of(Precision prec, size_t count) {
|
|
139
154
|
switch (prec) {
|
|
140
|
-
case Precision::INT4: return (count + 1) / 2;
|
|
155
|
+
case Precision::INT4: return (count + 1) / 2;
|
|
141
156
|
default: return count * size_of(prec);
|
|
142
157
|
}
|
|
143
158
|
}
|
|
144
159
|
|
|
160
|
+
static size_t byte_offset_of(Precision prec, size_t element_offset) {
|
|
161
|
+
switch (prec) {
|
|
162
|
+
case Precision::INT4:
|
|
163
|
+
assert(element_offset % 32 == 0 && "INT4 byte offset must be group-aligned (multiple of 32)");
|
|
164
|
+
return element_offset / 2;
|
|
165
|
+
default: return element_offset * size_of(prec);
|
|
166
|
+
}
|
|
167
|
+
}
|
|
168
|
+
|
|
145
169
|
static constexpr bool is_integer(Precision prec) {
|
|
146
170
|
switch (prec) {
|
|
147
171
|
case Precision::INT8: return true;
|
|
@@ -177,7 +201,6 @@ struct TensorConfig {
|
|
|
177
201
|
Precision compute_precision = Precision::INT8;
|
|
178
202
|
Precision output_precision = Precision::INT8;
|
|
179
203
|
bool auto_mixed_precision = false;
|
|
180
|
-
bool enable_int4_packing = true;
|
|
181
204
|
|
|
182
205
|
static TensorConfig& global();
|
|
183
206
|
};
|
|
@@ -205,8 +228,12 @@ struct BufferDesc {
|
|
|
205
228
|
void* scales_data = nullptr;
|
|
206
229
|
std::unique_ptr<char[]> owned_scales;
|
|
207
230
|
|
|
208
|
-
|
|
209
|
-
size_t
|
|
231
|
+
bool is_interleaved = false;
|
|
232
|
+
size_t original_N = 0;
|
|
233
|
+
|
|
234
|
+
void* activation_scales_data = nullptr;
|
|
235
|
+
std::unique_ptr<char[]> owned_activation_scales;
|
|
236
|
+
size_t num_rows_for_activation_scales = 0;
|
|
210
237
|
|
|
211
238
|
BufferDesc();
|
|
212
239
|
BufferDesc(const std::vector<size_t>& s, Precision prec = Precision::INT8);
|
|
@@ -230,23 +257,43 @@ struct BufferDesc {
|
|
|
230
257
|
const __fp16* scales_as_fp16() const {
|
|
231
258
|
return reinterpret_cast<const __fp16*>(scales_data);
|
|
232
259
|
}
|
|
260
|
+
|
|
233
261
|
bool is_grouped_int8() const {
|
|
234
262
|
return precision == Precision::INT8 && group_size > 0;
|
|
235
263
|
}
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
const uint8_t* packed_int4_as_uint8() const {
|
|
240
|
-
return reinterpret_cast<const uint8_t*>(packed_int4_data);
|
|
264
|
+
|
|
265
|
+
bool is_grouped_int4() const {
|
|
266
|
+
return precision == Precision::INT4 && group_size > 0;
|
|
241
267
|
}
|
|
268
|
+
|
|
242
269
|
void set_grouped_scales(size_t gs, size_t ng, void* scales_ptr) {
|
|
243
270
|
group_size = gs;
|
|
244
271
|
num_groups = ng;
|
|
245
272
|
scales_data = scales_ptr;
|
|
246
273
|
}
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
274
|
+
|
|
275
|
+
void set_interleaved(bool interleaved, size_t orig_n) {
|
|
276
|
+
is_interleaved = interleaved;
|
|
277
|
+
original_N = orig_n;
|
|
278
|
+
}
|
|
279
|
+
|
|
280
|
+
bool has_activation_scales() const {
|
|
281
|
+
return activation_scales_data != nullptr && num_rows_for_activation_scales > 0;
|
|
282
|
+
}
|
|
283
|
+
const float* activation_scales_as_float() const {
|
|
284
|
+
return reinterpret_cast<const float*>(activation_scales_data);
|
|
285
|
+
}
|
|
286
|
+
float* activation_scales_as_float() {
|
|
287
|
+
return reinterpret_cast<float*>(activation_scales_data);
|
|
288
|
+
}
|
|
289
|
+
void allocate_activation_scales(size_t num_rows) {
|
|
290
|
+
num_rows_for_activation_scales = num_rows;
|
|
291
|
+
owned_activation_scales = std::make_unique<char[]>(num_rows * sizeof(float));
|
|
292
|
+
activation_scales_data = owned_activation_scales.get();
|
|
293
|
+
}
|
|
294
|
+
void set_activation_scales(void* scales_ptr, size_t num_rows) {
|
|
295
|
+
activation_scales_data = scales_ptr;
|
|
296
|
+
num_rows_for_activation_scales = num_rows;
|
|
250
297
|
}
|
|
251
298
|
|
|
252
299
|
void allocate();
|
|
@@ -267,6 +314,7 @@ struct OpParams {
|
|
|
267
314
|
size_t slice_length = 0;
|
|
268
315
|
size_t window_size = 0;
|
|
269
316
|
bool is_causal = true;
|
|
317
|
+
bool attention_mask_is_additive = false;
|
|
270
318
|
std::vector<size_t> new_shape;
|
|
271
319
|
std::vector<size_t> permutation;
|
|
272
320
|
Precision output_precision = Precision::INT8;
|
|
@@ -282,8 +330,14 @@ struct OpParams {
|
|
|
282
330
|
|
|
283
331
|
size_t index_value = 0;
|
|
284
332
|
size_t num_classes = 0;
|
|
333
|
+
size_t num_groups = 0;
|
|
285
334
|
size_t dst_height = 0;
|
|
286
335
|
size_t dst_width = 0;
|
|
336
|
+
bool normalize_routing = false;
|
|
337
|
+
size_t num_experts = 0;
|
|
338
|
+
size_t num_experts_per_tok = 0;
|
|
339
|
+
bool moe_gated = true;
|
|
340
|
+
Activation activation = Activation::SILU;
|
|
287
341
|
|
|
288
342
|
std::vector<float> bias_values;
|
|
289
343
|
std::vector<uint32_t> bias_indices;
|
|
@@ -295,6 +349,7 @@ struct OpParams {
|
|
|
295
349
|
size_t cache_seq_len = 0;
|
|
296
350
|
size_t num_kv_heads = 0;
|
|
297
351
|
size_t head_dim = 0;
|
|
352
|
+
size_t num_fft_bins = 0;
|
|
298
353
|
};
|
|
299
354
|
|
|
300
355
|
struct GraphNode {
|
|
@@ -324,10 +379,12 @@ void compute_sample_node(GraphNode& node, const std::vector<std::unique_ptr<Grap
|
|
|
324
379
|
void compute_scatter_topk_node(GraphNode& node, const std::vector<std::unique_ptr<GraphNode>>& nodes, const std::unordered_map<size_t, size_t>& node_index_map);
|
|
325
380
|
void compute_topk_node(GraphNode& node, const std::vector<std::unique_ptr<GraphNode>>& nodes, const std::unordered_map<size_t, size_t>& node_index_map);
|
|
326
381
|
void compute_layernorm_node(GraphNode& node, const std::vector<std::unique_ptr<GraphNode>>& nodes, const std::unordered_map<size_t, size_t>& node_index_map);
|
|
382
|
+
void compute_groupnorm_node(GraphNode& node, const std::vector<std::unique_ptr<GraphNode>>& nodes, const std::unordered_map<size_t, size_t>& node_index_map);
|
|
383
|
+
void compute_persistent_node(GraphNode& node, const std::vector<std::unique_ptr<GraphNode>>& nodes, const std::unordered_map<size_t, size_t>& node_index_map);
|
|
327
384
|
void compute_index_node(GraphNode& node, const std::vector<std::unique_ptr<GraphNode>>& nodes, const std::unordered_map<size_t, size_t>& node_index_map);
|
|
385
|
+
void compute_lstm_cell_node(GraphNode& node, const std::vector<std::unique_ptr<GraphNode>>& nodes, const std::unordered_map<size_t, size_t>& node_index_map);
|
|
328
386
|
|
|
329
387
|
void shrink_thread_local_buffers();
|
|
330
|
-
|
|
331
388
|
class BufferPool {
|
|
332
389
|
public:
|
|
333
390
|
BufferPool() = default;
|
|
@@ -372,6 +429,7 @@ public:
|
|
|
372
429
|
|
|
373
430
|
size_t input(const std::vector<size_t>& shape, Precision precision = Precision::INT8);
|
|
374
431
|
size_t precision_cast(size_t input, Precision target_precision);
|
|
432
|
+
size_t quantize_activations(size_t input);
|
|
375
433
|
|
|
376
434
|
size_t add(size_t input1, size_t input2);
|
|
377
435
|
size_t add_clipped(size_t input1, size_t input2);
|
|
@@ -388,10 +446,15 @@ public:
|
|
|
388
446
|
size_t scalar_sqrt(size_t input);
|
|
389
447
|
size_t scalar_cos(size_t input);
|
|
390
448
|
size_t scalar_sin(size_t input);
|
|
449
|
+
size_t scalar_log(size_t input);
|
|
391
450
|
|
|
451
|
+
size_t relu(size_t input);
|
|
392
452
|
size_t silu(size_t input);
|
|
393
453
|
size_t gelu(size_t input);
|
|
394
454
|
size_t gelu_erf(size_t input);
|
|
455
|
+
size_t sigmoid(size_t input);
|
|
456
|
+
size_t tanh(size_t input);
|
|
457
|
+
size_t glu(size_t input, int axis = -1);
|
|
395
458
|
|
|
396
459
|
size_t matmul(size_t input1, size_t input2, bool pretransposed_rhs = false, ComputeBackend backend = ComputeBackend::CPU);
|
|
397
460
|
size_t transpose(size_t input, ComputeBackend backend = ComputeBackend::CPU);
|
|
@@ -409,8 +472,8 @@ public:
|
|
|
409
472
|
size_t gather(size_t embeddings, size_t indices);
|
|
410
473
|
size_t mmap_embeddings(const std::string& filename);
|
|
411
474
|
size_t mmap_weights(const std::string& filename);
|
|
412
|
-
size_t load_weights(const std::string& filename);
|
|
413
475
|
void set_grouped_scales(size_t node_id, size_t group_size, size_t num_groups, void* scales_ptr);
|
|
476
|
+
void set_interleaved(size_t node_id, bool interleaved, size_t original_N);
|
|
414
477
|
|
|
415
478
|
void release_weight_pages(size_t node_id);
|
|
416
479
|
void prefetch_weight_pages(size_t node_id);
|
|
@@ -420,22 +483,68 @@ public:
|
|
|
420
483
|
size_t bilinear_interpolation(size_t pos_embeds, size_t dst_height, size_t dst_width);
|
|
421
484
|
|
|
422
485
|
size_t layernorm(size_t input, size_t weight, size_t bias, float epsilon = 1e-5f);
|
|
486
|
+
size_t layernorm(size_t input, size_t weight, float epsilon = 1e-5f); // No bias version
|
|
487
|
+
size_t groupnorm(size_t input, size_t weight, size_t bias, size_t num_groups = 32, float epsilon = 1e-5f);
|
|
488
|
+
size_t batchnorm(size_t input, size_t weight, size_t bias, size_t running_mean, size_t running_var, int axis = 1, float epsilon = 1e-5f);
|
|
423
489
|
size_t topk(size_t input, size_t k);
|
|
490
|
+
size_t moe_layer(size_t hidden,
|
|
491
|
+
size_t routing_probs,
|
|
492
|
+
size_t topk_indices,
|
|
493
|
+
const std::vector<size_t>& w1_weights,
|
|
494
|
+
const std::vector<size_t>& w3_weights,
|
|
495
|
+
const std::vector<size_t>& w2_weights,
|
|
496
|
+
size_t num_experts,
|
|
497
|
+
size_t num_experts_per_tok,
|
|
498
|
+
bool normalize_routing,
|
|
499
|
+
float epsilon,
|
|
500
|
+
float routed_scaling_factor);
|
|
501
|
+
size_t moe_layer(size_t hidden,
|
|
502
|
+
size_t routing_probs,
|
|
503
|
+
size_t topk_indices,
|
|
504
|
+
const std::vector<size_t>& w1_weights,
|
|
505
|
+
const std::vector<size_t>& w2_weights,
|
|
506
|
+
size_t num_experts,
|
|
507
|
+
size_t num_experts_per_tok,
|
|
508
|
+
bool normalize_routing,
|
|
509
|
+
float epsilon,
|
|
510
|
+
float routed_scaling_factor,
|
|
511
|
+
Activation activation);
|
|
424
512
|
size_t rms_norm(size_t input, size_t weight, float epsilon = 1e-5f);
|
|
425
513
|
size_t rope(size_t input, float theta, size_t position_offset = 0, ComputeBackend backend = ComputeBackend::CPU);
|
|
514
|
+
size_t rope_gptj(size_t input, float theta, size_t position_offset = 0, size_t rot_dim = 0, ComputeBackend backend = ComputeBackend::CPU);
|
|
426
515
|
size_t softmax(size_t input, int axis = -1);
|
|
427
516
|
size_t attention(size_t query, size_t key, size_t value, float scale, bool is_causal = true, ComputeBackend backend = ComputeBackend::CPU);
|
|
428
517
|
size_t attention(size_t query, size_t key, size_t value, float scale, size_t position_offset, ComputeBackend backend = ComputeBackend::CPU);
|
|
429
518
|
size_t attention(size_t query, size_t key, size_t value, float scale, size_t position_offset, size_t window_size, ComputeBackend backend = ComputeBackend::CPU);
|
|
519
|
+
size_t attention_masked(size_t query, size_t key, size_t value, size_t mask, float scale,
|
|
520
|
+
bool is_causal = true, ComputeBackend backend = ComputeBackend::CPU,
|
|
521
|
+
bool additive_mask = false, size_t position_offset = 0, size_t window_size = 0);
|
|
522
|
+
size_t rel_pos_bias(size_t query, size_t relative_key, float scale);
|
|
430
523
|
|
|
431
524
|
size_t attention_int8_hybrid(size_t query, size_t key_new, size_t value_new, float scale, size_t position_offset,
|
|
432
525
|
const int8_t* cached_keys, const int8_t* cached_values,
|
|
433
526
|
const float* k_scales, const float* v_scales,
|
|
434
|
-
size_t cache_len, size_t num_kv_heads, size_t head_dim);
|
|
527
|
+
size_t cache_len, size_t num_kv_heads, size_t head_dim, size_t window_size = 0);
|
|
435
528
|
|
|
436
529
|
size_t conv1d_causal(size_t input, size_t weight, size_t kernel_size, size_t dilation = 1);
|
|
437
530
|
size_t conv1d_k3(size_t input, size_t weight, size_t stride);
|
|
438
|
-
|
|
531
|
+
size_t conv1d_k7s3(size_t input, size_t weight, size_t bias);
|
|
532
|
+
size_t conv1d(size_t input, size_t weight, size_t stride);
|
|
533
|
+
size_t conv1d(size_t input, size_t weight, size_t bias, size_t stride);
|
|
534
|
+
size_t conv1d_same_depthwise_k9(size_t input, size_t weight);
|
|
535
|
+
size_t conv1d_same_depthwise_k9(size_t input, size_t weight, size_t bias);
|
|
536
|
+
size_t conv1d_pointwise(size_t input, size_t weight);
|
|
537
|
+
size_t conv1d_pointwise(size_t input, size_t weight, size_t bias);
|
|
538
|
+
size_t conv2d_k3s2p1(size_t input, size_t weight);
|
|
539
|
+
size_t conv2d_k3s2p1(size_t input, size_t weight, size_t bias);
|
|
540
|
+
size_t conv2d_depthwise_k3s2p1(size_t input, size_t weight);
|
|
541
|
+
size_t conv2d_depthwise_k3s2p1(size_t input, size_t weight, size_t bias);
|
|
542
|
+
size_t conv2d_pointwise_1x1(size_t input, size_t weight);
|
|
543
|
+
size_t conv2d_pointwise_1x1(size_t input, size_t weight, size_t bias);
|
|
544
|
+
|
|
545
|
+
size_t lstm_cell(size_t input, size_t h_prev, size_t c_prev, size_t weight_ih, size_t weight_hh, size_t bias_ih, size_t bias_hh);
|
|
546
|
+
size_t stft(size_t input, size_t weight, size_t stride, size_t num_fft_bins);
|
|
547
|
+
|
|
439
548
|
size_t sample(size_t logits, float temperature = 0.6f, float top_p = 0.95f, size_t top_k = 20,
|
|
440
549
|
const std::unordered_map<uint32_t, float>& logit_bias = {});
|
|
441
550
|
|
|
@@ -462,6 +571,10 @@ public:
|
|
|
462
571
|
void allocate_buffers();
|
|
463
572
|
size_t get_node_count() const;
|
|
464
573
|
|
|
574
|
+
size_t persistent(size_t source_node);
|
|
575
|
+
bool is_populated(size_t persistent_node_id) const;
|
|
576
|
+
void invalidate_persistent(size_t persistent_node_id);
|
|
577
|
+
|
|
465
578
|
std::vector<std::unique_ptr<GraphNode>> nodes_;
|
|
466
579
|
std::unordered_map<size_t, size_t> node_index_map_;
|
|
467
580
|
|
|
@@ -473,6 +586,9 @@ private:
|
|
|
473
586
|
std::vector<DebugNodeEntry> debug_nodes_;
|
|
474
587
|
BufferPool buffer_pool_;
|
|
475
588
|
bool prefill_mode_ = false;
|
|
589
|
+
|
|
590
|
+
std::unordered_set<size_t> persistent_node_ids_;
|
|
591
|
+
std::unordered_set<size_t> populated_node_ids_;
|
|
476
592
|
};
|
|
477
593
|
|
|
478
594
|
|
|
@@ -485,7 +601,6 @@ namespace GraphFile {
|
|
|
485
601
|
};
|
|
486
602
|
|
|
487
603
|
void save_node(CactusGraph& graph, size_t node_id, const std::string& filename);
|
|
488
|
-
LoadedNode load_into_graph(CactusGraph& graph, const std::string& filename);
|
|
489
604
|
|
|
490
605
|
class MappedFile {
|
|
491
606
|
public:
|
|
@@ -499,16 +614,14 @@ namespace GraphFile {
|
|
|
499
614
|
|
|
500
615
|
const std::vector<size_t>& shape() const;
|
|
501
616
|
Precision precision() const;
|
|
502
|
-
Precision effective_precision() const {
|
|
503
|
-
return is_int4_ ? Precision::INT8 : precision_;
|
|
504
|
-
}
|
|
505
617
|
size_t byte_size() const;
|
|
506
618
|
|
|
507
619
|
size_t group_size() const { return group_size_; }
|
|
508
620
|
size_t num_groups() const { return num_groups_; }
|
|
509
621
|
const void* scales_data() const;
|
|
510
|
-
|
|
511
|
-
bool
|
|
622
|
+
|
|
623
|
+
bool is_interleaved() const { return is_interleaved_; }
|
|
624
|
+
size_t original_N() const { return original_N_; }
|
|
512
625
|
|
|
513
626
|
void* data();
|
|
514
627
|
const void* data() const;
|
|
@@ -516,8 +629,6 @@ namespace GraphFile {
|
|
|
516
629
|
template<typename T>
|
|
517
630
|
const T* typed_data() const;
|
|
518
631
|
|
|
519
|
-
LoadedNode load_into_graph(CactusGraph& graph) const;
|
|
520
|
-
|
|
521
632
|
void release_pages();
|
|
522
633
|
void prefetch_pages();
|
|
523
634
|
|
|
@@ -532,16 +643,14 @@ namespace GraphFile {
|
|
|
532
643
|
size_t num_groups_ = 0;
|
|
533
644
|
size_t scales_offset_ = 0;
|
|
534
645
|
size_t scales_bytes_ = 0;
|
|
535
|
-
uint32_t version_ = 1;
|
|
536
646
|
uint32_t alignment_ = 32;
|
|
537
|
-
|
|
538
|
-
|
|
647
|
+
|
|
648
|
+
bool is_interleaved_ = false;
|
|
649
|
+
size_t original_N_ = 0;
|
|
650
|
+
|
|
539
651
|
void parse_header();
|
|
540
652
|
void apply_madvise_hints();
|
|
541
|
-
void unpack_int4_if_needed() const;
|
|
542
653
|
};
|
|
543
|
-
|
|
544
|
-
MappedFile mmap_load(const std::string& filename);
|
|
545
654
|
}
|
|
546
655
|
|
|
547
|
-
#endif
|
|
656
|
+
#endif
|