react-native-nitro-mlx 0.2.2 → 0.4.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (110) hide show
  1. package/MLXReactNative.podspec +7 -1
  2. package/ios/Sources/AudioCaptureManager.swift +110 -0
  3. package/ios/Sources/HybridLLM.swift +518 -42
  4. package/ios/Sources/HybridSTT.swift +202 -0
  5. package/ios/Sources/HybridTTS.swift +145 -0
  6. package/ios/Sources/JSONHelpers.swift +9 -0
  7. package/ios/Sources/ModelDownloader.swift +26 -12
  8. package/ios/Sources/StreamEventEmitter.swift +132 -0
  9. package/ios/Sources/ThinkingStateMachine.swift +206 -0
  10. package/lib/module/index.js +3 -0
  11. package/lib/module/index.js.map +1 -1
  12. package/lib/module/llm.js +72 -4
  13. package/lib/module/llm.js.map +1 -1
  14. package/lib/module/models.js +97 -26
  15. package/lib/module/models.js.map +1 -1
  16. package/lib/module/specs/STT.nitro.js +4 -0
  17. package/lib/module/specs/STT.nitro.js.map +1 -0
  18. package/lib/module/specs/TTS.nitro.js +4 -0
  19. package/lib/module/specs/TTS.nitro.js.map +1 -0
  20. package/lib/module/stt.js +49 -0
  21. package/lib/module/stt.js.map +1 -0
  22. package/lib/module/tool-utils.js +56 -0
  23. package/lib/module/tool-utils.js.map +1 -0
  24. package/lib/module/tts.js +40 -0
  25. package/lib/module/tts.js.map +1 -0
  26. package/lib/typescript/src/index.d.ts +8 -3
  27. package/lib/typescript/src/index.d.ts.map +1 -1
  28. package/lib/typescript/src/llm.d.ts +46 -4
  29. package/lib/typescript/src/llm.d.ts.map +1 -1
  30. package/lib/typescript/src/models.d.ts +13 -4
  31. package/lib/typescript/src/models.d.ts.map +1 -1
  32. package/lib/typescript/src/specs/LLM.nitro.d.ts +79 -7
  33. package/lib/typescript/src/specs/LLM.nitro.d.ts.map +1 -1
  34. package/lib/typescript/src/specs/STT.nitro.d.ts +28 -0
  35. package/lib/typescript/src/specs/STT.nitro.d.ts.map +1 -0
  36. package/lib/typescript/src/specs/TTS.nitro.d.ts +22 -0
  37. package/lib/typescript/src/specs/TTS.nitro.d.ts.map +1 -0
  38. package/lib/typescript/src/stt.d.ts +16 -0
  39. package/lib/typescript/src/stt.d.ts.map +1 -0
  40. package/lib/typescript/src/tool-utils.d.ts +13 -0
  41. package/lib/typescript/src/tool-utils.d.ts.map +1 -0
  42. package/lib/typescript/src/tts.d.ts +13 -0
  43. package/lib/typescript/src/tts.d.ts.map +1 -0
  44. package/nitrogen/generated/ios/MLXReactNative+autolinking.rb +1 -1
  45. package/nitrogen/generated/ios/MLXReactNative-Swift-Cxx-Bridge.cpp +76 -1
  46. package/nitrogen/generated/ios/MLXReactNative-Swift-Cxx-Bridge.hpp +338 -1
  47. package/nitrogen/generated/ios/MLXReactNative-Swift-Cxx-Umbrella.hpp +28 -1
  48. package/nitrogen/generated/ios/MLXReactNativeAutolinking.mm +17 -1
  49. package/nitrogen/generated/ios/MLXReactNativeAutolinking.swift +31 -1
  50. package/nitrogen/generated/ios/c++/HybridLLMSpecSwift.cpp +1 -1
  51. package/nitrogen/generated/ios/c++/HybridLLMSpecSwift.hpp +18 -3
  52. package/nitrogen/generated/ios/c++/HybridModelManagerSpecSwift.cpp +1 -1
  53. package/nitrogen/generated/ios/c++/HybridModelManagerSpecSwift.hpp +1 -1
  54. package/nitrogen/generated/ios/c++/HybridSTTSpecSwift.cpp +11 -0
  55. package/nitrogen/generated/ios/c++/HybridSTTSpecSwift.hpp +149 -0
  56. package/nitrogen/generated/ios/c++/HybridTTSSpecSwift.cpp +11 -0
  57. package/nitrogen/generated/ios/c++/HybridTTSSpecSwift.hpp +128 -0
  58. package/nitrogen/generated/ios/swift/Func_std__shared_ptr_Promise_std__shared_ptr_Promise_std__shared_ptr_AnyMap______std__shared_ptr_AnyMap_.swift +62 -0
  59. package/nitrogen/generated/ios/swift/Func_void.swift +1 -1
  60. package/nitrogen/generated/ios/swift/Func_void_bool.swift +1 -1
  61. package/nitrogen/generated/ios/swift/Func_void_double.swift +1 -1
  62. package/nitrogen/generated/ios/swift/Func_void_std__exception_ptr.swift +1 -1
  63. package/nitrogen/generated/ios/swift/Func_void_std__shared_ptr_AnyMap_.swift +47 -0
  64. package/nitrogen/generated/ios/swift/Func_void_std__shared_ptr_ArrayBuffer_.swift +47 -0
  65. package/nitrogen/generated/ios/swift/Func_void_std__shared_ptr_Promise_std__shared_ptr_AnyMap___.swift +67 -0
  66. package/nitrogen/generated/ios/swift/Func_void_std__string.swift +1 -1
  67. package/nitrogen/generated/ios/swift/Func_void_std__string_std__string.swift +47 -0
  68. package/nitrogen/generated/ios/swift/Func_void_std__vector_std__string_.swift +1 -1
  69. package/nitrogen/generated/ios/swift/GenerationStats.swift +14 -3
  70. package/nitrogen/generated/ios/swift/HybridLLMSpec.swift +3 -2
  71. package/nitrogen/generated/ios/swift/HybridLLMSpec_cxx.swift +38 -2
  72. package/nitrogen/generated/ios/swift/HybridModelManagerSpec.swift +1 -1
  73. package/nitrogen/generated/ios/swift/HybridModelManagerSpec_cxx.swift +1 -1
  74. package/nitrogen/generated/ios/swift/HybridSTTSpec.swift +66 -0
  75. package/nitrogen/generated/ios/swift/HybridSTTSpec_cxx.swift +286 -0
  76. package/nitrogen/generated/ios/swift/HybridTTSSpec.swift +63 -0
  77. package/nitrogen/generated/ios/swift/HybridTTSSpec_cxx.swift +229 -0
  78. package/nitrogen/generated/ios/swift/LLMLoadOptions.swift +44 -2
  79. package/nitrogen/generated/ios/swift/LLMMessage.swift +1 -1
  80. package/nitrogen/generated/ios/swift/STTLoadOptions.swift +66 -0
  81. package/nitrogen/generated/ios/swift/TTSGenerateOptions.swift +78 -0
  82. package/nitrogen/generated/ios/swift/TTSLoadOptions.swift +66 -0
  83. package/nitrogen/generated/ios/swift/ToolDefinition.swift +113 -0
  84. package/nitrogen/generated/ios/swift/ToolParameter.swift +69 -0
  85. package/nitrogen/generated/shared/c++/GenerationStats.hpp +7 -3
  86. package/nitrogen/generated/shared/c++/HybridLLMSpec.cpp +2 -1
  87. package/nitrogen/generated/shared/c++/HybridLLMSpec.hpp +3 -2
  88. package/nitrogen/generated/shared/c++/HybridModelManagerSpec.cpp +1 -1
  89. package/nitrogen/generated/shared/c++/HybridModelManagerSpec.hpp +1 -1
  90. package/nitrogen/generated/shared/c++/HybridSTTSpec.cpp +32 -0
  91. package/nitrogen/generated/shared/c++/HybridSTTSpec.hpp +78 -0
  92. package/nitrogen/generated/shared/c++/HybridTTSSpec.cpp +29 -0
  93. package/nitrogen/generated/shared/c++/HybridTTSSpec.hpp +78 -0
  94. package/nitrogen/generated/shared/c++/LLMLoadOptions.hpp +10 -3
  95. package/nitrogen/generated/shared/c++/LLMMessage.hpp +1 -1
  96. package/nitrogen/generated/shared/c++/STTLoadOptions.hpp +76 -0
  97. package/nitrogen/generated/shared/c++/TTSGenerateOptions.hpp +80 -0
  98. package/nitrogen/generated/shared/c++/TTSLoadOptions.hpp +76 -0
  99. package/nitrogen/generated/shared/c++/ToolDefinition.hpp +93 -0
  100. package/nitrogen/generated/shared/c++/ToolParameter.hpp +87 -0
  101. package/package.json +13 -8
  102. package/src/index.ts +40 -3
  103. package/src/llm.ts +90 -5
  104. package/src/models.ts +81 -1
  105. package/src/specs/LLM.nitro.ts +111 -7
  106. package/src/specs/STT.nitro.ts +35 -0
  107. package/src/specs/TTS.nitro.ts +30 -0
  108. package/src/stt.ts +67 -0
  109. package/src/tool-utils.ts +74 -0
  110. package/src/tts.ts +60 -0
@@ -3,22 +3,27 @@ import NitroModules
3
3
  internal import MLX
4
4
  internal import MLXLLM
5
5
  internal import MLXLMCommon
6
+ internal import Tokenizers
6
7
 
7
8
  class HybridLLM: HybridLLMSpec {
8
9
  private var session: ChatSession?
9
10
  private var currentTask: Task<String, Error>?
10
- private var container: Any?
11
+ private var container: ModelContainer?
11
12
  private var lastStats: GenerationStats = GenerationStats(
12
13
  tokenCount: 0,
13
14
  tokensPerSecond: 0,
14
15
  timeToFirstToken: 0,
15
- totalTime: 0
16
+ totalTime: 0,
17
+ toolExecutionTime: 0
16
18
  )
17
19
  private var modelFactory: ModelFactory = LLMModelFactory.shared
18
20
  private var manageHistory: Bool = false
19
21
  private var messageHistory: [LLMMessage] = []
20
22
  private var loadTask: Task<Void, Error>?
21
23
 
24
+ private var tools: [ToolDefinition] = []
25
+ private var toolSchemas: [ToolSpec] = []
26
+
22
27
  var isLoaded: Bool { session != nil }
23
28
  var isGenerating: Bool { currentTask != nil }
24
29
  var modelId: String = ""
@@ -53,7 +58,7 @@ class HybridLLM: HybridLLMSpec {
53
58
  }
54
59
 
55
60
  private func getGPUMemoryUsage() -> String {
56
- let snapshot = GPU.snapshot()
61
+ let snapshot = Memory.snapshot()
57
62
  let allocatedMB = Float(snapshot.activeMemory) / 1024.0 / 1024.0
58
63
  let cacheMB = Float(snapshot.cacheMemory) / 1024.0 / 1024.0
59
64
  let peakMB = Float(snapshot.peakMemory) / 1024.0 / 1024.0
@@ -61,18 +66,48 @@ class HybridLLM: HybridLLMSpec {
61
66
  allocatedMB, cacheMB, peakMB)
62
67
  }
63
68
 
69
+ private func buildToolSchema(from tool: ToolDefinition) -> ToolSpec {
70
+ var properties: [String: [String: Any]] = [:]
71
+ var required: [String] = []
72
+
73
+ for param in tool.parameters {
74
+ properties[param.name] = [
75
+ "type": param.type,
76
+ "description": param.description
77
+ ]
78
+ if param.required {
79
+ required.append(param.name)
80
+ }
81
+ }
82
+
83
+ return [
84
+ "type": "function",
85
+ "function": [
86
+ "name": tool.name,
87
+ "description": tool.description,
88
+ "parameters": [
89
+ "type": "object",
90
+ "properties": properties,
91
+ "required": required
92
+ ]
93
+ ]
94
+ ] as ToolSpec
95
+ }
96
+
64
97
  func load(modelId: String, options: LLMLoadOptions?) throws -> Promise<Void> {
65
98
  self.loadTask?.cancel()
66
99
 
67
100
  return Promise.async { [self] in
68
101
  let task = Task { @MainActor in
69
- MLX.GPU.set(cacheLimit: 2000000)
102
+ Memory.cacheLimit = 2000000
70
103
 
71
104
  self.currentTask?.cancel()
72
105
  self.currentTask = nil
73
106
  self.session = nil
74
107
  self.container = nil
75
- MLX.GPU.clearCache()
108
+ self.tools = []
109
+ self.toolSchemas = []
110
+ Memory.clearCache()
76
111
 
77
112
  let memoryAfterCleanup = self.getMemoryUsage()
78
113
  let gpuAfterCleanup = self.getGPUMemoryUsage()
@@ -94,6 +129,12 @@ class HybridLLM: HybridLLMSpec {
94
129
  let gpuAfterContainer = self.getGPUMemoryUsage()
95
130
  log("Model loaded - Host: \(memoryAfterContainer), GPU: \(gpuAfterContainer)")
96
131
 
132
+ if let jsTools = options?.tools {
133
+ self.tools = jsTools
134
+ self.toolSchemas = jsTools.map { self.buildToolSchema(from: $0) }
135
+ log("Loaded \(self.tools.count) tools: \(self.tools.map { $0.name })")
136
+ }
137
+
97
138
  let additionalContextDict: [String: Any]? = if let messages = options?.additionalContext {
98
139
  ["messages": messages.map { ["role": $0.role, "content": $0.content] }]
99
140
  } else {
@@ -135,25 +176,26 @@ class HybridLLM: HybridLLMSpec {
135
176
  }
136
177
 
137
178
  self.currentTask = task
179
+ defer { self.currentTask = nil }
138
180
 
139
- do {
140
- let result = try await task.value
141
- self.currentTask = nil
181
+ let result = try await task.value
142
182
 
143
- if self.manageHistory {
144
- self.messageHistory.append(LLMMessage(role: "assistant", content: result))
145
- }
146
-
147
- return result
148
- } catch {
149
- self.currentTask = nil
150
- throw error
183
+ if self.manageHistory {
184
+ self.messageHistory.append(LLMMessage(role: "assistant", content: result))
151
185
  }
186
+
187
+ return result
152
188
  }
153
189
  }
154
190
 
155
- func stream(prompt: String, onToken: @escaping (String) -> Void) throws -> Promise<String> {
156
- guard let session = session else {
191
+ private let maxToolCallDepth = 10
192
+
193
+ func stream(
194
+ prompt: String,
195
+ onToken: @escaping (String) -> Void,
196
+ onToolCall: ((String, String) -> Void)?
197
+ ) throws -> Promise<String> {
198
+ guard let container else {
157
199
  throw LLMError.notLoaded
158
200
  }
159
201
 
@@ -163,22 +205,24 @@ class HybridLLM: HybridLLMSpec {
163
205
  }
164
206
 
165
207
  let task = Task<String, Error> {
166
- var result = ""
167
- var tokenCount = 0
168
208
  let startTime = Date()
169
209
  var firstTokenTime: Date?
210
+ var tokenCount = 0
170
211
 
171
- log("Streaming response for: \(prompt.prefix(50))...")
172
- for try await chunk in session.streamResponse(to: prompt) {
173
- if Task.isCancelled { break }
174
-
175
- if firstTokenTime == nil {
176
- firstTokenTime = Date()
177
- }
178
- tokenCount += 1
179
- result += chunk
180
- onToken(chunk)
181
- }
212
+ let result = try await self.performGeneration(
213
+ container: container,
214
+ prompt: prompt,
215
+ toolResults: nil,
216
+ depth: 0,
217
+ onToken: { token in
218
+ if firstTokenTime == nil {
219
+ firstTokenTime = Date()
220
+ }
221
+ tokenCount += 1
222
+ onToken(token)
223
+ },
224
+ onToolCall: onToolCall ?? { _, _ in }
225
+ )
182
226
 
183
227
  let endTime = Date()
184
228
  let totalTime = endTime.timeIntervalSince(startTime) * 1000
@@ -189,7 +233,8 @@ class HybridLLM: HybridLLMSpec {
189
233
  tokenCount: Double(tokenCount),
190
234
  tokensPerSecond: tokensPerSecond,
191
235
  timeToFirstToken: timeToFirstToken,
192
- totalTime: totalTime
236
+ totalTime: totalTime,
237
+ toolExecutionTime: 0
193
238
  )
194
239
 
195
240
  log("Stream complete - \(tokenCount) tokens, \(String(format: "%.1f", tokensPerSecond)) tokens/s")
@@ -197,21 +242,450 @@ class HybridLLM: HybridLLMSpec {
197
242
  }
198
243
 
199
244
  self.currentTask = task
245
+ defer { self.currentTask = nil }
200
246
 
201
- do {
202
- let result = try await task.value
203
- self.currentTask = nil
247
+ let result = try await task.value
204
248
 
205
- if self.manageHistory {
206
- self.messageHistory.append(LLMMessage(role: "assistant", content: result))
207
- }
249
+ if self.manageHistory {
250
+ self.messageHistory.append(LLMMessage(role: "assistant", content: result))
251
+ }
252
+
253
+ return result
254
+ }
255
+ }
208
256
 
257
+ func streamWithEvents(
258
+ prompt: String,
259
+ onEvent: @escaping (String) -> Void
260
+ ) throws -> Promise<String> {
261
+ guard let container else {
262
+ throw LLMError.notLoaded
263
+ }
264
+
265
+ return Promise.async { [self] in
266
+ if self.manageHistory {
267
+ self.messageHistory.append(LLMMessage(role: "user", content: prompt))
268
+ }
269
+
270
+ let task = Task<String, Error> {
271
+ let startTime = Date()
272
+ var firstTokenTime: Date?
273
+ var outputTokenCount = 0
274
+ var mlxTokenCount = 0
275
+ var mlxGenerationTime: Double = 0
276
+ var toolExecutionTime: Double = 0
277
+ let emitter = StreamEventEmitter(callback: onEvent)
278
+
279
+ emitter.emitGenerationStart()
280
+
281
+ let result = try await self.performGenerationWithEvents(
282
+ container: container,
283
+ prompt: prompt,
284
+ toolResults: nil,
285
+ depth: 0,
286
+ emitter: emitter,
287
+ onTokenProcessed: {
288
+ if firstTokenTime == nil {
289
+ firstTokenTime = Date()
290
+ }
291
+ outputTokenCount += 1
292
+ },
293
+ onGenerationInfo: { tokens, time in
294
+ mlxTokenCount += tokens
295
+ mlxGenerationTime += time
296
+ },
297
+ toolExecutionTime: &toolExecutionTime
298
+ )
299
+
300
+ let endTime = Date()
301
+ let totalTime = endTime.timeIntervalSince(startTime) * 1000
302
+ let timeToFirstToken = (firstTokenTime ?? endTime).timeIntervalSince(startTime) * 1000
303
+ let tokensPerSecond = mlxGenerationTime > 0 ? Double(mlxTokenCount) / (mlxGenerationTime / 1000) : 0
304
+
305
+ let stats = GenerationStats(
306
+ tokenCount: Double(mlxTokenCount),
307
+ tokensPerSecond: tokensPerSecond,
308
+ timeToFirstToken: timeToFirstToken,
309
+ totalTime: totalTime,
310
+ toolExecutionTime: toolExecutionTime
311
+ )
312
+
313
+ self.lastStats = stats
314
+ emitter.emitGenerationEnd(content: result, stats: stats)
315
+
316
+ log("StreamWithEvents complete - \(mlxTokenCount) tokens, \(String(format: "%.1f", tokensPerSecond)) tokens/s (tool execution: \(String(format: "%.0f", toolExecutionTime))ms)")
209
317
  return result
210
- } catch {
211
- self.currentTask = nil
212
- throw error
318
+ }
319
+
320
+ self.currentTask = task
321
+ defer { self.currentTask = nil }
322
+
323
+ let result = try await task.value
324
+
325
+ if self.manageHistory {
326
+ self.messageHistory.append(LLMMessage(role: "assistant", content: result))
327
+ }
328
+
329
+ return result
330
+ }
331
+ }
332
+
333
+ private func buildChatMessages(
334
+ prompt: String,
335
+ toolResults: [String]?,
336
+ depth: Int
337
+ ) -> [Chat.Message] {
338
+ var chat: [Chat.Message] = []
339
+
340
+ if !self.systemPrompt.isEmpty {
341
+ chat.append(.system(self.systemPrompt))
342
+ }
343
+
344
+ for msg in self.messageHistory {
345
+ switch msg.role {
346
+ case "user": chat.append(.user(msg.content))
347
+ case "assistant": chat.append(.assistant(msg.content))
348
+ case "system": chat.append(.system(msg.content))
349
+ case "tool": chat.append(.tool(msg.content))
350
+ default: break
351
+ }
352
+ }
353
+
354
+ if depth == 0 {
355
+ chat.append(.user(prompt))
356
+ }
357
+
358
+ if let toolResults {
359
+ for result in toolResults {
360
+ chat.append(.tool(result))
361
+ }
362
+ }
363
+
364
+ return chat
365
+ }
366
+
367
+ private func executeToolCall(
368
+ tool: ToolDefinition,
369
+ argsDict: [String: Any]
370
+ ) async throws -> String {
371
+ let argsAnyMap = self.dictionaryToAnyMap(argsDict)
372
+ let outerPromise = tool.handler(argsAnyMap)
373
+ let innerPromise = try await outerPromise.await()
374
+ let resultAnyMap = try await innerPromise.await()
375
+ let resultDict = self.anyMapToDictionary(resultAnyMap)
376
+ return dictionaryToJson(resultDict)
377
+ }
378
+
379
+ private func performGenerationWithEvents(
380
+ container: ModelContainer,
381
+ prompt: String,
382
+ toolResults: [String]?,
383
+ depth: Int,
384
+ emitter: StreamEventEmitter,
385
+ onTokenProcessed: @escaping () -> Void,
386
+ onGenerationInfo: @escaping (Int, Double) -> Void,
387
+ toolExecutionTime: inout Double
388
+ ) async throws -> String {
389
+ if depth >= maxToolCallDepth {
390
+ log("Max tool call depth reached (\(maxToolCallDepth))")
391
+ return ""
392
+ }
393
+
394
+ var output = ""
395
+ var thinkingMachine = ThinkingStateMachine()
396
+ var pendingToolCalls: [(id: String, tool: ToolDefinition, args: [String: Any], argsJson: String)] = []
397
+
398
+ let chat = buildChatMessages(prompt: prompt, toolResults: toolResults, depth: depth)
399
+ let userInput = UserInput(
400
+ chat: chat,
401
+ tools: !self.toolSchemas.isEmpty ? self.toolSchemas : nil
402
+ )
403
+
404
+ let lmInput = try await container.prepare(input: userInput)
405
+
406
+ let stream = try await container.perform { context in
407
+ let parameters = GenerateParameters(maxTokens: 2048, temperature: 0.7)
408
+ return try MLXLMCommon.generate(
409
+ input: lmInput,
410
+ parameters: parameters,
411
+ context: context
412
+ )
413
+ }
414
+
415
+ for await generation in stream {
416
+ if Task.isCancelled { break }
417
+
418
+ switch generation {
419
+ case .chunk(let text):
420
+ let outputs = thinkingMachine.process(token: text)
421
+
422
+ for machineOutput in outputs {
423
+ switch machineOutput {
424
+ case .token(let token):
425
+ output += token
426
+ emitter.emitToken(token)
427
+ onTokenProcessed()
428
+
429
+ case .thinkingStart:
430
+ emitter.emitThinkingStart()
431
+
432
+ case .thinkingChunk(let chunk):
433
+ emitter.emitThinkingChunk(chunk)
434
+
435
+ case .thinkingEnd(let content):
436
+ emitter.emitThinkingEnd(content)
437
+ }
438
+ }
439
+
440
+ case .toolCall(let toolCall):
441
+ log("Tool call detected: \(toolCall.function.name)")
442
+
443
+ guard let tool = self.tools.first(where: { $0.name == toolCall.function.name }) else {
444
+ log("Unknown tool: \(toolCall.function.name)")
445
+ continue
446
+ }
447
+
448
+ let toolCallId = UUID().uuidString
449
+ let argsDict = self.convertToolCallArguments(toolCall.function.arguments)
450
+ let argsJson = dictionaryToJson(argsDict)
451
+
452
+ emitter.emitToolCallStart(id: toolCallId, name: toolCall.function.name, arguments: argsJson)
453
+ pendingToolCalls.append((id: toolCallId, tool: tool, args: argsDict, argsJson: argsJson))
454
+
455
+ case .info(let info):
456
+ log("Generation info: \(info.generationTokenCount) tokens, \(String(format: "%.1f", info.tokensPerSecond)) tokens/s")
457
+ let generationTime = info.tokensPerSecond > 0 ? Double(info.generationTokenCount) / info.tokensPerSecond * 1000 : 0
458
+ onGenerationInfo(info.generationTokenCount, generationTime)
459
+ }
460
+ }
461
+
462
+ let flushOutputs = thinkingMachine.flush()
463
+ for machineOutput in flushOutputs {
464
+ switch machineOutput {
465
+ case .token(let token):
466
+ output += token
467
+ emitter.emitToken(token)
468
+ onTokenProcessed()
469
+ case .thinkingStart:
470
+ emitter.emitThinkingStart()
471
+ case .thinkingChunk(let chunk):
472
+ emitter.emitThinkingChunk(chunk)
473
+ case .thinkingEnd(let content):
474
+ emitter.emitThinkingEnd(content)
475
+ }
476
+ }
477
+
478
+ if !pendingToolCalls.isEmpty {
479
+ log("Executing \(pendingToolCalls.count) tool call(s)")
480
+
481
+ let toolStartTime = Date()
482
+
483
+ for call in pendingToolCalls {
484
+ emitter.emitToolCallExecuting(id: call.id)
485
+ }
486
+
487
+ let allToolResults: [String] = await withTaskGroup(of: (Int, String).self) { group in
488
+ for (index, call) in pendingToolCalls.enumerated() {
489
+ group.addTask { [self] in
490
+ do {
491
+ let resultJson = try await self.executeToolCall(tool: call.tool, argsDict: call.args)
492
+ self.log("Tool result for \(call.tool.name): \(resultJson.prefix(100))...")
493
+ emitter.emitToolCallCompleted(id: call.id, result: resultJson)
494
+ return (index, resultJson)
495
+ } catch {
496
+ self.log("Tool execution error for \(call.tool.name): \(error)")
497
+ emitter.emitToolCallFailed(id: call.id, error: error.localizedDescription)
498
+ return (index, "{\"error\": \"Tool execution failed\"}")
499
+ }
500
+ }
501
+ }
502
+
503
+ var results = Array(repeating: "", count: pendingToolCalls.count)
504
+ for await (index, result) in group {
505
+ results[index] = result
506
+ }
507
+ return results
508
+ }
509
+
510
+ toolExecutionTime += Date().timeIntervalSince(toolStartTime) * 1000
511
+
512
+ if !output.isEmpty {
513
+ self.messageHistory.append(LLMMessage(role: "assistant", content: output))
514
+ }
515
+
516
+ for result in allToolResults {
517
+ self.messageHistory.append(LLMMessage(role: "tool", content: result))
518
+ }
519
+
520
+ let continuation = try await self.performGenerationWithEvents(
521
+ container: container,
522
+ prompt: prompt,
523
+ toolResults: allToolResults,
524
+ depth: depth + 1,
525
+ emitter: emitter,
526
+ onTokenProcessed: onTokenProcessed,
527
+ onGenerationInfo: onGenerationInfo,
528
+ toolExecutionTime: &toolExecutionTime
529
+ )
530
+
531
+ return output + continuation
532
+ }
533
+
534
+ return output
535
+ }
536
+
537
+ private func performGeneration(
538
+ container: ModelContainer,
539
+ prompt: String,
540
+ toolResults: [String]?,
541
+ depth: Int,
542
+ onToken: @escaping (String) -> Void,
543
+ onToolCall: @escaping (String, String) -> Void
544
+ ) async throws -> String {
545
+ if depth >= maxToolCallDepth {
546
+ log("Max tool call depth reached (\(maxToolCallDepth))")
547
+ return ""
548
+ }
549
+
550
+ var output = ""
551
+ var pendingToolCalls: [(tool: ToolDefinition, args: [String: Any], argsJson: String)] = []
552
+
553
+ let chat = buildChatMessages(prompt: prompt, toolResults: toolResults, depth: depth)
554
+ let userInput = UserInput(
555
+ chat: chat,
556
+ tools: !self.toolSchemas.isEmpty ? self.toolSchemas : nil
557
+ )
558
+
559
+ let lmInput = try await container.prepare(input: userInput)
560
+
561
+ let stream = try await container.perform { context in
562
+ let parameters = GenerateParameters(maxTokens: 2048, temperature: 0.7)
563
+ return try MLXLMCommon.generate(
564
+ input: lmInput,
565
+ parameters: parameters,
566
+ context: context
567
+ )
568
+ }
569
+
570
+ for await generation in stream {
571
+ if Task.isCancelled { break }
572
+
573
+ switch generation {
574
+ case .chunk(let text):
575
+ output += text
576
+ onToken(text)
577
+
578
+ case .toolCall(let toolCall):
579
+ log("Tool call detected: \(toolCall.function.name)")
580
+
581
+ guard let tool = self.tools.first(where: { $0.name == toolCall.function.name }) else {
582
+ log("Unknown tool: \(toolCall.function.name)")
583
+ continue
584
+ }
585
+
586
+ let argsDict = self.convertToolCallArguments(toolCall.function.arguments)
587
+ let argsJson = dictionaryToJson(argsDict)
588
+
589
+ pendingToolCalls.append((tool: tool, args: argsDict, argsJson: argsJson))
590
+ onToolCall(toolCall.function.name, argsJson)
591
+
592
+ case .info(let info):
593
+ log("Generation info: \(info.generationTokenCount) tokens, \(String(format: "%.1f", info.tokensPerSecond)) tokens/s")
594
+ }
595
+ }
596
+
597
+ if !pendingToolCalls.isEmpty {
598
+ log("Executing \(pendingToolCalls.count) tool call(s)")
599
+
600
+ let allToolResults: [String] = await withTaskGroup(of: (Int, String).self) { group in
601
+ for (index, call) in pendingToolCalls.enumerated() {
602
+ group.addTask { [self] in
603
+ do {
604
+ let resultJson = try await self.executeToolCall(tool: call.tool, argsDict: call.args)
605
+ self.log("Tool result for \(call.tool.name): \(resultJson.prefix(100))...")
606
+ return (index, resultJson)
607
+ } catch {
608
+ self.log("Tool execution error for \(call.tool.name): \(error)")
609
+ return (index, "{\"error\": \"Tool execution failed\"}")
610
+ }
611
+ }
612
+ }
613
+
614
+ var results = Array(repeating: "", count: pendingToolCalls.count)
615
+ for await (index, result) in group {
616
+ results[index] = result
617
+ }
618
+ return results
619
+ }
620
+
621
+ if !output.isEmpty {
622
+ self.messageHistory.append(LLMMessage(role: "assistant", content: output))
623
+ }
624
+
625
+ if depth == 0 {
626
+ self.messageHistory.append(LLMMessage(role: "user", content: prompt))
627
+ }
628
+
629
+ for result in allToolResults {
630
+ self.messageHistory.append(LLMMessage(role: "tool", content: result))
631
+ }
632
+
633
+ onToken("\u{200B}")
634
+
635
+ let continuation = try await self.performGeneration(
636
+ container: container,
637
+ prompt: prompt,
638
+ toolResults: allToolResults,
639
+ depth: depth + 1,
640
+ onToken: onToken,
641
+ onToolCall: onToolCall
642
+ )
643
+
644
+ return output + continuation
645
+ }
646
+
647
+ return output
648
+ }
649
+
650
+ private func convertToolCallArguments(_ arguments: [String: JSONValue]) -> [String: Any] {
651
+ var result: [String: Any] = [:]
652
+ for (key, value) in arguments {
653
+ result[key] = value.anyValue
654
+ }
655
+ return result
656
+ }
657
+
658
+ private func dictionaryToAnyMap(_ dict: [String: Any]) -> AnyMap {
659
+ let anyMap = AnyMap()
660
+ for (key, value) in dict {
661
+ switch value {
662
+ case let stringValue as String:
663
+ anyMap.setString(key: key, value: stringValue)
664
+ case let doubleValue as Double:
665
+ anyMap.setDouble(key: key, value: doubleValue)
666
+ case let intValue as Int:
667
+ anyMap.setDouble(key: key, value: Double(intValue))
668
+ case let boolValue as Bool:
669
+ anyMap.setBoolean(key: key, value: boolValue)
670
+ default:
671
+ anyMap.setString(key: key, value: String(describing: value))
672
+ }
673
+ }
674
+ return anyMap
675
+ }
676
+
677
+ private func anyMapToDictionary(_ anyMap: AnyMap) -> [String: Any] {
678
+ var dict: [String: Any] = [:]
679
+ for key in anyMap.getAllKeys() {
680
+ if anyMap.isString(key: key) {
681
+ dict[key] = anyMap.getString(key: key)
682
+ } else if anyMap.isDouble(key: key) {
683
+ dict[key] = anyMap.getDouble(key: key)
684
+ } else if anyMap.isBool(key: key) {
685
+ dict[key] = anyMap.getBoolean(key: key)
213
686
  }
214
687
  }
688
+ return dict
215
689
  }
216
690
 
217
691
  func stop() throws {
@@ -231,11 +705,13 @@ class HybridLLM: HybridLLMSpec {
231
705
  currentTask = nil
232
706
  session = nil
233
707
  container = nil
708
+ tools = []
709
+ toolSchemas = []
234
710
  messageHistory = []
235
711
  manageHistory = false
236
712
  modelId = ""
237
713
 
238
- MLX.GPU.clearCache()
714
+ MLX.Memory.clearCache()
239
715
 
240
716
  let memoryAfter = getMemoryUsage()
241
717
  let gpuAfter = getGPUMemoryUsage()