react-native-nitro-mlx 0.3.0 → 0.4.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. package/MLXReactNative.podspec +7 -1
  2. package/ios/Sources/AudioCaptureManager.swift +110 -0
  3. package/ios/Sources/HybridLLM.swift +309 -68
  4. package/ios/Sources/HybridSTT.swift +202 -0
  5. package/ios/Sources/HybridTTS.swift +145 -0
  6. package/ios/Sources/JSONHelpers.swift +9 -0
  7. package/ios/Sources/ModelDownloader.swift +26 -12
  8. package/ios/Sources/StreamEventEmitter.swift +132 -0
  9. package/ios/Sources/ThinkingStateMachine.swift +206 -0
  10. package/lib/module/index.js +2 -0
  11. package/lib/module/index.js.map +1 -1
  12. package/lib/module/llm.js +39 -1
  13. package/lib/module/llm.js.map +1 -1
  14. package/lib/module/models.js +97 -26
  15. package/lib/module/models.js.map +1 -1
  16. package/lib/module/specs/STT.nitro.js +4 -0
  17. package/lib/module/specs/STT.nitro.js.map +1 -0
  18. package/lib/module/specs/TTS.nitro.js +4 -0
  19. package/lib/module/specs/TTS.nitro.js.map +1 -0
  20. package/lib/module/stt.js +49 -0
  21. package/lib/module/stt.js.map +1 -0
  22. package/lib/module/tts.js +40 -0
  23. package/lib/module/tts.js.map +1 -0
  24. package/lib/typescript/src/index.d.ts +7 -3
  25. package/lib/typescript/src/index.d.ts.map +1 -1
  26. package/lib/typescript/src/llm.d.ts +32 -2
  27. package/lib/typescript/src/llm.d.ts.map +1 -1
  28. package/lib/typescript/src/models.d.ts +13 -4
  29. package/lib/typescript/src/models.d.ts.map +1 -1
  30. package/lib/typescript/src/specs/LLM.nitro.d.ts +49 -4
  31. package/lib/typescript/src/specs/LLM.nitro.d.ts.map +1 -1
  32. package/lib/typescript/src/specs/STT.nitro.d.ts +28 -0
  33. package/lib/typescript/src/specs/STT.nitro.d.ts.map +1 -0
  34. package/lib/typescript/src/specs/TTS.nitro.d.ts +22 -0
  35. package/lib/typescript/src/specs/TTS.nitro.d.ts.map +1 -0
  36. package/lib/typescript/src/stt.d.ts +16 -0
  37. package/lib/typescript/src/stt.d.ts.map +1 -0
  38. package/lib/typescript/src/tts.d.ts +13 -0
  39. package/lib/typescript/src/tts.d.ts.map +1 -0
  40. package/nitrogen/generated/ios/MLXReactNative-Swift-Cxx-Bridge.cpp +42 -0
  41. package/nitrogen/generated/ios/MLXReactNative-Swift-Cxx-Bridge.hpp +165 -0
  42. package/nitrogen/generated/ios/MLXReactNative-Swift-Cxx-Umbrella.hpp +20 -0
  43. package/nitrogen/generated/ios/MLXReactNativeAutolinking.mm +16 -0
  44. package/nitrogen/generated/ios/MLXReactNativeAutolinking.swift +30 -0
  45. package/nitrogen/generated/ios/c++/HybridLLMSpecSwift.hpp +8 -0
  46. package/nitrogen/generated/ios/c++/HybridSTTSpecSwift.cpp +11 -0
  47. package/nitrogen/generated/ios/c++/HybridSTTSpecSwift.hpp +149 -0
  48. package/nitrogen/generated/ios/c++/HybridTTSSpecSwift.cpp +11 -0
  49. package/nitrogen/generated/ios/c++/HybridTTSSpecSwift.hpp +128 -0
  50. package/nitrogen/generated/ios/swift/Func_void_std__shared_ptr_ArrayBuffer_.swift +47 -0
  51. package/nitrogen/generated/ios/swift/GenerationStats.swift +13 -2
  52. package/nitrogen/generated/ios/swift/HybridLLMSpec.swift +1 -0
  53. package/nitrogen/generated/ios/swift/HybridLLMSpec_cxx.swift +24 -0
  54. package/nitrogen/generated/ios/swift/HybridSTTSpec.swift +66 -0
  55. package/nitrogen/generated/ios/swift/HybridSTTSpec_cxx.swift +286 -0
  56. package/nitrogen/generated/ios/swift/HybridTTSSpec.swift +63 -0
  57. package/nitrogen/generated/ios/swift/HybridTTSSpec_cxx.swift +229 -0
  58. package/nitrogen/generated/ios/swift/STTLoadOptions.swift +66 -0
  59. package/nitrogen/generated/ios/swift/TTSGenerateOptions.swift +78 -0
  60. package/nitrogen/generated/ios/swift/TTSLoadOptions.swift +66 -0
  61. package/nitrogen/generated/shared/c++/GenerationStats.hpp +6 -2
  62. package/nitrogen/generated/shared/c++/HybridLLMSpec.cpp +1 -0
  63. package/nitrogen/generated/shared/c++/HybridLLMSpec.hpp +1 -0
  64. package/nitrogen/generated/shared/c++/HybridSTTSpec.cpp +32 -0
  65. package/nitrogen/generated/shared/c++/HybridSTTSpec.hpp +78 -0
  66. package/nitrogen/generated/shared/c++/HybridTTSSpec.cpp +29 -0
  67. package/nitrogen/generated/shared/c++/HybridTTSSpec.hpp +78 -0
  68. package/nitrogen/generated/shared/c++/STTLoadOptions.hpp +76 -0
  69. package/nitrogen/generated/shared/c++/TTSGenerateOptions.hpp +80 -0
  70. package/nitrogen/generated/shared/c++/TTSLoadOptions.hpp +76 -0
  71. package/package.json +8 -4
  72. package/src/index.ts +31 -1
  73. package/src/llm.ts +48 -2
  74. package/src/models.ts +81 -1
  75. package/src/specs/LLM.nitro.ts +74 -4
  76. package/src/specs/STT.nitro.ts +35 -0
  77. package/src/specs/TTS.nitro.ts +30 -0
  78. package/src/stt.ts +67 -0
  79. package/src/tts.ts +60 -0
@@ -24,10 +24,16 @@ Pod::Spec.new do |s|
24
24
 
25
25
  spm_dependency(s,
26
26
  url: "https://github.com/ml-explore/mlx-swift-lm.git",
27
- requirement: {kind: "upToNextMinorVersion", minimumVersion: "2.29.3"},
27
+ requirement: {kind: "upToNextMinorVersion", minimumVersion: "2.30.3"},
28
28
  products: ["MLXLLM", "MLXLMCommon"]
29
29
  )
30
30
 
31
+ spm_dependency(s,
32
+ url: "https://github.com/Blaizzy/mlx-audio-swift.git",
33
+ requirement: {kind: "branch", branch: "main"},
34
+ products: ["MLXAudioTTS", "MLXAudioSTT", "MLXAudioCore"]
35
+ )
36
+
31
37
  s.pod_target_xcconfig = {
32
38
  # C++ compiler flags, mainly for folly.
33
39
  "GCC_PREPROCESSOR_DEFINITIONS" => "$(inherited) FOLLY_NO_CONFIG FOLLY_CFG_NO_COROUTINES"
@@ -0,0 +1,110 @@
1
+ import AVFoundation
2
+ import Foundation
3
+ internal import MLX
4
+
5
+ class AudioCaptureManager {
6
+ private let audioEngine = AVAudioEngine()
7
+ private var audioBuffer: [Float] = []
8
+ private let bufferLock = NSLock()
9
+ private let targetSampleRate: Double = 16000
10
+
11
+ var isCapturing: Bool { audioEngine.isRunning }
12
+
13
+ func startCapturing() async throws {
14
+ let session = AVAudioSession.sharedInstance()
15
+ try session.setCategory(.record, mode: .measurement)
16
+ try session.setActive(true)
17
+
18
+ let inputNode = audioEngine.inputNode
19
+ let inputFormat = inputNode.outputFormat(forBus: 0)
20
+ let outputFormat = AVAudioFormat(
21
+ commonFormat: .pcmFormatFloat32,
22
+ sampleRate: targetSampleRate,
23
+ channels: 1,
24
+ interleaved: false
25
+ )!
26
+
27
+ guard
28
+ let converter = AVAudioConverter(
29
+ from: inputFormat, to: outputFormat)
30
+ else {
31
+ throw NSError(
32
+ domain: "AudioCaptureManager",
33
+ code: -1,
34
+ userInfo: [
35
+ NSLocalizedDescriptionKey:
36
+ "Failed to create audio converter"
37
+ ]
38
+ )
39
+ }
40
+
41
+ bufferLock.lock()
42
+ audioBuffer.removeAll()
43
+ bufferLock.unlock()
44
+
45
+ inputNode.installTap(
46
+ onBus: 0, bufferSize: 4096, format: inputFormat
47
+ ) { [weak self] buffer, _ in
48
+ guard let self else { return }
49
+
50
+ let frameCount = AVAudioFrameCount(
51
+ targetSampleRate * Double(buffer.frameLength)
52
+ / inputFormat.sampleRate
53
+ )
54
+ guard
55
+ let convertedBuffer = AVAudioPCMBuffer(
56
+ pcmFormat: outputFormat, frameCapacity: frameCount)
57
+ else { return }
58
+
59
+ var error: NSError?
60
+ converter.convert(to: convertedBuffer, error: &error) {
61
+ _, outStatus in
62
+ outStatus.pointee = .haveData
63
+ return buffer
64
+ }
65
+
66
+ if error == nil, let channelData = convertedBuffer.floatChannelData {
67
+ let frames = Int(convertedBuffer.frameLength)
68
+ self.bufferLock.lock()
69
+ self.audioBuffer.append(
70
+ contentsOf: UnsafeBufferPointer(
71
+ start: channelData[0], count: frames))
72
+ self.bufferLock.unlock()
73
+ }
74
+ }
75
+
76
+ audioEngine.prepare()
77
+ try audioEngine.start()
78
+ }
79
+
80
+ func snapshotAndClear() -> MLXArray? {
81
+ bufferLock.lock()
82
+ let samples = audioBuffer
83
+ audioBuffer.removeAll()
84
+ bufferLock.unlock()
85
+
86
+ guard samples.count >= 8000 else { return nil }
87
+ return MLXArray(samples)
88
+ }
89
+
90
+ func snapshot() -> MLXArray? {
91
+ bufferLock.lock()
92
+ let samples = audioBuffer
93
+ bufferLock.unlock()
94
+
95
+ guard samples.count >= 16000 else { return nil }
96
+ return MLXArray(samples)
97
+ }
98
+
99
+ func stopCapturing() -> MLXArray {
100
+ audioEngine.inputNode.removeTap(onBus: 0)
101
+ audioEngine.stop()
102
+
103
+ bufferLock.lock()
104
+ let samples = audioBuffer
105
+ audioBuffer.removeAll()
106
+ bufferLock.unlock()
107
+
108
+ return MLXArray(samples)
109
+ }
110
+ }
@@ -8,12 +8,13 @@ internal import Tokenizers
8
8
  class HybridLLM: HybridLLMSpec {
9
9
  private var session: ChatSession?
10
10
  private var currentTask: Task<String, Error>?
11
- private var container: Any?
11
+ private var container: ModelContainer?
12
12
  private var lastStats: GenerationStats = GenerationStats(
13
13
  tokenCount: 0,
14
14
  tokensPerSecond: 0,
15
15
  timeToFirstToken: 0,
16
- totalTime: 0
16
+ totalTime: 0,
17
+ toolExecutionTime: 0
17
18
  )
18
19
  private var modelFactory: ModelFactory = LLMModelFactory.shared
19
20
  private var manageHistory: Bool = false
@@ -57,7 +58,7 @@ class HybridLLM: HybridLLMSpec {
57
58
  }
58
59
 
59
60
  private func getGPUMemoryUsage() -> String {
60
- let snapshot = GPU.snapshot()
61
+ let snapshot = Memory.snapshot()
61
62
  let allocatedMB = Float(snapshot.activeMemory) / 1024.0 / 1024.0
62
63
  let cacheMB = Float(snapshot.cacheMemory) / 1024.0 / 1024.0
63
64
  let peakMB = Float(snapshot.peakMemory) / 1024.0 / 1024.0
@@ -98,7 +99,7 @@ class HybridLLM: HybridLLMSpec {
98
99
 
99
100
  return Promise.async { [self] in
100
101
  let task = Task { @MainActor in
101
- MLX.GPU.set(cacheLimit: 2000000)
102
+ Memory.cacheLimit = 2000000
102
103
 
103
104
  self.currentTask?.cancel()
104
105
  self.currentTask = nil
@@ -106,7 +107,7 @@ class HybridLLM: HybridLLMSpec {
106
107
  self.container = nil
107
108
  self.tools = []
108
109
  self.toolSchemas = []
109
- MLX.GPU.clearCache()
110
+ Memory.clearCache()
110
111
 
111
112
  let memoryAfterCleanup = self.getMemoryUsage()
112
113
  let gpuAfterCleanup = self.getGPUMemoryUsage()
@@ -175,20 +176,15 @@ class HybridLLM: HybridLLMSpec {
175
176
  }
176
177
 
177
178
  self.currentTask = task
179
+ defer { self.currentTask = nil }
178
180
 
179
- do {
180
- let result = try await task.value
181
- self.currentTask = nil
182
-
183
- if self.manageHistory {
184
- self.messageHistory.append(LLMMessage(role: "assistant", content: result))
185
- }
181
+ let result = try await task.value
186
182
 
187
- return result
188
- } catch {
189
- self.currentTask = nil
190
- throw error
183
+ if self.manageHistory {
184
+ self.messageHistory.append(LLMMessage(role: "assistant", content: result))
191
185
  }
186
+
187
+ return result
192
188
  }
193
189
  }
194
190
 
@@ -199,7 +195,7 @@ class HybridLLM: HybridLLMSpec {
199
195
  onToken: @escaping (String) -> Void,
200
196
  onToolCall: ((String, String) -> Void)?
201
197
  ) throws -> Promise<String> {
202
- guard let container = container as? ModelContainer else {
198
+ guard let container else {
203
199
  throw LLMError.notLoaded
204
200
  }
205
201
 
@@ -237,7 +233,8 @@ class HybridLLM: HybridLLMSpec {
237
233
  tokenCount: Double(tokenCount),
238
234
  tokensPerSecond: tokensPerSecond,
239
235
  timeToFirstToken: timeToFirstToken,
240
- totalTime: totalTime
236
+ totalTime: totalTime,
237
+ toolExecutionTime: 0
241
238
  )
242
239
 
243
240
  log("Stream complete - \(tokenCount) tokens, \(String(format: "%.1f", tokensPerSecond)) tokens/s")
@@ -245,39 +242,99 @@ class HybridLLM: HybridLLMSpec {
245
242
  }
246
243
 
247
244
  self.currentTask = task
245
+ defer { self.currentTask = nil }
248
246
 
249
- do {
250
- let result = try await task.value
251
- self.currentTask = nil
252
-
253
- if self.manageHistory {
254
- self.messageHistory.append(LLMMessage(role: "assistant", content: result))
255
- }
247
+ let result = try await task.value
256
248
 
257
- return result
258
- } catch {
259
- self.currentTask = nil
260
- throw error
249
+ if self.manageHistory {
250
+ self.messageHistory.append(LLMMessage(role: "assistant", content: result))
261
251
  }
252
+
253
+ return result
262
254
  }
263
255
  }
264
256
 
265
- private func performGeneration(
266
- container: ModelContainer,
257
+ func streamWithEvents(
267
258
  prompt: String,
268
- toolResults: [String]?,
269
- depth: Int,
270
- onToken: @escaping (String) -> Void,
271
- onToolCall: @escaping (String, String) -> Void
272
- ) async throws -> String {
273
- if depth >= maxToolCallDepth {
274
- log("Max tool call depth reached (\(maxToolCallDepth))")
275
- return ""
259
+ onEvent: @escaping (String) -> Void
260
+ ) throws -> Promise<String> {
261
+ guard let container else {
262
+ throw LLMError.notLoaded
276
263
  }
277
264
 
278
- var output = ""
279
- var pendingToolCalls: [(tool: ToolDefinition, args: [String: Any], argsJson: String)] = []
265
+ return Promise.async { [self] in
266
+ if self.manageHistory {
267
+ self.messageHistory.append(LLMMessage(role: "user", content: prompt))
268
+ }
269
+
270
+ let task = Task<String, Error> {
271
+ let startTime = Date()
272
+ var firstTokenTime: Date?
273
+ var outputTokenCount = 0
274
+ var mlxTokenCount = 0
275
+ var mlxGenerationTime: Double = 0
276
+ var toolExecutionTime: Double = 0
277
+ let emitter = StreamEventEmitter(callback: onEvent)
278
+
279
+ emitter.emitGenerationStart()
280
+
281
+ let result = try await self.performGenerationWithEvents(
282
+ container: container,
283
+ prompt: prompt,
284
+ toolResults: nil,
285
+ depth: 0,
286
+ emitter: emitter,
287
+ onTokenProcessed: {
288
+ if firstTokenTime == nil {
289
+ firstTokenTime = Date()
290
+ }
291
+ outputTokenCount += 1
292
+ },
293
+ onGenerationInfo: { tokens, time in
294
+ mlxTokenCount += tokens
295
+ mlxGenerationTime += time
296
+ },
297
+ toolExecutionTime: &toolExecutionTime
298
+ )
299
+
300
+ let endTime = Date()
301
+ let totalTime = endTime.timeIntervalSince(startTime) * 1000
302
+ let timeToFirstToken = (firstTokenTime ?? endTime).timeIntervalSince(startTime) * 1000
303
+ let tokensPerSecond = mlxGenerationTime > 0 ? Double(mlxTokenCount) / (mlxGenerationTime / 1000) : 0
304
+
305
+ let stats = GenerationStats(
306
+ tokenCount: Double(mlxTokenCount),
307
+ tokensPerSecond: tokensPerSecond,
308
+ timeToFirstToken: timeToFirstToken,
309
+ totalTime: totalTime,
310
+ toolExecutionTime: toolExecutionTime
311
+ )
312
+
313
+ self.lastStats = stats
314
+ emitter.emitGenerationEnd(content: result, stats: stats)
315
+
316
+ log("StreamWithEvents complete - \(mlxTokenCount) tokens, \(String(format: "%.1f", tokensPerSecond)) tokens/s (tool execution: \(String(format: "%.0f", toolExecutionTime))ms)")
317
+ return result
318
+ }
319
+
320
+ self.currentTask = task
321
+ defer { self.currentTask = nil }
322
+
323
+ let result = try await task.value
324
+
325
+ if self.manageHistory {
326
+ self.messageHistory.append(LLMMessage(role: "assistant", content: result))
327
+ }
328
+
329
+ return result
330
+ }
331
+ }
280
332
 
333
+ private func buildChatMessages(
334
+ prompt: String,
335
+ toolResults: [String]?,
336
+ depth: Int
337
+ ) -> [Chat.Message] {
281
338
  var chat: [Chat.Message] = []
282
339
 
283
340
  if !self.systemPrompt.isEmpty {
@@ -298,12 +355,202 @@ class HybridLLM: HybridLLMSpec {
298
355
  chat.append(.user(prompt))
299
356
  }
300
357
 
301
- if let toolResults = toolResults {
358
+ if let toolResults {
302
359
  for result in toolResults {
303
360
  chat.append(.tool(result))
304
361
  }
305
362
  }
306
363
 
364
+ return chat
365
+ }
366
+
367
+ private func executeToolCall(
368
+ tool: ToolDefinition,
369
+ argsDict: [String: Any]
370
+ ) async throws -> String {
371
+ let argsAnyMap = self.dictionaryToAnyMap(argsDict)
372
+ let outerPromise = tool.handler(argsAnyMap)
373
+ let innerPromise = try await outerPromise.await()
374
+ let resultAnyMap = try await innerPromise.await()
375
+ let resultDict = self.anyMapToDictionary(resultAnyMap)
376
+ return dictionaryToJson(resultDict)
377
+ }
378
+
379
+ private func performGenerationWithEvents(
380
+ container: ModelContainer,
381
+ prompt: String,
382
+ toolResults: [String]?,
383
+ depth: Int,
384
+ emitter: StreamEventEmitter,
385
+ onTokenProcessed: @escaping () -> Void,
386
+ onGenerationInfo: @escaping (Int, Double) -> Void,
387
+ toolExecutionTime: inout Double
388
+ ) async throws -> String {
389
+ if depth >= maxToolCallDepth {
390
+ log("Max tool call depth reached (\(maxToolCallDepth))")
391
+ return ""
392
+ }
393
+
394
+ var output = ""
395
+ var thinkingMachine = ThinkingStateMachine()
396
+ var pendingToolCalls: [(id: String, tool: ToolDefinition, args: [String: Any], argsJson: String)] = []
397
+
398
+ let chat = buildChatMessages(prompt: prompt, toolResults: toolResults, depth: depth)
399
+ let userInput = UserInput(
400
+ chat: chat,
401
+ tools: !self.toolSchemas.isEmpty ? self.toolSchemas : nil
402
+ )
403
+
404
+ let lmInput = try await container.prepare(input: userInput)
405
+
406
+ let stream = try await container.perform { context in
407
+ let parameters = GenerateParameters(maxTokens: 2048, temperature: 0.7)
408
+ return try MLXLMCommon.generate(
409
+ input: lmInput,
410
+ parameters: parameters,
411
+ context: context
412
+ )
413
+ }
414
+
415
+ for await generation in stream {
416
+ if Task.isCancelled { break }
417
+
418
+ switch generation {
419
+ case .chunk(let text):
420
+ let outputs = thinkingMachine.process(token: text)
421
+
422
+ for machineOutput in outputs {
423
+ switch machineOutput {
424
+ case .token(let token):
425
+ output += token
426
+ emitter.emitToken(token)
427
+ onTokenProcessed()
428
+
429
+ case .thinkingStart:
430
+ emitter.emitThinkingStart()
431
+
432
+ case .thinkingChunk(let chunk):
433
+ emitter.emitThinkingChunk(chunk)
434
+
435
+ case .thinkingEnd(let content):
436
+ emitter.emitThinkingEnd(content)
437
+ }
438
+ }
439
+
440
+ case .toolCall(let toolCall):
441
+ log("Tool call detected: \(toolCall.function.name)")
442
+
443
+ guard let tool = self.tools.first(where: { $0.name == toolCall.function.name }) else {
444
+ log("Unknown tool: \(toolCall.function.name)")
445
+ continue
446
+ }
447
+
448
+ let toolCallId = UUID().uuidString
449
+ let argsDict = self.convertToolCallArguments(toolCall.function.arguments)
450
+ let argsJson = dictionaryToJson(argsDict)
451
+
452
+ emitter.emitToolCallStart(id: toolCallId, name: toolCall.function.name, arguments: argsJson)
453
+ pendingToolCalls.append((id: toolCallId, tool: tool, args: argsDict, argsJson: argsJson))
454
+
455
+ case .info(let info):
456
+ log("Generation info: \(info.generationTokenCount) tokens, \(String(format: "%.1f", info.tokensPerSecond)) tokens/s")
457
+ let generationTime = info.tokensPerSecond > 0 ? Double(info.generationTokenCount) / info.tokensPerSecond * 1000 : 0
458
+ onGenerationInfo(info.generationTokenCount, generationTime)
459
+ }
460
+ }
461
+
462
+ let flushOutputs = thinkingMachine.flush()
463
+ for machineOutput in flushOutputs {
464
+ switch machineOutput {
465
+ case .token(let token):
466
+ output += token
467
+ emitter.emitToken(token)
468
+ onTokenProcessed()
469
+ case .thinkingStart:
470
+ emitter.emitThinkingStart()
471
+ case .thinkingChunk(let chunk):
472
+ emitter.emitThinkingChunk(chunk)
473
+ case .thinkingEnd(let content):
474
+ emitter.emitThinkingEnd(content)
475
+ }
476
+ }
477
+
478
+ if !pendingToolCalls.isEmpty {
479
+ log("Executing \(pendingToolCalls.count) tool call(s)")
480
+
481
+ let toolStartTime = Date()
482
+
483
+ for call in pendingToolCalls {
484
+ emitter.emitToolCallExecuting(id: call.id)
485
+ }
486
+
487
+ let allToolResults: [String] = await withTaskGroup(of: (Int, String).self) { group in
488
+ for (index, call) in pendingToolCalls.enumerated() {
489
+ group.addTask { [self] in
490
+ do {
491
+ let resultJson = try await self.executeToolCall(tool: call.tool, argsDict: call.args)
492
+ self.log("Tool result for \(call.tool.name): \(resultJson.prefix(100))...")
493
+ emitter.emitToolCallCompleted(id: call.id, result: resultJson)
494
+ return (index, resultJson)
495
+ } catch {
496
+ self.log("Tool execution error for \(call.tool.name): \(error)")
497
+ emitter.emitToolCallFailed(id: call.id, error: error.localizedDescription)
498
+ return (index, "{\"error\": \"Tool execution failed\"}")
499
+ }
500
+ }
501
+ }
502
+
503
+ var results = Array(repeating: "", count: pendingToolCalls.count)
504
+ for await (index, result) in group {
505
+ results[index] = result
506
+ }
507
+ return results
508
+ }
509
+
510
+ toolExecutionTime += Date().timeIntervalSince(toolStartTime) * 1000
511
+
512
+ if !output.isEmpty {
513
+ self.messageHistory.append(LLMMessage(role: "assistant", content: output))
514
+ }
515
+
516
+ for result in allToolResults {
517
+ self.messageHistory.append(LLMMessage(role: "tool", content: result))
518
+ }
519
+
520
+ let continuation = try await self.performGenerationWithEvents(
521
+ container: container,
522
+ prompt: prompt,
523
+ toolResults: allToolResults,
524
+ depth: depth + 1,
525
+ emitter: emitter,
526
+ onTokenProcessed: onTokenProcessed,
527
+ onGenerationInfo: onGenerationInfo,
528
+ toolExecutionTime: &toolExecutionTime
529
+ )
530
+
531
+ return output + continuation
532
+ }
533
+
534
+ return output
535
+ }
536
+
537
+ private func performGeneration(
538
+ container: ModelContainer,
539
+ prompt: String,
540
+ toolResults: [String]?,
541
+ depth: Int,
542
+ onToken: @escaping (String) -> Void,
543
+ onToolCall: @escaping (String, String) -> Void
544
+ ) async throws -> String {
545
+ if depth >= maxToolCallDepth {
546
+ log("Max tool call depth reached (\(maxToolCallDepth))")
547
+ return ""
548
+ }
549
+
550
+ var output = ""
551
+ var pendingToolCalls: [(tool: ToolDefinition, args: [String: Any], argsJson: String)] = []
552
+
553
+ let chat = buildChatMessages(prompt: prompt, toolResults: toolResults, depth: depth)
307
554
  let userInput = UserInput(
308
555
  chat: chat,
309
556
  tools: !self.toolSchemas.isEmpty ? self.toolSchemas : nil
@@ -337,7 +584,7 @@ class HybridLLM: HybridLLMSpec {
337
584
  }
338
585
 
339
586
  let argsDict = self.convertToolCallArguments(toolCall.function.arguments)
340
- let argsJson = self.dictionaryToJson(argsDict)
587
+ let argsJson = dictionaryToJson(argsDict)
341
588
 
342
589
  pendingToolCalls.append((tool: tool, args: argsDict, argsJson: argsJson))
343
590
  onToolCall(toolCall.function.name, argsJson)
@@ -350,23 +597,25 @@ class HybridLLM: HybridLLMSpec {
350
597
  if !pendingToolCalls.isEmpty {
351
598
  log("Executing \(pendingToolCalls.count) tool call(s)")
352
599
 
353
- var allToolResults: [String] = []
354
-
355
- for (tool, argsDict, _) in pendingToolCalls {
356
- do {
357
- let argsAnyMap = self.dictionaryToAnyMap(argsDict)
358
- let outerPromise = tool.handler(argsAnyMap)
359
- let innerPromise = try await outerPromise.await()
360
- let resultAnyMap = try await innerPromise.await()
361
- let resultDict = self.anyMapToDictionary(resultAnyMap)
362
- let resultJson = self.dictionaryToJson(resultDict)
363
-
364
- log("Tool result for \(tool.name): \(resultJson.prefix(100))...")
365
- allToolResults.append(resultJson)
366
- } catch {
367
- log("Tool execution error for \(tool.name): \(error)")
368
- allToolResults.append("{\"error\": \"Tool execution failed\"}")
600
+ let allToolResults: [String] = await withTaskGroup(of: (Int, String).self) { group in
601
+ for (index, call) in pendingToolCalls.enumerated() {
602
+ group.addTask { [self] in
603
+ do {
604
+ let resultJson = try await self.executeToolCall(tool: call.tool, argsDict: call.args)
605
+ self.log("Tool result for \(call.tool.name): \(resultJson.prefix(100))...")
606
+ return (index, resultJson)
607
+ } catch {
608
+ self.log("Tool execution error for \(call.tool.name): \(error)")
609
+ return (index, "{\"error\": \"Tool execution failed\"}")
610
+ }
611
+ }
369
612
  }
613
+
614
+ var results = Array(repeating: "", count: pendingToolCalls.count)
615
+ for await (index, result) in group {
616
+ results[index] = result
617
+ }
618
+ return results
370
619
  }
371
620
 
372
621
  if !output.isEmpty {
@@ -406,14 +655,6 @@ class HybridLLM: HybridLLMSpec {
406
655
  return result
407
656
  }
408
657
 
409
- private func dictionaryToJson(_ dict: [String: Any]) -> String {
410
- guard let data = try? JSONSerialization.data(withJSONObject: dict),
411
- let json = String(data: data, encoding: .utf8) else {
412
- return "{}"
413
- }
414
- return json
415
- }
416
-
417
658
  private func dictionaryToAnyMap(_ dict: [String: Any]) -> AnyMap {
418
659
  let anyMap = AnyMap()
419
660
  for (key, value) in dict {
@@ -470,7 +711,7 @@ class HybridLLM: HybridLLMSpec {
470
711
  manageHistory = false
471
712
  modelId = ""
472
713
 
473
- MLX.GPU.clearCache()
714
+ MLX.Memory.clearCache()
474
715
 
475
716
  let memoryAfter = getMemoryUsage()
476
717
  let gpuAfter = getGPUMemoryUsage()