langfun 0.1.2.dev202412020805__py3-none-any.whl → 0.1.2.dev202412030804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -28,21 +28,13 @@ try:
28
28
  from google import auth as google_auth
29
29
  from google.auth import credentials as credentials_lib
30
30
  from google.auth.transport import requests as auth_requests
31
- import vertexai
32
- from google.cloud.aiplatform import models as aiplatform_models
33
- from vertexai import generative_models
34
- from vertexai import language_models
35
31
  # pylint: enable=g-import-not-at-top
36
32
 
37
33
  Credentials = credentials_lib.Credentials
38
34
  except ImportError:
39
35
  google_auth = None
36
+ credentials_lib = None
40
37
  auth_requests = None
41
- credentials_lib = None # pylint: disable=invalid-name
42
- vertexai = None
43
- generative_models = None
44
- language_models = None
45
- aiplatform_models = None
46
38
  Credentials = Any
47
39
 
48
40
 
@@ -56,408 +48,72 @@ AVGERAGE_CHARS_PER_TOEKN = 4
56
48
  # as of 2024-10-10.
57
49
  SUPPORTED_MODELS_AND_SETTINGS = {
58
50
  'gemini-1.5-pro-001': pg.Dict(
59
- api='gemini',
60
51
  rpm=100,
61
52
  cost_per_1k_input_chars=0.0003125,
62
53
  cost_per_1k_output_chars=0.00125,
63
54
  ),
64
55
  'gemini-1.5-pro-002': pg.Dict(
65
- api='gemini',
66
56
  rpm=100,
67
57
  cost_per_1k_input_chars=0.0003125,
68
58
  cost_per_1k_output_chars=0.00125,
69
59
  ),
70
60
  'gemini-1.5-flash-002': pg.Dict(
71
- api='gemini',
72
61
  rpm=500,
73
62
  cost_per_1k_input_chars=0.00001875,
74
63
  cost_per_1k_output_chars=0.000075,
75
64
  ),
76
65
  'gemini-1.5-flash-001': pg.Dict(
77
- api='gemini',
78
66
  rpm=500,
79
67
  cost_per_1k_input_chars=0.00001875,
80
68
  cost_per_1k_output_chars=0.000075,
81
69
  ),
82
70
  'gemini-1.5-pro': pg.Dict(
83
- api='gemini',
84
71
  rpm=100,
85
72
  cost_per_1k_input_chars=0.0003125,
86
73
  cost_per_1k_output_chars=0.00125,
87
74
  ),
88
75
  'gemini-1.5-flash': pg.Dict(
89
- api='gemini',
90
- rpm=500,
91
- cost_per_1k_input_chars=0.00001875,
92
- cost_per_1k_output_chars=0.000075,
93
- ),
94
- 'gemini-1.5-pro-latest': pg.Dict(
95
- api='gemini',
96
- rpm=100,
97
- cost_per_1k_input_chars=0.0003125,
98
- cost_per_1k_output_chars=0.00125,
99
- ),
100
- 'gemini-1.5-flash-latest': pg.Dict(
101
- api='gemini',
102
76
  rpm=500,
103
77
  cost_per_1k_input_chars=0.00001875,
104
78
  cost_per_1k_output_chars=0.000075,
105
79
  ),
106
80
  'gemini-1.5-pro-preview-0514': pg.Dict(
107
- api='gemini',
108
81
  rpm=50,
109
82
  cost_per_1k_input_chars=0.0003125,
110
83
  cost_per_1k_output_chars=0.00125,
111
84
  ),
112
85
  'gemini-1.5-pro-preview-0409': pg.Dict(
113
- api='gemini',
114
86
  rpm=50,
115
87
  cost_per_1k_input_chars=0.0003125,
116
88
  cost_per_1k_output_chars=0.00125,
117
89
  ),
118
90
  'gemini-1.5-flash-preview-0514': pg.Dict(
119
- api='gemini',
120
91
  rpm=200,
121
92
  cost_per_1k_input_chars=0.00001875,
122
93
  cost_per_1k_output_chars=0.000075,
123
94
  ),
124
95
  'gemini-1.0-pro': pg.Dict(
125
- api='gemini',
126
96
  rpm=300,
127
97
  cost_per_1k_input_chars=0.000125,
128
98
  cost_per_1k_output_chars=0.000375,
129
99
  ),
130
100
  'gemini-1.0-pro-vision': pg.Dict(
131
- api='gemini',
132
101
  rpm=100,
133
102
  cost_per_1k_input_chars=0.000125,
134
103
  cost_per_1k_output_chars=0.000375,
135
104
  ),
136
- # PaLM APIs.
137
- 'text-bison': pg.Dict(
138
- api='palm',
139
- rpm=1600
140
- ),
141
- 'text-bison-32k': pg.Dict(
142
- api='palm',
143
- rpm=300
144
- ),
145
- 'text-unicorn': pg.Dict(
146
- api='palm',
147
- rpm=100
148
- ),
149
- # Endpoint
150
105
  # TODO(chengrun): Set a more appropriate rpm for endpoint.
151
- 'custom': pg.Dict(api='endpoint', rpm=20),
106
+ 'vertexai-endpoint': pg.Dict(
107
+ rpm=20,
108
+ cost_per_1k_input_chars=0.0000125,
109
+ cost_per_1k_output_chars=0.0000375,
110
+ ),
152
111
  }
153
112
 
154
113
 
155
- @lf.use_init_args(['model'])
156
- class VertexAI(lf.LanguageModel):
157
- """Language model served on VertexAI."""
158
-
159
- model: pg.typing.Annotated[
160
- pg.typing.Enum(
161
- pg.MISSING_VALUE, list(SUPPORTED_MODELS_AND_SETTINGS.keys())
162
- ),
163
- (
164
- 'Vertex AI model name. See '
165
- 'https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models '
166
- 'for details.'
167
- ),
168
- ]
169
-
170
- endpoint_name: pg.typing.Annotated[
171
- str | None,
172
- 'Vertex Endpoint name or ID.',
173
- ]
174
-
175
- project: Annotated[
176
- str | None,
177
- (
178
- 'Vertex AI project ID. Or set from environment variable '
179
- 'VERTEXAI_PROJECT.'
180
- ),
181
- ] = None
182
-
183
- location: Annotated[
184
- str | None,
185
- (
186
- 'Vertex AI service location. Or set from environment variable '
187
- 'VERTEXAI_LOCATION.'
188
- ),
189
- ] = None
190
-
191
- credentials: Annotated[
192
- Credentials | None,
193
- (
194
- 'Credentials to use. If None, the default credentials to the '
195
- 'environment will be used.'
196
- ),
197
- ] = None
198
-
199
- supported_modalities: Annotated[
200
- list[str],
201
- 'A list of MIME types for supported modalities'
202
- ] = []
203
-
204
- def _on_bound(self):
205
- super()._on_bound()
206
- self.__dict__.pop('_api_initialized', None)
207
- if generative_models is None:
208
- raise RuntimeError(
209
- 'Please install "langfun[llm-google-vertex]" to use Vertex AI models.'
210
- )
211
-
212
- @functools.cached_property
213
- def _api_initialized(self):
214
- project = self.project or os.environ.get('VERTEXAI_PROJECT', None)
215
- if not project:
216
- raise ValueError(
217
- 'Please specify `project` during `__init__` or set environment '
218
- 'variable `VERTEXAI_PROJECT` with your Vertex AI project ID.'
219
- )
220
-
221
- location = self.location or os.environ.get('VERTEXAI_LOCATION', None)
222
- if not location:
223
- raise ValueError(
224
- 'Please specify `location` during `__init__` or set environment '
225
- 'variable `VERTEXAI_LOCATION` with your Vertex AI service location.'
226
- )
227
-
228
- credentials = self.credentials
229
- # Placeholder for Google-internal credentials.
230
- assert vertexai is not None
231
- vertexai.init(project=project, location=location, credentials=credentials)
232
- return True
233
-
234
- @property
235
- def model_id(self) -> str:
236
- """Returns a string to identify the model."""
237
- return f'VertexAI({self.model})'
238
-
239
- @property
240
- def resource_id(self) -> str:
241
- """Returns a string to identify the resource for rate control."""
242
- return self.model_id
243
-
244
- @property
245
- def max_concurrency(self) -> int:
246
- """Returns the maximum number of concurrent requests."""
247
- return self.rate_to_max_concurrency(
248
- requests_per_min=SUPPORTED_MODELS_AND_SETTINGS[self.model].rpm,
249
- tokens_per_min=0,
250
- )
251
-
252
- def estimate_cost(
253
- self,
254
- num_input_tokens: int,
255
- num_output_tokens: int
256
- ) -> float | None:
257
- """Estimate the cost based on usage."""
258
- cost_per_1k_input_chars = SUPPORTED_MODELS_AND_SETTINGS[self.model].get(
259
- 'cost_per_1k_input_chars', None
260
- )
261
- cost_per_1k_output_chars = SUPPORTED_MODELS_AND_SETTINGS[self.model].get(
262
- 'cost_per_1k_output_chars', None
263
- )
264
- if cost_per_1k_output_chars is None or cost_per_1k_input_chars is None:
265
- return None
266
- return (
267
- cost_per_1k_input_chars * num_input_tokens
268
- + cost_per_1k_output_chars * num_output_tokens
269
- ) * AVGERAGE_CHARS_PER_TOEKN / 1000
270
-
271
- def _generation_config(
272
- self, prompt: lf.Message, options: lf.LMSamplingOptions
273
- ) -> Any: # generative_models.GenerationConfig
274
- """Creates generation config from langfun sampling options."""
275
- assert generative_models is not None
276
- # Users could use `metadata_json_schema` to pass additional
277
- # request arguments.
278
- json_schema = prompt.metadata.get('json_schema')
279
- response_mime_type = None
280
- if json_schema is not None:
281
- if not isinstance(json_schema, dict):
282
- raise ValueError(
283
- f'`json_schema` must be a dict, got {json_schema!r}.'
284
- )
285
- response_mime_type = 'application/json'
286
- prompt.metadata.formatted_text = (
287
- prompt.text
288
- + '\n\n [RESPONSE FORMAT (not part of prompt)]\n'
289
- + pg.to_json_str(json_schema, json_indent=2)
290
- )
291
-
292
- return generative_models.GenerationConfig(
293
- temperature=options.temperature,
294
- top_p=options.top_p,
295
- top_k=options.top_k,
296
- max_output_tokens=options.max_tokens,
297
- stop_sequences=options.stop,
298
- response_mime_type=response_mime_type,
299
- response_schema=json_schema,
300
- )
301
-
302
- def _content_from_message(
303
- self, prompt: lf.Message
304
- ) -> list[str | Any]:
305
- """Gets generation input from langfun message."""
306
- assert generative_models is not None
307
- chunks = []
308
-
309
- for lf_chunk in prompt.chunk():
310
- if isinstance(lf_chunk, str):
311
- chunks.append(lf_chunk)
312
- elif isinstance(lf_chunk, lf_modalities.Mime):
313
- try:
314
- modalities = lf_chunk.make_compatible(
315
- self.supported_modalities + ['text/plain']
316
- )
317
- if isinstance(modalities, lf_modalities.Mime):
318
- modalities = [modalities]
319
- for modality in modalities:
320
- if modality.is_text:
321
- chunk = modality.to_text()
322
- else:
323
- chunk = generative_models.Part.from_data(
324
- modality.to_bytes(), modality.mime_type
325
- )
326
- chunks.append(chunk)
327
- except lf.ModalityError as e:
328
- raise lf.ModalityError(f'Unsupported modality: {lf_chunk!r}') from e
329
- else:
330
- raise lf.ModalityError(f'Unsupported modality: {lf_chunk!r}')
331
- return chunks
332
-
333
- def _generation_response_to_message(
334
- self,
335
- response: Any, # generative_models.GenerationResponse
336
- ) -> lf.Message:
337
- """Parses generative response into message."""
338
- return lf.AIMessage(response.text)
339
-
340
- def _generation_endpoint_response_to_message(
341
- self,
342
- response: Any, # google.cloud.aiplatform.aiplatform.models.Prediction
343
- ) -> lf.Message:
344
- """Parses Endpoint response into message."""
345
- return lf.AIMessage(response.predictions[0])
346
-
347
- def _sample(self, prompts: list[lf.Message]) -> list[lf.LMSamplingResult]:
348
- assert self._api_initialized, 'Vertex AI API is not initialized.'
349
- # TODO(yifenglu): It seems this exception is due to the instability of the
350
- # API. We should revisit this later.
351
- retry_on_errors = [
352
- (Exception, 'InternalServerError'),
353
- (Exception, 'ResourceExhausted'),
354
- (Exception, '_InactiveRpcError'),
355
- (Exception, 'ValueError'),
356
- ]
357
-
358
- return self._parallel_execute_with_currency_control(
359
- self._sample_single,
360
- prompts,
361
- retry_on_errors=retry_on_errors,
362
- )
363
-
364
- def _sample_single(self, prompt: lf.Message) -> lf.LMSamplingResult:
365
- if self.sampling_options.n > 1:
366
- raise ValueError(
367
- f'`n` greater than 1 is not supported: {self.sampling_options.n}.'
368
- )
369
- api = SUPPORTED_MODELS_AND_SETTINGS[self.model].api
370
- match api:
371
- case 'gemini':
372
- return self._sample_generative_model(prompt)
373
- case 'palm':
374
- return self._sample_text_generation_model(prompt)
375
- case 'endpoint':
376
- return self._sample_endpoint_model(prompt)
377
- case _:
378
- raise ValueError(f'Unsupported API: {api}')
379
-
380
- def _sample_generative_model(self, prompt: lf.Message) -> lf.LMSamplingResult:
381
- """Samples a generative model."""
382
- model = _VERTEXAI_MODEL_HUB.get_generative_model(self.model)
383
- input_content = self._content_from_message(prompt)
384
- response = model.generate_content(
385
- input_content,
386
- generation_config=self._generation_config(
387
- prompt, self.sampling_options
388
- ),
389
- )
390
- usage_metadata = response.usage_metadata
391
- usage = lf.LMSamplingUsage(
392
- prompt_tokens=usage_metadata.prompt_token_count,
393
- completion_tokens=usage_metadata.candidates_token_count,
394
- total_tokens=usage_metadata.total_token_count,
395
- estimated_cost=self.estimate_cost(
396
- num_input_tokens=usage_metadata.prompt_token_count,
397
- num_output_tokens=usage_metadata.candidates_token_count,
398
- ),
399
- )
400
- return lf.LMSamplingResult(
401
- [
402
- # Scoring is not supported.
403
- lf.LMSample(
404
- self._generation_response_to_message(response), score=0.0
405
- ),
406
- ],
407
- usage=usage,
408
- )
409
-
410
- def _sample_text_generation_model(
411
- self, prompt: lf.Message
412
- ) -> lf.LMSamplingResult:
413
- """Samples a text generation model."""
414
- model = _VERTEXAI_MODEL_HUB.get_text_generation_model(self.model)
415
- predict_options = dict(
416
- temperature=self.sampling_options.temperature,
417
- top_k=self.sampling_options.top_k,
418
- top_p=self.sampling_options.top_p,
419
- max_output_tokens=self.sampling_options.max_tokens,
420
- stop_sequences=self.sampling_options.stop,
421
- )
422
- response = model.predict(prompt.text, **predict_options)
423
- return lf.LMSamplingResult([
424
- # Scoring is not supported.
425
- lf.LMSample(lf.AIMessage(response.text), score=0.0)
426
- ])
427
-
428
- def _sample_endpoint_model(self, prompt: lf.Message) -> lf.LMSamplingResult:
429
- """Samples a text generation model."""
430
- assert aiplatform_models is not None
431
- model = aiplatform_models.Endpoint(self.endpoint_name)
432
- # TODO(chengrun): Add support for stop_sequences.
433
- predict_options = dict(
434
- temperature=self.sampling_options.temperature
435
- if self.sampling_options.temperature is not None
436
- else 1.0,
437
- top_k=self.sampling_options.top_k
438
- if self.sampling_options.top_k is not None
439
- else 32,
440
- top_p=self.sampling_options.top_p
441
- if self.sampling_options.top_p is not None
442
- else 1,
443
- max_tokens=self.sampling_options.max_tokens
444
- if self.sampling_options.max_tokens is not None
445
- else 8192,
446
- )
447
- instances = [{'prompt': prompt.text, **predict_options}]
448
- response = model.predict(instances=instances)
449
-
450
- return lf.LMSamplingResult([
451
- # Scoring is not supported.
452
- lf.LMSample(
453
- self._generation_endpoint_response_to_message(response), score=0.0
454
- )
455
- ])
456
-
457
-
458
114
  @lf.use_init_args(['model'])
459
115
  @pg.members([('api_endpoint', pg.typing.Str().freeze(''))])
460
- class VertexAIRest(rest.REST):
116
+ class VertexAI(rest.REST):
461
117
  """Language model served on VertexAI with REST API."""
462
118
 
463
119
  model: pg.typing.Annotated[
@@ -687,39 +343,6 @@ class VertexAIRest(rest.REST):
687
343
  return lf.AIMessage.from_chunks(chunks)
688
344
 
689
345
 
690
- class _ModelHub:
691
- """Vertex AI model hub."""
692
-
693
- def __init__(self):
694
- self._generative_model_cache = {}
695
- self._text_generation_model_cache = {}
696
-
697
- def get_generative_model(
698
- self, model_id: str
699
- ) -> Any: # generative_models.GenerativeModel:
700
- """Gets a generative model by model id."""
701
- model = self._generative_model_cache.get(model_id, None)
702
- if model is None:
703
- assert generative_models is not None
704
- model = generative_models.GenerativeModel(model_id)
705
- self._generative_model_cache[model_id] = model
706
- return model
707
-
708
- def get_text_generation_model(
709
- self, model_id: str
710
- ) -> Any: # language_models.TextGenerationModel
711
- """Gets a text generation model by model id."""
712
- model = self._text_generation_model_cache.get(model_id, None)
713
- if model is None:
714
- assert language_models is not None
715
- model = language_models.TextGenerationModel.from_pretrained(model_id)
716
- self._text_generation_model_cache[model_id] = model
717
- return model
718
-
719
-
720
- _VERTEXAI_MODEL_HUB = _ModelHub()
721
-
722
-
723
346
  _IMAGE_TYPES = [
724
347
  'image/png',
725
348
  'image/jpeg',
@@ -773,33 +396,19 @@ class VertexAIGemini1_5(VertexAI): # pylint: disable=invalid-name
773
396
  )
774
397
 
775
398
 
776
- class VertexAIGeminiPro1_5_Latest(VertexAIGemini1_5): # pylint: disable=invalid-name
777
- """Vertex AI Gemini 1.5 Pro model."""
778
-
779
- model = 'gemini-1.5-pro-latest'
780
-
781
-
782
399
  class VertexAIGeminiPro1_5(VertexAIGemini1_5): # pylint: disable=invalid-name
783
400
  """Vertex AI Gemini 1.5 Pro model."""
784
401
 
785
402
  model = 'gemini-1.5-pro'
786
403
 
787
404
 
788
- class VertexAIRestGemini1_5(VertexAIRest): # pylint: disable=invalid-name
789
- """Vertex AI Gemini 1.5 model with REST API."""
790
-
791
- supported_modalities: pg.typing.List(str).freeze( # pytype: disable=invalid-annotation
792
- _DOCUMENT_TYPES + _IMAGE_TYPES + _AUDIO_TYPES + _VIDEO_TYPES
793
- )
794
-
795
-
796
- class VertexAIGeminiPro1_5_002(VertexAIRestGemini1_5): # pylint: disable=invalid-name
405
+ class VertexAIGeminiPro1_5_002(VertexAIGemini1_5): # pylint: disable=invalid-name
797
406
  """Vertex AI Gemini 1.5 Pro model."""
798
407
 
799
408
  model = 'gemini-1.5-pro-002'
800
409
 
801
410
 
802
- class VertexAIGeminiPro1_5_001(VertexAIRestGemini1_5): # pylint: disable=invalid-name
411
+ class VertexAIGeminiPro1_5_001(VertexAIGemini1_5): # pylint: disable=invalid-name
803
412
  """Vertex AI Gemini 1.5 Pro model."""
804
413
 
805
414
  model = 'gemini-1.5-pro-001'
@@ -817,25 +426,19 @@ class VertexAIGeminiPro1_5_0409(VertexAIGemini1_5): # pylint: disable=invalid-n
817
426
  model = 'gemini-1.5-pro-preview-0409'
818
427
 
819
428
 
820
- class VertexAIGeminiFlash1_5_Latest(VertexAIGemini1_5): # pylint: disable=invalid-name
821
- """Vertex AI Gemini 1.5 Flash model."""
822
-
823
- model = 'gemini-1.5-flash-latest'
824
-
825
-
826
429
  class VertexAIGeminiFlash1_5(VertexAIGemini1_5): # pylint: disable=invalid-name
827
430
  """Vertex AI Gemini 1.5 Flash model."""
828
431
 
829
432
  model = 'gemini-1.5-flash'
830
433
 
831
434
 
832
- class VertexAIGeminiFlash1_5_002(VertexAIRestGemini1_5): # pylint: disable=invalid-name
435
+ class VertexAIGeminiFlash1_5_002(VertexAIGemini1_5): # pylint: disable=invalid-name
833
436
  """Vertex AI Gemini 1.5 Flash model."""
834
437
 
835
438
  model = 'gemini-1.5-flash-002'
836
439
 
837
440
 
838
- class VertexAIGeminiFlash1_5_001(VertexAIRestGemini1_5): # pylint: disable=invalid-name
441
+ class VertexAIGeminiFlash1_5_001(VertexAIGemini1_5): # pylint: disable=invalid-name
839
442
  """Vertex AI Gemini 1.5 Flash model."""
840
443
 
841
444
  model = 'gemini-1.5-flash-001'
@@ -847,7 +450,7 @@ class VertexAIGeminiFlash1_5_0514(VertexAIGemini1_5): # pylint: disable=invalid
847
450
  model = 'gemini-1.5-flash-preview-0514'
848
451
 
849
452
 
850
- class VertexAIGeminiPro1(VertexAIRest): # pylint: disable=invalid-name
453
+ class VertexAIGeminiPro1(VertexAI): # pylint: disable=invalid-name
851
454
  """Vertex AI Gemini 1.0 Pro model."""
852
455
 
853
456
  model = 'gemini-1.0-pro'
@@ -862,19 +465,17 @@ class VertexAIGeminiPro1Vision(VertexAI): # pylint: disable=invalid-name
862
465
  )
863
466
 
864
467
 
865
- class VertexAIPalm2(VertexAI): # pylint: disable=invalid-name
866
- """Vertex AI PaLM2 text generation model."""
867
-
868
- model = 'text-bison'
869
-
870
-
871
- class VertexAIPalm2_32K(VertexAI): # pylint: disable=invalid-name
872
- """Vertex AI PaLM2 text generation model (32K context length)."""
468
+ class VertexAIEndpoint(VertexAI): # pylint: disable=invalid-name
469
+ """Vertex AI Endpoint model."""
873
470
 
874
- model = 'text-bison-32k'
471
+ model = 'vertexai-endpoint'
875
472
 
473
+ endpoint: Annotated[str, 'Vertex AI Endpoint ID.']
876
474
 
877
- class VertexAICustom(VertexAI): # pylint: disable=invalid-name
878
- """Vertex AI Custom model (Endpoint)."""
879
-
880
- model = 'custom'
475
+ @property
476
+ def api_endpoint(self) -> str:
477
+ return (
478
+ f'https://{self.location}-aiplatform.googleapis.com/v1/projects/'
479
+ f'{self.project}/locations/{self.location}/'
480
+ f'endpoints/{self.endpoint}:generateContent'
481
+ )