@inferrlm/react-native-mlx 0.2.0-inferrlm.1 → 0.4.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (102) hide show
  1. package/MLXReactNative.podspec +7 -1
  2. package/ios/Sources/AudioCaptureManager.swift +110 -0
  3. package/ios/Sources/HybridLLM.swift +562 -74
  4. package/ios/Sources/HybridSTT.swift +202 -0
  5. package/ios/Sources/HybridTTS.swift +145 -0
  6. package/ios/Sources/JSONHelpers.swift +9 -0
  7. package/ios/Sources/ModelDownloader.swift +26 -12
  8. package/ios/Sources/StreamEventEmitter.swift +132 -0
  9. package/ios/Sources/ThinkingStateMachine.swift +206 -0
  10. package/nitrogen/generated/ios/MLXReactNative+autolinking.rb +1 -1
  11. package/nitrogen/generated/ios/MLXReactNative-Swift-Cxx-Bridge.cpp +76 -1
  12. package/nitrogen/generated/ios/MLXReactNative-Swift-Cxx-Bridge.hpp +338 -1
  13. package/nitrogen/generated/ios/MLXReactNative-Swift-Cxx-Umbrella.hpp +28 -1
  14. package/nitrogen/generated/ios/MLXReactNativeAutolinking.mm +17 -1
  15. package/nitrogen/generated/ios/MLXReactNativeAutolinking.swift +31 -1
  16. package/nitrogen/generated/ios/c++/HybridLLMSpecSwift.cpp +1 -1
  17. package/nitrogen/generated/ios/c++/HybridLLMSpecSwift.hpp +18 -3
  18. package/nitrogen/generated/ios/c++/HybridModelManagerSpecSwift.cpp +1 -1
  19. package/nitrogen/generated/ios/c++/HybridModelManagerSpecSwift.hpp +1 -1
  20. package/nitrogen/generated/ios/c++/HybridSTTSpecSwift.cpp +11 -0
  21. package/nitrogen/generated/ios/c++/HybridSTTSpecSwift.hpp +149 -0
  22. package/nitrogen/generated/ios/c++/HybridTTSSpecSwift.cpp +11 -0
  23. package/nitrogen/generated/ios/c++/HybridTTSSpecSwift.hpp +128 -0
  24. package/nitrogen/generated/ios/swift/Func_std__shared_ptr_Promise_std__shared_ptr_Promise_std__shared_ptr_AnyMap______std__shared_ptr_AnyMap_.swift +62 -0
  25. package/nitrogen/generated/ios/swift/Func_void.swift +1 -1
  26. package/nitrogen/generated/ios/swift/Func_void_bool.swift +1 -1
  27. package/nitrogen/generated/ios/swift/Func_void_double.swift +1 -1
  28. package/nitrogen/generated/ios/swift/Func_void_std__exception_ptr.swift +1 -1
  29. package/nitrogen/generated/ios/swift/Func_void_std__shared_ptr_AnyMap_.swift +47 -0
  30. package/nitrogen/generated/ios/swift/Func_void_std__shared_ptr_ArrayBuffer_.swift +47 -0
  31. package/nitrogen/generated/ios/swift/Func_void_std__shared_ptr_Promise_std__shared_ptr_AnyMap___.swift +67 -0
  32. package/nitrogen/generated/ios/swift/Func_void_std__string.swift +1 -1
  33. package/nitrogen/generated/ios/swift/Func_void_std__string_std__string.swift +47 -0
  34. package/nitrogen/generated/ios/swift/Func_void_std__vector_std__string_.swift +1 -1
  35. package/nitrogen/generated/ios/swift/GenerationStats.swift +14 -3
  36. package/nitrogen/generated/ios/swift/HybridLLMSpec.swift +3 -2
  37. package/nitrogen/generated/ios/swift/HybridLLMSpec_cxx.swift +38 -2
  38. package/nitrogen/generated/ios/swift/HybridModelManagerSpec.swift +1 -1
  39. package/nitrogen/generated/ios/swift/HybridModelManagerSpec_cxx.swift +1 -1
  40. package/nitrogen/generated/ios/swift/HybridSTTSpec.swift +66 -0
  41. package/nitrogen/generated/ios/swift/HybridSTTSpec_cxx.swift +286 -0
  42. package/nitrogen/generated/ios/swift/HybridTTSSpec.swift +63 -0
  43. package/nitrogen/generated/ios/swift/HybridTTSSpec_cxx.swift +229 -0
  44. package/nitrogen/generated/ios/swift/LLMLoadOptions.swift +44 -2
  45. package/nitrogen/generated/ios/swift/LLMMessage.swift +1 -1
  46. package/nitrogen/generated/ios/swift/STTLoadOptions.swift +66 -0
  47. package/nitrogen/generated/ios/swift/TTSGenerateOptions.swift +78 -0
  48. package/nitrogen/generated/ios/swift/TTSLoadOptions.swift +66 -0
  49. package/nitrogen/generated/ios/swift/ToolDefinition.swift +113 -0
  50. package/nitrogen/generated/ios/swift/ToolParameter.swift +69 -0
  51. package/nitrogen/generated/shared/c++/GenerationStats.hpp +7 -3
  52. package/nitrogen/generated/shared/c++/HybridLLMSpec.cpp +2 -1
  53. package/nitrogen/generated/shared/c++/HybridLLMSpec.hpp +3 -2
  54. package/nitrogen/generated/shared/c++/HybridModelManagerSpec.cpp +1 -1
  55. package/nitrogen/generated/shared/c++/HybridModelManagerSpec.hpp +1 -1
  56. package/nitrogen/generated/shared/c++/HybridSTTSpec.cpp +32 -0
  57. package/nitrogen/generated/shared/c++/HybridSTTSpec.hpp +78 -0
  58. package/nitrogen/generated/shared/c++/HybridTTSSpec.cpp +29 -0
  59. package/nitrogen/generated/shared/c++/HybridTTSSpec.hpp +78 -0
  60. package/nitrogen/generated/shared/c++/LLMLoadOptions.hpp +10 -3
  61. package/nitrogen/generated/shared/c++/LLMMessage.hpp +1 -1
  62. package/nitrogen/generated/shared/c++/STTLoadOptions.hpp +76 -0
  63. package/nitrogen/generated/shared/c++/TTSGenerateOptions.hpp +80 -0
  64. package/nitrogen/generated/shared/c++/TTSLoadOptions.hpp +76 -0
  65. package/nitrogen/generated/shared/c++/ToolDefinition.hpp +93 -0
  66. package/nitrogen/generated/shared/c++/ToolParameter.hpp +87 -0
  67. package/package.json +13 -8
  68. package/src/index.ts +48 -4
  69. package/src/llm.ts +90 -5
  70. package/src/models.ts +347 -0
  71. package/src/specs/LLM.nitro.ts +111 -7
  72. package/src/specs/STT.nitro.ts +35 -0
  73. package/src/specs/TTS.nitro.ts +30 -0
  74. package/src/stt.ts +67 -0
  75. package/src/tool-utils.ts +74 -0
  76. package/src/tts.ts +60 -0
  77. package/lib/module/index.js +0 -6
  78. package/lib/module/index.js.map +0 -1
  79. package/lib/module/llm.js +0 -125
  80. package/lib/module/llm.js.map +0 -1
  81. package/lib/module/modelManager.js +0 -79
  82. package/lib/module/modelManager.js.map +0 -1
  83. package/lib/module/models.js +0 -41
  84. package/lib/module/models.js.map +0 -1
  85. package/lib/module/package.json +0 -1
  86. package/lib/module/specs/LLM.nitro.js +0 -4
  87. package/lib/module/specs/LLM.nitro.js.map +0 -1
  88. package/lib/module/specs/ModelManager.nitro.js +0 -4
  89. package/lib/module/specs/ModelManager.nitro.js.map +0 -1
  90. package/lib/typescript/package.json +0 -1
  91. package/lib/typescript/src/index.d.ts +0 -6
  92. package/lib/typescript/src/index.d.ts.map +0 -1
  93. package/lib/typescript/src/llm.d.ts +0 -87
  94. package/lib/typescript/src/llm.d.ts.map +0 -1
  95. package/lib/typescript/src/modelManager.d.ts +0 -53
  96. package/lib/typescript/src/modelManager.d.ts.map +0 -1
  97. package/lib/typescript/src/models.d.ts +0 -29
  98. package/lib/typescript/src/models.d.ts.map +0 -1
  99. package/lib/typescript/src/specs/LLM.nitro.d.ts +0 -88
  100. package/lib/typescript/src/specs/LLM.nitro.d.ts.map +0 -1
  101. package/lib/typescript/src/specs/ModelManager.nitro.d.ts +0 -41
  102. package/lib/typescript/src/specs/ModelManager.nitro.d.ts.map +0 -1
@@ -3,20 +3,26 @@ import NitroModules
3
3
  internal import MLX
4
4
  internal import MLXLLM
5
5
  internal import MLXLMCommon
6
+ internal import Tokenizers
6
7
 
7
8
  class HybridLLM: HybridLLMSpec {
8
9
  private var session: ChatSession?
9
10
  private var currentTask: Task<String, Error>?
10
- private var container: Any?
11
+ private var container: ModelContainer?
11
12
  private var lastStats: GenerationStats = GenerationStats(
12
13
  tokenCount: 0,
13
14
  tokensPerSecond: 0,
14
15
  timeToFirstToken: 0,
15
- totalTime: 0
16
+ totalTime: 0,
17
+ toolExecutionTime: 0
16
18
  )
17
19
  private var modelFactory: ModelFactory = LLMModelFactory.shared
18
20
  private var manageHistory: Bool = false
19
21
  private var messageHistory: [LLMMessage] = []
22
+ private var loadTask: Task<Void, Error>?
23
+
24
+ private var tools: [ToolDefinition] = []
25
+ private var toolSchemas: [ToolSpec] = []
20
26
 
21
27
  var isLoaded: Bool { session != nil }
22
28
  var isGenerating: Bool { currentTask != nil }
@@ -52,7 +58,7 @@ class HybridLLM: HybridLLMSpec {
52
58
  }
53
59
 
54
60
  private func getGPUMemoryUsage() -> String {
55
- let snapshot = GPU.snapshot()
61
+ let snapshot = Memory.snapshot()
56
62
  let allocatedMB = Float(snapshot.activeMemory) / 1024.0 / 1024.0
57
63
  let cacheMB = Float(snapshot.cacheMemory) / 1024.0 / 1024.0
58
64
  let peakMB = Float(snapshot.peakMemory) / 1024.0 / 1024.0
@@ -60,51 +66,95 @@ class HybridLLM: HybridLLMSpec {
60
66
  allocatedMB, cacheMB, peakMB)
61
67
  }
62
68
 
69
+ private func buildToolSchema(from tool: ToolDefinition) -> ToolSpec {
70
+ var properties: [String: [String: Any]] = [:]
71
+ var required: [String] = []
72
+
73
+ for param in tool.parameters {
74
+ properties[param.name] = [
75
+ "type": param.type,
76
+ "description": param.description
77
+ ]
78
+ if param.required {
79
+ required.append(param.name)
80
+ }
81
+ }
82
+
83
+ return [
84
+ "type": "function",
85
+ "function": [
86
+ "name": tool.name,
87
+ "description": tool.description,
88
+ "parameters": [
89
+ "type": "object",
90
+ "properties": properties,
91
+ "required": required
92
+ ]
93
+ ]
94
+ ] as ToolSpec
95
+ }
96
+
63
97
  func load(modelId: String, options: LLMLoadOptions?) throws -> Promise<Void> {
64
- return Promise.async { [self] in
65
- MLX.GPU.set(cacheLimit: 2000000)
98
+ self.loadTask?.cancel()
66
99
 
67
- self.currentTask?.cancel()
68
- self.currentTask = nil
69
- self.session = nil
70
- self.container = nil
71
- MLX.GPU.clearCache()
100
+ return Promise.async { [self] in
101
+ let task = Task { @MainActor in
102
+ Memory.cacheLimit = 2000000
72
103
 
73
- let memoryAfterCleanup = self.getMemoryUsage()
74
- let gpuAfterCleanup = self.getGPUMemoryUsage()
75
- log("After cleanup - Host: \(memoryAfterCleanup), GPU: \(gpuAfterCleanup)")
104
+ self.currentTask?.cancel()
105
+ self.currentTask = nil
106
+ self.session = nil
107
+ self.container = nil
108
+ self.tools = []
109
+ self.toolSchemas = []
110
+ Memory.clearCache()
111
+
112
+ let memoryAfterCleanup = self.getMemoryUsage()
113
+ let gpuAfterCleanup = self.getGPUMemoryUsage()
114
+ log("After cleanup - Host: \(memoryAfterCleanup), GPU: \(gpuAfterCleanup)")
115
+
116
+ let modelDir = await ModelDownloader.shared.getModelDirectory(modelId: modelId)
117
+ log("Loading from directory: \(modelDir.path)")
118
+
119
+ let config = ModelConfiguration(directory: modelDir)
120
+ let loadedContainer = try await self.modelFactory.loadContainer(
121
+ configuration: config
122
+ ) { progress in
123
+ options?.onProgress?(progress.fractionCompleted)
124
+ }
76
125
 
77
- let modelDir = await ModelDownloader.shared.getModelDirectory(modelId: modelId)
78
- log("Loading from directory: \(modelDir.path)")
126
+ try Task.checkCancellation()
79
127
 
80
- let config = ModelConfiguration(directory: modelDir)
81
- let loadedContainer = try await modelFactory.loadContainer(
82
- configuration: config
83
- ) { progress in
84
- options?.onProgress?(progress.fractionCompleted)
85
- }
128
+ let memoryAfterContainer = self.getMemoryUsage()
129
+ let gpuAfterContainer = self.getGPUMemoryUsage()
130
+ log("Model loaded - Host: \(memoryAfterContainer), GPU: \(gpuAfterContainer)")
86
131
 
87
- let memoryAfterContainer = self.getMemoryUsage()
88
- let gpuAfterContainer = self.getGPUMemoryUsage()
89
- log("Model loaded - Host: \(memoryAfterContainer), GPU: \(gpuAfterContainer)")
132
+ if let jsTools = options?.tools {
133
+ self.tools = jsTools
134
+ self.toolSchemas = jsTools.map { self.buildToolSchema(from: $0) }
135
+ log("Loaded \(self.tools.count) tools: \(self.tools.map { $0.name })")
136
+ }
90
137
 
91
- // Convert [LLMMessage]? to [String: Any]?
92
- let additionalContextDict: [String: Any]? = if let messages = options?.additionalContext {
93
- ["messages": messages.map { ["role": $0.role, "content": $0.content] }]
94
- } else {
95
- nil
96
- }
138
+ let additionalContextDict: [String: Any]? = if let messages = options?.additionalContext {
139
+ ["messages": messages.map { ["role": $0.role, "content": $0.content] }]
140
+ } else {
141
+ nil
142
+ }
97
143
 
98
- self.container = loadedContainer
99
- self.session = ChatSession(loadedContainer, instructions: self.systemPrompt, additionalContext: additionalContextDict)
100
- self.modelId = modelId
144
+ self.container = loadedContainer
145
+ self.session = ChatSession(loadedContainer, instructions: self.systemPrompt, additionalContext: additionalContextDict)
146
+ self.modelId = modelId
101
147
 
102
- self.manageHistory = options?.manageHistory ?? false
103
- self.messageHistory = options?.additionalContext ?? []
148
+ self.manageHistory = options?.manageHistory ?? false
149
+ self.messageHistory = options?.additionalContext ?? []
104
150
 
105
- if self.manageHistory {
106
- log("History management enabled with \(self.messageHistory.count) initial messages")
151
+ if self.manageHistory {
152
+ log("History management enabled with \(self.messageHistory.count) initial messages")
153
+ }
107
154
  }
155
+
156
+ self.loadTask = task
157
+ try await task.value
108
158
  }
109
159
  }
110
160
 
@@ -126,25 +176,26 @@ class HybridLLM: HybridLLMSpec {
126
176
  }
127
177
 
128
178
  self.currentTask = task
179
+ defer { self.currentTask = nil }
129
180
 
130
- do {
131
- let result = try await task.value
132
- self.currentTask = nil
181
+ let result = try await task.value
133
182
 
134
- if self.manageHistory {
135
- self.messageHistory.append(LLMMessage(role: "assistant", content: result))
136
- }
137
-
138
- return result
139
- } catch {
140
- self.currentTask = nil
141
- throw error
183
+ if self.manageHistory {
184
+ self.messageHistory.append(LLMMessage(role: "assistant", content: result))
142
185
  }
186
+
187
+ return result
143
188
  }
144
189
  }
145
190
 
146
- func stream(prompt: String, onToken: @escaping (String) -> Void) throws -> Promise<String> {
147
- guard let session = session else {
191
+ private let maxToolCallDepth = 10
192
+
193
+ func stream(
194
+ prompt: String,
195
+ onToken: @escaping (String) -> Void,
196
+ onToolCall: ((String, String) -> Void)?
197
+ ) throws -> Promise<String> {
198
+ guard let container else {
148
199
  throw LLMError.notLoaded
149
200
  }
150
201
 
@@ -154,22 +205,24 @@ class HybridLLM: HybridLLMSpec {
154
205
  }
155
206
 
156
207
  let task = Task<String, Error> {
157
- var result = ""
158
- var tokenCount = 0
159
208
  let startTime = Date()
160
209
  var firstTokenTime: Date?
210
+ var tokenCount = 0
161
211
 
162
- log("Streaming response for: \(prompt.prefix(50))...")
163
- for try await chunk in session.streamResponse(to: prompt) {
164
- if Task.isCancelled { break }
165
-
166
- if firstTokenTime == nil {
167
- firstTokenTime = Date()
168
- }
169
- tokenCount += 1
170
- result += chunk
171
- onToken(chunk)
172
- }
212
+ let result = try await self.performGeneration(
213
+ container: container,
214
+ prompt: prompt,
215
+ toolResults: nil,
216
+ depth: 0,
217
+ onToken: { token in
218
+ if firstTokenTime == nil {
219
+ firstTokenTime = Date()
220
+ }
221
+ tokenCount += 1
222
+ onToken(token)
223
+ },
224
+ onToolCall: onToolCall ?? { _, _ in }
225
+ )
173
226
 
174
227
  let endTime = Date()
175
228
  let totalTime = endTime.timeIntervalSince(startTime) * 1000
@@ -180,7 +233,8 @@ class HybridLLM: HybridLLMSpec {
180
233
  tokenCount: Double(tokenCount),
181
234
  tokensPerSecond: tokensPerSecond,
182
235
  timeToFirstToken: timeToFirstToken,
183
- totalTime: totalTime
236
+ totalTime: totalTime,
237
+ toolExecutionTime: 0
184
238
  )
185
239
 
186
240
  log("Stream complete - \(tokenCount) tokens, \(String(format: "%.1f", tokensPerSecond)) tokens/s")
@@ -188,21 +242,450 @@ class HybridLLM: HybridLLMSpec {
188
242
  }
189
243
 
190
244
  self.currentTask = task
245
+ defer { self.currentTask = nil }
191
246
 
192
- do {
193
- let result = try await task.value
194
- self.currentTask = nil
247
+ let result = try await task.value
195
248
 
196
- if self.manageHistory {
197
- self.messageHistory.append(LLMMessage(role: "assistant", content: result))
198
- }
249
+ if self.manageHistory {
250
+ self.messageHistory.append(LLMMessage(role: "assistant", content: result))
251
+ }
252
+
253
+ return result
254
+ }
255
+ }
256
+
257
+ func streamWithEvents(
258
+ prompt: String,
259
+ onEvent: @escaping (String) -> Void
260
+ ) throws -> Promise<String> {
261
+ guard let container else {
262
+ throw LLMError.notLoaded
263
+ }
199
264
 
265
+ return Promise.async { [self] in
266
+ if self.manageHistory {
267
+ self.messageHistory.append(LLMMessage(role: "user", content: prompt))
268
+ }
269
+
270
+ let task = Task<String, Error> {
271
+ let startTime = Date()
272
+ var firstTokenTime: Date?
273
+ var outputTokenCount = 0
274
+ var mlxTokenCount = 0
275
+ var mlxGenerationTime: Double = 0
276
+ var toolExecutionTime: Double = 0
277
+ let emitter = StreamEventEmitter(callback: onEvent)
278
+
279
+ emitter.emitGenerationStart()
280
+
281
+ let result = try await self.performGenerationWithEvents(
282
+ container: container,
283
+ prompt: prompt,
284
+ toolResults: nil,
285
+ depth: 0,
286
+ emitter: emitter,
287
+ onTokenProcessed: {
288
+ if firstTokenTime == nil {
289
+ firstTokenTime = Date()
290
+ }
291
+ outputTokenCount += 1
292
+ },
293
+ onGenerationInfo: { tokens, time in
294
+ mlxTokenCount += tokens
295
+ mlxGenerationTime += time
296
+ },
297
+ toolExecutionTime: &toolExecutionTime
298
+ )
299
+
300
+ let endTime = Date()
301
+ let totalTime = endTime.timeIntervalSince(startTime) * 1000
302
+ let timeToFirstToken = (firstTokenTime ?? endTime).timeIntervalSince(startTime) * 1000
303
+ let tokensPerSecond = mlxGenerationTime > 0 ? Double(mlxTokenCount) / (mlxGenerationTime / 1000) : 0
304
+
305
+ let stats = GenerationStats(
306
+ tokenCount: Double(mlxTokenCount),
307
+ tokensPerSecond: tokensPerSecond,
308
+ timeToFirstToken: timeToFirstToken,
309
+ totalTime: totalTime,
310
+ toolExecutionTime: toolExecutionTime
311
+ )
312
+
313
+ self.lastStats = stats
314
+ emitter.emitGenerationEnd(content: result, stats: stats)
315
+
316
+ log("StreamWithEvents complete - \(mlxTokenCount) tokens, \(String(format: "%.1f", tokensPerSecond)) tokens/s (tool execution: \(String(format: "%.0f", toolExecutionTime))ms)")
200
317
  return result
201
- } catch {
202
- self.currentTask = nil
203
- throw error
318
+ }
319
+
320
+ self.currentTask = task
321
+ defer { self.currentTask = nil }
322
+
323
+ let result = try await task.value
324
+
325
+ if self.manageHistory {
326
+ self.messageHistory.append(LLMMessage(role: "assistant", content: result))
327
+ }
328
+
329
+ return result
330
+ }
331
+ }
332
+
333
+ private func buildChatMessages(
334
+ prompt: String,
335
+ toolResults: [String]?,
336
+ depth: Int
337
+ ) -> [Chat.Message] {
338
+ var chat: [Chat.Message] = []
339
+
340
+ if !self.systemPrompt.isEmpty {
341
+ chat.append(.system(self.systemPrompt))
342
+ }
343
+
344
+ for msg in self.messageHistory {
345
+ switch msg.role {
346
+ case "user": chat.append(.user(msg.content))
347
+ case "assistant": chat.append(.assistant(msg.content))
348
+ case "system": chat.append(.system(msg.content))
349
+ case "tool": chat.append(.tool(msg.content))
350
+ default: break
351
+ }
352
+ }
353
+
354
+ if depth == 0 {
355
+ chat.append(.user(prompt))
356
+ }
357
+
358
+ if let toolResults {
359
+ for result in toolResults {
360
+ chat.append(.tool(result))
204
361
  }
205
362
  }
363
+
364
+ return chat
365
+ }
366
+
367
+ private func executeToolCall(
368
+ tool: ToolDefinition,
369
+ argsDict: [String: Any]
370
+ ) async throws -> String {
371
+ let argsAnyMap = self.dictionaryToAnyMap(argsDict)
372
+ let outerPromise = tool.handler(argsAnyMap)
373
+ let innerPromise = try await outerPromise.await()
374
+ let resultAnyMap = try await innerPromise.await()
375
+ let resultDict = self.anyMapToDictionary(resultAnyMap)
376
+ return dictionaryToJson(resultDict)
377
+ }
378
+
379
+ private func performGenerationWithEvents(
380
+ container: ModelContainer,
381
+ prompt: String,
382
+ toolResults: [String]?,
383
+ depth: Int,
384
+ emitter: StreamEventEmitter,
385
+ onTokenProcessed: @escaping () -> Void,
386
+ onGenerationInfo: @escaping (Int, Double) -> Void,
387
+ toolExecutionTime: inout Double
388
+ ) async throws -> String {
389
+ if depth >= maxToolCallDepth {
390
+ log("Max tool call depth reached (\(maxToolCallDepth))")
391
+ return ""
392
+ }
393
+
394
+ var output = ""
395
+ var thinkingMachine = ThinkingStateMachine()
396
+ var pendingToolCalls: [(id: String, tool: ToolDefinition, args: [String: Any], argsJson: String)] = []
397
+
398
+ let chat = buildChatMessages(prompt: prompt, toolResults: toolResults, depth: depth)
399
+ let userInput = UserInput(
400
+ chat: chat,
401
+ tools: !self.toolSchemas.isEmpty ? self.toolSchemas : nil
402
+ )
403
+
404
+ let lmInput = try await container.prepare(input: userInput)
405
+
406
+ let stream = try await container.perform { context in
407
+ let parameters = GenerateParameters(maxTokens: 2048, temperature: 0.7)
408
+ return try MLXLMCommon.generate(
409
+ input: lmInput,
410
+ parameters: parameters,
411
+ context: context
412
+ )
413
+ }
414
+
415
+ for await generation in stream {
416
+ if Task.isCancelled { break }
417
+
418
+ switch generation {
419
+ case .chunk(let text):
420
+ let outputs = thinkingMachine.process(token: text)
421
+
422
+ for machineOutput in outputs {
423
+ switch machineOutput {
424
+ case .token(let token):
425
+ output += token
426
+ emitter.emitToken(token)
427
+ onTokenProcessed()
428
+
429
+ case .thinkingStart:
430
+ emitter.emitThinkingStart()
431
+
432
+ case .thinkingChunk(let chunk):
433
+ emitter.emitThinkingChunk(chunk)
434
+
435
+ case .thinkingEnd(let content):
436
+ emitter.emitThinkingEnd(content)
437
+ }
438
+ }
439
+
440
+ case .toolCall(let toolCall):
441
+ log("Tool call detected: \(toolCall.function.name)")
442
+
443
+ guard let tool = self.tools.first(where: { $0.name == toolCall.function.name }) else {
444
+ log("Unknown tool: \(toolCall.function.name)")
445
+ continue
446
+ }
447
+
448
+ let toolCallId = UUID().uuidString
449
+ let argsDict = self.convertToolCallArguments(toolCall.function.arguments)
450
+ let argsJson = dictionaryToJson(argsDict)
451
+
452
+ emitter.emitToolCallStart(id: toolCallId, name: toolCall.function.name, arguments: argsJson)
453
+ pendingToolCalls.append((id: toolCallId, tool: tool, args: argsDict, argsJson: argsJson))
454
+
455
+ case .info(let info):
456
+ log("Generation info: \(info.generationTokenCount) tokens, \(String(format: "%.1f", info.tokensPerSecond)) tokens/s")
457
+ let generationTime = info.tokensPerSecond > 0 ? Double(info.generationTokenCount) / info.tokensPerSecond * 1000 : 0
458
+ onGenerationInfo(info.generationTokenCount, generationTime)
459
+ }
460
+ }
461
+
462
+ let flushOutputs = thinkingMachine.flush()
463
+ for machineOutput in flushOutputs {
464
+ switch machineOutput {
465
+ case .token(let token):
466
+ output += token
467
+ emitter.emitToken(token)
468
+ onTokenProcessed()
469
+ case .thinkingStart:
470
+ emitter.emitThinkingStart()
471
+ case .thinkingChunk(let chunk):
472
+ emitter.emitThinkingChunk(chunk)
473
+ case .thinkingEnd(let content):
474
+ emitter.emitThinkingEnd(content)
475
+ }
476
+ }
477
+
478
+ if !pendingToolCalls.isEmpty {
479
+ log("Executing \(pendingToolCalls.count) tool call(s)")
480
+
481
+ let toolStartTime = Date()
482
+
483
+ for call in pendingToolCalls {
484
+ emitter.emitToolCallExecuting(id: call.id)
485
+ }
486
+
487
+ let allToolResults: [String] = await withTaskGroup(of: (Int, String).self) { group in
488
+ for (index, call) in pendingToolCalls.enumerated() {
489
+ group.addTask { [self] in
490
+ do {
491
+ let resultJson = try await self.executeToolCall(tool: call.tool, argsDict: call.args)
492
+ self.log("Tool result for \(call.tool.name): \(resultJson.prefix(100))...")
493
+ emitter.emitToolCallCompleted(id: call.id, result: resultJson)
494
+ return (index, resultJson)
495
+ } catch {
496
+ self.log("Tool execution error for \(call.tool.name): \(error)")
497
+ emitter.emitToolCallFailed(id: call.id, error: error.localizedDescription)
498
+ return (index, "{\"error\": \"Tool execution failed\"}")
499
+ }
500
+ }
501
+ }
502
+
503
+ var results = Array(repeating: "", count: pendingToolCalls.count)
504
+ for await (index, result) in group {
505
+ results[index] = result
506
+ }
507
+ return results
508
+ }
509
+
510
+ toolExecutionTime += Date().timeIntervalSince(toolStartTime) * 1000
511
+
512
+ if !output.isEmpty {
513
+ self.messageHistory.append(LLMMessage(role: "assistant", content: output))
514
+ }
515
+
516
+ for result in allToolResults {
517
+ self.messageHistory.append(LLMMessage(role: "tool", content: result))
518
+ }
519
+
520
+ let continuation = try await self.performGenerationWithEvents(
521
+ container: container,
522
+ prompt: prompt,
523
+ toolResults: allToolResults,
524
+ depth: depth + 1,
525
+ emitter: emitter,
526
+ onTokenProcessed: onTokenProcessed,
527
+ onGenerationInfo: onGenerationInfo,
528
+ toolExecutionTime: &toolExecutionTime
529
+ )
530
+
531
+ return output + continuation
532
+ }
533
+
534
+ return output
535
+ }
536
+
537
+ private func performGeneration(
538
+ container: ModelContainer,
539
+ prompt: String,
540
+ toolResults: [String]?,
541
+ depth: Int,
542
+ onToken: @escaping (String) -> Void,
543
+ onToolCall: @escaping (String, String) -> Void
544
+ ) async throws -> String {
545
+ if depth >= maxToolCallDepth {
546
+ log("Max tool call depth reached (\(maxToolCallDepth))")
547
+ return ""
548
+ }
549
+
550
+ var output = ""
551
+ var pendingToolCalls: [(tool: ToolDefinition, args: [String: Any], argsJson: String)] = []
552
+
553
+ let chat = buildChatMessages(prompt: prompt, toolResults: toolResults, depth: depth)
554
+ let userInput = UserInput(
555
+ chat: chat,
556
+ tools: !self.toolSchemas.isEmpty ? self.toolSchemas : nil
557
+ )
558
+
559
+ let lmInput = try await container.prepare(input: userInput)
560
+
561
+ let stream = try await container.perform { context in
562
+ let parameters = GenerateParameters(maxTokens: 2048, temperature: 0.7)
563
+ return try MLXLMCommon.generate(
564
+ input: lmInput,
565
+ parameters: parameters,
566
+ context: context
567
+ )
568
+ }
569
+
570
+ for await generation in stream {
571
+ if Task.isCancelled { break }
572
+
573
+ switch generation {
574
+ case .chunk(let text):
575
+ output += text
576
+ onToken(text)
577
+
578
+ case .toolCall(let toolCall):
579
+ log("Tool call detected: \(toolCall.function.name)")
580
+
581
+ guard let tool = self.tools.first(where: { $0.name == toolCall.function.name }) else {
582
+ log("Unknown tool: \(toolCall.function.name)")
583
+ continue
584
+ }
585
+
586
+ let argsDict = self.convertToolCallArguments(toolCall.function.arguments)
587
+ let argsJson = dictionaryToJson(argsDict)
588
+
589
+ pendingToolCalls.append((tool: tool, args: argsDict, argsJson: argsJson))
590
+ onToolCall(toolCall.function.name, argsJson)
591
+
592
+ case .info(let info):
593
+ log("Generation info: \(info.generationTokenCount) tokens, \(String(format: "%.1f", info.tokensPerSecond)) tokens/s")
594
+ }
595
+ }
596
+
597
+ if !pendingToolCalls.isEmpty {
598
+ log("Executing \(pendingToolCalls.count) tool call(s)")
599
+
600
+ let allToolResults: [String] = await withTaskGroup(of: (Int, String).self) { group in
601
+ for (index, call) in pendingToolCalls.enumerated() {
602
+ group.addTask { [self] in
603
+ do {
604
+ let resultJson = try await self.executeToolCall(tool: call.tool, argsDict: call.args)
605
+ self.log("Tool result for \(call.tool.name): \(resultJson.prefix(100))...")
606
+ return (index, resultJson)
607
+ } catch {
608
+ self.log("Tool execution error for \(call.tool.name): \(error)")
609
+ return (index, "{\"error\": \"Tool execution failed\"}")
610
+ }
611
+ }
612
+ }
613
+
614
+ var results = Array(repeating: "", count: pendingToolCalls.count)
615
+ for await (index, result) in group {
616
+ results[index] = result
617
+ }
618
+ return results
619
+ }
620
+
621
+ if !output.isEmpty {
622
+ self.messageHistory.append(LLMMessage(role: "assistant", content: output))
623
+ }
624
+
625
+ if depth == 0 {
626
+ self.messageHistory.append(LLMMessage(role: "user", content: prompt))
627
+ }
628
+
629
+ for result in allToolResults {
630
+ self.messageHistory.append(LLMMessage(role: "tool", content: result))
631
+ }
632
+
633
+ onToken("\u{200B}")
634
+
635
+ let continuation = try await self.performGeneration(
636
+ container: container,
637
+ prompt: prompt,
638
+ toolResults: allToolResults,
639
+ depth: depth + 1,
640
+ onToken: onToken,
641
+ onToolCall: onToolCall
642
+ )
643
+
644
+ return output + continuation
645
+ }
646
+
647
+ return output
648
+ }
649
+
650
+ private func convertToolCallArguments(_ arguments: [String: JSONValue]) -> [String: Any] {
651
+ var result: [String: Any] = [:]
652
+ for (key, value) in arguments {
653
+ result[key] = value.anyValue
654
+ }
655
+ return result
656
+ }
657
+
658
+ private func dictionaryToAnyMap(_ dict: [String: Any]) -> AnyMap {
659
+ let anyMap = AnyMap()
660
+ for (key, value) in dict {
661
+ switch value {
662
+ case let stringValue as String:
663
+ anyMap.setString(key: key, value: stringValue)
664
+ case let doubleValue as Double:
665
+ anyMap.setDouble(key: key, value: doubleValue)
666
+ case let intValue as Int:
667
+ anyMap.setDouble(key: key, value: Double(intValue))
668
+ case let boolValue as Bool:
669
+ anyMap.setBoolean(key: key, value: boolValue)
670
+ default:
671
+ anyMap.setString(key: key, value: String(describing: value))
672
+ }
673
+ }
674
+ return anyMap
675
+ }
676
+
677
+ private func anyMapToDictionary(_ anyMap: AnyMap) -> [String: Any] {
678
+ var dict: [String: Any] = [:]
679
+ for key in anyMap.getAllKeys() {
680
+ if anyMap.isString(key: key) {
681
+ dict[key] = anyMap.getString(key: key)
682
+ } else if anyMap.isDouble(key: key) {
683
+ dict[key] = anyMap.getDouble(key: key)
684
+ } else if anyMap.isBool(key: key) {
685
+ dict[key] = anyMap.getBoolean(key: key)
686
+ }
687
+ }
688
+ return dict
206
689
  }
207
690
 
208
691
  func stop() throws {
@@ -211,6 +694,9 @@ class HybridLLM: HybridLLMSpec {
211
694
  }
212
695
 
213
696
  func unload() throws {
697
+ loadTask?.cancel()
698
+ loadTask = nil
699
+
214
700
  let memoryBefore = getMemoryUsage()
215
701
  let gpuBefore = getGPUMemoryUsage()
216
702
  log("Before unload - Host: \(memoryBefore), GPU: \(gpuBefore)")
@@ -219,11 +705,13 @@ class HybridLLM: HybridLLMSpec {
219
705
  currentTask = nil
220
706
  session = nil
221
707
  container = nil
708
+ tools = []
709
+ toolSchemas = []
222
710
  messageHistory = []
223
711
  manageHistory = false
224
712
  modelId = ""
225
713
 
226
- MLX.GPU.clearCache()
714
+ MLX.Memory.clearCache()
227
715
 
228
716
  let memoryAfter = getMemoryUsage()
229
717
  let gpuAfter = getGPUMemoryUsage()