langfun 0.0.2.dev20240319__py3-none-any.whl → 0.0.2.dev20240429__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. langfun/__init__.py +2 -0
  2. langfun/core/__init__.py +1 -0
  3. langfun/core/coding/python/correction.py +0 -7
  4. langfun/core/component.py +6 -0
  5. langfun/core/component_test.py +1 -0
  6. langfun/core/eval/__init__.py +2 -0
  7. langfun/core/eval/base.py +240 -37
  8. langfun/core/eval/base_test.py +52 -18
  9. langfun/core/eval/matching.py +26 -9
  10. langfun/core/eval/matching_test.py +3 -4
  11. langfun/core/eval/scoring.py +15 -6
  12. langfun/core/eval/scoring_test.py +2 -2
  13. langfun/core/langfunc.py +0 -5
  14. langfun/core/langfunc_test.py +6 -4
  15. langfun/core/language_model.py +124 -24
  16. langfun/core/language_model_test.py +249 -26
  17. langfun/core/llms/__init__.py +24 -5
  18. langfun/core/llms/anthropic.py +263 -0
  19. langfun/core/llms/anthropic_test.py +167 -0
  20. langfun/core/llms/cache/in_memory_test.py +37 -28
  21. langfun/core/llms/fake.py +31 -22
  22. langfun/core/llms/fake_test.py +122 -11
  23. langfun/core/llms/{gemini.py → google_genai.py} +117 -15
  24. langfun/core/llms/{gemini_test.py → google_genai_test.py} +83 -15
  25. langfun/core/llms/groq.py +260 -0
  26. langfun/core/llms/groq_test.py +170 -0
  27. langfun/core/llms/llama_cpp.py +3 -1
  28. langfun/core/llms/openai.py +97 -79
  29. langfun/core/llms/openai_test.py +285 -59
  30. langfun/core/modalities/video.py +5 -2
  31. langfun/core/structured/__init__.py +3 -0
  32. langfun/core/structured/completion_test.py +2 -2
  33. langfun/core/structured/function_generation.py +245 -0
  34. langfun/core/structured/function_generation_test.py +329 -0
  35. langfun/core/structured/mapping.py +59 -3
  36. langfun/core/structured/mapping_test.py +17 -0
  37. langfun/core/structured/parsing.py +2 -1
  38. langfun/core/structured/parsing_test.py +18 -13
  39. langfun/core/structured/prompting.py +27 -6
  40. langfun/core/structured/prompting_test.py +79 -12
  41. langfun/core/structured/schema.py +25 -22
  42. langfun/core/structured/schema_generation.py +2 -3
  43. langfun/core/structured/schema_generation_test.py +2 -2
  44. langfun/core/structured/schema_test.py +42 -27
  45. langfun/core/template.py +125 -10
  46. langfun/core/template_test.py +75 -0
  47. langfun/core/templates/selfplay_test.py +6 -2
  48. {langfun-0.0.2.dev20240319.dist-info → langfun-0.0.2.dev20240429.dist-info}/METADATA +3 -2
  49. {langfun-0.0.2.dev20240319.dist-info → langfun-0.0.2.dev20240429.dist-info}/RECORD +52 -46
  50. {langfun-0.0.2.dev20240319.dist-info → langfun-0.0.2.dev20240429.dist-info}/LICENSE +0 -0
  51. {langfun-0.0.2.dev20240319.dist-info → langfun-0.0.2.dev20240429.dist-info}/WHEEL +0 -0
  52. {langfun-0.0.2.dev20240319.dist-info → langfun-0.0.2.dev20240429.dist-info}/top_level.txt +0 -0
langfun/__init__.py CHANGED
@@ -34,6 +34,7 @@ score = structured.score
34
34
  generate_class = structured.generate_class
35
35
 
36
36
  source_form = structured.source_form
37
+ function_gen = structured.function_gen
37
38
 
38
39
  from langfun.core import eval # pylint: disable=redefined-builtin
39
40
  from langfun.core import templates
@@ -54,6 +55,7 @@ Video = modalities.Video
54
55
  PDF = modalities.PDF
55
56
 
56
57
  # Error types.
58
+ MappingError = structured.MappingError
57
59
  SchemaError = structured.SchemaError
58
60
  JsonError = structured.JsonError
59
61
  CodeError = coding.CodeError
langfun/core/__init__.py CHANGED
@@ -99,6 +99,7 @@ from langfun.core.modality import ModalityRef
99
99
  from langfun.core.language_model import LanguageModel
100
100
  from langfun.core.language_model import LMSample
101
101
  from langfun.core.language_model import LMSamplingOptions
102
+ from langfun.core.language_model import LMSamplingUsage
102
103
  from langfun.core.language_model import LMSamplingResult
103
104
  from langfun.core.language_model import LMScoringResult
104
105
  from langfun.core.language_model import LMCache
@@ -12,7 +12,6 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  """Python code error correction."""
15
- import re
16
15
  from typing import Any
17
16
  import langfun.core as lf
18
17
  from langfun.core.coding.python import errors
@@ -31,11 +30,6 @@ class CorrectedCode(pg.Object):
31
30
  corrected_code: str
32
31
 
33
32
 
34
- def remove_docstrings(code):
35
- pattern = re.compile(r"(def .+?:\s*?)('''|\"\"\")((.|\s)*?)(\2)", re.DOTALL)
36
- return pattern.sub(r"\1", code)
37
-
38
-
39
33
  def run_with_correction(
40
34
  code: str,
41
35
  error: str | None = None,
@@ -86,7 +80,6 @@ def run_with_correction(
86
80
  # pytype: enable=import-error
87
81
  # pylint: enable=g-import-not-at-top
88
82
 
89
- code = remove_docstrings(code)
90
83
  if max_attempts == 0:
91
84
  result = execution.run(
92
85
  code,
langfun/core/component.py CHANGED
@@ -210,6 +210,12 @@ def get_contextual_override(var_name: str) -> ContextualOverride | None:
210
210
  return _get_scoped_value(_global_tls, _CONTEXT_OVERRIDES, var_name)
211
211
 
212
212
 
213
+ def all_contextual_values() -> dict[str, Any]:
214
+ """Returns all contextual values provided from `lf.context` in scope."""
215
+ overrides = getattr(_global_tls, _CONTEXT_OVERRIDES, {})
216
+ return {k: v.value for k, v in overrides.items()}
217
+
218
+
213
219
  @contextlib.contextmanager
214
220
  def _contextual_scope(
215
221
  tls: threading.local, tls_key, **variables
@@ -84,6 +84,7 @@ class ComponentContextTest(unittest.TestCase):
84
84
  lf.get_contextual_override('y'),
85
85
  lf.ContextualOverride(3, cascade=False, override_attrs=False),
86
86
  )
87
+ self.assertEqual(lf.all_contextual_values(), dict(x=3, y=3, z=3))
87
88
 
88
89
  # Member attributes take precedence over `lf.context`.
89
90
  self.assertEqual(a1.x, 1)
@@ -16,6 +16,8 @@
16
16
  # pylint: disable=g-importing-member
17
17
  # pylint: disable=g-bad-import-order
18
18
 
19
+ from langfun.core.eval.base import app_run
20
+
19
21
  from langfun.core.eval.base import Evaluable
20
22
  from langfun.core.eval.base import Evaluation
21
23
  from langfun.core.eval.base import Suite
langfun/core/eval/base.py CHANGED
@@ -26,7 +26,10 @@ import threading
26
26
  import time
27
27
  from typing import Annotated, Any, Callable, Iterator, Literal, Optional, Sequence, Type, Union
28
28
 
29
+ from absl import app
30
+ from absl import flags
29
31
  import langfun.core as lf
32
+ import langfun.core.coding as lf_coding
30
33
  from langfun.core.llms.cache import in_memory
31
34
  import langfun.core.structured as lf_structured
32
35
  import pyglove as pg
@@ -41,14 +44,6 @@ class Evaluable(lf.Component):
41
44
  INDEX_HTML = 'index.html'
42
45
  SUMMARY_HTML = 'summary.html'
43
46
 
44
- id: Annotated[
45
- str,
46
- (
47
- 'The ID of the evaluation, which should be unique across all '
48
- 'evaluations.'
49
- ),
50
- ]
51
-
52
47
  root_dir: Annotated[
53
48
  str | None,
54
49
  (
@@ -61,6 +56,18 @@ class Evaluable(lf.Component):
61
56
  int, 'Number of decimals when reporting precision.'
62
57
  ] = lf.contextual(default=1)
63
58
 
59
+ @property
60
+ @abc.abstractmethod
61
+ def id(self) -> str:
62
+ """Returns the ID of the task.
63
+
64
+ Returns:
65
+ Evaluation task ID. Different evaluation task should have their unique
66
+ task IDs, for each task will be stored in sub-directoreis identified by
67
+ their IDs. For suites, the ID could be an empty string as they will not
68
+ produce sub-directories
69
+ """
70
+
64
71
  @property
65
72
  def dir(self) -> str | None:
66
73
  """Returns the directory for saving results and details."""
@@ -533,7 +540,7 @@ class Evaluable(lf.Component):
533
540
  f'<div style="color: {text_color}; white-space: pre-wrap;'
534
541
  'padding: 10px; border: 1px solid; margin-top: 10px">'
535
542
  )
536
- s.write(m.text)
543
+ s.write(m.get('formatted_text', m.text))
537
544
  if m.result is not None:
538
545
  s.write(
539
546
  '<div style="color: magenta; white-space: pre-wrap;'
@@ -541,6 +548,16 @@ class Evaluable(lf.Component):
541
548
  )
542
549
  s.write(pg.format(m.result))
543
550
  s.write('</div>')
551
+ if 'usage' in m.metadata:
552
+ s.write(
553
+ '<div style="background-color: #EEEEEE; color: black; '
554
+ 'white-space: pre-wrap; padding: 10px; border: 0px solid; '
555
+ 'margin: 10px">'
556
+ f'prompt: {m.usage.prompt_tokens} tokens, '
557
+ f'response: {m.usage.completion_tokens} tokens, '
558
+ f'total: {m.usage.total_tokens} tokens'
559
+ '</div>'
560
+ )
544
561
  s.write('</div>')
545
562
 
546
563
  @classmethod
@@ -578,12 +595,15 @@ class _LeafNode:
578
595
  progress_bar: int | None = None
579
596
 
580
597
 
581
- @pg.use_init_args(['id', 'children'])
598
+ @pg.use_init_args(['children'])
582
599
  class Suite(Evaluable):
583
600
  """Evaluation suite."""
584
601
 
585
602
  children: Annotated[list[Evaluable], 'Child evaluation sets or suites.']
586
603
 
604
+ # Use empty ID as suite is just a container of child evaluations.
605
+ id: str = ''
606
+
587
607
  __kwargs__: Annotated[
588
608
  Any,
589
609
  (
@@ -802,17 +822,36 @@ class Evaluation(Evaluable):
802
822
  return 0.0
803
823
  return self.num_failures / self.num_completed
804
824
 
825
+ @property
826
+ def has_usage(self) -> bool:
827
+ """Returns True if token usage is enabled."""
828
+ return self._num_usages > 0
829
+
830
+ @property
831
+ def average_prompt_tokens(self) -> int:
832
+ """Returns the average prompt tokens."""
833
+ if not self.has_usage:
834
+ return 0
835
+ return self._total_prompt_tokens // self._num_usages
836
+
837
+ @property
838
+ def average_completion_tokens(self) -> int:
839
+ """Returns the average completion tokens."""
840
+ if not self.has_usage:
841
+ return 0
842
+ return self._total_completion_tokens // self._num_usages
843
+
844
+ @property
845
+ def average_total_tokens(self) -> int:
846
+ """Returns the average total tokens."""
847
+ return self.average_prompt_tokens + self.average_completion_tokens
848
+
805
849
  @functools.cached_property
806
850
  def schema(self) -> lf_structured.Schema | None:
807
851
  """Schema."""
808
852
  if self.schema_fn is None:
809
853
  return None
810
854
 
811
- kwargs = {}
812
- # Allow schema to be a function based on current evaluation.
813
- if 'evaluation' in self.schema_fn.__signature__.arg_names:
814
- kwargs['evaluation'] = self
815
-
816
855
  schema = self._call_schema_fn()
817
856
  fewshot_examples = None
818
857
  if isinstance(schema, tuple):
@@ -841,8 +880,10 @@ class Evaluation(Evaluable):
841
880
  kwargs['evaluation'] = self
842
881
  return self.schema_fn(**kwargs)
843
882
 
844
- def _formalize_schema(self, annotation) -> lf_structured.Schema:
883
+ def _formalize_schema(self, annotation) -> lf_structured.Schema | None:
845
884
  """Formalizes schema from annotation."""
885
+ if annotation in (str, None):
886
+ return None
846
887
  if self.method == 'complete':
847
888
  if not hasattr(annotation, '__schema__'):
848
889
  raise TypeError(
@@ -851,7 +892,11 @@ class Evaluation(Evaluable):
851
892
  'Encountered: {annotation!r}.'
852
893
  )
853
894
  self._maybe_adjust_schema_for_completion(annotation)
854
- return lf_structured.Schema.from_value(annotation)
895
+ schema = lf_structured.Schema.from_value(annotation)
896
+ # NOTE(daiyip): add references to the dependent classes of the returned type
897
+ # to prevent unused subclasses get garbage collected by Python.
898
+ setattr(schema, '__dependencies__', schema.class_dependencies())
899
+ return schema
855
900
 
856
901
  def _maybe_adjust_schema_for_completion(self, cls):
857
902
  if (self.completion_prompt_field is None
@@ -883,6 +928,14 @@ class Evaluation(Evaluable):
883
928
  completion_examples.append(ex)
884
929
  return completion_examples
885
930
 
931
+ @property
932
+ def id(self) -> str:
933
+ """Returns the ID of this evaluation."""
934
+ id_prefix = self.__class__.__name__
935
+ if not self.is_deterministic:
936
+ return id_prefix
937
+ return f'{id_prefix}@{self.hash}'
938
+
886
939
  @functools.cached_property
887
940
  def children(self) -> list['Evaluation']:
888
941
  """Returns the trials as child evaluations if this evaluation is a space."""
@@ -892,7 +945,6 @@ class Evaluation(Evaluable):
892
945
  for i, child in enumerate(pg.iter(self)):
893
946
  child.sym_setparent(self)
894
947
  child.sym_setpath(self.sym_path + f'children[{i}]')
895
- child.rebind(id=f'{self.id}@{child.hash}', skip_notification=True)
896
948
  children.append(child)
897
949
  return children
898
950
 
@@ -921,6 +973,10 @@ class Evaluation(Evaluable):
921
973
  self._failures = []
922
974
  self._num_completed = 0
923
975
 
976
+ self._total_prompt_tokens = 0
977
+ self._total_completion_tokens = 0
978
+ self._num_usages = 0
979
+
924
980
  @property
925
981
  def failures_link(self) -> str | None:
926
982
  """Returns the link to the failures page."""
@@ -940,7 +996,7 @@ class Evaluation(Evaluable):
940
996
  example = example or self.examples[0]
941
997
 
942
998
  # We make a copy to avoid pollute the state of current object.
943
- copy = self.clone()
999
+ copy: Evaluation = self.clone()
944
1000
  copy.__dict__['examples'] = [example]
945
1001
 
946
1002
  # We set the symbolic parent of the cloned to access contextual information
@@ -970,9 +1026,9 @@ class Evaluation(Evaluable):
970
1026
  color='blue',
971
1027
  )
972
1028
 
973
- # Audit the result.
974
- copy.audit(example, output, output_message)
1029
+ copy.audit(example, output_message, None, dryrun=True)
975
1030
  result = copy.summarize()
1031
+
976
1032
  if verbose:
977
1033
  lf.console.write('')
978
1034
  lf.console.write(
@@ -1004,7 +1060,11 @@ class Evaluation(Evaluable):
1004
1060
  self._reset()
1005
1061
 
1006
1062
  def _process(example: Any):
1007
- return self.process(example, **(self.additional_args or {}))
1063
+ # NOTE(daiyip): set the `input` symbol of the globals to None, so LLM
1064
+ # generated code with calls to `input` will raise an error, thus not
1065
+ # blocking the evaluation.
1066
+ with lf_coding.context(input=None):
1067
+ return self.process(example, **(self.additional_args or {}))
1008
1068
 
1009
1069
  try:
1010
1070
  for example, message, error in lf.concurrent_map(
@@ -1015,11 +1075,12 @@ class Evaluation(Evaluable):
1015
1075
  status_fn=self._status,
1016
1076
  ):
1017
1077
  if error is not None:
1018
- self._failures.append((example, str(error)))
1019
- else:
1020
- output = message.text if self.schema is None else message.result
1021
- self.audit(example, output, message)
1022
- self._num_completed += 1
1078
+ message = (
1079
+ error.lm_response
1080
+ if isinstance(error, lf_structured.MappingError)
1081
+ else None
1082
+ )
1083
+ self.audit(example, message, error)
1023
1084
  finally:
1024
1085
  # Save cache upon completion or interruption.
1025
1086
  if self.dir and self.cache:
@@ -1122,6 +1183,19 @@ class Evaluation(Evaluable):
1122
1183
  )
1123
1184
  else:
1124
1185
  cache_stats = dict(use_cache=False)
1186
+
1187
+ if self.has_usage:
1188
+ usage = pg.Dict(
1189
+ total_prompt_tokens=self._total_prompt_tokens,
1190
+ total_completion_tokens=self._total_completion_tokens,
1191
+ num_usages=self._num_usages,
1192
+ average_prompt_tokens=self.average_prompt_tokens,
1193
+ average_completion_tokens=self.average_completion_tokens,
1194
+ average_total_tokens=self.average_total_tokens,
1195
+ )
1196
+ else:
1197
+ usage = None
1198
+
1125
1199
  result = pg.Dict(
1126
1200
  experiment_setup=pg.Dict(
1127
1201
  id=self.id,
@@ -1137,6 +1211,7 @@ class Evaluation(Evaluable):
1137
1211
  failures=self.num_failures,
1138
1212
  failure_rate=self.failure_rate,
1139
1213
  ),
1214
+ usage=usage,
1140
1215
  )
1141
1216
  return result
1142
1217
 
@@ -1158,9 +1233,28 @@ class Evaluation(Evaluable):
1158
1233
  '</td></tr><tr><td>'
1159
1234
  )
1160
1235
  self._render_metric(s)
1236
+
1237
+ # Summarize average usage.
1238
+ if self.result.usage is not None:
1239
+ self._render_usage(s)
1240
+
1161
1241
  s.write('</td></tr></table></div>')
1162
1242
  return s.getvalue()
1163
1243
 
1244
+ def _render_usage(self, s: io.StringIO) -> None:
1245
+ """Renders usage in HTML."""
1246
+ usage = self.result.usage
1247
+ total = usage.total_prompt_tokens + usage.total_completion_tokens
1248
+ s.write(
1249
+ '&nbsp;<a title="'
1250
+ f'# of usages: {usage.num_usages}&#013;'
1251
+ f'total prompt: {usage.total_prompt_tokens}&#013;'
1252
+ f'total response: {usage.total_completion_tokens}&#013;'
1253
+ f'avg prompt: {usage.average_prompt_tokens}&#013;'
1254
+ f'avg response: {usage.average_completion_tokens}'
1255
+ f'" style="color:gray">({total} tokens)</a>'
1256
+ )
1257
+
1164
1258
  def _render_metric(self, s: io.StringIO) -> None:
1165
1259
  """Renders metrics in HTML."""
1166
1260
  assert self.result is not None
@@ -1175,17 +1269,48 @@ class Evaluation(Evaluable):
1175
1269
  )
1176
1270
  )
1177
1271
 
1178
- def audit(self, example: Any, output: Any, message: lf.Message) -> None:
1272
+ def audit(
1273
+ self,
1274
+ example: Any,
1275
+ message: lf.Message | None,
1276
+ error: Exception | None = None,
1277
+ dryrun: bool = False,
1278
+ ) -> None:
1179
1279
  """Audits the example against the output. Subclasses should override.
1180
1280
 
1181
1281
  Args:
1182
1282
  example: The input object.
1183
- output: The output from LM. For `lf.call`, if `schema_fn` is not provided,
1184
- it will be the raw LM response string. Otherwise it will be the
1185
- structured output from the LM.
1186
1283
  message: The entire message returned by the LM, which could be used to
1187
- trace the LM input, response and parsed structure.
1284
+ trace the LM input, response and parsed structure. If error is raised
1285
+ before LLM could return a response, None will be its value.
1286
+ error: The exception during processing the example.
1287
+ dryrun: Whether or not audition takes place during dryrun.
1188
1288
  """
1289
+ if error is not None:
1290
+ self._failures.append((example, str(error)))
1291
+ if isinstance(error, lf_structured.MappingError):
1292
+ message = error.lm_response
1293
+ else:
1294
+ assert message is not None
1295
+ output = message.text if self.schema is None else message.result
1296
+ self.audit_processed(example, output, message, dryrun=dryrun)
1297
+
1298
+ # Audit usage.
1299
+ if message is not None:
1300
+ self.audit_usage(message, dryrun=dryrun)
1301
+ self._num_completed += 1
1302
+
1303
+ def audit_usage(self, message: lf.Message, dryrun: bool = False) -> None:
1304
+ for m in message.trace():
1305
+ if 'usage' in m.metadata:
1306
+ self._total_prompt_tokens += m.usage.prompt_tokens
1307
+ self._total_completion_tokens += m.usage.completion_tokens
1308
+ self._num_usages += 1
1309
+
1310
+ def audit_processed(
1311
+ self, example: Any, output: Any, message: lf.Message, dryrun: bool = False
1312
+ ) -> None:
1313
+ """Audits a successfully processed example. Subclass should override."""
1189
1314
 
1190
1315
  def save(
1191
1316
  self, definition: bool = True, result: bool = True, report: bool = True
@@ -1229,8 +1354,10 @@ class Evaluation(Evaluable):
1229
1354
  '<td>Prompt</td>'
1230
1355
  '<td>Schema</td>'
1231
1356
  '<td>Additional Args</td>'
1232
- '<td>Failures</td>'
1233
1357
  )
1358
+ if self.result.usage is not None:
1359
+ s.write('<td>Usage</td>')
1360
+ s.write('<td>Failures</td>')
1234
1361
 
1235
1362
  def _render_result_row(self, s: io.StringIO) -> None:
1236
1363
  s.write(
@@ -1255,6 +1382,12 @@ class Evaluation(Evaluable):
1255
1382
  '<td style="color:purple" '
1256
1383
  f'{_html_repr(self.additional_args, compact=False)}</td>'
1257
1384
  )
1385
+ # Usage.
1386
+ if self.result.usage is not None:
1387
+ s.write('<td>')
1388
+ self._render_usage(s)
1389
+ s.write('</td>')
1390
+
1258
1391
  # Failures.
1259
1392
  s.write(
1260
1393
  '<td><span style="color:orange">%s</span>%s</td>'
@@ -1353,8 +1486,8 @@ class Summary(pg.Object):
1353
1486
  Type[lf.LanguageModel],
1354
1487
  tuple[lf.LanguageModel | Type[lf.LanguageModel], ...],
1355
1488
  ] = lf.LanguageModel,
1356
- method: Union[str, tuple[str], None] = None,
1357
- schema_fn: Union[pg.Functor, tuple[pg.Functor], None] = None,
1489
+ method: Union[str, tuple[str, ...], None] = None,
1490
+ schema_fn: Union[pg.Functor, tuple[pg.Functor, ...], None] = None,
1358
1491
  completed: bool | None = None,
1359
1492
  pivot_field: str | None = None,
1360
1493
  ) -> 'Summary':
@@ -1518,9 +1651,12 @@ class Summary(pg.Object):
1518
1651
  pivot_field = pivot_field or self.pivot_field
1519
1652
  s = io.StringIO()
1520
1653
  s.write('<html><body>')
1521
- for task in self.tasks():
1654
+ for task in sorted(self.tasks(), key=lambda cls: cls.__name__):
1655
+ table_id = task.__name__.lower()
1522
1656
  s.write('<div>')
1523
- s.write(f'<h2>{task.__name__}</h2>')
1657
+ s.write(f'<a id="{table_id}"')
1658
+ s.write(f'<h2><a href="#{table_id}">{task.__name__}</a></h2>')
1659
+ s.write('</a>')
1524
1660
  table = Summary.Table.from_evaluations(
1525
1661
  self.select(task=task).evaluations, pivot_field
1526
1662
  )
@@ -1532,8 +1668,35 @@ class Summary(pg.Object):
1532
1668
  def _repr_html_(self) -> str:
1533
1669
  return self.html()
1534
1670
 
1671
+ def json(
1672
+ self,
1673
+ ) -> dict[
1674
+ str, # Task name
1675
+ list[pg.Dict], # List of pg.Dict with `experiment` and `metrics`.
1676
+ ]:
1677
+ """Returns the JSON representation of the summary."""
1678
+ task_results = {}
1679
+ for task in sorted(self.tasks(), key=lambda cls: cls.__name__):
1680
+ results = []
1681
+ for entry in self.select(task=task).evaluations:
1682
+ results.append(
1683
+ pg.Dict(
1684
+ id=entry.id,
1685
+ experiment=entry,
1686
+ dir=entry.dir,
1687
+ metrics=entry.result.metrics if entry.result else None,
1688
+ )
1689
+ )
1690
+ task_results[task.__name__] = results
1691
+ return task_results
1692
+
1535
1693
  def save(self, file: str, pivot_field: str | None = None) -> None:
1536
1694
  pg.save(self.html(pivot_field), file, file_format='txt')
1695
+ if file.endswith('.html'):
1696
+ json_file = file.replace('.html', '.json')
1697
+ else:
1698
+ json_file = os.path.join(file, '.json')
1699
+ pg.save(self.json(), json_file)
1537
1700
 
1538
1701
  @classmethod
1539
1702
  def from_dirs(
@@ -1744,3 +1907,43 @@ def monitor_async(
1744
1907
  scan_interval=scan_interval,
1745
1908
  refresh_when_stop=refresh_when_stop,
1746
1909
  )
1910
+
1911
+
1912
+ def app_run(target: Evaluable):
1913
+ """Runs the target evaluation as an absl app.
1914
+
1915
+ Args:
1916
+ target: An Langfun evaluable object.
1917
+ """
1918
+ flags.DEFINE_string(
1919
+ 'root_dir', None, 'Root directory for running the evaluation.'
1920
+ )
1921
+
1922
+ flags.DEFINE_bool(
1923
+ 'dryrun', False, 'If True, dryrun the experiment instead of running it.'
1924
+ )
1925
+
1926
+ flags.DEFINE_bool(
1927
+ 'debug', False, 'If True, output prompt and response to the console.'
1928
+ )
1929
+
1930
+ flags.DEFINE_bool(
1931
+ 'rerun',
1932
+ False,
1933
+ 'If True, rerun the experiment even a cached result is found.',
1934
+ )
1935
+
1936
+ FLAGS = flags.FLAGS # pylint: disable=invalid-name
1937
+
1938
+ def _main(argv):
1939
+ if len(argv) > 1:
1940
+ raise app.UsageError('Too many command-line arguments.')
1941
+
1942
+ if FLAGS.root_dir:
1943
+ target.rebind(root_dir=FLAGS.root_dir, raise_on_no_change=False)
1944
+ if FLAGS.dryrun:
1945
+ target.dryrun(debug=FLAGS.debug)
1946
+ else:
1947
+ target.run(debug=FLAGS.debug, rerun=FLAGS.rerun)
1948
+
1949
+ app.run(_main)