langfun 0.1.2.dev202412070804__py3-none-any.whl → 0.1.2.dev202412110804__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/core/agentic/action.py +716 -132
- langfun/core/agentic/action_test.py +78 -42
- langfun/core/eval/v2/evaluation.py +3 -0
- langfun/core/llms/__init__.py +1 -0
- langfun/core/llms/anthropic.py +1 -1
- langfun/core/llms/google_genai.py +31 -41
- langfun/core/llms/vertexai.py +10 -9
- langfun/core/message.py +1 -2
- langfun/core/message_test.py +1 -2
- langfun/core/structured/function_generation.py +26 -14
- langfun/core/structured/function_generation_test.py +30 -0
- langfun/core/structured/prompting.py +105 -5
- langfun/core/structured/prompting_test.py +12 -0
- langfun/core/structured/schema.py +1 -1
- {langfun-0.1.2.dev202412070804.dist-info → langfun-0.1.2.dev202412110804.dist-info}/METADATA +1 -1
- {langfun-0.1.2.dev202412070804.dist-info → langfun-0.1.2.dev202412110804.dist-info}/RECORD +19 -19
- {langfun-0.1.2.dev202412070804.dist-info → langfun-0.1.2.dev202412110804.dist-info}/LICENSE +0 -0
- {langfun-0.1.2.dev202412070804.dist-info → langfun-0.1.2.dev202412110804.dist-info}/WHEEL +0 -0
- {langfun-0.1.2.dev202412070804.dist-info → langfun-0.1.2.dev202412110804.dist-info}/top_level.txt +0 -0
@@ -18,6 +18,8 @@ import unittest
|
|
18
18
|
import langfun.core as lf
|
19
19
|
from langfun.core.agentic import action as action_lib
|
20
20
|
from langfun.core.llms import fake
|
21
|
+
import langfun.core.structured as lf_structured
|
22
|
+
import pyglove as pg
|
21
23
|
|
22
24
|
|
23
25
|
class SessionTest(unittest.TestCase):
|
@@ -28,60 +30,94 @@ class SessionTest(unittest.TestCase):
|
|
28
30
|
class Bar(action_lib.Action):
|
29
31
|
|
30
32
|
def call(self, session, *, lm, **kwargs):
|
31
|
-
test.assertIs(session.
|
33
|
+
test.assertIs(session.current_action.action, self)
|
32
34
|
session.info('Begin Bar')
|
33
35
|
session.query('bar', lm=lm)
|
36
|
+
session.add_metadata(note='bar')
|
34
37
|
return 2
|
35
38
|
|
36
39
|
class Foo(action_lib.Action):
|
37
40
|
x: int
|
38
41
|
|
39
42
|
def call(self, session, *, lm, **kwargs):
|
40
|
-
test.assertIs(session.
|
41
|
-
session.
|
42
|
-
|
43
|
+
test.assertIs(session.current_action.action, self)
|
44
|
+
with session.phase('prepare'):
|
45
|
+
session.info('Begin Foo', x=1)
|
46
|
+
session.query('foo', lm=lm)
|
47
|
+
with session.track_queries():
|
48
|
+
self.make_additional_query(lm)
|
49
|
+
session.add_metadata(note='foo')
|
43
50
|
return self.x + Bar()(session, lm=lm)
|
44
51
|
|
52
|
+
def make_additional_query(self, lm):
|
53
|
+
lf_structured.query('additional query', lm=lm)
|
54
|
+
|
45
55
|
lm = fake.StaticResponse('lm response')
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
self.
|
51
|
-
self.
|
52
|
-
self.
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
self.assertEqual(len(
|
60
|
-
self.
|
61
|
-
|
62
|
-
|
63
|
-
)
|
64
|
-
self.
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
)
|
69
|
-
self.assertEqual(
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
)
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
)
|
79
|
-
|
80
|
-
self.
|
81
|
-
|
82
|
-
|
83
|
-
|
56
|
+
foo = Foo(1)
|
57
|
+
self.assertEqual(foo(lm=lm), 3)
|
58
|
+
|
59
|
+
session = foo.session
|
60
|
+
self.assertIsNotNone(session)
|
61
|
+
self.assertIsInstance(session.root.action, action_lib.RootAction)
|
62
|
+
self.assertIs(session.current_action, session.root)
|
63
|
+
|
64
|
+
#
|
65
|
+
# Inspecting the root invocation.
|
66
|
+
#
|
67
|
+
|
68
|
+
root = session.root
|
69
|
+
self.assertEqual(len(root.execution.items), 1)
|
70
|
+
self.assertIs(root.execution.items[0].action, foo)
|
71
|
+
|
72
|
+
self.assertTrue(root.execution.has_started)
|
73
|
+
self.assertTrue(root.execution.has_stopped)
|
74
|
+
self.assertGreater(root.execution.elapse, 0)
|
75
|
+
self.assertEqual(root.result, 3)
|
76
|
+
self.assertEqual(root.result_metadata, dict(note='foo'))
|
77
|
+
|
78
|
+
# The root space should have one action (foo), no queries, and no logs.
|
79
|
+
self.assertEqual(len(list(root.actions)), 1)
|
80
|
+
self.assertEqual(len(list(root.queries)), 0)
|
81
|
+
self.assertEqual(len(list(root.logs)), 0)
|
82
|
+
# 1 query from Bar and 2 from Foo.
|
83
|
+
self.assertEqual(len(list(root.all_queries)), 3)
|
84
|
+
# 1 log from Bar and 1 from Foo.
|
85
|
+
self.assertEqual(len(list(root.all_logs)), 2)
|
86
|
+
self.assertEqual(root.usage_summary.total.num_requests, 3)
|
87
|
+
|
88
|
+
# Inspecting the top-level action (Foo)
|
89
|
+
foo_invocation = root.execution.items[0]
|
90
|
+
self.assertEqual(len(foo_invocation.execution.items), 3)
|
91
|
+
|
92
|
+
# Prepare phase.
|
93
|
+
prepare_phase = foo_invocation.execution.items[0]
|
94
|
+
self.assertIsInstance(
|
95
|
+
prepare_phase, action_lib.ExecutionTrace
|
84
96
|
)
|
97
|
+
self.assertEqual(len(prepare_phase.items), 2)
|
98
|
+
self.assertTrue(prepare_phase.has_started)
|
99
|
+
self.assertTrue(prepare_phase.has_stopped)
|
100
|
+
self.assertEqual(prepare_phase.usage_summary.total.num_requests, 1)
|
101
|
+
|
102
|
+
# Tracked queries.
|
103
|
+
query_invocation = foo_invocation.execution.items[1]
|
104
|
+
self.assertIsInstance(query_invocation, lf_structured.QueryInvocation)
|
105
|
+
self.assertIs(query_invocation.lm, lm)
|
106
|
+
|
107
|
+
# Invocation to Bar.
|
108
|
+
bar_invocation = foo_invocation.execution.items[2]
|
109
|
+
self.assertIsInstance(bar_invocation, action_lib.ActionInvocation)
|
110
|
+
self.assertIsInstance(bar_invocation.action, Bar)
|
111
|
+
self.assertEqual(bar_invocation.result, 2)
|
112
|
+
self.assertEqual(bar_invocation.result_metadata, dict(note='bar'))
|
113
|
+
self.assertEqual(len(bar_invocation.execution.items), 2)
|
114
|
+
|
115
|
+
# Save to HTML
|
116
|
+
self.assertIn('invocation-result', session.to_html().content)
|
117
|
+
|
118
|
+
# Save session to JSON
|
119
|
+
json_str = session.to_json_str(save_ref_value=True)
|
120
|
+
self.assertIsInstance(pg.from_json_str(json_str), action_lib.Session)
|
85
121
|
|
86
122
|
def test_log(self):
|
87
123
|
session = action_lib.Session()
|
langfun/core/llms/__init__.py
CHANGED
@@ -33,6 +33,7 @@ from langfun.core.llms.rest import REST
|
|
33
33
|
# Gemini models.
|
34
34
|
from langfun.core.llms.google_genai import GenAI
|
35
35
|
from langfun.core.llms.google_genai import GeminiExp_20241114
|
36
|
+
from langfun.core.llms.google_genai import GeminiExp_20241206
|
36
37
|
from langfun.core.llms.google_genai import GeminiFlash1_5
|
37
38
|
from langfun.core.llms.google_genai import GeminiPro
|
38
39
|
from langfun.core.llms.google_genai import GeminiPro1_5
|
langfun/core/llms/anthropic.py
CHANGED
@@ -20,6 +20,7 @@ from typing import Annotated, Any, Literal
|
|
20
20
|
|
21
21
|
import langfun.core as lf
|
22
22
|
from langfun.core import modalities as lf_modalities
|
23
|
+
from langfun.core.llms import vertexai
|
23
24
|
import pyglove as pg
|
24
25
|
|
25
26
|
|
@@ -54,6 +55,7 @@ class GenAI(lf.LanguageModel):
|
|
54
55
|
'gemini-1.5-pro-latest',
|
55
56
|
'gemini-1.5-flash-latest',
|
56
57
|
'gemini-exp-1114',
|
58
|
+
'gemini-exp-1206',
|
57
59
|
],
|
58
60
|
'Model name.',
|
59
61
|
]
|
@@ -306,64 +308,52 @@ _GOOGLE_GENAI_MODEL_HUB = _ModelHub()
|
|
306
308
|
#
|
307
309
|
|
308
310
|
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
'audio/flac',
|
320
|
-
'audio/mp3',
|
321
|
-
'audio/m4a',
|
322
|
-
'audio/mpeg',
|
323
|
-
'audio/mpga',
|
324
|
-
'audio/mp4',
|
325
|
-
'audio/opus',
|
326
|
-
'audio/pcm',
|
327
|
-
'audio/wav',
|
328
|
-
'audio/webm'
|
329
|
-
]
|
330
|
-
|
331
|
-
_VIDEO_TYPES = [
|
332
|
-
'video/mov',
|
333
|
-
'video/mpeg',
|
334
|
-
'video/mpegps',
|
335
|
-
'video/mpg',
|
336
|
-
'video/mp4',
|
337
|
-
'video/webm',
|
338
|
-
'video/wmv',
|
339
|
-
'video/x-flv',
|
340
|
-
'video/3gpp',
|
341
|
-
]
|
342
|
-
|
343
|
-
_PDF = [
|
344
|
-
'application/pdf',
|
345
|
-
]
|
311
|
+
class GeminiExp_20241206(GenAI): # pylint: disable=invalid-name
|
312
|
+
"""Gemini Experimental model launched on 12/06/2024."""
|
313
|
+
|
314
|
+
model = 'gemini-exp-1206'
|
315
|
+
supported_modalities = (
|
316
|
+
vertexai.DOCUMENT_TYPES
|
317
|
+
+ vertexai.IMAGE_TYPES
|
318
|
+
+ vertexai.AUDIO_TYPES
|
319
|
+
+ vertexai.VIDEO_TYPES
|
320
|
+
)
|
346
321
|
|
347
322
|
|
348
323
|
class GeminiExp_20241114(GenAI): # pylint: disable=invalid-name
|
349
324
|
"""Gemini Experimental model launched on 11/14/2024."""
|
350
325
|
|
351
326
|
model = 'gemini-exp-1114'
|
352
|
-
supported_modalities =
|
327
|
+
supported_modalities = (
|
328
|
+
vertexai.DOCUMENT_TYPES
|
329
|
+
+ vertexai.IMAGE_TYPES
|
330
|
+
+ vertexai.AUDIO_TYPES
|
331
|
+
+ vertexai.VIDEO_TYPES
|
332
|
+
)
|
353
333
|
|
354
334
|
|
355
335
|
class GeminiPro1_5(GenAI): # pylint: disable=invalid-name
|
356
336
|
"""Gemini Pro latest model."""
|
357
337
|
|
358
338
|
model = 'gemini-1.5-pro-latest'
|
359
|
-
supported_modalities =
|
339
|
+
supported_modalities = (
|
340
|
+
vertexai.DOCUMENT_TYPES
|
341
|
+
+ vertexai.IMAGE_TYPES
|
342
|
+
+ vertexai.AUDIO_TYPES
|
343
|
+
+ vertexai.VIDEO_TYPES
|
344
|
+
)
|
360
345
|
|
361
346
|
|
362
347
|
class GeminiFlash1_5(GenAI): # pylint: disable=invalid-name
|
363
348
|
"""Gemini Flash latest model."""
|
364
349
|
|
365
350
|
model = 'gemini-1.5-flash-latest'
|
366
|
-
supported_modalities =
|
351
|
+
supported_modalities = (
|
352
|
+
vertexai.DOCUMENT_TYPES
|
353
|
+
+ vertexai.IMAGE_TYPES
|
354
|
+
+ vertexai.AUDIO_TYPES
|
355
|
+
+ vertexai.VIDEO_TYPES
|
356
|
+
)
|
367
357
|
|
368
358
|
|
369
359
|
class GeminiPro(GenAI):
|
@@ -376,7 +366,7 @@ class GeminiProVision(GenAI):
|
|
376
366
|
"""Gemini Pro vision model."""
|
377
367
|
|
378
368
|
model = 'gemini-pro-vision'
|
379
|
-
supported_modalities =
|
369
|
+
supported_modalities = vertexai.IMAGE_TYPES + vertexai.VIDEO_TYPES
|
380
370
|
|
381
371
|
|
382
372
|
class Palm2(GenAI):
|
langfun/core/llms/vertexai.py
CHANGED
@@ -343,7 +343,7 @@ class VertexAI(rest.REST):
|
|
343
343
|
return lf.AIMessage.from_chunks(chunks)
|
344
344
|
|
345
345
|
|
346
|
-
|
346
|
+
IMAGE_TYPES = [
|
347
347
|
'image/png',
|
348
348
|
'image/jpeg',
|
349
349
|
'image/webp',
|
@@ -351,7 +351,7 @@ _IMAGE_TYPES = [
|
|
351
351
|
'image/heif',
|
352
352
|
]
|
353
353
|
|
354
|
-
|
354
|
+
AUDIO_TYPES = [
|
355
355
|
'audio/aac',
|
356
356
|
'audio/flac',
|
357
357
|
'audio/mp3',
|
@@ -362,10 +362,10 @@ _AUDIO_TYPES = [
|
|
362
362
|
'audio/opus',
|
363
363
|
'audio/pcm',
|
364
364
|
'audio/wav',
|
365
|
-
'audio/webm'
|
365
|
+
'audio/webm',
|
366
366
|
]
|
367
367
|
|
368
|
-
|
368
|
+
VIDEO_TYPES = [
|
369
369
|
'video/mov',
|
370
370
|
'video/mpeg',
|
371
371
|
'video/mpegps',
|
@@ -375,9 +375,10 @@ _VIDEO_TYPES = [
|
|
375
375
|
'video/wmv',
|
376
376
|
'video/x-flv',
|
377
377
|
'video/3gpp',
|
378
|
+
'video/quicktime',
|
378
379
|
]
|
379
380
|
|
380
|
-
|
381
|
+
DOCUMENT_TYPES = [
|
381
382
|
'application/pdf',
|
382
383
|
'text/plain',
|
383
384
|
'text/csv',
|
@@ -391,8 +392,8 @@ _DOCUMENT_TYPES = [
|
|
391
392
|
class VertexAIGemini1_5(VertexAI): # pylint: disable=invalid-name
|
392
393
|
"""Vertex AI Gemini 1.5 model."""
|
393
394
|
|
394
|
-
supported_modalities: pg.typing.List(str).freeze(
|
395
|
-
|
395
|
+
supported_modalities: pg.typing.List(str).freeze( # pytype: disable=invalid-annotation
|
396
|
+
DOCUMENT_TYPES + IMAGE_TYPES + AUDIO_TYPES + VIDEO_TYPES
|
396
397
|
)
|
397
398
|
|
398
399
|
|
@@ -460,8 +461,8 @@ class VertexAIGeminiPro1Vision(VertexAI): # pylint: disable=invalid-name
|
|
460
461
|
"""Vertex AI Gemini 1.0 Pro Vision model."""
|
461
462
|
|
462
463
|
model = 'gemini-1.0-pro-vision'
|
463
|
-
supported_modalities: pg.typing.List(str).freeze(
|
464
|
-
|
464
|
+
supported_modalities: pg.typing.List(str).freeze( # pytype: disable=invalid-annotation
|
465
|
+
IMAGE_TYPES + VIDEO_TYPES
|
465
466
|
)
|
466
467
|
|
467
468
|
|
langfun/core/message.py
CHANGED
@@ -769,7 +769,6 @@ class Message(
|
|
769
769
|
padding: 20px;
|
770
770
|
margin: 10px 5px 10px 5px;
|
771
771
|
font-style: italic;
|
772
|
-
font-size: 1.1em;
|
773
772
|
white-space: pre-wrap;
|
774
773
|
border: 1px solid #EEE;
|
775
774
|
border-radius: 5px;
|
@@ -778,7 +777,7 @@ class Message(
|
|
778
777
|
.modality-in-text {
|
779
778
|
display: inline-block;
|
780
779
|
}
|
781
|
-
.modality-in-text > details {
|
780
|
+
.modality-in-text > details.pyglove {
|
782
781
|
display: inline-block;
|
783
782
|
font-size: 0.8em;
|
784
783
|
border: 0;
|
langfun/core/message_test.py
CHANGED
@@ -380,7 +380,6 @@ class MessageTest(unittest.TestCase):
|
|
380
380
|
padding: 20px;
|
381
381
|
margin: 10px 5px 10px 5px;
|
382
382
|
font-style: italic;
|
383
|
-
font-size: 1.1em;
|
384
383
|
white-space: pre-wrap;
|
385
384
|
border: 1px solid #EEE;
|
386
385
|
border-radius: 5px;
|
@@ -389,7 +388,7 @@ class MessageTest(unittest.TestCase):
|
|
389
388
|
.modality-in-text {
|
390
389
|
display: inline-block;
|
391
390
|
}
|
392
|
-
.modality-in-text > details {
|
391
|
+
.modality-in-text > details.pyglove {
|
393
392
|
display: inline-block;
|
394
393
|
font-size: 0.8em;
|
395
394
|
border: 0;
|
@@ -76,6 +76,7 @@ def unittest_with_test_cases(f, unittests):
|
|
76
76
|
|
77
77
|
def _function_gen(
|
78
78
|
func: Callable[..., Any],
|
79
|
+
context: dict[str, Any],
|
79
80
|
signature: str,
|
80
81
|
lm: language_model.LanguageModel,
|
81
82
|
num_retries: int = 1,
|
@@ -141,21 +142,23 @@ def _function_gen(
|
|
141
142
|
elif isinstance(unittest, list):
|
142
143
|
unittest_examples = unittest
|
143
144
|
|
145
|
+
last_error = None
|
144
146
|
for _ in range(num_retries):
|
145
147
|
try:
|
146
148
|
source_code = prompting.query(
|
147
149
|
PythonFunctionPrompt(signature=signature), lm=lm
|
148
150
|
)
|
149
|
-
f = python.evaluate(source_code)
|
151
|
+
f = python.evaluate(source_code, global_vars=context)
|
150
152
|
|
151
153
|
# Check whether the sigantures are the same.
|
152
154
|
if inspect.signature(f) != inspect.signature(func):
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
155
|
+
raise python.CodeError(
|
156
|
+
code=source_code,
|
157
|
+
cause=TypeError(
|
158
|
+
f"Signature mismatch: Expected: {inspect.signature(func)}, "
|
159
|
+
f"Actual: {inspect.signature(f)}.",
|
160
|
+
),
|
157
161
|
)
|
158
|
-
continue
|
159
162
|
|
160
163
|
if callable(unittest):
|
161
164
|
unittest(f)
|
@@ -163,10 +166,12 @@ def _function_gen(
|
|
163
166
|
unittest_with_test_cases(f, unittest_examples)
|
164
167
|
|
165
168
|
return f, source_code
|
166
|
-
except
|
167
|
-
|
168
|
-
|
169
|
-
|
169
|
+
except python.CodeError as e:
|
170
|
+
last_error = e
|
171
|
+
pg.logging.warning(
|
172
|
+
f"Bad code generated: {e}",
|
173
|
+
)
|
174
|
+
raise last_error
|
170
175
|
|
171
176
|
|
172
177
|
def _process_signature(signature):
|
@@ -220,6 +225,13 @@ def function_gen(
|
|
220
225
|
setattr(func, "__function__", None)
|
221
226
|
setattr(func, "__source_code__", None)
|
222
227
|
|
228
|
+
# Prepare the globals/locals for the generated code to be evaluated against.
|
229
|
+
callstack = inspect.stack()
|
230
|
+
assert len(callstack) > 1
|
231
|
+
context = dict(callstack[1][0].f_globals)
|
232
|
+
context.update(callstack[1][0].f_locals)
|
233
|
+
context.pop(func.__name__, None)
|
234
|
+
|
223
235
|
@functools.wraps(func)
|
224
236
|
def lm_generated_func(*args, **kwargs):
|
225
237
|
if func.__function__ is not None:
|
@@ -238,20 +250,20 @@ def function_gen(
|
|
238
250
|
|
239
251
|
if signature in cache:
|
240
252
|
func.__source_code__ = cache[signature]
|
241
|
-
func.__function__ = python.evaluate(
|
253
|
+
func.__function__ = python.evaluate(
|
254
|
+
func.__source_code__, global_vars=context
|
255
|
+
)
|
242
256
|
return func.__function__(*args, **kwargs)
|
243
257
|
|
244
258
|
func.__function__, func.__source_code__ = _function_gen(
|
245
259
|
func,
|
260
|
+
context,
|
246
261
|
signature,
|
247
262
|
lm,
|
248
263
|
num_retries=num_retries,
|
249
264
|
unittest=unittest,
|
250
265
|
unittest_num_retries=unittest_num_retries,
|
251
266
|
)
|
252
|
-
if func.__function__ is None:
|
253
|
-
raise ValueError(f"Function generation failed. Signature:\n{signature}")
|
254
|
-
|
255
267
|
if cache_filename is not None:
|
256
268
|
cache[signature] = func.__source_code__
|
257
269
|
cache.save(cache_filename)
|
@@ -311,6 +311,36 @@ class FunctionGenerationTest(unittest.TestCase):
|
|
311
311
|
|
312
312
|
self.assertEqual(linear_search(['a', 'b', 'c'], 'c'), 2)
|
313
313
|
|
314
|
+
def test_context_passthrough(self):
|
315
|
+
|
316
|
+
class Number(pg.Object):
|
317
|
+
value: int
|
318
|
+
|
319
|
+
function_gen_lm_response = inspect.cleandoc("""
|
320
|
+
```python
|
321
|
+
def add(a: Number, b: Number) -> Number:
|
322
|
+
\"\"\"Adds two numbers together.\"\"\"
|
323
|
+
return Number(a.value + b.value)
|
324
|
+
```
|
325
|
+
""")
|
326
|
+
|
327
|
+
lm = fake.StaticSequence(
|
328
|
+
[function_gen_lm_response]
|
329
|
+
)
|
330
|
+
|
331
|
+
def _unittest_fn(func):
|
332
|
+
assert func(Number(1), Number(2)) == Number(3)
|
333
|
+
|
334
|
+
custom_unittest = _unittest_fn
|
335
|
+
|
336
|
+
@function_generation.function_gen(
|
337
|
+
lm=lm, unittest=custom_unittest, num_retries=1
|
338
|
+
)
|
339
|
+
def add(a: Number, b: Number) -> Number: # pylint: disable=unused-argument
|
340
|
+
"""Adds two numbers together."""
|
341
|
+
|
342
|
+
self.assertEqual(add(Number(2), Number(3)), Number(5))
|
343
|
+
|
314
344
|
def test_siganture_check(self):
|
315
345
|
incorrect_signature_lm_response = inspect.cleandoc("""
|
316
346
|
```python
|
@@ -264,9 +264,9 @@ def query(
|
|
264
264
|
schema_lib.Schema.from_value(schema)
|
265
265
|
if schema not in (None, str) else None
|
266
266
|
),
|
267
|
-
output=pg.Ref(_result(output_message)),
|
268
267
|
lm=pg.Ref(lm),
|
269
268
|
examples=pg.Ref(examples) if examples else [],
|
269
|
+
lm_response=lf.AIMessage(output_message.text),
|
270
270
|
usage_summary=usage_summary,
|
271
271
|
)
|
272
272
|
for i, (tracker, include_child_scopes) in enumerate(trackers):
|
@@ -357,7 +357,7 @@ def _reward_fn(cls) -> Callable[
|
|
357
357
|
return _reward
|
358
358
|
|
359
359
|
|
360
|
-
class QueryInvocation(pg.Object):
|
360
|
+
class QueryInvocation(pg.Object, pg.views.HtmlTreeView.Extension):
|
361
361
|
"""A class to represent the invocation of `lf.query`."""
|
362
362
|
|
363
363
|
input: Annotated[
|
@@ -368,9 +368,9 @@ class QueryInvocation(pg.Object):
|
|
368
368
|
schema_lib.schema_spec(noneable=True),
|
369
369
|
'Schema of `lf.query`.'
|
370
370
|
]
|
371
|
-
|
372
|
-
|
373
|
-
'
|
371
|
+
lm_response: Annotated[
|
372
|
+
lf.Message,
|
373
|
+
'Raw LM response.'
|
374
374
|
]
|
375
375
|
lm: Annotated[
|
376
376
|
lf.LanguageModel,
|
@@ -385,6 +385,106 @@ class QueryInvocation(pg.Object):
|
|
385
385
|
'Usage summary for `lf.query`.'
|
386
386
|
]
|
387
387
|
|
388
|
+
@functools.cached_property
|
389
|
+
def lm_request(self) -> lf.Message:
|
390
|
+
return query_prompt(self.input, self.schema)
|
391
|
+
|
392
|
+
@functools.cached_property
|
393
|
+
def output(self) -> Any:
|
394
|
+
return query_output(self.lm_response, self.schema)
|
395
|
+
|
396
|
+
def _on_bound(self):
|
397
|
+
super()._on_bound()
|
398
|
+
self.__dict__.pop('lm_request', None)
|
399
|
+
self.__dict__.pop('output', None)
|
400
|
+
|
401
|
+
def _html_tree_view_summary(
|
402
|
+
self,
|
403
|
+
*,
|
404
|
+
view: pg.views.HtmlTreeView,
|
405
|
+
**kwargs: Any
|
406
|
+
) -> pg.Html | None:
|
407
|
+
return view.summary(
|
408
|
+
value=self,
|
409
|
+
title=pg.Html.element(
|
410
|
+
'div',
|
411
|
+
[
|
412
|
+
pg.views.html.controls.Label(
|
413
|
+
'lf.query',
|
414
|
+
css_classes=['query-invocation-type-name']
|
415
|
+
),
|
416
|
+
pg.views.html.controls.Badge(
|
417
|
+
f'lm={self.lm.model_id}',
|
418
|
+
pg.format(
|
419
|
+
self.lm,
|
420
|
+
verbose=False,
|
421
|
+
python_format=True,
|
422
|
+
hide_default_values=True
|
423
|
+
),
|
424
|
+
css_classes=['query-invocation-lm']
|
425
|
+
),
|
426
|
+
self.usage_summary.to_html(extra_flags=dict(as_badge=True))
|
427
|
+
],
|
428
|
+
css_classes=['query-invocation-title']
|
429
|
+
),
|
430
|
+
enable_summary_tooltip=False
|
431
|
+
)
|
432
|
+
|
433
|
+
def _html_tree_view_content(
|
434
|
+
self,
|
435
|
+
*,
|
436
|
+
view: pg.views.HtmlTreeView,
|
437
|
+
**kwargs: Any
|
438
|
+
) -> pg.Html:
|
439
|
+
return pg.views.html.controls.TabControl([
|
440
|
+
pg.views.html.controls.Tab(
|
441
|
+
'input',
|
442
|
+
pg.view(self.input, collapse_level=None),
|
443
|
+
),
|
444
|
+
pg.views.html.controls.Tab(
|
445
|
+
'schema',
|
446
|
+
pg.view(self.schema),
|
447
|
+
),
|
448
|
+
pg.views.html.controls.Tab(
|
449
|
+
'output',
|
450
|
+
pg.view(self.output, collapse_level=None),
|
451
|
+
),
|
452
|
+
pg.views.html.controls.Tab(
|
453
|
+
'lm_request',
|
454
|
+
pg.view(
|
455
|
+
self.lm_request,
|
456
|
+
extra_flags=dict(include_message_metadata=False),
|
457
|
+
),
|
458
|
+
),
|
459
|
+
pg.views.html.controls.Tab(
|
460
|
+
'lm_response',
|
461
|
+
pg.view(
|
462
|
+
self.lm_response,
|
463
|
+
extra_flags=dict(include_message_metadata=False)
|
464
|
+
),
|
465
|
+
),
|
466
|
+
], tab_position='top').to_html()
|
467
|
+
|
468
|
+
@classmethod
|
469
|
+
def _html_tree_view_css_styles(cls) -> list[str]:
|
470
|
+
return super()._html_tree_view_css_styles() + [
|
471
|
+
"""
|
472
|
+
.query-invocation-title {
|
473
|
+
display: inline-block;
|
474
|
+
font-weight: normal;
|
475
|
+
}
|
476
|
+
.query-invocation-type-name {
|
477
|
+
font-style: italic;
|
478
|
+
color: #888;
|
479
|
+
}
|
480
|
+
.query-invocation-lm.badge {
|
481
|
+
margin-left: 5px;
|
482
|
+
margin-right: 5px;
|
483
|
+
background-color: #fff0d6;
|
484
|
+
}
|
485
|
+
"""
|
486
|
+
]
|
487
|
+
|
388
488
|
|
389
489
|
@contextlib.contextmanager
|
390
490
|
def track_queries(
|
@@ -962,6 +962,18 @@ class QueryStructureJsonTest(unittest.TestCase):
|
|
962
962
|
)
|
963
963
|
|
964
964
|
|
965
|
+
class QueryInvocationTest(unittest.TestCase):
|
966
|
+
|
967
|
+
def test_to_html(self):
|
968
|
+
lm = fake.StaticSequence([
|
969
|
+
'Activity(description="hi")',
|
970
|
+
])
|
971
|
+
with prompting.track_queries() as queries:
|
972
|
+
prompting.query('foo', Activity, lm=lm)
|
973
|
+
|
974
|
+
self.assertIn('schema', queries[0].to_html_str())
|
975
|
+
|
976
|
+
|
965
977
|
class TrackQueriesTest(unittest.TestCase):
|
966
978
|
|
967
979
|
def test_include_child_scopes(self):
|