langfun 0.1.2.dev202509120804__py3-none-any.whl → 0.1.2.dev202512040805__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (162) hide show
  1. langfun/__init__.py +1 -1
  2. langfun/core/__init__.py +7 -1
  3. langfun/core/agentic/__init__.py +8 -1
  4. langfun/core/agentic/action.py +740 -112
  5. langfun/core/agentic/action_eval.py +9 -2
  6. langfun/core/agentic/action_test.py +189 -24
  7. langfun/core/async_support.py +104 -5
  8. langfun/core/async_support_test.py +23 -0
  9. langfun/core/coding/python/correction.py +19 -9
  10. langfun/core/coding/python/execution.py +14 -12
  11. langfun/core/coding/python/generation.py +21 -16
  12. langfun/core/coding/python/sandboxing.py +23 -3
  13. langfun/core/component.py +42 -3
  14. langfun/core/concurrent.py +70 -6
  15. langfun/core/concurrent_test.py +9 -2
  16. langfun/core/console.py +1 -1
  17. langfun/core/data/conversion/anthropic.py +12 -3
  18. langfun/core/data/conversion/anthropic_test.py +8 -6
  19. langfun/core/data/conversion/gemini.py +11 -2
  20. langfun/core/data/conversion/gemini_test.py +48 -9
  21. langfun/core/data/conversion/openai.py +145 -31
  22. langfun/core/data/conversion/openai_test.py +161 -17
  23. langfun/core/eval/base.py +48 -44
  24. langfun/core/eval/base_test.py +5 -5
  25. langfun/core/eval/matching.py +5 -2
  26. langfun/core/eval/patching.py +3 -3
  27. langfun/core/eval/scoring.py +4 -3
  28. langfun/core/eval/v2/__init__.py +2 -0
  29. langfun/core/eval/v2/checkpointing.py +76 -7
  30. langfun/core/eval/v2/checkpointing_test.py +9 -2
  31. langfun/core/eval/v2/config_saver.py +37 -0
  32. langfun/core/eval/v2/config_saver_test.py +36 -0
  33. langfun/core/eval/v2/eval_test_helper.py +104 -3
  34. langfun/core/eval/v2/evaluation.py +92 -17
  35. langfun/core/eval/v2/evaluation_test.py +9 -3
  36. langfun/core/eval/v2/example.py +50 -40
  37. langfun/core/eval/v2/example_test.py +16 -8
  38. langfun/core/eval/v2/experiment.py +84 -15
  39. langfun/core/eval/v2/experiment_test.py +19 -0
  40. langfun/core/eval/v2/metric_values.py +31 -3
  41. langfun/core/eval/v2/metric_values_test.py +32 -0
  42. langfun/core/eval/v2/metrics.py +157 -44
  43. langfun/core/eval/v2/metrics_test.py +39 -18
  44. langfun/core/eval/v2/progress.py +31 -1
  45. langfun/core/eval/v2/progress_test.py +27 -0
  46. langfun/core/eval/v2/progress_tracking.py +13 -5
  47. langfun/core/eval/v2/progress_tracking_test.py +9 -1
  48. langfun/core/eval/v2/reporting.py +90 -71
  49. langfun/core/eval/v2/reporting_test.py +24 -6
  50. langfun/core/eval/v2/runners/__init__.py +30 -0
  51. langfun/core/eval/v2/{runners.py → runners/base.py} +72 -180
  52. langfun/core/eval/v2/runners/beam.py +354 -0
  53. langfun/core/eval/v2/runners/beam_test.py +153 -0
  54. langfun/core/eval/v2/runners/ckpt_monitor.py +294 -0
  55. langfun/core/eval/v2/runners/ckpt_monitor_test.py +162 -0
  56. langfun/core/eval/v2/runners/debug.py +40 -0
  57. langfun/core/eval/v2/runners/debug_test.py +76 -0
  58. langfun/core/eval/v2/runners/parallel.py +243 -0
  59. langfun/core/eval/v2/runners/parallel_test.py +182 -0
  60. langfun/core/eval/v2/runners/sequential.py +47 -0
  61. langfun/core/eval/v2/runners/sequential_test.py +169 -0
  62. langfun/core/langfunc.py +45 -130
  63. langfun/core/langfunc_test.py +7 -5
  64. langfun/core/language_model.py +189 -36
  65. langfun/core/language_model_test.py +54 -3
  66. langfun/core/llms/__init__.py +12 -1
  67. langfun/core/llms/anthropic.py +157 -2
  68. langfun/core/llms/azure_openai.py +29 -17
  69. langfun/core/llms/cache/base.py +25 -3
  70. langfun/core/llms/cache/in_memory.py +48 -7
  71. langfun/core/llms/cache/in_memory_test.py +14 -4
  72. langfun/core/llms/compositional.py +25 -1
  73. langfun/core/llms/deepseek.py +30 -2
  74. langfun/core/llms/fake.py +32 -1
  75. langfun/core/llms/gemini.py +64 -12
  76. langfun/core/llms/gemini_test.py +110 -0
  77. langfun/core/llms/google_genai.py +34 -1
  78. langfun/core/llms/groq.py +28 -3
  79. langfun/core/llms/llama_cpp.py +23 -4
  80. langfun/core/llms/openai.py +120 -3
  81. langfun/core/llms/openai_compatible.py +148 -27
  82. langfun/core/llms/openai_compatible_test.py +207 -20
  83. langfun/core/llms/openai_test.py +0 -2
  84. langfun/core/llms/rest.py +16 -1
  85. langfun/core/llms/vertexai.py +58 -8
  86. langfun/core/logging.py +1 -1
  87. langfun/core/mcp/__init__.py +10 -0
  88. langfun/core/mcp/client.py +177 -0
  89. langfun/core/mcp/client_test.py +71 -0
  90. langfun/core/mcp/session.py +241 -0
  91. langfun/core/mcp/session_test.py +54 -0
  92. langfun/core/mcp/testing/simple_mcp_client.py +33 -0
  93. langfun/core/mcp/testing/simple_mcp_server.py +33 -0
  94. langfun/core/mcp/tool.py +254 -0
  95. langfun/core/mcp/tool_test.py +197 -0
  96. langfun/core/memory.py +1 -0
  97. langfun/core/message.py +160 -55
  98. langfun/core/message_test.py +65 -81
  99. langfun/core/modalities/__init__.py +8 -0
  100. langfun/core/modalities/audio.py +21 -1
  101. langfun/core/modalities/image.py +73 -3
  102. langfun/core/modalities/image_test.py +116 -0
  103. langfun/core/modalities/mime.py +64 -3
  104. langfun/core/modalities/mime_test.py +11 -0
  105. langfun/core/modalities/pdf.py +19 -1
  106. langfun/core/modalities/video.py +21 -1
  107. langfun/core/modality.py +167 -29
  108. langfun/core/modality_test.py +42 -12
  109. langfun/core/natural_language.py +1 -1
  110. langfun/core/sampling.py +4 -4
  111. langfun/core/sampling_test.py +20 -4
  112. langfun/core/structured/__init__.py +2 -24
  113. langfun/core/structured/completion.py +34 -44
  114. langfun/core/structured/completion_test.py +23 -43
  115. langfun/core/structured/description.py +54 -50
  116. langfun/core/structured/function_generation.py +29 -12
  117. langfun/core/structured/mapping.py +81 -37
  118. langfun/core/structured/parsing.py +95 -79
  119. langfun/core/structured/parsing_test.py +0 -3
  120. langfun/core/structured/querying.py +230 -154
  121. langfun/core/structured/querying_test.py +69 -33
  122. langfun/core/structured/schema/__init__.py +49 -0
  123. langfun/core/structured/schema/base.py +664 -0
  124. langfun/core/structured/schema/base_test.py +531 -0
  125. langfun/core/structured/schema/json.py +174 -0
  126. langfun/core/structured/schema/json_test.py +121 -0
  127. langfun/core/structured/schema/python.py +316 -0
  128. langfun/core/structured/schema/python_test.py +410 -0
  129. langfun/core/structured/schema_generation.py +33 -14
  130. langfun/core/structured/scoring.py +47 -36
  131. langfun/core/structured/tokenization.py +26 -11
  132. langfun/core/subscription.py +2 -2
  133. langfun/core/template.py +175 -50
  134. langfun/core/template_test.py +123 -17
  135. langfun/env/__init__.py +43 -0
  136. langfun/env/base_environment.py +827 -0
  137. langfun/env/base_environment_test.py +473 -0
  138. langfun/env/base_feature.py +304 -0
  139. langfun/env/base_feature_test.py +228 -0
  140. langfun/env/base_sandbox.py +842 -0
  141. langfun/env/base_sandbox_test.py +1235 -0
  142. langfun/env/event_handlers/__init__.py +14 -0
  143. langfun/env/event_handlers/chain.py +233 -0
  144. langfun/env/event_handlers/chain_test.py +253 -0
  145. langfun/env/event_handlers/event_logger.py +472 -0
  146. langfun/env/event_handlers/event_logger_test.py +304 -0
  147. langfun/env/event_handlers/metric_writer.py +726 -0
  148. langfun/env/event_handlers/metric_writer_test.py +214 -0
  149. langfun/env/interface.py +1640 -0
  150. langfun/env/interface_test.py +153 -0
  151. langfun/env/load_balancers.py +59 -0
  152. langfun/env/load_balancers_test.py +141 -0
  153. langfun/env/test_utils.py +507 -0
  154. {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512040805.dist-info}/METADATA +7 -3
  155. langfun-0.1.2.dev202512040805.dist-info/RECORD +217 -0
  156. langfun/core/eval/v2/runners_test.py +0 -343
  157. langfun/core/structured/schema.py +0 -987
  158. langfun/core/structured/schema_test.py +0 -982
  159. langfun-0.1.2.dev202509120804.dist-info/RECORD +0 -172
  160. {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512040805.dist-info}/WHEEL +0 -0
  161. {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512040805.dist-info}/licenses/LICENSE +0 -0
  162. {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512040805.dist-info}/top_level.txt +0 -0
@@ -14,8 +14,11 @@
14
14
  """Base classes for agentic actions."""
15
15
 
16
16
  import abc
17
+ import collections
17
18
  import contextlib
19
+ import dataclasses
18
20
  import functools
21
+ import itertools
19
22
  import threading
20
23
  import time
21
24
  import typing
@@ -35,7 +38,12 @@ class ActionTimeoutError(ActionError):
35
38
 
36
39
 
37
40
  class Action(pg.Object):
38
- """Base class for Langfun's agentic actions.
41
+ """Base class for agentic actions.
42
+
43
+ An `Action` represents a single, executable step or task that an agent can
44
+ perform, such as calling a tool, querying a language model, or returning a
45
+ final answer. Actions are designed to be composable and trackable within a
46
+ `Session`.
39
47
 
40
48
  # Developing Actions
41
49
 
@@ -148,7 +156,7 @@ class Action(pg.Object):
148
156
 
149
157
  # Explicitly create and pass a session.
150
158
  with lf.Session(id='my_agent_session') as session:
151
- result = calc(session=session) # Pass the session explicitly
159
+ result = calc(session=session) # Pass the session explicitly
152
160
  print(result)
153
161
  ```
154
162
 
@@ -187,11 +195,23 @@ class Action(pg.Object):
187
195
  self._session = None
188
196
  self._invocation: ActionInvocation | None = None
189
197
 
198
+ # NOTE(daiyip): Users could use `self._state` to keep track of the state
199
+ # during the execution of the action.
200
+ # Strictly speaking action state should better fit into
201
+ # ActionInvocation, we make it a property of Action as it's usually
202
+ # initialized before `__call__` is called.
203
+ self._state: Any = None
204
+
190
205
  @property
191
206
  def session(self) -> Optional['Session']:
192
207
  """Returns the session started by this action."""
193
208
  return self._session
194
209
 
210
+ @property
211
+ def state(self) -> Any:
212
+ """Returns the state of the action."""
213
+ return self._state
214
+
195
215
  @property
196
216
  def result(self) -> Any:
197
217
  """Returns the result of the action."""
@@ -295,19 +315,190 @@ class Action(pg.Object):
295
315
  """
296
316
 
297
317
 
318
+ #
319
+ # Execution tracking.
320
+ #
321
+
322
+
323
+ class ExecutionUnit(pg.Object):
324
+ """Base class for execution units in an agentic trajectory.
325
+
326
+ An `ExecutionUnit` represents a logical step or container in the agent's
327
+ execution flow. It serves as the common interface for top-level executable
328
+ items.
329
+
330
+ The concrete subclasses of `ExecutionUnit` are typically:
331
+ * **`ActionInvocation`**: Represents a single, specific action executed by
332
+ the agent.
333
+ * **`ParallelExecutions`**: Represents a container for a group of
334
+ `ExecutionUnits` that were executed concurrently.
335
+
336
+ Users can retrieve the immediate child execution units from an `ExecutionUnit`
337
+ object. To access the leaf nodes of the execution tree, use `all_actions`,
338
+ `all_queries`, and `all_logs` instead.
339
+
340
+ Each unit exposes a **`position`** property to reveal its specific location
341
+ within the execution hierarchy (e.g., '1.2.3').
342
+
343
+ Users could use **`parent_execution_unit`** to get the parent execution unit
344
+ of the current execution unit.
345
+ """
346
+
347
+ @dataclasses.dataclass
348
+ class Position:
349
+ """The position of an executed unit under current session."""
350
+
351
+ parent: Optional['ExecutionUnit.Position'] = None
352
+ index: int = 0
353
+
354
+ def indices(self) -> tuple[int, ...]:
355
+ """Returns the indices from root to current execution unit."""
356
+ # A deque is efficient for adding items to the front.
357
+ path = collections.deque()
358
+ current_pos = self
359
+
360
+ # Traverse up from the current position to the root.
361
+ while current_pos.parent is not None:
362
+ path.appendleft(current_pos.index)
363
+ current_pos = current_pos.parent
364
+
365
+ path.appendleft(current_pos.index)
366
+ return tuple(path)
367
+
368
+ def to_str(
369
+ self,
370
+ *,
371
+ index_base: int = 1,
372
+ separator: str = '.',
373
+ **kwargs
374
+ ) -> str:
375
+ """Returns a string description of the position."""
376
+ # For root action, we return empty string as it's position descriptor.
377
+ if self.parent is None:
378
+ return ''
379
+ parent_descriptor = self.parent.to_str(
380
+ index_base=index_base,
381
+ separator=separator,
382
+ **kwargs
383
+ )
384
+ if not parent_descriptor:
385
+ return f'{self.index + index_base}'
386
+ return parent_descriptor + separator + f'{self.index + index_base}'
387
+
388
+ def __repr__(self) -> str:
389
+ return f'Position({", ".join(str(x) for x in self.indices())})'
390
+
391
+ def __str__(self) -> str:
392
+ return self.to_str()
393
+
394
+ def __eq__(self, other: 'ExecutionUnit.Position') -> bool:
395
+ if isinstance(other, ExecutionUnit.Position):
396
+ return self.indices() == other.indices()
397
+ if isinstance(other, tuple):
398
+ return self.indices() == other
399
+ if isinstance(other, str):
400
+ return str(self) == other
401
+ return False
402
+
403
+ def __ne__(self, other: 'ExecutionUnit.Position') -> bool:
404
+ return not self == other
405
+
406
+ def __hash__(self) -> int:
407
+ return hash(self.indices())
408
+
409
+ def __lt__(self, other: 'ExecutionUnit.Position') -> bool:
410
+ return self.indices() < other.indices()
411
+
412
+ def __gt__(self, other: 'ExecutionUnit.Position') -> bool:
413
+ return self.indices() > other.indices()
414
+
415
+ def _on_parent_change(self, *args, **kwargs):
416
+ super()._on_parent_change(*args, **kwargs)
417
+ self.__dict__.pop('parent_execution_unit', None)
418
+ self.__dict__.pop('position', None)
419
+
420
+ @functools.cached_property
421
+ def parent_execution_unit(self) -> Optional['ExecutionUnit']:
422
+ """Returns the parent execution unit of the current execution unit."""
423
+ parent_trace = self.sym_ancestor(lambda x: isinstance(x, ExecutionTrace))
424
+ assert isinstance(parent_trace, ExecutionTrace), (
425
+ 'Execution unit is not associated with any `ExecutionTrace`: '
426
+ f'{self}'
427
+ )
428
+ return parent_trace.parent_execution_unit
429
+
430
+ @functools.cached_property
431
+ def position(self) -> Position:
432
+ """Returns the execution position of the action."""
433
+ parent_trace = self.sym_ancestor(lambda x: isinstance(x, ExecutionTrace))
434
+ while parent_trace is not None:
435
+ parent_position = parent_trace.position
436
+ if parent_position is not None:
437
+ return ExecutionUnit.Position(
438
+ parent_position, parent_trace.indexof(self, ExecutionUnit)
439
+ )
440
+ parent_trace = parent_trace.sym_ancestor(
441
+ lambda x: isinstance(x, ExecutionTrace)
442
+ )
443
+ return ExecutionUnit.Position(None, 0)
444
+
445
+ @property
446
+ @abc.abstractmethod
447
+ def execution_units(
448
+ self,
449
+ ) -> list['ExecutionUnit']:
450
+ """Returns immediate child execution items."""
451
+
452
+ @property
453
+ @abc.abstractmethod
454
+ def queries(self) -> list[lf_structured.QueryInvocation]:
455
+ """Returns queries issued by the execution item."""
456
+
457
+ @property
458
+ @abc.abstractmethod
459
+ def actions(self) -> list['ActionInvocation']:
460
+ """Returns immediate child action invocations."""
461
+
462
+ @property
463
+ @abc.abstractmethod
464
+ def logs(self) -> list[lf.logging.LogEntry]:
465
+ """Returns immediate logs under current execution item."""
466
+
467
+ @property
468
+ @abc.abstractmethod
469
+ def all_queries(self) -> list[lf_structured.QueryInvocation]:
470
+ """Returns all queries from the subtree."""
471
+
472
+ @property
473
+ @abc.abstractmethod
474
+ def all_actions(self) -> list['ActionInvocation']:
475
+ """Returns all action invocations from the subtree."""
476
+
477
+ @property
478
+ @abc.abstractmethod
479
+ def all_logs(self) -> list[lf.logging.LogEntry]:
480
+ """Returns all logs from the subtree."""
481
+
482
+
298
483
  # Type definition for traced item during execution.
299
484
  TracedItem = Union[
485
+ ExecutionUnit,
300
486
  lf_structured.QueryInvocation,
301
- 'ActionInvocation',
302
487
  'ExecutionTrace',
303
- 'ParallelExecutions',
304
488
  # NOTE(daiyip): Consider remove log entry once we migrate existing agents.
305
489
  lf.logging.LogEntry,
306
490
  ]
307
491
 
308
492
 
309
493
  class ExecutionTrace(pg.Object, pg.views.html.HtmlTreeView.Extension):
310
- """Trace of the execution of an action."""
494
+ """Trace of an execution, containing queries, logs, and sub-actions.
495
+
496
+ `ExecutionTrace` records the sequence of operations performed during an
497
+ action's execution or within a specific phase of execution (demarcated by
498
+ `session.track_phase`). It captures `lf.query` calls, log entries, and
499
+ nested `ActionInvocation` objects in the order they occurred. It also
500
+ aggregates LLM usage summaries from its child items.
501
+ """
311
502
 
312
503
  name: Annotated[
313
504
  str | None,
@@ -315,7 +506,7 @@ class ExecutionTrace(pg.Object, pg.views.html.HtmlTreeView.Extension):
315
506
  'The name of the execution trace. If None, the trace is unnamed, '
316
507
  'which is the case for the top-level trace of an action. An '
317
508
  'execution trace could have sub-traces, called phases, which are '
318
- 'created and named by `session.phase()` context manager.'
509
+ 'created and named by `session.track_phase()` context manager.'
319
510
  )
320
511
  ] = None
321
512
 
@@ -347,9 +538,15 @@ class ExecutionTrace(pg.Object, pg.views.html.HtmlTreeView.Extension):
347
538
  def _on_parent_change(self, *args, **kwargs):
348
539
  super()._on_parent_change(*args, **kwargs)
349
540
  self.__dict__.pop('id', None)
541
+ self.__dict__.pop('parent_execution_unit', None)
542
+ self.__dict__.pop('position', None)
350
543
 
351
- def indexof(self, item: TracedItem, count_item_cls: Type[Any]) -> int:
352
- """Returns the index of the child items of given type."""
544
+ def indexof(
545
+ self,
546
+ item: TracedItem,
547
+ count_item_cls: Type[Any] | tuple[Type[Any], ...]
548
+ ) -> int:
549
+ """Returns the index of the child item of given type."""
353
550
  pos = 0
354
551
  for x in self._iter_children(count_item_cls):
355
552
  if x is item:
@@ -357,6 +554,52 @@ class ExecutionTrace(pg.Object, pg.views.html.HtmlTreeView.Extension):
357
554
  pos += 1
358
555
  return -1
359
556
 
557
+ @functools.cached_property
558
+ def parent_execution_unit(self) -> ExecutionUnit:
559
+ """Returns the parent execution unit of the current execution trace."""
560
+ parent = self.sym_parent
561
+ if isinstance(parent, ActionInvocation):
562
+ # Current execution trace is the body of an action.
563
+ return parent
564
+ elif isinstance(parent, pg.List):
565
+ container = parent.sym_parent
566
+ if isinstance(container, ParallelExecutions):
567
+ return container
568
+ elif isinstance(container, ExecutionTrace):
569
+ return container.parent_execution_unit
570
+ assert False, (
571
+ 'Execution trace is not associated with any `ActionInvocation` or '
572
+ f'`ParallelExecutions`: {self}'
573
+ )
574
+
575
+ @functools.cached_property
576
+ def position(self) -> ExecutionUnit.Position | None:
577
+ """Returns the execution position of the execution trace.
578
+
579
+ Returns:
580
+ The execution position of the execution trace, or None if the execution
581
+ trace is either not associated with an execution unit or the execution
582
+ trace is a phase under another execution trace.
583
+ """
584
+ parent = self.sym_parent
585
+ if isinstance(parent, ActionInvocation):
586
+ # Current execution trace is the body of an action.
587
+ return parent.position
588
+ elif isinstance(parent, pg.List):
589
+ container = parent.sym_parent
590
+ if isinstance(container, ParallelExecutions):
591
+ return ExecutionUnit.Position(
592
+ container.position, self.sym_path.key
593
+ )
594
+ elif isinstance(container, ExecutionTrace):
595
+ # When execution trace is a phase under another execution trace,
596
+ # we return None as the position.
597
+ return None
598
+ assert False, (
599
+ 'Execution trace is not associated with any `ActionInvocation` or '
600
+ f'`ParallelExecutions`: {self}'
601
+ )
602
+
360
603
  @functools.cached_property
361
604
  def id(self) -> str:
362
605
  parent = self.sym_parent
@@ -428,6 +671,11 @@ class ExecutionTrace(pg.Object, pg.views.html.HtmlTreeView.Extension):
428
671
  """Returns action invocations from the sequence."""
429
672
  return list(self._iter_children(ActionInvocation))
430
673
 
674
+ @property
675
+ def execution_units(self) -> list[ExecutionUnit]:
676
+ """Returns parallel executions from the sequence."""
677
+ return list(self._iter_children(ExecutionUnit))
678
+
431
679
  @property
432
680
  def logs(self) -> list[lf.logging.LogEntry]:
433
681
  """Returns logs from the sequence."""
@@ -448,7 +696,9 @@ class ExecutionTrace(pg.Object, pg.views.html.HtmlTreeView.Extension):
448
696
  """Returns all logs from current trace and its child execution items."""
449
697
  return list(self._iter_subtree(lf.logging.LogEntry))
450
698
 
451
- def _iter_children(self, item_cls: Type[Any]) -> Iterator[TracedItem]:
699
+ def _iter_children(
700
+ self, item_cls: Type[Any] | tuple[Type[Any], ...]
701
+ ) -> Iterator[TracedItem]:
452
702
  for item in self.items:
453
703
  if isinstance(item, item_cls):
454
704
  yield item
@@ -460,7 +710,10 @@ class ExecutionTrace(pg.Object, pg.views.html.HtmlTreeView.Extension):
460
710
  for x in branch._iter_children(item_cls): # pylint: disable=protected-access
461
711
  yield x
462
712
 
463
- def _iter_subtree(self, item_cls: Type[Any]) -> Iterator[TracedItem]:
713
+ def _iter_subtree(
714
+ self,
715
+ item_cls: Type[Any] | tuple[Type[Any], ...]
716
+ ) -> Iterator[TracedItem]:
464
717
  for item in self.items:
465
718
  if isinstance(item, item_cls):
466
719
  yield item
@@ -525,6 +778,18 @@ class ExecutionTrace(pg.Object, pg.views.html.HtmlTreeView.Extension):
525
778
  remove_class=['not-started'],
526
779
  )
527
780
 
781
+ def remove(self, item: TracedItem) -> None:
782
+ """Removes an item from the sequence."""
783
+ index = self.items.index(item)
784
+ if index == -1:
785
+ raise ValueError(f'Item not found in execution trace: {item!r}')
786
+
787
+ with pg.notify_on_change(False):
788
+ self.items.pop(index)
789
+
790
+ if self._tab_control is not None:
791
+ self._tab_control.remove(index)
792
+
528
793
  def extend(self, items: Iterable[TracedItem]) -> None:
529
794
  """Extends the sequence with a list of items."""
530
795
  for item in items:
@@ -761,8 +1026,13 @@ class ExecutionTrace(pg.Object, pg.views.html.HtmlTreeView.Extension):
761
1026
  ]
762
1027
 
763
1028
 
764
- class ParallelExecutions(pg.Object, pg.views.html.HtmlTreeView.Extension):
765
- """A class for encapsulating parallel execution traces."""
1029
+ class ParallelExecutions(ExecutionUnit, pg.views.html.HtmlTreeView.Extension):
1030
+ """A container for multiple parallel execution traces.
1031
+
1032
+ When `session.concurrent_map` is used, it creates a `ParallelExecutions`
1033
+ object to hold an `ExecutionTrace` for each parallel branch of execution,
1034
+ allowing inspection of parallel workflows.
1035
+ """
766
1036
 
767
1037
  name: Annotated[
768
1038
  str | None,
@@ -808,8 +1078,64 @@ class ParallelExecutions(pg.Object, pg.views.html.HtmlTreeView.Extension):
808
1078
  self.branches.append(branch)
809
1079
  if self._tab_control is not None:
810
1080
  self._tab_control.append(self._branch_tab(branch))
1081
+
1082
+ # Invalidate cached properties.
1083
+ self.__dict__.pop('all_queries', None)
1084
+ self.__dict__.pop('all_actions', None)
1085
+ self.__dict__.pop('all_logs', None)
811
1086
  return branch
812
1087
 
1088
+ #
1089
+ # ExecutionUnit interface.
1090
+ #
1091
+
1092
+ @property
1093
+ def execution_units(self) -> list[ExecutionUnit]:
1094
+ """Returns immediate child execution items from execution sequence."""
1095
+ return []
1096
+
1097
+ @property
1098
+ def queries(self) -> list[lf_structured.QueryInvocation]:
1099
+ """Returns immediate queries made by the parallel execution."""
1100
+ return []
1101
+
1102
+ @property
1103
+ def actions(self) -> list['ActionInvocation']:
1104
+ """Returns immediate child action invocations."""
1105
+ return []
1106
+
1107
+ @property
1108
+ def logs(self) -> list[lf.logging.LogEntry]:
1109
+ """Returns immediate child logs from execution sequence."""
1110
+ return []
1111
+
1112
+ @functools.cached_property
1113
+ def all_queries(self) -> list[lf_structured.QueryInvocation]:
1114
+ """Returns all queries made by the action and its child execution items."""
1115
+ return list(
1116
+ itertools.chain.from_iterable(
1117
+ branch.all_queries for branch in self.branches
1118
+ )
1119
+ )
1120
+
1121
+ @functools.cached_property
1122
+ def all_actions(self) -> list['ActionInvocation']:
1123
+ """Returns all actions made by the action and its child execution items."""
1124
+ return list(
1125
+ itertools.chain.from_iterable(
1126
+ branch.all_actions for branch in self.branches
1127
+ )
1128
+ )
1129
+
1130
+ @property
1131
+ def all_logs(self) -> list[lf.logging.LogEntry]:
1132
+ """Returns all logs made by the action and its child execution items."""
1133
+ return list(
1134
+ itertools.chain.from_iterable(
1135
+ branch.all_logs for branch in self.branches
1136
+ )
1137
+ )
1138
+
813
1139
  #
814
1140
  # HTML views.
815
1141
  #
@@ -850,9 +1176,20 @@ class ParallelExecutions(pg.Object, pg.views.html.HtmlTreeView.Extension):
850
1176
  )
851
1177
 
852
1178
 
853
- class ActionInvocation(pg.Object, pg.views.html.HtmlTreeView.Extension):
854
- """A class for capturing the invocation of an action."""
855
- action: Action
1179
+ class ActionInvocation(ExecutionUnit, pg.views.html.HtmlTreeView.Extension):
1180
+ """An invocation of an action, capturing its execution and result.
1181
+
1182
+ `ActionInvocation` represents a single call to an `Action`. It contains
1183
+ the `Action` object itself, its result or error, associated metadata,
1184
+ and an `ExecutionTrace` detailing the steps taken during its execution
1185
+ (queries, logs, sub-actions). Invocations form a tree structure within a
1186
+ `Session`, reflecting the hierarchy of agentic operations.
1187
+ """
1188
+
1189
+ action: Annotated[
1190
+ Action,
1191
+ 'The action being invoked.'
1192
+ ]
856
1193
 
857
1194
  result: Annotated[
858
1195
  Any,
@@ -900,6 +1237,7 @@ class ActionInvocation(pg.Object, pg.views.html.HtmlTreeView.Extension):
900
1237
  def _on_parent_change(self, *args, **kwargs):
901
1238
  super()._on_parent_change(*args, **kwargs)
902
1239
  self.__dict__.pop('id', None)
1240
+ self.__dict__.pop('position', None)
903
1241
 
904
1242
  @property
905
1243
  def parent_action(self) -> Optional['ActionInvocation']:
@@ -931,6 +1269,20 @@ class ActionInvocation(pg.Object, pg.views.html.HtmlTreeView.Extension):
931
1269
  """Returns True if the action invocation has an error."""
932
1270
  return self.error is not None
933
1271
 
1272
+ @property
1273
+ def state(self) -> Any:
1274
+ """Returns the state of the action."""
1275
+ return self.action.state
1276
+
1277
+ #
1278
+ # Implement `ExecutionUnit` interface.
1279
+ #
1280
+
1281
+ @property
1282
+ def execution_units(self) -> list[ExecutionUnit]:
1283
+ """Returns immediate child execution items from execution sequence."""
1284
+ return self.execution.execution_units
1285
+
934
1286
  @property
935
1287
  def logs(self) -> list[lf.logging.LogEntry]:
936
1288
  """Returns immediate child logs from execution sequence."""
@@ -982,10 +1334,12 @@ class ActionInvocation(pg.Object, pg.views.html.HtmlTreeView.Extension):
982
1334
  metadata: dict[str, Any] | None = None,
983
1335
  ) -> None:
984
1336
  """Ends the execution of the action with result and metadata."""
985
- rebind_dict = dict(result=result, error=error)
986
- if metadata is not None:
987
- rebind_dict['metadata'] = metadata
988
- self.rebind(**rebind_dict, skip_notification=True, raise_on_no_change=False)
1337
+ with pg.notify_on_change(False):
1338
+ self.result = result
1339
+ self.error = error
1340
+ if metadata:
1341
+ self.metadata.update(metadata)
1342
+
989
1343
  self.execution.stop()
990
1344
  if self._tab_control is not None:
991
1345
  if self.metadata:
@@ -1140,31 +1494,346 @@ class RootAction(Action):
1140
1494
  raise NotImplementedError('Shall not be called.')
1141
1495
 
1142
1496
 
1497
+ class SessionEventHandler:
1498
+ """Interface for handling session events."""
1499
+
1500
+ def get(
1501
+ self,
1502
+ session_cls: type['SessionEventHandler']
1503
+ ) -> Optional['SessionEventHandler']:
1504
+ """Returns this or a child event handler for the given session class."""
1505
+ if isinstance(self, session_cls):
1506
+ return self
1507
+ elif isinstance(self, SessionEventHandlerChain):
1508
+ for handler in self.handlers:
1509
+ if v := handler.get(session_cls):
1510
+ return v
1511
+ return None
1512
+
1513
+ def on_session_start(
1514
+ self,
1515
+ session: 'Session'
1516
+ ) -> None:
1517
+ """Called when a session starts."""
1518
+
1519
+ def on_session_end(
1520
+ self,
1521
+ session: 'Session'
1522
+ ) -> None:
1523
+ """Called when a session ends."""
1524
+
1525
+ def on_action_start(
1526
+ self,
1527
+ session: 'Session',
1528
+ action: ActionInvocation
1529
+ ) -> None:
1530
+ """Called when an action starts."""
1531
+
1532
+ def on_action_end(
1533
+ self,
1534
+ session: 'Session',
1535
+ action: ActionInvocation
1536
+ ) -> None:
1537
+ """Called when an action ends."""
1538
+
1539
+ def on_action_progress(
1540
+ self,
1541
+ session: 'Session',
1542
+ action: ActionInvocation,
1543
+ title: str,
1544
+ **kwargs
1545
+ ) -> None:
1546
+ """Called when an action progress is updated."""
1547
+
1548
+ def on_query_start(
1549
+ self,
1550
+ session: 'Session',
1551
+ action: ActionInvocation,
1552
+ query: lf_structured.QueryInvocation,
1553
+ ) -> None:
1554
+ """Called when a query starts."""
1555
+
1556
+ def on_query_end(
1557
+ self,
1558
+ session: 'Session',
1559
+ action: ActionInvocation,
1560
+ query: lf_structured.QueryInvocation,
1561
+ ) -> None:
1562
+ """Called when a query ends."""
1563
+
1564
+
1565
+ @dataclasses.dataclass
1566
+ class SessionEventHandlerChain(SessionEventHandler):
1567
+ """A session event handler that chains multiple event handlers."""
1568
+
1569
+ handlers: list[SessionEventHandler]
1570
+
1571
+ def on_session_start(self, session: 'Session') -> None:
1572
+ """Called when a session starts."""
1573
+ for handler in self.handlers:
1574
+ handler.on_session_start(session)
1575
+
1576
+ def on_session_end(self, session: 'Session') -> None:
1577
+ """Called when a session ends."""
1578
+ for handler in self.handlers:
1579
+ handler.on_session_end(session)
1580
+
1581
+ def on_action_start(
1582
+ self,
1583
+ session: 'Session',
1584
+ action: ActionInvocation) -> None:
1585
+ """Called when an action starts."""
1586
+ for handler in self.handlers:
1587
+ handler.on_action_start(session, action)
1588
+
1589
+ def on_action_end(
1590
+ self,
1591
+ session: 'Session',
1592
+ action: ActionInvocation) -> None:
1593
+ """Called when an action ends."""
1594
+ for handler in self.handlers:
1595
+ handler.on_action_end(session, action)
1596
+
1597
+ def on_action_progress(
1598
+ self,
1599
+ session: 'Session',
1600
+ action: ActionInvocation,
1601
+ title: str,
1602
+ **kwargs
1603
+ ) -> None:
1604
+ """Called when an action progress is updated."""
1605
+ for handler in self.handlers:
1606
+ handler.on_action_progress(session, action, title, **kwargs)
1607
+
1608
+ def on_query_start(
1609
+ self,
1610
+ session: 'Session',
1611
+ action: ActionInvocation,
1612
+ query: lf_structured.QueryInvocation,
1613
+ ) -> None:
1614
+ """Called when a query starts."""
1615
+ for handler in self.handlers:
1616
+ handler.on_query_start(session, action, query)
1617
+
1618
+ def on_query_end(
1619
+ self,
1620
+ session: 'Session',
1621
+ action: ActionInvocation,
1622
+ query: lf_structured.QueryInvocation,
1623
+ ) -> None:
1624
+ """Called when a query ends."""
1625
+ for handler in self.handlers:
1626
+ handler.on_query_end(session, action, query)
1627
+
1628
+
1629
+ @dataclasses.dataclass
1630
+ class SessionLogging(SessionEventHandler):
1631
+ """An event handler that logs Session events."""
1632
+
1633
+ verbose: bool = False
1634
+
1635
+ def on_session_end(self, session: 'Session'):
1636
+ if session.has_error:
1637
+ session.error(
1638
+ f'Trajectory failed in {session.elapse:.2f} seconds.',
1639
+ error=session.final_error,
1640
+ metadata=session.root.metadata,
1641
+ keep=True,
1642
+ )
1643
+ elif self.verbose:
1644
+ session.info(
1645
+ f'Trajectory succeeded in {session.elapse:.2f} seconds.',
1646
+ result=session.final_result,
1647
+ metadata=session.root.metadata,
1648
+ keep=False,
1649
+ )
1650
+
1651
+ def on_action_start(
1652
+ self,
1653
+ session: 'Session',
1654
+ action: ActionInvocation
1655
+ ) -> None:
1656
+ if self.verbose:
1657
+ session.info(
1658
+ 'Action execution started.',
1659
+ action=action.action,
1660
+ keep=False,
1661
+ )
1662
+
1663
+ def on_action_end(
1664
+ self,
1665
+ session: 'Session',
1666
+ action: ActionInvocation
1667
+ ) -> None:
1668
+ if action.has_error:
1669
+ session.warning(
1670
+ (
1671
+ f'Action execution failed in '
1672
+ f'{action.execution.elapse:.2f} seconds.'
1673
+ ),
1674
+ action=action.action,
1675
+ error=action.error,
1676
+ keep=True,
1677
+ )
1678
+ elif self.verbose:
1679
+ session.info(
1680
+ (
1681
+ f'Action execution succeeded in '
1682
+ f'{action.execution.elapse:.2f} seconds.'
1683
+ ),
1684
+ action=action.action,
1685
+ result=action.result,
1686
+ keep=False,
1687
+ )
1688
+
1689
+ def on_query_start(
1690
+ self,
1691
+ session: 'Session',
1692
+ action: ActionInvocation,
1693
+ query: lf_structured.QueryInvocation,
1694
+ ) -> None:
1695
+ if self.verbose:
1696
+ session.info(
1697
+ 'Querying LLM started.',
1698
+ lm=query.lm.model_id,
1699
+ output_type=(
1700
+ lf_structured.annotation(query.schema.spec)
1701
+ if query.schema is not None else None
1702
+ ),
1703
+ keep=False,
1704
+ )
1705
+
1706
+ def on_query_end(
1707
+ self,
1708
+ session: 'Session',
1709
+ action: ActionInvocation,
1710
+ query: lf_structured.QueryInvocation,
1711
+ ) -> None:
1712
+ if query.has_error:
1713
+ session.warning(
1714
+ (
1715
+ f'Querying LLM failed in '
1716
+ f'{time.time() - query.start_time:.2f} seconds.'
1717
+ ),
1718
+ lm=query.lm.model_id,
1719
+ output_type=(
1720
+ lf_structured.annotation(query.schema.spec)
1721
+ if query.schema is not None else None
1722
+ ),
1723
+ error=query.error,
1724
+ keep=True,
1725
+ )
1726
+ elif self.verbose:
1727
+ session.info(
1728
+ (
1729
+ f'Querying LLM succeeded in '
1730
+ f'{time.time() - query.start_time:.2f} seconds.'
1731
+ ),
1732
+ lm=query.lm.model_id,
1733
+ output_type=(
1734
+ lf_structured.annotation(query.schema.spec)
1735
+ if query.schema is not None else None
1736
+ ),
1737
+ keep=False,
1738
+ )
1739
+
1740
+
1143
1741
  class Session(pg.Object, pg.views.html.HtmlTreeView.Extension):
1144
- """Session for performing an agentic task."""
1742
+ """Manages the execution trajectory of agentic actions.
1743
+
1744
+ A `Session` tracks the execution of a root `Action` and all its
1745
+ sub-actions, including LLM queries (`lf.query`), logging messages,
1746
+ and nested actions. It provides a complete, hierarchical trace of an
1747
+ agent's workflow, which is important for debugging, analysis, and
1748
+ visualization.
1749
+
1750
+ Sessions can be created implicitly when an action is called without an
1751
+ active session, or explicitly for more control.
1752
+
1753
+ **1. Implicit Session:**
1754
+ When an action is called without a session, Langfun creates one automatically.
1755
+
1756
+ ```python
1757
+ action = MyAction()
1758
+ action()
1759
+ session = action.session # Access the implicit session
1760
+ ```
1761
+
1762
+ **2. Explicit Session:**
1763
+ Use a `with` statement to manage a session explicitly. This is useful for
1764
+ setting session IDs or capturing the trajectory of multiple top-level actions.
1765
+
1766
+ ```python
1767
+ with lf.Session(id='my-session') as session:
1768
+ action1()
1769
+ action2()
1770
+ ```
1771
+
1772
+ **3. Accessing Trajectory:**
1773
+ The `session.root` attribute provides access to the `ActionInvocation` tree.
1774
+
1775
+ ```python
1776
+ with lf.Session() as session:
1777
+ my_action()
1778
+
1779
+ # Get all queries in the session
1780
+ print(session.all_queries)
1781
+
1782
+ # Get all top-level action calls in the session
1783
+ print(session.root.actions)
1784
+ ```
1785
+ """
1145
1786
 
1146
1787
  root: Annotated[
1147
1788
  ActionInvocation,
1148
1789
  'The root action invocation of the session.'
1149
- ] = ActionInvocation(RootAction())
1790
+ ]
1150
1791
 
1151
1792
  id: Annotated[
1152
1793
  str | None,
1153
- 'An optional identifier for the sessin, which will be used for logging.'
1154
- ] = None
1794
+ 'An optional identifier for the session, which will be used for logging.'
1795
+ ]
1155
1796
 
1156
- verbose: Annotated[
1157
- bool,
1158
- (
1159
- 'If True, the session will be logged with verbose action and query '
1160
- 'activities.'
1161
- )
1162
- ] = False
1797
+ @pg.explicit_method_override
1798
+ def __init__(
1799
+ self,
1800
+ id: str | None = None, # pylint: disable=redefined-builtin
1801
+ *,
1802
+ verbose: bool = False,
1803
+ event_handler: SessionEventHandler | None = None,
1804
+ root: ActionInvocation | None = None,
1805
+ **kwargs
1806
+ ):
1807
+ super().__init__(
1808
+ id=id,
1809
+ root=root or ActionInvocation(RootAction()),
1810
+ **kwargs
1811
+ )
1812
+ self._event_handler = event_handler or SessionLogging(verbose=verbose)
1813
+
1814
+ @property
1815
+ def event_handler(self) -> SessionEventHandler:
1816
+ """Returns the event handler for the session."""
1817
+ return self._event_handler
1818
+
1819
+ def _sym_clone(self, deep: bool, memo: Any = None) -> 'Session':
1820
+ other = super()._sym_clone(deep=deep, memo=memo)
1821
+ if deep:
1822
+ event_handler = pg.clone(self.event_handler, deep=deep, memo=memo)
1823
+ else:
1824
+ event_handler = self.event_handler
1825
+ other._event_handler = event_handler # pylint: disable=protected-access
1826
+ return other
1163
1827
 
1164
1828
  #
1165
1829
  # Shortcut methods for accessing the root action invocation.
1166
1830
  #
1167
1831
 
1832
+ @property
1833
+ def metadata(self) -> dict[str, Any]:
1834
+ """Returns metadata associated with the root of the session."""
1835
+ return self.root.metadata
1836
+
1168
1837
  @property
1169
1838
  def all_queries(self) -> list[lf_structured.QueryInvocation]:
1170
1839
  """Returns all queries made by the session."""
@@ -1250,6 +1919,7 @@ class Session(pg.Object, pg.views.html.HtmlTreeView.Extension):
1250
1919
  def start(self) -> None:
1251
1920
  """Starts the session."""
1252
1921
  self.root.execution.start()
1922
+ self.event_handler.on_session_start(self)
1253
1923
 
1254
1924
  def end(
1255
1925
  self,
@@ -1258,21 +1928,8 @@ class Session(pg.Object, pg.views.html.HtmlTreeView.Extension):
1258
1928
  metadata: dict[str, Any] | None = None,
1259
1929
  ) -> None:
1260
1930
  """Ends the session."""
1261
- if error is not None:
1262
- self.error(
1263
- f'Trajectory failed in {self.elapse:.2f} seconds.',
1264
- error=error,
1265
- metadata=metadata,
1266
- keep=True,
1267
- )
1268
- elif self.verbose:
1269
- self.info(
1270
- f'Trajectory succeeded in {self.elapse:.2f} seconds.',
1271
- result=result,
1272
- metadata=metadata,
1273
- keep=False,
1274
- )
1275
1931
  self.root.end(result, error, metadata)
1932
+ self.event_handler.on_session_end(self)
1276
1933
 
1277
1934
  def check_execution_time(self) -> None:
1278
1935
  """Checks the execution time of the current action."""
@@ -1291,6 +1948,20 @@ class Session(pg.Object, pg.views.html.HtmlTreeView.Extension):
1291
1948
  'seconds.'
1292
1949
  )
1293
1950
 
1951
+ def update_progress(self, title: str, **kwargs: Any) -> None:
1952
+ """Updates the progress of current action's execution.
1953
+
1954
+ Args:
1955
+ title: The title of the progress update.
1956
+ **kwargs: Additional keyword arguments to pass to the event handler.
1957
+ """
1958
+ self.event_handler.on_action_progress(
1959
+ self,
1960
+ self._current_action,
1961
+ title,
1962
+ **kwargs
1963
+ )
1964
+
1294
1965
  def __enter__(self):
1295
1966
  """Enters the session."""
1296
1967
  self.start()
@@ -1350,34 +2021,10 @@ class Session(pg.Object, pg.views.html.HtmlTreeView.Extension):
1350
2021
  self._current_execution = invocation.execution
1351
2022
  # Start the execution of the current action.
1352
2023
  self._current_action.start()
1353
- if self.verbose:
1354
- self.info(
1355
- 'Action execution started.',
1356
- action=invocation.action,
1357
- keep=False,
1358
- )
2024
+ self.event_handler.on_action_start(self, self._current_action)
1359
2025
  yield invocation
1360
2026
  finally:
1361
- if invocation.has_error:
1362
- self.warning(
1363
- (
1364
- f'Action execution failed in '
1365
- f'{invocation.execution.elapse:.2f} seconds.'
1366
- ),
1367
- action=invocation.action,
1368
- error=invocation.error,
1369
- keep=True,
1370
- )
1371
- elif self.verbose:
1372
- self.info(
1373
- (
1374
- f'Action execution succeeded in '
1375
- f'{invocation.execution.elapse:.2f} seconds.'
1376
- ),
1377
- action=invocation.action,
1378
- result=invocation.result,
1379
- keep=False,
1380
- )
2027
+ self.event_handler.on_action_end(self, self._current_action)
1381
2028
  self._current_execution = parent_execution
1382
2029
  self._current_action = parent_action
1383
2030
 
@@ -1403,13 +2050,20 @@ class Session(pg.Object, pg.views.html.HtmlTreeView.Extension):
1403
2050
  @contextlib.contextmanager
1404
2051
  def track_queries(
1405
2052
  self,
1406
- phase: str | None = None
2053
+ phase: str | None = None,
2054
+ track_if: Callable[
2055
+ [lf_structured.QueryInvocation],
2056
+ bool
2057
+ ] | None = None,
1407
2058
  ) -> Iterator[list[lf_structured.QueryInvocation]]:
1408
2059
  """Tracks `lf.query` made within the context.
1409
2060
 
1410
2061
  Args:
1411
2062
  phase: The name of a new phase to track the queries in. If not provided,
1412
2063
  the queries will be tracked in the parent phase.
2064
+ track_if: A function that takes a `lf_structured.QueryInvocation` and
2065
+ returns True if the query should be included in the result. If None,
2066
+ all queries (including failed queries) will be included.
1413
2067
 
1414
2068
  Yields:
1415
2069
  A list of `lf.QueryInvocation` objects, each for a single `lf.query`
@@ -1425,51 +2079,21 @@ class Session(pg.Object, pg.views.html.HtmlTreeView.Extension):
1425
2079
  skip_notification=False, raise_on_no_change=False
1426
2080
  )
1427
2081
  execution.append(invocation)
1428
- if self.verbose:
1429
- self.info(
1430
- 'Querying LLM started.',
1431
- lm=invocation.lm.model_id,
1432
- output_type=(
1433
- lf_structured.annotation(invocation.schema.spec)
1434
- if invocation.schema is not None else None
1435
- ),
1436
- keep=False,
1437
- )
2082
+ self.event_handler.on_query_start(self, self._current_action, invocation)
1438
2083
 
1439
2084
  def _query_end(invocation: lf_structured.QueryInvocation):
2085
+ if track_if is not None and not track_if(invocation):
2086
+ self._current_execution.remove(invocation)
2087
+ # Even if the query is not included in the execution trace, we still
2088
+ # count the usage summary to the current execution and trigger the
2089
+ # event handler to log the query.
1440
2090
  self._current_execution.merge_usage_summary(invocation.usage_summary)
1441
- if invocation.has_error:
1442
- self.warning(
1443
- (
1444
- f'Querying LLM failed in '
1445
- f'{time.time() - invocation.start_time:.2f} seconds.'
1446
- ),
1447
- lm=invocation.lm.model_id,
1448
- output_type=(
1449
- lf_structured.annotation(invocation.schema.spec)
1450
- if invocation.schema is not None else None
1451
- ),
1452
- error=invocation.error,
1453
- keep=True,
1454
- )
1455
- elif self.verbose:
1456
- self.info(
1457
- (
1458
- f'Querying LLM succeeded in '
1459
- f'{time.time() - invocation.start_time:.2f} seconds.'
1460
- ),
1461
- lm=invocation.lm.model_id,
1462
- output_type=(
1463
- lf_structured.annotation(invocation.schema.spec)
1464
- if invocation.schema is not None else None
1465
- ),
1466
- keep=False,
1467
- )
2091
+ self.event_handler.on_query_end(self, self._current_action, invocation)
1468
2092
 
1469
2093
  with self.track_phase(phase), lf_structured.track_queries(
1470
2094
  include_child_scopes=False,
1471
- start_callabck=_query_start,
1472
- end_callabck=_query_end,
2095
+ start_callback=_query_start,
2096
+ end_callback=_query_end,
1473
2097
  ) as queries:
1474
2098
  try:
1475
2099
  yield queries
@@ -1495,8 +2119,9 @@ class Session(pg.Object, pg.views.html.HtmlTreeView.Extension):
1495
2119
  *,
1496
2120
  lm: lf.LanguageModel,
1497
2121
  examples: list[lf_structured.MappingExample] | None = None,
2122
+ track_if: Callable[[lf_structured.QueryInvocation], bool] | None = None,
1498
2123
  **kwargs
1499
- ) -> Any:
2124
+ ) -> Any:
1500
2125
  """Calls `lf.query` and associates it with the current invocation.
1501
2126
 
1502
2127
  The following code are equivalent:
@@ -1521,12 +2146,15 @@ class Session(pg.Object, pg.views.html.HtmlTreeView.Extension):
1521
2146
  default: The default value to return if the query fails.
1522
2147
  lm: The language model to use for the query.
1523
2148
  examples: The examples to use for the query.
2149
+ track_if: A function that takes a `lf_structured.QueryInvocation`
2150
+ and returns True if the query should be tracked.
2151
+ If None, all queries (including failed queries) will be tracked.
1524
2152
  **kwargs: Additional keyword arguments to pass to `lf.query`.
1525
2153
 
1526
2154
  Returns:
1527
2155
  The result of the query.
1528
2156
  """
1529
- with self.track_queries():
2157
+ with self.track_queries(track_if=track_if):
1530
2158
  return lf_structured.query(
1531
2159
  prompt,
1532
2160
  schema=schema,