langfun 0.1.1.dev20240809__py3-none-any.whl → 0.1.1.dev20240812__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -105,6 +105,7 @@ from langfun.core.llms.vertexai import VertexAIGeminiPro1
105
105
  from langfun.core.llms.vertexai import VertexAIGeminiPro1Vision
106
106
  from langfun.core.llms.vertexai import VertexAIPalm2
107
107
  from langfun.core.llms.vertexai import VertexAIPalm2_32K
108
+ from langfun.core.llms.vertexai import VertexAICustom
108
109
 
109
110
 
110
111
  # LLaMA C++ models.
@@ -18,6 +18,7 @@ import os
18
18
  from typing import Annotated, Any
19
19
 
20
20
  from google.auth import credentials as credentials_lib
21
+ from google.cloud.aiplatform import aiplatform
21
22
  import langfun.core as lf
22
23
  from langfun.core import modalities as lf_modalities
23
24
  import pyglove as pg
@@ -35,6 +36,9 @@ SUPPORTED_MODELS_AND_SETTINGS = {
35
36
  'text-bison': pg.Dict(api='palm', rpm=1600),
36
37
  'text-bison-32k': pg.Dict(api='palm', rpm=300),
37
38
  'text-unicorn': pg.Dict(api='palm', rpm=100),
39
+ # Endpoint
40
+ # TODO(chengrun): Set a more appropriate rpm for endpoint.
41
+ 'custom': pg.Dict(api='endpoint', rpm=20),
38
42
  }
39
43
 
40
44
 
@@ -53,6 +57,11 @@ class VertexAI(lf.LanguageModel):
53
57
  ),
54
58
  ]
55
59
 
60
+ endpoint_name: pg.typing.Annotated[
61
+ str | None,
62
+ 'Vertex Endpoint name or ID.',
63
+ ]
64
+
56
65
  project: Annotated[
57
66
  str | None,
58
67
  (
@@ -177,6 +186,13 @@ class VertexAI(lf.LanguageModel):
177
186
  """Parses generative response into message."""
178
187
  return lf.AIMessage(response.text)
179
188
 
189
+ def _generation_endpoint_response_to_message(
190
+ self,
191
+ response: Any, # google.cloud.aiplatform.aiplatform.models.Prediction
192
+ ) -> lf.Message:
193
+ """Parses Endpoint response into message."""
194
+ return lf.AIMessage(response.predictions[0])
195
+
180
196
  def _sample(self, prompts: list[lf.Message]) -> list[lf.LMSamplingResult]:
181
197
  assert self._api_initialized, 'Vertex AI API is not initialized.'
182
198
  # TODO(yifenglu): It seems this exception is due to the instability of the
@@ -212,6 +228,8 @@ class VertexAI(lf.LanguageModel):
212
228
  return self._sample_generative_model(prompt)
213
229
  case 'palm':
214
230
  return self._sample_text_generation_model(prompt)
231
+ case 'endpoint':
232
+ return self._sample_endpoint_model(prompt)
215
233
  case _:
216
234
  raise ValueError(f'Unsupported API: {api}')
217
235
 
@@ -257,6 +275,34 @@ class VertexAI(lf.LanguageModel):
257
275
  lf.LMSample(lf.AIMessage(response.text), score=0.0)
258
276
  ])
259
277
 
278
+ def _sample_endpoint_model(self, prompt: lf.Message) -> lf.LMSamplingResult:
279
+ """Samples a text generation model."""
280
+ model = aiplatform.Endpoint(self.endpoint_name)
281
+ # TODO(chengrun): Add support for stop_sequences.
282
+ predict_options = dict(
283
+ temperature=self.sampling_options.temperature
284
+ if self.sampling_options.temperature is not None
285
+ else 1.0,
286
+ top_k=self.sampling_options.top_k
287
+ if self.sampling_options.top_k is not None
288
+ else 32,
289
+ top_p=self.sampling_options.top_p
290
+ if self.sampling_options.top_p is not None
291
+ else 1,
292
+ max_tokens=self.sampling_options.max_tokens
293
+ if self.sampling_options.max_tokens is not None
294
+ else 8192,
295
+ )
296
+ instances = [{'prompt': prompt.text, **predict_options}]
297
+ response = model.predict(instances=instances)
298
+
299
+ return lf.LMSamplingResult([
300
+ # Scoring is not supported.
301
+ lf.LMSample(
302
+ self._generation_endpoint_response_to_message(response), score=0.0
303
+ )
304
+ ])
305
+
260
306
 
261
307
  class _ModelHub:
262
308
  """Vertex AI model hub."""
@@ -387,3 +433,9 @@ class VertexAIPalm2_32K(VertexAI): # pylint: disable=invalid-name
387
433
  """Vertex AI PaLM2 text generation model (32K context length)."""
388
434
 
389
435
  model = 'text-bison-32k'
436
+
437
+
438
+ class VertexAICustom(VertexAI): # pylint: disable=invalid-name
439
+ """Vertex AI Custom model (Endpoint)."""
440
+
441
+ model = 'custom'
@@ -17,6 +17,7 @@ import os
17
17
  import unittest
18
18
  from unittest import mock
19
19
 
20
+ from google.cloud.aiplatform.aiplatform import models as aiplatform_models
20
21
  from vertexai import generative_models
21
22
  import langfun.core as lf
22
23
  from langfun.core import modalities as lf_modalities
@@ -63,6 +64,20 @@ def mock_generate_content(content, generation_config, **kwargs):
63
64
  })
64
65
 
65
66
 
67
+ def mock_endpoint_predict(instances, **kwargs):
68
+ del kwargs
69
+ assert len(instances) == 1
70
+ return aiplatform_models.Prediction(
71
+ predictions=[
72
+ f"This is a response to {instances[0]['prompt']} with"
73
+ f" temperature={instances[0]['temperature']},"
74
+ f" top_p={instances[0]['top_p']}, top_k={instances[0]['top_k']},"
75
+ f" max_tokens={instances[0]['max_tokens']}."
76
+ ],
77
+ deployed_model_id='',
78
+ )
79
+
80
+
66
81
  class VertexAITest(unittest.TestCase):
67
82
  """Tests for Vertex model."""
68
83
 
@@ -227,6 +242,30 @@ class VertexAITest(unittest.TestCase):
227
242
  ),
228
243
  )
229
244
 
245
+ def test_call_endpoint_model(self):
246
+ with mock.patch(
247
+ 'google.cloud.aiplatform.aiplatform.Endpoint.predict'
248
+ ) as mock_model_predict:
249
+
250
+ mock_model_predict.side_effect = mock_endpoint_predict
251
+ lm = vertexai.VertexAI(
252
+ 'custom',
253
+ endpoint_name='123',
254
+ project='abc',
255
+ location='us-central1',
256
+ )
257
+ self.assertEqual(
258
+ lm(
259
+ 'hello',
260
+ temperature=2.0,
261
+ top_p=1.0,
262
+ top_k=20,
263
+ max_tokens=50,
264
+ ),
265
+ 'This is a response to hello with temperature=2.0, top_p=1.0,'
266
+ ' top_k=20, max_tokens=50.',
267
+ )
268
+
230
269
 
231
270
  if __name__ == '__main__':
232
271
  unittest.main()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: langfun
3
- Version: 0.1.1.dev20240809
3
+ Version: 0.1.1.dev20240812
4
4
  Summary: Langfun: Language as Functions.
5
5
  Home-page: https://github.com/google/langfun
6
6
  Author: Langfun Authors
@@ -52,7 +52,7 @@ langfun/core/eval/patching.py,sha256=R0s2eAd1m97exQt06dmUL0V_MBG0W2Hxg7fhNB7cXW0
52
52
  langfun/core/eval/patching_test.py,sha256=8kCd54Egjju22FMgtJuxEsrXkW8ifs-UUBHtrCG1L6w,4775
53
53
  langfun/core/eval/scoring.py,sha256=AlCwEVrU6nvURDB1aPxA2XBUmOjWxuNJDXJoS4-6VbU,6386
54
54
  langfun/core/eval/scoring_test.py,sha256=O8olHbrUEg60gMxwOkWzKBJZpZoUlmVnBANX5Se2SXM,4546
55
- langfun/core/llms/__init__.py,sha256=YGILcGi2QTxDG0v-0Gd4uAj1HL_zRhtllOM9EURxzDg,4712
55
+ langfun/core/llms/__init__.py,sha256=ggPvkKq8l-B8rN3ZD6d7r3_d5DYAd5fC7FgC6ZI2Wzo,4766
56
56
  langfun/core/llms/anthropic.py,sha256=Gon3fOi31RhZFgNd0ijyTnKnUdp9hrWrCoSXyO4UaLw,7316
57
57
  langfun/core/llms/anthropic_test.py,sha256=T-swuMkfnlgs8Fpif4rtXs579exGk0TsbLMirXDZCkg,5533
58
58
  langfun/core/llms/fake.py,sha256=Dd7-6ka9pFf3fcWZyczamjOqQ91MOI-m7We3Oc9Ffmo,2927
@@ -67,8 +67,8 @@ langfun/core/llms/openai.py,sha256=jILxfFb3vBuyf1u_2-LVfs_wekPF2RVuNFzNVg25pEA,1
67
67
  langfun/core/llms/openai_test.py,sha256=3muDTnW7UBOSHq694Fi2bofqhe8Pkj0Tl8IShoLCTOM,15525
68
68
  langfun/core/llms/rest.py,sha256=laopuq-zD8V-3Y6eFDngftHEbE66VlUkCD2-rvvRaLU,3388
69
69
  langfun/core/llms/rest_test.py,sha256=NZ3Nf0XQVpT9kLP5cBVo_yBHLI7vWTYhWQxYEJVMGs4,3472
70
- langfun/core/llms/vertexai.py,sha256=pSvEWZ2jqS21TEEoztn3VWaqT09abkMcj07Oe2BJR1s,12017
71
- langfun/core/llms/vertexai_test.py,sha256=G18BG36h5KvmX2zutDTLjtYCRjTuP_nWIFm4FMnLnyY,7651
70
+ langfun/core/llms/vertexai.py,sha256=-YoEUlG19CWIhJb8S6puPqdX9SoiT5NNAItefwdCfsk,13781
71
+ langfun/core/llms/vertexai_test.py,sha256=N3k4N9_bVjC6_Qtg4WO9jYNv8M9xmv5UdODvIKG2upo,8835
72
72
  langfun/core/llms/cache/__init__.py,sha256=QAo3InUMDM_YpteNnVCSejI4zOsnjSMWKJKzkb3VY64,993
73
73
  langfun/core/llms/cache/base.py,sha256=cFfYvOIUae842pncqCAsRvqXCk2AnAsRYVx0mcIoAeY,3338
74
74
  langfun/core/llms/cache/in_memory.py,sha256=YfFyJEhLs73cUiB0ZfhMxYpdE8Iuxxw-dvMFwGHTSHw,4742
@@ -117,8 +117,8 @@ langfun/core/templates/demonstration.py,sha256=vCrgYubdZM5Umqcgp8NUVGXgr4P_c-fik
117
117
  langfun/core/templates/demonstration_test.py,sha256=SafcDQ0WgI7pw05EmPI2S4v1t3ABKzup8jReCljHeK4,2162
118
118
  langfun/core/templates/selfplay.py,sha256=yhgrJbiYwq47TgzThmHrDQTF4nDrTI09CWGhuQPNv-s,2273
119
119
  langfun/core/templates/selfplay_test.py,sha256=rBW2Qr8yi-aWYwoTwRR-n1peKyMX9QXPZXURjLgoiRs,2264
120
- langfun-0.1.1.dev20240809.dist-info/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
121
- langfun-0.1.1.dev20240809.dist-info/METADATA,sha256=ZfbP7-dJs45zSvUsxwQN-CYQabeGJbX0wDeqb6A23wU,5234
122
- langfun-0.1.1.dev20240809.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
123
- langfun-0.1.1.dev20240809.dist-info/top_level.txt,sha256=RhlEkHxs1qtzmmtWSwYoLVJAc1YrbPtxQ52uh8Z9VvY,8
124
- langfun-0.1.1.dev20240809.dist-info/RECORD,,
120
+ langfun-0.1.1.dev20240812.dist-info/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
121
+ langfun-0.1.1.dev20240812.dist-info/METADATA,sha256=Kk1dXZKetEZkE1Ycs26AdvbUuJY6TIA7Z75LKGqDlq0,5234
122
+ langfun-0.1.1.dev20240812.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
123
+ langfun-0.1.1.dev20240812.dist-info/top_level.txt,sha256=RhlEkHxs1qtzmmtWSwYoLVJAc1YrbPtxQ52uh8Z9VvY,8
124
+ langfun-0.1.1.dev20240812.dist-info/RECORD,,