langfun 0.0.2.dev20240430__py3-none-any.whl → 0.0.2.dev20240501__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/core/eval/__init__.py +14 -1
- langfun/core/eval/base.py +180 -32
- langfun/core/eval/base_test.py +89 -8
- langfun/core/eval/patching.py +130 -0
- langfun/core/eval/patching_test.py +170 -0
- {langfun-0.0.2.dev20240430.dist-info → langfun-0.0.2.dev20240501.dist-info}/METADATA +1 -2
- {langfun-0.0.2.dev20240430.dist-info → langfun-0.0.2.dev20240501.dist-info}/RECORD +10 -8
- {langfun-0.0.2.dev20240430.dist-info → langfun-0.0.2.dev20240501.dist-info}/LICENSE +0 -0
- {langfun-0.0.2.dev20240430.dist-info → langfun-0.0.2.dev20240501.dist-info}/WHEEL +0 -0
- {langfun-0.0.2.dev20240430.dist-info → langfun-0.0.2.dev20240501.dist-info}/top_level.txt +0 -0
langfun/core/eval/__init__.py
CHANGED
@@ -16,7 +16,11 @@
|
|
16
16
|
# pylint: disable=g-importing-member
|
17
17
|
# pylint: disable=g-bad-import-order
|
18
18
|
|
19
|
-
from langfun.core.eval.base import
|
19
|
+
from langfun.core.eval.base import register
|
20
|
+
from langfun.core.eval.base import registered_names
|
21
|
+
from langfun.core.eval.base import get_evaluation
|
22
|
+
from langfun.core.eval.base import get
|
23
|
+
from langfun.core.eval.base import run
|
20
24
|
|
21
25
|
from langfun.core.eval.base import Evaluable
|
22
26
|
from langfun.core.eval.base import Evaluation
|
@@ -34,6 +38,15 @@ from langfun.core.eval.base import as_inputs
|
|
34
38
|
from langfun.core.eval.matching import Matching
|
35
39
|
from langfun.core.eval.scoring import Scoring
|
36
40
|
|
41
|
+
# Experiment patching.
|
42
|
+
from langfun.core.eval.patching import patch_member
|
37
43
|
|
44
|
+
from langfun.core.eval.patching import patch_lm
|
45
|
+
from langfun.core.eval.patching import patch_parsing_lm
|
46
|
+
from langfun.core.eval.patching import patch_inputs
|
47
|
+
from langfun.core.eval.patching import patch_prompt
|
48
|
+
from langfun.core.eval.patching import patch_schema_fn
|
49
|
+
|
50
|
+
# Placeholder for Google-internal imports.
|
38
51
|
# pylint: enable=g-bad-import-order
|
39
52
|
# pylint: enable=g-importing-member
|
langfun/core/eval/base.py
CHANGED
@@ -25,10 +25,9 @@ import os
|
|
25
25
|
import re
|
26
26
|
import threading
|
27
27
|
import time
|
28
|
+
import types
|
28
29
|
from typing import Annotated, Any, Callable, Iterator, Literal, Optional, Sequence, Type, Union
|
29
30
|
|
30
|
-
from absl import app
|
31
|
-
from absl import flags
|
32
31
|
import langfun.core as lf
|
33
32
|
import langfun.core.coding as lf_coding
|
34
33
|
from langfun.core.llms.cache import in_memory
|
@@ -600,7 +599,6 @@ class _LeafNode:
|
|
600
599
|
@pg.use_init_args(['children'])
|
601
600
|
class Suite(Evaluable):
|
602
601
|
"""Evaluation suite."""
|
603
|
-
|
604
602
|
children: Annotated[list[Evaluable], 'Child evaluation sets or suites.']
|
605
603
|
|
606
604
|
# Use empty ID as suite is just a container of child evaluations.
|
@@ -2146,41 +2144,191 @@ def monitor_async(
|
|
2146
2144
|
)
|
2147
2145
|
|
2148
2146
|
|
2149
|
-
|
2150
|
-
|
2147
|
+
#
|
2148
|
+
# Named evaluations and experiments support.
|
2149
|
+
#
|
2151
2150
|
|
2152
|
-
Args:
|
2153
|
-
target: An Langfun evaluable object.
|
2154
|
-
"""
|
2155
|
-
flags.DEFINE_string(
|
2156
|
-
'root_dir', None, 'Root directory for running the evaluation.'
|
2157
|
-
)
|
2158
2151
|
|
2159
|
-
|
2160
|
-
|
2161
|
-
)
|
2152
|
+
class _NamedEvaluationRegistry:
|
2153
|
+
"""Named evaluation registry."""
|
2162
2154
|
|
2163
|
-
|
2164
|
-
|
2165
|
-
)
|
2155
|
+
def __init__(self):
|
2156
|
+
self._registry = {}
|
2166
2157
|
|
2167
|
-
|
2168
|
-
|
2169
|
-
|
2170
|
-
|
2171
|
-
)
|
2158
|
+
def names(self) -> list[str]:
|
2159
|
+
"""Returns all registered names."""
|
2160
|
+
return sorted(self._registry.keys())
|
2161
|
+
|
2162
|
+
def get(self, name: str) -> Type[Evaluable]:
|
2163
|
+
"""Gets an evaluation by name."""
|
2164
|
+
if name not in self._registry:
|
2165
|
+
raise ValueError(
|
2166
|
+
f'Evaluation {name!r} not found. '
|
2167
|
+
'Did you forget to import the module that registers it?'
|
2168
|
+
)
|
2169
|
+
return self._registry[name]
|
2170
|
+
|
2171
|
+
def register(
|
2172
|
+
self,
|
2173
|
+
name: str,
|
2174
|
+
experiment_cls: Type[Evaluable],
|
2175
|
+
):
|
2176
|
+
"""Register an experiment class."""
|
2177
|
+
self._registry[name] = experiment_cls
|
2178
|
+
|
2179
|
+
|
2180
|
+
_eval_registry = _NamedEvaluationRegistry()
|
2172
2181
|
|
2173
|
-
FLAGS = flags.FLAGS # pylint: disable=invalid-name
|
2174
2182
|
|
2175
|
-
|
2176
|
-
|
2177
|
-
|
2183
|
+
def registered_names() -> list[str]:
|
2184
|
+
"""Returns all registered names."""
|
2185
|
+
return _eval_registry.names()
|
2178
2186
|
|
2179
|
-
|
2180
|
-
|
2181
|
-
|
2182
|
-
|
2187
|
+
|
2188
|
+
def get_evaluation(evaluation: str | Evaluable) -> Evaluable:
|
2189
|
+
"""Gets an evaluation experiment by name."""
|
2190
|
+
if isinstance(evaluation, str):
|
2191
|
+
return _eval_registry.get(evaluation)()
|
2192
|
+
return evaluation
|
2193
|
+
|
2194
|
+
|
2195
|
+
def register(name: str):
|
2196
|
+
"""Decorator to create a named evaluation class."""
|
2197
|
+
|
2198
|
+
def _register(func_or_cls: Type[Evaluation] | types.FunctionType):
|
2199
|
+
if inspect.isfunction(func_or_cls):
|
2200
|
+
e = func_or_cls()
|
2201
|
+
if not isinstance(e, Evaluable):
|
2202
|
+
raise TypeError(
|
2203
|
+
f'The return value of `{func_or_cls}` should be an instance of '
|
2204
|
+
'`lf.eval.Evaluable` subclass.'
|
2205
|
+
)
|
2206
|
+
|
2207
|
+
class GeneratedSuite(Suite):
|
2208
|
+
# NOTE(daiyip): Delay serialization key registration for generated
|
2209
|
+
# class.
|
2210
|
+
auto_register = False
|
2211
|
+
children = e.children if isinstance(e, Suite) else [e]
|
2212
|
+
|
2213
|
+
cls = GeneratedSuite
|
2214
|
+
cls.__name__ = func_or_cls.__name__
|
2215
|
+
cls.__doc__ = func_or_cls.__doc__
|
2216
|
+
cls.__qualname__ = func_or_cls.__qualname__
|
2217
|
+
cls.__module__ = getattr(func_or_cls, '__module__', 'wrapper')
|
2218
|
+
cls.register_for_deserialization(cls.__type_name__)
|
2219
|
+
|
2220
|
+
elif issubclass(func_or_cls, Evaluable):
|
2221
|
+
cls = func_or_cls
|
2183
2222
|
else:
|
2184
|
-
|
2223
|
+
raise ValueError(f'Unsupported type: {type(func_or_cls)}')
|
2224
|
+
|
2225
|
+
_eval_registry.register(name, cls)
|
2226
|
+
return cls
|
2227
|
+
|
2228
|
+
return _register
|
2229
|
+
|
2230
|
+
|
2231
|
+
def get(
|
2232
|
+
root_dir: str,
|
2233
|
+
evaluations: list[str | Evaluable],
|
2234
|
+
filter: Union[ # pylint: disable=redefined-builtin
|
2235
|
+
str, # Regex to filter evaluation based on ID.
|
2236
|
+
Callable[[Evaluable], bool], # Custom filter function.
|
2237
|
+
None # No filtering (Default).
|
2238
|
+
] = None, # pylint: disable=bad-whitespace
|
2239
|
+
patches: list[Union[
|
2240
|
+
str, # String-based PyGlove patcher.
|
2241
|
+
pg.patching.Patcher, # PyGlove patcher object.
|
2242
|
+
Callable[[pg.KeyPath, Any, Any], Any], # PyGlove rebind function.
|
2243
|
+
]] | None = None, # pylint: disable=bad-whitespace
|
2244
|
+
) -> Suite:
|
2245
|
+
"""Gets a suite from a list of patched evaluations.
|
2246
|
+
|
2247
|
+
Args:
|
2248
|
+
root_dir: The root directory of the experiment.
|
2249
|
+
evaluations: A list of evaluations to be included in the suite.
|
2250
|
+
filter: A regular expression (str) for selecting sub-experiments of matched
|
2251
|
+
IDs, or a filter function to filter the evaluations.
|
2252
|
+
patches: A list of patches to be applied to the suite. Each element can be
|
2253
|
+
a string (for string-based patcher), a `pg.patching.Patcher` object, or
|
2254
|
+
a rebind function (e.g. `pg.rebind`). See `lf.eval.patch_*` for more
|
2255
|
+
details.
|
2256
|
+
|
2257
|
+
Returns:
|
2258
|
+
A suite of selected `lf.eval.Evaluation` objects.
|
2259
|
+
"""
|
2260
|
+
evaluations = [get_evaluation(e) for e in evaluations]
|
2261
|
+
suite = Suite(evaluations, root_dir=root_dir)
|
2262
|
+
if patches:
|
2263
|
+
suite = pg.patch(suite, patches)
|
2264
|
+
|
2265
|
+
if isinstance(filter, str):
|
2266
|
+
regex = re.compile(filter)
|
2267
|
+
filter = lambda x: bool(regex.match(x.id))
|
2268
|
+
|
2269
|
+
if filter:
|
2270
|
+
suite = Suite(
|
2271
|
+
[leaf for leaf in suite.leaf_nodes if filter(leaf)], root_dir=root_dir)
|
2272
|
+
return suite
|
2273
|
+
|
2274
|
+
|
2275
|
+
def run(
|
2276
|
+
root_dir: str,
|
2277
|
+
evaluations: list[str | Evaluable],
|
2278
|
+
filter: Union[ # pylint: disable=redefined-builtin
|
2279
|
+
str, # Regex to filter evaluation based on ID.
|
2280
|
+
Callable[[Evaluable], bool], # Custom filter function.
|
2281
|
+
None # No filtering (Default).
|
2282
|
+
] = None, # pylint: disable=bad-whitespace
|
2283
|
+
patches: list[Union[
|
2284
|
+
str, # String-based PyGlove patcher.
|
2285
|
+
pg.patching.Patcher, # PyGlove patcher object.
|
2286
|
+
Callable[[pg.KeyPath, Any, Any], Any], # PyGlove rebind function.
|
2287
|
+
]] | None = None, # pylint: disable=bad-whitespace
|
2288
|
+
mode: Literal['run', 'rerun', 'dryrun', 'noop'] = 'run',
|
2289
|
+
debug: bool = False,
|
2290
|
+
print_definition: bool = False,
|
2291
|
+
**kwargs,
|
2292
|
+
) -> Suite:
|
2293
|
+
"""Run selected evaluations with patching.
|
2294
|
+
|
2295
|
+
Args:
|
2296
|
+
root_dir: The root directory of the experiment.
|
2297
|
+
evaluations: A list of evaluations to be included in the suite.
|
2298
|
+
filter: A regular expression (str) for selecting sub-experiments of matched
|
2299
|
+
IDs, or a filter function to filter the evaluations.
|
2300
|
+
patches: A list of patches to be applied to the suite. Each element can be
|
2301
|
+
a string (for string-based patcher), a `pg.patching.Patcher` object, or
|
2302
|
+
a rebind function (e.g. `pg.rebind`). See `lf.eval.patch_*` for more
|
2303
|
+
details.
|
2304
|
+
mode: The mode to run the suite. "run" to run the suite, with reusing
|
2305
|
+
existing results if available; "rerun" to rerun all evaluations even if
|
2306
|
+
there are existing results; "dryrun" to dryrun the suite; and "noop"
|
2307
|
+
to do nothing.
|
2308
|
+
debug: Whether to run in debug mode.
|
2309
|
+
print_definition: Whether to print the experiment definition.
|
2310
|
+
**kwargs: Additional arguments to be passed to dryrun/run the suite.
|
2311
|
+
|
2312
|
+
Returns:
|
2313
|
+
A suite of selected `lf.eval.Evaluation` objects.
|
2314
|
+
"""
|
2315
|
+
suite = get(root_dir, evaluations, patches=patches, filter=filter)
|
2316
|
+
if print_definition:
|
2317
|
+
lf.console.write(
|
2318
|
+
pg.format(
|
2319
|
+
suite,
|
2320
|
+
compact=False,
|
2321
|
+
verbose=False,
|
2322
|
+
hide_default_values=True,
|
2323
|
+
python_format=True,
|
2324
|
+
),
|
2325
|
+
title='[EXPERIMENT DEFINITION]',
|
2326
|
+
color='blue',
|
2327
|
+
)
|
2185
2328
|
|
2186
|
-
|
2329
|
+
if mode == 'run':
|
2330
|
+
rerun = mode == 'rerun'
|
2331
|
+
suite.run(debug=debug, rerun=rerun, **kwargs)
|
2332
|
+
elif mode == 'dryrun':
|
2333
|
+
suite.dryrun(debug=debug, **kwargs)
|
2334
|
+
return suite
|
langfun/core/eval/base_test.py
CHANGED
@@ -749,16 +749,97 @@ class SummaryTest(unittest.TestCase):
|
|
749
749
|
self.assertTrue(pg.io.path_exists(summary_file))
|
750
750
|
|
751
751
|
|
752
|
-
class
|
752
|
+
class NamedEvaluationTest(unittest.TestCase):
|
753
753
|
|
754
|
-
def
|
755
|
-
|
756
|
-
|
757
|
-
|
758
|
-
|
754
|
+
def test_named_eval_class(self):
|
755
|
+
|
756
|
+
@base.register('named_eval/class_test')
|
757
|
+
class MyEval(base.Evaluation):
|
758
|
+
inputs = base.as_inputs([
|
759
|
+
pg.Dict(question='Compute 1 + 1'),
|
760
|
+
])
|
761
|
+
method = 'query'
|
762
|
+
prompt = pg.oneof([
|
763
|
+
lf.Template('{{example.question}}'),
|
764
|
+
lf.Template('Hello {{example.question}}'),
|
765
|
+
])
|
766
|
+
schema_fn = answer_schema()
|
767
|
+
|
768
|
+
evaluation = base.get_evaluation('named_eval/class_test')
|
769
|
+
self.assertIsInstance(evaluation, MyEval)
|
770
|
+
self.assertIsNone(evaluation.dir)
|
771
|
+
self.assertIsNone(evaluation.root_dir)
|
772
|
+
self.assertIn('named_eval/class_test', base.registered_names())
|
773
|
+
|
774
|
+
with self.assertRaisesRegex(ValueError, 'Unsupported type.*'):
|
775
|
+
@base.register('named_eval/bad_class')
|
776
|
+
class Foo: # pylint: disable=unused-variable
|
777
|
+
pass
|
778
|
+
|
779
|
+
def test_named_eval_functor(self):
|
780
|
+
|
781
|
+
@base.register('named_eval/functor_test')
|
782
|
+
def my_eval():
|
783
|
+
return base.Evaluation(
|
784
|
+
inputs=base.as_inputs([
|
785
|
+
pg.Dict(question='Compute 1 + 1'),
|
786
|
+
]),
|
787
|
+
method='query',
|
788
|
+
prompt=pg.oneof([
|
789
|
+
lf.Template('{{example.question}}'),
|
790
|
+
lf.Template('Hello {{example.question}}'),
|
791
|
+
]),
|
792
|
+
schema_fn=answer_schema(),
|
759
793
|
)
|
760
|
-
|
761
|
-
|
794
|
+
|
795
|
+
self.assertTrue(issubclass(my_eval, base.Evaluable))
|
796
|
+
evaluation = base.get_evaluation('named_eval/functor_test')
|
797
|
+
self.assertIn('named_eval/functor_test', base.registered_names())
|
798
|
+
self.assertIsInstance(evaluation, my_eval)
|
799
|
+
self.assertIsNone(evaluation.root_dir, None)
|
800
|
+
|
801
|
+
with self.assertRaisesRegex(ValueError, 'Evaluation .* not found'):
|
802
|
+
base.get_evaluation('named_eval/non_existent')
|
803
|
+
|
804
|
+
with self.assertRaisesRegex(TypeError, 'The return value .*'):
|
805
|
+
@base.register('named_eval/bad_return_type')
|
806
|
+
def bad_eval(): # pylint: disable=unused-variable
|
807
|
+
return 1
|
808
|
+
|
809
|
+
def test_run(self):
|
810
|
+
@base.register('test/run')
|
811
|
+
def test_run(): # pylint: disable=unused-variable
|
812
|
+
lm = fake.StaticResponse('Solution(final_answer=2)')
|
813
|
+
return eval_set('run_test', 'query', schema_fn=answer_schema(), lm=lm)
|
814
|
+
|
815
|
+
e = base.run(
|
816
|
+
tempfile.gettempdir(),
|
817
|
+
['test/run'],
|
818
|
+
id_regex='run_test.*',
|
819
|
+
mode='dryrun',
|
820
|
+
print_definition=True,
|
821
|
+
)
|
822
|
+
self.assertEqual(
|
823
|
+
e.leaf_nodes[0].dir,
|
824
|
+
os.path.join(tempfile.gettempdir(), e.leaf_nodes[0].id),
|
825
|
+
)
|
826
|
+
self.assertTrue(
|
827
|
+
pg.eq(
|
828
|
+
e.leaf_nodes[0].lm, fake.StaticResponse('Solution(final_answer=2)')
|
829
|
+
)
|
830
|
+
)
|
831
|
+
|
832
|
+
@pg.patcher()
|
833
|
+
def bad_lm(unused_eval): # pylint: disable=unused-variable
|
834
|
+
return dict(lm=fake.StaticResponse('efg'))
|
835
|
+
|
836
|
+
e = base.run(
|
837
|
+
tempfile.gettempdir(),
|
838
|
+
[test_run()],
|
839
|
+
filter='Evaluation.*',
|
840
|
+
patches=['bad_lm']
|
841
|
+
)
|
842
|
+
self.assertTrue(pg.eq(e.leaf_nodes[0].lm, fake.StaticResponse('efg')))
|
762
843
|
|
763
844
|
|
764
845
|
if __name__ == '__main__':
|
@@ -0,0 +1,130 @@
|
|
1
|
+
# Copyright 2024 The Langfun Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
"""Experiment patching for Langfun evaluations."""
|
15
|
+
|
16
|
+
import inspect
|
17
|
+
from typing import Union
|
18
|
+
import langfun.core as lf
|
19
|
+
from langfun.core import llms as lf_llms
|
20
|
+
from langfun.core.eval import base
|
21
|
+
import pyglove as pg
|
22
|
+
|
23
|
+
|
24
|
+
#
|
25
|
+
# Program-based patchers.
|
26
|
+
#
|
27
|
+
|
28
|
+
|
29
|
+
def patch_member(cls, key, value, parent_key: str | None = None):
|
30
|
+
"""Patches a member of a class."""
|
31
|
+
|
32
|
+
def _rebind_fn(k, v, p):
|
33
|
+
if (
|
34
|
+
isinstance(p, cls)
|
35
|
+
and k.key == key
|
36
|
+
and (parent_key is None or (p and p.sym_path.key == parent_key))
|
37
|
+
):
|
38
|
+
if inspect.isfunction(value):
|
39
|
+
return value(k, v, p)
|
40
|
+
return value
|
41
|
+
return v
|
42
|
+
|
43
|
+
return _rebind_fn
|
44
|
+
|
45
|
+
|
46
|
+
def patch_lm(lm: Union[lf.LanguageModel, pg.hyper.OneOf]): # pylint: disable=redefined-outer-name
|
47
|
+
"""Patches the LLM of evaluations."""
|
48
|
+
return patch_member(base.Evaluable, "lm", lm)
|
49
|
+
|
50
|
+
|
51
|
+
def patch_parsing_lm(lm: Union[lf.LanguageModel, pg.hyper.OneOf]): # pylint: disable=redefined-outer-name
|
52
|
+
"""Patches the parsing LLM of evaluations."""
|
53
|
+
return patch_member(base.Evaluable, "parsing_lm", lm)
|
54
|
+
|
55
|
+
|
56
|
+
def patch_schema_fn(schema_fn: Union[pg.Functor, pg.hyper.OneOf]):
|
57
|
+
"""Patches the schema_fn of evaluations."""
|
58
|
+
return patch_member(base.Evaluable, "schema_fn", schema_fn)
|
59
|
+
|
60
|
+
|
61
|
+
def patch_prompt(prompt: Union[str, lf.Template, pg.hyper.OneOf]):
|
62
|
+
"""Patches the prompt of evaluations."""
|
63
|
+
return patch_member(base.Evaluable, "prompt", prompt)
|
64
|
+
|
65
|
+
|
66
|
+
def patch_inputs(inputs: Union[pg.Functor, pg.hyper.OneOf]):
|
67
|
+
"""Patches the inputs used in evaluations."""
|
68
|
+
return patch_member(base.Evaluable, "inputs", inputs)
|
69
|
+
|
70
|
+
|
71
|
+
def patch_additional_args(**kwargs):
|
72
|
+
"""Patches additional_args."""
|
73
|
+
|
74
|
+
def value_fn(k, unused_v, p):
|
75
|
+
# We infer the symbolic value for the old args, as it might be a
|
76
|
+
# contextual attribute referring to its containing object.
|
77
|
+
old_args = p.sym_inferred(k.key)
|
78
|
+
if old_args:
|
79
|
+
old_args = dict(old_args)
|
80
|
+
old_args.update(kwargs)
|
81
|
+
return old_args
|
82
|
+
return kwargs
|
83
|
+
|
84
|
+
return patch_member(base.Evaluable, "additional_args", value_fn)
|
85
|
+
|
86
|
+
|
87
|
+
#
|
88
|
+
# String-based patching.
|
89
|
+
#
|
90
|
+
|
91
|
+
_NAMED_MODELS = {
|
92
|
+
# GPT models.
|
93
|
+
"gpt35turbo": lf_llms.Gpt35Turbo,
|
94
|
+
"gpt35turbo16k": lf_llms.Gpt35Turbo16K,
|
95
|
+
"gpt4": lf_llms.Gpt4,
|
96
|
+
"gpt4turbo": lf_llms.Gpt4Turbo,
|
97
|
+
# Anthropic models.
|
98
|
+
"haiku": lf_llms.Claude3Haiku,
|
99
|
+
"claude3haiku": lf_llms.Claude3Haiku,
|
100
|
+
"opus": lf_llms.Claude3Opus,
|
101
|
+
"claude3opus": lf_llms.Claude3Opus,
|
102
|
+
"sonnet": lf_llms.Claude3Sonnet,
|
103
|
+
"claude3sonnet": lf_llms.Claude3Opus,
|
104
|
+
}
|
105
|
+
|
106
|
+
|
107
|
+
def model_by_name(name: str) -> lf.LanguageModel:
|
108
|
+
"""Gets model by name."""
|
109
|
+
name = name.strip().lower()
|
110
|
+
if name in _NAMED_MODELS:
|
111
|
+
return _NAMED_MODELS[name]()
|
112
|
+
raise ValueError(f"Unknown model name: {name}")
|
113
|
+
|
114
|
+
|
115
|
+
@pg.patcher(auto_typing=True)
|
116
|
+
def lm(unused_eval, models: list[str]):
|
117
|
+
"""Patch the LM used for benchmarking."""
|
118
|
+
return patch_lm(pg.oneof([model_by_name(name) for name in models]))
|
119
|
+
|
120
|
+
|
121
|
+
@pg.patcher(auto_typing=True)
|
122
|
+
def temperature(unused_eval, value: float):
|
123
|
+
"""Patch the temperature used for benchmarking."""
|
124
|
+
return patch_member(lf.LMSamplingOptions, "temperature", value)
|
125
|
+
|
126
|
+
|
127
|
+
@pg.patcher(auto_typing=True)
|
128
|
+
def max_tokens(unused_eval, value: int | None):
|
129
|
+
"""Patch the temperature used for benchmarking."""
|
130
|
+
return patch_member(lf.LMSamplingOptions, "max_tokens", value)
|
@@ -0,0 +1,170 @@
|
|
1
|
+
# Copyright 2024 The Langfun Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
"""Tests for evaluation patching."""
|
15
|
+
|
16
|
+
import unittest
|
17
|
+
from langfun.core import llms as lf_llms
|
18
|
+
from langfun.core.eval import base
|
19
|
+
from langfun.core.eval import patching
|
20
|
+
import pyglove as pg
|
21
|
+
|
22
|
+
|
23
|
+
class PatchingCommonTest(unittest.TestCase):
|
24
|
+
|
25
|
+
def test_patch_member(self):
|
26
|
+
class A(pg.Object):
|
27
|
+
x: int = 1
|
28
|
+
|
29
|
+
class B(pg.Object):
|
30
|
+
a: A
|
31
|
+
|
32
|
+
b = B(A())
|
33
|
+
pg.patch(b, [patching.patch_member(A, 'x', 2)])
|
34
|
+
self.assertEqual(b, B(A(2)))
|
35
|
+
|
36
|
+
def test_patch_args(self):
|
37
|
+
s = base.Suite(
|
38
|
+
[base.Evaluation(inputs=base.as_inputs([1]))],
|
39
|
+
additional_args=dict(x=1, y=2),
|
40
|
+
)
|
41
|
+
pg.patch(s, [patching.patch_additional_args(x=3, z=4)])
|
42
|
+
self.assertTrue(
|
43
|
+
pg.eq(
|
44
|
+
s,
|
45
|
+
base.Suite(
|
46
|
+
[
|
47
|
+
base.Evaluation(
|
48
|
+
inputs=base.as_inputs([1]),
|
49
|
+
additional_args=dict(x=3, y=2, z=4),
|
50
|
+
)
|
51
|
+
],
|
52
|
+
additional_args=dict(x=3, y=2, z=4),
|
53
|
+
),
|
54
|
+
)
|
55
|
+
)
|
56
|
+
|
57
|
+
def test_patch_lm(self):
|
58
|
+
s = base.Suite(
|
59
|
+
[base.Evaluation(inputs=base.as_inputs([1]))],
|
60
|
+
lm=lf_llms.Gpt35Turbo(),
|
61
|
+
)
|
62
|
+
pg.patch(
|
63
|
+
s, [patching.patch_lm(pg.oneof([lf_llms.Gpt35Turbo(), lf_llms.Gpt4()]))]
|
64
|
+
)
|
65
|
+
self.assertTrue(
|
66
|
+
pg.eq(
|
67
|
+
s,
|
68
|
+
base.Suite(
|
69
|
+
[
|
70
|
+
base.Evaluation(
|
71
|
+
inputs=base.as_inputs([1]),
|
72
|
+
lm=pg.oneof([lf_llms.Gpt35Turbo(), lf_llms.Gpt4()]),
|
73
|
+
)
|
74
|
+
],
|
75
|
+
lm=pg.oneof([lf_llms.Gpt35Turbo(), lf_llms.Gpt4()]),
|
76
|
+
),
|
77
|
+
)
|
78
|
+
)
|
79
|
+
|
80
|
+
def test_patch_parsing_lm(self):
|
81
|
+
s = base.Suite(
|
82
|
+
[base.Evaluation(inputs=base.as_inputs([1]))],
|
83
|
+
lm=lf_llms.Gpt4(),
|
84
|
+
)
|
85
|
+
pg.patch(s, [patching.patch_parsing_lm(lf_llms.Gpt35Turbo())])
|
86
|
+
self.assertTrue(
|
87
|
+
pg.eq(
|
88
|
+
s,
|
89
|
+
base.Suite(
|
90
|
+
[
|
91
|
+
base.Evaluation(
|
92
|
+
inputs=base.as_inputs([1]),
|
93
|
+
lm=lf_llms.Gpt4(),
|
94
|
+
parsing_lm=lf_llms.Gpt35Turbo(),
|
95
|
+
)
|
96
|
+
],
|
97
|
+
# NOTE(daiyip): Suite does not have `parsing_lm` as one of its
|
98
|
+
# variable keyword fields yet, so patching does not add to it.
|
99
|
+
# This is okay since we only care about the leaf nodes.
|
100
|
+
lm=lf_llms.Gpt4(),
|
101
|
+
),
|
102
|
+
)
|
103
|
+
)
|
104
|
+
|
105
|
+
def test_patch_prompt(self):
|
106
|
+
e = base.Evaluation(inputs=base.as_inputs([1]))
|
107
|
+
pg.patch(e, [patching.patch_prompt('Q: {{example.question}}')])
|
108
|
+
self.assertTrue(
|
109
|
+
pg.eq(
|
110
|
+
e,
|
111
|
+
base.Evaluation(
|
112
|
+
inputs=base.as_inputs([1]),
|
113
|
+
prompt='Q: {{example.question}}',
|
114
|
+
),
|
115
|
+
)
|
116
|
+
)
|
117
|
+
|
118
|
+
def test_patch_inputs(self):
|
119
|
+
e = base.Evaluation(inputs=base.as_inputs([1]))
|
120
|
+
pg.patch(e, [patching.patch_inputs(base.as_inputs([2]))])
|
121
|
+
self.assertTrue(
|
122
|
+
pg.eq(
|
123
|
+
e,
|
124
|
+
base.Evaluation(
|
125
|
+
inputs=base.as_inputs([2]),
|
126
|
+
),
|
127
|
+
)
|
128
|
+
)
|
129
|
+
|
130
|
+
def test_patch_schema_fn(self):
|
131
|
+
@pg.functor()
|
132
|
+
def int_schema():
|
133
|
+
return int
|
134
|
+
|
135
|
+
e = base.Evaluation(inputs=base.as_inputs([1]))
|
136
|
+
pg.patch(e, [patching.patch_schema_fn(int_schema())])
|
137
|
+
self.assertTrue(
|
138
|
+
pg.eq(
|
139
|
+
e,
|
140
|
+
base.Evaluation(
|
141
|
+
inputs=base.as_inputs([1]),
|
142
|
+
schema_fn=int_schema(),
|
143
|
+
),
|
144
|
+
)
|
145
|
+
)
|
146
|
+
|
147
|
+
|
148
|
+
class StringPatcheTest(unittest.TestCase):
|
149
|
+
|
150
|
+
def test_lm(self):
|
151
|
+
target = pg.patch(
|
152
|
+
base.Evaluation(inputs=base.as_inputs([1])),
|
153
|
+
['lm?haiku:gpt4', 'max_tokens?1024', 'temperature?0.7'],
|
154
|
+
)
|
155
|
+
self.assertEqual(
|
156
|
+
target.lm,
|
157
|
+
pg.oneof([
|
158
|
+
lf_llms.Claude3Haiku(temperature=0.7, max_tokens=1024),
|
159
|
+
lf_llms.Gpt4(temperature=0.7, max_tokens=1024),
|
160
|
+
]),
|
161
|
+
)
|
162
|
+
with self.assertRaisesRegex(ValueError, 'Unknown model name'):
|
163
|
+
pg.patch(
|
164
|
+
base.Evaluation(inputs=base.as_inputs([1])),
|
165
|
+
['lm?gpt2'],
|
166
|
+
)
|
167
|
+
|
168
|
+
|
169
|
+
if __name__ == '__main__':
|
170
|
+
unittest.main()
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: langfun
|
3
|
-
Version: 0.0.2.
|
3
|
+
Version: 0.0.2.dev20240501
|
4
4
|
Summary: Langfun: Language as Functions.
|
5
5
|
Home-page: https://github.com/google/langfun
|
6
6
|
Author: Langfun Authors
|
@@ -21,7 +21,6 @@ Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
|
21
21
|
Classifier: Topic :: Software Development :: Libraries
|
22
22
|
Description-Content-Type: text/markdown
|
23
23
|
License-File: LICENSE
|
24
|
-
Requires-Dist: absl-py >=1.0.0
|
25
24
|
Requires-Dist: google-generativeai >=0.3.2
|
26
25
|
Requires-Dist: jinja2 >=3.1.2
|
27
26
|
Requires-Dist: openai ==0.27.2
|
@@ -39,11 +39,13 @@ langfun/core/coding/python/parsing.py,sha256=uyvI1c5OLZhMVK2Oltkl3oJxSLlG0wadlpQ
|
|
39
39
|
langfun/core/coding/python/parsing_test.py,sha256=9vAWF484kWIm6JZq8NFiMgKUDhXV-deRl1QMmNERfAA,7386
|
40
40
|
langfun/core/coding/python/permissions.py,sha256=1QWGHvzL8MM0Ok_auQ9tURqZHtdOfJaDpBzZ29GUE-c,2544
|
41
41
|
langfun/core/coding/python/permissions_test.py,sha256=w5EDb8QxpxgJyZkojyzVWQvDfg366zn99-g__6TbPQ0,2699
|
42
|
-
langfun/core/eval/__init__.py,sha256=
|
43
|
-
langfun/core/eval/base.py,sha256=
|
44
|
-
langfun/core/eval/base_test.py,sha256=
|
42
|
+
langfun/core/eval/__init__.py,sha256=Evt-E4FEhZF2tXL6-byh_AyA7Cc_ZoGmvnN7vkAZedk,1898
|
43
|
+
langfun/core/eval/base.py,sha256=VgHdnfkHeGPp0XjIGHw9LDZsR0Z4-yuWIkzn4pqJj3Y,73967
|
44
|
+
langfun/core/eval/base_test.py,sha256=cHOTIWVW4Dp8gKKIKcZrAcJ-w84j2GIozTzJoiAX7p4,26743
|
45
45
|
langfun/core/eval/matching.py,sha256=Y4vFoNTQEOwko6IA8l9OZ52-vt52e3VGmcTtvLA67wM,9782
|
46
46
|
langfun/core/eval/matching_test.py,sha256=f7iVyXH5KGJBWt4Wp14Bt9J3X59A6Ayfog9MbuFvPew,5532
|
47
|
+
langfun/core/eval/patching.py,sha256=R0s2eAd1m97exQt06dmUL0V_MBG0W2Hxg7fhNB7cXW0,3866
|
48
|
+
langfun/core/eval/patching_test.py,sha256=8kCd54Egjju22FMgtJuxEsrXkW8ifs-UUBHtrCG1L6w,4775
|
47
49
|
langfun/core/eval/scoring.py,sha256=1J7IATo-8FXUR0SBqk9icztHiM0lWkBFcWUo-vUURgQ,6376
|
48
50
|
langfun/core/eval/scoring_test.py,sha256=O8olHbrUEg60gMxwOkWzKBJZpZoUlmVnBANX5Se2SXM,4546
|
49
51
|
langfun/core/llms/__init__.py,sha256=1bPg1QI8duOZCYINm-jWi094x0JtLmsk4KX60qIC_gs,3245
|
@@ -101,8 +103,8 @@ langfun/core/templates/demonstration.py,sha256=vCrgYubdZM5Umqcgp8NUVGXgr4P_c-fik
|
|
101
103
|
langfun/core/templates/demonstration_test.py,sha256=SafcDQ0WgI7pw05EmPI2S4v1t3ABKzup8jReCljHeK4,2162
|
102
104
|
langfun/core/templates/selfplay.py,sha256=yhgrJbiYwq47TgzThmHrDQTF4nDrTI09CWGhuQPNv-s,2273
|
103
105
|
langfun/core/templates/selfplay_test.py,sha256=DYVrkk7uNKCqJGEHH31HssU2BPuMItU1vJLzfcXIlYg,2156
|
104
|
-
langfun-0.0.2.
|
105
|
-
langfun-0.0.2.
|
106
|
-
langfun-0.0.2.
|
107
|
-
langfun-0.0.2.
|
108
|
-
langfun-0.0.2.
|
106
|
+
langfun-0.0.2.dev20240501.dist-info/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
|
107
|
+
langfun-0.0.2.dev20240501.dist-info/METADATA,sha256=SUhJ4RRQcyqLKu16sGip7Z2D875PI5EarCo3VDAGxuQ,3405
|
108
|
+
langfun-0.0.2.dev20240501.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
109
|
+
langfun-0.0.2.dev20240501.dist-info/top_level.txt,sha256=RhlEkHxs1qtzmmtWSwYoLVJAc1YrbPtxQ52uh8Z9VvY,8
|
110
|
+
langfun-0.0.2.dev20240501.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|