langfun 0.1.2.dev202501050804__py3-none-any.whl → 0.1.2.dev202501060804__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/core/eval/v2/reporting.py +7 -2
- langfun/core/llms/__init__.py +21 -26
- langfun/core/llms/cache/in_memory.py +6 -0
- langfun/core/llms/cache/in_memory_test.py +5 -0
- langfun/core/llms/gemini.py +507 -0
- langfun/core/llms/gemini_test.py +195 -0
- langfun/core/llms/google_genai.py +45 -320
- langfun/core/llms/google_genai_test.py +9 -204
- langfun/core/llms/vertexai.py +25 -357
- langfun/core/llms/vertexai_test.py +6 -166
- {langfun-0.1.2.dev202501050804.dist-info → langfun-0.1.2.dev202501060804.dist-info}/METADATA +7 -13
- {langfun-0.1.2.dev202501050804.dist-info → langfun-0.1.2.dev202501060804.dist-info}/RECORD +15 -13
- {langfun-0.1.2.dev202501050804.dist-info → langfun-0.1.2.dev202501060804.dist-info}/WHEEL +1 -1
- {langfun-0.1.2.dev202501050804.dist-info → langfun-0.1.2.dev202501060804.dist-info}/LICENSE +0 -0
- {langfun-0.1.2.dev202501050804.dist-info → langfun-0.1.2.dev202501060804.dist-info}/top_level.txt +0 -0
@@ -11,223 +11,28 @@
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
|
-
"""Tests for
|
14
|
+
"""Tests for Google GenAI models."""
|
15
15
|
|
16
16
|
import os
|
17
17
|
import unittest
|
18
|
-
from unittest import mock
|
19
|
-
|
20
|
-
from google import generativeai as genai
|
21
|
-
import langfun.core as lf
|
22
|
-
from langfun.core import modalities as lf_modalities
|
23
18
|
from langfun.core.llms import google_genai
|
24
|
-
import pyglove as pg
|
25
|
-
|
26
|
-
|
27
|
-
example_image = (
|
28
|
-
b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x18\x00\x00\x00\x18\x04'
|
29
|
-
b'\x03\x00\x00\x00\x12Y \xcb\x00\x00\x00\x18PLTE\x00\x00'
|
30
|
-
b'\x00fff_chaag_cg_ch^ci_ciC\xedb\x94\x00\x00\x00\x08tRNS'
|
31
|
-
b'\x00\n\x9f*\xd4\xff_\xf4\xe4\x8b\xf3a\x00\x00\x00>IDATx'
|
32
|
-
b'\x01c \x05\x08)"\xd8\xcc\xae!\x06pNz\x88k\x19\\Q\xa8"\x10'
|
33
|
-
b'\xc1\x14\x95\x01%\xc1\n\xa143Ta\xa8"D-\x84\x03QM\x98\xc3'
|
34
|
-
b'\x1a\x1a\x1a@5\x0e\x04\xa0q\x88\x05\x00\x07\xf8\x18\xf9'
|
35
|
-
b'\xdao\xd0|\x00\x00\x00\x00IEND\xaeB`\x82'
|
36
|
-
)
|
37
|
-
|
38
|
-
|
39
|
-
def mock_get_model(model_name, *args, **kwargs):
|
40
|
-
del args, kwargs
|
41
|
-
if 'gemini' in model_name:
|
42
|
-
method = 'generateContent'
|
43
|
-
elif 'chat' in model_name:
|
44
|
-
method = 'generateMessage'
|
45
|
-
else:
|
46
|
-
method = 'generateText'
|
47
|
-
return pg.Dict(supported_generation_methods=[method])
|
48
|
-
|
49
|
-
|
50
|
-
def mock_generate_text(*, model, prompt, **kwargs):
|
51
|
-
return pg.Dict(
|
52
|
-
candidates=[pg.Dict(output=f'{prompt} to {model} with {kwargs}')]
|
53
|
-
)
|
54
|
-
|
55
|
-
|
56
|
-
def mock_chat(*, model, messages, **kwargs):
|
57
|
-
return pg.Dict(
|
58
|
-
candidates=[pg.Dict(content=f'{messages} to {model} with {kwargs}')]
|
59
|
-
)
|
60
|
-
|
61
|
-
|
62
|
-
def mock_generate_content(content, generation_config, **kwargs):
|
63
|
-
del kwargs
|
64
|
-
c = generation_config
|
65
|
-
return genai.types.GenerateContentResponse(
|
66
|
-
done=True,
|
67
|
-
iterator=None,
|
68
|
-
chunks=[],
|
69
|
-
result=pg.Dict(
|
70
|
-
prompt_feedback=pg.Dict(block_reason=None),
|
71
|
-
candidates=[
|
72
|
-
pg.Dict(
|
73
|
-
content=pg.Dict(
|
74
|
-
parts=[
|
75
|
-
pg.Dict(
|
76
|
-
text=(
|
77
|
-
f'This is a response to {content[0]} with '
|
78
|
-
f'n={c.candidate_count}, '
|
79
|
-
f'temperature={c.temperature}, '
|
80
|
-
f'top_p={c.top_p}, '
|
81
|
-
f'top_k={c.top_k}, '
|
82
|
-
f'max_tokens={c.max_output_tokens}, '
|
83
|
-
f'stop={c.stop_sequences}.'
|
84
|
-
)
|
85
|
-
)
|
86
|
-
]
|
87
|
-
),
|
88
|
-
),
|
89
|
-
],
|
90
|
-
),
|
91
|
-
)
|
92
19
|
|
93
20
|
|
94
21
|
class GenAITest(unittest.TestCase):
|
95
|
-
"""Tests for
|
96
|
-
|
97
|
-
def test_content_from_message_text_only(self):
|
98
|
-
text = 'This is a beautiful day'
|
99
|
-
model = google_genai.GeminiPro()
|
100
|
-
chunks = model._content_from_message(lf.UserMessage(text))
|
101
|
-
self.assertEqual(chunks, [text])
|
102
|
-
|
103
|
-
def test_content_from_message_mm(self):
|
104
|
-
message = lf.UserMessage(
|
105
|
-
'This is an <<[[image]]>>, what is it?',
|
106
|
-
image=lf_modalities.Image.from_bytes(example_image),
|
107
|
-
)
|
22
|
+
"""Tests for GenAI model."""
|
108
23
|
|
109
|
-
|
110
|
-
with self.assertRaisesRegex(lf.ModalityError, 'Unsupported modality'):
|
111
|
-
google_genai.GeminiPro()._content_from_message(message)
|
112
|
-
|
113
|
-
model = google_genai.GeminiProVision()
|
114
|
-
chunks = model._content_from_message(message)
|
115
|
-
self.maxDiff = None
|
116
|
-
self.assertEqual(
|
117
|
-
chunks,
|
118
|
-
[
|
119
|
-
'This is an',
|
120
|
-
genai.types.BlobDict(mime_type='image/png', data=example_image),
|
121
|
-
', what is it?',
|
122
|
-
],
|
123
|
-
)
|
124
|
-
|
125
|
-
def test_response_to_result_text_only(self):
|
126
|
-
response = genai.types.GenerateContentResponse(
|
127
|
-
done=True,
|
128
|
-
iterator=None,
|
129
|
-
chunks=[],
|
130
|
-
result=pg.Dict(
|
131
|
-
prompt_feedback=pg.Dict(block_reason=None),
|
132
|
-
candidates=[
|
133
|
-
pg.Dict(
|
134
|
-
content=pg.Dict(
|
135
|
-
parts=[pg.Dict(text='This is response 1.')]
|
136
|
-
),
|
137
|
-
),
|
138
|
-
pg.Dict(
|
139
|
-
content=pg.Dict(parts=[pg.Dict(text='This is response 2.')])
|
140
|
-
),
|
141
|
-
],
|
142
|
-
),
|
143
|
-
)
|
144
|
-
model = google_genai.GeminiProVision()
|
145
|
-
result = model._response_to_result(response)
|
146
|
-
self.assertEqual(
|
147
|
-
result,
|
148
|
-
lf.LMSamplingResult([
|
149
|
-
lf.LMSample(lf.AIMessage('This is response 1.'), score=0.0),
|
150
|
-
lf.LMSample(lf.AIMessage('This is response 2.'), score=0.0),
|
151
|
-
]),
|
152
|
-
)
|
153
|
-
|
154
|
-
def test_model_hub(self):
|
155
|
-
orig_get_model = genai.get_model
|
156
|
-
genai.get_model = mock_get_model
|
157
|
-
|
158
|
-
model = google_genai._GOOGLE_GENAI_MODEL_HUB.get('gemini-pro')
|
159
|
-
self.assertIsNotNone(model)
|
160
|
-
self.assertIs(google_genai._GOOGLE_GENAI_MODEL_HUB.get('gemini-pro'), model)
|
161
|
-
|
162
|
-
genai.get_model = orig_get_model
|
163
|
-
|
164
|
-
def test_api_key_check(self):
|
24
|
+
def test_basics(self):
|
165
25
|
with self.assertRaisesRegex(ValueError, 'Please specify `api_key`'):
|
166
|
-
_ = google_genai.
|
26
|
+
_ = google_genai.GeminiPro1_5().api_endpoint
|
27
|
+
|
28
|
+
self.assertIsNotNone(google_genai.GeminiPro1_5(api_key='abc').api_endpoint)
|
167
29
|
|
168
|
-
self.assertTrue(google_genai.GeminiPro(api_key='abc')._api_initialized)
|
169
30
|
os.environ['GOOGLE_API_KEY'] = 'abc'
|
170
|
-
|
31
|
+
lm = google_genai.GeminiPro1_5()
|
32
|
+
self.assertIsNotNone(lm.api_endpoint)
|
33
|
+
self.assertTrue(lm.model_id.startswith('GenAI('))
|
171
34
|
del os.environ['GOOGLE_API_KEY']
|
172
35
|
|
173
|
-
def test_call(self):
|
174
|
-
with mock.patch(
|
175
|
-
'google.generativeai.GenerativeModel.generate_content',
|
176
|
-
) as mock_generate:
|
177
|
-
orig_get_model = genai.get_model
|
178
|
-
genai.get_model = mock_get_model
|
179
|
-
mock_generate.side_effect = mock_generate_content
|
180
|
-
|
181
|
-
lm = google_genai.GeminiPro(api_key='test_key')
|
182
|
-
self.maxDiff = None
|
183
|
-
self.assertEqual(
|
184
|
-
lm('hello', temperature=2.0, top_k=20, max_tokens=1024).text,
|
185
|
-
(
|
186
|
-
'This is a response to hello with n=1, temperature=2.0, '
|
187
|
-
'top_p=None, top_k=20, max_tokens=1024, stop=None.'
|
188
|
-
),
|
189
|
-
)
|
190
|
-
genai.get_model = orig_get_model
|
191
|
-
|
192
|
-
def test_call_with_legacy_completion_model(self):
|
193
|
-
orig_get_model = genai.get_model
|
194
|
-
genai.get_model = mock_get_model
|
195
|
-
orig_generate_text = getattr(genai, 'generate_text', None)
|
196
|
-
if orig_generate_text is not None:
|
197
|
-
genai.generate_text = mock_generate_text
|
198
|
-
|
199
|
-
lm = google_genai.Palm2(api_key='test_key')
|
200
|
-
self.maxDiff = None
|
201
|
-
self.assertEqual(
|
202
|
-
lm('hello', temperature=2.0, top_k=20).text,
|
203
|
-
(
|
204
|
-
"hello to models/text-bison-001 with {'temperature': 2.0, "
|
205
|
-
"'top_k': 20, 'top_p': None, 'candidate_count': 1, "
|
206
|
-
"'max_output_tokens': None, 'stop_sequences': None}"
|
207
|
-
),
|
208
|
-
)
|
209
|
-
genai.generate_text = orig_generate_text
|
210
|
-
genai.get_model = orig_get_model
|
211
|
-
|
212
|
-
def test_call_with_legacy_chat_model(self):
|
213
|
-
orig_get_model = genai.get_model
|
214
|
-
genai.get_model = mock_get_model
|
215
|
-
orig_chat = getattr(genai, 'chat', None)
|
216
|
-
if orig_chat is not None:
|
217
|
-
genai.chat = mock_chat
|
218
|
-
|
219
|
-
lm = google_genai.Palm2_IT(api_key='test_key')
|
220
|
-
self.maxDiff = None
|
221
|
-
self.assertEqual(
|
222
|
-
lm('hello', temperature=2.0, top_k=20).text,
|
223
|
-
(
|
224
|
-
"hello to models/chat-bison-001 with {'temperature': 2.0, "
|
225
|
-
"'top_k': 20, 'top_p': None, 'candidate_count': 1}"
|
226
|
-
),
|
227
|
-
)
|
228
|
-
genai.chat = orig_chat
|
229
|
-
genai.get_model = orig_get_model
|
230
|
-
|
231
36
|
|
232
37
|
if __name__ == '__main__':
|
233
38
|
unittest.main()
|
langfun/core/llms/vertexai.py
CHANGED
@@ -13,14 +13,12 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
"""Vertex AI generative models."""
|
15
15
|
|
16
|
-
import base64
|
17
16
|
import functools
|
18
17
|
import os
|
19
18
|
from typing import Annotated, Any
|
20
19
|
|
21
20
|
import langfun.core as lf
|
22
|
-
from langfun.core import
|
23
|
-
from langfun.core.llms import rest
|
21
|
+
from langfun.core.llms import gemini
|
24
22
|
import pyglove as pg
|
25
23
|
|
26
24
|
try:
|
@@ -38,114 +36,11 @@ except ImportError:
|
|
38
36
|
Credentials = Any
|
39
37
|
|
40
38
|
|
41
|
-
# https://cloud.google.com/vertex-ai/generative-ai/pricing
|
42
|
-
# describes that the average number of characters per token is about 4.
|
43
|
-
AVGERAGE_CHARS_PER_TOKEN = 4
|
44
|
-
|
45
|
-
|
46
|
-
# Price in US dollars,
|
47
|
-
# from https://cloud.google.com/vertex-ai/generative-ai/pricing
|
48
|
-
# as of 2024-10-10.
|
49
|
-
SUPPORTED_MODELS_AND_SETTINGS = {
|
50
|
-
'gemini-1.5-pro-001': pg.Dict(
|
51
|
-
rpm=100,
|
52
|
-
cost_per_1k_input_chars=0.0003125,
|
53
|
-
cost_per_1k_output_chars=0.00125,
|
54
|
-
),
|
55
|
-
'gemini-1.5-pro-002': pg.Dict(
|
56
|
-
rpm=100,
|
57
|
-
cost_per_1k_input_chars=0.0003125,
|
58
|
-
cost_per_1k_output_chars=0.00125,
|
59
|
-
),
|
60
|
-
'gemini-1.5-flash-002': pg.Dict(
|
61
|
-
rpm=500,
|
62
|
-
cost_per_1k_input_chars=0.00001875,
|
63
|
-
cost_per_1k_output_chars=0.000075,
|
64
|
-
),
|
65
|
-
'gemini-1.5-flash-001': pg.Dict(
|
66
|
-
rpm=500,
|
67
|
-
cost_per_1k_input_chars=0.00001875,
|
68
|
-
cost_per_1k_output_chars=0.000075,
|
69
|
-
),
|
70
|
-
'gemini-1.5-pro': pg.Dict(
|
71
|
-
rpm=100,
|
72
|
-
cost_per_1k_input_chars=0.0003125,
|
73
|
-
cost_per_1k_output_chars=0.00125,
|
74
|
-
),
|
75
|
-
'gemini-1.5-flash': pg.Dict(
|
76
|
-
rpm=500,
|
77
|
-
cost_per_1k_input_chars=0.00001875,
|
78
|
-
cost_per_1k_output_chars=0.000075,
|
79
|
-
),
|
80
|
-
'gemini-1.5-pro-preview-0514': pg.Dict(
|
81
|
-
rpm=50,
|
82
|
-
cost_per_1k_input_chars=0.0003125,
|
83
|
-
cost_per_1k_output_chars=0.00125,
|
84
|
-
),
|
85
|
-
'gemini-1.5-pro-preview-0409': pg.Dict(
|
86
|
-
rpm=50,
|
87
|
-
cost_per_1k_input_chars=0.0003125,
|
88
|
-
cost_per_1k_output_chars=0.00125,
|
89
|
-
),
|
90
|
-
'gemini-1.5-flash-preview-0514': pg.Dict(
|
91
|
-
rpm=200,
|
92
|
-
cost_per_1k_input_chars=0.00001875,
|
93
|
-
cost_per_1k_output_chars=0.000075,
|
94
|
-
),
|
95
|
-
'gemini-1.0-pro': pg.Dict(
|
96
|
-
rpm=300,
|
97
|
-
cost_per_1k_input_chars=0.000125,
|
98
|
-
cost_per_1k_output_chars=0.000375,
|
99
|
-
),
|
100
|
-
'gemini-1.0-pro-vision': pg.Dict(
|
101
|
-
rpm=100,
|
102
|
-
cost_per_1k_input_chars=0.000125,
|
103
|
-
cost_per_1k_output_chars=0.000375,
|
104
|
-
),
|
105
|
-
# TODO(sharatsharat): Update costs when published
|
106
|
-
'gemini-exp-1206': pg.Dict(
|
107
|
-
rpm=20,
|
108
|
-
cost_per_1k_input_chars=0.000,
|
109
|
-
cost_per_1k_output_chars=0.000,
|
110
|
-
),
|
111
|
-
# TODO(sharatsharat): Update costs when published
|
112
|
-
'gemini-2.0-flash-exp': pg.Dict(
|
113
|
-
rpm=10,
|
114
|
-
cost_per_1k_input_chars=0.000,
|
115
|
-
cost_per_1k_output_chars=0.000,
|
116
|
-
),
|
117
|
-
# TODO(yifenglu): Update costs when published
|
118
|
-
'gemini-2.0-flash-thinking-exp-1219': pg.Dict(
|
119
|
-
rpm=10,
|
120
|
-
cost_per_1k_input_chars=0.000,
|
121
|
-
cost_per_1k_output_chars=0.000,
|
122
|
-
),
|
123
|
-
# TODO(chengrun): Set a more appropriate rpm for endpoint.
|
124
|
-
'vertexai-endpoint': pg.Dict(
|
125
|
-
rpm=20,
|
126
|
-
cost_per_1k_input_chars=0.0000125,
|
127
|
-
cost_per_1k_output_chars=0.0000375,
|
128
|
-
),
|
129
|
-
}
|
130
|
-
|
131
|
-
|
132
39
|
@lf.use_init_args(['model'])
|
133
40
|
@pg.members([('api_endpoint', pg.typing.Str().freeze(''))])
|
134
|
-
class VertexAI(
|
41
|
+
class VertexAI(gemini.Gemini):
|
135
42
|
"""Language model served on VertexAI with REST API."""
|
136
43
|
|
137
|
-
model: pg.typing.Annotated[
|
138
|
-
pg.typing.Enum(
|
139
|
-
pg.MISSING_VALUE, list(SUPPORTED_MODELS_AND_SETTINGS.keys())
|
140
|
-
),
|
141
|
-
(
|
142
|
-
'Vertex AI model name with REST API support. See '
|
143
|
-
'https://cloud.google.com/vertex-ai/generative-ai/docs/'
|
144
|
-
'model-reference/inference#supported-models'
|
145
|
-
' for details.'
|
146
|
-
),
|
147
|
-
]
|
148
|
-
|
149
44
|
project: Annotated[
|
150
45
|
str | None,
|
151
46
|
(
|
@@ -170,11 +65,6 @@ class VertexAI(rest.REST):
|
|
170
65
|
),
|
171
66
|
] = None
|
172
67
|
|
173
|
-
supported_modalities: Annotated[
|
174
|
-
list[str],
|
175
|
-
'A list of MIME types for supported modalities'
|
176
|
-
] = []
|
177
|
-
|
178
68
|
def _on_bound(self):
|
179
69
|
super()._on_bound()
|
180
70
|
if google_auth is None:
|
@@ -209,31 +99,9 @@ class VertexAI(rest.REST):
|
|
209
99
|
self._credentials = credentials
|
210
100
|
|
211
101
|
@property
|
212
|
-
def
|
213
|
-
"""Returns
|
214
|
-
return self.
|
215
|
-
requests_per_min=SUPPORTED_MODELS_AND_SETTINGS[self.model].rpm,
|
216
|
-
tokens_per_min=0,
|
217
|
-
)
|
218
|
-
|
219
|
-
def estimate_cost(
|
220
|
-
self,
|
221
|
-
num_input_tokens: int,
|
222
|
-
num_output_tokens: int
|
223
|
-
) -> float | None:
|
224
|
-
"""Estimate the cost based on usage."""
|
225
|
-
cost_per_1k_input_chars = SUPPORTED_MODELS_AND_SETTINGS[self.model].get(
|
226
|
-
'cost_per_1k_input_chars', None
|
227
|
-
)
|
228
|
-
cost_per_1k_output_chars = SUPPORTED_MODELS_AND_SETTINGS[self.model].get(
|
229
|
-
'cost_per_1k_output_chars', None
|
230
|
-
)
|
231
|
-
if cost_per_1k_output_chars is None or cost_per_1k_input_chars is None:
|
232
|
-
return None
|
233
|
-
return (
|
234
|
-
cost_per_1k_input_chars * num_input_tokens
|
235
|
-
+ cost_per_1k_output_chars * num_output_tokens
|
236
|
-
) * AVGERAGE_CHARS_PER_TOKEN / 1000
|
102
|
+
def model_id(self) -> str:
|
103
|
+
"""Returns a string to identify the model."""
|
104
|
+
return f'VertexAI({self.model})'
|
237
105
|
|
238
106
|
@functools.cached_property
|
239
107
|
def _session(self):
|
@@ -244,12 +112,6 @@ class VertexAI(rest.REST):
|
|
244
112
|
s.headers.update(self.headers or {})
|
245
113
|
return s
|
246
114
|
|
247
|
-
@property
|
248
|
-
def headers(self):
|
249
|
-
return {
|
250
|
-
'Content-Type': 'application/json; charset=utf-8',
|
251
|
-
}
|
252
|
-
|
253
115
|
@property
|
254
116
|
def api_endpoint(self) -> str:
|
255
117
|
return (
|
@@ -258,263 +120,69 @@ class VertexAI(rest.REST):
|
|
258
120
|
f'models/{self.model}:generateContent'
|
259
121
|
)
|
260
122
|
|
261
|
-
def request(
|
262
|
-
self, prompt: lf.Message, sampling_options: lf.LMSamplingOptions
|
263
|
-
) -> dict[str, Any]:
|
264
|
-
request = dict(
|
265
|
-
generationConfig=self._generation_config(prompt, sampling_options)
|
266
|
-
)
|
267
|
-
request['contents'] = [self._content_from_message(prompt)]
|
268
|
-
return request
|
269
|
-
|
270
|
-
def _generation_config(
|
271
|
-
self, prompt: lf.Message, options: lf.LMSamplingOptions
|
272
|
-
) -> dict[str, Any]:
|
273
|
-
"""Returns a dict as generation config for prompt and LMSamplingOptions."""
|
274
|
-
config = dict(
|
275
|
-
temperature=options.temperature,
|
276
|
-
maxOutputTokens=options.max_tokens,
|
277
|
-
candidateCount=options.n,
|
278
|
-
topK=options.top_k,
|
279
|
-
topP=options.top_p,
|
280
|
-
stopSequences=options.stop,
|
281
|
-
seed=options.random_seed,
|
282
|
-
responseLogprobs=options.logprobs,
|
283
|
-
logprobs=options.top_logprobs,
|
284
|
-
)
|
285
123
|
|
286
|
-
|
287
|
-
|
288
|
-
raise ValueError(
|
289
|
-
f'`json_schema` must be a dict, got {json_schema!r}.'
|
290
|
-
)
|
291
|
-
json_schema = pg.to_json(json_schema)
|
292
|
-
config['responseSchema'] = json_schema
|
293
|
-
config['responseMimeType'] = 'application/json'
|
294
|
-
prompt.metadata.formatted_text = (
|
295
|
-
prompt.text
|
296
|
-
+ '\n\n [RESPONSE FORMAT (not part of prompt)]\n'
|
297
|
-
+ pg.to_json_str(json_schema, json_indent=2)
|
298
|
-
)
|
299
|
-
return config
|
300
|
-
|
301
|
-
def _content_from_message(self, prompt: lf.Message) -> dict[str, Any]:
|
302
|
-
"""Gets generation content from langfun message."""
|
303
|
-
parts = []
|
304
|
-
for lf_chunk in prompt.chunk():
|
305
|
-
if isinstance(lf_chunk, str):
|
306
|
-
parts.append({'text': lf_chunk})
|
307
|
-
elif isinstance(lf_chunk, lf_modalities.Mime):
|
308
|
-
try:
|
309
|
-
modalities = lf_chunk.make_compatible(
|
310
|
-
self.supported_modalities + ['text/plain']
|
311
|
-
)
|
312
|
-
if isinstance(modalities, lf_modalities.Mime):
|
313
|
-
modalities = [modalities]
|
314
|
-
for modality in modalities:
|
315
|
-
if modality.is_text:
|
316
|
-
parts.append({'text': modality.to_text()})
|
317
|
-
else:
|
318
|
-
parts.append({
|
319
|
-
'inlineData': {
|
320
|
-
'data': base64.b64encode(modality.to_bytes()).decode(),
|
321
|
-
'mimeType': modality.mime_type,
|
322
|
-
}
|
323
|
-
})
|
324
|
-
except lf.ModalityError as e:
|
325
|
-
raise lf.ModalityError(f'Unsupported modality: {lf_chunk!r}') from e
|
326
|
-
else:
|
327
|
-
raise lf.ModalityError(f'Unsupported modality: {lf_chunk!r}')
|
328
|
-
return dict(role='user', parts=parts)
|
329
|
-
|
330
|
-
def result(self, json: dict[str, Any]) -> lf.LMSamplingResult:
|
331
|
-
messages = [
|
332
|
-
self._message_from_content_parts(candidate['content']['parts'])
|
333
|
-
for candidate in json['candidates']
|
334
|
-
]
|
335
|
-
usage = json['usageMetadata']
|
336
|
-
input_tokens = usage['promptTokenCount']
|
337
|
-
output_tokens = usage['candidatesTokenCount']
|
338
|
-
return lf.LMSamplingResult(
|
339
|
-
[lf.LMSample(message) for message in messages],
|
340
|
-
usage=lf.LMSamplingUsage(
|
341
|
-
prompt_tokens=input_tokens,
|
342
|
-
completion_tokens=output_tokens,
|
343
|
-
total_tokens=input_tokens + output_tokens,
|
344
|
-
estimated_cost=self.estimate_cost(
|
345
|
-
num_input_tokens=input_tokens,
|
346
|
-
num_output_tokens=output_tokens,
|
347
|
-
),
|
348
|
-
),
|
349
|
-
)
|
124
|
+
class VertexAIGeminiFlash2_0ThinkingExp_20241219(VertexAI): # pylint: disable=invalid-name
|
125
|
+
"""Vertex AI Gemini Flash 2.0 Thinking model launched on 12/19/2024."""
|
350
126
|
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
for part in parts:
|
357
|
-
if text_part := part.get('text'):
|
358
|
-
chunks.append(text_part)
|
359
|
-
else:
|
360
|
-
raise ValueError(f'Unsupported part: {part}')
|
361
|
-
return lf.AIMessage.from_chunks(chunks)
|
362
|
-
|
363
|
-
|
364
|
-
IMAGE_TYPES = [
|
365
|
-
'image/png',
|
366
|
-
'image/jpeg',
|
367
|
-
'image/webp',
|
368
|
-
'image/heic',
|
369
|
-
'image/heif',
|
370
|
-
]
|
371
|
-
|
372
|
-
AUDIO_TYPES = [
|
373
|
-
'audio/aac',
|
374
|
-
'audio/flac',
|
375
|
-
'audio/mp3',
|
376
|
-
'audio/m4a',
|
377
|
-
'audio/mpeg',
|
378
|
-
'audio/mpga',
|
379
|
-
'audio/mp4',
|
380
|
-
'audio/opus',
|
381
|
-
'audio/pcm',
|
382
|
-
'audio/wav',
|
383
|
-
'audio/webm',
|
384
|
-
]
|
385
|
-
|
386
|
-
VIDEO_TYPES = [
|
387
|
-
'video/mov',
|
388
|
-
'video/mpeg',
|
389
|
-
'video/mpegps',
|
390
|
-
'video/mpg',
|
391
|
-
'video/mp4',
|
392
|
-
'video/webm',
|
393
|
-
'video/wmv',
|
394
|
-
'video/x-flv',
|
395
|
-
'video/3gpp',
|
396
|
-
'video/quicktime',
|
397
|
-
]
|
398
|
-
|
399
|
-
DOCUMENT_TYPES = [
|
400
|
-
'application/pdf',
|
401
|
-
'text/plain',
|
402
|
-
'text/csv',
|
403
|
-
'text/html',
|
404
|
-
'text/xml',
|
405
|
-
'text/x-script.python',
|
406
|
-
'application/json',
|
407
|
-
]
|
408
|
-
|
409
|
-
|
410
|
-
class VertexAIGemini2_0(VertexAI): # pylint: disable=invalid-name
|
411
|
-
"""Vertex AI Gemini 2.0 model."""
|
412
|
-
|
413
|
-
supported_modalities: pg.typing.List(str).freeze( # pytype: disable=invalid-annotation
|
414
|
-
DOCUMENT_TYPES + IMAGE_TYPES + AUDIO_TYPES + VIDEO_TYPES
|
415
|
-
)
|
416
|
-
|
417
|
-
|
418
|
-
class VertexAIGeminiFlash2_0Exp(VertexAIGemini2_0): # pylint: disable=invalid-name
|
127
|
+
api_version = 'v1alpha'
|
128
|
+
model = 'gemini-2.0-flash-thinking-exp-1219'
|
129
|
+
|
130
|
+
|
131
|
+
class VertexAIGeminiFlash2_0Exp(VertexAI): # pylint: disable=invalid-name
|
419
132
|
"""Vertex AI Gemini 2.0 Flash model."""
|
420
133
|
|
421
134
|
model = 'gemini-2.0-flash-exp'
|
422
135
|
|
423
136
|
|
424
|
-
class
|
425
|
-
"""Vertex AI Gemini
|
137
|
+
class VertexAIGeminiExp_20241206(VertexAI): # pylint: disable=invalid-name
|
138
|
+
"""Vertex AI Gemini Experimental model launched on 12/06/2024."""
|
426
139
|
|
427
|
-
model = 'gemini-
|
140
|
+
model = 'gemini-exp-1206'
|
428
141
|
|
429
142
|
|
430
|
-
class
|
431
|
-
"""Vertex AI Gemini
|
143
|
+
class VertexAIGeminiExp_20241114(VertexAI): # pylint: disable=invalid-name
|
144
|
+
"""Vertex AI Gemini Experimental model launched on 11/14/2024."""
|
432
145
|
|
433
|
-
|
434
|
-
DOCUMENT_TYPES + IMAGE_TYPES + AUDIO_TYPES + VIDEO_TYPES
|
435
|
-
)
|
146
|
+
model = 'gemini-exp-1114'
|
436
147
|
|
437
148
|
|
438
|
-
class VertexAIGeminiPro1_5(
|
149
|
+
class VertexAIGeminiPro1_5(VertexAI): # pylint: disable=invalid-name
|
439
150
|
"""Vertex AI Gemini 1.5 Pro model."""
|
440
151
|
|
441
|
-
model = 'gemini-1.5-pro'
|
152
|
+
model = 'gemini-1.5-pro-latest'
|
442
153
|
|
443
154
|
|
444
|
-
class VertexAIGeminiPro1_5_002(
|
155
|
+
class VertexAIGeminiPro1_5_002(VertexAI): # pylint: disable=invalid-name
|
445
156
|
"""Vertex AI Gemini 1.5 Pro model."""
|
446
157
|
|
447
158
|
model = 'gemini-1.5-pro-002'
|
448
159
|
|
449
160
|
|
450
|
-
class VertexAIGeminiPro1_5_001(
|
161
|
+
class VertexAIGeminiPro1_5_001(VertexAI): # pylint: disable=invalid-name
|
451
162
|
"""Vertex AI Gemini 1.5 Pro model."""
|
452
163
|
|
453
164
|
model = 'gemini-1.5-pro-001'
|
454
165
|
|
455
166
|
|
456
|
-
class
|
457
|
-
"""Vertex AI Gemini 1.5 Pro preview model."""
|
458
|
-
|
459
|
-
model = 'gemini-1.5-pro-preview-0514'
|
460
|
-
|
461
|
-
|
462
|
-
class VertexAIGeminiPro1_5_0409(VertexAIGemini1_5): # pylint: disable=invalid-name
|
463
|
-
"""Vertex AI Gemini 1.5 Pro preview model."""
|
464
|
-
|
465
|
-
model = 'gemini-1.5-pro-preview-0409'
|
466
|
-
|
467
|
-
|
468
|
-
class VertexAIGeminiFlash1_5(VertexAIGemini1_5): # pylint: disable=invalid-name
|
167
|
+
class VertexAIGeminiFlash1_5(VertexAI): # pylint: disable=invalid-name
|
469
168
|
"""Vertex AI Gemini 1.5 Flash model."""
|
470
169
|
|
471
170
|
model = 'gemini-1.5-flash'
|
472
171
|
|
473
172
|
|
474
|
-
class VertexAIGeminiFlash1_5_002(
|
173
|
+
class VertexAIGeminiFlash1_5_002(VertexAI): # pylint: disable=invalid-name
|
475
174
|
"""Vertex AI Gemini 1.5 Flash model."""
|
476
175
|
|
477
176
|
model = 'gemini-1.5-flash-002'
|
478
177
|
|
479
178
|
|
480
|
-
class VertexAIGeminiFlash1_5_001(
|
179
|
+
class VertexAIGeminiFlash1_5_001(VertexAI): # pylint: disable=invalid-name
|
481
180
|
"""Vertex AI Gemini 1.5 Flash model."""
|
482
181
|
|
483
182
|
model = 'gemini-1.5-flash-001'
|
484
183
|
|
485
184
|
|
486
|
-
class VertexAIGeminiFlash1_5_0514(VertexAIGemini1_5): # pylint: disable=invalid-name
|
487
|
-
"""Vertex AI Gemini 1.5 Flash preview model."""
|
488
|
-
|
489
|
-
model = 'gemini-1.5-flash-preview-0514'
|
490
|
-
|
491
|
-
|
492
185
|
class VertexAIGeminiPro1(VertexAI): # pylint: disable=invalid-name
|
493
186
|
"""Vertex AI Gemini 1.0 Pro model."""
|
494
187
|
|
495
188
|
model = 'gemini-1.0-pro'
|
496
|
-
|
497
|
-
|
498
|
-
class VertexAIGeminiPro1Vision(VertexAI): # pylint: disable=invalid-name
|
499
|
-
"""Vertex AI Gemini 1.0 Pro Vision model."""
|
500
|
-
|
501
|
-
model = 'gemini-1.0-pro-vision'
|
502
|
-
supported_modalities: pg.typing.List(str).freeze( # pytype: disable=invalid-annotation
|
503
|
-
IMAGE_TYPES + VIDEO_TYPES
|
504
|
-
)
|
505
|
-
|
506
|
-
|
507
|
-
class VertexAIEndpoint(VertexAI): # pylint: disable=invalid-name
|
508
|
-
"""Vertex AI Endpoint model."""
|
509
|
-
|
510
|
-
model = 'vertexai-endpoint'
|
511
|
-
|
512
|
-
endpoint: Annotated[str, 'Vertex AI Endpoint ID.']
|
513
|
-
|
514
|
-
@property
|
515
|
-
def api_endpoint(self) -> str:
|
516
|
-
return (
|
517
|
-
f'https://{self.location}-aiplatform.googleapis.com/v1/projects/'
|
518
|
-
f'{self.project}/locations/{self.location}/'
|
519
|
-
f'endpoints/{self.endpoint}:generateContent'
|
520
|
-
)
|