langfun 0.1.2.dev202504280818__py3-none-any.whl → 0.1.2.dev202504300804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of langfun might be problematic. Click here for more details.

@@ -22,54 +22,108 @@ import langfun.core.structured as lf_structured
22
22
  import pyglove as pg
23
23
 
24
24
 
25
- class SessionTest(unittest.TestCase):
25
+ class Bar(action_lib.Action):
26
+ simulate_action_error: bool = False
27
+
28
+ def call(self, session, *, lm, **kwargs):
29
+ assert session.current_action.action is self
30
+ session.info('Begin Bar')
31
+ session.query('bar', lm=lm)
32
+ session.add_metadata(note='bar')
33
+ if self.simulate_action_error:
34
+ raise ValueError('Bar error')
35
+ return 2
36
+
37
+
38
+ class Foo(action_lib.Action):
39
+ x: int
40
+ simulate_action_error: bool = False
41
+ simulate_query_error: bool = False
42
+
43
+ def call(self, session, *, lm, **kwargs):
44
+ assert session.current_action.action is self
45
+ with session.track_phase('prepare'):
46
+ session.info('Begin Foo', x=1)
47
+ session.query(
48
+ 'foo',
49
+ schema=int if self.simulate_query_error else None,
50
+ lm=lm
51
+ )
52
+ with session.track_queries():
53
+ self.make_additional_query(lm)
54
+ session.add_metadata(note='foo')
55
+
56
+ def _sub_task(i):
57
+ session.add_metadata(**{f'subtask_{i}': i})
58
+ return lf_structured.query(f'subtask_{i}', lm=lm)
59
+
60
+ for i, output, error in session.concurrent_map(
61
+ _sub_task, range(3), max_workers=2, silence_on_errors=None,
62
+ ):
63
+ assert isinstance(i, int), i
64
+ assert isinstance(output, str), output
65
+ assert error is None, error
66
+ return self.x + Bar(
67
+ simulate_action_error=self.simulate_action_error
68
+ )(session, lm=lm)
69
+
70
+ def make_additional_query(self, lm):
71
+ lf_structured.query('additional query', lm=lm)
72
+
73
+
74
+ class ActionInvocationTest(unittest.TestCase):
75
+
76
+ def test_basics(self):
77
+ action_invocation = action_lib.ActionInvocation(
78
+ action=Foo(1)
79
+ )
80
+ self.assertEqual(action_invocation.id, '')
81
+ root = action_lib.ActionInvocation(action=action_lib.RootAction())
82
+ root.execution.append(action_invocation)
83
+ self.assertIs(action_invocation.parent_action, root)
84
+ self.assertEqual(action_invocation.id, '/a1')
85
+
86
+
87
+ class ExecutionTraceTest(unittest.TestCase):
26
88
 
27
89
  def test_basics(self):
28
- test = self
29
-
30
- class Bar(action_lib.Action):
31
-
32
- def call(self, session, *, lm, **kwargs):
33
- test.assertIs(session.current_action.action, self)
34
- session.info('Begin Bar')
35
- session.query('bar', lm=lm)
36
- session.add_metadata(note='bar')
37
- return 2
38
-
39
- class Foo(action_lib.Action):
40
- x: int
41
-
42
- def call(self, session, *, lm, **kwargs):
43
- test.assertIs(session.current_action.action, self)
44
- with session.track_phase('prepare'):
45
- session.info('Begin Foo', x=1)
46
- session.query('foo', lm=lm)
47
- with session.track_queries():
48
- self.make_additional_query(lm)
49
- session.add_metadata(note='foo')
50
-
51
- def _sub_task(i):
52
- session.add_metadata(**{f'subtask_{i}': i})
53
- return lf_structured.query(f'subtask_{i}', lm=lm)
54
-
55
- for i, output, error in session.concurrent_map(
56
- _sub_task, range(3), max_workers=2, silence_on_errors=None,
57
- ):
58
- assert isinstance(i, int), i
59
- assert isinstance(output, str), output
60
- assert error is None, error
61
- return self.x + Bar()(session, lm=lm)
62
-
63
- def make_additional_query(self, lm):
64
- lf_structured.query('additional query', lm=lm)
90
+ execution = action_lib.ExecutionTrace()
91
+ self.assertEqual(execution.id, '')
65
92
 
93
+ root = action_lib.ActionInvocation(action=action_lib.RootAction())
94
+ action_invocation = action_lib.ActionInvocation(
95
+ action=Foo(1)
96
+ )
97
+ root.execution.append(action_invocation)
98
+ self.assertEqual(action_invocation.execution.id, '/a1')
99
+
100
+ root.execution.reset()
101
+ self.assertEqual(len(root.execution), 0)
102
+
103
+
104
+ class SessionTest(unittest.TestCase):
105
+
106
+ def test_succeeded_trajectory(self):
66
107
  lm = fake.StaticResponse('lm response')
67
108
  foo = Foo(1)
68
- self.assertEqual(foo(lm=lm, verbose=True), 3)
109
+ self.assertIsNone(foo.session)
110
+ self.assertIsNone(foo.result)
111
+ self.assertIsNone(foo.metadata)
112
+
113
+ session = action_lib.Session(id='agent@1')
114
+ self.assertEqual(session.id, 'agent@1')
115
+
116
+ # Render HTML view to trigger dynamic update during execution.
117
+ _ = session.to_html()
118
+
119
+ self.assertEqual(foo(session, lm=lm, verbose=True), 3)
120
+
121
+ self.assertIsNone(foo.session)
122
+ self.assertEqual(foo.result, 3)
123
+ self.assertEqual(
124
+ foo.metadata, dict(note='foo', subtask_0=0, subtask_1=1, subtask_2=2)
125
+ )
69
126
 
70
- session = foo.session
71
- self.assertIn('session@', session.id)
72
- self.assertIsNotNone(session)
73
127
  self.assertIsInstance(session.root.action, action_lib.RootAction)
74
128
  self.assertIs(session.current_action, session.root)
75
129
 
@@ -78,6 +132,9 @@ class SessionTest(unittest.TestCase):
78
132
  #
79
133
 
80
134
  root = session.root
135
+ self.assertIsNone(root.parent_action)
136
+ self.assertEqual(root.id, 'agent@1:')
137
+ self.assertEqual(root.execution.id, 'agent@1:')
81
138
  self.assertEqual(len(root.execution.items), 1)
82
139
  self.assertIs(root.execution.items[0].action, foo)
83
140
 
@@ -104,33 +161,57 @@ class SessionTest(unittest.TestCase):
104
161
 
105
162
  # Inspecting the top-level action (Foo)
106
163
  foo_invocation = root.execution.items[0]
164
+ self.assertIs(foo_invocation.parent_action, root)
165
+ self.assertEqual(foo_invocation.id, 'agent@1:/a1')
166
+ self.assertEqual(foo_invocation.execution.id, 'agent@1:/a1')
107
167
  self.assertEqual(len(foo_invocation.execution.items), 4)
108
168
 
109
169
  # Prepare phase.
110
170
  prepare_phase = foo_invocation.execution.items[0]
111
- self.assertIsInstance(
112
- prepare_phase, action_lib.ExecutionTrace
113
- )
171
+ self.assertIsInstance(prepare_phase, action_lib.ExecutionTrace)
172
+ self.assertEqual(prepare_phase.id, 'agent@1:/a1/prepare')
114
173
  self.assertEqual(len(prepare_phase.items), 2)
115
174
  self.assertTrue(prepare_phase.has_started)
116
175
  self.assertTrue(prepare_phase.has_stopped)
117
176
  self.assertEqual(prepare_phase.usage_summary.total.num_requests, 1)
177
+ self.assertIsInstance(prepare_phase.items[0], lf.logging.LogEntry)
178
+ self.assertIsInstance(prepare_phase.items[1], lf_structured.QueryInvocation)
179
+ self.assertEqual(prepare_phase.items[1].id, 'agent@1:/a1/prepare/q1')
118
180
 
119
181
  # Tracked queries.
120
182
  query_invocation = foo_invocation.execution.items[1]
121
183
  self.assertIsInstance(query_invocation, lf_structured.QueryInvocation)
184
+ self.assertEqual(query_invocation.id, 'agent@1:/a1/q2')
122
185
  self.assertIs(query_invocation.lm, lm)
186
+ self.assertEqual(
187
+ foo_invocation.execution.indexof(
188
+ query_invocation, lf_structured.QueryInvocation
189
+ ),
190
+ 1
191
+ )
192
+ self.assertEqual(
193
+ root.execution.indexof(
194
+ query_invocation, lf_structured.QueryInvocation
195
+ ),
196
+ -1
197
+ )
123
198
 
124
199
  # Tracked parallel executions.
125
200
  parallel_executions = foo_invocation.execution.items[2]
201
+ self.assertEqual(parallel_executions.id, 'agent@1:/a1/p1')
126
202
  self.assertIsInstance(parallel_executions, action_lib.ParallelExecutions)
127
203
  self.assertEqual(len(parallel_executions), 3)
204
+ self.assertEqual(parallel_executions[0].id, 'agent@1:/a1/p1/b1')
205
+ self.assertEqual(parallel_executions[1].id, 'agent@1:/a1/p1/b2')
206
+ self.assertEqual(parallel_executions[2].id, 'agent@1:/a1/p1/b3')
128
207
  self.assertEqual(len(parallel_executions[0].queries), 1)
129
208
  self.assertEqual(len(parallel_executions[1].queries), 1)
130
209
  self.assertEqual(len(parallel_executions[2].queries), 1)
131
210
 
132
211
  # Invocation to Bar.
133
212
  bar_invocation = foo_invocation.execution.items[3]
213
+ self.assertIs(bar_invocation.parent_action, foo_invocation)
214
+ self.assertEqual(bar_invocation.id, 'agent@1:/a1/a1')
134
215
  self.assertIsInstance(bar_invocation, action_lib.ActionInvocation)
135
216
  self.assertIsInstance(bar_invocation.action, Bar)
136
217
  self.assertEqual(bar_invocation.result, 2)
@@ -144,6 +225,51 @@ class SessionTest(unittest.TestCase):
144
225
  json_str = session.to_json_str(save_ref_value=True)
145
226
  self.assertIsInstance(pg.from_json_str(json_str), action_lib.Session)
146
227
 
228
+ def test_failed_action(self):
229
+ lm = fake.StaticResponse('lm response')
230
+ foo = Foo(1, simulate_action_error=True)
231
+ with self.assertRaisesRegex(ValueError, 'Bar error'):
232
+ foo(lm=lm)
233
+
234
+ session = foo.session
235
+ self.assertIsNotNone(session)
236
+ self.assertIsInstance(session.root.action, action_lib.RootAction)
237
+ self.assertIs(session.current_action, session.root)
238
+
239
+ # Inspecting the root invocation.
240
+ root = session.root
241
+ self.assertRegex(root.id, 'agent@.*:')
242
+ self.assertTrue(root.has_error)
243
+ foo_invocation = root.execution.items[0]
244
+ self.assertIsInstance(foo_invocation, action_lib.ActionInvocation)
245
+ self.assertTrue(foo_invocation.has_error)
246
+ bar_invocation = foo_invocation.execution.items[3]
247
+ self.assertIsInstance(bar_invocation, action_lib.ActionInvocation)
248
+ self.assertTrue(bar_invocation.has_error)
249
+
250
+ # Save to HTML
251
+ self.assertIn('error', session.to_html().content)
252
+
253
+ def test_failed_query(self):
254
+ lm = fake.StaticResponse('lm response')
255
+ foo = Foo(1, simulate_query_error=True)
256
+ with self.assertRaisesRegex(lf_structured.MappingError, 'SyntaxError'):
257
+ foo(lm=lm)
258
+
259
+ session = foo.session
260
+ self.assertIsNotNone(session)
261
+ self.assertIsInstance(session.root.action, action_lib.RootAction)
262
+ self.assertIs(session.current_action, session.root)
263
+
264
+ # Inspecting the root invocation.
265
+ root = session.root
266
+ self.assertRegex(root.id, 'agent@.*:')
267
+ self.assertTrue(root.has_error)
268
+ foo_invocation = root.execution.items[0]
269
+ self.assertIsInstance(foo_invocation, action_lib.ActionInvocation)
270
+ self.assertTrue(foo_invocation.has_error)
271
+ self.assertEqual(len(foo_invocation.execution.items), 2)
272
+
147
273
  def test_log(self):
148
274
  session = action_lib.Session()
149
275
  session.debug('hi', x=1, y=2)
@@ -153,8 +279,8 @@ class SessionTest(unittest.TestCase):
153
279
  session.fatal('hi', x=1, y=2)
154
280
 
155
281
  def test_as_message(self):
156
- session = action_lib.Session(id='abc')
157
- self.assertEqual(session.id, 'abc')
282
+ session = action_lib.Session()
283
+ self.assertIn('agent@', session.id)
158
284
  self.assertIsInstance(session.as_message(), lf.AIMessage)
159
285
 
160
286
 
@@ -195,7 +195,7 @@ class EvaluationTest(unittest.TestCase):
195
195
  score=1.0,
196
196
  logprobs=None,
197
197
  is_cached=False,
198
- usage=lf.LMSamplingUsage(387, 24, 411),
198
+ usage=lf.LMSamplingUsage(428, 24, 452),
199
199
  tags=['lm-response', 'lm-output', 'transformed'],
200
200
  ),
201
201
  )
@@ -234,12 +234,12 @@ class EvaluationTest(unittest.TestCase):
234
234
  }
235
235
  ),
236
236
  usage=dict(
237
- total_prompt_tokens=774,
237
+ total_prompt_tokens=856,
238
238
  total_completion_tokens=25,
239
239
  num_usages=2,
240
- average_prompt_tokens=387,
240
+ average_prompt_tokens=428,
241
241
  average_completion_tokens=12,
242
- average_total_tokens=399,
242
+ average_total_tokens=440,
243
243
  ),
244
244
  ),
245
245
  )
@@ -167,6 +167,8 @@ class Evaluation(experiment_lib.Experiment):
167
167
  example.input = self.example_input_by_id(example.id)
168
168
 
169
169
  checkpointed = self._state.ckpt_example(example.id)
170
+ self._state.update(example, in_progress=True)
171
+
170
172
  with pg.timeit('evaluate') as timeit, lf.track_usages() as usage_summary:
171
173
  if checkpointed is None or checkpointed.has_error:
172
174
  if checkpointed is None:
@@ -221,7 +223,7 @@ class Evaluation(experiment_lib.Experiment):
221
223
  if example.newly_processed:
222
224
  example.end_time = time.time()
223
225
 
224
- self._state.update(example)
226
+ self._state.update(example, in_progress=False)
225
227
  return example
226
228
 
227
229
  def _process(
@@ -501,6 +503,21 @@ class Evaluation(experiment_lib.Experiment):
501
503
  )
502
504
  )
503
505
 
506
+ def _in_progress_tab() -> pg.views.html.controls.Tab | None:
507
+ """Renders a tab for the in progress examples."""
508
+ if not self.state.in_progress_examples:
509
+ return None
510
+ return pg.views.html.controls.Tab(
511
+ label='In Progress',
512
+ content=pg.Html.element(
513
+ 'div', [
514
+ self._in_progress_view(
515
+ list(self.state.in_progress_examples.values())
516
+ )
517
+ ]
518
+ )
519
+ )
520
+
504
521
  def _metric_tab(metric: metrics_lib.Metric) -> pg.views.html.controls.Tab:
505
522
  """Renders a tab for a metric (group)."""
506
523
  return pg.views.html.controls.Tab(
@@ -571,10 +588,9 @@ class Evaluation(experiment_lib.Experiment):
571
588
  pg.views.html.controls.TabControl(
572
589
  [
573
590
  _definition_tab(),
574
- ] + [
575
- _metric_tab(m) for m in self.metrics
576
- ] + [
577
- _logs_tab()
591
+ [_metric_tab(m) for m in self.metrics],
592
+ _in_progress_tab(),
593
+ _logs_tab(),
578
594
  ],
579
595
  selected=1,
580
596
  )
@@ -598,6 +614,27 @@ class Evaluation(experiment_lib.Experiment):
598
614
  css_classes=['eval-details'],
599
615
  )
600
616
 
617
+ def _in_progress_view(
618
+ self, in_progress_examples: list[example_lib.Example]
619
+ ) -> pg.Html:
620
+ """Renders a HTML view for the in-progress examples."""
621
+ current_time = time.time()
622
+ logs = [f'(Total {len(in_progress_examples)} examples in progress)']
623
+ for example in in_progress_examples:
624
+ if example.newly_processed:
625
+ logs.append(
626
+ f'Example {example.id}: In progress for '
627
+ f'{current_time - example.start_time:.2f} seconds.'
628
+ )
629
+ else:
630
+ logs.append(f'Example {example.id}: Recomputing metrics...')
631
+ return pg.Html.element(
632
+ 'textarea',
633
+ [pg.Html.escape('\n'.join(logs))],
634
+ readonly=True,
635
+ css_classes=['logs-textarea'],
636
+ )
637
+
601
638
  def _html_tree_view_config(self) -> dict[str, Any]:
602
639
  return dict(
603
640
  css_classes=['eval-card'] if self.is_leaf else None
@@ -716,14 +753,27 @@ class EvaluationState:
716
753
  'Whether the example is evaluated.'
717
754
  ] = False
718
755
 
756
+ in_progress: Annotated[
757
+ bool,
758
+ (
759
+ 'Whether the example is in progress. '
760
+ )
761
+ ] = False
762
+
719
763
  newly_processed: Annotated[
720
764
  bool,
721
- 'Whether the example is newly processed.'
765
+ (
766
+ 'Whether the example is newly processed. '
767
+ 'Applicable only when evaluated is True.'
768
+ )
722
769
  ] = False
723
770
 
724
771
  has_error: Annotated[
725
772
  bool,
726
- 'Whether the example has error.'
773
+ (
774
+ 'Whether the example has error. '
775
+ 'Applicable only when evaluated is True.'
776
+ )
727
777
  ] = False
728
778
 
729
779
  def __init__(self):
@@ -732,6 +782,7 @@ class EvaluationState:
732
782
  self._evaluation_status: dict[
733
783
  int, EvaluationState.ExampleStatus
734
784
  ] = {}
785
+ self._in_progress_examples: dict[int, example_lib.Example] = {}
735
786
 
736
787
  def load(
737
788
  self,
@@ -758,6 +809,11 @@ class EvaluationState:
758
809
  """Returns the evaluation status of the examples."""
759
810
  return self._evaluation_status
760
811
 
812
+ @property
813
+ def in_progress_examples(self) -> dict[int, example_lib.Example]:
814
+ """Returns the in-progress examples."""
815
+ return self._in_progress_examples
816
+
761
817
  @property
762
818
  def ckpt_examples(self) -> dict[int, example_lib.Example]:
763
819
  """Returns the unevaluated examples from checkpoints."""
@@ -773,17 +829,27 @@ class EvaluationState:
773
829
  example_id, EvaluationState.ExampleStatus()
774
830
  )
775
831
 
776
- def update(self, example: example_lib.Example) -> None:
832
+ def update(self, example: example_lib.Example, in_progress: bool) -> None:
777
833
  """Updates the state with the given example."""
778
- self._update_status(example)
779
- # Processed examples will be removed once it's done.
780
- self._ckpt_examples.pop(example.id, None)
834
+ self._update_status(example, in_progress)
835
+
836
+ if in_progress:
837
+ self._in_progress_examples[example.id] = example
838
+ else:
839
+ self._in_progress_examples.pop(example.id, None)
840
+ # Processed examples will be removed once it's done.
841
+ self._ckpt_examples.pop(example.id, None)
781
842
 
782
- def _update_status(self, example: example_lib.Example) -> None:
843
+ def _update_status(
844
+ self,
845
+ example: example_lib.Example,
846
+ in_progress: bool
847
+ ) -> None:
783
848
  """Updates the evaluation status of the example."""
784
849
  self._evaluation_status[example.id] = (
785
850
  EvaluationState.ExampleStatus(
786
851
  evaluated=example.output != pg.MISSING_VALUE,
852
+ in_progress=in_progress,
787
853
  newly_processed=example.newly_processed,
788
854
  has_error=example.has_error,
789
855
  )
@@ -79,8 +79,10 @@ class EvaluationTest(unittest.TestCase):
79
79
  exp = eval_test_helper.TestEvaluation()
80
80
  example = exp.evaluate(Example(id=3))
81
81
  self.assertTrue(exp.state.get_status(3).evaluated)
82
+ self.assertFalse(exp.state.get_status(3).in_progress)
82
83
  self.assertTrue(exp.state.get_status(3).newly_processed)
83
84
  self.assertFalse(exp.state.get_status(3).has_error)
85
+ self.assertEqual(exp.state.in_progress_examples, {})
84
86
  self.assertTrue(example.newly_processed)
85
87
  self.assertEqual(example.input, pg.Dict(x=2, y=4, groundtruth=6))
86
88
  self.assertEqual(example.output, 6)
@@ -55,8 +55,11 @@ from langfun.core.structured.parsing import call
55
55
 
56
56
  from langfun.core.structured.querying import track_queries
57
57
  from langfun.core.structured.querying import QueryInvocation
58
+
59
+ from langfun.core.structured.querying import LfQuery
58
60
  from langfun.core.structured.querying import query
59
61
  from langfun.core.structured.querying import query_and_reduce
62
+ from langfun.core.structured.querying import query_protocol
60
63
 
61
64
  from langfun.core.structured.querying import query_prompt
62
65
  from langfun.core.structured.querying import query_output
@@ -340,8 +340,11 @@ class Mapping(lf.LangFunc):
340
340
  schema_title: Annotated[str, 'The section title for schema.'] = 'SCHEMA'
341
341
 
342
342
  protocol: Annotated[
343
- schema_lib.SchemaProtocol,
344
- 'The protocol for representing the schema and value.',
343
+ str,
344
+ (
345
+ 'A string representing the protocol for formatting the prompt. '
346
+ 'Built-in Langfun protocols are: `python` and `json`.'
347
+ ),
345
348
  ] = 'python'
346
349
 
347
350
  #
@@ -646,7 +646,7 @@ class CallTest(unittest.TestCase):
646
646
  score=1.0,
647
647
  logprobs=None,
648
648
  is_cached=False,
649
- usage=lf.LMSamplingUsage(315, 1, 316),
649
+ usage=lf.LMSamplingUsage(356, 1, 357),
650
650
  tags=['lm-response', 'lm-output', 'transformed']
651
651
  ),
652
652
  )