langfun 0.1.2.dev202501010804__py3-none-any.whl → 0.1.2.dev202501060804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. langfun/core/__init__.py +0 -4
  2. langfun/core/eval/matching.py +2 -2
  3. langfun/core/eval/scoring.py +6 -2
  4. langfun/core/eval/v2/checkpointing.py +106 -72
  5. langfun/core/eval/v2/checkpointing_test.py +108 -3
  6. langfun/core/eval/v2/eval_test_helper.py +56 -0
  7. langfun/core/eval/v2/evaluation.py +25 -4
  8. langfun/core/eval/v2/evaluation_test.py +11 -0
  9. langfun/core/eval/v2/example.py +11 -1
  10. langfun/core/eval/v2/example_test.py +16 -2
  11. langfun/core/eval/v2/experiment.py +83 -19
  12. langfun/core/eval/v2/experiment_test.py +121 -3
  13. langfun/core/eval/v2/reporting.py +67 -20
  14. langfun/core/eval/v2/reporting_test.py +119 -2
  15. langfun/core/eval/v2/runners.py +7 -4
  16. langfun/core/llms/__init__.py +23 -24
  17. langfun/core/llms/anthropic.py +12 -0
  18. langfun/core/llms/cache/in_memory.py +6 -0
  19. langfun/core/llms/cache/in_memory_test.py +5 -0
  20. langfun/core/llms/gemini.py +507 -0
  21. langfun/core/llms/gemini_test.py +195 -0
  22. langfun/core/llms/google_genai.py +46 -310
  23. langfun/core/llms/google_genai_test.py +9 -204
  24. langfun/core/llms/openai.py +23 -37
  25. langfun/core/llms/vertexai.py +28 -348
  26. langfun/core/llms/vertexai_test.py +6 -166
  27. {langfun-0.1.2.dev202501010804.dist-info → langfun-0.1.2.dev202501060804.dist-info}/METADATA +7 -13
  28. {langfun-0.1.2.dev202501010804.dist-info → langfun-0.1.2.dev202501060804.dist-info}/RECORD +31 -31
  29. {langfun-0.1.2.dev202501010804.dist-info → langfun-0.1.2.dev202501060804.dist-info}/WHEEL +1 -1
  30. langfun/core/repr_utils.py +0 -204
  31. langfun/core/repr_utils_test.py +0 -90
  32. {langfun-0.1.2.dev202501010804.dist-info → langfun-0.1.2.dev202501060804.dist-info}/LICENSE +0 -0
  33. {langfun-0.1.2.dev202501010804.dist-info → langfun-0.1.2.dev202501060804.dist-info}/top_level.txt +0 -0
@@ -67,6 +67,13 @@ SUPPORTED_MODELS_AND_SETTINGS = {
67
67
  cost_per_1k_input_tokens=0.001,
68
68
  cost_per_1k_output_tokens=0.005,
69
69
  ),
70
+ 'claude-3-opus@20240229': pg.Dict(
71
+ max_tokens=4096,
72
+ rpm=4000,
73
+ tpm=400000,
74
+ cost_per_1k_input_tokens=0.015,
75
+ cost_per_1k_output_tokens=0.075,
76
+ ),
70
77
  # Anthropic hosted models.
71
78
  'claude-3-5-sonnet-20241022': pg.Dict(
72
79
  max_tokens=8192,
@@ -461,6 +468,11 @@ class VertexAIAnthropic(Anthropic):
461
468
  return request
462
469
 
463
470
 
471
+ class VertexAIClaude3_Opus_20240229(VertexAIAnthropic): # pylint: disable=invalid-name
472
+ """Anthropic's Claude 3 Opus model on VertexAI."""
473
+ model = 'claude-3-opus@20240229'
474
+
475
+
464
476
  class VertexAIClaude3_5_Sonnet_20241022(VertexAIAnthropic): # pylint: disable=invalid-name
465
477
  """Anthropic's Claude 3.5 Sonnet model on VertexAI."""
466
478
  model = 'claude-3-5-sonnet-v2@20241022'
@@ -15,6 +15,7 @@
15
15
 
16
16
  import collections
17
17
  import contextlib
18
+ import json
18
19
  from typing import Annotated, Any, Iterator
19
20
  import langfun.core as lf
20
21
  from langfun.core.llms.cache import base
@@ -49,6 +50,11 @@ class InMemory(base.LMCacheBase):
49
50
  "Creating a new cache as cache file '%s' does not exist.",
50
51
  self.filename,
51
52
  )
53
+ except json.JSONDecodeError:
54
+ pg.logging.warning(
55
+ "Creating a new cache as cache file '%s' is corrupted.",
56
+ self.filename,
57
+ )
52
58
 
53
59
  def model_ids(self) -> list[str]:
54
60
  """Returns the model ids of cached queires."""
@@ -295,6 +295,11 @@ class InMemoryLMCacheTest(unittest.TestCase):
295
295
  self.assertEqual(cache2.stats.num_updates, 2)
296
296
  cache2.save()
297
297
 
298
+ # Corrupted file.
299
+ pg.io.writefile(path, 'bad_content')
300
+ cache3 = in_memory.InMemory(path)
301
+ self.assertEqual(len(cache3), 0)
302
+
298
303
 
299
304
  class LmCacheTest(unittest.TestCase):
300
305
 
@@ -0,0 +1,507 @@
1
+ # Copyright 2025 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Gemini REST API (Shared by Google GenAI and Vertex AI)."""
15
+
16
+ import base64
17
+ from typing import Any
18
+
19
+ import langfun.core as lf
20
+ from langfun.core import modalities as lf_modalities
21
+ from langfun.core.llms import rest
22
+ import pyglove as pg
23
+
24
+ # Supported modalities.
25
+
26
+ IMAGE_TYPES = [
27
+ 'image/png',
28
+ 'image/jpeg',
29
+ 'image/webp',
30
+ 'image/heic',
31
+ 'image/heif',
32
+ ]
33
+
34
+ AUDIO_TYPES = [
35
+ 'audio/aac',
36
+ 'audio/flac',
37
+ 'audio/mp3',
38
+ 'audio/m4a',
39
+ 'audio/mpeg',
40
+ 'audio/mpga',
41
+ 'audio/mp4',
42
+ 'audio/opus',
43
+ 'audio/pcm',
44
+ 'audio/wav',
45
+ 'audio/webm',
46
+ ]
47
+
48
+ VIDEO_TYPES = [
49
+ 'video/mov',
50
+ 'video/mpeg',
51
+ 'video/mpegps',
52
+ 'video/mpg',
53
+ 'video/mp4',
54
+ 'video/webm',
55
+ 'video/wmv',
56
+ 'video/x-flv',
57
+ 'video/3gpp',
58
+ 'video/quicktime',
59
+ ]
60
+
61
+ DOCUMENT_TYPES = [
62
+ 'application/pdf',
63
+ 'text/plain',
64
+ 'text/csv',
65
+ 'text/html',
66
+ 'text/xml',
67
+ 'text/x-script.python',
68
+ 'application/json',
69
+ ]
70
+
71
+ TEXT_ONLY = []
72
+
73
+ ALL_MODALITIES = (
74
+ IMAGE_TYPES + AUDIO_TYPES + VIDEO_TYPES + DOCUMENT_TYPES
75
+ )
76
+
77
+ SUPPORTED_MODELS_AND_SETTINGS = {
78
+ # For automatically rate control and cost estimation, we explicitly register
79
+ # supported models here. This may be inconvenient for new models, but it
80
+ # helps us to keep track of the models and their pricing.
81
+ # Models and RPM are from
82
+ # https://ai.google.dev/gemini-api/docs/models/gemini?_gl=1*114hbho*_up*MQ..&gclid=Cj0KCQiAst67BhCEARIsAKKdWOljBY5aQdNQ41zOPkXFCwymUfMNFl_7ukm1veAf75ZTD9qWFrFr11IaApL3EALw_wcB
83
+ # Pricing in US dollars, from https://ai.google.dev/pricing
84
+ # as of 2025-01-03.
85
+ # NOTE: Please update google_genai.py, vertexai.py, __init__.py when
86
+ # adding new models.
87
+ # !!! PLEASE KEEP MODELS SORTED BY RELEASE DATE !!!
88
+ 'gemini-2.0-flash-thinking-exp-1219': pg.Dict(
89
+ latest_update='2024-12-19',
90
+ experimental=True,
91
+ in_service=True,
92
+ supported_modalities=ALL_MODALITIES,
93
+ rpm_free=10,
94
+ tpm_free=4_000_000,
95
+ rpm_paid=0,
96
+ tpm_paid=0,
97
+ cost_per_1m_input_tokens_up_to_128k=0,
98
+ cost_per_1m_output_tokens_up_to_128k=0,
99
+ cost_per_1m_cached_tokens_up_to_128k=0,
100
+ cost_per_1m_input_tokens_longer_than_128k=0,
101
+ cost_per_1m_output_tokens_longer_than_128k=0,
102
+ cost_per_1m_cached_tokens_longer_than_128k=0,
103
+ ),
104
+ 'gemini-2.0-flash-exp': pg.Dict(
105
+ latest_update='2024-12-11',
106
+ experimental=True,
107
+ in_service=True,
108
+ supported_modalities=ALL_MODALITIES,
109
+ rpm_free=10,
110
+ tpm_free=4_000_000,
111
+ rpm_paid=0,
112
+ tpm_paid=0,
113
+ cost_per_1m_input_tokens_up_to_128k=0,
114
+ cost_per_1m_output_tokens_up_to_128k=0,
115
+ cost_per_1m_cached_tokens_up_to_128k=0,
116
+ cost_per_1m_input_tokens_longer_than_128k=0,
117
+ cost_per_1m_output_tokens_longer_than_128k=0,
118
+ cost_per_1m_cached_tokens_longer_than_128k=0,
119
+ ),
120
+ 'gemini-exp-1206': pg.Dict(
121
+ latest_update='2024-12-06',
122
+ experimental=True,
123
+ in_service=True,
124
+ supported_modalities=ALL_MODALITIES,
125
+ rpm_free=10,
126
+ tpm_free=4_000_000,
127
+ rpm_paid=0,
128
+ tpm_paid=0,
129
+ cost_per_1m_input_tokens_up_to_128k=0,
130
+ cost_per_1m_output_tokens_up_to_128k=0,
131
+ cost_per_1m_cached_tokens_up_to_128k=0,
132
+ cost_per_1m_input_tokens_longer_than_128k=0,
133
+ cost_per_1m_output_tokens_longer_than_128k=0,
134
+ cost_per_1m_cached_tokens_longer_than_128k=0,
135
+ ),
136
+ 'learnlm-1.5-pro-experimental': pg.Dict(
137
+ latest_update='2024-11-19',
138
+ experimental=True,
139
+ in_service=True,
140
+ supported_modalities=ALL_MODALITIES,
141
+ rpm_free=10,
142
+ tpm_free=4_000_000,
143
+ rpm_paid=0,
144
+ tpm_paid=0,
145
+ cost_per_1m_input_tokens_up_to_128k=0,
146
+ cost_per_1m_output_tokens_up_to_128k=0,
147
+ cost_per_1m_cached_tokens_up_to_128k=0,
148
+ cost_per_1m_input_tokens_longer_than_128k=0,
149
+ cost_per_1m_output_tokens_longer_than_128k=0,
150
+ cost_per_1m_cached_tokens_longer_than_128k=0,
151
+ ),
152
+ 'gemini-exp-1114': pg.Dict(
153
+ latest_update='2024-11-14',
154
+ experimental=True,
155
+ in_service=True,
156
+ supported_modalities=ALL_MODALITIES,
157
+ rpm_free=10,
158
+ tpm_free=4_000_000,
159
+ rpm_paid=0,
160
+ tpm_paid=0,
161
+ cost_per_1m_input_tokens_up_to_128k=0,
162
+ cost_per_1m_output_tokens_up_to_128k=0,
163
+ cost_per_1m_cached_tokens_up_to_128k=0,
164
+ cost_per_1m_input_tokens_longer_than_128k=0,
165
+ cost_per_1m_output_tokens_longer_than_128k=0,
166
+ cost_per_1m_cached_tokens_longer_than_128k=0,
167
+ ),
168
+ 'gemini-1.5-flash-latest': pg.Dict(
169
+ latest_update='2024-09-30',
170
+ in_service=True,
171
+ supported_modalities=ALL_MODALITIES,
172
+ rpm_free=15,
173
+ tpm_free=1_000_000,
174
+ rpm_paid=2000,
175
+ tpm_paid=4_000_000,
176
+ cost_per_1m_input_tokens_up_to_128k=0.075,
177
+ cost_per_1m_output_tokens_up_to_128k=0.3,
178
+ cost_per_1m_cached_tokens_up_to_128k=0.01875,
179
+ cost_per_1m_input_tokens_longer_than_128k=0.15,
180
+ cost_per_1m_output_tokens_longer_than_128k=0.6,
181
+ cost_per_1m_cached_tokens_longer_than_128k=0.0375,
182
+ ),
183
+ 'gemini-1.5-flash': pg.Dict(
184
+ latest_update='2024-09-30',
185
+ in_service=True,
186
+ supported_modalities=ALL_MODALITIES,
187
+ rpm_free=15,
188
+ tpm_free=1_000_000,
189
+ rpm_paid=2000,
190
+ tpm_paid=4_000_000,
191
+ cost_per_1m_input_tokens_up_to_128k=0.075,
192
+ cost_per_1m_output_tokens_up_to_128k=0.3,
193
+ cost_per_1m_cached_tokens_up_to_128k=0.01875,
194
+ cost_per_1m_input_tokens_longer_than_128k=0.15,
195
+ cost_per_1m_output_tokens_longer_than_128k=0.6,
196
+ cost_per_1m_cached_tokens_longer_than_128k=0.0375,
197
+ ),
198
+ 'gemini-1.5-flash-001': pg.Dict(
199
+ latest_update='2024-09-30',
200
+ in_service=True,
201
+ supported_modalities=ALL_MODALITIES,
202
+ rpm_free=15,
203
+ tpm_free=1_000_000,
204
+ rpm_paid=2000,
205
+ tpm_paid=4_000_000,
206
+ cost_per_1m_input_tokens_up_to_128k=0.075,
207
+ cost_per_1m_output_tokens_up_to_128k=0.3,
208
+ cost_per_1m_cached_tokens_up_to_128k=0.01875,
209
+ cost_per_1m_input_tokens_longer_than_128k=0.15,
210
+ cost_per_1m_output_tokens_longer_than_128k=0.6,
211
+ cost_per_1m_cached_tokens_longer_than_128k=0.0375,
212
+ ),
213
+ 'gemini-1.5-flash-002': pg.Dict(
214
+ latest_update='2024-09-30',
215
+ in_service=True,
216
+ supported_modalities=ALL_MODALITIES,
217
+ rpm_free=15,
218
+ tpm_free=1_000_000,
219
+ rpm_paid=2000,
220
+ tpm_paid=4_000_000,
221
+ cost_per_1m_input_tokens_up_to_128k=0.075,
222
+ cost_per_1m_output_tokens_up_to_128k=0.3,
223
+ cost_per_1m_cached_tokens_up_to_128k=0.01875,
224
+ cost_per_1m_input_tokens_longer_than_128k=0.15,
225
+ cost_per_1m_output_tokens_longer_than_128k=0.6,
226
+ cost_per_1m_cached_tokens_longer_than_128k=0.0375,
227
+ ),
228
+ 'gemini-1.5-flash-8b': pg.Dict(
229
+ latest_update='2024-10-30',
230
+ in_service=True,
231
+ supported_modalities=ALL_MODALITIES,
232
+ rpm_free=15,
233
+ tpm_free=1_000_000,
234
+ rpm_paid=4000,
235
+ tpm_paid=4_000_000,
236
+ cost_per_1m_input_tokens_up_to_128k=0.0375,
237
+ cost_per_1m_output_tokens_up_to_128k=0.15,
238
+ cost_per_1m_cached_tokens_up_to_128k=0.01,
239
+ cost_per_1m_input_tokens_longer_than_128k=0.075,
240
+ cost_per_1m_output_tokens_longer_than_128k=0.3,
241
+ cost_per_1m_cached_tokens_longer_than_128k=0.02,
242
+ ),
243
+ 'gemini-1.5-flash-8b-001': pg.Dict(
244
+ latest_update='2024-10-30',
245
+ in_service=True,
246
+ supported_modalities=ALL_MODALITIES,
247
+ rpm_free=15,
248
+ tpm_free=1_000_000,
249
+ rpm_paid=4000,
250
+ tpm_paid=4_000_000,
251
+ cost_per_1m_input_tokens_up_to_128k=0.0375,
252
+ cost_per_1m_output_tokens_up_to_128k=0.15,
253
+ cost_per_1m_cached_tokens_up_to_128k=0.01,
254
+ cost_per_1m_input_tokens_longer_than_128k=0.075,
255
+ cost_per_1m_output_tokens_longer_than_128k=0.3,
256
+ cost_per_1m_cached_tokens_longer_than_128k=0.02,
257
+ ),
258
+ 'gemini-1.5-pro-latest': pg.Dict(
259
+ latest_update='2024-09-30',
260
+ in_service=True,
261
+ supported_modalities=ALL_MODALITIES,
262
+ rpm_free=2,
263
+ tpm_free=32_000,
264
+ rpm_paid=1000,
265
+ tpm_paid=4_000_000,
266
+ cost_per_1m_input_tokens_up_to_128k=1.25,
267
+ cost_per_1m_output_tokens_up_to_128k=5.00,
268
+ cost_per_1m_cached_tokens_up_to_128k=0.3125,
269
+ cost_per_1m_input_tokens_longer_than_128k=2.5,
270
+ cost_per_1m_output_tokens_longer_than_128k=10.00,
271
+ cost_per_1m_cached_tokens_longer_than_128k=0.625,
272
+ ),
273
+ 'gemini-1.5-pro': pg.Dict(
274
+ latest_update='2024-09-30',
275
+ in_service=True,
276
+ supported_modalities=ALL_MODALITIES,
277
+ rpm_free=2,
278
+ tpm_free=32_000,
279
+ rpm_paid=1000,
280
+ tpm_paid=4_000_000,
281
+ cost_per_1m_input_tokens_up_to_128k=1.25,
282
+ cost_per_1m_output_tokens_up_to_128k=5.00,
283
+ cost_per_1m_cached_tokens_up_to_128k=0.3125,
284
+ cost_per_1m_input_tokens_longer_than_128k=2.5,
285
+ cost_per_1m_output_tokens_longer_than_128k=10.00,
286
+ cost_per_1m_cached_tokens_longer_than_128k=0.625,
287
+ ),
288
+ 'gemini-1.5-pro-001': pg.Dict(
289
+ latest_update='2024-09-30',
290
+ in_service=True,
291
+ supported_modalities=ALL_MODALITIES,
292
+ rpm_free=2,
293
+ tpm_free=32_000,
294
+ rpm_paid=1000,
295
+ tpm_paid=4_000_000,
296
+ cost_per_1m_input_tokens_up_to_128k=1.25,
297
+ cost_per_1m_output_tokens_up_to_128k=5.00,
298
+ cost_per_1m_cached_tokens_up_to_128k=0.3125,
299
+ cost_per_1m_input_tokens_longer_than_128k=2.5,
300
+ cost_per_1m_output_tokens_longer_than_128k=10.00,
301
+ cost_per_1m_cached_tokens_longer_than_128k=0.625,
302
+ ),
303
+ 'gemini-1.5-pro-002': pg.Dict(
304
+ latest_update='2024-09-30',
305
+ in_service=True,
306
+ supported_modalities=ALL_MODALITIES,
307
+ rpm_free=2,
308
+ tpm_free=32_000,
309
+ rpm_paid=1000,
310
+ tpm_paid=4_000_000,
311
+ cost_per_1m_input_tokens_up_to_128k=1.25,
312
+ cost_per_1m_output_tokens_up_to_128k=5.00,
313
+ cost_per_1m_cached_tokens_up_to_128k=0.3125,
314
+ cost_per_1m_input_tokens_longer_than_128k=2.5,
315
+ cost_per_1m_output_tokens_longer_than_128k=10.00,
316
+ cost_per_1m_cached_tokens_longer_than_128k=0.625,
317
+ ),
318
+ 'gemini-1.0-pro': pg.Dict(
319
+ in_service=False,
320
+ supported_modalities=TEXT_ONLY,
321
+ rpm_free=15,
322
+ tpm_free=32_000,
323
+ rpm_paid=360,
324
+ tpm_paid=120_000,
325
+ cost_per_1m_input_tokens_up_to_128k=0.5,
326
+ cost_per_1m_output_tokens_up_to_128k=1.5,
327
+ cost_per_1m_cached_tokens_up_to_128k=0,
328
+ cost_per_1m_input_tokens_longer_than_128k=0.5,
329
+ cost_per_1m_output_tokens_longer_than_128k=1.5,
330
+ cost_per_1m_cached_tokens_longer_than_128k=0,
331
+ ),
332
+ }
333
+
334
+
335
+ @pg.use_init_args(['model'])
336
+ class Gemini(rest.REST):
337
+ """Language models provided by Google GenAI."""
338
+
339
+ model: pg.typing.Annotated[
340
+ pg.typing.Enum(
341
+ pg.MISSING_VALUE, list(SUPPORTED_MODELS_AND_SETTINGS.keys())
342
+ ),
343
+ 'The name of the model to use.',
344
+ ]
345
+
346
+ @property
347
+ def supported_modalities(self) -> list[str]:
348
+ """Returns the list of supported modalities."""
349
+ return SUPPORTED_MODELS_AND_SETTINGS[self.model].supported_modalities
350
+
351
+ @property
352
+ def max_concurrency(self) -> int:
353
+ """Returns the maximum number of concurrent requests."""
354
+ return self.rate_to_max_concurrency(
355
+ requests_per_min=max(
356
+ SUPPORTED_MODELS_AND_SETTINGS[self.model].rpm_free,
357
+ SUPPORTED_MODELS_AND_SETTINGS[self.model].rpm_paid
358
+ ),
359
+ tokens_per_min=max(
360
+ SUPPORTED_MODELS_AND_SETTINGS[self.model].tpm_free,
361
+ SUPPORTED_MODELS_AND_SETTINGS[self.model].tpm_paid,
362
+ ),
363
+ )
364
+
365
+ def estimate_cost(
366
+ self,
367
+ num_input_tokens: int,
368
+ num_output_tokens: int
369
+ ) -> float | None:
370
+ """Estimate the cost based on usage."""
371
+ entry = SUPPORTED_MODELS_AND_SETTINGS[self.model]
372
+ if num_input_tokens < 128_000:
373
+ cost_per_1m_input_tokens = entry.cost_per_1m_input_tokens_up_to_128k
374
+ cost_per_1m_output_tokens = entry.cost_per_1m_output_tokens_up_to_128k
375
+ else:
376
+ cost_per_1m_input_tokens = entry.cost_per_1m_input_tokens_longer_than_128k
377
+ cost_per_1m_output_tokens = (
378
+ entry.cost_per_1m_output_tokens_longer_than_128k
379
+ )
380
+ return (
381
+ cost_per_1m_input_tokens * num_input_tokens
382
+ + cost_per_1m_output_tokens * num_output_tokens
383
+ ) / 1000_1000
384
+
385
+ @property
386
+ def model_id(self) -> str:
387
+ """Returns a string to identify the model."""
388
+ return self.model
389
+
390
+ @classmethod
391
+ def dir(cls):
392
+ return [k for k, v in SUPPORTED_MODELS_AND_SETTINGS.items() if v.in_service]
393
+
394
+ @property
395
+ def headers(self):
396
+ return {
397
+ 'Content-Type': 'application/json; charset=utf-8',
398
+ }
399
+
400
+ def request(
401
+ self, prompt: lf.Message, sampling_options: lf.LMSamplingOptions
402
+ ) -> dict[str, Any]:
403
+ request = dict(
404
+ generationConfig=self._generation_config(prompt, sampling_options)
405
+ )
406
+ request['contents'] = [self._content_from_message(prompt)]
407
+ return request
408
+
409
+ def _generation_config(
410
+ self, prompt: lf.Message, options: lf.LMSamplingOptions
411
+ ) -> dict[str, Any]:
412
+ """Returns a dict as generation config for prompt and LMSamplingOptions."""
413
+ config = dict(
414
+ temperature=options.temperature,
415
+ maxOutputTokens=options.max_tokens,
416
+ candidateCount=options.n,
417
+ topK=options.top_k,
418
+ topP=options.top_p,
419
+ stopSequences=options.stop,
420
+ seed=options.random_seed,
421
+ responseLogprobs=options.logprobs,
422
+ logprobs=options.top_logprobs,
423
+ )
424
+
425
+ if json_schema := prompt.metadata.get('json_schema'):
426
+ if not isinstance(json_schema, dict):
427
+ raise ValueError(
428
+ f'`json_schema` must be a dict, got {json_schema!r}.'
429
+ )
430
+ json_schema = pg.to_json(json_schema)
431
+ config['responseSchema'] = json_schema
432
+ config['responseMimeType'] = 'application/json'
433
+ prompt.metadata.formatted_text = (
434
+ prompt.text
435
+ + '\n\n [RESPONSE FORMAT (not part of prompt)]\n'
436
+ + pg.to_json_str(json_schema, json_indent=2)
437
+ )
438
+ return config
439
+
440
+ def _content_from_message(self, prompt: lf.Message) -> dict[str, Any]:
441
+ """Gets generation content from langfun message."""
442
+ parts = []
443
+ for lf_chunk in prompt.chunk():
444
+ if isinstance(lf_chunk, str):
445
+ parts.append({'text': lf_chunk})
446
+ elif isinstance(lf_chunk, lf_modalities.Mime):
447
+ try:
448
+ modalities = lf_chunk.make_compatible(
449
+ self.supported_modalities + ['text/plain']
450
+ )
451
+ if isinstance(modalities, lf_modalities.Mime):
452
+ modalities = [modalities]
453
+ for modality in modalities:
454
+ if modality.is_text:
455
+ parts.append({'text': modality.to_text()})
456
+ else:
457
+ parts.append({
458
+ 'inlineData': {
459
+ 'data': base64.b64encode(modality.to_bytes()).decode(),
460
+ 'mimeType': modality.mime_type,
461
+ }
462
+ })
463
+ except lf.ModalityError as e:
464
+ raise lf.ModalityError(f'Unsupported modality: {lf_chunk!r}') from e
465
+ else:
466
+ raise lf.ModalityError(f'Unsupported modality: {lf_chunk!r}')
467
+ return dict(role='user', parts=parts)
468
+
469
+ def result(self, json: dict[str, Any]) -> lf.LMSamplingResult:
470
+ messages = [
471
+ self._message_from_content_parts(candidate['content']['parts'])
472
+ for candidate in json['candidates']
473
+ ]
474
+ usage = json['usageMetadata']
475
+ input_tokens = usage['promptTokenCount']
476
+ output_tokens = usage['candidatesTokenCount']
477
+ return lf.LMSamplingResult(
478
+ [lf.LMSample(message) for message in messages],
479
+ usage=lf.LMSamplingUsage(
480
+ prompt_tokens=input_tokens,
481
+ completion_tokens=output_tokens,
482
+ total_tokens=input_tokens + output_tokens,
483
+ estimated_cost=self.estimate_cost(
484
+ num_input_tokens=input_tokens,
485
+ num_output_tokens=output_tokens,
486
+ ),
487
+ ),
488
+ )
489
+
490
+ def _message_from_content_parts(
491
+ self, parts: list[dict[str, Any]]
492
+ ) -> lf.Message:
493
+ """Converts Vertex AI's content parts protocol to message."""
494
+ chunks = []
495
+ thought_chunks = []
496
+ for part in parts:
497
+ if text_part := part.get('text'):
498
+ if part.get('thought'):
499
+ thought_chunks.append(text_part)
500
+ else:
501
+ chunks.append(text_part)
502
+ else:
503
+ raise ValueError(f'Unsupported part: {part}')
504
+ message = lf.AIMessage.from_chunks(chunks)
505
+ if thought_chunks:
506
+ message.set('thought', lf.AIMessage.from_chunks(thought_chunks))
507
+ return message