langfun 0.1.2.dev202501050804__py3-none-any.whl → 0.1.2.dev202501090804__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/core/__init__.py +0 -5
- langfun/core/coding/python/correction.py +4 -3
- langfun/core/coding/python/errors.py +10 -9
- langfun/core/coding/python/execution.py +23 -12
- langfun/core/coding/python/execution_test.py +21 -2
- langfun/core/coding/python/generation.py +18 -9
- langfun/core/concurrent.py +2 -3
- langfun/core/console.py +8 -3
- langfun/core/eval/base.py +2 -3
- langfun/core/eval/v2/reporting.py +15 -6
- langfun/core/language_model.py +7 -4
- langfun/core/language_model_test.py +15 -0
- langfun/core/llms/__init__.py +25 -26
- langfun/core/llms/cache/in_memory.py +6 -0
- langfun/core/llms/cache/in_memory_test.py +5 -0
- langfun/core/llms/deepseek.py +261 -0
- langfun/core/llms/deepseek_test.py +438 -0
- langfun/core/llms/gemini.py +507 -0
- langfun/core/llms/gemini_test.py +195 -0
- langfun/core/llms/google_genai.py +46 -320
- langfun/core/llms/google_genai_test.py +9 -204
- langfun/core/llms/openai.py +5 -0
- langfun/core/llms/vertexai.py +31 -359
- langfun/core/llms/vertexai_test.py +6 -166
- langfun/core/structured/mapping.py +13 -13
- langfun/core/structured/mapping_test.py +2 -2
- langfun/core/structured/schema.py +16 -8
- {langfun-0.1.2.dev202501050804.dist-info → langfun-0.1.2.dev202501090804.dist-info}/METADATA +19 -14
- {langfun-0.1.2.dev202501050804.dist-info → langfun-0.1.2.dev202501090804.dist-info}/RECORD +32 -30
- {langfun-0.1.2.dev202501050804.dist-info → langfun-0.1.2.dev202501090804.dist-info}/WHEEL +1 -1
- langfun/core/text_formatting.py +0 -168
- langfun/core/text_formatting_test.py +0 -65
- {langfun-0.1.2.dev202501050804.dist-info → langfun-0.1.2.dev202501090804.dist-info}/LICENSE +0 -0
- {langfun-0.1.2.dev202501050804.dist-info → langfun-0.1.2.dev202501090804.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,507 @@
|
|
1
|
+
# Copyright 2025 The Langfun Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
"""Gemini REST API (Shared by Google GenAI and Vertex AI)."""
|
15
|
+
|
16
|
+
import base64
|
17
|
+
from typing import Any
|
18
|
+
|
19
|
+
import langfun.core as lf
|
20
|
+
from langfun.core import modalities as lf_modalities
|
21
|
+
from langfun.core.llms import rest
|
22
|
+
import pyglove as pg
|
23
|
+
|
24
|
+
# Supported modalities.
|
25
|
+
|
26
|
+
IMAGE_TYPES = [
|
27
|
+
'image/png',
|
28
|
+
'image/jpeg',
|
29
|
+
'image/webp',
|
30
|
+
'image/heic',
|
31
|
+
'image/heif',
|
32
|
+
]
|
33
|
+
|
34
|
+
AUDIO_TYPES = [
|
35
|
+
'audio/aac',
|
36
|
+
'audio/flac',
|
37
|
+
'audio/mp3',
|
38
|
+
'audio/m4a',
|
39
|
+
'audio/mpeg',
|
40
|
+
'audio/mpga',
|
41
|
+
'audio/mp4',
|
42
|
+
'audio/opus',
|
43
|
+
'audio/pcm',
|
44
|
+
'audio/wav',
|
45
|
+
'audio/webm',
|
46
|
+
]
|
47
|
+
|
48
|
+
VIDEO_TYPES = [
|
49
|
+
'video/mov',
|
50
|
+
'video/mpeg',
|
51
|
+
'video/mpegps',
|
52
|
+
'video/mpg',
|
53
|
+
'video/mp4',
|
54
|
+
'video/webm',
|
55
|
+
'video/wmv',
|
56
|
+
'video/x-flv',
|
57
|
+
'video/3gpp',
|
58
|
+
'video/quicktime',
|
59
|
+
]
|
60
|
+
|
61
|
+
DOCUMENT_TYPES = [
|
62
|
+
'application/pdf',
|
63
|
+
'text/plain',
|
64
|
+
'text/csv',
|
65
|
+
'text/html',
|
66
|
+
'text/xml',
|
67
|
+
'text/x-script.python',
|
68
|
+
'application/json',
|
69
|
+
]
|
70
|
+
|
71
|
+
TEXT_ONLY = []
|
72
|
+
|
73
|
+
ALL_MODALITIES = (
|
74
|
+
IMAGE_TYPES + AUDIO_TYPES + VIDEO_TYPES + DOCUMENT_TYPES
|
75
|
+
)
|
76
|
+
|
77
|
+
SUPPORTED_MODELS_AND_SETTINGS = {
|
78
|
+
# For automatically rate control and cost estimation, we explicitly register
|
79
|
+
# supported models here. This may be inconvenient for new models, but it
|
80
|
+
# helps us to keep track of the models and their pricing.
|
81
|
+
# Models and RPM are from
|
82
|
+
# https://ai.google.dev/gemini-api/docs/models/gemini?_gl=1*114hbho*_up*MQ..&gclid=Cj0KCQiAst67BhCEARIsAKKdWOljBY5aQdNQ41zOPkXFCwymUfMNFl_7ukm1veAf75ZTD9qWFrFr11IaApL3EALw_wcB
|
83
|
+
# Pricing in US dollars, from https://ai.google.dev/pricing
|
84
|
+
# as of 2025-01-03.
|
85
|
+
# NOTE: Please update google_genai.py, vertexai.py, __init__.py when
|
86
|
+
# adding new models.
|
87
|
+
# !!! PLEASE KEEP MODELS SORTED BY RELEASE DATE !!!
|
88
|
+
'gemini-2.0-flash-thinking-exp-1219': pg.Dict(
|
89
|
+
latest_update='2024-12-19',
|
90
|
+
experimental=True,
|
91
|
+
in_service=True,
|
92
|
+
supported_modalities=ALL_MODALITIES,
|
93
|
+
rpm_free=10,
|
94
|
+
tpm_free=4_000_000,
|
95
|
+
rpm_paid=0,
|
96
|
+
tpm_paid=0,
|
97
|
+
cost_per_1m_input_tokens_up_to_128k=0,
|
98
|
+
cost_per_1m_output_tokens_up_to_128k=0,
|
99
|
+
cost_per_1m_cached_tokens_up_to_128k=0,
|
100
|
+
cost_per_1m_input_tokens_longer_than_128k=0,
|
101
|
+
cost_per_1m_output_tokens_longer_than_128k=0,
|
102
|
+
cost_per_1m_cached_tokens_longer_than_128k=0,
|
103
|
+
),
|
104
|
+
'gemini-2.0-flash-exp': pg.Dict(
|
105
|
+
latest_update='2024-12-11',
|
106
|
+
experimental=True,
|
107
|
+
in_service=True,
|
108
|
+
supported_modalities=ALL_MODALITIES,
|
109
|
+
rpm_free=10,
|
110
|
+
tpm_free=4_000_000,
|
111
|
+
rpm_paid=0,
|
112
|
+
tpm_paid=0,
|
113
|
+
cost_per_1m_input_tokens_up_to_128k=0,
|
114
|
+
cost_per_1m_output_tokens_up_to_128k=0,
|
115
|
+
cost_per_1m_cached_tokens_up_to_128k=0,
|
116
|
+
cost_per_1m_input_tokens_longer_than_128k=0,
|
117
|
+
cost_per_1m_output_tokens_longer_than_128k=0,
|
118
|
+
cost_per_1m_cached_tokens_longer_than_128k=0,
|
119
|
+
),
|
120
|
+
'gemini-exp-1206': pg.Dict(
|
121
|
+
latest_update='2024-12-06',
|
122
|
+
experimental=True,
|
123
|
+
in_service=True,
|
124
|
+
supported_modalities=ALL_MODALITIES,
|
125
|
+
rpm_free=10,
|
126
|
+
tpm_free=4_000_000,
|
127
|
+
rpm_paid=0,
|
128
|
+
tpm_paid=0,
|
129
|
+
cost_per_1m_input_tokens_up_to_128k=0,
|
130
|
+
cost_per_1m_output_tokens_up_to_128k=0,
|
131
|
+
cost_per_1m_cached_tokens_up_to_128k=0,
|
132
|
+
cost_per_1m_input_tokens_longer_than_128k=0,
|
133
|
+
cost_per_1m_output_tokens_longer_than_128k=0,
|
134
|
+
cost_per_1m_cached_tokens_longer_than_128k=0,
|
135
|
+
),
|
136
|
+
'learnlm-1.5-pro-experimental': pg.Dict(
|
137
|
+
latest_update='2024-11-19',
|
138
|
+
experimental=True,
|
139
|
+
in_service=True,
|
140
|
+
supported_modalities=ALL_MODALITIES,
|
141
|
+
rpm_free=10,
|
142
|
+
tpm_free=4_000_000,
|
143
|
+
rpm_paid=0,
|
144
|
+
tpm_paid=0,
|
145
|
+
cost_per_1m_input_tokens_up_to_128k=0,
|
146
|
+
cost_per_1m_output_tokens_up_to_128k=0,
|
147
|
+
cost_per_1m_cached_tokens_up_to_128k=0,
|
148
|
+
cost_per_1m_input_tokens_longer_than_128k=0,
|
149
|
+
cost_per_1m_output_tokens_longer_than_128k=0,
|
150
|
+
cost_per_1m_cached_tokens_longer_than_128k=0,
|
151
|
+
),
|
152
|
+
'gemini-exp-1114': pg.Dict(
|
153
|
+
latest_update='2024-11-14',
|
154
|
+
experimental=True,
|
155
|
+
in_service=True,
|
156
|
+
supported_modalities=ALL_MODALITIES,
|
157
|
+
rpm_free=10,
|
158
|
+
tpm_free=4_000_000,
|
159
|
+
rpm_paid=0,
|
160
|
+
tpm_paid=0,
|
161
|
+
cost_per_1m_input_tokens_up_to_128k=0,
|
162
|
+
cost_per_1m_output_tokens_up_to_128k=0,
|
163
|
+
cost_per_1m_cached_tokens_up_to_128k=0,
|
164
|
+
cost_per_1m_input_tokens_longer_than_128k=0,
|
165
|
+
cost_per_1m_output_tokens_longer_than_128k=0,
|
166
|
+
cost_per_1m_cached_tokens_longer_than_128k=0,
|
167
|
+
),
|
168
|
+
'gemini-1.5-flash-latest': pg.Dict(
|
169
|
+
latest_update='2024-09-30',
|
170
|
+
in_service=True,
|
171
|
+
supported_modalities=ALL_MODALITIES,
|
172
|
+
rpm_free=15,
|
173
|
+
tpm_free=1_000_000,
|
174
|
+
rpm_paid=2000,
|
175
|
+
tpm_paid=4_000_000,
|
176
|
+
cost_per_1m_input_tokens_up_to_128k=0.075,
|
177
|
+
cost_per_1m_output_tokens_up_to_128k=0.3,
|
178
|
+
cost_per_1m_cached_tokens_up_to_128k=0.01875,
|
179
|
+
cost_per_1m_input_tokens_longer_than_128k=0.15,
|
180
|
+
cost_per_1m_output_tokens_longer_than_128k=0.6,
|
181
|
+
cost_per_1m_cached_tokens_longer_than_128k=0.0375,
|
182
|
+
),
|
183
|
+
'gemini-1.5-flash': pg.Dict(
|
184
|
+
latest_update='2024-09-30',
|
185
|
+
in_service=True,
|
186
|
+
supported_modalities=ALL_MODALITIES,
|
187
|
+
rpm_free=15,
|
188
|
+
tpm_free=1_000_000,
|
189
|
+
rpm_paid=2000,
|
190
|
+
tpm_paid=4_000_000,
|
191
|
+
cost_per_1m_input_tokens_up_to_128k=0.075,
|
192
|
+
cost_per_1m_output_tokens_up_to_128k=0.3,
|
193
|
+
cost_per_1m_cached_tokens_up_to_128k=0.01875,
|
194
|
+
cost_per_1m_input_tokens_longer_than_128k=0.15,
|
195
|
+
cost_per_1m_output_tokens_longer_than_128k=0.6,
|
196
|
+
cost_per_1m_cached_tokens_longer_than_128k=0.0375,
|
197
|
+
),
|
198
|
+
'gemini-1.5-flash-001': pg.Dict(
|
199
|
+
latest_update='2024-09-30',
|
200
|
+
in_service=True,
|
201
|
+
supported_modalities=ALL_MODALITIES,
|
202
|
+
rpm_free=15,
|
203
|
+
tpm_free=1_000_000,
|
204
|
+
rpm_paid=2000,
|
205
|
+
tpm_paid=4_000_000,
|
206
|
+
cost_per_1m_input_tokens_up_to_128k=0.075,
|
207
|
+
cost_per_1m_output_tokens_up_to_128k=0.3,
|
208
|
+
cost_per_1m_cached_tokens_up_to_128k=0.01875,
|
209
|
+
cost_per_1m_input_tokens_longer_than_128k=0.15,
|
210
|
+
cost_per_1m_output_tokens_longer_than_128k=0.6,
|
211
|
+
cost_per_1m_cached_tokens_longer_than_128k=0.0375,
|
212
|
+
),
|
213
|
+
'gemini-1.5-flash-002': pg.Dict(
|
214
|
+
latest_update='2024-09-30',
|
215
|
+
in_service=True,
|
216
|
+
supported_modalities=ALL_MODALITIES,
|
217
|
+
rpm_free=15,
|
218
|
+
tpm_free=1_000_000,
|
219
|
+
rpm_paid=2000,
|
220
|
+
tpm_paid=4_000_000,
|
221
|
+
cost_per_1m_input_tokens_up_to_128k=0.075,
|
222
|
+
cost_per_1m_output_tokens_up_to_128k=0.3,
|
223
|
+
cost_per_1m_cached_tokens_up_to_128k=0.01875,
|
224
|
+
cost_per_1m_input_tokens_longer_than_128k=0.15,
|
225
|
+
cost_per_1m_output_tokens_longer_than_128k=0.6,
|
226
|
+
cost_per_1m_cached_tokens_longer_than_128k=0.0375,
|
227
|
+
),
|
228
|
+
'gemini-1.5-flash-8b': pg.Dict(
|
229
|
+
latest_update='2024-10-30',
|
230
|
+
in_service=True,
|
231
|
+
supported_modalities=ALL_MODALITIES,
|
232
|
+
rpm_free=15,
|
233
|
+
tpm_free=1_000_000,
|
234
|
+
rpm_paid=4000,
|
235
|
+
tpm_paid=4_000_000,
|
236
|
+
cost_per_1m_input_tokens_up_to_128k=0.0375,
|
237
|
+
cost_per_1m_output_tokens_up_to_128k=0.15,
|
238
|
+
cost_per_1m_cached_tokens_up_to_128k=0.01,
|
239
|
+
cost_per_1m_input_tokens_longer_than_128k=0.075,
|
240
|
+
cost_per_1m_output_tokens_longer_than_128k=0.3,
|
241
|
+
cost_per_1m_cached_tokens_longer_than_128k=0.02,
|
242
|
+
),
|
243
|
+
'gemini-1.5-flash-8b-001': pg.Dict(
|
244
|
+
latest_update='2024-10-30',
|
245
|
+
in_service=True,
|
246
|
+
supported_modalities=ALL_MODALITIES,
|
247
|
+
rpm_free=15,
|
248
|
+
tpm_free=1_000_000,
|
249
|
+
rpm_paid=4000,
|
250
|
+
tpm_paid=4_000_000,
|
251
|
+
cost_per_1m_input_tokens_up_to_128k=0.0375,
|
252
|
+
cost_per_1m_output_tokens_up_to_128k=0.15,
|
253
|
+
cost_per_1m_cached_tokens_up_to_128k=0.01,
|
254
|
+
cost_per_1m_input_tokens_longer_than_128k=0.075,
|
255
|
+
cost_per_1m_output_tokens_longer_than_128k=0.3,
|
256
|
+
cost_per_1m_cached_tokens_longer_than_128k=0.02,
|
257
|
+
),
|
258
|
+
'gemini-1.5-pro-latest': pg.Dict(
|
259
|
+
latest_update='2024-09-30',
|
260
|
+
in_service=True,
|
261
|
+
supported_modalities=ALL_MODALITIES,
|
262
|
+
rpm_free=2,
|
263
|
+
tpm_free=32_000,
|
264
|
+
rpm_paid=1000,
|
265
|
+
tpm_paid=4_000_000,
|
266
|
+
cost_per_1m_input_tokens_up_to_128k=1.25,
|
267
|
+
cost_per_1m_output_tokens_up_to_128k=5.00,
|
268
|
+
cost_per_1m_cached_tokens_up_to_128k=0.3125,
|
269
|
+
cost_per_1m_input_tokens_longer_than_128k=2.5,
|
270
|
+
cost_per_1m_output_tokens_longer_than_128k=10.00,
|
271
|
+
cost_per_1m_cached_tokens_longer_than_128k=0.625,
|
272
|
+
),
|
273
|
+
'gemini-1.5-pro': pg.Dict(
|
274
|
+
latest_update='2024-09-30',
|
275
|
+
in_service=True,
|
276
|
+
supported_modalities=ALL_MODALITIES,
|
277
|
+
rpm_free=2,
|
278
|
+
tpm_free=32_000,
|
279
|
+
rpm_paid=1000,
|
280
|
+
tpm_paid=4_000_000,
|
281
|
+
cost_per_1m_input_tokens_up_to_128k=1.25,
|
282
|
+
cost_per_1m_output_tokens_up_to_128k=5.00,
|
283
|
+
cost_per_1m_cached_tokens_up_to_128k=0.3125,
|
284
|
+
cost_per_1m_input_tokens_longer_than_128k=2.5,
|
285
|
+
cost_per_1m_output_tokens_longer_than_128k=10.00,
|
286
|
+
cost_per_1m_cached_tokens_longer_than_128k=0.625,
|
287
|
+
),
|
288
|
+
'gemini-1.5-pro-001': pg.Dict(
|
289
|
+
latest_update='2024-09-30',
|
290
|
+
in_service=True,
|
291
|
+
supported_modalities=ALL_MODALITIES,
|
292
|
+
rpm_free=2,
|
293
|
+
tpm_free=32_000,
|
294
|
+
rpm_paid=1000,
|
295
|
+
tpm_paid=4_000_000,
|
296
|
+
cost_per_1m_input_tokens_up_to_128k=1.25,
|
297
|
+
cost_per_1m_output_tokens_up_to_128k=5.00,
|
298
|
+
cost_per_1m_cached_tokens_up_to_128k=0.3125,
|
299
|
+
cost_per_1m_input_tokens_longer_than_128k=2.5,
|
300
|
+
cost_per_1m_output_tokens_longer_than_128k=10.00,
|
301
|
+
cost_per_1m_cached_tokens_longer_than_128k=0.625,
|
302
|
+
),
|
303
|
+
'gemini-1.5-pro-002': pg.Dict(
|
304
|
+
latest_update='2024-09-30',
|
305
|
+
in_service=True,
|
306
|
+
supported_modalities=ALL_MODALITIES,
|
307
|
+
rpm_free=2,
|
308
|
+
tpm_free=32_000,
|
309
|
+
rpm_paid=1000,
|
310
|
+
tpm_paid=4_000_000,
|
311
|
+
cost_per_1m_input_tokens_up_to_128k=1.25,
|
312
|
+
cost_per_1m_output_tokens_up_to_128k=5.00,
|
313
|
+
cost_per_1m_cached_tokens_up_to_128k=0.3125,
|
314
|
+
cost_per_1m_input_tokens_longer_than_128k=2.5,
|
315
|
+
cost_per_1m_output_tokens_longer_than_128k=10.00,
|
316
|
+
cost_per_1m_cached_tokens_longer_than_128k=0.625,
|
317
|
+
),
|
318
|
+
'gemini-1.0-pro': pg.Dict(
|
319
|
+
in_service=False,
|
320
|
+
supported_modalities=TEXT_ONLY,
|
321
|
+
rpm_free=15,
|
322
|
+
tpm_free=32_000,
|
323
|
+
rpm_paid=360,
|
324
|
+
tpm_paid=120_000,
|
325
|
+
cost_per_1m_input_tokens_up_to_128k=0.5,
|
326
|
+
cost_per_1m_output_tokens_up_to_128k=1.5,
|
327
|
+
cost_per_1m_cached_tokens_up_to_128k=0,
|
328
|
+
cost_per_1m_input_tokens_longer_than_128k=0.5,
|
329
|
+
cost_per_1m_output_tokens_longer_than_128k=1.5,
|
330
|
+
cost_per_1m_cached_tokens_longer_than_128k=0,
|
331
|
+
),
|
332
|
+
}
|
333
|
+
|
334
|
+
|
335
|
+
@pg.use_init_args(['model'])
|
336
|
+
class Gemini(rest.REST):
|
337
|
+
"""Language models provided by Google GenAI."""
|
338
|
+
|
339
|
+
model: pg.typing.Annotated[
|
340
|
+
pg.typing.Enum(
|
341
|
+
pg.MISSING_VALUE, list(SUPPORTED_MODELS_AND_SETTINGS.keys())
|
342
|
+
),
|
343
|
+
'The name of the model to use.',
|
344
|
+
]
|
345
|
+
|
346
|
+
@property
|
347
|
+
def supported_modalities(self) -> list[str]:
|
348
|
+
"""Returns the list of supported modalities."""
|
349
|
+
return SUPPORTED_MODELS_AND_SETTINGS[self.model].supported_modalities
|
350
|
+
|
351
|
+
@property
|
352
|
+
def max_concurrency(self) -> int:
|
353
|
+
"""Returns the maximum number of concurrent requests."""
|
354
|
+
return self.rate_to_max_concurrency(
|
355
|
+
requests_per_min=max(
|
356
|
+
SUPPORTED_MODELS_AND_SETTINGS[self.model].rpm_free,
|
357
|
+
SUPPORTED_MODELS_AND_SETTINGS[self.model].rpm_paid
|
358
|
+
),
|
359
|
+
tokens_per_min=max(
|
360
|
+
SUPPORTED_MODELS_AND_SETTINGS[self.model].tpm_free,
|
361
|
+
SUPPORTED_MODELS_AND_SETTINGS[self.model].tpm_paid,
|
362
|
+
),
|
363
|
+
)
|
364
|
+
|
365
|
+
def estimate_cost(
|
366
|
+
self,
|
367
|
+
num_input_tokens: int,
|
368
|
+
num_output_tokens: int
|
369
|
+
) -> float | None:
|
370
|
+
"""Estimate the cost based on usage."""
|
371
|
+
entry = SUPPORTED_MODELS_AND_SETTINGS[self.model]
|
372
|
+
if num_input_tokens < 128_000:
|
373
|
+
cost_per_1m_input_tokens = entry.cost_per_1m_input_tokens_up_to_128k
|
374
|
+
cost_per_1m_output_tokens = entry.cost_per_1m_output_tokens_up_to_128k
|
375
|
+
else:
|
376
|
+
cost_per_1m_input_tokens = entry.cost_per_1m_input_tokens_longer_than_128k
|
377
|
+
cost_per_1m_output_tokens = (
|
378
|
+
entry.cost_per_1m_output_tokens_longer_than_128k
|
379
|
+
)
|
380
|
+
return (
|
381
|
+
cost_per_1m_input_tokens * num_input_tokens
|
382
|
+
+ cost_per_1m_output_tokens * num_output_tokens
|
383
|
+
) / 1000_1000
|
384
|
+
|
385
|
+
@property
|
386
|
+
def model_id(self) -> str:
|
387
|
+
"""Returns a string to identify the model."""
|
388
|
+
return self.model
|
389
|
+
|
390
|
+
@classmethod
|
391
|
+
def dir(cls):
|
392
|
+
return [k for k, v in SUPPORTED_MODELS_AND_SETTINGS.items() if v.in_service]
|
393
|
+
|
394
|
+
@property
|
395
|
+
def headers(self):
|
396
|
+
return {
|
397
|
+
'Content-Type': 'application/json; charset=utf-8',
|
398
|
+
}
|
399
|
+
|
400
|
+
def request(
|
401
|
+
self, prompt: lf.Message, sampling_options: lf.LMSamplingOptions
|
402
|
+
) -> dict[str, Any]:
|
403
|
+
request = dict(
|
404
|
+
generationConfig=self._generation_config(prompt, sampling_options)
|
405
|
+
)
|
406
|
+
request['contents'] = [self._content_from_message(prompt)]
|
407
|
+
return request
|
408
|
+
|
409
|
+
def _generation_config(
|
410
|
+
self, prompt: lf.Message, options: lf.LMSamplingOptions
|
411
|
+
) -> dict[str, Any]:
|
412
|
+
"""Returns a dict as generation config for prompt and LMSamplingOptions."""
|
413
|
+
config = dict(
|
414
|
+
temperature=options.temperature,
|
415
|
+
maxOutputTokens=options.max_tokens,
|
416
|
+
candidateCount=options.n,
|
417
|
+
topK=options.top_k,
|
418
|
+
topP=options.top_p,
|
419
|
+
stopSequences=options.stop,
|
420
|
+
seed=options.random_seed,
|
421
|
+
responseLogprobs=options.logprobs,
|
422
|
+
logprobs=options.top_logprobs,
|
423
|
+
)
|
424
|
+
|
425
|
+
if json_schema := prompt.metadata.get('json_schema'):
|
426
|
+
if not isinstance(json_schema, dict):
|
427
|
+
raise ValueError(
|
428
|
+
f'`json_schema` must be a dict, got {json_schema!r}.'
|
429
|
+
)
|
430
|
+
json_schema = pg.to_json(json_schema)
|
431
|
+
config['responseSchema'] = json_schema
|
432
|
+
config['responseMimeType'] = 'application/json'
|
433
|
+
prompt.metadata.formatted_text = (
|
434
|
+
prompt.text
|
435
|
+
+ '\n\n [RESPONSE FORMAT (not part of prompt)]\n'
|
436
|
+
+ pg.to_json_str(json_schema, json_indent=2)
|
437
|
+
)
|
438
|
+
return config
|
439
|
+
|
440
|
+
def _content_from_message(self, prompt: lf.Message) -> dict[str, Any]:
|
441
|
+
"""Gets generation content from langfun message."""
|
442
|
+
parts = []
|
443
|
+
for lf_chunk in prompt.chunk():
|
444
|
+
if isinstance(lf_chunk, str):
|
445
|
+
parts.append({'text': lf_chunk})
|
446
|
+
elif isinstance(lf_chunk, lf_modalities.Mime):
|
447
|
+
try:
|
448
|
+
modalities = lf_chunk.make_compatible(
|
449
|
+
self.supported_modalities + ['text/plain']
|
450
|
+
)
|
451
|
+
if isinstance(modalities, lf_modalities.Mime):
|
452
|
+
modalities = [modalities]
|
453
|
+
for modality in modalities:
|
454
|
+
if modality.is_text:
|
455
|
+
parts.append({'text': modality.to_text()})
|
456
|
+
else:
|
457
|
+
parts.append({
|
458
|
+
'inlineData': {
|
459
|
+
'data': base64.b64encode(modality.to_bytes()).decode(),
|
460
|
+
'mimeType': modality.mime_type,
|
461
|
+
}
|
462
|
+
})
|
463
|
+
except lf.ModalityError as e:
|
464
|
+
raise lf.ModalityError(f'Unsupported modality: {lf_chunk!r}') from e
|
465
|
+
else:
|
466
|
+
raise lf.ModalityError(f'Unsupported modality: {lf_chunk!r}')
|
467
|
+
return dict(role='user', parts=parts)
|
468
|
+
|
469
|
+
def result(self, json: dict[str, Any]) -> lf.LMSamplingResult:
|
470
|
+
messages = [
|
471
|
+
self._message_from_content_parts(candidate['content']['parts'])
|
472
|
+
for candidate in json['candidates']
|
473
|
+
]
|
474
|
+
usage = json['usageMetadata']
|
475
|
+
input_tokens = usage['promptTokenCount']
|
476
|
+
output_tokens = usage['candidatesTokenCount']
|
477
|
+
return lf.LMSamplingResult(
|
478
|
+
[lf.LMSample(message) for message in messages],
|
479
|
+
usage=lf.LMSamplingUsage(
|
480
|
+
prompt_tokens=input_tokens,
|
481
|
+
completion_tokens=output_tokens,
|
482
|
+
total_tokens=input_tokens + output_tokens,
|
483
|
+
estimated_cost=self.estimate_cost(
|
484
|
+
num_input_tokens=input_tokens,
|
485
|
+
num_output_tokens=output_tokens,
|
486
|
+
),
|
487
|
+
),
|
488
|
+
)
|
489
|
+
|
490
|
+
def _message_from_content_parts(
|
491
|
+
self, parts: list[dict[str, Any]]
|
492
|
+
) -> lf.Message:
|
493
|
+
"""Converts Vertex AI's content parts protocol to message."""
|
494
|
+
chunks = []
|
495
|
+
thought_chunks = []
|
496
|
+
for part in parts:
|
497
|
+
if text_part := part.get('text'):
|
498
|
+
if part.get('thought'):
|
499
|
+
thought_chunks.append(text_part)
|
500
|
+
else:
|
501
|
+
chunks.append(text_part)
|
502
|
+
else:
|
503
|
+
raise ValueError(f'Unsupported part: {part}')
|
504
|
+
message = lf.AIMessage.from_chunks(chunks)
|
505
|
+
if thought_chunks:
|
506
|
+
message.set('thought', lf.AIMessage.from_chunks(thought_chunks))
|
507
|
+
return message
|
@@ -0,0 +1,195 @@
|
|
1
|
+
# Copyright 2025 The Langfun Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
"""Tests for Gemini API."""
|
15
|
+
|
16
|
+
import base64
|
17
|
+
from typing import Any
|
18
|
+
import unittest
|
19
|
+
from unittest import mock
|
20
|
+
|
21
|
+
import langfun.core as lf
|
22
|
+
from langfun.core import modalities as lf_modalities
|
23
|
+
from langfun.core.llms import gemini
|
24
|
+
import pyglove as pg
|
25
|
+
import requests
|
26
|
+
|
27
|
+
|
28
|
+
example_image = (
|
29
|
+
b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x18\x00\x00\x00\x18\x04'
|
30
|
+
b'\x03\x00\x00\x00\x12Y \xcb\x00\x00\x00\x18PLTE\x00\x00'
|
31
|
+
b'\x00fff_chaag_cg_ch^ci_ciC\xedb\x94\x00\x00\x00\x08tRNS'
|
32
|
+
b'\x00\n\x9f*\xd4\xff_\xf4\xe4\x8b\xf3a\x00\x00\x00>IDATx'
|
33
|
+
b'\x01c \x05\x08)"\xd8\xcc\xae!\x06pNz\x88k\x19\\Q\xa8"\x10'
|
34
|
+
b'\xc1\x14\x95\x01%\xc1\n\xa143Ta\xa8"D-\x84\x03QM\x98\xc3'
|
35
|
+
b'\x1a\x1a\x1a@5\x0e\x04\xa0q\x88\x05\x00\x07\xf8\x18\xf9'
|
36
|
+
b'\xdao\xd0|\x00\x00\x00\x00IEND\xaeB`\x82'
|
37
|
+
)
|
38
|
+
|
39
|
+
|
40
|
+
def mock_requests_post(url: str, json: dict[str, Any], **kwargs):
|
41
|
+
del url, kwargs
|
42
|
+
c = pg.Dict(json['generationConfig'])
|
43
|
+
content = json['contents'][0]['parts'][0]['text']
|
44
|
+
response = requests.Response()
|
45
|
+
response.status_code = 200
|
46
|
+
response._content = pg.to_json_str({
|
47
|
+
'candidates': [
|
48
|
+
{
|
49
|
+
'content': {
|
50
|
+
'role': 'model',
|
51
|
+
'parts': [
|
52
|
+
{
|
53
|
+
'text': (
|
54
|
+
f'This is a response to {content} with '
|
55
|
+
f'temperature={c.temperature}, '
|
56
|
+
f'top_p={c.topP}, '
|
57
|
+
f'top_k={c.topK}, '
|
58
|
+
f'max_tokens={c.maxOutputTokens}, '
|
59
|
+
f'stop={"".join(c.stopSequences)}.'
|
60
|
+
),
|
61
|
+
},
|
62
|
+
{
|
63
|
+
'text': 'This is the thought.',
|
64
|
+
'thought': True,
|
65
|
+
}
|
66
|
+
],
|
67
|
+
},
|
68
|
+
},
|
69
|
+
],
|
70
|
+
'usageMetadata': {
|
71
|
+
'promptTokenCount': 3,
|
72
|
+
'candidatesTokenCount': 4,
|
73
|
+
}
|
74
|
+
}).encode()
|
75
|
+
return response
|
76
|
+
|
77
|
+
|
78
|
+
class GeminiTest(unittest.TestCase):
|
79
|
+
"""Tests for Vertex model with REST API."""
|
80
|
+
|
81
|
+
def test_content_from_message_text_only(self):
|
82
|
+
text = 'This is a beautiful day'
|
83
|
+
model = gemini.Gemini('gemini-1.5-pro', api_endpoint='')
|
84
|
+
chunks = model._content_from_message(lf.UserMessage(text))
|
85
|
+
self.assertEqual(chunks, {'role': 'user', 'parts': [{'text': text}]})
|
86
|
+
|
87
|
+
def test_content_from_message_mm(self):
|
88
|
+
image = lf_modalities.Image.from_bytes(example_image)
|
89
|
+
message = lf.UserMessage(
|
90
|
+
'This is an <<[[image]]>>, what is it?', image=image
|
91
|
+
)
|
92
|
+
|
93
|
+
# Non-multimodal model.
|
94
|
+
with self.assertRaisesRegex(lf.ModalityError, 'Unsupported modality'):
|
95
|
+
gemini.Gemini(
|
96
|
+
'gemini-1.0-pro', api_endpoint=''
|
97
|
+
)._content_from_message(message)
|
98
|
+
|
99
|
+
model = gemini.Gemini('gemini-1.5-pro', api_endpoint='')
|
100
|
+
content = model._content_from_message(message)
|
101
|
+
self.assertEqual(
|
102
|
+
content,
|
103
|
+
{
|
104
|
+
'role': 'user',
|
105
|
+
'parts': [
|
106
|
+
{'text': 'This is an'},
|
107
|
+
{
|
108
|
+
'inlineData': {
|
109
|
+
'data': base64.b64encode(example_image).decode(),
|
110
|
+
'mimeType': 'image/png',
|
111
|
+
}
|
112
|
+
},
|
113
|
+
{'text': ', what is it?'},
|
114
|
+
],
|
115
|
+
},
|
116
|
+
)
|
117
|
+
|
118
|
+
def test_generation_config(self):
|
119
|
+
model = gemini.Gemini('gemini-1.5-pro', api_endpoint='')
|
120
|
+
json_schema = {
|
121
|
+
'type': 'object',
|
122
|
+
'properties': {
|
123
|
+
'name': {'type': 'string'},
|
124
|
+
},
|
125
|
+
'required': ['name'],
|
126
|
+
'title': 'Person',
|
127
|
+
}
|
128
|
+
actual = model._generation_config(
|
129
|
+
lf.UserMessage('hi', json_schema=json_schema),
|
130
|
+
lf.LMSamplingOptions(
|
131
|
+
temperature=2.0,
|
132
|
+
top_p=1.0,
|
133
|
+
top_k=20,
|
134
|
+
max_tokens=1024,
|
135
|
+
stop=['\n'],
|
136
|
+
),
|
137
|
+
)
|
138
|
+
self.assertEqual(
|
139
|
+
actual,
|
140
|
+
dict(
|
141
|
+
candidateCount=1,
|
142
|
+
temperature=2.0,
|
143
|
+
topP=1.0,
|
144
|
+
topK=20,
|
145
|
+
maxOutputTokens=1024,
|
146
|
+
stopSequences=['\n'],
|
147
|
+
responseLogprobs=False,
|
148
|
+
logprobs=None,
|
149
|
+
seed=None,
|
150
|
+
responseMimeType='application/json',
|
151
|
+
responseSchema={
|
152
|
+
'type': 'object',
|
153
|
+
'properties': {
|
154
|
+
'name': {'type': 'string'}
|
155
|
+
},
|
156
|
+
'required': ['name'],
|
157
|
+
'title': 'Person',
|
158
|
+
}
|
159
|
+
),
|
160
|
+
)
|
161
|
+
with self.assertRaisesRegex(
|
162
|
+
ValueError, '`json_schema` must be a dict, got'
|
163
|
+
):
|
164
|
+
model._generation_config(
|
165
|
+
lf.UserMessage('hi', json_schema='not a dict'),
|
166
|
+
lf.LMSamplingOptions(),
|
167
|
+
)
|
168
|
+
|
169
|
+
def test_call_model(self):
|
170
|
+
with mock.patch('requests.Session.post') as mock_generate:
|
171
|
+
mock_generate.side_effect = mock_requests_post
|
172
|
+
|
173
|
+
lm = gemini.Gemini('gemini-1.5-pro', api_endpoint='')
|
174
|
+
r = lm(
|
175
|
+
'hello',
|
176
|
+
temperature=2.0,
|
177
|
+
top_p=1.0,
|
178
|
+
top_k=20,
|
179
|
+
max_tokens=1024,
|
180
|
+
stop='\n',
|
181
|
+
)
|
182
|
+
self.assertEqual(
|
183
|
+
r.text,
|
184
|
+
(
|
185
|
+
'This is a response to hello with temperature=2.0, '
|
186
|
+
'top_p=1.0, top_k=20, max_tokens=1024, stop=\n.'
|
187
|
+
),
|
188
|
+
)
|
189
|
+
self.assertEqual(r.metadata.thought, 'This is the thought.')
|
190
|
+
self.assertEqual(r.metadata.usage.prompt_tokens, 3)
|
191
|
+
self.assertEqual(r.metadata.usage.completion_tokens, 4)
|
192
|
+
|
193
|
+
|
194
|
+
if __name__ == '__main__':
|
195
|
+
unittest.main()
|