langfun 0.1.1.dev20240808__py3-none-any.whl → 0.1.1.dev20240811__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/core/llms/__init__.py +1 -0
- langfun/core/llms/vertexai.py +52 -0
- langfun/core/llms/vertexai_test.py +39 -0
- {langfun-0.1.1.dev20240808.dist-info → langfun-0.1.1.dev20240811.dist-info}/METADATA +1 -1
- {langfun-0.1.1.dev20240808.dist-info → langfun-0.1.1.dev20240811.dist-info}/RECORD +8 -8
- {langfun-0.1.1.dev20240808.dist-info → langfun-0.1.1.dev20240811.dist-info}/LICENSE +0 -0
- {langfun-0.1.1.dev20240808.dist-info → langfun-0.1.1.dev20240811.dist-info}/WHEEL +0 -0
- {langfun-0.1.1.dev20240808.dist-info → langfun-0.1.1.dev20240811.dist-info}/top_level.txt +0 -0
langfun/core/llms/__init__.py
CHANGED
@@ -105,6 +105,7 @@ from langfun.core.llms.vertexai import VertexAIGeminiPro1
|
|
105
105
|
from langfun.core.llms.vertexai import VertexAIGeminiPro1Vision
|
106
106
|
from langfun.core.llms.vertexai import VertexAIPalm2
|
107
107
|
from langfun.core.llms.vertexai import VertexAIPalm2_32K
|
108
|
+
from langfun.core.llms.vertexai import VertexAICustom
|
108
109
|
|
109
110
|
|
110
111
|
# LLaMA C++ models.
|
langfun/core/llms/vertexai.py
CHANGED
@@ -18,6 +18,7 @@ import os
|
|
18
18
|
from typing import Annotated, Any
|
19
19
|
|
20
20
|
from google.auth import credentials as credentials_lib
|
21
|
+
from google.cloud.aiplatform import aiplatform
|
21
22
|
import langfun.core as lf
|
22
23
|
from langfun.core import modalities as lf_modalities
|
23
24
|
import pyglove as pg
|
@@ -35,6 +36,9 @@ SUPPORTED_MODELS_AND_SETTINGS = {
|
|
35
36
|
'text-bison': pg.Dict(api='palm', rpm=1600),
|
36
37
|
'text-bison-32k': pg.Dict(api='palm', rpm=300),
|
37
38
|
'text-unicorn': pg.Dict(api='palm', rpm=100),
|
39
|
+
# Endpoint
|
40
|
+
# TODO(chengrun): Set a more appropriate rpm for endpoint.
|
41
|
+
'custom': pg.Dict(api='endpoint', rpm=20),
|
38
42
|
}
|
39
43
|
|
40
44
|
|
@@ -53,6 +57,11 @@ class VertexAI(lf.LanguageModel):
|
|
53
57
|
),
|
54
58
|
]
|
55
59
|
|
60
|
+
endpoint_name: pg.typing.Annotated[
|
61
|
+
str | None,
|
62
|
+
'Vertex Endpoint name or ID.',
|
63
|
+
]
|
64
|
+
|
56
65
|
project: Annotated[
|
57
66
|
str | None,
|
58
67
|
(
|
@@ -177,6 +186,13 @@ class VertexAI(lf.LanguageModel):
|
|
177
186
|
"""Parses generative response into message."""
|
178
187
|
return lf.AIMessage(response.text)
|
179
188
|
|
189
|
+
def _generation_endpoint_response_to_message(
|
190
|
+
self,
|
191
|
+
response: Any, # google.cloud.aiplatform.aiplatform.models.Prediction
|
192
|
+
) -> lf.Message:
|
193
|
+
"""Parses Endpoint response into message."""
|
194
|
+
return lf.AIMessage(response.predictions[0])
|
195
|
+
|
180
196
|
def _sample(self, prompts: list[lf.Message]) -> list[lf.LMSamplingResult]:
|
181
197
|
assert self._api_initialized, 'Vertex AI API is not initialized.'
|
182
198
|
# TODO(yifenglu): It seems this exception is due to the instability of the
|
@@ -212,6 +228,8 @@ class VertexAI(lf.LanguageModel):
|
|
212
228
|
return self._sample_generative_model(prompt)
|
213
229
|
case 'palm':
|
214
230
|
return self._sample_text_generation_model(prompt)
|
231
|
+
case 'endpoint':
|
232
|
+
return self._sample_endpoint_model(prompt)
|
215
233
|
case _:
|
216
234
|
raise ValueError(f'Unsupported API: {api}')
|
217
235
|
|
@@ -257,6 +275,34 @@ class VertexAI(lf.LanguageModel):
|
|
257
275
|
lf.LMSample(lf.AIMessage(response.text), score=0.0)
|
258
276
|
])
|
259
277
|
|
278
|
+
def _sample_endpoint_model(self, prompt: lf.Message) -> lf.LMSamplingResult:
|
279
|
+
"""Samples a text generation model."""
|
280
|
+
model = aiplatform.Endpoint(self.endpoint_name)
|
281
|
+
# TODO(chengrun): Add support for stop_sequences.
|
282
|
+
predict_options = dict(
|
283
|
+
temperature=self.sampling_options.temperature
|
284
|
+
if self.sampling_options.temperature is not None
|
285
|
+
else 1.0,
|
286
|
+
top_k=self.sampling_options.top_k
|
287
|
+
if self.sampling_options.top_k is not None
|
288
|
+
else 32,
|
289
|
+
top_p=self.sampling_options.top_p
|
290
|
+
if self.sampling_options.top_p is not None
|
291
|
+
else 1,
|
292
|
+
max_tokens=self.sampling_options.max_tokens
|
293
|
+
if self.sampling_options.max_tokens is not None
|
294
|
+
else 8192,
|
295
|
+
)
|
296
|
+
instances = [{'prompt': prompt.text, **predict_options}]
|
297
|
+
response = model.predict(instances=instances)
|
298
|
+
|
299
|
+
return lf.LMSamplingResult([
|
300
|
+
# Scoring is not supported.
|
301
|
+
lf.LMSample(
|
302
|
+
self._generation_endpoint_response_to_message(response), score=0.0
|
303
|
+
)
|
304
|
+
])
|
305
|
+
|
260
306
|
|
261
307
|
class _ModelHub:
|
262
308
|
"""Vertex AI model hub."""
|
@@ -387,3 +433,9 @@ class VertexAIPalm2_32K(VertexAI): # pylint: disable=invalid-name
|
|
387
433
|
"""Vertex AI PaLM2 text generation model (32K context length)."""
|
388
434
|
|
389
435
|
model = 'text-bison-32k'
|
436
|
+
|
437
|
+
|
438
|
+
class VertexAICustom(VertexAI): # pylint: disable=invalid-name
|
439
|
+
"""Vertex AI Custom model (Endpoint)."""
|
440
|
+
|
441
|
+
model = 'custom'
|
@@ -17,6 +17,7 @@ import os
|
|
17
17
|
import unittest
|
18
18
|
from unittest import mock
|
19
19
|
|
20
|
+
from google.cloud.aiplatform.aiplatform import models as aiplatform_models
|
20
21
|
from vertexai import generative_models
|
21
22
|
import langfun.core as lf
|
22
23
|
from langfun.core import modalities as lf_modalities
|
@@ -63,6 +64,20 @@ def mock_generate_content(content, generation_config, **kwargs):
|
|
63
64
|
})
|
64
65
|
|
65
66
|
|
67
|
+
def mock_endpoint_predict(instances, **kwargs):
|
68
|
+
del kwargs
|
69
|
+
assert len(instances) == 1
|
70
|
+
return aiplatform_models.Prediction(
|
71
|
+
predictions=[
|
72
|
+
f"This is a response to {instances[0]['prompt']} with"
|
73
|
+
f" temperature={instances[0]['temperature']},"
|
74
|
+
f" top_p={instances[0]['top_p']}, top_k={instances[0]['top_k']},"
|
75
|
+
f" max_tokens={instances[0]['max_tokens']}."
|
76
|
+
],
|
77
|
+
deployed_model_id='',
|
78
|
+
)
|
79
|
+
|
80
|
+
|
66
81
|
class VertexAITest(unittest.TestCase):
|
67
82
|
"""Tests for Vertex model."""
|
68
83
|
|
@@ -227,6 +242,30 @@ class VertexAITest(unittest.TestCase):
|
|
227
242
|
),
|
228
243
|
)
|
229
244
|
|
245
|
+
def test_call_endpoint_model(self):
|
246
|
+
with mock.patch(
|
247
|
+
'google.cloud.aiplatform.aiplatform.Endpoint.predict'
|
248
|
+
) as mock_model_predict:
|
249
|
+
|
250
|
+
mock_model_predict.side_effect = mock_endpoint_predict
|
251
|
+
lm = vertexai.VertexAI(
|
252
|
+
'custom',
|
253
|
+
endpoint_name='123',
|
254
|
+
project='abc',
|
255
|
+
location='us-central1',
|
256
|
+
)
|
257
|
+
self.assertEqual(
|
258
|
+
lm(
|
259
|
+
'hello',
|
260
|
+
temperature=2.0,
|
261
|
+
top_p=1.0,
|
262
|
+
top_k=20,
|
263
|
+
max_tokens=50,
|
264
|
+
),
|
265
|
+
'This is a response to hello with temperature=2.0, top_p=1.0,'
|
266
|
+
' top_k=20, max_tokens=50.',
|
267
|
+
)
|
268
|
+
|
230
269
|
|
231
270
|
if __name__ == '__main__':
|
232
271
|
unittest.main()
|
@@ -52,7 +52,7 @@ langfun/core/eval/patching.py,sha256=R0s2eAd1m97exQt06dmUL0V_MBG0W2Hxg7fhNB7cXW0
|
|
52
52
|
langfun/core/eval/patching_test.py,sha256=8kCd54Egjju22FMgtJuxEsrXkW8ifs-UUBHtrCG1L6w,4775
|
53
53
|
langfun/core/eval/scoring.py,sha256=AlCwEVrU6nvURDB1aPxA2XBUmOjWxuNJDXJoS4-6VbU,6386
|
54
54
|
langfun/core/eval/scoring_test.py,sha256=O8olHbrUEg60gMxwOkWzKBJZpZoUlmVnBANX5Se2SXM,4546
|
55
|
-
langfun/core/llms/__init__.py,sha256=
|
55
|
+
langfun/core/llms/__init__.py,sha256=ggPvkKq8l-B8rN3ZD6d7r3_d5DYAd5fC7FgC6ZI2Wzo,4766
|
56
56
|
langfun/core/llms/anthropic.py,sha256=Gon3fOi31RhZFgNd0ijyTnKnUdp9hrWrCoSXyO4UaLw,7316
|
57
57
|
langfun/core/llms/anthropic_test.py,sha256=T-swuMkfnlgs8Fpif4rtXs579exGk0TsbLMirXDZCkg,5533
|
58
58
|
langfun/core/llms/fake.py,sha256=Dd7-6ka9pFf3fcWZyczamjOqQ91MOI-m7We3Oc9Ffmo,2927
|
@@ -67,8 +67,8 @@ langfun/core/llms/openai.py,sha256=jILxfFb3vBuyf1u_2-LVfs_wekPF2RVuNFzNVg25pEA,1
|
|
67
67
|
langfun/core/llms/openai_test.py,sha256=3muDTnW7UBOSHq694Fi2bofqhe8Pkj0Tl8IShoLCTOM,15525
|
68
68
|
langfun/core/llms/rest.py,sha256=laopuq-zD8V-3Y6eFDngftHEbE66VlUkCD2-rvvRaLU,3388
|
69
69
|
langfun/core/llms/rest_test.py,sha256=NZ3Nf0XQVpT9kLP5cBVo_yBHLI7vWTYhWQxYEJVMGs4,3472
|
70
|
-
langfun/core/llms/vertexai.py,sha256
|
71
|
-
langfun/core/llms/vertexai_test.py,sha256=
|
70
|
+
langfun/core/llms/vertexai.py,sha256=-YoEUlG19CWIhJb8S6puPqdX9SoiT5NNAItefwdCfsk,13781
|
71
|
+
langfun/core/llms/vertexai_test.py,sha256=N3k4N9_bVjC6_Qtg4WO9jYNv8M9xmv5UdODvIKG2upo,8835
|
72
72
|
langfun/core/llms/cache/__init__.py,sha256=QAo3InUMDM_YpteNnVCSejI4zOsnjSMWKJKzkb3VY64,993
|
73
73
|
langfun/core/llms/cache/base.py,sha256=cFfYvOIUae842pncqCAsRvqXCk2AnAsRYVx0mcIoAeY,3338
|
74
74
|
langfun/core/llms/cache/in_memory.py,sha256=YfFyJEhLs73cUiB0ZfhMxYpdE8Iuxxw-dvMFwGHTSHw,4742
|
@@ -117,8 +117,8 @@ langfun/core/templates/demonstration.py,sha256=vCrgYubdZM5Umqcgp8NUVGXgr4P_c-fik
|
|
117
117
|
langfun/core/templates/demonstration_test.py,sha256=SafcDQ0WgI7pw05EmPI2S4v1t3ABKzup8jReCljHeK4,2162
|
118
118
|
langfun/core/templates/selfplay.py,sha256=yhgrJbiYwq47TgzThmHrDQTF4nDrTI09CWGhuQPNv-s,2273
|
119
119
|
langfun/core/templates/selfplay_test.py,sha256=rBW2Qr8yi-aWYwoTwRR-n1peKyMX9QXPZXURjLgoiRs,2264
|
120
|
-
langfun-0.1.1.
|
121
|
-
langfun-0.1.1.
|
122
|
-
langfun-0.1.1.
|
123
|
-
langfun-0.1.1.
|
124
|
-
langfun-0.1.1.
|
120
|
+
langfun-0.1.1.dev20240811.dist-info/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
|
121
|
+
langfun-0.1.1.dev20240811.dist-info/METADATA,sha256=zo0-slFX-N_vCKGKOHamespDwn3SbGpM0TsIkdwB0Fk,5234
|
122
|
+
langfun-0.1.1.dev20240811.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
|
123
|
+
langfun-0.1.1.dev20240811.dist-info/top_level.txt,sha256=RhlEkHxs1qtzmmtWSwYoLVJAc1YrbPtxQ52uh8Z9VvY,8
|
124
|
+
langfun-0.1.1.dev20240811.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|