langfun 0.0.2.dev20240330__py3-none-any.whl → 0.0.2.dev20240511__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of langfun might be problematic. Click here for more details.
- langfun/__init__.py +7 -0
- langfun/core/__init__.py +1 -0
- langfun/core/coding/python/correction.py +0 -7
- langfun/core/component.py +6 -0
- langfun/core/component_test.py +1 -0
- langfun/core/eval/__init__.py +15 -0
- langfun/core/eval/base.py +665 -95
- langfun/core/eval/base_test.py +224 -53
- langfun/core/eval/matching.py +48 -30
- langfun/core/eval/matching_test.py +25 -3
- langfun/core/eval/patching.py +130 -0
- langfun/core/eval/patching_test.py +170 -0
- langfun/core/eval/scoring.py +19 -10
- langfun/core/eval/scoring_test.py +21 -3
- langfun/core/langfunc.py +1 -22
- langfun/core/langfunc_test.py +10 -4
- langfun/core/language_model.py +130 -24
- langfun/core/language_model_test.py +249 -26
- langfun/core/llms/__init__.py +27 -2
- langfun/core/llms/anthropic.py +263 -0
- langfun/core/llms/anthropic_test.py +167 -0
- langfun/core/llms/cache/in_memory_test.py +37 -28
- langfun/core/llms/fake.py +34 -25
- langfun/core/llms/fake_test.py +122 -11
- langfun/core/llms/google_genai.py +8 -0
- langfun/core/llms/google_genai_test.py +8 -3
- langfun/core/llms/groq.py +260 -0
- langfun/core/llms/groq_test.py +170 -0
- langfun/core/llms/llama_cpp.py +3 -1
- langfun/core/llms/openai.py +100 -81
- langfun/core/llms/openai_test.py +287 -60
- langfun/core/llms/vertexai.py +291 -0
- langfun/core/llms/vertexai_test.py +233 -0
- langfun/core/modalities/image.py +1 -3
- langfun/core/modalities/mime.py +6 -0
- langfun/core/modalities/video.py +6 -5
- langfun/core/structured/__init__.py +5 -0
- langfun/core/structured/completion_test.py +2 -2
- langfun/core/structured/function_generation.py +245 -0
- langfun/core/structured/function_generation_test.py +329 -0
- langfun/core/structured/mapping.py +61 -3
- langfun/core/structured/mapping_test.py +17 -0
- langfun/core/structured/parsing_test.py +18 -13
- langfun/core/structured/prompting.py +61 -12
- langfun/core/structured/prompting_test.py +122 -12
- langfun/core/structured/schema.py +38 -6
- langfun/core/structured/schema_generation_test.py +2 -2
- langfun/core/structured/schema_test.py +36 -7
- langfun/core/structured/scoring.py +4 -1
- langfun/core/structured/scoring_test.py +6 -0
- langfun/core/template.py +147 -11
- langfun/core/template_test.py +75 -0
- langfun/core/templates/selfplay_test.py +6 -2
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240511.dist-info}/METADATA +3 -2
- langfun-0.0.2.dev20240511.dist-info/RECORD +112 -0
- langfun-0.0.2.dev20240330.dist-info/RECORD +0 -102
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240511.dist-info}/LICENSE +0 -0
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240511.dist-info}/WHEEL +0 -0
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240511.dist-info}/top_level.txt +0 -0
langfun/core/eval/base_test.py
CHANGED
@@ -101,7 +101,7 @@ class EvaluationTest(unittest.TestCase):
|
|
101
101
|
self.assertEqual(s.dir, os.path.join(s.root_dir, s.id))
|
102
102
|
self.assertEqual(s.hash, s.clone().hash)
|
103
103
|
# Test persistent hash.
|
104
|
-
self.assertEqual(s.hash, '
|
104
|
+
self.assertEqual(s.hash, 'ae86c703')
|
105
105
|
self.assertEqual(
|
106
106
|
s.hash, s.clone(override={'max_workers': 2, 'lm.timeout': 20}).hash
|
107
107
|
)
|
@@ -194,6 +194,7 @@ class EvaluationTest(unittest.TestCase):
|
|
194
194
|
cache_seed=0,
|
195
195
|
score=1.0,
|
196
196
|
logprobs=None,
|
197
|
+
usage=lf.LMSamplingUsage(387, 24, 411),
|
197
198
|
tags=['lm-response', 'lm-output', 'transformed'],
|
198
199
|
),
|
199
200
|
)
|
@@ -209,7 +210,7 @@ class EvaluationTest(unittest.TestCase):
|
|
209
210
|
s.result,
|
210
211
|
dict(
|
211
212
|
experiment_setup=dict(
|
212
|
-
id='Evaluation@
|
213
|
+
id='Evaluation@0fade07d',
|
213
214
|
dir=s.dir,
|
214
215
|
model='StaticSequence',
|
215
216
|
prompt_template='{{example.question}}',
|
@@ -219,7 +220,26 @@ class EvaluationTest(unittest.TestCase):
|
|
219
220
|
cache_stats=dict(
|
220
221
|
use_cache=True, num_queries=2, num_hits=0, num_updates=2
|
221
222
|
),
|
222
|
-
metrics=dict(
|
223
|
+
metrics=dict(
|
224
|
+
total=2,
|
225
|
+
failures=1,
|
226
|
+
failure_rate=0.5,
|
227
|
+
oop_failures=1,
|
228
|
+
oop_failure_rate=0.5,
|
229
|
+
non_oop_failures=0,
|
230
|
+
non_oop_failure_rate=0.0,
|
231
|
+
failure_breakdown={
|
232
|
+
'MappingError.SchemaError.TypeError': 1
|
233
|
+
}
|
234
|
+
),
|
235
|
+
usage=dict(
|
236
|
+
total_prompt_tokens=774,
|
237
|
+
total_completion_tokens=25,
|
238
|
+
num_usages=2,
|
239
|
+
average_prompt_tokens=387,
|
240
|
+
average_completion_tokens=12,
|
241
|
+
average_total_tokens=399,
|
242
|
+
),
|
223
243
|
),
|
224
244
|
)
|
225
245
|
self.assertTrue(
|
@@ -227,14 +247,32 @@ class EvaluationTest(unittest.TestCase):
|
|
227
247
|
self.assertTrue(
|
228
248
|
os.path.exists(os.path.join(s.dir, base.Evaluation.RESULT_JSON)))
|
229
249
|
self.assertTrue(
|
230
|
-
os.path.exists(os.path.join(s.dir, base.Evaluation.
|
250
|
+
os.path.exists(os.path.join(s.dir, base.Evaluation.OOP_FAILURES_JSON)))
|
231
251
|
self.assertTrue(
|
232
|
-
os.path.exists(
|
233
|
-
|
252
|
+
os.path.exists(
|
253
|
+
os.path.join(s.dir, base.Evaluation.NON_OOP_FAILURES_JSON)))
|
254
|
+
self.assertTrue(
|
255
|
+
os.path.exists(os.path.join(s.dir, base.Evaluation.CACHE_JSON)))
|
234
256
|
self.assertTrue(
|
235
257
|
os.path.exists(os.path.join(s.dir, base.Evaluation.INDEX_HTML)))
|
236
258
|
self.assertTrue(
|
237
|
-
os.path.exists(os.path.join(s.dir, base.Evaluation.
|
259
|
+
os.path.exists(os.path.join(s.dir, base.Evaluation.OOP_FAILURES_HTML)))
|
260
|
+
self.assertTrue(
|
261
|
+
os.path.exists(
|
262
|
+
os.path.join(s.dir, base.Evaluation.NON_OOP_FAILURES_HTML)))
|
263
|
+
self.assertTrue(
|
264
|
+
os.path.exists(os.path.join(s.root_dir, base.Evaluation.SUMMARY_HTML))
|
265
|
+
)
|
266
|
+
# Check summary JSON.
|
267
|
+
summary_json = os.path.join(
|
268
|
+
s.root_dir, base.Evaluation.SUMMARY_HTML.replace('.html', '.json')
|
269
|
+
)
|
270
|
+
self.assertTrue(os.path.exists(summary_json))
|
271
|
+
summary = pg.load(summary_json, force_dict=True)
|
272
|
+
self.assertIn('Evaluation', summary)
|
273
|
+
self.assertEqual(len(summary['Evaluation']), 1)
|
274
|
+
self.assertIsNotNone(summary['Evaluation'][0].experiment)
|
275
|
+
self.assertIsNotNone(summary['Evaluation'][0].metrics)
|
238
276
|
|
239
277
|
def test_run_wihtout_save(self):
|
240
278
|
lm = fake.StaticSequence([
|
@@ -255,7 +293,10 @@ class EvaluationTest(unittest.TestCase):
|
|
255
293
|
self.assertFalse(
|
256
294
|
os.path.exists(os.path.join(s.dir, base.Evaluation.INDEX_HTML)))
|
257
295
|
self.assertFalse(
|
258
|
-
os.path.exists(os.path.join(s.dir, base.Evaluation.
|
296
|
+
os.path.exists(os.path.join(s.dir, base.Evaluation.OOP_FAILURES_HTML)))
|
297
|
+
self.assertFalse(
|
298
|
+
os.path.exists(
|
299
|
+
os.path.join(s.dir, base.Evaluation.NON_OOP_FAILURES_HTML)))
|
259
300
|
|
260
301
|
def test_load(self):
|
261
302
|
lm = fake.StaticResponse('Solution(final_answer=2)')
|
@@ -274,8 +315,11 @@ class EvaluationTest(unittest.TestCase):
|
|
274
315
|
s = eval_set(
|
275
316
|
'run_filter_test', pg.oneof(['call', 'query']),
|
276
317
|
schema_fn=answer_schema(), lm=lm)
|
318
|
+
result = s.run(
|
319
|
+
filter=lambda x: x.method == 'query', dryrun=True, summary=False
|
320
|
+
)
|
277
321
|
self.assertEqual(
|
278
|
-
|
322
|
+
result,
|
279
323
|
{
|
280
324
|
s.children[0].id: None,
|
281
325
|
s.children[1].id: dict(
|
@@ -290,8 +334,18 @@ class EvaluationTest(unittest.TestCase):
|
|
290
334
|
cache_stats=dict(
|
291
335
|
use_cache=True, num_queries=2, num_hits=0, num_updates=2
|
292
336
|
),
|
293
|
-
metrics=dict(
|
294
|
-
|
337
|
+
metrics=dict(
|
338
|
+
total=2,
|
339
|
+
failures=0,
|
340
|
+
failure_rate=0.0,
|
341
|
+
oop_failures=0,
|
342
|
+
oop_failure_rate=0.0,
|
343
|
+
non_oop_failures=0,
|
344
|
+
non_oop_failure_rate=0.0,
|
345
|
+
failure_breakdown={},
|
346
|
+
),
|
347
|
+
usage=s.children[1].result.usage,
|
348
|
+
),
|
295
349
|
},
|
296
350
|
)
|
297
351
|
|
@@ -321,11 +375,10 @@ class EvaluationTest(unittest.TestCase):
|
|
321
375
|
s.children[0].dir, os.path.join(s.root_dir, s.children[0].id)
|
322
376
|
)
|
323
377
|
# Test persistent hash.
|
324
|
-
self.assertEqual(s.hash, '
|
378
|
+
self.assertEqual(s.hash, 'b66a4e88')
|
325
379
|
|
326
380
|
summary = s.run(verbose=True)
|
327
381
|
self.assertEqual(len(summary.evaluations), 2)
|
328
|
-
|
329
382
|
self.assertEqual(
|
330
383
|
s.result,
|
331
384
|
{
|
@@ -341,7 +394,19 @@ class EvaluationTest(unittest.TestCase):
|
|
341
394
|
cache_stats=dict(
|
342
395
|
use_cache=True, num_queries=2, num_hits=0, num_updates=2
|
343
396
|
),
|
344
|
-
metrics=dict(
|
397
|
+
metrics=dict(
|
398
|
+
total=2,
|
399
|
+
failures=1,
|
400
|
+
failure_rate=0.5,
|
401
|
+
oop_failures=1,
|
402
|
+
oop_failure_rate=0.5,
|
403
|
+
non_oop_failures=0,
|
404
|
+
non_oop_failure_rate=0.0,
|
405
|
+
failure_breakdown={
|
406
|
+
'MappingError.SchemaError.TypeError': 1
|
407
|
+
}
|
408
|
+
),
|
409
|
+
usage=s.children[0].result.usage,
|
345
410
|
),
|
346
411
|
s.children[1].id: dict(
|
347
412
|
experiment_setup=dict(
|
@@ -355,7 +420,19 @@ class EvaluationTest(unittest.TestCase):
|
|
355
420
|
cache_stats=dict(
|
356
421
|
use_cache=True, num_queries=2, num_hits=0, num_updates=2
|
357
422
|
),
|
358
|
-
metrics=dict(
|
423
|
+
metrics=dict(
|
424
|
+
total=2,
|
425
|
+
failures=1,
|
426
|
+
failure_rate=0.5,
|
427
|
+
oop_failures=1,
|
428
|
+
oop_failure_rate=0.5,
|
429
|
+
non_oop_failures=0,
|
430
|
+
non_oop_failure_rate=0.0,
|
431
|
+
failure_breakdown={
|
432
|
+
'MappingError.SchemaError.TypeError': 1
|
433
|
+
}
|
434
|
+
),
|
435
|
+
usage=s.children[1].result.usage,
|
359
436
|
),
|
360
437
|
},
|
361
438
|
)
|
@@ -448,10 +525,10 @@ class SuiteTest(unittest.TestCase):
|
|
448
525
|
lm=lm
|
449
526
|
)
|
450
527
|
# Test for persistent hash.
|
451
|
-
self.assertEqual(s.hash, '
|
528
|
+
self.assertEqual(s.hash, '26e6cc25')
|
452
529
|
s.run()
|
453
530
|
expected = {
|
454
|
-
|
531
|
+
'Evaluation@0fade07d': dict(
|
455
532
|
experiment_setup=dict(
|
456
533
|
id=s.children[0].id,
|
457
534
|
dir=s.children[0].dir,
|
@@ -463,45 +540,46 @@ class SuiteTest(unittest.TestCase):
|
|
463
540
|
cache_stats=dict(
|
464
541
|
use_cache=True, num_queries=2, num_hits=0, num_updates=2
|
465
542
|
),
|
466
|
-
metrics=dict(
|
543
|
+
metrics=dict(
|
544
|
+
total=2,
|
545
|
+
failures=1,
|
546
|
+
failure_rate=0.5,
|
547
|
+
oop_failures=1,
|
548
|
+
oop_failure_rate=0.5,
|
549
|
+
non_oop_failures=0,
|
550
|
+
non_oop_failure_rate=0.0,
|
551
|
+
failure_breakdown={
|
552
|
+
'MappingError.SchemaError.TypeError': 1
|
553
|
+
}
|
554
|
+
),
|
555
|
+
usage=s.children[0].result.usage,
|
467
556
|
),
|
468
|
-
|
469
|
-
|
470
|
-
|
471
|
-
|
472
|
-
|
473
|
-
|
474
|
-
|
475
|
-
|
476
|
-
prompt_template='{{example.question}}',
|
477
|
-
method='call',
|
478
|
-
schema_fn='answer_schema()',
|
479
|
-
),
|
480
|
-
cache_stats=dict(
|
481
|
-
use_cache=True, num_queries=4, num_hits=1, num_updates=3
|
482
|
-
),
|
483
|
-
metrics=dict(total=2, failures=2, failure_rate=1.0),
|
557
|
+
'Evaluation@ae86c703': dict(
|
558
|
+
experiment_setup=dict(
|
559
|
+
id=s.children[1].children[0].id,
|
560
|
+
dir=s.children[1].children[0].dir,
|
561
|
+
model='StaticSequence',
|
562
|
+
prompt_template='{{example.question}}',
|
563
|
+
method='call',
|
564
|
+
schema_fn='answer_schema()',
|
484
565
|
),
|
485
|
-
|
486
|
-
|
487
|
-
.id: dict(
|
488
|
-
experiment_setup=dict(
|
489
|
-
id=s.children[1].children[2].id,
|
490
|
-
dir=s.children[1].children[2].dir,
|
491
|
-
model='StaticSequence',
|
492
|
-
prompt_template='{{example.question}}',
|
493
|
-
method='query',
|
494
|
-
schema_fn='answer_schema()',
|
495
|
-
),
|
496
|
-
cache_stats=dict(
|
497
|
-
use_cache=True,
|
498
|
-
num_queries=2,
|
499
|
-
num_hits=0,
|
500
|
-
num_updates=2,
|
501
|
-
),
|
502
|
-
metrics=dict(total=2, failures=1, failure_rate=0.5),
|
566
|
+
cache_stats=dict(
|
567
|
+
use_cache=True, num_queries=4, num_hits=1, num_updates=3
|
503
568
|
),
|
504
|
-
|
569
|
+
metrics=dict(
|
570
|
+
total=2,
|
571
|
+
failures=2,
|
572
|
+
failure_rate=1.0,
|
573
|
+
oop_failures=2,
|
574
|
+
oop_failure_rate=1.0,
|
575
|
+
non_oop_failures=0,
|
576
|
+
non_oop_failure_rate=0.0,
|
577
|
+
failure_breakdown={
|
578
|
+
'MappingError.SchemaError.TypeError': 2
|
579
|
+
}
|
580
|
+
),
|
581
|
+
usage=s.children[1].children[0].result.usage,
|
582
|
+
),
|
505
583
|
}
|
506
584
|
self.assertEqual(s.result, expected)
|
507
585
|
|
@@ -671,5 +749,98 @@ class SummaryTest(unittest.TestCase):
|
|
671
749
|
self.assertTrue(pg.io.path_exists(summary_file))
|
672
750
|
|
673
751
|
|
752
|
+
class NamedEvaluationTest(unittest.TestCase):
|
753
|
+
|
754
|
+
def test_named_eval_class(self):
|
755
|
+
|
756
|
+
@base.register('named_eval/class_test')
|
757
|
+
class MyEval(base.Evaluation):
|
758
|
+
inputs = base.as_inputs([
|
759
|
+
pg.Dict(question='Compute 1 + 1'),
|
760
|
+
])
|
761
|
+
method = 'query'
|
762
|
+
prompt = pg.oneof([
|
763
|
+
lf.Template('{{example.question}}'),
|
764
|
+
lf.Template('Hello {{example.question}}'),
|
765
|
+
])
|
766
|
+
schema_fn = answer_schema()
|
767
|
+
|
768
|
+
evaluation = base.get_evaluation('named_eval/class_test')
|
769
|
+
self.assertIsInstance(evaluation, MyEval)
|
770
|
+
self.assertIsNone(evaluation.dir)
|
771
|
+
self.assertIsNone(evaluation.root_dir)
|
772
|
+
self.assertIn('named_eval/class_test', base.registered_names())
|
773
|
+
|
774
|
+
with self.assertRaisesRegex(ValueError, 'Unsupported type.*'):
|
775
|
+
@base.register('named_eval/bad_class')
|
776
|
+
class Foo: # pylint: disable=unused-variable
|
777
|
+
pass
|
778
|
+
|
779
|
+
def test_named_eval_functor(self):
|
780
|
+
|
781
|
+
@base.register('named_eval/functor_test')
|
782
|
+
def my_eval():
|
783
|
+
return base.Evaluation(
|
784
|
+
inputs=base.as_inputs([
|
785
|
+
pg.Dict(question='Compute 1 + 1'),
|
786
|
+
]),
|
787
|
+
method='query',
|
788
|
+
prompt=pg.oneof([
|
789
|
+
lf.Template('{{example.question}}'),
|
790
|
+
lf.Template('Hello {{example.question}}'),
|
791
|
+
]),
|
792
|
+
schema_fn=answer_schema(),
|
793
|
+
)
|
794
|
+
|
795
|
+
self.assertTrue(issubclass(my_eval, base.Evaluable))
|
796
|
+
evaluation = base.get_evaluation('named_eval/functor_test')
|
797
|
+
self.assertIn('named_eval/functor_test', base.registered_names())
|
798
|
+
self.assertIsInstance(evaluation, my_eval)
|
799
|
+
self.assertIsNone(evaluation.root_dir, None)
|
800
|
+
|
801
|
+
with self.assertRaisesRegex(ValueError, 'Evaluation .* not found'):
|
802
|
+
base.get_evaluation('named_eval/non_existent')
|
803
|
+
|
804
|
+
with self.assertRaisesRegex(TypeError, 'The return value .*'):
|
805
|
+
@base.register('named_eval/bad_return_type')
|
806
|
+
def bad_eval(): # pylint: disable=unused-variable
|
807
|
+
return 1
|
808
|
+
|
809
|
+
def test_run(self):
|
810
|
+
@base.register('test/run')
|
811
|
+
def test_run(): # pylint: disable=unused-variable
|
812
|
+
lm = fake.StaticResponse('Solution(final_answer=2)')
|
813
|
+
return eval_set('run_test', 'query', schema_fn=answer_schema(), lm=lm)
|
814
|
+
|
815
|
+
e = base.run(
|
816
|
+
tempfile.gettempdir(),
|
817
|
+
['test/run'],
|
818
|
+
id_regex='run_test.*',
|
819
|
+
mode='dryrun',
|
820
|
+
print_definition=True,
|
821
|
+
)
|
822
|
+
self.assertEqual(
|
823
|
+
e.leaf_nodes[0].dir,
|
824
|
+
os.path.join(tempfile.gettempdir(), e.leaf_nodes[0].id),
|
825
|
+
)
|
826
|
+
self.assertTrue(
|
827
|
+
pg.eq(
|
828
|
+
e.leaf_nodes[0].lm, fake.StaticResponse('Solution(final_answer=2)')
|
829
|
+
)
|
830
|
+
)
|
831
|
+
|
832
|
+
@pg.patcher()
|
833
|
+
def bad_lm(unused_eval): # pylint: disable=unused-variable
|
834
|
+
return dict(lm=fake.StaticResponse('efg'))
|
835
|
+
|
836
|
+
e = base.run(
|
837
|
+
tempfile.gettempdir(),
|
838
|
+
[test_run()],
|
839
|
+
filter='Evaluation.*',
|
840
|
+
patches=['bad_lm']
|
841
|
+
)
|
842
|
+
self.assertTrue(pg.eq(e.leaf_nodes[0].lm, fake.StaticResponse('efg')))
|
843
|
+
|
844
|
+
|
674
845
|
if __name__ == '__main__':
|
675
846
|
unittest.main()
|
langfun/core/eval/matching.py
CHANGED
@@ -86,9 +86,26 @@ class Matching(base.Evaluation):
|
|
86
86
|
self._matches = []
|
87
87
|
self._mismatches = []
|
88
88
|
|
89
|
-
def
|
89
|
+
def audit_processed(
|
90
|
+
self, example: Any, output: Any, message: lf.Message, dryrun: bool = False
|
91
|
+
) -> None:
|
90
92
|
groundtruth = self.groundtruth(example)
|
91
93
|
answer = self.answer(output, example)
|
94
|
+
|
95
|
+
if dryrun:
|
96
|
+
lf.console.write('')
|
97
|
+
lf.console.write(
|
98
|
+
str(groundtruth),
|
99
|
+
title='GROUDTRUTH',
|
100
|
+
color='green',
|
101
|
+
)
|
102
|
+
lf.console.write('')
|
103
|
+
lf.console.write(
|
104
|
+
str(answer),
|
105
|
+
title='ANSWER',
|
106
|
+
color='blue',
|
107
|
+
)
|
108
|
+
|
92
109
|
if self.match(answer, groundtruth):
|
93
110
|
self._matches.append((example, output, message))
|
94
111
|
else:
|
@@ -102,18 +119,18 @@ class Matching(base.Evaluation):
|
|
102
119
|
del progress
|
103
120
|
return {
|
104
121
|
'Model': self.lm.model_id,
|
105
|
-
'Matches':
|
106
|
-
self.match_rate
|
122
|
+
'Matches': '%s (%d/%d)' % (
|
123
|
+
self._format_rate(self.match_rate),
|
107
124
|
self.num_matches,
|
108
125
|
self.num_completed,
|
109
126
|
),
|
110
|
-
'Mismatches':
|
111
|
-
self.mismatch_rate
|
127
|
+
'Mismatches': '%s (%d/%d)' % (
|
128
|
+
self._format_rate(self.mismatch_rate),
|
112
129
|
self.num_mismatches,
|
113
130
|
self.num_completed,
|
114
131
|
),
|
115
|
-
'Failed':
|
116
|
-
self.failure_rate
|
132
|
+
'Failed': '%s (%d/%d)' % (
|
133
|
+
self._format_rate(self.failure_rate),
|
117
134
|
self.num_failures,
|
118
135
|
self.num_completed,
|
119
136
|
),
|
@@ -123,24 +140,25 @@ class Matching(base.Evaluation):
|
|
123
140
|
assert self.result is not None
|
124
141
|
m = self.result.metrics
|
125
142
|
return (
|
126
|
-
|
127
|
-
|
128
|
-
|
143
|
+
'COMPLETED(%s):'
|
144
|
+
' Matches=%s (%d/%d)'
|
145
|
+
' Mismatches=%s (%d/%d)'
|
146
|
+
' Failures=%s (%d/%d)'
|
129
147
|
) % (
|
130
148
|
run_status,
|
131
|
-
m.match_rate
|
149
|
+
self._format_rate(m.match_rate),
|
132
150
|
m.num_matches,
|
133
151
|
m.total,
|
134
|
-
m.mismatch_rate
|
152
|
+
self._format_rate(m.mismatch_rate),
|
135
153
|
m.num_mismatches,
|
136
154
|
m.total,
|
137
|
-
m.failure_rate
|
155
|
+
self._format_rate(m.failure_rate),
|
138
156
|
m.failures,
|
139
157
|
m.total,
|
140
158
|
)
|
141
159
|
|
142
|
-
def
|
143
|
-
result = super().
|
160
|
+
def finalize(self) -> pg.Dict:
|
161
|
+
result = super().finalize()
|
144
162
|
result.metrics.update(
|
145
163
|
num_matches=self.num_matches,
|
146
164
|
match_rate=self.match_rate,
|
@@ -155,19 +173,16 @@ class Matching(base.Evaluation):
|
|
155
173
|
super().save(definition, result, report)
|
156
174
|
|
157
175
|
if result:
|
158
|
-
|
159
|
-
def force_dict(v):
|
160
|
-
return pg.object_utils.json_conversion.strip_types(pg.to_json(v))
|
161
|
-
|
162
176
|
# Save matches.
|
163
177
|
pg.save(
|
164
178
|
[
|
165
|
-
|
166
|
-
# within functors which could be deserialized.
|
167
|
-
pg.Dict(input=input, output=force_dict(output))
|
179
|
+
pg.Dict(input=input, output=output)
|
168
180
|
for input, output, _ in self.matches
|
169
181
|
],
|
170
182
|
os.path.join(self.dir, Matching.MATCHES_JSON),
|
183
|
+
# We force the input and output to be dict so it does not depend on
|
184
|
+
# the downstream to serialize.
|
185
|
+
force_dict=True,
|
171
186
|
)
|
172
187
|
|
173
188
|
# Save mismatches.
|
@@ -175,10 +190,13 @@ class Matching(base.Evaluation):
|
|
175
190
|
[
|
176
191
|
# We force the output to be dict as its type may be defined
|
177
192
|
# within functors which could be deserialized.
|
178
|
-
pg.Dict(input=input, output=
|
193
|
+
pg.Dict(input=input, output=output)
|
179
194
|
for input, output, _ in self.mismatches
|
180
195
|
],
|
181
196
|
os.path.join(self.dir, Matching.MISMATCHES_JSON),
|
197
|
+
# We force the input and output to be dict so it does not depend on
|
198
|
+
# the downstream to serialize.
|
199
|
+
force_dict=True,
|
182
200
|
)
|
183
201
|
|
184
202
|
if report:
|
@@ -201,9 +219,9 @@ class Matching(base.Evaluation):
|
|
201
219
|
def _render_result_row(self, s: io.StringIO):
|
202
220
|
super()._render_result_row(s)
|
203
221
|
s.write(
|
204
|
-
'<td><span style="color:
|
222
|
+
'<td><span style="color:orange">%s</span>%s</td>'
|
205
223
|
% (
|
206
|
-
|
224
|
+
self._format_rate(self.mismatch_rate),
|
207
225
|
'<a href="%s">(%d/%d)</a>'
|
208
226
|
% (self.mismatches_link, self.num_mismatches, self.num_completed),
|
209
227
|
)
|
@@ -211,13 +229,13 @@ class Matching(base.Evaluation):
|
|
211
229
|
s.write(
|
212
230
|
'<td><span style="color:green">%s</span>%s</td>'
|
213
231
|
% (
|
214
|
-
|
232
|
+
self._format_rate(self.match_rate),
|
215
233
|
'<a href="%s">(%d/%d)</a>'
|
216
234
|
% (self.matches_link, self.num_matches, self.num_completed),
|
217
235
|
)
|
218
236
|
)
|
219
237
|
|
220
|
-
def
|
238
|
+
def _render_summary_metrics(self, s: io.StringIO) -> None:
|
221
239
|
"""Renders metrics in HTML."""
|
222
240
|
assert self.result is not None
|
223
241
|
m = self.result.metrics
|
@@ -227,7 +245,7 @@ class Matching(base.Evaluation):
|
|
227
245
|
m.num_matches,
|
228
246
|
m.total,
|
229
247
|
self.matches_link,
|
230
|
-
|
248
|
+
self._format_rate(m.match_rate),
|
231
249
|
)
|
232
250
|
)
|
233
251
|
s.write(' | ')
|
@@ -237,11 +255,11 @@ class Matching(base.Evaluation):
|
|
237
255
|
m.num_mismatches,
|
238
256
|
m.total,
|
239
257
|
self.mismatches_link,
|
240
|
-
|
258
|
+
self._format_rate(m.mismatch_rate),
|
241
259
|
)
|
242
260
|
)
|
243
261
|
s.write(' | ')
|
244
|
-
super().
|
262
|
+
super()._render_summary_metrics(s)
|
245
263
|
|
246
264
|
def _render_matches(self, s: io.StringIO) -> None:
|
247
265
|
"""Formats the matched cases into html."""
|
@@ -103,7 +103,7 @@ class MatchingTest(unittest.TestCase):
|
|
103
103
|
s.result,
|
104
104
|
dict(
|
105
105
|
experiment_setup=dict(
|
106
|
-
id='MyTask@
|
106
|
+
id='MyTask@739a174b',
|
107
107
|
dir=s.dir,
|
108
108
|
model='StaticSequence',
|
109
109
|
prompt_template='{{example.question}}',
|
@@ -120,11 +120,19 @@ class MatchingTest(unittest.TestCase):
|
|
120
120
|
total=4,
|
121
121
|
failures=1,
|
122
122
|
failure_rate=0.25,
|
123
|
+
oop_failures=1,
|
124
|
+
oop_failure_rate=0.25,
|
125
|
+
non_oop_failures=0,
|
126
|
+
non_oop_failure_rate=0.0,
|
127
|
+
failure_breakdown={
|
128
|
+
'MappingError.SchemaError.TypeError': 1
|
129
|
+
},
|
123
130
|
num_matches=2,
|
124
131
|
match_rate=0.5,
|
125
132
|
num_mismatches=1,
|
126
133
|
mismatch_rate=0.25,
|
127
134
|
),
|
135
|
+
usage=s.result.usage,
|
128
136
|
),
|
129
137
|
)
|
130
138
|
self.assertTrue(
|
@@ -159,7 +167,14 @@ class MatchingTest(unittest.TestCase):
|
|
159
167
|
self.assertTrue(
|
160
168
|
os.path.exists(
|
161
169
|
os.path.join(
|
162
|
-
s.dir, matching.Matching.
|
170
|
+
s.dir, matching.Matching.OOP_FAILURES_JSON
|
171
|
+
)
|
172
|
+
)
|
173
|
+
)
|
174
|
+
self.assertTrue(
|
175
|
+
os.path.exists(
|
176
|
+
os.path.join(
|
177
|
+
s.dir, matching.Matching.NON_OOP_FAILURES_JSON
|
163
178
|
)
|
164
179
|
)
|
165
180
|
)
|
@@ -174,7 +189,14 @@ class MatchingTest(unittest.TestCase):
|
|
174
189
|
self.assertTrue(
|
175
190
|
os.path.exists(
|
176
191
|
os.path.join(
|
177
|
-
s.dir, matching.Matching.
|
192
|
+
s.dir, matching.Matching.OOP_FAILURES_HTML
|
193
|
+
)
|
194
|
+
)
|
195
|
+
)
|
196
|
+
self.assertTrue(
|
197
|
+
os.path.exists(
|
198
|
+
os.path.join(
|
199
|
+
s.dir, matching.Matching.NON_OOP_FAILURES_HTML
|
178
200
|
)
|
179
201
|
)
|
180
202
|
)
|