langfun 0.0.2.dev20240319__py3-none-any.whl → 0.0.2.dev20240429__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. langfun/__init__.py +2 -0
  2. langfun/core/__init__.py +1 -0
  3. langfun/core/coding/python/correction.py +0 -7
  4. langfun/core/component.py +6 -0
  5. langfun/core/component_test.py +1 -0
  6. langfun/core/eval/__init__.py +2 -0
  7. langfun/core/eval/base.py +240 -37
  8. langfun/core/eval/base_test.py +52 -18
  9. langfun/core/eval/matching.py +26 -9
  10. langfun/core/eval/matching_test.py +3 -4
  11. langfun/core/eval/scoring.py +15 -6
  12. langfun/core/eval/scoring_test.py +2 -2
  13. langfun/core/langfunc.py +0 -5
  14. langfun/core/langfunc_test.py +6 -4
  15. langfun/core/language_model.py +124 -24
  16. langfun/core/language_model_test.py +249 -26
  17. langfun/core/llms/__init__.py +24 -5
  18. langfun/core/llms/anthropic.py +263 -0
  19. langfun/core/llms/anthropic_test.py +167 -0
  20. langfun/core/llms/cache/in_memory_test.py +37 -28
  21. langfun/core/llms/fake.py +31 -22
  22. langfun/core/llms/fake_test.py +122 -11
  23. langfun/core/llms/{gemini.py → google_genai.py} +117 -15
  24. langfun/core/llms/{gemini_test.py → google_genai_test.py} +83 -15
  25. langfun/core/llms/groq.py +260 -0
  26. langfun/core/llms/groq_test.py +170 -0
  27. langfun/core/llms/llama_cpp.py +3 -1
  28. langfun/core/llms/openai.py +97 -79
  29. langfun/core/llms/openai_test.py +285 -59
  30. langfun/core/modalities/video.py +5 -2
  31. langfun/core/structured/__init__.py +3 -0
  32. langfun/core/structured/completion_test.py +2 -2
  33. langfun/core/structured/function_generation.py +245 -0
  34. langfun/core/structured/function_generation_test.py +329 -0
  35. langfun/core/structured/mapping.py +59 -3
  36. langfun/core/structured/mapping_test.py +17 -0
  37. langfun/core/structured/parsing.py +2 -1
  38. langfun/core/structured/parsing_test.py +18 -13
  39. langfun/core/structured/prompting.py +27 -6
  40. langfun/core/structured/prompting_test.py +79 -12
  41. langfun/core/structured/schema.py +25 -22
  42. langfun/core/structured/schema_generation.py +2 -3
  43. langfun/core/structured/schema_generation_test.py +2 -2
  44. langfun/core/structured/schema_test.py +42 -27
  45. langfun/core/template.py +125 -10
  46. langfun/core/template_test.py +75 -0
  47. langfun/core/templates/selfplay_test.py +6 -2
  48. {langfun-0.0.2.dev20240319.dist-info → langfun-0.0.2.dev20240429.dist-info}/METADATA +3 -2
  49. {langfun-0.0.2.dev20240319.dist-info → langfun-0.0.2.dev20240429.dist-info}/RECORD +52 -46
  50. {langfun-0.0.2.dev20240319.dist-info → langfun-0.0.2.dev20240429.dist-info}/LICENSE +0 -0
  51. {langfun-0.0.2.dev20240319.dist-info → langfun-0.0.2.dev20240429.dist-info}/WHEEL +0 -0
  52. {langfun-0.0.2.dev20240319.dist-info → langfun-0.0.2.dev20240429.dist-info}/top_level.txt +0 -0
@@ -70,8 +70,7 @@ def eval_set(
70
70
  """Creates an evaluation object for testing."""
71
71
  tmp_dir = tempfile.gettempdir()
72
72
  return cls(
73
- id=eval_id,
74
- root_dir=tmp_dir,
73
+ root_dir=os.path.join(tmp_dir, eval_id),
75
74
  inputs=base.as_inputs([
76
75
  pg.Dict(question='Compute 1 + 1'),
77
76
  pg.Dict(question='Compute 1 + 2'),
@@ -102,7 +101,7 @@ class EvaluationTest(unittest.TestCase):
102
101
  self.assertEqual(s.dir, os.path.join(s.root_dir, s.id))
103
102
  self.assertEqual(s.hash, s.clone().hash)
104
103
  # Test persistent hash.
105
- self.assertEqual(s.hash, 'abc7c29a')
104
+ self.assertEqual(s.hash, 'ae86c703')
106
105
  self.assertEqual(
107
106
  s.hash, s.clone(override={'max_workers': 2, 'lm.timeout': 20}).hash
108
107
  )
@@ -195,6 +194,7 @@ class EvaluationTest(unittest.TestCase):
195
194
  cache_seed=0,
196
195
  score=1.0,
197
196
  logprobs=None,
197
+ usage=lf.LMSamplingUsage(387, 24, 411),
198
198
  tags=['lm-response', 'lm-output', 'transformed'],
199
199
  ),
200
200
  )
@@ -210,7 +210,7 @@ class EvaluationTest(unittest.TestCase):
210
210
  s.result,
211
211
  dict(
212
212
  experiment_setup=dict(
213
- id='run_test',
213
+ id='Evaluation@0fade07d',
214
214
  dir=s.dir,
215
215
  model='StaticSequence',
216
216
  prompt_template='{{example.question}}',
@@ -221,6 +221,14 @@ class EvaluationTest(unittest.TestCase):
221
221
  use_cache=True, num_queries=2, num_hits=0, num_updates=2
222
222
  ),
223
223
  metrics=dict(total=2, failures=1, failure_rate=0.5),
224
+ usage=dict(
225
+ total_prompt_tokens=774,
226
+ total_completion_tokens=25,
227
+ num_usages=2,
228
+ average_prompt_tokens=387,
229
+ average_completion_tokens=12,
230
+ average_total_tokens=399,
231
+ ),
224
232
  ),
225
233
  )
226
234
  self.assertTrue(
@@ -229,13 +237,23 @@ class EvaluationTest(unittest.TestCase):
229
237
  os.path.exists(os.path.join(s.dir, base.Evaluation.RESULT_JSON)))
230
238
  self.assertTrue(
231
239
  os.path.exists(os.path.join(s.dir, base.Evaluation.CACHE_JSON)))
232
- self.assertTrue(
233
- os.path.exists(os.path.join(s.root_dir, base.Evaluation.SUMMARY_HTML))
234
- )
235
240
  self.assertTrue(
236
241
  os.path.exists(os.path.join(s.dir, base.Evaluation.INDEX_HTML)))
237
242
  self.assertTrue(
238
243
  os.path.exists(os.path.join(s.dir, base.Evaluation.FAILURES_HTML)))
244
+ self.assertTrue(
245
+ os.path.exists(os.path.join(s.root_dir, base.Evaluation.SUMMARY_HTML))
246
+ )
247
+ # Check summary JSON.
248
+ summary_json = os.path.join(
249
+ s.root_dir, base.Evaluation.SUMMARY_HTML.replace('.html', '.json')
250
+ )
251
+ self.assertTrue(os.path.exists(summary_json))
252
+ summary = pg.load(summary_json, force_dict=True)
253
+ self.assertIn('Evaluation', summary)
254
+ self.assertEqual(len(summary['Evaluation']), 1)
255
+ self.assertIsNotNone(summary['Evaluation'][0].experiment)
256
+ self.assertIsNotNone(summary['Evaluation'][0].metrics)
239
257
 
240
258
  def test_run_wihtout_save(self):
241
259
  lm = fake.StaticSequence([
@@ -275,8 +293,11 @@ class EvaluationTest(unittest.TestCase):
275
293
  s = eval_set(
276
294
  'run_filter_test', pg.oneof(['call', 'query']),
277
295
  schema_fn=answer_schema(), lm=lm)
296
+ result = s.run(
297
+ filter=lambda x: x.method == 'query', dryrun=True, summary=False
298
+ )
278
299
  self.assertEqual(
279
- s.run(filter=lambda x: x.method == 'query', dryrun=True, summary=False),
300
+ result,
280
301
  {
281
302
  s.children[0].id: None,
282
303
  s.children[1].id: dict(
@@ -292,7 +313,8 @@ class EvaluationTest(unittest.TestCase):
292
313
  use_cache=True, num_queries=2, num_hits=0, num_updates=2
293
314
  ),
294
315
  metrics=dict(total=2, failures=0, failure_rate=0.0),
295
- )
316
+ usage=s.children[1].result.usage,
317
+ ),
296
318
  },
297
319
  )
298
320
 
@@ -302,7 +324,6 @@ class EvaluationTest(unittest.TestCase):
302
324
  '3',
303
325
  ])
304
326
  s = base.Evaluation(
305
- id='search_space_test',
306
327
  root_dir=tempfile.gettempdir(),
307
328
  inputs=base.as_inputs([
308
329
  pg.Dict(question='Compute 1 + 1'),
@@ -323,11 +344,10 @@ class EvaluationTest(unittest.TestCase):
323
344
  s.children[0].dir, os.path.join(s.root_dir, s.children[0].id)
324
345
  )
325
346
  # Test persistent hash.
326
- self.assertEqual(s.hash, 'ca7f722b')
347
+ self.assertEqual(s.hash, 'b66a4e88')
327
348
 
328
349
  summary = s.run(verbose=True)
329
350
  self.assertEqual(len(summary.evaluations), 2)
330
-
331
351
  self.assertEqual(
332
352
  s.result,
333
353
  {
@@ -344,6 +364,7 @@ class EvaluationTest(unittest.TestCase):
344
364
  use_cache=True, num_queries=2, num_hits=0, num_updates=2
345
365
  ),
346
366
  metrics=dict(total=2, failures=1, failure_rate=0.5),
367
+ usage=s.children[0].result.usage,
347
368
  ),
348
369
  s.children[1].id: dict(
349
370
  experiment_setup=dict(
@@ -358,6 +379,7 @@ class EvaluationTest(unittest.TestCase):
358
379
  use_cache=True, num_queries=2, num_hits=0, num_updates=2
359
380
  ),
360
381
  metrics=dict(total=2, failures=1, failure_rate=0.5),
382
+ usage=s.children[1].result.usage,
361
383
  ),
362
384
  },
363
385
  )
@@ -439,7 +461,6 @@ class SuiteTest(unittest.TestCase):
439
461
  '3',
440
462
  ] * 5)
441
463
  s = base.Suite(
442
- 'suite_run_test',
443
464
  [
444
465
  eval_set('run_test_1', 'query', schema_fn=answer_schema()),
445
466
  # A suite of search space. Two of the sub-experiments are identical,
@@ -451,7 +472,7 @@ class SuiteTest(unittest.TestCase):
451
472
  lm=lm
452
473
  )
453
474
  # Test for persistent hash.
454
- self.assertEqual(s.hash, '7285e52b')
475
+ self.assertEqual(s.hash, '26e6cc25')
455
476
  s.run()
456
477
  expected = {
457
478
  s.children[0].id: dict(
@@ -467,6 +488,7 @@ class SuiteTest(unittest.TestCase):
467
488
  use_cache=True, num_queries=2, num_hits=0, num_updates=2
468
489
  ),
469
490
  metrics=dict(total=2, failures=1, failure_rate=0.5),
491
+ usage=s.children[0].result.usage,
470
492
  ),
471
493
  s.children[1].id: {
472
494
  s.children[1]
@@ -484,6 +506,7 @@ class SuiteTest(unittest.TestCase):
484
506
  use_cache=True, num_queries=4, num_hits=1, num_updates=3
485
507
  ),
486
508
  metrics=dict(total=2, failures=2, failure_rate=1.0),
509
+ usage=s.children[1].children[0].result.usage,
487
510
  ),
488
511
  s.children[1]
489
512
  .children[2]
@@ -503,6 +526,7 @@ class SuiteTest(unittest.TestCase):
503
526
  num_updates=2,
504
527
  ),
505
528
  metrics=dict(total=2, failures=1, failure_rate=0.5),
529
+ usage=s.children[1].children[2].result.usage,
506
530
  ),
507
531
  },
508
532
  }
@@ -548,7 +572,6 @@ class SummaryTest(unittest.TestCase):
548
572
  def _eval_set(self, root_dir):
549
573
  return base.Suite(id='select_test', children=[
550
574
  TaskA(
551
- id='task_a',
552
575
  inputs=base.as_inputs([
553
576
  pg.Dict(question='Compute 1 + 1'),
554
577
  ]),
@@ -569,7 +592,6 @@ class SummaryTest(unittest.TestCase):
569
592
  max_workers=1,
570
593
  ),
571
594
  TaskB(
572
- id='task_b',
573
595
  inputs=base.as_inputs([
574
596
  pg.Dict(question='Compute 1 + 1'),
575
597
  ]),
@@ -650,10 +672,10 @@ class SummaryTest(unittest.TestCase):
650
672
  len(base.Summary.from_dirs(root_dir)), 2 * 2 * 2 * 2 + 2 * 1 * 1 * 2
651
673
  )
652
674
  self.assertEqual(
653
- len(base.Summary.from_dirs(root_dir, 'task_b')), 2 * 1 * 1 * 2
675
+ len(base.Summary.from_dirs(root_dir, 'TaskB')), 2 * 1 * 1 * 2
654
676
  )
655
677
  self.assertEqual(
656
- len(base.Summary.from_dirs(root_dir, ('task_a'))), 2 * 2 * 2 * 2
678
+ len(base.Summary.from_dirs(root_dir, ('TaskA'))), 2 * 2 * 2 * 2
657
679
  )
658
680
 
659
681
  def test_monitor(self):
@@ -676,5 +698,17 @@ class SummaryTest(unittest.TestCase):
676
698
  self.assertTrue(pg.io.path_exists(summary_file))
677
699
 
678
700
 
701
+ class AppRunTest(unittest.TestCase):
702
+
703
+ def test_app_run(self):
704
+ lm = fake.StaticSequence(['two', 'Solution(final_answer=2)'])
705
+ try:
706
+ base.app_run(
707
+ eval_set('app_run_test', 'query', schema_fn=answer_schema(), lm=lm)
708
+ )
709
+ except SystemExit:
710
+ pass
711
+
712
+
679
713
  if __name__ == '__main__':
680
714
  unittest.main()
@@ -86,9 +86,26 @@ class Matching(base.Evaluation):
86
86
  self._matches = []
87
87
  self._mismatches = []
88
88
 
89
- def audit(self, example: Any, output: Any, message: lf.Message) -> None:
89
+ def audit_processed(
90
+ self, example: Any, output: Any, message: lf.Message, dryrun: bool = False
91
+ ) -> None:
90
92
  groundtruth = self.groundtruth(example)
91
93
  answer = self.answer(output, example)
94
+
95
+ if dryrun:
96
+ lf.console.write('')
97
+ lf.console.write(
98
+ str(groundtruth),
99
+ title='GROUDTRUTH',
100
+ color='green',
101
+ )
102
+ lf.console.write('')
103
+ lf.console.write(
104
+ str(answer),
105
+ title='ANSWER',
106
+ color='blue',
107
+ )
108
+
92
109
  if self.match(answer, groundtruth):
93
110
  self._matches.append((example, output, message))
94
111
  else:
@@ -155,19 +172,16 @@ class Matching(base.Evaluation):
155
172
  super().save(definition, result, report)
156
173
 
157
174
  if result:
158
-
159
- def force_dict(v):
160
- return pg.object_utils.json_conversion.strip_types(pg.to_json(v))
161
-
162
175
  # Save matches.
163
176
  pg.save(
164
177
  [
165
- # We force the output to be dict as its type may be defined
166
- # within functors which could be deserialized.
167
- pg.Dict(input=input, output=force_dict(output))
178
+ pg.Dict(input=input, output=output)
168
179
  for input, output, _ in self.matches
169
180
  ],
170
181
  os.path.join(self.dir, Matching.MATCHES_JSON),
182
+ # We force the input and output to be dict so it does not depend on
183
+ # the downstream to serialize.
184
+ force_dict=True,
171
185
  )
172
186
 
173
187
  # Save mismatches.
@@ -175,10 +189,13 @@ class Matching(base.Evaluation):
175
189
  [
176
190
  # We force the output to be dict as its type may be defined
177
191
  # within functors which could be deserialized.
178
- pg.Dict(input=input, output=force_dict(output))
192
+ pg.Dict(input=input, output=output)
179
193
  for input, output, _ in self.mismatches
180
194
  ],
181
195
  os.path.join(self.dir, Matching.MISMATCHES_JSON),
196
+ # We force the input and output to be dict so it does not depend on
197
+ # the downstream to serialize.
198
+ force_dict=True,
182
199
  )
183
200
 
184
201
  if report:
@@ -65,10 +65,8 @@ def eval_set(
65
65
  use_cache: bool = True,
66
66
  ):
67
67
  """Creates an evaluation object for testing."""
68
- tmp_dir = tempfile.gettempdir()
69
68
  return MyTask(
70
- id=eval_id,
71
- root_dir=tmp_dir,
69
+ root_dir=os.path.join(tempfile.gettempdir(), eval_id),
72
70
  inputs=base.as_inputs([
73
71
  pg.Dict(question='Compute 1 + 1', groundtruth=2),
74
72
  pg.Dict(question='Compute 1 + 2', groundtruth=3),
@@ -105,7 +103,7 @@ class MatchingTest(unittest.TestCase):
105
103
  s.result,
106
104
  dict(
107
105
  experiment_setup=dict(
108
- id='match_run_test',
106
+ id='MyTask@739a174b',
109
107
  dir=s.dir,
110
108
  model='StaticSequence',
111
109
  prompt_template='{{example.question}}',
@@ -127,6 +125,7 @@ class MatchingTest(unittest.TestCase):
127
125
  num_mismatches=1,
128
126
  mismatch_rate=0.25,
129
127
  ),
128
+ usage=s.result.usage,
130
129
  ),
131
130
  )
132
131
  self.assertTrue(
@@ -61,8 +61,18 @@ class Scoring(base.Evaluation):
61
61
  super()._reset()
62
62
  self._scored = []
63
63
 
64
- def audit(self, example: Any, output: Any, message: lf.Message) -> None:
64
+ def audit_processed(
65
+ self, example: Any, output: Any, message: lf.Message, dryrun: bool = False
66
+ ) -> None:
65
67
  score = self.score(example, output)
68
+
69
+ if dryrun:
70
+ lf.console.write('')
71
+ lf.console.write(
72
+ str(score),
73
+ title='SCORE',
74
+ color='blue',
75
+ )
66
76
  self._scored.append((example, output, score, message))
67
77
 
68
78
  @abc.abstractmethod
@@ -118,19 +128,18 @@ class Scoring(base.Evaluation):
118
128
  super().save(definition, result, report)
119
129
 
120
130
  if result:
121
-
122
- def force_dict(v):
123
- return pg.object_utils.json_conversion.strip_types(pg.to_json(v))
124
-
125
131
  # Save scored.
126
132
  pg.save(
127
133
  [
128
134
  # We force the output to be dict as its type may be defined
129
135
  # within functors which could be deserialized.
130
- pg.Dict(input=input, output=force_dict(output), score=score)
136
+ pg.Dict(input=input, output=output, score=score)
131
137
  for input, output, score, _ in self.scored
132
138
  ],
133
139
  os.path.join(self.dir, Scoring.SCORED_JSON),
140
+ # We force the input and output to be dict so it does not depend on
141
+ # the downstream to serialize.
142
+ force_dict=True,
134
143
  )
135
144
 
136
145
  if report:
@@ -43,7 +43,6 @@ def constrained_by_upperbound(upper_bound: int):
43
43
 
44
44
 
45
45
  class ConstraintFollowing(scoring.Scoring):
46
- id = 'constraint_following'
47
46
  inputs = constrained_by_upperbound(1)
48
47
  prompt = '{{example}}'
49
48
  method = 'query'
@@ -82,7 +81,7 @@ class ScoringTest(unittest.TestCase):
82
81
  s.result,
83
82
  dict(
84
83
  experiment_setup=dict(
85
- id='constraint_following',
84
+ id='ConstraintFollowing@5c88a5eb',
86
85
  dir=s.dir,
87
86
  model='StaticSequence',
88
87
  prompt_template='{{example}}',
@@ -103,6 +102,7 @@ class ScoringTest(unittest.TestCase):
103
102
  score_rate=1.0,
104
103
  avg_score=0.5,
105
104
  ),
105
+ usage=s.result.usage,
106
106
  ),
107
107
  )
108
108
  self.assertTrue(
langfun/core/langfunc.py CHANGED
@@ -261,7 +261,6 @@ class LangFunc(
261
261
  if lm_input is None:
262
262
  lm_input = self.render(**kwargs)
263
263
 
264
- lm_input.tag(message_lib.Message.TAG_LM_INPUT)
265
264
  if skip_lm:
266
265
  return lm_input
267
266
 
@@ -270,10 +269,6 @@ class LangFunc(
270
269
  # Send rendered text to LM.
271
270
  lm_output = self.lm(lm_input, cache_seed=cache_seed)
272
271
 
273
- # Track the input as the source of the output.
274
- lm_output.source = lm_input
275
- lm_output.tag(message_lib.Message.TAG_LM_RESPONSE)
276
-
277
272
  # Transform the output message.
278
273
  lm_output = self.transform_output(lm_output)
279
274
  lm_output.tag(message_lib.Message.TAG_LM_OUTPUT)
@@ -82,7 +82,9 @@ class LangFuncCallTest(unittest.TestCase):
82
82
  self.assertEqual(i.tags, ['rendered'])
83
83
 
84
84
  r = l()
85
- self.assertEqual(r, message.AIMessage('Hello!!!', score=0.0, logprobs=None))
85
+ self.assertEqual(
86
+ r, message.AIMessage('Hello!!!', score=0.0, logprobs=None, usage=None)
87
+ )
86
88
  self.assertEqual(r.tags, ['lm-response', 'lm-output'])
87
89
  self.assertEqual(r.source, message.UserMessage('Hello'))
88
90
  self.assertEqual(r.source.tags, ['rendered', 'lm-input'])
@@ -92,8 +94,8 @@ class LangFuncCallTest(unittest.TestCase):
92
94
  self.assertEqual(
93
95
  repr(l),
94
96
  "LangFunc(template_str='Hello', clean=True,"
95
- ' lm=ExcitedEchoer(sampling_options=LMSamplingOptions(temperature=0.0,'
96
- ' max_tokens=1024, n=1, top_k=40, top_p=None, stop=None,'
97
+ ' lm=ExcitedEchoer(sampling_options=LMSamplingOptions(temperature=None,'
98
+ ' max_tokens=None, n=1, top_k=40, top_p=None, stop=None,'
97
99
  ' random_seed=None, logprobs=False, top_logprobs=None), cache=None,'
98
100
  ' max_concurrency=None, timeout=120.0, max_attempts=5,'
99
101
  ' retry_interval=(5, 60), exponential_backoff=True, debug=False))',
@@ -106,7 +108,7 @@ class LangFuncCallTest(unittest.TestCase):
106
108
  self.assertEqual(l.render(), 'Hello')
107
109
  r = l()
108
110
  self.assertEqual(
109
- r, message.AIMessage('Hello!!!', score=0.0, logprobs=None)
111
+ r, message.AIMessage('Hello!!!', score=0.0, logprobs=None, usage=None)
110
112
  )
111
113
  self.assertEqual(r.tags, ['lm-response', 'lm-output'])
112
114