langfun 0.1.2.dev202501010804__py3-none-any.whl → 0.1.2.dev202501060804__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/core/__init__.py +0 -4
- langfun/core/eval/matching.py +2 -2
- langfun/core/eval/scoring.py +6 -2
- langfun/core/eval/v2/checkpointing.py +106 -72
- langfun/core/eval/v2/checkpointing_test.py +108 -3
- langfun/core/eval/v2/eval_test_helper.py +56 -0
- langfun/core/eval/v2/evaluation.py +25 -4
- langfun/core/eval/v2/evaluation_test.py +11 -0
- langfun/core/eval/v2/example.py +11 -1
- langfun/core/eval/v2/example_test.py +16 -2
- langfun/core/eval/v2/experiment.py +83 -19
- langfun/core/eval/v2/experiment_test.py +121 -3
- langfun/core/eval/v2/reporting.py +67 -20
- langfun/core/eval/v2/reporting_test.py +119 -2
- langfun/core/eval/v2/runners.py +7 -4
- langfun/core/llms/__init__.py +23 -24
- langfun/core/llms/anthropic.py +12 -0
- langfun/core/llms/cache/in_memory.py +6 -0
- langfun/core/llms/cache/in_memory_test.py +5 -0
- langfun/core/llms/gemini.py +507 -0
- langfun/core/llms/gemini_test.py +195 -0
- langfun/core/llms/google_genai.py +46 -310
- langfun/core/llms/google_genai_test.py +9 -204
- langfun/core/llms/openai.py +23 -37
- langfun/core/llms/vertexai.py +28 -348
- langfun/core/llms/vertexai_test.py +6 -166
- {langfun-0.1.2.dev202501010804.dist-info → langfun-0.1.2.dev202501060804.dist-info}/METADATA +7 -13
- {langfun-0.1.2.dev202501010804.dist-info → langfun-0.1.2.dev202501060804.dist-info}/RECORD +31 -31
- {langfun-0.1.2.dev202501010804.dist-info → langfun-0.1.2.dev202501060804.dist-info}/WHEEL +1 -1
- langfun/core/repr_utils.py +0 -204
- langfun/core/repr_utils_test.py +0 -90
- {langfun-0.1.2.dev202501010804.dist-info → langfun-0.1.2.dev202501060804.dist-info}/LICENSE +0 -0
- {langfun-0.1.2.dev202501010804.dist-info → langfun-0.1.2.dev202501060804.dist-info}/top_level.txt +0 -0
langfun/core/llms/vertexai.py
CHANGED
@@ -13,14 +13,12 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
"""Vertex AI generative models."""
|
15
15
|
|
16
|
-
import base64
|
17
16
|
import functools
|
18
17
|
import os
|
19
18
|
from typing import Annotated, Any
|
20
19
|
|
21
20
|
import langfun.core as lf
|
22
|
-
from langfun.core import
|
23
|
-
from langfun.core.llms import rest
|
21
|
+
from langfun.core.llms import gemini
|
24
22
|
import pyglove as pg
|
25
23
|
|
26
24
|
try:
|
@@ -38,108 +36,11 @@ except ImportError:
|
|
38
36
|
Credentials = Any
|
39
37
|
|
40
38
|
|
41
|
-
# https://cloud.google.com/vertex-ai/generative-ai/pricing
|
42
|
-
# describes that the average number of characters per token is about 4.
|
43
|
-
AVGERAGE_CHARS_PER_TOKEN = 4
|
44
|
-
|
45
|
-
|
46
|
-
# Price in US dollars,
|
47
|
-
# from https://cloud.google.com/vertex-ai/generative-ai/pricing
|
48
|
-
# as of 2024-10-10.
|
49
|
-
SUPPORTED_MODELS_AND_SETTINGS = {
|
50
|
-
'gemini-1.5-pro-001': pg.Dict(
|
51
|
-
rpm=100,
|
52
|
-
cost_per_1k_input_chars=0.0003125,
|
53
|
-
cost_per_1k_output_chars=0.00125,
|
54
|
-
),
|
55
|
-
'gemini-1.5-pro-002': pg.Dict(
|
56
|
-
rpm=100,
|
57
|
-
cost_per_1k_input_chars=0.0003125,
|
58
|
-
cost_per_1k_output_chars=0.00125,
|
59
|
-
),
|
60
|
-
'gemini-1.5-flash-002': pg.Dict(
|
61
|
-
rpm=500,
|
62
|
-
cost_per_1k_input_chars=0.00001875,
|
63
|
-
cost_per_1k_output_chars=0.000075,
|
64
|
-
),
|
65
|
-
'gemini-1.5-flash-001': pg.Dict(
|
66
|
-
rpm=500,
|
67
|
-
cost_per_1k_input_chars=0.00001875,
|
68
|
-
cost_per_1k_output_chars=0.000075,
|
69
|
-
),
|
70
|
-
'gemini-1.5-pro': pg.Dict(
|
71
|
-
rpm=100,
|
72
|
-
cost_per_1k_input_chars=0.0003125,
|
73
|
-
cost_per_1k_output_chars=0.00125,
|
74
|
-
),
|
75
|
-
'gemini-1.5-flash': pg.Dict(
|
76
|
-
rpm=500,
|
77
|
-
cost_per_1k_input_chars=0.00001875,
|
78
|
-
cost_per_1k_output_chars=0.000075,
|
79
|
-
),
|
80
|
-
'gemini-1.5-pro-preview-0514': pg.Dict(
|
81
|
-
rpm=50,
|
82
|
-
cost_per_1k_input_chars=0.0003125,
|
83
|
-
cost_per_1k_output_chars=0.00125,
|
84
|
-
),
|
85
|
-
'gemini-1.5-pro-preview-0409': pg.Dict(
|
86
|
-
rpm=50,
|
87
|
-
cost_per_1k_input_chars=0.0003125,
|
88
|
-
cost_per_1k_output_chars=0.00125,
|
89
|
-
),
|
90
|
-
'gemini-1.5-flash-preview-0514': pg.Dict(
|
91
|
-
rpm=200,
|
92
|
-
cost_per_1k_input_chars=0.00001875,
|
93
|
-
cost_per_1k_output_chars=0.000075,
|
94
|
-
),
|
95
|
-
'gemini-1.0-pro': pg.Dict(
|
96
|
-
rpm=300,
|
97
|
-
cost_per_1k_input_chars=0.000125,
|
98
|
-
cost_per_1k_output_chars=0.000375,
|
99
|
-
),
|
100
|
-
'gemini-1.0-pro-vision': pg.Dict(
|
101
|
-
rpm=100,
|
102
|
-
cost_per_1k_input_chars=0.000125,
|
103
|
-
cost_per_1k_output_chars=0.000375,
|
104
|
-
),
|
105
|
-
# TODO(sharatsharat): Update costs when published
|
106
|
-
'gemini-exp-1206': pg.Dict(
|
107
|
-
rpm=20,
|
108
|
-
cost_per_1k_input_chars=0.000,
|
109
|
-
cost_per_1k_output_chars=0.000,
|
110
|
-
),
|
111
|
-
# TODO(sharatsharat): Update costs when published
|
112
|
-
'gemini-2.0-flash-exp': pg.Dict(
|
113
|
-
rpm=20,
|
114
|
-
cost_per_1k_input_chars=0.000,
|
115
|
-
cost_per_1k_output_chars=0.000,
|
116
|
-
),
|
117
|
-
# TODO(chengrun): Set a more appropriate rpm for endpoint.
|
118
|
-
'vertexai-endpoint': pg.Dict(
|
119
|
-
rpm=20,
|
120
|
-
cost_per_1k_input_chars=0.0000125,
|
121
|
-
cost_per_1k_output_chars=0.0000375,
|
122
|
-
),
|
123
|
-
}
|
124
|
-
|
125
|
-
|
126
39
|
@lf.use_init_args(['model'])
|
127
40
|
@pg.members([('api_endpoint', pg.typing.Str().freeze(''))])
|
128
|
-
class VertexAI(
|
41
|
+
class VertexAI(gemini.Gemini):
|
129
42
|
"""Language model served on VertexAI with REST API."""
|
130
43
|
|
131
|
-
model: pg.typing.Annotated[
|
132
|
-
pg.typing.Enum(
|
133
|
-
pg.MISSING_VALUE, list(SUPPORTED_MODELS_AND_SETTINGS.keys())
|
134
|
-
),
|
135
|
-
(
|
136
|
-
'Vertex AI model name with REST API support. See '
|
137
|
-
'https://cloud.google.com/vertex-ai/generative-ai/docs/'
|
138
|
-
'model-reference/inference#supported-models'
|
139
|
-
' for details.'
|
140
|
-
),
|
141
|
-
]
|
142
|
-
|
143
44
|
project: Annotated[
|
144
45
|
str | None,
|
145
46
|
(
|
@@ -164,11 +65,6 @@ class VertexAI(rest.REST):
|
|
164
65
|
),
|
165
66
|
] = None
|
166
67
|
|
167
|
-
supported_modalities: Annotated[
|
168
|
-
list[str],
|
169
|
-
'A list of MIME types for supported modalities'
|
170
|
-
] = []
|
171
|
-
|
172
68
|
def _on_bound(self):
|
173
69
|
super()._on_bound()
|
174
70
|
if google_auth is None:
|
@@ -203,31 +99,9 @@ class VertexAI(rest.REST):
|
|
203
99
|
self._credentials = credentials
|
204
100
|
|
205
101
|
@property
|
206
|
-
def
|
207
|
-
"""Returns
|
208
|
-
return self.
|
209
|
-
requests_per_min=SUPPORTED_MODELS_AND_SETTINGS[self.model].rpm,
|
210
|
-
tokens_per_min=0,
|
211
|
-
)
|
212
|
-
|
213
|
-
def estimate_cost(
|
214
|
-
self,
|
215
|
-
num_input_tokens: int,
|
216
|
-
num_output_tokens: int
|
217
|
-
) -> float | None:
|
218
|
-
"""Estimate the cost based on usage."""
|
219
|
-
cost_per_1k_input_chars = SUPPORTED_MODELS_AND_SETTINGS[self.model].get(
|
220
|
-
'cost_per_1k_input_chars', None
|
221
|
-
)
|
222
|
-
cost_per_1k_output_chars = SUPPORTED_MODELS_AND_SETTINGS[self.model].get(
|
223
|
-
'cost_per_1k_output_chars', None
|
224
|
-
)
|
225
|
-
if cost_per_1k_output_chars is None or cost_per_1k_input_chars is None:
|
226
|
-
return None
|
227
|
-
return (
|
228
|
-
cost_per_1k_input_chars * num_input_tokens
|
229
|
-
+ cost_per_1k_output_chars * num_output_tokens
|
230
|
-
) * AVGERAGE_CHARS_PER_TOKEN / 1000
|
102
|
+
def model_id(self) -> str:
|
103
|
+
"""Returns a string to identify the model."""
|
104
|
+
return f'VertexAI({self.model})'
|
231
105
|
|
232
106
|
@functools.cached_property
|
233
107
|
def _session(self):
|
@@ -238,12 +112,6 @@ class VertexAI(rest.REST):
|
|
238
112
|
s.headers.update(self.headers or {})
|
239
113
|
return s
|
240
114
|
|
241
|
-
@property
|
242
|
-
def headers(self):
|
243
|
-
return {
|
244
|
-
'Content-Type': 'application/json; charset=utf-8',
|
245
|
-
}
|
246
|
-
|
247
115
|
@property
|
248
116
|
def api_endpoint(self) -> str:
|
249
117
|
return (
|
@@ -252,257 +120,69 @@ class VertexAI(rest.REST):
|
|
252
120
|
f'models/{self.model}:generateContent'
|
253
121
|
)
|
254
122
|
|
255
|
-
def request(
|
256
|
-
self, prompt: lf.Message, sampling_options: lf.LMSamplingOptions
|
257
|
-
) -> dict[str, Any]:
|
258
|
-
request = dict(
|
259
|
-
generationConfig=self._generation_config(prompt, sampling_options)
|
260
|
-
)
|
261
|
-
request['contents'] = [self._content_from_message(prompt)]
|
262
|
-
return request
|
263
|
-
|
264
|
-
def _generation_config(
|
265
|
-
self, prompt: lf.Message, options: lf.LMSamplingOptions
|
266
|
-
) -> dict[str, Any]:
|
267
|
-
"""Returns a dict as generation config for prompt and LMSamplingOptions."""
|
268
|
-
config = dict(
|
269
|
-
temperature=options.temperature,
|
270
|
-
maxOutputTokens=options.max_tokens,
|
271
|
-
candidateCount=options.n,
|
272
|
-
topK=options.top_k,
|
273
|
-
topP=options.top_p,
|
274
|
-
stopSequences=options.stop,
|
275
|
-
seed=options.random_seed,
|
276
|
-
responseLogprobs=options.logprobs,
|
277
|
-
logprobs=options.top_logprobs,
|
278
|
-
)
|
279
123
|
|
280
|
-
|
281
|
-
|
282
|
-
raise ValueError(
|
283
|
-
f'`json_schema` must be a dict, got {json_schema!r}.'
|
284
|
-
)
|
285
|
-
json_schema = pg.to_json(json_schema)
|
286
|
-
config['responseSchema'] = json_schema
|
287
|
-
config['responseMimeType'] = 'application/json'
|
288
|
-
prompt.metadata.formatted_text = (
|
289
|
-
prompt.text
|
290
|
-
+ '\n\n [RESPONSE FORMAT (not part of prompt)]\n'
|
291
|
-
+ pg.to_json_str(json_schema, json_indent=2)
|
292
|
-
)
|
293
|
-
return config
|
294
|
-
|
295
|
-
def _content_from_message(self, prompt: lf.Message) -> dict[str, Any]:
|
296
|
-
"""Gets generation content from langfun message."""
|
297
|
-
parts = []
|
298
|
-
for lf_chunk in prompt.chunk():
|
299
|
-
if isinstance(lf_chunk, str):
|
300
|
-
parts.append({'text': lf_chunk})
|
301
|
-
elif isinstance(lf_chunk, lf_modalities.Mime):
|
302
|
-
try:
|
303
|
-
modalities = lf_chunk.make_compatible(
|
304
|
-
self.supported_modalities + ['text/plain']
|
305
|
-
)
|
306
|
-
if isinstance(modalities, lf_modalities.Mime):
|
307
|
-
modalities = [modalities]
|
308
|
-
for modality in modalities:
|
309
|
-
if modality.is_text:
|
310
|
-
parts.append({'text': modality.to_text()})
|
311
|
-
else:
|
312
|
-
parts.append({
|
313
|
-
'inlineData': {
|
314
|
-
'data': base64.b64encode(modality.to_bytes()).decode(),
|
315
|
-
'mimeType': modality.mime_type,
|
316
|
-
}
|
317
|
-
})
|
318
|
-
except lf.ModalityError as e:
|
319
|
-
raise lf.ModalityError(f'Unsupported modality: {lf_chunk!r}') from e
|
320
|
-
else:
|
321
|
-
raise lf.ModalityError(f'Unsupported modality: {lf_chunk!r}')
|
322
|
-
return dict(role='user', parts=parts)
|
323
|
-
|
324
|
-
def result(self, json: dict[str, Any]) -> lf.LMSamplingResult:
|
325
|
-
messages = [
|
326
|
-
self._message_from_content_parts(candidate['content']['parts'])
|
327
|
-
for candidate in json['candidates']
|
328
|
-
]
|
329
|
-
usage = json['usageMetadata']
|
330
|
-
input_tokens = usage['promptTokenCount']
|
331
|
-
output_tokens = usage['candidatesTokenCount']
|
332
|
-
return lf.LMSamplingResult(
|
333
|
-
[lf.LMSample(message) for message in messages],
|
334
|
-
usage=lf.LMSamplingUsage(
|
335
|
-
prompt_tokens=input_tokens,
|
336
|
-
completion_tokens=output_tokens,
|
337
|
-
total_tokens=input_tokens + output_tokens,
|
338
|
-
estimated_cost=self.estimate_cost(
|
339
|
-
num_input_tokens=input_tokens,
|
340
|
-
num_output_tokens=output_tokens,
|
341
|
-
),
|
342
|
-
),
|
343
|
-
)
|
124
|
+
class VertexAIGeminiFlash2_0ThinkingExp_20241219(VertexAI): # pylint: disable=invalid-name
|
125
|
+
"""Vertex AI Gemini Flash 2.0 Thinking model launched on 12/19/2024."""
|
344
126
|
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
for part in parts:
|
351
|
-
if text_part := part.get('text'):
|
352
|
-
chunks.append(text_part)
|
353
|
-
else:
|
354
|
-
raise ValueError(f'Unsupported part: {part}')
|
355
|
-
return lf.AIMessage.from_chunks(chunks)
|
356
|
-
|
357
|
-
|
358
|
-
IMAGE_TYPES = [
|
359
|
-
'image/png',
|
360
|
-
'image/jpeg',
|
361
|
-
'image/webp',
|
362
|
-
'image/heic',
|
363
|
-
'image/heif',
|
364
|
-
]
|
365
|
-
|
366
|
-
AUDIO_TYPES = [
|
367
|
-
'audio/aac',
|
368
|
-
'audio/flac',
|
369
|
-
'audio/mp3',
|
370
|
-
'audio/m4a',
|
371
|
-
'audio/mpeg',
|
372
|
-
'audio/mpga',
|
373
|
-
'audio/mp4',
|
374
|
-
'audio/opus',
|
375
|
-
'audio/pcm',
|
376
|
-
'audio/wav',
|
377
|
-
'audio/webm',
|
378
|
-
]
|
379
|
-
|
380
|
-
VIDEO_TYPES = [
|
381
|
-
'video/mov',
|
382
|
-
'video/mpeg',
|
383
|
-
'video/mpegps',
|
384
|
-
'video/mpg',
|
385
|
-
'video/mp4',
|
386
|
-
'video/webm',
|
387
|
-
'video/wmv',
|
388
|
-
'video/x-flv',
|
389
|
-
'video/3gpp',
|
390
|
-
'video/quicktime',
|
391
|
-
]
|
392
|
-
|
393
|
-
DOCUMENT_TYPES = [
|
394
|
-
'application/pdf',
|
395
|
-
'text/plain',
|
396
|
-
'text/csv',
|
397
|
-
'text/html',
|
398
|
-
'text/xml',
|
399
|
-
'text/x-script.python',
|
400
|
-
'application/json',
|
401
|
-
]
|
402
|
-
|
403
|
-
|
404
|
-
class VertexAIGemini2_0(VertexAI): # pylint: disable=invalid-name
|
405
|
-
"""Vertex AI Gemini 2.0 model."""
|
406
|
-
|
407
|
-
supported_modalities: pg.typing.List(str).freeze( # pytype: disable=invalid-annotation
|
408
|
-
DOCUMENT_TYPES + IMAGE_TYPES + AUDIO_TYPES + VIDEO_TYPES
|
409
|
-
)
|
410
|
-
|
411
|
-
|
412
|
-
class VertexAIGeminiFlash2_0Exp(VertexAIGemini2_0): # pylint: disable=invalid-name
|
127
|
+
api_version = 'v1alpha'
|
128
|
+
model = 'gemini-2.0-flash-thinking-exp-1219'
|
129
|
+
|
130
|
+
|
131
|
+
class VertexAIGeminiFlash2_0Exp(VertexAI): # pylint: disable=invalid-name
|
413
132
|
"""Vertex AI Gemini 2.0 Flash model."""
|
414
133
|
|
415
134
|
model = 'gemini-2.0-flash-exp'
|
416
135
|
|
417
136
|
|
418
|
-
class
|
419
|
-
"""Vertex AI Gemini
|
137
|
+
class VertexAIGeminiExp_20241206(VertexAI): # pylint: disable=invalid-name
|
138
|
+
"""Vertex AI Gemini Experimental model launched on 12/06/2024."""
|
420
139
|
|
421
|
-
|
422
|
-
DOCUMENT_TYPES + IMAGE_TYPES + AUDIO_TYPES + VIDEO_TYPES
|
423
|
-
)
|
140
|
+
model = 'gemini-exp-1206'
|
424
141
|
|
425
142
|
|
426
|
-
class
|
427
|
-
"""Vertex AI Gemini
|
143
|
+
class VertexAIGeminiExp_20241114(VertexAI): # pylint: disable=invalid-name
|
144
|
+
"""Vertex AI Gemini Experimental model launched on 11/14/2024."""
|
428
145
|
|
429
|
-
model = 'gemini-
|
146
|
+
model = 'gemini-exp-1114'
|
430
147
|
|
431
148
|
|
432
|
-
class
|
149
|
+
class VertexAIGeminiPro1_5(VertexAI): # pylint: disable=invalid-name
|
433
150
|
"""Vertex AI Gemini 1.5 Pro model."""
|
434
151
|
|
435
|
-
model = 'gemini-1.5-pro-
|
152
|
+
model = 'gemini-1.5-pro-latest'
|
436
153
|
|
437
154
|
|
438
|
-
class
|
155
|
+
class VertexAIGeminiPro1_5_002(VertexAI): # pylint: disable=invalid-name
|
439
156
|
"""Vertex AI Gemini 1.5 Pro model."""
|
440
157
|
|
441
|
-
model = 'gemini-1.5-pro-
|
442
|
-
|
443
|
-
|
444
|
-
class VertexAIGeminiPro1_5_0514(VertexAIGemini1_5): # pylint: disable=invalid-name
|
445
|
-
"""Vertex AI Gemini 1.5 Pro preview model."""
|
446
|
-
|
447
|
-
model = 'gemini-1.5-pro-preview-0514'
|
158
|
+
model = 'gemini-1.5-pro-002'
|
448
159
|
|
449
160
|
|
450
|
-
class
|
451
|
-
"""Vertex AI Gemini 1.5 Pro
|
161
|
+
class VertexAIGeminiPro1_5_001(VertexAI): # pylint: disable=invalid-name
|
162
|
+
"""Vertex AI Gemini 1.5 Pro model."""
|
452
163
|
|
453
|
-
model = 'gemini-1.5-pro-
|
164
|
+
model = 'gemini-1.5-pro-001'
|
454
165
|
|
455
166
|
|
456
|
-
class VertexAIGeminiFlash1_5(
|
167
|
+
class VertexAIGeminiFlash1_5(VertexAI): # pylint: disable=invalid-name
|
457
168
|
"""Vertex AI Gemini 1.5 Flash model."""
|
458
169
|
|
459
170
|
model = 'gemini-1.5-flash'
|
460
171
|
|
461
172
|
|
462
|
-
class VertexAIGeminiFlash1_5_002(
|
173
|
+
class VertexAIGeminiFlash1_5_002(VertexAI): # pylint: disable=invalid-name
|
463
174
|
"""Vertex AI Gemini 1.5 Flash model."""
|
464
175
|
|
465
176
|
model = 'gemini-1.5-flash-002'
|
466
177
|
|
467
178
|
|
468
|
-
class VertexAIGeminiFlash1_5_001(
|
179
|
+
class VertexAIGeminiFlash1_5_001(VertexAI): # pylint: disable=invalid-name
|
469
180
|
"""Vertex AI Gemini 1.5 Flash model."""
|
470
181
|
|
471
182
|
model = 'gemini-1.5-flash-001'
|
472
183
|
|
473
184
|
|
474
|
-
class VertexAIGeminiFlash1_5_0514(VertexAIGemini1_5): # pylint: disable=invalid-name
|
475
|
-
"""Vertex AI Gemini 1.5 Flash preview model."""
|
476
|
-
|
477
|
-
model = 'gemini-1.5-flash-preview-0514'
|
478
|
-
|
479
|
-
|
480
185
|
class VertexAIGeminiPro1(VertexAI): # pylint: disable=invalid-name
|
481
186
|
"""Vertex AI Gemini 1.0 Pro model."""
|
482
187
|
|
483
188
|
model = 'gemini-1.0-pro'
|
484
|
-
|
485
|
-
|
486
|
-
class VertexAIGeminiPro1Vision(VertexAI): # pylint: disable=invalid-name
|
487
|
-
"""Vertex AI Gemini 1.0 Pro Vision model."""
|
488
|
-
|
489
|
-
model = 'gemini-1.0-pro-vision'
|
490
|
-
supported_modalities: pg.typing.List(str).freeze( # pytype: disable=invalid-annotation
|
491
|
-
IMAGE_TYPES + VIDEO_TYPES
|
492
|
-
)
|
493
|
-
|
494
|
-
|
495
|
-
class VertexAIEndpoint(VertexAI): # pylint: disable=invalid-name
|
496
|
-
"""Vertex AI Endpoint model."""
|
497
|
-
|
498
|
-
model = 'vertexai-endpoint'
|
499
|
-
|
500
|
-
endpoint: Annotated[str, 'Vertex AI Endpoint ID.']
|
501
|
-
|
502
|
-
@property
|
503
|
-
def api_endpoint(self) -> str:
|
504
|
-
return (
|
505
|
-
f'https://{self.location}-aiplatform.googleapis.com/v1/projects/'
|
506
|
-
f'{self.project}/locations/{self.location}/'
|
507
|
-
f'endpoints/{self.endpoint}:generateContent'
|
508
|
-
)
|
@@ -11,105 +11,18 @@
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
|
-
"""Tests for
|
14
|
+
"""Tests for VertexAI models."""
|
15
15
|
|
16
|
-
import base64
|
17
16
|
import os
|
18
|
-
from typing import Any
|
19
17
|
import unittest
|
20
18
|
from unittest import mock
|
21
19
|
|
22
|
-
import langfun.core as lf
|
23
|
-
from langfun.core import modalities as lf_modalities
|
24
20
|
from langfun.core.llms import vertexai
|
25
|
-
import pyglove as pg
|
26
|
-
import requests
|
27
|
-
|
28
|
-
|
29
|
-
example_image = (
|
30
|
-
b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x18\x00\x00\x00\x18\x04'
|
31
|
-
b'\x03\x00\x00\x00\x12Y \xcb\x00\x00\x00\x18PLTE\x00\x00'
|
32
|
-
b'\x00fff_chaag_cg_ch^ci_ciC\xedb\x94\x00\x00\x00\x08tRNS'
|
33
|
-
b'\x00\n\x9f*\xd4\xff_\xf4\xe4\x8b\xf3a\x00\x00\x00>IDATx'
|
34
|
-
b'\x01c \x05\x08)"\xd8\xcc\xae!\x06pNz\x88k\x19\\Q\xa8"\x10'
|
35
|
-
b'\xc1\x14\x95\x01%\xc1\n\xa143Ta\xa8"D-\x84\x03QM\x98\xc3'
|
36
|
-
b'\x1a\x1a\x1a@5\x0e\x04\xa0q\x88\x05\x00\x07\xf8\x18\xf9'
|
37
|
-
b'\xdao\xd0|\x00\x00\x00\x00IEND\xaeB`\x82'
|
38
|
-
)
|
39
|
-
|
40
|
-
|
41
|
-
def mock_requests_post(url: str, json: dict[str, Any], **kwargs):
|
42
|
-
del url, kwargs
|
43
|
-
c = pg.Dict(json['generationConfig'])
|
44
|
-
content = json['contents'][0]['parts'][0]['text']
|
45
|
-
response = requests.Response()
|
46
|
-
response.status_code = 200
|
47
|
-
response._content = pg.to_json_str({
|
48
|
-
'candidates': [
|
49
|
-
{
|
50
|
-
'content': {
|
51
|
-
'role': 'model',
|
52
|
-
'parts': [
|
53
|
-
{
|
54
|
-
'text': (
|
55
|
-
f'This is a response to {content} with '
|
56
|
-
f'temperature={c.temperature}, '
|
57
|
-
f'top_p={c.topP}, '
|
58
|
-
f'top_k={c.topK}, '
|
59
|
-
f'max_tokens={c.maxOutputTokens}, '
|
60
|
-
f'stop={"".join(c.stopSequences)}.'
|
61
|
-
)
|
62
|
-
},
|
63
|
-
],
|
64
|
-
},
|
65
|
-
},
|
66
|
-
],
|
67
|
-
'usageMetadata': {
|
68
|
-
'promptTokenCount': 3,
|
69
|
-
'candidatesTokenCount': 4,
|
70
|
-
}
|
71
|
-
}).encode()
|
72
|
-
return response
|
73
21
|
|
74
22
|
|
75
23
|
class VertexAITest(unittest.TestCase):
|
76
24
|
"""Tests for Vertex model with REST API."""
|
77
25
|
|
78
|
-
def test_content_from_message_text_only(self):
|
79
|
-
text = 'This is a beautiful day'
|
80
|
-
model = vertexai.VertexAIGeminiPro1_5_002()
|
81
|
-
chunks = model._content_from_message(lf.UserMessage(text))
|
82
|
-
self.assertEqual(chunks, {'role': 'user', 'parts': [{'text': text}]})
|
83
|
-
|
84
|
-
def test_content_from_message_mm(self):
|
85
|
-
image = lf_modalities.Image.from_bytes(example_image)
|
86
|
-
message = lf.UserMessage(
|
87
|
-
'This is an <<[[image]]>>, what is it?', image=image
|
88
|
-
)
|
89
|
-
|
90
|
-
# Non-multimodal model.
|
91
|
-
with self.assertRaisesRegex(lf.ModalityError, 'Unsupported modality'):
|
92
|
-
vertexai.VertexAIGeminiPro1()._content_from_message(message)
|
93
|
-
|
94
|
-
model = vertexai.VertexAIGeminiPro1Vision()
|
95
|
-
content = model._content_from_message(message)
|
96
|
-
self.assertEqual(
|
97
|
-
content,
|
98
|
-
{
|
99
|
-
'role': 'user',
|
100
|
-
'parts': [
|
101
|
-
{'text': 'This is an'},
|
102
|
-
{
|
103
|
-
'inlineData': {
|
104
|
-
'data': base64.b64encode(example_image).decode(),
|
105
|
-
'mimeType': 'image/png',
|
106
|
-
}
|
107
|
-
},
|
108
|
-
{'text': ', what is it?'},
|
109
|
-
],
|
110
|
-
},
|
111
|
-
)
|
112
|
-
|
113
26
|
@mock.patch.object(vertexai.VertexAI, 'credentials', new=True)
|
114
27
|
def test_project_and_location_check(self):
|
115
28
|
with self.assertRaisesRegex(ValueError, 'Please specify `project`'):
|
@@ -126,87 +39,14 @@ class VertexAITest(unittest.TestCase):
|
|
126
39
|
|
127
40
|
os.environ['VERTEXAI_PROJECT'] = 'abc'
|
128
41
|
os.environ['VERTEXAI_LOCATION'] = 'us-central1'
|
129
|
-
|
42
|
+
model = vertexai.VertexAIGeminiPro1()
|
43
|
+
self.assertTrue(model.model_id.startswith('VertexAI('))
|
44
|
+
self.assertIsNotNone(model.api_endpoint)
|
45
|
+
self.assertTrue(model._api_initialized)
|
46
|
+
self.assertIsNotNone(model._session)
|
130
47
|
del os.environ['VERTEXAI_PROJECT']
|
131
48
|
del os.environ['VERTEXAI_LOCATION']
|
132
49
|
|
133
|
-
def test_generation_config(self):
|
134
|
-
model = vertexai.VertexAIGeminiPro1()
|
135
|
-
json_schema = {
|
136
|
-
'type': 'object',
|
137
|
-
'properties': {
|
138
|
-
'name': {'type': 'string'},
|
139
|
-
},
|
140
|
-
'required': ['name'],
|
141
|
-
'title': 'Person',
|
142
|
-
}
|
143
|
-
actual = model._generation_config(
|
144
|
-
lf.UserMessage('hi', json_schema=json_schema),
|
145
|
-
lf.LMSamplingOptions(
|
146
|
-
temperature=2.0,
|
147
|
-
top_p=1.0,
|
148
|
-
top_k=20,
|
149
|
-
max_tokens=1024,
|
150
|
-
stop=['\n'],
|
151
|
-
),
|
152
|
-
)
|
153
|
-
self.assertEqual(
|
154
|
-
actual,
|
155
|
-
dict(
|
156
|
-
candidateCount=1,
|
157
|
-
temperature=2.0,
|
158
|
-
topP=1.0,
|
159
|
-
topK=20,
|
160
|
-
maxOutputTokens=1024,
|
161
|
-
stopSequences=['\n'],
|
162
|
-
responseLogprobs=False,
|
163
|
-
logprobs=None,
|
164
|
-
seed=None,
|
165
|
-
responseMimeType='application/json',
|
166
|
-
responseSchema={
|
167
|
-
'type': 'object',
|
168
|
-
'properties': {
|
169
|
-
'name': {'type': 'string'}
|
170
|
-
},
|
171
|
-
'required': ['name'],
|
172
|
-
'title': 'Person',
|
173
|
-
}
|
174
|
-
),
|
175
|
-
)
|
176
|
-
with self.assertRaisesRegex(
|
177
|
-
ValueError, '`json_schema` must be a dict, got'
|
178
|
-
):
|
179
|
-
model._generation_config(
|
180
|
-
lf.UserMessage('hi', json_schema='not a dict'),
|
181
|
-
lf.LMSamplingOptions(),
|
182
|
-
)
|
183
|
-
|
184
|
-
@mock.patch.object(vertexai.VertexAI, 'credentials', new=True)
|
185
|
-
def test_call_model(self):
|
186
|
-
with mock.patch('requests.Session.post') as mock_generate:
|
187
|
-
mock_generate.side_effect = mock_requests_post
|
188
|
-
|
189
|
-
lm = vertexai.VertexAIGeminiPro1_5_002(
|
190
|
-
project='abc', location='us-central1'
|
191
|
-
)
|
192
|
-
r = lm(
|
193
|
-
'hello',
|
194
|
-
temperature=2.0,
|
195
|
-
top_p=1.0,
|
196
|
-
top_k=20,
|
197
|
-
max_tokens=1024,
|
198
|
-
stop='\n',
|
199
|
-
)
|
200
|
-
self.assertEqual(
|
201
|
-
r.text,
|
202
|
-
(
|
203
|
-
'This is a response to hello with temperature=2.0, '
|
204
|
-
'top_p=1.0, top_k=20, max_tokens=1024, stop=\n.'
|
205
|
-
),
|
206
|
-
)
|
207
|
-
self.assertEqual(r.metadata.usage.prompt_tokens, 3)
|
208
|
-
self.assertEqual(r.metadata.usage.completion_tokens, 4)
|
209
|
-
|
210
50
|
|
211
51
|
if __name__ == '__main__':
|
212
52
|
unittest.main()
|