cactus-react-native 0.2.2 → 0.2.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (80) hide show
  1. package/README.md +1 -1
  2. package/android/src/main/java/com/cactus/Cactus.java +35 -0
  3. package/android/src/main/java/com/cactus/LlamaContext.java +18 -1
  4. package/android/src/main/jni.cpp +11 -0
  5. package/android/src/main/jniLibs/arm64-v8a/libcactus.so +0 -0
  6. package/android/src/main/jniLibs/arm64-v8a/libcactus_v8.so +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/libcactus_v8_2.so +0 -0
  8. package/android/src/main/jniLibs/arm64-v8a/libcactus_v8_2_dotprod.so +0 -0
  9. package/android/src/main/jniLibs/arm64-v8a/libcactus_v8_2_dotprod_i8mm.so +0 -0
  10. package/android/src/main/jniLibs/arm64-v8a/libcactus_v8_2_i8mm.so +0 -0
  11. package/android/src/newarch/java/com/cactus/CactusModule.java +5 -0
  12. package/android/src/oldarch/java/com/cactus/CactusModule.java +5 -0
  13. package/ios/Cactus.mm +21 -0
  14. package/ios/CactusContext.h +1 -0
  15. package/ios/CactusContext.mm +4 -0
  16. package/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/cactus_ffi.h +0 -12
  17. package/ios/cactus.xcframework/ios-arm64/cactus.framework/cactus +0 -0
  18. package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/Headers/cactus_ffi.h +0 -12
  19. package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/cactus +0 -0
  20. package/ios/cactus.xcframework/tvos-arm64/cactus.framework/Headers/cactus_ffi.h +0 -12
  21. package/ios/cactus.xcframework/tvos-arm64/cactus.framework/cactus +0 -0
  22. package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/Headers/cactus_ffi.h +0 -12
  23. package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/cactus +0 -0
  24. package/lib/commonjs/NativeCactus.js +0 -1
  25. package/lib/commonjs/NativeCactus.js.map +1 -1
  26. package/lib/commonjs/chat.js +33 -0
  27. package/lib/commonjs/chat.js.map +1 -1
  28. package/lib/commonjs/index.js +0 -23
  29. package/lib/commonjs/index.js.map +1 -1
  30. package/lib/commonjs/lm.js +69 -32
  31. package/lib/commonjs/lm.js.map +1 -1
  32. package/lib/commonjs/tools.js +0 -7
  33. package/lib/commonjs/tools.js.map +1 -1
  34. package/lib/commonjs/tts.js +1 -4
  35. package/lib/commonjs/tts.js.map +1 -1
  36. package/lib/commonjs/vlm.js +25 -7
  37. package/lib/commonjs/vlm.js.map +1 -1
  38. package/lib/module/NativeCactus.js +0 -3
  39. package/lib/module/NativeCactus.js.map +1 -1
  40. package/lib/module/chat.js +31 -0
  41. package/lib/module/chat.js.map +1 -1
  42. package/lib/module/index.js +1 -10
  43. package/lib/module/index.js.map +1 -1
  44. package/lib/module/lm.js +70 -32
  45. package/lib/module/lm.js.map +1 -1
  46. package/lib/module/tools.js +0 -7
  47. package/lib/module/tools.js.map +1 -1
  48. package/lib/module/tts.js +1 -4
  49. package/lib/module/tts.js.map +1 -1
  50. package/lib/module/vlm.js +25 -7
  51. package/lib/module/vlm.js.map +1 -1
  52. package/lib/typescript/NativeCactus.d.ts +1 -142
  53. package/lib/typescript/NativeCactus.d.ts.map +1 -1
  54. package/lib/typescript/chat.d.ts +10 -0
  55. package/lib/typescript/chat.d.ts.map +1 -1
  56. package/lib/typescript/index.d.ts +2 -4
  57. package/lib/typescript/index.d.ts.map +1 -1
  58. package/lib/typescript/lm.d.ts +13 -7
  59. package/lib/typescript/lm.d.ts.map +1 -1
  60. package/lib/typescript/tools.d.ts.map +1 -1
  61. package/lib/typescript/tts.d.ts.map +1 -1
  62. package/lib/typescript/vlm.d.ts +3 -1
  63. package/lib/typescript/vlm.d.ts.map +1 -1
  64. package/package.json +1 -1
  65. package/src/NativeCactus.ts +6 -175
  66. package/src/chat.ts +42 -1
  67. package/src/index.ts +6 -17
  68. package/src/lm.ts +81 -26
  69. package/src/tools.ts +0 -5
  70. package/src/tts.ts +1 -4
  71. package/src/vlm.ts +35 -13
  72. package/android/src/main/jniLibs/x86_64/libcactus.so +0 -0
  73. package/android/src/main/jniLibs/x86_64/libcactus_x86_64.so +0 -0
  74. package/lib/commonjs/grammar.js +0 -560
  75. package/lib/commonjs/grammar.js.map +0 -1
  76. package/lib/module/grammar.js +0 -553
  77. package/lib/module/grammar.js.map +0 -1
  78. package/lib/typescript/grammar.d.ts +0 -37
  79. package/lib/typescript/grammar.d.ts.map +0 -1
  80. package/src/grammar.ts +0 -854
@@ -7,68 +7,28 @@ export type NativeEmbeddingParams = {
7
7
 
8
8
  export type NativeContextParams = {
9
9
  model: string
10
- /**
11
- * Chat template to override the default one from the model.
12
- */
13
10
  chat_template?: string
14
-
15
11
  reasoning_format?: string
16
-
17
12
  is_model_asset?: boolean
18
13
  use_progress_callback?: boolean
19
-
20
14
  n_ctx?: number
21
15
  n_batch?: number
22
16
  n_ubatch?: number
23
-
24
17
  n_threads?: number
25
-
26
- /**
27
- * Number of layers to store in VRAM (Currently only for iOS)
28
- */
29
18
  n_gpu_layers?: number
30
- /**
31
- * Skip GPU devices (iOS only)
32
- */
33
19
  no_gpu_devices?: boolean
34
-
35
- /**
36
- * Enable flash attention, only recommended in GPU device (Experimental in llama.cpp)
37
- */
38
20
  flash_attn?: boolean
39
-
40
- /**
41
- * KV cache data type for the K (Experimental in llama.cpp)
42
- */
43
21
  cache_type_k?: string
44
- /**
45
- * KV cache data type for the V (Experimental in llama.cpp)
46
- */
47
22
  cache_type_v?: string
48
-
49
23
  use_mlock?: boolean
50
24
  use_mmap?: boolean
51
25
  vocab_only?: boolean
52
-
53
- /**
54
- * Single LoRA adapter path
55
- */
56
26
  lora?: string
57
- /**
58
- * Single LoRA adapter scale
59
- */
60
27
  lora_scaled?: number
61
- /**
62
- * LoRA adapter list
63
- */
64
28
  lora_list?: Array<{ path: string; scaled?: number }>
65
-
66
29
  rope_freq_base?: number
67
30
  rope_freq_scale?: number
68
-
69
31
  pooling_type?: number
70
-
71
- // Embedding params
72
32
  embedding?: boolean
73
33
  embd_normalize?: number
74
34
  }
@@ -76,22 +36,9 @@ export type NativeContextParams = {
76
36
  export type NativeCompletionParams = {
77
37
  prompt: string
78
38
  n_threads?: number
79
- /**
80
- * JSON schema for convert to grammar for structured JSON output.
81
- * It will be override by grammar if both are set.
82
- */
83
39
  json_schema?: string
84
- /**
85
- * Set grammar for grammar-based sampling. Default: no grammar
86
- */
87
40
  grammar?: string
88
- /**
89
- * Lazy grammar sampling, trigger by grammar_triggers. Default: false
90
- */
91
41
  grammar_lazy?: boolean
92
- /**
93
- * Lazy grammar triggers. Default: []
94
- */
95
42
  grammar_triggers?: Array<{
96
43
  type: number
97
44
  value: string
@@ -99,121 +46,32 @@ export type NativeCompletionParams = {
99
46
  }>
100
47
  preserved_tokens?: Array<string>
101
48
  chat_format?: number
102
- /**
103
- * Specify a JSON array of stopping strings.
104
- * These words will not be included in the completion, so make sure to add them to the prompt for the next iteration. Default: `[]`
105
- */
106
49
  stop?: Array<string>
107
- /**
108
- * Set the maximum number of tokens to predict when generating text.
109
- * **Note:** May exceed the set limit slightly if the last token is a partial multibyte character.
110
- * When 0,no tokens will be generated but the prompt is evaluated into the cache. Default: `-1`, where `-1` is infinity.
111
- */
112
50
  n_predict?: number
113
- /**
114
- * If greater than 0, the response also contains the probabilities of top N tokens for each generated token given the sampling settings.
115
- * Note that for temperature < 0 the tokens are sampled greedily but token probabilities are still being calculated via a simple softmax of the logits without considering any other sampler settings.
116
- * Default: `0`
117
- */
118
51
  n_probs?: number
119
- /**
120
- * Limit the next token selection to the K most probable tokens. Default: `40`
121
- */
122
52
  top_k?: number
123
- /**
124
- * Limit the next token selection to a subset of tokens with a cumulative probability above a threshold P. Default: `0.95`
125
- */
126
53
  top_p?: number
127
- /**
128
- * The minimum probability for a token to be considered, relative to the probability of the most likely token. Default: `0.05`
129
- */
130
54
  min_p?: number
131
- /**
132
- * Set the chance for token removal via XTC sampler. Default: `0.0`, which is disabled.
133
- */
134
55
  xtc_probability?: number
135
- /**
136
- * Set a minimum probability threshold for tokens to be removed via XTC sampler. Default: `0.1` (> `0.5` disables XTC)
137
- */
138
56
  xtc_threshold?: number
139
- /**
140
- * Enable locally typical sampling with parameter p. Default: `1.0`, which is disabled.
141
- */
142
57
  typical_p?: number
143
- /**
144
- * Adjust the randomness of the generated text. Default: `0.8`
145
- */
146
58
  temperature?: number
147
- /**
148
- * Last n tokens to consider for penalizing repetition. Default: `64`, where `0` is disabled and `-1` is ctx-size.
149
- */
150
59
  penalty_last_n?: number
151
- /**
152
- * Control the repetition of token sequences in the generated text. Default: `1.0`
153
- */
154
60
  penalty_repeat?: number
155
- /**
156
- * Repeat alpha frequency penalty. Default: `0.0`, which is disabled.
157
- */
158
61
  penalty_freq?: number
159
- /**
160
- * Repeat alpha presence penalty. Default: `0.0`, which is disabled.
161
- */
162
62
  penalty_present?: number
163
- /**
164
- * Enable Mirostat sampling, controlling perplexity during text generation. Default: `0`, where `0` is disabled, `1` is Mirostat, and `2` is Mirostat 2.0.
165
- */
166
63
  mirostat?: number
167
- /**
168
- * Set the Mirostat target entropy, parameter tau. Default: `5.0`
169
- */
170
64
  mirostat_tau?: number
171
- /**
172
- * Set the Mirostat learning rate, parameter eta. Default: `0.1`
173
- */
174
65
  mirostat_eta?: number
175
- /**
176
- * Set the DRY (Don't Repeat Yourself) repetition penalty multiplier. Default: `0.0`, which is disabled.
177
- */
178
66
  dry_multiplier?: number
179
- /**
180
- * Set the DRY repetition penalty base value. Default: `1.75`
181
- */
182
67
  dry_base?: number
183
- /**
184
- * Tokens that extend repetition beyond this receive exponentially increasing penalty: multiplier * base ^ (length of repeating sequence before token - allowed length). Default: `2`
185
- */
186
68
  dry_allowed_length?: number
187
- /**
188
- * How many tokens to scan for repetitions. Default: `-1`, where `0` is disabled and `-1` is context size.
189
- */
190
69
  dry_penalty_last_n?: number
191
- /**
192
- * Specify an array of sequence breakers for DRY sampling. Only a JSON array of strings is accepted. Default: `['\n', ':', '"', '*']`
193
- */
194
70
  dry_sequence_breakers?: Array<string>
195
- /**
196
- * Top n sigma sampling as described in academic paper "Top-nσ: Not All Logits Are You Need" https://arxiv.org/pdf/2411.07641. Default: `-1.0` (Disabled)
197
- */
198
71
  top_n_sigma?: number
199
-
200
- /**
201
- * Ignore end of stream token and continue generating. Default: `false`
202
- */
203
72
  ignore_eos?: boolean
204
- /**
205
- * Modify the likelihood of a token appearing in the generated text completion.
206
- * For example, use `"logit_bias": [[15043,1.0]]` to increase the likelihood of the token 'Hello', or `"logit_bias": [[15043,-1.0]]` to decrease its likelihood.
207
- * Setting the value to false, `"logit_bias": [[15043,false]]` ensures that the token `Hello` is never produced. The tokens can also be represented as strings,
208
- * e.g.`[["Hello, World!",-0.5]]` will reduce the likelihood of all the individual tokens that represent the string `Hello, World!`, just like the `presence_penalty` does.
209
- * Default: `[]`
210
- */
211
73
  logit_bias?: Array<Array<number>>
212
- /**
213
- * Set the random number generator (RNG) seed. Default: `-1`, which is a random seed.
214
- */
215
74
  seed?: number
216
-
217
75
  emit_partial_completion: boolean
218
76
  }
219
77
 
@@ -239,17 +97,8 @@ export type NativeCompletionResultTimings = {
239
97
  }
240
98
 
241
99
  export type NativeCompletionResult = {
242
- /**
243
- * Original text (Ignored reasoning_content / tool_calls)
244
- */
245
100
  text: string
246
- /**
247
- * Reasoning content (parsed for reasoning model)
248
- */
249
101
  reasoning_content: string
250
- /**
251
- * Tool calls
252
- */
253
102
  tool_calls: Array<{
254
103
  type: 'function'
255
104
  function: {
@@ -258,11 +107,7 @@ export type NativeCompletionResult = {
258
107
  }
259
108
  id?: string
260
109
  }>
261
- /**
262
- * Content text (Filtered text by reasoning_content / tool_calls)
263
- */
264
110
  content: string
265
-
266
111
  tokens_predicted: number
267
112
  tokens_evaluated: number
268
113
  truncated: boolean
@@ -272,13 +117,11 @@ export type NativeCompletionResult = {
272
117
  stopping_word: string
273
118
  tokens_cached: number
274
119
  timings: NativeCompletionResultTimings
275
-
276
120
  completion_probabilities?: Array<NativeCompletionTokenProb>
277
121
  }
278
122
 
279
123
  export type NativeTokenizeResult = {
280
124
  tokens: Array<number>
281
- // New multimodal support
282
125
  has_media?: boolean
283
126
  bitmap_hashes?: Array<string>
284
127
  chunk_pos?: Array<number>
@@ -289,9 +132,8 @@ export type NativeEmbeddingResult = {
289
132
  embedding: Array<number>
290
133
  }
291
134
 
292
- // New TTS/Audio types
293
135
  export type NativeTTSType = {
294
- type: number // TTS_UNKNOWN = -1, TTS_OUTETTS_V0_2 = 1, TTS_OUTETTS_V0_3 = 2
136
+ type: number
295
137
  }
296
138
 
297
139
  export type NativeAudioCompletionResult = {
@@ -303,7 +145,7 @@ export type NativeAudioTokensResult = {
303
145
  }
304
146
 
305
147
  export type NativeAudioDecodeResult = {
306
- audio_data: Array<number> // Float array of audio samples
148
+ audio_data: Array<number>
307
149
  }
308
150
 
309
151
  export type NativeDeviceInfo = {
@@ -312,6 +154,7 @@ export type NativeDeviceInfo = {
312
154
  make: string
313
155
  os: string
314
156
  }
157
+
315
158
  export type NativeLlamaContext = {
316
159
  contextId: number
317
160
  model: {
@@ -320,9 +163,8 @@ export type NativeLlamaContext = {
320
163
  nEmbd: number
321
164
  nParams: number
322
165
  chatTemplates: {
323
- llamaChat: boolean // Chat template in llama-chat.cpp
166
+ llamaChat: boolean
324
167
  minja: {
325
- // Chat template supported by minja.hpp
326
168
  default: boolean
327
169
  defaultCaps: {
328
170
  tools: boolean
@@ -344,11 +186,8 @@ export type NativeLlamaContext = {
344
186
  }
345
187
  }
346
188
  metadata: Object
347
- isChatTemplateSupported: boolean // Deprecated
189
+ isChatTemplateSupported: boolean
348
190
  }
349
- /**
350
- * Loaded library name for Android
351
- */
352
191
  androidLib?: string
353
192
  gpu: boolean
354
193
  reasonNoGPU: string
@@ -381,13 +220,11 @@ export type JinjaFormattedChatResult = {
381
220
  export interface Spec extends TurboModule {
382
221
  toggleNativeLog(enabled: boolean): Promise<void>
383
222
  setContextLimit(limit: number): Promise<void>
384
-
385
223
  modelInfo(path: string, skip?: string[]): Promise<Object>
386
224
  initContext(
387
225
  contextId: number,
388
226
  params: NativeContextParams,
389
227
  ): Promise<NativeLlamaContext>
390
-
391
228
  getFormattedChat(
392
229
  contextId: number,
393
230
  messages: string,
@@ -434,7 +271,6 @@ export interface Spec extends TurboModule {
434
271
  pl: number,
435
272
  nr: number,
436
273
  ): Promise<string>
437
-
438
274
  applyLoraAdapters(
439
275
  contextId: number,
440
276
  loraAdapters: Array<{ path: string; scaled?: number }>,
@@ -443,8 +279,6 @@ export interface Spec extends TurboModule {
443
279
  getLoadedLoraAdapters(
444
280
  contextId: number,
445
281
  ): Promise<Array<{ path: string; scaled?: number }>>
446
-
447
- // New Multimodal Methods
448
282
  initMultimodal(
449
283
  contextId: number,
450
284
  mmprojPath: string,
@@ -454,8 +288,6 @@ export interface Spec extends TurboModule {
454
288
  isMultimodalSupportVision(contextId: number): Promise<boolean>
455
289
  isMultimodalSupportAudio(contextId: number): Promise<boolean>
456
290
  releaseMultimodal(contextId: number): Promise<void>
457
-
458
- // New TTS/Vocoder Methods
459
291
  initVocoder(
460
292
  contextId: number,
461
293
  vocoderModelPath: string,
@@ -477,9 +309,8 @@ export interface Spec extends TurboModule {
477
309
  ): Promise<NativeAudioDecodeResult>
478
310
  getDeviceInfo(contextId: number): Promise<NativeDeviceInfo>
479
311
  releaseVocoder(contextId: number): Promise<void>
480
-
312
+ rewind(contextId: number): Promise<void>
481
313
  releaseContext(contextId: number): Promise<void>
482
-
483
314
  releaseAllContexts(): Promise<void>
484
315
  }
485
316
 
package/src/chat.ts CHANGED
@@ -6,7 +6,7 @@ export type CactusMessagePart = {
6
6
 
7
7
  export type CactusOAICompatibleMessage = {
8
8
  role: string
9
- content?: string | CactusMessagePart[] | any // any for check invalid content type
9
+ content?: string | CactusMessagePart[] | any
10
10
  }
11
11
 
12
12
  export function formatChat(
@@ -42,3 +42,44 @@ export function formatChat(
42
42
  })
43
43
  return chat
44
44
  }
45
+
46
+ export interface ProcessedMessages {
47
+ newMessages: CactusOAICompatibleMessage[];
48
+ requiresReset: boolean;
49
+ }
50
+
51
+ export class ConversationHistoryManager {
52
+ private history: CactusOAICompatibleMessage[] = [];
53
+
54
+ public processNewMessages(
55
+ fullMessageHistory: CactusOAICompatibleMessage[]
56
+ ): ProcessedMessages {
57
+ let divergent = fullMessageHistory.length < this.history.length;
58
+ if (!divergent) {
59
+ for (let i = 0; i < this.history.length; i++) {
60
+ if (JSON.stringify(this.history[i]) !== JSON.stringify(fullMessageHistory[i])) {
61
+ divergent = true;
62
+ break;
63
+ }
64
+ }
65
+ }
66
+
67
+ if (divergent) {
68
+ return { newMessages: fullMessageHistory, requiresReset: true };
69
+ }
70
+
71
+ const newMessages = fullMessageHistory.slice(this.history.length);
72
+ return { newMessages, requiresReset: false };
73
+ }
74
+
75
+ public update(
76
+ newMessages: CactusOAICompatibleMessage[],
77
+ assistantResponse: CactusOAICompatibleMessage
78
+ ) {
79
+ this.history.push(...newMessages, assistantResponse);
80
+ }
81
+
82
+ public reset() {
83
+ this.history = [];
84
+ }
85
+ }
package/src/index.ts CHANGED
@@ -1,6 +1,7 @@
1
1
  import { NativeEventEmitter, DeviceEventEmitter, Platform } from 'react-native'
2
2
  import type { DeviceEventEmitterStatic } from 'react-native'
3
3
  import Cactus from './NativeCactus'
4
+
4
5
  import type {
5
6
  NativeContextParams,
6
7
  NativeLlamaContext,
@@ -20,15 +21,13 @@ import type {
20
21
  NativeAudioDecodeResult,
21
22
  NativeDeviceInfo,
22
23
  } from './NativeCactus'
23
- import type {
24
- SchemaGrammarConverterPropOrder,
25
- SchemaGrammarConverterBuiltinRule,
26
- } from './grammar'
27
- import { SchemaGrammarConverter, convertJsonSchemaToGrammar } from './grammar'
24
+
25
+
28
26
  import type { CactusMessagePart, CactusOAICompatibleMessage } from './chat'
29
27
  import { formatChat } from './chat'
30
28
  import { Tools, parseAndExecuteTool } from './tools'
31
29
  import { Telemetry, type TelemetryParams } from './telemetry'
30
+
32
31
  export type {
33
32
  NativeContextParams,
34
33
  NativeLlamaContext,
@@ -45,13 +44,9 @@ export type {
45
44
  CactusOAICompatibleMessage,
46
45
  JinjaFormattedChatResult,
47
46
  NativeAudioDecodeResult,
48
-
49
- // Deprecated
50
- SchemaGrammarConverterPropOrder,
51
- SchemaGrammarConverterBuiltinRule,
52
47
  }
53
48
 
54
- export { SchemaGrammarConverter, convertJsonSchemaToGrammar, Tools }
49
+ export {Tools }
55
50
  export * from './remote'
56
51
 
57
52
  const EVENT_ON_INIT_CONTEXT_PROGRESS = '@Cactus_onInitContextProgress'
@@ -254,7 +249,6 @@ export class LlamaContext {
254
249
  return this.completion(params, callback);
255
250
  }
256
251
  if (recursionCount >= recursionLimit) {
257
- // console.log(`Recursion limit reached (${recursionCount}/${recursionLimit}), returning default completion`)
258
252
  return this.completion({
259
253
  ...params,
260
254
  jinja: true,
@@ -264,14 +258,12 @@ export class LlamaContext {
264
258
 
265
259
  const messages = [...params.messages]; // avoid mutating the original messages
266
260
 
267
- // console.log('Calling completion...')
268
261
  const result = await this.completion({
269
262
  ...params,
270
263
  messages: messages,
271
264
  jinja: true,
272
265
  tools: params.tools.getSchemas()
273
266
  }, callback);
274
- // console.log('Completion result:', result);
275
267
 
276
268
  const {toolCalled, toolName, toolInput, toolOutput} =
277
269
  await parseAndExecuteTool(result, params.tools);
@@ -294,8 +286,6 @@ export class LlamaContext {
294
286
 
295
287
  messages.push(toolMessage);
296
288
 
297
- // console.log('Messages being sent to next completion:', JSON.stringify(messages, null, 2));
298
-
299
289
  return await this.completionWithTools(
300
290
  {...params, messages: messages},
301
291
  callback,
@@ -471,8 +461,7 @@ export class LlamaContext {
471
461
  }
472
462
 
473
463
  async rewind(): Promise<void> {
474
- // @ts-ignore
475
- return (Cactus as any).rewind(this.id)
464
+ return Cactus.rewind(this.id)
476
465
  }
477
466
  }
478
467
 
package/src/lm.ts CHANGED
@@ -9,8 +9,10 @@ import type {
9
9
  EmbeddingParams,
10
10
  NativeEmbeddingResult,
11
11
  } from './index'
12
+
12
13
  import { Telemetry } from './telemetry'
13
14
  import { setCactusToken, getVertexAIEmbedding } from './remote'
15
+ import { ConversationHistoryManager } from './chat'
14
16
 
15
17
  interface CactusLMReturn {
16
18
  lm: CactusLM | null
@@ -18,51 +20,105 @@ interface CactusLMReturn {
18
20
  }
19
21
 
20
22
  export class CactusLM {
21
- private context: LlamaContext
23
+ protected context: LlamaContext
24
+ protected conversationHistoryManager: ConversationHistoryManager
22
25
 
23
- private constructor(context: LlamaContext) {
26
+ protected constructor(context: LlamaContext) {
24
27
  this.context = context
28
+ this.conversationHistoryManager = new ConversationHistoryManager()
25
29
  }
26
30
 
27
31
  static async init(
28
32
  params: ContextParams,
29
33
  onProgress?: (progress: number) => void,
30
34
  cactusToken?: string,
35
+ retryOptions?: { maxRetries?: number; delayMs?: number },
31
36
  ): Promise<CactusLMReturn> {
32
37
  if (cactusToken) {
33
38
  setCactusToken(cactusToken);
34
39
  }
35
40
 
36
- // Avoid two back-to-back loads on devices where GPU off-load is unsupported (Android).
37
- const needGpuAttempt = Platform.OS !== 'android' && (params.n_gpu_layers ?? 0) > 0
38
- const configs = needGpuAttempt
39
- ? [params, { ...params, n_gpu_layers: 0 }]
40
- : [{ ...params, n_gpu_layers: 0 }]
41
+ const maxRetries = retryOptions?.maxRetries ?? 3;
42
+ const delayMs = retryOptions?.delayMs ?? 1000;
43
+
44
+ const configs = [
45
+ params,
46
+ { ...params, n_gpu_layers: 0 }
47
+ ];
48
+
49
+ const sleep = (ms: number): Promise<void> => {
50
+ return new Promise(resolve => {
51
+ const start = Date.now();
52
+ const wait = () => {
53
+ if (Date.now() - start >= ms) {
54
+ resolve();
55
+ } else {
56
+ Promise.resolve().then(wait);
57
+ }
58
+ };
59
+ wait();
60
+ });
61
+ };
41
62
 
42
63
  for (const config of configs) {
43
- try {
44
- const context = await initLlama(config, onProgress);
45
- return { lm: new CactusLM(context), error: null };
46
- } catch (e) {
47
- Telemetry.error(e as Error, {
48
- n_gpu_layers: config.n_gpu_layers ?? null,
49
- n_ctx: config.n_ctx ?? null,
50
- model: config.model ?? null,
51
- });
52
- if (configs.indexOf(config) === configs.length - 1) {
53
- return { lm: null, error: e as Error };
64
+ let lastError: Error | null = null;
65
+
66
+ for (let attempt = 1; attempt <= maxRetries; attempt++) {
67
+ try {
68
+ const context = await initLlama(config, onProgress);
69
+ return { lm: new CactusLM(context), error: null };
70
+ } catch (e) {
71
+ lastError = e as Error;
72
+ const isLastConfig = configs.indexOf(config) === configs.length - 1;
73
+ const isLastAttempt = attempt === maxRetries;
74
+
75
+ Telemetry.error(e as Error, {
76
+ n_gpu_layers: config.n_gpu_layers ?? null,
77
+ n_ctx: config.n_ctx ?? null,
78
+ model: config.model ?? null,
79
+ });
80
+
81
+ if (!isLastAttempt) {
82
+ const delay = delayMs * Math.pow(2, attempt - 1);
83
+ await sleep(delay);
84
+ } else if (!isLastConfig) {
85
+ break;
86
+ }
54
87
  }
55
88
  }
89
+
90
+ if (configs.indexOf(config) === configs.length - 1 && lastError) {
91
+ return { lm: null, error: lastError };
92
+ }
56
93
  }
57
- return { lm: null, error: new Error('Failed to initialize CactusLM') };
94
+ return { lm: null, error: new Error('Failed to initialize CactusLM after all retries') };
58
95
  }
59
96
 
60
- async completion(
97
+ completion = async (
61
98
  messages: CactusOAICompatibleMessage[],
62
99
  params: CompletionParams = {},
63
100
  callback?: (data: any) => void,
64
- ): Promise<NativeCompletionResult> {
65
- return await this.context.completion({ messages, ...params }, callback);
101
+ ): Promise<NativeCompletionResult> => {
102
+ const { newMessages, requiresReset } =
103
+ this.conversationHistoryManager.processNewMessages(messages);
104
+
105
+ if (requiresReset) {
106
+ this.context?.rewind();
107
+ this.conversationHistoryManager.reset();
108
+ }
109
+
110
+ if (newMessages.length === 0) {
111
+ console.warn('No messages to complete!');
112
+ }
113
+
114
+ const result = await this.context.completion({ messages: newMessages, ...params }, callback);
115
+
116
+ this.conversationHistoryManager.update(newMessages, {
117
+ role: 'assistant',
118
+ content: result.content,
119
+ });
120
+
121
+ return result;
66
122
  }
67
123
 
68
124
  async embedding(
@@ -105,19 +161,18 @@ export class CactusLM {
105
161
  return result;
106
162
  }
107
163
 
108
- private async _handleLocalEmbedding(text: string, params?: EmbeddingParams): Promise<NativeEmbeddingResult> {
164
+ protected async _handleLocalEmbedding(text: string, params?: EmbeddingParams): Promise<NativeEmbeddingResult> {
109
165
  return this.context.embedding(text, params);
110
166
  }
111
167
 
112
- private async _handleRemoteEmbedding(text: string): Promise<NativeEmbeddingResult> {
168
+ protected async _handleRemoteEmbedding(text: string): Promise<NativeEmbeddingResult> {
113
169
  const embeddingValues = await getVertexAIEmbedding(text);
114
170
  return {
115
171
  embedding: embeddingValues,
116
172
  };
117
173
  }
118
174
 
119
- async rewind(): Promise<void> {
120
- // @ts-ignore
175
+ rewind = async (): Promise<void> => {
121
176
  return this.context?.rewind()
122
177
  }
123
178
 
package/src/tools.ts CHANGED
@@ -56,22 +56,18 @@ export class Tools {
56
56
 
57
57
  export async function parseAndExecuteTool(result: NativeCompletionResult, tools: Tools): Promise<{toolCalled: boolean, toolName?: string, toolInput?: any, toolOutput?: any}> {
58
58
  if (!result.tool_calls || result.tool_calls.length === 0) {
59
- // console.log('No tool calls found');
60
59
  return {toolCalled: false};
61
60
  }
62
61
 
63
62
  try {
64
63
  const toolCall = result.tool_calls[0];
65
64
  if (!toolCall) {
66
- // console.log('No tool call found');
67
65
  return {toolCalled: false};
68
66
  }
69
67
  const toolName = toolCall.function.name;
70
68
  const toolInput = JSON.parse(toolCall.function.arguments);
71
69
 
72
- // console.log('Calling tool:', toolName, toolInput);
73
70
  const toolOutput = await tools.execute(toolName, toolInput);
74
- // console.log('Tool called result:', toolOutput);
75
71
 
76
72
  return {
77
73
  toolCalled: true,
@@ -80,7 +76,6 @@ export async function parseAndExecuteTool(result: NativeCompletionResult, tools:
80
76
  toolOutput
81
77
  };
82
78
  } catch (error) {
83
- // console.error('Error parsing tool call:', error);
84
79
  return {toolCalled: false};
85
80
  }
86
81
  }
package/src/tts.ts CHANGED
@@ -31,10 +31,7 @@ export class CactusTTS {
31
31
  speakerJsonStr,
32
32
  textToSpeak,
33
33
  )
34
- // This part is simplified. In a real scenario, the tokens from
35
- // the main model would be generated and passed to decodeAudioTokens.
36
- // For now, we are assuming a direct path which may not be fully functional
37
- // without the main model's token output for TTS.
34
+ // To-DO: Fix
38
35
  const tokens = (await this.context.tokenize(formatted_prompt)).tokens
39
36
  return decodeAudioTokens(this.context.id, tokens)
40
37
  }