langfun 0.0.2.dev20240429__py3-none-any.whl → 0.0.2.dev20240511__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of langfun might be problematic. Click here for more details.

Files changed (37) hide show
  1. langfun/__init__.py +5 -0
  2. langfun/core/eval/__init__.py +14 -1
  3. langfun/core/eval/base.py +503 -112
  4. langfun/core/eval/base_test.py +185 -53
  5. langfun/core/eval/matching.py +22 -21
  6. langfun/core/eval/matching_test.py +23 -2
  7. langfun/core/eval/patching.py +130 -0
  8. langfun/core/eval/patching_test.py +170 -0
  9. langfun/core/eval/scoring.py +4 -4
  10. langfun/core/eval/scoring_test.py +19 -2
  11. langfun/core/langfunc.py +1 -17
  12. langfun/core/langfunc_test.py +4 -0
  13. langfun/core/language_model.py +6 -0
  14. langfun/core/llms/__init__.py +8 -0
  15. langfun/core/llms/fake.py +6 -6
  16. langfun/core/llms/google_genai.py +8 -0
  17. langfun/core/llms/openai.py +3 -2
  18. langfun/core/llms/openai_test.py +2 -1
  19. langfun/core/llms/vertexai.py +291 -0
  20. langfun/core/llms/vertexai_test.py +233 -0
  21. langfun/core/modalities/image.py +1 -3
  22. langfun/core/modalities/mime.py +6 -0
  23. langfun/core/modalities/video.py +1 -3
  24. langfun/core/structured/__init__.py +2 -0
  25. langfun/core/structured/mapping.py +5 -1
  26. langfun/core/structured/prompting.py +39 -11
  27. langfun/core/structured/prompting_test.py +43 -0
  28. langfun/core/structured/schema.py +34 -4
  29. langfun/core/structured/schema_test.py +32 -1
  30. langfun/core/structured/scoring.py +4 -1
  31. langfun/core/structured/scoring_test.py +6 -0
  32. langfun/core/template.py +22 -1
  33. {langfun-0.0.2.dev20240429.dist-info → langfun-0.0.2.dev20240511.dist-info}/METADATA +2 -2
  34. {langfun-0.0.2.dev20240429.dist-info → langfun-0.0.2.dev20240511.dist-info}/RECORD +37 -33
  35. {langfun-0.0.2.dev20240429.dist-info → langfun-0.0.2.dev20240511.dist-info}/LICENSE +0 -0
  36. {langfun-0.0.2.dev20240429.dist-info → langfun-0.0.2.dev20240511.dist-info}/WHEEL +0 -0
  37. {langfun-0.0.2.dev20240429.dist-info → langfun-0.0.2.dev20240511.dist-info}/top_level.txt +0 -0
langfun/core/eval/base.py CHANGED
@@ -18,16 +18,16 @@ import collections
18
18
  import dataclasses
19
19
  import functools
20
20
  import hashlib
21
+ import html
21
22
  import inspect
22
23
  import io
23
24
  import os
24
25
  import re
25
26
  import threading
26
27
  import time
28
+ import types
27
29
  from typing import Annotated, Any, Callable, Iterator, Literal, Optional, Sequence, Type, Union
28
30
 
29
- from absl import app
30
- from absl import flags
31
31
  import langfun.core as lf
32
32
  import langfun.core.coding as lf_coding
33
33
  from langfun.core.llms.cache import in_memory
@@ -40,7 +40,8 @@ class Evaluable(lf.Component):
40
40
 
41
41
  EXPERIMENT_JSON = 'experiment.json'
42
42
  RESULT_JSON = 'result.json'
43
- FAILURES_JSON = 'failures.json'
43
+ OOP_FAILURES_JSON = 'oop_failures.json'
44
+ NON_OOP_FAILURES_JSON = 'non_oop_failures.json'
44
45
  INDEX_HTML = 'index.html'
45
46
  SUMMARY_HTML = 'summary.html'
46
47
 
@@ -358,7 +359,7 @@ class Evaluable(lf.Component):
358
359
  color='yellow')
359
360
 
360
361
  for node in self.nonleaf_nodes:
361
- node._result = {c.id: c.result for c in node.children} # pylint: disable=protected-access
362
+ node._result = {c.id: c.result for c in node.leaf_nodes} # pylint: disable=protected-access
362
363
  if should_save:
363
364
  node.save(result=False, report=False)
364
365
 
@@ -540,15 +541,15 @@ class Evaluable(lf.Component):
540
541
  f'<div style="color: {text_color}; white-space: pre-wrap;'
541
542
  'padding: 10px; border: 1px solid; margin-top: 10px">'
542
543
  )
543
- s.write(m.get('formatted_text', m.text))
544
+ s.write(html.escape(m.get('formatted_text', m.text)))
544
545
  if m.result is not None:
545
546
  s.write(
546
547
  '<div style="color: magenta; white-space: pre-wrap;'
547
548
  'padding: 10px; border: 1px solid; margin: 10px">'
548
549
  )
549
- s.write(pg.format(m.result))
550
+ s.write(html.escape(pg.format(m.result)))
550
551
  s.write('</div>')
551
- if 'usage' in m.metadata:
552
+ if 'usage' in m.metadata and m.usage is not None:
552
553
  s.write(
553
554
  '<div style="background-color: #EEEEEE; color: black; '
554
555
  'white-space: pre-wrap; padding: 10px; border: 0px solid; '
@@ -598,7 +599,6 @@ class _LeafNode:
598
599
  @pg.use_init_args(['children'])
599
600
  class Suite(Evaluable):
600
601
  """Evaluation suite."""
601
-
602
602
  children: Annotated[list[Evaluable], 'Child evaluation sets or suites.']
603
603
 
604
604
  # Use empty ID as suite is just a container of child evaluations.
@@ -753,10 +753,12 @@ class Evaluation(Evaluable):
753
753
 
754
754
  # Constants.
755
755
  CACHE_JSON = 'cache.json'
756
- FAILURES_HTML = 'failures.html'
756
+ OOP_FAILURES_HTML = 'oop_failures.html'
757
+ NON_OOP_FAILURES_HTML = 'non_oop_failures.html'
757
758
 
758
759
  @functools.cached_property
759
760
  def hash(self) -> str:
761
+ """Returns the semantic-based hash of the evaluation."""
760
762
  if self.is_deterministic:
761
763
  identity = pg.format(self._identifiers(), compact=True)
762
764
  else:
@@ -805,6 +807,10 @@ class Evaluation(Evaluable):
805
807
  """Returns the complete rate."""
806
808
  return self.num_completed / self.num_examples
807
809
 
810
+ #
811
+ # Properties on failures.
812
+ #
813
+
808
814
  @property
809
815
  def failures(self) -> list[tuple[Any, Exception]]:
810
816
  """Returns the failed examples and their errors."""
@@ -815,6 +821,15 @@ class Evaluation(Evaluable):
815
821
  """Returns the number of failed examples."""
816
822
  return len(self.failures)
817
823
 
824
+ @functools.cached_property
825
+ def failure_breakdown(self) -> dict[str, int]:
826
+ """Returns the breakdown of failures."""
827
+ breakdown = collections.defaultdict(int)
828
+ for _, error in self.failures:
829
+ breakdown[_error_key(error)] += 1
830
+ sorted_items = sorted(breakdown.items(), key=lambda x: x[1], reverse=True)
831
+ return pg.Dict({x[0]: x[1] for x in sorted_items})
832
+
818
833
  @property
819
834
  def failure_rate(self) -> float:
820
835
  """Returns the failure rate in range [0, 1]."""
@@ -822,6 +837,46 @@ class Evaluation(Evaluable):
822
837
  return 0.0
823
838
  return self.num_failures / self.num_completed
824
839
 
840
+ @functools.cached_property
841
+ def oop_failures(self) -> list[tuple[Any, lf_structured.MappingError]]:
842
+ """Returns the OOP failures."""
843
+ return [item for item in self.failures
844
+ if isinstance(item[1], lf_structured.MappingError)]
845
+
846
+ @property
847
+ def num_oop_failures(self) -> int:
848
+ """Returns the number of OOP failures."""
849
+ return len(self.oop_failures)
850
+
851
+ @property
852
+ def oop_failure_rate(self) -> float:
853
+ """Returns the OOP failure rate in range [0, 1]."""
854
+ if self.num_completed == 0:
855
+ return 0.0
856
+ return self.num_oop_failures / self.num_completed
857
+
858
+ @functools.cached_property
859
+ def non_oop_failures(self) -> list[tuple[Any, Exception]]:
860
+ """Returns the OOP failures."""
861
+ return [item for item in self.failures
862
+ if not isinstance(item[1], lf_structured.MappingError)]
863
+
864
+ @property
865
+ def num_non_oop_failures(self) -> int:
866
+ """Returns the number of non-OOP failures."""
867
+ return len(self.non_oop_failures)
868
+
869
+ @property
870
+ def non_oop_failure_rate(self) -> float:
871
+ """Returns the non-OOP failure rate in range [0, 1]."""
872
+ if self.num_completed == 0:
873
+ return 0.0
874
+ return self.num_non_oop_failures / self.num_completed
875
+
876
+ #
877
+ # Properties on usage.
878
+ #
879
+
825
880
  @property
826
881
  def has_usage(self) -> bool:
827
882
  """Returns True if token usage is enabled."""
@@ -976,13 +1031,22 @@ class Evaluation(Evaluable):
976
1031
  self._total_prompt_tokens = 0
977
1032
  self._total_completion_tokens = 0
978
1033
  self._num_usages = 0
1034
+ self.__dict__.pop('oop_failures', None)
1035
+ self.__dict__.pop('non_oop_failures', None)
979
1036
 
980
1037
  @property
981
- def failures_link(self) -> str | None:
982
- """Returns the link to the failures page."""
1038
+ def oop_failures_link(self) -> str | None:
1039
+ """Returns the link to the OOP failures page."""
983
1040
  if self.dir is None:
984
1041
  return None
985
- return self.link(os.path.join(self.dir, Evaluation.FAILURES_HTML))
1042
+ return self.link(os.path.join(self.dir, Evaluation.OOP_FAILURES_HTML))
1043
+
1044
+ @property
1045
+ def non_oop_failures_link(self) -> str | None:
1046
+ """Returns the link to then non-OOP failures page."""
1047
+ if self.dir is None:
1048
+ return None
1049
+ return self.link(os.path.join(self.dir, Evaluation.NON_OOP_FAILURES_HTML))
986
1050
 
987
1051
  def _dryrun(
988
1052
  self,
@@ -992,11 +1056,11 @@ class Evaluation(Evaluable):
992
1056
  verbose: bool,
993
1057
  **kwargs,
994
1058
  ) -> None:
995
- # Set the example for dryrun.
996
- example = example or self.examples[0]
997
-
998
1059
  # We make a copy to avoid pollute the state of current object.
999
1060
  copy: Evaluation = self.clone()
1061
+
1062
+ # Set the example for dryrun.
1063
+ example = example or copy.examples[0]
1000
1064
  copy.__dict__['examples'] = [example]
1001
1065
 
1002
1066
  # We set the symbolic parent of the cloned to access contextual information
@@ -1011,23 +1075,34 @@ class Evaluation(Evaluable):
1011
1075
  color='green',
1012
1076
  )
1013
1077
 
1014
- with lf.use_settings(debug=debug):
1015
- output_message = copy.process(example, **(self.additional_args or {}))
1016
- if self.schema is None:
1017
- output = output_message.text
1018
- else:
1019
- output = output_message.result
1078
+ error, output_message = None, None
1020
1079
 
1021
- if verbose:
1080
+ try:
1081
+ with lf.use_settings(debug=debug):
1082
+ output_message = copy.process(example, **(self.additional_args or {}))
1083
+ if self.schema is None:
1084
+ output = output_message.text
1085
+ else:
1086
+ output = output_message.result
1087
+
1088
+ if verbose:
1089
+ lf.console.write('')
1090
+ lf.console.write(
1091
+ str(output),
1092
+ title='OUTPUT',
1093
+ color='blue',
1094
+ )
1095
+ except lf_structured.MappingError as e:
1022
1096
  lf.console.write('')
1023
1097
  lf.console.write(
1024
- str(output),
1025
- title='OUTPUT',
1026
- color='blue',
1098
+ str(e),
1099
+ title='ERROR',
1100
+ color='red',
1027
1101
  )
1102
+ error = e
1028
1103
 
1029
- copy.audit(example, output_message, None, dryrun=True)
1030
- result = copy.summarize()
1104
+ copy.audit(example, output_message, error, dryrun=True)
1105
+ result = copy.finalize()
1031
1106
 
1032
1107
  if verbose:
1033
1108
  lf.console.write('')
@@ -1051,6 +1126,9 @@ class Evaluation(Evaluable):
1051
1126
  **kwargs,
1052
1127
  ) -> None:
1053
1128
  # Setup examples.
1129
+ # Reset examples so it could be read from the input functor.
1130
+ self.__dict__.pop('examples', None)
1131
+
1054
1132
  if end is None:
1055
1133
  end = len(self.examples)
1056
1134
  examples = self.examples[start:end]
@@ -1087,7 +1165,7 @@ class Evaluation(Evaluable):
1087
1165
  self.cache.save()
1088
1166
 
1089
1167
  # Summarize result.
1090
- self._result = self.summarize()
1168
+ self._result = self.finalize()
1091
1169
  if verbose:
1092
1170
  lf.console.write(
1093
1171
  str(self.result),
@@ -1101,7 +1179,7 @@ class Evaluation(Evaluable):
1101
1179
 
1102
1180
  def process(self, example: Any, **kwargs) -> lf.Message:
1103
1181
  """Process an example and returns its output."""
1104
- prompt = self.prompt.render(example=example).text
1182
+ prompt = lf.Template.from_value(self.prompt, example=example)
1105
1183
  if self.method == 'call':
1106
1184
  return lf_structured.call(
1107
1185
  prompt,
@@ -1129,7 +1207,9 @@ class Evaluation(Evaluable):
1129
1207
  else:
1130
1208
  assert self.method == 'complete', self.method
1131
1209
  assert isinstance(self.schema.spec, pg.typing.Object), self.schema
1132
- input_value = self.schema.spec.cls.partial(prompt)
1210
+ # TODO(daiyip): Currently multi-modal inputs within the prompt for
1211
+ # completion is not supported.
1212
+ input_value = self.schema.spec.cls.partial(prompt.render().text)
1133
1213
  return lf_structured.complete(
1134
1214
  input_value,
1135
1215
  lm=self.lm,
@@ -1143,13 +1223,13 @@ class Evaluation(Evaluable):
1143
1223
  def _status(self, progress: lf.concurrent.Progress) -> dict[str, Any]:
1144
1224
  return {
1145
1225
  'Model': self.lm.model_id,
1146
- 'Succeeded': f'%.{self.report_precision}f%% (%d/%d)' % (
1147
- progress.success_rate * 100,
1226
+ 'Succeeded': '%s (%d/%d)' % (
1227
+ self._format_rate(progress.success_rate),
1148
1228
  progress.succeeded,
1149
1229
  progress.completed,
1150
1230
  ),
1151
- 'Failed': f'%.{self.report_precision}f%% (%d/%d)' % (
1152
- progress.failure_rate * 100,
1231
+ 'Failed': '%s (%d/%d)' % (
1232
+ self._format_rate(progress.failure_rate),
1153
1233
  progress.failed,
1154
1234
  progress.completed,
1155
1235
  ),
@@ -1159,21 +1239,20 @@ class Evaluation(Evaluable):
1159
1239
  assert self.result is not None
1160
1240
  m = self.result.metrics
1161
1241
  return (
1162
- f'COMPLETED(%s): Successes=%.{self.report_precision}f%% (%d/%d)'
1163
- f' Failures=%.{self.report_precision}f%% (%d/%d)'
1242
+ 'COMPLETED(%s): Successes=%s(%d/%d) Failures=%s (%d/%d)'
1164
1243
  % (
1165
1244
  run_status,
1166
- (1 - m.failure_rate) * 100,
1245
+ self._format_rate(1 - m.failure_rate),
1167
1246
  m.total - m.failures,
1168
1247
  m.total,
1169
- m.failure_rate * 100,
1248
+ self._format_rate(m.failure_rate),
1170
1249
  m.failures,
1171
1250
  m.total,
1172
1251
  )
1173
1252
  )
1174
1253
 
1175
- def summarize(self) -> pg.Dict:
1176
- """Summarizes the evaluation result."""
1254
+ def finalize(self) -> pg.Dict:
1255
+ """Finalizes the evaluation result."""
1177
1256
  if self.cache:
1178
1257
  cache_stats = dict(
1179
1258
  use_cache=True,
@@ -1210,12 +1289,18 @@ class Evaluation(Evaluable):
1210
1289
  total=self.num_completed,
1211
1290
  failures=self.num_failures,
1212
1291
  failure_rate=self.failure_rate,
1292
+ oop_failures=self.num_oop_failures,
1293
+ oop_failure_rate=self.oop_failure_rate,
1294
+ non_oop_failures=self.num_non_oop_failures,
1295
+ non_oop_failure_rate=self.non_oop_failure_rate,
1296
+ failure_breakdown=self.failure_breakdown,
1213
1297
  ),
1214
1298
  usage=usage,
1215
1299
  )
1216
1300
  return result
1217
1301
 
1218
- def summarize_html(self) -> str:
1302
+ def summary_card(self) -> str:
1303
+ """Returns summary card in HTML."""
1219
1304
  s = io.StringIO()
1220
1305
  definition = _html_repr(self, compact=False, escape=True)
1221
1306
  s.write('<div><table><tr><td>')
@@ -1230,18 +1315,19 @@ class Evaluation(Evaluable):
1230
1315
  s.write(
1231
1316
  f'<a target="_blank" title="{definition}" '
1232
1317
  f'href="{self.index_link}">{self.hash}</a>'
1318
+ f' &nbsp;[<a href="{self.link(self.dir)}">dir</a>]'
1233
1319
  '</td></tr><tr><td>'
1234
1320
  )
1235
- self._render_metric(s)
1321
+ self._render_summary_metrics(s)
1236
1322
 
1237
1323
  # Summarize average usage.
1238
1324
  if self.result.usage is not None:
1239
- self._render_usage(s)
1325
+ self._render_summary_usage(s)
1240
1326
 
1241
1327
  s.write('</td></tr></table></div>')
1242
1328
  return s.getvalue()
1243
1329
 
1244
- def _render_usage(self, s: io.StringIO) -> None:
1330
+ def _render_summary_usage(self, s: io.StringIO) -> None:
1245
1331
  """Renders usage in HTML."""
1246
1332
  usage = self.result.usage
1247
1333
  total = usage.total_prompt_tokens + usage.total_completion_tokens
@@ -1255,19 +1341,65 @@ class Evaluation(Evaluable):
1255
1341
  f'" style="color:gray">({total} tokens)</a>'
1256
1342
  )
1257
1343
 
1258
- def _render_metric(self, s: io.StringIO) -> None:
1344
+ def _render_summary_metrics(self, s: io.StringIO) -> None:
1259
1345
  """Renders metrics in HTML."""
1260
1346
  assert self.result is not None
1261
1347
  m = self.result.metrics
1348
+
1349
+ # OOP failures.
1350
+ oop_failure_title = f'OOP failures ({m.oop_failures}/{m.total})'
1351
+ if m.oop_failures:
1352
+ oop_failure_title += '&#013;'
1353
+ for name, count in m.failure_breakdown.items():
1354
+ if name.startswith('MappingError'):
1355
+ oop_failure_title += '&#013;%s: %s (%d/%d)' % (
1356
+ name.removeprefix('MappingError.'),
1357
+ self._format_rate(count / m.total),
1358
+ count,
1359
+ m.total,
1360
+ )
1361
+
1362
+ extra_style = ''
1363
+ if m.oop_failure_rate > 0.1 and m.oop_failures > 3:
1364
+ extra_style = ';font-weight:bold'
1262
1365
  s.write(
1263
- '<a title="Failures (%d/%d)" href="%s" style="color:red">%s</a>'
1366
+ '<a title="%s" href="%s" style="color:magenta%s">%s</a>'
1264
1367
  % (
1265
- m.failures,
1266
- m.total,
1267
- self.failures_link,
1268
- f'%.{self.report_precision}f%% ' % (m.failure_rate * 100),
1368
+ oop_failure_title,
1369
+ self.oop_failures_link,
1370
+ extra_style,
1371
+ self._format_rate(m.oop_failure_rate),
1269
1372
  )
1270
1373
  )
1374
+ s.write(' | ')
1375
+
1376
+ # Non-OOP failures.
1377
+ non_oop_failure_title = f'Non-OOP failures ({m.non_oop_failures}/{m.total})'
1378
+ if m.non_oop_failures:
1379
+ non_oop_failure_title += '&#013;'
1380
+ for name, count in m.failure_breakdown.items():
1381
+ if not name.startswith('MappingError'):
1382
+ non_oop_failure_title += '&#013;%s: %s (%d/%d)' % (
1383
+ name,
1384
+ self._format_rate(count / m.total),
1385
+ count,
1386
+ m.total,
1387
+ )
1388
+
1389
+ extra_style = ';font-weight:bold' if m.non_oop_failures > 0 else ''
1390
+ s.write(
1391
+ '<a title="%s" href="%s" style="color:red%s">%s</a>'
1392
+ % (
1393
+ non_oop_failure_title,
1394
+ self.non_oop_failures_link,
1395
+ extra_style,
1396
+ self._format_rate(m.non_oop_failure_rate),
1397
+ )
1398
+ )
1399
+
1400
+ def _format_rate(self, rate: float) -> str:
1401
+ """Formats a rate."""
1402
+ return f'%.{self.report_precision}f%% ' % (rate * 100)
1271
1403
 
1272
1404
  def audit(
1273
1405
  self,
@@ -1287,7 +1419,13 @@ class Evaluation(Evaluable):
1287
1419
  dryrun: Whether or not audition takes place during dryrun.
1288
1420
  """
1289
1421
  if error is not None:
1290
- self._failures.append((example, str(error)))
1422
+ self._failures.append((example, error))
1423
+
1424
+ # Invalid cache of num_oop_failures.
1425
+ self.__dict__.pop('oop_failures', None)
1426
+ self.__dict__.pop('non_oop_failures', None)
1427
+ self.__dict__.pop('failure_breakdown', None)
1428
+
1291
1429
  if isinstance(error, lf_structured.MappingError):
1292
1430
  message = error.lm_response
1293
1431
  else:
@@ -1301,8 +1439,9 @@ class Evaluation(Evaluable):
1301
1439
  self._num_completed += 1
1302
1440
 
1303
1441
  def audit_usage(self, message: lf.Message, dryrun: bool = False) -> None:
1442
+ del dryrun
1304
1443
  for m in message.trace():
1305
- if 'usage' in m.metadata:
1444
+ if m.metadata.get('usage', None) is not None:
1306
1445
  self._total_prompt_tokens += m.usage.prompt_tokens
1307
1446
  self._total_completion_tokens += m.usage.completion_tokens
1308
1447
  self._num_usages += 1
@@ -1333,16 +1472,26 @@ class Evaluation(Evaluable):
1333
1472
  # Save failures.
1334
1473
  pg.save(
1335
1474
  [
1336
- pg.Dict(
1337
- input=input, error=lf.text_formatting.decolored(str(error))
1338
- )
1339
- for input, error in self.failures
1475
+ pg.Dict(input=input, error=_format_error(error))
1476
+ for input, error in self.oop_failures
1477
+ ],
1478
+ os.path.join(self.dir, Evaluation.OOP_FAILURES_JSON),
1479
+ )
1480
+ pg.save(
1481
+ self._html([self._render_result, self._render_oop_failures]),
1482
+ os.path.join(self.dir, Evaluation.OOP_FAILURES_HTML),
1483
+ file_format='txt',
1484
+ )
1485
+ pg.save(
1486
+ [
1487
+ pg.Dict(input=input, error=_format_error(error))
1488
+ for input, error in self.non_oop_failures
1340
1489
  ],
1341
- os.path.join(self.dir, Evaluation.FAILURES_JSON),
1490
+ os.path.join(self.dir, Evaluation.NON_OOP_FAILURES_JSON),
1342
1491
  )
1343
1492
  pg.save(
1344
- self._html([self._render_result, self._render_failures]),
1345
- os.path.join(self.dir, Evaluation.FAILURES_HTML),
1493
+ self._html([self._render_result, self._render_non_oop_failures]),
1494
+ os.path.join(self.dir, Evaluation.NON_OOP_FAILURES_HTML),
1346
1495
  file_format='txt',
1347
1496
  )
1348
1497
 
@@ -1357,7 +1506,8 @@ class Evaluation(Evaluable):
1357
1506
  )
1358
1507
  if self.result.usage is not None:
1359
1508
  s.write('<td>Usage</td>')
1360
- s.write('<td>Failures</td>')
1509
+ s.write('<td>OOP Failures</td>')
1510
+ s.write('<td>Non-OOP Failures</td>')
1361
1511
 
1362
1512
  def _render_result_row(self, s: io.StringIO) -> None:
1363
1513
  s.write(
@@ -1385,16 +1535,29 @@ class Evaluation(Evaluable):
1385
1535
  # Usage.
1386
1536
  if self.result.usage is not None:
1387
1537
  s.write('<td>')
1388
- self._render_usage(s)
1538
+ self._render_summary_usage(s)
1389
1539
  s.write('</td>')
1390
1540
 
1391
- # Failures.
1541
+ # OOP failures.
1542
+ s.write(
1543
+ '<td><span style="color:magenta">%s</span>%s</td>'
1544
+ % (
1545
+ self._format_rate(self.oop_failure_rate),
1546
+ '<a href="%s">(%d/%d)</a>'
1547
+ % (self.oop_failures_link,
1548
+ self.num_oop_failures,
1549
+ self.num_completed),
1550
+ )
1551
+ )
1552
+ # Non-OOP failures.
1392
1553
  s.write(
1393
- '<td><span style="color:orange">%s</span>%s</td>'
1554
+ '<td><span style="color:red">%s</span>%s</td>'
1394
1555
  % (
1395
- f'%.{self.report_precision}f%%' % (self.failure_rate * 100),
1556
+ self._format_rate(self.non_oop_failure_rate),
1396
1557
  '<a href="%s">(%d/%d)</a>'
1397
- % (self.failures_link, self.num_failures, self.num_completed),
1558
+ % (self.non_oop_failures_link,
1559
+ self.num_non_oop_failures,
1560
+ self.num_completed),
1398
1561
  )
1399
1562
  )
1400
1563
 
@@ -1408,24 +1571,77 @@ class Evaluation(Evaluable):
1408
1571
  else:
1409
1572
  return 'cyan'
1410
1573
 
1411
- def _render_failures(self, s: io.StringIO) -> None:
1574
+ def _render_oop_failures(self, s: io.StringIO) -> None:
1575
+ self._render_failures(s, '^MappingError.*', error_color='magenta')
1576
+
1577
+ def _render_non_oop_failures(self, s: io.StringIO) -> None:
1578
+ self._render_failures(s, '^(?!MappingError).*', error_color='red')
1579
+
1580
+ def _render_failures(
1581
+ self, s: io.StringIO, error_regex: str, error_color: str) -> None:
1412
1582
  """Formats the failed cases into html."""
1583
+ # Failure summary.
1413
1584
  s.write(
1414
- '<h2> Failed Cases </h2>'
1585
+ '<h2> Error Summary </h2>'
1415
1586
  '<div style="white-space:pre">\n'
1416
1587
  '<table style="border:1px solid">'
1417
- '<tr class="header"><td>No.</td><td>Input</td><td>Error</td></tr>'
1588
+ '<tr class="header"><td>Error type</td><td>Stats</td></tr>'
1418
1589
  )
1590
+ error_regex = re.compile(error_regex)
1591
+ if self.result.metrics.failure_breakdown:
1592
+ for name, count in self.result.metrics.failure_breakdown.items():
1593
+ if not error_regex.match(name):
1594
+ continue
1595
+
1596
+ link = f'<a href="#{name}">{name}</a>'
1597
+ error_rate = self._format_rate(count / self.result.metrics.total)
1598
+ stats = (f'<span style="color:{error_color}">{error_rate} '
1599
+ f'({count}/{self.result.metrics.total})</span>')
1600
+ s.write(f'<tr><td>{link}</td><td>{stats})</td></tr>')
1601
+ s.write(
1602
+ '</table></div>'
1603
+ '<h2> Failed Cases </h2>'
1604
+ '<div style="white-space:pre">'
1605
+ )
1606
+ # Failure details by error type.
1607
+ failures_by_error = collections.defaultdict(list)
1608
+ for example, error in self.failures:
1609
+ error_name = _error_key(error)
1610
+ if error_regex.match(error_name):
1611
+ failures_by_error[error_name].append((example, error))
1612
+
1613
+ for error_key, failures in failures_by_error.items():
1614
+ s.write(
1615
+ f'<h3 id="{error_key}"><a href="#{error_key}">{error_key}</a> '
1616
+ f'(count={len(failures)})</h3>'
1617
+ '<table style="border:1px solid">'
1618
+ '<tr class="header"><td>No.</td><td>Input</td>'
1619
+ '<td>LM invocation</td><td>Error</td></tr>'
1620
+ )
1621
+ for i, (example, error) in enumerate(failures):
1622
+ lm_response = None
1623
+ if isinstance(error, lf.structured.MappingError):
1624
+ lm_response = error.lm_response
1625
+ error = error.cause
1626
+
1627
+ bgcolor = 'white' if i % 2 == 0 else '#DDDDDD'
1628
+ s.write(f'<tr style="background-color: {bgcolor}"><td>{i + 1}</td>')
1629
+ s.write('<td style="color:green;white-space:pre-wrap">')
1630
+ s.write(pg.format(example, verbose=False))
1631
+ s.write('</td><td>')
1632
+ if lm_response is not None:
1633
+ self._render_message(lm_response, s)
1634
+ s.write(f'</td><td style="color:{error_color};white-space:pre">')
1635
+ s.write(_format_error(error))
1636
+ s.write('</td></tr>')
1637
+ s.write('</table>')
1638
+ s.write('</div>')
1419
1639
 
1420
- for i, (example, error) in enumerate(self.failures):
1421
- bgcolor = 'white' if i % 2 == 0 else '#DDDDDD'
1422
- s.write(f'<tr style="background-color: {bgcolor}"><td>{i + 1}</td>')
1423
- input_str = pg.format(example, verbose=False)
1424
- s.write(f'<td style="color:green;white-space:pre-wrap">{input_str}</td>')
1425
- error_str = lf.text_formatting.decolored(str(error))
1426
- s.write(f'<td style="color:red;white-space:pre">{error_str}</td>')
1427
- s.write('</tr>')
1428
- s.write('</table></div>')
1640
+ @classmethod
1641
+ def visualize(cls, evaluations: list['Evaluation']) -> str | None:
1642
+ """Visualize the a list of evaluations of this task in HTML."""
1643
+ del evaluations
1644
+ return None
1429
1645
 
1430
1646
 
1431
1647
  @pg.functor()
@@ -1578,7 +1794,7 @@ class Summary(pg.Object):
1578
1794
  if e is None:
1579
1795
  s.write('<span style="color: gray">N/A<span>')
1580
1796
  else:
1581
- s.write(e.summarize_html())
1797
+ s.write(e.summary_card())
1582
1798
  s.write('</td>')
1583
1799
  s.write('</tr>')
1584
1800
  s.write('</table>')
@@ -1653,13 +1869,22 @@ class Summary(pg.Object):
1653
1869
  s.write('<html><body>')
1654
1870
  for task in sorted(self.tasks(), key=lambda cls: cls.__name__):
1655
1871
  table_id = task.__name__.lower()
1872
+ evaluations = self.select(task=task).evaluations
1873
+ table = Summary.Table.from_evaluations(evaluations, pivot_field)
1656
1874
  s.write('<div>')
1657
- s.write(f'<a id="{table_id}"')
1658
- s.write(f'<h2><a href="#{table_id}">{task.__name__}</a></h2>')
1659
- s.write('</a>')
1660
- table = Summary.Table.from_evaluations(
1661
- self.select(task=task).evaluations, pivot_field
1875
+ s.write(
1876
+ f'<a id="{table_id}" href="#{table_id}">'
1877
+ f'<h2>{task.__name__}</h2></a>'
1662
1878
  )
1879
+
1880
+ # Allow users to plugin visualization code (e.g. matplot) in the summary
1881
+ # page.
1882
+ visual_part = task.visualize(evaluations)
1883
+ if visual_part:
1884
+ s.write(visual_part)
1885
+
1886
+ s.write(f'<h4 style="color:gray">{len(evaluations)} experiments</h4>')
1887
+ s.write('<hr/>')
1663
1888
  s.write(table.html())
1664
1889
  s.write('</div>')
1665
1890
  s.write('</body></html>')
@@ -1685,6 +1910,7 @@ class Summary(pg.Object):
1685
1910
  experiment=entry,
1686
1911
  dir=entry.dir,
1687
1912
  metrics=entry.result.metrics if entry.result else None,
1913
+ usage=entry.result.usage if entry.result else None,
1688
1914
  )
1689
1915
  )
1690
1916
  task_results[task.__name__] = results
@@ -1833,6 +2059,21 @@ class Summary(pg.Object):
1833
2059
  return result.join()
1834
2060
 
1835
2061
 
2062
+ def _format_error(error: Exception):
2063
+ """Formats an error into a string."""
2064
+ return (f'({error.__class__.__name__}) '
2065
+ + lf.text_formatting.decolored(str(error)))
2066
+
2067
+
2068
+ def _error_key(error: Exception) -> str:
2069
+ """Returns the key for an error."""
2070
+ error_names = []
2071
+ while error is not None:
2072
+ error_names.append(error.__class__.__name__)
2073
+ error = getattr(error, 'cause', None)
2074
+ return '.'.join(error_names)
2075
+
2076
+
1836
2077
  def _html_repr(value: Any, compact: bool = True, escape: bool = False) -> str:
1837
2078
  """Formats prompt in HTML."""
1838
2079
  if type(value) is lf.Template: # pylint: disable=unidiomatic-typecheck
@@ -1909,41 +2150,191 @@ def monitor_async(
1909
2150
  )
1910
2151
 
1911
2152
 
1912
- def app_run(target: Evaluable):
1913
- """Runs the target evaluation as an absl app.
2153
+ #
2154
+ # Named evaluations and experiments support.
2155
+ #
2156
+
2157
+
2158
+ class _NamedEvaluationRegistry:
2159
+ """Named evaluation registry."""
1914
2160
 
1915
- Args:
1916
- target: An Langfun evaluable object.
1917
- """
1918
- flags.DEFINE_string(
1919
- 'root_dir', None, 'Root directory for running the evaluation.'
1920
- )
2161
+ def __init__(self):
2162
+ self._registry = {}
1921
2163
 
1922
- flags.DEFINE_bool(
1923
- 'dryrun', False, 'If True, dryrun the experiment instead of running it.'
1924
- )
2164
+ def names(self) -> list[str]:
2165
+ """Returns all registered names."""
2166
+ return sorted(self._registry.keys())
1925
2167
 
1926
- flags.DEFINE_bool(
1927
- 'debug', False, 'If True, output prompt and response to the console.'
1928
- )
2168
+ def get(self, name: str) -> Type[Evaluable]:
2169
+ """Gets an evaluation by name."""
2170
+ if name not in self._registry:
2171
+ raise ValueError(
2172
+ f'Evaluation {name!r} not found. '
2173
+ 'Did you forget to import the module that registers it?'
2174
+ )
2175
+ return self._registry[name]
1929
2176
 
1930
- flags.DEFINE_bool(
1931
- 'rerun',
1932
- False,
1933
- 'If True, rerun the experiment even a cached result is found.',
1934
- )
2177
+ def register(
2178
+ self,
2179
+ name: str,
2180
+ experiment_cls: Type[Evaluable],
2181
+ ):
2182
+ """Register an experiment class."""
2183
+ self._registry[name] = experiment_cls
2184
+
2185
+
2186
+ _eval_registry = _NamedEvaluationRegistry()
2187
+
2188
+
2189
+ def registered_names() -> list[str]:
2190
+ """Returns all registered names."""
2191
+ return _eval_registry.names()
1935
2192
 
1936
- FLAGS = flags.FLAGS # pylint: disable=invalid-name
1937
2193
 
1938
- def _main(argv):
1939
- if len(argv) > 1:
1940
- raise app.UsageError('Too many command-line arguments.')
2194
+ def get_evaluation(evaluation: str | Evaluable) -> Evaluable:
2195
+ """Gets an evaluation experiment by name."""
2196
+ if isinstance(evaluation, str):
2197
+ return _eval_registry.get(evaluation)()
2198
+ return evaluation
1941
2199
 
1942
- if FLAGS.root_dir:
1943
- target.rebind(root_dir=FLAGS.root_dir, raise_on_no_change=False)
1944
- if FLAGS.dryrun:
1945
- target.dryrun(debug=FLAGS.debug)
2200
+
2201
+ def register(name: str):
2202
+ """Decorator to create a named evaluation class."""
2203
+
2204
+ def _register(func_or_cls: Type[Evaluation] | types.FunctionType):
2205
+ if inspect.isfunction(func_or_cls):
2206
+ e = func_or_cls()
2207
+ if not isinstance(e, Evaluable):
2208
+ raise TypeError(
2209
+ f'The return value of `{func_or_cls}` should be an instance of '
2210
+ '`lf.eval.Evaluable` subclass.'
2211
+ )
2212
+
2213
+ class GeneratedSuite(Suite):
2214
+ # NOTE(daiyip): Delay serialization key registration for generated
2215
+ # class.
2216
+ auto_register = False
2217
+ children = e.children if isinstance(e, Suite) else [e]
2218
+
2219
+ cls = GeneratedSuite
2220
+ cls.__name__ = func_or_cls.__name__
2221
+ cls.__doc__ = func_or_cls.__doc__
2222
+ cls.__qualname__ = func_or_cls.__qualname__
2223
+ cls.__module__ = getattr(func_or_cls, '__module__', 'wrapper')
2224
+ cls.register_for_deserialization(cls.__type_name__)
2225
+
2226
+ elif issubclass(func_or_cls, Evaluable):
2227
+ cls = func_or_cls
1946
2228
  else:
1947
- target.run(debug=FLAGS.debug, rerun=FLAGS.rerun)
2229
+ raise ValueError(f'Unsupported type: {type(func_or_cls)}')
2230
+
2231
+ _eval_registry.register(name, cls)
2232
+ return cls
2233
+
2234
+ return _register
2235
+
2236
+
2237
+ def get(
2238
+ root_dir: str,
2239
+ evaluations: list[str | Evaluable],
2240
+ filter: Union[ # pylint: disable=redefined-builtin
2241
+ str, # Regex to filter evaluation based on ID.
2242
+ Callable[[Evaluable], bool], # Custom filter function.
2243
+ None # No filtering (Default).
2244
+ ] = None, # pylint: disable=bad-whitespace
2245
+ patches: list[Union[
2246
+ str, # String-based PyGlove patcher.
2247
+ pg.patching.Patcher, # PyGlove patcher object.
2248
+ Callable[[pg.KeyPath, Any, Any], Any], # PyGlove rebind function.
2249
+ ]] | None = None, # pylint: disable=bad-whitespace
2250
+ ) -> Suite:
2251
+ """Gets a suite from a list of patched evaluations.
2252
+
2253
+ Args:
2254
+ root_dir: The root directory of the experiment.
2255
+ evaluations: A list of evaluations to be included in the suite.
2256
+ filter: A regular expression (str) for selecting sub-experiments of matched
2257
+ IDs, or a filter function to filter the evaluations.
2258
+ patches: A list of patches to be applied to the suite. Each element can be
2259
+ a string (for string-based patcher), a `pg.patching.Patcher` object, or
2260
+ a rebind function (e.g. `pg.rebind`). See `lf.eval.patch_*` for more
2261
+ details.
2262
+
2263
+ Returns:
2264
+ A suite of selected `lf.eval.Evaluation` objects.
2265
+ """
2266
+ evaluations = [get_evaluation(e) for e in evaluations]
2267
+ suite = Suite(evaluations, root_dir=root_dir)
2268
+ if patches:
2269
+ suite = pg.patch(suite, patches)
2270
+
2271
+ if isinstance(filter, str):
2272
+ regex = re.compile(filter)
2273
+ filter = lambda x: bool(regex.match(x.id))
2274
+
2275
+ if filter:
2276
+ suite = Suite(
2277
+ [leaf for leaf in suite.leaf_nodes if filter(leaf)], root_dir=root_dir)
2278
+ return suite
2279
+
2280
+
2281
+ def run(
2282
+ root_dir: str,
2283
+ evaluations: list[str | Evaluable],
2284
+ filter: Union[ # pylint: disable=redefined-builtin
2285
+ str, # Regex to filter evaluation based on ID.
2286
+ Callable[[Evaluable], bool], # Custom filter function.
2287
+ None # No filtering (Default).
2288
+ ] = None, # pylint: disable=bad-whitespace
2289
+ patches: list[Union[
2290
+ str, # String-based PyGlove patcher.
2291
+ pg.patching.Patcher, # PyGlove patcher object.
2292
+ Callable[[pg.KeyPath, Any, Any], Any], # PyGlove rebind function.
2293
+ ]] | None = None, # pylint: disable=bad-whitespace
2294
+ mode: Literal['run', 'rerun', 'dryrun', 'noop'] = 'run',
2295
+ debug: bool = False,
2296
+ print_definition: bool = False,
2297
+ **kwargs,
2298
+ ) -> Suite:
2299
+ """Run selected evaluations with patching.
2300
+
2301
+ Args:
2302
+ root_dir: The root directory of the experiment.
2303
+ evaluations: A list of evaluations to be included in the suite.
2304
+ filter: A regular expression (str) for selecting sub-experiments of matched
2305
+ IDs, or a filter function to filter the evaluations.
2306
+ patches: A list of patches to be applied to the suite. Each element can be
2307
+ a string (for string-based patcher), a `pg.patching.Patcher` object, or
2308
+ a rebind function (e.g. `pg.rebind`). See `lf.eval.patch_*` for more
2309
+ details.
2310
+ mode: The mode to run the suite. "run" to run the suite, with reusing
2311
+ existing results if available; "rerun" to rerun all evaluations even if
2312
+ there are existing results; "dryrun" to dryrun the suite; and "noop"
2313
+ to do nothing.
2314
+ debug: Whether to run in debug mode.
2315
+ print_definition: Whether to print the experiment definition.
2316
+ **kwargs: Additional arguments to be passed to dryrun/run the suite.
2317
+
2318
+ Returns:
2319
+ A suite of selected `lf.eval.Evaluation` objects.
2320
+ """
2321
+ suite = get(root_dir, evaluations, patches=patches, filter=filter)
2322
+ if print_definition:
2323
+ lf.console.write(
2324
+ pg.format(
2325
+ suite,
2326
+ compact=False,
2327
+ verbose=False,
2328
+ hide_default_values=True,
2329
+ python_format=True,
2330
+ ),
2331
+ title='[EXPERIMENT DEFINITION]',
2332
+ color='blue',
2333
+ )
1948
2334
 
1949
- app.run(_main)
2335
+ if mode == 'run':
2336
+ rerun = mode == 'rerun'
2337
+ suite.run(debug=debug, rerun=rerun, **kwargs)
2338
+ elif mode == 'dryrun':
2339
+ suite.dryrun(debug=debug, **kwargs)
2340
+ return suite