@dvai-bridge/android-litert-core 4.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,380 @@
1
+ package co.deepvoiceai.bridge.litert.core.Internal
2
+
3
+ import co.deepvoiceai.bridge.litert.core.LiteRTBackendError
4
+ import kotlinx.serialization.json.Json
5
+ import kotlinx.serialization.json.JsonArray
6
+ import kotlinx.serialization.json.JsonObject
7
+ import kotlinx.serialization.json.JsonPrimitive
8
+ import kotlinx.serialization.json.booleanOrNull
9
+ import kotlinx.serialization.json.contentOrNull
10
+ import kotlinx.serialization.json.intOrNull
11
+ import kotlinx.serialization.json.jsonArray
12
+ import kotlinx.serialization.json.jsonObject
13
+ import kotlinx.serialization.json.jsonPrimitive
14
+ import java.io.File
15
+
16
+ /**
17
+ * Pure-Kotlin BPE tokenizer that loads HuggingFace's standard
18
+ * `tokenizer.json` schema. No JNI, no native library — works on every
19
+ * Android ABI without surprise UnsatisfiedLinkErrors.
20
+ *
21
+ * Why a custom parser instead of an off-the-shelf artifact?
22
+ * - `com.github.huggingface:tokenizers-android` (JitPack) does not exist —
23
+ * the URL 401s. The plan's original guess was wrong.
24
+ * - `ai.djl.huggingface:tokenizers:0.36.0` (Maven Central) is JVM-only:
25
+ * DJL ships `libtokenizers.so` for x86_64 + aarch64-linux-gnu + macOS +
26
+ * Windows but NOT for Android (`*-linux-android`). Pulling DJL would
27
+ * crash the first encode() with UnsatisfiedLinkError on every Android
28
+ * target ABI.
29
+ * - HF's official Rust crate has no Android JNI wrapper on Maven.
30
+ *
31
+ * What's supported:
32
+ * - BPE merges (byte-pair encoding) — the standard Llama-3 / Gemma-2
33
+ * tokenizer.json shape: `model.type == "BPE"`, `model.vocab` as
34
+ * {token: id}, `model.merges` as space-separated `"A B"` pairs OR
35
+ * array-of-pair tuples (HF v0.21+ format).
36
+ * - Special / added tokens via `added_tokens` array (each entry is
37
+ * `{ id, content, special }`).
38
+ * - GPT-2 byte-level encoding pre-tokenizer (the standard Llama-3 case):
39
+ * every input byte mapped through GPT-2's printable byte permutation
40
+ * so the BPE step never has to handle unicode-class boundaries.
41
+ * - decode() reverses the byte-level mapping and concatenates pieces.
42
+ *
43
+ * What's NOT supported (call sites must avoid these models):
44
+ * - SentencePiece / Unigram tokenizers (`model.type == "Unigram"`) — Gemma
45
+ * uses these; for Gemma checkpoints the consumer should use the
46
+ * mediapipe backend instead which uses LiteRT-LM's bundled SentencePiece.
47
+ * - Jinja chat templates from `tokenizer_config.json`. The handler layer
48
+ * formats messages with a hard-coded Llama-3-style template
49
+ * (`<|begin_of_text|><|start_header_id|>{role}<|end_header_id|>\n\n{content}<|eot_id|>`)
50
+ * which works for Llama-3 family checkpoints; non-Llama checkpoints
51
+ * require the consumer to pre-render the prompt themselves.
52
+ * - Pre-tokenizer types other than ByteLevel (Whitespace, Sequence, etc.).
53
+ *
54
+ * If the loader encounters an unsupported tokenizer.json shape it throws
55
+ * [LiteRTBackendError.TokenizerLoadFailed] with a precise reason so the
56
+ * caller can fall back to a different backend.
57
+ *
58
+ * Mirrors the role of `CoreMLTokenizer.swift` (iOS) without the
59
+ * swift-transformers dependency — there's no equivalent maintained library
60
+ * on Android.
61
+ */
62
+ internal class HFTokenizerJson private constructor(
63
+ private val vocab: Map<String, Int>,
64
+ private val idToToken: Map<Int, String>,
65
+ private val mergeRanks: Map<Pair<String, String>, Int>,
66
+ private val specialTokens: Set<String>,
67
+ private val byteToUnicode: Map<Int, Char>,
68
+ private val unicodeToByte: Map<Char, Int>,
69
+ val bosTokenId: Int?,
70
+ val eosTokenId: Int,
71
+ val padTokenId: Int?,
72
+ ) {
73
+
74
+ /**
75
+ * Encode a UTF-8 string to a token-id list using BPE.
76
+ *
77
+ * Pipeline: UTF-8 bytes -> GPT-2 byte→unicode permutation -> BPE merges
78
+ * -> vocab lookup. Special tokens in the input string are matched
79
+ * verbatim before BPE runs (so `<|eot_id|>` resolves to a single id
80
+ * rather than being split into pieces).
81
+ */
82
+ fun encode(text: String): List<Int> {
83
+ if (text.isEmpty()) return emptyList()
84
+ val out = mutableListOf<Int>()
85
+ // Greedy special-token splitter: for each occurrence of a known
86
+ // special token, emit it as a single id; BPE the gap before it.
87
+ var cursor = 0
88
+ while (cursor < text.length) {
89
+ val match = findNextSpecial(text, cursor)
90
+ if (match == null) {
91
+ val tail = text.substring(cursor)
92
+ if (tail.isNotEmpty()) out.addAll(encodeBpe(tail))
93
+ break
94
+ }
95
+ // BPE the plain segment before the special token, then emit
96
+ // the special token id, then advance past it.
97
+ if (match.start > cursor) {
98
+ out.addAll(encodeBpe(text.substring(cursor, match.start)))
99
+ }
100
+ out.add(vocab.getValue(match.token))
101
+ cursor = match.start + match.token.length
102
+ }
103
+ return out
104
+ }
105
+
106
+ private data class SpecialMatch(val token: String, val start: Int)
107
+
108
+ /** Earliest-occurrence special token at or after [from]. Null if none. */
109
+ private fun findNextSpecial(text: String, from: Int): SpecialMatch? {
110
+ var best: SpecialMatch? = null
111
+ for (special in specialTokens) {
112
+ val idx = text.indexOf(special, from)
113
+ if (idx < 0) continue
114
+ if (best == null || idx < best.start) {
115
+ best = SpecialMatch(special, idx)
116
+ }
117
+ }
118
+ return best
119
+ }
120
+
121
+ private fun encodeBpe(text: String): List<Int> {
122
+ // GPT-2 byte-level: every UTF-8 byte mapped to a single unicode
123
+ // character via the byteToUnicode permutation, then BPE operates
124
+ // over the resulting string as a single "word" (HF tokenizer.json
125
+ // ByteLevel pre-tokenizer's default is to NOT split on whitespace
126
+ // for Llama-3 — every word boundary is preserved as a Ġ-prefixed
127
+ // piece during merges).
128
+ val bytes = text.toByteArray(Charsets.UTF_8)
129
+ val mapped = StringBuilder(bytes.size)
130
+ for (b in bytes) {
131
+ val unsigned = b.toInt() and 0xFF
132
+ mapped.append(byteToUnicode.getValue(unsigned))
133
+ }
134
+ return bpe(mapped.toString())
135
+ }
136
+
137
+ /**
138
+ * Apply BPE merges greedily to a single byte-level-encoded "word".
139
+ *
140
+ * Standard HF BPE algorithm:
141
+ * 1. Split the word into individual chars.
142
+ * 2. Find the pair with the lowest merge-rank among adjacent pairs.
143
+ * 3. Merge that pair everywhere it occurs in the symbol list.
144
+ * 4. Repeat until no more merges apply.
145
+ * 5. Look up each resulting symbol in the vocab.
146
+ */
147
+ private fun bpe(word: String): List<Int> {
148
+ if (word.isEmpty()) return emptyList()
149
+ val symbols = word.map { it.toString() }.toMutableList()
150
+ if (symbols.size == 1) {
151
+ return listOf(vocab[symbols[0]] ?: vocab.getValue("<unk>"))
152
+ }
153
+
154
+ while (symbols.size >= 2) {
155
+ // Find lowest-rank adjacent pair.
156
+ var bestRank = Int.MAX_VALUE
157
+ var bestIdx = -1
158
+ for (i in 0 until symbols.size - 1) {
159
+ val rank = mergeRanks[symbols[i] to symbols[i + 1]] ?: continue
160
+ if (rank < bestRank) {
161
+ bestRank = rank
162
+ bestIdx = i
163
+ }
164
+ }
165
+ if (bestIdx < 0) break
166
+ // Merge every occurrence of the best pair, left-to-right.
167
+ val left = symbols[bestIdx]
168
+ val right = symbols[bestIdx + 1]
169
+ val merged = left + right
170
+ var r = 0
171
+ val rebuilt = ArrayList<String>(symbols.size)
172
+ while (r < symbols.size) {
173
+ if (r < symbols.size - 1 && symbols[r] == left && symbols[r + 1] == right) {
174
+ rebuilt.add(merged)
175
+ r += 2
176
+ } else {
177
+ rebuilt.add(symbols[r])
178
+ r += 1
179
+ }
180
+ }
181
+ symbols.clear()
182
+ symbols.addAll(rebuilt)
183
+ }
184
+
185
+ return symbols.map { sym ->
186
+ vocab[sym] ?: vocab["<unk>"] ?: error("token '$sym' not in vocab and no <unk> fallback")
187
+ }
188
+ }
189
+
190
+ /**
191
+ * Decode a list of token ids back to a UTF-8 string. Reverses the
192
+ * byte-level mapping. Special tokens are skipped if [skipSpecialTokens]
193
+ * is true (the default for chat output).
194
+ */
195
+ fun decode(tokens: List<Int>, skipSpecialTokens: Boolean = true): String {
196
+ val pieces = StringBuilder()
197
+ for (id in tokens) {
198
+ val tok = idToToken[id] ?: continue
199
+ if (skipSpecialTokens && tok in specialTokens) continue
200
+ pieces.append(tok)
201
+ }
202
+ // Reverse the byte-level mapping: every char in `pieces` was the
203
+ // image of one input byte. Map each char back to its byte value
204
+ // and decode the resulting byte sequence as UTF-8.
205
+ val out = ByteArray(pieces.length)
206
+ var n = 0
207
+ for (i in pieces.indices) {
208
+ val byteVal = unicodeToByte[pieces[i]]
209
+ // Tokens added by `added_tokens` (e.g. chat-template control
210
+ // tokens) live OUTSIDE the byte-level alphabet — their chars
211
+ // are not in unicodeToByte. Skip them (or emit '?' if you want
212
+ // a visible artefact). For chat output, skipping is correct.
213
+ if (byteVal != null) {
214
+ out[n] = byteVal.toByte()
215
+ n += 1
216
+ }
217
+ }
218
+ return String(out, 0, n, Charsets.UTF_8)
219
+ }
220
+
221
+ fun decode(token: Int): String = decode(listOf(token), skipSpecialTokens = true)
222
+
223
+ companion object {
224
+ private val parser = Json { ignoreUnknownKeys = true }
225
+
226
+ /**
227
+ * Load a tokenizer.json from disk. Throws
228
+ * [LiteRTBackendError.TokenizerLoadFailed] on any parse / structure
229
+ * failure with a precise reason.
230
+ */
231
+ @Throws(LiteRTBackendError.TokenizerLoadFailed::class)
232
+ fun load(tokenizerJsonPath: String, eosTokenIdOverride: Int? = null): HFTokenizerJson {
233
+ val file = File(tokenizerJsonPath)
234
+ if (!file.isFile) {
235
+ throw LiteRTBackendError.TokenizerLoadFailed(
236
+ "tokenizer.json not found at $tokenizerJsonPath",
237
+ )
238
+ }
239
+ val root = try {
240
+ parser.parseToJsonElement(file.readText()).jsonObject
241
+ } catch (t: Throwable) {
242
+ throw LiteRTBackendError.TokenizerLoadFailed("failed to parse tokenizer.json: ${t.message}")
243
+ }
244
+
245
+ val model = root["model"] as? JsonObject
246
+ ?: throw LiteRTBackendError.TokenizerLoadFailed("tokenizer.json: missing 'model' object")
247
+ val type = (model["type"] as? JsonPrimitive)?.contentOrNull
248
+ if (type != null && type != "BPE") {
249
+ throw LiteRTBackendError.TokenizerLoadFailed(
250
+ "tokenizer.json: model.type='$type' is not supported (only BPE). Use the mediapipe backend for SentencePiece/Unigram models.",
251
+ )
252
+ }
253
+
254
+ val vocabRaw = model["vocab"] as? JsonObject
255
+ ?: throw LiteRTBackendError.TokenizerLoadFailed("tokenizer.json: missing 'model.vocab'")
256
+ val vocab = HashMap<String, Int>(vocabRaw.size)
257
+ val idToToken = HashMap<Int, String>(vocabRaw.size)
258
+ for ((tok, idEl) in vocabRaw) {
259
+ val id = (idEl as? JsonPrimitive)?.intOrNull
260
+ ?: throw LiteRTBackendError.TokenizerLoadFailed("tokenizer.json: vocab entry '$tok' is not an int")
261
+ vocab[tok] = id
262
+ idToToken[id] = tok
263
+ }
264
+
265
+ val mergesRaw = model["merges"] as? JsonArray
266
+ ?: throw LiteRTBackendError.TokenizerLoadFailed("tokenizer.json: missing 'model.merges'")
267
+ val mergeRanks = HashMap<Pair<String, String>, Int>(mergesRaw.size)
268
+ for ((rank, mEl) in mergesRaw.withIndex()) {
269
+ val pair = parseMergeEntry(mEl)
270
+ ?: throw LiteRTBackendError.TokenizerLoadFailed(
271
+ "tokenizer.json: merges[$rank] is not a 'A B' string or [A,B] pair",
272
+ )
273
+ mergeRanks[pair] = rank
274
+ }
275
+
276
+ // Special / added tokens. Each entry shape: { id, content, special, ... }.
277
+ // We treat anything with `special: true` (or anything in this list,
278
+ // since added_tokens are by convention always specials in modern HF
279
+ // tokenizer.json files) as a special token: matched verbatim by
280
+ // encode(), skipped by decode() when skipSpecialTokens=true.
281
+ val specialTokens = mutableSetOf<String>()
282
+ (root["added_tokens"] as? JsonArray)?.forEach { entry ->
283
+ val obj = entry as? JsonObject ?: return@forEach
284
+ val content = (obj["content"] as? JsonPrimitive)?.contentOrNull ?: return@forEach
285
+ val id = (obj["id"] as? JsonPrimitive)?.intOrNull
286
+ if (id != null) {
287
+ vocab[content] = id
288
+ idToToken[id] = content
289
+ }
290
+ val isSpecial = (obj["special"] as? JsonPrimitive)?.booleanOrNull ?: true
291
+ if (isSpecial) specialTokens.add(content)
292
+ }
293
+
294
+ // Discover BOS / EOS / PAD ids from `added_tokens` first, then
295
+ // from the standard names. The caller can override EOS via opts.
296
+ val bosTokenId = vocab["<|begin_of_text|>"] ?: vocab["<s>"] ?: vocab["<bos>"]
297
+ val discoveredEos = vocab["<|eot_id|>"]
298
+ ?: vocab["<|end_of_text|>"]
299
+ ?: vocab["</s>"]
300
+ ?: vocab["<eos>"]
301
+ val eosTokenId = eosTokenIdOverride ?: discoveredEos
302
+ ?: throw LiteRTBackendError.TokenizerLoadFailed(
303
+ "tokenizer.json: no EOS-like token in added_tokens (looked for <|eot_id|>, <|end_of_text|>, </s>, <eos>) — pass eosTokenId in start opts to override",
304
+ )
305
+ val padTokenId = vocab["<pad>"] ?: vocab["<|pad|>"]
306
+
307
+ val (b2u, u2b) = buildByteLevelMap()
308
+
309
+ return HFTokenizerJson(
310
+ vocab = vocab,
311
+ idToToken = idToToken,
312
+ mergeRanks = mergeRanks,
313
+ specialTokens = specialTokens,
314
+ byteToUnicode = b2u,
315
+ unicodeToByte = u2b,
316
+ bosTokenId = bosTokenId,
317
+ eosTokenId = eosTokenId,
318
+ padTokenId = padTokenId,
319
+ )
320
+ }
321
+
322
+ /**
323
+ * Parse one entry of tokenizer.json's `model.merges` array. Two
324
+ * shapes are seen in the wild:
325
+ * - String: "A B" (older HF, Llama-2-style). Split on first space.
326
+ * - Array of two strings: ["A", "B"] (HF v0.21+ default).
327
+ */
328
+ private fun parseMergeEntry(el: kotlinx.serialization.json.JsonElement): Pair<String, String>? {
329
+ return when (el) {
330
+ is JsonPrimitive -> {
331
+ val s = el.contentOrNull ?: return null
332
+ val sp = s.indexOf(' ')
333
+ if (sp < 0) return null
334
+ s.substring(0, sp) to s.substring(sp + 1)
335
+ }
336
+ is JsonArray -> {
337
+ if (el.size != 2) return null
338
+ val a = (el[0] as? JsonPrimitive)?.contentOrNull ?: return null
339
+ val b = (el[1] as? JsonPrimitive)?.contentOrNull ?: return null
340
+ a to b
341
+ }
342
+ else -> null
343
+ }
344
+ }
345
+
346
+ /**
347
+ * Construct GPT-2's reversible byte→unicode permutation. Maps each
348
+ * of the 256 byte values to a printable unicode codepoint:
349
+ * - Bytes that are already printable ASCII (33..126), Latin-1
350
+ * supplement printable (161..172, 174..255) map to themselves.
351
+ * - All other bytes (0..32, 127..160, 173) map to the Latin-1
352
+ * Supplement / Latin-Extended-A range starting at 256, in order.
353
+ *
354
+ * Reference: HuggingFace tokenizers' ByteLevel `bytes_to_unicode()`
355
+ * Python helper. The output map is identical between HF Python,
356
+ * tokenizers Rust, and this Kotlin port.
357
+ */
358
+ private fun buildByteLevelMap(): Pair<Map<Int, Char>, Map<Char, Int>> {
359
+ val printable = (33..126) + (161..172) + (174..255)
360
+ val bs = printable.toMutableList()
361
+ val cs = printable.map { it }.toMutableList()
362
+ var n = 0
363
+ for (b in 0..255) {
364
+ if (b !in printable) {
365
+ bs.add(b)
366
+ cs.add(256 + n)
367
+ n += 1
368
+ }
369
+ }
370
+ val byteToChar = HashMap<Int, Char>(256)
371
+ val charToByte = HashMap<Char, Int>(256)
372
+ for (i in bs.indices) {
373
+ val ch = cs[i].toChar()
374
+ byteToChar[bs[i]] = ch
375
+ charToByte[ch] = bs[i]
376
+ }
377
+ return byteToChar to charToByte
378
+ }
379
+ }
380
+ }
@@ -0,0 +1,241 @@
1
+ package co.deepvoiceai.bridge.litert.core.Internal
2
+
3
+ import co.deepvoiceai.bridge.litert.core.LiteRTBackendError
4
+ import com.google.ai.edge.litert.Accelerator
5
+ import com.google.ai.edge.litert.CompiledModel
6
+ import com.google.ai.edge.litert.TensorBuffer
7
+ import com.google.ai.edge.litert.TensorType
8
+ import java.io.File
9
+
10
+ /**
11
+ * Test seam over the LiteRT [CompiledModel]. Concrete [LiteRTEngine] runs
12
+ * the real native runtime; [LiteRTGenerator]'s mock test substitutes a
13
+ * canned-logits fake without loading a .tflite.
14
+ */
15
+ internal interface LiteRTEngineApi {
16
+ /** Vocab size (= length of the FloatArray returned by [runStep]). */
17
+ val vocabSize: Int
18
+
19
+ /** EOS id in the model's vocab — generator uses it to terminate decode. */
20
+ val eosTokenId: Int
21
+
22
+ /**
23
+ * Run a single forward pass with [token] at position [kvCachePosition].
24
+ * Returns the logits row for the next-token prediction (length = [vocabSize]).
25
+ * Throws [LiteRTBackendError.GenerationFailed] on native failure.
26
+ */
27
+ fun runStep(token: Int, kvCachePosition: Int): FloatArray
28
+
29
+ /** Release native resources. Idempotent. */
30
+ fun close()
31
+ }
32
+
33
+ /**
34
+ * Wraps Google's LiteRT [CompiledModel] for a stateful Llama-style
35
+ * autoregressive checkpoint. Drives single-token decoding via named-tensor
36
+ * `run(inputs, outputs, signature)` calls.
37
+ *
38
+ * Why not use LiteRT-LM? We deliberately depend on bare `litert` (see
39
+ * `android/build.gradle` top-of-file comment), so the KV-cache / sampler
40
+ * loop is implemented here in Kotlin. The Llama-style .tflite checkpoints
41
+ * we target carry the cache as graph-internal state, exposed through
42
+ * named inputs/outputs that the runtime maintains across calls within one
43
+ * [CompiledModel] instance — same shape Apple's CoreML stateful Llama
44
+ * checkpoints follow on iOS (see `CoreMLEngine.swift`).
45
+ *
46
+ * Tensor convention (auto-detected at init via [CompiledModel.getInputTensorType]):
47
+ * - [inputName] `input_ids` INT32 [1, 1] (default, overridable)
48
+ * - [causalMaskName] `causal_mask` FLOAT [1, 1, 1, kv_len] (optional)
49
+ * - [outputName] `logits` FLOAT [1, 1, vocab] or [1, vocab] (auto)
50
+ *
51
+ * If the model declares no `causal_mask` input we silently skip writing
52
+ * it — many simpler stateful checkpoints don't expose one.
53
+ *
54
+ * This class is NOT thread-safe. [LiteRTHandlers] serializes all calls
55
+ * behind a mutex; do the same in any other call site.
56
+ */
57
+ internal class LiteRTEngine(
58
+ modelPath: String,
59
+ private val inputName: String = "input_ids",
60
+ private val causalMaskName: String = "causal_mask",
61
+ private val outputName: String = "logits",
62
+ /** Surface override so the handler / config layer can lift it from start opts. */
63
+ @Suppress("UNUSED_PARAMETER")
64
+ private val contextSize: Int = 2048,
65
+ eosTokenId: Int,
66
+ accelerator: Accelerator = Accelerator.CPU,
67
+ ) : LiteRTEngineApi, AutoCloseable {
68
+
69
+ private val model: CompiledModel
70
+ override val vocabSize: Int
71
+ override val eosTokenId: Int = eosTokenId
72
+ private val hasCausalMask: Boolean
73
+ private val causalMaskRank: Int
74
+ private val inputIsInt64: Boolean
75
+
76
+ init {
77
+ val f = File(modelPath)
78
+ if (!f.isFile) {
79
+ throw LiteRTBackendError.ModelLoadFailed("model file not found at $modelPath")
80
+ }
81
+
82
+ model = try {
83
+ CompiledModel.create(modelPath, CompiledModel.Options(accelerator))
84
+ } catch (t: Throwable) {
85
+ throw LiteRTBackendError.ModelLoadFailed("CompiledModel.create failed: ${t.message ?: t::class.java.simpleName}")
86
+ }
87
+
88
+ // Validate the input_ids tensor exists and capture its rank/dtype
89
+ // for the writeInt path. We don't enforce shape == [1,1] here —
90
+ // the model owns its declared signature; we just feed [token] as
91
+ // a 1-element IntArray and let the runtime broadcast / error out.
92
+ val inputType = try {
93
+ model.getInputTensorType(inputName)
94
+ } catch (t: Throwable) {
95
+ // If the named-tensor lookup fails, the consumer's checkpoint
96
+ // doesn't follow our convention. Fail fast with a precise
97
+ // message — silent fallback to default signature would only
98
+ // surface the real shape mismatch deep inside nativeRun().
99
+ model.close()
100
+ throw LiteRTBackendError.ModelLoadFailed(
101
+ "input tensor '$inputName' not found on model (override via litertInputName opt). " +
102
+ "Cause: ${t.message ?: t::class.java.simpleName}",
103
+ )
104
+ }
105
+ // Both INT32 and INT64 input_ids are seen on Llama-style
106
+ // checkpoints in the wild (Llama-3 typically int32, some Gemma
107
+ // exports int64). We dispatch to `writeInt` vs `writeLong` based
108
+ // on the declared element type at runStep time, so both shapes
109
+ // round-trip cleanly through LiteRT's strict dtype check.
110
+ if (inputType.elementType != TensorType.ElementType.INT &&
111
+ inputType.elementType != TensorType.ElementType.INT64
112
+ ) {
113
+ model.close()
114
+ throw LiteRTBackendError.ModelLoadFailed(
115
+ "input tensor '$inputName' has unsupported elementType=${inputType.elementType}; expected INT or INT64",
116
+ )
117
+ }
118
+ inputIsInt64 = inputType.elementType == TensorType.ElementType.INT64
119
+
120
+ // causal_mask is optional — many simpler checkpoints don't expose it.
121
+ // Probe via a try/catch on getInputTensorType since the LiteRT API
122
+ // doesn't have a non-throwing "does this input exist?" call.
123
+ var maskRank = 0
124
+ val maskPresent = try {
125
+ val maskType = model.getInputTensorType(causalMaskName)
126
+ maskRank = maskType.layout?.rank ?: 0
127
+ true
128
+ } catch (_: Throwable) {
129
+ false
130
+ }
131
+ hasCausalMask = maskPresent
132
+ causalMaskRank = maskRank
133
+
134
+ // Discover logits rank + vocab size by inspecting the output
135
+ // tensor type. Handles both [1, 1, V] (Llama-3 style) and
136
+ // [1, V] (some Gemma exports).
137
+ val outputType = try {
138
+ model.getOutputTensorType(outputName)
139
+ } catch (t: Throwable) {
140
+ model.close()
141
+ throw LiteRTBackendError.ModelLoadFailed(
142
+ "output tensor '$outputName' not found on model (override via litertOutputName opt). " +
143
+ "Cause: ${t.message ?: t::class.java.simpleName}",
144
+ )
145
+ }
146
+ val outDims = outputType.layout?.dimensions ?: emptyList()
147
+ vocabSize = when (outDims.size) {
148
+ 3 -> outDims[2] // [1, 1, V] — Llama-3 style
149
+ 2 -> outDims[1] // [1, V] — some Gemma exports
150
+ else -> {
151
+ model.close()
152
+ throw LiteRTBackendError.ModelLoadFailed(
153
+ "output tensor '$outputName' has unsupported rank=${outDims.size}; expected 2 or 3 with vocab as last dim",
154
+ )
155
+ }
156
+ }
157
+ if (vocabSize <= 0) {
158
+ model.close()
159
+ throw LiteRTBackendError.ModelLoadFailed(
160
+ "output tensor '$outputName' reports non-positive vocab size: $vocabSize",
161
+ )
162
+ }
163
+ }
164
+
165
+ override fun runStep(token: Int, kvCachePosition: Int): FloatArray {
166
+ val inputs = mutableMapOf<String, TensorBuffer>()
167
+ val outputs = mutableMapOf<String, TensorBuffer>()
168
+ val opened = mutableListOf<TensorBuffer>()
169
+ try {
170
+ // input_ids: [1, 1] with the new token. writeInt vs writeLong
171
+ // is selected from the declared element type captured at init.
172
+ val inputBuf = model.createInputBuffer(inputName)
173
+ opened.add(inputBuf)
174
+ if (inputIsInt64) {
175
+ inputBuf.writeLong(longArrayOf(token.toLong()))
176
+ } else {
177
+ inputBuf.writeInt(intArrayOf(token))
178
+ }
179
+ inputs[inputName] = inputBuf
180
+
181
+ // causal_mask: [1, 1, 1, kvCachePosition+1] all-zeros if the
182
+ // model declares one. Zero = unmasked, large negative = masked;
183
+ // for a single-token decode step every prior position is visible.
184
+ // LiteRT only exposes writeFloat for FP tensors — even if the
185
+ // declared dtype is FP16, the runtime accepts FP32 input and
186
+ // converts internally. (HF model converts also produce FP32
187
+ // causal_masks more often than FP16 in 2026 conversions.)
188
+ if (hasCausalMask && causalMaskRank > 0) {
189
+ val kvLen = maxOf(1, kvCachePosition + 1)
190
+ val maskBuf = model.createInputBuffer(causalMaskName)
191
+ opened.add(maskBuf)
192
+ // Zero-fill with size = product of the buffer's logical
193
+ // dimensions (we pass the full kvLen-sized buffer; LiteRT
194
+ // resizes dynamic-axis tensors based on writeFloat length).
195
+ maskBuf.writeFloat(FloatArray(kvLen))
196
+ inputs[causalMaskName] = maskBuf
197
+ }
198
+
199
+ val outputBuf = model.createOutputBuffer(outputName)
200
+ opened.add(outputBuf)
201
+ outputs[outputName] = outputBuf
202
+
203
+ try {
204
+ model.run(inputs, outputs)
205
+ } catch (t: Throwable) {
206
+ throw LiteRTBackendError.GenerationFailed(
207
+ "model.run failed at kvPos=$kvCachePosition token=$token: ${t.message ?: t::class.java.simpleName}",
208
+ )
209
+ }
210
+
211
+ val raw = outputBuf.readFloat()
212
+ // For rank-3 logits we want the LAST row (the prediction for
213
+ // the *next* token). With a [1, 1, V] shape there's only one
214
+ // row anyway so raw IS the next-token logits — but if the
215
+ // checkpoint produces [1, T, V] for a multi-token prefill we'd
216
+ // want raw.takeLast(V). We default to the last-V slice for
217
+ // robustness.
218
+ return if (raw.size == vocabSize) {
219
+ raw
220
+ } else {
221
+ val start = raw.size - vocabSize
222
+ if (start < 0) {
223
+ throw LiteRTBackendError.GenerationFailed(
224
+ "logits buffer length ${raw.size} is smaller than vocabSize $vocabSize",
225
+ )
226
+ }
227
+ raw.copyOfRange(start, raw.size)
228
+ }
229
+ } finally {
230
+ // Release every per-call TensorBuffer (createInputBuffer +
231
+ // createOutputBuffer allocate fresh native handles each call).
232
+ for (buf in opened) {
233
+ runCatching { buf.close() }
234
+ }
235
+ }
236
+ }
237
+
238
+ override fun close() {
239
+ runCatching { model.close() }
240
+ }
241
+ }