langfun 0.0.2.dev20240429__py3-none-any.whl → 0.1.2.dev202501150804__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/__init__.py +20 -2
- langfun/core/__init__.py +16 -5
- langfun/core/agentic/__init__.py +30 -0
- langfun/core/agentic/action.py +854 -0
- langfun/core/agentic/action_eval.py +150 -0
- langfun/core/agentic/action_eval_test.py +109 -0
- langfun/core/agentic/action_test.py +136 -0
- langfun/core/coding/python/__init__.py +5 -11
- langfun/core/coding/python/correction.py +37 -21
- langfun/core/coding/python/correction_test.py +29 -3
- langfun/core/coding/python/execution.py +40 -216
- langfun/core/coding/python/execution_test.py +29 -89
- langfun/core/coding/python/generation.py +21 -11
- langfun/core/coding/python/generation_test.py +2 -2
- langfun/core/coding/python/parsing.py +108 -193
- langfun/core/coding/python/parsing_test.py +2 -105
- langfun/core/component.py +63 -2
- langfun/core/component_test.py +53 -0
- langfun/core/concurrent.py +414 -117
- langfun/core/concurrent_test.py +111 -24
- langfun/core/console.py +17 -5
- langfun/core/console_test.py +17 -0
- langfun/core/eval/__init__.py +16 -1
- langfun/core/eval/base.py +622 -174
- langfun/core/eval/base_test.py +200 -54
- langfun/core/eval/matching.py +63 -76
- langfun/core/eval/matching_test.py +17 -8
- langfun/core/eval/patching.py +130 -0
- langfun/core/eval/patching_test.py +170 -0
- langfun/core/eval/scoring.py +26 -26
- langfun/core/eval/scoring_test.py +19 -2
- langfun/core/eval/v2/__init__.py +42 -0
- langfun/core/eval/v2/checkpointing.py +380 -0
- langfun/core/eval/v2/checkpointing_test.py +228 -0
- langfun/core/eval/v2/eval_test_helper.py +136 -0
- langfun/core/eval/v2/evaluation.py +725 -0
- langfun/core/eval/v2/evaluation_test.py +180 -0
- langfun/core/eval/v2/example.py +305 -0
- langfun/core/eval/v2/example_test.py +128 -0
- langfun/core/eval/v2/experiment.py +1048 -0
- langfun/core/eval/v2/experiment_test.py +433 -0
- langfun/core/eval/v2/metric_values.py +156 -0
- langfun/core/eval/v2/metric_values_test.py +80 -0
- langfun/core/eval/v2/metrics.py +357 -0
- langfun/core/eval/v2/metrics_test.py +203 -0
- langfun/core/eval/v2/progress.py +348 -0
- langfun/core/eval/v2/progress_test.py +82 -0
- langfun/core/eval/v2/progress_tracking.py +210 -0
- langfun/core/eval/v2/progress_tracking_test.py +66 -0
- langfun/core/eval/v2/reporting.py +270 -0
- langfun/core/eval/v2/reporting_test.py +158 -0
- langfun/core/eval/v2/runners.py +488 -0
- langfun/core/eval/v2/runners_test.py +334 -0
- langfun/core/langfunc.py +4 -17
- langfun/core/langfunc_test.py +22 -6
- langfun/core/language_model.py +577 -39
- langfun/core/language_model_test.py +470 -56
- langfun/core/llms/__init__.py +87 -16
- langfun/core/llms/anthropic.py +312 -87
- langfun/core/llms/anthropic_test.py +71 -3
- langfun/core/llms/cache/base.py +21 -2
- langfun/core/llms/cache/in_memory.py +13 -0
- langfun/core/llms/cache/in_memory_test.py +53 -2
- langfun/core/llms/compositional.py +101 -0
- langfun/core/llms/compositional_test.py +73 -0
- langfun/core/llms/deepseek.py +117 -0
- langfun/core/llms/deepseek_test.py +61 -0
- langfun/core/llms/fake.py +11 -7
- langfun/core/llms/fake_test.py +14 -0
- langfun/core/llms/gemini.py +507 -0
- langfun/core/llms/gemini_test.py +195 -0
- langfun/core/llms/google_genai.py +62 -218
- langfun/core/llms/google_genai_test.py +9 -202
- langfun/core/llms/groq.py +160 -144
- langfun/core/llms/groq_test.py +31 -137
- langfun/core/llms/llama_cpp.py +15 -42
- langfun/core/llms/llama_cpp_test.py +4 -30
- langfun/core/llms/openai.py +395 -203
- langfun/core/llms/openai_compatible.py +179 -0
- langfun/core/llms/openai_compatible_test.py +495 -0
- langfun/core/llms/openai_test.py +30 -395
- langfun/core/llms/rest.py +113 -0
- langfun/core/llms/rest_test.py +111 -0
- langfun/core/llms/vertexai.py +192 -0
- langfun/core/llms/vertexai_test.py +52 -0
- langfun/core/logging.py +284 -0
- langfun/core/logging_test.py +125 -0
- langfun/core/message.py +319 -9
- langfun/core/message_test.py +190 -13
- langfun/core/modalities/__init__.py +6 -2
- langfun/core/modalities/audio.py +30 -0
- langfun/core/modalities/audio_test.py +63 -0
- langfun/core/modalities/image.py +39 -20
- langfun/core/modalities/image_test.py +52 -9
- langfun/core/modalities/mime.py +206 -29
- langfun/core/modalities/mime_test.py +90 -9
- langfun/core/modalities/ms_office.py +117 -0
- langfun/core/modalities/ms_office_test.py +389 -0
- langfun/core/modalities/pdf.py +22 -0
- langfun/core/modalities/pdf_test.py +57 -0
- langfun/core/modalities/video.py +9 -26
- langfun/core/modalities/video_test.py +3 -3
- langfun/core/modality.py +26 -3
- langfun/core/modality_test.py +2 -2
- langfun/core/sampling.py +11 -11
- langfun/core/structured/__init__.py +12 -16
- langfun/core/structured/completion.py +32 -5
- langfun/core/structured/completion_test.py +7 -6
- langfun/core/structured/description.py +2 -2
- langfun/core/structured/description_test.py +3 -3
- langfun/core/structured/function_generation.py +60 -27
- langfun/core/structured/function_generation_test.py +72 -2
- langfun/core/structured/mapping.py +97 -47
- langfun/core/structured/mapping_test.py +90 -2
- langfun/core/structured/parsing.py +33 -21
- langfun/core/structured/parsing_test.py +53 -9
- langfun/core/structured/querying.py +746 -0
- langfun/core/structured/{prompting_test.py → querying_test.py} +469 -51
- langfun/core/structured/schema.py +204 -97
- langfun/core/structured/schema_generation.py +1 -1
- langfun/core/structured/schema_test.py +130 -29
- langfun/core/structured/scoring.py +125 -19
- langfun/core/structured/scoring_test.py +30 -0
- langfun/core/structured/tokenization.py +64 -0
- langfun/core/structured/tokenization_test.py +48 -0
- langfun/core/template.py +115 -1
- langfun/core/template_test.py +71 -1
- langfun/core/templates/conversation.py +9 -0
- langfun/core/templates/conversation_test.py +4 -3
- langfun/core/templates/selfplay_test.py +10 -2
- langfun-0.1.2.dev202501150804.dist-info/METADATA +225 -0
- langfun-0.1.2.dev202501150804.dist-info/RECORD +153 -0
- {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501150804.dist-info}/WHEEL +1 -1
- langfun/core/coding/python/errors.py +0 -108
- langfun/core/coding/python/errors_test.py +0 -99
- langfun/core/coding/python/permissions.py +0 -90
- langfun/core/coding/python/permissions_test.py +0 -86
- langfun/core/structured/prompting.py +0 -238
- langfun/core/text_formatting.py +0 -162
- langfun/core/text_formatting_test.py +0 -47
- langfun-0.0.2.dev20240429.dist-info/METADATA +0 -100
- langfun-0.0.2.dev20240429.dist-info/RECORD +0 -108
- {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501150804.dist-info}/LICENSE +0 -0
- {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501150804.dist-info}/top_level.txt +0 -0
langfun/core/eval/base_test.py
CHANGED
@@ -194,6 +194,7 @@ class EvaluationTest(unittest.TestCase):
|
|
194
194
|
cache_seed=0,
|
195
195
|
score=1.0,
|
196
196
|
logprobs=None,
|
197
|
+
is_cached=False,
|
197
198
|
usage=lf.LMSamplingUsage(387, 24, 411),
|
198
199
|
tags=['lm-response', 'lm-output', 'transformed'],
|
199
200
|
),
|
@@ -220,7 +221,18 @@ class EvaluationTest(unittest.TestCase):
|
|
220
221
|
cache_stats=dict(
|
221
222
|
use_cache=True, num_queries=2, num_hits=0, num_updates=2
|
222
223
|
),
|
223
|
-
metrics=dict(
|
224
|
+
metrics=dict(
|
225
|
+
total=2,
|
226
|
+
failures=1,
|
227
|
+
failure_rate=0.5,
|
228
|
+
oop_failures=1,
|
229
|
+
oop_failure_rate=0.5,
|
230
|
+
non_oop_failures=0,
|
231
|
+
non_oop_failure_rate=0.0,
|
232
|
+
failure_breakdown={
|
233
|
+
'MappingError.SchemaError.TypeError': 1
|
234
|
+
}
|
235
|
+
),
|
224
236
|
usage=dict(
|
225
237
|
total_prompt_tokens=774,
|
226
238
|
total_completion_tokens=25,
|
@@ -235,12 +247,20 @@ class EvaluationTest(unittest.TestCase):
|
|
235
247
|
os.path.exists(os.path.join(s.dir, base.Evaluation.EXPERIMENT_JSON)))
|
236
248
|
self.assertTrue(
|
237
249
|
os.path.exists(os.path.join(s.dir, base.Evaluation.RESULT_JSON)))
|
250
|
+
self.assertTrue(
|
251
|
+
os.path.exists(os.path.join(s.dir, base.Evaluation.OOP_FAILURES_JSON)))
|
252
|
+
self.assertTrue(
|
253
|
+
os.path.exists(
|
254
|
+
os.path.join(s.dir, base.Evaluation.NON_OOP_FAILURES_JSON)))
|
238
255
|
self.assertTrue(
|
239
256
|
os.path.exists(os.path.join(s.dir, base.Evaluation.CACHE_JSON)))
|
240
257
|
self.assertTrue(
|
241
258
|
os.path.exists(os.path.join(s.dir, base.Evaluation.INDEX_HTML)))
|
242
259
|
self.assertTrue(
|
243
|
-
os.path.exists(os.path.join(s.dir, base.Evaluation.
|
260
|
+
os.path.exists(os.path.join(s.dir, base.Evaluation.OOP_FAILURES_HTML)))
|
261
|
+
self.assertTrue(
|
262
|
+
os.path.exists(
|
263
|
+
os.path.join(s.dir, base.Evaluation.NON_OOP_FAILURES_HTML)))
|
244
264
|
self.assertTrue(
|
245
265
|
os.path.exists(os.path.join(s.root_dir, base.Evaluation.SUMMARY_HTML))
|
246
266
|
)
|
@@ -249,7 +269,7 @@ class EvaluationTest(unittest.TestCase):
|
|
249
269
|
s.root_dir, base.Evaluation.SUMMARY_HTML.replace('.html', '.json')
|
250
270
|
)
|
251
271
|
self.assertTrue(os.path.exists(summary_json))
|
252
|
-
summary = pg.load(summary_json,
|
272
|
+
summary = pg.load(summary_json, auto_dict=True)
|
253
273
|
self.assertIn('Evaluation', summary)
|
254
274
|
self.assertEqual(len(summary['Evaluation']), 1)
|
255
275
|
self.assertIsNotNone(summary['Evaluation'][0].experiment)
|
@@ -274,7 +294,10 @@ class EvaluationTest(unittest.TestCase):
|
|
274
294
|
self.assertFalse(
|
275
295
|
os.path.exists(os.path.join(s.dir, base.Evaluation.INDEX_HTML)))
|
276
296
|
self.assertFalse(
|
277
|
-
os.path.exists(os.path.join(s.dir, base.Evaluation.
|
297
|
+
os.path.exists(os.path.join(s.dir, base.Evaluation.OOP_FAILURES_HTML)))
|
298
|
+
self.assertFalse(
|
299
|
+
os.path.exists(
|
300
|
+
os.path.join(s.dir, base.Evaluation.NON_OOP_FAILURES_HTML)))
|
278
301
|
|
279
302
|
def test_load(self):
|
280
303
|
lm = fake.StaticResponse('Solution(final_answer=2)')
|
@@ -312,7 +335,16 @@ class EvaluationTest(unittest.TestCase):
|
|
312
335
|
cache_stats=dict(
|
313
336
|
use_cache=True, num_queries=2, num_hits=0, num_updates=2
|
314
337
|
),
|
315
|
-
metrics=dict(
|
338
|
+
metrics=dict(
|
339
|
+
total=2,
|
340
|
+
failures=0,
|
341
|
+
failure_rate=0.0,
|
342
|
+
oop_failures=0,
|
343
|
+
oop_failure_rate=0.0,
|
344
|
+
non_oop_failures=0,
|
345
|
+
non_oop_failure_rate=0.0,
|
346
|
+
failure_breakdown={},
|
347
|
+
),
|
316
348
|
usage=s.children[1].result.usage,
|
317
349
|
),
|
318
350
|
},
|
@@ -363,7 +395,18 @@ class EvaluationTest(unittest.TestCase):
|
|
363
395
|
cache_stats=dict(
|
364
396
|
use_cache=True, num_queries=2, num_hits=0, num_updates=2
|
365
397
|
),
|
366
|
-
metrics=dict(
|
398
|
+
metrics=dict(
|
399
|
+
total=2,
|
400
|
+
failures=1,
|
401
|
+
failure_rate=0.5,
|
402
|
+
oop_failures=1,
|
403
|
+
oop_failure_rate=0.5,
|
404
|
+
non_oop_failures=0,
|
405
|
+
non_oop_failure_rate=0.0,
|
406
|
+
failure_breakdown={
|
407
|
+
'MappingError.SchemaError.TypeError': 1
|
408
|
+
}
|
409
|
+
),
|
367
410
|
usage=s.children[0].result.usage,
|
368
411
|
),
|
369
412
|
s.children[1].id: dict(
|
@@ -378,7 +421,18 @@ class EvaluationTest(unittest.TestCase):
|
|
378
421
|
cache_stats=dict(
|
379
422
|
use_cache=True, num_queries=2, num_hits=0, num_updates=2
|
380
423
|
),
|
381
|
-
metrics=dict(
|
424
|
+
metrics=dict(
|
425
|
+
total=2,
|
426
|
+
failures=1,
|
427
|
+
failure_rate=0.5,
|
428
|
+
oop_failures=1,
|
429
|
+
oop_failure_rate=0.5,
|
430
|
+
non_oop_failures=0,
|
431
|
+
non_oop_failure_rate=0.0,
|
432
|
+
failure_breakdown={
|
433
|
+
'MappingError.SchemaError.TypeError': 1
|
434
|
+
}
|
435
|
+
),
|
382
436
|
usage=s.children[1].result.usage,
|
383
437
|
),
|
384
438
|
},
|
@@ -475,7 +529,7 @@ class SuiteTest(unittest.TestCase):
|
|
475
529
|
self.assertEqual(s.hash, '26e6cc25')
|
476
530
|
s.run()
|
477
531
|
expected = {
|
478
|
-
|
532
|
+
'Evaluation@0fade07d': dict(
|
479
533
|
experiment_setup=dict(
|
480
534
|
id=s.children[0].id,
|
481
535
|
dir=s.children[0].dir,
|
@@ -487,48 +541,46 @@ class SuiteTest(unittest.TestCase):
|
|
487
541
|
cache_stats=dict(
|
488
542
|
use_cache=True, num_queries=2, num_hits=0, num_updates=2
|
489
543
|
),
|
490
|
-
metrics=dict(
|
544
|
+
metrics=dict(
|
545
|
+
total=2,
|
546
|
+
failures=1,
|
547
|
+
failure_rate=0.5,
|
548
|
+
oop_failures=1,
|
549
|
+
oop_failure_rate=0.5,
|
550
|
+
non_oop_failures=0,
|
551
|
+
non_oop_failure_rate=0.0,
|
552
|
+
failure_breakdown={
|
553
|
+
'MappingError.SchemaError.TypeError': 1
|
554
|
+
}
|
555
|
+
),
|
491
556
|
usage=s.children[0].result.usage,
|
492
557
|
),
|
493
|
-
|
494
|
-
|
495
|
-
|
496
|
-
|
497
|
-
|
498
|
-
|
499
|
-
|
500
|
-
|
501
|
-
prompt_template='{{example.question}}',
|
502
|
-
method='call',
|
503
|
-
schema_fn='answer_schema()',
|
504
|
-
),
|
505
|
-
cache_stats=dict(
|
506
|
-
use_cache=True, num_queries=4, num_hits=1, num_updates=3
|
507
|
-
),
|
508
|
-
metrics=dict(total=2, failures=2, failure_rate=1.0),
|
509
|
-
usage=s.children[1].children[0].result.usage,
|
558
|
+
'Evaluation@ae86c703': dict(
|
559
|
+
experiment_setup=dict(
|
560
|
+
id=s.children[1].children[0].id,
|
561
|
+
dir=s.children[1].children[0].dir,
|
562
|
+
model='StaticSequence',
|
563
|
+
prompt_template='{{example.question}}',
|
564
|
+
method='call',
|
565
|
+
schema_fn='answer_schema()',
|
510
566
|
),
|
511
|
-
|
512
|
-
|
513
|
-
.id: dict(
|
514
|
-
experiment_setup=dict(
|
515
|
-
id=s.children[1].children[2].id,
|
516
|
-
dir=s.children[1].children[2].dir,
|
517
|
-
model='StaticSequence',
|
518
|
-
prompt_template='{{example.question}}',
|
519
|
-
method='query',
|
520
|
-
schema_fn='answer_schema()',
|
521
|
-
),
|
522
|
-
cache_stats=dict(
|
523
|
-
use_cache=True,
|
524
|
-
num_queries=2,
|
525
|
-
num_hits=0,
|
526
|
-
num_updates=2,
|
527
|
-
),
|
528
|
-
metrics=dict(total=2, failures=1, failure_rate=0.5),
|
529
|
-
usage=s.children[1].children[2].result.usage,
|
567
|
+
cache_stats=dict(
|
568
|
+
use_cache=True, num_queries=4, num_hits=0, num_updates=4
|
530
569
|
),
|
531
|
-
|
570
|
+
metrics=dict(
|
571
|
+
total=2,
|
572
|
+
failures=2,
|
573
|
+
failure_rate=1.0,
|
574
|
+
oop_failures=2,
|
575
|
+
oop_failure_rate=1.0,
|
576
|
+
non_oop_failures=0,
|
577
|
+
non_oop_failure_rate=0.0,
|
578
|
+
failure_breakdown={
|
579
|
+
'MappingError.SchemaError.TypeError': 2
|
580
|
+
}
|
581
|
+
),
|
582
|
+
usage=s.children[1].children[0].result.usage,
|
583
|
+
),
|
532
584
|
}
|
533
585
|
self.assertEqual(s.result, expected)
|
534
586
|
|
@@ -547,6 +599,14 @@ class InputsFrom(unittest.TestCase):
|
|
547
599
|
pg.save([1, 2, 3], path)
|
548
600
|
self.assertEqual(base.inputs_from(path)(), [1, 2, 3])
|
549
601
|
|
602
|
+
path = os.path.join(tmp_dir, 'input_file.jsonl')
|
603
|
+
with pg.open_jsonl(path, 'w') as f:
|
604
|
+
f.add(pg.Dict(x=1))
|
605
|
+
f.add(dict(y=2))
|
606
|
+
self.assertEqual(
|
607
|
+
base.inputs_from(path)(), [pg.Dict(x=1), dict(y=2)]
|
608
|
+
)
|
609
|
+
|
550
610
|
def test_inputs_from_multiple_files(self):
|
551
611
|
tmp_dir = tempfile.gettempdir()
|
552
612
|
path1 = os.path.join(tmp_dir, 'input_file1.json')
|
@@ -698,16 +758,102 @@ class SummaryTest(unittest.TestCase):
|
|
698
758
|
self.assertTrue(pg.io.path_exists(summary_file))
|
699
759
|
|
700
760
|
|
701
|
-
class
|
761
|
+
class NamedEvaluationTest(unittest.TestCase):
|
702
762
|
|
703
|
-
def
|
704
|
-
|
705
|
-
|
706
|
-
|
707
|
-
|
763
|
+
def test_named_eval_class(self):
|
764
|
+
|
765
|
+
@base.register('named_eval/class_test')
|
766
|
+
class MyEval(base.Evaluation):
|
767
|
+
inputs = base.as_inputs([
|
768
|
+
pg.Dict(question='Compute 1 + 1'),
|
769
|
+
])
|
770
|
+
method = 'query'
|
771
|
+
prompt = pg.oneof([
|
772
|
+
lf.Template('{{example.question}}'),
|
773
|
+
lf.Template('Hello {{example.question}}'),
|
774
|
+
])
|
775
|
+
schema_fn = answer_schema()
|
776
|
+
|
777
|
+
[evaluation] = base.get_evaluations('named_eval/class_test')
|
778
|
+
self.assertIsInstance(evaluation, MyEval)
|
779
|
+
self.assertIsNone(evaluation.dir)
|
780
|
+
self.assertIsNone(evaluation.root_dir)
|
781
|
+
self.assertIn('named_eval/class_test', base.registered_names())
|
782
|
+
|
783
|
+
with self.assertRaisesRegex(ValueError, 'Unsupported type.*'):
|
784
|
+
@base.register('named_eval/bad_class')
|
785
|
+
class Foo: # pylint: disable=unused-variable
|
786
|
+
pass
|
787
|
+
|
788
|
+
def test_named_eval_functor(self):
|
789
|
+
|
790
|
+
@base.register('named_eval/functor_test')
|
791
|
+
def my_eval():
|
792
|
+
return base.Evaluation(
|
793
|
+
inputs=base.as_inputs([
|
794
|
+
pg.Dict(question='Compute 1 + 1'),
|
795
|
+
]),
|
796
|
+
method='query',
|
797
|
+
prompt=pg.oneof([
|
798
|
+
lf.Template('{{example.question}}'),
|
799
|
+
lf.Template('Hello {{example.question}}'),
|
800
|
+
]),
|
801
|
+
schema_fn=answer_schema(),
|
708
802
|
)
|
709
|
-
|
710
|
-
|
803
|
+
|
804
|
+
self.assertTrue(issubclass(my_eval, base.Evaluable))
|
805
|
+
[evaluation] = base.get_evaluations('named_eval/functor_test')
|
806
|
+
self.assertIn('named_eval/functor_test', base.registered_names())
|
807
|
+
self.assertIsInstance(evaluation, my_eval)
|
808
|
+
self.assertIsNone(evaluation.root_dir, None)
|
809
|
+
|
810
|
+
self.assertTrue(
|
811
|
+
pg.eq(base.get_evaluations('named_eval/functor.*'), [evaluation])
|
812
|
+
)
|
813
|
+
self.assertEqual(base.get_evaluations('named_eval/non_existent'), [])
|
814
|
+
|
815
|
+
with self.assertRaisesRegex(TypeError, 'The return value .*'):
|
816
|
+
@base.register('named_eval/bad_return_type')
|
817
|
+
def bad_eval(): # pylint: disable=unused-variable
|
818
|
+
return 1
|
819
|
+
|
820
|
+
def test_run(self):
|
821
|
+
@base.register('test/run')
|
822
|
+
def test_run(): # pylint: disable=unused-variable
|
823
|
+
lm = fake.StaticResponse('Solution(final_answer=2)')
|
824
|
+
return eval_set('run_test', 'query', schema_fn=answer_schema(), lm=lm)
|
825
|
+
|
826
|
+
e = base.run(
|
827
|
+
tempfile.gettempdir(),
|
828
|
+
['test/run'],
|
829
|
+
id_regex='run_test.*',
|
830
|
+
mode='dryrun',
|
831
|
+
print_definition=True,
|
832
|
+
)
|
833
|
+
self.assertEqual(
|
834
|
+
e.leaf_nodes[0].dir,
|
835
|
+
os.path.join(tempfile.gettempdir(), e.leaf_nodes[0].id),
|
836
|
+
)
|
837
|
+
self.assertTrue(
|
838
|
+
pg.eq(
|
839
|
+
e.leaf_nodes[0].lm, fake.StaticResponse('Solution(final_answer=2)')
|
840
|
+
)
|
841
|
+
)
|
842
|
+
|
843
|
+
@pg.patcher()
|
844
|
+
def bad_lm(unused_eval): # pylint: disable=unused-variable
|
845
|
+
return dict(lm=fake.StaticResponse('efg'))
|
846
|
+
|
847
|
+
e = base.run(
|
848
|
+
tempfile.gettempdir(),
|
849
|
+
[test_run()],
|
850
|
+
filter='Evaluation.*',
|
851
|
+
patches=['bad_lm']
|
852
|
+
)
|
853
|
+
self.assertTrue(pg.eq(e.leaf_nodes[0].lm, fake.StaticResponse('efg')))
|
854
|
+
|
855
|
+
with self.assertRaisesRegex(ValueError, 'No evaluations found'):
|
856
|
+
base.run(tempfile.gettempdir(), ['test/non_existent'])
|
711
857
|
|
712
858
|
|
713
859
|
if __name__ == '__main__':
|
langfun/core/eval/matching.py
CHANGED
@@ -41,8 +41,8 @@ class Matching(base.Evaluation):
|
|
41
41
|
"""Returns the answer from the structure output."""
|
42
42
|
|
43
43
|
@property
|
44
|
-
def matches(self) -> list[tuple[Any, Any, lf.Message]]:
|
45
|
-
"""Returns the matches examples, outputs and the output messages."""
|
44
|
+
def matches(self) -> list[tuple[int, Any, Any, lf.Message]]:
|
45
|
+
"""Returns the matches IDs, examples, outputs and the output messages."""
|
46
46
|
return self._matches
|
47
47
|
|
48
48
|
@property
|
@@ -57,7 +57,7 @@ class Matching(base.Evaluation):
|
|
57
57
|
return self.num_matches / self.num_completed
|
58
58
|
|
59
59
|
@property
|
60
|
-
def mismatches(self) -> list[tuple[Any, Any, lf.Message]]:
|
60
|
+
def mismatches(self) -> list[tuple[int, Any, Any, lf.Message]]:
|
61
61
|
"""Returns the mismatches examples, outputs and output messages."""
|
62
62
|
return self._mismatches
|
63
63
|
|
@@ -87,7 +87,8 @@ class Matching(base.Evaluation):
|
|
87
87
|
self._mismatches = []
|
88
88
|
|
89
89
|
def audit_processed(
|
90
|
-
self, example: Any, output: Any, message: lf.Message,
|
90
|
+
self, example_idx: int, example: Any, output: Any, message: lf.Message,
|
91
|
+
dryrun: bool = False
|
91
92
|
) -> None:
|
92
93
|
groundtruth = self.groundtruth(example)
|
93
94
|
answer = self.answer(output, example)
|
@@ -107,30 +108,29 @@ class Matching(base.Evaluation):
|
|
107
108
|
)
|
108
109
|
|
109
110
|
if self.match(answer, groundtruth):
|
110
|
-
self._matches.append((example, output, message))
|
111
|
+
self._matches.append((example_idx, example, output, message))
|
111
112
|
else:
|
112
|
-
self._mismatches.append((example, output, message))
|
113
|
+
self._mismatches.append((example_idx, example, output, message))
|
113
114
|
|
114
115
|
def match(self, answer: Any, groundtruth: Any) -> bool:
|
115
116
|
"""Matches answer against the groundtruth. Subclasses can override."""
|
116
117
|
return pg.eq(answer, groundtruth)
|
117
118
|
|
118
|
-
def
|
119
|
+
def _eval_status(self, progress: lf.concurrent.Progress) -> dict[str, Any]:
|
119
120
|
del progress
|
120
121
|
return {
|
121
|
-
'
|
122
|
-
|
123
|
-
self.match_rate * 100,
|
122
|
+
'Matches': '%s (%d/%d)' % (
|
123
|
+
self._format_rate(self.match_rate),
|
124
124
|
self.num_matches,
|
125
125
|
self.num_completed,
|
126
126
|
),
|
127
|
-
'Mismatches':
|
128
|
-
self.mismatch_rate
|
127
|
+
'Mismatches': '%s (%d/%d)' % (
|
128
|
+
self._format_rate(self.mismatch_rate),
|
129
129
|
self.num_mismatches,
|
130
130
|
self.num_completed,
|
131
131
|
),
|
132
|
-
'Failed':
|
133
|
-
self.failure_rate
|
132
|
+
'Failed': '%s (%d/%d)' % (
|
133
|
+
self._format_rate(self.failure_rate),
|
134
134
|
self.num_failures,
|
135
135
|
self.num_completed,
|
136
136
|
),
|
@@ -140,24 +140,25 @@ class Matching(base.Evaluation):
|
|
140
140
|
assert self.result is not None
|
141
141
|
m = self.result.metrics
|
142
142
|
return (
|
143
|
-
|
144
|
-
|
145
|
-
|
143
|
+
'COMPLETED(%s):'
|
144
|
+
' Matches=%s (%d/%d)'
|
145
|
+
' Mismatches=%s (%d/%d)'
|
146
|
+
' Failures=%s (%d/%d)'
|
146
147
|
) % (
|
147
148
|
run_status,
|
148
|
-
m.match_rate
|
149
|
+
self._format_rate(m.match_rate),
|
149
150
|
m.num_matches,
|
150
151
|
m.total,
|
151
|
-
m.mismatch_rate
|
152
|
+
self._format_rate(m.mismatch_rate),
|
152
153
|
m.num_mismatches,
|
153
154
|
m.total,
|
154
|
-
m.failure_rate
|
155
|
+
self._format_rate(m.failure_rate),
|
155
156
|
m.failures,
|
156
157
|
m.total,
|
157
158
|
)
|
158
159
|
|
159
|
-
def
|
160
|
-
result = super().
|
160
|
+
def finalize(self) -> pg.Dict:
|
161
|
+
result = super().finalize()
|
161
162
|
result.metrics.update(
|
162
163
|
num_matches=self.num_matches,
|
163
164
|
match_rate=self.match_rate,
|
@@ -171,33 +172,6 @@ class Matching(base.Evaluation):
|
|
171
172
|
) -> None:
|
172
173
|
super().save(definition, result, report)
|
173
174
|
|
174
|
-
if result:
|
175
|
-
# Save matches.
|
176
|
-
pg.save(
|
177
|
-
[
|
178
|
-
pg.Dict(input=input, output=output)
|
179
|
-
for input, output, _ in self.matches
|
180
|
-
],
|
181
|
-
os.path.join(self.dir, Matching.MATCHES_JSON),
|
182
|
-
# We force the input and output to be dict so it does not depend on
|
183
|
-
# the downstream to serialize.
|
184
|
-
force_dict=True,
|
185
|
-
)
|
186
|
-
|
187
|
-
# Save mismatches.
|
188
|
-
pg.save(
|
189
|
-
[
|
190
|
-
# We force the output to be dict as its type may be defined
|
191
|
-
# within functors which could be deserialized.
|
192
|
-
pg.Dict(input=input, output=output)
|
193
|
-
for input, output, _ in self.mismatches
|
194
|
-
],
|
195
|
-
os.path.join(self.dir, Matching.MISMATCHES_JSON),
|
196
|
-
# We force the input and output to be dict so it does not depend on
|
197
|
-
# the downstream to serialize.
|
198
|
-
force_dict=True,
|
199
|
-
)
|
200
|
-
|
201
175
|
if report:
|
202
176
|
pg.save(
|
203
177
|
self._html([self._render_result, self._render_matches]),
|
@@ -218,9 +192,9 @@ class Matching(base.Evaluation):
|
|
218
192
|
def _render_result_row(self, s: io.StringIO):
|
219
193
|
super()._render_result_row(s)
|
220
194
|
s.write(
|
221
|
-
'<td><span style="color:
|
195
|
+
'<td><span style="color:orange">%s</span>%s</td>'
|
222
196
|
% (
|
223
|
-
|
197
|
+
self._format_rate(self.mismatch_rate),
|
224
198
|
'<a href="%s">(%d/%d)</a>'
|
225
199
|
% (self.mismatches_link, self.num_mismatches, self.num_completed),
|
226
200
|
)
|
@@ -228,37 +202,33 @@ class Matching(base.Evaluation):
|
|
228
202
|
s.write(
|
229
203
|
'<td><span style="color:green">%s</span>%s</td>'
|
230
204
|
% (
|
231
|
-
|
205
|
+
self._format_rate(self.match_rate),
|
232
206
|
'<a href="%s">(%d/%d)</a>'
|
233
207
|
% (self.matches_link, self.num_matches, self.num_completed),
|
234
208
|
)
|
235
209
|
)
|
236
210
|
|
237
|
-
def
|
211
|
+
def _render_summary_metrics(self, s: io.StringIO) -> None:
|
238
212
|
"""Renders metrics in HTML."""
|
239
213
|
assert self.result is not None
|
240
214
|
m = self.result.metrics
|
241
|
-
|
242
|
-
|
243
|
-
% (
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
f'%.{self.report_precision}f%% ' % (m.match_rate * 100),
|
248
|
-
)
|
215
|
+
self._render_link(
|
216
|
+
s,
|
217
|
+
'Matches (%d/%d)' % (m.num_matches, m.total),
|
218
|
+
self._format_rate(m.match_rate),
|
219
|
+
'color:green',
|
220
|
+
lambda: self.matches_link,
|
249
221
|
)
|
250
222
|
s.write(' | ')
|
251
|
-
|
252
|
-
|
253
|
-
% (
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
f'%.{self.report_precision}f%% ' % (m.mismatch_rate * 100),
|
258
|
-
)
|
223
|
+
self._render_link(
|
224
|
+
s,
|
225
|
+
'Mismatches (%d/%d)' % (m.num_mismatches, m.total),
|
226
|
+
self._format_rate(m.mismatch_rate),
|
227
|
+
'color:orange',
|
228
|
+
lambda: self.mismatches_link,
|
259
229
|
)
|
260
230
|
s.write(' | ')
|
261
|
-
super().
|
231
|
+
super()._render_summary_metrics(s)
|
262
232
|
|
263
233
|
def _render_matches(self, s: io.StringIO) -> None:
|
264
234
|
"""Formats the matched cases into html."""
|
@@ -271,12 +241,29 @@ class Matching(base.Evaluation):
|
|
271
241
|
'<td>Prompt/Response Chain</td>'
|
272
242
|
'</tr>'
|
273
243
|
)
|
274
|
-
|
244
|
+
def _maybe_html(v, root_indent: int):
|
245
|
+
del root_indent
|
246
|
+
if hasattr(v, '_repr_html_'):
|
247
|
+
return v._repr_html_() # pylint: disable=protected-access
|
248
|
+
# Fall back to the default format.
|
249
|
+
return None
|
250
|
+
|
251
|
+
for i, (_, example, output, message) in enumerate(self.matches):
|
275
252
|
bgcolor = 'white' if i % 2 == 0 else '#DDDDDD'
|
276
253
|
s.write(f'<tr style="background-color: {bgcolor}"><td>{i + 1}</td>')
|
277
|
-
input_str = pg.
|
254
|
+
input_str = pg.Html.escape(
|
255
|
+
pg.format(
|
256
|
+
example, verbose=False, max_bytes_len=32,
|
257
|
+
custom_format=_maybe_html
|
258
|
+
)
|
259
|
+
)
|
278
260
|
s.write(f'<td style="color:green;white-space:pre-wrap">{input_str}</td>')
|
279
|
-
output_str = pg.
|
261
|
+
output_str = pg.Html.escape(
|
262
|
+
pg.format(
|
263
|
+
output, verbose=False, max_bytes_len=32,
|
264
|
+
custom_format=_maybe_html
|
265
|
+
)
|
266
|
+
)
|
280
267
|
s.write(f'<td style="color:blue;white-space:pre-wrap">{output_str}</td>')
|
281
268
|
s.write('<td>')
|
282
269
|
self._render_message(message, s)
|
@@ -296,12 +283,12 @@ class Matching(base.Evaluation):
|
|
296
283
|
'</tr>'
|
297
284
|
)
|
298
285
|
|
299
|
-
for i, (example, output, message) in enumerate(self.mismatches):
|
286
|
+
for i, (_, example, output, message) in enumerate(self.mismatches):
|
300
287
|
bgcolor = 'white' if i % 2 == 0 else '#DDDDDD'
|
301
288
|
s.write(f'<tr style="background-color: {bgcolor}"><td>{i + 1}</td>')
|
302
|
-
input_str = pg.format(example, verbose=False)
|
289
|
+
input_str = pg.format(example, verbose=False, max_bytes_len=32)
|
303
290
|
s.write(f'<td style="color:green;white-space:pre-wrap">{input_str}</td>')
|
304
|
-
output_str = pg.format(output, verbose=False)
|
291
|
+
output_str = pg.format(output, verbose=False, max_bytes_len=32)
|
305
292
|
s.write(
|
306
293
|
f'<td style="color:magenta;white-space:pre-wrap">{output_str}</td>'
|
307
294
|
)
|
@@ -120,6 +120,13 @@ class MatchingTest(unittest.TestCase):
|
|
120
120
|
total=4,
|
121
121
|
failures=1,
|
122
122
|
failure_rate=0.25,
|
123
|
+
oop_failures=1,
|
124
|
+
oop_failure_rate=0.25,
|
125
|
+
non_oop_failures=0,
|
126
|
+
non_oop_failure_rate=0.0,
|
127
|
+
failure_breakdown={
|
128
|
+
'MappingError.SchemaError.TypeError': 1
|
129
|
+
},
|
123
130
|
num_matches=2,
|
124
131
|
match_rate=0.5,
|
125
132
|
num_mismatches=1,
|
@@ -145,22 +152,17 @@ class MatchingTest(unittest.TestCase):
|
|
145
152
|
os.path.join(s.dir, matching.Matching.CACHE_JSON)
|
146
153
|
)
|
147
154
|
)
|
148
|
-
self.assertTrue(
|
149
|
-
os.path.exists(
|
150
|
-
os.path.join(s.dir, matching.Matching.MATCHES_JSON)
|
151
|
-
)
|
152
|
-
)
|
153
155
|
self.assertTrue(
|
154
156
|
os.path.exists(
|
155
157
|
os.path.join(
|
156
|
-
s.dir, matching.Matching.
|
158
|
+
s.dir, matching.Matching.OOP_FAILURES_JSON
|
157
159
|
)
|
158
160
|
)
|
159
161
|
)
|
160
162
|
self.assertTrue(
|
161
163
|
os.path.exists(
|
162
164
|
os.path.join(
|
163
|
-
s.dir, matching.Matching.
|
165
|
+
s.dir, matching.Matching.NON_OOP_FAILURES_JSON
|
164
166
|
)
|
165
167
|
)
|
166
168
|
)
|
@@ -175,7 +177,14 @@ class MatchingTest(unittest.TestCase):
|
|
175
177
|
self.assertTrue(
|
176
178
|
os.path.exists(
|
177
179
|
os.path.join(
|
178
|
-
s.dir, matching.Matching.
|
180
|
+
s.dir, matching.Matching.OOP_FAILURES_HTML
|
181
|
+
)
|
182
|
+
)
|
183
|
+
)
|
184
|
+
self.assertTrue(
|
185
|
+
os.path.exists(
|
186
|
+
os.path.join(
|
187
|
+
s.dir, matching.Matching.NON_OOP_FAILURES_HTML
|
179
188
|
)
|
180
189
|
)
|
181
190
|
)
|