langfun 0.0.2.dev20240330__py3-none-any.whl → 0.1.2.dev202501140804__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/__init__.py +22 -2
- langfun/core/__init__.py +17 -5
- langfun/core/agentic/__init__.py +30 -0
- langfun/core/agentic/action.py +854 -0
- langfun/core/agentic/action_eval.py +150 -0
- langfun/core/agentic/action_eval_test.py +109 -0
- langfun/core/agentic/action_test.py +136 -0
- langfun/core/coding/python/__init__.py +5 -11
- langfun/core/coding/python/correction.py +37 -28
- langfun/core/coding/python/correction_test.py +29 -3
- langfun/core/coding/python/execution.py +40 -216
- langfun/core/coding/python/execution_test.py +29 -89
- langfun/core/coding/python/generation.py +21 -11
- langfun/core/coding/python/generation_test.py +2 -2
- langfun/core/coding/python/parsing.py +108 -193
- langfun/core/coding/python/parsing_test.py +2 -105
- langfun/core/component.py +69 -2
- langfun/core/component_test.py +54 -0
- langfun/core/concurrent.py +414 -117
- langfun/core/concurrent_test.py +111 -24
- langfun/core/console.py +18 -5
- langfun/core/console_test.py +17 -0
- langfun/core/eval/__init__.py +17 -0
- langfun/core/eval/base.py +767 -140
- langfun/core/eval/base_test.py +238 -53
- langfun/core/eval/matching.py +80 -76
- langfun/core/eval/matching_test.py +19 -9
- langfun/core/eval/patching.py +130 -0
- langfun/core/eval/patching_test.py +170 -0
- langfun/core/eval/scoring.py +37 -28
- langfun/core/eval/scoring_test.py +21 -3
- langfun/core/eval/v2/__init__.py +42 -0
- langfun/core/eval/v2/checkpointing.py +380 -0
- langfun/core/eval/v2/checkpointing_test.py +228 -0
- langfun/core/eval/v2/eval_test_helper.py +136 -0
- langfun/core/eval/v2/evaluation.py +725 -0
- langfun/core/eval/v2/evaluation_test.py +180 -0
- langfun/core/eval/v2/example.py +305 -0
- langfun/core/eval/v2/example_test.py +128 -0
- langfun/core/eval/v2/experiment.py +1048 -0
- langfun/core/eval/v2/experiment_test.py +433 -0
- langfun/core/eval/v2/metric_values.py +156 -0
- langfun/core/eval/v2/metric_values_test.py +80 -0
- langfun/core/eval/v2/metrics.py +357 -0
- langfun/core/eval/v2/metrics_test.py +203 -0
- langfun/core/eval/v2/progress.py +348 -0
- langfun/core/eval/v2/progress_test.py +82 -0
- langfun/core/eval/v2/progress_tracking.py +210 -0
- langfun/core/eval/v2/progress_tracking_test.py +66 -0
- langfun/core/eval/v2/reporting.py +270 -0
- langfun/core/eval/v2/reporting_test.py +158 -0
- langfun/core/eval/v2/runners.py +488 -0
- langfun/core/eval/v2/runners_test.py +334 -0
- langfun/core/langfunc.py +3 -21
- langfun/core/langfunc_test.py +26 -8
- langfun/core/language_model.py +686 -48
- langfun/core/language_model_test.py +681 -44
- langfun/core/llms/__init__.py +100 -12
- langfun/core/llms/anthropic.py +488 -0
- langfun/core/llms/anthropic_test.py +235 -0
- langfun/core/llms/cache/base.py +21 -2
- langfun/core/llms/cache/in_memory.py +13 -0
- langfun/core/llms/cache/in_memory_test.py +88 -28
- langfun/core/llms/compositional.py +101 -0
- langfun/core/llms/compositional_test.py +73 -0
- langfun/core/llms/deepseek.py +117 -0
- langfun/core/llms/deepseek_test.py +61 -0
- langfun/core/llms/fake.py +39 -26
- langfun/core/llms/fake_test.py +136 -11
- langfun/core/llms/gemini.py +507 -0
- langfun/core/llms/gemini_test.py +195 -0
- langfun/core/llms/google_genai.py +62 -218
- langfun/core/llms/google_genai_test.py +9 -197
- langfun/core/llms/groq.py +276 -0
- langfun/core/llms/groq_test.py +64 -0
- langfun/core/llms/llama_cpp.py +15 -40
- langfun/core/llms/llama_cpp_test.py +4 -30
- langfun/core/llms/openai.py +436 -226
- langfun/core/llms/openai_compatible.py +179 -0
- langfun/core/llms/openai_compatible_test.py +495 -0
- langfun/core/llms/openai_test.py +35 -174
- langfun/core/llms/rest.py +113 -0
- langfun/core/llms/rest_test.py +111 -0
- langfun/core/llms/vertexai.py +192 -0
- langfun/core/llms/vertexai_test.py +52 -0
- langfun/core/logging.py +284 -0
- langfun/core/logging_test.py +125 -0
- langfun/core/message.py +319 -9
- langfun/core/message_test.py +190 -13
- langfun/core/modalities/__init__.py +6 -2
- langfun/core/modalities/audio.py +30 -0
- langfun/core/modalities/audio_test.py +63 -0
- langfun/core/modalities/image.py +39 -20
- langfun/core/modalities/image_test.py +52 -9
- langfun/core/modalities/mime.py +206 -29
- langfun/core/modalities/mime_test.py +90 -9
- langfun/core/modalities/ms_office.py +117 -0
- langfun/core/modalities/ms_office_test.py +389 -0
- langfun/core/modalities/pdf.py +22 -0
- langfun/core/modalities/pdf_test.py +57 -0
- langfun/core/modalities/video.py +9 -23
- langfun/core/modalities/video_test.py +3 -3
- langfun/core/modality.py +26 -3
- langfun/core/modality_test.py +2 -2
- langfun/core/sampling.py +11 -11
- langfun/core/structured/__init__.py +15 -16
- langfun/core/structured/completion.py +32 -5
- langfun/core/structured/completion_test.py +9 -8
- langfun/core/structured/description.py +2 -2
- langfun/core/structured/description_test.py +3 -3
- langfun/core/structured/function_generation.py +278 -0
- langfun/core/structured/function_generation_test.py +399 -0
- langfun/core/structured/mapping.py +150 -46
- langfun/core/structured/mapping_test.py +105 -0
- langfun/core/structured/parsing.py +33 -21
- langfun/core/structured/parsing_test.py +71 -22
- langfun/core/structured/querying.py +746 -0
- langfun/core/structured/{prompting_test.py → querying_test.py} +545 -60
- langfun/core/structured/schema.py +208 -99
- langfun/core/structured/schema_generation.py +1 -1
- langfun/core/structured/schema_generation_test.py +2 -2
- langfun/core/structured/schema_test.py +133 -34
- langfun/core/structured/scoring.py +125 -19
- langfun/core/structured/scoring_test.py +30 -0
- langfun/core/structured/tokenization.py +64 -0
- langfun/core/structured/tokenization_test.py +48 -0
- langfun/core/template.py +240 -11
- langfun/core/template_test.py +146 -1
- langfun/core/templates/conversation.py +9 -0
- langfun/core/templates/conversation_test.py +4 -3
- langfun/core/templates/selfplay_test.py +14 -2
- langfun-0.1.2.dev202501140804.dist-info/METADATA +225 -0
- langfun-0.1.2.dev202501140804.dist-info/RECORD +153 -0
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/WHEEL +1 -1
- langfun/core/coding/python/errors.py +0 -108
- langfun/core/coding/python/errors_test.py +0 -99
- langfun/core/coding/python/permissions.py +0 -90
- langfun/core/coding/python/permissions_test.py +0 -86
- langfun/core/structured/prompting.py +0 -217
- langfun/core/text_formatting.py +0 -162
- langfun/core/text_formatting_test.py +0 -47
- langfun-0.0.2.dev20240330.dist-info/METADATA +0 -99
- langfun-0.0.2.dev20240330.dist-info/RECORD +0 -102
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/LICENSE +0 -0
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,357 @@
|
|
1
|
+
# Copyright 2024 The Langfun Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
"""Common metrics for Langfun evaluation."""
|
15
|
+
|
16
|
+
|
17
|
+
import abc
|
18
|
+
import collections
|
19
|
+
import threading
|
20
|
+
from typing import Annotated, Any
|
21
|
+
|
22
|
+
from langfun.core.eval.v2 import example as example_lib
|
23
|
+
from langfun.core.eval.v2 import metric_values
|
24
|
+
import pyglove as pg
|
25
|
+
|
26
|
+
|
27
|
+
Rate = metric_values.Rate
|
28
|
+
Average = metric_values.Average
|
29
|
+
|
30
|
+
|
31
|
+
class Metric(pg.Object, pg.views.HtmlTreeView.Extension):
|
32
|
+
"""Interface for an evaluation metric."""
|
33
|
+
|
34
|
+
name: Annotated[
|
35
|
+
str,
|
36
|
+
(
|
37
|
+
'Name of the metric, which will be used as the key in the dict '
|
38
|
+
'returned by `Experiment.metric_values()`'
|
39
|
+
)
|
40
|
+
]
|
41
|
+
|
42
|
+
def _on_bound(self):
|
43
|
+
super()._on_bound()
|
44
|
+
self._label_group = None
|
45
|
+
self._lock = threading.Lock()
|
46
|
+
|
47
|
+
def audit(self, example: example_lib.Example) -> dict[str, Any]:
|
48
|
+
"""Audits a processed example and returns metric metadata for it."""
|
49
|
+
# NOTE(daiyip): the metric values are being updated concurrently, so we
|
50
|
+
# uses a lock to avoid race condition. We might consider relaxing the lock
|
51
|
+
# later if metric auditing becomes a bottleneck.
|
52
|
+
with self._lock:
|
53
|
+
for v in self.values():
|
54
|
+
v.increment_total()
|
55
|
+
|
56
|
+
metadata = self._audit(example)
|
57
|
+
|
58
|
+
self._update_view()
|
59
|
+
return metadata
|
60
|
+
|
61
|
+
@abc.abstractmethod
|
62
|
+
def _audit(self, example: example_lib.Example) -> dict[str, Any]:
|
63
|
+
"""Subclasses should override this method to implement the metric logic."""
|
64
|
+
|
65
|
+
@abc.abstractmethod
|
66
|
+
def values(self) -> list[metric_values.MetricValue]:
|
67
|
+
"""Returns all the values computed by this metric."""
|
68
|
+
|
69
|
+
def reset(self) -> None:
|
70
|
+
"""Resets the metric values."""
|
71
|
+
for v in self.values():
|
72
|
+
v.reset()
|
73
|
+
|
74
|
+
def _update_view(self):
|
75
|
+
"""Refreshes the metric values."""
|
76
|
+
if self._label_group is None:
|
77
|
+
return
|
78
|
+
|
79
|
+
for label, value in zip(self._label_group.labels, self.values()):
|
80
|
+
label.update(
|
81
|
+
text=self._metric_value_text(value),
|
82
|
+
tooltip=self._metric_value_tooltip(value),
|
83
|
+
)
|
84
|
+
|
85
|
+
def _metric_value_text(self, metric_value: metric_values.MetricValue) -> str:
|
86
|
+
"""Returns the label text for the metric value."""
|
87
|
+
return str(metric_value)
|
88
|
+
|
89
|
+
def _metric_value_tooltip(
|
90
|
+
self, metric_value: metric_values.MetricValue) -> str:
|
91
|
+
"""Returns the label text for the metric value."""
|
92
|
+
with pg.str_format(verbose=True):
|
93
|
+
return f'{metric_value.sym_path.key}: {metric_value}'
|
94
|
+
|
95
|
+
def _metric_label_text(self) -> str:
|
96
|
+
return ''.join(
|
97
|
+
c for c in self.__class__.__name__
|
98
|
+
if c.isalnum() and not c.islower()
|
99
|
+
)
|
100
|
+
|
101
|
+
def _metric_label_tooltip(self) -> str:
|
102
|
+
return self.__class__.__type_name__
|
103
|
+
|
104
|
+
def _html_tree_view(
|
105
|
+
self,
|
106
|
+
*,
|
107
|
+
view: pg.views.HtmlTreeView,
|
108
|
+
extra_flags: dict[str, Any] | None = None,
|
109
|
+
**kwargs,
|
110
|
+
) -> pg.Html:
|
111
|
+
"""Renders the content of the metric value."""
|
112
|
+
extra_flags = extra_flags or {}
|
113
|
+
interactive = extra_flags.get('interactive', True)
|
114
|
+
label_group = self._label_group
|
115
|
+
if label_group is None:
|
116
|
+
label_group = pg.views.html.controls.LabelGroup(
|
117
|
+
[
|
118
|
+
pg.views.html.controls.Label(
|
119
|
+
self._metric_value_text(mv),
|
120
|
+
tooltip=self._metric_value_tooltip(mv),
|
121
|
+
css_classes=[mv.sym_path.key, 'metric-value'],
|
122
|
+
interactive=interactive,
|
123
|
+
) for mv in self.values()
|
124
|
+
],
|
125
|
+
name=pg.views.html.controls.Label(
|
126
|
+
self._metric_label_text(),
|
127
|
+
tooltip=self._metric_label_tooltip(),
|
128
|
+
css_classes=[
|
129
|
+
'metric-name',
|
130
|
+
pg.object_utils.camel_to_snake(self.__class__.__name__, '-')
|
131
|
+
],
|
132
|
+
interactive=False,
|
133
|
+
),
|
134
|
+
css_classes=['metric-container'],
|
135
|
+
)
|
136
|
+
if interactive:
|
137
|
+
self._label_group = label_group
|
138
|
+
return label_group.to_html()
|
139
|
+
|
140
|
+
@classmethod
|
141
|
+
def _html_tree_view_css_styles(cls) -> list[str]:
|
142
|
+
return super()._html_tree_view_css_styles() + [
|
143
|
+
"""
|
144
|
+
.metric-container {
|
145
|
+
display: inline-flex;
|
146
|
+
overflow: hidden;
|
147
|
+
border-radius: 5px;
|
148
|
+
border: 0px;
|
149
|
+
margin: 5px;
|
150
|
+
padding: 0px;
|
151
|
+
}
|
152
|
+
.metric-container .label-container {
|
153
|
+
vertical-align: middle;
|
154
|
+
}
|
155
|
+
.metric-value.oop_errors {
|
156
|
+
color: magenta;
|
157
|
+
background-color: #f9e6eb;
|
158
|
+
}
|
159
|
+
.metric-value.non_oop_errors {
|
160
|
+
color: red;
|
161
|
+
background-color: #fdcccc;
|
162
|
+
}
|
163
|
+
"""
|
164
|
+
]
|
165
|
+
|
166
|
+
#
|
167
|
+
# Common metrics.
|
168
|
+
#
|
169
|
+
|
170
|
+
|
171
|
+
class MetricBase(Metric):
|
172
|
+
"""Base class for common metrics."""
|
173
|
+
|
174
|
+
oop_errors: Rate | None = Rate()
|
175
|
+
non_oop_errors: Rate | None = Rate()
|
176
|
+
|
177
|
+
def _on_bound(self) -> None:
|
178
|
+
super()._on_bound()
|
179
|
+
self._error_breakdown = collections.defaultdict(list)
|
180
|
+
|
181
|
+
def reset(self) -> None:
|
182
|
+
"""Resets the metric."""
|
183
|
+
super().reset()
|
184
|
+
self._error_breakdown = collections.defaultdict(list)
|
185
|
+
|
186
|
+
def _audit(self, example: example_lib.Example) -> dict[str, Any]:
|
187
|
+
"""Audits the evaluation example after processing."""
|
188
|
+
if example.error is None:
|
189
|
+
return self._audit_processed(example)
|
190
|
+
else:
|
191
|
+
return self._audit_error(example)
|
192
|
+
|
193
|
+
def _audit_error(self, example: example_lib.Example) -> dict[str, Any]:
|
194
|
+
"""Audits the evaluation example after processing."""
|
195
|
+
assert example.error is not None
|
196
|
+
tag = example.error.tag
|
197
|
+
if tag.startswith('MappingError'):
|
198
|
+
self.oop_errors.add(example.id, 1)
|
199
|
+
else:
|
200
|
+
self.non_oop_errors.add(example.id, 1)
|
201
|
+
self._error_breakdown[tag].append(example.id)
|
202
|
+
return dict(error=tag)
|
203
|
+
|
204
|
+
@abc.abstractmethod
|
205
|
+
def _audit_processed(self, example: example_lib.Example) -> dict[str, Any]:
|
206
|
+
"""Audits the evaluation example after processing."""
|
207
|
+
|
208
|
+
def _oop_errors_breakdown(self) -> str | None:
|
209
|
+
"""Returns the OOP error breakdown as a string."""
|
210
|
+
return '\n'.join(
|
211
|
+
[
|
212
|
+
f'- {k}: {len(v)}' for k, v in self._error_breakdown.items()
|
213
|
+
if k.startswith('MappingError')
|
214
|
+
]
|
215
|
+
) or None
|
216
|
+
|
217
|
+
def _non_oop_errors_breakdown(self) -> str | None:
|
218
|
+
"""Returns the non-OOP error breakdown as a string."""
|
219
|
+
return '\n'.join(
|
220
|
+
[
|
221
|
+
f'- {k}: {len(v)}' for k, v in self._error_breakdown.items()
|
222
|
+
if not k.startswith('MappingError')
|
223
|
+
]
|
224
|
+
) or None
|
225
|
+
|
226
|
+
def _sym_nondefault(self) -> dict[str, Any]:
|
227
|
+
"""Overrides nondefault valuesso volatile values are not included."""
|
228
|
+
return dict()
|
229
|
+
|
230
|
+
|
231
|
+
class Match(MetricBase):
|
232
|
+
"""Metric for matching outputs against groundtruth."""
|
233
|
+
|
234
|
+
name = 'match'
|
235
|
+
matches: Rate = Rate()
|
236
|
+
mismatches: Rate = Rate()
|
237
|
+
|
238
|
+
def match(
|
239
|
+
self, example_input: Any, output: Any
|
240
|
+
) -> bool | tuple[bool, dict[str, Any]]:
|
241
|
+
"""Returns whether the output matches the groundtruth from the example.
|
242
|
+
|
243
|
+
Args:
|
244
|
+
example_input: The example input which contains the groundtruth.
|
245
|
+
output: The output to match against.
|
246
|
+
|
247
|
+
Returns:
|
248
|
+
True if the output matches the groundtruth, False otherwise.
|
249
|
+
Or a tuple of (match, metadata).
|
250
|
+
"""
|
251
|
+
groundtruth = getattr(example_input, 'groundtruth', pg.MISSING_VALUE)
|
252
|
+
if pg.MISSING_VALUE == groundtruth:
|
253
|
+
raise ValueError(
|
254
|
+
f'`groundtruth` is not present in the example ({example_input}). '
|
255
|
+
'Please subclassing `Match` and override the `match` method to '
|
256
|
+
'support custom example format.'
|
257
|
+
)
|
258
|
+
return pg.eq(output, groundtruth)
|
259
|
+
|
260
|
+
def _audit_processed(self, example: example_lib.Example) -> dict[str, Any]:
|
261
|
+
"""Audits the evaluation example after processing."""
|
262
|
+
metadata = {}
|
263
|
+
is_match = self.match(example.input, example.output)
|
264
|
+
if isinstance(is_match, tuple):
|
265
|
+
is_match, metadata = is_match
|
266
|
+
if is_match:
|
267
|
+
self.matches.add(example.id, 1)
|
268
|
+
metadata['match'] = True
|
269
|
+
else:
|
270
|
+
self.mismatches.add(example.id, 1)
|
271
|
+
metadata['mismatch'] = True
|
272
|
+
return metadata
|
273
|
+
|
274
|
+
def values(self) -> list[metric_values.MetricValue]:
|
275
|
+
"""Returns all the values computed by this metric."""
|
276
|
+
return [
|
277
|
+
self.matches,
|
278
|
+
self.mismatches,
|
279
|
+
self.oop_errors,
|
280
|
+
self.non_oop_errors
|
281
|
+
]
|
282
|
+
|
283
|
+
@classmethod
|
284
|
+
def _html_tree_view_css_styles(cls) -> list[str]:
|
285
|
+
return super()._html_tree_view_css_styles() + [
|
286
|
+
"""
|
287
|
+
.metric-name.match {
|
288
|
+
padding: 5px;
|
289
|
+
color: white;
|
290
|
+
background-color: purple;
|
291
|
+
}
|
292
|
+
.metric-value.matches {
|
293
|
+
color: green;
|
294
|
+
background-color: #dcefbe;
|
295
|
+
}
|
296
|
+
.metric-value.mismatches {
|
297
|
+
color: orange;
|
298
|
+
background-color: #ffefc4;
|
299
|
+
}
|
300
|
+
"""
|
301
|
+
]
|
302
|
+
|
303
|
+
|
304
|
+
class Score(MetricBase):
|
305
|
+
"""Base class for scoring."""
|
306
|
+
|
307
|
+
name = 'score'
|
308
|
+
average_score: Average = Average()
|
309
|
+
|
310
|
+
@abc.abstractmethod
|
311
|
+
def score(
|
312
|
+
self,
|
313
|
+
example_input: Any,
|
314
|
+
output: Any) -> float | tuple[float, dict[str, Any]]:
|
315
|
+
"""Returns the score based on the example and output.
|
316
|
+
|
317
|
+
Args:
|
318
|
+
example_input: The example input based on which the output is generated.
|
319
|
+
output: The output to score.
|
320
|
+
|
321
|
+
Returns:
|
322
|
+
A float score. Or a tuple of (score, metadata).
|
323
|
+
"""
|
324
|
+
|
325
|
+
def _audit_processed(self, example: example_lib.Example) -> dict[str, Any]:
|
326
|
+
"""Audits the evaluation example after processing."""
|
327
|
+
metadata = {}
|
328
|
+
score = self.score(example.input, example.output)
|
329
|
+
if isinstance(score, tuple):
|
330
|
+
score, metadata = score
|
331
|
+
self.average_score.add(example.id, score)
|
332
|
+
metadata['score'] = score
|
333
|
+
return metadata
|
334
|
+
|
335
|
+
def values(self) -> list[metric_values.MetricValue]:
|
336
|
+
"""Returns all the values computed by this metric."""
|
337
|
+
return [
|
338
|
+
self.average_score,
|
339
|
+
self.oop_errors,
|
340
|
+
self.non_oop_errors
|
341
|
+
]
|
342
|
+
|
343
|
+
@classmethod
|
344
|
+
def _html_tree_view_css_styles(cls) -> list[str]:
|
345
|
+
return super()._html_tree_view_css_styles() + [
|
346
|
+
"""
|
347
|
+
.metric-name.score {
|
348
|
+
padding: 5px;
|
349
|
+
color: white;
|
350
|
+
background-color: blue;
|
351
|
+
}
|
352
|
+
.metric-value.average_score {
|
353
|
+
color: blue;
|
354
|
+
background-color: #b0c7f6;
|
355
|
+
}
|
356
|
+
"""
|
357
|
+
]
|
@@ -0,0 +1,203 @@
|
|
1
|
+
# Copyright 2024 The Langfun Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
import unittest
|
15
|
+
|
16
|
+
from langfun.core.eval.v2 import example as example_lib
|
17
|
+
from langfun.core.eval.v2 import metrics
|
18
|
+
import pyglove as pg
|
19
|
+
|
20
|
+
Example = example_lib.Example
|
21
|
+
|
22
|
+
|
23
|
+
class MatchTest(unittest.TestCase):
|
24
|
+
|
25
|
+
def test_basic(self):
|
26
|
+
m = metrics.Match() # pylint: disable=invalid-name
|
27
|
+
self.assertEqual(
|
28
|
+
m.audit(Example(id=1, input=pg.Dict(groundtruth=1), output=1)),
|
29
|
+
dict(match=True)
|
30
|
+
)
|
31
|
+
self.assertEqual(
|
32
|
+
m.audit(Example(id=2, input=pg.Dict(groundtruth=1), output=2)),
|
33
|
+
dict(mismatch=True)
|
34
|
+
)
|
35
|
+
self.assertEqual(
|
36
|
+
m.audit(
|
37
|
+
Example(
|
38
|
+
id=3,
|
39
|
+
input=pg.Dict(groundtruth=1),
|
40
|
+
error=pg.object_utils.ErrorInfo(
|
41
|
+
tag='ValueError',
|
42
|
+
description='Bad input.',
|
43
|
+
stacktrace='...',
|
44
|
+
)
|
45
|
+
)
|
46
|
+
),
|
47
|
+
dict(error='ValueError')
|
48
|
+
)
|
49
|
+
self.assertEqual(
|
50
|
+
m.audit(
|
51
|
+
Example(
|
52
|
+
id=3,
|
53
|
+
input=pg.Dict(groundtruth=1),
|
54
|
+
error=pg.object_utils.ErrorInfo(
|
55
|
+
tag='MappingError.CodeError',
|
56
|
+
description='Bad input.',
|
57
|
+
stacktrace='...',
|
58
|
+
)
|
59
|
+
)
|
60
|
+
),
|
61
|
+
dict(error='MappingError.CodeError')
|
62
|
+
)
|
63
|
+
self.assertEqual(m.matches, 0.25)
|
64
|
+
self.assertEqual(m.mismatches, 0.25)
|
65
|
+
self.assertEqual(m.oop_errors, 0.25)
|
66
|
+
self.assertEqual(m.non_oop_errors, 0.25)
|
67
|
+
|
68
|
+
self.assertEqual(m.values(), [
|
69
|
+
m.matches,
|
70
|
+
m.mismatches,
|
71
|
+
m.oop_errors,
|
72
|
+
m.non_oop_errors
|
73
|
+
])
|
74
|
+
m.reset()
|
75
|
+
self.assertEqual(len(m.matches.data_points), 0)
|
76
|
+
self.assertEqual(len(m.mismatches.data_points), 0)
|
77
|
+
self.assertEqual(len(m.oop_errors.data_points), 0)
|
78
|
+
self.assertEqual(len(m.non_oop_errors.data_points), 0)
|
79
|
+
|
80
|
+
def test_bad_case(self):
|
81
|
+
m = metrics.Match() # pylint: disable=invalid-name
|
82
|
+
with self.assertRaisesRegex(ValueError, '`groundtruth` is not present'):
|
83
|
+
m.audit(Example(id=1, input=pg.Dict(x=1), output=1))
|
84
|
+
|
85
|
+
def test_custom_metadata(self):
|
86
|
+
|
87
|
+
class MyMatch(metrics.Match):
|
88
|
+
def match(self, example_input, output):
|
89
|
+
return example_input.x == output, dict(x=example_input.x)
|
90
|
+
|
91
|
+
m = MyMatch() # pylint: disable=invalid-name
|
92
|
+
self.assertEqual(
|
93
|
+
m.audit(Example(id=1, input=pg.Dict(x=1), output=1)),
|
94
|
+
dict(match=True, x=1)
|
95
|
+
)
|
96
|
+
self.assertEqual(m.matches, 1.0)
|
97
|
+
|
98
|
+
def test_html_view(self):
|
99
|
+
m = metrics.Match() # pylint: disable=invalid-name
|
100
|
+
m.audit(Example(id=1, input=pg.Dict(groundtruth=1), output=1))
|
101
|
+
self.assertIn(
|
102
|
+
'100.0%',
|
103
|
+
m.to_html().content,
|
104
|
+
)
|
105
|
+
with pg.views.html.controls.HtmlControl.track_scripts() as scripts:
|
106
|
+
m.audit(Example(id=2, input=pg.Dict(groundtruth=1), output=2))
|
107
|
+
self.assertEqual(len(scripts), 12)
|
108
|
+
|
109
|
+
|
110
|
+
class ScoreTest(unittest.TestCase):
|
111
|
+
|
112
|
+
def test_basic(self):
|
113
|
+
|
114
|
+
class MyScore(metrics.Score):
|
115
|
+
|
116
|
+
def score(self, example_input, output) -> float:
|
117
|
+
return example_input.x * output
|
118
|
+
|
119
|
+
m = MyScore() # pylint: disable=invalid-name
|
120
|
+
self.assertEqual(
|
121
|
+
m.audit(Example(id=1, input=pg.Dict(x=1), output=1)),
|
122
|
+
dict(score=1 * 1)
|
123
|
+
)
|
124
|
+
self.assertEqual(
|
125
|
+
m.audit(Example(id=2, input=pg.Dict(x=2), output=2)),
|
126
|
+
dict(score=2 * 2)
|
127
|
+
)
|
128
|
+
self.assertEqual(
|
129
|
+
m.audit(
|
130
|
+
Example(
|
131
|
+
id=3,
|
132
|
+
input=pg.Dict(x=1),
|
133
|
+
error=pg.object_utils.ErrorInfo(
|
134
|
+
tag='ValueError',
|
135
|
+
description='Bad input.',
|
136
|
+
stacktrace='...',
|
137
|
+
)
|
138
|
+
)
|
139
|
+
),
|
140
|
+
dict(error='ValueError')
|
141
|
+
)
|
142
|
+
self.assertEqual(
|
143
|
+
m.audit(
|
144
|
+
Example(
|
145
|
+
id=3,
|
146
|
+
input=pg.Dict(x=1),
|
147
|
+
error=pg.object_utils.ErrorInfo(
|
148
|
+
tag='MappingError.CodeError',
|
149
|
+
description='Bad input.',
|
150
|
+
stacktrace='...',
|
151
|
+
)
|
152
|
+
)
|
153
|
+
),
|
154
|
+
dict(error='MappingError.CodeError')
|
155
|
+
)
|
156
|
+
self.assertEqual(m.average_score, 2.5)
|
157
|
+
self.assertEqual(m.oop_errors, 0.25)
|
158
|
+
self.assertEqual(m.non_oop_errors, 0.25)
|
159
|
+
|
160
|
+
self.assertEqual(m.values(), [
|
161
|
+
m.average_score,
|
162
|
+
m.oop_errors,
|
163
|
+
m.non_oop_errors
|
164
|
+
])
|
165
|
+
m.reset()
|
166
|
+
self.assertEqual(len(m.average_score.data_points), 0)
|
167
|
+
self.assertEqual(len(m.oop_errors.data_points), 0)
|
168
|
+
self.assertEqual(len(m.non_oop_errors.data_points), 0)
|
169
|
+
|
170
|
+
def test_custom_metadata(self):
|
171
|
+
|
172
|
+
class MyScore(metrics.Score):
|
173
|
+
|
174
|
+
def score(self, example_input, output):
|
175
|
+
return example_input.x * output, dict(x=example_input.x)
|
176
|
+
|
177
|
+
m = MyScore() # pylint: disable=invalid-name
|
178
|
+
self.assertEqual(
|
179
|
+
m.audit(Example(id=1, input=pg.Dict(x=1), output=1)),
|
180
|
+
dict(score=1 * 1, x=1)
|
181
|
+
)
|
182
|
+
self.assertEqual(m.average_score, 1.0)
|
183
|
+
|
184
|
+
def test_html_view(self):
|
185
|
+
|
186
|
+
class MyScore(metrics.Score):
|
187
|
+
|
188
|
+
def score(self, example_input, output) -> float:
|
189
|
+
return example_input.x * output
|
190
|
+
|
191
|
+
m = MyScore() # pylint: disable=invalid-name
|
192
|
+
m.audit(Example(id=1, input=pg.Dict(x=1), output=2))
|
193
|
+
self.assertIn(
|
194
|
+
'2.000',
|
195
|
+
m.to_html().content,
|
196
|
+
)
|
197
|
+
with pg.views.html.controls.HtmlControl.track_scripts() as scripts:
|
198
|
+
m.audit(Example(id=2, input=pg.Dict(x=1), output=2))
|
199
|
+
self.assertEqual(len(scripts), 9)
|
200
|
+
|
201
|
+
|
202
|
+
if __name__ == '__main__':
|
203
|
+
unittest.main()
|