@dvai-bridge/android-litert-core 4.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,71 @@
1
+ package co.deepvoiceai.bridge.litert.core.Internal
2
+
3
+ import co.deepvoiceai.bridge.litert.core.LiteRTBackendError
4
+ import kotlinx.coroutines.Dispatchers
5
+ import kotlinx.coroutines.withContext
6
+
7
+ /**
8
+ * Orchestrates [LiteRTEngineApi] + [HFTokenizerJson] + [LiteRTSampler]
9
+ * into a single autoregressive generate() call.
10
+ *
11
+ * Mirrors `CoreMLGenerator.swift` (iOS) — same prefill+decode unification:
12
+ * each [LiteRTEngineApi.runStep] returns logits for the *next* token
13
+ * (position kv+1), so after feeding all prompt tokens the last logits
14
+ * give the first generated token. We do NOT re-feed `prompt.last()` —
15
+ * that would double-count it in the KV cache.
16
+ *
17
+ * Concurrency: this class is suspend-only; long blocking native calls
18
+ * (`engine.runStep`) are wrapped in `withContext(Dispatchers.Default)` so
19
+ * the caller's coroutine doesn't pin the main thread. Caller must
20
+ * serialize calls behind a mutex (see [LiteRTHandlers.generatorMutex]).
21
+ */
22
+ internal class LiteRTGenerator(
23
+ private val engine: LiteRTEngineApi,
24
+ private val tokenizer: HFTokenizerJson,
25
+ private val sampler: LiteRTSampler,
26
+ private val maxNewTokens: Int,
27
+ ) {
28
+
29
+ /**
30
+ * Tokenize [prompt], run the prefill+decode loop, and return the
31
+ * decoded completion (without re-emitting the prompt). Stops on EOS or
32
+ * once [maxNewTokens] generated tokens have been produced.
33
+ *
34
+ * Throws [LiteRTBackendError.GenerationFailed] on empty prompt or
35
+ * native errors propagated from [engine].
36
+ */
37
+ suspend fun generate(prompt: String): String = withContext(Dispatchers.Default) {
38
+ val promptTokens = tokenizer.encode(prompt)
39
+ if (promptTokens.isEmpty()) {
40
+ throw LiteRTBackendError.GenerationFailed("prompt tokenized to empty list")
41
+ }
42
+
43
+ val generated = mutableListOf<Int>()
44
+
45
+ // Prefill: feed each prompt token, advancing kv position. After
46
+ // the last prompt token, `lastLogits` predicts the FIRST output
47
+ // token — no separate "decode kickoff" step needed.
48
+ var lastLogits = engine.runStep(promptTokens[0], kvCachePosition = 0)
49
+ var kvPos = 1
50
+ for (i in 1 until promptTokens.size) {
51
+ lastLogits = engine.runStep(promptTokens[i], kvCachePosition = kvPos)
52
+ kvPos += 1
53
+ }
54
+
55
+ var nextToken = sampler.sample(lastLogits)
56
+ var produced = 0
57
+ while (produced < maxNewTokens) {
58
+ if (nextToken == engine.eosTokenId) break
59
+ generated.add(nextToken)
60
+ produced += 1
61
+ // Stop early if we already hit the cap — saves one wasted
62
+ // forward pass that would just be discarded.
63
+ if (produced >= maxNewTokens) break
64
+ lastLogits = engine.runStep(nextToken, kvCachePosition = kvPos)
65
+ kvPos += 1
66
+ nextToken = sampler.sample(lastLogits)
67
+ }
68
+
69
+ tokenizer.decode(generated, skipSpecialTokens = true)
70
+ }
71
+ }
@@ -0,0 +1,105 @@
1
+ package co.deepvoiceai.bridge.litert.core.Internal
2
+
3
+ import kotlin.math.exp
4
+ import kotlin.random.Random
5
+
6
+ /**
7
+ * Greedy + temperature/top-p/top-k sampler for LiteRT logits. Pure Kotlin,
8
+ * no LiteRT or HuggingFace tokenizer dependency — purely numerical.
9
+ *
10
+ * Mirrors `CoreMLSampler.swift` (iOS) 1:1; algorithm comments are kept in
11
+ * lockstep with the iOS implementation so future fixes apply uniformly.
12
+ *
13
+ * @param temperature 0.0 = pure argmax (deterministic). >0.0 enables
14
+ * temperature scaling.
15
+ * @param topP Nucleus-sampling cutoff. 1.0 = disabled (consider all
16
+ * tokens). 0.0 < topP < 1.0 = keep tokens whose
17
+ * cumulative probability covers `topP`.
18
+ * @param topK 0 = disabled. >0 = keep only the K highest-probability
19
+ * tokens before sampling.
20
+ * @param random Source of randomness. Test code seeds for determinism.
21
+ */
22
+ internal class LiteRTSampler(
23
+ private val temperature: Float,
24
+ private val topP: Float,
25
+ private val topK: Int,
26
+ private val random: Random = Random.Default,
27
+ ) {
28
+ fun sample(logits: FloatArray): Int {
29
+ // Fast path: temperature == 0 -> pure argmax. Avoids softmax allocation.
30
+ if (temperature <= 0f) return argmax(logits)
31
+
32
+ // Apply temperature: divide each logit by T, then softmax over the
33
+ // result. Larger T flattens the distribution, smaller T sharpens.
34
+ val scaled = FloatArray(logits.size) { logits[it] / temperature }
35
+ val probs = softmax(scaled)
36
+
37
+ // Build (idx, prob) pairs, sorted descending by prob. Allocates an
38
+ // index array proportional to vocab size — typically 32k-256k.
39
+ val sortedIdxByProbDesc = (probs.indices)
40
+ .sortedByDescending { probs[it] }
41
+
42
+ // top-k truncation: keep only the first K entries.
43
+ val afterTopK = if (topK > 0 && topK < sortedIdxByProbDesc.size) {
44
+ sortedIdxByProbDesc.subList(0, topK)
45
+ } else {
46
+ sortedIdxByProbDesc
47
+ }
48
+
49
+ // top-p (nucleus) truncation: keep prefix of indices whose cumulative
50
+ // probability covers `topP`. Always keep at least one token.
51
+ val afterTopP = if (topP > 0f && topP < 1f) {
52
+ val kept = mutableListOf<Int>()
53
+ var cum = 0f
54
+ for (idx in afterTopK) {
55
+ kept.add(idx)
56
+ cum += probs[idx]
57
+ if (cum >= topP) break
58
+ }
59
+ kept
60
+ } else {
61
+ afterTopK
62
+ }
63
+
64
+ // Renormalize the kept distribution and sample multinomially.
65
+ val keptProbs = afterTopP.map { probs[it] }
66
+ val total = keptProbs.sum()
67
+ val r = random.nextFloat() * total
68
+ var acc = 0f
69
+ for ((i, p) in keptProbs.withIndex()) {
70
+ acc += p
71
+ if (r <= acc) return afterTopP[i]
72
+ }
73
+ // Numerical fall-through: return last kept index.
74
+ return afterTopP.last()
75
+ }
76
+
77
+ private fun argmax(arr: FloatArray): Int {
78
+ var best = 0
79
+ var bestVal = arr[0]
80
+ for (i in 1 until arr.size) {
81
+ if (arr[i] > bestVal) {
82
+ bestVal = arr[i]
83
+ best = i
84
+ }
85
+ }
86
+ return best
87
+ }
88
+
89
+ /**
90
+ * Numerically-stable softmax: subtract max, exp, normalize. Avoids
91
+ * overflow on large logits. Returns a new FloatArray of the same size.
92
+ */
93
+ private fun softmax(arr: FloatArray): FloatArray {
94
+ var max = arr[0]
95
+ for (i in 1 until arr.size) if (arr[i] > max) max = arr[i]
96
+ var sum = 0f
97
+ val out = FloatArray(arr.size) { i ->
98
+ val e = exp((arr[i] - max).toDouble()).toFloat()
99
+ sum += e
100
+ e
101
+ }
102
+ for (i in out.indices) out[i] /= sum
103
+ return out
104
+ }
105
+ }
@@ -0,0 +1,13 @@
1
+ package co.deepvoiceai.bridge.litert.core
2
+
3
+ /**
4
+ * Errors surfaced by the LiteRT backend. Mirrors `CoreMLBackendError` (iOS)
5
+ * and `MediaPipeBackendError` (mediapipe-core) — when the umbrella SDK
6
+ * (`@dvai-bridge/android`) maps a backend exception to a public
7
+ * `DVAIBridgeError`, it pattern-matches on the sealed type here.
8
+ */
9
+ sealed class LiteRTBackendError(message: String) : Exception(message) {
10
+ class ModelLoadFailed(reason: String) : LiteRTBackendError("LiteRT model load failed: $reason")
11
+ class TokenizerLoadFailed(reason: String) : LiteRTBackendError("LiteRT tokenizer load failed: $reason")
12
+ class GenerationFailed(reason: String) : LiteRTBackendError("LiteRT generation failed: $reason")
13
+ }
@@ -0,0 +1,378 @@
1
+ package co.deepvoiceai.bridge.litert.core
2
+
3
+ import co.deepvoiceai.bridge.litert.core.Internal.LiteRTGenerator
4
+ import co.deepvoiceai.bridge.shared.core.DvaiHandlers
5
+ import co.deepvoiceai.bridge.shared.core.HandlerContext
6
+ import co.deepvoiceai.bridge.shared.core.HandlerResponse
7
+ import kotlinx.coroutines.flow.flow
8
+ import kotlinx.coroutines.sync.Mutex
9
+ import kotlinx.coroutines.sync.withLock
10
+ import kotlinx.serialization.json.JsonArray
11
+ import kotlinx.serialization.json.JsonNull
12
+ import kotlinx.serialization.json.JsonObject
13
+ import kotlinx.serialization.json.JsonPrimitive
14
+ import kotlinx.serialization.json.addJsonObject
15
+ import kotlinx.serialization.json.booleanOrNull
16
+ import kotlinx.serialization.json.buildJsonObject
17
+ import kotlinx.serialization.json.contentOrNull
18
+ import kotlinx.serialization.json.intOrNull
19
+ import kotlinx.serialization.json.put
20
+ import kotlinx.serialization.json.putJsonArray
21
+ import kotlinx.serialization.json.putJsonObject
22
+
23
+ /**
24
+ * OpenAI-compatible handler set for the LiteRT backend.
25
+ *
26
+ * Wires [LiteRTGenerator] into the four standard routes:
27
+ * POST /v1/chat/completions → tokenize messages, generate, wrap.
28
+ * POST /v1/completions → adapt chat output to the legacy text shape.
29
+ * POST /v1/embeddings → 501 Not Implemented (LiteRT raw graphs
30
+ * don't expose an embedding head).
31
+ * GET /v1/models → standard list-with-one-item shape.
32
+ *
33
+ * Streaming for chat: same 4-frame shape (role / content / finish / [DONE])
34
+ * as `LlamaHandlers` — Ktor's responseChannel buffers the body anyway, so
35
+ * 4-chunk vs 1-chunk is identical to the client. Real per-token streaming
36
+ * lands when the dispatch layer flushes per chunk (out of scope here).
37
+ *
38
+ * Chat-template rendering: the LiteRT backend ships a hard-coded Llama-3
39
+ * template (`<|begin_of_text|><|start_header_id|>{role}<|end_header_id|>\n\n{content}<|eot_id|>`).
40
+ * Most Llama-3.x family checkpoints accept it. Non-Llama checkpoints
41
+ * require the consumer to pre-render the prompt themselves and send a
42
+ * single user message — see the consumer guide.
43
+ *
44
+ * All generator-touching paths are serialized via [generatorMutex] because
45
+ * the underlying LiteRT [com.google.ai.edge.litert.CompiledModel] keeps
46
+ * an internal KV-cache state across calls; concurrent requests would
47
+ * interleave tokens from different conversations.
48
+ */
49
+ class LiteRTHandlers internal constructor(
50
+ private val generator: LiteRTGenerator,
51
+ private val modelId: String,
52
+ /** Extra opt to override the default Llama-3 chat-template renderer. */
53
+ private val chatTemplate: ChatTemplateRenderer = ChatTemplateRenderer.LLAMA3,
54
+ private val maxNewTokensDefault: Int = 256,
55
+ ) : DvaiHandlers {
56
+
57
+ private val generatorMutex = Mutex()
58
+
59
+ override suspend fun handleChatCompletion(body: JsonObject, ctx: HandlerContext): HandlerResponse {
60
+ val messagesJson = body["messages"] as? JsonArray
61
+ ?: return HandlerResponse.Error(400, "Missing 'messages' field")
62
+ val messages = mutableListOf<Pair<String, String>>()
63
+ for (m in messagesJson) {
64
+ val obj = m as? JsonObject
65
+ ?: return HandlerResponse.Error(400, "messages entry is not an object")
66
+ val role = (obj["role"] as? JsonPrimitive)?.contentOrNull
67
+ ?: return HandlerResponse.Error(400, "messages entry missing 'role'")
68
+ // Only string `content` is accepted in the LiteRT backend —
69
+ // multimodal content parts (image_url / input_audio) are not
70
+ // routable through bare LiteRT graphs (no mmproj equivalent).
71
+ val content = (obj["content"] as? JsonPrimitive)?.contentOrNull
72
+ ?: return HandlerResponse.Error(
73
+ 400,
74
+ "LiteRT backend only accepts string `content` (not multimodal arrays)",
75
+ )
76
+ messages.add(role to content)
77
+ }
78
+
79
+ val prompt = chatTemplate.render(messages)
80
+
81
+ val stream = (body["stream"] as? JsonPrimitive)?.booleanOrNull ?: false
82
+ @Suppress("UNUSED_VARIABLE") // kept for future per-call sampler overrides
83
+ val maxTokens = (body["max_tokens"] as? JsonPrimitive)?.intOrNull ?: maxNewTokensDefault
84
+
85
+ val completion: String = try {
86
+ generatorMutex.withLock { generator.generate(prompt) }
87
+ } catch (e: LiteRTBackendError.GenerationFailed) {
88
+ return HandlerResponse.Error(500, e.message ?: "generation failed")
89
+ } catch (e: Throwable) {
90
+ return HandlerResponse.Error(500, "unexpected error: ${e.message ?: e::class.java.simpleName}")
91
+ }
92
+
93
+ val id = "chatcmpl-" + java.util.UUID.randomUUID().toString().take(24).lowercase()
94
+ val created = System.currentTimeMillis() / 1000L
95
+
96
+ if (stream) {
97
+ val frames = buildChatStreamFrames(id = id, created = created, completion = completion)
98
+ return HandlerResponse.Sse(
99
+ flow {
100
+ for (frame in frames) emit(frame)
101
+ },
102
+ )
103
+ }
104
+
105
+ val response = buildJsonObject {
106
+ put("id", id)
107
+ put("object", "chat.completion")
108
+ put("created", created)
109
+ put("model", modelId)
110
+ putJsonArray("choices") {
111
+ addJsonObject {
112
+ put("index", 0)
113
+ putJsonObject("message") {
114
+ put("role", "assistant")
115
+ put("content", completion)
116
+ }
117
+ put("finish_reason", "stop")
118
+ }
119
+ }
120
+ putJsonObject("usage") {
121
+ put("prompt_tokens", 0)
122
+ put("completion_tokens", 0)
123
+ put("total_tokens", 0)
124
+ }
125
+ }
126
+ return HandlerResponse.Json(200, response)
127
+ }
128
+
129
+ override suspend fun handleCompletion(body: JsonObject, ctx: HandlerContext): HandlerResponse {
130
+ val promptField = body["prompt"]
131
+ val prompt: String = when {
132
+ promptField == null || promptField is JsonNull -> ""
133
+ promptField is JsonPrimitive && promptField.contentOrNull != null -> promptField.content
134
+ promptField is JsonArray -> promptField.joinToString("\n") {
135
+ (it as? JsonPrimitive)?.contentOrNull ?: ""
136
+ }
137
+ else -> return HandlerResponse.Error(400, "'prompt' must be a string or array of strings")
138
+ }
139
+
140
+ val chatBody = buildJsonObject {
141
+ for ((k, v) in body) {
142
+ if (k == "prompt") continue
143
+ put(k, v)
144
+ }
145
+ putJsonArray("messages") {
146
+ addJsonObject {
147
+ put("role", "user")
148
+ put("content", prompt)
149
+ }
150
+ }
151
+ }
152
+
153
+ val chatResp = handleChatCompletion(chatBody, ctx)
154
+ return when (chatResp) {
155
+ is HandlerResponse.Json -> {
156
+ // Kotlin can't smart-cast `chatResp.body` to JsonObject because
157
+ // HandlerResponse lives in shared-core (different module) — its
158
+ // public `val body` could in principle be a custom getter. Bind
159
+ // to a local val and cast once. (Same pattern as LlamaHandlers.)
160
+ val respBody = chatResp.body
161
+ if (chatResp.status != 200 || respBody !is JsonObject) {
162
+ chatResp
163
+ } else {
164
+ HandlerResponse.Json(200, chatToLegacyCompletion(respBody))
165
+ }
166
+ }
167
+ is HandlerResponse.Sse -> {
168
+ val model = (body["model"] as? JsonPrimitive)?.contentOrNull ?: modelId
169
+ HandlerResponse.Sse(
170
+ flow {
171
+ chatResp.flow.collect { chunk ->
172
+ emit(adaptChunkToLegacy(chunk, model))
173
+ }
174
+ },
175
+ )
176
+ }
177
+ is HandlerResponse.Error -> chatResp
178
+ }
179
+ }
180
+
181
+ override suspend fun handleEmbeddings(body: JsonObject, ctx: HandlerContext): HandlerResponse {
182
+ // LiteRT raw .tflite graphs don't expose a native embedding head —
183
+ // the consumer-facing model is purely a logits-producing language
184
+ // model. Surfacing this as 501 Not Implemented matches the
185
+ // OpenAI-compatible "this server doesn't support that endpoint"
186
+ // shape; consumers wanting embeddings should use the llama
187
+ // backend with `embeddingMode: true`.
188
+ return HandlerResponse.Error(501, "embeddings not supported by LiteRT backend")
189
+ }
190
+
191
+ override suspend fun handleModels(ctx: HandlerContext): HandlerResponse =
192
+ HandlerResponse.Json(
193
+ 200,
194
+ buildJsonObject {
195
+ put("object", "list")
196
+ putJsonArray("data") {
197
+ addJsonObject {
198
+ put("id", ctx.modelId)
199
+ put("object", "model")
200
+ put("owned_by", "dvai-bridge")
201
+ }
202
+ }
203
+ },
204
+ )
205
+
206
+ // ----- helpers -----
207
+
208
+ /**
209
+ * 4 SSE frames: role / content / finish / [DONE]. Same shape as
210
+ * `LlamaHandlers` and `FoundationHandlers` — see the comparison in
211
+ * `docs/development/handler-parity.md`.
212
+ */
213
+ private fun buildChatStreamFrames(id: String, created: Long, completion: String): List<String> {
214
+ val out = mutableListOf<String>()
215
+ val role = buildJsonObject {
216
+ put("id", id)
217
+ put("object", "chat.completion.chunk")
218
+ put("created", created)
219
+ put("model", modelId)
220
+ putJsonArray("choices") {
221
+ addJsonObject {
222
+ put("index", 0)
223
+ putJsonObject("delta") { put("role", "assistant") }
224
+ }
225
+ }
226
+ }
227
+ out += "data: $role\n\n"
228
+
229
+ val content = buildJsonObject {
230
+ put("id", id)
231
+ put("object", "chat.completion.chunk")
232
+ put("created", created)
233
+ put("model", modelId)
234
+ putJsonArray("choices") {
235
+ addJsonObject {
236
+ put("index", 0)
237
+ putJsonObject("delta") { put("content", completion) }
238
+ }
239
+ }
240
+ }
241
+ out += "data: $content\n\n"
242
+
243
+ val finish = buildJsonObject {
244
+ put("id", id)
245
+ put("object", "chat.completion.chunk")
246
+ put("created", created)
247
+ put("model", modelId)
248
+ putJsonArray("choices") {
249
+ addJsonObject {
250
+ put("index", 0)
251
+ putJsonObject("delta") { /* empty */ }
252
+ put("finish_reason", "stop")
253
+ }
254
+ }
255
+ }
256
+ out += "data: $finish\n\n"
257
+
258
+ out += "data: [DONE]\n\n"
259
+ return out
260
+ }
261
+
262
+ /** Mirrors `chatToLegacyCompletion()` from `LlamaHandlers.kt` 1:1. */
263
+ private fun chatToLegacyCompletion(chat: JsonObject): JsonObject = buildJsonObject {
264
+ val chatId = (chat["id"] as? JsonPrimitive)?.contentOrNull ?: ""
265
+ val cmplId = if (chatId.isEmpty()) "cmpl-${System.currentTimeMillis() / 1000L}"
266
+ else chatId.replace("chatcmpl-", "cmpl-")
267
+ put("id", cmplId)
268
+ put("object", "text_completion")
269
+ chat["created"]?.let { put("created", it) }
270
+ ?: put("created", System.currentTimeMillis() / 1000L)
271
+ put("model", (chat["model"] as? JsonPrimitive)?.contentOrNull ?: modelId)
272
+ putJsonArray("choices") {
273
+ val chatChoices = chat["choices"] as? JsonArray ?: JsonArray(emptyList())
274
+ for (c in chatChoices) {
275
+ val co = c as? JsonObject ?: continue
276
+ addJsonObject {
277
+ val msg = co["message"] as? JsonObject
278
+ put("text", (msg?.get("content") as? JsonPrimitive)?.contentOrNull ?: "")
279
+ put("index", (co["index"] as? JsonPrimitive)?.intOrNull ?: 0)
280
+ put(
281
+ "finish_reason",
282
+ (co["finish_reason"] as? JsonPrimitive)?.contentOrNull ?: "stop",
283
+ )
284
+ put("logprobs", JsonNull)
285
+ }
286
+ }
287
+ }
288
+ val usage = chat["usage"] as? JsonObject
289
+ if (usage != null) {
290
+ put("usage", usage)
291
+ } else {
292
+ putJsonObject("usage") {
293
+ put("prompt_tokens", 0)
294
+ put("completion_tokens", 0)
295
+ put("total_tokens", 0)
296
+ }
297
+ }
298
+ }
299
+
300
+ /** Adapt a single SSE frame from chat.completion.chunk -> text_completion.chunk. */
301
+ private fun adaptChunkToLegacy(chunk: String, model: String): String {
302
+ val trimmed = chunk.trim()
303
+ if (!trimmed.startsWith("data:")) return chunk
304
+ val payload = trimmed.removePrefix("data:").trim()
305
+ if (payload == "[DONE]") return "data: [DONE]\n\n"
306
+ val parsed = try {
307
+ kotlinx.serialization.json.Json.parseToJsonElement(payload) as? JsonObject ?: return chunk
308
+ } catch (_: Exception) {
309
+ return chunk
310
+ }
311
+ val chatId = (parsed["id"] as? JsonPrimitive)?.contentOrNull ?: ""
312
+ val id = chatId.replace("chatcmpl-", "cmpl-")
313
+ val legacy = buildJsonObject {
314
+ put("id", id)
315
+ put("object", "text_completion.chunk")
316
+ parsed["created"]?.let { put("created", it) }
317
+ ?: put("created", System.currentTimeMillis() / 1000L)
318
+ put("model", (parsed["model"] as? JsonPrimitive)?.contentOrNull ?: model)
319
+ putJsonArray("choices") {
320
+ val chatChoices = parsed["choices"] as? JsonArray ?: JsonArray(emptyList())
321
+ for (c in chatChoices) {
322
+ val co = c as? JsonObject ?: continue
323
+ addJsonObject {
324
+ val delta = co["delta"] as? JsonObject
325
+ put("text", (delta?.get("content") as? JsonPrimitive)?.contentOrNull ?: "")
326
+ put("index", (co["index"] as? JsonPrimitive)?.intOrNull ?: 0)
327
+ val fr = co["finish_reason"]
328
+ if (fr is JsonPrimitive && fr.contentOrNull != null) {
329
+ put("finish_reason", fr.content)
330
+ } else {
331
+ put("finish_reason", JsonNull)
332
+ }
333
+ put("logprobs", JsonNull)
334
+ }
335
+ }
336
+ }
337
+ }
338
+ return "data: $legacy\n\n"
339
+ }
340
+ }
341
+
342
+ /**
343
+ * Stringifier for OpenAI-style messages. The LiteRT backend has no Jinja
344
+ * engine, so we ship a couple of well-known templates as Kotlin code and
345
+ * let the consumer pick. Defaults to LLAMA3 — the dominant tokenizer
346
+ * family on Maven Central LiteRT checkpoints in 2026.
347
+ */
348
+ public enum class ChatTemplateRenderer {
349
+ LLAMA3 {
350
+ override fun render(messages: List<Pair<String, String>>): String {
351
+ val sb = StringBuilder()
352
+ sb.append("<|begin_of_text|>")
353
+ for ((role, content) in messages) {
354
+ sb.append("<|start_header_id|>").append(role).append("<|end_header_id|>\n\n")
355
+ sb.append(content).append("<|eot_id|>")
356
+ }
357
+ sb.append("<|start_header_id|>assistant<|end_header_id|>\n\n")
358
+ return sb.toString()
359
+ }
360
+ },
361
+
362
+ /**
363
+ * Dumb concatenation — appends each message as `role: content\n` and
364
+ * a trailing `assistant:`. For checkpoints with no chat template at
365
+ * all (raw text-completion .tflite files), or when the consumer
366
+ * already pre-rendered the prompt and sends a single user message.
367
+ */
368
+ PLAIN {
369
+ override fun render(messages: List<Pair<String, String>>): String {
370
+ val sb = StringBuilder()
371
+ for ((role, content) in messages) sb.append(role).append(": ").append(content).append("\n")
372
+ sb.append("assistant: ")
373
+ return sb.toString()
374
+ }
375
+ };
376
+
377
+ abstract fun render(messages: List<Pair<String, String>>): String
378
+ }