@fugood/llama.node 1.3.0-rc.2 → 1.3.0-rc.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/README.md CHANGED
@@ -37,11 +37,9 @@ import { loadModel } from '@fugood/llama.node'
37
37
  // Initial a Llama context with the model (may take a while)
38
38
  const context = await loadModel({
39
39
  model: 'path/to/gguf/model',
40
- use_mlock: true,
41
40
  n_ctx: 2048,
42
- n_gpu_layers: 1, // > 0: enable GPU
43
- // embedding: true, // use embedding
44
- // lib_variant: 'opencl', // Change backend
41
+ n_gpu_layers: 99, // > 0: enable GPU
42
+ // lib_variant: 'vulkan', // Change backend
45
43
  })
46
44
 
47
45
  // Do completion
package/lib/binding.js CHANGED
@@ -42,7 +42,7 @@ var __awaiter = (this && this.__awaiter) || function (thisArg, _arguments, P, ge
42
42
  });
43
43
  };
44
44
  Object.defineProperty(exports, "__esModule", { value: true });
45
- exports.loadModule = void 0;
45
+ exports.isLibVariantAvailable = exports.loadModule = void 0;
46
46
  const getPlatformPackageName = (variant) => {
47
47
  const platform = process.platform;
48
48
  const arch = process.arch;
@@ -72,3 +72,21 @@ const loadModule = (variant) => __awaiter(void 0, void 0, void 0, function* () {
72
72
  return (yield Promise.resolve().then(() => __importStar(require('../build/Release/index.node'))));
73
73
  });
74
74
  exports.loadModule = loadModule;
75
+ const isLibVariantAvailable = (variant) => __awaiter(void 0, void 0, void 0, function* () {
76
+ if (variant && variant !== 'default') {
77
+ const module = yield loadPlatformPackage(getPlatformPackageName(variant));
78
+ return module != null;
79
+ }
80
+ const defaultModule = yield loadPlatformPackage(getPlatformPackageName());
81
+ if (defaultModule)
82
+ return true;
83
+ try {
84
+ // @ts-ignore
85
+ yield Promise.resolve().then(() => __importStar(require('../build/Release/index.node')));
86
+ return true;
87
+ }
88
+ catch (error) {
89
+ return false;
90
+ }
91
+ });
92
+ exports.isLibVariantAvailable = isLibVariantAvailable;
package/lib/binding.ts CHANGED
@@ -375,7 +375,7 @@ export type ToolCall = {
375
375
  }
376
376
 
377
377
  export interface LlamaContext {
378
- new (options: LlamaModelOptions): LlamaContext
378
+ new (options: LlamaModelOptions, onProgress?: (progress: number) => void): LlamaContext
379
379
  getSystemInfo(): string
380
380
  getModelInfo(): ModelInfo
381
381
  getFormattedChat(
@@ -587,3 +587,21 @@ export const loadModule = async (variant?: LibVariant): Promise<Module> => {
587
587
  // @ts-ignore
588
588
  return (await import('../build/Release/index.node')) as Module
589
589
  }
590
+
591
+ export const isLibVariantAvailable = async (variant?: LibVariant): Promise<boolean> => {
592
+ if (variant && variant !== 'default') {
593
+ const module = await loadPlatformPackage(getPlatformPackageName(variant))
594
+ return module != null
595
+ }
596
+
597
+ const defaultModule = await loadPlatformPackage(getPlatformPackageName())
598
+ if (defaultModule) return true
599
+
600
+ try {
601
+ // @ts-ignore
602
+ await import('../build/Release/index.node')
603
+ return true
604
+ } catch (error) {
605
+ return false
606
+ }
607
+ }
package/lib/index.js CHANGED
@@ -23,14 +23,14 @@ var __awaiter = (this && this.__awaiter) || function (thisArg, _arguments, P, ge
23
23
  });
24
24
  };
25
25
  Object.defineProperty(exports, "__esModule", { value: true });
26
- exports.BuildInfo = exports.getBackendDevicesInfo = exports.loadLlamaModelInfo = exports.initLlama = exports.loadModel = exports.toggleNativeLog = exports.MTMD_DEFAULT_MEDIA_MARKER = exports.LlamaParallelAPI = void 0;
26
+ exports.BuildInfo = exports.getBackendDevicesInfo = exports.loadLlamaModelInfo = exports.initLlama = exports.loadModel = exports.toggleNativeLog = exports.LlamaParallelAPI = void 0;
27
27
  exports.addNativeLogListener = addNativeLogListener;
28
28
  const binding_1 = require("./binding");
29
29
  const version_1 = require("./version");
30
30
  const parallel_1 = require("./parallel");
31
31
  Object.defineProperty(exports, "LlamaParallelAPI", { enumerable: true, get: function () { return parallel_1.LlamaParallelAPI; } });
32
+ const utils_1 = require("./utils");
32
33
  __exportStar(require("./binding"), exports);
33
- exports.MTMD_DEFAULT_MEDIA_MARKER = '<__media__>';
34
34
  const mods = {};
35
35
  const logListeners = [];
36
36
  const logCallback = (level, text) => {
@@ -83,60 +83,9 @@ class LlamaContextWrapper {
83
83
  isLlamaChatSupported() {
84
84
  return !!this.ctx.getModelInfo().chatTemplates.llamaChat;
85
85
  }
86
- _formatMediaChat(messages) {
87
- if (!messages)
88
- return {
89
- messages,
90
- has_media: false,
91
- };
92
- const mediaPaths = [];
93
- return {
94
- messages: messages.map((msg) => {
95
- if (Array.isArray(msg.content)) {
96
- const content = msg.content.map((part) => {
97
- var _a;
98
- // Handle multimodal content
99
- if (part.type === 'image_url') {
100
- let path = ((_a = part.image_url) === null || _a === void 0 ? void 0 : _a.url) || '';
101
- mediaPaths.push(path);
102
- return {
103
- type: 'text',
104
- text: exports.MTMD_DEFAULT_MEDIA_MARKER,
105
- };
106
- }
107
- else if (part.type === 'input_audio') {
108
- const { input_audio: audio } = part;
109
- if (!audio)
110
- throw new Error('input_audio is required');
111
- const { format } = audio;
112
- if (format != 'wav' && format != 'mp3') {
113
- throw new Error(`Unsupported audio format: ${format}`);
114
- }
115
- if (audio.url) {
116
- const path = audio.url.replace(/file:\/\//, '');
117
- mediaPaths.push(path);
118
- }
119
- else if (audio.data) {
120
- mediaPaths.push(audio.data);
121
- }
122
- return {
123
- type: 'text',
124
- text: exports.MTMD_DEFAULT_MEDIA_MARKER,
125
- };
126
- }
127
- return part;
128
- });
129
- return Object.assign(Object.assign({}, msg), { content });
130
- }
131
- return msg;
132
- }),
133
- has_media: mediaPaths.length > 0,
134
- media_paths: mediaPaths,
135
- };
136
- }
137
86
  getFormattedChat(messages, template, params) {
138
87
  var _a;
139
- const { messages: chat, has_media, media_paths, } = this._formatMediaChat(messages);
88
+ const { messages: chat, has_media, media_paths, } = (0, utils_1.formatMediaChat)(messages);
140
89
  const useJinja = this.isJinjaSupported() && (params === null || params === void 0 ? void 0 : params.jinja);
141
90
  let tmpl;
142
91
  if (template)
@@ -170,7 +119,7 @@ class LlamaContextWrapper {
170
119
  media_paths }, jinjaResult);
171
120
  }
172
121
  completion(options, callback) {
173
- const { messages, media_paths = options.media_paths } = this._formatMediaChat(options.messages);
122
+ const { messages, media_paths = options.media_paths } = (0, utils_1.formatMediaChat)(options.messages);
174
123
  return this.ctx.completion(Object.assign(Object.assign({}, options), { messages, media_paths: options.media_paths || media_paths }), callback || (() => { }));
175
124
  }
176
125
  stopCompletion() {
@@ -244,12 +193,12 @@ class LlamaContextWrapper {
244
193
  return this.ctx.decodeAudioTokens(tokens);
245
194
  }
246
195
  }
247
- const loadModel = (options) => __awaiter(void 0, void 0, void 0, function* () {
196
+ const loadModel = (options, onProgress) => __awaiter(void 0, void 0, void 0, function* () {
248
197
  var _a, _b;
249
198
  const variant = (_a = options.lib_variant) !== null && _a !== void 0 ? _a : 'default';
250
199
  (_b = mods[variant]) !== null && _b !== void 0 ? _b : (mods[variant] = yield (0, binding_1.loadModule)(options.lib_variant));
251
200
  refreshNativeLogSetup();
252
- const nativeCtx = new mods[variant].LlamaContext(options);
201
+ const nativeCtx = new mods[variant].LlamaContext(options, onProgress);
253
202
  return new LlamaContextWrapper(nativeCtx);
254
203
  });
255
204
  exports.loadModel = loadModel;
package/lib/index.ts CHANGED
@@ -19,12 +19,11 @@ import type {
19
19
  } from './binding'
20
20
  import { BUILD_NUMBER, BUILD_COMMIT } from './version'
21
21
  import { LlamaParallelAPI } from './parallel'
22
+ import { formatMediaChat } from './utils'
22
23
 
23
24
  export * from './binding'
24
25
  export { LlamaParallelAPI }
25
26
 
26
- export const MTMD_DEFAULT_MEDIA_MARKER = '<__media__>'
27
-
28
27
  export interface LlamaModelOptionsExtended extends LlamaModelOptions {
29
28
  lib_variant?: LibVariant
30
29
  }
@@ -104,63 +103,6 @@ class LlamaContextWrapper {
104
103
  return !!this.ctx.getModelInfo().chatTemplates.llamaChat
105
104
  }
106
105
 
107
- _formatMediaChat(messages: ChatMessage[] | undefined): {
108
- messages: ChatMessage[] | undefined
109
- has_media: boolean
110
- media_paths?: string[]
111
- } {
112
- if (!messages)
113
- return {
114
- messages,
115
- has_media: false,
116
- }
117
- const mediaPaths: string[] = []
118
- return {
119
- messages: messages.map((msg) => {
120
- if (Array.isArray(msg.content)) {
121
- const content = msg.content.map((part) => {
122
- // Handle multimodal content
123
- if (part.type === 'image_url') {
124
- let path = part.image_url?.url || ''
125
- mediaPaths.push(path)
126
- return {
127
- type: 'text',
128
- text: MTMD_DEFAULT_MEDIA_MARKER,
129
- }
130
- } else if (part.type === 'input_audio') {
131
- const { input_audio: audio } = part
132
- if (!audio) throw new Error('input_audio is required')
133
-
134
- const { format } = audio
135
- if (format != 'wav' && format != 'mp3') {
136
- throw new Error(`Unsupported audio format: ${format}`)
137
- }
138
- if (audio.url) {
139
- const path = audio.url.replace(/file:\/\//, '')
140
- mediaPaths.push(path)
141
- } else if (audio.data) {
142
- mediaPaths.push(audio.data)
143
- }
144
- return {
145
- type: 'text',
146
- text: MTMD_DEFAULT_MEDIA_MARKER,
147
- }
148
- }
149
- return part
150
- })
151
-
152
- return {
153
- ...msg,
154
- content,
155
- }
156
- }
157
- return msg
158
- }),
159
- has_media: mediaPaths.length > 0,
160
- media_paths: mediaPaths,
161
- }
162
- }
163
-
164
106
  getFormattedChat(
165
107
  messages: ChatMessage[],
166
108
  template?: string,
@@ -180,7 +122,7 @@ class LlamaContextWrapper {
180
122
  messages: chat,
181
123
  has_media,
182
124
  media_paths,
183
- } = this._formatMediaChat(messages)
125
+ } = formatMediaChat(messages)
184
126
 
185
127
  const useJinja = this.isJinjaSupported() && params?.jinja
186
128
  let tmpl
@@ -228,7 +170,7 @@ class LlamaContextWrapper {
228
170
  callback?: (token: LlamaCompletionToken) => void,
229
171
  ): Promise<LlamaCompletionResult> {
230
172
  const { messages, media_paths = options.media_paths } =
231
- this._formatMediaChat(options.messages)
173
+ formatMediaChat(options.messages)
232
174
  return this.ctx.completion(
233
175
  {
234
176
  ...options,
@@ -357,12 +299,13 @@ class LlamaContextWrapper {
357
299
 
358
300
  export const loadModel = async (
359
301
  options: LlamaModelOptionsExtended,
302
+ onProgress?: (progress: number) => void,
360
303
  ): Promise<LlamaContextWrapper> => {
361
304
  const variant = options.lib_variant ?? 'default'
362
305
  mods[variant] ??= await loadModule(options.lib_variant)
363
306
  refreshNativeLogSetup()
364
307
 
365
- const nativeCtx = new mods[variant].LlamaContext(options)
308
+ const nativeCtx = new mods[variant].LlamaContext(options, onProgress)
366
309
  return new LlamaContextWrapper(nativeCtx)
367
310
  }
368
311
 
package/lib/parallel.js CHANGED
@@ -10,6 +10,7 @@ var __awaiter = (this && this.__awaiter) || function (thisArg, _arguments, P, ge
10
10
  };
11
11
  Object.defineProperty(exports, "__esModule", { value: true });
12
12
  exports.LlamaParallelAPI = void 0;
13
+ const utils_1 = require("./utils");
13
14
  class LlamaParallelAPI {
14
15
  constructor(context) {
15
16
  this.enabled = false;
@@ -87,8 +88,9 @@ class LlamaParallelAPI {
87
88
  }
88
89
  }
89
90
  : undefined;
91
+ const { messages, media_paths = options.media_paths } = (0, utils_1.formatMediaChat)(options.messages);
90
92
  // Queue the completion immediately (this is synchronous!)
91
- const { requestId } = this.context.queueCompletion(options, tokenCallback ||
93
+ const { requestId } = this.context.queueCompletion(Object.assign(Object.assign({}, options), { messages, media_paths: media_paths }), tokenCallback ||
92
94
  ((error, result) => {
93
95
  if (error) {
94
96
  const pendingReq = this.pendingRequests.get(result === null || result === void 0 ? void 0 : result.requestId);
package/lib/parallel.ts CHANGED
@@ -5,6 +5,7 @@ import type {
5
5
  LlamaCompletionToken,
6
6
  RerankParams,
7
7
  } from './binding'
8
+ import { formatMediaChat } from './utils'
8
9
 
9
10
  export class LlamaParallelAPI {
10
11
  private context: LlamaContext
@@ -109,9 +110,16 @@ export class LlamaParallelAPI {
109
110
  }
110
111
  : undefined
111
112
 
113
+ const { messages, media_paths = options.media_paths } = formatMediaChat(
114
+ options.messages,
115
+ )
112
116
  // Queue the completion immediately (this is synchronous!)
113
117
  const { requestId } = this.context.queueCompletion(
114
- options,
118
+ {
119
+ ...options,
120
+ messages,
121
+ media_paths: media_paths,
122
+ },
115
123
  tokenCallback ||
116
124
  ((error, result) => {
117
125
  if (error) {
package/lib/utils.js ADDED
@@ -0,0 +1,56 @@
1
+ "use strict";
2
+ Object.defineProperty(exports, "__esModule", { value: true });
3
+ exports.formatMediaChat = exports.MTMD_DEFAULT_MEDIA_MARKER = void 0;
4
+ exports.MTMD_DEFAULT_MEDIA_MARKER = '<__media__>';
5
+ const formatMediaChat = (messages) => {
6
+ if (!messages)
7
+ return {
8
+ messages,
9
+ has_media: false,
10
+ };
11
+ const mediaPaths = [];
12
+ return {
13
+ messages: messages.map((msg) => {
14
+ if (Array.isArray(msg.content)) {
15
+ const content = msg.content.map((part) => {
16
+ var _a;
17
+ // Handle multimodal content
18
+ if (part.type === 'image_url') {
19
+ let path = ((_a = part.image_url) === null || _a === void 0 ? void 0 : _a.url) || '';
20
+ mediaPaths.push(path);
21
+ return {
22
+ type: 'text',
23
+ text: exports.MTMD_DEFAULT_MEDIA_MARKER,
24
+ };
25
+ }
26
+ else if (part.type === 'input_audio') {
27
+ const { input_audio: audio } = part;
28
+ if (!audio)
29
+ throw new Error('input_audio is required');
30
+ const { format } = audio;
31
+ if (format != 'wav' && format != 'mp3') {
32
+ throw new Error(`Unsupported audio format: ${format}`);
33
+ }
34
+ if (audio.url) {
35
+ const path = audio.url.replace(/file:\/\//, '');
36
+ mediaPaths.push(path);
37
+ }
38
+ else if (audio.data) {
39
+ mediaPaths.push(audio.data);
40
+ }
41
+ return {
42
+ type: 'text',
43
+ text: exports.MTMD_DEFAULT_MEDIA_MARKER,
44
+ };
45
+ }
46
+ return part;
47
+ });
48
+ return Object.assign(Object.assign({}, msg), { content });
49
+ }
50
+ return msg;
51
+ }),
52
+ has_media: mediaPaths.length > 0,
53
+ media_paths: mediaPaths,
54
+ };
55
+ };
56
+ exports.formatMediaChat = formatMediaChat;
package/lib/utils.ts ADDED
@@ -0,0 +1,63 @@
1
+
2
+ import type {
3
+ ChatMessage,
4
+ } from './binding'
5
+
6
+ export const MTMD_DEFAULT_MEDIA_MARKER = '<__media__>'
7
+
8
+ export const formatMediaChat = (messages: ChatMessage[] | undefined): {
9
+ messages: ChatMessage[] | undefined
10
+ has_media: boolean
11
+ media_paths?: string[]
12
+ } => {
13
+ if (!messages)
14
+ return {
15
+ messages,
16
+ has_media: false,
17
+ }
18
+ const mediaPaths: string[] = []
19
+ return {
20
+ messages: messages.map((msg) => {
21
+ if (Array.isArray(msg.content)) {
22
+ const content = msg.content.map((part) => {
23
+ // Handle multimodal content
24
+ if (part.type === 'image_url') {
25
+ let path = part.image_url?.url || ''
26
+ mediaPaths.push(path)
27
+ return {
28
+ type: 'text',
29
+ text: MTMD_DEFAULT_MEDIA_MARKER,
30
+ }
31
+ } else if (part.type === 'input_audio') {
32
+ const { input_audio: audio } = part
33
+ if (!audio) throw new Error('input_audio is required')
34
+
35
+ const { format } = audio
36
+ if (format != 'wav' && format != 'mp3') {
37
+ throw new Error(`Unsupported audio format: ${format}`)
38
+ }
39
+ if (audio.url) {
40
+ const path = audio.url.replace(/file:\/\//, '')
41
+ mediaPaths.push(path)
42
+ } else if (audio.data) {
43
+ mediaPaths.push(audio.data)
44
+ }
45
+ return {
46
+ type: 'text',
47
+ text: MTMD_DEFAULT_MEDIA_MARKER,
48
+ }
49
+ }
50
+ return part
51
+ })
52
+
53
+ return {
54
+ ...msg,
55
+ content,
56
+ }
57
+ }
58
+ return msg
59
+ }),
60
+ has_media: mediaPaths.length > 0,
61
+ media_paths: mediaPaths,
62
+ }
63
+ }
package/package.json CHANGED
@@ -1,7 +1,7 @@
1
1
  {
2
2
  "name": "@fugood/llama.node",
3
3
  "access": "public",
4
- "version": "1.3.0-rc.2",
4
+ "version": "1.3.0-rc.5",
5
5
  "description": "An another Node binding of llama.cpp",
6
6
  "main": "lib/index.js",
7
7
  "scripts": {
@@ -72,19 +72,19 @@
72
72
  "CMakeLists.txt"
73
73
  ],
74
74
  "optionalDependencies": {
75
- "@fugood/node-llama-linux-x64": "1.3.0-rc.2",
76
- "@fugood/node-llama-linux-x64-vulkan": "1.3.0-rc.2",
77
- "@fugood/node-llama-linux-x64-cuda": "1.3.0-rc.2",
78
- "@fugood/node-llama-linux-arm64": "1.3.0-rc.2",
79
- "@fugood/node-llama-linux-arm64-vulkan": "1.3.0-rc.2",
80
- "@fugood/node-llama-linux-arm64-cuda": "1.3.0-rc.2",
81
- "@fugood/node-llama-win32-x64": "1.3.0-rc.2",
82
- "@fugood/node-llama-win32-x64-vulkan": "1.3.0-rc.2",
83
- "@fugood/node-llama-win32-x64-cuda": "1.3.0-rc.2",
84
- "@fugood/node-llama-win32-arm64": "1.3.0-rc.2",
85
- "@fugood/node-llama-win32-arm64-vulkan": "1.3.0-rc.2",
86
- "@fugood/node-llama-darwin-x64": "1.3.0-rc.2",
87
- "@fugood/node-llama-darwin-arm64": "1.3.0-rc.2"
75
+ "@fugood/node-llama-linux-x64": "1.3.0-rc.5",
76
+ "@fugood/node-llama-linux-x64-vulkan": "1.3.0-rc.5",
77
+ "@fugood/node-llama-linux-x64-cuda": "1.3.0-rc.5",
78
+ "@fugood/node-llama-linux-arm64": "1.3.0-rc.5",
79
+ "@fugood/node-llama-linux-arm64-vulkan": "1.3.0-rc.5",
80
+ "@fugood/node-llama-linux-arm64-cuda": "1.3.0-rc.5",
81
+ "@fugood/node-llama-win32-x64": "1.3.0-rc.5",
82
+ "@fugood/node-llama-win32-x64-vulkan": "1.3.0-rc.5",
83
+ "@fugood/node-llama-win32-x64-cuda": "1.3.0-rc.5",
84
+ "@fugood/node-llama-win32-arm64": "1.3.0-rc.5",
85
+ "@fugood/node-llama-win32-arm64-vulkan": "1.3.0-rc.5",
86
+ "@fugood/node-llama-darwin-x64": "1.3.0-rc.5",
87
+ "@fugood/node-llama-darwin-arm64": "1.3.0-rc.5"
88
88
  },
89
89
  "devDependencies": {
90
90
  "@babel/preset-env": "^7.24.4",
@@ -221,7 +221,7 @@ static int32_t pooling_type_from_str(const std::string &s) {
221
221
  }
222
222
 
223
223
  // construct({ model, embedding, n_ctx, n_batch, n_threads, n_gpu_layers,
224
- // use_mlock, use_mmap }): LlamaContext throws error
224
+ // use_mlock, use_mmap }, onProgress?: (progress: number) => void): LlamaContext throws error
225
225
  LlamaContext::LlamaContext(const Napi::CallbackInfo &info)
226
226
  : Napi::ObjectWrap<LlamaContext>(info) {
227
227
  Napi::Env env = info.Env();
@@ -230,6 +230,16 @@ LlamaContext::LlamaContext(const Napi::CallbackInfo &info)
230
230
  }
231
231
  auto options = info[0].As<Napi::Object>();
232
232
 
233
+ // Check if progress callback is provided
234
+ bool has_progress_callback = info.Length() >= 2 && info[1].IsFunction();
235
+ if (has_progress_callback) {
236
+ _progress_tsfn = Napi::ThreadSafeFunction::New(
237
+ env, info[1].As<Napi::Function>(), "Model Loading Progress", 0, 1,
238
+ [](Napi::Env) {
239
+ // Finalizer callback
240
+ });
241
+ }
242
+
233
243
  common_params params;
234
244
  params.model.path = get_option<std::string>(options, "model", "");
235
245
  if (params.model.path.empty()) {
@@ -323,12 +333,55 @@ LlamaContext::LlamaContext(const Napi::CallbackInfo &info)
323
333
 
324
334
  // Use rn-llama context instead of direct session
325
335
  _rn_ctx = new llama_rn_context();
336
+ _rn_ctx->is_load_interrupted = false;
337
+ _rn_ctx->loading_progress = 0;
338
+
339
+ // Set up progress callback if provided
340
+ if (has_progress_callback) {
341
+ params.load_progress_callback = [](float progress, void *user_data) {
342
+ LlamaContext *self = static_cast<LlamaContext *>(user_data);
343
+ unsigned int percentage = static_cast<unsigned int>(100 * progress);
344
+
345
+ // Only call callback if progress increased
346
+ if (percentage > self->_rn_ctx->loading_progress) {
347
+ self->_rn_ctx->loading_progress = percentage;
348
+
349
+ // Create a heap-allocated copy of the percentage
350
+ auto *data = new unsigned int(percentage);
351
+
352
+ // Queue callback to be executed on the JavaScript thread
353
+ auto status = self->_progress_tsfn.NonBlockingCall(
354
+ data, [](Napi::Env env, Napi::Function jsCallback, unsigned int *data) {
355
+ jsCallback.Call({Napi::Number::New(env, *data)});
356
+ delete data;
357
+ });
358
+
359
+ // If the call failed, clean up the data
360
+ if (status != napi_ok) {
361
+ delete data;
362
+ }
363
+ }
364
+
365
+ // Return true to continue loading, false to interrupt
366
+ return !self->_rn_ctx->is_load_interrupted;
367
+ };
368
+ params.load_progress_callback_user_data = this;
369
+ }
370
+
326
371
  if (!_rn_ctx->loadModel(params)) {
372
+ if (has_progress_callback) {
373
+ _progress_tsfn.Release();
374
+ }
327
375
  delete _rn_ctx;
328
376
  _rn_ctx = nullptr;
329
377
  Napi::TypeError::New(env, "Failed to load model").ThrowAsJavaScriptException();
330
378
  }
331
379
 
380
+ // Release progress callback after model is loaded
381
+ if (has_progress_callback) {
382
+ _progress_tsfn.Release();
383
+ }
384
+
332
385
  // Handle LoRA adapters through rn-llama
333
386
  if (!lora.empty()) {
334
387
  _rn_ctx->applyLoraAdapters(lora);
@@ -343,6 +396,11 @@ LlamaContext::~LlamaContext() {
343
396
  _context_valid->store(false);
344
397
  }
345
398
 
399
+ // Interrupt model loading if in progress
400
+ if (_rn_ctx) {
401
+ _rn_ctx->is_load_interrupted = true;
402
+ }
403
+
346
404
  // The DisposeWorker is responsible for cleanup of _rn_ctx
347
405
  // If _rn_ctx is still not null here, it means disposal was not properly initiated
348
406
  if (_rn_ctx) {
@@ -78,4 +78,7 @@ private:
78
78
  // Validity flag for async callbacks to prevent use-after-free
79
79
  // Shared pointer ensures callbacks can safely check if context is still alive
80
80
  std::shared_ptr<std::atomic<bool>> _context_valid;
81
+
82
+ // Progress callback support for model loading
83
+ Napi::ThreadSafeFunction _progress_tsfn;
81
84
  };