langfun 0.1.2.dev202412010804__py3-none-any.whl → 0.1.2.dev202412030000__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -13,17 +13,21 @@
13
13
  # limitations under the License.
14
14
  """Vertex AI generative models."""
15
15
 
16
+ import base64
16
17
  import functools
17
18
  import os
18
19
  from typing import Annotated, Any
19
20
 
20
21
  import langfun.core as lf
21
22
  from langfun.core import modalities as lf_modalities
23
+ from langfun.core.llms import rest
22
24
  import pyglove as pg
23
25
 
24
26
  try:
25
27
  # pylint: disable=g-import-not-at-top
28
+ from google import auth as google_auth
26
29
  from google.auth import credentials as credentials_lib
30
+ from google.auth.transport import requests as auth_requests
27
31
  import vertexai
28
32
  from google.cloud.aiplatform import models as aiplatform_models
29
33
  from vertexai import generative_models
@@ -32,6 +36,8 @@ try:
32
36
 
33
37
  Credentials = credentials_lib.Credentials
34
38
  except ImportError:
39
+ google_auth = None
40
+ auth_requests = None
35
41
  credentials_lib = None # pylint: disable=invalid-name
36
42
  vertexai = None
37
43
  generative_models = None
@@ -449,6 +455,238 @@ class VertexAI(lf.LanguageModel):
449
455
  ])
450
456
 
451
457
 
458
+ @lf.use_init_args(['model'])
459
+ @pg.members([('api_endpoint', pg.typing.Str().freeze(''))])
460
+ class VertexAIRest(rest.REST):
461
+ """Language model served on VertexAI with REST API."""
462
+
463
+ model: pg.typing.Annotated[
464
+ pg.typing.Enum(
465
+ pg.MISSING_VALUE, list(SUPPORTED_MODELS_AND_SETTINGS.keys())
466
+ ),
467
+ (
468
+ 'Vertex AI model name with REST API support. See '
469
+ 'https://cloud.google.com/vertex-ai/generative-ai/docs/'
470
+ 'model-reference/inference#supported-models'
471
+ ' for details.'
472
+ ),
473
+ ]
474
+
475
+ project: Annotated[
476
+ str | None,
477
+ (
478
+ 'Vertex AI project ID. Or set from environment variable '
479
+ 'VERTEXAI_PROJECT.'
480
+ ),
481
+ ] = None
482
+
483
+ location: Annotated[
484
+ str | None,
485
+ (
486
+ 'Vertex AI service location. Or set from environment variable '
487
+ 'VERTEXAI_LOCATION.'
488
+ ),
489
+ ] = None
490
+
491
+ credentials: Annotated[
492
+ Credentials | None,
493
+ (
494
+ 'Credentials to use. If None, the default credentials to the '
495
+ 'environment will be used.'
496
+ ),
497
+ ] = None
498
+
499
+ supported_modalities: Annotated[
500
+ list[str],
501
+ 'A list of MIME types for supported modalities'
502
+ ] = []
503
+
504
+ def _on_bound(self):
505
+ super()._on_bound()
506
+ if google_auth is None:
507
+ raise ValueError(
508
+ 'Please install "langfun[llm-google-vertex]" to use Vertex AI models.'
509
+ )
510
+ self._project = None
511
+ self._credentials = None
512
+
513
+ def _initialize(self):
514
+ project = self.project or os.environ.get('VERTEXAI_PROJECT', None)
515
+ if not project:
516
+ raise ValueError(
517
+ 'Please specify `project` during `__init__` or set environment '
518
+ 'variable `VERTEXAI_PROJECT` with your Vertex AI project ID.'
519
+ )
520
+
521
+ location = self.location or os.environ.get('VERTEXAI_LOCATION', None)
522
+ if not location:
523
+ raise ValueError(
524
+ 'Please specify `location` during `__init__` or set environment '
525
+ 'variable `VERTEXAI_LOCATION` with your Vertex AI service location.'
526
+ )
527
+
528
+ self._project = project
529
+ credentials = self.credentials
530
+ if credentials is None:
531
+ # Use default credentials.
532
+ credentials = google_auth.default(
533
+ scopes=['https://www.googleapis.com/auth/cloud-platform']
534
+ )
535
+ self._credentials = credentials
536
+
537
+ @property
538
+ def max_concurrency(self) -> int:
539
+ """Returns the maximum number of concurrent requests."""
540
+ return self.rate_to_max_concurrency(
541
+ requests_per_min=SUPPORTED_MODELS_AND_SETTINGS[self.model].rpm,
542
+ tokens_per_min=0,
543
+ )
544
+
545
+ def estimate_cost(
546
+ self,
547
+ num_input_tokens: int,
548
+ num_output_tokens: int
549
+ ) -> float | None:
550
+ """Estimate the cost based on usage."""
551
+ cost_per_1k_input_chars = SUPPORTED_MODELS_AND_SETTINGS[self.model].get(
552
+ 'cost_per_1k_input_chars', None
553
+ )
554
+ cost_per_1k_output_chars = SUPPORTED_MODELS_AND_SETTINGS[self.model].get(
555
+ 'cost_per_1k_output_chars', None
556
+ )
557
+ if cost_per_1k_output_chars is None or cost_per_1k_input_chars is None:
558
+ return None
559
+ return (
560
+ cost_per_1k_input_chars * num_input_tokens
561
+ + cost_per_1k_output_chars * num_output_tokens
562
+ ) * AVGERAGE_CHARS_PER_TOEKN / 1000
563
+
564
+ @functools.cached_property
565
+ def _session(self):
566
+ assert self._api_initialized
567
+ assert self._credentials is not None
568
+ assert auth_requests is not None
569
+ s = auth_requests.AuthorizedSession(self._credentials)
570
+ s.headers.update(self.headers or {})
571
+ return s
572
+
573
+ @property
574
+ def headers(self):
575
+ return {
576
+ 'Content-Type': 'application/json; charset=utf-8',
577
+ }
578
+
579
+ @property
580
+ def api_endpoint(self) -> str:
581
+ return (
582
+ f'https://{self.location}-aiplatform.googleapis.com/v1/projects/'
583
+ f'{self.project}/locations/{self.location}/publishers/google/'
584
+ f'models/{self.model}:generateContent'
585
+ )
586
+
587
+ def request(
588
+ self, prompt: lf.Message, sampling_options: lf.LMSamplingOptions
589
+ ) -> dict[str, Any]:
590
+ request = dict(
591
+ generationConfig=self._generation_config(prompt, sampling_options)
592
+ )
593
+ request['contents'] = [self._content_from_message(prompt)]
594
+ return request
595
+
596
+ def _generation_config(
597
+ self, prompt: lf.Message, options: lf.LMSamplingOptions
598
+ ) -> dict[str, Any]:
599
+ """Returns a dict as generation config for prompt and LMSamplingOptions."""
600
+ config = dict(
601
+ temperature=options.temperature,
602
+ maxOutputTokens=options.max_tokens,
603
+ candidateCount=options.n,
604
+ topK=options.top_k,
605
+ topP=options.top_p,
606
+ stopSequences=options.stop,
607
+ seed=options.random_seed,
608
+ responseLogprobs=options.logprobs,
609
+ logprobs=options.top_logprobs,
610
+ )
611
+
612
+ if json_schema := prompt.metadata.get('json_schema'):
613
+ if not isinstance(json_schema, dict):
614
+ raise ValueError(
615
+ f'`json_schema` must be a dict, got {json_schema!r}.'
616
+ )
617
+ json_schema = pg.to_json(json_schema)
618
+ config['responseSchema'] = json_schema
619
+ config['responseMimeType'] = 'application/json'
620
+ prompt.metadata.formatted_text = (
621
+ prompt.text
622
+ + '\n\n [RESPONSE FORMAT (not part of prompt)]\n'
623
+ + pg.to_json_str(json_schema, json_indent=2)
624
+ )
625
+ return config
626
+
627
+ def _content_from_message(self, prompt: lf.Message) -> dict[str, Any]:
628
+ """Gets generation content from langfun message."""
629
+ parts = []
630
+ for lf_chunk in prompt.chunk():
631
+ if isinstance(lf_chunk, str):
632
+ parts.append({'text': lf_chunk})
633
+ elif isinstance(lf_chunk, lf_modalities.Mime):
634
+ try:
635
+ modalities = lf_chunk.make_compatible(
636
+ self.supported_modalities + ['text/plain']
637
+ )
638
+ if isinstance(modalities, lf_modalities.Mime):
639
+ modalities = [modalities]
640
+ for modality in modalities:
641
+ if modality.is_text:
642
+ parts.append({'text': modality.to_text()})
643
+ else:
644
+ parts.append({
645
+ 'inlineData': {
646
+ 'data': base64.b64encode(modality.to_bytes()).decode(),
647
+ 'mimeType': modality.mime_type,
648
+ }
649
+ })
650
+ except lf.ModalityError as e:
651
+ raise lf.ModalityError(f'Unsupported modality: {lf_chunk!r}') from e
652
+ else:
653
+ raise lf.ModalityError(f'Unsupported modality: {lf_chunk!r}')
654
+ return dict(role='user', parts=parts)
655
+
656
+ def result(self, json: dict[str, Any]) -> lf.LMSamplingResult:
657
+ messages = [
658
+ self._message_from_content_parts(candidate['content']['parts'])
659
+ for candidate in json['candidates']
660
+ ]
661
+ usage = json['usageMetadata']
662
+ input_tokens = usage['promptTokenCount']
663
+ output_tokens = usage['candidatesTokenCount']
664
+ return lf.LMSamplingResult(
665
+ [lf.LMSample(message) for message in messages],
666
+ usage=lf.LMSamplingUsage(
667
+ prompt_tokens=input_tokens,
668
+ completion_tokens=output_tokens,
669
+ total_tokens=input_tokens + output_tokens,
670
+ estimated_cost=self.estimate_cost(
671
+ num_input_tokens=input_tokens,
672
+ num_output_tokens=output_tokens,
673
+ ),
674
+ ),
675
+ )
676
+
677
+ def _message_from_content_parts(
678
+ self, parts: list[dict[str, Any]]
679
+ ) -> lf.Message:
680
+ """Converts Vertex AI's content parts protocol to message."""
681
+ chunks = []
682
+ for part in parts:
683
+ if text_part := part.get('text'):
684
+ chunks.append(text_part)
685
+ else:
686
+ raise ValueError(f'Unsupported part: {part}')
687
+ return lf.AIMessage.from_chunks(chunks)
688
+
689
+
452
690
  class _ModelHub:
453
691
  """Vertex AI model hub."""
454
692
 
@@ -547,13 +785,21 @@ class VertexAIGeminiPro1_5(VertexAIGemini1_5): # pylint: disable=invalid-name
547
785
  model = 'gemini-1.5-pro'
548
786
 
549
787
 
550
- class VertexAIGeminiPro1_5_002(VertexAIGemini1_5): # pylint: disable=invalid-name
788
+ class VertexAIRestGemini1_5(VertexAIRest): # pylint: disable=invalid-name
789
+ """Vertex AI Gemini 1.5 model with REST API."""
790
+
791
+ supported_modalities: pg.typing.List(str).freeze( # pytype: disable=invalid-annotation
792
+ _DOCUMENT_TYPES + _IMAGE_TYPES + _AUDIO_TYPES + _VIDEO_TYPES
793
+ )
794
+
795
+
796
+ class VertexAIGeminiPro1_5_002(VertexAIRestGemini1_5): # pylint: disable=invalid-name
551
797
  """Vertex AI Gemini 1.5 Pro model."""
552
798
 
553
799
  model = 'gemini-1.5-pro-002'
554
800
 
555
801
 
556
- class VertexAIGeminiPro1_5_001(VertexAIGemini1_5): # pylint: disable=invalid-name
802
+ class VertexAIGeminiPro1_5_001(VertexAIRestGemini1_5): # pylint: disable=invalid-name
557
803
  """Vertex AI Gemini 1.5 Pro model."""
558
804
 
559
805
  model = 'gemini-1.5-pro-001'
@@ -583,13 +829,13 @@ class VertexAIGeminiFlash1_5(VertexAIGemini1_5): # pylint: disable=invalid-name
583
829
  model = 'gemini-1.5-flash'
584
830
 
585
831
 
586
- class VertexAIGeminiFlash1_5_002(VertexAIGemini1_5): # pylint: disable=invalid-name
832
+ class VertexAIGeminiFlash1_5_002(VertexAIRestGemini1_5): # pylint: disable=invalid-name
587
833
  """Vertex AI Gemini 1.5 Flash model."""
588
834
 
589
835
  model = 'gemini-1.5-flash-002'
590
836
 
591
837
 
592
- class VertexAIGeminiFlash1_5_001(VertexAIGemini1_5): # pylint: disable=invalid-name
838
+ class VertexAIGeminiFlash1_5_001(VertexAIRestGemini1_5): # pylint: disable=invalid-name
593
839
  """Vertex AI Gemini 1.5 Flash model."""
594
840
 
595
841
  model = 'gemini-1.5-flash-001'
@@ -601,14 +847,14 @@ class VertexAIGeminiFlash1_5_0514(VertexAIGemini1_5): # pylint: disable=invalid
601
847
  model = 'gemini-1.5-flash-preview-0514'
602
848
 
603
849
 
604
- class VertexAIGeminiPro1(VertexAI): # pylint: disable=invalid-name
850
+ class VertexAIGeminiPro1(VertexAIRest): # pylint: disable=invalid-name
605
851
  """Vertex AI Gemini 1.0 Pro model."""
606
852
 
607
853
  model = 'gemini-1.0-pro'
608
854
 
609
855
 
610
856
  class VertexAIGeminiPro1Vision(VertexAI): # pylint: disable=invalid-name
611
- """Vertex AI Gemini 1.0 Pro model."""
857
+ """Vertex AI Gemini 1.0 Pro Vision model."""
612
858
 
613
859
  model = 'gemini-1.0-pro-vision'
614
860
  supported_modalities: pg.typing.List(str).freeze( # pytype: disable=invalid-annotation
@@ -14,6 +14,7 @@
14
14
  """Tests for Gemini models."""
15
15
 
16
16
  import os
17
+ from typing import Any
17
18
  import unittest
18
19
  from unittest import mock
19
20
 
@@ -23,6 +24,7 @@ import langfun.core as lf
23
24
  from langfun.core import modalities as lf_modalities
24
25
  from langfun.core.llms import vertexai
25
26
  import pyglove as pg
27
+ import requests
26
28
 
27
29
 
28
30
  example_image = (
@@ -64,6 +66,40 @@ def mock_generate_content(content, generation_config, **kwargs):
64
66
  })
65
67
 
66
68
 
69
+ def mock_requests_post(url: str, json: dict[str, Any], **kwargs):
70
+ del url, kwargs
71
+ c = pg.Dict(json['generationConfig'])
72
+ content = json['contents'][0]['parts'][0]['text']
73
+ response = requests.Response()
74
+ response.status_code = 200
75
+ response._content = pg.to_json_str({
76
+ 'candidates': [
77
+ {
78
+ 'content': {
79
+ 'role': 'model',
80
+ 'parts': [
81
+ {
82
+ 'text': (
83
+ f'This is a response to {content} with '
84
+ f'temperature={c.temperature}, '
85
+ f'top_p={c.topP}, '
86
+ f'top_k={c.topK}, '
87
+ f'max_tokens={c.maxOutputTokens}, '
88
+ f'stop={"".join(c.stopSequences)}.'
89
+ )
90
+ },
91
+ ],
92
+ },
93
+ },
94
+ ],
95
+ 'usageMetadata': {
96
+ 'promptTokenCount': 3,
97
+ 'candidatesTokenCount': 4,
98
+ }
99
+ }).encode()
100
+ return response
101
+
102
+
67
103
  def mock_endpoint_predict(instances, **kwargs):
68
104
  del kwargs
69
105
  assert len(instances) == 1
@@ -83,7 +119,7 @@ class VertexAITest(unittest.TestCase):
83
119
 
84
120
  def test_content_from_message_text_only(self):
85
121
  text = 'This is a beautiful day'
86
- model = vertexai.VertexAIGeminiPro1()
122
+ model = vertexai.VertexAIGeminiPro1Vision()
87
123
  chunks = model._content_from_message(lf.UserMessage(text))
88
124
  self.assertEqual(chunks, [text])
89
125
 
@@ -95,7 +131,7 @@ class VertexAITest(unittest.TestCase):
95
131
 
96
132
  # Non-multimodal model.
97
133
  with self.assertRaisesRegex(lf.ModalityError, 'Unsupported modality'):
98
- vertexai.VertexAIGeminiPro1()._content_from_message(message)
134
+ vertexai.VertexAIPalm2()._content_from_message(message)
99
135
 
100
136
  model = vertexai.VertexAIGeminiPro1Vision()
101
137
  chunks = model._content_from_message(message)
@@ -119,7 +155,7 @@ class VertexAITest(unittest.TestCase):
119
155
  },
120
156
  ],
121
157
  })
122
- model = vertexai.VertexAIGeminiPro1()
158
+ model = vertexai.VertexAIGeminiPro1Vision()
123
159
  message = model._generation_response_to_message(response)
124
160
  self.assertEqual(message, lf.AIMessage('hello world'))
125
161
 
@@ -158,25 +194,25 @@ class VertexAITest(unittest.TestCase):
158
194
 
159
195
  def test_project_and_location_check(self):
160
196
  with self.assertRaisesRegex(ValueError, 'Please specify `project`'):
161
- _ = vertexai.VertexAIGeminiPro1()._api_initialized
197
+ _ = vertexai.VertexAIGeminiPro1Vision()._api_initialized
162
198
 
163
199
  with self.assertRaisesRegex(ValueError, 'Please specify `location`'):
164
- _ = vertexai.VertexAIGeminiPro1(project='abc')._api_initialized
200
+ _ = vertexai.VertexAIGeminiPro1Vision(project='abc')._api_initialized
165
201
 
166
202
  self.assertTrue(
167
- vertexai.VertexAIGeminiPro1(
203
+ vertexai.VertexAIGeminiPro1Vision(
168
204
  project='abc', location='us-central1'
169
205
  )._api_initialized
170
206
  )
171
207
 
172
208
  os.environ['VERTEXAI_PROJECT'] = 'abc'
173
209
  os.environ['VERTEXAI_LOCATION'] = 'us-central1'
174
- self.assertTrue(vertexai.VertexAIGeminiPro1()._api_initialized)
210
+ self.assertTrue(vertexai.VertexAIGeminiPro1Vision()._api_initialized)
175
211
  del os.environ['VERTEXAI_PROJECT']
176
212
  del os.environ['VERTEXAI_LOCATION']
177
213
 
178
214
  def test_generation_config(self):
179
- model = vertexai.VertexAIGeminiPro1()
215
+ model = vertexai.VertexAIGeminiPro1Vision()
180
216
  json_schema = {
181
217
  'type': 'object',
182
218
  'properties': {
@@ -245,7 +281,9 @@ class VertexAITest(unittest.TestCase):
245
281
  ) as mock_generate:
246
282
  mock_generate.side_effect = mock_generate_content
247
283
 
248
- lm = vertexai.VertexAIGeminiPro1(project='abc', location='us-central1')
284
+ lm = vertexai.VertexAIGeminiPro1Vision(
285
+ project='abc', location='us-central1'
286
+ )
249
287
  self.assertEqual(
250
288
  lm(
251
289
  'hello',
@@ -328,5 +366,163 @@ class VertexAITest(unittest.TestCase):
328
366
  )
329
367
 
330
368
 
369
+ class VertexRestfulAITest(unittest.TestCase):
370
+ """Tests for Vertex model with REST API."""
371
+
372
+ def test_content_from_message_text_only(self):
373
+ text = 'This is a beautiful day'
374
+ model = vertexai.VertexAIGeminiPro1_5_002()
375
+ chunks = model._content_from_message(lf.UserMessage(text))
376
+ self.assertEqual(chunks, {'role': 'user', 'parts': [{'text': text}]})
377
+
378
+ def test_content_from_message_mm(self):
379
+ message = lf.UserMessage(
380
+ 'This is an <<[[image]]>>, what is it?',
381
+ image=lf_modalities.Image.from_bytes(example_image),
382
+ )
383
+
384
+ # Non-multimodal model.
385
+ with self.assertRaisesRegex(lf.ModalityError, 'Unsupported modality'):
386
+ vertexai.VertexAIGeminiPro1()._content_from_message(message)
387
+
388
+ model = vertexai.VertexAIGeminiPro1Vision()
389
+ chunks = model._content_from_message(message)
390
+ self.maxDiff = None
391
+ self.assertEqual([chunks[0], chunks[2]], ['This is an', ', what is it?'])
392
+ self.assertIsInstance(chunks[1], generative_models.Part)
393
+
394
+ def test_generation_response_to_message_text_only(self):
395
+ response = generative_models.GenerationResponse.from_dict({
396
+ 'candidates': [
397
+ {
398
+ 'index': 0,
399
+ 'content': {
400
+ 'role': 'model',
401
+ 'parts': [
402
+ {
403
+ 'text': 'hello world',
404
+ },
405
+ ],
406
+ },
407
+ },
408
+ ],
409
+ })
410
+ model = vertexai.VertexAIGeminiPro1Vision()
411
+ message = model._generation_response_to_message(response)
412
+ self.assertEqual(message, lf.AIMessage('hello world'))
413
+
414
+ def test_model_hub(self):
415
+ with mock.patch(
416
+ 'vertexai.generative_models.'
417
+ 'GenerativeModel.__init__'
418
+ ) as mock_model_init:
419
+ mock_model_init.side_effect = lambda *args, **kwargs: None
420
+ model = vertexai._VERTEXAI_MODEL_HUB.get_generative_model(
421
+ 'gemini-1.0-pro'
422
+ )
423
+ self.assertIsNotNone(model)
424
+ self.assertIs(
425
+ vertexai._VERTEXAI_MODEL_HUB.get_generative_model('gemini-1.0-pro'),
426
+ model,
427
+ )
428
+
429
+ @mock.patch.object(vertexai.VertexAIRest, 'credentials', new=True)
430
+ def test_project_and_location_check(self):
431
+ with self.assertRaisesRegex(ValueError, 'Please specify `project`'):
432
+ _ = vertexai.VertexAIGeminiPro1()._api_initialized
433
+
434
+ with self.assertRaisesRegex(ValueError, 'Please specify `location`'):
435
+ _ = vertexai.VertexAIGeminiPro1(project='abc')._api_initialized
436
+
437
+ self.assertTrue(
438
+ vertexai.VertexAIGeminiPro1(
439
+ project='abc', location='us-central1'
440
+ )._api_initialized
441
+ )
442
+
443
+ os.environ['VERTEXAI_PROJECT'] = 'abc'
444
+ os.environ['VERTEXAI_LOCATION'] = 'us-central1'
445
+ self.assertTrue(vertexai.VertexAIGeminiPro1()._api_initialized)
446
+ del os.environ['VERTEXAI_PROJECT']
447
+ del os.environ['VERTEXAI_LOCATION']
448
+
449
+ def test_generation_config(self):
450
+ model = vertexai.VertexAIGeminiPro1()
451
+ json_schema = {
452
+ 'type': 'object',
453
+ 'properties': {
454
+ 'name': {'type': 'string'},
455
+ },
456
+ 'required': ['name'],
457
+ 'title': 'Person',
458
+ }
459
+ actual = model._generation_config(
460
+ lf.UserMessage('hi', json_schema=json_schema),
461
+ lf.LMSamplingOptions(
462
+ temperature=2.0,
463
+ top_p=1.0,
464
+ top_k=20,
465
+ max_tokens=1024,
466
+ stop=['\n'],
467
+ ),
468
+ )
469
+ self.assertEqual(
470
+ actual,
471
+ dict(
472
+ candidateCount=1,
473
+ temperature=2.0,
474
+ topP=1.0,
475
+ topK=20,
476
+ maxOutputTokens=1024,
477
+ stopSequences=['\n'],
478
+ responseLogprobs=False,
479
+ logprobs=None,
480
+ seed=None,
481
+ responseMimeType='application/json',
482
+ responseSchema={
483
+ 'type': 'object',
484
+ 'properties': {
485
+ 'name': {'type': 'string'}
486
+ },
487
+ 'required': ['name'],
488
+ 'title': 'Person',
489
+ }
490
+ ),
491
+ )
492
+ with self.assertRaisesRegex(
493
+ ValueError, '`json_schema` must be a dict, got'
494
+ ):
495
+ model._generation_config(
496
+ lf.UserMessage('hi', json_schema='not a dict'),
497
+ lf.LMSamplingOptions(),
498
+ )
499
+
500
+ @mock.patch.object(vertexai.VertexAIRest, 'credentials', new=True)
501
+ def test_call_model(self):
502
+ with mock.patch('requests.Session.post') as mock_generate:
503
+ mock_generate.side_effect = mock_requests_post
504
+
505
+ lm = vertexai.VertexAIGeminiPro1_5_002(
506
+ project='abc', location='us-central1'
507
+ )
508
+ r = lm(
509
+ 'hello',
510
+ temperature=2.0,
511
+ top_p=1.0,
512
+ top_k=20,
513
+ max_tokens=1024,
514
+ stop='\n',
515
+ )
516
+ self.assertEqual(
517
+ r.text,
518
+ (
519
+ 'This is a response to hello with temperature=2.0, '
520
+ 'top_p=1.0, top_k=20, max_tokens=1024, stop=\n.'
521
+ ),
522
+ )
523
+ self.assertEqual(r.metadata.usage.prompt_tokens, 3)
524
+ self.assertEqual(r.metadata.usage.completion_tokens, 4)
525
+
526
+
331
527
  if __name__ == '__main__':
332
528
  unittest.main()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: langfun
3
- Version: 0.1.2.dev202412010804
3
+ Version: 0.1.2.dev202412030000
4
4
  Summary: Langfun: Language as Functions.
5
5
  Home-page: https://github.com/google/langfun
6
6
  Author: Langfun Authors
@@ -32,7 +32,6 @@ Requires-Dist: termcolor==1.1.0; extra == "all"
32
32
  Requires-Dist: tqdm>=4.64.1; extra == "all"
33
33
  Requires-Dist: google-cloud-aiplatform>=1.5.0; extra == "all"
34
34
  Requires-Dist: google-generativeai>=0.3.2; extra == "all"
35
- Requires-Dist: openai>=0.27.2; extra == "all"
36
35
  Requires-Dist: python-magic>=0.4.27; extra == "all"
37
36
  Requires-Dist: python-docx>=0.8.11; extra == "all"
38
37
  Requires-Dist: pillow>=10.0.0; extra == "all"
@@ -44,7 +43,6 @@ Requires-Dist: tqdm>=4.64.1; extra == "ui"
44
43
  Provides-Extra: llm
45
44
  Requires-Dist: google-cloud-aiplatform>=1.5.0; extra == "llm"
46
45
  Requires-Dist: google-generativeai>=0.3.2; extra == "llm"
47
- Requires-Dist: openai>=0.27.2; extra == "llm"
48
46
  Provides-Extra: llm-google
49
47
  Requires-Dist: google-cloud-aiplatform>=1.5.0; extra == "llm-google"
50
48
  Requires-Dist: google-generativeai>=0.3.2; extra == "llm-google"
@@ -52,8 +50,6 @@ Provides-Extra: llm-google-vertex
52
50
  Requires-Dist: google-cloud-aiplatform>=1.5.0; extra == "llm-google-vertex"
53
51
  Provides-Extra: llm-google-genai
54
52
  Requires-Dist: google-generativeai>=0.3.2; extra == "llm-google-genai"
55
- Provides-Extra: llm-openai
56
- Requires-Dist: openai>=0.27.2; extra == "llm-openai"
57
53
  Provides-Extra: mime
58
54
  Requires-Dist: python-magic>=0.4.27; extra == "mime"
59
55
  Requires-Dist: python-docx>=0.8.11; extra == "mime"
@@ -214,7 +210,6 @@ If you want to customize your installation, you can select specific features usi
214
210
  | llm-google | All supported Google-powered LLMs. |
215
211
  | llm-google-vertexai | LLMs powered by Google Cloud VertexAI |
216
212
  | llm-google-genai | LLMs powered by Google Generative AI API |
217
- | llm-openai | LLMs powered by OpenAI |
218
213
  | mime | All MIME supports. |
219
214
  | mime-auto | Automatic MIME type detection. |
220
215
  | mime-docx | DocX format support. |