react-native-nitro-mlx 0.3.0 → 0.4.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/MLXReactNative.podspec +7 -1
- package/ios/Sources/AudioCaptureManager.swift +110 -0
- package/ios/Sources/HybridLLM.swift +309 -68
- package/ios/Sources/HybridSTT.swift +202 -0
- package/ios/Sources/HybridTTS.swift +145 -0
- package/ios/Sources/JSONHelpers.swift +9 -0
- package/ios/Sources/ModelDownloader.swift +26 -12
- package/ios/Sources/StreamEventEmitter.swift +132 -0
- package/ios/Sources/ThinkingStateMachine.swift +206 -0
- package/lib/module/index.js +2 -0
- package/lib/module/index.js.map +1 -1
- package/lib/module/llm.js +39 -1
- package/lib/module/llm.js.map +1 -1
- package/lib/module/models.js +97 -26
- package/lib/module/models.js.map +1 -1
- package/lib/module/specs/STT.nitro.js +4 -0
- package/lib/module/specs/STT.nitro.js.map +1 -0
- package/lib/module/specs/TTS.nitro.js +4 -0
- package/lib/module/specs/TTS.nitro.js.map +1 -0
- package/lib/module/stt.js +49 -0
- package/lib/module/stt.js.map +1 -0
- package/lib/module/tts.js +40 -0
- package/lib/module/tts.js.map +1 -0
- package/lib/typescript/src/index.d.ts +7 -3
- package/lib/typescript/src/index.d.ts.map +1 -1
- package/lib/typescript/src/llm.d.ts +32 -2
- package/lib/typescript/src/llm.d.ts.map +1 -1
- package/lib/typescript/src/models.d.ts +13 -4
- package/lib/typescript/src/models.d.ts.map +1 -1
- package/lib/typescript/src/specs/LLM.nitro.d.ts +49 -4
- package/lib/typescript/src/specs/LLM.nitro.d.ts.map +1 -1
- package/lib/typescript/src/specs/STT.nitro.d.ts +28 -0
- package/lib/typescript/src/specs/STT.nitro.d.ts.map +1 -0
- package/lib/typescript/src/specs/TTS.nitro.d.ts +22 -0
- package/lib/typescript/src/specs/TTS.nitro.d.ts.map +1 -0
- package/lib/typescript/src/stt.d.ts +16 -0
- package/lib/typescript/src/stt.d.ts.map +1 -0
- package/lib/typescript/src/tts.d.ts +13 -0
- package/lib/typescript/src/tts.d.ts.map +1 -0
- package/nitrogen/generated/ios/MLXReactNative-Swift-Cxx-Bridge.cpp +42 -0
- package/nitrogen/generated/ios/MLXReactNative-Swift-Cxx-Bridge.hpp +165 -0
- package/nitrogen/generated/ios/MLXReactNative-Swift-Cxx-Umbrella.hpp +20 -0
- package/nitrogen/generated/ios/MLXReactNativeAutolinking.mm +16 -0
- package/nitrogen/generated/ios/MLXReactNativeAutolinking.swift +30 -0
- package/nitrogen/generated/ios/c++/HybridLLMSpecSwift.hpp +8 -0
- package/nitrogen/generated/ios/c++/HybridSTTSpecSwift.cpp +11 -0
- package/nitrogen/generated/ios/c++/HybridSTTSpecSwift.hpp +149 -0
- package/nitrogen/generated/ios/c++/HybridTTSSpecSwift.cpp +11 -0
- package/nitrogen/generated/ios/c++/HybridTTSSpecSwift.hpp +128 -0
- package/nitrogen/generated/ios/swift/Func_void_std__shared_ptr_ArrayBuffer_.swift +47 -0
- package/nitrogen/generated/ios/swift/GenerationStats.swift +13 -2
- package/nitrogen/generated/ios/swift/HybridLLMSpec.swift +1 -0
- package/nitrogen/generated/ios/swift/HybridLLMSpec_cxx.swift +24 -0
- package/nitrogen/generated/ios/swift/HybridSTTSpec.swift +66 -0
- package/nitrogen/generated/ios/swift/HybridSTTSpec_cxx.swift +286 -0
- package/nitrogen/generated/ios/swift/HybridTTSSpec.swift +63 -0
- package/nitrogen/generated/ios/swift/HybridTTSSpec_cxx.swift +229 -0
- package/nitrogen/generated/ios/swift/STTLoadOptions.swift +66 -0
- package/nitrogen/generated/ios/swift/TTSGenerateOptions.swift +78 -0
- package/nitrogen/generated/ios/swift/TTSLoadOptions.swift +66 -0
- package/nitrogen/generated/shared/c++/GenerationStats.hpp +6 -2
- package/nitrogen/generated/shared/c++/HybridLLMSpec.cpp +1 -0
- package/nitrogen/generated/shared/c++/HybridLLMSpec.hpp +1 -0
- package/nitrogen/generated/shared/c++/HybridSTTSpec.cpp +32 -0
- package/nitrogen/generated/shared/c++/HybridSTTSpec.hpp +78 -0
- package/nitrogen/generated/shared/c++/HybridTTSSpec.cpp +29 -0
- package/nitrogen/generated/shared/c++/HybridTTSSpec.hpp +78 -0
- package/nitrogen/generated/shared/c++/STTLoadOptions.hpp +76 -0
- package/nitrogen/generated/shared/c++/TTSGenerateOptions.hpp +80 -0
- package/nitrogen/generated/shared/c++/TTSLoadOptions.hpp +76 -0
- package/package.json +8 -4
- package/src/index.ts +31 -1
- package/src/llm.ts +48 -2
- package/src/models.ts +81 -1
- package/src/specs/LLM.nitro.ts +74 -4
- package/src/specs/STT.nitro.ts +35 -0
- package/src/specs/TTS.nitro.ts +30 -0
- package/src/stt.ts +67 -0
- package/src/tts.ts +60 -0
|
@@ -0,0 +1,202 @@
|
|
|
1
|
+
import Foundation
|
|
2
|
+
import NitroModules
|
|
3
|
+
internal import MLX
|
|
4
|
+
internal import MLXAudioSTT
|
|
5
|
+
internal import MLXAudioCore
|
|
6
|
+
|
|
7
|
+
enum STTError: Error {
|
|
8
|
+
case notLoaded
|
|
9
|
+
case notListening
|
|
10
|
+
case alreadyListening
|
|
11
|
+
}
|
|
12
|
+
|
|
13
|
+
class HybridSTT: HybridSTTSpec {
|
|
14
|
+
private var model: GLMASRModel?
|
|
15
|
+
private var activeTask: Task<String, Error>?
|
|
16
|
+
private var loadTask: Task<Void, Error>?
|
|
17
|
+
private var captureManager: AudioCaptureManager?
|
|
18
|
+
|
|
19
|
+
var isLoaded: Bool { model != nil }
|
|
20
|
+
var isTranscribing: Bool { activeTask != nil }
|
|
21
|
+
var isListening: Bool { captureManager?.isCapturing ?? false }
|
|
22
|
+
var modelId: String = ""
|
|
23
|
+
|
|
24
|
+
private func arrayBufferToMLXArray(_ buffer: ArrayBuffer) -> MLXArray {
|
|
25
|
+
let count = buffer.size / MemoryLayout<Float>.size
|
|
26
|
+
let rawPtr = UnsafeRawPointer(buffer.data)
|
|
27
|
+
let floatPtr = rawPtr.bindMemory(to: Float.self, capacity: count)
|
|
28
|
+
let floatBuffer = UnsafeBufferPointer(start: floatPtr, count: count)
|
|
29
|
+
return MLXArray(Array(floatBuffer))
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
func load(modelId: String, options: STTLoadOptions?) throws -> Promise<Void> {
|
|
33
|
+
self.loadTask?.cancel()
|
|
34
|
+
|
|
35
|
+
return Promise.async { [self] in
|
|
36
|
+
let task = Task { @MainActor in
|
|
37
|
+
self.activeTask?.cancel()
|
|
38
|
+
self.activeTask = nil
|
|
39
|
+
self.model = nil
|
|
40
|
+
MLX.Memory.clearCache()
|
|
41
|
+
|
|
42
|
+
let loadedModel = try await GLMASRModel.fromPretrained(modelId)
|
|
43
|
+
|
|
44
|
+
try Task.checkCancellation()
|
|
45
|
+
|
|
46
|
+
self.model = loadedModel
|
|
47
|
+
self.modelId = modelId
|
|
48
|
+
|
|
49
|
+
options?.onProgress?(1.0)
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
self.loadTask = task
|
|
53
|
+
try await task.value
|
|
54
|
+
}
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
func transcribe(audio: ArrayBuffer) throws -> Promise<String> {
|
|
58
|
+
guard let model else {
|
|
59
|
+
throw STTError.notLoaded
|
|
60
|
+
}
|
|
61
|
+
|
|
62
|
+
return Promise.async { [self] in
|
|
63
|
+
let task = Task<String, Error> {
|
|
64
|
+
let mlxAudio = self.arrayBufferToMLXArray(audio)
|
|
65
|
+
let output = model.generate(audio: mlxAudio)
|
|
66
|
+
return output.text
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
self.activeTask = task
|
|
70
|
+
defer { self.activeTask = nil }
|
|
71
|
+
|
|
72
|
+
return try await task.value
|
|
73
|
+
}
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
func transcribeStream(
|
|
77
|
+
audio: ArrayBuffer,
|
|
78
|
+
onToken: @escaping (_ token: String) -> Void
|
|
79
|
+
) throws -> Promise<String> {
|
|
80
|
+
guard let model else {
|
|
81
|
+
throw STTError.notLoaded
|
|
82
|
+
}
|
|
83
|
+
|
|
84
|
+
return Promise.async { [self] in
|
|
85
|
+
let task = Task<String, Error> {
|
|
86
|
+
let mlxAudio = self.arrayBufferToMLXArray(audio)
|
|
87
|
+
let stream = model.generateStream(audio: mlxAudio)
|
|
88
|
+
var finalText = ""
|
|
89
|
+
|
|
90
|
+
for try await event in stream {
|
|
91
|
+
if Task.isCancelled { break }
|
|
92
|
+
|
|
93
|
+
switch event {
|
|
94
|
+
case .token(let token):
|
|
95
|
+
onToken(token)
|
|
96
|
+
case .result(let output):
|
|
97
|
+
finalText = output.text
|
|
98
|
+
case .info:
|
|
99
|
+
break
|
|
100
|
+
}
|
|
101
|
+
}
|
|
102
|
+
|
|
103
|
+
return finalText
|
|
104
|
+
}
|
|
105
|
+
|
|
106
|
+
self.activeTask = task
|
|
107
|
+
defer { self.activeTask = nil }
|
|
108
|
+
|
|
109
|
+
return try await task.value
|
|
110
|
+
}
|
|
111
|
+
}
|
|
112
|
+
|
|
113
|
+
func startListening() throws -> Promise<Void> {
|
|
114
|
+
guard model != nil else {
|
|
115
|
+
throw STTError.notLoaded
|
|
116
|
+
}
|
|
117
|
+
guard captureManager == nil || !captureManager!.isCapturing else {
|
|
118
|
+
throw STTError.alreadyListening
|
|
119
|
+
}
|
|
120
|
+
|
|
121
|
+
return Promise.async { [self] in
|
|
122
|
+
let manager = AudioCaptureManager()
|
|
123
|
+
self.captureManager = manager
|
|
124
|
+
try await manager.startCapturing()
|
|
125
|
+
}
|
|
126
|
+
}
|
|
127
|
+
|
|
128
|
+
func transcribeBuffer() throws -> Promise<String> {
|
|
129
|
+
guard let model else {
|
|
130
|
+
throw STTError.notLoaded
|
|
131
|
+
}
|
|
132
|
+
guard let manager = captureManager, manager.isCapturing else {
|
|
133
|
+
throw STTError.notListening
|
|
134
|
+
}
|
|
135
|
+
guard let audio = manager.snapshot() else {
|
|
136
|
+
return Promise.resolved(withResult: "")
|
|
137
|
+
}
|
|
138
|
+
|
|
139
|
+
return Promise.async { [self] in
|
|
140
|
+
let task = Task<String, Error> {
|
|
141
|
+
let output = model.generate(audio: audio)
|
|
142
|
+
return output.text
|
|
143
|
+
}
|
|
144
|
+
|
|
145
|
+
self.activeTask = task
|
|
146
|
+
defer { self.activeTask = nil }
|
|
147
|
+
|
|
148
|
+
let result = try await task.value
|
|
149
|
+
MLX.Memory.clearCache()
|
|
150
|
+
return result
|
|
151
|
+
}
|
|
152
|
+
}
|
|
153
|
+
|
|
154
|
+
func stopListening() throws -> Promise<String> {
|
|
155
|
+
guard let model else {
|
|
156
|
+
throw STTError.notLoaded
|
|
157
|
+
}
|
|
158
|
+
guard let manager = captureManager, manager.isCapturing else {
|
|
159
|
+
throw STTError.notListening
|
|
160
|
+
}
|
|
161
|
+
|
|
162
|
+
let audio = manager.stopCapturing()
|
|
163
|
+
self.captureManager = nil
|
|
164
|
+
|
|
165
|
+
return Promise.async { [self] in
|
|
166
|
+
let task = Task<String, Error> {
|
|
167
|
+
let output = model.generate(audio: audio)
|
|
168
|
+
return output.text
|
|
169
|
+
}
|
|
170
|
+
|
|
171
|
+
self.activeTask = task
|
|
172
|
+
defer { self.activeTask = nil }
|
|
173
|
+
|
|
174
|
+
let result = try await task.value
|
|
175
|
+
MLX.Memory.clearCache()
|
|
176
|
+
return result
|
|
177
|
+
}
|
|
178
|
+
}
|
|
179
|
+
|
|
180
|
+
func stop() throws {
|
|
181
|
+
activeTask?.cancel()
|
|
182
|
+
activeTask = nil
|
|
183
|
+
if let manager = captureManager, manager.isCapturing {
|
|
184
|
+
_ = manager.stopCapturing()
|
|
185
|
+
}
|
|
186
|
+
captureManager = nil
|
|
187
|
+
}
|
|
188
|
+
|
|
189
|
+
func unload() throws {
|
|
190
|
+
loadTask?.cancel()
|
|
191
|
+
loadTask = nil
|
|
192
|
+
activeTask?.cancel()
|
|
193
|
+
activeTask = nil
|
|
194
|
+
if let manager = captureManager, manager.isCapturing {
|
|
195
|
+
_ = manager.stopCapturing()
|
|
196
|
+
}
|
|
197
|
+
captureManager = nil
|
|
198
|
+
model = nil
|
|
199
|
+
modelId = ""
|
|
200
|
+
Memory.clearCache()
|
|
201
|
+
}
|
|
202
|
+
}
|
|
@@ -0,0 +1,145 @@
|
|
|
1
|
+
import Foundation
|
|
2
|
+
import NitroModules
|
|
3
|
+
internal import MLX
|
|
4
|
+
internal import MLXAudioTTS
|
|
5
|
+
internal import MLXAudioCore
|
|
6
|
+
|
|
7
|
+
enum TTSError: Error {
|
|
8
|
+
case notLoaded
|
|
9
|
+
}
|
|
10
|
+
|
|
11
|
+
class HybridTTS: HybridTTSSpec {
|
|
12
|
+
private var model: SpeechGenerationModel?
|
|
13
|
+
private var activeTask: Task<Any, Error>?
|
|
14
|
+
private var loadTask: Task<Void, Error>?
|
|
15
|
+
|
|
16
|
+
var isLoaded: Bool { model != nil }
|
|
17
|
+
var isGenerating: Bool { activeTask != nil }
|
|
18
|
+
var modelId: String = ""
|
|
19
|
+
var sampleRate: Double {
|
|
20
|
+
Double(model?.sampleRate ?? 24000)
|
|
21
|
+
}
|
|
22
|
+
|
|
23
|
+
private func mlxArrayToArrayBuffer(_ audio: MLXArray) -> ArrayBuffer {
|
|
24
|
+
let evaluated = audio.asType(.float32)
|
|
25
|
+
MLX.eval(evaluated)
|
|
26
|
+
let arrayData = evaluated.asData(access: .copy)
|
|
27
|
+
let byteSize = arrayData.data.count
|
|
28
|
+
let buffer = ArrayBuffer.allocate(size: byteSize)
|
|
29
|
+
arrayData.data.withUnsafeBytes { srcPtr in
|
|
30
|
+
UnsafeMutableRawPointer(buffer.data).copyMemory(
|
|
31
|
+
from: srcPtr.baseAddress!,
|
|
32
|
+
byteCount: byteSize
|
|
33
|
+
)
|
|
34
|
+
}
|
|
35
|
+
return buffer
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
func load(modelId: String, options: TTSLoadOptions?) throws -> Promise<Void> {
|
|
39
|
+
self.loadTask?.cancel()
|
|
40
|
+
|
|
41
|
+
return Promise.async { [self] in
|
|
42
|
+
let task = Task { @MainActor in
|
|
43
|
+
self.activeTask?.cancel()
|
|
44
|
+
self.activeTask = nil
|
|
45
|
+
self.model = nil
|
|
46
|
+
MLX.Memory.clearCache()
|
|
47
|
+
|
|
48
|
+
let loadedModel = try await TTSModelUtils.loadModel(modelRepo: modelId)
|
|
49
|
+
|
|
50
|
+
try Task.checkCancellation()
|
|
51
|
+
|
|
52
|
+
self.model = loadedModel
|
|
53
|
+
self.modelId = modelId
|
|
54
|
+
|
|
55
|
+
options?.onProgress?(1.0)
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
self.loadTask = task
|
|
59
|
+
try await task.value
|
|
60
|
+
}
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
func generate(
|
|
64
|
+
text: String,
|
|
65
|
+
options: TTSGenerateOptions?
|
|
66
|
+
) throws -> Promise<ArrayBuffer> {
|
|
67
|
+
guard let model else {
|
|
68
|
+
throw TTSError.notLoaded
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
return Promise.async { [self] in
|
|
72
|
+
let task = Task<Any, Error> {
|
|
73
|
+
let audio = try await model.generate(
|
|
74
|
+
text: text,
|
|
75
|
+
voice: options?.voice,
|
|
76
|
+
refAudio: nil,
|
|
77
|
+
refText: nil,
|
|
78
|
+
language: nil
|
|
79
|
+
)
|
|
80
|
+
return self.mlxArrayToArrayBuffer(audio) as Any
|
|
81
|
+
}
|
|
82
|
+
|
|
83
|
+
self.activeTask = task
|
|
84
|
+
defer { self.activeTask = nil }
|
|
85
|
+
|
|
86
|
+
return try await task.value as! ArrayBuffer
|
|
87
|
+
}
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
func stream(
|
|
91
|
+
text: String,
|
|
92
|
+
onAudioChunk: @escaping (ArrayBuffer) -> Void,
|
|
93
|
+
options: TTSGenerateOptions?
|
|
94
|
+
) throws -> Promise<Void> {
|
|
95
|
+
guard let model else {
|
|
96
|
+
throw TTSError.notLoaded
|
|
97
|
+
}
|
|
98
|
+
|
|
99
|
+
return Promise.async { [self] in
|
|
100
|
+
let task = Task<Any, Error> {
|
|
101
|
+
let stream = model.generateStream(
|
|
102
|
+
text: text,
|
|
103
|
+
voice: options?.voice,
|
|
104
|
+
refAudio: nil,
|
|
105
|
+
refText: nil,
|
|
106
|
+
language: nil,
|
|
107
|
+
generationParameters: model.defaultGenerationParameters
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
for try await event in stream {
|
|
111
|
+
if Task.isCancelled { break }
|
|
112
|
+
|
|
113
|
+
switch event {
|
|
114
|
+
case .audio(let audio):
|
|
115
|
+
let buffer = self.mlxArrayToArrayBuffer(audio)
|
|
116
|
+
onAudioChunk(buffer)
|
|
117
|
+
case .token, .info:
|
|
118
|
+
break
|
|
119
|
+
}
|
|
120
|
+
}
|
|
121
|
+
return () as Any
|
|
122
|
+
}
|
|
123
|
+
|
|
124
|
+
self.activeTask = task
|
|
125
|
+
defer { self.activeTask = nil }
|
|
126
|
+
|
|
127
|
+
_ = try await task.value
|
|
128
|
+
}
|
|
129
|
+
}
|
|
130
|
+
|
|
131
|
+
func stop() throws {
|
|
132
|
+
activeTask?.cancel()
|
|
133
|
+
activeTask = nil
|
|
134
|
+
}
|
|
135
|
+
|
|
136
|
+
func unload() throws {
|
|
137
|
+
loadTask?.cancel()
|
|
138
|
+
loadTask = nil
|
|
139
|
+
activeTask?.cancel()
|
|
140
|
+
activeTask = nil
|
|
141
|
+
model = nil
|
|
142
|
+
modelId = ""
|
|
143
|
+
Memory.clearCache()
|
|
144
|
+
}
|
|
145
|
+
}
|
|
@@ -12,32 +12,46 @@ actor ModelDownloader: NSObject {
|
|
|
12
12
|
}
|
|
13
13
|
}
|
|
14
14
|
|
|
15
|
+
private let downloadableExtensions: Set<String> = [
|
|
16
|
+
"json", "safetensors", "txt", "model", "tiktoken", "py"
|
|
17
|
+
]
|
|
18
|
+
|
|
19
|
+
private func fetchFileList(modelId: String) async throws -> [String] {
|
|
20
|
+
let urlString = "https://huggingface.co/api/models/\(modelId)"
|
|
21
|
+
guard let url = URL(string: urlString) else { return [] }
|
|
22
|
+
|
|
23
|
+
let (data, _) = try await URLSession.shared.data(from: url)
|
|
24
|
+
guard let json = try JSONSerialization.jsonObject(with: data) as? [String: Any],
|
|
25
|
+
let siblings = json["siblings"] as? [[String: Any]]
|
|
26
|
+
else { return [] }
|
|
27
|
+
|
|
28
|
+
return siblings.compactMap { $0["rfilename"] as? String }
|
|
29
|
+
.filter { name in
|
|
30
|
+
let ext = (name as NSString).pathExtension.lowercased()
|
|
31
|
+
return downloadableExtensions.contains(ext)
|
|
32
|
+
}
|
|
33
|
+
}
|
|
34
|
+
|
|
15
35
|
func download(
|
|
16
36
|
modelId: String,
|
|
17
37
|
progressCallback: @escaping (Double) -> Void
|
|
18
38
|
) async throws -> URL {
|
|
19
|
-
let
|
|
20
|
-
"config.json",
|
|
21
|
-
"tokenizer.json",
|
|
22
|
-
"tokenizer_config.json",
|
|
23
|
-
"model.safetensors"
|
|
24
|
-
]
|
|
25
|
-
|
|
39
|
+
let files = try await fetchFileList(modelId: modelId)
|
|
26
40
|
let modelDir = getModelDirectory(modelId: modelId)
|
|
27
41
|
try fileManager.createDirectory(at: modelDir, withIntermediateDirectories: true)
|
|
28
42
|
|
|
29
43
|
log("Model directory: \(modelDir.path)")
|
|
30
|
-
log("Files to download: \(
|
|
44
|
+
log("Files to download: \(files)")
|
|
31
45
|
|
|
32
46
|
var downloaded = 0
|
|
33
47
|
|
|
34
|
-
for file in
|
|
48
|
+
for file in files {
|
|
35
49
|
let destURL = modelDir.appendingPathComponent(file)
|
|
36
50
|
|
|
37
51
|
if fileManager.fileExists(atPath: destURL.path) {
|
|
38
52
|
log("File exists, skipping: \(file)")
|
|
39
53
|
downloaded += 1
|
|
40
|
-
progressCallback(Double(downloaded) / Double(
|
|
54
|
+
progressCallback(Double(downloaded) / Double(files.count))
|
|
41
55
|
continue
|
|
42
56
|
}
|
|
43
57
|
|
|
@@ -69,7 +83,7 @@ actor ModelDownloader: NSObject {
|
|
|
69
83
|
}
|
|
70
84
|
|
|
71
85
|
downloaded += 1
|
|
72
|
-
progressCallback(Double(downloaded) / Double(
|
|
86
|
+
progressCallback(Double(downloaded) / Double(files.count))
|
|
73
87
|
}
|
|
74
88
|
|
|
75
89
|
return modelDir
|
|
@@ -77,7 +91,7 @@ actor ModelDownloader: NSObject {
|
|
|
77
91
|
|
|
78
92
|
func isDownloaded(modelId: String) -> Bool {
|
|
79
93
|
let modelDir = getModelDirectory(modelId: modelId)
|
|
80
|
-
let requiredFiles = ["config.json", "model.safetensors"
|
|
94
|
+
let requiredFiles = ["config.json", "model.safetensors"]
|
|
81
95
|
|
|
82
96
|
let allExist = requiredFiles.allSatisfy { file in
|
|
83
97
|
fileManager.fileExists(atPath: modelDir.appendingPathComponent(file).path)
|
|
@@ -0,0 +1,132 @@
|
|
|
1
|
+
import Foundation
|
|
2
|
+
import NitroModules
|
|
3
|
+
|
|
4
|
+
struct StreamEventEmitter {
|
|
5
|
+
private let callback: (String) -> Void
|
|
6
|
+
private let encoder = JSONEncoder()
|
|
7
|
+
|
|
8
|
+
init(callback: @escaping (String) -> Void) {
|
|
9
|
+
self.callback = callback
|
|
10
|
+
}
|
|
11
|
+
|
|
12
|
+
private func emit<T: Encodable>(_ event: T) {
|
|
13
|
+
guard let data = try? encoder.encode(event),
|
|
14
|
+
let json = String(data: data, encoding: .utf8) else { return }
|
|
15
|
+
callback(json)
|
|
16
|
+
}
|
|
17
|
+
|
|
18
|
+
private func timestamp() -> Double {
|
|
19
|
+
Date().timeIntervalSince1970 * 1000
|
|
20
|
+
}
|
|
21
|
+
|
|
22
|
+
struct GenerationStartEvent: Encodable {
|
|
23
|
+
let type = "generation_start"
|
|
24
|
+
let timestamp: Double
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
struct TokenEvent: Encodable {
|
|
28
|
+
let type = "token"
|
|
29
|
+
let token: String
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
struct ThinkingStartEvent: Encodable {
|
|
33
|
+
let type = "thinking_start"
|
|
34
|
+
let timestamp: Double
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
struct ThinkingChunkEvent: Encodable {
|
|
38
|
+
let type = "thinking_chunk"
|
|
39
|
+
let chunk: String
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
struct ThinkingEndEvent: Encodable {
|
|
43
|
+
let type = "thinking_end"
|
|
44
|
+
let content: String
|
|
45
|
+
let timestamp: Double
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
struct ToolCallStartEvent: Encodable {
|
|
49
|
+
let type = "tool_call_start"
|
|
50
|
+
let id: String
|
|
51
|
+
let name: String
|
|
52
|
+
let arguments: String
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
struct ToolCallExecutingEvent: Encodable {
|
|
56
|
+
let type = "tool_call_executing"
|
|
57
|
+
let id: String
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
struct ToolCallCompletedEvent: Encodable {
|
|
61
|
+
let type = "tool_call_completed"
|
|
62
|
+
let id: String
|
|
63
|
+
let result: String
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
struct ToolCallFailedEvent: Encodable {
|
|
67
|
+
let type = "tool_call_failed"
|
|
68
|
+
let id: String
|
|
69
|
+
let error: String
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
struct StatsPayload: Encodable {
|
|
73
|
+
let tokenCount: Double
|
|
74
|
+
let tokensPerSecond: Double
|
|
75
|
+
let timeToFirstToken: Double
|
|
76
|
+
let totalTime: Double
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
struct GenerationEndEvent: Encodable {
|
|
80
|
+
let type = "generation_end"
|
|
81
|
+
let content: String
|
|
82
|
+
let stats: StatsPayload
|
|
83
|
+
}
|
|
84
|
+
|
|
85
|
+
func emitGenerationStart() {
|
|
86
|
+
emit(GenerationStartEvent(timestamp: timestamp()))
|
|
87
|
+
}
|
|
88
|
+
|
|
89
|
+
func emitToken(_ token: String) {
|
|
90
|
+
emit(TokenEvent(token: token))
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
func emitThinkingStart() {
|
|
94
|
+
emit(ThinkingStartEvent(timestamp: timestamp()))
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
func emitThinkingChunk(_ chunk: String) {
|
|
98
|
+
emit(ThinkingChunkEvent(chunk: chunk))
|
|
99
|
+
}
|
|
100
|
+
|
|
101
|
+
func emitThinkingEnd(_ content: String) {
|
|
102
|
+
emit(ThinkingEndEvent(content: content, timestamp: timestamp()))
|
|
103
|
+
}
|
|
104
|
+
|
|
105
|
+
func emitToolCallStart(id: String, name: String, arguments: String) {
|
|
106
|
+
emit(ToolCallStartEvent(id: id, name: name, arguments: arguments))
|
|
107
|
+
}
|
|
108
|
+
|
|
109
|
+
func emitToolCallExecuting(id: String) {
|
|
110
|
+
emit(ToolCallExecutingEvent(id: id))
|
|
111
|
+
}
|
|
112
|
+
|
|
113
|
+
func emitToolCallCompleted(id: String, result: String) {
|
|
114
|
+
emit(ToolCallCompletedEvent(id: id, result: result))
|
|
115
|
+
}
|
|
116
|
+
|
|
117
|
+
func emitToolCallFailed(id: String, error: String) {
|
|
118
|
+
emit(ToolCallFailedEvent(id: id, error: error))
|
|
119
|
+
}
|
|
120
|
+
|
|
121
|
+
func emitGenerationEnd(content: String, stats: GenerationStats) {
|
|
122
|
+
emit(GenerationEndEvent(
|
|
123
|
+
content: content,
|
|
124
|
+
stats: StatsPayload(
|
|
125
|
+
tokenCount: stats.tokenCount,
|
|
126
|
+
tokensPerSecond: stats.tokensPerSecond,
|
|
127
|
+
timeToFirstToken: stats.timeToFirstToken,
|
|
128
|
+
totalTime: stats.totalTime
|
|
129
|
+
)
|
|
130
|
+
))
|
|
131
|
+
}
|
|
132
|
+
}
|