langfun 0.1.2.dev202412070804__py3-none-any.whl → 0.1.2.dev202412110804__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/core/agentic/action.py +716 -132
- langfun/core/agentic/action_test.py +78 -42
- langfun/core/eval/v2/evaluation.py +3 -0
- langfun/core/llms/__init__.py +1 -0
- langfun/core/llms/anthropic.py +1 -1
- langfun/core/llms/google_genai.py +31 -41
- langfun/core/llms/vertexai.py +10 -9
- langfun/core/message.py +1 -2
- langfun/core/message_test.py +1 -2
- langfun/core/structured/function_generation.py +26 -14
- langfun/core/structured/function_generation_test.py +30 -0
- langfun/core/structured/prompting.py +105 -5
- langfun/core/structured/prompting_test.py +12 -0
- langfun/core/structured/schema.py +1 -1
- {langfun-0.1.2.dev202412070804.dist-info → langfun-0.1.2.dev202412110804.dist-info}/METADATA +1 -1
- {langfun-0.1.2.dev202412070804.dist-info → langfun-0.1.2.dev202412110804.dist-info}/RECORD +19 -19
- {langfun-0.1.2.dev202412070804.dist-info → langfun-0.1.2.dev202412110804.dist-info}/LICENSE +0 -0
- {langfun-0.1.2.dev202412070804.dist-info → langfun-0.1.2.dev202412110804.dist-info}/WHEEL +0 -0
- {langfun-0.1.2.dev202412070804.dist-info → langfun-0.1.2.dev202412110804.dist-info}/top_level.txt +0 -0
langfun/core/agentic/action.py
CHANGED
@@ -15,7 +15,11 @@
|
|
15
15
|
|
16
16
|
import abc
|
17
17
|
import contextlib
|
18
|
-
|
18
|
+
import datetime
|
19
|
+
import time
|
20
|
+
|
21
|
+
import typing
|
22
|
+
from typing import Annotated, Any, ContextManager, Iterable, Iterator, Optional, Type, Union
|
19
23
|
import langfun.core as lf
|
20
24
|
from langfun.core import structured as lf_structured
|
21
25
|
import pyglove as pg
|
@@ -26,24 +30,413 @@ class Action(pg.Object):
|
|
26
30
|
|
27
31
|
def _on_bound(self):
|
28
32
|
super()._on_bound()
|
33
|
+
self._session = None
|
29
34
|
self._result = None
|
35
|
+
self._result_metadata = {}
|
36
|
+
|
37
|
+
@property
|
38
|
+
def session(self) -> Optional['Session']:
|
39
|
+
"""Returns the session started by this action."""
|
40
|
+
return self._session
|
30
41
|
|
31
42
|
@property
|
32
43
|
def result(self) -> Any:
|
33
44
|
"""Returns the result of the action."""
|
34
45
|
return self._result
|
35
46
|
|
47
|
+
@property
|
48
|
+
def result_metadata(self) -> dict[str, Any] | None:
|
49
|
+
"""Returns the metadata associated with the result from previous call."""
|
50
|
+
return self._result_metadata
|
51
|
+
|
36
52
|
def __call__(
|
37
|
-
self,
|
53
|
+
self,
|
54
|
+
session: Optional['Session'] = None,
|
55
|
+
*,
|
56
|
+
show_progress: bool = True,
|
57
|
+
**kwargs) -> Any:
|
38
58
|
"""Executes the action."""
|
39
|
-
|
40
|
-
|
41
|
-
|
59
|
+
new_session = session is None
|
60
|
+
if new_session:
|
61
|
+
session = Session()
|
62
|
+
if show_progress:
|
63
|
+
lf.console.display(pg.view(session, name='agent_session'))
|
64
|
+
|
65
|
+
with session.track_action(self):
|
66
|
+
result = self.call(session=session, **kwargs)
|
67
|
+
|
68
|
+
# For the top-level action, we store the session in the metadata.
|
69
|
+
if new_session:
|
70
|
+
self._session = session
|
71
|
+
self._result = result
|
72
|
+
self._result_metadata = session.current_action.result_metadata
|
42
73
|
return self._result
|
43
74
|
|
44
75
|
@abc.abstractmethod
|
45
76
|
def call(self, session: 'Session', **kwargs) -> Any:
|
46
|
-
"""
|
77
|
+
"""Calls the action.
|
78
|
+
|
79
|
+
Args:
|
80
|
+
session: The session to use for the action.
|
81
|
+
**kwargs: Additional keyword arguments to pass to the action.
|
82
|
+
|
83
|
+
Returns:
|
84
|
+
The result of the action.
|
85
|
+
"""
|
86
|
+
|
87
|
+
|
88
|
+
# Type definition for traced item during execution.
|
89
|
+
TracedItem = Union[
|
90
|
+
lf_structured.QueryInvocation,
|
91
|
+
'ActionInvocation',
|
92
|
+
'ExecutionTrace',
|
93
|
+
# NOTE(daiyip): Consider remove log entry once we migrate existing agents.
|
94
|
+
lf.logging.LogEntry,
|
95
|
+
]
|
96
|
+
|
97
|
+
|
98
|
+
class ExecutionTrace(pg.Object, pg.views.html.HtmlTreeView.Extension):
|
99
|
+
"""Trace of the execution of an action."""
|
100
|
+
|
101
|
+
name: Annotated[
|
102
|
+
str | None,
|
103
|
+
(
|
104
|
+
'The name of the execution trace. If None, the trace is unnamed, '
|
105
|
+
'which is the case for the top-level trace of an action. An '
|
106
|
+
'execution trace could have sub-traces, called phases, which are '
|
107
|
+
'created and named by `session.phase()` context manager.'
|
108
|
+
)
|
109
|
+
] = None
|
110
|
+
|
111
|
+
start_time: Annotated[
|
112
|
+
float | None,
|
113
|
+
'The start time of the execution. If None, the execution is not started.'
|
114
|
+
] = None
|
115
|
+
|
116
|
+
end_time: Annotated[
|
117
|
+
float | None,
|
118
|
+
'The end time of the execution. If None, the execution is not ended.'
|
119
|
+
] = None
|
120
|
+
|
121
|
+
items: Annotated[
|
122
|
+
list[TracedItem],
|
123
|
+
'All tracked execution items in the sequence.'
|
124
|
+
] = []
|
125
|
+
|
126
|
+
def _on_bound(self):
|
127
|
+
super()._on_bound()
|
128
|
+
self._usage_summary = lf.UsageSummary()
|
129
|
+
for item in self.items:
|
130
|
+
if hasattr(item, 'usage_summary'):
|
131
|
+
self._usage_summary.merge(item.usage_summary)
|
132
|
+
|
133
|
+
self._tab_control = None
|
134
|
+
self._time_badge = None
|
135
|
+
|
136
|
+
def start(self) -> None:
|
137
|
+
assert self.start_time is None, 'Execution already started.'
|
138
|
+
self.rebind(start_time=time.time(), skip_notification=True)
|
139
|
+
if self._time_badge is not None:
|
140
|
+
self._time_badge.update(
|
141
|
+
'Starting',
|
142
|
+
add_class=['running'],
|
143
|
+
remove_class=['not-started'],
|
144
|
+
)
|
145
|
+
|
146
|
+
def stop(self) -> None:
|
147
|
+
assert self.end_time is None, 'Execution already stopped.'
|
148
|
+
self.rebind(end_time=time.time(), skip_notification=True)
|
149
|
+
if self._time_badge is not None:
|
150
|
+
self._time_badge.update(
|
151
|
+
f'{int(self.elapse)} seconds',
|
152
|
+
add_class=['finished'],
|
153
|
+
remove_class=['running'],
|
154
|
+
)
|
155
|
+
|
156
|
+
@property
|
157
|
+
def has_started(self) -> bool:
|
158
|
+
return self.start_time is not None
|
159
|
+
|
160
|
+
@property
|
161
|
+
def has_stopped(self) -> bool:
|
162
|
+
return self.end_time is not None
|
163
|
+
|
164
|
+
@property
|
165
|
+
def elapse(self) -> float:
|
166
|
+
"""Returns the elapsed time of the execution."""
|
167
|
+
if self.start_time is None:
|
168
|
+
return 0.0
|
169
|
+
if self.end_time is None:
|
170
|
+
return time.time() - self.start_time
|
171
|
+
return self.end_time - self.start_time
|
172
|
+
|
173
|
+
@property
|
174
|
+
def queries(self) -> list[lf_structured.QueryInvocation]:
|
175
|
+
"""Returns queries from the sequence."""
|
176
|
+
return self._child_items(lf_structured.QueryInvocation)
|
177
|
+
|
178
|
+
@property
|
179
|
+
def actions(self) -> list['ActionInvocation']:
|
180
|
+
"""Returns action invocations from the sequence."""
|
181
|
+
return self._child_items(ActionInvocation)
|
182
|
+
|
183
|
+
@property
|
184
|
+
def logs(self) -> list[lf.logging.LogEntry]:
|
185
|
+
"""Returns logs from the sequence."""
|
186
|
+
return self._child_items(lf.logging.LogEntry)
|
187
|
+
|
188
|
+
@property
|
189
|
+
def all_queries(self) -> list[lf_structured.QueryInvocation]:
|
190
|
+
"""Returns all queries from current trace and its child execution items."""
|
191
|
+
return self._all_child_items(lf_structured.QueryInvocation)
|
192
|
+
|
193
|
+
@property
|
194
|
+
def all_logs(self) -> list[lf.logging.LogEntry]:
|
195
|
+
"""Returns all logs from current trace and its child execution items."""
|
196
|
+
return self._all_child_items(lf.logging.LogEntry)
|
197
|
+
|
198
|
+
def _child_items(self, item_cls: Type[Any]) -> list[Any]:
|
199
|
+
child_items = []
|
200
|
+
for item in self.items:
|
201
|
+
if isinstance(item, item_cls):
|
202
|
+
child_items.append(item)
|
203
|
+
elif isinstance(item, ExecutionTrace):
|
204
|
+
child_items.extend(item._child_items(item_cls)) # pylint: disable=protected-access
|
205
|
+
return child_items
|
206
|
+
|
207
|
+
def _all_child_items(self, item_cls: Type[Any]) -> list[Any]:
|
208
|
+
child_items = []
|
209
|
+
for item in self.items:
|
210
|
+
if isinstance(item, item_cls):
|
211
|
+
child_items.append(item)
|
212
|
+
elif isinstance(item, ActionInvocation):
|
213
|
+
child_items.extend(item.execution._all_child_items(item_cls)) # pylint: disable=protected-access
|
214
|
+
elif isinstance(item, ExecutionTrace):
|
215
|
+
child_items.extend(item._all_child_items(item_cls)) # pylint: disable=protected-access
|
216
|
+
return child_items
|
217
|
+
|
218
|
+
def append(self, item: TracedItem) -> None:
|
219
|
+
"""Appends an item to the sequence."""
|
220
|
+
with pg.notify_on_change(False):
|
221
|
+
self.items.append(item)
|
222
|
+
|
223
|
+
if isinstance(item, lf_structured.QueryInvocation):
|
224
|
+
current_invocation = self
|
225
|
+
while current_invocation is not None:
|
226
|
+
current_invocation.usage_summary.merge(item.usage_summary)
|
227
|
+
current_invocation = typing.cast(
|
228
|
+
ExecutionTrace,
|
229
|
+
current_invocation.sym_ancestor(
|
230
|
+
lambda x: isinstance(x, ExecutionTrace)
|
231
|
+
)
|
232
|
+
)
|
233
|
+
|
234
|
+
if self._tab_control is not None:
|
235
|
+
self._tab_control.append(self._execution_item_tab(item))
|
236
|
+
|
237
|
+
if (self._time_badge is not None
|
238
|
+
and not isinstance(item, lf.logging.LogEntry)):
|
239
|
+
sub_task_label = self._execution_item_label(item)
|
240
|
+
self._time_badge.update(
|
241
|
+
pg.Html.element(
|
242
|
+
'span',
|
243
|
+
[
|
244
|
+
'Running',
|
245
|
+
pg.views.html.controls.Badge(
|
246
|
+
sub_task_label.text,
|
247
|
+
tooltip=sub_task_label.tooltip,
|
248
|
+
css_classes=['task-in-progress']
|
249
|
+
)
|
250
|
+
]
|
251
|
+
),
|
252
|
+
add_class=['running'],
|
253
|
+
remove_class=['not-started'],
|
254
|
+
)
|
255
|
+
|
256
|
+
def extend(self, items: Iterable[TracedItem]) -> None:
|
257
|
+
"""Extends the sequence with a list of items."""
|
258
|
+
for item in items:
|
259
|
+
self.append(item)
|
260
|
+
|
261
|
+
@property
|
262
|
+
def usage_summary(self) -> lf.UsageSummary:
|
263
|
+
"""Returns the usage summary of the action."""
|
264
|
+
return self._usage_summary
|
265
|
+
|
266
|
+
#
|
267
|
+
# HTML views.
|
268
|
+
#
|
269
|
+
|
270
|
+
def _html_tree_view_summary(
|
271
|
+
self,
|
272
|
+
*,
|
273
|
+
name: str | None = None,
|
274
|
+
extra_flags: dict[str, Any] | None = None,
|
275
|
+
view: pg.views.html.HtmlTreeView, **kwargs
|
276
|
+
):
|
277
|
+
extra_flags = extra_flags or {}
|
278
|
+
interactive = extra_flags.get('interactive', True)
|
279
|
+
def time_badge():
|
280
|
+
if not self.has_started:
|
281
|
+
label = '(Not started)'
|
282
|
+
css_class = 'not-started'
|
283
|
+
elif not self.has_stopped:
|
284
|
+
label = 'Starting'
|
285
|
+
css_class = 'running'
|
286
|
+
else:
|
287
|
+
label = f'{int(self.elapse)} seconds'
|
288
|
+
css_class = 'finished'
|
289
|
+
return pg.views.html.controls.Badge(
|
290
|
+
label,
|
291
|
+
css_classes=['execution-time', css_class],
|
292
|
+
interactive=interactive,
|
293
|
+
)
|
294
|
+
time_badge = time_badge()
|
295
|
+
if interactive:
|
296
|
+
self._time_badge = time_badge
|
297
|
+
title = pg.Html.element(
|
298
|
+
'div',
|
299
|
+
[
|
300
|
+
'ExecutionTrace',
|
301
|
+
time_badge,
|
302
|
+
],
|
303
|
+
css_classes=['execution-trace-title'],
|
304
|
+
)
|
305
|
+
kwargs.pop('title', None)
|
306
|
+
kwargs['enable_summary_tooltip'] = False
|
307
|
+
kwargs['enable_key_tooltip'] = False
|
308
|
+
return view.summary(
|
309
|
+
self,
|
310
|
+
name=name,
|
311
|
+
title=title,
|
312
|
+
extra_flags=extra_flags,
|
313
|
+
**kwargs
|
314
|
+
)
|
315
|
+
|
316
|
+
def _html_tree_view_content(self, **kwargs):
|
317
|
+
del kwargs
|
318
|
+
self._tab_control = pg.views.html.controls.TabControl(
|
319
|
+
[self._execution_item_tab(item) for item in self.items],
|
320
|
+
tab_position='left'
|
321
|
+
)
|
322
|
+
return pg.Html.element(
|
323
|
+
'div',
|
324
|
+
[
|
325
|
+
self._tab_control
|
326
|
+
]
|
327
|
+
)
|
328
|
+
|
329
|
+
def _execution_item_tab(self, item: TracedItem) -> pg.views.html.controls.Tab:
|
330
|
+
if isinstance(item, ActionInvocation):
|
331
|
+
css_class = 'action'
|
332
|
+
elif isinstance(item, lf_structured.QueryInvocation):
|
333
|
+
css_class = 'query'
|
334
|
+
elif isinstance(item, lf.logging.LogEntry):
|
335
|
+
css_class = 'log'
|
336
|
+
elif isinstance(item, ExecutionTrace):
|
337
|
+
css_class = 'phase'
|
338
|
+
else:
|
339
|
+
raise ValueError(f'Unsupported item type: {type(item)}')
|
340
|
+
|
341
|
+
return pg.views.html.controls.Tab(
|
342
|
+
label=self._execution_item_label(item),
|
343
|
+
content=pg.view(item),
|
344
|
+
css_classes=[css_class]
|
345
|
+
)
|
346
|
+
|
347
|
+
def _execution_item_label(
|
348
|
+
self, item: TracedItem
|
349
|
+
) -> pg.views.html.controls.Label:
|
350
|
+
if isinstance(item, ActionInvocation):
|
351
|
+
return pg.views.html.controls.Label(
|
352
|
+
item.action.__class__.__name__,
|
353
|
+
tooltip=pg.format(
|
354
|
+
item.action,
|
355
|
+
verbose=False,
|
356
|
+
hide_default_values=True,
|
357
|
+
max_str_len=80,
|
358
|
+
max_bytes_len=20,
|
359
|
+
),
|
360
|
+
)
|
361
|
+
elif isinstance(item, lf_structured.QueryInvocation):
|
362
|
+
schema_title = 'str'
|
363
|
+
if item.schema:
|
364
|
+
schema_title = lf_structured.annotation(item.schema.spec)
|
365
|
+
return pg.views.html.controls.Label(
|
366
|
+
schema_title,
|
367
|
+
tooltip=(
|
368
|
+
pg.format(
|
369
|
+
item.input,
|
370
|
+
verbose=False,
|
371
|
+
hide_default_values=True,
|
372
|
+
max_str_len=80,
|
373
|
+
max_bytes_len=20,
|
374
|
+
)
|
375
|
+
),
|
376
|
+
)
|
377
|
+
elif isinstance(item, lf.logging.LogEntry):
|
378
|
+
return pg.views.html.controls.Label(
|
379
|
+
'Log',
|
380
|
+
tooltip=item.message,
|
381
|
+
)
|
382
|
+
elif isinstance(item, ExecutionTrace):
|
383
|
+
return pg.views.html.controls.Label(
|
384
|
+
item.name or 'Phase'
|
385
|
+
)
|
386
|
+
else:
|
387
|
+
raise ValueError(f'Unsupported item type: {type(item)}')
|
388
|
+
|
389
|
+
def _html_tree_view_css_styles(self) -> list[str]:
|
390
|
+
return super()._html_tree_view_css_styles() + [
|
391
|
+
"""
|
392
|
+
.tab-button.action > ::before {
|
393
|
+
content: "A";
|
394
|
+
font-weight: bold;
|
395
|
+
color: red;
|
396
|
+
padding: 10px;
|
397
|
+
}
|
398
|
+
.tab-button.phase > ::before {
|
399
|
+
content: "P";
|
400
|
+
font-weight: bold;
|
401
|
+
color: purple;
|
402
|
+
padding: 10px;
|
403
|
+
}
|
404
|
+
.tab-button.query > ::before {
|
405
|
+
content: "Q";
|
406
|
+
font-weight: bold;
|
407
|
+
color: orange;
|
408
|
+
padding: 10px;
|
409
|
+
}
|
410
|
+
.tab-button.log > ::before {
|
411
|
+
content: "L";
|
412
|
+
font-weight: bold;
|
413
|
+
color: green;
|
414
|
+
padding: 10px;
|
415
|
+
}
|
416
|
+
.details.execution-trace, .details.action-invocation {
|
417
|
+
border: 1px solid #eee;
|
418
|
+
}
|
419
|
+
.execution-trace-title {
|
420
|
+
display: inline-block;
|
421
|
+
}
|
422
|
+
.badge.execution-time {
|
423
|
+
margin-left: 5px;
|
424
|
+
}
|
425
|
+
.execution-time.running {
|
426
|
+
background-color: lavender;
|
427
|
+
font-weight: normal;
|
428
|
+
}
|
429
|
+
.execution-time.finished {
|
430
|
+
background-color: aliceblue;
|
431
|
+
font-weight: bold;
|
432
|
+
}
|
433
|
+
.badge.task-in-progress {
|
434
|
+
margin-left: 5px;
|
435
|
+
background-color: azure;
|
436
|
+
font-weight: bold;
|
437
|
+
}
|
438
|
+
"""
|
439
|
+
]
|
47
440
|
|
48
441
|
|
49
442
|
class ActionInvocation(pg.Object, pg.views.html.HtmlTreeView.Extension):
|
@@ -55,131 +448,246 @@ class ActionInvocation(pg.Object, pg.views.html.HtmlTreeView.Extension):
|
|
55
448
|
'The result of the action.'
|
56
449
|
] = None
|
57
450
|
|
451
|
+
result_metadata: Annotated[
|
452
|
+
dict[str, Any],
|
453
|
+
'The metadata returned by the action.'
|
454
|
+
] = {}
|
455
|
+
|
58
456
|
execution: Annotated[
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
'ActionInvocation',
|
63
|
-
lf.logging.LogEntry
|
64
|
-
]
|
65
|
-
],
|
66
|
-
'Execution execution.'
|
67
|
-
] = []
|
457
|
+
ExecutionTrace,
|
458
|
+
'The execution sequence of the action.'
|
459
|
+
] = ExecutionTrace()
|
68
460
|
|
69
461
|
# Allow symbolic assignment without `rebind`.
|
70
462
|
allow_symbolic_assignment = True
|
71
463
|
|
464
|
+
def _on_bound(self):
|
465
|
+
super()._on_bound()
|
466
|
+
self._current_phase = self.execution
|
467
|
+
self._result_badge = None
|
468
|
+
self._result_metadata_badge = None
|
469
|
+
|
470
|
+
@property
|
471
|
+
def current_phase(self) -> ExecutionTrace:
|
472
|
+
"""Returns the current execution phase."""
|
473
|
+
return self._current_phase
|
474
|
+
|
475
|
+
@contextlib.contextmanager
|
476
|
+
def phase(self, name: str) -> Iterator[ExecutionTrace]:
|
477
|
+
"""Context manager for starting a new execution phase."""
|
478
|
+
phase = ExecutionTrace(name=name)
|
479
|
+
phase.start()
|
480
|
+
parent_phase = self._current_phase
|
481
|
+
self._current_phase.append(phase)
|
482
|
+
self._current_phase = phase
|
483
|
+
try:
|
484
|
+
yield phase
|
485
|
+
finally:
|
486
|
+
phase.stop()
|
487
|
+
self._current_phase = parent_phase
|
488
|
+
|
72
489
|
@property
|
73
490
|
def logs(self) -> list[lf.logging.LogEntry]:
|
74
|
-
"""Returns logs from execution sequence."""
|
75
|
-
return
|
491
|
+
"""Returns immediate child logs from execution sequence."""
|
492
|
+
return self.execution.logs
|
76
493
|
|
77
494
|
@property
|
78
|
-
def
|
79
|
-
"""Returns child action invocations."""
|
80
|
-
return
|
495
|
+
def actions(self) -> list['ActionInvocation']:
|
496
|
+
"""Returns immediate child action invocations."""
|
497
|
+
return self.execution.actions
|
81
498
|
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
499
|
+
@property
|
500
|
+
def queries(self) -> list[lf_structured.QueryInvocation]:
|
501
|
+
"""Returns immediate queries made by the action."""
|
502
|
+
return self.execution.queries
|
503
|
+
|
504
|
+
@property
|
505
|
+
def all_queries(self) -> list[lf_structured.QueryInvocation]:
|
506
|
+
"""Returns all queries made by the action and its child execution items."""
|
507
|
+
return self.execution.all_queries
|
508
|
+
|
509
|
+
@property
|
510
|
+
def all_logs(self) -> list[lf.logging.LogEntry]:
|
511
|
+
"""Returns all logs made by the action and its child execution items."""
|
512
|
+
return self.execution.all_logs
|
513
|
+
|
514
|
+
@property
|
515
|
+
def usage_summary(self) -> lf.UsageSummary:
|
516
|
+
"""Returns the usage summary of the action."""
|
517
|
+
return self.execution.usage_summary
|
518
|
+
|
519
|
+
def start(self) -> None:
|
520
|
+
"""Starts the execution of the action."""
|
521
|
+
self.execution.start()
|
522
|
+
|
523
|
+
def end(self, result: Any, result_metadata: dict[str, Any]) -> None:
|
524
|
+
"""Ends the execution of the action with result and metadata."""
|
525
|
+
self.execution.stop()
|
526
|
+
self.rebind(
|
527
|
+
result=result,
|
528
|
+
result_metadata=result_metadata,
|
529
|
+
skip_notification=True,
|
530
|
+
raise_on_no_change=False
|
531
|
+
)
|
532
|
+
if self._result_badge is not None:
|
533
|
+
self._result_badge.update(
|
534
|
+
self._result_badge_label(result),
|
535
|
+
tooltip=self._result_badge_tooltip(result),
|
536
|
+
add_class=['ready'],
|
537
|
+
remove_class=['not-ready'],
|
538
|
+
)
|
539
|
+
if self._result_metadata_badge is not None:
|
540
|
+
result_metadata = dict(result_metadata)
|
541
|
+
result_metadata.pop('session', None)
|
542
|
+
self._result_metadata_badge.update(
|
543
|
+
'{...}',
|
544
|
+
tooltip=self._result_metadata_badge_tooltip(result_metadata),
|
545
|
+
add_class=['ready'],
|
546
|
+
remove_class=['not-ready'],
|
547
|
+
)
|
548
|
+
|
549
|
+
#
|
550
|
+
# HTML views.
|
551
|
+
#
|
93
552
|
|
94
553
|
def _html_tree_view_summary(
|
95
554
|
self, *, view: pg.views.html.HtmlTreeView, **kwargs
|
96
555
|
):
|
97
|
-
|
98
|
-
return None
|
99
|
-
kwargs.pop('title')
|
100
|
-
return view.summary(
|
101
|
-
self,
|
102
|
-
title=view.render(
|
103
|
-
self.action, name='action', collapse_level=0,
|
104
|
-
css_classes='invocation-title',
|
105
|
-
),
|
106
|
-
**kwargs
|
107
|
-
)
|
556
|
+
return None
|
108
557
|
|
109
558
|
def _html_tree_view_content(
|
110
559
|
self,
|
111
560
|
*,
|
112
|
-
root_path: pg.KeyPath | None = None,
|
113
|
-
collapse_level: int | None = None,
|
114
561
|
view: pg.views.html.HtmlTreeView,
|
562
|
+
extra_flags: dict[str, Any] | None = None,
|
115
563
|
**kwargs
|
116
564
|
):
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
565
|
+
extra_flags = extra_flags or {}
|
566
|
+
interactive = extra_flags.get('interactive', True)
|
567
|
+
if (isinstance(self.action, RootAction)
|
568
|
+
and self.execution.has_stopped
|
569
|
+
and len(self.execution.items) == 1):
|
570
|
+
return view.content(self.execution.items[0], extra_flags=extra_flags)
|
571
|
+
|
572
|
+
def _result_badge():
|
573
|
+
if not self.execution.has_stopped:
|
574
|
+
label = '(n/a)'
|
575
|
+
tooltip = 'Result is not available yet.'
|
576
|
+
css_class = 'not-ready'
|
577
|
+
else:
|
578
|
+
label = self._result_badge_label(self.result)
|
579
|
+
tooltip = self._result_badge_tooltip(self.result)
|
580
|
+
css_class = 'ready'
|
581
|
+
return pg.views.html.controls.Badge(
|
582
|
+
label,
|
583
|
+
tooltip=tooltip,
|
584
|
+
css_classes=['invocation-result', css_class],
|
585
|
+
interactive=interactive,
|
134
586
|
)
|
135
587
|
|
136
|
-
def
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
],
|
588
|
+
def _result_metadata_badge():
|
589
|
+
if not self.execution.has_stopped:
|
590
|
+
label = '(n/a)'
|
591
|
+
tooltip = 'Result metadata is not available yet.'
|
592
|
+
css_class = 'not-ready'
|
593
|
+
else:
|
594
|
+
label = '{...}' if self.result_metadata else '(empty)'
|
595
|
+
tooltip = self._result_metadata_badge_tooltip(self.result_metadata)
|
596
|
+
css_class = 'ready'
|
597
|
+
return pg.views.html.controls.Badge(
|
598
|
+
label,
|
599
|
+
tooltip=tooltip,
|
600
|
+
css_classes=['invocation-result-metadata', css_class],
|
601
|
+
interactive=interactive,
|
149
602
|
)
|
150
603
|
|
151
|
-
|
152
|
-
|
604
|
+
result_badge = _result_badge()
|
605
|
+
result_metadata_badge = _result_metadata_badge()
|
606
|
+
if interactive:
|
607
|
+
self._result_badge = result_badge
|
608
|
+
self._result_metadata_badge = result_metadata_badge
|
609
|
+
|
153
610
|
return pg.Html.element(
|
154
611
|
'div',
|
155
612
|
[
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
613
|
+
pg.Html.element(
|
614
|
+
'div',
|
615
|
+
[
|
616
|
+
view.render(
|
617
|
+
self.usage_summary, extra_flags=dict(as_badge=True)
|
618
|
+
),
|
619
|
+
result_badge,
|
620
|
+
result_metadata_badge,
|
621
|
+
],
|
622
|
+
css_classes=['invocation-badge-container'],
|
162
623
|
),
|
163
|
-
|
164
|
-
|
624
|
+
view.render( # pylint: disable=g-long-ternary
|
625
|
+
self.action,
|
626
|
+
name='action',
|
627
|
+
collapse_level=None,
|
628
|
+
root_path=self.action.sym_path,
|
629
|
+
css_classes='invocation-title',
|
630
|
+
enable_summary_tooltip=False,
|
631
|
+
) if not isinstance(self.action, RootAction) else None,
|
632
|
+
view.render(self.execution, name='execution'),
|
165
633
|
]
|
166
634
|
)
|
167
635
|
|
636
|
+
def _result_badge_label(self, result: Any) -> str:
|
637
|
+
label = pg.format(
|
638
|
+
result, python_format=True, verbose=False
|
639
|
+
)
|
640
|
+
if len(label) > 40:
|
641
|
+
if isinstance(result, str):
|
642
|
+
label = label[:40] + '...'
|
643
|
+
else:
|
644
|
+
label = f'{result.__class__.__name__}(...)'
|
645
|
+
return label
|
646
|
+
|
647
|
+
def _result_badge_tooltip(self, result: Any) -> pg.Html:
|
648
|
+
return typing.cast(
|
649
|
+
pg.Html,
|
650
|
+
pg.view(
|
651
|
+
result, name='result',
|
652
|
+
collapse_level=None,
|
653
|
+
enable_summary_tooltip=False,
|
654
|
+
enable_key_tooltip=False,
|
655
|
+
)
|
656
|
+
)
|
657
|
+
|
658
|
+
def _result_metadata_badge_tooltip(
|
659
|
+
self, result_metadata: dict[str, Any]
|
660
|
+
) -> pg.Html:
|
661
|
+
return typing.cast(
|
662
|
+
pg.Html,
|
663
|
+
pg.view(
|
664
|
+
result_metadata,
|
665
|
+
name='result_metadata',
|
666
|
+
collapse_level=None,
|
667
|
+
enable_summary_tooltip=False,
|
668
|
+
)
|
669
|
+
)
|
670
|
+
|
168
671
|
@classmethod
|
169
672
|
def _html_tree_view_css_styles(cls) -> list[str]:
|
170
673
|
return super()._html_tree_view_css_styles() + [
|
171
674
|
"""
|
172
|
-
|
173
|
-
display:
|
174
|
-
|
175
|
-
border: 1px solid white;
|
675
|
+
.invocation-badge-container {
|
676
|
+
display: flex;
|
677
|
+
padding-bottom: 5px;
|
176
678
|
}
|
177
|
-
|
178
|
-
|
679
|
+
.invocation-badge-container > .label-container {
|
680
|
+
margin-right: 3px;
|
179
681
|
}
|
180
|
-
|
181
|
-
|
182
|
-
|
682
|
+
.invocation-result.ready {
|
683
|
+
background-color: lightcyan;
|
684
|
+
}
|
685
|
+
.invocation-result-metadata.ready {
|
686
|
+
background-color: lightyellow;
|
687
|
+
}
|
688
|
+
details.pyglove.invocation-title {
|
689
|
+
background-color: aliceblue;
|
690
|
+
border: 0px solid white;
|
183
691
|
}
|
184
692
|
"""
|
185
693
|
]
|
@@ -192,52 +700,80 @@ class RootAction(Action):
|
|
192
700
|
raise NotImplementedError('Shall not be called.')
|
193
701
|
|
194
702
|
|
195
|
-
class Session(pg.Object):
|
703
|
+
class Session(pg.Object, pg.views.html.HtmlTreeView.Extension):
|
196
704
|
"""Session for performing an agentic task."""
|
197
705
|
|
198
|
-
|
706
|
+
root: ActionInvocation = ActionInvocation(RootAction())
|
199
707
|
|
200
708
|
def _on_bound(self):
|
201
709
|
super()._on_bound()
|
202
|
-
self.
|
710
|
+
self._current_action = self.root
|
203
711
|
|
204
712
|
@property
|
205
713
|
def final_result(self) -> Any:
|
206
714
|
"""Returns the final result of the session."""
|
207
|
-
return self.
|
715
|
+
return self.root.result
|
208
716
|
|
209
717
|
@property
|
210
|
-
def
|
718
|
+
def current_action(self) -> ActionInvocation:
|
211
719
|
"""Returns the current invocation."""
|
212
|
-
|
213
|
-
|
720
|
+
return self._current_action
|
721
|
+
|
722
|
+
def add_metadata(self, **kwargs: Any) -> None:
|
723
|
+
"""Adds metadata to the current invocation."""
|
724
|
+
with pg.notify_on_change(False):
|
725
|
+
self._current_action.result_metadata.update(kwargs)
|
726
|
+
|
727
|
+
def phase(self, name: str) -> ContextManager[ExecutionTrace]:
|
728
|
+
"""Context manager for starting a new execution phase."""
|
729
|
+
return self.current_action.phase(name)
|
214
730
|
|
215
731
|
@contextlib.contextmanager
|
216
|
-
def
|
732
|
+
def track_action(self, action: Action) -> Iterator[ActionInvocation]:
|
217
733
|
"""Track the execution of an action."""
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
734
|
+
if not self.root.execution.has_started:
|
735
|
+
self.root.start()
|
736
|
+
|
737
|
+
invocation = ActionInvocation(pg.maybe_ref(action))
|
738
|
+
parent_action = self._current_action
|
739
|
+
parent_action.current_phase.append(invocation)
|
222
740
|
|
223
741
|
try:
|
224
|
-
|
742
|
+
self._current_action = invocation
|
743
|
+
# Start the execution of the current action.
|
744
|
+
self._current_action.start()
|
745
|
+
yield invocation
|
225
746
|
finally:
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
assert self._invocation_stack, self._invocation_stack
|
233
|
-
|
234
|
-
if len(self._invocation_stack) == 1:
|
235
|
-
self.root_invocation.rebind(
|
236
|
-
result=invocation.result,
|
237
|
-
skip_notification=True,
|
238
|
-
raise_on_no_change=False
|
747
|
+
# Stop the execution of the current action.
|
748
|
+
self._current_action.end(action.result, action.result_metadata)
|
749
|
+
self._current_action = parent_action
|
750
|
+
if parent_action is self.root:
|
751
|
+
parent_action.end(
|
752
|
+
result=action.result, result_metadata=action.result_metadata,
|
239
753
|
)
|
240
754
|
|
755
|
+
@contextlib.contextmanager
|
756
|
+
def track_queries(
|
757
|
+
self,
|
758
|
+
phase: str | None = None
|
759
|
+
) -> Iterator[list[lf_structured.QueryInvocation]]:
|
760
|
+
"""Tracks `lf.query` made within the context.
|
761
|
+
|
762
|
+
Args:
|
763
|
+
phase: The name of a new phase to track the queries in. If not provided,
|
764
|
+
the queries will be tracked in the parent phase.
|
765
|
+
|
766
|
+
Yields:
|
767
|
+
A list of `lf.QueryInvocation` objects, each for a single `lf.query`
|
768
|
+
call.
|
769
|
+
"""
|
770
|
+
with self.phase(phase) if phase else contextlib.nullcontext():
|
771
|
+
with lf_structured.track_queries(include_child_scopes=False) as queries:
|
772
|
+
try:
|
773
|
+
yield queries
|
774
|
+
finally:
|
775
|
+
self._current_action.current_phase.extend(queries)
|
776
|
+
|
241
777
|
def query(
|
242
778
|
self,
|
243
779
|
prompt: Union[str, lf.Template, Any],
|
@@ -250,9 +786,37 @@ class Session(pg.Object):
|
|
250
786
|
examples: list[lf_structured.MappingExample] | None = None,
|
251
787
|
**kwargs
|
252
788
|
) -> Any:
|
253
|
-
"""Calls `lf.query` and associates it with the current invocation.
|
254
|
-
|
255
|
-
|
789
|
+
"""Calls `lf.query` and associates it with the current invocation.
|
790
|
+
|
791
|
+
The following code are equivalent:
|
792
|
+
|
793
|
+
Code 1:
|
794
|
+
```
|
795
|
+
session.query(...)
|
796
|
+
```
|
797
|
+
|
798
|
+
Code 2:
|
799
|
+
```
|
800
|
+
with session.track_queries() as queries:
|
801
|
+
output = lf.query(...)
|
802
|
+
```
|
803
|
+
The former is preferred when `lf.query` is directly called by the action.
|
804
|
+
If `lf.query` is called by a function that does not have access to the
|
805
|
+
session, the latter should be used.
|
806
|
+
|
807
|
+
Args:
|
808
|
+
prompt: The prompt to query.
|
809
|
+
schema: The schema to use for the query.
|
810
|
+
default: The default value to return if the query fails.
|
811
|
+
lm: The language model to use for the query.
|
812
|
+
examples: The examples to use for the query.
|
813
|
+
**kwargs: Additional keyword arguments to pass to `lf.query`.
|
814
|
+
|
815
|
+
Returns:
|
816
|
+
The result of the query.
|
817
|
+
"""
|
818
|
+
with self.track_queries():
|
819
|
+
return lf_structured.query(
|
256
820
|
prompt,
|
257
821
|
schema=schema,
|
258
822
|
default=default,
|
@@ -260,17 +824,16 @@ class Session(pg.Object):
|
|
260
824
|
examples=examples,
|
261
825
|
**kwargs
|
262
826
|
)
|
263
|
-
with pg.notify_on_change(False):
|
264
|
-
self.current_invocation.execution.extend(queries)
|
265
|
-
return output
|
266
827
|
|
267
828
|
def _log(self, level: lf.logging.LogLevel, message: str, **kwargs):
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
829
|
+
self._current_action.current_phase.append(
|
830
|
+
lf.logging.LogEntry(
|
831
|
+
level=level,
|
832
|
+
time=datetime.datetime.now(),
|
833
|
+
message=message,
|
834
|
+
metadata=kwargs,
|
835
|
+
)
|
836
|
+
)
|
274
837
|
|
275
838
|
def debug(self, message: str, **kwargs):
|
276
839
|
"""Logs a debug message to the session."""
|
@@ -296,5 +859,26 @@ class Session(pg.Object):
|
|
296
859
|
"""Returns the session as a message."""
|
297
860
|
return lf.AIMessage(
|
298
861
|
'Agentic task session.',
|
299
|
-
result=self.
|
862
|
+
result=self.root
|
863
|
+
)
|
864
|
+
|
865
|
+
#
|
866
|
+
# HTML views.
|
867
|
+
#
|
868
|
+
|
869
|
+
def _html_tree_view_content(
|
870
|
+
self,
|
871
|
+
*,
|
872
|
+
view: pg.views.html.HtmlTreeView,
|
873
|
+
**kwargs
|
874
|
+
):
|
875
|
+
return view.content(self.root, **kwargs)
|
876
|
+
|
877
|
+
@classmethod
|
878
|
+
def _html_tree_view_config(cls):
|
879
|
+
config = super()._html_tree_view_config()
|
880
|
+
config.update(
|
881
|
+
enable_summary_tooltip=False,
|
882
|
+
enable_key_tooltip=False,
|
300
883
|
)
|
884
|
+
return config
|