react-native-nitro-mlx 0.2.2 → 0.4.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/MLXReactNative.podspec +7 -1
- package/ios/Sources/AudioCaptureManager.swift +110 -0
- package/ios/Sources/HybridLLM.swift +518 -42
- package/ios/Sources/HybridSTT.swift +202 -0
- package/ios/Sources/HybridTTS.swift +145 -0
- package/ios/Sources/JSONHelpers.swift +9 -0
- package/ios/Sources/ModelDownloader.swift +26 -12
- package/ios/Sources/StreamEventEmitter.swift +132 -0
- package/ios/Sources/ThinkingStateMachine.swift +206 -0
- package/lib/module/index.js +3 -0
- package/lib/module/index.js.map +1 -1
- package/lib/module/llm.js +72 -4
- package/lib/module/llm.js.map +1 -1
- package/lib/module/models.js +97 -26
- package/lib/module/models.js.map +1 -1
- package/lib/module/specs/STT.nitro.js +4 -0
- package/lib/module/specs/STT.nitro.js.map +1 -0
- package/lib/module/specs/TTS.nitro.js +4 -0
- package/lib/module/specs/TTS.nitro.js.map +1 -0
- package/lib/module/stt.js +49 -0
- package/lib/module/stt.js.map +1 -0
- package/lib/module/tool-utils.js +56 -0
- package/lib/module/tool-utils.js.map +1 -0
- package/lib/module/tts.js +40 -0
- package/lib/module/tts.js.map +1 -0
- package/lib/typescript/src/index.d.ts +8 -3
- package/lib/typescript/src/index.d.ts.map +1 -1
- package/lib/typescript/src/llm.d.ts +46 -4
- package/lib/typescript/src/llm.d.ts.map +1 -1
- package/lib/typescript/src/models.d.ts +13 -4
- package/lib/typescript/src/models.d.ts.map +1 -1
- package/lib/typescript/src/specs/LLM.nitro.d.ts +79 -7
- package/lib/typescript/src/specs/LLM.nitro.d.ts.map +1 -1
- package/lib/typescript/src/specs/STT.nitro.d.ts +28 -0
- package/lib/typescript/src/specs/STT.nitro.d.ts.map +1 -0
- package/lib/typescript/src/specs/TTS.nitro.d.ts +22 -0
- package/lib/typescript/src/specs/TTS.nitro.d.ts.map +1 -0
- package/lib/typescript/src/stt.d.ts +16 -0
- package/lib/typescript/src/stt.d.ts.map +1 -0
- package/lib/typescript/src/tool-utils.d.ts +13 -0
- package/lib/typescript/src/tool-utils.d.ts.map +1 -0
- package/lib/typescript/src/tts.d.ts +13 -0
- package/lib/typescript/src/tts.d.ts.map +1 -0
- package/nitrogen/generated/ios/MLXReactNative+autolinking.rb +1 -1
- package/nitrogen/generated/ios/MLXReactNative-Swift-Cxx-Bridge.cpp +76 -1
- package/nitrogen/generated/ios/MLXReactNative-Swift-Cxx-Bridge.hpp +338 -1
- package/nitrogen/generated/ios/MLXReactNative-Swift-Cxx-Umbrella.hpp +28 -1
- package/nitrogen/generated/ios/MLXReactNativeAutolinking.mm +17 -1
- package/nitrogen/generated/ios/MLXReactNativeAutolinking.swift +31 -1
- package/nitrogen/generated/ios/c++/HybridLLMSpecSwift.cpp +1 -1
- package/nitrogen/generated/ios/c++/HybridLLMSpecSwift.hpp +18 -3
- package/nitrogen/generated/ios/c++/HybridModelManagerSpecSwift.cpp +1 -1
- package/nitrogen/generated/ios/c++/HybridModelManagerSpecSwift.hpp +1 -1
- package/nitrogen/generated/ios/c++/HybridSTTSpecSwift.cpp +11 -0
- package/nitrogen/generated/ios/c++/HybridSTTSpecSwift.hpp +149 -0
- package/nitrogen/generated/ios/c++/HybridTTSSpecSwift.cpp +11 -0
- package/nitrogen/generated/ios/c++/HybridTTSSpecSwift.hpp +128 -0
- package/nitrogen/generated/ios/swift/Func_std__shared_ptr_Promise_std__shared_ptr_Promise_std__shared_ptr_AnyMap______std__shared_ptr_AnyMap_.swift +62 -0
- package/nitrogen/generated/ios/swift/Func_void.swift +1 -1
- package/nitrogen/generated/ios/swift/Func_void_bool.swift +1 -1
- package/nitrogen/generated/ios/swift/Func_void_double.swift +1 -1
- package/nitrogen/generated/ios/swift/Func_void_std__exception_ptr.swift +1 -1
- package/nitrogen/generated/ios/swift/Func_void_std__shared_ptr_AnyMap_.swift +47 -0
- package/nitrogen/generated/ios/swift/Func_void_std__shared_ptr_ArrayBuffer_.swift +47 -0
- package/nitrogen/generated/ios/swift/Func_void_std__shared_ptr_Promise_std__shared_ptr_AnyMap___.swift +67 -0
- package/nitrogen/generated/ios/swift/Func_void_std__string.swift +1 -1
- package/nitrogen/generated/ios/swift/Func_void_std__string_std__string.swift +47 -0
- package/nitrogen/generated/ios/swift/Func_void_std__vector_std__string_.swift +1 -1
- package/nitrogen/generated/ios/swift/GenerationStats.swift +14 -3
- package/nitrogen/generated/ios/swift/HybridLLMSpec.swift +3 -2
- package/nitrogen/generated/ios/swift/HybridLLMSpec_cxx.swift +38 -2
- package/nitrogen/generated/ios/swift/HybridModelManagerSpec.swift +1 -1
- package/nitrogen/generated/ios/swift/HybridModelManagerSpec_cxx.swift +1 -1
- package/nitrogen/generated/ios/swift/HybridSTTSpec.swift +66 -0
- package/nitrogen/generated/ios/swift/HybridSTTSpec_cxx.swift +286 -0
- package/nitrogen/generated/ios/swift/HybridTTSSpec.swift +63 -0
- package/nitrogen/generated/ios/swift/HybridTTSSpec_cxx.swift +229 -0
- package/nitrogen/generated/ios/swift/LLMLoadOptions.swift +44 -2
- package/nitrogen/generated/ios/swift/LLMMessage.swift +1 -1
- package/nitrogen/generated/ios/swift/STTLoadOptions.swift +66 -0
- package/nitrogen/generated/ios/swift/TTSGenerateOptions.swift +78 -0
- package/nitrogen/generated/ios/swift/TTSLoadOptions.swift +66 -0
- package/nitrogen/generated/ios/swift/ToolDefinition.swift +113 -0
- package/nitrogen/generated/ios/swift/ToolParameter.swift +69 -0
- package/nitrogen/generated/shared/c++/GenerationStats.hpp +7 -3
- package/nitrogen/generated/shared/c++/HybridLLMSpec.cpp +2 -1
- package/nitrogen/generated/shared/c++/HybridLLMSpec.hpp +3 -2
- package/nitrogen/generated/shared/c++/HybridModelManagerSpec.cpp +1 -1
- package/nitrogen/generated/shared/c++/HybridModelManagerSpec.hpp +1 -1
- package/nitrogen/generated/shared/c++/HybridSTTSpec.cpp +32 -0
- package/nitrogen/generated/shared/c++/HybridSTTSpec.hpp +78 -0
- package/nitrogen/generated/shared/c++/HybridTTSSpec.cpp +29 -0
- package/nitrogen/generated/shared/c++/HybridTTSSpec.hpp +78 -0
- package/nitrogen/generated/shared/c++/LLMLoadOptions.hpp +10 -3
- package/nitrogen/generated/shared/c++/LLMMessage.hpp +1 -1
- package/nitrogen/generated/shared/c++/STTLoadOptions.hpp +76 -0
- package/nitrogen/generated/shared/c++/TTSGenerateOptions.hpp +80 -0
- package/nitrogen/generated/shared/c++/TTSLoadOptions.hpp +76 -0
- package/nitrogen/generated/shared/c++/ToolDefinition.hpp +93 -0
- package/nitrogen/generated/shared/c++/ToolParameter.hpp +87 -0
- package/package.json +13 -8
- package/src/index.ts +40 -3
- package/src/llm.ts +90 -5
- package/src/models.ts +81 -1
- package/src/specs/LLM.nitro.ts +111 -7
- package/src/specs/STT.nitro.ts +35 -0
- package/src/specs/TTS.nitro.ts +30 -0
- package/src/stt.ts +67 -0
- package/src/tool-utils.ts +74 -0
- package/src/tts.ts +60 -0
|
@@ -3,22 +3,27 @@ import NitroModules
|
|
|
3
3
|
internal import MLX
|
|
4
4
|
internal import MLXLLM
|
|
5
5
|
internal import MLXLMCommon
|
|
6
|
+
internal import Tokenizers
|
|
6
7
|
|
|
7
8
|
class HybridLLM: HybridLLMSpec {
|
|
8
9
|
private var session: ChatSession?
|
|
9
10
|
private var currentTask: Task<String, Error>?
|
|
10
|
-
private var container:
|
|
11
|
+
private var container: ModelContainer?
|
|
11
12
|
private var lastStats: GenerationStats = GenerationStats(
|
|
12
13
|
tokenCount: 0,
|
|
13
14
|
tokensPerSecond: 0,
|
|
14
15
|
timeToFirstToken: 0,
|
|
15
|
-
totalTime: 0
|
|
16
|
+
totalTime: 0,
|
|
17
|
+
toolExecutionTime: 0
|
|
16
18
|
)
|
|
17
19
|
private var modelFactory: ModelFactory = LLMModelFactory.shared
|
|
18
20
|
private var manageHistory: Bool = false
|
|
19
21
|
private var messageHistory: [LLMMessage] = []
|
|
20
22
|
private var loadTask: Task<Void, Error>?
|
|
21
23
|
|
|
24
|
+
private var tools: [ToolDefinition] = []
|
|
25
|
+
private var toolSchemas: [ToolSpec] = []
|
|
26
|
+
|
|
22
27
|
var isLoaded: Bool { session != nil }
|
|
23
28
|
var isGenerating: Bool { currentTask != nil }
|
|
24
29
|
var modelId: String = ""
|
|
@@ -53,7 +58,7 @@ class HybridLLM: HybridLLMSpec {
|
|
|
53
58
|
}
|
|
54
59
|
|
|
55
60
|
private func getGPUMemoryUsage() -> String {
|
|
56
|
-
let snapshot =
|
|
61
|
+
let snapshot = Memory.snapshot()
|
|
57
62
|
let allocatedMB = Float(snapshot.activeMemory) / 1024.0 / 1024.0
|
|
58
63
|
let cacheMB = Float(snapshot.cacheMemory) / 1024.0 / 1024.0
|
|
59
64
|
let peakMB = Float(snapshot.peakMemory) / 1024.0 / 1024.0
|
|
@@ -61,18 +66,48 @@ class HybridLLM: HybridLLMSpec {
|
|
|
61
66
|
allocatedMB, cacheMB, peakMB)
|
|
62
67
|
}
|
|
63
68
|
|
|
69
|
+
private func buildToolSchema(from tool: ToolDefinition) -> ToolSpec {
|
|
70
|
+
var properties: [String: [String: Any]] = [:]
|
|
71
|
+
var required: [String] = []
|
|
72
|
+
|
|
73
|
+
for param in tool.parameters {
|
|
74
|
+
properties[param.name] = [
|
|
75
|
+
"type": param.type,
|
|
76
|
+
"description": param.description
|
|
77
|
+
]
|
|
78
|
+
if param.required {
|
|
79
|
+
required.append(param.name)
|
|
80
|
+
}
|
|
81
|
+
}
|
|
82
|
+
|
|
83
|
+
return [
|
|
84
|
+
"type": "function",
|
|
85
|
+
"function": [
|
|
86
|
+
"name": tool.name,
|
|
87
|
+
"description": tool.description,
|
|
88
|
+
"parameters": [
|
|
89
|
+
"type": "object",
|
|
90
|
+
"properties": properties,
|
|
91
|
+
"required": required
|
|
92
|
+
]
|
|
93
|
+
]
|
|
94
|
+
] as ToolSpec
|
|
95
|
+
}
|
|
96
|
+
|
|
64
97
|
func load(modelId: String, options: LLMLoadOptions?) throws -> Promise<Void> {
|
|
65
98
|
self.loadTask?.cancel()
|
|
66
99
|
|
|
67
100
|
return Promise.async { [self] in
|
|
68
101
|
let task = Task { @MainActor in
|
|
69
|
-
|
|
102
|
+
Memory.cacheLimit = 2000000
|
|
70
103
|
|
|
71
104
|
self.currentTask?.cancel()
|
|
72
105
|
self.currentTask = nil
|
|
73
106
|
self.session = nil
|
|
74
107
|
self.container = nil
|
|
75
|
-
|
|
108
|
+
self.tools = []
|
|
109
|
+
self.toolSchemas = []
|
|
110
|
+
Memory.clearCache()
|
|
76
111
|
|
|
77
112
|
let memoryAfterCleanup = self.getMemoryUsage()
|
|
78
113
|
let gpuAfterCleanup = self.getGPUMemoryUsage()
|
|
@@ -94,6 +129,12 @@ class HybridLLM: HybridLLMSpec {
|
|
|
94
129
|
let gpuAfterContainer = self.getGPUMemoryUsage()
|
|
95
130
|
log("Model loaded - Host: \(memoryAfterContainer), GPU: \(gpuAfterContainer)")
|
|
96
131
|
|
|
132
|
+
if let jsTools = options?.tools {
|
|
133
|
+
self.tools = jsTools
|
|
134
|
+
self.toolSchemas = jsTools.map { self.buildToolSchema(from: $0) }
|
|
135
|
+
log("Loaded \(self.tools.count) tools: \(self.tools.map { $0.name })")
|
|
136
|
+
}
|
|
137
|
+
|
|
97
138
|
let additionalContextDict: [String: Any]? = if let messages = options?.additionalContext {
|
|
98
139
|
["messages": messages.map { ["role": $0.role, "content": $0.content] }]
|
|
99
140
|
} else {
|
|
@@ -135,25 +176,26 @@ class HybridLLM: HybridLLMSpec {
|
|
|
135
176
|
}
|
|
136
177
|
|
|
137
178
|
self.currentTask = task
|
|
179
|
+
defer { self.currentTask = nil }
|
|
138
180
|
|
|
139
|
-
|
|
140
|
-
let result = try await task.value
|
|
141
|
-
self.currentTask = nil
|
|
181
|
+
let result = try await task.value
|
|
142
182
|
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
}
|
|
146
|
-
|
|
147
|
-
return result
|
|
148
|
-
} catch {
|
|
149
|
-
self.currentTask = nil
|
|
150
|
-
throw error
|
|
183
|
+
if self.manageHistory {
|
|
184
|
+
self.messageHistory.append(LLMMessage(role: "assistant", content: result))
|
|
151
185
|
}
|
|
186
|
+
|
|
187
|
+
return result
|
|
152
188
|
}
|
|
153
189
|
}
|
|
154
190
|
|
|
155
|
-
|
|
156
|
-
|
|
191
|
+
private let maxToolCallDepth = 10
|
|
192
|
+
|
|
193
|
+
func stream(
|
|
194
|
+
prompt: String,
|
|
195
|
+
onToken: @escaping (String) -> Void,
|
|
196
|
+
onToolCall: ((String, String) -> Void)?
|
|
197
|
+
) throws -> Promise<String> {
|
|
198
|
+
guard let container else {
|
|
157
199
|
throw LLMError.notLoaded
|
|
158
200
|
}
|
|
159
201
|
|
|
@@ -163,22 +205,24 @@ class HybridLLM: HybridLLMSpec {
|
|
|
163
205
|
}
|
|
164
206
|
|
|
165
207
|
let task = Task<String, Error> {
|
|
166
|
-
var result = ""
|
|
167
|
-
var tokenCount = 0
|
|
168
208
|
let startTime = Date()
|
|
169
209
|
var firstTokenTime: Date?
|
|
210
|
+
var tokenCount = 0
|
|
170
211
|
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
212
|
+
let result = try await self.performGeneration(
|
|
213
|
+
container: container,
|
|
214
|
+
prompt: prompt,
|
|
215
|
+
toolResults: nil,
|
|
216
|
+
depth: 0,
|
|
217
|
+
onToken: { token in
|
|
218
|
+
if firstTokenTime == nil {
|
|
219
|
+
firstTokenTime = Date()
|
|
220
|
+
}
|
|
221
|
+
tokenCount += 1
|
|
222
|
+
onToken(token)
|
|
223
|
+
},
|
|
224
|
+
onToolCall: onToolCall ?? { _, _ in }
|
|
225
|
+
)
|
|
182
226
|
|
|
183
227
|
let endTime = Date()
|
|
184
228
|
let totalTime = endTime.timeIntervalSince(startTime) * 1000
|
|
@@ -189,7 +233,8 @@ class HybridLLM: HybridLLMSpec {
|
|
|
189
233
|
tokenCount: Double(tokenCount),
|
|
190
234
|
tokensPerSecond: tokensPerSecond,
|
|
191
235
|
timeToFirstToken: timeToFirstToken,
|
|
192
|
-
totalTime: totalTime
|
|
236
|
+
totalTime: totalTime,
|
|
237
|
+
toolExecutionTime: 0
|
|
193
238
|
)
|
|
194
239
|
|
|
195
240
|
log("Stream complete - \(tokenCount) tokens, \(String(format: "%.1f", tokensPerSecond)) tokens/s")
|
|
@@ -197,21 +242,450 @@ class HybridLLM: HybridLLMSpec {
|
|
|
197
242
|
}
|
|
198
243
|
|
|
199
244
|
self.currentTask = task
|
|
245
|
+
defer { self.currentTask = nil }
|
|
200
246
|
|
|
201
|
-
|
|
202
|
-
let result = try await task.value
|
|
203
|
-
self.currentTask = nil
|
|
247
|
+
let result = try await task.value
|
|
204
248
|
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
249
|
+
if self.manageHistory {
|
|
250
|
+
self.messageHistory.append(LLMMessage(role: "assistant", content: result))
|
|
251
|
+
}
|
|
252
|
+
|
|
253
|
+
return result
|
|
254
|
+
}
|
|
255
|
+
}
|
|
208
256
|
|
|
257
|
+
func streamWithEvents(
|
|
258
|
+
prompt: String,
|
|
259
|
+
onEvent: @escaping (String) -> Void
|
|
260
|
+
) throws -> Promise<String> {
|
|
261
|
+
guard let container else {
|
|
262
|
+
throw LLMError.notLoaded
|
|
263
|
+
}
|
|
264
|
+
|
|
265
|
+
return Promise.async { [self] in
|
|
266
|
+
if self.manageHistory {
|
|
267
|
+
self.messageHistory.append(LLMMessage(role: "user", content: prompt))
|
|
268
|
+
}
|
|
269
|
+
|
|
270
|
+
let task = Task<String, Error> {
|
|
271
|
+
let startTime = Date()
|
|
272
|
+
var firstTokenTime: Date?
|
|
273
|
+
var outputTokenCount = 0
|
|
274
|
+
var mlxTokenCount = 0
|
|
275
|
+
var mlxGenerationTime: Double = 0
|
|
276
|
+
var toolExecutionTime: Double = 0
|
|
277
|
+
let emitter = StreamEventEmitter(callback: onEvent)
|
|
278
|
+
|
|
279
|
+
emitter.emitGenerationStart()
|
|
280
|
+
|
|
281
|
+
let result = try await self.performGenerationWithEvents(
|
|
282
|
+
container: container,
|
|
283
|
+
prompt: prompt,
|
|
284
|
+
toolResults: nil,
|
|
285
|
+
depth: 0,
|
|
286
|
+
emitter: emitter,
|
|
287
|
+
onTokenProcessed: {
|
|
288
|
+
if firstTokenTime == nil {
|
|
289
|
+
firstTokenTime = Date()
|
|
290
|
+
}
|
|
291
|
+
outputTokenCount += 1
|
|
292
|
+
},
|
|
293
|
+
onGenerationInfo: { tokens, time in
|
|
294
|
+
mlxTokenCount += tokens
|
|
295
|
+
mlxGenerationTime += time
|
|
296
|
+
},
|
|
297
|
+
toolExecutionTime: &toolExecutionTime
|
|
298
|
+
)
|
|
299
|
+
|
|
300
|
+
let endTime = Date()
|
|
301
|
+
let totalTime = endTime.timeIntervalSince(startTime) * 1000
|
|
302
|
+
let timeToFirstToken = (firstTokenTime ?? endTime).timeIntervalSince(startTime) * 1000
|
|
303
|
+
let tokensPerSecond = mlxGenerationTime > 0 ? Double(mlxTokenCount) / (mlxGenerationTime / 1000) : 0
|
|
304
|
+
|
|
305
|
+
let stats = GenerationStats(
|
|
306
|
+
tokenCount: Double(mlxTokenCount),
|
|
307
|
+
tokensPerSecond: tokensPerSecond,
|
|
308
|
+
timeToFirstToken: timeToFirstToken,
|
|
309
|
+
totalTime: totalTime,
|
|
310
|
+
toolExecutionTime: toolExecutionTime
|
|
311
|
+
)
|
|
312
|
+
|
|
313
|
+
self.lastStats = stats
|
|
314
|
+
emitter.emitGenerationEnd(content: result, stats: stats)
|
|
315
|
+
|
|
316
|
+
log("StreamWithEvents complete - \(mlxTokenCount) tokens, \(String(format: "%.1f", tokensPerSecond)) tokens/s (tool execution: \(String(format: "%.0f", toolExecutionTime))ms)")
|
|
209
317
|
return result
|
|
210
|
-
}
|
|
211
|
-
|
|
212
|
-
|
|
318
|
+
}
|
|
319
|
+
|
|
320
|
+
self.currentTask = task
|
|
321
|
+
defer { self.currentTask = nil }
|
|
322
|
+
|
|
323
|
+
let result = try await task.value
|
|
324
|
+
|
|
325
|
+
if self.manageHistory {
|
|
326
|
+
self.messageHistory.append(LLMMessage(role: "assistant", content: result))
|
|
327
|
+
}
|
|
328
|
+
|
|
329
|
+
return result
|
|
330
|
+
}
|
|
331
|
+
}
|
|
332
|
+
|
|
333
|
+
private func buildChatMessages(
|
|
334
|
+
prompt: String,
|
|
335
|
+
toolResults: [String]?,
|
|
336
|
+
depth: Int
|
|
337
|
+
) -> [Chat.Message] {
|
|
338
|
+
var chat: [Chat.Message] = []
|
|
339
|
+
|
|
340
|
+
if !self.systemPrompt.isEmpty {
|
|
341
|
+
chat.append(.system(self.systemPrompt))
|
|
342
|
+
}
|
|
343
|
+
|
|
344
|
+
for msg in self.messageHistory {
|
|
345
|
+
switch msg.role {
|
|
346
|
+
case "user": chat.append(.user(msg.content))
|
|
347
|
+
case "assistant": chat.append(.assistant(msg.content))
|
|
348
|
+
case "system": chat.append(.system(msg.content))
|
|
349
|
+
case "tool": chat.append(.tool(msg.content))
|
|
350
|
+
default: break
|
|
351
|
+
}
|
|
352
|
+
}
|
|
353
|
+
|
|
354
|
+
if depth == 0 {
|
|
355
|
+
chat.append(.user(prompt))
|
|
356
|
+
}
|
|
357
|
+
|
|
358
|
+
if let toolResults {
|
|
359
|
+
for result in toolResults {
|
|
360
|
+
chat.append(.tool(result))
|
|
361
|
+
}
|
|
362
|
+
}
|
|
363
|
+
|
|
364
|
+
return chat
|
|
365
|
+
}
|
|
366
|
+
|
|
367
|
+
private func executeToolCall(
|
|
368
|
+
tool: ToolDefinition,
|
|
369
|
+
argsDict: [String: Any]
|
|
370
|
+
) async throws -> String {
|
|
371
|
+
let argsAnyMap = self.dictionaryToAnyMap(argsDict)
|
|
372
|
+
let outerPromise = tool.handler(argsAnyMap)
|
|
373
|
+
let innerPromise = try await outerPromise.await()
|
|
374
|
+
let resultAnyMap = try await innerPromise.await()
|
|
375
|
+
let resultDict = self.anyMapToDictionary(resultAnyMap)
|
|
376
|
+
return dictionaryToJson(resultDict)
|
|
377
|
+
}
|
|
378
|
+
|
|
379
|
+
private func performGenerationWithEvents(
|
|
380
|
+
container: ModelContainer,
|
|
381
|
+
prompt: String,
|
|
382
|
+
toolResults: [String]?,
|
|
383
|
+
depth: Int,
|
|
384
|
+
emitter: StreamEventEmitter,
|
|
385
|
+
onTokenProcessed: @escaping () -> Void,
|
|
386
|
+
onGenerationInfo: @escaping (Int, Double) -> Void,
|
|
387
|
+
toolExecutionTime: inout Double
|
|
388
|
+
) async throws -> String {
|
|
389
|
+
if depth >= maxToolCallDepth {
|
|
390
|
+
log("Max tool call depth reached (\(maxToolCallDepth))")
|
|
391
|
+
return ""
|
|
392
|
+
}
|
|
393
|
+
|
|
394
|
+
var output = ""
|
|
395
|
+
var thinkingMachine = ThinkingStateMachine()
|
|
396
|
+
var pendingToolCalls: [(id: String, tool: ToolDefinition, args: [String: Any], argsJson: String)] = []
|
|
397
|
+
|
|
398
|
+
let chat = buildChatMessages(prompt: prompt, toolResults: toolResults, depth: depth)
|
|
399
|
+
let userInput = UserInput(
|
|
400
|
+
chat: chat,
|
|
401
|
+
tools: !self.toolSchemas.isEmpty ? self.toolSchemas : nil
|
|
402
|
+
)
|
|
403
|
+
|
|
404
|
+
let lmInput = try await container.prepare(input: userInput)
|
|
405
|
+
|
|
406
|
+
let stream = try await container.perform { context in
|
|
407
|
+
let parameters = GenerateParameters(maxTokens: 2048, temperature: 0.7)
|
|
408
|
+
return try MLXLMCommon.generate(
|
|
409
|
+
input: lmInput,
|
|
410
|
+
parameters: parameters,
|
|
411
|
+
context: context
|
|
412
|
+
)
|
|
413
|
+
}
|
|
414
|
+
|
|
415
|
+
for await generation in stream {
|
|
416
|
+
if Task.isCancelled { break }
|
|
417
|
+
|
|
418
|
+
switch generation {
|
|
419
|
+
case .chunk(let text):
|
|
420
|
+
let outputs = thinkingMachine.process(token: text)
|
|
421
|
+
|
|
422
|
+
for machineOutput in outputs {
|
|
423
|
+
switch machineOutput {
|
|
424
|
+
case .token(let token):
|
|
425
|
+
output += token
|
|
426
|
+
emitter.emitToken(token)
|
|
427
|
+
onTokenProcessed()
|
|
428
|
+
|
|
429
|
+
case .thinkingStart:
|
|
430
|
+
emitter.emitThinkingStart()
|
|
431
|
+
|
|
432
|
+
case .thinkingChunk(let chunk):
|
|
433
|
+
emitter.emitThinkingChunk(chunk)
|
|
434
|
+
|
|
435
|
+
case .thinkingEnd(let content):
|
|
436
|
+
emitter.emitThinkingEnd(content)
|
|
437
|
+
}
|
|
438
|
+
}
|
|
439
|
+
|
|
440
|
+
case .toolCall(let toolCall):
|
|
441
|
+
log("Tool call detected: \(toolCall.function.name)")
|
|
442
|
+
|
|
443
|
+
guard let tool = self.tools.first(where: { $0.name == toolCall.function.name }) else {
|
|
444
|
+
log("Unknown tool: \(toolCall.function.name)")
|
|
445
|
+
continue
|
|
446
|
+
}
|
|
447
|
+
|
|
448
|
+
let toolCallId = UUID().uuidString
|
|
449
|
+
let argsDict = self.convertToolCallArguments(toolCall.function.arguments)
|
|
450
|
+
let argsJson = dictionaryToJson(argsDict)
|
|
451
|
+
|
|
452
|
+
emitter.emitToolCallStart(id: toolCallId, name: toolCall.function.name, arguments: argsJson)
|
|
453
|
+
pendingToolCalls.append((id: toolCallId, tool: tool, args: argsDict, argsJson: argsJson))
|
|
454
|
+
|
|
455
|
+
case .info(let info):
|
|
456
|
+
log("Generation info: \(info.generationTokenCount) tokens, \(String(format: "%.1f", info.tokensPerSecond)) tokens/s")
|
|
457
|
+
let generationTime = info.tokensPerSecond > 0 ? Double(info.generationTokenCount) / info.tokensPerSecond * 1000 : 0
|
|
458
|
+
onGenerationInfo(info.generationTokenCount, generationTime)
|
|
459
|
+
}
|
|
460
|
+
}
|
|
461
|
+
|
|
462
|
+
let flushOutputs = thinkingMachine.flush()
|
|
463
|
+
for machineOutput in flushOutputs {
|
|
464
|
+
switch machineOutput {
|
|
465
|
+
case .token(let token):
|
|
466
|
+
output += token
|
|
467
|
+
emitter.emitToken(token)
|
|
468
|
+
onTokenProcessed()
|
|
469
|
+
case .thinkingStart:
|
|
470
|
+
emitter.emitThinkingStart()
|
|
471
|
+
case .thinkingChunk(let chunk):
|
|
472
|
+
emitter.emitThinkingChunk(chunk)
|
|
473
|
+
case .thinkingEnd(let content):
|
|
474
|
+
emitter.emitThinkingEnd(content)
|
|
475
|
+
}
|
|
476
|
+
}
|
|
477
|
+
|
|
478
|
+
if !pendingToolCalls.isEmpty {
|
|
479
|
+
log("Executing \(pendingToolCalls.count) tool call(s)")
|
|
480
|
+
|
|
481
|
+
let toolStartTime = Date()
|
|
482
|
+
|
|
483
|
+
for call in pendingToolCalls {
|
|
484
|
+
emitter.emitToolCallExecuting(id: call.id)
|
|
485
|
+
}
|
|
486
|
+
|
|
487
|
+
let allToolResults: [String] = await withTaskGroup(of: (Int, String).self) { group in
|
|
488
|
+
for (index, call) in pendingToolCalls.enumerated() {
|
|
489
|
+
group.addTask { [self] in
|
|
490
|
+
do {
|
|
491
|
+
let resultJson = try await self.executeToolCall(tool: call.tool, argsDict: call.args)
|
|
492
|
+
self.log("Tool result for \(call.tool.name): \(resultJson.prefix(100))...")
|
|
493
|
+
emitter.emitToolCallCompleted(id: call.id, result: resultJson)
|
|
494
|
+
return (index, resultJson)
|
|
495
|
+
} catch {
|
|
496
|
+
self.log("Tool execution error for \(call.tool.name): \(error)")
|
|
497
|
+
emitter.emitToolCallFailed(id: call.id, error: error.localizedDescription)
|
|
498
|
+
return (index, "{\"error\": \"Tool execution failed\"}")
|
|
499
|
+
}
|
|
500
|
+
}
|
|
501
|
+
}
|
|
502
|
+
|
|
503
|
+
var results = Array(repeating: "", count: pendingToolCalls.count)
|
|
504
|
+
for await (index, result) in group {
|
|
505
|
+
results[index] = result
|
|
506
|
+
}
|
|
507
|
+
return results
|
|
508
|
+
}
|
|
509
|
+
|
|
510
|
+
toolExecutionTime += Date().timeIntervalSince(toolStartTime) * 1000
|
|
511
|
+
|
|
512
|
+
if !output.isEmpty {
|
|
513
|
+
self.messageHistory.append(LLMMessage(role: "assistant", content: output))
|
|
514
|
+
}
|
|
515
|
+
|
|
516
|
+
for result in allToolResults {
|
|
517
|
+
self.messageHistory.append(LLMMessage(role: "tool", content: result))
|
|
518
|
+
}
|
|
519
|
+
|
|
520
|
+
let continuation = try await self.performGenerationWithEvents(
|
|
521
|
+
container: container,
|
|
522
|
+
prompt: prompt,
|
|
523
|
+
toolResults: allToolResults,
|
|
524
|
+
depth: depth + 1,
|
|
525
|
+
emitter: emitter,
|
|
526
|
+
onTokenProcessed: onTokenProcessed,
|
|
527
|
+
onGenerationInfo: onGenerationInfo,
|
|
528
|
+
toolExecutionTime: &toolExecutionTime
|
|
529
|
+
)
|
|
530
|
+
|
|
531
|
+
return output + continuation
|
|
532
|
+
}
|
|
533
|
+
|
|
534
|
+
return output
|
|
535
|
+
}
|
|
536
|
+
|
|
537
|
+
private func performGeneration(
|
|
538
|
+
container: ModelContainer,
|
|
539
|
+
prompt: String,
|
|
540
|
+
toolResults: [String]?,
|
|
541
|
+
depth: Int,
|
|
542
|
+
onToken: @escaping (String) -> Void,
|
|
543
|
+
onToolCall: @escaping (String, String) -> Void
|
|
544
|
+
) async throws -> String {
|
|
545
|
+
if depth >= maxToolCallDepth {
|
|
546
|
+
log("Max tool call depth reached (\(maxToolCallDepth))")
|
|
547
|
+
return ""
|
|
548
|
+
}
|
|
549
|
+
|
|
550
|
+
var output = ""
|
|
551
|
+
var pendingToolCalls: [(tool: ToolDefinition, args: [String: Any], argsJson: String)] = []
|
|
552
|
+
|
|
553
|
+
let chat = buildChatMessages(prompt: prompt, toolResults: toolResults, depth: depth)
|
|
554
|
+
let userInput = UserInput(
|
|
555
|
+
chat: chat,
|
|
556
|
+
tools: !self.toolSchemas.isEmpty ? self.toolSchemas : nil
|
|
557
|
+
)
|
|
558
|
+
|
|
559
|
+
let lmInput = try await container.prepare(input: userInput)
|
|
560
|
+
|
|
561
|
+
let stream = try await container.perform { context in
|
|
562
|
+
let parameters = GenerateParameters(maxTokens: 2048, temperature: 0.7)
|
|
563
|
+
return try MLXLMCommon.generate(
|
|
564
|
+
input: lmInput,
|
|
565
|
+
parameters: parameters,
|
|
566
|
+
context: context
|
|
567
|
+
)
|
|
568
|
+
}
|
|
569
|
+
|
|
570
|
+
for await generation in stream {
|
|
571
|
+
if Task.isCancelled { break }
|
|
572
|
+
|
|
573
|
+
switch generation {
|
|
574
|
+
case .chunk(let text):
|
|
575
|
+
output += text
|
|
576
|
+
onToken(text)
|
|
577
|
+
|
|
578
|
+
case .toolCall(let toolCall):
|
|
579
|
+
log("Tool call detected: \(toolCall.function.name)")
|
|
580
|
+
|
|
581
|
+
guard let tool = self.tools.first(where: { $0.name == toolCall.function.name }) else {
|
|
582
|
+
log("Unknown tool: \(toolCall.function.name)")
|
|
583
|
+
continue
|
|
584
|
+
}
|
|
585
|
+
|
|
586
|
+
let argsDict = self.convertToolCallArguments(toolCall.function.arguments)
|
|
587
|
+
let argsJson = dictionaryToJson(argsDict)
|
|
588
|
+
|
|
589
|
+
pendingToolCalls.append((tool: tool, args: argsDict, argsJson: argsJson))
|
|
590
|
+
onToolCall(toolCall.function.name, argsJson)
|
|
591
|
+
|
|
592
|
+
case .info(let info):
|
|
593
|
+
log("Generation info: \(info.generationTokenCount) tokens, \(String(format: "%.1f", info.tokensPerSecond)) tokens/s")
|
|
594
|
+
}
|
|
595
|
+
}
|
|
596
|
+
|
|
597
|
+
if !pendingToolCalls.isEmpty {
|
|
598
|
+
log("Executing \(pendingToolCalls.count) tool call(s)")
|
|
599
|
+
|
|
600
|
+
let allToolResults: [String] = await withTaskGroup(of: (Int, String).self) { group in
|
|
601
|
+
for (index, call) in pendingToolCalls.enumerated() {
|
|
602
|
+
group.addTask { [self] in
|
|
603
|
+
do {
|
|
604
|
+
let resultJson = try await self.executeToolCall(tool: call.tool, argsDict: call.args)
|
|
605
|
+
self.log("Tool result for \(call.tool.name): \(resultJson.prefix(100))...")
|
|
606
|
+
return (index, resultJson)
|
|
607
|
+
} catch {
|
|
608
|
+
self.log("Tool execution error for \(call.tool.name): \(error)")
|
|
609
|
+
return (index, "{\"error\": \"Tool execution failed\"}")
|
|
610
|
+
}
|
|
611
|
+
}
|
|
612
|
+
}
|
|
613
|
+
|
|
614
|
+
var results = Array(repeating: "", count: pendingToolCalls.count)
|
|
615
|
+
for await (index, result) in group {
|
|
616
|
+
results[index] = result
|
|
617
|
+
}
|
|
618
|
+
return results
|
|
619
|
+
}
|
|
620
|
+
|
|
621
|
+
if !output.isEmpty {
|
|
622
|
+
self.messageHistory.append(LLMMessage(role: "assistant", content: output))
|
|
623
|
+
}
|
|
624
|
+
|
|
625
|
+
if depth == 0 {
|
|
626
|
+
self.messageHistory.append(LLMMessage(role: "user", content: prompt))
|
|
627
|
+
}
|
|
628
|
+
|
|
629
|
+
for result in allToolResults {
|
|
630
|
+
self.messageHistory.append(LLMMessage(role: "tool", content: result))
|
|
631
|
+
}
|
|
632
|
+
|
|
633
|
+
onToken("\u{200B}")
|
|
634
|
+
|
|
635
|
+
let continuation = try await self.performGeneration(
|
|
636
|
+
container: container,
|
|
637
|
+
prompt: prompt,
|
|
638
|
+
toolResults: allToolResults,
|
|
639
|
+
depth: depth + 1,
|
|
640
|
+
onToken: onToken,
|
|
641
|
+
onToolCall: onToolCall
|
|
642
|
+
)
|
|
643
|
+
|
|
644
|
+
return output + continuation
|
|
645
|
+
}
|
|
646
|
+
|
|
647
|
+
return output
|
|
648
|
+
}
|
|
649
|
+
|
|
650
|
+
private func convertToolCallArguments(_ arguments: [String: JSONValue]) -> [String: Any] {
|
|
651
|
+
var result: [String: Any] = [:]
|
|
652
|
+
for (key, value) in arguments {
|
|
653
|
+
result[key] = value.anyValue
|
|
654
|
+
}
|
|
655
|
+
return result
|
|
656
|
+
}
|
|
657
|
+
|
|
658
|
+
private func dictionaryToAnyMap(_ dict: [String: Any]) -> AnyMap {
|
|
659
|
+
let anyMap = AnyMap()
|
|
660
|
+
for (key, value) in dict {
|
|
661
|
+
switch value {
|
|
662
|
+
case let stringValue as String:
|
|
663
|
+
anyMap.setString(key: key, value: stringValue)
|
|
664
|
+
case let doubleValue as Double:
|
|
665
|
+
anyMap.setDouble(key: key, value: doubleValue)
|
|
666
|
+
case let intValue as Int:
|
|
667
|
+
anyMap.setDouble(key: key, value: Double(intValue))
|
|
668
|
+
case let boolValue as Bool:
|
|
669
|
+
anyMap.setBoolean(key: key, value: boolValue)
|
|
670
|
+
default:
|
|
671
|
+
anyMap.setString(key: key, value: String(describing: value))
|
|
672
|
+
}
|
|
673
|
+
}
|
|
674
|
+
return anyMap
|
|
675
|
+
}
|
|
676
|
+
|
|
677
|
+
private func anyMapToDictionary(_ anyMap: AnyMap) -> [String: Any] {
|
|
678
|
+
var dict: [String: Any] = [:]
|
|
679
|
+
for key in anyMap.getAllKeys() {
|
|
680
|
+
if anyMap.isString(key: key) {
|
|
681
|
+
dict[key] = anyMap.getString(key: key)
|
|
682
|
+
} else if anyMap.isDouble(key: key) {
|
|
683
|
+
dict[key] = anyMap.getDouble(key: key)
|
|
684
|
+
} else if anyMap.isBool(key: key) {
|
|
685
|
+
dict[key] = anyMap.getBoolean(key: key)
|
|
213
686
|
}
|
|
214
687
|
}
|
|
688
|
+
return dict
|
|
215
689
|
}
|
|
216
690
|
|
|
217
691
|
func stop() throws {
|
|
@@ -231,11 +705,13 @@ class HybridLLM: HybridLLMSpec {
|
|
|
231
705
|
currentTask = nil
|
|
232
706
|
session = nil
|
|
233
707
|
container = nil
|
|
708
|
+
tools = []
|
|
709
|
+
toolSchemas = []
|
|
234
710
|
messageHistory = []
|
|
235
711
|
manageHistory = false
|
|
236
712
|
modelId = ""
|
|
237
713
|
|
|
238
|
-
MLX.
|
|
714
|
+
MLX.Memory.clearCache()
|
|
239
715
|
|
|
240
716
|
let memoryAfter = getMemoryUsage()
|
|
241
717
|
let gpuAfter = getGPUMemoryUsage()
|