langfun 0.0.2.dev20240330__py3-none-any.whl → 0.0.2.dev20240511__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of langfun might be problematic. Click here for more details.

Files changed (59) hide show
  1. langfun/__init__.py +7 -0
  2. langfun/core/__init__.py +1 -0
  3. langfun/core/coding/python/correction.py +0 -7
  4. langfun/core/component.py +6 -0
  5. langfun/core/component_test.py +1 -0
  6. langfun/core/eval/__init__.py +15 -0
  7. langfun/core/eval/base.py +665 -95
  8. langfun/core/eval/base_test.py +224 -53
  9. langfun/core/eval/matching.py +48 -30
  10. langfun/core/eval/matching_test.py +25 -3
  11. langfun/core/eval/patching.py +130 -0
  12. langfun/core/eval/patching_test.py +170 -0
  13. langfun/core/eval/scoring.py +19 -10
  14. langfun/core/eval/scoring_test.py +21 -3
  15. langfun/core/langfunc.py +1 -22
  16. langfun/core/langfunc_test.py +10 -4
  17. langfun/core/language_model.py +130 -24
  18. langfun/core/language_model_test.py +249 -26
  19. langfun/core/llms/__init__.py +27 -2
  20. langfun/core/llms/anthropic.py +263 -0
  21. langfun/core/llms/anthropic_test.py +167 -0
  22. langfun/core/llms/cache/in_memory_test.py +37 -28
  23. langfun/core/llms/fake.py +34 -25
  24. langfun/core/llms/fake_test.py +122 -11
  25. langfun/core/llms/google_genai.py +8 -0
  26. langfun/core/llms/google_genai_test.py +8 -3
  27. langfun/core/llms/groq.py +260 -0
  28. langfun/core/llms/groq_test.py +170 -0
  29. langfun/core/llms/llama_cpp.py +3 -1
  30. langfun/core/llms/openai.py +100 -81
  31. langfun/core/llms/openai_test.py +287 -60
  32. langfun/core/llms/vertexai.py +291 -0
  33. langfun/core/llms/vertexai_test.py +233 -0
  34. langfun/core/modalities/image.py +1 -3
  35. langfun/core/modalities/mime.py +6 -0
  36. langfun/core/modalities/video.py +6 -5
  37. langfun/core/structured/__init__.py +5 -0
  38. langfun/core/structured/completion_test.py +2 -2
  39. langfun/core/structured/function_generation.py +245 -0
  40. langfun/core/structured/function_generation_test.py +329 -0
  41. langfun/core/structured/mapping.py +61 -3
  42. langfun/core/structured/mapping_test.py +17 -0
  43. langfun/core/structured/parsing_test.py +18 -13
  44. langfun/core/structured/prompting.py +61 -12
  45. langfun/core/structured/prompting_test.py +122 -12
  46. langfun/core/structured/schema.py +38 -6
  47. langfun/core/structured/schema_generation_test.py +2 -2
  48. langfun/core/structured/schema_test.py +36 -7
  49. langfun/core/structured/scoring.py +4 -1
  50. langfun/core/structured/scoring_test.py +6 -0
  51. langfun/core/template.py +147 -11
  52. langfun/core/template_test.py +75 -0
  53. langfun/core/templates/selfplay_test.py +6 -2
  54. {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240511.dist-info}/METADATA +3 -2
  55. langfun-0.0.2.dev20240511.dist-info/RECORD +112 -0
  56. langfun-0.0.2.dev20240330.dist-info/RECORD +0 -102
  57. {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240511.dist-info}/LICENSE +0 -0
  58. {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240511.dist-info}/WHEEL +0 -0
  59. {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240511.dist-info}/top_level.txt +0 -0
langfun/core/eval/base.py CHANGED
@@ -18,12 +18,14 @@ import collections
18
18
  import dataclasses
19
19
  import functools
20
20
  import hashlib
21
+ import html
21
22
  import inspect
22
23
  import io
23
24
  import os
24
25
  import re
25
26
  import threading
26
27
  import time
28
+ import types
27
29
  from typing import Annotated, Any, Callable, Iterator, Literal, Optional, Sequence, Type, Union
28
30
 
29
31
  import langfun.core as lf
@@ -38,7 +40,8 @@ class Evaluable(lf.Component):
38
40
 
39
41
  EXPERIMENT_JSON = 'experiment.json'
40
42
  RESULT_JSON = 'result.json'
41
- FAILURES_JSON = 'failures.json'
43
+ OOP_FAILURES_JSON = 'oop_failures.json'
44
+ NON_OOP_FAILURES_JSON = 'non_oop_failures.json'
42
45
  INDEX_HTML = 'index.html'
43
46
  SUMMARY_HTML = 'summary.html'
44
47
 
@@ -356,7 +359,7 @@ class Evaluable(lf.Component):
356
359
  color='yellow')
357
360
 
358
361
  for node in self.nonleaf_nodes:
359
- node._result = {c.id: c.result for c in node.children} # pylint: disable=protected-access
362
+ node._result = {c.id: c.result for c in node.leaf_nodes} # pylint: disable=protected-access
360
363
  if should_save:
361
364
  node.save(result=False, report=False)
362
365
 
@@ -538,14 +541,24 @@ class Evaluable(lf.Component):
538
541
  f'<div style="color: {text_color}; white-space: pre-wrap;'
539
542
  'padding: 10px; border: 1px solid; margin-top: 10px">'
540
543
  )
541
- s.write(m.text)
544
+ s.write(html.escape(m.get('formatted_text', m.text)))
542
545
  if m.result is not None:
543
546
  s.write(
544
547
  '<div style="color: magenta; white-space: pre-wrap;'
545
548
  'padding: 10px; border: 1px solid; margin: 10px">'
546
549
  )
547
- s.write(pg.format(m.result))
550
+ s.write(html.escape(pg.format(m.result)))
548
551
  s.write('</div>')
552
+ if 'usage' in m.metadata and m.usage is not None:
553
+ s.write(
554
+ '<div style="background-color: #EEEEEE; color: black; '
555
+ 'white-space: pre-wrap; padding: 10px; border: 0px solid; '
556
+ 'margin: 10px">'
557
+ f'prompt: {m.usage.prompt_tokens} tokens, '
558
+ f'response: {m.usage.completion_tokens} tokens, '
559
+ f'total: {m.usage.total_tokens} tokens'
560
+ '</div>'
561
+ )
549
562
  s.write('</div>')
550
563
 
551
564
  @classmethod
@@ -586,7 +599,6 @@ class _LeafNode:
586
599
  @pg.use_init_args(['children'])
587
600
  class Suite(Evaluable):
588
601
  """Evaluation suite."""
589
-
590
602
  children: Annotated[list[Evaluable], 'Child evaluation sets or suites.']
591
603
 
592
604
  # Use empty ID as suite is just a container of child evaluations.
@@ -741,10 +753,12 @@ class Evaluation(Evaluable):
741
753
 
742
754
  # Constants.
743
755
  CACHE_JSON = 'cache.json'
744
- FAILURES_HTML = 'failures.html'
756
+ OOP_FAILURES_HTML = 'oop_failures.html'
757
+ NON_OOP_FAILURES_HTML = 'non_oop_failures.html'
745
758
 
746
759
  @functools.cached_property
747
760
  def hash(self) -> str:
761
+ """Returns the semantic-based hash of the evaluation."""
748
762
  if self.is_deterministic:
749
763
  identity = pg.format(self._identifiers(), compact=True)
750
764
  else:
@@ -793,6 +807,10 @@ class Evaluation(Evaluable):
793
807
  """Returns the complete rate."""
794
808
  return self.num_completed / self.num_examples
795
809
 
810
+ #
811
+ # Properties on failures.
812
+ #
813
+
796
814
  @property
797
815
  def failures(self) -> list[tuple[Any, Exception]]:
798
816
  """Returns the failed examples and their errors."""
@@ -803,6 +821,15 @@ class Evaluation(Evaluable):
803
821
  """Returns the number of failed examples."""
804
822
  return len(self.failures)
805
823
 
824
+ @functools.cached_property
825
+ def failure_breakdown(self) -> dict[str, int]:
826
+ """Returns the breakdown of failures."""
827
+ breakdown = collections.defaultdict(int)
828
+ for _, error in self.failures:
829
+ breakdown[_error_key(error)] += 1
830
+ sorted_items = sorted(breakdown.items(), key=lambda x: x[1], reverse=True)
831
+ return pg.Dict({x[0]: x[1] for x in sorted_items})
832
+
806
833
  @property
807
834
  def failure_rate(self) -> float:
808
835
  """Returns the failure rate in range [0, 1]."""
@@ -810,17 +837,76 @@ class Evaluation(Evaluable):
810
837
  return 0.0
811
838
  return self.num_failures / self.num_completed
812
839
 
840
+ @functools.cached_property
841
+ def oop_failures(self) -> list[tuple[Any, lf_structured.MappingError]]:
842
+ """Returns the OOP failures."""
843
+ return [item for item in self.failures
844
+ if isinstance(item[1], lf_structured.MappingError)]
845
+
846
+ @property
847
+ def num_oop_failures(self) -> int:
848
+ """Returns the number of OOP failures."""
849
+ return len(self.oop_failures)
850
+
851
+ @property
852
+ def oop_failure_rate(self) -> float:
853
+ """Returns the OOP failure rate in range [0, 1]."""
854
+ if self.num_completed == 0:
855
+ return 0.0
856
+ return self.num_oop_failures / self.num_completed
857
+
858
+ @functools.cached_property
859
+ def non_oop_failures(self) -> list[tuple[Any, Exception]]:
860
+ """Returns the OOP failures."""
861
+ return [item for item in self.failures
862
+ if not isinstance(item[1], lf_structured.MappingError)]
863
+
864
+ @property
865
+ def num_non_oop_failures(self) -> int:
866
+ """Returns the number of non-OOP failures."""
867
+ return len(self.non_oop_failures)
868
+
869
+ @property
870
+ def non_oop_failure_rate(self) -> float:
871
+ """Returns the non-OOP failure rate in range [0, 1]."""
872
+ if self.num_completed == 0:
873
+ return 0.0
874
+ return self.num_non_oop_failures / self.num_completed
875
+
876
+ #
877
+ # Properties on usage.
878
+ #
879
+
880
+ @property
881
+ def has_usage(self) -> bool:
882
+ """Returns True if token usage is enabled."""
883
+ return self._num_usages > 0
884
+
885
+ @property
886
+ def average_prompt_tokens(self) -> int:
887
+ """Returns the average prompt tokens."""
888
+ if not self.has_usage:
889
+ return 0
890
+ return self._total_prompt_tokens // self._num_usages
891
+
892
+ @property
893
+ def average_completion_tokens(self) -> int:
894
+ """Returns the average completion tokens."""
895
+ if not self.has_usage:
896
+ return 0
897
+ return self._total_completion_tokens // self._num_usages
898
+
899
+ @property
900
+ def average_total_tokens(self) -> int:
901
+ """Returns the average total tokens."""
902
+ return self.average_prompt_tokens + self.average_completion_tokens
903
+
813
904
  @functools.cached_property
814
905
  def schema(self) -> lf_structured.Schema | None:
815
906
  """Schema."""
816
907
  if self.schema_fn is None:
817
908
  return None
818
909
 
819
- kwargs = {}
820
- # Allow schema to be a function based on current evaluation.
821
- if 'evaluation' in self.schema_fn.__signature__.arg_names:
822
- kwargs['evaluation'] = self
823
-
824
910
  schema = self._call_schema_fn()
825
911
  fewshot_examples = None
826
912
  if isinstance(schema, tuple):
@@ -861,7 +947,11 @@ class Evaluation(Evaluable):
861
947
  'Encountered: {annotation!r}.'
862
948
  )
863
949
  self._maybe_adjust_schema_for_completion(annotation)
864
- return lf_structured.Schema.from_value(annotation)
950
+ schema = lf_structured.Schema.from_value(annotation)
951
+ # NOTE(daiyip): add references to the dependent classes of the returned type
952
+ # to prevent unused subclasses get garbage collected by Python.
953
+ setattr(schema, '__dependencies__', schema.class_dependencies())
954
+ return schema
865
955
 
866
956
  def _maybe_adjust_schema_for_completion(self, cls):
867
957
  if (self.completion_prompt_field is None
@@ -938,12 +1028,25 @@ class Evaluation(Evaluable):
938
1028
  self._failures = []
939
1029
  self._num_completed = 0
940
1030
 
1031
+ self._total_prompt_tokens = 0
1032
+ self._total_completion_tokens = 0
1033
+ self._num_usages = 0
1034
+ self.__dict__.pop('oop_failures', None)
1035
+ self.__dict__.pop('non_oop_failures', None)
1036
+
1037
+ @property
1038
+ def oop_failures_link(self) -> str | None:
1039
+ """Returns the link to the OOP failures page."""
1040
+ if self.dir is None:
1041
+ return None
1042
+ return self.link(os.path.join(self.dir, Evaluation.OOP_FAILURES_HTML))
1043
+
941
1044
  @property
942
- def failures_link(self) -> str | None:
943
- """Returns the link to the failures page."""
1045
+ def non_oop_failures_link(self) -> str | None:
1046
+ """Returns the link to then non-OOP failures page."""
944
1047
  if self.dir is None:
945
1048
  return None
946
- return self.link(os.path.join(self.dir, Evaluation.FAILURES_HTML))
1049
+ return self.link(os.path.join(self.dir, Evaluation.NON_OOP_FAILURES_HTML))
947
1050
 
948
1051
  def _dryrun(
949
1052
  self,
@@ -953,11 +1056,11 @@ class Evaluation(Evaluable):
953
1056
  verbose: bool,
954
1057
  **kwargs,
955
1058
  ) -> None:
956
- # Set the example for dryrun.
957
- example = example or self.examples[0]
958
-
959
1059
  # We make a copy to avoid pollute the state of current object.
960
- copy = self.clone()
1060
+ copy: Evaluation = self.clone()
1061
+
1062
+ # Set the example for dryrun.
1063
+ example = example or copy.examples[0]
961
1064
  copy.__dict__['examples'] = [example]
962
1065
 
963
1066
  # We set the symbolic parent of the cloned to access contextual information
@@ -972,24 +1075,35 @@ class Evaluation(Evaluable):
972
1075
  color='green',
973
1076
  )
974
1077
 
975
- with lf.use_settings(debug=debug):
976
- output_message = copy.process(example, **(self.additional_args or {}))
977
- if self.schema is None:
978
- output = output_message.text
979
- else:
980
- output = output_message.result
1078
+ error, output_message = None, None
981
1079
 
982
- if verbose:
1080
+ try:
1081
+ with lf.use_settings(debug=debug):
1082
+ output_message = copy.process(example, **(self.additional_args or {}))
1083
+ if self.schema is None:
1084
+ output = output_message.text
1085
+ else:
1086
+ output = output_message.result
1087
+
1088
+ if verbose:
1089
+ lf.console.write('')
1090
+ lf.console.write(
1091
+ str(output),
1092
+ title='OUTPUT',
1093
+ color='blue',
1094
+ )
1095
+ except lf_structured.MappingError as e:
983
1096
  lf.console.write('')
984
1097
  lf.console.write(
985
- str(output),
986
- title='OUTPUT',
987
- color='blue',
1098
+ str(e),
1099
+ title='ERROR',
1100
+ color='red',
988
1101
  )
1102
+ error = e
1103
+
1104
+ copy.audit(example, output_message, error, dryrun=True)
1105
+ result = copy.finalize()
989
1106
 
990
- # Audit the result.
991
- copy.audit(example, output, output_message)
992
- result = copy.summarize()
993
1107
  if verbose:
994
1108
  lf.console.write('')
995
1109
  lf.console.write(
@@ -1012,6 +1126,9 @@ class Evaluation(Evaluable):
1012
1126
  **kwargs,
1013
1127
  ) -> None:
1014
1128
  # Setup examples.
1129
+ # Reset examples so it could be read from the input functor.
1130
+ self.__dict__.pop('examples', None)
1131
+
1015
1132
  if end is None:
1016
1133
  end = len(self.examples)
1017
1134
  examples = self.examples[start:end]
@@ -1036,18 +1153,19 @@ class Evaluation(Evaluable):
1036
1153
  status_fn=self._status,
1037
1154
  ):
1038
1155
  if error is not None:
1039
- self._failures.append((example, str(error)))
1040
- else:
1041
- output = message.text if self.schema is None else message.result
1042
- self.audit(example, output, message)
1043
- self._num_completed += 1
1156
+ message = (
1157
+ error.lm_response
1158
+ if isinstance(error, lf_structured.MappingError)
1159
+ else None
1160
+ )
1161
+ self.audit(example, message, error)
1044
1162
  finally:
1045
1163
  # Save cache upon completion or interruption.
1046
1164
  if self.dir and self.cache:
1047
1165
  self.cache.save()
1048
1166
 
1049
1167
  # Summarize result.
1050
- self._result = self.summarize()
1168
+ self._result = self.finalize()
1051
1169
  if verbose:
1052
1170
  lf.console.write(
1053
1171
  str(self.result),
@@ -1061,7 +1179,7 @@ class Evaluation(Evaluable):
1061
1179
 
1062
1180
  def process(self, example: Any, **kwargs) -> lf.Message:
1063
1181
  """Process an example and returns its output."""
1064
- prompt = self.prompt.render(example=example).text
1182
+ prompt = lf.Template.from_value(self.prompt, example=example)
1065
1183
  if self.method == 'call':
1066
1184
  return lf_structured.call(
1067
1185
  prompt,
@@ -1089,7 +1207,9 @@ class Evaluation(Evaluable):
1089
1207
  else:
1090
1208
  assert self.method == 'complete', self.method
1091
1209
  assert isinstance(self.schema.spec, pg.typing.Object), self.schema
1092
- input_value = self.schema.spec.cls.partial(prompt)
1210
+ # TODO(daiyip): Currently multi-modal inputs within the prompt for
1211
+ # completion is not supported.
1212
+ input_value = self.schema.spec.cls.partial(prompt.render().text)
1093
1213
  return lf_structured.complete(
1094
1214
  input_value,
1095
1215
  lm=self.lm,
@@ -1103,13 +1223,13 @@ class Evaluation(Evaluable):
1103
1223
  def _status(self, progress: lf.concurrent.Progress) -> dict[str, Any]:
1104
1224
  return {
1105
1225
  'Model': self.lm.model_id,
1106
- 'Succeeded': f'%.{self.report_precision}f%% (%d/%d)' % (
1107
- progress.success_rate * 100,
1226
+ 'Succeeded': '%s (%d/%d)' % (
1227
+ self._format_rate(progress.success_rate),
1108
1228
  progress.succeeded,
1109
1229
  progress.completed,
1110
1230
  ),
1111
- 'Failed': f'%.{self.report_precision}f%% (%d/%d)' % (
1112
- progress.failure_rate * 100,
1231
+ 'Failed': '%s (%d/%d)' % (
1232
+ self._format_rate(progress.failure_rate),
1113
1233
  progress.failed,
1114
1234
  progress.completed,
1115
1235
  ),
@@ -1119,21 +1239,20 @@ class Evaluation(Evaluable):
1119
1239
  assert self.result is not None
1120
1240
  m = self.result.metrics
1121
1241
  return (
1122
- f'COMPLETED(%s): Successes=%.{self.report_precision}f%% (%d/%d)'
1123
- f' Failures=%.{self.report_precision}f%% (%d/%d)'
1242
+ 'COMPLETED(%s): Successes=%s(%d/%d) Failures=%s (%d/%d)'
1124
1243
  % (
1125
1244
  run_status,
1126
- (1 - m.failure_rate) * 100,
1245
+ self._format_rate(1 - m.failure_rate),
1127
1246
  m.total - m.failures,
1128
1247
  m.total,
1129
- m.failure_rate * 100,
1248
+ self._format_rate(m.failure_rate),
1130
1249
  m.failures,
1131
1250
  m.total,
1132
1251
  )
1133
1252
  )
1134
1253
 
1135
- def summarize(self) -> pg.Dict:
1136
- """Summarizes the evaluation result."""
1254
+ def finalize(self) -> pg.Dict:
1255
+ """Finalizes the evaluation result."""
1137
1256
  if self.cache:
1138
1257
  cache_stats = dict(
1139
1258
  use_cache=True,
@@ -1143,6 +1262,19 @@ class Evaluation(Evaluable):
1143
1262
  )
1144
1263
  else:
1145
1264
  cache_stats = dict(use_cache=False)
1265
+
1266
+ if self.has_usage:
1267
+ usage = pg.Dict(
1268
+ total_prompt_tokens=self._total_prompt_tokens,
1269
+ total_completion_tokens=self._total_completion_tokens,
1270
+ num_usages=self._num_usages,
1271
+ average_prompt_tokens=self.average_prompt_tokens,
1272
+ average_completion_tokens=self.average_completion_tokens,
1273
+ average_total_tokens=self.average_total_tokens,
1274
+ )
1275
+ else:
1276
+ usage = None
1277
+
1146
1278
  result = pg.Dict(
1147
1279
  experiment_setup=pg.Dict(
1148
1280
  id=self.id,
@@ -1157,11 +1289,18 @@ class Evaluation(Evaluable):
1157
1289
  total=self.num_completed,
1158
1290
  failures=self.num_failures,
1159
1291
  failure_rate=self.failure_rate,
1292
+ oop_failures=self.num_oop_failures,
1293
+ oop_failure_rate=self.oop_failure_rate,
1294
+ non_oop_failures=self.num_non_oop_failures,
1295
+ non_oop_failure_rate=self.non_oop_failure_rate,
1296
+ failure_breakdown=self.failure_breakdown,
1160
1297
  ),
1298
+ usage=usage,
1161
1299
  )
1162
1300
  return result
1163
1301
 
1164
- def summarize_html(self) -> str:
1302
+ def summary_card(self) -> str:
1303
+ """Returns summary card in HTML."""
1165
1304
  s = io.StringIO()
1166
1305
  definition = _html_repr(self, compact=False, escape=True)
1167
1306
  s.write('<div><table><tr><td>')
@@ -1176,37 +1315,141 @@ class Evaluation(Evaluable):
1176
1315
  s.write(
1177
1316
  f'<a target="_blank" title="{definition}" '
1178
1317
  f'href="{self.index_link}">{self.hash}</a>'
1318
+ f' &nbsp;[<a href="{self.link(self.dir)}">dir</a>]'
1179
1319
  '</td></tr><tr><td>'
1180
1320
  )
1181
- self._render_metric(s)
1321
+ self._render_summary_metrics(s)
1322
+
1323
+ # Summarize average usage.
1324
+ if self.result.usage is not None:
1325
+ self._render_summary_usage(s)
1326
+
1182
1327
  s.write('</td></tr></table></div>')
1183
1328
  return s.getvalue()
1184
1329
 
1185
- def _render_metric(self, s: io.StringIO) -> None:
1330
+ def _render_summary_usage(self, s: io.StringIO) -> None:
1331
+ """Renders usage in HTML."""
1332
+ usage = self.result.usage
1333
+ total = usage.total_prompt_tokens + usage.total_completion_tokens
1334
+ s.write(
1335
+ '&nbsp;<a title="'
1336
+ f'# of usages: {usage.num_usages}&#013;'
1337
+ f'total prompt: {usage.total_prompt_tokens}&#013;'
1338
+ f'total response: {usage.total_completion_tokens}&#013;'
1339
+ f'avg prompt: {usage.average_prompt_tokens}&#013;'
1340
+ f'avg response: {usage.average_completion_tokens}'
1341
+ f'" style="color:gray">({total} tokens)</a>'
1342
+ )
1343
+
1344
+ def _render_summary_metrics(self, s: io.StringIO) -> None:
1186
1345
  """Renders metrics in HTML."""
1187
1346
  assert self.result is not None
1188
1347
  m = self.result.metrics
1348
+
1349
+ # OOP failures.
1350
+ oop_failure_title = f'OOP failures ({m.oop_failures}/{m.total})'
1351
+ if m.oop_failures:
1352
+ oop_failure_title += '&#013;'
1353
+ for name, count in m.failure_breakdown.items():
1354
+ if name.startswith('MappingError'):
1355
+ oop_failure_title += '&#013;%s: %s (%d/%d)' % (
1356
+ name.removeprefix('MappingError.'),
1357
+ self._format_rate(count / m.total),
1358
+ count,
1359
+ m.total,
1360
+ )
1361
+
1362
+ extra_style = ''
1363
+ if m.oop_failure_rate > 0.1 and m.oop_failures > 3:
1364
+ extra_style = ';font-weight:bold'
1189
1365
  s.write(
1190
- '<a title="Failures (%d/%d)" href="%s" style="color:red">%s</a>'
1366
+ '<a title="%s" href="%s" style="color:magenta%s">%s</a>'
1191
1367
  % (
1192
- m.failures,
1193
- m.total,
1194
- self.failures_link,
1195
- f'%.{self.report_precision}f%% ' % (m.failure_rate * 100),
1368
+ oop_failure_title,
1369
+ self.oop_failures_link,
1370
+ extra_style,
1371
+ self._format_rate(m.oop_failure_rate),
1196
1372
  )
1197
1373
  )
1374
+ s.write(' | ')
1375
+
1376
+ # Non-OOP failures.
1377
+ non_oop_failure_title = f'Non-OOP failures ({m.non_oop_failures}/{m.total})'
1378
+ if m.non_oop_failures:
1379
+ non_oop_failure_title += '&#013;'
1380
+ for name, count in m.failure_breakdown.items():
1381
+ if not name.startswith('MappingError'):
1382
+ non_oop_failure_title += '&#013;%s: %s (%d/%d)' % (
1383
+ name,
1384
+ self._format_rate(count / m.total),
1385
+ count,
1386
+ m.total,
1387
+ )
1198
1388
 
1199
- def audit(self, example: Any, output: Any, message: lf.Message) -> None:
1389
+ extra_style = ';font-weight:bold' if m.non_oop_failures > 0 else ''
1390
+ s.write(
1391
+ '<a title="%s" href="%s" style="color:red%s">%s</a>'
1392
+ % (
1393
+ non_oop_failure_title,
1394
+ self.non_oop_failures_link,
1395
+ extra_style,
1396
+ self._format_rate(m.non_oop_failure_rate),
1397
+ )
1398
+ )
1399
+
1400
+ def _format_rate(self, rate: float) -> str:
1401
+ """Formats a rate."""
1402
+ return f'%.{self.report_precision}f%% ' % (rate * 100)
1403
+
1404
+ def audit(
1405
+ self,
1406
+ example: Any,
1407
+ message: lf.Message | None,
1408
+ error: Exception | None = None,
1409
+ dryrun: bool = False,
1410
+ ) -> None:
1200
1411
  """Audits the example against the output. Subclasses should override.
1201
1412
 
1202
1413
  Args:
1203
1414
  example: The input object.
1204
- output: The output from LM. For `lf.call`, if `schema_fn` is not provided,
1205
- it will be the raw LM response string. Otherwise it will be the
1206
- structured output from the LM.
1207
1415
  message: The entire message returned by the LM, which could be used to
1208
- trace the LM input, response and parsed structure.
1416
+ trace the LM input, response and parsed structure. If error is raised
1417
+ before LLM could return a response, None will be its value.
1418
+ error: The exception during processing the example.
1419
+ dryrun: Whether or not audition takes place during dryrun.
1209
1420
  """
1421
+ if error is not None:
1422
+ self._failures.append((example, error))
1423
+
1424
+ # Invalid cache of num_oop_failures.
1425
+ self.__dict__.pop('oop_failures', None)
1426
+ self.__dict__.pop('non_oop_failures', None)
1427
+ self.__dict__.pop('failure_breakdown', None)
1428
+
1429
+ if isinstance(error, lf_structured.MappingError):
1430
+ message = error.lm_response
1431
+ else:
1432
+ assert message is not None
1433
+ output = message.text if self.schema is None else message.result
1434
+ self.audit_processed(example, output, message, dryrun=dryrun)
1435
+
1436
+ # Audit usage.
1437
+ if message is not None:
1438
+ self.audit_usage(message, dryrun=dryrun)
1439
+ self._num_completed += 1
1440
+
1441
+ def audit_usage(self, message: lf.Message, dryrun: bool = False) -> None:
1442
+ del dryrun
1443
+ for m in message.trace():
1444
+ if m.metadata.get('usage', None) is not None:
1445
+ self._total_prompt_tokens += m.usage.prompt_tokens
1446
+ self._total_completion_tokens += m.usage.completion_tokens
1447
+ self._num_usages += 1
1448
+
1449
+ def audit_processed(
1450
+ self, example: Any, output: Any, message: lf.Message, dryrun: bool = False
1451
+ ) -> None:
1452
+ """Audits a successfully processed example. Subclass should override."""
1210
1453
 
1211
1454
  def save(
1212
1455
  self, definition: bool = True, result: bool = True, report: bool = True
@@ -1229,16 +1472,26 @@ class Evaluation(Evaluable):
1229
1472
  # Save failures.
1230
1473
  pg.save(
1231
1474
  [
1232
- pg.Dict(
1233
- input=input, error=lf.text_formatting.decolored(str(error))
1234
- )
1235
- for input, error in self.failures
1475
+ pg.Dict(input=input, error=_format_error(error))
1476
+ for input, error in self.oop_failures
1236
1477
  ],
1237
- os.path.join(self.dir, Evaluation.FAILURES_JSON),
1478
+ os.path.join(self.dir, Evaluation.OOP_FAILURES_JSON),
1238
1479
  )
1239
1480
  pg.save(
1240
- self._html([self._render_result, self._render_failures]),
1241
- os.path.join(self.dir, Evaluation.FAILURES_HTML),
1481
+ self._html([self._render_result, self._render_oop_failures]),
1482
+ os.path.join(self.dir, Evaluation.OOP_FAILURES_HTML),
1483
+ file_format='txt',
1484
+ )
1485
+ pg.save(
1486
+ [
1487
+ pg.Dict(input=input, error=_format_error(error))
1488
+ for input, error in self.non_oop_failures
1489
+ ],
1490
+ os.path.join(self.dir, Evaluation.NON_OOP_FAILURES_JSON),
1491
+ )
1492
+ pg.save(
1493
+ self._html([self._render_result, self._render_non_oop_failures]),
1494
+ os.path.join(self.dir, Evaluation.NON_OOP_FAILURES_HTML),
1242
1495
  file_format='txt',
1243
1496
  )
1244
1497
 
@@ -1250,8 +1503,11 @@ class Evaluation(Evaluable):
1250
1503
  '<td>Prompt</td>'
1251
1504
  '<td>Schema</td>'
1252
1505
  '<td>Additional Args</td>'
1253
- '<td>Failures</td>'
1254
1506
  )
1507
+ if self.result.usage is not None:
1508
+ s.write('<td>Usage</td>')
1509
+ s.write('<td>OOP Failures</td>')
1510
+ s.write('<td>Non-OOP Failures</td>')
1255
1511
 
1256
1512
  def _render_result_row(self, s: io.StringIO) -> None:
1257
1513
  s.write(
@@ -1276,13 +1532,32 @@ class Evaluation(Evaluable):
1276
1532
  '<td style="color:purple" '
1277
1533
  f'{_html_repr(self.additional_args, compact=False)}</td>'
1278
1534
  )
1279
- # Failures.
1535
+ # Usage.
1536
+ if self.result.usage is not None:
1537
+ s.write('<td>')
1538
+ self._render_summary_usage(s)
1539
+ s.write('</td>')
1540
+
1541
+ # OOP failures.
1542
+ s.write(
1543
+ '<td><span style="color:magenta">%s</span>%s</td>'
1544
+ % (
1545
+ self._format_rate(self.oop_failure_rate),
1546
+ '<a href="%s">(%d/%d)</a>'
1547
+ % (self.oop_failures_link,
1548
+ self.num_oop_failures,
1549
+ self.num_completed),
1550
+ )
1551
+ )
1552
+ # Non-OOP failures.
1280
1553
  s.write(
1281
- '<td><span style="color:orange">%s</span>%s</td>'
1554
+ '<td><span style="color:red">%s</span>%s</td>'
1282
1555
  % (
1283
- f'%.{self.report_precision}f%%' % (self.failure_rate * 100),
1556
+ self._format_rate(self.non_oop_failure_rate),
1284
1557
  '<a href="%s">(%d/%d)</a>'
1285
- % (self.failures_link, self.num_failures, self.num_completed),
1558
+ % (self.non_oop_failures_link,
1559
+ self.num_non_oop_failures,
1560
+ self.num_completed),
1286
1561
  )
1287
1562
  )
1288
1563
 
@@ -1296,24 +1571,77 @@ class Evaluation(Evaluable):
1296
1571
  else:
1297
1572
  return 'cyan'
1298
1573
 
1299
- def _render_failures(self, s: io.StringIO) -> None:
1574
+ def _render_oop_failures(self, s: io.StringIO) -> None:
1575
+ self._render_failures(s, '^MappingError.*', error_color='magenta')
1576
+
1577
+ def _render_non_oop_failures(self, s: io.StringIO) -> None:
1578
+ self._render_failures(s, '^(?!MappingError).*', error_color='red')
1579
+
1580
+ def _render_failures(
1581
+ self, s: io.StringIO, error_regex: str, error_color: str) -> None:
1300
1582
  """Formats the failed cases into html."""
1583
+ # Failure summary.
1301
1584
  s.write(
1302
- '<h2> Failed Cases </h2>'
1585
+ '<h2> Error Summary </h2>'
1303
1586
  '<div style="white-space:pre">\n'
1304
1587
  '<table style="border:1px solid">'
1305
- '<tr class="header"><td>No.</td><td>Input</td><td>Error</td></tr>'
1588
+ '<tr class="header"><td>Error type</td><td>Stats</td></tr>'
1306
1589
  )
1590
+ error_regex = re.compile(error_regex)
1591
+ if self.result.metrics.failure_breakdown:
1592
+ for name, count in self.result.metrics.failure_breakdown.items():
1593
+ if not error_regex.match(name):
1594
+ continue
1595
+
1596
+ link = f'<a href="#{name}">{name}</a>'
1597
+ error_rate = self._format_rate(count / self.result.metrics.total)
1598
+ stats = (f'<span style="color:{error_color}">{error_rate} '
1599
+ f'({count}/{self.result.metrics.total})</span>')
1600
+ s.write(f'<tr><td>{link}</td><td>{stats})</td></tr>')
1601
+ s.write(
1602
+ '</table></div>'
1603
+ '<h2> Failed Cases </h2>'
1604
+ '<div style="white-space:pre">'
1605
+ )
1606
+ # Failure details by error type.
1607
+ failures_by_error = collections.defaultdict(list)
1608
+ for example, error in self.failures:
1609
+ error_name = _error_key(error)
1610
+ if error_regex.match(error_name):
1611
+ failures_by_error[error_name].append((example, error))
1612
+
1613
+ for error_key, failures in failures_by_error.items():
1614
+ s.write(
1615
+ f'<h3 id="{error_key}"><a href="#{error_key}">{error_key}</a> '
1616
+ f'(count={len(failures)})</h3>'
1617
+ '<table style="border:1px solid">'
1618
+ '<tr class="header"><td>No.</td><td>Input</td>'
1619
+ '<td>LM invocation</td><td>Error</td></tr>'
1620
+ )
1621
+ for i, (example, error) in enumerate(failures):
1622
+ lm_response = None
1623
+ if isinstance(error, lf.structured.MappingError):
1624
+ lm_response = error.lm_response
1625
+ error = error.cause
1626
+
1627
+ bgcolor = 'white' if i % 2 == 0 else '#DDDDDD'
1628
+ s.write(f'<tr style="background-color: {bgcolor}"><td>{i + 1}</td>')
1629
+ s.write('<td style="color:green;white-space:pre-wrap">')
1630
+ s.write(pg.format(example, verbose=False))
1631
+ s.write('</td><td>')
1632
+ if lm_response is not None:
1633
+ self._render_message(lm_response, s)
1634
+ s.write(f'</td><td style="color:{error_color};white-space:pre">')
1635
+ s.write(_format_error(error))
1636
+ s.write('</td></tr>')
1637
+ s.write('</table>')
1638
+ s.write('</div>')
1307
1639
 
1308
- for i, (example, error) in enumerate(self.failures):
1309
- bgcolor = 'white' if i % 2 == 0 else '#DDDDDD'
1310
- s.write(f'<tr style="background-color: {bgcolor}"><td>{i + 1}</td>')
1311
- input_str = pg.format(example, verbose=False)
1312
- s.write(f'<td style="color:green;white-space:pre-wrap">{input_str}</td>')
1313
- error_str = lf.text_formatting.decolored(str(error))
1314
- s.write(f'<td style="color:red;white-space:pre">{error_str}</td>')
1315
- s.write('</tr>')
1316
- s.write('</table></div>')
1640
+ @classmethod
1641
+ def visualize(cls, evaluations: list['Evaluation']) -> str | None:
1642
+ """Visualize the a list of evaluations of this task in HTML."""
1643
+ del evaluations
1644
+ return None
1317
1645
 
1318
1646
 
1319
1647
  @pg.functor()
@@ -1374,8 +1702,8 @@ class Summary(pg.Object):
1374
1702
  Type[lf.LanguageModel],
1375
1703
  tuple[lf.LanguageModel | Type[lf.LanguageModel], ...],
1376
1704
  ] = lf.LanguageModel,
1377
- method: Union[str, tuple[str], None] = None,
1378
- schema_fn: Union[pg.Functor, tuple[pg.Functor], None] = None,
1705
+ method: Union[str, tuple[str, ...], None] = None,
1706
+ schema_fn: Union[pg.Functor, tuple[pg.Functor, ...], None] = None,
1379
1707
  completed: bool | None = None,
1380
1708
  pivot_field: str | None = None,
1381
1709
  ) -> 'Summary':
@@ -1466,7 +1794,7 @@ class Summary(pg.Object):
1466
1794
  if e is None:
1467
1795
  s.write('<span style="color: gray">N/A<span>')
1468
1796
  else:
1469
- s.write(e.summarize_html())
1797
+ s.write(e.summary_card())
1470
1798
  s.write('</td>')
1471
1799
  s.write('</tr>')
1472
1800
  s.write('</table>')
@@ -1541,13 +1869,22 @@ class Summary(pg.Object):
1541
1869
  s.write('<html><body>')
1542
1870
  for task in sorted(self.tasks(), key=lambda cls: cls.__name__):
1543
1871
  table_id = task.__name__.lower()
1872
+ evaluations = self.select(task=task).evaluations
1873
+ table = Summary.Table.from_evaluations(evaluations, pivot_field)
1544
1874
  s.write('<div>')
1545
- s.write(f'<a id="{table_id}"')
1546
- s.write(f'<h2><a href="#{table_id}">{task.__name__}</a></h2>')
1547
- s.write('</a>')
1548
- table = Summary.Table.from_evaluations(
1549
- self.select(task=task).evaluations, pivot_field
1875
+ s.write(
1876
+ f'<a id="{table_id}" href="#{table_id}">'
1877
+ f'<h2>{task.__name__}</h2></a>'
1550
1878
  )
1879
+
1880
+ # Allow users to plugin visualization code (e.g. matplot) in the summary
1881
+ # page.
1882
+ visual_part = task.visualize(evaluations)
1883
+ if visual_part:
1884
+ s.write(visual_part)
1885
+
1886
+ s.write(f'<h4 style="color:gray">{len(evaluations)} experiments</h4>')
1887
+ s.write('<hr/>')
1551
1888
  s.write(table.html())
1552
1889
  s.write('</div>')
1553
1890
  s.write('</body></html>')
@@ -1556,8 +1893,36 @@ class Summary(pg.Object):
1556
1893
  def _repr_html_(self) -> str:
1557
1894
  return self.html()
1558
1895
 
1896
+ def json(
1897
+ self,
1898
+ ) -> dict[
1899
+ str, # Task name
1900
+ list[pg.Dict], # List of pg.Dict with `experiment` and `metrics`.
1901
+ ]:
1902
+ """Returns the JSON representation of the summary."""
1903
+ task_results = {}
1904
+ for task in sorted(self.tasks(), key=lambda cls: cls.__name__):
1905
+ results = []
1906
+ for entry in self.select(task=task).evaluations:
1907
+ results.append(
1908
+ pg.Dict(
1909
+ id=entry.id,
1910
+ experiment=entry,
1911
+ dir=entry.dir,
1912
+ metrics=entry.result.metrics if entry.result else None,
1913
+ usage=entry.result.usage if entry.result else None,
1914
+ )
1915
+ )
1916
+ task_results[task.__name__] = results
1917
+ return task_results
1918
+
1559
1919
  def save(self, file: str, pivot_field: str | None = None) -> None:
1560
1920
  pg.save(self.html(pivot_field), file, file_format='txt')
1921
+ if file.endswith('.html'):
1922
+ json_file = file.replace('.html', '.json')
1923
+ else:
1924
+ json_file = os.path.join(file, '.json')
1925
+ pg.save(self.json(), json_file)
1561
1926
 
1562
1927
  @classmethod
1563
1928
  def from_dirs(
@@ -1694,6 +2059,21 @@ class Summary(pg.Object):
1694
2059
  return result.join()
1695
2060
 
1696
2061
 
2062
+ def _format_error(error: Exception):
2063
+ """Formats an error into a string."""
2064
+ return (f'({error.__class__.__name__}) '
2065
+ + lf.text_formatting.decolored(str(error)))
2066
+
2067
+
2068
+ def _error_key(error: Exception) -> str:
2069
+ """Returns the key for an error."""
2070
+ error_names = []
2071
+ while error is not None:
2072
+ error_names.append(error.__class__.__name__)
2073
+ error = getattr(error, 'cause', None)
2074
+ return '.'.join(error_names)
2075
+
2076
+
1697
2077
  def _html_repr(value: Any, compact: bool = True, escape: bool = False) -> str:
1698
2078
  """Formats prompt in HTML."""
1699
2079
  if type(value) is lf.Template: # pylint: disable=unidiomatic-typecheck
@@ -1768,3 +2148,193 @@ def monitor_async(
1768
2148
  scan_interval=scan_interval,
1769
2149
  refresh_when_stop=refresh_when_stop,
1770
2150
  )
2151
+
2152
+
2153
+ #
2154
+ # Named evaluations and experiments support.
2155
+ #
2156
+
2157
+
2158
+ class _NamedEvaluationRegistry:
2159
+ """Named evaluation registry."""
2160
+
2161
+ def __init__(self):
2162
+ self._registry = {}
2163
+
2164
+ def names(self) -> list[str]:
2165
+ """Returns all registered names."""
2166
+ return sorted(self._registry.keys())
2167
+
2168
+ def get(self, name: str) -> Type[Evaluable]:
2169
+ """Gets an evaluation by name."""
2170
+ if name not in self._registry:
2171
+ raise ValueError(
2172
+ f'Evaluation {name!r} not found. '
2173
+ 'Did you forget to import the module that registers it?'
2174
+ )
2175
+ return self._registry[name]
2176
+
2177
+ def register(
2178
+ self,
2179
+ name: str,
2180
+ experiment_cls: Type[Evaluable],
2181
+ ):
2182
+ """Register an experiment class."""
2183
+ self._registry[name] = experiment_cls
2184
+
2185
+
2186
+ _eval_registry = _NamedEvaluationRegistry()
2187
+
2188
+
2189
+ def registered_names() -> list[str]:
2190
+ """Returns all registered names."""
2191
+ return _eval_registry.names()
2192
+
2193
+
2194
+ def get_evaluation(evaluation: str | Evaluable) -> Evaluable:
2195
+ """Gets an evaluation experiment by name."""
2196
+ if isinstance(evaluation, str):
2197
+ return _eval_registry.get(evaluation)()
2198
+ return evaluation
2199
+
2200
+
2201
+ def register(name: str):
2202
+ """Decorator to create a named evaluation class."""
2203
+
2204
+ def _register(func_or_cls: Type[Evaluation] | types.FunctionType):
2205
+ if inspect.isfunction(func_or_cls):
2206
+ e = func_or_cls()
2207
+ if not isinstance(e, Evaluable):
2208
+ raise TypeError(
2209
+ f'The return value of `{func_or_cls}` should be an instance of '
2210
+ '`lf.eval.Evaluable` subclass.'
2211
+ )
2212
+
2213
+ class GeneratedSuite(Suite):
2214
+ # NOTE(daiyip): Delay serialization key registration for generated
2215
+ # class.
2216
+ auto_register = False
2217
+ children = e.children if isinstance(e, Suite) else [e]
2218
+
2219
+ cls = GeneratedSuite
2220
+ cls.__name__ = func_or_cls.__name__
2221
+ cls.__doc__ = func_or_cls.__doc__
2222
+ cls.__qualname__ = func_or_cls.__qualname__
2223
+ cls.__module__ = getattr(func_or_cls, '__module__', 'wrapper')
2224
+ cls.register_for_deserialization(cls.__type_name__)
2225
+
2226
+ elif issubclass(func_or_cls, Evaluable):
2227
+ cls = func_or_cls
2228
+ else:
2229
+ raise ValueError(f'Unsupported type: {type(func_or_cls)}')
2230
+
2231
+ _eval_registry.register(name, cls)
2232
+ return cls
2233
+
2234
+ return _register
2235
+
2236
+
2237
+ def get(
2238
+ root_dir: str,
2239
+ evaluations: list[str | Evaluable],
2240
+ filter: Union[ # pylint: disable=redefined-builtin
2241
+ str, # Regex to filter evaluation based on ID.
2242
+ Callable[[Evaluable], bool], # Custom filter function.
2243
+ None # No filtering (Default).
2244
+ ] = None, # pylint: disable=bad-whitespace
2245
+ patches: list[Union[
2246
+ str, # String-based PyGlove patcher.
2247
+ pg.patching.Patcher, # PyGlove patcher object.
2248
+ Callable[[pg.KeyPath, Any, Any], Any], # PyGlove rebind function.
2249
+ ]] | None = None, # pylint: disable=bad-whitespace
2250
+ ) -> Suite:
2251
+ """Gets a suite from a list of patched evaluations.
2252
+
2253
+ Args:
2254
+ root_dir: The root directory of the experiment.
2255
+ evaluations: A list of evaluations to be included in the suite.
2256
+ filter: A regular expression (str) for selecting sub-experiments of matched
2257
+ IDs, or a filter function to filter the evaluations.
2258
+ patches: A list of patches to be applied to the suite. Each element can be
2259
+ a string (for string-based patcher), a `pg.patching.Patcher` object, or
2260
+ a rebind function (e.g. `pg.rebind`). See `lf.eval.patch_*` for more
2261
+ details.
2262
+
2263
+ Returns:
2264
+ A suite of selected `lf.eval.Evaluation` objects.
2265
+ """
2266
+ evaluations = [get_evaluation(e) for e in evaluations]
2267
+ suite = Suite(evaluations, root_dir=root_dir)
2268
+ if patches:
2269
+ suite = pg.patch(suite, patches)
2270
+
2271
+ if isinstance(filter, str):
2272
+ regex = re.compile(filter)
2273
+ filter = lambda x: bool(regex.match(x.id))
2274
+
2275
+ if filter:
2276
+ suite = Suite(
2277
+ [leaf for leaf in suite.leaf_nodes if filter(leaf)], root_dir=root_dir)
2278
+ return suite
2279
+
2280
+
2281
+ def run(
2282
+ root_dir: str,
2283
+ evaluations: list[str | Evaluable],
2284
+ filter: Union[ # pylint: disable=redefined-builtin
2285
+ str, # Regex to filter evaluation based on ID.
2286
+ Callable[[Evaluable], bool], # Custom filter function.
2287
+ None # No filtering (Default).
2288
+ ] = None, # pylint: disable=bad-whitespace
2289
+ patches: list[Union[
2290
+ str, # String-based PyGlove patcher.
2291
+ pg.patching.Patcher, # PyGlove patcher object.
2292
+ Callable[[pg.KeyPath, Any, Any], Any], # PyGlove rebind function.
2293
+ ]] | None = None, # pylint: disable=bad-whitespace
2294
+ mode: Literal['run', 'rerun', 'dryrun', 'noop'] = 'run',
2295
+ debug: bool = False,
2296
+ print_definition: bool = False,
2297
+ **kwargs,
2298
+ ) -> Suite:
2299
+ """Run selected evaluations with patching.
2300
+
2301
+ Args:
2302
+ root_dir: The root directory of the experiment.
2303
+ evaluations: A list of evaluations to be included in the suite.
2304
+ filter: A regular expression (str) for selecting sub-experiments of matched
2305
+ IDs, or a filter function to filter the evaluations.
2306
+ patches: A list of patches to be applied to the suite. Each element can be
2307
+ a string (for string-based patcher), a `pg.patching.Patcher` object, or
2308
+ a rebind function (e.g. `pg.rebind`). See `lf.eval.patch_*` for more
2309
+ details.
2310
+ mode: The mode to run the suite. "run" to run the suite, with reusing
2311
+ existing results if available; "rerun" to rerun all evaluations even if
2312
+ there are existing results; "dryrun" to dryrun the suite; and "noop"
2313
+ to do nothing.
2314
+ debug: Whether to run in debug mode.
2315
+ print_definition: Whether to print the experiment definition.
2316
+ **kwargs: Additional arguments to be passed to dryrun/run the suite.
2317
+
2318
+ Returns:
2319
+ A suite of selected `lf.eval.Evaluation` objects.
2320
+ """
2321
+ suite = get(root_dir, evaluations, patches=patches, filter=filter)
2322
+ if print_definition:
2323
+ lf.console.write(
2324
+ pg.format(
2325
+ suite,
2326
+ compact=False,
2327
+ verbose=False,
2328
+ hide_default_values=True,
2329
+ python_format=True,
2330
+ ),
2331
+ title='[EXPERIMENT DEFINITION]',
2332
+ color='blue',
2333
+ )
2334
+
2335
+ if mode == 'run':
2336
+ rerun = mode == 'rerun'
2337
+ suite.run(debug=debug, rerun=rerun, **kwargs)
2338
+ elif mode == 'dryrun':
2339
+ suite.dryrun(debug=debug, **kwargs)
2340
+ return suite