@dvai-bridge/android-litert-core 4.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,199 @@
1
+ package co.deepvoiceai.bridge.litert.core
2
+
3
+ import co.deepvoiceai.bridge.litert.core.Internal.HFTokenizerJson
4
+ import co.deepvoiceai.bridge.litert.core.Internal.LiteRTEngine
5
+ import co.deepvoiceai.bridge.litert.core.Internal.LiteRTGenerator
6
+ import co.deepvoiceai.bridge.litert.core.Internal.LiteRTSampler
7
+ import co.deepvoiceai.bridge.shared.core.CorsConfig
8
+ import co.deepvoiceai.bridge.shared.core.HandlerContext
9
+ import co.deepvoiceai.bridge.shared.core.HttpServer
10
+ import com.google.ai.edge.litert.Accelerator
11
+ import kotlinx.coroutines.sync.Mutex
12
+ import kotlinx.coroutines.sync.withLock
13
+ import kotlin.random.Random
14
+
15
+ /**
16
+ * Owns the running state of the LiteRT core: the engine, the tokenizer,
17
+ * the HTTP server, and model metadata. All access is serialized through
18
+ * a [Mutex] so concurrent start/stop calls can never race against the
19
+ * underlying [com.google.ai.edge.litert.CompiledModel] (whose internal
20
+ * KV-cache state is single-conversation by design).
21
+ *
22
+ * Capacitor-free: opts are plain `Map<String, Any?>` and return values
23
+ * are plain `Map<String, Any?>`. The Capacitor wrapper translates
24
+ * JSObject ↔ Map at the JS bridge boundary (consistent with the
25
+ * existing llama-core / mediapipe-core PluginState shape).
26
+ *
27
+ * Required `start()` opts:
28
+ * - `modelPath` String — absolute path to the .tflite / .litertlm file.
29
+ * - `tokenizerPath` String — absolute path to the HF `tokenizer.json`.
30
+ *
31
+ * Optional `start()` opts (all have sane defaults):
32
+ * - `contextSize` Int default 2048
33
+ * - `temperature` Float default 0.0 (greedy)
34
+ * - `topP` Float default 1.0
35
+ * - `topK` Int default 0 (disabled)
36
+ * - `maxNewTokens` Int default 256
37
+ * - `eosTokenId` Int override tokenizer.json's discovered EOS
38
+ * - `accelerator` String "cpu" | "gpu" | "npu" default "cpu"
39
+ * - `chatTemplate` String "llama3" | "plain" default "llama3"
40
+ * - `litertInputName` String default "input_ids"
41
+ * - `litertCausalMaskName` String default "causal_mask"
42
+ * - `litertOutputName` String default "logits"
43
+ * - `httpBasePort` Int default 38883
44
+ * - `httpMaxPortAttempts` Int default 16
45
+ * - `corsOrigin` (see [CorsConfig.fromOpt])
46
+ * - `samplerSeed` Long default System.nanoTime()
47
+ * - `modelId` String default modelPath basename
48
+ */
49
+ class LiteRTPluginState {
50
+ private val mutex = Mutex()
51
+ private var server: HttpServer? = null
52
+ private var engine: LiteRTEngine? = null
53
+ private var modelId: String = ""
54
+ private var isRunning: Boolean = false
55
+ private var baseUrl: String? = null
56
+ private var port: Int? = null
57
+
58
+ suspend fun start(opts: Map<String, Any?>): Map<String, Any?> = mutex.withLock {
59
+ if (isRunning) stopInternal()
60
+
61
+ val modelPath = (opts["modelPath"] as? String)?.takeIf { it.isNotEmpty() }
62
+ ?: throw IllegalArgumentException("modelPath is required for litert backend")
63
+ val tokenizerPath = (opts["tokenizerPath"] as? String)?.takeIf { it.isNotEmpty() }
64
+ ?: throw IllegalArgumentException("tokenizerPath is required for litert backend")
65
+
66
+ val contextSize = (opts["contextSize"] as? Number)?.toInt() ?: 2048
67
+ val temperature = (opts["temperature"] as? Number)?.toFloat() ?: 0.0f
68
+ val topP = (opts["topP"] as? Number)?.toFloat() ?: 1.0f
69
+ val topK = (opts["topK"] as? Number)?.toInt() ?: 0
70
+ val maxNewTokens = (opts["maxNewTokens"] as? Number)?.toInt() ?: 256
71
+ val eosOverride = (opts["eosTokenId"] as? Number)?.toInt()
72
+ val acceleratorOpt = (opts["accelerator"] as? String)?.lowercase()
73
+ val accelerator = when (acceleratorOpt) {
74
+ "gpu" -> Accelerator.GPU
75
+ "npu" -> Accelerator.NPU
76
+ null, "cpu" -> Accelerator.CPU
77
+ else -> throw IllegalArgumentException("invalid 'accelerator' opt: $acceleratorOpt (expected cpu|gpu|npu)")
78
+ }
79
+ val chatTemplateOpt = (opts["chatTemplate"] as? String)?.lowercase()
80
+ val chatTemplate = when (chatTemplateOpt) {
81
+ null, "llama3" -> ChatTemplateRenderer.LLAMA3
82
+ "plain" -> ChatTemplateRenderer.PLAIN
83
+ else -> throw IllegalArgumentException("invalid 'chatTemplate' opt: $chatTemplateOpt (expected llama3|plain)")
84
+ }
85
+ val inputName = (opts["litertInputName"] as? String) ?: "input_ids"
86
+ val causalMaskName = (opts["litertCausalMaskName"] as? String) ?: "causal_mask"
87
+ val outputName = (opts["litertOutputName"] as? String) ?: "logits"
88
+ val httpBasePort = (opts["httpBasePort"] as? Number)?.toInt() ?: 38883
89
+ val httpMaxPortAttempts = (opts["httpMaxPortAttempts"] as? Number)?.toInt() ?: 16
90
+ val corsConfig = CorsConfig.fromOpt(opts["corsOrigin"])
91
+ val samplerSeed = (opts["samplerSeed"] as? Number)?.toLong() ?: System.nanoTime()
92
+ val resolvedModelId = (opts["modelId"] as? String)?.takeIf { it.isNotEmpty() }
93
+ ?: deriveModelId(modelPath)
94
+
95
+ // Load tokenizer FIRST so we can pick up its discovered EOS for the
96
+ // engine constructor. Tokenizer load failure is cheap; engine load
97
+ // is expensive (loads the full .tflite into memory).
98
+ val tokenizer = try {
99
+ HFTokenizerJson.load(tokenizerPath, eosTokenIdOverride = eosOverride)
100
+ } catch (e: LiteRTBackendError.TokenizerLoadFailed) {
101
+ throw e
102
+ }
103
+
104
+ val newEngine = try {
105
+ LiteRTEngine(
106
+ modelPath = modelPath,
107
+ inputName = inputName,
108
+ causalMaskName = causalMaskName,
109
+ outputName = outputName,
110
+ contextSize = contextSize,
111
+ eosTokenId = tokenizer.eosTokenId,
112
+ accelerator = accelerator,
113
+ )
114
+ } catch (e: LiteRTBackendError.ModelLoadFailed) {
115
+ throw e
116
+ }
117
+
118
+ val sampler = LiteRTSampler(
119
+ temperature = temperature,
120
+ topP = topP,
121
+ topK = topK,
122
+ random = Random(samplerSeed),
123
+ )
124
+ val generator = LiteRTGenerator(
125
+ engine = newEngine,
126
+ tokenizer = tokenizer,
127
+ sampler = sampler,
128
+ maxNewTokens = maxNewTokens,
129
+ )
130
+ val handlers = LiteRTHandlers(
131
+ generator = generator,
132
+ modelId = resolvedModelId,
133
+ chatTemplate = chatTemplate,
134
+ maxNewTokensDefault = maxNewTokens,
135
+ )
136
+ val ctx = HandlerContext(modelId = resolvedModelId, backendName = "litert")
137
+
138
+ val newServer = HttpServer()
139
+ val boundPort = try {
140
+ newServer.startWithRoutes(
141
+ basePort = httpBasePort,
142
+ maxAttempts = httpMaxPortAttempts,
143
+ host = "127.0.0.1",
144
+ handlers = handlers,
145
+ ctx = ctx,
146
+ corsConfig = corsConfig,
147
+ )
148
+ } catch (t: Throwable) {
149
+ // Bind failed — release the engine we already initialized so we
150
+ // don't leak the native handle.
151
+ runCatching { newEngine.close() }
152
+ throw t
153
+ }
154
+
155
+ this.engine = newEngine
156
+ this.server = newServer
157
+ this.modelId = resolvedModelId
158
+ this.port = boundPort
159
+ this.baseUrl = "http://127.0.0.1:$boundPort/v1"
160
+ this.isRunning = true
161
+
162
+ return@withLock mapOf(
163
+ "baseUrl" to "http://127.0.0.1:$boundPort/v1",
164
+ "port" to boundPort,
165
+ "backend" to "litert",
166
+ "modelId" to resolvedModelId,
167
+ )
168
+ }
169
+
170
+ suspend fun stop() = mutex.withLock { stopInternal() }
171
+
172
+ private suspend fun stopInternal() {
173
+ server?.stop()
174
+ runCatching { engine?.close() }
175
+ server = null
176
+ engine = null
177
+ modelId = ""
178
+ baseUrl = null
179
+ port = null
180
+ isRunning = false
181
+ }
182
+
183
+ fun statusInfo(): Map<String, Any?> = buildMap {
184
+ put("running", isRunning)
185
+ baseUrl?.let { put("baseUrl", it) }
186
+ if (isRunning) put("backend", "litert")
187
+ }
188
+
189
+ /**
190
+ * Best-effort default model id from a model file path: strip the
191
+ * directory prefix and any `.tflite` / `.litertlm` extension.
192
+ * Mirrors `MediaPipePluginState.deriveModelId`.
193
+ */
194
+ private fun deriveModelId(modelPath: String): String {
195
+ val name = modelPath.substringAfterLast('/').substringAfterLast('\\')
196
+ val stripped = name.removeSuffix(".tflite").removeSuffix(".litertlm")
197
+ return stripped.ifEmpty { "litert-default" }
198
+ }
199
+ }
@@ -0,0 +1,234 @@
1
+ package co.deepvoiceai.bridge.litert.core.Internal
2
+
3
+ import kotlinx.coroutines.runBlocking
4
+ import org.junit.Assert.assertEquals
5
+ import org.junit.Assert.assertTrue
6
+ import org.junit.Test
7
+ import kotlin.random.Random
8
+
9
+ /**
10
+ * Unit tests for [LiteRTGenerator]'s decode-loop control flow, using a
11
+ * fake [LiteRTEngineApi] that returns canned logits so we can verify:
12
+ *
13
+ * - Tokens come out in the order the engine's logits dictate (greedy).
14
+ * - EOS terminates the loop early.
15
+ * - maxNewTokens caps the loop even when EOS is never produced.
16
+ * - kvCachePosition is incremented monotonically across runStep calls.
17
+ *
18
+ * Tokenizer: a tiny stub that maps the prompt to a fixed token list and
19
+ * round-trips token ids through a vocab built off raw chars. This avoids
20
+ * loading a real tokenizer.json from disk in a unit test.
21
+ */
22
+ class LiteRTGeneratorMockTest {
23
+
24
+ /** Records every (token, kvPos) pair the generator feeds. */
25
+ private class CallLog {
26
+ val calls = mutableListOf<Pair<Int, Int>>()
27
+ }
28
+
29
+ /**
30
+ * Fake engine: returns canned logits whose argmax is the *next*
31
+ * entry in [predictionsAfterPrefill] starting from the last prefill
32
+ * call. Earlier prefill calls return zero-vector logits — those
33
+ * predictions are immediately overwritten by the next prefill step,
34
+ * so their values are irrelevant.
35
+ *
36
+ * @param promptLen number of prefill calls the generator will make
37
+ * (= prompt token count). The fake counts down
38
+ * `promptLen` calls before it starts emitting
39
+ * meaningful argmax targets from
40
+ * [predictionsAfterPrefill][0].
41
+ */
42
+ private class FakeEngine(
43
+ override val vocabSize: Int,
44
+ override val eosTokenId: Int,
45
+ private val promptLen: Int,
46
+ private val predictionsAfterPrefill: List<Int>,
47
+ private val log: CallLog,
48
+ ) : LiteRTEngineApi {
49
+ private var callIdx = 0
50
+
51
+ override fun runStep(token: Int, kvCachePosition: Int): FloatArray {
52
+ log.calls.add(token to kvCachePosition)
53
+ val logits = FloatArray(vocabSize) { -10f }
54
+ // First (promptLen - 1) calls are early-prefill; their logits
55
+ // are immediately overwritten by the next prefill call so we
56
+ // emit no preference. Call (promptLen - 1) is the LAST prefill
57
+ // — its argmax is the FIRST sampled token. From there on,
58
+ // each runStep is a decode-feed and its argmax is the next
59
+ // sampled token.
60
+ val predictionIdx = callIdx - (promptLen - 1)
61
+ callIdx += 1
62
+ if (predictionIdx < 0) return logits
63
+ val target = if (predictionIdx < predictionsAfterPrefill.size) {
64
+ predictionsAfterPrefill[predictionIdx]
65
+ } else {
66
+ eosTokenId
67
+ }
68
+ logits[target] = 100f // dominant
69
+ return logits
70
+ }
71
+
72
+ override fun close() {}
73
+ }
74
+
75
+ /**
76
+ * Build a tokenizer over a tiny ASCII alphabet so the generator's
77
+ * encode + decode round-trips deterministically. We use the
78
+ * pure-Kotlin BPE loader's output via a synthetic JSON; that's
79
+ * heavier than needed here, so we provide a hand-rolled stub
80
+ * instead by writing a tokenizer.json to a temp file.
81
+ *
82
+ * Since the production code only depends on encode/decode, a far
83
+ * simpler approach is to build the tokenizer through a custom
84
+ * factory bypassing the JSON loader. We do that via reflection on
85
+ * the private constructor — but that's brittle. Instead, write a
86
+ * minimal valid tokenizer.json and load it normally.
87
+ */
88
+ private fun buildTinyTokenizer(eosId: Int = 99): HFTokenizerJson {
89
+ // Minimal tokenizer.json with BPE model and added_tokens for EOS.
90
+ // Vocab covers single-byte chars 'a'..'e' (which after byte-level
91
+ // mapping become themselves since 'a'..'e' are in the printable
92
+ // ASCII range 33..126) plus the EOS special.
93
+ val json = """
94
+ {
95
+ "model": {
96
+ "type": "BPE",
97
+ "vocab": { "a": 0, "b": 1, "c": 2, "d": 3, "e": 4 },
98
+ "merges": []
99
+ },
100
+ "added_tokens": [
101
+ { "id": $eosId, "content": "<|eot_id|>", "special": true }
102
+ ]
103
+ }
104
+ """.trimIndent()
105
+ val tmp = java.io.File.createTempFile("litert-test-tokenizer", ".json")
106
+ tmp.deleteOnExit()
107
+ tmp.writeText(json)
108
+ return HFTokenizerJson.load(tmp.absolutePath)
109
+ }
110
+
111
+ @Test
112
+ fun `generator emits scripted tokens in order`() = runBlocking {
113
+ // Prompt "abc" tokenizes to [0, 1, 2] (promptLen=3). After the
114
+ // last prefill step the engine predicts 3 ('d'), then 4 ('e'),
115
+ // then 99 (EOS). Decoded output is "de".
116
+ val log = CallLog()
117
+ val tokenizer = buildTinyTokenizer(eosId = 99)
118
+ val engine = FakeEngine(
119
+ vocabSize = 100,
120
+ eosTokenId = 99,
121
+ promptLen = 3,
122
+ predictionsAfterPrefill = listOf(3, 4, 99),
123
+ log = log,
124
+ )
125
+ val sampler = LiteRTSampler(temperature = 0f, topP = 1f, topK = 0, random = Random(42))
126
+ val gen = LiteRTGenerator(engine, tokenizer, sampler, maxNewTokens = 16)
127
+
128
+ val output = gen.generate("abc")
129
+ assertEquals("de", output)
130
+ }
131
+
132
+ @Test
133
+ fun `generator stops on EOS`() = runBlocking {
134
+ // promptLen=1: prompt "a" -> [0] -> single prefill call. Its
135
+ // argmax is EOS, so the decode loop terminates before producing
136
+ // any token.
137
+ val log = CallLog()
138
+ val tokenizer = buildTinyTokenizer(eosId = 99)
139
+ val engine = FakeEngine(
140
+ vocabSize = 100,
141
+ eosTokenId = 99,
142
+ promptLen = 1,
143
+ predictionsAfterPrefill = listOf(99),
144
+ log = log,
145
+ )
146
+ val sampler = LiteRTSampler(temperature = 0f, topP = 1f, topK = 0, random = Random(42))
147
+ val gen = LiteRTGenerator(engine, tokenizer, sampler, maxNewTokens = 99)
148
+
149
+ val output = gen.generate("a") // tokenizes to [0]
150
+ assertEquals("", output)
151
+ // Exactly one runStep call (the prefill); first sample produced
152
+ // EOS so the decode loop never invoked the engine again.
153
+ assertEquals(1, log.calls.size)
154
+ assertEquals(0 to 0, log.calls[0])
155
+ }
156
+
157
+ @Test
158
+ fun `generator caps at maxNewTokens when EOS never emitted`() = runBlocking {
159
+ // 100 'd's in a row, never EOS. maxNewTokens=5 caps the loop.
160
+ val log = CallLog()
161
+ val tokenizer = buildTinyTokenizer(eosId = 99)
162
+ val script = List(100) { 3 }
163
+ val engine = FakeEngine(
164
+ vocabSize = 100,
165
+ eosTokenId = 99,
166
+ promptLen = 3,
167
+ predictionsAfterPrefill = script,
168
+ log = log,
169
+ )
170
+ val sampler = LiteRTSampler(temperature = 0f, topP = 1f, topK = 0, random = Random(42))
171
+ val gen = LiteRTGenerator(engine, tokenizer, sampler, maxNewTokens = 5)
172
+
173
+ val output = gen.generate("abc") // tokenizes to [0,1,2]
174
+ assertEquals("ddddd", output)
175
+ }
176
+
177
+ @Test
178
+ fun `kvCachePosition increments monotonically across calls`() = runBlocking {
179
+ val log = CallLog()
180
+ val tokenizer = buildTinyTokenizer(eosId = 99)
181
+ val engine = FakeEngine(
182
+ vocabSize = 100,
183
+ eosTokenId = 99,
184
+ promptLen = 3,
185
+ predictionsAfterPrefill = listOf(3, 4, 99),
186
+ log = log,
187
+ )
188
+ val sampler = LiteRTSampler(temperature = 0f, topP = 1f, topK = 0, random = Random(42))
189
+ val gen = LiteRTGenerator(engine, tokenizer, sampler, maxNewTokens = 16)
190
+
191
+ gen.generate("abc")
192
+
193
+ // Expected calls (token, kvPos):
194
+ // 0: (0, 0) prefill[0] — early-prefill, argmax irrelevant
195
+ // 1: (1, 1) prefill[1] — early-prefill, argmax irrelevant
196
+ // 2: (2, 2) prefill[2] — last prefill -> sample produces 3
197
+ // 3: (3, 3) decode-feed -> sample produces 4
198
+ // 4: (4, 4) decode-feed -> sample produces 99 -> EOS halts
199
+ assertEquals(5, log.calls.size)
200
+ assertEquals(0 to 0, log.calls[0])
201
+ assertEquals(1 to 1, log.calls[1])
202
+ assertEquals(2 to 2, log.calls[2])
203
+ assertEquals(3 to 3, log.calls[3])
204
+ assertEquals(4 to 4, log.calls[4])
205
+ for (i in log.calls.indices) {
206
+ assertEquals(i, log.calls[i].second)
207
+ }
208
+ }
209
+
210
+ @Test
211
+ fun `empty prompt throws GenerationFailed`() = runBlocking {
212
+ val log = CallLog()
213
+ val tokenizer = buildTinyTokenizer(eosId = 99)
214
+ val engine = FakeEngine(
215
+ vocabSize = 100,
216
+ eosTokenId = 99,
217
+ promptLen = 0,
218
+ predictionsAfterPrefill = listOf(3),
219
+ log = log,
220
+ )
221
+ val sampler = LiteRTSampler(temperature = 0f, topP = 1f, topK = 0, random = Random(42))
222
+ val gen = LiteRTGenerator(engine, tokenizer, sampler, maxNewTokens = 4)
223
+
224
+ try {
225
+ gen.generate("")
226
+ org.junit.Assert.fail("expected GenerationFailed for empty prompt")
227
+ } catch (e: co.deepvoiceai.bridge.litert.core.LiteRTBackendError.GenerationFailed) {
228
+ assertTrue(e.message!!.contains("empty"))
229
+ }
230
+ // Engine never invoked — the generator validates BEFORE the first
231
+ // runStep call.
232
+ assertEquals(0, log.calls.size)
233
+ }
234
+ }
@@ -0,0 +1,136 @@
1
+ package co.deepvoiceai.bridge.litert.core.Internal
2
+
3
+ import org.junit.Assert.assertEquals
4
+ import org.junit.Assert.assertNotEquals
5
+ import org.junit.Assert.assertTrue
6
+ import org.junit.Test
7
+ import kotlin.random.Random
8
+
9
+ /**
10
+ * Unit tests for [LiteRTSampler]. Pure-Kotlin sampler — no Robolectric or
11
+ * Android runtime needed, plain JUnit 4.
12
+ *
13
+ * The sampler mirrors `CoreMLSampler.swift` (iOS); these tests intentionally
14
+ * mirror what would be the iOS test cases so a behavioural divergence
15
+ * surfaces in both suites.
16
+ */
17
+ class LiteRTSamplerTest {
18
+
19
+ /**
20
+ * temperature == 0 must take the pure-argmax path regardless of topP /
21
+ * topK. This is the deterministic chat / function-calling default.
22
+ */
23
+ @Test
24
+ fun `temperature zero is argmax`() {
25
+ val sampler = LiteRTSampler(temperature = 0f, topP = 1f, topK = 0, random = Random(42))
26
+ val logits = floatArrayOf(0.1f, 5.0f, -1.0f, 4.99f)
27
+ // Argmax = index 1 (5.0). Run repeatedly to confirm zero variance.
28
+ repeat(8) {
29
+ assertEquals(1, sampler.sample(logits))
30
+ }
31
+ }
32
+
33
+ /**
34
+ * topK = 1 must collapse to argmax even when temperature > 0 — the
35
+ * post-softmax distribution has a single non-zero entry, so multinomial
36
+ * sampling can only return the argmax index.
37
+ */
38
+ @Test
39
+ fun `topK one with positive temperature is deterministic argmax`() {
40
+ val sampler = LiteRTSampler(temperature = 1.0f, topP = 1f, topK = 1, random = Random(42))
41
+ val logits = floatArrayOf(-2.0f, 0.5f, 3.7f, 1.2f)
42
+ // Argmax = index 2.
43
+ repeat(16) {
44
+ assertEquals(2, sampler.sample(logits))
45
+ }
46
+ }
47
+
48
+ /**
49
+ * topP truncation: with cumulative-prob cutoff = 0.5 and a sharply
50
+ * peaked distribution (one logit dominates softmax > 0.5), only the
51
+ * argmax index can be returned.
52
+ */
53
+ @Test
54
+ fun `topP truncation keeps cumulative probability cutoff`() {
55
+ val sampler = LiteRTSampler(temperature = 1.0f, topP = 0.5f, topK = 0, random = Random(42))
56
+ // logits[3] = 10 dominates softmax (>>99% mass) — top-p=0.5 keeps
57
+ // exactly one index, so sampling collapses to argmax = 3.
58
+ val logits = floatArrayOf(0f, 0f, 0f, 10f, 0f)
59
+ repeat(16) {
60
+ assertEquals(3, sampler.sample(logits))
61
+ }
62
+ }
63
+
64
+ /**
65
+ * Without truncation and a flat distribution, different RNG seeds
66
+ * must select different indices — confirms the multinomial path runs
67
+ * (regression test against an accidental shortcut to argmax).
68
+ */
69
+ @Test
70
+ fun `flat distribution with different seeds yields varied samples`() {
71
+ val flat = FloatArray(8) { 0f } // softmax = uniform 1/8
72
+ val s1 = LiteRTSampler(temperature = 1f, topP = 1f, topK = 0, random = Random(1))
73
+ val s2 = LiteRTSampler(temperature = 1f, topP = 1f, topK = 0, random = Random(99))
74
+ // Two seeds drawing from a uniform distribution over 8 buckets are
75
+ // overwhelmingly likely to disagree on at least one of 5 draws —
76
+ // probability of all 5 matches is (1/8)^5 ~ 3e-5.
77
+ var anyDifferent = false
78
+ repeat(5) {
79
+ if (s1.sample(flat) != s2.sample(flat)) {
80
+ anyDifferent = true
81
+ }
82
+ }
83
+ assertTrue("expected varied samples across different seeds", anyDifferent)
84
+ }
85
+
86
+ /**
87
+ * Determinism: same seed + same logits must produce the same sequence
88
+ * of samples. Catches accidental statics or non-deterministic sort.
89
+ */
90
+ @Test
91
+ fun `same seed produces identical sequence`() {
92
+ val logits = floatArrayOf(1f, 2f, 3f, 4f, 5f)
93
+ val s1 = LiteRTSampler(temperature = 0.7f, topP = 0.9f, topK = 4, random = Random(42))
94
+ val s2 = LiteRTSampler(temperature = 0.7f, topP = 0.9f, topK = 4, random = Random(42))
95
+ for (i in 0 until 32) {
96
+ val a = s1.sample(logits)
97
+ val b = s2.sample(logits)
98
+ assertEquals("sample $i", a, b)
99
+ }
100
+ }
101
+
102
+ /**
103
+ * Single-element logits: every sampler config must return index 0
104
+ * since there's nothing else to pick. Edge case at the start of the
105
+ * vocab loop.
106
+ */
107
+ @Test
108
+ fun `single element logits returns index zero`() {
109
+ val logits = floatArrayOf(2.0f)
110
+ val configs = listOf(
111
+ LiteRTSampler(temperature = 0f, topP = 1f, topK = 0, random = Random(42)),
112
+ LiteRTSampler(temperature = 1f, topP = 1f, topK = 0, random = Random(42)),
113
+ LiteRTSampler(temperature = 1f, topP = 0.5f, topK = 1, random = Random(42)),
114
+ )
115
+ for (s in configs) {
116
+ assertEquals(0, s.sample(logits))
117
+ }
118
+ }
119
+
120
+ /**
121
+ * Soft sanity: with a non-degenerate distribution and temperature > 0,
122
+ * a sufficiently large sample should NOT always equal the argmax.
123
+ * Ensures the sampler isn't silently short-circuiting to argmax.
124
+ */
125
+ @Test
126
+ fun `non greedy sampling is not always argmax`() {
127
+ val logits = floatArrayOf(1f, 1.05f, 1f, 1f) // near-uniform with slight tilt
128
+ val sampler = LiteRTSampler(temperature = 1f, topP = 1f, topK = 0, random = Random(7))
129
+ var nonArgmax = 0
130
+ for (i in 0 until 50) {
131
+ val s = sampler.sample(logits)
132
+ if (s != 1) nonArgmax += 1 // 1 is the argmax
133
+ }
134
+ assertNotEquals("expected at least some non-argmax draws", 0, nonArgmax)
135
+ }
136
+ }
package/package.json ADDED
@@ -0,0 +1,19 @@
1
+ {
2
+ "name": "@dvai-bridge/android-litert-core",
3
+ "version": "4.0.0",
4
+ "description": "DVAI-Bridge Android LiteRT core — wraps Google's LiteRT (TFLite) runtime with HuggingFace tokenizers, exposes the OpenAI-compatible HTTP surface via android-shared-core. Capacitor-free.",
5
+ "author": "Deep Chakraborty <https://github.com/dk013>",
6
+ "license": "Custom (See LICENSE)",
7
+ "files": [
8
+ "android/src",
9
+ "android/build.gradle",
10
+ "android/gradle.properties",
11
+ "android/settings.gradle",
12
+ "README.md",
13
+ "LICENSE"
14
+ ],
15
+ "publishConfig": {
16
+ "registry": "https://registry.npmjs.org/",
17
+ "access": "public"
18
+ }
19
+ }