langfun 0.0.2.dev20240429__py3-none-any.whl → 0.1.2.dev202501140804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (144) hide show
  1. langfun/__init__.py +20 -2
  2. langfun/core/__init__.py +16 -5
  3. langfun/core/agentic/__init__.py +30 -0
  4. langfun/core/agentic/action.py +854 -0
  5. langfun/core/agentic/action_eval.py +150 -0
  6. langfun/core/agentic/action_eval_test.py +109 -0
  7. langfun/core/agentic/action_test.py +136 -0
  8. langfun/core/coding/python/__init__.py +5 -11
  9. langfun/core/coding/python/correction.py +37 -21
  10. langfun/core/coding/python/correction_test.py +29 -3
  11. langfun/core/coding/python/execution.py +40 -216
  12. langfun/core/coding/python/execution_test.py +29 -89
  13. langfun/core/coding/python/generation.py +21 -11
  14. langfun/core/coding/python/generation_test.py +2 -2
  15. langfun/core/coding/python/parsing.py +108 -193
  16. langfun/core/coding/python/parsing_test.py +2 -105
  17. langfun/core/component.py +63 -2
  18. langfun/core/component_test.py +53 -0
  19. langfun/core/concurrent.py +414 -117
  20. langfun/core/concurrent_test.py +111 -24
  21. langfun/core/console.py +18 -5
  22. langfun/core/console_test.py +17 -0
  23. langfun/core/eval/__init__.py +16 -1
  24. langfun/core/eval/base.py +622 -174
  25. langfun/core/eval/base_test.py +200 -54
  26. langfun/core/eval/matching.py +63 -76
  27. langfun/core/eval/matching_test.py +17 -8
  28. langfun/core/eval/patching.py +130 -0
  29. langfun/core/eval/patching_test.py +170 -0
  30. langfun/core/eval/scoring.py +26 -26
  31. langfun/core/eval/scoring_test.py +19 -2
  32. langfun/core/eval/v2/__init__.py +42 -0
  33. langfun/core/eval/v2/checkpointing.py +380 -0
  34. langfun/core/eval/v2/checkpointing_test.py +228 -0
  35. langfun/core/eval/v2/eval_test_helper.py +136 -0
  36. langfun/core/eval/v2/evaluation.py +725 -0
  37. langfun/core/eval/v2/evaluation_test.py +180 -0
  38. langfun/core/eval/v2/example.py +305 -0
  39. langfun/core/eval/v2/example_test.py +128 -0
  40. langfun/core/eval/v2/experiment.py +1048 -0
  41. langfun/core/eval/v2/experiment_test.py +433 -0
  42. langfun/core/eval/v2/metric_values.py +156 -0
  43. langfun/core/eval/v2/metric_values_test.py +80 -0
  44. langfun/core/eval/v2/metrics.py +357 -0
  45. langfun/core/eval/v2/metrics_test.py +203 -0
  46. langfun/core/eval/v2/progress.py +348 -0
  47. langfun/core/eval/v2/progress_test.py +82 -0
  48. langfun/core/eval/v2/progress_tracking.py +210 -0
  49. langfun/core/eval/v2/progress_tracking_test.py +66 -0
  50. langfun/core/eval/v2/reporting.py +270 -0
  51. langfun/core/eval/v2/reporting_test.py +158 -0
  52. langfun/core/eval/v2/runners.py +488 -0
  53. langfun/core/eval/v2/runners_test.py +334 -0
  54. langfun/core/langfunc.py +4 -17
  55. langfun/core/langfunc_test.py +22 -6
  56. langfun/core/language_model.py +577 -39
  57. langfun/core/language_model_test.py +470 -56
  58. langfun/core/llms/__init__.py +87 -16
  59. langfun/core/llms/anthropic.py +312 -87
  60. langfun/core/llms/anthropic_test.py +71 -3
  61. langfun/core/llms/cache/base.py +21 -2
  62. langfun/core/llms/cache/in_memory.py +13 -0
  63. langfun/core/llms/cache/in_memory_test.py +53 -2
  64. langfun/core/llms/compositional.py +101 -0
  65. langfun/core/llms/compositional_test.py +73 -0
  66. langfun/core/llms/deepseek.py +117 -0
  67. langfun/core/llms/deepseek_test.py +61 -0
  68. langfun/core/llms/fake.py +11 -7
  69. langfun/core/llms/fake_test.py +14 -0
  70. langfun/core/llms/gemini.py +507 -0
  71. langfun/core/llms/gemini_test.py +195 -0
  72. langfun/core/llms/google_genai.py +62 -218
  73. langfun/core/llms/google_genai_test.py +9 -202
  74. langfun/core/llms/groq.py +160 -144
  75. langfun/core/llms/groq_test.py +31 -137
  76. langfun/core/llms/llama_cpp.py +15 -42
  77. langfun/core/llms/llama_cpp_test.py +4 -30
  78. langfun/core/llms/openai.py +395 -203
  79. langfun/core/llms/openai_compatible.py +179 -0
  80. langfun/core/llms/openai_compatible_test.py +495 -0
  81. langfun/core/llms/openai_test.py +30 -395
  82. langfun/core/llms/rest.py +113 -0
  83. langfun/core/llms/rest_test.py +111 -0
  84. langfun/core/llms/vertexai.py +192 -0
  85. langfun/core/llms/vertexai_test.py +52 -0
  86. langfun/core/logging.py +284 -0
  87. langfun/core/logging_test.py +125 -0
  88. langfun/core/message.py +319 -9
  89. langfun/core/message_test.py +190 -13
  90. langfun/core/modalities/__init__.py +6 -2
  91. langfun/core/modalities/audio.py +30 -0
  92. langfun/core/modalities/audio_test.py +63 -0
  93. langfun/core/modalities/image.py +39 -20
  94. langfun/core/modalities/image_test.py +52 -9
  95. langfun/core/modalities/mime.py +206 -29
  96. langfun/core/modalities/mime_test.py +90 -9
  97. langfun/core/modalities/ms_office.py +117 -0
  98. langfun/core/modalities/ms_office_test.py +389 -0
  99. langfun/core/modalities/pdf.py +22 -0
  100. langfun/core/modalities/pdf_test.py +57 -0
  101. langfun/core/modalities/video.py +9 -26
  102. langfun/core/modalities/video_test.py +3 -3
  103. langfun/core/modality.py +26 -3
  104. langfun/core/modality_test.py +2 -2
  105. langfun/core/sampling.py +11 -11
  106. langfun/core/structured/__init__.py +12 -16
  107. langfun/core/structured/completion.py +32 -5
  108. langfun/core/structured/completion_test.py +7 -6
  109. langfun/core/structured/description.py +2 -2
  110. langfun/core/structured/description_test.py +3 -3
  111. langfun/core/structured/function_generation.py +60 -27
  112. langfun/core/structured/function_generation_test.py +72 -2
  113. langfun/core/structured/mapping.py +97 -47
  114. langfun/core/structured/mapping_test.py +90 -2
  115. langfun/core/structured/parsing.py +33 -21
  116. langfun/core/structured/parsing_test.py +53 -9
  117. langfun/core/structured/querying.py +746 -0
  118. langfun/core/structured/{prompting_test.py → querying_test.py} +469 -51
  119. langfun/core/structured/schema.py +204 -97
  120. langfun/core/structured/schema_generation.py +1 -1
  121. langfun/core/structured/schema_test.py +130 -29
  122. langfun/core/structured/scoring.py +125 -19
  123. langfun/core/structured/scoring_test.py +30 -0
  124. langfun/core/structured/tokenization.py +64 -0
  125. langfun/core/structured/tokenization_test.py +48 -0
  126. langfun/core/template.py +115 -1
  127. langfun/core/template_test.py +71 -1
  128. langfun/core/templates/conversation.py +9 -0
  129. langfun/core/templates/conversation_test.py +4 -3
  130. langfun/core/templates/selfplay_test.py +10 -2
  131. langfun-0.1.2.dev202501140804.dist-info/METADATA +225 -0
  132. langfun-0.1.2.dev202501140804.dist-info/RECORD +153 -0
  133. {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501140804.dist-info}/WHEEL +1 -1
  134. langfun/core/coding/python/errors.py +0 -108
  135. langfun/core/coding/python/errors_test.py +0 -99
  136. langfun/core/coding/python/permissions.py +0 -90
  137. langfun/core/coding/python/permissions_test.py +0 -86
  138. langfun/core/structured/prompting.py +0 -238
  139. langfun/core/text_formatting.py +0 -162
  140. langfun/core/text_formatting_test.py +0 -47
  141. langfun-0.0.2.dev20240429.dist-info/METADATA +0 -100
  142. langfun-0.0.2.dev20240429.dist-info/RECORD +0 -108
  143. {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501140804.dist-info}/LICENSE +0 -0
  144. {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501140804.dist-info}/top_level.txt +0 -0
@@ -194,6 +194,7 @@ class EvaluationTest(unittest.TestCase):
194
194
  cache_seed=0,
195
195
  score=1.0,
196
196
  logprobs=None,
197
+ is_cached=False,
197
198
  usage=lf.LMSamplingUsage(387, 24, 411),
198
199
  tags=['lm-response', 'lm-output', 'transformed'],
199
200
  ),
@@ -220,7 +221,18 @@ class EvaluationTest(unittest.TestCase):
220
221
  cache_stats=dict(
221
222
  use_cache=True, num_queries=2, num_hits=0, num_updates=2
222
223
  ),
223
- metrics=dict(total=2, failures=1, failure_rate=0.5),
224
+ metrics=dict(
225
+ total=2,
226
+ failures=1,
227
+ failure_rate=0.5,
228
+ oop_failures=1,
229
+ oop_failure_rate=0.5,
230
+ non_oop_failures=0,
231
+ non_oop_failure_rate=0.0,
232
+ failure_breakdown={
233
+ 'MappingError.SchemaError.TypeError': 1
234
+ }
235
+ ),
224
236
  usage=dict(
225
237
  total_prompt_tokens=774,
226
238
  total_completion_tokens=25,
@@ -235,12 +247,20 @@ class EvaluationTest(unittest.TestCase):
235
247
  os.path.exists(os.path.join(s.dir, base.Evaluation.EXPERIMENT_JSON)))
236
248
  self.assertTrue(
237
249
  os.path.exists(os.path.join(s.dir, base.Evaluation.RESULT_JSON)))
250
+ self.assertTrue(
251
+ os.path.exists(os.path.join(s.dir, base.Evaluation.OOP_FAILURES_JSON)))
252
+ self.assertTrue(
253
+ os.path.exists(
254
+ os.path.join(s.dir, base.Evaluation.NON_OOP_FAILURES_JSON)))
238
255
  self.assertTrue(
239
256
  os.path.exists(os.path.join(s.dir, base.Evaluation.CACHE_JSON)))
240
257
  self.assertTrue(
241
258
  os.path.exists(os.path.join(s.dir, base.Evaluation.INDEX_HTML)))
242
259
  self.assertTrue(
243
- os.path.exists(os.path.join(s.dir, base.Evaluation.FAILURES_HTML)))
260
+ os.path.exists(os.path.join(s.dir, base.Evaluation.OOP_FAILURES_HTML)))
261
+ self.assertTrue(
262
+ os.path.exists(
263
+ os.path.join(s.dir, base.Evaluation.NON_OOP_FAILURES_HTML)))
244
264
  self.assertTrue(
245
265
  os.path.exists(os.path.join(s.root_dir, base.Evaluation.SUMMARY_HTML))
246
266
  )
@@ -249,7 +269,7 @@ class EvaluationTest(unittest.TestCase):
249
269
  s.root_dir, base.Evaluation.SUMMARY_HTML.replace('.html', '.json')
250
270
  )
251
271
  self.assertTrue(os.path.exists(summary_json))
252
- summary = pg.load(summary_json, force_dict=True)
272
+ summary = pg.load(summary_json, auto_dict=True)
253
273
  self.assertIn('Evaluation', summary)
254
274
  self.assertEqual(len(summary['Evaluation']), 1)
255
275
  self.assertIsNotNone(summary['Evaluation'][0].experiment)
@@ -274,7 +294,10 @@ class EvaluationTest(unittest.TestCase):
274
294
  self.assertFalse(
275
295
  os.path.exists(os.path.join(s.dir, base.Evaluation.INDEX_HTML)))
276
296
  self.assertFalse(
277
- os.path.exists(os.path.join(s.dir, base.Evaluation.FAILURES_HTML)))
297
+ os.path.exists(os.path.join(s.dir, base.Evaluation.OOP_FAILURES_HTML)))
298
+ self.assertFalse(
299
+ os.path.exists(
300
+ os.path.join(s.dir, base.Evaluation.NON_OOP_FAILURES_HTML)))
278
301
 
279
302
  def test_load(self):
280
303
  lm = fake.StaticResponse('Solution(final_answer=2)')
@@ -312,7 +335,16 @@ class EvaluationTest(unittest.TestCase):
312
335
  cache_stats=dict(
313
336
  use_cache=True, num_queries=2, num_hits=0, num_updates=2
314
337
  ),
315
- metrics=dict(total=2, failures=0, failure_rate=0.0),
338
+ metrics=dict(
339
+ total=2,
340
+ failures=0,
341
+ failure_rate=0.0,
342
+ oop_failures=0,
343
+ oop_failure_rate=0.0,
344
+ non_oop_failures=0,
345
+ non_oop_failure_rate=0.0,
346
+ failure_breakdown={},
347
+ ),
316
348
  usage=s.children[1].result.usage,
317
349
  ),
318
350
  },
@@ -363,7 +395,18 @@ class EvaluationTest(unittest.TestCase):
363
395
  cache_stats=dict(
364
396
  use_cache=True, num_queries=2, num_hits=0, num_updates=2
365
397
  ),
366
- metrics=dict(total=2, failures=1, failure_rate=0.5),
398
+ metrics=dict(
399
+ total=2,
400
+ failures=1,
401
+ failure_rate=0.5,
402
+ oop_failures=1,
403
+ oop_failure_rate=0.5,
404
+ non_oop_failures=0,
405
+ non_oop_failure_rate=0.0,
406
+ failure_breakdown={
407
+ 'MappingError.SchemaError.TypeError': 1
408
+ }
409
+ ),
367
410
  usage=s.children[0].result.usage,
368
411
  ),
369
412
  s.children[1].id: dict(
@@ -378,7 +421,18 @@ class EvaluationTest(unittest.TestCase):
378
421
  cache_stats=dict(
379
422
  use_cache=True, num_queries=2, num_hits=0, num_updates=2
380
423
  ),
381
- metrics=dict(total=2, failures=1, failure_rate=0.5),
424
+ metrics=dict(
425
+ total=2,
426
+ failures=1,
427
+ failure_rate=0.5,
428
+ oop_failures=1,
429
+ oop_failure_rate=0.5,
430
+ non_oop_failures=0,
431
+ non_oop_failure_rate=0.0,
432
+ failure_breakdown={
433
+ 'MappingError.SchemaError.TypeError': 1
434
+ }
435
+ ),
382
436
  usage=s.children[1].result.usage,
383
437
  ),
384
438
  },
@@ -475,7 +529,7 @@ class SuiteTest(unittest.TestCase):
475
529
  self.assertEqual(s.hash, '26e6cc25')
476
530
  s.run()
477
531
  expected = {
478
- s.children[0].id: dict(
532
+ 'Evaluation@0fade07d': dict(
479
533
  experiment_setup=dict(
480
534
  id=s.children[0].id,
481
535
  dir=s.children[0].dir,
@@ -487,48 +541,46 @@ class SuiteTest(unittest.TestCase):
487
541
  cache_stats=dict(
488
542
  use_cache=True, num_queries=2, num_hits=0, num_updates=2
489
543
  ),
490
- metrics=dict(total=2, failures=1, failure_rate=0.5),
544
+ metrics=dict(
545
+ total=2,
546
+ failures=1,
547
+ failure_rate=0.5,
548
+ oop_failures=1,
549
+ oop_failure_rate=0.5,
550
+ non_oop_failures=0,
551
+ non_oop_failure_rate=0.0,
552
+ failure_breakdown={
553
+ 'MappingError.SchemaError.TypeError': 1
554
+ }
555
+ ),
491
556
  usage=s.children[0].result.usage,
492
557
  ),
493
- s.children[1].id: {
494
- s.children[1]
495
- .children[0]
496
- .id: dict(
497
- experiment_setup=dict(
498
- id=s.children[1].children[0].id,
499
- dir=s.children[1].children[0].dir,
500
- model='StaticSequence',
501
- prompt_template='{{example.question}}',
502
- method='call',
503
- schema_fn='answer_schema()',
504
- ),
505
- cache_stats=dict(
506
- use_cache=True, num_queries=4, num_hits=1, num_updates=3
507
- ),
508
- metrics=dict(total=2, failures=2, failure_rate=1.0),
509
- usage=s.children[1].children[0].result.usage,
558
+ 'Evaluation@ae86c703': dict(
559
+ experiment_setup=dict(
560
+ id=s.children[1].children[0].id,
561
+ dir=s.children[1].children[0].dir,
562
+ model='StaticSequence',
563
+ prompt_template='{{example.question}}',
564
+ method='call',
565
+ schema_fn='answer_schema()',
510
566
  ),
511
- s.children[1]
512
- .children[2]
513
- .id: dict(
514
- experiment_setup=dict(
515
- id=s.children[1].children[2].id,
516
- dir=s.children[1].children[2].dir,
517
- model='StaticSequence',
518
- prompt_template='{{example.question}}',
519
- method='query',
520
- schema_fn='answer_schema()',
521
- ),
522
- cache_stats=dict(
523
- use_cache=True,
524
- num_queries=2,
525
- num_hits=0,
526
- num_updates=2,
527
- ),
528
- metrics=dict(total=2, failures=1, failure_rate=0.5),
529
- usage=s.children[1].children[2].result.usage,
567
+ cache_stats=dict(
568
+ use_cache=True, num_queries=4, num_hits=0, num_updates=4
530
569
  ),
531
- },
570
+ metrics=dict(
571
+ total=2,
572
+ failures=2,
573
+ failure_rate=1.0,
574
+ oop_failures=2,
575
+ oop_failure_rate=1.0,
576
+ non_oop_failures=0,
577
+ non_oop_failure_rate=0.0,
578
+ failure_breakdown={
579
+ 'MappingError.SchemaError.TypeError': 2
580
+ }
581
+ ),
582
+ usage=s.children[1].children[0].result.usage,
583
+ ),
532
584
  }
533
585
  self.assertEqual(s.result, expected)
534
586
 
@@ -547,6 +599,14 @@ class InputsFrom(unittest.TestCase):
547
599
  pg.save([1, 2, 3], path)
548
600
  self.assertEqual(base.inputs_from(path)(), [1, 2, 3])
549
601
 
602
+ path = os.path.join(tmp_dir, 'input_file.jsonl')
603
+ with pg.open_jsonl(path, 'w') as f:
604
+ f.add(pg.Dict(x=1))
605
+ f.add(dict(y=2))
606
+ self.assertEqual(
607
+ base.inputs_from(path)(), [pg.Dict(x=1), dict(y=2)]
608
+ )
609
+
550
610
  def test_inputs_from_multiple_files(self):
551
611
  tmp_dir = tempfile.gettempdir()
552
612
  path1 = os.path.join(tmp_dir, 'input_file1.json')
@@ -698,16 +758,102 @@ class SummaryTest(unittest.TestCase):
698
758
  self.assertTrue(pg.io.path_exists(summary_file))
699
759
 
700
760
 
701
- class AppRunTest(unittest.TestCase):
761
+ class NamedEvaluationTest(unittest.TestCase):
702
762
 
703
- def test_app_run(self):
704
- lm = fake.StaticSequence(['two', 'Solution(final_answer=2)'])
705
- try:
706
- base.app_run(
707
- eval_set('app_run_test', 'query', schema_fn=answer_schema(), lm=lm)
763
+ def test_named_eval_class(self):
764
+
765
+ @base.register('named_eval/class_test')
766
+ class MyEval(base.Evaluation):
767
+ inputs = base.as_inputs([
768
+ pg.Dict(question='Compute 1 + 1'),
769
+ ])
770
+ method = 'query'
771
+ prompt = pg.oneof([
772
+ lf.Template('{{example.question}}'),
773
+ lf.Template('Hello {{example.question}}'),
774
+ ])
775
+ schema_fn = answer_schema()
776
+
777
+ [evaluation] = base.get_evaluations('named_eval/class_test')
778
+ self.assertIsInstance(evaluation, MyEval)
779
+ self.assertIsNone(evaluation.dir)
780
+ self.assertIsNone(evaluation.root_dir)
781
+ self.assertIn('named_eval/class_test', base.registered_names())
782
+
783
+ with self.assertRaisesRegex(ValueError, 'Unsupported type.*'):
784
+ @base.register('named_eval/bad_class')
785
+ class Foo: # pylint: disable=unused-variable
786
+ pass
787
+
788
+ def test_named_eval_functor(self):
789
+
790
+ @base.register('named_eval/functor_test')
791
+ def my_eval():
792
+ return base.Evaluation(
793
+ inputs=base.as_inputs([
794
+ pg.Dict(question='Compute 1 + 1'),
795
+ ]),
796
+ method='query',
797
+ prompt=pg.oneof([
798
+ lf.Template('{{example.question}}'),
799
+ lf.Template('Hello {{example.question}}'),
800
+ ]),
801
+ schema_fn=answer_schema(),
708
802
  )
709
- except SystemExit:
710
- pass
803
+
804
+ self.assertTrue(issubclass(my_eval, base.Evaluable))
805
+ [evaluation] = base.get_evaluations('named_eval/functor_test')
806
+ self.assertIn('named_eval/functor_test', base.registered_names())
807
+ self.assertIsInstance(evaluation, my_eval)
808
+ self.assertIsNone(evaluation.root_dir, None)
809
+
810
+ self.assertTrue(
811
+ pg.eq(base.get_evaluations('named_eval/functor.*'), [evaluation])
812
+ )
813
+ self.assertEqual(base.get_evaluations('named_eval/non_existent'), [])
814
+
815
+ with self.assertRaisesRegex(TypeError, 'The return value .*'):
816
+ @base.register('named_eval/bad_return_type')
817
+ def bad_eval(): # pylint: disable=unused-variable
818
+ return 1
819
+
820
+ def test_run(self):
821
+ @base.register('test/run')
822
+ def test_run(): # pylint: disable=unused-variable
823
+ lm = fake.StaticResponse('Solution(final_answer=2)')
824
+ return eval_set('run_test', 'query', schema_fn=answer_schema(), lm=lm)
825
+
826
+ e = base.run(
827
+ tempfile.gettempdir(),
828
+ ['test/run'],
829
+ id_regex='run_test.*',
830
+ mode='dryrun',
831
+ print_definition=True,
832
+ )
833
+ self.assertEqual(
834
+ e.leaf_nodes[0].dir,
835
+ os.path.join(tempfile.gettempdir(), e.leaf_nodes[0].id),
836
+ )
837
+ self.assertTrue(
838
+ pg.eq(
839
+ e.leaf_nodes[0].lm, fake.StaticResponse('Solution(final_answer=2)')
840
+ )
841
+ )
842
+
843
+ @pg.patcher()
844
+ def bad_lm(unused_eval): # pylint: disable=unused-variable
845
+ return dict(lm=fake.StaticResponse('efg'))
846
+
847
+ e = base.run(
848
+ tempfile.gettempdir(),
849
+ [test_run()],
850
+ filter='Evaluation.*',
851
+ patches=['bad_lm']
852
+ )
853
+ self.assertTrue(pg.eq(e.leaf_nodes[0].lm, fake.StaticResponse('efg')))
854
+
855
+ with self.assertRaisesRegex(ValueError, 'No evaluations found'):
856
+ base.run(tempfile.gettempdir(), ['test/non_existent'])
711
857
 
712
858
 
713
859
  if __name__ == '__main__':
@@ -41,8 +41,8 @@ class Matching(base.Evaluation):
41
41
  """Returns the answer from the structure output."""
42
42
 
43
43
  @property
44
- def matches(self) -> list[tuple[Any, Any, lf.Message]]:
45
- """Returns the matches examples, outputs and the output messages."""
44
+ def matches(self) -> list[tuple[int, Any, Any, lf.Message]]:
45
+ """Returns the matches IDs, examples, outputs and the output messages."""
46
46
  return self._matches
47
47
 
48
48
  @property
@@ -57,7 +57,7 @@ class Matching(base.Evaluation):
57
57
  return self.num_matches / self.num_completed
58
58
 
59
59
  @property
60
- def mismatches(self) -> list[tuple[Any, Any, lf.Message]]:
60
+ def mismatches(self) -> list[tuple[int, Any, Any, lf.Message]]:
61
61
  """Returns the mismatches examples, outputs and output messages."""
62
62
  return self._mismatches
63
63
 
@@ -87,7 +87,8 @@ class Matching(base.Evaluation):
87
87
  self._mismatches = []
88
88
 
89
89
  def audit_processed(
90
- self, example: Any, output: Any, message: lf.Message, dryrun: bool = False
90
+ self, example_idx: int, example: Any, output: Any, message: lf.Message,
91
+ dryrun: bool = False
91
92
  ) -> None:
92
93
  groundtruth = self.groundtruth(example)
93
94
  answer = self.answer(output, example)
@@ -107,30 +108,29 @@ class Matching(base.Evaluation):
107
108
  )
108
109
 
109
110
  if self.match(answer, groundtruth):
110
- self._matches.append((example, output, message))
111
+ self._matches.append((example_idx, example, output, message))
111
112
  else:
112
- self._mismatches.append((example, output, message))
113
+ self._mismatches.append((example_idx, example, output, message))
113
114
 
114
115
  def match(self, answer: Any, groundtruth: Any) -> bool:
115
116
  """Matches answer against the groundtruth. Subclasses can override."""
116
117
  return pg.eq(answer, groundtruth)
117
118
 
118
- def _status(self, progress: lf.concurrent.Progress) -> dict[str, Any]:
119
+ def _eval_status(self, progress: lf.concurrent.Progress) -> dict[str, Any]:
119
120
  del progress
120
121
  return {
121
- 'Model': self.lm.model_id,
122
- 'Matches': f'%.{self.report_precision}f%% (%d/%d)' % (
123
- self.match_rate * 100,
122
+ 'Matches': '%s (%d/%d)' % (
123
+ self._format_rate(self.match_rate),
124
124
  self.num_matches,
125
125
  self.num_completed,
126
126
  ),
127
- 'Mismatches': f'%.{self.report_precision}f%% (%d/%d)' % (
128
- self.mismatch_rate * 100,
127
+ 'Mismatches': '%s (%d/%d)' % (
128
+ self._format_rate(self.mismatch_rate),
129
129
  self.num_mismatches,
130
130
  self.num_completed,
131
131
  ),
132
- 'Failed': f'%.{self.report_precision}f%% (%d/%d)' % (
133
- self.failure_rate * 100,
132
+ 'Failed': '%s (%d/%d)' % (
133
+ self._format_rate(self.failure_rate),
134
134
  self.num_failures,
135
135
  self.num_completed,
136
136
  ),
@@ -140,24 +140,25 @@ class Matching(base.Evaluation):
140
140
  assert self.result is not None
141
141
  m = self.result.metrics
142
142
  return (
143
- f'COMPLETED(%s): Matches=%.{self.report_precision}f%% (%d/%d)'
144
- f' Mismatches=%.{self.report_precision}f%% (%d/%d)'
145
- f' Failures=%.{self.report_precision}f%% (%d/%d)'
143
+ 'COMPLETED(%s):'
144
+ ' Matches=%s (%d/%d)'
145
+ ' Mismatches=%s (%d/%d)'
146
+ ' Failures=%s (%d/%d)'
146
147
  ) % (
147
148
  run_status,
148
- m.match_rate * 100,
149
+ self._format_rate(m.match_rate),
149
150
  m.num_matches,
150
151
  m.total,
151
- m.mismatch_rate * 100,
152
+ self._format_rate(m.mismatch_rate),
152
153
  m.num_mismatches,
153
154
  m.total,
154
- m.failure_rate * 100,
155
+ self._format_rate(m.failure_rate),
155
156
  m.failures,
156
157
  m.total,
157
158
  )
158
159
 
159
- def summarize(self) -> pg.Dict:
160
- result = super().summarize()
160
+ def finalize(self) -> pg.Dict:
161
+ result = super().finalize()
161
162
  result.metrics.update(
162
163
  num_matches=self.num_matches,
163
164
  match_rate=self.match_rate,
@@ -171,33 +172,6 @@ class Matching(base.Evaluation):
171
172
  ) -> None:
172
173
  super().save(definition, result, report)
173
174
 
174
- if result:
175
- # Save matches.
176
- pg.save(
177
- [
178
- pg.Dict(input=input, output=output)
179
- for input, output, _ in self.matches
180
- ],
181
- os.path.join(self.dir, Matching.MATCHES_JSON),
182
- # We force the input and output to be dict so it does not depend on
183
- # the downstream to serialize.
184
- force_dict=True,
185
- )
186
-
187
- # Save mismatches.
188
- pg.save(
189
- [
190
- # We force the output to be dict as its type may be defined
191
- # within functors which could be deserialized.
192
- pg.Dict(input=input, output=output)
193
- for input, output, _ in self.mismatches
194
- ],
195
- os.path.join(self.dir, Matching.MISMATCHES_JSON),
196
- # We force the input and output to be dict so it does not depend on
197
- # the downstream to serialize.
198
- force_dict=True,
199
- )
200
-
201
175
  if report:
202
176
  pg.save(
203
177
  self._html([self._render_result, self._render_matches]),
@@ -218,9 +192,9 @@ class Matching(base.Evaluation):
218
192
  def _render_result_row(self, s: io.StringIO):
219
193
  super()._render_result_row(s)
220
194
  s.write(
221
- '<td><span style="color:red">%s</span>%s</td>'
195
+ '<td><span style="color:orange">%s</span>%s</td>'
222
196
  % (
223
- f'%.{self.report_precision}f%% ' % (self.mismatch_rate * 100),
197
+ self._format_rate(self.mismatch_rate),
224
198
  '<a href="%s">(%d/%d)</a>'
225
199
  % (self.mismatches_link, self.num_mismatches, self.num_completed),
226
200
  )
@@ -228,37 +202,33 @@ class Matching(base.Evaluation):
228
202
  s.write(
229
203
  '<td><span style="color:green">%s</span>%s</td>'
230
204
  % (
231
- f'%.{self.report_precision}f%% ' % (self.match_rate * 100),
205
+ self._format_rate(self.match_rate),
232
206
  '<a href="%s">(%d/%d)</a>'
233
207
  % (self.matches_link, self.num_matches, self.num_completed),
234
208
  )
235
209
  )
236
210
 
237
- def _render_metric(self, s: io.StringIO) -> None:
211
+ def _render_summary_metrics(self, s: io.StringIO) -> None:
238
212
  """Renders metrics in HTML."""
239
213
  assert self.result is not None
240
214
  m = self.result.metrics
241
- s.write(
242
- '<a title="Matches (%d/%d)" href="%s" style="color:green">%s</a>'
243
- % (
244
- m.num_matches,
245
- m.total,
246
- self.matches_link,
247
- f'%.{self.report_precision}f%% ' % (m.match_rate * 100),
248
- )
215
+ self._render_link(
216
+ s,
217
+ 'Matches (%d/%d)' % (m.num_matches, m.total),
218
+ self._format_rate(m.match_rate),
219
+ 'color:green',
220
+ lambda: self.matches_link,
249
221
  )
250
222
  s.write(' | ')
251
- s.write(
252
- '<a title="Mismatches (%d/%d)" href="%s" style="color:orange">%s</a>'
253
- % (
254
- m.num_mismatches,
255
- m.total,
256
- self.mismatches_link,
257
- f'%.{self.report_precision}f%% ' % (m.mismatch_rate * 100),
258
- )
223
+ self._render_link(
224
+ s,
225
+ 'Mismatches (%d/%d)' % (m.num_mismatches, m.total),
226
+ self._format_rate(m.mismatch_rate),
227
+ 'color:orange',
228
+ lambda: self.mismatches_link,
259
229
  )
260
230
  s.write(' | ')
261
- super()._render_metric(s)
231
+ super()._render_summary_metrics(s)
262
232
 
263
233
  def _render_matches(self, s: io.StringIO) -> None:
264
234
  """Formats the matched cases into html."""
@@ -271,12 +241,29 @@ class Matching(base.Evaluation):
271
241
  '<td>Prompt/Response Chain</td>'
272
242
  '</tr>'
273
243
  )
274
- for i, (example, output, message) in enumerate(self.matches):
244
+ def _maybe_html(v, root_indent: int):
245
+ del root_indent
246
+ if hasattr(v, '_repr_html_'):
247
+ return v._repr_html_() # pylint: disable=protected-access
248
+ # Fall back to the default format.
249
+ return None
250
+
251
+ for i, (_, example, output, message) in enumerate(self.matches):
275
252
  bgcolor = 'white' if i % 2 == 0 else '#DDDDDD'
276
253
  s.write(f'<tr style="background-color: {bgcolor}"><td>{i + 1}</td>')
277
- input_str = pg.format(example, verbose=False)
254
+ input_str = pg.Html.escape(
255
+ pg.format(
256
+ example, verbose=False, max_bytes_len=32,
257
+ custom_format=_maybe_html
258
+ )
259
+ )
278
260
  s.write(f'<td style="color:green;white-space:pre-wrap">{input_str}</td>')
279
- output_str = pg.format(output, verbose=False)
261
+ output_str = pg.Html.escape(
262
+ pg.format(
263
+ output, verbose=False, max_bytes_len=32,
264
+ custom_format=_maybe_html
265
+ )
266
+ )
280
267
  s.write(f'<td style="color:blue;white-space:pre-wrap">{output_str}</td>')
281
268
  s.write('<td>')
282
269
  self._render_message(message, s)
@@ -296,12 +283,12 @@ class Matching(base.Evaluation):
296
283
  '</tr>'
297
284
  )
298
285
 
299
- for i, (example, output, message) in enumerate(self.mismatches):
286
+ for i, (_, example, output, message) in enumerate(self.mismatches):
300
287
  bgcolor = 'white' if i % 2 == 0 else '#DDDDDD'
301
288
  s.write(f'<tr style="background-color: {bgcolor}"><td>{i + 1}</td>')
302
- input_str = pg.format(example, verbose=False)
289
+ input_str = pg.format(example, verbose=False, max_bytes_len=32)
303
290
  s.write(f'<td style="color:green;white-space:pre-wrap">{input_str}</td>')
304
- output_str = pg.format(output, verbose=False)
291
+ output_str = pg.format(output, verbose=False, max_bytes_len=32)
305
292
  s.write(
306
293
  f'<td style="color:magenta;white-space:pre-wrap">{output_str}</td>'
307
294
  )
@@ -120,6 +120,13 @@ class MatchingTest(unittest.TestCase):
120
120
  total=4,
121
121
  failures=1,
122
122
  failure_rate=0.25,
123
+ oop_failures=1,
124
+ oop_failure_rate=0.25,
125
+ non_oop_failures=0,
126
+ non_oop_failure_rate=0.0,
127
+ failure_breakdown={
128
+ 'MappingError.SchemaError.TypeError': 1
129
+ },
123
130
  num_matches=2,
124
131
  match_rate=0.5,
125
132
  num_mismatches=1,
@@ -145,22 +152,17 @@ class MatchingTest(unittest.TestCase):
145
152
  os.path.join(s.dir, matching.Matching.CACHE_JSON)
146
153
  )
147
154
  )
148
- self.assertTrue(
149
- os.path.exists(
150
- os.path.join(s.dir, matching.Matching.MATCHES_JSON)
151
- )
152
- )
153
155
  self.assertTrue(
154
156
  os.path.exists(
155
157
  os.path.join(
156
- s.dir, matching.Matching.MISMATCHES_JSON
158
+ s.dir, matching.Matching.OOP_FAILURES_JSON
157
159
  )
158
160
  )
159
161
  )
160
162
  self.assertTrue(
161
163
  os.path.exists(
162
164
  os.path.join(
163
- s.dir, matching.Matching.FAILURES_JSON
165
+ s.dir, matching.Matching.NON_OOP_FAILURES_JSON
164
166
  )
165
167
  )
166
168
  )
@@ -175,7 +177,14 @@ class MatchingTest(unittest.TestCase):
175
177
  self.assertTrue(
176
178
  os.path.exists(
177
179
  os.path.join(
178
- s.dir, matching.Matching.FAILURES_HTML
180
+ s.dir, matching.Matching.OOP_FAILURES_HTML
181
+ )
182
+ )
183
+ )
184
+ self.assertTrue(
185
+ os.path.exists(
186
+ os.path.join(
187
+ s.dir, matching.Matching.NON_OOP_FAILURES_HTML
179
188
  )
180
189
  )
181
190
  )