@dvai-bridge/android-litert-core 4.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +51 -0
- package/README.md +199 -0
- package/android/build.gradle +180 -0
- package/android/gradle.properties +5 -0
- package/android/settings.gradle +1 -0
- package/android/src/main/AndroidManifest.xml +3 -0
- package/android/src/main/java/co/deepvoiceai/bridge/litert/core/Internal/HFTokenizerJson.kt +380 -0
- package/android/src/main/java/co/deepvoiceai/bridge/litert/core/Internal/LiteRTEngine.kt +241 -0
- package/android/src/main/java/co/deepvoiceai/bridge/litert/core/Internal/LiteRTGenerator.kt +71 -0
- package/android/src/main/java/co/deepvoiceai/bridge/litert/core/Internal/LiteRTSampler.kt +105 -0
- package/android/src/main/java/co/deepvoiceai/bridge/litert/core/LiteRTBackendError.kt +13 -0
- package/android/src/main/java/co/deepvoiceai/bridge/litert/core/LiteRTHandlers.kt +378 -0
- package/android/src/main/java/co/deepvoiceai/bridge/litert/core/LiteRTPluginState.kt +199 -0
- package/android/src/test/java/co/deepvoiceai/bridge/litert/core/Internal/LiteRTGeneratorMockTest.kt +234 -0
- package/android/src/test/java/co/deepvoiceai/bridge/litert/core/Internal/LiteRTSamplerTest.kt +136 -0
- package/package.json +19 -0
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
package co.deepvoiceai.bridge.litert.core.Internal
|
|
2
|
+
|
|
3
|
+
import co.deepvoiceai.bridge.litert.core.LiteRTBackendError
|
|
4
|
+
import kotlinx.coroutines.Dispatchers
|
|
5
|
+
import kotlinx.coroutines.withContext
|
|
6
|
+
|
|
7
|
+
/**
|
|
8
|
+
* Orchestrates [LiteRTEngineApi] + [HFTokenizerJson] + [LiteRTSampler]
|
|
9
|
+
* into a single autoregressive generate() call.
|
|
10
|
+
*
|
|
11
|
+
* Mirrors `CoreMLGenerator.swift` (iOS) — same prefill+decode unification:
|
|
12
|
+
* each [LiteRTEngineApi.runStep] returns logits for the *next* token
|
|
13
|
+
* (position kv+1), so after feeding all prompt tokens the last logits
|
|
14
|
+
* give the first generated token. We do NOT re-feed `prompt.last()` —
|
|
15
|
+
* that would double-count it in the KV cache.
|
|
16
|
+
*
|
|
17
|
+
* Concurrency: this class is suspend-only; long blocking native calls
|
|
18
|
+
* (`engine.runStep`) are wrapped in `withContext(Dispatchers.Default)` so
|
|
19
|
+
* the caller's coroutine doesn't pin the main thread. Caller must
|
|
20
|
+
* serialize calls behind a mutex (see [LiteRTHandlers.generatorMutex]).
|
|
21
|
+
*/
|
|
22
|
+
internal class LiteRTGenerator(
|
|
23
|
+
private val engine: LiteRTEngineApi,
|
|
24
|
+
private val tokenizer: HFTokenizerJson,
|
|
25
|
+
private val sampler: LiteRTSampler,
|
|
26
|
+
private val maxNewTokens: Int,
|
|
27
|
+
) {
|
|
28
|
+
|
|
29
|
+
/**
|
|
30
|
+
* Tokenize [prompt], run the prefill+decode loop, and return the
|
|
31
|
+
* decoded completion (without re-emitting the prompt). Stops on EOS or
|
|
32
|
+
* once [maxNewTokens] generated tokens have been produced.
|
|
33
|
+
*
|
|
34
|
+
* Throws [LiteRTBackendError.GenerationFailed] on empty prompt or
|
|
35
|
+
* native errors propagated from [engine].
|
|
36
|
+
*/
|
|
37
|
+
suspend fun generate(prompt: String): String = withContext(Dispatchers.Default) {
|
|
38
|
+
val promptTokens = tokenizer.encode(prompt)
|
|
39
|
+
if (promptTokens.isEmpty()) {
|
|
40
|
+
throw LiteRTBackendError.GenerationFailed("prompt tokenized to empty list")
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
val generated = mutableListOf<Int>()
|
|
44
|
+
|
|
45
|
+
// Prefill: feed each prompt token, advancing kv position. After
|
|
46
|
+
// the last prompt token, `lastLogits` predicts the FIRST output
|
|
47
|
+
// token — no separate "decode kickoff" step needed.
|
|
48
|
+
var lastLogits = engine.runStep(promptTokens[0], kvCachePosition = 0)
|
|
49
|
+
var kvPos = 1
|
|
50
|
+
for (i in 1 until promptTokens.size) {
|
|
51
|
+
lastLogits = engine.runStep(promptTokens[i], kvCachePosition = kvPos)
|
|
52
|
+
kvPos += 1
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
var nextToken = sampler.sample(lastLogits)
|
|
56
|
+
var produced = 0
|
|
57
|
+
while (produced < maxNewTokens) {
|
|
58
|
+
if (nextToken == engine.eosTokenId) break
|
|
59
|
+
generated.add(nextToken)
|
|
60
|
+
produced += 1
|
|
61
|
+
// Stop early if we already hit the cap — saves one wasted
|
|
62
|
+
// forward pass that would just be discarded.
|
|
63
|
+
if (produced >= maxNewTokens) break
|
|
64
|
+
lastLogits = engine.runStep(nextToken, kvCachePosition = kvPos)
|
|
65
|
+
kvPos += 1
|
|
66
|
+
nextToken = sampler.sample(lastLogits)
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
tokenizer.decode(generated, skipSpecialTokens = true)
|
|
70
|
+
}
|
|
71
|
+
}
|
|
@@ -0,0 +1,105 @@
|
|
|
1
|
+
package co.deepvoiceai.bridge.litert.core.Internal
|
|
2
|
+
|
|
3
|
+
import kotlin.math.exp
|
|
4
|
+
import kotlin.random.Random
|
|
5
|
+
|
|
6
|
+
/**
|
|
7
|
+
* Greedy + temperature/top-p/top-k sampler for LiteRT logits. Pure Kotlin,
|
|
8
|
+
* no LiteRT or HuggingFace tokenizer dependency — purely numerical.
|
|
9
|
+
*
|
|
10
|
+
* Mirrors `CoreMLSampler.swift` (iOS) 1:1; algorithm comments are kept in
|
|
11
|
+
* lockstep with the iOS implementation so future fixes apply uniformly.
|
|
12
|
+
*
|
|
13
|
+
* @param temperature 0.0 = pure argmax (deterministic). >0.0 enables
|
|
14
|
+
* temperature scaling.
|
|
15
|
+
* @param topP Nucleus-sampling cutoff. 1.0 = disabled (consider all
|
|
16
|
+
* tokens). 0.0 < topP < 1.0 = keep tokens whose
|
|
17
|
+
* cumulative probability covers `topP`.
|
|
18
|
+
* @param topK 0 = disabled. >0 = keep only the K highest-probability
|
|
19
|
+
* tokens before sampling.
|
|
20
|
+
* @param random Source of randomness. Test code seeds for determinism.
|
|
21
|
+
*/
|
|
22
|
+
internal class LiteRTSampler(
|
|
23
|
+
private val temperature: Float,
|
|
24
|
+
private val topP: Float,
|
|
25
|
+
private val topK: Int,
|
|
26
|
+
private val random: Random = Random.Default,
|
|
27
|
+
) {
|
|
28
|
+
fun sample(logits: FloatArray): Int {
|
|
29
|
+
// Fast path: temperature == 0 -> pure argmax. Avoids softmax allocation.
|
|
30
|
+
if (temperature <= 0f) return argmax(logits)
|
|
31
|
+
|
|
32
|
+
// Apply temperature: divide each logit by T, then softmax over the
|
|
33
|
+
// result. Larger T flattens the distribution, smaller T sharpens.
|
|
34
|
+
val scaled = FloatArray(logits.size) { logits[it] / temperature }
|
|
35
|
+
val probs = softmax(scaled)
|
|
36
|
+
|
|
37
|
+
// Build (idx, prob) pairs, sorted descending by prob. Allocates an
|
|
38
|
+
// index array proportional to vocab size — typically 32k-256k.
|
|
39
|
+
val sortedIdxByProbDesc = (probs.indices)
|
|
40
|
+
.sortedByDescending { probs[it] }
|
|
41
|
+
|
|
42
|
+
// top-k truncation: keep only the first K entries.
|
|
43
|
+
val afterTopK = if (topK > 0 && topK < sortedIdxByProbDesc.size) {
|
|
44
|
+
sortedIdxByProbDesc.subList(0, topK)
|
|
45
|
+
} else {
|
|
46
|
+
sortedIdxByProbDesc
|
|
47
|
+
}
|
|
48
|
+
|
|
49
|
+
// top-p (nucleus) truncation: keep prefix of indices whose cumulative
|
|
50
|
+
// probability covers `topP`. Always keep at least one token.
|
|
51
|
+
val afterTopP = if (topP > 0f && topP < 1f) {
|
|
52
|
+
val kept = mutableListOf<Int>()
|
|
53
|
+
var cum = 0f
|
|
54
|
+
for (idx in afterTopK) {
|
|
55
|
+
kept.add(idx)
|
|
56
|
+
cum += probs[idx]
|
|
57
|
+
if (cum >= topP) break
|
|
58
|
+
}
|
|
59
|
+
kept
|
|
60
|
+
} else {
|
|
61
|
+
afterTopK
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
// Renormalize the kept distribution and sample multinomially.
|
|
65
|
+
val keptProbs = afterTopP.map { probs[it] }
|
|
66
|
+
val total = keptProbs.sum()
|
|
67
|
+
val r = random.nextFloat() * total
|
|
68
|
+
var acc = 0f
|
|
69
|
+
for ((i, p) in keptProbs.withIndex()) {
|
|
70
|
+
acc += p
|
|
71
|
+
if (r <= acc) return afterTopP[i]
|
|
72
|
+
}
|
|
73
|
+
// Numerical fall-through: return last kept index.
|
|
74
|
+
return afterTopP.last()
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
private fun argmax(arr: FloatArray): Int {
|
|
78
|
+
var best = 0
|
|
79
|
+
var bestVal = arr[0]
|
|
80
|
+
for (i in 1 until arr.size) {
|
|
81
|
+
if (arr[i] > bestVal) {
|
|
82
|
+
bestVal = arr[i]
|
|
83
|
+
best = i
|
|
84
|
+
}
|
|
85
|
+
}
|
|
86
|
+
return best
|
|
87
|
+
}
|
|
88
|
+
|
|
89
|
+
/**
|
|
90
|
+
* Numerically-stable softmax: subtract max, exp, normalize. Avoids
|
|
91
|
+
* overflow on large logits. Returns a new FloatArray of the same size.
|
|
92
|
+
*/
|
|
93
|
+
private fun softmax(arr: FloatArray): FloatArray {
|
|
94
|
+
var max = arr[0]
|
|
95
|
+
for (i in 1 until arr.size) if (arr[i] > max) max = arr[i]
|
|
96
|
+
var sum = 0f
|
|
97
|
+
val out = FloatArray(arr.size) { i ->
|
|
98
|
+
val e = exp((arr[i] - max).toDouble()).toFloat()
|
|
99
|
+
sum += e
|
|
100
|
+
e
|
|
101
|
+
}
|
|
102
|
+
for (i in out.indices) out[i] /= sum
|
|
103
|
+
return out
|
|
104
|
+
}
|
|
105
|
+
}
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
package co.deepvoiceai.bridge.litert.core
|
|
2
|
+
|
|
3
|
+
/**
|
|
4
|
+
* Errors surfaced by the LiteRT backend. Mirrors `CoreMLBackendError` (iOS)
|
|
5
|
+
* and `MediaPipeBackendError` (mediapipe-core) — when the umbrella SDK
|
|
6
|
+
* (`@dvai-bridge/android`) maps a backend exception to a public
|
|
7
|
+
* `DVAIBridgeError`, it pattern-matches on the sealed type here.
|
|
8
|
+
*/
|
|
9
|
+
sealed class LiteRTBackendError(message: String) : Exception(message) {
|
|
10
|
+
class ModelLoadFailed(reason: String) : LiteRTBackendError("LiteRT model load failed: $reason")
|
|
11
|
+
class TokenizerLoadFailed(reason: String) : LiteRTBackendError("LiteRT tokenizer load failed: $reason")
|
|
12
|
+
class GenerationFailed(reason: String) : LiteRTBackendError("LiteRT generation failed: $reason")
|
|
13
|
+
}
|
|
@@ -0,0 +1,378 @@
|
|
|
1
|
+
package co.deepvoiceai.bridge.litert.core
|
|
2
|
+
|
|
3
|
+
import co.deepvoiceai.bridge.litert.core.Internal.LiteRTGenerator
|
|
4
|
+
import co.deepvoiceai.bridge.shared.core.DvaiHandlers
|
|
5
|
+
import co.deepvoiceai.bridge.shared.core.HandlerContext
|
|
6
|
+
import co.deepvoiceai.bridge.shared.core.HandlerResponse
|
|
7
|
+
import kotlinx.coroutines.flow.flow
|
|
8
|
+
import kotlinx.coroutines.sync.Mutex
|
|
9
|
+
import kotlinx.coroutines.sync.withLock
|
|
10
|
+
import kotlinx.serialization.json.JsonArray
|
|
11
|
+
import kotlinx.serialization.json.JsonNull
|
|
12
|
+
import kotlinx.serialization.json.JsonObject
|
|
13
|
+
import kotlinx.serialization.json.JsonPrimitive
|
|
14
|
+
import kotlinx.serialization.json.addJsonObject
|
|
15
|
+
import kotlinx.serialization.json.booleanOrNull
|
|
16
|
+
import kotlinx.serialization.json.buildJsonObject
|
|
17
|
+
import kotlinx.serialization.json.contentOrNull
|
|
18
|
+
import kotlinx.serialization.json.intOrNull
|
|
19
|
+
import kotlinx.serialization.json.put
|
|
20
|
+
import kotlinx.serialization.json.putJsonArray
|
|
21
|
+
import kotlinx.serialization.json.putJsonObject
|
|
22
|
+
|
|
23
|
+
/**
|
|
24
|
+
* OpenAI-compatible handler set for the LiteRT backend.
|
|
25
|
+
*
|
|
26
|
+
* Wires [LiteRTGenerator] into the four standard routes:
|
|
27
|
+
* POST /v1/chat/completions → tokenize messages, generate, wrap.
|
|
28
|
+
* POST /v1/completions → adapt chat output to the legacy text shape.
|
|
29
|
+
* POST /v1/embeddings → 501 Not Implemented (LiteRT raw graphs
|
|
30
|
+
* don't expose an embedding head).
|
|
31
|
+
* GET /v1/models → standard list-with-one-item shape.
|
|
32
|
+
*
|
|
33
|
+
* Streaming for chat: same 4-frame shape (role / content / finish / [DONE])
|
|
34
|
+
* as `LlamaHandlers` — Ktor's responseChannel buffers the body anyway, so
|
|
35
|
+
* 4-chunk vs 1-chunk is identical to the client. Real per-token streaming
|
|
36
|
+
* lands when the dispatch layer flushes per chunk (out of scope here).
|
|
37
|
+
*
|
|
38
|
+
* Chat-template rendering: the LiteRT backend ships a hard-coded Llama-3
|
|
39
|
+
* template (`<|begin_of_text|><|start_header_id|>{role}<|end_header_id|>\n\n{content}<|eot_id|>`).
|
|
40
|
+
* Most Llama-3.x family checkpoints accept it. Non-Llama checkpoints
|
|
41
|
+
* require the consumer to pre-render the prompt themselves and send a
|
|
42
|
+
* single user message — see the consumer guide.
|
|
43
|
+
*
|
|
44
|
+
* All generator-touching paths are serialized via [generatorMutex] because
|
|
45
|
+
* the underlying LiteRT [com.google.ai.edge.litert.CompiledModel] keeps
|
|
46
|
+
* an internal KV-cache state across calls; concurrent requests would
|
|
47
|
+
* interleave tokens from different conversations.
|
|
48
|
+
*/
|
|
49
|
+
class LiteRTHandlers internal constructor(
|
|
50
|
+
private val generator: LiteRTGenerator,
|
|
51
|
+
private val modelId: String,
|
|
52
|
+
/** Extra opt to override the default Llama-3 chat-template renderer. */
|
|
53
|
+
private val chatTemplate: ChatTemplateRenderer = ChatTemplateRenderer.LLAMA3,
|
|
54
|
+
private val maxNewTokensDefault: Int = 256,
|
|
55
|
+
) : DvaiHandlers {
|
|
56
|
+
|
|
57
|
+
private val generatorMutex = Mutex()
|
|
58
|
+
|
|
59
|
+
override suspend fun handleChatCompletion(body: JsonObject, ctx: HandlerContext): HandlerResponse {
|
|
60
|
+
val messagesJson = body["messages"] as? JsonArray
|
|
61
|
+
?: return HandlerResponse.Error(400, "Missing 'messages' field")
|
|
62
|
+
val messages = mutableListOf<Pair<String, String>>()
|
|
63
|
+
for (m in messagesJson) {
|
|
64
|
+
val obj = m as? JsonObject
|
|
65
|
+
?: return HandlerResponse.Error(400, "messages entry is not an object")
|
|
66
|
+
val role = (obj["role"] as? JsonPrimitive)?.contentOrNull
|
|
67
|
+
?: return HandlerResponse.Error(400, "messages entry missing 'role'")
|
|
68
|
+
// Only string `content` is accepted in the LiteRT backend —
|
|
69
|
+
// multimodal content parts (image_url / input_audio) are not
|
|
70
|
+
// routable through bare LiteRT graphs (no mmproj equivalent).
|
|
71
|
+
val content = (obj["content"] as? JsonPrimitive)?.contentOrNull
|
|
72
|
+
?: return HandlerResponse.Error(
|
|
73
|
+
400,
|
|
74
|
+
"LiteRT backend only accepts string `content` (not multimodal arrays)",
|
|
75
|
+
)
|
|
76
|
+
messages.add(role to content)
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
val prompt = chatTemplate.render(messages)
|
|
80
|
+
|
|
81
|
+
val stream = (body["stream"] as? JsonPrimitive)?.booleanOrNull ?: false
|
|
82
|
+
@Suppress("UNUSED_VARIABLE") // kept for future per-call sampler overrides
|
|
83
|
+
val maxTokens = (body["max_tokens"] as? JsonPrimitive)?.intOrNull ?: maxNewTokensDefault
|
|
84
|
+
|
|
85
|
+
val completion: String = try {
|
|
86
|
+
generatorMutex.withLock { generator.generate(prompt) }
|
|
87
|
+
} catch (e: LiteRTBackendError.GenerationFailed) {
|
|
88
|
+
return HandlerResponse.Error(500, e.message ?: "generation failed")
|
|
89
|
+
} catch (e: Throwable) {
|
|
90
|
+
return HandlerResponse.Error(500, "unexpected error: ${e.message ?: e::class.java.simpleName}")
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
val id = "chatcmpl-" + java.util.UUID.randomUUID().toString().take(24).lowercase()
|
|
94
|
+
val created = System.currentTimeMillis() / 1000L
|
|
95
|
+
|
|
96
|
+
if (stream) {
|
|
97
|
+
val frames = buildChatStreamFrames(id = id, created = created, completion = completion)
|
|
98
|
+
return HandlerResponse.Sse(
|
|
99
|
+
flow {
|
|
100
|
+
for (frame in frames) emit(frame)
|
|
101
|
+
},
|
|
102
|
+
)
|
|
103
|
+
}
|
|
104
|
+
|
|
105
|
+
val response = buildJsonObject {
|
|
106
|
+
put("id", id)
|
|
107
|
+
put("object", "chat.completion")
|
|
108
|
+
put("created", created)
|
|
109
|
+
put("model", modelId)
|
|
110
|
+
putJsonArray("choices") {
|
|
111
|
+
addJsonObject {
|
|
112
|
+
put("index", 0)
|
|
113
|
+
putJsonObject("message") {
|
|
114
|
+
put("role", "assistant")
|
|
115
|
+
put("content", completion)
|
|
116
|
+
}
|
|
117
|
+
put("finish_reason", "stop")
|
|
118
|
+
}
|
|
119
|
+
}
|
|
120
|
+
putJsonObject("usage") {
|
|
121
|
+
put("prompt_tokens", 0)
|
|
122
|
+
put("completion_tokens", 0)
|
|
123
|
+
put("total_tokens", 0)
|
|
124
|
+
}
|
|
125
|
+
}
|
|
126
|
+
return HandlerResponse.Json(200, response)
|
|
127
|
+
}
|
|
128
|
+
|
|
129
|
+
override suspend fun handleCompletion(body: JsonObject, ctx: HandlerContext): HandlerResponse {
|
|
130
|
+
val promptField = body["prompt"]
|
|
131
|
+
val prompt: String = when {
|
|
132
|
+
promptField == null || promptField is JsonNull -> ""
|
|
133
|
+
promptField is JsonPrimitive && promptField.contentOrNull != null -> promptField.content
|
|
134
|
+
promptField is JsonArray -> promptField.joinToString("\n") {
|
|
135
|
+
(it as? JsonPrimitive)?.contentOrNull ?: ""
|
|
136
|
+
}
|
|
137
|
+
else -> return HandlerResponse.Error(400, "'prompt' must be a string or array of strings")
|
|
138
|
+
}
|
|
139
|
+
|
|
140
|
+
val chatBody = buildJsonObject {
|
|
141
|
+
for ((k, v) in body) {
|
|
142
|
+
if (k == "prompt") continue
|
|
143
|
+
put(k, v)
|
|
144
|
+
}
|
|
145
|
+
putJsonArray("messages") {
|
|
146
|
+
addJsonObject {
|
|
147
|
+
put("role", "user")
|
|
148
|
+
put("content", prompt)
|
|
149
|
+
}
|
|
150
|
+
}
|
|
151
|
+
}
|
|
152
|
+
|
|
153
|
+
val chatResp = handleChatCompletion(chatBody, ctx)
|
|
154
|
+
return when (chatResp) {
|
|
155
|
+
is HandlerResponse.Json -> {
|
|
156
|
+
// Kotlin can't smart-cast `chatResp.body` to JsonObject because
|
|
157
|
+
// HandlerResponse lives in shared-core (different module) — its
|
|
158
|
+
// public `val body` could in principle be a custom getter. Bind
|
|
159
|
+
// to a local val and cast once. (Same pattern as LlamaHandlers.)
|
|
160
|
+
val respBody = chatResp.body
|
|
161
|
+
if (chatResp.status != 200 || respBody !is JsonObject) {
|
|
162
|
+
chatResp
|
|
163
|
+
} else {
|
|
164
|
+
HandlerResponse.Json(200, chatToLegacyCompletion(respBody))
|
|
165
|
+
}
|
|
166
|
+
}
|
|
167
|
+
is HandlerResponse.Sse -> {
|
|
168
|
+
val model = (body["model"] as? JsonPrimitive)?.contentOrNull ?: modelId
|
|
169
|
+
HandlerResponse.Sse(
|
|
170
|
+
flow {
|
|
171
|
+
chatResp.flow.collect { chunk ->
|
|
172
|
+
emit(adaptChunkToLegacy(chunk, model))
|
|
173
|
+
}
|
|
174
|
+
},
|
|
175
|
+
)
|
|
176
|
+
}
|
|
177
|
+
is HandlerResponse.Error -> chatResp
|
|
178
|
+
}
|
|
179
|
+
}
|
|
180
|
+
|
|
181
|
+
override suspend fun handleEmbeddings(body: JsonObject, ctx: HandlerContext): HandlerResponse {
|
|
182
|
+
// LiteRT raw .tflite graphs don't expose a native embedding head —
|
|
183
|
+
// the consumer-facing model is purely a logits-producing language
|
|
184
|
+
// model. Surfacing this as 501 Not Implemented matches the
|
|
185
|
+
// OpenAI-compatible "this server doesn't support that endpoint"
|
|
186
|
+
// shape; consumers wanting embeddings should use the llama
|
|
187
|
+
// backend with `embeddingMode: true`.
|
|
188
|
+
return HandlerResponse.Error(501, "embeddings not supported by LiteRT backend")
|
|
189
|
+
}
|
|
190
|
+
|
|
191
|
+
override suspend fun handleModels(ctx: HandlerContext): HandlerResponse =
|
|
192
|
+
HandlerResponse.Json(
|
|
193
|
+
200,
|
|
194
|
+
buildJsonObject {
|
|
195
|
+
put("object", "list")
|
|
196
|
+
putJsonArray("data") {
|
|
197
|
+
addJsonObject {
|
|
198
|
+
put("id", ctx.modelId)
|
|
199
|
+
put("object", "model")
|
|
200
|
+
put("owned_by", "dvai-bridge")
|
|
201
|
+
}
|
|
202
|
+
}
|
|
203
|
+
},
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
// ----- helpers -----
|
|
207
|
+
|
|
208
|
+
/**
|
|
209
|
+
* 4 SSE frames: role / content / finish / [DONE]. Same shape as
|
|
210
|
+
* `LlamaHandlers` and `FoundationHandlers` — see the comparison in
|
|
211
|
+
* `docs/development/handler-parity.md`.
|
|
212
|
+
*/
|
|
213
|
+
private fun buildChatStreamFrames(id: String, created: Long, completion: String): List<String> {
|
|
214
|
+
val out = mutableListOf<String>()
|
|
215
|
+
val role = buildJsonObject {
|
|
216
|
+
put("id", id)
|
|
217
|
+
put("object", "chat.completion.chunk")
|
|
218
|
+
put("created", created)
|
|
219
|
+
put("model", modelId)
|
|
220
|
+
putJsonArray("choices") {
|
|
221
|
+
addJsonObject {
|
|
222
|
+
put("index", 0)
|
|
223
|
+
putJsonObject("delta") { put("role", "assistant") }
|
|
224
|
+
}
|
|
225
|
+
}
|
|
226
|
+
}
|
|
227
|
+
out += "data: $role\n\n"
|
|
228
|
+
|
|
229
|
+
val content = buildJsonObject {
|
|
230
|
+
put("id", id)
|
|
231
|
+
put("object", "chat.completion.chunk")
|
|
232
|
+
put("created", created)
|
|
233
|
+
put("model", modelId)
|
|
234
|
+
putJsonArray("choices") {
|
|
235
|
+
addJsonObject {
|
|
236
|
+
put("index", 0)
|
|
237
|
+
putJsonObject("delta") { put("content", completion) }
|
|
238
|
+
}
|
|
239
|
+
}
|
|
240
|
+
}
|
|
241
|
+
out += "data: $content\n\n"
|
|
242
|
+
|
|
243
|
+
val finish = buildJsonObject {
|
|
244
|
+
put("id", id)
|
|
245
|
+
put("object", "chat.completion.chunk")
|
|
246
|
+
put("created", created)
|
|
247
|
+
put("model", modelId)
|
|
248
|
+
putJsonArray("choices") {
|
|
249
|
+
addJsonObject {
|
|
250
|
+
put("index", 0)
|
|
251
|
+
putJsonObject("delta") { /* empty */ }
|
|
252
|
+
put("finish_reason", "stop")
|
|
253
|
+
}
|
|
254
|
+
}
|
|
255
|
+
}
|
|
256
|
+
out += "data: $finish\n\n"
|
|
257
|
+
|
|
258
|
+
out += "data: [DONE]\n\n"
|
|
259
|
+
return out
|
|
260
|
+
}
|
|
261
|
+
|
|
262
|
+
/** Mirrors `chatToLegacyCompletion()` from `LlamaHandlers.kt` 1:1. */
|
|
263
|
+
private fun chatToLegacyCompletion(chat: JsonObject): JsonObject = buildJsonObject {
|
|
264
|
+
val chatId = (chat["id"] as? JsonPrimitive)?.contentOrNull ?: ""
|
|
265
|
+
val cmplId = if (chatId.isEmpty()) "cmpl-${System.currentTimeMillis() / 1000L}"
|
|
266
|
+
else chatId.replace("chatcmpl-", "cmpl-")
|
|
267
|
+
put("id", cmplId)
|
|
268
|
+
put("object", "text_completion")
|
|
269
|
+
chat["created"]?.let { put("created", it) }
|
|
270
|
+
?: put("created", System.currentTimeMillis() / 1000L)
|
|
271
|
+
put("model", (chat["model"] as? JsonPrimitive)?.contentOrNull ?: modelId)
|
|
272
|
+
putJsonArray("choices") {
|
|
273
|
+
val chatChoices = chat["choices"] as? JsonArray ?: JsonArray(emptyList())
|
|
274
|
+
for (c in chatChoices) {
|
|
275
|
+
val co = c as? JsonObject ?: continue
|
|
276
|
+
addJsonObject {
|
|
277
|
+
val msg = co["message"] as? JsonObject
|
|
278
|
+
put("text", (msg?.get("content") as? JsonPrimitive)?.contentOrNull ?: "")
|
|
279
|
+
put("index", (co["index"] as? JsonPrimitive)?.intOrNull ?: 0)
|
|
280
|
+
put(
|
|
281
|
+
"finish_reason",
|
|
282
|
+
(co["finish_reason"] as? JsonPrimitive)?.contentOrNull ?: "stop",
|
|
283
|
+
)
|
|
284
|
+
put("logprobs", JsonNull)
|
|
285
|
+
}
|
|
286
|
+
}
|
|
287
|
+
}
|
|
288
|
+
val usage = chat["usage"] as? JsonObject
|
|
289
|
+
if (usage != null) {
|
|
290
|
+
put("usage", usage)
|
|
291
|
+
} else {
|
|
292
|
+
putJsonObject("usage") {
|
|
293
|
+
put("prompt_tokens", 0)
|
|
294
|
+
put("completion_tokens", 0)
|
|
295
|
+
put("total_tokens", 0)
|
|
296
|
+
}
|
|
297
|
+
}
|
|
298
|
+
}
|
|
299
|
+
|
|
300
|
+
/** Adapt a single SSE frame from chat.completion.chunk -> text_completion.chunk. */
|
|
301
|
+
private fun adaptChunkToLegacy(chunk: String, model: String): String {
|
|
302
|
+
val trimmed = chunk.trim()
|
|
303
|
+
if (!trimmed.startsWith("data:")) return chunk
|
|
304
|
+
val payload = trimmed.removePrefix("data:").trim()
|
|
305
|
+
if (payload == "[DONE]") return "data: [DONE]\n\n"
|
|
306
|
+
val parsed = try {
|
|
307
|
+
kotlinx.serialization.json.Json.parseToJsonElement(payload) as? JsonObject ?: return chunk
|
|
308
|
+
} catch (_: Exception) {
|
|
309
|
+
return chunk
|
|
310
|
+
}
|
|
311
|
+
val chatId = (parsed["id"] as? JsonPrimitive)?.contentOrNull ?: ""
|
|
312
|
+
val id = chatId.replace("chatcmpl-", "cmpl-")
|
|
313
|
+
val legacy = buildJsonObject {
|
|
314
|
+
put("id", id)
|
|
315
|
+
put("object", "text_completion.chunk")
|
|
316
|
+
parsed["created"]?.let { put("created", it) }
|
|
317
|
+
?: put("created", System.currentTimeMillis() / 1000L)
|
|
318
|
+
put("model", (parsed["model"] as? JsonPrimitive)?.contentOrNull ?: model)
|
|
319
|
+
putJsonArray("choices") {
|
|
320
|
+
val chatChoices = parsed["choices"] as? JsonArray ?: JsonArray(emptyList())
|
|
321
|
+
for (c in chatChoices) {
|
|
322
|
+
val co = c as? JsonObject ?: continue
|
|
323
|
+
addJsonObject {
|
|
324
|
+
val delta = co["delta"] as? JsonObject
|
|
325
|
+
put("text", (delta?.get("content") as? JsonPrimitive)?.contentOrNull ?: "")
|
|
326
|
+
put("index", (co["index"] as? JsonPrimitive)?.intOrNull ?: 0)
|
|
327
|
+
val fr = co["finish_reason"]
|
|
328
|
+
if (fr is JsonPrimitive && fr.contentOrNull != null) {
|
|
329
|
+
put("finish_reason", fr.content)
|
|
330
|
+
} else {
|
|
331
|
+
put("finish_reason", JsonNull)
|
|
332
|
+
}
|
|
333
|
+
put("logprobs", JsonNull)
|
|
334
|
+
}
|
|
335
|
+
}
|
|
336
|
+
}
|
|
337
|
+
}
|
|
338
|
+
return "data: $legacy\n\n"
|
|
339
|
+
}
|
|
340
|
+
}
|
|
341
|
+
|
|
342
|
+
/**
|
|
343
|
+
* Stringifier for OpenAI-style messages. The LiteRT backend has no Jinja
|
|
344
|
+
* engine, so we ship a couple of well-known templates as Kotlin code and
|
|
345
|
+
* let the consumer pick. Defaults to LLAMA3 — the dominant tokenizer
|
|
346
|
+
* family on Maven Central LiteRT checkpoints in 2026.
|
|
347
|
+
*/
|
|
348
|
+
public enum class ChatTemplateRenderer {
|
|
349
|
+
LLAMA3 {
|
|
350
|
+
override fun render(messages: List<Pair<String, String>>): String {
|
|
351
|
+
val sb = StringBuilder()
|
|
352
|
+
sb.append("<|begin_of_text|>")
|
|
353
|
+
for ((role, content) in messages) {
|
|
354
|
+
sb.append("<|start_header_id|>").append(role).append("<|end_header_id|>\n\n")
|
|
355
|
+
sb.append(content).append("<|eot_id|>")
|
|
356
|
+
}
|
|
357
|
+
sb.append("<|start_header_id|>assistant<|end_header_id|>\n\n")
|
|
358
|
+
return sb.toString()
|
|
359
|
+
}
|
|
360
|
+
},
|
|
361
|
+
|
|
362
|
+
/**
|
|
363
|
+
* Dumb concatenation — appends each message as `role: content\n` and
|
|
364
|
+
* a trailing `assistant:`. For checkpoints with no chat template at
|
|
365
|
+
* all (raw text-completion .tflite files), or when the consumer
|
|
366
|
+
* already pre-rendered the prompt and sends a single user message.
|
|
367
|
+
*/
|
|
368
|
+
PLAIN {
|
|
369
|
+
override fun render(messages: List<Pair<String, String>>): String {
|
|
370
|
+
val sb = StringBuilder()
|
|
371
|
+
for ((role, content) in messages) sb.append(role).append(": ").append(content).append("\n")
|
|
372
|
+
sb.append("assistant: ")
|
|
373
|
+
return sb.toString()
|
|
374
|
+
}
|
|
375
|
+
};
|
|
376
|
+
|
|
377
|
+
abstract fun render(messages: List<Pair<String, String>>): String
|
|
378
|
+
}
|