langfun 0.1.1.dev20240812__py3-none-any.whl → 0.1.1.dev20240813__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/core/llms/vertexai.py +29 -9
- langfun/core/llms/vertexai_test.py +73 -22
- {langfun-0.1.1.dev20240812.dist-info → langfun-0.1.1.dev20240813.dist-info}/METADATA +1 -1
- {langfun-0.1.1.dev20240812.dist-info → langfun-0.1.1.dev20240813.dist-info}/RECORD +7 -7
- {langfun-0.1.1.dev20240812.dist-info → langfun-0.1.1.dev20240813.dist-info}/LICENSE +0 -0
- {langfun-0.1.1.dev20240812.dist-info → langfun-0.1.1.dev20240813.dist-info}/WHEEL +0 -0
- {langfun-0.1.1.dev20240812.dist-info → langfun-0.1.1.dev20240813.dist-info}/top_level.txt +0 -0
langfun/core/llms/vertexai.py
CHANGED
@@ -18,18 +18,17 @@ import os
|
|
18
18
|
from typing import Annotated, Any
|
19
19
|
|
20
20
|
from google.auth import credentials as credentials_lib
|
21
|
-
from google.cloud.aiplatform import aiplatform
|
22
21
|
import langfun.core as lf
|
23
22
|
from langfun.core import modalities as lf_modalities
|
24
23
|
import pyglove as pg
|
25
24
|
|
26
25
|
|
27
26
|
SUPPORTED_MODELS_AND_SETTINGS = {
|
28
|
-
'gemini-1.5-pro-001': pg.Dict(api='gemini', rpm=
|
29
|
-
'gemini-1.5-flash-001': pg.Dict(api='gemini', rpm=
|
30
|
-
'gemini-1.5-pro-preview-0514': pg.Dict(api='gemini', rpm=
|
31
|
-
'gemini-1.5-pro-preview-0409': pg.Dict(api='gemini', rpm=
|
32
|
-
'gemini-1.5-flash-preview-0514': pg.Dict(api='gemini', rpm=
|
27
|
+
'gemini-1.5-pro-001': pg.Dict(api='gemini', rpm=50),
|
28
|
+
'gemini-1.5-flash-001': pg.Dict(api='gemini', rpm=200),
|
29
|
+
'gemini-1.5-pro-preview-0514': pg.Dict(api='gemini', rpm=50),
|
30
|
+
'gemini-1.5-pro-preview-0409': pg.Dict(api='gemini', rpm=50),
|
31
|
+
'gemini-1.5-flash-preview-0514': pg.Dict(api='gemini', rpm=200),
|
33
32
|
'gemini-1.0-pro': pg.Dict(api='gemini', rpm=300),
|
34
33
|
'gemini-1.0-pro-vision': pg.Dict(api='gemini', rpm=100),
|
35
34
|
# PaLM APIs.
|
@@ -136,16 +135,34 @@ class VertexAI(lf.LanguageModel):
|
|
136
135
|
)
|
137
136
|
|
138
137
|
def _generation_config(
|
139
|
-
self, options: lf.LMSamplingOptions
|
138
|
+
self, prompt: lf.Message, options: lf.LMSamplingOptions
|
140
139
|
) -> Any: # generative_models.GenerationConfig
|
141
140
|
"""Creates generation config from langfun sampling options."""
|
142
141
|
from vertexai import generative_models
|
142
|
+
# Users could use `metadata_json_schema` to pass additional
|
143
|
+
# request arguments.
|
144
|
+
json_schema = prompt.metadata.get('json_schema')
|
145
|
+
response_mime_type = None
|
146
|
+
if json_schema is not None:
|
147
|
+
if not isinstance(json_schema, dict):
|
148
|
+
raise ValueError(
|
149
|
+
f'`json_schema` must be a dict, got {json_schema!r}.'
|
150
|
+
)
|
151
|
+
response_mime_type = 'application/json'
|
152
|
+
prompt.metadata.formatted_text = (
|
153
|
+
prompt.text
|
154
|
+
+ '\n\n [RESPONSE FORMAT (not part of prompt)]\n'
|
155
|
+
+ pg.to_json_str(json_schema, json_indent=2)
|
156
|
+
)
|
157
|
+
|
143
158
|
return generative_models.GenerationConfig(
|
144
159
|
temperature=options.temperature,
|
145
160
|
top_p=options.top_p,
|
146
161
|
top_k=options.top_k,
|
147
162
|
max_output_tokens=options.max_tokens,
|
148
163
|
stop_sequences=options.stop,
|
164
|
+
response_mime_type=response_mime_type,
|
165
|
+
response_schema=json_schema,
|
149
166
|
)
|
150
167
|
|
151
168
|
def _content_from_message(
|
@@ -239,7 +256,9 @@ class VertexAI(lf.LanguageModel):
|
|
239
256
|
input_content = self._content_from_message(prompt)
|
240
257
|
response = model.generate_content(
|
241
258
|
input_content,
|
242
|
-
generation_config=self._generation_config(
|
259
|
+
generation_config=self._generation_config(
|
260
|
+
prompt, self.sampling_options
|
261
|
+
),
|
243
262
|
)
|
244
263
|
usage_metadata = response.usage_metadata
|
245
264
|
usage = lf.LMSamplingUsage(
|
@@ -277,7 +296,8 @@ class VertexAI(lf.LanguageModel):
|
|
277
296
|
|
278
297
|
def _sample_endpoint_model(self, prompt: lf.Message) -> lf.LMSamplingResult:
|
279
298
|
"""Samples a text generation model."""
|
280
|
-
|
299
|
+
from google.cloud.aiplatform import models
|
300
|
+
model = models.Endpoint(self.endpoint_name)
|
281
301
|
# TODO(chengrun): Add support for stop_sequences.
|
282
302
|
predict_options = dict(
|
283
303
|
temperature=self.sampling_options.temperature
|
@@ -17,7 +17,7 @@ import os
|
|
17
17
|
import unittest
|
18
18
|
from unittest import mock
|
19
19
|
|
20
|
-
from google.cloud.aiplatform
|
20
|
+
from google.cloud.aiplatform import models as aiplatform_models
|
21
21
|
from vertexai import generative_models
|
22
22
|
import langfun.core as lf
|
23
23
|
from langfun.core import modalities as lf_modalities
|
@@ -175,6 +175,53 @@ class VertexAITest(unittest.TestCase):
|
|
175
175
|
del os.environ['VERTEXAI_PROJECT']
|
176
176
|
del os.environ['VERTEXAI_LOCATION']
|
177
177
|
|
178
|
+
def test_generation_config(self):
|
179
|
+
model = vertexai.VertexAIGeminiPro1()
|
180
|
+
json_schema = {
|
181
|
+
'type': 'object',
|
182
|
+
'properties': {
|
183
|
+
'name': {'type': 'string'},
|
184
|
+
},
|
185
|
+
'required': ['name'],
|
186
|
+
'title': 'Person',
|
187
|
+
}
|
188
|
+
config = model._generation_config(
|
189
|
+
lf.UserMessage('hi', json_schema=json_schema),
|
190
|
+
lf.LMSamplingOptions(
|
191
|
+
temperature=2.0,
|
192
|
+
top_p=1.0,
|
193
|
+
top_k=20,
|
194
|
+
max_tokens=1024,
|
195
|
+
stop=['\n'],
|
196
|
+
),
|
197
|
+
)
|
198
|
+
self.assertEqual(
|
199
|
+
config.to_dict(),
|
200
|
+
dict(
|
201
|
+
temperature=2.0,
|
202
|
+
top_p=1.0,
|
203
|
+
top_k=20.0,
|
204
|
+
max_output_tokens=1024,
|
205
|
+
stop_sequences=['\n'],
|
206
|
+
response_mime_type='application/json',
|
207
|
+
response_schema={
|
208
|
+
'type_': 'OBJECT',
|
209
|
+
'properties': {
|
210
|
+
'name': {'type_': 'STRING'}
|
211
|
+
},
|
212
|
+
'required': ['name'],
|
213
|
+
'title': 'Person',
|
214
|
+
}
|
215
|
+
),
|
216
|
+
)
|
217
|
+
with self.assertRaisesRegex(
|
218
|
+
ValueError, '`json_schema` must be a dict, got'
|
219
|
+
):
|
220
|
+
model._generation_config(
|
221
|
+
lf.UserMessage('hi', json_schema='not a dict'),
|
222
|
+
lf.LMSamplingOptions(),
|
223
|
+
)
|
224
|
+
|
178
225
|
def test_call_generative_model(self):
|
179
226
|
with mock.patch(
|
180
227
|
'vertexai.generative_models.'
|
@@ -244,27 +291,31 @@ class VertexAITest(unittest.TestCase):
|
|
244
291
|
|
245
292
|
def test_call_endpoint_model(self):
|
246
293
|
with mock.patch(
|
247
|
-
'google.cloud.aiplatform.
|
248
|
-
) as
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
294
|
+
'google.cloud.aiplatform.models.Endpoint.__init__'
|
295
|
+
) as mock_model_init:
|
296
|
+
mock_model_init.side_effect = lambda *args, **kwargs: None
|
297
|
+
with mock.patch(
|
298
|
+
'google.cloud.aiplatform.models.Endpoint.predict'
|
299
|
+
) as mock_model_predict:
|
300
|
+
|
301
|
+
mock_model_predict.side_effect = mock_endpoint_predict
|
302
|
+
lm = vertexai.VertexAI(
|
303
|
+
'custom',
|
304
|
+
endpoint_name='123',
|
305
|
+
project='abc',
|
306
|
+
location='us-central1',
|
307
|
+
)
|
308
|
+
self.assertEqual(
|
309
|
+
lm(
|
310
|
+
'hello',
|
311
|
+
temperature=2.0,
|
312
|
+
top_p=1.0,
|
313
|
+
top_k=20,
|
314
|
+
max_tokens=50,
|
315
|
+
),
|
316
|
+
'This is a response to hello with temperature=2.0, top_p=1.0,'
|
317
|
+
' top_k=20, max_tokens=50.',
|
318
|
+
)
|
268
319
|
|
269
320
|
|
270
321
|
if __name__ == '__main__':
|
@@ -67,8 +67,8 @@ langfun/core/llms/openai.py,sha256=jILxfFb3vBuyf1u_2-LVfs_wekPF2RVuNFzNVg25pEA,1
|
|
67
67
|
langfun/core/llms/openai_test.py,sha256=3muDTnW7UBOSHq694Fi2bofqhe8Pkj0Tl8IShoLCTOM,15525
|
68
68
|
langfun/core/llms/rest.py,sha256=laopuq-zD8V-3Y6eFDngftHEbE66VlUkCD2-rvvRaLU,3388
|
69
69
|
langfun/core/llms/rest_test.py,sha256=NZ3Nf0XQVpT9kLP5cBVo_yBHLI7vWTYhWQxYEJVMGs4,3472
|
70
|
-
langfun/core/llms/vertexai.py,sha256
|
71
|
-
langfun/core/llms/vertexai_test.py,sha256=
|
70
|
+
langfun/core/llms/vertexai.py,sha256=tXAnP357XhcsETTnk6M-hH4xyFi7tk6fsaf3tjzsY6E,14501
|
71
|
+
langfun/core/llms/vertexai_test.py,sha256=EPR-mB2hNUpvpf7E8m_k5bh04epdQTVUuYU6hPgZyu8,10321
|
72
72
|
langfun/core/llms/cache/__init__.py,sha256=QAo3InUMDM_YpteNnVCSejI4zOsnjSMWKJKzkb3VY64,993
|
73
73
|
langfun/core/llms/cache/base.py,sha256=cFfYvOIUae842pncqCAsRvqXCk2AnAsRYVx0mcIoAeY,3338
|
74
74
|
langfun/core/llms/cache/in_memory.py,sha256=YfFyJEhLs73cUiB0ZfhMxYpdE8Iuxxw-dvMFwGHTSHw,4742
|
@@ -117,8 +117,8 @@ langfun/core/templates/demonstration.py,sha256=vCrgYubdZM5Umqcgp8NUVGXgr4P_c-fik
|
|
117
117
|
langfun/core/templates/demonstration_test.py,sha256=SafcDQ0WgI7pw05EmPI2S4v1t3ABKzup8jReCljHeK4,2162
|
118
118
|
langfun/core/templates/selfplay.py,sha256=yhgrJbiYwq47TgzThmHrDQTF4nDrTI09CWGhuQPNv-s,2273
|
119
119
|
langfun/core/templates/selfplay_test.py,sha256=rBW2Qr8yi-aWYwoTwRR-n1peKyMX9QXPZXURjLgoiRs,2264
|
120
|
-
langfun-0.1.1.
|
121
|
-
langfun-0.1.1.
|
122
|
-
langfun-0.1.1.
|
123
|
-
langfun-0.1.1.
|
124
|
-
langfun-0.1.1.
|
120
|
+
langfun-0.1.1.dev20240813.dist-info/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
|
121
|
+
langfun-0.1.1.dev20240813.dist-info/METADATA,sha256=pRn2OuwICzmonFmbZuzx6LzEgF_LE44B0ceTIa76fLs,5234
|
122
|
+
langfun-0.1.1.dev20240813.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
|
123
|
+
langfun-0.1.1.dev20240813.dist-info/top_level.txt,sha256=RhlEkHxs1qtzmmtWSwYoLVJAc1YrbPtxQ52uh8Z9VvY,8
|
124
|
+
langfun-0.1.1.dev20240813.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|