react-native-nitro-mlx 0.3.0 → 0.4.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/MLXReactNative.podspec +7 -1
- package/ios/Sources/AudioCaptureManager.swift +110 -0
- package/ios/Sources/HybridLLM.swift +309 -68
- package/ios/Sources/HybridSTT.swift +202 -0
- package/ios/Sources/HybridTTS.swift +145 -0
- package/ios/Sources/JSONHelpers.swift +9 -0
- package/ios/Sources/ModelDownloader.swift +26 -12
- package/ios/Sources/StreamEventEmitter.swift +132 -0
- package/ios/Sources/ThinkingStateMachine.swift +206 -0
- package/lib/module/index.js +2 -0
- package/lib/module/index.js.map +1 -1
- package/lib/module/llm.js +39 -1
- package/lib/module/llm.js.map +1 -1
- package/lib/module/models.js +97 -26
- package/lib/module/models.js.map +1 -1
- package/lib/module/specs/STT.nitro.js +4 -0
- package/lib/module/specs/STT.nitro.js.map +1 -0
- package/lib/module/specs/TTS.nitro.js +4 -0
- package/lib/module/specs/TTS.nitro.js.map +1 -0
- package/lib/module/stt.js +49 -0
- package/lib/module/stt.js.map +1 -0
- package/lib/module/tts.js +40 -0
- package/lib/module/tts.js.map +1 -0
- package/lib/typescript/src/index.d.ts +7 -3
- package/lib/typescript/src/index.d.ts.map +1 -1
- package/lib/typescript/src/llm.d.ts +32 -2
- package/lib/typescript/src/llm.d.ts.map +1 -1
- package/lib/typescript/src/models.d.ts +13 -4
- package/lib/typescript/src/models.d.ts.map +1 -1
- package/lib/typescript/src/specs/LLM.nitro.d.ts +49 -4
- package/lib/typescript/src/specs/LLM.nitro.d.ts.map +1 -1
- package/lib/typescript/src/specs/STT.nitro.d.ts +28 -0
- package/lib/typescript/src/specs/STT.nitro.d.ts.map +1 -0
- package/lib/typescript/src/specs/TTS.nitro.d.ts +22 -0
- package/lib/typescript/src/specs/TTS.nitro.d.ts.map +1 -0
- package/lib/typescript/src/stt.d.ts +16 -0
- package/lib/typescript/src/stt.d.ts.map +1 -0
- package/lib/typescript/src/tts.d.ts +13 -0
- package/lib/typescript/src/tts.d.ts.map +1 -0
- package/nitrogen/generated/ios/MLXReactNative-Swift-Cxx-Bridge.cpp +42 -0
- package/nitrogen/generated/ios/MLXReactNative-Swift-Cxx-Bridge.hpp +165 -0
- package/nitrogen/generated/ios/MLXReactNative-Swift-Cxx-Umbrella.hpp +20 -0
- package/nitrogen/generated/ios/MLXReactNativeAutolinking.mm +16 -0
- package/nitrogen/generated/ios/MLXReactNativeAutolinking.swift +30 -0
- package/nitrogen/generated/ios/c++/HybridLLMSpecSwift.hpp +8 -0
- package/nitrogen/generated/ios/c++/HybridSTTSpecSwift.cpp +11 -0
- package/nitrogen/generated/ios/c++/HybridSTTSpecSwift.hpp +149 -0
- package/nitrogen/generated/ios/c++/HybridTTSSpecSwift.cpp +11 -0
- package/nitrogen/generated/ios/c++/HybridTTSSpecSwift.hpp +128 -0
- package/nitrogen/generated/ios/swift/Func_void_std__shared_ptr_ArrayBuffer_.swift +47 -0
- package/nitrogen/generated/ios/swift/GenerationStats.swift +13 -2
- package/nitrogen/generated/ios/swift/HybridLLMSpec.swift +1 -0
- package/nitrogen/generated/ios/swift/HybridLLMSpec_cxx.swift +24 -0
- package/nitrogen/generated/ios/swift/HybridSTTSpec.swift +66 -0
- package/nitrogen/generated/ios/swift/HybridSTTSpec_cxx.swift +286 -0
- package/nitrogen/generated/ios/swift/HybridTTSSpec.swift +63 -0
- package/nitrogen/generated/ios/swift/HybridTTSSpec_cxx.swift +229 -0
- package/nitrogen/generated/ios/swift/STTLoadOptions.swift +66 -0
- package/nitrogen/generated/ios/swift/TTSGenerateOptions.swift +78 -0
- package/nitrogen/generated/ios/swift/TTSLoadOptions.swift +66 -0
- package/nitrogen/generated/shared/c++/GenerationStats.hpp +6 -2
- package/nitrogen/generated/shared/c++/HybridLLMSpec.cpp +1 -0
- package/nitrogen/generated/shared/c++/HybridLLMSpec.hpp +1 -0
- package/nitrogen/generated/shared/c++/HybridSTTSpec.cpp +32 -0
- package/nitrogen/generated/shared/c++/HybridSTTSpec.hpp +78 -0
- package/nitrogen/generated/shared/c++/HybridTTSSpec.cpp +29 -0
- package/nitrogen/generated/shared/c++/HybridTTSSpec.hpp +78 -0
- package/nitrogen/generated/shared/c++/STTLoadOptions.hpp +76 -0
- package/nitrogen/generated/shared/c++/TTSGenerateOptions.hpp +80 -0
- package/nitrogen/generated/shared/c++/TTSLoadOptions.hpp +76 -0
- package/package.json +8 -4
- package/src/index.ts +31 -1
- package/src/llm.ts +48 -2
- package/src/models.ts +81 -1
- package/src/specs/LLM.nitro.ts +74 -4
- package/src/specs/STT.nitro.ts +35 -0
- package/src/specs/TTS.nitro.ts +30 -0
- package/src/stt.ts +67 -0
- package/src/tts.ts +60 -0
package/MLXReactNative.podspec
CHANGED
|
@@ -24,10 +24,16 @@ Pod::Spec.new do |s|
|
|
|
24
24
|
|
|
25
25
|
spm_dependency(s,
|
|
26
26
|
url: "https://github.com/ml-explore/mlx-swift-lm.git",
|
|
27
|
-
requirement: {kind: "upToNextMinorVersion", minimumVersion: "2.
|
|
27
|
+
requirement: {kind: "upToNextMinorVersion", minimumVersion: "2.30.3"},
|
|
28
28
|
products: ["MLXLLM", "MLXLMCommon"]
|
|
29
29
|
)
|
|
30
30
|
|
|
31
|
+
spm_dependency(s,
|
|
32
|
+
url: "https://github.com/Blaizzy/mlx-audio-swift.git",
|
|
33
|
+
requirement: {kind: "branch", branch: "main"},
|
|
34
|
+
products: ["MLXAudioTTS", "MLXAudioSTT", "MLXAudioCore"]
|
|
35
|
+
)
|
|
36
|
+
|
|
31
37
|
s.pod_target_xcconfig = {
|
|
32
38
|
# C++ compiler flags, mainly for folly.
|
|
33
39
|
"GCC_PREPROCESSOR_DEFINITIONS" => "$(inherited) FOLLY_NO_CONFIG FOLLY_CFG_NO_COROUTINES"
|
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
import AVFoundation
|
|
2
|
+
import Foundation
|
|
3
|
+
internal import MLX
|
|
4
|
+
|
|
5
|
+
class AudioCaptureManager {
|
|
6
|
+
private let audioEngine = AVAudioEngine()
|
|
7
|
+
private var audioBuffer: [Float] = []
|
|
8
|
+
private let bufferLock = NSLock()
|
|
9
|
+
private let targetSampleRate: Double = 16000
|
|
10
|
+
|
|
11
|
+
var isCapturing: Bool { audioEngine.isRunning }
|
|
12
|
+
|
|
13
|
+
func startCapturing() async throws {
|
|
14
|
+
let session = AVAudioSession.sharedInstance()
|
|
15
|
+
try session.setCategory(.record, mode: .measurement)
|
|
16
|
+
try session.setActive(true)
|
|
17
|
+
|
|
18
|
+
let inputNode = audioEngine.inputNode
|
|
19
|
+
let inputFormat = inputNode.outputFormat(forBus: 0)
|
|
20
|
+
let outputFormat = AVAudioFormat(
|
|
21
|
+
commonFormat: .pcmFormatFloat32,
|
|
22
|
+
sampleRate: targetSampleRate,
|
|
23
|
+
channels: 1,
|
|
24
|
+
interleaved: false
|
|
25
|
+
)!
|
|
26
|
+
|
|
27
|
+
guard
|
|
28
|
+
let converter = AVAudioConverter(
|
|
29
|
+
from: inputFormat, to: outputFormat)
|
|
30
|
+
else {
|
|
31
|
+
throw NSError(
|
|
32
|
+
domain: "AudioCaptureManager",
|
|
33
|
+
code: -1,
|
|
34
|
+
userInfo: [
|
|
35
|
+
NSLocalizedDescriptionKey:
|
|
36
|
+
"Failed to create audio converter"
|
|
37
|
+
]
|
|
38
|
+
)
|
|
39
|
+
}
|
|
40
|
+
|
|
41
|
+
bufferLock.lock()
|
|
42
|
+
audioBuffer.removeAll()
|
|
43
|
+
bufferLock.unlock()
|
|
44
|
+
|
|
45
|
+
inputNode.installTap(
|
|
46
|
+
onBus: 0, bufferSize: 4096, format: inputFormat
|
|
47
|
+
) { [weak self] buffer, _ in
|
|
48
|
+
guard let self else { return }
|
|
49
|
+
|
|
50
|
+
let frameCount = AVAudioFrameCount(
|
|
51
|
+
targetSampleRate * Double(buffer.frameLength)
|
|
52
|
+
/ inputFormat.sampleRate
|
|
53
|
+
)
|
|
54
|
+
guard
|
|
55
|
+
let convertedBuffer = AVAudioPCMBuffer(
|
|
56
|
+
pcmFormat: outputFormat, frameCapacity: frameCount)
|
|
57
|
+
else { return }
|
|
58
|
+
|
|
59
|
+
var error: NSError?
|
|
60
|
+
converter.convert(to: convertedBuffer, error: &error) {
|
|
61
|
+
_, outStatus in
|
|
62
|
+
outStatus.pointee = .haveData
|
|
63
|
+
return buffer
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
if error == nil, let channelData = convertedBuffer.floatChannelData {
|
|
67
|
+
let frames = Int(convertedBuffer.frameLength)
|
|
68
|
+
self.bufferLock.lock()
|
|
69
|
+
self.audioBuffer.append(
|
|
70
|
+
contentsOf: UnsafeBufferPointer(
|
|
71
|
+
start: channelData[0], count: frames))
|
|
72
|
+
self.bufferLock.unlock()
|
|
73
|
+
}
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
audioEngine.prepare()
|
|
77
|
+
try audioEngine.start()
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
func snapshotAndClear() -> MLXArray? {
|
|
81
|
+
bufferLock.lock()
|
|
82
|
+
let samples = audioBuffer
|
|
83
|
+
audioBuffer.removeAll()
|
|
84
|
+
bufferLock.unlock()
|
|
85
|
+
|
|
86
|
+
guard samples.count >= 8000 else { return nil }
|
|
87
|
+
return MLXArray(samples)
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
func snapshot() -> MLXArray? {
|
|
91
|
+
bufferLock.lock()
|
|
92
|
+
let samples = audioBuffer
|
|
93
|
+
bufferLock.unlock()
|
|
94
|
+
|
|
95
|
+
guard samples.count >= 16000 else { return nil }
|
|
96
|
+
return MLXArray(samples)
|
|
97
|
+
}
|
|
98
|
+
|
|
99
|
+
func stopCapturing() -> MLXArray {
|
|
100
|
+
audioEngine.inputNode.removeTap(onBus: 0)
|
|
101
|
+
audioEngine.stop()
|
|
102
|
+
|
|
103
|
+
bufferLock.lock()
|
|
104
|
+
let samples = audioBuffer
|
|
105
|
+
audioBuffer.removeAll()
|
|
106
|
+
bufferLock.unlock()
|
|
107
|
+
|
|
108
|
+
return MLXArray(samples)
|
|
109
|
+
}
|
|
110
|
+
}
|
|
@@ -8,12 +8,13 @@ internal import Tokenizers
|
|
|
8
8
|
class HybridLLM: HybridLLMSpec {
|
|
9
9
|
private var session: ChatSession?
|
|
10
10
|
private var currentTask: Task<String, Error>?
|
|
11
|
-
private var container:
|
|
11
|
+
private var container: ModelContainer?
|
|
12
12
|
private var lastStats: GenerationStats = GenerationStats(
|
|
13
13
|
tokenCount: 0,
|
|
14
14
|
tokensPerSecond: 0,
|
|
15
15
|
timeToFirstToken: 0,
|
|
16
|
-
totalTime: 0
|
|
16
|
+
totalTime: 0,
|
|
17
|
+
toolExecutionTime: 0
|
|
17
18
|
)
|
|
18
19
|
private var modelFactory: ModelFactory = LLMModelFactory.shared
|
|
19
20
|
private var manageHistory: Bool = false
|
|
@@ -57,7 +58,7 @@ class HybridLLM: HybridLLMSpec {
|
|
|
57
58
|
}
|
|
58
59
|
|
|
59
60
|
private func getGPUMemoryUsage() -> String {
|
|
60
|
-
let snapshot =
|
|
61
|
+
let snapshot = Memory.snapshot()
|
|
61
62
|
let allocatedMB = Float(snapshot.activeMemory) / 1024.0 / 1024.0
|
|
62
63
|
let cacheMB = Float(snapshot.cacheMemory) / 1024.0 / 1024.0
|
|
63
64
|
let peakMB = Float(snapshot.peakMemory) / 1024.0 / 1024.0
|
|
@@ -98,7 +99,7 @@ class HybridLLM: HybridLLMSpec {
|
|
|
98
99
|
|
|
99
100
|
return Promise.async { [self] in
|
|
100
101
|
let task = Task { @MainActor in
|
|
101
|
-
|
|
102
|
+
Memory.cacheLimit = 2000000
|
|
102
103
|
|
|
103
104
|
self.currentTask?.cancel()
|
|
104
105
|
self.currentTask = nil
|
|
@@ -106,7 +107,7 @@ class HybridLLM: HybridLLMSpec {
|
|
|
106
107
|
self.container = nil
|
|
107
108
|
self.tools = []
|
|
108
109
|
self.toolSchemas = []
|
|
109
|
-
|
|
110
|
+
Memory.clearCache()
|
|
110
111
|
|
|
111
112
|
let memoryAfterCleanup = self.getMemoryUsage()
|
|
112
113
|
let gpuAfterCleanup = self.getGPUMemoryUsage()
|
|
@@ -175,20 +176,15 @@ class HybridLLM: HybridLLMSpec {
|
|
|
175
176
|
}
|
|
176
177
|
|
|
177
178
|
self.currentTask = task
|
|
179
|
+
defer { self.currentTask = nil }
|
|
178
180
|
|
|
179
|
-
|
|
180
|
-
let result = try await task.value
|
|
181
|
-
self.currentTask = nil
|
|
182
|
-
|
|
183
|
-
if self.manageHistory {
|
|
184
|
-
self.messageHistory.append(LLMMessage(role: "assistant", content: result))
|
|
185
|
-
}
|
|
181
|
+
let result = try await task.value
|
|
186
182
|
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
self.currentTask = nil
|
|
190
|
-
throw error
|
|
183
|
+
if self.manageHistory {
|
|
184
|
+
self.messageHistory.append(LLMMessage(role: "assistant", content: result))
|
|
191
185
|
}
|
|
186
|
+
|
|
187
|
+
return result
|
|
192
188
|
}
|
|
193
189
|
}
|
|
194
190
|
|
|
@@ -199,7 +195,7 @@ class HybridLLM: HybridLLMSpec {
|
|
|
199
195
|
onToken: @escaping (String) -> Void,
|
|
200
196
|
onToolCall: ((String, String) -> Void)?
|
|
201
197
|
) throws -> Promise<String> {
|
|
202
|
-
guard let container
|
|
198
|
+
guard let container else {
|
|
203
199
|
throw LLMError.notLoaded
|
|
204
200
|
}
|
|
205
201
|
|
|
@@ -237,7 +233,8 @@ class HybridLLM: HybridLLMSpec {
|
|
|
237
233
|
tokenCount: Double(tokenCount),
|
|
238
234
|
tokensPerSecond: tokensPerSecond,
|
|
239
235
|
timeToFirstToken: timeToFirstToken,
|
|
240
|
-
totalTime: totalTime
|
|
236
|
+
totalTime: totalTime,
|
|
237
|
+
toolExecutionTime: 0
|
|
241
238
|
)
|
|
242
239
|
|
|
243
240
|
log("Stream complete - \(tokenCount) tokens, \(String(format: "%.1f", tokensPerSecond)) tokens/s")
|
|
@@ -245,39 +242,99 @@ class HybridLLM: HybridLLMSpec {
|
|
|
245
242
|
}
|
|
246
243
|
|
|
247
244
|
self.currentTask = task
|
|
245
|
+
defer { self.currentTask = nil }
|
|
248
246
|
|
|
249
|
-
|
|
250
|
-
let result = try await task.value
|
|
251
|
-
self.currentTask = nil
|
|
252
|
-
|
|
253
|
-
if self.manageHistory {
|
|
254
|
-
self.messageHistory.append(LLMMessage(role: "assistant", content: result))
|
|
255
|
-
}
|
|
247
|
+
let result = try await task.value
|
|
256
248
|
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
self.currentTask = nil
|
|
260
|
-
throw error
|
|
249
|
+
if self.manageHistory {
|
|
250
|
+
self.messageHistory.append(LLMMessage(role: "assistant", content: result))
|
|
261
251
|
}
|
|
252
|
+
|
|
253
|
+
return result
|
|
262
254
|
}
|
|
263
255
|
}
|
|
264
256
|
|
|
265
|
-
|
|
266
|
-
container: ModelContainer,
|
|
257
|
+
func streamWithEvents(
|
|
267
258
|
prompt: String,
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
) async throws -> String {
|
|
273
|
-
if depth >= maxToolCallDepth {
|
|
274
|
-
log("Max tool call depth reached (\(maxToolCallDepth))")
|
|
275
|
-
return ""
|
|
259
|
+
onEvent: @escaping (String) -> Void
|
|
260
|
+
) throws -> Promise<String> {
|
|
261
|
+
guard let container else {
|
|
262
|
+
throw LLMError.notLoaded
|
|
276
263
|
}
|
|
277
264
|
|
|
278
|
-
|
|
279
|
-
|
|
265
|
+
return Promise.async { [self] in
|
|
266
|
+
if self.manageHistory {
|
|
267
|
+
self.messageHistory.append(LLMMessage(role: "user", content: prompt))
|
|
268
|
+
}
|
|
269
|
+
|
|
270
|
+
let task = Task<String, Error> {
|
|
271
|
+
let startTime = Date()
|
|
272
|
+
var firstTokenTime: Date?
|
|
273
|
+
var outputTokenCount = 0
|
|
274
|
+
var mlxTokenCount = 0
|
|
275
|
+
var mlxGenerationTime: Double = 0
|
|
276
|
+
var toolExecutionTime: Double = 0
|
|
277
|
+
let emitter = StreamEventEmitter(callback: onEvent)
|
|
278
|
+
|
|
279
|
+
emitter.emitGenerationStart()
|
|
280
|
+
|
|
281
|
+
let result = try await self.performGenerationWithEvents(
|
|
282
|
+
container: container,
|
|
283
|
+
prompt: prompt,
|
|
284
|
+
toolResults: nil,
|
|
285
|
+
depth: 0,
|
|
286
|
+
emitter: emitter,
|
|
287
|
+
onTokenProcessed: {
|
|
288
|
+
if firstTokenTime == nil {
|
|
289
|
+
firstTokenTime = Date()
|
|
290
|
+
}
|
|
291
|
+
outputTokenCount += 1
|
|
292
|
+
},
|
|
293
|
+
onGenerationInfo: { tokens, time in
|
|
294
|
+
mlxTokenCount += tokens
|
|
295
|
+
mlxGenerationTime += time
|
|
296
|
+
},
|
|
297
|
+
toolExecutionTime: &toolExecutionTime
|
|
298
|
+
)
|
|
299
|
+
|
|
300
|
+
let endTime = Date()
|
|
301
|
+
let totalTime = endTime.timeIntervalSince(startTime) * 1000
|
|
302
|
+
let timeToFirstToken = (firstTokenTime ?? endTime).timeIntervalSince(startTime) * 1000
|
|
303
|
+
let tokensPerSecond = mlxGenerationTime > 0 ? Double(mlxTokenCount) / (mlxGenerationTime / 1000) : 0
|
|
304
|
+
|
|
305
|
+
let stats = GenerationStats(
|
|
306
|
+
tokenCount: Double(mlxTokenCount),
|
|
307
|
+
tokensPerSecond: tokensPerSecond,
|
|
308
|
+
timeToFirstToken: timeToFirstToken,
|
|
309
|
+
totalTime: totalTime,
|
|
310
|
+
toolExecutionTime: toolExecutionTime
|
|
311
|
+
)
|
|
312
|
+
|
|
313
|
+
self.lastStats = stats
|
|
314
|
+
emitter.emitGenerationEnd(content: result, stats: stats)
|
|
315
|
+
|
|
316
|
+
log("StreamWithEvents complete - \(mlxTokenCount) tokens, \(String(format: "%.1f", tokensPerSecond)) tokens/s (tool execution: \(String(format: "%.0f", toolExecutionTime))ms)")
|
|
317
|
+
return result
|
|
318
|
+
}
|
|
319
|
+
|
|
320
|
+
self.currentTask = task
|
|
321
|
+
defer { self.currentTask = nil }
|
|
322
|
+
|
|
323
|
+
let result = try await task.value
|
|
324
|
+
|
|
325
|
+
if self.manageHistory {
|
|
326
|
+
self.messageHistory.append(LLMMessage(role: "assistant", content: result))
|
|
327
|
+
}
|
|
328
|
+
|
|
329
|
+
return result
|
|
330
|
+
}
|
|
331
|
+
}
|
|
280
332
|
|
|
333
|
+
private func buildChatMessages(
|
|
334
|
+
prompt: String,
|
|
335
|
+
toolResults: [String]?,
|
|
336
|
+
depth: Int
|
|
337
|
+
) -> [Chat.Message] {
|
|
281
338
|
var chat: [Chat.Message] = []
|
|
282
339
|
|
|
283
340
|
if !self.systemPrompt.isEmpty {
|
|
@@ -298,12 +355,202 @@ class HybridLLM: HybridLLMSpec {
|
|
|
298
355
|
chat.append(.user(prompt))
|
|
299
356
|
}
|
|
300
357
|
|
|
301
|
-
if let toolResults
|
|
358
|
+
if let toolResults {
|
|
302
359
|
for result in toolResults {
|
|
303
360
|
chat.append(.tool(result))
|
|
304
361
|
}
|
|
305
362
|
}
|
|
306
363
|
|
|
364
|
+
return chat
|
|
365
|
+
}
|
|
366
|
+
|
|
367
|
+
private func executeToolCall(
|
|
368
|
+
tool: ToolDefinition,
|
|
369
|
+
argsDict: [String: Any]
|
|
370
|
+
) async throws -> String {
|
|
371
|
+
let argsAnyMap = self.dictionaryToAnyMap(argsDict)
|
|
372
|
+
let outerPromise = tool.handler(argsAnyMap)
|
|
373
|
+
let innerPromise = try await outerPromise.await()
|
|
374
|
+
let resultAnyMap = try await innerPromise.await()
|
|
375
|
+
let resultDict = self.anyMapToDictionary(resultAnyMap)
|
|
376
|
+
return dictionaryToJson(resultDict)
|
|
377
|
+
}
|
|
378
|
+
|
|
379
|
+
private func performGenerationWithEvents(
|
|
380
|
+
container: ModelContainer,
|
|
381
|
+
prompt: String,
|
|
382
|
+
toolResults: [String]?,
|
|
383
|
+
depth: Int,
|
|
384
|
+
emitter: StreamEventEmitter,
|
|
385
|
+
onTokenProcessed: @escaping () -> Void,
|
|
386
|
+
onGenerationInfo: @escaping (Int, Double) -> Void,
|
|
387
|
+
toolExecutionTime: inout Double
|
|
388
|
+
) async throws -> String {
|
|
389
|
+
if depth >= maxToolCallDepth {
|
|
390
|
+
log("Max tool call depth reached (\(maxToolCallDepth))")
|
|
391
|
+
return ""
|
|
392
|
+
}
|
|
393
|
+
|
|
394
|
+
var output = ""
|
|
395
|
+
var thinkingMachine = ThinkingStateMachine()
|
|
396
|
+
var pendingToolCalls: [(id: String, tool: ToolDefinition, args: [String: Any], argsJson: String)] = []
|
|
397
|
+
|
|
398
|
+
let chat = buildChatMessages(prompt: prompt, toolResults: toolResults, depth: depth)
|
|
399
|
+
let userInput = UserInput(
|
|
400
|
+
chat: chat,
|
|
401
|
+
tools: !self.toolSchemas.isEmpty ? self.toolSchemas : nil
|
|
402
|
+
)
|
|
403
|
+
|
|
404
|
+
let lmInput = try await container.prepare(input: userInput)
|
|
405
|
+
|
|
406
|
+
let stream = try await container.perform { context in
|
|
407
|
+
let parameters = GenerateParameters(maxTokens: 2048, temperature: 0.7)
|
|
408
|
+
return try MLXLMCommon.generate(
|
|
409
|
+
input: lmInput,
|
|
410
|
+
parameters: parameters,
|
|
411
|
+
context: context
|
|
412
|
+
)
|
|
413
|
+
}
|
|
414
|
+
|
|
415
|
+
for await generation in stream {
|
|
416
|
+
if Task.isCancelled { break }
|
|
417
|
+
|
|
418
|
+
switch generation {
|
|
419
|
+
case .chunk(let text):
|
|
420
|
+
let outputs = thinkingMachine.process(token: text)
|
|
421
|
+
|
|
422
|
+
for machineOutput in outputs {
|
|
423
|
+
switch machineOutput {
|
|
424
|
+
case .token(let token):
|
|
425
|
+
output += token
|
|
426
|
+
emitter.emitToken(token)
|
|
427
|
+
onTokenProcessed()
|
|
428
|
+
|
|
429
|
+
case .thinkingStart:
|
|
430
|
+
emitter.emitThinkingStart()
|
|
431
|
+
|
|
432
|
+
case .thinkingChunk(let chunk):
|
|
433
|
+
emitter.emitThinkingChunk(chunk)
|
|
434
|
+
|
|
435
|
+
case .thinkingEnd(let content):
|
|
436
|
+
emitter.emitThinkingEnd(content)
|
|
437
|
+
}
|
|
438
|
+
}
|
|
439
|
+
|
|
440
|
+
case .toolCall(let toolCall):
|
|
441
|
+
log("Tool call detected: \(toolCall.function.name)")
|
|
442
|
+
|
|
443
|
+
guard let tool = self.tools.first(where: { $0.name == toolCall.function.name }) else {
|
|
444
|
+
log("Unknown tool: \(toolCall.function.name)")
|
|
445
|
+
continue
|
|
446
|
+
}
|
|
447
|
+
|
|
448
|
+
let toolCallId = UUID().uuidString
|
|
449
|
+
let argsDict = self.convertToolCallArguments(toolCall.function.arguments)
|
|
450
|
+
let argsJson = dictionaryToJson(argsDict)
|
|
451
|
+
|
|
452
|
+
emitter.emitToolCallStart(id: toolCallId, name: toolCall.function.name, arguments: argsJson)
|
|
453
|
+
pendingToolCalls.append((id: toolCallId, tool: tool, args: argsDict, argsJson: argsJson))
|
|
454
|
+
|
|
455
|
+
case .info(let info):
|
|
456
|
+
log("Generation info: \(info.generationTokenCount) tokens, \(String(format: "%.1f", info.tokensPerSecond)) tokens/s")
|
|
457
|
+
let generationTime = info.tokensPerSecond > 0 ? Double(info.generationTokenCount) / info.tokensPerSecond * 1000 : 0
|
|
458
|
+
onGenerationInfo(info.generationTokenCount, generationTime)
|
|
459
|
+
}
|
|
460
|
+
}
|
|
461
|
+
|
|
462
|
+
let flushOutputs = thinkingMachine.flush()
|
|
463
|
+
for machineOutput in flushOutputs {
|
|
464
|
+
switch machineOutput {
|
|
465
|
+
case .token(let token):
|
|
466
|
+
output += token
|
|
467
|
+
emitter.emitToken(token)
|
|
468
|
+
onTokenProcessed()
|
|
469
|
+
case .thinkingStart:
|
|
470
|
+
emitter.emitThinkingStart()
|
|
471
|
+
case .thinkingChunk(let chunk):
|
|
472
|
+
emitter.emitThinkingChunk(chunk)
|
|
473
|
+
case .thinkingEnd(let content):
|
|
474
|
+
emitter.emitThinkingEnd(content)
|
|
475
|
+
}
|
|
476
|
+
}
|
|
477
|
+
|
|
478
|
+
if !pendingToolCalls.isEmpty {
|
|
479
|
+
log("Executing \(pendingToolCalls.count) tool call(s)")
|
|
480
|
+
|
|
481
|
+
let toolStartTime = Date()
|
|
482
|
+
|
|
483
|
+
for call in pendingToolCalls {
|
|
484
|
+
emitter.emitToolCallExecuting(id: call.id)
|
|
485
|
+
}
|
|
486
|
+
|
|
487
|
+
let allToolResults: [String] = await withTaskGroup(of: (Int, String).self) { group in
|
|
488
|
+
for (index, call) in pendingToolCalls.enumerated() {
|
|
489
|
+
group.addTask { [self] in
|
|
490
|
+
do {
|
|
491
|
+
let resultJson = try await self.executeToolCall(tool: call.tool, argsDict: call.args)
|
|
492
|
+
self.log("Tool result for \(call.tool.name): \(resultJson.prefix(100))...")
|
|
493
|
+
emitter.emitToolCallCompleted(id: call.id, result: resultJson)
|
|
494
|
+
return (index, resultJson)
|
|
495
|
+
} catch {
|
|
496
|
+
self.log("Tool execution error for \(call.tool.name): \(error)")
|
|
497
|
+
emitter.emitToolCallFailed(id: call.id, error: error.localizedDescription)
|
|
498
|
+
return (index, "{\"error\": \"Tool execution failed\"}")
|
|
499
|
+
}
|
|
500
|
+
}
|
|
501
|
+
}
|
|
502
|
+
|
|
503
|
+
var results = Array(repeating: "", count: pendingToolCalls.count)
|
|
504
|
+
for await (index, result) in group {
|
|
505
|
+
results[index] = result
|
|
506
|
+
}
|
|
507
|
+
return results
|
|
508
|
+
}
|
|
509
|
+
|
|
510
|
+
toolExecutionTime += Date().timeIntervalSince(toolStartTime) * 1000
|
|
511
|
+
|
|
512
|
+
if !output.isEmpty {
|
|
513
|
+
self.messageHistory.append(LLMMessage(role: "assistant", content: output))
|
|
514
|
+
}
|
|
515
|
+
|
|
516
|
+
for result in allToolResults {
|
|
517
|
+
self.messageHistory.append(LLMMessage(role: "tool", content: result))
|
|
518
|
+
}
|
|
519
|
+
|
|
520
|
+
let continuation = try await self.performGenerationWithEvents(
|
|
521
|
+
container: container,
|
|
522
|
+
prompt: prompt,
|
|
523
|
+
toolResults: allToolResults,
|
|
524
|
+
depth: depth + 1,
|
|
525
|
+
emitter: emitter,
|
|
526
|
+
onTokenProcessed: onTokenProcessed,
|
|
527
|
+
onGenerationInfo: onGenerationInfo,
|
|
528
|
+
toolExecutionTime: &toolExecutionTime
|
|
529
|
+
)
|
|
530
|
+
|
|
531
|
+
return output + continuation
|
|
532
|
+
}
|
|
533
|
+
|
|
534
|
+
return output
|
|
535
|
+
}
|
|
536
|
+
|
|
537
|
+
private func performGeneration(
|
|
538
|
+
container: ModelContainer,
|
|
539
|
+
prompt: String,
|
|
540
|
+
toolResults: [String]?,
|
|
541
|
+
depth: Int,
|
|
542
|
+
onToken: @escaping (String) -> Void,
|
|
543
|
+
onToolCall: @escaping (String, String) -> Void
|
|
544
|
+
) async throws -> String {
|
|
545
|
+
if depth >= maxToolCallDepth {
|
|
546
|
+
log("Max tool call depth reached (\(maxToolCallDepth))")
|
|
547
|
+
return ""
|
|
548
|
+
}
|
|
549
|
+
|
|
550
|
+
var output = ""
|
|
551
|
+
var pendingToolCalls: [(tool: ToolDefinition, args: [String: Any], argsJson: String)] = []
|
|
552
|
+
|
|
553
|
+
let chat = buildChatMessages(prompt: prompt, toolResults: toolResults, depth: depth)
|
|
307
554
|
let userInput = UserInput(
|
|
308
555
|
chat: chat,
|
|
309
556
|
tools: !self.toolSchemas.isEmpty ? self.toolSchemas : nil
|
|
@@ -337,7 +584,7 @@ class HybridLLM: HybridLLMSpec {
|
|
|
337
584
|
}
|
|
338
585
|
|
|
339
586
|
let argsDict = self.convertToolCallArguments(toolCall.function.arguments)
|
|
340
|
-
let argsJson =
|
|
587
|
+
let argsJson = dictionaryToJson(argsDict)
|
|
341
588
|
|
|
342
589
|
pendingToolCalls.append((tool: tool, args: argsDict, argsJson: argsJson))
|
|
343
590
|
onToolCall(toolCall.function.name, argsJson)
|
|
@@ -350,23 +597,25 @@ class HybridLLM: HybridLLMSpec {
|
|
|
350
597
|
if !pendingToolCalls.isEmpty {
|
|
351
598
|
log("Executing \(pendingToolCalls.count) tool call(s)")
|
|
352
599
|
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
allToolResults.append(resultJson)
|
|
366
|
-
} catch {
|
|
367
|
-
log("Tool execution error for \(tool.name): \(error)")
|
|
368
|
-
allToolResults.append("{\"error\": \"Tool execution failed\"}")
|
|
600
|
+
let allToolResults: [String] = await withTaskGroup(of: (Int, String).self) { group in
|
|
601
|
+
for (index, call) in pendingToolCalls.enumerated() {
|
|
602
|
+
group.addTask { [self] in
|
|
603
|
+
do {
|
|
604
|
+
let resultJson = try await self.executeToolCall(tool: call.tool, argsDict: call.args)
|
|
605
|
+
self.log("Tool result for \(call.tool.name): \(resultJson.prefix(100))...")
|
|
606
|
+
return (index, resultJson)
|
|
607
|
+
} catch {
|
|
608
|
+
self.log("Tool execution error for \(call.tool.name): \(error)")
|
|
609
|
+
return (index, "{\"error\": \"Tool execution failed\"}")
|
|
610
|
+
}
|
|
611
|
+
}
|
|
369
612
|
}
|
|
613
|
+
|
|
614
|
+
var results = Array(repeating: "", count: pendingToolCalls.count)
|
|
615
|
+
for await (index, result) in group {
|
|
616
|
+
results[index] = result
|
|
617
|
+
}
|
|
618
|
+
return results
|
|
370
619
|
}
|
|
371
620
|
|
|
372
621
|
if !output.isEmpty {
|
|
@@ -406,14 +655,6 @@ class HybridLLM: HybridLLMSpec {
|
|
|
406
655
|
return result
|
|
407
656
|
}
|
|
408
657
|
|
|
409
|
-
private func dictionaryToJson(_ dict: [String: Any]) -> String {
|
|
410
|
-
guard let data = try? JSONSerialization.data(withJSONObject: dict),
|
|
411
|
-
let json = String(data: data, encoding: .utf8) else {
|
|
412
|
-
return "{}"
|
|
413
|
-
}
|
|
414
|
-
return json
|
|
415
|
-
}
|
|
416
|
-
|
|
417
658
|
private func dictionaryToAnyMap(_ dict: [String: Any]) -> AnyMap {
|
|
418
659
|
let anyMap = AnyMap()
|
|
419
660
|
for (key, value) in dict {
|
|
@@ -470,7 +711,7 @@ class HybridLLM: HybridLLMSpec {
|
|
|
470
711
|
manageHistory = false
|
|
471
712
|
modelId = ""
|
|
472
713
|
|
|
473
|
-
MLX.
|
|
714
|
+
MLX.Memory.clearCache()
|
|
474
715
|
|
|
475
716
|
let memoryAfter = getMemoryUsage()
|
|
476
717
|
let gpuAfter = getGPUMemoryUsage()
|