langfun 0.1.2.dev202507140805__py3-none-any.whl → 0.1.2.dev202507150805__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -17,6 +17,7 @@ import collections
17
17
  from concurrent import futures
18
18
  import contextlib
19
19
  import io
20
+ import sys
20
21
  import time
21
22
  import unittest
22
23
  from langfun.core import component
@@ -330,6 +331,8 @@ class ProgressBarTest(unittest.TestCase):
330
331
  with self.assertRaisesRegex(ValueError, 'Unsupported status'):
331
332
  concurrent.ProgressBar.update(bar_id, 0, status=1)
332
333
  concurrent.ProgressBar.uninstall(bar_id)
334
+ sys.stderr.flush()
335
+ time.sleep(1)
333
336
  self.assertIn('1/4', string_io.getvalue())
334
337
  self.assertIn('2/4', string_io.getvalue())
335
338
  self.assertIn('hello', string_io.getvalue())
@@ -28,7 +28,7 @@ Example = example_lib.Example
28
28
  class SequenceWriterTest(unittest.TestCase):
29
29
 
30
30
  def test_basic(self):
31
- file = os.path.join(tempfile.gettempdir(), 'test.jsonl')
31
+ file = os.path.join(tempfile.mkdtemp(), 'test.jsonl')
32
32
  writer = checkpointing.SequenceWriter(file)
33
33
  example = Example(id=1, input=pg.Dict(x=1), output=2)
34
34
  writer.add(example)
@@ -36,7 +36,7 @@ class SequenceWriterTest(unittest.TestCase):
36
36
  self.assertTrue(pg.io.path_exists(file))
37
37
 
38
38
  def test_error_handling(self):
39
- file = os.path.join(tempfile.gettempdir(), 'test_error_handling.jsonl')
39
+ file = os.path.join(tempfile.mkdtemp(), 'test_error_handling.jsonl')
40
40
  writer = checkpointing.SequenceWriter(file)
41
41
  writer.add(Example(id=1, input=pg.Dict(x=1), output=2))
42
42
 
@@ -87,7 +87,7 @@ class CheckpointerTest(unittest.TestCase):
87
87
  class PerExampleCheckpointerTest(CheckpointerTest):
88
88
 
89
89
  def test_checkpointing(self):
90
- root_dir = os.path.join(tempfile.gettempdir(), 'per_example_checkpointer')
90
+ root_dir = os.path.join(tempfile.mkdtemp(), 'per_example_checkpointer')
91
91
  experiment = eval_test_helper.test_experiment()
92
92
  checkpoint_filename = 'checkpoint.jsonl'
93
93
  checkpointer = checkpointing.PerExampleCheckpointer(checkpoint_filename)
@@ -119,7 +119,7 @@ class PerExampleCheckpointerTest(CheckpointerTest):
119
119
  self.assertEqual(leaf.progress.num_skipped, num_processed[leaf.id])
120
120
 
121
121
  # Test warm start without reprocess.
122
- root_dir = os.path.join(tempfile.gettempdir(), 'per_example_checkpointer2')
122
+ root_dir = os.path.join(tempfile.mkdtemp(), 'per_example_checkpointer2')
123
123
  experiment = eval_test_helper.test_experiment()
124
124
  _ = experiment.run(
125
125
  root_dir, 'new', runner='sequential', plugins=[checkpointer],
@@ -129,7 +129,7 @@ class PerExampleCheckpointerTest(CheckpointerTest):
129
129
  self.assertEqual(leaf.progress.num_skipped, num_processed[leaf.id])
130
130
 
131
131
  # Test warm start with reprocess.
132
- root_dir = os.path.join(tempfile.gettempdir(), 'per_example_checkpointer3')
132
+ root_dir = os.path.join(tempfile.mkdtemp(), 'per_example_checkpointer3')
133
133
  experiment = eval_test_helper.test_experiment()
134
134
  _ = experiment.run(
135
135
  root_dir, 'new', runner='sequential', plugins=[checkpointer],
@@ -139,7 +139,7 @@ class PerExampleCheckpointerTest(CheckpointerTest):
139
139
  for leaf in experiment.leaf_nodes:
140
140
  self.assertEqual(leaf.progress.num_skipped, 0)
141
141
 
142
- root_dir = os.path.join(tempfile.gettempdir(), 'per_example_checkpointer4')
142
+ root_dir = os.path.join(tempfile.mkdtemp(), 'per_example_checkpointer4')
143
143
  experiment = eval_test_helper.test_experiment()
144
144
  _ = experiment.run(
145
145
  root_dir, 'new', runner='sequential', plugins=[checkpointer],
@@ -151,7 +151,7 @@ class PerExampleCheckpointerTest(CheckpointerTest):
151
151
 
152
152
  def test_loading_corrupted_checkpoint(self):
153
153
  root_dir = os.path.join(
154
- tempfile.gettempdir(),
154
+ tempfile.mkdtemp(),
155
155
  'per_example_checkpointer_with_corrupted_checkpoint'
156
156
  )
157
157
  experiment = eval_test_helper.TestEvaluation()
@@ -178,7 +178,7 @@ class PerExampleCheckpointerTest(CheckpointerTest):
178
178
  num_processed[example.id] = i + 1
179
179
 
180
180
  root_dir = os.path.join(
181
- tempfile.gettempdir(),
181
+ tempfile.mkdtemp(),
182
182
  'per_example_checkpointer_with_corrupted_checkpoint_warm_start'
183
183
  )
184
184
  experiment = eval_test_helper.TestEvaluation()
@@ -192,7 +192,7 @@ class PerExampleCheckpointerTest(CheckpointerTest):
192
192
 
193
193
  def test_checkpointing_error(self):
194
194
  root_dir = os.path.join(
195
- tempfile.gettempdir(),
195
+ tempfile.mkdtemp(),
196
196
  'per_example_checkpointer_with_checkpointing_error'
197
197
  )
198
198
  experiment = (eval_test_helper
@@ -207,7 +207,7 @@ class PerExampleCheckpointerTest(CheckpointerTest):
207
207
  class BulkCheckpointerTest(CheckpointerTest):
208
208
 
209
209
  def test_checkpointing(self):
210
- root_dir = os.path.join(tempfile.gettempdir(), 'test_bulk_checkpointer')
210
+ root_dir = os.path.join(tempfile.mkdtemp(), 'test_bulk_checkpointer')
211
211
  experiment = eval_test_helper.test_experiment()
212
212
  checkpoint_filename = 'checkpoint.jsonl'
213
213
  checkpointer = checkpointing.BulkCheckpointer(checkpoint_filename)
@@ -238,7 +238,7 @@ class BulkCheckpointerTest(CheckpointerTest):
238
238
 
239
239
  def test_checkpointing_error(self):
240
240
  root_dir = os.path.join(
241
- tempfile.gettempdir(),
241
+ tempfile.mkdtemp(),
242
242
  'bulk_checkpointer_with_checkpointing_error'
243
243
  )
244
244
  experiment = (eval_test_helper
@@ -116,7 +116,7 @@ class EvaluationTest(unittest.TestCase):
116
116
  self.assertEqual(example.metric_metadata, dict(error='ValueError'))
117
117
 
118
118
  def test_evaluate_withstate(self):
119
- eval_dir = os.path.join(tempfile.gettempdir(), 'test_eval')
119
+ eval_dir = os.path.join(tempfile.mkdtemp(), 'test_eval')
120
120
  pg.io.mkdirs(eval_dir, exist_ok=True)
121
121
  state_file = os.path.join(eval_dir, 'state.jsonl')
122
122
  with pg.io.open_sequence(state_file, 'w') as f:
@@ -145,7 +145,7 @@ class RunIdTest(unittest.TestCase):
145
145
  )
146
146
 
147
147
  def test_get_latest(self):
148
- root_dir = os.path.join(tempfile.gettempdir(), 'test_eval')
148
+ root_dir = os.path.join(tempfile.mkdtemp(), 'test_eval')
149
149
  pg.io.mkdirs(os.path.join(root_dir, 'run_20241102_0'))
150
150
  pg.io.mkdirs(os.path.join(root_dir, 'run_20241101_0'))
151
151
  self.assertEqual(
@@ -153,15 +153,15 @@ class RunIdTest(unittest.TestCase):
153
153
  RunId.from_id('20241102_0')
154
154
  )
155
155
  self.assertIsNone(RunId.get_latest('/notexist'))
156
- self.assertIsNone(RunId.get_latest(tempfile.gettempdir()))
156
+ self.assertIsNone(RunId.get_latest(tempfile.mkdtemp()))
157
157
 
158
158
  def test_new(self):
159
159
  rid = RunId(date=datetime.date.today(), number=1)
160
160
  self.assertEqual(
161
- RunId.new(root_dir=os.path.join(tempfile.gettempdir(), 'test_new')),
161
+ RunId.new(root_dir=os.path.join(tempfile.mkdtemp(), 'test_new')),
162
162
  rid
163
163
  )
164
- root_dir = os.path.join(tempfile.gettempdir(), 'test_eval2')
164
+ root_dir = os.path.join(tempfile.mkdtemp(), 'test_eval2')
165
165
  pg.io.mkdirs(rid.dirname(root_dir))
166
166
  self.assertEqual(RunId.new(root_dir), rid.next())
167
167
 
@@ -185,13 +185,13 @@ class RunIdTest(unittest.TestCase):
185
185
  with self.assertRaisesRegex(
186
186
  ValueError, '.* no previous runs'
187
187
  ):
188
- RunId.from_id('latest', root_dir=tempfile.gettempdir())
188
+ RunId.from_id('latest', root_dir=tempfile.mkdtemp())
189
189
 
190
190
  self.assertEqual(
191
191
  RunId.from_id('20241102_1'),
192
192
  RunId(date=datetime.date(2024, 11, 2), number=1)
193
193
  )
194
- root_dir = os.path.join(tempfile.gettempdir(), 'test_eval3')
194
+ root_dir = os.path.join(tempfile.mkdtemp(), 'test_eval3')
195
195
  rid = RunId.from_id('20241102_1')
196
196
  pg.io.mkdirs(rid.dirname(root_dir))
197
197
  self.assertEqual(
@@ -413,7 +413,7 @@ class RunnerTest(unittest.TestCase):
413
413
  ),
414
414
  TestRunner
415
415
  )
416
- root_dir = os.path.join(tempfile.gettempdir(), 'my_eval')
416
+ root_dir = os.path.join(tempfile.mkdtemp(), 'my_eval')
417
417
 
418
418
  # Test standard run.
419
419
  MyEvaluation(replica_id=0).run(
@@ -34,7 +34,7 @@ class HtmlProgressTrackerTest(unittest.TestCase):
34
34
  lf_console._notebook = pg.Dict(
35
35
  display=display
36
36
  )
37
- root_dir = os.path.join(tempfile.gettempdir(), 'test_html_progress_tracker')
37
+ root_dir = os.path.join(tempfile.mkdtemp(), 'test_html_progress_tracker')
38
38
  experiment = eval_test_helper.test_experiment()
39
39
  _ = experiment.run(root_dir, 'new', plugins=[])
40
40
  self.assertIsInstance(result['view'], pg.Html)
@@ -44,7 +44,7 @@ class HtmlProgressTrackerTest(unittest.TestCase):
44
44
  class TqdmProgressTrackerTest(unittest.TestCase):
45
45
 
46
46
  def test_basic(self):
47
- root_dir = os.path.join(tempfile.gettempdir(), 'test_tqdm_progress_tracker')
47
+ root_dir = os.path.join(tempfile.mkdtemp(), 'test_tqdm_progress_tracker')
48
48
  experiment = eval_test_helper.test_experiment()
49
49
  string_io = io.StringIO()
50
50
  with contextlib.redirect_stderr(string_io):
@@ -53,7 +53,7 @@ class TqdmProgressTrackerTest(unittest.TestCase):
53
53
 
54
54
  def test_with_example_ids(self):
55
55
  root_dir = os.path.join(
56
- tempfile.gettempdir(), 'test_tqdm_progress_tracker_with_example_ids'
56
+ tempfile.mkdtemp(), 'test_tqdm_progress_tracker_with_example_ids'
57
57
  )
58
58
  experiment = eval_test_helper.test_experiment()
59
59
  string_io = io.StringIO()
@@ -25,7 +25,7 @@ import pyglove as pg
25
25
  class ReportingTest(unittest.TestCase):
26
26
 
27
27
  def test_reporting(self):
28
- root_dir = os.path.join(tempfile.gettempdir(), 'test_reporting')
28
+ root_dir = os.path.join(tempfile.mkdtemp(), 'test_reporting')
29
29
  experiment = eval_test_helper.test_experiment()
30
30
  checkpointer = checkpointing.BulkCheckpointer('checkpoint.jsonl')
31
31
  reporter = reporting.HtmlReporter()
@@ -49,7 +49,7 @@ class ReportingTest(unittest.TestCase):
49
49
  self.assertTrue(found_generation_log)
50
50
 
51
51
  # Test warm start.
52
- root_dir = os.path.join(tempfile.gettempdir(), 'test_reporting2')
52
+ root_dir = os.path.join(tempfile.mkdtemp(), 'test_reporting2')
53
53
  experiment = eval_test_helper.test_experiment()
54
54
  run = experiment.run(
55
55
  root_dir, 'new', plugins=[checkpointer, reporter],
@@ -75,7 +75,7 @@ class ReportingTest(unittest.TestCase):
75
75
 
76
76
  def test_index_html_generation_error(self):
77
77
  root_dir = os.path.join(
78
- tempfile.gettempdir(),
78
+ tempfile.mkdtemp(),
79
79
  'test_reporting_with_index_html_generation_error'
80
80
  )
81
81
  experiment = (eval_test_helper
@@ -98,7 +98,7 @@ class ReportingTest(unittest.TestCase):
98
98
 
99
99
  def test_example_html_generation_error(self):
100
100
  root_dir = os.path.join(
101
- tempfile.gettempdir(),
101
+ tempfile.mkdtemp(),
102
102
  'test_reporting_with_example_html_generation_error'
103
103
  )
104
104
  experiment = (eval_test_helper
@@ -126,7 +126,7 @@ class ReportingTest(unittest.TestCase):
126
126
 
127
127
  # Test warm start.
128
128
  root_dir = os.path.join(
129
- tempfile.gettempdir(),
129
+ tempfile.mkdtemp(),
130
130
  'test_reporting_with_example_html_generation_error2'
131
131
  )
132
132
  experiment = (eval_test_helper
@@ -103,7 +103,7 @@ class RunnerTest(unittest.TestCase):
103
103
  def test_basic(self):
104
104
  plugin = TestPlugin()
105
105
  exp = eval_test_helper.test_experiment()
106
- root_dir = os.path.join(tempfile.gettempdir(), 'test_sequential_runner')
106
+ root_dir = os.path.join(tempfile.mkdtemp(), 'test_sequential_runner')
107
107
  run = exp.run(root_dir, runner='sequential', plugins=[plugin])
108
108
 
109
109
  self.assertIsNotNone(plugin.start_time)
@@ -143,7 +143,7 @@ class RunnerTest(unittest.TestCase):
143
143
  self.assertEqual(node.progress.num_processed, node.progress.num_total)
144
144
 
145
145
  def test_raise_if_has_error(self):
146
- root_dir = os.path.join(tempfile.gettempdir(), 'test_raise_if_has_error')
146
+ root_dir = os.path.join(tempfile.mkdtemp(), 'test_raise_if_has_error')
147
147
  exp = eval_test_helper.TestEvaluation()
148
148
  with self.assertRaisesRegex(ValueError, 'x should not be 5'):
149
149
  exp.run(
@@ -154,7 +154,7 @@ class RunnerTest(unittest.TestCase):
154
154
  exp.run(root_dir, runner='parallel', plugins=[], raise_if_has_error=True)
155
155
 
156
156
  def test_example_ids(self):
157
- root_dir = os.path.join(tempfile.gettempdir(), 'test_example_ids')
157
+ root_dir = os.path.join(tempfile.mkdtemp(), 'test_example_ids')
158
158
  exp = eval_test_helper.test_experiment()
159
159
  plugin = TestPlugin()
160
160
  _ = exp.run(
@@ -164,7 +164,7 @@ class RunnerTest(unittest.TestCase):
164
164
  self.assertEqual(plugin.completed_example_ids, [5, 7, 9] * 6)
165
165
 
166
166
  def test_shuffle_inputs(self):
167
- root_dir = os.path.join(tempfile.gettempdir(), 'test_shuffle_inputs')
167
+ root_dir = os.path.join(tempfile.mkdtemp(), 'test_shuffle_inputs')
168
168
  exp = eval_test_helper.test_experiment()
169
169
  plugin = TestPlugin()
170
170
  run = exp.run(
@@ -175,7 +175,7 @@ class RunnerTest(unittest.TestCase):
175
175
  def test_filter(self):
176
176
  plugin = TestPlugin()
177
177
  exp = eval_test_helper.test_experiment()
178
- root_dir = os.path.join(tempfile.gettempdir(), 'test_filter')
178
+ root_dir = os.path.join(tempfile.mkdtemp(), 'test_filter')
179
179
 
180
180
  _ = exp.run(
181
181
  root_dir, runner='sequential', plugins=[plugin],
@@ -207,7 +207,7 @@ class RunnerTest(unittest.TestCase):
207
207
  inputs=test_inputs(num_examples=pg.oneof([2, 4]))
208
208
  )
209
209
  # Global cache.
210
- root_dir = os.path.join(tempfile.gettempdir(), 'global_cache')
210
+ root_dir = os.path.join(tempfile.mkdtemp(), 'global_cache')
211
211
  run = exp.run(
212
212
  root_dir, 'new', runner='sequential', use_cache='global', plugins=[]
213
213
  )
@@ -216,7 +216,7 @@ class RunnerTest(unittest.TestCase):
216
216
  self.assertEqual(exp.usage_summary.uncached.total.num_requests, 2)
217
217
 
218
218
  # Per-dataset cache.
219
- root_dir = os.path.join(tempfile.gettempdir(), 'per_dataset')
219
+ root_dir = os.path.join(tempfile.mkdtemp(), 'per_dataset')
220
220
  run = exp.run(
221
221
  root_dir, 'new', runner='sequential',
222
222
  use_cache='per_dataset', plugins=[]
@@ -229,7 +229,7 @@ class RunnerTest(unittest.TestCase):
229
229
  self.assertEqual(exp.usage_summary.uncached.total.num_requests, 3)
230
230
 
231
231
  # No cache.
232
- root_dir = os.path.join(tempfile.gettempdir(), 'no')
232
+ root_dir = os.path.join(tempfile.mkdtemp(), 'no')
233
233
  run = exp.run(root_dir, runner='sequential', use_cache='no', plugins=[])
234
234
  self.assertFalse(pg.io.path_exists(run.output_path_for(exp, 'cache.json')))
235
235
  for leaf in exp.leaf_nodes:
@@ -245,7 +245,7 @@ class ParallelRunnerTest(RunnerTest):
245
245
  def test_parallel_runner(self):
246
246
  plugin = TestPlugin()
247
247
  exp = eval_test_helper.test_experiment()
248
- root_dir = os.path.join(tempfile.gettempdir(), 'test_parallel_runner')
248
+ root_dir = os.path.join(tempfile.mkdtemp(), 'test_parallel_runner')
249
249
  run = exp.run(root_dir, runner='parallel', plugins=[plugin])
250
250
 
251
251
  self.assertIsNotNone(plugin.start_time)
@@ -286,7 +286,7 @@ class ParallelRunnerTest(RunnerTest):
286
286
  plugin = TestPlugin()
287
287
  exp = eval_test_helper.test_experiment()
288
288
  root_dir = os.path.join(
289
- tempfile.gettempdir(), 'test_concurrent_startup_delay'
289
+ tempfile.mkdtemp(), 'test_concurrent_startup_delay'
290
290
  )
291
291
  _ = exp.run(
292
292
  root_dir,
@@ -301,7 +301,7 @@ class DebugRunnerTest(RunnerTest):
301
301
  def test_debug_runner(self):
302
302
  plugin = TestPlugin()
303
303
  exp = eval_test_helper.test_experiment()
304
- root_dir = os.path.join(tempfile.gettempdir(), 'test_debug_runner')
304
+ root_dir = os.path.join(tempfile.mkdtemp(), 'test_debug_runner')
305
305
  run = exp.run(root_dir, runner='debug', plugins=[plugin])
306
306
 
307
307
  self.assertIsNotNone(plugin.start_time)
@@ -667,10 +667,7 @@ class LMDebugMode(enum.IntFlag):
667
667
  PROMPT = enum.auto()
668
668
  RESPONSE = enum.auto()
669
669
 
670
- @classmethod
671
- @property
672
- def ALL(cls) -> 'LMDebugMode': # pylint: disable=invalid-name
673
- return LMDebugMode.INFO | LMDebugMode.PROMPT | LMDebugMode.RESPONSE
670
+ ALL = INFO | PROMPT | RESPONSE
674
671
 
675
672
 
676
673
  class LanguageModel(component.Component):
@@ -1101,7 +1098,10 @@ class LanguageModel(component.Component):
1101
1098
  return [job.result for job in executed_jobs]
1102
1099
 
1103
1100
  def __call__(
1104
- self, prompt: message_lib.Message, *, cache_seed: int = 0, **kwargs
1101
+ self,
1102
+ prompt: str | message_lib.Message,
1103
+ *,
1104
+ cache_seed: int = 0, **kwargs
1105
1105
  ) -> message_lib.Message:
1106
1106
  """Returns the first candidate."""
1107
1107
  prompt = message_lib.UserMessage.from_value(prompt)
@@ -683,16 +683,14 @@ class LanguageModelTest(unittest.TestCase):
683
683
 
684
684
  debug_info = string_io.getvalue()
685
685
  expected_included = [
686
- debug_prints[f]
687
- for f in lm_lib.LMDebugMode
688
- if f != lm_lib.LMDebugMode.NONE and f in debug_mode
686
+ debug_prints[f] for f in (info_flag, prompt_flag, response_flag)
687
+ if f in debug_mode
689
688
  ]
690
689
  expected_excluded = [
691
690
  debug_prints[f]
692
- for f in lm_lib.LMDebugMode
693
- if f != lm_lib.LMDebugMode.NONE and f not in debug_mode
691
+ for f in (info_flag, prompt_flag, response_flag)
692
+ if f not in debug_mode
694
693
  ]
695
-
696
694
  for expected_include in expected_included:
697
695
  self.assertIn('[0] ' + expected_include, debug_info)
698
696
  for expected_exclude in expected_excluded:
@@ -750,13 +748,13 @@ class LanguageModelTest(unittest.TestCase):
750
748
  debug_info = string_io.getvalue()
751
749
  expected_included = [
752
750
  debug_prints[f]
753
- for f in lm_lib.LMDebugMode
754
- if f != lm_lib.LMDebugMode.NONE and f in debug_mode
751
+ for f in (info_flag, prompt_flag, response_flag)
752
+ if f in debug_mode
755
753
  ]
756
754
  expected_excluded = [
757
755
  debug_prints[f]
758
- for f in lm_lib.LMDebugMode
759
- if f != lm_lib.LMDebugMode.NONE and f not in debug_mode
756
+ for f in (info_flag, prompt_flag, response_flag)
757
+ if f not in debug_mode
760
758
  ]
761
759
 
762
760
  for expected_include in expected_included:
@@ -813,13 +811,13 @@ class LanguageModelTest(unittest.TestCase):
813
811
  debug_info = string_io.getvalue()
814
812
  expected_included = [
815
813
  debug_prints[f]
816
- for f in lm_lib.LMDebugMode
817
- if f != lm_lib.LMDebugMode.NONE and f in debug_mode
814
+ for f in (info_flag, prompt_flag, response_flag)
815
+ if f in debug_mode
818
816
  ]
819
817
  expected_excluded = [
820
818
  debug_prints[f]
821
- for f in lm_lib.LMDebugMode
822
- if f != lm_lib.LMDebugMode.NONE and f not in debug_mode
819
+ for f in (info_flag, prompt_flag, response_flag)
820
+ if f not in debug_mode
823
821
  ]
824
822
 
825
823
  for expected_include in expected_included:
@@ -78,7 +78,7 @@ class RandomChoice(lf.LanguageModel):
78
78
  )
79
79
 
80
80
  def __call__(
81
- self, prompt: lf.Message, *, cache_seed: int = 0, **kwargs
81
+ self, prompt: str | lf.Message, *, cache_seed: int = 0, **kwargs
82
82
  ) -> lf.Message:
83
83
  return self._select_lm()(prompt, cache_seed=cache_seed, **kwargs)
84
84
 
langfun/core/message.py CHANGED
@@ -225,13 +225,11 @@ class Message(
225
225
  return MessageConverter.get(format_or_type, **kwargs).to_value(self)
226
226
 
227
227
  @classmethod
228
- @property
229
228
  def convertible_formats(cls) -> list[str]:
230
229
  """Returns supported format for message conversion."""
231
230
  return MessageConverter.convertible_formats()
232
231
 
233
232
  @classmethod
234
- @property
235
233
  def convertible_types(cls) -> list[str]:
236
234
  """Returns supported types for message conversion."""
237
235
  return MessageConverter.convertible_types()
@@ -938,8 +936,8 @@ class MessageConverter(pg.Object):
938
936
  """Converts a Langfun message to other formats."""
939
937
 
940
938
  @abc.abstractmethod
941
- def from_value(self, value: Message) -> Message:
942
- """Returns a MessageConverter from a Langfun message."""
939
+ def from_value(self, value: Any) -> Message:
940
+ """Returns a Langfun message from other formats."""
943
941
 
944
942
  @classmethod
945
943
  def _safe_read(
@@ -521,12 +521,12 @@ class MessageConverterTest(unittest.TestCase):
521
521
  def from_value(self, value: tuple[int, ...]) -> message.Message:
522
522
  return message.UserMessage(','.join(str(x) for x in value))
523
523
 
524
- self.assertIn('test_format1', message.Message.convertible_formats)
525
- self.assertIn('test_format2', message.Message.convertible_formats)
526
- self.assertIn('test_format3', message.Message.convertible_formats)
524
+ self.assertIn('test_format1', message.Message.convertible_formats())
525
+ self.assertIn('test_format2', message.Message.convertible_formats())
526
+ self.assertIn('test_format3', message.Message.convertible_formats())
527
527
 
528
- self.assertIn(int, message.Message.convertible_types)
529
- self.assertIn(tuple, message.Message.convertible_types)
528
+ self.assertIn(int, message.Message.convertible_types())
529
+ self.assertIn(tuple, message.Message.convertible_types())
530
530
  self.assertEqual(
531
531
  message.Message.from_value(1, format='test_format1'),
532
532
  message.UserMessage('1')
@@ -433,8 +433,18 @@ class Mapping(lf.LangFunc):
433
433
  schema = self.mapping_request.schema
434
434
  if schema is None:
435
435
  return None
436
+ response_text = lm_output.text
437
+
438
+ # For Gemini, we might have tool calls in the metadata, use tool call codes
439
+ # to construct the response text if it's present.
440
+ # NOTE(daiyip): This logic is subject to change.
441
+ if 'tool_calls' in lm_output.metadata:
442
+ assert lm_output.metadata['tool_calls'], lm_output.metadata
443
+ response_text = '\n'.join(
444
+ tc.text for tc in lm_output.metadata['tool_calls']
445
+ )
436
446
  return schema.parse(
437
- lm_output.text,
447
+ response_text,
438
448
  protocol=self.protocol,
439
449
  additional_context=self.globals(),
440
450
  autofix=self.autofix,
@@ -114,6 +114,21 @@ class QueryTest(unittest.TestCase):
114
114
  ),
115
115
  'The answer is one.',
116
116
  )
117
+ # Testing tool calls in the response.
118
+ self.assertEqual(
119
+ querying.query(
120
+ 'abc',
121
+ Activity,
122
+ lm=fake.StaticResponse(
123
+ lf.AIMessage(
124
+ 'Here is the answer.',
125
+ tool_calls=[lf.AIMessage('Activity(description="hello")')],
126
+ ),
127
+ ),
128
+ ),
129
+ Activity(description='hello'),
130
+ )
131
+ # Test completing a partial object.
117
132
  self.assertEqual(
118
133
  querying.query(
119
134
  Activity.partial(),
@@ -17,6 +17,7 @@ import abc
17
17
  import inspect
18
18
  import io
19
19
  import re
20
+ import sys
20
21
  import textwrap
21
22
  import typing
22
23
  from typing import Any, Literal, Sequence, Type, Union
@@ -451,6 +452,15 @@ def class_definition(
451
452
  out.write(f' """{cls.__doc__}"""\n')
452
453
  else:
453
454
  out.write(' """')
455
+
456
+ # Since Python 3.13, the indentation of docstring lines is removed.
457
+ # Therefore, we add two spaces to each non-empty line to keep the
458
+ # indentation consistent with the class definition.
459
+ if sys.version_info >= (3, 13):
460
+ for i in range(1, len(doc_lines)):
461
+ if doc_lines[i]:
462
+ doc_lines[i] = ' ' * 2 + doc_lines[i]
463
+
454
464
  for line in doc_lines:
455
465
  out.write(line)
456
466
  out.write('\n')
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: langfun
3
- Version: 0.1.2.dev202507140805
3
+ Version: 0.1.2.dev202507150805
4
4
  Summary: Langfun: Language as Functions.
5
5
  Home-page: https://github.com/google/langfun
6
6
  Author: Langfun Authors
@@ -3,18 +3,18 @@ langfun/core/__init__.py,sha256=pW4prpiyWNkRbtWBGYF1thn7_0F_TgDVfAIZPvGn6HA,4758
3
3
  langfun/core/component.py,sha256=g1kQM0bryYYYWVDrSMnHfc74wIBbpfe5_B3s-UIP5GE,3028
4
4
  langfun/core/component_test.py,sha256=0CxTgjAud3aj8wBauFhG2FHDqrxCTl4OI4gzQTad-40,9254
5
5
  langfun/core/concurrent.py,sha256=zY-pXqlGqss_GI20tM1gXvyW8QepVPUuFNmutcIdhbI,32760
6
- langfun/core/concurrent_test.py,sha256=roMFze0EKuyPbmG6DZzz8K8VGsZwWzc0F1uJZTFROC4,17572
6
+ langfun/core/concurrent_test.py,sha256=fjVcxD_OSH9fBqBEpDpuIVfcfoKZWDtwmkoM2ZMHqy8,17628
7
7
  langfun/core/console.py,sha256=cLQEf84aDxItA9fStJV22xJch0TqFLNf9hLqwJ0RHmU,2652
8
8
  langfun/core/console_test.py,sha256=pBOcuNMJdVELywvroptfcRtJMsegMm3wSlHAL2TdxVk,1679
9
9
  langfun/core/langfunc.py,sha256=G50YgoVZ0y1GFw2ev41MlOqr6qa8YakbvNC0h_E0PiA,11140
10
10
  langfun/core/langfunc_test.py,sha256=CDn-gJCa5EnjN7cotAVCfSCbuzddq2o-HzEt7kV8HbY,8882
11
- langfun/core/language_model.py,sha256=fJeYDz_TD1feiUvSysXXeo2bV-cq5T34HWeYgsTICP4,49680
12
- langfun/core/language_model_test.py,sha256=VyiHwrUtJGkoLyzsjhVqawijwtwoRqsYOvQD57n8Iv8,37413
11
+ langfun/core/language_model.py,sha256=5i0Je5526JO2YY6qExi6Yf7VQVgSVeZIKOjt3I8kxqQ,49573
12
+ langfun/core/language_model_test.py,sha256=9EofP3_gTH28SNWiOKzTUYMHH0EYtbi9xGuT1KZT1XU,37330
13
13
  langfun/core/logging.py,sha256=7IGAhp7mGokZxxqtL-XZvFLKaZ5k3F5_Xp2NUtR4GwE,9136
14
14
  langfun/core/logging_test.py,sha256=vbVGOQxwMmVSiFfbt2897gUt-8nqDpV64jCAeUG_q5U,6924
15
15
  langfun/core/memory.py,sha256=vyXVvfvSdLLJAzdIupnbn3k26OgclCx-OJ7gddS5e1Y,2070
16
- langfun/core/message.py,sha256=qUJZ9NfrlYG3aU_K2ud236gdTnvZ7Qw2T4pv4hI9ivg,32817
17
- langfun/core/message_test.py,sha256=XE7SaYJUK8TCJBdih4XtnpnB6NhcU-vLxJoaP67WbSU,40793
16
+ langfun/core/message.py,sha256=Nx9SqEIkPMS5I1RyMQFlWUjZCsdlGamv_wTze2-3R4M,32784
17
+ langfun/core/message_test.py,sha256=dAA_ZzI5MGyFfXyejxPrB90SbR066mkIgmRtdZ5ZbL4,40803
18
18
  langfun/core/modality.py,sha256=K8pUGuMpfWcOtVcXC_OqVjro1-RhHF6ddQni61DuYzM,4166
19
19
  langfun/core/modality_test.py,sha256=0WL_yd3B4K-FviWdSpDnOwj0f9TQI0v9t6X0vWvvJbo,2415
20
20
  langfun/core/natural_language.py,sha256=3ynSnaYQnjE60LIPK5fyMgdIjubnPYZwzGq4rWPeloE,1177
@@ -61,14 +61,14 @@ langfun/core/eval/scoring.py,sha256=_DvnlgI1SdRVaOojao_AkV3pnenfCPOqyhvlg-Sw-5M,
61
61
  langfun/core/eval/scoring_test.py,sha256=UcBH0R6vAovZ0A4yM22s5cBHL1qVKASubrbu1t8dYBw,4529
62
62
  langfun/core/eval/v2/__init__.py,sha256=9lNKJwbvl0lcFblAXYT_OHI8fOubJsTOdSkxEqsP1xU,1726
63
63
  langfun/core/eval/v2/checkpointing.py,sha256=t47rBfzGZYgIqWW1N1Ak9yQnNtHd-IRbEO0cZjG2VRo,11755
64
- langfun/core/eval/v2/checkpointing_test.py,sha256=NggOSJ_6XSa4cNP6nGIu9wLsK59dUwe8SPWDiXtGGDE,9197
64
+ langfun/core/eval/v2/checkpointing_test.py,sha256=cuQ1zom5DMXIebxYW6L3N5XRyhfoEEDrs7XQcAxg8Nc,9164
65
65
  langfun/core/eval/v2/eval_test_helper.py,sha256=sKFi_wPYCNmr96WyTduuXY0KnxjFxcJyEhXey-_nGX8,3962
66
66
  langfun/core/eval/v2/evaluation.py,sha256=ihT5dljnUkHM97XS9OwE2wOnYC-oYnHYgG5KN1hmiaU,27037
67
- langfun/core/eval/v2/evaluation_test.py,sha256=QNp_HEvRTupvNuLEeYTvylykh1Ut2jpMqHQ-gCUZQ10,6919
67
+ langfun/core/eval/v2/evaluation_test.py,sha256=46bGjNZmd57NXcJSoaC17DO9B74rpVBOVTEln_4W61c,6916
68
68
  langfun/core/eval/v2/example.py,sha256=v1dIz89pccIqujt7utrk0EbqMWM9kBn-2fYGRTKe358,10890
69
69
  langfun/core/eval/v2/example_test.py,sha256=wsHQD6te7ghROmxe3Xg_NK4TU0xS2MkNfnpo-H0H8xM,3399
70
70
  langfun/core/eval/v2/experiment.py,sha256=fb3RHNOSRftV7ZTBfYVV50iEevqdPwRHCt3mgtLzuFw,33408
71
- langfun/core/eval/v2/experiment_test.py,sha256=UmCobeS6ifPcaGkTJp0WPISolXrVFbeFCBiyJeA0Lt4,13666
71
+ langfun/core/eval/v2/experiment_test.py,sha256=BYrPYfQfU2jDfAlZcHDT0KUaXOnCnyTWyUEKhDoqXfw,13645
72
72
  langfun/core/eval/v2/metric_values.py,sha256=_B905bC-jxrYPLSEcP2M8MaHZOVMz_bVrUw8YC4arCE,4660
73
73
  langfun/core/eval/v2/metric_values_test.py,sha256=ab2oF_HsIwrSy459108ggyjgefHSPn8UVILR4dRwx14,2634
74
74
  langfun/core/eval/v2/metrics.py,sha256=bl8i6u-ZHRBz4hAc3LzsZ2Dc7ZRQcuTYeUhhH-GxfF0,10628
@@ -76,17 +76,17 @@ langfun/core/eval/v2/metrics_test.py,sha256=LibZXvWEJDVRY-Mza_bQT-SbmbXCHUnFhL7Z
76
76
  langfun/core/eval/v2/progress.py,sha256=azZgssQgNdv3IgjKEaQBuGI5ucFDNbdi02P4z_nQ8GE,10292
77
77
  langfun/core/eval/v2/progress_test.py,sha256=YU7VHzmy5knPZwj9vpBN3rQQH2tukj9eKHkuBCI62h8,2540
78
78
  langfun/core/eval/v2/progress_tracking.py,sha256=zNhNPGlnJnHELEfFpbTMCSXFn8d1IJ57OOYkfFaBFfM,6097
79
- langfun/core/eval/v2/progress_tracking_test.py,sha256=fouMVJkFJqHjbhQJngGLGCmA9x3n0dU4USI2dY163mg,2291
79
+ langfun/core/eval/v2/progress_tracking_test.py,sha256=sJhlVfinGsg3Kf2wQ_hT7VMcpQfaI4ZkqyW9ujElkwA,2282
80
80
  langfun/core/eval/v2/reporting.py,sha256=yUIPCAMnp7InIzpv1DDWrcLO-75iiOUTpscj7smkfrA,8335
81
- langfun/core/eval/v2/reporting_test.py,sha256=hcPJJaMtPulqERvHYTpId83WXdqDKnnexmULtK7WKwk,5686
81
+ langfun/core/eval/v2/reporting_test.py,sha256=CMK-vwho8cNRJwlbkCqm_v5fykE7Y3V6SaIOCY0CDyA,5671
82
82
  langfun/core/eval/v2/runners.py,sha256=iqbH4jMtnNMhfuv1eHaxJmk1Vvsrz-sAJJFP8U44-tA,16758
83
- langfun/core/eval/v2/runners_test.py,sha256=DO3xV0sBNB6n65j41xx2i7gqUCJcPF37DFZLEjrmISg,11987
83
+ langfun/core/eval/v2/runners_test.py,sha256=spjkmqlls_vyERdZMdjv6dhIN9ZfxsDDvIQAWTj2kMk,11954
84
84
  langfun/core/llms/__init__.py,sha256=CtxUdXohQ8AQk1DqBT6MBy2zdAoPSggNo00SYrj9-AY,9521
85
85
  langfun/core/llms/anthropic.py,sha256=YcQ2VG8iOfXtry_tTpAukmiwXa2hK_9LkpkmXk41Nm0,26226
86
86
  langfun/core/llms/anthropic_test.py,sha256=qA9vByp_cwwXNlXzcwHpPWFnO9lfFo8NKfDi5nBNqgI,9052
87
87
  langfun/core/llms/azure_openai.py,sha256=-KkSLaR54MlsIqz_XIwv0TnsBnvNTAxnjA2Q2O2u5KM,2733
88
88
  langfun/core/llms/azure_openai_test.py,sha256=lkMZkQdJBV97fTM4C4z8qNfvr6spgiN5G4hvVUIVr0M,1735
89
- langfun/core/llms/compositional.py,sha256=csW_FLlgL-tpeyCOTVvfUQkMa_zCN5Y2I-YbSNuK27U,2872
89
+ langfun/core/llms/compositional.py,sha256=W_Fe2BdbkjwTzWW-paCWcEeG9oOR3-IcBG8oc73taSM,2878
90
90
  langfun/core/llms/compositional_test.py,sha256=4eTnOer-DncRKGaIJW2ZQQMLnt5r2R0UIx_DYOvGAQo,2027
91
91
  langfun/core/llms/deepseek.py,sha256=jvTxdXPr-vH6HNakn_Ootx1heDg8Fen2FUkUW36bpCs,5247
92
92
  langfun/core/llms/deepseek_test.py,sha256=DvROWPlDuow5E1lfoSkhyGt_ELA19JoQoDsTnRgDtTg,1847
@@ -133,13 +133,13 @@ langfun/core/structured/description.py,sha256=6BztYOiucPkF4CrTQtPLPJo1gN2dwnKmaJ
133
133
  langfun/core/structured/description_test.py,sha256=UxaXnKKP7TnyPDPUyf3U-zPE0TvLlIP6DGr8thjcePw,7365
134
134
  langfun/core/structured/function_generation.py,sha256=g7AOR_e8HxFU6n6Df750aGkgMgV1KExLZMAz0yd5Agg,8555
135
135
  langfun/core/structured/function_generation_test.py,sha256=LaXYDXf9GlqUrR6v_gtmK_H4kxzonmU7SYbn7XXMgjU,12128
136
- langfun/core/structured/mapping.py,sha256=iraHpcEeF_iuEX2eoTsLGwTHHvxqp2gNDjoMf98l0Kk,13941
136
+ langfun/core/structured/mapping.py,sha256=1YBW8PKpJKXS7DKukfzKNioL84PrKUcB4KOUudrQ20w,14374
137
137
  langfun/core/structured/mapping_test.py,sha256=OntYvfDitAf0tAnzQty3YS90vyEn6FY1Mi93r_ViEk8,9594
138
138
  langfun/core/structured/parsing.py,sha256=MGvI7ypXlwfzr5XB8_TFU9Ei0_5reYqkWkv64eAy0EA,12015
139
139
  langfun/core/structured/parsing_test.py,sha256=V8Cj1tJK4Lxv_b0YQj6-2hzXZgnYNBa2JR7rOLRBKoQ,22346
140
140
  langfun/core/structured/querying.py,sha256=vE_NOLNlIe4A0DueQfyiBEUh3AsSD8Hhx2dSDHNYpYk,37976
141
- langfun/core/structured/querying_test.py,sha256=pYWXqAnzp5eOCjU4yEPtE73iLNqxHISb3y3FaSbI7vs,49760
142
- langfun/core/structured/schema.py,sha256=r_ewdRMsALVOdnvGSeYBcz2-VJ_3_nMxY4GtzUHCYUU,29111
141
+ langfun/core/structured/querying_test.py,sha256=Q0HwmbUI9BqMaeN8vgn_EvX29CzfcomGIKVqKJ6dZyY,50212
142
+ langfun/core/structured/schema.py,sha256=xtgrr3t5tcYQ2gi_fkTKz2IgDMf84gpiykmBdfnV6Io,29486
143
143
  langfun/core/structured/schema_generation.py,sha256=pEWeTd8tQWYnEHukas6GVl4uGerLsQ2aNybtnm4Qgxc,5352
144
144
  langfun/core/structured/schema_generation_test.py,sha256=RM9s71kMNg2jTePwInkiW9fK1ACN37eyPeF8OII-0zw,2950
145
145
  langfun/core/structured/schema_test.py,sha256=H42ZZdPi8CIv7WzrnXwMwQQaPQxlmDSY31pfqQs-Xqw,26567
@@ -156,8 +156,8 @@ langfun/core/templates/demonstration.py,sha256=vCrgYubdZM5Umqcgp8NUVGXgr4P_c-fik
156
156
  langfun/core/templates/demonstration_test.py,sha256=SafcDQ0WgI7pw05EmPI2S4v1t3ABKzup8jReCljHeK4,2162
157
157
  langfun/core/templates/selfplay.py,sha256=yhgrJbiYwq47TgzThmHrDQTF4nDrTI09CWGhuQPNv-s,2273
158
158
  langfun/core/templates/selfplay_test.py,sha256=Ot__1P1M8oJfoTp-M9-PQ6HUXqZKyMwvZ5f7yQ3yfyM,2326
159
- langfun-0.1.2.dev202507140805.dist-info/licenses/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
160
- langfun-0.1.2.dev202507140805.dist-info/METADATA,sha256=PTIyICp15xRxYrwgOSZMqa95ZM-bB3aQ5vlOc8Ggw_8,8178
161
- langfun-0.1.2.dev202507140805.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
162
- langfun-0.1.2.dev202507140805.dist-info/top_level.txt,sha256=RhlEkHxs1qtzmmtWSwYoLVJAc1YrbPtxQ52uh8Z9VvY,8
163
- langfun-0.1.2.dev202507140805.dist-info/RECORD,,
159
+ langfun-0.1.2.dev202507150805.dist-info/licenses/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
160
+ langfun-0.1.2.dev202507150805.dist-info/METADATA,sha256=sO0bfLTiiYrI_s4AbIEF_5v0li1JRZ1IjH492KCQrGU,8178
161
+ langfun-0.1.2.dev202507150805.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
162
+ langfun-0.1.2.dev202507150805.dist-info/top_level.txt,sha256=RhlEkHxs1qtzmmtWSwYoLVJAc1YrbPtxQ52uh8Z9VvY,8
163
+ langfun-0.1.2.dev202507150805.dist-info/RECORD,,