langfun 0.0.2.dev20240429__py3-none-any.whl → 0.0.2.dev20240511__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of langfun might be problematic. Click here for more details.

Files changed (37) hide show
  1. langfun/__init__.py +5 -0
  2. langfun/core/eval/__init__.py +14 -1
  3. langfun/core/eval/base.py +503 -112
  4. langfun/core/eval/base_test.py +185 -53
  5. langfun/core/eval/matching.py +22 -21
  6. langfun/core/eval/matching_test.py +23 -2
  7. langfun/core/eval/patching.py +130 -0
  8. langfun/core/eval/patching_test.py +170 -0
  9. langfun/core/eval/scoring.py +4 -4
  10. langfun/core/eval/scoring_test.py +19 -2
  11. langfun/core/langfunc.py +1 -17
  12. langfun/core/langfunc_test.py +4 -0
  13. langfun/core/language_model.py +6 -0
  14. langfun/core/llms/__init__.py +8 -0
  15. langfun/core/llms/fake.py +6 -6
  16. langfun/core/llms/google_genai.py +8 -0
  17. langfun/core/llms/openai.py +3 -2
  18. langfun/core/llms/openai_test.py +2 -1
  19. langfun/core/llms/vertexai.py +291 -0
  20. langfun/core/llms/vertexai_test.py +233 -0
  21. langfun/core/modalities/image.py +1 -3
  22. langfun/core/modalities/mime.py +6 -0
  23. langfun/core/modalities/video.py +1 -3
  24. langfun/core/structured/__init__.py +2 -0
  25. langfun/core/structured/mapping.py +5 -1
  26. langfun/core/structured/prompting.py +39 -11
  27. langfun/core/structured/prompting_test.py +43 -0
  28. langfun/core/structured/schema.py +34 -4
  29. langfun/core/structured/schema_test.py +32 -1
  30. langfun/core/structured/scoring.py +4 -1
  31. langfun/core/structured/scoring_test.py +6 -0
  32. langfun/core/template.py +22 -1
  33. {langfun-0.0.2.dev20240429.dist-info → langfun-0.0.2.dev20240511.dist-info}/METADATA +2 -2
  34. {langfun-0.0.2.dev20240429.dist-info → langfun-0.0.2.dev20240511.dist-info}/RECORD +37 -33
  35. {langfun-0.0.2.dev20240429.dist-info → langfun-0.0.2.dev20240511.dist-info}/LICENSE +0 -0
  36. {langfun-0.0.2.dev20240429.dist-info → langfun-0.0.2.dev20240511.dist-info}/WHEEL +0 -0
  37. {langfun-0.0.2.dev20240429.dist-info → langfun-0.0.2.dev20240511.dist-info}/top_level.txt +0 -0
@@ -220,7 +220,18 @@ class EvaluationTest(unittest.TestCase):
220
220
  cache_stats=dict(
221
221
  use_cache=True, num_queries=2, num_hits=0, num_updates=2
222
222
  ),
223
- metrics=dict(total=2, failures=1, failure_rate=0.5),
223
+ metrics=dict(
224
+ total=2,
225
+ failures=1,
226
+ failure_rate=0.5,
227
+ oop_failures=1,
228
+ oop_failure_rate=0.5,
229
+ non_oop_failures=0,
230
+ non_oop_failure_rate=0.0,
231
+ failure_breakdown={
232
+ 'MappingError.SchemaError.TypeError': 1
233
+ }
234
+ ),
224
235
  usage=dict(
225
236
  total_prompt_tokens=774,
226
237
  total_completion_tokens=25,
@@ -235,12 +246,20 @@ class EvaluationTest(unittest.TestCase):
235
246
  os.path.exists(os.path.join(s.dir, base.Evaluation.EXPERIMENT_JSON)))
236
247
  self.assertTrue(
237
248
  os.path.exists(os.path.join(s.dir, base.Evaluation.RESULT_JSON)))
249
+ self.assertTrue(
250
+ os.path.exists(os.path.join(s.dir, base.Evaluation.OOP_FAILURES_JSON)))
251
+ self.assertTrue(
252
+ os.path.exists(
253
+ os.path.join(s.dir, base.Evaluation.NON_OOP_FAILURES_JSON)))
238
254
  self.assertTrue(
239
255
  os.path.exists(os.path.join(s.dir, base.Evaluation.CACHE_JSON)))
240
256
  self.assertTrue(
241
257
  os.path.exists(os.path.join(s.dir, base.Evaluation.INDEX_HTML)))
242
258
  self.assertTrue(
243
- os.path.exists(os.path.join(s.dir, base.Evaluation.FAILURES_HTML)))
259
+ os.path.exists(os.path.join(s.dir, base.Evaluation.OOP_FAILURES_HTML)))
260
+ self.assertTrue(
261
+ os.path.exists(
262
+ os.path.join(s.dir, base.Evaluation.NON_OOP_FAILURES_HTML)))
244
263
  self.assertTrue(
245
264
  os.path.exists(os.path.join(s.root_dir, base.Evaluation.SUMMARY_HTML))
246
265
  )
@@ -274,7 +293,10 @@ class EvaluationTest(unittest.TestCase):
274
293
  self.assertFalse(
275
294
  os.path.exists(os.path.join(s.dir, base.Evaluation.INDEX_HTML)))
276
295
  self.assertFalse(
277
- os.path.exists(os.path.join(s.dir, base.Evaluation.FAILURES_HTML)))
296
+ os.path.exists(os.path.join(s.dir, base.Evaluation.OOP_FAILURES_HTML)))
297
+ self.assertFalse(
298
+ os.path.exists(
299
+ os.path.join(s.dir, base.Evaluation.NON_OOP_FAILURES_HTML)))
278
300
 
279
301
  def test_load(self):
280
302
  lm = fake.StaticResponse('Solution(final_answer=2)')
@@ -312,7 +334,16 @@ class EvaluationTest(unittest.TestCase):
312
334
  cache_stats=dict(
313
335
  use_cache=True, num_queries=2, num_hits=0, num_updates=2
314
336
  ),
315
- metrics=dict(total=2, failures=0, failure_rate=0.0),
337
+ metrics=dict(
338
+ total=2,
339
+ failures=0,
340
+ failure_rate=0.0,
341
+ oop_failures=0,
342
+ oop_failure_rate=0.0,
343
+ non_oop_failures=0,
344
+ non_oop_failure_rate=0.0,
345
+ failure_breakdown={},
346
+ ),
316
347
  usage=s.children[1].result.usage,
317
348
  ),
318
349
  },
@@ -363,7 +394,18 @@ class EvaluationTest(unittest.TestCase):
363
394
  cache_stats=dict(
364
395
  use_cache=True, num_queries=2, num_hits=0, num_updates=2
365
396
  ),
366
- metrics=dict(total=2, failures=1, failure_rate=0.5),
397
+ metrics=dict(
398
+ total=2,
399
+ failures=1,
400
+ failure_rate=0.5,
401
+ oop_failures=1,
402
+ oop_failure_rate=0.5,
403
+ non_oop_failures=0,
404
+ non_oop_failure_rate=0.0,
405
+ failure_breakdown={
406
+ 'MappingError.SchemaError.TypeError': 1
407
+ }
408
+ ),
367
409
  usage=s.children[0].result.usage,
368
410
  ),
369
411
  s.children[1].id: dict(
@@ -378,7 +420,18 @@ class EvaluationTest(unittest.TestCase):
378
420
  cache_stats=dict(
379
421
  use_cache=True, num_queries=2, num_hits=0, num_updates=2
380
422
  ),
381
- metrics=dict(total=2, failures=1, failure_rate=0.5),
423
+ metrics=dict(
424
+ total=2,
425
+ failures=1,
426
+ failure_rate=0.5,
427
+ oop_failures=1,
428
+ oop_failure_rate=0.5,
429
+ non_oop_failures=0,
430
+ non_oop_failure_rate=0.0,
431
+ failure_breakdown={
432
+ 'MappingError.SchemaError.TypeError': 1
433
+ }
434
+ ),
382
435
  usage=s.children[1].result.usage,
383
436
  ),
384
437
  },
@@ -475,7 +528,7 @@ class SuiteTest(unittest.TestCase):
475
528
  self.assertEqual(s.hash, '26e6cc25')
476
529
  s.run()
477
530
  expected = {
478
- s.children[0].id: dict(
531
+ 'Evaluation@0fade07d': dict(
479
532
  experiment_setup=dict(
480
533
  id=s.children[0].id,
481
534
  dir=s.children[0].dir,
@@ -487,48 +540,46 @@ class SuiteTest(unittest.TestCase):
487
540
  cache_stats=dict(
488
541
  use_cache=True, num_queries=2, num_hits=0, num_updates=2
489
542
  ),
490
- metrics=dict(total=2, failures=1, failure_rate=0.5),
543
+ metrics=dict(
544
+ total=2,
545
+ failures=1,
546
+ failure_rate=0.5,
547
+ oop_failures=1,
548
+ oop_failure_rate=0.5,
549
+ non_oop_failures=0,
550
+ non_oop_failure_rate=0.0,
551
+ failure_breakdown={
552
+ 'MappingError.SchemaError.TypeError': 1
553
+ }
554
+ ),
491
555
  usage=s.children[0].result.usage,
492
556
  ),
493
- s.children[1].id: {
494
- s.children[1]
495
- .children[0]
496
- .id: dict(
497
- experiment_setup=dict(
498
- id=s.children[1].children[0].id,
499
- dir=s.children[1].children[0].dir,
500
- model='StaticSequence',
501
- prompt_template='{{example.question}}',
502
- method='call',
503
- schema_fn='answer_schema()',
504
- ),
505
- cache_stats=dict(
506
- use_cache=True, num_queries=4, num_hits=1, num_updates=3
507
- ),
508
- metrics=dict(total=2, failures=2, failure_rate=1.0),
509
- usage=s.children[1].children[0].result.usage,
557
+ 'Evaluation@ae86c703': dict(
558
+ experiment_setup=dict(
559
+ id=s.children[1].children[0].id,
560
+ dir=s.children[1].children[0].dir,
561
+ model='StaticSequence',
562
+ prompt_template='{{example.question}}',
563
+ method='call',
564
+ schema_fn='answer_schema()',
510
565
  ),
511
- s.children[1]
512
- .children[2]
513
- .id: dict(
514
- experiment_setup=dict(
515
- id=s.children[1].children[2].id,
516
- dir=s.children[1].children[2].dir,
517
- model='StaticSequence',
518
- prompt_template='{{example.question}}',
519
- method='query',
520
- schema_fn='answer_schema()',
521
- ),
522
- cache_stats=dict(
523
- use_cache=True,
524
- num_queries=2,
525
- num_hits=0,
526
- num_updates=2,
527
- ),
528
- metrics=dict(total=2, failures=1, failure_rate=0.5),
529
- usage=s.children[1].children[2].result.usage,
566
+ cache_stats=dict(
567
+ use_cache=True, num_queries=4, num_hits=1, num_updates=3
530
568
  ),
531
- },
569
+ metrics=dict(
570
+ total=2,
571
+ failures=2,
572
+ failure_rate=1.0,
573
+ oop_failures=2,
574
+ oop_failure_rate=1.0,
575
+ non_oop_failures=0,
576
+ non_oop_failure_rate=0.0,
577
+ failure_breakdown={
578
+ 'MappingError.SchemaError.TypeError': 2
579
+ }
580
+ ),
581
+ usage=s.children[1].children[0].result.usage,
582
+ ),
532
583
  }
533
584
  self.assertEqual(s.result, expected)
534
585
 
@@ -698,16 +749,97 @@ class SummaryTest(unittest.TestCase):
698
749
  self.assertTrue(pg.io.path_exists(summary_file))
699
750
 
700
751
 
701
- class AppRunTest(unittest.TestCase):
752
+ class NamedEvaluationTest(unittest.TestCase):
702
753
 
703
- def test_app_run(self):
704
- lm = fake.StaticSequence(['two', 'Solution(final_answer=2)'])
705
- try:
706
- base.app_run(
707
- eval_set('app_run_test', 'query', schema_fn=answer_schema(), lm=lm)
754
+ def test_named_eval_class(self):
755
+
756
+ @base.register('named_eval/class_test')
757
+ class MyEval(base.Evaluation):
758
+ inputs = base.as_inputs([
759
+ pg.Dict(question='Compute 1 + 1'),
760
+ ])
761
+ method = 'query'
762
+ prompt = pg.oneof([
763
+ lf.Template('{{example.question}}'),
764
+ lf.Template('Hello {{example.question}}'),
765
+ ])
766
+ schema_fn = answer_schema()
767
+
768
+ evaluation = base.get_evaluation('named_eval/class_test')
769
+ self.assertIsInstance(evaluation, MyEval)
770
+ self.assertIsNone(evaluation.dir)
771
+ self.assertIsNone(evaluation.root_dir)
772
+ self.assertIn('named_eval/class_test', base.registered_names())
773
+
774
+ with self.assertRaisesRegex(ValueError, 'Unsupported type.*'):
775
+ @base.register('named_eval/bad_class')
776
+ class Foo: # pylint: disable=unused-variable
777
+ pass
778
+
779
+ def test_named_eval_functor(self):
780
+
781
+ @base.register('named_eval/functor_test')
782
+ def my_eval():
783
+ return base.Evaluation(
784
+ inputs=base.as_inputs([
785
+ pg.Dict(question='Compute 1 + 1'),
786
+ ]),
787
+ method='query',
788
+ prompt=pg.oneof([
789
+ lf.Template('{{example.question}}'),
790
+ lf.Template('Hello {{example.question}}'),
791
+ ]),
792
+ schema_fn=answer_schema(),
708
793
  )
709
- except SystemExit:
710
- pass
794
+
795
+ self.assertTrue(issubclass(my_eval, base.Evaluable))
796
+ evaluation = base.get_evaluation('named_eval/functor_test')
797
+ self.assertIn('named_eval/functor_test', base.registered_names())
798
+ self.assertIsInstance(evaluation, my_eval)
799
+ self.assertIsNone(evaluation.root_dir, None)
800
+
801
+ with self.assertRaisesRegex(ValueError, 'Evaluation .* not found'):
802
+ base.get_evaluation('named_eval/non_existent')
803
+
804
+ with self.assertRaisesRegex(TypeError, 'The return value .*'):
805
+ @base.register('named_eval/bad_return_type')
806
+ def bad_eval(): # pylint: disable=unused-variable
807
+ return 1
808
+
809
+ def test_run(self):
810
+ @base.register('test/run')
811
+ def test_run(): # pylint: disable=unused-variable
812
+ lm = fake.StaticResponse('Solution(final_answer=2)')
813
+ return eval_set('run_test', 'query', schema_fn=answer_schema(), lm=lm)
814
+
815
+ e = base.run(
816
+ tempfile.gettempdir(),
817
+ ['test/run'],
818
+ id_regex='run_test.*',
819
+ mode='dryrun',
820
+ print_definition=True,
821
+ )
822
+ self.assertEqual(
823
+ e.leaf_nodes[0].dir,
824
+ os.path.join(tempfile.gettempdir(), e.leaf_nodes[0].id),
825
+ )
826
+ self.assertTrue(
827
+ pg.eq(
828
+ e.leaf_nodes[0].lm, fake.StaticResponse('Solution(final_answer=2)')
829
+ )
830
+ )
831
+
832
+ @pg.patcher()
833
+ def bad_lm(unused_eval): # pylint: disable=unused-variable
834
+ return dict(lm=fake.StaticResponse('efg'))
835
+
836
+ e = base.run(
837
+ tempfile.gettempdir(),
838
+ [test_run()],
839
+ filter='Evaluation.*',
840
+ patches=['bad_lm']
841
+ )
842
+ self.assertTrue(pg.eq(e.leaf_nodes[0].lm, fake.StaticResponse('efg')))
711
843
 
712
844
 
713
845
  if __name__ == '__main__':
@@ -119,18 +119,18 @@ class Matching(base.Evaluation):
119
119
  del progress
120
120
  return {
121
121
  'Model': self.lm.model_id,
122
- 'Matches': f'%.{self.report_precision}f%% (%d/%d)' % (
123
- self.match_rate * 100,
122
+ 'Matches': '%s (%d/%d)' % (
123
+ self._format_rate(self.match_rate),
124
124
  self.num_matches,
125
125
  self.num_completed,
126
126
  ),
127
- 'Mismatches': f'%.{self.report_precision}f%% (%d/%d)' % (
128
- self.mismatch_rate * 100,
127
+ 'Mismatches': '%s (%d/%d)' % (
128
+ self._format_rate(self.mismatch_rate),
129
129
  self.num_mismatches,
130
130
  self.num_completed,
131
131
  ),
132
- 'Failed': f'%.{self.report_precision}f%% (%d/%d)' % (
133
- self.failure_rate * 100,
132
+ 'Failed': '%s (%d/%d)' % (
133
+ self._format_rate(self.failure_rate),
134
134
  self.num_failures,
135
135
  self.num_completed,
136
136
  ),
@@ -140,24 +140,25 @@ class Matching(base.Evaluation):
140
140
  assert self.result is not None
141
141
  m = self.result.metrics
142
142
  return (
143
- f'COMPLETED(%s): Matches=%.{self.report_precision}f%% (%d/%d)'
144
- f' Mismatches=%.{self.report_precision}f%% (%d/%d)'
145
- f' Failures=%.{self.report_precision}f%% (%d/%d)'
143
+ 'COMPLETED(%s):'
144
+ ' Matches=%s (%d/%d)'
145
+ ' Mismatches=%s (%d/%d)'
146
+ ' Failures=%s (%d/%d)'
146
147
  ) % (
147
148
  run_status,
148
- m.match_rate * 100,
149
+ self._format_rate(m.match_rate),
149
150
  m.num_matches,
150
151
  m.total,
151
- m.mismatch_rate * 100,
152
+ self._format_rate(m.mismatch_rate),
152
153
  m.num_mismatches,
153
154
  m.total,
154
- m.failure_rate * 100,
155
+ self._format_rate(m.failure_rate),
155
156
  m.failures,
156
157
  m.total,
157
158
  )
158
159
 
159
- def summarize(self) -> pg.Dict:
160
- result = super().summarize()
160
+ def finalize(self) -> pg.Dict:
161
+ result = super().finalize()
161
162
  result.metrics.update(
162
163
  num_matches=self.num_matches,
163
164
  match_rate=self.match_rate,
@@ -218,9 +219,9 @@ class Matching(base.Evaluation):
218
219
  def _render_result_row(self, s: io.StringIO):
219
220
  super()._render_result_row(s)
220
221
  s.write(
221
- '<td><span style="color:red">%s</span>%s</td>'
222
+ '<td><span style="color:orange">%s</span>%s</td>'
222
223
  % (
223
- f'%.{self.report_precision}f%% ' % (self.mismatch_rate * 100),
224
+ self._format_rate(self.mismatch_rate),
224
225
  '<a href="%s">(%d/%d)</a>'
225
226
  % (self.mismatches_link, self.num_mismatches, self.num_completed),
226
227
  )
@@ -228,13 +229,13 @@ class Matching(base.Evaluation):
228
229
  s.write(
229
230
  '<td><span style="color:green">%s</span>%s</td>'
230
231
  % (
231
- f'%.{self.report_precision}f%% ' % (self.match_rate * 100),
232
+ self._format_rate(self.match_rate),
232
233
  '<a href="%s">(%d/%d)</a>'
233
234
  % (self.matches_link, self.num_matches, self.num_completed),
234
235
  )
235
236
  )
236
237
 
237
- def _render_metric(self, s: io.StringIO) -> None:
238
+ def _render_summary_metrics(self, s: io.StringIO) -> None:
238
239
  """Renders metrics in HTML."""
239
240
  assert self.result is not None
240
241
  m = self.result.metrics
@@ -244,7 +245,7 @@ class Matching(base.Evaluation):
244
245
  m.num_matches,
245
246
  m.total,
246
247
  self.matches_link,
247
- f'%.{self.report_precision}f%% ' % (m.match_rate * 100),
248
+ self._format_rate(m.match_rate),
248
249
  )
249
250
  )
250
251
  s.write(' | ')
@@ -254,11 +255,11 @@ class Matching(base.Evaluation):
254
255
  m.num_mismatches,
255
256
  m.total,
256
257
  self.mismatches_link,
257
- f'%.{self.report_precision}f%% ' % (m.mismatch_rate * 100),
258
+ self._format_rate(m.mismatch_rate),
258
259
  )
259
260
  )
260
261
  s.write(' | ')
261
- super()._render_metric(s)
262
+ super()._render_summary_metrics(s)
262
263
 
263
264
  def _render_matches(self, s: io.StringIO) -> None:
264
265
  """Formats the matched cases into html."""
@@ -120,6 +120,13 @@ class MatchingTest(unittest.TestCase):
120
120
  total=4,
121
121
  failures=1,
122
122
  failure_rate=0.25,
123
+ oop_failures=1,
124
+ oop_failure_rate=0.25,
125
+ non_oop_failures=0,
126
+ non_oop_failure_rate=0.0,
127
+ failure_breakdown={
128
+ 'MappingError.SchemaError.TypeError': 1
129
+ },
123
130
  num_matches=2,
124
131
  match_rate=0.5,
125
132
  num_mismatches=1,
@@ -160,7 +167,14 @@ class MatchingTest(unittest.TestCase):
160
167
  self.assertTrue(
161
168
  os.path.exists(
162
169
  os.path.join(
163
- s.dir, matching.Matching.FAILURES_JSON
170
+ s.dir, matching.Matching.OOP_FAILURES_JSON
171
+ )
172
+ )
173
+ )
174
+ self.assertTrue(
175
+ os.path.exists(
176
+ os.path.join(
177
+ s.dir, matching.Matching.NON_OOP_FAILURES_JSON
164
178
  )
165
179
  )
166
180
  )
@@ -175,7 +189,14 @@ class MatchingTest(unittest.TestCase):
175
189
  self.assertTrue(
176
190
  os.path.exists(
177
191
  os.path.join(
178
- s.dir, matching.Matching.FAILURES_HTML
192
+ s.dir, matching.Matching.OOP_FAILURES_HTML
193
+ )
194
+ )
195
+ )
196
+ self.assertTrue(
197
+ os.path.exists(
198
+ os.path.join(
199
+ s.dir, matching.Matching.NON_OOP_FAILURES_HTML
179
200
  )
180
201
  )
181
202
  )
@@ -0,0 +1,130 @@
1
+ # Copyright 2024 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Experiment patching for Langfun evaluations."""
15
+
16
+ import inspect
17
+ from typing import Union
18
+ import langfun.core as lf
19
+ from langfun.core import llms as lf_llms
20
+ from langfun.core.eval import base
21
+ import pyglove as pg
22
+
23
+
24
+ #
25
+ # Program-based patchers.
26
+ #
27
+
28
+
29
+ def patch_member(cls, key, value, parent_key: str | None = None):
30
+ """Patches a member of a class."""
31
+
32
+ def _rebind_fn(k, v, p):
33
+ if (
34
+ isinstance(p, cls)
35
+ and k.key == key
36
+ and (parent_key is None or (p and p.sym_path.key == parent_key))
37
+ ):
38
+ if inspect.isfunction(value):
39
+ return value(k, v, p)
40
+ return value
41
+ return v
42
+
43
+ return _rebind_fn
44
+
45
+
46
+ def patch_lm(lm: Union[lf.LanguageModel, pg.hyper.OneOf]): # pylint: disable=redefined-outer-name
47
+ """Patches the LLM of evaluations."""
48
+ return patch_member(base.Evaluable, "lm", lm)
49
+
50
+
51
+ def patch_parsing_lm(lm: Union[lf.LanguageModel, pg.hyper.OneOf]): # pylint: disable=redefined-outer-name
52
+ """Patches the parsing LLM of evaluations."""
53
+ return patch_member(base.Evaluable, "parsing_lm", lm)
54
+
55
+
56
+ def patch_schema_fn(schema_fn: Union[pg.Functor, pg.hyper.OneOf]):
57
+ """Patches the schema_fn of evaluations."""
58
+ return patch_member(base.Evaluable, "schema_fn", schema_fn)
59
+
60
+
61
+ def patch_prompt(prompt: Union[str, lf.Template, pg.hyper.OneOf]):
62
+ """Patches the prompt of evaluations."""
63
+ return patch_member(base.Evaluable, "prompt", prompt)
64
+
65
+
66
+ def patch_inputs(inputs: Union[pg.Functor, pg.hyper.OneOf]):
67
+ """Patches the inputs used in evaluations."""
68
+ return patch_member(base.Evaluable, "inputs", inputs)
69
+
70
+
71
+ def patch_additional_args(**kwargs):
72
+ """Patches additional_args."""
73
+
74
+ def value_fn(k, unused_v, p):
75
+ # We infer the symbolic value for the old args, as it might be a
76
+ # contextual attribute referring to its containing object.
77
+ old_args = p.sym_inferred(k.key)
78
+ if old_args:
79
+ old_args = dict(old_args)
80
+ old_args.update(kwargs)
81
+ return old_args
82
+ return kwargs
83
+
84
+ return patch_member(base.Evaluable, "additional_args", value_fn)
85
+
86
+
87
+ #
88
+ # String-based patching.
89
+ #
90
+
91
+ _NAMED_MODELS = {
92
+ # GPT models.
93
+ "gpt35turbo": lf_llms.Gpt35Turbo,
94
+ "gpt35turbo16k": lf_llms.Gpt35Turbo16K,
95
+ "gpt4": lf_llms.Gpt4,
96
+ "gpt4turbo": lf_llms.Gpt4Turbo,
97
+ # Anthropic models.
98
+ "haiku": lf_llms.Claude3Haiku,
99
+ "claude3haiku": lf_llms.Claude3Haiku,
100
+ "opus": lf_llms.Claude3Opus,
101
+ "claude3opus": lf_llms.Claude3Opus,
102
+ "sonnet": lf_llms.Claude3Sonnet,
103
+ "claude3sonnet": lf_llms.Claude3Opus,
104
+ }
105
+
106
+
107
+ def model_by_name(name: str) -> lf.LanguageModel:
108
+ """Gets model by name."""
109
+ name = name.strip().lower()
110
+ if name in _NAMED_MODELS:
111
+ return _NAMED_MODELS[name]()
112
+ raise ValueError(f"Unknown model name: {name}")
113
+
114
+
115
+ @pg.patcher(auto_typing=True)
116
+ def lm(unused_eval, models: list[str]):
117
+ """Patch the LM used for benchmarking."""
118
+ return patch_lm(pg.oneof([model_by_name(name) for name in models]))
119
+
120
+
121
+ @pg.patcher(auto_typing=True)
122
+ def temperature(unused_eval, value: float):
123
+ """Patch the temperature used for benchmarking."""
124
+ return patch_member(lf.LMSamplingOptions, "temperature", value)
125
+
126
+
127
+ @pg.patcher(auto_typing=True)
128
+ def max_tokens(unused_eval, value: int | None):
129
+ """Patch the temperature used for benchmarking."""
130
+ return patch_member(lf.LMSamplingOptions, "max_tokens", value)