liter_llm 1.0.0.pre.rc.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/README.md +239 -0
- data/ext/liter_llm_rb/extconf.rb +65 -0
- data/ext/liter_llm_rb/native/.cargo/config.toml +23 -0
- data/ext/liter_llm_rb/native/Cargo.lock +3713 -0
- data/ext/liter_llm_rb/native/Cargo.toml +32 -0
- data/ext/liter_llm_rb/native/build.rs +15 -0
- data/ext/liter_llm_rb/native/src/lib.rs +1079 -0
- data/lib/liter_llm.rb +8 -0
- data/sig/liter_llm.rbs +416 -0
- data/vendor/Cargo.toml +54 -0
- data/vendor/liter-llm/Cargo.toml +92 -0
- data/vendor/liter-llm/README.md +252 -0
- data/vendor/liter-llm/schemas/pricing.json +40 -0
- data/vendor/liter-llm/schemas/providers.json +1662 -0
- data/vendor/liter-llm/src/auth/azure_ad.rs +264 -0
- data/vendor/liter-llm/src/auth/bedrock_sts.rs +353 -0
- data/vendor/liter-llm/src/auth/mod.rs +68 -0
- data/vendor/liter-llm/src/auth/vertex_oauth.rs +353 -0
- data/vendor/liter-llm/src/client/config.rs +351 -0
- data/vendor/liter-llm/src/client/managed.rs +622 -0
- data/vendor/liter-llm/src/client/mod.rs +864 -0
- data/vendor/liter-llm/src/cost.rs +212 -0
- data/vendor/liter-llm/src/error.rs +190 -0
- data/vendor/liter-llm/src/http/eventstream.rs +860 -0
- data/vendor/liter-llm/src/http/mod.rs +12 -0
- data/vendor/liter-llm/src/http/request.rs +438 -0
- data/vendor/liter-llm/src/http/retry.rs +72 -0
- data/vendor/liter-llm/src/http/streaming.rs +289 -0
- data/vendor/liter-llm/src/lib.rs +37 -0
- data/vendor/liter-llm/src/provider/anthropic.rs +2250 -0
- data/vendor/liter-llm/src/provider/azure.rs +579 -0
- data/vendor/liter-llm/src/provider/bedrock.rs +1543 -0
- data/vendor/liter-llm/src/provider/cohere.rs +654 -0
- data/vendor/liter-llm/src/provider/custom.rs +404 -0
- data/vendor/liter-llm/src/provider/google_ai.rs +281 -0
- data/vendor/liter-llm/src/provider/mistral.rs +188 -0
- data/vendor/liter-llm/src/provider/mod.rs +616 -0
- data/vendor/liter-llm/src/provider/vertex.rs +1504 -0
- data/vendor/liter-llm/src/tests.rs +1425 -0
- data/vendor/liter-llm/src/tokenizer.rs +281 -0
- data/vendor/liter-llm/src/tower/budget.rs +599 -0
- data/vendor/liter-llm/src/tower/cache.rs +502 -0
- data/vendor/liter-llm/src/tower/cache_opendal.rs +270 -0
- data/vendor/liter-llm/src/tower/cooldown.rs +231 -0
- data/vendor/liter-llm/src/tower/cost.rs +404 -0
- data/vendor/liter-llm/src/tower/fallback.rs +121 -0
- data/vendor/liter-llm/src/tower/health.rs +219 -0
- data/vendor/liter-llm/src/tower/hooks.rs +369 -0
- data/vendor/liter-llm/src/tower/mod.rs +77 -0
- data/vendor/liter-llm/src/tower/rate_limit.rs +300 -0
- data/vendor/liter-llm/src/tower/router.rs +436 -0
- data/vendor/liter-llm/src/tower/service.rs +181 -0
- data/vendor/liter-llm/src/tower/tests.rs +539 -0
- data/vendor/liter-llm/src/tower/tests_common.rs +252 -0
- data/vendor/liter-llm/src/tower/tracing.rs +209 -0
- data/vendor/liter-llm/src/tower/types.rs +170 -0
- data/vendor/liter-llm/src/types/audio.rs +52 -0
- data/vendor/liter-llm/src/types/batch.rs +77 -0
- data/vendor/liter-llm/src/types/chat.rs +214 -0
- data/vendor/liter-llm/src/types/common.rs +244 -0
- data/vendor/liter-llm/src/types/embedding.rs +84 -0
- data/vendor/liter-llm/src/types/files.rs +58 -0
- data/vendor/liter-llm/src/types/image.rs +40 -0
- data/vendor/liter-llm/src/types/mod.rs +27 -0
- data/vendor/liter-llm/src/types/models.rs +21 -0
- data/vendor/liter-llm/src/types/moderation.rs +80 -0
- data/vendor/liter-llm/src/types/ocr.rs +87 -0
- data/vendor/liter-llm/src/types/rerank.rs +46 -0
- data/vendor/liter-llm/src/types/responses.rs +55 -0
- data/vendor/liter-llm/src/types/search.rs +45 -0
- data/vendor/liter-llm/tests/contract.rs +332 -0
- data/vendor/liter-llm-ffi/Cargo.toml +30 -0
- data/vendor/liter-llm-ffi/build.rs +66 -0
- data/vendor/liter-llm-ffi/cbindgen.toml +60 -0
- data/vendor/liter-llm-ffi/liter_llm.h +850 -0
- data/vendor/liter-llm-ffi/src/lib.rs +2488 -0
- metadata +286 -0
|
@@ -0,0 +1,1504 @@
|
|
|
1
|
+
use std::borrow::Cow;
|
|
2
|
+
use std::sync::atomic::{AtomicU64, Ordering};
|
|
3
|
+
|
|
4
|
+
use crate::error::{LiterLlmError, Result};
|
|
5
|
+
use crate::provider::Provider;
|
|
6
|
+
use crate::types::ChatCompletionChunk;
|
|
7
|
+
|
|
8
|
+
/// Default Vertex AI location when none is specified.
|
|
9
|
+
const DEFAULT_LOCATION: &str = "us-central1";
|
|
10
|
+
|
|
11
|
+
/// Global counter for generating unique tool call IDs.
|
|
12
|
+
static TOOL_CALL_COUNTER: AtomicU64 = AtomicU64::new(0);
|
|
13
|
+
|
|
14
|
+
/// Google Vertex AI / Gemini provider.
|
|
15
|
+
///
|
|
16
|
+
/// Differences from the OpenAI-compatible baseline:
|
|
17
|
+
/// - Auth uses `Authorization: Bearer <token>` where the token is a Google
|
|
18
|
+
/// Cloud OAuth2 access token (obtained via ADC, service account, or
|
|
19
|
+
/// `gcloud auth print-access-token`).
|
|
20
|
+
/// - The base URL is constructed from `VERTEXAI_PROJECT` and `VERTEXAI_LOCATION`
|
|
21
|
+
/// environment variables, or can be overridden via `base_url` in [`ClientConfig`].
|
|
22
|
+
/// The resulting URL follows the pattern:
|
|
23
|
+
/// `https://{location}-aiplatform.googleapis.com/v1/projects/{project}/locations/{location}`
|
|
24
|
+
/// - Model names are routed via the `vertex_ai/` prefix which is stripped
|
|
25
|
+
/// before being sent in the request body.
|
|
26
|
+
/// - The native Gemini `generateContent` format is used, not the OpenAI
|
|
27
|
+
/// `/chat/completions` path. Request and response are translated accordingly.
|
|
28
|
+
/// - Streaming uses SSE with `?alt=sse`; each chunk is a full `generateContent`
|
|
29
|
+
/// response JSON wrapped in a standard SSE `data:` line.
|
|
30
|
+
///
|
|
31
|
+
/// # Token management
|
|
32
|
+
///
|
|
33
|
+
/// Supply a pre-obtained access token as the `api_key` parameter.
|
|
34
|
+
/// Token refresh is the caller's responsibility. A future release will add
|
|
35
|
+
/// ADC / service-account-based automatic refresh.
|
|
36
|
+
///
|
|
37
|
+
/// # Environment variables
|
|
38
|
+
///
|
|
39
|
+
/// - `VERTEXAI_PROJECT` (required): Google Cloud project ID.
|
|
40
|
+
/// - `VERTEXAI_LOCATION` (optional): GCP region, defaults to `us-central1`.
|
|
41
|
+
///
|
|
42
|
+
/// # Configuration
|
|
43
|
+
///
|
|
44
|
+
/// ```rust,ignore
|
|
45
|
+
/// // Option 1: Use environment variables (recommended).
|
|
46
|
+
/// // export VERTEXAI_PROJECT=my-project
|
|
47
|
+
/// // export VERTEXAI_LOCATION=us-central1
|
|
48
|
+
/// let config = ClientConfigBuilder::new("ya29.your-access-token").build();
|
|
49
|
+
/// let client = DefaultClient::new(config, Some("vertex_ai/gemini-2.0-flash"))?;
|
|
50
|
+
///
|
|
51
|
+
/// // Option 2: Explicit base_url override (bypasses env var resolution).
|
|
52
|
+
/// let config = ClientConfigBuilder::new("ya29.your-access-token")
|
|
53
|
+
/// .base_url(
|
|
54
|
+
/// "https://us-central1-aiplatform.googleapis.com/v1/\
|
|
55
|
+
/// projects/my-project/locations/us-central1",
|
|
56
|
+
/// )
|
|
57
|
+
/// .build();
|
|
58
|
+
/// let client = DefaultClient::new(config, Some("vertex_ai/gemini-2.0-flash"))?;
|
|
59
|
+
/// ```
|
|
60
|
+
pub struct VertexAiProvider {
|
|
61
|
+
/// Cached base URL: `https://{location}-aiplatform.googleapis.com/v1/projects/{project}/locations/{location}`.
|
|
62
|
+
base_url: String,
|
|
63
|
+
}
|
|
64
|
+
|
|
65
|
+
impl VertexAiProvider {
|
|
66
|
+
/// Construct with an explicit project and location.
|
|
67
|
+
#[must_use]
|
|
68
|
+
pub fn new(project: impl Into<String>, location: impl Into<String>) -> Self {
|
|
69
|
+
let project = project.into();
|
|
70
|
+
let location = location.into();
|
|
71
|
+
let base_url =
|
|
72
|
+
format!("https://{location}-aiplatform.googleapis.com/v1/projects/{project}/locations/{location}");
|
|
73
|
+
Self { base_url }
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
/// Construct from environment variables.
|
|
77
|
+
///
|
|
78
|
+
/// Reads `VERTEXAI_PROJECT` and `VERTEXAI_LOCATION` (defaults to `us-central1`).
|
|
79
|
+
/// If `VERTEXAI_PROJECT` is not set, the base URL will be empty and
|
|
80
|
+
/// [`validate`] will return an error.
|
|
81
|
+
#[must_use]
|
|
82
|
+
pub fn from_env() -> Self {
|
|
83
|
+
let project = std::env::var("VERTEXAI_PROJECT").unwrap_or_default();
|
|
84
|
+
let location = std::env::var("VERTEXAI_LOCATION").unwrap_or_else(|_| DEFAULT_LOCATION.to_owned());
|
|
85
|
+
if project.is_empty() {
|
|
86
|
+
return Self {
|
|
87
|
+
base_url: String::new(),
|
|
88
|
+
};
|
|
89
|
+
}
|
|
90
|
+
Self::new(project, location)
|
|
91
|
+
}
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
impl Provider for VertexAiProvider {
|
|
95
|
+
fn name(&self) -> &str {
|
|
96
|
+
"vertex_ai"
|
|
97
|
+
}
|
|
98
|
+
|
|
99
|
+
/// Vertex AI base URL constructed from project and location.
|
|
100
|
+
///
|
|
101
|
+
/// Returns an empty string when the provider was constructed without a
|
|
102
|
+
/// valid project (e.g. `VERTEXAI_PROJECT` not set). The [`validate`]
|
|
103
|
+
/// method catches this at client construction time.
|
|
104
|
+
fn base_url(&self) -> &str {
|
|
105
|
+
&self.base_url
|
|
106
|
+
}
|
|
107
|
+
|
|
108
|
+
/// Validate that required configuration is present.
|
|
109
|
+
///
|
|
110
|
+
/// Checks that the base URL was successfully constructed from environment
|
|
111
|
+
/// variables (`VERTEXAI_PROJECT` is required, `VERTEXAI_LOCATION` defaults
|
|
112
|
+
/// to `us-central1`).
|
|
113
|
+
fn validate(&self) -> Result<()> {
|
|
114
|
+
if self.base_url.is_empty() {
|
|
115
|
+
return Err(LiterLlmError::BadRequest {
|
|
116
|
+
message: "Vertex AI requires a project ID. \
|
|
117
|
+
Set VERTEXAI_PROJECT (and optionally VERTEXAI_LOCATION) \
|
|
118
|
+
in the environment, or provide an explicit base_url in \
|
|
119
|
+
ClientConfig."
|
|
120
|
+
.into(),
|
|
121
|
+
});
|
|
122
|
+
}
|
|
123
|
+
Ok(())
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
fn auth_header<'a>(&'a self, api_key: &'a str) -> Option<(Cow<'static, str>, Cow<'a, str>)> {
|
|
127
|
+
// Vertex AI requires an OAuth2 Bearer token.
|
|
128
|
+
Some((Cow::Borrowed("Authorization"), Cow::Owned(format!("Bearer {api_key}"))))
|
|
129
|
+
}
|
|
130
|
+
|
|
131
|
+
fn matches_model(&self, model: &str) -> bool {
|
|
132
|
+
model.starts_with("vertex_ai/")
|
|
133
|
+
}
|
|
134
|
+
|
|
135
|
+
fn strip_model_prefix<'m>(&self, model: &'m str) -> &'m str {
|
|
136
|
+
model.strip_prefix("vertex_ai/").unwrap_or(model)
|
|
137
|
+
}
|
|
138
|
+
|
|
139
|
+
/// Build the full URL for a Gemini API request.
|
|
140
|
+
///
|
|
141
|
+
/// Chat completions → `{base}/publishers/google/models/{model}:generateContent`
|
|
142
|
+
/// Embeddings → `{base}/publishers/google/models/{model}:predict`
|
|
143
|
+
/// Other paths → `{base}{endpoint_path}`
|
|
144
|
+
fn build_url(&self, endpoint_path: &str, model: &str) -> String {
|
|
145
|
+
let base = self.base_url();
|
|
146
|
+
if base.is_empty() {
|
|
147
|
+
// Caller must supply a base_url; will fail at validate() / HTTP layer.
|
|
148
|
+
return String::new();
|
|
149
|
+
}
|
|
150
|
+
let base = base.trim_end_matches('/');
|
|
151
|
+
if endpoint_path.contains("chat/completions") {
|
|
152
|
+
format!("{base}/publishers/google/models/{model}:generateContent")
|
|
153
|
+
} else if endpoint_path.contains("embeddings") {
|
|
154
|
+
format!("{base}/publishers/google/models/{model}:predict")
|
|
155
|
+
} else {
|
|
156
|
+
format!("{base}{endpoint_path}")
|
|
157
|
+
}
|
|
158
|
+
}
|
|
159
|
+
|
|
160
|
+
fn transform_request(&self, body: &mut serde_json::Value) -> Result<()> {
|
|
161
|
+
transform_gemini_request(body)
|
|
162
|
+
}
|
|
163
|
+
|
|
164
|
+
fn transform_response(&self, body: &mut serde_json::Value) -> Result<()> {
|
|
165
|
+
transform_gemini_response(body)
|
|
166
|
+
}
|
|
167
|
+
|
|
168
|
+
/// Build the streaming URL: appends `?alt=sse` to enable SSE streaming.
|
|
169
|
+
///
|
|
170
|
+
/// Gemini's streaming endpoint uses the same path as the non-streaming
|
|
171
|
+
/// `generateContent` endpoint but requires `?alt=sse` to switch to
|
|
172
|
+
/// Server-Sent Events mode.
|
|
173
|
+
fn build_stream_url(&self, endpoint_path: &str, model: &str) -> String {
|
|
174
|
+
let url = self.build_url(endpoint_path, model);
|
|
175
|
+
if url.is_empty() {
|
|
176
|
+
return url;
|
|
177
|
+
}
|
|
178
|
+
format!("{url}?alt=sse")
|
|
179
|
+
}
|
|
180
|
+
|
|
181
|
+
fn parse_stream_event(&self, event_data: &str) -> Result<Option<ChatCompletionChunk>> {
|
|
182
|
+
parse_gemini_stream_event(event_data)
|
|
183
|
+
}
|
|
184
|
+
}
|
|
185
|
+
|
|
186
|
+
// ── Shared Gemini transform functions ────────────────────────────────────────
|
|
187
|
+
//
|
|
188
|
+
// These are `pub(crate)` so that both `VertexAiProvider` and `GoogleAiProvider`
|
|
189
|
+
// can reuse the same Gemini request/response translation logic.
|
|
190
|
+
|
|
191
|
+
/// Convert an OpenAI-style chat request to Gemini `generateContent` format.
|
|
192
|
+
///
|
|
193
|
+
/// Key translations:
|
|
194
|
+
/// - System messages → `systemInstruction.parts[]`.
|
|
195
|
+
/// - Assistant role → `model` role.
|
|
196
|
+
/// - Tool calls → `functionCall` parts; tool results → `functionResponse` parts.
|
|
197
|
+
/// - Generation parameters → `generationConfig`.
|
|
198
|
+
/// - Multimodal content arrays → Gemini's `inlineData` / `fileData` format.
|
|
199
|
+
/// - `response_format` → `generationConfig.responseMimeType`.
|
|
200
|
+
/// - `tool_choice` → `toolConfig.functionCallingConfig.mode`.
|
|
201
|
+
/// - `extra_body.safety_settings` → top-level `safetySettings` array.
|
|
202
|
+
/// - `extra_body.grounding_config` / `google_search_retrieval` → `tools` entry.
|
|
203
|
+
/// - `extra_body.cached_content` → top-level `cachedContent` field.
|
|
204
|
+
/// - `ContentPart::Document` → `inlineData` with the document's MIME type.
|
|
205
|
+
pub(crate) fn transform_gemini_request(body: &mut serde_json::Value) -> Result<()> {
|
|
206
|
+
use serde_json::json;
|
|
207
|
+
|
|
208
|
+
// Extract extra_body before taking ownership of fields, since it may contain
|
|
209
|
+
// Gemini-specific extensions (safety_settings, grounding_config, cached_content).
|
|
210
|
+
let extra_body = body
|
|
211
|
+
.as_object_mut()
|
|
212
|
+
.and_then(|o| o.remove("extra_body"))
|
|
213
|
+
.and_then(|v| match v {
|
|
214
|
+
serde_json::Value::Object(map) => Some(map),
|
|
215
|
+
_ => None,
|
|
216
|
+
});
|
|
217
|
+
|
|
218
|
+
// Take ownership of the messages array to avoid cloning.
|
|
219
|
+
let messages = body
|
|
220
|
+
.as_object_mut()
|
|
221
|
+
.and_then(|o| o.remove("messages"))
|
|
222
|
+
.and_then(|v| match v {
|
|
223
|
+
serde_json::Value::Array(arr) => Some(arr),
|
|
224
|
+
_ => None,
|
|
225
|
+
})
|
|
226
|
+
.unwrap_or_default();
|
|
227
|
+
|
|
228
|
+
let mut system_parts: Vec<serde_json::Value> = vec![];
|
|
229
|
+
let mut contents: Vec<serde_json::Value> = vec![];
|
|
230
|
+
|
|
231
|
+
for msg in &messages {
|
|
232
|
+
let role = msg.get("role").and_then(|r| r.as_str()).unwrap_or("");
|
|
233
|
+
let content = msg.get("content");
|
|
234
|
+
|
|
235
|
+
match role {
|
|
236
|
+
"system" | "developer" => {
|
|
237
|
+
if let Some(text) = content.and_then(|c| c.as_str()) {
|
|
238
|
+
system_parts.push(json!({"text": text}));
|
|
239
|
+
}
|
|
240
|
+
}
|
|
241
|
+
"user" => {
|
|
242
|
+
let parts = convert_user_content_to_gemini(content);
|
|
243
|
+
contents.push(json!({"role": "user", "parts": parts}));
|
|
244
|
+
}
|
|
245
|
+
"assistant" => {
|
|
246
|
+
let mut parts: Vec<serde_json::Value> = vec![];
|
|
247
|
+
if let Some(text) = content.and_then(|c| c.as_str())
|
|
248
|
+
&& !text.is_empty()
|
|
249
|
+
{
|
|
250
|
+
parts.push(json!({"text": text}));
|
|
251
|
+
}
|
|
252
|
+
// Convert OpenAI tool_calls to Gemini functionCall parts.
|
|
253
|
+
if let Some(tool_calls) = msg.get("tool_calls").and_then(|t| t.as_array()) {
|
|
254
|
+
for tc in tool_calls {
|
|
255
|
+
let args: serde_json::Value = tc
|
|
256
|
+
.pointer("/function/arguments")
|
|
257
|
+
.and_then(|a| a.as_str())
|
|
258
|
+
.and_then(|s| serde_json::from_str(s).ok())
|
|
259
|
+
.unwrap_or_else(|| json!({}));
|
|
260
|
+
parts.push(json!({
|
|
261
|
+
"functionCall": {
|
|
262
|
+
"name": tc.pointer("/function/name"),
|
|
263
|
+
"args": args
|
|
264
|
+
}
|
|
265
|
+
}));
|
|
266
|
+
}
|
|
267
|
+
}
|
|
268
|
+
if parts.is_empty() {
|
|
269
|
+
parts.push(json!({"text": ""}));
|
|
270
|
+
}
|
|
271
|
+
// Gemini uses "model" role for assistant turns.
|
|
272
|
+
contents.push(json!({"role": "model", "parts": parts}));
|
|
273
|
+
}
|
|
274
|
+
"tool" => {
|
|
275
|
+
// Map tool result back to a user turn with a functionResponse part.
|
|
276
|
+
// Gemini requires the function name — use the `name` field only.
|
|
277
|
+
// The `tool_call_id` is an OpenAI correlation ID, not a function name,
|
|
278
|
+
// so we must not fall back to it.
|
|
279
|
+
let name = msg.get("name").and_then(|n| n.as_str()).unwrap_or("tool");
|
|
280
|
+
let result_content = content.cloned().unwrap_or(json!(null));
|
|
281
|
+
contents.push(json!({
|
|
282
|
+
"role": "user",
|
|
283
|
+
"parts": [{
|
|
284
|
+
"functionResponse": {
|
|
285
|
+
"name": name,
|
|
286
|
+
"response": {"result": result_content}
|
|
287
|
+
}
|
|
288
|
+
}]
|
|
289
|
+
}));
|
|
290
|
+
}
|
|
291
|
+
_ => {}
|
|
292
|
+
}
|
|
293
|
+
}
|
|
294
|
+
|
|
295
|
+
// Build generationConfig from OpenAI parameters.
|
|
296
|
+
let mut gen_config = json!({});
|
|
297
|
+
// Support both max_tokens (legacy) and max_completion_tokens (newer OpenAI spec).
|
|
298
|
+
if let Some(max_tokens) = body.get("max_completion_tokens").or_else(|| body.get("max_tokens")) {
|
|
299
|
+
gen_config["maxOutputTokens"] = max_tokens.clone();
|
|
300
|
+
}
|
|
301
|
+
if let Some(temp) = body.get("temperature") {
|
|
302
|
+
gen_config["temperature"] = temp.clone();
|
|
303
|
+
}
|
|
304
|
+
if let Some(top_p) = body.get("top_p") {
|
|
305
|
+
gen_config["topP"] = top_p.clone();
|
|
306
|
+
}
|
|
307
|
+
if let Some(stop) = body.get("stop") {
|
|
308
|
+
let sequences = if let Some(s) = stop.as_str() {
|
|
309
|
+
vec![json!(s)]
|
|
310
|
+
} else {
|
|
311
|
+
stop.as_array().cloned().unwrap_or_default()
|
|
312
|
+
};
|
|
313
|
+
gen_config["stopSequences"] = json!(sequences);
|
|
314
|
+
}
|
|
315
|
+
|
|
316
|
+
// Translate response_format to Gemini's responseMimeType.
|
|
317
|
+
if let Some(rf) = body.get("response_format") {
|
|
318
|
+
let rf_type = rf.get("type").and_then(|t| t.as_str()).unwrap_or("");
|
|
319
|
+
match rf_type {
|
|
320
|
+
"json_object" => {
|
|
321
|
+
gen_config["responseMimeType"] = json!("application/json");
|
|
322
|
+
}
|
|
323
|
+
"json_schema" => {
|
|
324
|
+
gen_config["responseMimeType"] = json!("application/json");
|
|
325
|
+
// If a JSON schema is provided, pass it through.
|
|
326
|
+
if let Some(schema) = rf.get("json_schema").and_then(|s| s.get("schema")) {
|
|
327
|
+
gen_config["responseSchema"] = schema.clone();
|
|
328
|
+
}
|
|
329
|
+
}
|
|
330
|
+
// "text" or unknown types: no special handling needed.
|
|
331
|
+
_ => {}
|
|
332
|
+
}
|
|
333
|
+
}
|
|
334
|
+
|
|
335
|
+
// Translate OpenAI tools array to Gemini functionDeclarations.
|
|
336
|
+
let mut tools_value = body.get("tools").and_then(|t| t.as_array()).map(|arr| {
|
|
337
|
+
let declarations: Vec<serde_json::Value> = arr
|
|
338
|
+
.iter()
|
|
339
|
+
.map(|t| {
|
|
340
|
+
let name = t.pointer("/function/name").cloned().unwrap_or(json!("unknown"));
|
|
341
|
+
let description = t.pointer("/function/description").cloned().unwrap_or(json!(""));
|
|
342
|
+
let parameters = t
|
|
343
|
+
.pointer("/function/parameters")
|
|
344
|
+
.cloned()
|
|
345
|
+
.unwrap_or_else(|| json!({"type": "object"}));
|
|
346
|
+
json!({
|
|
347
|
+
"name": name,
|
|
348
|
+
"description": description,
|
|
349
|
+
"parameters": parameters
|
|
350
|
+
})
|
|
351
|
+
})
|
|
352
|
+
.collect();
|
|
353
|
+
json!([{"functionDeclarations": declarations}])
|
|
354
|
+
});
|
|
355
|
+
|
|
356
|
+
// Translate tool_choice to Gemini toolConfig.functionCallingConfig.mode.
|
|
357
|
+
let tool_config = translate_tool_choice(body.get("tool_choice"));
|
|
358
|
+
|
|
359
|
+
// ── extra_body extensions ────────────────────────────────────────────────
|
|
360
|
+
let mut safety_settings: Option<serde_json::Value> = None;
|
|
361
|
+
let mut cached_content: Option<serde_json::Value> = None;
|
|
362
|
+
|
|
363
|
+
if let Some(ref eb) = extra_body {
|
|
364
|
+
// Safety settings: inject as top-level safetySettings array.
|
|
365
|
+
if let Some(ss) = eb.get("safety_settings") {
|
|
366
|
+
safety_settings = Some(ss.clone());
|
|
367
|
+
}
|
|
368
|
+
|
|
369
|
+
// Grounding / Google Search: add google_search_retrieval to tools array.
|
|
370
|
+
if eb.contains_key("grounding_config") || eb.contains_key("google_search_retrieval") {
|
|
371
|
+
let grounding_tool = json!({"google_search_retrieval": {}});
|
|
372
|
+
match &mut tools_value {
|
|
373
|
+
Some(existing) => {
|
|
374
|
+
if let Some(arr) = existing.as_array_mut() {
|
|
375
|
+
arr.push(grounding_tool);
|
|
376
|
+
}
|
|
377
|
+
}
|
|
378
|
+
None => {
|
|
379
|
+
tools_value = Some(json!([grounding_tool]));
|
|
380
|
+
}
|
|
381
|
+
}
|
|
382
|
+
}
|
|
383
|
+
|
|
384
|
+
// Context caching: inject as top-level cachedContent field.
|
|
385
|
+
if let Some(cc) = eb.get("cached_content") {
|
|
386
|
+
cached_content = Some(cc.clone());
|
|
387
|
+
}
|
|
388
|
+
}
|
|
389
|
+
|
|
390
|
+
let mut new_body = json!({"contents": contents});
|
|
391
|
+
if !system_parts.is_empty() {
|
|
392
|
+
// Gemini API requires camelCase: systemInstruction.
|
|
393
|
+
new_body["systemInstruction"] = json!({"parts": system_parts});
|
|
394
|
+
}
|
|
395
|
+
if let Some(obj) = gen_config.as_object()
|
|
396
|
+
&& !obj.is_empty()
|
|
397
|
+
{
|
|
398
|
+
new_body["generationConfig"] = gen_config;
|
|
399
|
+
}
|
|
400
|
+
if let Some(tools) = tools_value {
|
|
401
|
+
new_body["tools"] = tools;
|
|
402
|
+
}
|
|
403
|
+
if let Some(tc) = tool_config {
|
|
404
|
+
new_body["toolConfig"] = tc;
|
|
405
|
+
}
|
|
406
|
+
if let Some(ss) = safety_settings {
|
|
407
|
+
new_body["safetySettings"] = ss;
|
|
408
|
+
}
|
|
409
|
+
if let Some(cc) = cached_content {
|
|
410
|
+
new_body["cachedContent"] = cc;
|
|
411
|
+
}
|
|
412
|
+
|
|
413
|
+
*body = new_body;
|
|
414
|
+
Ok(())
|
|
415
|
+
}
|
|
416
|
+
|
|
417
|
+
/// Normalize a Gemini `generateContent` response to OpenAI chat completion format.
|
|
418
|
+
///
|
|
419
|
+
/// Gemini wraps the response in `candidates[0].content.parts[]`.
|
|
420
|
+
/// Finish reasons use Gemini terminology (`STOP`, `MAX_TOKENS`, `SAFETY`, ...)
|
|
421
|
+
/// and are mapped to the OpenAI `finish_reason` set.
|
|
422
|
+
///
|
|
423
|
+
/// If `groundingMetadata` is present on the candidate, it is included in the
|
|
424
|
+
/// response as `_grounding_metadata` for supplementary use by callers.
|
|
425
|
+
///
|
|
426
|
+
/// **Known limitation:** The `model` field in the normalized response is
|
|
427
|
+
/// always `""`. Gemini/Vertex AI does not include the model name in its
|
|
428
|
+
/// response body -- the model is only present in the request URL path.
|
|
429
|
+
pub(crate) fn transform_gemini_response(body: &mut serde_json::Value) -> Result<()> {
|
|
430
|
+
use serde_json::json;
|
|
431
|
+
|
|
432
|
+
// Check for a blocked prompt (no candidates, but promptFeedback.blockReason set).
|
|
433
|
+
let candidates = body.get("candidates").and_then(|c| c.as_array());
|
|
434
|
+
if candidates.is_none_or(|c| c.is_empty()) {
|
|
435
|
+
let block_reason = body
|
|
436
|
+
.pointer("/promptFeedback/blockReason")
|
|
437
|
+
.and_then(|r| r.as_str())
|
|
438
|
+
.unwrap_or("UNKNOWN");
|
|
439
|
+
let prompt_tokens = body
|
|
440
|
+
.pointer("/usageMetadata/promptTokenCount")
|
|
441
|
+
.and_then(|v| v.as_u64())
|
|
442
|
+
.unwrap_or(0);
|
|
443
|
+
*body = json!({
|
|
444
|
+
"id": "gemini-resp",
|
|
445
|
+
"object": "chat.completion",
|
|
446
|
+
"created": super::unix_timestamp_secs(),
|
|
447
|
+
"model": "",
|
|
448
|
+
"choices": [{
|
|
449
|
+
"index": 0,
|
|
450
|
+
"message": {"role": "assistant", "content": null},
|
|
451
|
+
"finish_reason": "content_filter"
|
|
452
|
+
}],
|
|
453
|
+
"usage": {
|
|
454
|
+
"prompt_tokens": prompt_tokens,
|
|
455
|
+
"completion_tokens": 0,
|
|
456
|
+
"total_tokens": prompt_tokens
|
|
457
|
+
},
|
|
458
|
+
"system_fingerprint": null,
|
|
459
|
+
"_block_reason": block_reason
|
|
460
|
+
});
|
|
461
|
+
return Ok(());
|
|
462
|
+
}
|
|
463
|
+
|
|
464
|
+
let candidate = body.pointer("/candidates/0").cloned();
|
|
465
|
+
let finish_reason_raw = candidate
|
|
466
|
+
.as_ref()
|
|
467
|
+
.and_then(|c| c.get("finishReason"))
|
|
468
|
+
.and_then(|f| f.as_str())
|
|
469
|
+
.unwrap_or("STOP");
|
|
470
|
+
let parts = candidate
|
|
471
|
+
.as_ref()
|
|
472
|
+
.and_then(|c| c.pointer("/content/parts"))
|
|
473
|
+
.and_then(|p| p.as_array())
|
|
474
|
+
.cloned()
|
|
475
|
+
.unwrap_or_default();
|
|
476
|
+
|
|
477
|
+
// Collect text content from parts.
|
|
478
|
+
let text: String = parts
|
|
479
|
+
.iter()
|
|
480
|
+
.filter_map(|p| p.get("text").and_then(|t| t.as_str()))
|
|
481
|
+
.collect::<Vec<_>>()
|
|
482
|
+
.join("");
|
|
483
|
+
|
|
484
|
+
// Collect functionCall parts and convert to OpenAI tool_calls.
|
|
485
|
+
// Each call gets a unique ID via an atomic counter to avoid collisions
|
|
486
|
+
// when the same function is called multiple times.
|
|
487
|
+
let tool_calls: Vec<serde_json::Value> = parts
|
|
488
|
+
.iter()
|
|
489
|
+
.filter_map(|p| {
|
|
490
|
+
p.get("functionCall").map(|fc| {
|
|
491
|
+
let name = fc.get("name").and_then(|n| n.as_str()).unwrap_or("unknown");
|
|
492
|
+
let call_id = TOOL_CALL_COUNTER.fetch_add(1, Ordering::Relaxed);
|
|
493
|
+
let arguments = serde_json::to_string(fc.get("args").unwrap_or(&json!({}))).unwrap_or_default();
|
|
494
|
+
json!({
|
|
495
|
+
"id": format!("call_{name}_{call_id}"),
|
|
496
|
+
"type": "function",
|
|
497
|
+
"function": {
|
|
498
|
+
"name": fc.get("name"),
|
|
499
|
+
"arguments": arguments
|
|
500
|
+
}
|
|
501
|
+
})
|
|
502
|
+
})
|
|
503
|
+
})
|
|
504
|
+
.collect();
|
|
505
|
+
|
|
506
|
+
let finish_reason = match finish_reason_raw {
|
|
507
|
+
"STOP" => "stop",
|
|
508
|
+
"MAX_TOKENS" => "length",
|
|
509
|
+
"SAFETY" | "RECITATION" | "BLOCKLIST" | "PROHIBITED_CONTENT" | "SPII" | "IMAGE_SAFETY" => "content_filter",
|
|
510
|
+
"LANGUAGE" | "OTHER" => "stop",
|
|
511
|
+
"TOOL_CODE" | "FUNCTION_CALL" => "tool_calls",
|
|
512
|
+
_ => "stop",
|
|
513
|
+
};
|
|
514
|
+
|
|
515
|
+
let prompt_tokens = body
|
|
516
|
+
.pointer("/usageMetadata/promptTokenCount")
|
|
517
|
+
.and_then(|v| v.as_u64())
|
|
518
|
+
.unwrap_or(0);
|
|
519
|
+
let completion_tokens = body
|
|
520
|
+
.pointer("/usageMetadata/candidatesTokenCount")
|
|
521
|
+
.and_then(|v| v.as_u64())
|
|
522
|
+
.unwrap_or(0);
|
|
523
|
+
|
|
524
|
+
let response_id = body.get("responseId").cloned().unwrap_or_else(|| json!("gemini-resp"));
|
|
525
|
+
|
|
526
|
+
let content_value: serde_json::Value = if text.is_empty() { json!(null) } else { json!(text) };
|
|
527
|
+
|
|
528
|
+
let mut message = json!({"role": "assistant", "content": content_value});
|
|
529
|
+
if !tool_calls.is_empty() {
|
|
530
|
+
message["tool_calls"] = json!(tool_calls);
|
|
531
|
+
}
|
|
532
|
+
|
|
533
|
+
// Extract grounding metadata if present (supplementary data from Google Search grounding).
|
|
534
|
+
let grounding_metadata = candidate.as_ref().and_then(|c| c.get("groundingMetadata")).cloned();
|
|
535
|
+
|
|
536
|
+
let mut result = json!({
|
|
537
|
+
"id": response_id,
|
|
538
|
+
"object": "chat.completion",
|
|
539
|
+
"created": super::unix_timestamp_secs(),
|
|
540
|
+
"model": "",
|
|
541
|
+
"choices": [{
|
|
542
|
+
"index": 0,
|
|
543
|
+
"message": message,
|
|
544
|
+
"finish_reason": finish_reason
|
|
545
|
+
}],
|
|
546
|
+
"usage": {
|
|
547
|
+
"prompt_tokens": prompt_tokens,
|
|
548
|
+
"completion_tokens": completion_tokens,
|
|
549
|
+
"total_tokens": prompt_tokens + completion_tokens
|
|
550
|
+
}
|
|
551
|
+
});
|
|
552
|
+
|
|
553
|
+
if let Some(gm) = grounding_metadata {
|
|
554
|
+
result["_grounding_metadata"] = gm;
|
|
555
|
+
}
|
|
556
|
+
|
|
557
|
+
*body = result;
|
|
558
|
+
|
|
559
|
+
Ok(())
|
|
560
|
+
}
|
|
561
|
+
|
|
562
|
+
/// Parse a single SSE event from Gemini's streaming endpoint.
|
|
563
|
+
///
|
|
564
|
+
/// Gemini streaming uses SSE with `?alt=sse`. Each event data is a complete
|
|
565
|
+
/// `generateContent` JSON response. We reuse `transform_gemini_response` to
|
|
566
|
+
/// normalize it into OpenAI format, then build a `ChatCompletionChunk` from
|
|
567
|
+
/// the first choice's message content.
|
|
568
|
+
///
|
|
569
|
+
/// **Note:** The `id` and `model` fields are empty strings on every chunk
|
|
570
|
+
/// because Gemini's streaming payloads do not include them, and this parser
|
|
571
|
+
/// is stateless.
|
|
572
|
+
pub(crate) fn parse_gemini_stream_event(event_data: &str) -> Result<Option<ChatCompletionChunk>> {
|
|
573
|
+
// NOTE: `[DONE]` is handled at the SSE parser level; no check needed here.
|
|
574
|
+
if event_data.trim().is_empty() {
|
|
575
|
+
return Ok(None);
|
|
576
|
+
}
|
|
577
|
+
|
|
578
|
+
let mut body: serde_json::Value = serde_json::from_str(event_data).map_err(|e| LiterLlmError::Streaming {
|
|
579
|
+
message: format!("failed to parse Gemini SSE data: {e}"),
|
|
580
|
+
})?;
|
|
581
|
+
|
|
582
|
+
// Normalize to OpenAI chat completion format.
|
|
583
|
+
transform_gemini_response(&mut body)?;
|
|
584
|
+
|
|
585
|
+
// Extract fields from the normalized response.
|
|
586
|
+
let id = body
|
|
587
|
+
.get("id")
|
|
588
|
+
.and_then(|v| v.as_str())
|
|
589
|
+
.unwrap_or("gemini-resp")
|
|
590
|
+
.to_owned();
|
|
591
|
+
let model = body.get("model").and_then(|v| v.as_str()).unwrap_or("").to_owned();
|
|
592
|
+
|
|
593
|
+
let choice = body.pointer("/choices/0");
|
|
594
|
+
let content = choice
|
|
595
|
+
.and_then(|c| c.pointer("/message/content"))
|
|
596
|
+
.and_then(|v| v.as_str())
|
|
597
|
+
.map(ToOwned::to_owned);
|
|
598
|
+
let finish_reason_str = choice
|
|
599
|
+
.and_then(|c| c.get("finish_reason"))
|
|
600
|
+
.and_then(|v| v.as_str())
|
|
601
|
+
.unwrap_or("");
|
|
602
|
+
|
|
603
|
+
// Extract tool_calls from the normalized message if present.
|
|
604
|
+
let stream_tool_calls = choice
|
|
605
|
+
.and_then(|c| c.pointer("/message/tool_calls"))
|
|
606
|
+
.and_then(|v| v.as_array())
|
|
607
|
+
.filter(|arr| !arr.is_empty())
|
|
608
|
+
.map(|arr| {
|
|
609
|
+
use crate::types::{StreamFunctionCall, StreamToolCall, ToolType};
|
|
610
|
+
arr.iter()
|
|
611
|
+
.enumerate()
|
|
612
|
+
.map(|(idx, tc)| StreamToolCall {
|
|
613
|
+
index: idx as u32,
|
|
614
|
+
id: tc.get("id").and_then(|v| v.as_str()).map(ToOwned::to_owned),
|
|
615
|
+
call_type: Some(ToolType::Function),
|
|
616
|
+
function: tc.get("function").map(|f| StreamFunctionCall {
|
|
617
|
+
name: f.get("name").and_then(|v| v.as_str()).map(ToOwned::to_owned),
|
|
618
|
+
arguments: f.get("arguments").and_then(|v| v.as_str()).map(ToOwned::to_owned),
|
|
619
|
+
}),
|
|
620
|
+
})
|
|
621
|
+
.collect::<Vec<_>>()
|
|
622
|
+
});
|
|
623
|
+
|
|
624
|
+
use crate::types::{FinishReason, StreamChoice, StreamDelta};
|
|
625
|
+
|
|
626
|
+
let finish_reason = match finish_reason_str {
|
|
627
|
+
"stop" => Some(FinishReason::Stop),
|
|
628
|
+
"length" => Some(FinishReason::Length),
|
|
629
|
+
"tool_calls" => Some(FinishReason::ToolCalls),
|
|
630
|
+
"content_filter" => Some(FinishReason::ContentFilter),
|
|
631
|
+
_ => None,
|
|
632
|
+
};
|
|
633
|
+
|
|
634
|
+
let chunk = ChatCompletionChunk {
|
|
635
|
+
id,
|
|
636
|
+
object: "chat.completion.chunk".to_owned(),
|
|
637
|
+
created: super::unix_timestamp_secs(),
|
|
638
|
+
model,
|
|
639
|
+
choices: vec![StreamChoice {
|
|
640
|
+
index: 0,
|
|
641
|
+
delta: StreamDelta {
|
|
642
|
+
role: Some("assistant".to_owned()),
|
|
643
|
+
content,
|
|
644
|
+
tool_calls: stream_tool_calls,
|
|
645
|
+
function_call: None,
|
|
646
|
+
refusal: None,
|
|
647
|
+
},
|
|
648
|
+
finish_reason,
|
|
649
|
+
}],
|
|
650
|
+
usage: None,
|
|
651
|
+
system_fingerprint: None,
|
|
652
|
+
service_tier: None,
|
|
653
|
+
};
|
|
654
|
+
|
|
655
|
+
Ok(Some(chunk))
|
|
656
|
+
}
|
|
657
|
+
|
|
658
|
+
// ── Helper functions ──────────────────────────────────────────────────────────
|
|
659
|
+
|
|
660
|
+
/// Convert OpenAI user content (string or content-part array) to Gemini parts.
|
|
661
|
+
///
|
|
662
|
+
/// Handles four cases:
|
|
663
|
+
/// 1. Plain string -> single text part.
|
|
664
|
+
/// 2. Array of content parts -> each part converted to Gemini format.
|
|
665
|
+
/// 3. `ContentPart::Document` -> Gemini `inlineData` with the document's MIME type.
|
|
666
|
+
/// 4. None/null -> single empty text part.
|
|
667
|
+
pub(crate) fn convert_user_content_to_gemini(content: Option<&serde_json::Value>) -> Vec<serde_json::Value> {
|
|
668
|
+
use serde_json::json;
|
|
669
|
+
|
|
670
|
+
match content {
|
|
671
|
+
Some(serde_json::Value::String(s)) => vec![json!({"text": s})],
|
|
672
|
+
Some(serde_json::Value::Array(parts)) => {
|
|
673
|
+
parts
|
|
674
|
+
.iter()
|
|
675
|
+
.filter_map(|part| {
|
|
676
|
+
let part_type = part.get("type").and_then(|t| t.as_str())?;
|
|
677
|
+
match part_type {
|
|
678
|
+
"text" => {
|
|
679
|
+
let text = part.get("text").and_then(|t| t.as_str()).unwrap_or("");
|
|
680
|
+
Some(json!({"text": text}))
|
|
681
|
+
}
|
|
682
|
+
"image_url" => {
|
|
683
|
+
let url = part.pointer("/image_url/url").and_then(|u| u.as_str())?;
|
|
684
|
+
if url.starts_with("data:") {
|
|
685
|
+
// data:<media_type>;base64,<data>
|
|
686
|
+
if let Some((header, data)) = url.split_once(',') {
|
|
687
|
+
let mime_type = header.trim_start_matches("data:").trim_end_matches(";base64");
|
|
688
|
+
return Some(json!({
|
|
689
|
+
"inlineData": {
|
|
690
|
+
"mimeType": mime_type,
|
|
691
|
+
"data": data
|
|
692
|
+
}
|
|
693
|
+
}));
|
|
694
|
+
}
|
|
695
|
+
}
|
|
696
|
+
// Plain URL -- use Gemini's fileData format.
|
|
697
|
+
Some(json!({
|
|
698
|
+
"fileData": {
|
|
699
|
+
"mimeType": "image/jpeg",
|
|
700
|
+
"fileUri": url
|
|
701
|
+
}
|
|
702
|
+
}))
|
|
703
|
+
}
|
|
704
|
+
"document" => {
|
|
705
|
+
// ContentPart::Document -> Gemini inlineData.
|
|
706
|
+
let doc = part.get("document")?;
|
|
707
|
+
let data = doc.get("data").and_then(|d| d.as_str())?;
|
|
708
|
+
let media_type = doc
|
|
709
|
+
.get("media_type")
|
|
710
|
+
.and_then(|m| m.as_str())
|
|
711
|
+
.unwrap_or("application/pdf");
|
|
712
|
+
Some(json!({
|
|
713
|
+
"inlineData": {
|
|
714
|
+
"mimeType": media_type,
|
|
715
|
+
"data": data
|
|
716
|
+
}
|
|
717
|
+
}))
|
|
718
|
+
}
|
|
719
|
+
_ => {
|
|
720
|
+
// Unknown content part types: fall back to text representation.
|
|
721
|
+
let text = part.get("text").and_then(|t| t.as_str()).unwrap_or("");
|
|
722
|
+
if text.is_empty() {
|
|
723
|
+
None
|
|
724
|
+
} else {
|
|
725
|
+
Some(json!({"text": text}))
|
|
726
|
+
}
|
|
727
|
+
}
|
|
728
|
+
}
|
|
729
|
+
})
|
|
730
|
+
.collect()
|
|
731
|
+
}
|
|
732
|
+
_ => vec![json!({"text": ""})],
|
|
733
|
+
}
|
|
734
|
+
}
|
|
735
|
+
|
|
736
|
+
/// Translate OpenAI `tool_choice` to Gemini `toolConfig.functionCallingConfig`.
|
|
737
|
+
///
|
|
738
|
+
/// OpenAI `tool_choice` values:
|
|
739
|
+
/// - `"none"` -> `NONE`
|
|
740
|
+
/// - `"auto"` -> `AUTO`
|
|
741
|
+
/// - `"required"` -> `ANY`
|
|
742
|
+
/// - `{"type": "function", "function": {"name": "..."}}` -> `ANY` with `allowedFunctionNames`
|
|
743
|
+
fn translate_tool_choice(tool_choice: Option<&serde_json::Value>) -> Option<serde_json::Value> {
|
|
744
|
+
use serde_json::json;
|
|
745
|
+
|
|
746
|
+
let tc = tool_choice?;
|
|
747
|
+
|
|
748
|
+
if let Some(s) = tc.as_str() {
|
|
749
|
+
let mode = match s {
|
|
750
|
+
"none" => "NONE",
|
|
751
|
+
"auto" => "AUTO",
|
|
752
|
+
"required" => "ANY",
|
|
753
|
+
_ => return None,
|
|
754
|
+
};
|
|
755
|
+
return Some(json!({
|
|
756
|
+
"functionCallingConfig": {
|
|
757
|
+
"mode": mode
|
|
758
|
+
}
|
|
759
|
+
}));
|
|
760
|
+
}
|
|
761
|
+
|
|
762
|
+
// Object form: {"type": "function", "function": {"name": "specific_fn"}}
|
|
763
|
+
if let Some(name) = tc.pointer("/function/name").and_then(|n| n.as_str()) {
|
|
764
|
+
return Some(json!({
|
|
765
|
+
"functionCallingConfig": {
|
|
766
|
+
"mode": "ANY",
|
|
767
|
+
"allowedFunctionNames": [name]
|
|
768
|
+
}
|
|
769
|
+
}));
|
|
770
|
+
}
|
|
771
|
+
|
|
772
|
+
None
|
|
773
|
+
}
|
|
774
|
+
|
|
775
|
+
// ── Unit tests ────────────────────────────────────────────────────────────────
|
|
776
|
+
|
|
777
|
+
#[cfg(test)]
|
|
778
|
+
mod tests {
|
|
779
|
+
use serde_json::json;
|
|
780
|
+
|
|
781
|
+
use super::*;
|
|
782
|
+
use crate::provider::Provider;
|
|
783
|
+
|
|
784
|
+
fn provider() -> VertexAiProvider {
|
|
785
|
+
VertexAiProvider::new("test-project", "us-central1")
|
|
786
|
+
}
|
|
787
|
+
|
|
788
|
+
fn provider_without_project() -> VertexAiProvider {
|
|
789
|
+
VertexAiProvider {
|
|
790
|
+
base_url: String::new(),
|
|
791
|
+
}
|
|
792
|
+
}
|
|
793
|
+
|
|
794
|
+
// ── validate ──────────────────────────────────────────────────────────────
|
|
795
|
+
|
|
796
|
+
#[test]
|
|
797
|
+
fn validate_succeeds_with_project() {
|
|
798
|
+
let p = provider();
|
|
799
|
+
assert!(p.validate().is_ok());
|
|
800
|
+
}
|
|
801
|
+
|
|
802
|
+
#[test]
|
|
803
|
+
fn validate_fails_without_project() {
|
|
804
|
+
let p = provider_without_project();
|
|
805
|
+
let err = p.validate().unwrap_err();
|
|
806
|
+
assert!(
|
|
807
|
+
err.to_string().contains("VERTEXAI_PROJECT"),
|
|
808
|
+
"error should mention VERTEXAI_PROJECT"
|
|
809
|
+
);
|
|
810
|
+
}
|
|
811
|
+
|
|
812
|
+
// ── base_url ──────────────────────────────────────────────────────────────
|
|
813
|
+
|
|
814
|
+
#[test]
|
|
815
|
+
fn base_url_constructed_from_project_and_location() {
|
|
816
|
+
let p = provider();
|
|
817
|
+
assert_eq!(
|
|
818
|
+
p.base_url(),
|
|
819
|
+
"https://us-central1-aiplatform.googleapis.com/v1/projects/test-project/locations/us-central1"
|
|
820
|
+
);
|
|
821
|
+
}
|
|
822
|
+
|
|
823
|
+
#[test]
|
|
824
|
+
fn base_url_custom_location() {
|
|
825
|
+
let p = VertexAiProvider::new("my-proj", "europe-west1");
|
|
826
|
+
assert_eq!(
|
|
827
|
+
p.base_url(),
|
|
828
|
+
"https://europe-west1-aiplatform.googleapis.com/v1/projects/my-proj/locations/europe-west1"
|
|
829
|
+
);
|
|
830
|
+
}
|
|
831
|
+
|
|
832
|
+
// ── build_url ─────────────────────────────────────────────────────────────
|
|
833
|
+
|
|
834
|
+
#[test]
|
|
835
|
+
fn build_url_returns_empty_without_base() {
|
|
836
|
+
let p = provider_without_project();
|
|
837
|
+
let url = p.build_url("/chat/completions", "gemini-2.0-flash");
|
|
838
|
+
assert!(url.is_empty(), "should return empty string without a base URL");
|
|
839
|
+
}
|
|
840
|
+
|
|
841
|
+
#[test]
|
|
842
|
+
fn build_url_chat_completions() {
|
|
843
|
+
let p = provider();
|
|
844
|
+
let url = p.build_url("/chat/completions", "gemini-2.0-flash");
|
|
845
|
+
assert!(url.ends_with("/publishers/google/models/gemini-2.0-flash:generateContent"));
|
|
846
|
+
}
|
|
847
|
+
|
|
848
|
+
#[test]
|
|
849
|
+
fn build_url_embeddings() {
|
|
850
|
+
let p = provider();
|
|
851
|
+
let url = p.build_url("/embeddings", "text-embedding-004");
|
|
852
|
+
assert!(url.ends_with("/publishers/google/models/text-embedding-004:predict"));
|
|
853
|
+
}
|
|
854
|
+
|
|
855
|
+
// ── transform_request ─────────────────────────────────────────────────────
|
|
856
|
+
|
|
857
|
+
#[test]
|
|
858
|
+
fn transform_request_basic_chat() {
|
|
859
|
+
let p = provider();
|
|
860
|
+
let mut body = json!({
|
|
861
|
+
"messages": [
|
|
862
|
+
{"role": "system", "content": "You are a helpful assistant."},
|
|
863
|
+
{"role": "user", "content": "Hello!"}
|
|
864
|
+
],
|
|
865
|
+
"max_tokens": 200,
|
|
866
|
+
"temperature": 0.5
|
|
867
|
+
});
|
|
868
|
+
|
|
869
|
+
p.transform_request(&mut body).unwrap();
|
|
870
|
+
|
|
871
|
+
// System instruction extracted with camelCase key required by Gemini API.
|
|
872
|
+
assert_eq!(
|
|
873
|
+
body["systemInstruction"]["parts"][0]["text"],
|
|
874
|
+
"You are a helpful assistant."
|
|
875
|
+
);
|
|
876
|
+
|
|
877
|
+
// User message converted to Gemini format.
|
|
878
|
+
assert_eq!(body["contents"][0]["role"], "user");
|
|
879
|
+
assert_eq!(body["contents"][0]["parts"][0]["text"], "Hello!");
|
|
880
|
+
|
|
881
|
+
// Generation config set.
|
|
882
|
+
assert_eq!(body["generationConfig"]["maxOutputTokens"], 200);
|
|
883
|
+
assert_eq!(body["generationConfig"]["temperature"], 0.5);
|
|
884
|
+
}
|
|
885
|
+
|
|
886
|
+
#[test]
|
|
887
|
+
fn transform_request_assistant_becomes_model_role() {
|
|
888
|
+
let p = provider();
|
|
889
|
+
let mut body = json!({
|
|
890
|
+
"messages": [
|
|
891
|
+
{"role": "user", "content": "Hi"},
|
|
892
|
+
{"role": "assistant", "content": "Hello there!"}
|
|
893
|
+
]
|
|
894
|
+
});
|
|
895
|
+
|
|
896
|
+
p.transform_request(&mut body).unwrap();
|
|
897
|
+
|
|
898
|
+
assert_eq!(body["contents"][1]["role"], "model");
|
|
899
|
+
assert_eq!(body["contents"][1]["parts"][0]["text"], "Hello there!");
|
|
900
|
+
}
|
|
901
|
+
|
|
902
|
+
#[test]
|
|
903
|
+
fn transform_request_with_tool_calls() {
|
|
904
|
+
let p = provider();
|
|
905
|
+
let mut body = json!({
|
|
906
|
+
"messages": [
|
|
907
|
+
{"role": "user", "content": "What is the weather in Berlin?"},
|
|
908
|
+
{
|
|
909
|
+
"role": "assistant",
|
|
910
|
+
"content": null,
|
|
911
|
+
"tool_calls": [{
|
|
912
|
+
"id": "call_1",
|
|
913
|
+
"type": "function",
|
|
914
|
+
"function": {"name": "get_weather", "arguments": "{\"city\":\"Berlin\"}"}
|
|
915
|
+
}]
|
|
916
|
+
},
|
|
917
|
+
{
|
|
918
|
+
"role": "tool",
|
|
919
|
+
"name": "get_weather",
|
|
920
|
+
"tool_call_id": "call_1",
|
|
921
|
+
"content": "Sunny, 22°C"
|
|
922
|
+
}
|
|
923
|
+
]
|
|
924
|
+
});
|
|
925
|
+
|
|
926
|
+
p.transform_request(&mut body).unwrap();
|
|
927
|
+
|
|
928
|
+
let contents = body["contents"].as_array().unwrap();
|
|
929
|
+
assert_eq!(contents.len(), 3);
|
|
930
|
+
|
|
931
|
+
// Assistant turn with functionCall part.
|
|
932
|
+
let model_turn = &contents[1];
|
|
933
|
+
assert_eq!(model_turn["role"], "model");
|
|
934
|
+
let fn_call = &model_turn["parts"][0]["functionCall"];
|
|
935
|
+
assert_eq!(fn_call["name"], "get_weather");
|
|
936
|
+
assert_eq!(fn_call["args"]["city"], "Berlin");
|
|
937
|
+
|
|
938
|
+
// Tool result as user turn with functionResponse.
|
|
939
|
+
let tool_turn = &contents[2];
|
|
940
|
+
assert_eq!(tool_turn["role"], "user");
|
|
941
|
+
let fn_resp = &tool_turn["parts"][0]["functionResponse"];
|
|
942
|
+
assert_eq!(fn_resp["name"], "get_weather");
|
|
943
|
+
}
|
|
944
|
+
|
|
945
|
+
#[test]
|
|
946
|
+
fn transform_request_stop_sequences() {
|
|
947
|
+
let p = provider();
|
|
948
|
+
let mut body = json!({
|
|
949
|
+
"messages": [{"role": "user", "content": "hi"}],
|
|
950
|
+
"stop": ["END", "STOP"]
|
|
951
|
+
});
|
|
952
|
+
|
|
953
|
+
p.transform_request(&mut body).unwrap();
|
|
954
|
+
|
|
955
|
+
let stop_seqs = body["generationConfig"]["stopSequences"].as_array().unwrap();
|
|
956
|
+
assert_eq!(stop_seqs.len(), 2);
|
|
957
|
+
assert_eq!(stop_seqs[0], "END");
|
|
958
|
+
assert_eq!(stop_seqs[1], "STOP");
|
|
959
|
+
}
|
|
960
|
+
|
|
961
|
+
// ── transform_request: safety settings ───────────────────────────────────
|
|
962
|
+
|
|
963
|
+
#[test]
|
|
964
|
+
fn transform_request_safety_settings_from_extra_body() {
|
|
965
|
+
let p = provider();
|
|
966
|
+
let mut body = json!({
|
|
967
|
+
"messages": [{"role": "user", "content": "hi"}],
|
|
968
|
+
"extra_body": {
|
|
969
|
+
"safety_settings": [
|
|
970
|
+
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
|
|
971
|
+
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_ONLY_HIGH"}
|
|
972
|
+
]
|
|
973
|
+
}
|
|
974
|
+
});
|
|
975
|
+
|
|
976
|
+
p.transform_request(&mut body).unwrap();
|
|
977
|
+
|
|
978
|
+
let settings = body["safetySettings"].as_array().unwrap();
|
|
979
|
+
assert_eq!(settings.len(), 2);
|
|
980
|
+
assert_eq!(settings[0]["category"], "HARM_CATEGORY_HATE_SPEECH");
|
|
981
|
+
assert_eq!(settings[0]["threshold"], "BLOCK_MEDIUM_AND_ABOVE");
|
|
982
|
+
assert_eq!(settings[1]["category"], "HARM_CATEGORY_DANGEROUS_CONTENT");
|
|
983
|
+
}
|
|
984
|
+
|
|
985
|
+
// ── transform_request: grounding / Google Search ─────────────────────────
|
|
986
|
+
|
|
987
|
+
#[test]
|
|
988
|
+
fn transform_request_grounding_config_adds_google_search() {
|
|
989
|
+
let p = provider();
|
|
990
|
+
let mut body = json!({
|
|
991
|
+
"messages": [{"role": "user", "content": "What happened today?"}],
|
|
992
|
+
"extra_body": {
|
|
993
|
+
"grounding_config": {}
|
|
994
|
+
}
|
|
995
|
+
});
|
|
996
|
+
|
|
997
|
+
p.transform_request(&mut body).unwrap();
|
|
998
|
+
|
|
999
|
+
let tools = body["tools"].as_array().unwrap();
|
|
1000
|
+
assert!(
|
|
1001
|
+
tools.iter().any(|t| t.get("google_search_retrieval").is_some()),
|
|
1002
|
+
"tools should contain google_search_retrieval"
|
|
1003
|
+
);
|
|
1004
|
+
}
|
|
1005
|
+
|
|
1006
|
+
#[test]
|
|
1007
|
+
fn transform_request_google_search_retrieval_with_existing_tools() {
|
|
1008
|
+
let p = provider();
|
|
1009
|
+
let mut body = json!({
|
|
1010
|
+
"messages": [{"role": "user", "content": "hi"}],
|
|
1011
|
+
"tools": [{"type": "function", "function": {"name": "f", "parameters": {}}}],
|
|
1012
|
+
"extra_body": {
|
|
1013
|
+
"google_search_retrieval": {}
|
|
1014
|
+
}
|
|
1015
|
+
});
|
|
1016
|
+
|
|
1017
|
+
p.transform_request(&mut body).unwrap();
|
|
1018
|
+
|
|
1019
|
+
let tools = body["tools"].as_array().unwrap();
|
|
1020
|
+
// Should have functionDeclarations + google_search_retrieval.
|
|
1021
|
+
assert_eq!(tools.len(), 2);
|
|
1022
|
+
assert!(tools[0].get("functionDeclarations").is_some());
|
|
1023
|
+
assert!(tools[1].get("google_search_retrieval").is_some());
|
|
1024
|
+
}
|
|
1025
|
+
|
|
1026
|
+
// ── transform_request: context caching ───────────────────────────────────
|
|
1027
|
+
|
|
1028
|
+
#[test]
|
|
1029
|
+
fn transform_request_cached_content_from_extra_body() {
|
|
1030
|
+
let p = provider();
|
|
1031
|
+
let cached = "projects/xxx/locations/xxx/cachedContents/abc123";
|
|
1032
|
+
let mut body = json!({
|
|
1033
|
+
"messages": [{"role": "user", "content": "hi"}],
|
|
1034
|
+
"extra_body": {
|
|
1035
|
+
"cached_content": cached
|
|
1036
|
+
}
|
|
1037
|
+
});
|
|
1038
|
+
|
|
1039
|
+
p.transform_request(&mut body).unwrap();
|
|
1040
|
+
|
|
1041
|
+
assert_eq!(body["cachedContent"], cached);
|
|
1042
|
+
}
|
|
1043
|
+
|
|
1044
|
+
// ── transform_request: document handling ─────────────────────────────────
|
|
1045
|
+
|
|
1046
|
+
#[test]
|
|
1047
|
+
fn transform_request_document_content_part() {
|
|
1048
|
+
let p = provider();
|
|
1049
|
+
let mut body = json!({
|
|
1050
|
+
"messages": [{
|
|
1051
|
+
"role": "user",
|
|
1052
|
+
"content": [
|
|
1053
|
+
{"type": "text", "text": "Summarize this document."},
|
|
1054
|
+
{
|
|
1055
|
+
"type": "document",
|
|
1056
|
+
"document": {
|
|
1057
|
+
"data": "JVBERi0xLjQ=",
|
|
1058
|
+
"media_type": "application/pdf"
|
|
1059
|
+
}
|
|
1060
|
+
}
|
|
1061
|
+
]
|
|
1062
|
+
}]
|
|
1063
|
+
});
|
|
1064
|
+
|
|
1065
|
+
p.transform_request(&mut body).unwrap();
|
|
1066
|
+
|
|
1067
|
+
let parts = body["contents"][0]["parts"].as_array().unwrap();
|
|
1068
|
+
assert_eq!(parts.len(), 2);
|
|
1069
|
+
assert_eq!(parts[0]["text"], "Summarize this document.");
|
|
1070
|
+
assert_eq!(parts[1]["inlineData"]["mimeType"], "application/pdf");
|
|
1071
|
+
assert_eq!(parts[1]["inlineData"]["data"], "JVBERi0xLjQ=");
|
|
1072
|
+
}
|
|
1073
|
+
|
|
1074
|
+
// ── transform_response ────────────────────────────────────────────────────
|
|
1075
|
+
|
|
1076
|
+
#[test]
|
|
1077
|
+
fn transform_response_basic() {
|
|
1078
|
+
let p = provider();
|
|
1079
|
+
let mut body = json!({
|
|
1080
|
+
"responseId": "resp-gemini-123",
|
|
1081
|
+
"candidates": [{
|
|
1082
|
+
"content": {
|
|
1083
|
+
"role": "model",
|
|
1084
|
+
"parts": [{"text": "Hello from Gemini!"}]
|
|
1085
|
+
},
|
|
1086
|
+
"finishReason": "STOP"
|
|
1087
|
+
}],
|
|
1088
|
+
"usageMetadata": {
|
|
1089
|
+
"promptTokenCount": 8,
|
|
1090
|
+
"candidatesTokenCount": 6
|
|
1091
|
+
}
|
|
1092
|
+
});
|
|
1093
|
+
|
|
1094
|
+
p.transform_response(&mut body).unwrap();
|
|
1095
|
+
|
|
1096
|
+
assert_eq!(body["object"], "chat.completion");
|
|
1097
|
+
assert_eq!(body["id"], "resp-gemini-123");
|
|
1098
|
+
assert_eq!(body["choices"][0]["message"]["content"], "Hello from Gemini!");
|
|
1099
|
+
assert_eq!(body["choices"][0]["finish_reason"], "stop");
|
|
1100
|
+
assert_eq!(body["usage"]["prompt_tokens"], 8);
|
|
1101
|
+
assert_eq!(body["usage"]["completion_tokens"], 6);
|
|
1102
|
+
assert_eq!(body["usage"]["total_tokens"], 14);
|
|
1103
|
+
}
|
|
1104
|
+
|
|
1105
|
+
#[test]
|
|
1106
|
+
fn transform_response_tool_calls_have_unique_ids() {
|
|
1107
|
+
let p = provider();
|
|
1108
|
+
let mut body = json!({
|
|
1109
|
+
"candidates": [{
|
|
1110
|
+
"content": {
|
|
1111
|
+
"role": "model",
|
|
1112
|
+
"parts": [
|
|
1113
|
+
{
|
|
1114
|
+
"functionCall": {
|
|
1115
|
+
"name": "get_weather",
|
|
1116
|
+
"args": {"city": "Berlin"}
|
|
1117
|
+
}
|
|
1118
|
+
},
|
|
1119
|
+
{
|
|
1120
|
+
"functionCall": {
|
|
1121
|
+
"name": "get_weather",
|
|
1122
|
+
"args": {"city": "Paris"}
|
|
1123
|
+
}
|
|
1124
|
+
}
|
|
1125
|
+
]
|
|
1126
|
+
},
|
|
1127
|
+
"finishReason": "STOP"
|
|
1128
|
+
}],
|
|
1129
|
+
"usageMetadata": {"promptTokenCount": 10, "candidatesTokenCount": 5}
|
|
1130
|
+
});
|
|
1131
|
+
|
|
1132
|
+
p.transform_response(&mut body).unwrap();
|
|
1133
|
+
|
|
1134
|
+
let tool_calls = body["choices"][0]["message"]["tool_calls"].as_array().unwrap();
|
|
1135
|
+
assert_eq!(tool_calls.len(), 2);
|
|
1136
|
+
|
|
1137
|
+
// Both calls should have the function name "get_weather" but different IDs.
|
|
1138
|
+
let id0 = tool_calls[0]["id"].as_str().unwrap();
|
|
1139
|
+
let id1 = tool_calls[1]["id"].as_str().unwrap();
|
|
1140
|
+
assert_ne!(id0, id1, "tool call IDs must be unique even for the same function");
|
|
1141
|
+
assert!(id0.starts_with("call_get_weather_"));
|
|
1142
|
+
assert!(id1.starts_with("call_get_weather_"));
|
|
1143
|
+
|
|
1144
|
+
// Verify arguments are correct.
|
|
1145
|
+
let args0: serde_json::Value =
|
|
1146
|
+
serde_json::from_str(tool_calls[0]["function"]["arguments"].as_str().unwrap()).unwrap();
|
|
1147
|
+
let args1: serde_json::Value =
|
|
1148
|
+
serde_json::from_str(tool_calls[1]["function"]["arguments"].as_str().unwrap()).unwrap();
|
|
1149
|
+
assert_eq!(args0["city"], "Berlin");
|
|
1150
|
+
assert_eq!(args1["city"], "Paris");
|
|
1151
|
+
}
|
|
1152
|
+
|
|
1153
|
+
#[test]
|
|
1154
|
+
fn transform_response_single_tool_call() {
|
|
1155
|
+
let p = provider();
|
|
1156
|
+
let mut body = json!({
|
|
1157
|
+
"candidates": [{
|
|
1158
|
+
"content": {
|
|
1159
|
+
"role": "model",
|
|
1160
|
+
"parts": [{
|
|
1161
|
+
"functionCall": {
|
|
1162
|
+
"name": "get_weather",
|
|
1163
|
+
"args": {"city": "Berlin"}
|
|
1164
|
+
}
|
|
1165
|
+
}]
|
|
1166
|
+
},
|
|
1167
|
+
"finishReason": "STOP"
|
|
1168
|
+
}],
|
|
1169
|
+
"usageMetadata": {"promptTokenCount": 10, "candidatesTokenCount": 5}
|
|
1170
|
+
});
|
|
1171
|
+
|
|
1172
|
+
p.transform_response(&mut body).unwrap();
|
|
1173
|
+
|
|
1174
|
+
let tool_calls = body["choices"][0]["message"]["tool_calls"].as_array().unwrap();
|
|
1175
|
+
assert_eq!(tool_calls.len(), 1);
|
|
1176
|
+
assert_eq!(tool_calls[0]["function"]["name"], "get_weather");
|
|
1177
|
+
// ID should contain the function name and a unique counter.
|
|
1178
|
+
let id = tool_calls[0]["id"].as_str().unwrap();
|
|
1179
|
+
assert!(
|
|
1180
|
+
id.starts_with("call_get_weather_"),
|
|
1181
|
+
"id should start with call_get_weather_, got: {id}"
|
|
1182
|
+
);
|
|
1183
|
+
}
|
|
1184
|
+
|
|
1185
|
+
#[test]
|
|
1186
|
+
fn transform_response_finish_reason_mapping() {
|
|
1187
|
+
let p = provider();
|
|
1188
|
+
|
|
1189
|
+
for (gemini_reason, expected_oai_reason) in [
|
|
1190
|
+
("STOP", "stop"),
|
|
1191
|
+
("MAX_TOKENS", "length"),
|
|
1192
|
+
("SAFETY", "content_filter"),
|
|
1193
|
+
("RECITATION", "content_filter"),
|
|
1194
|
+
("BLOCKLIST", "content_filter"),
|
|
1195
|
+
("PROHIBITED_CONTENT", "content_filter"),
|
|
1196
|
+
("UNKNOWN_FUTURE_REASON", "stop"),
|
|
1197
|
+
] {
|
|
1198
|
+
let mut body = json!({
|
|
1199
|
+
"candidates": [{
|
|
1200
|
+
"content": {"role": "model", "parts": [{"text": ""}]},
|
|
1201
|
+
"finishReason": gemini_reason
|
|
1202
|
+
}],
|
|
1203
|
+
"usageMetadata": {"promptTokenCount": 0, "candidatesTokenCount": 0}
|
|
1204
|
+
});
|
|
1205
|
+
p.transform_response(&mut body).unwrap();
|
|
1206
|
+
assert_eq!(
|
|
1207
|
+
body["choices"][0]["finish_reason"], expected_oai_reason,
|
|
1208
|
+
"Gemini finishReason '{gemini_reason}' should map to '{expected_oai_reason}'"
|
|
1209
|
+
);
|
|
1210
|
+
}
|
|
1211
|
+
}
|
|
1212
|
+
|
|
1213
|
+
#[test]
|
|
1214
|
+
fn transform_response_grounding_metadata_preserved() {
|
|
1215
|
+
let p = provider();
|
|
1216
|
+
let mut body = json!({
|
|
1217
|
+
"candidates": [{
|
|
1218
|
+
"content": {
|
|
1219
|
+
"role": "model",
|
|
1220
|
+
"parts": [{"text": "grounded answer"}]
|
|
1221
|
+
},
|
|
1222
|
+
"finishReason": "STOP",
|
|
1223
|
+
"groundingMetadata": {
|
|
1224
|
+
"searchEntryPoint": {"renderedContent": "<html>...</html>"},
|
|
1225
|
+
"groundingChunks": [{"web": {"uri": "https://example.com", "title": "Example"}}]
|
|
1226
|
+
}
|
|
1227
|
+
}],
|
|
1228
|
+
"usageMetadata": {"promptTokenCount": 5, "candidatesTokenCount": 3}
|
|
1229
|
+
});
|
|
1230
|
+
|
|
1231
|
+
p.transform_response(&mut body).unwrap();
|
|
1232
|
+
|
|
1233
|
+
assert_eq!(body["choices"][0]["message"]["content"], "grounded answer");
|
|
1234
|
+
assert!(
|
|
1235
|
+
body.get("_grounding_metadata").is_some(),
|
|
1236
|
+
"grounding metadata should be preserved"
|
|
1237
|
+
);
|
|
1238
|
+
assert!(body["_grounding_metadata"]["groundingChunks"].as_array().unwrap().len() == 1);
|
|
1239
|
+
}
|
|
1240
|
+
|
|
1241
|
+
// ── parse_stream_event ────────────────────────────────────────────────────
|
|
1242
|
+
|
|
1243
|
+
#[test]
|
|
1244
|
+
fn parse_stream_event_empty_returns_none() {
|
|
1245
|
+
let p = provider();
|
|
1246
|
+
let result = p.parse_stream_event("").unwrap();
|
|
1247
|
+
assert!(result.is_none());
|
|
1248
|
+
}
|
|
1249
|
+
|
|
1250
|
+
#[test]
|
|
1251
|
+
fn parse_stream_event_done_is_handled_at_sse_level() {
|
|
1252
|
+
// `[DONE]` is now caught by the SSE parser before reaching the provider.
|
|
1253
|
+
// If it were to reach the provider, it would be invalid JSON.
|
|
1254
|
+
let p = provider();
|
|
1255
|
+
let result = p.parse_stream_event("[DONE]");
|
|
1256
|
+
assert!(
|
|
1257
|
+
result.is_err(),
|
|
1258
|
+
"[DONE] is not valid JSON and should error if it reaches the provider"
|
|
1259
|
+
);
|
|
1260
|
+
}
|
|
1261
|
+
|
|
1262
|
+
#[test]
|
|
1263
|
+
fn parse_stream_event_basic_chunk() {
|
|
1264
|
+
let p = provider();
|
|
1265
|
+
let event_data = r#"{
|
|
1266
|
+
"candidates": [{
|
|
1267
|
+
"content": {"role": "model", "parts": [{"text": "Hello"}]},
|
|
1268
|
+
"finishReason": "STOP"
|
|
1269
|
+
}],
|
|
1270
|
+
"usageMetadata": {"promptTokenCount": 5, "candidatesTokenCount": 2}
|
|
1271
|
+
}"#;
|
|
1272
|
+
|
|
1273
|
+
let chunk = p.parse_stream_event(event_data).unwrap().unwrap();
|
|
1274
|
+
|
|
1275
|
+
assert_eq!(chunk.object, "chat.completion.chunk");
|
|
1276
|
+
assert_eq!(chunk.choices[0].delta.content.as_deref(), Some("Hello"));
|
|
1277
|
+
}
|
|
1278
|
+
|
|
1279
|
+
// ── model prefix / matching ───────────────────────────────────────────────
|
|
1280
|
+
|
|
1281
|
+
#[test]
|
|
1282
|
+
fn strip_model_prefix() {
|
|
1283
|
+
let p = provider();
|
|
1284
|
+
assert_eq!(p.strip_model_prefix("vertex_ai/gemini-2.0-flash"), "gemini-2.0-flash");
|
|
1285
|
+
assert_eq!(p.strip_model_prefix("gemini-2.0-flash"), "gemini-2.0-flash");
|
|
1286
|
+
}
|
|
1287
|
+
|
|
1288
|
+
#[test]
|
|
1289
|
+
fn matches_model() {
|
|
1290
|
+
let p = provider();
|
|
1291
|
+
assert!(p.matches_model("vertex_ai/gemini-2.0-flash"));
|
|
1292
|
+
assert!(!p.matches_model("gemini-2.0-flash"));
|
|
1293
|
+
assert!(!p.matches_model("gpt-4"));
|
|
1294
|
+
}
|
|
1295
|
+
|
|
1296
|
+
// ── multimodal content ────────────────────────────────────────────────────
|
|
1297
|
+
|
|
1298
|
+
#[test]
|
|
1299
|
+
fn transform_request_multimodal_user_content() {
|
|
1300
|
+
let p = provider();
|
|
1301
|
+
let mut body = json!({
|
|
1302
|
+
"messages": [{
|
|
1303
|
+
"role": "user",
|
|
1304
|
+
"content": [
|
|
1305
|
+
{"type": "text", "text": "What is in this image?"},
|
|
1306
|
+
{"type": "image_url", "image_url": {"url": "data:image/jpeg;base64,/9j/abc=="}}
|
|
1307
|
+
]
|
|
1308
|
+
}]
|
|
1309
|
+
});
|
|
1310
|
+
|
|
1311
|
+
p.transform_request(&mut body).unwrap();
|
|
1312
|
+
|
|
1313
|
+
let parts = body["contents"][0]["parts"].as_array().unwrap();
|
|
1314
|
+
assert_eq!(parts.len(), 2);
|
|
1315
|
+
assert_eq!(parts[0]["text"], "What is in this image?");
|
|
1316
|
+
assert_eq!(parts[1]["inlineData"]["mimeType"], "image/jpeg");
|
|
1317
|
+
assert_eq!(parts[1]["inlineData"]["data"], "/9j/abc==");
|
|
1318
|
+
}
|
|
1319
|
+
|
|
1320
|
+
#[test]
|
|
1321
|
+
fn transform_request_multimodal_url_image() {
|
|
1322
|
+
let p = provider();
|
|
1323
|
+
let mut body = json!({
|
|
1324
|
+
"messages": [{
|
|
1325
|
+
"role": "user",
|
|
1326
|
+
"content": [
|
|
1327
|
+
{"type": "text", "text": "Describe this."},
|
|
1328
|
+
{"type": "image_url", "image_url": {"url": "https://example.com/image.jpg"}}
|
|
1329
|
+
]
|
|
1330
|
+
}]
|
|
1331
|
+
});
|
|
1332
|
+
|
|
1333
|
+
p.transform_request(&mut body).unwrap();
|
|
1334
|
+
|
|
1335
|
+
let parts = body["contents"][0]["parts"].as_array().unwrap();
|
|
1336
|
+
assert_eq!(parts.len(), 2);
|
|
1337
|
+
assert_eq!(parts[1]["fileData"]["fileUri"], "https://example.com/image.jpg");
|
|
1338
|
+
}
|
|
1339
|
+
|
|
1340
|
+
// ── response_format translation ───────────────────────────────────────────
|
|
1341
|
+
|
|
1342
|
+
#[test]
|
|
1343
|
+
fn transform_request_response_format_json_object() {
|
|
1344
|
+
let p = provider();
|
|
1345
|
+
let mut body = json!({
|
|
1346
|
+
"messages": [{"role": "user", "content": "hi"}],
|
|
1347
|
+
"response_format": {"type": "json_object"}
|
|
1348
|
+
});
|
|
1349
|
+
|
|
1350
|
+
p.transform_request(&mut body).unwrap();
|
|
1351
|
+
|
|
1352
|
+
assert_eq!(body["generationConfig"]["responseMimeType"], "application/json");
|
|
1353
|
+
}
|
|
1354
|
+
|
|
1355
|
+
#[test]
|
|
1356
|
+
fn transform_request_response_format_json_schema() {
|
|
1357
|
+
let p = provider();
|
|
1358
|
+
let mut body = json!({
|
|
1359
|
+
"messages": [{"role": "user", "content": "hi"}],
|
|
1360
|
+
"response_format": {
|
|
1361
|
+
"type": "json_schema",
|
|
1362
|
+
"json_schema": {
|
|
1363
|
+
"name": "test",
|
|
1364
|
+
"schema": {"type": "object", "properties": {"name": {"type": "string"}}}
|
|
1365
|
+
}
|
|
1366
|
+
}
|
|
1367
|
+
});
|
|
1368
|
+
|
|
1369
|
+
p.transform_request(&mut body).unwrap();
|
|
1370
|
+
|
|
1371
|
+
assert_eq!(body["generationConfig"]["responseMimeType"], "application/json");
|
|
1372
|
+
assert_eq!(body["generationConfig"]["responseSchema"]["type"], "object");
|
|
1373
|
+
}
|
|
1374
|
+
|
|
1375
|
+
// ── tool_choice translation ───────────────────────────────────────────────
|
|
1376
|
+
|
|
1377
|
+
#[test]
|
|
1378
|
+
fn transform_request_tool_choice_auto() {
|
|
1379
|
+
let p = provider();
|
|
1380
|
+
let mut body = json!({
|
|
1381
|
+
"messages": [{"role": "user", "content": "hi"}],
|
|
1382
|
+
"tools": [{"type": "function", "function": {"name": "f", "parameters": {}}}],
|
|
1383
|
+
"tool_choice": "auto"
|
|
1384
|
+
});
|
|
1385
|
+
|
|
1386
|
+
p.transform_request(&mut body).unwrap();
|
|
1387
|
+
|
|
1388
|
+
assert_eq!(body["toolConfig"]["functionCallingConfig"]["mode"], "AUTO");
|
|
1389
|
+
}
|
|
1390
|
+
|
|
1391
|
+
#[test]
|
|
1392
|
+
fn transform_request_tool_choice_none() {
|
|
1393
|
+
let p = provider();
|
|
1394
|
+
let mut body = json!({
|
|
1395
|
+
"messages": [{"role": "user", "content": "hi"}],
|
|
1396
|
+
"tools": [{"type": "function", "function": {"name": "f", "parameters": {}}}],
|
|
1397
|
+
"tool_choice": "none"
|
|
1398
|
+
});
|
|
1399
|
+
|
|
1400
|
+
p.transform_request(&mut body).unwrap();
|
|
1401
|
+
|
|
1402
|
+
assert_eq!(body["toolConfig"]["functionCallingConfig"]["mode"], "NONE");
|
|
1403
|
+
}
|
|
1404
|
+
|
|
1405
|
+
#[test]
|
|
1406
|
+
fn transform_request_tool_choice_required() {
|
|
1407
|
+
let p = provider();
|
|
1408
|
+
let mut body = json!({
|
|
1409
|
+
"messages": [{"role": "user", "content": "hi"}],
|
|
1410
|
+
"tools": [{"type": "function", "function": {"name": "f", "parameters": {}}}],
|
|
1411
|
+
"tool_choice": "required"
|
|
1412
|
+
});
|
|
1413
|
+
|
|
1414
|
+
p.transform_request(&mut body).unwrap();
|
|
1415
|
+
|
|
1416
|
+
assert_eq!(body["toolConfig"]["functionCallingConfig"]["mode"], "ANY");
|
|
1417
|
+
}
|
|
1418
|
+
|
|
1419
|
+
#[test]
|
|
1420
|
+
fn transform_request_tool_choice_specific_function() {
|
|
1421
|
+
let p = provider();
|
|
1422
|
+
let mut body = json!({
|
|
1423
|
+
"messages": [{"role": "user", "content": "hi"}],
|
|
1424
|
+
"tools": [{"type": "function", "function": {"name": "get_weather", "parameters": {}}}],
|
|
1425
|
+
"tool_choice": {"type": "function", "function": {"name": "get_weather"}}
|
|
1426
|
+
});
|
|
1427
|
+
|
|
1428
|
+
p.transform_request(&mut body).unwrap();
|
|
1429
|
+
|
|
1430
|
+
assert_eq!(body["toolConfig"]["functionCallingConfig"]["mode"], "ANY");
|
|
1431
|
+
assert_eq!(
|
|
1432
|
+
body["toolConfig"]["functionCallingConfig"]["allowedFunctionNames"][0],
|
|
1433
|
+
"get_weather"
|
|
1434
|
+
);
|
|
1435
|
+
}
|
|
1436
|
+
|
|
1437
|
+
// ── helper function tests ─────────────────────────────────────────────────
|
|
1438
|
+
|
|
1439
|
+
#[test]
|
|
1440
|
+
fn convert_user_content_string() {
|
|
1441
|
+
let content = json!("Hello!");
|
|
1442
|
+
let parts = convert_user_content_to_gemini(Some(&content));
|
|
1443
|
+
assert_eq!(parts.len(), 1);
|
|
1444
|
+
assert_eq!(parts[0]["text"], "Hello!");
|
|
1445
|
+
}
|
|
1446
|
+
|
|
1447
|
+
#[test]
|
|
1448
|
+
fn convert_user_content_array_with_image() {
|
|
1449
|
+
let content = json!([
|
|
1450
|
+
{"type": "text", "text": "What is this?"},
|
|
1451
|
+
{"type": "image_url", "image_url": {"url": "data:image/png;base64,iVBOR"}}
|
|
1452
|
+
]);
|
|
1453
|
+
let parts = convert_user_content_to_gemini(Some(&content));
|
|
1454
|
+
assert_eq!(parts.len(), 2);
|
|
1455
|
+
assert_eq!(parts[0]["text"], "What is this?");
|
|
1456
|
+
assert_eq!(parts[1]["inlineData"]["mimeType"], "image/png");
|
|
1457
|
+
assert_eq!(parts[1]["inlineData"]["data"], "iVBOR");
|
|
1458
|
+
}
|
|
1459
|
+
|
|
1460
|
+
#[test]
|
|
1461
|
+
fn convert_user_content_none() {
|
|
1462
|
+
let parts = convert_user_content_to_gemini(None);
|
|
1463
|
+
assert_eq!(parts.len(), 1);
|
|
1464
|
+
assert_eq!(parts[0]["text"], "");
|
|
1465
|
+
}
|
|
1466
|
+
|
|
1467
|
+
#[test]
|
|
1468
|
+
fn convert_user_content_document_part() {
|
|
1469
|
+
let content = json!([
|
|
1470
|
+
{"type": "text", "text": "Read this PDF."},
|
|
1471
|
+
{"type": "document", "document": {"data": "base64data==", "media_type": "application/pdf"}}
|
|
1472
|
+
]);
|
|
1473
|
+
let parts = convert_user_content_to_gemini(Some(&content));
|
|
1474
|
+
assert_eq!(parts.len(), 2);
|
|
1475
|
+
assert_eq!(parts[0]["text"], "Read this PDF.");
|
|
1476
|
+
assert_eq!(parts[1]["inlineData"]["mimeType"], "application/pdf");
|
|
1477
|
+
assert_eq!(parts[1]["inlineData"]["data"], "base64data==");
|
|
1478
|
+
}
|
|
1479
|
+
|
|
1480
|
+
#[test]
|
|
1481
|
+
fn translate_tool_choice_string_values() {
|
|
1482
|
+
let auto = translate_tool_choice(Some(&json!("auto"))).unwrap();
|
|
1483
|
+
assert_eq!(auto["functionCallingConfig"]["mode"], "AUTO");
|
|
1484
|
+
|
|
1485
|
+
let none = translate_tool_choice(Some(&json!("none"))).unwrap();
|
|
1486
|
+
assert_eq!(none["functionCallingConfig"]["mode"], "NONE");
|
|
1487
|
+
|
|
1488
|
+
let required = translate_tool_choice(Some(&json!("required"))).unwrap();
|
|
1489
|
+
assert_eq!(required["functionCallingConfig"]["mode"], "ANY");
|
|
1490
|
+
}
|
|
1491
|
+
|
|
1492
|
+
#[test]
|
|
1493
|
+
fn translate_tool_choice_specific_function() {
|
|
1494
|
+
let tc = json!({"type": "function", "function": {"name": "my_fn"}});
|
|
1495
|
+
let result = translate_tool_choice(Some(&tc)).unwrap();
|
|
1496
|
+
assert_eq!(result["functionCallingConfig"]["mode"], "ANY");
|
|
1497
|
+
assert_eq!(result["functionCallingConfig"]["allowedFunctionNames"][0], "my_fn");
|
|
1498
|
+
}
|
|
1499
|
+
|
|
1500
|
+
#[test]
|
|
1501
|
+
fn translate_tool_choice_none_input() {
|
|
1502
|
+
assert!(translate_tool_choice(None).is_none());
|
|
1503
|
+
}
|
|
1504
|
+
}
|