@dvai-bridge/android-litert-core 4.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +51 -0
- package/README.md +199 -0
- package/android/build.gradle +180 -0
- package/android/gradle.properties +5 -0
- package/android/settings.gradle +1 -0
- package/android/src/main/AndroidManifest.xml +3 -0
- package/android/src/main/java/co/deepvoiceai/bridge/litert/core/Internal/HFTokenizerJson.kt +380 -0
- package/android/src/main/java/co/deepvoiceai/bridge/litert/core/Internal/LiteRTEngine.kt +241 -0
- package/android/src/main/java/co/deepvoiceai/bridge/litert/core/Internal/LiteRTGenerator.kt +71 -0
- package/android/src/main/java/co/deepvoiceai/bridge/litert/core/Internal/LiteRTSampler.kt +105 -0
- package/android/src/main/java/co/deepvoiceai/bridge/litert/core/LiteRTBackendError.kt +13 -0
- package/android/src/main/java/co/deepvoiceai/bridge/litert/core/LiteRTHandlers.kt +378 -0
- package/android/src/main/java/co/deepvoiceai/bridge/litert/core/LiteRTPluginState.kt +199 -0
- package/android/src/test/java/co/deepvoiceai/bridge/litert/core/Internal/LiteRTGeneratorMockTest.kt +234 -0
- package/android/src/test/java/co/deepvoiceai/bridge/litert/core/Internal/LiteRTSamplerTest.kt +136 -0
- package/package.json +19 -0
|
@@ -0,0 +1,380 @@
|
|
|
1
|
+
package co.deepvoiceai.bridge.litert.core.Internal
|
|
2
|
+
|
|
3
|
+
import co.deepvoiceai.bridge.litert.core.LiteRTBackendError
|
|
4
|
+
import kotlinx.serialization.json.Json
|
|
5
|
+
import kotlinx.serialization.json.JsonArray
|
|
6
|
+
import kotlinx.serialization.json.JsonObject
|
|
7
|
+
import kotlinx.serialization.json.JsonPrimitive
|
|
8
|
+
import kotlinx.serialization.json.booleanOrNull
|
|
9
|
+
import kotlinx.serialization.json.contentOrNull
|
|
10
|
+
import kotlinx.serialization.json.intOrNull
|
|
11
|
+
import kotlinx.serialization.json.jsonArray
|
|
12
|
+
import kotlinx.serialization.json.jsonObject
|
|
13
|
+
import kotlinx.serialization.json.jsonPrimitive
|
|
14
|
+
import java.io.File
|
|
15
|
+
|
|
16
|
+
/**
|
|
17
|
+
* Pure-Kotlin BPE tokenizer that loads HuggingFace's standard
|
|
18
|
+
* `tokenizer.json` schema. No JNI, no native library — works on every
|
|
19
|
+
* Android ABI without surprise UnsatisfiedLinkErrors.
|
|
20
|
+
*
|
|
21
|
+
* Why a custom parser instead of an off-the-shelf artifact?
|
|
22
|
+
* - `com.github.huggingface:tokenizers-android` (JitPack) does not exist —
|
|
23
|
+
* the URL 401s. The plan's original guess was wrong.
|
|
24
|
+
* - `ai.djl.huggingface:tokenizers:0.36.0` (Maven Central) is JVM-only:
|
|
25
|
+
* DJL ships `libtokenizers.so` for x86_64 + aarch64-linux-gnu + macOS +
|
|
26
|
+
* Windows but NOT for Android (`*-linux-android`). Pulling DJL would
|
|
27
|
+
* crash the first encode() with UnsatisfiedLinkError on every Android
|
|
28
|
+
* target ABI.
|
|
29
|
+
* - HF's official Rust crate has no Android JNI wrapper on Maven.
|
|
30
|
+
*
|
|
31
|
+
* What's supported:
|
|
32
|
+
* - BPE merges (byte-pair encoding) — the standard Llama-3 / Gemma-2
|
|
33
|
+
* tokenizer.json shape: `model.type == "BPE"`, `model.vocab` as
|
|
34
|
+
* {token: id}, `model.merges` as space-separated `"A B"` pairs OR
|
|
35
|
+
* array-of-pair tuples (HF v0.21+ format).
|
|
36
|
+
* - Special / added tokens via `added_tokens` array (each entry is
|
|
37
|
+
* `{ id, content, special }`).
|
|
38
|
+
* - GPT-2 byte-level encoding pre-tokenizer (the standard Llama-3 case):
|
|
39
|
+
* every input byte mapped through GPT-2's printable byte permutation
|
|
40
|
+
* so the BPE step never has to handle unicode-class boundaries.
|
|
41
|
+
* - decode() reverses the byte-level mapping and concatenates pieces.
|
|
42
|
+
*
|
|
43
|
+
* What's NOT supported (call sites must avoid these models):
|
|
44
|
+
* - SentencePiece / Unigram tokenizers (`model.type == "Unigram"`) — Gemma
|
|
45
|
+
* uses these; for Gemma checkpoints the consumer should use the
|
|
46
|
+
* mediapipe backend instead which uses LiteRT-LM's bundled SentencePiece.
|
|
47
|
+
* - Jinja chat templates from `tokenizer_config.json`. The handler layer
|
|
48
|
+
* formats messages with a hard-coded Llama-3-style template
|
|
49
|
+
* (`<|begin_of_text|><|start_header_id|>{role}<|end_header_id|>\n\n{content}<|eot_id|>`)
|
|
50
|
+
* which works for Llama-3 family checkpoints; non-Llama checkpoints
|
|
51
|
+
* require the consumer to pre-render the prompt themselves.
|
|
52
|
+
* - Pre-tokenizer types other than ByteLevel (Whitespace, Sequence, etc.).
|
|
53
|
+
*
|
|
54
|
+
* If the loader encounters an unsupported tokenizer.json shape it throws
|
|
55
|
+
* [LiteRTBackendError.TokenizerLoadFailed] with a precise reason so the
|
|
56
|
+
* caller can fall back to a different backend.
|
|
57
|
+
*
|
|
58
|
+
* Mirrors the role of `CoreMLTokenizer.swift` (iOS) without the
|
|
59
|
+
* swift-transformers dependency — there's no equivalent maintained library
|
|
60
|
+
* on Android.
|
|
61
|
+
*/
|
|
62
|
+
internal class HFTokenizerJson private constructor(
|
|
63
|
+
private val vocab: Map<String, Int>,
|
|
64
|
+
private val idToToken: Map<Int, String>,
|
|
65
|
+
private val mergeRanks: Map<Pair<String, String>, Int>,
|
|
66
|
+
private val specialTokens: Set<String>,
|
|
67
|
+
private val byteToUnicode: Map<Int, Char>,
|
|
68
|
+
private val unicodeToByte: Map<Char, Int>,
|
|
69
|
+
val bosTokenId: Int?,
|
|
70
|
+
val eosTokenId: Int,
|
|
71
|
+
val padTokenId: Int?,
|
|
72
|
+
) {
|
|
73
|
+
|
|
74
|
+
/**
|
|
75
|
+
* Encode a UTF-8 string to a token-id list using BPE.
|
|
76
|
+
*
|
|
77
|
+
* Pipeline: UTF-8 bytes -> GPT-2 byte→unicode permutation -> BPE merges
|
|
78
|
+
* -> vocab lookup. Special tokens in the input string are matched
|
|
79
|
+
* verbatim before BPE runs (so `<|eot_id|>` resolves to a single id
|
|
80
|
+
* rather than being split into pieces).
|
|
81
|
+
*/
|
|
82
|
+
fun encode(text: String): List<Int> {
|
|
83
|
+
if (text.isEmpty()) return emptyList()
|
|
84
|
+
val out = mutableListOf<Int>()
|
|
85
|
+
// Greedy special-token splitter: for each occurrence of a known
|
|
86
|
+
// special token, emit it as a single id; BPE the gap before it.
|
|
87
|
+
var cursor = 0
|
|
88
|
+
while (cursor < text.length) {
|
|
89
|
+
val match = findNextSpecial(text, cursor)
|
|
90
|
+
if (match == null) {
|
|
91
|
+
val tail = text.substring(cursor)
|
|
92
|
+
if (tail.isNotEmpty()) out.addAll(encodeBpe(tail))
|
|
93
|
+
break
|
|
94
|
+
}
|
|
95
|
+
// BPE the plain segment before the special token, then emit
|
|
96
|
+
// the special token id, then advance past it.
|
|
97
|
+
if (match.start > cursor) {
|
|
98
|
+
out.addAll(encodeBpe(text.substring(cursor, match.start)))
|
|
99
|
+
}
|
|
100
|
+
out.add(vocab.getValue(match.token))
|
|
101
|
+
cursor = match.start + match.token.length
|
|
102
|
+
}
|
|
103
|
+
return out
|
|
104
|
+
}
|
|
105
|
+
|
|
106
|
+
private data class SpecialMatch(val token: String, val start: Int)
|
|
107
|
+
|
|
108
|
+
/** Earliest-occurrence special token at or after [from]. Null if none. */
|
|
109
|
+
private fun findNextSpecial(text: String, from: Int): SpecialMatch? {
|
|
110
|
+
var best: SpecialMatch? = null
|
|
111
|
+
for (special in specialTokens) {
|
|
112
|
+
val idx = text.indexOf(special, from)
|
|
113
|
+
if (idx < 0) continue
|
|
114
|
+
if (best == null || idx < best.start) {
|
|
115
|
+
best = SpecialMatch(special, idx)
|
|
116
|
+
}
|
|
117
|
+
}
|
|
118
|
+
return best
|
|
119
|
+
}
|
|
120
|
+
|
|
121
|
+
private fun encodeBpe(text: String): List<Int> {
|
|
122
|
+
// GPT-2 byte-level: every UTF-8 byte mapped to a single unicode
|
|
123
|
+
// character via the byteToUnicode permutation, then BPE operates
|
|
124
|
+
// over the resulting string as a single "word" (HF tokenizer.json
|
|
125
|
+
// ByteLevel pre-tokenizer's default is to NOT split on whitespace
|
|
126
|
+
// for Llama-3 — every word boundary is preserved as a Ġ-prefixed
|
|
127
|
+
// piece during merges).
|
|
128
|
+
val bytes = text.toByteArray(Charsets.UTF_8)
|
|
129
|
+
val mapped = StringBuilder(bytes.size)
|
|
130
|
+
for (b in bytes) {
|
|
131
|
+
val unsigned = b.toInt() and 0xFF
|
|
132
|
+
mapped.append(byteToUnicode.getValue(unsigned))
|
|
133
|
+
}
|
|
134
|
+
return bpe(mapped.toString())
|
|
135
|
+
}
|
|
136
|
+
|
|
137
|
+
/**
|
|
138
|
+
* Apply BPE merges greedily to a single byte-level-encoded "word".
|
|
139
|
+
*
|
|
140
|
+
* Standard HF BPE algorithm:
|
|
141
|
+
* 1. Split the word into individual chars.
|
|
142
|
+
* 2. Find the pair with the lowest merge-rank among adjacent pairs.
|
|
143
|
+
* 3. Merge that pair everywhere it occurs in the symbol list.
|
|
144
|
+
* 4. Repeat until no more merges apply.
|
|
145
|
+
* 5. Look up each resulting symbol in the vocab.
|
|
146
|
+
*/
|
|
147
|
+
private fun bpe(word: String): List<Int> {
|
|
148
|
+
if (word.isEmpty()) return emptyList()
|
|
149
|
+
val symbols = word.map { it.toString() }.toMutableList()
|
|
150
|
+
if (symbols.size == 1) {
|
|
151
|
+
return listOf(vocab[symbols[0]] ?: vocab.getValue("<unk>"))
|
|
152
|
+
}
|
|
153
|
+
|
|
154
|
+
while (symbols.size >= 2) {
|
|
155
|
+
// Find lowest-rank adjacent pair.
|
|
156
|
+
var bestRank = Int.MAX_VALUE
|
|
157
|
+
var bestIdx = -1
|
|
158
|
+
for (i in 0 until symbols.size - 1) {
|
|
159
|
+
val rank = mergeRanks[symbols[i] to symbols[i + 1]] ?: continue
|
|
160
|
+
if (rank < bestRank) {
|
|
161
|
+
bestRank = rank
|
|
162
|
+
bestIdx = i
|
|
163
|
+
}
|
|
164
|
+
}
|
|
165
|
+
if (bestIdx < 0) break
|
|
166
|
+
// Merge every occurrence of the best pair, left-to-right.
|
|
167
|
+
val left = symbols[bestIdx]
|
|
168
|
+
val right = symbols[bestIdx + 1]
|
|
169
|
+
val merged = left + right
|
|
170
|
+
var r = 0
|
|
171
|
+
val rebuilt = ArrayList<String>(symbols.size)
|
|
172
|
+
while (r < symbols.size) {
|
|
173
|
+
if (r < symbols.size - 1 && symbols[r] == left && symbols[r + 1] == right) {
|
|
174
|
+
rebuilt.add(merged)
|
|
175
|
+
r += 2
|
|
176
|
+
} else {
|
|
177
|
+
rebuilt.add(symbols[r])
|
|
178
|
+
r += 1
|
|
179
|
+
}
|
|
180
|
+
}
|
|
181
|
+
symbols.clear()
|
|
182
|
+
symbols.addAll(rebuilt)
|
|
183
|
+
}
|
|
184
|
+
|
|
185
|
+
return symbols.map { sym ->
|
|
186
|
+
vocab[sym] ?: vocab["<unk>"] ?: error("token '$sym' not in vocab and no <unk> fallback")
|
|
187
|
+
}
|
|
188
|
+
}
|
|
189
|
+
|
|
190
|
+
/**
|
|
191
|
+
* Decode a list of token ids back to a UTF-8 string. Reverses the
|
|
192
|
+
* byte-level mapping. Special tokens are skipped if [skipSpecialTokens]
|
|
193
|
+
* is true (the default for chat output).
|
|
194
|
+
*/
|
|
195
|
+
fun decode(tokens: List<Int>, skipSpecialTokens: Boolean = true): String {
|
|
196
|
+
val pieces = StringBuilder()
|
|
197
|
+
for (id in tokens) {
|
|
198
|
+
val tok = idToToken[id] ?: continue
|
|
199
|
+
if (skipSpecialTokens && tok in specialTokens) continue
|
|
200
|
+
pieces.append(tok)
|
|
201
|
+
}
|
|
202
|
+
// Reverse the byte-level mapping: every char in `pieces` was the
|
|
203
|
+
// image of one input byte. Map each char back to its byte value
|
|
204
|
+
// and decode the resulting byte sequence as UTF-8.
|
|
205
|
+
val out = ByteArray(pieces.length)
|
|
206
|
+
var n = 0
|
|
207
|
+
for (i in pieces.indices) {
|
|
208
|
+
val byteVal = unicodeToByte[pieces[i]]
|
|
209
|
+
// Tokens added by `added_tokens` (e.g. chat-template control
|
|
210
|
+
// tokens) live OUTSIDE the byte-level alphabet — their chars
|
|
211
|
+
// are not in unicodeToByte. Skip them (or emit '?' if you want
|
|
212
|
+
// a visible artefact). For chat output, skipping is correct.
|
|
213
|
+
if (byteVal != null) {
|
|
214
|
+
out[n] = byteVal.toByte()
|
|
215
|
+
n += 1
|
|
216
|
+
}
|
|
217
|
+
}
|
|
218
|
+
return String(out, 0, n, Charsets.UTF_8)
|
|
219
|
+
}
|
|
220
|
+
|
|
221
|
+
fun decode(token: Int): String = decode(listOf(token), skipSpecialTokens = true)
|
|
222
|
+
|
|
223
|
+
companion object {
|
|
224
|
+
private val parser = Json { ignoreUnknownKeys = true }
|
|
225
|
+
|
|
226
|
+
/**
|
|
227
|
+
* Load a tokenizer.json from disk. Throws
|
|
228
|
+
* [LiteRTBackendError.TokenizerLoadFailed] on any parse / structure
|
|
229
|
+
* failure with a precise reason.
|
|
230
|
+
*/
|
|
231
|
+
@Throws(LiteRTBackendError.TokenizerLoadFailed::class)
|
|
232
|
+
fun load(tokenizerJsonPath: String, eosTokenIdOverride: Int? = null): HFTokenizerJson {
|
|
233
|
+
val file = File(tokenizerJsonPath)
|
|
234
|
+
if (!file.isFile) {
|
|
235
|
+
throw LiteRTBackendError.TokenizerLoadFailed(
|
|
236
|
+
"tokenizer.json not found at $tokenizerJsonPath",
|
|
237
|
+
)
|
|
238
|
+
}
|
|
239
|
+
val root = try {
|
|
240
|
+
parser.parseToJsonElement(file.readText()).jsonObject
|
|
241
|
+
} catch (t: Throwable) {
|
|
242
|
+
throw LiteRTBackendError.TokenizerLoadFailed("failed to parse tokenizer.json: ${t.message}")
|
|
243
|
+
}
|
|
244
|
+
|
|
245
|
+
val model = root["model"] as? JsonObject
|
|
246
|
+
?: throw LiteRTBackendError.TokenizerLoadFailed("tokenizer.json: missing 'model' object")
|
|
247
|
+
val type = (model["type"] as? JsonPrimitive)?.contentOrNull
|
|
248
|
+
if (type != null && type != "BPE") {
|
|
249
|
+
throw LiteRTBackendError.TokenizerLoadFailed(
|
|
250
|
+
"tokenizer.json: model.type='$type' is not supported (only BPE). Use the mediapipe backend for SentencePiece/Unigram models.",
|
|
251
|
+
)
|
|
252
|
+
}
|
|
253
|
+
|
|
254
|
+
val vocabRaw = model["vocab"] as? JsonObject
|
|
255
|
+
?: throw LiteRTBackendError.TokenizerLoadFailed("tokenizer.json: missing 'model.vocab'")
|
|
256
|
+
val vocab = HashMap<String, Int>(vocabRaw.size)
|
|
257
|
+
val idToToken = HashMap<Int, String>(vocabRaw.size)
|
|
258
|
+
for ((tok, idEl) in vocabRaw) {
|
|
259
|
+
val id = (idEl as? JsonPrimitive)?.intOrNull
|
|
260
|
+
?: throw LiteRTBackendError.TokenizerLoadFailed("tokenizer.json: vocab entry '$tok' is not an int")
|
|
261
|
+
vocab[tok] = id
|
|
262
|
+
idToToken[id] = tok
|
|
263
|
+
}
|
|
264
|
+
|
|
265
|
+
val mergesRaw = model["merges"] as? JsonArray
|
|
266
|
+
?: throw LiteRTBackendError.TokenizerLoadFailed("tokenizer.json: missing 'model.merges'")
|
|
267
|
+
val mergeRanks = HashMap<Pair<String, String>, Int>(mergesRaw.size)
|
|
268
|
+
for ((rank, mEl) in mergesRaw.withIndex()) {
|
|
269
|
+
val pair = parseMergeEntry(mEl)
|
|
270
|
+
?: throw LiteRTBackendError.TokenizerLoadFailed(
|
|
271
|
+
"tokenizer.json: merges[$rank] is not a 'A B' string or [A,B] pair",
|
|
272
|
+
)
|
|
273
|
+
mergeRanks[pair] = rank
|
|
274
|
+
}
|
|
275
|
+
|
|
276
|
+
// Special / added tokens. Each entry shape: { id, content, special, ... }.
|
|
277
|
+
// We treat anything with `special: true` (or anything in this list,
|
|
278
|
+
// since added_tokens are by convention always specials in modern HF
|
|
279
|
+
// tokenizer.json files) as a special token: matched verbatim by
|
|
280
|
+
// encode(), skipped by decode() when skipSpecialTokens=true.
|
|
281
|
+
val specialTokens = mutableSetOf<String>()
|
|
282
|
+
(root["added_tokens"] as? JsonArray)?.forEach { entry ->
|
|
283
|
+
val obj = entry as? JsonObject ?: return@forEach
|
|
284
|
+
val content = (obj["content"] as? JsonPrimitive)?.contentOrNull ?: return@forEach
|
|
285
|
+
val id = (obj["id"] as? JsonPrimitive)?.intOrNull
|
|
286
|
+
if (id != null) {
|
|
287
|
+
vocab[content] = id
|
|
288
|
+
idToToken[id] = content
|
|
289
|
+
}
|
|
290
|
+
val isSpecial = (obj["special"] as? JsonPrimitive)?.booleanOrNull ?: true
|
|
291
|
+
if (isSpecial) specialTokens.add(content)
|
|
292
|
+
}
|
|
293
|
+
|
|
294
|
+
// Discover BOS / EOS / PAD ids from `added_tokens` first, then
|
|
295
|
+
// from the standard names. The caller can override EOS via opts.
|
|
296
|
+
val bosTokenId = vocab["<|begin_of_text|>"] ?: vocab["<s>"] ?: vocab["<bos>"]
|
|
297
|
+
val discoveredEos = vocab["<|eot_id|>"]
|
|
298
|
+
?: vocab["<|end_of_text|>"]
|
|
299
|
+
?: vocab["</s>"]
|
|
300
|
+
?: vocab["<eos>"]
|
|
301
|
+
val eosTokenId = eosTokenIdOverride ?: discoveredEos
|
|
302
|
+
?: throw LiteRTBackendError.TokenizerLoadFailed(
|
|
303
|
+
"tokenizer.json: no EOS-like token in added_tokens (looked for <|eot_id|>, <|end_of_text|>, </s>, <eos>) — pass eosTokenId in start opts to override",
|
|
304
|
+
)
|
|
305
|
+
val padTokenId = vocab["<pad>"] ?: vocab["<|pad|>"]
|
|
306
|
+
|
|
307
|
+
val (b2u, u2b) = buildByteLevelMap()
|
|
308
|
+
|
|
309
|
+
return HFTokenizerJson(
|
|
310
|
+
vocab = vocab,
|
|
311
|
+
idToToken = idToToken,
|
|
312
|
+
mergeRanks = mergeRanks,
|
|
313
|
+
specialTokens = specialTokens,
|
|
314
|
+
byteToUnicode = b2u,
|
|
315
|
+
unicodeToByte = u2b,
|
|
316
|
+
bosTokenId = bosTokenId,
|
|
317
|
+
eosTokenId = eosTokenId,
|
|
318
|
+
padTokenId = padTokenId,
|
|
319
|
+
)
|
|
320
|
+
}
|
|
321
|
+
|
|
322
|
+
/**
|
|
323
|
+
* Parse one entry of tokenizer.json's `model.merges` array. Two
|
|
324
|
+
* shapes are seen in the wild:
|
|
325
|
+
* - String: "A B" (older HF, Llama-2-style). Split on first space.
|
|
326
|
+
* - Array of two strings: ["A", "B"] (HF v0.21+ default).
|
|
327
|
+
*/
|
|
328
|
+
private fun parseMergeEntry(el: kotlinx.serialization.json.JsonElement): Pair<String, String>? {
|
|
329
|
+
return when (el) {
|
|
330
|
+
is JsonPrimitive -> {
|
|
331
|
+
val s = el.contentOrNull ?: return null
|
|
332
|
+
val sp = s.indexOf(' ')
|
|
333
|
+
if (sp < 0) return null
|
|
334
|
+
s.substring(0, sp) to s.substring(sp + 1)
|
|
335
|
+
}
|
|
336
|
+
is JsonArray -> {
|
|
337
|
+
if (el.size != 2) return null
|
|
338
|
+
val a = (el[0] as? JsonPrimitive)?.contentOrNull ?: return null
|
|
339
|
+
val b = (el[1] as? JsonPrimitive)?.contentOrNull ?: return null
|
|
340
|
+
a to b
|
|
341
|
+
}
|
|
342
|
+
else -> null
|
|
343
|
+
}
|
|
344
|
+
}
|
|
345
|
+
|
|
346
|
+
/**
|
|
347
|
+
* Construct GPT-2's reversible byte→unicode permutation. Maps each
|
|
348
|
+
* of the 256 byte values to a printable unicode codepoint:
|
|
349
|
+
* - Bytes that are already printable ASCII (33..126), Latin-1
|
|
350
|
+
* supplement printable (161..172, 174..255) map to themselves.
|
|
351
|
+
* - All other bytes (0..32, 127..160, 173) map to the Latin-1
|
|
352
|
+
* Supplement / Latin-Extended-A range starting at 256, in order.
|
|
353
|
+
*
|
|
354
|
+
* Reference: HuggingFace tokenizers' ByteLevel `bytes_to_unicode()`
|
|
355
|
+
* Python helper. The output map is identical between HF Python,
|
|
356
|
+
* tokenizers Rust, and this Kotlin port.
|
|
357
|
+
*/
|
|
358
|
+
private fun buildByteLevelMap(): Pair<Map<Int, Char>, Map<Char, Int>> {
|
|
359
|
+
val printable = (33..126) + (161..172) + (174..255)
|
|
360
|
+
val bs = printable.toMutableList()
|
|
361
|
+
val cs = printable.map { it }.toMutableList()
|
|
362
|
+
var n = 0
|
|
363
|
+
for (b in 0..255) {
|
|
364
|
+
if (b !in printable) {
|
|
365
|
+
bs.add(b)
|
|
366
|
+
cs.add(256 + n)
|
|
367
|
+
n += 1
|
|
368
|
+
}
|
|
369
|
+
}
|
|
370
|
+
val byteToChar = HashMap<Int, Char>(256)
|
|
371
|
+
val charToByte = HashMap<Char, Int>(256)
|
|
372
|
+
for (i in bs.indices) {
|
|
373
|
+
val ch = cs[i].toChar()
|
|
374
|
+
byteToChar[bs[i]] = ch
|
|
375
|
+
charToByte[ch] = bs[i]
|
|
376
|
+
}
|
|
377
|
+
return byteToChar to charToByte
|
|
378
|
+
}
|
|
379
|
+
}
|
|
380
|
+
}
|
|
@@ -0,0 +1,241 @@
|
|
|
1
|
+
package co.deepvoiceai.bridge.litert.core.Internal
|
|
2
|
+
|
|
3
|
+
import co.deepvoiceai.bridge.litert.core.LiteRTBackendError
|
|
4
|
+
import com.google.ai.edge.litert.Accelerator
|
|
5
|
+
import com.google.ai.edge.litert.CompiledModel
|
|
6
|
+
import com.google.ai.edge.litert.TensorBuffer
|
|
7
|
+
import com.google.ai.edge.litert.TensorType
|
|
8
|
+
import java.io.File
|
|
9
|
+
|
|
10
|
+
/**
|
|
11
|
+
* Test seam over the LiteRT [CompiledModel]. Concrete [LiteRTEngine] runs
|
|
12
|
+
* the real native runtime; [LiteRTGenerator]'s mock test substitutes a
|
|
13
|
+
* canned-logits fake without loading a .tflite.
|
|
14
|
+
*/
|
|
15
|
+
internal interface LiteRTEngineApi {
|
|
16
|
+
/** Vocab size (= length of the FloatArray returned by [runStep]). */
|
|
17
|
+
val vocabSize: Int
|
|
18
|
+
|
|
19
|
+
/** EOS id in the model's vocab — generator uses it to terminate decode. */
|
|
20
|
+
val eosTokenId: Int
|
|
21
|
+
|
|
22
|
+
/**
|
|
23
|
+
* Run a single forward pass with [token] at position [kvCachePosition].
|
|
24
|
+
* Returns the logits row for the next-token prediction (length = [vocabSize]).
|
|
25
|
+
* Throws [LiteRTBackendError.GenerationFailed] on native failure.
|
|
26
|
+
*/
|
|
27
|
+
fun runStep(token: Int, kvCachePosition: Int): FloatArray
|
|
28
|
+
|
|
29
|
+
/** Release native resources. Idempotent. */
|
|
30
|
+
fun close()
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
/**
|
|
34
|
+
* Wraps Google's LiteRT [CompiledModel] for a stateful Llama-style
|
|
35
|
+
* autoregressive checkpoint. Drives single-token decoding via named-tensor
|
|
36
|
+
* `run(inputs, outputs, signature)` calls.
|
|
37
|
+
*
|
|
38
|
+
* Why not use LiteRT-LM? We deliberately depend on bare `litert` (see
|
|
39
|
+
* `android/build.gradle` top-of-file comment), so the KV-cache / sampler
|
|
40
|
+
* loop is implemented here in Kotlin. The Llama-style .tflite checkpoints
|
|
41
|
+
* we target carry the cache as graph-internal state, exposed through
|
|
42
|
+
* named inputs/outputs that the runtime maintains across calls within one
|
|
43
|
+
* [CompiledModel] instance — same shape Apple's CoreML stateful Llama
|
|
44
|
+
* checkpoints follow on iOS (see `CoreMLEngine.swift`).
|
|
45
|
+
*
|
|
46
|
+
* Tensor convention (auto-detected at init via [CompiledModel.getInputTensorType]):
|
|
47
|
+
* - [inputName] `input_ids` INT32 [1, 1] (default, overridable)
|
|
48
|
+
* - [causalMaskName] `causal_mask` FLOAT [1, 1, 1, kv_len] (optional)
|
|
49
|
+
* - [outputName] `logits` FLOAT [1, 1, vocab] or [1, vocab] (auto)
|
|
50
|
+
*
|
|
51
|
+
* If the model declares no `causal_mask` input we silently skip writing
|
|
52
|
+
* it — many simpler stateful checkpoints don't expose one.
|
|
53
|
+
*
|
|
54
|
+
* This class is NOT thread-safe. [LiteRTHandlers] serializes all calls
|
|
55
|
+
* behind a mutex; do the same in any other call site.
|
|
56
|
+
*/
|
|
57
|
+
internal class LiteRTEngine(
|
|
58
|
+
modelPath: String,
|
|
59
|
+
private val inputName: String = "input_ids",
|
|
60
|
+
private val causalMaskName: String = "causal_mask",
|
|
61
|
+
private val outputName: String = "logits",
|
|
62
|
+
/** Surface override so the handler / config layer can lift it from start opts. */
|
|
63
|
+
@Suppress("UNUSED_PARAMETER")
|
|
64
|
+
private val contextSize: Int = 2048,
|
|
65
|
+
eosTokenId: Int,
|
|
66
|
+
accelerator: Accelerator = Accelerator.CPU,
|
|
67
|
+
) : LiteRTEngineApi, AutoCloseable {
|
|
68
|
+
|
|
69
|
+
private val model: CompiledModel
|
|
70
|
+
override val vocabSize: Int
|
|
71
|
+
override val eosTokenId: Int = eosTokenId
|
|
72
|
+
private val hasCausalMask: Boolean
|
|
73
|
+
private val causalMaskRank: Int
|
|
74
|
+
private val inputIsInt64: Boolean
|
|
75
|
+
|
|
76
|
+
init {
|
|
77
|
+
val f = File(modelPath)
|
|
78
|
+
if (!f.isFile) {
|
|
79
|
+
throw LiteRTBackendError.ModelLoadFailed("model file not found at $modelPath")
|
|
80
|
+
}
|
|
81
|
+
|
|
82
|
+
model = try {
|
|
83
|
+
CompiledModel.create(modelPath, CompiledModel.Options(accelerator))
|
|
84
|
+
} catch (t: Throwable) {
|
|
85
|
+
throw LiteRTBackendError.ModelLoadFailed("CompiledModel.create failed: ${t.message ?: t::class.java.simpleName}")
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
// Validate the input_ids tensor exists and capture its rank/dtype
|
|
89
|
+
// for the writeInt path. We don't enforce shape == [1,1] here —
|
|
90
|
+
// the model owns its declared signature; we just feed [token] as
|
|
91
|
+
// a 1-element IntArray and let the runtime broadcast / error out.
|
|
92
|
+
val inputType = try {
|
|
93
|
+
model.getInputTensorType(inputName)
|
|
94
|
+
} catch (t: Throwable) {
|
|
95
|
+
// If the named-tensor lookup fails, the consumer's checkpoint
|
|
96
|
+
// doesn't follow our convention. Fail fast with a precise
|
|
97
|
+
// message — silent fallback to default signature would only
|
|
98
|
+
// surface the real shape mismatch deep inside nativeRun().
|
|
99
|
+
model.close()
|
|
100
|
+
throw LiteRTBackendError.ModelLoadFailed(
|
|
101
|
+
"input tensor '$inputName' not found on model (override via litertInputName opt). " +
|
|
102
|
+
"Cause: ${t.message ?: t::class.java.simpleName}",
|
|
103
|
+
)
|
|
104
|
+
}
|
|
105
|
+
// Both INT32 and INT64 input_ids are seen on Llama-style
|
|
106
|
+
// checkpoints in the wild (Llama-3 typically int32, some Gemma
|
|
107
|
+
// exports int64). We dispatch to `writeInt` vs `writeLong` based
|
|
108
|
+
// on the declared element type at runStep time, so both shapes
|
|
109
|
+
// round-trip cleanly through LiteRT's strict dtype check.
|
|
110
|
+
if (inputType.elementType != TensorType.ElementType.INT &&
|
|
111
|
+
inputType.elementType != TensorType.ElementType.INT64
|
|
112
|
+
) {
|
|
113
|
+
model.close()
|
|
114
|
+
throw LiteRTBackendError.ModelLoadFailed(
|
|
115
|
+
"input tensor '$inputName' has unsupported elementType=${inputType.elementType}; expected INT or INT64",
|
|
116
|
+
)
|
|
117
|
+
}
|
|
118
|
+
inputIsInt64 = inputType.elementType == TensorType.ElementType.INT64
|
|
119
|
+
|
|
120
|
+
// causal_mask is optional — many simpler checkpoints don't expose it.
|
|
121
|
+
// Probe via a try/catch on getInputTensorType since the LiteRT API
|
|
122
|
+
// doesn't have a non-throwing "does this input exist?" call.
|
|
123
|
+
var maskRank = 0
|
|
124
|
+
val maskPresent = try {
|
|
125
|
+
val maskType = model.getInputTensorType(causalMaskName)
|
|
126
|
+
maskRank = maskType.layout?.rank ?: 0
|
|
127
|
+
true
|
|
128
|
+
} catch (_: Throwable) {
|
|
129
|
+
false
|
|
130
|
+
}
|
|
131
|
+
hasCausalMask = maskPresent
|
|
132
|
+
causalMaskRank = maskRank
|
|
133
|
+
|
|
134
|
+
// Discover logits rank + vocab size by inspecting the output
|
|
135
|
+
// tensor type. Handles both [1, 1, V] (Llama-3 style) and
|
|
136
|
+
// [1, V] (some Gemma exports).
|
|
137
|
+
val outputType = try {
|
|
138
|
+
model.getOutputTensorType(outputName)
|
|
139
|
+
} catch (t: Throwable) {
|
|
140
|
+
model.close()
|
|
141
|
+
throw LiteRTBackendError.ModelLoadFailed(
|
|
142
|
+
"output tensor '$outputName' not found on model (override via litertOutputName opt). " +
|
|
143
|
+
"Cause: ${t.message ?: t::class.java.simpleName}",
|
|
144
|
+
)
|
|
145
|
+
}
|
|
146
|
+
val outDims = outputType.layout?.dimensions ?: emptyList()
|
|
147
|
+
vocabSize = when (outDims.size) {
|
|
148
|
+
3 -> outDims[2] // [1, 1, V] — Llama-3 style
|
|
149
|
+
2 -> outDims[1] // [1, V] — some Gemma exports
|
|
150
|
+
else -> {
|
|
151
|
+
model.close()
|
|
152
|
+
throw LiteRTBackendError.ModelLoadFailed(
|
|
153
|
+
"output tensor '$outputName' has unsupported rank=${outDims.size}; expected 2 or 3 with vocab as last dim",
|
|
154
|
+
)
|
|
155
|
+
}
|
|
156
|
+
}
|
|
157
|
+
if (vocabSize <= 0) {
|
|
158
|
+
model.close()
|
|
159
|
+
throw LiteRTBackendError.ModelLoadFailed(
|
|
160
|
+
"output tensor '$outputName' reports non-positive vocab size: $vocabSize",
|
|
161
|
+
)
|
|
162
|
+
}
|
|
163
|
+
}
|
|
164
|
+
|
|
165
|
+
override fun runStep(token: Int, kvCachePosition: Int): FloatArray {
|
|
166
|
+
val inputs = mutableMapOf<String, TensorBuffer>()
|
|
167
|
+
val outputs = mutableMapOf<String, TensorBuffer>()
|
|
168
|
+
val opened = mutableListOf<TensorBuffer>()
|
|
169
|
+
try {
|
|
170
|
+
// input_ids: [1, 1] with the new token. writeInt vs writeLong
|
|
171
|
+
// is selected from the declared element type captured at init.
|
|
172
|
+
val inputBuf = model.createInputBuffer(inputName)
|
|
173
|
+
opened.add(inputBuf)
|
|
174
|
+
if (inputIsInt64) {
|
|
175
|
+
inputBuf.writeLong(longArrayOf(token.toLong()))
|
|
176
|
+
} else {
|
|
177
|
+
inputBuf.writeInt(intArrayOf(token))
|
|
178
|
+
}
|
|
179
|
+
inputs[inputName] = inputBuf
|
|
180
|
+
|
|
181
|
+
// causal_mask: [1, 1, 1, kvCachePosition+1] all-zeros if the
|
|
182
|
+
// model declares one. Zero = unmasked, large negative = masked;
|
|
183
|
+
// for a single-token decode step every prior position is visible.
|
|
184
|
+
// LiteRT only exposes writeFloat for FP tensors — even if the
|
|
185
|
+
// declared dtype is FP16, the runtime accepts FP32 input and
|
|
186
|
+
// converts internally. (HF model converts also produce FP32
|
|
187
|
+
// causal_masks more often than FP16 in 2026 conversions.)
|
|
188
|
+
if (hasCausalMask && causalMaskRank > 0) {
|
|
189
|
+
val kvLen = maxOf(1, kvCachePosition + 1)
|
|
190
|
+
val maskBuf = model.createInputBuffer(causalMaskName)
|
|
191
|
+
opened.add(maskBuf)
|
|
192
|
+
// Zero-fill with size = product of the buffer's logical
|
|
193
|
+
// dimensions (we pass the full kvLen-sized buffer; LiteRT
|
|
194
|
+
// resizes dynamic-axis tensors based on writeFloat length).
|
|
195
|
+
maskBuf.writeFloat(FloatArray(kvLen))
|
|
196
|
+
inputs[causalMaskName] = maskBuf
|
|
197
|
+
}
|
|
198
|
+
|
|
199
|
+
val outputBuf = model.createOutputBuffer(outputName)
|
|
200
|
+
opened.add(outputBuf)
|
|
201
|
+
outputs[outputName] = outputBuf
|
|
202
|
+
|
|
203
|
+
try {
|
|
204
|
+
model.run(inputs, outputs)
|
|
205
|
+
} catch (t: Throwable) {
|
|
206
|
+
throw LiteRTBackendError.GenerationFailed(
|
|
207
|
+
"model.run failed at kvPos=$kvCachePosition token=$token: ${t.message ?: t::class.java.simpleName}",
|
|
208
|
+
)
|
|
209
|
+
}
|
|
210
|
+
|
|
211
|
+
val raw = outputBuf.readFloat()
|
|
212
|
+
// For rank-3 logits we want the LAST row (the prediction for
|
|
213
|
+
// the *next* token). With a [1, 1, V] shape there's only one
|
|
214
|
+
// row anyway so raw IS the next-token logits — but if the
|
|
215
|
+
// checkpoint produces [1, T, V] for a multi-token prefill we'd
|
|
216
|
+
// want raw.takeLast(V). We default to the last-V slice for
|
|
217
|
+
// robustness.
|
|
218
|
+
return if (raw.size == vocabSize) {
|
|
219
|
+
raw
|
|
220
|
+
} else {
|
|
221
|
+
val start = raw.size - vocabSize
|
|
222
|
+
if (start < 0) {
|
|
223
|
+
throw LiteRTBackendError.GenerationFailed(
|
|
224
|
+
"logits buffer length ${raw.size} is smaller than vocabSize $vocabSize",
|
|
225
|
+
)
|
|
226
|
+
}
|
|
227
|
+
raw.copyOfRange(start, raw.size)
|
|
228
|
+
}
|
|
229
|
+
} finally {
|
|
230
|
+
// Release every per-call TensorBuffer (createInputBuffer +
|
|
231
|
+
// createOutputBuffer allocate fresh native handles each call).
|
|
232
|
+
for (buf in opened) {
|
|
233
|
+
runCatching { buf.close() }
|
|
234
|
+
}
|
|
235
|
+
}
|
|
236
|
+
}
|
|
237
|
+
|
|
238
|
+
override fun close() {
|
|
239
|
+
runCatching { model.close() }
|
|
240
|
+
}
|
|
241
|
+
}
|