langfun 0.1.2.dev202501050804__py3-none-any.whl → 0.1.2.dev202501070804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -51,6 +51,7 @@ class HtmlReporter(experiment_lib.Plugin):
51
51
  self._update_thread = None
52
52
  self._stop_update = False
53
53
  self._stop_update_experiment_ids = set()
54
+ self._experiment_index_lock = None
54
55
 
55
56
  def on_run_start(
56
57
  self,
@@ -61,6 +62,9 @@ class HtmlReporter(experiment_lib.Plugin):
61
62
  self._last_experiment_report_time = {leaf.id: 0 for leaf in root.leaf_nodes}
62
63
  self._stop_update = False
63
64
  self._stop_update_experiment_ids = set()
65
+ self._experiment_index_lock = {
66
+ leaf.id: threading.Lock() for leaf in root.leaf_nodes
67
+ }
64
68
  self._update_thread = threading.Thread(
65
69
  target=self._update_thread_func, args=(runner,)
66
70
  )
@@ -170,7 +174,8 @@ class HtmlReporter(experiment_lib.Plugin):
170
174
  card_view=False,
171
175
  ),
172
176
  )
173
- html.save(index_html_path)
177
+ with self._experiment_index_lock[experiment.id]:
178
+ html.save(index_html_path)
174
179
  experiment.info(
175
180
  f'Updated {index_html_path!r} in {t.elapse:.2f} seconds.',
176
181
  )
@@ -185,11 +190,11 @@ class HtmlReporter(experiment_lib.Plugin):
185
190
  time.time() - self._last_experiment_report_time[experiment.id]
186
191
  > self.experiment_report_interval
187
192
  ):
193
+ self._last_experiment_report_time[experiment.id] = time.time()
188
194
  if background:
189
195
  runner.background_run(_save)
190
196
  else:
191
197
  _save()
192
- self._last_experiment_report_time[experiment.id] = time.time()
193
198
 
194
199
  def _save_example_html(
195
200
  self, runner: Runner, experiment: Experiment, example: Example
@@ -434,7 +434,10 @@ class LanguageModel(component.Component):
434
434
  def __init__(self, *args, **kwargs) -> None:
435
435
  """Overrides __init__ to pass through **kwargs to sampling options."""
436
436
 
437
- sampling_options = kwargs.pop('sampling_options', LMSamplingOptions())
437
+ sampling_options = kwargs.pop(
438
+ 'sampling_options',
439
+ pg.clone(self.__schema__.fields['sampling_options'].default_value)
440
+ )
438
441
  sampling_options_delta = {}
439
442
 
440
443
  for k, v in kwargs.items():
@@ -117,6 +117,21 @@ class LanguageModelTest(unittest.TestCase):
117
117
  self.assertEqual(lm.sampling_options.top_k, 2)
118
118
  self.assertEqual(lm.max_attempts, 2)
119
119
 
120
+ def test_subclassing(self):
121
+
122
+ class ChildModel(lm_lib.LanguageModel):
123
+
124
+ sampling_options = lm_lib.LMSamplingOptions(
125
+ temperature=0.5, top_k=20
126
+ )
127
+
128
+ def _sample(self, *args, **kwargs):
129
+ pass
130
+
131
+ lm = ChildModel(top_k=10)
132
+ self.assertEqual(lm.sampling_options.temperature, 0.5)
133
+ self.assertEqual(lm.sampling_options.top_k, 10)
134
+
120
135
  def test_sample(self):
121
136
  lm = MockModel(top_k=1)
122
137
  self.assertEqual(
@@ -32,16 +32,30 @@ from langfun.core.llms.rest import REST
32
32
 
33
33
  # Gemini models.
34
34
  from langfun.core.llms.google_genai import GenAI
35
- from langfun.core.llms.google_genai import GeminiFlash2_0ThinkingExp
35
+ from langfun.core.llms.google_genai import GeminiFlash2_0ThinkingExp_20241219
36
36
  from langfun.core.llms.google_genai import GeminiFlash2_0Exp
37
- from langfun.core.llms.google_genai import GeminiExp_20241114
38
37
  from langfun.core.llms.google_genai import GeminiExp_20241206
39
- from langfun.core.llms.google_genai import GeminiFlash1_5
38
+ from langfun.core.llms.google_genai import GeminiExp_20241114
40
39
  from langfun.core.llms.google_genai import GeminiPro1_5
41
- from langfun.core.llms.google_genai import GeminiPro
42
- from langfun.core.llms.google_genai import GeminiProVision
43
- from langfun.core.llms.google_genai import Palm2
44
- from langfun.core.llms.google_genai import Palm2_IT
40
+ from langfun.core.llms.google_genai import GeminiPro1_5_002
41
+ from langfun.core.llms.google_genai import GeminiPro1_5_001
42
+ from langfun.core.llms.google_genai import GeminiFlash1_5
43
+ from langfun.core.llms.google_genai import GeminiFlash1_5_002
44
+ from langfun.core.llms.google_genai import GeminiFlash1_5_001
45
+ from langfun.core.llms.google_genai import GeminiPro1
46
+
47
+ from langfun.core.llms.vertexai import VertexAI
48
+ from langfun.core.llms.vertexai import VertexAIGeminiFlash2_0ThinkingExp_20241219
49
+ from langfun.core.llms.vertexai import VertexAIGeminiFlash2_0Exp
50
+ from langfun.core.llms.vertexai import VertexAIGeminiExp_20241206
51
+ from langfun.core.llms.vertexai import VertexAIGeminiExp_20241114
52
+ from langfun.core.llms.vertexai import VertexAIGeminiPro1_5
53
+ from langfun.core.llms.vertexai import VertexAIGeminiPro1_5_002
54
+ from langfun.core.llms.vertexai import VertexAIGeminiPro1_5_001
55
+ from langfun.core.llms.vertexai import VertexAIGeminiFlash1_5
56
+ from langfun.core.llms.vertexai import VertexAIGeminiFlash1_5_002
57
+ from langfun.core.llms.vertexai import VertexAIGeminiFlash1_5_001
58
+ from langfun.core.llms.vertexai import VertexAIGeminiPro1
45
59
 
46
60
  # OpenAI models.
47
61
  from langfun.core.llms.openai import OpenAI
@@ -124,25 +138,6 @@ from langfun.core.llms.groq import GroqGemma_7B_IT
124
138
  from langfun.core.llms.groq import GroqWhisper_Large_v3
125
139
  from langfun.core.llms.groq import GroqWhisper_Large_v3Turbo
126
140
 
127
- from langfun.core.llms.vertexai import VertexAI
128
- from langfun.core.llms.vertexai import VertexAIGemini2_0
129
- from langfun.core.llms.vertexai import VertexAIGeminiFlash2_0Exp
130
- from langfun.core.llms.vertexai import VertexAIGeminiFlash2_0ThinkingExp
131
- from langfun.core.llms.vertexai import VertexAIGemini1_5
132
- from langfun.core.llms.vertexai import VertexAIGeminiPro1_5
133
- from langfun.core.llms.vertexai import VertexAIGeminiPro1_5_001
134
- from langfun.core.llms.vertexai import VertexAIGeminiPro1_5_002
135
- from langfun.core.llms.vertexai import VertexAIGeminiPro1_5_0514
136
- from langfun.core.llms.vertexai import VertexAIGeminiPro1_5_0409
137
- from langfun.core.llms.vertexai import VertexAIGeminiFlash1_5
138
- from langfun.core.llms.vertexai import VertexAIGeminiFlash1_5_001
139
- from langfun.core.llms.vertexai import VertexAIGeminiFlash1_5_002
140
- from langfun.core.llms.vertexai import VertexAIGeminiFlash1_5_0514
141
- from langfun.core.llms.vertexai import VertexAIGeminiPro1
142
- from langfun.core.llms.vertexai import VertexAIGeminiPro1Vision
143
- from langfun.core.llms.vertexai import VertexAIEndpoint
144
-
145
-
146
141
  # LLaMA C++ models.
147
142
  from langfun.core.llms.llama_cpp import LlamaCppRemote
148
143
 
@@ -15,6 +15,7 @@
15
15
 
16
16
  import collections
17
17
  import contextlib
18
+ import json
18
19
  from typing import Annotated, Any, Iterator
19
20
  import langfun.core as lf
20
21
  from langfun.core.llms.cache import base
@@ -49,6 +50,11 @@ class InMemory(base.LMCacheBase):
49
50
  "Creating a new cache as cache file '%s' does not exist.",
50
51
  self.filename,
51
52
  )
53
+ except json.JSONDecodeError:
54
+ pg.logging.warning(
55
+ "Creating a new cache as cache file '%s' is corrupted.",
56
+ self.filename,
57
+ )
52
58
 
53
59
  def model_ids(self) -> list[str]:
54
60
  """Returns the model ids of cached queires."""
@@ -295,6 +295,11 @@ class InMemoryLMCacheTest(unittest.TestCase):
295
295
  self.assertEqual(cache2.stats.num_updates, 2)
296
296
  cache2.save()
297
297
 
298
+ # Corrupted file.
299
+ pg.io.writefile(path, 'bad_content')
300
+ cache3 = in_memory.InMemory(path)
301
+ self.assertEqual(len(cache3), 0)
302
+
298
303
 
299
304
  class LmCacheTest(unittest.TestCase):
300
305
 
@@ -0,0 +1,507 @@
1
+ # Copyright 2025 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Gemini REST API (Shared by Google GenAI and Vertex AI)."""
15
+
16
+ import base64
17
+ from typing import Any
18
+
19
+ import langfun.core as lf
20
+ from langfun.core import modalities as lf_modalities
21
+ from langfun.core.llms import rest
22
+ import pyglove as pg
23
+
24
+ # Supported modalities.
25
+
26
+ IMAGE_TYPES = [
27
+ 'image/png',
28
+ 'image/jpeg',
29
+ 'image/webp',
30
+ 'image/heic',
31
+ 'image/heif',
32
+ ]
33
+
34
+ AUDIO_TYPES = [
35
+ 'audio/aac',
36
+ 'audio/flac',
37
+ 'audio/mp3',
38
+ 'audio/m4a',
39
+ 'audio/mpeg',
40
+ 'audio/mpga',
41
+ 'audio/mp4',
42
+ 'audio/opus',
43
+ 'audio/pcm',
44
+ 'audio/wav',
45
+ 'audio/webm',
46
+ ]
47
+
48
+ VIDEO_TYPES = [
49
+ 'video/mov',
50
+ 'video/mpeg',
51
+ 'video/mpegps',
52
+ 'video/mpg',
53
+ 'video/mp4',
54
+ 'video/webm',
55
+ 'video/wmv',
56
+ 'video/x-flv',
57
+ 'video/3gpp',
58
+ 'video/quicktime',
59
+ ]
60
+
61
+ DOCUMENT_TYPES = [
62
+ 'application/pdf',
63
+ 'text/plain',
64
+ 'text/csv',
65
+ 'text/html',
66
+ 'text/xml',
67
+ 'text/x-script.python',
68
+ 'application/json',
69
+ ]
70
+
71
+ TEXT_ONLY = []
72
+
73
+ ALL_MODALITIES = (
74
+ IMAGE_TYPES + AUDIO_TYPES + VIDEO_TYPES + DOCUMENT_TYPES
75
+ )
76
+
77
+ SUPPORTED_MODELS_AND_SETTINGS = {
78
+ # For automatically rate control and cost estimation, we explicitly register
79
+ # supported models here. This may be inconvenient for new models, but it
80
+ # helps us to keep track of the models and their pricing.
81
+ # Models and RPM are from
82
+ # https://ai.google.dev/gemini-api/docs/models/gemini?_gl=1*114hbho*_up*MQ..&gclid=Cj0KCQiAst67BhCEARIsAKKdWOljBY5aQdNQ41zOPkXFCwymUfMNFl_7ukm1veAf75ZTD9qWFrFr11IaApL3EALw_wcB
83
+ # Pricing in US dollars, from https://ai.google.dev/pricing
84
+ # as of 2025-01-03.
85
+ # NOTE: Please update google_genai.py, vertexai.py, __init__.py when
86
+ # adding new models.
87
+ # !!! PLEASE KEEP MODELS SORTED BY RELEASE DATE !!!
88
+ 'gemini-2.0-flash-thinking-exp-1219': pg.Dict(
89
+ latest_update='2024-12-19',
90
+ experimental=True,
91
+ in_service=True,
92
+ supported_modalities=ALL_MODALITIES,
93
+ rpm_free=10,
94
+ tpm_free=4_000_000,
95
+ rpm_paid=0,
96
+ tpm_paid=0,
97
+ cost_per_1m_input_tokens_up_to_128k=0,
98
+ cost_per_1m_output_tokens_up_to_128k=0,
99
+ cost_per_1m_cached_tokens_up_to_128k=0,
100
+ cost_per_1m_input_tokens_longer_than_128k=0,
101
+ cost_per_1m_output_tokens_longer_than_128k=0,
102
+ cost_per_1m_cached_tokens_longer_than_128k=0,
103
+ ),
104
+ 'gemini-2.0-flash-exp': pg.Dict(
105
+ latest_update='2024-12-11',
106
+ experimental=True,
107
+ in_service=True,
108
+ supported_modalities=ALL_MODALITIES,
109
+ rpm_free=10,
110
+ tpm_free=4_000_000,
111
+ rpm_paid=0,
112
+ tpm_paid=0,
113
+ cost_per_1m_input_tokens_up_to_128k=0,
114
+ cost_per_1m_output_tokens_up_to_128k=0,
115
+ cost_per_1m_cached_tokens_up_to_128k=0,
116
+ cost_per_1m_input_tokens_longer_than_128k=0,
117
+ cost_per_1m_output_tokens_longer_than_128k=0,
118
+ cost_per_1m_cached_tokens_longer_than_128k=0,
119
+ ),
120
+ 'gemini-exp-1206': pg.Dict(
121
+ latest_update='2024-12-06',
122
+ experimental=True,
123
+ in_service=True,
124
+ supported_modalities=ALL_MODALITIES,
125
+ rpm_free=10,
126
+ tpm_free=4_000_000,
127
+ rpm_paid=0,
128
+ tpm_paid=0,
129
+ cost_per_1m_input_tokens_up_to_128k=0,
130
+ cost_per_1m_output_tokens_up_to_128k=0,
131
+ cost_per_1m_cached_tokens_up_to_128k=0,
132
+ cost_per_1m_input_tokens_longer_than_128k=0,
133
+ cost_per_1m_output_tokens_longer_than_128k=0,
134
+ cost_per_1m_cached_tokens_longer_than_128k=0,
135
+ ),
136
+ 'learnlm-1.5-pro-experimental': pg.Dict(
137
+ latest_update='2024-11-19',
138
+ experimental=True,
139
+ in_service=True,
140
+ supported_modalities=ALL_MODALITIES,
141
+ rpm_free=10,
142
+ tpm_free=4_000_000,
143
+ rpm_paid=0,
144
+ tpm_paid=0,
145
+ cost_per_1m_input_tokens_up_to_128k=0,
146
+ cost_per_1m_output_tokens_up_to_128k=0,
147
+ cost_per_1m_cached_tokens_up_to_128k=0,
148
+ cost_per_1m_input_tokens_longer_than_128k=0,
149
+ cost_per_1m_output_tokens_longer_than_128k=0,
150
+ cost_per_1m_cached_tokens_longer_than_128k=0,
151
+ ),
152
+ 'gemini-exp-1114': pg.Dict(
153
+ latest_update='2024-11-14',
154
+ experimental=True,
155
+ in_service=True,
156
+ supported_modalities=ALL_MODALITIES,
157
+ rpm_free=10,
158
+ tpm_free=4_000_000,
159
+ rpm_paid=0,
160
+ tpm_paid=0,
161
+ cost_per_1m_input_tokens_up_to_128k=0,
162
+ cost_per_1m_output_tokens_up_to_128k=0,
163
+ cost_per_1m_cached_tokens_up_to_128k=0,
164
+ cost_per_1m_input_tokens_longer_than_128k=0,
165
+ cost_per_1m_output_tokens_longer_than_128k=0,
166
+ cost_per_1m_cached_tokens_longer_than_128k=0,
167
+ ),
168
+ 'gemini-1.5-flash-latest': pg.Dict(
169
+ latest_update='2024-09-30',
170
+ in_service=True,
171
+ supported_modalities=ALL_MODALITIES,
172
+ rpm_free=15,
173
+ tpm_free=1_000_000,
174
+ rpm_paid=2000,
175
+ tpm_paid=4_000_000,
176
+ cost_per_1m_input_tokens_up_to_128k=0.075,
177
+ cost_per_1m_output_tokens_up_to_128k=0.3,
178
+ cost_per_1m_cached_tokens_up_to_128k=0.01875,
179
+ cost_per_1m_input_tokens_longer_than_128k=0.15,
180
+ cost_per_1m_output_tokens_longer_than_128k=0.6,
181
+ cost_per_1m_cached_tokens_longer_than_128k=0.0375,
182
+ ),
183
+ 'gemini-1.5-flash': pg.Dict(
184
+ latest_update='2024-09-30',
185
+ in_service=True,
186
+ supported_modalities=ALL_MODALITIES,
187
+ rpm_free=15,
188
+ tpm_free=1_000_000,
189
+ rpm_paid=2000,
190
+ tpm_paid=4_000_000,
191
+ cost_per_1m_input_tokens_up_to_128k=0.075,
192
+ cost_per_1m_output_tokens_up_to_128k=0.3,
193
+ cost_per_1m_cached_tokens_up_to_128k=0.01875,
194
+ cost_per_1m_input_tokens_longer_than_128k=0.15,
195
+ cost_per_1m_output_tokens_longer_than_128k=0.6,
196
+ cost_per_1m_cached_tokens_longer_than_128k=0.0375,
197
+ ),
198
+ 'gemini-1.5-flash-001': pg.Dict(
199
+ latest_update='2024-09-30',
200
+ in_service=True,
201
+ supported_modalities=ALL_MODALITIES,
202
+ rpm_free=15,
203
+ tpm_free=1_000_000,
204
+ rpm_paid=2000,
205
+ tpm_paid=4_000_000,
206
+ cost_per_1m_input_tokens_up_to_128k=0.075,
207
+ cost_per_1m_output_tokens_up_to_128k=0.3,
208
+ cost_per_1m_cached_tokens_up_to_128k=0.01875,
209
+ cost_per_1m_input_tokens_longer_than_128k=0.15,
210
+ cost_per_1m_output_tokens_longer_than_128k=0.6,
211
+ cost_per_1m_cached_tokens_longer_than_128k=0.0375,
212
+ ),
213
+ 'gemini-1.5-flash-002': pg.Dict(
214
+ latest_update='2024-09-30',
215
+ in_service=True,
216
+ supported_modalities=ALL_MODALITIES,
217
+ rpm_free=15,
218
+ tpm_free=1_000_000,
219
+ rpm_paid=2000,
220
+ tpm_paid=4_000_000,
221
+ cost_per_1m_input_tokens_up_to_128k=0.075,
222
+ cost_per_1m_output_tokens_up_to_128k=0.3,
223
+ cost_per_1m_cached_tokens_up_to_128k=0.01875,
224
+ cost_per_1m_input_tokens_longer_than_128k=0.15,
225
+ cost_per_1m_output_tokens_longer_than_128k=0.6,
226
+ cost_per_1m_cached_tokens_longer_than_128k=0.0375,
227
+ ),
228
+ 'gemini-1.5-flash-8b': pg.Dict(
229
+ latest_update='2024-10-30',
230
+ in_service=True,
231
+ supported_modalities=ALL_MODALITIES,
232
+ rpm_free=15,
233
+ tpm_free=1_000_000,
234
+ rpm_paid=4000,
235
+ tpm_paid=4_000_000,
236
+ cost_per_1m_input_tokens_up_to_128k=0.0375,
237
+ cost_per_1m_output_tokens_up_to_128k=0.15,
238
+ cost_per_1m_cached_tokens_up_to_128k=0.01,
239
+ cost_per_1m_input_tokens_longer_than_128k=0.075,
240
+ cost_per_1m_output_tokens_longer_than_128k=0.3,
241
+ cost_per_1m_cached_tokens_longer_than_128k=0.02,
242
+ ),
243
+ 'gemini-1.5-flash-8b-001': pg.Dict(
244
+ latest_update='2024-10-30',
245
+ in_service=True,
246
+ supported_modalities=ALL_MODALITIES,
247
+ rpm_free=15,
248
+ tpm_free=1_000_000,
249
+ rpm_paid=4000,
250
+ tpm_paid=4_000_000,
251
+ cost_per_1m_input_tokens_up_to_128k=0.0375,
252
+ cost_per_1m_output_tokens_up_to_128k=0.15,
253
+ cost_per_1m_cached_tokens_up_to_128k=0.01,
254
+ cost_per_1m_input_tokens_longer_than_128k=0.075,
255
+ cost_per_1m_output_tokens_longer_than_128k=0.3,
256
+ cost_per_1m_cached_tokens_longer_than_128k=0.02,
257
+ ),
258
+ 'gemini-1.5-pro-latest': pg.Dict(
259
+ latest_update='2024-09-30',
260
+ in_service=True,
261
+ supported_modalities=ALL_MODALITIES,
262
+ rpm_free=2,
263
+ tpm_free=32_000,
264
+ rpm_paid=1000,
265
+ tpm_paid=4_000_000,
266
+ cost_per_1m_input_tokens_up_to_128k=1.25,
267
+ cost_per_1m_output_tokens_up_to_128k=5.00,
268
+ cost_per_1m_cached_tokens_up_to_128k=0.3125,
269
+ cost_per_1m_input_tokens_longer_than_128k=2.5,
270
+ cost_per_1m_output_tokens_longer_than_128k=10.00,
271
+ cost_per_1m_cached_tokens_longer_than_128k=0.625,
272
+ ),
273
+ 'gemini-1.5-pro': pg.Dict(
274
+ latest_update='2024-09-30',
275
+ in_service=True,
276
+ supported_modalities=ALL_MODALITIES,
277
+ rpm_free=2,
278
+ tpm_free=32_000,
279
+ rpm_paid=1000,
280
+ tpm_paid=4_000_000,
281
+ cost_per_1m_input_tokens_up_to_128k=1.25,
282
+ cost_per_1m_output_tokens_up_to_128k=5.00,
283
+ cost_per_1m_cached_tokens_up_to_128k=0.3125,
284
+ cost_per_1m_input_tokens_longer_than_128k=2.5,
285
+ cost_per_1m_output_tokens_longer_than_128k=10.00,
286
+ cost_per_1m_cached_tokens_longer_than_128k=0.625,
287
+ ),
288
+ 'gemini-1.5-pro-001': pg.Dict(
289
+ latest_update='2024-09-30',
290
+ in_service=True,
291
+ supported_modalities=ALL_MODALITIES,
292
+ rpm_free=2,
293
+ tpm_free=32_000,
294
+ rpm_paid=1000,
295
+ tpm_paid=4_000_000,
296
+ cost_per_1m_input_tokens_up_to_128k=1.25,
297
+ cost_per_1m_output_tokens_up_to_128k=5.00,
298
+ cost_per_1m_cached_tokens_up_to_128k=0.3125,
299
+ cost_per_1m_input_tokens_longer_than_128k=2.5,
300
+ cost_per_1m_output_tokens_longer_than_128k=10.00,
301
+ cost_per_1m_cached_tokens_longer_than_128k=0.625,
302
+ ),
303
+ 'gemini-1.5-pro-002': pg.Dict(
304
+ latest_update='2024-09-30',
305
+ in_service=True,
306
+ supported_modalities=ALL_MODALITIES,
307
+ rpm_free=2,
308
+ tpm_free=32_000,
309
+ rpm_paid=1000,
310
+ tpm_paid=4_000_000,
311
+ cost_per_1m_input_tokens_up_to_128k=1.25,
312
+ cost_per_1m_output_tokens_up_to_128k=5.00,
313
+ cost_per_1m_cached_tokens_up_to_128k=0.3125,
314
+ cost_per_1m_input_tokens_longer_than_128k=2.5,
315
+ cost_per_1m_output_tokens_longer_than_128k=10.00,
316
+ cost_per_1m_cached_tokens_longer_than_128k=0.625,
317
+ ),
318
+ 'gemini-1.0-pro': pg.Dict(
319
+ in_service=False,
320
+ supported_modalities=TEXT_ONLY,
321
+ rpm_free=15,
322
+ tpm_free=32_000,
323
+ rpm_paid=360,
324
+ tpm_paid=120_000,
325
+ cost_per_1m_input_tokens_up_to_128k=0.5,
326
+ cost_per_1m_output_tokens_up_to_128k=1.5,
327
+ cost_per_1m_cached_tokens_up_to_128k=0,
328
+ cost_per_1m_input_tokens_longer_than_128k=0.5,
329
+ cost_per_1m_output_tokens_longer_than_128k=1.5,
330
+ cost_per_1m_cached_tokens_longer_than_128k=0,
331
+ ),
332
+ }
333
+
334
+
335
+ @pg.use_init_args(['model'])
336
+ class Gemini(rest.REST):
337
+ """Language models provided by Google GenAI."""
338
+
339
+ model: pg.typing.Annotated[
340
+ pg.typing.Enum(
341
+ pg.MISSING_VALUE, list(SUPPORTED_MODELS_AND_SETTINGS.keys())
342
+ ),
343
+ 'The name of the model to use.',
344
+ ]
345
+
346
+ @property
347
+ def supported_modalities(self) -> list[str]:
348
+ """Returns the list of supported modalities."""
349
+ return SUPPORTED_MODELS_AND_SETTINGS[self.model].supported_modalities
350
+
351
+ @property
352
+ def max_concurrency(self) -> int:
353
+ """Returns the maximum number of concurrent requests."""
354
+ return self.rate_to_max_concurrency(
355
+ requests_per_min=max(
356
+ SUPPORTED_MODELS_AND_SETTINGS[self.model].rpm_free,
357
+ SUPPORTED_MODELS_AND_SETTINGS[self.model].rpm_paid
358
+ ),
359
+ tokens_per_min=max(
360
+ SUPPORTED_MODELS_AND_SETTINGS[self.model].tpm_free,
361
+ SUPPORTED_MODELS_AND_SETTINGS[self.model].tpm_paid,
362
+ ),
363
+ )
364
+
365
+ def estimate_cost(
366
+ self,
367
+ num_input_tokens: int,
368
+ num_output_tokens: int
369
+ ) -> float | None:
370
+ """Estimate the cost based on usage."""
371
+ entry = SUPPORTED_MODELS_AND_SETTINGS[self.model]
372
+ if num_input_tokens < 128_000:
373
+ cost_per_1m_input_tokens = entry.cost_per_1m_input_tokens_up_to_128k
374
+ cost_per_1m_output_tokens = entry.cost_per_1m_output_tokens_up_to_128k
375
+ else:
376
+ cost_per_1m_input_tokens = entry.cost_per_1m_input_tokens_longer_than_128k
377
+ cost_per_1m_output_tokens = (
378
+ entry.cost_per_1m_output_tokens_longer_than_128k
379
+ )
380
+ return (
381
+ cost_per_1m_input_tokens * num_input_tokens
382
+ + cost_per_1m_output_tokens * num_output_tokens
383
+ ) / 1000_1000
384
+
385
+ @property
386
+ def model_id(self) -> str:
387
+ """Returns a string to identify the model."""
388
+ return self.model
389
+
390
+ @classmethod
391
+ def dir(cls):
392
+ return [k for k, v in SUPPORTED_MODELS_AND_SETTINGS.items() if v.in_service]
393
+
394
+ @property
395
+ def headers(self):
396
+ return {
397
+ 'Content-Type': 'application/json; charset=utf-8',
398
+ }
399
+
400
+ def request(
401
+ self, prompt: lf.Message, sampling_options: lf.LMSamplingOptions
402
+ ) -> dict[str, Any]:
403
+ request = dict(
404
+ generationConfig=self._generation_config(prompt, sampling_options)
405
+ )
406
+ request['contents'] = [self._content_from_message(prompt)]
407
+ return request
408
+
409
+ def _generation_config(
410
+ self, prompt: lf.Message, options: lf.LMSamplingOptions
411
+ ) -> dict[str, Any]:
412
+ """Returns a dict as generation config for prompt and LMSamplingOptions."""
413
+ config = dict(
414
+ temperature=options.temperature,
415
+ maxOutputTokens=options.max_tokens,
416
+ candidateCount=options.n,
417
+ topK=options.top_k,
418
+ topP=options.top_p,
419
+ stopSequences=options.stop,
420
+ seed=options.random_seed,
421
+ responseLogprobs=options.logprobs,
422
+ logprobs=options.top_logprobs,
423
+ )
424
+
425
+ if json_schema := prompt.metadata.get('json_schema'):
426
+ if not isinstance(json_schema, dict):
427
+ raise ValueError(
428
+ f'`json_schema` must be a dict, got {json_schema!r}.'
429
+ )
430
+ json_schema = pg.to_json(json_schema)
431
+ config['responseSchema'] = json_schema
432
+ config['responseMimeType'] = 'application/json'
433
+ prompt.metadata.formatted_text = (
434
+ prompt.text
435
+ + '\n\n [RESPONSE FORMAT (not part of prompt)]\n'
436
+ + pg.to_json_str(json_schema, json_indent=2)
437
+ )
438
+ return config
439
+
440
+ def _content_from_message(self, prompt: lf.Message) -> dict[str, Any]:
441
+ """Gets generation content from langfun message."""
442
+ parts = []
443
+ for lf_chunk in prompt.chunk():
444
+ if isinstance(lf_chunk, str):
445
+ parts.append({'text': lf_chunk})
446
+ elif isinstance(lf_chunk, lf_modalities.Mime):
447
+ try:
448
+ modalities = lf_chunk.make_compatible(
449
+ self.supported_modalities + ['text/plain']
450
+ )
451
+ if isinstance(modalities, lf_modalities.Mime):
452
+ modalities = [modalities]
453
+ for modality in modalities:
454
+ if modality.is_text:
455
+ parts.append({'text': modality.to_text()})
456
+ else:
457
+ parts.append({
458
+ 'inlineData': {
459
+ 'data': base64.b64encode(modality.to_bytes()).decode(),
460
+ 'mimeType': modality.mime_type,
461
+ }
462
+ })
463
+ except lf.ModalityError as e:
464
+ raise lf.ModalityError(f'Unsupported modality: {lf_chunk!r}') from e
465
+ else:
466
+ raise lf.ModalityError(f'Unsupported modality: {lf_chunk!r}')
467
+ return dict(role='user', parts=parts)
468
+
469
+ def result(self, json: dict[str, Any]) -> lf.LMSamplingResult:
470
+ messages = [
471
+ self._message_from_content_parts(candidate['content']['parts'])
472
+ for candidate in json['candidates']
473
+ ]
474
+ usage = json['usageMetadata']
475
+ input_tokens = usage['promptTokenCount']
476
+ output_tokens = usage['candidatesTokenCount']
477
+ return lf.LMSamplingResult(
478
+ [lf.LMSample(message) for message in messages],
479
+ usage=lf.LMSamplingUsage(
480
+ prompt_tokens=input_tokens,
481
+ completion_tokens=output_tokens,
482
+ total_tokens=input_tokens + output_tokens,
483
+ estimated_cost=self.estimate_cost(
484
+ num_input_tokens=input_tokens,
485
+ num_output_tokens=output_tokens,
486
+ ),
487
+ ),
488
+ )
489
+
490
+ def _message_from_content_parts(
491
+ self, parts: list[dict[str, Any]]
492
+ ) -> lf.Message:
493
+ """Converts Vertex AI's content parts protocol to message."""
494
+ chunks = []
495
+ thought_chunks = []
496
+ for part in parts:
497
+ if text_part := part.get('text'):
498
+ if part.get('thought'):
499
+ thought_chunks.append(text_part)
500
+ else:
501
+ chunks.append(text_part)
502
+ else:
503
+ raise ValueError(f'Unsupported part: {part}')
504
+ message = lf.AIMessage.from_chunks(chunks)
505
+ if thought_chunks:
506
+ message.set('thought', lf.AIMessage.from_chunks(thought_chunks))
507
+ return message