langfun 0.0.2.dev20240429__py3-none-any.whl → 0.1.2.dev202501150804__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/__init__.py +20 -2
- langfun/core/__init__.py +16 -5
- langfun/core/agentic/__init__.py +30 -0
- langfun/core/agentic/action.py +854 -0
- langfun/core/agentic/action_eval.py +150 -0
- langfun/core/agentic/action_eval_test.py +109 -0
- langfun/core/agentic/action_test.py +136 -0
- langfun/core/coding/python/__init__.py +5 -11
- langfun/core/coding/python/correction.py +37 -21
- langfun/core/coding/python/correction_test.py +29 -3
- langfun/core/coding/python/execution.py +40 -216
- langfun/core/coding/python/execution_test.py +29 -89
- langfun/core/coding/python/generation.py +21 -11
- langfun/core/coding/python/generation_test.py +2 -2
- langfun/core/coding/python/parsing.py +108 -193
- langfun/core/coding/python/parsing_test.py +2 -105
- langfun/core/component.py +63 -2
- langfun/core/component_test.py +53 -0
- langfun/core/concurrent.py +414 -117
- langfun/core/concurrent_test.py +111 -24
- langfun/core/console.py +17 -5
- langfun/core/console_test.py +17 -0
- langfun/core/eval/__init__.py +16 -1
- langfun/core/eval/base.py +622 -174
- langfun/core/eval/base_test.py +200 -54
- langfun/core/eval/matching.py +63 -76
- langfun/core/eval/matching_test.py +17 -8
- langfun/core/eval/patching.py +130 -0
- langfun/core/eval/patching_test.py +170 -0
- langfun/core/eval/scoring.py +26 -26
- langfun/core/eval/scoring_test.py +19 -2
- langfun/core/eval/v2/__init__.py +42 -0
- langfun/core/eval/v2/checkpointing.py +380 -0
- langfun/core/eval/v2/checkpointing_test.py +228 -0
- langfun/core/eval/v2/eval_test_helper.py +136 -0
- langfun/core/eval/v2/evaluation.py +725 -0
- langfun/core/eval/v2/evaluation_test.py +180 -0
- langfun/core/eval/v2/example.py +305 -0
- langfun/core/eval/v2/example_test.py +128 -0
- langfun/core/eval/v2/experiment.py +1048 -0
- langfun/core/eval/v2/experiment_test.py +433 -0
- langfun/core/eval/v2/metric_values.py +156 -0
- langfun/core/eval/v2/metric_values_test.py +80 -0
- langfun/core/eval/v2/metrics.py +357 -0
- langfun/core/eval/v2/metrics_test.py +203 -0
- langfun/core/eval/v2/progress.py +348 -0
- langfun/core/eval/v2/progress_test.py +82 -0
- langfun/core/eval/v2/progress_tracking.py +210 -0
- langfun/core/eval/v2/progress_tracking_test.py +66 -0
- langfun/core/eval/v2/reporting.py +270 -0
- langfun/core/eval/v2/reporting_test.py +158 -0
- langfun/core/eval/v2/runners.py +488 -0
- langfun/core/eval/v2/runners_test.py +334 -0
- langfun/core/langfunc.py +4 -17
- langfun/core/langfunc_test.py +22 -6
- langfun/core/language_model.py +577 -39
- langfun/core/language_model_test.py +470 -56
- langfun/core/llms/__init__.py +87 -16
- langfun/core/llms/anthropic.py +312 -87
- langfun/core/llms/anthropic_test.py +71 -3
- langfun/core/llms/cache/base.py +21 -2
- langfun/core/llms/cache/in_memory.py +13 -0
- langfun/core/llms/cache/in_memory_test.py +53 -2
- langfun/core/llms/compositional.py +101 -0
- langfun/core/llms/compositional_test.py +73 -0
- langfun/core/llms/deepseek.py +117 -0
- langfun/core/llms/deepseek_test.py +61 -0
- langfun/core/llms/fake.py +11 -7
- langfun/core/llms/fake_test.py +14 -0
- langfun/core/llms/gemini.py +507 -0
- langfun/core/llms/gemini_test.py +195 -0
- langfun/core/llms/google_genai.py +62 -218
- langfun/core/llms/google_genai_test.py +9 -202
- langfun/core/llms/groq.py +160 -144
- langfun/core/llms/groq_test.py +31 -137
- langfun/core/llms/llama_cpp.py +15 -42
- langfun/core/llms/llama_cpp_test.py +4 -30
- langfun/core/llms/openai.py +395 -203
- langfun/core/llms/openai_compatible.py +179 -0
- langfun/core/llms/openai_compatible_test.py +495 -0
- langfun/core/llms/openai_test.py +30 -395
- langfun/core/llms/rest.py +113 -0
- langfun/core/llms/rest_test.py +111 -0
- langfun/core/llms/vertexai.py +192 -0
- langfun/core/llms/vertexai_test.py +52 -0
- langfun/core/logging.py +284 -0
- langfun/core/logging_test.py +125 -0
- langfun/core/message.py +319 -9
- langfun/core/message_test.py +190 -13
- langfun/core/modalities/__init__.py +6 -2
- langfun/core/modalities/audio.py +30 -0
- langfun/core/modalities/audio_test.py +63 -0
- langfun/core/modalities/image.py +39 -20
- langfun/core/modalities/image_test.py +52 -9
- langfun/core/modalities/mime.py +206 -29
- langfun/core/modalities/mime_test.py +90 -9
- langfun/core/modalities/ms_office.py +117 -0
- langfun/core/modalities/ms_office_test.py +389 -0
- langfun/core/modalities/pdf.py +22 -0
- langfun/core/modalities/pdf_test.py +57 -0
- langfun/core/modalities/video.py +9 -26
- langfun/core/modalities/video_test.py +3 -3
- langfun/core/modality.py +26 -3
- langfun/core/modality_test.py +2 -2
- langfun/core/sampling.py +11 -11
- langfun/core/structured/__init__.py +12 -16
- langfun/core/structured/completion.py +32 -5
- langfun/core/structured/completion_test.py +7 -6
- langfun/core/structured/description.py +2 -2
- langfun/core/structured/description_test.py +3 -3
- langfun/core/structured/function_generation.py +60 -27
- langfun/core/structured/function_generation_test.py +72 -2
- langfun/core/structured/mapping.py +97 -47
- langfun/core/structured/mapping_test.py +90 -2
- langfun/core/structured/parsing.py +33 -21
- langfun/core/structured/parsing_test.py +53 -9
- langfun/core/structured/querying.py +746 -0
- langfun/core/structured/{prompting_test.py → querying_test.py} +469 -51
- langfun/core/structured/schema.py +204 -97
- langfun/core/structured/schema_generation.py +1 -1
- langfun/core/structured/schema_test.py +130 -29
- langfun/core/structured/scoring.py +125 -19
- langfun/core/structured/scoring_test.py +30 -0
- langfun/core/structured/tokenization.py +64 -0
- langfun/core/structured/tokenization_test.py +48 -0
- langfun/core/template.py +115 -1
- langfun/core/template_test.py +71 -1
- langfun/core/templates/conversation.py +9 -0
- langfun/core/templates/conversation_test.py +4 -3
- langfun/core/templates/selfplay_test.py +10 -2
- langfun-0.1.2.dev202501150804.dist-info/METADATA +225 -0
- langfun-0.1.2.dev202501150804.dist-info/RECORD +153 -0
- {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501150804.dist-info}/WHEEL +1 -1
- langfun/core/coding/python/errors.py +0 -108
- langfun/core/coding/python/errors_test.py +0 -99
- langfun/core/coding/python/permissions.py +0 -90
- langfun/core/coding/python/permissions_test.py +0 -86
- langfun/core/structured/prompting.py +0 -238
- langfun/core/text_formatting.py +0 -162
- langfun/core/text_formatting_test.py +0 -47
- langfun-0.0.2.dev20240429.dist-info/METADATA +0 -100
- langfun-0.0.2.dev20240429.dist-info/RECORD +0 -108
- {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501150804.dist-info}/LICENSE +0 -0
- {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501150804.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,433 @@
|
|
1
|
+
# Copyright 2024 The Langfun Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
import datetime
|
15
|
+
import os
|
16
|
+
import tempfile
|
17
|
+
import unittest
|
18
|
+
|
19
|
+
from langfun.core.eval.v2 import evaluation as evaluation_lib
|
20
|
+
from langfun.core.eval.v2 import experiment as experiment_lib
|
21
|
+
from langfun.core.eval.v2 import metrics as metrics_lib
|
22
|
+
|
23
|
+
import pyglove as pg
|
24
|
+
|
25
|
+
Experiment = experiment_lib.Experiment
|
26
|
+
Suite = experiment_lib.Suite
|
27
|
+
Evaluation = evaluation_lib.Evaluation
|
28
|
+
Run = experiment_lib.Run
|
29
|
+
RunId = experiment_lib.RunId
|
30
|
+
Runner = experiment_lib.Runner
|
31
|
+
|
32
|
+
|
33
|
+
@pg.functor()
|
34
|
+
def sample_inputs(num_examples: int = 1):
|
35
|
+
return [
|
36
|
+
pg.Dict(x=1)
|
37
|
+
] * num_examples
|
38
|
+
|
39
|
+
|
40
|
+
class MyEvaluation(Evaluation):
|
41
|
+
NAME = 'my_eval'
|
42
|
+
RUN_ARGS = dict(
|
43
|
+
runner='test'
|
44
|
+
)
|
45
|
+
|
46
|
+
replica_id: int = 0
|
47
|
+
inputs = sample_inputs()
|
48
|
+
metrics = [metrics_lib.Match()]
|
49
|
+
|
50
|
+
def process(self, example):
|
51
|
+
return 1
|
52
|
+
|
53
|
+
|
54
|
+
class ExperimentTest(unittest.TestCase):
|
55
|
+
|
56
|
+
def test_hierarchy(self):
|
57
|
+
exp = Suite([
|
58
|
+
Suite([
|
59
|
+
MyEvaluation(replica_id=0)
|
60
|
+
]),
|
61
|
+
MyEvaluation(replica_id=pg.oneof(range(5))),
|
62
|
+
])
|
63
|
+
|
64
|
+
self.assertIsNotNone(exp.id)
|
65
|
+
self.assertTrue(exp.id.startswith('Suite@'))
|
66
|
+
self.assertEqual(len(exp.children), 2)
|
67
|
+
self.assertEqual(len(exp.leaf_nodes), 6)
|
68
|
+
self.assertEqual(len(exp.nonleaf_nodes), 3)
|
69
|
+
self.assertFalse(exp.is_leaf)
|
70
|
+
self.assertFalse(exp.empty())
|
71
|
+
self.assertEqual(len(exp.nodes), 9)
|
72
|
+
|
73
|
+
self.assertTrue(exp.children[0].children[0].id.startswith('MyEvaluation@'))
|
74
|
+
self.assertTrue(exp.children[0].children[0].is_leaf)
|
75
|
+
self.assertEqual(len(exp.children[0].children[0].leaf_nodes), 1)
|
76
|
+
self.assertFalse(exp.children[1].is_leaf)
|
77
|
+
self.assertEqual(len(exp.children[1].children), 5)
|
78
|
+
self.assertEqual(len(exp.children[1].leaf_nodes), 5)
|
79
|
+
self.assertEqual(exp.leaf_nodes[-1].replica_id, 4)
|
80
|
+
self.assertNotEqual(exp.leaf_nodes[1].hash, exp.leaf_nodes[2].hash)
|
81
|
+
|
82
|
+
self.assertIsNone(exp.parent)
|
83
|
+
self.assertIs(exp.children[0].parent, exp)
|
84
|
+
self.assertIs(exp.children[0].children[0].parent, exp.children[0])
|
85
|
+
self.assertIs(exp.children[1].children[0].parent, exp.children[1])
|
86
|
+
self.assertIs(exp.get(exp.leaf_nodes[-1].id), exp.leaf_nodes[-1])
|
87
|
+
|
88
|
+
def test_html_view(self):
|
89
|
+
exp = Suite([
|
90
|
+
Suite([
|
91
|
+
MyEvaluation(replica_id=0)
|
92
|
+
]),
|
93
|
+
MyEvaluation(replica_id=pg.oneof(range(5))),
|
94
|
+
])
|
95
|
+
self.assertIn(exp.id, exp.to_html().content)
|
96
|
+
run = Run('/root', RunId.from_id('20241102_0'), pg.Ref(exp))
|
97
|
+
self.assertIn(
|
98
|
+
str(run.id),
|
99
|
+
run.to_html(
|
100
|
+
extra_flags=dict(
|
101
|
+
current_run=run
|
102
|
+
)
|
103
|
+
).content
|
104
|
+
)
|
105
|
+
|
106
|
+
def test_find(self):
|
107
|
+
exp = Experiment.find('my_eval')
|
108
|
+
self.assertIsInstance(exp, MyEvaluation)
|
109
|
+
exp = Experiment.find('.*_eval')
|
110
|
+
self.assertIsInstance(exp, MyEvaluation)
|
111
|
+
exp = Experiment.find('foo')
|
112
|
+
self.assertTrue(pg.eq(exp, Suite([])))
|
113
|
+
|
114
|
+
|
115
|
+
class RunIdTest(unittest.TestCase):
|
116
|
+
|
117
|
+
def test_basic(self):
|
118
|
+
rid = RunId.from_id('20241102_0')
|
119
|
+
self.assertEqual(
|
120
|
+
rid.dirname('/root'), os.path.join('/root', 'run_20241102_0')
|
121
|
+
)
|
122
|
+
self.assertEqual(str(rid), '20241102_0')
|
123
|
+
self.assertEqual(rid.date, datetime.date(2024, 11, 2))
|
124
|
+
self.assertEqual(rid.number, 0)
|
125
|
+
|
126
|
+
def test_comparison(self):
|
127
|
+
self.assertEqual(
|
128
|
+
RunId.from_id('20241102_0'), RunId.from_id('20241102_0')
|
129
|
+
)
|
130
|
+
self.assertLess(
|
131
|
+
RunId.from_id('20241102_0'), RunId.from_id('20241102_1')
|
132
|
+
)
|
133
|
+
self.assertLess(
|
134
|
+
RunId.from_id('20241101_0'), RunId.from_id('20241102_1')
|
135
|
+
)
|
136
|
+
self.assertGreater(
|
137
|
+
RunId.from_id('20241102_0'), RunId.from_id('20241101_0')
|
138
|
+
)
|
139
|
+
self.assertLessEqual(
|
140
|
+
RunId.from_id('20241102_0'), RunId.from_id('20241102_0')
|
141
|
+
)
|
142
|
+
self.assertEqual(
|
143
|
+
RunId.from_id('20241102_0').next(),
|
144
|
+
RunId.from_id('20241102_1')
|
145
|
+
)
|
146
|
+
|
147
|
+
def test_get_latest(self):
|
148
|
+
root_dir = os.path.join(tempfile.gettempdir(), 'test_eval')
|
149
|
+
pg.io.mkdirs(os.path.join(root_dir, 'run_20241102_0'))
|
150
|
+
pg.io.mkdirs(os.path.join(root_dir, 'run_20241101_0'))
|
151
|
+
self.assertEqual(
|
152
|
+
RunId.get_latest(root_dir),
|
153
|
+
RunId.from_id('20241102_0')
|
154
|
+
)
|
155
|
+
self.assertIsNone(RunId.get_latest('/notexist'))
|
156
|
+
self.assertIsNone(RunId.get_latest(tempfile.gettempdir()))
|
157
|
+
|
158
|
+
def test_new(self):
|
159
|
+
rid = RunId(date=datetime.date.today(), number=1)
|
160
|
+
self.assertEqual(
|
161
|
+
RunId.new(root_dir=os.path.join(tempfile.gettempdir(), 'test_new')),
|
162
|
+
rid
|
163
|
+
)
|
164
|
+
root_dir = os.path.join(tempfile.gettempdir(), 'test_eval2')
|
165
|
+
pg.io.mkdirs(rid.dirname(root_dir))
|
166
|
+
self.assertEqual(RunId.new(root_dir), rid.next())
|
167
|
+
|
168
|
+
def test_is_valid(self):
|
169
|
+
self.assertTrue(RunId.is_valid('latest'))
|
170
|
+
self.assertTrue(RunId.is_valid('new'))
|
171
|
+
self.assertTrue(RunId.is_valid('20241102_0'))
|
172
|
+
self.assertFalse(RunId.is_valid('20241102-0'))
|
173
|
+
|
174
|
+
def test_from_id(self):
|
175
|
+
with self.assertRaisesRegex(
|
176
|
+
ValueError, '.* must be one of'
|
177
|
+
):
|
178
|
+
RunId.from_id('abc')
|
179
|
+
|
180
|
+
with self.assertRaisesRegex(
|
181
|
+
ValueError, '`root_dir` must be provided'
|
182
|
+
):
|
183
|
+
RunId.from_id('latest')
|
184
|
+
|
185
|
+
with self.assertRaisesRegex(
|
186
|
+
ValueError, '.* no previous runs'
|
187
|
+
):
|
188
|
+
RunId.from_id('latest', root_dir=tempfile.gettempdir())
|
189
|
+
|
190
|
+
self.assertEqual(
|
191
|
+
RunId.from_id('20241102_1'),
|
192
|
+
RunId(date=datetime.date(2024, 11, 2), number=1)
|
193
|
+
)
|
194
|
+
root_dir = os.path.join(tempfile.gettempdir(), 'test_eval3')
|
195
|
+
rid = RunId.from_id('20241102_1')
|
196
|
+
pg.io.mkdirs(rid.dirname(root_dir))
|
197
|
+
self.assertEqual(
|
198
|
+
RunId.from_id('latest', root_dir=root_dir), rid
|
199
|
+
)
|
200
|
+
self.assertEqual(
|
201
|
+
RunId.from_id('new', root_dir=root_dir),
|
202
|
+
RunId(datetime.date.today(), 1)
|
203
|
+
)
|
204
|
+
self.assertEqual(
|
205
|
+
RunId.from_id(None, root_dir=root_dir), rid
|
206
|
+
)
|
207
|
+
|
208
|
+
|
209
|
+
class RunTest(unittest.TestCase):
|
210
|
+
|
211
|
+
def test_input_output_paths(self):
|
212
|
+
run = Run(
|
213
|
+
'/root',
|
214
|
+
RunId.from_id('20241102_0'),
|
215
|
+
pg.Ref(Suite([
|
216
|
+
MyEvaluation(replica_id=0),
|
217
|
+
])),
|
218
|
+
)
|
219
|
+
self.assertEqual(run.output_root, '/root/run_20241102_0')
|
220
|
+
self.assertEqual(run.input_root, '/root/run_20241102_0')
|
221
|
+
self.assertEqual(
|
222
|
+
run.output_dir(run.experiment.leaf_nodes[0]),
|
223
|
+
(
|
224
|
+
'/root/run_20241102_0/MyEvaluation/'
|
225
|
+
+ run.experiment.leaf_nodes[0].hash
|
226
|
+
)
|
227
|
+
)
|
228
|
+
self.assertEqual(
|
229
|
+
run.input_path_for(run.experiment, 'a.txt'),
|
230
|
+
'/root/run_20241102_0/a.txt'
|
231
|
+
)
|
232
|
+
self.assertEqual(
|
233
|
+
run.input_path_for(run.experiment.leaf_nodes[0], 'a.txt'),
|
234
|
+
'/root/run_20241102_0/MyEvaluation/%s/a.txt' % (
|
235
|
+
run.experiment.leaf_nodes[0].hash
|
236
|
+
)
|
237
|
+
)
|
238
|
+
|
239
|
+
# With warmup_id
|
240
|
+
run = Run(
|
241
|
+
'/root',
|
242
|
+
RunId.from_id('20241102_0'),
|
243
|
+
pg.Ref(Suite([MyEvaluation(replica_id=0)])),
|
244
|
+
warm_start_from='/root2/run_20241103_1'
|
245
|
+
)
|
246
|
+
self.assertEqual(run.output_root, '/root/run_20241102_0')
|
247
|
+
self.assertEqual(run.input_root, '/root2/run_20241103_1')
|
248
|
+
self.assertEqual(
|
249
|
+
run.output_dir(run.experiment.leaf_nodes[0]),
|
250
|
+
(
|
251
|
+
'/root/run_20241102_0/MyEvaluation/'
|
252
|
+
+ run.experiment.leaf_nodes[0].hash
|
253
|
+
)
|
254
|
+
)
|
255
|
+
self.assertEqual(
|
256
|
+
run.input_dir(run.experiment.leaf_nodes[0]),
|
257
|
+
(
|
258
|
+
'/root2/run_20241103_1/MyEvaluation/'
|
259
|
+
+ run.experiment.leaf_nodes[0].hash
|
260
|
+
)
|
261
|
+
)
|
262
|
+
self.assertEqual(
|
263
|
+
run.input_path_for(run.experiment, 'a.txt'),
|
264
|
+
'/root2/run_20241103_1/a.txt'
|
265
|
+
)
|
266
|
+
self.assertEqual(
|
267
|
+
run.input_path_for(run.experiment.leaf_nodes[0], 'a.txt'),
|
268
|
+
'/root2/run_20241103_1/MyEvaluation/%s/a.txt' % (
|
269
|
+
run.experiment.leaf_nodes[0].hash
|
270
|
+
)
|
271
|
+
)
|
272
|
+
|
273
|
+
def test_examples_start_from_scratch(self):
|
274
|
+
run = Run(
|
275
|
+
'/root',
|
276
|
+
RunId.from_id('20241102_0'),
|
277
|
+
pg.Ref(Suite([
|
278
|
+
MyEvaluation(replica_id=0, inputs=sample_inputs(10)),
|
279
|
+
])),
|
280
|
+
)
|
281
|
+
root = run.experiment
|
282
|
+
self.assertEqual(run.examples_to_evaluate(root), set())
|
283
|
+
self.assertEqual(run.examples_to_reprocess(root), set())
|
284
|
+
self.assertEqual(run.examples_to_load(root), set())
|
285
|
+
self.assertEqual(run.examples_to_load_metadata(root), set())
|
286
|
+
|
287
|
+
exp = root.leaf_nodes[0]
|
288
|
+
self.assertEqual(run.examples_to_evaluate(exp), set(range(1, 11)))
|
289
|
+
self.assertEqual(run.examples_to_reprocess(exp), set())
|
290
|
+
self.assertEqual(run.examples_to_load(exp), set(range(1, 11)))
|
291
|
+
self.assertEqual(run.examples_to_load_metadata(exp), set())
|
292
|
+
|
293
|
+
def test_examples_with_example_ids(self):
|
294
|
+
run = Run(
|
295
|
+
'/root',
|
296
|
+
RunId.from_id('20241102_0'),
|
297
|
+
pg.Ref(Suite([
|
298
|
+
MyEvaluation(replica_id=0, inputs=sample_inputs(10)),
|
299
|
+
])),
|
300
|
+
example_ids=[1, 3, 5]
|
301
|
+
)
|
302
|
+
exp = run.experiment.leaf_nodes[0]
|
303
|
+
self.assertEqual(run.examples_to_evaluate(exp), set([1, 3, 5]))
|
304
|
+
self.assertEqual(run.examples_to_reprocess(exp), set())
|
305
|
+
self.assertEqual(run.examples_to_load(exp), set([1, 3, 5]))
|
306
|
+
self.assertEqual(run.examples_to_load_metadata(exp), set())
|
307
|
+
|
308
|
+
def test_examples_with_reprocess_all(self):
|
309
|
+
run = Run(
|
310
|
+
'/root',
|
311
|
+
RunId.from_id('20241102_0'),
|
312
|
+
pg.Ref(Suite([
|
313
|
+
MyEvaluation(replica_id=0, inputs=sample_inputs(10)),
|
314
|
+
])),
|
315
|
+
example_ids=[1, 3, 5],
|
316
|
+
reprocess=True
|
317
|
+
)
|
318
|
+
exp = run.experiment.leaf_nodes[0]
|
319
|
+
self.assertEqual(run.examples_to_evaluate(exp), set([1, 3, 5]))
|
320
|
+
self.assertEqual(run.examples_to_reprocess(exp), set([1, 3, 5]))
|
321
|
+
self.assertEqual(run.examples_to_load(exp), set())
|
322
|
+
self.assertEqual(run.examples_to_load_metadata(exp), set())
|
323
|
+
|
324
|
+
def test_examples_with_reprocess_some(self):
|
325
|
+
run = Run(
|
326
|
+
'/root',
|
327
|
+
RunId.from_id('20241102_0'),
|
328
|
+
pg.Ref(Suite([
|
329
|
+
MyEvaluation(replica_id=0, inputs=sample_inputs(10)),
|
330
|
+
])),
|
331
|
+
example_ids=[1, 3, 5],
|
332
|
+
reprocess=[1],
|
333
|
+
)
|
334
|
+
exp = run.experiment.leaf_nodes[0]
|
335
|
+
self.assertEqual(run.examples_to_evaluate(exp), set([1, 3, 5]))
|
336
|
+
self.assertEqual(run.examples_to_reprocess(exp), set([1]))
|
337
|
+
self.assertEqual(run.examples_to_load(exp), set([3, 5]))
|
338
|
+
self.assertEqual(run.examples_to_load_metadata(exp), set())
|
339
|
+
|
340
|
+
def test_examples_with_generate_example_html_all(self):
|
341
|
+
run = Run(
|
342
|
+
'/root',
|
343
|
+
RunId.from_id('20241102_0'),
|
344
|
+
pg.Ref(Suite([
|
345
|
+
MyEvaluation(replica_id=0, inputs=sample_inputs(10)),
|
346
|
+
])),
|
347
|
+
example_ids=[1, 3, 5],
|
348
|
+
reprocess=[1],
|
349
|
+
generate_example_html='all',
|
350
|
+
)
|
351
|
+
exp = run.experiment.leaf_nodes[0]
|
352
|
+
self.assertEqual(run.examples_to_evaluate(exp), set([1, 3, 5]))
|
353
|
+
self.assertEqual(run.examples_to_reprocess(exp), set([1]))
|
354
|
+
self.assertEqual(run.examples_to_load(exp), set([3, 5]))
|
355
|
+
self.assertEqual(run.examples_to_load_metadata(exp), set([3, 5]))
|
356
|
+
|
357
|
+
def test_examples_with_generate_example_html_new(self):
|
358
|
+
run = Run(
|
359
|
+
'/root',
|
360
|
+
RunId.from_id('20241102_0'),
|
361
|
+
pg.Ref(Suite([
|
362
|
+
MyEvaluation(replica_id=0, inputs=sample_inputs(10)),
|
363
|
+
])),
|
364
|
+
example_ids=[1, 3, 5],
|
365
|
+
reprocess=[1],
|
366
|
+
generate_example_html='new',
|
367
|
+
)
|
368
|
+
exp = run.experiment.leaf_nodes[0]
|
369
|
+
self.assertEqual(run.examples_to_evaluate(exp), set([1, 3, 5]))
|
370
|
+
self.assertEqual(run.examples_to_reprocess(exp), set([1]))
|
371
|
+
self.assertEqual(run.examples_to_load(exp), set([3, 5]))
|
372
|
+
self.assertEqual(run.examples_to_load_metadata(exp), set())
|
373
|
+
|
374
|
+
def test_examples_with_generate_example_html_some(self):
|
375
|
+
run = Run(
|
376
|
+
'/root',
|
377
|
+
RunId.from_id('20241102_0'),
|
378
|
+
pg.Ref(Suite([
|
379
|
+
MyEvaluation(replica_id=0, inputs=sample_inputs(10)),
|
380
|
+
])),
|
381
|
+
example_ids=[1, 3, 5],
|
382
|
+
reprocess=[1],
|
383
|
+
generate_example_html=[1, 2, 3],
|
384
|
+
)
|
385
|
+
exp = run.experiment.leaf_nodes[0]
|
386
|
+
self.assertEqual(run.examples_to_evaluate(exp), set([1, 3, 5]))
|
387
|
+
self.assertEqual(run.examples_to_reprocess(exp), set([1]))
|
388
|
+
self.assertEqual(run.examples_to_load(exp), set([2, 3, 5]))
|
389
|
+
self.assertEqual(run.examples_to_load_metadata(exp), set([2, 3]))
|
390
|
+
|
391
|
+
|
392
|
+
class RunnerTest(unittest.TestCase):
|
393
|
+
|
394
|
+
def test_basic(self):
|
395
|
+
|
396
|
+
class TestRunner(Runner):
|
397
|
+
NAME = 'test'
|
398
|
+
|
399
|
+
def run(self):
|
400
|
+
pass
|
401
|
+
|
402
|
+
self.assertIsInstance(
|
403
|
+
Runner.create(
|
404
|
+
'test',
|
405
|
+
current_run=Run(
|
406
|
+
'/root',
|
407
|
+
RunId.from_id('20241102_0'), pg.Ref(Suite([])),
|
408
|
+
)
|
409
|
+
),
|
410
|
+
TestRunner
|
411
|
+
)
|
412
|
+
root_dir = os.path.join(tempfile.gettempdir(), 'my_eval')
|
413
|
+
|
414
|
+
# Test standard run.
|
415
|
+
MyEvaluation(replica_id=0).run(
|
416
|
+
root_dir, id='20241101_0', runner='test'
|
417
|
+
)
|
418
|
+
|
419
|
+
# Test run preconfigured.
|
420
|
+
MyEvaluation(replica_id=0).run_preconfigured(
|
421
|
+
root_dir=root_dir, id='20241101_1'
|
422
|
+
)
|
423
|
+
|
424
|
+
with self.assertRaisesRegex(
|
425
|
+
ValueError, 'Runner class must define a NAME constant'
|
426
|
+
):
|
427
|
+
class AnotherRunner(Runner): # pylint: disable=unused-variable
|
428
|
+
def run(self):
|
429
|
+
pass
|
430
|
+
|
431
|
+
|
432
|
+
if __name__ == '__main__':
|
433
|
+
unittest.main()
|
@@ -0,0 +1,156 @@
|
|
1
|
+
# Copyright 2024 The Langfun Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
"""Common value types for evaluation metrics and metadata."""
|
15
|
+
|
16
|
+
|
17
|
+
import abc
|
18
|
+
from typing import Annotated, Any, Union
|
19
|
+
import pyglove as pg
|
20
|
+
|
21
|
+
|
22
|
+
class MetricValue(pg.Object):
|
23
|
+
"""Base class for metric values."""
|
24
|
+
|
25
|
+
class DataPoint(pg.Object):
|
26
|
+
"""A data point for a metric value."""
|
27
|
+
example_id: int
|
28
|
+
value: float
|
29
|
+
weight: float = 1.0
|
30
|
+
|
31
|
+
# NOTE(daiyip): For evaluations, usually the number of examples is within 10K,
|
32
|
+
# therefore it's beneficial to store all accumulated values with their example
|
33
|
+
# IDs so we are able to track the individual examples that contributed to this
|
34
|
+
# metric value. If this premise changes, we might consider using a more
|
35
|
+
# efficient data structure.
|
36
|
+
data_points: Annotated[
|
37
|
+
list[DataPoint],
|
38
|
+
'Accumulated computed values with example IDs and weights.'
|
39
|
+
] = []
|
40
|
+
|
41
|
+
total: Annotated[
|
42
|
+
int,
|
43
|
+
'The total number of examples being evaluated. Including errors.'
|
44
|
+
] = 0
|
45
|
+
|
46
|
+
def _on_bound(self):
|
47
|
+
super()._on_bound()
|
48
|
+
self._weighted_sum = sum(dp.value * dp.weight for dp in self.data_points)
|
49
|
+
|
50
|
+
def reset(self) -> None:
|
51
|
+
"""Resets the value to its initial state."""
|
52
|
+
self._sync_members(data_points=[], total=0)
|
53
|
+
self._weighted_sum = 0.0
|
54
|
+
|
55
|
+
def _sync_members(self, **kwargs) -> None:
|
56
|
+
"""Synchronizes the members of this object."""
|
57
|
+
self.rebind(**kwargs, skip_notification=True, raise_on_no_change=False)
|
58
|
+
|
59
|
+
def __float__(self) -> float:
|
60
|
+
"""Returns the float representation of this object."""
|
61
|
+
if self.total == 0:
|
62
|
+
return float('nan')
|
63
|
+
return self.reduce()
|
64
|
+
|
65
|
+
@abc.abstractmethod
|
66
|
+
def reduce(self) -> float:
|
67
|
+
"""Reduces the accumulated values into a single value."""
|
68
|
+
|
69
|
+
def increment_total(self, delta: int = 1) -> 'MetricValue':
|
70
|
+
"""Increments the total number of examples being evaluated."""
|
71
|
+
self._sync_members(total=self.total + delta)
|
72
|
+
return self
|
73
|
+
|
74
|
+
def add(
|
75
|
+
self,
|
76
|
+
example_id: int,
|
77
|
+
value: float,
|
78
|
+
weight: float = 1.0,
|
79
|
+
increment_total: bool = False,
|
80
|
+
) -> 'MetricValue':
|
81
|
+
"""Adds a value to the accumulated values."""
|
82
|
+
self._weighted_sum += value * weight
|
83
|
+
with pg.notify_on_change(False), pg.allow_writable_accessors(True):
|
84
|
+
self.data_points.append(
|
85
|
+
MetricValue.DataPoint(example_id, value, weight)
|
86
|
+
)
|
87
|
+
if increment_total:
|
88
|
+
self.increment_total()
|
89
|
+
return self
|
90
|
+
|
91
|
+
def __gt__(self, other: Union['MetricValue', float]) -> bool:
|
92
|
+
if isinstance(other, self.__class__):
|
93
|
+
return float(self) > float(other)
|
94
|
+
return float(self) > other
|
95
|
+
|
96
|
+
def __lt__(self, other: Union['MetricValue', float]) -> bool:
|
97
|
+
if isinstance(other, self.__class__):
|
98
|
+
return float(self) < float(other)
|
99
|
+
return float(self) < other
|
100
|
+
|
101
|
+
def __eq__(self, other: Union['MetricValue', float]) -> bool:
|
102
|
+
if isinstance(other, self.__class__):
|
103
|
+
return super().__eq__(other)
|
104
|
+
return float(self) == other
|
105
|
+
|
106
|
+
def __nonzero__(self) -> bool:
|
107
|
+
return float(self) != 0
|
108
|
+
|
109
|
+
def format(
|
110
|
+
self,
|
111
|
+
compact: bool = False,
|
112
|
+
verbose: bool = True,
|
113
|
+
*args,
|
114
|
+
**kwargs
|
115
|
+
) -> str:
|
116
|
+
if compact:
|
117
|
+
return super().format(compact, *args, **kwargs)
|
118
|
+
if self.total == 0:
|
119
|
+
return 'n/a'
|
120
|
+
if verbose:
|
121
|
+
return (
|
122
|
+
f'{self.scalar_repr()} ({len(self.data_points)}/{self.total})'
|
123
|
+
)
|
124
|
+
return self.scalar_repr()
|
125
|
+
|
126
|
+
@abc.abstractmethod
|
127
|
+
def scalar_repr(self) -> str:
|
128
|
+
"""Returns the format string for the value."""
|
129
|
+
|
130
|
+
def _sym_nondefault(self) -> dict[str, Any]:
|
131
|
+
"""Overrides nondefault valuesso volatile values are not included."""
|
132
|
+
return dict()
|
133
|
+
|
134
|
+
|
135
|
+
class Rate(MetricValue):
|
136
|
+
"""Representing a rate in range [0, 1]."""
|
137
|
+
|
138
|
+
def reduce(self) -> float:
|
139
|
+
return self._weighted_sum / self.total
|
140
|
+
|
141
|
+
def scalar_repr(self):
|
142
|
+
if self.total == 0:
|
143
|
+
return 'n/a'
|
144
|
+
return f'{self.reduce():.1%}'
|
145
|
+
|
146
|
+
|
147
|
+
class Average(MetricValue):
|
148
|
+
"""Average of a aggregated values."""
|
149
|
+
|
150
|
+
def reduce(self) -> float:
|
151
|
+
if not self.data_points:
|
152
|
+
return float('nan')
|
153
|
+
return self._weighted_sum / len(self.data_points)
|
154
|
+
|
155
|
+
def scalar_repr(self):
|
156
|
+
return f'{self.reduce():.3f}'
|
@@ -0,0 +1,80 @@
|
|
1
|
+
# Copyright 2024 The Langfun Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
import math
|
15
|
+
import unittest
|
16
|
+
|
17
|
+
from langfun.core.eval.v2 import metric_values
|
18
|
+
import pyglove as pg
|
19
|
+
|
20
|
+
|
21
|
+
class RateTest(unittest.TestCase):
|
22
|
+
|
23
|
+
def test_basic(self):
|
24
|
+
rate = metric_values.Rate()
|
25
|
+
self.assertEqual(rate.total, 0)
|
26
|
+
self.assertTrue(math.isnan(float(rate)))
|
27
|
+
self.assertEqual(pg.format(rate), 'n/a')
|
28
|
+
rate.increment_total()
|
29
|
+
self.assertEqual(rate.total, 1)
|
30
|
+
self.assertEqual(float(rate), 0.0)
|
31
|
+
rate.add(1, 1.0, 1.0)
|
32
|
+
self.assertEqual(float(rate), 1.0)
|
33
|
+
self.assertEqual(pg.format(rate, verbose=False), '100.0%')
|
34
|
+
self.assertEqual(pg.format(rate, verbose=True), '100.0% (1/1)')
|
35
|
+
self.assertEqual(
|
36
|
+
rate.data_points, [metric_values.MetricValue.DataPoint(1, 1.0, 1.0)]
|
37
|
+
)
|
38
|
+
self.assertEqual(rate, 1.0)
|
39
|
+
self.assertGreater(rate, 0.5)
|
40
|
+
self.assertLess(rate, 1.5)
|
41
|
+
self.assertEqual(
|
42
|
+
rate,
|
43
|
+
metric_values.Rate(
|
44
|
+
[metric_values.MetricValue.DataPoint(1, 1.0, 1.0)], 1
|
45
|
+
)
|
46
|
+
)
|
47
|
+
self.assertGreater(rate, metric_values.Rate([], 1))
|
48
|
+
self.assertLess(metric_values.Rate([], 1), rate)
|
49
|
+
|
50
|
+
rate.reset()
|
51
|
+
self.assertEqual(rate.total, 0)
|
52
|
+
self.assertTrue(math.isnan(float(rate)))
|
53
|
+
|
54
|
+
|
55
|
+
class AverageTest(unittest.TestCase):
|
56
|
+
|
57
|
+
def test_basic(self):
|
58
|
+
average = metric_values.Average()
|
59
|
+
self.assertEqual(average.total, 0)
|
60
|
+
self.assertTrue(math.isnan(float(average)))
|
61
|
+
self.assertEqual(pg.format(average, verbose=False), 'n/a')
|
62
|
+
average.add(1, 1.0, 0.5, increment_total=True)
|
63
|
+
average.add(1, 0.0, 1.0, increment_total=True)
|
64
|
+
self.assertEqual(average.total, 2)
|
65
|
+
self.assertEqual(float(average), 0.25)
|
66
|
+
self.assertEqual(pg.format(average, verbose=False), '0.250')
|
67
|
+
self.assertEqual(pg.format(average, verbose=True), '0.250 (2/2)')
|
68
|
+
self.assertEqual(
|
69
|
+
average.data_points,
|
70
|
+
[
|
71
|
+
metric_values.MetricValue.DataPoint(1, 1.0, 0.5),
|
72
|
+
metric_values.MetricValue.DataPoint(1, 0.0, 1.0),
|
73
|
+
]
|
74
|
+
)
|
75
|
+
average.reset()
|
76
|
+
self.assertEqual(average.total, 0)
|
77
|
+
|
78
|
+
|
79
|
+
if __name__ == '__main__':
|
80
|
+
unittest.main()
|