langfun 0.1.1.dev20240812__py3-none-any.whl → 0.1.1.dev20240813__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -18,18 +18,17 @@ import os
18
18
  from typing import Annotated, Any
19
19
 
20
20
  from google.auth import credentials as credentials_lib
21
- from google.cloud.aiplatform import aiplatform
22
21
  import langfun.core as lf
23
22
  from langfun.core import modalities as lf_modalities
24
23
  import pyglove as pg
25
24
 
26
25
 
27
26
  SUPPORTED_MODELS_AND_SETTINGS = {
28
- 'gemini-1.5-pro-001': pg.Dict(api='gemini', rpm=5),
29
- 'gemini-1.5-flash-001': pg.Dict(api='gemini', rpm=5),
30
- 'gemini-1.5-pro-preview-0514': pg.Dict(api='gemini', rpm=5),
31
- 'gemini-1.5-pro-preview-0409': pg.Dict(api='gemini', rpm=5),
32
- 'gemini-1.5-flash-preview-0514': pg.Dict(api='gemini', rpm=5),
27
+ 'gemini-1.5-pro-001': pg.Dict(api='gemini', rpm=50),
28
+ 'gemini-1.5-flash-001': pg.Dict(api='gemini', rpm=200),
29
+ 'gemini-1.5-pro-preview-0514': pg.Dict(api='gemini', rpm=50),
30
+ 'gemini-1.5-pro-preview-0409': pg.Dict(api='gemini', rpm=50),
31
+ 'gemini-1.5-flash-preview-0514': pg.Dict(api='gemini', rpm=200),
33
32
  'gemini-1.0-pro': pg.Dict(api='gemini', rpm=300),
34
33
  'gemini-1.0-pro-vision': pg.Dict(api='gemini', rpm=100),
35
34
  # PaLM APIs.
@@ -136,16 +135,34 @@ class VertexAI(lf.LanguageModel):
136
135
  )
137
136
 
138
137
  def _generation_config(
139
- self, options: lf.LMSamplingOptions
138
+ self, prompt: lf.Message, options: lf.LMSamplingOptions
140
139
  ) -> Any: # generative_models.GenerationConfig
141
140
  """Creates generation config from langfun sampling options."""
142
141
  from vertexai import generative_models
142
+ # Users could use `metadata_json_schema` to pass additional
143
+ # request arguments.
144
+ json_schema = prompt.metadata.get('json_schema')
145
+ response_mime_type = None
146
+ if json_schema is not None:
147
+ if not isinstance(json_schema, dict):
148
+ raise ValueError(
149
+ f'`json_schema` must be a dict, got {json_schema!r}.'
150
+ )
151
+ response_mime_type = 'application/json'
152
+ prompt.metadata.formatted_text = (
153
+ prompt.text
154
+ + '\n\n [RESPONSE FORMAT (not part of prompt)]\n'
155
+ + pg.to_json_str(json_schema, json_indent=2)
156
+ )
157
+
143
158
  return generative_models.GenerationConfig(
144
159
  temperature=options.temperature,
145
160
  top_p=options.top_p,
146
161
  top_k=options.top_k,
147
162
  max_output_tokens=options.max_tokens,
148
163
  stop_sequences=options.stop,
164
+ response_mime_type=response_mime_type,
165
+ response_schema=json_schema,
149
166
  )
150
167
 
151
168
  def _content_from_message(
@@ -239,7 +256,9 @@ class VertexAI(lf.LanguageModel):
239
256
  input_content = self._content_from_message(prompt)
240
257
  response = model.generate_content(
241
258
  input_content,
242
- generation_config=self._generation_config(self.sampling_options),
259
+ generation_config=self._generation_config(
260
+ prompt, self.sampling_options
261
+ ),
243
262
  )
244
263
  usage_metadata = response.usage_metadata
245
264
  usage = lf.LMSamplingUsage(
@@ -277,7 +296,8 @@ class VertexAI(lf.LanguageModel):
277
296
 
278
297
  def _sample_endpoint_model(self, prompt: lf.Message) -> lf.LMSamplingResult:
279
298
  """Samples a text generation model."""
280
- model = aiplatform.Endpoint(self.endpoint_name)
299
+ from google.cloud.aiplatform import models
300
+ model = models.Endpoint(self.endpoint_name)
281
301
  # TODO(chengrun): Add support for stop_sequences.
282
302
  predict_options = dict(
283
303
  temperature=self.sampling_options.temperature
@@ -17,7 +17,7 @@ import os
17
17
  import unittest
18
18
  from unittest import mock
19
19
 
20
- from google.cloud.aiplatform.aiplatform import models as aiplatform_models
20
+ from google.cloud.aiplatform import models as aiplatform_models
21
21
  from vertexai import generative_models
22
22
  import langfun.core as lf
23
23
  from langfun.core import modalities as lf_modalities
@@ -175,6 +175,53 @@ class VertexAITest(unittest.TestCase):
175
175
  del os.environ['VERTEXAI_PROJECT']
176
176
  del os.environ['VERTEXAI_LOCATION']
177
177
 
178
+ def test_generation_config(self):
179
+ model = vertexai.VertexAIGeminiPro1()
180
+ json_schema = {
181
+ 'type': 'object',
182
+ 'properties': {
183
+ 'name': {'type': 'string'},
184
+ },
185
+ 'required': ['name'],
186
+ 'title': 'Person',
187
+ }
188
+ config = model._generation_config(
189
+ lf.UserMessage('hi', json_schema=json_schema),
190
+ lf.LMSamplingOptions(
191
+ temperature=2.0,
192
+ top_p=1.0,
193
+ top_k=20,
194
+ max_tokens=1024,
195
+ stop=['\n'],
196
+ ),
197
+ )
198
+ self.assertEqual(
199
+ config.to_dict(),
200
+ dict(
201
+ temperature=2.0,
202
+ top_p=1.0,
203
+ top_k=20.0,
204
+ max_output_tokens=1024,
205
+ stop_sequences=['\n'],
206
+ response_mime_type='application/json',
207
+ response_schema={
208
+ 'type_': 'OBJECT',
209
+ 'properties': {
210
+ 'name': {'type_': 'STRING'}
211
+ },
212
+ 'required': ['name'],
213
+ 'title': 'Person',
214
+ }
215
+ ),
216
+ )
217
+ with self.assertRaisesRegex(
218
+ ValueError, '`json_schema` must be a dict, got'
219
+ ):
220
+ model._generation_config(
221
+ lf.UserMessage('hi', json_schema='not a dict'),
222
+ lf.LMSamplingOptions(),
223
+ )
224
+
178
225
  def test_call_generative_model(self):
179
226
  with mock.patch(
180
227
  'vertexai.generative_models.'
@@ -244,27 +291,31 @@ class VertexAITest(unittest.TestCase):
244
291
 
245
292
  def test_call_endpoint_model(self):
246
293
  with mock.patch(
247
- 'google.cloud.aiplatform.aiplatform.Endpoint.predict'
248
- ) as mock_model_predict:
249
-
250
- mock_model_predict.side_effect = mock_endpoint_predict
251
- lm = vertexai.VertexAI(
252
- 'custom',
253
- endpoint_name='123',
254
- project='abc',
255
- location='us-central1',
256
- )
257
- self.assertEqual(
258
- lm(
259
- 'hello',
260
- temperature=2.0,
261
- top_p=1.0,
262
- top_k=20,
263
- max_tokens=50,
264
- ),
265
- 'This is a response to hello with temperature=2.0, top_p=1.0,'
266
- ' top_k=20, max_tokens=50.',
267
- )
294
+ 'google.cloud.aiplatform.models.Endpoint.__init__'
295
+ ) as mock_model_init:
296
+ mock_model_init.side_effect = lambda *args, **kwargs: None
297
+ with mock.patch(
298
+ 'google.cloud.aiplatform.models.Endpoint.predict'
299
+ ) as mock_model_predict:
300
+
301
+ mock_model_predict.side_effect = mock_endpoint_predict
302
+ lm = vertexai.VertexAI(
303
+ 'custom',
304
+ endpoint_name='123',
305
+ project='abc',
306
+ location='us-central1',
307
+ )
308
+ self.assertEqual(
309
+ lm(
310
+ 'hello',
311
+ temperature=2.0,
312
+ top_p=1.0,
313
+ top_k=20,
314
+ max_tokens=50,
315
+ ),
316
+ 'This is a response to hello with temperature=2.0, top_p=1.0,'
317
+ ' top_k=20, max_tokens=50.',
318
+ )
268
319
 
269
320
 
270
321
  if __name__ == '__main__':
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: langfun
3
- Version: 0.1.1.dev20240812
3
+ Version: 0.1.1.dev20240813
4
4
  Summary: Langfun: Language as Functions.
5
5
  Home-page: https://github.com/google/langfun
6
6
  Author: Langfun Authors
@@ -67,8 +67,8 @@ langfun/core/llms/openai.py,sha256=jILxfFb3vBuyf1u_2-LVfs_wekPF2RVuNFzNVg25pEA,1
67
67
  langfun/core/llms/openai_test.py,sha256=3muDTnW7UBOSHq694Fi2bofqhe8Pkj0Tl8IShoLCTOM,15525
68
68
  langfun/core/llms/rest.py,sha256=laopuq-zD8V-3Y6eFDngftHEbE66VlUkCD2-rvvRaLU,3388
69
69
  langfun/core/llms/rest_test.py,sha256=NZ3Nf0XQVpT9kLP5cBVo_yBHLI7vWTYhWQxYEJVMGs4,3472
70
- langfun/core/llms/vertexai.py,sha256=-YoEUlG19CWIhJb8S6puPqdX9SoiT5NNAItefwdCfsk,13781
71
- langfun/core/llms/vertexai_test.py,sha256=N3k4N9_bVjC6_Qtg4WO9jYNv8M9xmv5UdODvIKG2upo,8835
70
+ langfun/core/llms/vertexai.py,sha256=tXAnP357XhcsETTnk6M-hH4xyFi7tk6fsaf3tjzsY6E,14501
71
+ langfun/core/llms/vertexai_test.py,sha256=EPR-mB2hNUpvpf7E8m_k5bh04epdQTVUuYU6hPgZyu8,10321
72
72
  langfun/core/llms/cache/__init__.py,sha256=QAo3InUMDM_YpteNnVCSejI4zOsnjSMWKJKzkb3VY64,993
73
73
  langfun/core/llms/cache/base.py,sha256=cFfYvOIUae842pncqCAsRvqXCk2AnAsRYVx0mcIoAeY,3338
74
74
  langfun/core/llms/cache/in_memory.py,sha256=YfFyJEhLs73cUiB0ZfhMxYpdE8Iuxxw-dvMFwGHTSHw,4742
@@ -117,8 +117,8 @@ langfun/core/templates/demonstration.py,sha256=vCrgYubdZM5Umqcgp8NUVGXgr4P_c-fik
117
117
  langfun/core/templates/demonstration_test.py,sha256=SafcDQ0WgI7pw05EmPI2S4v1t3ABKzup8jReCljHeK4,2162
118
118
  langfun/core/templates/selfplay.py,sha256=yhgrJbiYwq47TgzThmHrDQTF4nDrTI09CWGhuQPNv-s,2273
119
119
  langfun/core/templates/selfplay_test.py,sha256=rBW2Qr8yi-aWYwoTwRR-n1peKyMX9QXPZXURjLgoiRs,2264
120
- langfun-0.1.1.dev20240812.dist-info/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
121
- langfun-0.1.1.dev20240812.dist-info/METADATA,sha256=Kk1dXZKetEZkE1Ycs26AdvbUuJY6TIA7Z75LKGqDlq0,5234
122
- langfun-0.1.1.dev20240812.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
123
- langfun-0.1.1.dev20240812.dist-info/top_level.txt,sha256=RhlEkHxs1qtzmmtWSwYoLVJAc1YrbPtxQ52uh8Z9VvY,8
124
- langfun-0.1.1.dev20240812.dist-info/RECORD,,
120
+ langfun-0.1.1.dev20240813.dist-info/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
121
+ langfun-0.1.1.dev20240813.dist-info/METADATA,sha256=pRn2OuwICzmonFmbZuzx6LzEgF_LE44B0ceTIa76fLs,5234
122
+ langfun-0.1.1.dev20240813.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
123
+ langfun-0.1.1.dev20240813.dist-info/top_level.txt,sha256=RhlEkHxs1qtzmmtWSwYoLVJAc1YrbPtxQ52uh8Z9VvY,8
124
+ langfun-0.1.1.dev20240813.dist-info/RECORD,,