@inferrlm/react-native-mlx 0.4.2 → 0.4.8
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/MLXReactNative.podspec +8 -2
- package/app.plugin.js +63 -0
- package/ios/Sources/HybridLLM.swift +181 -12
- package/ios/Sources/LLMError.swift +1 -0
- package/package.json +1 -1
package/MLXReactNative.podspec
CHANGED
|
@@ -24,13 +24,19 @@ Pod::Spec.new do |s|
|
|
|
24
24
|
|
|
25
25
|
spm_dependency(s,
|
|
26
26
|
url: "https://github.com/ml-explore/mlx-swift-lm.git",
|
|
27
|
-
requirement: {kind: "
|
|
27
|
+
requirement: {kind: "upToNextMinorVersion", minimumVersion: "3.31.3"},
|
|
28
28
|
products: ["MLXLLM", "MLXLMCommon"]
|
|
29
29
|
)
|
|
30
30
|
|
|
31
|
+
spm_dependency(s,
|
|
32
|
+
url: "https://github.com/huggingface/swift-transformers",
|
|
33
|
+
requirement: {kind: "upToNextMinorVersion", minimumVersion: "1.2.0"},
|
|
34
|
+
products: ["Tokenizers"]
|
|
35
|
+
)
|
|
36
|
+
|
|
31
37
|
spm_dependency(s,
|
|
32
38
|
url: "https://github.com/Blaizzy/mlx-audio-swift.git",
|
|
33
|
-
requirement: {kind: "
|
|
39
|
+
requirement: {kind: "revision", revision: "856e04afb3c6eb931d92bb0d6ae7bbfbdfa89b15"},
|
|
34
40
|
products: ["MLXAudioTTS", "MLXAudioSTT", "MLXAudioCore"]
|
|
35
41
|
)
|
|
36
42
|
|
package/app.plugin.js
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
const { withPodfile } = require('@expo/config-plugins');
|
|
2
|
+
|
|
3
|
+
function disableDeterministicPodUuids(contents) {
|
|
4
|
+
const line = "install! 'cocoapods', :deterministic_uuids => false";
|
|
5
|
+
|
|
6
|
+
if (contents.includes(line)) {
|
|
7
|
+
return contents;
|
|
8
|
+
}
|
|
9
|
+
|
|
10
|
+
const anchor = 'prepare_react_native_project!';
|
|
11
|
+
|
|
12
|
+
if (!contents.includes(anchor)) {
|
|
13
|
+
return contents;
|
|
14
|
+
}
|
|
15
|
+
|
|
16
|
+
return contents.replace(anchor, `${line}\n\n${anchor}`);
|
|
17
|
+
}
|
|
18
|
+
|
|
19
|
+
function injectSpmRootFix(contents) {
|
|
20
|
+
const block = [
|
|
21
|
+
'if defined?(::SPMManager) && ::SPMManager.instance_methods.include?(:add_spm_to_target)',
|
|
22
|
+
' ::SPMManager.class_eval do',
|
|
23
|
+
' unless method_defined?(:inferrlm_add_spm_to_target)',
|
|
24
|
+
' alias_method :inferrlm_add_spm_to_target, :add_spm_to_target',
|
|
25
|
+
'',
|
|
26
|
+
' def add_spm_to_target(project, target, url, requirement, products)',
|
|
27
|
+
' root = project.root_object',
|
|
28
|
+
' if root && project.objects_by_uuid[root.uuid] != root',
|
|
29
|
+
' root.add_referrer(project)',
|
|
30
|
+
' end',
|
|
31
|
+
'',
|
|
32
|
+
' inferrlm_add_spm_to_target(project, target, url, requirement, products)',
|
|
33
|
+
' end',
|
|
34
|
+
' end',
|
|
35
|
+
' end',
|
|
36
|
+
'end',
|
|
37
|
+
].join('\n');
|
|
38
|
+
|
|
39
|
+
if (contents.includes('::SPMManager.class_eval do')) {
|
|
40
|
+
return contents;
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
const anchor = 'prepare_react_native_project!';
|
|
44
|
+
|
|
45
|
+
if (!contents.includes(anchor)) {
|
|
46
|
+
return contents;
|
|
47
|
+
}
|
|
48
|
+
|
|
49
|
+
return contents.replace(anchor, `${block}\n\n${anchor}`);
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
module.exports = function withMlxIosPods(config) {
|
|
53
|
+
return withPodfile(config, (modConfig) => {
|
|
54
|
+
modConfig.modResults.contents = disableDeterministicPodUuids(modConfig.modResults.contents);
|
|
55
|
+
modConfig.modResults.contents = injectSpmRootFix(modConfig.modResults.contents);
|
|
56
|
+
return modConfig;
|
|
57
|
+
});
|
|
58
|
+
};
|
|
59
|
+
|
|
60
|
+
module.exports._helpers = {
|
|
61
|
+
disableDeterministicPodUuids,
|
|
62
|
+
injectSpmRootFix,
|
|
63
|
+
};
|
|
@@ -1,9 +1,149 @@
|
|
|
1
1
|
import Foundation
|
|
2
2
|
import NitroModules
|
|
3
|
+
import Tokenizers
|
|
3
4
|
internal import MLX
|
|
4
5
|
internal import MLXLLM
|
|
5
6
|
internal import MLXLMCommon
|
|
6
|
-
|
|
7
|
+
|
|
8
|
+
private typealias ToolSpec = [String: Any]
|
|
9
|
+
|
|
10
|
+
private func isJANGModel(at dir: URL) -> Bool {
|
|
11
|
+
FileManager.default.fileExists(atPath: dir.appendingPathComponent("jang_config.json").path)
|
|
12
|
+
}
|
|
13
|
+
|
|
14
|
+
private protocol ChatTemplateConfigurableTokenizer: MLXLMCommon.Tokenizer {
|
|
15
|
+
func applyChatTemplate(
|
|
16
|
+
messages: [[String: any Sendable]],
|
|
17
|
+
chatTemplate: Tokenizers.ChatTemplateArgument?,
|
|
18
|
+
addGenerationPrompt: Bool,
|
|
19
|
+
truncation: Bool,
|
|
20
|
+
maxLength: Int?,
|
|
21
|
+
tools: [[String: any Sendable]]?,
|
|
22
|
+
additionalContext: [String: any Sendable]?
|
|
23
|
+
) throws -> [Int]
|
|
24
|
+
}
|
|
25
|
+
|
|
26
|
+
private extension MLXLMCommon.Tokenizer {
|
|
27
|
+
func applyChatTemplate(
|
|
28
|
+
messages: [[String: any Sendable]],
|
|
29
|
+
chatTemplate: Tokenizers.ChatTemplateArgument?,
|
|
30
|
+
addGenerationPrompt: Bool,
|
|
31
|
+
truncation: Bool,
|
|
32
|
+
maxLength: Int?,
|
|
33
|
+
tools: [[String: any Sendable]]?,
|
|
34
|
+
additionalContext: [String: any Sendable]?
|
|
35
|
+
) throws -> [Int] {
|
|
36
|
+
guard let tokenizer = self as? any ChatTemplateConfigurableTokenizer else {
|
|
37
|
+
throw MLXLMCommon.TokenizerError.missingChatTemplate
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
return try tokenizer.applyChatTemplate(
|
|
41
|
+
messages: messages,
|
|
42
|
+
chatTemplate: chatTemplate,
|
|
43
|
+
addGenerationPrompt: addGenerationPrompt,
|
|
44
|
+
truncation: truncation,
|
|
45
|
+
maxLength: maxLength,
|
|
46
|
+
tools: tools,
|
|
47
|
+
additionalContext: additionalContext
|
|
48
|
+
)
|
|
49
|
+
}
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
private struct TransformersTokenizerBridge: ChatTemplateConfigurableTokenizer {
|
|
53
|
+
private let upstream: any Tokenizers.Tokenizer
|
|
54
|
+
|
|
55
|
+
init(_ upstream: any Tokenizers.Tokenizer) {
|
|
56
|
+
self.upstream = upstream
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
func encode(text: String, addSpecialTokens: Bool) -> [Int] {
|
|
60
|
+
upstream.encode(text: text, addSpecialTokens: addSpecialTokens)
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
func decode(tokenIds: [Int], skipSpecialTokens: Bool) -> String {
|
|
64
|
+
upstream.decode(tokens: tokenIds, skipSpecialTokens: skipSpecialTokens)
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
func convertTokenToId(_ token: String) -> Int? {
|
|
68
|
+
upstream.convertTokenToId(token)
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
func convertIdToToken(_ id: Int) -> String? {
|
|
72
|
+
upstream.convertIdToToken(id)
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
var bosToken: String? { upstream.bosToken }
|
|
76
|
+
var eosToken: String? { upstream.eosToken }
|
|
77
|
+
var unknownToken: String? { upstream.unknownToken }
|
|
78
|
+
|
|
79
|
+
func applyChatTemplate(
|
|
80
|
+
messages: [[String: any Sendable]],
|
|
81
|
+
tools: [[String: any Sendable]]?,
|
|
82
|
+
additionalContext: [String: any Sendable]?
|
|
83
|
+
) throws -> [Int] {
|
|
84
|
+
do {
|
|
85
|
+
return try upstream.applyChatTemplate(
|
|
86
|
+
messages: messages,
|
|
87
|
+
tools: tools,
|
|
88
|
+
additionalContext: additionalContext
|
|
89
|
+
)
|
|
90
|
+
} catch Tokenizers.TokenizerError.missingChatTemplate {
|
|
91
|
+
throw MLXLMCommon.TokenizerError.missingChatTemplate
|
|
92
|
+
}
|
|
93
|
+
}
|
|
94
|
+
|
|
95
|
+
func applyChatTemplate(
|
|
96
|
+
messages: [[String: any Sendable]],
|
|
97
|
+
chatTemplate: Tokenizers.ChatTemplateArgument?,
|
|
98
|
+
addGenerationPrompt: Bool,
|
|
99
|
+
truncation: Bool,
|
|
100
|
+
maxLength: Int?,
|
|
101
|
+
tools: [[String: any Sendable]]?,
|
|
102
|
+
additionalContext: [String: any Sendable]?
|
|
103
|
+
) throws -> [Int] {
|
|
104
|
+
do {
|
|
105
|
+
return try upstream.applyChatTemplate(
|
|
106
|
+
messages: messages,
|
|
107
|
+
chatTemplate: chatTemplate,
|
|
108
|
+
addGenerationPrompt: addGenerationPrompt,
|
|
109
|
+
truncation: truncation,
|
|
110
|
+
maxLength: maxLength,
|
|
111
|
+
tools: tools,
|
|
112
|
+
additionalContext: additionalContext
|
|
113
|
+
)
|
|
114
|
+
} catch Tokenizers.TokenizerError.missingChatTemplate {
|
|
115
|
+
throw MLXLMCommon.TokenizerError.missingChatTemplate
|
|
116
|
+
}
|
|
117
|
+
}
|
|
118
|
+
}
|
|
119
|
+
|
|
120
|
+
private struct TransformersTokenizerLoader: MLXLMCommon.TokenizerLoader {
|
|
121
|
+
func load(from directory: URL) async throws -> any MLXLMCommon.Tokenizer {
|
|
122
|
+
let upstream = try await Tokenizers.AutoTokenizer.from(modelFolder: directory)
|
|
123
|
+
return TransformersTokenizerBridge(upstream)
|
|
124
|
+
}
|
|
125
|
+
}
|
|
126
|
+
|
|
127
|
+
@MainActor
|
|
128
|
+
private final class LoadProgress {
|
|
129
|
+
private let callback: ((Double) -> Void)?
|
|
130
|
+
private(set) var value: Double = 0
|
|
131
|
+
|
|
132
|
+
init(callback: ((Double) -> Void)?) {
|
|
133
|
+
self.callback = callback
|
|
134
|
+
}
|
|
135
|
+
|
|
136
|
+
func set(_ nextValue: Double) {
|
|
137
|
+
let clamped = min(1.0, max(value, nextValue))
|
|
138
|
+
guard clamped > value else { return }
|
|
139
|
+
value = clamped
|
|
140
|
+
callback?(clamped)
|
|
141
|
+
}
|
|
142
|
+
|
|
143
|
+
func tick(toward upperBound: Double, step: Double) {
|
|
144
|
+
set(min(upperBound, value + step))
|
|
145
|
+
}
|
|
146
|
+
}
|
|
7
147
|
|
|
8
148
|
class HybridLLM: HybridLLMSpec {
|
|
9
149
|
private var session: ChatSession?
|
|
@@ -16,7 +156,6 @@ class HybridLLM: HybridLLMSpec {
|
|
|
16
156
|
totalTime: 0,
|
|
17
157
|
toolExecutionTime: 0
|
|
18
158
|
)
|
|
19
|
-
private var modelFactory: ModelFactory = LLMModelFactory.shared
|
|
20
159
|
private var manageHistory: Bool = false
|
|
21
160
|
private var messageHistory: [LLMMessage] = []
|
|
22
161
|
private var loadTask: Task<Void, Error>?
|
|
@@ -142,6 +281,9 @@ class HybridLLM: HybridLLMSpec {
|
|
|
142
281
|
|
|
143
282
|
return Promise.async { [self] in
|
|
144
283
|
let task = Task { @MainActor in
|
|
284
|
+
let progress = LoadProgress(callback: options?.onProgress)
|
|
285
|
+
progress.set(0.02)
|
|
286
|
+
|
|
145
287
|
Memory.cacheLimit = 2000000
|
|
146
288
|
|
|
147
289
|
self.currentTask?.cancel()
|
|
@@ -151,6 +293,7 @@ class HybridLLM: HybridLLMSpec {
|
|
|
151
293
|
self.tools = []
|
|
152
294
|
self.toolSchemas = []
|
|
153
295
|
Memory.clearCache()
|
|
296
|
+
progress.set(0.12)
|
|
154
297
|
|
|
155
298
|
let memoryAfterCleanup = self.getMemoryUsage()
|
|
156
299
|
let gpuAfterCleanup = self.getGPUMemoryUsage()
|
|
@@ -158,13 +301,30 @@ class HybridLLM: HybridLLMSpec {
|
|
|
158
301
|
|
|
159
302
|
let modelDir = await ModelDownloader.shared.getModelDirectory(modelId: modelId)
|
|
160
303
|
log("Loading from directory: \(modelDir.path)")
|
|
304
|
+
progress.set(0.18)
|
|
161
305
|
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
306
|
+
if isJANGModel(at: modelDir) {
|
|
307
|
+
throw LLMError.unsupportedModel("JANG model format is not supported by the current MLX dependency set")
|
|
308
|
+
}
|
|
309
|
+
|
|
310
|
+
var loadingPulse: Task<Void, Never>?
|
|
311
|
+
if options?.onProgress != nil {
|
|
312
|
+
loadingPulse = Task { @MainActor in
|
|
313
|
+
while !Task.isCancelled {
|
|
314
|
+
try? await Task.sleep(nanoseconds: 250_000_000)
|
|
315
|
+
progress.tick(toward: 0.78, step: 0.04)
|
|
316
|
+
}
|
|
317
|
+
}
|
|
167
318
|
}
|
|
319
|
+
defer { loadingPulse?.cancel() }
|
|
320
|
+
|
|
321
|
+
progress.set(0.22)
|
|
322
|
+
|
|
323
|
+
let loadedContainer = try await LLMModelFactory.shared.loadContainer(
|
|
324
|
+
from: modelDir,
|
|
325
|
+
using: TransformersTokenizerLoader()
|
|
326
|
+
)
|
|
327
|
+
progress.set(0.86)
|
|
168
328
|
|
|
169
329
|
try Task.checkCancellation()
|
|
170
330
|
|
|
@@ -194,6 +354,9 @@ class HybridLLM: HybridLLMSpec {
|
|
|
194
354
|
let updatedExtra = await loadedContainer.configuration.extraEOSTokens
|
|
195
355
|
log("EOS patched - ids: \(updated), extra: \(updatedExtra)")
|
|
196
356
|
}
|
|
357
|
+
progress.set(0.92)
|
|
358
|
+
|
|
359
|
+
progress.set(0.95)
|
|
197
360
|
|
|
198
361
|
let memoryAfterContainer = self.getMemoryUsage()
|
|
199
362
|
let gpuAfterContainer = self.getGPUMemoryUsage()
|
|
@@ -205,10 +368,14 @@ class HybridLLM: HybridLLMSpec {
|
|
|
205
368
|
log("Loaded \(self.tools.count) tools: \(self.tools.map { $0.name })")
|
|
206
369
|
}
|
|
207
370
|
|
|
208
|
-
let additionalContextDict: [String:
|
|
209
|
-
|
|
371
|
+
let additionalContextDict: [String: any Sendable]?
|
|
372
|
+
if let messages = options?.additionalContext {
|
|
373
|
+
let contextMessages: [[String: String]] = messages.map {
|
|
374
|
+
["role": $0.role, "content": $0.content]
|
|
375
|
+
}
|
|
376
|
+
additionalContextDict = ["messages": contextMessages]
|
|
210
377
|
} else {
|
|
211
|
-
nil
|
|
378
|
+
additionalContextDict = nil
|
|
212
379
|
}
|
|
213
380
|
|
|
214
381
|
self.container = loadedContainer
|
|
@@ -221,6 +388,8 @@ class HybridLLM: HybridLLMSpec {
|
|
|
221
388
|
if self.manageHistory {
|
|
222
389
|
log("History management enabled with \(self.messageHistory.count) initial messages")
|
|
223
390
|
}
|
|
391
|
+
|
|
392
|
+
progress.set(1.0)
|
|
224
393
|
}
|
|
225
394
|
|
|
226
395
|
self.loadTask = task
|
|
@@ -471,7 +640,7 @@ class HybridLLM: HybridLLMSpec {
|
|
|
471
640
|
additionalContext: additionalCtx
|
|
472
641
|
)
|
|
473
642
|
self.log("template_applied token_count=\(result.count)")
|
|
474
|
-
let decoded = context.tokenizer.decode(
|
|
643
|
+
let decoded = context.tokenizer.decode(tokenIds: Array(result.suffix(60)))
|
|
475
644
|
self.log("input_tail_decoded: \(decoded)")
|
|
476
645
|
self.lastInputContainedThinkTag = decoded.contains("<think>")
|
|
477
646
|
return result
|
|
@@ -487,7 +656,7 @@ class HybridLLM: HybridLLMSpec {
|
|
|
487
656
|
additionalContext: additionalCtx
|
|
488
657
|
)
|
|
489
658
|
self.log("fallback_template_applied token_count=\(result.count)")
|
|
490
|
-
let decoded = context.tokenizer.decode(
|
|
659
|
+
let decoded = context.tokenizer.decode(tokenIds: Array(result.suffix(60)))
|
|
491
660
|
self.log("fallback_input_tail_decoded: \(decoded)")
|
|
492
661
|
self.lastInputContainedThinkTag = decoded.contains("<think>")
|
|
493
662
|
return result
|
package/package.json
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
{
|
|
2
2
|
"name": "@inferrlm/react-native-mlx",
|
|
3
3
|
"description": "MLX Swift integration for React Native - InferrLM fork with enhanced features",
|
|
4
|
-
"version": "0.4.
|
|
4
|
+
"version": "0.4.8",
|
|
5
5
|
"main": "./lib/module/index.js",
|
|
6
6
|
"module": "./lib/module/index.js",
|
|
7
7
|
"types": "./lib/typescript/src/index.d.ts",
|