langfun 0.1.2.dev202509120804__py3-none-any.whl → 0.1.2.dev202512150805__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (162) hide show
  1. langfun/__init__.py +1 -1
  2. langfun/core/__init__.py +7 -1
  3. langfun/core/agentic/__init__.py +8 -1
  4. langfun/core/agentic/action.py +740 -112
  5. langfun/core/agentic/action_eval.py +9 -2
  6. langfun/core/agentic/action_test.py +189 -24
  7. langfun/core/async_support.py +104 -5
  8. langfun/core/async_support_test.py +23 -0
  9. langfun/core/coding/python/correction.py +19 -9
  10. langfun/core/coding/python/execution.py +14 -12
  11. langfun/core/coding/python/generation.py +21 -16
  12. langfun/core/coding/python/sandboxing.py +23 -3
  13. langfun/core/component.py +42 -3
  14. langfun/core/concurrent.py +70 -6
  15. langfun/core/concurrent_test.py +9 -2
  16. langfun/core/console.py +1 -1
  17. langfun/core/data/conversion/anthropic.py +12 -3
  18. langfun/core/data/conversion/anthropic_test.py +8 -6
  19. langfun/core/data/conversion/gemini.py +11 -2
  20. langfun/core/data/conversion/gemini_test.py +48 -9
  21. langfun/core/data/conversion/openai.py +145 -31
  22. langfun/core/data/conversion/openai_test.py +161 -17
  23. langfun/core/eval/base.py +48 -44
  24. langfun/core/eval/base_test.py +5 -5
  25. langfun/core/eval/matching.py +5 -2
  26. langfun/core/eval/patching.py +3 -3
  27. langfun/core/eval/scoring.py +4 -3
  28. langfun/core/eval/v2/__init__.py +3 -0
  29. langfun/core/eval/v2/checkpointing.py +148 -46
  30. langfun/core/eval/v2/checkpointing_test.py +9 -2
  31. langfun/core/eval/v2/config_saver.py +37 -0
  32. langfun/core/eval/v2/config_saver_test.py +36 -0
  33. langfun/core/eval/v2/eval_test_helper.py +104 -3
  34. langfun/core/eval/v2/evaluation.py +102 -19
  35. langfun/core/eval/v2/evaluation_test.py +9 -3
  36. langfun/core/eval/v2/example.py +50 -40
  37. langfun/core/eval/v2/example_test.py +16 -8
  38. langfun/core/eval/v2/experiment.py +95 -20
  39. langfun/core/eval/v2/experiment_test.py +19 -0
  40. langfun/core/eval/v2/metric_values.py +31 -3
  41. langfun/core/eval/v2/metric_values_test.py +32 -0
  42. langfun/core/eval/v2/metrics.py +157 -44
  43. langfun/core/eval/v2/metrics_test.py +39 -18
  44. langfun/core/eval/v2/progress.py +31 -1
  45. langfun/core/eval/v2/progress_test.py +27 -0
  46. langfun/core/eval/v2/progress_tracking.py +13 -5
  47. langfun/core/eval/v2/progress_tracking_test.py +9 -1
  48. langfun/core/eval/v2/reporting.py +88 -71
  49. langfun/core/eval/v2/reporting_test.py +24 -6
  50. langfun/core/eval/v2/runners/__init__.py +30 -0
  51. langfun/core/eval/v2/{runners.py → runners/base.py} +73 -180
  52. langfun/core/eval/v2/runners/beam.py +354 -0
  53. langfun/core/eval/v2/runners/beam_test.py +153 -0
  54. langfun/core/eval/v2/runners/ckpt_monitor.py +350 -0
  55. langfun/core/eval/v2/runners/ckpt_monitor_test.py +213 -0
  56. langfun/core/eval/v2/runners/debug.py +40 -0
  57. langfun/core/eval/v2/runners/debug_test.py +76 -0
  58. langfun/core/eval/v2/runners/parallel.py +243 -0
  59. langfun/core/eval/v2/runners/parallel_test.py +182 -0
  60. langfun/core/eval/v2/runners/sequential.py +47 -0
  61. langfun/core/eval/v2/runners/sequential_test.py +169 -0
  62. langfun/core/langfunc.py +45 -130
  63. langfun/core/langfunc_test.py +7 -5
  64. langfun/core/language_model.py +189 -36
  65. langfun/core/language_model_test.py +54 -3
  66. langfun/core/llms/__init__.py +14 -1
  67. langfun/core/llms/anthropic.py +157 -2
  68. langfun/core/llms/azure_openai.py +29 -17
  69. langfun/core/llms/cache/base.py +25 -3
  70. langfun/core/llms/cache/in_memory.py +48 -7
  71. langfun/core/llms/cache/in_memory_test.py +14 -4
  72. langfun/core/llms/compositional.py +25 -1
  73. langfun/core/llms/deepseek.py +30 -2
  74. langfun/core/llms/fake.py +32 -1
  75. langfun/core/llms/gemini.py +90 -12
  76. langfun/core/llms/gemini_test.py +110 -0
  77. langfun/core/llms/google_genai.py +52 -1
  78. langfun/core/llms/groq.py +28 -3
  79. langfun/core/llms/llama_cpp.py +23 -4
  80. langfun/core/llms/openai.py +120 -3
  81. langfun/core/llms/openai_compatible.py +148 -27
  82. langfun/core/llms/openai_compatible_test.py +207 -20
  83. langfun/core/llms/openai_test.py +0 -2
  84. langfun/core/llms/rest.py +16 -1
  85. langfun/core/llms/vertexai.py +78 -8
  86. langfun/core/logging.py +1 -1
  87. langfun/core/mcp/__init__.py +10 -0
  88. langfun/core/mcp/client.py +177 -0
  89. langfun/core/mcp/client_test.py +71 -0
  90. langfun/core/mcp/session.py +241 -0
  91. langfun/core/mcp/session_test.py +54 -0
  92. langfun/core/mcp/testing/simple_mcp_client.py +33 -0
  93. langfun/core/mcp/testing/simple_mcp_server.py +33 -0
  94. langfun/core/mcp/tool.py +254 -0
  95. langfun/core/mcp/tool_test.py +197 -0
  96. langfun/core/memory.py +1 -0
  97. langfun/core/message.py +160 -55
  98. langfun/core/message_test.py +65 -81
  99. langfun/core/modalities/__init__.py +8 -0
  100. langfun/core/modalities/audio.py +21 -1
  101. langfun/core/modalities/image.py +73 -3
  102. langfun/core/modalities/image_test.py +116 -0
  103. langfun/core/modalities/mime.py +78 -4
  104. langfun/core/modalities/mime_test.py +59 -0
  105. langfun/core/modalities/pdf.py +19 -1
  106. langfun/core/modalities/video.py +21 -1
  107. langfun/core/modality.py +167 -29
  108. langfun/core/modality_test.py +42 -12
  109. langfun/core/natural_language.py +1 -1
  110. langfun/core/sampling.py +4 -4
  111. langfun/core/sampling_test.py +20 -4
  112. langfun/core/structured/__init__.py +2 -24
  113. langfun/core/structured/completion.py +34 -44
  114. langfun/core/structured/completion_test.py +23 -43
  115. langfun/core/structured/description.py +54 -50
  116. langfun/core/structured/function_generation.py +29 -12
  117. langfun/core/structured/mapping.py +81 -37
  118. langfun/core/structured/parsing.py +95 -79
  119. langfun/core/structured/parsing_test.py +0 -3
  120. langfun/core/structured/querying.py +230 -154
  121. langfun/core/structured/querying_test.py +69 -33
  122. langfun/core/structured/schema/__init__.py +49 -0
  123. langfun/core/structured/schema/base.py +664 -0
  124. langfun/core/structured/schema/base_test.py +531 -0
  125. langfun/core/structured/schema/json.py +174 -0
  126. langfun/core/structured/schema/json_test.py +121 -0
  127. langfun/core/structured/schema/python.py +316 -0
  128. langfun/core/structured/schema/python_test.py +410 -0
  129. langfun/core/structured/schema_generation.py +33 -14
  130. langfun/core/structured/scoring.py +47 -36
  131. langfun/core/structured/tokenization.py +26 -11
  132. langfun/core/subscription.py +2 -2
  133. langfun/core/template.py +175 -50
  134. langfun/core/template_test.py +123 -17
  135. langfun/env/__init__.py +43 -0
  136. langfun/env/base_environment.py +827 -0
  137. langfun/env/base_environment_test.py +473 -0
  138. langfun/env/base_feature.py +304 -0
  139. langfun/env/base_feature_test.py +228 -0
  140. langfun/env/base_sandbox.py +842 -0
  141. langfun/env/base_sandbox_test.py +1235 -0
  142. langfun/env/event_handlers/__init__.py +14 -0
  143. langfun/env/event_handlers/chain.py +233 -0
  144. langfun/env/event_handlers/chain_test.py +253 -0
  145. langfun/env/event_handlers/event_logger.py +472 -0
  146. langfun/env/event_handlers/event_logger_test.py +304 -0
  147. langfun/env/event_handlers/metric_writer.py +726 -0
  148. langfun/env/event_handlers/metric_writer_test.py +214 -0
  149. langfun/env/interface.py +1640 -0
  150. langfun/env/interface_test.py +153 -0
  151. langfun/env/load_balancers.py +59 -0
  152. langfun/env/load_balancers_test.py +141 -0
  153. langfun/env/test_utils.py +507 -0
  154. {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512150805.dist-info}/METADATA +7 -3
  155. langfun-0.1.2.dev202512150805.dist-info/RECORD +217 -0
  156. langfun/core/eval/v2/runners_test.py +0 -343
  157. langfun/core/structured/schema.py +0 -987
  158. langfun/core/structured/schema_test.py +0 -982
  159. langfun-0.1.2.dev202509120804.dist-info/RECORD +0 -172
  160. {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512150805.dist-info}/WHEEL +0 -0
  161. {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512150805.dist-info}/licenses/LICENSE +0 -0
  162. {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512150805.dist-info}/top_level.txt +0 -0
@@ -24,7 +24,14 @@ import pyglove as pg
24
24
 
25
25
 
26
26
  class ActionEval(lf.eval.v2.Evaluation):
27
- """Agent evaluation."""
27
+ """Evaluation for agentic actions.
28
+
29
+ `ActionEval` is a specialized evaluation class for executing and evaluating
30
+ agentic actions based on provided inputs. Each input example is expected to
31
+ contain an `action` attribute. The `process` method executes the action
32
+ within a dedicated `Session`, captures the final result, and returns it
33
+ along with the session details in the metadata.
34
+ """
28
35
 
29
36
  action_args: Annotated[
30
37
  dict[str, Any],
@@ -68,7 +75,7 @@ class ExampleView(pg.Object):
68
75
  class ActionEvalV1(lf_eval.Matching):
69
76
  """Base class for action evaluations.
70
77
 
71
- The input function should returns a list of pg.Dict, with `action` and
78
+ The input function should return a list of pg.Dict, with `action` and
72
79
  `groundtruth` fields.
73
80
  """
74
81
  # We override the schema and prompt to dummy values since they are not used.
@@ -34,6 +34,7 @@ class Bar(action_lib.Action):
34
34
  time.sleep(self.simulate_execution_time)
35
35
  session.query('bar', lm=lm)
36
36
  session.add_metadata(note='bar')
37
+ session.update_progress('Query completed')
37
38
  if self.simulate_action_error:
38
39
  raise ValueError('Bar error')
39
40
  return 2 + pg.contextual_value('baz', 0)
@@ -51,6 +52,7 @@ class Foo(action_lib.Action):
51
52
  with session.track_phase('prepare'):
52
53
  session.info('Begin Foo', x=1)
53
54
  time.sleep(self.simulate_execution_time[0])
55
+ Bar()(session, lm=lm)
54
56
  session.query(
55
57
  'foo',
56
58
  schema=int if self.simulate_query_error else None,
@@ -64,14 +66,21 @@ class Foo(action_lib.Action):
64
66
  def _sub_task(i):
65
67
  session.add_metadata(**{f'subtask_{i}': i})
66
68
  time.sleep(self.simulate_execution_time[2])
69
+ Bar()(session, lm=lm)
67
70
  return lf_structured.query(f'subtask_{i}', lm=lm)
68
71
 
72
+ self._state = []
69
73
  for i, output, error in session.concurrent_map(
70
- _sub_task, range(3), max_workers=2, silence_on_errors=None,
74
+ _sub_task,
75
+ range(3),
76
+ max_workers=2,
77
+ ordered=True,
78
+ silence_on_errors=None,
71
79
  ):
72
80
  assert isinstance(i, int), i
73
81
  assert isinstance(output, str), output
74
82
  assert error is None, error
83
+ self._state.append(i)
75
84
  return self.x + Bar(
76
85
  simulate_action_error=self.simulate_action_error,
77
86
  simulate_execution_time=self.simulate_execution_time[3]
@@ -81,6 +90,50 @@ class Foo(action_lib.Action):
81
90
  lf_structured.query('additional query', lm=lm)
82
91
 
83
92
 
93
+ class ExecutionUnitPositionTest(unittest.TestCase):
94
+
95
+ def test_basics(self):
96
+ pos1 = action_lib.ExecutionUnit.Position(None, 0)
97
+ self.assertEqual(repr(pos1), 'Position(0)')
98
+ self.assertEqual(str(pos1), '')
99
+ self.assertIsNone(pos1.parent)
100
+ self.assertEqual(pos1.index, 0)
101
+ self.assertEqual(pos1.indices(), (0,))
102
+ self.assertEqual(pos1, (0,))
103
+ self.assertEqual(pos1, '')
104
+ self.assertEqual(pos1, action_lib.ExecutionUnit.Position(None, 0))
105
+ self.assertNotEqual(pos1, 1)
106
+ self.assertNotEqual(pos1, (1,))
107
+ self.assertNotEqual(pos1, action_lib.ExecutionUnit.Position(None, 1))
108
+
109
+ pos2 = action_lib.ExecutionUnit.Position(pos1, 0)
110
+ self.assertEqual(repr(pos2), 'Position(0, 0)')
111
+ self.assertEqual(str(pos2), '1')
112
+ self.assertEqual(pos2, '1')
113
+ self.assertEqual(pos2.parent, pos1)
114
+ self.assertEqual(pos2.index, 0)
115
+ self.assertEqual(pos2.indices(), (0, 0))
116
+ self.assertNotEqual(pos1, pos2)
117
+ self.assertLess(pos1, pos2)
118
+ self.assertGreater(pos2, pos1)
119
+ self.assertEqual(
120
+ hash(pos2),
121
+ hash(
122
+ action_lib.ExecutionUnit.Position(
123
+ action_lib.ExecutionUnit.Position(None, 0), 0
124
+ )
125
+ )
126
+ )
127
+
128
+ pos3 = action_lib.ExecutionUnit.Position(pos2, 0)
129
+ self.assertEqual(str(pos3), '1.1')
130
+ self.assertEqual(pos3, '1.1')
131
+ self.assertEqual(pos3.parent, pos2)
132
+ self.assertEqual(pos3.index, 0)
133
+ self.assertEqual(pos3.indices(), (0, 0, 0))
134
+ self.assertEqual(pos3.to_str(separator='>'), '1>1')
135
+
136
+
84
137
  class ActionInvocationTest(unittest.TestCase):
85
138
 
86
139
  def test_basics(self):
@@ -101,9 +154,7 @@ class ExecutionTraceTest(unittest.TestCase):
101
154
  self.assertEqual(execution.id, '')
102
155
 
103
156
  root = action_lib.ActionInvocation(action=action_lib.RootAction())
104
- action_invocation = action_lib.ActionInvocation(
105
- action=Foo(1)
106
- )
157
+ action_invocation = action_lib.ActionInvocation(action=Foo(1))
107
158
  root.execution.append(action_invocation)
108
159
  self.assertEqual(action_invocation.execution.id, '/a1')
109
160
 
@@ -118,10 +169,11 @@ class SessionTest(unittest.TestCase):
118
169
  foo = Foo(1)
119
170
  self.assertIsNone(foo.session)
120
171
  self.assertIsNone(foo.invocation)
172
+ self.assertIsNone(foo.state)
121
173
  self.assertIsNone(foo.result)
122
174
  self.assertIsNone(foo.metadata)
123
175
 
124
- session = action_lib.Session(id='agent@1')
176
+ session = action_lib.Session(id='agent@1', verbose=True)
125
177
  self.assertEqual(session.id, 'agent@1')
126
178
  self.assertFalse(session.has_started)
127
179
  self.assertFalse(session.has_stopped)
@@ -130,12 +182,14 @@ class SessionTest(unittest.TestCase):
130
182
  _ = session.to_html()
131
183
 
132
184
  with session:
133
- result = foo(session, lm=lm, verbose=True)
185
+ result = foo(session, lm=lm)
134
186
 
135
187
  self.assertTrue(session.has_started)
136
188
  self.assertTrue(session.has_stopped)
137
189
  self.assertEqual(result, 3)
138
190
  self.assertIsNone(foo.session)
191
+ self.assertEqual(foo.state, [0, 1, 2])
192
+ self.assertIs(foo.invocation.state, foo.state)
139
193
  self.assertEqual(foo.result, 3)
140
194
  self.assertEqual(
141
195
  foo.metadata, dict(note='foo', subtask_0=0, subtask_1=1, subtask_2=2)
@@ -143,6 +197,7 @@ class SessionTest(unittest.TestCase):
143
197
 
144
198
  self.assertIsInstance(session.root.action, action_lib.RootAction)
145
199
  self.assertIs(session.current_action, session.root)
200
+ self.assertIs(session.metadata, session.root.metadata)
146
201
 
147
202
  #
148
203
  # Inspecting the root invocation.
@@ -165,20 +220,25 @@ class SessionTest(unittest.TestCase):
165
220
  )
166
221
 
167
222
  # The root space should have one action (foo), no queries, and no logs.
223
+ self.assertEqual(len(root.execution_units), 1)
168
224
  self.assertEqual(len(root.actions), 1)
169
225
  self.assertEqual(len(root.queries), 0)
170
226
  self.assertEqual(len(root.logs), 0)
171
- # 1 query from Bar, 2 from Foo and 3 from parallel executions.
172
- self.assertEqual(len(session.all_queries), 6)
173
- self.assertEqual(len(root.all_queries), 6)
174
- # 2 actions: Foo and Bar.
175
- self.assertEqual(len(session.all_actions), 2)
176
- self.assertEqual(len(root.all_actions), 2)
177
- # 1 log from Bar and 1 from Foo.
178
- self.assertEqual(len(session.all_logs), 2)
179
- self.assertEqual(len(root.all_logs), 2)
227
+ # 2 query from Bar, 2 from Foo and 2 * 3 from parallel executions.
228
+ self.assertEqual(len(session.all_queries), 10)
229
+ self.assertEqual(len(root.all_queries), 10)
230
+ # 6 actions: Foo and 2 Bar, and 3 Bar from parallel executions.
231
+ self.assertEqual(len(session.all_actions), 6)
232
+ self.assertEqual(
233
+ [str(a.position) for a in session.all_actions],
234
+ ['1', '1.1', '1.2.1.1', '1.2.2.1', '1.2.3.1', '1.3']
235
+ )
236
+ self.assertEqual(len(root.all_actions), 6)
237
+ # 1 log from Bar and 1 from Foo and 3 from Bar in parallel executions.
238
+ self.assertEqual(len(session.all_logs), 6)
239
+ self.assertEqual(len(root.all_logs), 6)
180
240
  self.assertIs(session.usage_summary, root.usage_summary)
181
- self.assertEqual(root.usage_summary.total.num_requests, 6)
241
+ self.assertEqual(root.usage_summary.total.num_requests, 10)
182
242
 
183
243
  # Inspecting the top-level action (Foo)
184
244
  foo_invocation = root.execution[0]
@@ -190,15 +250,19 @@ class SessionTest(unittest.TestCase):
190
250
 
191
251
  # Prepare phase.
192
252
  prepare_phase = foo_invocation.execution[0]
253
+ self.assertIsNone(prepare_phase.position)
193
254
  self.assertIsInstance(prepare_phase, action_lib.ExecutionTrace)
194
255
  self.assertEqual(prepare_phase.id, 'agent@1:/a1/prepare')
195
- self.assertEqual(len(prepare_phase.items), 2)
256
+ self.assertEqual(len(prepare_phase.items), 3)
196
257
  self.assertTrue(prepare_phase.has_started)
197
258
  self.assertTrue(prepare_phase.has_stopped)
198
- self.assertEqual(prepare_phase.usage_summary.total.num_requests, 1)
259
+ self.assertEqual(prepare_phase.usage_summary.total.num_requests, 2)
199
260
  self.assertIsInstance(prepare_phase.items[0], lf.logging.LogEntry)
200
- self.assertIsInstance(prepare_phase.items[1], lf_structured.QueryInvocation)
201
- self.assertEqual(prepare_phase.items[1].id, 'agent@1:/a1/prepare/q1')
261
+ self.assertIsInstance(prepare_phase.items[1], action_lib.ActionInvocation)
262
+ self.assertIs(prepare_phase.items[1].parent_execution_unit, foo_invocation)
263
+ self.assertEqual(prepare_phase.items[1].id, 'agent@1:/a1/prepare/a1')
264
+ self.assertIsInstance(prepare_phase.items[2], lf_structured.QueryInvocation)
265
+ self.assertEqual(prepare_phase.items[2].id, 'agent@1:/a1/prepare/q1')
202
266
 
203
267
  # Tracked queries.
204
268
  query_invocation = foo_invocation.execution[1]
@@ -220,20 +284,44 @@ class SessionTest(unittest.TestCase):
220
284
 
221
285
  # Tracked parallel executions.
222
286
  parallel_executions = foo_invocation.execution[2]
287
+ # root (0) > foo (0) > parallel executions (1)
288
+ self.assertEqual(parallel_executions.position, (0, 0, 1))
223
289
  self.assertEqual(parallel_executions.id, 'agent@1:/a1/p1')
224
290
  self.assertIsInstance(parallel_executions, action_lib.ParallelExecutions)
291
+ self.assertIs(
292
+ parallel_executions.all_actions[0].parent_execution_unit,
293
+ parallel_executions
294
+ )
295
+ self.assertIs(
296
+ parallel_executions.all_actions[0].parent_action,
297
+ foo_invocation
298
+ )
225
299
  self.assertEqual(len(parallel_executions), 3)
226
300
  self.assertEqual(parallel_executions[0].id, 'agent@1:/a1/p1/b1')
227
301
  self.assertEqual(parallel_executions[1].id, 'agent@1:/a1/p1/b2')
228
302
  self.assertEqual(parallel_executions[2].id, 'agent@1:/a1/p1/b3')
303
+ self.assertEqual(len(parallel_executions[0].execution_units), 1)
304
+ self.assertEqual(len(parallel_executions[1].execution_units), 1)
305
+ self.assertEqual(len(parallel_executions[2].execution_units), 1)
229
306
  self.assertEqual(len(parallel_executions[0].queries), 1)
307
+ self.assertEqual(len(parallel_executions[0].all_queries), 2)
230
308
  self.assertEqual(len(parallel_executions[1].queries), 1)
309
+ self.assertEqual(len(parallel_executions[1].all_queries), 2)
231
310
  self.assertEqual(len(parallel_executions[2].queries), 1)
311
+ self.assertEqual(len(parallel_executions[2].all_queries), 2)
312
+ self.assertEqual(len(parallel_executions.execution_units), 0)
313
+ self.assertEqual(len(parallel_executions.actions), 0)
314
+ self.assertEqual(len(parallel_executions.queries), 0)
315
+ self.assertEqual(len(parallel_executions.logs), 0)
316
+ self.assertEqual(len(parallel_executions.all_actions), 3)
317
+ self.assertEqual(len(parallel_executions.all_queries), 6)
318
+ self.assertEqual(len(parallel_executions.all_logs), 3)
232
319
 
233
320
  # Invocation to Bar.
234
321
  bar_invocation = foo_invocation.execution[3]
235
322
  self.assertIs(bar_invocation.parent_action, foo_invocation)
236
- self.assertEqual(bar_invocation.id, 'agent@1:/a1/a1')
323
+ self.assertIs(bar_invocation.parent_execution_unit, foo_invocation)
324
+ self.assertEqual(bar_invocation.id, 'agent@1:/a1/a5')
237
325
  self.assertIsInstance(bar_invocation, action_lib.ActionInvocation)
238
326
  self.assertIsInstance(bar_invocation.action, Bar)
239
327
  self.assertEqual(bar_invocation.result, 2)
@@ -366,7 +454,7 @@ class SessionTest(unittest.TestCase):
366
454
  self.assertFalse(session.has_stopped)
367
455
 
368
456
  session.start()
369
- result = foo(session, lm=lm, verbose=True)
457
+ result = foo(session, lm=lm)
370
458
  session.end(result)
371
459
 
372
460
  self.assertTrue(session.has_started)
@@ -386,7 +474,7 @@ class SessionTest(unittest.TestCase):
386
474
  session = action_lib.Session(id='agent@1')
387
475
  with self.assertRaisesRegex(ValueError, 'Bar error'):
388
476
  with session:
389
- foo(session, lm=lm, verbose=True)
477
+ foo(session, lm=lm)
390
478
  self.assertTrue(session.has_started)
391
479
  self.assertTrue(session.has_stopped)
392
480
  self.assertTrue(session.has_error)
@@ -399,7 +487,7 @@ class SessionTest(unittest.TestCase):
399
487
  foo = Foo(1, simulate_action_error=True)
400
488
  session = action_lib.Session(id='agent@1')
401
489
  with self.assertRaisesRegex(ValueError, 'Please call `Session.start'):
402
- foo(session, lm=lm, verbose=True)
490
+ foo(session, lm=lm)
403
491
 
404
492
  def test_succeed_with_multiple_actions(self):
405
493
  lm = fake.StaticResponse('lm response')
@@ -480,6 +568,58 @@ class SessionTest(unittest.TestCase):
480
568
  ):
481
569
  foo(lm=lm, max_execution_time=1.0)
482
570
 
571
+ def test_event_handler(self):
572
+
573
+ class MyActionHandler(pg.Object, action_lib.SessionEventHandler):
574
+ def _on_bound(self):
575
+ super()._on_bound()
576
+ self.progresses = []
577
+
578
+ def on_session_start(self, session):
579
+ session.add_metadata(progresses=pg.Ref(self.progresses))
580
+
581
+ def on_action_progress(self, session, action, title, **kwargs):
582
+ self.progresses.append((action.id, title))
583
+
584
+ handler = MyActionHandler()
585
+ self.assertIs(handler.get(MyActionHandler), handler)
586
+ self.assertIsNone(handler.get(action_lib.SessionLogging))
587
+
588
+ handler_chain = action_lib.SessionEventHandlerChain(
589
+ handlers=[handler, action_lib.SessionLogging()]
590
+ )
591
+ self.assertIs(handler_chain.get(MyActionHandler), handler)
592
+ self.assertIs(
593
+ handler_chain.get(action_lib.SessionLogging),
594
+ handler_chain.handlers[1]
595
+ )
596
+
597
+ session = action_lib.Session(
598
+ id='agent@1',
599
+ event_handler=handler_chain
600
+ )
601
+ bar = Bar()
602
+ with session:
603
+ bar(session, lm=fake.StaticResponse('lm response'))
604
+ session.update_progress('Trajectory completed')
605
+
606
+ self.assertIs(session.metadata['progresses'], handler.progresses)
607
+ self.assertEqual(handler.progresses, [
608
+ ('agent@1:/a1', 'Query completed'),
609
+ ('agent@1:', 'Trajectory completed'),
610
+ ])
611
+
612
+ def test_clone(self):
613
+ event_handler = action_lib.SessionLogging()
614
+ session = action_lib.Session(event_handler=event_handler)
615
+ other = session.clone()
616
+ self.assertIsNot(session, other)
617
+ self.assertIs(other.event_handler, event_handler)
618
+
619
+ other = session.clone(deep=True)
620
+ self.assertIsNot(session, other)
621
+ self.assertIsNot(other.event_handler, session.event_handler)
622
+
483
623
  def test_log(self):
484
624
  session = action_lib.Session()
485
625
  session.debug('hi', x=1, y=2)
@@ -493,6 +633,31 @@ class SessionTest(unittest.TestCase):
493
633
  self.assertIn('agent@', session.id)
494
634
  self.assertIsInstance(session.as_message(), lf.AIMessage)
495
635
 
636
+ def test_query_with_track_if(self):
637
+ lm = fake.StaticResponse('lm response')
638
+ session = action_lib.Session()
639
+
640
+ # Render session to trigger javascript updates to the HTML when
641
+ # operating on the session.
642
+ _ = session.to_html()
643
+ with session:
644
+ # This query will succeed.
645
+ session.query(
646
+ 'prompt1',
647
+ schema=None,
648
+ lm=lm,
649
+ track_if=lambda q: not q.has_error,
650
+ default=None)
651
+ # This query will fail during parsing.
652
+ session.query(
653
+ 'prompt2',
654
+ schema=int,
655
+ lm=lm,
656
+ track_if=lambda q: not q.has_error,
657
+ default=None)
658
+ self.assertEqual(len(session.root.queries), 1)
659
+ self.assertIsNone(session.root.queries[0].error)
660
+
496
661
 
497
662
  if __name__ == '__main__':
498
663
  unittest.main()
@@ -11,18 +11,117 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
- """Utility for async IO in Langfun."""
14
+ """Utilities for asynchronous programming in Langfun."""
15
15
 
16
16
  import asyncio
17
- from typing import Any, Callable
17
+ import contextlib
18
+ from typing import Any, Awaitable, Callable, Iterator
19
+ import anyio
18
20
  import pyglove as pg
19
21
 
20
22
 
21
23
  async def invoke_async(
22
- callable_object: Callable[..., Any], *args, **kwargs
24
+ sync_callable: Callable[..., Any], *args, **kwargs
23
25
  ) -> Any:
24
- """Invokes a callable asynchronously with `lf.context` manager enabled."""
26
+ """Invokes a sync callable asynchronously in a separate thread.
27
+
28
+ This is useful for wrapping a sync function into an async function,
29
+ allowing multiple calls of the sync function to run concurrently.
30
+ `lf.context` will be propagated to the thread that runs the sync callable.
31
+
32
+ Args:
33
+ sync_callable: The sync callable to invoke.
34
+ *args: Positional arguments to pass to the callable.
35
+ **kwargs: Keyword arguments to pass to the callable.
36
+
37
+ Returns:
38
+ An awaitable that resolves to the return value of the sync_callable.
39
+ """
25
40
  return await asyncio.to_thread(
26
41
  # Enable `lf.context` manager for async calls.
27
- pg.with_contextual_override(callable_object), *args, **kwargs
42
+ pg.with_contextual_override(sync_callable), *args, **kwargs
28
43
  )
44
+
45
+
46
+ def invoke_sync(
47
+ async_callable: Callable[..., Awaitable[Any]],
48
+ *args,
49
+ **kwargs
50
+ ) -> Any:
51
+ """Invokes an async callable synchronously.
52
+
53
+ This is useful for calling an async function from a sync context.
54
+ If there is an existing async event loop in current thread managed by
55
+ `lf.sync_context_manager`, it will be used for running the async callable.
56
+ Otherwise, `anyio.run` will be used to run the async callable in a new
57
+ event loop.
58
+ `lf.context` will be propagated to the async callable.
59
+
60
+ Args:
61
+ async_callable: The async callable to invoke.
62
+ *args: Positional arguments to pass to the callable.
63
+ **kwargs: Keyword arguments to pass to the callable.
64
+
65
+ Returns:
66
+ The return value of the async_callable.
67
+ """
68
+ async def _invoke():
69
+ return await async_callable(*args, **kwargs)
70
+ invoke_fn = pg.with_contextual_override(_invoke)
71
+ blocking_portal = pg.utils.thread_local_get('__blocking_portal__', None)
72
+ if blocking_portal is None:
73
+ return anyio.run(invoke_fn)
74
+ return blocking_portal.call(invoke_fn)
75
+
76
+
77
+ @contextlib.contextmanager
78
+ def sync_context_manager(
79
+ async_context_manager: contextlib.AbstractAsyncContextManager[Any]
80
+ ) -> Iterator[Any]:
81
+ """Adapts an async context manager to a sync context manager.
82
+
83
+ sync_context_manager installs a blocking portal in current thread to run the
84
+ async context manager in a blocking way. It's useful for running async code in
85
+ sync context managers, e.g. `sync_context_manager` can be nested and share the
86
+ same event loop.
87
+
88
+ Example:
89
+
90
+ ```python
91
+ @contextlib.asynccontextmanager
92
+ async def foo(x):
93
+ try:
94
+ yield x
95
+ finally:
96
+ pass
97
+
98
+ with lf.sync_context_manager(foo(x)) as x
99
+ with lf.sync_context_manager(foo(y)) as y:
100
+ ...
101
+ ```
102
+
103
+ Args:
104
+ async_context_manager: The async context manager to adapt.
105
+
106
+ Yields:
107
+ The value yielded by the async context manager.
108
+ """
109
+ blocking_portal = pg.utils.thread_local_get('__blocking_portal__', None)
110
+ portal_exit_stack = None
111
+
112
+ try:
113
+ if blocking_portal is None:
114
+ portal_exit_stack = contextlib.ExitStack()
115
+ blocking_portal = portal_exit_stack.enter_context(
116
+ anyio.from_thread.start_blocking_portal()
117
+ )
118
+ pg.utils.thread_local_set('__blocking_portal__', blocking_portal)
119
+ context_manager = blocking_portal.wrap_async_context_manager(
120
+ async_context_manager
121
+ )
122
+ with context_manager as value:
123
+ yield value
124
+ finally:
125
+ if portal_exit_stack is not None:
126
+ portal_exit_stack.close()
127
+ pg.utils.thread_local_del('__blocking_portal__')
@@ -13,6 +13,7 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import asyncio
16
+ import contextlib
16
17
  import time
17
18
  import unittest
18
19
 
@@ -34,6 +35,28 @@ class AsyncSupportTest(unittest.TestCase):
34
35
  with pg.contextual_override(z=3):
35
36
  self.assertEqual(asyncio.run(r), 6)
36
37
 
38
+ def test_invoke_sync(self):
39
+ @contextlib.asynccontextmanager
40
+ async def bar(x):
41
+ try:
42
+ yield x
43
+ finally:
44
+ pass
45
+
46
+ async def foo(x, *, y):
47
+ time.sleep(2)
48
+ return x + y + pg.contextual_value('z', 0)
49
+
50
+ with pg.contextual_override(z=3):
51
+ with async_support.sync_context_manager(bar(1)) as x:
52
+ self.assertEqual(x, 1)
53
+ with async_support.sync_context_manager(bar(2)) as y:
54
+ self.assertEqual(y, 2)
55
+ self.assertEqual(async_support.invoke_sync(foo, 1, y=2), 6)
56
+
57
+ with pg.contextual_override(z=2):
58
+ self.assertEqual(async_support.invoke_sync(foo, 1, y=2), 5)
59
+
37
60
 
38
61
  if __name__ == '__main__':
39
62
  unittest.main()
@@ -19,13 +19,23 @@ import pyglove as pg
19
19
 
20
20
 
21
21
  class CodeWithError(pg.Object):
22
- """Python code with error."""
22
+ """A structure representing Python code along with an execution error.
23
+
24
+ This is used as input to a language model for error correction, providing
25
+ the model with the code that failed and the error message it produced.
26
+ """
23
27
 
24
28
  code: str
25
29
  error: str
26
30
 
27
31
 
28
32
  class CorrectedCode(pg.Object):
33
+ """A structure containing corrected Python code.
34
+
35
+ This is used as the output schema when asking a language model to correct
36
+ code, expecting the model to return the fixed code in the `corrected_code`
37
+ field.
38
+ """
29
39
  corrected_code: str
30
40
 
31
41
 
@@ -49,7 +59,7 @@ def run_with_correction(
49
59
  code: The source code that may or may not be problematic.
50
60
  error: An optional initial error for `code` when it's problematic, usually
51
61
  caught from elsewhere when it ran. If None, code will be executed once to
52
- verify if its good and obtain a feedback error message.
62
+ verify if it's good and obtain a feedback error message.
53
63
  global_vars: A dict of str to value as the global variables that could be
54
64
  accessed within the corrected code.
55
65
  lm: Language model to be used. If not specified, it will try to use the `lm`
@@ -57,15 +67,15 @@ def run_with_correction(
57
67
  max_attempts: Max number of attempts for the correction.
58
68
  sandbox: If True, run code in sandbox; If False, run code in current
59
69
  process. If None, run in sandbox first, if the output could not be
60
- serialized and pass to current process, run the code again in current
70
+ serialized and passed to current process, run the code again in current
61
71
  process.
62
72
  permission: The permission to run the code.
63
73
  timeout: The timeout for running the corrected code. If None, there is no
64
74
  timeout. Applicable only when sandbox is set to True.
65
75
  returns_code: If True, the return value is a tuple of (result, final code).
66
76
  Otherwise the return value is the result only.
67
- returns_stdout: If True, the stdout (a str) will be returned.
68
- outputs_intermediate: If True, intermediate output will be outputted as a
77
+ returns_stdout: If True, the stdout (a string) will be returned.
78
+ outputs_intermediate: If True, intermediate output will be output as a
69
79
  dict, with the last line's value accessible by key '__result__'. Otherwise
70
80
  the value of the last line will be returned.
71
81
 
@@ -161,7 +171,7 @@ def correct(
161
171
  code: The source code that may or may not be problematic.
162
172
  error: An optional initial error for `code` when it's problematic, usually
163
173
  caught from elsewhere when it ran. If None, code will be executed once to
164
- verify if its good and obtain a feedback error message.
174
+ verify if it's good and obtain a feedback error message.
165
175
  global_vars: A dict of str to value as the global variables that could be
166
176
  accessed within the corrected code.
167
177
  lm: Language model to be used. If not specified, it will try to use the `lm`
@@ -169,7 +179,7 @@ def correct(
169
179
  max_attempts: Max number of attempts for the correction.
170
180
  sandbox: If True, run code in sandbox; If False, run code in current
171
181
  process. If None, run in sandbox first, if the output could not be
172
- serialized and pass to current process, run the code again in current
182
+ serialized and passed to current process, run the code again in current
173
183
  process.
174
184
  timeout: The timeout for running the corrected code. If None, there is no
175
185
  timeout. Applicable only when sandbox is set to True.
@@ -193,7 +203,7 @@ def correct(
193
203
 
194
204
 
195
205
  def _error_feedback_str(error: Exception) -> str:
196
- """Returns the error str for feedback."""
206
+ """Returns the error string for feedback."""
197
207
  if isinstance(error, pg.coding.CodeError):
198
208
  return pg.decolor(error.format(include_complete_code=False))
199
209
  else:
@@ -201,7 +211,7 @@ def _error_feedback_str(error: Exception) -> str:
201
211
 
202
212
 
203
213
  def _maybe_custom_validate(result: Any) -> Any:
204
- """Apply custom validation through __validate_generation__ method."""
214
+ """Applies custom validation through __validate__ method."""
205
215
  if isinstance(result, dict) and "__result__" in result:
206
216
  r = result["__result__"]
207
217
  else: