@inferrlm/react-native-mlx 0.2.0-inferrlm.2 → 0.4.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/MLXReactNative.podspec +9 -3
- package/ios/Sources/AudioCaptureManager.swift +110 -0
- package/ios/Sources/HybridLLM.swift +562 -74
- package/ios/Sources/HybridSTT.swift +202 -0
- package/ios/Sources/HybridTTS.swift +145 -0
- package/ios/Sources/JSONHelpers.swift +9 -0
- package/ios/Sources/ModelDownloader.swift +26 -12
- package/ios/Sources/StreamEventEmitter.swift +132 -0
- package/ios/Sources/ThinkingStateMachine.swift +206 -0
- package/nitrogen/generated/ios/MLXReactNative+autolinking.rb +1 -1
- package/nitrogen/generated/ios/MLXReactNative-Swift-Cxx-Bridge.cpp +76 -1
- package/nitrogen/generated/ios/MLXReactNative-Swift-Cxx-Bridge.hpp +338 -1
- package/nitrogen/generated/ios/MLXReactNative-Swift-Cxx-Umbrella.hpp +28 -1
- package/nitrogen/generated/ios/MLXReactNativeAutolinking.mm +17 -1
- package/nitrogen/generated/ios/MLXReactNativeAutolinking.swift +31 -1
- package/nitrogen/generated/ios/c++/HybridLLMSpecSwift.cpp +1 -1
- package/nitrogen/generated/ios/c++/HybridLLMSpecSwift.hpp +18 -3
- package/nitrogen/generated/ios/c++/HybridModelManagerSpecSwift.cpp +1 -1
- package/nitrogen/generated/ios/c++/HybridModelManagerSpecSwift.hpp +1 -1
- package/nitrogen/generated/ios/c++/HybridSTTSpecSwift.cpp +11 -0
- package/nitrogen/generated/ios/c++/HybridSTTSpecSwift.hpp +149 -0
- package/nitrogen/generated/ios/c++/HybridTTSSpecSwift.cpp +11 -0
- package/nitrogen/generated/ios/c++/HybridTTSSpecSwift.hpp +128 -0
- package/nitrogen/generated/ios/swift/Func_std__shared_ptr_Promise_std__shared_ptr_Promise_std__shared_ptr_AnyMap______std__shared_ptr_AnyMap_.swift +62 -0
- package/nitrogen/generated/ios/swift/Func_void.swift +1 -1
- package/nitrogen/generated/ios/swift/Func_void_bool.swift +1 -1
- package/nitrogen/generated/ios/swift/Func_void_double.swift +1 -1
- package/nitrogen/generated/ios/swift/Func_void_std__exception_ptr.swift +1 -1
- package/nitrogen/generated/ios/swift/Func_void_std__shared_ptr_AnyMap_.swift +47 -0
- package/nitrogen/generated/ios/swift/Func_void_std__shared_ptr_ArrayBuffer_.swift +47 -0
- package/nitrogen/generated/ios/swift/Func_void_std__shared_ptr_Promise_std__shared_ptr_AnyMap___.swift +67 -0
- package/nitrogen/generated/ios/swift/Func_void_std__string.swift +1 -1
- package/nitrogen/generated/ios/swift/Func_void_std__string_std__string.swift +47 -0
- package/nitrogen/generated/ios/swift/Func_void_std__vector_std__string_.swift +1 -1
- package/nitrogen/generated/ios/swift/GenerationStats.swift +14 -3
- package/nitrogen/generated/ios/swift/HybridLLMSpec.swift +3 -2
- package/nitrogen/generated/ios/swift/HybridLLMSpec_cxx.swift +38 -2
- package/nitrogen/generated/ios/swift/HybridModelManagerSpec.swift +1 -1
- package/nitrogen/generated/ios/swift/HybridModelManagerSpec_cxx.swift +1 -1
- package/nitrogen/generated/ios/swift/HybridSTTSpec.swift +66 -0
- package/nitrogen/generated/ios/swift/HybridSTTSpec_cxx.swift +286 -0
- package/nitrogen/generated/ios/swift/HybridTTSSpec.swift +63 -0
- package/nitrogen/generated/ios/swift/HybridTTSSpec_cxx.swift +229 -0
- package/nitrogen/generated/ios/swift/LLMLoadOptions.swift +44 -2
- package/nitrogen/generated/ios/swift/LLMMessage.swift +1 -1
- package/nitrogen/generated/ios/swift/STTLoadOptions.swift +66 -0
- package/nitrogen/generated/ios/swift/TTSGenerateOptions.swift +78 -0
- package/nitrogen/generated/ios/swift/TTSLoadOptions.swift +66 -0
- package/nitrogen/generated/ios/swift/ToolDefinition.swift +113 -0
- package/nitrogen/generated/ios/swift/ToolParameter.swift +69 -0
- package/nitrogen/generated/shared/c++/GenerationStats.hpp +7 -3
- package/nitrogen/generated/shared/c++/HybridLLMSpec.cpp +2 -1
- package/nitrogen/generated/shared/c++/HybridLLMSpec.hpp +3 -2
- package/nitrogen/generated/shared/c++/HybridModelManagerSpec.cpp +1 -1
- package/nitrogen/generated/shared/c++/HybridModelManagerSpec.hpp +1 -1
- package/nitrogen/generated/shared/c++/HybridSTTSpec.cpp +32 -0
- package/nitrogen/generated/shared/c++/HybridSTTSpec.hpp +78 -0
- package/nitrogen/generated/shared/c++/HybridTTSSpec.cpp +29 -0
- package/nitrogen/generated/shared/c++/HybridTTSSpec.hpp +78 -0
- package/nitrogen/generated/shared/c++/LLMLoadOptions.hpp +10 -3
- package/nitrogen/generated/shared/c++/LLMMessage.hpp +1 -1
- package/nitrogen/generated/shared/c++/STTLoadOptions.hpp +76 -0
- package/nitrogen/generated/shared/c++/TTSGenerateOptions.hpp +80 -0
- package/nitrogen/generated/shared/c++/TTSLoadOptions.hpp +76 -0
- package/nitrogen/generated/shared/c++/ToolDefinition.hpp +93 -0
- package/nitrogen/generated/shared/c++/ToolParameter.hpp +87 -0
- package/package.json +13 -8
- package/src/index.ts +48 -4
- package/src/llm.ts +90 -5
- package/src/models.ts +347 -0
- package/src/specs/LLM.nitro.ts +111 -7
- package/src/specs/STT.nitro.ts +35 -0
- package/src/specs/TTS.nitro.ts +30 -0
- package/src/stt.ts +67 -0
- package/src/tool-utils.ts +74 -0
- package/src/tts.ts +60 -0
- package/lib/module/index.js +0 -6
- package/lib/module/index.js.map +0 -1
- package/lib/module/llm.js +0 -125
- package/lib/module/llm.js.map +0 -1
- package/lib/module/modelManager.js +0 -79
- package/lib/module/modelManager.js.map +0 -1
- package/lib/module/models.js +0 -41
- package/lib/module/models.js.map +0 -1
- package/lib/module/package.json +0 -1
- package/lib/module/specs/LLM.nitro.js +0 -4
- package/lib/module/specs/LLM.nitro.js.map +0 -1
- package/lib/module/specs/ModelManager.nitro.js +0 -4
- package/lib/module/specs/ModelManager.nitro.js.map +0 -1
- package/lib/typescript/package.json +0 -1
- package/lib/typescript/src/index.d.ts +0 -6
- package/lib/typescript/src/index.d.ts.map +0 -1
- package/lib/typescript/src/llm.d.ts +0 -87
- package/lib/typescript/src/llm.d.ts.map +0 -1
- package/lib/typescript/src/modelManager.d.ts +0 -53
- package/lib/typescript/src/modelManager.d.ts.map +0 -1
- package/lib/typescript/src/models.d.ts +0 -29
- package/lib/typescript/src/models.d.ts.map +0 -1
- package/lib/typescript/src/specs/LLM.nitro.d.ts +0 -88
- package/lib/typescript/src/specs/LLM.nitro.d.ts.map +0 -1
- package/lib/typescript/src/specs/ModelManager.nitro.d.ts +0 -41
- package/lib/typescript/src/specs/ModelManager.nitro.d.ts.map +0 -1
|
@@ -3,20 +3,26 @@ import NitroModules
|
|
|
3
3
|
internal import MLX
|
|
4
4
|
internal import MLXLLM
|
|
5
5
|
internal import MLXLMCommon
|
|
6
|
+
internal import Tokenizers
|
|
6
7
|
|
|
7
8
|
class HybridLLM: HybridLLMSpec {
|
|
8
9
|
private var session: ChatSession?
|
|
9
10
|
private var currentTask: Task<String, Error>?
|
|
10
|
-
private var container:
|
|
11
|
+
private var container: ModelContainer?
|
|
11
12
|
private var lastStats: GenerationStats = GenerationStats(
|
|
12
13
|
tokenCount: 0,
|
|
13
14
|
tokensPerSecond: 0,
|
|
14
15
|
timeToFirstToken: 0,
|
|
15
|
-
totalTime: 0
|
|
16
|
+
totalTime: 0,
|
|
17
|
+
toolExecutionTime: 0
|
|
16
18
|
)
|
|
17
19
|
private var modelFactory: ModelFactory = LLMModelFactory.shared
|
|
18
20
|
private var manageHistory: Bool = false
|
|
19
21
|
private var messageHistory: [LLMMessage] = []
|
|
22
|
+
private var loadTask: Task<Void, Error>?
|
|
23
|
+
|
|
24
|
+
private var tools: [ToolDefinition] = []
|
|
25
|
+
private var toolSchemas: [ToolSpec] = []
|
|
20
26
|
|
|
21
27
|
var isLoaded: Bool { session != nil }
|
|
22
28
|
var isGenerating: Bool { currentTask != nil }
|
|
@@ -52,7 +58,7 @@ class HybridLLM: HybridLLMSpec {
|
|
|
52
58
|
}
|
|
53
59
|
|
|
54
60
|
private func getGPUMemoryUsage() -> String {
|
|
55
|
-
let snapshot =
|
|
61
|
+
let snapshot = Memory.snapshot()
|
|
56
62
|
let allocatedMB = Float(snapshot.activeMemory) / 1024.0 / 1024.0
|
|
57
63
|
let cacheMB = Float(snapshot.cacheMemory) / 1024.0 / 1024.0
|
|
58
64
|
let peakMB = Float(snapshot.peakMemory) / 1024.0 / 1024.0
|
|
@@ -60,51 +66,95 @@ class HybridLLM: HybridLLMSpec {
|
|
|
60
66
|
allocatedMB, cacheMB, peakMB)
|
|
61
67
|
}
|
|
62
68
|
|
|
69
|
+
private func buildToolSchema(from tool: ToolDefinition) -> ToolSpec {
|
|
70
|
+
var properties: [String: [String: Any]] = [:]
|
|
71
|
+
var required: [String] = []
|
|
72
|
+
|
|
73
|
+
for param in tool.parameters {
|
|
74
|
+
properties[param.name] = [
|
|
75
|
+
"type": param.type,
|
|
76
|
+
"description": param.description
|
|
77
|
+
]
|
|
78
|
+
if param.required {
|
|
79
|
+
required.append(param.name)
|
|
80
|
+
}
|
|
81
|
+
}
|
|
82
|
+
|
|
83
|
+
return [
|
|
84
|
+
"type": "function",
|
|
85
|
+
"function": [
|
|
86
|
+
"name": tool.name,
|
|
87
|
+
"description": tool.description,
|
|
88
|
+
"parameters": [
|
|
89
|
+
"type": "object",
|
|
90
|
+
"properties": properties,
|
|
91
|
+
"required": required
|
|
92
|
+
]
|
|
93
|
+
]
|
|
94
|
+
] as ToolSpec
|
|
95
|
+
}
|
|
96
|
+
|
|
63
97
|
func load(modelId: String, options: LLMLoadOptions?) throws -> Promise<Void> {
|
|
64
|
-
|
|
65
|
-
MLX.GPU.set(cacheLimit: 2000000)
|
|
98
|
+
self.loadTask?.cancel()
|
|
66
99
|
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
self.container = nil
|
|
71
|
-
MLX.GPU.clearCache()
|
|
100
|
+
return Promise.async { [self] in
|
|
101
|
+
let task = Task { @MainActor in
|
|
102
|
+
Memory.cacheLimit = 2000000
|
|
72
103
|
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
104
|
+
self.currentTask?.cancel()
|
|
105
|
+
self.currentTask = nil
|
|
106
|
+
self.session = nil
|
|
107
|
+
self.container = nil
|
|
108
|
+
self.tools = []
|
|
109
|
+
self.toolSchemas = []
|
|
110
|
+
Memory.clearCache()
|
|
111
|
+
|
|
112
|
+
let memoryAfterCleanup = self.getMemoryUsage()
|
|
113
|
+
let gpuAfterCleanup = self.getGPUMemoryUsage()
|
|
114
|
+
log("After cleanup - Host: \(memoryAfterCleanup), GPU: \(gpuAfterCleanup)")
|
|
115
|
+
|
|
116
|
+
let modelDir = await ModelDownloader.shared.getModelDirectory(modelId: modelId)
|
|
117
|
+
log("Loading from directory: \(modelDir.path)")
|
|
118
|
+
|
|
119
|
+
let config = ModelConfiguration(directory: modelDir)
|
|
120
|
+
let loadedContainer = try await self.modelFactory.loadContainer(
|
|
121
|
+
configuration: config
|
|
122
|
+
) { progress in
|
|
123
|
+
options?.onProgress?(progress.fractionCompleted)
|
|
124
|
+
}
|
|
76
125
|
|
|
77
|
-
|
|
78
|
-
log("Loading from directory: \(modelDir.path)")
|
|
126
|
+
try Task.checkCancellation()
|
|
79
127
|
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
) { progress in
|
|
84
|
-
options?.onProgress?(progress.fractionCompleted)
|
|
85
|
-
}
|
|
128
|
+
let memoryAfterContainer = self.getMemoryUsage()
|
|
129
|
+
let gpuAfterContainer = self.getGPUMemoryUsage()
|
|
130
|
+
log("Model loaded - Host: \(memoryAfterContainer), GPU: \(gpuAfterContainer)")
|
|
86
131
|
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
132
|
+
if let jsTools = options?.tools {
|
|
133
|
+
self.tools = jsTools
|
|
134
|
+
self.toolSchemas = jsTools.map { self.buildToolSchema(from: $0) }
|
|
135
|
+
log("Loaded \(self.tools.count) tools: \(self.tools.map { $0.name })")
|
|
136
|
+
}
|
|
90
137
|
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
}
|
|
138
|
+
let additionalContextDict: [String: Any]? = if let messages = options?.additionalContext {
|
|
139
|
+
["messages": messages.map { ["role": $0.role, "content": $0.content] }]
|
|
140
|
+
} else {
|
|
141
|
+
nil
|
|
142
|
+
}
|
|
97
143
|
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
144
|
+
self.container = loadedContainer
|
|
145
|
+
self.session = ChatSession(loadedContainer, instructions: self.systemPrompt, additionalContext: additionalContextDict)
|
|
146
|
+
self.modelId = modelId
|
|
101
147
|
|
|
102
|
-
|
|
103
|
-
|
|
148
|
+
self.manageHistory = options?.manageHistory ?? false
|
|
149
|
+
self.messageHistory = options?.additionalContext ?? []
|
|
104
150
|
|
|
105
|
-
|
|
106
|
-
|
|
151
|
+
if self.manageHistory {
|
|
152
|
+
log("History management enabled with \(self.messageHistory.count) initial messages")
|
|
153
|
+
}
|
|
107
154
|
}
|
|
155
|
+
|
|
156
|
+
self.loadTask = task
|
|
157
|
+
try await task.value
|
|
108
158
|
}
|
|
109
159
|
}
|
|
110
160
|
|
|
@@ -126,25 +176,26 @@ class HybridLLM: HybridLLMSpec {
|
|
|
126
176
|
}
|
|
127
177
|
|
|
128
178
|
self.currentTask = task
|
|
179
|
+
defer { self.currentTask = nil }
|
|
129
180
|
|
|
130
|
-
|
|
131
|
-
let result = try await task.value
|
|
132
|
-
self.currentTask = nil
|
|
181
|
+
let result = try await task.value
|
|
133
182
|
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
}
|
|
137
|
-
|
|
138
|
-
return result
|
|
139
|
-
} catch {
|
|
140
|
-
self.currentTask = nil
|
|
141
|
-
throw error
|
|
183
|
+
if self.manageHistory {
|
|
184
|
+
self.messageHistory.append(LLMMessage(role: "assistant", content: result))
|
|
142
185
|
}
|
|
186
|
+
|
|
187
|
+
return result
|
|
143
188
|
}
|
|
144
189
|
}
|
|
145
190
|
|
|
146
|
-
|
|
147
|
-
|
|
191
|
+
private let maxToolCallDepth = 10
|
|
192
|
+
|
|
193
|
+
func stream(
|
|
194
|
+
prompt: String,
|
|
195
|
+
onToken: @escaping (String) -> Void,
|
|
196
|
+
onToolCall: ((String, String) -> Void)?
|
|
197
|
+
) throws -> Promise<String> {
|
|
198
|
+
guard let container else {
|
|
148
199
|
throw LLMError.notLoaded
|
|
149
200
|
}
|
|
150
201
|
|
|
@@ -154,22 +205,24 @@ class HybridLLM: HybridLLMSpec {
|
|
|
154
205
|
}
|
|
155
206
|
|
|
156
207
|
let task = Task<String, Error> {
|
|
157
|
-
var result = ""
|
|
158
|
-
var tokenCount = 0
|
|
159
208
|
let startTime = Date()
|
|
160
209
|
var firstTokenTime: Date?
|
|
210
|
+
var tokenCount = 0
|
|
161
211
|
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
212
|
+
let result = try await self.performGeneration(
|
|
213
|
+
container: container,
|
|
214
|
+
prompt: prompt,
|
|
215
|
+
toolResults: nil,
|
|
216
|
+
depth: 0,
|
|
217
|
+
onToken: { token in
|
|
218
|
+
if firstTokenTime == nil {
|
|
219
|
+
firstTokenTime = Date()
|
|
220
|
+
}
|
|
221
|
+
tokenCount += 1
|
|
222
|
+
onToken(token)
|
|
223
|
+
},
|
|
224
|
+
onToolCall: onToolCall ?? { _, _ in }
|
|
225
|
+
)
|
|
173
226
|
|
|
174
227
|
let endTime = Date()
|
|
175
228
|
let totalTime = endTime.timeIntervalSince(startTime) * 1000
|
|
@@ -180,7 +233,8 @@ class HybridLLM: HybridLLMSpec {
|
|
|
180
233
|
tokenCount: Double(tokenCount),
|
|
181
234
|
tokensPerSecond: tokensPerSecond,
|
|
182
235
|
timeToFirstToken: timeToFirstToken,
|
|
183
|
-
totalTime: totalTime
|
|
236
|
+
totalTime: totalTime,
|
|
237
|
+
toolExecutionTime: 0
|
|
184
238
|
)
|
|
185
239
|
|
|
186
240
|
log("Stream complete - \(tokenCount) tokens, \(String(format: "%.1f", tokensPerSecond)) tokens/s")
|
|
@@ -188,21 +242,450 @@ class HybridLLM: HybridLLMSpec {
|
|
|
188
242
|
}
|
|
189
243
|
|
|
190
244
|
self.currentTask = task
|
|
245
|
+
defer { self.currentTask = nil }
|
|
191
246
|
|
|
192
|
-
|
|
193
|
-
let result = try await task.value
|
|
194
|
-
self.currentTask = nil
|
|
247
|
+
let result = try await task.value
|
|
195
248
|
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
249
|
+
if self.manageHistory {
|
|
250
|
+
self.messageHistory.append(LLMMessage(role: "assistant", content: result))
|
|
251
|
+
}
|
|
252
|
+
|
|
253
|
+
return result
|
|
254
|
+
}
|
|
255
|
+
}
|
|
256
|
+
|
|
257
|
+
func streamWithEvents(
|
|
258
|
+
prompt: String,
|
|
259
|
+
onEvent: @escaping (String) -> Void
|
|
260
|
+
) throws -> Promise<String> {
|
|
261
|
+
guard let container else {
|
|
262
|
+
throw LLMError.notLoaded
|
|
263
|
+
}
|
|
199
264
|
|
|
265
|
+
return Promise.async { [self] in
|
|
266
|
+
if self.manageHistory {
|
|
267
|
+
self.messageHistory.append(LLMMessage(role: "user", content: prompt))
|
|
268
|
+
}
|
|
269
|
+
|
|
270
|
+
let task = Task<String, Error> {
|
|
271
|
+
let startTime = Date()
|
|
272
|
+
var firstTokenTime: Date?
|
|
273
|
+
var outputTokenCount = 0
|
|
274
|
+
var mlxTokenCount = 0
|
|
275
|
+
var mlxGenerationTime: Double = 0
|
|
276
|
+
var toolExecutionTime: Double = 0
|
|
277
|
+
let emitter = StreamEventEmitter(callback: onEvent)
|
|
278
|
+
|
|
279
|
+
emitter.emitGenerationStart()
|
|
280
|
+
|
|
281
|
+
let result = try await self.performGenerationWithEvents(
|
|
282
|
+
container: container,
|
|
283
|
+
prompt: prompt,
|
|
284
|
+
toolResults: nil,
|
|
285
|
+
depth: 0,
|
|
286
|
+
emitter: emitter,
|
|
287
|
+
onTokenProcessed: {
|
|
288
|
+
if firstTokenTime == nil {
|
|
289
|
+
firstTokenTime = Date()
|
|
290
|
+
}
|
|
291
|
+
outputTokenCount += 1
|
|
292
|
+
},
|
|
293
|
+
onGenerationInfo: { tokens, time in
|
|
294
|
+
mlxTokenCount += tokens
|
|
295
|
+
mlxGenerationTime += time
|
|
296
|
+
},
|
|
297
|
+
toolExecutionTime: &toolExecutionTime
|
|
298
|
+
)
|
|
299
|
+
|
|
300
|
+
let endTime = Date()
|
|
301
|
+
let totalTime = endTime.timeIntervalSince(startTime) * 1000
|
|
302
|
+
let timeToFirstToken = (firstTokenTime ?? endTime).timeIntervalSince(startTime) * 1000
|
|
303
|
+
let tokensPerSecond = mlxGenerationTime > 0 ? Double(mlxTokenCount) / (mlxGenerationTime / 1000) : 0
|
|
304
|
+
|
|
305
|
+
let stats = GenerationStats(
|
|
306
|
+
tokenCount: Double(mlxTokenCount),
|
|
307
|
+
tokensPerSecond: tokensPerSecond,
|
|
308
|
+
timeToFirstToken: timeToFirstToken,
|
|
309
|
+
totalTime: totalTime,
|
|
310
|
+
toolExecutionTime: toolExecutionTime
|
|
311
|
+
)
|
|
312
|
+
|
|
313
|
+
self.lastStats = stats
|
|
314
|
+
emitter.emitGenerationEnd(content: result, stats: stats)
|
|
315
|
+
|
|
316
|
+
log("StreamWithEvents complete - \(mlxTokenCount) tokens, \(String(format: "%.1f", tokensPerSecond)) tokens/s (tool execution: \(String(format: "%.0f", toolExecutionTime))ms)")
|
|
200
317
|
return result
|
|
201
|
-
}
|
|
202
|
-
|
|
203
|
-
|
|
318
|
+
}
|
|
319
|
+
|
|
320
|
+
self.currentTask = task
|
|
321
|
+
defer { self.currentTask = nil }
|
|
322
|
+
|
|
323
|
+
let result = try await task.value
|
|
324
|
+
|
|
325
|
+
if self.manageHistory {
|
|
326
|
+
self.messageHistory.append(LLMMessage(role: "assistant", content: result))
|
|
327
|
+
}
|
|
328
|
+
|
|
329
|
+
return result
|
|
330
|
+
}
|
|
331
|
+
}
|
|
332
|
+
|
|
333
|
+
private func buildChatMessages(
|
|
334
|
+
prompt: String,
|
|
335
|
+
toolResults: [String]?,
|
|
336
|
+
depth: Int
|
|
337
|
+
) -> [Chat.Message] {
|
|
338
|
+
var chat: [Chat.Message] = []
|
|
339
|
+
|
|
340
|
+
if !self.systemPrompt.isEmpty {
|
|
341
|
+
chat.append(.system(self.systemPrompt))
|
|
342
|
+
}
|
|
343
|
+
|
|
344
|
+
for msg in self.messageHistory {
|
|
345
|
+
switch msg.role {
|
|
346
|
+
case "user": chat.append(.user(msg.content))
|
|
347
|
+
case "assistant": chat.append(.assistant(msg.content))
|
|
348
|
+
case "system": chat.append(.system(msg.content))
|
|
349
|
+
case "tool": chat.append(.tool(msg.content))
|
|
350
|
+
default: break
|
|
351
|
+
}
|
|
352
|
+
}
|
|
353
|
+
|
|
354
|
+
if depth == 0 {
|
|
355
|
+
chat.append(.user(prompt))
|
|
356
|
+
}
|
|
357
|
+
|
|
358
|
+
if let toolResults {
|
|
359
|
+
for result in toolResults {
|
|
360
|
+
chat.append(.tool(result))
|
|
204
361
|
}
|
|
205
362
|
}
|
|
363
|
+
|
|
364
|
+
return chat
|
|
365
|
+
}
|
|
366
|
+
|
|
367
|
+
private func executeToolCall(
|
|
368
|
+
tool: ToolDefinition,
|
|
369
|
+
argsDict: [String: Any]
|
|
370
|
+
) async throws -> String {
|
|
371
|
+
let argsAnyMap = self.dictionaryToAnyMap(argsDict)
|
|
372
|
+
let outerPromise = tool.handler(argsAnyMap)
|
|
373
|
+
let innerPromise = try await outerPromise.await()
|
|
374
|
+
let resultAnyMap = try await innerPromise.await()
|
|
375
|
+
let resultDict = self.anyMapToDictionary(resultAnyMap)
|
|
376
|
+
return dictionaryToJson(resultDict)
|
|
377
|
+
}
|
|
378
|
+
|
|
379
|
+
private func performGenerationWithEvents(
|
|
380
|
+
container: ModelContainer,
|
|
381
|
+
prompt: String,
|
|
382
|
+
toolResults: [String]?,
|
|
383
|
+
depth: Int,
|
|
384
|
+
emitter: StreamEventEmitter,
|
|
385
|
+
onTokenProcessed: @escaping () -> Void,
|
|
386
|
+
onGenerationInfo: @escaping (Int, Double) -> Void,
|
|
387
|
+
toolExecutionTime: inout Double
|
|
388
|
+
) async throws -> String {
|
|
389
|
+
if depth >= maxToolCallDepth {
|
|
390
|
+
log("Max tool call depth reached (\(maxToolCallDepth))")
|
|
391
|
+
return ""
|
|
392
|
+
}
|
|
393
|
+
|
|
394
|
+
var output = ""
|
|
395
|
+
var thinkingMachine = ThinkingStateMachine()
|
|
396
|
+
var pendingToolCalls: [(id: String, tool: ToolDefinition, args: [String: Any], argsJson: String)] = []
|
|
397
|
+
|
|
398
|
+
let chat = buildChatMessages(prompt: prompt, toolResults: toolResults, depth: depth)
|
|
399
|
+
let userInput = UserInput(
|
|
400
|
+
chat: chat,
|
|
401
|
+
tools: !self.toolSchemas.isEmpty ? self.toolSchemas : nil
|
|
402
|
+
)
|
|
403
|
+
|
|
404
|
+
let lmInput = try await container.prepare(input: userInput)
|
|
405
|
+
|
|
406
|
+
let stream = try await container.perform { context in
|
|
407
|
+
let parameters = GenerateParameters(maxTokens: 2048, temperature: 0.7)
|
|
408
|
+
return try MLXLMCommon.generate(
|
|
409
|
+
input: lmInput,
|
|
410
|
+
parameters: parameters,
|
|
411
|
+
context: context
|
|
412
|
+
)
|
|
413
|
+
}
|
|
414
|
+
|
|
415
|
+
for await generation in stream {
|
|
416
|
+
if Task.isCancelled { break }
|
|
417
|
+
|
|
418
|
+
switch generation {
|
|
419
|
+
case .chunk(let text):
|
|
420
|
+
let outputs = thinkingMachine.process(token: text)
|
|
421
|
+
|
|
422
|
+
for machineOutput in outputs {
|
|
423
|
+
switch machineOutput {
|
|
424
|
+
case .token(let token):
|
|
425
|
+
output += token
|
|
426
|
+
emitter.emitToken(token)
|
|
427
|
+
onTokenProcessed()
|
|
428
|
+
|
|
429
|
+
case .thinkingStart:
|
|
430
|
+
emitter.emitThinkingStart()
|
|
431
|
+
|
|
432
|
+
case .thinkingChunk(let chunk):
|
|
433
|
+
emitter.emitThinkingChunk(chunk)
|
|
434
|
+
|
|
435
|
+
case .thinkingEnd(let content):
|
|
436
|
+
emitter.emitThinkingEnd(content)
|
|
437
|
+
}
|
|
438
|
+
}
|
|
439
|
+
|
|
440
|
+
case .toolCall(let toolCall):
|
|
441
|
+
log("Tool call detected: \(toolCall.function.name)")
|
|
442
|
+
|
|
443
|
+
guard let tool = self.tools.first(where: { $0.name == toolCall.function.name }) else {
|
|
444
|
+
log("Unknown tool: \(toolCall.function.name)")
|
|
445
|
+
continue
|
|
446
|
+
}
|
|
447
|
+
|
|
448
|
+
let toolCallId = UUID().uuidString
|
|
449
|
+
let argsDict = self.convertToolCallArguments(toolCall.function.arguments)
|
|
450
|
+
let argsJson = dictionaryToJson(argsDict)
|
|
451
|
+
|
|
452
|
+
emitter.emitToolCallStart(id: toolCallId, name: toolCall.function.name, arguments: argsJson)
|
|
453
|
+
pendingToolCalls.append((id: toolCallId, tool: tool, args: argsDict, argsJson: argsJson))
|
|
454
|
+
|
|
455
|
+
case .info(let info):
|
|
456
|
+
log("Generation info: \(info.generationTokenCount) tokens, \(String(format: "%.1f", info.tokensPerSecond)) tokens/s")
|
|
457
|
+
let generationTime = info.tokensPerSecond > 0 ? Double(info.generationTokenCount) / info.tokensPerSecond * 1000 : 0
|
|
458
|
+
onGenerationInfo(info.generationTokenCount, generationTime)
|
|
459
|
+
}
|
|
460
|
+
}
|
|
461
|
+
|
|
462
|
+
let flushOutputs = thinkingMachine.flush()
|
|
463
|
+
for machineOutput in flushOutputs {
|
|
464
|
+
switch machineOutput {
|
|
465
|
+
case .token(let token):
|
|
466
|
+
output += token
|
|
467
|
+
emitter.emitToken(token)
|
|
468
|
+
onTokenProcessed()
|
|
469
|
+
case .thinkingStart:
|
|
470
|
+
emitter.emitThinkingStart()
|
|
471
|
+
case .thinkingChunk(let chunk):
|
|
472
|
+
emitter.emitThinkingChunk(chunk)
|
|
473
|
+
case .thinkingEnd(let content):
|
|
474
|
+
emitter.emitThinkingEnd(content)
|
|
475
|
+
}
|
|
476
|
+
}
|
|
477
|
+
|
|
478
|
+
if !pendingToolCalls.isEmpty {
|
|
479
|
+
log("Executing \(pendingToolCalls.count) tool call(s)")
|
|
480
|
+
|
|
481
|
+
let toolStartTime = Date()
|
|
482
|
+
|
|
483
|
+
for call in pendingToolCalls {
|
|
484
|
+
emitter.emitToolCallExecuting(id: call.id)
|
|
485
|
+
}
|
|
486
|
+
|
|
487
|
+
let allToolResults: [String] = await withTaskGroup(of: (Int, String).self) { group in
|
|
488
|
+
for (index, call) in pendingToolCalls.enumerated() {
|
|
489
|
+
group.addTask { [self] in
|
|
490
|
+
do {
|
|
491
|
+
let resultJson = try await self.executeToolCall(tool: call.tool, argsDict: call.args)
|
|
492
|
+
self.log("Tool result for \(call.tool.name): \(resultJson.prefix(100))...")
|
|
493
|
+
emitter.emitToolCallCompleted(id: call.id, result: resultJson)
|
|
494
|
+
return (index, resultJson)
|
|
495
|
+
} catch {
|
|
496
|
+
self.log("Tool execution error for \(call.tool.name): \(error)")
|
|
497
|
+
emitter.emitToolCallFailed(id: call.id, error: error.localizedDescription)
|
|
498
|
+
return (index, "{\"error\": \"Tool execution failed\"}")
|
|
499
|
+
}
|
|
500
|
+
}
|
|
501
|
+
}
|
|
502
|
+
|
|
503
|
+
var results = Array(repeating: "", count: pendingToolCalls.count)
|
|
504
|
+
for await (index, result) in group {
|
|
505
|
+
results[index] = result
|
|
506
|
+
}
|
|
507
|
+
return results
|
|
508
|
+
}
|
|
509
|
+
|
|
510
|
+
toolExecutionTime += Date().timeIntervalSince(toolStartTime) * 1000
|
|
511
|
+
|
|
512
|
+
if !output.isEmpty {
|
|
513
|
+
self.messageHistory.append(LLMMessage(role: "assistant", content: output))
|
|
514
|
+
}
|
|
515
|
+
|
|
516
|
+
for result in allToolResults {
|
|
517
|
+
self.messageHistory.append(LLMMessage(role: "tool", content: result))
|
|
518
|
+
}
|
|
519
|
+
|
|
520
|
+
let continuation = try await self.performGenerationWithEvents(
|
|
521
|
+
container: container,
|
|
522
|
+
prompt: prompt,
|
|
523
|
+
toolResults: allToolResults,
|
|
524
|
+
depth: depth + 1,
|
|
525
|
+
emitter: emitter,
|
|
526
|
+
onTokenProcessed: onTokenProcessed,
|
|
527
|
+
onGenerationInfo: onGenerationInfo,
|
|
528
|
+
toolExecutionTime: &toolExecutionTime
|
|
529
|
+
)
|
|
530
|
+
|
|
531
|
+
return output + continuation
|
|
532
|
+
}
|
|
533
|
+
|
|
534
|
+
return output
|
|
535
|
+
}
|
|
536
|
+
|
|
537
|
+
private func performGeneration(
|
|
538
|
+
container: ModelContainer,
|
|
539
|
+
prompt: String,
|
|
540
|
+
toolResults: [String]?,
|
|
541
|
+
depth: Int,
|
|
542
|
+
onToken: @escaping (String) -> Void,
|
|
543
|
+
onToolCall: @escaping (String, String) -> Void
|
|
544
|
+
) async throws -> String {
|
|
545
|
+
if depth >= maxToolCallDepth {
|
|
546
|
+
log("Max tool call depth reached (\(maxToolCallDepth))")
|
|
547
|
+
return ""
|
|
548
|
+
}
|
|
549
|
+
|
|
550
|
+
var output = ""
|
|
551
|
+
var pendingToolCalls: [(tool: ToolDefinition, args: [String: Any], argsJson: String)] = []
|
|
552
|
+
|
|
553
|
+
let chat = buildChatMessages(prompt: prompt, toolResults: toolResults, depth: depth)
|
|
554
|
+
let userInput = UserInput(
|
|
555
|
+
chat: chat,
|
|
556
|
+
tools: !self.toolSchemas.isEmpty ? self.toolSchemas : nil
|
|
557
|
+
)
|
|
558
|
+
|
|
559
|
+
let lmInput = try await container.prepare(input: userInput)
|
|
560
|
+
|
|
561
|
+
let stream = try await container.perform { context in
|
|
562
|
+
let parameters = GenerateParameters(maxTokens: 2048, temperature: 0.7)
|
|
563
|
+
return try MLXLMCommon.generate(
|
|
564
|
+
input: lmInput,
|
|
565
|
+
parameters: parameters,
|
|
566
|
+
context: context
|
|
567
|
+
)
|
|
568
|
+
}
|
|
569
|
+
|
|
570
|
+
for await generation in stream {
|
|
571
|
+
if Task.isCancelled { break }
|
|
572
|
+
|
|
573
|
+
switch generation {
|
|
574
|
+
case .chunk(let text):
|
|
575
|
+
output += text
|
|
576
|
+
onToken(text)
|
|
577
|
+
|
|
578
|
+
case .toolCall(let toolCall):
|
|
579
|
+
log("Tool call detected: \(toolCall.function.name)")
|
|
580
|
+
|
|
581
|
+
guard let tool = self.tools.first(where: { $0.name == toolCall.function.name }) else {
|
|
582
|
+
log("Unknown tool: \(toolCall.function.name)")
|
|
583
|
+
continue
|
|
584
|
+
}
|
|
585
|
+
|
|
586
|
+
let argsDict = self.convertToolCallArguments(toolCall.function.arguments)
|
|
587
|
+
let argsJson = dictionaryToJson(argsDict)
|
|
588
|
+
|
|
589
|
+
pendingToolCalls.append((tool: tool, args: argsDict, argsJson: argsJson))
|
|
590
|
+
onToolCall(toolCall.function.name, argsJson)
|
|
591
|
+
|
|
592
|
+
case .info(let info):
|
|
593
|
+
log("Generation info: \(info.generationTokenCount) tokens, \(String(format: "%.1f", info.tokensPerSecond)) tokens/s")
|
|
594
|
+
}
|
|
595
|
+
}
|
|
596
|
+
|
|
597
|
+
if !pendingToolCalls.isEmpty {
|
|
598
|
+
log("Executing \(pendingToolCalls.count) tool call(s)")
|
|
599
|
+
|
|
600
|
+
let allToolResults: [String] = await withTaskGroup(of: (Int, String).self) { group in
|
|
601
|
+
for (index, call) in pendingToolCalls.enumerated() {
|
|
602
|
+
group.addTask { [self] in
|
|
603
|
+
do {
|
|
604
|
+
let resultJson = try await self.executeToolCall(tool: call.tool, argsDict: call.args)
|
|
605
|
+
self.log("Tool result for \(call.tool.name): \(resultJson.prefix(100))...")
|
|
606
|
+
return (index, resultJson)
|
|
607
|
+
} catch {
|
|
608
|
+
self.log("Tool execution error for \(call.tool.name): \(error)")
|
|
609
|
+
return (index, "{\"error\": \"Tool execution failed\"}")
|
|
610
|
+
}
|
|
611
|
+
}
|
|
612
|
+
}
|
|
613
|
+
|
|
614
|
+
var results = Array(repeating: "", count: pendingToolCalls.count)
|
|
615
|
+
for await (index, result) in group {
|
|
616
|
+
results[index] = result
|
|
617
|
+
}
|
|
618
|
+
return results
|
|
619
|
+
}
|
|
620
|
+
|
|
621
|
+
if !output.isEmpty {
|
|
622
|
+
self.messageHistory.append(LLMMessage(role: "assistant", content: output))
|
|
623
|
+
}
|
|
624
|
+
|
|
625
|
+
if depth == 0 {
|
|
626
|
+
self.messageHistory.append(LLMMessage(role: "user", content: prompt))
|
|
627
|
+
}
|
|
628
|
+
|
|
629
|
+
for result in allToolResults {
|
|
630
|
+
self.messageHistory.append(LLMMessage(role: "tool", content: result))
|
|
631
|
+
}
|
|
632
|
+
|
|
633
|
+
onToken("\u{200B}")
|
|
634
|
+
|
|
635
|
+
let continuation = try await self.performGeneration(
|
|
636
|
+
container: container,
|
|
637
|
+
prompt: prompt,
|
|
638
|
+
toolResults: allToolResults,
|
|
639
|
+
depth: depth + 1,
|
|
640
|
+
onToken: onToken,
|
|
641
|
+
onToolCall: onToolCall
|
|
642
|
+
)
|
|
643
|
+
|
|
644
|
+
return output + continuation
|
|
645
|
+
}
|
|
646
|
+
|
|
647
|
+
return output
|
|
648
|
+
}
|
|
649
|
+
|
|
650
|
+
private func convertToolCallArguments(_ arguments: [String: JSONValue]) -> [String: Any] {
|
|
651
|
+
var result: [String: Any] = [:]
|
|
652
|
+
for (key, value) in arguments {
|
|
653
|
+
result[key] = value.anyValue
|
|
654
|
+
}
|
|
655
|
+
return result
|
|
656
|
+
}
|
|
657
|
+
|
|
658
|
+
private func dictionaryToAnyMap(_ dict: [String: Any]) -> AnyMap {
|
|
659
|
+
let anyMap = AnyMap()
|
|
660
|
+
for (key, value) in dict {
|
|
661
|
+
switch value {
|
|
662
|
+
case let stringValue as String:
|
|
663
|
+
anyMap.setString(key: key, value: stringValue)
|
|
664
|
+
case let doubleValue as Double:
|
|
665
|
+
anyMap.setDouble(key: key, value: doubleValue)
|
|
666
|
+
case let intValue as Int:
|
|
667
|
+
anyMap.setDouble(key: key, value: Double(intValue))
|
|
668
|
+
case let boolValue as Bool:
|
|
669
|
+
anyMap.setBoolean(key: key, value: boolValue)
|
|
670
|
+
default:
|
|
671
|
+
anyMap.setString(key: key, value: String(describing: value))
|
|
672
|
+
}
|
|
673
|
+
}
|
|
674
|
+
return anyMap
|
|
675
|
+
}
|
|
676
|
+
|
|
677
|
+
private func anyMapToDictionary(_ anyMap: AnyMap) -> [String: Any] {
|
|
678
|
+
var dict: [String: Any] = [:]
|
|
679
|
+
for key in anyMap.getAllKeys() {
|
|
680
|
+
if anyMap.isString(key: key) {
|
|
681
|
+
dict[key] = anyMap.getString(key: key)
|
|
682
|
+
} else if anyMap.isDouble(key: key) {
|
|
683
|
+
dict[key] = anyMap.getDouble(key: key)
|
|
684
|
+
} else if anyMap.isBool(key: key) {
|
|
685
|
+
dict[key] = anyMap.getBoolean(key: key)
|
|
686
|
+
}
|
|
687
|
+
}
|
|
688
|
+
return dict
|
|
206
689
|
}
|
|
207
690
|
|
|
208
691
|
func stop() throws {
|
|
@@ -211,6 +694,9 @@ class HybridLLM: HybridLLMSpec {
|
|
|
211
694
|
}
|
|
212
695
|
|
|
213
696
|
func unload() throws {
|
|
697
|
+
loadTask?.cancel()
|
|
698
|
+
loadTask = nil
|
|
699
|
+
|
|
214
700
|
let memoryBefore = getMemoryUsage()
|
|
215
701
|
let gpuBefore = getGPUMemoryUsage()
|
|
216
702
|
log("Before unload - Host: \(memoryBefore), GPU: \(gpuBefore)")
|
|
@@ -219,11 +705,13 @@ class HybridLLM: HybridLLMSpec {
|
|
|
219
705
|
currentTask = nil
|
|
220
706
|
session = nil
|
|
221
707
|
container = nil
|
|
708
|
+
tools = []
|
|
709
|
+
toolSchemas = []
|
|
222
710
|
messageHistory = []
|
|
223
711
|
manageHistory = false
|
|
224
712
|
modelId = ""
|
|
225
713
|
|
|
226
|
-
MLX.
|
|
714
|
+
MLX.Memory.clearCache()
|
|
227
715
|
|
|
228
716
|
let memoryAfter = getMemoryUsage()
|
|
229
717
|
let gpuAfter = getGPUMemoryUsage()
|