langfun 0.1.2.dev202411090804__py3-none-any.whl → 0.1.2.dev202411140804__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of langfun might be problematic. Click here for more details.
- langfun/core/console.py +10 -2
- langfun/core/console_test.py +17 -0
- langfun/core/eval/__init__.py +2 -0
- langfun/core/eval/v2/__init__.py +38 -0
- langfun/core/eval/v2/checkpointing.py +135 -0
- langfun/core/eval/v2/checkpointing_test.py +89 -0
- langfun/core/eval/v2/evaluation.py +627 -0
- langfun/core/eval/v2/evaluation_test.py +156 -0
- langfun/core/eval/v2/example.py +295 -0
- langfun/core/eval/v2/example_test.py +114 -0
- langfun/core/eval/v2/experiment.py +949 -0
- langfun/core/eval/v2/experiment_test.py +304 -0
- langfun/core/eval/v2/metric_values.py +156 -0
- langfun/core/eval/v2/metric_values_test.py +80 -0
- langfun/core/eval/v2/metrics.py +357 -0
- langfun/core/eval/v2/metrics_test.py +203 -0
- langfun/core/eval/v2/progress.py +348 -0
- langfun/core/eval/v2/progress_test.py +82 -0
- langfun/core/eval/v2/progress_tracking.py +209 -0
- langfun/core/eval/v2/progress_tracking_test.py +56 -0
- langfun/core/eval/v2/reporting.py +144 -0
- langfun/core/eval/v2/reporting_test.py +41 -0
- langfun/core/eval/v2/runners.py +417 -0
- langfun/core/eval/v2/runners_test.py +311 -0
- langfun/core/eval/v2/test_helper.py +80 -0
- langfun/core/language_model.py +122 -11
- langfun/core/language_model_test.py +97 -4
- langfun/core/llms/__init__.py +3 -0
- langfun/core/llms/compositional.py +101 -0
- langfun/core/llms/compositional_test.py +73 -0
- langfun/core/llms/vertexai.py +4 -4
- {langfun-0.1.2.dev202411090804.dist-info → langfun-0.1.2.dev202411140804.dist-info}/METADATA +1 -1
- {langfun-0.1.2.dev202411090804.dist-info → langfun-0.1.2.dev202411140804.dist-info}/RECORD +36 -12
- {langfun-0.1.2.dev202411090804.dist-info → langfun-0.1.2.dev202411140804.dist-info}/WHEEL +1 -1
- {langfun-0.1.2.dev202411090804.dist-info → langfun-0.1.2.dev202411140804.dist-info}/LICENSE +0 -0
- {langfun-0.1.2.dev202411090804.dist-info → langfun-0.1.2.dev202411140804.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,144 @@
|
|
1
|
+
# Copyright 2024 The Langfun Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
"""Reporting evaluation results."""
|
15
|
+
|
16
|
+
import time
|
17
|
+
from typing import Annotated
|
18
|
+
|
19
|
+
from langfun.core.eval.v2 import example as example_lib
|
20
|
+
from langfun.core.eval.v2 import experiment as experiment_lib
|
21
|
+
|
22
|
+
Runner = experiment_lib.Runner
|
23
|
+
Experiment = experiment_lib.Experiment
|
24
|
+
Example = example_lib.Example
|
25
|
+
|
26
|
+
|
27
|
+
_SUMMARY_FILE = 'summary.html'
|
28
|
+
_EVALULATION_DETAIL_FILE = 'index.html'
|
29
|
+
|
30
|
+
|
31
|
+
class HtmlReporter(experiment_lib.Plugin):
|
32
|
+
"""Plugin for periodically generating HTML reports for the experiment."""
|
33
|
+
|
34
|
+
summary_interval: Annotated[
|
35
|
+
int,
|
36
|
+
'The interval of writing summary in seconds.'
|
37
|
+
] = 60
|
38
|
+
|
39
|
+
experiment_report_interval: Annotated[
|
40
|
+
int,
|
41
|
+
'The interval of writing report for inidividual experiments in seconds.'
|
42
|
+
] = 60
|
43
|
+
|
44
|
+
def _on_bound(self):
|
45
|
+
super()._on_bound()
|
46
|
+
self._last_summary_time = 0
|
47
|
+
self._last_experiment_report_time = {}
|
48
|
+
|
49
|
+
def on_run_start(
|
50
|
+
self,
|
51
|
+
runner: Runner,
|
52
|
+
root: Experiment
|
53
|
+
) -> None:
|
54
|
+
self._maybe_update_summary(runner)
|
55
|
+
self._last_experiment_report_time = {leaf.id: 0 for leaf in root.leaf_nodes}
|
56
|
+
|
57
|
+
def on_run_complete(
|
58
|
+
self,
|
59
|
+
runner: Runner,
|
60
|
+
root: Experiment
|
61
|
+
) -> None:
|
62
|
+
self._maybe_update_summary(runner, force=True)
|
63
|
+
|
64
|
+
def on_experiment_start(
|
65
|
+
self,
|
66
|
+
runner: Runner,
|
67
|
+
experiment: Experiment
|
68
|
+
) -> None:
|
69
|
+
if experiment.is_leaf:
|
70
|
+
self._maybe_update_experiment_html(runner, experiment)
|
71
|
+
|
72
|
+
def on_experiment_complete(
|
73
|
+
self, runner: Runner, experiment: Experiment
|
74
|
+
):
|
75
|
+
if experiment.is_leaf:
|
76
|
+
self._maybe_update_experiment_html(runner, experiment, force=True)
|
77
|
+
|
78
|
+
def on_example_complete(
|
79
|
+
self, runner: Runner, experiment: Experiment, example: Example
|
80
|
+
):
|
81
|
+
self._save_example_html(runner, experiment, example)
|
82
|
+
self._maybe_update_experiment_html(runner, experiment)
|
83
|
+
self._maybe_update_summary(runner)
|
84
|
+
|
85
|
+
def _maybe_update_summary(self, runner: Runner, force: bool = False) -> None:
|
86
|
+
"""Maybe update the summary of current run."""
|
87
|
+
run = runner.current_run
|
88
|
+
def _summary():
|
89
|
+
run.experiment.to_html(
|
90
|
+
collapse_level=None,
|
91
|
+
extra_flags=dict(
|
92
|
+
current_run=run, interactive=False, card_view=True,
|
93
|
+
)
|
94
|
+
).save(
|
95
|
+
run.output_path_for(run.experiment, _SUMMARY_FILE)
|
96
|
+
)
|
97
|
+
|
98
|
+
if force or (time.time() - self._last_summary_time > self.summary_interval):
|
99
|
+
runner.background_run(_summary)
|
100
|
+
self._last_summary_time = time.time()
|
101
|
+
|
102
|
+
def _maybe_update_experiment_html(
|
103
|
+
self, runner: Runner, experiment: Experiment, force: bool = False
|
104
|
+
) -> None:
|
105
|
+
def _save():
|
106
|
+
html = experiment.to_html(
|
107
|
+
collapse_level=None,
|
108
|
+
extra_flags=dict(
|
109
|
+
current_run=runner.current_run,
|
110
|
+
interactive=False,
|
111
|
+
card_view=False,
|
112
|
+
),
|
113
|
+
)
|
114
|
+
html.save(
|
115
|
+
runner.current_run.output_path_for(
|
116
|
+
experiment, _EVALULATION_DETAIL_FILE
|
117
|
+
)
|
118
|
+
)
|
119
|
+
if force or (
|
120
|
+
time.time() - self._last_experiment_report_time[experiment.id]
|
121
|
+
> self.experiment_report_interval
|
122
|
+
):
|
123
|
+
runner.background_run(_save)
|
124
|
+
self._last_experiment_report_time[experiment.id] = time.time()
|
125
|
+
|
126
|
+
def _save_example_html(
|
127
|
+
self, runner: Runner, experiment: Experiment, example: Example
|
128
|
+
) -> None:
|
129
|
+
"""Saves the example."""
|
130
|
+
def _save():
|
131
|
+
html = example.to_html(
|
132
|
+
collapse_level=None,
|
133
|
+
enable_summary_tooltip=False,
|
134
|
+
extra_flags=dict(
|
135
|
+
# For properly rendering the next link.
|
136
|
+
num_examples=getattr(experiment, 'num_examples', None)
|
137
|
+
),
|
138
|
+
)
|
139
|
+
html.save(
|
140
|
+
runner.current_run.output_path_for(
|
141
|
+
experiment, f'{example.id}.html'
|
142
|
+
)
|
143
|
+
)
|
144
|
+
runner.background_run(_save)
|
@@ -0,0 +1,41 @@
|
|
1
|
+
# Copyright 2024 The Langfun Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
import os
|
15
|
+
import tempfile
|
16
|
+
import unittest
|
17
|
+
|
18
|
+
from langfun.core.eval.v2 import reporting
|
19
|
+
from langfun.core.eval.v2 import runners as runners_lib # pylint: disable=unused-import
|
20
|
+
from langfun.core.eval.v2 import test_helper
|
21
|
+
import pyglove as pg
|
22
|
+
|
23
|
+
|
24
|
+
class ReportingTest(unittest.TestCase):
|
25
|
+
|
26
|
+
def test_reporting(self):
|
27
|
+
root_dir = os.path.join(tempfile.gettempdir(), 'test_reporting')
|
28
|
+
experiment = test_helper.test_experiment()
|
29
|
+
reporter = reporting.HtmlReporter()
|
30
|
+
run = experiment.run(root_dir, 'new', plugins=[reporter])
|
31
|
+
pg.io.path_exists(run.output_path_for(experiment, 'summary.html'))
|
32
|
+
for leaf in experiment.leaf_nodes:
|
33
|
+
self.assertTrue(
|
34
|
+
pg.io.path_exists(run.output_path_for(leaf, 'index.html'))
|
35
|
+
)
|
36
|
+
for i in range(leaf.num_examples):
|
37
|
+
pg.io.path_exists(run.output_path_for(leaf, f'{i + 1}.html'))
|
38
|
+
|
39
|
+
|
40
|
+
if __name__ == '__main__':
|
41
|
+
unittest.main()
|
@@ -0,0 +1,417 @@
|
|
1
|
+
# Copyright 2024 The Langfun Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
"""Evaluation experiment runners."""
|
15
|
+
import abc
|
16
|
+
import collections
|
17
|
+
import concurrent.futures
|
18
|
+
from typing import Any, Annotated, Callable, Iterator
|
19
|
+
|
20
|
+
from langfun import core as lf
|
21
|
+
from langfun.core.eval.v2 import checkpointing
|
22
|
+
from langfun.core.eval.v2 import evaluation as evaluation_lib
|
23
|
+
from langfun.core.eval.v2 import example as example_lib
|
24
|
+
from langfun.core.eval.v2 import experiment as experiment_lib
|
25
|
+
from langfun.core.eval.v2 import progress_tracking
|
26
|
+
from langfun.core.eval.v2 import reporting
|
27
|
+
from langfun.core.llms.cache import in_memory
|
28
|
+
import pyglove as pg
|
29
|
+
|
30
|
+
Runner = experiment_lib.Runner
|
31
|
+
Example = example_lib.Example
|
32
|
+
Evaluation = evaluation_lib.Evaluation
|
33
|
+
Experiment = experiment_lib.Experiment
|
34
|
+
Plugin = experiment_lib.Plugin
|
35
|
+
|
36
|
+
|
37
|
+
_RUN_MANIFEST = 'run.json'
|
38
|
+
|
39
|
+
|
40
|
+
class RunnerBase(Runner):
|
41
|
+
"""A simple runner that runs evaluations and their examples sequentially."""
|
42
|
+
|
43
|
+
tqdm: Annotated[
|
44
|
+
bool,
|
45
|
+
(
|
46
|
+
'If True, force using tqdm for progress update. Otherwise, determine '
|
47
|
+
'it automatically based on the running environment (console vs. '
|
48
|
+
'notebook)'
|
49
|
+
)
|
50
|
+
] = False
|
51
|
+
|
52
|
+
plugins = [
|
53
|
+
checkpointing.Checkpointer(),
|
54
|
+
reporting.HtmlReporter(),
|
55
|
+
]
|
56
|
+
|
57
|
+
def _on_bound(self):
|
58
|
+
super()._on_bound()
|
59
|
+
|
60
|
+
# Install the tqdm plugin if needed.
|
61
|
+
with pg.notify_on_change(False):
|
62
|
+
self.plugins.append(progress_tracking.progress_tracker(self.tqdm))
|
63
|
+
|
64
|
+
self._io_pool = concurrent.futures.ThreadPoolExecutor(max_workers=16)
|
65
|
+
# TODO(daiyip): render background errors.
|
66
|
+
self._background_last_error = None
|
67
|
+
|
68
|
+
def background_run(self, func: Callable[..., Any], *args, **kwargs) -> None:
|
69
|
+
"""Runs the function with the IO pool."""
|
70
|
+
def _background_run(*args, **kwargs):
|
71
|
+
try:
|
72
|
+
func(*args, **kwargs)
|
73
|
+
except Exception as e: # pylint: disable=broad-except
|
74
|
+
self._background_last_error = e
|
75
|
+
self._io_pool.submit(_background_run, *args, **kwargs)
|
76
|
+
|
77
|
+
def _all_plugins(self, experiment: Experiment) -> Iterator[Plugin]:
|
78
|
+
"""Returns all plugins for the experiment."""
|
79
|
+
for plugin in self.plugins:
|
80
|
+
yield plugin
|
81
|
+
for plugin in experiment.plugins:
|
82
|
+
yield plugin
|
83
|
+
|
84
|
+
#
|
85
|
+
# IO operations for saving running files.
|
86
|
+
#
|
87
|
+
|
88
|
+
def _save_run_manifest(self) -> None:
|
89
|
+
def _save():
|
90
|
+
pg.symbolic.deref(self.current_run.clone(), recursive=True).save(
|
91
|
+
self.current_run.output_path_for(
|
92
|
+
self.current_run.experiment, _RUN_MANIFEST
|
93
|
+
),
|
94
|
+
hide_default_values=True
|
95
|
+
)
|
96
|
+
self.background_run(_save)
|
97
|
+
|
98
|
+
def on_run_start(self) -> None:
|
99
|
+
"""Called when a runner is started."""
|
100
|
+
self._save_run_manifest()
|
101
|
+
|
102
|
+
for plugin in self._all_plugins(self.current_run.experiment):
|
103
|
+
plugin.on_run_start(self, self.current_run.experiment)
|
104
|
+
|
105
|
+
def on_run_complete(self) -> None:
|
106
|
+
"""Called when a runner is complete."""
|
107
|
+
for plugin in self._all_plugins(self.current_run.experiment):
|
108
|
+
plugin.on_run_complete(self, self.current_run.experiment)
|
109
|
+
|
110
|
+
def on_run_abort(self, error: Exception) -> None:
|
111
|
+
"""Called when a runner is aborted."""
|
112
|
+
for plugin in self._all_plugins(self.current_run.experiment):
|
113
|
+
plugin.on_run_abort(self, self.current_run.experiment, error)
|
114
|
+
|
115
|
+
def on_experiment_start(self, experiment: Experiment) -> None:
|
116
|
+
"""Called when an evaluation is started."""
|
117
|
+
# Start the progress of the evaluation.
|
118
|
+
if experiment.is_leaf:
|
119
|
+
assert isinstance(experiment, Evaluation)
|
120
|
+
experiment.progress.start(
|
121
|
+
total=(len(self.current_run.example_ids)
|
122
|
+
if self.current_run.example_ids else experiment.num_examples)
|
123
|
+
)
|
124
|
+
else:
|
125
|
+
experiment.progress.start(total=len(experiment.leaf_nodes))
|
126
|
+
|
127
|
+
# Notify the plugins of the experiment start.
|
128
|
+
for plugin in self._all_plugins(experiment):
|
129
|
+
plugin.on_experiment_start(self, experiment)
|
130
|
+
|
131
|
+
def on_experiment_skipped(self, experiment: Experiment) -> None:
|
132
|
+
"""Called when an evaluation is skipped."""
|
133
|
+
# Skip event will only be triggered for leaf evaluations.
|
134
|
+
assert experiment.is_leaf
|
135
|
+
experiment.progress.start(total=1)
|
136
|
+
experiment.progress.increment_skipped(1)
|
137
|
+
|
138
|
+
# Notify the plugins of the experiment skip.
|
139
|
+
for plugin in self._all_plugins(experiment):
|
140
|
+
plugin.on_experiment_skipped(self, experiment)
|
141
|
+
|
142
|
+
# Only leaf evaluations will trigger the complete notification of the
|
143
|
+
# ancestors.
|
144
|
+
if experiment.is_leaf:
|
145
|
+
self._update_ancestor_progresses(experiment)
|
146
|
+
|
147
|
+
def on_experiment_complete(self, experiment: Experiment) -> None:
|
148
|
+
"""Called when an evaluation is complete."""
|
149
|
+
progress = experiment.progress
|
150
|
+
progress.stop()
|
151
|
+
|
152
|
+
# Notify the plugins of the experiment complete.
|
153
|
+
for plugin in self._all_plugins(experiment):
|
154
|
+
plugin.on_experiment_complete(self, experiment)
|
155
|
+
|
156
|
+
# Only leaf evaluations will trigger the complete notification of the
|
157
|
+
# ancestors.
|
158
|
+
if experiment.is_leaf:
|
159
|
+
self._update_ancestor_progresses(experiment)
|
160
|
+
|
161
|
+
def _update_ancestor_progresses(self, experiment: Experiment):
|
162
|
+
"""Updates the progresses of the parent nodes of the experiment."""
|
163
|
+
parent = experiment.parent
|
164
|
+
progress = experiment.progress
|
165
|
+
while parent is not None:
|
166
|
+
parent_progress = parent.progress
|
167
|
+
if progress.is_failed:
|
168
|
+
parent_progress.increment_failed()
|
169
|
+
elif progress.is_skipped:
|
170
|
+
parent_progress.increment_skipped()
|
171
|
+
else:
|
172
|
+
# A evaluation could be considered as done if it has processed all the
|
173
|
+
# examples specified by `example_ids`.
|
174
|
+
assert progress.is_completed
|
175
|
+
parent_progress.increment_processed()
|
176
|
+
|
177
|
+
if parent_progress.is_completed:
|
178
|
+
self.on_experiment_complete(parent)
|
179
|
+
elif parent_progress.is_skipped:
|
180
|
+
self.on_experiment_skipped(parent)
|
181
|
+
parent = parent.parent
|
182
|
+
|
183
|
+
def on_example_start(
|
184
|
+
self,
|
185
|
+
experiment: Experiment,
|
186
|
+
example: Example
|
187
|
+
) -> None:
|
188
|
+
"""Called when an evaluation example is started."""
|
189
|
+
for plugin in self._all_plugins(experiment):
|
190
|
+
plugin.on_example_start(self, experiment, example)
|
191
|
+
|
192
|
+
def on_example_complete(
|
193
|
+
self,
|
194
|
+
experiment: Experiment,
|
195
|
+
example: Example
|
196
|
+
) -> None:
|
197
|
+
"""Called when an evaluation example is complete."""
|
198
|
+
if example.newly_processed:
|
199
|
+
if example.error is None:
|
200
|
+
experiment.progress.increment_processed()
|
201
|
+
else:
|
202
|
+
experiment.progress.increment_failed()
|
203
|
+
else:
|
204
|
+
experiment.progress.increment_skipped()
|
205
|
+
|
206
|
+
experiment.usage_summary.merge(example.usage_summary)
|
207
|
+
experiment.progress.update_execution_summary(example.execution_status)
|
208
|
+
|
209
|
+
parent = experiment.parent
|
210
|
+
while parent is not None:
|
211
|
+
parent.usage_summary.merge(example.usage_summary)
|
212
|
+
parent = parent.parent
|
213
|
+
|
214
|
+
for plugin in self._all_plugins(experiment):
|
215
|
+
plugin.on_example_complete(self, experiment, example)
|
216
|
+
|
217
|
+
def run(self) -> None:
|
218
|
+
"""Runs the experiment."""
|
219
|
+
# Resets the experiment before getting start.
|
220
|
+
for node in self.current_run.experiment.nodes:
|
221
|
+
node.reset()
|
222
|
+
|
223
|
+
# Start the run.
|
224
|
+
self.on_run_start()
|
225
|
+
cache = None
|
226
|
+
|
227
|
+
try:
|
228
|
+
# Start the non-leaf nodes.
|
229
|
+
for node in self.current_run.experiment.nonleaf_nodes:
|
230
|
+
self.on_experiment_start(node)
|
231
|
+
|
232
|
+
# Skip evaluations if needed.
|
233
|
+
if self.current_run.filter is not None:
|
234
|
+
targets = []
|
235
|
+
for evaluation in self.current_run.experiment.leaf_nodes:
|
236
|
+
if self.current_run.filter(evaluation):
|
237
|
+
targets.append(evaluation)
|
238
|
+
else:
|
239
|
+
self.on_experiment_skipped(evaluation)
|
240
|
+
else:
|
241
|
+
targets = self.current_run.experiment.leaf_nodes
|
242
|
+
|
243
|
+
# Prepare the global cache if needed.
|
244
|
+
global_settings = {}
|
245
|
+
if self.current_run.use_cache == 'global':
|
246
|
+
cache = self._load_or_create_cache(self.current_run.experiment)
|
247
|
+
global_settings['cache'] = cache
|
248
|
+
|
249
|
+
# Evaluate the leaf evaluations if not skipped.
|
250
|
+
with lf.use_settings(**global_settings):
|
251
|
+
self._run(targets)
|
252
|
+
|
253
|
+
self.on_run_complete()
|
254
|
+
except Exception as e: # pylint: disable=broad-except
|
255
|
+
self.on_run_abort(e)
|
256
|
+
raise e
|
257
|
+
finally:
|
258
|
+
if cache is not None:
|
259
|
+
self.background_run(cache.save)
|
260
|
+
|
261
|
+
@abc.abstractmethod
|
262
|
+
def _run(self, evaluations: list[Evaluation]) -> None:
|
263
|
+
"""Runs multiple evaluations."""
|
264
|
+
|
265
|
+
def run_evaluation(self, evaluation: Evaluation) -> None:
|
266
|
+
"""Runs the evaluation."""
|
267
|
+
self.on_experiment_start(evaluation)
|
268
|
+
|
269
|
+
per_evaluation_settings = {}
|
270
|
+
cache = None
|
271
|
+
if self.current_run.use_cache == 'per_dataset':
|
272
|
+
cache = self._load_or_create_cache(evaluation)
|
273
|
+
per_evaluation_settings['cache'] = cache
|
274
|
+
|
275
|
+
with lf.use_settings(**per_evaluation_settings):
|
276
|
+
if self.current_run.example_ids is None:
|
277
|
+
items = (
|
278
|
+
Example(id=i + 1, input=ex) for i, ex in enumerate(
|
279
|
+
evaluation.example_inputs)
|
280
|
+
)
|
281
|
+
else:
|
282
|
+
items = (
|
283
|
+
Example(
|
284
|
+
id=example_id, input=evaluation.example_input_by_id(example_id)
|
285
|
+
) for example_id in self.current_run.example_ids
|
286
|
+
)
|
287
|
+
self._evaluate_items(evaluation, items)
|
288
|
+
|
289
|
+
if cache:
|
290
|
+
self.background_run(cache.save)
|
291
|
+
self.on_experiment_complete(evaluation)
|
292
|
+
|
293
|
+
@abc.abstractmethod
|
294
|
+
def _evaluate_items(
|
295
|
+
self, evaluation: Evaluation, items: Iterator[Example]
|
296
|
+
) -> None:
|
297
|
+
"""Evaluates the items of an evaluation."""
|
298
|
+
|
299
|
+
def evaluate_item(
|
300
|
+
self,
|
301
|
+
evaluation: Evaluation,
|
302
|
+
item: Example
|
303
|
+
) -> Example:
|
304
|
+
"""Runs the evaluation example."""
|
305
|
+
self.on_example_start(evaluation, item)
|
306
|
+
item = evaluation.evaluate(
|
307
|
+
item, raise_if_has_error=self.current_run.raise_if_has_error
|
308
|
+
)
|
309
|
+
self.on_example_complete(evaluation, item)
|
310
|
+
return item
|
311
|
+
|
312
|
+
def _load_or_create_cache(self, experiment: Experiment) -> lf.LMCache | None:
|
313
|
+
"""Loads or creates the cache."""
|
314
|
+
return in_memory.InMemory(
|
315
|
+
self.current_run.output_path_for(experiment, 'cache.json')
|
316
|
+
)
|
317
|
+
|
318
|
+
|
319
|
+
class SequentialRunner(RunnerBase):
|
320
|
+
"""Sequential runner.
|
321
|
+
|
322
|
+
Sequential runner runs all evaluations and their examples in sequence,
|
323
|
+
as well as the background tasks, it allows the developer to catch all
|
324
|
+
exceptions thrown from the background tasks, making it easier to debug.
|
325
|
+
"""
|
326
|
+
|
327
|
+
NAME = 'sequential'
|
328
|
+
|
329
|
+
def background_run(
|
330
|
+
self, func: Callable[..., Any], *args: Any, **kwargs: Any
|
331
|
+
) -> None:
|
332
|
+
"""Runs the function with the IO pool."""
|
333
|
+
func(*args, **kwargs)
|
334
|
+
|
335
|
+
def _run(self, evaluations: list[Evaluation]) -> None:
|
336
|
+
"""Runs the experiment in sequence."""
|
337
|
+
for e in evaluations:
|
338
|
+
self.run_evaluation(e)
|
339
|
+
|
340
|
+
def _evaluate_items(
|
341
|
+
self, evaluation: Evaluation, items: Iterator[Example]
|
342
|
+
) -> None:
|
343
|
+
"""Runs the evaluation items in sequence."""
|
344
|
+
for item in items:
|
345
|
+
self.evaluate_item(evaluation, item)
|
346
|
+
|
347
|
+
|
348
|
+
class DebugRunner(SequentialRunner):
|
349
|
+
"""Debug runner."""
|
350
|
+
|
351
|
+
NAME = 'debug'
|
352
|
+
|
353
|
+
# Do not use the checkpointer for debug runner.
|
354
|
+
plugins = []
|
355
|
+
|
356
|
+
def _on_bound(self):
|
357
|
+
super()._on_bound()
|
358
|
+
if self.current_run.example_ids is None:
|
359
|
+
self.current_run.rebind(example_ids=[1], skip_notification=True)
|
360
|
+
self.current_run.rebind(raise_if_has_error=True, skip_notification=True)
|
361
|
+
|
362
|
+
def _save_run_manifest(self) -> None:
|
363
|
+
"""Do nothing to avoid overriden existing runs."""
|
364
|
+
|
365
|
+
|
366
|
+
class ParallelRunner(RunnerBase):
|
367
|
+
"""Parallel runner."""
|
368
|
+
|
369
|
+
NAME = 'parallel'
|
370
|
+
|
371
|
+
timeout: Annotated[
|
372
|
+
int | None,
|
373
|
+
'Timeout for each evaluation example.'
|
374
|
+
] = None
|
375
|
+
|
376
|
+
def _run(self, evaluations: list[Evaluation]) -> None:
|
377
|
+
"""Runs the evaluations in parallel."""
|
378
|
+
def _run_group(evaluation_group: list[Evaluation]):
|
379
|
+
for e in evaluation_group:
|
380
|
+
self.run_evaluation(e)
|
381
|
+
|
382
|
+
# Run evaluations in parallel groupped by resource key.
|
383
|
+
groups: dict[str, list[Evaluation]] = collections.defaultdict(list)
|
384
|
+
for e in evaluations:
|
385
|
+
resource_ids = e.resource_ids()
|
386
|
+
if not resource_ids:
|
387
|
+
group_id = e.id
|
388
|
+
else:
|
389
|
+
# TODO(daiyip): support group that requires multiple resources.
|
390
|
+
group_id = resource_ids.pop()
|
391
|
+
groups[group_id].append(e)
|
392
|
+
|
393
|
+
for _, _, _ in lf.concurrent_map(
|
394
|
+
_run_group,
|
395
|
+
groups.values(),
|
396
|
+
max_workers=max(64, len(groups)),
|
397
|
+
timeout=self.timeout,
|
398
|
+
silence_on_errors=(
|
399
|
+
None if self.current_run.raise_if_has_error else BaseException
|
400
|
+
)
|
401
|
+
):
|
402
|
+
pass
|
403
|
+
|
404
|
+
def _evaluate_items(
|
405
|
+
self, evaluation: Evaluation, items: Iterator[Example]
|
406
|
+
) -> None:
|
407
|
+
"""Override run items to run in parallel."""
|
408
|
+
for _, _, _ in lf.concurrent_map(
|
409
|
+
lambda item: self.evaluate_item(evaluation, item),
|
410
|
+
items,
|
411
|
+
max_workers=evaluation.max_workers,
|
412
|
+
timeout=self.timeout,
|
413
|
+
silence_on_errors=(
|
414
|
+
None if self.current_run.raise_if_has_error else BaseException
|
415
|
+
)
|
416
|
+
):
|
417
|
+
pass
|