@fugood/llama.node 1.2.6 → 1.3.0-rc.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/CMakeLists.txt CHANGED
@@ -124,6 +124,8 @@ include_directories(
124
124
  ${CMAKE_JS_INC}
125
125
  "src/llama.cpp"
126
126
  "src/llama.cpp/src"
127
+ "src/llama.cpp/ggml/include"
128
+ "src/llama.cpp/ggml/src"
127
129
  "src/tools/mtmd"
128
130
  )
129
131
 
@@ -137,6 +139,7 @@ file(
137
139
  "src/LlamaCompletionWorker.h"
138
140
  "src/LlamaContext.cpp"
139
141
  "src/LlamaContext.h"
142
+ "src/LlamaContext_parallel.cpp"
140
143
  "src/EmbeddingWorker.cpp"
141
144
  "src/EmbeddingWorker.h"
142
145
  "src/RerankWorker.cpp"
package/lib/binding.js CHANGED
@@ -15,23 +15,13 @@ var __setModuleDefault = (this && this.__setModuleDefault) || (Object.create ? (
15
15
  }) : function(o, v) {
16
16
  o["default"] = v;
17
17
  });
18
- var __importStar = (this && this.__importStar) || (function () {
19
- var ownKeys = function(o) {
20
- ownKeys = Object.getOwnPropertyNames || function (o) {
21
- var ar = [];
22
- for (var k in o) if (Object.prototype.hasOwnProperty.call(o, k)) ar[ar.length] = k;
23
- return ar;
24
- };
25
- return ownKeys(o);
26
- };
27
- return function (mod) {
28
- if (mod && mod.__esModule) return mod;
29
- var result = {};
30
- if (mod != null) for (var k = ownKeys(mod), i = 0; i < k.length; i++) if (k[i] !== "default") __createBinding(result, mod, k[i]);
31
- __setModuleDefault(result, mod);
32
- return result;
33
- };
34
- })();
18
+ var __importStar = (this && this.__importStar) || function (mod) {
19
+ if (mod && mod.__esModule) return mod;
20
+ var result = {};
21
+ if (mod != null) for (var k in mod) if (k !== "default" && Object.prototype.hasOwnProperty.call(mod, k)) __createBinding(result, mod, k);
22
+ __setModuleDefault(result, mod);
23
+ return result;
24
+ };
35
25
  var __awaiter = (this && this.__awaiter) || function (thisArg, _arguments, P, generator) {
36
26
  function adopt(value) { return value instanceof P ? value : new P(function (resolve) { resolve(value); }); }
37
27
  return new (P || (P = Promise))(function (resolve, reject) {
package/lib/binding.ts CHANGED
@@ -25,6 +25,12 @@ export type LlamaModelOptions = {
25
25
  n_ctx?: number
26
26
  n_batch?: number
27
27
  n_ubatch?: number
28
+ /**
29
+ * Number of parallel sequences to support (sets n_seq_max).
30
+ * This determines the maximum number of parallel slots that can be used.
31
+ * Default: 8
32
+ */
33
+ n_parallel?: number
28
34
  n_threads?: number
29
35
  n_gpu_layers?: number
30
36
  flash_attn_type?: 'auto' | 'on' | 'off'
@@ -157,6 +163,36 @@ export type LlamaCompletionOptions = {
157
163
  n_probs?: number
158
164
  }
159
165
 
166
+ /**
167
+ * Parameters for parallel completion requests (queueCompletion).
168
+ * Extends LlamaCompletionOptions with parallel-mode specific options.
169
+ */
170
+ export type LlamaParallelCompletionOptions = LlamaCompletionOptions & {
171
+ /**
172
+ * File path to load session state from before processing.
173
+ * This allows you to resume from a previously saved completion state.
174
+ * Use with `save_state_path` to enable conversation continuity across requests.
175
+ * Example: `'/path/to/session.bin'` or `'file:///path/to/session.bin'`
176
+ */
177
+ load_state_path?: string
178
+
179
+ /**
180
+ * File path to save session state to after completion.
181
+ * The session state will be saved to this file path when the completion finishes.
182
+ * You can then pass this path to `load_state_path` in a subsequent request to resume.
183
+ * Example: `'/path/to/session.bin'` or `'file:///path/to/session.bin'`
184
+ */
185
+ save_state_path?: string
186
+
187
+ /**
188
+ * Number of tokens to save when saving session state.
189
+ * If not specified or <= 0, all tokens will be saved.
190
+ * Use this to limit the size of saved session files.
191
+ * Example: `512` to save only the last 512 tokens
192
+ */
193
+ save_state_size?: number
194
+ }
195
+
160
196
  export type TokenProbability = {
161
197
  tok_str: string
162
198
  prob: number
@@ -200,6 +236,36 @@ export type LlamaCompletionToken = {
200
236
  completion_probabilities?: CompletionProbability[]
201
237
  }
202
238
 
239
+ /**
240
+ * Result from a parallel completion request (queueCompletion callback).
241
+ * Extends the basic completion result with per-slot timing information.
242
+ */
243
+ export type LlamaParallelCompletionResult = {
244
+ requestId: number
245
+ text: string
246
+ reasoning_content?: string
247
+ content?: string
248
+ tool_calls?: ToolCall[]
249
+ chat_format: number
250
+ stopped_eos: boolean
251
+ stopped_limit: boolean
252
+ stopped_word: boolean
253
+ context_full: boolean
254
+ tokens_evaluated: number
255
+ tokens_predicted: number
256
+ timings: {
257
+ cache_n: number
258
+ prompt_n: number
259
+ prompt_ms: number
260
+ prompt_per_token_ms: number
261
+ prompt_per_second: number
262
+ predicted_n: number
263
+ predicted_ms: number
264
+ predicted_per_token_ms: number
265
+ predicted_per_second: number
266
+ }
267
+ }
268
+
203
269
  export type TokenizeResult = {
204
270
  tokens: Int32Array
205
271
  has_media: boolean
@@ -221,6 +287,14 @@ export type RerankResult = {
221
287
  index: number
222
288
  }
223
289
 
290
+ export type BackendDeviceInfo = {
291
+ backend: string
292
+ type: string
293
+ deviceName: string
294
+ maxMemorySize: number
295
+ metadata?: Record<string, any>
296
+ }
297
+
224
298
  export type ModelInfo = {
225
299
  desc: string
226
300
  nEmbd: number
@@ -271,7 +345,7 @@ export type JinjaFormattedChatResult = {
271
345
  prompt: string
272
346
  chat_format: number
273
347
  grammar: string
274
- grammea_lazy: boolean
348
+ grammar_lazy: boolean
275
349
  grammar_triggers: Array<{
276
350
  type: number
277
351
  value: string
@@ -404,12 +478,76 @@ export interface LlamaContext {
404
478
  */
405
479
  decodeAudioTokens(tokens: number[]|Int32Array): Promise<Float32Array>
406
480
 
481
+ // Parallel decoding methods
482
+
483
+ /**
484
+ * Enable parallel decoding mode
485
+ * @param params Configuration for parallel mode
486
+ * @returns boolean indicating if successful
487
+ */
488
+ enableParallelMode(params: { n_parallel?: number, n_batch?: number }): boolean
489
+
490
+ /**
491
+ * Disable parallel decoding mode
492
+ */
493
+ disableParallelMode(): void
494
+
495
+ /**
496
+ * Queue a completion request for parallel processing
497
+ * @param options Completion options with parallel-specific state management
498
+ * @param callback Optional callback that receives tokens during generation and final result
499
+ * @returns Object with requestId
500
+ */
501
+ queueCompletion(
502
+ options: LlamaParallelCompletionOptions,
503
+ callback?: (error: any, result: LlamaParallelCompletionResult) => void,
504
+ ): { requestId: number }
505
+
506
+ /**
507
+ * Queue an embedding request for parallel processing
508
+ * @param text Text to embed
509
+ * @param params Optional embedding parameters
510
+ * @param callback Optional result callback
511
+ * @returns Object with requestId
512
+ */
513
+ queueEmbedding(
514
+ text: string,
515
+ params?: { embd_normalize?: number },
516
+ callback?: (error: any, result: any) => void,
517
+ ): { requestId: number }
518
+
519
+ /**
520
+ * Queue a rerank request for parallel processing
521
+ * @param query Query text
522
+ * @param documents Documents to rank
523
+ * @param params Optional rerank parameters
524
+ * @param callback Optional result callback
525
+ * @returns Object with requestId
526
+ */
527
+ queueRerank(
528
+ query: string,
529
+ documents: string[],
530
+ params?: RerankParams,
531
+ callback?: (error: any, result: any) => void,
532
+ ): { requestId: number }
533
+
534
+ /**
535
+ * Cancel a queued request
536
+ * @param requestId Request ID to cancel
537
+ */
538
+ cancelRequest(requestId: number): void
539
+
407
540
  // static
408
541
  loadModelInfo(path: string, skip: string[]): Promise<GGUFModelInfo>
409
542
  toggleNativeLog(
410
543
  enable: boolean,
411
544
  callback: (level: string, text: string) => void,
412
545
  ): void
546
+ /**
547
+ * Get information about available backend devices
548
+ * @returns Array of backend device information
549
+ */
550
+ getBackendDevicesInfo(): BackendDeviceInfo[]
413
551
  }
414
552
 
415
553
  export interface Module {
package/lib/index.js CHANGED
@@ -23,10 +23,12 @@ var __awaiter = (this && this.__awaiter) || function (thisArg, _arguments, P, ge
23
23
  });
24
24
  };
25
25
  Object.defineProperty(exports, "__esModule", { value: true });
26
- exports.BuildInfo = exports.loadLlamaModelInfo = exports.initLlama = exports.loadModel = exports.toggleNativeLog = exports.MTMD_DEFAULT_MEDIA_MARKER = void 0;
26
+ exports.BuildInfo = exports.getBackendDevicesInfo = exports.loadLlamaModelInfo = exports.initLlama = exports.loadModel = exports.toggleNativeLog = exports.MTMD_DEFAULT_MEDIA_MARKER = exports.LlamaParallelAPI = void 0;
27
27
  exports.addNativeLogListener = addNativeLogListener;
28
28
  const binding_1 = require("./binding");
29
29
  const version_1 = require("./version");
30
+ const parallel_1 = require("./parallel");
31
+ Object.defineProperty(exports, "LlamaParallelAPI", { enumerable: true, get: function () { return parallel_1.LlamaParallelAPI; } });
30
32
  __exportStar(require("./binding"), exports);
31
33
  exports.MTMD_DEFAULT_MEDIA_MARKER = '<__media__>';
32
34
  const mods = {};
@@ -66,6 +68,7 @@ const getJsonSchema = (responseFormat) => {
66
68
  class LlamaContextWrapper {
67
69
  constructor(nativeCtx) {
68
70
  this.ctx = nativeCtx;
71
+ this.parallel = new parallel_1.LlamaParallelAPI(nativeCtx);
69
72
  }
70
73
  getSystemInfo() {
71
74
  return this.ctx.getSystemInfo();
@@ -138,7 +141,6 @@ class LlamaContextWrapper {
138
141
  let tmpl;
139
142
  if (template)
140
143
  tmpl = template; // Force replace if provided
141
- const jsonSchema = getJsonSchema(params === null || params === void 0 ? void 0 : params.response_format);
142
144
  const result = this.ctx.getFormattedChat(chat, tmpl, {
143
145
  jinja: useJinja,
144
146
  response_format: params === null || params === void 0 ? void 0 : params.response_format,
@@ -267,6 +269,14 @@ const loadLlamaModelInfo = (path) => __awaiter(void 0, void 0, void 0, function*
267
269
  return mods[variant].LlamaContext.loadModelInfo(path, modelInfoSkip);
268
270
  });
269
271
  exports.loadLlamaModelInfo = loadLlamaModelInfo;
272
+ const getBackendDevicesInfo = (...args_1) => __awaiter(void 0, [...args_1], void 0, function* (variant = 'default') {
273
+ var _a;
274
+ (_a = mods[variant]) !== null && _a !== void 0 ? _a : (mods[variant] = yield (0, binding_1.loadModule)(variant));
275
+ refreshNativeLogSetup();
276
+ const jsonString = mods[variant].LlamaContext.getBackendDevicesInfo();
277
+ return JSON.parse(jsonString);
278
+ });
279
+ exports.getBackendDevicesInfo = getBackendDevicesInfo;
270
280
  exports.BuildInfo = {
271
281
  number: version_1.BUILD_NUMBER,
272
282
  commit: version_1.BUILD_COMMIT,
package/lib/index.ts CHANGED
@@ -18,8 +18,10 @@ import type {
18
18
  GGUFModelInfo,
19
19
  } from './binding'
20
20
  import { BUILD_NUMBER, BUILD_COMMIT } from './version'
21
+ import { LlamaParallelAPI } from './parallel'
21
22
 
22
23
  export * from './binding'
24
+ export { LlamaParallelAPI }
23
25
 
24
26
  export const MTMD_DEFAULT_MEDIA_MARKER = '<__media__>'
25
27
 
@@ -78,9 +80,11 @@ export type FormattedChatResult = {
78
80
 
79
81
  class LlamaContextWrapper {
80
82
  ctx: LlamaContext
83
+ parallel: LlamaParallelAPI
81
84
 
82
85
  constructor(nativeCtx: LlamaContext) {
83
86
  this.ctx = nativeCtx
87
+ this.parallel = new LlamaParallelAPI(nativeCtx)
84
88
  }
85
89
 
86
90
  getSystemInfo(): string {
@@ -181,7 +185,6 @@ class LlamaContextWrapper {
181
185
  const useJinja = this.isJinjaSupported() && params?.jinja
182
186
  let tmpl
183
187
  if (template) tmpl = template // Force replace if provided
184
- const jsonSchema = getJsonSchema(params?.response_format)
185
188
 
186
189
  const result = this.ctx.getFormattedChat(chat!, tmpl, {
187
190
  jinja: useJinja,
@@ -382,6 +385,15 @@ export const loadLlamaModelInfo = async (
382
385
  return mods[variant].LlamaContext.loadModelInfo(path, modelInfoSkip)
383
386
  }
384
387
 
388
+ export const getBackendDevicesInfo = async (
389
+ variant: LibVariant = 'default'
390
+ ): Promise<import('./binding').BackendDeviceInfo[]> => {
391
+ mods[variant] ??= await loadModule(variant)
392
+ refreshNativeLogSetup()
393
+ const jsonString = mods[variant].LlamaContext.getBackendDevicesInfo()
394
+ return JSON.parse(jsonString as any)
395
+ }
396
+
385
397
  export const BuildInfo = {
386
398
  number: BUILD_NUMBER,
387
399
  commit: BUILD_COMMIT,
@@ -0,0 +1,214 @@
1
+ "use strict";
2
+ var __awaiter = (this && this.__awaiter) || function (thisArg, _arguments, P, generator) {
3
+ function adopt(value) { return value instanceof P ? value : new P(function (resolve) { resolve(value); }); }
4
+ return new (P || (P = Promise))(function (resolve, reject) {
5
+ function fulfilled(value) { try { step(generator.next(value)); } catch (e) { reject(e); } }
6
+ function rejected(value) { try { step(generator["throw"](value)); } catch (e) { reject(e); } }
7
+ function step(result) { result.done ? resolve(result.value) : adopt(result.value).then(fulfilled, rejected); }
8
+ step((generator = generator.apply(thisArg, _arguments || [])).next());
9
+ });
10
+ };
11
+ Object.defineProperty(exports, "__esModule", { value: true });
12
+ exports.LlamaParallelAPI = void 0;
13
+ class LlamaParallelAPI {
14
+ constructor(context) {
15
+ this.enabled = false;
16
+ this.pendingRequests = new Map();
17
+ this.context = context;
18
+ }
19
+ /**
20
+ * Enable parallel decoding mode
21
+ * @param config Configuration for parallel mode
22
+ * @returns boolean indicating if successful
23
+ */
24
+ enable(config) {
25
+ return __awaiter(this, void 0, void 0, function* () {
26
+ const defaultConfig = { n_parallel: 2, n_batch: 512 };
27
+ const result = this.context.enableParallelMode(Object.assign(Object.assign({}, defaultConfig), config));
28
+ this.enabled = result;
29
+ return result;
30
+ });
31
+ }
32
+ /**
33
+ * Disable parallel decoding mode
34
+ */
35
+ disable() {
36
+ this.context.disableParallelMode();
37
+ this.enabled = false;
38
+ }
39
+ /**
40
+ * Configure parallel decoding mode (enables if not already enabled)
41
+ * @param config Configuration for parallel mode
42
+ * @returns boolean indicating if successful
43
+ */
44
+ configure(config) {
45
+ return __awaiter(this, void 0, void 0, function* () {
46
+ return this.enable(config);
47
+ });
48
+ }
49
+ /**
50
+ * Queue a completion request for parallel processing
51
+ * @param options Completion options
52
+ * @param onToken Optional callback for each token
53
+ * @returns Object with requestId, promise for result, and stop function
54
+ */
55
+ completion(options, onToken) {
56
+ return __awaiter(this, void 0, void 0, function* () {
57
+ if (!this.enabled) {
58
+ throw new Error('Parallel mode is not enabled. Call enable() first.');
59
+ }
60
+ const tokenCallback = onToken
61
+ ? (error, result) => {
62
+ if (error) {
63
+ console.error('Token callback error:', error);
64
+ // Handle completion error
65
+ const pendingReq = this.pendingRequests.get(result === null || result === void 0 ? void 0 : result.requestId);
66
+ if (pendingReq) {
67
+ pendingReq.reject(error);
68
+ this.pendingRequests.delete(result === null || result === void 0 ? void 0 : result.requestId);
69
+ }
70
+ return;
71
+ }
72
+ // Check if this is a token callback or final result
73
+ if (result) {
74
+ if (result.token !== undefined) {
75
+ // This is a token callback
76
+ onToken(result.requestId, result);
77
+ }
78
+ else if (result.text !== undefined ||
79
+ result.content !== undefined) {
80
+ // This is the final result
81
+ const pendingReq = this.pendingRequests.get(result.requestId);
82
+ if (pendingReq) {
83
+ pendingReq.resolve(result);
84
+ this.pendingRequests.delete(result.requestId);
85
+ }
86
+ }
87
+ }
88
+ }
89
+ : undefined;
90
+ // Queue the completion immediately (this is synchronous!)
91
+ const { requestId } = this.context.queueCompletion(options, tokenCallback ||
92
+ ((error, result) => {
93
+ if (error) {
94
+ const pendingReq = this.pendingRequests.get(result === null || result === void 0 ? void 0 : result.requestId);
95
+ if (pendingReq) {
96
+ pendingReq.reject(error);
97
+ this.pendingRequests.delete(result === null || result === void 0 ? void 0 : result.requestId);
98
+ }
99
+ }
100
+ else if (result &&
101
+ (result.text !== undefined || result.content !== undefined)) {
102
+ // Final result for non-streaming
103
+ const pendingReq = this.pendingRequests.get(result.requestId);
104
+ if (pendingReq) {
105
+ pendingReq.resolve(result);
106
+ this.pendingRequests.delete(result.requestId);
107
+ }
108
+ }
109
+ }));
110
+ // Create promise for final result
111
+ const promise = new Promise((resolveResult, rejectResult) => {
112
+ this.pendingRequests.set(requestId, {
113
+ resolve: resolveResult,
114
+ reject: rejectResult,
115
+ });
116
+ });
117
+ // Create stop function
118
+ const stop = () => {
119
+ this.context.cancelRequest(requestId);
120
+ const pendingReq = this.pendingRequests.get(requestId);
121
+ if (pendingReq) {
122
+ pendingReq.reject(new Error('Request cancelled'));
123
+ this.pendingRequests.delete(requestId);
124
+ }
125
+ };
126
+ // Return immediately without wrapping in a Promise
127
+ return {
128
+ requestId,
129
+ promise,
130
+ stop,
131
+ };
132
+ });
133
+ }
134
+ /**
135
+ * Queue an embedding request for parallel processing
136
+ * @param text Text to embed
137
+ * @param params Optional embedding parameters
138
+ * @returns Object with requestId and promise for result
139
+ */
140
+ embedding(text, params) {
141
+ return __awaiter(this, void 0, void 0, function* () {
142
+ if (!this.enabled) {
143
+ throw new Error('Parallel mode is not enabled. Call enable() first.');
144
+ }
145
+ // Create promise for result
146
+ let resolveResult;
147
+ let rejectResult;
148
+ const promise = new Promise((res, rej) => {
149
+ resolveResult = res;
150
+ rejectResult = rej;
151
+ });
152
+ // Queue the embedding immediately (this is synchronous!)
153
+ const { requestId } = this.context.queueEmbedding(text, params, (error, result) => {
154
+ if (error) {
155
+ rejectResult(error);
156
+ }
157
+ else {
158
+ resolveResult(result);
159
+ }
160
+ });
161
+ // Return immediately without wrapping in a Promise
162
+ return {
163
+ requestId,
164
+ promise,
165
+ };
166
+ });
167
+ }
168
+ /**
169
+ * Queue a rerank request for parallel processing
170
+ * @param query Query text
171
+ * @param documents Documents to rank
172
+ * @param params Optional rerank parameters
173
+ * @returns Object with requestId and promise for results
174
+ */
175
+ rerank(query, documents, params) {
176
+ return __awaiter(this, void 0, void 0, function* () {
177
+ if (!this.enabled) {
178
+ throw new Error('Parallel mode is not enabled. Call enable() first.');
179
+ }
180
+ // Create promise for result
181
+ let resolveResult;
182
+ let rejectResult;
183
+ const promise = new Promise((res, rej) => {
184
+ resolveResult = res;
185
+ rejectResult = rej;
186
+ });
187
+ // Queue the rerank immediately (this is synchronous!)
188
+ const { requestId } = this.context.queueRerank(query, documents, params, (error, result) => {
189
+ if (error) {
190
+ rejectResult(error);
191
+ }
192
+ else {
193
+ // Add document text to results and sort by score
194
+ const enrichedResults = result.results
195
+ .map((r) => (Object.assign(Object.assign({}, r), { document: documents[r.index] })))
196
+ .sort((a, b) => b.score - a.score);
197
+ resolveResult(enrichedResults);
198
+ }
199
+ });
200
+ // Return immediately without wrapping in a Promise
201
+ return {
202
+ requestId,
203
+ promise,
204
+ };
205
+ });
206
+ }
207
+ /**
208
+ * Check if parallel mode is enabled
209
+ */
210
+ isEnabled() {
211
+ return this.enabled;
212
+ }
213
+ }
214
+ exports.LlamaParallelAPI = LlamaParallelAPI;
@@ -0,0 +1,273 @@
1
+ // Parallel decoding API implementation for llama.node
2
+ import type {
3
+ LlamaContext,
4
+ LlamaCompletionOptions,
5
+ LlamaCompletionToken,
6
+ RerankParams,
7
+ } from './binding'
8
+
9
+ export class LlamaParallelAPI {
10
+ private context: LlamaContext
11
+ private enabled: boolean = false
12
+ private pendingRequests = new Map<
13
+ number,
14
+ {
15
+ resolve: (value: any) => void
16
+ reject: (reason?: any) => void
17
+ }
18
+ >()
19
+
20
+ constructor(context: LlamaContext) {
21
+ this.context = context
22
+ }
23
+
24
+ /**
25
+ * Enable parallel decoding mode
26
+ * @param config Configuration for parallel mode
27
+ * @returns boolean indicating if successful
28
+ */
29
+ async enable(config?: {
30
+ n_parallel?: number
31
+ n_batch?: number
32
+ }): Promise<boolean> {
33
+ const defaultConfig = { n_parallel: 2, n_batch: 512 }
34
+ const result = this.context.enableParallelMode({
35
+ ...defaultConfig,
36
+ ...config,
37
+ })
38
+ this.enabled = result
39
+ return result
40
+ }
41
+
42
+ /**
43
+ * Disable parallel decoding mode
44
+ */
45
+ disable(): void {
46
+ this.context.disableParallelMode()
47
+ this.enabled = false
48
+ }
49
+
50
+ /**
51
+ * Configure parallel decoding mode (enables if not already enabled)
52
+ * @param config Configuration for parallel mode
53
+ * @returns boolean indicating if successful
54
+ */
55
+ async configure(config: {
56
+ n_parallel?: number
57
+ n_batch?: number
58
+ }): Promise<boolean> {
59
+ return this.enable(config)
60
+ }
61
+
62
+ /**
63
+ * Queue a completion request for parallel processing
64
+ * @param options Completion options
65
+ * @param onToken Optional callback for each token
66
+ * @returns Object with requestId, promise for result, and stop function
67
+ */
68
+ async completion(
69
+ options: LlamaCompletionOptions,
70
+ onToken?: (requestId: number, data: LlamaCompletionToken) => void,
71
+ ): Promise<{
72
+ requestId: number
73
+ promise: Promise<any>
74
+ stop: () => void
75
+ }> {
76
+ if (!this.enabled) {
77
+ throw new Error('Parallel mode is not enabled. Call enable() first.')
78
+ }
79
+
80
+ const tokenCallback = onToken
81
+ ? (error: any, result: any) => {
82
+ if (error) {
83
+ console.error('Token callback error:', error)
84
+ // Handle completion error
85
+ const pendingReq = this.pendingRequests.get(result?.requestId)
86
+ if (pendingReq) {
87
+ pendingReq.reject(error)
88
+ this.pendingRequests.delete(result?.requestId)
89
+ }
90
+ return
91
+ }
92
+ // Check if this is a token callback or final result
93
+ if (result) {
94
+ if (result.token !== undefined) {
95
+ // This is a token callback
96
+ onToken(result.requestId, result)
97
+ } else if (
98
+ result.text !== undefined ||
99
+ result.content !== undefined
100
+ ) {
101
+ // This is the final result
102
+ const pendingReq = this.pendingRequests.get(result.requestId)
103
+ if (pendingReq) {
104
+ pendingReq.resolve(result)
105
+ this.pendingRequests.delete(result.requestId)
106
+ }
107
+ }
108
+ }
109
+ }
110
+ : undefined
111
+
112
+ // Queue the completion immediately (this is synchronous!)
113
+ const { requestId } = this.context.queueCompletion(
114
+ options,
115
+ tokenCallback ||
116
+ ((error, result) => {
117
+ if (error) {
118
+ const pendingReq = this.pendingRequests.get(result?.requestId)
119
+ if (pendingReq) {
120
+ pendingReq.reject(error)
121
+ this.pendingRequests.delete(result?.requestId)
122
+ }
123
+ } else if (
124
+ result &&
125
+ (result.text !== undefined || result.content !== undefined)
126
+ ) {
127
+ // Final result for non-streaming
128
+ const pendingReq = this.pendingRequests.get(result.requestId)
129
+ if (pendingReq) {
130
+ pendingReq.resolve(result)
131
+ this.pendingRequests.delete(result.requestId)
132
+ }
133
+ }
134
+ }),
135
+ )
136
+
137
+ // Create promise for final result
138
+ const promise = new Promise((resolveResult, rejectResult) => {
139
+ this.pendingRequests.set(requestId, {
140
+ resolve: resolveResult,
141
+ reject: rejectResult,
142
+ })
143
+ })
144
+
145
+ // Create stop function
146
+ const stop = () => {
147
+ this.context.cancelRequest(requestId)
148
+ const pendingReq = this.pendingRequests.get(requestId)
149
+ if (pendingReq) {
150
+ pendingReq.reject(new Error('Request cancelled'))
151
+ this.pendingRequests.delete(requestId)
152
+ }
153
+ }
154
+
155
+ // Return immediately without wrapping in a Promise
156
+ return {
157
+ requestId,
158
+ promise,
159
+ stop,
160
+ }
161
+ }
162
+
163
+ /**
164
+ * Queue an embedding request for parallel processing
165
+ * @param text Text to embed
166
+ * @param params Optional embedding parameters
167
+ * @returns Object with requestId and promise for result
168
+ */
169
+ async embedding(
170
+ text: string,
171
+ params?: { embd_normalize?: number },
172
+ ): Promise<{
173
+ requestId: number
174
+ promise: Promise<{ embedding: number[] }>
175
+ }> {
176
+ if (!this.enabled) {
177
+ throw new Error('Parallel mode is not enabled. Call enable() first.')
178
+ }
179
+
180
+ // Create promise for result
181
+ let resolveResult: (value: any) => void
182
+ let rejectResult: (reason?: any) => void
183
+
184
+ const promise = new Promise<{ embedding: number[] }>((res, rej) => {
185
+ resolveResult = res
186
+ rejectResult = rej
187
+ })
188
+
189
+ // Queue the embedding immediately (this is synchronous!)
190
+ const { requestId } = this.context.queueEmbedding(
191
+ text,
192
+ params,
193
+ (error, result) => {
194
+ if (error) {
195
+ rejectResult(error)
196
+ } else {
197
+ resolveResult(result)
198
+ }
199
+ },
200
+ )
201
+
202
+ // Return immediately without wrapping in a Promise
203
+ return {
204
+ requestId,
205
+ promise,
206
+ }
207
+ }
208
+
209
+ /**
210
+ * Queue a rerank request for parallel processing
211
+ * @param query Query text
212
+ * @param documents Documents to rank
213
+ * @param params Optional rerank parameters
214
+ * @returns Object with requestId and promise for results
215
+ */
216
+ async rerank(
217
+ query: string,
218
+ documents: string[],
219
+ params?: RerankParams,
220
+ ): Promise<{
221
+ requestId: number
222
+ promise: Promise<Array<{ score: number; index: number; document: string }>>
223
+ }> {
224
+ if (!this.enabled) {
225
+ throw new Error('Parallel mode is not enabled. Call enable() first.')
226
+ }
227
+
228
+ // Create promise for result
229
+ let resolveResult: (value: any) => void
230
+ let rejectResult: (reason?: any) => void
231
+
232
+ const promise = new Promise<
233
+ Array<{ score: number; index: number; document: string }>
234
+ >((res, rej) => {
235
+ resolveResult = res
236
+ rejectResult = rej
237
+ })
238
+
239
+ // Queue the rerank immediately (this is synchronous!)
240
+ const { requestId } = this.context.queueRerank(
241
+ query,
242
+ documents,
243
+ params,
244
+ (error, result) => {
245
+ if (error) {
246
+ rejectResult(error)
247
+ } else {
248
+ // Add document text to results and sort by score
249
+ const enrichedResults = result.results
250
+ .map((r: any) => ({
251
+ ...r,
252
+ document: documents[r.index],
253
+ }))
254
+ .sort((a: any, b: any) => b.score - a.score)
255
+ resolveResult(enrichedResults)
256
+ }
257
+ },
258
+ )
259
+
260
+ // Return immediately without wrapping in a Promise
261
+ return {
262
+ requestId,
263
+ promise,
264
+ }
265
+ }
266
+
267
+ /**
268
+ * Check if parallel mode is enabled
269
+ */
270
+ isEnabled(): boolean {
271
+ return this.enabled
272
+ }
273
+ }
package/package.json CHANGED
@@ -1,7 +1,7 @@
1
1
  {
2
2
  "name": "@fugood/llama.node",
3
3
  "access": "public",
4
- "version": "1.2.6",
4
+ "version": "1.3.0-rc.1",
5
5
  "description": "An another Node binding of llama.cpp",
6
6
  "main": "lib/index.js",
7
7
  "scripts": {
@@ -72,19 +72,19 @@
72
72
  "CMakeLists.txt"
73
73
  ],
74
74
  "optionalDependencies": {
75
- "@fugood/node-llama-linux-x64": "1.2.6",
76
- "@fugood/node-llama-linux-x64-vulkan": "1.2.6",
77
- "@fugood/node-llama-linux-x64-cuda": "1.2.6",
78
- "@fugood/node-llama-linux-arm64": "1.2.6",
79
- "@fugood/node-llama-linux-arm64-vulkan": "1.2.6",
80
- "@fugood/node-llama-linux-arm64-cuda": "1.2.6",
81
- "@fugood/node-llama-win32-x64": "1.2.6",
82
- "@fugood/node-llama-win32-x64-vulkan": "1.2.6",
83
- "@fugood/node-llama-win32-x64-cuda": "1.2.6",
84
- "@fugood/node-llama-win32-arm64": "1.2.6",
85
- "@fugood/node-llama-win32-arm64-vulkan": "1.2.6",
86
- "@fugood/node-llama-darwin-x64": "1.2.6",
87
- "@fugood/node-llama-darwin-arm64": "1.2.6"
75
+ "@fugood/node-llama-linux-x64": "1.3.0-rc.1",
76
+ "@fugood/node-llama-linux-x64-vulkan": "1.3.0-rc.1",
77
+ "@fugood/node-llama-linux-x64-cuda": "1.3.0-rc.1",
78
+ "@fugood/node-llama-linux-arm64": "1.3.0-rc.1",
79
+ "@fugood/node-llama-linux-arm64-vulkan": "1.3.0-rc.1",
80
+ "@fugood/node-llama-linux-arm64-cuda": "1.3.0-rc.1",
81
+ "@fugood/node-llama-win32-x64": "1.3.0-rc.1",
82
+ "@fugood/node-llama-win32-x64-vulkan": "1.3.0-rc.1",
83
+ "@fugood/node-llama-win32-x64-cuda": "1.3.0-rc.1",
84
+ "@fugood/node-llama-win32-arm64": "1.3.0-rc.1",
85
+ "@fugood/node-llama-win32-arm64-vulkan": "1.3.0-rc.1",
86
+ "@fugood/node-llama-darwin-x64": "1.3.0-rc.1",
87
+ "@fugood/node-llama-darwin-arm64": "1.3.0-rc.1"
88
88
  },
89
89
  "devDependencies": {
90
90
  "@babel/preset-env": "^7.24.4",
@@ -89,6 +89,13 @@ Napi::Value LlamaContext::ModelInfo(const Napi::CallbackInfo &info) {
89
89
  return metadata;
90
90
  }
91
91
 
92
+ // getBackendDevicesInfo(): string
93
+ Napi::Value LlamaContext::GetBackendDevicesInfo(const Napi::CallbackInfo &info) {
94
+ Napi::Env env = info.Env();
95
+ std::string devices_json = rnllama::get_backend_devices_info();
96
+ return Napi::String::New(env, devices_json);
97
+ }
98
+
92
99
  void LlamaContext::Init(Napi::Env env, Napi::Object &exports) {
93
100
  Napi::Function func = DefineClass(
94
101
  env, "LlamaContext",
@@ -148,6 +155,9 @@ void LlamaContext::Init(Napi::Env env, Napi::Object &exports) {
148
155
  StaticMethod<&LlamaContext::ToggleNativeLog>(
149
156
  "toggleNativeLog",
150
157
  static_cast<napi_property_attributes>(napi_enumerable)),
158
+ StaticMethod<&LlamaContext::GetBackendDevicesInfo>(
159
+ "getBackendDevicesInfo",
160
+ static_cast<napi_property_attributes>(napi_enumerable)),
151
161
  InstanceMethod<&LlamaContext::GetMultimodalSupport>(
152
162
  "getMultimodalSupport",
153
163
  static_cast<napi_property_attributes>(napi_enumerable)),
@@ -168,6 +178,25 @@ void LlamaContext::Init(Napi::Env env, Napi::Object &exports) {
168
178
  static_cast<napi_property_attributes>(napi_enumerable)),
169
179
  InstanceMethod<&LlamaContext::DecodeAudioTokens>(
170
180
  "decodeAudioTokens",
181
+ static_cast<napi_property_attributes>(napi_enumerable)),
182
+ // Parallel decoding methods
183
+ InstanceMethod<&LlamaContext::EnableParallelMode>(
184
+ "enableParallelMode",
185
+ static_cast<napi_property_attributes>(napi_enumerable)),
186
+ InstanceMethod<&LlamaContext::DisableParallelMode>(
187
+ "disableParallelMode",
188
+ static_cast<napi_property_attributes>(napi_enumerable)),
189
+ InstanceMethod<&LlamaContext::QueueCompletion>(
190
+ "queueCompletion",
191
+ static_cast<napi_property_attributes>(napi_enumerable)),
192
+ InstanceMethod<&LlamaContext::QueueEmbedding>(
193
+ "queueEmbedding",
194
+ static_cast<napi_property_attributes>(napi_enumerable)),
195
+ InstanceMethod<&LlamaContext::QueueRerank>(
196
+ "queueRerank",
197
+ static_cast<napi_property_attributes>(napi_enumerable)),
198
+ InstanceMethod<&LlamaContext::CancelRequest>(
199
+ "cancelRequest",
171
200
  static_cast<napi_property_attributes>(napi_enumerable))});
172
201
  Napi::FunctionReference *constructor = new Napi::FunctionReference();
173
202
  *constructor = Napi::Persistent(func);
@@ -217,6 +246,7 @@ LlamaContext::LlamaContext(const Napi::CallbackInfo &info)
217
246
  params.n_ctx = get_option<int32_t>(options, "n_ctx", 512);
218
247
  params.n_batch = get_option<int32_t>(options, "n_batch", 2048);
219
248
  params.n_ubatch = get_option<int32_t>(options, "n_ubatch", 512);
249
+ params.n_parallel = get_option<int32_t>(options, "n_parallel", 1); // Default to 1 for compatibility
220
250
  params.embedding = get_option<bool>(options, "embedding", false);
221
251
  if (params.embedding) {
222
252
  // For non-causal models, batch size must be equal to ubatch size
@@ -288,6 +318,9 @@ LlamaContext::LlamaContext(const Napi::CallbackInfo &info)
288
318
  }
289
319
  }
290
320
  }
321
+ // Initialize validity flag for async callback safety
322
+ _context_valid = std::make_shared<std::atomic<bool>>(true);
323
+
291
324
  // Use rn-llama context instead of direct session
292
325
  _rn_ctx = new llama_rn_context();
293
326
  if (!_rn_ctx->loadModel(params)) {
@@ -305,6 +338,11 @@ LlamaContext::LlamaContext(const Napi::CallbackInfo &info)
305
338
  }
306
339
 
307
340
  LlamaContext::~LlamaContext() {
341
+ // Invalidate the context to prevent use-after-free in async callbacks
342
+ if (_context_valid) {
343
+ _context_valid->store(false);
344
+ }
345
+
308
346
  // The DisposeWorker is responsible for cleanup of _rn_ctx
309
347
  // If _rn_ctx is still not null here, it means disposal was not properly initiated
310
348
  if (_rn_ctx) {
@@ -579,7 +617,7 @@ Napi::Value LlamaContext::GetFormattedChat(const Napi::CallbackInfo &info) {
579
617
  // grammar: string
580
618
  result.Set("grammar", chatParams.grammar);
581
619
  // grammar_lazy: boolean
582
- result.Set("grammea_lazy", chatParams.grammar_lazy);
620
+ result.Set("grammar_lazy", chatParams.grammar_lazy);
583
621
  // grammar_triggers: [{ value: string, token: number }]
584
622
  Napi::Array grammar_triggers = Napi::Array::New(env);
585
623
  for (size_t i = 0; i < chatParams.grammar_triggers.size(); i++) {
@@ -1135,6 +1173,11 @@ Napi::Value LlamaContext::Release(const Napi::CallbackInfo &info) {
1135
1173
  _wip->SetStop();
1136
1174
  }
1137
1175
 
1176
+ // stop_processing_loop
1177
+ if (_rn_ctx && _rn_ctx->slot_manager) {
1178
+ _rn_ctx->slot_manager->stop_processing_loop();
1179
+ }
1180
+
1138
1181
  if (_rn_ctx == nullptr) {
1139
1182
  auto promise = Napi::Promise::Deferred(env);
1140
1183
  promise.Resolve(env.Undefined());
@@ -4,6 +4,10 @@
4
4
  #include "rn-llama/rn-llama.h"
5
5
  #include "rn-llama/rn-completion.h"
6
6
  #include "rn-llama/rn-tts.h"
7
+ #include "rn-llama/rn-slot.h"
8
+ #include "rn-llama/rn-slot-manager.h"
9
+ #include <atomic>
10
+ #include <memory>
7
11
 
8
12
  using namespace rnllama;
9
13
 
@@ -21,6 +25,7 @@ public:
21
25
  ~LlamaContext();
22
26
  static void ToggleNativeLog(const Napi::CallbackInfo &info);
23
27
  static Napi::Value ModelInfo(const Napi::CallbackInfo &info);
28
+ static Napi::Value GetBackendDevicesInfo(const Napi::CallbackInfo &info);
24
29
  static void Init(Napi::Env env, Napi::Object &exports);
25
30
 
26
31
  private:
@@ -55,10 +60,22 @@ private:
55
60
  Napi::Value GetAudioCompletionGuideTokens(const Napi::CallbackInfo &info);
56
61
  Napi::Value DecodeAudioTokens(const Napi::CallbackInfo &info);
57
62
 
63
+ // Parallel decoding methods
64
+ Napi::Value EnableParallelMode(const Napi::CallbackInfo &info);
65
+ void DisableParallelMode(const Napi::CallbackInfo &info);
66
+ Napi::Value QueueCompletion(const Napi::CallbackInfo &info);
67
+ Napi::Value QueueEmbedding(const Napi::CallbackInfo &info);
68
+ Napi::Value QueueRerank(const Napi::CallbackInfo &info);
69
+ void CancelRequest(const Napi::CallbackInfo &info);
70
+
58
71
  std::string _info;
59
72
  Napi::Object _meta;
60
73
  LlamaCompletionWorker *_wip = nullptr;
61
74
 
62
75
  // Use rn-llama context instead of direct llama.cpp types
63
76
  llama_rn_context *_rn_ctx = nullptr;
77
+
78
+ // Validity flag for async callbacks to prevent use-after-free
79
+ // Shared pointer ensures callbacks can safely check if context is still alive
80
+ std::shared_ptr<std::atomic<bool>> _context_valid;
64
81
  };
package/src/common.hpp CHANGED
@@ -16,11 +16,12 @@ static bool is_nil(const Napi::Value &value) {
16
16
  return value.IsNull() || value.IsUndefined();
17
17
  }
18
18
 
19
- static std::string json_stringify(const Napi::Object &obj) {
20
- Napi::Env env = obj.Env();
19
+ // Overload for Napi::Value to handle both arrays and objects
20
+ static std::string json_stringify(const Napi::Value &value) {
21
+ Napi::Env env = value.Env();
21
22
  Napi::Object json = env.Global().Get("JSON").As<Napi::Object>();
22
23
  Napi::Function stringify = json.Get("stringify").As<Napi::Function>();
23
- return stringify.Call(json, {obj}).As<Napi::String>().ToString();
24
+ return stringify.Call(json, {value}).As<Napi::String>().ToString();
24
25
  }
25
26
 
26
27
  static void console_log(Napi::Env env, const std::string &message) {
@@ -1760,7 +1760,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
1760
1760
  ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_LOOKUP}));
1761
1761
  add_opt(common_arg(
1762
1762
  {"-t", "--threads"}, "N",
1763
- string_format("number of threads to use during generation (default: %d)", params.cpuparams.n_threads),
1763
+ string_format("number of CPU threads to use during generation (default: %d)", params.cpuparams.n_threads),
1764
1764
  [](common_params & params, int value) {
1765
1765
  params.cpuparams.n_threads = value;
1766
1766
  if (params.cpuparams.n_threads <= 0) {
@@ -577,6 +577,10 @@ extern "C" {
577
577
  GGML_UNARY_OP_EXP,
578
578
  GGML_UNARY_OP_GELU_ERF,
579
579
  GGML_UNARY_OP_XIELU,
580
+ GGML_UNARY_OP_FLOOR,
581
+ GGML_UNARY_OP_CEIL,
582
+ GGML_UNARY_OP_ROUND,
583
+ GGML_UNARY_OP_TRUNC,
580
584
 
581
585
  GGML_UNARY_OP_COUNT,
582
586
  };
@@ -1151,6 +1155,46 @@ extern "C" {
1151
1155
  struct ggml_context * ctx,
1152
1156
  struct ggml_tensor * a);
1153
1157
 
1158
+ GGML_API struct ggml_tensor * ggml_floor(
1159
+ struct ggml_context * ctx,
1160
+ struct ggml_tensor * a);
1161
+
1162
+ GGML_API struct ggml_tensor * ggml_floor_inplace(
1163
+ struct ggml_context * ctx,
1164
+ struct ggml_tensor * a);
1165
+
1166
+ GGML_API struct ggml_tensor * ggml_ceil(
1167
+ struct ggml_context * ctx,
1168
+ struct ggml_tensor * a);
1169
+
1170
+ GGML_API struct ggml_tensor * ggml_ceil_inplace(
1171
+ struct ggml_context * ctx,
1172
+ struct ggml_tensor * a);
1173
+
1174
+ GGML_API struct ggml_tensor * ggml_round(
1175
+ struct ggml_context * ctx,
1176
+ struct ggml_tensor * a);
1177
+
1178
+ GGML_API struct ggml_tensor * ggml_round_inplace(
1179
+ struct ggml_context * ctx,
1180
+ struct ggml_tensor * a);
1181
+
1182
+ /**
1183
+ * Truncates the fractional part of each element in the tensor (towards zero).
1184
+ * For example: trunc(3.7) = 3.0, trunc(-2.9) = -2.0
1185
+ * Similar to std::trunc in C/C++.
1186
+ */
1187
+
1188
+ GGML_API struct ggml_tensor * ggml_trunc(
1189
+ struct ggml_context * ctx,
1190
+ struct ggml_tensor * a);
1191
+
1192
+ GGML_API struct ggml_tensor * ggml_trunc_inplace(
1193
+ struct ggml_context * ctx,
1194
+ struct ggml_tensor * a);
1195
+
1196
+
1197
+
1154
1198
  // xIELU activation function
1155
1199
  // x = x * (c_a(alpha_n) + c_b(alpha_p, beta) * sigmoid(beta * x)) + eps * (x > 0)
1156
1200
  // where c_a = softplus and c_b(a, b) = softplus(a) + b are constraining functions
@@ -2184,6 +2184,10 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
2184
2184
  case GGML_UNARY_OP_HARDSWISH:
2185
2185
  case GGML_UNARY_OP_HARDSIGMOID:
2186
2186
  case GGML_UNARY_OP_EXP:
2187
+ case GGML_UNARY_OP_FLOOR:
2188
+ case GGML_UNARY_OP_CEIL:
2189
+ case GGML_UNARY_OP_ROUND:
2190
+ case GGML_UNARY_OP_TRUNC:
2187
2191
  {
2188
2192
  n_tasks = 1;
2189
2193
  } break;
@@ -3563,13 +3567,17 @@ void ggml_cpu_init(void) {
3563
3567
  #ifdef GGML_USE_OPENMP
3564
3568
  //if (!getenv("OMP_WAIT_POLICY")) {
3565
3569
  // // set the wait policy to active, so that OpenMP threads don't sleep
3566
- // putenv("OMP_WAIT_POLICY=active");
3570
+ // setenv("OMP_WAIT_POLICY", "active", 0)
3567
3571
  //}
3568
3572
 
3569
3573
  if (!getenv("KMP_BLOCKTIME")) {
3570
3574
  // set the time to wait before sleeping a thread
3571
3575
  // this is less aggressive than setting the wait policy to active, but should achieve similar results in most cases
3572
- putenv("KMP_BLOCKTIME=200"); // 200ms
3576
+ #ifdef _WIN32
3577
+ _putenv_s("KMP_BLOCKTIME", "200"); // 200ms
3578
+ #else
3579
+ setenv("KMP_BLOCKTIME", "200", 0); // 200ms
3580
+ #endif
3573
3581
  }
3574
3582
  #endif
3575
3583
  }
@@ -8993,6 +8993,22 @@ void ggml_compute_forward_unary(
8993
8993
  {
8994
8994
  ggml_compute_forward_exp(params, dst);
8995
8995
  } break;
8996
+ case GGML_UNARY_OP_FLOOR:
8997
+ {
8998
+ ggml_compute_forward_floor(params, dst);
8999
+ } break;
9000
+ case GGML_UNARY_OP_CEIL:
9001
+ {
9002
+ ggml_compute_forward_ceil(params, dst);
9003
+ } break;
9004
+ case GGML_UNARY_OP_ROUND:
9005
+ {
9006
+ ggml_compute_forward_round(params, dst);
9007
+ } break;
9008
+ case GGML_UNARY_OP_TRUNC:
9009
+ {
9010
+ ggml_compute_forward_trunc(params, dst);
9011
+ } break;
8996
9012
  case GGML_UNARY_OP_XIELU:
8997
9013
  {
8998
9014
  ggml_compute_forward_xielu(params, dst);
@@ -73,6 +73,22 @@ static inline float op_log(float x) {
73
73
  return logf(x);
74
74
  }
75
75
 
76
+ static inline float op_floor(float x) {
77
+ return floorf(x);
78
+ }
79
+
80
+ static inline float op_ceil(float x) {
81
+ return ceilf(x);
82
+ }
83
+
84
+ static inline float op_round(float x) {
85
+ return roundf(x);
86
+ }
87
+
88
+ static inline float op_trunc(float x) {
89
+ return truncf(x);
90
+ }
91
+
76
92
  template <float (*op)(float), typename src0_t, typename dst_t>
77
93
  static inline void vec_unary_op(int64_t n, dst_t * y, const src0_t * x) {
78
94
  constexpr auto src0_to_f32 = type_conversion_table<src0_t>::to_f32;
@@ -274,6 +290,22 @@ void ggml_compute_forward_log(const ggml_compute_params * params, ggml_tensor *
274
290
  unary_op<op_log>(params, dst);
275
291
  }
276
292
 
293
+ void ggml_compute_forward_floor(const ggml_compute_params * params, ggml_tensor * dst) {
294
+ unary_op<op_floor>(params, dst);
295
+ }
296
+
297
+ void ggml_compute_forward_ceil(const ggml_compute_params * params, ggml_tensor * dst) {
298
+ unary_op<op_ceil>(params, dst);
299
+ }
300
+
301
+ void ggml_compute_forward_round(const ggml_compute_params * params, ggml_tensor * dst) {
302
+ unary_op<op_round>(params, dst);
303
+ }
304
+
305
+ void ggml_compute_forward_trunc(const ggml_compute_params * params, ggml_tensor * dst) {
306
+ unary_op<op_trunc>(params, dst);
307
+ }
308
+
277
309
  void ggml_compute_forward_xielu(const ggml_compute_params * params, ggml_tensor * dst) {
278
310
  const float alpha_n = ggml_get_op_params_f32(dst, 1);
279
311
  const float alpha_p = ggml_get_op_params_f32(dst, 2);
@@ -22,6 +22,10 @@ void ggml_compute_forward_sqrt(const struct ggml_compute_params * params, struct
22
22
  void ggml_compute_forward_sin(const struct ggml_compute_params * params, struct ggml_tensor * dst);
23
23
  void ggml_compute_forward_cos(const struct ggml_compute_params * params, struct ggml_tensor * dst);
24
24
  void ggml_compute_forward_log(const struct ggml_compute_params * params, struct ggml_tensor * dst);
25
+ void ggml_compute_forward_floor(const struct ggml_compute_params * params, struct ggml_tensor * dst);
26
+ void ggml_compute_forward_ceil(const struct ggml_compute_params * params, struct ggml_tensor * dst);
27
+ void ggml_compute_forward_round(const struct ggml_compute_params * params, struct ggml_tensor * dst);
28
+ void ggml_compute_forward_trunc(const struct ggml_compute_params * params, struct ggml_tensor * dst);
25
29
  void ggml_compute_forward_xielu(const struct ggml_compute_params * params, struct ggml_tensor * dst);
26
30
 
27
31
  #ifdef __cplusplus
@@ -5,6 +5,7 @@
5
5
  #include <map>
6
6
 
7
7
  static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
8
+ { LLM_ARCH_CLIP, "clip" }, // dummy, only used by llama-quantize
8
9
  { LLM_ARCH_LLAMA, "llama" },
9
10
  { LLM_ARCH_LLAMA4, "llama4" },
10
11
  { LLM_ARCH_DECI, "deci" },
@@ -275,6 +276,10 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
275
276
  };
276
277
 
277
278
  static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_NAMES = {
279
+ {
280
+ LLM_ARCH_CLIP,
281
+ {},
282
+ },
278
283
  {
279
284
  LLM_ARCH_LLAMA,
280
285
  {
@@ -9,6 +9,7 @@
9
9
  //
10
10
 
11
11
  enum llm_arch {
12
+ LLM_ARCH_CLIP,
12
13
  LLM_ARCH_LLAMA,
13
14
  LLM_ARCH_LLAMA4,
14
15
  LLM_ARCH_DECI,
@@ -478,7 +478,8 @@ void llama_model::load_hparams(llama_model_loader & ml) {
478
478
  ml.get_key(LLM_KV_GENERAL_NAME, name, false);
479
479
 
480
480
  // everything past this point is not vocab-related
481
- if (hparams.vocab_only) {
481
+ // for CLIP models, we only need to load tensors, no hparams
482
+ if (hparams.vocab_only || ml.get_arch() == LLM_ARCH_CLIP) {
482
483
  return;
483
484
  }
484
485
 
@@ -20013,6 +20014,7 @@ int32_t llama_n_head(const llama_model * model) {
20013
20014
  llama_rope_type llama_model_rope_type(const llama_model * model) {
20014
20015
  switch (model->arch) {
20015
20016
  // these models do not use RoPE
20017
+ case LLM_ARCH_CLIP:
20016
20018
  case LLM_ARCH_GPT2:
20017
20019
  case LLM_ARCH_GPTJ:
20018
20020
  case LLM_ARCH_MPT:
@@ -701,6 +701,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
701
701
  });
702
702
  }
703
703
 
704
+ bool is_clip_model = false;
704
705
  for (const auto * it : tensors) {
705
706
  const struct ggml_tensor * tensor = it->tensor;
706
707
 
@@ -714,12 +715,14 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
714
715
  } else if (name == LLM_TN(model.arch)(LLM_TENSOR_OUTPUT, "weight")) {
715
716
  qs.has_output = true;
716
717
  }
718
+
719
+ is_clip_model |= name.rfind("mm.", 0) == 0; // check the "mm." prefix
717
720
  }
718
721
 
719
722
  qs.n_ffn_down = qs.n_ffn_gate = qs.n_ffn_up = (int)model.hparams.n_layer;
720
723
 
721
724
  // sanity checks for models that have attention layers
722
- if (qs.n_attention_wv != 0)
725
+ if (qs.n_attention_wv != 0 && !is_clip_model)
723
726
  {
724
727
  const auto & n_head_kv_iter = model.hparams.n_head_kv_arr.begin();
725
728
  // attention layers have a non-zero number of kv heads
@@ -881,6 +884,9 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
881
884
  // do not quantize relative position bias (T5)
882
885
  quantize &= name.find("attn_rel_b.weight") == std::string::npos;
883
886
 
887
+ // do not quantize specific multimodal tensors
888
+ quantize &= name.find(".position_embd.") == std::string::npos;
889
+
884
890
  ggml_type new_type;
885
891
  void * new_data;
886
892
  size_t new_size;
@@ -124,6 +124,9 @@ static int llama_model_load(const std::string & fname, std::vector<std::string>
124
124
  } catch(const std::exception & e) {
125
125
  throw std::runtime_error("error loading model hyperparameters: " + std::string(e.what()));
126
126
  }
127
+ if (model.arch == LLM_ARCH_CLIP) {
128
+ throw std::runtime_error("CLIP cannot be used as main model, use it with --mmproj instead");
129
+ }
127
130
  try {
128
131
  model.load_vocab(ml);
129
132
  } catch(const std::exception & e) {