langfun 0.1.1.dev20240812__py3-none-any.whl → 0.1.1.dev20240817__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/core/eval/__init__.py +1 -1
- langfun/core/eval/base.py +21 -12
- langfun/core/eval/base_test.py +9 -4
- langfun/core/llms/__init__.py +1 -0
- langfun/core/llms/openai.py +60 -9
- langfun/core/llms/openai_test.py +63 -3
- langfun/core/llms/vertexai.py +29 -9
- langfun/core/llms/vertexai_test.py +73 -22
- {langfun-0.1.1.dev20240812.dist-info → langfun-0.1.1.dev20240817.dist-info}/METADATA +1 -1
- {langfun-0.1.1.dev20240812.dist-info → langfun-0.1.1.dev20240817.dist-info}/RECORD +13 -13
- {langfun-0.1.1.dev20240812.dist-info → langfun-0.1.1.dev20240817.dist-info}/WHEEL +1 -1
- {langfun-0.1.1.dev20240812.dist-info → langfun-0.1.1.dev20240817.dist-info}/LICENSE +0 -0
- {langfun-0.1.1.dev20240812.dist-info → langfun-0.1.1.dev20240817.dist-info}/top_level.txt +0 -0
langfun/core/eval/__init__.py
CHANGED
@@ -18,7 +18,7 @@
|
|
18
18
|
|
19
19
|
from langfun.core.eval.base import register
|
20
20
|
from langfun.core.eval.base import registered_names
|
21
|
-
from langfun.core.eval.base import
|
21
|
+
from langfun.core.eval.base import get_evaluations
|
22
22
|
from langfun.core.eval.base import get
|
23
23
|
from langfun.core.eval.base import run
|
24
24
|
|
langfun/core/eval/base.py
CHANGED
@@ -2159,14 +2159,17 @@ class _NamedEvaluationRegistry:
|
|
2159
2159
|
"""Returns all registered names."""
|
2160
2160
|
return sorted(self._registry.keys())
|
2161
2161
|
|
2162
|
-
def get(self, name: str) -> Type[Evaluable]:
|
2162
|
+
def get(self, name: str) -> list[Type[Evaluable]]:
|
2163
2163
|
"""Gets an evaluation by name."""
|
2164
|
-
|
2165
|
-
|
2166
|
-
|
2167
|
-
|
2168
|
-
)
|
2169
|
-
|
2164
|
+
matches = []
|
2165
|
+
if name in self._registry:
|
2166
|
+
matches.append(self._registry[name])
|
2167
|
+
else:
|
2168
|
+
regex = re.compile(name)
|
2169
|
+
for key, cls in self._registry.items():
|
2170
|
+
if regex.match(key):
|
2171
|
+
matches.append(cls)
|
2172
|
+
return matches
|
2170
2173
|
|
2171
2174
|
def register(
|
2172
2175
|
self,
|
@@ -2185,11 +2188,11 @@ def registered_names() -> list[str]:
|
|
2185
2188
|
return _eval_registry.names()
|
2186
2189
|
|
2187
2190
|
|
2188
|
-
def
|
2191
|
+
def get_evaluations(evaluation: str | Evaluable) -> list[Evaluable]:
|
2189
2192
|
"""Gets an evaluation experiment by name."""
|
2190
2193
|
if isinstance(evaluation, str):
|
2191
|
-
return _eval_registry.get(evaluation)
|
2192
|
-
return evaluation
|
2194
|
+
return [e() for e in _eval_registry.get(evaluation)]
|
2195
|
+
return [evaluation]
|
2193
2196
|
|
2194
2197
|
|
2195
2198
|
def register(name: str):
|
@@ -2257,8 +2260,14 @@ def get(
|
|
2257
2260
|
Returns:
|
2258
2261
|
A suite of selected `lf.eval.Evaluation` objects.
|
2259
2262
|
"""
|
2260
|
-
|
2261
|
-
|
2263
|
+
matches = []
|
2264
|
+
for e in evaluations:
|
2265
|
+
matches.extend(get_evaluations(e))
|
2266
|
+
|
2267
|
+
if not matches:
|
2268
|
+
raise ValueError('No evaluations found.')
|
2269
|
+
|
2270
|
+
suite = Suite(matches, root_dir=root_dir)
|
2262
2271
|
if patches:
|
2263
2272
|
suite = pg.patch(suite, patches)
|
2264
2273
|
|
langfun/core/eval/base_test.py
CHANGED
@@ -765,7 +765,7 @@ class NamedEvaluationTest(unittest.TestCase):
|
|
765
765
|
])
|
766
766
|
schema_fn = answer_schema()
|
767
767
|
|
768
|
-
evaluation = base.
|
768
|
+
[evaluation] = base.get_evaluations('named_eval/class_test')
|
769
769
|
self.assertIsInstance(evaluation, MyEval)
|
770
770
|
self.assertIsNone(evaluation.dir)
|
771
771
|
self.assertIsNone(evaluation.root_dir)
|
@@ -793,13 +793,15 @@ class NamedEvaluationTest(unittest.TestCase):
|
|
793
793
|
)
|
794
794
|
|
795
795
|
self.assertTrue(issubclass(my_eval, base.Evaluable))
|
796
|
-
evaluation = base.
|
796
|
+
[evaluation] = base.get_evaluations('named_eval/functor_test')
|
797
797
|
self.assertIn('named_eval/functor_test', base.registered_names())
|
798
798
|
self.assertIsInstance(evaluation, my_eval)
|
799
799
|
self.assertIsNone(evaluation.root_dir, None)
|
800
800
|
|
801
|
-
|
802
|
-
|
801
|
+
self.assertTrue(
|
802
|
+
pg.eq(base.get_evaluations('named_eval/functor.*'), [evaluation])
|
803
|
+
)
|
804
|
+
self.assertEqual(base.get_evaluations('named_eval/non_existent'), [])
|
803
805
|
|
804
806
|
with self.assertRaisesRegex(TypeError, 'The return value .*'):
|
805
807
|
@base.register('named_eval/bad_return_type')
|
@@ -841,6 +843,9 @@ class NamedEvaluationTest(unittest.TestCase):
|
|
841
843
|
)
|
842
844
|
self.assertTrue(pg.eq(e.leaf_nodes[0].lm, fake.StaticResponse('efg')))
|
843
845
|
|
846
|
+
with self.assertRaisesRegex(ValueError, 'No evaluations found'):
|
847
|
+
base.run(tempfile.gettempdir(), ['test/non_existent'])
|
848
|
+
|
844
849
|
|
845
850
|
if __name__ == '__main__':
|
846
851
|
unittest.main()
|
langfun/core/llms/__init__.py
CHANGED
@@ -42,6 +42,7 @@ from langfun.core.llms.openai import OpenAI
|
|
42
42
|
from langfun.core.llms.openai import Gpt4oMini
|
43
43
|
from langfun.core.llms.openai import Gpt4oMini_20240718
|
44
44
|
from langfun.core.llms.openai import Gpt4o
|
45
|
+
from langfun.core.llms.openai import Gpt4o_20240806
|
45
46
|
from langfun.core.llms.openai import Gpt4o_20240513
|
46
47
|
|
47
48
|
from langfun.core.llms.openai import Gpt4Turbo
|
langfun/core/llms/openai.py
CHANGED
@@ -37,6 +37,7 @@ SUPPORTED_MODELS_AND_SETTINGS = {
|
|
37
37
|
'gpt-4o-mini': pg.Dict(rpm=10000, tpm=5000000),
|
38
38
|
'gpt-4o-mini-2024-07-18': pg.Dict(rpm=10000, tpm=5000000),
|
39
39
|
'gpt-4o': pg.Dict(rpm=10000, tpm=5000000),
|
40
|
+
'gpt-4o-2024-08-06': pg.Dict(rpm=10000, tpm=5000000),
|
40
41
|
'gpt-4o-2024-05-13': pg.Dict(rpm=10000, tpm=5000000),
|
41
42
|
# GPT-4-Turbo models
|
42
43
|
'gpt-4-turbo': pg.Dict(rpm=10000, tpm=2000000),
|
@@ -178,6 +179,8 @@ class OpenAI(lf.LanguageModel):
|
|
178
179
|
args['top_p'] = options.top_p
|
179
180
|
if options.stop:
|
180
181
|
args['stop'] = options.stop
|
182
|
+
if options.random_seed is not None:
|
183
|
+
args['seed'] = options.random_seed
|
181
184
|
return args
|
182
185
|
|
183
186
|
def _sample(self, prompts: list[lf.Message]) -> list[lf.LMSamplingResult]:
|
@@ -231,10 +234,10 @@ class OpenAI(lf.LanguageModel):
|
|
231
234
|
def _chat_complete_batch(
|
232
235
|
self, prompts: list[lf.Message]
|
233
236
|
) -> list[lf.LMSamplingResult]:
|
234
|
-
def
|
237
|
+
def _content_from_message(message: lf.Message):
|
235
238
|
if self.multimodal:
|
236
239
|
content = []
|
237
|
-
for chunk in
|
240
|
+
for chunk in message.chunk():
|
238
241
|
if isinstance(chunk, str):
|
239
242
|
item = dict(type='text', text=chunk)
|
240
243
|
elif isinstance(chunk, lf_modalities.Image):
|
@@ -244,14 +247,56 @@ class OpenAI(lf.LanguageModel):
|
|
244
247
|
raise ValueError(f'Unsupported modality object: {chunk!r}.')
|
245
248
|
content.append(item)
|
246
249
|
else:
|
247
|
-
content =
|
250
|
+
content = message.text
|
251
|
+
return content
|
248
252
|
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
+
def _open_ai_chat_completion(prompt: lf.Message):
|
254
|
+
request_args = self._get_request_args(self.sampling_options)
|
255
|
+
# Users could use `metadata_json_schema` to pass additional
|
256
|
+
# request arguments.
|
257
|
+
json_schema = prompt.metadata.get('json_schema')
|
258
|
+
if json_schema is not None:
|
259
|
+
if not isinstance(json_schema, dict):
|
260
|
+
raise ValueError(
|
261
|
+
f'`json_schema` must be a dict, got {json_schema!r}.'
|
262
|
+
)
|
263
|
+
if 'title' not in json_schema:
|
264
|
+
raise ValueError(
|
265
|
+
f'The root of `json_schema` must have a `title` field, '
|
266
|
+
f'got {json_schema!r}.'
|
267
|
+
)
|
268
|
+
request_args.update(
|
269
|
+
response_format=dict(
|
270
|
+
type='json_schema',
|
271
|
+
json_schema=dict(
|
272
|
+
schema=json_schema,
|
273
|
+
name=json_schema['title'],
|
274
|
+
strict=True,
|
275
|
+
)
|
276
|
+
)
|
277
|
+
)
|
278
|
+
prompt.metadata.formatted_text = (
|
279
|
+
prompt.text
|
280
|
+
+ '\n\n [RESPONSE FORMAT (not part of prompt)]\n'
|
281
|
+
+ pg.to_json_str(request_args['response_format'], json_indent=2)
|
282
|
+
)
|
283
|
+
|
284
|
+
# Prepare messages.
|
285
|
+
messages = []
|
286
|
+
# Users could use `metadata_system_message` to pass system message.
|
287
|
+
system_message = prompt.metadata.get('system_message')
|
288
|
+
if system_message:
|
289
|
+
system_message = lf.SystemMessage.from_value(system_message)
|
290
|
+
messages.append(
|
291
|
+
dict(role='system', content=_content_from_message(system_message))
|
292
|
+
)
|
293
|
+
messages.append(dict(role='user', content=_content_from_message(prompt)))
|
294
|
+
|
295
|
+
response = cast(
|
296
|
+
openai_object.OpenAIObject,
|
297
|
+
openai.ChatCompletion.create(messages=messages, **request_args)
|
253
298
|
)
|
254
|
-
|
299
|
+
|
255
300
|
samples = []
|
256
301
|
for choice in response.choices:
|
257
302
|
logprobs = None
|
@@ -367,8 +412,14 @@ class Gpt4o(OpenAI):
|
|
367
412
|
multimodal = True
|
368
413
|
|
369
414
|
|
415
|
+
class Gpt4o_20240806(OpenAI): # pylint:disable=invalid-name
|
416
|
+
"""GPT-4o version 2024-08-06."""
|
417
|
+
model = 'gpt-4o-2024-08-06'
|
418
|
+
multimodal = True
|
419
|
+
|
420
|
+
|
370
421
|
class Gpt4o_20240513(OpenAI): # pylint:disable=invalid-name
|
371
|
-
"""GPT-4o."""
|
422
|
+
"""GPT-4o version 2024-05-13."""
|
372
423
|
model = 'gpt-4o-2024-05-13'
|
373
424
|
multimodal = True
|
374
425
|
|
langfun/core/llms/openai_test.py
CHANGED
@@ -43,12 +43,23 @@ def mock_completion_query(prompt, *, n=1, **kwargs):
|
|
43
43
|
|
44
44
|
|
45
45
|
def mock_chat_completion_query(messages, *, n=1, **kwargs):
|
46
|
-
|
46
|
+
if len(messages) > 1:
|
47
|
+
system_message = f' system={messages[0]["content"]}'
|
48
|
+
else:
|
49
|
+
system_message = ''
|
50
|
+
|
51
|
+
if 'response_format' in kwargs:
|
52
|
+
response_format = f' format={kwargs["response_format"]["type"]}'
|
53
|
+
else:
|
54
|
+
response_format = ''
|
55
|
+
|
47
56
|
choices = []
|
48
57
|
for k in range(n):
|
49
58
|
choices.append(pg.Dict(
|
50
59
|
message=pg.Dict(
|
51
|
-
content=
|
60
|
+
content=(
|
61
|
+
f'Sample {k} for message.{system_message}{response_format}'
|
62
|
+
)
|
52
63
|
),
|
53
64
|
logprobs=None,
|
54
65
|
))
|
@@ -123,7 +134,9 @@ class OpenAITest(unittest.TestCase):
|
|
123
134
|
)
|
124
135
|
self.assertEqual(
|
125
136
|
openai.Gpt4(api_key='test_key')._get_request_args(
|
126
|
-
lf.LMSamplingOptions(
|
137
|
+
lf.LMSamplingOptions(
|
138
|
+
temperature=1.0, stop=['\n'], n=1, random_seed=123
|
139
|
+
)
|
127
140
|
),
|
128
141
|
dict(
|
129
142
|
model='gpt-4',
|
@@ -134,6 +147,7 @@ class OpenAITest(unittest.TestCase):
|
|
134
147
|
stream=False,
|
135
148
|
timeout=120.0,
|
136
149
|
stop=['\n'],
|
150
|
+
seed=123,
|
137
151
|
),
|
138
152
|
)
|
139
153
|
|
@@ -461,6 +475,52 @@ class OpenAITest(unittest.TestCase):
|
|
461
475
|
),
|
462
476
|
)
|
463
477
|
|
478
|
+
def test_call_with_system_message(self):
|
479
|
+
with mock.patch('openai.ChatCompletion.create') as mock_chat_completion:
|
480
|
+
mock_chat_completion.side_effect = mock_chat_completion_query
|
481
|
+
lm = openai.OpenAI(api_key='test_key', model='gpt-4')
|
482
|
+
self.assertEqual(
|
483
|
+
lm(
|
484
|
+
lf.UserMessage(
|
485
|
+
'hello',
|
486
|
+
system_message='hi',
|
487
|
+
),
|
488
|
+
sampling_options=lf.LMSamplingOptions(n=2)
|
489
|
+
),
|
490
|
+
'Sample 0 for message. system=hi',
|
491
|
+
)
|
492
|
+
|
493
|
+
def test_call_with_json_schema(self):
|
494
|
+
with mock.patch('openai.ChatCompletion.create') as mock_chat_completion:
|
495
|
+
mock_chat_completion.side_effect = mock_chat_completion_query
|
496
|
+
lm = openai.OpenAI(api_key='test_key', model='gpt-4')
|
497
|
+
self.assertEqual(
|
498
|
+
lm(
|
499
|
+
lf.UserMessage(
|
500
|
+
'hello',
|
501
|
+
json_schema={
|
502
|
+
'type': 'object',
|
503
|
+
'properties': {
|
504
|
+
'name': {'type': 'string'},
|
505
|
+
},
|
506
|
+
'required': ['name'],
|
507
|
+
'title': 'Person',
|
508
|
+
}
|
509
|
+
),
|
510
|
+
sampling_options=lf.LMSamplingOptions(n=2)
|
511
|
+
),
|
512
|
+
'Sample 0 for message. format=json_schema',
|
513
|
+
)
|
514
|
+
|
515
|
+
# Test bad json schema.
|
516
|
+
with self.assertRaisesRegex(ValueError, '`json_schema` must be a dict'):
|
517
|
+
lm(lf.UserMessage('hello', json_schema='foo'))
|
518
|
+
|
519
|
+
with self.assertRaisesRegex(
|
520
|
+
ValueError, 'The root of `json_schema` must have a `title` field'
|
521
|
+
):
|
522
|
+
lm(lf.UserMessage('hello', json_schema={}))
|
523
|
+
|
464
524
|
|
465
525
|
if __name__ == '__main__':
|
466
526
|
unittest.main()
|
langfun/core/llms/vertexai.py
CHANGED
@@ -18,18 +18,17 @@ import os
|
|
18
18
|
from typing import Annotated, Any
|
19
19
|
|
20
20
|
from google.auth import credentials as credentials_lib
|
21
|
-
from google.cloud.aiplatform import aiplatform
|
22
21
|
import langfun.core as lf
|
23
22
|
from langfun.core import modalities as lf_modalities
|
24
23
|
import pyglove as pg
|
25
24
|
|
26
25
|
|
27
26
|
SUPPORTED_MODELS_AND_SETTINGS = {
|
28
|
-
'gemini-1.5-pro-001': pg.Dict(api='gemini', rpm=
|
29
|
-
'gemini-1.5-flash-001': pg.Dict(api='gemini', rpm=
|
30
|
-
'gemini-1.5-pro-preview-0514': pg.Dict(api='gemini', rpm=
|
31
|
-
'gemini-1.5-pro-preview-0409': pg.Dict(api='gemini', rpm=
|
32
|
-
'gemini-1.5-flash-preview-0514': pg.Dict(api='gemini', rpm=
|
27
|
+
'gemini-1.5-pro-001': pg.Dict(api='gemini', rpm=50),
|
28
|
+
'gemini-1.5-flash-001': pg.Dict(api='gemini', rpm=200),
|
29
|
+
'gemini-1.5-pro-preview-0514': pg.Dict(api='gemini', rpm=50),
|
30
|
+
'gemini-1.5-pro-preview-0409': pg.Dict(api='gemini', rpm=50),
|
31
|
+
'gemini-1.5-flash-preview-0514': pg.Dict(api='gemini', rpm=200),
|
33
32
|
'gemini-1.0-pro': pg.Dict(api='gemini', rpm=300),
|
34
33
|
'gemini-1.0-pro-vision': pg.Dict(api='gemini', rpm=100),
|
35
34
|
# PaLM APIs.
|
@@ -136,16 +135,34 @@ class VertexAI(lf.LanguageModel):
|
|
136
135
|
)
|
137
136
|
|
138
137
|
def _generation_config(
|
139
|
-
self, options: lf.LMSamplingOptions
|
138
|
+
self, prompt: lf.Message, options: lf.LMSamplingOptions
|
140
139
|
) -> Any: # generative_models.GenerationConfig
|
141
140
|
"""Creates generation config from langfun sampling options."""
|
142
141
|
from vertexai import generative_models
|
142
|
+
# Users could use `metadata_json_schema` to pass additional
|
143
|
+
# request arguments.
|
144
|
+
json_schema = prompt.metadata.get('json_schema')
|
145
|
+
response_mime_type = None
|
146
|
+
if json_schema is not None:
|
147
|
+
if not isinstance(json_schema, dict):
|
148
|
+
raise ValueError(
|
149
|
+
f'`json_schema` must be a dict, got {json_schema!r}.'
|
150
|
+
)
|
151
|
+
response_mime_type = 'application/json'
|
152
|
+
prompt.metadata.formatted_text = (
|
153
|
+
prompt.text
|
154
|
+
+ '\n\n [RESPONSE FORMAT (not part of prompt)]\n'
|
155
|
+
+ pg.to_json_str(json_schema, json_indent=2)
|
156
|
+
)
|
157
|
+
|
143
158
|
return generative_models.GenerationConfig(
|
144
159
|
temperature=options.temperature,
|
145
160
|
top_p=options.top_p,
|
146
161
|
top_k=options.top_k,
|
147
162
|
max_output_tokens=options.max_tokens,
|
148
163
|
stop_sequences=options.stop,
|
164
|
+
response_mime_type=response_mime_type,
|
165
|
+
response_schema=json_schema,
|
149
166
|
)
|
150
167
|
|
151
168
|
def _content_from_message(
|
@@ -239,7 +256,9 @@ class VertexAI(lf.LanguageModel):
|
|
239
256
|
input_content = self._content_from_message(prompt)
|
240
257
|
response = model.generate_content(
|
241
258
|
input_content,
|
242
|
-
generation_config=self._generation_config(
|
259
|
+
generation_config=self._generation_config(
|
260
|
+
prompt, self.sampling_options
|
261
|
+
),
|
243
262
|
)
|
244
263
|
usage_metadata = response.usage_metadata
|
245
264
|
usage = lf.LMSamplingUsage(
|
@@ -277,7 +296,8 @@ class VertexAI(lf.LanguageModel):
|
|
277
296
|
|
278
297
|
def _sample_endpoint_model(self, prompt: lf.Message) -> lf.LMSamplingResult:
|
279
298
|
"""Samples a text generation model."""
|
280
|
-
|
299
|
+
from google.cloud.aiplatform import models
|
300
|
+
model = models.Endpoint(self.endpoint_name)
|
281
301
|
# TODO(chengrun): Add support for stop_sequences.
|
282
302
|
predict_options = dict(
|
283
303
|
temperature=self.sampling_options.temperature
|
@@ -17,7 +17,7 @@ import os
|
|
17
17
|
import unittest
|
18
18
|
from unittest import mock
|
19
19
|
|
20
|
-
from google.cloud.aiplatform
|
20
|
+
from google.cloud.aiplatform import models as aiplatform_models
|
21
21
|
from vertexai import generative_models
|
22
22
|
import langfun.core as lf
|
23
23
|
from langfun.core import modalities as lf_modalities
|
@@ -175,6 +175,53 @@ class VertexAITest(unittest.TestCase):
|
|
175
175
|
del os.environ['VERTEXAI_PROJECT']
|
176
176
|
del os.environ['VERTEXAI_LOCATION']
|
177
177
|
|
178
|
+
def test_generation_config(self):
|
179
|
+
model = vertexai.VertexAIGeminiPro1()
|
180
|
+
json_schema = {
|
181
|
+
'type': 'object',
|
182
|
+
'properties': {
|
183
|
+
'name': {'type': 'string'},
|
184
|
+
},
|
185
|
+
'required': ['name'],
|
186
|
+
'title': 'Person',
|
187
|
+
}
|
188
|
+
config = model._generation_config(
|
189
|
+
lf.UserMessage('hi', json_schema=json_schema),
|
190
|
+
lf.LMSamplingOptions(
|
191
|
+
temperature=2.0,
|
192
|
+
top_p=1.0,
|
193
|
+
top_k=20,
|
194
|
+
max_tokens=1024,
|
195
|
+
stop=['\n'],
|
196
|
+
),
|
197
|
+
)
|
198
|
+
self.assertEqual(
|
199
|
+
config.to_dict(),
|
200
|
+
dict(
|
201
|
+
temperature=2.0,
|
202
|
+
top_p=1.0,
|
203
|
+
top_k=20.0,
|
204
|
+
max_output_tokens=1024,
|
205
|
+
stop_sequences=['\n'],
|
206
|
+
response_mime_type='application/json',
|
207
|
+
response_schema={
|
208
|
+
'type_': 'OBJECT',
|
209
|
+
'properties': {
|
210
|
+
'name': {'type_': 'STRING'}
|
211
|
+
},
|
212
|
+
'required': ['name'],
|
213
|
+
'title': 'Person',
|
214
|
+
}
|
215
|
+
),
|
216
|
+
)
|
217
|
+
with self.assertRaisesRegex(
|
218
|
+
ValueError, '`json_schema` must be a dict, got'
|
219
|
+
):
|
220
|
+
model._generation_config(
|
221
|
+
lf.UserMessage('hi', json_schema='not a dict'),
|
222
|
+
lf.LMSamplingOptions(),
|
223
|
+
)
|
224
|
+
|
178
225
|
def test_call_generative_model(self):
|
179
226
|
with mock.patch(
|
180
227
|
'vertexai.generative_models.'
|
@@ -244,27 +291,31 @@ class VertexAITest(unittest.TestCase):
|
|
244
291
|
|
245
292
|
def test_call_endpoint_model(self):
|
246
293
|
with mock.patch(
|
247
|
-
'google.cloud.aiplatform.
|
248
|
-
) as
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
294
|
+
'google.cloud.aiplatform.models.Endpoint.__init__'
|
295
|
+
) as mock_model_init:
|
296
|
+
mock_model_init.side_effect = lambda *args, **kwargs: None
|
297
|
+
with mock.patch(
|
298
|
+
'google.cloud.aiplatform.models.Endpoint.predict'
|
299
|
+
) as mock_model_predict:
|
300
|
+
|
301
|
+
mock_model_predict.side_effect = mock_endpoint_predict
|
302
|
+
lm = vertexai.VertexAI(
|
303
|
+
'custom',
|
304
|
+
endpoint_name='123',
|
305
|
+
project='abc',
|
306
|
+
location='us-central1',
|
307
|
+
)
|
308
|
+
self.assertEqual(
|
309
|
+
lm(
|
310
|
+
'hello',
|
311
|
+
temperature=2.0,
|
312
|
+
top_p=1.0,
|
313
|
+
top_k=20,
|
314
|
+
max_tokens=50,
|
315
|
+
),
|
316
|
+
'This is a response to hello with temperature=2.0, top_p=1.0,'
|
317
|
+
' top_k=20, max_tokens=50.',
|
318
|
+
)
|
268
319
|
|
269
320
|
|
270
321
|
if __name__ == '__main__':
|
@@ -43,16 +43,16 @@ langfun/core/coding/python/parsing.py,sha256=LMg8REP4VDY0YQjtPAGNAW4rKlMNdSXF8m1
|
|
43
43
|
langfun/core/coding/python/parsing_test.py,sha256=9vAWF484kWIm6JZq8NFiMgKUDhXV-deRl1QMmNERfAA,7386
|
44
44
|
langfun/core/coding/python/permissions.py,sha256=1QWGHvzL8MM0Ok_auQ9tURqZHtdOfJaDpBzZ29GUE-c,2544
|
45
45
|
langfun/core/coding/python/permissions_test.py,sha256=w5EDb8QxpxgJyZkojyzVWQvDfg366zn99-g__6TbPQ0,2699
|
46
|
-
langfun/core/eval/__init__.py,sha256=
|
47
|
-
langfun/core/eval/base.py,sha256=
|
48
|
-
langfun/core/eval/base_test.py,sha256=
|
46
|
+
langfun/core/eval/__init__.py,sha256=Ogdr9OtTywhhLPHi3AZzOD2mXX2oyaHWflrSTMm96uA,1899
|
47
|
+
langfun/core/eval/base.py,sha256=0_iaKuQhS49PlbWqCQ5EABUMKavr2R4ltcJZWCVoZZg,73816
|
48
|
+
langfun/core/eval/base_test.py,sha256=p1EfqviHMz_ppQY8FU67h5OCgL0tzhLvXzGIsq0sVyI,26930
|
49
49
|
langfun/core/eval/matching.py,sha256=9GX8HfO9jKxgNLAivgy5K88Xhoh6Z7Pptq65pe7vht8,9762
|
50
50
|
langfun/core/eval/matching_test.py,sha256=f7iVyXH5KGJBWt4Wp14Bt9J3X59A6Ayfog9MbuFvPew,5532
|
51
51
|
langfun/core/eval/patching.py,sha256=R0s2eAd1m97exQt06dmUL0V_MBG0W2Hxg7fhNB7cXW0,3866
|
52
52
|
langfun/core/eval/patching_test.py,sha256=8kCd54Egjju22FMgtJuxEsrXkW8ifs-UUBHtrCG1L6w,4775
|
53
53
|
langfun/core/eval/scoring.py,sha256=AlCwEVrU6nvURDB1aPxA2XBUmOjWxuNJDXJoS4-6VbU,6386
|
54
54
|
langfun/core/eval/scoring_test.py,sha256=O8olHbrUEg60gMxwOkWzKBJZpZoUlmVnBANX5Se2SXM,4546
|
55
|
-
langfun/core/llms/__init__.py,sha256=
|
55
|
+
langfun/core/llms/__init__.py,sha256=a1AV3XWi2gY4UvmmaPP1GapaQxygA6xzQJvVQRp6EPA,4818
|
56
56
|
langfun/core/llms/anthropic.py,sha256=Gon3fOi31RhZFgNd0ijyTnKnUdp9hrWrCoSXyO4UaLw,7316
|
57
57
|
langfun/core/llms/anthropic_test.py,sha256=T-swuMkfnlgs8Fpif4rtXs579exGk0TsbLMirXDZCkg,5533
|
58
58
|
langfun/core/llms/fake.py,sha256=Dd7-6ka9pFf3fcWZyczamjOqQ91MOI-m7We3Oc9Ffmo,2927
|
@@ -63,12 +63,12 @@ langfun/core/llms/groq.py,sha256=pqtyOZ_1_OJMOg8xATWT_B_SVbuT9nMRf4VkH9GzW8g,630
|
|
63
63
|
langfun/core/llms/groq_test.py,sha256=GYF_Qtq5S1H1TrKH38t6_lkdroqT7v-joYLDKnmS9e0,5274
|
64
64
|
langfun/core/llms/llama_cpp.py,sha256=9tXQntSCDtjTF3bnyJrAPCr4N6wycy5nXYvp9uduygE,2843
|
65
65
|
langfun/core/llms/llama_cpp_test.py,sha256=MWO_qaOeKjRniGjcaWPDScd7HPaIJemqUZoslrt4FPs,1806
|
66
|
-
langfun/core/llms/openai.py,sha256=
|
67
|
-
langfun/core/llms/openai_test.py,sha256=
|
66
|
+
langfun/core/llms/openai.py,sha256=lRfR2iim7OdzMCJCf1DXB5YVfSwflvUucMWY3dsMaRA,15798
|
67
|
+
langfun/core/llms/openai_test.py,sha256=02KeysRppXcAwA4SBUow8hKolFiEU9_lTSdlVHZletM,17518
|
68
68
|
langfun/core/llms/rest.py,sha256=laopuq-zD8V-3Y6eFDngftHEbE66VlUkCD2-rvvRaLU,3388
|
69
69
|
langfun/core/llms/rest_test.py,sha256=NZ3Nf0XQVpT9kLP5cBVo_yBHLI7vWTYhWQxYEJVMGs4,3472
|
70
|
-
langfun/core/llms/vertexai.py,sha256
|
71
|
-
langfun/core/llms/vertexai_test.py,sha256=
|
70
|
+
langfun/core/llms/vertexai.py,sha256=tXAnP357XhcsETTnk6M-hH4xyFi7tk6fsaf3tjzsY6E,14501
|
71
|
+
langfun/core/llms/vertexai_test.py,sha256=EPR-mB2hNUpvpf7E8m_k5bh04epdQTVUuYU6hPgZyu8,10321
|
72
72
|
langfun/core/llms/cache/__init__.py,sha256=QAo3InUMDM_YpteNnVCSejI4zOsnjSMWKJKzkb3VY64,993
|
73
73
|
langfun/core/llms/cache/base.py,sha256=cFfYvOIUae842pncqCAsRvqXCk2AnAsRYVx0mcIoAeY,3338
|
74
74
|
langfun/core/llms/cache/in_memory.py,sha256=YfFyJEhLs73cUiB0ZfhMxYpdE8Iuxxw-dvMFwGHTSHw,4742
|
@@ -117,8 +117,8 @@ langfun/core/templates/demonstration.py,sha256=vCrgYubdZM5Umqcgp8NUVGXgr4P_c-fik
|
|
117
117
|
langfun/core/templates/demonstration_test.py,sha256=SafcDQ0WgI7pw05EmPI2S4v1t3ABKzup8jReCljHeK4,2162
|
118
118
|
langfun/core/templates/selfplay.py,sha256=yhgrJbiYwq47TgzThmHrDQTF4nDrTI09CWGhuQPNv-s,2273
|
119
119
|
langfun/core/templates/selfplay_test.py,sha256=rBW2Qr8yi-aWYwoTwRR-n1peKyMX9QXPZXURjLgoiRs,2264
|
120
|
-
langfun-0.1.1.
|
121
|
-
langfun-0.1.1.
|
122
|
-
langfun-0.1.1.
|
123
|
-
langfun-0.1.1.
|
124
|
-
langfun-0.1.1.
|
120
|
+
langfun-0.1.1.dev20240817.dist-info/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
|
121
|
+
langfun-0.1.1.dev20240817.dist-info/METADATA,sha256=_j3RVHGW7vU400lTH9B4ncDzlqEmCHOs4GfpwO96quE,5234
|
122
|
+
langfun-0.1.1.dev20240817.dist-info/WHEEL,sha256=HiCZjzuy6Dw0hdX5R3LCFPDmFS4BWl8H-8W39XfmgX4,91
|
123
|
+
langfun-0.1.1.dev20240817.dist-info/top_level.txt,sha256=RhlEkHxs1qtzmmtWSwYoLVJAc1YrbPtxQ52uh8Z9VvY,8
|
124
|
+
langfun-0.1.1.dev20240817.dist-info/RECORD,,
|
File without changes
|
File without changes
|