cactus-react-native 1.4.0 → 1.7.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/Cactus.podspec +1 -1
- package/README.md +465 -174
- package/android/CMakeLists.txt +24 -5
- package/android/src/main/jniLibs/arm64-v8a/libcactus.a +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libcurl.a +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libmbedcrypto.a +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libmbedtls.a +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libmbedx509.a +0 -0
- package/cpp/HybridCactus.cpp +157 -6
- package/cpp/HybridCactus.hpp +20 -3
- package/cpp/cactus_ffi.h +65 -30
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/cactus.h +0 -1
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/cactus_ffi.h +65 -30
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/cactus_utils.h +357 -122
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/engine.h +184 -63
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/gemma_tools.h +549 -0
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/graph.h +153 -27
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/kernel.h +90 -178
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/kernel_utils.h +276 -151
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/cactus +0 -0
- package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/cactus.h +0 -1
- package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/cactus_ffi.h +65 -30
- package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/cactus_utils.h +357 -122
- package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/engine.h +184 -63
- package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/gemma_tools.h +549 -0
- package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/graph.h +153 -27
- package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/kernel.h +90 -178
- package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/kernel_utils.h +276 -151
- package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/cactus +0 -0
- package/lib/module/classes/CactusLM.js +43 -58
- package/lib/module/classes/CactusLM.js.map +1 -1
- package/lib/module/classes/CactusSTT.js +64 -38
- package/lib/module/classes/CactusSTT.js.map +1 -1
- package/lib/module/classes/CactusVAD.js +95 -0
- package/lib/module/classes/CactusVAD.js.map +1 -0
- package/lib/module/hooks/useCactusLM.js +23 -15
- package/lib/module/hooks/useCactusLM.js.map +1 -1
- package/lib/module/hooks/useCactusSTT.js +85 -28
- package/lib/module/hooks/useCactusSTT.js.map +1 -1
- package/lib/module/hooks/useCactusVAD.js +171 -0
- package/lib/module/hooks/useCactusVAD.js.map +1 -0
- package/lib/module/index.js +2 -3
- package/lib/module/index.js.map +1 -1
- package/lib/module/modelRegistry.js +52 -0
- package/lib/module/modelRegistry.js.map +1 -0
- package/lib/module/native/Cactus.js +107 -8
- package/lib/module/native/Cactus.js.map +1 -1
- package/lib/module/native/CactusIndex.js.map +1 -1
- package/lib/module/native/index.js +0 -3
- package/lib/module/native/index.js.map +1 -1
- package/lib/module/types/CactusLM.js +2 -0
- package/lib/module/types/CactusSTT.js +2 -0
- package/lib/module/types/CactusVAD.js +4 -0
- package/lib/module/types/{CactusModel.js.map → CactusVAD.js.map} +1 -1
- package/lib/module/types/common.js +2 -0
- package/lib/module/types/{CactusSTTModel.js.map → common.js.map} +1 -1
- package/lib/typescript/src/classes/CactusLM.d.ts +8 -6
- package/lib/typescript/src/classes/CactusLM.d.ts.map +1 -1
- package/lib/typescript/src/classes/CactusSTT.d.ts +11 -6
- package/lib/typescript/src/classes/CactusSTT.d.ts.map +1 -1
- package/lib/typescript/src/classes/CactusVAD.d.ts +20 -0
- package/lib/typescript/src/classes/CactusVAD.d.ts.map +1 -0
- package/lib/typescript/src/hooks/useCactusLM.d.ts +3 -3
- package/lib/typescript/src/hooks/useCactusLM.d.ts.map +1 -1
- package/lib/typescript/src/hooks/useCactusSTT.d.ts +11 -5
- package/lib/typescript/src/hooks/useCactusSTT.d.ts.map +1 -1
- package/lib/typescript/src/hooks/useCactusVAD.d.ts +15 -0
- package/lib/typescript/src/hooks/useCactusVAD.d.ts.map +1 -0
- package/lib/typescript/src/index.d.ts +7 -6
- package/lib/typescript/src/index.d.ts.map +1 -1
- package/lib/typescript/src/modelRegistry.d.ts +5 -0
- package/lib/typescript/src/modelRegistry.d.ts.map +1 -0
- package/lib/typescript/src/native/Cactus.d.ts +12 -6
- package/lib/typescript/src/native/Cactus.d.ts.map +1 -1
- package/lib/typescript/src/native/CactusIndex.d.ts +2 -2
- package/lib/typescript/src/native/CactusIndex.d.ts.map +1 -1
- package/lib/typescript/src/native/index.d.ts +0 -3
- package/lib/typescript/src/native/index.d.ts.map +1 -1
- package/lib/typescript/src/specs/Cactus.nitro.d.ts +6 -1
- package/lib/typescript/src/specs/Cactus.nitro.d.ts.map +1 -1
- package/lib/typescript/src/types/CactusIndex.d.ts +2 -2
- package/lib/typescript/src/types/CactusIndex.d.ts.map +1 -1
- package/lib/typescript/src/types/CactusLM.d.ts +19 -9
- package/lib/typescript/src/types/CactusLM.d.ts.map +1 -1
- package/lib/typescript/src/types/CactusSTT.d.ts +45 -4
- package/lib/typescript/src/types/CactusSTT.d.ts.map +1 -1
- package/lib/typescript/src/types/CactusVAD.d.ts +34 -0
- package/lib/typescript/src/types/CactusVAD.d.ts.map +1 -0
- package/lib/typescript/src/types/common.d.ts +23 -0
- package/lib/typescript/src/types/common.d.ts.map +1 -0
- package/nitro.json +0 -11
- package/nitrogen/generated/android/cactus+autolinking.cmake +0 -5
- package/nitrogen/generated/android/cactusOnLoad.cpp +0 -30
- package/nitrogen/generated/ios/Cactus-Swift-Cxx-Bridge.cpp +0 -50
- package/nitrogen/generated/ios/Cactus-Swift-Cxx-Bridge.hpp +9 -147
- package/nitrogen/generated/ios/Cactus-Swift-Cxx-Umbrella.hpp +0 -13
- package/nitrogen/generated/ios/CactusAutolinking.mm +0 -26
- package/nitrogen/generated/ios/CactusAutolinking.swift +0 -30
- package/nitrogen/generated/shared/c++/HybridCactusSpec.cpp +5 -0
- package/nitrogen/generated/shared/c++/HybridCactusSpec.hpp +6 -1
- package/package.json +3 -3
- package/src/classes/CactusLM.ts +59 -74
- package/src/classes/CactusSTT.ts +92 -49
- package/src/classes/CactusVAD.ts +129 -0
- package/src/hooks/useCactusLM.ts +26 -9
- package/src/hooks/useCactusSTT.ts +105 -44
- package/src/hooks/useCactusVAD.ts +215 -0
- package/src/index.tsx +20 -10
- package/src/modelRegistry.ts +65 -0
- package/src/native/Cactus.ts +130 -14
- package/src/native/CactusIndex.ts +2 -2
- package/src/native/index.ts +0 -3
- package/src/specs/Cactus.nitro.ts +11 -2
- package/src/types/CactusIndex.ts +2 -2
- package/src/types/CactusLM.ts +20 -9
- package/src/types/CactusSTT.ts +50 -4
- package/src/types/CactusVAD.ts +39 -0
- package/src/types/common.ts +23 -0
- package/android/src/main/java/com/margelo/nitro/cactus/HybridCactusCrypto.kt +0 -46
- package/android/src/main/java/com/margelo/nitro/cactus/HybridCactusDeviceInfo.kt +0 -27
- package/android/src/main/jniLibs/arm64-v8a/libcactus_util.a +0 -0
- package/cpp/HybridCactusUtil.cpp +0 -47
- package/cpp/HybridCactusUtil.hpp +0 -27
- package/cpp/cactus_util.h +0 -25
- package/ios/HybridCactusCrypto.swift +0 -37
- package/ios/HybridCactusDeviceInfo.swift +0 -32
- package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/cactus_telemetry.h +0 -656
- package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/cactus_telemetry.h +0 -656
- package/ios/cactus_util.xcframework/Info.plist +0 -39
- package/ios/cactus_util.xcframework/ios-arm64/cactus_util.framework/Headers/cactus_util.h +0 -25
- package/ios/cactus_util.xcframework/ios-arm64/cactus_util.framework/Headers/database.h +0 -27
- package/ios/cactus_util.xcframework/ios-arm64/cactus_util.framework/Headers/ios_utils.h +0 -10
- package/ios/cactus_util.xcframework/ios-arm64/cactus_util.framework/Headers/logging.h +0 -25
- package/ios/cactus_util.xcframework/ios-arm64/cactus_util.framework/Info.plist +0 -0
- package/ios/cactus_util.xcframework/ios-arm64/cactus_util.framework/cactus_util +0 -0
- package/ios/cactus_util.xcframework/ios-arm64-simulator/cactus_util.framework/Headers/cactus_util.h +0 -25
- package/ios/cactus_util.xcframework/ios-arm64-simulator/cactus_util.framework/Headers/database.h +0 -27
- package/ios/cactus_util.xcframework/ios-arm64-simulator/cactus_util.framework/Headers/ios_utils.h +0 -10
- package/ios/cactus_util.xcframework/ios-arm64-simulator/cactus_util.framework/Headers/logging.h +0 -25
- package/ios/cactus_util.xcframework/ios-arm64-simulator/cactus_util.framework/Info.plist +0 -0
- package/ios/cactus_util.xcframework/ios-arm64-simulator/cactus_util.framework/_CodeSignature/CodeResources +0 -135
- package/ios/cactus_util.xcframework/ios-arm64-simulator/cactus_util.framework/cactus_util +0 -0
- package/lib/module/api/Database.js +0 -137
- package/lib/module/api/Database.js.map +0 -1
- package/lib/module/api/RemoteLM.js +0 -201
- package/lib/module/api/RemoteLM.js.map +0 -1
- package/lib/module/config/CactusConfig.js +0 -12
- package/lib/module/config/CactusConfig.js.map +0 -1
- package/lib/module/native/CactusCrypto.js +0 -10
- package/lib/module/native/CactusCrypto.js.map +0 -1
- package/lib/module/native/CactusDeviceInfo.js +0 -13
- package/lib/module/native/CactusDeviceInfo.js.map +0 -1
- package/lib/module/native/CactusUtil.js +0 -36
- package/lib/module/native/CactusUtil.js.map +0 -1
- package/lib/module/specs/CactusCrypto.nitro.js +0 -4
- package/lib/module/specs/CactusCrypto.nitro.js.map +0 -1
- package/lib/module/specs/CactusDeviceInfo.nitro.js +0 -4
- package/lib/module/specs/CactusDeviceInfo.nitro.js.map +0 -1
- package/lib/module/specs/CactusUtil.nitro.js +0 -4
- package/lib/module/specs/CactusUtil.nitro.js.map +0 -1
- package/lib/module/telemetry/Telemetry.js +0 -154
- package/lib/module/telemetry/Telemetry.js.map +0 -1
- package/lib/module/types/CactusModel.js +0 -2
- package/lib/module/types/CactusSTTModel.js +0 -2
- package/lib/typescript/src/api/Database.d.ts +0 -18
- package/lib/typescript/src/api/Database.d.ts.map +0 -1
- package/lib/typescript/src/api/RemoteLM.d.ts +0 -14
- package/lib/typescript/src/api/RemoteLM.d.ts.map +0 -1
- package/lib/typescript/src/config/CactusConfig.d.ts +0 -7
- package/lib/typescript/src/config/CactusConfig.d.ts.map +0 -1
- package/lib/typescript/src/native/CactusCrypto.d.ts +0 -5
- package/lib/typescript/src/native/CactusCrypto.d.ts.map +0 -1
- package/lib/typescript/src/native/CactusDeviceInfo.d.ts +0 -7
- package/lib/typescript/src/native/CactusDeviceInfo.d.ts.map +0 -1
- package/lib/typescript/src/native/CactusUtil.d.ts +0 -6
- package/lib/typescript/src/native/CactusUtil.d.ts.map +0 -1
- package/lib/typescript/src/specs/CactusCrypto.nitro.d.ts +0 -8
- package/lib/typescript/src/specs/CactusCrypto.nitro.d.ts.map +0 -1
- package/lib/typescript/src/specs/CactusDeviceInfo.nitro.d.ts +0 -16
- package/lib/typescript/src/specs/CactusDeviceInfo.nitro.d.ts.map +0 -1
- package/lib/typescript/src/specs/CactusUtil.nitro.d.ts +0 -10
- package/lib/typescript/src/specs/CactusUtil.nitro.d.ts.map +0 -1
- package/lib/typescript/src/telemetry/Telemetry.d.ts +0 -34
- package/lib/typescript/src/telemetry/Telemetry.d.ts.map +0 -1
- package/lib/typescript/src/types/CactusModel.d.ts +0 -13
- package/lib/typescript/src/types/CactusModel.d.ts.map +0 -1
- package/lib/typescript/src/types/CactusSTTModel.d.ts +0 -8
- package/lib/typescript/src/types/CactusSTTModel.d.ts.map +0 -1
- package/nitrogen/generated/android/c++/JDeviceInfo.hpp +0 -74
- package/nitrogen/generated/android/c++/JHybridCactusCryptoSpec.cpp +0 -65
- package/nitrogen/generated/android/c++/JHybridCactusCryptoSpec.hpp +0 -65
- package/nitrogen/generated/android/c++/JHybridCactusDeviceInfoSpec.cpp +0 -85
- package/nitrogen/generated/android/c++/JHybridCactusDeviceInfoSpec.hpp +0 -66
- package/nitrogen/generated/android/kotlin/com/margelo/nitro/cactus/DeviceInfo.kt +0 -50
- package/nitrogen/generated/android/kotlin/com/margelo/nitro/cactus/HybridCactusCryptoSpec.kt +0 -58
- package/nitrogen/generated/android/kotlin/com/margelo/nitro/cactus/HybridCactusDeviceInfoSpec.kt +0 -62
- package/nitrogen/generated/ios/c++/HybridCactusCryptoSpecSwift.cpp +0 -11
- package/nitrogen/generated/ios/c++/HybridCactusCryptoSpecSwift.hpp +0 -77
- package/nitrogen/generated/ios/c++/HybridCactusDeviceInfoSpecSwift.cpp +0 -11
- package/nitrogen/generated/ios/c++/HybridCactusDeviceInfoSpecSwift.hpp +0 -88
- package/nitrogen/generated/ios/swift/DeviceInfo.swift +0 -98
- package/nitrogen/generated/ios/swift/Func_void_DeviceInfo.swift +0 -47
- package/nitrogen/generated/ios/swift/Func_void_std__optional_std__string_.swift +0 -54
- package/nitrogen/generated/ios/swift/HybridCactusCryptoSpec.swift +0 -57
- package/nitrogen/generated/ios/swift/HybridCactusCryptoSpec_cxx.swift +0 -139
- package/nitrogen/generated/ios/swift/HybridCactusDeviceInfoSpec.swift +0 -58
- package/nitrogen/generated/ios/swift/HybridCactusDeviceInfoSpec_cxx.swift +0 -164
- package/nitrogen/generated/shared/c++/DeviceInfo.hpp +0 -92
- package/nitrogen/generated/shared/c++/HybridCactusCryptoSpec.cpp +0 -21
- package/nitrogen/generated/shared/c++/HybridCactusCryptoSpec.hpp +0 -63
- package/nitrogen/generated/shared/c++/HybridCactusDeviceInfoSpec.cpp +0 -22
- package/nitrogen/generated/shared/c++/HybridCactusDeviceInfoSpec.hpp +0 -67
- package/nitrogen/generated/shared/c++/HybridCactusUtilSpec.cpp +0 -23
- package/nitrogen/generated/shared/c++/HybridCactusUtilSpec.hpp +0 -66
- package/src/api/Database.ts +0 -188
- package/src/api/RemoteLM.ts +0 -273
- package/src/config/CactusConfig.ts +0 -11
- package/src/native/CactusCrypto.ts +0 -11
- package/src/native/CactusDeviceInfo.ts +0 -18
- package/src/native/CactusUtil.ts +0 -43
- package/src/specs/CactusCrypto.nitro.ts +0 -6
- package/src/specs/CactusDeviceInfo.nitro.ts +0 -15
- package/src/specs/CactusUtil.nitro.ts +0 -8
- package/src/telemetry/Telemetry.ts +0 -236
- package/src/types/CactusModel.ts +0 -15
- package/src/types/CactusSTTModel.ts +0 -10
|
@@ -15,12 +15,7 @@ enum class ScalarOpType {
|
|
|
15
15
|
SIN
|
|
16
16
|
};
|
|
17
17
|
|
|
18
|
-
|
|
19
|
-
void cactus_add_int8(const int8_t* a, const int8_t* b, int8_t* output, size_t num_elements);
|
|
20
|
-
void cactus_subtract_int8(const int8_t* a, const int8_t* b, int8_t* output, size_t num_elements);
|
|
21
|
-
void cactus_multiply_int8(const int8_t* a, const int8_t* b, int8_t* output, size_t num_elements);
|
|
22
|
-
void cactus_divide_int8(const int8_t* a, const int8_t* b, int8_t* output, size_t num_elements);
|
|
23
|
-
|
|
18
|
+
constexpr size_t KV_QUANT_GROUP_SIZE = 32;
|
|
24
19
|
|
|
25
20
|
void cactus_add_f16(const __fp16* a, const __fp16* b, __fp16* output, size_t num_elements);
|
|
26
21
|
void cactus_add_f16_clipped(const __fp16* a, const __fp16* b, __fp16* output, size_t num_elements);
|
|
@@ -28,27 +23,6 @@ void cactus_subtract_f16(const __fp16* a, const __fp16* b, __fp16* output, size_
|
|
|
28
23
|
void cactus_multiply_f16(const __fp16* a, const __fp16* b, __fp16* output, size_t num_elements);
|
|
29
24
|
void cactus_divide_f16(const __fp16* a, const __fp16* b, __fp16* output, size_t num_elements);
|
|
30
25
|
|
|
31
|
-
|
|
32
|
-
void cactus_add_f32(const float* a, const float* b, float* output, size_t num_elements);
|
|
33
|
-
void cactus_subtract_f32(const float* a, const float* b, float* output, size_t num_elements);
|
|
34
|
-
void cactus_multiply_f32(const float* a, const float* b, float* output, size_t num_elements);
|
|
35
|
-
void cactus_divide_f32(const float* a, const float* b, float* output, size_t num_elements);
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
void cactus_add_broadcast_int8(const int8_t* a, const int8_t* b, int8_t* output,
|
|
39
|
-
const size_t* a_strides, const size_t* b_strides,
|
|
40
|
-
const size_t* output_shape, size_t ndim);
|
|
41
|
-
void cactus_subtract_broadcast_int8(const int8_t* a, const int8_t* b, int8_t* output,
|
|
42
|
-
const size_t* a_strides, const size_t* b_strides,
|
|
43
|
-
const size_t* output_shape, size_t ndim);
|
|
44
|
-
void cactus_multiply_broadcast_int8(const int8_t* a, const int8_t* b, int8_t* output,
|
|
45
|
-
const size_t* a_strides, const size_t* b_strides,
|
|
46
|
-
const size_t* output_shape, size_t ndim);
|
|
47
|
-
void cactus_divide_broadcast_int8(const int8_t* a, const int8_t* b, int8_t* output,
|
|
48
|
-
const size_t* a_strides, const size_t* b_strides,
|
|
49
|
-
const size_t* output_shape, size_t ndim);
|
|
50
|
-
|
|
51
|
-
|
|
52
26
|
void cactus_add_broadcast_f16(const __fp16* a, const __fp16* b, __fp16* output,
|
|
53
27
|
const size_t* a_strides, const size_t* b_strides,
|
|
54
28
|
const size_t* output_shape, size_t ndim);
|
|
@@ -62,159 +36,85 @@ void cactus_divide_broadcast_f16(const __fp16* a, const __fp16* b, __fp16* outpu
|
|
|
62
36
|
const size_t* a_strides, const size_t* b_strides,
|
|
63
37
|
const size_t* output_shape, size_t ndim);
|
|
64
38
|
|
|
65
|
-
|
|
66
|
-
void cactus_add_broadcast_f32(const float* a, const float* b, float* output,
|
|
67
|
-
const size_t* a_strides, const size_t* b_strides,
|
|
68
|
-
const size_t* output_shape, size_t ndim);
|
|
69
|
-
void cactus_subtract_broadcast_f32(const float* a, const float* b, float* output,
|
|
70
|
-
const size_t* a_strides, const size_t* b_strides,
|
|
71
|
-
const size_t* output_shape, size_t ndim);
|
|
72
|
-
void cactus_multiply_broadcast_f32(const float* a, const float* b, float* output,
|
|
73
|
-
const size_t* a_strides, const size_t* b_strides,
|
|
74
|
-
const size_t* output_shape, size_t ndim);
|
|
75
|
-
void cactus_divide_broadcast_f32(const float* a, const float* b, float* output,
|
|
76
|
-
const size_t* a_strides, const size_t* b_strides,
|
|
77
|
-
const size_t* output_shape, size_t ndim);
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
void cactus_scalar_op_int8(const int8_t* input, int8_t* output, size_t num_elements, float scalar_value, ScalarOpType op_type);
|
|
81
39
|
void cactus_scalar_op_f16(const __fp16* input, __fp16* output, size_t num_elements, float scalar_value, ScalarOpType op_type);
|
|
82
|
-
void cactus_scalar_op_f32(const float* input, float* output, size_t num_elements, float scalar_value, ScalarOpType op_type);
|
|
83
40
|
|
|
41
|
+
void cactus_gemv_int8(const int8_t* A, float A_scale,
|
|
42
|
+
const int8_t* B, const __fp16* B_scales,
|
|
43
|
+
__fp16* C, size_t K, size_t N, size_t group_size);
|
|
84
44
|
|
|
85
|
-
void
|
|
86
|
-
|
|
87
|
-
|
|
45
|
+
void cactus_gemm_int8(const int8_t* A, const float* A_scales,
|
|
46
|
+
const int8_t* B, const __fp16* B_scales,
|
|
47
|
+
__fp16* C, size_t M, size_t K, size_t N, size_t group_size);
|
|
88
48
|
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
#define cactus_matmul_int8_to_int32 cactus_matmul_int8_to_int32_i8mm
|
|
93
|
-
#else
|
|
94
|
-
void cactus_matmul_int8_to_int32(const int8_t* a, const int8_t* b_transposed, int32_t* c,
|
|
95
|
-
size_t M, size_t K, size_t N);
|
|
96
|
-
#endif
|
|
49
|
+
void cactus_matmul_int8(const int8_t* A, const float* A_scales,
|
|
50
|
+
const int8_t* B, const __fp16* B_scales,
|
|
51
|
+
__fp16* C, size_t M, size_t K, size_t N, size_t group_size);
|
|
97
52
|
|
|
98
53
|
void cactus_matmul_f16(const __fp16* a, const __fp16* b_transposed, __fp16* c,
|
|
99
54
|
size_t M, size_t K, size_t N);
|
|
100
55
|
|
|
101
|
-
void cactus_matmul_f32(const float* a, const float* b_transposed, float* c,
|
|
102
|
-
size_t M, size_t K, size_t N);
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
void cactus_transpose_2d_int8(const int8_t* source, int8_t* destination,
|
|
106
|
-
size_t num_rows, size_t num_cols, size_t start_row, size_t end_row);
|
|
107
56
|
void cactus_transpose_2d_f16(const __fp16* source, __fp16* destination,
|
|
108
57
|
size_t num_rows, size_t num_cols, size_t start_row, size_t end_row);
|
|
109
|
-
void cactus_transpose_2d_f32(const float* source, float* destination,
|
|
110
|
-
size_t num_rows, size_t num_cols, size_t start_row, size_t end_row);
|
|
111
|
-
|
|
112
|
-
void cactus_transpose_int8(const int8_t* source, int8_t* destination, const size_t* shape,
|
|
113
|
-
const size_t* permutation, size_t ndim, size_t start_idx, size_t end_idx);
|
|
114
58
|
void cactus_transpose_f16(const __fp16* source, __fp16* destination, const size_t* shape,
|
|
115
59
|
const size_t* permutation, size_t ndim, size_t start_idx, size_t end_idx);
|
|
116
|
-
void cactus_transpose_f32(const float* source, float* destination, const size_t* shape,
|
|
117
|
-
const size_t* permutation, size_t ndim, size_t start_idx, size_t end_idx);
|
|
118
60
|
|
|
119
|
-
int64_t cactus_sum_all_int8(const int8_t* data, size_t num_elements);
|
|
120
|
-
void cactus_sum_axis_int8(const int8_t* input, int8_t* output, size_t outer_size, size_t axis_size, size_t inner_size);
|
|
121
61
|
double cactus_sum_all_f16(const __fp16* data, size_t num_elements);
|
|
122
|
-
|
|
123
|
-
void cactus_sum_axis_f32(const float* input, float* output, size_t outer_size, size_t axis_size, size_t inner_size);
|
|
62
|
+
void cactus_sum_axis_f16(const __fp16* input, __fp16* output, size_t outer_size, size_t axis_size, size_t inner_size);
|
|
124
63
|
|
|
125
|
-
double cactus_mean_all_int8(const int8_t* data, size_t num_elements);
|
|
126
|
-
void cactus_mean_axis_int8(const int8_t* input, int8_t* output, size_t outer_size, size_t axis_size, size_t inner_size);
|
|
127
64
|
double cactus_mean_all_f16(const __fp16* data, size_t num_elements);
|
|
128
65
|
void cactus_mean_axis_f16(const __fp16* input, __fp16* output, size_t outer_size, size_t axis_size, size_t inner_size);
|
|
129
|
-
double cactus_mean_all_f32(const float* data, size_t num_elements);
|
|
130
|
-
void cactus_mean_axis_f32(const float* input, float* output, size_t outer_size, size_t axis_size, size_t inner_size);
|
|
131
66
|
|
|
132
|
-
double
|
|
133
|
-
void
|
|
134
|
-
double cactus_variance_all_f32(const float* data, size_t num_elements);
|
|
135
|
-
void cactus_variance_axis_f32(const float* input, float* output, size_t outer_size, size_t axis_size, size_t inner_size);
|
|
67
|
+
double cactus_variance_all_f16(const __fp16* data, size_t num_elements);
|
|
68
|
+
void cactus_variance_axis_f16(const __fp16* input, __fp16* output, size_t outer_size, size_t axis_size, size_t inner_size);
|
|
136
69
|
|
|
137
|
-
|
|
138
|
-
void
|
|
139
|
-
float cactus_min_all_f32(const float* data, size_t num_elements);
|
|
140
|
-
void cactus_min_axis_f32(const float* input, float* output, size_t outer_size, size_t axis_size, size_t inner_size);
|
|
70
|
+
__fp16 cactus_min_all_f16(const __fp16* data, size_t num_elements);
|
|
71
|
+
void cactus_min_axis_f16(const __fp16* input, __fp16* output, size_t outer_size, size_t axis_size, size_t inner_size);
|
|
141
72
|
|
|
142
|
-
|
|
143
|
-
void
|
|
144
|
-
float cactus_max_all_f32(const float* data, size_t num_elements);
|
|
145
|
-
void cactus_max_axis_f32(const float* input, float* output, size_t outer_size, size_t axis_size, size_t inner_size);
|
|
73
|
+
__fp16 cactus_max_all_f16(const __fp16* data, size_t num_elements);
|
|
74
|
+
void cactus_max_axis_f16(const __fp16* input, __fp16* output, size_t outer_size, size_t axis_size, size_t inner_size);
|
|
146
75
|
|
|
147
76
|
void cactus_rms_norm_f16(const __fp16* input, const __fp16* weight, __fp16* output,
|
|
148
77
|
size_t batch_size, size_t dims, float eps);
|
|
149
|
-
|
|
150
|
-
void cactus_rms_norm_f32(const float* input, const float* weight, float* output,
|
|
151
|
-
size_t batch_size, size_t dims, float eps);
|
|
152
|
-
|
|
153
|
-
void cactus_rms_norm_i8_f32(const int8_t* input, const float* weight, float* output,
|
|
154
|
-
size_t batch_size, size_t dims, float eps, float input_scale);
|
|
155
78
|
|
|
156
79
|
void cactus_rope_f16(const __fp16* input, __fp16* output, size_t batch_size, size_t seq_len,
|
|
157
80
|
size_t num_heads, size_t head_dim, size_t start_pos, float theta);
|
|
158
81
|
|
|
159
|
-
void
|
|
160
|
-
|
|
82
|
+
void cactus_gpt_j_rope_f16(const __fp16* input, __fp16* output, size_t batch_size, size_t seq_len,
|
|
83
|
+
size_t num_heads, size_t head_dim, size_t rot_dim, size_t start_pos, float theta);
|
|
161
84
|
|
|
162
|
-
void
|
|
163
|
-
size_t num_heads, size_t head_dim, size_t start_pos, float theta,
|
|
164
|
-
float input_scale, float output_scale);
|
|
165
|
-
|
|
166
|
-
void cactus_softmax_f16(const __fp16* input, __fp16* output, size_t batch_size,
|
|
85
|
+
void cactus_softmax_f16(const __fp16* input, __fp16* output, size_t batch_size,
|
|
167
86
|
size_t seq_len, size_t vocab_size);
|
|
168
87
|
|
|
169
|
-
void
|
|
170
|
-
size_t seq_len, size_t vocab_size);
|
|
88
|
+
void cactus_relu_f16(const __fp16* input, __fp16* output, size_t num_elements);
|
|
171
89
|
|
|
172
|
-
void cactus_silu_f32(const float* input, float* output, size_t num_elements);
|
|
173
90
|
void cactus_silu_f16(const __fp16* input, __fp16* output, size_t num_elements);
|
|
174
|
-
void cactus_silu_int8(const int8_t* input, int8_t* output, size_t num_elements,
|
|
175
|
-
float input_scale, float output_scale);
|
|
176
91
|
|
|
177
|
-
void cactus_gelu_f32(const float* input, float* output, size_t num_elements);
|
|
178
92
|
void cactus_gelu_f16(const __fp16* input, __fp16* output, size_t num_elements);
|
|
179
|
-
void cactus_gelu_int8(const int8_t* input, int8_t* output, size_t num_elements,
|
|
180
|
-
float input_scale, float output_scale);
|
|
181
93
|
|
|
182
|
-
void cactus_gelu_f32_erf(const float* input, float* output, size_t num_elements);
|
|
183
94
|
void cactus_gelu_f16_erf(const __fp16* input, __fp16* output, size_t num_elements);
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
float scale_in,
|
|
189
|
-
float scale_out);
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
void cactus_attention_int8(const int8_t* queries, const int8_t* keys, const int8_t* values, int8_t* output,
|
|
193
|
-
size_t batch_size, size_t seq_len, size_t kv_seq_len, size_t num_q_heads, size_t num_kv_heads,
|
|
194
|
-
size_t head_dim, float scale, const int8_t* mask,
|
|
195
|
-
float q_scale, float k_scale, float v_scale, float output_scale, size_t position_offset = 0, size_t window_size = 0,
|
|
196
|
-
bool is_causal = true);
|
|
95
|
+
|
|
96
|
+
void cactus_sigmoid_f16(const __fp16* input, __fp16* output, size_t num_elements);
|
|
97
|
+
|
|
98
|
+
void cactus_tanh_f16(const __fp16* input, __fp16* output, size_t num_elements);
|
|
197
99
|
|
|
198
100
|
void cactus_attention_f16(const __fp16* queries, const __fp16* keys, const __fp16* values, __fp16* output,
|
|
199
101
|
size_t batch_size, size_t seq_len, size_t kv_seq_len, size_t num_q_heads, size_t num_kv_heads,
|
|
200
102
|
size_t head_dim, float scale, const __fp16* mask, size_t position_offset = 0, size_t window_size = 0,
|
|
201
103
|
bool is_causal = true);
|
|
202
104
|
|
|
203
|
-
void
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
const
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
size_t
|
|
214
|
-
size_t
|
|
215
|
-
size_t
|
|
216
|
-
size_t K,
|
|
217
|
-
size_t dilation);
|
|
105
|
+
void cactus_attention_hybrid_int8_fp16(
|
|
106
|
+
const __fp16* queries,
|
|
107
|
+
const int8_t* keys_cached,
|
|
108
|
+
const int8_t* values_cached,
|
|
109
|
+
const float* k_scales,
|
|
110
|
+
const float* v_scales,
|
|
111
|
+
const __fp16* keys_new,
|
|
112
|
+
const __fp16* values_new,
|
|
113
|
+
__fp16* output,
|
|
114
|
+
size_t batch_size, size_t seq_len, size_t cache_len, size_t new_len,
|
|
115
|
+
size_t num_q_heads, size_t num_kv_heads, size_t head_dim,
|
|
116
|
+
float scale, size_t position_offset = 0, bool is_causal = true, size_t window_size = 0,
|
|
117
|
+
size_t group_size = KV_QUANT_GROUP_SIZE);
|
|
218
118
|
|
|
219
119
|
void cactus_conv1d_causal_depthwise_f16(
|
|
220
120
|
const __fp16* input,
|
|
@@ -226,23 +126,10 @@ void cactus_conv1d_causal_depthwise_f16(
|
|
|
226
126
|
size_t K,
|
|
227
127
|
size_t dilation);
|
|
228
128
|
|
|
229
|
-
void
|
|
230
|
-
const
|
|
231
|
-
const
|
|
232
|
-
|
|
233
|
-
size_t N,
|
|
234
|
-
size_t L,
|
|
235
|
-
size_t C,
|
|
236
|
-
size_t K,
|
|
237
|
-
size_t dilation,
|
|
238
|
-
float input_scale,
|
|
239
|
-
float weight_scale,
|
|
240
|
-
float output_scale);
|
|
241
|
-
|
|
242
|
-
void cactus_conv1d_f32_k3(
|
|
243
|
-
const float* input,
|
|
244
|
-
const float* weight,
|
|
245
|
-
float* output,
|
|
129
|
+
void cactus_conv1d_f16_k3(
|
|
130
|
+
const __fp16* input,
|
|
131
|
+
const __fp16* weight,
|
|
132
|
+
__fp16* output,
|
|
246
133
|
size_t N,
|
|
247
134
|
size_t L,
|
|
248
135
|
size_t C_in,
|
|
@@ -250,37 +137,42 @@ void cactus_conv1d_f32_k3(
|
|
|
250
137
|
size_t stride
|
|
251
138
|
);
|
|
252
139
|
|
|
253
|
-
void
|
|
140
|
+
void cactus_conv1d_f16(
|
|
254
141
|
const __fp16* input,
|
|
255
142
|
const __fp16* weight,
|
|
143
|
+
const __fp16* bias,
|
|
256
144
|
__fp16* output,
|
|
257
145
|
size_t N,
|
|
258
146
|
size_t L,
|
|
259
147
|
size_t C_in,
|
|
260
148
|
size_t C_out,
|
|
149
|
+
size_t K,
|
|
261
150
|
size_t stride
|
|
262
151
|
);
|
|
263
152
|
|
|
264
|
-
void
|
|
265
|
-
const
|
|
266
|
-
const
|
|
267
|
-
|
|
153
|
+
void cactus_stft_magnitude_f16(
|
|
154
|
+
const __fp16* input,
|
|
155
|
+
const __fp16* weight,
|
|
156
|
+
__fp16* output,
|
|
268
157
|
size_t N, size_t L,
|
|
269
158
|
size_t C_in, size_t C_out,
|
|
270
|
-
size_t stride
|
|
159
|
+
size_t K, size_t stride,
|
|
160
|
+
size_t num_fft_bins
|
|
271
161
|
);
|
|
272
162
|
|
|
273
|
-
void
|
|
163
|
+
void cactus_conv1d_f16_k7s3_oc8(
|
|
274
164
|
const __fp16* input,
|
|
275
|
-
const __fp16*
|
|
165
|
+
const __fp16* Wpack,
|
|
166
|
+
const __fp16* bias,
|
|
276
167
|
__fp16* output,
|
|
277
|
-
size_t N,
|
|
278
|
-
size_t
|
|
279
|
-
size_t
|
|
168
|
+
size_t N,
|
|
169
|
+
size_t L,
|
|
170
|
+
size_t C_in,
|
|
171
|
+
size_t C_out
|
|
280
172
|
);
|
|
281
173
|
|
|
282
|
-
void
|
|
283
|
-
|
|
174
|
+
void cactus_bilinear_interpolation_f16(const __fp16* input, __fp16* output, size_t src_height, size_t src_width, size_t embed_dim,
|
|
175
|
+
size_t dst_height, size_t dst_width);
|
|
284
176
|
|
|
285
177
|
void cactus_sample_f32(const float* logits, uint32_t* output, size_t vocab_size,
|
|
286
178
|
float temperature, float top_p, size_t top_k, size_t random_seed,
|
|
@@ -291,25 +183,45 @@ void cactus_sample_f16(const __fp16* logits, uint32_t* output, size_t vocab_size
|
|
|
291
183
|
const float* bias_values = nullptr, const uint32_t* bias_indices = nullptr,
|
|
292
184
|
size_t bias_count = 0);
|
|
293
185
|
|
|
294
|
-
|
|
295
|
-
void cactus_concat_f32(const float* input1, const float* input2, float* output,
|
|
296
|
-
const size_t* shape1, const size_t* shape2, const size_t* output_shape,
|
|
297
|
-
size_t ndims, int axis);
|
|
298
186
|
void cactus_concat_f16(const __fp16* input1, const __fp16* input2, __fp16* output,
|
|
299
187
|
const size_t* shape1, const size_t* shape2, const size_t* output_shape,
|
|
300
188
|
size_t ndims, int axis);
|
|
301
|
-
void cactus_concat_int8(const int8_t* input1, const int8_t* input2, int8_t* output,
|
|
302
|
-
const size_t* shape1, const size_t* shape2, const size_t* output_shape,
|
|
303
|
-
size_t ndims, int axis);
|
|
304
189
|
|
|
305
190
|
void cactus_int8_to_fp32(const int8_t* src, float* dst, size_t count, float scale = 1.0f);
|
|
306
191
|
void cactus_fp32_to_int8(const float* src, int8_t* dst, size_t count, float scale = 1.0f);
|
|
307
|
-
void cactus_dynamic_quantize_fp32_to_int8(const float* src, int8_t* dst, size_t count, float* computed_scale);
|
|
308
192
|
void cactus_fp16_to_fp32(const __fp16* src, float* dst, size_t count);
|
|
309
193
|
void cactus_fp32_to_fp16(const float* src, __fp16* dst, size_t count);
|
|
310
194
|
void cactus_int8_to_fp16(const int8_t* src, __fp16* dst, size_t count, float scale = 1.0f);
|
|
311
195
|
void cactus_fp16_to_int8(const __fp16* src, int8_t* dst, size_t count, float scale = 1.0f);
|
|
312
196
|
float cactus_fp16_max_abs(const __fp16* src, size_t count);
|
|
313
|
-
void cactus_int32_to_fp16_scaled(const int32_t* src, __fp16* dst, size_t count, float scale);
|
|
314
197
|
|
|
315
|
-
|
|
198
|
+
void cactus_quantize_kv_fp16_to_int8(
|
|
199
|
+
const __fp16* src,
|
|
200
|
+
int8_t* dst,
|
|
201
|
+
float* scales,
|
|
202
|
+
size_t seq_len, size_t kv_heads, size_t head_dim,
|
|
203
|
+
size_t group_size = KV_QUANT_GROUP_SIZE);
|
|
204
|
+
|
|
205
|
+
inline size_t kv_scales_count(size_t seq_len, size_t kv_heads, size_t head_dim, size_t group_size = KV_QUANT_GROUP_SIZE) {
|
|
206
|
+
size_t num_groups = (head_dim + group_size - 1) / group_size;
|
|
207
|
+
return seq_len * kv_heads * num_groups;
|
|
208
|
+
}
|
|
209
|
+
|
|
210
|
+
void cactus_unpack_int4_to_int8(const uint8_t* packed, int8_t* unpacked, size_t unpacked_count);
|
|
211
|
+
|
|
212
|
+
void cactus_lstm_cell_f16(
|
|
213
|
+
const __fp16* x_input,
|
|
214
|
+
const __fp16* h_prev,
|
|
215
|
+
const __fp16* c_prev,
|
|
216
|
+
const __fp16* weight_ih,
|
|
217
|
+
const __fp16* weight_hh,
|
|
218
|
+
const __fp16* bias_ih,
|
|
219
|
+
const __fp16* bias_hh,
|
|
220
|
+
__fp16* h_new,
|
|
221
|
+
__fp16* c_new,
|
|
222
|
+
size_t batch_size,
|
|
223
|
+
size_t input_size,
|
|
224
|
+
size_t hidden_size
|
|
225
|
+
);
|
|
226
|
+
|
|
227
|
+
#endif
|