cactus-react-native 1.2.1 → 1.5.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (238) hide show
  1. package/README.md +765 -33
  2. package/android/CMakeLists.txt +4 -3
  3. package/android/src/main/java/com/margelo/nitro/cactus/HybridCactusFileSystem.kt +20 -1
  4. package/android/src/main/jniLibs/arm64-v8a/libcactus.a +0 -0
  5. package/android/src/main/jniLibs/arm64-v8a/libcactus_util.a +0 -0
  6. package/cpp/HybridCactus.cpp +231 -19
  7. package/cpp/HybridCactus.hpp +25 -3
  8. package/cpp/HybridCactusIndex.cpp +325 -0
  9. package/cpp/HybridCactusIndex.hpp +43 -0
  10. package/cpp/HybridCactusUtil.cpp +3 -3
  11. package/cpp/HybridCactusUtil.hpp +2 -1
  12. package/cpp/cactus_ffi.h +107 -2
  13. package/cpp/cactus_util.h +1 -1
  14. package/ios/HybridCactusFileSystem.swift +23 -2
  15. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/cactus.h +2 -0
  16. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/cactus_ffi.h +107 -2
  17. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/cactus_telemetry.h +656 -0
  18. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/{ffi_utils.h → cactus_utils.h} +145 -18
  19. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/engine.h +135 -7
  20. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/gemma_tools.h +549 -0
  21. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/graph.h +193 -26
  22. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/kernel.h +54 -195
  23. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/kernel_utils.h +399 -140
  24. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Info.plist +0 -0
  25. package/ios/cactus.xcframework/ios-arm64/cactus.framework/cactus +0 -0
  26. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/cactus.h +2 -0
  27. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/cactus_ffi.h +107 -2
  28. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/cactus_telemetry.h +656 -0
  29. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/{ffi_utils.h → cactus_utils.h} +145 -18
  30. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/engine.h +135 -7
  31. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/gemma_tools.h +549 -0
  32. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/graph.h +193 -26
  33. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/kernel.h +54 -195
  34. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/kernel_utils.h +399 -140
  35. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Info.plist +0 -0
  36. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/_CodeSignature/CodeResources +1 -1
  37. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/cactus +0 -0
  38. package/ios/cactus_util.xcframework/Info.plist +4 -4
  39. package/ios/cactus_util.xcframework/ios-arm64/cactus_util.framework/Headers/cactus_util.h +1 -1
  40. package/ios/cactus_util.xcframework/ios-arm64/cactus_util.framework/Headers/database.h +27 -0
  41. package/ios/cactus_util.xcframework/ios-arm64/cactus_util.framework/Info.plist +0 -0
  42. package/ios/cactus_util.xcframework/ios-arm64/cactus_util.framework/cactus_util +0 -0
  43. package/ios/cactus_util.xcframework/ios-arm64-simulator/cactus_util.framework/Headers/cactus_util.h +1 -1
  44. package/ios/cactus_util.xcframework/ios-arm64-simulator/cactus_util.framework/Headers/database.h +27 -0
  45. package/ios/cactus_util.xcframework/ios-arm64-simulator/cactus_util.framework/Info.plist +0 -0
  46. package/ios/cactus_util.xcframework/ios-arm64-simulator/cactus_util.framework/_CodeSignature/CodeResources +3 -3
  47. package/ios/cactus_util.xcframework/ios-arm64-simulator/cactus_util.framework/cactus_util +0 -0
  48. package/lib/module/api/Database.js +12 -95
  49. package/lib/module/api/Database.js.map +1 -1
  50. package/lib/module/classes/CactusIndex.js +45 -0
  51. package/lib/module/classes/CactusIndex.js.map +1 -0
  52. package/lib/module/classes/CactusLM.js +65 -17
  53. package/lib/module/classes/CactusLM.js.map +1 -1
  54. package/lib/module/classes/CactusSTT.js +104 -17
  55. package/lib/module/classes/CactusSTT.js.map +1 -1
  56. package/lib/module/config/CactusConfig.js +2 -0
  57. package/lib/module/config/CactusConfig.js.map +1 -1
  58. package/lib/module/constants/packageVersion.js +1 -1
  59. package/lib/module/hooks/useCactusIndex.js +175 -0
  60. package/lib/module/hooks/useCactusIndex.js.map +1 -0
  61. package/lib/module/hooks/useCactusLM.js +68 -7
  62. package/lib/module/hooks/useCactusLM.js.map +1 -1
  63. package/lib/module/hooks/useCactusSTT.js +102 -6
  64. package/lib/module/hooks/useCactusSTT.js.map +1 -1
  65. package/lib/module/index.js +2 -0
  66. package/lib/module/index.js.map +1 -1
  67. package/lib/module/models.js +336 -0
  68. package/lib/module/models.js.map +1 -0
  69. package/lib/module/native/Cactus.js +61 -13
  70. package/lib/module/native/Cactus.js.map +1 -1
  71. package/lib/module/native/CactusFileSystem.js +3 -0
  72. package/lib/module/native/CactusFileSystem.js.map +1 -1
  73. package/lib/module/native/CactusIndex.js +32 -0
  74. package/lib/module/native/CactusIndex.js.map +1 -0
  75. package/lib/module/native/CactusUtil.js +16 -3
  76. package/lib/module/native/CactusUtil.js.map +1 -1
  77. package/lib/module/native/index.js +1 -0
  78. package/lib/module/native/index.js.map +1 -1
  79. package/lib/module/specs/CactusIndex.nitro.js +4 -0
  80. package/lib/module/specs/CactusIndex.nitro.js.map +1 -0
  81. package/lib/module/telemetry/Telemetry.js +3 -1
  82. package/lib/module/telemetry/Telemetry.js.map +1 -1
  83. package/lib/module/types/CactusIndex.js +2 -0
  84. package/lib/module/types/{CactusModel.js.map → CactusIndex.js.map} +1 -1
  85. package/lib/module/types/CactusLM.js +2 -0
  86. package/lib/module/types/CactusSTT.js +2 -0
  87. package/lib/module/types/common.js +2 -0
  88. package/lib/module/types/{CactusSTTModel.js.map → common.js.map} +1 -1
  89. package/lib/typescript/src/api/Database.d.ts +4 -7
  90. package/lib/typescript/src/api/Database.d.ts.map +1 -1
  91. package/lib/typescript/src/classes/CactusIndex.d.ts +15 -0
  92. package/lib/typescript/src/classes/CactusIndex.d.ts.map +1 -0
  93. package/lib/typescript/src/classes/CactusLM.d.ts +12 -5
  94. package/lib/typescript/src/classes/CactusLM.d.ts.map +1 -1
  95. package/lib/typescript/src/classes/CactusSTT.d.ts +15 -5
  96. package/lib/typescript/src/classes/CactusSTT.d.ts.map +1 -1
  97. package/lib/typescript/src/config/CactusConfig.d.ts +1 -0
  98. package/lib/typescript/src/config/CactusConfig.d.ts.map +1 -1
  99. package/lib/typescript/src/constants/packageVersion.d.ts +1 -1
  100. package/lib/typescript/src/hooks/useCactusIndex.d.ts +14 -0
  101. package/lib/typescript/src/hooks/useCactusIndex.d.ts.map +1 -0
  102. package/lib/typescript/src/hooks/useCactusLM.d.ts +6 -4
  103. package/lib/typescript/src/hooks/useCactusLM.d.ts.map +1 -1
  104. package/lib/typescript/src/hooks/useCactusSTT.d.ts +13 -5
  105. package/lib/typescript/src/hooks/useCactusSTT.d.ts.map +1 -1
  106. package/lib/typescript/src/index.d.ts +6 -4
  107. package/lib/typescript/src/index.d.ts.map +1 -1
  108. package/lib/typescript/src/models.d.ts +6 -0
  109. package/lib/typescript/src/models.d.ts.map +1 -0
  110. package/lib/typescript/src/native/Cactus.d.ts +10 -3
  111. package/lib/typescript/src/native/Cactus.d.ts.map +1 -1
  112. package/lib/typescript/src/native/CactusFileSystem.d.ts +1 -0
  113. package/lib/typescript/src/native/CactusFileSystem.d.ts.map +1 -1
  114. package/lib/typescript/src/native/CactusIndex.d.ts +12 -0
  115. package/lib/typescript/src/native/CactusIndex.d.ts.map +1 -0
  116. package/lib/typescript/src/native/CactusUtil.d.ts.map +1 -1
  117. package/lib/typescript/src/native/index.d.ts +1 -0
  118. package/lib/typescript/src/native/index.d.ts.map +1 -1
  119. package/lib/typescript/src/specs/Cactus.nitro.d.ts +9 -2
  120. package/lib/typescript/src/specs/Cactus.nitro.d.ts.map +1 -1
  121. package/lib/typescript/src/specs/CactusFileSystem.nitro.d.ts +1 -0
  122. package/lib/typescript/src/specs/CactusFileSystem.nitro.d.ts.map +1 -1
  123. package/lib/typescript/src/specs/CactusIndex.nitro.d.ts +24 -0
  124. package/lib/typescript/src/specs/CactusIndex.nitro.d.ts.map +1 -0
  125. package/lib/typescript/src/specs/CactusUtil.nitro.d.ts +1 -1
  126. package/lib/typescript/src/specs/CactusUtil.nitro.d.ts.map +1 -1
  127. package/lib/typescript/src/types/CactusIndex.d.ts +34 -0
  128. package/lib/typescript/src/types/CactusIndex.d.ts.map +1 -0
  129. package/lib/typescript/src/types/CactusLM.d.ts +19 -0
  130. package/lib/typescript/src/types/CactusLM.d.ts.map +1 -1
  131. package/lib/typescript/src/types/CactusSTT.d.ts +21 -1
  132. package/lib/typescript/src/types/CactusSTT.d.ts.map +1 -1
  133. package/lib/typescript/src/types/common.d.ts +28 -0
  134. package/lib/typescript/src/types/common.d.ts.map +1 -0
  135. package/nitro.json +3 -0
  136. package/nitrogen/generated/android/c++/JDeviceInfo.hpp +1 -1
  137. package/nitrogen/generated/android/c++/JFunc_void_double.hpp +1 -1
  138. package/nitrogen/generated/android/c++/JHybridCactusCryptoSpec.cpp +1 -1
  139. package/nitrogen/generated/android/c++/JHybridCactusCryptoSpec.hpp +1 -1
  140. package/nitrogen/generated/android/c++/JHybridCactusDeviceInfoSpec.cpp +1 -1
  141. package/nitrogen/generated/android/c++/JHybridCactusDeviceInfoSpec.hpp +1 -1
  142. package/nitrogen/generated/android/c++/JHybridCactusFileSystemSpec.cpp +17 -1
  143. package/nitrogen/generated/android/c++/JHybridCactusFileSystemSpec.hpp +2 -1
  144. package/nitrogen/generated/android/c++/JHybridCactusImageSpec.cpp +1 -1
  145. package/nitrogen/generated/android/c++/JHybridCactusImageSpec.hpp +1 -1
  146. package/nitrogen/generated/android/cactus+autolinking.cmake +2 -1
  147. package/nitrogen/generated/android/cactus+autolinking.gradle +1 -1
  148. package/nitrogen/generated/android/cactusOnLoad.cpp +11 -1
  149. package/nitrogen/generated/android/cactusOnLoad.hpp +1 -1
  150. package/nitrogen/generated/android/kotlin/com/margelo/nitro/cactus/DeviceInfo.kt +1 -1
  151. package/nitrogen/generated/android/kotlin/com/margelo/nitro/cactus/Func_void_double.kt +1 -1
  152. package/nitrogen/generated/android/kotlin/com/margelo/nitro/cactus/HybridCactusCryptoSpec.kt +1 -1
  153. package/nitrogen/generated/android/kotlin/com/margelo/nitro/cactus/HybridCactusDeviceInfoSpec.kt +1 -1
  154. package/nitrogen/generated/android/kotlin/com/margelo/nitro/cactus/HybridCactusFileSystemSpec.kt +5 -1
  155. package/nitrogen/generated/android/kotlin/com/margelo/nitro/cactus/HybridCactusImageSpec.kt +1 -1
  156. package/nitrogen/generated/android/kotlin/com/margelo/nitro/cactus/cactusOnLoad.kt +1 -1
  157. package/nitrogen/generated/ios/Cactus+autolinking.rb +1 -1
  158. package/nitrogen/generated/ios/Cactus-Swift-Cxx-Bridge.cpp +1 -1
  159. package/nitrogen/generated/ios/Cactus-Swift-Cxx-Bridge.hpp +1 -1
  160. package/nitrogen/generated/ios/Cactus-Swift-Cxx-Umbrella.hpp +1 -1
  161. package/nitrogen/generated/ios/CactusAutolinking.mm +11 -1
  162. package/nitrogen/generated/ios/CactusAutolinking.swift +1 -1
  163. package/nitrogen/generated/ios/c++/HybridCactusCryptoSpecSwift.cpp +1 -1
  164. package/nitrogen/generated/ios/c++/HybridCactusCryptoSpecSwift.hpp +1 -1
  165. package/nitrogen/generated/ios/c++/HybridCactusDeviceInfoSpecSwift.cpp +1 -1
  166. package/nitrogen/generated/ios/c++/HybridCactusDeviceInfoSpecSwift.hpp +1 -1
  167. package/nitrogen/generated/ios/c++/HybridCactusFileSystemSpecSwift.cpp +1 -1
  168. package/nitrogen/generated/ios/c++/HybridCactusFileSystemSpecSwift.hpp +9 -1
  169. package/nitrogen/generated/ios/c++/HybridCactusImageSpecSwift.cpp +1 -1
  170. package/nitrogen/generated/ios/c++/HybridCactusImageSpecSwift.hpp +1 -1
  171. package/nitrogen/generated/ios/swift/DeviceInfo.swift +1 -1
  172. package/nitrogen/generated/ios/swift/Func_void.swift +1 -1
  173. package/nitrogen/generated/ios/swift/Func_void_DeviceInfo.swift +1 -1
  174. package/nitrogen/generated/ios/swift/Func_void_bool.swift +1 -1
  175. package/nitrogen/generated/ios/swift/Func_void_double.swift +1 -1
  176. package/nitrogen/generated/ios/swift/Func_void_std__exception_ptr.swift +1 -1
  177. package/nitrogen/generated/ios/swift/Func_void_std__optional_std__string_.swift +1 -1
  178. package/nitrogen/generated/ios/swift/Func_void_std__string.swift +1 -1
  179. package/nitrogen/generated/ios/swift/HybridCactusCryptoSpec.swift +1 -1
  180. package/nitrogen/generated/ios/swift/HybridCactusCryptoSpec_cxx.swift +1 -1
  181. package/nitrogen/generated/ios/swift/HybridCactusDeviceInfoSpec.swift +1 -1
  182. package/nitrogen/generated/ios/swift/HybridCactusDeviceInfoSpec_cxx.swift +1 -1
  183. package/nitrogen/generated/ios/swift/HybridCactusFileSystemSpec.swift +2 -1
  184. package/nitrogen/generated/ios/swift/HybridCactusFileSystemSpec_cxx.swift +20 -1
  185. package/nitrogen/generated/ios/swift/HybridCactusImageSpec.swift +1 -1
  186. package/nitrogen/generated/ios/swift/HybridCactusImageSpec_cxx.swift +1 -1
  187. package/nitrogen/generated/shared/c++/CactusIndexGetResult.hpp +84 -0
  188. package/nitrogen/generated/shared/c++/CactusIndexQueryResult.hpp +79 -0
  189. package/nitrogen/generated/shared/c++/DeviceInfo.hpp +1 -1
  190. package/nitrogen/generated/shared/c++/HybridCactusCryptoSpec.cpp +1 -1
  191. package/nitrogen/generated/shared/c++/HybridCactusCryptoSpec.hpp +1 -1
  192. package/nitrogen/generated/shared/c++/HybridCactusDeviceInfoSpec.cpp +1 -1
  193. package/nitrogen/generated/shared/c++/HybridCactusDeviceInfoSpec.hpp +1 -1
  194. package/nitrogen/generated/shared/c++/HybridCactusFileSystemSpec.cpp +2 -1
  195. package/nitrogen/generated/shared/c++/HybridCactusFileSystemSpec.hpp +2 -1
  196. package/nitrogen/generated/shared/c++/HybridCactusImageSpec.cpp +1 -1
  197. package/nitrogen/generated/shared/c++/HybridCactusImageSpec.hpp +1 -1
  198. package/nitrogen/generated/shared/c++/HybridCactusIndexSpec.cpp +27 -0
  199. package/nitrogen/generated/shared/c++/HybridCactusIndexSpec.hpp +76 -0
  200. package/nitrogen/generated/shared/c++/HybridCactusSpec.cpp +8 -1
  201. package/nitrogen/generated/shared/c++/HybridCactusSpec.hpp +11 -3
  202. package/nitrogen/generated/shared/c++/HybridCactusUtilSpec.cpp +1 -1
  203. package/nitrogen/generated/shared/c++/HybridCactusUtilSpec.hpp +2 -2
  204. package/package.json +2 -2
  205. package/src/api/Database.ts +14 -135
  206. package/src/classes/CactusIndex.ts +58 -0
  207. package/src/classes/CactusLM.ts +87 -19
  208. package/src/classes/CactusSTT.ts +134 -20
  209. package/src/config/CactusConfig.ts +3 -0
  210. package/src/constants/packageVersion.ts +1 -1
  211. package/src/hooks/useCactusIndex.ts +195 -0
  212. package/src/hooks/useCactusLM.ts +88 -8
  213. package/src/hooks/useCactusSTT.ts +119 -7
  214. package/src/index.tsx +22 -2
  215. package/src/models.ts +344 -0
  216. package/src/native/Cactus.ts +95 -13
  217. package/src/native/CactusFileSystem.ts +4 -0
  218. package/src/native/CactusIndex.ts +54 -0
  219. package/src/native/CactusUtil.ts +19 -3
  220. package/src/native/index.ts +1 -0
  221. package/src/specs/Cactus.nitro.ts +18 -2
  222. package/src/specs/CactusFileSystem.nitro.ts +2 -0
  223. package/src/specs/CactusIndex.nitro.ts +31 -0
  224. package/src/specs/CactusUtil.nitro.ts +1 -1
  225. package/src/telemetry/Telemetry.ts +1 -1
  226. package/src/types/CactusIndex.ts +40 -0
  227. package/src/types/CactusLM.ts +24 -0
  228. package/src/types/CactusSTT.ts +27 -1
  229. package/src/types/common.ts +28 -0
  230. package/android/src/main/jniLibs/arm64-v8a/libcactus_util.so +0 -0
  231. package/lib/module/types/CactusModel.js +0 -2
  232. package/lib/module/types/CactusSTTModel.js +0 -2
  233. package/lib/typescript/src/types/CactusModel.d.ts +0 -13
  234. package/lib/typescript/src/types/CactusModel.d.ts.map +0 -1
  235. package/lib/typescript/src/types/CactusSTTModel.d.ts +0 -8
  236. package/lib/typescript/src/types/CactusSTTModel.d.ts.map +0 -1
  237. package/src/types/CactusModel.ts +0 -15
  238. package/src/types/CactusSTTModel.ts +0 -10
@@ -8,15 +8,99 @@
8
8
  #include <cstring>
9
9
  #include <stdexcept>
10
10
  #include <string>
11
+ #include <mutex>
12
+ #include <sstream>
13
+ #include <iostream>
14
+ #include <arm_neon.h>
15
+
16
+ namespace cactus {
17
+
18
+ enum class LogLevel {
19
+ DEBUG = 0,
20
+ INFO = 1,
21
+ WARN = 2,
22
+ ERROR = 3,
23
+ NONE = 4
24
+ };
25
+
26
+ class Logger {
27
+ public:
28
+ static Logger& instance() {
29
+ static Logger logger;
30
+ return logger;
31
+ }
32
+
33
+ void set_level(LogLevel level) { min_level_ = level; }
34
+ LogLevel get_level() const { return min_level_; }
35
+
36
+ void set_callback(std::function<void(LogLevel, const std::string&, const std::string&)> cb) {
37
+ std::lock_guard<std::mutex> lock(mutex_);
38
+ callback_ = cb;
39
+ }
40
+
41
+ void log(LogLevel level, const std::string& component, const std::string& message) {
42
+ if (level < min_level_) return;
43
+
44
+ std::lock_guard<std::mutex> lock(mutex_);
45
+
46
+ if (callback_) {
47
+ callback_(level, component, message);
48
+ } else {
49
+ std::cerr << "[" << level_string(level) << "] [" << component << "] " << message << std::endl;
50
+ }
51
+
52
+ if (level == LogLevel::ERROR) {
53
+ last_error_ = "[" + component + "] " + message;
54
+ }
55
+ }
56
+
57
+ const std::string& last_error() const { return last_error_; }
58
+ void clear_error() { last_error_.clear(); }
59
+
60
+ private:
61
+ Logger() : min_level_(LogLevel::WARN) {}
62
+
63
+ static const char* level_string(LogLevel level) {
64
+ switch (level) {
65
+ case LogLevel::DEBUG: return "DEBUG";
66
+ case LogLevel::INFO: return "INFO";
67
+ case LogLevel::WARN: return "WARN";
68
+ case LogLevel::ERROR: return "ERROR";
69
+ default: return "?";
70
+ }
71
+ }
72
+
73
+ LogLevel min_level_;
74
+ std::mutex mutex_;
75
+ std::string last_error_;
76
+ std::function<void(LogLevel, const std::string&, const std::string&)> callback_;
77
+ };
78
+
79
+ } // namespace cactus
80
+
81
+ #define CACTUS_LOG(level, component, msg) \
82
+ do { \
83
+ if (static_cast<int>(level) >= static_cast<int>(cactus::Logger::instance().get_level())) { \
84
+ std::ostringstream _cactus_log_ss; \
85
+ _cactus_log_ss << msg; \
86
+ cactus::Logger::instance().log(level, component, _cactus_log_ss.str()); \
87
+ } \
88
+ } while(0)
89
+
90
+ #define CACTUS_LOG_DEBUG(component, msg) CACTUS_LOG(cactus::LogLevel::DEBUG, component, msg)
91
+ #define CACTUS_LOG_INFO(component, msg) CACTUS_LOG(cactus::LogLevel::INFO, component, msg)
92
+ #define CACTUS_LOG_WARN(component, msg) CACTUS_LOG(cactus::LogLevel::WARN, component, msg)
93
+ #define CACTUS_LOG_ERROR(component, msg) CACTUS_LOG(cactus::LogLevel::ERROR, component, msg)
11
94
 
12
95
  namespace GraphFile {
13
96
  class MappedFile;
14
97
  }
15
98
 
16
99
  enum class Precision {
17
- INT8,
100
+ INT8,
18
101
  FP16,
19
- FP32
102
+ FP32,
103
+ INT4
20
104
  };
21
105
 
22
106
  enum class ComputeBackend {
@@ -30,7 +114,7 @@ enum class OpType {
30
114
  MATMUL, TRANSPOSE, RESHAPE, SLICE, GATHER, EMBEDDING,
31
115
  BILINEAR_INTERPOLATION,
32
116
  SUM, MEAN, VARIANCE, MIN, MAX,
33
- RMS_NORM, ROPE, SOFTMAX, ATTENTION, CONV1D_CAUSAL, CONV1D_K3,
117
+ RMS_NORM, ROPE, SOFTMAX, ATTENTION, ATTENTION_INT8_HYBRID, CONV1D_CAUSAL, CONV1D_K3,
34
118
  SCALAR_ADD, SCALAR_SUBTRACT, SCALAR_MULTIPLY, SCALAR_DIVIDE, SCALAR_EXP, SCALAR_SQRT, SCALAR_COS, SCALAR_SIN,
35
119
  SILU, GELU, GELU_ERF,
36
120
  SAMPLE, CONCAT,
@@ -40,27 +124,38 @@ enum class OpType {
40
124
  };
41
125
 
42
126
  struct PrecisionTraits {
127
+ // Returns in-memory element size (INT4 unpacks to INT8, so returns 1)
43
128
  static constexpr size_t size_of(Precision prec) {
44
129
  switch (prec) {
45
130
  case Precision::INT8: return 1;
46
131
  case Precision::FP16: return 2;
47
132
  case Precision::FP32: return 4;
133
+ case Precision::INT4: return 1;
48
134
  }
49
135
  return 1;
50
136
  }
51
-
137
+
138
+ static constexpr size_t packed_size_of(Precision prec, size_t count) {
139
+ switch (prec) {
140
+ case Precision::INT4: return (count + 1) / 2;
141
+ default: return count * size_of(prec);
142
+ }
143
+ }
144
+
52
145
  static constexpr bool is_integer(Precision prec) {
53
146
  switch (prec) {
54
147
  case Precision::INT8: return true;
148
+ case Precision::INT4: return true;
55
149
  case Precision::FP16: return false;
56
150
  case Precision::FP32: return false;
57
151
  }
58
152
  return true;
59
153
  }
60
-
154
+
61
155
  static constexpr bool is_floating_point(Precision prec) {
62
156
  switch (prec) {
63
157
  case Precision::INT8: return false;
158
+ case Precision::INT4: return false;
64
159
  case Precision::FP16: return true;
65
160
  case Precision::FP32: return true;
66
161
  }
@@ -71,8 +166,6 @@ struct PrecisionTraits {
71
166
  namespace Quantization {
72
167
  void int8_to_fp32(const int8_t* src, float* dst, size_t count, float scale = 1.0f);
73
168
  void fp32_to_int8(const float* src, int8_t* dst, size_t count, float scale = 1.0f);
74
- void dynamic_quantize_fp32_to_int8(const float* src, int8_t* dst, size_t count,
75
- float* computed_scale);
76
169
  void fp16_to_fp32(const __fp16* src, float* dst, size_t count);
77
170
  void fp32_to_fp16(const float* src, __fp16* dst, size_t count);
78
171
  void int8_to_fp16(const int8_t* src, __fp16* dst, size_t count, float scale = 1.0f);
@@ -106,10 +199,17 @@ struct BufferDesc {
106
199
  void* external_data;
107
200
  char* pooled_data;
108
201
  Precision precision;
109
- float quantization_scale;
202
+
203
+ size_t group_size = 0;
204
+ size_t num_groups = 0;
205
+ void* scales_data = nullptr;
206
+ std::unique_ptr<char[]> owned_scales;
207
+
208
+ const void* packed_int4_data = nullptr;
209
+ size_t packed_int4_size = 0;
110
210
 
111
211
  BufferDesc();
112
- BufferDesc(const std::vector<size_t>& s, Precision prec = Precision::INT8, float scale = 1.0f);
212
+ BufferDesc(const std::vector<size_t>& s, Precision prec = Precision::INT8);
113
213
  ~BufferDesc();
114
214
 
115
215
  BufferDesc(BufferDesc&& other) noexcept;
@@ -127,6 +227,28 @@ struct BufferDesc {
127
227
  template<typename T>
128
228
  const T* data_as() const { return static_cast<const T*>(get_data()); }
129
229
 
230
+ const __fp16* scales_as_fp16() const {
231
+ return reinterpret_cast<const __fp16*>(scales_data);
232
+ }
233
+ bool is_grouped_int8() const {
234
+ return precision == Precision::INT8 && group_size > 0;
235
+ }
236
+ bool is_packed_int4() const {
237
+ return packed_int4_data != nullptr && packed_int4_size > 0;
238
+ }
239
+ const uint8_t* packed_int4_as_uint8() const {
240
+ return reinterpret_cast<const uint8_t*>(packed_int4_data);
241
+ }
242
+ void set_grouped_scales(size_t gs, size_t ng, void* scales_ptr) {
243
+ group_size = gs;
244
+ num_groups = ng;
245
+ scales_data = scales_ptr;
246
+ }
247
+ void set_packed_int4(const void* packed_data, size_t packed_size) {
248
+ packed_int4_data = packed_data;
249
+ packed_int4_size = packed_size;
250
+ }
251
+
130
252
  void allocate();
131
253
  void allocate_from_pool(BufferPool& pool);
132
254
  void release_to_pool(BufferPool& pool);
@@ -144,7 +266,7 @@ struct OpParams {
144
266
  size_t slice_start = 0;
145
267
  size_t slice_length = 0;
146
268
  size_t window_size = 0;
147
- bool is_causal = true; // Default to causal for backward compatibility
269
+ bool is_causal = true;
148
270
  std::vector<size_t> new_shape;
149
271
  std::vector<size_t> permutation;
150
272
  Precision output_precision = Precision::INT8;
@@ -158,10 +280,21 @@ struct OpParams {
158
280
  size_t top_k = 0;
159
281
  size_t random_seed = 0;
160
282
 
161
- size_t index_value = 0; // For INDEX operation
162
- size_t num_classes = 0; // For scatter operations
283
+ size_t index_value = 0;
284
+ size_t num_classes = 0;
163
285
  size_t dst_height = 0;
164
- size_t dst_width = 0;
286
+ size_t dst_width = 0;
287
+
288
+ std::vector<float> bias_values;
289
+ std::vector<uint32_t> bias_indices;
290
+
291
+ const int8_t* cached_keys_int8 = nullptr;
292
+ const int8_t* cached_values_int8 = nullptr;
293
+ const float* cached_k_scales = nullptr;
294
+ const float* cached_v_scales = nullptr;
295
+ size_t cache_seq_len = 0;
296
+ size_t num_kv_heads = 0;
297
+ size_t head_dim = 0;
165
298
  };
166
299
 
167
300
  struct GraphNode {
@@ -241,7 +374,7 @@ public:
241
374
  size_t precision_cast(size_t input, Precision target_precision);
242
375
 
243
376
  size_t add(size_t input1, size_t input2);
244
- size_t add_clipped(size_t input1, size_t input2); // For FP16 residual connections (Gemma)
377
+ size_t add_clipped(size_t input1, size_t input2);
245
378
  size_t subtract(size_t input1, size_t input2);
246
379
  size_t multiply(size_t input1, size_t input2);
247
380
  size_t divide(size_t input1, size_t input2);
@@ -276,8 +409,12 @@ public:
276
409
  size_t gather(size_t embeddings, size_t indices);
277
410
  size_t mmap_embeddings(const std::string& filename);
278
411
  size_t mmap_weights(const std::string& filename);
279
- size_t load_weights(const std::string& filename);
280
- void set_quantization_scale(size_t node_id, float scale);
412
+ size_t load_weights(const std::string& filename);
413
+ void set_grouped_scales(size_t node_id, size_t group_size, size_t num_groups, void* scales_ptr);
414
+
415
+ void release_weight_pages(size_t node_id);
416
+ void prefetch_weight_pages(size_t node_id);
417
+ void release_all_weight_pages();
281
418
  size_t embedding(const std::string& filename, size_t indices);
282
419
  size_t embedding(size_t embedding_tensor, size_t indices);
283
420
  size_t bilinear_interpolation(size_t pos_embeds, size_t dst_height, size_t dst_width);
@@ -291,10 +428,16 @@ public:
291
428
  size_t attention(size_t query, size_t key, size_t value, float scale, size_t position_offset, ComputeBackend backend = ComputeBackend::CPU);
292
429
  size_t attention(size_t query, size_t key, size_t value, float scale, size_t position_offset, size_t window_size, ComputeBackend backend = ComputeBackend::CPU);
293
430
 
431
+ size_t attention_int8_hybrid(size_t query, size_t key_new, size_t value_new, float scale, size_t position_offset,
432
+ const int8_t* cached_keys, const int8_t* cached_values,
433
+ const float* k_scales, const float* v_scales,
434
+ size_t cache_len, size_t num_kv_heads, size_t head_dim);
435
+
294
436
  size_t conv1d_causal(size_t input, size_t weight, size_t kernel_size, size_t dilation = 1);
295
437
  size_t conv1d_k3(size_t input, size_t weight, size_t stride);
296
438
 
297
- size_t sample(size_t logits, float temperature = 0.6f, float top_p = 0.95f, size_t top_k = 20);
439
+ size_t sample(size_t logits, float temperature = 0.6f, float top_p = 0.95f, size_t top_k = 20,
440
+ const std::unordered_map<uint32_t, float>& logit_bias = {});
298
441
 
299
442
  size_t concat(size_t input1, size_t input2, int axis = 0);
300
443
  size_t scatter_topk(size_t indices, size_t values, size_t num_classes);
@@ -306,6 +449,8 @@ public:
306
449
  void execute(const std::string& profile_file = "");
307
450
  void hard_reset();
308
451
  void soft_reset();
452
+ void soft_reset_keep_pool();
453
+ void set_prefill_mode(bool enabled) { prefill_mode_ = enabled; }
309
454
 
310
455
  void register_debug_node(uint32_t layer_idx, const std::string& name, size_t node_id);
311
456
  void capture_debug_node(uint32_t layer_idx, const std::string& name, size_t node_id);
@@ -324,8 +469,10 @@ private:
324
469
  size_t next_node_id_;
325
470
  std::vector<std::unique_ptr<GraphFile::MappedFile>> mapped_files_;
326
471
  std::unordered_map<std::string, size_t> weight_cache_;
472
+ std::unordered_map<size_t, size_t> node_to_mapped_file_;
327
473
  std::vector<DebugNodeEntry> debug_nodes_;
328
474
  BufferPool buffer_pool_;
475
+ bool prefill_mode_ = false;
329
476
  };
330
477
 
331
478
 
@@ -344,25 +491,36 @@ namespace GraphFile {
344
491
  public:
345
492
  MappedFile(const std::string& filename);
346
493
  ~MappedFile();
347
-
494
+
348
495
  MappedFile(const MappedFile&) = delete;
349
496
  MappedFile& operator=(const MappedFile&) = delete;
350
497
  MappedFile(MappedFile&& other) noexcept;
351
498
  MappedFile& operator=(MappedFile&& other) noexcept;
352
-
499
+
353
500
  const std::vector<size_t>& shape() const;
354
501
  Precision precision() const;
502
+ Precision effective_precision() const {
503
+ return is_int4_ ? Precision::INT8 : precision_;
504
+ }
355
505
  size_t byte_size() const;
356
- float quantization_scale() const;
357
-
506
+
507
+ size_t group_size() const { return group_size_; }
508
+ size_t num_groups() const { return num_groups_; }
509
+ const void* scales_data() const;
510
+ const void* raw_packed_data() const; // Get raw mmap'd data without unpacking (for INT4)
511
+ bool is_int4() const { return is_int4_; }
512
+
358
513
  void* data();
359
514
  const void* data() const;
360
-
515
+
361
516
  template<typename T>
362
517
  const T* typed_data() const;
363
-
518
+
364
519
  LoadedNode load_into_graph(CactusGraph& graph) const;
365
-
520
+
521
+ void release_pages();
522
+ void prefetch_pages();
523
+
366
524
  private:
367
525
  int fd_;
368
526
  void* mapped_data_;
@@ -370,10 +528,19 @@ namespace GraphFile {
370
528
  std::vector<size_t> shape_;
371
529
  Precision precision_;
372
530
  size_t byte_size_;
373
- float quantization_scale_;
531
+ size_t group_size_ = 0;
532
+ size_t num_groups_ = 0;
533
+ size_t scales_offset_ = 0;
534
+ size_t scales_bytes_ = 0;
535
+ uint32_t version_ = 1;
536
+ uint32_t alignment_ = 32;
537
+ bool is_int4_ = false;
538
+ mutable std::unique_ptr<int8_t[]> unpacked_int4_data_;
374
539
  void parse_header();
540
+ void apply_madvise_hints();
541
+ void unpack_int4_if_needed() const;
375
542
  };
376
-
543
+
377
544
  MappedFile mmap_load(const std::string& filename);
378
545
  }
379
546