langfun 0.1.2.dev202411270804__py3-none-any.whl → 0.1.2.dev202412020805__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -87,7 +87,8 @@ class _TqdmProgressTracker(experiment_lib.Plugin):
87
87
  self._leaf_progresses = {
88
88
  leaf.id: lf.concurrent.ProgressBar.install(
89
89
  label=f'[#{i + 1} - {leaf.id}]',
90
- total=leaf.num_examples,
90
+ total=(len(runner.current_run.example_ids)
91
+ if runner.current_run.example_ids else leaf.num_examples),
91
92
  color='cyan',
92
93
  status=None
93
94
  )
@@ -51,6 +51,16 @@ class TqdmProgressTrackerTest(unittest.TestCase):
51
51
  _ = experiment.run(root_dir, 'new', plugins=[])
52
52
  self.assertIn('All: 100%', string_io.getvalue())
53
53
 
54
+ def test_with_example_ids(self):
55
+ root_dir = os.path.join(
56
+ tempfile.gettempdir(), 'test_tqdm_progress_tracker_with_example_ids'
57
+ )
58
+ experiment = test_helper.test_experiment()
59
+ string_io = io.StringIO()
60
+ with contextlib.redirect_stderr(string_io):
61
+ _ = experiment.run(root_dir, 'new', example_ids=[1], plugins=[])
62
+ self.assertIn('All: 100%', string_io.getvalue())
63
+
54
64
 
55
65
  if __name__ == '__main__':
56
66
  unittest.main()
@@ -120,6 +120,8 @@ from langfun.core.llms.groq import GroqWhisper_Large_v3
120
120
  from langfun.core.llms.groq import GroqWhisper_Large_v3Turbo
121
121
 
122
122
  from langfun.core.llms.vertexai import VertexAI
123
+ from langfun.core.llms.vertexai import VertexAIRest
124
+ from langfun.core.llms.vertexai import VertexAIRestGemini1_5
123
125
  from langfun.core.llms.vertexai import VertexAIGemini1_5
124
126
  from langfun.core.llms.vertexai import VertexAIGeminiPro1_5
125
127
  from langfun.core.llms.vertexai import VertexAIGeminiPro1_5_Latest
@@ -530,7 +530,12 @@ class OpenAI(lf.LanguageModel):
530
530
  if isinstance(chunk, str):
531
531
  item = dict(type='text', text=chunk)
532
532
  elif isinstance(chunk, lf_modalities.Image):
533
- uri = chunk.uri or chunk.content_uri
533
+ if chunk.uri and chunk.uri.lower().startswith(
534
+ ('http:', 'https:', 'ftp:')
535
+ ):
536
+ uri = chunk.uri
537
+ else:
538
+ uri = chunk.content_uri
534
539
  item = dict(type='image_url', image_url=dict(url=uri))
535
540
  else:
536
541
  raise ValueError(f'Unsupported modality object: {chunk!r}.')
@@ -13,17 +13,21 @@
13
13
  # limitations under the License.
14
14
  """Vertex AI generative models."""
15
15
 
16
+ import base64
16
17
  import functools
17
18
  import os
18
19
  from typing import Annotated, Any
19
20
 
20
21
  import langfun.core as lf
21
22
  from langfun.core import modalities as lf_modalities
23
+ from langfun.core.llms import rest
22
24
  import pyglove as pg
23
25
 
24
26
  try:
25
27
  # pylint: disable=g-import-not-at-top
28
+ from google import auth as google_auth
26
29
  from google.auth import credentials as credentials_lib
30
+ from google.auth.transport import requests as auth_requests
27
31
  import vertexai
28
32
  from google.cloud.aiplatform import models as aiplatform_models
29
33
  from vertexai import generative_models
@@ -32,6 +36,8 @@ try:
32
36
 
33
37
  Credentials = credentials_lib.Credentials
34
38
  except ImportError:
39
+ google_auth = None
40
+ auth_requests = None
35
41
  credentials_lib = None # pylint: disable=invalid-name
36
42
  vertexai = None
37
43
  generative_models = None
@@ -449,6 +455,238 @@ class VertexAI(lf.LanguageModel):
449
455
  ])
450
456
 
451
457
 
458
+ @lf.use_init_args(['model'])
459
+ @pg.members([('api_endpoint', pg.typing.Str().freeze(''))])
460
+ class VertexAIRest(rest.REST):
461
+ """Language model served on VertexAI with REST API."""
462
+
463
+ model: pg.typing.Annotated[
464
+ pg.typing.Enum(
465
+ pg.MISSING_VALUE, list(SUPPORTED_MODELS_AND_SETTINGS.keys())
466
+ ),
467
+ (
468
+ 'Vertex AI model name with REST API support. See '
469
+ 'https://cloud.google.com/vertex-ai/generative-ai/docs/'
470
+ 'model-reference/inference#supported-models'
471
+ ' for details.'
472
+ ),
473
+ ]
474
+
475
+ project: Annotated[
476
+ str | None,
477
+ (
478
+ 'Vertex AI project ID. Or set from environment variable '
479
+ 'VERTEXAI_PROJECT.'
480
+ ),
481
+ ] = None
482
+
483
+ location: Annotated[
484
+ str | None,
485
+ (
486
+ 'Vertex AI service location. Or set from environment variable '
487
+ 'VERTEXAI_LOCATION.'
488
+ ),
489
+ ] = None
490
+
491
+ credentials: Annotated[
492
+ Credentials | None,
493
+ (
494
+ 'Credentials to use. If None, the default credentials to the '
495
+ 'environment will be used.'
496
+ ),
497
+ ] = None
498
+
499
+ supported_modalities: Annotated[
500
+ list[str],
501
+ 'A list of MIME types for supported modalities'
502
+ ] = []
503
+
504
+ def _on_bound(self):
505
+ super()._on_bound()
506
+ if google_auth is None:
507
+ raise ValueError(
508
+ 'Please install "langfun[llm-google-vertex]" to use Vertex AI models.'
509
+ )
510
+ self._project = None
511
+ self._credentials = None
512
+
513
+ def _initialize(self):
514
+ project = self.project or os.environ.get('VERTEXAI_PROJECT', None)
515
+ if not project:
516
+ raise ValueError(
517
+ 'Please specify `project` during `__init__` or set environment '
518
+ 'variable `VERTEXAI_PROJECT` with your Vertex AI project ID.'
519
+ )
520
+
521
+ location = self.location or os.environ.get('VERTEXAI_LOCATION', None)
522
+ if not location:
523
+ raise ValueError(
524
+ 'Please specify `location` during `__init__` or set environment '
525
+ 'variable `VERTEXAI_LOCATION` with your Vertex AI service location.'
526
+ )
527
+
528
+ self._project = project
529
+ credentials = self.credentials
530
+ if credentials is None:
531
+ # Use default credentials.
532
+ credentials = google_auth.default(
533
+ scopes=['https://www.googleapis.com/auth/cloud-platform']
534
+ )
535
+ self._credentials = credentials
536
+
537
+ @property
538
+ def max_concurrency(self) -> int:
539
+ """Returns the maximum number of concurrent requests."""
540
+ return self.rate_to_max_concurrency(
541
+ requests_per_min=SUPPORTED_MODELS_AND_SETTINGS[self.model].rpm,
542
+ tokens_per_min=0,
543
+ )
544
+
545
+ def estimate_cost(
546
+ self,
547
+ num_input_tokens: int,
548
+ num_output_tokens: int
549
+ ) -> float | None:
550
+ """Estimate the cost based on usage."""
551
+ cost_per_1k_input_chars = SUPPORTED_MODELS_AND_SETTINGS[self.model].get(
552
+ 'cost_per_1k_input_chars', None
553
+ )
554
+ cost_per_1k_output_chars = SUPPORTED_MODELS_AND_SETTINGS[self.model].get(
555
+ 'cost_per_1k_output_chars', None
556
+ )
557
+ if cost_per_1k_output_chars is None or cost_per_1k_input_chars is None:
558
+ return None
559
+ return (
560
+ cost_per_1k_input_chars * num_input_tokens
561
+ + cost_per_1k_output_chars * num_output_tokens
562
+ ) * AVGERAGE_CHARS_PER_TOEKN / 1000
563
+
564
+ @functools.cached_property
565
+ def _session(self):
566
+ assert self._api_initialized
567
+ assert self._credentials is not None
568
+ assert auth_requests is not None
569
+ s = auth_requests.AuthorizedSession(self._credentials)
570
+ s.headers.update(self.headers or {})
571
+ return s
572
+
573
+ @property
574
+ def headers(self):
575
+ return {
576
+ 'Content-Type': 'application/json; charset=utf-8',
577
+ }
578
+
579
+ @property
580
+ def api_endpoint(self) -> str:
581
+ return (
582
+ f'https://{self.location}-aiplatform.googleapis.com/v1/projects/'
583
+ f'{self.project}/locations/{self.location}/publishers/google/'
584
+ f'models/{self.model}:generateContent'
585
+ )
586
+
587
+ def request(
588
+ self, prompt: lf.Message, sampling_options: lf.LMSamplingOptions
589
+ ) -> dict[str, Any]:
590
+ request = dict(
591
+ generationConfig=self._generation_config(prompt, sampling_options)
592
+ )
593
+ request['contents'] = [self._content_from_message(prompt)]
594
+ return request
595
+
596
+ def _generation_config(
597
+ self, prompt: lf.Message, options: lf.LMSamplingOptions
598
+ ) -> dict[str, Any]:
599
+ """Returns a dict as generation config for prompt and LMSamplingOptions."""
600
+ config = dict(
601
+ temperature=options.temperature,
602
+ maxOutputTokens=options.max_tokens,
603
+ candidateCount=options.n,
604
+ topK=options.top_k,
605
+ topP=options.top_p,
606
+ stopSequences=options.stop,
607
+ seed=options.random_seed,
608
+ responseLogprobs=options.logprobs,
609
+ logprobs=options.top_logprobs,
610
+ )
611
+
612
+ if json_schema := prompt.metadata.get('json_schema'):
613
+ if not isinstance(json_schema, dict):
614
+ raise ValueError(
615
+ f'`json_schema` must be a dict, got {json_schema!r}.'
616
+ )
617
+ json_schema = pg.to_json(json_schema)
618
+ config['responseSchema'] = json_schema
619
+ config['responseMimeType'] = 'application/json'
620
+ prompt.metadata.formatted_text = (
621
+ prompt.text
622
+ + '\n\n [RESPONSE FORMAT (not part of prompt)]\n'
623
+ + pg.to_json_str(json_schema, json_indent=2)
624
+ )
625
+ return config
626
+
627
+ def _content_from_message(self, prompt: lf.Message) -> dict[str, Any]:
628
+ """Gets generation content from langfun message."""
629
+ parts = []
630
+ for lf_chunk in prompt.chunk():
631
+ if isinstance(lf_chunk, str):
632
+ parts.append({'text': lf_chunk})
633
+ elif isinstance(lf_chunk, lf_modalities.Mime):
634
+ try:
635
+ modalities = lf_chunk.make_compatible(
636
+ self.supported_modalities + ['text/plain']
637
+ )
638
+ if isinstance(modalities, lf_modalities.Mime):
639
+ modalities = [modalities]
640
+ for modality in modalities:
641
+ if modality.is_text:
642
+ parts.append({'text': modality.to_text()})
643
+ else:
644
+ parts.append({
645
+ 'inlineData': {
646
+ 'data': base64.b64encode(modality.to_bytes()).decode(),
647
+ 'mimeType': modality.mime_type,
648
+ }
649
+ })
650
+ except lf.ModalityError as e:
651
+ raise lf.ModalityError(f'Unsupported modality: {lf_chunk!r}') from e
652
+ else:
653
+ raise lf.ModalityError(f'Unsupported modality: {lf_chunk!r}')
654
+ return dict(role='user', parts=parts)
655
+
656
+ def result(self, json: dict[str, Any]) -> lf.LMSamplingResult:
657
+ messages = [
658
+ self._message_from_content_parts(candidate['content']['parts'])
659
+ for candidate in json['candidates']
660
+ ]
661
+ usage = json['usageMetadata']
662
+ input_tokens = usage['promptTokenCount']
663
+ output_tokens = usage['candidatesTokenCount']
664
+ return lf.LMSamplingResult(
665
+ [lf.LMSample(message) for message in messages],
666
+ usage=lf.LMSamplingUsage(
667
+ prompt_tokens=input_tokens,
668
+ completion_tokens=output_tokens,
669
+ total_tokens=input_tokens + output_tokens,
670
+ estimated_cost=self.estimate_cost(
671
+ num_input_tokens=input_tokens,
672
+ num_output_tokens=output_tokens,
673
+ ),
674
+ ),
675
+ )
676
+
677
+ def _message_from_content_parts(
678
+ self, parts: list[dict[str, Any]]
679
+ ) -> lf.Message:
680
+ """Converts Vertex AI's content parts protocol to message."""
681
+ chunks = []
682
+ for part in parts:
683
+ if text_part := part.get('text'):
684
+ chunks.append(text_part)
685
+ else:
686
+ raise ValueError(f'Unsupported part: {part}')
687
+ return lf.AIMessage.from_chunks(chunks)
688
+
689
+
452
690
  class _ModelHub:
453
691
  """Vertex AI model hub."""
454
692
 
@@ -547,13 +785,21 @@ class VertexAIGeminiPro1_5(VertexAIGemini1_5): # pylint: disable=invalid-name
547
785
  model = 'gemini-1.5-pro'
548
786
 
549
787
 
550
- class VertexAIGeminiPro1_5_002(VertexAIGemini1_5): # pylint: disable=invalid-name
788
+ class VertexAIRestGemini1_5(VertexAIRest): # pylint: disable=invalid-name
789
+ """Vertex AI Gemini 1.5 model with REST API."""
790
+
791
+ supported_modalities: pg.typing.List(str).freeze( # pytype: disable=invalid-annotation
792
+ _DOCUMENT_TYPES + _IMAGE_TYPES + _AUDIO_TYPES + _VIDEO_TYPES
793
+ )
794
+
795
+
796
+ class VertexAIGeminiPro1_5_002(VertexAIRestGemini1_5): # pylint: disable=invalid-name
551
797
  """Vertex AI Gemini 1.5 Pro model."""
552
798
 
553
799
  model = 'gemini-1.5-pro-002'
554
800
 
555
801
 
556
- class VertexAIGeminiPro1_5_001(VertexAIGemini1_5): # pylint: disable=invalid-name
802
+ class VertexAIGeminiPro1_5_001(VertexAIRestGemini1_5): # pylint: disable=invalid-name
557
803
  """Vertex AI Gemini 1.5 Pro model."""
558
804
 
559
805
  model = 'gemini-1.5-pro-001'
@@ -583,13 +829,13 @@ class VertexAIGeminiFlash1_5(VertexAIGemini1_5): # pylint: disable=invalid-name
583
829
  model = 'gemini-1.5-flash'
584
830
 
585
831
 
586
- class VertexAIGeminiFlash1_5_002(VertexAIGemini1_5): # pylint: disable=invalid-name
832
+ class VertexAIGeminiFlash1_5_002(VertexAIRestGemini1_5): # pylint: disable=invalid-name
587
833
  """Vertex AI Gemini 1.5 Flash model."""
588
834
 
589
835
  model = 'gemini-1.5-flash-002'
590
836
 
591
837
 
592
- class VertexAIGeminiFlash1_5_001(VertexAIGemini1_5): # pylint: disable=invalid-name
838
+ class VertexAIGeminiFlash1_5_001(VertexAIRestGemini1_5): # pylint: disable=invalid-name
593
839
  """Vertex AI Gemini 1.5 Flash model."""
594
840
 
595
841
  model = 'gemini-1.5-flash-001'
@@ -601,14 +847,14 @@ class VertexAIGeminiFlash1_5_0514(VertexAIGemini1_5): # pylint: disable=invalid
601
847
  model = 'gemini-1.5-flash-preview-0514'
602
848
 
603
849
 
604
- class VertexAIGeminiPro1(VertexAI): # pylint: disable=invalid-name
850
+ class VertexAIGeminiPro1(VertexAIRest): # pylint: disable=invalid-name
605
851
  """Vertex AI Gemini 1.0 Pro model."""
606
852
 
607
853
  model = 'gemini-1.0-pro'
608
854
 
609
855
 
610
856
  class VertexAIGeminiPro1Vision(VertexAI): # pylint: disable=invalid-name
611
- """Vertex AI Gemini 1.0 Pro model."""
857
+ """Vertex AI Gemini 1.0 Pro Vision model."""
612
858
 
613
859
  model = 'gemini-1.0-pro-vision'
614
860
  supported_modalities: pg.typing.List(str).freeze( # pytype: disable=invalid-annotation
@@ -14,6 +14,7 @@
14
14
  """Tests for Gemini models."""
15
15
 
16
16
  import os
17
+ from typing import Any
17
18
  import unittest
18
19
  from unittest import mock
19
20
 
@@ -23,6 +24,7 @@ import langfun.core as lf
23
24
  from langfun.core import modalities as lf_modalities
24
25
  from langfun.core.llms import vertexai
25
26
  import pyglove as pg
27
+ import requests
26
28
 
27
29
 
28
30
  example_image = (
@@ -64,6 +66,40 @@ def mock_generate_content(content, generation_config, **kwargs):
64
66
  })
65
67
 
66
68
 
69
+ def mock_requests_post(url: str, json: dict[str, Any], **kwargs):
70
+ del url, kwargs
71
+ c = pg.Dict(json['generationConfig'])
72
+ content = json['contents'][0]['parts'][0]['text']
73
+ response = requests.Response()
74
+ response.status_code = 200
75
+ response._content = pg.to_json_str({
76
+ 'candidates': [
77
+ {
78
+ 'content': {
79
+ 'role': 'model',
80
+ 'parts': [
81
+ {
82
+ 'text': (
83
+ f'This is a response to {content} with '
84
+ f'temperature={c.temperature}, '
85
+ f'top_p={c.topP}, '
86
+ f'top_k={c.topK}, '
87
+ f'max_tokens={c.maxOutputTokens}, '
88
+ f'stop={"".join(c.stopSequences)}.'
89
+ )
90
+ },
91
+ ],
92
+ },
93
+ },
94
+ ],
95
+ 'usageMetadata': {
96
+ 'promptTokenCount': 3,
97
+ 'candidatesTokenCount': 4,
98
+ }
99
+ }).encode()
100
+ return response
101
+
102
+
67
103
  def mock_endpoint_predict(instances, **kwargs):
68
104
  del kwargs
69
105
  assert len(instances) == 1
@@ -83,7 +119,7 @@ class VertexAITest(unittest.TestCase):
83
119
 
84
120
  def test_content_from_message_text_only(self):
85
121
  text = 'This is a beautiful day'
86
- model = vertexai.VertexAIGeminiPro1()
122
+ model = vertexai.VertexAIGeminiPro1Vision()
87
123
  chunks = model._content_from_message(lf.UserMessage(text))
88
124
  self.assertEqual(chunks, [text])
89
125
 
@@ -95,7 +131,7 @@ class VertexAITest(unittest.TestCase):
95
131
 
96
132
  # Non-multimodal model.
97
133
  with self.assertRaisesRegex(lf.ModalityError, 'Unsupported modality'):
98
- vertexai.VertexAIGeminiPro1()._content_from_message(message)
134
+ vertexai.VertexAIPalm2()._content_from_message(message)
99
135
 
100
136
  model = vertexai.VertexAIGeminiPro1Vision()
101
137
  chunks = model._content_from_message(message)
@@ -119,7 +155,7 @@ class VertexAITest(unittest.TestCase):
119
155
  },
120
156
  ],
121
157
  })
122
- model = vertexai.VertexAIGeminiPro1()
158
+ model = vertexai.VertexAIGeminiPro1Vision()
123
159
  message = model._generation_response_to_message(response)
124
160
  self.assertEqual(message, lf.AIMessage('hello world'))
125
161
 
@@ -158,25 +194,25 @@ class VertexAITest(unittest.TestCase):
158
194
 
159
195
  def test_project_and_location_check(self):
160
196
  with self.assertRaisesRegex(ValueError, 'Please specify `project`'):
161
- _ = vertexai.VertexAIGeminiPro1()._api_initialized
197
+ _ = vertexai.VertexAIGeminiPro1Vision()._api_initialized
162
198
 
163
199
  with self.assertRaisesRegex(ValueError, 'Please specify `location`'):
164
- _ = vertexai.VertexAIGeminiPro1(project='abc')._api_initialized
200
+ _ = vertexai.VertexAIGeminiPro1Vision(project='abc')._api_initialized
165
201
 
166
202
  self.assertTrue(
167
- vertexai.VertexAIGeminiPro1(
203
+ vertexai.VertexAIGeminiPro1Vision(
168
204
  project='abc', location='us-central1'
169
205
  )._api_initialized
170
206
  )
171
207
 
172
208
  os.environ['VERTEXAI_PROJECT'] = 'abc'
173
209
  os.environ['VERTEXAI_LOCATION'] = 'us-central1'
174
- self.assertTrue(vertexai.VertexAIGeminiPro1()._api_initialized)
210
+ self.assertTrue(vertexai.VertexAIGeminiPro1Vision()._api_initialized)
175
211
  del os.environ['VERTEXAI_PROJECT']
176
212
  del os.environ['VERTEXAI_LOCATION']
177
213
 
178
214
  def test_generation_config(self):
179
- model = vertexai.VertexAIGeminiPro1()
215
+ model = vertexai.VertexAIGeminiPro1Vision()
180
216
  json_schema = {
181
217
  'type': 'object',
182
218
  'properties': {
@@ -245,7 +281,9 @@ class VertexAITest(unittest.TestCase):
245
281
  ) as mock_generate:
246
282
  mock_generate.side_effect = mock_generate_content
247
283
 
248
- lm = vertexai.VertexAIGeminiPro1(project='abc', location='us-central1')
284
+ lm = vertexai.VertexAIGeminiPro1Vision(
285
+ project='abc', location='us-central1'
286
+ )
249
287
  self.assertEqual(
250
288
  lm(
251
289
  'hello',
@@ -328,5 +366,161 @@ class VertexAITest(unittest.TestCase):
328
366
  )
329
367
 
330
368
 
369
+ class VertexRestfulAITest(unittest.TestCase):
370
+ """Tests for Vertex model with REST API."""
371
+
372
+ def test_content_from_message_text_only(self):
373
+ text = 'This is a beautiful day'
374
+ model = vertexai.VertexAIGeminiPro1_5_002()
375
+ chunks = model._content_from_message(lf.UserMessage(text))
376
+ self.assertEqual(chunks, {'role': 'user', 'parts': [{'text': text}]})
377
+
378
+ def test_content_from_message_mm(self):
379
+ message = lf.UserMessage(
380
+ 'This is an <<[[image]]>>, what is it?',
381
+ image=lf_modalities.Image.from_bytes(example_image),
382
+ )
383
+
384
+ # Non-multimodal model.
385
+ with self.assertRaisesRegex(lf.ModalityError, 'Unsupported modality'):
386
+ vertexai.VertexAIGeminiPro1()._content_from_message(message)
387
+
388
+ model = vertexai.VertexAIGeminiPro1Vision()
389
+ chunks = model._content_from_message(message)
390
+ self.maxDiff = None
391
+ self.assertEqual([chunks[0], chunks[2]], ['This is an', ', what is it?'])
392
+ self.assertIsInstance(chunks[1], generative_models.Part)
393
+
394
+ def test_generation_response_to_message_text_only(self):
395
+ response = generative_models.GenerationResponse.from_dict({
396
+ 'candidates': [
397
+ {
398
+ 'index': 0,
399
+ 'content': {
400
+ 'role': 'model',
401
+ 'parts': [
402
+ {
403
+ 'text': 'hello world',
404
+ },
405
+ ],
406
+ },
407
+ },
408
+ ],
409
+ })
410
+ model = vertexai.VertexAIGeminiPro1Vision()
411
+ message = model._generation_response_to_message(response)
412
+ self.assertEqual(message, lf.AIMessage('hello world'))
413
+
414
+ def test_model_hub(self):
415
+ with mock.patch(
416
+ 'vertexai.generative_models.'
417
+ 'GenerativeModel.__init__'
418
+ ) as mock_model_init:
419
+ mock_model_init.side_effect = lambda *args, **kwargs: None
420
+ model = vertexai._VERTEXAI_MODEL_HUB.get_generative_model(
421
+ 'gemini-1.0-pro'
422
+ )
423
+ self.assertIsNotNone(model)
424
+ self.assertIs(
425
+ vertexai._VERTEXAI_MODEL_HUB.get_generative_model('gemini-1.0-pro'),
426
+ model,
427
+ )
428
+
429
+ def test_project_and_location_check(self):
430
+ with self.assertRaisesRegex(ValueError, 'Please specify `project`'):
431
+ _ = vertexai.VertexAIGeminiPro1()._api_initialized
432
+
433
+ with self.assertRaisesRegex(ValueError, 'Please specify `location`'):
434
+ _ = vertexai.VertexAIGeminiPro1(project='abc')._api_initialized
435
+
436
+ self.assertTrue(
437
+ vertexai.VertexAIGeminiPro1(
438
+ project='abc', location='us-central1'
439
+ )._api_initialized
440
+ )
441
+
442
+ os.environ['VERTEXAI_PROJECT'] = 'abc'
443
+ os.environ['VERTEXAI_LOCATION'] = 'us-central1'
444
+ self.assertTrue(vertexai.VertexAIGeminiPro1()._api_initialized)
445
+ del os.environ['VERTEXAI_PROJECT']
446
+ del os.environ['VERTEXAI_LOCATION']
447
+
448
+ def test_generation_config(self):
449
+ model = vertexai.VertexAIGeminiPro1()
450
+ json_schema = {
451
+ 'type': 'object',
452
+ 'properties': {
453
+ 'name': {'type': 'string'},
454
+ },
455
+ 'required': ['name'],
456
+ 'title': 'Person',
457
+ }
458
+ actual = model._generation_config(
459
+ lf.UserMessage('hi', json_schema=json_schema),
460
+ lf.LMSamplingOptions(
461
+ temperature=2.0,
462
+ top_p=1.0,
463
+ top_k=20,
464
+ max_tokens=1024,
465
+ stop=['\n'],
466
+ ),
467
+ )
468
+ self.assertEqual(
469
+ actual,
470
+ dict(
471
+ candidateCount=1,
472
+ temperature=2.0,
473
+ topP=1.0,
474
+ topK=20,
475
+ maxOutputTokens=1024,
476
+ stopSequences=['\n'],
477
+ responseLogprobs=False,
478
+ logprobs=None,
479
+ seed=None,
480
+ responseMimeType='application/json',
481
+ responseSchema={
482
+ 'type': 'object',
483
+ 'properties': {
484
+ 'name': {'type': 'string'}
485
+ },
486
+ 'required': ['name'],
487
+ 'title': 'Person',
488
+ }
489
+ ),
490
+ )
491
+ with self.assertRaisesRegex(
492
+ ValueError, '`json_schema` must be a dict, got'
493
+ ):
494
+ model._generation_config(
495
+ lf.UserMessage('hi', json_schema='not a dict'),
496
+ lf.LMSamplingOptions(),
497
+ )
498
+
499
+ def test_call_model(self):
500
+ with mock.patch('requests.Session.post') as mock_generate:
501
+ mock_generate.side_effect = mock_requests_post
502
+
503
+ lm = vertexai.VertexAIGeminiPro1_5_002(
504
+ project='abc', location='us-central1'
505
+ )
506
+ r = lm(
507
+ 'hello',
508
+ temperature=2.0,
509
+ top_p=1.0,
510
+ top_k=20,
511
+ max_tokens=1024,
512
+ stop='\n',
513
+ )
514
+ self.assertEqual(
515
+ r.text,
516
+ (
517
+ 'This is a response to hello with temperature=2.0, '
518
+ 'top_p=1.0, top_k=20, max_tokens=1024, stop=\n.'
519
+ ),
520
+ )
521
+ self.assertEqual(r.metadata.usage.prompt_tokens, 3)
522
+ self.assertEqual(r.metadata.usage.completion_tokens, 4)
523
+
524
+
331
525
  if __name__ == '__main__':
332
526
  unittest.main()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: langfun
3
- Version: 0.1.2.dev202411270804
3
+ Version: 0.1.2.dev202412020805
4
4
  Summary: Langfun: Language as Functions.
5
5
  Home-page: https://github.com/google/langfun
6
6
  Author: Langfun Authors
@@ -72,14 +72,14 @@ langfun/core/eval/v2/metrics.py,sha256=bl8i6u-ZHRBz4hAc3LzsZ2Dc7ZRQcuTYeUhhH-Gxf
72
72
  langfun/core/eval/v2/metrics_test.py,sha256=p4FzLJsE8XAzAQuyP9hfEf9YeKWZ__PO_ue8a9P0-cc,6082
73
73
  langfun/core/eval/v2/progress.py,sha256=azZgssQgNdv3IgjKEaQBuGI5ucFDNbdi02P4z_nQ8GE,10292
74
74
  langfun/core/eval/v2/progress_test.py,sha256=YU7VHzmy5knPZwj9vpBN3rQQH2tukj9eKHkuBCI62h8,2540
75
- langfun/core/eval/v2/progress_tracking.py,sha256=1imwSbllxHWG3zYrzo2NvytBZsVtjqum6bmXGGsvT1E,5987
76
- langfun/core/eval/v2/progress_tracking_test.py,sha256=eY2HvZeEXDA5Zyfi2m5NDWO_9kSfQsaAOEcIhkSbWCY,1874
75
+ langfun/core/eval/v2/progress_tracking.py,sha256=l9fEkz4oP5McpZzf72Ua7PYm3lAWtRru7gRWNf8H0ms,6083
76
+ langfun/core/eval/v2/progress_tracking_test.py,sha256=iO-DslCJWncU7-27XaMKxDeKrsGbwdk_tKfoRk3KboE,2271
77
77
  langfun/core/eval/v2/reporting.py,sha256=TGkli1IDwqfqsCJ_WslOMGk_24JDg7oRRTGXlAJlWpc,4361
78
78
  langfun/core/eval/v2/reporting_test.py,sha256=JxffbUPWInUyLjo-AQVFrllga884Mdfm05R86FtxSss,1482
79
79
  langfun/core/eval/v2/runners.py,sha256=zJmu-amUiYv1g0Ek4c3mXkBgp-AFvSF7WpXVZCCf7Y4,14245
80
80
  langfun/core/eval/v2/runners_test.py,sha256=UeiUNygux_U6iGVG18rhp68ZE4hoWeoT6XsXvSjxNQg,11620
81
81
  langfun/core/eval/v2/test_helper.py,sha256=pDpZTBnWRR5xjJv3Uy3NWEzArqlL8FTMOgeR4C53F5M,2348
82
- langfun/core/llms/__init__.py,sha256=tM9LjVQDFFG7Et2vjUr45GlYNEu73LAcWsd_1DO0KcQ,6460
82
+ langfun/core/llms/__init__.py,sha256=5djybv30-27qaVjZY50IkT66UdvLa-2xfGX-vOfeM0I,6573
83
83
  langfun/core/llms/anthropic.py,sha256=uJXVgaFONL8okOSVQ4VGMGht_VZ30m1hoLzmDbIjmks,13990
84
84
  langfun/core/llms/anthropic_test.py,sha256=-2U4kc_pgBM7wqxu8RuxzyHPGww1EAWqKUvN4PW8Btw,8058
85
85
  langfun/core/llms/compositional.py,sha256=csW_FLlgL-tpeyCOTVvfUQkMa_zCN5Y2I-YbSNuK27U,2872
@@ -92,12 +92,12 @@ langfun/core/llms/groq.py,sha256=dCnR3eAECEKuKKAAj-PDTs8NRHl6CQPdf57m1f6a79U,103
92
92
  langfun/core/llms/groq_test.py,sha256=GYF_Qtq5S1H1TrKH38t6_lkdroqT7v-joYLDKnmS9e0,5274
93
93
  langfun/core/llms/llama_cpp.py,sha256=9tXQntSCDtjTF3bnyJrAPCr4N6wycy5nXYvp9uduygE,2843
94
94
  langfun/core/llms/llama_cpp_test.py,sha256=MWO_qaOeKjRniGjcaWPDScd7HPaIJemqUZoslrt4FPs,1806
95
- langfun/core/llms/openai.py,sha256=HVJct35hWwer7iq_1ZfTy0F20ZXSB2X-H4CVwStR0rg,23054
95
+ langfun/core/llms/openai.py,sha256=_VwOSuDsyXDngUM2iiES0CW1aN0BzMjXNBMegLzm4J4,23209
96
96
  langfun/core/llms/openai_test.py,sha256=_8cd3VRNEUfE0-Ko1RiM6MlC5hjalRj7nYTJNhG1p3E,18907
97
97
  langfun/core/llms/rest.py,sha256=sWbYUV8S3SuOg9giq7xwD-xDRfaF7NP_ig7bI52-Rj4,3442
98
98
  langfun/core/llms/rest_test.py,sha256=NZ3Nf0XQVpT9kLP5cBVo_yBHLI7vWTYhWQxYEJVMGs4,3472
99
- langfun/core/llms/vertexai.py,sha256=sl8TIhTuz8ZZFSGo97K6zqXcJQNgF2CVAOLTEiMuiDA,19006
100
- langfun/core/llms/vertexai_test.py,sha256=I8gEHLRXZZGq_d2VDtJAkAIzf-lNSCoB8y2lwFckY-w,10885
99
+ langfun/core/llms/vertexai.py,sha256=EZhJrdN-SsZVV0KT3NHzaJLVKsNMxCT6M3W6f5fpIWQ,27068
100
+ langfun/core/llms/vertexai_test.py,sha256=nGv59yE4xu1zUxqmP_U941QjSBrr_sW15Q2YakuxMv4,16982
101
101
  langfun/core/llms/cache/__init__.py,sha256=QAo3InUMDM_YpteNnVCSejI4zOsnjSMWKJKzkb3VY64,993
102
102
  langfun/core/llms/cache/base.py,sha256=rt3zwmyw0y9jsSGW-ZbV1vAfLxQ7_3AVk0l2EySlse4,3918
103
103
  langfun/core/llms/cache/in_memory.py,sha256=l6b-iU9OTfTRo9Zmg4VrQIuArs4cCJDOpXiEpvNocjo,5004
@@ -148,8 +148,8 @@ langfun/core/templates/demonstration.py,sha256=vCrgYubdZM5Umqcgp8NUVGXgr4P_c-fik
148
148
  langfun/core/templates/demonstration_test.py,sha256=SafcDQ0WgI7pw05EmPI2S4v1t3ABKzup8jReCljHeK4,2162
149
149
  langfun/core/templates/selfplay.py,sha256=yhgrJbiYwq47TgzThmHrDQTF4nDrTI09CWGhuQPNv-s,2273
150
150
  langfun/core/templates/selfplay_test.py,sha256=Ot__1P1M8oJfoTp-M9-PQ6HUXqZKyMwvZ5f7yQ3yfyM,2326
151
- langfun-0.1.2.dev202411270804.dist-info/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
152
- langfun-0.1.2.dev202411270804.dist-info/METADATA,sha256=8mjOah3dljZjQz41eKvz-MAQ2Bbn0dsqKCWBm1IXUJM,8890
153
- langfun-0.1.2.dev202411270804.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
154
- langfun-0.1.2.dev202411270804.dist-info/top_level.txt,sha256=RhlEkHxs1qtzmmtWSwYoLVJAc1YrbPtxQ52uh8Z9VvY,8
155
- langfun-0.1.2.dev202411270804.dist-info/RECORD,,
151
+ langfun-0.1.2.dev202412020805.dist-info/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
152
+ langfun-0.1.2.dev202412020805.dist-info/METADATA,sha256=c3yjg186RyrDaIHGLMpmXsI7-Kqj4V1vLGxYsjJJN2Y,8890
153
+ langfun-0.1.2.dev202412020805.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
154
+ langfun-0.1.2.dev202412020805.dist-info/top_level.txt,sha256=RhlEkHxs1qtzmmtWSwYoLVJAc1YrbPtxQ52uh8Z9VvY,8
155
+ langfun-0.1.2.dev202412020805.dist-info/RECORD,,