langfun 0.0.2.dev20240430__py3-none-any.whl → 0.0.2.dev20240501__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -16,7 +16,11 @@
16
16
  # pylint: disable=g-importing-member
17
17
  # pylint: disable=g-bad-import-order
18
18
 
19
- from langfun.core.eval.base import app_run
19
+ from langfun.core.eval.base import register
20
+ from langfun.core.eval.base import registered_names
21
+ from langfun.core.eval.base import get_evaluation
22
+ from langfun.core.eval.base import get
23
+ from langfun.core.eval.base import run
20
24
 
21
25
  from langfun.core.eval.base import Evaluable
22
26
  from langfun.core.eval.base import Evaluation
@@ -34,6 +38,15 @@ from langfun.core.eval.base import as_inputs
34
38
  from langfun.core.eval.matching import Matching
35
39
  from langfun.core.eval.scoring import Scoring
36
40
 
41
+ # Experiment patching.
42
+ from langfun.core.eval.patching import patch_member
37
43
 
44
+ from langfun.core.eval.patching import patch_lm
45
+ from langfun.core.eval.patching import patch_parsing_lm
46
+ from langfun.core.eval.patching import patch_inputs
47
+ from langfun.core.eval.patching import patch_prompt
48
+ from langfun.core.eval.patching import patch_schema_fn
49
+
50
+ # Placeholder for Google-internal imports.
38
51
  # pylint: enable=g-bad-import-order
39
52
  # pylint: enable=g-importing-member
langfun/core/eval/base.py CHANGED
@@ -25,10 +25,9 @@ import os
25
25
  import re
26
26
  import threading
27
27
  import time
28
+ import types
28
29
  from typing import Annotated, Any, Callable, Iterator, Literal, Optional, Sequence, Type, Union
29
30
 
30
- from absl import app
31
- from absl import flags
32
31
  import langfun.core as lf
33
32
  import langfun.core.coding as lf_coding
34
33
  from langfun.core.llms.cache import in_memory
@@ -600,7 +599,6 @@ class _LeafNode:
600
599
  @pg.use_init_args(['children'])
601
600
  class Suite(Evaluable):
602
601
  """Evaluation suite."""
603
-
604
602
  children: Annotated[list[Evaluable], 'Child evaluation sets or suites.']
605
603
 
606
604
  # Use empty ID as suite is just a container of child evaluations.
@@ -2146,41 +2144,191 @@ def monitor_async(
2146
2144
  )
2147
2145
 
2148
2146
 
2149
- def app_run(target: Evaluable):
2150
- """Runs the target evaluation as an absl app.
2147
+ #
2148
+ # Named evaluations and experiments support.
2149
+ #
2151
2150
 
2152
- Args:
2153
- target: An Langfun evaluable object.
2154
- """
2155
- flags.DEFINE_string(
2156
- 'root_dir', None, 'Root directory for running the evaluation.'
2157
- )
2158
2151
 
2159
- flags.DEFINE_bool(
2160
- 'dryrun', False, 'If True, dryrun the experiment instead of running it.'
2161
- )
2152
+ class _NamedEvaluationRegistry:
2153
+ """Named evaluation registry."""
2162
2154
 
2163
- flags.DEFINE_bool(
2164
- 'debug', False, 'If True, output prompt and response to the console.'
2165
- )
2155
+ def __init__(self):
2156
+ self._registry = {}
2166
2157
 
2167
- flags.DEFINE_bool(
2168
- 'rerun',
2169
- False,
2170
- 'If True, rerun the experiment even a cached result is found.',
2171
- )
2158
+ def names(self) -> list[str]:
2159
+ """Returns all registered names."""
2160
+ return sorted(self._registry.keys())
2161
+
2162
+ def get(self, name: str) -> Type[Evaluable]:
2163
+ """Gets an evaluation by name."""
2164
+ if name not in self._registry:
2165
+ raise ValueError(
2166
+ f'Evaluation {name!r} not found. '
2167
+ 'Did you forget to import the module that registers it?'
2168
+ )
2169
+ return self._registry[name]
2170
+
2171
+ def register(
2172
+ self,
2173
+ name: str,
2174
+ experiment_cls: Type[Evaluable],
2175
+ ):
2176
+ """Register an experiment class."""
2177
+ self._registry[name] = experiment_cls
2178
+
2179
+
2180
+ _eval_registry = _NamedEvaluationRegistry()
2172
2181
 
2173
- FLAGS = flags.FLAGS # pylint: disable=invalid-name
2174
2182
 
2175
- def _main(argv):
2176
- if len(argv) > 1:
2177
- raise app.UsageError('Too many command-line arguments.')
2183
+ def registered_names() -> list[str]:
2184
+ """Returns all registered names."""
2185
+ return _eval_registry.names()
2178
2186
 
2179
- if FLAGS.root_dir:
2180
- target.rebind(root_dir=FLAGS.root_dir, raise_on_no_change=False)
2181
- if FLAGS.dryrun:
2182
- target.dryrun(debug=FLAGS.debug)
2187
+
2188
+ def get_evaluation(evaluation: str | Evaluable) -> Evaluable:
2189
+ """Gets an evaluation experiment by name."""
2190
+ if isinstance(evaluation, str):
2191
+ return _eval_registry.get(evaluation)()
2192
+ return evaluation
2193
+
2194
+
2195
+ def register(name: str):
2196
+ """Decorator to create a named evaluation class."""
2197
+
2198
+ def _register(func_or_cls: Type[Evaluation] | types.FunctionType):
2199
+ if inspect.isfunction(func_or_cls):
2200
+ e = func_or_cls()
2201
+ if not isinstance(e, Evaluable):
2202
+ raise TypeError(
2203
+ f'The return value of `{func_or_cls}` should be an instance of '
2204
+ '`lf.eval.Evaluable` subclass.'
2205
+ )
2206
+
2207
+ class GeneratedSuite(Suite):
2208
+ # NOTE(daiyip): Delay serialization key registration for generated
2209
+ # class.
2210
+ auto_register = False
2211
+ children = e.children if isinstance(e, Suite) else [e]
2212
+
2213
+ cls = GeneratedSuite
2214
+ cls.__name__ = func_or_cls.__name__
2215
+ cls.__doc__ = func_or_cls.__doc__
2216
+ cls.__qualname__ = func_or_cls.__qualname__
2217
+ cls.__module__ = getattr(func_or_cls, '__module__', 'wrapper')
2218
+ cls.register_for_deserialization(cls.__type_name__)
2219
+
2220
+ elif issubclass(func_or_cls, Evaluable):
2221
+ cls = func_or_cls
2183
2222
  else:
2184
- target.run(debug=FLAGS.debug, rerun=FLAGS.rerun)
2223
+ raise ValueError(f'Unsupported type: {type(func_or_cls)}')
2224
+
2225
+ _eval_registry.register(name, cls)
2226
+ return cls
2227
+
2228
+ return _register
2229
+
2230
+
2231
+ def get(
2232
+ root_dir: str,
2233
+ evaluations: list[str | Evaluable],
2234
+ filter: Union[ # pylint: disable=redefined-builtin
2235
+ str, # Regex to filter evaluation based on ID.
2236
+ Callable[[Evaluable], bool], # Custom filter function.
2237
+ None # No filtering (Default).
2238
+ ] = None, # pylint: disable=bad-whitespace
2239
+ patches: list[Union[
2240
+ str, # String-based PyGlove patcher.
2241
+ pg.patching.Patcher, # PyGlove patcher object.
2242
+ Callable[[pg.KeyPath, Any, Any], Any], # PyGlove rebind function.
2243
+ ]] | None = None, # pylint: disable=bad-whitespace
2244
+ ) -> Suite:
2245
+ """Gets a suite from a list of patched evaluations.
2246
+
2247
+ Args:
2248
+ root_dir: The root directory of the experiment.
2249
+ evaluations: A list of evaluations to be included in the suite.
2250
+ filter: A regular expression (str) for selecting sub-experiments of matched
2251
+ IDs, or a filter function to filter the evaluations.
2252
+ patches: A list of patches to be applied to the suite. Each element can be
2253
+ a string (for string-based patcher), a `pg.patching.Patcher` object, or
2254
+ a rebind function (e.g. `pg.rebind`). See `lf.eval.patch_*` for more
2255
+ details.
2256
+
2257
+ Returns:
2258
+ A suite of selected `lf.eval.Evaluation` objects.
2259
+ """
2260
+ evaluations = [get_evaluation(e) for e in evaluations]
2261
+ suite = Suite(evaluations, root_dir=root_dir)
2262
+ if patches:
2263
+ suite = pg.patch(suite, patches)
2264
+
2265
+ if isinstance(filter, str):
2266
+ regex = re.compile(filter)
2267
+ filter = lambda x: bool(regex.match(x.id))
2268
+
2269
+ if filter:
2270
+ suite = Suite(
2271
+ [leaf for leaf in suite.leaf_nodes if filter(leaf)], root_dir=root_dir)
2272
+ return suite
2273
+
2274
+
2275
+ def run(
2276
+ root_dir: str,
2277
+ evaluations: list[str | Evaluable],
2278
+ filter: Union[ # pylint: disable=redefined-builtin
2279
+ str, # Regex to filter evaluation based on ID.
2280
+ Callable[[Evaluable], bool], # Custom filter function.
2281
+ None # No filtering (Default).
2282
+ ] = None, # pylint: disable=bad-whitespace
2283
+ patches: list[Union[
2284
+ str, # String-based PyGlove patcher.
2285
+ pg.patching.Patcher, # PyGlove patcher object.
2286
+ Callable[[pg.KeyPath, Any, Any], Any], # PyGlove rebind function.
2287
+ ]] | None = None, # pylint: disable=bad-whitespace
2288
+ mode: Literal['run', 'rerun', 'dryrun', 'noop'] = 'run',
2289
+ debug: bool = False,
2290
+ print_definition: bool = False,
2291
+ **kwargs,
2292
+ ) -> Suite:
2293
+ """Run selected evaluations with patching.
2294
+
2295
+ Args:
2296
+ root_dir: The root directory of the experiment.
2297
+ evaluations: A list of evaluations to be included in the suite.
2298
+ filter: A regular expression (str) for selecting sub-experiments of matched
2299
+ IDs, or a filter function to filter the evaluations.
2300
+ patches: A list of patches to be applied to the suite. Each element can be
2301
+ a string (for string-based patcher), a `pg.patching.Patcher` object, or
2302
+ a rebind function (e.g. `pg.rebind`). See `lf.eval.patch_*` for more
2303
+ details.
2304
+ mode: The mode to run the suite. "run" to run the suite, with reusing
2305
+ existing results if available; "rerun" to rerun all evaluations even if
2306
+ there are existing results; "dryrun" to dryrun the suite; and "noop"
2307
+ to do nothing.
2308
+ debug: Whether to run in debug mode.
2309
+ print_definition: Whether to print the experiment definition.
2310
+ **kwargs: Additional arguments to be passed to dryrun/run the suite.
2311
+
2312
+ Returns:
2313
+ A suite of selected `lf.eval.Evaluation` objects.
2314
+ """
2315
+ suite = get(root_dir, evaluations, patches=patches, filter=filter)
2316
+ if print_definition:
2317
+ lf.console.write(
2318
+ pg.format(
2319
+ suite,
2320
+ compact=False,
2321
+ verbose=False,
2322
+ hide_default_values=True,
2323
+ python_format=True,
2324
+ ),
2325
+ title='[EXPERIMENT DEFINITION]',
2326
+ color='blue',
2327
+ )
2185
2328
 
2186
- app.run(_main)
2329
+ if mode == 'run':
2330
+ rerun = mode == 'rerun'
2331
+ suite.run(debug=debug, rerun=rerun, **kwargs)
2332
+ elif mode == 'dryrun':
2333
+ suite.dryrun(debug=debug, **kwargs)
2334
+ return suite
@@ -749,16 +749,97 @@ class SummaryTest(unittest.TestCase):
749
749
  self.assertTrue(pg.io.path_exists(summary_file))
750
750
 
751
751
 
752
- class AppRunTest(unittest.TestCase):
752
+ class NamedEvaluationTest(unittest.TestCase):
753
753
 
754
- def test_app_run(self):
755
- lm = fake.StaticSequence(['two', 'Solution(final_answer=2)'])
756
- try:
757
- base.app_run(
758
- eval_set('app_run_test', 'query', schema_fn=answer_schema(), lm=lm)
754
+ def test_named_eval_class(self):
755
+
756
+ @base.register('named_eval/class_test')
757
+ class MyEval(base.Evaluation):
758
+ inputs = base.as_inputs([
759
+ pg.Dict(question='Compute 1 + 1'),
760
+ ])
761
+ method = 'query'
762
+ prompt = pg.oneof([
763
+ lf.Template('{{example.question}}'),
764
+ lf.Template('Hello {{example.question}}'),
765
+ ])
766
+ schema_fn = answer_schema()
767
+
768
+ evaluation = base.get_evaluation('named_eval/class_test')
769
+ self.assertIsInstance(evaluation, MyEval)
770
+ self.assertIsNone(evaluation.dir)
771
+ self.assertIsNone(evaluation.root_dir)
772
+ self.assertIn('named_eval/class_test', base.registered_names())
773
+
774
+ with self.assertRaisesRegex(ValueError, 'Unsupported type.*'):
775
+ @base.register('named_eval/bad_class')
776
+ class Foo: # pylint: disable=unused-variable
777
+ pass
778
+
779
+ def test_named_eval_functor(self):
780
+
781
+ @base.register('named_eval/functor_test')
782
+ def my_eval():
783
+ return base.Evaluation(
784
+ inputs=base.as_inputs([
785
+ pg.Dict(question='Compute 1 + 1'),
786
+ ]),
787
+ method='query',
788
+ prompt=pg.oneof([
789
+ lf.Template('{{example.question}}'),
790
+ lf.Template('Hello {{example.question}}'),
791
+ ]),
792
+ schema_fn=answer_schema(),
759
793
  )
760
- except SystemExit:
761
- pass
794
+
795
+ self.assertTrue(issubclass(my_eval, base.Evaluable))
796
+ evaluation = base.get_evaluation('named_eval/functor_test')
797
+ self.assertIn('named_eval/functor_test', base.registered_names())
798
+ self.assertIsInstance(evaluation, my_eval)
799
+ self.assertIsNone(evaluation.root_dir, None)
800
+
801
+ with self.assertRaisesRegex(ValueError, 'Evaluation .* not found'):
802
+ base.get_evaluation('named_eval/non_existent')
803
+
804
+ with self.assertRaisesRegex(TypeError, 'The return value .*'):
805
+ @base.register('named_eval/bad_return_type')
806
+ def bad_eval(): # pylint: disable=unused-variable
807
+ return 1
808
+
809
+ def test_run(self):
810
+ @base.register('test/run')
811
+ def test_run(): # pylint: disable=unused-variable
812
+ lm = fake.StaticResponse('Solution(final_answer=2)')
813
+ return eval_set('run_test', 'query', schema_fn=answer_schema(), lm=lm)
814
+
815
+ e = base.run(
816
+ tempfile.gettempdir(),
817
+ ['test/run'],
818
+ id_regex='run_test.*',
819
+ mode='dryrun',
820
+ print_definition=True,
821
+ )
822
+ self.assertEqual(
823
+ e.leaf_nodes[0].dir,
824
+ os.path.join(tempfile.gettempdir(), e.leaf_nodes[0].id),
825
+ )
826
+ self.assertTrue(
827
+ pg.eq(
828
+ e.leaf_nodes[0].lm, fake.StaticResponse('Solution(final_answer=2)')
829
+ )
830
+ )
831
+
832
+ @pg.patcher()
833
+ def bad_lm(unused_eval): # pylint: disable=unused-variable
834
+ return dict(lm=fake.StaticResponse('efg'))
835
+
836
+ e = base.run(
837
+ tempfile.gettempdir(),
838
+ [test_run()],
839
+ filter='Evaluation.*',
840
+ patches=['bad_lm']
841
+ )
842
+ self.assertTrue(pg.eq(e.leaf_nodes[0].lm, fake.StaticResponse('efg')))
762
843
 
763
844
 
764
845
  if __name__ == '__main__':
@@ -0,0 +1,130 @@
1
+ # Copyright 2024 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Experiment patching for Langfun evaluations."""
15
+
16
+ import inspect
17
+ from typing import Union
18
+ import langfun.core as lf
19
+ from langfun.core import llms as lf_llms
20
+ from langfun.core.eval import base
21
+ import pyglove as pg
22
+
23
+
24
+ #
25
+ # Program-based patchers.
26
+ #
27
+
28
+
29
+ def patch_member(cls, key, value, parent_key: str | None = None):
30
+ """Patches a member of a class."""
31
+
32
+ def _rebind_fn(k, v, p):
33
+ if (
34
+ isinstance(p, cls)
35
+ and k.key == key
36
+ and (parent_key is None or (p and p.sym_path.key == parent_key))
37
+ ):
38
+ if inspect.isfunction(value):
39
+ return value(k, v, p)
40
+ return value
41
+ return v
42
+
43
+ return _rebind_fn
44
+
45
+
46
+ def patch_lm(lm: Union[lf.LanguageModel, pg.hyper.OneOf]): # pylint: disable=redefined-outer-name
47
+ """Patches the LLM of evaluations."""
48
+ return patch_member(base.Evaluable, "lm", lm)
49
+
50
+
51
+ def patch_parsing_lm(lm: Union[lf.LanguageModel, pg.hyper.OneOf]): # pylint: disable=redefined-outer-name
52
+ """Patches the parsing LLM of evaluations."""
53
+ return patch_member(base.Evaluable, "parsing_lm", lm)
54
+
55
+
56
+ def patch_schema_fn(schema_fn: Union[pg.Functor, pg.hyper.OneOf]):
57
+ """Patches the schema_fn of evaluations."""
58
+ return patch_member(base.Evaluable, "schema_fn", schema_fn)
59
+
60
+
61
+ def patch_prompt(prompt: Union[str, lf.Template, pg.hyper.OneOf]):
62
+ """Patches the prompt of evaluations."""
63
+ return patch_member(base.Evaluable, "prompt", prompt)
64
+
65
+
66
+ def patch_inputs(inputs: Union[pg.Functor, pg.hyper.OneOf]):
67
+ """Patches the inputs used in evaluations."""
68
+ return patch_member(base.Evaluable, "inputs", inputs)
69
+
70
+
71
+ def patch_additional_args(**kwargs):
72
+ """Patches additional_args."""
73
+
74
+ def value_fn(k, unused_v, p):
75
+ # We infer the symbolic value for the old args, as it might be a
76
+ # contextual attribute referring to its containing object.
77
+ old_args = p.sym_inferred(k.key)
78
+ if old_args:
79
+ old_args = dict(old_args)
80
+ old_args.update(kwargs)
81
+ return old_args
82
+ return kwargs
83
+
84
+ return patch_member(base.Evaluable, "additional_args", value_fn)
85
+
86
+
87
+ #
88
+ # String-based patching.
89
+ #
90
+
91
+ _NAMED_MODELS = {
92
+ # GPT models.
93
+ "gpt35turbo": lf_llms.Gpt35Turbo,
94
+ "gpt35turbo16k": lf_llms.Gpt35Turbo16K,
95
+ "gpt4": lf_llms.Gpt4,
96
+ "gpt4turbo": lf_llms.Gpt4Turbo,
97
+ # Anthropic models.
98
+ "haiku": lf_llms.Claude3Haiku,
99
+ "claude3haiku": lf_llms.Claude3Haiku,
100
+ "opus": lf_llms.Claude3Opus,
101
+ "claude3opus": lf_llms.Claude3Opus,
102
+ "sonnet": lf_llms.Claude3Sonnet,
103
+ "claude3sonnet": lf_llms.Claude3Opus,
104
+ }
105
+
106
+
107
+ def model_by_name(name: str) -> lf.LanguageModel:
108
+ """Gets model by name."""
109
+ name = name.strip().lower()
110
+ if name in _NAMED_MODELS:
111
+ return _NAMED_MODELS[name]()
112
+ raise ValueError(f"Unknown model name: {name}")
113
+
114
+
115
+ @pg.patcher(auto_typing=True)
116
+ def lm(unused_eval, models: list[str]):
117
+ """Patch the LM used for benchmarking."""
118
+ return patch_lm(pg.oneof([model_by_name(name) for name in models]))
119
+
120
+
121
+ @pg.patcher(auto_typing=True)
122
+ def temperature(unused_eval, value: float):
123
+ """Patch the temperature used for benchmarking."""
124
+ return patch_member(lf.LMSamplingOptions, "temperature", value)
125
+
126
+
127
+ @pg.patcher(auto_typing=True)
128
+ def max_tokens(unused_eval, value: int | None):
129
+ """Patch the temperature used for benchmarking."""
130
+ return patch_member(lf.LMSamplingOptions, "max_tokens", value)
@@ -0,0 +1,170 @@
1
+ # Copyright 2024 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Tests for evaluation patching."""
15
+
16
+ import unittest
17
+ from langfun.core import llms as lf_llms
18
+ from langfun.core.eval import base
19
+ from langfun.core.eval import patching
20
+ import pyglove as pg
21
+
22
+
23
+ class PatchingCommonTest(unittest.TestCase):
24
+
25
+ def test_patch_member(self):
26
+ class A(pg.Object):
27
+ x: int = 1
28
+
29
+ class B(pg.Object):
30
+ a: A
31
+
32
+ b = B(A())
33
+ pg.patch(b, [patching.patch_member(A, 'x', 2)])
34
+ self.assertEqual(b, B(A(2)))
35
+
36
+ def test_patch_args(self):
37
+ s = base.Suite(
38
+ [base.Evaluation(inputs=base.as_inputs([1]))],
39
+ additional_args=dict(x=1, y=2),
40
+ )
41
+ pg.patch(s, [patching.patch_additional_args(x=3, z=4)])
42
+ self.assertTrue(
43
+ pg.eq(
44
+ s,
45
+ base.Suite(
46
+ [
47
+ base.Evaluation(
48
+ inputs=base.as_inputs([1]),
49
+ additional_args=dict(x=3, y=2, z=4),
50
+ )
51
+ ],
52
+ additional_args=dict(x=3, y=2, z=4),
53
+ ),
54
+ )
55
+ )
56
+
57
+ def test_patch_lm(self):
58
+ s = base.Suite(
59
+ [base.Evaluation(inputs=base.as_inputs([1]))],
60
+ lm=lf_llms.Gpt35Turbo(),
61
+ )
62
+ pg.patch(
63
+ s, [patching.patch_lm(pg.oneof([lf_llms.Gpt35Turbo(), lf_llms.Gpt4()]))]
64
+ )
65
+ self.assertTrue(
66
+ pg.eq(
67
+ s,
68
+ base.Suite(
69
+ [
70
+ base.Evaluation(
71
+ inputs=base.as_inputs([1]),
72
+ lm=pg.oneof([lf_llms.Gpt35Turbo(), lf_llms.Gpt4()]),
73
+ )
74
+ ],
75
+ lm=pg.oneof([lf_llms.Gpt35Turbo(), lf_llms.Gpt4()]),
76
+ ),
77
+ )
78
+ )
79
+
80
+ def test_patch_parsing_lm(self):
81
+ s = base.Suite(
82
+ [base.Evaluation(inputs=base.as_inputs([1]))],
83
+ lm=lf_llms.Gpt4(),
84
+ )
85
+ pg.patch(s, [patching.patch_parsing_lm(lf_llms.Gpt35Turbo())])
86
+ self.assertTrue(
87
+ pg.eq(
88
+ s,
89
+ base.Suite(
90
+ [
91
+ base.Evaluation(
92
+ inputs=base.as_inputs([1]),
93
+ lm=lf_llms.Gpt4(),
94
+ parsing_lm=lf_llms.Gpt35Turbo(),
95
+ )
96
+ ],
97
+ # NOTE(daiyip): Suite does not have `parsing_lm` as one of its
98
+ # variable keyword fields yet, so patching does not add to it.
99
+ # This is okay since we only care about the leaf nodes.
100
+ lm=lf_llms.Gpt4(),
101
+ ),
102
+ )
103
+ )
104
+
105
+ def test_patch_prompt(self):
106
+ e = base.Evaluation(inputs=base.as_inputs([1]))
107
+ pg.patch(e, [patching.patch_prompt('Q: {{example.question}}')])
108
+ self.assertTrue(
109
+ pg.eq(
110
+ e,
111
+ base.Evaluation(
112
+ inputs=base.as_inputs([1]),
113
+ prompt='Q: {{example.question}}',
114
+ ),
115
+ )
116
+ )
117
+
118
+ def test_patch_inputs(self):
119
+ e = base.Evaluation(inputs=base.as_inputs([1]))
120
+ pg.patch(e, [patching.patch_inputs(base.as_inputs([2]))])
121
+ self.assertTrue(
122
+ pg.eq(
123
+ e,
124
+ base.Evaluation(
125
+ inputs=base.as_inputs([2]),
126
+ ),
127
+ )
128
+ )
129
+
130
+ def test_patch_schema_fn(self):
131
+ @pg.functor()
132
+ def int_schema():
133
+ return int
134
+
135
+ e = base.Evaluation(inputs=base.as_inputs([1]))
136
+ pg.patch(e, [patching.patch_schema_fn(int_schema())])
137
+ self.assertTrue(
138
+ pg.eq(
139
+ e,
140
+ base.Evaluation(
141
+ inputs=base.as_inputs([1]),
142
+ schema_fn=int_schema(),
143
+ ),
144
+ )
145
+ )
146
+
147
+
148
+ class StringPatcheTest(unittest.TestCase):
149
+
150
+ def test_lm(self):
151
+ target = pg.patch(
152
+ base.Evaluation(inputs=base.as_inputs([1])),
153
+ ['lm?haiku:gpt4', 'max_tokens?1024', 'temperature?0.7'],
154
+ )
155
+ self.assertEqual(
156
+ target.lm,
157
+ pg.oneof([
158
+ lf_llms.Claude3Haiku(temperature=0.7, max_tokens=1024),
159
+ lf_llms.Gpt4(temperature=0.7, max_tokens=1024),
160
+ ]),
161
+ )
162
+ with self.assertRaisesRegex(ValueError, 'Unknown model name'):
163
+ pg.patch(
164
+ base.Evaluation(inputs=base.as_inputs([1])),
165
+ ['lm?gpt2'],
166
+ )
167
+
168
+
169
+ if __name__ == '__main__':
170
+ unittest.main()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: langfun
3
- Version: 0.0.2.dev20240430
3
+ Version: 0.0.2.dev20240501
4
4
  Summary: Langfun: Language as Functions.
5
5
  Home-page: https://github.com/google/langfun
6
6
  Author: Langfun Authors
@@ -21,7 +21,6 @@ Classifier: Topic :: Software Development :: Libraries :: Python Modules
21
21
  Classifier: Topic :: Software Development :: Libraries
22
22
  Description-Content-Type: text/markdown
23
23
  License-File: LICENSE
24
- Requires-Dist: absl-py >=1.0.0
25
24
  Requires-Dist: google-generativeai >=0.3.2
26
25
  Requires-Dist: jinja2 >=3.1.2
27
26
  Requires-Dist: openai ==0.27.2
@@ -39,11 +39,13 @@ langfun/core/coding/python/parsing.py,sha256=uyvI1c5OLZhMVK2Oltkl3oJxSLlG0wadlpQ
39
39
  langfun/core/coding/python/parsing_test.py,sha256=9vAWF484kWIm6JZq8NFiMgKUDhXV-deRl1QMmNERfAA,7386
40
40
  langfun/core/coding/python/permissions.py,sha256=1QWGHvzL8MM0Ok_auQ9tURqZHtdOfJaDpBzZ29GUE-c,2544
41
41
  langfun/core/coding/python/permissions_test.py,sha256=w5EDb8QxpxgJyZkojyzVWQvDfg366zn99-g__6TbPQ0,2699
42
- langfun/core/eval/__init__.py,sha256=NSmPe2lxdxFoY4h8VkNyONPAFtOTUpK9WhmZRaqUgiI,1335
43
- langfun/core/eval/base.py,sha256=ImIdyjh89yWUbFoSI12xzpcSmvB34y8_F0WAcUi-4sg,68405
44
- langfun/core/eval/base_test.py,sha256=SEo43ftMscpZ5QV6AGaywrA6SobVaG_P7sUbjoBGqg8,24081
42
+ langfun/core/eval/__init__.py,sha256=Evt-E4FEhZF2tXL6-byh_AyA7Cc_ZoGmvnN7vkAZedk,1898
43
+ langfun/core/eval/base.py,sha256=VgHdnfkHeGPp0XjIGHw9LDZsR0Z4-yuWIkzn4pqJj3Y,73967
44
+ langfun/core/eval/base_test.py,sha256=cHOTIWVW4Dp8gKKIKcZrAcJ-w84j2GIozTzJoiAX7p4,26743
45
45
  langfun/core/eval/matching.py,sha256=Y4vFoNTQEOwko6IA8l9OZ52-vt52e3VGmcTtvLA67wM,9782
46
46
  langfun/core/eval/matching_test.py,sha256=f7iVyXH5KGJBWt4Wp14Bt9J3X59A6Ayfog9MbuFvPew,5532
47
+ langfun/core/eval/patching.py,sha256=R0s2eAd1m97exQt06dmUL0V_MBG0W2Hxg7fhNB7cXW0,3866
48
+ langfun/core/eval/patching_test.py,sha256=8kCd54Egjju22FMgtJuxEsrXkW8ifs-UUBHtrCG1L6w,4775
47
49
  langfun/core/eval/scoring.py,sha256=1J7IATo-8FXUR0SBqk9icztHiM0lWkBFcWUo-vUURgQ,6376
48
50
  langfun/core/eval/scoring_test.py,sha256=O8olHbrUEg60gMxwOkWzKBJZpZoUlmVnBANX5Se2SXM,4546
49
51
  langfun/core/llms/__init__.py,sha256=1bPg1QI8duOZCYINm-jWi094x0JtLmsk4KX60qIC_gs,3245
@@ -101,8 +103,8 @@ langfun/core/templates/demonstration.py,sha256=vCrgYubdZM5Umqcgp8NUVGXgr4P_c-fik
101
103
  langfun/core/templates/demonstration_test.py,sha256=SafcDQ0WgI7pw05EmPI2S4v1t3ABKzup8jReCljHeK4,2162
102
104
  langfun/core/templates/selfplay.py,sha256=yhgrJbiYwq47TgzThmHrDQTF4nDrTI09CWGhuQPNv-s,2273
103
105
  langfun/core/templates/selfplay_test.py,sha256=DYVrkk7uNKCqJGEHH31HssU2BPuMItU1vJLzfcXIlYg,2156
104
- langfun-0.0.2.dev20240430.dist-info/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
105
- langfun-0.0.2.dev20240430.dist-info/METADATA,sha256=RpEIB1auHihqOoDrPnFQaYqgpqxFKA9_Z9iuCfPxe5s,3436
106
- langfun-0.0.2.dev20240430.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
107
- langfun-0.0.2.dev20240430.dist-info/top_level.txt,sha256=RhlEkHxs1qtzmmtWSwYoLVJAc1YrbPtxQ52uh8Z9VvY,8
108
- langfun-0.0.2.dev20240430.dist-info/RECORD,,
106
+ langfun-0.0.2.dev20240501.dist-info/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
107
+ langfun-0.0.2.dev20240501.dist-info/METADATA,sha256=SUhJ4RRQcyqLKu16sGip7Z2D875PI5EarCo3VDAGxuQ,3405
108
+ langfun-0.0.2.dev20240501.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
109
+ langfun-0.0.2.dev20240501.dist-info/top_level.txt,sha256=RhlEkHxs1qtzmmtWSwYoLVJAc1YrbPtxQ52uh8Z9VvY,8
110
+ langfun-0.0.2.dev20240501.dist-info/RECORD,,