react-native-gemma-agent 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +21 -0
- package/README.md +457 -0
- package/package.json +52 -0
- package/skills/calculator.ts +47 -0
- package/skills/deviceLocation.ts +180 -0
- package/skills/index.ts +3 -0
- package/skills/queryWikipedia.ts +96 -0
- package/skills/readCalendar.ts +74 -0
- package/skills/webSearch.ts +75 -0
- package/src/AgentOrchestrator.ts +315 -0
- package/src/BM25Scorer.ts +118 -0
- package/src/FunctionCallParser.ts +113 -0
- package/src/GemmaAgentProvider.tsx +101 -0
- package/src/InferenceEngine.ts +301 -0
- package/src/ModelManager.ts +244 -0
- package/src/SkillRegistry.ts +60 -0
- package/src/SkillSandbox.tsx +155 -0
- package/src/index.ts +52 -0
- package/src/types.ts +197 -0
- package/src/useGemmaAgent.ts +222 -0
- package/src/useModelDownload.ts +80 -0
- package/src/useSkillRegistry.ts +58 -0
|
@@ -0,0 +1,301 @@
|
|
|
1
|
+
import {
|
|
2
|
+
initLlama,
|
|
3
|
+
releaseAllLlama,
|
|
4
|
+
type LlamaContext,
|
|
5
|
+
type TokenData,
|
|
6
|
+
type NativeCompletionResult,
|
|
7
|
+
} from 'llama.rn';
|
|
8
|
+
import type {
|
|
9
|
+
Message,
|
|
10
|
+
CompletionResult,
|
|
11
|
+
CompletionTimings,
|
|
12
|
+
ContextUsage,
|
|
13
|
+
GenerateOptions,
|
|
14
|
+
TokenEvent,
|
|
15
|
+
ToolCall,
|
|
16
|
+
ToolDefinition,
|
|
17
|
+
InferenceEngineConfig,
|
|
18
|
+
} from './types';
|
|
19
|
+
|
|
20
|
+
const DEFAULT_CONFIG: Required<InferenceEngineConfig> = {
|
|
21
|
+
contextSize: 4096,
|
|
22
|
+
batchSize: 512,
|
|
23
|
+
threads: 4,
|
|
24
|
+
flashAttn: 'auto',
|
|
25
|
+
useMlock: true,
|
|
26
|
+
gpuLayers: -1,
|
|
27
|
+
};
|
|
28
|
+
|
|
29
|
+
const DEFAULT_GENERATE: Required<Pick<GenerateOptions, 'maxTokens' | 'temperature' | 'topP' | 'topK'>> = {
|
|
30
|
+
maxTokens: 1024,
|
|
31
|
+
temperature: 0.7,
|
|
32
|
+
topP: 0.9,
|
|
33
|
+
topK: 40,
|
|
34
|
+
};
|
|
35
|
+
|
|
36
|
+
type LoadedModelInfo = {
|
|
37
|
+
gpu: boolean;
|
|
38
|
+
reasonNoGPU: string;
|
|
39
|
+
description: string | null;
|
|
40
|
+
nParams: number | null;
|
|
41
|
+
};
|
|
42
|
+
|
|
43
|
+
export class InferenceEngine {
|
|
44
|
+
private context: LlamaContext | null = null;
|
|
45
|
+
private config: Required<InferenceEngineConfig>;
|
|
46
|
+
private modelInfo: LoadedModelInfo | null = null;
|
|
47
|
+
private _isGenerating = false;
|
|
48
|
+
private _lastPromptTokens = 0;
|
|
49
|
+
private _lastPredictedTokens = 0;
|
|
50
|
+
|
|
51
|
+
constructor(config?: InferenceEngineConfig) {
|
|
52
|
+
this.config = { ...DEFAULT_CONFIG, ...config };
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
get isLoaded(): boolean {
|
|
56
|
+
return this.context !== null;
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
get isGenerating(): boolean {
|
|
60
|
+
return this._isGenerating;
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
get gpu(): boolean {
|
|
64
|
+
return this.modelInfo?.gpu ?? false;
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
/**
|
|
68
|
+
* Load a GGUF model into memory.
|
|
69
|
+
* @param modelPath — absolute path to the .gguf file on device
|
|
70
|
+
* @param onProgress — loading progress callback (0-100)
|
|
71
|
+
* @returns load time in ms
|
|
72
|
+
*/
|
|
73
|
+
async loadModel(
|
|
74
|
+
modelPath: string,
|
|
75
|
+
onProgress?: (percent: number) => void,
|
|
76
|
+
): Promise<number> {
|
|
77
|
+
if (this.context) {
|
|
78
|
+
throw new Error('Model already loaded. Call unload() first.');
|
|
79
|
+
}
|
|
80
|
+
|
|
81
|
+
const start = Date.now();
|
|
82
|
+
|
|
83
|
+
this.context = await initLlama(
|
|
84
|
+
{
|
|
85
|
+
model: modelPath,
|
|
86
|
+
n_ctx: this.config.contextSize,
|
|
87
|
+
n_batch: this.config.batchSize,
|
|
88
|
+
n_threads: this.config.threads,
|
|
89
|
+
flash_attn_type: this.config.flashAttn,
|
|
90
|
+
use_mlock: this.config.useMlock,
|
|
91
|
+
n_gpu_layers: this.config.gpuLayers,
|
|
92
|
+
},
|
|
93
|
+
(progress: number) => {
|
|
94
|
+
onProgress?.(progress);
|
|
95
|
+
},
|
|
96
|
+
);
|
|
97
|
+
|
|
98
|
+
const loadTimeMs = Date.now() - start;
|
|
99
|
+
|
|
100
|
+
this.modelInfo = {
|
|
101
|
+
gpu: this.context.gpu,
|
|
102
|
+
reasonNoGPU: this.context.reasonNoGPU,
|
|
103
|
+
description: this.context.model?.desc ?? null,
|
|
104
|
+
nParams: this.context.model?.nParams ?? null,
|
|
105
|
+
};
|
|
106
|
+
|
|
107
|
+
return loadTimeMs;
|
|
108
|
+
}
|
|
109
|
+
|
|
110
|
+
/**
|
|
111
|
+
* Run inference with messages and optional tools.
|
|
112
|
+
* Returns the full completion result including any tool calls.
|
|
113
|
+
*/
|
|
114
|
+
async generate(
|
|
115
|
+
messages: Message[],
|
|
116
|
+
options?: GenerateOptions,
|
|
117
|
+
onToken?: (event: TokenEvent) => void,
|
|
118
|
+
): Promise<CompletionResult> {
|
|
119
|
+
if (!this.context) {
|
|
120
|
+
throw new Error('No model loaded. Call loadModel() first.');
|
|
121
|
+
}
|
|
122
|
+
if (this._isGenerating) {
|
|
123
|
+
throw new Error('Generation already in progress. Call stopGeneration() first.');
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
this._isGenerating = true;
|
|
127
|
+
|
|
128
|
+
try {
|
|
129
|
+
const llamaMessages = messages.map(msg => {
|
|
130
|
+
const m: Record<string, unknown> = {
|
|
131
|
+
role: msg.role,
|
|
132
|
+
content: msg.content ?? '',
|
|
133
|
+
};
|
|
134
|
+
// Only include fields with actual string values — undefined/null
|
|
135
|
+
// fields become JSON null and crash llama.cpp's Jinja parser
|
|
136
|
+
if (msg.tool_calls && msg.tool_calls.length > 0) {
|
|
137
|
+
m.tool_calls = msg.tool_calls.map(tc => ({
|
|
138
|
+
type: tc.type,
|
|
139
|
+
id: tc.id ?? 'call_0',
|
|
140
|
+
function: {
|
|
141
|
+
name: tc.function.name,
|
|
142
|
+
arguments: tc.function.arguments ?? '{}',
|
|
143
|
+
},
|
|
144
|
+
}));
|
|
145
|
+
}
|
|
146
|
+
if (typeof msg.tool_call_id === 'string') {
|
|
147
|
+
m.tool_call_id = msg.tool_call_id;
|
|
148
|
+
}
|
|
149
|
+
if (typeof msg.name === 'string') {
|
|
150
|
+
m.name = msg.name;
|
|
151
|
+
}
|
|
152
|
+
return m;
|
|
153
|
+
});
|
|
154
|
+
|
|
155
|
+
const completionParams: Record<string, unknown> = {
|
|
156
|
+
messages: llamaMessages,
|
|
157
|
+
n_predict: options?.maxTokens ?? DEFAULT_GENERATE.maxTokens,
|
|
158
|
+
temperature: options?.temperature ?? DEFAULT_GENERATE.temperature,
|
|
159
|
+
top_p: options?.topP ?? DEFAULT_GENERATE.topP,
|
|
160
|
+
top_k: options?.topK ?? DEFAULT_GENERATE.topK,
|
|
161
|
+
stop: options?.stop ?? ['<end_of_turn>', '<eos>'],
|
|
162
|
+
};
|
|
163
|
+
|
|
164
|
+
if (options?.tools && options.tools.length > 0) {
|
|
165
|
+
completionParams.tools = options.tools;
|
|
166
|
+
completionParams.tool_choice = options.toolChoice ?? 'auto';
|
|
167
|
+
}
|
|
168
|
+
|
|
169
|
+
const result: NativeCompletionResult = await this.context.completion(
|
|
170
|
+
completionParams as any,
|
|
171
|
+
(data: TokenData) => {
|
|
172
|
+
if (onToken && data.token) {
|
|
173
|
+
onToken({
|
|
174
|
+
token: data.token,
|
|
175
|
+
toolCalls: data.tool_calls as ToolCall[] | undefined,
|
|
176
|
+
});
|
|
177
|
+
}
|
|
178
|
+
},
|
|
179
|
+
);
|
|
180
|
+
|
|
181
|
+
const mapped = this.mapResult(result);
|
|
182
|
+
this._lastPromptTokens = mapped.timings.promptTokens;
|
|
183
|
+
this._lastPredictedTokens = mapped.timings.predictedTokens;
|
|
184
|
+
return mapped;
|
|
185
|
+
} finally {
|
|
186
|
+
this._isGenerating = false;
|
|
187
|
+
}
|
|
188
|
+
}
|
|
189
|
+
|
|
190
|
+
/**
|
|
191
|
+
* Stop an in-progress generation.
|
|
192
|
+
*/
|
|
193
|
+
async stopGeneration(): Promise<void> {
|
|
194
|
+
if (this.context && this._isGenerating) {
|
|
195
|
+
await this.context.stopCompletion();
|
|
196
|
+
this._isGenerating = false;
|
|
197
|
+
}
|
|
198
|
+
}
|
|
199
|
+
|
|
200
|
+
/**
|
|
201
|
+
* Unload the model and free memory.
|
|
202
|
+
*/
|
|
203
|
+
async unload(): Promise<void> {
|
|
204
|
+
if (this.context) {
|
|
205
|
+
await releaseAllLlama();
|
|
206
|
+
this.context = null;
|
|
207
|
+
this.modelInfo = null;
|
|
208
|
+
this._isGenerating = false;
|
|
209
|
+
}
|
|
210
|
+
}
|
|
211
|
+
|
|
212
|
+
/**
|
|
213
|
+
* Get info about the loaded model.
|
|
214
|
+
*/
|
|
215
|
+
getInfo(): {
|
|
216
|
+
loaded: boolean;
|
|
217
|
+
gpu: boolean;
|
|
218
|
+
reasonNoGPU: string | null;
|
|
219
|
+
description: string | null;
|
|
220
|
+
nParams: number | null;
|
|
221
|
+
} {
|
|
222
|
+
return {
|
|
223
|
+
loaded: this.isLoaded,
|
|
224
|
+
gpu: this.modelInfo?.gpu ?? false,
|
|
225
|
+
reasonNoGPU: this.modelInfo?.reasonNoGPU ?? null,
|
|
226
|
+
description: this.modelInfo?.description ?? null,
|
|
227
|
+
nParams: this.modelInfo?.nParams ?? null,
|
|
228
|
+
};
|
|
229
|
+
}
|
|
230
|
+
|
|
231
|
+
/**
|
|
232
|
+
* Get current context window usage.
|
|
233
|
+
* Uses the last generation's prompt token count as an estimate of how
|
|
234
|
+
* much of the KV cache is filled. This is accurate when called right
|
|
235
|
+
* after generate() — the prompt tokens represent the full conversation
|
|
236
|
+
* history that was fed into the model.
|
|
237
|
+
*/
|
|
238
|
+
getContextUsage(): ContextUsage {
|
|
239
|
+
const total = this.config.contextSize;
|
|
240
|
+
const used = this._lastPromptTokens + this._lastPredictedTokens;
|
|
241
|
+
const percent = total > 0 ? Math.round((used / total) * 100) : 0;
|
|
242
|
+
return { used, total, percent };
|
|
243
|
+
}
|
|
244
|
+
|
|
245
|
+
/**
|
|
246
|
+
* Run a benchmark.
|
|
247
|
+
* Returns prompt processing and token generation speeds.
|
|
248
|
+
*/
|
|
249
|
+
async bench(
|
|
250
|
+
pp = 512,
|
|
251
|
+
tg = 128,
|
|
252
|
+
pl = 1,
|
|
253
|
+
nr = 3,
|
|
254
|
+
): Promise<{ ppSpeed: number; tgSpeed: number; flashAttn: boolean } | null> {
|
|
255
|
+
if (!this.context) {
|
|
256
|
+
return null;
|
|
257
|
+
}
|
|
258
|
+
|
|
259
|
+
try {
|
|
260
|
+
const result = await this.context.bench(pp, tg, pl, nr);
|
|
261
|
+
return {
|
|
262
|
+
ppSpeed: result.speedPp ?? 0,
|
|
263
|
+
tgSpeed: result.speedTg ?? 0,
|
|
264
|
+
flashAttn: Boolean(result.flashAttn),
|
|
265
|
+
};
|
|
266
|
+
} catch {
|
|
267
|
+
return null;
|
|
268
|
+
}
|
|
269
|
+
}
|
|
270
|
+
|
|
271
|
+
private mapResult(result: NativeCompletionResult): CompletionResult {
|
|
272
|
+
const timings: CompletionTimings = {
|
|
273
|
+
promptTokens: result.timings.prompt_n,
|
|
274
|
+
promptMs: result.timings.prompt_ms,
|
|
275
|
+
promptPerSecond: result.timings.prompt_per_second,
|
|
276
|
+
predictedTokens: result.timings.predicted_n,
|
|
277
|
+
predictedMs: result.timings.predicted_ms,
|
|
278
|
+
predictedPerSecond: result.timings.predicted_per_second,
|
|
279
|
+
};
|
|
280
|
+
|
|
281
|
+
const toolCalls: ToolCall[] = (result.tool_calls ?? []).map((tc, i) => ({
|
|
282
|
+
type: 'function' as const,
|
|
283
|
+
id: tc.id ?? `call_${i}`,
|
|
284
|
+
function: {
|
|
285
|
+
name: tc.function.name,
|
|
286
|
+
arguments: tc.function.arguments ?? '{}',
|
|
287
|
+
},
|
|
288
|
+
}));
|
|
289
|
+
|
|
290
|
+
return {
|
|
291
|
+
text: result.text,
|
|
292
|
+
content: result.content ?? result.text,
|
|
293
|
+
reasoning: result.reasoning_content || null,
|
|
294
|
+
toolCalls,
|
|
295
|
+
timings,
|
|
296
|
+
stoppedEos: result.stopped_eos,
|
|
297
|
+
stoppedLimit: result.stopped_limit > 0,
|
|
298
|
+
contextFull: result.context_full,
|
|
299
|
+
};
|
|
300
|
+
}
|
|
301
|
+
}
|
|
@@ -0,0 +1,244 @@
|
|
|
1
|
+
import RNFS from 'react-native-fs';
|
|
2
|
+
import type { ModelStatus, ModelConfig, DownloadProgress, ModelInfo } from './types';
|
|
3
|
+
|
|
4
|
+
const DEFAULT_HF_BASE = 'https://huggingface.co';
|
|
5
|
+
const DOWNLOAD_CHUNK_TIMEOUT = 60_000;
|
|
6
|
+
|
|
7
|
+
type StatusListener = (status: ModelStatus) => void;
|
|
8
|
+
type ProgressListener = (progress: DownloadProgress) => void;
|
|
9
|
+
|
|
10
|
+
export class ModelManager {
|
|
11
|
+
private _status: ModelStatus = 'not_downloaded';
|
|
12
|
+
private _modelPath: string | null = null;
|
|
13
|
+
private _config: ModelConfig;
|
|
14
|
+
private _downloadJobId: number | null = null;
|
|
15
|
+
private _statusListeners: Set<StatusListener> = new Set();
|
|
16
|
+
private _sizeBytes: number | null = null;
|
|
17
|
+
|
|
18
|
+
constructor(config: ModelConfig) {
|
|
19
|
+
this._config = config;
|
|
20
|
+
}
|
|
21
|
+
|
|
22
|
+
get status(): ModelStatus {
|
|
23
|
+
return this._status;
|
|
24
|
+
}
|
|
25
|
+
|
|
26
|
+
get modelPath(): string | null {
|
|
27
|
+
return this._modelPath;
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
onStatusChange(listener: StatusListener): () => void {
|
|
31
|
+
this._statusListeners.add(listener);
|
|
32
|
+
return () => this._statusListeners.delete(listener);
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
private setStatus(status: ModelStatus): void {
|
|
36
|
+
this._status = status;
|
|
37
|
+
for (const listener of this._statusListeners) {
|
|
38
|
+
listener(status);
|
|
39
|
+
}
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
/**
|
|
43
|
+
* Check if the model file exists at common locations.
|
|
44
|
+
* Returns the path if found, null otherwise.
|
|
45
|
+
*/
|
|
46
|
+
async findModel(): Promise<string | null> {
|
|
47
|
+
const candidates = [
|
|
48
|
+
`${RNFS.DocumentDirectoryPath}/${this._config.filename}`,
|
|
49
|
+
`/data/local/tmp/${this._config.filename}`,
|
|
50
|
+
`${RNFS.CachesDirectoryPath}/${this._config.filename}`,
|
|
51
|
+
];
|
|
52
|
+
|
|
53
|
+
for (const path of candidates) {
|
|
54
|
+
if (await RNFS.exists(path)) {
|
|
55
|
+
const stat = await RNFS.stat(path);
|
|
56
|
+
this._sizeBytes = Number(stat.size);
|
|
57
|
+
this._modelPath = path;
|
|
58
|
+
this.setStatus('ready');
|
|
59
|
+
return path;
|
|
60
|
+
}
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
return null;
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
/**
|
|
67
|
+
* Check if model exists and update status accordingly.
|
|
68
|
+
*/
|
|
69
|
+
async checkModel(): Promise<boolean> {
|
|
70
|
+
const path = await this.findModel();
|
|
71
|
+
if (path) {
|
|
72
|
+
return true;
|
|
73
|
+
}
|
|
74
|
+
this.setStatus('not_downloaded');
|
|
75
|
+
return false;
|
|
76
|
+
}
|
|
77
|
+
|
|
78
|
+
/**
|
|
79
|
+
* Download model from HuggingFace with progress.
|
|
80
|
+
* Supports resume via HTTP Range headers.
|
|
81
|
+
*/
|
|
82
|
+
async download(onProgress?: ProgressListener): Promise<string> {
|
|
83
|
+
if (this._status === 'downloading') {
|
|
84
|
+
throw new Error('Download already in progress');
|
|
85
|
+
}
|
|
86
|
+
|
|
87
|
+
// Check if already downloaded
|
|
88
|
+
const existing = await this.findModel();
|
|
89
|
+
if (existing) {
|
|
90
|
+
onProgress?.({ bytesDownloaded: this._sizeBytes!, totalBytes: this._sizeBytes!, percent: 100 });
|
|
91
|
+
return existing;
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
this.setStatus('downloading');
|
|
95
|
+
|
|
96
|
+
const destPath = `${RNFS.DocumentDirectoryPath}/${this._config.filename}`;
|
|
97
|
+
const url = this.buildDownloadUrl();
|
|
98
|
+
|
|
99
|
+
// Check for partial download (resume support)
|
|
100
|
+
let existingBytes = 0;
|
|
101
|
+
const partialPath = `${destPath}.partial`;
|
|
102
|
+
if (await RNFS.exists(partialPath)) {
|
|
103
|
+
const stat = await RNFS.stat(partialPath);
|
|
104
|
+
existingBytes = Number(stat.size);
|
|
105
|
+
}
|
|
106
|
+
|
|
107
|
+
try {
|
|
108
|
+
const headers: Record<string, string> = {};
|
|
109
|
+
if (existingBytes > 0) {
|
|
110
|
+
headers['Range'] = `bytes=${existingBytes}-`;
|
|
111
|
+
}
|
|
112
|
+
|
|
113
|
+
const totalBytes = this._config.expectedSize ?? 0;
|
|
114
|
+
|
|
115
|
+
const result = RNFS.downloadFile({
|
|
116
|
+
fromUrl: url,
|
|
117
|
+
toFile: partialPath,
|
|
118
|
+
headers,
|
|
119
|
+
begin: (res) => {
|
|
120
|
+
const contentLength = res.contentLength;
|
|
121
|
+
if (contentLength > 0 && totalBytes === 0) {
|
|
122
|
+
// Update expected size from Content-Length
|
|
123
|
+
}
|
|
124
|
+
},
|
|
125
|
+
progress: (res) => {
|
|
126
|
+
const downloaded = existingBytes + res.bytesWritten;
|
|
127
|
+
const total = totalBytes || (existingBytes + res.contentLength);
|
|
128
|
+
const percent = total > 0 ? Math.round((downloaded / total) * 100) : 0;
|
|
129
|
+
onProgress?.({ bytesDownloaded: downloaded, totalBytes: total, percent });
|
|
130
|
+
},
|
|
131
|
+
progressInterval: 500,
|
|
132
|
+
progressDivider: 0,
|
|
133
|
+
readTimeout: DOWNLOAD_CHUNK_TIMEOUT,
|
|
134
|
+
connectionTimeout: 15_000,
|
|
135
|
+
});
|
|
136
|
+
|
|
137
|
+
this._downloadJobId = result.jobId;
|
|
138
|
+
|
|
139
|
+
const response = await result.promise;
|
|
140
|
+
|
|
141
|
+
if (response.statusCode >= 200 && response.statusCode < 300) {
|
|
142
|
+
// Move partial to final destination
|
|
143
|
+
if (await RNFS.exists(destPath)) {
|
|
144
|
+
await RNFS.unlink(destPath);
|
|
145
|
+
}
|
|
146
|
+
await RNFS.moveFile(partialPath, destPath);
|
|
147
|
+
|
|
148
|
+
this._modelPath = destPath;
|
|
149
|
+
const stat = await RNFS.stat(destPath);
|
|
150
|
+
this._sizeBytes = Number(stat.size);
|
|
151
|
+
this._downloadJobId = null;
|
|
152
|
+
this.setStatus('ready');
|
|
153
|
+
|
|
154
|
+
onProgress?.({ bytesDownloaded: this._sizeBytes, totalBytes: this._sizeBytes, percent: 100 });
|
|
155
|
+
return destPath;
|
|
156
|
+
}
|
|
157
|
+
|
|
158
|
+
throw new Error(`Download failed with status ${response.statusCode}`);
|
|
159
|
+
} catch (err) {
|
|
160
|
+
this._downloadJobId = null;
|
|
161
|
+
|
|
162
|
+
// Don't delete partial file — allows resume on retry
|
|
163
|
+
if (this._status !== 'ready') {
|
|
164
|
+
this.setStatus('error');
|
|
165
|
+
}
|
|
166
|
+
throw err;
|
|
167
|
+
}
|
|
168
|
+
}
|
|
169
|
+
|
|
170
|
+
/**
|
|
171
|
+
* Cancel an in-progress download.
|
|
172
|
+
*/
|
|
173
|
+
cancelDownload(): void {
|
|
174
|
+
if (this._downloadJobId !== null) {
|
|
175
|
+
RNFS.stopDownload(this._downloadJobId);
|
|
176
|
+
this._downloadJobId = null;
|
|
177
|
+
this.setStatus('not_downloaded');
|
|
178
|
+
}
|
|
179
|
+
}
|
|
180
|
+
|
|
181
|
+
/**
|
|
182
|
+
* Set a custom model path (for pre-downloaded models).
|
|
183
|
+
*/
|
|
184
|
+
async setModelPath(path: string): Promise<void> {
|
|
185
|
+
if (!(await RNFS.exists(path))) {
|
|
186
|
+
throw new Error(`Model file not found at ${path}`);
|
|
187
|
+
}
|
|
188
|
+
const stat = await RNFS.stat(path);
|
|
189
|
+
this._sizeBytes = Number(stat.size);
|
|
190
|
+
this._modelPath = path;
|
|
191
|
+
this.setStatus('ready');
|
|
192
|
+
}
|
|
193
|
+
|
|
194
|
+
/**
|
|
195
|
+
* Delete the downloaded model file.
|
|
196
|
+
*/
|
|
197
|
+
async deleteModel(): Promise<void> {
|
|
198
|
+
if (this._modelPath && (await RNFS.exists(this._modelPath))) {
|
|
199
|
+
await RNFS.unlink(this._modelPath);
|
|
200
|
+
}
|
|
201
|
+
|
|
202
|
+
// Also clean up partial downloads
|
|
203
|
+
const partialPath = `${RNFS.DocumentDirectoryPath}/${this._config.filename}.partial`;
|
|
204
|
+
if (await RNFS.exists(partialPath)) {
|
|
205
|
+
await RNFS.unlink(partialPath);
|
|
206
|
+
}
|
|
207
|
+
|
|
208
|
+
this._modelPath = null;
|
|
209
|
+
this._sizeBytes = null;
|
|
210
|
+
this.setStatus('not_downloaded');
|
|
211
|
+
}
|
|
212
|
+
|
|
213
|
+
/**
|
|
214
|
+
* Get model info.
|
|
215
|
+
*/
|
|
216
|
+
getInfo(): ModelInfo {
|
|
217
|
+
return {
|
|
218
|
+
status: this._status,
|
|
219
|
+
path: this._modelPath,
|
|
220
|
+
sizeBytes: this._sizeBytes,
|
|
221
|
+
description: null,
|
|
222
|
+
nParams: null,
|
|
223
|
+
nEmbd: null,
|
|
224
|
+
};
|
|
225
|
+
}
|
|
226
|
+
|
|
227
|
+
/**
|
|
228
|
+
* Check available storage space.
|
|
229
|
+
*/
|
|
230
|
+
async checkStorage(): Promise<{ available: number; required: number; sufficient: boolean }> {
|
|
231
|
+
const fsInfo = await RNFS.getFSInfo();
|
|
232
|
+
const required = this._config.expectedSize ?? 0;
|
|
233
|
+
return {
|
|
234
|
+
available: fsInfo.freeSpace,
|
|
235
|
+
required,
|
|
236
|
+
sufficient: required === 0 || fsInfo.freeSpace > required * 1.1, // 10% buffer
|
|
237
|
+
};
|
|
238
|
+
}
|
|
239
|
+
|
|
240
|
+
private buildDownloadUrl(): string {
|
|
241
|
+
const { repoId, filename } = this._config;
|
|
242
|
+
return `${DEFAULT_HF_BASE}/${repoId}/resolve/main/${filename}`;
|
|
243
|
+
}
|
|
244
|
+
}
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
import type { SkillManifest, ToolDefinition } from './types';
|
|
2
|
+
|
|
3
|
+
export class SkillRegistry {
|
|
4
|
+
private skills: Map<string, SkillManifest> = new Map();
|
|
5
|
+
|
|
6
|
+
/**
|
|
7
|
+
* Register a skill. Validates that required fields are present.
|
|
8
|
+
*/
|
|
9
|
+
registerSkill(skill: SkillManifest): void {
|
|
10
|
+
if (skill.type === 'js' && !skill.html) {
|
|
11
|
+
throw new Error(`JS skill "${skill.name}" requires an html field`);
|
|
12
|
+
}
|
|
13
|
+
if (skill.type === 'native' && !skill.execute) {
|
|
14
|
+
throw new Error(`Native skill "${skill.name}" requires an execute function`);
|
|
15
|
+
}
|
|
16
|
+
this.skills.set(skill.name, skill);
|
|
17
|
+
}
|
|
18
|
+
|
|
19
|
+
unregisterSkill(name: string): void {
|
|
20
|
+
this.skills.delete(name);
|
|
21
|
+
}
|
|
22
|
+
|
|
23
|
+
getSkill(name: string): SkillManifest | null {
|
|
24
|
+
return this.skills.get(name) ?? null;
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
getSkills(): SkillManifest[] {
|
|
28
|
+
return Array.from(this.skills.values());
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
hasSkill(name: string): boolean {
|
|
32
|
+
return this.skills.has(name);
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
/**
|
|
36
|
+
* Convert registered skills to OpenAI-compatible tool definitions.
|
|
37
|
+
* Pass these to InferenceEngine.generate() as the `tools` parameter —
|
|
38
|
+
* llama.rn handles the rest via Jinja templates + Gemma 4 chat format.
|
|
39
|
+
*/
|
|
40
|
+
toToolDefinitions(): ToolDefinition[] {
|
|
41
|
+
return this.getSkills().map(skill => ({
|
|
42
|
+
type: 'function' as const,
|
|
43
|
+
function: {
|
|
44
|
+
name: skill.name,
|
|
45
|
+
description:
|
|
46
|
+
skill.description +
|
|
47
|
+
(skill.instructions ? `\n${skill.instructions}` : ''),
|
|
48
|
+
parameters: {
|
|
49
|
+
type: 'object' as const,
|
|
50
|
+
properties: skill.parameters,
|
|
51
|
+
required: skill.requiredParameters,
|
|
52
|
+
},
|
|
53
|
+
},
|
|
54
|
+
}));
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
clear(): void {
|
|
58
|
+
this.skills.clear();
|
|
59
|
+
}
|
|
60
|
+
}
|