langfun 0.0.2.dev20240330__py3-none-any.whl → 0.1.2.dev202501140804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (145) hide show
  1. langfun/__init__.py +22 -2
  2. langfun/core/__init__.py +17 -5
  3. langfun/core/agentic/__init__.py +30 -0
  4. langfun/core/agentic/action.py +854 -0
  5. langfun/core/agentic/action_eval.py +150 -0
  6. langfun/core/agentic/action_eval_test.py +109 -0
  7. langfun/core/agentic/action_test.py +136 -0
  8. langfun/core/coding/python/__init__.py +5 -11
  9. langfun/core/coding/python/correction.py +37 -28
  10. langfun/core/coding/python/correction_test.py +29 -3
  11. langfun/core/coding/python/execution.py +40 -216
  12. langfun/core/coding/python/execution_test.py +29 -89
  13. langfun/core/coding/python/generation.py +21 -11
  14. langfun/core/coding/python/generation_test.py +2 -2
  15. langfun/core/coding/python/parsing.py +108 -193
  16. langfun/core/coding/python/parsing_test.py +2 -105
  17. langfun/core/component.py +69 -2
  18. langfun/core/component_test.py +54 -0
  19. langfun/core/concurrent.py +414 -117
  20. langfun/core/concurrent_test.py +111 -24
  21. langfun/core/console.py +18 -5
  22. langfun/core/console_test.py +17 -0
  23. langfun/core/eval/__init__.py +17 -0
  24. langfun/core/eval/base.py +767 -140
  25. langfun/core/eval/base_test.py +238 -53
  26. langfun/core/eval/matching.py +80 -76
  27. langfun/core/eval/matching_test.py +19 -9
  28. langfun/core/eval/patching.py +130 -0
  29. langfun/core/eval/patching_test.py +170 -0
  30. langfun/core/eval/scoring.py +37 -28
  31. langfun/core/eval/scoring_test.py +21 -3
  32. langfun/core/eval/v2/__init__.py +42 -0
  33. langfun/core/eval/v2/checkpointing.py +380 -0
  34. langfun/core/eval/v2/checkpointing_test.py +228 -0
  35. langfun/core/eval/v2/eval_test_helper.py +136 -0
  36. langfun/core/eval/v2/evaluation.py +725 -0
  37. langfun/core/eval/v2/evaluation_test.py +180 -0
  38. langfun/core/eval/v2/example.py +305 -0
  39. langfun/core/eval/v2/example_test.py +128 -0
  40. langfun/core/eval/v2/experiment.py +1048 -0
  41. langfun/core/eval/v2/experiment_test.py +433 -0
  42. langfun/core/eval/v2/metric_values.py +156 -0
  43. langfun/core/eval/v2/metric_values_test.py +80 -0
  44. langfun/core/eval/v2/metrics.py +357 -0
  45. langfun/core/eval/v2/metrics_test.py +203 -0
  46. langfun/core/eval/v2/progress.py +348 -0
  47. langfun/core/eval/v2/progress_test.py +82 -0
  48. langfun/core/eval/v2/progress_tracking.py +210 -0
  49. langfun/core/eval/v2/progress_tracking_test.py +66 -0
  50. langfun/core/eval/v2/reporting.py +270 -0
  51. langfun/core/eval/v2/reporting_test.py +158 -0
  52. langfun/core/eval/v2/runners.py +488 -0
  53. langfun/core/eval/v2/runners_test.py +334 -0
  54. langfun/core/langfunc.py +3 -21
  55. langfun/core/langfunc_test.py +26 -8
  56. langfun/core/language_model.py +686 -48
  57. langfun/core/language_model_test.py +681 -44
  58. langfun/core/llms/__init__.py +100 -12
  59. langfun/core/llms/anthropic.py +488 -0
  60. langfun/core/llms/anthropic_test.py +235 -0
  61. langfun/core/llms/cache/base.py +21 -2
  62. langfun/core/llms/cache/in_memory.py +13 -0
  63. langfun/core/llms/cache/in_memory_test.py +88 -28
  64. langfun/core/llms/compositional.py +101 -0
  65. langfun/core/llms/compositional_test.py +73 -0
  66. langfun/core/llms/deepseek.py +117 -0
  67. langfun/core/llms/deepseek_test.py +61 -0
  68. langfun/core/llms/fake.py +39 -26
  69. langfun/core/llms/fake_test.py +136 -11
  70. langfun/core/llms/gemini.py +507 -0
  71. langfun/core/llms/gemini_test.py +195 -0
  72. langfun/core/llms/google_genai.py +62 -218
  73. langfun/core/llms/google_genai_test.py +9 -197
  74. langfun/core/llms/groq.py +276 -0
  75. langfun/core/llms/groq_test.py +64 -0
  76. langfun/core/llms/llama_cpp.py +15 -40
  77. langfun/core/llms/llama_cpp_test.py +4 -30
  78. langfun/core/llms/openai.py +436 -226
  79. langfun/core/llms/openai_compatible.py +179 -0
  80. langfun/core/llms/openai_compatible_test.py +495 -0
  81. langfun/core/llms/openai_test.py +35 -174
  82. langfun/core/llms/rest.py +113 -0
  83. langfun/core/llms/rest_test.py +111 -0
  84. langfun/core/llms/vertexai.py +192 -0
  85. langfun/core/llms/vertexai_test.py +52 -0
  86. langfun/core/logging.py +284 -0
  87. langfun/core/logging_test.py +125 -0
  88. langfun/core/message.py +319 -9
  89. langfun/core/message_test.py +190 -13
  90. langfun/core/modalities/__init__.py +6 -2
  91. langfun/core/modalities/audio.py +30 -0
  92. langfun/core/modalities/audio_test.py +63 -0
  93. langfun/core/modalities/image.py +39 -20
  94. langfun/core/modalities/image_test.py +52 -9
  95. langfun/core/modalities/mime.py +206 -29
  96. langfun/core/modalities/mime_test.py +90 -9
  97. langfun/core/modalities/ms_office.py +117 -0
  98. langfun/core/modalities/ms_office_test.py +389 -0
  99. langfun/core/modalities/pdf.py +22 -0
  100. langfun/core/modalities/pdf_test.py +57 -0
  101. langfun/core/modalities/video.py +9 -23
  102. langfun/core/modalities/video_test.py +3 -3
  103. langfun/core/modality.py +26 -3
  104. langfun/core/modality_test.py +2 -2
  105. langfun/core/sampling.py +11 -11
  106. langfun/core/structured/__init__.py +15 -16
  107. langfun/core/structured/completion.py +32 -5
  108. langfun/core/structured/completion_test.py +9 -8
  109. langfun/core/structured/description.py +2 -2
  110. langfun/core/structured/description_test.py +3 -3
  111. langfun/core/structured/function_generation.py +278 -0
  112. langfun/core/structured/function_generation_test.py +399 -0
  113. langfun/core/structured/mapping.py +150 -46
  114. langfun/core/structured/mapping_test.py +105 -0
  115. langfun/core/structured/parsing.py +33 -21
  116. langfun/core/structured/parsing_test.py +71 -22
  117. langfun/core/structured/querying.py +746 -0
  118. langfun/core/structured/{prompting_test.py → querying_test.py} +545 -60
  119. langfun/core/structured/schema.py +208 -99
  120. langfun/core/structured/schema_generation.py +1 -1
  121. langfun/core/structured/schema_generation_test.py +2 -2
  122. langfun/core/structured/schema_test.py +133 -34
  123. langfun/core/structured/scoring.py +125 -19
  124. langfun/core/structured/scoring_test.py +30 -0
  125. langfun/core/structured/tokenization.py +64 -0
  126. langfun/core/structured/tokenization_test.py +48 -0
  127. langfun/core/template.py +240 -11
  128. langfun/core/template_test.py +146 -1
  129. langfun/core/templates/conversation.py +9 -0
  130. langfun/core/templates/conversation_test.py +4 -3
  131. langfun/core/templates/selfplay_test.py +14 -2
  132. langfun-0.1.2.dev202501140804.dist-info/METADATA +225 -0
  133. langfun-0.1.2.dev202501140804.dist-info/RECORD +153 -0
  134. {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/WHEEL +1 -1
  135. langfun/core/coding/python/errors.py +0 -108
  136. langfun/core/coding/python/errors_test.py +0 -99
  137. langfun/core/coding/python/permissions.py +0 -90
  138. langfun/core/coding/python/permissions_test.py +0 -86
  139. langfun/core/structured/prompting.py +0 -217
  140. langfun/core/text_formatting.py +0 -162
  141. langfun/core/text_formatting_test.py +0 -47
  142. langfun-0.0.2.dev20240330.dist-info/METADATA +0 -99
  143. langfun-0.0.2.dev20240330.dist-info/RECORD +0 -102
  144. {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/LICENSE +0 -0
  145. {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/top_level.txt +0 -0
@@ -101,7 +101,7 @@ class EvaluationTest(unittest.TestCase):
101
101
  self.assertEqual(s.dir, os.path.join(s.root_dir, s.id))
102
102
  self.assertEqual(s.hash, s.clone().hash)
103
103
  # Test persistent hash.
104
- self.assertEqual(s.hash, 'abc7c29a')
104
+ self.assertEqual(s.hash, 'ae86c703')
105
105
  self.assertEqual(
106
106
  s.hash, s.clone(override={'max_workers': 2, 'lm.timeout': 20}).hash
107
107
  )
@@ -194,6 +194,8 @@ class EvaluationTest(unittest.TestCase):
194
194
  cache_seed=0,
195
195
  score=1.0,
196
196
  logprobs=None,
197
+ is_cached=False,
198
+ usage=lf.LMSamplingUsage(387, 24, 411),
197
199
  tags=['lm-response', 'lm-output', 'transformed'],
198
200
  ),
199
201
  )
@@ -209,7 +211,7 @@ class EvaluationTest(unittest.TestCase):
209
211
  s.result,
210
212
  dict(
211
213
  experiment_setup=dict(
212
- id='Evaluation@17915dc6',
214
+ id='Evaluation@0fade07d',
213
215
  dir=s.dir,
214
216
  model='StaticSequence',
215
217
  prompt_template='{{example.question}}',
@@ -219,7 +221,26 @@ class EvaluationTest(unittest.TestCase):
219
221
  cache_stats=dict(
220
222
  use_cache=True, num_queries=2, num_hits=0, num_updates=2
221
223
  ),
222
- metrics=dict(total=2, failures=1, failure_rate=0.5),
224
+ metrics=dict(
225
+ total=2,
226
+ failures=1,
227
+ failure_rate=0.5,
228
+ oop_failures=1,
229
+ oop_failure_rate=0.5,
230
+ non_oop_failures=0,
231
+ non_oop_failure_rate=0.0,
232
+ failure_breakdown={
233
+ 'MappingError.SchemaError.TypeError': 1
234
+ }
235
+ ),
236
+ usage=dict(
237
+ total_prompt_tokens=774,
238
+ total_completion_tokens=25,
239
+ num_usages=2,
240
+ average_prompt_tokens=387,
241
+ average_completion_tokens=12,
242
+ average_total_tokens=399,
243
+ ),
223
244
  ),
224
245
  )
225
246
  self.assertTrue(
@@ -227,14 +248,32 @@ class EvaluationTest(unittest.TestCase):
227
248
  self.assertTrue(
228
249
  os.path.exists(os.path.join(s.dir, base.Evaluation.RESULT_JSON)))
229
250
  self.assertTrue(
230
- os.path.exists(os.path.join(s.dir, base.Evaluation.CACHE_JSON)))
251
+ os.path.exists(os.path.join(s.dir, base.Evaluation.OOP_FAILURES_JSON)))
231
252
  self.assertTrue(
232
- os.path.exists(os.path.join(s.root_dir, base.Evaluation.SUMMARY_HTML))
233
- )
253
+ os.path.exists(
254
+ os.path.join(s.dir, base.Evaluation.NON_OOP_FAILURES_JSON)))
255
+ self.assertTrue(
256
+ os.path.exists(os.path.join(s.dir, base.Evaluation.CACHE_JSON)))
234
257
  self.assertTrue(
235
258
  os.path.exists(os.path.join(s.dir, base.Evaluation.INDEX_HTML)))
236
259
  self.assertTrue(
237
- os.path.exists(os.path.join(s.dir, base.Evaluation.FAILURES_HTML)))
260
+ os.path.exists(os.path.join(s.dir, base.Evaluation.OOP_FAILURES_HTML)))
261
+ self.assertTrue(
262
+ os.path.exists(
263
+ os.path.join(s.dir, base.Evaluation.NON_OOP_FAILURES_HTML)))
264
+ self.assertTrue(
265
+ os.path.exists(os.path.join(s.root_dir, base.Evaluation.SUMMARY_HTML))
266
+ )
267
+ # Check summary JSON.
268
+ summary_json = os.path.join(
269
+ s.root_dir, base.Evaluation.SUMMARY_HTML.replace('.html', '.json')
270
+ )
271
+ self.assertTrue(os.path.exists(summary_json))
272
+ summary = pg.load(summary_json, auto_dict=True)
273
+ self.assertIn('Evaluation', summary)
274
+ self.assertEqual(len(summary['Evaluation']), 1)
275
+ self.assertIsNotNone(summary['Evaluation'][0].experiment)
276
+ self.assertIsNotNone(summary['Evaluation'][0].metrics)
238
277
 
239
278
  def test_run_wihtout_save(self):
240
279
  lm = fake.StaticSequence([
@@ -255,7 +294,10 @@ class EvaluationTest(unittest.TestCase):
255
294
  self.assertFalse(
256
295
  os.path.exists(os.path.join(s.dir, base.Evaluation.INDEX_HTML)))
257
296
  self.assertFalse(
258
- os.path.exists(os.path.join(s.dir, base.Evaluation.FAILURES_HTML)))
297
+ os.path.exists(os.path.join(s.dir, base.Evaluation.OOP_FAILURES_HTML)))
298
+ self.assertFalse(
299
+ os.path.exists(
300
+ os.path.join(s.dir, base.Evaluation.NON_OOP_FAILURES_HTML)))
259
301
 
260
302
  def test_load(self):
261
303
  lm = fake.StaticResponse('Solution(final_answer=2)')
@@ -274,8 +316,11 @@ class EvaluationTest(unittest.TestCase):
274
316
  s = eval_set(
275
317
  'run_filter_test', pg.oneof(['call', 'query']),
276
318
  schema_fn=answer_schema(), lm=lm)
319
+ result = s.run(
320
+ filter=lambda x: x.method == 'query', dryrun=True, summary=False
321
+ )
277
322
  self.assertEqual(
278
- s.run(filter=lambda x: x.method == 'query', dryrun=True, summary=False),
323
+ result,
279
324
  {
280
325
  s.children[0].id: None,
281
326
  s.children[1].id: dict(
@@ -290,8 +335,18 @@ class EvaluationTest(unittest.TestCase):
290
335
  cache_stats=dict(
291
336
  use_cache=True, num_queries=2, num_hits=0, num_updates=2
292
337
  ),
293
- metrics=dict(total=2, failures=0, failure_rate=0.0),
294
- )
338
+ metrics=dict(
339
+ total=2,
340
+ failures=0,
341
+ failure_rate=0.0,
342
+ oop_failures=0,
343
+ oop_failure_rate=0.0,
344
+ non_oop_failures=0,
345
+ non_oop_failure_rate=0.0,
346
+ failure_breakdown={},
347
+ ),
348
+ usage=s.children[1].result.usage,
349
+ ),
295
350
  },
296
351
  )
297
352
 
@@ -321,11 +376,10 @@ class EvaluationTest(unittest.TestCase):
321
376
  s.children[0].dir, os.path.join(s.root_dir, s.children[0].id)
322
377
  )
323
378
  # Test persistent hash.
324
- self.assertEqual(s.hash, 'ca7f722b')
379
+ self.assertEqual(s.hash, 'b66a4e88')
325
380
 
326
381
  summary = s.run(verbose=True)
327
382
  self.assertEqual(len(summary.evaluations), 2)
328
-
329
383
  self.assertEqual(
330
384
  s.result,
331
385
  {
@@ -341,7 +395,19 @@ class EvaluationTest(unittest.TestCase):
341
395
  cache_stats=dict(
342
396
  use_cache=True, num_queries=2, num_hits=0, num_updates=2
343
397
  ),
344
- metrics=dict(total=2, failures=1, failure_rate=0.5),
398
+ metrics=dict(
399
+ total=2,
400
+ failures=1,
401
+ failure_rate=0.5,
402
+ oop_failures=1,
403
+ oop_failure_rate=0.5,
404
+ non_oop_failures=0,
405
+ non_oop_failure_rate=0.0,
406
+ failure_breakdown={
407
+ 'MappingError.SchemaError.TypeError': 1
408
+ }
409
+ ),
410
+ usage=s.children[0].result.usage,
345
411
  ),
346
412
  s.children[1].id: dict(
347
413
  experiment_setup=dict(
@@ -355,7 +421,19 @@ class EvaluationTest(unittest.TestCase):
355
421
  cache_stats=dict(
356
422
  use_cache=True, num_queries=2, num_hits=0, num_updates=2
357
423
  ),
358
- metrics=dict(total=2, failures=1, failure_rate=0.5),
424
+ metrics=dict(
425
+ total=2,
426
+ failures=1,
427
+ failure_rate=0.5,
428
+ oop_failures=1,
429
+ oop_failure_rate=0.5,
430
+ non_oop_failures=0,
431
+ non_oop_failure_rate=0.0,
432
+ failure_breakdown={
433
+ 'MappingError.SchemaError.TypeError': 1
434
+ }
435
+ ),
436
+ usage=s.children[1].result.usage,
359
437
  ),
360
438
  },
361
439
  )
@@ -448,10 +526,10 @@ class SuiteTest(unittest.TestCase):
448
526
  lm=lm
449
527
  )
450
528
  # Test for persistent hash.
451
- self.assertEqual(s.hash, '7285e52b')
529
+ self.assertEqual(s.hash, '26e6cc25')
452
530
  s.run()
453
531
  expected = {
454
- s.children[0].id: dict(
532
+ 'Evaluation@0fade07d': dict(
455
533
  experiment_setup=dict(
456
534
  id=s.children[0].id,
457
535
  dir=s.children[0].dir,
@@ -463,45 +541,46 @@ class SuiteTest(unittest.TestCase):
463
541
  cache_stats=dict(
464
542
  use_cache=True, num_queries=2, num_hits=0, num_updates=2
465
543
  ),
466
- metrics=dict(total=2, failures=1, failure_rate=0.5),
544
+ metrics=dict(
545
+ total=2,
546
+ failures=1,
547
+ failure_rate=0.5,
548
+ oop_failures=1,
549
+ oop_failure_rate=0.5,
550
+ non_oop_failures=0,
551
+ non_oop_failure_rate=0.0,
552
+ failure_breakdown={
553
+ 'MappingError.SchemaError.TypeError': 1
554
+ }
555
+ ),
556
+ usage=s.children[0].result.usage,
467
557
  ),
468
- s.children[1].id: {
469
- s.children[1]
470
- .children[0]
471
- .id: dict(
472
- experiment_setup=dict(
473
- id=s.children[1].children[0].id,
474
- dir=s.children[1].children[0].dir,
475
- model='StaticSequence',
476
- prompt_template='{{example.question}}',
477
- method='call',
478
- schema_fn='answer_schema()',
479
- ),
480
- cache_stats=dict(
481
- use_cache=True, num_queries=4, num_hits=1, num_updates=3
482
- ),
483
- metrics=dict(total=2, failures=2, failure_rate=1.0),
558
+ 'Evaluation@ae86c703': dict(
559
+ experiment_setup=dict(
560
+ id=s.children[1].children[0].id,
561
+ dir=s.children[1].children[0].dir,
562
+ model='StaticSequence',
563
+ prompt_template='{{example.question}}',
564
+ method='call',
565
+ schema_fn='answer_schema()',
484
566
  ),
485
- s.children[1]
486
- .children[2]
487
- .id: dict(
488
- experiment_setup=dict(
489
- id=s.children[1].children[2].id,
490
- dir=s.children[1].children[2].dir,
491
- model='StaticSequence',
492
- prompt_template='{{example.question}}',
493
- method='query',
494
- schema_fn='answer_schema()',
495
- ),
496
- cache_stats=dict(
497
- use_cache=True,
498
- num_queries=2,
499
- num_hits=0,
500
- num_updates=2,
501
- ),
502
- metrics=dict(total=2, failures=1, failure_rate=0.5),
567
+ cache_stats=dict(
568
+ use_cache=True, num_queries=4, num_hits=0, num_updates=4
503
569
  ),
504
- },
570
+ metrics=dict(
571
+ total=2,
572
+ failures=2,
573
+ failure_rate=1.0,
574
+ oop_failures=2,
575
+ oop_failure_rate=1.0,
576
+ non_oop_failures=0,
577
+ non_oop_failure_rate=0.0,
578
+ failure_breakdown={
579
+ 'MappingError.SchemaError.TypeError': 2
580
+ }
581
+ ),
582
+ usage=s.children[1].children[0].result.usage,
583
+ ),
505
584
  }
506
585
  self.assertEqual(s.result, expected)
507
586
 
@@ -520,6 +599,14 @@ class InputsFrom(unittest.TestCase):
520
599
  pg.save([1, 2, 3], path)
521
600
  self.assertEqual(base.inputs_from(path)(), [1, 2, 3])
522
601
 
602
+ path = os.path.join(tmp_dir, 'input_file.jsonl')
603
+ with pg.open_jsonl(path, 'w') as f:
604
+ f.add(pg.Dict(x=1))
605
+ f.add(dict(y=2))
606
+ self.assertEqual(
607
+ base.inputs_from(path)(), [pg.Dict(x=1), dict(y=2)]
608
+ )
609
+
523
610
  def test_inputs_from_multiple_files(self):
524
611
  tmp_dir = tempfile.gettempdir()
525
612
  path1 = os.path.join(tmp_dir, 'input_file1.json')
@@ -671,5 +758,103 @@ class SummaryTest(unittest.TestCase):
671
758
  self.assertTrue(pg.io.path_exists(summary_file))
672
759
 
673
760
 
761
+ class NamedEvaluationTest(unittest.TestCase):
762
+
763
+ def test_named_eval_class(self):
764
+
765
+ @base.register('named_eval/class_test')
766
+ class MyEval(base.Evaluation):
767
+ inputs = base.as_inputs([
768
+ pg.Dict(question='Compute 1 + 1'),
769
+ ])
770
+ method = 'query'
771
+ prompt = pg.oneof([
772
+ lf.Template('{{example.question}}'),
773
+ lf.Template('Hello {{example.question}}'),
774
+ ])
775
+ schema_fn = answer_schema()
776
+
777
+ [evaluation] = base.get_evaluations('named_eval/class_test')
778
+ self.assertIsInstance(evaluation, MyEval)
779
+ self.assertIsNone(evaluation.dir)
780
+ self.assertIsNone(evaluation.root_dir)
781
+ self.assertIn('named_eval/class_test', base.registered_names())
782
+
783
+ with self.assertRaisesRegex(ValueError, 'Unsupported type.*'):
784
+ @base.register('named_eval/bad_class')
785
+ class Foo: # pylint: disable=unused-variable
786
+ pass
787
+
788
+ def test_named_eval_functor(self):
789
+
790
+ @base.register('named_eval/functor_test')
791
+ def my_eval():
792
+ return base.Evaluation(
793
+ inputs=base.as_inputs([
794
+ pg.Dict(question='Compute 1 + 1'),
795
+ ]),
796
+ method='query',
797
+ prompt=pg.oneof([
798
+ lf.Template('{{example.question}}'),
799
+ lf.Template('Hello {{example.question}}'),
800
+ ]),
801
+ schema_fn=answer_schema(),
802
+ )
803
+
804
+ self.assertTrue(issubclass(my_eval, base.Evaluable))
805
+ [evaluation] = base.get_evaluations('named_eval/functor_test')
806
+ self.assertIn('named_eval/functor_test', base.registered_names())
807
+ self.assertIsInstance(evaluation, my_eval)
808
+ self.assertIsNone(evaluation.root_dir, None)
809
+
810
+ self.assertTrue(
811
+ pg.eq(base.get_evaluations('named_eval/functor.*'), [evaluation])
812
+ )
813
+ self.assertEqual(base.get_evaluations('named_eval/non_existent'), [])
814
+
815
+ with self.assertRaisesRegex(TypeError, 'The return value .*'):
816
+ @base.register('named_eval/bad_return_type')
817
+ def bad_eval(): # pylint: disable=unused-variable
818
+ return 1
819
+
820
+ def test_run(self):
821
+ @base.register('test/run')
822
+ def test_run(): # pylint: disable=unused-variable
823
+ lm = fake.StaticResponse('Solution(final_answer=2)')
824
+ return eval_set('run_test', 'query', schema_fn=answer_schema(), lm=lm)
825
+
826
+ e = base.run(
827
+ tempfile.gettempdir(),
828
+ ['test/run'],
829
+ id_regex='run_test.*',
830
+ mode='dryrun',
831
+ print_definition=True,
832
+ )
833
+ self.assertEqual(
834
+ e.leaf_nodes[0].dir,
835
+ os.path.join(tempfile.gettempdir(), e.leaf_nodes[0].id),
836
+ )
837
+ self.assertTrue(
838
+ pg.eq(
839
+ e.leaf_nodes[0].lm, fake.StaticResponse('Solution(final_answer=2)')
840
+ )
841
+ )
842
+
843
+ @pg.patcher()
844
+ def bad_lm(unused_eval): # pylint: disable=unused-variable
845
+ return dict(lm=fake.StaticResponse('efg'))
846
+
847
+ e = base.run(
848
+ tempfile.gettempdir(),
849
+ [test_run()],
850
+ filter='Evaluation.*',
851
+ patches=['bad_lm']
852
+ )
853
+ self.assertTrue(pg.eq(e.leaf_nodes[0].lm, fake.StaticResponse('efg')))
854
+
855
+ with self.assertRaisesRegex(ValueError, 'No evaluations found'):
856
+ base.run(tempfile.gettempdir(), ['test/non_existent'])
857
+
858
+
674
859
  if __name__ == '__main__':
675
860
  unittest.main()
@@ -41,8 +41,8 @@ class Matching(base.Evaluation):
41
41
  """Returns the answer from the structure output."""
42
42
 
43
43
  @property
44
- def matches(self) -> list[tuple[Any, Any, lf.Message]]:
45
- """Returns the matches examples, outputs and the output messages."""
44
+ def matches(self) -> list[tuple[int, Any, Any, lf.Message]]:
45
+ """Returns the matches IDs, examples, outputs and the output messages."""
46
46
  return self._matches
47
47
 
48
48
  @property
@@ -57,7 +57,7 @@ class Matching(base.Evaluation):
57
57
  return self.num_matches / self.num_completed
58
58
 
59
59
  @property
60
- def mismatches(self) -> list[tuple[Any, Any, lf.Message]]:
60
+ def mismatches(self) -> list[tuple[int, Any, Any, lf.Message]]:
61
61
  """Returns the mismatches examples, outputs and output messages."""
62
62
  return self._mismatches
63
63
 
@@ -86,34 +86,51 @@ class Matching(base.Evaluation):
86
86
  self._matches = []
87
87
  self._mismatches = []
88
88
 
89
- def audit(self, example: Any, output: Any, message: lf.Message) -> None:
89
+ def audit_processed(
90
+ self, example_idx: int, example: Any, output: Any, message: lf.Message,
91
+ dryrun: bool = False
92
+ ) -> None:
90
93
  groundtruth = self.groundtruth(example)
91
94
  answer = self.answer(output, example)
95
+
96
+ if dryrun:
97
+ lf.console.write('')
98
+ lf.console.write(
99
+ str(groundtruth),
100
+ title='GROUDTRUTH',
101
+ color='green',
102
+ )
103
+ lf.console.write('')
104
+ lf.console.write(
105
+ str(answer),
106
+ title='ANSWER',
107
+ color='blue',
108
+ )
109
+
92
110
  if self.match(answer, groundtruth):
93
- self._matches.append((example, output, message))
111
+ self._matches.append((example_idx, example, output, message))
94
112
  else:
95
- self._mismatches.append((example, output, message))
113
+ self._mismatches.append((example_idx, example, output, message))
96
114
 
97
115
  def match(self, answer: Any, groundtruth: Any) -> bool:
98
116
  """Matches answer against the groundtruth. Subclasses can override."""
99
117
  return pg.eq(answer, groundtruth)
100
118
 
101
- def _status(self, progress: lf.concurrent.Progress) -> dict[str, Any]:
119
+ def _eval_status(self, progress: lf.concurrent.Progress) -> dict[str, Any]:
102
120
  del progress
103
121
  return {
104
- 'Model': self.lm.model_id,
105
- 'Matches': f'%.{self.report_precision}f%% (%d/%d)' % (
106
- self.match_rate * 100,
122
+ 'Matches': '%s (%d/%d)' % (
123
+ self._format_rate(self.match_rate),
107
124
  self.num_matches,
108
125
  self.num_completed,
109
126
  ),
110
- 'Mismatches': f'%.{self.report_precision}f%% (%d/%d)' % (
111
- self.mismatch_rate * 100,
127
+ 'Mismatches': '%s (%d/%d)' % (
128
+ self._format_rate(self.mismatch_rate),
112
129
  self.num_mismatches,
113
130
  self.num_completed,
114
131
  ),
115
- 'Failed': f'%.{self.report_precision}f%% (%d/%d)' % (
116
- self.failure_rate * 100,
132
+ 'Failed': '%s (%d/%d)' % (
133
+ self._format_rate(self.failure_rate),
117
134
  self.num_failures,
118
135
  self.num_completed,
119
136
  ),
@@ -123,24 +140,25 @@ class Matching(base.Evaluation):
123
140
  assert self.result is not None
124
141
  m = self.result.metrics
125
142
  return (
126
- f'COMPLETED(%s): Matches=%.{self.report_precision}f%% (%d/%d)'
127
- f' Mismatches=%.{self.report_precision}f%% (%d/%d)'
128
- f' Failures=%.{self.report_precision}f%% (%d/%d)'
143
+ 'COMPLETED(%s):'
144
+ ' Matches=%s (%d/%d)'
145
+ ' Mismatches=%s (%d/%d)'
146
+ ' Failures=%s (%d/%d)'
129
147
  ) % (
130
148
  run_status,
131
- m.match_rate * 100,
149
+ self._format_rate(m.match_rate),
132
150
  m.num_matches,
133
151
  m.total,
134
- m.mismatch_rate * 100,
152
+ self._format_rate(m.mismatch_rate),
135
153
  m.num_mismatches,
136
154
  m.total,
137
- m.failure_rate * 100,
155
+ self._format_rate(m.failure_rate),
138
156
  m.failures,
139
157
  m.total,
140
158
  )
141
159
 
142
- def summarize(self) -> pg.Dict:
143
- result = super().summarize()
160
+ def finalize(self) -> pg.Dict:
161
+ result = super().finalize()
144
162
  result.metrics.update(
145
163
  num_matches=self.num_matches,
146
164
  match_rate=self.match_rate,
@@ -154,33 +172,6 @@ class Matching(base.Evaluation):
154
172
  ) -> None:
155
173
  super().save(definition, result, report)
156
174
 
157
- if result:
158
-
159
- def force_dict(v):
160
- return pg.object_utils.json_conversion.strip_types(pg.to_json(v))
161
-
162
- # Save matches.
163
- pg.save(
164
- [
165
- # We force the output to be dict as its type may be defined
166
- # within functors which could be deserialized.
167
- pg.Dict(input=input, output=force_dict(output))
168
- for input, output, _ in self.matches
169
- ],
170
- os.path.join(self.dir, Matching.MATCHES_JSON),
171
- )
172
-
173
- # Save mismatches.
174
- pg.save(
175
- [
176
- # We force the output to be dict as its type may be defined
177
- # within functors which could be deserialized.
178
- pg.Dict(input=input, output=force_dict(output))
179
- for input, output, _ in self.mismatches
180
- ],
181
- os.path.join(self.dir, Matching.MISMATCHES_JSON),
182
- )
183
-
184
175
  if report:
185
176
  pg.save(
186
177
  self._html([self._render_result, self._render_matches]),
@@ -201,9 +192,9 @@ class Matching(base.Evaluation):
201
192
  def _render_result_row(self, s: io.StringIO):
202
193
  super()._render_result_row(s)
203
194
  s.write(
204
- '<td><span style="color:red">%s</span>%s</td>'
195
+ '<td><span style="color:orange">%s</span>%s</td>'
205
196
  % (
206
- f'%.{self.report_precision}f%% ' % (self.mismatch_rate * 100),
197
+ self._format_rate(self.mismatch_rate),
207
198
  '<a href="%s">(%d/%d)</a>'
208
199
  % (self.mismatches_link, self.num_mismatches, self.num_completed),
209
200
  )
@@ -211,37 +202,33 @@ class Matching(base.Evaluation):
211
202
  s.write(
212
203
  '<td><span style="color:green">%s</span>%s</td>'
213
204
  % (
214
- f'%.{self.report_precision}f%% ' % (self.match_rate * 100),
205
+ self._format_rate(self.match_rate),
215
206
  '<a href="%s">(%d/%d)</a>'
216
207
  % (self.matches_link, self.num_matches, self.num_completed),
217
208
  )
218
209
  )
219
210
 
220
- def _render_metric(self, s: io.StringIO) -> None:
211
+ def _render_summary_metrics(self, s: io.StringIO) -> None:
221
212
  """Renders metrics in HTML."""
222
213
  assert self.result is not None
223
214
  m = self.result.metrics
224
- s.write(
225
- '<a title="Matches (%d/%d)" href="%s" style="color:green">%s</a>'
226
- % (
227
- m.num_matches,
228
- m.total,
229
- self.matches_link,
230
- f'%.{self.report_precision}f%% ' % (m.match_rate * 100),
231
- )
215
+ self._render_link(
216
+ s,
217
+ 'Matches (%d/%d)' % (m.num_matches, m.total),
218
+ self._format_rate(m.match_rate),
219
+ 'color:green',
220
+ lambda: self.matches_link,
232
221
  )
233
222
  s.write(' | ')
234
- s.write(
235
- '<a title="Mismatches (%d/%d)" href="%s" style="color:orange">%s</a>'
236
- % (
237
- m.num_mismatches,
238
- m.total,
239
- self.mismatches_link,
240
- f'%.{self.report_precision}f%% ' % (m.mismatch_rate * 100),
241
- )
223
+ self._render_link(
224
+ s,
225
+ 'Mismatches (%d/%d)' % (m.num_mismatches, m.total),
226
+ self._format_rate(m.mismatch_rate),
227
+ 'color:orange',
228
+ lambda: self.mismatches_link,
242
229
  )
243
230
  s.write(' | ')
244
- super()._render_metric(s)
231
+ super()._render_summary_metrics(s)
245
232
 
246
233
  def _render_matches(self, s: io.StringIO) -> None:
247
234
  """Formats the matched cases into html."""
@@ -254,12 +241,29 @@ class Matching(base.Evaluation):
254
241
  '<td>Prompt/Response Chain</td>'
255
242
  '</tr>'
256
243
  )
257
- for i, (example, output, message) in enumerate(self.matches):
244
+ def _maybe_html(v, root_indent: int):
245
+ del root_indent
246
+ if hasattr(v, '_repr_html_'):
247
+ return v._repr_html_() # pylint: disable=protected-access
248
+ # Fall back to the default format.
249
+ return None
250
+
251
+ for i, (_, example, output, message) in enumerate(self.matches):
258
252
  bgcolor = 'white' if i % 2 == 0 else '#DDDDDD'
259
253
  s.write(f'<tr style="background-color: {bgcolor}"><td>{i + 1}</td>')
260
- input_str = pg.format(example, verbose=False)
254
+ input_str = pg.Html.escape(
255
+ pg.format(
256
+ example, verbose=False, max_bytes_len=32,
257
+ custom_format=_maybe_html
258
+ )
259
+ )
261
260
  s.write(f'<td style="color:green;white-space:pre-wrap">{input_str}</td>')
262
- output_str = pg.format(output, verbose=False)
261
+ output_str = pg.Html.escape(
262
+ pg.format(
263
+ output, verbose=False, max_bytes_len=32,
264
+ custom_format=_maybe_html
265
+ )
266
+ )
263
267
  s.write(f'<td style="color:blue;white-space:pre-wrap">{output_str}</td>')
264
268
  s.write('<td>')
265
269
  self._render_message(message, s)
@@ -279,12 +283,12 @@ class Matching(base.Evaluation):
279
283
  '</tr>'
280
284
  )
281
285
 
282
- for i, (example, output, message) in enumerate(self.mismatches):
286
+ for i, (_, example, output, message) in enumerate(self.mismatches):
283
287
  bgcolor = 'white' if i % 2 == 0 else '#DDDDDD'
284
288
  s.write(f'<tr style="background-color: {bgcolor}"><td>{i + 1}</td>')
285
- input_str = pg.format(example, verbose=False)
289
+ input_str = pg.format(example, verbose=False, max_bytes_len=32)
286
290
  s.write(f'<td style="color:green;white-space:pre-wrap">{input_str}</td>')
287
- output_str = pg.format(output, verbose=False)
291
+ output_str = pg.format(output, verbose=False, max_bytes_len=32)
288
292
  s.write(
289
293
  f'<td style="color:magenta;white-space:pre-wrap">{output_str}</td>'
290
294
  )