langfun 0.0.2.dev20240330__py3-none-any.whl → 0.0.2.dev20240429__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. langfun/__init__.py +2 -0
  2. langfun/core/__init__.py +1 -0
  3. langfun/core/coding/python/correction.py +0 -7
  4. langfun/core/component.py +6 -0
  5. langfun/core/component_test.py +1 -0
  6. langfun/core/eval/__init__.py +2 -0
  7. langfun/core/eval/base.py +202 -23
  8. langfun/core/eval/base_test.py +49 -10
  9. langfun/core/eval/matching.py +26 -9
  10. langfun/core/eval/matching_test.py +2 -1
  11. langfun/core/eval/scoring.py +15 -6
  12. langfun/core/eval/scoring_test.py +2 -1
  13. langfun/core/langfunc.py +0 -5
  14. langfun/core/langfunc_test.py +6 -4
  15. langfun/core/language_model.py +124 -24
  16. langfun/core/language_model_test.py +249 -26
  17. langfun/core/llms/__init__.py +19 -2
  18. langfun/core/llms/anthropic.py +263 -0
  19. langfun/core/llms/anthropic_test.py +167 -0
  20. langfun/core/llms/cache/in_memory_test.py +37 -28
  21. langfun/core/llms/fake.py +31 -22
  22. langfun/core/llms/fake_test.py +122 -11
  23. langfun/core/llms/google_genai_test.py +8 -3
  24. langfun/core/llms/groq.py +260 -0
  25. langfun/core/llms/groq_test.py +170 -0
  26. langfun/core/llms/llama_cpp.py +3 -1
  27. langfun/core/llms/openai.py +97 -79
  28. langfun/core/llms/openai_test.py +285 -59
  29. langfun/core/modalities/video.py +5 -2
  30. langfun/core/structured/__init__.py +3 -0
  31. langfun/core/structured/completion_test.py +2 -2
  32. langfun/core/structured/function_generation.py +245 -0
  33. langfun/core/structured/function_generation_test.py +329 -0
  34. langfun/core/structured/mapping.py +56 -2
  35. langfun/core/structured/mapping_test.py +17 -0
  36. langfun/core/structured/parsing_test.py +18 -13
  37. langfun/core/structured/prompting.py +27 -6
  38. langfun/core/structured/prompting_test.py +79 -12
  39. langfun/core/structured/schema.py +4 -2
  40. langfun/core/structured/schema_generation_test.py +2 -2
  41. langfun/core/structured/schema_test.py +4 -6
  42. langfun/core/template.py +125 -10
  43. langfun/core/template_test.py +75 -0
  44. langfun/core/templates/selfplay_test.py +6 -2
  45. {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240429.dist-info}/METADATA +3 -2
  46. {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240429.dist-info}/RECORD +49 -43
  47. {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240429.dist-info}/LICENSE +0 -0
  48. {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240429.dist-info}/WHEEL +0 -0
  49. {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240429.dist-info}/top_level.txt +0 -0
langfun/__init__.py CHANGED
@@ -34,6 +34,7 @@ score = structured.score
34
34
  generate_class = structured.generate_class
35
35
 
36
36
  source_form = structured.source_form
37
+ function_gen = structured.function_gen
37
38
 
38
39
  from langfun.core import eval # pylint: disable=redefined-builtin
39
40
  from langfun.core import templates
@@ -54,6 +55,7 @@ Video = modalities.Video
54
55
  PDF = modalities.PDF
55
56
 
56
57
  # Error types.
58
+ MappingError = structured.MappingError
57
59
  SchemaError = structured.SchemaError
58
60
  JsonError = structured.JsonError
59
61
  CodeError = coding.CodeError
langfun/core/__init__.py CHANGED
@@ -99,6 +99,7 @@ from langfun.core.modality import ModalityRef
99
99
  from langfun.core.language_model import LanguageModel
100
100
  from langfun.core.language_model import LMSample
101
101
  from langfun.core.language_model import LMSamplingOptions
102
+ from langfun.core.language_model import LMSamplingUsage
102
103
  from langfun.core.language_model import LMSamplingResult
103
104
  from langfun.core.language_model import LMScoringResult
104
105
  from langfun.core.language_model import LMCache
@@ -12,7 +12,6 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  """Python code error correction."""
15
- import re
16
15
  from typing import Any
17
16
  import langfun.core as lf
18
17
  from langfun.core.coding.python import errors
@@ -31,11 +30,6 @@ class CorrectedCode(pg.Object):
31
30
  corrected_code: str
32
31
 
33
32
 
34
- def remove_docstrings(code):
35
- pattern = re.compile(r"(def .+?:\s*?)('''|\"\"\")((.|\s)*?)(\2)", re.DOTALL)
36
- return pattern.sub(r"\1", code)
37
-
38
-
39
33
  def run_with_correction(
40
34
  code: str,
41
35
  error: str | None = None,
@@ -86,7 +80,6 @@ def run_with_correction(
86
80
  # pytype: enable=import-error
87
81
  # pylint: enable=g-import-not-at-top
88
82
 
89
- code = remove_docstrings(code)
90
83
  if max_attempts == 0:
91
84
  result = execution.run(
92
85
  code,
langfun/core/component.py CHANGED
@@ -210,6 +210,12 @@ def get_contextual_override(var_name: str) -> ContextualOverride | None:
210
210
  return _get_scoped_value(_global_tls, _CONTEXT_OVERRIDES, var_name)
211
211
 
212
212
 
213
+ def all_contextual_values() -> dict[str, Any]:
214
+ """Returns all contextual values provided from `lf.context` in scope."""
215
+ overrides = getattr(_global_tls, _CONTEXT_OVERRIDES, {})
216
+ return {k: v.value for k, v in overrides.items()}
217
+
218
+
213
219
  @contextlib.contextmanager
214
220
  def _contextual_scope(
215
221
  tls: threading.local, tls_key, **variables
@@ -84,6 +84,7 @@ class ComponentContextTest(unittest.TestCase):
84
84
  lf.get_contextual_override('y'),
85
85
  lf.ContextualOverride(3, cascade=False, override_attrs=False),
86
86
  )
87
+ self.assertEqual(lf.all_contextual_values(), dict(x=3, y=3, z=3))
87
88
 
88
89
  # Member attributes take precedence over `lf.context`.
89
90
  self.assertEqual(a1.x, 1)
@@ -16,6 +16,8 @@
16
16
  # pylint: disable=g-importing-member
17
17
  # pylint: disable=g-bad-import-order
18
18
 
19
+ from langfun.core.eval.base import app_run
20
+
19
21
  from langfun.core.eval.base import Evaluable
20
22
  from langfun.core.eval.base import Evaluation
21
23
  from langfun.core.eval.base import Suite
langfun/core/eval/base.py CHANGED
@@ -26,6 +26,8 @@ import threading
26
26
  import time
27
27
  from typing import Annotated, Any, Callable, Iterator, Literal, Optional, Sequence, Type, Union
28
28
 
29
+ from absl import app
30
+ from absl import flags
29
31
  import langfun.core as lf
30
32
  import langfun.core.coding as lf_coding
31
33
  from langfun.core.llms.cache import in_memory
@@ -538,7 +540,7 @@ class Evaluable(lf.Component):
538
540
  f'<div style="color: {text_color}; white-space: pre-wrap;'
539
541
  'padding: 10px; border: 1px solid; margin-top: 10px">'
540
542
  )
541
- s.write(m.text)
543
+ s.write(m.get('formatted_text', m.text))
542
544
  if m.result is not None:
543
545
  s.write(
544
546
  '<div style="color: magenta; white-space: pre-wrap;'
@@ -546,6 +548,16 @@ class Evaluable(lf.Component):
546
548
  )
547
549
  s.write(pg.format(m.result))
548
550
  s.write('</div>')
551
+ if 'usage' in m.metadata:
552
+ s.write(
553
+ '<div style="background-color: #EEEEEE; color: black; '
554
+ 'white-space: pre-wrap; padding: 10px; border: 0px solid; '
555
+ 'margin: 10px">'
556
+ f'prompt: {m.usage.prompt_tokens} tokens, '
557
+ f'response: {m.usage.completion_tokens} tokens, '
558
+ f'total: {m.usage.total_tokens} tokens'
559
+ '</div>'
560
+ )
549
561
  s.write('</div>')
550
562
 
551
563
  @classmethod
@@ -810,17 +822,36 @@ class Evaluation(Evaluable):
810
822
  return 0.0
811
823
  return self.num_failures / self.num_completed
812
824
 
825
+ @property
826
+ def has_usage(self) -> bool:
827
+ """Returns True if token usage is enabled."""
828
+ return self._num_usages > 0
829
+
830
+ @property
831
+ def average_prompt_tokens(self) -> int:
832
+ """Returns the average prompt tokens."""
833
+ if not self.has_usage:
834
+ return 0
835
+ return self._total_prompt_tokens // self._num_usages
836
+
837
+ @property
838
+ def average_completion_tokens(self) -> int:
839
+ """Returns the average completion tokens."""
840
+ if not self.has_usage:
841
+ return 0
842
+ return self._total_completion_tokens // self._num_usages
843
+
844
+ @property
845
+ def average_total_tokens(self) -> int:
846
+ """Returns the average total tokens."""
847
+ return self.average_prompt_tokens + self.average_completion_tokens
848
+
813
849
  @functools.cached_property
814
850
  def schema(self) -> lf_structured.Schema | None:
815
851
  """Schema."""
816
852
  if self.schema_fn is None:
817
853
  return None
818
854
 
819
- kwargs = {}
820
- # Allow schema to be a function based on current evaluation.
821
- if 'evaluation' in self.schema_fn.__signature__.arg_names:
822
- kwargs['evaluation'] = self
823
-
824
855
  schema = self._call_schema_fn()
825
856
  fewshot_examples = None
826
857
  if isinstance(schema, tuple):
@@ -861,7 +892,11 @@ class Evaluation(Evaluable):
861
892
  'Encountered: {annotation!r}.'
862
893
  )
863
894
  self._maybe_adjust_schema_for_completion(annotation)
864
- return lf_structured.Schema.from_value(annotation)
895
+ schema = lf_structured.Schema.from_value(annotation)
896
+ # NOTE(daiyip): add references to the dependent classes of the returned type
897
+ # to prevent unused subclasses get garbage collected by Python.
898
+ setattr(schema, '__dependencies__', schema.class_dependencies())
899
+ return schema
865
900
 
866
901
  def _maybe_adjust_schema_for_completion(self, cls):
867
902
  if (self.completion_prompt_field is None
@@ -938,6 +973,10 @@ class Evaluation(Evaluable):
938
973
  self._failures = []
939
974
  self._num_completed = 0
940
975
 
976
+ self._total_prompt_tokens = 0
977
+ self._total_completion_tokens = 0
978
+ self._num_usages = 0
979
+
941
980
  @property
942
981
  def failures_link(self) -> str | None:
943
982
  """Returns the link to the failures page."""
@@ -957,7 +996,7 @@ class Evaluation(Evaluable):
957
996
  example = example or self.examples[0]
958
997
 
959
998
  # We make a copy to avoid pollute the state of current object.
960
- copy = self.clone()
999
+ copy: Evaluation = self.clone()
961
1000
  copy.__dict__['examples'] = [example]
962
1001
 
963
1002
  # We set the symbolic parent of the cloned to access contextual information
@@ -987,9 +1026,9 @@ class Evaluation(Evaluable):
987
1026
  color='blue',
988
1027
  )
989
1028
 
990
- # Audit the result.
991
- copy.audit(example, output, output_message)
1029
+ copy.audit(example, output_message, None, dryrun=True)
992
1030
  result = copy.summarize()
1031
+
993
1032
  if verbose:
994
1033
  lf.console.write('')
995
1034
  lf.console.write(
@@ -1036,11 +1075,12 @@ class Evaluation(Evaluable):
1036
1075
  status_fn=self._status,
1037
1076
  ):
1038
1077
  if error is not None:
1039
- self._failures.append((example, str(error)))
1040
- else:
1041
- output = message.text if self.schema is None else message.result
1042
- self.audit(example, output, message)
1043
- self._num_completed += 1
1078
+ message = (
1079
+ error.lm_response
1080
+ if isinstance(error, lf_structured.MappingError)
1081
+ else None
1082
+ )
1083
+ self.audit(example, message, error)
1044
1084
  finally:
1045
1085
  # Save cache upon completion or interruption.
1046
1086
  if self.dir and self.cache:
@@ -1143,6 +1183,19 @@ class Evaluation(Evaluable):
1143
1183
  )
1144
1184
  else:
1145
1185
  cache_stats = dict(use_cache=False)
1186
+
1187
+ if self.has_usage:
1188
+ usage = pg.Dict(
1189
+ total_prompt_tokens=self._total_prompt_tokens,
1190
+ total_completion_tokens=self._total_completion_tokens,
1191
+ num_usages=self._num_usages,
1192
+ average_prompt_tokens=self.average_prompt_tokens,
1193
+ average_completion_tokens=self.average_completion_tokens,
1194
+ average_total_tokens=self.average_total_tokens,
1195
+ )
1196
+ else:
1197
+ usage = None
1198
+
1146
1199
  result = pg.Dict(
1147
1200
  experiment_setup=pg.Dict(
1148
1201
  id=self.id,
@@ -1158,6 +1211,7 @@ class Evaluation(Evaluable):
1158
1211
  failures=self.num_failures,
1159
1212
  failure_rate=self.failure_rate,
1160
1213
  ),
1214
+ usage=usage,
1161
1215
  )
1162
1216
  return result
1163
1217
 
@@ -1179,9 +1233,28 @@ class Evaluation(Evaluable):
1179
1233
  '</td></tr><tr><td>'
1180
1234
  )
1181
1235
  self._render_metric(s)
1236
+
1237
+ # Summarize average usage.
1238
+ if self.result.usage is not None:
1239
+ self._render_usage(s)
1240
+
1182
1241
  s.write('</td></tr></table></div>')
1183
1242
  return s.getvalue()
1184
1243
 
1244
+ def _render_usage(self, s: io.StringIO) -> None:
1245
+ """Renders usage in HTML."""
1246
+ usage = self.result.usage
1247
+ total = usage.total_prompt_tokens + usage.total_completion_tokens
1248
+ s.write(
1249
+ '&nbsp;<a title="'
1250
+ f'# of usages: {usage.num_usages}&#013;'
1251
+ f'total prompt: {usage.total_prompt_tokens}&#013;'
1252
+ f'total response: {usage.total_completion_tokens}&#013;'
1253
+ f'avg prompt: {usage.average_prompt_tokens}&#013;'
1254
+ f'avg response: {usage.average_completion_tokens}'
1255
+ f'" style="color:gray">({total} tokens)</a>'
1256
+ )
1257
+
1185
1258
  def _render_metric(self, s: io.StringIO) -> None:
1186
1259
  """Renders metrics in HTML."""
1187
1260
  assert self.result is not None
@@ -1196,17 +1269,48 @@ class Evaluation(Evaluable):
1196
1269
  )
1197
1270
  )
1198
1271
 
1199
- def audit(self, example: Any, output: Any, message: lf.Message) -> None:
1272
+ def audit(
1273
+ self,
1274
+ example: Any,
1275
+ message: lf.Message | None,
1276
+ error: Exception | None = None,
1277
+ dryrun: bool = False,
1278
+ ) -> None:
1200
1279
  """Audits the example against the output. Subclasses should override.
1201
1280
 
1202
1281
  Args:
1203
1282
  example: The input object.
1204
- output: The output from LM. For `lf.call`, if `schema_fn` is not provided,
1205
- it will be the raw LM response string. Otherwise it will be the
1206
- structured output from the LM.
1207
1283
  message: The entire message returned by the LM, which could be used to
1208
- trace the LM input, response and parsed structure.
1284
+ trace the LM input, response and parsed structure. If error is raised
1285
+ before LLM could return a response, None will be its value.
1286
+ error: The exception during processing the example.
1287
+ dryrun: Whether or not audition takes place during dryrun.
1209
1288
  """
1289
+ if error is not None:
1290
+ self._failures.append((example, str(error)))
1291
+ if isinstance(error, lf_structured.MappingError):
1292
+ message = error.lm_response
1293
+ else:
1294
+ assert message is not None
1295
+ output = message.text if self.schema is None else message.result
1296
+ self.audit_processed(example, output, message, dryrun=dryrun)
1297
+
1298
+ # Audit usage.
1299
+ if message is not None:
1300
+ self.audit_usage(message, dryrun=dryrun)
1301
+ self._num_completed += 1
1302
+
1303
+ def audit_usage(self, message: lf.Message, dryrun: bool = False) -> None:
1304
+ for m in message.trace():
1305
+ if 'usage' in m.metadata:
1306
+ self._total_prompt_tokens += m.usage.prompt_tokens
1307
+ self._total_completion_tokens += m.usage.completion_tokens
1308
+ self._num_usages += 1
1309
+
1310
+ def audit_processed(
1311
+ self, example: Any, output: Any, message: lf.Message, dryrun: bool = False
1312
+ ) -> None:
1313
+ """Audits a successfully processed example. Subclass should override."""
1210
1314
 
1211
1315
  def save(
1212
1316
  self, definition: bool = True, result: bool = True, report: bool = True
@@ -1250,8 +1354,10 @@ class Evaluation(Evaluable):
1250
1354
  '<td>Prompt</td>'
1251
1355
  '<td>Schema</td>'
1252
1356
  '<td>Additional Args</td>'
1253
- '<td>Failures</td>'
1254
1357
  )
1358
+ if self.result.usage is not None:
1359
+ s.write('<td>Usage</td>')
1360
+ s.write('<td>Failures</td>')
1255
1361
 
1256
1362
  def _render_result_row(self, s: io.StringIO) -> None:
1257
1363
  s.write(
@@ -1276,6 +1382,12 @@ class Evaluation(Evaluable):
1276
1382
  '<td style="color:purple" '
1277
1383
  f'{_html_repr(self.additional_args, compact=False)}</td>'
1278
1384
  )
1385
+ # Usage.
1386
+ if self.result.usage is not None:
1387
+ s.write('<td>')
1388
+ self._render_usage(s)
1389
+ s.write('</td>')
1390
+
1279
1391
  # Failures.
1280
1392
  s.write(
1281
1393
  '<td><span style="color:orange">%s</span>%s</td>'
@@ -1374,8 +1486,8 @@ class Summary(pg.Object):
1374
1486
  Type[lf.LanguageModel],
1375
1487
  tuple[lf.LanguageModel | Type[lf.LanguageModel], ...],
1376
1488
  ] = lf.LanguageModel,
1377
- method: Union[str, tuple[str], None] = None,
1378
- schema_fn: Union[pg.Functor, tuple[pg.Functor], None] = None,
1489
+ method: Union[str, tuple[str, ...], None] = None,
1490
+ schema_fn: Union[pg.Functor, tuple[pg.Functor, ...], None] = None,
1379
1491
  completed: bool | None = None,
1380
1492
  pivot_field: str | None = None,
1381
1493
  ) -> 'Summary':
@@ -1556,8 +1668,35 @@ class Summary(pg.Object):
1556
1668
  def _repr_html_(self) -> str:
1557
1669
  return self.html()
1558
1670
 
1671
+ def json(
1672
+ self,
1673
+ ) -> dict[
1674
+ str, # Task name
1675
+ list[pg.Dict], # List of pg.Dict with `experiment` and `metrics`.
1676
+ ]:
1677
+ """Returns the JSON representation of the summary."""
1678
+ task_results = {}
1679
+ for task in sorted(self.tasks(), key=lambda cls: cls.__name__):
1680
+ results = []
1681
+ for entry in self.select(task=task).evaluations:
1682
+ results.append(
1683
+ pg.Dict(
1684
+ id=entry.id,
1685
+ experiment=entry,
1686
+ dir=entry.dir,
1687
+ metrics=entry.result.metrics if entry.result else None,
1688
+ )
1689
+ )
1690
+ task_results[task.__name__] = results
1691
+ return task_results
1692
+
1559
1693
  def save(self, file: str, pivot_field: str | None = None) -> None:
1560
1694
  pg.save(self.html(pivot_field), file, file_format='txt')
1695
+ if file.endswith('.html'):
1696
+ json_file = file.replace('.html', '.json')
1697
+ else:
1698
+ json_file = os.path.join(file, '.json')
1699
+ pg.save(self.json(), json_file)
1561
1700
 
1562
1701
  @classmethod
1563
1702
  def from_dirs(
@@ -1768,3 +1907,43 @@ def monitor_async(
1768
1907
  scan_interval=scan_interval,
1769
1908
  refresh_when_stop=refresh_when_stop,
1770
1909
  )
1910
+
1911
+
1912
+ def app_run(target: Evaluable):
1913
+ """Runs the target evaluation as an absl app.
1914
+
1915
+ Args:
1916
+ target: An Langfun evaluable object.
1917
+ """
1918
+ flags.DEFINE_string(
1919
+ 'root_dir', None, 'Root directory for running the evaluation.'
1920
+ )
1921
+
1922
+ flags.DEFINE_bool(
1923
+ 'dryrun', False, 'If True, dryrun the experiment instead of running it.'
1924
+ )
1925
+
1926
+ flags.DEFINE_bool(
1927
+ 'debug', False, 'If True, output prompt and response to the console.'
1928
+ )
1929
+
1930
+ flags.DEFINE_bool(
1931
+ 'rerun',
1932
+ False,
1933
+ 'If True, rerun the experiment even a cached result is found.',
1934
+ )
1935
+
1936
+ FLAGS = flags.FLAGS # pylint: disable=invalid-name
1937
+
1938
+ def _main(argv):
1939
+ if len(argv) > 1:
1940
+ raise app.UsageError('Too many command-line arguments.')
1941
+
1942
+ if FLAGS.root_dir:
1943
+ target.rebind(root_dir=FLAGS.root_dir, raise_on_no_change=False)
1944
+ if FLAGS.dryrun:
1945
+ target.dryrun(debug=FLAGS.debug)
1946
+ else:
1947
+ target.run(debug=FLAGS.debug, rerun=FLAGS.rerun)
1948
+
1949
+ app.run(_main)
@@ -101,7 +101,7 @@ class EvaluationTest(unittest.TestCase):
101
101
  self.assertEqual(s.dir, os.path.join(s.root_dir, s.id))
102
102
  self.assertEqual(s.hash, s.clone().hash)
103
103
  # Test persistent hash.
104
- self.assertEqual(s.hash, 'abc7c29a')
104
+ self.assertEqual(s.hash, 'ae86c703')
105
105
  self.assertEqual(
106
106
  s.hash, s.clone(override={'max_workers': 2, 'lm.timeout': 20}).hash
107
107
  )
@@ -194,6 +194,7 @@ class EvaluationTest(unittest.TestCase):
194
194
  cache_seed=0,
195
195
  score=1.0,
196
196
  logprobs=None,
197
+ usage=lf.LMSamplingUsage(387, 24, 411),
197
198
  tags=['lm-response', 'lm-output', 'transformed'],
198
199
  ),
199
200
  )
@@ -209,7 +210,7 @@ class EvaluationTest(unittest.TestCase):
209
210
  s.result,
210
211
  dict(
211
212
  experiment_setup=dict(
212
- id='Evaluation@17915dc6',
213
+ id='Evaluation@0fade07d',
213
214
  dir=s.dir,
214
215
  model='StaticSequence',
215
216
  prompt_template='{{example.question}}',
@@ -220,6 +221,14 @@ class EvaluationTest(unittest.TestCase):
220
221
  use_cache=True, num_queries=2, num_hits=0, num_updates=2
221
222
  ),
222
223
  metrics=dict(total=2, failures=1, failure_rate=0.5),
224
+ usage=dict(
225
+ total_prompt_tokens=774,
226
+ total_completion_tokens=25,
227
+ num_usages=2,
228
+ average_prompt_tokens=387,
229
+ average_completion_tokens=12,
230
+ average_total_tokens=399,
231
+ ),
223
232
  ),
224
233
  )
225
234
  self.assertTrue(
@@ -228,13 +237,23 @@ class EvaluationTest(unittest.TestCase):
228
237
  os.path.exists(os.path.join(s.dir, base.Evaluation.RESULT_JSON)))
229
238
  self.assertTrue(
230
239
  os.path.exists(os.path.join(s.dir, base.Evaluation.CACHE_JSON)))
231
- self.assertTrue(
232
- os.path.exists(os.path.join(s.root_dir, base.Evaluation.SUMMARY_HTML))
233
- )
234
240
  self.assertTrue(
235
241
  os.path.exists(os.path.join(s.dir, base.Evaluation.INDEX_HTML)))
236
242
  self.assertTrue(
237
243
  os.path.exists(os.path.join(s.dir, base.Evaluation.FAILURES_HTML)))
244
+ self.assertTrue(
245
+ os.path.exists(os.path.join(s.root_dir, base.Evaluation.SUMMARY_HTML))
246
+ )
247
+ # Check summary JSON.
248
+ summary_json = os.path.join(
249
+ s.root_dir, base.Evaluation.SUMMARY_HTML.replace('.html', '.json')
250
+ )
251
+ self.assertTrue(os.path.exists(summary_json))
252
+ summary = pg.load(summary_json, force_dict=True)
253
+ self.assertIn('Evaluation', summary)
254
+ self.assertEqual(len(summary['Evaluation']), 1)
255
+ self.assertIsNotNone(summary['Evaluation'][0].experiment)
256
+ self.assertIsNotNone(summary['Evaluation'][0].metrics)
238
257
 
239
258
  def test_run_wihtout_save(self):
240
259
  lm = fake.StaticSequence([
@@ -274,8 +293,11 @@ class EvaluationTest(unittest.TestCase):
274
293
  s = eval_set(
275
294
  'run_filter_test', pg.oneof(['call', 'query']),
276
295
  schema_fn=answer_schema(), lm=lm)
296
+ result = s.run(
297
+ filter=lambda x: x.method == 'query', dryrun=True, summary=False
298
+ )
277
299
  self.assertEqual(
278
- s.run(filter=lambda x: x.method == 'query', dryrun=True, summary=False),
300
+ result,
279
301
  {
280
302
  s.children[0].id: None,
281
303
  s.children[1].id: dict(
@@ -291,7 +313,8 @@ class EvaluationTest(unittest.TestCase):
291
313
  use_cache=True, num_queries=2, num_hits=0, num_updates=2
292
314
  ),
293
315
  metrics=dict(total=2, failures=0, failure_rate=0.0),
294
- )
316
+ usage=s.children[1].result.usage,
317
+ ),
295
318
  },
296
319
  )
297
320
 
@@ -321,11 +344,10 @@ class EvaluationTest(unittest.TestCase):
321
344
  s.children[0].dir, os.path.join(s.root_dir, s.children[0].id)
322
345
  )
323
346
  # Test persistent hash.
324
- self.assertEqual(s.hash, 'ca7f722b')
347
+ self.assertEqual(s.hash, 'b66a4e88')
325
348
 
326
349
  summary = s.run(verbose=True)
327
350
  self.assertEqual(len(summary.evaluations), 2)
328
-
329
351
  self.assertEqual(
330
352
  s.result,
331
353
  {
@@ -342,6 +364,7 @@ class EvaluationTest(unittest.TestCase):
342
364
  use_cache=True, num_queries=2, num_hits=0, num_updates=2
343
365
  ),
344
366
  metrics=dict(total=2, failures=1, failure_rate=0.5),
367
+ usage=s.children[0].result.usage,
345
368
  ),
346
369
  s.children[1].id: dict(
347
370
  experiment_setup=dict(
@@ -356,6 +379,7 @@ class EvaluationTest(unittest.TestCase):
356
379
  use_cache=True, num_queries=2, num_hits=0, num_updates=2
357
380
  ),
358
381
  metrics=dict(total=2, failures=1, failure_rate=0.5),
382
+ usage=s.children[1].result.usage,
359
383
  ),
360
384
  },
361
385
  )
@@ -448,7 +472,7 @@ class SuiteTest(unittest.TestCase):
448
472
  lm=lm
449
473
  )
450
474
  # Test for persistent hash.
451
- self.assertEqual(s.hash, '7285e52b')
475
+ self.assertEqual(s.hash, '26e6cc25')
452
476
  s.run()
453
477
  expected = {
454
478
  s.children[0].id: dict(
@@ -464,6 +488,7 @@ class SuiteTest(unittest.TestCase):
464
488
  use_cache=True, num_queries=2, num_hits=0, num_updates=2
465
489
  ),
466
490
  metrics=dict(total=2, failures=1, failure_rate=0.5),
491
+ usage=s.children[0].result.usage,
467
492
  ),
468
493
  s.children[1].id: {
469
494
  s.children[1]
@@ -481,6 +506,7 @@ class SuiteTest(unittest.TestCase):
481
506
  use_cache=True, num_queries=4, num_hits=1, num_updates=3
482
507
  ),
483
508
  metrics=dict(total=2, failures=2, failure_rate=1.0),
509
+ usage=s.children[1].children[0].result.usage,
484
510
  ),
485
511
  s.children[1]
486
512
  .children[2]
@@ -500,6 +526,7 @@ class SuiteTest(unittest.TestCase):
500
526
  num_updates=2,
501
527
  ),
502
528
  metrics=dict(total=2, failures=1, failure_rate=0.5),
529
+ usage=s.children[1].children[2].result.usage,
503
530
  ),
504
531
  },
505
532
  }
@@ -671,5 +698,17 @@ class SummaryTest(unittest.TestCase):
671
698
  self.assertTrue(pg.io.path_exists(summary_file))
672
699
 
673
700
 
701
+ class AppRunTest(unittest.TestCase):
702
+
703
+ def test_app_run(self):
704
+ lm = fake.StaticSequence(['two', 'Solution(final_answer=2)'])
705
+ try:
706
+ base.app_run(
707
+ eval_set('app_run_test', 'query', schema_fn=answer_schema(), lm=lm)
708
+ )
709
+ except SystemExit:
710
+ pass
711
+
712
+
674
713
  if __name__ == '__main__':
675
714
  unittest.main()
@@ -86,9 +86,26 @@ class Matching(base.Evaluation):
86
86
  self._matches = []
87
87
  self._mismatches = []
88
88
 
89
- def audit(self, example: Any, output: Any, message: lf.Message) -> None:
89
+ def audit_processed(
90
+ self, example: Any, output: Any, message: lf.Message, dryrun: bool = False
91
+ ) -> None:
90
92
  groundtruth = self.groundtruth(example)
91
93
  answer = self.answer(output, example)
94
+
95
+ if dryrun:
96
+ lf.console.write('')
97
+ lf.console.write(
98
+ str(groundtruth),
99
+ title='GROUDTRUTH',
100
+ color='green',
101
+ )
102
+ lf.console.write('')
103
+ lf.console.write(
104
+ str(answer),
105
+ title='ANSWER',
106
+ color='blue',
107
+ )
108
+
92
109
  if self.match(answer, groundtruth):
93
110
  self._matches.append((example, output, message))
94
111
  else:
@@ -155,19 +172,16 @@ class Matching(base.Evaluation):
155
172
  super().save(definition, result, report)
156
173
 
157
174
  if result:
158
-
159
- def force_dict(v):
160
- return pg.object_utils.json_conversion.strip_types(pg.to_json(v))
161
-
162
175
  # Save matches.
163
176
  pg.save(
164
177
  [
165
- # We force the output to be dict as its type may be defined
166
- # within functors which could be deserialized.
167
- pg.Dict(input=input, output=force_dict(output))
178
+ pg.Dict(input=input, output=output)
168
179
  for input, output, _ in self.matches
169
180
  ],
170
181
  os.path.join(self.dir, Matching.MATCHES_JSON),
182
+ # We force the input and output to be dict so it does not depend on
183
+ # the downstream to serialize.
184
+ force_dict=True,
171
185
  )
172
186
 
173
187
  # Save mismatches.
@@ -175,10 +189,13 @@ class Matching(base.Evaluation):
175
189
  [
176
190
  # We force the output to be dict as its type may be defined
177
191
  # within functors which could be deserialized.
178
- pg.Dict(input=input, output=force_dict(output))
192
+ pg.Dict(input=input, output=output)
179
193
  for input, output, _ in self.mismatches
180
194
  ],
181
195
  os.path.join(self.dir, Matching.MISMATCHES_JSON),
196
+ # We force the input and output to be dict so it does not depend on
197
+ # the downstream to serialize.
198
+ force_dict=True,
182
199
  )
183
200
 
184
201
  if report: