langfun 0.0.2.dev20240429__py3-none-any.whl → 0.1.2.dev202501150804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (144) hide show
  1. langfun/__init__.py +20 -2
  2. langfun/core/__init__.py +16 -5
  3. langfun/core/agentic/__init__.py +30 -0
  4. langfun/core/agentic/action.py +854 -0
  5. langfun/core/agentic/action_eval.py +150 -0
  6. langfun/core/agentic/action_eval_test.py +109 -0
  7. langfun/core/agentic/action_test.py +136 -0
  8. langfun/core/coding/python/__init__.py +5 -11
  9. langfun/core/coding/python/correction.py +37 -21
  10. langfun/core/coding/python/correction_test.py +29 -3
  11. langfun/core/coding/python/execution.py +40 -216
  12. langfun/core/coding/python/execution_test.py +29 -89
  13. langfun/core/coding/python/generation.py +21 -11
  14. langfun/core/coding/python/generation_test.py +2 -2
  15. langfun/core/coding/python/parsing.py +108 -193
  16. langfun/core/coding/python/parsing_test.py +2 -105
  17. langfun/core/component.py +63 -2
  18. langfun/core/component_test.py +53 -0
  19. langfun/core/concurrent.py +414 -117
  20. langfun/core/concurrent_test.py +111 -24
  21. langfun/core/console.py +17 -5
  22. langfun/core/console_test.py +17 -0
  23. langfun/core/eval/__init__.py +16 -1
  24. langfun/core/eval/base.py +622 -174
  25. langfun/core/eval/base_test.py +200 -54
  26. langfun/core/eval/matching.py +63 -76
  27. langfun/core/eval/matching_test.py +17 -8
  28. langfun/core/eval/patching.py +130 -0
  29. langfun/core/eval/patching_test.py +170 -0
  30. langfun/core/eval/scoring.py +26 -26
  31. langfun/core/eval/scoring_test.py +19 -2
  32. langfun/core/eval/v2/__init__.py +42 -0
  33. langfun/core/eval/v2/checkpointing.py +380 -0
  34. langfun/core/eval/v2/checkpointing_test.py +228 -0
  35. langfun/core/eval/v2/eval_test_helper.py +136 -0
  36. langfun/core/eval/v2/evaluation.py +725 -0
  37. langfun/core/eval/v2/evaluation_test.py +180 -0
  38. langfun/core/eval/v2/example.py +305 -0
  39. langfun/core/eval/v2/example_test.py +128 -0
  40. langfun/core/eval/v2/experiment.py +1048 -0
  41. langfun/core/eval/v2/experiment_test.py +433 -0
  42. langfun/core/eval/v2/metric_values.py +156 -0
  43. langfun/core/eval/v2/metric_values_test.py +80 -0
  44. langfun/core/eval/v2/metrics.py +357 -0
  45. langfun/core/eval/v2/metrics_test.py +203 -0
  46. langfun/core/eval/v2/progress.py +348 -0
  47. langfun/core/eval/v2/progress_test.py +82 -0
  48. langfun/core/eval/v2/progress_tracking.py +210 -0
  49. langfun/core/eval/v2/progress_tracking_test.py +66 -0
  50. langfun/core/eval/v2/reporting.py +270 -0
  51. langfun/core/eval/v2/reporting_test.py +158 -0
  52. langfun/core/eval/v2/runners.py +488 -0
  53. langfun/core/eval/v2/runners_test.py +334 -0
  54. langfun/core/langfunc.py +4 -17
  55. langfun/core/langfunc_test.py +22 -6
  56. langfun/core/language_model.py +577 -39
  57. langfun/core/language_model_test.py +470 -56
  58. langfun/core/llms/__init__.py +87 -16
  59. langfun/core/llms/anthropic.py +312 -87
  60. langfun/core/llms/anthropic_test.py +71 -3
  61. langfun/core/llms/cache/base.py +21 -2
  62. langfun/core/llms/cache/in_memory.py +13 -0
  63. langfun/core/llms/cache/in_memory_test.py +53 -2
  64. langfun/core/llms/compositional.py +101 -0
  65. langfun/core/llms/compositional_test.py +73 -0
  66. langfun/core/llms/deepseek.py +117 -0
  67. langfun/core/llms/deepseek_test.py +61 -0
  68. langfun/core/llms/fake.py +11 -7
  69. langfun/core/llms/fake_test.py +14 -0
  70. langfun/core/llms/gemini.py +507 -0
  71. langfun/core/llms/gemini_test.py +195 -0
  72. langfun/core/llms/google_genai.py +62 -218
  73. langfun/core/llms/google_genai_test.py +9 -202
  74. langfun/core/llms/groq.py +160 -144
  75. langfun/core/llms/groq_test.py +31 -137
  76. langfun/core/llms/llama_cpp.py +15 -42
  77. langfun/core/llms/llama_cpp_test.py +4 -30
  78. langfun/core/llms/openai.py +395 -203
  79. langfun/core/llms/openai_compatible.py +179 -0
  80. langfun/core/llms/openai_compatible_test.py +495 -0
  81. langfun/core/llms/openai_test.py +30 -395
  82. langfun/core/llms/rest.py +113 -0
  83. langfun/core/llms/rest_test.py +111 -0
  84. langfun/core/llms/vertexai.py +192 -0
  85. langfun/core/llms/vertexai_test.py +52 -0
  86. langfun/core/logging.py +284 -0
  87. langfun/core/logging_test.py +125 -0
  88. langfun/core/message.py +319 -9
  89. langfun/core/message_test.py +190 -13
  90. langfun/core/modalities/__init__.py +6 -2
  91. langfun/core/modalities/audio.py +30 -0
  92. langfun/core/modalities/audio_test.py +63 -0
  93. langfun/core/modalities/image.py +39 -20
  94. langfun/core/modalities/image_test.py +52 -9
  95. langfun/core/modalities/mime.py +206 -29
  96. langfun/core/modalities/mime_test.py +90 -9
  97. langfun/core/modalities/ms_office.py +117 -0
  98. langfun/core/modalities/ms_office_test.py +389 -0
  99. langfun/core/modalities/pdf.py +22 -0
  100. langfun/core/modalities/pdf_test.py +57 -0
  101. langfun/core/modalities/video.py +9 -26
  102. langfun/core/modalities/video_test.py +3 -3
  103. langfun/core/modality.py +26 -3
  104. langfun/core/modality_test.py +2 -2
  105. langfun/core/sampling.py +11 -11
  106. langfun/core/structured/__init__.py +12 -16
  107. langfun/core/structured/completion.py +32 -5
  108. langfun/core/structured/completion_test.py +7 -6
  109. langfun/core/structured/description.py +2 -2
  110. langfun/core/structured/description_test.py +3 -3
  111. langfun/core/structured/function_generation.py +60 -27
  112. langfun/core/structured/function_generation_test.py +72 -2
  113. langfun/core/structured/mapping.py +97 -47
  114. langfun/core/structured/mapping_test.py +90 -2
  115. langfun/core/structured/parsing.py +33 -21
  116. langfun/core/structured/parsing_test.py +53 -9
  117. langfun/core/structured/querying.py +746 -0
  118. langfun/core/structured/{prompting_test.py → querying_test.py} +469 -51
  119. langfun/core/structured/schema.py +204 -97
  120. langfun/core/structured/schema_generation.py +1 -1
  121. langfun/core/structured/schema_test.py +130 -29
  122. langfun/core/structured/scoring.py +125 -19
  123. langfun/core/structured/scoring_test.py +30 -0
  124. langfun/core/structured/tokenization.py +64 -0
  125. langfun/core/structured/tokenization_test.py +48 -0
  126. langfun/core/template.py +115 -1
  127. langfun/core/template_test.py +71 -1
  128. langfun/core/templates/conversation.py +9 -0
  129. langfun/core/templates/conversation_test.py +4 -3
  130. langfun/core/templates/selfplay_test.py +10 -2
  131. langfun-0.1.2.dev202501150804.dist-info/METADATA +225 -0
  132. langfun-0.1.2.dev202501150804.dist-info/RECORD +153 -0
  133. {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501150804.dist-info}/WHEEL +1 -1
  134. langfun/core/coding/python/errors.py +0 -108
  135. langfun/core/coding/python/errors_test.py +0 -99
  136. langfun/core/coding/python/permissions.py +0 -90
  137. langfun/core/coding/python/permissions_test.py +0 -86
  138. langfun/core/structured/prompting.py +0 -238
  139. langfun/core/text_formatting.py +0 -162
  140. langfun/core/text_formatting_test.py +0 -47
  141. langfun-0.0.2.dev20240429.dist-info/METADATA +0 -100
  142. langfun-0.0.2.dev20240429.dist-info/RECORD +0 -108
  143. {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501150804.dist-info}/LICENSE +0 -0
  144. {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501150804.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,854 @@
1
+ # Copyright 2024 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Base classes for agentic actions."""
15
+
16
+ import abc
17
+ import contextlib
18
+ import datetime
19
+ import time
20
+
21
+ import typing
22
+ from typing import Annotated, Any, ContextManager, Iterable, Iterator, Optional, Type, Union
23
+ import langfun.core as lf
24
+ from langfun.core import structured as lf_structured
25
+ import pyglove as pg
26
+
27
+
28
+ class Action(pg.Object):
29
+ """Base class for agent actions."""
30
+
31
+ def _on_bound(self):
32
+ super()._on_bound()
33
+ self._session = None
34
+ self._result = None
35
+ self._metadata = {}
36
+
37
+ @property
38
+ def session(self) -> Optional['Session']:
39
+ """Returns the session started by this action."""
40
+ return self._session
41
+
42
+ @property
43
+ def result(self) -> Any:
44
+ """Returns the result of the action."""
45
+ return self._result
46
+
47
+ @property
48
+ def metadata(self) -> dict[str, Any] | None:
49
+ """Returns the metadata associated with the result from previous call."""
50
+ return self._metadata
51
+
52
+ def __call__(
53
+ self,
54
+ session: Optional['Session'] = None,
55
+ *,
56
+ show_progress: bool = True,
57
+ **kwargs) -> Any:
58
+ """Executes the action."""
59
+ new_session = session is None
60
+ if new_session:
61
+ session = Session()
62
+ if show_progress:
63
+ lf.console.display(pg.view(session, name='agent_session'))
64
+
65
+ with session.track_action(self):
66
+ result = self.call(session=session, **kwargs)
67
+
68
+ # For the top-level action, we store the session in the metadata.
69
+ if new_session:
70
+ self._session = session
71
+ self._result = result
72
+ self._metadata = session.current_action.metadata
73
+ return self._result
74
+
75
+ @abc.abstractmethod
76
+ def call(self, session: 'Session', **kwargs) -> Any:
77
+ """Calls the action.
78
+
79
+ Args:
80
+ session: The session to use for the action.
81
+ **kwargs: Additional keyword arguments to pass to the action.
82
+
83
+ Returns:
84
+ The result of the action.
85
+ """
86
+
87
+
88
+ # Type definition for traced item during execution.
89
+ TracedItem = Union[
90
+ lf_structured.QueryInvocation,
91
+ 'ActionInvocation',
92
+ 'ExecutionTrace',
93
+ # NOTE(daiyip): Consider remove log entry once we migrate existing agents.
94
+ lf.logging.LogEntry,
95
+ ]
96
+
97
+
98
+ class ExecutionTrace(pg.Object, pg.views.html.HtmlTreeView.Extension):
99
+ """Trace of the execution of an action."""
100
+
101
+ name: Annotated[
102
+ str | None,
103
+ (
104
+ 'The name of the execution trace. If None, the trace is unnamed, '
105
+ 'which is the case for the top-level trace of an action. An '
106
+ 'execution trace could have sub-traces, called phases, which are '
107
+ 'created and named by `session.phase()` context manager.'
108
+ )
109
+ ] = None
110
+
111
+ start_time: Annotated[
112
+ float | None,
113
+ 'The start time of the execution. If None, the execution is not started.'
114
+ ] = None
115
+
116
+ end_time: Annotated[
117
+ float | None,
118
+ 'The end time of the execution. If None, the execution is not ended.'
119
+ ] = None
120
+
121
+ items: Annotated[
122
+ list[TracedItem],
123
+ 'All tracked execution items in the sequence.'
124
+ ] = []
125
+
126
+ def _on_bound(self):
127
+ super()._on_bound()
128
+ self._usage_summary = lf.UsageSummary()
129
+ for item in self.items:
130
+ if hasattr(item, 'usage_summary'):
131
+ self._usage_summary.merge(item.usage_summary)
132
+
133
+ self._tab_control = None
134
+ self._time_badge = None
135
+
136
+ def start(self) -> None:
137
+ assert self.start_time is None, 'Execution already started.'
138
+ self.rebind(start_time=time.time(), skip_notification=True)
139
+ if self._time_badge is not None:
140
+ self._time_badge.update(
141
+ 'Starting',
142
+ add_class=['starting'],
143
+ remove_class=['not-started'],
144
+ )
145
+
146
+ def stop(self) -> None:
147
+ assert self.end_time is None, 'Execution already stopped.'
148
+ self.rebind(end_time=time.time(), skip_notification=True)
149
+ if self._time_badge is not None:
150
+ self._time_badge.update(
151
+ f'{int(self.elapse)} seconds',
152
+ tooltip=pg.format(self.execution_summary(), verbose=False),
153
+ add_class=['finished'],
154
+ remove_class=['running'],
155
+ )
156
+
157
+ def __len__(self) -> int:
158
+ return len(self.items)
159
+
160
+ @property
161
+ def has_started(self) -> bool:
162
+ return self.start_time is not None
163
+
164
+ @property
165
+ def has_stopped(self) -> bool:
166
+ return self.end_time is not None
167
+
168
+ @property
169
+ def elapse(self) -> float:
170
+ """Returns the elapsed time of the execution."""
171
+ if self.start_time is None:
172
+ return 0.0
173
+ if self.end_time is None:
174
+ return time.time() - self.start_time
175
+ return self.end_time - self.start_time
176
+
177
+ @property
178
+ def queries(self) -> list[lf_structured.QueryInvocation]:
179
+ """Returns queries from the sequence."""
180
+ return self._child_items(lf_structured.QueryInvocation)
181
+
182
+ @property
183
+ def actions(self) -> list['ActionInvocation']:
184
+ """Returns action invocations from the sequence."""
185
+ return self._child_items(ActionInvocation)
186
+
187
+ @property
188
+ def logs(self) -> list[lf.logging.LogEntry]:
189
+ """Returns logs from the sequence."""
190
+ return self._child_items(lf.logging.LogEntry)
191
+
192
+ @property
193
+ def all_queries(self) -> list[lf_structured.QueryInvocation]:
194
+ """Returns all queries from current trace and its child execution items."""
195
+ return self._all_child_items(lf_structured.QueryInvocation)
196
+
197
+ @property
198
+ def all_logs(self) -> list[lf.logging.LogEntry]:
199
+ """Returns all logs from current trace and its child execution items."""
200
+ return self._all_child_items(lf.logging.LogEntry)
201
+
202
+ def _child_items(self, item_cls: Type[Any]) -> list[Any]:
203
+ child_items = []
204
+ for item in self.items:
205
+ if isinstance(item, item_cls):
206
+ child_items.append(item)
207
+ elif isinstance(item, ExecutionTrace):
208
+ child_items.extend(item._child_items(item_cls)) # pylint: disable=protected-access
209
+ return child_items
210
+
211
+ def _all_child_items(self, item_cls: Type[Any]) -> list[Any]:
212
+ child_items = []
213
+ for item in self.items:
214
+ if isinstance(item, item_cls):
215
+ child_items.append(item)
216
+ elif isinstance(item, ActionInvocation):
217
+ child_items.extend(item.execution._all_child_items(item_cls)) # pylint: disable=protected-access
218
+ elif isinstance(item, ExecutionTrace):
219
+ child_items.extend(item._all_child_items(item_cls)) # pylint: disable=protected-access
220
+ return child_items
221
+
222
+ def append(self, item: TracedItem) -> None:
223
+ """Appends an item to the sequence."""
224
+ with pg.notify_on_change(False):
225
+ self.items.append(item)
226
+
227
+ if isinstance(item, lf_structured.QueryInvocation):
228
+ current_invocation = self
229
+ while current_invocation is not None:
230
+ current_invocation.usage_summary.merge(item.usage_summary)
231
+ current_invocation = typing.cast(
232
+ ExecutionTrace,
233
+ current_invocation.sym_ancestor(
234
+ lambda x: isinstance(x, ExecutionTrace)
235
+ )
236
+ )
237
+
238
+ if self._tab_control is not None:
239
+ self._tab_control.append(self._execution_item_tab(item))
240
+
241
+ if (self._time_badge is not None
242
+ and not isinstance(item, lf.logging.LogEntry)):
243
+ sub_task_label = self._execution_item_label(item)
244
+ self._time_badge.update(
245
+ text=sub_task_label.text,
246
+ tooltip=sub_task_label.tooltip.content,
247
+ add_class=['running'],
248
+ remove_class=['not-started'],
249
+ )
250
+
251
+ def extend(self, items: Iterable[TracedItem]) -> None:
252
+ """Extends the sequence with a list of items."""
253
+ for item in items:
254
+ self.append(item)
255
+
256
+ @property
257
+ def usage_summary(self) -> lf.UsageSummary:
258
+ """Returns the usage summary of the action."""
259
+ return self._usage_summary
260
+
261
+ def execution_summary(self) -> dict[str, Any]:
262
+ """Execution summary string."""
263
+ return pg.Dict(
264
+ num_queries=len(self.queries),
265
+ execution_breakdown=[
266
+ dict(
267
+ action=action.action.__class__.__name__,
268
+ usage=action.usage_summary.total,
269
+ execution_time=action.execution.elapse,
270
+ )
271
+ for action in self.actions
272
+ ]
273
+ )
274
+
275
+ #
276
+ # HTML views.
277
+ #
278
+
279
+ def _html_tree_view_summary(
280
+ self,
281
+ *,
282
+ name: str | None = None,
283
+ extra_flags: dict[str, Any] | None = None,
284
+ view: pg.views.html.HtmlTreeView, **kwargs
285
+ ):
286
+ return None
287
+
288
+ def _execution_badge(self, interactive: bool = True):
289
+ if not self.has_started:
290
+ label = '(Not started)'
291
+ tooltip = 'Execution not started.'
292
+ css_class = 'not-started'
293
+ elif not self.has_stopped:
294
+ label = 'Starting'
295
+ tooltip = 'Execution starting.'
296
+ css_class = 'running'
297
+ else:
298
+ label = f'{int(self.elapse)} seconds'
299
+ tooltip = pg.format(self.execution_summary(), verbose=False)
300
+ css_class = 'finished'
301
+ time_badge = pg.views.html.controls.Badge(
302
+ label,
303
+ tooltip=tooltip,
304
+ css_classes=['execution-time', css_class],
305
+ interactive=interactive,
306
+ )
307
+ if interactive:
308
+ self._time_badge = time_badge
309
+ return time_badge
310
+
311
+ def _html_tree_view_content(
312
+ self,
313
+ *,
314
+ extra_flags: dict[str, Any] | None = None,
315
+ **kwargs
316
+ ):
317
+ del kwargs
318
+ extra_flags = extra_flags or {}
319
+ interactive = extra_flags.get('interactive', True)
320
+ if interactive or self.items:
321
+ self._tab_control = pg.views.html.controls.TabControl(
322
+ [self._execution_item_tab(item) for item in self.items],
323
+ tab_position='left'
324
+ )
325
+ return self._tab_control.to_html()
326
+ return '(no tracked items)'
327
+
328
+ def _execution_item_tab(self, item: TracedItem) -> pg.views.html.controls.Tab:
329
+ if isinstance(item, ActionInvocation):
330
+ css_class = 'action'
331
+ elif isinstance(item, lf_structured.QueryInvocation):
332
+ css_class = 'query'
333
+ elif isinstance(item, lf.logging.LogEntry):
334
+ css_class = 'log'
335
+ elif isinstance(item, ExecutionTrace):
336
+ css_class = 'phase'
337
+ else:
338
+ raise ValueError(f'Unsupported item type: {type(item)}')
339
+
340
+ return pg.views.html.controls.Tab(
341
+ label=self._execution_item_label(item),
342
+ content=pg.view(item),
343
+ css_classes=[css_class]
344
+ )
345
+
346
+ def _execution_item_label(
347
+ self, item: TracedItem
348
+ ) -> pg.views.html.controls.Label:
349
+ if isinstance(item, ActionInvocation):
350
+ return pg.views.html.controls.Label(
351
+ item.action.__class__.__name__,
352
+ tooltip=pg.format(
353
+ item.action,
354
+ verbose=False,
355
+ hide_default_values=True,
356
+ max_str_len=80,
357
+ max_bytes_len=20,
358
+ ),
359
+ )
360
+ elif isinstance(item, lf_structured.QueryInvocation):
361
+ schema_title = 'str'
362
+ if item.schema:
363
+ schema_title = lf_structured.annotation(item.schema.spec)
364
+ return pg.views.html.controls.Label(
365
+ schema_title,
366
+ tooltip=(
367
+ pg.format(
368
+ item.input,
369
+ verbose=False,
370
+ hide_default_values=True,
371
+ max_str_len=80,
372
+ max_bytes_len=20,
373
+ )
374
+ ),
375
+ )
376
+ elif isinstance(item, lf.logging.LogEntry):
377
+ return pg.views.html.controls.Label(
378
+ 'Log',
379
+ tooltip=item.message,
380
+ )
381
+ elif isinstance(item, ExecutionTrace):
382
+ return pg.views.html.controls.Label(
383
+ item.name or 'Phase',
384
+ tooltip=f'Execution phase {item.name!r}.'
385
+ )
386
+ else:
387
+ raise ValueError(f'Unsupported item type: {type(item)}')
388
+
389
+ def _html_tree_view_css_styles(self) -> list[str]:
390
+ return super()._html_tree_view_css_styles() + [
391
+ """
392
+ .tab-button.action > ::before {
393
+ content: "A";
394
+ font-weight: bold;
395
+ color: red;
396
+ padding: 10px;
397
+ }
398
+ .tab-button.phase > ::before {
399
+ content: "P";
400
+ font-weight: bold;
401
+ color: purple;
402
+ padding: 10px;
403
+ }
404
+ .tab-button.query > ::before {
405
+ content: "Q";
406
+ font-weight: bold;
407
+ color: orange;
408
+ padding: 10px;
409
+ }
410
+ .tab-button.log > ::before {
411
+ content: "L";
412
+ font-weight: bold;
413
+ color: green;
414
+ padding: 10px;
415
+ }
416
+ .details.execution-trace, .details.action-invocation {
417
+ border: 1px solid #eee;
418
+ }
419
+ .execution-trace-title {
420
+ display: inline-block;
421
+ }
422
+ .badge.execution-time {
423
+ margin-left: 4px;
424
+ border-radius: 0px;
425
+ }
426
+ .execution-time.starting {
427
+ background-color: ghostwhite;
428
+ font-weight: normal;
429
+ }
430
+ .execution-time.running {
431
+ background-color: ghostwhite;
432
+ font-weight: normal;
433
+ }
434
+ .execution-time.finished {
435
+ background-color: aliceblue;
436
+ font-weight: bold;
437
+ }
438
+ """
439
+ ]
440
+
441
+
442
+ class ActionInvocation(pg.Object, pg.views.html.HtmlTreeView.Extension):
443
+ """A class for capturing the invocation of an action."""
444
+ action: Action
445
+
446
+ result: Annotated[
447
+ Any,
448
+ 'The result of the action.'
449
+ ] = None
450
+
451
+ metadata: Annotated[
452
+ dict[str, Any],
453
+ 'The metadata returned by the action.'
454
+ ] = {}
455
+
456
+ execution: Annotated[
457
+ ExecutionTrace,
458
+ 'The execution sequence of the action.'
459
+ ] = ExecutionTrace()
460
+
461
+ # Allow symbolic assignment without `rebind`.
462
+ allow_symbolic_assignment = True
463
+
464
+ def _on_bound(self):
465
+ super()._on_bound()
466
+ self._current_phase = self.execution
467
+ self._tab_control = None
468
+
469
+ @property
470
+ def current_phase(self) -> ExecutionTrace:
471
+ """Returns the current execution phase."""
472
+ return self._current_phase
473
+
474
+ @contextlib.contextmanager
475
+ def phase(self, name: str) -> Iterator[ExecutionTrace]:
476
+ """Context manager for starting a new execution phase."""
477
+ phase = ExecutionTrace(name=name)
478
+ phase.start()
479
+ parent_phase = self._current_phase
480
+ self._current_phase.append(phase)
481
+ self._current_phase = phase
482
+ try:
483
+ yield phase
484
+ finally:
485
+ phase.stop()
486
+ self._current_phase = parent_phase
487
+
488
+ @property
489
+ def logs(self) -> list[lf.logging.LogEntry]:
490
+ """Returns immediate child logs from execution sequence."""
491
+ return self.execution.logs
492
+
493
+ @property
494
+ def actions(self) -> list['ActionInvocation']:
495
+ """Returns immediate child action invocations."""
496
+ return self.execution.actions
497
+
498
+ @property
499
+ def queries(self) -> list[lf_structured.QueryInvocation]:
500
+ """Returns immediate queries made by the action."""
501
+ return self.execution.queries
502
+
503
+ @property
504
+ def all_queries(self) -> list[lf_structured.QueryInvocation]:
505
+ """Returns all queries made by the action and its child execution items."""
506
+ return self.execution.all_queries
507
+
508
+ @property
509
+ def all_logs(self) -> list[lf.logging.LogEntry]:
510
+ """Returns all logs made by the action and its child execution items."""
511
+ return self.execution.all_logs
512
+
513
+ @property
514
+ def usage_summary(self) -> lf.UsageSummary:
515
+ """Returns the usage summary of the action."""
516
+ return self.execution.usage_summary
517
+
518
+ def start(self) -> None:
519
+ """Starts the execution of the action."""
520
+ self.execution.start()
521
+
522
+ def end(self, result: Any, metadata: dict[str, Any]) -> None:
523
+ """Ends the execution of the action with result and metadata."""
524
+ self.rebind(
525
+ result=result,
526
+ metadata=metadata,
527
+ skip_notification=True,
528
+ raise_on_no_change=False
529
+ )
530
+ self.execution.stop()
531
+ if self._tab_control is not None:
532
+ if self.metadata:
533
+ self._tab_control.insert(
534
+ 1,
535
+ pg.views.html.controls.Tab(
536
+ 'metadata',
537
+ pg.view(
538
+ self.metadata,
539
+ collapse_level=None,
540
+ enable_summary_tooltip=False
541
+ ),
542
+ name='metadata',
543
+ )
544
+ )
545
+ self._tab_control.insert(
546
+ 1,
547
+ pg.views.html.controls.Tab(
548
+ 'result',
549
+ pg.view(
550
+ self.result,
551
+ collapse_level=None,
552
+ enable_summary_tooltip=False
553
+ ),
554
+ name='result',
555
+ ),
556
+ )
557
+ self._tab_control.select(['metadata', 'result'])
558
+
559
+ #
560
+ # HTML views.
561
+ #
562
+
563
+ def _html_tree_view_summary(
564
+ self, *, view: pg.views.html.HtmlTreeView, **kwargs
565
+ ):
566
+ return None
567
+
568
+ def _html_tree_view_content(
569
+ self,
570
+ *,
571
+ view: pg.views.html.HtmlTreeView,
572
+ extra_flags: dict[str, Any] | None = None,
573
+ **kwargs
574
+ ):
575
+ extra_flags = extra_flags or {}
576
+ interactive = extra_flags.get('interactive', True)
577
+ if (isinstance(self.action, RootAction)
578
+ and self.execution.has_stopped
579
+ and len(self.execution) == 1):
580
+ return view.content(self.execution.items[0], extra_flags=extra_flags)
581
+
582
+ tabs = []
583
+ if not isinstance(self.action, RootAction):
584
+ tabs.append(
585
+ pg.views.html.controls.Tab(
586
+ 'action',
587
+ view.render( # pylint: disable=g-long-ternary
588
+ self.action,
589
+ collapse_level=None,
590
+ root_path=self.action.sym_path,
591
+ enable_summary_tooltip=False,
592
+ ),
593
+ name='action',
594
+ )
595
+ )
596
+ if self.execution.has_stopped:
597
+ tabs.append(
598
+ pg.views.html.controls.Tab(
599
+ 'result',
600
+ view.render(
601
+ self.result,
602
+ collapse_level=None,
603
+ enable_summary_tooltip=False
604
+ ),
605
+ name='result'
606
+ )
607
+ )
608
+ if self.metadata:
609
+ tabs.append(
610
+ pg.views.html.controls.Tab(
611
+ 'metadata',
612
+ view.render(
613
+ self.metadata,
614
+ collapse_level=None,
615
+ enable_summary_tooltip=False
616
+ ),
617
+ name='metadata'
618
+ )
619
+ )
620
+
621
+ tabs.append(
622
+ pg.views.html.controls.Tab(
623
+ pg.Html.element(
624
+ 'span',
625
+ [
626
+ 'execution',
627
+ self.execution._execution_badge(interactive), # pylint: disable=protected-access
628
+ (
629
+ self.usage_summary.to_html( # pylint: disable=g-long-ternary
630
+ extra_flags=dict(as_badge=True)
631
+ )
632
+ if (interactive
633
+ or self.usage_summary.total.num_requests > 0)
634
+ else None
635
+ ),
636
+ ],
637
+ css_classes=['execution-tab-title']
638
+ ),
639
+ view.render(self.execution, extra_flags=extra_flags),
640
+ name='execution',
641
+ )
642
+ )
643
+ tab_control = pg.views.html.controls.TabControl(tabs)
644
+ # Select the tab following a priority: metadata, result, action, execution.
645
+ tab_control.select(['metadata', 'result', 'action', 'execution'])
646
+ if interactive:
647
+ self._tab_control = tab_control
648
+ return tab_control
649
+
650
+ @classmethod
651
+ def _html_tree_view_css_styles(cls) -> list[str]:
652
+ return super()._html_tree_view_css_styles() + [
653
+ """
654
+ .execution-tab-title {
655
+ text-align: left;
656
+ }
657
+ .execution-tab-title .usage-summary.label {
658
+ border-radius: 0px;
659
+ font-weight: normal;
660
+ color: #AAA;
661
+ }
662
+ """
663
+ ]
664
+
665
+
666
+ class RootAction(Action):
667
+ """A placeholder action for the root of the action tree."""
668
+
669
+ def call(self, session: 'Session', **kwargs) -> Any:
670
+ raise NotImplementedError('Shall not be called.')
671
+
672
+
673
+ class Session(pg.Object, pg.views.html.HtmlTreeView.Extension):
674
+ """Session for performing an agentic task."""
675
+
676
+ root: ActionInvocation = ActionInvocation(RootAction())
677
+
678
+ def _on_bound(self):
679
+ super()._on_bound()
680
+ self._current_action = self.root
681
+
682
+ @property
683
+ def final_result(self) -> Any:
684
+ """Returns the final result of the session."""
685
+ return self.root.result
686
+
687
+ @property
688
+ def current_action(self) -> ActionInvocation:
689
+ """Returns the current invocation."""
690
+ return self._current_action
691
+
692
+ def add_metadata(self, **kwargs: Any) -> None:
693
+ """Adds metadata to the current invocation."""
694
+ with pg.notify_on_change(False):
695
+ self._current_action.metadata.update(kwargs)
696
+
697
+ def phase(self, name: str) -> ContextManager[ExecutionTrace]:
698
+ """Context manager for starting a new execution phase."""
699
+ return self.current_action.phase(name)
700
+
701
+ @contextlib.contextmanager
702
+ def track_action(self, action: Action) -> Iterator[ActionInvocation]:
703
+ """Track the execution of an action."""
704
+ if not self.root.execution.has_started:
705
+ self.root.start()
706
+
707
+ invocation = ActionInvocation(pg.maybe_ref(action))
708
+ parent_action = self._current_action
709
+ parent_action.current_phase.append(invocation)
710
+
711
+ try:
712
+ self._current_action = invocation
713
+ # Start the execution of the current action.
714
+ self._current_action.start()
715
+ yield invocation
716
+ finally:
717
+ # Stop the execution of the current action.
718
+ self._current_action.end(action.result, action.metadata)
719
+ self._current_action = parent_action
720
+ if parent_action is self.root:
721
+ parent_action.end(
722
+ result=action.result, metadata=action.metadata,
723
+ )
724
+
725
+ @contextlib.contextmanager
726
+ def track_queries(
727
+ self,
728
+ phase: str | None = None
729
+ ) -> Iterator[list[lf_structured.QueryInvocation]]:
730
+ """Tracks `lf.query` made within the context.
731
+
732
+ Args:
733
+ phase: The name of a new phase to track the queries in. If not provided,
734
+ the queries will be tracked in the parent phase.
735
+
736
+ Yields:
737
+ A list of `lf.QueryInvocation` objects, each for a single `lf.query`
738
+ call.
739
+ """
740
+ with self.phase(phase) if phase else contextlib.nullcontext():
741
+ with lf_structured.track_queries(include_child_scopes=False) as queries:
742
+ try:
743
+ yield queries
744
+ finally:
745
+ self._current_action.current_phase.extend(queries)
746
+
747
+ def query(
748
+ self,
749
+ prompt: Union[str, lf.Template, Any],
750
+ schema: Union[
751
+ lf_structured.Schema, Type[Any], list[Type[Any]], dict[str, Any], None
752
+ ] = None,
753
+ default: Any = lf.RAISE_IF_HAS_ERROR,
754
+ *,
755
+ lm: lf.LanguageModel | None = None,
756
+ examples: list[lf_structured.MappingExample] | None = None,
757
+ **kwargs
758
+ ) -> Any:
759
+ """Calls `lf.query` and associates it with the current invocation.
760
+
761
+ The following code are equivalent:
762
+
763
+ Code 1:
764
+ ```
765
+ session.query(...)
766
+ ```
767
+
768
+ Code 2:
769
+ ```
770
+ with session.track_queries() as queries:
771
+ output = lf.query(...)
772
+ ```
773
+ The former is preferred when `lf.query` is directly called by the action.
774
+ If `lf.query` is called by a function that does not have access to the
775
+ session, the latter should be used.
776
+
777
+ Args:
778
+ prompt: The prompt to query.
779
+ schema: The schema to use for the query.
780
+ default: The default value to return if the query fails.
781
+ lm: The language model to use for the query.
782
+ examples: The examples to use for the query.
783
+ **kwargs: Additional keyword arguments to pass to `lf.query`.
784
+
785
+ Returns:
786
+ The result of the query.
787
+ """
788
+ with self.track_queries():
789
+ return lf_structured.query(
790
+ prompt,
791
+ schema=schema,
792
+ default=default,
793
+ lm=lm,
794
+ examples=examples,
795
+ **kwargs
796
+ )
797
+
798
+ def _log(self, level: lf.logging.LogLevel, message: str, **kwargs):
799
+ self._current_action.current_phase.append(
800
+ lf.logging.LogEntry(
801
+ level=level,
802
+ time=datetime.datetime.now(),
803
+ message=message,
804
+ metadata=kwargs,
805
+ )
806
+ )
807
+
808
+ def debug(self, message: str, **kwargs):
809
+ """Logs a debug message to the session."""
810
+ self._log('debug', message, **kwargs)
811
+
812
+ def info(self, message: str, **kwargs):
813
+ """Logs an info message to the session."""
814
+ self._log('info', message, **kwargs)
815
+
816
+ def warning(self, message: str, **kwargs):
817
+ """Logs a warning message to the session."""
818
+ self._log('warning', message, **kwargs)
819
+
820
+ def error(self, message: str, **kwargs):
821
+ """Logs an error message to the session."""
822
+ self._log('error', message, **kwargs)
823
+
824
+ def fatal(self, message: str, **kwargs):
825
+ """Logs a fatal message to the session."""
826
+ self._log('fatal', message, **kwargs)
827
+
828
+ def as_message(self) -> lf.AIMessage:
829
+ """Returns the session as a message."""
830
+ return lf.AIMessage(
831
+ 'Agentic task session.',
832
+ result=self.root
833
+ )
834
+
835
+ #
836
+ # HTML views.
837
+ #
838
+
839
+ def _html_tree_view_content(
840
+ self,
841
+ *,
842
+ view: pg.views.html.HtmlTreeView,
843
+ **kwargs
844
+ ):
845
+ return view.content(self.root, **kwargs)
846
+
847
+ @classmethod
848
+ def _html_tree_view_config(cls):
849
+ config = super()._html_tree_view_config()
850
+ config.update(
851
+ enable_summary_tooltip=False,
852
+ enable_key_tooltip=False,
853
+ )
854
+ return config