langfun 0.0.2.dev20240330__py3-none-any.whl → 0.1.2.dev202501140804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (145) hide show
  1. langfun/__init__.py +22 -2
  2. langfun/core/__init__.py +17 -5
  3. langfun/core/agentic/__init__.py +30 -0
  4. langfun/core/agentic/action.py +854 -0
  5. langfun/core/agentic/action_eval.py +150 -0
  6. langfun/core/agentic/action_eval_test.py +109 -0
  7. langfun/core/agentic/action_test.py +136 -0
  8. langfun/core/coding/python/__init__.py +5 -11
  9. langfun/core/coding/python/correction.py +37 -28
  10. langfun/core/coding/python/correction_test.py +29 -3
  11. langfun/core/coding/python/execution.py +40 -216
  12. langfun/core/coding/python/execution_test.py +29 -89
  13. langfun/core/coding/python/generation.py +21 -11
  14. langfun/core/coding/python/generation_test.py +2 -2
  15. langfun/core/coding/python/parsing.py +108 -193
  16. langfun/core/coding/python/parsing_test.py +2 -105
  17. langfun/core/component.py +69 -2
  18. langfun/core/component_test.py +54 -0
  19. langfun/core/concurrent.py +414 -117
  20. langfun/core/concurrent_test.py +111 -24
  21. langfun/core/console.py +18 -5
  22. langfun/core/console_test.py +17 -0
  23. langfun/core/eval/__init__.py +17 -0
  24. langfun/core/eval/base.py +767 -140
  25. langfun/core/eval/base_test.py +238 -53
  26. langfun/core/eval/matching.py +80 -76
  27. langfun/core/eval/matching_test.py +19 -9
  28. langfun/core/eval/patching.py +130 -0
  29. langfun/core/eval/patching_test.py +170 -0
  30. langfun/core/eval/scoring.py +37 -28
  31. langfun/core/eval/scoring_test.py +21 -3
  32. langfun/core/eval/v2/__init__.py +42 -0
  33. langfun/core/eval/v2/checkpointing.py +380 -0
  34. langfun/core/eval/v2/checkpointing_test.py +228 -0
  35. langfun/core/eval/v2/eval_test_helper.py +136 -0
  36. langfun/core/eval/v2/evaluation.py +725 -0
  37. langfun/core/eval/v2/evaluation_test.py +180 -0
  38. langfun/core/eval/v2/example.py +305 -0
  39. langfun/core/eval/v2/example_test.py +128 -0
  40. langfun/core/eval/v2/experiment.py +1048 -0
  41. langfun/core/eval/v2/experiment_test.py +433 -0
  42. langfun/core/eval/v2/metric_values.py +156 -0
  43. langfun/core/eval/v2/metric_values_test.py +80 -0
  44. langfun/core/eval/v2/metrics.py +357 -0
  45. langfun/core/eval/v2/metrics_test.py +203 -0
  46. langfun/core/eval/v2/progress.py +348 -0
  47. langfun/core/eval/v2/progress_test.py +82 -0
  48. langfun/core/eval/v2/progress_tracking.py +210 -0
  49. langfun/core/eval/v2/progress_tracking_test.py +66 -0
  50. langfun/core/eval/v2/reporting.py +270 -0
  51. langfun/core/eval/v2/reporting_test.py +158 -0
  52. langfun/core/eval/v2/runners.py +488 -0
  53. langfun/core/eval/v2/runners_test.py +334 -0
  54. langfun/core/langfunc.py +3 -21
  55. langfun/core/langfunc_test.py +26 -8
  56. langfun/core/language_model.py +686 -48
  57. langfun/core/language_model_test.py +681 -44
  58. langfun/core/llms/__init__.py +100 -12
  59. langfun/core/llms/anthropic.py +488 -0
  60. langfun/core/llms/anthropic_test.py +235 -0
  61. langfun/core/llms/cache/base.py +21 -2
  62. langfun/core/llms/cache/in_memory.py +13 -0
  63. langfun/core/llms/cache/in_memory_test.py +88 -28
  64. langfun/core/llms/compositional.py +101 -0
  65. langfun/core/llms/compositional_test.py +73 -0
  66. langfun/core/llms/deepseek.py +117 -0
  67. langfun/core/llms/deepseek_test.py +61 -0
  68. langfun/core/llms/fake.py +39 -26
  69. langfun/core/llms/fake_test.py +136 -11
  70. langfun/core/llms/gemini.py +507 -0
  71. langfun/core/llms/gemini_test.py +195 -0
  72. langfun/core/llms/google_genai.py +62 -218
  73. langfun/core/llms/google_genai_test.py +9 -197
  74. langfun/core/llms/groq.py +276 -0
  75. langfun/core/llms/groq_test.py +64 -0
  76. langfun/core/llms/llama_cpp.py +15 -40
  77. langfun/core/llms/llama_cpp_test.py +4 -30
  78. langfun/core/llms/openai.py +436 -226
  79. langfun/core/llms/openai_compatible.py +179 -0
  80. langfun/core/llms/openai_compatible_test.py +495 -0
  81. langfun/core/llms/openai_test.py +35 -174
  82. langfun/core/llms/rest.py +113 -0
  83. langfun/core/llms/rest_test.py +111 -0
  84. langfun/core/llms/vertexai.py +192 -0
  85. langfun/core/llms/vertexai_test.py +52 -0
  86. langfun/core/logging.py +284 -0
  87. langfun/core/logging_test.py +125 -0
  88. langfun/core/message.py +319 -9
  89. langfun/core/message_test.py +190 -13
  90. langfun/core/modalities/__init__.py +6 -2
  91. langfun/core/modalities/audio.py +30 -0
  92. langfun/core/modalities/audio_test.py +63 -0
  93. langfun/core/modalities/image.py +39 -20
  94. langfun/core/modalities/image_test.py +52 -9
  95. langfun/core/modalities/mime.py +206 -29
  96. langfun/core/modalities/mime_test.py +90 -9
  97. langfun/core/modalities/ms_office.py +117 -0
  98. langfun/core/modalities/ms_office_test.py +389 -0
  99. langfun/core/modalities/pdf.py +22 -0
  100. langfun/core/modalities/pdf_test.py +57 -0
  101. langfun/core/modalities/video.py +9 -23
  102. langfun/core/modalities/video_test.py +3 -3
  103. langfun/core/modality.py +26 -3
  104. langfun/core/modality_test.py +2 -2
  105. langfun/core/sampling.py +11 -11
  106. langfun/core/structured/__init__.py +15 -16
  107. langfun/core/structured/completion.py +32 -5
  108. langfun/core/structured/completion_test.py +9 -8
  109. langfun/core/structured/description.py +2 -2
  110. langfun/core/structured/description_test.py +3 -3
  111. langfun/core/structured/function_generation.py +278 -0
  112. langfun/core/structured/function_generation_test.py +399 -0
  113. langfun/core/structured/mapping.py +150 -46
  114. langfun/core/structured/mapping_test.py +105 -0
  115. langfun/core/structured/parsing.py +33 -21
  116. langfun/core/structured/parsing_test.py +71 -22
  117. langfun/core/structured/querying.py +746 -0
  118. langfun/core/structured/{prompting_test.py → querying_test.py} +545 -60
  119. langfun/core/structured/schema.py +208 -99
  120. langfun/core/structured/schema_generation.py +1 -1
  121. langfun/core/structured/schema_generation_test.py +2 -2
  122. langfun/core/structured/schema_test.py +133 -34
  123. langfun/core/structured/scoring.py +125 -19
  124. langfun/core/structured/scoring_test.py +30 -0
  125. langfun/core/structured/tokenization.py +64 -0
  126. langfun/core/structured/tokenization_test.py +48 -0
  127. langfun/core/template.py +240 -11
  128. langfun/core/template_test.py +146 -1
  129. langfun/core/templates/conversation.py +9 -0
  130. langfun/core/templates/conversation_test.py +4 -3
  131. langfun/core/templates/selfplay_test.py +14 -2
  132. langfun-0.1.2.dev202501140804.dist-info/METADATA +225 -0
  133. langfun-0.1.2.dev202501140804.dist-info/RECORD +153 -0
  134. {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/WHEEL +1 -1
  135. langfun/core/coding/python/errors.py +0 -108
  136. langfun/core/coding/python/errors_test.py +0 -99
  137. langfun/core/coding/python/permissions.py +0 -90
  138. langfun/core/coding/python/permissions_test.py +0 -86
  139. langfun/core/structured/prompting.py +0 -217
  140. langfun/core/text_formatting.py +0 -162
  141. langfun/core/text_formatting_test.py +0 -47
  142. langfun-0.0.2.dev20240330.dist-info/METADATA +0 -99
  143. langfun-0.0.2.dev20240330.dist-info/RECORD +0 -102
  144. {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/LICENSE +0 -0
  145. {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,357 @@
1
+ # Copyright 2024 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Common metrics for Langfun evaluation."""
15
+
16
+
17
+ import abc
18
+ import collections
19
+ import threading
20
+ from typing import Annotated, Any
21
+
22
+ from langfun.core.eval.v2 import example as example_lib
23
+ from langfun.core.eval.v2 import metric_values
24
+ import pyglove as pg
25
+
26
+
27
+ Rate = metric_values.Rate
28
+ Average = metric_values.Average
29
+
30
+
31
+ class Metric(pg.Object, pg.views.HtmlTreeView.Extension):
32
+ """Interface for an evaluation metric."""
33
+
34
+ name: Annotated[
35
+ str,
36
+ (
37
+ 'Name of the metric, which will be used as the key in the dict '
38
+ 'returned by `Experiment.metric_values()`'
39
+ )
40
+ ]
41
+
42
+ def _on_bound(self):
43
+ super()._on_bound()
44
+ self._label_group = None
45
+ self._lock = threading.Lock()
46
+
47
+ def audit(self, example: example_lib.Example) -> dict[str, Any]:
48
+ """Audits a processed example and returns metric metadata for it."""
49
+ # NOTE(daiyip): the metric values are being updated concurrently, so we
50
+ # uses a lock to avoid race condition. We might consider relaxing the lock
51
+ # later if metric auditing becomes a bottleneck.
52
+ with self._lock:
53
+ for v in self.values():
54
+ v.increment_total()
55
+
56
+ metadata = self._audit(example)
57
+
58
+ self._update_view()
59
+ return metadata
60
+
61
+ @abc.abstractmethod
62
+ def _audit(self, example: example_lib.Example) -> dict[str, Any]:
63
+ """Subclasses should override this method to implement the metric logic."""
64
+
65
+ @abc.abstractmethod
66
+ def values(self) -> list[metric_values.MetricValue]:
67
+ """Returns all the values computed by this metric."""
68
+
69
+ def reset(self) -> None:
70
+ """Resets the metric values."""
71
+ for v in self.values():
72
+ v.reset()
73
+
74
+ def _update_view(self):
75
+ """Refreshes the metric values."""
76
+ if self._label_group is None:
77
+ return
78
+
79
+ for label, value in zip(self._label_group.labels, self.values()):
80
+ label.update(
81
+ text=self._metric_value_text(value),
82
+ tooltip=self._metric_value_tooltip(value),
83
+ )
84
+
85
+ def _metric_value_text(self, metric_value: metric_values.MetricValue) -> str:
86
+ """Returns the label text for the metric value."""
87
+ return str(metric_value)
88
+
89
+ def _metric_value_tooltip(
90
+ self, metric_value: metric_values.MetricValue) -> str:
91
+ """Returns the label text for the metric value."""
92
+ with pg.str_format(verbose=True):
93
+ return f'{metric_value.sym_path.key}: {metric_value}'
94
+
95
+ def _metric_label_text(self) -> str:
96
+ return ''.join(
97
+ c for c in self.__class__.__name__
98
+ if c.isalnum() and not c.islower()
99
+ )
100
+
101
+ def _metric_label_tooltip(self) -> str:
102
+ return self.__class__.__type_name__
103
+
104
+ def _html_tree_view(
105
+ self,
106
+ *,
107
+ view: pg.views.HtmlTreeView,
108
+ extra_flags: dict[str, Any] | None = None,
109
+ **kwargs,
110
+ ) -> pg.Html:
111
+ """Renders the content of the metric value."""
112
+ extra_flags = extra_flags or {}
113
+ interactive = extra_flags.get('interactive', True)
114
+ label_group = self._label_group
115
+ if label_group is None:
116
+ label_group = pg.views.html.controls.LabelGroup(
117
+ [
118
+ pg.views.html.controls.Label(
119
+ self._metric_value_text(mv),
120
+ tooltip=self._metric_value_tooltip(mv),
121
+ css_classes=[mv.sym_path.key, 'metric-value'],
122
+ interactive=interactive,
123
+ ) for mv in self.values()
124
+ ],
125
+ name=pg.views.html.controls.Label(
126
+ self._metric_label_text(),
127
+ tooltip=self._metric_label_tooltip(),
128
+ css_classes=[
129
+ 'metric-name',
130
+ pg.object_utils.camel_to_snake(self.__class__.__name__, '-')
131
+ ],
132
+ interactive=False,
133
+ ),
134
+ css_classes=['metric-container'],
135
+ )
136
+ if interactive:
137
+ self._label_group = label_group
138
+ return label_group.to_html()
139
+
140
+ @classmethod
141
+ def _html_tree_view_css_styles(cls) -> list[str]:
142
+ return super()._html_tree_view_css_styles() + [
143
+ """
144
+ .metric-container {
145
+ display: inline-flex;
146
+ overflow: hidden;
147
+ border-radius: 5px;
148
+ border: 0px;
149
+ margin: 5px;
150
+ padding: 0px;
151
+ }
152
+ .metric-container .label-container {
153
+ vertical-align: middle;
154
+ }
155
+ .metric-value.oop_errors {
156
+ color: magenta;
157
+ background-color: #f9e6eb;
158
+ }
159
+ .metric-value.non_oop_errors {
160
+ color: red;
161
+ background-color: #fdcccc;
162
+ }
163
+ """
164
+ ]
165
+
166
+ #
167
+ # Common metrics.
168
+ #
169
+
170
+
171
+ class MetricBase(Metric):
172
+ """Base class for common metrics."""
173
+
174
+ oop_errors: Rate | None = Rate()
175
+ non_oop_errors: Rate | None = Rate()
176
+
177
+ def _on_bound(self) -> None:
178
+ super()._on_bound()
179
+ self._error_breakdown = collections.defaultdict(list)
180
+
181
+ def reset(self) -> None:
182
+ """Resets the metric."""
183
+ super().reset()
184
+ self._error_breakdown = collections.defaultdict(list)
185
+
186
+ def _audit(self, example: example_lib.Example) -> dict[str, Any]:
187
+ """Audits the evaluation example after processing."""
188
+ if example.error is None:
189
+ return self._audit_processed(example)
190
+ else:
191
+ return self._audit_error(example)
192
+
193
+ def _audit_error(self, example: example_lib.Example) -> dict[str, Any]:
194
+ """Audits the evaluation example after processing."""
195
+ assert example.error is not None
196
+ tag = example.error.tag
197
+ if tag.startswith('MappingError'):
198
+ self.oop_errors.add(example.id, 1)
199
+ else:
200
+ self.non_oop_errors.add(example.id, 1)
201
+ self._error_breakdown[tag].append(example.id)
202
+ return dict(error=tag)
203
+
204
+ @abc.abstractmethod
205
+ def _audit_processed(self, example: example_lib.Example) -> dict[str, Any]:
206
+ """Audits the evaluation example after processing."""
207
+
208
+ def _oop_errors_breakdown(self) -> str | None:
209
+ """Returns the OOP error breakdown as a string."""
210
+ return '\n'.join(
211
+ [
212
+ f'- {k}: {len(v)}' for k, v in self._error_breakdown.items()
213
+ if k.startswith('MappingError')
214
+ ]
215
+ ) or None
216
+
217
+ def _non_oop_errors_breakdown(self) -> str | None:
218
+ """Returns the non-OOP error breakdown as a string."""
219
+ return '\n'.join(
220
+ [
221
+ f'- {k}: {len(v)}' for k, v in self._error_breakdown.items()
222
+ if not k.startswith('MappingError')
223
+ ]
224
+ ) or None
225
+
226
+ def _sym_nondefault(self) -> dict[str, Any]:
227
+ """Overrides nondefault valuesso volatile values are not included."""
228
+ return dict()
229
+
230
+
231
+ class Match(MetricBase):
232
+ """Metric for matching outputs against groundtruth."""
233
+
234
+ name = 'match'
235
+ matches: Rate = Rate()
236
+ mismatches: Rate = Rate()
237
+
238
+ def match(
239
+ self, example_input: Any, output: Any
240
+ ) -> bool | tuple[bool, dict[str, Any]]:
241
+ """Returns whether the output matches the groundtruth from the example.
242
+
243
+ Args:
244
+ example_input: The example input which contains the groundtruth.
245
+ output: The output to match against.
246
+
247
+ Returns:
248
+ True if the output matches the groundtruth, False otherwise.
249
+ Or a tuple of (match, metadata).
250
+ """
251
+ groundtruth = getattr(example_input, 'groundtruth', pg.MISSING_VALUE)
252
+ if pg.MISSING_VALUE == groundtruth:
253
+ raise ValueError(
254
+ f'`groundtruth` is not present in the example ({example_input}). '
255
+ 'Please subclassing `Match` and override the `match` method to '
256
+ 'support custom example format.'
257
+ )
258
+ return pg.eq(output, groundtruth)
259
+
260
+ def _audit_processed(self, example: example_lib.Example) -> dict[str, Any]:
261
+ """Audits the evaluation example after processing."""
262
+ metadata = {}
263
+ is_match = self.match(example.input, example.output)
264
+ if isinstance(is_match, tuple):
265
+ is_match, metadata = is_match
266
+ if is_match:
267
+ self.matches.add(example.id, 1)
268
+ metadata['match'] = True
269
+ else:
270
+ self.mismatches.add(example.id, 1)
271
+ metadata['mismatch'] = True
272
+ return metadata
273
+
274
+ def values(self) -> list[metric_values.MetricValue]:
275
+ """Returns all the values computed by this metric."""
276
+ return [
277
+ self.matches,
278
+ self.mismatches,
279
+ self.oop_errors,
280
+ self.non_oop_errors
281
+ ]
282
+
283
+ @classmethod
284
+ def _html_tree_view_css_styles(cls) -> list[str]:
285
+ return super()._html_tree_view_css_styles() + [
286
+ """
287
+ .metric-name.match {
288
+ padding: 5px;
289
+ color: white;
290
+ background-color: purple;
291
+ }
292
+ .metric-value.matches {
293
+ color: green;
294
+ background-color: #dcefbe;
295
+ }
296
+ .metric-value.mismatches {
297
+ color: orange;
298
+ background-color: #ffefc4;
299
+ }
300
+ """
301
+ ]
302
+
303
+
304
+ class Score(MetricBase):
305
+ """Base class for scoring."""
306
+
307
+ name = 'score'
308
+ average_score: Average = Average()
309
+
310
+ @abc.abstractmethod
311
+ def score(
312
+ self,
313
+ example_input: Any,
314
+ output: Any) -> float | tuple[float, dict[str, Any]]:
315
+ """Returns the score based on the example and output.
316
+
317
+ Args:
318
+ example_input: The example input based on which the output is generated.
319
+ output: The output to score.
320
+
321
+ Returns:
322
+ A float score. Or a tuple of (score, metadata).
323
+ """
324
+
325
+ def _audit_processed(self, example: example_lib.Example) -> dict[str, Any]:
326
+ """Audits the evaluation example after processing."""
327
+ metadata = {}
328
+ score = self.score(example.input, example.output)
329
+ if isinstance(score, tuple):
330
+ score, metadata = score
331
+ self.average_score.add(example.id, score)
332
+ metadata['score'] = score
333
+ return metadata
334
+
335
+ def values(self) -> list[metric_values.MetricValue]:
336
+ """Returns all the values computed by this metric."""
337
+ return [
338
+ self.average_score,
339
+ self.oop_errors,
340
+ self.non_oop_errors
341
+ ]
342
+
343
+ @classmethod
344
+ def _html_tree_view_css_styles(cls) -> list[str]:
345
+ return super()._html_tree_view_css_styles() + [
346
+ """
347
+ .metric-name.score {
348
+ padding: 5px;
349
+ color: white;
350
+ background-color: blue;
351
+ }
352
+ .metric-value.average_score {
353
+ color: blue;
354
+ background-color: #b0c7f6;
355
+ }
356
+ """
357
+ ]
@@ -0,0 +1,203 @@
1
+ # Copyright 2024 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import unittest
15
+
16
+ from langfun.core.eval.v2 import example as example_lib
17
+ from langfun.core.eval.v2 import metrics
18
+ import pyglove as pg
19
+
20
+ Example = example_lib.Example
21
+
22
+
23
+ class MatchTest(unittest.TestCase):
24
+
25
+ def test_basic(self):
26
+ m = metrics.Match() # pylint: disable=invalid-name
27
+ self.assertEqual(
28
+ m.audit(Example(id=1, input=pg.Dict(groundtruth=1), output=1)),
29
+ dict(match=True)
30
+ )
31
+ self.assertEqual(
32
+ m.audit(Example(id=2, input=pg.Dict(groundtruth=1), output=2)),
33
+ dict(mismatch=True)
34
+ )
35
+ self.assertEqual(
36
+ m.audit(
37
+ Example(
38
+ id=3,
39
+ input=pg.Dict(groundtruth=1),
40
+ error=pg.object_utils.ErrorInfo(
41
+ tag='ValueError',
42
+ description='Bad input.',
43
+ stacktrace='...',
44
+ )
45
+ )
46
+ ),
47
+ dict(error='ValueError')
48
+ )
49
+ self.assertEqual(
50
+ m.audit(
51
+ Example(
52
+ id=3,
53
+ input=pg.Dict(groundtruth=1),
54
+ error=pg.object_utils.ErrorInfo(
55
+ tag='MappingError.CodeError',
56
+ description='Bad input.',
57
+ stacktrace='...',
58
+ )
59
+ )
60
+ ),
61
+ dict(error='MappingError.CodeError')
62
+ )
63
+ self.assertEqual(m.matches, 0.25)
64
+ self.assertEqual(m.mismatches, 0.25)
65
+ self.assertEqual(m.oop_errors, 0.25)
66
+ self.assertEqual(m.non_oop_errors, 0.25)
67
+
68
+ self.assertEqual(m.values(), [
69
+ m.matches,
70
+ m.mismatches,
71
+ m.oop_errors,
72
+ m.non_oop_errors
73
+ ])
74
+ m.reset()
75
+ self.assertEqual(len(m.matches.data_points), 0)
76
+ self.assertEqual(len(m.mismatches.data_points), 0)
77
+ self.assertEqual(len(m.oop_errors.data_points), 0)
78
+ self.assertEqual(len(m.non_oop_errors.data_points), 0)
79
+
80
+ def test_bad_case(self):
81
+ m = metrics.Match() # pylint: disable=invalid-name
82
+ with self.assertRaisesRegex(ValueError, '`groundtruth` is not present'):
83
+ m.audit(Example(id=1, input=pg.Dict(x=1), output=1))
84
+
85
+ def test_custom_metadata(self):
86
+
87
+ class MyMatch(metrics.Match):
88
+ def match(self, example_input, output):
89
+ return example_input.x == output, dict(x=example_input.x)
90
+
91
+ m = MyMatch() # pylint: disable=invalid-name
92
+ self.assertEqual(
93
+ m.audit(Example(id=1, input=pg.Dict(x=1), output=1)),
94
+ dict(match=True, x=1)
95
+ )
96
+ self.assertEqual(m.matches, 1.0)
97
+
98
+ def test_html_view(self):
99
+ m = metrics.Match() # pylint: disable=invalid-name
100
+ m.audit(Example(id=1, input=pg.Dict(groundtruth=1), output=1))
101
+ self.assertIn(
102
+ '100.0%',
103
+ m.to_html().content,
104
+ )
105
+ with pg.views.html.controls.HtmlControl.track_scripts() as scripts:
106
+ m.audit(Example(id=2, input=pg.Dict(groundtruth=1), output=2))
107
+ self.assertEqual(len(scripts), 12)
108
+
109
+
110
+ class ScoreTest(unittest.TestCase):
111
+
112
+ def test_basic(self):
113
+
114
+ class MyScore(metrics.Score):
115
+
116
+ def score(self, example_input, output) -> float:
117
+ return example_input.x * output
118
+
119
+ m = MyScore() # pylint: disable=invalid-name
120
+ self.assertEqual(
121
+ m.audit(Example(id=1, input=pg.Dict(x=1), output=1)),
122
+ dict(score=1 * 1)
123
+ )
124
+ self.assertEqual(
125
+ m.audit(Example(id=2, input=pg.Dict(x=2), output=2)),
126
+ dict(score=2 * 2)
127
+ )
128
+ self.assertEqual(
129
+ m.audit(
130
+ Example(
131
+ id=3,
132
+ input=pg.Dict(x=1),
133
+ error=pg.object_utils.ErrorInfo(
134
+ tag='ValueError',
135
+ description='Bad input.',
136
+ stacktrace='...',
137
+ )
138
+ )
139
+ ),
140
+ dict(error='ValueError')
141
+ )
142
+ self.assertEqual(
143
+ m.audit(
144
+ Example(
145
+ id=3,
146
+ input=pg.Dict(x=1),
147
+ error=pg.object_utils.ErrorInfo(
148
+ tag='MappingError.CodeError',
149
+ description='Bad input.',
150
+ stacktrace='...',
151
+ )
152
+ )
153
+ ),
154
+ dict(error='MappingError.CodeError')
155
+ )
156
+ self.assertEqual(m.average_score, 2.5)
157
+ self.assertEqual(m.oop_errors, 0.25)
158
+ self.assertEqual(m.non_oop_errors, 0.25)
159
+
160
+ self.assertEqual(m.values(), [
161
+ m.average_score,
162
+ m.oop_errors,
163
+ m.non_oop_errors
164
+ ])
165
+ m.reset()
166
+ self.assertEqual(len(m.average_score.data_points), 0)
167
+ self.assertEqual(len(m.oop_errors.data_points), 0)
168
+ self.assertEqual(len(m.non_oop_errors.data_points), 0)
169
+
170
+ def test_custom_metadata(self):
171
+
172
+ class MyScore(metrics.Score):
173
+
174
+ def score(self, example_input, output):
175
+ return example_input.x * output, dict(x=example_input.x)
176
+
177
+ m = MyScore() # pylint: disable=invalid-name
178
+ self.assertEqual(
179
+ m.audit(Example(id=1, input=pg.Dict(x=1), output=1)),
180
+ dict(score=1 * 1, x=1)
181
+ )
182
+ self.assertEqual(m.average_score, 1.0)
183
+
184
+ def test_html_view(self):
185
+
186
+ class MyScore(metrics.Score):
187
+
188
+ def score(self, example_input, output) -> float:
189
+ return example_input.x * output
190
+
191
+ m = MyScore() # pylint: disable=invalid-name
192
+ m.audit(Example(id=1, input=pg.Dict(x=1), output=2))
193
+ self.assertIn(
194
+ '2.000',
195
+ m.to_html().content,
196
+ )
197
+ with pg.views.html.controls.HtmlControl.track_scripts() as scripts:
198
+ m.audit(Example(id=2, input=pg.Dict(x=1), output=2))
199
+ self.assertEqual(len(scripts), 9)
200
+
201
+
202
+ if __name__ == '__main__':
203
+ unittest.main()