langfun 0.1.2.dev202412180804__py3-none-any.whl → 0.1.2.dev202412230804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -13,11 +13,14 @@
13
13
  # limitations under the License.
14
14
  """Reporting evaluation results."""
15
15
 
16
+ import threading
16
17
  import time
18
+ import traceback
17
19
  from typing import Annotated
18
20
 
19
21
  from langfun.core.eval.v2 import example as example_lib
20
22
  from langfun.core.eval.v2 import experiment as experiment_lib
23
+ import pyglove as pg
21
24
 
22
25
  Runner = experiment_lib.Runner
23
26
  Experiment = experiment_lib.Experiment
@@ -39,12 +42,15 @@ class HtmlReporter(experiment_lib.Plugin):
39
42
  experiment_report_interval: Annotated[
40
43
  int,
41
44
  'The interval of writing report for inidividual experiments in seconds.'
42
- ] = 60
45
+ ] = 120
43
46
 
44
47
  def _on_bound(self):
45
48
  super()._on_bound()
46
49
  self._last_summary_time = 0
47
50
  self._last_experiment_report_time = {}
51
+ self._update_thread = None
52
+ self._stop_update = False
53
+ self._stop_update_experiment_ids = set()
48
54
 
49
55
  def on_run_start(
50
56
  self,
@@ -53,14 +59,41 @@ class HtmlReporter(experiment_lib.Plugin):
53
59
  ) -> None:
54
60
  self._maybe_update_summary(runner)
55
61
  self._last_experiment_report_time = {leaf.id: 0 for leaf in root.leaf_nodes}
62
+ self._stop_update = False
63
+ self._stop_update_experiment_ids = set()
64
+ self._update_thread = threading.Thread(
65
+ target=self._update_thread_func, args=(runner,)
66
+ )
67
+ self._update_thread.start()
56
68
 
57
69
  def on_run_complete(
58
70
  self,
59
71
  runner: Runner,
60
72
  root: Experiment
61
73
  ) -> None:
74
+ self._stop_update = True
62
75
  self._maybe_update_summary(runner, force=True)
63
76
 
77
+ def on_run_abort(
78
+ self,
79
+ runner: Runner,
80
+ root: Experiment,
81
+ error: BaseException
82
+ ) -> None:
83
+ self._stop_update = True
84
+ self._maybe_update_summary(runner, force=True)
85
+
86
+ def _update_thread_func(self, runner: Runner):
87
+ while not self._stop_update:
88
+ self._maybe_update_summary(runner, background=False)
89
+ for leaf in runner.current_run.experiment.leaf_nodes:
90
+ if leaf.id in self._stop_update_experiment_ids:
91
+ continue
92
+ self._maybe_update_experiment_html(runner, leaf, background=False)
93
+ if leaf.progress.is_stopped:
94
+ self._stop_update_experiment_ids.add(leaf.id)
95
+ time.sleep(5)
96
+
64
97
  def on_experiment_start(
65
98
  self,
66
99
  runner: Runner,
@@ -75,6 +108,16 @@ class HtmlReporter(experiment_lib.Plugin):
75
108
  if experiment.is_leaf:
76
109
  self._maybe_update_experiment_html(runner, experiment, force=True)
77
110
 
111
+ def on_experiment_abort(
112
+ self,
113
+ runner: Runner,
114
+ experiment: Experiment,
115
+ error: BaseException
116
+ ) -> None:
117
+ del error
118
+ assert experiment.is_leaf
119
+ self._maybe_update_experiment_html(runner, experiment, force=True)
120
+
78
121
  def on_example_complete(
79
122
  self, runner: Runner, experiment: Experiment, example: Example
80
123
  ):
@@ -82,7 +125,11 @@ class HtmlReporter(experiment_lib.Plugin):
82
125
  self._maybe_update_experiment_html(runner, experiment)
83
126
  self._maybe_update_summary(runner)
84
127
 
85
- def _maybe_update_summary(self, runner: Runner, force: bool = False) -> None:
128
+ def _maybe_update_summary(
129
+ self,
130
+ runner: Runner,
131
+ background: bool = True,
132
+ force: bool = False) -> None:
86
133
  """Maybe update the summary of current run."""
87
134
  run = runner.current_run
88
135
  def _summary():
@@ -96,31 +143,52 @@ class HtmlReporter(experiment_lib.Plugin):
96
143
  )
97
144
 
98
145
  if force or (time.time() - self._last_summary_time > self.summary_interval):
99
- runner.background_run(_summary)
146
+ if background:
147
+ runner.background_run(_summary)
148
+ else:
149
+ _summary()
100
150
  self._last_summary_time = time.time()
101
151
 
102
152
  def _maybe_update_experiment_html(
103
- self, runner: Runner, experiment: Experiment, force: bool = False
153
+ self,
154
+ runner: Runner,
155
+ experiment: Experiment,
156
+ force: bool = False,
157
+ background: bool = True,
104
158
  ) -> None:
105
159
  def _save():
106
- html = experiment.to_html(
107
- collapse_level=None,
108
- extra_flags=dict(
109
- current_run=runner.current_run,
110
- interactive=False,
111
- card_view=False,
112
- ),
160
+ index_html_path = runner.current_run.output_path_for(
161
+ experiment, _EVALULATION_DETAIL_FILE
113
162
  )
114
- html.save(
115
- runner.current_run.output_path_for(
116
- experiment, _EVALULATION_DETAIL_FILE
163
+ try:
164
+ with pg.timeit() as t:
165
+ html = experiment.to_html(
166
+ collapse_level=None,
167
+ extra_flags=dict(
168
+ current_run=runner.current_run,
169
+ interactive=False,
170
+ card_view=False,
171
+ ),
117
172
  )
118
- )
173
+ html.save(index_html_path)
174
+ experiment.info(
175
+ f'Generated HTML {index_html_path!r} in {t.elapse:.2f} seconds.',
176
+ )
177
+ except BaseException as e: # pylint: disable=broad-except
178
+ experiment.error(
179
+ f'Failed to save HTML {index_html_path!r}. '
180
+ f'Error: {e}, Stacktrace: \n{traceback.format_exc()}.',
181
+ )
182
+ raise e
183
+
119
184
  if force or (
120
185
  time.time() - self._last_experiment_report_time[experiment.id]
121
186
  > self.experiment_report_interval
122
187
  ):
123
- runner.background_run(_save)
188
+ if background:
189
+ runner.background_run(_save)
190
+ else:
191
+ _save()
124
192
  self._last_experiment_report_time[experiment.id] = time.time()
125
193
 
126
194
  def _save_example_html(
@@ -128,17 +196,24 @@ class HtmlReporter(experiment_lib.Plugin):
128
196
  ) -> None:
129
197
  """Saves the example."""
130
198
  def _save():
131
- html = example.to_html(
132
- collapse_level=None,
133
- enable_summary_tooltip=False,
134
- extra_flags=dict(
135
- # For properly rendering the next link.
136
- num_examples=getattr(experiment, 'num_examples', None)
137
- ),
138
- )
139
- html.save(
140
- runner.current_run.output_path_for(
141
- experiment, f'{example.id}.html'
142
- )
143
- )
199
+ try:
200
+ html = example.to_html(
201
+ collapse_level=None,
202
+ enable_summary_tooltip=False,
203
+ extra_flags=dict(
204
+ # For properly rendering the next link.
205
+ num_examples=getattr(experiment, 'num_examples', None)
206
+ ),
207
+ )
208
+ html.save(
209
+ runner.current_run.output_path_for(
210
+ experiment, f'{example.id}.html'
211
+ )
212
+ )
213
+ except BaseException as e: # pylint: disable=broad-except
214
+ experiment.error(
215
+ f'Failed to save HTML {example.id}.html. '
216
+ f'Error: {e}, Stacktrace: \n{traceback.format_exc()}.',
217
+ )
218
+ raise e
144
219
  runner.background_run(_save)
@@ -15,9 +15,9 @@ import os
15
15
  import tempfile
16
16
  import unittest
17
17
 
18
+ from langfun.core.eval.v2 import eval_test_helper
18
19
  from langfun.core.eval.v2 import reporting
19
20
  from langfun.core.eval.v2 import runners as runners_lib # pylint: disable=unused-import
20
- from langfun.core.eval.v2 import test_helper
21
21
  import pyglove as pg
22
22
 
23
23
 
@@ -25,7 +25,7 @@ class ReportingTest(unittest.TestCase):
25
25
 
26
26
  def test_reporting(self):
27
27
  root_dir = os.path.join(tempfile.gettempdir(), 'test_reporting')
28
- experiment = test_helper.test_experiment()
28
+ experiment = eval_test_helper.test_experiment()
29
29
  reporter = reporting.HtmlReporter()
30
30
  run = experiment.run(root_dir, 'new', plugins=[reporter])
31
31
  pg.io.path_exists(run.output_path_for(experiment, 'summary.html'))
@@ -18,6 +18,7 @@ import concurrent.futures
18
18
  import random
19
19
  import threading
20
20
  import time
21
+ import traceback
21
22
  from typing import Any, Annotated, Callable, Iterator
22
23
 
23
24
  from langfun import core as lf
@@ -64,6 +65,7 @@ class RunnerBase(Runner):
64
65
  with pg.notify_on_change(False):
65
66
  self.plugins.append(progress_tracking.progress_tracker(self.tqdm))
66
67
 
68
+ self._io_pool_lock = threading.Lock()
67
69
  self._io_pool = concurrent.futures.ThreadPoolExecutor(max_workers=16)
68
70
  # TODO(daiyip): render background errors.
69
71
  self._background_last_error = None
@@ -75,7 +77,10 @@ class RunnerBase(Runner):
75
77
  func(*args, **kwargs)
76
78
  except Exception as e: # pylint: disable=broad-except
77
79
  self._background_last_error = e
78
- self._io_pool.submit(_background_run, *args, **kwargs)
80
+
81
+ with self._io_pool_lock:
82
+ if self._io_pool is not None:
83
+ self._io_pool.submit(_background_run, *args, **kwargs)
79
84
 
80
85
  def _all_plugins(self, experiment: Experiment) -> Iterator[Plugin]:
81
86
  """Returns all plugins for the experiment."""
@@ -120,9 +125,14 @@ class RunnerBase(Runner):
120
125
  # Start the progress of the evaluation.
121
126
  if experiment.is_leaf:
122
127
  assert isinstance(experiment, Evaluation)
123
- experiment.progress.start(
124
- total=(len(self.current_run.example_ids)
125
- if self.current_run.example_ids else experiment.num_examples)
128
+ num_examples_to_evaluate = (
129
+ len(self.current_run.example_ids)
130
+ if self.current_run.example_ids else experiment.num_examples
131
+ )
132
+ experiment.progress.start(total=num_examples_to_evaluate)
133
+ experiment.info(
134
+ 'Starting evaluation %s with %d examples to evaluate.'
135
+ % (experiment.id, num_examples_to_evaluate)
126
136
  )
127
137
  else:
128
138
  experiment.progress.start(total=len(experiment.leaf_nodes))
@@ -144,8 +154,7 @@ class RunnerBase(Runner):
144
154
 
145
155
  # Only leaf evaluations will trigger the complete notification of the
146
156
  # ancestors.
147
- if experiment.is_leaf:
148
- self._update_ancestor_progresses(experiment)
157
+ self._update_ancestor_progresses(experiment)
149
158
 
150
159
  def on_experiment_complete(self, experiment: Experiment) -> None:
151
160
  """Called when an evaluation is complete."""
@@ -160,6 +169,35 @@ class RunnerBase(Runner):
160
169
  # ancestors.
161
170
  if experiment.is_leaf:
162
171
  self._update_ancestor_progresses(experiment)
172
+ self._log_experiment_completion(experiment)
173
+
174
+ def _log_experiment_completion(self, experiment: Experiment):
175
+ example_ids = (
176
+ self.current_run.example_ids if self.current_run.example_ids else
177
+ list(range(1, experiment.num_examples + 1))
178
+ )
179
+ num_from_checkpoint, num_processed = 0, 0
180
+ for example_id in example_ids:
181
+ example = experiment.state.get(example_id)
182
+ if example.newly_processed:
183
+ num_processed += 1
184
+ else:
185
+ num_from_checkpoint += 1
186
+ experiment.info(
187
+ f'{experiment.id} completed with {num_from_checkpoint + num_processed} '
188
+ f'examples evaluated ({num_from_checkpoint} from checkpoint, '
189
+ f'{num_processed} newly processed).'
190
+ )
191
+
192
+ def on_experiment_abort(
193
+ self, experiment: Experiment, error: BaseException) -> None:
194
+ """Called when an evaluation is complete."""
195
+ assert experiment.is_leaf
196
+ experiment.fatal(f'{error}\n\n{traceback.format_exc()}')
197
+
198
+ # Notify the plugins of the experiment abort.
199
+ for plugin in self._all_plugins(experiment):
200
+ plugin.on_experiment_abort(self, experiment, error)
163
201
 
164
202
  def _update_ancestor_progresses(self, experiment: Experiment):
165
203
  """Updates the progresses of the parent nodes of the experiment."""
@@ -262,7 +300,9 @@ class RunnerBase(Runner):
262
300
  self.background_run(cache.save)
263
301
 
264
302
  # Wait for the background tasks to finish.
265
- self._io_pool.shutdown(wait=True)
303
+ with self._io_pool_lock:
304
+ self._io_pool, io_pool = None, self._io_pool
305
+ io_pool.shutdown(wait=True)
266
306
 
267
307
  @abc.abstractmethod
268
308
  def _run(self, evaluations: list[Evaluation]) -> None:
@@ -270,31 +310,36 @@ class RunnerBase(Runner):
270
310
 
271
311
  def run_evaluation(self, evaluation: Evaluation) -> None:
272
312
  """Runs the evaluation."""
273
- self.on_experiment_start(evaluation)
274
-
275
- per_evaluation_settings = {}
276
- cache = None
277
- if self.current_run.use_cache == 'per_dataset':
278
- cache = self._load_or_create_cache(evaluation)
279
- per_evaluation_settings['cache'] = cache
280
-
281
- with lf.use_settings(**per_evaluation_settings):
282
- if self.current_run.example_ids is None:
283
- items = (
284
- Example(id=i + 1, input=ex) for i, ex in enumerate(
285
- evaluation.example_inputs)
286
- )
287
- else:
288
- items = (
289
- Example(
290
- id=example_id, input=evaluation.example_input_by_id(example_id)
291
- ) for example_id in self.current_run.example_ids
292
- )
293
- self._evaluate_items(evaluation, items)
294
-
295
- if cache:
296
- self.background_run(cache.save)
297
- self.on_experiment_complete(evaluation)
313
+ try:
314
+ self.on_experiment_start(evaluation)
315
+
316
+ per_evaluation_settings = {}
317
+ cache = None
318
+ if self.current_run.use_cache == 'per_dataset':
319
+ cache = self._load_or_create_cache(evaluation)
320
+ per_evaluation_settings['cache'] = cache
321
+
322
+ with lf.use_settings(**per_evaluation_settings):
323
+ if self.current_run.example_ids is None:
324
+ items = (
325
+ Example(id=i + 1, input=ex) for i, ex in enumerate(
326
+ evaluation.example_inputs)
327
+ )
328
+ else:
329
+ items = (
330
+ Example(
331
+ id=example_id,
332
+ input=evaluation.example_input_by_id(example_id)
333
+ ) for example_id in self.current_run.example_ids
334
+ )
335
+ self._evaluate_items(evaluation, items)
336
+
337
+ if cache:
338
+ self.background_run(cache.save)
339
+ self.on_experiment_complete(evaluation)
340
+ except BaseException as e: # pylint: disable=broad-except
341
+ self.on_experiment_abort(evaluation, e)
342
+ raise e
298
343
 
299
344
  @abc.abstractmethod
300
345
  def _evaluate_items(
@@ -410,9 +455,7 @@ class ParallelRunner(RunnerBase):
410
455
  groups.values(),
411
456
  max_workers=max(64, len(groups)),
412
457
  timeout=self.timeout,
413
- silence_on_errors=(
414
- None if self.current_run.raise_if_has_error else BaseException
415
- )
458
+ silence_on_errors=None,
416
459
  ):
417
460
  pass
418
461
 
@@ -437,8 +480,6 @@ class ParallelRunner(RunnerBase):
437
480
  items,
438
481
  max_workers=evaluation.max_workers,
439
482
  timeout=self.timeout,
440
- silence_on_errors=(
441
- None if self.current_run.raise_if_has_error else BaseException
442
- )
483
+ silence_on_errors=None,
443
484
  ):
444
485
  pass
@@ -18,10 +18,11 @@ import time
18
18
  from typing import Any
19
19
  import unittest
20
20
 
21
+ from langfun.core.eval.v2 import eval_test_helper
21
22
  from langfun.core.eval.v2 import example as example_lib
22
23
  from langfun.core.eval.v2 import experiment as experiment_lib
23
24
  from langfun.core.eval.v2 import runners as runners_lib # pylint: disable=unused-import
24
- from langfun.core.eval.v2 import test_helper
25
+
25
26
  import pyglove as pg
26
27
 
27
28
 
@@ -101,7 +102,7 @@ class RunnerTest(unittest.TestCase):
101
102
 
102
103
  def test_basic(self):
103
104
  plugin = TestPlugin()
104
- exp = test_helper.test_experiment()
105
+ exp = eval_test_helper.test_experiment()
105
106
  root_dir = os.path.join(tempfile.gettempdir(), 'test_sequential_runner')
106
107
  run = exp.run(root_dir, runner='sequential', plugins=[plugin])
107
108
 
@@ -143,7 +144,7 @@ class RunnerTest(unittest.TestCase):
143
144
 
144
145
  def test_raise_if_has_error(self):
145
146
  root_dir = os.path.join(tempfile.gettempdir(), 'test_raise_if_has_error')
146
- exp = test_helper.TestEvaluation()
147
+ exp = eval_test_helper.TestEvaluation()
147
148
  with self.assertRaisesRegex(ValueError, 'x should not be 5'):
148
149
  exp.run(
149
150
  root_dir, runner='sequential', plugins=[], raise_if_has_error=True
@@ -154,7 +155,7 @@ class RunnerTest(unittest.TestCase):
154
155
 
155
156
  def test_example_ids(self):
156
157
  root_dir = os.path.join(tempfile.gettempdir(), 'test_example_ids')
157
- exp = test_helper.test_experiment()
158
+ exp = eval_test_helper.test_experiment()
158
159
  plugin = TestPlugin()
159
160
  _ = exp.run(
160
161
  root_dir, runner='sequential', plugins=[plugin], example_ids=[5, 7, 9]
@@ -164,7 +165,7 @@ class RunnerTest(unittest.TestCase):
164
165
 
165
166
  def test_filter(self):
166
167
  plugin = TestPlugin()
167
- exp = test_helper.test_experiment()
168
+ exp = eval_test_helper.test_experiment()
168
169
  root_dir = os.path.join(tempfile.gettempdir(), 'test_filter')
169
170
 
170
171
  _ = exp.run(
@@ -193,7 +194,7 @@ class RunnerTest(unittest.TestCase):
193
194
  ) for i in range(num_examples)
194
195
  ]
195
196
 
196
- exp = test_helper.TestEvaluation(
197
+ exp = eval_test_helper.TestEvaluation(
197
198
  inputs=test_inputs(num_examples=pg.oneof([2, 4]))
198
199
  )
199
200
  # Global cache.
@@ -234,7 +235,7 @@ class ParallelRunnerTest(RunnerTest):
234
235
 
235
236
  def test_parallel_runner(self):
236
237
  plugin = TestPlugin()
237
- exp = test_helper.test_experiment()
238
+ exp = eval_test_helper.test_experiment()
238
239
  root_dir = os.path.join(tempfile.gettempdir(), 'test_parallel_runner')
239
240
  run = exp.run(root_dir, runner='parallel', plugins=[plugin])
240
241
 
@@ -274,7 +275,7 @@ class ParallelRunnerTest(RunnerTest):
274
275
 
275
276
  def test_concurrent_startup_delay(self):
276
277
  plugin = TestPlugin()
277
- exp = test_helper.test_experiment()
278
+ exp = eval_test_helper.test_experiment()
278
279
  root_dir = os.path.join(
279
280
  tempfile.gettempdir(), 'test_concurrent_startup_delay'
280
281
  )
@@ -290,7 +291,7 @@ class DebugRunnerTest(RunnerTest):
290
291
 
291
292
  def test_debug_runner(self):
292
293
  plugin = TestPlugin()
293
- exp = test_helper.test_experiment()
294
+ exp = eval_test_helper.test_experiment()
294
295
  root_dir = os.path.join(tempfile.gettempdir(), 'test_debug_runner')
295
296
  run = exp.run(root_dir, runner='debug', plugins=[plugin])
296
297
 
langfun/core/logging.py CHANGED
@@ -54,6 +54,25 @@ class LogEntry(pg.Object, pg.views.HtmlTreeView.Extension):
54
54
  def should_output(self, min_log_level: LogLevel) -> bool:
55
55
  return _LOG_LEVELS.index(self.level) >= _LOG_LEVELS.index(min_log_level)
56
56
 
57
+ def format(self,
58
+ compact: bool = False,
59
+ verbose: bool = True,
60
+ root_indent: int = 0,
61
+ *,
62
+ text_format: bool = True,
63
+ **kwargs):
64
+ if text_format:
65
+ s = f"""{self.time.strftime('%H:%M:%S')} {self.level.upper()} - {self.message}"""
66
+ if self.metadata:
67
+ s += f' (metadata: {self.metadata!r})'
68
+ return s
69
+ return super().format(
70
+ compact=compact,
71
+ verbose=verbose,
72
+ root_indent=root_indent,
73
+ **kwargs
74
+ )
75
+
57
76
  def _html_tree_view_summary(
58
77
  self,
59
78
  view: pg.views.HtmlTreeView,
@@ -61,6 +61,25 @@ class LoggingTest(unittest.TestCase):
61
61
  print(actual)
62
62
  self.assertEqual(actual, expected)
63
63
 
64
+ def test_format(self):
65
+ time = datetime.datetime(2024, 10, 10, 12, 30, 45)
66
+ self.assertEqual(
67
+ str(
68
+ logging.LogEntry(
69
+ level='info', message='hello\nworld',
70
+ time=time, metadata=dict(x=1),
71
+ )
72
+ ),
73
+ '12:30:45 INFO - hello\nworld (metadata: {x=1})',
74
+ )
75
+ self.assertIn(
76
+ 'LogEntry(',
77
+ logging.LogEntry(
78
+ level='info', message='hello\nworld',
79
+ time=time, metadata=dict(x=1),
80
+ ).format(text_format=False),
81
+ )
82
+
64
83
  def test_html(self):
65
84
  time = datetime.datetime(2024, 10, 10, 12, 30, 45)
66
85
  self.assert_html_content(
@@ -270,24 +270,31 @@ def call(
270
270
  if schema in (str, None):
271
271
  return lm_output if returns_message else lm_output.text
272
272
 
273
+ def _chain_nl_output_message(parsing_message: lf.Message):
274
+ """Chain the source of the parsed output to the LM output."""
275
+ parsing_message.root.source = lm_output
276
+ parsing_message.tag('parsing-lm-output')
277
+ parsing_message.lm_input.tag('parsing-lm-input')
278
+
273
279
  # Call `parsing_lm` for structured parsing.
274
- parsing_message = querying.query(
275
- lm_output.text,
276
- schema,
277
- examples=parsing_examples,
278
- lm=parsing_lm or lm,
279
- include_context=parsing_include_context,
280
- cache_seed=cache_seed,
281
- autofix=autofix,
282
- autofix_lm=autofix_lm or lm,
283
- protocol=protocol,
284
- returns_message=True,
285
- **kwargs,
286
- )
287
- # Chain the source of the parsed output to the LM output.
288
- parsing_message.root.source = lm_output
289
- parsing_message.tag('parsing-lm-output')
290
- parsing_message.lm_input.tag('parsing-lm-input')
280
+ try:
281
+ parsing_message = querying.query(
282
+ lm_output.text,
283
+ schema,
284
+ examples=parsing_examples,
285
+ lm=parsing_lm or lm,
286
+ include_context=parsing_include_context,
287
+ cache_seed=cache_seed,
288
+ autofix=autofix,
289
+ autofix_lm=autofix_lm or lm,
290
+ protocol=protocol,
291
+ returns_message=True,
292
+ **kwargs,
293
+ )
294
+ _chain_nl_output_message(parsing_message)
295
+ except mapping.MappingError as e:
296
+ _chain_nl_output_message(e.lm_response)
297
+ raise e
291
298
  return parsing_message if returns_message else parsing_message.result
292
299
 
293
300
 
@@ -686,6 +686,31 @@ class CallTest(unittest.TestCase):
686
686
  ],
687
687
  returns_message=True,
688
688
  )
689
+ self.assertIn('parsing-lm-output', output.tags)
690
+ self.assertIn('parsing-lm-input', output.source.tags)
691
+ self.assertEqual(output.root.text, 'Compute 1 + 2')
692
+
693
+ def test_call_with_parsing_message_chaining_on_parsing_error(self):
694
+ try:
695
+ output = parsing.call(
696
+ 'Compute 1 + 2',
697
+ int,
698
+ lm=fake.StaticSequence(['three']),
699
+ parsing_lm=fake.StaticSequence(['abc']),
700
+ parsing_examples=[
701
+ mapping.MappingExample(
702
+ context='Multiple four and five',
703
+ input='twenty',
704
+ schema=int,
705
+ output=20,
706
+ )
707
+ ],
708
+ returns_message=True,
709
+ )
710
+ except mapping.MappingError as e:
711
+ output = e.lm_response
712
+ self.assertIn('parsing-lm-output', output.tags)
713
+ self.assertIn('parsing-lm-input', output.source.tags)
689
714
  self.assertEqual(output.root.text, 'Compute 1 + 2')
690
715
 
691
716
  def test_call_with_autofix(self):
@@ -583,7 +583,16 @@ class QueryInvocation(pg.Object, pg.views.HtmlTreeView.Extension):
583
583
 
584
584
  @functools.cached_property
585
585
  def output(self) -> Any:
586
- return query_output(self.lm_response, self.schema)
586
+ """The output of `lf.query`. If it failed, returns the `MappingError`."""
587
+ try:
588
+ return query_output(self.lm_response, self.schema)
589
+ except mapping.MappingError as e:
590
+ return e
591
+
592
+ @property
593
+ def has_error(self) -> bool:
594
+ """Returns True if the query failed to generate a valid output."""
595
+ return isinstance(self.output, BaseException)
587
596
 
588
597
  @property
589
598
  def elapse(self) -> float:
@@ -1051,6 +1051,16 @@ class QueryStructureJsonTest(unittest.TestCase):
1051
1051
 
1052
1052
  class QueryInvocationTest(unittest.TestCase):
1053
1053
 
1054
+ def test_basics(self):
1055
+ lm = fake.StaticSequence([
1056
+ 'Activity(description="hi"',
1057
+ ])
1058
+ with querying.track_queries() as queries:
1059
+ querying.query('foo', Activity, default=None, lm=lm)
1060
+
1061
+ self.assertTrue(queries[0].has_error)
1062
+ self.assertIsInstance(queries[0].output, mapping.MappingError)
1063
+
1054
1064
  def test_to_html(self):
1055
1065
  lm = fake.StaticSequence([
1056
1066
  'Activity(description="hi")',
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: langfun
3
- Version: 0.1.2.dev202412180804
3
+ Version: 0.1.2.dev202412230804
4
4
  Summary: Langfun: Language as Functions.
5
5
  Home-page: https://github.com/google/langfun
6
6
  Author: Langfun Authors