langfun 0.1.2.dev202412010804__py3-none-any.whl → 0.1.2.dev202412030804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -13,30 +13,28 @@
13
13
  # limitations under the License.
14
14
  """Vertex AI generative models."""
15
15
 
16
+ import base64
16
17
  import functools
17
18
  import os
18
19
  from typing import Annotated, Any
19
20
 
20
21
  import langfun.core as lf
21
22
  from langfun.core import modalities as lf_modalities
23
+ from langfun.core.llms import rest
22
24
  import pyglove as pg
23
25
 
24
26
  try:
25
27
  # pylint: disable=g-import-not-at-top
28
+ from google import auth as google_auth
26
29
  from google.auth import credentials as credentials_lib
27
- import vertexai
28
- from google.cloud.aiplatform import models as aiplatform_models
29
- from vertexai import generative_models
30
- from vertexai import language_models
30
+ from google.auth.transport import requests as auth_requests
31
31
  # pylint: enable=g-import-not-at-top
32
32
 
33
33
  Credentials = credentials_lib.Credentials
34
34
  except ImportError:
35
- credentials_lib = None # pylint: disable=invalid-name
36
- vertexai = None
37
- generative_models = None
38
- language_models = None
39
- aiplatform_models = None
35
+ google_auth = None
36
+ credentials_lib = None
37
+ auth_requests = None
40
38
  Credentials = Any
41
39
 
42
40
 
@@ -50,122 +48,86 @@ AVGERAGE_CHARS_PER_TOEKN = 4
50
48
  # as of 2024-10-10.
51
49
  SUPPORTED_MODELS_AND_SETTINGS = {
52
50
  'gemini-1.5-pro-001': pg.Dict(
53
- api='gemini',
54
51
  rpm=100,
55
52
  cost_per_1k_input_chars=0.0003125,
56
53
  cost_per_1k_output_chars=0.00125,
57
54
  ),
58
55
  'gemini-1.5-pro-002': pg.Dict(
59
- api='gemini',
60
56
  rpm=100,
61
57
  cost_per_1k_input_chars=0.0003125,
62
58
  cost_per_1k_output_chars=0.00125,
63
59
  ),
64
60
  'gemini-1.5-flash-002': pg.Dict(
65
- api='gemini',
66
61
  rpm=500,
67
62
  cost_per_1k_input_chars=0.00001875,
68
63
  cost_per_1k_output_chars=0.000075,
69
64
  ),
70
65
  'gemini-1.5-flash-001': pg.Dict(
71
- api='gemini',
72
66
  rpm=500,
73
67
  cost_per_1k_input_chars=0.00001875,
74
68
  cost_per_1k_output_chars=0.000075,
75
69
  ),
76
70
  'gemini-1.5-pro': pg.Dict(
77
- api='gemini',
78
71
  rpm=100,
79
72
  cost_per_1k_input_chars=0.0003125,
80
73
  cost_per_1k_output_chars=0.00125,
81
74
  ),
82
75
  'gemini-1.5-flash': pg.Dict(
83
- api='gemini',
84
- rpm=500,
85
- cost_per_1k_input_chars=0.00001875,
86
- cost_per_1k_output_chars=0.000075,
87
- ),
88
- 'gemini-1.5-pro-latest': pg.Dict(
89
- api='gemini',
90
- rpm=100,
91
- cost_per_1k_input_chars=0.0003125,
92
- cost_per_1k_output_chars=0.00125,
93
- ),
94
- 'gemini-1.5-flash-latest': pg.Dict(
95
- api='gemini',
96
76
  rpm=500,
97
77
  cost_per_1k_input_chars=0.00001875,
98
78
  cost_per_1k_output_chars=0.000075,
99
79
  ),
100
80
  'gemini-1.5-pro-preview-0514': pg.Dict(
101
- api='gemini',
102
81
  rpm=50,
103
82
  cost_per_1k_input_chars=0.0003125,
104
83
  cost_per_1k_output_chars=0.00125,
105
84
  ),
106
85
  'gemini-1.5-pro-preview-0409': pg.Dict(
107
- api='gemini',
108
86
  rpm=50,
109
87
  cost_per_1k_input_chars=0.0003125,
110
88
  cost_per_1k_output_chars=0.00125,
111
89
  ),
112
90
  'gemini-1.5-flash-preview-0514': pg.Dict(
113
- api='gemini',
114
91
  rpm=200,
115
92
  cost_per_1k_input_chars=0.00001875,
116
93
  cost_per_1k_output_chars=0.000075,
117
94
  ),
118
95
  'gemini-1.0-pro': pg.Dict(
119
- api='gemini',
120
96
  rpm=300,
121
97
  cost_per_1k_input_chars=0.000125,
122
98
  cost_per_1k_output_chars=0.000375,
123
99
  ),
124
100
  'gemini-1.0-pro-vision': pg.Dict(
125
- api='gemini',
126
101
  rpm=100,
127
102
  cost_per_1k_input_chars=0.000125,
128
103
  cost_per_1k_output_chars=0.000375,
129
104
  ),
130
- # PaLM APIs.
131
- 'text-bison': pg.Dict(
132
- api='palm',
133
- rpm=1600
134
- ),
135
- 'text-bison-32k': pg.Dict(
136
- api='palm',
137
- rpm=300
138
- ),
139
- 'text-unicorn': pg.Dict(
140
- api='palm',
141
- rpm=100
142
- ),
143
- # Endpoint
144
105
  # TODO(chengrun): Set a more appropriate rpm for endpoint.
145
- 'custom': pg.Dict(api='endpoint', rpm=20),
106
+ 'vertexai-endpoint': pg.Dict(
107
+ rpm=20,
108
+ cost_per_1k_input_chars=0.0000125,
109
+ cost_per_1k_output_chars=0.0000375,
110
+ ),
146
111
  }
147
112
 
148
113
 
149
114
  @lf.use_init_args(['model'])
150
- class VertexAI(lf.LanguageModel):
151
- """Language model served on VertexAI."""
115
+ @pg.members([('api_endpoint', pg.typing.Str().freeze(''))])
116
+ class VertexAI(rest.REST):
117
+ """Language model served on VertexAI with REST API."""
152
118
 
153
119
  model: pg.typing.Annotated[
154
120
  pg.typing.Enum(
155
121
  pg.MISSING_VALUE, list(SUPPORTED_MODELS_AND_SETTINGS.keys())
156
122
  ),
157
123
  (
158
- 'Vertex AI model name. See '
159
- 'https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models '
160
- 'for details.'
124
+ 'Vertex AI model name with REST API support. See '
125
+ 'https://cloud.google.com/vertex-ai/generative-ai/docs/'
126
+ 'model-reference/inference#supported-models'
127
+ ' for details.'
161
128
  ),
162
129
  ]
163
130
 
164
- endpoint_name: pg.typing.Annotated[
165
- str | None,
166
- 'Vertex Endpoint name or ID.',
167
- ]
168
-
169
131
  project: Annotated[
170
132
  str | None,
171
133
  (
@@ -197,14 +159,14 @@ class VertexAI(lf.LanguageModel):
197
159
 
198
160
  def _on_bound(self):
199
161
  super()._on_bound()
200
- self.__dict__.pop('_api_initialized', None)
201
- if generative_models is None:
202
- raise RuntimeError(
162
+ if google_auth is None:
163
+ raise ValueError(
203
164
  'Please install "langfun[llm-google-vertex]" to use Vertex AI models.'
204
165
  )
166
+ self._project = None
167
+ self._credentials = None
205
168
 
206
- @functools.cached_property
207
- def _api_initialized(self):
169
+ def _initialize(self):
208
170
  project = self.project or os.environ.get('VERTEXAI_PROJECT', None)
209
171
  if not project:
210
172
  raise ValueError(
@@ -219,21 +181,14 @@ class VertexAI(lf.LanguageModel):
219
181
  'variable `VERTEXAI_LOCATION` with your Vertex AI service location.'
220
182
  )
221
183
 
184
+ self._project = project
222
185
  credentials = self.credentials
223
- # Placeholder for Google-internal credentials.
224
- assert vertexai is not None
225
- vertexai.init(project=project, location=location, credentials=credentials)
226
- return True
227
-
228
- @property
229
- def model_id(self) -> str:
230
- """Returns a string to identify the model."""
231
- return f'VertexAI({self.model})'
232
-
233
- @property
234
- def resource_id(self) -> str:
235
- """Returns a string to identify the resource for rate control."""
236
- return self.model_id
186
+ if credentials is None:
187
+ # Use default credentials.
188
+ credentials = google_auth.default(
189
+ scopes=['https://www.googleapis.com/auth/cloud-platform']
190
+ )
191
+ self._credentials = credentials
237
192
 
238
193
  @property
239
194
  def max_concurrency(self) -> int:
@@ -262,47 +217,75 @@ class VertexAI(lf.LanguageModel):
262
217
  + cost_per_1k_output_chars * num_output_tokens
263
218
  ) * AVGERAGE_CHARS_PER_TOEKN / 1000
264
219
 
220
+ @functools.cached_property
221
+ def _session(self):
222
+ assert self._api_initialized
223
+ assert self._credentials is not None
224
+ assert auth_requests is not None
225
+ s = auth_requests.AuthorizedSession(self._credentials)
226
+ s.headers.update(self.headers or {})
227
+ return s
228
+
229
+ @property
230
+ def headers(self):
231
+ return {
232
+ 'Content-Type': 'application/json; charset=utf-8',
233
+ }
234
+
235
+ @property
236
+ def api_endpoint(self) -> str:
237
+ return (
238
+ f'https://{self.location}-aiplatform.googleapis.com/v1/projects/'
239
+ f'{self.project}/locations/{self.location}/publishers/google/'
240
+ f'models/{self.model}:generateContent'
241
+ )
242
+
243
+ def request(
244
+ self, prompt: lf.Message, sampling_options: lf.LMSamplingOptions
245
+ ) -> dict[str, Any]:
246
+ request = dict(
247
+ generationConfig=self._generation_config(prompt, sampling_options)
248
+ )
249
+ request['contents'] = [self._content_from_message(prompt)]
250
+ return request
251
+
265
252
  def _generation_config(
266
253
  self, prompt: lf.Message, options: lf.LMSamplingOptions
267
- ) -> Any: # generative_models.GenerationConfig
268
- """Creates generation config from langfun sampling options."""
269
- assert generative_models is not None
270
- # Users could use `metadata_json_schema` to pass additional
271
- # request arguments.
272
- json_schema = prompt.metadata.get('json_schema')
273
- response_mime_type = None
274
- if json_schema is not None:
254
+ ) -> dict[str, Any]:
255
+ """Returns a dict as generation config for prompt and LMSamplingOptions."""
256
+ config = dict(
257
+ temperature=options.temperature,
258
+ maxOutputTokens=options.max_tokens,
259
+ candidateCount=options.n,
260
+ topK=options.top_k,
261
+ topP=options.top_p,
262
+ stopSequences=options.stop,
263
+ seed=options.random_seed,
264
+ responseLogprobs=options.logprobs,
265
+ logprobs=options.top_logprobs,
266
+ )
267
+
268
+ if json_schema := prompt.metadata.get('json_schema'):
275
269
  if not isinstance(json_schema, dict):
276
270
  raise ValueError(
277
271
  f'`json_schema` must be a dict, got {json_schema!r}.'
278
272
  )
279
- response_mime_type = 'application/json'
273
+ json_schema = pg.to_json(json_schema)
274
+ config['responseSchema'] = json_schema
275
+ config['responseMimeType'] = 'application/json'
280
276
  prompt.metadata.formatted_text = (
281
277
  prompt.text
282
278
  + '\n\n [RESPONSE FORMAT (not part of prompt)]\n'
283
279
  + pg.to_json_str(json_schema, json_indent=2)
284
280
  )
281
+ return config
285
282
 
286
- return generative_models.GenerationConfig(
287
- temperature=options.temperature,
288
- top_p=options.top_p,
289
- top_k=options.top_k,
290
- max_output_tokens=options.max_tokens,
291
- stop_sequences=options.stop,
292
- response_mime_type=response_mime_type,
293
- response_schema=json_schema,
294
- )
295
-
296
- def _content_from_message(
297
- self, prompt: lf.Message
298
- ) -> list[str | Any]:
299
- """Gets generation input from langfun message."""
300
- assert generative_models is not None
301
- chunks = []
302
-
283
+ def _content_from_message(self, prompt: lf.Message) -> dict[str, Any]:
284
+ """Gets generation content from langfun message."""
285
+ parts = []
303
286
  for lf_chunk in prompt.chunk():
304
287
  if isinstance(lf_chunk, str):
305
- chunks.append(lf_chunk)
288
+ parts.append({'text': lf_chunk})
306
289
  elif isinstance(lf_chunk, lf_modalities.Mime):
307
290
  try:
308
291
  modalities = lf_chunk.make_compatible(
@@ -312,174 +295,52 @@ class VertexAI(lf.LanguageModel):
312
295
  modalities = [modalities]
313
296
  for modality in modalities:
314
297
  if modality.is_text:
315
- chunk = modality.to_text()
298
+ parts.append({'text': modality.to_text()})
316
299
  else:
317
- chunk = generative_models.Part.from_data(
318
- modality.to_bytes(), modality.mime_type
319
- )
320
- chunks.append(chunk)
300
+ parts.append({
301
+ 'inlineData': {
302
+ 'data': base64.b64encode(modality.to_bytes()).decode(),
303
+ 'mimeType': modality.mime_type,
304
+ }
305
+ })
321
306
  except lf.ModalityError as e:
322
307
  raise lf.ModalityError(f'Unsupported modality: {lf_chunk!r}') from e
323
308
  else:
324
309
  raise lf.ModalityError(f'Unsupported modality: {lf_chunk!r}')
325
- return chunks
326
-
327
- def _generation_response_to_message(
328
- self,
329
- response: Any, # generative_models.GenerationResponse
330
- ) -> lf.Message:
331
- """Parses generative response into message."""
332
- return lf.AIMessage(response.text)
310
+ return dict(role='user', parts=parts)
333
311
 
334
- def _generation_endpoint_response_to_message(
335
- self,
336
- response: Any, # google.cloud.aiplatform.aiplatform.models.Prediction
337
- ) -> lf.Message:
338
- """Parses Endpoint response into message."""
339
- return lf.AIMessage(response.predictions[0])
340
-
341
- def _sample(self, prompts: list[lf.Message]) -> list[lf.LMSamplingResult]:
342
- assert self._api_initialized, 'Vertex AI API is not initialized.'
343
- # TODO(yifenglu): It seems this exception is due to the instability of the
344
- # API. We should revisit this later.
345
- retry_on_errors = [
346
- (Exception, 'InternalServerError'),
347
- (Exception, 'ResourceExhausted'),
348
- (Exception, '_InactiveRpcError'),
349
- (Exception, 'ValueError'),
312
+ def result(self, json: dict[str, Any]) -> lf.LMSamplingResult:
313
+ messages = [
314
+ self._message_from_content_parts(candidate['content']['parts'])
315
+ for candidate in json['candidates']
350
316
  ]
351
-
352
- return self._parallel_execute_with_currency_control(
353
- self._sample_single,
354
- prompts,
355
- retry_on_errors=retry_on_errors,
356
- )
357
-
358
- def _sample_single(self, prompt: lf.Message) -> lf.LMSamplingResult:
359
- if self.sampling_options.n > 1:
360
- raise ValueError(
361
- f'`n` greater than 1 is not supported: {self.sampling_options.n}.'
362
- )
363
- api = SUPPORTED_MODELS_AND_SETTINGS[self.model].api
364
- match api:
365
- case 'gemini':
366
- return self._sample_generative_model(prompt)
367
- case 'palm':
368
- return self._sample_text_generation_model(prompt)
369
- case 'endpoint':
370
- return self._sample_endpoint_model(prompt)
371
- case _:
372
- raise ValueError(f'Unsupported API: {api}')
373
-
374
- def _sample_generative_model(self, prompt: lf.Message) -> lf.LMSamplingResult:
375
- """Samples a generative model."""
376
- model = _VERTEXAI_MODEL_HUB.get_generative_model(self.model)
377
- input_content = self._content_from_message(prompt)
378
- response = model.generate_content(
379
- input_content,
380
- generation_config=self._generation_config(
381
- prompt, self.sampling_options
382
- ),
383
- )
384
- usage_metadata = response.usage_metadata
385
- usage = lf.LMSamplingUsage(
386
- prompt_tokens=usage_metadata.prompt_token_count,
387
- completion_tokens=usage_metadata.candidates_token_count,
388
- total_tokens=usage_metadata.total_token_count,
389
- estimated_cost=self.estimate_cost(
390
- num_input_tokens=usage_metadata.prompt_token_count,
391
- num_output_tokens=usage_metadata.candidates_token_count,
392
- ),
393
- )
317
+ usage = json['usageMetadata']
318
+ input_tokens = usage['promptTokenCount']
319
+ output_tokens = usage['candidatesTokenCount']
394
320
  return lf.LMSamplingResult(
395
- [
396
- # Scoring is not supported.
397
- lf.LMSample(
398
- self._generation_response_to_message(response), score=0.0
321
+ [lf.LMSample(message) for message in messages],
322
+ usage=lf.LMSamplingUsage(
323
+ prompt_tokens=input_tokens,
324
+ completion_tokens=output_tokens,
325
+ total_tokens=input_tokens + output_tokens,
326
+ estimated_cost=self.estimate_cost(
327
+ num_input_tokens=input_tokens,
328
+ num_output_tokens=output_tokens,
399
329
  ),
400
- ],
401
- usage=usage,
402
- )
403
-
404
- def _sample_text_generation_model(
405
- self, prompt: lf.Message
406
- ) -> lf.LMSamplingResult:
407
- """Samples a text generation model."""
408
- model = _VERTEXAI_MODEL_HUB.get_text_generation_model(self.model)
409
- predict_options = dict(
410
- temperature=self.sampling_options.temperature,
411
- top_k=self.sampling_options.top_k,
412
- top_p=self.sampling_options.top_p,
413
- max_output_tokens=self.sampling_options.max_tokens,
414
- stop_sequences=self.sampling_options.stop,
415
- )
416
- response = model.predict(prompt.text, **predict_options)
417
- return lf.LMSamplingResult([
418
- # Scoring is not supported.
419
- lf.LMSample(lf.AIMessage(response.text), score=0.0)
420
- ])
421
-
422
- def _sample_endpoint_model(self, prompt: lf.Message) -> lf.LMSamplingResult:
423
- """Samples a text generation model."""
424
- assert aiplatform_models is not None
425
- model = aiplatform_models.Endpoint(self.endpoint_name)
426
- # TODO(chengrun): Add support for stop_sequences.
427
- predict_options = dict(
428
- temperature=self.sampling_options.temperature
429
- if self.sampling_options.temperature is not None
430
- else 1.0,
431
- top_k=self.sampling_options.top_k
432
- if self.sampling_options.top_k is not None
433
- else 32,
434
- top_p=self.sampling_options.top_p
435
- if self.sampling_options.top_p is not None
436
- else 1,
437
- max_tokens=self.sampling_options.max_tokens
438
- if self.sampling_options.max_tokens is not None
439
- else 8192,
330
+ ),
440
331
  )
441
- instances = [{'prompt': prompt.text, **predict_options}]
442
- response = model.predict(instances=instances)
443
-
444
- return lf.LMSamplingResult([
445
- # Scoring is not supported.
446
- lf.LMSample(
447
- self._generation_endpoint_response_to_message(response), score=0.0
448
- )
449
- ])
450
-
451
-
452
- class _ModelHub:
453
- """Vertex AI model hub."""
454
-
455
- def __init__(self):
456
- self._generative_model_cache = {}
457
- self._text_generation_model_cache = {}
458
-
459
- def get_generative_model(
460
- self, model_id: str
461
- ) -> Any: # generative_models.GenerativeModel:
462
- """Gets a generative model by model id."""
463
- model = self._generative_model_cache.get(model_id, None)
464
- if model is None:
465
- assert generative_models is not None
466
- model = generative_models.GenerativeModel(model_id)
467
- self._generative_model_cache[model_id] = model
468
- return model
469
332
 
470
- def get_text_generation_model(
471
- self, model_id: str
472
- ) -> Any: # language_models.TextGenerationModel
473
- """Gets a text generation model by model id."""
474
- model = self._text_generation_model_cache.get(model_id, None)
475
- if model is None:
476
- assert language_models is not None
477
- model = language_models.TextGenerationModel.from_pretrained(model_id)
478
- self._text_generation_model_cache[model_id] = model
479
- return model
480
-
481
-
482
- _VERTEXAI_MODEL_HUB = _ModelHub()
333
+ def _message_from_content_parts(
334
+ self, parts: list[dict[str, Any]]
335
+ ) -> lf.Message:
336
+ """Converts Vertex AI's content parts protocol to message."""
337
+ chunks = []
338
+ for part in parts:
339
+ if text_part := part.get('text'):
340
+ chunks.append(text_part)
341
+ else:
342
+ raise ValueError(f'Unsupported part: {part}')
343
+ return lf.AIMessage.from_chunks(chunks)
483
344
 
484
345
 
485
346
  _IMAGE_TYPES = [
@@ -535,12 +396,6 @@ class VertexAIGemini1_5(VertexAI): # pylint: disable=invalid-name
535
396
  )
536
397
 
537
398
 
538
- class VertexAIGeminiPro1_5_Latest(VertexAIGemini1_5): # pylint: disable=invalid-name
539
- """Vertex AI Gemini 1.5 Pro model."""
540
-
541
- model = 'gemini-1.5-pro-latest'
542
-
543
-
544
399
  class VertexAIGeminiPro1_5(VertexAIGemini1_5): # pylint: disable=invalid-name
545
400
  """Vertex AI Gemini 1.5 Pro model."""
546
401
 
@@ -571,12 +426,6 @@ class VertexAIGeminiPro1_5_0409(VertexAIGemini1_5): # pylint: disable=invalid-n
571
426
  model = 'gemini-1.5-pro-preview-0409'
572
427
 
573
428
 
574
- class VertexAIGeminiFlash1_5_Latest(VertexAIGemini1_5): # pylint: disable=invalid-name
575
- """Vertex AI Gemini 1.5 Flash model."""
576
-
577
- model = 'gemini-1.5-flash-latest'
578
-
579
-
580
429
  class VertexAIGeminiFlash1_5(VertexAIGemini1_5): # pylint: disable=invalid-name
581
430
  """Vertex AI Gemini 1.5 Flash model."""
582
431
 
@@ -608,7 +457,7 @@ class VertexAIGeminiPro1(VertexAI): # pylint: disable=invalid-name
608
457
 
609
458
 
610
459
  class VertexAIGeminiPro1Vision(VertexAI): # pylint: disable=invalid-name
611
- """Vertex AI Gemini 1.0 Pro model."""
460
+ """Vertex AI Gemini 1.0 Pro Vision model."""
612
461
 
613
462
  model = 'gemini-1.0-pro-vision'
614
463
  supported_modalities: pg.typing.List(str).freeze( # pytype: disable=invalid-annotation
@@ -616,19 +465,17 @@ class VertexAIGeminiPro1Vision(VertexAI): # pylint: disable=invalid-name
616
465
  )
617
466
 
618
467
 
619
- class VertexAIPalm2(VertexAI): # pylint: disable=invalid-name
620
- """Vertex AI PaLM2 text generation model."""
621
-
622
- model = 'text-bison'
623
-
468
+ class VertexAIEndpoint(VertexAI): # pylint: disable=invalid-name
469
+ """Vertex AI Endpoint model."""
624
470
 
625
- class VertexAIPalm2_32K(VertexAI): # pylint: disable=invalid-name
626
- """Vertex AI PaLM2 text generation model (32K context length)."""
471
+ model = 'vertexai-endpoint'
627
472
 
628
- model = 'text-bison-32k'
473
+ endpoint: Annotated[str, 'Vertex AI Endpoint ID.']
629
474
 
630
-
631
- class VertexAICustom(VertexAI): # pylint: disable=invalid-name
632
- """Vertex AI Custom model (Endpoint)."""
633
-
634
- model = 'custom'
475
+ @property
476
+ def api_endpoint(self) -> str:
477
+ return (
478
+ f'https://{self.location}-aiplatform.googleapis.com/v1/projects/'
479
+ f'{self.project}/locations/{self.location}/'
480
+ f'endpoints/{self.endpoint}:generateContent'
481
+ )