langfun 0.0.2.dev20240330__py3-none-any.whl → 0.0.2.dev20240511__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of langfun might be problematic. Click here for more details.

Files changed (59) hide show
  1. langfun/__init__.py +7 -0
  2. langfun/core/__init__.py +1 -0
  3. langfun/core/coding/python/correction.py +0 -7
  4. langfun/core/component.py +6 -0
  5. langfun/core/component_test.py +1 -0
  6. langfun/core/eval/__init__.py +15 -0
  7. langfun/core/eval/base.py +665 -95
  8. langfun/core/eval/base_test.py +224 -53
  9. langfun/core/eval/matching.py +48 -30
  10. langfun/core/eval/matching_test.py +25 -3
  11. langfun/core/eval/patching.py +130 -0
  12. langfun/core/eval/patching_test.py +170 -0
  13. langfun/core/eval/scoring.py +19 -10
  14. langfun/core/eval/scoring_test.py +21 -3
  15. langfun/core/langfunc.py +1 -22
  16. langfun/core/langfunc_test.py +10 -4
  17. langfun/core/language_model.py +130 -24
  18. langfun/core/language_model_test.py +249 -26
  19. langfun/core/llms/__init__.py +27 -2
  20. langfun/core/llms/anthropic.py +263 -0
  21. langfun/core/llms/anthropic_test.py +167 -0
  22. langfun/core/llms/cache/in_memory_test.py +37 -28
  23. langfun/core/llms/fake.py +34 -25
  24. langfun/core/llms/fake_test.py +122 -11
  25. langfun/core/llms/google_genai.py +8 -0
  26. langfun/core/llms/google_genai_test.py +8 -3
  27. langfun/core/llms/groq.py +260 -0
  28. langfun/core/llms/groq_test.py +170 -0
  29. langfun/core/llms/llama_cpp.py +3 -1
  30. langfun/core/llms/openai.py +100 -81
  31. langfun/core/llms/openai_test.py +287 -60
  32. langfun/core/llms/vertexai.py +291 -0
  33. langfun/core/llms/vertexai_test.py +233 -0
  34. langfun/core/modalities/image.py +1 -3
  35. langfun/core/modalities/mime.py +6 -0
  36. langfun/core/modalities/video.py +6 -5
  37. langfun/core/structured/__init__.py +5 -0
  38. langfun/core/structured/completion_test.py +2 -2
  39. langfun/core/structured/function_generation.py +245 -0
  40. langfun/core/structured/function_generation_test.py +329 -0
  41. langfun/core/structured/mapping.py +61 -3
  42. langfun/core/structured/mapping_test.py +17 -0
  43. langfun/core/structured/parsing_test.py +18 -13
  44. langfun/core/structured/prompting.py +61 -12
  45. langfun/core/structured/prompting_test.py +122 -12
  46. langfun/core/structured/schema.py +38 -6
  47. langfun/core/structured/schema_generation_test.py +2 -2
  48. langfun/core/structured/schema_test.py +36 -7
  49. langfun/core/structured/scoring.py +4 -1
  50. langfun/core/structured/scoring_test.py +6 -0
  51. langfun/core/template.py +147 -11
  52. langfun/core/template_test.py +75 -0
  53. langfun/core/templates/selfplay_test.py +6 -2
  54. {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240511.dist-info}/METADATA +3 -2
  55. langfun-0.0.2.dev20240511.dist-info/RECORD +112 -0
  56. langfun-0.0.2.dev20240330.dist-info/RECORD +0 -102
  57. {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240511.dist-info}/LICENSE +0 -0
  58. {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240511.dist-info}/WHEEL +0 -0
  59. {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240511.dist-info}/top_level.txt +0 -0
@@ -101,7 +101,7 @@ class EvaluationTest(unittest.TestCase):
101
101
  self.assertEqual(s.dir, os.path.join(s.root_dir, s.id))
102
102
  self.assertEqual(s.hash, s.clone().hash)
103
103
  # Test persistent hash.
104
- self.assertEqual(s.hash, 'abc7c29a')
104
+ self.assertEqual(s.hash, 'ae86c703')
105
105
  self.assertEqual(
106
106
  s.hash, s.clone(override={'max_workers': 2, 'lm.timeout': 20}).hash
107
107
  )
@@ -194,6 +194,7 @@ class EvaluationTest(unittest.TestCase):
194
194
  cache_seed=0,
195
195
  score=1.0,
196
196
  logprobs=None,
197
+ usage=lf.LMSamplingUsage(387, 24, 411),
197
198
  tags=['lm-response', 'lm-output', 'transformed'],
198
199
  ),
199
200
  )
@@ -209,7 +210,7 @@ class EvaluationTest(unittest.TestCase):
209
210
  s.result,
210
211
  dict(
211
212
  experiment_setup=dict(
212
- id='Evaluation@17915dc6',
213
+ id='Evaluation@0fade07d',
213
214
  dir=s.dir,
214
215
  model='StaticSequence',
215
216
  prompt_template='{{example.question}}',
@@ -219,7 +220,26 @@ class EvaluationTest(unittest.TestCase):
219
220
  cache_stats=dict(
220
221
  use_cache=True, num_queries=2, num_hits=0, num_updates=2
221
222
  ),
222
- metrics=dict(total=2, failures=1, failure_rate=0.5),
223
+ metrics=dict(
224
+ total=2,
225
+ failures=1,
226
+ failure_rate=0.5,
227
+ oop_failures=1,
228
+ oop_failure_rate=0.5,
229
+ non_oop_failures=0,
230
+ non_oop_failure_rate=0.0,
231
+ failure_breakdown={
232
+ 'MappingError.SchemaError.TypeError': 1
233
+ }
234
+ ),
235
+ usage=dict(
236
+ total_prompt_tokens=774,
237
+ total_completion_tokens=25,
238
+ num_usages=2,
239
+ average_prompt_tokens=387,
240
+ average_completion_tokens=12,
241
+ average_total_tokens=399,
242
+ ),
223
243
  ),
224
244
  )
225
245
  self.assertTrue(
@@ -227,14 +247,32 @@ class EvaluationTest(unittest.TestCase):
227
247
  self.assertTrue(
228
248
  os.path.exists(os.path.join(s.dir, base.Evaluation.RESULT_JSON)))
229
249
  self.assertTrue(
230
- os.path.exists(os.path.join(s.dir, base.Evaluation.CACHE_JSON)))
250
+ os.path.exists(os.path.join(s.dir, base.Evaluation.OOP_FAILURES_JSON)))
231
251
  self.assertTrue(
232
- os.path.exists(os.path.join(s.root_dir, base.Evaluation.SUMMARY_HTML))
233
- )
252
+ os.path.exists(
253
+ os.path.join(s.dir, base.Evaluation.NON_OOP_FAILURES_JSON)))
254
+ self.assertTrue(
255
+ os.path.exists(os.path.join(s.dir, base.Evaluation.CACHE_JSON)))
234
256
  self.assertTrue(
235
257
  os.path.exists(os.path.join(s.dir, base.Evaluation.INDEX_HTML)))
236
258
  self.assertTrue(
237
- os.path.exists(os.path.join(s.dir, base.Evaluation.FAILURES_HTML)))
259
+ os.path.exists(os.path.join(s.dir, base.Evaluation.OOP_FAILURES_HTML)))
260
+ self.assertTrue(
261
+ os.path.exists(
262
+ os.path.join(s.dir, base.Evaluation.NON_OOP_FAILURES_HTML)))
263
+ self.assertTrue(
264
+ os.path.exists(os.path.join(s.root_dir, base.Evaluation.SUMMARY_HTML))
265
+ )
266
+ # Check summary JSON.
267
+ summary_json = os.path.join(
268
+ s.root_dir, base.Evaluation.SUMMARY_HTML.replace('.html', '.json')
269
+ )
270
+ self.assertTrue(os.path.exists(summary_json))
271
+ summary = pg.load(summary_json, force_dict=True)
272
+ self.assertIn('Evaluation', summary)
273
+ self.assertEqual(len(summary['Evaluation']), 1)
274
+ self.assertIsNotNone(summary['Evaluation'][0].experiment)
275
+ self.assertIsNotNone(summary['Evaluation'][0].metrics)
238
276
 
239
277
  def test_run_wihtout_save(self):
240
278
  lm = fake.StaticSequence([
@@ -255,7 +293,10 @@ class EvaluationTest(unittest.TestCase):
255
293
  self.assertFalse(
256
294
  os.path.exists(os.path.join(s.dir, base.Evaluation.INDEX_HTML)))
257
295
  self.assertFalse(
258
- os.path.exists(os.path.join(s.dir, base.Evaluation.FAILURES_HTML)))
296
+ os.path.exists(os.path.join(s.dir, base.Evaluation.OOP_FAILURES_HTML)))
297
+ self.assertFalse(
298
+ os.path.exists(
299
+ os.path.join(s.dir, base.Evaluation.NON_OOP_FAILURES_HTML)))
259
300
 
260
301
  def test_load(self):
261
302
  lm = fake.StaticResponse('Solution(final_answer=2)')
@@ -274,8 +315,11 @@ class EvaluationTest(unittest.TestCase):
274
315
  s = eval_set(
275
316
  'run_filter_test', pg.oneof(['call', 'query']),
276
317
  schema_fn=answer_schema(), lm=lm)
318
+ result = s.run(
319
+ filter=lambda x: x.method == 'query', dryrun=True, summary=False
320
+ )
277
321
  self.assertEqual(
278
- s.run(filter=lambda x: x.method == 'query', dryrun=True, summary=False),
322
+ result,
279
323
  {
280
324
  s.children[0].id: None,
281
325
  s.children[1].id: dict(
@@ -290,8 +334,18 @@ class EvaluationTest(unittest.TestCase):
290
334
  cache_stats=dict(
291
335
  use_cache=True, num_queries=2, num_hits=0, num_updates=2
292
336
  ),
293
- metrics=dict(total=2, failures=0, failure_rate=0.0),
294
- )
337
+ metrics=dict(
338
+ total=2,
339
+ failures=0,
340
+ failure_rate=0.0,
341
+ oop_failures=0,
342
+ oop_failure_rate=0.0,
343
+ non_oop_failures=0,
344
+ non_oop_failure_rate=0.0,
345
+ failure_breakdown={},
346
+ ),
347
+ usage=s.children[1].result.usage,
348
+ ),
295
349
  },
296
350
  )
297
351
 
@@ -321,11 +375,10 @@ class EvaluationTest(unittest.TestCase):
321
375
  s.children[0].dir, os.path.join(s.root_dir, s.children[0].id)
322
376
  )
323
377
  # Test persistent hash.
324
- self.assertEqual(s.hash, 'ca7f722b')
378
+ self.assertEqual(s.hash, 'b66a4e88')
325
379
 
326
380
  summary = s.run(verbose=True)
327
381
  self.assertEqual(len(summary.evaluations), 2)
328
-
329
382
  self.assertEqual(
330
383
  s.result,
331
384
  {
@@ -341,7 +394,19 @@ class EvaluationTest(unittest.TestCase):
341
394
  cache_stats=dict(
342
395
  use_cache=True, num_queries=2, num_hits=0, num_updates=2
343
396
  ),
344
- metrics=dict(total=2, failures=1, failure_rate=0.5),
397
+ metrics=dict(
398
+ total=2,
399
+ failures=1,
400
+ failure_rate=0.5,
401
+ oop_failures=1,
402
+ oop_failure_rate=0.5,
403
+ non_oop_failures=0,
404
+ non_oop_failure_rate=0.0,
405
+ failure_breakdown={
406
+ 'MappingError.SchemaError.TypeError': 1
407
+ }
408
+ ),
409
+ usage=s.children[0].result.usage,
345
410
  ),
346
411
  s.children[1].id: dict(
347
412
  experiment_setup=dict(
@@ -355,7 +420,19 @@ class EvaluationTest(unittest.TestCase):
355
420
  cache_stats=dict(
356
421
  use_cache=True, num_queries=2, num_hits=0, num_updates=2
357
422
  ),
358
- metrics=dict(total=2, failures=1, failure_rate=0.5),
423
+ metrics=dict(
424
+ total=2,
425
+ failures=1,
426
+ failure_rate=0.5,
427
+ oop_failures=1,
428
+ oop_failure_rate=0.5,
429
+ non_oop_failures=0,
430
+ non_oop_failure_rate=0.0,
431
+ failure_breakdown={
432
+ 'MappingError.SchemaError.TypeError': 1
433
+ }
434
+ ),
435
+ usage=s.children[1].result.usage,
359
436
  ),
360
437
  },
361
438
  )
@@ -448,10 +525,10 @@ class SuiteTest(unittest.TestCase):
448
525
  lm=lm
449
526
  )
450
527
  # Test for persistent hash.
451
- self.assertEqual(s.hash, '7285e52b')
528
+ self.assertEqual(s.hash, '26e6cc25')
452
529
  s.run()
453
530
  expected = {
454
- s.children[0].id: dict(
531
+ 'Evaluation@0fade07d': dict(
455
532
  experiment_setup=dict(
456
533
  id=s.children[0].id,
457
534
  dir=s.children[0].dir,
@@ -463,45 +540,46 @@ class SuiteTest(unittest.TestCase):
463
540
  cache_stats=dict(
464
541
  use_cache=True, num_queries=2, num_hits=0, num_updates=2
465
542
  ),
466
- metrics=dict(total=2, failures=1, failure_rate=0.5),
543
+ metrics=dict(
544
+ total=2,
545
+ failures=1,
546
+ failure_rate=0.5,
547
+ oop_failures=1,
548
+ oop_failure_rate=0.5,
549
+ non_oop_failures=0,
550
+ non_oop_failure_rate=0.0,
551
+ failure_breakdown={
552
+ 'MappingError.SchemaError.TypeError': 1
553
+ }
554
+ ),
555
+ usage=s.children[0].result.usage,
467
556
  ),
468
- s.children[1].id: {
469
- s.children[1]
470
- .children[0]
471
- .id: dict(
472
- experiment_setup=dict(
473
- id=s.children[1].children[0].id,
474
- dir=s.children[1].children[0].dir,
475
- model='StaticSequence',
476
- prompt_template='{{example.question}}',
477
- method='call',
478
- schema_fn='answer_schema()',
479
- ),
480
- cache_stats=dict(
481
- use_cache=True, num_queries=4, num_hits=1, num_updates=3
482
- ),
483
- metrics=dict(total=2, failures=2, failure_rate=1.0),
557
+ 'Evaluation@ae86c703': dict(
558
+ experiment_setup=dict(
559
+ id=s.children[1].children[0].id,
560
+ dir=s.children[1].children[0].dir,
561
+ model='StaticSequence',
562
+ prompt_template='{{example.question}}',
563
+ method='call',
564
+ schema_fn='answer_schema()',
484
565
  ),
485
- s.children[1]
486
- .children[2]
487
- .id: dict(
488
- experiment_setup=dict(
489
- id=s.children[1].children[2].id,
490
- dir=s.children[1].children[2].dir,
491
- model='StaticSequence',
492
- prompt_template='{{example.question}}',
493
- method='query',
494
- schema_fn='answer_schema()',
495
- ),
496
- cache_stats=dict(
497
- use_cache=True,
498
- num_queries=2,
499
- num_hits=0,
500
- num_updates=2,
501
- ),
502
- metrics=dict(total=2, failures=1, failure_rate=0.5),
566
+ cache_stats=dict(
567
+ use_cache=True, num_queries=4, num_hits=1, num_updates=3
503
568
  ),
504
- },
569
+ metrics=dict(
570
+ total=2,
571
+ failures=2,
572
+ failure_rate=1.0,
573
+ oop_failures=2,
574
+ oop_failure_rate=1.0,
575
+ non_oop_failures=0,
576
+ non_oop_failure_rate=0.0,
577
+ failure_breakdown={
578
+ 'MappingError.SchemaError.TypeError': 2
579
+ }
580
+ ),
581
+ usage=s.children[1].children[0].result.usage,
582
+ ),
505
583
  }
506
584
  self.assertEqual(s.result, expected)
507
585
 
@@ -671,5 +749,98 @@ class SummaryTest(unittest.TestCase):
671
749
  self.assertTrue(pg.io.path_exists(summary_file))
672
750
 
673
751
 
752
+ class NamedEvaluationTest(unittest.TestCase):
753
+
754
+ def test_named_eval_class(self):
755
+
756
+ @base.register('named_eval/class_test')
757
+ class MyEval(base.Evaluation):
758
+ inputs = base.as_inputs([
759
+ pg.Dict(question='Compute 1 + 1'),
760
+ ])
761
+ method = 'query'
762
+ prompt = pg.oneof([
763
+ lf.Template('{{example.question}}'),
764
+ lf.Template('Hello {{example.question}}'),
765
+ ])
766
+ schema_fn = answer_schema()
767
+
768
+ evaluation = base.get_evaluation('named_eval/class_test')
769
+ self.assertIsInstance(evaluation, MyEval)
770
+ self.assertIsNone(evaluation.dir)
771
+ self.assertIsNone(evaluation.root_dir)
772
+ self.assertIn('named_eval/class_test', base.registered_names())
773
+
774
+ with self.assertRaisesRegex(ValueError, 'Unsupported type.*'):
775
+ @base.register('named_eval/bad_class')
776
+ class Foo: # pylint: disable=unused-variable
777
+ pass
778
+
779
+ def test_named_eval_functor(self):
780
+
781
+ @base.register('named_eval/functor_test')
782
+ def my_eval():
783
+ return base.Evaluation(
784
+ inputs=base.as_inputs([
785
+ pg.Dict(question='Compute 1 + 1'),
786
+ ]),
787
+ method='query',
788
+ prompt=pg.oneof([
789
+ lf.Template('{{example.question}}'),
790
+ lf.Template('Hello {{example.question}}'),
791
+ ]),
792
+ schema_fn=answer_schema(),
793
+ )
794
+
795
+ self.assertTrue(issubclass(my_eval, base.Evaluable))
796
+ evaluation = base.get_evaluation('named_eval/functor_test')
797
+ self.assertIn('named_eval/functor_test', base.registered_names())
798
+ self.assertIsInstance(evaluation, my_eval)
799
+ self.assertIsNone(evaluation.root_dir, None)
800
+
801
+ with self.assertRaisesRegex(ValueError, 'Evaluation .* not found'):
802
+ base.get_evaluation('named_eval/non_existent')
803
+
804
+ with self.assertRaisesRegex(TypeError, 'The return value .*'):
805
+ @base.register('named_eval/bad_return_type')
806
+ def bad_eval(): # pylint: disable=unused-variable
807
+ return 1
808
+
809
+ def test_run(self):
810
+ @base.register('test/run')
811
+ def test_run(): # pylint: disable=unused-variable
812
+ lm = fake.StaticResponse('Solution(final_answer=2)')
813
+ return eval_set('run_test', 'query', schema_fn=answer_schema(), lm=lm)
814
+
815
+ e = base.run(
816
+ tempfile.gettempdir(),
817
+ ['test/run'],
818
+ id_regex='run_test.*',
819
+ mode='dryrun',
820
+ print_definition=True,
821
+ )
822
+ self.assertEqual(
823
+ e.leaf_nodes[0].dir,
824
+ os.path.join(tempfile.gettempdir(), e.leaf_nodes[0].id),
825
+ )
826
+ self.assertTrue(
827
+ pg.eq(
828
+ e.leaf_nodes[0].lm, fake.StaticResponse('Solution(final_answer=2)')
829
+ )
830
+ )
831
+
832
+ @pg.patcher()
833
+ def bad_lm(unused_eval): # pylint: disable=unused-variable
834
+ return dict(lm=fake.StaticResponse('efg'))
835
+
836
+ e = base.run(
837
+ tempfile.gettempdir(),
838
+ [test_run()],
839
+ filter='Evaluation.*',
840
+ patches=['bad_lm']
841
+ )
842
+ self.assertTrue(pg.eq(e.leaf_nodes[0].lm, fake.StaticResponse('efg')))
843
+
844
+
674
845
  if __name__ == '__main__':
675
846
  unittest.main()
@@ -86,9 +86,26 @@ class Matching(base.Evaluation):
86
86
  self._matches = []
87
87
  self._mismatches = []
88
88
 
89
- def audit(self, example: Any, output: Any, message: lf.Message) -> None:
89
+ def audit_processed(
90
+ self, example: Any, output: Any, message: lf.Message, dryrun: bool = False
91
+ ) -> None:
90
92
  groundtruth = self.groundtruth(example)
91
93
  answer = self.answer(output, example)
94
+
95
+ if dryrun:
96
+ lf.console.write('')
97
+ lf.console.write(
98
+ str(groundtruth),
99
+ title='GROUDTRUTH',
100
+ color='green',
101
+ )
102
+ lf.console.write('')
103
+ lf.console.write(
104
+ str(answer),
105
+ title='ANSWER',
106
+ color='blue',
107
+ )
108
+
92
109
  if self.match(answer, groundtruth):
93
110
  self._matches.append((example, output, message))
94
111
  else:
@@ -102,18 +119,18 @@ class Matching(base.Evaluation):
102
119
  del progress
103
120
  return {
104
121
  'Model': self.lm.model_id,
105
- 'Matches': f'%.{self.report_precision}f%% (%d/%d)' % (
106
- self.match_rate * 100,
122
+ 'Matches': '%s (%d/%d)' % (
123
+ self._format_rate(self.match_rate),
107
124
  self.num_matches,
108
125
  self.num_completed,
109
126
  ),
110
- 'Mismatches': f'%.{self.report_precision}f%% (%d/%d)' % (
111
- self.mismatch_rate * 100,
127
+ 'Mismatches': '%s (%d/%d)' % (
128
+ self._format_rate(self.mismatch_rate),
112
129
  self.num_mismatches,
113
130
  self.num_completed,
114
131
  ),
115
- 'Failed': f'%.{self.report_precision}f%% (%d/%d)' % (
116
- self.failure_rate * 100,
132
+ 'Failed': '%s (%d/%d)' % (
133
+ self._format_rate(self.failure_rate),
117
134
  self.num_failures,
118
135
  self.num_completed,
119
136
  ),
@@ -123,24 +140,25 @@ class Matching(base.Evaluation):
123
140
  assert self.result is not None
124
141
  m = self.result.metrics
125
142
  return (
126
- f'COMPLETED(%s): Matches=%.{self.report_precision}f%% (%d/%d)'
127
- f' Mismatches=%.{self.report_precision}f%% (%d/%d)'
128
- f' Failures=%.{self.report_precision}f%% (%d/%d)'
143
+ 'COMPLETED(%s):'
144
+ ' Matches=%s (%d/%d)'
145
+ ' Mismatches=%s (%d/%d)'
146
+ ' Failures=%s (%d/%d)'
129
147
  ) % (
130
148
  run_status,
131
- m.match_rate * 100,
149
+ self._format_rate(m.match_rate),
132
150
  m.num_matches,
133
151
  m.total,
134
- m.mismatch_rate * 100,
152
+ self._format_rate(m.mismatch_rate),
135
153
  m.num_mismatches,
136
154
  m.total,
137
- m.failure_rate * 100,
155
+ self._format_rate(m.failure_rate),
138
156
  m.failures,
139
157
  m.total,
140
158
  )
141
159
 
142
- def summarize(self) -> pg.Dict:
143
- result = super().summarize()
160
+ def finalize(self) -> pg.Dict:
161
+ result = super().finalize()
144
162
  result.metrics.update(
145
163
  num_matches=self.num_matches,
146
164
  match_rate=self.match_rate,
@@ -155,19 +173,16 @@ class Matching(base.Evaluation):
155
173
  super().save(definition, result, report)
156
174
 
157
175
  if result:
158
-
159
- def force_dict(v):
160
- return pg.object_utils.json_conversion.strip_types(pg.to_json(v))
161
-
162
176
  # Save matches.
163
177
  pg.save(
164
178
  [
165
- # We force the output to be dict as its type may be defined
166
- # within functors which could be deserialized.
167
- pg.Dict(input=input, output=force_dict(output))
179
+ pg.Dict(input=input, output=output)
168
180
  for input, output, _ in self.matches
169
181
  ],
170
182
  os.path.join(self.dir, Matching.MATCHES_JSON),
183
+ # We force the input and output to be dict so it does not depend on
184
+ # the downstream to serialize.
185
+ force_dict=True,
171
186
  )
172
187
 
173
188
  # Save mismatches.
@@ -175,10 +190,13 @@ class Matching(base.Evaluation):
175
190
  [
176
191
  # We force the output to be dict as its type may be defined
177
192
  # within functors which could be deserialized.
178
- pg.Dict(input=input, output=force_dict(output))
193
+ pg.Dict(input=input, output=output)
179
194
  for input, output, _ in self.mismatches
180
195
  ],
181
196
  os.path.join(self.dir, Matching.MISMATCHES_JSON),
197
+ # We force the input and output to be dict so it does not depend on
198
+ # the downstream to serialize.
199
+ force_dict=True,
182
200
  )
183
201
 
184
202
  if report:
@@ -201,9 +219,9 @@ class Matching(base.Evaluation):
201
219
  def _render_result_row(self, s: io.StringIO):
202
220
  super()._render_result_row(s)
203
221
  s.write(
204
- '<td><span style="color:red">%s</span>%s</td>'
222
+ '<td><span style="color:orange">%s</span>%s</td>'
205
223
  % (
206
- f'%.{self.report_precision}f%% ' % (self.mismatch_rate * 100),
224
+ self._format_rate(self.mismatch_rate),
207
225
  '<a href="%s">(%d/%d)</a>'
208
226
  % (self.mismatches_link, self.num_mismatches, self.num_completed),
209
227
  )
@@ -211,13 +229,13 @@ class Matching(base.Evaluation):
211
229
  s.write(
212
230
  '<td><span style="color:green">%s</span>%s</td>'
213
231
  % (
214
- f'%.{self.report_precision}f%% ' % (self.match_rate * 100),
232
+ self._format_rate(self.match_rate),
215
233
  '<a href="%s">(%d/%d)</a>'
216
234
  % (self.matches_link, self.num_matches, self.num_completed),
217
235
  )
218
236
  )
219
237
 
220
- def _render_metric(self, s: io.StringIO) -> None:
238
+ def _render_summary_metrics(self, s: io.StringIO) -> None:
221
239
  """Renders metrics in HTML."""
222
240
  assert self.result is not None
223
241
  m = self.result.metrics
@@ -227,7 +245,7 @@ class Matching(base.Evaluation):
227
245
  m.num_matches,
228
246
  m.total,
229
247
  self.matches_link,
230
- f'%.{self.report_precision}f%% ' % (m.match_rate * 100),
248
+ self._format_rate(m.match_rate),
231
249
  )
232
250
  )
233
251
  s.write(' | ')
@@ -237,11 +255,11 @@ class Matching(base.Evaluation):
237
255
  m.num_mismatches,
238
256
  m.total,
239
257
  self.mismatches_link,
240
- f'%.{self.report_precision}f%% ' % (m.mismatch_rate * 100),
258
+ self._format_rate(m.mismatch_rate),
241
259
  )
242
260
  )
243
261
  s.write(' | ')
244
- super()._render_metric(s)
262
+ super()._render_summary_metrics(s)
245
263
 
246
264
  def _render_matches(self, s: io.StringIO) -> None:
247
265
  """Formats the matched cases into html."""
@@ -103,7 +103,7 @@ class MatchingTest(unittest.TestCase):
103
103
  s.result,
104
104
  dict(
105
105
  experiment_setup=dict(
106
- id='MyTask@3d87f97f',
106
+ id='MyTask@739a174b',
107
107
  dir=s.dir,
108
108
  model='StaticSequence',
109
109
  prompt_template='{{example.question}}',
@@ -120,11 +120,19 @@ class MatchingTest(unittest.TestCase):
120
120
  total=4,
121
121
  failures=1,
122
122
  failure_rate=0.25,
123
+ oop_failures=1,
124
+ oop_failure_rate=0.25,
125
+ non_oop_failures=0,
126
+ non_oop_failure_rate=0.0,
127
+ failure_breakdown={
128
+ 'MappingError.SchemaError.TypeError': 1
129
+ },
123
130
  num_matches=2,
124
131
  match_rate=0.5,
125
132
  num_mismatches=1,
126
133
  mismatch_rate=0.25,
127
134
  ),
135
+ usage=s.result.usage,
128
136
  ),
129
137
  )
130
138
  self.assertTrue(
@@ -159,7 +167,14 @@ class MatchingTest(unittest.TestCase):
159
167
  self.assertTrue(
160
168
  os.path.exists(
161
169
  os.path.join(
162
- s.dir, matching.Matching.FAILURES_JSON
170
+ s.dir, matching.Matching.OOP_FAILURES_JSON
171
+ )
172
+ )
173
+ )
174
+ self.assertTrue(
175
+ os.path.exists(
176
+ os.path.join(
177
+ s.dir, matching.Matching.NON_OOP_FAILURES_JSON
163
178
  )
164
179
  )
165
180
  )
@@ -174,7 +189,14 @@ class MatchingTest(unittest.TestCase):
174
189
  self.assertTrue(
175
190
  os.path.exists(
176
191
  os.path.join(
177
- s.dir, matching.Matching.FAILURES_HTML
192
+ s.dir, matching.Matching.OOP_FAILURES_HTML
193
+ )
194
+ )
195
+ )
196
+ self.assertTrue(
197
+ os.path.exists(
198
+ os.path.join(
199
+ s.dir, matching.Matching.NON_OOP_FAILURES_HTML
178
200
  )
179
201
  )
180
202
  )