langfun 0.0.2.dev20240330__py3-none-any.whl → 0.1.2.dev202501140804__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/__init__.py +22 -2
- langfun/core/__init__.py +17 -5
- langfun/core/agentic/__init__.py +30 -0
- langfun/core/agentic/action.py +854 -0
- langfun/core/agentic/action_eval.py +150 -0
- langfun/core/agentic/action_eval_test.py +109 -0
- langfun/core/agentic/action_test.py +136 -0
- langfun/core/coding/python/__init__.py +5 -11
- langfun/core/coding/python/correction.py +37 -28
- langfun/core/coding/python/correction_test.py +29 -3
- langfun/core/coding/python/execution.py +40 -216
- langfun/core/coding/python/execution_test.py +29 -89
- langfun/core/coding/python/generation.py +21 -11
- langfun/core/coding/python/generation_test.py +2 -2
- langfun/core/coding/python/parsing.py +108 -193
- langfun/core/coding/python/parsing_test.py +2 -105
- langfun/core/component.py +69 -2
- langfun/core/component_test.py +54 -0
- langfun/core/concurrent.py +414 -117
- langfun/core/concurrent_test.py +111 -24
- langfun/core/console.py +18 -5
- langfun/core/console_test.py +17 -0
- langfun/core/eval/__init__.py +17 -0
- langfun/core/eval/base.py +767 -140
- langfun/core/eval/base_test.py +238 -53
- langfun/core/eval/matching.py +80 -76
- langfun/core/eval/matching_test.py +19 -9
- langfun/core/eval/patching.py +130 -0
- langfun/core/eval/patching_test.py +170 -0
- langfun/core/eval/scoring.py +37 -28
- langfun/core/eval/scoring_test.py +21 -3
- langfun/core/eval/v2/__init__.py +42 -0
- langfun/core/eval/v2/checkpointing.py +380 -0
- langfun/core/eval/v2/checkpointing_test.py +228 -0
- langfun/core/eval/v2/eval_test_helper.py +136 -0
- langfun/core/eval/v2/evaluation.py +725 -0
- langfun/core/eval/v2/evaluation_test.py +180 -0
- langfun/core/eval/v2/example.py +305 -0
- langfun/core/eval/v2/example_test.py +128 -0
- langfun/core/eval/v2/experiment.py +1048 -0
- langfun/core/eval/v2/experiment_test.py +433 -0
- langfun/core/eval/v2/metric_values.py +156 -0
- langfun/core/eval/v2/metric_values_test.py +80 -0
- langfun/core/eval/v2/metrics.py +357 -0
- langfun/core/eval/v2/metrics_test.py +203 -0
- langfun/core/eval/v2/progress.py +348 -0
- langfun/core/eval/v2/progress_test.py +82 -0
- langfun/core/eval/v2/progress_tracking.py +210 -0
- langfun/core/eval/v2/progress_tracking_test.py +66 -0
- langfun/core/eval/v2/reporting.py +270 -0
- langfun/core/eval/v2/reporting_test.py +158 -0
- langfun/core/eval/v2/runners.py +488 -0
- langfun/core/eval/v2/runners_test.py +334 -0
- langfun/core/langfunc.py +3 -21
- langfun/core/langfunc_test.py +26 -8
- langfun/core/language_model.py +686 -48
- langfun/core/language_model_test.py +681 -44
- langfun/core/llms/__init__.py +100 -12
- langfun/core/llms/anthropic.py +488 -0
- langfun/core/llms/anthropic_test.py +235 -0
- langfun/core/llms/cache/base.py +21 -2
- langfun/core/llms/cache/in_memory.py +13 -0
- langfun/core/llms/cache/in_memory_test.py +88 -28
- langfun/core/llms/compositional.py +101 -0
- langfun/core/llms/compositional_test.py +73 -0
- langfun/core/llms/deepseek.py +117 -0
- langfun/core/llms/deepseek_test.py +61 -0
- langfun/core/llms/fake.py +39 -26
- langfun/core/llms/fake_test.py +136 -11
- langfun/core/llms/gemini.py +507 -0
- langfun/core/llms/gemini_test.py +195 -0
- langfun/core/llms/google_genai.py +62 -218
- langfun/core/llms/google_genai_test.py +9 -197
- langfun/core/llms/groq.py +276 -0
- langfun/core/llms/groq_test.py +64 -0
- langfun/core/llms/llama_cpp.py +15 -40
- langfun/core/llms/llama_cpp_test.py +4 -30
- langfun/core/llms/openai.py +436 -226
- langfun/core/llms/openai_compatible.py +179 -0
- langfun/core/llms/openai_compatible_test.py +495 -0
- langfun/core/llms/openai_test.py +35 -174
- langfun/core/llms/rest.py +113 -0
- langfun/core/llms/rest_test.py +111 -0
- langfun/core/llms/vertexai.py +192 -0
- langfun/core/llms/vertexai_test.py +52 -0
- langfun/core/logging.py +284 -0
- langfun/core/logging_test.py +125 -0
- langfun/core/message.py +319 -9
- langfun/core/message_test.py +190 -13
- langfun/core/modalities/__init__.py +6 -2
- langfun/core/modalities/audio.py +30 -0
- langfun/core/modalities/audio_test.py +63 -0
- langfun/core/modalities/image.py +39 -20
- langfun/core/modalities/image_test.py +52 -9
- langfun/core/modalities/mime.py +206 -29
- langfun/core/modalities/mime_test.py +90 -9
- langfun/core/modalities/ms_office.py +117 -0
- langfun/core/modalities/ms_office_test.py +389 -0
- langfun/core/modalities/pdf.py +22 -0
- langfun/core/modalities/pdf_test.py +57 -0
- langfun/core/modalities/video.py +9 -23
- langfun/core/modalities/video_test.py +3 -3
- langfun/core/modality.py +26 -3
- langfun/core/modality_test.py +2 -2
- langfun/core/sampling.py +11 -11
- langfun/core/structured/__init__.py +15 -16
- langfun/core/structured/completion.py +32 -5
- langfun/core/structured/completion_test.py +9 -8
- langfun/core/structured/description.py +2 -2
- langfun/core/structured/description_test.py +3 -3
- langfun/core/structured/function_generation.py +278 -0
- langfun/core/structured/function_generation_test.py +399 -0
- langfun/core/structured/mapping.py +150 -46
- langfun/core/structured/mapping_test.py +105 -0
- langfun/core/structured/parsing.py +33 -21
- langfun/core/structured/parsing_test.py +71 -22
- langfun/core/structured/querying.py +746 -0
- langfun/core/structured/{prompting_test.py → querying_test.py} +545 -60
- langfun/core/structured/schema.py +208 -99
- langfun/core/structured/schema_generation.py +1 -1
- langfun/core/structured/schema_generation_test.py +2 -2
- langfun/core/structured/schema_test.py +133 -34
- langfun/core/structured/scoring.py +125 -19
- langfun/core/structured/scoring_test.py +30 -0
- langfun/core/structured/tokenization.py +64 -0
- langfun/core/structured/tokenization_test.py +48 -0
- langfun/core/template.py +240 -11
- langfun/core/template_test.py +146 -1
- langfun/core/templates/conversation.py +9 -0
- langfun/core/templates/conversation_test.py +4 -3
- langfun/core/templates/selfplay_test.py +14 -2
- langfun-0.1.2.dev202501140804.dist-info/METADATA +225 -0
- langfun-0.1.2.dev202501140804.dist-info/RECORD +153 -0
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/WHEEL +1 -1
- langfun/core/coding/python/errors.py +0 -108
- langfun/core/coding/python/errors_test.py +0 -99
- langfun/core/coding/python/permissions.py +0 -90
- langfun/core/coding/python/permissions_test.py +0 -86
- langfun/core/structured/prompting.py +0 -217
- langfun/core/text_formatting.py +0 -162
- langfun/core/text_formatting_test.py +0 -47
- langfun-0.0.2.dev20240330.dist-info/METADATA +0 -99
- langfun-0.0.2.dev20240330.dist-info/RECORD +0 -102
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/LICENSE +0 -0
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/top_level.txt +0 -0
langfun/core/eval/base_test.py
CHANGED
@@ -101,7 +101,7 @@ class EvaluationTest(unittest.TestCase):
|
|
101
101
|
self.assertEqual(s.dir, os.path.join(s.root_dir, s.id))
|
102
102
|
self.assertEqual(s.hash, s.clone().hash)
|
103
103
|
# Test persistent hash.
|
104
|
-
self.assertEqual(s.hash, '
|
104
|
+
self.assertEqual(s.hash, 'ae86c703')
|
105
105
|
self.assertEqual(
|
106
106
|
s.hash, s.clone(override={'max_workers': 2, 'lm.timeout': 20}).hash
|
107
107
|
)
|
@@ -194,6 +194,8 @@ class EvaluationTest(unittest.TestCase):
|
|
194
194
|
cache_seed=0,
|
195
195
|
score=1.0,
|
196
196
|
logprobs=None,
|
197
|
+
is_cached=False,
|
198
|
+
usage=lf.LMSamplingUsage(387, 24, 411),
|
197
199
|
tags=['lm-response', 'lm-output', 'transformed'],
|
198
200
|
),
|
199
201
|
)
|
@@ -209,7 +211,7 @@ class EvaluationTest(unittest.TestCase):
|
|
209
211
|
s.result,
|
210
212
|
dict(
|
211
213
|
experiment_setup=dict(
|
212
|
-
id='Evaluation@
|
214
|
+
id='Evaluation@0fade07d',
|
213
215
|
dir=s.dir,
|
214
216
|
model='StaticSequence',
|
215
217
|
prompt_template='{{example.question}}',
|
@@ -219,7 +221,26 @@ class EvaluationTest(unittest.TestCase):
|
|
219
221
|
cache_stats=dict(
|
220
222
|
use_cache=True, num_queries=2, num_hits=0, num_updates=2
|
221
223
|
),
|
222
|
-
metrics=dict(
|
224
|
+
metrics=dict(
|
225
|
+
total=2,
|
226
|
+
failures=1,
|
227
|
+
failure_rate=0.5,
|
228
|
+
oop_failures=1,
|
229
|
+
oop_failure_rate=0.5,
|
230
|
+
non_oop_failures=0,
|
231
|
+
non_oop_failure_rate=0.0,
|
232
|
+
failure_breakdown={
|
233
|
+
'MappingError.SchemaError.TypeError': 1
|
234
|
+
}
|
235
|
+
),
|
236
|
+
usage=dict(
|
237
|
+
total_prompt_tokens=774,
|
238
|
+
total_completion_tokens=25,
|
239
|
+
num_usages=2,
|
240
|
+
average_prompt_tokens=387,
|
241
|
+
average_completion_tokens=12,
|
242
|
+
average_total_tokens=399,
|
243
|
+
),
|
223
244
|
),
|
224
245
|
)
|
225
246
|
self.assertTrue(
|
@@ -227,14 +248,32 @@ class EvaluationTest(unittest.TestCase):
|
|
227
248
|
self.assertTrue(
|
228
249
|
os.path.exists(os.path.join(s.dir, base.Evaluation.RESULT_JSON)))
|
229
250
|
self.assertTrue(
|
230
|
-
os.path.exists(os.path.join(s.dir, base.Evaluation.
|
251
|
+
os.path.exists(os.path.join(s.dir, base.Evaluation.OOP_FAILURES_JSON)))
|
231
252
|
self.assertTrue(
|
232
|
-
os.path.exists(
|
233
|
-
|
253
|
+
os.path.exists(
|
254
|
+
os.path.join(s.dir, base.Evaluation.NON_OOP_FAILURES_JSON)))
|
255
|
+
self.assertTrue(
|
256
|
+
os.path.exists(os.path.join(s.dir, base.Evaluation.CACHE_JSON)))
|
234
257
|
self.assertTrue(
|
235
258
|
os.path.exists(os.path.join(s.dir, base.Evaluation.INDEX_HTML)))
|
236
259
|
self.assertTrue(
|
237
|
-
os.path.exists(os.path.join(s.dir, base.Evaluation.
|
260
|
+
os.path.exists(os.path.join(s.dir, base.Evaluation.OOP_FAILURES_HTML)))
|
261
|
+
self.assertTrue(
|
262
|
+
os.path.exists(
|
263
|
+
os.path.join(s.dir, base.Evaluation.NON_OOP_FAILURES_HTML)))
|
264
|
+
self.assertTrue(
|
265
|
+
os.path.exists(os.path.join(s.root_dir, base.Evaluation.SUMMARY_HTML))
|
266
|
+
)
|
267
|
+
# Check summary JSON.
|
268
|
+
summary_json = os.path.join(
|
269
|
+
s.root_dir, base.Evaluation.SUMMARY_HTML.replace('.html', '.json')
|
270
|
+
)
|
271
|
+
self.assertTrue(os.path.exists(summary_json))
|
272
|
+
summary = pg.load(summary_json, auto_dict=True)
|
273
|
+
self.assertIn('Evaluation', summary)
|
274
|
+
self.assertEqual(len(summary['Evaluation']), 1)
|
275
|
+
self.assertIsNotNone(summary['Evaluation'][0].experiment)
|
276
|
+
self.assertIsNotNone(summary['Evaluation'][0].metrics)
|
238
277
|
|
239
278
|
def test_run_wihtout_save(self):
|
240
279
|
lm = fake.StaticSequence([
|
@@ -255,7 +294,10 @@ class EvaluationTest(unittest.TestCase):
|
|
255
294
|
self.assertFalse(
|
256
295
|
os.path.exists(os.path.join(s.dir, base.Evaluation.INDEX_HTML)))
|
257
296
|
self.assertFalse(
|
258
|
-
os.path.exists(os.path.join(s.dir, base.Evaluation.
|
297
|
+
os.path.exists(os.path.join(s.dir, base.Evaluation.OOP_FAILURES_HTML)))
|
298
|
+
self.assertFalse(
|
299
|
+
os.path.exists(
|
300
|
+
os.path.join(s.dir, base.Evaluation.NON_OOP_FAILURES_HTML)))
|
259
301
|
|
260
302
|
def test_load(self):
|
261
303
|
lm = fake.StaticResponse('Solution(final_answer=2)')
|
@@ -274,8 +316,11 @@ class EvaluationTest(unittest.TestCase):
|
|
274
316
|
s = eval_set(
|
275
317
|
'run_filter_test', pg.oneof(['call', 'query']),
|
276
318
|
schema_fn=answer_schema(), lm=lm)
|
319
|
+
result = s.run(
|
320
|
+
filter=lambda x: x.method == 'query', dryrun=True, summary=False
|
321
|
+
)
|
277
322
|
self.assertEqual(
|
278
|
-
|
323
|
+
result,
|
279
324
|
{
|
280
325
|
s.children[0].id: None,
|
281
326
|
s.children[1].id: dict(
|
@@ -290,8 +335,18 @@ class EvaluationTest(unittest.TestCase):
|
|
290
335
|
cache_stats=dict(
|
291
336
|
use_cache=True, num_queries=2, num_hits=0, num_updates=2
|
292
337
|
),
|
293
|
-
metrics=dict(
|
294
|
-
|
338
|
+
metrics=dict(
|
339
|
+
total=2,
|
340
|
+
failures=0,
|
341
|
+
failure_rate=0.0,
|
342
|
+
oop_failures=0,
|
343
|
+
oop_failure_rate=0.0,
|
344
|
+
non_oop_failures=0,
|
345
|
+
non_oop_failure_rate=0.0,
|
346
|
+
failure_breakdown={},
|
347
|
+
),
|
348
|
+
usage=s.children[1].result.usage,
|
349
|
+
),
|
295
350
|
},
|
296
351
|
)
|
297
352
|
|
@@ -321,11 +376,10 @@ class EvaluationTest(unittest.TestCase):
|
|
321
376
|
s.children[0].dir, os.path.join(s.root_dir, s.children[0].id)
|
322
377
|
)
|
323
378
|
# Test persistent hash.
|
324
|
-
self.assertEqual(s.hash, '
|
379
|
+
self.assertEqual(s.hash, 'b66a4e88')
|
325
380
|
|
326
381
|
summary = s.run(verbose=True)
|
327
382
|
self.assertEqual(len(summary.evaluations), 2)
|
328
|
-
|
329
383
|
self.assertEqual(
|
330
384
|
s.result,
|
331
385
|
{
|
@@ -341,7 +395,19 @@ class EvaluationTest(unittest.TestCase):
|
|
341
395
|
cache_stats=dict(
|
342
396
|
use_cache=True, num_queries=2, num_hits=0, num_updates=2
|
343
397
|
),
|
344
|
-
metrics=dict(
|
398
|
+
metrics=dict(
|
399
|
+
total=2,
|
400
|
+
failures=1,
|
401
|
+
failure_rate=0.5,
|
402
|
+
oop_failures=1,
|
403
|
+
oop_failure_rate=0.5,
|
404
|
+
non_oop_failures=0,
|
405
|
+
non_oop_failure_rate=0.0,
|
406
|
+
failure_breakdown={
|
407
|
+
'MappingError.SchemaError.TypeError': 1
|
408
|
+
}
|
409
|
+
),
|
410
|
+
usage=s.children[0].result.usage,
|
345
411
|
),
|
346
412
|
s.children[1].id: dict(
|
347
413
|
experiment_setup=dict(
|
@@ -355,7 +421,19 @@ class EvaluationTest(unittest.TestCase):
|
|
355
421
|
cache_stats=dict(
|
356
422
|
use_cache=True, num_queries=2, num_hits=0, num_updates=2
|
357
423
|
),
|
358
|
-
metrics=dict(
|
424
|
+
metrics=dict(
|
425
|
+
total=2,
|
426
|
+
failures=1,
|
427
|
+
failure_rate=0.5,
|
428
|
+
oop_failures=1,
|
429
|
+
oop_failure_rate=0.5,
|
430
|
+
non_oop_failures=0,
|
431
|
+
non_oop_failure_rate=0.0,
|
432
|
+
failure_breakdown={
|
433
|
+
'MappingError.SchemaError.TypeError': 1
|
434
|
+
}
|
435
|
+
),
|
436
|
+
usage=s.children[1].result.usage,
|
359
437
|
),
|
360
438
|
},
|
361
439
|
)
|
@@ -448,10 +526,10 @@ class SuiteTest(unittest.TestCase):
|
|
448
526
|
lm=lm
|
449
527
|
)
|
450
528
|
# Test for persistent hash.
|
451
|
-
self.assertEqual(s.hash, '
|
529
|
+
self.assertEqual(s.hash, '26e6cc25')
|
452
530
|
s.run()
|
453
531
|
expected = {
|
454
|
-
|
532
|
+
'Evaluation@0fade07d': dict(
|
455
533
|
experiment_setup=dict(
|
456
534
|
id=s.children[0].id,
|
457
535
|
dir=s.children[0].dir,
|
@@ -463,45 +541,46 @@ class SuiteTest(unittest.TestCase):
|
|
463
541
|
cache_stats=dict(
|
464
542
|
use_cache=True, num_queries=2, num_hits=0, num_updates=2
|
465
543
|
),
|
466
|
-
metrics=dict(
|
544
|
+
metrics=dict(
|
545
|
+
total=2,
|
546
|
+
failures=1,
|
547
|
+
failure_rate=0.5,
|
548
|
+
oop_failures=1,
|
549
|
+
oop_failure_rate=0.5,
|
550
|
+
non_oop_failures=0,
|
551
|
+
non_oop_failure_rate=0.0,
|
552
|
+
failure_breakdown={
|
553
|
+
'MappingError.SchemaError.TypeError': 1
|
554
|
+
}
|
555
|
+
),
|
556
|
+
usage=s.children[0].result.usage,
|
467
557
|
),
|
468
|
-
|
469
|
-
|
470
|
-
|
471
|
-
|
472
|
-
|
473
|
-
|
474
|
-
|
475
|
-
|
476
|
-
prompt_template='{{example.question}}',
|
477
|
-
method='call',
|
478
|
-
schema_fn='answer_schema()',
|
479
|
-
),
|
480
|
-
cache_stats=dict(
|
481
|
-
use_cache=True, num_queries=4, num_hits=1, num_updates=3
|
482
|
-
),
|
483
|
-
metrics=dict(total=2, failures=2, failure_rate=1.0),
|
558
|
+
'Evaluation@ae86c703': dict(
|
559
|
+
experiment_setup=dict(
|
560
|
+
id=s.children[1].children[0].id,
|
561
|
+
dir=s.children[1].children[0].dir,
|
562
|
+
model='StaticSequence',
|
563
|
+
prompt_template='{{example.question}}',
|
564
|
+
method='call',
|
565
|
+
schema_fn='answer_schema()',
|
484
566
|
),
|
485
|
-
|
486
|
-
|
487
|
-
.id: dict(
|
488
|
-
experiment_setup=dict(
|
489
|
-
id=s.children[1].children[2].id,
|
490
|
-
dir=s.children[1].children[2].dir,
|
491
|
-
model='StaticSequence',
|
492
|
-
prompt_template='{{example.question}}',
|
493
|
-
method='query',
|
494
|
-
schema_fn='answer_schema()',
|
495
|
-
),
|
496
|
-
cache_stats=dict(
|
497
|
-
use_cache=True,
|
498
|
-
num_queries=2,
|
499
|
-
num_hits=0,
|
500
|
-
num_updates=2,
|
501
|
-
),
|
502
|
-
metrics=dict(total=2, failures=1, failure_rate=0.5),
|
567
|
+
cache_stats=dict(
|
568
|
+
use_cache=True, num_queries=4, num_hits=0, num_updates=4
|
503
569
|
),
|
504
|
-
|
570
|
+
metrics=dict(
|
571
|
+
total=2,
|
572
|
+
failures=2,
|
573
|
+
failure_rate=1.0,
|
574
|
+
oop_failures=2,
|
575
|
+
oop_failure_rate=1.0,
|
576
|
+
non_oop_failures=0,
|
577
|
+
non_oop_failure_rate=0.0,
|
578
|
+
failure_breakdown={
|
579
|
+
'MappingError.SchemaError.TypeError': 2
|
580
|
+
}
|
581
|
+
),
|
582
|
+
usage=s.children[1].children[0].result.usage,
|
583
|
+
),
|
505
584
|
}
|
506
585
|
self.assertEqual(s.result, expected)
|
507
586
|
|
@@ -520,6 +599,14 @@ class InputsFrom(unittest.TestCase):
|
|
520
599
|
pg.save([1, 2, 3], path)
|
521
600
|
self.assertEqual(base.inputs_from(path)(), [1, 2, 3])
|
522
601
|
|
602
|
+
path = os.path.join(tmp_dir, 'input_file.jsonl')
|
603
|
+
with pg.open_jsonl(path, 'w') as f:
|
604
|
+
f.add(pg.Dict(x=1))
|
605
|
+
f.add(dict(y=2))
|
606
|
+
self.assertEqual(
|
607
|
+
base.inputs_from(path)(), [pg.Dict(x=1), dict(y=2)]
|
608
|
+
)
|
609
|
+
|
523
610
|
def test_inputs_from_multiple_files(self):
|
524
611
|
tmp_dir = tempfile.gettempdir()
|
525
612
|
path1 = os.path.join(tmp_dir, 'input_file1.json')
|
@@ -671,5 +758,103 @@ class SummaryTest(unittest.TestCase):
|
|
671
758
|
self.assertTrue(pg.io.path_exists(summary_file))
|
672
759
|
|
673
760
|
|
761
|
+
class NamedEvaluationTest(unittest.TestCase):
|
762
|
+
|
763
|
+
def test_named_eval_class(self):
|
764
|
+
|
765
|
+
@base.register('named_eval/class_test')
|
766
|
+
class MyEval(base.Evaluation):
|
767
|
+
inputs = base.as_inputs([
|
768
|
+
pg.Dict(question='Compute 1 + 1'),
|
769
|
+
])
|
770
|
+
method = 'query'
|
771
|
+
prompt = pg.oneof([
|
772
|
+
lf.Template('{{example.question}}'),
|
773
|
+
lf.Template('Hello {{example.question}}'),
|
774
|
+
])
|
775
|
+
schema_fn = answer_schema()
|
776
|
+
|
777
|
+
[evaluation] = base.get_evaluations('named_eval/class_test')
|
778
|
+
self.assertIsInstance(evaluation, MyEval)
|
779
|
+
self.assertIsNone(evaluation.dir)
|
780
|
+
self.assertIsNone(evaluation.root_dir)
|
781
|
+
self.assertIn('named_eval/class_test', base.registered_names())
|
782
|
+
|
783
|
+
with self.assertRaisesRegex(ValueError, 'Unsupported type.*'):
|
784
|
+
@base.register('named_eval/bad_class')
|
785
|
+
class Foo: # pylint: disable=unused-variable
|
786
|
+
pass
|
787
|
+
|
788
|
+
def test_named_eval_functor(self):
|
789
|
+
|
790
|
+
@base.register('named_eval/functor_test')
|
791
|
+
def my_eval():
|
792
|
+
return base.Evaluation(
|
793
|
+
inputs=base.as_inputs([
|
794
|
+
pg.Dict(question='Compute 1 + 1'),
|
795
|
+
]),
|
796
|
+
method='query',
|
797
|
+
prompt=pg.oneof([
|
798
|
+
lf.Template('{{example.question}}'),
|
799
|
+
lf.Template('Hello {{example.question}}'),
|
800
|
+
]),
|
801
|
+
schema_fn=answer_schema(),
|
802
|
+
)
|
803
|
+
|
804
|
+
self.assertTrue(issubclass(my_eval, base.Evaluable))
|
805
|
+
[evaluation] = base.get_evaluations('named_eval/functor_test')
|
806
|
+
self.assertIn('named_eval/functor_test', base.registered_names())
|
807
|
+
self.assertIsInstance(evaluation, my_eval)
|
808
|
+
self.assertIsNone(evaluation.root_dir, None)
|
809
|
+
|
810
|
+
self.assertTrue(
|
811
|
+
pg.eq(base.get_evaluations('named_eval/functor.*'), [evaluation])
|
812
|
+
)
|
813
|
+
self.assertEqual(base.get_evaluations('named_eval/non_existent'), [])
|
814
|
+
|
815
|
+
with self.assertRaisesRegex(TypeError, 'The return value .*'):
|
816
|
+
@base.register('named_eval/bad_return_type')
|
817
|
+
def bad_eval(): # pylint: disable=unused-variable
|
818
|
+
return 1
|
819
|
+
|
820
|
+
def test_run(self):
|
821
|
+
@base.register('test/run')
|
822
|
+
def test_run(): # pylint: disable=unused-variable
|
823
|
+
lm = fake.StaticResponse('Solution(final_answer=2)')
|
824
|
+
return eval_set('run_test', 'query', schema_fn=answer_schema(), lm=lm)
|
825
|
+
|
826
|
+
e = base.run(
|
827
|
+
tempfile.gettempdir(),
|
828
|
+
['test/run'],
|
829
|
+
id_regex='run_test.*',
|
830
|
+
mode='dryrun',
|
831
|
+
print_definition=True,
|
832
|
+
)
|
833
|
+
self.assertEqual(
|
834
|
+
e.leaf_nodes[0].dir,
|
835
|
+
os.path.join(tempfile.gettempdir(), e.leaf_nodes[0].id),
|
836
|
+
)
|
837
|
+
self.assertTrue(
|
838
|
+
pg.eq(
|
839
|
+
e.leaf_nodes[0].lm, fake.StaticResponse('Solution(final_answer=2)')
|
840
|
+
)
|
841
|
+
)
|
842
|
+
|
843
|
+
@pg.patcher()
|
844
|
+
def bad_lm(unused_eval): # pylint: disable=unused-variable
|
845
|
+
return dict(lm=fake.StaticResponse('efg'))
|
846
|
+
|
847
|
+
e = base.run(
|
848
|
+
tempfile.gettempdir(),
|
849
|
+
[test_run()],
|
850
|
+
filter='Evaluation.*',
|
851
|
+
patches=['bad_lm']
|
852
|
+
)
|
853
|
+
self.assertTrue(pg.eq(e.leaf_nodes[0].lm, fake.StaticResponse('efg')))
|
854
|
+
|
855
|
+
with self.assertRaisesRegex(ValueError, 'No evaluations found'):
|
856
|
+
base.run(tempfile.gettempdir(), ['test/non_existent'])
|
857
|
+
|
858
|
+
|
674
859
|
if __name__ == '__main__':
|
675
860
|
unittest.main()
|
langfun/core/eval/matching.py
CHANGED
@@ -41,8 +41,8 @@ class Matching(base.Evaluation):
|
|
41
41
|
"""Returns the answer from the structure output."""
|
42
42
|
|
43
43
|
@property
|
44
|
-
def matches(self) -> list[tuple[Any, Any, lf.Message]]:
|
45
|
-
"""Returns the matches examples, outputs and the output messages."""
|
44
|
+
def matches(self) -> list[tuple[int, Any, Any, lf.Message]]:
|
45
|
+
"""Returns the matches IDs, examples, outputs and the output messages."""
|
46
46
|
return self._matches
|
47
47
|
|
48
48
|
@property
|
@@ -57,7 +57,7 @@ class Matching(base.Evaluation):
|
|
57
57
|
return self.num_matches / self.num_completed
|
58
58
|
|
59
59
|
@property
|
60
|
-
def mismatches(self) -> list[tuple[Any, Any, lf.Message]]:
|
60
|
+
def mismatches(self) -> list[tuple[int, Any, Any, lf.Message]]:
|
61
61
|
"""Returns the mismatches examples, outputs and output messages."""
|
62
62
|
return self._mismatches
|
63
63
|
|
@@ -86,34 +86,51 @@ class Matching(base.Evaluation):
|
|
86
86
|
self._matches = []
|
87
87
|
self._mismatches = []
|
88
88
|
|
89
|
-
def
|
89
|
+
def audit_processed(
|
90
|
+
self, example_idx: int, example: Any, output: Any, message: lf.Message,
|
91
|
+
dryrun: bool = False
|
92
|
+
) -> None:
|
90
93
|
groundtruth = self.groundtruth(example)
|
91
94
|
answer = self.answer(output, example)
|
95
|
+
|
96
|
+
if dryrun:
|
97
|
+
lf.console.write('')
|
98
|
+
lf.console.write(
|
99
|
+
str(groundtruth),
|
100
|
+
title='GROUDTRUTH',
|
101
|
+
color='green',
|
102
|
+
)
|
103
|
+
lf.console.write('')
|
104
|
+
lf.console.write(
|
105
|
+
str(answer),
|
106
|
+
title='ANSWER',
|
107
|
+
color='blue',
|
108
|
+
)
|
109
|
+
|
92
110
|
if self.match(answer, groundtruth):
|
93
|
-
self._matches.append((example, output, message))
|
111
|
+
self._matches.append((example_idx, example, output, message))
|
94
112
|
else:
|
95
|
-
self._mismatches.append((example, output, message))
|
113
|
+
self._mismatches.append((example_idx, example, output, message))
|
96
114
|
|
97
115
|
def match(self, answer: Any, groundtruth: Any) -> bool:
|
98
116
|
"""Matches answer against the groundtruth. Subclasses can override."""
|
99
117
|
return pg.eq(answer, groundtruth)
|
100
118
|
|
101
|
-
def
|
119
|
+
def _eval_status(self, progress: lf.concurrent.Progress) -> dict[str, Any]:
|
102
120
|
del progress
|
103
121
|
return {
|
104
|
-
'
|
105
|
-
|
106
|
-
self.match_rate * 100,
|
122
|
+
'Matches': '%s (%d/%d)' % (
|
123
|
+
self._format_rate(self.match_rate),
|
107
124
|
self.num_matches,
|
108
125
|
self.num_completed,
|
109
126
|
),
|
110
|
-
'Mismatches':
|
111
|
-
self.mismatch_rate
|
127
|
+
'Mismatches': '%s (%d/%d)' % (
|
128
|
+
self._format_rate(self.mismatch_rate),
|
112
129
|
self.num_mismatches,
|
113
130
|
self.num_completed,
|
114
131
|
),
|
115
|
-
'Failed':
|
116
|
-
self.failure_rate
|
132
|
+
'Failed': '%s (%d/%d)' % (
|
133
|
+
self._format_rate(self.failure_rate),
|
117
134
|
self.num_failures,
|
118
135
|
self.num_completed,
|
119
136
|
),
|
@@ -123,24 +140,25 @@ class Matching(base.Evaluation):
|
|
123
140
|
assert self.result is not None
|
124
141
|
m = self.result.metrics
|
125
142
|
return (
|
126
|
-
|
127
|
-
|
128
|
-
|
143
|
+
'COMPLETED(%s):'
|
144
|
+
' Matches=%s (%d/%d)'
|
145
|
+
' Mismatches=%s (%d/%d)'
|
146
|
+
' Failures=%s (%d/%d)'
|
129
147
|
) % (
|
130
148
|
run_status,
|
131
|
-
m.match_rate
|
149
|
+
self._format_rate(m.match_rate),
|
132
150
|
m.num_matches,
|
133
151
|
m.total,
|
134
|
-
m.mismatch_rate
|
152
|
+
self._format_rate(m.mismatch_rate),
|
135
153
|
m.num_mismatches,
|
136
154
|
m.total,
|
137
|
-
m.failure_rate
|
155
|
+
self._format_rate(m.failure_rate),
|
138
156
|
m.failures,
|
139
157
|
m.total,
|
140
158
|
)
|
141
159
|
|
142
|
-
def
|
143
|
-
result = super().
|
160
|
+
def finalize(self) -> pg.Dict:
|
161
|
+
result = super().finalize()
|
144
162
|
result.metrics.update(
|
145
163
|
num_matches=self.num_matches,
|
146
164
|
match_rate=self.match_rate,
|
@@ -154,33 +172,6 @@ class Matching(base.Evaluation):
|
|
154
172
|
) -> None:
|
155
173
|
super().save(definition, result, report)
|
156
174
|
|
157
|
-
if result:
|
158
|
-
|
159
|
-
def force_dict(v):
|
160
|
-
return pg.object_utils.json_conversion.strip_types(pg.to_json(v))
|
161
|
-
|
162
|
-
# Save matches.
|
163
|
-
pg.save(
|
164
|
-
[
|
165
|
-
# We force the output to be dict as its type may be defined
|
166
|
-
# within functors which could be deserialized.
|
167
|
-
pg.Dict(input=input, output=force_dict(output))
|
168
|
-
for input, output, _ in self.matches
|
169
|
-
],
|
170
|
-
os.path.join(self.dir, Matching.MATCHES_JSON),
|
171
|
-
)
|
172
|
-
|
173
|
-
# Save mismatches.
|
174
|
-
pg.save(
|
175
|
-
[
|
176
|
-
# We force the output to be dict as its type may be defined
|
177
|
-
# within functors which could be deserialized.
|
178
|
-
pg.Dict(input=input, output=force_dict(output))
|
179
|
-
for input, output, _ in self.mismatches
|
180
|
-
],
|
181
|
-
os.path.join(self.dir, Matching.MISMATCHES_JSON),
|
182
|
-
)
|
183
|
-
|
184
175
|
if report:
|
185
176
|
pg.save(
|
186
177
|
self._html([self._render_result, self._render_matches]),
|
@@ -201,9 +192,9 @@ class Matching(base.Evaluation):
|
|
201
192
|
def _render_result_row(self, s: io.StringIO):
|
202
193
|
super()._render_result_row(s)
|
203
194
|
s.write(
|
204
|
-
'<td><span style="color:
|
195
|
+
'<td><span style="color:orange">%s</span>%s</td>'
|
205
196
|
% (
|
206
|
-
|
197
|
+
self._format_rate(self.mismatch_rate),
|
207
198
|
'<a href="%s">(%d/%d)</a>'
|
208
199
|
% (self.mismatches_link, self.num_mismatches, self.num_completed),
|
209
200
|
)
|
@@ -211,37 +202,33 @@ class Matching(base.Evaluation):
|
|
211
202
|
s.write(
|
212
203
|
'<td><span style="color:green">%s</span>%s</td>'
|
213
204
|
% (
|
214
|
-
|
205
|
+
self._format_rate(self.match_rate),
|
215
206
|
'<a href="%s">(%d/%d)</a>'
|
216
207
|
% (self.matches_link, self.num_matches, self.num_completed),
|
217
208
|
)
|
218
209
|
)
|
219
210
|
|
220
|
-
def
|
211
|
+
def _render_summary_metrics(self, s: io.StringIO) -> None:
|
221
212
|
"""Renders metrics in HTML."""
|
222
213
|
assert self.result is not None
|
223
214
|
m = self.result.metrics
|
224
|
-
|
225
|
-
|
226
|
-
% (
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
f'%.{self.report_precision}f%% ' % (m.match_rate * 100),
|
231
|
-
)
|
215
|
+
self._render_link(
|
216
|
+
s,
|
217
|
+
'Matches (%d/%d)' % (m.num_matches, m.total),
|
218
|
+
self._format_rate(m.match_rate),
|
219
|
+
'color:green',
|
220
|
+
lambda: self.matches_link,
|
232
221
|
)
|
233
222
|
s.write(' | ')
|
234
|
-
|
235
|
-
|
236
|
-
% (
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
f'%.{self.report_precision}f%% ' % (m.mismatch_rate * 100),
|
241
|
-
)
|
223
|
+
self._render_link(
|
224
|
+
s,
|
225
|
+
'Mismatches (%d/%d)' % (m.num_mismatches, m.total),
|
226
|
+
self._format_rate(m.mismatch_rate),
|
227
|
+
'color:orange',
|
228
|
+
lambda: self.mismatches_link,
|
242
229
|
)
|
243
230
|
s.write(' | ')
|
244
|
-
super().
|
231
|
+
super()._render_summary_metrics(s)
|
245
232
|
|
246
233
|
def _render_matches(self, s: io.StringIO) -> None:
|
247
234
|
"""Formats the matched cases into html."""
|
@@ -254,12 +241,29 @@ class Matching(base.Evaluation):
|
|
254
241
|
'<td>Prompt/Response Chain</td>'
|
255
242
|
'</tr>'
|
256
243
|
)
|
257
|
-
|
244
|
+
def _maybe_html(v, root_indent: int):
|
245
|
+
del root_indent
|
246
|
+
if hasattr(v, '_repr_html_'):
|
247
|
+
return v._repr_html_() # pylint: disable=protected-access
|
248
|
+
# Fall back to the default format.
|
249
|
+
return None
|
250
|
+
|
251
|
+
for i, (_, example, output, message) in enumerate(self.matches):
|
258
252
|
bgcolor = 'white' if i % 2 == 0 else '#DDDDDD'
|
259
253
|
s.write(f'<tr style="background-color: {bgcolor}"><td>{i + 1}</td>')
|
260
|
-
input_str = pg.
|
254
|
+
input_str = pg.Html.escape(
|
255
|
+
pg.format(
|
256
|
+
example, verbose=False, max_bytes_len=32,
|
257
|
+
custom_format=_maybe_html
|
258
|
+
)
|
259
|
+
)
|
261
260
|
s.write(f'<td style="color:green;white-space:pre-wrap">{input_str}</td>')
|
262
|
-
output_str = pg.
|
261
|
+
output_str = pg.Html.escape(
|
262
|
+
pg.format(
|
263
|
+
output, verbose=False, max_bytes_len=32,
|
264
|
+
custom_format=_maybe_html
|
265
|
+
)
|
266
|
+
)
|
263
267
|
s.write(f'<td style="color:blue;white-space:pre-wrap">{output_str}</td>')
|
264
268
|
s.write('<td>')
|
265
269
|
self._render_message(message, s)
|
@@ -279,12 +283,12 @@ class Matching(base.Evaluation):
|
|
279
283
|
'</tr>'
|
280
284
|
)
|
281
285
|
|
282
|
-
for i, (example, output, message) in enumerate(self.mismatches):
|
286
|
+
for i, (_, example, output, message) in enumerate(self.mismatches):
|
283
287
|
bgcolor = 'white' if i % 2 == 0 else '#DDDDDD'
|
284
288
|
s.write(f'<tr style="background-color: {bgcolor}"><td>{i + 1}</td>')
|
285
|
-
input_str = pg.format(example, verbose=False)
|
289
|
+
input_str = pg.format(example, verbose=False, max_bytes_len=32)
|
286
290
|
s.write(f'<td style="color:green;white-space:pre-wrap">{input_str}</td>')
|
287
|
-
output_str = pg.format(output, verbose=False)
|
291
|
+
output_str = pg.format(output, verbose=False, max_bytes_len=32)
|
288
292
|
s.write(
|
289
293
|
f'<td style="color:magenta;white-space:pre-wrap">{output_str}</td>'
|
290
294
|
)
|