langfun 0.1.2.dev202411100803__py3-none-any.whl → 0.1.2.dev202411120804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. langfun/core/console.py +10 -2
  2. langfun/core/console_test.py +17 -0
  3. langfun/core/eval/__init__.py +2 -0
  4. langfun/core/eval/v2/__init__.py +34 -0
  5. langfun/core/eval/v2/checkpointing.py +130 -0
  6. langfun/core/eval/v2/checkpointing_test.py +89 -0
  7. langfun/core/eval/v2/evaluation.py +615 -0
  8. langfun/core/eval/v2/evaluation_test.py +143 -0
  9. langfun/core/eval/v2/example.py +286 -0
  10. langfun/core/eval/v2/example_test.py +92 -0
  11. langfun/core/eval/v2/experiment.py +949 -0
  12. langfun/core/eval/v2/experiment_test.py +304 -0
  13. langfun/core/eval/v2/metric_values.py +156 -0
  14. langfun/core/eval/v2/metric_values_test.py +80 -0
  15. langfun/core/eval/v2/metrics.py +357 -0
  16. langfun/core/eval/v2/metrics_test.py +203 -0
  17. langfun/core/eval/v2/progress.py +348 -0
  18. langfun/core/eval/v2/progress_test.py +82 -0
  19. langfun/core/eval/v2/progress_tracking.py +209 -0
  20. langfun/core/eval/v2/progress_tracking_test.py +56 -0
  21. langfun/core/eval/v2/reporting.py +144 -0
  22. langfun/core/eval/v2/reporting_test.py +41 -0
  23. langfun/core/eval/v2/runners.py +417 -0
  24. langfun/core/eval/v2/runners_test.py +311 -0
  25. langfun/core/eval/v2/test_helper.py +78 -0
  26. langfun/core/language_model.py +122 -11
  27. langfun/core/language_model_test.py +97 -4
  28. langfun/core/llms/__init__.py +3 -0
  29. langfun/core/llms/compositional.py +101 -0
  30. langfun/core/llms/compositional_test.py +73 -0
  31. {langfun-0.1.2.dev202411100803.dist-info → langfun-0.1.2.dev202411120804.dist-info}/METADATA +1 -1
  32. {langfun-0.1.2.dev202411100803.dist-info → langfun-0.1.2.dev202411120804.dist-info}/RECORD +35 -11
  33. {langfun-0.1.2.dev202411100803.dist-info → langfun-0.1.2.dev202411120804.dist-info}/WHEEL +1 -1
  34. {langfun-0.1.2.dev202411100803.dist-info → langfun-0.1.2.dev202411120804.dist-info}/LICENSE +0 -0
  35. {langfun-0.1.2.dev202411100803.dist-info → langfun-0.1.2.dev202411120804.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,143 @@
1
+ # Copyright 2024 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import os
15
+ import tempfile
16
+ import unittest
17
+
18
+ from langfun.core.eval.v2 import evaluation as evaluation_lib
19
+ from langfun.core.eval.v2 import example as example_lib
20
+ from langfun.core.eval.v2 import experiment as experiment_lib
21
+
22
+ from langfun.core.eval.v2 import test_helper
23
+
24
+ import pyglove as pg
25
+
26
+ Example = example_lib.Example
27
+ Evaluation = evaluation_lib.Evaluation
28
+ RunId = experiment_lib.RunId
29
+ Run = experiment_lib.Run
30
+
31
+
32
+ class EvaluationTest(unittest.TestCase):
33
+
34
+ def test_hyper_evaluation(self):
35
+ exp = test_helper.TestEvaluation(
36
+ lm=test_helper.TestLLM(offset=pg.oneof(range(3)))
37
+ )
38
+ self.assertFalse(exp.is_leaf)
39
+ self.assertTrue(
40
+ pg.eq(
41
+ exp.children,
42
+ [
43
+ test_helper.TestEvaluation(lm=test_helper.TestLLM(offset=0)),
44
+ test_helper.TestEvaluation(lm=test_helper.TestLLM(offset=1)),
45
+ test_helper.TestEvaluation(lm=test_helper.TestLLM(offset=2)),
46
+ ]
47
+ )
48
+ )
49
+ self.assertEqual(exp.children[0].num_examples, 10)
50
+ self.assertEqual(
51
+ [c.is_leaf for c in exp.children],
52
+ [True] * len(exp.children)
53
+ )
54
+ self.assertEqual(
55
+ [r.resource_ids() for r in exp.leaf_nodes],
56
+ [set(['test_llm:0']), set(['test_llm:1']), set(['test_llm:2'])]
57
+ )
58
+
59
+ def test_evaluate(self):
60
+ exp = test_helper.TestEvaluation()
61
+ example = exp.evaluate(Example(id=3))
62
+ self.assertTrue(example.newly_processed)
63
+ self.assertEqual(example.input, pg.Dict(x=2, y=4, groundtruth=6))
64
+ self.assertEqual(example.output, 6)
65
+ self.assertIsNone(example.error)
66
+ self.assertEqual(example.metadata, {})
67
+ self.assertEqual(example.metric_metadata, dict(match=True))
68
+ self.assertIsNotNone(example.usage_summary)
69
+ self.assertGreater(example.usage_summary.total.total_tokens, 0)
70
+ self.assertEqual(example.usage_summary.total.num_requests, 1)
71
+ self.assertIsNotNone(example.execution_status)
72
+ self.assertIsNotNone(example.start_time)
73
+ self.assertIsNotNone(example.end_time)
74
+
75
+ exp = test_helper.TestEvaluation(lm=test_helper.TestLLM(offset=1))
76
+ example = exp.evaluate(3)
77
+ self.assertTrue(example.newly_processed)
78
+ self.assertEqual(example.input, pg.Dict(x=2, y=4, groundtruth=6))
79
+ self.assertEqual(example.output, 7)
80
+ self.assertIsNone(example.error)
81
+ self.assertEqual(example.metadata, {})
82
+ self.assertEqual(example.metric_metadata, dict(mismatch=True))
83
+
84
+ with self.assertRaisesRegex(ValueError, 'x should not be 5'):
85
+ _ = exp.evaluate(6, raise_if_has_error=True)
86
+ example = exp.evaluate(6)
87
+ self.assertTrue(example.newly_processed)
88
+ self.assertEqual(example.input, pg.Dict(x=5, y=25, groundtruth=30))
89
+ self.assertEqual(pg.MISSING_VALUE, example.output)
90
+ self.assertEqual(example.error.tag, 'ValueError')
91
+ self.assertEqual(example.metadata, {})
92
+ self.assertEqual(example.metric_metadata, dict(error='ValueError'))
93
+
94
+ def test_evaluate_with_state(self):
95
+ eval_dir = os.path.join(tempfile.gettempdir(), 'test_eval')
96
+ pg.io.mkdirs(eval_dir, exist_ok=True)
97
+ state_file = os.path.join(eval_dir, 'state.jsonl')
98
+ with pg.io.open_sequence(state_file, 'w') as f:
99
+ exp = test_helper.TestEvaluation()
100
+ example = exp.evaluate(3)
101
+ self.assertTrue(example.newly_processed)
102
+ self.assertEqual(example.input, pg.Dict(x=2, y=4, groundtruth=6))
103
+ self.assertEqual(example.output, 6)
104
+ self.assertEqual(len(exp._state.evaluated_examples), 1)
105
+ f.add(pg.to_json_str(example))
106
+
107
+ exp.reset()
108
+ self.assertEqual(len(exp._state.evaluated_examples), 0)
109
+ exp.load_state(state_file)
110
+ self.assertEqual(len(exp._state.evaluated_examples), 1)
111
+ example = exp.evaluate(3)
112
+ self.assertFalse(example.newly_processed)
113
+ self.assertEqual(example.input, pg.Dict(x=2, y=4, groundtruth=6))
114
+ self.assertEqual(example.output, 6)
115
+ self.assertGreater(example.usage_summary.total.total_tokens, 0)
116
+ self.assertGreater(example.usage_summary.cached.total.total_tokens, 0)
117
+ self.assertEqual(example.usage_summary.cached.total.num_requests, 1)
118
+ self.assertEqual(example.usage_summary.uncached.total.total_tokens, 0)
119
+ self.assertEqual(example.usage_summary.uncached.total.num_requests, 0)
120
+
121
+ def test_html_view(self):
122
+ exp = test_helper.TestEvaluation()
123
+ self.assertIn(
124
+ exp.id,
125
+ exp.to_html(extra_flags=dict(card_view=True, current_run=None)).content
126
+ )
127
+ self.assertIn(
128
+ exp.id,
129
+ exp.to_html(
130
+ extra_flags=dict(
131
+ card_view=False,
132
+ current_run=Run(
133
+ root_dir='/tmp/test_run',
134
+ id=RunId.from_id('20241031_1'),
135
+ experiment=pg.Ref(exp),
136
+ )
137
+ )
138
+ ).content
139
+ )
140
+
141
+
142
+ if __name__ == '__main__':
143
+ unittest.main()
@@ -0,0 +1,286 @@
1
+ # Copyright 2024 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Base classes for Langfun evaluation."""
15
+
16
+ import dataclasses
17
+ import inspect
18
+ from typing import Any, Callable
19
+ import langfun.core as lf
20
+ import pyglove as pg
21
+
22
+
23
+ @dataclasses.dataclass
24
+ class Example(pg.JSONConvertible, pg.views.HtmlTreeView.Extension):
25
+ """An item for the evaluation.
26
+
27
+ Attributes:
28
+ id: The 1-based ID of the item in the evaluation set.
29
+ input: An element returned from the `Evaluable.inputs` functor.
30
+ output: The output of the `process` method. If `pg.MISSING_VALUE`, it has
31
+ not been processed yet.
32
+ metadata: The metadata of the item produced by the `process` method.
33
+ metric_metadata: The dictionary returned from `Metric.audit`.
34
+ start_time: The start time of the evaluation item.
35
+ end_time: The end time of the evaluation item.
36
+ usage_summary: The summary of LLM usages of the evaluation item.
37
+ execution_status: The timeit status of the evaluation item.
38
+ """
39
+ id: int
40
+ input: Any = pg.MISSING_VALUE
41
+ output: Any = pg.MISSING_VALUE
42
+ error: pg.object_utils.ErrorInfo | None = None
43
+ metadata: dict[str, Any] = dataclasses.field(default_factory=dict)
44
+ metric_metadata: dict[str, Any] | None = None
45
+ # Execution information.
46
+ newly_processed: bool = True
47
+ start_time: float | None = None
48
+ end_time: float | None = None
49
+ usage_summary: lf.UsageSummary | None = None
50
+ execution_status: dict[str, pg.object_utils.TimeIt.Status] | None = None
51
+
52
+ def __post_init__(self):
53
+ if self.execution_status is not None:
54
+ for status in self.execution_status.values():
55
+ if status.has_error:
56
+ self.error = status.error
57
+ break
58
+
59
+ @property
60
+ def is_processed(self) -> bool:
61
+ """Returns whether the item has been processed."""
62
+ return pg.MISSING_VALUE != self.output
63
+
64
+ @property
65
+ def has_error(self) -> bool:
66
+ """Returns whether the item has an error."""
67
+ return self.error is not None
68
+
69
+ @property
70
+ def elapse(self) -> float | None:
71
+ """Returns the elapse time of the item."""
72
+ if self.execution_status is not None:
73
+ return self.execution_status['evaluate'].elapse
74
+ return None
75
+
76
+ def to_json(self, **kwargs) -> dict[str, Any]:
77
+ """Returns the JSON representation of the item."""
78
+ return self.to_json_dict(
79
+ fields=dict(
80
+ id=(self.id, None),
81
+ # NOTE(daiyip): We do not write `input` to JSON as it will be
82
+ # loaded from the input functor. This allows us to support
83
+ # non-serializable examples.
84
+ output=(self.output, pg.MISSING_VALUE),
85
+ error=(self.error, None),
86
+ metadata=(self.metadata, {}),
87
+ metric_metadata=(self.metric_metadata, None),
88
+ start_time=(self.start_time, None),
89
+ end_time=(self.end_time, None),
90
+ usage_summary=(self.usage_summary, None),
91
+ execution_status=(self.execution_status, None),
92
+ ),
93
+ exclude_default=True,
94
+ **kwargs,
95
+ )
96
+
97
+ @classmethod
98
+ def from_json(
99
+ cls,
100
+ json_value: dict[str, Any],
101
+ *,
102
+ example_input_by_id: Callable[[int], Any],
103
+ **kwargs
104
+ ) -> 'Example':
105
+ """Creates an example from the JSON representation."""
106
+ example_id = json_value.get('id')
107
+ example_input = example_input_by_id(example_id)
108
+ json_value['input'] = example_input
109
+
110
+ # NOTE(daiyip): We need to load the types of the examples into the
111
+ # deserialization context, otherwise the deserialization will fail if the
112
+ # types are not registered.
113
+ def example_class_defs(example) -> list[type[Any]]:
114
+ referred_types = set()
115
+ def _visit(k, v, p):
116
+ del k, p
117
+ if inspect.isclass(v):
118
+ referred_types.add(v)
119
+ elif isinstance(v, pg.Object):
120
+ referred_types.add(v.__class__)
121
+ return pg.TraverseAction.ENTER
122
+ pg.traverse(example, _visit)
123
+ return list(referred_types)
124
+
125
+ with pg.JSONConvertible.load_types_for_deserialization(
126
+ *example_class_defs(example_input)
127
+ ):
128
+ return cls(
129
+ **{k: pg.from_json(v, **kwargs) for k, v in json_value.items()}
130
+ )
131
+
132
+ #
133
+ # HTML rendering.
134
+ #
135
+
136
+ def _html_tree_view_content(
137
+ self,
138
+ *,
139
+ view: pg.views.HtmlTreeView,
140
+ root_path: pg.KeyPath | None = None,
141
+ extra_flags: dict[str, Any] | None = None,
142
+ **kwargs
143
+ ):
144
+ root_path = root_path or pg.KeyPath()
145
+ extra_flags = extra_flags or {}
146
+ num_examples = extra_flags.get('num_examples', None)
147
+
148
+ def _metric_metadata_badge(key, value):
149
+ if isinstance(value, bool) and bool:
150
+ text = key
151
+ else:
152
+ text = f'{key}:{value}'
153
+ return pg.views.html.controls.Badge(
154
+ text,
155
+ css_classes=[pg.object_utils.camel_to_snake(key, '-')],
156
+ )
157
+
158
+ def _render_header():
159
+ return pg.Html.element(
160
+ 'div',
161
+ [
162
+ pg.Html.element(
163
+ 'div',
164
+ [
165
+ # Previous button.
166
+ pg.views.html.controls.Label( # pylint: disable=g-long-ternary
167
+ '◀',
168
+ link=f'{self.id - 1}.html',
169
+ css_classes=['previous'],
170
+ ) if self.id > 1 else None,
171
+ # Current example ID.
172
+ pg.views.html.controls.Label(
173
+ f'#{self.id}',
174
+ css_classes=['example-id'],
175
+ ),
176
+ # Next button.
177
+ pg.views.html.controls.Label( # pylint: disable=g-long-ternary
178
+ '▶',
179
+ link=f'{self.id + 1}.html',
180
+ css_classes=['next'],
181
+ ) if (num_examples is None
182
+ or self.id < num_examples) else None,
183
+
184
+ ]
185
+ ),
186
+ pg.Html.element(
187
+ 'div',
188
+ [
189
+ # Usage summary.
190
+ pg.view( # pylint: disable=g-long-ternary
191
+ self.usage_summary,
192
+ extra_flags=dict(as_badge=True)
193
+ ) if self.usage_summary is not None else None,
194
+ # Metric metadata.
195
+ pg.views.html.controls.LabelGroup(
196
+ [ # pylint: disable=g-long-ternary
197
+ _metric_metadata_badge(k, v)
198
+ for k, v in self.metric_metadata.items()
199
+ ] if self.metric_metadata else []
200
+ ),
201
+ ],
202
+ css_classes=['example-container'],
203
+ )
204
+ ]
205
+ )
206
+
207
+ def _render_content():
208
+ def _tab(label, key):
209
+ field = getattr(self, key)
210
+ if pg.MISSING_VALUE == field or not field:
211
+ return None
212
+ return pg.views.html.controls.Tab(
213
+ label=label,
214
+ content=view.render(
215
+ field,
216
+ root_path=root_path + key,
217
+ **view.get_passthrough_kwargs(**kwargs),
218
+ ),
219
+ )
220
+ tabs = [
221
+ _tab('Input', 'input'),
222
+ _tab('Output', 'output'),
223
+ _tab('Output Metadata', 'metadata'),
224
+ _tab('Error', 'error'),
225
+ ]
226
+ return pg.views.html.controls.TabControl(
227
+ [tab for tab in tabs if tab is not None]
228
+ )
229
+
230
+ return pg.Html.element(
231
+ 'div',
232
+ [
233
+ _render_header(),
234
+ _render_content(),
235
+ ],
236
+ css_classes=['eval-example']
237
+ )
238
+
239
+ def _html_tree_view_summary(self, *, view, **kwargs):
240
+ return None
241
+
242
+ @classmethod
243
+ def _html_tree_view_css_styles(cls) -> list[str]:
244
+ return super()._html_tree_view_css_styles() + [
245
+ """
246
+ .example-container {
247
+ display: block;
248
+ padding: 10px;
249
+ }
250
+ .example-id {
251
+ font-weight: bold;
252
+ font-size: 40px;
253
+ margin: 0 10px;
254
+ vertical-align: middle;
255
+ }
256
+ a.previous, a.next {
257
+ text-decoration: none;
258
+ vertical-align: middle;
259
+ display: inline-block;
260
+ padding: 8px 8px;
261
+ color: #DDD;
262
+ }
263
+ a.previous:hover, a.next:hover {
264
+ background-color: #ddd;
265
+ color: black;
266
+ }
267
+ /* Badge styles. */
268
+ .eval-example .badge.match {
269
+ color: green;
270
+ background-color: #dcefbe;
271
+ }
272
+ .eval-example .badge.error {
273
+ color: red;
274
+ background-color: #fdcccc;
275
+ }
276
+ .eval-example .badge.mismatch {
277
+ color: orange;
278
+ background-color: #ffefc4;
279
+ }
280
+ .eval-example .badge.score {
281
+ color: blue;
282
+ background-color: #c4dced;
283
+ }
284
+ """
285
+ ]
286
+
@@ -0,0 +1,92 @@
1
+ # Copyright 2024 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import unittest
15
+
16
+ from langfun.core.eval.v2 import example as example_lib
17
+ import pyglove as pg
18
+
19
+ Example = example_lib.Example
20
+
21
+
22
+ class ExampleTest(unittest.TestCase):
23
+
24
+ def test_basic(self):
25
+ error = pg.object_utils.ErrorInfo(
26
+ tag='ValueError',
27
+ description='Bad input',
28
+ stacktrace='...',
29
+ )
30
+ ex = Example(id=1, execution_status={
31
+ 'evaluate': pg.object_utils.TimeIt.Status(
32
+ name='evaluation', elapse=1.0, error=error
33
+ )
34
+ })
35
+ self.assertEqual(ex.error, error)
36
+ self.assertFalse(ex.is_processed)
37
+ self.assertTrue(ex.has_error)
38
+ self.assertEqual(ex.elapse, 1.0)
39
+
40
+ ex = Example(id=2, output=1)
41
+ self.assertTrue(ex.is_processed)
42
+ self.assertFalse(ex.has_error)
43
+ self.assertIsNone(ex.elapse)
44
+
45
+ def test_json_conversion(self):
46
+ def input_func():
47
+ class A(pg.Object):
48
+ x: int
49
+
50
+ class B(pg.Object):
51
+ x: int = 1
52
+ y: int = 2
53
+
54
+ return [
55
+ pg.Dict(
56
+ a=A,
57
+ b=B
58
+ )
59
+ ]
60
+
61
+ inputs = input_func()
62
+ ex = Example(
63
+ id=1,
64
+ input=inputs[0],
65
+ output=inputs[0].a(1),
66
+ metadata=dict(b=inputs[0].b())
67
+ )
68
+ json_str = pg.to_json_str(ex)
69
+ self.assertEqual(
70
+ pg.from_json_str(
71
+ json_str,
72
+ example_input_by_id=lambda i: inputs[i - 1]
73
+ ),
74
+ ex
75
+ )
76
+
77
+ def test_html_view(self):
78
+ ex = Example(
79
+ id=1,
80
+ input=pg.Dict(a=1, b=2),
81
+ output=3,
82
+ metadata=dict(sum=3),
83
+ metric_metadata=dict(match=True),
84
+ )
85
+ self.assertNotIn(
86
+ 'next',
87
+ ex.to_html(extra_flags=dict(num_examples=1)).content,
88
+ )
89
+
90
+
91
+ if __name__ == '__main__':
92
+ unittest.main()