liter_llm 1.0.0.pre.rc.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. checksums.yaml +7 -0
  2. data/README.md +239 -0
  3. data/ext/liter_llm_rb/extconf.rb +65 -0
  4. data/ext/liter_llm_rb/native/.cargo/config.toml +23 -0
  5. data/ext/liter_llm_rb/native/Cargo.lock +3713 -0
  6. data/ext/liter_llm_rb/native/Cargo.toml +32 -0
  7. data/ext/liter_llm_rb/native/build.rs +15 -0
  8. data/ext/liter_llm_rb/native/src/lib.rs +1079 -0
  9. data/lib/liter_llm.rb +8 -0
  10. data/sig/liter_llm.rbs +416 -0
  11. data/vendor/Cargo.toml +54 -0
  12. data/vendor/liter-llm/Cargo.toml +92 -0
  13. data/vendor/liter-llm/README.md +252 -0
  14. data/vendor/liter-llm/schemas/pricing.json +40 -0
  15. data/vendor/liter-llm/schemas/providers.json +1662 -0
  16. data/vendor/liter-llm/src/auth/azure_ad.rs +264 -0
  17. data/vendor/liter-llm/src/auth/bedrock_sts.rs +353 -0
  18. data/vendor/liter-llm/src/auth/mod.rs +68 -0
  19. data/vendor/liter-llm/src/auth/vertex_oauth.rs +353 -0
  20. data/vendor/liter-llm/src/client/config.rs +351 -0
  21. data/vendor/liter-llm/src/client/managed.rs +622 -0
  22. data/vendor/liter-llm/src/client/mod.rs +864 -0
  23. data/vendor/liter-llm/src/cost.rs +212 -0
  24. data/vendor/liter-llm/src/error.rs +190 -0
  25. data/vendor/liter-llm/src/http/eventstream.rs +860 -0
  26. data/vendor/liter-llm/src/http/mod.rs +12 -0
  27. data/vendor/liter-llm/src/http/request.rs +438 -0
  28. data/vendor/liter-llm/src/http/retry.rs +72 -0
  29. data/vendor/liter-llm/src/http/streaming.rs +289 -0
  30. data/vendor/liter-llm/src/lib.rs +37 -0
  31. data/vendor/liter-llm/src/provider/anthropic.rs +2250 -0
  32. data/vendor/liter-llm/src/provider/azure.rs +579 -0
  33. data/vendor/liter-llm/src/provider/bedrock.rs +1543 -0
  34. data/vendor/liter-llm/src/provider/cohere.rs +654 -0
  35. data/vendor/liter-llm/src/provider/custom.rs +404 -0
  36. data/vendor/liter-llm/src/provider/google_ai.rs +281 -0
  37. data/vendor/liter-llm/src/provider/mistral.rs +188 -0
  38. data/vendor/liter-llm/src/provider/mod.rs +616 -0
  39. data/vendor/liter-llm/src/provider/vertex.rs +1504 -0
  40. data/vendor/liter-llm/src/tests.rs +1425 -0
  41. data/vendor/liter-llm/src/tokenizer.rs +281 -0
  42. data/vendor/liter-llm/src/tower/budget.rs +599 -0
  43. data/vendor/liter-llm/src/tower/cache.rs +502 -0
  44. data/vendor/liter-llm/src/tower/cache_opendal.rs +270 -0
  45. data/vendor/liter-llm/src/tower/cooldown.rs +231 -0
  46. data/vendor/liter-llm/src/tower/cost.rs +404 -0
  47. data/vendor/liter-llm/src/tower/fallback.rs +121 -0
  48. data/vendor/liter-llm/src/tower/health.rs +219 -0
  49. data/vendor/liter-llm/src/tower/hooks.rs +369 -0
  50. data/vendor/liter-llm/src/tower/mod.rs +77 -0
  51. data/vendor/liter-llm/src/tower/rate_limit.rs +300 -0
  52. data/vendor/liter-llm/src/tower/router.rs +436 -0
  53. data/vendor/liter-llm/src/tower/service.rs +181 -0
  54. data/vendor/liter-llm/src/tower/tests.rs +539 -0
  55. data/vendor/liter-llm/src/tower/tests_common.rs +252 -0
  56. data/vendor/liter-llm/src/tower/tracing.rs +209 -0
  57. data/vendor/liter-llm/src/tower/types.rs +170 -0
  58. data/vendor/liter-llm/src/types/audio.rs +52 -0
  59. data/vendor/liter-llm/src/types/batch.rs +77 -0
  60. data/vendor/liter-llm/src/types/chat.rs +214 -0
  61. data/vendor/liter-llm/src/types/common.rs +244 -0
  62. data/vendor/liter-llm/src/types/embedding.rs +84 -0
  63. data/vendor/liter-llm/src/types/files.rs +58 -0
  64. data/vendor/liter-llm/src/types/image.rs +40 -0
  65. data/vendor/liter-llm/src/types/mod.rs +27 -0
  66. data/vendor/liter-llm/src/types/models.rs +21 -0
  67. data/vendor/liter-llm/src/types/moderation.rs +80 -0
  68. data/vendor/liter-llm/src/types/ocr.rs +87 -0
  69. data/vendor/liter-llm/src/types/rerank.rs +46 -0
  70. data/vendor/liter-llm/src/types/responses.rs +55 -0
  71. data/vendor/liter-llm/src/types/search.rs +45 -0
  72. data/vendor/liter-llm/tests/contract.rs +332 -0
  73. data/vendor/liter-llm-ffi/Cargo.toml +30 -0
  74. data/vendor/liter-llm-ffi/build.rs +66 -0
  75. data/vendor/liter-llm-ffi/cbindgen.toml +60 -0
  76. data/vendor/liter-llm-ffi/liter_llm.h +850 -0
  77. data/vendor/liter-llm-ffi/src/lib.rs +2488 -0
  78. metadata +286 -0
@@ -0,0 +1,1504 @@
1
+ use std::borrow::Cow;
2
+ use std::sync::atomic::{AtomicU64, Ordering};
3
+
4
+ use crate::error::{LiterLlmError, Result};
5
+ use crate::provider::Provider;
6
+ use crate::types::ChatCompletionChunk;
7
+
8
+ /// Default Vertex AI location when none is specified.
9
+ const DEFAULT_LOCATION: &str = "us-central1";
10
+
11
+ /// Global counter for generating unique tool call IDs.
12
+ static TOOL_CALL_COUNTER: AtomicU64 = AtomicU64::new(0);
13
+
14
+ /// Google Vertex AI / Gemini provider.
15
+ ///
16
+ /// Differences from the OpenAI-compatible baseline:
17
+ /// - Auth uses `Authorization: Bearer <token>` where the token is a Google
18
+ /// Cloud OAuth2 access token (obtained via ADC, service account, or
19
+ /// `gcloud auth print-access-token`).
20
+ /// - The base URL is constructed from `VERTEXAI_PROJECT` and `VERTEXAI_LOCATION`
21
+ /// environment variables, or can be overridden via `base_url` in [`ClientConfig`].
22
+ /// The resulting URL follows the pattern:
23
+ /// `https://{location}-aiplatform.googleapis.com/v1/projects/{project}/locations/{location}`
24
+ /// - Model names are routed via the `vertex_ai/` prefix which is stripped
25
+ /// before being sent in the request body.
26
+ /// - The native Gemini `generateContent` format is used, not the OpenAI
27
+ /// `/chat/completions` path. Request and response are translated accordingly.
28
+ /// - Streaming uses SSE with `?alt=sse`; each chunk is a full `generateContent`
29
+ /// response JSON wrapped in a standard SSE `data:` line.
30
+ ///
31
+ /// # Token management
32
+ ///
33
+ /// Supply a pre-obtained access token as the `api_key` parameter.
34
+ /// Token refresh is the caller's responsibility. A future release will add
35
+ /// ADC / service-account-based automatic refresh.
36
+ ///
37
+ /// # Environment variables
38
+ ///
39
+ /// - `VERTEXAI_PROJECT` (required): Google Cloud project ID.
40
+ /// - `VERTEXAI_LOCATION` (optional): GCP region, defaults to `us-central1`.
41
+ ///
42
+ /// # Configuration
43
+ ///
44
+ /// ```rust,ignore
45
+ /// // Option 1: Use environment variables (recommended).
46
+ /// // export VERTEXAI_PROJECT=my-project
47
+ /// // export VERTEXAI_LOCATION=us-central1
48
+ /// let config = ClientConfigBuilder::new("ya29.your-access-token").build();
49
+ /// let client = DefaultClient::new(config, Some("vertex_ai/gemini-2.0-flash"))?;
50
+ ///
51
+ /// // Option 2: Explicit base_url override (bypasses env var resolution).
52
+ /// let config = ClientConfigBuilder::new("ya29.your-access-token")
53
+ /// .base_url(
54
+ /// "https://us-central1-aiplatform.googleapis.com/v1/\
55
+ /// projects/my-project/locations/us-central1",
56
+ /// )
57
+ /// .build();
58
+ /// let client = DefaultClient::new(config, Some("vertex_ai/gemini-2.0-flash"))?;
59
+ /// ```
60
+ pub struct VertexAiProvider {
61
+ /// Cached base URL: `https://{location}-aiplatform.googleapis.com/v1/projects/{project}/locations/{location}`.
62
+ base_url: String,
63
+ }
64
+
65
+ impl VertexAiProvider {
66
+ /// Construct with an explicit project and location.
67
+ #[must_use]
68
+ pub fn new(project: impl Into<String>, location: impl Into<String>) -> Self {
69
+ let project = project.into();
70
+ let location = location.into();
71
+ let base_url =
72
+ format!("https://{location}-aiplatform.googleapis.com/v1/projects/{project}/locations/{location}");
73
+ Self { base_url }
74
+ }
75
+
76
+ /// Construct from environment variables.
77
+ ///
78
+ /// Reads `VERTEXAI_PROJECT` and `VERTEXAI_LOCATION` (defaults to `us-central1`).
79
+ /// If `VERTEXAI_PROJECT` is not set, the base URL will be empty and
80
+ /// [`validate`] will return an error.
81
+ #[must_use]
82
+ pub fn from_env() -> Self {
83
+ let project = std::env::var("VERTEXAI_PROJECT").unwrap_or_default();
84
+ let location = std::env::var("VERTEXAI_LOCATION").unwrap_or_else(|_| DEFAULT_LOCATION.to_owned());
85
+ if project.is_empty() {
86
+ return Self {
87
+ base_url: String::new(),
88
+ };
89
+ }
90
+ Self::new(project, location)
91
+ }
92
+ }
93
+
94
+ impl Provider for VertexAiProvider {
95
+ fn name(&self) -> &str {
96
+ "vertex_ai"
97
+ }
98
+
99
+ /// Vertex AI base URL constructed from project and location.
100
+ ///
101
+ /// Returns an empty string when the provider was constructed without a
102
+ /// valid project (e.g. `VERTEXAI_PROJECT` not set). The [`validate`]
103
+ /// method catches this at client construction time.
104
+ fn base_url(&self) -> &str {
105
+ &self.base_url
106
+ }
107
+
108
+ /// Validate that required configuration is present.
109
+ ///
110
+ /// Checks that the base URL was successfully constructed from environment
111
+ /// variables (`VERTEXAI_PROJECT` is required, `VERTEXAI_LOCATION` defaults
112
+ /// to `us-central1`).
113
+ fn validate(&self) -> Result<()> {
114
+ if self.base_url.is_empty() {
115
+ return Err(LiterLlmError::BadRequest {
116
+ message: "Vertex AI requires a project ID. \
117
+ Set VERTEXAI_PROJECT (and optionally VERTEXAI_LOCATION) \
118
+ in the environment, or provide an explicit base_url in \
119
+ ClientConfig."
120
+ .into(),
121
+ });
122
+ }
123
+ Ok(())
124
+ }
125
+
126
+ fn auth_header<'a>(&'a self, api_key: &'a str) -> Option<(Cow<'static, str>, Cow<'a, str>)> {
127
+ // Vertex AI requires an OAuth2 Bearer token.
128
+ Some((Cow::Borrowed("Authorization"), Cow::Owned(format!("Bearer {api_key}"))))
129
+ }
130
+
131
+ fn matches_model(&self, model: &str) -> bool {
132
+ model.starts_with("vertex_ai/")
133
+ }
134
+
135
+ fn strip_model_prefix<'m>(&self, model: &'m str) -> &'m str {
136
+ model.strip_prefix("vertex_ai/").unwrap_or(model)
137
+ }
138
+
139
+ /// Build the full URL for a Gemini API request.
140
+ ///
141
+ /// Chat completions → `{base}/publishers/google/models/{model}:generateContent`
142
+ /// Embeddings → `{base}/publishers/google/models/{model}:predict`
143
+ /// Other paths → `{base}{endpoint_path}`
144
+ fn build_url(&self, endpoint_path: &str, model: &str) -> String {
145
+ let base = self.base_url();
146
+ if base.is_empty() {
147
+ // Caller must supply a base_url; will fail at validate() / HTTP layer.
148
+ return String::new();
149
+ }
150
+ let base = base.trim_end_matches('/');
151
+ if endpoint_path.contains("chat/completions") {
152
+ format!("{base}/publishers/google/models/{model}:generateContent")
153
+ } else if endpoint_path.contains("embeddings") {
154
+ format!("{base}/publishers/google/models/{model}:predict")
155
+ } else {
156
+ format!("{base}{endpoint_path}")
157
+ }
158
+ }
159
+
160
+ fn transform_request(&self, body: &mut serde_json::Value) -> Result<()> {
161
+ transform_gemini_request(body)
162
+ }
163
+
164
+ fn transform_response(&self, body: &mut serde_json::Value) -> Result<()> {
165
+ transform_gemini_response(body)
166
+ }
167
+
168
+ /// Build the streaming URL: appends `?alt=sse` to enable SSE streaming.
169
+ ///
170
+ /// Gemini's streaming endpoint uses the same path as the non-streaming
171
+ /// `generateContent` endpoint but requires `?alt=sse` to switch to
172
+ /// Server-Sent Events mode.
173
+ fn build_stream_url(&self, endpoint_path: &str, model: &str) -> String {
174
+ let url = self.build_url(endpoint_path, model);
175
+ if url.is_empty() {
176
+ return url;
177
+ }
178
+ format!("{url}?alt=sse")
179
+ }
180
+
181
+ fn parse_stream_event(&self, event_data: &str) -> Result<Option<ChatCompletionChunk>> {
182
+ parse_gemini_stream_event(event_data)
183
+ }
184
+ }
185
+
186
+ // ── Shared Gemini transform functions ────────────────────────────────────────
187
+ //
188
+ // These are `pub(crate)` so that both `VertexAiProvider` and `GoogleAiProvider`
189
+ // can reuse the same Gemini request/response translation logic.
190
+
191
+ /// Convert an OpenAI-style chat request to Gemini `generateContent` format.
192
+ ///
193
+ /// Key translations:
194
+ /// - System messages → `systemInstruction.parts[]`.
195
+ /// - Assistant role → `model` role.
196
+ /// - Tool calls → `functionCall` parts; tool results → `functionResponse` parts.
197
+ /// - Generation parameters → `generationConfig`.
198
+ /// - Multimodal content arrays → Gemini's `inlineData` / `fileData` format.
199
+ /// - `response_format` → `generationConfig.responseMimeType`.
200
+ /// - `tool_choice` → `toolConfig.functionCallingConfig.mode`.
201
+ /// - `extra_body.safety_settings` → top-level `safetySettings` array.
202
+ /// - `extra_body.grounding_config` / `google_search_retrieval` → `tools` entry.
203
+ /// - `extra_body.cached_content` → top-level `cachedContent` field.
204
+ /// - `ContentPart::Document` → `inlineData` with the document's MIME type.
205
+ pub(crate) fn transform_gemini_request(body: &mut serde_json::Value) -> Result<()> {
206
+ use serde_json::json;
207
+
208
+ // Extract extra_body before taking ownership of fields, since it may contain
209
+ // Gemini-specific extensions (safety_settings, grounding_config, cached_content).
210
+ let extra_body = body
211
+ .as_object_mut()
212
+ .and_then(|o| o.remove("extra_body"))
213
+ .and_then(|v| match v {
214
+ serde_json::Value::Object(map) => Some(map),
215
+ _ => None,
216
+ });
217
+
218
+ // Take ownership of the messages array to avoid cloning.
219
+ let messages = body
220
+ .as_object_mut()
221
+ .and_then(|o| o.remove("messages"))
222
+ .and_then(|v| match v {
223
+ serde_json::Value::Array(arr) => Some(arr),
224
+ _ => None,
225
+ })
226
+ .unwrap_or_default();
227
+
228
+ let mut system_parts: Vec<serde_json::Value> = vec![];
229
+ let mut contents: Vec<serde_json::Value> = vec![];
230
+
231
+ for msg in &messages {
232
+ let role = msg.get("role").and_then(|r| r.as_str()).unwrap_or("");
233
+ let content = msg.get("content");
234
+
235
+ match role {
236
+ "system" | "developer" => {
237
+ if let Some(text) = content.and_then(|c| c.as_str()) {
238
+ system_parts.push(json!({"text": text}));
239
+ }
240
+ }
241
+ "user" => {
242
+ let parts = convert_user_content_to_gemini(content);
243
+ contents.push(json!({"role": "user", "parts": parts}));
244
+ }
245
+ "assistant" => {
246
+ let mut parts: Vec<serde_json::Value> = vec![];
247
+ if let Some(text) = content.and_then(|c| c.as_str())
248
+ && !text.is_empty()
249
+ {
250
+ parts.push(json!({"text": text}));
251
+ }
252
+ // Convert OpenAI tool_calls to Gemini functionCall parts.
253
+ if let Some(tool_calls) = msg.get("tool_calls").and_then(|t| t.as_array()) {
254
+ for tc in tool_calls {
255
+ let args: serde_json::Value = tc
256
+ .pointer("/function/arguments")
257
+ .and_then(|a| a.as_str())
258
+ .and_then(|s| serde_json::from_str(s).ok())
259
+ .unwrap_or_else(|| json!({}));
260
+ parts.push(json!({
261
+ "functionCall": {
262
+ "name": tc.pointer("/function/name"),
263
+ "args": args
264
+ }
265
+ }));
266
+ }
267
+ }
268
+ if parts.is_empty() {
269
+ parts.push(json!({"text": ""}));
270
+ }
271
+ // Gemini uses "model" role for assistant turns.
272
+ contents.push(json!({"role": "model", "parts": parts}));
273
+ }
274
+ "tool" => {
275
+ // Map tool result back to a user turn with a functionResponse part.
276
+ // Gemini requires the function name — use the `name` field only.
277
+ // The `tool_call_id` is an OpenAI correlation ID, not a function name,
278
+ // so we must not fall back to it.
279
+ let name = msg.get("name").and_then(|n| n.as_str()).unwrap_or("tool");
280
+ let result_content = content.cloned().unwrap_or(json!(null));
281
+ contents.push(json!({
282
+ "role": "user",
283
+ "parts": [{
284
+ "functionResponse": {
285
+ "name": name,
286
+ "response": {"result": result_content}
287
+ }
288
+ }]
289
+ }));
290
+ }
291
+ _ => {}
292
+ }
293
+ }
294
+
295
+ // Build generationConfig from OpenAI parameters.
296
+ let mut gen_config = json!({});
297
+ // Support both max_tokens (legacy) and max_completion_tokens (newer OpenAI spec).
298
+ if let Some(max_tokens) = body.get("max_completion_tokens").or_else(|| body.get("max_tokens")) {
299
+ gen_config["maxOutputTokens"] = max_tokens.clone();
300
+ }
301
+ if let Some(temp) = body.get("temperature") {
302
+ gen_config["temperature"] = temp.clone();
303
+ }
304
+ if let Some(top_p) = body.get("top_p") {
305
+ gen_config["topP"] = top_p.clone();
306
+ }
307
+ if let Some(stop) = body.get("stop") {
308
+ let sequences = if let Some(s) = stop.as_str() {
309
+ vec![json!(s)]
310
+ } else {
311
+ stop.as_array().cloned().unwrap_or_default()
312
+ };
313
+ gen_config["stopSequences"] = json!(sequences);
314
+ }
315
+
316
+ // Translate response_format to Gemini's responseMimeType.
317
+ if let Some(rf) = body.get("response_format") {
318
+ let rf_type = rf.get("type").and_then(|t| t.as_str()).unwrap_or("");
319
+ match rf_type {
320
+ "json_object" => {
321
+ gen_config["responseMimeType"] = json!("application/json");
322
+ }
323
+ "json_schema" => {
324
+ gen_config["responseMimeType"] = json!("application/json");
325
+ // If a JSON schema is provided, pass it through.
326
+ if let Some(schema) = rf.get("json_schema").and_then(|s| s.get("schema")) {
327
+ gen_config["responseSchema"] = schema.clone();
328
+ }
329
+ }
330
+ // "text" or unknown types: no special handling needed.
331
+ _ => {}
332
+ }
333
+ }
334
+
335
+ // Translate OpenAI tools array to Gemini functionDeclarations.
336
+ let mut tools_value = body.get("tools").and_then(|t| t.as_array()).map(|arr| {
337
+ let declarations: Vec<serde_json::Value> = arr
338
+ .iter()
339
+ .map(|t| {
340
+ let name = t.pointer("/function/name").cloned().unwrap_or(json!("unknown"));
341
+ let description = t.pointer("/function/description").cloned().unwrap_or(json!(""));
342
+ let parameters = t
343
+ .pointer("/function/parameters")
344
+ .cloned()
345
+ .unwrap_or_else(|| json!({"type": "object"}));
346
+ json!({
347
+ "name": name,
348
+ "description": description,
349
+ "parameters": parameters
350
+ })
351
+ })
352
+ .collect();
353
+ json!([{"functionDeclarations": declarations}])
354
+ });
355
+
356
+ // Translate tool_choice to Gemini toolConfig.functionCallingConfig.mode.
357
+ let tool_config = translate_tool_choice(body.get("tool_choice"));
358
+
359
+ // ── extra_body extensions ────────────────────────────────────────────────
360
+ let mut safety_settings: Option<serde_json::Value> = None;
361
+ let mut cached_content: Option<serde_json::Value> = None;
362
+
363
+ if let Some(ref eb) = extra_body {
364
+ // Safety settings: inject as top-level safetySettings array.
365
+ if let Some(ss) = eb.get("safety_settings") {
366
+ safety_settings = Some(ss.clone());
367
+ }
368
+
369
+ // Grounding / Google Search: add google_search_retrieval to tools array.
370
+ if eb.contains_key("grounding_config") || eb.contains_key("google_search_retrieval") {
371
+ let grounding_tool = json!({"google_search_retrieval": {}});
372
+ match &mut tools_value {
373
+ Some(existing) => {
374
+ if let Some(arr) = existing.as_array_mut() {
375
+ arr.push(grounding_tool);
376
+ }
377
+ }
378
+ None => {
379
+ tools_value = Some(json!([grounding_tool]));
380
+ }
381
+ }
382
+ }
383
+
384
+ // Context caching: inject as top-level cachedContent field.
385
+ if let Some(cc) = eb.get("cached_content") {
386
+ cached_content = Some(cc.clone());
387
+ }
388
+ }
389
+
390
+ let mut new_body = json!({"contents": contents});
391
+ if !system_parts.is_empty() {
392
+ // Gemini API requires camelCase: systemInstruction.
393
+ new_body["systemInstruction"] = json!({"parts": system_parts});
394
+ }
395
+ if let Some(obj) = gen_config.as_object()
396
+ && !obj.is_empty()
397
+ {
398
+ new_body["generationConfig"] = gen_config;
399
+ }
400
+ if let Some(tools) = tools_value {
401
+ new_body["tools"] = tools;
402
+ }
403
+ if let Some(tc) = tool_config {
404
+ new_body["toolConfig"] = tc;
405
+ }
406
+ if let Some(ss) = safety_settings {
407
+ new_body["safetySettings"] = ss;
408
+ }
409
+ if let Some(cc) = cached_content {
410
+ new_body["cachedContent"] = cc;
411
+ }
412
+
413
+ *body = new_body;
414
+ Ok(())
415
+ }
416
+
417
+ /// Normalize a Gemini `generateContent` response to OpenAI chat completion format.
418
+ ///
419
+ /// Gemini wraps the response in `candidates[0].content.parts[]`.
420
+ /// Finish reasons use Gemini terminology (`STOP`, `MAX_TOKENS`, `SAFETY`, ...)
421
+ /// and are mapped to the OpenAI `finish_reason` set.
422
+ ///
423
+ /// If `groundingMetadata` is present on the candidate, it is included in the
424
+ /// response as `_grounding_metadata` for supplementary use by callers.
425
+ ///
426
+ /// **Known limitation:** The `model` field in the normalized response is
427
+ /// always `""`. Gemini/Vertex AI does not include the model name in its
428
+ /// response body -- the model is only present in the request URL path.
429
+ pub(crate) fn transform_gemini_response(body: &mut serde_json::Value) -> Result<()> {
430
+ use serde_json::json;
431
+
432
+ // Check for a blocked prompt (no candidates, but promptFeedback.blockReason set).
433
+ let candidates = body.get("candidates").and_then(|c| c.as_array());
434
+ if candidates.is_none_or(|c| c.is_empty()) {
435
+ let block_reason = body
436
+ .pointer("/promptFeedback/blockReason")
437
+ .and_then(|r| r.as_str())
438
+ .unwrap_or("UNKNOWN");
439
+ let prompt_tokens = body
440
+ .pointer("/usageMetadata/promptTokenCount")
441
+ .and_then(|v| v.as_u64())
442
+ .unwrap_or(0);
443
+ *body = json!({
444
+ "id": "gemini-resp",
445
+ "object": "chat.completion",
446
+ "created": super::unix_timestamp_secs(),
447
+ "model": "",
448
+ "choices": [{
449
+ "index": 0,
450
+ "message": {"role": "assistant", "content": null},
451
+ "finish_reason": "content_filter"
452
+ }],
453
+ "usage": {
454
+ "prompt_tokens": prompt_tokens,
455
+ "completion_tokens": 0,
456
+ "total_tokens": prompt_tokens
457
+ },
458
+ "system_fingerprint": null,
459
+ "_block_reason": block_reason
460
+ });
461
+ return Ok(());
462
+ }
463
+
464
+ let candidate = body.pointer("/candidates/0").cloned();
465
+ let finish_reason_raw = candidate
466
+ .as_ref()
467
+ .and_then(|c| c.get("finishReason"))
468
+ .and_then(|f| f.as_str())
469
+ .unwrap_or("STOP");
470
+ let parts = candidate
471
+ .as_ref()
472
+ .and_then(|c| c.pointer("/content/parts"))
473
+ .and_then(|p| p.as_array())
474
+ .cloned()
475
+ .unwrap_or_default();
476
+
477
+ // Collect text content from parts.
478
+ let text: String = parts
479
+ .iter()
480
+ .filter_map(|p| p.get("text").and_then(|t| t.as_str()))
481
+ .collect::<Vec<_>>()
482
+ .join("");
483
+
484
+ // Collect functionCall parts and convert to OpenAI tool_calls.
485
+ // Each call gets a unique ID via an atomic counter to avoid collisions
486
+ // when the same function is called multiple times.
487
+ let tool_calls: Vec<serde_json::Value> = parts
488
+ .iter()
489
+ .filter_map(|p| {
490
+ p.get("functionCall").map(|fc| {
491
+ let name = fc.get("name").and_then(|n| n.as_str()).unwrap_or("unknown");
492
+ let call_id = TOOL_CALL_COUNTER.fetch_add(1, Ordering::Relaxed);
493
+ let arguments = serde_json::to_string(fc.get("args").unwrap_or(&json!({}))).unwrap_or_default();
494
+ json!({
495
+ "id": format!("call_{name}_{call_id}"),
496
+ "type": "function",
497
+ "function": {
498
+ "name": fc.get("name"),
499
+ "arguments": arguments
500
+ }
501
+ })
502
+ })
503
+ })
504
+ .collect();
505
+
506
+ let finish_reason = match finish_reason_raw {
507
+ "STOP" => "stop",
508
+ "MAX_TOKENS" => "length",
509
+ "SAFETY" | "RECITATION" | "BLOCKLIST" | "PROHIBITED_CONTENT" | "SPII" | "IMAGE_SAFETY" => "content_filter",
510
+ "LANGUAGE" | "OTHER" => "stop",
511
+ "TOOL_CODE" | "FUNCTION_CALL" => "tool_calls",
512
+ _ => "stop",
513
+ };
514
+
515
+ let prompt_tokens = body
516
+ .pointer("/usageMetadata/promptTokenCount")
517
+ .and_then(|v| v.as_u64())
518
+ .unwrap_or(0);
519
+ let completion_tokens = body
520
+ .pointer("/usageMetadata/candidatesTokenCount")
521
+ .and_then(|v| v.as_u64())
522
+ .unwrap_or(0);
523
+
524
+ let response_id = body.get("responseId").cloned().unwrap_or_else(|| json!("gemini-resp"));
525
+
526
+ let content_value: serde_json::Value = if text.is_empty() { json!(null) } else { json!(text) };
527
+
528
+ let mut message = json!({"role": "assistant", "content": content_value});
529
+ if !tool_calls.is_empty() {
530
+ message["tool_calls"] = json!(tool_calls);
531
+ }
532
+
533
+ // Extract grounding metadata if present (supplementary data from Google Search grounding).
534
+ let grounding_metadata = candidate.as_ref().and_then(|c| c.get("groundingMetadata")).cloned();
535
+
536
+ let mut result = json!({
537
+ "id": response_id,
538
+ "object": "chat.completion",
539
+ "created": super::unix_timestamp_secs(),
540
+ "model": "",
541
+ "choices": [{
542
+ "index": 0,
543
+ "message": message,
544
+ "finish_reason": finish_reason
545
+ }],
546
+ "usage": {
547
+ "prompt_tokens": prompt_tokens,
548
+ "completion_tokens": completion_tokens,
549
+ "total_tokens": prompt_tokens + completion_tokens
550
+ }
551
+ });
552
+
553
+ if let Some(gm) = grounding_metadata {
554
+ result["_grounding_metadata"] = gm;
555
+ }
556
+
557
+ *body = result;
558
+
559
+ Ok(())
560
+ }
561
+
562
+ /// Parse a single SSE event from Gemini's streaming endpoint.
563
+ ///
564
+ /// Gemini streaming uses SSE with `?alt=sse`. Each event data is a complete
565
+ /// `generateContent` JSON response. We reuse `transform_gemini_response` to
566
+ /// normalize it into OpenAI format, then build a `ChatCompletionChunk` from
567
+ /// the first choice's message content.
568
+ ///
569
+ /// **Note:** The `id` and `model` fields are empty strings on every chunk
570
+ /// because Gemini's streaming payloads do not include them, and this parser
571
+ /// is stateless.
572
+ pub(crate) fn parse_gemini_stream_event(event_data: &str) -> Result<Option<ChatCompletionChunk>> {
573
+ // NOTE: `[DONE]` is handled at the SSE parser level; no check needed here.
574
+ if event_data.trim().is_empty() {
575
+ return Ok(None);
576
+ }
577
+
578
+ let mut body: serde_json::Value = serde_json::from_str(event_data).map_err(|e| LiterLlmError::Streaming {
579
+ message: format!("failed to parse Gemini SSE data: {e}"),
580
+ })?;
581
+
582
+ // Normalize to OpenAI chat completion format.
583
+ transform_gemini_response(&mut body)?;
584
+
585
+ // Extract fields from the normalized response.
586
+ let id = body
587
+ .get("id")
588
+ .and_then(|v| v.as_str())
589
+ .unwrap_or("gemini-resp")
590
+ .to_owned();
591
+ let model = body.get("model").and_then(|v| v.as_str()).unwrap_or("").to_owned();
592
+
593
+ let choice = body.pointer("/choices/0");
594
+ let content = choice
595
+ .and_then(|c| c.pointer("/message/content"))
596
+ .and_then(|v| v.as_str())
597
+ .map(ToOwned::to_owned);
598
+ let finish_reason_str = choice
599
+ .and_then(|c| c.get("finish_reason"))
600
+ .and_then(|v| v.as_str())
601
+ .unwrap_or("");
602
+
603
+ // Extract tool_calls from the normalized message if present.
604
+ let stream_tool_calls = choice
605
+ .and_then(|c| c.pointer("/message/tool_calls"))
606
+ .and_then(|v| v.as_array())
607
+ .filter(|arr| !arr.is_empty())
608
+ .map(|arr| {
609
+ use crate::types::{StreamFunctionCall, StreamToolCall, ToolType};
610
+ arr.iter()
611
+ .enumerate()
612
+ .map(|(idx, tc)| StreamToolCall {
613
+ index: idx as u32,
614
+ id: tc.get("id").and_then(|v| v.as_str()).map(ToOwned::to_owned),
615
+ call_type: Some(ToolType::Function),
616
+ function: tc.get("function").map(|f| StreamFunctionCall {
617
+ name: f.get("name").and_then(|v| v.as_str()).map(ToOwned::to_owned),
618
+ arguments: f.get("arguments").and_then(|v| v.as_str()).map(ToOwned::to_owned),
619
+ }),
620
+ })
621
+ .collect::<Vec<_>>()
622
+ });
623
+
624
+ use crate::types::{FinishReason, StreamChoice, StreamDelta};
625
+
626
+ let finish_reason = match finish_reason_str {
627
+ "stop" => Some(FinishReason::Stop),
628
+ "length" => Some(FinishReason::Length),
629
+ "tool_calls" => Some(FinishReason::ToolCalls),
630
+ "content_filter" => Some(FinishReason::ContentFilter),
631
+ _ => None,
632
+ };
633
+
634
+ let chunk = ChatCompletionChunk {
635
+ id,
636
+ object: "chat.completion.chunk".to_owned(),
637
+ created: super::unix_timestamp_secs(),
638
+ model,
639
+ choices: vec![StreamChoice {
640
+ index: 0,
641
+ delta: StreamDelta {
642
+ role: Some("assistant".to_owned()),
643
+ content,
644
+ tool_calls: stream_tool_calls,
645
+ function_call: None,
646
+ refusal: None,
647
+ },
648
+ finish_reason,
649
+ }],
650
+ usage: None,
651
+ system_fingerprint: None,
652
+ service_tier: None,
653
+ };
654
+
655
+ Ok(Some(chunk))
656
+ }
657
+
658
+ // ── Helper functions ──────────────────────────────────────────────────────────
659
+
660
+ /// Convert OpenAI user content (string or content-part array) to Gemini parts.
661
+ ///
662
+ /// Handles four cases:
663
+ /// 1. Plain string -> single text part.
664
+ /// 2. Array of content parts -> each part converted to Gemini format.
665
+ /// 3. `ContentPart::Document` -> Gemini `inlineData` with the document's MIME type.
666
+ /// 4. None/null -> single empty text part.
667
+ pub(crate) fn convert_user_content_to_gemini(content: Option<&serde_json::Value>) -> Vec<serde_json::Value> {
668
+ use serde_json::json;
669
+
670
+ match content {
671
+ Some(serde_json::Value::String(s)) => vec![json!({"text": s})],
672
+ Some(serde_json::Value::Array(parts)) => {
673
+ parts
674
+ .iter()
675
+ .filter_map(|part| {
676
+ let part_type = part.get("type").and_then(|t| t.as_str())?;
677
+ match part_type {
678
+ "text" => {
679
+ let text = part.get("text").and_then(|t| t.as_str()).unwrap_or("");
680
+ Some(json!({"text": text}))
681
+ }
682
+ "image_url" => {
683
+ let url = part.pointer("/image_url/url").and_then(|u| u.as_str())?;
684
+ if url.starts_with("data:") {
685
+ // data:<media_type>;base64,<data>
686
+ if let Some((header, data)) = url.split_once(',') {
687
+ let mime_type = header.trim_start_matches("data:").trim_end_matches(";base64");
688
+ return Some(json!({
689
+ "inlineData": {
690
+ "mimeType": mime_type,
691
+ "data": data
692
+ }
693
+ }));
694
+ }
695
+ }
696
+ // Plain URL -- use Gemini's fileData format.
697
+ Some(json!({
698
+ "fileData": {
699
+ "mimeType": "image/jpeg",
700
+ "fileUri": url
701
+ }
702
+ }))
703
+ }
704
+ "document" => {
705
+ // ContentPart::Document -> Gemini inlineData.
706
+ let doc = part.get("document")?;
707
+ let data = doc.get("data").and_then(|d| d.as_str())?;
708
+ let media_type = doc
709
+ .get("media_type")
710
+ .and_then(|m| m.as_str())
711
+ .unwrap_or("application/pdf");
712
+ Some(json!({
713
+ "inlineData": {
714
+ "mimeType": media_type,
715
+ "data": data
716
+ }
717
+ }))
718
+ }
719
+ _ => {
720
+ // Unknown content part types: fall back to text representation.
721
+ let text = part.get("text").and_then(|t| t.as_str()).unwrap_or("");
722
+ if text.is_empty() {
723
+ None
724
+ } else {
725
+ Some(json!({"text": text}))
726
+ }
727
+ }
728
+ }
729
+ })
730
+ .collect()
731
+ }
732
+ _ => vec![json!({"text": ""})],
733
+ }
734
+ }
735
+
736
+ /// Translate OpenAI `tool_choice` to Gemini `toolConfig.functionCallingConfig`.
737
+ ///
738
+ /// OpenAI `tool_choice` values:
739
+ /// - `"none"` -> `NONE`
740
+ /// - `"auto"` -> `AUTO`
741
+ /// - `"required"` -> `ANY`
742
+ /// - `{"type": "function", "function": {"name": "..."}}` -> `ANY` with `allowedFunctionNames`
743
+ fn translate_tool_choice(tool_choice: Option<&serde_json::Value>) -> Option<serde_json::Value> {
744
+ use serde_json::json;
745
+
746
+ let tc = tool_choice?;
747
+
748
+ if let Some(s) = tc.as_str() {
749
+ let mode = match s {
750
+ "none" => "NONE",
751
+ "auto" => "AUTO",
752
+ "required" => "ANY",
753
+ _ => return None,
754
+ };
755
+ return Some(json!({
756
+ "functionCallingConfig": {
757
+ "mode": mode
758
+ }
759
+ }));
760
+ }
761
+
762
+ // Object form: {"type": "function", "function": {"name": "specific_fn"}}
763
+ if let Some(name) = tc.pointer("/function/name").and_then(|n| n.as_str()) {
764
+ return Some(json!({
765
+ "functionCallingConfig": {
766
+ "mode": "ANY",
767
+ "allowedFunctionNames": [name]
768
+ }
769
+ }));
770
+ }
771
+
772
+ None
773
+ }
774
+
775
+ // ── Unit tests ────────────────────────────────────────────────────────────────
776
+
777
+ #[cfg(test)]
778
+ mod tests {
779
+ use serde_json::json;
780
+
781
+ use super::*;
782
+ use crate::provider::Provider;
783
+
784
+ fn provider() -> VertexAiProvider {
785
+ VertexAiProvider::new("test-project", "us-central1")
786
+ }
787
+
788
+ fn provider_without_project() -> VertexAiProvider {
789
+ VertexAiProvider {
790
+ base_url: String::new(),
791
+ }
792
+ }
793
+
794
+ // ── validate ──────────────────────────────────────────────────────────────
795
+
796
+ #[test]
797
+ fn validate_succeeds_with_project() {
798
+ let p = provider();
799
+ assert!(p.validate().is_ok());
800
+ }
801
+
802
+ #[test]
803
+ fn validate_fails_without_project() {
804
+ let p = provider_without_project();
805
+ let err = p.validate().unwrap_err();
806
+ assert!(
807
+ err.to_string().contains("VERTEXAI_PROJECT"),
808
+ "error should mention VERTEXAI_PROJECT"
809
+ );
810
+ }
811
+
812
+ // ── base_url ──────────────────────────────────────────────────────────────
813
+
814
+ #[test]
815
+ fn base_url_constructed_from_project_and_location() {
816
+ let p = provider();
817
+ assert_eq!(
818
+ p.base_url(),
819
+ "https://us-central1-aiplatform.googleapis.com/v1/projects/test-project/locations/us-central1"
820
+ );
821
+ }
822
+
823
+ #[test]
824
+ fn base_url_custom_location() {
825
+ let p = VertexAiProvider::new("my-proj", "europe-west1");
826
+ assert_eq!(
827
+ p.base_url(),
828
+ "https://europe-west1-aiplatform.googleapis.com/v1/projects/my-proj/locations/europe-west1"
829
+ );
830
+ }
831
+
832
+ // ── build_url ─────────────────────────────────────────────────────────────
833
+
834
+ #[test]
835
+ fn build_url_returns_empty_without_base() {
836
+ let p = provider_without_project();
837
+ let url = p.build_url("/chat/completions", "gemini-2.0-flash");
838
+ assert!(url.is_empty(), "should return empty string without a base URL");
839
+ }
840
+
841
+ #[test]
842
+ fn build_url_chat_completions() {
843
+ let p = provider();
844
+ let url = p.build_url("/chat/completions", "gemini-2.0-flash");
845
+ assert!(url.ends_with("/publishers/google/models/gemini-2.0-flash:generateContent"));
846
+ }
847
+
848
+ #[test]
849
+ fn build_url_embeddings() {
850
+ let p = provider();
851
+ let url = p.build_url("/embeddings", "text-embedding-004");
852
+ assert!(url.ends_with("/publishers/google/models/text-embedding-004:predict"));
853
+ }
854
+
855
+ // ── transform_request ─────────────────────────────────────────────────────
856
+
857
+ #[test]
858
+ fn transform_request_basic_chat() {
859
+ let p = provider();
860
+ let mut body = json!({
861
+ "messages": [
862
+ {"role": "system", "content": "You are a helpful assistant."},
863
+ {"role": "user", "content": "Hello!"}
864
+ ],
865
+ "max_tokens": 200,
866
+ "temperature": 0.5
867
+ });
868
+
869
+ p.transform_request(&mut body).unwrap();
870
+
871
+ // System instruction extracted with camelCase key required by Gemini API.
872
+ assert_eq!(
873
+ body["systemInstruction"]["parts"][0]["text"],
874
+ "You are a helpful assistant."
875
+ );
876
+
877
+ // User message converted to Gemini format.
878
+ assert_eq!(body["contents"][0]["role"], "user");
879
+ assert_eq!(body["contents"][0]["parts"][0]["text"], "Hello!");
880
+
881
+ // Generation config set.
882
+ assert_eq!(body["generationConfig"]["maxOutputTokens"], 200);
883
+ assert_eq!(body["generationConfig"]["temperature"], 0.5);
884
+ }
885
+
886
+ #[test]
887
+ fn transform_request_assistant_becomes_model_role() {
888
+ let p = provider();
889
+ let mut body = json!({
890
+ "messages": [
891
+ {"role": "user", "content": "Hi"},
892
+ {"role": "assistant", "content": "Hello there!"}
893
+ ]
894
+ });
895
+
896
+ p.transform_request(&mut body).unwrap();
897
+
898
+ assert_eq!(body["contents"][1]["role"], "model");
899
+ assert_eq!(body["contents"][1]["parts"][0]["text"], "Hello there!");
900
+ }
901
+
902
+ #[test]
903
+ fn transform_request_with_tool_calls() {
904
+ let p = provider();
905
+ let mut body = json!({
906
+ "messages": [
907
+ {"role": "user", "content": "What is the weather in Berlin?"},
908
+ {
909
+ "role": "assistant",
910
+ "content": null,
911
+ "tool_calls": [{
912
+ "id": "call_1",
913
+ "type": "function",
914
+ "function": {"name": "get_weather", "arguments": "{\"city\":\"Berlin\"}"}
915
+ }]
916
+ },
917
+ {
918
+ "role": "tool",
919
+ "name": "get_weather",
920
+ "tool_call_id": "call_1",
921
+ "content": "Sunny, 22°C"
922
+ }
923
+ ]
924
+ });
925
+
926
+ p.transform_request(&mut body).unwrap();
927
+
928
+ let contents = body["contents"].as_array().unwrap();
929
+ assert_eq!(contents.len(), 3);
930
+
931
+ // Assistant turn with functionCall part.
932
+ let model_turn = &contents[1];
933
+ assert_eq!(model_turn["role"], "model");
934
+ let fn_call = &model_turn["parts"][0]["functionCall"];
935
+ assert_eq!(fn_call["name"], "get_weather");
936
+ assert_eq!(fn_call["args"]["city"], "Berlin");
937
+
938
+ // Tool result as user turn with functionResponse.
939
+ let tool_turn = &contents[2];
940
+ assert_eq!(tool_turn["role"], "user");
941
+ let fn_resp = &tool_turn["parts"][0]["functionResponse"];
942
+ assert_eq!(fn_resp["name"], "get_weather");
943
+ }
944
+
945
+ #[test]
946
+ fn transform_request_stop_sequences() {
947
+ let p = provider();
948
+ let mut body = json!({
949
+ "messages": [{"role": "user", "content": "hi"}],
950
+ "stop": ["END", "STOP"]
951
+ });
952
+
953
+ p.transform_request(&mut body).unwrap();
954
+
955
+ let stop_seqs = body["generationConfig"]["stopSequences"].as_array().unwrap();
956
+ assert_eq!(stop_seqs.len(), 2);
957
+ assert_eq!(stop_seqs[0], "END");
958
+ assert_eq!(stop_seqs[1], "STOP");
959
+ }
960
+
961
+ // ── transform_request: safety settings ───────────────────────────────────
962
+
963
+ #[test]
964
+ fn transform_request_safety_settings_from_extra_body() {
965
+ let p = provider();
966
+ let mut body = json!({
967
+ "messages": [{"role": "user", "content": "hi"}],
968
+ "extra_body": {
969
+ "safety_settings": [
970
+ {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
971
+ {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_ONLY_HIGH"}
972
+ ]
973
+ }
974
+ });
975
+
976
+ p.transform_request(&mut body).unwrap();
977
+
978
+ let settings = body["safetySettings"].as_array().unwrap();
979
+ assert_eq!(settings.len(), 2);
980
+ assert_eq!(settings[0]["category"], "HARM_CATEGORY_HATE_SPEECH");
981
+ assert_eq!(settings[0]["threshold"], "BLOCK_MEDIUM_AND_ABOVE");
982
+ assert_eq!(settings[1]["category"], "HARM_CATEGORY_DANGEROUS_CONTENT");
983
+ }
984
+
985
+ // ── transform_request: grounding / Google Search ─────────────────────────
986
+
987
+ #[test]
988
+ fn transform_request_grounding_config_adds_google_search() {
989
+ let p = provider();
990
+ let mut body = json!({
991
+ "messages": [{"role": "user", "content": "What happened today?"}],
992
+ "extra_body": {
993
+ "grounding_config": {}
994
+ }
995
+ });
996
+
997
+ p.transform_request(&mut body).unwrap();
998
+
999
+ let tools = body["tools"].as_array().unwrap();
1000
+ assert!(
1001
+ tools.iter().any(|t| t.get("google_search_retrieval").is_some()),
1002
+ "tools should contain google_search_retrieval"
1003
+ );
1004
+ }
1005
+
1006
+ #[test]
1007
+ fn transform_request_google_search_retrieval_with_existing_tools() {
1008
+ let p = provider();
1009
+ let mut body = json!({
1010
+ "messages": [{"role": "user", "content": "hi"}],
1011
+ "tools": [{"type": "function", "function": {"name": "f", "parameters": {}}}],
1012
+ "extra_body": {
1013
+ "google_search_retrieval": {}
1014
+ }
1015
+ });
1016
+
1017
+ p.transform_request(&mut body).unwrap();
1018
+
1019
+ let tools = body["tools"].as_array().unwrap();
1020
+ // Should have functionDeclarations + google_search_retrieval.
1021
+ assert_eq!(tools.len(), 2);
1022
+ assert!(tools[0].get("functionDeclarations").is_some());
1023
+ assert!(tools[1].get("google_search_retrieval").is_some());
1024
+ }
1025
+
1026
+ // ── transform_request: context caching ───────────────────────────────────
1027
+
1028
+ #[test]
1029
+ fn transform_request_cached_content_from_extra_body() {
1030
+ let p = provider();
1031
+ let cached = "projects/xxx/locations/xxx/cachedContents/abc123";
1032
+ let mut body = json!({
1033
+ "messages": [{"role": "user", "content": "hi"}],
1034
+ "extra_body": {
1035
+ "cached_content": cached
1036
+ }
1037
+ });
1038
+
1039
+ p.transform_request(&mut body).unwrap();
1040
+
1041
+ assert_eq!(body["cachedContent"], cached);
1042
+ }
1043
+
1044
+ // ── transform_request: document handling ─────────────────────────────────
1045
+
1046
+ #[test]
1047
+ fn transform_request_document_content_part() {
1048
+ let p = provider();
1049
+ let mut body = json!({
1050
+ "messages": [{
1051
+ "role": "user",
1052
+ "content": [
1053
+ {"type": "text", "text": "Summarize this document."},
1054
+ {
1055
+ "type": "document",
1056
+ "document": {
1057
+ "data": "JVBERi0xLjQ=",
1058
+ "media_type": "application/pdf"
1059
+ }
1060
+ }
1061
+ ]
1062
+ }]
1063
+ });
1064
+
1065
+ p.transform_request(&mut body).unwrap();
1066
+
1067
+ let parts = body["contents"][0]["parts"].as_array().unwrap();
1068
+ assert_eq!(parts.len(), 2);
1069
+ assert_eq!(parts[0]["text"], "Summarize this document.");
1070
+ assert_eq!(parts[1]["inlineData"]["mimeType"], "application/pdf");
1071
+ assert_eq!(parts[1]["inlineData"]["data"], "JVBERi0xLjQ=");
1072
+ }
1073
+
1074
+ // ── transform_response ────────────────────────────────────────────────────
1075
+
1076
+ #[test]
1077
+ fn transform_response_basic() {
1078
+ let p = provider();
1079
+ let mut body = json!({
1080
+ "responseId": "resp-gemini-123",
1081
+ "candidates": [{
1082
+ "content": {
1083
+ "role": "model",
1084
+ "parts": [{"text": "Hello from Gemini!"}]
1085
+ },
1086
+ "finishReason": "STOP"
1087
+ }],
1088
+ "usageMetadata": {
1089
+ "promptTokenCount": 8,
1090
+ "candidatesTokenCount": 6
1091
+ }
1092
+ });
1093
+
1094
+ p.transform_response(&mut body).unwrap();
1095
+
1096
+ assert_eq!(body["object"], "chat.completion");
1097
+ assert_eq!(body["id"], "resp-gemini-123");
1098
+ assert_eq!(body["choices"][0]["message"]["content"], "Hello from Gemini!");
1099
+ assert_eq!(body["choices"][0]["finish_reason"], "stop");
1100
+ assert_eq!(body["usage"]["prompt_tokens"], 8);
1101
+ assert_eq!(body["usage"]["completion_tokens"], 6);
1102
+ assert_eq!(body["usage"]["total_tokens"], 14);
1103
+ }
1104
+
1105
+ #[test]
1106
+ fn transform_response_tool_calls_have_unique_ids() {
1107
+ let p = provider();
1108
+ let mut body = json!({
1109
+ "candidates": [{
1110
+ "content": {
1111
+ "role": "model",
1112
+ "parts": [
1113
+ {
1114
+ "functionCall": {
1115
+ "name": "get_weather",
1116
+ "args": {"city": "Berlin"}
1117
+ }
1118
+ },
1119
+ {
1120
+ "functionCall": {
1121
+ "name": "get_weather",
1122
+ "args": {"city": "Paris"}
1123
+ }
1124
+ }
1125
+ ]
1126
+ },
1127
+ "finishReason": "STOP"
1128
+ }],
1129
+ "usageMetadata": {"promptTokenCount": 10, "candidatesTokenCount": 5}
1130
+ });
1131
+
1132
+ p.transform_response(&mut body).unwrap();
1133
+
1134
+ let tool_calls = body["choices"][0]["message"]["tool_calls"].as_array().unwrap();
1135
+ assert_eq!(tool_calls.len(), 2);
1136
+
1137
+ // Both calls should have the function name "get_weather" but different IDs.
1138
+ let id0 = tool_calls[0]["id"].as_str().unwrap();
1139
+ let id1 = tool_calls[1]["id"].as_str().unwrap();
1140
+ assert_ne!(id0, id1, "tool call IDs must be unique even for the same function");
1141
+ assert!(id0.starts_with("call_get_weather_"));
1142
+ assert!(id1.starts_with("call_get_weather_"));
1143
+
1144
+ // Verify arguments are correct.
1145
+ let args0: serde_json::Value =
1146
+ serde_json::from_str(tool_calls[0]["function"]["arguments"].as_str().unwrap()).unwrap();
1147
+ let args1: serde_json::Value =
1148
+ serde_json::from_str(tool_calls[1]["function"]["arguments"].as_str().unwrap()).unwrap();
1149
+ assert_eq!(args0["city"], "Berlin");
1150
+ assert_eq!(args1["city"], "Paris");
1151
+ }
1152
+
1153
+ #[test]
1154
+ fn transform_response_single_tool_call() {
1155
+ let p = provider();
1156
+ let mut body = json!({
1157
+ "candidates": [{
1158
+ "content": {
1159
+ "role": "model",
1160
+ "parts": [{
1161
+ "functionCall": {
1162
+ "name": "get_weather",
1163
+ "args": {"city": "Berlin"}
1164
+ }
1165
+ }]
1166
+ },
1167
+ "finishReason": "STOP"
1168
+ }],
1169
+ "usageMetadata": {"promptTokenCount": 10, "candidatesTokenCount": 5}
1170
+ });
1171
+
1172
+ p.transform_response(&mut body).unwrap();
1173
+
1174
+ let tool_calls = body["choices"][0]["message"]["tool_calls"].as_array().unwrap();
1175
+ assert_eq!(tool_calls.len(), 1);
1176
+ assert_eq!(tool_calls[0]["function"]["name"], "get_weather");
1177
+ // ID should contain the function name and a unique counter.
1178
+ let id = tool_calls[0]["id"].as_str().unwrap();
1179
+ assert!(
1180
+ id.starts_with("call_get_weather_"),
1181
+ "id should start with call_get_weather_, got: {id}"
1182
+ );
1183
+ }
1184
+
1185
+ #[test]
1186
+ fn transform_response_finish_reason_mapping() {
1187
+ let p = provider();
1188
+
1189
+ for (gemini_reason, expected_oai_reason) in [
1190
+ ("STOP", "stop"),
1191
+ ("MAX_TOKENS", "length"),
1192
+ ("SAFETY", "content_filter"),
1193
+ ("RECITATION", "content_filter"),
1194
+ ("BLOCKLIST", "content_filter"),
1195
+ ("PROHIBITED_CONTENT", "content_filter"),
1196
+ ("UNKNOWN_FUTURE_REASON", "stop"),
1197
+ ] {
1198
+ let mut body = json!({
1199
+ "candidates": [{
1200
+ "content": {"role": "model", "parts": [{"text": ""}]},
1201
+ "finishReason": gemini_reason
1202
+ }],
1203
+ "usageMetadata": {"promptTokenCount": 0, "candidatesTokenCount": 0}
1204
+ });
1205
+ p.transform_response(&mut body).unwrap();
1206
+ assert_eq!(
1207
+ body["choices"][0]["finish_reason"], expected_oai_reason,
1208
+ "Gemini finishReason '{gemini_reason}' should map to '{expected_oai_reason}'"
1209
+ );
1210
+ }
1211
+ }
1212
+
1213
+ #[test]
1214
+ fn transform_response_grounding_metadata_preserved() {
1215
+ let p = provider();
1216
+ let mut body = json!({
1217
+ "candidates": [{
1218
+ "content": {
1219
+ "role": "model",
1220
+ "parts": [{"text": "grounded answer"}]
1221
+ },
1222
+ "finishReason": "STOP",
1223
+ "groundingMetadata": {
1224
+ "searchEntryPoint": {"renderedContent": "<html>...</html>"},
1225
+ "groundingChunks": [{"web": {"uri": "https://example.com", "title": "Example"}}]
1226
+ }
1227
+ }],
1228
+ "usageMetadata": {"promptTokenCount": 5, "candidatesTokenCount": 3}
1229
+ });
1230
+
1231
+ p.transform_response(&mut body).unwrap();
1232
+
1233
+ assert_eq!(body["choices"][0]["message"]["content"], "grounded answer");
1234
+ assert!(
1235
+ body.get("_grounding_metadata").is_some(),
1236
+ "grounding metadata should be preserved"
1237
+ );
1238
+ assert!(body["_grounding_metadata"]["groundingChunks"].as_array().unwrap().len() == 1);
1239
+ }
1240
+
1241
+ // ── parse_stream_event ────────────────────────────────────────────────────
1242
+
1243
+ #[test]
1244
+ fn parse_stream_event_empty_returns_none() {
1245
+ let p = provider();
1246
+ let result = p.parse_stream_event("").unwrap();
1247
+ assert!(result.is_none());
1248
+ }
1249
+
1250
+ #[test]
1251
+ fn parse_stream_event_done_is_handled_at_sse_level() {
1252
+ // `[DONE]` is now caught by the SSE parser before reaching the provider.
1253
+ // If it were to reach the provider, it would be invalid JSON.
1254
+ let p = provider();
1255
+ let result = p.parse_stream_event("[DONE]");
1256
+ assert!(
1257
+ result.is_err(),
1258
+ "[DONE] is not valid JSON and should error if it reaches the provider"
1259
+ );
1260
+ }
1261
+
1262
+ #[test]
1263
+ fn parse_stream_event_basic_chunk() {
1264
+ let p = provider();
1265
+ let event_data = r#"{
1266
+ "candidates": [{
1267
+ "content": {"role": "model", "parts": [{"text": "Hello"}]},
1268
+ "finishReason": "STOP"
1269
+ }],
1270
+ "usageMetadata": {"promptTokenCount": 5, "candidatesTokenCount": 2}
1271
+ }"#;
1272
+
1273
+ let chunk = p.parse_stream_event(event_data).unwrap().unwrap();
1274
+
1275
+ assert_eq!(chunk.object, "chat.completion.chunk");
1276
+ assert_eq!(chunk.choices[0].delta.content.as_deref(), Some("Hello"));
1277
+ }
1278
+
1279
+ // ── model prefix / matching ───────────────────────────────────────────────
1280
+
1281
+ #[test]
1282
+ fn strip_model_prefix() {
1283
+ let p = provider();
1284
+ assert_eq!(p.strip_model_prefix("vertex_ai/gemini-2.0-flash"), "gemini-2.0-flash");
1285
+ assert_eq!(p.strip_model_prefix("gemini-2.0-flash"), "gemini-2.0-flash");
1286
+ }
1287
+
1288
+ #[test]
1289
+ fn matches_model() {
1290
+ let p = provider();
1291
+ assert!(p.matches_model("vertex_ai/gemini-2.0-flash"));
1292
+ assert!(!p.matches_model("gemini-2.0-flash"));
1293
+ assert!(!p.matches_model("gpt-4"));
1294
+ }
1295
+
1296
+ // ── multimodal content ────────────────────────────────────────────────────
1297
+
1298
+ #[test]
1299
+ fn transform_request_multimodal_user_content() {
1300
+ let p = provider();
1301
+ let mut body = json!({
1302
+ "messages": [{
1303
+ "role": "user",
1304
+ "content": [
1305
+ {"type": "text", "text": "What is in this image?"},
1306
+ {"type": "image_url", "image_url": {"url": "data:image/jpeg;base64,/9j/abc=="}}
1307
+ ]
1308
+ }]
1309
+ });
1310
+
1311
+ p.transform_request(&mut body).unwrap();
1312
+
1313
+ let parts = body["contents"][0]["parts"].as_array().unwrap();
1314
+ assert_eq!(parts.len(), 2);
1315
+ assert_eq!(parts[0]["text"], "What is in this image?");
1316
+ assert_eq!(parts[1]["inlineData"]["mimeType"], "image/jpeg");
1317
+ assert_eq!(parts[1]["inlineData"]["data"], "/9j/abc==");
1318
+ }
1319
+
1320
+ #[test]
1321
+ fn transform_request_multimodal_url_image() {
1322
+ let p = provider();
1323
+ let mut body = json!({
1324
+ "messages": [{
1325
+ "role": "user",
1326
+ "content": [
1327
+ {"type": "text", "text": "Describe this."},
1328
+ {"type": "image_url", "image_url": {"url": "https://example.com/image.jpg"}}
1329
+ ]
1330
+ }]
1331
+ });
1332
+
1333
+ p.transform_request(&mut body).unwrap();
1334
+
1335
+ let parts = body["contents"][0]["parts"].as_array().unwrap();
1336
+ assert_eq!(parts.len(), 2);
1337
+ assert_eq!(parts[1]["fileData"]["fileUri"], "https://example.com/image.jpg");
1338
+ }
1339
+
1340
+ // ── response_format translation ───────────────────────────────────────────
1341
+
1342
+ #[test]
1343
+ fn transform_request_response_format_json_object() {
1344
+ let p = provider();
1345
+ let mut body = json!({
1346
+ "messages": [{"role": "user", "content": "hi"}],
1347
+ "response_format": {"type": "json_object"}
1348
+ });
1349
+
1350
+ p.transform_request(&mut body).unwrap();
1351
+
1352
+ assert_eq!(body["generationConfig"]["responseMimeType"], "application/json");
1353
+ }
1354
+
1355
+ #[test]
1356
+ fn transform_request_response_format_json_schema() {
1357
+ let p = provider();
1358
+ let mut body = json!({
1359
+ "messages": [{"role": "user", "content": "hi"}],
1360
+ "response_format": {
1361
+ "type": "json_schema",
1362
+ "json_schema": {
1363
+ "name": "test",
1364
+ "schema": {"type": "object", "properties": {"name": {"type": "string"}}}
1365
+ }
1366
+ }
1367
+ });
1368
+
1369
+ p.transform_request(&mut body).unwrap();
1370
+
1371
+ assert_eq!(body["generationConfig"]["responseMimeType"], "application/json");
1372
+ assert_eq!(body["generationConfig"]["responseSchema"]["type"], "object");
1373
+ }
1374
+
1375
+ // ── tool_choice translation ───────────────────────────────────────────────
1376
+
1377
+ #[test]
1378
+ fn transform_request_tool_choice_auto() {
1379
+ let p = provider();
1380
+ let mut body = json!({
1381
+ "messages": [{"role": "user", "content": "hi"}],
1382
+ "tools": [{"type": "function", "function": {"name": "f", "parameters": {}}}],
1383
+ "tool_choice": "auto"
1384
+ });
1385
+
1386
+ p.transform_request(&mut body).unwrap();
1387
+
1388
+ assert_eq!(body["toolConfig"]["functionCallingConfig"]["mode"], "AUTO");
1389
+ }
1390
+
1391
+ #[test]
1392
+ fn transform_request_tool_choice_none() {
1393
+ let p = provider();
1394
+ let mut body = json!({
1395
+ "messages": [{"role": "user", "content": "hi"}],
1396
+ "tools": [{"type": "function", "function": {"name": "f", "parameters": {}}}],
1397
+ "tool_choice": "none"
1398
+ });
1399
+
1400
+ p.transform_request(&mut body).unwrap();
1401
+
1402
+ assert_eq!(body["toolConfig"]["functionCallingConfig"]["mode"], "NONE");
1403
+ }
1404
+
1405
+ #[test]
1406
+ fn transform_request_tool_choice_required() {
1407
+ let p = provider();
1408
+ let mut body = json!({
1409
+ "messages": [{"role": "user", "content": "hi"}],
1410
+ "tools": [{"type": "function", "function": {"name": "f", "parameters": {}}}],
1411
+ "tool_choice": "required"
1412
+ });
1413
+
1414
+ p.transform_request(&mut body).unwrap();
1415
+
1416
+ assert_eq!(body["toolConfig"]["functionCallingConfig"]["mode"], "ANY");
1417
+ }
1418
+
1419
+ #[test]
1420
+ fn transform_request_tool_choice_specific_function() {
1421
+ let p = provider();
1422
+ let mut body = json!({
1423
+ "messages": [{"role": "user", "content": "hi"}],
1424
+ "tools": [{"type": "function", "function": {"name": "get_weather", "parameters": {}}}],
1425
+ "tool_choice": {"type": "function", "function": {"name": "get_weather"}}
1426
+ });
1427
+
1428
+ p.transform_request(&mut body).unwrap();
1429
+
1430
+ assert_eq!(body["toolConfig"]["functionCallingConfig"]["mode"], "ANY");
1431
+ assert_eq!(
1432
+ body["toolConfig"]["functionCallingConfig"]["allowedFunctionNames"][0],
1433
+ "get_weather"
1434
+ );
1435
+ }
1436
+
1437
+ // ── helper function tests ─────────────────────────────────────────────────
1438
+
1439
+ #[test]
1440
+ fn convert_user_content_string() {
1441
+ let content = json!("Hello!");
1442
+ let parts = convert_user_content_to_gemini(Some(&content));
1443
+ assert_eq!(parts.len(), 1);
1444
+ assert_eq!(parts[0]["text"], "Hello!");
1445
+ }
1446
+
1447
+ #[test]
1448
+ fn convert_user_content_array_with_image() {
1449
+ let content = json!([
1450
+ {"type": "text", "text": "What is this?"},
1451
+ {"type": "image_url", "image_url": {"url": "data:image/png;base64,iVBOR"}}
1452
+ ]);
1453
+ let parts = convert_user_content_to_gemini(Some(&content));
1454
+ assert_eq!(parts.len(), 2);
1455
+ assert_eq!(parts[0]["text"], "What is this?");
1456
+ assert_eq!(parts[1]["inlineData"]["mimeType"], "image/png");
1457
+ assert_eq!(parts[1]["inlineData"]["data"], "iVBOR");
1458
+ }
1459
+
1460
+ #[test]
1461
+ fn convert_user_content_none() {
1462
+ let parts = convert_user_content_to_gemini(None);
1463
+ assert_eq!(parts.len(), 1);
1464
+ assert_eq!(parts[0]["text"], "");
1465
+ }
1466
+
1467
+ #[test]
1468
+ fn convert_user_content_document_part() {
1469
+ let content = json!([
1470
+ {"type": "text", "text": "Read this PDF."},
1471
+ {"type": "document", "document": {"data": "base64data==", "media_type": "application/pdf"}}
1472
+ ]);
1473
+ let parts = convert_user_content_to_gemini(Some(&content));
1474
+ assert_eq!(parts.len(), 2);
1475
+ assert_eq!(parts[0]["text"], "Read this PDF.");
1476
+ assert_eq!(parts[1]["inlineData"]["mimeType"], "application/pdf");
1477
+ assert_eq!(parts[1]["inlineData"]["data"], "base64data==");
1478
+ }
1479
+
1480
+ #[test]
1481
+ fn translate_tool_choice_string_values() {
1482
+ let auto = translate_tool_choice(Some(&json!("auto"))).unwrap();
1483
+ assert_eq!(auto["functionCallingConfig"]["mode"], "AUTO");
1484
+
1485
+ let none = translate_tool_choice(Some(&json!("none"))).unwrap();
1486
+ assert_eq!(none["functionCallingConfig"]["mode"], "NONE");
1487
+
1488
+ let required = translate_tool_choice(Some(&json!("required"))).unwrap();
1489
+ assert_eq!(required["functionCallingConfig"]["mode"], "ANY");
1490
+ }
1491
+
1492
+ #[test]
1493
+ fn translate_tool_choice_specific_function() {
1494
+ let tc = json!({"type": "function", "function": {"name": "my_fn"}});
1495
+ let result = translate_tool_choice(Some(&tc)).unwrap();
1496
+ assert_eq!(result["functionCallingConfig"]["mode"], "ANY");
1497
+ assert_eq!(result["functionCallingConfig"]["allowedFunctionNames"][0], "my_fn");
1498
+ }
1499
+
1500
+ #[test]
1501
+ fn translate_tool_choice_none_input() {
1502
+ assert!(translate_tool_choice(None).is_none());
1503
+ }
1504
+ }