langfun 0.1.2.dev202412010804__py3-none-any.whl → 0.1.2.dev202412030000__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/core/eval/v2/progress_tracking.py +2 -1
- langfun/core/eval/v2/progress_tracking_test.py +10 -0
- langfun/core/llms/__init__.py +2 -0
- langfun/core/llms/openai.py +142 -202
- langfun/core/llms/openai_test.py +160 -224
- langfun/core/llms/vertexai.py +252 -6
- langfun/core/llms/vertexai_test.py +205 -9
- {langfun-0.1.2.dev202412010804.dist-info → langfun-0.1.2.dev202412030000.dist-info}/METADATA +1 -6
- {langfun-0.1.2.dev202412010804.dist-info → langfun-0.1.2.dev202412030000.dist-info}/RECORD +12 -12
- {langfun-0.1.2.dev202412010804.dist-info → langfun-0.1.2.dev202412030000.dist-info}/LICENSE +0 -0
- {langfun-0.1.2.dev202412010804.dist-info → langfun-0.1.2.dev202412030000.dist-info}/WHEEL +0 -0
- {langfun-0.1.2.dev202412010804.dist-info → langfun-0.1.2.dev202412030000.dist-info}/top_level.txt +0 -0
langfun/core/llms/vertexai.py
CHANGED
@@ -13,17 +13,21 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
"""Vertex AI generative models."""
|
15
15
|
|
16
|
+
import base64
|
16
17
|
import functools
|
17
18
|
import os
|
18
19
|
from typing import Annotated, Any
|
19
20
|
|
20
21
|
import langfun.core as lf
|
21
22
|
from langfun.core import modalities as lf_modalities
|
23
|
+
from langfun.core.llms import rest
|
22
24
|
import pyglove as pg
|
23
25
|
|
24
26
|
try:
|
25
27
|
# pylint: disable=g-import-not-at-top
|
28
|
+
from google import auth as google_auth
|
26
29
|
from google.auth import credentials as credentials_lib
|
30
|
+
from google.auth.transport import requests as auth_requests
|
27
31
|
import vertexai
|
28
32
|
from google.cloud.aiplatform import models as aiplatform_models
|
29
33
|
from vertexai import generative_models
|
@@ -32,6 +36,8 @@ try:
|
|
32
36
|
|
33
37
|
Credentials = credentials_lib.Credentials
|
34
38
|
except ImportError:
|
39
|
+
google_auth = None
|
40
|
+
auth_requests = None
|
35
41
|
credentials_lib = None # pylint: disable=invalid-name
|
36
42
|
vertexai = None
|
37
43
|
generative_models = None
|
@@ -449,6 +455,238 @@ class VertexAI(lf.LanguageModel):
|
|
449
455
|
])
|
450
456
|
|
451
457
|
|
458
|
+
@lf.use_init_args(['model'])
|
459
|
+
@pg.members([('api_endpoint', pg.typing.Str().freeze(''))])
|
460
|
+
class VertexAIRest(rest.REST):
|
461
|
+
"""Language model served on VertexAI with REST API."""
|
462
|
+
|
463
|
+
model: pg.typing.Annotated[
|
464
|
+
pg.typing.Enum(
|
465
|
+
pg.MISSING_VALUE, list(SUPPORTED_MODELS_AND_SETTINGS.keys())
|
466
|
+
),
|
467
|
+
(
|
468
|
+
'Vertex AI model name with REST API support. See '
|
469
|
+
'https://cloud.google.com/vertex-ai/generative-ai/docs/'
|
470
|
+
'model-reference/inference#supported-models'
|
471
|
+
' for details.'
|
472
|
+
),
|
473
|
+
]
|
474
|
+
|
475
|
+
project: Annotated[
|
476
|
+
str | None,
|
477
|
+
(
|
478
|
+
'Vertex AI project ID. Or set from environment variable '
|
479
|
+
'VERTEXAI_PROJECT.'
|
480
|
+
),
|
481
|
+
] = None
|
482
|
+
|
483
|
+
location: Annotated[
|
484
|
+
str | None,
|
485
|
+
(
|
486
|
+
'Vertex AI service location. Or set from environment variable '
|
487
|
+
'VERTEXAI_LOCATION.'
|
488
|
+
),
|
489
|
+
] = None
|
490
|
+
|
491
|
+
credentials: Annotated[
|
492
|
+
Credentials | None,
|
493
|
+
(
|
494
|
+
'Credentials to use. If None, the default credentials to the '
|
495
|
+
'environment will be used.'
|
496
|
+
),
|
497
|
+
] = None
|
498
|
+
|
499
|
+
supported_modalities: Annotated[
|
500
|
+
list[str],
|
501
|
+
'A list of MIME types for supported modalities'
|
502
|
+
] = []
|
503
|
+
|
504
|
+
def _on_bound(self):
|
505
|
+
super()._on_bound()
|
506
|
+
if google_auth is None:
|
507
|
+
raise ValueError(
|
508
|
+
'Please install "langfun[llm-google-vertex]" to use Vertex AI models.'
|
509
|
+
)
|
510
|
+
self._project = None
|
511
|
+
self._credentials = None
|
512
|
+
|
513
|
+
def _initialize(self):
|
514
|
+
project = self.project or os.environ.get('VERTEXAI_PROJECT', None)
|
515
|
+
if not project:
|
516
|
+
raise ValueError(
|
517
|
+
'Please specify `project` during `__init__` or set environment '
|
518
|
+
'variable `VERTEXAI_PROJECT` with your Vertex AI project ID.'
|
519
|
+
)
|
520
|
+
|
521
|
+
location = self.location or os.environ.get('VERTEXAI_LOCATION', None)
|
522
|
+
if not location:
|
523
|
+
raise ValueError(
|
524
|
+
'Please specify `location` during `__init__` or set environment '
|
525
|
+
'variable `VERTEXAI_LOCATION` with your Vertex AI service location.'
|
526
|
+
)
|
527
|
+
|
528
|
+
self._project = project
|
529
|
+
credentials = self.credentials
|
530
|
+
if credentials is None:
|
531
|
+
# Use default credentials.
|
532
|
+
credentials = google_auth.default(
|
533
|
+
scopes=['https://www.googleapis.com/auth/cloud-platform']
|
534
|
+
)
|
535
|
+
self._credentials = credentials
|
536
|
+
|
537
|
+
@property
|
538
|
+
def max_concurrency(self) -> int:
|
539
|
+
"""Returns the maximum number of concurrent requests."""
|
540
|
+
return self.rate_to_max_concurrency(
|
541
|
+
requests_per_min=SUPPORTED_MODELS_AND_SETTINGS[self.model].rpm,
|
542
|
+
tokens_per_min=0,
|
543
|
+
)
|
544
|
+
|
545
|
+
def estimate_cost(
|
546
|
+
self,
|
547
|
+
num_input_tokens: int,
|
548
|
+
num_output_tokens: int
|
549
|
+
) -> float | None:
|
550
|
+
"""Estimate the cost based on usage."""
|
551
|
+
cost_per_1k_input_chars = SUPPORTED_MODELS_AND_SETTINGS[self.model].get(
|
552
|
+
'cost_per_1k_input_chars', None
|
553
|
+
)
|
554
|
+
cost_per_1k_output_chars = SUPPORTED_MODELS_AND_SETTINGS[self.model].get(
|
555
|
+
'cost_per_1k_output_chars', None
|
556
|
+
)
|
557
|
+
if cost_per_1k_output_chars is None or cost_per_1k_input_chars is None:
|
558
|
+
return None
|
559
|
+
return (
|
560
|
+
cost_per_1k_input_chars * num_input_tokens
|
561
|
+
+ cost_per_1k_output_chars * num_output_tokens
|
562
|
+
) * AVGERAGE_CHARS_PER_TOEKN / 1000
|
563
|
+
|
564
|
+
@functools.cached_property
|
565
|
+
def _session(self):
|
566
|
+
assert self._api_initialized
|
567
|
+
assert self._credentials is not None
|
568
|
+
assert auth_requests is not None
|
569
|
+
s = auth_requests.AuthorizedSession(self._credentials)
|
570
|
+
s.headers.update(self.headers or {})
|
571
|
+
return s
|
572
|
+
|
573
|
+
@property
|
574
|
+
def headers(self):
|
575
|
+
return {
|
576
|
+
'Content-Type': 'application/json; charset=utf-8',
|
577
|
+
}
|
578
|
+
|
579
|
+
@property
|
580
|
+
def api_endpoint(self) -> str:
|
581
|
+
return (
|
582
|
+
f'https://{self.location}-aiplatform.googleapis.com/v1/projects/'
|
583
|
+
f'{self.project}/locations/{self.location}/publishers/google/'
|
584
|
+
f'models/{self.model}:generateContent'
|
585
|
+
)
|
586
|
+
|
587
|
+
def request(
|
588
|
+
self, prompt: lf.Message, sampling_options: lf.LMSamplingOptions
|
589
|
+
) -> dict[str, Any]:
|
590
|
+
request = dict(
|
591
|
+
generationConfig=self._generation_config(prompt, sampling_options)
|
592
|
+
)
|
593
|
+
request['contents'] = [self._content_from_message(prompt)]
|
594
|
+
return request
|
595
|
+
|
596
|
+
def _generation_config(
|
597
|
+
self, prompt: lf.Message, options: lf.LMSamplingOptions
|
598
|
+
) -> dict[str, Any]:
|
599
|
+
"""Returns a dict as generation config for prompt and LMSamplingOptions."""
|
600
|
+
config = dict(
|
601
|
+
temperature=options.temperature,
|
602
|
+
maxOutputTokens=options.max_tokens,
|
603
|
+
candidateCount=options.n,
|
604
|
+
topK=options.top_k,
|
605
|
+
topP=options.top_p,
|
606
|
+
stopSequences=options.stop,
|
607
|
+
seed=options.random_seed,
|
608
|
+
responseLogprobs=options.logprobs,
|
609
|
+
logprobs=options.top_logprobs,
|
610
|
+
)
|
611
|
+
|
612
|
+
if json_schema := prompt.metadata.get('json_schema'):
|
613
|
+
if not isinstance(json_schema, dict):
|
614
|
+
raise ValueError(
|
615
|
+
f'`json_schema` must be a dict, got {json_schema!r}.'
|
616
|
+
)
|
617
|
+
json_schema = pg.to_json(json_schema)
|
618
|
+
config['responseSchema'] = json_schema
|
619
|
+
config['responseMimeType'] = 'application/json'
|
620
|
+
prompt.metadata.formatted_text = (
|
621
|
+
prompt.text
|
622
|
+
+ '\n\n [RESPONSE FORMAT (not part of prompt)]\n'
|
623
|
+
+ pg.to_json_str(json_schema, json_indent=2)
|
624
|
+
)
|
625
|
+
return config
|
626
|
+
|
627
|
+
def _content_from_message(self, prompt: lf.Message) -> dict[str, Any]:
|
628
|
+
"""Gets generation content from langfun message."""
|
629
|
+
parts = []
|
630
|
+
for lf_chunk in prompt.chunk():
|
631
|
+
if isinstance(lf_chunk, str):
|
632
|
+
parts.append({'text': lf_chunk})
|
633
|
+
elif isinstance(lf_chunk, lf_modalities.Mime):
|
634
|
+
try:
|
635
|
+
modalities = lf_chunk.make_compatible(
|
636
|
+
self.supported_modalities + ['text/plain']
|
637
|
+
)
|
638
|
+
if isinstance(modalities, lf_modalities.Mime):
|
639
|
+
modalities = [modalities]
|
640
|
+
for modality in modalities:
|
641
|
+
if modality.is_text:
|
642
|
+
parts.append({'text': modality.to_text()})
|
643
|
+
else:
|
644
|
+
parts.append({
|
645
|
+
'inlineData': {
|
646
|
+
'data': base64.b64encode(modality.to_bytes()).decode(),
|
647
|
+
'mimeType': modality.mime_type,
|
648
|
+
}
|
649
|
+
})
|
650
|
+
except lf.ModalityError as e:
|
651
|
+
raise lf.ModalityError(f'Unsupported modality: {lf_chunk!r}') from e
|
652
|
+
else:
|
653
|
+
raise lf.ModalityError(f'Unsupported modality: {lf_chunk!r}')
|
654
|
+
return dict(role='user', parts=parts)
|
655
|
+
|
656
|
+
def result(self, json: dict[str, Any]) -> lf.LMSamplingResult:
|
657
|
+
messages = [
|
658
|
+
self._message_from_content_parts(candidate['content']['parts'])
|
659
|
+
for candidate in json['candidates']
|
660
|
+
]
|
661
|
+
usage = json['usageMetadata']
|
662
|
+
input_tokens = usage['promptTokenCount']
|
663
|
+
output_tokens = usage['candidatesTokenCount']
|
664
|
+
return lf.LMSamplingResult(
|
665
|
+
[lf.LMSample(message) for message in messages],
|
666
|
+
usage=lf.LMSamplingUsage(
|
667
|
+
prompt_tokens=input_tokens,
|
668
|
+
completion_tokens=output_tokens,
|
669
|
+
total_tokens=input_tokens + output_tokens,
|
670
|
+
estimated_cost=self.estimate_cost(
|
671
|
+
num_input_tokens=input_tokens,
|
672
|
+
num_output_tokens=output_tokens,
|
673
|
+
),
|
674
|
+
),
|
675
|
+
)
|
676
|
+
|
677
|
+
def _message_from_content_parts(
|
678
|
+
self, parts: list[dict[str, Any]]
|
679
|
+
) -> lf.Message:
|
680
|
+
"""Converts Vertex AI's content parts protocol to message."""
|
681
|
+
chunks = []
|
682
|
+
for part in parts:
|
683
|
+
if text_part := part.get('text'):
|
684
|
+
chunks.append(text_part)
|
685
|
+
else:
|
686
|
+
raise ValueError(f'Unsupported part: {part}')
|
687
|
+
return lf.AIMessage.from_chunks(chunks)
|
688
|
+
|
689
|
+
|
452
690
|
class _ModelHub:
|
453
691
|
"""Vertex AI model hub."""
|
454
692
|
|
@@ -547,13 +785,21 @@ class VertexAIGeminiPro1_5(VertexAIGemini1_5): # pylint: disable=invalid-name
|
|
547
785
|
model = 'gemini-1.5-pro'
|
548
786
|
|
549
787
|
|
550
|
-
class
|
788
|
+
class VertexAIRestGemini1_5(VertexAIRest): # pylint: disable=invalid-name
|
789
|
+
"""Vertex AI Gemini 1.5 model with REST API."""
|
790
|
+
|
791
|
+
supported_modalities: pg.typing.List(str).freeze( # pytype: disable=invalid-annotation
|
792
|
+
_DOCUMENT_TYPES + _IMAGE_TYPES + _AUDIO_TYPES + _VIDEO_TYPES
|
793
|
+
)
|
794
|
+
|
795
|
+
|
796
|
+
class VertexAIGeminiPro1_5_002(VertexAIRestGemini1_5): # pylint: disable=invalid-name
|
551
797
|
"""Vertex AI Gemini 1.5 Pro model."""
|
552
798
|
|
553
799
|
model = 'gemini-1.5-pro-002'
|
554
800
|
|
555
801
|
|
556
|
-
class VertexAIGeminiPro1_5_001(
|
802
|
+
class VertexAIGeminiPro1_5_001(VertexAIRestGemini1_5): # pylint: disable=invalid-name
|
557
803
|
"""Vertex AI Gemini 1.5 Pro model."""
|
558
804
|
|
559
805
|
model = 'gemini-1.5-pro-001'
|
@@ -583,13 +829,13 @@ class VertexAIGeminiFlash1_5(VertexAIGemini1_5): # pylint: disable=invalid-name
|
|
583
829
|
model = 'gemini-1.5-flash'
|
584
830
|
|
585
831
|
|
586
|
-
class VertexAIGeminiFlash1_5_002(
|
832
|
+
class VertexAIGeminiFlash1_5_002(VertexAIRestGemini1_5): # pylint: disable=invalid-name
|
587
833
|
"""Vertex AI Gemini 1.5 Flash model."""
|
588
834
|
|
589
835
|
model = 'gemini-1.5-flash-002'
|
590
836
|
|
591
837
|
|
592
|
-
class VertexAIGeminiFlash1_5_001(
|
838
|
+
class VertexAIGeminiFlash1_5_001(VertexAIRestGemini1_5): # pylint: disable=invalid-name
|
593
839
|
"""Vertex AI Gemini 1.5 Flash model."""
|
594
840
|
|
595
841
|
model = 'gemini-1.5-flash-001'
|
@@ -601,14 +847,14 @@ class VertexAIGeminiFlash1_5_0514(VertexAIGemini1_5): # pylint: disable=invalid
|
|
601
847
|
model = 'gemini-1.5-flash-preview-0514'
|
602
848
|
|
603
849
|
|
604
|
-
class VertexAIGeminiPro1(
|
850
|
+
class VertexAIGeminiPro1(VertexAIRest): # pylint: disable=invalid-name
|
605
851
|
"""Vertex AI Gemini 1.0 Pro model."""
|
606
852
|
|
607
853
|
model = 'gemini-1.0-pro'
|
608
854
|
|
609
855
|
|
610
856
|
class VertexAIGeminiPro1Vision(VertexAI): # pylint: disable=invalid-name
|
611
|
-
"""Vertex AI Gemini 1.0 Pro model."""
|
857
|
+
"""Vertex AI Gemini 1.0 Pro Vision model."""
|
612
858
|
|
613
859
|
model = 'gemini-1.0-pro-vision'
|
614
860
|
supported_modalities: pg.typing.List(str).freeze( # pytype: disable=invalid-annotation
|
@@ -14,6 +14,7 @@
|
|
14
14
|
"""Tests for Gemini models."""
|
15
15
|
|
16
16
|
import os
|
17
|
+
from typing import Any
|
17
18
|
import unittest
|
18
19
|
from unittest import mock
|
19
20
|
|
@@ -23,6 +24,7 @@ import langfun.core as lf
|
|
23
24
|
from langfun.core import modalities as lf_modalities
|
24
25
|
from langfun.core.llms import vertexai
|
25
26
|
import pyglove as pg
|
27
|
+
import requests
|
26
28
|
|
27
29
|
|
28
30
|
example_image = (
|
@@ -64,6 +66,40 @@ def mock_generate_content(content, generation_config, **kwargs):
|
|
64
66
|
})
|
65
67
|
|
66
68
|
|
69
|
+
def mock_requests_post(url: str, json: dict[str, Any], **kwargs):
|
70
|
+
del url, kwargs
|
71
|
+
c = pg.Dict(json['generationConfig'])
|
72
|
+
content = json['contents'][0]['parts'][0]['text']
|
73
|
+
response = requests.Response()
|
74
|
+
response.status_code = 200
|
75
|
+
response._content = pg.to_json_str({
|
76
|
+
'candidates': [
|
77
|
+
{
|
78
|
+
'content': {
|
79
|
+
'role': 'model',
|
80
|
+
'parts': [
|
81
|
+
{
|
82
|
+
'text': (
|
83
|
+
f'This is a response to {content} with '
|
84
|
+
f'temperature={c.temperature}, '
|
85
|
+
f'top_p={c.topP}, '
|
86
|
+
f'top_k={c.topK}, '
|
87
|
+
f'max_tokens={c.maxOutputTokens}, '
|
88
|
+
f'stop={"".join(c.stopSequences)}.'
|
89
|
+
)
|
90
|
+
},
|
91
|
+
],
|
92
|
+
},
|
93
|
+
},
|
94
|
+
],
|
95
|
+
'usageMetadata': {
|
96
|
+
'promptTokenCount': 3,
|
97
|
+
'candidatesTokenCount': 4,
|
98
|
+
}
|
99
|
+
}).encode()
|
100
|
+
return response
|
101
|
+
|
102
|
+
|
67
103
|
def mock_endpoint_predict(instances, **kwargs):
|
68
104
|
del kwargs
|
69
105
|
assert len(instances) == 1
|
@@ -83,7 +119,7 @@ class VertexAITest(unittest.TestCase):
|
|
83
119
|
|
84
120
|
def test_content_from_message_text_only(self):
|
85
121
|
text = 'This is a beautiful day'
|
86
|
-
model = vertexai.
|
122
|
+
model = vertexai.VertexAIGeminiPro1Vision()
|
87
123
|
chunks = model._content_from_message(lf.UserMessage(text))
|
88
124
|
self.assertEqual(chunks, [text])
|
89
125
|
|
@@ -95,7 +131,7 @@ class VertexAITest(unittest.TestCase):
|
|
95
131
|
|
96
132
|
# Non-multimodal model.
|
97
133
|
with self.assertRaisesRegex(lf.ModalityError, 'Unsupported modality'):
|
98
|
-
vertexai.
|
134
|
+
vertexai.VertexAIPalm2()._content_from_message(message)
|
99
135
|
|
100
136
|
model = vertexai.VertexAIGeminiPro1Vision()
|
101
137
|
chunks = model._content_from_message(message)
|
@@ -119,7 +155,7 @@ class VertexAITest(unittest.TestCase):
|
|
119
155
|
},
|
120
156
|
],
|
121
157
|
})
|
122
|
-
model = vertexai.
|
158
|
+
model = vertexai.VertexAIGeminiPro1Vision()
|
123
159
|
message = model._generation_response_to_message(response)
|
124
160
|
self.assertEqual(message, lf.AIMessage('hello world'))
|
125
161
|
|
@@ -158,25 +194,25 @@ class VertexAITest(unittest.TestCase):
|
|
158
194
|
|
159
195
|
def test_project_and_location_check(self):
|
160
196
|
with self.assertRaisesRegex(ValueError, 'Please specify `project`'):
|
161
|
-
_ = vertexai.
|
197
|
+
_ = vertexai.VertexAIGeminiPro1Vision()._api_initialized
|
162
198
|
|
163
199
|
with self.assertRaisesRegex(ValueError, 'Please specify `location`'):
|
164
|
-
_ = vertexai.
|
200
|
+
_ = vertexai.VertexAIGeminiPro1Vision(project='abc')._api_initialized
|
165
201
|
|
166
202
|
self.assertTrue(
|
167
|
-
vertexai.
|
203
|
+
vertexai.VertexAIGeminiPro1Vision(
|
168
204
|
project='abc', location='us-central1'
|
169
205
|
)._api_initialized
|
170
206
|
)
|
171
207
|
|
172
208
|
os.environ['VERTEXAI_PROJECT'] = 'abc'
|
173
209
|
os.environ['VERTEXAI_LOCATION'] = 'us-central1'
|
174
|
-
self.assertTrue(vertexai.
|
210
|
+
self.assertTrue(vertexai.VertexAIGeminiPro1Vision()._api_initialized)
|
175
211
|
del os.environ['VERTEXAI_PROJECT']
|
176
212
|
del os.environ['VERTEXAI_LOCATION']
|
177
213
|
|
178
214
|
def test_generation_config(self):
|
179
|
-
model = vertexai.
|
215
|
+
model = vertexai.VertexAIGeminiPro1Vision()
|
180
216
|
json_schema = {
|
181
217
|
'type': 'object',
|
182
218
|
'properties': {
|
@@ -245,7 +281,9 @@ class VertexAITest(unittest.TestCase):
|
|
245
281
|
) as mock_generate:
|
246
282
|
mock_generate.side_effect = mock_generate_content
|
247
283
|
|
248
|
-
lm = vertexai.
|
284
|
+
lm = vertexai.VertexAIGeminiPro1Vision(
|
285
|
+
project='abc', location='us-central1'
|
286
|
+
)
|
249
287
|
self.assertEqual(
|
250
288
|
lm(
|
251
289
|
'hello',
|
@@ -328,5 +366,163 @@ class VertexAITest(unittest.TestCase):
|
|
328
366
|
)
|
329
367
|
|
330
368
|
|
369
|
+
class VertexRestfulAITest(unittest.TestCase):
|
370
|
+
"""Tests for Vertex model with REST API."""
|
371
|
+
|
372
|
+
def test_content_from_message_text_only(self):
|
373
|
+
text = 'This is a beautiful day'
|
374
|
+
model = vertexai.VertexAIGeminiPro1_5_002()
|
375
|
+
chunks = model._content_from_message(lf.UserMessage(text))
|
376
|
+
self.assertEqual(chunks, {'role': 'user', 'parts': [{'text': text}]})
|
377
|
+
|
378
|
+
def test_content_from_message_mm(self):
|
379
|
+
message = lf.UserMessage(
|
380
|
+
'This is an <<[[image]]>>, what is it?',
|
381
|
+
image=lf_modalities.Image.from_bytes(example_image),
|
382
|
+
)
|
383
|
+
|
384
|
+
# Non-multimodal model.
|
385
|
+
with self.assertRaisesRegex(lf.ModalityError, 'Unsupported modality'):
|
386
|
+
vertexai.VertexAIGeminiPro1()._content_from_message(message)
|
387
|
+
|
388
|
+
model = vertexai.VertexAIGeminiPro1Vision()
|
389
|
+
chunks = model._content_from_message(message)
|
390
|
+
self.maxDiff = None
|
391
|
+
self.assertEqual([chunks[0], chunks[2]], ['This is an', ', what is it?'])
|
392
|
+
self.assertIsInstance(chunks[1], generative_models.Part)
|
393
|
+
|
394
|
+
def test_generation_response_to_message_text_only(self):
|
395
|
+
response = generative_models.GenerationResponse.from_dict({
|
396
|
+
'candidates': [
|
397
|
+
{
|
398
|
+
'index': 0,
|
399
|
+
'content': {
|
400
|
+
'role': 'model',
|
401
|
+
'parts': [
|
402
|
+
{
|
403
|
+
'text': 'hello world',
|
404
|
+
},
|
405
|
+
],
|
406
|
+
},
|
407
|
+
},
|
408
|
+
],
|
409
|
+
})
|
410
|
+
model = vertexai.VertexAIGeminiPro1Vision()
|
411
|
+
message = model._generation_response_to_message(response)
|
412
|
+
self.assertEqual(message, lf.AIMessage('hello world'))
|
413
|
+
|
414
|
+
def test_model_hub(self):
|
415
|
+
with mock.patch(
|
416
|
+
'vertexai.generative_models.'
|
417
|
+
'GenerativeModel.__init__'
|
418
|
+
) as mock_model_init:
|
419
|
+
mock_model_init.side_effect = lambda *args, **kwargs: None
|
420
|
+
model = vertexai._VERTEXAI_MODEL_HUB.get_generative_model(
|
421
|
+
'gemini-1.0-pro'
|
422
|
+
)
|
423
|
+
self.assertIsNotNone(model)
|
424
|
+
self.assertIs(
|
425
|
+
vertexai._VERTEXAI_MODEL_HUB.get_generative_model('gemini-1.0-pro'),
|
426
|
+
model,
|
427
|
+
)
|
428
|
+
|
429
|
+
@mock.patch.object(vertexai.VertexAIRest, 'credentials', new=True)
|
430
|
+
def test_project_and_location_check(self):
|
431
|
+
with self.assertRaisesRegex(ValueError, 'Please specify `project`'):
|
432
|
+
_ = vertexai.VertexAIGeminiPro1()._api_initialized
|
433
|
+
|
434
|
+
with self.assertRaisesRegex(ValueError, 'Please specify `location`'):
|
435
|
+
_ = vertexai.VertexAIGeminiPro1(project='abc')._api_initialized
|
436
|
+
|
437
|
+
self.assertTrue(
|
438
|
+
vertexai.VertexAIGeminiPro1(
|
439
|
+
project='abc', location='us-central1'
|
440
|
+
)._api_initialized
|
441
|
+
)
|
442
|
+
|
443
|
+
os.environ['VERTEXAI_PROJECT'] = 'abc'
|
444
|
+
os.environ['VERTEXAI_LOCATION'] = 'us-central1'
|
445
|
+
self.assertTrue(vertexai.VertexAIGeminiPro1()._api_initialized)
|
446
|
+
del os.environ['VERTEXAI_PROJECT']
|
447
|
+
del os.environ['VERTEXAI_LOCATION']
|
448
|
+
|
449
|
+
def test_generation_config(self):
|
450
|
+
model = vertexai.VertexAIGeminiPro1()
|
451
|
+
json_schema = {
|
452
|
+
'type': 'object',
|
453
|
+
'properties': {
|
454
|
+
'name': {'type': 'string'},
|
455
|
+
},
|
456
|
+
'required': ['name'],
|
457
|
+
'title': 'Person',
|
458
|
+
}
|
459
|
+
actual = model._generation_config(
|
460
|
+
lf.UserMessage('hi', json_schema=json_schema),
|
461
|
+
lf.LMSamplingOptions(
|
462
|
+
temperature=2.0,
|
463
|
+
top_p=1.0,
|
464
|
+
top_k=20,
|
465
|
+
max_tokens=1024,
|
466
|
+
stop=['\n'],
|
467
|
+
),
|
468
|
+
)
|
469
|
+
self.assertEqual(
|
470
|
+
actual,
|
471
|
+
dict(
|
472
|
+
candidateCount=1,
|
473
|
+
temperature=2.0,
|
474
|
+
topP=1.0,
|
475
|
+
topK=20,
|
476
|
+
maxOutputTokens=1024,
|
477
|
+
stopSequences=['\n'],
|
478
|
+
responseLogprobs=False,
|
479
|
+
logprobs=None,
|
480
|
+
seed=None,
|
481
|
+
responseMimeType='application/json',
|
482
|
+
responseSchema={
|
483
|
+
'type': 'object',
|
484
|
+
'properties': {
|
485
|
+
'name': {'type': 'string'}
|
486
|
+
},
|
487
|
+
'required': ['name'],
|
488
|
+
'title': 'Person',
|
489
|
+
}
|
490
|
+
),
|
491
|
+
)
|
492
|
+
with self.assertRaisesRegex(
|
493
|
+
ValueError, '`json_schema` must be a dict, got'
|
494
|
+
):
|
495
|
+
model._generation_config(
|
496
|
+
lf.UserMessage('hi', json_schema='not a dict'),
|
497
|
+
lf.LMSamplingOptions(),
|
498
|
+
)
|
499
|
+
|
500
|
+
@mock.patch.object(vertexai.VertexAIRest, 'credentials', new=True)
|
501
|
+
def test_call_model(self):
|
502
|
+
with mock.patch('requests.Session.post') as mock_generate:
|
503
|
+
mock_generate.side_effect = mock_requests_post
|
504
|
+
|
505
|
+
lm = vertexai.VertexAIGeminiPro1_5_002(
|
506
|
+
project='abc', location='us-central1'
|
507
|
+
)
|
508
|
+
r = lm(
|
509
|
+
'hello',
|
510
|
+
temperature=2.0,
|
511
|
+
top_p=1.0,
|
512
|
+
top_k=20,
|
513
|
+
max_tokens=1024,
|
514
|
+
stop='\n',
|
515
|
+
)
|
516
|
+
self.assertEqual(
|
517
|
+
r.text,
|
518
|
+
(
|
519
|
+
'This is a response to hello with temperature=2.0, '
|
520
|
+
'top_p=1.0, top_k=20, max_tokens=1024, stop=\n.'
|
521
|
+
),
|
522
|
+
)
|
523
|
+
self.assertEqual(r.metadata.usage.prompt_tokens, 3)
|
524
|
+
self.assertEqual(r.metadata.usage.completion_tokens, 4)
|
525
|
+
|
526
|
+
|
331
527
|
if __name__ == '__main__':
|
332
528
|
unittest.main()
|
{langfun-0.1.2.dev202412010804.dist-info → langfun-0.1.2.dev202412030000.dist-info}/METADATA
RENAMED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: langfun
|
3
|
-
Version: 0.1.2.
|
3
|
+
Version: 0.1.2.dev202412030000
|
4
4
|
Summary: Langfun: Language as Functions.
|
5
5
|
Home-page: https://github.com/google/langfun
|
6
6
|
Author: Langfun Authors
|
@@ -32,7 +32,6 @@ Requires-Dist: termcolor==1.1.0; extra == "all"
|
|
32
32
|
Requires-Dist: tqdm>=4.64.1; extra == "all"
|
33
33
|
Requires-Dist: google-cloud-aiplatform>=1.5.0; extra == "all"
|
34
34
|
Requires-Dist: google-generativeai>=0.3.2; extra == "all"
|
35
|
-
Requires-Dist: openai>=0.27.2; extra == "all"
|
36
35
|
Requires-Dist: python-magic>=0.4.27; extra == "all"
|
37
36
|
Requires-Dist: python-docx>=0.8.11; extra == "all"
|
38
37
|
Requires-Dist: pillow>=10.0.0; extra == "all"
|
@@ -44,7 +43,6 @@ Requires-Dist: tqdm>=4.64.1; extra == "ui"
|
|
44
43
|
Provides-Extra: llm
|
45
44
|
Requires-Dist: google-cloud-aiplatform>=1.5.0; extra == "llm"
|
46
45
|
Requires-Dist: google-generativeai>=0.3.2; extra == "llm"
|
47
|
-
Requires-Dist: openai>=0.27.2; extra == "llm"
|
48
46
|
Provides-Extra: llm-google
|
49
47
|
Requires-Dist: google-cloud-aiplatform>=1.5.0; extra == "llm-google"
|
50
48
|
Requires-Dist: google-generativeai>=0.3.2; extra == "llm-google"
|
@@ -52,8 +50,6 @@ Provides-Extra: llm-google-vertex
|
|
52
50
|
Requires-Dist: google-cloud-aiplatform>=1.5.0; extra == "llm-google-vertex"
|
53
51
|
Provides-Extra: llm-google-genai
|
54
52
|
Requires-Dist: google-generativeai>=0.3.2; extra == "llm-google-genai"
|
55
|
-
Provides-Extra: llm-openai
|
56
|
-
Requires-Dist: openai>=0.27.2; extra == "llm-openai"
|
57
53
|
Provides-Extra: mime
|
58
54
|
Requires-Dist: python-magic>=0.4.27; extra == "mime"
|
59
55
|
Requires-Dist: python-docx>=0.8.11; extra == "mime"
|
@@ -214,7 +210,6 @@ If you want to customize your installation, you can select specific features usi
|
|
214
210
|
| llm-google | All supported Google-powered LLMs. |
|
215
211
|
| llm-google-vertexai | LLMs powered by Google Cloud VertexAI |
|
216
212
|
| llm-google-genai | LLMs powered by Google Generative AI API |
|
217
|
-
| llm-openai | LLMs powered by OpenAI |
|
218
213
|
| mime | All MIME supports. |
|
219
214
|
| mime-auto | Automatic MIME type detection. |
|
220
215
|
| mime-docx | DocX format support. |
|