langfun 0.1.2.dev202511160804__py3-none-any.whl → 0.1.2.dev202511270805__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of langfun might be problematic. Click here for more details.

Files changed (41) hide show
  1. langfun/core/__init__.py +1 -0
  2. langfun/core/agentic/__init__.py +4 -1
  3. langfun/core/agentic/action.py +340 -17
  4. langfun/core/agentic/action_test.py +124 -21
  5. langfun/core/eval/base_test.py +5 -5
  6. langfun/core/eval/v2/checkpointing.py +25 -1
  7. langfun/core/eval/v2/checkpointing_test.py +8 -1
  8. langfun/core/eval/v2/eval_test_helper.py +7 -2
  9. langfun/core/eval/v2/evaluation.py +4 -1
  10. langfun/core/eval/v2/example.py +5 -1
  11. langfun/core/eval/v2/example_test.py +13 -5
  12. langfun/core/eval/v2/experiment.py +23 -0
  13. langfun/core/eval/v2/experiment_test.py +19 -0
  14. langfun/core/eval/v2/progress_tracking.py +12 -3
  15. langfun/core/eval/v2/progress_tracking_test.py +3 -1
  16. langfun/core/eval/v2/reporting_test.py +4 -0
  17. langfun/core/eval/v2/runners/__init__.py +4 -0
  18. langfun/core/eval/v2/runners/base.py +40 -21
  19. langfun/core/eval/v2/runners/beam.py +341 -0
  20. langfun/core/eval/v2/runners/beam_test.py +131 -0
  21. langfun/core/eval/v2/runners/ckpt_monitor.py +294 -0
  22. langfun/core/eval/v2/runners/ckpt_monitor_test.py +162 -0
  23. langfun/core/eval/v2/runners/debug_test.py +1 -4
  24. langfun/core/eval/v2/runners/parallel_test.py +1 -4
  25. langfun/core/eval/v2/runners/sequential_test.py +1 -4
  26. langfun/core/langfunc_test.py +3 -3
  27. langfun/core/language_model.py +38 -5
  28. langfun/core/language_model_test.py +45 -0
  29. langfun/core/llms/__init__.py +2 -0
  30. langfun/core/llms/gemini.py +41 -8
  31. langfun/core/llms/gemini_test.py +84 -0
  32. langfun/core/llms/google_genai.py +5 -0
  33. langfun/core/llms/vertexai.py +7 -0
  34. langfun/core/modalities/mime.py +2 -0
  35. langfun/core/modalities/mime_test.py +11 -0
  36. langfun/core/structured/schema/__init__.py +1 -0
  37. {langfun-0.1.2.dev202511160804.dist-info → langfun-0.1.2.dev202511270805.dist-info}/METADATA +1 -1
  38. {langfun-0.1.2.dev202511160804.dist-info → langfun-0.1.2.dev202511270805.dist-info}/RECORD +41 -37
  39. {langfun-0.1.2.dev202511160804.dist-info → langfun-0.1.2.dev202511270805.dist-info}/WHEEL +0 -0
  40. {langfun-0.1.2.dev202511160804.dist-info → langfun-0.1.2.dev202511270805.dist-info}/licenses/LICENSE +0 -0
  41. {langfun-0.1.2.dev202511160804.dist-info → langfun-0.1.2.dev202511270805.dist-info}/top_level.txt +0 -0
@@ -52,6 +52,7 @@ class Foo(action_lib.Action):
52
52
  with session.track_phase('prepare'):
53
53
  session.info('Begin Foo', x=1)
54
54
  time.sleep(self.simulate_execution_time[0])
55
+ Bar()(session, lm=lm)
55
56
  session.query(
56
57
  'foo',
57
58
  schema=int if self.simulate_query_error else None,
@@ -65,6 +66,7 @@ class Foo(action_lib.Action):
65
66
  def _sub_task(i):
66
67
  session.add_metadata(**{f'subtask_{i}': i})
67
68
  time.sleep(self.simulate_execution_time[2])
69
+ Bar()(session, lm=lm)
68
70
  return lf_structured.query(f'subtask_{i}', lm=lm)
69
71
 
70
72
  self._state = []
@@ -88,6 +90,50 @@ class Foo(action_lib.Action):
88
90
  lf_structured.query('additional query', lm=lm)
89
91
 
90
92
 
93
+ class ExecutionUnitPositionTest(unittest.TestCase):
94
+
95
+ def test_basics(self):
96
+ pos1 = action_lib.ExecutionUnit.Position(None, 0)
97
+ self.assertEqual(repr(pos1), 'Position(0)')
98
+ self.assertEqual(str(pos1), '')
99
+ self.assertIsNone(pos1.parent)
100
+ self.assertEqual(pos1.index, 0)
101
+ self.assertEqual(pos1.indices(), (0,))
102
+ self.assertEqual(pos1, (0,))
103
+ self.assertEqual(pos1, '')
104
+ self.assertEqual(pos1, action_lib.ExecutionUnit.Position(None, 0))
105
+ self.assertNotEqual(pos1, 1)
106
+ self.assertNotEqual(pos1, (1,))
107
+ self.assertNotEqual(pos1, action_lib.ExecutionUnit.Position(None, 1))
108
+
109
+ pos2 = action_lib.ExecutionUnit.Position(pos1, 0)
110
+ self.assertEqual(repr(pos2), 'Position(0, 0)')
111
+ self.assertEqual(str(pos2), '1')
112
+ self.assertEqual(pos2, '1')
113
+ self.assertEqual(pos2.parent, pos1)
114
+ self.assertEqual(pos2.index, 0)
115
+ self.assertEqual(pos2.indices(), (0, 0))
116
+ self.assertNotEqual(pos1, pos2)
117
+ self.assertLess(pos1, pos2)
118
+ self.assertGreater(pos2, pos1)
119
+ self.assertEqual(
120
+ hash(pos2),
121
+ hash(
122
+ action_lib.ExecutionUnit.Position(
123
+ action_lib.ExecutionUnit.Position(None, 0), 0
124
+ )
125
+ )
126
+ )
127
+
128
+ pos3 = action_lib.ExecutionUnit.Position(pos2, 0)
129
+ self.assertEqual(str(pos3), '1.1')
130
+ self.assertEqual(pos3, '1.1')
131
+ self.assertEqual(pos3.parent, pos2)
132
+ self.assertEqual(pos3.index, 0)
133
+ self.assertEqual(pos3.indices(), (0, 0, 0))
134
+ self.assertEqual(pos3.to_str(separator='>'), '1>1')
135
+
136
+
91
137
  class ActionInvocationTest(unittest.TestCase):
92
138
 
93
139
  def test_basics(self):
@@ -108,9 +154,7 @@ class ExecutionTraceTest(unittest.TestCase):
108
154
  self.assertEqual(execution.id, '')
109
155
 
110
156
  root = action_lib.ActionInvocation(action=action_lib.RootAction())
111
- action_invocation = action_lib.ActionInvocation(
112
- action=Foo(1)
113
- )
157
+ action_invocation = action_lib.ActionInvocation(action=Foo(1))
114
158
  root.execution.append(action_invocation)
115
159
  self.assertEqual(action_invocation.execution.id, '/a1')
116
160
 
@@ -153,6 +197,7 @@ class SessionTest(unittest.TestCase):
153
197
 
154
198
  self.assertIsInstance(session.root.action, action_lib.RootAction)
155
199
  self.assertIs(session.current_action, session.root)
200
+ self.assertIs(session.metadata, session.root.metadata)
156
201
 
157
202
  #
158
203
  # Inspecting the root invocation.
@@ -175,20 +220,25 @@ class SessionTest(unittest.TestCase):
175
220
  )
176
221
 
177
222
  # The root space should have one action (foo), no queries, and no logs.
223
+ self.assertEqual(len(root.execution_units), 1)
178
224
  self.assertEqual(len(root.actions), 1)
179
225
  self.assertEqual(len(root.queries), 0)
180
226
  self.assertEqual(len(root.logs), 0)
181
- # 1 query from Bar, 2 from Foo and 3 from parallel executions.
182
- self.assertEqual(len(session.all_queries), 6)
183
- self.assertEqual(len(root.all_queries), 6)
184
- # 2 actions: Foo and Bar.
185
- self.assertEqual(len(session.all_actions), 2)
186
- self.assertEqual(len(root.all_actions), 2)
187
- # 1 log from Bar and 1 from Foo.
188
- self.assertEqual(len(session.all_logs), 2)
189
- self.assertEqual(len(root.all_logs), 2)
227
+ # 2 query from Bar, 2 from Foo and 2 * 3 from parallel executions.
228
+ self.assertEqual(len(session.all_queries), 10)
229
+ self.assertEqual(len(root.all_queries), 10)
230
+ # 6 actions: Foo and 2 Bar, and 3 Bar from parallel executions.
231
+ self.assertEqual(len(session.all_actions), 6)
232
+ self.assertEqual(
233
+ [str(a.position) for a in session.all_actions],
234
+ ['1', '1.1', '1.2.1.1', '1.2.2.1', '1.2.3.1', '1.3']
235
+ )
236
+ self.assertEqual(len(root.all_actions), 6)
237
+ # 1 log from Bar and 1 from Foo and 3 from Bar in parallel executions.
238
+ self.assertEqual(len(session.all_logs), 6)
239
+ self.assertEqual(len(root.all_logs), 6)
190
240
  self.assertIs(session.usage_summary, root.usage_summary)
191
- self.assertEqual(root.usage_summary.total.num_requests, 6)
241
+ self.assertEqual(root.usage_summary.total.num_requests, 10)
192
242
 
193
243
  # Inspecting the top-level action (Foo)
194
244
  foo_invocation = root.execution[0]
@@ -200,15 +250,19 @@ class SessionTest(unittest.TestCase):
200
250
 
201
251
  # Prepare phase.
202
252
  prepare_phase = foo_invocation.execution[0]
253
+ self.assertIsNone(prepare_phase.position)
203
254
  self.assertIsInstance(prepare_phase, action_lib.ExecutionTrace)
204
255
  self.assertEqual(prepare_phase.id, 'agent@1:/a1/prepare')
205
- self.assertEqual(len(prepare_phase.items), 2)
256
+ self.assertEqual(len(prepare_phase.items), 3)
206
257
  self.assertTrue(prepare_phase.has_started)
207
258
  self.assertTrue(prepare_phase.has_stopped)
208
- self.assertEqual(prepare_phase.usage_summary.total.num_requests, 1)
259
+ self.assertEqual(prepare_phase.usage_summary.total.num_requests, 2)
209
260
  self.assertIsInstance(prepare_phase.items[0], lf.logging.LogEntry)
210
- self.assertIsInstance(prepare_phase.items[1], lf_structured.QueryInvocation)
211
- self.assertEqual(prepare_phase.items[1].id, 'agent@1:/a1/prepare/q1')
261
+ self.assertIsInstance(prepare_phase.items[1], action_lib.ActionInvocation)
262
+ self.assertIs(prepare_phase.items[1].parent_execution_unit, foo_invocation)
263
+ self.assertEqual(prepare_phase.items[1].id, 'agent@1:/a1/prepare/a1')
264
+ self.assertIsInstance(prepare_phase.items[2], lf_structured.QueryInvocation)
265
+ self.assertEqual(prepare_phase.items[2].id, 'agent@1:/a1/prepare/q1')
212
266
 
213
267
  # Tracked queries.
214
268
  query_invocation = foo_invocation.execution[1]
@@ -230,20 +284,44 @@ class SessionTest(unittest.TestCase):
230
284
 
231
285
  # Tracked parallel executions.
232
286
  parallel_executions = foo_invocation.execution[2]
287
+ # root (0) > foo (0) > parallel executions (1)
288
+ self.assertEqual(parallel_executions.position, (0, 0, 1))
233
289
  self.assertEqual(parallel_executions.id, 'agent@1:/a1/p1')
234
290
  self.assertIsInstance(parallel_executions, action_lib.ParallelExecutions)
291
+ self.assertIs(
292
+ parallel_executions.all_actions[0].parent_execution_unit,
293
+ parallel_executions
294
+ )
295
+ self.assertIs(
296
+ parallel_executions.all_actions[0].parent_action,
297
+ foo_invocation
298
+ )
235
299
  self.assertEqual(len(parallel_executions), 3)
236
300
  self.assertEqual(parallel_executions[0].id, 'agent@1:/a1/p1/b1')
237
301
  self.assertEqual(parallel_executions[1].id, 'agent@1:/a1/p1/b2')
238
302
  self.assertEqual(parallel_executions[2].id, 'agent@1:/a1/p1/b3')
303
+ self.assertEqual(len(parallel_executions[0].execution_units), 1)
304
+ self.assertEqual(len(parallel_executions[1].execution_units), 1)
305
+ self.assertEqual(len(parallel_executions[2].execution_units), 1)
239
306
  self.assertEqual(len(parallel_executions[0].queries), 1)
307
+ self.assertEqual(len(parallel_executions[0].all_queries), 2)
240
308
  self.assertEqual(len(parallel_executions[1].queries), 1)
309
+ self.assertEqual(len(parallel_executions[1].all_queries), 2)
241
310
  self.assertEqual(len(parallel_executions[2].queries), 1)
311
+ self.assertEqual(len(parallel_executions[2].all_queries), 2)
312
+ self.assertEqual(len(parallel_executions.execution_units), 0)
313
+ self.assertEqual(len(parallel_executions.actions), 0)
314
+ self.assertEqual(len(parallel_executions.queries), 0)
315
+ self.assertEqual(len(parallel_executions.logs), 0)
316
+ self.assertEqual(len(parallel_executions.all_actions), 3)
317
+ self.assertEqual(len(parallel_executions.all_queries), 6)
318
+ self.assertEqual(len(parallel_executions.all_logs), 3)
242
319
 
243
320
  # Invocation to Bar.
244
321
  bar_invocation = foo_invocation.execution[3]
245
322
  self.assertIs(bar_invocation.parent_action, foo_invocation)
246
- self.assertEqual(bar_invocation.id, 'agent@1:/a1/a1')
323
+ self.assertIs(bar_invocation.parent_execution_unit, foo_invocation)
324
+ self.assertEqual(bar_invocation.id, 'agent@1:/a1/a5')
247
325
  self.assertIsInstance(bar_invocation, action_lib.ActionInvocation)
248
326
  self.assertIsInstance(bar_invocation.action, Bar)
249
327
  self.assertEqual(bar_invocation.result, 2)
@@ -497,26 +575,51 @@ class SessionTest(unittest.TestCase):
497
575
  super()._on_bound()
498
576
  self.progresses = []
499
577
 
578
+ def on_session_start(self, session):
579
+ session.add_metadata(progresses=pg.Ref(self.progresses))
580
+
500
581
  def on_action_progress(self, session, action, title, **kwargs):
501
582
  self.progresses.append((action.id, title))
502
583
 
503
584
  handler = MyActionHandler()
585
+ self.assertIs(handler.get(MyActionHandler), handler)
586
+ self.assertIsNone(handler.get(action_lib.SessionLogging))
587
+
588
+ handler_chain = action_lib.SessionEventHandlerChain(
589
+ handlers=[handler, action_lib.SessionLogging()]
590
+ )
591
+ self.assertIs(handler_chain.get(MyActionHandler), handler)
592
+ self.assertIs(
593
+ handler_chain.get(action_lib.SessionLogging),
594
+ handler_chain.handlers[1]
595
+ )
596
+
504
597
  session = action_lib.Session(
505
598
  id='agent@1',
506
- event_handler=action_lib.SessionEventHandlerChain(
507
- handlers=[handler, action_lib.SessionLogging()]
508
- )
599
+ event_handler=handler_chain
509
600
  )
510
601
  bar = Bar()
511
602
  with session:
512
603
  bar(session, lm=fake.StaticResponse('lm response'))
513
604
  session.update_progress('Trajectory completed')
514
605
 
606
+ self.assertIs(session.metadata['progresses'], handler.progresses)
515
607
  self.assertEqual(handler.progresses, [
516
608
  ('agent@1:/a1', 'Query completed'),
517
609
  ('agent@1:', 'Trajectory completed'),
518
610
  ])
519
611
 
612
+ def test_clone(self):
613
+ event_handler = action_lib.SessionLogging()
614
+ session = action_lib.Session(event_handler=event_handler)
615
+ other = session.clone()
616
+ self.assertIsNot(session, other)
617
+ self.assertIs(other.event_handler, event_handler)
618
+
619
+ other = session.clone(deep=True)
620
+ self.assertIsNot(session, other)
621
+ self.assertIsNot(other.event_handler, session.event_handler)
622
+
520
623
  def test_log(self):
521
624
  session = action_lib.Session()
522
625
  session.debug('hi', x=1, y=2)
@@ -101,7 +101,7 @@ class EvaluationTest(unittest.TestCase):
101
101
  self.assertEqual(s.dir, os.path.join(s.root_dir, s.id))
102
102
  self.assertEqual(s.hash, s.clone().hash)
103
103
  # Test persistent hash.
104
- self.assertEqual(s.hash, 'ee958159')
104
+ self.assertEqual(s.hash, '4dfe486a')
105
105
  self.assertEqual(
106
106
  s.hash, s.clone(override={'max_workers': 2, 'lm.timeout': 20}).hash
107
107
  )
@@ -211,7 +211,7 @@ class EvaluationTest(unittest.TestCase):
211
211
  s.result,
212
212
  dict(
213
213
  experiment_setup=dict(
214
- id='Evaluation@27a702cb',
214
+ id='Evaluation@e028b6e6',
215
215
  dir=s.dir,
216
216
  model='StaticSequence',
217
217
  prompt_template='{{example.question}}',
@@ -269,7 +269,7 @@ class EvaluationTest(unittest.TestCase):
269
269
  s.root_dir, base.Evaluation.SUMMARY_HTML.replace('.html', '.json')
270
270
  )
271
271
  self.assertTrue(os.path.exists(summary_json))
272
- summary = pg.load(summary_json, auto_dict=True)
272
+ summary = pg.load(summary_json, convert_unknown=True)
273
273
  self.assertIn('Evaluation', summary)
274
274
  self.assertEqual(len(summary['Evaluation']), 1)
275
275
  self.assertIsNotNone(summary['Evaluation'][0].experiment)
@@ -376,7 +376,7 @@ class EvaluationTest(unittest.TestCase):
376
376
  s.children[0].dir, os.path.join(s.root_dir, s.children[0].id)
377
377
  )
378
378
  # Test persistent hash.
379
- self.assertEqual(s.hash, 'f47532a7')
379
+ self.assertEqual(s.hash, 'fa8f5419')
380
380
 
381
381
  summary = s.run(verbose=True)
382
382
  self.assertEqual(len(summary.evaluations), 2)
@@ -526,7 +526,7 @@ class SuiteTest(unittest.TestCase):
526
526
  lm=lm
527
527
  )
528
528
  # Test for persistent hash.
529
- self.assertEqual(s.hash, '4bd6a2f5')
529
+ self.assertEqual(s.hash, 'ec3901b8')
530
530
  s.run()
531
531
  expected = {
532
532
  s.children[0].id: dict(
@@ -13,6 +13,7 @@
13
13
  # limitations under the License.
14
14
  """Checkpointing evaluation runs."""
15
15
  import abc
16
+ import datetime
16
17
  import re
17
18
  import threading
18
19
  import traceback
@@ -44,7 +45,12 @@ class Checkpointer(experiment_lib.Plugin):
44
45
  checkpoint_filename: Annotated[
45
46
  str,
46
47
  'Checkpoint file pattern.'
47
- ] = 'checkpoint.bagz'
48
+ ] = 'checkpoint.jsonl'
49
+
50
+ enable_inprogress_file: Annotated[
51
+ bool,
52
+ 'If True, write file "<example_id>.inprogress" when example gets started.'
53
+ ] = True
48
54
 
49
55
  max_ckpt_loading_threads: Annotated[
50
56
  int,
@@ -90,6 +96,24 @@ class Checkpointer(experiment_lib.Plugin):
90
96
  f'scratch. Example IDs: {example_ids_to_evaluate}.'
91
97
  )
92
98
 
99
+ def on_example_start(
100
+ self,
101
+ runner: Runner,
102
+ experiment: Experiment,
103
+ example: Example,
104
+ ) -> None:
105
+ """Saves the example to the checkpoint file."""
106
+ if self.enable_inprogress_file:
107
+ def _save_inprogress_file(example: Example):
108
+ inprogress_file = runner.current_run.output_path_for(
109
+ experiment, f'{example.id}.inprogress'
110
+ )
111
+ pg.io.writefile(
112
+ inprogress_file,
113
+ datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
114
+ )
115
+ runner.background_run(_save_inprogress_file, example)
116
+
93
117
  def on_example_complete(
94
118
  self,
95
119
  runner: Runner,
@@ -90,7 +90,10 @@ class PerExampleCheckpointerTest(CheckpointerTest):
90
90
  root_dir = os.path.join(tempfile.mkdtemp(), 'per_example_checkpointer')
91
91
  experiment = eval_test_helper.test_experiment()
92
92
  checkpoint_filename = 'checkpoint.jsonl'
93
- checkpointer = checkpointing.PerExampleCheckpointer(checkpoint_filename)
93
+ checkpointer = checkpointing.PerExampleCheckpointer(
94
+ checkpoint_filename,
95
+ enable_inprogress_file=True
96
+ )
94
97
  collector = ExampleCollector()
95
98
  run = experiment.run(
96
99
  root_dir, 'new', runner='sequential', plugins=[checkpointer, collector]
@@ -102,6 +105,10 @@ class PerExampleCheckpointerTest(CheckpointerTest):
102
105
  example = collector.examples[i + 1]
103
106
  ckpt = run.output_path_for(leaf, f'checkpoint_{example.id}.jsonl')
104
107
  self.assertTrue(pg.io.path_exists(ckpt))
108
+ inprogress_file = run.output_path_for(
109
+ leaf, f'{example.id}.inprogress'
110
+ )
111
+ self.assertTrue(pg.io.path_exists(inprogress_file))
105
112
  with pg.io.open_sequence(ckpt) as f:
106
113
  examples_from_ckpt = list(iter(f))
107
114
  # `eval_test_helper.test_experiment` has two TestEvaluation with
@@ -127,11 +127,16 @@ class TestEvaluationWithIndexHtmlGenerationError(TestEvaluation):
127
127
  raise ValueError('Cannot render HTML.')
128
128
 
129
129
 
130
+ def test_evaluation(offset: int | pg.hyper.OneOf = 0):
131
+ """Returns a test evaluation."""
132
+ return TestEvaluation(lm=TestLLM(offset=offset))
133
+
134
+
130
135
  def test_experiment():
131
136
  """Returns a test experiment."""
132
137
  return Suite([
133
- TestEvaluation(lm=TestLLM(offset=0)),
134
- TestEvaluation(lm=TestLLM(offset=pg.oneof(range(5)))),
138
+ test_evaluation(),
139
+ test_evaluation(pg.oneof(range(5))),
135
140
  ])
136
141
 
137
142
 
@@ -880,8 +880,9 @@ class EvaluationState:
880
880
  load_example_metadata: bool | Callable[
881
881
  [example_lib.Example], bool] = True,
882
882
  filter: Callable[[example_lib.Example], bool] | None = None, # pylint: disable=redefined-builtin
883
- ) -> None:
883
+ ) -> list[example_lib.Example]:
884
884
  """Loads the state from the example sequence file."""
885
+ examples = []
885
886
  for example in example_lib.Example.iter_ckpts(
886
887
  state_file,
887
888
  example_input_by_id=example_input_by_id,
@@ -891,6 +892,8 @@ class EvaluationState:
891
892
  continue
892
893
  example.newly_processed = False
893
894
  self._ckpt_examples[example.id] = example
895
+ examples.append(example)
896
+ return examples
894
897
 
895
898
  @property
896
899
  def evaluation_status(self) -> dict[int, ExampleStatus]:
@@ -155,6 +155,8 @@ class Example(pg.JSONConvertible, pg.views.HtmlTreeView.Extension):
155
155
  ckpt_file: str | list[str],
156
156
  example_input_by_id: Callable[[int], Any] | None = None,
157
157
  load_example_metadata: bool = True,
158
+ convert_unknown: bool = True,
159
+ **kwargs
158
160
  ) -> Iterator['Example']:
159
161
  """Iterates Examples from the checkpoint files."""
160
162
  ckpt_files = [ckpt_file] if isinstance(ckpt_file, str) else ckpt_file
@@ -164,7 +166,9 @@ class Example(pg.JSONConvertible, pg.views.HtmlTreeView.Extension):
164
166
  example = pg.from_json_str(
165
167
  record,
166
168
  example_input_by_id=example_input_by_id,
167
- load_example_metadata=load_example_metadata
169
+ load_example_metadata=load_example_metadata,
170
+ convert_unknown=convert_unknown,
171
+ **kwargs
168
172
  )
169
173
  assert isinstance(example, cls), example
170
174
  yield example
@@ -94,15 +94,23 @@ class ExampleTest(unittest.TestCase):
94
94
  pg.JSONConvertible._TYPE_REGISTRY._type_to_cls_map.pop(
95
95
  inputs[0].b.__type_name__
96
96
  )
97
- v = pg.from_json_str(json_str, auto_dict=True, load_example_metadata=True)
98
- v.output.pop('type_name')
99
- v.metadata.b.pop('type_name')
97
+ v = pg.from_json_str(
98
+ json_str,
99
+ convert_unknown=True,
100
+ load_example_metadata=True
101
+ )
100
102
  self.assertEqual(
101
103
  v,
102
104
  Example(
103
105
  id=1,
104
- output=pg.Dict(x=1),
105
- metadata=dict(b=pg.Dict(x=1, y=2)),
106
+ output=pg.symbolic.UnknownTypedObject(
107
+ inputs[0].a.__type_name__, x=1
108
+ ),
109
+ metadata=dict(
110
+ b=pg.symbolic.UnknownTypedObject(
111
+ inputs[0].b.__type_name__, x=1, y=2
112
+ )
113
+ ),
106
114
  )
107
115
  )
108
116
  # Serialize with input.
@@ -1055,6 +1055,29 @@ class Plugin(lf.Component):
1055
1055
  or result processing.
1056
1056
  """
1057
1057
 
1058
+ @classmethod
1059
+ def is_per_example(cls) -> bool:
1060
+ """Returns whether the plugin is per example only.
1061
+
1062
+ Per-example plugins can be installed on individual workers when examples
1063
+ are evaluated by multiple processes in parallel.
1064
+ """
1065
+
1066
+ def same_code(method1, method2):
1067
+ return method1.__code__ == method2.__code__
1068
+ return all(
1069
+ same_code(method1, method2)
1070
+ for method1, method2 in [
1071
+ (Plugin.on_run_start, cls.on_run_start),
1072
+ (Plugin.on_run_complete, cls.on_run_complete),
1073
+ (Plugin.on_run_abort, cls.on_run_abort),
1074
+ (Plugin.on_experiment_start, cls.on_experiment_start),
1075
+ (Plugin.on_experiment_skipped, cls.on_experiment_skipped),
1076
+ (Plugin.on_experiment_complete, cls.on_experiment_complete),
1077
+ (Plugin.on_experiment_abort, cls.on_experiment_abort),
1078
+ ]
1079
+ )
1080
+
1058
1081
  def on_run_start(
1059
1082
  self,
1060
1083
  runner: Runner,
@@ -433,5 +433,24 @@ class RunnerTest(unittest.TestCase):
433
433
  pass
434
434
 
435
435
 
436
+ class PluginTest(unittest.TestCase):
437
+
438
+ def test_per_example_only(self):
439
+
440
+ class PerExamplePlugin(experiment_lib.Plugin):
441
+
442
+ def on_example_complete(self, runner, experiment, example):
443
+ print('on_example_complete')
444
+
445
+ self.assertTrue(PerExamplePlugin.is_per_example())
446
+
447
+ class NonPerExamplePlugin(experiment_lib.Plugin):
448
+
449
+ def on_experiment_complete(self, runner, experiment):
450
+ print('on_example_complete')
451
+
452
+ self.assertFalse(NonPerExamplePlugin.is_per_example())
453
+
454
+
436
455
  if __name__ == '__main__':
437
456
  unittest.main()
@@ -14,6 +14,7 @@
14
14
  """Tracking evaluation run progress."""
15
15
 
16
16
  import os
17
+ from typing import Literal
17
18
  import langfun.core as lf
18
19
  from langfun.core.eval.v2 import example as example_lib
19
20
  from langfun.core.eval.v2 import experiment as experiment_lib
@@ -24,16 +25,24 @@ Experiment = experiment_lib.Experiment
24
25
  Example = example_lib.Example
25
26
 
26
27
 
27
- def progress_tracker(tqdm: bool = False) -> experiment_lib.Plugin:
28
+ def progress_tracker(
29
+ tracker_type: Literal['tqdm', 'html', 'auto'] = 'auto'
30
+ ) -> experiment_lib.Plugin:
28
31
  """Creates a progress tracker as a plugin.
29
32
 
30
33
  Args:
31
- tqdm: If True, force using tqdm for progress update.
34
+ tracker_type: The type of progress tracker to use.
35
+ If `tqdm`, force using tqdm for progress update.
36
+ If `html`, force using html for progress update.
37
+ If `auto`, determine it automatically based on the running
38
+ environment (console vs. notebook)
32
39
 
33
40
  Returns:
34
41
  The progress tracker plugin.
35
42
  """
36
- if tqdm or not lf.console.under_notebook():
43
+ if tracker_type == 'tqdm' or (
44
+ tracker_type == 'auto' and not lf.console.under_notebook()
45
+ ):
37
46
  return _TqdmProgressTracker()
38
47
  else:
39
48
  return _HtmlProgressTracker()
@@ -21,7 +21,7 @@ import unittest
21
21
  from langfun.core import concurrent as lf_concurrent
22
22
  from langfun.core import console as lf_console
23
23
  from langfun.core.eval.v2 import eval_test_helper
24
- from langfun.core.eval.v2 import progress_tracking # pylint: disable=unused-import
24
+ from langfun.core.eval.v2 import progress_tracking
25
25
  from langfun.core.eval.v2 import runners as runners_lib # pylint: disable=unused-import
26
26
  import pyglove as pg
27
27
 
@@ -33,6 +33,7 @@ class HtmlProgressTrackerTest(unittest.TestCase):
33
33
  def display(x):
34
34
  result['view'] = x.to_html()
35
35
 
36
+ self.assertFalse(progress_tracking._HtmlProgressTracker.is_per_example())
36
37
  lf_console._notebook = pg.Dict(
37
38
  display=display
38
39
  )
@@ -46,6 +47,7 @@ class HtmlProgressTrackerTest(unittest.TestCase):
46
47
  class TqdmProgressTrackerTest(unittest.TestCase):
47
48
 
48
49
  def test_basic(self):
50
+ self.assertFalse(progress_tracking._TqdmProgressTracker.is_per_example())
49
51
  root_dir = os.path.join(tempfile.mkdtemp(), 'test_tqdm_progress_tracker')
50
52
  experiment = eval_test_helper.test_experiment()
51
53
  string_io = io.StringIO()
@@ -29,7 +29,11 @@ class ReportingTest(unittest.TestCase):
29
29
  experiment = eval_test_helper.test_experiment()
30
30
  checkpointer = checkpointing.BulkCheckpointer('checkpoint.jsonl')
31
31
  reporter = reporting.HtmlReporter()
32
+ self.assertFalse(reporter.is_per_example())
33
+
32
34
  example_html_generator = reporting.ExampleHtmlGenerator()
35
+ self.assertTrue(example_html_generator.is_per_example())
36
+
33
37
  run = experiment.run(
34
38
  root_dir,
35
39
  'new',
@@ -13,13 +13,17 @@
13
13
  # limitations under the License.
14
14
  """Langfun evaluation runners."""
15
15
 
16
+ # pylint: disable=g-importing-member
16
17
  from langfun.core.eval.v2.runners.base import RunnerBase
18
+ from langfun.core.eval.v2.runners.beam import BeamRunner
17
19
  from langfun.core.eval.v2.runners.debug import DebugRunner
18
20
  from langfun.core.eval.v2.runners.parallel import ParallelRunner
19
21
  from langfun.core.eval.v2.runners.sequential import SequentialRunner
22
+ # pylint: enable=g-importing-member
20
23
 
21
24
  __all__ = [
22
25
  'RunnerBase',
26
+ 'BeamRunner',
23
27
  'DebugRunner',
24
28
  'ParallelRunner',
25
29
  'SequentialRunner',