langfun 0.1.2.dev202508250805__py3-none-any.whl → 0.1.2.dev202511110805__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of langfun might be problematic. Click here for more details.
- langfun/__init__.py +1 -1
- langfun/core/__init__.py +6 -1
- langfun/core/agentic/__init__.py +4 -0
- langfun/core/agentic/action.py +412 -103
- langfun/core/agentic/action_eval.py +9 -2
- langfun/core/agentic/action_test.py +68 -6
- langfun/core/async_support.py +104 -5
- langfun/core/async_support_test.py +23 -0
- langfun/core/coding/python/correction.py +19 -9
- langfun/core/coding/python/execution.py +14 -12
- langfun/core/coding/python/generation.py +21 -16
- langfun/core/coding/python/sandboxing.py +23 -3
- langfun/core/component.py +42 -3
- langfun/core/concurrent.py +70 -6
- langfun/core/concurrent_test.py +9 -2
- langfun/core/console.py +1 -1
- langfun/core/data/conversion/anthropic.py +12 -3
- langfun/core/data/conversion/anthropic_test.py +8 -6
- langfun/core/data/conversion/gemini.py +9 -2
- langfun/core/data/conversion/gemini_test.py +12 -9
- langfun/core/data/conversion/openai.py +145 -31
- langfun/core/data/conversion/openai_test.py +161 -17
- langfun/core/eval/base.py +47 -43
- langfun/core/eval/base_test.py +4 -4
- langfun/core/eval/matching.py +5 -2
- langfun/core/eval/patching.py +3 -3
- langfun/core/eval/scoring.py +4 -3
- langfun/core/eval/v2/__init__.py +1 -0
- langfun/core/eval/v2/checkpointing.py +30 -4
- langfun/core/eval/v2/eval_test_helper.py +1 -1
- langfun/core/eval/v2/evaluation.py +60 -14
- langfun/core/eval/v2/example.py +22 -11
- langfun/core/eval/v2/experiment.py +51 -8
- langfun/core/eval/v2/metric_values.py +31 -3
- langfun/core/eval/v2/metric_values_test.py +32 -0
- langfun/core/eval/v2/metrics.py +39 -4
- langfun/core/eval/v2/metrics_test.py +14 -0
- langfun/core/eval/v2/progress.py +30 -1
- langfun/core/eval/v2/progress_test.py +27 -0
- langfun/core/eval/v2/progress_tracking_test.py +6 -0
- langfun/core/eval/v2/reporting.py +90 -71
- langfun/core/eval/v2/reporting_test.py +20 -6
- langfun/core/eval/v2/runners.py +27 -7
- langfun/core/eval/v2/runners_test.py +3 -0
- langfun/core/langfunc.py +45 -130
- langfun/core/langfunc_test.py +6 -4
- langfun/core/language_model.py +151 -31
- langfun/core/language_model_test.py +9 -3
- langfun/core/llms/__init__.py +12 -1
- langfun/core/llms/anthropic.py +157 -2
- langfun/core/llms/azure_openai.py +29 -17
- langfun/core/llms/cache/base.py +25 -3
- langfun/core/llms/cache/in_memory.py +48 -7
- langfun/core/llms/cache/in_memory_test.py +14 -4
- langfun/core/llms/compositional.py +25 -1
- langfun/core/llms/deepseek.py +30 -2
- langfun/core/llms/fake.py +39 -1
- langfun/core/llms/fake_test.py +9 -0
- langfun/core/llms/gemini.py +43 -7
- langfun/core/llms/google_genai.py +34 -1
- langfun/core/llms/groq.py +28 -3
- langfun/core/llms/llama_cpp.py +23 -4
- langfun/core/llms/openai.py +93 -3
- langfun/core/llms/openai_compatible.py +148 -27
- langfun/core/llms/openai_compatible_test.py +207 -20
- langfun/core/llms/openai_test.py +0 -2
- langfun/core/llms/rest.py +16 -1
- langfun/core/llms/vertexai.py +59 -8
- langfun/core/logging.py +1 -1
- langfun/core/mcp/__init__.py +10 -0
- langfun/core/mcp/client.py +177 -0
- langfun/core/mcp/client_test.py +71 -0
- langfun/core/mcp/session.py +241 -0
- langfun/core/mcp/session_test.py +54 -0
- langfun/core/mcp/testing/simple_mcp_client.py +33 -0
- langfun/core/mcp/testing/simple_mcp_server.py +33 -0
- langfun/core/mcp/tool.py +256 -0
- langfun/core/mcp/tool_test.py +197 -0
- langfun/core/memory.py +1 -0
- langfun/core/message.py +160 -55
- langfun/core/message_test.py +65 -81
- langfun/core/modalities/__init__.py +8 -0
- langfun/core/modalities/audio.py +21 -1
- langfun/core/modalities/image.py +19 -1
- langfun/core/modalities/mime.py +62 -3
- langfun/core/modalities/pdf.py +19 -1
- langfun/core/modalities/video.py +21 -1
- langfun/core/modality.py +167 -29
- langfun/core/modality_test.py +42 -12
- langfun/core/natural_language.py +1 -1
- langfun/core/sampling.py +4 -4
- langfun/core/sampling_test.py +20 -4
- langfun/core/structured/completion.py +34 -44
- langfun/core/structured/completion_test.py +23 -43
- langfun/core/structured/description.py +54 -50
- langfun/core/structured/function_generation.py +29 -12
- langfun/core/structured/mapping.py +74 -28
- langfun/core/structured/parsing.py +90 -74
- langfun/core/structured/parsing_test.py +0 -3
- langfun/core/structured/querying.py +242 -156
- langfun/core/structured/querying_test.py +95 -64
- langfun/core/structured/schema.py +70 -10
- langfun/core/structured/schema_generation.py +33 -14
- langfun/core/structured/scoring.py +45 -34
- langfun/core/structured/tokenization.py +24 -9
- langfun/core/subscription.py +2 -2
- langfun/core/template.py +175 -50
- langfun/core/template_test.py +123 -17
- langfun/env/__init__.py +43 -0
- langfun/env/base_environment.py +827 -0
- langfun/env/base_environment_test.py +473 -0
- langfun/env/base_feature.py +304 -0
- langfun/env/base_feature_test.py +228 -0
- langfun/env/base_sandbox.py +842 -0
- langfun/env/base_sandbox_test.py +1235 -0
- langfun/env/event_handlers/__init__.py +14 -0
- langfun/env/event_handlers/chain.py +233 -0
- langfun/env/event_handlers/chain_test.py +253 -0
- langfun/env/event_handlers/event_logger.py +472 -0
- langfun/env/event_handlers/event_logger_test.py +304 -0
- langfun/env/event_handlers/metric_writer.py +726 -0
- langfun/env/event_handlers/metric_writer_test.py +214 -0
- langfun/env/interface.py +1640 -0
- langfun/env/interface_test.py +151 -0
- langfun/env/load_balancers.py +59 -0
- langfun/env/load_balancers_test.py +139 -0
- langfun/env/test_utils.py +497 -0
- {langfun-0.1.2.dev202508250805.dist-info → langfun-0.1.2.dev202511110805.dist-info}/METADATA +7 -3
- langfun-0.1.2.dev202511110805.dist-info/RECORD +200 -0
- langfun-0.1.2.dev202508250805.dist-info/RECORD +0 -172
- {langfun-0.1.2.dev202508250805.dist-info → langfun-0.1.2.dev202511110805.dist-info}/WHEEL +0 -0
- {langfun-0.1.2.dev202508250805.dist-info → langfun-0.1.2.dev202511110805.dist-info}/licenses/LICENSE +0 -0
- {langfun-0.1.2.dev202508250805.dist-info → langfun-0.1.2.dev202511110805.dist-info}/top_level.txt +0 -0
langfun/core/eval/v2/__init__.py
CHANGED
|
@@ -38,6 +38,7 @@ from langfun.core.eval.v2 import runners
|
|
|
38
38
|
from langfun.core.eval.v2.checkpointing import BulkCheckpointer
|
|
39
39
|
from langfun.core.eval.v2.checkpointing import PerExampleCheckpointer
|
|
40
40
|
from langfun.core.eval.v2.reporting import HtmlReporter
|
|
41
|
+
from langfun.core.eval.v2.reporting import ExampleHtmlGenerator
|
|
41
42
|
|
|
42
43
|
|
|
43
44
|
# pylint: enable=g-bad-import-order
|
|
@@ -29,7 +29,17 @@ Runner = experiment_lib.Runner
|
|
|
29
29
|
|
|
30
30
|
|
|
31
31
|
class Checkpointer(experiment_lib.Plugin):
|
|
32
|
-
"""Base class for checkpointing evaluation examples.
|
|
32
|
+
"""Base class for checkpointing evaluation examples.
|
|
33
|
+
|
|
34
|
+
`Checkpointer` is a plugin that saves the state of processed examples
|
|
35
|
+
incrementally during an experiment run, allowing the experiment to be resumed
|
|
36
|
+
later. When an experiment starts, the checkpointer loads any previously saved
|
|
37
|
+
examples from an earlier run (or a warm-start run) into `experiment.state`,
|
|
38
|
+
so the runner can skip processing them again.
|
|
39
|
+
Subclasses should implement `_list_checkpoint_filenames` to identify
|
|
40
|
+
checkpoint files to load, and `_save_example` to save a newly processed
|
|
41
|
+
example.
|
|
42
|
+
"""
|
|
33
43
|
|
|
34
44
|
checkpoint_filename: Annotated[
|
|
35
45
|
str,
|
|
@@ -170,7 +180,12 @@ class Checkpointer(experiment_lib.Plugin):
|
|
|
170
180
|
|
|
171
181
|
|
|
172
182
|
class PerExampleCheckpointer(Checkpointer):
|
|
173
|
-
"""Checkpointer that saves each example to a separate file.
|
|
183
|
+
"""Checkpointer that saves each example to a separate file.
|
|
184
|
+
|
|
185
|
+
This checkpointer saves each processed example to its own checkpoint file,
|
|
186
|
+
named using the pattern `<checkpoint_filename_prefix>_<example_id>.<ext>`.
|
|
187
|
+
For example, `checkpoint_1.bagz`, `checkpoint_2.bagz`, etc.
|
|
188
|
+
"""
|
|
174
189
|
|
|
175
190
|
def _on_bound(self):
|
|
176
191
|
super()._on_bound()
|
|
@@ -235,7 +250,13 @@ class PerExampleCheckpointer(Checkpointer):
|
|
|
235
250
|
|
|
236
251
|
|
|
237
252
|
class BulkCheckpointer(Checkpointer):
|
|
238
|
-
"""Checkpointer that saves all examples to a single file.
|
|
253
|
+
"""Checkpointer that saves all examples of an evaluation to a single file.
|
|
254
|
+
|
|
255
|
+
This checkpointer appends newly processed examples of an evaluation to a
|
|
256
|
+
single sequence file (e.g., `checkpoint.bagz`). This is often more efficient
|
|
257
|
+
than `PerExampleCheckpointer` when dealing with a large number of examples
|
|
258
|
+
or when file system overhead is a concern.
|
|
259
|
+
"""
|
|
239
260
|
|
|
240
261
|
def _on_bound(self):
|
|
241
262
|
super()._on_bound()
|
|
@@ -341,7 +362,12 @@ class BulkCheckpointer(Checkpointer):
|
|
|
341
362
|
|
|
342
363
|
|
|
343
364
|
class SequenceWriter:
|
|
344
|
-
"""
|
|
365
|
+
"""A thread-safe writer for sequence files (e.g., Bagz).
|
|
366
|
+
|
|
367
|
+
`SequenceWriter` wraps a `pg.io.SequenceWriter` to provide thread-safe
|
|
368
|
+
`add` and `close` operations, ensuring that examples can be written
|
|
369
|
+
concurrently from multiple threads without corrupting the sequence file.
|
|
370
|
+
"""
|
|
345
371
|
|
|
346
372
|
def __init__(self, path: str):
|
|
347
373
|
self._lock = threading.Lock()
|
|
@@ -32,17 +32,63 @@ import pyglove as pg
|
|
|
32
32
|
|
|
33
33
|
|
|
34
34
|
class Evaluation(experiment_lib.Experiment):
|
|
35
|
-
"""
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
35
|
+
"""Base class for Langfun evaluations.
|
|
36
|
+
|
|
37
|
+
`lf.eval.Evaluation` is the base class for defining evaluation tasks in
|
|
38
|
+
Langfun. Users typically subclass it to implement custom evaluation logic by
|
|
39
|
+
overriding `inputs` and `process` methods.
|
|
40
|
+
|
|
41
|
+
An `Evaluation` object encapsulates:
|
|
42
|
+
|
|
43
|
+
* **`inputs`**: A callable that returns an iterable of input examples to be
|
|
44
|
+
processed. This is usually provided by implementing an `inputs(self)`
|
|
45
|
+
method in the subclass, which yields input items for evaluation one by
|
|
46
|
+
one.
|
|
47
|
+
* **`process(self, example)`**: An abstract method that processes one
|
|
48
|
+
example and returns the output, or a tuple of (output, metadata).
|
|
49
|
+
The output will be used for computing metrics.
|
|
50
|
+
* **`metrics`**: A list of metrics (e.g., `lf.metrics.Accuracy`) to compute
|
|
51
|
+
based on the outputs from `process`. Some metrics may require users to
|
|
52
|
+
implement a `ground_truth(self, example)` method in the subclass to
|
|
53
|
+
compute metrics against ground truth.
|
|
54
|
+
* **Hyperparameters**: Any other attributes of the class serve as
|
|
55
|
+
hyperparameters for the evaluation (e.g., the language model to use).
|
|
56
|
+
|
|
57
|
+
**Running Evaluations:**
|
|
58
|
+
|
|
59
|
+
Evaluations are executed via `lf.eval.Suite` or by calling the `.run()`
|
|
60
|
+
method on an `Evaluation` instance, which returns a `Run` object
|
|
61
|
+
containing the evaluation run information and results. If an evaluation
|
|
62
|
+
contains sweeable parameters (using `pg.oneof`), `.run()` will expand it
|
|
63
|
+
into multiple evaluation sub-tasks -- one for each combination of
|
|
64
|
+
hyperparameters -- all managed within the same `Run`.
|
|
65
|
+
|
|
66
|
+
**Example:**
|
|
67
|
+
|
|
68
|
+
```python
|
|
69
|
+
import langfun as lf
|
|
70
|
+
import pyglove as pg
|
|
71
|
+
|
|
72
|
+
class MyEval(lf.eval.Evaluation):
|
|
73
|
+
lm: lf.LanguageModel
|
|
74
|
+
prompt: str = '1 + 1 = '
|
|
75
|
+
|
|
76
|
+
def inputs(self):
|
|
77
|
+
yield 2
|
|
78
|
+
|
|
79
|
+
def process(self, example: lf.eval.Example):
|
|
80
|
+
return int(lf.query(self.prompt, lm=self.lm))
|
|
81
|
+
|
|
82
|
+
def ground_truth(self, example: lf.eval.Example) -> int:
|
|
83
|
+
return example.input
|
|
84
|
+
|
|
85
|
+
# Run evaluation using two different LMs
|
|
86
|
+
evaluation = MyEval(
|
|
87
|
+
lm=pg.oneof([lf.llms.Gpt4(), lf.llms.Gemini()]),
|
|
88
|
+
metrics=[lf.metrics.Accuracy()]
|
|
89
|
+
)
|
|
90
|
+
run_info = evaluation.run()
|
|
91
|
+
```
|
|
46
92
|
"""
|
|
47
93
|
|
|
48
94
|
inputs: Annotated[
|
|
@@ -137,7 +183,7 @@ class Evaluation(experiment_lib.Experiment):
|
|
|
137
183
|
|
|
138
184
|
Args:
|
|
139
185
|
example: An example object to process. `example.input` is an object
|
|
140
|
-
|
|
186
|
+
yielded from `inputs()` method.
|
|
141
187
|
|
|
142
188
|
Returns:
|
|
143
189
|
A processed output. Or a tuple of (output, metadata).
|
|
@@ -287,7 +333,7 @@ class Evaluation(experiment_lib.Experiment):
|
|
|
287
333
|
A unique string representing the resource required.
|
|
288
334
|
"""
|
|
289
335
|
return {
|
|
290
|
-
v.resource_id for _, v in self.sym_init_args.
|
|
336
|
+
v.resource_id for _, v in self.sym_init_args.sym_items()
|
|
291
337
|
if isinstance(v, lf.LanguageModel)
|
|
292
338
|
}
|
|
293
339
|
|
|
@@ -760,7 +806,7 @@ class Evaluation(experiment_lib.Experiment):
|
|
|
760
806
|
|
|
761
807
|
|
|
762
808
|
class EvaluationState:
|
|
763
|
-
"""
|
|
809
|
+
"""In-memory state of an evaluation."""
|
|
764
810
|
|
|
765
811
|
class ExampleStatus(pg.Object):
|
|
766
812
|
"""Example state."""
|
langfun/core/eval/v2/example.py
CHANGED
|
@@ -22,19 +22,30 @@ import pyglove as pg
|
|
|
22
22
|
|
|
23
23
|
@dataclasses.dataclass
|
|
24
24
|
class Example(pg.JSONConvertible, pg.views.HtmlTreeView.Extension):
|
|
25
|
-
"""An
|
|
25
|
+
"""An example for evaluation.
|
|
26
|
+
|
|
27
|
+
An evaluation example contains the input and output of an evaluation task,
|
|
28
|
+
as well as metadata about the evaluation process, such as execution time,
|
|
29
|
+
LLM usage, and metric results.
|
|
26
30
|
|
|
27
31
|
Attributes:
|
|
28
|
-
id: The 1-based ID of the
|
|
29
|
-
input: An element returned from the `Evaluable.inputs` functor
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
32
|
+
id: The 1-based ID of the example in the evaluation set.
|
|
33
|
+
input: An element returned from the `Evaluable.inputs` functor, which serves
|
|
34
|
+
as the input for `lf.Evaluable.process`.
|
|
35
|
+
output: The output of `lf.Evaluable.process` method. If `pg.MISSING_VALUE`,
|
|
36
|
+
it indicates the example has not been processed yet.
|
|
37
|
+
error: The error encountered during `lf.Evaluable.process`. If None, it
|
|
38
|
+
indicates the process was successful.
|
|
39
|
+
metadata: The metadata of the example produced by `lf.Evaluable.process`.
|
|
40
|
+
metric_metadata: The dictionary returned from `Metric.audit`, which contains
|
|
41
|
+
metadata about metric computation for this example.
|
|
42
|
+
newly_processed: Whether this example is processed in the current run. If
|
|
43
|
+
False, it indicates the example was loaded from a checkpoint from previous
|
|
44
|
+
runs.
|
|
45
|
+
start_time: The start time of processing this example.
|
|
46
|
+
end_time: The end time of processing this example.
|
|
47
|
+
usage_summary: The summary of LLM usages for processing this example.
|
|
48
|
+
execution_status: The timeit status of processing this example.
|
|
38
49
|
"""
|
|
39
50
|
id: int
|
|
40
51
|
input: Any = pg.MISSING_VALUE
|
|
@@ -139,10 +139,10 @@ class Experiment(lf.Component, pg.views.HtmlTreeView.Extension):
|
|
|
139
139
|
|
|
140
140
|
# Checkpointing
|
|
141
141
|
|
|
142
|
-
Experiments support checkpointing, which is enabled by default. It allows
|
|
142
|
+
Experiments support checkpointing, which is enabled by default. It allows
|
|
143
143
|
users to resume their experiments from a saved state. When an experiment runs,
|
|
144
|
-
it creates a new directory for that run and saves
|
|
145
|
-
|
|
144
|
+
it creates a new directory for that run and saves its progress to checkpoint
|
|
145
|
+
files. If the experiment is interrupted or fails, users can resume
|
|
146
146
|
it by specifying the 'id' or 'warm_start_from' argument (shown above) to
|
|
147
147
|
seamlessly continue from previously saved state without starting over.
|
|
148
148
|
|
|
@@ -169,7 +169,7 @@ class Experiment(lf.Component, pg.views.HtmlTreeView.Extension):
|
|
|
169
169
|
|
|
170
170
|
# Experiment Plugins
|
|
171
171
|
|
|
172
|
-
|
|
172
|
+
Experiments can be extended by plugins. Plugins can listen to the events of
|
|
173
173
|
experiment execution and produce additional outputs. For example, a plugin
|
|
174
174
|
can be added to an experiment to generate additional metrics or to save
|
|
175
175
|
additional data to a database. More details will be added in the future.
|
|
@@ -657,7 +657,30 @@ class Experiment(lf.Component, pg.views.HtmlTreeView.Extension):
|
|
|
657
657
|
|
|
658
658
|
@pg.use_init_args(['children'])
|
|
659
659
|
class Suite(Experiment):
|
|
660
|
-
"""A suite of evaluations.
|
|
660
|
+
"""A suite of evaluations.
|
|
661
|
+
|
|
662
|
+
`lf.eval.Suite` groups multiple `lf.eval.Evaluation` or other `Suite`
|
|
663
|
+
objects into a single experiment, allowing them to be run, managed, and
|
|
664
|
+
reported together.
|
|
665
|
+
|
|
666
|
+
**Example:**
|
|
667
|
+
|
|
668
|
+
```python
|
|
669
|
+
import langfun as lf
|
|
670
|
+
|
|
671
|
+
suite = lf.eval.Suite([
|
|
672
|
+
MyEval(lm=lf.llms.Gpt4()),
|
|
673
|
+
MyEval(lm=lf.llms.Gemini()),
|
|
674
|
+
lf.eval.Suite([
|
|
675
|
+
AnotherEval(lm=lf.llms.Gpt4()),
|
|
676
|
+
AnotherEval(lm=lf.llms.Gemini())
|
|
677
|
+
])
|
|
678
|
+
])
|
|
679
|
+
|
|
680
|
+
# Run all evaluations in the suite
|
|
681
|
+
run_info = suite.run('/path/to/my/suite_run')
|
|
682
|
+
```
|
|
683
|
+
"""
|
|
661
684
|
|
|
662
685
|
children: Annotated[
|
|
663
686
|
list[Experiment], 'A list of child experiments.'
|
|
@@ -791,7 +814,14 @@ class RunId(pg.Object):
|
|
|
791
814
|
|
|
792
815
|
|
|
793
816
|
class Run(pg.Object, pg.views.html.HtmlTreeView.Extension):
|
|
794
|
-
"""
|
|
817
|
+
"""Represents a single run of an experiment.
|
|
818
|
+
|
|
819
|
+
A `Run` object holds all the configurations for executing an experiment,
|
|
820
|
+
such as the experiment definition, input/output directories, and flags
|
|
821
|
+
controlling the execution behavior (e.g., error handling, checkpointing).
|
|
822
|
+
It also provides utility methods for accessing run-specific paths and
|
|
823
|
+
filtering examples for evaluation.
|
|
824
|
+
"""
|
|
795
825
|
|
|
796
826
|
root_dir: Annotated[
|
|
797
827
|
str,
|
|
@@ -971,7 +1001,13 @@ class Run(pg.Object, pg.views.html.HtmlTreeView.Extension):
|
|
|
971
1001
|
|
|
972
1002
|
|
|
973
1003
|
class Runner(pg.Object):
|
|
974
|
-
"""Interface for experiment runner.
|
|
1004
|
+
"""Interface for experiment runner.
|
|
1005
|
+
|
|
1006
|
+
A runner is responsible for executing the evaluations within an experiment
|
|
1007
|
+
based on the configuration specified in a `Run` object. Different runners
|
|
1008
|
+
can implement different execution strategies, such as sequential or parallel
|
|
1009
|
+
processing of examples and evaluations.
|
|
1010
|
+
"""
|
|
975
1011
|
|
|
976
1012
|
# Class-level variable for registering the runner.
|
|
977
1013
|
NAME = None
|
|
@@ -1010,7 +1046,14 @@ class Runner(pg.Object):
|
|
|
1010
1046
|
|
|
1011
1047
|
|
|
1012
1048
|
class Plugin(lf.Component):
|
|
1013
|
-
"""Base class for experiment plugins.
|
|
1049
|
+
"""Base class for experiment plugins.
|
|
1050
|
+
|
|
1051
|
+
Plugins provide a mechanism to extend the behavior of an experiment run
|
|
1052
|
+
by hooking into various events during the lifecycle of experiment and
|
|
1053
|
+
example execution, such as `on_run_start`, `on_experiment_complete`,
|
|
1054
|
+
`on_example_start`, etc. They can be used for custom logging, monitoring,
|
|
1055
|
+
or result processing.
|
|
1056
|
+
"""
|
|
1014
1057
|
|
|
1015
1058
|
def on_run_start(
|
|
1016
1059
|
self,
|
|
@@ -20,7 +20,15 @@ import pyglove as pg
|
|
|
20
20
|
|
|
21
21
|
|
|
22
22
|
class MetricValue(pg.Object):
|
|
23
|
-
"""Base class for metric values.
|
|
23
|
+
"""Base class for metric values.
|
|
24
|
+
|
|
25
|
+
`MetricValue` is the base class for representing aggregated metric values
|
|
26
|
+
in an evaluation. It accumulates data points from individual examples,
|
|
27
|
+
each consisting of a value and an optional weight, associated with an example
|
|
28
|
+
ID. Subclasses must implement `reduce` method to compute a single float value
|
|
29
|
+
from accumulated data points, and `scalar_repr` to provide a string
|
|
30
|
+
representation of the reduced value.
|
|
31
|
+
"""
|
|
24
32
|
|
|
25
33
|
class DataPoint(pg.Object):
|
|
26
34
|
"""A data point for a metric value."""
|
|
@@ -88,6 +96,14 @@ class MetricValue(pg.Object):
|
|
|
88
96
|
self.increment_total()
|
|
89
97
|
return self
|
|
90
98
|
|
|
99
|
+
def merge_from(self, other: 'MetricValue') -> 'MetricValue':
|
|
100
|
+
"""Merges the values from another metric value."""
|
|
101
|
+
self._weighted_sum += other._weighted_sum # pylint: disable=protected-access
|
|
102
|
+
with pg.notify_on_change(False), pg.allow_writable_accessors(True):
|
|
103
|
+
self.data_points.extend(other.data_points)
|
|
104
|
+
self.increment_total(other.total)
|
|
105
|
+
return self
|
|
106
|
+
|
|
91
107
|
def __gt__(self, other: Union['MetricValue', float]) -> bool:
|
|
92
108
|
if isinstance(other, self.__class__):
|
|
93
109
|
return float(self) > float(other)
|
|
@@ -133,7 +149,13 @@ class MetricValue(pg.Object):
|
|
|
133
149
|
|
|
134
150
|
|
|
135
151
|
class Rate(MetricValue):
|
|
136
|
-
"""
|
|
152
|
+
"""Metric value representing a rate in range [0, 1].
|
|
153
|
+
|
|
154
|
+
`Rate` is used for metrics that compute a rate, such as accuracy or error
|
|
155
|
+
rate. The final value is computed as the weighted sum of accumulated values
|
|
156
|
+
divided by the total number of examples. It's displayed as a percentage
|
|
157
|
+
(e.g., 90.0%).
|
|
158
|
+
"""
|
|
137
159
|
|
|
138
160
|
def reduce(self) -> float:
|
|
139
161
|
return self._weighted_sum / self.total
|
|
@@ -145,7 +167,13 @@ class Rate(MetricValue):
|
|
|
145
167
|
|
|
146
168
|
|
|
147
169
|
class Average(MetricValue):
|
|
148
|
-
"""
|
|
170
|
+
"""Metric value representing an average of accumulated values.
|
|
171
|
+
|
|
172
|
+
`Average` is used for metrics that compute an average score across examples
|
|
173
|
+
(e.g., average quality score). The final value is computed as the weighted
|
|
174
|
+
sum of accumulated values divided by the number of data points.
|
|
175
|
+
It's displayed as a float with 3 decimal places (e.g., 4.750).
|
|
176
|
+
"""
|
|
149
177
|
|
|
150
178
|
def reduce(self) -> float:
|
|
151
179
|
if not self.data_points:
|
|
@@ -51,6 +51,22 @@ class RateTest(unittest.TestCase):
|
|
|
51
51
|
self.assertEqual(rate.total, 0)
|
|
52
52
|
self.assertTrue(math.isnan(float(rate)))
|
|
53
53
|
|
|
54
|
+
def test_merge_from(self):
|
|
55
|
+
rate1 = metric_values.Rate()
|
|
56
|
+
rate1.add(1, 1.0, 1.0, increment_total=True)
|
|
57
|
+
rate2 = metric_values.Rate()
|
|
58
|
+
rate2.add(2, 0.0, 1.0, increment_total=True)
|
|
59
|
+
rate1.merge_from(rate2)
|
|
60
|
+
self.assertEqual(rate1.total, 2)
|
|
61
|
+
self.assertEqual(float(rate1), 0.5)
|
|
62
|
+
self.assertEqual(
|
|
63
|
+
rate1.data_points,
|
|
64
|
+
[
|
|
65
|
+
metric_values.MetricValue.DataPoint(1, 1.0, 1.0),
|
|
66
|
+
metric_values.MetricValue.DataPoint(2, 0.0, 1.0),
|
|
67
|
+
],
|
|
68
|
+
)
|
|
69
|
+
|
|
54
70
|
|
|
55
71
|
class AverageTest(unittest.TestCase):
|
|
56
72
|
|
|
@@ -75,6 +91,22 @@ class AverageTest(unittest.TestCase):
|
|
|
75
91
|
average.reset()
|
|
76
92
|
self.assertEqual(average.total, 0)
|
|
77
93
|
|
|
94
|
+
def test_merge_from(self):
|
|
95
|
+
avg1 = metric_values.Average()
|
|
96
|
+
avg1.add(1, 1.0, 0.5, increment_total=True)
|
|
97
|
+
avg2 = metric_values.Average()
|
|
98
|
+
avg2.add(2, 0.0, 1.0, increment_total=True)
|
|
99
|
+
avg1.merge_from(avg2)
|
|
100
|
+
self.assertEqual(avg1.total, 2)
|
|
101
|
+
self.assertEqual(float(avg1), 0.25)
|
|
102
|
+
self.assertEqual(
|
|
103
|
+
avg1.data_points,
|
|
104
|
+
[
|
|
105
|
+
metric_values.MetricValue.DataPoint(1, 1.0, 0.5),
|
|
106
|
+
metric_values.MetricValue.DataPoint(2, 0.0, 1.0),
|
|
107
|
+
],
|
|
108
|
+
)
|
|
109
|
+
|
|
78
110
|
|
|
79
111
|
if __name__ == '__main__':
|
|
80
112
|
unittest.main()
|
langfun/core/eval/v2/metrics.py
CHANGED
|
@@ -29,7 +29,15 @@ Average = metric_values.Average
|
|
|
29
29
|
|
|
30
30
|
|
|
31
31
|
class Metric(pg.Object, pg.views.HtmlTreeView.Extension):
|
|
32
|
-
"""Interface for an evaluation metric.
|
|
32
|
+
"""Interface for an evaluation metric.
|
|
33
|
+
|
|
34
|
+
A metric is used to evaluate the quality of the outputs produced by an
|
|
35
|
+
evaluation. It works by auditing each processed example via its `audit`
|
|
36
|
+
method, which in turn calls the user-overridable `_audit` method to perform
|
|
37
|
+
metric-specific logic and update metric values. Metrics can compute multiple
|
|
38
|
+
values (e.g., precision, recall, F1 score) which are exposed via the
|
|
39
|
+
`values` method.
|
|
40
|
+
"""
|
|
33
41
|
|
|
34
42
|
name: Annotated[
|
|
35
43
|
str,
|
|
@@ -71,6 +79,12 @@ class Metric(pg.Object, pg.views.HtmlTreeView.Extension):
|
|
|
71
79
|
for v in self.values():
|
|
72
80
|
v.reset()
|
|
73
81
|
|
|
82
|
+
def merge_from(self, other: 'Metric') -> 'Metric':
|
|
83
|
+
"""Merges the values from another metric."""
|
|
84
|
+
for v1, v2 in zip(self.values(), other.values()):
|
|
85
|
+
v1.merge_from(v2)
|
|
86
|
+
return self
|
|
87
|
+
|
|
74
88
|
def _update_view(self):
|
|
75
89
|
"""Refreshes the metric values."""
|
|
76
90
|
if self._label_group is None:
|
|
@@ -169,7 +183,15 @@ class Metric(pg.Object, pg.views.HtmlTreeView.Extension):
|
|
|
169
183
|
|
|
170
184
|
|
|
171
185
|
class MetricBase(Metric):
|
|
172
|
-
"""Base class for common metrics.
|
|
186
|
+
"""Base class for common metrics.
|
|
187
|
+
|
|
188
|
+
`MetricBase` provides common functionalities for metrics, such as automatic
|
|
189
|
+
error counting based on whether an example has an error during evaluation.
|
|
190
|
+
It distinguishes between Object-Oriented Programming (OOP) errors
|
|
191
|
+
(e.g. `MappingError` during structured output generation) and other errors.
|
|
192
|
+
Subclasses should implement `_audit_processed` for metric computation on
|
|
193
|
+
successfully processed examples.
|
|
194
|
+
"""
|
|
173
195
|
|
|
174
196
|
oop_errors: Rate | None = Rate()
|
|
175
197
|
non_oop_errors: Rate | None = Rate()
|
|
@@ -229,7 +251,13 @@ class MetricBase(Metric):
|
|
|
229
251
|
|
|
230
252
|
|
|
231
253
|
class Match(MetricBase):
|
|
232
|
-
"""Metric for matching outputs against
|
|
254
|
+
"""Metric for matching outputs against ground truth.
|
|
255
|
+
|
|
256
|
+
This metric computes match and mismatch rates by comparing the output of
|
|
257
|
+
an example with its ground truth. By default, it looks for a `groundtruth`
|
|
258
|
+
attribute in `example.input` for comparison. Users can customize this behavior
|
|
259
|
+
by subclassing `Match` and overriding the `match` method.
|
|
260
|
+
"""
|
|
233
261
|
|
|
234
262
|
name = 'match'
|
|
235
263
|
matches: Rate = Rate()
|
|
@@ -302,7 +330,14 @@ class Match(MetricBase):
|
|
|
302
330
|
|
|
303
331
|
|
|
304
332
|
class Score(MetricBase):
|
|
305
|
-
"""Base class for scoring.
|
|
333
|
+
"""Base class for scoring metrics.
|
|
334
|
+
|
|
335
|
+
`Score` is a base class for metrics that assign a numerical score to each
|
|
336
|
+
example's output (e.g., evaluating quality on a scale of 1-5).
|
|
337
|
+
It automatically computes the average score across all examples.
|
|
338
|
+
Subclasses must implement the `score` method to define how an example
|
|
339
|
+
should be scored.
|
|
340
|
+
"""
|
|
306
341
|
|
|
307
342
|
name = 'score'
|
|
308
343
|
average_score: Average = Average()
|
|
@@ -106,6 +106,20 @@ class MatchTest(unittest.TestCase):
|
|
|
106
106
|
m.audit(Example(id=2, input=pg.Dict(groundtruth=1), output=2))
|
|
107
107
|
self.assertEqual(len(scripts), 12)
|
|
108
108
|
|
|
109
|
+
def test_merge_from(self):
|
|
110
|
+
m1 = metrics.Match()
|
|
111
|
+
m1.audit(Example(id=1, input=pg.Dict(groundtruth=1), output=1))
|
|
112
|
+
m2 = metrics.Match()
|
|
113
|
+
m2.audit(Example(id=2, input=pg.Dict(groundtruth=1), output=2))
|
|
114
|
+
m1.merge_from(m2)
|
|
115
|
+
self.assertEqual(m1.matches, 0.5)
|
|
116
|
+
self.assertEqual(m1.mismatches, 0.5)
|
|
117
|
+
self.assertEqual(m1.oop_errors, 0.0)
|
|
118
|
+
self.assertEqual(m1.non_oop_errors, 0.0)
|
|
119
|
+
self.assertEqual(m1.matches.total, 2)
|
|
120
|
+
self.assertEqual(len(m1.matches.data_points), 1)
|
|
121
|
+
self.assertEqual(len(m1.mismatches.data_points), 1)
|
|
122
|
+
|
|
109
123
|
|
|
110
124
|
class ScoreTest(unittest.TestCase):
|
|
111
125
|
|
langfun/core/eval/v2/progress.py
CHANGED
|
@@ -21,7 +21,15 @@ import pyglove as pg
|
|
|
21
21
|
|
|
22
22
|
|
|
23
23
|
class Progress(pg.Object, pg.views.HtmlTreeView.Extension):
|
|
24
|
-
"""
|
|
24
|
+
"""Represents and tracks the progress of an evaluation.
|
|
25
|
+
|
|
26
|
+
The `Progress` class maintains counts of processed, failed, and skipped
|
|
27
|
+
items in an evaluation, along with timing information (start time, stop time,
|
|
28
|
+
duration) and an execution summary. It provides properties to check the
|
|
29
|
+
status of the evaluation (e.g., `is_started`, `is_completed`) and methods
|
|
30
|
+
to update progress as items are evaluated.
|
|
31
|
+
It also supports HTML rendering as a progress bar for visualization.
|
|
32
|
+
"""
|
|
25
33
|
|
|
26
34
|
num_total: Annotated[
|
|
27
35
|
int | None,
|
|
@@ -216,6 +224,27 @@ class Progress(pg.Object, pg.views.HtmlTreeView.Extension):
|
|
|
216
224
|
"""Overrides nondefault values so volatile values are not included."""
|
|
217
225
|
return dict()
|
|
218
226
|
|
|
227
|
+
def merge_from(self, other: 'Progress') -> None:
|
|
228
|
+
"""Merges the progress from another progress."""
|
|
229
|
+
with pg.notify_on_change(False), pg.allow_writable_accessors(True):
|
|
230
|
+
if other.start_time is not None and (
|
|
231
|
+
self.start_time is None or self.start_time > other.start_time):
|
|
232
|
+
self.start_time = other.start_time
|
|
233
|
+
|
|
234
|
+
if other.stop_time is not None and (
|
|
235
|
+
self.stop_time is None or self.stop_time < other.stop_time):
|
|
236
|
+
self.stop_time = other.stop_time
|
|
237
|
+
|
|
238
|
+
if other.num_total is not None:
|
|
239
|
+
if self.num_total is None:
|
|
240
|
+
self.num_total = other.num_total
|
|
241
|
+
else:
|
|
242
|
+
assert self.num_total == other.num_total, (self, other)
|
|
243
|
+
self.num_processed += other.num_processed
|
|
244
|
+
self.num_failed += other.num_failed
|
|
245
|
+
self.num_skipped += other.num_skipped
|
|
246
|
+
self.execution_summary.aggregate(other.execution_summary.breakdown)
|
|
247
|
+
|
|
219
248
|
#
|
|
220
249
|
# HTML view.
|
|
221
250
|
#
|
|
@@ -77,6 +77,33 @@ class ProgressTest(unittest.TestCase):
|
|
|
77
77
|
self.assertTrue(p.is_stopped)
|
|
78
78
|
self.assertIsNotNone(p.stop_time_str)
|
|
79
79
|
|
|
80
|
+
def test_merge_from(self):
|
|
81
|
+
p1 = Progress()
|
|
82
|
+
p1.start(10)
|
|
83
|
+
p1.increment_processed()
|
|
84
|
+
p1.increment_failed()
|
|
85
|
+
p1.stop()
|
|
86
|
+
|
|
87
|
+
p2 = Progress()
|
|
88
|
+
p2.start(10)
|
|
89
|
+
p2.increment_skipped()
|
|
90
|
+
p2.stop()
|
|
91
|
+
|
|
92
|
+
with pg.allow_writable_accessors(True):
|
|
93
|
+
p1.start_time = 2.0
|
|
94
|
+
p1.stop_time = 4.0
|
|
95
|
+
p2.start_time = 1.0
|
|
96
|
+
p2.stop_time = 5.0
|
|
97
|
+
|
|
98
|
+
p1.merge_from(p2)
|
|
99
|
+
self.assertEqual(p1.num_total, 10)
|
|
100
|
+
self.assertEqual(p1.num_processed, 1)
|
|
101
|
+
self.assertEqual(p1.num_failed, 1)
|
|
102
|
+
self.assertEqual(p1.num_skipped, 1)
|
|
103
|
+
self.assertEqual(p1.num_completed, 3)
|
|
104
|
+
self.assertEqual(p1.start_time, 1.0)
|
|
105
|
+
self.assertEqual(p1.stop_time, 5.0)
|
|
106
|
+
|
|
80
107
|
|
|
81
108
|
if __name__ == '__main__':
|
|
82
109
|
unittest.main()
|
|
@@ -14,9 +14,11 @@
|
|
|
14
14
|
import contextlib
|
|
15
15
|
import io
|
|
16
16
|
import os
|
|
17
|
+
import sys
|
|
17
18
|
import tempfile
|
|
18
19
|
import unittest
|
|
19
20
|
|
|
21
|
+
from langfun.core import concurrent as lf_concurrent
|
|
20
22
|
from langfun.core import console as lf_console
|
|
21
23
|
from langfun.core.eval.v2 import eval_test_helper
|
|
22
24
|
from langfun.core.eval.v2 import progress_tracking # pylint: disable=unused-import
|
|
@@ -49,6 +51,8 @@ class TqdmProgressTrackerTest(unittest.TestCase):
|
|
|
49
51
|
string_io = io.StringIO()
|
|
50
52
|
with contextlib.redirect_stderr(string_io):
|
|
51
53
|
_ = experiment.run(root_dir, 'new', plugins=[])
|
|
54
|
+
sys.stderr.flush()
|
|
55
|
+
lf_concurrent.ProgressBar.refresh()
|
|
52
56
|
self.assertIn('All: 100%', string_io.getvalue())
|
|
53
57
|
|
|
54
58
|
def test_with_example_ids(self):
|
|
@@ -59,6 +63,8 @@ class TqdmProgressTrackerTest(unittest.TestCase):
|
|
|
59
63
|
string_io = io.StringIO()
|
|
60
64
|
with contextlib.redirect_stderr(string_io):
|
|
61
65
|
_ = experiment.run(root_dir, 'new', example_ids=[1], plugins=[])
|
|
66
|
+
sys.stderr.flush()
|
|
67
|
+
lf_concurrent.ProgressBar.refresh()
|
|
62
68
|
self.assertIn('All: 100%', string_io.getvalue())
|
|
63
69
|
|
|
64
70
|
|