@inferrlm/react-native-mlx 0.2.0-inferrlm.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/MLXReactNative.podspec +42 -0
- package/ios/Bridge.h +8 -0
- package/ios/Sources/HybridLLM.swift +245 -0
- package/ios/Sources/HybridModelManager.swift +77 -0
- package/ios/Sources/LLMError.swift +6 -0
- package/ios/Sources/MLXReactNative.h +16 -0
- package/ios/Sources/ModelDownloader.swift +103 -0
- package/lib/module/index.js +6 -0
- package/lib/module/index.js.map +1 -0
- package/lib/module/llm.js +125 -0
- package/lib/module/llm.js.map +1 -0
- package/lib/module/modelManager.js +79 -0
- package/lib/module/modelManager.js.map +1 -0
- package/lib/module/models.js +41 -0
- package/lib/module/models.js.map +1 -0
- package/lib/module/package.json +1 -0
- package/lib/module/specs/LLM.nitro.js +4 -0
- package/lib/module/specs/LLM.nitro.js.map +1 -0
- package/lib/module/specs/ModelManager.nitro.js +4 -0
- package/lib/module/specs/ModelManager.nitro.js.map +1 -0
- package/lib/typescript/package.json +1 -0
- package/lib/typescript/src/index.d.ts +6 -0
- package/lib/typescript/src/index.d.ts.map +1 -0
- package/lib/typescript/src/llm.d.ts +87 -0
- package/lib/typescript/src/llm.d.ts.map +1 -0
- package/lib/typescript/src/modelManager.d.ts +53 -0
- package/lib/typescript/src/modelManager.d.ts.map +1 -0
- package/lib/typescript/src/models.d.ts +29 -0
- package/lib/typescript/src/models.d.ts.map +1 -0
- package/lib/typescript/src/specs/LLM.nitro.d.ts +88 -0
- package/lib/typescript/src/specs/LLM.nitro.d.ts.map +1 -0
- package/lib/typescript/src/specs/ModelManager.nitro.d.ts +41 -0
- package/lib/typescript/src/specs/ModelManager.nitro.d.ts.map +1 -0
- package/nitrogen/generated/.gitattributes +1 -0
- package/nitrogen/generated/ios/MLXReactNative+autolinking.rb +60 -0
- package/nitrogen/generated/ios/MLXReactNative-Swift-Cxx-Bridge.cpp +98 -0
- package/nitrogen/generated/ios/MLXReactNative-Swift-Cxx-Bridge.hpp +399 -0
- package/nitrogen/generated/ios/MLXReactNative-Swift-Cxx-Umbrella.hpp +62 -0
- package/nitrogen/generated/ios/MLXReactNativeAutolinking.mm +41 -0
- package/nitrogen/generated/ios/MLXReactNativeAutolinking.swift +40 -0
- package/nitrogen/generated/ios/c++/HybridLLMSpecSwift.cpp +11 -0
- package/nitrogen/generated/ios/c++/HybridLLMSpecSwift.hpp +160 -0
- package/nitrogen/generated/ios/c++/HybridModelManagerSpecSwift.cpp +11 -0
- package/nitrogen/generated/ios/c++/HybridModelManagerSpecSwift.hpp +116 -0
- package/nitrogen/generated/ios/swift/Func_void.swift +47 -0
- package/nitrogen/generated/ios/swift/Func_void_bool.swift +47 -0
- package/nitrogen/generated/ios/swift/Func_void_double.swift +47 -0
- package/nitrogen/generated/ios/swift/Func_void_std__exception_ptr.swift +47 -0
- package/nitrogen/generated/ios/swift/Func_void_std__string.swift +47 -0
- package/nitrogen/generated/ios/swift/Func_void_std__vector_std__string_.swift +47 -0
- package/nitrogen/generated/ios/swift/GenerationStats.swift +69 -0
- package/nitrogen/generated/ios/swift/HybridLLMSpec.swift +67 -0
- package/nitrogen/generated/ios/swift/HybridLLMSpec_cxx.swift +285 -0
- package/nitrogen/generated/ios/swift/HybridModelManagerSpec.swift +60 -0
- package/nitrogen/generated/ios/swift/HybridModelManagerSpec_cxx.swift +234 -0
- package/nitrogen/generated/ios/swift/LLMLoadOptions.swift +138 -0
- package/nitrogen/generated/ios/swift/LLMMessage.swift +47 -0
- package/nitrogen/generated/shared/c++/GenerationStats.hpp +87 -0
- package/nitrogen/generated/shared/c++/HybridLLMSpec.cpp +35 -0
- package/nitrogen/generated/shared/c++/HybridLLMSpec.hpp +87 -0
- package/nitrogen/generated/shared/c++/HybridModelManagerSpec.cpp +27 -0
- package/nitrogen/generated/shared/c++/HybridModelManagerSpec.hpp +70 -0
- package/nitrogen/generated/shared/c++/LLMLoadOptions.hpp +87 -0
- package/nitrogen/generated/shared/c++/LLMMessage.hpp +79 -0
- package/package.json +142 -0
- package/src/index.ts +6 -0
- package/src/llm.ts +144 -0
- package/src/modelManager.ts +88 -0
- package/src/models.ts +45 -0
- package/src/specs/LLM.nitro.ts +98 -0
- package/src/specs/ModelManager.nitro.ts +44 -0
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
require "json"
|
|
2
|
+
|
|
3
|
+
package = JSON.parse(File.read(File.join(__dir__, "package.json")))
|
|
4
|
+
|
|
5
|
+
Pod::Spec.new do |s|
|
|
6
|
+
s.name = "MLXReactNative"
|
|
7
|
+
s.version = package["version"]
|
|
8
|
+
s.summary = package["description"]
|
|
9
|
+
s.homepage = package["homepage"]
|
|
10
|
+
s.license = package["license"]
|
|
11
|
+
s.authors = package["author"]
|
|
12
|
+
|
|
13
|
+
s.platforms = { :ios => 26.0, :visionos => 1.0 }
|
|
14
|
+
s.source = { :git => "https://github.com/corasan/react-native-nitro-mlx.git", :tag => "#{s.version}" }
|
|
15
|
+
|
|
16
|
+
s.source_files = [
|
|
17
|
+
# Implementation (Swift)
|
|
18
|
+
"ios/Sources/**/*.{swift}",
|
|
19
|
+
# Autolinking/Registration (Objective-C++)
|
|
20
|
+
"ios/**/*.{m,mm}",
|
|
21
|
+
# Implementation (C++ objects)
|
|
22
|
+
"cpp/**/*.{hpp,cpp}",
|
|
23
|
+
]
|
|
24
|
+
|
|
25
|
+
spm_dependency(s,
|
|
26
|
+
url: "https://github.com/ml-explore/mlx-swift-lm.git",
|
|
27
|
+
requirement: {kind: "upToNextMinorVersion", minimumVersion: "2.29.2"},
|
|
28
|
+
products: ["MLXLLM", "MLXLMCommon"]
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
s.pod_target_xcconfig = {
|
|
32
|
+
# C++ compiler flags, mainly for folly.
|
|
33
|
+
"GCC_PREPROCESSOR_DEFINITIONS" => "$(inherited) FOLLY_NO_CONFIG FOLLY_CFG_NO_COROUTINES"
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
load 'nitrogen/generated/ios/MLXReactNative+autolinking.rb'
|
|
37
|
+
add_nitrogen_files(s)
|
|
38
|
+
|
|
39
|
+
s.dependency 'React-jsi'
|
|
40
|
+
s.dependency 'React-callinvoker'
|
|
41
|
+
install_modules_dependencies(s)
|
|
42
|
+
end
|
package/ios/Bridge.h
ADDED
|
@@ -0,0 +1,245 @@
|
|
|
1
|
+
import Foundation
|
|
2
|
+
import NitroModules
|
|
3
|
+
internal import MLX
|
|
4
|
+
internal import MLXLLM
|
|
5
|
+
internal import MLXLMCommon
|
|
6
|
+
|
|
7
|
+
class HybridLLM: HybridLLMSpec {
|
|
8
|
+
private var session: ChatSession?
|
|
9
|
+
private var currentTask: Task<String, Error>?
|
|
10
|
+
private var container: Any?
|
|
11
|
+
private var lastStats: GenerationStats = GenerationStats(
|
|
12
|
+
tokenCount: 0,
|
|
13
|
+
tokensPerSecond: 0,
|
|
14
|
+
timeToFirstToken: 0,
|
|
15
|
+
totalTime: 0
|
|
16
|
+
)
|
|
17
|
+
private var modelFactory: ModelFactory = LLMModelFactory.shared
|
|
18
|
+
private var manageHistory: Bool = false
|
|
19
|
+
private var messageHistory: [LLMMessage] = []
|
|
20
|
+
|
|
21
|
+
var isLoaded: Bool { session != nil }
|
|
22
|
+
var isGenerating: Bool { currentTask != nil }
|
|
23
|
+
var modelId: String = ""
|
|
24
|
+
var debug: Bool = false
|
|
25
|
+
var systemPrompt: String = "You are a helpful assistant."
|
|
26
|
+
var additionalContext: LLMMessage = LLMMessage()
|
|
27
|
+
|
|
28
|
+
private func log(_ message: String) {
|
|
29
|
+
if debug {
|
|
30
|
+
print("[MLXReactNative.HybridLLM] \(message)")
|
|
31
|
+
}
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
private func getMemoryUsage() -> String {
|
|
35
|
+
var taskInfo = mach_task_basic_info()
|
|
36
|
+
var count = mach_msg_type_number_t(MemoryLayout<mach_task_basic_info>.size)/4
|
|
37
|
+
let result: kern_return_t = withUnsafeMutablePointer(to: &taskInfo) {
|
|
38
|
+
$0.withMemoryRebound(to: integer_t.self, capacity: 1) {
|
|
39
|
+
task_info(mach_task_self_,
|
|
40
|
+
task_flavor_t(MACH_TASK_BASIC_INFO),
|
|
41
|
+
$0,
|
|
42
|
+
&count)
|
|
43
|
+
}
|
|
44
|
+
}
|
|
45
|
+
|
|
46
|
+
if result == KERN_SUCCESS {
|
|
47
|
+
let usedMB = Float(taskInfo.resident_size) / 1024.0 / 1024.0
|
|
48
|
+
return String(format: "%.1f MB", usedMB)
|
|
49
|
+
} else {
|
|
50
|
+
return "unknown"
|
|
51
|
+
}
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
private func getGPUMemoryUsage() -> String {
|
|
55
|
+
let snapshot = GPU.snapshot()
|
|
56
|
+
let allocatedMB = Float(snapshot.activeMemory) / 1024.0 / 1024.0
|
|
57
|
+
let cacheMB = Float(snapshot.cacheMemory) / 1024.0 / 1024.0
|
|
58
|
+
let peakMB = Float(snapshot.peakMemory) / 1024.0 / 1024.0
|
|
59
|
+
return String(format: "Allocated: %.1f MB, Cache: %.1f MB, Peak: %.1f MB",
|
|
60
|
+
allocatedMB, cacheMB, peakMB)
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
func load(modelId: String, options: LLMLoadOptions?) throws -> Promise<Void> {
|
|
64
|
+
return Promise.async { [self] in
|
|
65
|
+
MLX.GPU.set(cacheLimit: 2000000)
|
|
66
|
+
|
|
67
|
+
self.currentTask?.cancel()
|
|
68
|
+
self.currentTask = nil
|
|
69
|
+
self.session = nil
|
|
70
|
+
self.container = nil
|
|
71
|
+
MLX.GPU.clearCache()
|
|
72
|
+
|
|
73
|
+
let memoryAfterCleanup = self.getMemoryUsage()
|
|
74
|
+
let gpuAfterCleanup = self.getGPUMemoryUsage()
|
|
75
|
+
log("After cleanup - Host: \(memoryAfterCleanup), GPU: \(gpuAfterCleanup)")
|
|
76
|
+
|
|
77
|
+
let modelDir = await ModelDownloader.shared.getModelDirectory(modelId: modelId)
|
|
78
|
+
log("Loading from directory: \(modelDir.path)")
|
|
79
|
+
|
|
80
|
+
let config = ModelConfiguration(directory: modelDir)
|
|
81
|
+
let loadedContainer = try await modelFactory.loadContainer(
|
|
82
|
+
configuration: config
|
|
83
|
+
) { progress in
|
|
84
|
+
options?.onProgress?(progress.fractionCompleted)
|
|
85
|
+
}
|
|
86
|
+
|
|
87
|
+
let memoryAfterContainer = self.getMemoryUsage()
|
|
88
|
+
let gpuAfterContainer = self.getGPUMemoryUsage()
|
|
89
|
+
log("Model loaded - Host: \(memoryAfterContainer), GPU: \(gpuAfterContainer)")
|
|
90
|
+
|
|
91
|
+
// Convert [LLMMessage]? to [String: Any]?
|
|
92
|
+
let additionalContextDict: [String: Any]? = if let messages = options?.additionalContext {
|
|
93
|
+
["messages": messages.map { ["role": $0.role, "content": $0.content] }]
|
|
94
|
+
} else {
|
|
95
|
+
nil
|
|
96
|
+
}
|
|
97
|
+
|
|
98
|
+
self.container = loadedContainer
|
|
99
|
+
self.session = ChatSession(loadedContainer, instructions: self.systemPrompt, additionalContext: additionalContextDict)
|
|
100
|
+
self.modelId = modelId
|
|
101
|
+
|
|
102
|
+
self.manageHistory = options?.manageHistory ?? false
|
|
103
|
+
self.messageHistory = options?.additionalContext ?? []
|
|
104
|
+
|
|
105
|
+
if self.manageHistory {
|
|
106
|
+
log("History management enabled with \(self.messageHistory.count) initial messages")
|
|
107
|
+
}
|
|
108
|
+
}
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
func generate(prompt: String) throws -> Promise<String> {
|
|
112
|
+
guard let session = session else {
|
|
113
|
+
throw LLMError.notLoaded
|
|
114
|
+
}
|
|
115
|
+
|
|
116
|
+
return Promise.async { [self] in
|
|
117
|
+
if self.manageHistory {
|
|
118
|
+
self.messageHistory.append(LLMMessage(role: "user", content: prompt))
|
|
119
|
+
}
|
|
120
|
+
|
|
121
|
+
let task = Task<String, Error> {
|
|
122
|
+
log("Generating response for: \(prompt.prefix(50))...")
|
|
123
|
+
let result = try await session.respond(to: prompt)
|
|
124
|
+
log("Generation complete")
|
|
125
|
+
return result
|
|
126
|
+
}
|
|
127
|
+
|
|
128
|
+
self.currentTask = task
|
|
129
|
+
|
|
130
|
+
do {
|
|
131
|
+
let result = try await task.value
|
|
132
|
+
self.currentTask = nil
|
|
133
|
+
|
|
134
|
+
if self.manageHistory {
|
|
135
|
+
self.messageHistory.append(LLMMessage(role: "assistant", content: result))
|
|
136
|
+
}
|
|
137
|
+
|
|
138
|
+
return result
|
|
139
|
+
} catch {
|
|
140
|
+
self.currentTask = nil
|
|
141
|
+
throw error
|
|
142
|
+
}
|
|
143
|
+
}
|
|
144
|
+
}
|
|
145
|
+
|
|
146
|
+
func stream(prompt: String, onToken: @escaping (String) -> Void) throws -> Promise<String> {
|
|
147
|
+
guard let session = session else {
|
|
148
|
+
throw LLMError.notLoaded
|
|
149
|
+
}
|
|
150
|
+
|
|
151
|
+
return Promise.async { [self] in
|
|
152
|
+
if self.manageHistory {
|
|
153
|
+
self.messageHistory.append(LLMMessage(role: "user", content: prompt))
|
|
154
|
+
}
|
|
155
|
+
|
|
156
|
+
let task = Task<String, Error> {
|
|
157
|
+
var result = ""
|
|
158
|
+
var tokenCount = 0
|
|
159
|
+
let startTime = Date()
|
|
160
|
+
var firstTokenTime: Date?
|
|
161
|
+
|
|
162
|
+
log("Streaming response for: \(prompt.prefix(50))...")
|
|
163
|
+
for try await chunk in session.streamResponse(to: prompt) {
|
|
164
|
+
if Task.isCancelled { break }
|
|
165
|
+
|
|
166
|
+
if firstTokenTime == nil {
|
|
167
|
+
firstTokenTime = Date()
|
|
168
|
+
}
|
|
169
|
+
tokenCount += 1
|
|
170
|
+
result += chunk
|
|
171
|
+
onToken(chunk)
|
|
172
|
+
}
|
|
173
|
+
|
|
174
|
+
let endTime = Date()
|
|
175
|
+
let totalTime = endTime.timeIntervalSince(startTime) * 1000
|
|
176
|
+
let timeToFirstToken = (firstTokenTime ?? endTime).timeIntervalSince(startTime) * 1000
|
|
177
|
+
let tokensPerSecond = totalTime > 0 ? Double(tokenCount) / (totalTime / 1000) : 0
|
|
178
|
+
|
|
179
|
+
self.lastStats = GenerationStats(
|
|
180
|
+
tokenCount: Double(tokenCount),
|
|
181
|
+
tokensPerSecond: tokensPerSecond,
|
|
182
|
+
timeToFirstToken: timeToFirstToken,
|
|
183
|
+
totalTime: totalTime
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
log("Stream complete - \(tokenCount) tokens, \(String(format: "%.1f", tokensPerSecond)) tokens/s")
|
|
187
|
+
return result
|
|
188
|
+
}
|
|
189
|
+
|
|
190
|
+
self.currentTask = task
|
|
191
|
+
|
|
192
|
+
do {
|
|
193
|
+
let result = try await task.value
|
|
194
|
+
self.currentTask = nil
|
|
195
|
+
|
|
196
|
+
if self.manageHistory {
|
|
197
|
+
self.messageHistory.append(LLMMessage(role: "assistant", content: result))
|
|
198
|
+
}
|
|
199
|
+
|
|
200
|
+
return result
|
|
201
|
+
} catch {
|
|
202
|
+
self.currentTask = nil
|
|
203
|
+
throw error
|
|
204
|
+
}
|
|
205
|
+
}
|
|
206
|
+
}
|
|
207
|
+
|
|
208
|
+
func stop() throws {
|
|
209
|
+
currentTask?.cancel()
|
|
210
|
+
currentTask = nil
|
|
211
|
+
}
|
|
212
|
+
|
|
213
|
+
func unload() throws {
|
|
214
|
+
let memoryBefore = getMemoryUsage()
|
|
215
|
+
let gpuBefore = getGPUMemoryUsage()
|
|
216
|
+
log("Before unload - Host: \(memoryBefore), GPU: \(gpuBefore)")
|
|
217
|
+
|
|
218
|
+
currentTask?.cancel()
|
|
219
|
+
currentTask = nil
|
|
220
|
+
session = nil
|
|
221
|
+
container = nil
|
|
222
|
+
messageHistory = []
|
|
223
|
+
manageHistory = false
|
|
224
|
+
modelId = ""
|
|
225
|
+
|
|
226
|
+
MLX.GPU.clearCache()
|
|
227
|
+
|
|
228
|
+
let memoryAfter = getMemoryUsage()
|
|
229
|
+
let gpuAfter = getGPUMemoryUsage()
|
|
230
|
+
log("After unload - Host: \(memoryAfter), GPU: \(gpuAfter)")
|
|
231
|
+
}
|
|
232
|
+
|
|
233
|
+
func getLastGenerationStats() throws -> GenerationStats {
|
|
234
|
+
return lastStats
|
|
235
|
+
}
|
|
236
|
+
|
|
237
|
+
func getHistory() throws -> [LLMMessage] {
|
|
238
|
+
return messageHistory
|
|
239
|
+
}
|
|
240
|
+
|
|
241
|
+
func clearHistory() throws {
|
|
242
|
+
messageHistory = []
|
|
243
|
+
log("Message history cleared")
|
|
244
|
+
}
|
|
245
|
+
}
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
import Foundation
|
|
2
|
+
import NitroModules
|
|
3
|
+
internal import MLXLMCommon
|
|
4
|
+
internal import MLXLLM
|
|
5
|
+
|
|
6
|
+
class HybridModelManager: HybridModelManagerSpec {
|
|
7
|
+
private let fileManager = FileManager.default
|
|
8
|
+
|
|
9
|
+
var debug: Bool {
|
|
10
|
+
get { ModelDownloader.debug }
|
|
11
|
+
set { ModelDownloader.debug = newValue }
|
|
12
|
+
}
|
|
13
|
+
|
|
14
|
+
private func log(_ message: String) {
|
|
15
|
+
if debug {
|
|
16
|
+
print("[MLXReactNative.HybridModelManager] \(message)")
|
|
17
|
+
}
|
|
18
|
+
}
|
|
19
|
+
|
|
20
|
+
func download(
|
|
21
|
+
modelId: String,
|
|
22
|
+
progressCallback: @escaping (Double) -> Void
|
|
23
|
+
) throws -> Promise<String> {
|
|
24
|
+
return Promise.async { [self] in
|
|
25
|
+
log("Starting download for: \(modelId)")
|
|
26
|
+
|
|
27
|
+
let modelDir = try await ModelDownloader.shared.download(
|
|
28
|
+
modelId: modelId,
|
|
29
|
+
progressCallback: progressCallback
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
log("Download complete: \(modelDir.path)")
|
|
33
|
+
return modelDir.path
|
|
34
|
+
}
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
func isDownloaded(modelId: String) throws -> Promise<Bool> {
|
|
38
|
+
return Promise.async {
|
|
39
|
+
return await ModelDownloader.shared.isDownloaded(modelId: modelId)
|
|
40
|
+
}
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
func getDownloadedModels() throws -> Promise<[String]> {
|
|
44
|
+
return Promise.async { [self] in
|
|
45
|
+
let docsDir = fileManager.urls(for: .documentDirectory, in: .userDomainMask).first!
|
|
46
|
+
let modelsDir = docsDir.appendingPathComponent("huggingface/models")
|
|
47
|
+
|
|
48
|
+
guard fileManager.fileExists(atPath: modelsDir.path) else {
|
|
49
|
+
return []
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
let contents = try fileManager.contentsOfDirectory(
|
|
53
|
+
at: modelsDir,
|
|
54
|
+
includingPropertiesForKeys: [.isDirectoryKey]
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
return contents
|
|
58
|
+
.filter { url in
|
|
59
|
+
var isDir: ObjCBool = false
|
|
60
|
+
return fileManager.fileExists(atPath: url.path, isDirectory: &isDir) && isDir.boolValue
|
|
61
|
+
}
|
|
62
|
+
.map { $0.lastPathComponent.replacingOccurrences(of: "_", with: "/") }
|
|
63
|
+
}
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
func deleteModel(modelId: String) throws -> Promise<Void> {
|
|
67
|
+
return Promise.async {
|
|
68
|
+
try await ModelDownloader.shared.deleteModel(modelId: modelId)
|
|
69
|
+
}
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
func getModelPath(modelId: String) throws -> Promise<String> {
|
|
73
|
+
return Promise.async {
|
|
74
|
+
return await ModelDownloader.shared.getModelDirectory(modelId: modelId).path
|
|
75
|
+
}
|
|
76
|
+
}
|
|
77
|
+
}
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
//
|
|
2
|
+
// MLXReactNative.h
|
|
3
|
+
// MLXReactNative
|
|
4
|
+
//
|
|
5
|
+
// Created by Henry on 2/20/25.
|
|
6
|
+
//
|
|
7
|
+
|
|
8
|
+
#import <Foundation/Foundation.h>
|
|
9
|
+
|
|
10
|
+
//! Project version number for MLXReactNative.
|
|
11
|
+
FOUNDATION_EXPORT double MLXReactNativeVersionNumber;
|
|
12
|
+
|
|
13
|
+
//! Project version string for MLXReactNative.
|
|
14
|
+
FOUNDATION_EXPORT const unsigned char MLXReactNativeVersionString[];
|
|
15
|
+
|
|
16
|
+
// In this header, you should import all the public headers of your framework using statements like #import <MLXReactNative/PublicHeader.h>
|
|
@@ -0,0 +1,103 @@
|
|
|
1
|
+
import Foundation
|
|
2
|
+
|
|
3
|
+
actor ModelDownloader: NSObject {
|
|
4
|
+
static let shared = ModelDownloader()
|
|
5
|
+
static var debug: Bool = false
|
|
6
|
+
|
|
7
|
+
private let fileManager = FileManager.default
|
|
8
|
+
|
|
9
|
+
private func log(_ message: String) {
|
|
10
|
+
if Self.debug {
|
|
11
|
+
print("[Downloader] \(message)")
|
|
12
|
+
}
|
|
13
|
+
}
|
|
14
|
+
|
|
15
|
+
func download(
|
|
16
|
+
modelId: String,
|
|
17
|
+
progressCallback: @escaping (Double) -> Void
|
|
18
|
+
) async throws -> URL {
|
|
19
|
+
let requiredFiles = [
|
|
20
|
+
"config.json",
|
|
21
|
+
"tokenizer.json",
|
|
22
|
+
"tokenizer_config.json",
|
|
23
|
+
"model.safetensors"
|
|
24
|
+
]
|
|
25
|
+
|
|
26
|
+
let modelDir = getModelDirectory(modelId: modelId)
|
|
27
|
+
try fileManager.createDirectory(at: modelDir, withIntermediateDirectories: true)
|
|
28
|
+
|
|
29
|
+
log("Model directory: \(modelDir.path)")
|
|
30
|
+
log("Files to download: \(requiredFiles)")
|
|
31
|
+
|
|
32
|
+
var downloaded = 0
|
|
33
|
+
|
|
34
|
+
for file in requiredFiles {
|
|
35
|
+
let destURL = modelDir.appendingPathComponent(file)
|
|
36
|
+
|
|
37
|
+
if fileManager.fileExists(atPath: destURL.path) {
|
|
38
|
+
log("File exists, skipping: \(file)")
|
|
39
|
+
downloaded += 1
|
|
40
|
+
progressCallback(Double(downloaded) / Double(requiredFiles.count))
|
|
41
|
+
continue
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
let urlString = "https://huggingface.co/\(modelId)/resolve/main/\(file)"
|
|
45
|
+
guard let url = URL(string: urlString) else {
|
|
46
|
+
log("Invalid URL: \(urlString)")
|
|
47
|
+
continue
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
log("Downloading: \(file)")
|
|
51
|
+
|
|
52
|
+
let (tempURL, response) = try await URLSession.shared.download(from: url)
|
|
53
|
+
|
|
54
|
+
guard let httpResponse = response as? HTTPURLResponse else {
|
|
55
|
+
log("Invalid response for: \(file)")
|
|
56
|
+
continue
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
log("Response status: \(httpResponse.statusCode) for \(file)")
|
|
60
|
+
|
|
61
|
+
if httpResponse.statusCode == 200 {
|
|
62
|
+
if fileManager.fileExists(atPath: destURL.path) {
|
|
63
|
+
try fileManager.removeItem(at: destURL)
|
|
64
|
+
}
|
|
65
|
+
try fileManager.moveItem(at: tempURL, to: destURL)
|
|
66
|
+
log("Saved: \(file)")
|
|
67
|
+
} else {
|
|
68
|
+
log("Failed to download: \(file) - Status: \(httpResponse.statusCode)")
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
downloaded += 1
|
|
72
|
+
progressCallback(Double(downloaded) / Double(requiredFiles.count))
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
return modelDir
|
|
76
|
+
}
|
|
77
|
+
|
|
78
|
+
func isDownloaded(modelId: String) -> Bool {
|
|
79
|
+
let modelDir = getModelDirectory(modelId: modelId)
|
|
80
|
+
let requiredFiles = ["config.json", "model.safetensors", "tokenizer.json"]
|
|
81
|
+
|
|
82
|
+
let allExist = requiredFiles.allSatisfy { file in
|
|
83
|
+
fileManager.fileExists(atPath: modelDir.appendingPathComponent(file).path)
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
log("isDownloaded(\(modelId)): \(allExist)")
|
|
87
|
+
return allExist
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
func getModelDirectory(modelId: String) -> URL {
|
|
91
|
+
let docsDir = fileManager.urls(for: .documentDirectory, in: .userDomainMask).first!
|
|
92
|
+
return docsDir
|
|
93
|
+
.appendingPathComponent("huggingface/models")
|
|
94
|
+
.appendingPathComponent(modelId.replacingOccurrences(of: "/", with: "_"))
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
func deleteModel(modelId: String) throws {
|
|
98
|
+
let modelDir = getModelDirectory(modelId: modelId)
|
|
99
|
+
if fileManager.fileExists(atPath: modelDir.path) {
|
|
100
|
+
try fileManager.removeItem(at: modelDir)
|
|
101
|
+
}
|
|
102
|
+
}
|
|
103
|
+
}
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"names":["LLM","ModelManager","MLXModel"],"sourceRoot":"../../src","sources":["index.ts"],"mappings":";;AAAA,SAASA,GAAG,QAAsB,UAAO;AACzC,SAASC,YAAY,QAAQ,mBAAgB;AAC7C,SAASC,QAAQ,QAAQ,aAAU","ignoreList":[]}
|
|
@@ -0,0 +1,125 @@
|
|
|
1
|
+
"use strict";
|
|
2
|
+
|
|
3
|
+
import { NitroModules } from 'react-native-nitro-modules';
|
|
4
|
+
let instance = null;
|
|
5
|
+
function getInstance() {
|
|
6
|
+
if (!instance) {
|
|
7
|
+
instance = NitroModules.createHybridObject('LLM');
|
|
8
|
+
}
|
|
9
|
+
return instance;
|
|
10
|
+
}
|
|
11
|
+
|
|
12
|
+
/**
|
|
13
|
+
* LLM text generation using MLX on Apple Silicon.
|
|
14
|
+
*
|
|
15
|
+
* @example
|
|
16
|
+
* ```ts
|
|
17
|
+
* import { LLM } from 'react-native-nitro-mlx'
|
|
18
|
+
*
|
|
19
|
+
* // Load a model
|
|
20
|
+
* await LLM.load('mlx-community/Qwen3-0.6B-4bit', progress => {
|
|
21
|
+
* console.log(`Loading: ${(progress * 100).toFixed(0)}%`)
|
|
22
|
+
* })
|
|
23
|
+
*
|
|
24
|
+
* // Stream a response
|
|
25
|
+
* await LLM.stream('Hello!', token => {
|
|
26
|
+
* process.stdout.write(token)
|
|
27
|
+
* })
|
|
28
|
+
*
|
|
29
|
+
* // Get generation stats
|
|
30
|
+
* const stats = LLM.getLastGenerationStats()
|
|
31
|
+
* console.log(`${stats.tokensPerSecond} tokens/sec`)
|
|
32
|
+
* ```
|
|
33
|
+
*/
|
|
34
|
+
export const LLM = {
|
|
35
|
+
/**
|
|
36
|
+
* Load a model into memory. Downloads the model from HuggingFace if not already cached.
|
|
37
|
+
* @param modelId - HuggingFace model ID (e.g., 'mlx-community/Qwen3-0.6B-4bit')
|
|
38
|
+
* @param options - Callback invoked with loading progress (0-1)
|
|
39
|
+
*/
|
|
40
|
+
load(modelId, options) {
|
|
41
|
+
return getInstance().load(modelId, options);
|
|
42
|
+
},
|
|
43
|
+
/**
|
|
44
|
+
* Generate a complete response for a prompt. Blocks until generation is complete.
|
|
45
|
+
* For streaming responses, use `stream()` instead.
|
|
46
|
+
* @param prompt - The input text to generate a response for
|
|
47
|
+
* @returns The complete generated text
|
|
48
|
+
*/
|
|
49
|
+
generate(prompt) {
|
|
50
|
+
return getInstance().generate(prompt);
|
|
51
|
+
},
|
|
52
|
+
/**
|
|
53
|
+
* Stream a response token by token.
|
|
54
|
+
* @param prompt - The input text to generate a response for
|
|
55
|
+
* @param onToken - Callback invoked for each generated token
|
|
56
|
+
* @returns The complete generated text
|
|
57
|
+
*/
|
|
58
|
+
stream(prompt, onToken) {
|
|
59
|
+
return getInstance().stream(prompt, onToken);
|
|
60
|
+
},
|
|
61
|
+
/**
|
|
62
|
+
* Stop the current generation. Safe to call even if not generating.
|
|
63
|
+
*/
|
|
64
|
+
stop() {
|
|
65
|
+
getInstance().stop();
|
|
66
|
+
},
|
|
67
|
+
/**
|
|
68
|
+
* Unload the current model and release memory.
|
|
69
|
+
* Call this when you're done with the model to free up memory.
|
|
70
|
+
*/
|
|
71
|
+
unload() {
|
|
72
|
+
getInstance().unload();
|
|
73
|
+
},
|
|
74
|
+
/**
|
|
75
|
+
* Get statistics from the last generation.
|
|
76
|
+
* @returns Statistics including token count, tokens/sec, TTFT, and total time
|
|
77
|
+
*/
|
|
78
|
+
getLastGenerationStats() {
|
|
79
|
+
return getInstance().getLastGenerationStats();
|
|
80
|
+
},
|
|
81
|
+
/**
|
|
82
|
+
* Get the message history if management is enabled.
|
|
83
|
+
* @returns Array of messages in the history
|
|
84
|
+
*/
|
|
85
|
+
getHistory() {
|
|
86
|
+
return getInstance().getHistory();
|
|
87
|
+
},
|
|
88
|
+
/**
|
|
89
|
+
* Clear the message history.
|
|
90
|
+
*/
|
|
91
|
+
clearHistory() {
|
|
92
|
+
getInstance().clearHistory();
|
|
93
|
+
},
|
|
94
|
+
/** Whether a model is currently loaded and ready for generation */
|
|
95
|
+
get isLoaded() {
|
|
96
|
+
return getInstance().isLoaded;
|
|
97
|
+
},
|
|
98
|
+
/** Whether text is currently being generated */
|
|
99
|
+
get isGenerating() {
|
|
100
|
+
return getInstance().isGenerating;
|
|
101
|
+
},
|
|
102
|
+
/** The ID of the currently loaded model, or empty string if none */
|
|
103
|
+
get modelId() {
|
|
104
|
+
return getInstance().modelId;
|
|
105
|
+
},
|
|
106
|
+
/** Enable debug logging to console */
|
|
107
|
+
get debug() {
|
|
108
|
+
return getInstance().debug;
|
|
109
|
+
},
|
|
110
|
+
set debug(value) {
|
|
111
|
+
getInstance().debug = value;
|
|
112
|
+
},
|
|
113
|
+
/**
|
|
114
|
+
* System prompt used when loading the model.
|
|
115
|
+
* Set this before calling `load()`. Changes require reloading the model.
|
|
116
|
+
* @default "You are a helpful assistant."
|
|
117
|
+
*/
|
|
118
|
+
get systemPrompt() {
|
|
119
|
+
return getInstance().systemPrompt;
|
|
120
|
+
},
|
|
121
|
+
set systemPrompt(value) {
|
|
122
|
+
getInstance().systemPrompt = value;
|
|
123
|
+
}
|
|
124
|
+
};
|
|
125
|
+
//# sourceMappingURL=llm.js.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"names":["NitroModules","instance","getInstance","createHybridObject","LLM","load","modelId","options","generate","prompt","stream","onToken","stop","unload","getLastGenerationStats","getHistory","clearHistory","isLoaded","isGenerating","debug","value","systemPrompt"],"sourceRoot":"../../src","sources":["llm.ts"],"mappings":";;AAAA,SAASA,YAAY,QAAQ,4BAA4B;AAGzD,IAAIC,QAAwB,GAAG,IAAI;AAOnC,SAASC,WAAWA,CAAA,EAAY;EAC9B,IAAI,CAACD,QAAQ,EAAE;IACbA,QAAQ,GAAGD,YAAY,CAACG,kBAAkB,CAAU,KAAK,CAAC;EAC5D;EACA,OAAOF,QAAQ;AACjB;;AAEA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA;AACA,OAAO,MAAMG,GAAG,GAAG;EACjB;AACF;AACA;AACA;AACA;EACEC,IAAIA,CAACC,OAAe,EAAEC,OAAuB,EAAiB;IAC5D,OAAOL,WAAW,CAAC,CAAC,CAACG,IAAI,CAACC,OAAO,EAAEC,OAAO,CAAC;EAC7C,CAAC;EAED;AACF;AACA;AACA;AACA;AACA;EACEC,QAAQA,CAACC,MAAc,EAAmB;IACxC,OAAOP,WAAW,CAAC,CAAC,CAACM,QAAQ,CAACC,MAAM,CAAC;EACvC,CAAC;EAED;AACF;AACA;AACA;AACA;AACA;EACEC,MAAMA,CAACD,MAAc,EAAEE,OAAgC,EAAmB;IACxE,OAAOT,WAAW,CAAC,CAAC,CAACQ,MAAM,CAACD,MAAM,EAAEE,OAAO,CAAC;EAC9C,CAAC;EAED;AACF;AACA;EACEC,IAAIA,CAAA,EAAS;IACXV,WAAW,CAAC,CAAC,CAACU,IAAI,CAAC,CAAC;EACtB,CAAC;EAED;AACF;AACA;AACA;EACEC,MAAMA,CAAA,EAAS;IACbX,WAAW,CAAC,CAAC,CAACW,MAAM,CAAC,CAAC;EACxB,CAAC;EAED;AACF;AACA;AACA;EACEC,sBAAsBA,CAAA,EAAoB;IACxC,OAAOZ,WAAW,CAAC,CAAC,CAACY,sBAAsB,CAAC,CAAC;EAC/C,CAAC;EAED;AACF;AACA;AACA;EACEC,UAAUA,CAAA,EAAc;IACtB,OAAOb,WAAW,CAAC,CAAC,CAACa,UAAU,CAAC,CAAC;EACnC,CAAC;EAED;AACF;AACA;EACEC,YAAYA,CAAA,EAAS;IACnBd,WAAW,CAAC,CAAC,CAACc,YAAY,CAAC,CAAC;EAC9B,CAAC;EAED;EACA,IAAIC,QAAQA,CAAA,EAAY;IACtB,OAAOf,WAAW,CAAC,CAAC,CAACe,QAAQ;EAC/B,CAAC;EAED;EACA,IAAIC,YAAYA,CAAA,EAAY;IAC1B,OAAOhB,WAAW,CAAC,CAAC,CAACgB,YAAY;EACnC,CAAC;EAED;EACA,IAAIZ,OAAOA,CAAA,EAAW;IACpB,OAAOJ,WAAW,CAAC,CAAC,CAACI,OAAO;EAC9B,CAAC;EAED;EACA,IAAIa,KAAKA,CAAA,EAAY;IACnB,OAAOjB,WAAW,CAAC,CAAC,CAACiB,KAAK;EAC5B,CAAC;EAED,IAAIA,KAAKA,CAACC,KAAc,EAAE;IACxBlB,WAAW,CAAC,CAAC,CAACiB,KAAK,GAAGC,KAAK;EAC7B,CAAC;EAED;AACF;AACA;AACA;AACA;EACE,IAAIC,YAAYA,CAAA,EAAW;IACzB,OAAOnB,WAAW,CAAC,CAAC,CAACmB,YAAY;EACnC,CAAC;EAED,IAAIA,YAAYA,CAACD,KAAa,EAAE;IAC9BlB,WAAW,CAAC,CAAC,CAACmB,YAAY,GAAGD,KAAK;EACpC;AACF,CAAC","ignoreList":[]}
|