cactus-react-native 1.4.0 → 1.7.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/Cactus.podspec +1 -1
- package/README.md +465 -174
- package/android/CMakeLists.txt +24 -5
- package/android/src/main/jniLibs/arm64-v8a/libcactus.a +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libcurl.a +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libmbedcrypto.a +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libmbedtls.a +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libmbedx509.a +0 -0
- package/cpp/HybridCactus.cpp +157 -6
- package/cpp/HybridCactus.hpp +20 -3
- package/cpp/cactus_ffi.h +65 -30
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/cactus.h +0 -1
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/cactus_ffi.h +65 -30
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/cactus_utils.h +357 -122
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/engine.h +184 -63
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/gemma_tools.h +549 -0
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/graph.h +153 -27
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/kernel.h +90 -178
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/kernel_utils.h +276 -151
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/cactus +0 -0
- package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/cactus.h +0 -1
- package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/cactus_ffi.h +65 -30
- package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/cactus_utils.h +357 -122
- package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/engine.h +184 -63
- package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/gemma_tools.h +549 -0
- package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/graph.h +153 -27
- package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/kernel.h +90 -178
- package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/kernel_utils.h +276 -151
- package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/cactus +0 -0
- package/lib/module/classes/CactusLM.js +43 -58
- package/lib/module/classes/CactusLM.js.map +1 -1
- package/lib/module/classes/CactusSTT.js +64 -38
- package/lib/module/classes/CactusSTT.js.map +1 -1
- package/lib/module/classes/CactusVAD.js +95 -0
- package/lib/module/classes/CactusVAD.js.map +1 -0
- package/lib/module/hooks/useCactusLM.js +23 -15
- package/lib/module/hooks/useCactusLM.js.map +1 -1
- package/lib/module/hooks/useCactusSTT.js +85 -28
- package/lib/module/hooks/useCactusSTT.js.map +1 -1
- package/lib/module/hooks/useCactusVAD.js +171 -0
- package/lib/module/hooks/useCactusVAD.js.map +1 -0
- package/lib/module/index.js +2 -3
- package/lib/module/index.js.map +1 -1
- package/lib/module/modelRegistry.js +52 -0
- package/lib/module/modelRegistry.js.map +1 -0
- package/lib/module/native/Cactus.js +107 -8
- package/lib/module/native/Cactus.js.map +1 -1
- package/lib/module/native/CactusIndex.js.map +1 -1
- package/lib/module/native/index.js +0 -3
- package/lib/module/native/index.js.map +1 -1
- package/lib/module/types/CactusLM.js +2 -0
- package/lib/module/types/CactusSTT.js +2 -0
- package/lib/module/types/CactusVAD.js +4 -0
- package/lib/module/types/{CactusModel.js.map → CactusVAD.js.map} +1 -1
- package/lib/module/types/common.js +2 -0
- package/lib/module/types/{CactusSTTModel.js.map → common.js.map} +1 -1
- package/lib/typescript/src/classes/CactusLM.d.ts +8 -6
- package/lib/typescript/src/classes/CactusLM.d.ts.map +1 -1
- package/lib/typescript/src/classes/CactusSTT.d.ts +11 -6
- package/lib/typescript/src/classes/CactusSTT.d.ts.map +1 -1
- package/lib/typescript/src/classes/CactusVAD.d.ts +20 -0
- package/lib/typescript/src/classes/CactusVAD.d.ts.map +1 -0
- package/lib/typescript/src/hooks/useCactusLM.d.ts +3 -3
- package/lib/typescript/src/hooks/useCactusLM.d.ts.map +1 -1
- package/lib/typescript/src/hooks/useCactusSTT.d.ts +11 -5
- package/lib/typescript/src/hooks/useCactusSTT.d.ts.map +1 -1
- package/lib/typescript/src/hooks/useCactusVAD.d.ts +15 -0
- package/lib/typescript/src/hooks/useCactusVAD.d.ts.map +1 -0
- package/lib/typescript/src/index.d.ts +7 -6
- package/lib/typescript/src/index.d.ts.map +1 -1
- package/lib/typescript/src/modelRegistry.d.ts +5 -0
- package/lib/typescript/src/modelRegistry.d.ts.map +1 -0
- package/lib/typescript/src/native/Cactus.d.ts +12 -6
- package/lib/typescript/src/native/Cactus.d.ts.map +1 -1
- package/lib/typescript/src/native/CactusIndex.d.ts +2 -2
- package/lib/typescript/src/native/CactusIndex.d.ts.map +1 -1
- package/lib/typescript/src/native/index.d.ts +0 -3
- package/lib/typescript/src/native/index.d.ts.map +1 -1
- package/lib/typescript/src/specs/Cactus.nitro.d.ts +6 -1
- package/lib/typescript/src/specs/Cactus.nitro.d.ts.map +1 -1
- package/lib/typescript/src/types/CactusIndex.d.ts +2 -2
- package/lib/typescript/src/types/CactusIndex.d.ts.map +1 -1
- package/lib/typescript/src/types/CactusLM.d.ts +19 -9
- package/lib/typescript/src/types/CactusLM.d.ts.map +1 -1
- package/lib/typescript/src/types/CactusSTT.d.ts +45 -4
- package/lib/typescript/src/types/CactusSTT.d.ts.map +1 -1
- package/lib/typescript/src/types/CactusVAD.d.ts +34 -0
- package/lib/typescript/src/types/CactusVAD.d.ts.map +1 -0
- package/lib/typescript/src/types/common.d.ts +23 -0
- package/lib/typescript/src/types/common.d.ts.map +1 -0
- package/nitro.json +0 -11
- package/nitrogen/generated/android/cactus+autolinking.cmake +0 -5
- package/nitrogen/generated/android/cactusOnLoad.cpp +0 -30
- package/nitrogen/generated/ios/Cactus-Swift-Cxx-Bridge.cpp +0 -50
- package/nitrogen/generated/ios/Cactus-Swift-Cxx-Bridge.hpp +9 -147
- package/nitrogen/generated/ios/Cactus-Swift-Cxx-Umbrella.hpp +0 -13
- package/nitrogen/generated/ios/CactusAutolinking.mm +0 -26
- package/nitrogen/generated/ios/CactusAutolinking.swift +0 -30
- package/nitrogen/generated/shared/c++/HybridCactusSpec.cpp +5 -0
- package/nitrogen/generated/shared/c++/HybridCactusSpec.hpp +6 -1
- package/package.json +3 -3
- package/src/classes/CactusLM.ts +59 -74
- package/src/classes/CactusSTT.ts +92 -49
- package/src/classes/CactusVAD.ts +129 -0
- package/src/hooks/useCactusLM.ts +26 -9
- package/src/hooks/useCactusSTT.ts +105 -44
- package/src/hooks/useCactusVAD.ts +215 -0
- package/src/index.tsx +20 -10
- package/src/modelRegistry.ts +65 -0
- package/src/native/Cactus.ts +130 -14
- package/src/native/CactusIndex.ts +2 -2
- package/src/native/index.ts +0 -3
- package/src/specs/Cactus.nitro.ts +11 -2
- package/src/types/CactusIndex.ts +2 -2
- package/src/types/CactusLM.ts +20 -9
- package/src/types/CactusSTT.ts +50 -4
- package/src/types/CactusVAD.ts +39 -0
- package/src/types/common.ts +23 -0
- package/android/src/main/java/com/margelo/nitro/cactus/HybridCactusCrypto.kt +0 -46
- package/android/src/main/java/com/margelo/nitro/cactus/HybridCactusDeviceInfo.kt +0 -27
- package/android/src/main/jniLibs/arm64-v8a/libcactus_util.a +0 -0
- package/cpp/HybridCactusUtil.cpp +0 -47
- package/cpp/HybridCactusUtil.hpp +0 -27
- package/cpp/cactus_util.h +0 -25
- package/ios/HybridCactusCrypto.swift +0 -37
- package/ios/HybridCactusDeviceInfo.swift +0 -32
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/cactus_telemetry.h +0 -656
- package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/cactus_telemetry.h +0 -656
- package/ios/cactus_util.xcframework/Info.plist +0 -39
- package/ios/cactus_util.xcframework/ios-arm64/cactus_util.framework/Headers/cactus_util.h +0 -25
- package/ios/cactus_util.xcframework/ios-arm64/cactus_util.framework/Headers/database.h +0 -27
- package/ios/cactus_util.xcframework/ios-arm64/cactus_util.framework/Headers/ios_utils.h +0 -10
- package/ios/cactus_util.xcframework/ios-arm64/cactus_util.framework/Headers/logging.h +0 -25
- package/ios/cactus_util.xcframework/ios-arm64/cactus_util.framework/Info.plist +0 -0
- package/ios/cactus_util.xcframework/ios-arm64/cactus_util.framework/cactus_util +0 -0
- package/ios/cactus_util.xcframework/ios-arm64-simulator/cactus_util.framework/Headers/cactus_util.h +0 -25
- package/ios/cactus_util.xcframework/ios-arm64-simulator/cactus_util.framework/Headers/database.h +0 -27
- package/ios/cactus_util.xcframework/ios-arm64-simulator/cactus_util.framework/Headers/ios_utils.h +0 -10
- package/ios/cactus_util.xcframework/ios-arm64-simulator/cactus_util.framework/Headers/logging.h +0 -25
- package/ios/cactus_util.xcframework/ios-arm64-simulator/cactus_util.framework/Info.plist +0 -0
- package/ios/cactus_util.xcframework/ios-arm64-simulator/cactus_util.framework/_CodeSignature/CodeResources +0 -135
- package/ios/cactus_util.xcframework/ios-arm64-simulator/cactus_util.framework/cactus_util +0 -0
- package/lib/module/api/Database.js +0 -137
- package/lib/module/api/Database.js.map +0 -1
- package/lib/module/api/RemoteLM.js +0 -201
- package/lib/module/api/RemoteLM.js.map +0 -1
- package/lib/module/config/CactusConfig.js +0 -12
- package/lib/module/config/CactusConfig.js.map +0 -1
- package/lib/module/native/CactusCrypto.js +0 -10
- package/lib/module/native/CactusCrypto.js.map +0 -1
- package/lib/module/native/CactusDeviceInfo.js +0 -13
- package/lib/module/native/CactusDeviceInfo.js.map +0 -1
- package/lib/module/native/CactusUtil.js +0 -36
- package/lib/module/native/CactusUtil.js.map +0 -1
- package/lib/module/specs/CactusCrypto.nitro.js +0 -4
- package/lib/module/specs/CactusCrypto.nitro.js.map +0 -1
- package/lib/module/specs/CactusDeviceInfo.nitro.js +0 -4
- package/lib/module/specs/CactusDeviceInfo.nitro.js.map +0 -1
- package/lib/module/specs/CactusUtil.nitro.js +0 -4
- package/lib/module/specs/CactusUtil.nitro.js.map +0 -1
- package/lib/module/telemetry/Telemetry.js +0 -154
- package/lib/module/telemetry/Telemetry.js.map +0 -1
- package/lib/module/types/CactusModel.js +0 -2
- package/lib/module/types/CactusSTTModel.js +0 -2
- package/lib/typescript/src/api/Database.d.ts +0 -18
- package/lib/typescript/src/api/Database.d.ts.map +0 -1
- package/lib/typescript/src/api/RemoteLM.d.ts +0 -14
- package/lib/typescript/src/api/RemoteLM.d.ts.map +0 -1
- package/lib/typescript/src/config/CactusConfig.d.ts +0 -7
- package/lib/typescript/src/config/CactusConfig.d.ts.map +0 -1
- package/lib/typescript/src/native/CactusCrypto.d.ts +0 -5
- package/lib/typescript/src/native/CactusCrypto.d.ts.map +0 -1
- package/lib/typescript/src/native/CactusDeviceInfo.d.ts +0 -7
- package/lib/typescript/src/native/CactusDeviceInfo.d.ts.map +0 -1
- package/lib/typescript/src/native/CactusUtil.d.ts +0 -6
- package/lib/typescript/src/native/CactusUtil.d.ts.map +0 -1
- package/lib/typescript/src/specs/CactusCrypto.nitro.d.ts +0 -8
- package/lib/typescript/src/specs/CactusCrypto.nitro.d.ts.map +0 -1
- package/lib/typescript/src/specs/CactusDeviceInfo.nitro.d.ts +0 -16
- package/lib/typescript/src/specs/CactusDeviceInfo.nitro.d.ts.map +0 -1
- package/lib/typescript/src/specs/CactusUtil.nitro.d.ts +0 -10
- package/lib/typescript/src/specs/CactusUtil.nitro.d.ts.map +0 -1
- package/lib/typescript/src/telemetry/Telemetry.d.ts +0 -34
- package/lib/typescript/src/telemetry/Telemetry.d.ts.map +0 -1
- package/lib/typescript/src/types/CactusModel.d.ts +0 -13
- package/lib/typescript/src/types/CactusModel.d.ts.map +0 -1
- package/lib/typescript/src/types/CactusSTTModel.d.ts +0 -8
- package/lib/typescript/src/types/CactusSTTModel.d.ts.map +0 -1
- package/nitrogen/generated/android/c++/JDeviceInfo.hpp +0 -74
- package/nitrogen/generated/android/c++/JHybridCactusCryptoSpec.cpp +0 -65
- package/nitrogen/generated/android/c++/JHybridCactusCryptoSpec.hpp +0 -65
- package/nitrogen/generated/android/c++/JHybridCactusDeviceInfoSpec.cpp +0 -85
- package/nitrogen/generated/android/c++/JHybridCactusDeviceInfoSpec.hpp +0 -66
- package/nitrogen/generated/android/kotlin/com/margelo/nitro/cactus/DeviceInfo.kt +0 -50
- package/nitrogen/generated/android/kotlin/com/margelo/nitro/cactus/HybridCactusCryptoSpec.kt +0 -58
- package/nitrogen/generated/android/kotlin/com/margelo/nitro/cactus/HybridCactusDeviceInfoSpec.kt +0 -62
- package/nitrogen/generated/ios/c++/HybridCactusCryptoSpecSwift.cpp +0 -11
- package/nitrogen/generated/ios/c++/HybridCactusCryptoSpecSwift.hpp +0 -77
- package/nitrogen/generated/ios/c++/HybridCactusDeviceInfoSpecSwift.cpp +0 -11
- package/nitrogen/generated/ios/c++/HybridCactusDeviceInfoSpecSwift.hpp +0 -88
- package/nitrogen/generated/ios/swift/DeviceInfo.swift +0 -98
- package/nitrogen/generated/ios/swift/Func_void_DeviceInfo.swift +0 -47
- package/nitrogen/generated/ios/swift/Func_void_std__optional_std__string_.swift +0 -54
- package/nitrogen/generated/ios/swift/HybridCactusCryptoSpec.swift +0 -57
- package/nitrogen/generated/ios/swift/HybridCactusCryptoSpec_cxx.swift +0 -139
- package/nitrogen/generated/ios/swift/HybridCactusDeviceInfoSpec.swift +0 -58
- package/nitrogen/generated/ios/swift/HybridCactusDeviceInfoSpec_cxx.swift +0 -164
- package/nitrogen/generated/shared/c++/DeviceInfo.hpp +0 -92
- package/nitrogen/generated/shared/c++/HybridCactusCryptoSpec.cpp +0 -21
- package/nitrogen/generated/shared/c++/HybridCactusCryptoSpec.hpp +0 -63
- package/nitrogen/generated/shared/c++/HybridCactusDeviceInfoSpec.cpp +0 -22
- package/nitrogen/generated/shared/c++/HybridCactusDeviceInfoSpec.hpp +0 -67
- package/nitrogen/generated/shared/c++/HybridCactusUtilSpec.cpp +0 -23
- package/nitrogen/generated/shared/c++/HybridCactusUtilSpec.hpp +0 -66
- package/src/api/Database.ts +0 -188
- package/src/api/RemoteLM.ts +0 -273
- package/src/config/CactusConfig.ts +0 -11
- package/src/native/CactusCrypto.ts +0 -11
- package/src/native/CactusDeviceInfo.ts +0 -18
- package/src/native/CactusUtil.ts +0 -43
- package/src/specs/CactusCrypto.nitro.ts +0 -6
- package/src/specs/CactusDeviceInfo.nitro.ts +0 -15
- package/src/specs/CactusUtil.nitro.ts +0 -8
- package/src/telemetry/Telemetry.ts +0 -236
- package/src/types/CactusModel.ts +0 -15
- package/src/types/CactusSTTModel.ts +0 -10
|
@@ -4,6 +4,7 @@
|
|
|
4
4
|
#include <vector>
|
|
5
5
|
#include <memory>
|
|
6
6
|
#include <unordered_map>
|
|
7
|
+
#include <unordered_set>
|
|
7
8
|
#include <functional>
|
|
8
9
|
#include <cstring>
|
|
9
10
|
#include <stdexcept>
|
|
@@ -11,6 +12,7 @@
|
|
|
11
12
|
#include <mutex>
|
|
12
13
|
#include <sstream>
|
|
13
14
|
#include <iostream>
|
|
15
|
+
#include <arm_neon.h>
|
|
14
16
|
|
|
15
17
|
namespace cactus {
|
|
16
18
|
|
|
@@ -96,9 +98,10 @@ namespace GraphFile {
|
|
|
96
98
|
}
|
|
97
99
|
|
|
98
100
|
enum class Precision {
|
|
99
|
-
INT8,
|
|
101
|
+
INT8,
|
|
100
102
|
FP16,
|
|
101
|
-
FP32
|
|
103
|
+
FP32,
|
|
104
|
+
INT4
|
|
102
105
|
};
|
|
103
106
|
|
|
104
107
|
enum class ComputeBackend {
|
|
@@ -112,13 +115,17 @@ enum class OpType {
|
|
|
112
115
|
MATMUL, TRANSPOSE, RESHAPE, SLICE, GATHER, EMBEDDING,
|
|
113
116
|
BILINEAR_INTERPOLATION,
|
|
114
117
|
SUM, MEAN, VARIANCE, MIN, MAX,
|
|
115
|
-
RMS_NORM, ROPE, SOFTMAX, ATTENTION, CONV1D_CAUSAL, CONV1D_K3,
|
|
118
|
+
RMS_NORM, ROPE, ROPE_GPTJ, SOFTMAX, ATTENTION, ATTENTION_INT8_HYBRID, CONV1D_CAUSAL, CONV1D_K3, CONV1D_K7S3, CONV1D,
|
|
116
119
|
SCALAR_ADD, SCALAR_SUBTRACT, SCALAR_MULTIPLY, SCALAR_DIVIDE, SCALAR_EXP, SCALAR_SQRT, SCALAR_COS, SCALAR_SIN,
|
|
117
|
-
SILU, GELU, GELU_ERF,
|
|
120
|
+
RELU, SILU, GELU, GELU_ERF, SIGMOID, TANH,
|
|
118
121
|
SAMPLE, CONCAT,
|
|
119
122
|
SCATTER_TOPK,
|
|
120
|
-
TOPK, LAYERNORM,
|
|
123
|
+
TOPK, LAYERNORM, GROUPNORM,
|
|
121
124
|
INDEX,
|
|
125
|
+
PERSISTENT,
|
|
126
|
+
QUANTIZE_ACTIVATIONS,
|
|
127
|
+
LSTM_CELL,
|
|
128
|
+
STFT_MAGNITUDE
|
|
122
129
|
};
|
|
123
130
|
|
|
124
131
|
struct PrecisionTraits {
|
|
@@ -127,22 +134,32 @@ struct PrecisionTraits {
|
|
|
127
134
|
case Precision::INT8: return 1;
|
|
128
135
|
case Precision::FP16: return 2;
|
|
129
136
|
case Precision::FP32: return 4;
|
|
137
|
+
case Precision::INT4: return 1;
|
|
130
138
|
}
|
|
131
139
|
return 1;
|
|
132
140
|
}
|
|
133
|
-
|
|
141
|
+
|
|
142
|
+
static constexpr size_t packed_size_of(Precision prec, size_t count) {
|
|
143
|
+
switch (prec) {
|
|
144
|
+
case Precision::INT4: return (count + 1) / 2;
|
|
145
|
+
default: return count * size_of(prec);
|
|
146
|
+
}
|
|
147
|
+
}
|
|
148
|
+
|
|
134
149
|
static constexpr bool is_integer(Precision prec) {
|
|
135
150
|
switch (prec) {
|
|
136
151
|
case Precision::INT8: return true;
|
|
152
|
+
case Precision::INT4: return true;
|
|
137
153
|
case Precision::FP16: return false;
|
|
138
154
|
case Precision::FP32: return false;
|
|
139
155
|
}
|
|
140
156
|
return true;
|
|
141
157
|
}
|
|
142
|
-
|
|
158
|
+
|
|
143
159
|
static constexpr bool is_floating_point(Precision prec) {
|
|
144
160
|
switch (prec) {
|
|
145
161
|
case Precision::INT8: return false;
|
|
162
|
+
case Precision::INT4: return false;
|
|
146
163
|
case Precision::FP16: return true;
|
|
147
164
|
case Precision::FP32: return true;
|
|
148
165
|
}
|
|
@@ -153,8 +170,6 @@ struct PrecisionTraits {
|
|
|
153
170
|
namespace Quantization {
|
|
154
171
|
void int8_to_fp32(const int8_t* src, float* dst, size_t count, float scale = 1.0f);
|
|
155
172
|
void fp32_to_int8(const float* src, int8_t* dst, size_t count, float scale = 1.0f);
|
|
156
|
-
void dynamic_quantize_fp32_to_int8(const float* src, int8_t* dst, size_t count,
|
|
157
|
-
float* computed_scale);
|
|
158
173
|
void fp16_to_fp32(const __fp16* src, float* dst, size_t count);
|
|
159
174
|
void fp32_to_fp16(const float* src, __fp16* dst, size_t count);
|
|
160
175
|
void int8_to_fp16(const int8_t* src, __fp16* dst, size_t count, float scale = 1.0f);
|
|
@@ -188,10 +203,21 @@ struct BufferDesc {
|
|
|
188
203
|
void* external_data;
|
|
189
204
|
char* pooled_data;
|
|
190
205
|
Precision precision;
|
|
191
|
-
|
|
206
|
+
|
|
207
|
+
size_t group_size = 0;
|
|
208
|
+
size_t num_groups = 0;
|
|
209
|
+
void* scales_data = nullptr;
|
|
210
|
+
std::unique_ptr<char[]> owned_scales;
|
|
211
|
+
|
|
212
|
+
bool is_interleaved = false;
|
|
213
|
+
size_t original_N = 0;
|
|
214
|
+
|
|
215
|
+
void* activation_scales_data = nullptr;
|
|
216
|
+
std::unique_ptr<char[]> owned_activation_scales;
|
|
217
|
+
size_t num_rows_for_activation_scales = 0;
|
|
192
218
|
|
|
193
219
|
BufferDesc();
|
|
194
|
-
BufferDesc(const std::vector<size_t>& s, Precision prec = Precision::INT8
|
|
220
|
+
BufferDesc(const std::vector<size_t>& s, Precision prec = Precision::INT8);
|
|
195
221
|
~BufferDesc();
|
|
196
222
|
|
|
197
223
|
BufferDesc(BufferDesc&& other) noexcept;
|
|
@@ -209,6 +235,44 @@ struct BufferDesc {
|
|
|
209
235
|
template<typename T>
|
|
210
236
|
const T* data_as() const { return static_cast<const T*>(get_data()); }
|
|
211
237
|
|
|
238
|
+
const __fp16* scales_as_fp16() const {
|
|
239
|
+
return reinterpret_cast<const __fp16*>(scales_data);
|
|
240
|
+
}
|
|
241
|
+
|
|
242
|
+
bool is_grouped_int8() const {
|
|
243
|
+
return precision == Precision::INT8 && group_size > 0;
|
|
244
|
+
}
|
|
245
|
+
|
|
246
|
+
void set_grouped_scales(size_t gs, size_t ng, void* scales_ptr) {
|
|
247
|
+
group_size = gs;
|
|
248
|
+
num_groups = ng;
|
|
249
|
+
scales_data = scales_ptr;
|
|
250
|
+
}
|
|
251
|
+
|
|
252
|
+
void set_interleaved(bool interleaved, size_t orig_n) {
|
|
253
|
+
is_interleaved = interleaved;
|
|
254
|
+
original_N = orig_n;
|
|
255
|
+
}
|
|
256
|
+
|
|
257
|
+
bool has_activation_scales() const {
|
|
258
|
+
return activation_scales_data != nullptr && num_rows_for_activation_scales > 0;
|
|
259
|
+
}
|
|
260
|
+
const float* activation_scales_as_float() const {
|
|
261
|
+
return reinterpret_cast<const float*>(activation_scales_data);
|
|
262
|
+
}
|
|
263
|
+
float* activation_scales_as_float() {
|
|
264
|
+
return reinterpret_cast<float*>(activation_scales_data);
|
|
265
|
+
}
|
|
266
|
+
void allocate_activation_scales(size_t num_rows) {
|
|
267
|
+
num_rows_for_activation_scales = num_rows;
|
|
268
|
+
owned_activation_scales = std::make_unique<char[]>(num_rows * sizeof(float));
|
|
269
|
+
activation_scales_data = owned_activation_scales.get();
|
|
270
|
+
}
|
|
271
|
+
void set_activation_scales(void* scales_ptr, size_t num_rows) {
|
|
272
|
+
activation_scales_data = scales_ptr;
|
|
273
|
+
num_rows_for_activation_scales = num_rows;
|
|
274
|
+
}
|
|
275
|
+
|
|
212
276
|
void allocate();
|
|
213
277
|
void allocate_from_pool(BufferPool& pool);
|
|
214
278
|
void release_to_pool(BufferPool& pool);
|
|
@@ -242,11 +306,21 @@ struct OpParams {
|
|
|
242
306
|
|
|
243
307
|
size_t index_value = 0;
|
|
244
308
|
size_t num_classes = 0;
|
|
309
|
+
size_t num_groups = 0;
|
|
245
310
|
size_t dst_height = 0;
|
|
246
311
|
size_t dst_width = 0;
|
|
247
312
|
|
|
248
313
|
std::vector<float> bias_values;
|
|
249
314
|
std::vector<uint32_t> bias_indices;
|
|
315
|
+
|
|
316
|
+
const int8_t* cached_keys_int8 = nullptr;
|
|
317
|
+
const int8_t* cached_values_int8 = nullptr;
|
|
318
|
+
const float* cached_k_scales = nullptr;
|
|
319
|
+
const float* cached_v_scales = nullptr;
|
|
320
|
+
size_t cache_seq_len = 0;
|
|
321
|
+
size_t num_kv_heads = 0;
|
|
322
|
+
size_t head_dim = 0;
|
|
323
|
+
size_t num_fft_bins = 0;
|
|
250
324
|
};
|
|
251
325
|
|
|
252
326
|
struct GraphNode {
|
|
@@ -276,7 +350,10 @@ void compute_sample_node(GraphNode& node, const std::vector<std::unique_ptr<Grap
|
|
|
276
350
|
void compute_scatter_topk_node(GraphNode& node, const std::vector<std::unique_ptr<GraphNode>>& nodes, const std::unordered_map<size_t, size_t>& node_index_map);
|
|
277
351
|
void compute_topk_node(GraphNode& node, const std::vector<std::unique_ptr<GraphNode>>& nodes, const std::unordered_map<size_t, size_t>& node_index_map);
|
|
278
352
|
void compute_layernorm_node(GraphNode& node, const std::vector<std::unique_ptr<GraphNode>>& nodes, const std::unordered_map<size_t, size_t>& node_index_map);
|
|
353
|
+
void compute_groupnorm_node(GraphNode& node, const std::vector<std::unique_ptr<GraphNode>>& nodes, const std::unordered_map<size_t, size_t>& node_index_map);
|
|
354
|
+
void compute_persistent_node(GraphNode& node, const std::vector<std::unique_ptr<GraphNode>>& nodes, const std::unordered_map<size_t, size_t>& node_index_map);
|
|
279
355
|
void compute_index_node(GraphNode& node, const std::vector<std::unique_ptr<GraphNode>>& nodes, const std::unordered_map<size_t, size_t>& node_index_map);
|
|
356
|
+
void compute_lstm_cell_node(GraphNode& node, const std::vector<std::unique_ptr<GraphNode>>& nodes, const std::unordered_map<size_t, size_t>& node_index_map);
|
|
280
357
|
|
|
281
358
|
void shrink_thread_local_buffers();
|
|
282
359
|
|
|
@@ -324,9 +401,10 @@ public:
|
|
|
324
401
|
|
|
325
402
|
size_t input(const std::vector<size_t>& shape, Precision precision = Precision::INT8);
|
|
326
403
|
size_t precision_cast(size_t input, Precision target_precision);
|
|
404
|
+
size_t quantize_activations(size_t input);
|
|
327
405
|
|
|
328
406
|
size_t add(size_t input1, size_t input2);
|
|
329
|
-
size_t add_clipped(size_t input1, size_t input2);
|
|
407
|
+
size_t add_clipped(size_t input1, size_t input2);
|
|
330
408
|
size_t subtract(size_t input1, size_t input2);
|
|
331
409
|
size_t multiply(size_t input1, size_t input2);
|
|
332
410
|
size_t divide(size_t input1, size_t input2);
|
|
@@ -341,9 +419,12 @@ public:
|
|
|
341
419
|
size_t scalar_cos(size_t input);
|
|
342
420
|
size_t scalar_sin(size_t input);
|
|
343
421
|
|
|
422
|
+
size_t relu(size_t input);
|
|
344
423
|
size_t silu(size_t input);
|
|
345
424
|
size_t gelu(size_t input);
|
|
346
425
|
size_t gelu_erf(size_t input);
|
|
426
|
+
size_t sigmoid(size_t input);
|
|
427
|
+
size_t tanh(size_t input);
|
|
347
428
|
|
|
348
429
|
size_t matmul(size_t input1, size_t input2, bool pretransposed_rhs = false, ComputeBackend backend = ComputeBackend::CPU);
|
|
349
430
|
size_t transpose(size_t input, ComputeBackend backend = ComputeBackend::CPU);
|
|
@@ -361,24 +442,42 @@ public:
|
|
|
361
442
|
size_t gather(size_t embeddings, size_t indices);
|
|
362
443
|
size_t mmap_embeddings(const std::string& filename);
|
|
363
444
|
size_t mmap_weights(const std::string& filename);
|
|
364
|
-
|
|
365
|
-
void
|
|
445
|
+
void set_grouped_scales(size_t node_id, size_t group_size, size_t num_groups, void* scales_ptr);
|
|
446
|
+
void set_interleaved(size_t node_id, bool interleaved, size_t original_N);
|
|
447
|
+
|
|
448
|
+
void release_weight_pages(size_t node_id);
|
|
449
|
+
void prefetch_weight_pages(size_t node_id);
|
|
450
|
+
void release_all_weight_pages();
|
|
366
451
|
size_t embedding(const std::string& filename, size_t indices);
|
|
367
452
|
size_t embedding(size_t embedding_tensor, size_t indices);
|
|
368
453
|
size_t bilinear_interpolation(size_t pos_embeds, size_t dst_height, size_t dst_width);
|
|
369
454
|
|
|
370
455
|
size_t layernorm(size_t input, size_t weight, size_t bias, float epsilon = 1e-5f);
|
|
456
|
+
size_t layernorm(size_t input, size_t weight, float epsilon = 1e-5f); // No bias version
|
|
457
|
+
size_t groupnorm(size_t input, size_t weight, size_t bias, size_t num_groups = 32, float epsilon = 1e-5f);
|
|
371
458
|
size_t topk(size_t input, size_t k);
|
|
372
459
|
size_t rms_norm(size_t input, size_t weight, float epsilon = 1e-5f);
|
|
373
460
|
size_t rope(size_t input, float theta, size_t position_offset = 0, ComputeBackend backend = ComputeBackend::CPU);
|
|
461
|
+
size_t rope_gptj(size_t input, float theta, size_t position_offset = 0, size_t rot_dim = 0, ComputeBackend backend = ComputeBackend::CPU);
|
|
374
462
|
size_t softmax(size_t input, int axis = -1);
|
|
375
463
|
size_t attention(size_t query, size_t key, size_t value, float scale, bool is_causal = true, ComputeBackend backend = ComputeBackend::CPU);
|
|
376
464
|
size_t attention(size_t query, size_t key, size_t value, float scale, size_t position_offset, ComputeBackend backend = ComputeBackend::CPU);
|
|
377
465
|
size_t attention(size_t query, size_t key, size_t value, float scale, size_t position_offset, size_t window_size, ComputeBackend backend = ComputeBackend::CPU);
|
|
378
466
|
|
|
467
|
+
size_t attention_int8_hybrid(size_t query, size_t key_new, size_t value_new, float scale, size_t position_offset,
|
|
468
|
+
const int8_t* cached_keys, const int8_t* cached_values,
|
|
469
|
+
const float* k_scales, const float* v_scales,
|
|
470
|
+
size_t cache_len, size_t num_kv_heads, size_t head_dim, size_t window_size = 0);
|
|
471
|
+
|
|
379
472
|
size_t conv1d_causal(size_t input, size_t weight, size_t kernel_size, size_t dilation = 1);
|
|
380
473
|
size_t conv1d_k3(size_t input, size_t weight, size_t stride);
|
|
381
|
-
|
|
474
|
+
size_t conv1d_k7s3(size_t input, size_t weight, size_t bias);
|
|
475
|
+
size_t conv1d(size_t input, size_t weight, size_t stride);
|
|
476
|
+
size_t conv1d(size_t input, size_t weight, size_t bias, size_t stride);
|
|
477
|
+
|
|
478
|
+
size_t lstm_cell(size_t input, size_t h_prev, size_t c_prev, size_t weight_ih, size_t weight_hh, size_t bias_ih, size_t bias_hh);
|
|
479
|
+
size_t stft_magnitude(size_t input, size_t weight, size_t stride, size_t num_fft_bins);
|
|
480
|
+
|
|
382
481
|
size_t sample(size_t logits, float temperature = 0.6f, float top_p = 0.95f, size_t top_k = 20,
|
|
383
482
|
const std::unordered_map<uint32_t, float>& logit_bias = {});
|
|
384
483
|
|
|
@@ -392,6 +491,8 @@ public:
|
|
|
392
491
|
void execute(const std::string& profile_file = "");
|
|
393
492
|
void hard_reset();
|
|
394
493
|
void soft_reset();
|
|
494
|
+
void soft_reset_keep_pool();
|
|
495
|
+
void set_prefill_mode(bool enabled) { prefill_mode_ = enabled; }
|
|
395
496
|
|
|
396
497
|
void register_debug_node(uint32_t layer_idx, const std::string& name, size_t node_id);
|
|
397
498
|
void capture_debug_node(uint32_t layer_idx, const std::string& name, size_t node_id);
|
|
@@ -403,6 +504,10 @@ public:
|
|
|
403
504
|
void allocate_buffers();
|
|
404
505
|
size_t get_node_count() const;
|
|
405
506
|
|
|
507
|
+
size_t persistent(size_t source_node);
|
|
508
|
+
bool is_populated(size_t persistent_node_id) const;
|
|
509
|
+
void invalidate_persistent(size_t persistent_node_id);
|
|
510
|
+
|
|
406
511
|
std::vector<std::unique_ptr<GraphNode>> nodes_;
|
|
407
512
|
std::unordered_map<size_t, size_t> node_index_map_;
|
|
408
513
|
|
|
@@ -410,8 +515,13 @@ private:
|
|
|
410
515
|
size_t next_node_id_;
|
|
411
516
|
std::vector<std::unique_ptr<GraphFile::MappedFile>> mapped_files_;
|
|
412
517
|
std::unordered_map<std::string, size_t> weight_cache_;
|
|
518
|
+
std::unordered_map<size_t, size_t> node_to_mapped_file_;
|
|
413
519
|
std::vector<DebugNodeEntry> debug_nodes_;
|
|
414
520
|
BufferPool buffer_pool_;
|
|
521
|
+
bool prefill_mode_ = false;
|
|
522
|
+
|
|
523
|
+
std::unordered_set<size_t> persistent_node_ids_;
|
|
524
|
+
std::unordered_set<size_t> populated_node_ids_;
|
|
415
525
|
};
|
|
416
526
|
|
|
417
527
|
|
|
@@ -424,31 +534,37 @@ namespace GraphFile {
|
|
|
424
534
|
};
|
|
425
535
|
|
|
426
536
|
void save_node(CactusGraph& graph, size_t node_id, const std::string& filename);
|
|
427
|
-
LoadedNode load_into_graph(CactusGraph& graph, const std::string& filename);
|
|
428
537
|
|
|
429
538
|
class MappedFile {
|
|
430
539
|
public:
|
|
431
540
|
MappedFile(const std::string& filename);
|
|
432
541
|
~MappedFile();
|
|
433
|
-
|
|
542
|
+
|
|
434
543
|
MappedFile(const MappedFile&) = delete;
|
|
435
544
|
MappedFile& operator=(const MappedFile&) = delete;
|
|
436
545
|
MappedFile(MappedFile&& other) noexcept;
|
|
437
546
|
MappedFile& operator=(MappedFile&& other) noexcept;
|
|
438
|
-
|
|
547
|
+
|
|
439
548
|
const std::vector<size_t>& shape() const;
|
|
440
549
|
Precision precision() const;
|
|
441
550
|
size_t byte_size() const;
|
|
442
|
-
|
|
443
|
-
|
|
551
|
+
|
|
552
|
+
size_t group_size() const { return group_size_; }
|
|
553
|
+
size_t num_groups() const { return num_groups_; }
|
|
554
|
+
const void* scales_data() const;
|
|
555
|
+
|
|
556
|
+
bool is_interleaved() const { return is_interleaved_; }
|
|
557
|
+
size_t original_N() const { return original_N_; }
|
|
558
|
+
|
|
444
559
|
void* data();
|
|
445
560
|
const void* data() const;
|
|
446
|
-
|
|
561
|
+
|
|
447
562
|
template<typename T>
|
|
448
563
|
const T* typed_data() const;
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
564
|
+
|
|
565
|
+
void release_pages();
|
|
566
|
+
void prefetch_pages();
|
|
567
|
+
|
|
452
568
|
private:
|
|
453
569
|
int fd_;
|
|
454
570
|
void* mapped_data_;
|
|
@@ -456,11 +572,21 @@ namespace GraphFile {
|
|
|
456
572
|
std::vector<size_t> shape_;
|
|
457
573
|
Precision precision_;
|
|
458
574
|
size_t byte_size_;
|
|
459
|
-
|
|
575
|
+
size_t group_size_ = 0;
|
|
576
|
+
size_t num_groups_ = 0;
|
|
577
|
+
size_t scales_offset_ = 0;
|
|
578
|
+
size_t scales_bytes_ = 0;
|
|
579
|
+
uint32_t alignment_ = 32;
|
|
580
|
+
|
|
581
|
+
bool is_interleaved_ = false;
|
|
582
|
+
size_t original_N_ = 0;
|
|
583
|
+
|
|
584
|
+
std::unique_ptr<int8_t[]> unpacked_data_;
|
|
585
|
+
|
|
460
586
|
void parse_header();
|
|
587
|
+
void apply_madvise_hints();
|
|
588
|
+
void unpack_int4_data();
|
|
461
589
|
};
|
|
462
|
-
|
|
463
|
-
MappedFile mmap_load(const std::string& filename);
|
|
464
590
|
}
|
|
465
591
|
|
|
466
592
|
#endif
|