langfun 0.1.2.dev202501050804__py3-none-any.whl → 0.1.2.dev202501090804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. langfun/core/__init__.py +0 -5
  2. langfun/core/coding/python/correction.py +4 -3
  3. langfun/core/coding/python/errors.py +10 -9
  4. langfun/core/coding/python/execution.py +23 -12
  5. langfun/core/coding/python/execution_test.py +21 -2
  6. langfun/core/coding/python/generation.py +18 -9
  7. langfun/core/concurrent.py +2 -3
  8. langfun/core/console.py +8 -3
  9. langfun/core/eval/base.py +2 -3
  10. langfun/core/eval/v2/reporting.py +15 -6
  11. langfun/core/language_model.py +7 -4
  12. langfun/core/language_model_test.py +15 -0
  13. langfun/core/llms/__init__.py +25 -26
  14. langfun/core/llms/cache/in_memory.py +6 -0
  15. langfun/core/llms/cache/in_memory_test.py +5 -0
  16. langfun/core/llms/deepseek.py +261 -0
  17. langfun/core/llms/deepseek_test.py +438 -0
  18. langfun/core/llms/gemini.py +507 -0
  19. langfun/core/llms/gemini_test.py +195 -0
  20. langfun/core/llms/google_genai.py +46 -320
  21. langfun/core/llms/google_genai_test.py +9 -204
  22. langfun/core/llms/openai.py +5 -0
  23. langfun/core/llms/vertexai.py +31 -359
  24. langfun/core/llms/vertexai_test.py +6 -166
  25. langfun/core/structured/mapping.py +13 -13
  26. langfun/core/structured/mapping_test.py +2 -2
  27. langfun/core/structured/schema.py +16 -8
  28. {langfun-0.1.2.dev202501050804.dist-info → langfun-0.1.2.dev202501090804.dist-info}/METADATA +19 -14
  29. {langfun-0.1.2.dev202501050804.dist-info → langfun-0.1.2.dev202501090804.dist-info}/RECORD +32 -30
  30. {langfun-0.1.2.dev202501050804.dist-info → langfun-0.1.2.dev202501090804.dist-info}/WHEEL +1 -1
  31. langfun/core/text_formatting.py +0 -168
  32. langfun/core/text_formatting_test.py +0 -65
  33. {langfun-0.1.2.dev202501050804.dist-info → langfun-0.1.2.dev202501090804.dist-info}/LICENSE +0 -0
  34. {langfun-0.1.2.dev202501050804.dist-info → langfun-0.1.2.dev202501090804.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,507 @@
1
+ # Copyright 2025 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Gemini REST API (Shared by Google GenAI and Vertex AI)."""
15
+
16
+ import base64
17
+ from typing import Any
18
+
19
+ import langfun.core as lf
20
+ from langfun.core import modalities as lf_modalities
21
+ from langfun.core.llms import rest
22
+ import pyglove as pg
23
+
24
+ # Supported modalities.
25
+
26
+ IMAGE_TYPES = [
27
+ 'image/png',
28
+ 'image/jpeg',
29
+ 'image/webp',
30
+ 'image/heic',
31
+ 'image/heif',
32
+ ]
33
+
34
+ AUDIO_TYPES = [
35
+ 'audio/aac',
36
+ 'audio/flac',
37
+ 'audio/mp3',
38
+ 'audio/m4a',
39
+ 'audio/mpeg',
40
+ 'audio/mpga',
41
+ 'audio/mp4',
42
+ 'audio/opus',
43
+ 'audio/pcm',
44
+ 'audio/wav',
45
+ 'audio/webm',
46
+ ]
47
+
48
+ VIDEO_TYPES = [
49
+ 'video/mov',
50
+ 'video/mpeg',
51
+ 'video/mpegps',
52
+ 'video/mpg',
53
+ 'video/mp4',
54
+ 'video/webm',
55
+ 'video/wmv',
56
+ 'video/x-flv',
57
+ 'video/3gpp',
58
+ 'video/quicktime',
59
+ ]
60
+
61
+ DOCUMENT_TYPES = [
62
+ 'application/pdf',
63
+ 'text/plain',
64
+ 'text/csv',
65
+ 'text/html',
66
+ 'text/xml',
67
+ 'text/x-script.python',
68
+ 'application/json',
69
+ ]
70
+
71
+ TEXT_ONLY = []
72
+
73
+ ALL_MODALITIES = (
74
+ IMAGE_TYPES + AUDIO_TYPES + VIDEO_TYPES + DOCUMENT_TYPES
75
+ )
76
+
77
+ SUPPORTED_MODELS_AND_SETTINGS = {
78
+ # For automatically rate control and cost estimation, we explicitly register
79
+ # supported models here. This may be inconvenient for new models, but it
80
+ # helps us to keep track of the models and their pricing.
81
+ # Models and RPM are from
82
+ # https://ai.google.dev/gemini-api/docs/models/gemini?_gl=1*114hbho*_up*MQ..&gclid=Cj0KCQiAst67BhCEARIsAKKdWOljBY5aQdNQ41zOPkXFCwymUfMNFl_7ukm1veAf75ZTD9qWFrFr11IaApL3EALw_wcB
83
+ # Pricing in US dollars, from https://ai.google.dev/pricing
84
+ # as of 2025-01-03.
85
+ # NOTE: Please update google_genai.py, vertexai.py, __init__.py when
86
+ # adding new models.
87
+ # !!! PLEASE KEEP MODELS SORTED BY RELEASE DATE !!!
88
+ 'gemini-2.0-flash-thinking-exp-1219': pg.Dict(
89
+ latest_update='2024-12-19',
90
+ experimental=True,
91
+ in_service=True,
92
+ supported_modalities=ALL_MODALITIES,
93
+ rpm_free=10,
94
+ tpm_free=4_000_000,
95
+ rpm_paid=0,
96
+ tpm_paid=0,
97
+ cost_per_1m_input_tokens_up_to_128k=0,
98
+ cost_per_1m_output_tokens_up_to_128k=0,
99
+ cost_per_1m_cached_tokens_up_to_128k=0,
100
+ cost_per_1m_input_tokens_longer_than_128k=0,
101
+ cost_per_1m_output_tokens_longer_than_128k=0,
102
+ cost_per_1m_cached_tokens_longer_than_128k=0,
103
+ ),
104
+ 'gemini-2.0-flash-exp': pg.Dict(
105
+ latest_update='2024-12-11',
106
+ experimental=True,
107
+ in_service=True,
108
+ supported_modalities=ALL_MODALITIES,
109
+ rpm_free=10,
110
+ tpm_free=4_000_000,
111
+ rpm_paid=0,
112
+ tpm_paid=0,
113
+ cost_per_1m_input_tokens_up_to_128k=0,
114
+ cost_per_1m_output_tokens_up_to_128k=0,
115
+ cost_per_1m_cached_tokens_up_to_128k=0,
116
+ cost_per_1m_input_tokens_longer_than_128k=0,
117
+ cost_per_1m_output_tokens_longer_than_128k=0,
118
+ cost_per_1m_cached_tokens_longer_than_128k=0,
119
+ ),
120
+ 'gemini-exp-1206': pg.Dict(
121
+ latest_update='2024-12-06',
122
+ experimental=True,
123
+ in_service=True,
124
+ supported_modalities=ALL_MODALITIES,
125
+ rpm_free=10,
126
+ tpm_free=4_000_000,
127
+ rpm_paid=0,
128
+ tpm_paid=0,
129
+ cost_per_1m_input_tokens_up_to_128k=0,
130
+ cost_per_1m_output_tokens_up_to_128k=0,
131
+ cost_per_1m_cached_tokens_up_to_128k=0,
132
+ cost_per_1m_input_tokens_longer_than_128k=0,
133
+ cost_per_1m_output_tokens_longer_than_128k=0,
134
+ cost_per_1m_cached_tokens_longer_than_128k=0,
135
+ ),
136
+ 'learnlm-1.5-pro-experimental': pg.Dict(
137
+ latest_update='2024-11-19',
138
+ experimental=True,
139
+ in_service=True,
140
+ supported_modalities=ALL_MODALITIES,
141
+ rpm_free=10,
142
+ tpm_free=4_000_000,
143
+ rpm_paid=0,
144
+ tpm_paid=0,
145
+ cost_per_1m_input_tokens_up_to_128k=0,
146
+ cost_per_1m_output_tokens_up_to_128k=0,
147
+ cost_per_1m_cached_tokens_up_to_128k=0,
148
+ cost_per_1m_input_tokens_longer_than_128k=0,
149
+ cost_per_1m_output_tokens_longer_than_128k=0,
150
+ cost_per_1m_cached_tokens_longer_than_128k=0,
151
+ ),
152
+ 'gemini-exp-1114': pg.Dict(
153
+ latest_update='2024-11-14',
154
+ experimental=True,
155
+ in_service=True,
156
+ supported_modalities=ALL_MODALITIES,
157
+ rpm_free=10,
158
+ tpm_free=4_000_000,
159
+ rpm_paid=0,
160
+ tpm_paid=0,
161
+ cost_per_1m_input_tokens_up_to_128k=0,
162
+ cost_per_1m_output_tokens_up_to_128k=0,
163
+ cost_per_1m_cached_tokens_up_to_128k=0,
164
+ cost_per_1m_input_tokens_longer_than_128k=0,
165
+ cost_per_1m_output_tokens_longer_than_128k=0,
166
+ cost_per_1m_cached_tokens_longer_than_128k=0,
167
+ ),
168
+ 'gemini-1.5-flash-latest': pg.Dict(
169
+ latest_update='2024-09-30',
170
+ in_service=True,
171
+ supported_modalities=ALL_MODALITIES,
172
+ rpm_free=15,
173
+ tpm_free=1_000_000,
174
+ rpm_paid=2000,
175
+ tpm_paid=4_000_000,
176
+ cost_per_1m_input_tokens_up_to_128k=0.075,
177
+ cost_per_1m_output_tokens_up_to_128k=0.3,
178
+ cost_per_1m_cached_tokens_up_to_128k=0.01875,
179
+ cost_per_1m_input_tokens_longer_than_128k=0.15,
180
+ cost_per_1m_output_tokens_longer_than_128k=0.6,
181
+ cost_per_1m_cached_tokens_longer_than_128k=0.0375,
182
+ ),
183
+ 'gemini-1.5-flash': pg.Dict(
184
+ latest_update='2024-09-30',
185
+ in_service=True,
186
+ supported_modalities=ALL_MODALITIES,
187
+ rpm_free=15,
188
+ tpm_free=1_000_000,
189
+ rpm_paid=2000,
190
+ tpm_paid=4_000_000,
191
+ cost_per_1m_input_tokens_up_to_128k=0.075,
192
+ cost_per_1m_output_tokens_up_to_128k=0.3,
193
+ cost_per_1m_cached_tokens_up_to_128k=0.01875,
194
+ cost_per_1m_input_tokens_longer_than_128k=0.15,
195
+ cost_per_1m_output_tokens_longer_than_128k=0.6,
196
+ cost_per_1m_cached_tokens_longer_than_128k=0.0375,
197
+ ),
198
+ 'gemini-1.5-flash-001': pg.Dict(
199
+ latest_update='2024-09-30',
200
+ in_service=True,
201
+ supported_modalities=ALL_MODALITIES,
202
+ rpm_free=15,
203
+ tpm_free=1_000_000,
204
+ rpm_paid=2000,
205
+ tpm_paid=4_000_000,
206
+ cost_per_1m_input_tokens_up_to_128k=0.075,
207
+ cost_per_1m_output_tokens_up_to_128k=0.3,
208
+ cost_per_1m_cached_tokens_up_to_128k=0.01875,
209
+ cost_per_1m_input_tokens_longer_than_128k=0.15,
210
+ cost_per_1m_output_tokens_longer_than_128k=0.6,
211
+ cost_per_1m_cached_tokens_longer_than_128k=0.0375,
212
+ ),
213
+ 'gemini-1.5-flash-002': pg.Dict(
214
+ latest_update='2024-09-30',
215
+ in_service=True,
216
+ supported_modalities=ALL_MODALITIES,
217
+ rpm_free=15,
218
+ tpm_free=1_000_000,
219
+ rpm_paid=2000,
220
+ tpm_paid=4_000_000,
221
+ cost_per_1m_input_tokens_up_to_128k=0.075,
222
+ cost_per_1m_output_tokens_up_to_128k=0.3,
223
+ cost_per_1m_cached_tokens_up_to_128k=0.01875,
224
+ cost_per_1m_input_tokens_longer_than_128k=0.15,
225
+ cost_per_1m_output_tokens_longer_than_128k=0.6,
226
+ cost_per_1m_cached_tokens_longer_than_128k=0.0375,
227
+ ),
228
+ 'gemini-1.5-flash-8b': pg.Dict(
229
+ latest_update='2024-10-30',
230
+ in_service=True,
231
+ supported_modalities=ALL_MODALITIES,
232
+ rpm_free=15,
233
+ tpm_free=1_000_000,
234
+ rpm_paid=4000,
235
+ tpm_paid=4_000_000,
236
+ cost_per_1m_input_tokens_up_to_128k=0.0375,
237
+ cost_per_1m_output_tokens_up_to_128k=0.15,
238
+ cost_per_1m_cached_tokens_up_to_128k=0.01,
239
+ cost_per_1m_input_tokens_longer_than_128k=0.075,
240
+ cost_per_1m_output_tokens_longer_than_128k=0.3,
241
+ cost_per_1m_cached_tokens_longer_than_128k=0.02,
242
+ ),
243
+ 'gemini-1.5-flash-8b-001': pg.Dict(
244
+ latest_update='2024-10-30',
245
+ in_service=True,
246
+ supported_modalities=ALL_MODALITIES,
247
+ rpm_free=15,
248
+ tpm_free=1_000_000,
249
+ rpm_paid=4000,
250
+ tpm_paid=4_000_000,
251
+ cost_per_1m_input_tokens_up_to_128k=0.0375,
252
+ cost_per_1m_output_tokens_up_to_128k=0.15,
253
+ cost_per_1m_cached_tokens_up_to_128k=0.01,
254
+ cost_per_1m_input_tokens_longer_than_128k=0.075,
255
+ cost_per_1m_output_tokens_longer_than_128k=0.3,
256
+ cost_per_1m_cached_tokens_longer_than_128k=0.02,
257
+ ),
258
+ 'gemini-1.5-pro-latest': pg.Dict(
259
+ latest_update='2024-09-30',
260
+ in_service=True,
261
+ supported_modalities=ALL_MODALITIES,
262
+ rpm_free=2,
263
+ tpm_free=32_000,
264
+ rpm_paid=1000,
265
+ tpm_paid=4_000_000,
266
+ cost_per_1m_input_tokens_up_to_128k=1.25,
267
+ cost_per_1m_output_tokens_up_to_128k=5.00,
268
+ cost_per_1m_cached_tokens_up_to_128k=0.3125,
269
+ cost_per_1m_input_tokens_longer_than_128k=2.5,
270
+ cost_per_1m_output_tokens_longer_than_128k=10.00,
271
+ cost_per_1m_cached_tokens_longer_than_128k=0.625,
272
+ ),
273
+ 'gemini-1.5-pro': pg.Dict(
274
+ latest_update='2024-09-30',
275
+ in_service=True,
276
+ supported_modalities=ALL_MODALITIES,
277
+ rpm_free=2,
278
+ tpm_free=32_000,
279
+ rpm_paid=1000,
280
+ tpm_paid=4_000_000,
281
+ cost_per_1m_input_tokens_up_to_128k=1.25,
282
+ cost_per_1m_output_tokens_up_to_128k=5.00,
283
+ cost_per_1m_cached_tokens_up_to_128k=0.3125,
284
+ cost_per_1m_input_tokens_longer_than_128k=2.5,
285
+ cost_per_1m_output_tokens_longer_than_128k=10.00,
286
+ cost_per_1m_cached_tokens_longer_than_128k=0.625,
287
+ ),
288
+ 'gemini-1.5-pro-001': pg.Dict(
289
+ latest_update='2024-09-30',
290
+ in_service=True,
291
+ supported_modalities=ALL_MODALITIES,
292
+ rpm_free=2,
293
+ tpm_free=32_000,
294
+ rpm_paid=1000,
295
+ tpm_paid=4_000_000,
296
+ cost_per_1m_input_tokens_up_to_128k=1.25,
297
+ cost_per_1m_output_tokens_up_to_128k=5.00,
298
+ cost_per_1m_cached_tokens_up_to_128k=0.3125,
299
+ cost_per_1m_input_tokens_longer_than_128k=2.5,
300
+ cost_per_1m_output_tokens_longer_than_128k=10.00,
301
+ cost_per_1m_cached_tokens_longer_than_128k=0.625,
302
+ ),
303
+ 'gemini-1.5-pro-002': pg.Dict(
304
+ latest_update='2024-09-30',
305
+ in_service=True,
306
+ supported_modalities=ALL_MODALITIES,
307
+ rpm_free=2,
308
+ tpm_free=32_000,
309
+ rpm_paid=1000,
310
+ tpm_paid=4_000_000,
311
+ cost_per_1m_input_tokens_up_to_128k=1.25,
312
+ cost_per_1m_output_tokens_up_to_128k=5.00,
313
+ cost_per_1m_cached_tokens_up_to_128k=0.3125,
314
+ cost_per_1m_input_tokens_longer_than_128k=2.5,
315
+ cost_per_1m_output_tokens_longer_than_128k=10.00,
316
+ cost_per_1m_cached_tokens_longer_than_128k=0.625,
317
+ ),
318
+ 'gemini-1.0-pro': pg.Dict(
319
+ in_service=False,
320
+ supported_modalities=TEXT_ONLY,
321
+ rpm_free=15,
322
+ tpm_free=32_000,
323
+ rpm_paid=360,
324
+ tpm_paid=120_000,
325
+ cost_per_1m_input_tokens_up_to_128k=0.5,
326
+ cost_per_1m_output_tokens_up_to_128k=1.5,
327
+ cost_per_1m_cached_tokens_up_to_128k=0,
328
+ cost_per_1m_input_tokens_longer_than_128k=0.5,
329
+ cost_per_1m_output_tokens_longer_than_128k=1.5,
330
+ cost_per_1m_cached_tokens_longer_than_128k=0,
331
+ ),
332
+ }
333
+
334
+
335
+ @pg.use_init_args(['model'])
336
+ class Gemini(rest.REST):
337
+ """Language models provided by Google GenAI."""
338
+
339
+ model: pg.typing.Annotated[
340
+ pg.typing.Enum(
341
+ pg.MISSING_VALUE, list(SUPPORTED_MODELS_AND_SETTINGS.keys())
342
+ ),
343
+ 'The name of the model to use.',
344
+ ]
345
+
346
+ @property
347
+ def supported_modalities(self) -> list[str]:
348
+ """Returns the list of supported modalities."""
349
+ return SUPPORTED_MODELS_AND_SETTINGS[self.model].supported_modalities
350
+
351
+ @property
352
+ def max_concurrency(self) -> int:
353
+ """Returns the maximum number of concurrent requests."""
354
+ return self.rate_to_max_concurrency(
355
+ requests_per_min=max(
356
+ SUPPORTED_MODELS_AND_SETTINGS[self.model].rpm_free,
357
+ SUPPORTED_MODELS_AND_SETTINGS[self.model].rpm_paid
358
+ ),
359
+ tokens_per_min=max(
360
+ SUPPORTED_MODELS_AND_SETTINGS[self.model].tpm_free,
361
+ SUPPORTED_MODELS_AND_SETTINGS[self.model].tpm_paid,
362
+ ),
363
+ )
364
+
365
+ def estimate_cost(
366
+ self,
367
+ num_input_tokens: int,
368
+ num_output_tokens: int
369
+ ) -> float | None:
370
+ """Estimate the cost based on usage."""
371
+ entry = SUPPORTED_MODELS_AND_SETTINGS[self.model]
372
+ if num_input_tokens < 128_000:
373
+ cost_per_1m_input_tokens = entry.cost_per_1m_input_tokens_up_to_128k
374
+ cost_per_1m_output_tokens = entry.cost_per_1m_output_tokens_up_to_128k
375
+ else:
376
+ cost_per_1m_input_tokens = entry.cost_per_1m_input_tokens_longer_than_128k
377
+ cost_per_1m_output_tokens = (
378
+ entry.cost_per_1m_output_tokens_longer_than_128k
379
+ )
380
+ return (
381
+ cost_per_1m_input_tokens * num_input_tokens
382
+ + cost_per_1m_output_tokens * num_output_tokens
383
+ ) / 1000_1000
384
+
385
+ @property
386
+ def model_id(self) -> str:
387
+ """Returns a string to identify the model."""
388
+ return self.model
389
+
390
+ @classmethod
391
+ def dir(cls):
392
+ return [k for k, v in SUPPORTED_MODELS_AND_SETTINGS.items() if v.in_service]
393
+
394
+ @property
395
+ def headers(self):
396
+ return {
397
+ 'Content-Type': 'application/json; charset=utf-8',
398
+ }
399
+
400
+ def request(
401
+ self, prompt: lf.Message, sampling_options: lf.LMSamplingOptions
402
+ ) -> dict[str, Any]:
403
+ request = dict(
404
+ generationConfig=self._generation_config(prompt, sampling_options)
405
+ )
406
+ request['contents'] = [self._content_from_message(prompt)]
407
+ return request
408
+
409
+ def _generation_config(
410
+ self, prompt: lf.Message, options: lf.LMSamplingOptions
411
+ ) -> dict[str, Any]:
412
+ """Returns a dict as generation config for prompt and LMSamplingOptions."""
413
+ config = dict(
414
+ temperature=options.temperature,
415
+ maxOutputTokens=options.max_tokens,
416
+ candidateCount=options.n,
417
+ topK=options.top_k,
418
+ topP=options.top_p,
419
+ stopSequences=options.stop,
420
+ seed=options.random_seed,
421
+ responseLogprobs=options.logprobs,
422
+ logprobs=options.top_logprobs,
423
+ )
424
+
425
+ if json_schema := prompt.metadata.get('json_schema'):
426
+ if not isinstance(json_schema, dict):
427
+ raise ValueError(
428
+ f'`json_schema` must be a dict, got {json_schema!r}.'
429
+ )
430
+ json_schema = pg.to_json(json_schema)
431
+ config['responseSchema'] = json_schema
432
+ config['responseMimeType'] = 'application/json'
433
+ prompt.metadata.formatted_text = (
434
+ prompt.text
435
+ + '\n\n [RESPONSE FORMAT (not part of prompt)]\n'
436
+ + pg.to_json_str(json_schema, json_indent=2)
437
+ )
438
+ return config
439
+
440
+ def _content_from_message(self, prompt: lf.Message) -> dict[str, Any]:
441
+ """Gets generation content from langfun message."""
442
+ parts = []
443
+ for lf_chunk in prompt.chunk():
444
+ if isinstance(lf_chunk, str):
445
+ parts.append({'text': lf_chunk})
446
+ elif isinstance(lf_chunk, lf_modalities.Mime):
447
+ try:
448
+ modalities = lf_chunk.make_compatible(
449
+ self.supported_modalities + ['text/plain']
450
+ )
451
+ if isinstance(modalities, lf_modalities.Mime):
452
+ modalities = [modalities]
453
+ for modality in modalities:
454
+ if modality.is_text:
455
+ parts.append({'text': modality.to_text()})
456
+ else:
457
+ parts.append({
458
+ 'inlineData': {
459
+ 'data': base64.b64encode(modality.to_bytes()).decode(),
460
+ 'mimeType': modality.mime_type,
461
+ }
462
+ })
463
+ except lf.ModalityError as e:
464
+ raise lf.ModalityError(f'Unsupported modality: {lf_chunk!r}') from e
465
+ else:
466
+ raise lf.ModalityError(f'Unsupported modality: {lf_chunk!r}')
467
+ return dict(role='user', parts=parts)
468
+
469
+ def result(self, json: dict[str, Any]) -> lf.LMSamplingResult:
470
+ messages = [
471
+ self._message_from_content_parts(candidate['content']['parts'])
472
+ for candidate in json['candidates']
473
+ ]
474
+ usage = json['usageMetadata']
475
+ input_tokens = usage['promptTokenCount']
476
+ output_tokens = usage['candidatesTokenCount']
477
+ return lf.LMSamplingResult(
478
+ [lf.LMSample(message) for message in messages],
479
+ usage=lf.LMSamplingUsage(
480
+ prompt_tokens=input_tokens,
481
+ completion_tokens=output_tokens,
482
+ total_tokens=input_tokens + output_tokens,
483
+ estimated_cost=self.estimate_cost(
484
+ num_input_tokens=input_tokens,
485
+ num_output_tokens=output_tokens,
486
+ ),
487
+ ),
488
+ )
489
+
490
+ def _message_from_content_parts(
491
+ self, parts: list[dict[str, Any]]
492
+ ) -> lf.Message:
493
+ """Converts Vertex AI's content parts protocol to message."""
494
+ chunks = []
495
+ thought_chunks = []
496
+ for part in parts:
497
+ if text_part := part.get('text'):
498
+ if part.get('thought'):
499
+ thought_chunks.append(text_part)
500
+ else:
501
+ chunks.append(text_part)
502
+ else:
503
+ raise ValueError(f'Unsupported part: {part}')
504
+ message = lf.AIMessage.from_chunks(chunks)
505
+ if thought_chunks:
506
+ message.set('thought', lf.AIMessage.from_chunks(thought_chunks))
507
+ return message
@@ -0,0 +1,195 @@
1
+ # Copyright 2025 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Tests for Gemini API."""
15
+
16
+ import base64
17
+ from typing import Any
18
+ import unittest
19
+ from unittest import mock
20
+
21
+ import langfun.core as lf
22
+ from langfun.core import modalities as lf_modalities
23
+ from langfun.core.llms import gemini
24
+ import pyglove as pg
25
+ import requests
26
+
27
+
28
+ example_image = (
29
+ b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x18\x00\x00\x00\x18\x04'
30
+ b'\x03\x00\x00\x00\x12Y \xcb\x00\x00\x00\x18PLTE\x00\x00'
31
+ b'\x00fff_chaag_cg_ch^ci_ciC\xedb\x94\x00\x00\x00\x08tRNS'
32
+ b'\x00\n\x9f*\xd4\xff_\xf4\xe4\x8b\xf3a\x00\x00\x00>IDATx'
33
+ b'\x01c \x05\x08)"\xd8\xcc\xae!\x06pNz\x88k\x19\\Q\xa8"\x10'
34
+ b'\xc1\x14\x95\x01%\xc1\n\xa143Ta\xa8"D-\x84\x03QM\x98\xc3'
35
+ b'\x1a\x1a\x1a@5\x0e\x04\xa0q\x88\x05\x00\x07\xf8\x18\xf9'
36
+ b'\xdao\xd0|\x00\x00\x00\x00IEND\xaeB`\x82'
37
+ )
38
+
39
+
40
+ def mock_requests_post(url: str, json: dict[str, Any], **kwargs):
41
+ del url, kwargs
42
+ c = pg.Dict(json['generationConfig'])
43
+ content = json['contents'][0]['parts'][0]['text']
44
+ response = requests.Response()
45
+ response.status_code = 200
46
+ response._content = pg.to_json_str({
47
+ 'candidates': [
48
+ {
49
+ 'content': {
50
+ 'role': 'model',
51
+ 'parts': [
52
+ {
53
+ 'text': (
54
+ f'This is a response to {content} with '
55
+ f'temperature={c.temperature}, '
56
+ f'top_p={c.topP}, '
57
+ f'top_k={c.topK}, '
58
+ f'max_tokens={c.maxOutputTokens}, '
59
+ f'stop={"".join(c.stopSequences)}.'
60
+ ),
61
+ },
62
+ {
63
+ 'text': 'This is the thought.',
64
+ 'thought': True,
65
+ }
66
+ ],
67
+ },
68
+ },
69
+ ],
70
+ 'usageMetadata': {
71
+ 'promptTokenCount': 3,
72
+ 'candidatesTokenCount': 4,
73
+ }
74
+ }).encode()
75
+ return response
76
+
77
+
78
+ class GeminiTest(unittest.TestCase):
79
+ """Tests for Vertex model with REST API."""
80
+
81
+ def test_content_from_message_text_only(self):
82
+ text = 'This is a beautiful day'
83
+ model = gemini.Gemini('gemini-1.5-pro', api_endpoint='')
84
+ chunks = model._content_from_message(lf.UserMessage(text))
85
+ self.assertEqual(chunks, {'role': 'user', 'parts': [{'text': text}]})
86
+
87
+ def test_content_from_message_mm(self):
88
+ image = lf_modalities.Image.from_bytes(example_image)
89
+ message = lf.UserMessage(
90
+ 'This is an <<[[image]]>>, what is it?', image=image
91
+ )
92
+
93
+ # Non-multimodal model.
94
+ with self.assertRaisesRegex(lf.ModalityError, 'Unsupported modality'):
95
+ gemini.Gemini(
96
+ 'gemini-1.0-pro', api_endpoint=''
97
+ )._content_from_message(message)
98
+
99
+ model = gemini.Gemini('gemini-1.5-pro', api_endpoint='')
100
+ content = model._content_from_message(message)
101
+ self.assertEqual(
102
+ content,
103
+ {
104
+ 'role': 'user',
105
+ 'parts': [
106
+ {'text': 'This is an'},
107
+ {
108
+ 'inlineData': {
109
+ 'data': base64.b64encode(example_image).decode(),
110
+ 'mimeType': 'image/png',
111
+ }
112
+ },
113
+ {'text': ', what is it?'},
114
+ ],
115
+ },
116
+ )
117
+
118
+ def test_generation_config(self):
119
+ model = gemini.Gemini('gemini-1.5-pro', api_endpoint='')
120
+ json_schema = {
121
+ 'type': 'object',
122
+ 'properties': {
123
+ 'name': {'type': 'string'},
124
+ },
125
+ 'required': ['name'],
126
+ 'title': 'Person',
127
+ }
128
+ actual = model._generation_config(
129
+ lf.UserMessage('hi', json_schema=json_schema),
130
+ lf.LMSamplingOptions(
131
+ temperature=2.0,
132
+ top_p=1.0,
133
+ top_k=20,
134
+ max_tokens=1024,
135
+ stop=['\n'],
136
+ ),
137
+ )
138
+ self.assertEqual(
139
+ actual,
140
+ dict(
141
+ candidateCount=1,
142
+ temperature=2.0,
143
+ topP=1.0,
144
+ topK=20,
145
+ maxOutputTokens=1024,
146
+ stopSequences=['\n'],
147
+ responseLogprobs=False,
148
+ logprobs=None,
149
+ seed=None,
150
+ responseMimeType='application/json',
151
+ responseSchema={
152
+ 'type': 'object',
153
+ 'properties': {
154
+ 'name': {'type': 'string'}
155
+ },
156
+ 'required': ['name'],
157
+ 'title': 'Person',
158
+ }
159
+ ),
160
+ )
161
+ with self.assertRaisesRegex(
162
+ ValueError, '`json_schema` must be a dict, got'
163
+ ):
164
+ model._generation_config(
165
+ lf.UserMessage('hi', json_schema='not a dict'),
166
+ lf.LMSamplingOptions(),
167
+ )
168
+
169
+ def test_call_model(self):
170
+ with mock.patch('requests.Session.post') as mock_generate:
171
+ mock_generate.side_effect = mock_requests_post
172
+
173
+ lm = gemini.Gemini('gemini-1.5-pro', api_endpoint='')
174
+ r = lm(
175
+ 'hello',
176
+ temperature=2.0,
177
+ top_p=1.0,
178
+ top_k=20,
179
+ max_tokens=1024,
180
+ stop='\n',
181
+ )
182
+ self.assertEqual(
183
+ r.text,
184
+ (
185
+ 'This is a response to hello with temperature=2.0, '
186
+ 'top_p=1.0, top_k=20, max_tokens=1024, stop=\n.'
187
+ ),
188
+ )
189
+ self.assertEqual(r.metadata.thought, 'This is the thought.')
190
+ self.assertEqual(r.metadata.usage.prompt_tokens, 3)
191
+ self.assertEqual(r.metadata.usage.completion_tokens, 4)
192
+
193
+
194
+ if __name__ == '__main__':
195
+ unittest.main()