langfun 0.0.2.dev20240429__py3-none-any.whl → 0.0.2.dev20240511__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of langfun might be problematic. Click here for more details.
- langfun/__init__.py +5 -0
- langfun/core/eval/__init__.py +14 -1
- langfun/core/eval/base.py +503 -112
- langfun/core/eval/base_test.py +185 -53
- langfun/core/eval/matching.py +22 -21
- langfun/core/eval/matching_test.py +23 -2
- langfun/core/eval/patching.py +130 -0
- langfun/core/eval/patching_test.py +170 -0
- langfun/core/eval/scoring.py +4 -4
- langfun/core/eval/scoring_test.py +19 -2
- langfun/core/langfunc.py +1 -17
- langfun/core/langfunc_test.py +4 -0
- langfun/core/language_model.py +6 -0
- langfun/core/llms/__init__.py +8 -0
- langfun/core/llms/fake.py +6 -6
- langfun/core/llms/google_genai.py +8 -0
- langfun/core/llms/openai.py +3 -2
- langfun/core/llms/openai_test.py +2 -1
- langfun/core/llms/vertexai.py +291 -0
- langfun/core/llms/vertexai_test.py +233 -0
- langfun/core/modalities/image.py +1 -3
- langfun/core/modalities/mime.py +6 -0
- langfun/core/modalities/video.py +1 -3
- langfun/core/structured/__init__.py +2 -0
- langfun/core/structured/mapping.py +5 -1
- langfun/core/structured/prompting.py +39 -11
- langfun/core/structured/prompting_test.py +43 -0
- langfun/core/structured/schema.py +34 -4
- langfun/core/structured/schema_test.py +32 -1
- langfun/core/structured/scoring.py +4 -1
- langfun/core/structured/scoring_test.py +6 -0
- langfun/core/template.py +22 -1
- {langfun-0.0.2.dev20240429.dist-info → langfun-0.0.2.dev20240511.dist-info}/METADATA +2 -2
- {langfun-0.0.2.dev20240429.dist-info → langfun-0.0.2.dev20240511.dist-info}/RECORD +37 -33
- {langfun-0.0.2.dev20240429.dist-info → langfun-0.0.2.dev20240511.dist-info}/LICENSE +0 -0
- {langfun-0.0.2.dev20240429.dist-info → langfun-0.0.2.dev20240511.dist-info}/WHEEL +0 -0
- {langfun-0.0.2.dev20240429.dist-info → langfun-0.0.2.dev20240511.dist-info}/top_level.txt +0 -0
langfun/core/eval/base_test.py
CHANGED
@@ -220,7 +220,18 @@ class EvaluationTest(unittest.TestCase):
|
|
220
220
|
cache_stats=dict(
|
221
221
|
use_cache=True, num_queries=2, num_hits=0, num_updates=2
|
222
222
|
),
|
223
|
-
metrics=dict(
|
223
|
+
metrics=dict(
|
224
|
+
total=2,
|
225
|
+
failures=1,
|
226
|
+
failure_rate=0.5,
|
227
|
+
oop_failures=1,
|
228
|
+
oop_failure_rate=0.5,
|
229
|
+
non_oop_failures=0,
|
230
|
+
non_oop_failure_rate=0.0,
|
231
|
+
failure_breakdown={
|
232
|
+
'MappingError.SchemaError.TypeError': 1
|
233
|
+
}
|
234
|
+
),
|
224
235
|
usage=dict(
|
225
236
|
total_prompt_tokens=774,
|
226
237
|
total_completion_tokens=25,
|
@@ -235,12 +246,20 @@ class EvaluationTest(unittest.TestCase):
|
|
235
246
|
os.path.exists(os.path.join(s.dir, base.Evaluation.EXPERIMENT_JSON)))
|
236
247
|
self.assertTrue(
|
237
248
|
os.path.exists(os.path.join(s.dir, base.Evaluation.RESULT_JSON)))
|
249
|
+
self.assertTrue(
|
250
|
+
os.path.exists(os.path.join(s.dir, base.Evaluation.OOP_FAILURES_JSON)))
|
251
|
+
self.assertTrue(
|
252
|
+
os.path.exists(
|
253
|
+
os.path.join(s.dir, base.Evaluation.NON_OOP_FAILURES_JSON)))
|
238
254
|
self.assertTrue(
|
239
255
|
os.path.exists(os.path.join(s.dir, base.Evaluation.CACHE_JSON)))
|
240
256
|
self.assertTrue(
|
241
257
|
os.path.exists(os.path.join(s.dir, base.Evaluation.INDEX_HTML)))
|
242
258
|
self.assertTrue(
|
243
|
-
os.path.exists(os.path.join(s.dir, base.Evaluation.
|
259
|
+
os.path.exists(os.path.join(s.dir, base.Evaluation.OOP_FAILURES_HTML)))
|
260
|
+
self.assertTrue(
|
261
|
+
os.path.exists(
|
262
|
+
os.path.join(s.dir, base.Evaluation.NON_OOP_FAILURES_HTML)))
|
244
263
|
self.assertTrue(
|
245
264
|
os.path.exists(os.path.join(s.root_dir, base.Evaluation.SUMMARY_HTML))
|
246
265
|
)
|
@@ -274,7 +293,10 @@ class EvaluationTest(unittest.TestCase):
|
|
274
293
|
self.assertFalse(
|
275
294
|
os.path.exists(os.path.join(s.dir, base.Evaluation.INDEX_HTML)))
|
276
295
|
self.assertFalse(
|
277
|
-
os.path.exists(os.path.join(s.dir, base.Evaluation.
|
296
|
+
os.path.exists(os.path.join(s.dir, base.Evaluation.OOP_FAILURES_HTML)))
|
297
|
+
self.assertFalse(
|
298
|
+
os.path.exists(
|
299
|
+
os.path.join(s.dir, base.Evaluation.NON_OOP_FAILURES_HTML)))
|
278
300
|
|
279
301
|
def test_load(self):
|
280
302
|
lm = fake.StaticResponse('Solution(final_answer=2)')
|
@@ -312,7 +334,16 @@ class EvaluationTest(unittest.TestCase):
|
|
312
334
|
cache_stats=dict(
|
313
335
|
use_cache=True, num_queries=2, num_hits=0, num_updates=2
|
314
336
|
),
|
315
|
-
metrics=dict(
|
337
|
+
metrics=dict(
|
338
|
+
total=2,
|
339
|
+
failures=0,
|
340
|
+
failure_rate=0.0,
|
341
|
+
oop_failures=0,
|
342
|
+
oop_failure_rate=0.0,
|
343
|
+
non_oop_failures=0,
|
344
|
+
non_oop_failure_rate=0.0,
|
345
|
+
failure_breakdown={},
|
346
|
+
),
|
316
347
|
usage=s.children[1].result.usage,
|
317
348
|
),
|
318
349
|
},
|
@@ -363,7 +394,18 @@ class EvaluationTest(unittest.TestCase):
|
|
363
394
|
cache_stats=dict(
|
364
395
|
use_cache=True, num_queries=2, num_hits=0, num_updates=2
|
365
396
|
),
|
366
|
-
metrics=dict(
|
397
|
+
metrics=dict(
|
398
|
+
total=2,
|
399
|
+
failures=1,
|
400
|
+
failure_rate=0.5,
|
401
|
+
oop_failures=1,
|
402
|
+
oop_failure_rate=0.5,
|
403
|
+
non_oop_failures=0,
|
404
|
+
non_oop_failure_rate=0.0,
|
405
|
+
failure_breakdown={
|
406
|
+
'MappingError.SchemaError.TypeError': 1
|
407
|
+
}
|
408
|
+
),
|
367
409
|
usage=s.children[0].result.usage,
|
368
410
|
),
|
369
411
|
s.children[1].id: dict(
|
@@ -378,7 +420,18 @@ class EvaluationTest(unittest.TestCase):
|
|
378
420
|
cache_stats=dict(
|
379
421
|
use_cache=True, num_queries=2, num_hits=0, num_updates=2
|
380
422
|
),
|
381
|
-
metrics=dict(
|
423
|
+
metrics=dict(
|
424
|
+
total=2,
|
425
|
+
failures=1,
|
426
|
+
failure_rate=0.5,
|
427
|
+
oop_failures=1,
|
428
|
+
oop_failure_rate=0.5,
|
429
|
+
non_oop_failures=0,
|
430
|
+
non_oop_failure_rate=0.0,
|
431
|
+
failure_breakdown={
|
432
|
+
'MappingError.SchemaError.TypeError': 1
|
433
|
+
}
|
434
|
+
),
|
382
435
|
usage=s.children[1].result.usage,
|
383
436
|
),
|
384
437
|
},
|
@@ -475,7 +528,7 @@ class SuiteTest(unittest.TestCase):
|
|
475
528
|
self.assertEqual(s.hash, '26e6cc25')
|
476
529
|
s.run()
|
477
530
|
expected = {
|
478
|
-
|
531
|
+
'Evaluation@0fade07d': dict(
|
479
532
|
experiment_setup=dict(
|
480
533
|
id=s.children[0].id,
|
481
534
|
dir=s.children[0].dir,
|
@@ -487,48 +540,46 @@ class SuiteTest(unittest.TestCase):
|
|
487
540
|
cache_stats=dict(
|
488
541
|
use_cache=True, num_queries=2, num_hits=0, num_updates=2
|
489
542
|
),
|
490
|
-
metrics=dict(
|
543
|
+
metrics=dict(
|
544
|
+
total=2,
|
545
|
+
failures=1,
|
546
|
+
failure_rate=0.5,
|
547
|
+
oop_failures=1,
|
548
|
+
oop_failure_rate=0.5,
|
549
|
+
non_oop_failures=0,
|
550
|
+
non_oop_failure_rate=0.0,
|
551
|
+
failure_breakdown={
|
552
|
+
'MappingError.SchemaError.TypeError': 1
|
553
|
+
}
|
554
|
+
),
|
491
555
|
usage=s.children[0].result.usage,
|
492
556
|
),
|
493
|
-
|
494
|
-
|
495
|
-
|
496
|
-
|
497
|
-
|
498
|
-
|
499
|
-
|
500
|
-
|
501
|
-
prompt_template='{{example.question}}',
|
502
|
-
method='call',
|
503
|
-
schema_fn='answer_schema()',
|
504
|
-
),
|
505
|
-
cache_stats=dict(
|
506
|
-
use_cache=True, num_queries=4, num_hits=1, num_updates=3
|
507
|
-
),
|
508
|
-
metrics=dict(total=2, failures=2, failure_rate=1.0),
|
509
|
-
usage=s.children[1].children[0].result.usage,
|
557
|
+
'Evaluation@ae86c703': dict(
|
558
|
+
experiment_setup=dict(
|
559
|
+
id=s.children[1].children[0].id,
|
560
|
+
dir=s.children[1].children[0].dir,
|
561
|
+
model='StaticSequence',
|
562
|
+
prompt_template='{{example.question}}',
|
563
|
+
method='call',
|
564
|
+
schema_fn='answer_schema()',
|
510
565
|
),
|
511
|
-
|
512
|
-
|
513
|
-
.id: dict(
|
514
|
-
experiment_setup=dict(
|
515
|
-
id=s.children[1].children[2].id,
|
516
|
-
dir=s.children[1].children[2].dir,
|
517
|
-
model='StaticSequence',
|
518
|
-
prompt_template='{{example.question}}',
|
519
|
-
method='query',
|
520
|
-
schema_fn='answer_schema()',
|
521
|
-
),
|
522
|
-
cache_stats=dict(
|
523
|
-
use_cache=True,
|
524
|
-
num_queries=2,
|
525
|
-
num_hits=0,
|
526
|
-
num_updates=2,
|
527
|
-
),
|
528
|
-
metrics=dict(total=2, failures=1, failure_rate=0.5),
|
529
|
-
usage=s.children[1].children[2].result.usage,
|
566
|
+
cache_stats=dict(
|
567
|
+
use_cache=True, num_queries=4, num_hits=1, num_updates=3
|
530
568
|
),
|
531
|
-
|
569
|
+
metrics=dict(
|
570
|
+
total=2,
|
571
|
+
failures=2,
|
572
|
+
failure_rate=1.0,
|
573
|
+
oop_failures=2,
|
574
|
+
oop_failure_rate=1.0,
|
575
|
+
non_oop_failures=0,
|
576
|
+
non_oop_failure_rate=0.0,
|
577
|
+
failure_breakdown={
|
578
|
+
'MappingError.SchemaError.TypeError': 2
|
579
|
+
}
|
580
|
+
),
|
581
|
+
usage=s.children[1].children[0].result.usage,
|
582
|
+
),
|
532
583
|
}
|
533
584
|
self.assertEqual(s.result, expected)
|
534
585
|
|
@@ -698,16 +749,97 @@ class SummaryTest(unittest.TestCase):
|
|
698
749
|
self.assertTrue(pg.io.path_exists(summary_file))
|
699
750
|
|
700
751
|
|
701
|
-
class
|
752
|
+
class NamedEvaluationTest(unittest.TestCase):
|
702
753
|
|
703
|
-
def
|
704
|
-
|
705
|
-
|
706
|
-
|
707
|
-
|
754
|
+
def test_named_eval_class(self):
|
755
|
+
|
756
|
+
@base.register('named_eval/class_test')
|
757
|
+
class MyEval(base.Evaluation):
|
758
|
+
inputs = base.as_inputs([
|
759
|
+
pg.Dict(question='Compute 1 + 1'),
|
760
|
+
])
|
761
|
+
method = 'query'
|
762
|
+
prompt = pg.oneof([
|
763
|
+
lf.Template('{{example.question}}'),
|
764
|
+
lf.Template('Hello {{example.question}}'),
|
765
|
+
])
|
766
|
+
schema_fn = answer_schema()
|
767
|
+
|
768
|
+
evaluation = base.get_evaluation('named_eval/class_test')
|
769
|
+
self.assertIsInstance(evaluation, MyEval)
|
770
|
+
self.assertIsNone(evaluation.dir)
|
771
|
+
self.assertIsNone(evaluation.root_dir)
|
772
|
+
self.assertIn('named_eval/class_test', base.registered_names())
|
773
|
+
|
774
|
+
with self.assertRaisesRegex(ValueError, 'Unsupported type.*'):
|
775
|
+
@base.register('named_eval/bad_class')
|
776
|
+
class Foo: # pylint: disable=unused-variable
|
777
|
+
pass
|
778
|
+
|
779
|
+
def test_named_eval_functor(self):
|
780
|
+
|
781
|
+
@base.register('named_eval/functor_test')
|
782
|
+
def my_eval():
|
783
|
+
return base.Evaluation(
|
784
|
+
inputs=base.as_inputs([
|
785
|
+
pg.Dict(question='Compute 1 + 1'),
|
786
|
+
]),
|
787
|
+
method='query',
|
788
|
+
prompt=pg.oneof([
|
789
|
+
lf.Template('{{example.question}}'),
|
790
|
+
lf.Template('Hello {{example.question}}'),
|
791
|
+
]),
|
792
|
+
schema_fn=answer_schema(),
|
708
793
|
)
|
709
|
-
|
710
|
-
|
794
|
+
|
795
|
+
self.assertTrue(issubclass(my_eval, base.Evaluable))
|
796
|
+
evaluation = base.get_evaluation('named_eval/functor_test')
|
797
|
+
self.assertIn('named_eval/functor_test', base.registered_names())
|
798
|
+
self.assertIsInstance(evaluation, my_eval)
|
799
|
+
self.assertIsNone(evaluation.root_dir, None)
|
800
|
+
|
801
|
+
with self.assertRaisesRegex(ValueError, 'Evaluation .* not found'):
|
802
|
+
base.get_evaluation('named_eval/non_existent')
|
803
|
+
|
804
|
+
with self.assertRaisesRegex(TypeError, 'The return value .*'):
|
805
|
+
@base.register('named_eval/bad_return_type')
|
806
|
+
def bad_eval(): # pylint: disable=unused-variable
|
807
|
+
return 1
|
808
|
+
|
809
|
+
def test_run(self):
|
810
|
+
@base.register('test/run')
|
811
|
+
def test_run(): # pylint: disable=unused-variable
|
812
|
+
lm = fake.StaticResponse('Solution(final_answer=2)')
|
813
|
+
return eval_set('run_test', 'query', schema_fn=answer_schema(), lm=lm)
|
814
|
+
|
815
|
+
e = base.run(
|
816
|
+
tempfile.gettempdir(),
|
817
|
+
['test/run'],
|
818
|
+
id_regex='run_test.*',
|
819
|
+
mode='dryrun',
|
820
|
+
print_definition=True,
|
821
|
+
)
|
822
|
+
self.assertEqual(
|
823
|
+
e.leaf_nodes[0].dir,
|
824
|
+
os.path.join(tempfile.gettempdir(), e.leaf_nodes[0].id),
|
825
|
+
)
|
826
|
+
self.assertTrue(
|
827
|
+
pg.eq(
|
828
|
+
e.leaf_nodes[0].lm, fake.StaticResponse('Solution(final_answer=2)')
|
829
|
+
)
|
830
|
+
)
|
831
|
+
|
832
|
+
@pg.patcher()
|
833
|
+
def bad_lm(unused_eval): # pylint: disable=unused-variable
|
834
|
+
return dict(lm=fake.StaticResponse('efg'))
|
835
|
+
|
836
|
+
e = base.run(
|
837
|
+
tempfile.gettempdir(),
|
838
|
+
[test_run()],
|
839
|
+
filter='Evaluation.*',
|
840
|
+
patches=['bad_lm']
|
841
|
+
)
|
842
|
+
self.assertTrue(pg.eq(e.leaf_nodes[0].lm, fake.StaticResponse('efg')))
|
711
843
|
|
712
844
|
|
713
845
|
if __name__ == '__main__':
|
langfun/core/eval/matching.py
CHANGED
@@ -119,18 +119,18 @@ class Matching(base.Evaluation):
|
|
119
119
|
del progress
|
120
120
|
return {
|
121
121
|
'Model': self.lm.model_id,
|
122
|
-
'Matches':
|
123
|
-
self.match_rate
|
122
|
+
'Matches': '%s (%d/%d)' % (
|
123
|
+
self._format_rate(self.match_rate),
|
124
124
|
self.num_matches,
|
125
125
|
self.num_completed,
|
126
126
|
),
|
127
|
-
'Mismatches':
|
128
|
-
self.mismatch_rate
|
127
|
+
'Mismatches': '%s (%d/%d)' % (
|
128
|
+
self._format_rate(self.mismatch_rate),
|
129
129
|
self.num_mismatches,
|
130
130
|
self.num_completed,
|
131
131
|
),
|
132
|
-
'Failed':
|
133
|
-
self.failure_rate
|
132
|
+
'Failed': '%s (%d/%d)' % (
|
133
|
+
self._format_rate(self.failure_rate),
|
134
134
|
self.num_failures,
|
135
135
|
self.num_completed,
|
136
136
|
),
|
@@ -140,24 +140,25 @@ class Matching(base.Evaluation):
|
|
140
140
|
assert self.result is not None
|
141
141
|
m = self.result.metrics
|
142
142
|
return (
|
143
|
-
|
144
|
-
|
145
|
-
|
143
|
+
'COMPLETED(%s):'
|
144
|
+
' Matches=%s (%d/%d)'
|
145
|
+
' Mismatches=%s (%d/%d)'
|
146
|
+
' Failures=%s (%d/%d)'
|
146
147
|
) % (
|
147
148
|
run_status,
|
148
|
-
m.match_rate
|
149
|
+
self._format_rate(m.match_rate),
|
149
150
|
m.num_matches,
|
150
151
|
m.total,
|
151
|
-
m.mismatch_rate
|
152
|
+
self._format_rate(m.mismatch_rate),
|
152
153
|
m.num_mismatches,
|
153
154
|
m.total,
|
154
|
-
m.failure_rate
|
155
|
+
self._format_rate(m.failure_rate),
|
155
156
|
m.failures,
|
156
157
|
m.total,
|
157
158
|
)
|
158
159
|
|
159
|
-
def
|
160
|
-
result = super().
|
160
|
+
def finalize(self) -> pg.Dict:
|
161
|
+
result = super().finalize()
|
161
162
|
result.metrics.update(
|
162
163
|
num_matches=self.num_matches,
|
163
164
|
match_rate=self.match_rate,
|
@@ -218,9 +219,9 @@ class Matching(base.Evaluation):
|
|
218
219
|
def _render_result_row(self, s: io.StringIO):
|
219
220
|
super()._render_result_row(s)
|
220
221
|
s.write(
|
221
|
-
'<td><span style="color:
|
222
|
+
'<td><span style="color:orange">%s</span>%s</td>'
|
222
223
|
% (
|
223
|
-
|
224
|
+
self._format_rate(self.mismatch_rate),
|
224
225
|
'<a href="%s">(%d/%d)</a>'
|
225
226
|
% (self.mismatches_link, self.num_mismatches, self.num_completed),
|
226
227
|
)
|
@@ -228,13 +229,13 @@ class Matching(base.Evaluation):
|
|
228
229
|
s.write(
|
229
230
|
'<td><span style="color:green">%s</span>%s</td>'
|
230
231
|
% (
|
231
|
-
|
232
|
+
self._format_rate(self.match_rate),
|
232
233
|
'<a href="%s">(%d/%d)</a>'
|
233
234
|
% (self.matches_link, self.num_matches, self.num_completed),
|
234
235
|
)
|
235
236
|
)
|
236
237
|
|
237
|
-
def
|
238
|
+
def _render_summary_metrics(self, s: io.StringIO) -> None:
|
238
239
|
"""Renders metrics in HTML."""
|
239
240
|
assert self.result is not None
|
240
241
|
m = self.result.metrics
|
@@ -244,7 +245,7 @@ class Matching(base.Evaluation):
|
|
244
245
|
m.num_matches,
|
245
246
|
m.total,
|
246
247
|
self.matches_link,
|
247
|
-
|
248
|
+
self._format_rate(m.match_rate),
|
248
249
|
)
|
249
250
|
)
|
250
251
|
s.write(' | ')
|
@@ -254,11 +255,11 @@ class Matching(base.Evaluation):
|
|
254
255
|
m.num_mismatches,
|
255
256
|
m.total,
|
256
257
|
self.mismatches_link,
|
257
|
-
|
258
|
+
self._format_rate(m.mismatch_rate),
|
258
259
|
)
|
259
260
|
)
|
260
261
|
s.write(' | ')
|
261
|
-
super().
|
262
|
+
super()._render_summary_metrics(s)
|
262
263
|
|
263
264
|
def _render_matches(self, s: io.StringIO) -> None:
|
264
265
|
"""Formats the matched cases into html."""
|
@@ -120,6 +120,13 @@ class MatchingTest(unittest.TestCase):
|
|
120
120
|
total=4,
|
121
121
|
failures=1,
|
122
122
|
failure_rate=0.25,
|
123
|
+
oop_failures=1,
|
124
|
+
oop_failure_rate=0.25,
|
125
|
+
non_oop_failures=0,
|
126
|
+
non_oop_failure_rate=0.0,
|
127
|
+
failure_breakdown={
|
128
|
+
'MappingError.SchemaError.TypeError': 1
|
129
|
+
},
|
123
130
|
num_matches=2,
|
124
131
|
match_rate=0.5,
|
125
132
|
num_mismatches=1,
|
@@ -160,7 +167,14 @@ class MatchingTest(unittest.TestCase):
|
|
160
167
|
self.assertTrue(
|
161
168
|
os.path.exists(
|
162
169
|
os.path.join(
|
163
|
-
s.dir, matching.Matching.
|
170
|
+
s.dir, matching.Matching.OOP_FAILURES_JSON
|
171
|
+
)
|
172
|
+
)
|
173
|
+
)
|
174
|
+
self.assertTrue(
|
175
|
+
os.path.exists(
|
176
|
+
os.path.join(
|
177
|
+
s.dir, matching.Matching.NON_OOP_FAILURES_JSON
|
164
178
|
)
|
165
179
|
)
|
166
180
|
)
|
@@ -175,7 +189,14 @@ class MatchingTest(unittest.TestCase):
|
|
175
189
|
self.assertTrue(
|
176
190
|
os.path.exists(
|
177
191
|
os.path.join(
|
178
|
-
s.dir, matching.Matching.
|
192
|
+
s.dir, matching.Matching.OOP_FAILURES_HTML
|
193
|
+
)
|
194
|
+
)
|
195
|
+
)
|
196
|
+
self.assertTrue(
|
197
|
+
os.path.exists(
|
198
|
+
os.path.join(
|
199
|
+
s.dir, matching.Matching.NON_OOP_FAILURES_HTML
|
179
200
|
)
|
180
201
|
)
|
181
202
|
)
|
@@ -0,0 +1,130 @@
|
|
1
|
+
# Copyright 2024 The Langfun Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
"""Experiment patching for Langfun evaluations."""
|
15
|
+
|
16
|
+
import inspect
|
17
|
+
from typing import Union
|
18
|
+
import langfun.core as lf
|
19
|
+
from langfun.core import llms as lf_llms
|
20
|
+
from langfun.core.eval import base
|
21
|
+
import pyglove as pg
|
22
|
+
|
23
|
+
|
24
|
+
#
|
25
|
+
# Program-based patchers.
|
26
|
+
#
|
27
|
+
|
28
|
+
|
29
|
+
def patch_member(cls, key, value, parent_key: str | None = None):
|
30
|
+
"""Patches a member of a class."""
|
31
|
+
|
32
|
+
def _rebind_fn(k, v, p):
|
33
|
+
if (
|
34
|
+
isinstance(p, cls)
|
35
|
+
and k.key == key
|
36
|
+
and (parent_key is None or (p and p.sym_path.key == parent_key))
|
37
|
+
):
|
38
|
+
if inspect.isfunction(value):
|
39
|
+
return value(k, v, p)
|
40
|
+
return value
|
41
|
+
return v
|
42
|
+
|
43
|
+
return _rebind_fn
|
44
|
+
|
45
|
+
|
46
|
+
def patch_lm(lm: Union[lf.LanguageModel, pg.hyper.OneOf]): # pylint: disable=redefined-outer-name
|
47
|
+
"""Patches the LLM of evaluations."""
|
48
|
+
return patch_member(base.Evaluable, "lm", lm)
|
49
|
+
|
50
|
+
|
51
|
+
def patch_parsing_lm(lm: Union[lf.LanguageModel, pg.hyper.OneOf]): # pylint: disable=redefined-outer-name
|
52
|
+
"""Patches the parsing LLM of evaluations."""
|
53
|
+
return patch_member(base.Evaluable, "parsing_lm", lm)
|
54
|
+
|
55
|
+
|
56
|
+
def patch_schema_fn(schema_fn: Union[pg.Functor, pg.hyper.OneOf]):
|
57
|
+
"""Patches the schema_fn of evaluations."""
|
58
|
+
return patch_member(base.Evaluable, "schema_fn", schema_fn)
|
59
|
+
|
60
|
+
|
61
|
+
def patch_prompt(prompt: Union[str, lf.Template, pg.hyper.OneOf]):
|
62
|
+
"""Patches the prompt of evaluations."""
|
63
|
+
return patch_member(base.Evaluable, "prompt", prompt)
|
64
|
+
|
65
|
+
|
66
|
+
def patch_inputs(inputs: Union[pg.Functor, pg.hyper.OneOf]):
|
67
|
+
"""Patches the inputs used in evaluations."""
|
68
|
+
return patch_member(base.Evaluable, "inputs", inputs)
|
69
|
+
|
70
|
+
|
71
|
+
def patch_additional_args(**kwargs):
|
72
|
+
"""Patches additional_args."""
|
73
|
+
|
74
|
+
def value_fn(k, unused_v, p):
|
75
|
+
# We infer the symbolic value for the old args, as it might be a
|
76
|
+
# contextual attribute referring to its containing object.
|
77
|
+
old_args = p.sym_inferred(k.key)
|
78
|
+
if old_args:
|
79
|
+
old_args = dict(old_args)
|
80
|
+
old_args.update(kwargs)
|
81
|
+
return old_args
|
82
|
+
return kwargs
|
83
|
+
|
84
|
+
return patch_member(base.Evaluable, "additional_args", value_fn)
|
85
|
+
|
86
|
+
|
87
|
+
#
|
88
|
+
# String-based patching.
|
89
|
+
#
|
90
|
+
|
91
|
+
_NAMED_MODELS = {
|
92
|
+
# GPT models.
|
93
|
+
"gpt35turbo": lf_llms.Gpt35Turbo,
|
94
|
+
"gpt35turbo16k": lf_llms.Gpt35Turbo16K,
|
95
|
+
"gpt4": lf_llms.Gpt4,
|
96
|
+
"gpt4turbo": lf_llms.Gpt4Turbo,
|
97
|
+
# Anthropic models.
|
98
|
+
"haiku": lf_llms.Claude3Haiku,
|
99
|
+
"claude3haiku": lf_llms.Claude3Haiku,
|
100
|
+
"opus": lf_llms.Claude3Opus,
|
101
|
+
"claude3opus": lf_llms.Claude3Opus,
|
102
|
+
"sonnet": lf_llms.Claude3Sonnet,
|
103
|
+
"claude3sonnet": lf_llms.Claude3Opus,
|
104
|
+
}
|
105
|
+
|
106
|
+
|
107
|
+
def model_by_name(name: str) -> lf.LanguageModel:
|
108
|
+
"""Gets model by name."""
|
109
|
+
name = name.strip().lower()
|
110
|
+
if name in _NAMED_MODELS:
|
111
|
+
return _NAMED_MODELS[name]()
|
112
|
+
raise ValueError(f"Unknown model name: {name}")
|
113
|
+
|
114
|
+
|
115
|
+
@pg.patcher(auto_typing=True)
|
116
|
+
def lm(unused_eval, models: list[str]):
|
117
|
+
"""Patch the LM used for benchmarking."""
|
118
|
+
return patch_lm(pg.oneof([model_by_name(name) for name in models]))
|
119
|
+
|
120
|
+
|
121
|
+
@pg.patcher(auto_typing=True)
|
122
|
+
def temperature(unused_eval, value: float):
|
123
|
+
"""Patch the temperature used for benchmarking."""
|
124
|
+
return patch_member(lf.LMSamplingOptions, "temperature", value)
|
125
|
+
|
126
|
+
|
127
|
+
@pg.patcher(auto_typing=True)
|
128
|
+
def max_tokens(unused_eval, value: int | None):
|
129
|
+
"""Patch the temperature used for benchmarking."""
|
130
|
+
return patch_member(lf.LMSamplingOptions, "max_tokens", value)
|