langfun 0.1.2.dev202411270804__py3-none-any.whl → 0.1.2.dev202412020805__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/core/eval/v2/progress_tracking.py +2 -1
- langfun/core/eval/v2/progress_tracking_test.py +10 -0
- langfun/core/llms/__init__.py +2 -0
- langfun/core/llms/openai.py +6 -1
- langfun/core/llms/vertexai.py +252 -6
- langfun/core/llms/vertexai_test.py +203 -9
- {langfun-0.1.2.dev202411270804.dist-info → langfun-0.1.2.dev202412020805.dist-info}/METADATA +1 -1
- {langfun-0.1.2.dev202411270804.dist-info → langfun-0.1.2.dev202412020805.dist-info}/RECORD +11 -11
- {langfun-0.1.2.dev202411270804.dist-info → langfun-0.1.2.dev202412020805.dist-info}/LICENSE +0 -0
- {langfun-0.1.2.dev202411270804.dist-info → langfun-0.1.2.dev202412020805.dist-info}/WHEEL +0 -0
- {langfun-0.1.2.dev202411270804.dist-info → langfun-0.1.2.dev202412020805.dist-info}/top_level.txt +0 -0
@@ -87,7 +87,8 @@ class _TqdmProgressTracker(experiment_lib.Plugin):
|
|
87
87
|
self._leaf_progresses = {
|
88
88
|
leaf.id: lf.concurrent.ProgressBar.install(
|
89
89
|
label=f'[#{i + 1} - {leaf.id}]',
|
90
|
-
total=
|
90
|
+
total=(len(runner.current_run.example_ids)
|
91
|
+
if runner.current_run.example_ids else leaf.num_examples),
|
91
92
|
color='cyan',
|
92
93
|
status=None
|
93
94
|
)
|
@@ -51,6 +51,16 @@ class TqdmProgressTrackerTest(unittest.TestCase):
|
|
51
51
|
_ = experiment.run(root_dir, 'new', plugins=[])
|
52
52
|
self.assertIn('All: 100%', string_io.getvalue())
|
53
53
|
|
54
|
+
def test_with_example_ids(self):
|
55
|
+
root_dir = os.path.join(
|
56
|
+
tempfile.gettempdir(), 'test_tqdm_progress_tracker_with_example_ids'
|
57
|
+
)
|
58
|
+
experiment = test_helper.test_experiment()
|
59
|
+
string_io = io.StringIO()
|
60
|
+
with contextlib.redirect_stderr(string_io):
|
61
|
+
_ = experiment.run(root_dir, 'new', example_ids=[1], plugins=[])
|
62
|
+
self.assertIn('All: 100%', string_io.getvalue())
|
63
|
+
|
54
64
|
|
55
65
|
if __name__ == '__main__':
|
56
66
|
unittest.main()
|
langfun/core/llms/__init__.py
CHANGED
@@ -120,6 +120,8 @@ from langfun.core.llms.groq import GroqWhisper_Large_v3
|
|
120
120
|
from langfun.core.llms.groq import GroqWhisper_Large_v3Turbo
|
121
121
|
|
122
122
|
from langfun.core.llms.vertexai import VertexAI
|
123
|
+
from langfun.core.llms.vertexai import VertexAIRest
|
124
|
+
from langfun.core.llms.vertexai import VertexAIRestGemini1_5
|
123
125
|
from langfun.core.llms.vertexai import VertexAIGemini1_5
|
124
126
|
from langfun.core.llms.vertexai import VertexAIGeminiPro1_5
|
125
127
|
from langfun.core.llms.vertexai import VertexAIGeminiPro1_5_Latest
|
langfun/core/llms/openai.py
CHANGED
@@ -530,7 +530,12 @@ class OpenAI(lf.LanguageModel):
|
|
530
530
|
if isinstance(chunk, str):
|
531
531
|
item = dict(type='text', text=chunk)
|
532
532
|
elif isinstance(chunk, lf_modalities.Image):
|
533
|
-
|
533
|
+
if chunk.uri and chunk.uri.lower().startswith(
|
534
|
+
('http:', 'https:', 'ftp:')
|
535
|
+
):
|
536
|
+
uri = chunk.uri
|
537
|
+
else:
|
538
|
+
uri = chunk.content_uri
|
534
539
|
item = dict(type='image_url', image_url=dict(url=uri))
|
535
540
|
else:
|
536
541
|
raise ValueError(f'Unsupported modality object: {chunk!r}.')
|
langfun/core/llms/vertexai.py
CHANGED
@@ -13,17 +13,21 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
"""Vertex AI generative models."""
|
15
15
|
|
16
|
+
import base64
|
16
17
|
import functools
|
17
18
|
import os
|
18
19
|
from typing import Annotated, Any
|
19
20
|
|
20
21
|
import langfun.core as lf
|
21
22
|
from langfun.core import modalities as lf_modalities
|
23
|
+
from langfun.core.llms import rest
|
22
24
|
import pyglove as pg
|
23
25
|
|
24
26
|
try:
|
25
27
|
# pylint: disable=g-import-not-at-top
|
28
|
+
from google import auth as google_auth
|
26
29
|
from google.auth import credentials as credentials_lib
|
30
|
+
from google.auth.transport import requests as auth_requests
|
27
31
|
import vertexai
|
28
32
|
from google.cloud.aiplatform import models as aiplatform_models
|
29
33
|
from vertexai import generative_models
|
@@ -32,6 +36,8 @@ try:
|
|
32
36
|
|
33
37
|
Credentials = credentials_lib.Credentials
|
34
38
|
except ImportError:
|
39
|
+
google_auth = None
|
40
|
+
auth_requests = None
|
35
41
|
credentials_lib = None # pylint: disable=invalid-name
|
36
42
|
vertexai = None
|
37
43
|
generative_models = None
|
@@ -449,6 +455,238 @@ class VertexAI(lf.LanguageModel):
|
|
449
455
|
])
|
450
456
|
|
451
457
|
|
458
|
+
@lf.use_init_args(['model'])
|
459
|
+
@pg.members([('api_endpoint', pg.typing.Str().freeze(''))])
|
460
|
+
class VertexAIRest(rest.REST):
|
461
|
+
"""Language model served on VertexAI with REST API."""
|
462
|
+
|
463
|
+
model: pg.typing.Annotated[
|
464
|
+
pg.typing.Enum(
|
465
|
+
pg.MISSING_VALUE, list(SUPPORTED_MODELS_AND_SETTINGS.keys())
|
466
|
+
),
|
467
|
+
(
|
468
|
+
'Vertex AI model name with REST API support. See '
|
469
|
+
'https://cloud.google.com/vertex-ai/generative-ai/docs/'
|
470
|
+
'model-reference/inference#supported-models'
|
471
|
+
' for details.'
|
472
|
+
),
|
473
|
+
]
|
474
|
+
|
475
|
+
project: Annotated[
|
476
|
+
str | None,
|
477
|
+
(
|
478
|
+
'Vertex AI project ID. Or set from environment variable '
|
479
|
+
'VERTEXAI_PROJECT.'
|
480
|
+
),
|
481
|
+
] = None
|
482
|
+
|
483
|
+
location: Annotated[
|
484
|
+
str | None,
|
485
|
+
(
|
486
|
+
'Vertex AI service location. Or set from environment variable '
|
487
|
+
'VERTEXAI_LOCATION.'
|
488
|
+
),
|
489
|
+
] = None
|
490
|
+
|
491
|
+
credentials: Annotated[
|
492
|
+
Credentials | None,
|
493
|
+
(
|
494
|
+
'Credentials to use. If None, the default credentials to the '
|
495
|
+
'environment will be used.'
|
496
|
+
),
|
497
|
+
] = None
|
498
|
+
|
499
|
+
supported_modalities: Annotated[
|
500
|
+
list[str],
|
501
|
+
'A list of MIME types for supported modalities'
|
502
|
+
] = []
|
503
|
+
|
504
|
+
def _on_bound(self):
|
505
|
+
super()._on_bound()
|
506
|
+
if google_auth is None:
|
507
|
+
raise ValueError(
|
508
|
+
'Please install "langfun[llm-google-vertex]" to use Vertex AI models.'
|
509
|
+
)
|
510
|
+
self._project = None
|
511
|
+
self._credentials = None
|
512
|
+
|
513
|
+
def _initialize(self):
|
514
|
+
project = self.project or os.environ.get('VERTEXAI_PROJECT', None)
|
515
|
+
if not project:
|
516
|
+
raise ValueError(
|
517
|
+
'Please specify `project` during `__init__` or set environment '
|
518
|
+
'variable `VERTEXAI_PROJECT` with your Vertex AI project ID.'
|
519
|
+
)
|
520
|
+
|
521
|
+
location = self.location or os.environ.get('VERTEXAI_LOCATION', None)
|
522
|
+
if not location:
|
523
|
+
raise ValueError(
|
524
|
+
'Please specify `location` during `__init__` or set environment '
|
525
|
+
'variable `VERTEXAI_LOCATION` with your Vertex AI service location.'
|
526
|
+
)
|
527
|
+
|
528
|
+
self._project = project
|
529
|
+
credentials = self.credentials
|
530
|
+
if credentials is None:
|
531
|
+
# Use default credentials.
|
532
|
+
credentials = google_auth.default(
|
533
|
+
scopes=['https://www.googleapis.com/auth/cloud-platform']
|
534
|
+
)
|
535
|
+
self._credentials = credentials
|
536
|
+
|
537
|
+
@property
|
538
|
+
def max_concurrency(self) -> int:
|
539
|
+
"""Returns the maximum number of concurrent requests."""
|
540
|
+
return self.rate_to_max_concurrency(
|
541
|
+
requests_per_min=SUPPORTED_MODELS_AND_SETTINGS[self.model].rpm,
|
542
|
+
tokens_per_min=0,
|
543
|
+
)
|
544
|
+
|
545
|
+
def estimate_cost(
|
546
|
+
self,
|
547
|
+
num_input_tokens: int,
|
548
|
+
num_output_tokens: int
|
549
|
+
) -> float | None:
|
550
|
+
"""Estimate the cost based on usage."""
|
551
|
+
cost_per_1k_input_chars = SUPPORTED_MODELS_AND_SETTINGS[self.model].get(
|
552
|
+
'cost_per_1k_input_chars', None
|
553
|
+
)
|
554
|
+
cost_per_1k_output_chars = SUPPORTED_MODELS_AND_SETTINGS[self.model].get(
|
555
|
+
'cost_per_1k_output_chars', None
|
556
|
+
)
|
557
|
+
if cost_per_1k_output_chars is None or cost_per_1k_input_chars is None:
|
558
|
+
return None
|
559
|
+
return (
|
560
|
+
cost_per_1k_input_chars * num_input_tokens
|
561
|
+
+ cost_per_1k_output_chars * num_output_tokens
|
562
|
+
) * AVGERAGE_CHARS_PER_TOEKN / 1000
|
563
|
+
|
564
|
+
@functools.cached_property
|
565
|
+
def _session(self):
|
566
|
+
assert self._api_initialized
|
567
|
+
assert self._credentials is not None
|
568
|
+
assert auth_requests is not None
|
569
|
+
s = auth_requests.AuthorizedSession(self._credentials)
|
570
|
+
s.headers.update(self.headers or {})
|
571
|
+
return s
|
572
|
+
|
573
|
+
@property
|
574
|
+
def headers(self):
|
575
|
+
return {
|
576
|
+
'Content-Type': 'application/json; charset=utf-8',
|
577
|
+
}
|
578
|
+
|
579
|
+
@property
|
580
|
+
def api_endpoint(self) -> str:
|
581
|
+
return (
|
582
|
+
f'https://{self.location}-aiplatform.googleapis.com/v1/projects/'
|
583
|
+
f'{self.project}/locations/{self.location}/publishers/google/'
|
584
|
+
f'models/{self.model}:generateContent'
|
585
|
+
)
|
586
|
+
|
587
|
+
def request(
|
588
|
+
self, prompt: lf.Message, sampling_options: lf.LMSamplingOptions
|
589
|
+
) -> dict[str, Any]:
|
590
|
+
request = dict(
|
591
|
+
generationConfig=self._generation_config(prompt, sampling_options)
|
592
|
+
)
|
593
|
+
request['contents'] = [self._content_from_message(prompt)]
|
594
|
+
return request
|
595
|
+
|
596
|
+
def _generation_config(
|
597
|
+
self, prompt: lf.Message, options: lf.LMSamplingOptions
|
598
|
+
) -> dict[str, Any]:
|
599
|
+
"""Returns a dict as generation config for prompt and LMSamplingOptions."""
|
600
|
+
config = dict(
|
601
|
+
temperature=options.temperature,
|
602
|
+
maxOutputTokens=options.max_tokens,
|
603
|
+
candidateCount=options.n,
|
604
|
+
topK=options.top_k,
|
605
|
+
topP=options.top_p,
|
606
|
+
stopSequences=options.stop,
|
607
|
+
seed=options.random_seed,
|
608
|
+
responseLogprobs=options.logprobs,
|
609
|
+
logprobs=options.top_logprobs,
|
610
|
+
)
|
611
|
+
|
612
|
+
if json_schema := prompt.metadata.get('json_schema'):
|
613
|
+
if not isinstance(json_schema, dict):
|
614
|
+
raise ValueError(
|
615
|
+
f'`json_schema` must be a dict, got {json_schema!r}.'
|
616
|
+
)
|
617
|
+
json_schema = pg.to_json(json_schema)
|
618
|
+
config['responseSchema'] = json_schema
|
619
|
+
config['responseMimeType'] = 'application/json'
|
620
|
+
prompt.metadata.formatted_text = (
|
621
|
+
prompt.text
|
622
|
+
+ '\n\n [RESPONSE FORMAT (not part of prompt)]\n'
|
623
|
+
+ pg.to_json_str(json_schema, json_indent=2)
|
624
|
+
)
|
625
|
+
return config
|
626
|
+
|
627
|
+
def _content_from_message(self, prompt: lf.Message) -> dict[str, Any]:
|
628
|
+
"""Gets generation content from langfun message."""
|
629
|
+
parts = []
|
630
|
+
for lf_chunk in prompt.chunk():
|
631
|
+
if isinstance(lf_chunk, str):
|
632
|
+
parts.append({'text': lf_chunk})
|
633
|
+
elif isinstance(lf_chunk, lf_modalities.Mime):
|
634
|
+
try:
|
635
|
+
modalities = lf_chunk.make_compatible(
|
636
|
+
self.supported_modalities + ['text/plain']
|
637
|
+
)
|
638
|
+
if isinstance(modalities, lf_modalities.Mime):
|
639
|
+
modalities = [modalities]
|
640
|
+
for modality in modalities:
|
641
|
+
if modality.is_text:
|
642
|
+
parts.append({'text': modality.to_text()})
|
643
|
+
else:
|
644
|
+
parts.append({
|
645
|
+
'inlineData': {
|
646
|
+
'data': base64.b64encode(modality.to_bytes()).decode(),
|
647
|
+
'mimeType': modality.mime_type,
|
648
|
+
}
|
649
|
+
})
|
650
|
+
except lf.ModalityError as e:
|
651
|
+
raise lf.ModalityError(f'Unsupported modality: {lf_chunk!r}') from e
|
652
|
+
else:
|
653
|
+
raise lf.ModalityError(f'Unsupported modality: {lf_chunk!r}')
|
654
|
+
return dict(role='user', parts=parts)
|
655
|
+
|
656
|
+
def result(self, json: dict[str, Any]) -> lf.LMSamplingResult:
|
657
|
+
messages = [
|
658
|
+
self._message_from_content_parts(candidate['content']['parts'])
|
659
|
+
for candidate in json['candidates']
|
660
|
+
]
|
661
|
+
usage = json['usageMetadata']
|
662
|
+
input_tokens = usage['promptTokenCount']
|
663
|
+
output_tokens = usage['candidatesTokenCount']
|
664
|
+
return lf.LMSamplingResult(
|
665
|
+
[lf.LMSample(message) for message in messages],
|
666
|
+
usage=lf.LMSamplingUsage(
|
667
|
+
prompt_tokens=input_tokens,
|
668
|
+
completion_tokens=output_tokens,
|
669
|
+
total_tokens=input_tokens + output_tokens,
|
670
|
+
estimated_cost=self.estimate_cost(
|
671
|
+
num_input_tokens=input_tokens,
|
672
|
+
num_output_tokens=output_tokens,
|
673
|
+
),
|
674
|
+
),
|
675
|
+
)
|
676
|
+
|
677
|
+
def _message_from_content_parts(
|
678
|
+
self, parts: list[dict[str, Any]]
|
679
|
+
) -> lf.Message:
|
680
|
+
"""Converts Vertex AI's content parts protocol to message."""
|
681
|
+
chunks = []
|
682
|
+
for part in parts:
|
683
|
+
if text_part := part.get('text'):
|
684
|
+
chunks.append(text_part)
|
685
|
+
else:
|
686
|
+
raise ValueError(f'Unsupported part: {part}')
|
687
|
+
return lf.AIMessage.from_chunks(chunks)
|
688
|
+
|
689
|
+
|
452
690
|
class _ModelHub:
|
453
691
|
"""Vertex AI model hub."""
|
454
692
|
|
@@ -547,13 +785,21 @@ class VertexAIGeminiPro1_5(VertexAIGemini1_5): # pylint: disable=invalid-name
|
|
547
785
|
model = 'gemini-1.5-pro'
|
548
786
|
|
549
787
|
|
550
|
-
class
|
788
|
+
class VertexAIRestGemini1_5(VertexAIRest): # pylint: disable=invalid-name
|
789
|
+
"""Vertex AI Gemini 1.5 model with REST API."""
|
790
|
+
|
791
|
+
supported_modalities: pg.typing.List(str).freeze( # pytype: disable=invalid-annotation
|
792
|
+
_DOCUMENT_TYPES + _IMAGE_TYPES + _AUDIO_TYPES + _VIDEO_TYPES
|
793
|
+
)
|
794
|
+
|
795
|
+
|
796
|
+
class VertexAIGeminiPro1_5_002(VertexAIRestGemini1_5): # pylint: disable=invalid-name
|
551
797
|
"""Vertex AI Gemini 1.5 Pro model."""
|
552
798
|
|
553
799
|
model = 'gemini-1.5-pro-002'
|
554
800
|
|
555
801
|
|
556
|
-
class VertexAIGeminiPro1_5_001(
|
802
|
+
class VertexAIGeminiPro1_5_001(VertexAIRestGemini1_5): # pylint: disable=invalid-name
|
557
803
|
"""Vertex AI Gemini 1.5 Pro model."""
|
558
804
|
|
559
805
|
model = 'gemini-1.5-pro-001'
|
@@ -583,13 +829,13 @@ class VertexAIGeminiFlash1_5(VertexAIGemini1_5): # pylint: disable=invalid-name
|
|
583
829
|
model = 'gemini-1.5-flash'
|
584
830
|
|
585
831
|
|
586
|
-
class VertexAIGeminiFlash1_5_002(
|
832
|
+
class VertexAIGeminiFlash1_5_002(VertexAIRestGemini1_5): # pylint: disable=invalid-name
|
587
833
|
"""Vertex AI Gemini 1.5 Flash model."""
|
588
834
|
|
589
835
|
model = 'gemini-1.5-flash-002'
|
590
836
|
|
591
837
|
|
592
|
-
class VertexAIGeminiFlash1_5_001(
|
838
|
+
class VertexAIGeminiFlash1_5_001(VertexAIRestGemini1_5): # pylint: disable=invalid-name
|
593
839
|
"""Vertex AI Gemini 1.5 Flash model."""
|
594
840
|
|
595
841
|
model = 'gemini-1.5-flash-001'
|
@@ -601,14 +847,14 @@ class VertexAIGeminiFlash1_5_0514(VertexAIGemini1_5): # pylint: disable=invalid
|
|
601
847
|
model = 'gemini-1.5-flash-preview-0514'
|
602
848
|
|
603
849
|
|
604
|
-
class VertexAIGeminiPro1(
|
850
|
+
class VertexAIGeminiPro1(VertexAIRest): # pylint: disable=invalid-name
|
605
851
|
"""Vertex AI Gemini 1.0 Pro model."""
|
606
852
|
|
607
853
|
model = 'gemini-1.0-pro'
|
608
854
|
|
609
855
|
|
610
856
|
class VertexAIGeminiPro1Vision(VertexAI): # pylint: disable=invalid-name
|
611
|
-
"""Vertex AI Gemini 1.0 Pro model."""
|
857
|
+
"""Vertex AI Gemini 1.0 Pro Vision model."""
|
612
858
|
|
613
859
|
model = 'gemini-1.0-pro-vision'
|
614
860
|
supported_modalities: pg.typing.List(str).freeze( # pytype: disable=invalid-annotation
|
@@ -14,6 +14,7 @@
|
|
14
14
|
"""Tests for Gemini models."""
|
15
15
|
|
16
16
|
import os
|
17
|
+
from typing import Any
|
17
18
|
import unittest
|
18
19
|
from unittest import mock
|
19
20
|
|
@@ -23,6 +24,7 @@ import langfun.core as lf
|
|
23
24
|
from langfun.core import modalities as lf_modalities
|
24
25
|
from langfun.core.llms import vertexai
|
25
26
|
import pyglove as pg
|
27
|
+
import requests
|
26
28
|
|
27
29
|
|
28
30
|
example_image = (
|
@@ -64,6 +66,40 @@ def mock_generate_content(content, generation_config, **kwargs):
|
|
64
66
|
})
|
65
67
|
|
66
68
|
|
69
|
+
def mock_requests_post(url: str, json: dict[str, Any], **kwargs):
|
70
|
+
del url, kwargs
|
71
|
+
c = pg.Dict(json['generationConfig'])
|
72
|
+
content = json['contents'][0]['parts'][0]['text']
|
73
|
+
response = requests.Response()
|
74
|
+
response.status_code = 200
|
75
|
+
response._content = pg.to_json_str({
|
76
|
+
'candidates': [
|
77
|
+
{
|
78
|
+
'content': {
|
79
|
+
'role': 'model',
|
80
|
+
'parts': [
|
81
|
+
{
|
82
|
+
'text': (
|
83
|
+
f'This is a response to {content} with '
|
84
|
+
f'temperature={c.temperature}, '
|
85
|
+
f'top_p={c.topP}, '
|
86
|
+
f'top_k={c.topK}, '
|
87
|
+
f'max_tokens={c.maxOutputTokens}, '
|
88
|
+
f'stop={"".join(c.stopSequences)}.'
|
89
|
+
)
|
90
|
+
},
|
91
|
+
],
|
92
|
+
},
|
93
|
+
},
|
94
|
+
],
|
95
|
+
'usageMetadata': {
|
96
|
+
'promptTokenCount': 3,
|
97
|
+
'candidatesTokenCount': 4,
|
98
|
+
}
|
99
|
+
}).encode()
|
100
|
+
return response
|
101
|
+
|
102
|
+
|
67
103
|
def mock_endpoint_predict(instances, **kwargs):
|
68
104
|
del kwargs
|
69
105
|
assert len(instances) == 1
|
@@ -83,7 +119,7 @@ class VertexAITest(unittest.TestCase):
|
|
83
119
|
|
84
120
|
def test_content_from_message_text_only(self):
|
85
121
|
text = 'This is a beautiful day'
|
86
|
-
model = vertexai.
|
122
|
+
model = vertexai.VertexAIGeminiPro1Vision()
|
87
123
|
chunks = model._content_from_message(lf.UserMessage(text))
|
88
124
|
self.assertEqual(chunks, [text])
|
89
125
|
|
@@ -95,7 +131,7 @@ class VertexAITest(unittest.TestCase):
|
|
95
131
|
|
96
132
|
# Non-multimodal model.
|
97
133
|
with self.assertRaisesRegex(lf.ModalityError, 'Unsupported modality'):
|
98
|
-
vertexai.
|
134
|
+
vertexai.VertexAIPalm2()._content_from_message(message)
|
99
135
|
|
100
136
|
model = vertexai.VertexAIGeminiPro1Vision()
|
101
137
|
chunks = model._content_from_message(message)
|
@@ -119,7 +155,7 @@ class VertexAITest(unittest.TestCase):
|
|
119
155
|
},
|
120
156
|
],
|
121
157
|
})
|
122
|
-
model = vertexai.
|
158
|
+
model = vertexai.VertexAIGeminiPro1Vision()
|
123
159
|
message = model._generation_response_to_message(response)
|
124
160
|
self.assertEqual(message, lf.AIMessage('hello world'))
|
125
161
|
|
@@ -158,25 +194,25 @@ class VertexAITest(unittest.TestCase):
|
|
158
194
|
|
159
195
|
def test_project_and_location_check(self):
|
160
196
|
with self.assertRaisesRegex(ValueError, 'Please specify `project`'):
|
161
|
-
_ = vertexai.
|
197
|
+
_ = vertexai.VertexAIGeminiPro1Vision()._api_initialized
|
162
198
|
|
163
199
|
with self.assertRaisesRegex(ValueError, 'Please specify `location`'):
|
164
|
-
_ = vertexai.
|
200
|
+
_ = vertexai.VertexAIGeminiPro1Vision(project='abc')._api_initialized
|
165
201
|
|
166
202
|
self.assertTrue(
|
167
|
-
vertexai.
|
203
|
+
vertexai.VertexAIGeminiPro1Vision(
|
168
204
|
project='abc', location='us-central1'
|
169
205
|
)._api_initialized
|
170
206
|
)
|
171
207
|
|
172
208
|
os.environ['VERTEXAI_PROJECT'] = 'abc'
|
173
209
|
os.environ['VERTEXAI_LOCATION'] = 'us-central1'
|
174
|
-
self.assertTrue(vertexai.
|
210
|
+
self.assertTrue(vertexai.VertexAIGeminiPro1Vision()._api_initialized)
|
175
211
|
del os.environ['VERTEXAI_PROJECT']
|
176
212
|
del os.environ['VERTEXAI_LOCATION']
|
177
213
|
|
178
214
|
def test_generation_config(self):
|
179
|
-
model = vertexai.
|
215
|
+
model = vertexai.VertexAIGeminiPro1Vision()
|
180
216
|
json_schema = {
|
181
217
|
'type': 'object',
|
182
218
|
'properties': {
|
@@ -245,7 +281,9 @@ class VertexAITest(unittest.TestCase):
|
|
245
281
|
) as mock_generate:
|
246
282
|
mock_generate.side_effect = mock_generate_content
|
247
283
|
|
248
|
-
lm = vertexai.
|
284
|
+
lm = vertexai.VertexAIGeminiPro1Vision(
|
285
|
+
project='abc', location='us-central1'
|
286
|
+
)
|
249
287
|
self.assertEqual(
|
250
288
|
lm(
|
251
289
|
'hello',
|
@@ -328,5 +366,161 @@ class VertexAITest(unittest.TestCase):
|
|
328
366
|
)
|
329
367
|
|
330
368
|
|
369
|
+
class VertexRestfulAITest(unittest.TestCase):
|
370
|
+
"""Tests for Vertex model with REST API."""
|
371
|
+
|
372
|
+
def test_content_from_message_text_only(self):
|
373
|
+
text = 'This is a beautiful day'
|
374
|
+
model = vertexai.VertexAIGeminiPro1_5_002()
|
375
|
+
chunks = model._content_from_message(lf.UserMessage(text))
|
376
|
+
self.assertEqual(chunks, {'role': 'user', 'parts': [{'text': text}]})
|
377
|
+
|
378
|
+
def test_content_from_message_mm(self):
|
379
|
+
message = lf.UserMessage(
|
380
|
+
'This is an <<[[image]]>>, what is it?',
|
381
|
+
image=lf_modalities.Image.from_bytes(example_image),
|
382
|
+
)
|
383
|
+
|
384
|
+
# Non-multimodal model.
|
385
|
+
with self.assertRaisesRegex(lf.ModalityError, 'Unsupported modality'):
|
386
|
+
vertexai.VertexAIGeminiPro1()._content_from_message(message)
|
387
|
+
|
388
|
+
model = vertexai.VertexAIGeminiPro1Vision()
|
389
|
+
chunks = model._content_from_message(message)
|
390
|
+
self.maxDiff = None
|
391
|
+
self.assertEqual([chunks[0], chunks[2]], ['This is an', ', what is it?'])
|
392
|
+
self.assertIsInstance(chunks[1], generative_models.Part)
|
393
|
+
|
394
|
+
def test_generation_response_to_message_text_only(self):
|
395
|
+
response = generative_models.GenerationResponse.from_dict({
|
396
|
+
'candidates': [
|
397
|
+
{
|
398
|
+
'index': 0,
|
399
|
+
'content': {
|
400
|
+
'role': 'model',
|
401
|
+
'parts': [
|
402
|
+
{
|
403
|
+
'text': 'hello world',
|
404
|
+
},
|
405
|
+
],
|
406
|
+
},
|
407
|
+
},
|
408
|
+
],
|
409
|
+
})
|
410
|
+
model = vertexai.VertexAIGeminiPro1Vision()
|
411
|
+
message = model._generation_response_to_message(response)
|
412
|
+
self.assertEqual(message, lf.AIMessage('hello world'))
|
413
|
+
|
414
|
+
def test_model_hub(self):
|
415
|
+
with mock.patch(
|
416
|
+
'vertexai.generative_models.'
|
417
|
+
'GenerativeModel.__init__'
|
418
|
+
) as mock_model_init:
|
419
|
+
mock_model_init.side_effect = lambda *args, **kwargs: None
|
420
|
+
model = vertexai._VERTEXAI_MODEL_HUB.get_generative_model(
|
421
|
+
'gemini-1.0-pro'
|
422
|
+
)
|
423
|
+
self.assertIsNotNone(model)
|
424
|
+
self.assertIs(
|
425
|
+
vertexai._VERTEXAI_MODEL_HUB.get_generative_model('gemini-1.0-pro'),
|
426
|
+
model,
|
427
|
+
)
|
428
|
+
|
429
|
+
def test_project_and_location_check(self):
|
430
|
+
with self.assertRaisesRegex(ValueError, 'Please specify `project`'):
|
431
|
+
_ = vertexai.VertexAIGeminiPro1()._api_initialized
|
432
|
+
|
433
|
+
with self.assertRaisesRegex(ValueError, 'Please specify `location`'):
|
434
|
+
_ = vertexai.VertexAIGeminiPro1(project='abc')._api_initialized
|
435
|
+
|
436
|
+
self.assertTrue(
|
437
|
+
vertexai.VertexAIGeminiPro1(
|
438
|
+
project='abc', location='us-central1'
|
439
|
+
)._api_initialized
|
440
|
+
)
|
441
|
+
|
442
|
+
os.environ['VERTEXAI_PROJECT'] = 'abc'
|
443
|
+
os.environ['VERTEXAI_LOCATION'] = 'us-central1'
|
444
|
+
self.assertTrue(vertexai.VertexAIGeminiPro1()._api_initialized)
|
445
|
+
del os.environ['VERTEXAI_PROJECT']
|
446
|
+
del os.environ['VERTEXAI_LOCATION']
|
447
|
+
|
448
|
+
def test_generation_config(self):
|
449
|
+
model = vertexai.VertexAIGeminiPro1()
|
450
|
+
json_schema = {
|
451
|
+
'type': 'object',
|
452
|
+
'properties': {
|
453
|
+
'name': {'type': 'string'},
|
454
|
+
},
|
455
|
+
'required': ['name'],
|
456
|
+
'title': 'Person',
|
457
|
+
}
|
458
|
+
actual = model._generation_config(
|
459
|
+
lf.UserMessage('hi', json_schema=json_schema),
|
460
|
+
lf.LMSamplingOptions(
|
461
|
+
temperature=2.0,
|
462
|
+
top_p=1.0,
|
463
|
+
top_k=20,
|
464
|
+
max_tokens=1024,
|
465
|
+
stop=['\n'],
|
466
|
+
),
|
467
|
+
)
|
468
|
+
self.assertEqual(
|
469
|
+
actual,
|
470
|
+
dict(
|
471
|
+
candidateCount=1,
|
472
|
+
temperature=2.0,
|
473
|
+
topP=1.0,
|
474
|
+
topK=20,
|
475
|
+
maxOutputTokens=1024,
|
476
|
+
stopSequences=['\n'],
|
477
|
+
responseLogprobs=False,
|
478
|
+
logprobs=None,
|
479
|
+
seed=None,
|
480
|
+
responseMimeType='application/json',
|
481
|
+
responseSchema={
|
482
|
+
'type': 'object',
|
483
|
+
'properties': {
|
484
|
+
'name': {'type': 'string'}
|
485
|
+
},
|
486
|
+
'required': ['name'],
|
487
|
+
'title': 'Person',
|
488
|
+
}
|
489
|
+
),
|
490
|
+
)
|
491
|
+
with self.assertRaisesRegex(
|
492
|
+
ValueError, '`json_schema` must be a dict, got'
|
493
|
+
):
|
494
|
+
model._generation_config(
|
495
|
+
lf.UserMessage('hi', json_schema='not a dict'),
|
496
|
+
lf.LMSamplingOptions(),
|
497
|
+
)
|
498
|
+
|
499
|
+
def test_call_model(self):
|
500
|
+
with mock.patch('requests.Session.post') as mock_generate:
|
501
|
+
mock_generate.side_effect = mock_requests_post
|
502
|
+
|
503
|
+
lm = vertexai.VertexAIGeminiPro1_5_002(
|
504
|
+
project='abc', location='us-central1'
|
505
|
+
)
|
506
|
+
r = lm(
|
507
|
+
'hello',
|
508
|
+
temperature=2.0,
|
509
|
+
top_p=1.0,
|
510
|
+
top_k=20,
|
511
|
+
max_tokens=1024,
|
512
|
+
stop='\n',
|
513
|
+
)
|
514
|
+
self.assertEqual(
|
515
|
+
r.text,
|
516
|
+
(
|
517
|
+
'This is a response to hello with temperature=2.0, '
|
518
|
+
'top_p=1.0, top_k=20, max_tokens=1024, stop=\n.'
|
519
|
+
),
|
520
|
+
)
|
521
|
+
self.assertEqual(r.metadata.usage.prompt_tokens, 3)
|
522
|
+
self.assertEqual(r.metadata.usage.completion_tokens, 4)
|
523
|
+
|
524
|
+
|
331
525
|
if __name__ == '__main__':
|
332
526
|
unittest.main()
|
@@ -72,14 +72,14 @@ langfun/core/eval/v2/metrics.py,sha256=bl8i6u-ZHRBz4hAc3LzsZ2Dc7ZRQcuTYeUhhH-Gxf
|
|
72
72
|
langfun/core/eval/v2/metrics_test.py,sha256=p4FzLJsE8XAzAQuyP9hfEf9YeKWZ__PO_ue8a9P0-cc,6082
|
73
73
|
langfun/core/eval/v2/progress.py,sha256=azZgssQgNdv3IgjKEaQBuGI5ucFDNbdi02P4z_nQ8GE,10292
|
74
74
|
langfun/core/eval/v2/progress_test.py,sha256=YU7VHzmy5knPZwj9vpBN3rQQH2tukj9eKHkuBCI62h8,2540
|
75
|
-
langfun/core/eval/v2/progress_tracking.py,sha256=
|
76
|
-
langfun/core/eval/v2/progress_tracking_test.py,sha256=
|
75
|
+
langfun/core/eval/v2/progress_tracking.py,sha256=l9fEkz4oP5McpZzf72Ua7PYm3lAWtRru7gRWNf8H0ms,6083
|
76
|
+
langfun/core/eval/v2/progress_tracking_test.py,sha256=iO-DslCJWncU7-27XaMKxDeKrsGbwdk_tKfoRk3KboE,2271
|
77
77
|
langfun/core/eval/v2/reporting.py,sha256=TGkli1IDwqfqsCJ_WslOMGk_24JDg7oRRTGXlAJlWpc,4361
|
78
78
|
langfun/core/eval/v2/reporting_test.py,sha256=JxffbUPWInUyLjo-AQVFrllga884Mdfm05R86FtxSss,1482
|
79
79
|
langfun/core/eval/v2/runners.py,sha256=zJmu-amUiYv1g0Ek4c3mXkBgp-AFvSF7WpXVZCCf7Y4,14245
|
80
80
|
langfun/core/eval/v2/runners_test.py,sha256=UeiUNygux_U6iGVG18rhp68ZE4hoWeoT6XsXvSjxNQg,11620
|
81
81
|
langfun/core/eval/v2/test_helper.py,sha256=pDpZTBnWRR5xjJv3Uy3NWEzArqlL8FTMOgeR4C53F5M,2348
|
82
|
-
langfun/core/llms/__init__.py,sha256=
|
82
|
+
langfun/core/llms/__init__.py,sha256=5djybv30-27qaVjZY50IkT66UdvLa-2xfGX-vOfeM0I,6573
|
83
83
|
langfun/core/llms/anthropic.py,sha256=uJXVgaFONL8okOSVQ4VGMGht_VZ30m1hoLzmDbIjmks,13990
|
84
84
|
langfun/core/llms/anthropic_test.py,sha256=-2U4kc_pgBM7wqxu8RuxzyHPGww1EAWqKUvN4PW8Btw,8058
|
85
85
|
langfun/core/llms/compositional.py,sha256=csW_FLlgL-tpeyCOTVvfUQkMa_zCN5Y2I-YbSNuK27U,2872
|
@@ -92,12 +92,12 @@ langfun/core/llms/groq.py,sha256=dCnR3eAECEKuKKAAj-PDTs8NRHl6CQPdf57m1f6a79U,103
|
|
92
92
|
langfun/core/llms/groq_test.py,sha256=GYF_Qtq5S1H1TrKH38t6_lkdroqT7v-joYLDKnmS9e0,5274
|
93
93
|
langfun/core/llms/llama_cpp.py,sha256=9tXQntSCDtjTF3bnyJrAPCr4N6wycy5nXYvp9uduygE,2843
|
94
94
|
langfun/core/llms/llama_cpp_test.py,sha256=MWO_qaOeKjRniGjcaWPDScd7HPaIJemqUZoslrt4FPs,1806
|
95
|
-
langfun/core/llms/openai.py,sha256=
|
95
|
+
langfun/core/llms/openai.py,sha256=_VwOSuDsyXDngUM2iiES0CW1aN0BzMjXNBMegLzm4J4,23209
|
96
96
|
langfun/core/llms/openai_test.py,sha256=_8cd3VRNEUfE0-Ko1RiM6MlC5hjalRj7nYTJNhG1p3E,18907
|
97
97
|
langfun/core/llms/rest.py,sha256=sWbYUV8S3SuOg9giq7xwD-xDRfaF7NP_ig7bI52-Rj4,3442
|
98
98
|
langfun/core/llms/rest_test.py,sha256=NZ3Nf0XQVpT9kLP5cBVo_yBHLI7vWTYhWQxYEJVMGs4,3472
|
99
|
-
langfun/core/llms/vertexai.py,sha256=
|
100
|
-
langfun/core/llms/vertexai_test.py,sha256=
|
99
|
+
langfun/core/llms/vertexai.py,sha256=EZhJrdN-SsZVV0KT3NHzaJLVKsNMxCT6M3W6f5fpIWQ,27068
|
100
|
+
langfun/core/llms/vertexai_test.py,sha256=nGv59yE4xu1zUxqmP_U941QjSBrr_sW15Q2YakuxMv4,16982
|
101
101
|
langfun/core/llms/cache/__init__.py,sha256=QAo3InUMDM_YpteNnVCSejI4zOsnjSMWKJKzkb3VY64,993
|
102
102
|
langfun/core/llms/cache/base.py,sha256=rt3zwmyw0y9jsSGW-ZbV1vAfLxQ7_3AVk0l2EySlse4,3918
|
103
103
|
langfun/core/llms/cache/in_memory.py,sha256=l6b-iU9OTfTRo9Zmg4VrQIuArs4cCJDOpXiEpvNocjo,5004
|
@@ -148,8 +148,8 @@ langfun/core/templates/demonstration.py,sha256=vCrgYubdZM5Umqcgp8NUVGXgr4P_c-fik
|
|
148
148
|
langfun/core/templates/demonstration_test.py,sha256=SafcDQ0WgI7pw05EmPI2S4v1t3ABKzup8jReCljHeK4,2162
|
149
149
|
langfun/core/templates/selfplay.py,sha256=yhgrJbiYwq47TgzThmHrDQTF4nDrTI09CWGhuQPNv-s,2273
|
150
150
|
langfun/core/templates/selfplay_test.py,sha256=Ot__1P1M8oJfoTp-M9-PQ6HUXqZKyMwvZ5f7yQ3yfyM,2326
|
151
|
-
langfun-0.1.2.
|
152
|
-
langfun-0.1.2.
|
153
|
-
langfun-0.1.2.
|
154
|
-
langfun-0.1.2.
|
155
|
-
langfun-0.1.2.
|
151
|
+
langfun-0.1.2.dev202412020805.dist-info/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
|
152
|
+
langfun-0.1.2.dev202412020805.dist-info/METADATA,sha256=c3yjg186RyrDaIHGLMpmXsI7-Kqj4V1vLGxYsjJJN2Y,8890
|
153
|
+
langfun-0.1.2.dev202412020805.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
|
154
|
+
langfun-0.1.2.dev202412020805.dist-info/top_level.txt,sha256=RhlEkHxs1qtzmmtWSwYoLVJAc1YrbPtxQ52uh8Z9VvY,8
|
155
|
+
langfun-0.1.2.dev202412020805.dist-info/RECORD,,
|
File without changes
|
File without changes
|
{langfun-0.1.2.dev202411270804.dist-info → langfun-0.1.2.dev202412020805.dist-info}/top_level.txt
RENAMED
File without changes
|