langfun 0.1.2.dev202510230805__py3-none-any.whl → 0.1.2.dev202511270805__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of langfun might be problematic. Click here for more details.
- langfun/core/__init__.py +2 -0
- langfun/core/agentic/__init__.py +4 -1
- langfun/core/agentic/action.py +447 -29
- langfun/core/agentic/action_eval.py +9 -2
- langfun/core/agentic/action_test.py +149 -21
- langfun/core/async_support.py +32 -3
- langfun/core/coding/python/correction.py +19 -9
- langfun/core/coding/python/execution.py +14 -12
- langfun/core/coding/python/generation.py +21 -16
- langfun/core/coding/python/sandboxing.py +23 -3
- langfun/core/component.py +42 -3
- langfun/core/concurrent.py +70 -6
- langfun/core/concurrent_test.py +1 -0
- langfun/core/console.py +1 -1
- langfun/core/data/conversion/anthropic.py +12 -3
- langfun/core/data/conversion/anthropic_test.py +8 -6
- langfun/core/data/conversion/gemini.py +9 -2
- langfun/core/data/conversion/gemini_test.py +12 -9
- langfun/core/data/conversion/openai.py +145 -31
- langfun/core/data/conversion/openai_test.py +161 -17
- langfun/core/eval/base.py +47 -43
- langfun/core/eval/base_test.py +5 -5
- langfun/core/eval/matching.py +5 -2
- langfun/core/eval/patching.py +3 -3
- langfun/core/eval/scoring.py +4 -3
- langfun/core/eval/v2/__init__.py +1 -0
- langfun/core/eval/v2/checkpointing.py +64 -6
- langfun/core/eval/v2/checkpointing_test.py +9 -2
- langfun/core/eval/v2/eval_test_helper.py +103 -2
- langfun/core/eval/v2/evaluation.py +91 -16
- langfun/core/eval/v2/evaluation_test.py +9 -3
- langfun/core/eval/v2/example.py +50 -40
- langfun/core/eval/v2/example_test.py +16 -8
- langfun/core/eval/v2/experiment.py +74 -8
- langfun/core/eval/v2/experiment_test.py +19 -0
- langfun/core/eval/v2/metric_values.py +31 -3
- langfun/core/eval/v2/metric_values_test.py +32 -0
- langfun/core/eval/v2/metrics.py +157 -44
- langfun/core/eval/v2/metrics_test.py +39 -18
- langfun/core/eval/v2/progress.py +30 -1
- langfun/core/eval/v2/progress_test.py +27 -0
- langfun/core/eval/v2/progress_tracking.py +12 -3
- langfun/core/eval/v2/progress_tracking_test.py +6 -1
- langfun/core/eval/v2/reporting.py +90 -71
- langfun/core/eval/v2/reporting_test.py +24 -6
- langfun/core/eval/v2/runners/__init__.py +30 -0
- langfun/core/eval/v2/{runners.py → runners/base.py} +59 -142
- langfun/core/eval/v2/runners/beam.py +341 -0
- langfun/core/eval/v2/runners/beam_test.py +131 -0
- langfun/core/eval/v2/runners/ckpt_monitor.py +294 -0
- langfun/core/eval/v2/runners/ckpt_monitor_test.py +162 -0
- langfun/core/eval/v2/runners/debug.py +40 -0
- langfun/core/eval/v2/runners/debug_test.py +76 -0
- langfun/core/eval/v2/runners/parallel.py +100 -0
- langfun/core/eval/v2/runners/parallel_test.py +95 -0
- langfun/core/eval/v2/runners/sequential.py +47 -0
- langfun/core/eval/v2/runners/sequential_test.py +172 -0
- langfun/core/langfunc.py +45 -130
- langfun/core/langfunc_test.py +7 -5
- langfun/core/language_model.py +141 -21
- langfun/core/language_model_test.py +54 -3
- langfun/core/llms/__init__.py +9 -1
- langfun/core/llms/anthropic.py +157 -2
- langfun/core/llms/azure_openai.py +29 -17
- langfun/core/llms/cache/base.py +25 -3
- langfun/core/llms/cache/in_memory.py +48 -7
- langfun/core/llms/cache/in_memory_test.py +14 -4
- langfun/core/llms/compositional.py +25 -1
- langfun/core/llms/deepseek.py +30 -2
- langfun/core/llms/fake.py +32 -1
- langfun/core/llms/gemini.py +55 -17
- langfun/core/llms/gemini_test.py +84 -0
- langfun/core/llms/google_genai.py +34 -1
- langfun/core/llms/groq.py +28 -3
- langfun/core/llms/llama_cpp.py +23 -4
- langfun/core/llms/openai.py +36 -3
- langfun/core/llms/openai_compatible.py +148 -27
- langfun/core/llms/openai_compatible_test.py +207 -20
- langfun/core/llms/openai_test.py +0 -2
- langfun/core/llms/rest.py +12 -1
- langfun/core/llms/vertexai.py +58 -8
- langfun/core/logging.py +1 -1
- langfun/core/mcp/client.py +77 -22
- langfun/core/mcp/client_test.py +8 -35
- langfun/core/mcp/session.py +94 -29
- langfun/core/mcp/session_test.py +54 -0
- langfun/core/mcp/tool.py +151 -22
- langfun/core/mcp/tool_test.py +197 -0
- langfun/core/memory.py +1 -0
- langfun/core/message.py +160 -55
- langfun/core/message_test.py +65 -81
- langfun/core/modalities/__init__.py +8 -0
- langfun/core/modalities/audio.py +21 -1
- langfun/core/modalities/image.py +19 -1
- langfun/core/modalities/mime.py +64 -3
- langfun/core/modalities/mime_test.py +11 -0
- langfun/core/modalities/pdf.py +19 -1
- langfun/core/modalities/video.py +21 -1
- langfun/core/modality.py +167 -29
- langfun/core/modality_test.py +42 -12
- langfun/core/natural_language.py +1 -1
- langfun/core/sampling.py +4 -4
- langfun/core/sampling_test.py +20 -4
- langfun/core/structured/__init__.py +2 -24
- langfun/core/structured/completion.py +34 -44
- langfun/core/structured/completion_test.py +23 -43
- langfun/core/structured/description.py +54 -50
- langfun/core/structured/function_generation.py +29 -12
- langfun/core/structured/mapping.py +81 -37
- langfun/core/structured/parsing.py +95 -79
- langfun/core/structured/parsing_test.py +0 -3
- langfun/core/structured/querying.py +215 -142
- langfun/core/structured/querying_test.py +65 -29
- langfun/core/structured/schema/__init__.py +49 -0
- langfun/core/structured/schema/base.py +664 -0
- langfun/core/structured/schema/base_test.py +531 -0
- langfun/core/structured/schema/json.py +174 -0
- langfun/core/structured/schema/json_test.py +121 -0
- langfun/core/structured/schema/python.py +316 -0
- langfun/core/structured/schema/python_test.py +410 -0
- langfun/core/structured/schema_generation.py +33 -14
- langfun/core/structured/scoring.py +47 -36
- langfun/core/structured/tokenization.py +26 -11
- langfun/core/subscription.py +2 -2
- langfun/core/template.py +174 -49
- langfun/core/template_test.py +123 -17
- langfun/env/__init__.py +8 -2
- langfun/env/base_environment.py +320 -128
- langfun/env/base_environment_test.py +473 -0
- langfun/env/base_feature.py +92 -15
- langfun/env/base_feature_test.py +228 -0
- langfun/env/base_sandbox.py +84 -361
- langfun/env/base_sandbox_test.py +1235 -0
- langfun/env/event_handlers/__init__.py +1 -1
- langfun/env/event_handlers/chain.py +233 -0
- langfun/env/event_handlers/chain_test.py +253 -0
- langfun/env/event_handlers/event_logger.py +95 -98
- langfun/env/event_handlers/event_logger_test.py +21 -21
- langfun/env/event_handlers/metric_writer.py +225 -140
- langfun/env/event_handlers/metric_writer_test.py +23 -6
- langfun/env/interface.py +854 -40
- langfun/env/interface_test.py +112 -2
- langfun/env/load_balancers_test.py +23 -2
- langfun/env/test_utils.py +126 -84
- {langfun-0.1.2.dev202510230805.dist-info → langfun-0.1.2.dev202511270805.dist-info}/METADATA +1 -1
- langfun-0.1.2.dev202511270805.dist-info/RECORD +215 -0
- langfun/core/eval/v2/runners_test.py +0 -343
- langfun/core/structured/schema.py +0 -987
- langfun/core/structured/schema_test.py +0 -982
- langfun/env/base_test.py +0 -1481
- langfun/env/event_handlers/base.py +0 -350
- langfun-0.1.2.dev202510230805.dist-info/RECORD +0 -195
- {langfun-0.1.2.dev202510230805.dist-info → langfun-0.1.2.dev202511270805.dist-info}/WHEEL +0 -0
- {langfun-0.1.2.dev202510230805.dist-info → langfun-0.1.2.dev202511270805.dist-info}/licenses/LICENSE +0 -0
- {langfun-0.1.2.dev202510230805.dist-info → langfun-0.1.2.dev202511270805.dist-info}/top_level.txt +0 -0
langfun/core/eval/v2/example.py
CHANGED
|
@@ -22,19 +22,30 @@ import pyglove as pg
|
|
|
22
22
|
|
|
23
23
|
@dataclasses.dataclass
|
|
24
24
|
class Example(pg.JSONConvertible, pg.views.HtmlTreeView.Extension):
|
|
25
|
-
"""An
|
|
25
|
+
"""An example for evaluation.
|
|
26
|
+
|
|
27
|
+
An evaluation example contains the input and output of an evaluation task,
|
|
28
|
+
as well as metadata about the evaluation process, such as execution time,
|
|
29
|
+
LLM usage, and metric results.
|
|
26
30
|
|
|
27
31
|
Attributes:
|
|
28
|
-
id: The 1-based ID of the
|
|
29
|
-
input: An element returned from the `Evaluable.inputs` functor
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
32
|
+
id: The 1-based ID of the example in the evaluation set.
|
|
33
|
+
input: An element returned from the `Evaluable.inputs` functor, which serves
|
|
34
|
+
as the input for `lf.Evaluable.process`.
|
|
35
|
+
output: The output of `lf.Evaluable.process` method. If `pg.MISSING_VALUE`,
|
|
36
|
+
it indicates the example has not been processed yet.
|
|
37
|
+
error: The error raised from `lf.Evaluable.process`. If None, it
|
|
38
|
+
indicates the process was successful.
|
|
39
|
+
metadata: The metadata of the example produced by `lf.Evaluable.process`.
|
|
40
|
+
metric_metadata: The dictionary returned from `Metric.audit`, which contains
|
|
41
|
+
metadata about metric computation for this example.
|
|
42
|
+
newly_processed: Whether this example is processed in the current run. If
|
|
43
|
+
False, it indicates the example was loaded from a checkpoint from previous
|
|
44
|
+
runs.
|
|
45
|
+
start_time: The start time of processing this example.
|
|
46
|
+
end_time: The end time of processing this example.
|
|
47
|
+
usage_summary: The summary of LLM usages for processing this example.
|
|
48
|
+
execution_status: The timeit status of processing this example.
|
|
38
49
|
"""
|
|
39
50
|
id: int
|
|
40
51
|
input: Any = pg.MISSING_VALUE
|
|
@@ -49,14 +60,6 @@ class Example(pg.JSONConvertible, pg.views.HtmlTreeView.Extension):
|
|
|
49
60
|
usage_summary: lf.UsageSummary | None = None
|
|
50
61
|
execution_status: dict[str, pg.utils.TimeIt.Status] | None = None
|
|
51
62
|
|
|
52
|
-
def __post_init__(self):
|
|
53
|
-
if self.execution_status is not None:
|
|
54
|
-
for status in self.execution_status.values():
|
|
55
|
-
if status.has_error:
|
|
56
|
-
assert isinstance(status.error, pg.ErrorInfo)
|
|
57
|
-
self.error = status.error
|
|
58
|
-
break
|
|
59
|
-
|
|
60
63
|
@property
|
|
61
64
|
def is_processed(self) -> bool:
|
|
62
65
|
"""Returns whether the item has been processed."""
|
|
@@ -152,6 +155,8 @@ class Example(pg.JSONConvertible, pg.views.HtmlTreeView.Extension):
|
|
|
152
155
|
ckpt_file: str | list[str],
|
|
153
156
|
example_input_by_id: Callable[[int], Any] | None = None,
|
|
154
157
|
load_example_metadata: bool = True,
|
|
158
|
+
convert_unknown: bool = True,
|
|
159
|
+
**kwargs
|
|
155
160
|
) -> Iterator['Example']:
|
|
156
161
|
"""Iterates Examples from the checkpoint files."""
|
|
157
162
|
ckpt_files = [ckpt_file] if isinstance(ckpt_file, str) else ckpt_file
|
|
@@ -161,7 +166,9 @@ class Example(pg.JSONConvertible, pg.views.HtmlTreeView.Extension):
|
|
|
161
166
|
example = pg.from_json_str(
|
|
162
167
|
record,
|
|
163
168
|
example_input_by_id=example_input_by_id,
|
|
164
|
-
load_example_metadata=load_example_metadata
|
|
169
|
+
load_example_metadata=load_example_metadata,
|
|
170
|
+
convert_unknown=convert_unknown,
|
|
171
|
+
**kwargs
|
|
165
172
|
)
|
|
166
173
|
assert isinstance(example, cls), example
|
|
167
174
|
yield example
|
|
@@ -182,15 +189,23 @@ class Example(pg.JSONConvertible, pg.views.HtmlTreeView.Extension):
|
|
|
182
189
|
extra_flags = extra_flags or {}
|
|
183
190
|
num_examples = extra_flags.get('num_examples', None)
|
|
184
191
|
|
|
185
|
-
def
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
192
|
+
def _metric_label_group(metric_metadata: dict[str, Any] | None):
|
|
193
|
+
"""Renders a label group for metric metadata."""
|
|
194
|
+
badges = []
|
|
195
|
+
if metric_metadata:
|
|
196
|
+
for metric_name, metadata in metric_metadata.items():
|
|
197
|
+
assert isinstance(metadata, dict), (metric_name, metadata)
|
|
198
|
+
for k, v in metadata.items():
|
|
199
|
+
css_class = k
|
|
200
|
+
if isinstance(v, bool):
|
|
201
|
+
css_class += '_true' if v else '_false'
|
|
202
|
+
badge = pg.views.html.controls.Badge(
|
|
203
|
+
f'{k}:{v}',
|
|
204
|
+
tooltip=f'{metric_name}: {k}',
|
|
205
|
+
css_classes=[css_class],
|
|
206
|
+
)
|
|
207
|
+
badges.append(badge)
|
|
208
|
+
return pg.views.html.controls.LabelGroup(badges)
|
|
194
209
|
|
|
195
210
|
def _render_header():
|
|
196
211
|
return pg.Html.element(
|
|
@@ -229,12 +244,7 @@ class Example(pg.JSONConvertible, pg.views.HtmlTreeView.Extension):
|
|
|
229
244
|
extra_flags=dict(as_badge=True)
|
|
230
245
|
) if self.usage_summary is not None else None,
|
|
231
246
|
# Metric metadata.
|
|
232
|
-
|
|
233
|
-
[ # pylint: disable=g-long-ternary
|
|
234
|
-
_metric_metadata_badge(k, v)
|
|
235
|
-
for k, v in self.metric_metadata.items()
|
|
236
|
-
] if self.metric_metadata else []
|
|
237
|
-
),
|
|
247
|
+
_metric_label_group(self.metric_metadata)
|
|
238
248
|
],
|
|
239
249
|
css_classes=['example-container'],
|
|
240
250
|
)
|
|
@@ -305,18 +315,18 @@ class Example(pg.JSONConvertible, pg.views.HtmlTreeView.Extension):
|
|
|
305
315
|
color: black;
|
|
306
316
|
}
|
|
307
317
|
/* Badge styles. */
|
|
308
|
-
.eval-example .badge.
|
|
318
|
+
.eval-example .badge.is_correct_true {
|
|
309
319
|
color: green;
|
|
310
320
|
background-color: #dcefbe;
|
|
311
321
|
}
|
|
322
|
+
.eval-example .badge.is_correct_false {
|
|
323
|
+
color: orange;
|
|
324
|
+
background-color: #ffefc4;
|
|
325
|
+
}
|
|
312
326
|
.eval-example .badge.error {
|
|
313
327
|
color: red;
|
|
314
328
|
background-color: #fdcccc;
|
|
315
329
|
}
|
|
316
|
-
.eval-example .badge.mismatch {
|
|
317
|
-
color: orange;
|
|
318
|
-
background-color: #ffefc4;
|
|
319
|
-
}
|
|
320
330
|
.eval-example .badge.score {
|
|
321
331
|
color: blue;
|
|
322
332
|
background-color: #c4dced;
|
|
@@ -32,9 +32,9 @@ class ExampleTest(unittest.TestCase):
|
|
|
32
32
|
name='evaluation', elapse=1.0, error=error
|
|
33
33
|
)
|
|
34
34
|
})
|
|
35
|
-
self.
|
|
35
|
+
self.assertIsNone(ex.error)
|
|
36
36
|
self.assertFalse(ex.is_processed)
|
|
37
|
-
self.
|
|
37
|
+
self.assertFalse(ex.has_error)
|
|
38
38
|
self.assertEqual(ex.elapse, 1.0)
|
|
39
39
|
|
|
40
40
|
ex = Example(id=2, output=1)
|
|
@@ -94,15 +94,23 @@ class ExampleTest(unittest.TestCase):
|
|
|
94
94
|
pg.JSONConvertible._TYPE_REGISTRY._type_to_cls_map.pop(
|
|
95
95
|
inputs[0].b.__type_name__
|
|
96
96
|
)
|
|
97
|
-
v = pg.from_json_str(
|
|
98
|
-
|
|
99
|
-
|
|
97
|
+
v = pg.from_json_str(
|
|
98
|
+
json_str,
|
|
99
|
+
convert_unknown=True,
|
|
100
|
+
load_example_metadata=True
|
|
101
|
+
)
|
|
100
102
|
self.assertEqual(
|
|
101
103
|
v,
|
|
102
104
|
Example(
|
|
103
105
|
id=1,
|
|
104
|
-
output=pg.
|
|
105
|
-
|
|
106
|
+
output=pg.symbolic.UnknownTypedObject(
|
|
107
|
+
inputs[0].a.__type_name__, x=1
|
|
108
|
+
),
|
|
109
|
+
metadata=dict(
|
|
110
|
+
b=pg.symbolic.UnknownTypedObject(
|
|
111
|
+
inputs[0].b.__type_name__, x=1, y=2
|
|
112
|
+
)
|
|
113
|
+
),
|
|
106
114
|
)
|
|
107
115
|
)
|
|
108
116
|
# Serialize with input.
|
|
@@ -116,7 +124,7 @@ class ExampleTest(unittest.TestCase):
|
|
|
116
124
|
input=pg.Dict(a=1, b=2),
|
|
117
125
|
output=3,
|
|
118
126
|
metadata=dict(sum=3),
|
|
119
|
-
metric_metadata=dict(match=True),
|
|
127
|
+
metric_metadata=dict(match=dict(match=True)),
|
|
120
128
|
)
|
|
121
129
|
self.assertNotIn(
|
|
122
130
|
'next',
|
|
@@ -139,10 +139,10 @@ class Experiment(lf.Component, pg.views.HtmlTreeView.Extension):
|
|
|
139
139
|
|
|
140
140
|
# Checkpointing
|
|
141
141
|
|
|
142
|
-
Experiments support checkpointing, which is enabled by default. It allows
|
|
142
|
+
Experiments support checkpointing, which is enabled by default. It allows
|
|
143
143
|
users to resume their experiments from a saved state. When an experiment runs,
|
|
144
|
-
it creates a new directory for that run and saves
|
|
145
|
-
|
|
144
|
+
it creates a new directory for that run and saves its progress to checkpoint
|
|
145
|
+
files. If the experiment is interrupted or fails, users can resume
|
|
146
146
|
it by specifying the 'id' or 'warm_start_from' argument (shown above) to
|
|
147
147
|
seamlessly continue from previously saved state without starting over.
|
|
148
148
|
|
|
@@ -169,7 +169,7 @@ class Experiment(lf.Component, pg.views.HtmlTreeView.Extension):
|
|
|
169
169
|
|
|
170
170
|
# Experiment Plugins
|
|
171
171
|
|
|
172
|
-
|
|
172
|
+
Experiments can be extended by plugins. Plugins can listen to the events of
|
|
173
173
|
experiment execution and produce additional outputs. For example, a plugin
|
|
174
174
|
can be added to an experiment to generate additional metrics or to save
|
|
175
175
|
additional data to a database. More details will be added in the future.
|
|
@@ -657,7 +657,30 @@ class Experiment(lf.Component, pg.views.HtmlTreeView.Extension):
|
|
|
657
657
|
|
|
658
658
|
@pg.use_init_args(['children'])
|
|
659
659
|
class Suite(Experiment):
|
|
660
|
-
"""A suite of evaluations.
|
|
660
|
+
"""A suite of evaluations.
|
|
661
|
+
|
|
662
|
+
`lf.eval.Suite` groups multiple `lf.eval.Evaluation` or other `Suite`
|
|
663
|
+
objects into a single experiment, allowing them to be run, managed, and
|
|
664
|
+
reported together.
|
|
665
|
+
|
|
666
|
+
**Example:**
|
|
667
|
+
|
|
668
|
+
```python
|
|
669
|
+
import langfun as lf
|
|
670
|
+
|
|
671
|
+
suite = lf.eval.Suite([
|
|
672
|
+
MyEval(lm=lf.llms.Gpt4()),
|
|
673
|
+
MyEval(lm=lf.llms.Gemini()),
|
|
674
|
+
lf.eval.Suite([
|
|
675
|
+
AnotherEval(lm=lf.llms.Gpt4()),
|
|
676
|
+
AnotherEval(lm=lf.llms.Gemini())
|
|
677
|
+
])
|
|
678
|
+
])
|
|
679
|
+
|
|
680
|
+
# Run all evaluations in the suite
|
|
681
|
+
run_info = suite.run('/path/to/my/suite_run')
|
|
682
|
+
```
|
|
683
|
+
"""
|
|
661
684
|
|
|
662
685
|
children: Annotated[
|
|
663
686
|
list[Experiment], 'A list of child experiments.'
|
|
@@ -791,7 +814,14 @@ class RunId(pg.Object):
|
|
|
791
814
|
|
|
792
815
|
|
|
793
816
|
class Run(pg.Object, pg.views.html.HtmlTreeView.Extension):
|
|
794
|
-
"""
|
|
817
|
+
"""Represents a single run of an experiment.
|
|
818
|
+
|
|
819
|
+
A `Run` object holds all the configurations for executing an experiment,
|
|
820
|
+
such as the experiment definition, input/output directories, and flags
|
|
821
|
+
controlling the execution behavior (e.g., error handling, checkpointing).
|
|
822
|
+
It also provides utility methods for accessing run-specific paths and
|
|
823
|
+
filtering examples for evaluation.
|
|
824
|
+
"""
|
|
795
825
|
|
|
796
826
|
root_dir: Annotated[
|
|
797
827
|
str,
|
|
@@ -971,7 +1001,13 @@ class Run(pg.Object, pg.views.html.HtmlTreeView.Extension):
|
|
|
971
1001
|
|
|
972
1002
|
|
|
973
1003
|
class Runner(pg.Object):
|
|
974
|
-
"""Interface for experiment runner.
|
|
1004
|
+
"""Interface for experiment runner.
|
|
1005
|
+
|
|
1006
|
+
A runner is responsible for executing the evaluations within an experiment
|
|
1007
|
+
based on the configuration specified in a `Run` object. Different runners
|
|
1008
|
+
can implement different execution strategies, such as sequential or parallel
|
|
1009
|
+
processing of examples and evaluations.
|
|
1010
|
+
"""
|
|
975
1011
|
|
|
976
1012
|
# Class-level variable for registering the runner.
|
|
977
1013
|
NAME = None
|
|
@@ -1010,7 +1046,37 @@ class Runner(pg.Object):
|
|
|
1010
1046
|
|
|
1011
1047
|
|
|
1012
1048
|
class Plugin(lf.Component):
|
|
1013
|
-
"""Base class for experiment plugins.
|
|
1049
|
+
"""Base class for experiment plugins.
|
|
1050
|
+
|
|
1051
|
+
Plugins provide a mechanism to extend the behavior of an experiment run
|
|
1052
|
+
by hooking into various events during the lifecycle of experiment and
|
|
1053
|
+
example execution, such as `on_run_start`, `on_experiment_complete`,
|
|
1054
|
+
`on_example_start`, etc. They can be used for custom logging, monitoring,
|
|
1055
|
+
or result processing.
|
|
1056
|
+
"""
|
|
1057
|
+
|
|
1058
|
+
@classmethod
|
|
1059
|
+
def is_per_example(cls) -> bool:
|
|
1060
|
+
"""Returns whether the plugin is per example only.
|
|
1061
|
+
|
|
1062
|
+
Per-example plugins can be installed on individual workers when examples
|
|
1063
|
+
are evaluated by multiple processes in parallel.
|
|
1064
|
+
"""
|
|
1065
|
+
|
|
1066
|
+
def same_code(method1, method2):
|
|
1067
|
+
return method1.__code__ == method2.__code__
|
|
1068
|
+
return all(
|
|
1069
|
+
same_code(method1, method2)
|
|
1070
|
+
for method1, method2 in [
|
|
1071
|
+
(Plugin.on_run_start, cls.on_run_start),
|
|
1072
|
+
(Plugin.on_run_complete, cls.on_run_complete),
|
|
1073
|
+
(Plugin.on_run_abort, cls.on_run_abort),
|
|
1074
|
+
(Plugin.on_experiment_start, cls.on_experiment_start),
|
|
1075
|
+
(Plugin.on_experiment_skipped, cls.on_experiment_skipped),
|
|
1076
|
+
(Plugin.on_experiment_complete, cls.on_experiment_complete),
|
|
1077
|
+
(Plugin.on_experiment_abort, cls.on_experiment_abort),
|
|
1078
|
+
]
|
|
1079
|
+
)
|
|
1014
1080
|
|
|
1015
1081
|
def on_run_start(
|
|
1016
1082
|
self,
|
|
@@ -433,5 +433,24 @@ class RunnerTest(unittest.TestCase):
|
|
|
433
433
|
pass
|
|
434
434
|
|
|
435
435
|
|
|
436
|
+
class PluginTest(unittest.TestCase):
|
|
437
|
+
|
|
438
|
+
def test_per_example_only(self):
|
|
439
|
+
|
|
440
|
+
class PerExamplePlugin(experiment_lib.Plugin):
|
|
441
|
+
|
|
442
|
+
def on_example_complete(self, runner, experiment, example):
|
|
443
|
+
print('on_example_complete')
|
|
444
|
+
|
|
445
|
+
self.assertTrue(PerExamplePlugin.is_per_example())
|
|
446
|
+
|
|
447
|
+
class NonPerExamplePlugin(experiment_lib.Plugin):
|
|
448
|
+
|
|
449
|
+
def on_experiment_complete(self, runner, experiment):
|
|
450
|
+
print('on_example_complete')
|
|
451
|
+
|
|
452
|
+
self.assertFalse(NonPerExamplePlugin.is_per_example())
|
|
453
|
+
|
|
454
|
+
|
|
436
455
|
if __name__ == '__main__':
|
|
437
456
|
unittest.main()
|
|
@@ -20,7 +20,15 @@ import pyglove as pg
|
|
|
20
20
|
|
|
21
21
|
|
|
22
22
|
class MetricValue(pg.Object):
|
|
23
|
-
"""Base class for metric values.
|
|
23
|
+
"""Base class for metric values.
|
|
24
|
+
|
|
25
|
+
`MetricValue` is the base class for representing aggregated metric values
|
|
26
|
+
in an evaluation. It accumulates data points from individual examples,
|
|
27
|
+
each consisting of a value and an optional weight, associated with an example
|
|
28
|
+
ID. Subclasses must implement `reduce` method to compute a single float value
|
|
29
|
+
from accumulated data points, and `scalar_repr` to provide a string
|
|
30
|
+
representation of the reduced value.
|
|
31
|
+
"""
|
|
24
32
|
|
|
25
33
|
class DataPoint(pg.Object):
|
|
26
34
|
"""A data point for a metric value."""
|
|
@@ -88,6 +96,14 @@ class MetricValue(pg.Object):
|
|
|
88
96
|
self.increment_total()
|
|
89
97
|
return self
|
|
90
98
|
|
|
99
|
+
def merge_from(self, other: 'MetricValue') -> 'MetricValue':
|
|
100
|
+
"""Merges the values from another metric value."""
|
|
101
|
+
self._weighted_sum += other._weighted_sum # pylint: disable=protected-access
|
|
102
|
+
with pg.notify_on_change(False), pg.allow_writable_accessors(True):
|
|
103
|
+
self.data_points.extend(other.data_points)
|
|
104
|
+
self.increment_total(other.total)
|
|
105
|
+
return self
|
|
106
|
+
|
|
91
107
|
def __gt__(self, other: Union['MetricValue', float]) -> bool:
|
|
92
108
|
if isinstance(other, self.__class__):
|
|
93
109
|
return float(self) > float(other)
|
|
@@ -133,7 +149,13 @@ class MetricValue(pg.Object):
|
|
|
133
149
|
|
|
134
150
|
|
|
135
151
|
class Rate(MetricValue):
|
|
136
|
-
"""
|
|
152
|
+
"""Metric value representing a rate in range [0, 1].
|
|
153
|
+
|
|
154
|
+
`Rate` is used for metrics that compute a rate, such as accuracy or error
|
|
155
|
+
rate. The final value is computed as the weighted sum of accumulated values
|
|
156
|
+
divided by the total number of examples. It's displayed as a percentage
|
|
157
|
+
(e.g., 90.0%).
|
|
158
|
+
"""
|
|
137
159
|
|
|
138
160
|
def reduce(self) -> float:
|
|
139
161
|
return self._weighted_sum / self.total
|
|
@@ -145,7 +167,13 @@ class Rate(MetricValue):
|
|
|
145
167
|
|
|
146
168
|
|
|
147
169
|
class Average(MetricValue):
|
|
148
|
-
"""
|
|
170
|
+
"""Metric value representing an average of accumulated values.
|
|
171
|
+
|
|
172
|
+
`Average` is used for metrics that compute an average score across examples
|
|
173
|
+
(e.g., average quality score). The final value is computed as the weighted
|
|
174
|
+
sum of accumulated values divided by the number of data points.
|
|
175
|
+
It's displayed as a float with 3 decimal places (e.g., 4.750).
|
|
176
|
+
"""
|
|
149
177
|
|
|
150
178
|
def reduce(self) -> float:
|
|
151
179
|
if not self.data_points:
|
|
@@ -51,6 +51,22 @@ class RateTest(unittest.TestCase):
|
|
|
51
51
|
self.assertEqual(rate.total, 0)
|
|
52
52
|
self.assertTrue(math.isnan(float(rate)))
|
|
53
53
|
|
|
54
|
+
def test_merge_from(self):
|
|
55
|
+
rate1 = metric_values.Rate()
|
|
56
|
+
rate1.add(1, 1.0, 1.0, increment_total=True)
|
|
57
|
+
rate2 = metric_values.Rate()
|
|
58
|
+
rate2.add(2, 0.0, 1.0, increment_total=True)
|
|
59
|
+
rate1.merge_from(rate2)
|
|
60
|
+
self.assertEqual(rate1.total, 2)
|
|
61
|
+
self.assertEqual(float(rate1), 0.5)
|
|
62
|
+
self.assertEqual(
|
|
63
|
+
rate1.data_points,
|
|
64
|
+
[
|
|
65
|
+
metric_values.MetricValue.DataPoint(1, 1.0, 1.0),
|
|
66
|
+
metric_values.MetricValue.DataPoint(2, 0.0, 1.0),
|
|
67
|
+
],
|
|
68
|
+
)
|
|
69
|
+
|
|
54
70
|
|
|
55
71
|
class AverageTest(unittest.TestCase):
|
|
56
72
|
|
|
@@ -75,6 +91,22 @@ class AverageTest(unittest.TestCase):
|
|
|
75
91
|
average.reset()
|
|
76
92
|
self.assertEqual(average.total, 0)
|
|
77
93
|
|
|
94
|
+
def test_merge_from(self):
|
|
95
|
+
avg1 = metric_values.Average()
|
|
96
|
+
avg1.add(1, 1.0, 0.5, increment_total=True)
|
|
97
|
+
avg2 = metric_values.Average()
|
|
98
|
+
avg2.add(2, 0.0, 1.0, increment_total=True)
|
|
99
|
+
avg1.merge_from(avg2)
|
|
100
|
+
self.assertEqual(avg1.total, 2)
|
|
101
|
+
self.assertEqual(float(avg1), 0.25)
|
|
102
|
+
self.assertEqual(
|
|
103
|
+
avg1.data_points,
|
|
104
|
+
[
|
|
105
|
+
metric_values.MetricValue.DataPoint(1, 1.0, 0.5),
|
|
106
|
+
metric_values.MetricValue.DataPoint(2, 0.0, 1.0),
|
|
107
|
+
],
|
|
108
|
+
)
|
|
109
|
+
|
|
78
110
|
|
|
79
111
|
if __name__ == '__main__':
|
|
80
112
|
unittest.main()
|