langfun 0.1.2.dev202412020805__py3-none-any.whl → 0.1.2.dev202412030804__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/core/llms/__init__.py +1 -7
- langfun/core/llms/openai.py +142 -207
- langfun/core/llms/openai_test.py +160 -224
- langfun/core/llms/vertexai.py +23 -422
- langfun/core/llms/vertexai_test.py +21 -335
- {langfun-0.1.2.dev202412020805.dist-info → langfun-0.1.2.dev202412030804.dist-info}/METADATA +1 -12
- {langfun-0.1.2.dev202412020805.dist-info → langfun-0.1.2.dev202412030804.dist-info}/RECORD +10 -10
- {langfun-0.1.2.dev202412020805.dist-info → langfun-0.1.2.dev202412030804.dist-info}/LICENSE +0 -0
- {langfun-0.1.2.dev202412020805.dist-info → langfun-0.1.2.dev202412030804.dist-info}/WHEEL +0 -0
- {langfun-0.1.2.dev202412020805.dist-info → langfun-0.1.2.dev202412030804.dist-info}/top_level.txt +0 -0
langfun/core/llms/vertexai.py
CHANGED
@@ -28,21 +28,13 @@ try:
|
|
28
28
|
from google import auth as google_auth
|
29
29
|
from google.auth import credentials as credentials_lib
|
30
30
|
from google.auth.transport import requests as auth_requests
|
31
|
-
import vertexai
|
32
|
-
from google.cloud.aiplatform import models as aiplatform_models
|
33
|
-
from vertexai import generative_models
|
34
|
-
from vertexai import language_models
|
35
31
|
# pylint: enable=g-import-not-at-top
|
36
32
|
|
37
33
|
Credentials = credentials_lib.Credentials
|
38
34
|
except ImportError:
|
39
35
|
google_auth = None
|
36
|
+
credentials_lib = None
|
40
37
|
auth_requests = None
|
41
|
-
credentials_lib = None # pylint: disable=invalid-name
|
42
|
-
vertexai = None
|
43
|
-
generative_models = None
|
44
|
-
language_models = None
|
45
|
-
aiplatform_models = None
|
46
38
|
Credentials = Any
|
47
39
|
|
48
40
|
|
@@ -56,408 +48,72 @@ AVGERAGE_CHARS_PER_TOEKN = 4
|
|
56
48
|
# as of 2024-10-10.
|
57
49
|
SUPPORTED_MODELS_AND_SETTINGS = {
|
58
50
|
'gemini-1.5-pro-001': pg.Dict(
|
59
|
-
api='gemini',
|
60
51
|
rpm=100,
|
61
52
|
cost_per_1k_input_chars=0.0003125,
|
62
53
|
cost_per_1k_output_chars=0.00125,
|
63
54
|
),
|
64
55
|
'gemini-1.5-pro-002': pg.Dict(
|
65
|
-
api='gemini',
|
66
56
|
rpm=100,
|
67
57
|
cost_per_1k_input_chars=0.0003125,
|
68
58
|
cost_per_1k_output_chars=0.00125,
|
69
59
|
),
|
70
60
|
'gemini-1.5-flash-002': pg.Dict(
|
71
|
-
api='gemini',
|
72
61
|
rpm=500,
|
73
62
|
cost_per_1k_input_chars=0.00001875,
|
74
63
|
cost_per_1k_output_chars=0.000075,
|
75
64
|
),
|
76
65
|
'gemini-1.5-flash-001': pg.Dict(
|
77
|
-
api='gemini',
|
78
66
|
rpm=500,
|
79
67
|
cost_per_1k_input_chars=0.00001875,
|
80
68
|
cost_per_1k_output_chars=0.000075,
|
81
69
|
),
|
82
70
|
'gemini-1.5-pro': pg.Dict(
|
83
|
-
api='gemini',
|
84
71
|
rpm=100,
|
85
72
|
cost_per_1k_input_chars=0.0003125,
|
86
73
|
cost_per_1k_output_chars=0.00125,
|
87
74
|
),
|
88
75
|
'gemini-1.5-flash': pg.Dict(
|
89
|
-
api='gemini',
|
90
|
-
rpm=500,
|
91
|
-
cost_per_1k_input_chars=0.00001875,
|
92
|
-
cost_per_1k_output_chars=0.000075,
|
93
|
-
),
|
94
|
-
'gemini-1.5-pro-latest': pg.Dict(
|
95
|
-
api='gemini',
|
96
|
-
rpm=100,
|
97
|
-
cost_per_1k_input_chars=0.0003125,
|
98
|
-
cost_per_1k_output_chars=0.00125,
|
99
|
-
),
|
100
|
-
'gemini-1.5-flash-latest': pg.Dict(
|
101
|
-
api='gemini',
|
102
76
|
rpm=500,
|
103
77
|
cost_per_1k_input_chars=0.00001875,
|
104
78
|
cost_per_1k_output_chars=0.000075,
|
105
79
|
),
|
106
80
|
'gemini-1.5-pro-preview-0514': pg.Dict(
|
107
|
-
api='gemini',
|
108
81
|
rpm=50,
|
109
82
|
cost_per_1k_input_chars=0.0003125,
|
110
83
|
cost_per_1k_output_chars=0.00125,
|
111
84
|
),
|
112
85
|
'gemini-1.5-pro-preview-0409': pg.Dict(
|
113
|
-
api='gemini',
|
114
86
|
rpm=50,
|
115
87
|
cost_per_1k_input_chars=0.0003125,
|
116
88
|
cost_per_1k_output_chars=0.00125,
|
117
89
|
),
|
118
90
|
'gemini-1.5-flash-preview-0514': pg.Dict(
|
119
|
-
api='gemini',
|
120
91
|
rpm=200,
|
121
92
|
cost_per_1k_input_chars=0.00001875,
|
122
93
|
cost_per_1k_output_chars=0.000075,
|
123
94
|
),
|
124
95
|
'gemini-1.0-pro': pg.Dict(
|
125
|
-
api='gemini',
|
126
96
|
rpm=300,
|
127
97
|
cost_per_1k_input_chars=0.000125,
|
128
98
|
cost_per_1k_output_chars=0.000375,
|
129
99
|
),
|
130
100
|
'gemini-1.0-pro-vision': pg.Dict(
|
131
|
-
api='gemini',
|
132
101
|
rpm=100,
|
133
102
|
cost_per_1k_input_chars=0.000125,
|
134
103
|
cost_per_1k_output_chars=0.000375,
|
135
104
|
),
|
136
|
-
# PaLM APIs.
|
137
|
-
'text-bison': pg.Dict(
|
138
|
-
api='palm',
|
139
|
-
rpm=1600
|
140
|
-
),
|
141
|
-
'text-bison-32k': pg.Dict(
|
142
|
-
api='palm',
|
143
|
-
rpm=300
|
144
|
-
),
|
145
|
-
'text-unicorn': pg.Dict(
|
146
|
-
api='palm',
|
147
|
-
rpm=100
|
148
|
-
),
|
149
|
-
# Endpoint
|
150
105
|
# TODO(chengrun): Set a more appropriate rpm for endpoint.
|
151
|
-
'
|
106
|
+
'vertexai-endpoint': pg.Dict(
|
107
|
+
rpm=20,
|
108
|
+
cost_per_1k_input_chars=0.0000125,
|
109
|
+
cost_per_1k_output_chars=0.0000375,
|
110
|
+
),
|
152
111
|
}
|
153
112
|
|
154
113
|
|
155
|
-
@lf.use_init_args(['model'])
|
156
|
-
class VertexAI(lf.LanguageModel):
|
157
|
-
"""Language model served on VertexAI."""
|
158
|
-
|
159
|
-
model: pg.typing.Annotated[
|
160
|
-
pg.typing.Enum(
|
161
|
-
pg.MISSING_VALUE, list(SUPPORTED_MODELS_AND_SETTINGS.keys())
|
162
|
-
),
|
163
|
-
(
|
164
|
-
'Vertex AI model name. See '
|
165
|
-
'https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models '
|
166
|
-
'for details.'
|
167
|
-
),
|
168
|
-
]
|
169
|
-
|
170
|
-
endpoint_name: pg.typing.Annotated[
|
171
|
-
str | None,
|
172
|
-
'Vertex Endpoint name or ID.',
|
173
|
-
]
|
174
|
-
|
175
|
-
project: Annotated[
|
176
|
-
str | None,
|
177
|
-
(
|
178
|
-
'Vertex AI project ID. Or set from environment variable '
|
179
|
-
'VERTEXAI_PROJECT.'
|
180
|
-
),
|
181
|
-
] = None
|
182
|
-
|
183
|
-
location: Annotated[
|
184
|
-
str | None,
|
185
|
-
(
|
186
|
-
'Vertex AI service location. Or set from environment variable '
|
187
|
-
'VERTEXAI_LOCATION.'
|
188
|
-
),
|
189
|
-
] = None
|
190
|
-
|
191
|
-
credentials: Annotated[
|
192
|
-
Credentials | None,
|
193
|
-
(
|
194
|
-
'Credentials to use. If None, the default credentials to the '
|
195
|
-
'environment will be used.'
|
196
|
-
),
|
197
|
-
] = None
|
198
|
-
|
199
|
-
supported_modalities: Annotated[
|
200
|
-
list[str],
|
201
|
-
'A list of MIME types for supported modalities'
|
202
|
-
] = []
|
203
|
-
|
204
|
-
def _on_bound(self):
|
205
|
-
super()._on_bound()
|
206
|
-
self.__dict__.pop('_api_initialized', None)
|
207
|
-
if generative_models is None:
|
208
|
-
raise RuntimeError(
|
209
|
-
'Please install "langfun[llm-google-vertex]" to use Vertex AI models.'
|
210
|
-
)
|
211
|
-
|
212
|
-
@functools.cached_property
|
213
|
-
def _api_initialized(self):
|
214
|
-
project = self.project or os.environ.get('VERTEXAI_PROJECT', None)
|
215
|
-
if not project:
|
216
|
-
raise ValueError(
|
217
|
-
'Please specify `project` during `__init__` or set environment '
|
218
|
-
'variable `VERTEXAI_PROJECT` with your Vertex AI project ID.'
|
219
|
-
)
|
220
|
-
|
221
|
-
location = self.location or os.environ.get('VERTEXAI_LOCATION', None)
|
222
|
-
if not location:
|
223
|
-
raise ValueError(
|
224
|
-
'Please specify `location` during `__init__` or set environment '
|
225
|
-
'variable `VERTEXAI_LOCATION` with your Vertex AI service location.'
|
226
|
-
)
|
227
|
-
|
228
|
-
credentials = self.credentials
|
229
|
-
# Placeholder for Google-internal credentials.
|
230
|
-
assert vertexai is not None
|
231
|
-
vertexai.init(project=project, location=location, credentials=credentials)
|
232
|
-
return True
|
233
|
-
|
234
|
-
@property
|
235
|
-
def model_id(self) -> str:
|
236
|
-
"""Returns a string to identify the model."""
|
237
|
-
return f'VertexAI({self.model})'
|
238
|
-
|
239
|
-
@property
|
240
|
-
def resource_id(self) -> str:
|
241
|
-
"""Returns a string to identify the resource for rate control."""
|
242
|
-
return self.model_id
|
243
|
-
|
244
|
-
@property
|
245
|
-
def max_concurrency(self) -> int:
|
246
|
-
"""Returns the maximum number of concurrent requests."""
|
247
|
-
return self.rate_to_max_concurrency(
|
248
|
-
requests_per_min=SUPPORTED_MODELS_AND_SETTINGS[self.model].rpm,
|
249
|
-
tokens_per_min=0,
|
250
|
-
)
|
251
|
-
|
252
|
-
def estimate_cost(
|
253
|
-
self,
|
254
|
-
num_input_tokens: int,
|
255
|
-
num_output_tokens: int
|
256
|
-
) -> float | None:
|
257
|
-
"""Estimate the cost based on usage."""
|
258
|
-
cost_per_1k_input_chars = SUPPORTED_MODELS_AND_SETTINGS[self.model].get(
|
259
|
-
'cost_per_1k_input_chars', None
|
260
|
-
)
|
261
|
-
cost_per_1k_output_chars = SUPPORTED_MODELS_AND_SETTINGS[self.model].get(
|
262
|
-
'cost_per_1k_output_chars', None
|
263
|
-
)
|
264
|
-
if cost_per_1k_output_chars is None or cost_per_1k_input_chars is None:
|
265
|
-
return None
|
266
|
-
return (
|
267
|
-
cost_per_1k_input_chars * num_input_tokens
|
268
|
-
+ cost_per_1k_output_chars * num_output_tokens
|
269
|
-
) * AVGERAGE_CHARS_PER_TOEKN / 1000
|
270
|
-
|
271
|
-
def _generation_config(
|
272
|
-
self, prompt: lf.Message, options: lf.LMSamplingOptions
|
273
|
-
) -> Any: # generative_models.GenerationConfig
|
274
|
-
"""Creates generation config from langfun sampling options."""
|
275
|
-
assert generative_models is not None
|
276
|
-
# Users could use `metadata_json_schema` to pass additional
|
277
|
-
# request arguments.
|
278
|
-
json_schema = prompt.metadata.get('json_schema')
|
279
|
-
response_mime_type = None
|
280
|
-
if json_schema is not None:
|
281
|
-
if not isinstance(json_schema, dict):
|
282
|
-
raise ValueError(
|
283
|
-
f'`json_schema` must be a dict, got {json_schema!r}.'
|
284
|
-
)
|
285
|
-
response_mime_type = 'application/json'
|
286
|
-
prompt.metadata.formatted_text = (
|
287
|
-
prompt.text
|
288
|
-
+ '\n\n [RESPONSE FORMAT (not part of prompt)]\n'
|
289
|
-
+ pg.to_json_str(json_schema, json_indent=2)
|
290
|
-
)
|
291
|
-
|
292
|
-
return generative_models.GenerationConfig(
|
293
|
-
temperature=options.temperature,
|
294
|
-
top_p=options.top_p,
|
295
|
-
top_k=options.top_k,
|
296
|
-
max_output_tokens=options.max_tokens,
|
297
|
-
stop_sequences=options.stop,
|
298
|
-
response_mime_type=response_mime_type,
|
299
|
-
response_schema=json_schema,
|
300
|
-
)
|
301
|
-
|
302
|
-
def _content_from_message(
|
303
|
-
self, prompt: lf.Message
|
304
|
-
) -> list[str | Any]:
|
305
|
-
"""Gets generation input from langfun message."""
|
306
|
-
assert generative_models is not None
|
307
|
-
chunks = []
|
308
|
-
|
309
|
-
for lf_chunk in prompt.chunk():
|
310
|
-
if isinstance(lf_chunk, str):
|
311
|
-
chunks.append(lf_chunk)
|
312
|
-
elif isinstance(lf_chunk, lf_modalities.Mime):
|
313
|
-
try:
|
314
|
-
modalities = lf_chunk.make_compatible(
|
315
|
-
self.supported_modalities + ['text/plain']
|
316
|
-
)
|
317
|
-
if isinstance(modalities, lf_modalities.Mime):
|
318
|
-
modalities = [modalities]
|
319
|
-
for modality in modalities:
|
320
|
-
if modality.is_text:
|
321
|
-
chunk = modality.to_text()
|
322
|
-
else:
|
323
|
-
chunk = generative_models.Part.from_data(
|
324
|
-
modality.to_bytes(), modality.mime_type
|
325
|
-
)
|
326
|
-
chunks.append(chunk)
|
327
|
-
except lf.ModalityError as e:
|
328
|
-
raise lf.ModalityError(f'Unsupported modality: {lf_chunk!r}') from e
|
329
|
-
else:
|
330
|
-
raise lf.ModalityError(f'Unsupported modality: {lf_chunk!r}')
|
331
|
-
return chunks
|
332
|
-
|
333
|
-
def _generation_response_to_message(
|
334
|
-
self,
|
335
|
-
response: Any, # generative_models.GenerationResponse
|
336
|
-
) -> lf.Message:
|
337
|
-
"""Parses generative response into message."""
|
338
|
-
return lf.AIMessage(response.text)
|
339
|
-
|
340
|
-
def _generation_endpoint_response_to_message(
|
341
|
-
self,
|
342
|
-
response: Any, # google.cloud.aiplatform.aiplatform.models.Prediction
|
343
|
-
) -> lf.Message:
|
344
|
-
"""Parses Endpoint response into message."""
|
345
|
-
return lf.AIMessage(response.predictions[0])
|
346
|
-
|
347
|
-
def _sample(self, prompts: list[lf.Message]) -> list[lf.LMSamplingResult]:
|
348
|
-
assert self._api_initialized, 'Vertex AI API is not initialized.'
|
349
|
-
# TODO(yifenglu): It seems this exception is due to the instability of the
|
350
|
-
# API. We should revisit this later.
|
351
|
-
retry_on_errors = [
|
352
|
-
(Exception, 'InternalServerError'),
|
353
|
-
(Exception, 'ResourceExhausted'),
|
354
|
-
(Exception, '_InactiveRpcError'),
|
355
|
-
(Exception, 'ValueError'),
|
356
|
-
]
|
357
|
-
|
358
|
-
return self._parallel_execute_with_currency_control(
|
359
|
-
self._sample_single,
|
360
|
-
prompts,
|
361
|
-
retry_on_errors=retry_on_errors,
|
362
|
-
)
|
363
|
-
|
364
|
-
def _sample_single(self, prompt: lf.Message) -> lf.LMSamplingResult:
|
365
|
-
if self.sampling_options.n > 1:
|
366
|
-
raise ValueError(
|
367
|
-
f'`n` greater than 1 is not supported: {self.sampling_options.n}.'
|
368
|
-
)
|
369
|
-
api = SUPPORTED_MODELS_AND_SETTINGS[self.model].api
|
370
|
-
match api:
|
371
|
-
case 'gemini':
|
372
|
-
return self._sample_generative_model(prompt)
|
373
|
-
case 'palm':
|
374
|
-
return self._sample_text_generation_model(prompt)
|
375
|
-
case 'endpoint':
|
376
|
-
return self._sample_endpoint_model(prompt)
|
377
|
-
case _:
|
378
|
-
raise ValueError(f'Unsupported API: {api}')
|
379
|
-
|
380
|
-
def _sample_generative_model(self, prompt: lf.Message) -> lf.LMSamplingResult:
|
381
|
-
"""Samples a generative model."""
|
382
|
-
model = _VERTEXAI_MODEL_HUB.get_generative_model(self.model)
|
383
|
-
input_content = self._content_from_message(prompt)
|
384
|
-
response = model.generate_content(
|
385
|
-
input_content,
|
386
|
-
generation_config=self._generation_config(
|
387
|
-
prompt, self.sampling_options
|
388
|
-
),
|
389
|
-
)
|
390
|
-
usage_metadata = response.usage_metadata
|
391
|
-
usage = lf.LMSamplingUsage(
|
392
|
-
prompt_tokens=usage_metadata.prompt_token_count,
|
393
|
-
completion_tokens=usage_metadata.candidates_token_count,
|
394
|
-
total_tokens=usage_metadata.total_token_count,
|
395
|
-
estimated_cost=self.estimate_cost(
|
396
|
-
num_input_tokens=usage_metadata.prompt_token_count,
|
397
|
-
num_output_tokens=usage_metadata.candidates_token_count,
|
398
|
-
),
|
399
|
-
)
|
400
|
-
return lf.LMSamplingResult(
|
401
|
-
[
|
402
|
-
# Scoring is not supported.
|
403
|
-
lf.LMSample(
|
404
|
-
self._generation_response_to_message(response), score=0.0
|
405
|
-
),
|
406
|
-
],
|
407
|
-
usage=usage,
|
408
|
-
)
|
409
|
-
|
410
|
-
def _sample_text_generation_model(
|
411
|
-
self, prompt: lf.Message
|
412
|
-
) -> lf.LMSamplingResult:
|
413
|
-
"""Samples a text generation model."""
|
414
|
-
model = _VERTEXAI_MODEL_HUB.get_text_generation_model(self.model)
|
415
|
-
predict_options = dict(
|
416
|
-
temperature=self.sampling_options.temperature,
|
417
|
-
top_k=self.sampling_options.top_k,
|
418
|
-
top_p=self.sampling_options.top_p,
|
419
|
-
max_output_tokens=self.sampling_options.max_tokens,
|
420
|
-
stop_sequences=self.sampling_options.stop,
|
421
|
-
)
|
422
|
-
response = model.predict(prompt.text, **predict_options)
|
423
|
-
return lf.LMSamplingResult([
|
424
|
-
# Scoring is not supported.
|
425
|
-
lf.LMSample(lf.AIMessage(response.text), score=0.0)
|
426
|
-
])
|
427
|
-
|
428
|
-
def _sample_endpoint_model(self, prompt: lf.Message) -> lf.LMSamplingResult:
|
429
|
-
"""Samples a text generation model."""
|
430
|
-
assert aiplatform_models is not None
|
431
|
-
model = aiplatform_models.Endpoint(self.endpoint_name)
|
432
|
-
# TODO(chengrun): Add support for stop_sequences.
|
433
|
-
predict_options = dict(
|
434
|
-
temperature=self.sampling_options.temperature
|
435
|
-
if self.sampling_options.temperature is not None
|
436
|
-
else 1.0,
|
437
|
-
top_k=self.sampling_options.top_k
|
438
|
-
if self.sampling_options.top_k is not None
|
439
|
-
else 32,
|
440
|
-
top_p=self.sampling_options.top_p
|
441
|
-
if self.sampling_options.top_p is not None
|
442
|
-
else 1,
|
443
|
-
max_tokens=self.sampling_options.max_tokens
|
444
|
-
if self.sampling_options.max_tokens is not None
|
445
|
-
else 8192,
|
446
|
-
)
|
447
|
-
instances = [{'prompt': prompt.text, **predict_options}]
|
448
|
-
response = model.predict(instances=instances)
|
449
|
-
|
450
|
-
return lf.LMSamplingResult([
|
451
|
-
# Scoring is not supported.
|
452
|
-
lf.LMSample(
|
453
|
-
self._generation_endpoint_response_to_message(response), score=0.0
|
454
|
-
)
|
455
|
-
])
|
456
|
-
|
457
|
-
|
458
114
|
@lf.use_init_args(['model'])
|
459
115
|
@pg.members([('api_endpoint', pg.typing.Str().freeze(''))])
|
460
|
-
class
|
116
|
+
class VertexAI(rest.REST):
|
461
117
|
"""Language model served on VertexAI with REST API."""
|
462
118
|
|
463
119
|
model: pg.typing.Annotated[
|
@@ -687,39 +343,6 @@ class VertexAIRest(rest.REST):
|
|
687
343
|
return lf.AIMessage.from_chunks(chunks)
|
688
344
|
|
689
345
|
|
690
|
-
class _ModelHub:
|
691
|
-
"""Vertex AI model hub."""
|
692
|
-
|
693
|
-
def __init__(self):
|
694
|
-
self._generative_model_cache = {}
|
695
|
-
self._text_generation_model_cache = {}
|
696
|
-
|
697
|
-
def get_generative_model(
|
698
|
-
self, model_id: str
|
699
|
-
) -> Any: # generative_models.GenerativeModel:
|
700
|
-
"""Gets a generative model by model id."""
|
701
|
-
model = self._generative_model_cache.get(model_id, None)
|
702
|
-
if model is None:
|
703
|
-
assert generative_models is not None
|
704
|
-
model = generative_models.GenerativeModel(model_id)
|
705
|
-
self._generative_model_cache[model_id] = model
|
706
|
-
return model
|
707
|
-
|
708
|
-
def get_text_generation_model(
|
709
|
-
self, model_id: str
|
710
|
-
) -> Any: # language_models.TextGenerationModel
|
711
|
-
"""Gets a text generation model by model id."""
|
712
|
-
model = self._text_generation_model_cache.get(model_id, None)
|
713
|
-
if model is None:
|
714
|
-
assert language_models is not None
|
715
|
-
model = language_models.TextGenerationModel.from_pretrained(model_id)
|
716
|
-
self._text_generation_model_cache[model_id] = model
|
717
|
-
return model
|
718
|
-
|
719
|
-
|
720
|
-
_VERTEXAI_MODEL_HUB = _ModelHub()
|
721
|
-
|
722
|
-
|
723
346
|
_IMAGE_TYPES = [
|
724
347
|
'image/png',
|
725
348
|
'image/jpeg',
|
@@ -773,33 +396,19 @@ class VertexAIGemini1_5(VertexAI): # pylint: disable=invalid-name
|
|
773
396
|
)
|
774
397
|
|
775
398
|
|
776
|
-
class VertexAIGeminiPro1_5_Latest(VertexAIGemini1_5): # pylint: disable=invalid-name
|
777
|
-
"""Vertex AI Gemini 1.5 Pro model."""
|
778
|
-
|
779
|
-
model = 'gemini-1.5-pro-latest'
|
780
|
-
|
781
|
-
|
782
399
|
class VertexAIGeminiPro1_5(VertexAIGemini1_5): # pylint: disable=invalid-name
|
783
400
|
"""Vertex AI Gemini 1.5 Pro model."""
|
784
401
|
|
785
402
|
model = 'gemini-1.5-pro'
|
786
403
|
|
787
404
|
|
788
|
-
class
|
789
|
-
"""Vertex AI Gemini 1.5 model with REST API."""
|
790
|
-
|
791
|
-
supported_modalities: pg.typing.List(str).freeze( # pytype: disable=invalid-annotation
|
792
|
-
_DOCUMENT_TYPES + _IMAGE_TYPES + _AUDIO_TYPES + _VIDEO_TYPES
|
793
|
-
)
|
794
|
-
|
795
|
-
|
796
|
-
class VertexAIGeminiPro1_5_002(VertexAIRestGemini1_5): # pylint: disable=invalid-name
|
405
|
+
class VertexAIGeminiPro1_5_002(VertexAIGemini1_5): # pylint: disable=invalid-name
|
797
406
|
"""Vertex AI Gemini 1.5 Pro model."""
|
798
407
|
|
799
408
|
model = 'gemini-1.5-pro-002'
|
800
409
|
|
801
410
|
|
802
|
-
class VertexAIGeminiPro1_5_001(
|
411
|
+
class VertexAIGeminiPro1_5_001(VertexAIGemini1_5): # pylint: disable=invalid-name
|
803
412
|
"""Vertex AI Gemini 1.5 Pro model."""
|
804
413
|
|
805
414
|
model = 'gemini-1.5-pro-001'
|
@@ -817,25 +426,19 @@ class VertexAIGeminiPro1_5_0409(VertexAIGemini1_5): # pylint: disable=invalid-n
|
|
817
426
|
model = 'gemini-1.5-pro-preview-0409'
|
818
427
|
|
819
428
|
|
820
|
-
class VertexAIGeminiFlash1_5_Latest(VertexAIGemini1_5): # pylint: disable=invalid-name
|
821
|
-
"""Vertex AI Gemini 1.5 Flash model."""
|
822
|
-
|
823
|
-
model = 'gemini-1.5-flash-latest'
|
824
|
-
|
825
|
-
|
826
429
|
class VertexAIGeminiFlash1_5(VertexAIGemini1_5): # pylint: disable=invalid-name
|
827
430
|
"""Vertex AI Gemini 1.5 Flash model."""
|
828
431
|
|
829
432
|
model = 'gemini-1.5-flash'
|
830
433
|
|
831
434
|
|
832
|
-
class VertexAIGeminiFlash1_5_002(
|
435
|
+
class VertexAIGeminiFlash1_5_002(VertexAIGemini1_5): # pylint: disable=invalid-name
|
833
436
|
"""Vertex AI Gemini 1.5 Flash model."""
|
834
437
|
|
835
438
|
model = 'gemini-1.5-flash-002'
|
836
439
|
|
837
440
|
|
838
|
-
class VertexAIGeminiFlash1_5_001(
|
441
|
+
class VertexAIGeminiFlash1_5_001(VertexAIGemini1_5): # pylint: disable=invalid-name
|
839
442
|
"""Vertex AI Gemini 1.5 Flash model."""
|
840
443
|
|
841
444
|
model = 'gemini-1.5-flash-001'
|
@@ -847,7 +450,7 @@ class VertexAIGeminiFlash1_5_0514(VertexAIGemini1_5): # pylint: disable=invalid
|
|
847
450
|
model = 'gemini-1.5-flash-preview-0514'
|
848
451
|
|
849
452
|
|
850
|
-
class VertexAIGeminiPro1(
|
453
|
+
class VertexAIGeminiPro1(VertexAI): # pylint: disable=invalid-name
|
851
454
|
"""Vertex AI Gemini 1.0 Pro model."""
|
852
455
|
|
853
456
|
model = 'gemini-1.0-pro'
|
@@ -862,19 +465,17 @@ class VertexAIGeminiPro1Vision(VertexAI): # pylint: disable=invalid-name
|
|
862
465
|
)
|
863
466
|
|
864
467
|
|
865
|
-
class
|
866
|
-
"""Vertex AI
|
867
|
-
|
868
|
-
model = 'text-bison'
|
869
|
-
|
870
|
-
|
871
|
-
class VertexAIPalm2_32K(VertexAI): # pylint: disable=invalid-name
|
872
|
-
"""Vertex AI PaLM2 text generation model (32K context length)."""
|
468
|
+
class VertexAIEndpoint(VertexAI): # pylint: disable=invalid-name
|
469
|
+
"""Vertex AI Endpoint model."""
|
873
470
|
|
874
|
-
model = '
|
471
|
+
model = 'vertexai-endpoint'
|
875
472
|
|
473
|
+
endpoint: Annotated[str, 'Vertex AI Endpoint ID.']
|
876
474
|
|
877
|
-
|
878
|
-
|
879
|
-
|
880
|
-
|
475
|
+
@property
|
476
|
+
def api_endpoint(self) -> str:
|
477
|
+
return (
|
478
|
+
f'https://{self.location}-aiplatform.googleapis.com/v1/projects/'
|
479
|
+
f'{self.project}/locations/{self.location}/'
|
480
|
+
f'endpoints/{self.endpoint}:generateContent'
|
481
|
+
)
|