langfun 0.1.2.dev202502110804__py3-none-any.whl → 0.1.2.dev202502120804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -14,386 +14,518 @@
14
14
  """Gemini REST API (Shared by Google GenAI and Vertex AI)."""
15
15
 
16
16
  import base64
17
- from typing import Any
17
+ import datetime
18
+ import functools
19
+ from typing import Annotated, Any
18
20
 
19
21
  import langfun.core as lf
20
22
  from langfun.core import modalities as lf_modalities
21
23
  from langfun.core.llms import rest
22
24
  import pyglove as pg
23
25
 
24
- # Supported modalities.
25
26
 
26
- IMAGE_TYPES = [
27
- 'image/png',
28
- 'image/jpeg',
29
- 'image/webp',
30
- 'image/heic',
31
- 'image/heif',
32
- ]
27
+ class GeminiModelInfo(lf.ModelInfo):
28
+ """Gemini model info."""
33
29
 
34
- AUDIO_TYPES = [
35
- 'audio/aac',
36
- 'audio/flac',
37
- 'audio/mp3',
38
- 'audio/m4a',
39
- 'audio/mpeg',
40
- 'audio/mpga',
41
- 'audio/mp4',
42
- 'audio/opus',
43
- 'audio/pcm',
44
- 'audio/wav',
45
- 'audio/webm',
46
- ]
30
+ # Constants for supported MIME types.
31
+ INPUT_IMAGE_TYPES = [
32
+ 'image/png',
33
+ 'image/jpeg',
34
+ 'image/webp',
35
+ 'image/heic',
36
+ 'image/heif',
37
+ ]
47
38
 
48
- VIDEO_TYPES = [
49
- 'video/mov',
50
- 'video/mpeg',
51
- 'video/mpegps',
52
- 'video/mpg',
53
- 'video/mp4',
54
- 'video/webm',
55
- 'video/wmv',
56
- 'video/x-flv',
57
- 'video/3gpp',
58
- 'video/quicktime',
59
- ]
39
+ INPUT_AUDIO_TYPES = [
40
+ 'audio/aac',
41
+ 'audio/flac',
42
+ 'audio/mp3',
43
+ 'audio/m4a',
44
+ 'audio/mpeg',
45
+ 'audio/mpga',
46
+ 'audio/mp4',
47
+ 'audio/opus',
48
+ 'audio/pcm',
49
+ 'audio/wav',
50
+ 'audio/webm',
51
+ ]
60
52
 
61
- DOCUMENT_TYPES = [
62
- 'application/pdf',
63
- 'text/plain',
64
- 'text/csv',
65
- 'text/html',
66
- 'text/xml',
67
- 'text/x-script.python',
68
- 'application/json',
69
- ]
53
+ INPUT_VIDEO_TYPES = [
54
+ 'video/mov',
55
+ 'video/mpeg',
56
+ 'video/mpegps',
57
+ 'video/mpg',
58
+ 'video/mp4',
59
+ 'video/webm',
60
+ 'video/wmv',
61
+ 'video/x-flv',
62
+ 'video/3gpp',
63
+ 'video/quicktime',
64
+ ]
70
65
 
71
- TEXT_ONLY = []
72
-
73
- ALL_MODALITIES = (
74
- IMAGE_TYPES + AUDIO_TYPES + VIDEO_TYPES + DOCUMENT_TYPES
75
- )
76
-
77
- SUPPORTED_MODELS_AND_SETTINGS = {
78
- # For automatically rate control and cost estimation, we explicitly register
79
- # supported models here. This may be inconvenient for new models, but it
80
- # helps us to keep track of the models and their pricing.
81
- # Models and RPM are from
82
- # https://ai.google.dev/gemini-api/docs/models/gemini?_gl=1*114hbho*_up*MQ..&gclid=Cj0KCQiAst67BhCEARIsAKKdWOljBY5aQdNQ41zOPkXFCwymUfMNFl_7ukm1veAf75ZTD9qWFrFr11IaApL3EALw_wcB
83
- # Pricing in US dollars, from https://ai.google.dev/pricing
84
- # as of 2025-01-03.
85
- # NOTE: Please update google_genai.py, vertexai.py, __init__.py when
86
- # adding new models.
87
- # !!! PLEASE KEEP MODELS SORTED BY RELEASE DATE !!!
88
- 'gemini-2.0-flash-001': pg.Dict(
89
- latest_update='2025-02-05',
90
- experimental=True,
91
- in_service=True,
92
- supported_modalities=ALL_MODALITIES,
93
- rpm_free=10_000,
94
- tpm_free=40_000_000,
95
- rpm_paid=0,
96
- tpm_paid=0,
97
- cost_per_1m_input_tokens_up_to_128k=0.1,
98
- cost_per_1m_output_tokens_up_to_128k=0.4,
99
- cost_per_1m_cached_tokens_up_to_128k=0.025,
100
- cost_per_1m_input_tokens_longer_than_128k=0.1,
101
- cost_per_1m_output_tokens_longer_than_128k=0.4,
102
- cost_per_1m_cached_tokens_longer_than_128k=0.025,
103
- ),
104
- 'gemini-2.0-flash': pg.Dict(
105
- latest_update='2025-02-05',
106
- experimental=True,
107
- in_service=True,
108
- supported_modalities=ALL_MODALITIES,
109
- rpm_free=10_000,
110
- tpm_free=40_000_000,
111
- rpm_paid=0,
112
- tpm_paid=0,
113
- cost_per_1m_input_tokens_up_to_128k=0.1,
114
- cost_per_1m_output_tokens_up_to_128k=0.4,
115
- cost_per_1m_cached_tokens_up_to_128k=0.025,
116
- cost_per_1m_input_tokens_longer_than_128k=0.1,
117
- cost_per_1m_output_tokens_longer_than_128k=0.4,
118
- cost_per_1m_cached_tokens_longer_than_128k=0.025,
119
- ),
120
- 'gemini-2.0-pro-exp-02-05': pg.Dict(
121
- latest_update='2025-02-05',
122
- experimental=True,
123
- in_service=True,
124
- supported_modalities=ALL_MODALITIES,
125
- rpm_free=10,
126
- tpm_free=4_000_000,
127
- rpm_paid=0,
128
- tpm_paid=0,
129
- cost_per_1m_input_tokens_up_to_128k=0.1,
130
- cost_per_1m_output_tokens_up_to_128k=0.4,
131
- cost_per_1m_cached_tokens_up_to_128k=0.025,
132
- cost_per_1m_input_tokens_longer_than_128k=0.1,
133
- cost_per_1m_output_tokens_longer_than_128k=0.4,
134
- cost_per_1m_cached_tokens_longer_than_128k=0.025,
135
- ),
136
- 'gemini-2.0-flash-thinking-exp-01-21': pg.Dict(
137
- latest_update='2024-01-21',
138
- experimental=True,
139
- in_service=True,
140
- supported_modalities=ALL_MODALITIES,
141
- rpm_free=10,
142
- tpm_free=40_000_000,
143
- rpm_paid=0,
144
- tpm_paid=0,
145
- cost_per_1m_input_tokens_up_to_128k=0,
146
- cost_per_1m_output_tokens_up_to_128k=0,
147
- cost_per_1m_cached_tokens_up_to_128k=0,
148
- cost_per_1m_input_tokens_longer_than_128k=0,
149
- cost_per_1m_output_tokens_longer_than_128k=0,
150
- cost_per_1m_cached_tokens_longer_than_128k=0,
151
- ),
152
- 'gemini-2.0-flash-thinking-exp-1219': pg.Dict(
153
- latest_update='2024-12-19',
154
- experimental=True,
66
+ INPUT_DOC_TYPES = [
67
+ 'application/pdf',
68
+ 'text/plain',
69
+ 'text/csv',
70
+ 'text/html',
71
+ 'text/xml',
72
+ 'text/x-script.python',
73
+ 'application/json',
74
+ ]
75
+
76
+ ALL_SUPPORTED_INPUT_TYPES = (
77
+ INPUT_IMAGE_TYPES
78
+ + INPUT_AUDIO_TYPES
79
+ + INPUT_VIDEO_TYPES
80
+ + INPUT_DOC_TYPES
81
+ )
82
+
83
+ LINKS = dict(
84
+ models='https://ai.google.dev/gemini-api/docs/models/gemini',
85
+ pricing='https://ai.google.dev/gemini-api/docs/pricing',
86
+ rate_limits='https://ai.google.dev/gemini-api/docs/models/gemini',
87
+ error_codes='https://ai.google.dev/gemini-api/docs/troubleshooting?lang=python#error-codes',
88
+ )
89
+
90
+ class Pricing(lf.ModelInfo.Pricing):
91
+ """Pricing for Gemini models."""
92
+
93
+ cost_per_1m_cached_input_tokens_with_prompt_longer_than_128k: Annotated[
94
+ float | None,
95
+ (
96
+ 'The cost per 1M cached input tokens for prompts longer than 128K. '
97
+ 'If None, the 128k constraint is not applicable.'
98
+ )
99
+ ] = None
100
+
101
+ cost_per_1m_input_tokens_with_prompt_longer_than_128k: Annotated[
102
+ float | None,
103
+ (
104
+ 'The cost per 1M input tokens for prompts longer than 128K. '
105
+ 'If None, the 128k constraint is not applicable.'
106
+ )
107
+ ] = None
108
+
109
+ cost_per_1m_output_tokens_with_prompt_longer_than_128k: Annotated[
110
+ float | None,
111
+ (
112
+ 'The cost per 1M output tokens for prompts longer than 128K.'
113
+ 'If None, the 128k constraint is not applicable.'
114
+ )
115
+ ] = None
116
+
117
+ def estimate_cost(self, usage: lf.LMSamplingUsage) -> float | None:
118
+ """Estimates the cost of using the model. Subclass could override.
119
+
120
+ Args:
121
+ usage: The usage information of the model.
122
+
123
+ Returns:
124
+ The estimated cost in US dollars. If None, cost estimating is not
125
+ supported on the model.
126
+ """
127
+ if (usage.prompt_tokens is None
128
+ or usage.prompt_tokens < 128_000
129
+ or not self.cost_per_1m_input_tokens_with_prompt_longer_than_128k):
130
+ return super().estimate_cost(usage)
131
+
132
+ return (
133
+ self.cost_per_1m_input_tokens_with_prompt_longer_than_128k
134
+ * usage.prompt_tokens
135
+ + self.cost_per_1m_output_tokens_with_prompt_longer_than_128k
136
+ * usage.completion_tokens
137
+ ) / 1000_000
138
+
139
+ experimental: Annotated[
140
+ bool,
141
+ (
142
+ 'If True, the model is experimental and may retire without notice.'
143
+ )
144
+ ] = False
145
+
146
+
147
+ # !!! PLEASE KEEP MODELS SORTED BY MODEL FAMILY AND RELEASE DATE !!!
148
+
149
+
150
+ SUPPORTED_MODELS = [
151
+
152
+ #
153
+ # Production models.
154
+ #
155
+
156
+ # Gemini 2.0 Flash.
157
+ GeminiModelInfo(
158
+ model_id='gemini-2.0-flash',
155
159
  in_service=True,
156
- supported_modalities=ALL_MODALITIES,
157
- rpm_free=10,
158
- tpm_free=4_000_000,
159
- rpm_paid=0,
160
- tpm_paid=0,
161
- cost_per_1m_input_tokens_up_to_128k=0,
162
- cost_per_1m_output_tokens_up_to_128k=0,
163
- cost_per_1m_cached_tokens_up_to_128k=0,
164
- cost_per_1m_input_tokens_longer_than_128k=0,
165
- cost_per_1m_output_tokens_longer_than_128k=0,
166
- cost_per_1m_cached_tokens_longer_than_128k=0,
160
+ provider=pg.oneof(['Google GenAI', 'VertexAI']),
161
+ model_type='instruction-tuned',
162
+ description=(
163
+ 'Gemini 2.0 Flash model.'
164
+ ),
165
+ release_date=datetime.datetime(2025, 2, 5),
166
+ input_modalities=GeminiModelInfo.ALL_SUPPORTED_INPUT_TYPES,
167
+ context_length=lf.ModelInfo.ContextLength(
168
+ max_input_tokens=1_048_576,
169
+ max_output_tokens=8_192,
170
+ ),
171
+ pricing=GeminiModelInfo.Pricing(
172
+ cost_per_1m_cached_input_tokens=0.025,
173
+ cost_per_1m_input_tokens=0.1,
174
+ cost_per_1m_output_tokens=0.4,
175
+ ),
176
+ rate_limits=lf.ModelInfo.RateLimits(
177
+ # Tier 4 rate limits
178
+ max_requests_per_minute=2000,
179
+ max_tokens_per_minute=4_000_000,
180
+ ),
167
181
  ),
168
- 'gemini-2.0-flash-exp': pg.Dict(
169
- latest_update='2024-12-11',
170
- experimental=True,
182
+ GeminiModelInfo(
183
+ model_id='gemini-2.0-flash-001',
171
184
  in_service=True,
172
- supported_modalities=ALL_MODALITIES,
173
- rpm_free=10,
174
- tpm_free=4_000_000,
175
- rpm_paid=0,
176
- tpm_paid=0,
177
- cost_per_1m_input_tokens_up_to_128k=0,
178
- cost_per_1m_output_tokens_up_to_128k=0,
179
- cost_per_1m_cached_tokens_up_to_128k=0,
180
- cost_per_1m_input_tokens_longer_than_128k=0,
181
- cost_per_1m_output_tokens_longer_than_128k=0,
182
- cost_per_1m_cached_tokens_longer_than_128k=0,
185
+ provider=pg.oneof(['Google GenAI', 'VertexAI']),
186
+ model_type='instruction-tuned',
187
+ description=(
188
+ 'Gemini 2.0 Flash model (version 001).'
189
+ ),
190
+ release_date=datetime.datetime(2025, 2, 5),
191
+ input_modalities=GeminiModelInfo.ALL_SUPPORTED_INPUT_TYPES,
192
+ context_length=lf.ModelInfo.ContextLength(
193
+ max_input_tokens=1_048_576,
194
+ max_output_tokens=8_192,
195
+ ),
196
+ pricing=GeminiModelInfo.Pricing(
197
+ cost_per_1m_cached_input_tokens=0.025,
198
+ cost_per_1m_input_tokens=0.1,
199
+ cost_per_1m_output_tokens=0.4,
200
+ ),
201
+ rate_limits=lf.ModelInfo.RateLimits(
202
+ # Tier 4 rate limits
203
+ max_requests_per_minute=2000,
204
+ max_tokens_per_minute=4_000_000,
205
+ ),
183
206
  ),
184
- 'gemini-exp-1206': pg.Dict(
185
- latest_update='2024-12-06',
186
- experimental=True,
207
+ # Gemini 2.0 Flash Lite.
208
+ GeminiModelInfo(
209
+ model_id='gemini-2.0-flash-lite-preview-02-05',
187
210
  in_service=True,
188
- supported_modalities=ALL_MODALITIES,
189
- rpm_free=10,
190
- tpm_free=4_000_000,
191
- rpm_paid=0,
192
- tpm_paid=0,
193
- cost_per_1m_input_tokens_up_to_128k=0,
194
- cost_per_1m_output_tokens_up_to_128k=0,
195
- cost_per_1m_cached_tokens_up_to_128k=0,
196
- cost_per_1m_input_tokens_longer_than_128k=0,
197
- cost_per_1m_output_tokens_longer_than_128k=0,
198
- cost_per_1m_cached_tokens_longer_than_128k=0,
211
+ provider=pg.oneof(['Google GenAI', 'VertexAI']),
212
+ model_type='instruction-tuned',
213
+ description=(
214
+ 'Gemini 2.0 Lite preview model.'
215
+ ),
216
+ release_date=datetime.datetime(2025, 2, 5),
217
+ input_modalities=GeminiModelInfo.ALL_SUPPORTED_INPUT_TYPES,
218
+ context_length=lf.ModelInfo.ContextLength(
219
+ max_input_tokens=1_048_576,
220
+ max_output_tokens=8_192,
221
+ ),
222
+ pricing=GeminiModelInfo.Pricing(
223
+ cost_per_1m_cached_input_tokens=0.01875,
224
+ cost_per_1m_input_tokens=0.075,
225
+ cost_per_1m_output_tokens=0.3,
226
+ ),
227
+ rate_limits=lf.ModelInfo.RateLimits(
228
+ # Tier 4 rate limits
229
+ max_requests_per_minute=4000,
230
+ max_tokens_per_minute=4_000_000,
231
+ ),
199
232
  ),
200
- 'learnlm-1.5-pro-experimental': pg.Dict(
201
- latest_update='2024-11-19',
202
- experimental=True,
233
+ # Gemini 1.5 Flash.
234
+ GeminiModelInfo(
235
+ model_id='gemini-1.5-flash',
236
+ alias_for='gemini-1.5-flash-002',
237
+ provider=pg.oneof(['Google GenAI', 'VertexAI']),
203
238
  in_service=True,
204
- supported_modalities=ALL_MODALITIES,
205
- rpm_free=10,
206
- tpm_free=4_000_000,
207
- rpm_paid=0,
208
- tpm_paid=0,
209
- cost_per_1m_input_tokens_up_to_128k=0,
210
- cost_per_1m_output_tokens_up_to_128k=0,
211
- cost_per_1m_cached_tokens_up_to_128k=0,
212
- cost_per_1m_input_tokens_longer_than_128k=0,
213
- cost_per_1m_output_tokens_longer_than_128k=0,
214
- cost_per_1m_cached_tokens_longer_than_128k=0,
239
+ model_type='instruction-tuned',
240
+ description=(
241
+ 'Gemini 1.5 Flash model (latest stable).'
242
+ ),
243
+ release_date=datetime.datetime(2024, 9, 30),
244
+ input_modalities=GeminiModelInfo.ALL_SUPPORTED_INPUT_TYPES,
245
+ context_length=lf.ModelInfo.ContextLength(
246
+ max_input_tokens=1_048_576,
247
+ max_output_tokens=8_192,
248
+ ),
249
+ pricing=GeminiModelInfo.Pricing(
250
+ cost_per_1m_cached_input_tokens=0.01875,
251
+ cost_per_1m_input_tokens=0.075,
252
+ cost_per_1m_output_tokens=0.3,
253
+ cost_per_1m_cached_input_tokens_with_prompt_longer_than_128k=0.0375,
254
+ cost_per_1m_input_tokens_with_prompt_longer_than_128k=0.15,
255
+ cost_per_1m_output_tokens_with_prompt_longer_than_128k=0.6,
256
+ ),
257
+ rate_limits=lf.ModelInfo.RateLimits(
258
+ # Tier 4 rate limits
259
+ max_requests_per_minute=2000,
260
+ max_tokens_per_minute=4_000_000,
261
+ ),
215
262
  ),
216
- 'gemini-exp-1114': pg.Dict(
217
- latest_update='2024-11-14',
218
- experimental=True,
263
+ GeminiModelInfo(
264
+ model_id='gemini-1.5-flash-001',
219
265
  in_service=True,
220
- supported_modalities=ALL_MODALITIES,
221
- rpm_free=10,
222
- tpm_free=4_000_000,
223
- rpm_paid=0,
224
- tpm_paid=0,
225
- cost_per_1m_input_tokens_up_to_128k=0,
226
- cost_per_1m_output_tokens_up_to_128k=0,
227
- cost_per_1m_cached_tokens_up_to_128k=0,
228
- cost_per_1m_input_tokens_longer_than_128k=0,
229
- cost_per_1m_output_tokens_longer_than_128k=0,
230
- cost_per_1m_cached_tokens_longer_than_128k=0,
266
+ provider=pg.oneof(['Google GenAI', 'VertexAI']),
267
+ model_type='instruction-tuned',
268
+ description=(
269
+ 'Gemini 1.5 Flash model (version 001).'
270
+ ),
271
+ input_modalities=GeminiModelInfo.ALL_SUPPORTED_INPUT_TYPES,
272
+ context_length=lf.ModelInfo.ContextLength(
273
+ max_input_tokens=1_048_576,
274
+ max_output_tokens=8_192,
275
+ ),
276
+ pricing=GeminiModelInfo.Pricing(
277
+ cost_per_1m_cached_input_tokens=0.01875,
278
+ cost_per_1m_input_tokens=0.075,
279
+ cost_per_1m_output_tokens=0.3,
280
+ cost_per_1m_cached_input_tokens_with_prompt_longer_than_128k=0.0375,
281
+ cost_per_1m_input_tokens_with_prompt_longer_than_128k=0.15,
282
+ cost_per_1m_output_tokens_with_prompt_longer_than_128k=0.6,
283
+ ),
284
+ rate_limits=lf.ModelInfo.RateLimits(
285
+ # Tier 4 rate limits
286
+ max_requests_per_minute=2000,
287
+ max_tokens_per_minute=4_000_000,
288
+ ),
231
289
  ),
232
- 'gemini-1.5-flash-latest': pg.Dict(
233
- latest_update='2024-09-30',
290
+ GeminiModelInfo(
291
+ model_id='gemini-1.5-flash-002',
234
292
  in_service=True,
235
- supported_modalities=ALL_MODALITIES,
236
- rpm_free=15,
237
- tpm_free=1_000_000,
238
- rpm_paid=2000,
239
- tpm_paid=4_000_000,
240
- cost_per_1m_input_tokens_up_to_128k=0.075,
241
- cost_per_1m_output_tokens_up_to_128k=0.3,
242
- cost_per_1m_cached_tokens_up_to_128k=0.01875,
243
- cost_per_1m_input_tokens_longer_than_128k=0.15,
244
- cost_per_1m_output_tokens_longer_than_128k=0.6,
245
- cost_per_1m_cached_tokens_longer_than_128k=0.0375,
293
+ provider=pg.oneof(['Google GenAI', 'VertexAI']),
294
+ model_type='instruction-tuned',
295
+ description=(
296
+ 'Gemini 1.5 Flash model (version 002).'
297
+ ),
298
+ input_modalities=GeminiModelInfo.ALL_SUPPORTED_INPUT_TYPES,
299
+ context_length=lf.ModelInfo.ContextLength(
300
+ max_input_tokens=1_048_576,
301
+ max_output_tokens=8_192,
302
+ ),
303
+ pricing=GeminiModelInfo.Pricing(
304
+ cost_per_1m_cached_input_tokens=0.01875,
305
+ cost_per_1m_input_tokens=0.075,
306
+ cost_per_1m_output_tokens=0.3,
307
+ cost_per_1m_cached_input_tokens_with_prompt_longer_than_128k=0.0375,
308
+ cost_per_1m_input_tokens_with_prompt_longer_than_128k=0.15,
309
+ cost_per_1m_output_tokens_with_prompt_longer_than_128k=0.6,
310
+ ),
311
+ rate_limits=lf.ModelInfo.RateLimits(
312
+ # Tier 4 rate limits
313
+ max_requests_per_minute=2000,
314
+ max_tokens_per_minute=4_000_000,
315
+ ),
246
316
  ),
247
- 'gemini-1.5-flash': pg.Dict(
248
- latest_update='2024-09-30',
317
+ # Gemini 1.5 Flash-8B.
318
+ GeminiModelInfo(
319
+ model_id='gemini-1.5-flash-8b',
249
320
  in_service=True,
250
- supported_modalities=ALL_MODALITIES,
251
- rpm_free=15,
252
- tpm_free=1_000_000,
253
- rpm_paid=2000,
254
- tpm_paid=4_000_000,
255
- cost_per_1m_input_tokens_up_to_128k=0.075,
256
- cost_per_1m_output_tokens_up_to_128k=0.3,
257
- cost_per_1m_cached_tokens_up_to_128k=0.01875,
258
- cost_per_1m_input_tokens_longer_than_128k=0.15,
259
- cost_per_1m_output_tokens_longer_than_128k=0.6,
260
- cost_per_1m_cached_tokens_longer_than_128k=0.0375,
321
+ provider='Google GenAI',
322
+ model_type='instruction-tuned',
323
+ description=(
324
+ 'Gemini 1.5 Flash 8B model (latest stable).'
325
+ ),
326
+ input_modalities=GeminiModelInfo.ALL_SUPPORTED_INPUT_TYPES,
327
+ context_length=lf.ModelInfo.ContextLength(
328
+ max_input_tokens=1_048_576,
329
+ max_output_tokens=8_192,
330
+ ),
331
+ pricing=GeminiModelInfo.Pricing(
332
+ cost_per_1m_cached_input_tokens=0.01,
333
+ cost_per_1m_input_tokens=0.0375,
334
+ cost_per_1m_output_tokens=0.15,
335
+ cost_per_1m_cached_input_tokens_with_prompt_longer_than_128k=0.02,
336
+ cost_per_1m_input_tokens_with_prompt_longer_than_128k=0.075,
337
+ cost_per_1m_output_tokens_with_prompt_longer_than_128k=0.3,
338
+ ),
339
+ rate_limits=lf.ModelInfo.RateLimits(
340
+ # Tier 4 rate limits
341
+ max_requests_per_minute=4000,
342
+ max_tokens_per_minute=4_000_000,
343
+ ),
261
344
  ),
262
- 'gemini-1.5-flash-001': pg.Dict(
263
- latest_update='2024-09-30',
345
+ GeminiModelInfo(
346
+ model_id='gemini-1.5-flash-8b-001',
264
347
  in_service=True,
265
- supported_modalities=ALL_MODALITIES,
266
- rpm_free=15,
267
- tpm_free=1_000_000,
268
- rpm_paid=2000,
269
- tpm_paid=4_000_000,
270
- cost_per_1m_input_tokens_up_to_128k=0.075,
271
- cost_per_1m_output_tokens_up_to_128k=0.3,
272
- cost_per_1m_cached_tokens_up_to_128k=0.01875,
273
- cost_per_1m_input_tokens_longer_than_128k=0.15,
274
- cost_per_1m_output_tokens_longer_than_128k=0.6,
275
- cost_per_1m_cached_tokens_longer_than_128k=0.0375,
348
+ provider='Google GenAI',
349
+ model_type='instruction-tuned',
350
+ description=(
351
+ 'Gemini 1.5 Flash 8B model (version 001).'
352
+ ),
353
+ input_modalities=GeminiModelInfo.ALL_SUPPORTED_INPUT_TYPES,
354
+ context_length=lf.ModelInfo.ContextLength(
355
+ max_input_tokens=1_048_576,
356
+ max_output_tokens=8_192,
357
+ ),
358
+ pricing=GeminiModelInfo.Pricing(
359
+ cost_per_1m_cached_input_tokens=0.01,
360
+ cost_per_1m_input_tokens=0.0375,
361
+ cost_per_1m_output_tokens=0.15,
362
+ cost_per_1m_cached_input_tokens_with_prompt_longer_than_128k=0.02,
363
+ cost_per_1m_input_tokens_with_prompt_longer_than_128k=0.075,
364
+ cost_per_1m_output_tokens_with_prompt_longer_than_128k=0.3,
365
+ ),
366
+ rate_limits=lf.ModelInfo.RateLimits(
367
+ # Tier 4 rate limits
368
+ max_requests_per_minute=4000,
369
+ max_tokens_per_minute=4_000_000,
370
+ ),
276
371
  ),
277
- 'gemini-1.5-flash-002': pg.Dict(
278
- latest_update='2024-09-30',
372
+ # Gemini 1.5 Pro.
373
+ GeminiModelInfo(
374
+ model_id='gemini-1.5-pro',
375
+ alias_for='gemini-1.5-pro-002',
279
376
  in_service=True,
280
- supported_modalities=ALL_MODALITIES,
281
- rpm_free=15,
282
- tpm_free=1_000_000,
283
- rpm_paid=2000,
284
- tpm_paid=4_000_000,
285
- cost_per_1m_input_tokens_up_to_128k=0.075,
286
- cost_per_1m_output_tokens_up_to_128k=0.3,
287
- cost_per_1m_cached_tokens_up_to_128k=0.01875,
288
- cost_per_1m_input_tokens_longer_than_128k=0.15,
289
- cost_per_1m_output_tokens_longer_than_128k=0.6,
290
- cost_per_1m_cached_tokens_longer_than_128k=0.0375,
377
+ provider=pg.oneof(['Google GenAI', 'VertexAI']),
378
+ model_type='instruction-tuned',
379
+ description=(
380
+ 'Gemini 1.5 Pro model (latest stable).'
381
+ ),
382
+ input_modalities=GeminiModelInfo.ALL_SUPPORTED_INPUT_TYPES,
383
+ context_length=lf.ModelInfo.ContextLength(
384
+ max_input_tokens=2_097_152,
385
+ max_output_tokens=8_192,
386
+ ),
387
+ pricing=GeminiModelInfo.Pricing(
388
+ cost_per_1m_cached_input_tokens=0.3125,
389
+ cost_per_1m_input_tokens=1.25,
390
+ cost_per_1m_output_tokens=5,
391
+ cost_per_1m_cached_input_tokens_with_prompt_longer_than_128k=0.625,
392
+ cost_per_1m_input_tokens_with_prompt_longer_than_128k=2.5,
393
+ cost_per_1m_output_tokens_with_prompt_longer_than_128k=10,
394
+ ),
395
+ rate_limits=lf.ModelInfo.RateLimits(
396
+ # Tier 4 rate limits
397
+ max_requests_per_minute=1000,
398
+ max_tokens_per_minute=4_000_000,
399
+ ),
291
400
  ),
292
- 'gemini-1.5-flash-8b': pg.Dict(
293
- latest_update='2024-10-30',
401
+ GeminiModelInfo(
402
+ model_id='gemini-1.5-pro-001',
294
403
  in_service=True,
295
- supported_modalities=ALL_MODALITIES,
296
- rpm_free=15,
297
- tpm_free=1_000_000,
298
- rpm_paid=4000,
299
- tpm_paid=4_000_000,
300
- cost_per_1m_input_tokens_up_to_128k=0.0375,
301
- cost_per_1m_output_tokens_up_to_128k=0.15,
302
- cost_per_1m_cached_tokens_up_to_128k=0.01,
303
- cost_per_1m_input_tokens_longer_than_128k=0.075,
304
- cost_per_1m_output_tokens_longer_than_128k=0.3,
305
- cost_per_1m_cached_tokens_longer_than_128k=0.02,
404
+ provider=pg.oneof(['Google GenAI', 'VertexAI']),
405
+ model_type='instruction-tuned',
406
+ description=(
407
+ 'Gemini 1.5 Pro model (version 001).'
408
+ ),
409
+ input_modalities=GeminiModelInfo.ALL_SUPPORTED_INPUT_TYPES,
410
+ context_length=lf.ModelInfo.ContextLength(
411
+ max_input_tokens=2_097_152,
412
+ max_output_tokens=8_192,
413
+ ),
414
+ pricing=GeminiModelInfo.Pricing(
415
+ cost_per_1m_cached_input_tokens=0.3125,
416
+ cost_per_1m_input_tokens=1.25,
417
+ cost_per_1m_output_tokens=5,
418
+ cost_per_1m_cached_input_tokens_with_prompt_longer_than_128k=0.625,
419
+ cost_per_1m_input_tokens_with_prompt_longer_than_128k=2.5,
420
+ cost_per_1m_output_tokens_with_prompt_longer_than_128k=10,
421
+ ),
422
+ rate_limits=lf.ModelInfo.RateLimits(
423
+ # Tier 4 rate limits
424
+ max_requests_per_minute=1000,
425
+ max_tokens_per_minute=4_000_000,
426
+ ),
306
427
  ),
307
- 'gemini-1.5-flash-8b-001': pg.Dict(
308
- latest_update='2024-10-30',
428
+ GeminiModelInfo(
429
+ model_id='gemini-1.5-pro-002',
309
430
  in_service=True,
310
- supported_modalities=ALL_MODALITIES,
311
- rpm_free=15,
312
- tpm_free=1_000_000,
313
- rpm_paid=4000,
314
- tpm_paid=4_000_000,
315
- cost_per_1m_input_tokens_up_to_128k=0.0375,
316
- cost_per_1m_output_tokens_up_to_128k=0.15,
317
- cost_per_1m_cached_tokens_up_to_128k=0.01,
318
- cost_per_1m_input_tokens_longer_than_128k=0.075,
319
- cost_per_1m_output_tokens_longer_than_128k=0.3,
320
- cost_per_1m_cached_tokens_longer_than_128k=0.02,
431
+ provider=pg.oneof(['Google GenAI', 'VertexAI']),
432
+ model_type='instruction-tuned',
433
+ description=(
434
+ 'Gemini 1.5 Pro model (version 002).'
435
+ ),
436
+ input_modalities=GeminiModelInfo.ALL_SUPPORTED_INPUT_TYPES,
437
+ context_length=lf.ModelInfo.ContextLength(
438
+ max_input_tokens=2_097_152,
439
+ max_output_tokens=8_192,
440
+ ),
441
+ pricing=GeminiModelInfo.Pricing(
442
+ cost_per_1m_cached_input_tokens=0.3125,
443
+ cost_per_1m_input_tokens=1.25,
444
+ cost_per_1m_output_tokens=5,
445
+ cost_per_1m_cached_input_tokens_with_prompt_longer_than_128k=0.625,
446
+ cost_per_1m_input_tokens_with_prompt_longer_than_128k=2.5,
447
+ cost_per_1m_output_tokens_with_prompt_longer_than_128k=10,
448
+ ),
449
+ rate_limits=lf.ModelInfo.RateLimits(
450
+ # Tier 4 rate limits
451
+ max_requests_per_minute=1000,
452
+ max_tokens_per_minute=4_000_000,
453
+ ),
321
454
  ),
322
- 'gemini-1.5-pro-latest': pg.Dict(
323
- latest_update='2024-09-30',
455
+
456
+ #
457
+ # Experimental models.
458
+ #
459
+
460
+ GeminiModelInfo(
461
+ model_id='gemini-2.0-pro-exp-02-05',
324
462
  in_service=True,
325
- supported_modalities=ALL_MODALITIES,
326
- rpm_free=2,
327
- tpm_free=32_000,
328
- rpm_paid=1000,
329
- tpm_paid=4_000_000,
330
- cost_per_1m_input_tokens_up_to_128k=1.25,
331
- cost_per_1m_output_tokens_up_to_128k=5.00,
332
- cost_per_1m_cached_tokens_up_to_128k=0.3125,
333
- cost_per_1m_input_tokens_longer_than_128k=2.5,
334
- cost_per_1m_output_tokens_longer_than_128k=10.00,
335
- cost_per_1m_cached_tokens_longer_than_128k=0.625,
463
+ experimental=True,
464
+ provider=pg.oneof(['Google GenAI', 'VertexAI']),
465
+ model_type='instruction-tuned',
466
+ description=(
467
+ 'Gemini 2.0 Pro experimental model (02/05/2025).'
468
+ ),
469
+ release_date=datetime.datetime(2025, 2, 5),
470
+ input_modalities=GeminiModelInfo.ALL_SUPPORTED_INPUT_TYPES,
471
+ context_length=lf.ModelInfo.ContextLength(
472
+ max_input_tokens=1_048_576,
473
+ max_output_tokens=8_192,
474
+ ),
336
475
  ),
337
- 'gemini-1.5-pro': pg.Dict(
338
- latest_update='2024-09-30',
476
+ GeminiModelInfo(
477
+ model_id='gemini-2.0-flash-thinking-exp-01-21',
339
478
  in_service=True,
340
- supported_modalities=ALL_MODALITIES,
341
- rpm_free=2,
342
- tpm_free=32_000,
343
- rpm_paid=1000,
344
- tpm_paid=4_000_000,
345
- cost_per_1m_input_tokens_up_to_128k=1.25,
346
- cost_per_1m_output_tokens_up_to_128k=5.00,
347
- cost_per_1m_cached_tokens_up_to_128k=0.3125,
348
- cost_per_1m_input_tokens_longer_than_128k=2.5,
349
- cost_per_1m_output_tokens_longer_than_128k=10.00,
350
- cost_per_1m_cached_tokens_longer_than_128k=0.625,
479
+ experimental=True,
480
+ provider=pg.oneof(['Google GenAI', 'VertexAI']),
481
+ model_type='thinking',
482
+ description=(
483
+ 'Gemini 2.0 Flash thinking experimental model (01/21/2025).'
484
+ ),
485
+ release_date=datetime.datetime(2025, 1, 21),
486
+ input_modalities=GeminiModelInfo.INPUT_IMAGE_TYPES,
487
+ context_length=lf.ModelInfo.ContextLength(
488
+ max_input_tokens=1_048_576,
489
+ max_output_tokens=8_192,
490
+ ),
351
491
  ),
352
- 'gemini-1.5-pro-001': pg.Dict(
353
- latest_update='2024-09-30',
492
+ GeminiModelInfo(
493
+ model_id='gemini-exp-1206',
354
494
  in_service=True,
355
- supported_modalities=ALL_MODALITIES,
356
- rpm_free=2,
357
- tpm_free=32_000,
358
- rpm_paid=1000,
359
- tpm_paid=4_000_000,
360
- cost_per_1m_input_tokens_up_to_128k=1.25,
361
- cost_per_1m_output_tokens_up_to_128k=5.00,
362
- cost_per_1m_cached_tokens_up_to_128k=0.3125,
363
- cost_per_1m_input_tokens_longer_than_128k=2.5,
364
- cost_per_1m_output_tokens_longer_than_128k=10.00,
365
- cost_per_1m_cached_tokens_longer_than_128k=0.625,
495
+ experimental=True,
496
+ provider='Google GenAI',
497
+ model_type='instruction-tuned',
498
+ description=(
499
+ 'Gemini year 1 experimental model (12/06/2024)'
500
+ ),
501
+ release_date=datetime.datetime(2025, 1, 21),
502
+ input_modalities=GeminiModelInfo.ALL_SUPPORTED_INPUT_TYPES,
503
+ context_length=lf.ModelInfo.ContextLength(
504
+ max_input_tokens=1_048_576,
505
+ max_output_tokens=8_192,
506
+ ),
366
507
  ),
367
- 'gemini-1.5-pro-002': pg.Dict(
368
- latest_update='2024-09-30',
508
+ GeminiModelInfo(
509
+ model_id='learnlm-1.5-pro-experimental',
369
510
  in_service=True,
370
- supported_modalities=ALL_MODALITIES,
371
- rpm_free=2,
372
- tpm_free=32_000,
373
- rpm_paid=1000,
374
- tpm_paid=4_000_000,
375
- cost_per_1m_input_tokens_up_to_128k=1.25,
376
- cost_per_1m_output_tokens_up_to_128k=5.00,
377
- cost_per_1m_cached_tokens_up_to_128k=0.3125,
378
- cost_per_1m_input_tokens_longer_than_128k=2.5,
379
- cost_per_1m_output_tokens_longer_than_128k=10.00,
380
- cost_per_1m_cached_tokens_longer_than_128k=0.625,
381
- ),
382
- 'gemini-1.0-pro': pg.Dict(
383
- in_service=False,
384
- supported_modalities=TEXT_ONLY,
385
- rpm_free=15,
386
- tpm_free=32_000,
387
- rpm_paid=360,
388
- tpm_paid=120_000,
389
- cost_per_1m_input_tokens_up_to_128k=0.5,
390
- cost_per_1m_output_tokens_up_to_128k=1.5,
391
- cost_per_1m_cached_tokens_up_to_128k=0,
392
- cost_per_1m_input_tokens_longer_than_128k=0.5,
393
- cost_per_1m_output_tokens_longer_than_128k=1.5,
394
- cost_per_1m_cached_tokens_longer_than_128k=0,
511
+ experimental=True,
512
+ provider='Google GenAI',
513
+ model_type='instruction-tuned',
514
+ description=(
515
+ 'Gemini experimental model on learning science principles.'
516
+ ),
517
+ url='https://ai.google.dev/gemini-api/docs/learnlm',
518
+ release_date=datetime.datetime(2024, 11, 19),
519
+ input_modalities=GeminiModelInfo.ALL_SUPPORTED_INPUT_TYPES,
520
+ context_length=lf.ModelInfo.ContextLength(
521
+ max_input_tokens=1_048_576,
522
+ max_output_tokens=8_192,
523
+ ),
395
524
  ),
396
- }
525
+ ]
526
+
527
+
528
+ _SUPPORTED_MODELS_BY_ID = {m.model_id: m for m in SUPPORTED_MODELS}
397
529
 
398
530
 
399
531
  @pg.use_init_args(['model'])
@@ -402,58 +534,18 @@ class Gemini(rest.REST):
402
534
 
403
535
  model: pg.typing.Annotated[
404
536
  pg.typing.Enum(
405
- pg.MISSING_VALUE, list(SUPPORTED_MODELS_AND_SETTINGS.keys())
537
+ pg.MISSING_VALUE, [m.model_id for m in SUPPORTED_MODELS]
406
538
  ),
407
539
  'The name of the model to use.',
408
540
  ]
409
541
 
410
- @property
411
- def supported_modalities(self) -> list[str]:
412
- """Returns the list of supported modalities."""
413
- return SUPPORTED_MODELS_AND_SETTINGS[self.model].supported_modalities
414
-
415
- @property
416
- def max_concurrency(self) -> int:
417
- """Returns the maximum number of concurrent requests."""
418
- return self.rate_to_max_concurrency(
419
- requests_per_min=max(
420
- SUPPORTED_MODELS_AND_SETTINGS[self.model].rpm_free,
421
- SUPPORTED_MODELS_AND_SETTINGS[self.model].rpm_paid
422
- ),
423
- tokens_per_min=max(
424
- SUPPORTED_MODELS_AND_SETTINGS[self.model].tpm_free,
425
- SUPPORTED_MODELS_AND_SETTINGS[self.model].tpm_paid,
426
- ),
427
- )
428
-
429
- def estimate_cost(
430
- self,
431
- num_input_tokens: int,
432
- num_output_tokens: int
433
- ) -> float | None:
434
- """Estimate the cost based on usage."""
435
- entry = SUPPORTED_MODELS_AND_SETTINGS[self.model]
436
- if num_input_tokens < 128_000:
437
- cost_per_1m_input_tokens = entry.cost_per_1m_input_tokens_up_to_128k
438
- cost_per_1m_output_tokens = entry.cost_per_1m_output_tokens_up_to_128k
439
- else:
440
- cost_per_1m_input_tokens = entry.cost_per_1m_input_tokens_longer_than_128k
441
- cost_per_1m_output_tokens = (
442
- entry.cost_per_1m_output_tokens_longer_than_128k
443
- )
444
- return (
445
- cost_per_1m_input_tokens * num_input_tokens
446
- + cost_per_1m_output_tokens * num_output_tokens
447
- ) / 1000_000
448
-
449
- @property
450
- def model_id(self) -> str:
451
- """Returns a string to identify the model."""
452
- return self.model
542
+ @functools.cached_property
543
+ def model_info(self) -> GeminiModelInfo:
544
+ return _SUPPORTED_MODELS_BY_ID[self.model]
453
545
 
454
546
  @classmethod
455
547
  def dir(cls):
456
- return [k for k, v in SUPPORTED_MODELS_AND_SETTINGS.items() if v.in_service]
548
+ return [m.model_id for m in SUPPORTED_MODELS if m.in_service]
457
549
 
458
550
  @property
459
551
  def headers(self):
@@ -510,7 +602,7 @@ class Gemini(rest.REST):
510
602
  elif isinstance(lf_chunk, lf_modalities.Mime):
511
603
  try:
512
604
  modalities = lf_chunk.make_compatible(
513
- self.supported_modalities + ['text/plain']
605
+ self.model_info.input_modalities + ['text/plain']
514
606
  )
515
607
  if isinstance(modalities, lf_modalities.Mime):
516
608
  modalities = [modalities]
@@ -527,7 +619,9 @@ class Gemini(rest.REST):
527
619
  except lf.ModalityError as e:
528
620
  raise lf.ModalityError(f'Unsupported modality: {lf_chunk!r}') from e
529
621
  else:
530
- raise lf.ModalityError(f'Unsupported modality: {lf_chunk!r}')
622
+ raise NotImplementedError(
623
+ f'Input conversion not implemented: {lf_chunk!r}'
624
+ )
531
625
  return dict(role='user', parts=parts)
532
626
 
533
627
  def result(self, json: dict[str, Any]) -> lf.LMSamplingResult:
@@ -544,10 +638,6 @@ class Gemini(rest.REST):
544
638
  prompt_tokens=input_tokens,
545
639
  completion_tokens=output_tokens,
546
640
  total_tokens=input_tokens + output_tokens,
547
- estimated_cost=self.estimate_cost(
548
- num_input_tokens=input_tokens,
549
- num_output_tokens=output_tokens,
550
- ),
551
641
  ),
552
642
  )
553
643