langfun 0.1.2.dev202412010804__py3-none-any.whl → 0.1.2.dev202412030804__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/core/eval/v2/progress_tracking.py +2 -1
- langfun/core/eval/v2/progress_tracking_test.py +10 -0
- langfun/core/llms/__init__.py +1 -5
- langfun/core/llms/openai.py +142 -202
- langfun/core/llms/openai_test.py +160 -224
- langfun/core/llms/vertexai.py +133 -286
- langfun/core/llms/vertexai_test.py +74 -194
- {langfun-0.1.2.dev202412010804.dist-info → langfun-0.1.2.dev202412030804.dist-info}/METADATA +1 -12
- {langfun-0.1.2.dev202412010804.dist-info → langfun-0.1.2.dev202412030804.dist-info}/RECORD +12 -12
- {langfun-0.1.2.dev202412010804.dist-info → langfun-0.1.2.dev202412030804.dist-info}/LICENSE +0 -0
- {langfun-0.1.2.dev202412010804.dist-info → langfun-0.1.2.dev202412030804.dist-info}/WHEEL +0 -0
- {langfun-0.1.2.dev202412010804.dist-info → langfun-0.1.2.dev202412030804.dist-info}/top_level.txt +0 -0
langfun/core/llms/vertexai.py
CHANGED
@@ -13,30 +13,28 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
"""Vertex AI generative models."""
|
15
15
|
|
16
|
+
import base64
|
16
17
|
import functools
|
17
18
|
import os
|
18
19
|
from typing import Annotated, Any
|
19
20
|
|
20
21
|
import langfun.core as lf
|
21
22
|
from langfun.core import modalities as lf_modalities
|
23
|
+
from langfun.core.llms import rest
|
22
24
|
import pyglove as pg
|
23
25
|
|
24
26
|
try:
|
25
27
|
# pylint: disable=g-import-not-at-top
|
28
|
+
from google import auth as google_auth
|
26
29
|
from google.auth import credentials as credentials_lib
|
27
|
-
import
|
28
|
-
from google.cloud.aiplatform import models as aiplatform_models
|
29
|
-
from vertexai import generative_models
|
30
|
-
from vertexai import language_models
|
30
|
+
from google.auth.transport import requests as auth_requests
|
31
31
|
# pylint: enable=g-import-not-at-top
|
32
32
|
|
33
33
|
Credentials = credentials_lib.Credentials
|
34
34
|
except ImportError:
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
language_models = None
|
39
|
-
aiplatform_models = None
|
35
|
+
google_auth = None
|
36
|
+
credentials_lib = None
|
37
|
+
auth_requests = None
|
40
38
|
Credentials = Any
|
41
39
|
|
42
40
|
|
@@ -50,122 +48,86 @@ AVGERAGE_CHARS_PER_TOEKN = 4
|
|
50
48
|
# as of 2024-10-10.
|
51
49
|
SUPPORTED_MODELS_AND_SETTINGS = {
|
52
50
|
'gemini-1.5-pro-001': pg.Dict(
|
53
|
-
api='gemini',
|
54
51
|
rpm=100,
|
55
52
|
cost_per_1k_input_chars=0.0003125,
|
56
53
|
cost_per_1k_output_chars=0.00125,
|
57
54
|
),
|
58
55
|
'gemini-1.5-pro-002': pg.Dict(
|
59
|
-
api='gemini',
|
60
56
|
rpm=100,
|
61
57
|
cost_per_1k_input_chars=0.0003125,
|
62
58
|
cost_per_1k_output_chars=0.00125,
|
63
59
|
),
|
64
60
|
'gemini-1.5-flash-002': pg.Dict(
|
65
|
-
api='gemini',
|
66
61
|
rpm=500,
|
67
62
|
cost_per_1k_input_chars=0.00001875,
|
68
63
|
cost_per_1k_output_chars=0.000075,
|
69
64
|
),
|
70
65
|
'gemini-1.5-flash-001': pg.Dict(
|
71
|
-
api='gemini',
|
72
66
|
rpm=500,
|
73
67
|
cost_per_1k_input_chars=0.00001875,
|
74
68
|
cost_per_1k_output_chars=0.000075,
|
75
69
|
),
|
76
70
|
'gemini-1.5-pro': pg.Dict(
|
77
|
-
api='gemini',
|
78
71
|
rpm=100,
|
79
72
|
cost_per_1k_input_chars=0.0003125,
|
80
73
|
cost_per_1k_output_chars=0.00125,
|
81
74
|
),
|
82
75
|
'gemini-1.5-flash': pg.Dict(
|
83
|
-
api='gemini',
|
84
|
-
rpm=500,
|
85
|
-
cost_per_1k_input_chars=0.00001875,
|
86
|
-
cost_per_1k_output_chars=0.000075,
|
87
|
-
),
|
88
|
-
'gemini-1.5-pro-latest': pg.Dict(
|
89
|
-
api='gemini',
|
90
|
-
rpm=100,
|
91
|
-
cost_per_1k_input_chars=0.0003125,
|
92
|
-
cost_per_1k_output_chars=0.00125,
|
93
|
-
),
|
94
|
-
'gemini-1.5-flash-latest': pg.Dict(
|
95
|
-
api='gemini',
|
96
76
|
rpm=500,
|
97
77
|
cost_per_1k_input_chars=0.00001875,
|
98
78
|
cost_per_1k_output_chars=0.000075,
|
99
79
|
),
|
100
80
|
'gemini-1.5-pro-preview-0514': pg.Dict(
|
101
|
-
api='gemini',
|
102
81
|
rpm=50,
|
103
82
|
cost_per_1k_input_chars=0.0003125,
|
104
83
|
cost_per_1k_output_chars=0.00125,
|
105
84
|
),
|
106
85
|
'gemini-1.5-pro-preview-0409': pg.Dict(
|
107
|
-
api='gemini',
|
108
86
|
rpm=50,
|
109
87
|
cost_per_1k_input_chars=0.0003125,
|
110
88
|
cost_per_1k_output_chars=0.00125,
|
111
89
|
),
|
112
90
|
'gemini-1.5-flash-preview-0514': pg.Dict(
|
113
|
-
api='gemini',
|
114
91
|
rpm=200,
|
115
92
|
cost_per_1k_input_chars=0.00001875,
|
116
93
|
cost_per_1k_output_chars=0.000075,
|
117
94
|
),
|
118
95
|
'gemini-1.0-pro': pg.Dict(
|
119
|
-
api='gemini',
|
120
96
|
rpm=300,
|
121
97
|
cost_per_1k_input_chars=0.000125,
|
122
98
|
cost_per_1k_output_chars=0.000375,
|
123
99
|
),
|
124
100
|
'gemini-1.0-pro-vision': pg.Dict(
|
125
|
-
api='gemini',
|
126
101
|
rpm=100,
|
127
102
|
cost_per_1k_input_chars=0.000125,
|
128
103
|
cost_per_1k_output_chars=0.000375,
|
129
104
|
),
|
130
|
-
# PaLM APIs.
|
131
|
-
'text-bison': pg.Dict(
|
132
|
-
api='palm',
|
133
|
-
rpm=1600
|
134
|
-
),
|
135
|
-
'text-bison-32k': pg.Dict(
|
136
|
-
api='palm',
|
137
|
-
rpm=300
|
138
|
-
),
|
139
|
-
'text-unicorn': pg.Dict(
|
140
|
-
api='palm',
|
141
|
-
rpm=100
|
142
|
-
),
|
143
|
-
# Endpoint
|
144
105
|
# TODO(chengrun): Set a more appropriate rpm for endpoint.
|
145
|
-
'
|
106
|
+
'vertexai-endpoint': pg.Dict(
|
107
|
+
rpm=20,
|
108
|
+
cost_per_1k_input_chars=0.0000125,
|
109
|
+
cost_per_1k_output_chars=0.0000375,
|
110
|
+
),
|
146
111
|
}
|
147
112
|
|
148
113
|
|
149
114
|
@lf.use_init_args(['model'])
|
150
|
-
|
151
|
-
|
115
|
+
@pg.members([('api_endpoint', pg.typing.Str().freeze(''))])
|
116
|
+
class VertexAI(rest.REST):
|
117
|
+
"""Language model served on VertexAI with REST API."""
|
152
118
|
|
153
119
|
model: pg.typing.Annotated[
|
154
120
|
pg.typing.Enum(
|
155
121
|
pg.MISSING_VALUE, list(SUPPORTED_MODELS_AND_SETTINGS.keys())
|
156
122
|
),
|
157
123
|
(
|
158
|
-
'Vertex AI model name. See '
|
159
|
-
'https://cloud.google.com/vertex-ai/generative-ai/docs/
|
160
|
-
'
|
124
|
+
'Vertex AI model name with REST API support. See '
|
125
|
+
'https://cloud.google.com/vertex-ai/generative-ai/docs/'
|
126
|
+
'model-reference/inference#supported-models'
|
127
|
+
' for details.'
|
161
128
|
),
|
162
129
|
]
|
163
130
|
|
164
|
-
endpoint_name: pg.typing.Annotated[
|
165
|
-
str | None,
|
166
|
-
'Vertex Endpoint name or ID.',
|
167
|
-
]
|
168
|
-
|
169
131
|
project: Annotated[
|
170
132
|
str | None,
|
171
133
|
(
|
@@ -197,14 +159,14 @@ class VertexAI(lf.LanguageModel):
|
|
197
159
|
|
198
160
|
def _on_bound(self):
|
199
161
|
super()._on_bound()
|
200
|
-
|
201
|
-
|
202
|
-
raise RuntimeError(
|
162
|
+
if google_auth is None:
|
163
|
+
raise ValueError(
|
203
164
|
'Please install "langfun[llm-google-vertex]" to use Vertex AI models.'
|
204
165
|
)
|
166
|
+
self._project = None
|
167
|
+
self._credentials = None
|
205
168
|
|
206
|
-
|
207
|
-
def _api_initialized(self):
|
169
|
+
def _initialize(self):
|
208
170
|
project = self.project or os.environ.get('VERTEXAI_PROJECT', None)
|
209
171
|
if not project:
|
210
172
|
raise ValueError(
|
@@ -219,21 +181,14 @@ class VertexAI(lf.LanguageModel):
|
|
219
181
|
'variable `VERTEXAI_LOCATION` with your Vertex AI service location.'
|
220
182
|
)
|
221
183
|
|
184
|
+
self._project = project
|
222
185
|
credentials = self.credentials
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
def model_id(self) -> str:
|
230
|
-
"""Returns a string to identify the model."""
|
231
|
-
return f'VertexAI({self.model})'
|
232
|
-
|
233
|
-
@property
|
234
|
-
def resource_id(self) -> str:
|
235
|
-
"""Returns a string to identify the resource for rate control."""
|
236
|
-
return self.model_id
|
186
|
+
if credentials is None:
|
187
|
+
# Use default credentials.
|
188
|
+
credentials = google_auth.default(
|
189
|
+
scopes=['https://www.googleapis.com/auth/cloud-platform']
|
190
|
+
)
|
191
|
+
self._credentials = credentials
|
237
192
|
|
238
193
|
@property
|
239
194
|
def max_concurrency(self) -> int:
|
@@ -262,47 +217,75 @@ class VertexAI(lf.LanguageModel):
|
|
262
217
|
+ cost_per_1k_output_chars * num_output_tokens
|
263
218
|
) * AVGERAGE_CHARS_PER_TOEKN / 1000
|
264
219
|
|
220
|
+
@functools.cached_property
|
221
|
+
def _session(self):
|
222
|
+
assert self._api_initialized
|
223
|
+
assert self._credentials is not None
|
224
|
+
assert auth_requests is not None
|
225
|
+
s = auth_requests.AuthorizedSession(self._credentials)
|
226
|
+
s.headers.update(self.headers or {})
|
227
|
+
return s
|
228
|
+
|
229
|
+
@property
|
230
|
+
def headers(self):
|
231
|
+
return {
|
232
|
+
'Content-Type': 'application/json; charset=utf-8',
|
233
|
+
}
|
234
|
+
|
235
|
+
@property
|
236
|
+
def api_endpoint(self) -> str:
|
237
|
+
return (
|
238
|
+
f'https://{self.location}-aiplatform.googleapis.com/v1/projects/'
|
239
|
+
f'{self.project}/locations/{self.location}/publishers/google/'
|
240
|
+
f'models/{self.model}:generateContent'
|
241
|
+
)
|
242
|
+
|
243
|
+
def request(
|
244
|
+
self, prompt: lf.Message, sampling_options: lf.LMSamplingOptions
|
245
|
+
) -> dict[str, Any]:
|
246
|
+
request = dict(
|
247
|
+
generationConfig=self._generation_config(prompt, sampling_options)
|
248
|
+
)
|
249
|
+
request['contents'] = [self._content_from_message(prompt)]
|
250
|
+
return request
|
251
|
+
|
265
252
|
def _generation_config(
|
266
253
|
self, prompt: lf.Message, options: lf.LMSamplingOptions
|
267
|
-
) -> Any:
|
268
|
-
"""
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
254
|
+
) -> dict[str, Any]:
|
255
|
+
"""Returns a dict as generation config for prompt and LMSamplingOptions."""
|
256
|
+
config = dict(
|
257
|
+
temperature=options.temperature,
|
258
|
+
maxOutputTokens=options.max_tokens,
|
259
|
+
candidateCount=options.n,
|
260
|
+
topK=options.top_k,
|
261
|
+
topP=options.top_p,
|
262
|
+
stopSequences=options.stop,
|
263
|
+
seed=options.random_seed,
|
264
|
+
responseLogprobs=options.logprobs,
|
265
|
+
logprobs=options.top_logprobs,
|
266
|
+
)
|
267
|
+
|
268
|
+
if json_schema := prompt.metadata.get('json_schema'):
|
275
269
|
if not isinstance(json_schema, dict):
|
276
270
|
raise ValueError(
|
277
271
|
f'`json_schema` must be a dict, got {json_schema!r}.'
|
278
272
|
)
|
279
|
-
|
273
|
+
json_schema = pg.to_json(json_schema)
|
274
|
+
config['responseSchema'] = json_schema
|
275
|
+
config['responseMimeType'] = 'application/json'
|
280
276
|
prompt.metadata.formatted_text = (
|
281
277
|
prompt.text
|
282
278
|
+ '\n\n [RESPONSE FORMAT (not part of prompt)]\n'
|
283
279
|
+ pg.to_json_str(json_schema, json_indent=2)
|
284
280
|
)
|
281
|
+
return config
|
285
282
|
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
top_k=options.top_k,
|
290
|
-
max_output_tokens=options.max_tokens,
|
291
|
-
stop_sequences=options.stop,
|
292
|
-
response_mime_type=response_mime_type,
|
293
|
-
response_schema=json_schema,
|
294
|
-
)
|
295
|
-
|
296
|
-
def _content_from_message(
|
297
|
-
self, prompt: lf.Message
|
298
|
-
) -> list[str | Any]:
|
299
|
-
"""Gets generation input from langfun message."""
|
300
|
-
assert generative_models is not None
|
301
|
-
chunks = []
|
302
|
-
|
283
|
+
def _content_from_message(self, prompt: lf.Message) -> dict[str, Any]:
|
284
|
+
"""Gets generation content from langfun message."""
|
285
|
+
parts = []
|
303
286
|
for lf_chunk in prompt.chunk():
|
304
287
|
if isinstance(lf_chunk, str):
|
305
|
-
|
288
|
+
parts.append({'text': lf_chunk})
|
306
289
|
elif isinstance(lf_chunk, lf_modalities.Mime):
|
307
290
|
try:
|
308
291
|
modalities = lf_chunk.make_compatible(
|
@@ -312,174 +295,52 @@ class VertexAI(lf.LanguageModel):
|
|
312
295
|
modalities = [modalities]
|
313
296
|
for modality in modalities:
|
314
297
|
if modality.is_text:
|
315
|
-
|
298
|
+
parts.append({'text': modality.to_text()})
|
316
299
|
else:
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
|
300
|
+
parts.append({
|
301
|
+
'inlineData': {
|
302
|
+
'data': base64.b64encode(modality.to_bytes()).decode(),
|
303
|
+
'mimeType': modality.mime_type,
|
304
|
+
}
|
305
|
+
})
|
321
306
|
except lf.ModalityError as e:
|
322
307
|
raise lf.ModalityError(f'Unsupported modality: {lf_chunk!r}') from e
|
323
308
|
else:
|
324
309
|
raise lf.ModalityError(f'Unsupported modality: {lf_chunk!r}')
|
325
|
-
return
|
326
|
-
|
327
|
-
def _generation_response_to_message(
|
328
|
-
self,
|
329
|
-
response: Any, # generative_models.GenerationResponse
|
330
|
-
) -> lf.Message:
|
331
|
-
"""Parses generative response into message."""
|
332
|
-
return lf.AIMessage(response.text)
|
310
|
+
return dict(role='user', parts=parts)
|
333
311
|
|
334
|
-
def
|
335
|
-
|
336
|
-
|
337
|
-
|
338
|
-
"""Parses Endpoint response into message."""
|
339
|
-
return lf.AIMessage(response.predictions[0])
|
340
|
-
|
341
|
-
def _sample(self, prompts: list[lf.Message]) -> list[lf.LMSamplingResult]:
|
342
|
-
assert self._api_initialized, 'Vertex AI API is not initialized.'
|
343
|
-
# TODO(yifenglu): It seems this exception is due to the instability of the
|
344
|
-
# API. We should revisit this later.
|
345
|
-
retry_on_errors = [
|
346
|
-
(Exception, 'InternalServerError'),
|
347
|
-
(Exception, 'ResourceExhausted'),
|
348
|
-
(Exception, '_InactiveRpcError'),
|
349
|
-
(Exception, 'ValueError'),
|
312
|
+
def result(self, json: dict[str, Any]) -> lf.LMSamplingResult:
|
313
|
+
messages = [
|
314
|
+
self._message_from_content_parts(candidate['content']['parts'])
|
315
|
+
for candidate in json['candidates']
|
350
316
|
]
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
prompts,
|
355
|
-
retry_on_errors=retry_on_errors,
|
356
|
-
)
|
357
|
-
|
358
|
-
def _sample_single(self, prompt: lf.Message) -> lf.LMSamplingResult:
|
359
|
-
if self.sampling_options.n > 1:
|
360
|
-
raise ValueError(
|
361
|
-
f'`n` greater than 1 is not supported: {self.sampling_options.n}.'
|
362
|
-
)
|
363
|
-
api = SUPPORTED_MODELS_AND_SETTINGS[self.model].api
|
364
|
-
match api:
|
365
|
-
case 'gemini':
|
366
|
-
return self._sample_generative_model(prompt)
|
367
|
-
case 'palm':
|
368
|
-
return self._sample_text_generation_model(prompt)
|
369
|
-
case 'endpoint':
|
370
|
-
return self._sample_endpoint_model(prompt)
|
371
|
-
case _:
|
372
|
-
raise ValueError(f'Unsupported API: {api}')
|
373
|
-
|
374
|
-
def _sample_generative_model(self, prompt: lf.Message) -> lf.LMSamplingResult:
|
375
|
-
"""Samples a generative model."""
|
376
|
-
model = _VERTEXAI_MODEL_HUB.get_generative_model(self.model)
|
377
|
-
input_content = self._content_from_message(prompt)
|
378
|
-
response = model.generate_content(
|
379
|
-
input_content,
|
380
|
-
generation_config=self._generation_config(
|
381
|
-
prompt, self.sampling_options
|
382
|
-
),
|
383
|
-
)
|
384
|
-
usage_metadata = response.usage_metadata
|
385
|
-
usage = lf.LMSamplingUsage(
|
386
|
-
prompt_tokens=usage_metadata.prompt_token_count,
|
387
|
-
completion_tokens=usage_metadata.candidates_token_count,
|
388
|
-
total_tokens=usage_metadata.total_token_count,
|
389
|
-
estimated_cost=self.estimate_cost(
|
390
|
-
num_input_tokens=usage_metadata.prompt_token_count,
|
391
|
-
num_output_tokens=usage_metadata.candidates_token_count,
|
392
|
-
),
|
393
|
-
)
|
317
|
+
usage = json['usageMetadata']
|
318
|
+
input_tokens = usage['promptTokenCount']
|
319
|
+
output_tokens = usage['candidatesTokenCount']
|
394
320
|
return lf.LMSamplingResult(
|
395
|
-
[
|
396
|
-
|
397
|
-
|
398
|
-
|
321
|
+
[lf.LMSample(message) for message in messages],
|
322
|
+
usage=lf.LMSamplingUsage(
|
323
|
+
prompt_tokens=input_tokens,
|
324
|
+
completion_tokens=output_tokens,
|
325
|
+
total_tokens=input_tokens + output_tokens,
|
326
|
+
estimated_cost=self.estimate_cost(
|
327
|
+
num_input_tokens=input_tokens,
|
328
|
+
num_output_tokens=output_tokens,
|
399
329
|
),
|
400
|
-
|
401
|
-
usage=usage,
|
402
|
-
)
|
403
|
-
|
404
|
-
def _sample_text_generation_model(
|
405
|
-
self, prompt: lf.Message
|
406
|
-
) -> lf.LMSamplingResult:
|
407
|
-
"""Samples a text generation model."""
|
408
|
-
model = _VERTEXAI_MODEL_HUB.get_text_generation_model(self.model)
|
409
|
-
predict_options = dict(
|
410
|
-
temperature=self.sampling_options.temperature,
|
411
|
-
top_k=self.sampling_options.top_k,
|
412
|
-
top_p=self.sampling_options.top_p,
|
413
|
-
max_output_tokens=self.sampling_options.max_tokens,
|
414
|
-
stop_sequences=self.sampling_options.stop,
|
415
|
-
)
|
416
|
-
response = model.predict(prompt.text, **predict_options)
|
417
|
-
return lf.LMSamplingResult([
|
418
|
-
# Scoring is not supported.
|
419
|
-
lf.LMSample(lf.AIMessage(response.text), score=0.0)
|
420
|
-
])
|
421
|
-
|
422
|
-
def _sample_endpoint_model(self, prompt: lf.Message) -> lf.LMSamplingResult:
|
423
|
-
"""Samples a text generation model."""
|
424
|
-
assert aiplatform_models is not None
|
425
|
-
model = aiplatform_models.Endpoint(self.endpoint_name)
|
426
|
-
# TODO(chengrun): Add support for stop_sequences.
|
427
|
-
predict_options = dict(
|
428
|
-
temperature=self.sampling_options.temperature
|
429
|
-
if self.sampling_options.temperature is not None
|
430
|
-
else 1.0,
|
431
|
-
top_k=self.sampling_options.top_k
|
432
|
-
if self.sampling_options.top_k is not None
|
433
|
-
else 32,
|
434
|
-
top_p=self.sampling_options.top_p
|
435
|
-
if self.sampling_options.top_p is not None
|
436
|
-
else 1,
|
437
|
-
max_tokens=self.sampling_options.max_tokens
|
438
|
-
if self.sampling_options.max_tokens is not None
|
439
|
-
else 8192,
|
330
|
+
),
|
440
331
|
)
|
441
|
-
instances = [{'prompt': prompt.text, **predict_options}]
|
442
|
-
response = model.predict(instances=instances)
|
443
|
-
|
444
|
-
return lf.LMSamplingResult([
|
445
|
-
# Scoring is not supported.
|
446
|
-
lf.LMSample(
|
447
|
-
self._generation_endpoint_response_to_message(response), score=0.0
|
448
|
-
)
|
449
|
-
])
|
450
|
-
|
451
|
-
|
452
|
-
class _ModelHub:
|
453
|
-
"""Vertex AI model hub."""
|
454
|
-
|
455
|
-
def __init__(self):
|
456
|
-
self._generative_model_cache = {}
|
457
|
-
self._text_generation_model_cache = {}
|
458
|
-
|
459
|
-
def get_generative_model(
|
460
|
-
self, model_id: str
|
461
|
-
) -> Any: # generative_models.GenerativeModel:
|
462
|
-
"""Gets a generative model by model id."""
|
463
|
-
model = self._generative_model_cache.get(model_id, None)
|
464
|
-
if model is None:
|
465
|
-
assert generative_models is not None
|
466
|
-
model = generative_models.GenerativeModel(model_id)
|
467
|
-
self._generative_model_cache[model_id] = model
|
468
|
-
return model
|
469
332
|
|
470
|
-
def
|
471
|
-
self,
|
472
|
-
) ->
|
473
|
-
"""
|
474
|
-
|
475
|
-
|
476
|
-
|
477
|
-
|
478
|
-
|
479
|
-
|
480
|
-
|
481
|
-
|
482
|
-
_VERTEXAI_MODEL_HUB = _ModelHub()
|
333
|
+
def _message_from_content_parts(
|
334
|
+
self, parts: list[dict[str, Any]]
|
335
|
+
) -> lf.Message:
|
336
|
+
"""Converts Vertex AI's content parts protocol to message."""
|
337
|
+
chunks = []
|
338
|
+
for part in parts:
|
339
|
+
if text_part := part.get('text'):
|
340
|
+
chunks.append(text_part)
|
341
|
+
else:
|
342
|
+
raise ValueError(f'Unsupported part: {part}')
|
343
|
+
return lf.AIMessage.from_chunks(chunks)
|
483
344
|
|
484
345
|
|
485
346
|
_IMAGE_TYPES = [
|
@@ -535,12 +396,6 @@ class VertexAIGemini1_5(VertexAI): # pylint: disable=invalid-name
|
|
535
396
|
)
|
536
397
|
|
537
398
|
|
538
|
-
class VertexAIGeminiPro1_5_Latest(VertexAIGemini1_5): # pylint: disable=invalid-name
|
539
|
-
"""Vertex AI Gemini 1.5 Pro model."""
|
540
|
-
|
541
|
-
model = 'gemini-1.5-pro-latest'
|
542
|
-
|
543
|
-
|
544
399
|
class VertexAIGeminiPro1_5(VertexAIGemini1_5): # pylint: disable=invalid-name
|
545
400
|
"""Vertex AI Gemini 1.5 Pro model."""
|
546
401
|
|
@@ -571,12 +426,6 @@ class VertexAIGeminiPro1_5_0409(VertexAIGemini1_5): # pylint: disable=invalid-n
|
|
571
426
|
model = 'gemini-1.5-pro-preview-0409'
|
572
427
|
|
573
428
|
|
574
|
-
class VertexAIGeminiFlash1_5_Latest(VertexAIGemini1_5): # pylint: disable=invalid-name
|
575
|
-
"""Vertex AI Gemini 1.5 Flash model."""
|
576
|
-
|
577
|
-
model = 'gemini-1.5-flash-latest'
|
578
|
-
|
579
|
-
|
580
429
|
class VertexAIGeminiFlash1_5(VertexAIGemini1_5): # pylint: disable=invalid-name
|
581
430
|
"""Vertex AI Gemini 1.5 Flash model."""
|
582
431
|
|
@@ -608,7 +457,7 @@ class VertexAIGeminiPro1(VertexAI): # pylint: disable=invalid-name
|
|
608
457
|
|
609
458
|
|
610
459
|
class VertexAIGeminiPro1Vision(VertexAI): # pylint: disable=invalid-name
|
611
|
-
"""Vertex AI Gemini 1.0 Pro model."""
|
460
|
+
"""Vertex AI Gemini 1.0 Pro Vision model."""
|
612
461
|
|
613
462
|
model = 'gemini-1.0-pro-vision'
|
614
463
|
supported_modalities: pg.typing.List(str).freeze( # pytype: disable=invalid-annotation
|
@@ -616,19 +465,17 @@ class VertexAIGeminiPro1Vision(VertexAI): # pylint: disable=invalid-name
|
|
616
465
|
)
|
617
466
|
|
618
467
|
|
619
|
-
class
|
620
|
-
"""Vertex AI
|
621
|
-
|
622
|
-
model = 'text-bison'
|
623
|
-
|
468
|
+
class VertexAIEndpoint(VertexAI): # pylint: disable=invalid-name
|
469
|
+
"""Vertex AI Endpoint model."""
|
624
470
|
|
625
|
-
|
626
|
-
"""Vertex AI PaLM2 text generation model (32K context length)."""
|
471
|
+
model = 'vertexai-endpoint'
|
627
472
|
|
628
|
-
|
473
|
+
endpoint: Annotated[str, 'Vertex AI Endpoint ID.']
|
629
474
|
|
630
|
-
|
631
|
-
|
632
|
-
|
633
|
-
|
634
|
-
|
475
|
+
@property
|
476
|
+
def api_endpoint(self) -> str:
|
477
|
+
return (
|
478
|
+
f'https://{self.location}-aiplatform.googleapis.com/v1/projects/'
|
479
|
+
f'{self.project}/locations/{self.location}/'
|
480
|
+
f'endpoints/{self.endpoint}:generateContent'
|
481
|
+
)
|