cactus-react-native 1.0.2 → 1.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (126) hide show
  1. package/README.md +378 -21
  2. package/android/src/main/java/com/margelo/nitro/cactus/HybridCactusCrypto.kt +23 -15
  3. package/android/src/main/java/com/margelo/nitro/cactus/HybridCactusDeviceInfo.kt +12 -9
  4. package/android/src/main/java/com/margelo/nitro/cactus/HybridCactusFileSystem.kt +42 -41
  5. package/android/src/main/java/com/margelo/nitro/cactus/HybridCactusImage.kt +81 -0
  6. package/android/src/main/jniLibs/arm64-v8a/libcactus.a +0 -0
  7. package/cpp/HybridCactus.cpp +105 -0
  8. package/cpp/HybridCactus.hpp +13 -0
  9. package/cpp/cactus_ffi.h +27 -0
  10. package/ios/HybridCactusImage.swift +53 -0
  11. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/cactus_ffi.h +27 -0
  12. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/engine.h +37 -5
  13. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/ffi_utils.h +10 -9
  14. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/graph.h +49 -7
  15. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/kernel.h +31 -0
  16. package/ios/cactus.xcframework/ios-arm64/cactus.framework/cactus +0 -0
  17. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/cactus_ffi.h +27 -0
  18. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/engine.h +37 -5
  19. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/ffi_utils.h +10 -9
  20. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/graph.h +49 -7
  21. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/kernel.h +31 -0
  22. package/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/cactus +0 -0
  23. package/lib/module/api/Database.js +23 -0
  24. package/lib/module/api/Database.js.map +1 -1
  25. package/lib/module/api/RemoteLM.js +201 -0
  26. package/lib/module/api/RemoteLM.js.map +1 -0
  27. package/lib/module/classes/CactusLM.js +52 -26
  28. package/lib/module/classes/CactusLM.js.map +1 -1
  29. package/lib/module/classes/CactusSTT.js +139 -0
  30. package/lib/module/classes/CactusSTT.js.map +1 -0
  31. package/lib/module/config/CactusConfig.js +4 -0
  32. package/lib/module/config/CactusConfig.js.map +1 -1
  33. package/lib/module/constants/packageVersion.js +1 -1
  34. package/lib/module/hooks/useCactusLM.js +33 -10
  35. package/lib/module/hooks/useCactusLM.js.map +1 -1
  36. package/lib/module/hooks/useCactusSTT.js +234 -0
  37. package/lib/module/hooks/useCactusSTT.js.map +1 -0
  38. package/lib/module/index.js +2 -0
  39. package/lib/module/index.js.map +1 -1
  40. package/lib/module/native/Cactus.js +50 -1
  41. package/lib/module/native/Cactus.js.map +1 -1
  42. package/lib/module/native/CactusFileSystem.js +2 -3
  43. package/lib/module/native/CactusFileSystem.js.map +1 -1
  44. package/lib/module/native/CactusImage.js +13 -0
  45. package/lib/module/native/CactusImage.js.map +1 -0
  46. package/lib/module/native/index.js +1 -0
  47. package/lib/module/native/index.js.map +1 -1
  48. package/lib/module/specs/CactusImage.nitro.js +4 -0
  49. package/lib/module/specs/CactusImage.nitro.js.map +1 -0
  50. package/lib/module/telemetry/Telemetry.js +53 -1
  51. package/lib/module/telemetry/Telemetry.js.map +1 -1
  52. package/lib/module/types/CactusSTT.js +2 -0
  53. package/lib/module/types/CactusSTT.js.map +1 -0
  54. package/lib/typescript/src/api/Database.d.ts +1 -0
  55. package/lib/typescript/src/api/Database.d.ts.map +1 -1
  56. package/lib/typescript/src/api/RemoteLM.d.ts +14 -0
  57. package/lib/typescript/src/api/RemoteLM.d.ts.map +1 -0
  58. package/lib/typescript/src/classes/CactusLM.d.ts +6 -4
  59. package/lib/typescript/src/classes/CactusLM.d.ts.map +1 -1
  60. package/lib/typescript/src/classes/CactusSTT.d.ts +26 -0
  61. package/lib/typescript/src/classes/CactusSTT.d.ts.map +1 -0
  62. package/lib/typescript/src/config/CactusConfig.d.ts +1 -0
  63. package/lib/typescript/src/config/CactusConfig.d.ts.map +1 -1
  64. package/lib/typescript/src/constants/packageVersion.d.ts +1 -1
  65. package/lib/typescript/src/hooks/useCactusLM.d.ts +4 -3
  66. package/lib/typescript/src/hooks/useCactusLM.d.ts.map +1 -1
  67. package/lib/typescript/src/hooks/useCactusSTT.d.ts +20 -0
  68. package/lib/typescript/src/hooks/useCactusSTT.d.ts.map +1 -0
  69. package/lib/typescript/src/index.d.ts +4 -1
  70. package/lib/typescript/src/index.d.ts.map +1 -1
  71. package/lib/typescript/src/native/Cactus.d.ts +9 -2
  72. package/lib/typescript/src/native/Cactus.d.ts.map +1 -1
  73. package/lib/typescript/src/native/CactusFileSystem.d.ts +1 -1
  74. package/lib/typescript/src/native/CactusFileSystem.d.ts.map +1 -1
  75. package/lib/typescript/src/native/CactusImage.d.ts +6 -0
  76. package/lib/typescript/src/native/CactusImage.d.ts.map +1 -0
  77. package/lib/typescript/src/native/index.d.ts +1 -0
  78. package/lib/typescript/src/native/index.d.ts.map +1 -1
  79. package/lib/typescript/src/specs/Cactus.nitro.d.ts +3 -0
  80. package/lib/typescript/src/specs/Cactus.nitro.d.ts.map +1 -1
  81. package/lib/typescript/src/specs/CactusImage.nitro.d.ts +9 -0
  82. package/lib/typescript/src/specs/CactusImage.nitro.d.ts.map +1 -0
  83. package/lib/typescript/src/telemetry/Telemetry.d.ts +5 -1
  84. package/lib/typescript/src/telemetry/Telemetry.d.ts.map +1 -1
  85. package/lib/typescript/src/types/CactusLM.d.ts +8 -5
  86. package/lib/typescript/src/types/CactusLM.d.ts.map +1 -1
  87. package/lib/typescript/src/types/CactusSTT.d.ts +37 -0
  88. package/lib/typescript/src/types/CactusSTT.d.ts.map +1 -0
  89. package/nitro.json +4 -0
  90. package/nitrogen/generated/android/c++/JHybridCactusImageSpec.cpp +81 -0
  91. package/nitrogen/generated/android/c++/JHybridCactusImageSpec.hpp +66 -0
  92. package/nitrogen/generated/android/cactus+autolinking.cmake +2 -0
  93. package/nitrogen/generated/android/cactusOnLoad.cpp +10 -0
  94. package/nitrogen/generated/android/kotlin/com/margelo/nitro/cactus/HybridCactusImageSpec.kt +62 -0
  95. package/nitrogen/generated/ios/Cactus-Swift-Cxx-Bridge.cpp +17 -0
  96. package/nitrogen/generated/ios/Cactus-Swift-Cxx-Bridge.hpp +17 -0
  97. package/nitrogen/generated/ios/Cactus-Swift-Cxx-Umbrella.hpp +5 -0
  98. package/nitrogen/generated/ios/CactusAutolinking.mm +8 -0
  99. package/nitrogen/generated/ios/CactusAutolinking.swift +15 -0
  100. package/nitrogen/generated/ios/c++/HybridCactusImageSpecSwift.cpp +11 -0
  101. package/nitrogen/generated/ios/c++/HybridCactusImageSpecSwift.hpp +85 -0
  102. package/nitrogen/generated/ios/swift/HybridCactusImageSpec.swift +58 -0
  103. package/nitrogen/generated/ios/swift/HybridCactusImageSpec_cxx.swift +158 -0
  104. package/nitrogen/generated/shared/c++/HybridCactusImageSpec.cpp +22 -0
  105. package/nitrogen/generated/shared/c++/HybridCactusImageSpec.hpp +64 -0
  106. package/nitrogen/generated/shared/c++/HybridCactusSpec.cpp +3 -0
  107. package/nitrogen/generated/shared/c++/HybridCactusSpec.hpp +3 -0
  108. package/package.json +1 -1
  109. package/src/api/Database.ts +27 -0
  110. package/src/api/RemoteLM.ts +273 -0
  111. package/src/classes/CactusLM.ts +72 -38
  112. package/src/classes/CactusSTT.ts +188 -0
  113. package/src/config/CactusConfig.ts +4 -0
  114. package/src/constants/packageVersion.ts +1 -1
  115. package/src/hooks/useCactusLM.ts +45 -17
  116. package/src/hooks/useCactusSTT.ts +285 -0
  117. package/src/index.tsx +14 -2
  118. package/src/native/Cactus.ts +94 -4
  119. package/src/native/CactusFileSystem.ts +2 -2
  120. package/src/native/CactusImage.ts +20 -0
  121. package/src/native/index.ts +1 -0
  122. package/src/specs/Cactus.nitro.ts +9 -0
  123. package/src/specs/CactusImage.nitro.ts +12 -0
  124. package/src/telemetry/Telemetry.ts +78 -1
  125. package/src/types/CactusLM.ts +9 -5
  126. package/src/types/CactusSTT.ts +42 -0
@@ -32,7 +32,7 @@ enum class OpType {
32
32
  SUM, MEAN, VARIANCE, MIN, MAX,
33
33
  RMS_NORM, ROPE, SOFTMAX, ATTENTION, CONV1D_CAUSAL, CONV1D_K3,
34
34
  SCALAR_ADD, SCALAR_SUBTRACT, SCALAR_MULTIPLY, SCALAR_DIVIDE, SCALAR_EXP, SCALAR_SQRT, SCALAR_COS, SCALAR_SIN,
35
- SILU, GELU,
35
+ SILU, GELU, GELU_ERF,
36
36
  SAMPLE, CONCAT,
37
37
  SCATTER_TOPK,
38
38
  TOPK, LAYERNORM,
@@ -92,32 +92,44 @@ struct TensorConfig {
92
92
  struct BroadcastInfo {
93
93
  std::vector<size_t> output_shape;
94
94
  bool needs_broadcasting;
95
-
95
+
96
96
  static BroadcastInfo compute(const std::vector<size_t>& lhs, const std::vector<size_t>& rhs);
97
97
  };
98
98
 
99
+ class BufferPool;
100
+
99
101
  struct BufferDesc {
100
102
  std::vector<size_t> shape;
101
103
  size_t total_size;
102
104
  size_t byte_size;
103
105
  std::unique_ptr<char[]> data;
104
106
  void* external_data;
107
+ char* pooled_data;
105
108
  Precision precision;
106
109
  float quantization_scale;
107
-
110
+
108
111
  BufferDesc();
109
112
  BufferDesc(const std::vector<size_t>& s, Precision prec = Precision::INT8, float scale = 1.0f);
110
-
113
+ ~BufferDesc();
114
+
115
+ BufferDesc(BufferDesc&& other) noexcept;
116
+ BufferDesc& operator=(BufferDesc&& other) noexcept;
117
+
118
+ BufferDesc(const BufferDesc&) = delete;
119
+ BufferDesc& operator=(const BufferDesc&) = delete;
120
+
111
121
  void* get_data();
112
122
  const void* get_data() const;
113
-
123
+
114
124
  template<typename T>
115
125
  T* data_as() { return static_cast<T*>(get_data()); }
116
-
126
+
117
127
  template<typename T>
118
128
  const T* data_as() const { return static_cast<const T*>(get_data()); }
119
-
129
+
120
130
  void allocate();
131
+ void allocate_from_pool(BufferPool& pool);
132
+ void release_to_pool(BufferPool& pool);
121
133
  void set_external(void* ptr);
122
134
  };
123
135
 
@@ -181,6 +193,33 @@ void compute_topk_node(GraphNode& node, const std::vector<std::unique_ptr<GraphN
181
193
  void compute_layernorm_node(GraphNode& node, const std::vector<std::unique_ptr<GraphNode>>& nodes, const std::unordered_map<size_t, size_t>& node_index_map);
182
194
  void compute_index_node(GraphNode& node, const std::vector<std::unique_ptr<GraphNode>>& nodes, const std::unordered_map<size_t, size_t>& node_index_map);
183
195
 
196
+ void shrink_thread_local_buffers();
197
+
198
+ class BufferPool {
199
+ public:
200
+ BufferPool() = default;
201
+ ~BufferPool() = default;
202
+
203
+ BufferPool(const BufferPool&) = delete;
204
+ BufferPool& operator=(const BufferPool&) = delete;
205
+
206
+ char* acquire(size_t byte_size);
207
+ void release(char* ptr, size_t byte_size);
208
+ void clear();
209
+
210
+ size_t active_bytes() const { return active_bytes_; }
211
+ size_t pool_bytes() const { return pool_bytes_; }
212
+ size_t peak_bytes() const { return peak_bytes_; }
213
+
214
+ private:
215
+ std::unordered_map<size_t, std::vector<std::unique_ptr<char[]>>> free_buffers_;
216
+ size_t active_bytes_ = 0;
217
+ size_t pool_bytes_ = 0;
218
+ size_t peak_bytes_ = 0;
219
+
220
+ size_t round_up_size(size_t size) const;
221
+ };
222
+
184
223
  namespace ValidationUtils {
185
224
  void validate_tensor_dims(const std::vector<size_t>& shape, size_t required_dims, const std::string& op_name);
186
225
  void validate_precision(Precision actual, Precision required, const std::string& op_name);
@@ -219,6 +258,7 @@ public:
219
258
 
220
259
  size_t silu(size_t input);
221
260
  size_t gelu(size_t input);
261
+ size_t gelu_erf(size_t input);
222
262
 
223
263
  size_t matmul(size_t input1, size_t input2, bool pretransposed_rhs = false, ComputeBackend backend = ComputeBackend::CPU);
224
264
  size_t transpose(size_t input, ComputeBackend backend = ComputeBackend::CPU);
@@ -236,6 +276,7 @@ public:
236
276
  size_t gather(size_t embeddings, size_t indices);
237
277
  size_t mmap_embeddings(const std::string& filename);
238
278
  size_t mmap_weights(const std::string& filename);
279
+ size_t load_weights(const std::string& filename);
239
280
  void set_quantization_scale(size_t node_id, float scale);
240
281
  size_t embedding(const std::string& filename, size_t indices);
241
282
  size_t embedding(size_t embedding_tensor, size_t indices);
@@ -284,6 +325,7 @@ private:
284
325
  std::vector<std::unique_ptr<GraphFile::MappedFile>> mapped_files_;
285
326
  std::unordered_map<std::string, size_t> weight_cache_;
286
327
  std::vector<DebugNodeEntry> debug_nodes_;
328
+ BufferPool buffer_pool_;
287
329
  };
288
330
 
289
331
 
@@ -174,6 +174,15 @@ void cactus_gelu_f16(const __fp16* input, __fp16* output, size_t num_elements);
174
174
  void cactus_gelu_int8(const int8_t* input, int8_t* output, size_t num_elements,
175
175
  float input_scale, float output_scale);
176
176
 
177
+ void cactus_gelu_f32_erf(const float* input, float* output, size_t num_elements);
178
+ void cactus_gelu_f16_erf(const __fp16* input, __fp16* output, size_t num_elements);
179
+ void cactus_gelu_int8_erf(
180
+ const int8_t* input,
181
+ int8_t* output,
182
+ size_t num_elements,
183
+ float scale_in,
184
+ float scale_out);
185
+
177
186
 
178
187
  void cactus_attention_int8(const int8_t* queries, const int8_t* keys, const int8_t* values, int8_t* output,
179
188
  size_t batch_size, size_t seq_len, size_t kv_seq_len, size_t num_q_heads, size_t num_kv_heads,
@@ -225,6 +234,28 @@ void cactus_conv1d_causal_depthwise_int8(
225
234
  float weight_scale,
226
235
  float output_scale);
227
236
 
237
+ void cactus_conv1d_f32_k3(
238
+ const float* input,
239
+ const float* weight,
240
+ float* output,
241
+ size_t N,
242
+ size_t L,
243
+ size_t C_in,
244
+ size_t C_out,
245
+ size_t stride
246
+ );
247
+
248
+ void cactus_conv1d_f16_k3(
249
+ const __fp16* input,
250
+ const __fp16* weight,
251
+ __fp16* output,
252
+ size_t N,
253
+ size_t L,
254
+ size_t C_in,
255
+ size_t C_out,
256
+ size_t stride
257
+ );
258
+
228
259
  void cactus_conv1d_f32_k3(
229
260
  const float* input,
230
261
  const float* weight,
@@ -33,6 +33,17 @@ CACTUS_FFI_EXPORT int cactus_complete(
33
33
  void* user_data
34
34
  );
35
35
 
36
+ CACTUS_FFI_EXPORT int cactus_transcribe(
37
+ cactus_model_t model,
38
+ const char* audio_file_path,
39
+ const char* prompt,
40
+ char* response_buffer,
41
+ size_t buffer_size,
42
+ const char* options_json,
43
+ cactus_token_callback callback,
44
+ void* user_data
45
+ );
46
+
36
47
 
37
48
  CACTUS_FFI_EXPORT int cactus_embed(
38
49
  cactus_model_t model,
@@ -42,6 +53,22 @@ CACTUS_FFI_EXPORT int cactus_embed(
42
53
  size_t* embedding_dim
43
54
  );
44
55
 
56
+ CACTUS_FFI_EXPORT int cactus_image_embed(
57
+ cactus_model_t model,
58
+ const char* image_path,
59
+ float* embeddings_buffer,
60
+ size_t buffer_size,
61
+ size_t* embedding_dim
62
+ );
63
+
64
+ CACTUS_FFI_EXPORT int cactus_audio_embed(
65
+ cactus_model_t model,
66
+ const char* audio_path,
67
+ float* embeddings_buffer,
68
+ size_t buffer_size,
69
+ size_t* embedding_dim
70
+ );
71
+
45
72
  CACTUS_FFI_EXPORT void cactus_reset(cactus_model_t model);
46
73
 
47
74
  CACTUS_FFI_EXPORT void cactus_stop(cactus_model_t model);
@@ -7,11 +7,28 @@
7
7
  #include <cstdint>
8
8
 
9
9
  #include "../graph/graph.h"
10
+
11
+ #ifdef __clang__
12
+ #pragma clang diagnostic push
13
+ #pragma clang diagnostic ignored "-Wc99-extensions"
14
+ #pragma clang diagnostic ignored "-Wunused-parameter"
15
+ #elif defined(__GNUC__)
16
+ #pragma GCC diagnostic push
17
+ #pragma GCC diagnostic ignored "-Wpedantic"
18
+ #pragma GCC diagnostic ignored "-Wunused-parameter"
19
+ #endif
20
+
10
21
  extern "C" {
11
22
  #include "../../libs/stb/stb_image.h"
12
23
  #include "../../libs/stb/stb_image_resize2.h"
13
24
  }
14
25
 
26
+ #ifdef __clang__
27
+ #pragma clang diagnostic pop
28
+ #elif defined(__GNUC__)
29
+ #pragma GCC diagnostic pop
30
+ #endif
31
+
15
32
  class CactusGraph;
16
33
 
17
34
  namespace cactus {
@@ -68,7 +85,7 @@ struct Config {
68
85
  float max_pixels_tolerance = 2.0f;
69
86
  bool do_image_splitting = true;
70
87
 
71
- enum class ModelType {QWEN = 0, GEMMA = 1, SMOL = 2, NOMIC = 3, LFM2 = 5, SIGLIP2 = 6};
88
+ enum class ModelType {QWEN = 0, GEMMA = 1, SMOL = 2, NOMIC = 3, LFM2 = 5, SIGLIP2 = 6, WHISPER = 7};
72
89
  ModelType model_type = ModelType::QWEN;
73
90
 
74
91
  enum class ModelVariant {DEFAULT = 0, VLM = 1, EXTRACT = 2, RAG = 3};
@@ -139,7 +156,7 @@ public:
139
156
  void set_corpus_dir(const std::string& dir) { corpus_dir_ = dir; }
140
157
 
141
158
  protected:
142
- enum class ModelType { UNKNOWN, QWEN, GEMMA, LFM2, SMOL, BERT };
159
+ enum class ModelType { UNKNOWN, QWEN, GEMMA, LFM2, SMOL, BERT, WHISPER};
143
160
  ModelType model_type_ = ModelType::UNKNOWN;
144
161
  enum class ModelVariant { DEFAULT, VLM, EXTRACT, RAG};
145
162
  ModelVariant model_variant_ = ModelVariant::DEFAULT;
@@ -302,7 +319,7 @@ private:
302
319
  };
303
320
 
304
321
  struct KVCache {
305
- static constexpr size_t DEFAULT_WINDOW_SIZE = 512;
322
+ static constexpr size_t DEFAULT_WINDOW_SIZE = 1024;
306
323
  static constexpr size_t DEFAULT_SINK_SIZE = 4;
307
324
 
308
325
  struct LayerCache {
@@ -365,28 +382,43 @@ public:
365
382
  const std::vector<DebugNode>& get_debug_nodes() const;
366
383
 
367
384
  virtual bool init(const std::string& model_folder, size_t context_size, const std::string& system_prompt = "", bool do_warmup = true);
385
+
368
386
  virtual bool init(CactusGraph* external_graph, const std::string& model_folder, size_t context_size,
369
387
  const std::string& system_prompt = "", bool do_warmup = true);
388
+
370
389
  virtual uint32_t generate(const std::vector<uint32_t>& tokens, float temperature = -1.0f, float top_p = -1.0f,
371
- size_t top_k = 0, const std::string& profile_file = "");
390
+ size_t top_k = 0, const std::string& profile_file = "", bool prefill_only = false);
372
391
 
373
392
  virtual uint32_t generate_with_images(const std::vector<uint32_t>& tokens, const std::vector<std::string>& image_paths,
374
393
  float temperature = -1.0f, float top_p = -1.0f,
375
394
  size_t top_k = 0, const std::string& profile_file = "");
395
+
396
+ virtual uint32_t generate_with_audio(const std::vector<uint32_t>& tokens, const std::vector<float>& mel_bins, float temperature = 0.0f, float top_p = 0.0f,
397
+ size_t top_k = 0, const std::string& profile_file = "");
376
398
 
377
399
  std::vector<float> get_embeddings(const std::vector<uint32_t>& tokens, bool pooled = true, const std::string& profile_file = "");
400
+
401
+ virtual std::vector<float> get_image_embeddings(const std::string& image_path);
402
+
403
+ virtual std::vector<float> get_audio_embeddings(const std::vector<float>& mel_bins);
378
404
 
379
405
  virtual void reset_cache() { kv_cache_.reset(); }
406
+
380
407
  void set_cache_window(size_t window_size, size_t sink_size = 4) { kv_cache_.set_window_size(window_size, sink_size); }
381
408
 
382
409
  void* graph_handle_;
383
410
 
384
411
  protected:
385
412
  virtual size_t forward(const std::vector<uint32_t>& tokens, bool use_cache = false) = 0;
413
+
414
+ virtual size_t forward(const std::vector<float>& mel_bins, const std::vector<uint32_t>& tokens, bool use_cache = false);
415
+
386
416
  virtual void load_weights_to_graph(CactusGraph* gb) = 0;
417
+
387
418
  virtual size_t build_attention(CactusGraph* gb, size_t normalized_input, uint32_t layer_idx,
388
419
  ComputeBackend backend, bool use_cache = false, size_t position_offset = 0) = 0;
389
- virtual size_t build_mlp(CactusGraph* gb, size_t normalized_h, uint32_t layer_idx,
420
+
421
+ virtual size_t build_mlp(CactusGraph* gb, size_t normalized_h, uint32_t layer_idx,
390
422
  ComputeBackend backend) const = 0;
391
423
  virtual size_t build_transformer_block(CactusGraph* gb, size_t hidden, uint32_t layer_idx,
392
424
  ComputeBackend backend, bool use_cache = false, size_t position_offset = 0) = 0;
@@ -8,6 +8,8 @@
8
8
  #include <stdexcept>
9
9
  #include <sstream>
10
10
  #include <iomanip>
11
+ #include <fstream>
12
+ #include <iostream>
11
13
  #include <filesystem>
12
14
  #include <cctype>
13
15
 
@@ -177,8 +179,8 @@ inline void parse_options_json(const std::string& json,
177
179
  float& temperature, float& top_p,
178
180
  size_t& top_k, size_t& max_tokens,
179
181
  std::vector<std::string>& stop_sequences) {
180
- temperature = -1.0f;
181
- top_p = -1.0f;
182
+ temperature = 0.0f;
183
+ top_p = 0.0f;
182
184
  top_k = 0;
183
185
  max_tokens = 100;
184
186
  stop_sequences.clear();
@@ -233,15 +235,14 @@ inline std::string format_tools_for_prompt(const std::vector<ToolFunction>& tool
233
235
  std::string formatted_tools_json;
234
236
  for (size_t i = 0; i < tools.size(); i++) {
235
237
  if (i > 0) formatted_tools_json += ",\n";
236
- formatted_tools_json += " {\n";
237
- formatted_tools_json += " \"type\": \"function\",\n";
238
- formatted_tools_json += " \"function\": {\n";
239
- formatted_tools_json += " \"name\": \"" + tools[i].name + "\",\n";
240
- formatted_tools_json += " \"description\": \"" + tools[i].description + "\"";
238
+ formatted_tools_json += "{\"type\":\"function\",\"function\":{\"name\":\""
239
+ + tools[i].name
240
+ + "\",\"description\":\""
241
+ + tools[i].description + "\"";
241
242
  if (tools[i].parameters.find("schema") != tools[i].parameters.end()) {
242
- formatted_tools_json += ",\n \"parameters\": " + tools[i].parameters.at("schema");
243
+ formatted_tools_json += ",\"parameters\":" + tools[i].parameters.at("schema");
243
244
  }
244
- formatted_tools_json += "\n }\n }";
245
+ formatted_tools_json += "}}";
245
246
  }
246
247
  return formatted_tools_json;
247
248
  }
@@ -32,7 +32,7 @@ enum class OpType {
32
32
  SUM, MEAN, VARIANCE, MIN, MAX,
33
33
  RMS_NORM, ROPE, SOFTMAX, ATTENTION, CONV1D_CAUSAL, CONV1D_K3,
34
34
  SCALAR_ADD, SCALAR_SUBTRACT, SCALAR_MULTIPLY, SCALAR_DIVIDE, SCALAR_EXP, SCALAR_SQRT, SCALAR_COS, SCALAR_SIN,
35
- SILU, GELU,
35
+ SILU, GELU, GELU_ERF,
36
36
  SAMPLE, CONCAT,
37
37
  SCATTER_TOPK,
38
38
  TOPK, LAYERNORM,
@@ -92,32 +92,44 @@ struct TensorConfig {
92
92
  struct BroadcastInfo {
93
93
  std::vector<size_t> output_shape;
94
94
  bool needs_broadcasting;
95
-
95
+
96
96
  static BroadcastInfo compute(const std::vector<size_t>& lhs, const std::vector<size_t>& rhs);
97
97
  };
98
98
 
99
+ class BufferPool;
100
+
99
101
  struct BufferDesc {
100
102
  std::vector<size_t> shape;
101
103
  size_t total_size;
102
104
  size_t byte_size;
103
105
  std::unique_ptr<char[]> data;
104
106
  void* external_data;
107
+ char* pooled_data;
105
108
  Precision precision;
106
109
  float quantization_scale;
107
-
110
+
108
111
  BufferDesc();
109
112
  BufferDesc(const std::vector<size_t>& s, Precision prec = Precision::INT8, float scale = 1.0f);
110
-
113
+ ~BufferDesc();
114
+
115
+ BufferDesc(BufferDesc&& other) noexcept;
116
+ BufferDesc& operator=(BufferDesc&& other) noexcept;
117
+
118
+ BufferDesc(const BufferDesc&) = delete;
119
+ BufferDesc& operator=(const BufferDesc&) = delete;
120
+
111
121
  void* get_data();
112
122
  const void* get_data() const;
113
-
123
+
114
124
  template<typename T>
115
125
  T* data_as() { return static_cast<T*>(get_data()); }
116
-
126
+
117
127
  template<typename T>
118
128
  const T* data_as() const { return static_cast<const T*>(get_data()); }
119
-
129
+
120
130
  void allocate();
131
+ void allocate_from_pool(BufferPool& pool);
132
+ void release_to_pool(BufferPool& pool);
121
133
  void set_external(void* ptr);
122
134
  };
123
135
 
@@ -181,6 +193,33 @@ void compute_topk_node(GraphNode& node, const std::vector<std::unique_ptr<GraphN
181
193
  void compute_layernorm_node(GraphNode& node, const std::vector<std::unique_ptr<GraphNode>>& nodes, const std::unordered_map<size_t, size_t>& node_index_map);
182
194
  void compute_index_node(GraphNode& node, const std::vector<std::unique_ptr<GraphNode>>& nodes, const std::unordered_map<size_t, size_t>& node_index_map);
183
195
 
196
+ void shrink_thread_local_buffers();
197
+
198
+ class BufferPool {
199
+ public:
200
+ BufferPool() = default;
201
+ ~BufferPool() = default;
202
+
203
+ BufferPool(const BufferPool&) = delete;
204
+ BufferPool& operator=(const BufferPool&) = delete;
205
+
206
+ char* acquire(size_t byte_size);
207
+ void release(char* ptr, size_t byte_size);
208
+ void clear();
209
+
210
+ size_t active_bytes() const { return active_bytes_; }
211
+ size_t pool_bytes() const { return pool_bytes_; }
212
+ size_t peak_bytes() const { return peak_bytes_; }
213
+
214
+ private:
215
+ std::unordered_map<size_t, std::vector<std::unique_ptr<char[]>>> free_buffers_;
216
+ size_t active_bytes_ = 0;
217
+ size_t pool_bytes_ = 0;
218
+ size_t peak_bytes_ = 0;
219
+
220
+ size_t round_up_size(size_t size) const;
221
+ };
222
+
184
223
  namespace ValidationUtils {
185
224
  void validate_tensor_dims(const std::vector<size_t>& shape, size_t required_dims, const std::string& op_name);
186
225
  void validate_precision(Precision actual, Precision required, const std::string& op_name);
@@ -219,6 +258,7 @@ public:
219
258
 
220
259
  size_t silu(size_t input);
221
260
  size_t gelu(size_t input);
261
+ size_t gelu_erf(size_t input);
222
262
 
223
263
  size_t matmul(size_t input1, size_t input2, bool pretransposed_rhs = false, ComputeBackend backend = ComputeBackend::CPU);
224
264
  size_t transpose(size_t input, ComputeBackend backend = ComputeBackend::CPU);
@@ -236,6 +276,7 @@ public:
236
276
  size_t gather(size_t embeddings, size_t indices);
237
277
  size_t mmap_embeddings(const std::string& filename);
238
278
  size_t mmap_weights(const std::string& filename);
279
+ size_t load_weights(const std::string& filename);
239
280
  void set_quantization_scale(size_t node_id, float scale);
240
281
  size_t embedding(const std::string& filename, size_t indices);
241
282
  size_t embedding(size_t embedding_tensor, size_t indices);
@@ -284,6 +325,7 @@ private:
284
325
  std::vector<std::unique_ptr<GraphFile::MappedFile>> mapped_files_;
285
326
  std::unordered_map<std::string, size_t> weight_cache_;
286
327
  std::vector<DebugNodeEntry> debug_nodes_;
328
+ BufferPool buffer_pool_;
287
329
  };
288
330
 
289
331
 
@@ -174,6 +174,15 @@ void cactus_gelu_f16(const __fp16* input, __fp16* output, size_t num_elements);
174
174
  void cactus_gelu_int8(const int8_t* input, int8_t* output, size_t num_elements,
175
175
  float input_scale, float output_scale);
176
176
 
177
+ void cactus_gelu_f32_erf(const float* input, float* output, size_t num_elements);
178
+ void cactus_gelu_f16_erf(const __fp16* input, __fp16* output, size_t num_elements);
179
+ void cactus_gelu_int8_erf(
180
+ const int8_t* input,
181
+ int8_t* output,
182
+ size_t num_elements,
183
+ float scale_in,
184
+ float scale_out);
185
+
177
186
 
178
187
  void cactus_attention_int8(const int8_t* queries, const int8_t* keys, const int8_t* values, int8_t* output,
179
188
  size_t batch_size, size_t seq_len, size_t kv_seq_len, size_t num_q_heads, size_t num_kv_heads,
@@ -225,6 +234,28 @@ void cactus_conv1d_causal_depthwise_int8(
225
234
  float weight_scale,
226
235
  float output_scale);
227
236
 
237
+ void cactus_conv1d_f32_k3(
238
+ const float* input,
239
+ const float* weight,
240
+ float* output,
241
+ size_t N,
242
+ size_t L,
243
+ size_t C_in,
244
+ size_t C_out,
245
+ size_t stride
246
+ );
247
+
248
+ void cactus_conv1d_f16_k3(
249
+ const __fp16* input,
250
+ const __fp16* weight,
251
+ __fp16* output,
252
+ size_t N,
253
+ size_t L,
254
+ size_t C_in,
255
+ size_t C_out,
256
+ size_t stride
257
+ );
258
+
228
259
  void cactus_conv1d_f32_k3(
229
260
  const float* input,
230
261
  const float* weight,
@@ -33,6 +33,29 @@ export class Database {
33
33
  }
34
34
  return await CactusUtil.registerApp(await response.text());
35
35
  }
36
+ static async getModel(slug) {
37
+ const response = await fetch(`${this.url}/functions/v1/get-models?slug=${slug}&sdk_name=react&sdk_version=${packageVersion}`, {
38
+ headers: {
39
+ apikey: this.key,
40
+ Authorization: `Bearer ${this.key}`
41
+ }
42
+ });
43
+ if (!response.ok) {
44
+ throw new Error('Getting model failed');
45
+ }
46
+ const model = await response.json();
47
+ return {
48
+ name: model.name,
49
+ slug: model.slug,
50
+ quantization: model.quantization,
51
+ sizeMb: model.size_mb,
52
+ downloadUrl: model.download_url,
53
+ supportsToolCalling: model.supports_tool_calling,
54
+ supportsVision: model.supports_vision,
55
+ createdAt: model.created_at,
56
+ isDownloaded: false
57
+ };
58
+ }
36
59
  static async getModels() {
37
60
  const response = await fetch(`${this.url}/functions/v1/get-models?sdk_name=react&sdk_version=${packageVersion}`, {
38
61
  headers: {
@@ -1 +1 @@
1
- {"version":3,"names":["CactusUtil","packageVersion","Database","url","key","sendLogRecords","records","response","fetch","method","headers","body","JSON","stringify","ok","Error","registerDevice","device_data","registerApp","text","getModels","apikey","Authorization","models","json","map","model","name","slug","quantization","sizeMb","size_mb","downloadUrl","download_url","supportsToolCalling","supports_tool_calling","supportsVision","supports_vision","createdAt","created_at","isDownloaded"],"sourceRoot":"../../../src","sources":["api/Database.ts"],"mappings":";;AAAA,SAASA,UAAU,QAAQ,oBAAW;AAGtC,SAASC,cAAc,QAAQ,gCAA6B;AAc5D,OAAO,MAAMC,QAAQ,CAAC;EACpB,OAAwBC,GAAG,GAAG,0CAA0C;EACxE,OAAwBC,GAAG,GACzB,kNAAkN;EAEpN,aAAoBC,cAAcA,CAACC,OAAoB,EAAiB;IACtE,MAAMC,QAAQ,GAAG,MAAMC,KAAK,CAAC,GAAG,IAAI,CAACL,GAAG,eAAe,EAAE;MACvDM,MAAM,EAAE,MAAM;MACdC,OAAO,EAAE;QACP,QAAQ,EAAE,IAAI,CAACN,GAAG;QAClB,eAAe,EAAE,UAAU,IAAI,CAACA,GAAG,EAAE;QACrC,cAAc,EAAE,kBAAkB;QAClC,iBAAiB,EAAE,QAAQ;QAC3B,QAAQ,EAAE;MACZ,CAAC;MACDO,IAAI,EAAEC,IAAI,CAACC,SAAS,CAACP,OAAO;IAC9B,CAAC,CAAC;IAEF,IAAI,CAACC,QAAQ,CAACO,EAAE,EAAE;MAChB,MAAM,IAAIC,KAAK,CAAC,qBAAqB,CAAC;IACxC;EACF;EAEA,aAAoBC,cAAcA,CAACC,WAAuB,EAAmB;IAC3E,MAAMV,QAAQ,GAAG,MAAMC,KAAK,CAC1B,GAAG,IAAI,CAACL,GAAG,mCAAmC,EAC9C;MACEM,MAAM,EAAE,MAAM;MACdE,IAAI,EAAEC,IAAI,CAACC,SAAS,CAAC;QAAEI;MAAY,CAAC;IACtC,CACF,CAAC;IAED,IAAI,CAACV,QAAQ,CAACO,EAAE,EAAE;MAChB,MAAM,IAAIC,KAAK,CAAC,2BAA2B,CAAC;IAC9C;IAEA,OAAO,MAAMf,UAAU,CAACkB,WAAW,CAAC,MAAMX,QAAQ,CAACY,IAAI,CAAC,CAAC,CAAC;EAC5D;EAEA,aAAoBC,SAASA,CAAA,EAA2B;IACtD,MAAMb,QAAQ,GAAG,MAAMC,KAAK,CAC1B,GAAG,IAAI,CAACL,GAAG,uDAAuDF,cAAc,EAAE,EAClF;MACES,OAAO,EAAE;QAAEW,MAAM,EAAE,IAAI,CAACjB,GAAG;QAAEkB,aAAa,EAAE,UAAU,IAAI,CAAClB,GAAG;MAAG;IACnE,CACF,CAAC;IAED,IAAI,CAACG,QAAQ,CAACO,EAAE,EAAE;MAChB,MAAM,IAAIC,KAAK,CAAC,uBAAuB,CAAC;IAC1C;IAEA,MAAMQ,MAAM,GAAI,MAAMhB,QAAQ,CAACiB,IAAI,CAAC,CAA2B;IAE/D,OAAOD,MAAM,CAACE,GAAG,CAAEC,KAAK,KAAM;MAC5BC,IAAI,EAAED,KAAK,CAACC,IAAI;MAChBC,IAAI,EAAEF,KAAK,CAACE,IAAI;MAChBC,YAAY,EAAEH,KAAK,CAACG,YAAY;MAChCC,MAAM,EAAEJ,KAAK,CAACK,OAAO;MACrBC,WAAW,EAAEN,KAAK,CAACO,YAAY;MAC/BC,mBAAmB,EAAER,KAAK,CAACS,qBAAqB;MAChDC,cAAc,EAAEV,KAAK,CAACW,eAAe;MACrCC,SAAS,EAAEZ,KAAK,CAACa,UAAU;MAC3BC,YAAY,EAAE;IAChB,CAAC,CAAC,CAAC;EACL;AACF","ignoreList":[]}
1
+ {"version":3,"names":["CactusUtil","packageVersion","Database","url","key","sendLogRecords","records","response","fetch","method","headers","body","JSON","stringify","ok","Error","registerDevice","device_data","registerApp","text","getModel","slug","apikey","Authorization","model","json","name","quantization","sizeMb","size_mb","downloadUrl","download_url","supportsToolCalling","supports_tool_calling","supportsVision","supports_vision","createdAt","created_at","isDownloaded","getModels","models","map"],"sourceRoot":"../../../src","sources":["api/Database.ts"],"mappings":";;AAAA,SAASA,UAAU,QAAQ,oBAAW;AAGtC,SAASC,cAAc,QAAQ,gCAA6B;AAc5D,OAAO,MAAMC,QAAQ,CAAC;EACpB,OAAwBC,GAAG,GAAG,0CAA0C;EACxE,OAAwBC,GAAG,GACzB,kNAAkN;EAEpN,aAAoBC,cAAcA,CAACC,OAAoB,EAAiB;IACtE,MAAMC,QAAQ,GAAG,MAAMC,KAAK,CAAC,GAAG,IAAI,CAACL,GAAG,eAAe,EAAE;MACvDM,MAAM,EAAE,MAAM;MACdC,OAAO,EAAE;QACP,QAAQ,EAAE,IAAI,CAACN,GAAG;QAClB,eAAe,EAAE,UAAU,IAAI,CAACA,GAAG,EAAE;QACrC,cAAc,EAAE,kBAAkB;QAClC,iBAAiB,EAAE,QAAQ;QAC3B,QAAQ,EAAE;MACZ,CAAC;MACDO,IAAI,EAAEC,IAAI,CAACC,SAAS,CAACP,OAAO;IAC9B,CAAC,CAAC;IAEF,IAAI,CAACC,QAAQ,CAACO,EAAE,EAAE;MAChB,MAAM,IAAIC,KAAK,CAAC,qBAAqB,CAAC;IACxC;EACF;EAEA,aAAoBC,cAAcA,CAACC,WAAuB,EAAmB;IAC3E,MAAMV,QAAQ,GAAG,MAAMC,KAAK,CAC1B,GAAG,IAAI,CAACL,GAAG,mCAAmC,EAC9C;MACEM,MAAM,EAAE,MAAM;MACdE,IAAI,EAAEC,IAAI,CAACC,SAAS,CAAC;QAAEI;MAAY,CAAC;IACtC,CACF,CAAC;IAED,IAAI,CAACV,QAAQ,CAACO,EAAE,EAAE;MAChB,MAAM,IAAIC,KAAK,CAAC,2BAA2B,CAAC;IAC9C;IAEA,OAAO,MAAMf,UAAU,CAACkB,WAAW,CAAC,MAAMX,QAAQ,CAACY,IAAI,CAAC,CAAC,CAAC;EAC5D;EAEA,aAAoBC,QAAQA,CAACC,IAAY,EAAwB;IAC/D,MAAMd,QAAQ,GAAG,MAAMC,KAAK,CAC1B,GAAG,IAAI,CAACL,GAAG,iCAAiCkB,IAAI,+BAA+BpB,cAAc,EAAE,EAC/F;MACES,OAAO,EAAE;QAAEY,MAAM,EAAE,IAAI,CAAClB,GAAG;QAAEmB,aAAa,EAAE,UAAU,IAAI,CAACnB,GAAG;MAAG;IACnE,CACF,CAAC;IAED,IAAI,CAACG,QAAQ,CAACO,EAAE,EAAE;MAChB,MAAM,IAAIC,KAAK,CAAC,sBAAsB,CAAC;IACzC;IAEA,MAAMS,KAAK,GAAI,MAAMjB,QAAQ,CAACkB,IAAI,CAAC,CAAyB;IAE5D,OAAO;MACLC,IAAI,EAAEF,KAAK,CAACE,IAAI;MAChBL,IAAI,EAAEG,KAAK,CAACH,IAAI;MAChBM,YAAY,EAAEH,KAAK,CAACG,YAAY;MAChCC,MAAM,EAAEJ,KAAK,CAACK,OAAO;MACrBC,WAAW,EAAEN,KAAK,CAACO,YAAY;MAC/BC,mBAAmB,EAAER,KAAK,CAACS,qBAAqB;MAChDC,cAAc,EAAEV,KAAK,CAACW,eAAe;MACrCC,SAAS,EAAEZ,KAAK,CAACa,UAAU;MAC3BC,YAAY,EAAE;IAChB,CAAC;EACH;EAEA,aAAoBC,SAASA,CAAA,EAA2B;IACtD,MAAMhC,QAAQ,GAAG,MAAMC,KAAK,CAC1B,GAAG,IAAI,CAACL,GAAG,uDAAuDF,cAAc,EAAE,EAClF;MACES,OAAO,EAAE;QAAEY,MAAM,EAAE,IAAI,CAAClB,GAAG;QAAEmB,aAAa,EAAE,UAAU,IAAI,CAACnB,GAAG;MAAG;IACnE,CACF,CAAC;IAED,IAAI,CAACG,QAAQ,CAACO,EAAE,EAAE;MAChB,MAAM,IAAIC,KAAK,CAAC,uBAAuB,CAAC;IAC1C;IAEA,MAAMyB,MAAM,GAAI,MAAMjC,QAAQ,CAACkB,IAAI,CAAC,CAA2B;IAE/D,OAAOe,MAAM,CAACC,GAAG,CAAEjB,KAAK,KAAM;MAC5BE,IAAI,EAAEF,KAAK,CAACE,IAAI;MAChBL,IAAI,EAAEG,KAAK,CAACH,IAAI;MAChBM,YAAY,EAAEH,KAAK,CAACG,YAAY;MAChCC,MAAM,EAAEJ,KAAK,CAACK,OAAO;MACrBC,WAAW,EAAEN,KAAK,CAACO,YAAY;MAC/BC,mBAAmB,EAAER,KAAK,CAACS,qBAAqB;MAChDC,cAAc,EAAEV,KAAK,CAACW,eAAe;MACrCC,SAAS,EAAEZ,KAAK,CAACa,UAAU;MAC3BC,YAAY,EAAE;IAChB,CAAC,CAAC,CAAC;EACL;AACF","ignoreList":[]}