react-native-litert-lm 0.3.6 → 0.4.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +207 -158
- package/android/build.gradle +12 -0
- package/android/src/main/AndroidManifest.xml +5 -0
- package/android/src/main/java/com/margelo/nitro/dev/litert/litertlm/HybridLiteRTLM.kt +316 -63
- package/android/src/main/java/dev/litert/litertlm/LiteRTLMPackage.kt +19 -2
- package/android/src/test/java/com/margelo/nitro/core/Promise.kt +46 -0
- package/android/src/test/java/com/margelo/nitro/dev/litert/litertlm/HybridLiteRTLMTest.kt +83 -0
- package/cpp/include/README.md +9 -11
- package/ios/HybridLiteRTLM.swift +1058 -0
- package/ios/Tests/HybridLiteRTLMTests.swift +67 -0
- package/lib/__mocks__/react-native-nitro-modules.d.ts +61 -0
- package/lib/__mocks__/react-native-nitro-modules.js +50 -0
- package/lib/__tests__/hooks.test.d.ts +1 -0
- package/lib/__tests__/hooks.test.js +124 -0
- package/lib/__tests__/memoryTracker.test.d.ts +1 -0
- package/lib/__tests__/memoryTracker.test.js +74 -0
- package/lib/__tests__/modelFactory.test.d.ts +1 -0
- package/lib/__tests__/modelFactory.test.js +52 -0
- package/lib/hooks.js +1 -1
- package/lib/index.d.ts +2 -4
- package/lib/index.js +12 -7
- package/lib/modelFactory.js +62 -63
- package/lib/specs/LiteRTLM.nitro.d.ts +71 -2
- package/nitrogen/generated/android/c++/JHybridLiteRTLMSpec.cpp +62 -7
- package/nitrogen/generated/android/c++/JHybridLiteRTLMSpec.hpp +3 -1
- package/nitrogen/generated/android/c++/JLLMConfig.hpp +40 -3
- package/nitrogen/generated/android/c++/JMultimodalPart.hpp +74 -0
- package/nitrogen/generated/android/c++/JPartType.hpp +61 -0
- package/nitrogen/generated/android/c++/JToolDefinition.hpp +65 -0
- package/nitrogen/generated/android/kotlin/com/margelo/nitro/dev/litert/litertlm/GenerationStats.kt +23 -0
- package/nitrogen/generated/android/kotlin/com/margelo/nitro/dev/litert/litertlm/HybridLiteRTLMSpec.kt +10 -2
- package/nitrogen/generated/android/kotlin/com/margelo/nitro/dev/litert/litertlm/LLMConfig.kt +46 -3
- package/nitrogen/generated/android/kotlin/com/margelo/nitro/dev/litert/litertlm/MemoryUsage.kt +19 -0
- package/nitrogen/generated/android/kotlin/com/margelo/nitro/dev/litert/litertlm/Message.kt +15 -0
- package/nitrogen/generated/android/kotlin/com/margelo/nitro/dev/litert/litertlm/MultimodalPart.kt +66 -0
- package/nitrogen/generated/android/kotlin/com/margelo/nitro/dev/litert/litertlm/PartType.kt +24 -0
- package/nitrogen/generated/android/kotlin/com/margelo/nitro/dev/litert/litertlm/ToolDefinition.kt +61 -0
- package/nitrogen/generated/ios/LiteRTLM-Swift-Cxx-Bridge.cpp +57 -1
- package/nitrogen/generated/ios/LiteRTLM-Swift-Cxx-Bridge.hpp +414 -3
- package/nitrogen/generated/ios/LiteRTLM-Swift-Cxx-Umbrella.hpp +41 -3
- package/nitrogen/generated/ios/LiteRTLMAutolinking.mm +4 -6
- package/nitrogen/generated/ios/LiteRTLMAutolinking.swift +10 -0
- package/nitrogen/generated/ios/c++/HybridLiteRTLMSpecSwift.cpp +11 -0
- package/nitrogen/generated/ios/c++/HybridLiteRTLMSpecSwift.hpp +224 -0
- package/nitrogen/generated/ios/swift/Backend.swift +44 -0
- package/nitrogen/generated/ios/swift/Func_void.swift +46 -0
- package/nitrogen/generated/ios/swift/Func_void_double.swift +46 -0
- package/nitrogen/generated/ios/swift/Func_void_std__exception_ptr.swift +46 -0
- package/nitrogen/generated/ios/swift/Func_void_std__string.swift +46 -0
- package/nitrogen/generated/ios/swift/Func_void_std__string_bool.swift +46 -0
- package/nitrogen/generated/ios/swift/GenerationStats.swift +54 -0
- package/nitrogen/generated/ios/swift/HybridLiteRTLMSpec.swift +69 -0
- package/nitrogen/generated/ios/swift/HybridLiteRTLMSpec_cxx.swift +383 -0
- package/nitrogen/generated/ios/swift/LLMConfig.swift +203 -0
- package/nitrogen/generated/ios/swift/MemoryUsage.swift +44 -0
- package/nitrogen/generated/ios/swift/Message.swift +34 -0
- package/nitrogen/generated/ios/swift/MultimodalPart.swift +83 -0
- package/nitrogen/generated/ios/swift/PartType.swift +44 -0
- package/nitrogen/generated/ios/swift/Role.swift +44 -0
- package/nitrogen/generated/ios/swift/ToolDefinition.swift +39 -0
- package/nitrogen/generated/shared/c++/HybridLiteRTLMSpec.cpp +2 -0
- package/nitrogen/generated/shared/c++/HybridLiteRTLMSpec.hpp +7 -2
- package/nitrogen/generated/shared/c++/LLMConfig.hpp +22 -2
- package/nitrogen/generated/shared/c++/MultimodalPart.hpp +99 -0
- package/nitrogen/generated/shared/c++/PartType.hpp +80 -0
- package/nitrogen/generated/shared/c++/ToolDefinition.hpp +91 -0
- package/package.json +16 -8
- package/react-native-litert-lm.podspec +15 -19
- package/scripts/download-ios-frameworks.sh +14 -48
- package/scripts/postinstall.js +1 -2
- package/src/__mocks__/react-native-nitro-modules.ts +48 -0
- package/src/__tests__/hooks.test.ts +153 -0
- package/src/__tests__/memoryTracker.test.ts +87 -0
- package/src/__tests__/modelFactory.test.ts +68 -0
- package/src/hooks.ts +1 -1
- package/src/index.ts +12 -9
- package/src/modelFactory.ts +82 -80
- package/src/specs/LiteRTLM.nitro.ts +80 -2
- package/cpp/HybridLiteRTLM.cpp +0 -838
- package/cpp/HybridLiteRTLM.hpp +0 -167
- package/cpp/IOSDownloadHelper.h +0 -24
- package/ios/IOSDownloadHelper.mm +0 -129
- package/scripts/build-ios-engine.sh +0 -302
- package/scripts/stubs/cxx_bridge_stubs.cc +0 -224
- package/scripts/stubs/gemma_model_constraint_provider.cc +0 -46
- package/scripts/stubs/llguidance_stubs.c +0 -101
- package/src/templates.ts +0 -105
|
@@ -0,0 +1,1058 @@
|
|
|
1
|
+
//
|
|
2
|
+
// HybridLiteRTLM.swift
|
|
3
|
+
// react-native-litert-lm
|
|
4
|
+
//
|
|
5
|
+
// Created by Antigravity on 2026-05-19.
|
|
6
|
+
// Copyright © 2026 Margelo. All rights reserved.
|
|
7
|
+
//
|
|
8
|
+
|
|
9
|
+
import Foundation
|
|
10
|
+
import NitroModules
|
|
11
|
+
import CLiteRTLM
|
|
12
|
+
import os
|
|
13
|
+
|
|
14
|
+
/// A stream context passed to the low-level C FFI callback to forward chunks safely to the JS thread.
|
|
15
|
+
private class StreamContext {
|
|
16
|
+
let userMessage: String
|
|
17
|
+
let startTime: Date
|
|
18
|
+
let onToken: (_ token: String, _ done: Bool) -> Void
|
|
19
|
+
let promise: Promise<Void>
|
|
20
|
+
let parent: HybridLiteRTLM
|
|
21
|
+
|
|
22
|
+
var rawResponse: String = ""
|
|
23
|
+
var fullResponse: String = ""
|
|
24
|
+
var lastEmittedLength: Int = 0
|
|
25
|
+
var tokenCount: Int = 0
|
|
26
|
+
|
|
27
|
+
init(
|
|
28
|
+
userMessage: String,
|
|
29
|
+
startTime: Date,
|
|
30
|
+
onToken: @escaping (_ token: String, _ done: Bool) -> Void,
|
|
31
|
+
promise: Promise<Void>,
|
|
32
|
+
parent: HybridLiteRTLM
|
|
33
|
+
) {
|
|
34
|
+
self.userMessage = userMessage
|
|
35
|
+
self.startTime = startTime
|
|
36
|
+
self.onToken = onToken
|
|
37
|
+
self.promise = promise
|
|
38
|
+
self.parent = parent
|
|
39
|
+
}
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
public class HybridLiteRTLM: HybridLiteRTLMSpec_base, HybridLiteRTLMSpec_protocol {
|
|
43
|
+
|
|
44
|
+
/// Dedicated background serial queue to protect the JSI/JS thread from blocking and deadlocks (User Rule #1).
|
|
45
|
+
private let queue = DispatchQueue(label: "dev.litert.engine", qos: .userInteractive)
|
|
46
|
+
|
|
47
|
+
/// Opaque pointer to the LiteRT LM C Engine.
|
|
48
|
+
private var engine: OpaquePointer?
|
|
49
|
+
|
|
50
|
+
/// Opaque pointer to the active conversation state.
|
|
51
|
+
private var conversation: OpaquePointer?
|
|
52
|
+
|
|
53
|
+
/// Thread-safe status flag.
|
|
54
|
+
private var isLoaded = false
|
|
55
|
+
|
|
56
|
+
/// Conversation history.
|
|
57
|
+
private var history: [Message] = []
|
|
58
|
+
|
|
59
|
+
/// Latest inference generation statistics.
|
|
60
|
+
private var lastStats = GenerationStats(
|
|
61
|
+
promptTokens: 0.0,
|
|
62
|
+
completionTokens: 0.0,
|
|
63
|
+
totalTokens: 0.0,
|
|
64
|
+
timeToFirstToken: 0.0,
|
|
65
|
+
totalTime: 0.0,
|
|
66
|
+
tokensPerSecond: 0.0
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
// Default configuration variables
|
|
70
|
+
private var backend: Backend = .cpu
|
|
71
|
+
private var temperature: Double = 0.7
|
|
72
|
+
private var topK: Int = 40
|
|
73
|
+
private var topP: Double = 0.95
|
|
74
|
+
private var maxTokens: Int = 1024
|
|
75
|
+
private var systemPrompt: String?
|
|
76
|
+
private var tools: [ToolDefinition]?
|
|
77
|
+
private var enableSpeculativeDecoding: Bool = false
|
|
78
|
+
|
|
79
|
+
/// Approximate model weight size to inform the JS engine's garbage collection.
|
|
80
|
+
public var memorySize: Int {
|
|
81
|
+
return 1024 * 1024 * 1024 // ~1GB proxy
|
|
82
|
+
}
|
|
83
|
+
|
|
84
|
+
deinit {
|
|
85
|
+
closeInternal()
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
// MARK: - Core Hybrid Object API
|
|
89
|
+
|
|
90
|
+
public func isReady() throws -> Bool {
|
|
91
|
+
return queue.sync { isLoaded }
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
public func getHistory() throws -> [Message] {
|
|
95
|
+
return queue.sync { history }
|
|
96
|
+
}
|
|
97
|
+
|
|
98
|
+
public func resetConversation() throws {
|
|
99
|
+
queue.sync {
|
|
100
|
+
history.removeAll()
|
|
101
|
+
lastStats = GenerationStats(
|
|
102
|
+
promptTokens: 0.0,
|
|
103
|
+
completionTokens: 0.0,
|
|
104
|
+
totalTokens: 0.0,
|
|
105
|
+
timeToFirstToken: 0.0,
|
|
106
|
+
totalTime: 0.0,
|
|
107
|
+
tokensPerSecond: 0.0
|
|
108
|
+
)
|
|
109
|
+
if isLoaded && engine != nil {
|
|
110
|
+
createNewConversation()
|
|
111
|
+
}
|
|
112
|
+
}
|
|
113
|
+
}
|
|
114
|
+
|
|
115
|
+
public func getStats() throws -> GenerationStats {
|
|
116
|
+
return queue.sync { lastStats }
|
|
117
|
+
}
|
|
118
|
+
|
|
119
|
+
public func countTokens(text: String) throws -> Double {
|
|
120
|
+
return try queue.sync {
|
|
121
|
+
guard let engine = self.engine else {
|
|
122
|
+
return -1.0
|
|
123
|
+
}
|
|
124
|
+
guard let result = litert_lm_engine_tokenize(engine, text) else {
|
|
125
|
+
return -1.0
|
|
126
|
+
}
|
|
127
|
+
let numTokens = litert_lm_tokenize_result_get_num_tokens(result)
|
|
128
|
+
litert_lm_tokenize_result_delete(result)
|
|
129
|
+
return Double(numTokens)
|
|
130
|
+
}
|
|
131
|
+
}
|
|
132
|
+
|
|
133
|
+
public func getMemoryUsage() throws -> MemoryUsage {
|
|
134
|
+
var residentBytes: Double = 0.0
|
|
135
|
+
var nativeHeapBytes: Double = 0.0
|
|
136
|
+
|
|
137
|
+
// Retrieve process resident set size (RSS) via Mach basic task info
|
|
138
|
+
var info = mach_task_basic_info()
|
|
139
|
+
var count = mach_msg_type_number_t(MemoryLayout<mach_task_basic_info>.size / MemoryLayout<integer_t>.size)
|
|
140
|
+
let kerr = withUnsafeMutablePointer(to: &info) {
|
|
141
|
+
$0.withMemoryRebound(to: integer_t.self, capacity: Int(count)) {
|
|
142
|
+
task_info(mach_task_self_, task_flavor_t(MACH_TASK_BASIC_INFO), $0, &count)
|
|
143
|
+
}
|
|
144
|
+
}
|
|
145
|
+
|
|
146
|
+
if kerr == KERN_SUCCESS {
|
|
147
|
+
residentBytes = Double(info.resident_size)
|
|
148
|
+
nativeHeapBytes = Double(info.resident_size)
|
|
149
|
+
}
|
|
150
|
+
|
|
151
|
+
// os_proc_available_memory reports actual headroom available before Jetsam termination (iOS 13+)
|
|
152
|
+
let availableBytes = Double(os_proc_available_memory())
|
|
153
|
+
|
|
154
|
+
// Flag memory warning at ~200MB remaining headroom
|
|
155
|
+
let isLowMemory = availableBytes < 200.0 * 1024.0 * 1024.0
|
|
156
|
+
|
|
157
|
+
return MemoryUsage(
|
|
158
|
+
nativeHeapBytes: nativeHeapBytes,
|
|
159
|
+
residentBytes: residentBytes,
|
|
160
|
+
availableMemoryBytes: availableBytes,
|
|
161
|
+
isLowMemory: isLowMemory
|
|
162
|
+
)
|
|
163
|
+
}
|
|
164
|
+
|
|
165
|
+
public func close() throws {
|
|
166
|
+
queue.sync {
|
|
167
|
+
closeInternal()
|
|
168
|
+
}
|
|
169
|
+
}
|
|
170
|
+
|
|
171
|
+
// MARK: - Async Operations
|
|
172
|
+
|
|
173
|
+
public func loadModel(modelPath: String, config: LLMConfig?) throws -> Promise<Void> {
|
|
174
|
+
let promise = Promise<Void>()
|
|
175
|
+
|
|
176
|
+
queue.async {
|
|
177
|
+
// Teardown any previous contexts
|
|
178
|
+
self.closeInternal()
|
|
179
|
+
|
|
180
|
+
// Extract configurations
|
|
181
|
+
if let config = config {
|
|
182
|
+
if let b = config.backend { self.backend = b }
|
|
183
|
+
if let t = config.temperature { self.temperature = t }
|
|
184
|
+
if let k = config.topK { self.topK = Int(k) }
|
|
185
|
+
if let p = config.topP { self.topP = p }
|
|
186
|
+
if let m = config.maxTokens { self.maxTokens = Int(m) }
|
|
187
|
+
if let s = config.systemPrompt { self.systemPrompt = s }
|
|
188
|
+
self.tools = config.tools
|
|
189
|
+
self.enableSpeculativeDecoding = config.enableSpeculativeDecoding ?? false
|
|
190
|
+
} else {
|
|
191
|
+
self.tools = nil
|
|
192
|
+
self.enableSpeculativeDecoding = false
|
|
193
|
+
}
|
|
194
|
+
|
|
195
|
+
// Map main backend string
|
|
196
|
+
let mainBackendStr = self.backend == .gpu ? "gpu" : (self.backend == .npu ? "gpu" : "cpu")
|
|
197
|
+
|
|
198
|
+
//Sniff multimodal support
|
|
199
|
+
let isMultimodal = config?.multimodal ?? (modelPath.lowercased().contains("3n") || modelPath.lowercased().contains("gemma3"))
|
|
200
|
+
let visionBackend = isMultimodal ? "gpu" : nil
|
|
201
|
+
let audioBackend = isMultimodal ? "cpu" : nil
|
|
202
|
+
|
|
203
|
+
var rawEngine: OpaquePointer? = nil
|
|
204
|
+
|
|
205
|
+
// Set LiteRT C Log Level to WARNING (2) for clean production output
|
|
206
|
+
litert_lm_set_min_log_level(2)
|
|
207
|
+
|
|
208
|
+
// Creation helper with scoped FFI pointer lifetime
|
|
209
|
+
let createEngine = { (main: String, vision: String?, audio: String?) -> OpaquePointer? in
|
|
210
|
+
let settings = modelPath.withCString { modelC in
|
|
211
|
+
self.withOptionalCString(main) { mainC in
|
|
212
|
+
self.withOptionalCString(vision) { visionC in
|
|
213
|
+
self.withOptionalCString(audio) { audioC in
|
|
214
|
+
return litert_lm_engine_settings_create(modelC, mainC, visionC, audioC)
|
|
215
|
+
}
|
|
216
|
+
}
|
|
217
|
+
}
|
|
218
|
+
}
|
|
219
|
+
|
|
220
|
+
guard let s = settings else { return nil }
|
|
221
|
+
defer { litert_lm_engine_settings_delete(s) }
|
|
222
|
+
|
|
223
|
+
litert_lm_engine_settings_set_max_num_tokens(s, Int32(self.maxTokens))
|
|
224
|
+
litert_lm_engine_settings_enable_benchmark(s)
|
|
225
|
+
|
|
226
|
+
if self.enableSpeculativeDecoding {
|
|
227
|
+
if let loadedFile = litert_lm_loaded_file_create((modelPath as NSString).utf8String) {
|
|
228
|
+
let hasMtp = litert_lm_loaded_file_has_speculative_decoding_support(loadedFile)
|
|
229
|
+
litert_lm_loaded_file_delete(loadedFile)
|
|
230
|
+
if hasMtp {
|
|
231
|
+
litert_lm_engine_settings_set_enable_speculative_decoding(s, true)
|
|
232
|
+
}
|
|
233
|
+
}
|
|
234
|
+
}
|
|
235
|
+
|
|
236
|
+
// Cache dir set to parent directory of model path
|
|
237
|
+
let cacheDir = (modelPath as NSString).deletingLastPathComponent
|
|
238
|
+
cacheDir.withCString { cacheC in
|
|
239
|
+
litert_lm_engine_settings_set_cache_dir(s, cacheC)
|
|
240
|
+
}
|
|
241
|
+
|
|
242
|
+
return litert_lm_engine_create(s)
|
|
243
|
+
}
|
|
244
|
+
|
|
245
|
+
// Attempt primary backend configuration
|
|
246
|
+
rawEngine = createEngine(mainBackendStr, visionBackend, audioBackend)
|
|
247
|
+
|
|
248
|
+
// Fallback sequence if GPU/NPU fails to initialize
|
|
249
|
+
if rawEngine == nil && mainBackendStr != "cpu" {
|
|
250
|
+
// Fallback 1: CPU execution with GPU acceleration for heavy Vision parameters
|
|
251
|
+
rawEngine = createEngine("cpu", "gpu", "cpu")
|
|
252
|
+
|
|
253
|
+
if rawEngine == nil {
|
|
254
|
+
// Fallback 2: Full CPU execution for all modalities
|
|
255
|
+
rawEngine = createEngine("cpu", "cpu", "cpu")
|
|
256
|
+
}
|
|
257
|
+
|
|
258
|
+
if rawEngine == nil {
|
|
259
|
+
// Fallback 3: Text-only CPU execution (skip vision executor mapping)
|
|
260
|
+
rawEngine = createEngine("cpu", nil, nil)
|
|
261
|
+
}
|
|
262
|
+
|
|
263
|
+
if rawEngine != nil {
|
|
264
|
+
self.backend = .cpu
|
|
265
|
+
}
|
|
266
|
+
}
|
|
267
|
+
|
|
268
|
+
guard let engine = rawEngine else {
|
|
269
|
+
promise.reject(withError: NSError(domain: "LiteRTLM", code: 500, userInfo: [NSLocalizedDescriptionKey: "Failed to construct LiteRT-LM engine. Checked backends and fallback chains."]))
|
|
270
|
+
return
|
|
271
|
+
}
|
|
272
|
+
|
|
273
|
+
self.engine = engine
|
|
274
|
+
self.createNewConversation()
|
|
275
|
+
|
|
276
|
+
guard self.conversation != nil else {
|
|
277
|
+
self.closeInternal()
|
|
278
|
+
promise.reject(withError: NSError(domain: "LiteRTLM", code: 500, userInfo: [NSLocalizedDescriptionKey: "Failed to create conversation context."]))
|
|
279
|
+
return
|
|
280
|
+
}
|
|
281
|
+
|
|
282
|
+
self.isLoaded = true
|
|
283
|
+
promise.resolve()
|
|
284
|
+
}
|
|
285
|
+
|
|
286
|
+
return promise
|
|
287
|
+
}
|
|
288
|
+
|
|
289
|
+
public func sendMessage(message: String) throws -> Promise<String> {
|
|
290
|
+
let promise = Promise<String>()
|
|
291
|
+
|
|
292
|
+
queue.async {
|
|
293
|
+
guard let conversation = self.conversation else {
|
|
294
|
+
promise.reject(withError: NSError(domain: "LiteRTLM", code: 400, userInfo: [NSLocalizedDescriptionKey: "LiteRTLM: No model loaded. Call loadModel() first."]))
|
|
295
|
+
return
|
|
296
|
+
}
|
|
297
|
+
|
|
298
|
+
let msgJson = self.buildTextMessageJson(text: message)
|
|
299
|
+
let startTime = Date()
|
|
300
|
+
|
|
301
|
+
// Synchronous FFI call blocks only this interactive queue
|
|
302
|
+
guard let response = litert_lm_conversation_send_message(conversation, msgJson, nil, nil) else {
|
|
303
|
+
promise.reject(withError: NSError(domain: "LiteRTLM", code: 500, userInfo: [NSLocalizedDescriptionKey: "LiteRT-LM: sendMessage failed"]))
|
|
304
|
+
return
|
|
305
|
+
}
|
|
306
|
+
defer { litert_lm_json_response_delete(response) }
|
|
307
|
+
|
|
308
|
+
var result = ""
|
|
309
|
+
if let responseStr = litert_lm_json_response_get_string(response) {
|
|
310
|
+
result = self.extractTextFromResponse(String(cString: responseStr))
|
|
311
|
+
.trimmingCharacters(in: .whitespacesAndNewlines)
|
|
312
|
+
}
|
|
313
|
+
|
|
314
|
+
let endTime = Date()
|
|
315
|
+
let totalTime = endTime.timeIntervalSince(startTime)
|
|
316
|
+
|
|
317
|
+
var completionTokens = 0.0
|
|
318
|
+
var tokensPerSecond = 0.0
|
|
319
|
+
var ttft = 0.0
|
|
320
|
+
|
|
321
|
+
if let benchInfo = litert_lm_conversation_get_benchmark_info(conversation) {
|
|
322
|
+
let numDecodeTurns = litert_lm_benchmark_info_get_num_decode_turns(benchInfo)
|
|
323
|
+
if numDecodeTurns > 0 {
|
|
324
|
+
let lastIdx = numDecodeTurns - 1
|
|
325
|
+
tokensPerSecond = litert_lm_benchmark_info_get_decode_tokens_per_sec_at(benchInfo, lastIdx)
|
|
326
|
+
completionTokens = Double(litert_lm_benchmark_info_get_decode_token_count_at(benchInfo, lastIdx))
|
|
327
|
+
}
|
|
328
|
+
ttft = litert_lm_benchmark_info_get_time_to_first_token(benchInfo)
|
|
329
|
+
litert_lm_benchmark_info_delete(benchInfo)
|
|
330
|
+
}
|
|
331
|
+
|
|
332
|
+
let promptTokens = Double(message.count) / 4.0
|
|
333
|
+
if completionTokens == 0.0 {
|
|
334
|
+
completionTokens = Double(result.count) / 4.0
|
|
335
|
+
}
|
|
336
|
+
|
|
337
|
+
self.lastStats = GenerationStats(
|
|
338
|
+
promptTokens: promptTokens,
|
|
339
|
+
completionTokens: completionTokens,
|
|
340
|
+
totalTokens: promptTokens + completionTokens,
|
|
341
|
+
timeToFirstToken: ttft,
|
|
342
|
+
totalTime: totalTime,
|
|
343
|
+
tokensPerSecond: tokensPerSecond > 0.0 ? tokensPerSecond : (completionTokens / totalTime)
|
|
344
|
+
)
|
|
345
|
+
|
|
346
|
+
self.history.append(Message(role: .user, content: message))
|
|
347
|
+
self.history.append(Message(role: .model, content: result))
|
|
348
|
+
|
|
349
|
+
promise.resolve(withResult: result)
|
|
350
|
+
}
|
|
351
|
+
|
|
352
|
+
return promise
|
|
353
|
+
}
|
|
354
|
+
|
|
355
|
+
public func sendMessageAsync(
|
|
356
|
+
message: String,
|
|
357
|
+
onToken: @escaping (_ token: String, _ done: Bool) -> Void
|
|
358
|
+
) throws -> Promise<Void> {
|
|
359
|
+
let promise = Promise<Void>()
|
|
360
|
+
|
|
361
|
+
queue.async {
|
|
362
|
+
guard let conversation = self.conversation else {
|
|
363
|
+
promise.reject(withError: NSError(domain: "LiteRTLM", code: 400, userInfo: [NSLocalizedDescriptionKey: "LiteRTLM: No model loaded. Call loadModel() first."]))
|
|
364
|
+
return
|
|
365
|
+
}
|
|
366
|
+
|
|
367
|
+
let msgJson = self.buildTextMessageJson(text: message)
|
|
368
|
+
let startTime = Date()
|
|
369
|
+
|
|
370
|
+
let context = StreamContext(
|
|
371
|
+
userMessage: message,
|
|
372
|
+
startTime: startTime,
|
|
373
|
+
onToken: onToken,
|
|
374
|
+
promise: promise,
|
|
375
|
+
parent: self
|
|
376
|
+
)
|
|
377
|
+
|
|
378
|
+
let callbackData = Unmanaged.passRetained(context).toOpaque()
|
|
379
|
+
|
|
380
|
+
let callback: LiteRtLmStreamCallback = { callbackData, chunk, isFinal, errorMsg in
|
|
381
|
+
guard let callbackData = callbackData else { return }
|
|
382
|
+
let ctx = Unmanaged<StreamContext>.fromOpaque(callbackData).takeUnretainedValue()
|
|
383
|
+
|
|
384
|
+
if let errorMsg = errorMsg {
|
|
385
|
+
let errorStr = String(cString: errorMsg)
|
|
386
|
+
ctx.onToken("Error: \(errorStr)", true)
|
|
387
|
+
ctx.promise.reject(withError: NSError(domain: "LiteRTLM", code: 500, userInfo: [NSLocalizedDescriptionKey: errorStr]))
|
|
388
|
+
Unmanaged<StreamContext>.fromOpaque(callbackData).release()
|
|
389
|
+
return
|
|
390
|
+
}
|
|
391
|
+
|
|
392
|
+
if isFinal {
|
|
393
|
+
let endTime = Date()
|
|
394
|
+
let totalTime = endTime.timeIntervalSince(ctx.startTime)
|
|
395
|
+
|
|
396
|
+
let cleaned = ctx.parent.stripControlTokens(ctx.rawResponse)
|
|
397
|
+
var finalCleaned = cleaned.trimmingCharacters(in: .whitespacesAndNewlines)
|
|
398
|
+
if !ctx.userMessage.isEmpty && finalCleaned.hasPrefix(ctx.userMessage) {
|
|
399
|
+
finalCleaned = String(finalCleaned.dropFirst(ctx.userMessage.count))
|
|
400
|
+
.trimmingCharacters(in: .whitespacesAndNewlines)
|
|
401
|
+
}
|
|
402
|
+
|
|
403
|
+
if finalCleaned.count > ctx.lastEmittedLength {
|
|
404
|
+
let startIdx = finalCleaned.index(finalCleaned.startIndex, offsetBy: ctx.lastEmittedLength)
|
|
405
|
+
let remaining = String(finalCleaned[startIdx...])
|
|
406
|
+
ctx.onToken(remaining, false)
|
|
407
|
+
}
|
|
408
|
+
ctx.fullResponse = finalCleaned
|
|
409
|
+
|
|
410
|
+
var completionTokens = Double(ctx.tokenCount)
|
|
411
|
+
var tokensPerSecond = 0.0
|
|
412
|
+
var ttft = 0.0
|
|
413
|
+
|
|
414
|
+
if let benchInfo = litert_lm_conversation_get_benchmark_info(ctx.parent.conversation) {
|
|
415
|
+
let numDecodeTurns = litert_lm_benchmark_info_get_num_decode_turns(benchInfo)
|
|
416
|
+
if numDecodeTurns > 0 {
|
|
417
|
+
let lastIdx = numDecodeTurns - 1
|
|
418
|
+
tokensPerSecond = litert_lm_benchmark_info_get_decode_tokens_per_sec_at(benchInfo, lastIdx)
|
|
419
|
+
completionTokens = Double(litert_lm_benchmark_info_get_decode_token_count_at(benchInfo, lastIdx))
|
|
420
|
+
}
|
|
421
|
+
ttft = litert_lm_benchmark_info_get_time_to_first_token(benchInfo)
|
|
422
|
+
litert_lm_benchmark_info_delete(benchInfo)
|
|
423
|
+
}
|
|
424
|
+
|
|
425
|
+
let promptTokens = Double(ctx.userMessage.count) / 4.0
|
|
426
|
+
if completionTokens == 0.0 {
|
|
427
|
+
completionTokens = Double(ctx.fullResponse.count) / 4.0
|
|
428
|
+
}
|
|
429
|
+
|
|
430
|
+
ctx.parent.lastStats = GenerationStats(
|
|
431
|
+
promptTokens: promptTokens,
|
|
432
|
+
completionTokens: completionTokens,
|
|
433
|
+
totalTokens: promptTokens + completionTokens,
|
|
434
|
+
timeToFirstToken: ttft,
|
|
435
|
+
totalTime: totalTime,
|
|
436
|
+
tokensPerSecond: tokensPerSecond > 0.0 ? tokensPerSecond : (completionTokens / totalTime)
|
|
437
|
+
)
|
|
438
|
+
|
|
439
|
+
ctx.parent.history.append(Message(role: .user, content: ctx.userMessage))
|
|
440
|
+
ctx.parent.history.append(Message(role: .model, content: ctx.fullResponse))
|
|
441
|
+
|
|
442
|
+
ctx.onToken("", true)
|
|
443
|
+
ctx.promise.resolve()
|
|
444
|
+
Unmanaged<StreamContext>.fromOpaque(callbackData).release()
|
|
445
|
+
return
|
|
446
|
+
}
|
|
447
|
+
|
|
448
|
+
if let chunk = chunk {
|
|
449
|
+
let token = String(cString: chunk)
|
|
450
|
+
let raw: String
|
|
451
|
+
if token.hasPrefix("{") && token.contains("\"role\"") {
|
|
452
|
+
raw = ctx.parent.extractTextFromResponse(token)
|
|
453
|
+
} else {
|
|
454
|
+
raw = token
|
|
455
|
+
}
|
|
456
|
+
|
|
457
|
+
ctx.rawResponse += raw
|
|
458
|
+
let cleaned = ctx.parent.stripControlTokens(ctx.rawResponse)
|
|
459
|
+
.trimmingLeadingCharacters(in: .whitespacesAndNewlines)
|
|
460
|
+
|
|
461
|
+
var processed = cleaned
|
|
462
|
+
if !ctx.userMessage.isEmpty && processed.hasPrefix(ctx.userMessage) {
|
|
463
|
+
processed = String(processed.dropFirst(ctx.userMessage.count))
|
|
464
|
+
.trimmingLeadingCharacters(in: .whitespacesAndNewlines)
|
|
465
|
+
}
|
|
466
|
+
|
|
467
|
+
let safeLen = ctx.parent.safeEmitLength(processed)
|
|
468
|
+
if safeLen > ctx.lastEmittedLength {
|
|
469
|
+
let chars = Array(processed)
|
|
470
|
+
let newText = String(chars[ctx.lastEmittedLength..<safeLen])
|
|
471
|
+
ctx.lastEmittedLength = safeLen
|
|
472
|
+
ctx.tokenCount += 1
|
|
473
|
+
ctx.onToken(newText, false)
|
|
474
|
+
}
|
|
475
|
+
}
|
|
476
|
+
}
|
|
477
|
+
|
|
478
|
+
let status = litert_lm_conversation_send_message_stream(
|
|
479
|
+
conversation,
|
|
480
|
+
msgJson,
|
|
481
|
+
nil,
|
|
482
|
+
nil,
|
|
483
|
+
callback,
|
|
484
|
+
callbackData
|
|
485
|
+
)
|
|
486
|
+
|
|
487
|
+
if status != 0 {
|
|
488
|
+
Unmanaged<StreamContext>.fromOpaque(callbackData).release()
|
|
489
|
+
promise.reject(withError: NSError(domain: "LiteRTLM", code: Int(status), userInfo: [NSLocalizedDescriptionKey: "Failed to start streaming conversation."]))
|
|
490
|
+
}
|
|
491
|
+
}
|
|
492
|
+
|
|
493
|
+
return promise
|
|
494
|
+
}
|
|
495
|
+
|
|
496
|
+
public func sendMessageWithImage(message: String, imagePath: String) throws -> Promise<String> {
|
|
497
|
+
let promise = Promise<String>()
|
|
498
|
+
|
|
499
|
+
queue.async {
|
|
500
|
+
guard let conversation = self.conversation else {
|
|
501
|
+
promise.reject(withError: NSError(domain: "LiteRTLM", code: 400, userInfo: [NSLocalizedDescriptionKey: "LiteRTLM: No model loaded. Call loadModel() first."]))
|
|
502
|
+
return
|
|
503
|
+
}
|
|
504
|
+
|
|
505
|
+
if !FileManager.default.fileExists(atPath: imagePath) {
|
|
506
|
+
promise.reject(withError: NSError(domain: "LiteRTLM", code: 404, userInfo: [NSLocalizedDescriptionKey: "Image file not found: \(imagePath)"]))
|
|
507
|
+
return
|
|
508
|
+
}
|
|
509
|
+
|
|
510
|
+
let msgJson = self.buildImageMessageJson(text: message, imagePath: imagePath)
|
|
511
|
+
let startTime = Date()
|
|
512
|
+
|
|
513
|
+
guard let response = litert_lm_conversation_send_message(conversation, msgJson, nil, nil) else {
|
|
514
|
+
promise.reject(withError: NSError(domain: "LiteRTLM", code: 500, userInfo: [NSLocalizedDescriptionKey: "LiteRT-LM: sendMessageWithImage failed"]))
|
|
515
|
+
return
|
|
516
|
+
}
|
|
517
|
+
defer { litert_lm_json_response_delete(response) }
|
|
518
|
+
|
|
519
|
+
var result = ""
|
|
520
|
+
if let responseStr = litert_lm_json_response_get_string(response) {
|
|
521
|
+
result = self.extractTextFromResponse(String(cString: responseStr))
|
|
522
|
+
.trimmingCharacters(in: .whitespacesAndNewlines)
|
|
523
|
+
}
|
|
524
|
+
|
|
525
|
+
let endTime = Date()
|
|
526
|
+
let totalTime = endTime.timeIntervalSince(startTime)
|
|
527
|
+
|
|
528
|
+
self.lastStats = GenerationStats(
|
|
529
|
+
promptTokens: Double(message.count) / 4.0,
|
|
530
|
+
completionTokens: Double(result.count) / 4.0,
|
|
531
|
+
totalTokens: Double(message.count + result.count) / 4.0,
|
|
532
|
+
timeToFirstToken: 0.0,
|
|
533
|
+
totalTime: totalTime,
|
|
534
|
+
tokensPerSecond: Double(result.count) / 4.0 / totalTime
|
|
535
|
+
)
|
|
536
|
+
|
|
537
|
+
self.history.append(Message(role: .user, content: message + " [image: \(imagePath)]"))
|
|
538
|
+
self.history.append(Message(role: .model, content: result))
|
|
539
|
+
|
|
540
|
+
promise.resolve(withResult: result)
|
|
541
|
+
}
|
|
542
|
+
|
|
543
|
+
return promise
|
|
544
|
+
}
|
|
545
|
+
|
|
546
|
+
public func sendMessageWithAudio(message: String, audioPath: String) throws -> Promise<String> {
|
|
547
|
+
let promise = Promise<String>()
|
|
548
|
+
|
|
549
|
+
queue.async {
|
|
550
|
+
guard let conversation = self.conversation else {
|
|
551
|
+
promise.reject(withError: NSError(domain: "LiteRTLM", code: 400, userInfo: [NSLocalizedDescriptionKey: "LiteRTLM: No model loaded. Call loadModel() first."]))
|
|
552
|
+
return
|
|
553
|
+
}
|
|
554
|
+
|
|
555
|
+
if !FileManager.default.fileExists(atPath: audioPath) {
|
|
556
|
+
promise.reject(withError: NSError(domain: "LiteRTLM", code: 404, userInfo: [NSLocalizedDescriptionKey: "Audio file not found: \(audioPath)"]))
|
|
557
|
+
return
|
|
558
|
+
}
|
|
559
|
+
|
|
560
|
+
let msgJson = self.buildAudioMessageJson(text: message, audioPath: audioPath)
|
|
561
|
+
let startTime = Date()
|
|
562
|
+
|
|
563
|
+
guard let response = litert_lm_conversation_send_message(conversation, msgJson, nil, nil) else {
|
|
564
|
+
promise.reject(withError: NSError(domain: "LiteRTLM", code: 500, userInfo: [NSLocalizedDescriptionKey: "LiteRT-LM: sendMessageWithAudio failed"]))
|
|
565
|
+
return
|
|
566
|
+
}
|
|
567
|
+
defer { litert_lm_json_response_delete(response) }
|
|
568
|
+
|
|
569
|
+
var result = ""
|
|
570
|
+
if let responseStr = litert_lm_json_response_get_string(response) {
|
|
571
|
+
result = self.extractTextFromResponse(String(cString: responseStr))
|
|
572
|
+
.trimmingCharacters(in: .whitespacesAndNewlines)
|
|
573
|
+
}
|
|
574
|
+
|
|
575
|
+
let endTime = Date()
|
|
576
|
+
let totalTime = endTime.timeIntervalSince(startTime)
|
|
577
|
+
|
|
578
|
+
self.lastStats = GenerationStats(
|
|
579
|
+
promptTokens: Double(message.count) / 4.0,
|
|
580
|
+
completionTokens: Double(result.count) / 4.0,
|
|
581
|
+
totalTokens: Double(message.count + result.count) / 4.0,
|
|
582
|
+
timeToFirstToken: 0.0,
|
|
583
|
+
totalTime: totalTime,
|
|
584
|
+
tokensPerSecond: Double(result.count) / 4.0 / totalTime
|
|
585
|
+
)
|
|
586
|
+
|
|
587
|
+
self.history.append(Message(role: .user, content: message + " [audio: \(audioPath)]"))
|
|
588
|
+
self.history.append(Message(role: .model, content: result))
|
|
589
|
+
|
|
590
|
+
promise.resolve(withResult: result)
|
|
591
|
+
}
|
|
592
|
+
|
|
593
|
+
return promise
|
|
594
|
+
}
|
|
595
|
+
|
|
596
|
+
public func sendMultimodalMessage(parts: [MultimodalPart]) throws -> Promise<String> {
|
|
597
|
+
let promise = Promise<String>()
|
|
598
|
+
|
|
599
|
+
queue.async {
|
|
600
|
+
guard let engine = self.engine else {
|
|
601
|
+
promise.reject(withError: NSError(domain: "LiteRTLM", code: 400, userInfo: [NSLocalizedDescriptionKey: "LiteRTLM: No model loaded. Call loadModel() first."]))
|
|
602
|
+
return
|
|
603
|
+
}
|
|
604
|
+
|
|
605
|
+
// Create session config
|
|
606
|
+
guard let sessionConfig = litert_lm_session_config_create() else {
|
|
607
|
+
promise.reject(withError: NSError(domain: "LiteRTLM", code: 500, userInfo: [NSLocalizedDescriptionKey: "LiteRTLM: Failed to create session config."]))
|
|
608
|
+
return
|
|
609
|
+
}
|
|
610
|
+
defer { litert_lm_session_config_delete(sessionConfig) }
|
|
611
|
+
|
|
612
|
+
litert_lm_session_config_set_max_output_tokens(sessionConfig, Int32(self.maxTokens))
|
|
613
|
+
|
|
614
|
+
var sampler = LiteRtLmSamplerParams()
|
|
615
|
+
sampler.type = kLiteRtLmSamplerTypeTopP
|
|
616
|
+
sampler.top_k = Int32(self.topK)
|
|
617
|
+
sampler.top_p = Float(self.topP)
|
|
618
|
+
sampler.temperature = Float(self.temperature)
|
|
619
|
+
sampler.seed = 0
|
|
620
|
+
withUnsafePointer(to: &sampler) { samplerPtr in
|
|
621
|
+
litert_lm_session_config_set_sampler_params(sessionConfig, samplerPtr)
|
|
622
|
+
}
|
|
623
|
+
|
|
624
|
+
guard let session = litert_lm_engine_create_session(engine, sessionConfig) else {
|
|
625
|
+
promise.reject(withError: NSError(domain: "LiteRTLM", code: 500, userInfo: [NSLocalizedDescriptionKey: "LiteRTLM: Failed to create session."]))
|
|
626
|
+
return
|
|
627
|
+
}
|
|
628
|
+
defer { litert_lm_session_delete(session) }
|
|
629
|
+
|
|
630
|
+
// Construct inputs array
|
|
631
|
+
var inputs: [LiteRtLmInputData] = []
|
|
632
|
+
var allocatedStrings: [UnsafeMutablePointer<CChar>] = []
|
|
633
|
+
|
|
634
|
+
defer {
|
|
635
|
+
for ptr in allocatedStrings {
|
|
636
|
+
free(ptr)
|
|
637
|
+
}
|
|
638
|
+
}
|
|
639
|
+
|
|
640
|
+
for part in parts {
|
|
641
|
+
switch part.type {
|
|
642
|
+
case .text:
|
|
643
|
+
if let text = part.text {
|
|
644
|
+
let cStr = strdup(text)!
|
|
645
|
+
allocatedStrings.append(cStr)
|
|
646
|
+
inputs.append(LiteRtLmInputData(type: kLiteRtLmInputDataTypeText, data: cStr, size: text.utf8.count))
|
|
647
|
+
}
|
|
648
|
+
case .image:
|
|
649
|
+
if let imageBuffer = part.imageBuffer {
|
|
650
|
+
inputs.append(LiteRtLmInputData(type: kLiteRtLmInputDataTypeImage, data: imageBuffer.data, size: imageBuffer.size))
|
|
651
|
+
}
|
|
652
|
+
case .audio:
|
|
653
|
+
if let audioBuffer = part.audioBuffer {
|
|
654
|
+
inputs.append(LiteRtLmInputData(type: kLiteRtLmInputDataTypeAudio, data: audioBuffer.data, size: audioBuffer.size))
|
|
655
|
+
}
|
|
656
|
+
}
|
|
657
|
+
}
|
|
658
|
+
|
|
659
|
+
let startTime = Date()
|
|
660
|
+
|
|
661
|
+
// Run session inference
|
|
662
|
+
guard let responses = litert_lm_session_generate_content(session, inputs, inputs.count) else {
|
|
663
|
+
promise.reject(withError: NSError(domain: "LiteRTLM", code: 500, userInfo: [NSLocalizedDescriptionKey: "LiteRTLM: Session generate content failed."]))
|
|
664
|
+
return
|
|
665
|
+
}
|
|
666
|
+
defer { litert_lm_responses_delete(responses) }
|
|
667
|
+
|
|
668
|
+
var result = ""
|
|
669
|
+
let numCandidates = litert_lm_responses_get_num_candidates(responses)
|
|
670
|
+
if numCandidates > 0 {
|
|
671
|
+
if let responseStr = litert_lm_responses_get_response_text_at(responses, 0) {
|
|
672
|
+
result = String(cString: responseStr).trimmingCharacters(in: .whitespacesAndNewlines)
|
|
673
|
+
}
|
|
674
|
+
}
|
|
675
|
+
|
|
676
|
+
let endTime = Date()
|
|
677
|
+
let totalTime = endTime.timeIntervalSince(startTime)
|
|
678
|
+
|
|
679
|
+
// Update last stats using benchmark info from session
|
|
680
|
+
var completionTokens = 0.0
|
|
681
|
+
var tokensPerSecond = 0.0
|
|
682
|
+
var ttft = 0.0
|
|
683
|
+
|
|
684
|
+
if let benchInfo = litert_lm_session_get_benchmark_info(session) {
|
|
685
|
+
let numDecodeTurns = litert_lm_benchmark_info_get_num_decode_turns(benchInfo)
|
|
686
|
+
if numDecodeTurns > 0 {
|
|
687
|
+
let lastIdx = numDecodeTurns - 1
|
|
688
|
+
tokensPerSecond = litert_lm_benchmark_info_get_decode_tokens_per_sec_at(benchInfo, lastIdx)
|
|
689
|
+
completionTokens = Double(litert_lm_benchmark_info_get_decode_token_count_at(benchInfo, lastIdx))
|
|
690
|
+
}
|
|
691
|
+
ttft = litert_lm_benchmark_info_get_time_to_first_token(benchInfo)
|
|
692
|
+
litert_lm_benchmark_info_delete(benchInfo)
|
|
693
|
+
}
|
|
694
|
+
|
|
695
|
+
let totalInputLen = parts.reduce(0) { $0 + ($1.text?.count ?? 0) }
|
|
696
|
+
let promptTokens = Double(totalInputLen) / 4.0
|
|
697
|
+
if completionTokens == 0.0 {
|
|
698
|
+
completionTokens = Double(result.count) / 4.0
|
|
699
|
+
}
|
|
700
|
+
|
|
701
|
+
self.lastStats = GenerationStats(
|
|
702
|
+
promptTokens: promptTokens,
|
|
703
|
+
completionTokens: completionTokens,
|
|
704
|
+
totalTokens: promptTokens + completionTokens,
|
|
705
|
+
timeToFirstToken: ttft,
|
|
706
|
+
totalTime: totalTime,
|
|
707
|
+
tokensPerSecond: tokensPerSecond > 0.0 ? tokensPerSecond : (completionTokens / totalTime)
|
|
708
|
+
)
|
|
709
|
+
|
|
710
|
+
// Append to history
|
|
711
|
+
var userTextRepresentation = ""
|
|
712
|
+
for part in parts {
|
|
713
|
+
if part.type == .text, let text = part.text {
|
|
714
|
+
userTextRepresentation += text + " "
|
|
715
|
+
} else if part.type == .image {
|
|
716
|
+
userTextRepresentation += "[Image Buffer] "
|
|
717
|
+
} else if part.type == .audio {
|
|
718
|
+
userTextRepresentation += "[Audio Buffer] "
|
|
719
|
+
}
|
|
720
|
+
}
|
|
721
|
+
userTextRepresentation = userTextRepresentation.trimmingCharacters(in: .whitespacesAndNewlines)
|
|
722
|
+
|
|
723
|
+
self.history.append(Message(role: .user, content: userTextRepresentation))
|
|
724
|
+
self.history.append(Message(role: .model, content: result))
|
|
725
|
+
|
|
726
|
+
promise.resolve(withResult: result)
|
|
727
|
+
}
|
|
728
|
+
|
|
729
|
+
return promise
|
|
730
|
+
}
|
|
731
|
+
|
|
732
|
+
public func downloadModel(
|
|
733
|
+
url: String,
|
|
734
|
+
fileName: String,
|
|
735
|
+
onProgress: ((Double) -> Void)?
|
|
736
|
+
) throws -> Promise<String> {
|
|
737
|
+
let promise = Promise<String>()
|
|
738
|
+
|
|
739
|
+
queue.async {
|
|
740
|
+
do {
|
|
741
|
+
if fileName.contains("..") || fileName.contains("/") || fileName.contains("\\") {
|
|
742
|
+
promise.reject(withError: NSError(domain: "LiteRTLM", code: 400, userInfo: [NSLocalizedDescriptionKey: "Invalid filename: path traversal or directory separators are not allowed."]))
|
|
743
|
+
return
|
|
744
|
+
}
|
|
745
|
+
|
|
746
|
+
let cachesDir = NSSearchPathForDirectoriesInDomains(.cachesDirectory, .userDomainMask, true).first ?? NSTemporaryDirectory()
|
|
747
|
+
let modelsDir = (cachesDir as NSString).appendingPathComponent("litert_models")
|
|
748
|
+
|
|
749
|
+
let fileManager = FileManager.default
|
|
750
|
+
if !fileManager.fileExists(atPath: modelsDir) {
|
|
751
|
+
try fileManager.createDirectory(atPath: modelsDir, withIntermediateDirectories: true, attributes: nil)
|
|
752
|
+
}
|
|
753
|
+
|
|
754
|
+
let destPath = (modelsDir as NSString).appendingPathComponent(fileName)
|
|
755
|
+
|
|
756
|
+
// Fast cache check
|
|
757
|
+
if fileManager.fileExists(atPath: destPath) {
|
|
758
|
+
let attrs = try fileManager.attributesOfItem(atPath: destPath)
|
|
759
|
+
if let fileSize = attrs[.size] as? UInt64, fileSize > 0 {
|
|
760
|
+
onProgress?(1.0)
|
|
761
|
+
promise.resolve(withResult: destPath)
|
|
762
|
+
return
|
|
763
|
+
}
|
|
764
|
+
}
|
|
765
|
+
|
|
766
|
+
guard let downloadUrl = URL(string: url), downloadUrl.scheme?.lowercased() == "https" else {
|
|
767
|
+
promise.reject(withError: NSError(domain: "LiteRTLM", code: 400, userInfo: [NSLocalizedDescriptionKey: "Invalid download URL: HTTPS is required for security."]))
|
|
768
|
+
return
|
|
769
|
+
}
|
|
770
|
+
|
|
771
|
+
onProgress?(0.0)
|
|
772
|
+
|
|
773
|
+
let sessionConfig = URLSessionConfiguration.default
|
|
774
|
+
sessionConfig.timeoutIntervalForRequest = 30
|
|
775
|
+
sessionConfig.timeoutIntervalForResource = 3600
|
|
776
|
+
|
|
777
|
+
let session = URLSession(configuration: sessionConfig)
|
|
778
|
+
var progressHandler: NSKeyValueObservation?
|
|
779
|
+
|
|
780
|
+
let task = session.downloadTask(with: downloadUrl) { location, response, error in
|
|
781
|
+
progressHandler?.invalidate()
|
|
782
|
+
|
|
783
|
+
if let error = error {
|
|
784
|
+
promise.reject(withError: error)
|
|
785
|
+
return
|
|
786
|
+
}
|
|
787
|
+
|
|
788
|
+
if let httpResponse = response as? HTTPURLResponse, httpResponse.statusCode >= 400 {
|
|
789
|
+
promise.reject(withError: NSError(domain: "LiteRTLM", code: httpResponse.statusCode, userInfo: [NSLocalizedDescriptionKey: "HTTP \(httpResponse.statusCode)"]))
|
|
790
|
+
return
|
|
791
|
+
}
|
|
792
|
+
|
|
793
|
+
guard let location = location else {
|
|
794
|
+
promise.reject(withError: NSError(domain: "LiteRTLM", code: 500, userInfo: [NSLocalizedDescriptionKey: "No download location found."]))
|
|
795
|
+
return
|
|
796
|
+
}
|
|
797
|
+
|
|
798
|
+
do {
|
|
799
|
+
if fileManager.fileExists(atPath: destPath) {
|
|
800
|
+
try fileManager.removeItem(atPath: destPath)
|
|
801
|
+
}
|
|
802
|
+
try fileManager.moveItem(at: location, to: URL(fileURLWithPath: destPath))
|
|
803
|
+
onProgress?(1.0)
|
|
804
|
+
promise.resolve(withResult: destPath)
|
|
805
|
+
} catch {
|
|
806
|
+
promise.reject(withError: error)
|
|
807
|
+
}
|
|
808
|
+
}
|
|
809
|
+
|
|
810
|
+
if let onProgress = onProgress {
|
|
811
|
+
var lastUpdate = Date()
|
|
812
|
+
progressHandler = task.observe(\.countOfBytesReceived, options: [.new]) { task, _ in
|
|
813
|
+
let expected = task.countOfBytesExpectedToReceive
|
|
814
|
+
if expected > 0 {
|
|
815
|
+
let now = Date()
|
|
816
|
+
// Throttled progress notifications to 10Hz
|
|
817
|
+
if now.timeIntervalSince(lastUpdate) > 0.1 {
|
|
818
|
+
let progress = Double(task.countOfBytesReceived) / Double(expected)
|
|
819
|
+
onProgress(progress)
|
|
820
|
+
lastUpdate = now
|
|
821
|
+
}
|
|
822
|
+
}
|
|
823
|
+
}
|
|
824
|
+
}
|
|
825
|
+
|
|
826
|
+
task.resume()
|
|
827
|
+
session.finishTasksAndInvalidate()
|
|
828
|
+
} catch {
|
|
829
|
+
promise.reject(withError: error)
|
|
830
|
+
}
|
|
831
|
+
}
|
|
832
|
+
|
|
833
|
+
return promise
|
|
834
|
+
}
|
|
835
|
+
|
|
836
|
+
public func deleteModel(fileName: String) throws -> Promise<Void> {
|
|
837
|
+
let promise = Promise<Void>()
|
|
838
|
+
|
|
839
|
+
queue.async {
|
|
840
|
+
do {
|
|
841
|
+
if fileName.contains("..") || fileName.contains("/") || fileName.contains("\\") {
|
|
842
|
+
promise.reject(withError: NSError(domain: "LiteRTLM", code: 400, userInfo: [NSLocalizedDescriptionKey: "Invalid filename: path traversal or directory separators are not allowed."]))
|
|
843
|
+
return
|
|
844
|
+
}
|
|
845
|
+
|
|
846
|
+
let cachesDir = NSSearchPathForDirectoriesInDomains(.cachesDirectory, .userDomainMask, true).first ?? NSTemporaryDirectory()
|
|
847
|
+
let modelsDir = (cachesDir as NSString).appendingPathComponent("litert_models")
|
|
848
|
+
let destPath = (modelsDir as NSString).appendingPathComponent(fileName)
|
|
849
|
+
|
|
850
|
+
let fileManager = FileManager.default
|
|
851
|
+
if fileManager.fileExists(atPath: destPath) {
|
|
852
|
+
try fileManager.removeItem(atPath: destPath)
|
|
853
|
+
if self.isLoaded {
|
|
854
|
+
self.closeInternal()
|
|
855
|
+
}
|
|
856
|
+
}
|
|
857
|
+
promise.resolve()
|
|
858
|
+
} catch {
|
|
859
|
+
promise.reject(withError: error)
|
|
860
|
+
}
|
|
861
|
+
}
|
|
862
|
+
|
|
863
|
+
return promise
|
|
864
|
+
}
|
|
865
|
+
|
|
866
|
+
// MARK: - Internal Engine Helpers
|
|
867
|
+
|
|
868
|
+
private func createNewConversation() {
|
|
869
|
+
guard let engine = self.engine else { return }
|
|
870
|
+
|
|
871
|
+
if let oldConv = self.conversation {
|
|
872
|
+
litert_lm_conversation_delete(oldConv)
|
|
873
|
+
self.conversation = nil
|
|
874
|
+
}
|
|
875
|
+
|
|
876
|
+
guard let convConfig = litert_lm_conversation_config_create() else { return }
|
|
877
|
+
defer { litert_lm_conversation_config_delete(convConfig) }
|
|
878
|
+
|
|
879
|
+
guard let sessionConfig = litert_lm_session_config_create() else { return }
|
|
880
|
+
defer { litert_lm_session_config_delete(sessionConfig) }
|
|
881
|
+
|
|
882
|
+
litert_lm_session_config_set_max_output_tokens(sessionConfig, Int32(self.maxTokens))
|
|
883
|
+
|
|
884
|
+
var sampler = LiteRtLmSamplerParams()
|
|
885
|
+
sampler.type = kLiteRtLmSamplerTypeTopP
|
|
886
|
+
sampler.top_k = Int32(self.topK)
|
|
887
|
+
sampler.top_p = Float(self.topP)
|
|
888
|
+
sampler.temperature = Float(self.temperature)
|
|
889
|
+
sampler.seed = 0
|
|
890
|
+
withUnsafePointer(to: &sampler) { samplerPtr in
|
|
891
|
+
litert_lm_session_config_set_sampler_params(sessionConfig, samplerPtr)
|
|
892
|
+
}
|
|
893
|
+
|
|
894
|
+
litert_lm_conversation_config_set_session_config(convConfig, sessionConfig)
|
|
895
|
+
|
|
896
|
+
if let systemPrompt = self.systemPrompt {
|
|
897
|
+
let systemMsgJson = "{\"role\":\"system\",\"content\":\"" + escapeJson(systemPrompt) + "\"}"
|
|
898
|
+
systemMsgJson.withCString { systemMsgC in
|
|
899
|
+
litert_lm_conversation_config_set_system_message(convConfig, systemMsgC)
|
|
900
|
+
}
|
|
901
|
+
}
|
|
902
|
+
|
|
903
|
+
if let tools = self.tools, !tools.isEmpty {
|
|
904
|
+
var toolsArray: [[String: Any]] = []
|
|
905
|
+
for tool in tools {
|
|
906
|
+
var functionMap: [String: Any] = ["name": tool.name, "description": tool.description]
|
|
907
|
+
if let data = tool.parametersJson.data(using: .utf8),
|
|
908
|
+
let parsedParams = try? JSONSerialization.jsonObject(with: data, options: []) {
|
|
909
|
+
functionMap["parameters"] = parsedParams
|
|
910
|
+
}
|
|
911
|
+
toolsArray.append(["type": "function", "function": functionMap])
|
|
912
|
+
}
|
|
913
|
+
if let data = try? JSONSerialization.data(withJSONObject: toolsArray, options: []),
|
|
914
|
+
let jsonString = String(data: data, encoding: .utf8) {
|
|
915
|
+
jsonString.withCString { toolsC in
|
|
916
|
+
litert_lm_conversation_config_set_tools(convConfig, toolsC)
|
|
917
|
+
}
|
|
918
|
+
}
|
|
919
|
+
}
|
|
920
|
+
|
|
921
|
+
self.conversation = litert_lm_conversation_create(engine, convConfig)
|
|
922
|
+
}
|
|
923
|
+
|
|
924
|
+
private func closeInternal() {
|
|
925
|
+
isLoaded = false
|
|
926
|
+
history.removeAll()
|
|
927
|
+
|
|
928
|
+
if let conversation = self.conversation {
|
|
929
|
+
litert_lm_conversation_delete(conversation)
|
|
930
|
+
self.conversation = nil
|
|
931
|
+
}
|
|
932
|
+
if let engine = self.engine {
|
|
933
|
+
litert_lm_engine_delete(engine)
|
|
934
|
+
self.engine = nil
|
|
935
|
+
}
|
|
936
|
+
|
|
937
|
+
lastStats = GenerationStats(
|
|
938
|
+
promptTokens: 0.0,
|
|
939
|
+
completionTokens: 0.0,
|
|
940
|
+
totalTokens: 0.0,
|
|
941
|
+
timeToFirstToken: 0.0,
|
|
942
|
+
totalTime: 0.0,
|
|
943
|
+
tokensPerSecond: 0.0
|
|
944
|
+
)
|
|
945
|
+
}
|
|
946
|
+
|
|
947
|
+
// MARK: - String and JSON Preprocessing Helpers
|
|
948
|
+
|
|
949
|
+
private let kControlTokens = [
|
|
950
|
+
"<end_of_turn>",
|
|
951
|
+
"<start_of_turn>model",
|
|
952
|
+
"<start_of_turn>user",
|
|
953
|
+
"<start_of_turn>",
|
|
954
|
+
"<eos>"
|
|
955
|
+
]
|
|
956
|
+
|
|
957
|
+
private func escapeJson(_ input: String) -> String {
|
|
958
|
+
var output = ""
|
|
959
|
+
for char in input {
|
|
960
|
+
switch char {
|
|
961
|
+
case "\"": output += "\\\""
|
|
962
|
+
case "\\": output += "\\\\"
|
|
963
|
+
case "\n": output += "\\n"
|
|
964
|
+
case "\r": output += "\\r"
|
|
965
|
+
case "\t": output += "\\t"
|
|
966
|
+
case "\u{0008}": output += "\\b"
|
|
967
|
+
case "\u{000c}": output += "\\f"
|
|
968
|
+
default: output.append(char)
|
|
969
|
+
}
|
|
970
|
+
}
|
|
971
|
+
return output
|
|
972
|
+
}
|
|
973
|
+
|
|
974
|
+
private func buildTextMessageJson(text: String) -> String {
|
|
975
|
+
return "{\"role\":\"user\",\"content\":\"" + escapeJson(text) + "\"}"
|
|
976
|
+
}
|
|
977
|
+
|
|
978
|
+
private func buildImageMessageJson(text: String, imagePath: String) -> String {
|
|
979
|
+
return "{\"role\":\"user\",\"content\":[" +
|
|
980
|
+
"{\"type\":\"text\",\"text\":\"" + escapeJson(text) + "\"}," +
|
|
981
|
+
"{\"type\":\"image\",\"path\":\"" + escapeJson(imagePath) + "\"}" +
|
|
982
|
+
"]}"
|
|
983
|
+
}
|
|
984
|
+
|
|
985
|
+
private func buildAudioMessageJson(text: String, audioPath: String) -> String {
|
|
986
|
+
return "{\"role\":\"user\",\"content\":[" +
|
|
987
|
+
"{\"type\":\"text\",\"text\":\"" + escapeJson(text) + "\"}," +
|
|
988
|
+
"{\"type\":\"audio\",\"path\":\"" + escapeJson(audioPath) + "\"}" +
|
|
989
|
+
"]}"
|
|
990
|
+
}
|
|
991
|
+
|
|
992
|
+
private func stripControlTokens(_ text: String) -> String {
|
|
993
|
+
var result = text
|
|
994
|
+
for tok in kControlTokens {
|
|
995
|
+
result = result.replacingOccurrences(of: tok, with: "")
|
|
996
|
+
}
|
|
997
|
+
return result
|
|
998
|
+
}
|
|
999
|
+
|
|
1000
|
+
private func safeEmitLength(_ text: String) -> Int {
|
|
1001
|
+
let chars = Array(text)
|
|
1002
|
+
guard let lastAngleIdx = chars.lastIndex(of: "<") else {
|
|
1003
|
+
return chars.count
|
|
1004
|
+
}
|
|
1005
|
+
let suffix = String(chars[lastAngleIdx...])
|
|
1006
|
+
for tok in kControlTokens {
|
|
1007
|
+
if tok.hasPrefix(suffix) && suffix.count < tok.count {
|
|
1008
|
+
return lastAngleIdx
|
|
1009
|
+
}
|
|
1010
|
+
}
|
|
1011
|
+
return chars.count
|
|
1012
|
+
}
|
|
1013
|
+
|
|
1014
|
+
private func extractTextFromResponse(_ jsonResponse: String) -> String {
|
|
1015
|
+
guard let data = jsonResponse.data(using: .utf8) else {
|
|
1016
|
+
return stripControlTokens(jsonResponse)
|
|
1017
|
+
}
|
|
1018
|
+
do {
|
|
1019
|
+
if let json = try JSONSerialization.jsonObject(with: data, options: []) as? [String: Any] {
|
|
1020
|
+
if let content = json["content"] {
|
|
1021
|
+
if let contentString = content as? String {
|
|
1022
|
+
return stripControlTokens(contentString)
|
|
1023
|
+
} else if let contentArray = content as? [[String: Any]] {
|
|
1024
|
+
var textResult = ""
|
|
1025
|
+
for part in contentArray {
|
|
1026
|
+
if let type = part["type"] as? String, type == "text", let text = part["text"] as? String {
|
|
1027
|
+
textResult += text
|
|
1028
|
+
}
|
|
1029
|
+
}
|
|
1030
|
+
return stripControlTokens(textResult)
|
|
1031
|
+
}
|
|
1032
|
+
}
|
|
1033
|
+
}
|
|
1034
|
+
} catch {}
|
|
1035
|
+
return stripControlTokens(jsonResponse)
|
|
1036
|
+
}
|
|
1037
|
+
|
|
1038
|
+
private func withOptionalCString<R>(_ string: String?, _ block: (UnsafePointer<CChar>?) -> R) -> R {
|
|
1039
|
+
if let string = string {
|
|
1040
|
+
return string.withCString { block($0) }
|
|
1041
|
+
} else {
|
|
1042
|
+
return block(nil)
|
|
1043
|
+
}
|
|
1044
|
+
}
|
|
1045
|
+
}
|
|
1046
|
+
|
|
1047
|
+
// MARK: - String Trimming Extension
|
|
1048
|
+
|
|
1049
|
+
private extension String {
|
|
1050
|
+
func trimmingLeadingCharacters(in characterSet: CharacterSet) -> String {
|
|
1051
|
+
guard let index = firstIndex(where: { char in
|
|
1052
|
+
!char.unicodeScalars.allSatisfy { characterSet.contains($0) }
|
|
1053
|
+
}) else {
|
|
1054
|
+
return ""
|
|
1055
|
+
}
|
|
1056
|
+
return String(self[index...])
|
|
1057
|
+
}
|
|
1058
|
+
}
|