@inferrlm/react-native-mlx 0.4.2 → 0.4.8

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -24,13 +24,19 @@ Pod::Spec.new do |s|
24
24
 
25
25
  spm_dependency(s,
26
26
  url: "https://github.com/ml-explore/mlx-swift-lm.git",
27
- requirement: {kind: "branch", branch: "main"},
27
+ requirement: {kind: "upToNextMinorVersion", minimumVersion: "3.31.3"},
28
28
  products: ["MLXLLM", "MLXLMCommon"]
29
29
  )
30
30
 
31
+ spm_dependency(s,
32
+ url: "https://github.com/huggingface/swift-transformers",
33
+ requirement: {kind: "upToNextMinorVersion", minimumVersion: "1.2.0"},
34
+ products: ["Tokenizers"]
35
+ )
36
+
31
37
  spm_dependency(s,
32
38
  url: "https://github.com/Blaizzy/mlx-audio-swift.git",
33
- requirement: {kind: "branch", branch: "main"},
39
+ requirement: {kind: "revision", revision: "856e04afb3c6eb931d92bb0d6ae7bbfbdfa89b15"},
34
40
  products: ["MLXAudioTTS", "MLXAudioSTT", "MLXAudioCore"]
35
41
  )
36
42
 
package/app.plugin.js ADDED
@@ -0,0 +1,63 @@
1
+ const { withPodfile } = require('@expo/config-plugins');
2
+
3
+ function disableDeterministicPodUuids(contents) {
4
+ const line = "install! 'cocoapods', :deterministic_uuids => false";
5
+
6
+ if (contents.includes(line)) {
7
+ return contents;
8
+ }
9
+
10
+ const anchor = 'prepare_react_native_project!';
11
+
12
+ if (!contents.includes(anchor)) {
13
+ return contents;
14
+ }
15
+
16
+ return contents.replace(anchor, `${line}\n\n${anchor}`);
17
+ }
18
+
19
+ function injectSpmRootFix(contents) {
20
+ const block = [
21
+ 'if defined?(::SPMManager) && ::SPMManager.instance_methods.include?(:add_spm_to_target)',
22
+ ' ::SPMManager.class_eval do',
23
+ ' unless method_defined?(:inferrlm_add_spm_to_target)',
24
+ ' alias_method :inferrlm_add_spm_to_target, :add_spm_to_target',
25
+ '',
26
+ ' def add_spm_to_target(project, target, url, requirement, products)',
27
+ ' root = project.root_object',
28
+ ' if root && project.objects_by_uuid[root.uuid] != root',
29
+ ' root.add_referrer(project)',
30
+ ' end',
31
+ '',
32
+ ' inferrlm_add_spm_to_target(project, target, url, requirement, products)',
33
+ ' end',
34
+ ' end',
35
+ ' end',
36
+ 'end',
37
+ ].join('\n');
38
+
39
+ if (contents.includes('::SPMManager.class_eval do')) {
40
+ return contents;
41
+ }
42
+
43
+ const anchor = 'prepare_react_native_project!';
44
+
45
+ if (!contents.includes(anchor)) {
46
+ return contents;
47
+ }
48
+
49
+ return contents.replace(anchor, `${block}\n\n${anchor}`);
50
+ }
51
+
52
+ module.exports = function withMlxIosPods(config) {
53
+ return withPodfile(config, (modConfig) => {
54
+ modConfig.modResults.contents = disableDeterministicPodUuids(modConfig.modResults.contents);
55
+ modConfig.modResults.contents = injectSpmRootFix(modConfig.modResults.contents);
56
+ return modConfig;
57
+ });
58
+ };
59
+
60
+ module.exports._helpers = {
61
+ disableDeterministicPodUuids,
62
+ injectSpmRootFix,
63
+ };
@@ -1,9 +1,149 @@
1
1
  import Foundation
2
2
  import NitroModules
3
+ import Tokenizers
3
4
  internal import MLX
4
5
  internal import MLXLLM
5
6
  internal import MLXLMCommon
6
- internal import Tokenizers
7
+
8
+ private typealias ToolSpec = [String: Any]
9
+
10
+ private func isJANGModel(at dir: URL) -> Bool {
11
+ FileManager.default.fileExists(atPath: dir.appendingPathComponent("jang_config.json").path)
12
+ }
13
+
14
+ private protocol ChatTemplateConfigurableTokenizer: MLXLMCommon.Tokenizer {
15
+ func applyChatTemplate(
16
+ messages: [[String: any Sendable]],
17
+ chatTemplate: Tokenizers.ChatTemplateArgument?,
18
+ addGenerationPrompt: Bool,
19
+ truncation: Bool,
20
+ maxLength: Int?,
21
+ tools: [[String: any Sendable]]?,
22
+ additionalContext: [String: any Sendable]?
23
+ ) throws -> [Int]
24
+ }
25
+
26
+ private extension MLXLMCommon.Tokenizer {
27
+ func applyChatTemplate(
28
+ messages: [[String: any Sendable]],
29
+ chatTemplate: Tokenizers.ChatTemplateArgument?,
30
+ addGenerationPrompt: Bool,
31
+ truncation: Bool,
32
+ maxLength: Int?,
33
+ tools: [[String: any Sendable]]?,
34
+ additionalContext: [String: any Sendable]?
35
+ ) throws -> [Int] {
36
+ guard let tokenizer = self as? any ChatTemplateConfigurableTokenizer else {
37
+ throw MLXLMCommon.TokenizerError.missingChatTemplate
38
+ }
39
+
40
+ return try tokenizer.applyChatTemplate(
41
+ messages: messages,
42
+ chatTemplate: chatTemplate,
43
+ addGenerationPrompt: addGenerationPrompt,
44
+ truncation: truncation,
45
+ maxLength: maxLength,
46
+ tools: tools,
47
+ additionalContext: additionalContext
48
+ )
49
+ }
50
+ }
51
+
52
+ private struct TransformersTokenizerBridge: ChatTemplateConfigurableTokenizer {
53
+ private let upstream: any Tokenizers.Tokenizer
54
+
55
+ init(_ upstream: any Tokenizers.Tokenizer) {
56
+ self.upstream = upstream
57
+ }
58
+
59
+ func encode(text: String, addSpecialTokens: Bool) -> [Int] {
60
+ upstream.encode(text: text, addSpecialTokens: addSpecialTokens)
61
+ }
62
+
63
+ func decode(tokenIds: [Int], skipSpecialTokens: Bool) -> String {
64
+ upstream.decode(tokens: tokenIds, skipSpecialTokens: skipSpecialTokens)
65
+ }
66
+
67
+ func convertTokenToId(_ token: String) -> Int? {
68
+ upstream.convertTokenToId(token)
69
+ }
70
+
71
+ func convertIdToToken(_ id: Int) -> String? {
72
+ upstream.convertIdToToken(id)
73
+ }
74
+
75
+ var bosToken: String? { upstream.bosToken }
76
+ var eosToken: String? { upstream.eosToken }
77
+ var unknownToken: String? { upstream.unknownToken }
78
+
79
+ func applyChatTemplate(
80
+ messages: [[String: any Sendable]],
81
+ tools: [[String: any Sendable]]?,
82
+ additionalContext: [String: any Sendable]?
83
+ ) throws -> [Int] {
84
+ do {
85
+ return try upstream.applyChatTemplate(
86
+ messages: messages,
87
+ tools: tools,
88
+ additionalContext: additionalContext
89
+ )
90
+ } catch Tokenizers.TokenizerError.missingChatTemplate {
91
+ throw MLXLMCommon.TokenizerError.missingChatTemplate
92
+ }
93
+ }
94
+
95
+ func applyChatTemplate(
96
+ messages: [[String: any Sendable]],
97
+ chatTemplate: Tokenizers.ChatTemplateArgument?,
98
+ addGenerationPrompt: Bool,
99
+ truncation: Bool,
100
+ maxLength: Int?,
101
+ tools: [[String: any Sendable]]?,
102
+ additionalContext: [String: any Sendable]?
103
+ ) throws -> [Int] {
104
+ do {
105
+ return try upstream.applyChatTemplate(
106
+ messages: messages,
107
+ chatTemplate: chatTemplate,
108
+ addGenerationPrompt: addGenerationPrompt,
109
+ truncation: truncation,
110
+ maxLength: maxLength,
111
+ tools: tools,
112
+ additionalContext: additionalContext
113
+ )
114
+ } catch Tokenizers.TokenizerError.missingChatTemplate {
115
+ throw MLXLMCommon.TokenizerError.missingChatTemplate
116
+ }
117
+ }
118
+ }
119
+
120
+ private struct TransformersTokenizerLoader: MLXLMCommon.TokenizerLoader {
121
+ func load(from directory: URL) async throws -> any MLXLMCommon.Tokenizer {
122
+ let upstream = try await Tokenizers.AutoTokenizer.from(modelFolder: directory)
123
+ return TransformersTokenizerBridge(upstream)
124
+ }
125
+ }
126
+
127
+ @MainActor
128
+ private final class LoadProgress {
129
+ private let callback: ((Double) -> Void)?
130
+ private(set) var value: Double = 0
131
+
132
+ init(callback: ((Double) -> Void)?) {
133
+ self.callback = callback
134
+ }
135
+
136
+ func set(_ nextValue: Double) {
137
+ let clamped = min(1.0, max(value, nextValue))
138
+ guard clamped > value else { return }
139
+ value = clamped
140
+ callback?(clamped)
141
+ }
142
+
143
+ func tick(toward upperBound: Double, step: Double) {
144
+ set(min(upperBound, value + step))
145
+ }
146
+ }
7
147
 
8
148
  class HybridLLM: HybridLLMSpec {
9
149
  private var session: ChatSession?
@@ -16,7 +156,6 @@ class HybridLLM: HybridLLMSpec {
16
156
  totalTime: 0,
17
157
  toolExecutionTime: 0
18
158
  )
19
- private var modelFactory: ModelFactory = LLMModelFactory.shared
20
159
  private var manageHistory: Bool = false
21
160
  private var messageHistory: [LLMMessage] = []
22
161
  private var loadTask: Task<Void, Error>?
@@ -142,6 +281,9 @@ class HybridLLM: HybridLLMSpec {
142
281
 
143
282
  return Promise.async { [self] in
144
283
  let task = Task { @MainActor in
284
+ let progress = LoadProgress(callback: options?.onProgress)
285
+ progress.set(0.02)
286
+
145
287
  Memory.cacheLimit = 2000000
146
288
 
147
289
  self.currentTask?.cancel()
@@ -151,6 +293,7 @@ class HybridLLM: HybridLLMSpec {
151
293
  self.tools = []
152
294
  self.toolSchemas = []
153
295
  Memory.clearCache()
296
+ progress.set(0.12)
154
297
 
155
298
  let memoryAfterCleanup = self.getMemoryUsage()
156
299
  let gpuAfterCleanup = self.getGPUMemoryUsage()
@@ -158,13 +301,30 @@ class HybridLLM: HybridLLMSpec {
158
301
 
159
302
  let modelDir = await ModelDownloader.shared.getModelDirectory(modelId: modelId)
160
303
  log("Loading from directory: \(modelDir.path)")
304
+ progress.set(0.18)
161
305
 
162
- let config = ModelConfiguration(directory: modelDir)
163
- let loadedContainer = try await self.modelFactory.loadContainer(
164
- configuration: config
165
- ) { progress in
166
- options?.onProgress?(progress.fractionCompleted)
306
+ if isJANGModel(at: modelDir) {
307
+ throw LLMError.unsupportedModel("JANG model format is not supported by the current MLX dependency set")
308
+ }
309
+
310
+ var loadingPulse: Task<Void, Never>?
311
+ if options?.onProgress != nil {
312
+ loadingPulse = Task { @MainActor in
313
+ while !Task.isCancelled {
314
+ try? await Task.sleep(nanoseconds: 250_000_000)
315
+ progress.tick(toward: 0.78, step: 0.04)
316
+ }
317
+ }
167
318
  }
319
+ defer { loadingPulse?.cancel() }
320
+
321
+ progress.set(0.22)
322
+
323
+ let loadedContainer = try await LLMModelFactory.shared.loadContainer(
324
+ from: modelDir,
325
+ using: TransformersTokenizerLoader()
326
+ )
327
+ progress.set(0.86)
168
328
 
169
329
  try Task.checkCancellation()
170
330
 
@@ -194,6 +354,9 @@ class HybridLLM: HybridLLMSpec {
194
354
  let updatedExtra = await loadedContainer.configuration.extraEOSTokens
195
355
  log("EOS patched - ids: \(updated), extra: \(updatedExtra)")
196
356
  }
357
+ progress.set(0.92)
358
+
359
+ progress.set(0.95)
197
360
 
198
361
  let memoryAfterContainer = self.getMemoryUsage()
199
362
  let gpuAfterContainer = self.getGPUMemoryUsage()
@@ -205,10 +368,14 @@ class HybridLLM: HybridLLMSpec {
205
368
  log("Loaded \(self.tools.count) tools: \(self.tools.map { $0.name })")
206
369
  }
207
370
 
208
- let additionalContextDict: [String: Any]? = if let messages = options?.additionalContext {
209
- ["messages": messages.map { ["role": $0.role, "content": $0.content] }]
371
+ let additionalContextDict: [String: any Sendable]?
372
+ if let messages = options?.additionalContext {
373
+ let contextMessages: [[String: String]] = messages.map {
374
+ ["role": $0.role, "content": $0.content]
375
+ }
376
+ additionalContextDict = ["messages": contextMessages]
210
377
  } else {
211
- nil
378
+ additionalContextDict = nil
212
379
  }
213
380
 
214
381
  self.container = loadedContainer
@@ -221,6 +388,8 @@ class HybridLLM: HybridLLMSpec {
221
388
  if self.manageHistory {
222
389
  log("History management enabled with \(self.messageHistory.count) initial messages")
223
390
  }
391
+
392
+ progress.set(1.0)
224
393
  }
225
394
 
226
395
  self.loadTask = task
@@ -471,7 +640,7 @@ class HybridLLM: HybridLLMSpec {
471
640
  additionalContext: additionalCtx
472
641
  )
473
642
  self.log("template_applied token_count=\(result.count)")
474
- let decoded = context.tokenizer.decode(tokens: Array(result.suffix(60)))
643
+ let decoded = context.tokenizer.decode(tokenIds: Array(result.suffix(60)))
475
644
  self.log("input_tail_decoded: \(decoded)")
476
645
  self.lastInputContainedThinkTag = decoded.contains("<think>")
477
646
  return result
@@ -487,7 +656,7 @@ class HybridLLM: HybridLLMSpec {
487
656
  additionalContext: additionalCtx
488
657
  )
489
658
  self.log("fallback_template_applied token_count=\(result.count)")
490
- let decoded = context.tokenizer.decode(tokens: Array(result.suffix(60)))
659
+ let decoded = context.tokenizer.decode(tokenIds: Array(result.suffix(60)))
491
660
  self.log("fallback_input_tail_decoded: \(decoded)")
492
661
  self.lastInputContainedThinkTag = decoded.contains("<think>")
493
662
  return result
@@ -3,4 +3,5 @@ import Foundation
3
3
  public enum LLMError: Error {
4
4
  case notLoaded
5
5
  case generationFailed(String)
6
+ case unsupportedModel(String)
6
7
  }
package/package.json CHANGED
@@ -1,7 +1,7 @@
1
1
  {
2
2
  "name": "@inferrlm/react-native-mlx",
3
3
  "description": "MLX Swift integration for React Native - InferrLM fork with enhanced features",
4
- "version": "0.4.2",
4
+ "version": "0.4.8",
5
5
  "main": "./lib/module/index.js",
6
6
  "module": "./lib/module/index.js",
7
7
  "types": "./lib/typescript/src/index.d.ts",