cactus-react-native 0.1.3 → 0.2.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. package/README.md +551 -720
  2. package/android/src/main/java/com/cactus/Cactus.java +41 -0
  3. package/android/src/main/java/com/cactus/LlamaContext.java +19 -0
  4. package/android/src/newarch/java/com/cactus/CactusModule.java +5 -0
  5. package/android/src/oldarch/java/com/cactus/CactusModule.java +5 -0
  6. package/ios/Cactus.mm +14 -0
  7. package/ios/CactusContext.h +1 -0
  8. package/ios/CactusContext.mm +18 -0
  9. package/ios/cactus.xcframework/ios-arm64_x86_64-simulator/cactus.framework/cactus +0 -0
  10. package/ios/cactus.xcframework/tvos-arm64_x86_64-simulator/cactus.framework/cactus +0 -0
  11. package/lib/commonjs/NativeCactus.js +10 -0
  12. package/lib/commonjs/NativeCactus.js.map +1 -1
  13. package/lib/commonjs/chat.js +37 -0
  14. package/lib/commonjs/grammar.js +560 -0
  15. package/lib/commonjs/index.js +545 -0
  16. package/lib/commonjs/index.js.map +1 -1
  17. package/lib/commonjs/lm.js +106 -0
  18. package/lib/commonjs/lm.js.map +1 -1
  19. package/lib/commonjs/projectId.js +8 -0
  20. package/lib/commonjs/projectId.js.map +1 -0
  21. package/lib/commonjs/remote.js +153 -0
  22. package/lib/commonjs/remote.js.map +1 -0
  23. package/lib/commonjs/telemetry.js +103 -0
  24. package/lib/commonjs/telemetry.js.map +1 -0
  25. package/lib/commonjs/tools.js +79 -0
  26. package/lib/commonjs/tools.js.map +1 -0
  27. package/lib/commonjs/tts.js +32 -0
  28. package/lib/commonjs/tts.js.map +1 -1
  29. package/lib/commonjs/vlm.js +150 -0
  30. package/lib/commonjs/vlm.js.map +1 -0
  31. package/lib/module/NativeCactus.js +8 -0
  32. package/lib/module/NativeCactus.js.map +1 -1
  33. package/lib/module/chat.js +33 -0
  34. package/lib/module/grammar.js +553 -0
  35. package/lib/module/index.js +435 -0
  36. package/lib/module/index.js.map +1 -1
  37. package/lib/module/lm.js +101 -0
  38. package/lib/module/lm.js.map +1 -0
  39. package/lib/module/projectId.js +4 -0
  40. package/lib/module/projectId.js.map +1 -0
  41. package/lib/module/remote.js +144 -0
  42. package/lib/module/remote.js.map +1 -0
  43. package/lib/module/telemetry.js +98 -0
  44. package/lib/module/telemetry.js.map +1 -0
  45. package/lib/module/tools.js +73 -0
  46. package/lib/module/tools.js.map +1 -0
  47. package/lib/module/tts.js +27 -0
  48. package/lib/module/tts.js.map +1 -1
  49. package/lib/module/vlm.js +145 -0
  50. package/lib/module/vlm.js.map +1 -1
  51. package/lib/typescript/NativeCactus.d.ts +7 -0
  52. package/lib/typescript/NativeCactus.d.ts.map +1 -1
  53. package/lib/typescript/index.d.ts +3 -1
  54. package/lib/typescript/index.d.ts.map +1 -1
  55. package/lib/typescript/lm.d.ts +11 -34
  56. package/lib/typescript/lm.d.ts.map +1 -1
  57. package/lib/typescript/projectId.d.ts +2 -0
  58. package/lib/typescript/projectId.d.ts.map +1 -0
  59. package/lib/typescript/remote.d.ts +7 -0
  60. package/lib/typescript/remote.d.ts.map +1 -0
  61. package/lib/typescript/telemetry.d.ts +25 -0
  62. package/lib/typescript/telemetry.d.ts.map +1 -0
  63. package/lib/typescript/tools.d.ts +0 -3
  64. package/lib/typescript/tools.d.ts.map +1 -1
  65. package/lib/typescript/tts.d.ts.map +1 -1
  66. package/lib/typescript/vlm.d.ts +14 -34
  67. package/lib/typescript/vlm.d.ts.map +1 -1
  68. package/package.json +4 -4
  69. package/scripts/postInstall.js +33 -0
  70. package/src/NativeCactus.ts +7 -0
  71. package/src/index.ts +122 -46
  72. package/src/lm.ts +80 -5
  73. package/src/projectId.ts +1 -0
  74. package/src/remote.ts +175 -0
  75. package/src/telemetry.ts +138 -0
  76. package/src/tools.ts +17 -58
  77. package/src/vlm.ts +129 -8
  78. package/android/src/main/jniLibs/x86_64/libcactus.so +0 -0
  79. package/android/src/main/jniLibs/x86_64/libcactus_x86_64.so +0 -0
package/src/lm.ts CHANGED
@@ -7,6 +7,13 @@ import type {
7
7
  EmbeddingParams,
8
8
  NativeEmbeddingResult,
9
9
  } from './index'
10
+ import { Telemetry } from './telemetry'
11
+ import { setCactusToken, getVertexAIEmbedding } from './remote'
12
+
13
+ interface CactusLMReturn {
14
+ lm: CactusLM | null
15
+ error: Error | null
16
+ }
10
17
 
11
18
  export class CactusLM {
12
19
  private context: LlamaContext
@@ -18,9 +25,33 @@ export class CactusLM {
18
25
  static async init(
19
26
  params: ContextParams,
20
27
  onProgress?: (progress: number) => void,
21
- ): Promise<CactusLM> {
22
- const context = await initLlama(params, onProgress)
23
- return new CactusLM(context)
28
+ cactusToken?: string,
29
+ ): Promise<CactusLMReturn> {
30
+ if (cactusToken) {
31
+ setCactusToken(cactusToken);
32
+ }
33
+
34
+ const configs = [
35
+ params,
36
+ { ...params, n_gpu_layers: 0 }
37
+ ];
38
+
39
+ for (const config of configs) {
40
+ try {
41
+ const context = await initLlama(config, onProgress);
42
+ return { lm: new CactusLM(context), error: null };
43
+ } catch (e) {
44
+ Telemetry.error(e as Error, {
45
+ n_gpu_layers: config.n_gpu_layers ?? null,
46
+ n_ctx: config.n_ctx ?? null,
47
+ model: config.model ?? null,
48
+ });
49
+ if (configs.indexOf(config) === configs.length - 1) {
50
+ return { lm: null, error: e as Error };
51
+ }
52
+ }
53
+ }
54
+ return { lm: null, error: new Error('Failed to initialize CactusLM') };
24
55
  }
25
56
 
26
57
  async completion(
@@ -28,14 +59,58 @@ export class CactusLM {
28
59
  params: CompletionParams = {},
29
60
  callback?: (data: any) => void,
30
61
  ): Promise<NativeCompletionResult> {
31
- return this.context.completion({ messages, ...params }, callback)
62
+ return await this.context.completion({ messages, ...params }, callback);
32
63
  }
33
64
 
34
65
  async embedding(
35
66
  text: string,
36
67
  params?: EmbeddingParams,
68
+ mode: string = 'local',
37
69
  ): Promise<NativeEmbeddingResult> {
38
- return this.context.embedding(text, params)
70
+ let result: NativeEmbeddingResult;
71
+ let lastError: Error | null = null;
72
+
73
+ if (mode === 'remote') {
74
+ result = await this._handleRemoteEmbedding(text);
75
+ } else if (mode === 'local') {
76
+ result = await this._handleLocalEmbedding(text, params);
77
+ } else if (mode === 'localfirst') {
78
+ try {
79
+ result = await this._handleLocalEmbedding(text, params);
80
+ } catch (e) {
81
+ lastError = e as Error;
82
+ try {
83
+ result = await this._handleRemoteEmbedding(text);
84
+ } catch (remoteError) {
85
+ throw lastError;
86
+ }
87
+ }
88
+ } else if (mode === 'remotefirst') {
89
+ try {
90
+ result = await this._handleRemoteEmbedding(text);
91
+ } catch (e) {
92
+ lastError = e as Error;
93
+ try {
94
+ result = await this._handleLocalEmbedding(text, params);
95
+ } catch (localError) {
96
+ throw lastError;
97
+ }
98
+ }
99
+ } else {
100
+ throw new Error('Invalid mode: ' + mode + '. Must be "local", "remote", "localfirst", or "remotefirst"');
101
+ }
102
+ return result;
103
+ }
104
+
105
+ private async _handleLocalEmbedding(text: string, params?: EmbeddingParams): Promise<NativeEmbeddingResult> {
106
+ return this.context.embedding(text, params);
107
+ }
108
+
109
+ private async _handleRemoteEmbedding(text: string): Promise<NativeEmbeddingResult> {
110
+ const embeddingValues = await getVertexAIEmbedding(text);
111
+ return {
112
+ embedding: embeddingValues,
113
+ };
39
114
  }
40
115
 
41
116
  async rewind(): Promise<void> {
@@ -0,0 +1 @@
1
+ export const PROJECT_ID = 'not-set';
package/src/remote.ts ADDED
@@ -0,0 +1,175 @@
1
+ let _cactusToken: string | null = null;
2
+
3
+ export function setCactusToken(token: string | null): void {
4
+ _cactusToken = token;
5
+ }
6
+
7
+ export async function getVertexAIEmbedding(text: string): Promise<number[]> {
8
+ if (_cactusToken === null) {
9
+ throw new Error('CactusToken not set. Please call CactusLM.init with cactusToken parameter.');
10
+ }
11
+
12
+ const projectId = 'cactus-v1-452518';
13
+ const location = 'us-central1';
14
+ const modelId = 'text-embedding-005';
15
+
16
+ const endpoint = `https://${location}-aiplatform.googleapis.com/v1/projects/${projectId}/locations/${location}/publishers/google/models/${modelId}:predict`;
17
+
18
+ const headers = {
19
+ 'Authorization': `Bearer ${_cactusToken}`,
20
+ 'Content-Type': 'application/json',
21
+ };
22
+
23
+ const requestBody = {
24
+ instances: [{ content: text }]
25
+ };
26
+
27
+ const response = await fetch(endpoint, {
28
+ method: 'POST',
29
+ headers,
30
+ body: JSON.stringify(requestBody),
31
+ });
32
+
33
+ if (response.status === 401) {
34
+ _cactusToken = null;
35
+ throw new Error('Authentication failed. Please update your cactusToken.');
36
+ } else if (!response.ok) {
37
+ const errorText = await response.text();
38
+ throw new Error(`HTTP ${response.status}: ${errorText}`);
39
+ }
40
+
41
+ const responseBody = await response.json();
42
+
43
+ if (responseBody.error) {
44
+ throw new Error(`API Error: ${responseBody.error.message}`);
45
+ }
46
+
47
+ const predictions = responseBody.predictions;
48
+ if (!predictions || predictions.length === 0) {
49
+ throw new Error('No predictions in response');
50
+ }
51
+
52
+ const embeddings = predictions[0].embeddings;
53
+ const values = embeddings.values;
54
+
55
+ return values;
56
+ }
57
+
58
+ export async function getVertexAICompletion(
59
+ textPrompt: string,
60
+ imageData?: string,
61
+ imagePath?: string,
62
+ mimeType?: string,
63
+ ): Promise<string> {
64
+ if (_cactusToken === null) {
65
+ throw new Error('CactusToken not set. Please call CactusVLM.init with cactusToken parameter.');
66
+ }
67
+
68
+ const projectId = 'cactus-v1-452518';
69
+ const location = 'global';
70
+ const modelId = 'gemini-2.5-flash-lite-preview-06-17';
71
+
72
+ const endpoint = `https://aiplatform.googleapis.com/v1/projects/${projectId}/locations/${location}/publishers/google/models/${modelId}:generateContent`;
73
+
74
+ const headers = {
75
+ 'Authorization': `Bearer ${_cactusToken}`,
76
+ 'Content-Type': 'application/json',
77
+ };
78
+
79
+ const parts: any[] = [];
80
+
81
+ if (imageData) {
82
+ const detectedMimeType = mimeType || 'image/jpeg';
83
+ parts.push({
84
+ inlineData: {
85
+ mimeType: detectedMimeType,
86
+ data: imageData
87
+ }
88
+ });
89
+ } else if (imagePath) {
90
+ const detectedMimeType = mimeType || detectMimeType(imagePath);
91
+ const RNFS = require('react-native-fs');
92
+ const base64Data = await RNFS.readFile(imagePath, 'base64');
93
+ parts.push({
94
+ inlineData: {
95
+ mimeType: detectedMimeType,
96
+ data: base64Data
97
+ }
98
+ });
99
+ }
100
+
101
+ parts.push({ text: textPrompt });
102
+
103
+ const requestBody = {
104
+ contents: {
105
+ role: 'user',
106
+ parts: parts,
107
+ }
108
+ };
109
+
110
+ const response = await fetch(endpoint, {
111
+ method: 'POST',
112
+ headers,
113
+ body: JSON.stringify(requestBody),
114
+ });
115
+
116
+ if (response.status === 401) {
117
+ _cactusToken = null;
118
+ throw new Error('Authentication failed. Please update your cactusToken.');
119
+ } else if (!response.ok) {
120
+ const errorText = await response.text();
121
+ throw new Error(`HTTP ${response.status}: ${errorText}`);
122
+ }
123
+
124
+ const responseBody = await response.json();
125
+
126
+ if (Array.isArray(responseBody)) {
127
+ throw new Error('Unexpected response format: received array instead of object');
128
+ }
129
+
130
+ if (responseBody.error) {
131
+ throw new Error(`API Error: ${responseBody.error.message}`);
132
+ }
133
+
134
+ const candidates = responseBody.candidates;
135
+ if (!candidates || candidates.length === 0) {
136
+ throw new Error('No candidates in response');
137
+ }
138
+
139
+ const content = candidates[0].content;
140
+ const responseParts = content.parts;
141
+ if (!responseParts || responseParts.length === 0) {
142
+ throw new Error('No parts in response');
143
+ }
144
+
145
+ return responseParts[0].text || '';
146
+ }
147
+
148
+ export async function getTextCompletion(prompt: string): Promise<string> {
149
+ return getVertexAICompletion(prompt);
150
+ }
151
+
152
+ export async function getVisionCompletion(prompt: string, imagePath: string): Promise<string> {
153
+ return getVertexAICompletion(prompt, undefined, imagePath);
154
+ }
155
+
156
+ export async function getVisionCompletionFromData(prompt: string, imageData: string, mimeType?: string): Promise<string> {
157
+ return getVertexAICompletion(prompt, imageData, undefined, mimeType);
158
+ }
159
+
160
+ function detectMimeType(filePath: string): string {
161
+ const extension = filePath.toLowerCase().split('.').pop();
162
+ switch (extension) {
163
+ case 'jpg':
164
+ case 'jpeg':
165
+ return 'image/jpeg';
166
+ case 'png':
167
+ return 'image/png';
168
+ case 'gif':
169
+ return 'image/gif';
170
+ case 'webp':
171
+ return 'image/webp';
172
+ default:
173
+ return 'image/jpeg';
174
+ }
175
+ }
@@ -0,0 +1,138 @@
1
+ import { Platform } from 'react-native'
2
+ // Import package.json to get version
3
+ const packageJson = require('../package.json');
4
+ import { PROJECT_ID } from './projectId';
5
+
6
+ export interface TelemetryParams {
7
+ n_gpu_layers: number | null
8
+ n_ctx: number | null
9
+ model: string | null
10
+ }
11
+
12
+ interface TelemetryRecord {
13
+ project_id: string;
14
+ device_id?: string;
15
+ device_manufacturer?: string;
16
+ device_model?: string;
17
+ os: 'iOS' | 'Android';
18
+ os_version: string;
19
+ framework: string;
20
+ framework_version: string;
21
+ telemetry_payload?: Record<string, any>;
22
+ error_payload?: Record<string, any>;
23
+ timestamp: string;
24
+ model_filename: string | null;
25
+ n_ctx: number | null;
26
+ n_gpu_layers: number | null;
27
+ }
28
+
29
+ interface TelemetryConfig {
30
+ supabaseUrl: string;
31
+ supabaseKey: string;
32
+ table?: string;
33
+ }
34
+
35
+ export class Telemetry {
36
+ private static instance: Telemetry | null = null;
37
+ private config: Required<TelemetryConfig>;
38
+
39
+ private constructor(config: TelemetryConfig) {
40
+ this.config = {
41
+ table: 'telemetry',
42
+ ...config
43
+ };
44
+ }
45
+
46
+ private static getFilename(path: string): string {
47
+ try {
48
+ return path.split('/').pop() || path.split('\\').pop() || 'unknown';
49
+ } catch {
50
+ return 'unknown';
51
+ }
52
+ }
53
+
54
+ static autoInit(): void {
55
+ if (!Telemetry.instance) {
56
+ Telemetry.instance = new Telemetry({
57
+ supabaseUrl: 'https://vlqqczxwyaodtcdmdmlw.supabase.co',
58
+ supabaseKey: 'eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJzdXBhYmFzZSIsInJlZiI6InZscXFjenh3eWFvZHRjZG1kbWx3Iiwicm9sZSI6ImFub24iLCJpYXQiOjE3NTE1MTg2MzIsImV4cCI6MjA2NzA5NDYzMn0.nBzqGuK9j6RZ6mOPWU2boAC_5H9XDs-fPpo5P3WZYbI', // Anon!
59
+ });
60
+ }
61
+ }
62
+
63
+ static init(config: TelemetryConfig): void {
64
+ if (!Telemetry.instance) {
65
+ Telemetry.instance = new Telemetry(config);
66
+ }
67
+ }
68
+
69
+ static track(payload: Record<string, any>, options: TelemetryParams, deviceMetadata?: Record<string, any>): void {
70
+ if (!Telemetry.instance) {
71
+ Telemetry.autoInit();
72
+ }
73
+ Telemetry.instance!.trackInternal(payload, options, deviceMetadata);
74
+ }
75
+
76
+ static error(error: Error, options: TelemetryParams): void {
77
+ if (!Telemetry.instance) {
78
+ Telemetry.autoInit();
79
+ }
80
+ Telemetry.instance!.errorInternal(error, options);
81
+ }
82
+
83
+ private trackInternal(payload: Record<string, any>, options: TelemetryParams, deviceMetadata?: Record<string, any>): void {
84
+ const record: TelemetryRecord = {
85
+ project_id: PROJECT_ID,
86
+ device_id: deviceMetadata?.deviceId,
87
+ device_manufacturer: deviceMetadata?.make,
88
+ device_model: deviceMetadata?.model,
89
+ os: Platform.OS === 'ios' ? 'iOS' : 'Android',
90
+ os_version: Platform.Version.toString(),
91
+ framework: 'react-native',
92
+ framework_version: packageJson.version,
93
+ telemetry_payload: payload,
94
+ timestamp: new Date().toISOString(),
95
+ model_filename: Telemetry.getFilename(options.model || ''),
96
+ n_ctx: options.n_ctx,
97
+ n_gpu_layers: options.n_gpu_layers,
98
+ };
99
+
100
+ this.sendRecord(record).catch(() => {});
101
+ }
102
+
103
+ private errorInternal(error: Error, options: TelemetryParams): void {
104
+ const errorPayload = {
105
+ message: error.message,
106
+ stack: error.stack,
107
+ name: error.name,
108
+ };
109
+
110
+ const record: TelemetryRecord = {
111
+ project_id: PROJECT_ID,
112
+ os: Platform.OS === 'ios' ? 'iOS' : 'Android',
113
+ os_version: Platform.Version.toString(),
114
+ framework: 'react-native',
115
+ framework_version: packageJson.version,
116
+ error_payload: errorPayload,
117
+ timestamp: new Date().toISOString(),
118
+ model_filename: Telemetry.getFilename(options.model || ''),
119
+ n_ctx: options.n_ctx,
120
+ n_gpu_layers: options.n_gpu_layers
121
+ };
122
+
123
+ this.sendRecord(record).catch(() => {});
124
+ }
125
+
126
+ private async sendRecord(record: TelemetryRecord): Promise<void> {
127
+ await (globalThis as any).fetch(`${this.config.supabaseUrl}/rest/v1/${this.config.table}`, {
128
+ method: 'POST',
129
+ headers: {
130
+ 'apikey': this.config.supabaseKey,
131
+ 'Authorization': `Bearer ${this.config.supabaseKey}`,
132
+ 'Content-Type': 'application/json',
133
+ 'Prefer': 'return=minimal'
134
+ },
135
+ body: JSON.stringify([record])
136
+ });
137
+ }
138
+ }
package/src/tools.ts CHANGED
@@ -1,4 +1,3 @@
1
- import type { CactusOAICompatibleMessage } from "./chat";
2
1
  import type { NativeCompletionResult } from "./NativeCactus";
3
2
 
4
3
  interface Parameter {
@@ -55,73 +54,33 @@ export class Tools {
55
54
  }
56
55
  }
57
56
 
58
- export function injectToolsIntoMessages(messages: CactusOAICompatibleMessage[], tools: Tools): CactusOAICompatibleMessage[] {
59
- const newMessages = [...messages];
60
- const toolsSchemas = tools.getSchemas();
61
- const promptToolInjection = `You have access to the following functions. Use them if required -
62
- ${JSON.stringify(toolsSchemas, null, 2)}
63
- Only use an available tool if needed. If a tool is chosen, respond ONLY with a JSON object matching the following schema:
64
- \`\`\`json
65
- {
66
- "tool_name": "<name of the tool>",
67
- "tool_input": {
68
- "<parameter_name>": "<parameter_value>",
69
- ...
70
- }
71
- }
72
- \`\`\`
73
- Remember, if you are calling a tool, you must respond with the JSON object and the JSON object ONLY!
74
- If no tool is needed, respond normally.
75
- `;
76
-
77
- const systemMessage = newMessages.find(m => m.role === 'system');
78
- if (!systemMessage) {
79
- newMessages.unshift({
80
- role: 'system',
81
- content: promptToolInjection
82
- });
83
- } else {
84
- systemMessage.content = `${systemMessage.content}\n\n${promptToolInjection}`;
85
- }
86
-
87
- return newMessages;
88
- }
89
-
90
57
  export async function parseAndExecuteTool(result: NativeCompletionResult, tools: Tools): Promise<{toolCalled: boolean, toolName?: string, toolInput?: any, toolOutput?: any}> {
91
- const match = result.content.match(/```json\s*([\s\S]*?)\s*```/);
92
-
93
- if (!match || !match[1]) return {toolCalled: false};
58
+ if (!result.tool_calls || result.tool_calls.length === 0) {
59
+ // console.log('No tool calls found');
60
+ return {toolCalled: false};
61
+ }
94
62
 
95
63
  try {
96
- const jsonContent = JSON.parse(match[1]);
97
- const { tool_name, tool_input } = jsonContent;
98
- // console.log('Calling tool:', tool_name, tool_input);
99
- const toolOutput = await tools.execute(tool_name, tool_input) || true;
64
+ const toolCall = result.tool_calls[0];
65
+ if (!toolCall) {
66
+ // console.log('No tool call found');
67
+ return {toolCalled: false};
68
+ }
69
+ const toolName = toolCall.function.name;
70
+ const toolInput = JSON.parse(toolCall.function.arguments);
71
+
72
+ // console.log('Calling tool:', toolName, toolInput);
73
+ const toolOutput = await tools.execute(toolName, toolInput);
100
74
  // console.log('Tool called result:', toolOutput);
101
75
 
102
76
  return {
103
77
  toolCalled: true,
104
- toolName: tool_name,
105
- toolInput: tool_input,
78
+ toolName,
79
+ toolInput,
106
80
  toolOutput
107
81
  };
108
82
  } catch (error) {
109
- // console.error('Error parsing JSON:', match, error);
83
+ // console.error('Error parsing tool call:', error);
110
84
  return {toolCalled: false};
111
85
  }
112
- }
113
-
114
- export function updateMessagesWithToolCall(messages: CactusOAICompatibleMessage[], toolName: string, toolInput: any, toolOutput: any): CactusOAICompatibleMessage[] {
115
- const newMessages = [...messages];
116
-
117
- newMessages.push({
118
- role: 'function-call',
119
- content: JSON.stringify({name: toolName, arguments: toolInput}, null, 2)
120
- })
121
- newMessages.push({
122
- role: 'function-response',
123
- content: JSON.stringify(toolOutput, null, 2)
124
- })
125
-
126
- return newMessages;
127
86
  }
package/src/vlm.ts CHANGED
@@ -10,6 +10,13 @@ import type {
10
10
  CactusOAICompatibleMessage,
11
11
  NativeCompletionResult,
12
12
  } from './index'
13
+ import { Telemetry } from './telemetry'
14
+ import { setCactusToken, getTextCompletion, getVisionCompletion } from './remote'
15
+
16
+ interface CactusVLMReturn {
17
+ vlm: CactusVLM | null
18
+ error: Error | null
19
+ }
13
20
 
14
21
  export type VLMContextParams = ContextParams & {
15
22
  mmproj: string
@@ -17,11 +24,12 @@ export type VLMContextParams = ContextParams & {
17
24
 
18
25
  export type VLMCompletionParams = Omit<CompletionParams, 'prompt'> & {
19
26
  images?: string[]
27
+ mode?: string
20
28
  }
21
29
 
22
30
  export class CactusVLM {
23
31
  private context: LlamaContext
24
-
32
+
25
33
  private constructor(context: LlamaContext) {
26
34
  this.context = context
27
35
  }
@@ -29,19 +37,84 @@ export class CactusVLM {
29
37
  static async init(
30
38
  params: VLMContextParams,
31
39
  onProgress?: (progress: number) => void,
32
- ): Promise<CactusVLM> {
33
- const context = await initLlama(params, onProgress)
40
+ cactusToken?: string,
41
+ ): Promise<CactusVLMReturn> {
42
+ if (cactusToken) {
43
+ setCactusToken(cactusToken);
44
+ }
45
+
46
+ const configs = [
47
+ params,
48
+ { ...params, n_gpu_layers: 0 }
49
+ ];
34
50
 
35
- // Explicitly disable GPU for the multimodal projector for stability.
36
- await initMultimodal(context.id, params.mmproj, false)
51
+ for (const config of configs) {
52
+ try {
53
+ const context = await initLlama(config, onProgress)
54
+ await initMultimodal(context.id, params.mmproj, false)
55
+ return {vlm: new CactusVLM(context), error: null}
56
+ } catch (e) {
57
+ Telemetry.error(e as Error, {
58
+ n_gpu_layers: config.n_gpu_layers ?? null,
59
+ n_ctx: config.n_ctx ?? null,
60
+ model: config.model ?? null,
61
+ });
62
+ if (configs.indexOf(config) === configs.length - 1) {
63
+ return {vlm: null, error: e as Error}
64
+ }
65
+ }
66
+ }
37
67
 
38
- return new CactusVLM(context)
68
+ return {vlm: null, error: new Error('Failed to initialize CactusVLM')}
39
69
  }
40
70
 
41
71
  async completion(
42
72
  messages: CactusOAICompatibleMessage[],
43
73
  params: VLMCompletionParams = {},
44
74
  callback?: (data: any) => void,
75
+ ): Promise<NativeCompletionResult> {
76
+ const mode = params.mode || 'local';
77
+
78
+ let result: NativeCompletionResult;
79
+ let lastError: Error | null = null;
80
+
81
+ if (mode === 'remote') {
82
+ result = await this._handleRemoteCompletion(messages, params, callback);
83
+ } else if (mode === 'local') {
84
+ result = await this._handleLocalCompletion(messages, params, callback);
85
+ } else if (mode === 'localfirst') {
86
+ try {
87
+ result = await this._handleLocalCompletion(messages, params, callback);
88
+ } catch (e) {
89
+ lastError = e as Error;
90
+ try {
91
+ result = await this._handleRemoteCompletion(messages, params, callback);
92
+ } catch (remoteError) {
93
+ throw lastError;
94
+ }
95
+ }
96
+ } else if (mode === 'remotefirst') {
97
+ try {
98
+ result = await this._handleRemoteCompletion(messages, params, callback);
99
+ } catch (e) {
100
+ lastError = e as Error;
101
+ try {
102
+ result = await this._handleLocalCompletion(messages, params, callback);
103
+ } catch (localError) {
104
+ throw lastError;
105
+ }
106
+ }
107
+ } else {
108
+ throw new Error('Invalid mode: ' + mode + '. Must be "local", "remote", "localfirst", or "remotefirst"');
109
+ }
110
+
111
+ return result;
112
+ }
113
+
114
+ private async _handleLocalCompletion(
115
+ messages: CactusOAICompatibleMessage[],
116
+ params: VLMCompletionParams,
117
+ callback?: (data: any) => void,
45
118
  ): Promise<NativeCompletionResult> {
46
119
  if (params.images && params.images.length > 0) {
47
120
  const formattedPrompt = await this.context.getFormattedChat(messages)
@@ -49,14 +122,62 @@ export class CactusVLM {
49
122
  typeof formattedPrompt === 'string'
50
123
  ? formattedPrompt
51
124
  : formattedPrompt.prompt
52
- return multimodalCompletion(
125
+ return await multimodalCompletion(
53
126
  this.context.id,
54
127
  prompt,
55
128
  params.images,
56
129
  { ...params, prompt, emit_partial_completion: !!callback },
57
130
  )
131
+ } else {
132
+ return await this.context.completion({ messages, ...params }, callback)
133
+ }
134
+ }
135
+
136
+ private async _handleRemoteCompletion(
137
+ messages: CactusOAICompatibleMessage[],
138
+ params: VLMCompletionParams,
139
+ callback?: (data: any) => void,
140
+ ): Promise<NativeCompletionResult> {
141
+ const prompt = messages.map((m) => `${m.role}: ${m.content}`).join('\n');
142
+ const imagePath = params.images && params.images.length > 0 ? params.images[0] : '';
143
+
144
+ let responseText: string;
145
+ if (imagePath) {
146
+ responseText = await getVisionCompletion(prompt, imagePath);
147
+ } else {
148
+ responseText = await getTextCompletion(prompt);
149
+ }
150
+
151
+ if (callback) {
152
+ for (let i = 0; i < responseText.length; i++) {
153
+ callback({ token: responseText[i] });
154
+ }
58
155
  }
59
- return this.context.completion({ messages, ...params }, callback)
156
+
157
+ return {
158
+ text: responseText,
159
+ reasoning_content: '',
160
+ tool_calls: [],
161
+ content: responseText,
162
+ tokens_predicted: responseText.split(' ').length,
163
+ tokens_evaluated: prompt.split(' ').length,
164
+ truncated: false,
165
+ stopped_eos: true,
166
+ stopped_word: '',
167
+ stopped_limit: 0,
168
+ stopping_word: '',
169
+ tokens_cached: 0,
170
+ timings: {
171
+ prompt_n: prompt.split(' ').length,
172
+ prompt_ms: 0,
173
+ prompt_per_token_ms: 0,
174
+ prompt_per_second: 0,
175
+ predicted_n: responseText.split(' ').length,
176
+ predicted_ms: 0,
177
+ predicted_per_token_ms: 0,
178
+ predicted_per_second: 0,
179
+ },
180
+ };
60
181
  }
61
182
 
62
183
  async rewind(): Promise<void> {