langfun 0.1.2.dev202501050804__py3-none-any.whl → 0.1.2.dev202501070804__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/core/eval/v2/reporting.py +7 -2
- langfun/core/language_model.py +4 -1
- langfun/core/language_model_test.py +15 -0
- langfun/core/llms/__init__.py +21 -26
- langfun/core/llms/cache/in_memory.py +6 -0
- langfun/core/llms/cache/in_memory_test.py +5 -0
- langfun/core/llms/gemini.py +507 -0
- langfun/core/llms/gemini_test.py +195 -0
- langfun/core/llms/google_genai.py +46 -320
- langfun/core/llms/google_genai_test.py +9 -204
- langfun/core/llms/openai.py +5 -0
- langfun/core/llms/vertexai.py +26 -357
- langfun/core/llms/vertexai_test.py +6 -166
- {langfun-0.1.2.dev202501050804.dist-info → langfun-0.1.2.dev202501070804.dist-info}/METADATA +7 -13
- {langfun-0.1.2.dev202501050804.dist-info → langfun-0.1.2.dev202501070804.dist-info}/RECORD +18 -16
- {langfun-0.1.2.dev202501050804.dist-info → langfun-0.1.2.dev202501070804.dist-info}/WHEEL +1 -1
- {langfun-0.1.2.dev202501050804.dist-info → langfun-0.1.2.dev202501070804.dist-info}/LICENSE +0 -0
- {langfun-0.1.2.dev202501050804.dist-info → langfun-0.1.2.dev202501070804.dist-info}/top_level.txt +0 -0
langfun/core/llms/vertexai.py
CHANGED
@@ -13,14 +13,12 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
"""Vertex AI generative models."""
|
15
15
|
|
16
|
-
import base64
|
17
16
|
import functools
|
18
17
|
import os
|
19
18
|
from typing import Annotated, Any
|
20
19
|
|
21
20
|
import langfun.core as lf
|
22
|
-
from langfun.core import
|
23
|
-
from langfun.core.llms import rest
|
21
|
+
from langfun.core.llms import gemini
|
24
22
|
import pyglove as pg
|
25
23
|
|
26
24
|
try:
|
@@ -38,114 +36,11 @@ except ImportError:
|
|
38
36
|
Credentials = Any
|
39
37
|
|
40
38
|
|
41
|
-
# https://cloud.google.com/vertex-ai/generative-ai/pricing
|
42
|
-
# describes that the average number of characters per token is about 4.
|
43
|
-
AVGERAGE_CHARS_PER_TOKEN = 4
|
44
|
-
|
45
|
-
|
46
|
-
# Price in US dollars,
|
47
|
-
# from https://cloud.google.com/vertex-ai/generative-ai/pricing
|
48
|
-
# as of 2024-10-10.
|
49
|
-
SUPPORTED_MODELS_AND_SETTINGS = {
|
50
|
-
'gemini-1.5-pro-001': pg.Dict(
|
51
|
-
rpm=100,
|
52
|
-
cost_per_1k_input_chars=0.0003125,
|
53
|
-
cost_per_1k_output_chars=0.00125,
|
54
|
-
),
|
55
|
-
'gemini-1.5-pro-002': pg.Dict(
|
56
|
-
rpm=100,
|
57
|
-
cost_per_1k_input_chars=0.0003125,
|
58
|
-
cost_per_1k_output_chars=0.00125,
|
59
|
-
),
|
60
|
-
'gemini-1.5-flash-002': pg.Dict(
|
61
|
-
rpm=500,
|
62
|
-
cost_per_1k_input_chars=0.00001875,
|
63
|
-
cost_per_1k_output_chars=0.000075,
|
64
|
-
),
|
65
|
-
'gemini-1.5-flash-001': pg.Dict(
|
66
|
-
rpm=500,
|
67
|
-
cost_per_1k_input_chars=0.00001875,
|
68
|
-
cost_per_1k_output_chars=0.000075,
|
69
|
-
),
|
70
|
-
'gemini-1.5-pro': pg.Dict(
|
71
|
-
rpm=100,
|
72
|
-
cost_per_1k_input_chars=0.0003125,
|
73
|
-
cost_per_1k_output_chars=0.00125,
|
74
|
-
),
|
75
|
-
'gemini-1.5-flash': pg.Dict(
|
76
|
-
rpm=500,
|
77
|
-
cost_per_1k_input_chars=0.00001875,
|
78
|
-
cost_per_1k_output_chars=0.000075,
|
79
|
-
),
|
80
|
-
'gemini-1.5-pro-preview-0514': pg.Dict(
|
81
|
-
rpm=50,
|
82
|
-
cost_per_1k_input_chars=0.0003125,
|
83
|
-
cost_per_1k_output_chars=0.00125,
|
84
|
-
),
|
85
|
-
'gemini-1.5-pro-preview-0409': pg.Dict(
|
86
|
-
rpm=50,
|
87
|
-
cost_per_1k_input_chars=0.0003125,
|
88
|
-
cost_per_1k_output_chars=0.00125,
|
89
|
-
),
|
90
|
-
'gemini-1.5-flash-preview-0514': pg.Dict(
|
91
|
-
rpm=200,
|
92
|
-
cost_per_1k_input_chars=0.00001875,
|
93
|
-
cost_per_1k_output_chars=0.000075,
|
94
|
-
),
|
95
|
-
'gemini-1.0-pro': pg.Dict(
|
96
|
-
rpm=300,
|
97
|
-
cost_per_1k_input_chars=0.000125,
|
98
|
-
cost_per_1k_output_chars=0.000375,
|
99
|
-
),
|
100
|
-
'gemini-1.0-pro-vision': pg.Dict(
|
101
|
-
rpm=100,
|
102
|
-
cost_per_1k_input_chars=0.000125,
|
103
|
-
cost_per_1k_output_chars=0.000375,
|
104
|
-
),
|
105
|
-
# TODO(sharatsharat): Update costs when published
|
106
|
-
'gemini-exp-1206': pg.Dict(
|
107
|
-
rpm=20,
|
108
|
-
cost_per_1k_input_chars=0.000,
|
109
|
-
cost_per_1k_output_chars=0.000,
|
110
|
-
),
|
111
|
-
# TODO(sharatsharat): Update costs when published
|
112
|
-
'gemini-2.0-flash-exp': pg.Dict(
|
113
|
-
rpm=10,
|
114
|
-
cost_per_1k_input_chars=0.000,
|
115
|
-
cost_per_1k_output_chars=0.000,
|
116
|
-
),
|
117
|
-
# TODO(yifenglu): Update costs when published
|
118
|
-
'gemini-2.0-flash-thinking-exp-1219': pg.Dict(
|
119
|
-
rpm=10,
|
120
|
-
cost_per_1k_input_chars=0.000,
|
121
|
-
cost_per_1k_output_chars=0.000,
|
122
|
-
),
|
123
|
-
# TODO(chengrun): Set a more appropriate rpm for endpoint.
|
124
|
-
'vertexai-endpoint': pg.Dict(
|
125
|
-
rpm=20,
|
126
|
-
cost_per_1k_input_chars=0.0000125,
|
127
|
-
cost_per_1k_output_chars=0.0000375,
|
128
|
-
),
|
129
|
-
}
|
130
|
-
|
131
|
-
|
132
39
|
@lf.use_init_args(['model'])
|
133
40
|
@pg.members([('api_endpoint', pg.typing.Str().freeze(''))])
|
134
|
-
class VertexAI(
|
41
|
+
class VertexAI(gemini.Gemini):
|
135
42
|
"""Language model served on VertexAI with REST API."""
|
136
43
|
|
137
|
-
model: pg.typing.Annotated[
|
138
|
-
pg.typing.Enum(
|
139
|
-
pg.MISSING_VALUE, list(SUPPORTED_MODELS_AND_SETTINGS.keys())
|
140
|
-
),
|
141
|
-
(
|
142
|
-
'Vertex AI model name with REST API support. See '
|
143
|
-
'https://cloud.google.com/vertex-ai/generative-ai/docs/'
|
144
|
-
'model-reference/inference#supported-models'
|
145
|
-
' for details.'
|
146
|
-
),
|
147
|
-
]
|
148
|
-
|
149
44
|
project: Annotated[
|
150
45
|
str | None,
|
151
46
|
(
|
@@ -170,11 +65,6 @@ class VertexAI(rest.REST):
|
|
170
65
|
),
|
171
66
|
] = None
|
172
67
|
|
173
|
-
supported_modalities: Annotated[
|
174
|
-
list[str],
|
175
|
-
'A list of MIME types for supported modalities'
|
176
|
-
] = []
|
177
|
-
|
178
68
|
def _on_bound(self):
|
179
69
|
super()._on_bound()
|
180
70
|
if google_auth is None:
|
@@ -209,31 +99,9 @@ class VertexAI(rest.REST):
|
|
209
99
|
self._credentials = credentials
|
210
100
|
|
211
101
|
@property
|
212
|
-
def
|
213
|
-
"""Returns
|
214
|
-
return self.
|
215
|
-
requests_per_min=SUPPORTED_MODELS_AND_SETTINGS[self.model].rpm,
|
216
|
-
tokens_per_min=0,
|
217
|
-
)
|
218
|
-
|
219
|
-
def estimate_cost(
|
220
|
-
self,
|
221
|
-
num_input_tokens: int,
|
222
|
-
num_output_tokens: int
|
223
|
-
) -> float | None:
|
224
|
-
"""Estimate the cost based on usage."""
|
225
|
-
cost_per_1k_input_chars = SUPPORTED_MODELS_AND_SETTINGS[self.model].get(
|
226
|
-
'cost_per_1k_input_chars', None
|
227
|
-
)
|
228
|
-
cost_per_1k_output_chars = SUPPORTED_MODELS_AND_SETTINGS[self.model].get(
|
229
|
-
'cost_per_1k_output_chars', None
|
230
|
-
)
|
231
|
-
if cost_per_1k_output_chars is None or cost_per_1k_input_chars is None:
|
232
|
-
return None
|
233
|
-
return (
|
234
|
-
cost_per_1k_input_chars * num_input_tokens
|
235
|
-
+ cost_per_1k_output_chars * num_output_tokens
|
236
|
-
) * AVGERAGE_CHARS_PER_TOKEN / 1000
|
102
|
+
def model_id(self) -> str:
|
103
|
+
"""Returns a string to identify the model."""
|
104
|
+
return f'VertexAI({self.model})'
|
237
105
|
|
238
106
|
@functools.cached_property
|
239
107
|
def _session(self):
|
@@ -244,12 +112,6 @@ class VertexAI(rest.REST):
|
|
244
112
|
s.headers.update(self.headers or {})
|
245
113
|
return s
|
246
114
|
|
247
|
-
@property
|
248
|
-
def headers(self):
|
249
|
-
return {
|
250
|
-
'Content-Type': 'application/json; charset=utf-8',
|
251
|
-
}
|
252
|
-
|
253
115
|
@property
|
254
116
|
def api_endpoint(self) -> str:
|
255
117
|
return (
|
@@ -258,263 +120,70 @@ class VertexAI(rest.REST):
|
|
258
120
|
f'models/{self.model}:generateContent'
|
259
121
|
)
|
260
122
|
|
261
|
-
def request(
|
262
|
-
self, prompt: lf.Message, sampling_options: lf.LMSamplingOptions
|
263
|
-
) -> dict[str, Any]:
|
264
|
-
request = dict(
|
265
|
-
generationConfig=self._generation_config(prompt, sampling_options)
|
266
|
-
)
|
267
|
-
request['contents'] = [self._content_from_message(prompt)]
|
268
|
-
return request
|
269
|
-
|
270
|
-
def _generation_config(
|
271
|
-
self, prompt: lf.Message, options: lf.LMSamplingOptions
|
272
|
-
) -> dict[str, Any]:
|
273
|
-
"""Returns a dict as generation config for prompt and LMSamplingOptions."""
|
274
|
-
config = dict(
|
275
|
-
temperature=options.temperature,
|
276
|
-
maxOutputTokens=options.max_tokens,
|
277
|
-
candidateCount=options.n,
|
278
|
-
topK=options.top_k,
|
279
|
-
topP=options.top_p,
|
280
|
-
stopSequences=options.stop,
|
281
|
-
seed=options.random_seed,
|
282
|
-
responseLogprobs=options.logprobs,
|
283
|
-
logprobs=options.top_logprobs,
|
284
|
-
)
|
285
123
|
|
286
|
-
|
287
|
-
|
288
|
-
raise ValueError(
|
289
|
-
f'`json_schema` must be a dict, got {json_schema!r}.'
|
290
|
-
)
|
291
|
-
json_schema = pg.to_json(json_schema)
|
292
|
-
config['responseSchema'] = json_schema
|
293
|
-
config['responseMimeType'] = 'application/json'
|
294
|
-
prompt.metadata.formatted_text = (
|
295
|
-
prompt.text
|
296
|
-
+ '\n\n [RESPONSE FORMAT (not part of prompt)]\n'
|
297
|
-
+ pg.to_json_str(json_schema, json_indent=2)
|
298
|
-
)
|
299
|
-
return config
|
300
|
-
|
301
|
-
def _content_from_message(self, prompt: lf.Message) -> dict[str, Any]:
|
302
|
-
"""Gets generation content from langfun message."""
|
303
|
-
parts = []
|
304
|
-
for lf_chunk in prompt.chunk():
|
305
|
-
if isinstance(lf_chunk, str):
|
306
|
-
parts.append({'text': lf_chunk})
|
307
|
-
elif isinstance(lf_chunk, lf_modalities.Mime):
|
308
|
-
try:
|
309
|
-
modalities = lf_chunk.make_compatible(
|
310
|
-
self.supported_modalities + ['text/plain']
|
311
|
-
)
|
312
|
-
if isinstance(modalities, lf_modalities.Mime):
|
313
|
-
modalities = [modalities]
|
314
|
-
for modality in modalities:
|
315
|
-
if modality.is_text:
|
316
|
-
parts.append({'text': modality.to_text()})
|
317
|
-
else:
|
318
|
-
parts.append({
|
319
|
-
'inlineData': {
|
320
|
-
'data': base64.b64encode(modality.to_bytes()).decode(),
|
321
|
-
'mimeType': modality.mime_type,
|
322
|
-
}
|
323
|
-
})
|
324
|
-
except lf.ModalityError as e:
|
325
|
-
raise lf.ModalityError(f'Unsupported modality: {lf_chunk!r}') from e
|
326
|
-
else:
|
327
|
-
raise lf.ModalityError(f'Unsupported modality: {lf_chunk!r}')
|
328
|
-
return dict(role='user', parts=parts)
|
329
|
-
|
330
|
-
def result(self, json: dict[str, Any]) -> lf.LMSamplingResult:
|
331
|
-
messages = [
|
332
|
-
self._message_from_content_parts(candidate['content']['parts'])
|
333
|
-
for candidate in json['candidates']
|
334
|
-
]
|
335
|
-
usage = json['usageMetadata']
|
336
|
-
input_tokens = usage['promptTokenCount']
|
337
|
-
output_tokens = usage['candidatesTokenCount']
|
338
|
-
return lf.LMSamplingResult(
|
339
|
-
[lf.LMSample(message) for message in messages],
|
340
|
-
usage=lf.LMSamplingUsage(
|
341
|
-
prompt_tokens=input_tokens,
|
342
|
-
completion_tokens=output_tokens,
|
343
|
-
total_tokens=input_tokens + output_tokens,
|
344
|
-
estimated_cost=self.estimate_cost(
|
345
|
-
num_input_tokens=input_tokens,
|
346
|
-
num_output_tokens=output_tokens,
|
347
|
-
),
|
348
|
-
),
|
349
|
-
)
|
124
|
+
class VertexAIGeminiFlash2_0ThinkingExp_20241219(VertexAI): # pylint: disable=invalid-name
|
125
|
+
"""Vertex AI Gemini Flash 2.0 Thinking model launched on 12/19/2024."""
|
350
126
|
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
|
357
|
-
if text_part := part.get('text'):
|
358
|
-
chunks.append(text_part)
|
359
|
-
else:
|
360
|
-
raise ValueError(f'Unsupported part: {part}')
|
361
|
-
return lf.AIMessage.from_chunks(chunks)
|
362
|
-
|
363
|
-
|
364
|
-
IMAGE_TYPES = [
|
365
|
-
'image/png',
|
366
|
-
'image/jpeg',
|
367
|
-
'image/webp',
|
368
|
-
'image/heic',
|
369
|
-
'image/heif',
|
370
|
-
]
|
371
|
-
|
372
|
-
AUDIO_TYPES = [
|
373
|
-
'audio/aac',
|
374
|
-
'audio/flac',
|
375
|
-
'audio/mp3',
|
376
|
-
'audio/m4a',
|
377
|
-
'audio/mpeg',
|
378
|
-
'audio/mpga',
|
379
|
-
'audio/mp4',
|
380
|
-
'audio/opus',
|
381
|
-
'audio/pcm',
|
382
|
-
'audio/wav',
|
383
|
-
'audio/webm',
|
384
|
-
]
|
385
|
-
|
386
|
-
VIDEO_TYPES = [
|
387
|
-
'video/mov',
|
388
|
-
'video/mpeg',
|
389
|
-
'video/mpegps',
|
390
|
-
'video/mpg',
|
391
|
-
'video/mp4',
|
392
|
-
'video/webm',
|
393
|
-
'video/wmv',
|
394
|
-
'video/x-flv',
|
395
|
-
'video/3gpp',
|
396
|
-
'video/quicktime',
|
397
|
-
]
|
398
|
-
|
399
|
-
DOCUMENT_TYPES = [
|
400
|
-
'application/pdf',
|
401
|
-
'text/plain',
|
402
|
-
'text/csv',
|
403
|
-
'text/html',
|
404
|
-
'text/xml',
|
405
|
-
'text/x-script.python',
|
406
|
-
'application/json',
|
407
|
-
]
|
408
|
-
|
409
|
-
|
410
|
-
class VertexAIGemini2_0(VertexAI): # pylint: disable=invalid-name
|
411
|
-
"""Vertex AI Gemini 2.0 model."""
|
412
|
-
|
413
|
-
supported_modalities: pg.typing.List(str).freeze( # pytype: disable=invalid-annotation
|
414
|
-
DOCUMENT_TYPES + IMAGE_TYPES + AUDIO_TYPES + VIDEO_TYPES
|
415
|
-
)
|
416
|
-
|
417
|
-
|
418
|
-
class VertexAIGeminiFlash2_0Exp(VertexAIGemini2_0): # pylint: disable=invalid-name
|
127
|
+
api_version = 'v1alpha'
|
128
|
+
model = 'gemini-2.0-flash-thinking-exp-1219'
|
129
|
+
timeout = None
|
130
|
+
|
131
|
+
|
132
|
+
class VertexAIGeminiFlash2_0Exp(VertexAI): # pylint: disable=invalid-name
|
419
133
|
"""Vertex AI Gemini 2.0 Flash model."""
|
420
134
|
|
421
135
|
model = 'gemini-2.0-flash-exp'
|
422
136
|
|
423
137
|
|
424
|
-
class
|
425
|
-
"""Vertex AI Gemini
|
138
|
+
class VertexAIGeminiExp_20241206(VertexAI): # pylint: disable=invalid-name
|
139
|
+
"""Vertex AI Gemini Experimental model launched on 12/06/2024."""
|
426
140
|
|
427
|
-
model = 'gemini-
|
141
|
+
model = 'gemini-exp-1206'
|
428
142
|
|
429
143
|
|
430
|
-
class
|
431
|
-
"""Vertex AI Gemini
|
144
|
+
class VertexAIGeminiExp_20241114(VertexAI): # pylint: disable=invalid-name
|
145
|
+
"""Vertex AI Gemini Experimental model launched on 11/14/2024."""
|
432
146
|
|
433
|
-
|
434
|
-
DOCUMENT_TYPES + IMAGE_TYPES + AUDIO_TYPES + VIDEO_TYPES
|
435
|
-
)
|
147
|
+
model = 'gemini-exp-1114'
|
436
148
|
|
437
149
|
|
438
|
-
class VertexAIGeminiPro1_5(
|
150
|
+
class VertexAIGeminiPro1_5(VertexAI): # pylint: disable=invalid-name
|
439
151
|
"""Vertex AI Gemini 1.5 Pro model."""
|
440
152
|
|
441
|
-
model = 'gemini-1.5-pro'
|
153
|
+
model = 'gemini-1.5-pro-latest'
|
442
154
|
|
443
155
|
|
444
|
-
class VertexAIGeminiPro1_5_002(
|
156
|
+
class VertexAIGeminiPro1_5_002(VertexAI): # pylint: disable=invalid-name
|
445
157
|
"""Vertex AI Gemini 1.5 Pro model."""
|
446
158
|
|
447
159
|
model = 'gemini-1.5-pro-002'
|
448
160
|
|
449
161
|
|
450
|
-
class VertexAIGeminiPro1_5_001(
|
162
|
+
class VertexAIGeminiPro1_5_001(VertexAI): # pylint: disable=invalid-name
|
451
163
|
"""Vertex AI Gemini 1.5 Pro model."""
|
452
164
|
|
453
165
|
model = 'gemini-1.5-pro-001'
|
454
166
|
|
455
167
|
|
456
|
-
class
|
457
|
-
"""Vertex AI Gemini 1.5 Pro preview model."""
|
458
|
-
|
459
|
-
model = 'gemini-1.5-pro-preview-0514'
|
460
|
-
|
461
|
-
|
462
|
-
class VertexAIGeminiPro1_5_0409(VertexAIGemini1_5): # pylint: disable=invalid-name
|
463
|
-
"""Vertex AI Gemini 1.5 Pro preview model."""
|
464
|
-
|
465
|
-
model = 'gemini-1.5-pro-preview-0409'
|
466
|
-
|
467
|
-
|
468
|
-
class VertexAIGeminiFlash1_5(VertexAIGemini1_5): # pylint: disable=invalid-name
|
168
|
+
class VertexAIGeminiFlash1_5(VertexAI): # pylint: disable=invalid-name
|
469
169
|
"""Vertex AI Gemini 1.5 Flash model."""
|
470
170
|
|
471
171
|
model = 'gemini-1.5-flash'
|
472
172
|
|
473
173
|
|
474
|
-
class VertexAIGeminiFlash1_5_002(
|
174
|
+
class VertexAIGeminiFlash1_5_002(VertexAI): # pylint: disable=invalid-name
|
475
175
|
"""Vertex AI Gemini 1.5 Flash model."""
|
476
176
|
|
477
177
|
model = 'gemini-1.5-flash-002'
|
478
178
|
|
479
179
|
|
480
|
-
class VertexAIGeminiFlash1_5_001(
|
180
|
+
class VertexAIGeminiFlash1_5_001(VertexAI): # pylint: disable=invalid-name
|
481
181
|
"""Vertex AI Gemini 1.5 Flash model."""
|
482
182
|
|
483
183
|
model = 'gemini-1.5-flash-001'
|
484
184
|
|
485
185
|
|
486
|
-
class VertexAIGeminiFlash1_5_0514(VertexAIGemini1_5): # pylint: disable=invalid-name
|
487
|
-
"""Vertex AI Gemini 1.5 Flash preview model."""
|
488
|
-
|
489
|
-
model = 'gemini-1.5-flash-preview-0514'
|
490
|
-
|
491
|
-
|
492
186
|
class VertexAIGeminiPro1(VertexAI): # pylint: disable=invalid-name
|
493
187
|
"""Vertex AI Gemini 1.0 Pro model."""
|
494
188
|
|
495
189
|
model = 'gemini-1.0-pro'
|
496
|
-
|
497
|
-
|
498
|
-
class VertexAIGeminiPro1Vision(VertexAI): # pylint: disable=invalid-name
|
499
|
-
"""Vertex AI Gemini 1.0 Pro Vision model."""
|
500
|
-
|
501
|
-
model = 'gemini-1.0-pro-vision'
|
502
|
-
supported_modalities: pg.typing.List(str).freeze( # pytype: disable=invalid-annotation
|
503
|
-
IMAGE_TYPES + VIDEO_TYPES
|
504
|
-
)
|
505
|
-
|
506
|
-
|
507
|
-
class VertexAIEndpoint(VertexAI): # pylint: disable=invalid-name
|
508
|
-
"""Vertex AI Endpoint model."""
|
509
|
-
|
510
|
-
model = 'vertexai-endpoint'
|
511
|
-
|
512
|
-
endpoint: Annotated[str, 'Vertex AI Endpoint ID.']
|
513
|
-
|
514
|
-
@property
|
515
|
-
def api_endpoint(self) -> str:
|
516
|
-
return (
|
517
|
-
f'https://{self.location}-aiplatform.googleapis.com/v1/projects/'
|
518
|
-
f'{self.project}/locations/{self.location}/'
|
519
|
-
f'endpoints/{self.endpoint}:generateContent'
|
520
|
-
)
|
@@ -11,105 +11,18 @@
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
|
-
"""Tests for
|
14
|
+
"""Tests for VertexAI models."""
|
15
15
|
|
16
|
-
import base64
|
17
16
|
import os
|
18
|
-
from typing import Any
|
19
17
|
import unittest
|
20
18
|
from unittest import mock
|
21
19
|
|
22
|
-
import langfun.core as lf
|
23
|
-
from langfun.core import modalities as lf_modalities
|
24
20
|
from langfun.core.llms import vertexai
|
25
|
-
import pyglove as pg
|
26
|
-
import requests
|
27
|
-
|
28
|
-
|
29
|
-
example_image = (
|
30
|
-
b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x18\x00\x00\x00\x18\x04'
|
31
|
-
b'\x03\x00\x00\x00\x12Y \xcb\x00\x00\x00\x18PLTE\x00\x00'
|
32
|
-
b'\x00fff_chaag_cg_ch^ci_ciC\xedb\x94\x00\x00\x00\x08tRNS'
|
33
|
-
b'\x00\n\x9f*\xd4\xff_\xf4\xe4\x8b\xf3a\x00\x00\x00>IDATx'
|
34
|
-
b'\x01c \x05\x08)"\xd8\xcc\xae!\x06pNz\x88k\x19\\Q\xa8"\x10'
|
35
|
-
b'\xc1\x14\x95\x01%\xc1\n\xa143Ta\xa8"D-\x84\x03QM\x98\xc3'
|
36
|
-
b'\x1a\x1a\x1a@5\x0e\x04\xa0q\x88\x05\x00\x07\xf8\x18\xf9'
|
37
|
-
b'\xdao\xd0|\x00\x00\x00\x00IEND\xaeB`\x82'
|
38
|
-
)
|
39
|
-
|
40
|
-
|
41
|
-
def mock_requests_post(url: str, json: dict[str, Any], **kwargs):
|
42
|
-
del url, kwargs
|
43
|
-
c = pg.Dict(json['generationConfig'])
|
44
|
-
content = json['contents'][0]['parts'][0]['text']
|
45
|
-
response = requests.Response()
|
46
|
-
response.status_code = 200
|
47
|
-
response._content = pg.to_json_str({
|
48
|
-
'candidates': [
|
49
|
-
{
|
50
|
-
'content': {
|
51
|
-
'role': 'model',
|
52
|
-
'parts': [
|
53
|
-
{
|
54
|
-
'text': (
|
55
|
-
f'This is a response to {content} with '
|
56
|
-
f'temperature={c.temperature}, '
|
57
|
-
f'top_p={c.topP}, '
|
58
|
-
f'top_k={c.topK}, '
|
59
|
-
f'max_tokens={c.maxOutputTokens}, '
|
60
|
-
f'stop={"".join(c.stopSequences)}.'
|
61
|
-
)
|
62
|
-
},
|
63
|
-
],
|
64
|
-
},
|
65
|
-
},
|
66
|
-
],
|
67
|
-
'usageMetadata': {
|
68
|
-
'promptTokenCount': 3,
|
69
|
-
'candidatesTokenCount': 4,
|
70
|
-
}
|
71
|
-
}).encode()
|
72
|
-
return response
|
73
21
|
|
74
22
|
|
75
23
|
class VertexAITest(unittest.TestCase):
|
76
24
|
"""Tests for Vertex model with REST API."""
|
77
25
|
|
78
|
-
def test_content_from_message_text_only(self):
|
79
|
-
text = 'This is a beautiful day'
|
80
|
-
model = vertexai.VertexAIGeminiPro1_5_002()
|
81
|
-
chunks = model._content_from_message(lf.UserMessage(text))
|
82
|
-
self.assertEqual(chunks, {'role': 'user', 'parts': [{'text': text}]})
|
83
|
-
|
84
|
-
def test_content_from_message_mm(self):
|
85
|
-
image = lf_modalities.Image.from_bytes(example_image)
|
86
|
-
message = lf.UserMessage(
|
87
|
-
'This is an <<[[image]]>>, what is it?', image=image
|
88
|
-
)
|
89
|
-
|
90
|
-
# Non-multimodal model.
|
91
|
-
with self.assertRaisesRegex(lf.ModalityError, 'Unsupported modality'):
|
92
|
-
vertexai.VertexAIGeminiPro1()._content_from_message(message)
|
93
|
-
|
94
|
-
model = vertexai.VertexAIGeminiPro1Vision()
|
95
|
-
content = model._content_from_message(message)
|
96
|
-
self.assertEqual(
|
97
|
-
content,
|
98
|
-
{
|
99
|
-
'role': 'user',
|
100
|
-
'parts': [
|
101
|
-
{'text': 'This is an'},
|
102
|
-
{
|
103
|
-
'inlineData': {
|
104
|
-
'data': base64.b64encode(example_image).decode(),
|
105
|
-
'mimeType': 'image/png',
|
106
|
-
}
|
107
|
-
},
|
108
|
-
{'text': ', what is it?'},
|
109
|
-
],
|
110
|
-
},
|
111
|
-
)
|
112
|
-
|
113
26
|
@mock.patch.object(vertexai.VertexAI, 'credentials', new=True)
|
114
27
|
def test_project_and_location_check(self):
|
115
28
|
with self.assertRaisesRegex(ValueError, 'Please specify `project`'):
|
@@ -126,87 +39,14 @@ class VertexAITest(unittest.TestCase):
|
|
126
39
|
|
127
40
|
os.environ['VERTEXAI_PROJECT'] = 'abc'
|
128
41
|
os.environ['VERTEXAI_LOCATION'] = 'us-central1'
|
129
|
-
|
42
|
+
model = vertexai.VertexAIGeminiPro1()
|
43
|
+
self.assertTrue(model.model_id.startswith('VertexAI('))
|
44
|
+
self.assertIsNotNone(model.api_endpoint)
|
45
|
+
self.assertTrue(model._api_initialized)
|
46
|
+
self.assertIsNotNone(model._session)
|
130
47
|
del os.environ['VERTEXAI_PROJECT']
|
131
48
|
del os.environ['VERTEXAI_LOCATION']
|
132
49
|
|
133
|
-
def test_generation_config(self):
|
134
|
-
model = vertexai.VertexAIGeminiPro1()
|
135
|
-
json_schema = {
|
136
|
-
'type': 'object',
|
137
|
-
'properties': {
|
138
|
-
'name': {'type': 'string'},
|
139
|
-
},
|
140
|
-
'required': ['name'],
|
141
|
-
'title': 'Person',
|
142
|
-
}
|
143
|
-
actual = model._generation_config(
|
144
|
-
lf.UserMessage('hi', json_schema=json_schema),
|
145
|
-
lf.LMSamplingOptions(
|
146
|
-
temperature=2.0,
|
147
|
-
top_p=1.0,
|
148
|
-
top_k=20,
|
149
|
-
max_tokens=1024,
|
150
|
-
stop=['\n'],
|
151
|
-
),
|
152
|
-
)
|
153
|
-
self.assertEqual(
|
154
|
-
actual,
|
155
|
-
dict(
|
156
|
-
candidateCount=1,
|
157
|
-
temperature=2.0,
|
158
|
-
topP=1.0,
|
159
|
-
topK=20,
|
160
|
-
maxOutputTokens=1024,
|
161
|
-
stopSequences=['\n'],
|
162
|
-
responseLogprobs=False,
|
163
|
-
logprobs=None,
|
164
|
-
seed=None,
|
165
|
-
responseMimeType='application/json',
|
166
|
-
responseSchema={
|
167
|
-
'type': 'object',
|
168
|
-
'properties': {
|
169
|
-
'name': {'type': 'string'}
|
170
|
-
},
|
171
|
-
'required': ['name'],
|
172
|
-
'title': 'Person',
|
173
|
-
}
|
174
|
-
),
|
175
|
-
)
|
176
|
-
with self.assertRaisesRegex(
|
177
|
-
ValueError, '`json_schema` must be a dict, got'
|
178
|
-
):
|
179
|
-
model._generation_config(
|
180
|
-
lf.UserMessage('hi', json_schema='not a dict'),
|
181
|
-
lf.LMSamplingOptions(),
|
182
|
-
)
|
183
|
-
|
184
|
-
@mock.patch.object(vertexai.VertexAI, 'credentials', new=True)
|
185
|
-
def test_call_model(self):
|
186
|
-
with mock.patch('requests.Session.post') as mock_generate:
|
187
|
-
mock_generate.side_effect = mock_requests_post
|
188
|
-
|
189
|
-
lm = vertexai.VertexAIGeminiPro1_5_002(
|
190
|
-
project='abc', location='us-central1'
|
191
|
-
)
|
192
|
-
r = lm(
|
193
|
-
'hello',
|
194
|
-
temperature=2.0,
|
195
|
-
top_p=1.0,
|
196
|
-
top_k=20,
|
197
|
-
max_tokens=1024,
|
198
|
-
stop='\n',
|
199
|
-
)
|
200
|
-
self.assertEqual(
|
201
|
-
r.text,
|
202
|
-
(
|
203
|
-
'This is a response to hello with temperature=2.0, '
|
204
|
-
'top_p=1.0, top_k=20, max_tokens=1024, stop=\n.'
|
205
|
-
),
|
206
|
-
)
|
207
|
-
self.assertEqual(r.metadata.usage.prompt_tokens, 3)
|
208
|
-
self.assertEqual(r.metadata.usage.completion_tokens, 4)
|
209
|
-
|
210
50
|
|
211
51
|
if __name__ == '__main__':
|
212
52
|
unittest.main()
|