langfun 0.0.2.dev20240429__py3-none-any.whl → 0.1.2.dev202501140804__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/__init__.py +20 -2
- langfun/core/__init__.py +16 -5
- langfun/core/agentic/__init__.py +30 -0
- langfun/core/agentic/action.py +854 -0
- langfun/core/agentic/action_eval.py +150 -0
- langfun/core/agentic/action_eval_test.py +109 -0
- langfun/core/agentic/action_test.py +136 -0
- langfun/core/coding/python/__init__.py +5 -11
- langfun/core/coding/python/correction.py +37 -21
- langfun/core/coding/python/correction_test.py +29 -3
- langfun/core/coding/python/execution.py +40 -216
- langfun/core/coding/python/execution_test.py +29 -89
- langfun/core/coding/python/generation.py +21 -11
- langfun/core/coding/python/generation_test.py +2 -2
- langfun/core/coding/python/parsing.py +108 -193
- langfun/core/coding/python/parsing_test.py +2 -105
- langfun/core/component.py +63 -2
- langfun/core/component_test.py +53 -0
- langfun/core/concurrent.py +414 -117
- langfun/core/concurrent_test.py +111 -24
- langfun/core/console.py +18 -5
- langfun/core/console_test.py +17 -0
- langfun/core/eval/__init__.py +16 -1
- langfun/core/eval/base.py +622 -174
- langfun/core/eval/base_test.py +200 -54
- langfun/core/eval/matching.py +63 -76
- langfun/core/eval/matching_test.py +17 -8
- langfun/core/eval/patching.py +130 -0
- langfun/core/eval/patching_test.py +170 -0
- langfun/core/eval/scoring.py +26 -26
- langfun/core/eval/scoring_test.py +19 -2
- langfun/core/eval/v2/__init__.py +42 -0
- langfun/core/eval/v2/checkpointing.py +380 -0
- langfun/core/eval/v2/checkpointing_test.py +228 -0
- langfun/core/eval/v2/eval_test_helper.py +136 -0
- langfun/core/eval/v2/evaluation.py +725 -0
- langfun/core/eval/v2/evaluation_test.py +180 -0
- langfun/core/eval/v2/example.py +305 -0
- langfun/core/eval/v2/example_test.py +128 -0
- langfun/core/eval/v2/experiment.py +1048 -0
- langfun/core/eval/v2/experiment_test.py +433 -0
- langfun/core/eval/v2/metric_values.py +156 -0
- langfun/core/eval/v2/metric_values_test.py +80 -0
- langfun/core/eval/v2/metrics.py +357 -0
- langfun/core/eval/v2/metrics_test.py +203 -0
- langfun/core/eval/v2/progress.py +348 -0
- langfun/core/eval/v2/progress_test.py +82 -0
- langfun/core/eval/v2/progress_tracking.py +210 -0
- langfun/core/eval/v2/progress_tracking_test.py +66 -0
- langfun/core/eval/v2/reporting.py +270 -0
- langfun/core/eval/v2/reporting_test.py +158 -0
- langfun/core/eval/v2/runners.py +488 -0
- langfun/core/eval/v2/runners_test.py +334 -0
- langfun/core/langfunc.py +4 -17
- langfun/core/langfunc_test.py +22 -6
- langfun/core/language_model.py +577 -39
- langfun/core/language_model_test.py +470 -56
- langfun/core/llms/__init__.py +87 -16
- langfun/core/llms/anthropic.py +312 -87
- langfun/core/llms/anthropic_test.py +71 -3
- langfun/core/llms/cache/base.py +21 -2
- langfun/core/llms/cache/in_memory.py +13 -0
- langfun/core/llms/cache/in_memory_test.py +53 -2
- langfun/core/llms/compositional.py +101 -0
- langfun/core/llms/compositional_test.py +73 -0
- langfun/core/llms/deepseek.py +117 -0
- langfun/core/llms/deepseek_test.py +61 -0
- langfun/core/llms/fake.py +11 -7
- langfun/core/llms/fake_test.py +14 -0
- langfun/core/llms/gemini.py +507 -0
- langfun/core/llms/gemini_test.py +195 -0
- langfun/core/llms/google_genai.py +62 -218
- langfun/core/llms/google_genai_test.py +9 -202
- langfun/core/llms/groq.py +160 -144
- langfun/core/llms/groq_test.py +31 -137
- langfun/core/llms/llama_cpp.py +15 -42
- langfun/core/llms/llama_cpp_test.py +4 -30
- langfun/core/llms/openai.py +395 -203
- langfun/core/llms/openai_compatible.py +179 -0
- langfun/core/llms/openai_compatible_test.py +495 -0
- langfun/core/llms/openai_test.py +30 -395
- langfun/core/llms/rest.py +113 -0
- langfun/core/llms/rest_test.py +111 -0
- langfun/core/llms/vertexai.py +192 -0
- langfun/core/llms/vertexai_test.py +52 -0
- langfun/core/logging.py +284 -0
- langfun/core/logging_test.py +125 -0
- langfun/core/message.py +319 -9
- langfun/core/message_test.py +190 -13
- langfun/core/modalities/__init__.py +6 -2
- langfun/core/modalities/audio.py +30 -0
- langfun/core/modalities/audio_test.py +63 -0
- langfun/core/modalities/image.py +39 -20
- langfun/core/modalities/image_test.py +52 -9
- langfun/core/modalities/mime.py +206 -29
- langfun/core/modalities/mime_test.py +90 -9
- langfun/core/modalities/ms_office.py +117 -0
- langfun/core/modalities/ms_office_test.py +389 -0
- langfun/core/modalities/pdf.py +22 -0
- langfun/core/modalities/pdf_test.py +57 -0
- langfun/core/modalities/video.py +9 -26
- langfun/core/modalities/video_test.py +3 -3
- langfun/core/modality.py +26 -3
- langfun/core/modality_test.py +2 -2
- langfun/core/sampling.py +11 -11
- langfun/core/structured/__init__.py +12 -16
- langfun/core/structured/completion.py +32 -5
- langfun/core/structured/completion_test.py +7 -6
- langfun/core/structured/description.py +2 -2
- langfun/core/structured/description_test.py +3 -3
- langfun/core/structured/function_generation.py +60 -27
- langfun/core/structured/function_generation_test.py +72 -2
- langfun/core/structured/mapping.py +97 -47
- langfun/core/structured/mapping_test.py +90 -2
- langfun/core/structured/parsing.py +33 -21
- langfun/core/structured/parsing_test.py +53 -9
- langfun/core/structured/querying.py +746 -0
- langfun/core/structured/{prompting_test.py → querying_test.py} +469 -51
- langfun/core/structured/schema.py +204 -97
- langfun/core/structured/schema_generation.py +1 -1
- langfun/core/structured/schema_test.py +130 -29
- langfun/core/structured/scoring.py +125 -19
- langfun/core/structured/scoring_test.py +30 -0
- langfun/core/structured/tokenization.py +64 -0
- langfun/core/structured/tokenization_test.py +48 -0
- langfun/core/template.py +115 -1
- langfun/core/template_test.py +71 -1
- langfun/core/templates/conversation.py +9 -0
- langfun/core/templates/conversation_test.py +4 -3
- langfun/core/templates/selfplay_test.py +10 -2
- langfun-0.1.2.dev202501140804.dist-info/METADATA +225 -0
- langfun-0.1.2.dev202501140804.dist-info/RECORD +153 -0
- {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501140804.dist-info}/WHEEL +1 -1
- langfun/core/coding/python/errors.py +0 -108
- langfun/core/coding/python/errors_test.py +0 -99
- langfun/core/coding/python/permissions.py +0 -90
- langfun/core/coding/python/permissions_test.py +0 -86
- langfun/core/structured/prompting.py +0 -238
- langfun/core/text_formatting.py +0 -162
- langfun/core/text_formatting_test.py +0 -47
- langfun-0.0.2.dev20240429.dist-info/METADATA +0 -100
- langfun-0.0.2.dev20240429.dist-info/RECORD +0 -108
- {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501140804.dist-info}/LICENSE +0 -0
- {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501140804.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,195 @@
|
|
1
|
+
# Copyright 2025 The Langfun Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
"""Tests for Gemini API."""
|
15
|
+
|
16
|
+
import base64
|
17
|
+
from typing import Any
|
18
|
+
import unittest
|
19
|
+
from unittest import mock
|
20
|
+
|
21
|
+
import langfun.core as lf
|
22
|
+
from langfun.core import modalities as lf_modalities
|
23
|
+
from langfun.core.llms import gemini
|
24
|
+
import pyglove as pg
|
25
|
+
import requests
|
26
|
+
|
27
|
+
|
28
|
+
example_image = (
|
29
|
+
b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x18\x00\x00\x00\x18\x04'
|
30
|
+
b'\x03\x00\x00\x00\x12Y \xcb\x00\x00\x00\x18PLTE\x00\x00'
|
31
|
+
b'\x00fff_chaag_cg_ch^ci_ciC\xedb\x94\x00\x00\x00\x08tRNS'
|
32
|
+
b'\x00\n\x9f*\xd4\xff_\xf4\xe4\x8b\xf3a\x00\x00\x00>IDATx'
|
33
|
+
b'\x01c \x05\x08)"\xd8\xcc\xae!\x06pNz\x88k\x19\\Q\xa8"\x10'
|
34
|
+
b'\xc1\x14\x95\x01%\xc1\n\xa143Ta\xa8"D-\x84\x03QM\x98\xc3'
|
35
|
+
b'\x1a\x1a\x1a@5\x0e\x04\xa0q\x88\x05\x00\x07\xf8\x18\xf9'
|
36
|
+
b'\xdao\xd0|\x00\x00\x00\x00IEND\xaeB`\x82'
|
37
|
+
)
|
38
|
+
|
39
|
+
|
40
|
+
def mock_requests_post(url: str, json: dict[str, Any], **kwargs):
|
41
|
+
del url, kwargs
|
42
|
+
c = pg.Dict(json['generationConfig'])
|
43
|
+
content = json['contents'][0]['parts'][0]['text']
|
44
|
+
response = requests.Response()
|
45
|
+
response.status_code = 200
|
46
|
+
response._content = pg.to_json_str({
|
47
|
+
'candidates': [
|
48
|
+
{
|
49
|
+
'content': {
|
50
|
+
'role': 'model',
|
51
|
+
'parts': [
|
52
|
+
{
|
53
|
+
'text': (
|
54
|
+
f'This is a response to {content} with '
|
55
|
+
f'temperature={c.temperature}, '
|
56
|
+
f'top_p={c.topP}, '
|
57
|
+
f'top_k={c.topK}, '
|
58
|
+
f'max_tokens={c.maxOutputTokens}, '
|
59
|
+
f'stop={"".join(c.stopSequences)}.'
|
60
|
+
),
|
61
|
+
},
|
62
|
+
{
|
63
|
+
'text': 'This is the thought.',
|
64
|
+
'thought': True,
|
65
|
+
}
|
66
|
+
],
|
67
|
+
},
|
68
|
+
},
|
69
|
+
],
|
70
|
+
'usageMetadata': {
|
71
|
+
'promptTokenCount': 3,
|
72
|
+
'candidatesTokenCount': 4,
|
73
|
+
}
|
74
|
+
}).encode()
|
75
|
+
return response
|
76
|
+
|
77
|
+
|
78
|
+
class GeminiTest(unittest.TestCase):
|
79
|
+
"""Tests for Vertex model with REST API."""
|
80
|
+
|
81
|
+
def test_content_from_message_text_only(self):
|
82
|
+
text = 'This is a beautiful day'
|
83
|
+
model = gemini.Gemini('gemini-1.5-pro', api_endpoint='')
|
84
|
+
chunks = model._content_from_message(lf.UserMessage(text))
|
85
|
+
self.assertEqual(chunks, {'role': 'user', 'parts': [{'text': text}]})
|
86
|
+
|
87
|
+
def test_content_from_message_mm(self):
|
88
|
+
image = lf_modalities.Image.from_bytes(example_image)
|
89
|
+
message = lf.UserMessage(
|
90
|
+
'This is an <<[[image]]>>, what is it?', image=image
|
91
|
+
)
|
92
|
+
|
93
|
+
# Non-multimodal model.
|
94
|
+
with self.assertRaisesRegex(lf.ModalityError, 'Unsupported modality'):
|
95
|
+
gemini.Gemini(
|
96
|
+
'gemini-1.0-pro', api_endpoint=''
|
97
|
+
)._content_from_message(message)
|
98
|
+
|
99
|
+
model = gemini.Gemini('gemini-1.5-pro', api_endpoint='')
|
100
|
+
content = model._content_from_message(message)
|
101
|
+
self.assertEqual(
|
102
|
+
content,
|
103
|
+
{
|
104
|
+
'role': 'user',
|
105
|
+
'parts': [
|
106
|
+
{'text': 'This is an'},
|
107
|
+
{
|
108
|
+
'inlineData': {
|
109
|
+
'data': base64.b64encode(example_image).decode(),
|
110
|
+
'mimeType': 'image/png',
|
111
|
+
}
|
112
|
+
},
|
113
|
+
{'text': ', what is it?'},
|
114
|
+
],
|
115
|
+
},
|
116
|
+
)
|
117
|
+
|
118
|
+
def test_generation_config(self):
|
119
|
+
model = gemini.Gemini('gemini-1.5-pro', api_endpoint='')
|
120
|
+
json_schema = {
|
121
|
+
'type': 'object',
|
122
|
+
'properties': {
|
123
|
+
'name': {'type': 'string'},
|
124
|
+
},
|
125
|
+
'required': ['name'],
|
126
|
+
'title': 'Person',
|
127
|
+
}
|
128
|
+
actual = model._generation_config(
|
129
|
+
lf.UserMessage('hi', json_schema=json_schema),
|
130
|
+
lf.LMSamplingOptions(
|
131
|
+
temperature=2.0,
|
132
|
+
top_p=1.0,
|
133
|
+
top_k=20,
|
134
|
+
max_tokens=1024,
|
135
|
+
stop=['\n'],
|
136
|
+
),
|
137
|
+
)
|
138
|
+
self.assertEqual(
|
139
|
+
actual,
|
140
|
+
dict(
|
141
|
+
candidateCount=1,
|
142
|
+
temperature=2.0,
|
143
|
+
topP=1.0,
|
144
|
+
topK=20,
|
145
|
+
maxOutputTokens=1024,
|
146
|
+
stopSequences=['\n'],
|
147
|
+
responseLogprobs=False,
|
148
|
+
logprobs=None,
|
149
|
+
seed=None,
|
150
|
+
responseMimeType='application/json',
|
151
|
+
responseSchema={
|
152
|
+
'type': 'object',
|
153
|
+
'properties': {
|
154
|
+
'name': {'type': 'string'}
|
155
|
+
},
|
156
|
+
'required': ['name'],
|
157
|
+
'title': 'Person',
|
158
|
+
}
|
159
|
+
),
|
160
|
+
)
|
161
|
+
with self.assertRaisesRegex(
|
162
|
+
ValueError, '`json_schema` must be a dict, got'
|
163
|
+
):
|
164
|
+
model._generation_config(
|
165
|
+
lf.UserMessage('hi', json_schema='not a dict'),
|
166
|
+
lf.LMSamplingOptions(),
|
167
|
+
)
|
168
|
+
|
169
|
+
def test_call_model(self):
|
170
|
+
with mock.patch('requests.Session.post') as mock_generate:
|
171
|
+
mock_generate.side_effect = mock_requests_post
|
172
|
+
|
173
|
+
lm = gemini.Gemini('gemini-1.5-pro', api_endpoint='')
|
174
|
+
r = lm(
|
175
|
+
'hello',
|
176
|
+
temperature=2.0,
|
177
|
+
top_p=1.0,
|
178
|
+
top_k=20,
|
179
|
+
max_tokens=1024,
|
180
|
+
stop='\n',
|
181
|
+
)
|
182
|
+
self.assertEqual(
|
183
|
+
r.text,
|
184
|
+
(
|
185
|
+
'This is a response to hello with temperature=2.0, '
|
186
|
+
'top_p=1.0, top_k=20, max_tokens=1024, stop=\n.'
|
187
|
+
),
|
188
|
+
)
|
189
|
+
self.assertEqual(r.metadata.thought, 'This is the thought.')
|
190
|
+
self.assertEqual(r.metadata.usage.prompt_tokens, 3)
|
191
|
+
self.assertEqual(r.metadata.usage.completion_tokens, 4)
|
192
|
+
|
193
|
+
|
194
|
+
if __name__ == '__main__':
|
195
|
+
unittest.main()
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 The Langfun Authors
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -11,33 +11,21 @@
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
|
-
"""
|
14
|
+
"""Language models from Google GenAI."""
|
15
15
|
|
16
|
-
import abc
|
17
|
-
import functools
|
18
16
|
import os
|
19
|
-
from typing import Annotated,
|
17
|
+
from typing import Annotated, Literal
|
20
18
|
|
21
|
-
import google.generativeai as genai
|
22
19
|
import langfun.core as lf
|
23
|
-
from langfun.core import
|
20
|
+
from langfun.core.llms import gemini
|
24
21
|
import pyglove as pg
|
25
22
|
|
26
23
|
|
27
24
|
@lf.use_init_args(['model'])
|
28
|
-
|
25
|
+
@pg.members([('api_endpoint', pg.typing.Str().freeze(''))])
|
26
|
+
class GenAI(gemini.Gemini):
|
29
27
|
"""Language models provided by Google GenAI."""
|
30
28
|
|
31
|
-
model: Annotated[
|
32
|
-
Literal[
|
33
|
-
'gemini-pro',
|
34
|
-
'gemini-pro-vision',
|
35
|
-
'text-bison-001',
|
36
|
-
'chat-bison-001',
|
37
|
-
],
|
38
|
-
'Model name.',
|
39
|
-
]
|
40
|
-
|
41
29
|
api_key: Annotated[
|
42
30
|
str | None,
|
43
31
|
(
|
@@ -47,19 +35,18 @@ class GenAI(lf.LanguageModel):
|
|
47
35
|
),
|
48
36
|
] = None
|
49
37
|
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
# Set the default max concurrency to 8 workers.
|
55
|
-
max_concurrency = 8
|
38
|
+
api_version: Annotated[
|
39
|
+
Literal['v1beta', 'v1alpha'],
|
40
|
+
'The API version to use.'
|
41
|
+
] = 'v1beta'
|
56
42
|
|
57
|
-
|
58
|
-
|
59
|
-
|
43
|
+
@property
|
44
|
+
def model_id(self) -> str:
|
45
|
+
"""Returns a string to identify the model."""
|
46
|
+
return f'GenAI({self.model})'
|
60
47
|
|
61
|
-
@
|
62
|
-
def
|
48
|
+
@property
|
49
|
+
def api_endpoint(self) -> str:
|
63
50
|
api_key = self.api_key or os.environ.get('GOOGLE_API_KEY', None)
|
64
51
|
if not api_key:
|
65
52
|
raise ValueError(
|
@@ -69,219 +56,76 @@ class GenAI(lf.LanguageModel):
|
|
69
56
|
'https://cloud.google.com/api-keys/docs/create-manage-api-keys '
|
70
57
|
'for more details.'
|
71
58
|
)
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
"""Lists generative models."""
|
78
|
-
return [
|
79
|
-
m.name.lstrip('models/')
|
80
|
-
for m in genai.list_models()
|
81
|
-
if (
|
82
|
-
'generateContent' in m.supported_generation_methods
|
83
|
-
or 'generateText' in m.supported_generation_methods
|
84
|
-
or 'generateMessage' in m.supported_generation_methods
|
85
|
-
)
|
86
|
-
]
|
59
|
+
return (
|
60
|
+
f'https://generativelanguage.googleapis.com/{self.api_version}'
|
61
|
+
f'/models/{self.model}:generateContent?'
|
62
|
+
f'key={api_key}'
|
63
|
+
)
|
87
64
|
|
88
|
-
@property
|
89
|
-
def model_id(self) -> str:
|
90
|
-
"""Returns a string to identify the model."""
|
91
|
-
return self.model
|
92
65
|
|
93
|
-
|
94
|
-
|
95
|
-
"""Returns a string to identify the resource for rate control."""
|
96
|
-
return self.model_id
|
97
|
-
|
98
|
-
def _generation_config(self, options: lf.LMSamplingOptions) -> dict[str, Any]:
|
99
|
-
"""Creates generation config from langfun sampling options."""
|
100
|
-
return genai.GenerationConfig(
|
101
|
-
candidate_count=options.n,
|
102
|
-
temperature=options.temperature,
|
103
|
-
top_p=options.top_p,
|
104
|
-
top_k=options.top_k,
|
105
|
-
max_output_tokens=options.max_tokens,
|
106
|
-
stop_sequences=options.stop,
|
107
|
-
)
|
66
|
+
class GeminiFlash2_0ThinkingExp_20241219(GenAI): # pylint: disable=invalid-name
|
67
|
+
"""Gemini Flash 2.0 Thinking model launched on 12/19/2024."""
|
108
68
|
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
"""Gets Evergreen formatted content from langfun message."""
|
113
|
-
formatted = lf.UserMessage(prompt.text)
|
114
|
-
formatted.source = prompt
|
115
|
-
|
116
|
-
chunks = []
|
117
|
-
for lf_chunk in formatted.chunk():
|
118
|
-
if isinstance(lf_chunk, str):
|
119
|
-
chunk = lf_chunk
|
120
|
-
elif self.multimodal and isinstance(lf_chunk, lf_modalities.MimeType):
|
121
|
-
chunk = genai.types.BlobDict(
|
122
|
-
data=lf_chunk.to_bytes(), mime_type=lf_chunk.mime_type
|
123
|
-
)
|
124
|
-
else:
|
125
|
-
raise ValueError(f'Unsupported modality: {lf_chunk!r}')
|
126
|
-
chunks.append(chunk)
|
127
|
-
return chunks
|
128
|
-
|
129
|
-
def _response_to_result(
|
130
|
-
self, response: genai.types.GenerateContentResponse | pg.Dict
|
131
|
-
) -> lf.LMSamplingResult:
|
132
|
-
"""Parses generative response into message."""
|
133
|
-
samples = []
|
134
|
-
for candidate in response.candidates:
|
135
|
-
chunks = []
|
136
|
-
for part in candidate.content.parts:
|
137
|
-
# TODO(daiyip): support multi-modal parts when they are available via
|
138
|
-
# Gemini API.
|
139
|
-
if hasattr(part, 'text'):
|
140
|
-
chunks.append(part.text)
|
141
|
-
samples.append(lf.LMSample(lf.AIMessage.from_chunks(chunks), score=0.0))
|
142
|
-
return lf.LMSamplingResult(samples)
|
143
|
-
|
144
|
-
def _sample(self, prompts: list[lf.Message]) -> list[lf.LMSamplingResult]:
|
145
|
-
assert self._api_initialized, 'Vertex AI API is not initialized.'
|
146
|
-
return self._parallel_execute_with_currency_control(
|
147
|
-
self._sample_single,
|
148
|
-
prompts,
|
149
|
-
)
|
69
|
+
api_version = 'v1alpha'
|
70
|
+
model = 'gemini-2.0-flash-thinking-exp-1219'
|
71
|
+
timeout = None
|
150
72
|
|
151
|
-
def _sample_single(self, prompt: lf.Message) -> lf.LMSamplingResult:
|
152
|
-
"""Samples a single prompt."""
|
153
|
-
model = _GOOGLE_GENAI_MODEL_HUB.get(self.model)
|
154
|
-
input_content = self._content_from_message(prompt)
|
155
|
-
response = model.generate_content(
|
156
|
-
input_content,
|
157
|
-
generation_config=self._generation_config(self.sampling_options),
|
158
|
-
)
|
159
|
-
return self._response_to_result(response)
|
160
|
-
|
161
|
-
|
162
|
-
class _LegacyGenerativeModel(pg.Object):
|
163
|
-
"""Base for legacy GenAI generative model."""
|
164
|
-
|
165
|
-
model: str
|
166
|
-
|
167
|
-
def generate_content(
|
168
|
-
self,
|
169
|
-
input_content: list[str | genai.types.BlobDict],
|
170
|
-
generation_config: genai.GenerationConfig,
|
171
|
-
) -> pg.Dict:
|
172
|
-
"""Generate content."""
|
173
|
-
segments = []
|
174
|
-
for s in input_content:
|
175
|
-
if not isinstance(s, str):
|
176
|
-
raise ValueError(f'Unsupported modality: {s!r}')
|
177
|
-
segments.append(s)
|
178
|
-
return self.generate(' '.join(segments), generation_config)
|
179
|
-
|
180
|
-
@abc.abstractmethod
|
181
|
-
def generate(
|
182
|
-
self, prompt: str, generation_config: genai.GenerationConfig) -> pg.Dict:
|
183
|
-
"""Generate response based on prompt."""
|
184
|
-
|
185
|
-
|
186
|
-
class _LegacyCompletionModel(_LegacyGenerativeModel):
|
187
|
-
"""Legacy GenAI completion model."""
|
188
|
-
|
189
|
-
def generate(
|
190
|
-
self, prompt: str, generation_config: genai.GenerationConfig
|
191
|
-
) -> pg.Dict:
|
192
|
-
completion: genai.types.Completion = genai.generate_text(
|
193
|
-
model=f'models/{self.model}',
|
194
|
-
prompt=prompt,
|
195
|
-
temperature=generation_config.temperature,
|
196
|
-
top_k=generation_config.top_k,
|
197
|
-
top_p=generation_config.top_p,
|
198
|
-
candidate_count=generation_config.candidate_count,
|
199
|
-
max_output_tokens=generation_config.max_output_tokens,
|
200
|
-
stop_sequences=generation_config.stop_sequences,
|
201
|
-
)
|
202
|
-
return pg.Dict(
|
203
|
-
candidates=[
|
204
|
-
pg.Dict(content=pg.Dict(parts=[pg.Dict(text=c['output'])]))
|
205
|
-
for c in completion.candidates
|
206
|
-
]
|
207
|
-
)
|
208
73
|
|
74
|
+
class GeminiFlash2_0Exp(GenAI): # pylint: disable=invalid-name
|
75
|
+
"""Gemini Flash 2.0 model launched on 12/11/2024."""
|
209
76
|
|
210
|
-
|
211
|
-
"""Legacy GenAI chat model."""
|
77
|
+
model = 'gemini-2.0-flash-exp'
|
212
78
|
|
213
|
-
def generate(
|
214
|
-
self, prompt: str, generation_config: genai.GenerationConfig
|
215
|
-
) -> pg.Dict:
|
216
|
-
response: genai.types.ChatResponse = genai.chat(
|
217
|
-
model=f'models/{self.model}',
|
218
|
-
messages=prompt,
|
219
|
-
temperature=generation_config.temperature,
|
220
|
-
top_k=generation_config.top_k,
|
221
|
-
top_p=generation_config.top_p,
|
222
|
-
candidate_count=generation_config.candidate_count,
|
223
|
-
)
|
224
|
-
return pg.Dict(
|
225
|
-
candidates=[
|
226
|
-
pg.Dict(content=pg.Dict(parts=[pg.Dict(text=c['content'])]))
|
227
|
-
for c in response.candidates
|
228
|
-
]
|
229
|
-
)
|
230
79
|
|
80
|
+
class GeminiExp_20241206(GenAI): # pylint: disable=invalid-name
|
81
|
+
"""Gemini Experimental model launched on 12/06/2024."""
|
231
82
|
|
232
|
-
|
233
|
-
"""Google Generative AI model hub."""
|
83
|
+
model = 'gemini-exp-1206'
|
234
84
|
|
235
|
-
def __init__(self):
|
236
|
-
self._model_cache = {}
|
237
85
|
|
238
|
-
|
239
|
-
|
240
|
-
) -> genai.GenerativeModel | _LegacyGenerativeModel:
|
241
|
-
"""Gets a generative model by model id."""
|
242
|
-
model = self._model_cache.get(model_name, None)
|
243
|
-
if model is None:
|
244
|
-
model_info = genai.get_model(f'models/{model_name}')
|
245
|
-
if 'generateContent' in model_info.supported_generation_methods:
|
246
|
-
model = genai.GenerativeModel(model_name)
|
247
|
-
elif 'generateText' in model_info.supported_generation_methods:
|
248
|
-
model = _LegacyCompletionModel(model_name)
|
249
|
-
elif 'generateMessage' in model_info.supported_generation_methods:
|
250
|
-
model = _LegacyChatModel(model_name)
|
251
|
-
else:
|
252
|
-
raise ValueError(f'Unsupported model: {model_name!r}')
|
253
|
-
self._model_cache[model_name] = model
|
254
|
-
return model
|
86
|
+
class GeminiExp_20241114(GenAI): # pylint: disable=invalid-name
|
87
|
+
"""Gemini Experimental model launched on 11/14/2024."""
|
255
88
|
|
89
|
+
model = 'gemini-exp-1114'
|
256
90
|
|
257
|
-
_GOOGLE_GENAI_MODEL_HUB = _ModelHub()
|
258
91
|
|
92
|
+
class GeminiPro1_5(GenAI): # pylint: disable=invalid-name
|
93
|
+
"""Gemini Pro latest model."""
|
259
94
|
|
260
|
-
|
261
|
-
|
262
|
-
|
95
|
+
model = 'gemini-1.5-pro-latest'
|
96
|
+
|
97
|
+
|
98
|
+
class GeminiPro1_5_002(GenAI): # pylint: disable=invalid-name
|
99
|
+
"""Gemini Pro latest model."""
|
100
|
+
|
101
|
+
model = 'gemini-1.5-pro-002'
|
102
|
+
|
103
|
+
|
104
|
+
class GeminiPro1_5_001(GenAI): # pylint: disable=invalid-name
|
105
|
+
"""Gemini Pro latest model."""
|
106
|
+
|
107
|
+
model = 'gemini-1.5-pro-001'
|
263
108
|
|
264
109
|
|
265
|
-
class
|
266
|
-
"""Gemini
|
110
|
+
class GeminiFlash1_5(GenAI): # pylint: disable=invalid-name
|
111
|
+
"""Gemini Flash latest model."""
|
267
112
|
|
268
|
-
model = 'gemini-
|
113
|
+
model = 'gemini-1.5-flash-latest'
|
269
114
|
|
270
115
|
|
271
|
-
class
|
272
|
-
"""Gemini
|
116
|
+
class GeminiFlash1_5_002(GenAI): # pylint: disable=invalid-name
|
117
|
+
"""Gemini Flash 1.5 model stable version 002."""
|
273
118
|
|
274
|
-
model = 'gemini-
|
275
|
-
multimodal = True
|
119
|
+
model = 'gemini-1.5-flash-002'
|
276
120
|
|
277
121
|
|
278
|
-
class
|
279
|
-
"""
|
122
|
+
class GeminiFlash1_5_001(GenAI): # pylint: disable=invalid-name
|
123
|
+
"""Gemini Flash 1.5 model stable version 001."""
|
280
124
|
|
281
|
-
model = '
|
125
|
+
model = 'gemini-1.5-flash-001'
|
282
126
|
|
283
127
|
|
284
|
-
class
|
285
|
-
"""
|
128
|
+
class GeminiPro1(GenAI): # pylint: disable=invalid-name
|
129
|
+
"""Gemini 1.0 Pro model."""
|
286
130
|
|
287
|
-
model = '
|
131
|
+
model = 'gemini-1.0-pro'
|