@dvai-bridge/android-litert-core 4.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +51 -0
- package/README.md +199 -0
- package/android/build.gradle +180 -0
- package/android/gradle.properties +5 -0
- package/android/settings.gradle +1 -0
- package/android/src/main/AndroidManifest.xml +3 -0
- package/android/src/main/java/co/deepvoiceai/bridge/litert/core/Internal/HFTokenizerJson.kt +380 -0
- package/android/src/main/java/co/deepvoiceai/bridge/litert/core/Internal/LiteRTEngine.kt +241 -0
- package/android/src/main/java/co/deepvoiceai/bridge/litert/core/Internal/LiteRTGenerator.kt +71 -0
- package/android/src/main/java/co/deepvoiceai/bridge/litert/core/Internal/LiteRTSampler.kt +105 -0
- package/android/src/main/java/co/deepvoiceai/bridge/litert/core/LiteRTBackendError.kt +13 -0
- package/android/src/main/java/co/deepvoiceai/bridge/litert/core/LiteRTHandlers.kt +378 -0
- package/android/src/main/java/co/deepvoiceai/bridge/litert/core/LiteRTPluginState.kt +199 -0
- package/android/src/test/java/co/deepvoiceai/bridge/litert/core/Internal/LiteRTGeneratorMockTest.kt +234 -0
- package/android/src/test/java/co/deepvoiceai/bridge/litert/core/Internal/LiteRTSamplerTest.kt +136 -0
- package/package.json +19 -0
|
@@ -0,0 +1,199 @@
|
|
|
1
|
+
package co.deepvoiceai.bridge.litert.core
|
|
2
|
+
|
|
3
|
+
import co.deepvoiceai.bridge.litert.core.Internal.HFTokenizerJson
|
|
4
|
+
import co.deepvoiceai.bridge.litert.core.Internal.LiteRTEngine
|
|
5
|
+
import co.deepvoiceai.bridge.litert.core.Internal.LiteRTGenerator
|
|
6
|
+
import co.deepvoiceai.bridge.litert.core.Internal.LiteRTSampler
|
|
7
|
+
import co.deepvoiceai.bridge.shared.core.CorsConfig
|
|
8
|
+
import co.deepvoiceai.bridge.shared.core.HandlerContext
|
|
9
|
+
import co.deepvoiceai.bridge.shared.core.HttpServer
|
|
10
|
+
import com.google.ai.edge.litert.Accelerator
|
|
11
|
+
import kotlinx.coroutines.sync.Mutex
|
|
12
|
+
import kotlinx.coroutines.sync.withLock
|
|
13
|
+
import kotlin.random.Random
|
|
14
|
+
|
|
15
|
+
/**
|
|
16
|
+
* Owns the running state of the LiteRT core: the engine, the tokenizer,
|
|
17
|
+
* the HTTP server, and model metadata. All access is serialized through
|
|
18
|
+
* a [Mutex] so concurrent start/stop calls can never race against the
|
|
19
|
+
* underlying [com.google.ai.edge.litert.CompiledModel] (whose internal
|
|
20
|
+
* KV-cache state is single-conversation by design).
|
|
21
|
+
*
|
|
22
|
+
* Capacitor-free: opts are plain `Map<String, Any?>` and return values
|
|
23
|
+
* are plain `Map<String, Any?>`. The Capacitor wrapper translates
|
|
24
|
+
* JSObject ↔ Map at the JS bridge boundary (consistent with the
|
|
25
|
+
* existing llama-core / mediapipe-core PluginState shape).
|
|
26
|
+
*
|
|
27
|
+
* Required `start()` opts:
|
|
28
|
+
* - `modelPath` String — absolute path to the .tflite / .litertlm file.
|
|
29
|
+
* - `tokenizerPath` String — absolute path to the HF `tokenizer.json`.
|
|
30
|
+
*
|
|
31
|
+
* Optional `start()` opts (all have sane defaults):
|
|
32
|
+
* - `contextSize` Int default 2048
|
|
33
|
+
* - `temperature` Float default 0.0 (greedy)
|
|
34
|
+
* - `topP` Float default 1.0
|
|
35
|
+
* - `topK` Int default 0 (disabled)
|
|
36
|
+
* - `maxNewTokens` Int default 256
|
|
37
|
+
* - `eosTokenId` Int override tokenizer.json's discovered EOS
|
|
38
|
+
* - `accelerator` String "cpu" | "gpu" | "npu" default "cpu"
|
|
39
|
+
* - `chatTemplate` String "llama3" | "plain" default "llama3"
|
|
40
|
+
* - `litertInputName` String default "input_ids"
|
|
41
|
+
* - `litertCausalMaskName` String default "causal_mask"
|
|
42
|
+
* - `litertOutputName` String default "logits"
|
|
43
|
+
* - `httpBasePort` Int default 38883
|
|
44
|
+
* - `httpMaxPortAttempts` Int default 16
|
|
45
|
+
* - `corsOrigin` (see [CorsConfig.fromOpt])
|
|
46
|
+
* - `samplerSeed` Long default System.nanoTime()
|
|
47
|
+
* - `modelId` String default modelPath basename
|
|
48
|
+
*/
|
|
49
|
+
class LiteRTPluginState {
|
|
50
|
+
private val mutex = Mutex()
|
|
51
|
+
private var server: HttpServer? = null
|
|
52
|
+
private var engine: LiteRTEngine? = null
|
|
53
|
+
private var modelId: String = ""
|
|
54
|
+
private var isRunning: Boolean = false
|
|
55
|
+
private var baseUrl: String? = null
|
|
56
|
+
private var port: Int? = null
|
|
57
|
+
|
|
58
|
+
suspend fun start(opts: Map<String, Any?>): Map<String, Any?> = mutex.withLock {
|
|
59
|
+
if (isRunning) stopInternal()
|
|
60
|
+
|
|
61
|
+
val modelPath = (opts["modelPath"] as? String)?.takeIf { it.isNotEmpty() }
|
|
62
|
+
?: throw IllegalArgumentException("modelPath is required for litert backend")
|
|
63
|
+
val tokenizerPath = (opts["tokenizerPath"] as? String)?.takeIf { it.isNotEmpty() }
|
|
64
|
+
?: throw IllegalArgumentException("tokenizerPath is required for litert backend")
|
|
65
|
+
|
|
66
|
+
val contextSize = (opts["contextSize"] as? Number)?.toInt() ?: 2048
|
|
67
|
+
val temperature = (opts["temperature"] as? Number)?.toFloat() ?: 0.0f
|
|
68
|
+
val topP = (opts["topP"] as? Number)?.toFloat() ?: 1.0f
|
|
69
|
+
val topK = (opts["topK"] as? Number)?.toInt() ?: 0
|
|
70
|
+
val maxNewTokens = (opts["maxNewTokens"] as? Number)?.toInt() ?: 256
|
|
71
|
+
val eosOverride = (opts["eosTokenId"] as? Number)?.toInt()
|
|
72
|
+
val acceleratorOpt = (opts["accelerator"] as? String)?.lowercase()
|
|
73
|
+
val accelerator = when (acceleratorOpt) {
|
|
74
|
+
"gpu" -> Accelerator.GPU
|
|
75
|
+
"npu" -> Accelerator.NPU
|
|
76
|
+
null, "cpu" -> Accelerator.CPU
|
|
77
|
+
else -> throw IllegalArgumentException("invalid 'accelerator' opt: $acceleratorOpt (expected cpu|gpu|npu)")
|
|
78
|
+
}
|
|
79
|
+
val chatTemplateOpt = (opts["chatTemplate"] as? String)?.lowercase()
|
|
80
|
+
val chatTemplate = when (chatTemplateOpt) {
|
|
81
|
+
null, "llama3" -> ChatTemplateRenderer.LLAMA3
|
|
82
|
+
"plain" -> ChatTemplateRenderer.PLAIN
|
|
83
|
+
else -> throw IllegalArgumentException("invalid 'chatTemplate' opt: $chatTemplateOpt (expected llama3|plain)")
|
|
84
|
+
}
|
|
85
|
+
val inputName = (opts["litertInputName"] as? String) ?: "input_ids"
|
|
86
|
+
val causalMaskName = (opts["litertCausalMaskName"] as? String) ?: "causal_mask"
|
|
87
|
+
val outputName = (opts["litertOutputName"] as? String) ?: "logits"
|
|
88
|
+
val httpBasePort = (opts["httpBasePort"] as? Number)?.toInt() ?: 38883
|
|
89
|
+
val httpMaxPortAttempts = (opts["httpMaxPortAttempts"] as? Number)?.toInt() ?: 16
|
|
90
|
+
val corsConfig = CorsConfig.fromOpt(opts["corsOrigin"])
|
|
91
|
+
val samplerSeed = (opts["samplerSeed"] as? Number)?.toLong() ?: System.nanoTime()
|
|
92
|
+
val resolvedModelId = (opts["modelId"] as? String)?.takeIf { it.isNotEmpty() }
|
|
93
|
+
?: deriveModelId(modelPath)
|
|
94
|
+
|
|
95
|
+
// Load tokenizer FIRST so we can pick up its discovered EOS for the
|
|
96
|
+
// engine constructor. Tokenizer load failure is cheap; engine load
|
|
97
|
+
// is expensive (loads the full .tflite into memory).
|
|
98
|
+
val tokenizer = try {
|
|
99
|
+
HFTokenizerJson.load(tokenizerPath, eosTokenIdOverride = eosOverride)
|
|
100
|
+
} catch (e: LiteRTBackendError.TokenizerLoadFailed) {
|
|
101
|
+
throw e
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
val newEngine = try {
|
|
105
|
+
LiteRTEngine(
|
|
106
|
+
modelPath = modelPath,
|
|
107
|
+
inputName = inputName,
|
|
108
|
+
causalMaskName = causalMaskName,
|
|
109
|
+
outputName = outputName,
|
|
110
|
+
contextSize = contextSize,
|
|
111
|
+
eosTokenId = tokenizer.eosTokenId,
|
|
112
|
+
accelerator = accelerator,
|
|
113
|
+
)
|
|
114
|
+
} catch (e: LiteRTBackendError.ModelLoadFailed) {
|
|
115
|
+
throw e
|
|
116
|
+
}
|
|
117
|
+
|
|
118
|
+
val sampler = LiteRTSampler(
|
|
119
|
+
temperature = temperature,
|
|
120
|
+
topP = topP,
|
|
121
|
+
topK = topK,
|
|
122
|
+
random = Random(samplerSeed),
|
|
123
|
+
)
|
|
124
|
+
val generator = LiteRTGenerator(
|
|
125
|
+
engine = newEngine,
|
|
126
|
+
tokenizer = tokenizer,
|
|
127
|
+
sampler = sampler,
|
|
128
|
+
maxNewTokens = maxNewTokens,
|
|
129
|
+
)
|
|
130
|
+
val handlers = LiteRTHandlers(
|
|
131
|
+
generator = generator,
|
|
132
|
+
modelId = resolvedModelId,
|
|
133
|
+
chatTemplate = chatTemplate,
|
|
134
|
+
maxNewTokensDefault = maxNewTokens,
|
|
135
|
+
)
|
|
136
|
+
val ctx = HandlerContext(modelId = resolvedModelId, backendName = "litert")
|
|
137
|
+
|
|
138
|
+
val newServer = HttpServer()
|
|
139
|
+
val boundPort = try {
|
|
140
|
+
newServer.startWithRoutes(
|
|
141
|
+
basePort = httpBasePort,
|
|
142
|
+
maxAttempts = httpMaxPortAttempts,
|
|
143
|
+
host = "127.0.0.1",
|
|
144
|
+
handlers = handlers,
|
|
145
|
+
ctx = ctx,
|
|
146
|
+
corsConfig = corsConfig,
|
|
147
|
+
)
|
|
148
|
+
} catch (t: Throwable) {
|
|
149
|
+
// Bind failed — release the engine we already initialized so we
|
|
150
|
+
// don't leak the native handle.
|
|
151
|
+
runCatching { newEngine.close() }
|
|
152
|
+
throw t
|
|
153
|
+
}
|
|
154
|
+
|
|
155
|
+
this.engine = newEngine
|
|
156
|
+
this.server = newServer
|
|
157
|
+
this.modelId = resolvedModelId
|
|
158
|
+
this.port = boundPort
|
|
159
|
+
this.baseUrl = "http://127.0.0.1:$boundPort/v1"
|
|
160
|
+
this.isRunning = true
|
|
161
|
+
|
|
162
|
+
return@withLock mapOf(
|
|
163
|
+
"baseUrl" to "http://127.0.0.1:$boundPort/v1",
|
|
164
|
+
"port" to boundPort,
|
|
165
|
+
"backend" to "litert",
|
|
166
|
+
"modelId" to resolvedModelId,
|
|
167
|
+
)
|
|
168
|
+
}
|
|
169
|
+
|
|
170
|
+
suspend fun stop() = mutex.withLock { stopInternal() }
|
|
171
|
+
|
|
172
|
+
private suspend fun stopInternal() {
|
|
173
|
+
server?.stop()
|
|
174
|
+
runCatching { engine?.close() }
|
|
175
|
+
server = null
|
|
176
|
+
engine = null
|
|
177
|
+
modelId = ""
|
|
178
|
+
baseUrl = null
|
|
179
|
+
port = null
|
|
180
|
+
isRunning = false
|
|
181
|
+
}
|
|
182
|
+
|
|
183
|
+
fun statusInfo(): Map<String, Any?> = buildMap {
|
|
184
|
+
put("running", isRunning)
|
|
185
|
+
baseUrl?.let { put("baseUrl", it) }
|
|
186
|
+
if (isRunning) put("backend", "litert")
|
|
187
|
+
}
|
|
188
|
+
|
|
189
|
+
/**
|
|
190
|
+
* Best-effort default model id from a model file path: strip the
|
|
191
|
+
* directory prefix and any `.tflite` / `.litertlm` extension.
|
|
192
|
+
* Mirrors `MediaPipePluginState.deriveModelId`.
|
|
193
|
+
*/
|
|
194
|
+
private fun deriveModelId(modelPath: String): String {
|
|
195
|
+
val name = modelPath.substringAfterLast('/').substringAfterLast('\\')
|
|
196
|
+
val stripped = name.removeSuffix(".tflite").removeSuffix(".litertlm")
|
|
197
|
+
return stripped.ifEmpty { "litert-default" }
|
|
198
|
+
}
|
|
199
|
+
}
|
package/android/src/test/java/co/deepvoiceai/bridge/litert/core/Internal/LiteRTGeneratorMockTest.kt
ADDED
|
@@ -0,0 +1,234 @@
|
|
|
1
|
+
package co.deepvoiceai.bridge.litert.core.Internal
|
|
2
|
+
|
|
3
|
+
import kotlinx.coroutines.runBlocking
|
|
4
|
+
import org.junit.Assert.assertEquals
|
|
5
|
+
import org.junit.Assert.assertTrue
|
|
6
|
+
import org.junit.Test
|
|
7
|
+
import kotlin.random.Random
|
|
8
|
+
|
|
9
|
+
/**
|
|
10
|
+
* Unit tests for [LiteRTGenerator]'s decode-loop control flow, using a
|
|
11
|
+
* fake [LiteRTEngineApi] that returns canned logits so we can verify:
|
|
12
|
+
*
|
|
13
|
+
* - Tokens come out in the order the engine's logits dictate (greedy).
|
|
14
|
+
* - EOS terminates the loop early.
|
|
15
|
+
* - maxNewTokens caps the loop even when EOS is never produced.
|
|
16
|
+
* - kvCachePosition is incremented monotonically across runStep calls.
|
|
17
|
+
*
|
|
18
|
+
* Tokenizer: a tiny stub that maps the prompt to a fixed token list and
|
|
19
|
+
* round-trips token ids through a vocab built off raw chars. This avoids
|
|
20
|
+
* loading a real tokenizer.json from disk in a unit test.
|
|
21
|
+
*/
|
|
22
|
+
class LiteRTGeneratorMockTest {
|
|
23
|
+
|
|
24
|
+
/** Records every (token, kvPos) pair the generator feeds. */
|
|
25
|
+
private class CallLog {
|
|
26
|
+
val calls = mutableListOf<Pair<Int, Int>>()
|
|
27
|
+
}
|
|
28
|
+
|
|
29
|
+
/**
|
|
30
|
+
* Fake engine: returns canned logits whose argmax is the *next*
|
|
31
|
+
* entry in [predictionsAfterPrefill] starting from the last prefill
|
|
32
|
+
* call. Earlier prefill calls return zero-vector logits — those
|
|
33
|
+
* predictions are immediately overwritten by the next prefill step,
|
|
34
|
+
* so their values are irrelevant.
|
|
35
|
+
*
|
|
36
|
+
* @param promptLen number of prefill calls the generator will make
|
|
37
|
+
* (= prompt token count). The fake counts down
|
|
38
|
+
* `promptLen` calls before it starts emitting
|
|
39
|
+
* meaningful argmax targets from
|
|
40
|
+
* [predictionsAfterPrefill][0].
|
|
41
|
+
*/
|
|
42
|
+
private class FakeEngine(
|
|
43
|
+
override val vocabSize: Int,
|
|
44
|
+
override val eosTokenId: Int,
|
|
45
|
+
private val promptLen: Int,
|
|
46
|
+
private val predictionsAfterPrefill: List<Int>,
|
|
47
|
+
private val log: CallLog,
|
|
48
|
+
) : LiteRTEngineApi {
|
|
49
|
+
private var callIdx = 0
|
|
50
|
+
|
|
51
|
+
override fun runStep(token: Int, kvCachePosition: Int): FloatArray {
|
|
52
|
+
log.calls.add(token to kvCachePosition)
|
|
53
|
+
val logits = FloatArray(vocabSize) { -10f }
|
|
54
|
+
// First (promptLen - 1) calls are early-prefill; their logits
|
|
55
|
+
// are immediately overwritten by the next prefill call so we
|
|
56
|
+
// emit no preference. Call (promptLen - 1) is the LAST prefill
|
|
57
|
+
// — its argmax is the FIRST sampled token. From there on,
|
|
58
|
+
// each runStep is a decode-feed and its argmax is the next
|
|
59
|
+
// sampled token.
|
|
60
|
+
val predictionIdx = callIdx - (promptLen - 1)
|
|
61
|
+
callIdx += 1
|
|
62
|
+
if (predictionIdx < 0) return logits
|
|
63
|
+
val target = if (predictionIdx < predictionsAfterPrefill.size) {
|
|
64
|
+
predictionsAfterPrefill[predictionIdx]
|
|
65
|
+
} else {
|
|
66
|
+
eosTokenId
|
|
67
|
+
}
|
|
68
|
+
logits[target] = 100f // dominant
|
|
69
|
+
return logits
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
override fun close() {}
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
/**
|
|
76
|
+
* Build a tokenizer over a tiny ASCII alphabet so the generator's
|
|
77
|
+
* encode + decode round-trips deterministically. We use the
|
|
78
|
+
* pure-Kotlin BPE loader's output via a synthetic JSON; that's
|
|
79
|
+
* heavier than needed here, so we provide a hand-rolled stub
|
|
80
|
+
* instead by writing a tokenizer.json to a temp file.
|
|
81
|
+
*
|
|
82
|
+
* Since the production code only depends on encode/decode, a far
|
|
83
|
+
* simpler approach is to build the tokenizer through a custom
|
|
84
|
+
* factory bypassing the JSON loader. We do that via reflection on
|
|
85
|
+
* the private constructor — but that's brittle. Instead, write a
|
|
86
|
+
* minimal valid tokenizer.json and load it normally.
|
|
87
|
+
*/
|
|
88
|
+
private fun buildTinyTokenizer(eosId: Int = 99): HFTokenizerJson {
|
|
89
|
+
// Minimal tokenizer.json with BPE model and added_tokens for EOS.
|
|
90
|
+
// Vocab covers single-byte chars 'a'..'e' (which after byte-level
|
|
91
|
+
// mapping become themselves since 'a'..'e' are in the printable
|
|
92
|
+
// ASCII range 33..126) plus the EOS special.
|
|
93
|
+
val json = """
|
|
94
|
+
{
|
|
95
|
+
"model": {
|
|
96
|
+
"type": "BPE",
|
|
97
|
+
"vocab": { "a": 0, "b": 1, "c": 2, "d": 3, "e": 4 },
|
|
98
|
+
"merges": []
|
|
99
|
+
},
|
|
100
|
+
"added_tokens": [
|
|
101
|
+
{ "id": $eosId, "content": "<|eot_id|>", "special": true }
|
|
102
|
+
]
|
|
103
|
+
}
|
|
104
|
+
""".trimIndent()
|
|
105
|
+
val tmp = java.io.File.createTempFile("litert-test-tokenizer", ".json")
|
|
106
|
+
tmp.deleteOnExit()
|
|
107
|
+
tmp.writeText(json)
|
|
108
|
+
return HFTokenizerJson.load(tmp.absolutePath)
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
@Test
|
|
112
|
+
fun `generator emits scripted tokens in order`() = runBlocking {
|
|
113
|
+
// Prompt "abc" tokenizes to [0, 1, 2] (promptLen=3). After the
|
|
114
|
+
// last prefill step the engine predicts 3 ('d'), then 4 ('e'),
|
|
115
|
+
// then 99 (EOS). Decoded output is "de".
|
|
116
|
+
val log = CallLog()
|
|
117
|
+
val tokenizer = buildTinyTokenizer(eosId = 99)
|
|
118
|
+
val engine = FakeEngine(
|
|
119
|
+
vocabSize = 100,
|
|
120
|
+
eosTokenId = 99,
|
|
121
|
+
promptLen = 3,
|
|
122
|
+
predictionsAfterPrefill = listOf(3, 4, 99),
|
|
123
|
+
log = log,
|
|
124
|
+
)
|
|
125
|
+
val sampler = LiteRTSampler(temperature = 0f, topP = 1f, topK = 0, random = Random(42))
|
|
126
|
+
val gen = LiteRTGenerator(engine, tokenizer, sampler, maxNewTokens = 16)
|
|
127
|
+
|
|
128
|
+
val output = gen.generate("abc")
|
|
129
|
+
assertEquals("de", output)
|
|
130
|
+
}
|
|
131
|
+
|
|
132
|
+
@Test
|
|
133
|
+
fun `generator stops on EOS`() = runBlocking {
|
|
134
|
+
// promptLen=1: prompt "a" -> [0] -> single prefill call. Its
|
|
135
|
+
// argmax is EOS, so the decode loop terminates before producing
|
|
136
|
+
// any token.
|
|
137
|
+
val log = CallLog()
|
|
138
|
+
val tokenizer = buildTinyTokenizer(eosId = 99)
|
|
139
|
+
val engine = FakeEngine(
|
|
140
|
+
vocabSize = 100,
|
|
141
|
+
eosTokenId = 99,
|
|
142
|
+
promptLen = 1,
|
|
143
|
+
predictionsAfterPrefill = listOf(99),
|
|
144
|
+
log = log,
|
|
145
|
+
)
|
|
146
|
+
val sampler = LiteRTSampler(temperature = 0f, topP = 1f, topK = 0, random = Random(42))
|
|
147
|
+
val gen = LiteRTGenerator(engine, tokenizer, sampler, maxNewTokens = 99)
|
|
148
|
+
|
|
149
|
+
val output = gen.generate("a") // tokenizes to [0]
|
|
150
|
+
assertEquals("", output)
|
|
151
|
+
// Exactly one runStep call (the prefill); first sample produced
|
|
152
|
+
// EOS so the decode loop never invoked the engine again.
|
|
153
|
+
assertEquals(1, log.calls.size)
|
|
154
|
+
assertEquals(0 to 0, log.calls[0])
|
|
155
|
+
}
|
|
156
|
+
|
|
157
|
+
@Test
|
|
158
|
+
fun `generator caps at maxNewTokens when EOS never emitted`() = runBlocking {
|
|
159
|
+
// 100 'd's in a row, never EOS. maxNewTokens=5 caps the loop.
|
|
160
|
+
val log = CallLog()
|
|
161
|
+
val tokenizer = buildTinyTokenizer(eosId = 99)
|
|
162
|
+
val script = List(100) { 3 }
|
|
163
|
+
val engine = FakeEngine(
|
|
164
|
+
vocabSize = 100,
|
|
165
|
+
eosTokenId = 99,
|
|
166
|
+
promptLen = 3,
|
|
167
|
+
predictionsAfterPrefill = script,
|
|
168
|
+
log = log,
|
|
169
|
+
)
|
|
170
|
+
val sampler = LiteRTSampler(temperature = 0f, topP = 1f, topK = 0, random = Random(42))
|
|
171
|
+
val gen = LiteRTGenerator(engine, tokenizer, sampler, maxNewTokens = 5)
|
|
172
|
+
|
|
173
|
+
val output = gen.generate("abc") // tokenizes to [0,1,2]
|
|
174
|
+
assertEquals("ddddd", output)
|
|
175
|
+
}
|
|
176
|
+
|
|
177
|
+
@Test
|
|
178
|
+
fun `kvCachePosition increments monotonically across calls`() = runBlocking {
|
|
179
|
+
val log = CallLog()
|
|
180
|
+
val tokenizer = buildTinyTokenizer(eosId = 99)
|
|
181
|
+
val engine = FakeEngine(
|
|
182
|
+
vocabSize = 100,
|
|
183
|
+
eosTokenId = 99,
|
|
184
|
+
promptLen = 3,
|
|
185
|
+
predictionsAfterPrefill = listOf(3, 4, 99),
|
|
186
|
+
log = log,
|
|
187
|
+
)
|
|
188
|
+
val sampler = LiteRTSampler(temperature = 0f, topP = 1f, topK = 0, random = Random(42))
|
|
189
|
+
val gen = LiteRTGenerator(engine, tokenizer, sampler, maxNewTokens = 16)
|
|
190
|
+
|
|
191
|
+
gen.generate("abc")
|
|
192
|
+
|
|
193
|
+
// Expected calls (token, kvPos):
|
|
194
|
+
// 0: (0, 0) prefill[0] — early-prefill, argmax irrelevant
|
|
195
|
+
// 1: (1, 1) prefill[1] — early-prefill, argmax irrelevant
|
|
196
|
+
// 2: (2, 2) prefill[2] — last prefill -> sample produces 3
|
|
197
|
+
// 3: (3, 3) decode-feed -> sample produces 4
|
|
198
|
+
// 4: (4, 4) decode-feed -> sample produces 99 -> EOS halts
|
|
199
|
+
assertEquals(5, log.calls.size)
|
|
200
|
+
assertEquals(0 to 0, log.calls[0])
|
|
201
|
+
assertEquals(1 to 1, log.calls[1])
|
|
202
|
+
assertEquals(2 to 2, log.calls[2])
|
|
203
|
+
assertEquals(3 to 3, log.calls[3])
|
|
204
|
+
assertEquals(4 to 4, log.calls[4])
|
|
205
|
+
for (i in log.calls.indices) {
|
|
206
|
+
assertEquals(i, log.calls[i].second)
|
|
207
|
+
}
|
|
208
|
+
}
|
|
209
|
+
|
|
210
|
+
@Test
|
|
211
|
+
fun `empty prompt throws GenerationFailed`() = runBlocking {
|
|
212
|
+
val log = CallLog()
|
|
213
|
+
val tokenizer = buildTinyTokenizer(eosId = 99)
|
|
214
|
+
val engine = FakeEngine(
|
|
215
|
+
vocabSize = 100,
|
|
216
|
+
eosTokenId = 99,
|
|
217
|
+
promptLen = 0,
|
|
218
|
+
predictionsAfterPrefill = listOf(3),
|
|
219
|
+
log = log,
|
|
220
|
+
)
|
|
221
|
+
val sampler = LiteRTSampler(temperature = 0f, topP = 1f, topK = 0, random = Random(42))
|
|
222
|
+
val gen = LiteRTGenerator(engine, tokenizer, sampler, maxNewTokens = 4)
|
|
223
|
+
|
|
224
|
+
try {
|
|
225
|
+
gen.generate("")
|
|
226
|
+
org.junit.Assert.fail("expected GenerationFailed for empty prompt")
|
|
227
|
+
} catch (e: co.deepvoiceai.bridge.litert.core.LiteRTBackendError.GenerationFailed) {
|
|
228
|
+
assertTrue(e.message!!.contains("empty"))
|
|
229
|
+
}
|
|
230
|
+
// Engine never invoked — the generator validates BEFORE the first
|
|
231
|
+
// runStep call.
|
|
232
|
+
assertEquals(0, log.calls.size)
|
|
233
|
+
}
|
|
234
|
+
}
|
|
@@ -0,0 +1,136 @@
|
|
|
1
|
+
package co.deepvoiceai.bridge.litert.core.Internal
|
|
2
|
+
|
|
3
|
+
import org.junit.Assert.assertEquals
|
|
4
|
+
import org.junit.Assert.assertNotEquals
|
|
5
|
+
import org.junit.Assert.assertTrue
|
|
6
|
+
import org.junit.Test
|
|
7
|
+
import kotlin.random.Random
|
|
8
|
+
|
|
9
|
+
/**
|
|
10
|
+
* Unit tests for [LiteRTSampler]. Pure-Kotlin sampler — no Robolectric or
|
|
11
|
+
* Android runtime needed, plain JUnit 4.
|
|
12
|
+
*
|
|
13
|
+
* The sampler mirrors `CoreMLSampler.swift` (iOS); these tests intentionally
|
|
14
|
+
* mirror what would be the iOS test cases so a behavioural divergence
|
|
15
|
+
* surfaces in both suites.
|
|
16
|
+
*/
|
|
17
|
+
class LiteRTSamplerTest {
|
|
18
|
+
|
|
19
|
+
/**
|
|
20
|
+
* temperature == 0 must take the pure-argmax path regardless of topP /
|
|
21
|
+
* topK. This is the deterministic chat / function-calling default.
|
|
22
|
+
*/
|
|
23
|
+
@Test
|
|
24
|
+
fun `temperature zero is argmax`() {
|
|
25
|
+
val sampler = LiteRTSampler(temperature = 0f, topP = 1f, topK = 0, random = Random(42))
|
|
26
|
+
val logits = floatArrayOf(0.1f, 5.0f, -1.0f, 4.99f)
|
|
27
|
+
// Argmax = index 1 (5.0). Run repeatedly to confirm zero variance.
|
|
28
|
+
repeat(8) {
|
|
29
|
+
assertEquals(1, sampler.sample(logits))
|
|
30
|
+
}
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
/**
|
|
34
|
+
* topK = 1 must collapse to argmax even when temperature > 0 — the
|
|
35
|
+
* post-softmax distribution has a single non-zero entry, so multinomial
|
|
36
|
+
* sampling can only return the argmax index.
|
|
37
|
+
*/
|
|
38
|
+
@Test
|
|
39
|
+
fun `topK one with positive temperature is deterministic argmax`() {
|
|
40
|
+
val sampler = LiteRTSampler(temperature = 1.0f, topP = 1f, topK = 1, random = Random(42))
|
|
41
|
+
val logits = floatArrayOf(-2.0f, 0.5f, 3.7f, 1.2f)
|
|
42
|
+
// Argmax = index 2.
|
|
43
|
+
repeat(16) {
|
|
44
|
+
assertEquals(2, sampler.sample(logits))
|
|
45
|
+
}
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
/**
|
|
49
|
+
* topP truncation: with cumulative-prob cutoff = 0.5 and a sharply
|
|
50
|
+
* peaked distribution (one logit dominates softmax > 0.5), only the
|
|
51
|
+
* argmax index can be returned.
|
|
52
|
+
*/
|
|
53
|
+
@Test
|
|
54
|
+
fun `topP truncation keeps cumulative probability cutoff`() {
|
|
55
|
+
val sampler = LiteRTSampler(temperature = 1.0f, topP = 0.5f, topK = 0, random = Random(42))
|
|
56
|
+
// logits[3] = 10 dominates softmax (>>99% mass) — top-p=0.5 keeps
|
|
57
|
+
// exactly one index, so sampling collapses to argmax = 3.
|
|
58
|
+
val logits = floatArrayOf(0f, 0f, 0f, 10f, 0f)
|
|
59
|
+
repeat(16) {
|
|
60
|
+
assertEquals(3, sampler.sample(logits))
|
|
61
|
+
}
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
/**
|
|
65
|
+
* Without truncation and a flat distribution, different RNG seeds
|
|
66
|
+
* must select different indices — confirms the multinomial path runs
|
|
67
|
+
* (regression test against an accidental shortcut to argmax).
|
|
68
|
+
*/
|
|
69
|
+
@Test
|
|
70
|
+
fun `flat distribution with different seeds yields varied samples`() {
|
|
71
|
+
val flat = FloatArray(8) { 0f } // softmax = uniform 1/8
|
|
72
|
+
val s1 = LiteRTSampler(temperature = 1f, topP = 1f, topK = 0, random = Random(1))
|
|
73
|
+
val s2 = LiteRTSampler(temperature = 1f, topP = 1f, topK = 0, random = Random(99))
|
|
74
|
+
// Two seeds drawing from a uniform distribution over 8 buckets are
|
|
75
|
+
// overwhelmingly likely to disagree on at least one of 5 draws —
|
|
76
|
+
// probability of all 5 matches is (1/8)^5 ~ 3e-5.
|
|
77
|
+
var anyDifferent = false
|
|
78
|
+
repeat(5) {
|
|
79
|
+
if (s1.sample(flat) != s2.sample(flat)) {
|
|
80
|
+
anyDifferent = true
|
|
81
|
+
}
|
|
82
|
+
}
|
|
83
|
+
assertTrue("expected varied samples across different seeds", anyDifferent)
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
/**
|
|
87
|
+
* Determinism: same seed + same logits must produce the same sequence
|
|
88
|
+
* of samples. Catches accidental statics or non-deterministic sort.
|
|
89
|
+
*/
|
|
90
|
+
@Test
|
|
91
|
+
fun `same seed produces identical sequence`() {
|
|
92
|
+
val logits = floatArrayOf(1f, 2f, 3f, 4f, 5f)
|
|
93
|
+
val s1 = LiteRTSampler(temperature = 0.7f, topP = 0.9f, topK = 4, random = Random(42))
|
|
94
|
+
val s2 = LiteRTSampler(temperature = 0.7f, topP = 0.9f, topK = 4, random = Random(42))
|
|
95
|
+
for (i in 0 until 32) {
|
|
96
|
+
val a = s1.sample(logits)
|
|
97
|
+
val b = s2.sample(logits)
|
|
98
|
+
assertEquals("sample $i", a, b)
|
|
99
|
+
}
|
|
100
|
+
}
|
|
101
|
+
|
|
102
|
+
/**
|
|
103
|
+
* Single-element logits: every sampler config must return index 0
|
|
104
|
+
* since there's nothing else to pick. Edge case at the start of the
|
|
105
|
+
* vocab loop.
|
|
106
|
+
*/
|
|
107
|
+
@Test
|
|
108
|
+
fun `single element logits returns index zero`() {
|
|
109
|
+
val logits = floatArrayOf(2.0f)
|
|
110
|
+
val configs = listOf(
|
|
111
|
+
LiteRTSampler(temperature = 0f, topP = 1f, topK = 0, random = Random(42)),
|
|
112
|
+
LiteRTSampler(temperature = 1f, topP = 1f, topK = 0, random = Random(42)),
|
|
113
|
+
LiteRTSampler(temperature = 1f, topP = 0.5f, topK = 1, random = Random(42)),
|
|
114
|
+
)
|
|
115
|
+
for (s in configs) {
|
|
116
|
+
assertEquals(0, s.sample(logits))
|
|
117
|
+
}
|
|
118
|
+
}
|
|
119
|
+
|
|
120
|
+
/**
|
|
121
|
+
* Soft sanity: with a non-degenerate distribution and temperature > 0,
|
|
122
|
+
* a sufficiently large sample should NOT always equal the argmax.
|
|
123
|
+
* Ensures the sampler isn't silently short-circuiting to argmax.
|
|
124
|
+
*/
|
|
125
|
+
@Test
|
|
126
|
+
fun `non greedy sampling is not always argmax`() {
|
|
127
|
+
val logits = floatArrayOf(1f, 1.05f, 1f, 1f) // near-uniform with slight tilt
|
|
128
|
+
val sampler = LiteRTSampler(temperature = 1f, topP = 1f, topK = 0, random = Random(7))
|
|
129
|
+
var nonArgmax = 0
|
|
130
|
+
for (i in 0 until 50) {
|
|
131
|
+
val s = sampler.sample(logits)
|
|
132
|
+
if (s != 1) nonArgmax += 1 // 1 is the argmax
|
|
133
|
+
}
|
|
134
|
+
assertNotEquals("expected at least some non-argmax draws", 0, nonArgmax)
|
|
135
|
+
}
|
|
136
|
+
}
|
package/package.json
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
{
|
|
2
|
+
"name": "@dvai-bridge/android-litert-core",
|
|
3
|
+
"version": "4.0.0",
|
|
4
|
+
"description": "DVAI-Bridge Android LiteRT core — wraps Google's LiteRT (TFLite) runtime with HuggingFace tokenizers, exposes the OpenAI-compatible HTTP surface via android-shared-core. Capacitor-free.",
|
|
5
|
+
"author": "Deep Chakraborty <https://github.com/dk013>",
|
|
6
|
+
"license": "Custom (See LICENSE)",
|
|
7
|
+
"files": [
|
|
8
|
+
"android/src",
|
|
9
|
+
"android/build.gradle",
|
|
10
|
+
"android/gradle.properties",
|
|
11
|
+
"android/settings.gradle",
|
|
12
|
+
"README.md",
|
|
13
|
+
"LICENSE"
|
|
14
|
+
],
|
|
15
|
+
"publishConfig": {
|
|
16
|
+
"registry": "https://registry.npmjs.org/",
|
|
17
|
+
"access": "public"
|
|
18
|
+
}
|
|
19
|
+
}
|