langfun 0.1.2.dev202412080804__py3-none-any.whl → 0.1.2.dev202412100804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -15,7 +15,11 @@
15
15
 
16
16
  import abc
17
17
  import contextlib
18
- from typing import Annotated, Any, Iterable, Iterator, Optional, Type, Union
18
+ import datetime
19
+ import time
20
+
21
+ import typing
22
+ from typing import Annotated, Any, ContextManager, Iterable, Iterator, Optional, Type, Union
19
23
  import langfun.core as lf
20
24
  from langfun.core import structured as lf_structured
21
25
  import pyglove as pg
@@ -26,24 +30,420 @@ class Action(pg.Object):
26
30
 
27
31
  def _on_bound(self):
28
32
  super()._on_bound()
33
+ self._session = None
29
34
  self._result = None
35
+ self._result_metadata = {}
36
+
37
+ @property
38
+ def session(self) -> Optional['Session']:
39
+ """Returns the session started by this action."""
40
+ return self._session
30
41
 
31
42
  @property
32
43
  def result(self) -> Any:
33
44
  """Returns the result of the action."""
34
45
  return self._result
35
46
 
47
+ @property
48
+ def result_metadata(self) -> dict[str, Any] | None:
49
+ """Returns the metadata associated with the result from previous call."""
50
+ return self._result_metadata
51
+
36
52
  def __call__(
37
- self, session: Optional['Session'] = None, **kwargs) -> Any:
53
+ self,
54
+ session: Optional['Session'] = None,
55
+ *,
56
+ show_progress: bool = True,
57
+ **kwargs) -> Any:
38
58
  """Executes the action."""
39
- session = session or Session()
40
- with session.track(self):
41
- self._result = self.call(session=session, **kwargs)
59
+ new_session = session is None
60
+ if new_session:
61
+ session = Session()
62
+ if show_progress:
63
+ lf.console.display(pg.view(session, name='agent_session'))
64
+
65
+ with session.track_action(self):
66
+ result = self.call(session=session, **kwargs)
67
+ metadata = dict()
68
+ if (isinstance(result, tuple)
69
+ and len(result) == 2 and isinstance(result[1], dict)):
70
+ result, metadata = result
71
+
72
+ # For the top-level action, we store the session in the metadata.
73
+ if new_session:
74
+ self._session = session
75
+ self._result, self._result_metadata = result, metadata
42
76
  return self._result
43
77
 
44
78
  @abc.abstractmethod
45
- def call(self, session: 'Session', **kwargs) -> Any:
46
- """Subclasses to implement."""
79
+ def call(
80
+ self,
81
+ session: 'Session',
82
+ **kwargs
83
+ ) -> Union[Any, tuple[Any, dict[str, Any]]]:
84
+ """Calls the action.
85
+
86
+ Args:
87
+ session: The session to use for the action.
88
+ **kwargs: Additional keyword arguments to pass to the action.
89
+
90
+ Returns:
91
+ The result of the action or a tuple of (result, result_metadata).
92
+ """
93
+
94
+
95
+ # Type definition for traced item during execution.
96
+ TracedItem = Union[
97
+ lf_structured.QueryInvocation,
98
+ 'ActionInvocation',
99
+ 'ExecutionTrace',
100
+ # NOTE(daiyip): Consider remove log entry once we migrate existing agents.
101
+ lf.logging.LogEntry,
102
+ ]
103
+
104
+
105
+ class ExecutionTrace(pg.Object, pg.views.html.HtmlTreeView.Extension):
106
+ """Trace of the execution of an action."""
107
+
108
+ name: Annotated[
109
+ str | None,
110
+ (
111
+ 'The name of the execution trace. If None, the trace is unnamed, '
112
+ 'which is the case for the top-level trace of an action. An '
113
+ 'execution trace could have sub-traces, called phases, which are '
114
+ 'created and named by `session.phase()` context manager.'
115
+ )
116
+ ] = None
117
+
118
+ start_time: Annotated[
119
+ float | None,
120
+ 'The start time of the execution. If None, the execution is not started.'
121
+ ] = None
122
+
123
+ end_time: Annotated[
124
+ float | None,
125
+ 'The end time of the execution. If None, the execution is not ended.'
126
+ ] = None
127
+
128
+ items: Annotated[
129
+ list[TracedItem],
130
+ 'All tracked execution items in the sequence.'
131
+ ] = []
132
+
133
+ def _on_bound(self):
134
+ super()._on_bound()
135
+ self._usage_summary = lf.UsageSummary()
136
+ for item in self.items:
137
+ if hasattr(item, 'usage_summary'):
138
+ self._usage_summary.merge(item.usage_summary)
139
+
140
+ self._tab_control = None
141
+ self._time_badge = None
142
+
143
+ def start(self) -> None:
144
+ assert self.start_time is None, 'Execution already started.'
145
+ self.rebind(start_time=time.time(), skip_notification=True)
146
+ if self._time_badge is not None:
147
+ self._time_badge.update(
148
+ 'Starting',
149
+ add_class=['running'],
150
+ remove_class=['not-started'],
151
+ )
152
+
153
+ def stop(self) -> None:
154
+ assert self.end_time is None, 'Execution already stopped.'
155
+ self.rebind(end_time=time.time(), skip_notification=True)
156
+ if self._time_badge is not None:
157
+ self._time_badge.update(
158
+ f'{int(self.elapse)} seconds',
159
+ add_class=['finished'],
160
+ remove_class=['running'],
161
+ )
162
+
163
+ @property
164
+ def has_started(self) -> bool:
165
+ return self.start_time is not None
166
+
167
+ @property
168
+ def has_stopped(self) -> bool:
169
+ return self.end_time is not None
170
+
171
+ @property
172
+ def elapse(self) -> float:
173
+ """Returns the elapsed time of the execution."""
174
+ if self.start_time is None:
175
+ return 0.0
176
+ if self.end_time is None:
177
+ return time.time() - self.start_time
178
+ return self.end_time - self.start_time
179
+
180
+ @property
181
+ def queries(self) -> list[lf_structured.QueryInvocation]:
182
+ """Returns queries from the sequence."""
183
+ return self._child_items(lf_structured.QueryInvocation)
184
+
185
+ @property
186
+ def actions(self) -> list['ActionInvocation']:
187
+ """Returns action invocations from the sequence."""
188
+ return self._child_items(ActionInvocation)
189
+
190
+ @property
191
+ def logs(self) -> list[lf.logging.LogEntry]:
192
+ """Returns logs from the sequence."""
193
+ return self._child_items(lf.logging.LogEntry)
194
+
195
+ @property
196
+ def all_queries(self) -> list[lf_structured.QueryInvocation]:
197
+ """Returns all queries from current trace and its child execution items."""
198
+ return self._all_child_items(lf_structured.QueryInvocation)
199
+
200
+ @property
201
+ def all_logs(self) -> list[lf.logging.LogEntry]:
202
+ """Returns all logs from current trace and its child execution items."""
203
+ return self._all_child_items(lf.logging.LogEntry)
204
+
205
+ def _child_items(self, item_cls: Type[Any]) -> list[Any]:
206
+ child_items = []
207
+ for item in self.items:
208
+ if isinstance(item, item_cls):
209
+ child_items.append(item)
210
+ elif isinstance(item, ExecutionTrace):
211
+ child_items.extend(item._child_items(item_cls)) # pylint: disable=protected-access
212
+ return child_items
213
+
214
+ def _all_child_items(self, item_cls: Type[Any]) -> list[Any]:
215
+ child_items = []
216
+ for item in self.items:
217
+ if isinstance(item, item_cls):
218
+ child_items.append(item)
219
+ elif isinstance(item, ActionInvocation):
220
+ child_items.extend(item.execution._all_child_items(item_cls)) # pylint: disable=protected-access
221
+ elif isinstance(item, ExecutionTrace):
222
+ child_items.extend(item._all_child_items(item_cls)) # pylint: disable=protected-access
223
+ return child_items
224
+
225
+ def append(self, item: TracedItem) -> None:
226
+ """Appends an item to the sequence."""
227
+ with pg.notify_on_change(False):
228
+ self.items.append(item)
229
+
230
+ if isinstance(item, lf_structured.QueryInvocation):
231
+ current_invocation = self
232
+ while current_invocation is not None:
233
+ current_invocation.usage_summary.merge(item.usage_summary)
234
+ current_invocation = typing.cast(
235
+ ExecutionTrace,
236
+ current_invocation.sym_ancestor(
237
+ lambda x: isinstance(x, ExecutionTrace)
238
+ )
239
+ )
240
+
241
+ if self._tab_control is not None:
242
+ self._tab_control.append(self._execution_item_tab(item))
243
+
244
+ if (self._time_badge is not None
245
+ and not isinstance(item, lf.logging.LogEntry)):
246
+ sub_task_label = self._execution_item_label(item)
247
+ self._time_badge.update(
248
+ pg.Html.element(
249
+ 'span',
250
+ [
251
+ 'Running',
252
+ pg.views.html.controls.Badge(
253
+ sub_task_label.text,
254
+ tooltip=sub_task_label.tooltip,
255
+ css_classes=['task-in-progress']
256
+ )
257
+ ]
258
+ ),
259
+ add_class=['running'],
260
+ remove_class=['not-started'],
261
+ )
262
+
263
+ def extend(self, items: Iterable[TracedItem]) -> None:
264
+ """Extends the sequence with a list of items."""
265
+ for item in items:
266
+ self.append(item)
267
+
268
+ @property
269
+ def usage_summary(self) -> lf.UsageSummary:
270
+ """Returns the usage summary of the action."""
271
+ return self._usage_summary
272
+
273
+ #
274
+ # HTML views.
275
+ #
276
+
277
+ def _html_tree_view_summary(
278
+ self,
279
+ *,
280
+ name: str | None = None,
281
+ extra_flags: dict[str, Any] | None = None,
282
+ view: pg.views.html.HtmlTreeView, **kwargs
283
+ ):
284
+ extra_flags = extra_flags or {}
285
+ interactive = extra_flags.get('interactive', True)
286
+ def time_badge():
287
+ if not self.has_started:
288
+ label = '(Not started)'
289
+ css_class = 'not-started'
290
+ elif not self.has_stopped:
291
+ label = 'Starting'
292
+ css_class = 'running'
293
+ else:
294
+ label = f'{int(self.elapse)} seconds'
295
+ css_class = 'finished'
296
+ return pg.views.html.controls.Badge(
297
+ label,
298
+ css_classes=['execution-time', css_class],
299
+ interactive=interactive,
300
+ )
301
+ time_badge = time_badge()
302
+ if interactive:
303
+ self._time_badge = time_badge
304
+ title = pg.Html.element(
305
+ 'div',
306
+ [
307
+ 'ExecutionTrace',
308
+ time_badge,
309
+ ],
310
+ css_classes=['execution-trace-title'],
311
+ )
312
+ kwargs.pop('title', None)
313
+ kwargs['enable_summary_tooltip'] = False
314
+ kwargs['enable_key_tooltip'] = False
315
+ return view.summary(
316
+ self,
317
+ name=name,
318
+ title=title,
319
+ extra_flags=extra_flags,
320
+ **kwargs
321
+ )
322
+
323
+ def _html_tree_view_content(self, **kwargs):
324
+ del kwargs
325
+ self._tab_control = pg.views.html.controls.TabControl(
326
+ [self._execution_item_tab(item) for item in self.items],
327
+ tab_position='left'
328
+ )
329
+ return pg.Html.element(
330
+ 'div',
331
+ [
332
+ self._tab_control
333
+ ]
334
+ )
335
+
336
+ def _execution_item_tab(self, item: TracedItem) -> pg.views.html.controls.Tab:
337
+ if isinstance(item, ActionInvocation):
338
+ css_class = 'action'
339
+ elif isinstance(item, lf_structured.QueryInvocation):
340
+ css_class = 'query'
341
+ elif isinstance(item, lf.logging.LogEntry):
342
+ css_class = 'log'
343
+ elif isinstance(item, ExecutionTrace):
344
+ css_class = 'phase'
345
+ else:
346
+ raise ValueError(f'Unsupported item type: {type(item)}')
347
+
348
+ return pg.views.html.controls.Tab(
349
+ label=self._execution_item_label(item),
350
+ content=pg.view(item),
351
+ css_classes=[css_class]
352
+ )
353
+
354
+ def _execution_item_label(
355
+ self, item: TracedItem
356
+ ) -> pg.views.html.controls.Label:
357
+ if isinstance(item, ActionInvocation):
358
+ return pg.views.html.controls.Label(
359
+ item.action.__class__.__name__,
360
+ tooltip=pg.format(
361
+ item.action,
362
+ verbose=False,
363
+ hide_default_values=True,
364
+ max_str_len=80,
365
+ max_bytes_len=20,
366
+ ),
367
+ )
368
+ elif isinstance(item, lf_structured.QueryInvocation):
369
+ schema_title = 'str'
370
+ if item.schema:
371
+ schema_title = lf_structured.annotation(item.schema.spec)
372
+ return pg.views.html.controls.Label(
373
+ schema_title,
374
+ tooltip=(
375
+ pg.format(
376
+ item.input,
377
+ verbose=False,
378
+ hide_default_values=True,
379
+ max_str_len=80,
380
+ max_bytes_len=20,
381
+ )
382
+ ),
383
+ )
384
+ elif isinstance(item, lf.logging.LogEntry):
385
+ return pg.views.html.controls.Label(
386
+ 'Log',
387
+ tooltip=item.message,
388
+ )
389
+ elif isinstance(item, ExecutionTrace):
390
+ return pg.views.html.controls.Label(
391
+ item.name or 'Phase'
392
+ )
393
+ else:
394
+ raise ValueError(f'Unsupported item type: {type(item)}')
395
+
396
+ def _html_tree_view_css_styles(self) -> list[str]:
397
+ return super()._html_tree_view_css_styles() + [
398
+ """
399
+ .tab-button.action > ::before {
400
+ content: "A";
401
+ font-weight: bold;
402
+ color: red;
403
+ padding: 10px;
404
+ }
405
+ .tab-button.phase > ::before {
406
+ content: "P";
407
+ font-weight: bold;
408
+ color: purple;
409
+ padding: 10px;
410
+ }
411
+ .tab-button.query > ::before {
412
+ content: "Q";
413
+ font-weight: bold;
414
+ color: orange;
415
+ padding: 10px;
416
+ }
417
+ .tab-button.log > ::before {
418
+ content: "L";
419
+ font-weight: bold;
420
+ color: green;
421
+ padding: 10px;
422
+ }
423
+ .details.execution-trace, .details.action-invocation {
424
+ border: 1px solid #eee;
425
+ }
426
+ .execution-trace-title {
427
+ display: inline-block;
428
+ }
429
+ .badge.execution-time {
430
+ margin-left: 5px;
431
+ }
432
+ .execution-time.running {
433
+ background-color: lavender;
434
+ font-weight: normal;
435
+ }
436
+ .execution-time.finished {
437
+ background-color: aliceblue;
438
+ font-weight: bold;
439
+ }
440
+ .badge.task-in-progress {
441
+ margin-left: 5px;
442
+ background-color: azure;
443
+ font-weight: bold;
444
+ }
445
+ """
446
+ ]
47
447
 
48
448
 
49
449
  class ActionInvocation(pg.Object, pg.views.html.HtmlTreeView.Extension):
@@ -55,131 +455,246 @@ class ActionInvocation(pg.Object, pg.views.html.HtmlTreeView.Extension):
55
455
  'The result of the action.'
56
456
  ] = None
57
457
 
458
+ result_metadata: Annotated[
459
+ dict[str, Any],
460
+ 'The metadata returned by the action.'
461
+ ] = {}
462
+
58
463
  execution: Annotated[
59
- list[
60
- Union[
61
- lf_structured.QueryInvocation,
62
- 'ActionInvocation',
63
- lf.logging.LogEntry
64
- ]
65
- ],
66
- 'Execution execution.'
67
- ] = []
464
+ ExecutionTrace,
465
+ 'The execution sequence of the action.'
466
+ ] = ExecutionTrace()
68
467
 
69
468
  # Allow symbolic assignment without `rebind`.
70
469
  allow_symbolic_assignment = True
71
470
 
471
+ def _on_bound(self):
472
+ super()._on_bound()
473
+ self._current_phase = self.execution
474
+ self._result_badge = None
475
+ self._result_metadata_badge = None
476
+
477
+ @property
478
+ def current_phase(self) -> ExecutionTrace:
479
+ """Returns the current execution phase."""
480
+ return self._current_phase
481
+
482
+ @contextlib.contextmanager
483
+ def phase(self, name: str) -> Iterator[ExecutionTrace]:
484
+ """Context manager for starting a new execution phase."""
485
+ phase = ExecutionTrace(name=name)
486
+ phase.start()
487
+ parent_phase = self._current_phase
488
+ self._current_phase.append(phase)
489
+ self._current_phase = phase
490
+ try:
491
+ yield phase
492
+ finally:
493
+ phase.stop()
494
+ self._current_phase = parent_phase
495
+
72
496
  @property
73
497
  def logs(self) -> list[lf.logging.LogEntry]:
74
- """Returns logs from execution sequence."""
75
- return [v for v in self.execution if isinstance(v, lf.logging.LogEntry)]
498
+ """Returns immediate child logs from execution sequence."""
499
+ return self.execution.logs
76
500
 
77
501
  @property
78
- def child_invocations(self) -> list['ActionInvocation']:
79
- """Returns child action invocations."""
80
- return [v for v in self.execution if isinstance(v, ActionInvocation)]
502
+ def actions(self) -> list['ActionInvocation']:
503
+ """Returns immediate child action invocations."""
504
+ return self.execution.actions
81
505
 
82
- def queries(
83
- self,
84
- include_children: bool = False
85
- ) -> Iterable[lf_structured.QueryInvocation]:
86
- """Iterates over queries from the current invocation."""
87
- for v in self.execution:
88
- if isinstance(v, lf_structured.QueryInvocation):
89
- yield v
90
- elif isinstance(v, ActionInvocation):
91
- if include_children:
92
- yield from v.queries(include_children=True)
506
+ @property
507
+ def queries(self) -> list[lf_structured.QueryInvocation]:
508
+ """Returns immediate queries made by the action."""
509
+ return self.execution.queries
510
+
511
+ @property
512
+ def all_queries(self) -> list[lf_structured.QueryInvocation]:
513
+ """Returns all queries made by the action and its child execution items."""
514
+ return self.execution.all_queries
515
+
516
+ @property
517
+ def all_logs(self) -> list[lf.logging.LogEntry]:
518
+ """Returns all logs made by the action and its child execution items."""
519
+ return self.execution.all_logs
520
+
521
+ @property
522
+ def usage_summary(self) -> lf.UsageSummary:
523
+ """Returns the usage summary of the action."""
524
+ return self.execution.usage_summary
525
+
526
+ def start(self) -> None:
527
+ """Starts the execution of the action."""
528
+ self.execution.start()
529
+
530
+ def end(self, result: Any, result_metadata: dict[str, Any]) -> None:
531
+ """Ends the execution of the action with result and metadata."""
532
+ self.execution.stop()
533
+ self.rebind(
534
+ result=result,
535
+ result_metadata=result_metadata,
536
+ skip_notification=True,
537
+ raise_on_no_change=False
538
+ )
539
+ if self._result_badge is not None:
540
+ self._result_badge.update(
541
+ self._result_badge_label(result),
542
+ tooltip=self._result_badge_tooltip(result),
543
+ add_class=['ready'],
544
+ remove_class=['not-ready'],
545
+ )
546
+ if self._result_metadata_badge is not None:
547
+ result_metadata = dict(result_metadata)
548
+ result_metadata.pop('session', None)
549
+ self._result_metadata_badge.update(
550
+ '{...}',
551
+ tooltip=self._result_metadata_badge_tooltip(result_metadata),
552
+ add_class=['ready'],
553
+ remove_class=['not-ready'],
554
+ )
555
+
556
+ #
557
+ # HTML views.
558
+ #
93
559
 
94
560
  def _html_tree_view_summary(
95
561
  self, *, view: pg.views.html.HtmlTreeView, **kwargs
96
562
  ):
97
- if isinstance(self.action, RootAction):
98
- return None
99
- kwargs.pop('title')
100
- return view.summary(
101
- self,
102
- title=view.render(
103
- self.action, name='action', collapse_level=0,
104
- css_classes='invocation-title',
105
- ),
106
- **kwargs
107
- )
563
+ return None
108
564
 
109
565
  def _html_tree_view_content(
110
566
  self,
111
567
  *,
112
- root_path: pg.KeyPath | None = None,
113
- collapse_level: int | None = None,
114
568
  view: pg.views.html.HtmlTreeView,
569
+ extra_flags: dict[str, Any] | None = None,
115
570
  **kwargs
116
571
  ):
117
- prepare_phase = []
118
- current_phase = prepare_phase
119
- action_phases = []
120
- for item in self.execution:
121
- if isinstance(item, ActionInvocation):
122
- current_phase = []
123
- action_phases.append(current_phase)
124
- current_phase.append(item)
125
-
126
- def _render_phase(
127
- phase: list[ActionInvocation | lf.logging.LogEntry]
128
- ) -> pg.Html.WritableTypes:
129
- return pg.Html.element(
130
- 'div',
131
- [
132
- view.render(item) for item in phase
133
- ]
572
+ extra_flags = extra_flags or {}
573
+ interactive = extra_flags.get('interactive', True)
574
+ if (isinstance(self.action, RootAction)
575
+ and self.execution.has_stopped
576
+ and len(self.execution.items) == 1):
577
+ return view.content(self.execution.items[0], extra_flags=extra_flags)
578
+
579
+ def _result_badge():
580
+ if not self.execution.has_stopped:
581
+ label = '(n/a)'
582
+ tooltip = 'Result is not available yet.'
583
+ css_class = 'not-ready'
584
+ else:
585
+ label = self._result_badge_label(self.result)
586
+ tooltip = self._result_badge_tooltip(self.result)
587
+ css_class = 'ready'
588
+ return pg.views.html.controls.Badge(
589
+ label,
590
+ tooltip=tooltip,
591
+ css_classes=['invocation-result', css_class],
592
+ interactive=interactive,
134
593
  )
135
594
 
136
- def _render_action_phases(
137
- phases: list[list[ActionInvocation | lf.logging.LogEntry]]
138
- ) -> pg.Html.WritableTypes:
139
- if len(phases) == 1:
140
- return _render_phase(phases[0])
141
- return pg.views.html.controls.TabControl(
142
- [
143
- pg.views.html.controls.Tab(
144
- label=f'Step {i + 1}',
145
- content=_render_phase(phase),
146
- )
147
- for i, phase in enumerate(phases)
148
- ],
595
+ def _result_metadata_badge():
596
+ if not self.execution.has_stopped:
597
+ label = '(n/a)'
598
+ tooltip = 'Result metadata is not available yet.'
599
+ css_class = 'not-ready'
600
+ else:
601
+ label = '{...}' if self.result_metadata else '(empty)'
602
+ tooltip = self._result_metadata_badge_tooltip(self.result_metadata)
603
+ css_class = 'ready'
604
+ return pg.views.html.controls.Badge(
605
+ label,
606
+ tooltip=tooltip,
607
+ css_classes=['invocation-result-metadata', css_class],
608
+ interactive=interactive,
149
609
  )
150
610
 
151
- result_name = 'final_result' if isinstance(
152
- self.action, RootAction) else 'result'
611
+ result_badge = _result_badge()
612
+ result_metadata_badge = _result_metadata_badge()
613
+ if interactive:
614
+ self._result_badge = result_badge
615
+ self._result_metadata_badge = result_metadata_badge
616
+
153
617
  return pg.Html.element(
154
618
  'div',
155
619
  [
156
- view.render(
157
- self.result,
158
- name=result_name,
159
- css_classes=[
160
- f'invocation-{result_name}'.replace('_', '-')
161
- ]
620
+ pg.Html.element(
621
+ 'div',
622
+ [
623
+ view.render(
624
+ self.usage_summary, extra_flags=dict(as_badge=True)
625
+ ),
626
+ result_badge,
627
+ result_metadata_badge,
628
+ ],
629
+ css_classes=['invocation-badge-container'],
162
630
  ),
163
- _render_phase(prepare_phase) if prepare_phase else None,
164
- _render_action_phases(action_phases)
631
+ view.render( # pylint: disable=g-long-ternary
632
+ self.action,
633
+ name='action',
634
+ collapse_level=None,
635
+ root_path=self.action.sym_path,
636
+ css_classes='invocation-title',
637
+ enable_summary_tooltip=False,
638
+ ) if not isinstance(self.action, RootAction) else None,
639
+ view.render(self.execution, name='execution'),
165
640
  ]
166
641
  )
167
642
 
643
+ def _result_badge_label(self, result: Any) -> str:
644
+ label = pg.format(
645
+ result, python_format=True, verbose=False
646
+ )
647
+ if len(label) > 40:
648
+ if isinstance(result, str):
649
+ label = label[:40] + '...'
650
+ else:
651
+ label = f'{result.__class__.__name__}(...)'
652
+ return label
653
+
654
+ def _result_badge_tooltip(self, result: Any) -> pg.Html:
655
+ return typing.cast(
656
+ pg.Html,
657
+ pg.view(
658
+ result, name='result',
659
+ collapse_level=None,
660
+ enable_summary_tooltip=False,
661
+ enable_key_tooltip=False,
662
+ )
663
+ )
664
+
665
+ def _result_metadata_badge_tooltip(
666
+ self, result_metadata: dict[str, Any]
667
+ ) -> pg.Html:
668
+ return typing.cast(
669
+ pg.Html,
670
+ pg.view(
671
+ result_metadata,
672
+ name='result_metadata',
673
+ collapse_level=None,
674
+ enable_summary_tooltip=False,
675
+ )
676
+ )
677
+
168
678
  @classmethod
169
679
  def _html_tree_view_css_styles(cls) -> list[str]:
170
680
  return super()._html_tree_view_css_styles() + [
171
681
  """
172
- details.invocation-title {
173
- display: inline-block;
174
- background-color: #b1f0ff;
175
- border: 1px solid white;
682
+ .invocation-badge-container {
683
+ display: flex;
684
+ padding-bottom: 5px;
176
685
  }
177
- details.invocation-result {
178
- border: 1px solid #eee;
686
+ .invocation-badge-container > .label-container {
687
+ margin-right: 3px;
179
688
  }
180
- details.invocation-final-result {
181
- border: 1px solid #eee;
182
- background-color: #fef78f;
689
+ .invocation-result.ready {
690
+ background-color: lightcyan;
691
+ }
692
+ .invocation-result-metadata.ready {
693
+ background-color: lightyellow;
694
+ }
695
+ details.pyglove.invocation-title {
696
+ background-color: aliceblue;
697
+ border: 0px solid white;
183
698
  }
184
699
  """
185
700
  ]
@@ -192,52 +707,75 @@ class RootAction(Action):
192
707
  raise NotImplementedError('Shall not be called.')
193
708
 
194
709
 
195
- class Session(pg.Object):
710
+ class Session(pg.Object, pg.views.html.HtmlTreeView.Extension):
196
711
  """Session for performing an agentic task."""
197
712
 
198
- root_invocation: ActionInvocation = ActionInvocation(RootAction())
713
+ root: ActionInvocation = ActionInvocation(RootAction())
199
714
 
200
715
  def _on_bound(self):
201
716
  super()._on_bound()
202
- self._invocation_stack = [self.root_invocation]
717
+ self._current_action = self.root
203
718
 
204
719
  @property
205
720
  def final_result(self) -> Any:
206
721
  """Returns the final result of the session."""
207
- return self.root_invocation.result
722
+ return self.root.result
208
723
 
209
724
  @property
210
- def current_invocation(self) -> ActionInvocation:
725
+ def current_action(self) -> ActionInvocation:
211
726
  """Returns the current invocation."""
212
- assert self._invocation_stack
213
- return self._invocation_stack[-1]
727
+ return self._current_action
728
+
729
+ def phase(self, name: str) -> ContextManager[ExecutionTrace]:
730
+ """Context manager for starting a new execution phase."""
731
+ return self.current_action.phase(name)
214
732
 
215
733
  @contextlib.contextmanager
216
- def track(self, action: Action) -> Iterator[ActionInvocation]:
734
+ def track_action(self, action: Action) -> Iterator[ActionInvocation]:
217
735
  """Track the execution of an action."""
218
- new_invocation = ActionInvocation(pg.maybe_ref(action))
219
- with pg.notify_on_change(False):
220
- self.current_invocation.execution.append(new_invocation)
221
- self._invocation_stack.append(new_invocation)
736
+ if not self.root.execution.has_started:
737
+ self.root.start()
738
+
739
+ invocation = ActionInvocation(pg.maybe_ref(action))
740
+ parent_action = self._current_action
741
+ parent_action.current_phase.append(invocation)
222
742
 
223
743
  try:
224
- yield new_invocation
744
+ self._current_action = invocation
745
+ # Start the execution of the current action.
746
+ self._current_action.start()
747
+ yield invocation
225
748
  finally:
226
- assert self._invocation_stack
227
- invocation = self._invocation_stack.pop(-1)
228
- invocation.rebind(
229
- result=action.result, skip_notification=True, raise_on_no_change=False
230
- )
231
- assert invocation.action is action, (invocation.action, action)
232
- assert self._invocation_stack, self._invocation_stack
233
-
234
- if len(self._invocation_stack) == 1:
235
- self.root_invocation.rebind(
236
- result=invocation.result,
237
- skip_notification=True,
238
- raise_on_no_change=False
749
+ # Stop the execution of the current action.
750
+ self._current_action.end(action.result, action.result_metadata)
751
+ self._current_action = parent_action
752
+ if parent_action is self.root:
753
+ parent_action.end(
754
+ result=action.result, result_metadata=action.result_metadata,
239
755
  )
240
756
 
757
+ @contextlib.contextmanager
758
+ def track_queries(
759
+ self,
760
+ phase: str | None = None
761
+ ) -> Iterator[list[lf_structured.QueryInvocation]]:
762
+ """Tracks `lf.query` made within the context.
763
+
764
+ Args:
765
+ phase: The name of a new phase to track the queries in. If not provided,
766
+ the queries will be tracked in the parent phase.
767
+
768
+ Yields:
769
+ A list of `lf.QueryInvocation` objects, each for a single `lf.query`
770
+ call.
771
+ """
772
+ with self.phase(phase) if phase else contextlib.nullcontext():
773
+ with lf_structured.track_queries(include_child_scopes=False) as queries:
774
+ try:
775
+ yield queries
776
+ finally:
777
+ self._current_action.current_phase.extend(queries)
778
+
241
779
  def query(
242
780
  self,
243
781
  prompt: Union[str, lf.Template, Any],
@@ -250,9 +788,37 @@ class Session(pg.Object):
250
788
  examples: list[lf_structured.MappingExample] | None = None,
251
789
  **kwargs
252
790
  ) -> Any:
253
- """Calls `lf.query` and associates it with the current invocation."""
254
- with lf_structured.track_queries() as queries:
255
- output = lf_structured.query(
791
+ """Calls `lf.query` and associates it with the current invocation.
792
+
793
+ The following code are equivalent:
794
+
795
+ Code 1:
796
+ ```
797
+ session.query(...)
798
+ ```
799
+
800
+ Code 2:
801
+ ```
802
+ with session.track_queries() as queries:
803
+ output = lf.query(...)
804
+ ```
805
+ The former is preferred when `lf.query` is directly called by the action.
806
+ If `lf.query` is called by a function that does not have access to the
807
+ session, the latter should be used.
808
+
809
+ Args:
810
+ prompt: The prompt to query.
811
+ schema: The schema to use for the query.
812
+ default: The default value to return if the query fails.
813
+ lm: The language model to use for the query.
814
+ examples: The examples to use for the query.
815
+ **kwargs: Additional keyword arguments to pass to `lf.query`.
816
+
817
+ Returns:
818
+ The result of the query.
819
+ """
820
+ with self.track_queries():
821
+ return lf_structured.query(
256
822
  prompt,
257
823
  schema=schema,
258
824
  default=default,
@@ -260,17 +826,16 @@ class Session(pg.Object):
260
826
  examples=examples,
261
827
  **kwargs
262
828
  )
263
- with pg.notify_on_change(False):
264
- self.current_invocation.execution.extend(queries)
265
- return output
266
829
 
267
830
  def _log(self, level: lf.logging.LogLevel, message: str, **kwargs):
268
- with pg.notify_on_change(False):
269
- self.current_invocation.execution.append(
270
- lf.logging.log(
271
- level, message, indent=len(self._invocation_stack) - 1, **kwargs
272
- )
273
- )
831
+ self._current_action.current_phase.append(
832
+ lf.logging.LogEntry(
833
+ level=level,
834
+ time=datetime.datetime.now(),
835
+ message=message,
836
+ metadata=kwargs,
837
+ )
838
+ )
274
839
 
275
840
  def debug(self, message: str, **kwargs):
276
841
  """Logs a debug message to the session."""
@@ -296,5 +861,26 @@ class Session(pg.Object):
296
861
  """Returns the session as a message."""
297
862
  return lf.AIMessage(
298
863
  'Agentic task session.',
299
- result=self.root_invocation
864
+ result=self.root
865
+ )
866
+
867
+ #
868
+ # HTML views.
869
+ #
870
+
871
+ def _html_tree_view_content(
872
+ self,
873
+ *,
874
+ view: pg.views.html.HtmlTreeView,
875
+ **kwargs
876
+ ):
877
+ return view.content(self.root, **kwargs)
878
+
879
+ @classmethod
880
+ def _html_tree_view_config(cls):
881
+ config = super()._html_tree_view_config()
882
+ config.update(
883
+ enable_summary_tooltip=False,
884
+ enable_key_tooltip=False,
300
885
  )
886
+ return config