langfun 0.1.2.dev202412030804__py3-none-any.whl → 0.1.2.dev202412040804__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/__init__.py +2 -0
- langfun/core/agentic/action.py +74 -24
- langfun/core/agentic/action_test.py +20 -4
- langfun/core/structured/__init__.py +2 -0
- langfun/core/structured/prompting.py +148 -47
- langfun/core/structured/prompting_test.py +84 -1
- {langfun-0.1.2.dev202412030804.dist-info → langfun-0.1.2.dev202412040804.dist-info}/METADATA +1 -1
- {langfun-0.1.2.dev202412030804.dist-info → langfun-0.1.2.dev202412040804.dist-info}/RECORD +11 -11
- {langfun-0.1.2.dev202412030804.dist-info → langfun-0.1.2.dev202412040804.dist-info}/LICENSE +0 -0
- {langfun-0.1.2.dev202412030804.dist-info → langfun-0.1.2.dev202412040804.dist-info}/WHEEL +0 -0
- {langfun-0.1.2.dev202412030804.dist-info → langfun-0.1.2.dev202412040804.dist-info}/top_level.txt +0 -0
langfun/__init__.py
CHANGED
@@ -35,6 +35,8 @@ complete = structured.complete
|
|
35
35
|
score = structured.score
|
36
36
|
generate_class = structured.generate_class
|
37
37
|
|
38
|
+
track_queries = structured.track_queries
|
39
|
+
|
38
40
|
# Helper functions for input/output transformations based on
|
39
41
|
# `lf.query` (e.g. jax-on-beam could use these for batch processing)
|
40
42
|
query_prompt = structured.query_prompt
|
langfun/core/agentic/action.py
CHANGED
@@ -14,8 +14,10 @@
|
|
14
14
|
"""Base classes for agentic actions."""
|
15
15
|
|
16
16
|
import abc
|
17
|
-
|
17
|
+
import contextlib
|
18
|
+
from typing import Annotated, Any, Iterable, Iterator, Optional, Type, Union
|
18
19
|
import langfun.core as lf
|
20
|
+
from langfun.core import structured as lf_structured
|
19
21
|
import pyglove as pg
|
20
22
|
|
21
23
|
|
@@ -35,12 +37,9 @@ class Action(pg.Object):
|
|
35
37
|
self, session: Optional['Session'] = None, **kwargs) -> Any:
|
36
38
|
"""Executes the action."""
|
37
39
|
session = session or Session()
|
38
|
-
|
39
|
-
session.begin(self)
|
40
|
+
with session.track(self):
|
40
41
|
self._result = self.call(session=session, **kwargs)
|
41
42
|
return self._result
|
42
|
-
finally:
|
43
|
-
session.end(self)
|
44
43
|
|
45
44
|
@abc.abstractmethod
|
46
45
|
def call(self, session: 'Session', **kwargs) -> Any:
|
@@ -50,9 +49,20 @@ class Action(pg.Object):
|
|
50
49
|
class ActionInvocation(pg.Object, pg.views.html.HtmlTreeView.Extension):
|
51
50
|
"""A class for capturing the invocation of an action."""
|
52
51
|
action: Action
|
53
|
-
|
52
|
+
|
53
|
+
result: Annotated[
|
54
|
+
Any,
|
55
|
+
'The result of the action.'
|
56
|
+
] = None
|
57
|
+
|
54
58
|
execution: Annotated[
|
55
|
-
list[
|
59
|
+
list[
|
60
|
+
Union[
|
61
|
+
lf_structured.QueryInvocation,
|
62
|
+
'ActionInvocation',
|
63
|
+
lf.logging.LogEntry
|
64
|
+
]
|
65
|
+
],
|
56
66
|
'Execution execution.'
|
57
67
|
] = []
|
58
68
|
|
@@ -69,6 +79,18 @@ class ActionInvocation(pg.Object, pg.views.html.HtmlTreeView.Extension):
|
|
69
79
|
"""Returns child action invocations."""
|
70
80
|
return [v for v in self.execution if isinstance(v, ActionInvocation)]
|
71
81
|
|
82
|
+
def queries(
|
83
|
+
self,
|
84
|
+
include_children: bool = False
|
85
|
+
) -> Iterable[lf_structured.QueryInvocation]:
|
86
|
+
"""Iterates over queries from the current invocation."""
|
87
|
+
for v in self.execution:
|
88
|
+
if isinstance(v, lf_structured.QueryInvocation):
|
89
|
+
yield v
|
90
|
+
elif isinstance(v, ActionInvocation):
|
91
|
+
if include_children:
|
92
|
+
yield from v.queries(include_children=True)
|
93
|
+
|
72
94
|
def _html_tree_view_summary(
|
73
95
|
self, *, view: pg.views.html.HtmlTreeView, **kwargs
|
74
96
|
):
|
@@ -190,29 +212,57 @@ class Session(pg.Object):
|
|
190
212
|
assert self._invocation_stack
|
191
213
|
return self._invocation_stack[-1]
|
192
214
|
|
193
|
-
|
194
|
-
|
215
|
+
@contextlib.contextmanager
|
216
|
+
def track(self, action: Action) -> Iterator[ActionInvocation]:
|
217
|
+
"""Track the execution of an action."""
|
195
218
|
new_invocation = ActionInvocation(pg.maybe_ref(action))
|
196
219
|
with pg.notify_on_change(False):
|
197
220
|
self.current_invocation.execution.append(new_invocation)
|
198
221
|
self._invocation_stack.append(new_invocation)
|
199
222
|
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
assert invocation.action is action, (invocation.action, action)
|
208
|
-
assert self._invocation_stack, self._invocation_stack
|
209
|
-
|
210
|
-
if len(self._invocation_stack) == 1:
|
211
|
-
self.root_invocation.rebind(
|
212
|
-
result=invocation.result,
|
213
|
-
skip_notification=True,
|
214
|
-
raise_on_no_change=False
|
223
|
+
try:
|
224
|
+
yield new_invocation
|
225
|
+
finally:
|
226
|
+
assert self._invocation_stack
|
227
|
+
invocation = self._invocation_stack.pop(-1)
|
228
|
+
invocation.rebind(
|
229
|
+
result=action.result, skip_notification=True, raise_on_no_change=False
|
215
230
|
)
|
231
|
+
assert invocation.action is action, (invocation.action, action)
|
232
|
+
assert self._invocation_stack, self._invocation_stack
|
233
|
+
|
234
|
+
if len(self._invocation_stack) == 1:
|
235
|
+
self.root_invocation.rebind(
|
236
|
+
result=invocation.result,
|
237
|
+
skip_notification=True,
|
238
|
+
raise_on_no_change=False
|
239
|
+
)
|
240
|
+
|
241
|
+
def query(
|
242
|
+
self,
|
243
|
+
prompt: Union[str, lf.Template, Any],
|
244
|
+
schema: Union[
|
245
|
+
lf_structured.Schema, Type[Any], list[Type[Any]], dict[str, Any], None
|
246
|
+
] = None,
|
247
|
+
default: Any = lf.RAISE_IF_HAS_ERROR,
|
248
|
+
*,
|
249
|
+
lm: lf.LanguageModel | None = None,
|
250
|
+
examples: list[lf_structured.MappingExample] | None = None,
|
251
|
+
**kwargs
|
252
|
+
) -> Any:
|
253
|
+
"""Calls `lf.query` and associates it with the current invocation."""
|
254
|
+
with lf_structured.track_queries() as queries:
|
255
|
+
output = lf_structured.query(
|
256
|
+
prompt,
|
257
|
+
schema=schema,
|
258
|
+
default=default,
|
259
|
+
lm=lm,
|
260
|
+
examples=examples,
|
261
|
+
**kwargs
|
262
|
+
)
|
263
|
+
with pg.notify_on_change(False):
|
264
|
+
self.current_invocation.execution.extend(queries)
|
265
|
+
return output
|
216
266
|
|
217
267
|
def _log(self, level: lf.logging.LogLevel, message: str, **kwargs):
|
218
268
|
with pg.notify_on_change(False):
|
@@ -17,6 +17,7 @@ import unittest
|
|
17
17
|
|
18
18
|
import langfun.core as lf
|
19
19
|
from langfun.core.agentic import action as action_lib
|
20
|
+
from langfun.core.llms import fake
|
20
21
|
|
21
22
|
|
22
23
|
class SessionTest(unittest.TestCase):
|
@@ -26,25 +27,35 @@ class SessionTest(unittest.TestCase):
|
|
26
27
|
|
27
28
|
class Bar(action_lib.Action):
|
28
29
|
|
29
|
-
def call(self, session, **kwargs):
|
30
|
+
def call(self, session, *, lm, **kwargs):
|
30
31
|
test.assertIs(session.current_invocation.action, self)
|
31
32
|
session.info('Begin Bar')
|
33
|
+
session.query('bar', lm=lm)
|
32
34
|
return 2
|
33
35
|
|
34
36
|
class Foo(action_lib.Action):
|
35
37
|
x: int
|
36
38
|
|
37
|
-
def call(self, session, **kwargs):
|
39
|
+
def call(self, session, *, lm, **kwargs):
|
38
40
|
test.assertIs(session.current_invocation.action, self)
|
39
41
|
session.info('Begin Foo', x=1)
|
40
|
-
|
42
|
+
session.query('foo', lm=lm)
|
43
|
+
return self.x + Bar()(session, lm=lm)
|
41
44
|
|
45
|
+
lm = fake.StaticResponse('lm response')
|
42
46
|
session = action_lib.Session()
|
43
47
|
root = session.root_invocation
|
44
48
|
self.assertIsInstance(root.action, action_lib.RootAction)
|
45
49
|
self.assertIs(session.current_invocation, session.root_invocation)
|
46
|
-
self.assertEqual(Foo(1)(session), 3)
|
50
|
+
self.assertEqual(Foo(1)(session, lm=lm), 3)
|
47
51
|
self.assertEqual(len(session.root_invocation.child_invocations), 1)
|
52
|
+
self.assertEqual(len(list(session.root_invocation.queries())), 0)
|
53
|
+
self.assertEqual(
|
54
|
+
len(list(session.root_invocation.queries(include_children=True))), 2
|
55
|
+
)
|
56
|
+
self.assertEqual(
|
57
|
+
len(list(session.root_invocation.child_invocations[0].queries())), 1
|
58
|
+
)
|
48
59
|
self.assertEqual(len(session.root_invocation.child_invocations[0].logs), 1)
|
49
60
|
self.assertEqual(
|
50
61
|
len(session.root_invocation.child_invocations[0].child_invocations),
|
@@ -55,6 +66,11 @@ class SessionTest(unittest.TestCase):
|
|
55
66
|
.child_invocations[0].child_invocations[0].logs),
|
56
67
|
1
|
57
68
|
)
|
69
|
+
self.assertEqual(
|
70
|
+
len(list(session.root_invocation
|
71
|
+
.child_invocations[0].child_invocations[0].queries())),
|
72
|
+
1
|
73
|
+
)
|
58
74
|
self.assertEqual(
|
59
75
|
len(session.root_invocation
|
60
76
|
.child_invocations[0].child_invocations[0].child_invocations),
|
@@ -69,6 +69,8 @@ from langfun.core.structured.prompting import query
|
|
69
69
|
from langfun.core.structured.prompting import query_prompt
|
70
70
|
from langfun.core.structured.prompting import query_output
|
71
71
|
from langfun.core.structured.prompting import query_reward
|
72
|
+
from langfun.core.structured.prompting import QueryInvocation
|
73
|
+
from langfun.core.structured.prompting import track_queries
|
72
74
|
|
73
75
|
from langfun.core.structured.description import DescribeStructure
|
74
76
|
from langfun.core.structured.description import describe
|
@@ -13,8 +13,9 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
"""Symbolic query."""
|
15
15
|
|
16
|
+
import contextlib
|
16
17
|
import functools
|
17
|
-
from typing import Any, Callable, Type, Union
|
18
|
+
from typing import Annotated, Any, Callable, Iterator, Type, Union
|
18
19
|
|
19
20
|
import langfun.core as lf
|
20
21
|
from langfun.core.llms import fake
|
@@ -102,7 +103,7 @@ def _query_structure_cls(
|
|
102
103
|
|
103
104
|
|
104
105
|
def query(
|
105
|
-
prompt: Union[str,
|
106
|
+
prompt: Union[str, lf.Template, Any],
|
106
107
|
schema: Union[
|
107
108
|
schema_lib.Schema, Type[Any], list[Type[Any]], dict[str, Any], None
|
108
109
|
] = None,
|
@@ -119,7 +120,7 @@ def query(
|
|
119
120
|
skip_lm: bool = False,
|
120
121
|
**kwargs,
|
121
122
|
) -> Any:
|
122
|
-
"""
|
123
|
+
"""Queries an language model for a (maybe) structured output.
|
123
124
|
|
124
125
|
Examples:
|
125
126
|
|
@@ -189,59 +190,93 @@ def query(
|
|
189
190
|
"""
|
190
191
|
# Internal usage logging.
|
191
192
|
|
193
|
+
# Normalize query schema.
|
192
194
|
# When `lf.query` is used for symbolic completion, schema is automatically
|
193
195
|
# inferred when it is None.
|
194
196
|
if isinstance(prompt, pg.Symbolic) and prompt.sym_partial and schema is None:
|
195
197
|
schema = prompt.__class__
|
196
198
|
|
197
|
-
#
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
199
|
+
# Normalize query input.
|
200
|
+
if isinstance(prompt, (lf.Message, str)):
|
201
|
+
# Query with structured output.
|
202
|
+
prompt_kwargs = kwargs.copy()
|
203
|
+
prompt_kwargs.pop('template_str', None)
|
204
|
+
query_input = lf.Template.from_value(prompt, **prompt_kwargs)
|
205
|
+
elif isinstance(prompt, lf.Template):
|
206
|
+
# Create a copy of the prompt if it has a parent object, so all child
|
207
|
+
# modality objects could be referred by path relative to the prompt.
|
208
|
+
query_input = prompt.clone() if prompt.sym_parent is not None else prompt
|
209
|
+
|
210
|
+
# Attach template metadata from kwargs. This is used to pass through fields
|
211
|
+
# from kwargs to the rendered message.
|
212
|
+
template_metadata = {
|
213
|
+
k: v for k, v in kwargs.items() if k.startswith('metadata_')
|
214
|
+
}
|
215
|
+
query_input.rebind(
|
216
|
+
template_metadata, skip_notification=True, raise_on_no_change=False
|
206
217
|
)
|
207
|
-
|
208
|
-
|
209
|
-
if processed_text != output.text:
|
210
|
-
output = lf.AIMessage(processed_text, source=output)
|
211
|
-
return output if returns_message else output.text
|
212
|
-
|
213
|
-
# Query with structured output.
|
214
|
-
prompt_kwargs = kwargs.copy()
|
215
|
-
|
216
|
-
# NOTE(daiyip): when `template_str` is passed in, it's intended to modify the
|
217
|
-
# QueryStructure template string. Therefore, we pop out the argument for
|
218
|
-
# prompt rendering.
|
219
|
-
prompt_kwargs.pop('template_str', None)
|
220
|
-
|
221
|
-
if isinstance(prompt, (str, lf.Message, lf.Template)):
|
222
|
-
prompt = lf.Template.from_value(prompt, **prompt_kwargs).render(lm=lm)
|
218
|
+
elif pg.MISSING_VALUE == prompt:
|
219
|
+
query_input = lf.UserMessage('')
|
223
220
|
else:
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
221
|
+
query_input = schema_lib.mark_missing(prompt)
|
222
|
+
|
223
|
+
with lf.track_usages() as usage_summary:
|
224
|
+
if schema in (None, str):
|
225
|
+
# Query with natural language output.
|
226
|
+
output_message = lf.LangFunc.from_value(query_input, **kwargs)(
|
227
|
+
lm=lm, cache_seed=cache_seed, skip_lm=skip_lm
|
228
|
+
)
|
229
|
+
if response_postprocess:
|
230
|
+
processed_text = response_postprocess(output_message.text)
|
231
|
+
if processed_text != output_message.text:
|
232
|
+
output_message = lf.AIMessage(processed_text, source=output_message)
|
233
|
+
else:
|
234
|
+
# Query with structured output.
|
235
|
+
output_message = _query_structure_cls(protocol)(
|
236
|
+
input=(
|
237
|
+
query_input.render(lm=lm)
|
238
|
+
if isinstance(query_input, lf.Template)
|
239
|
+
else query_input
|
240
|
+
),
|
241
|
+
schema=schema,
|
242
|
+
default=default,
|
243
|
+
examples=examples,
|
244
|
+
response_postprocess=response_postprocess,
|
245
|
+
autofix=autofix if protocol == 'python' else 0,
|
246
|
+
**kwargs,
|
247
|
+
)(
|
248
|
+
lm=lm,
|
249
|
+
autofix_lm=autofix_lm or lm,
|
250
|
+
cache_seed=cache_seed,
|
251
|
+
skip_lm=skip_lm,
|
252
|
+
)
|
253
|
+
|
254
|
+
def _result(message: lf.Message):
|
255
|
+
return message.text if schema in (None, str) else message.result
|
256
|
+
|
257
|
+
# Track the query invocations.
|
258
|
+
if pg.MISSING_VALUE != prompt and not skip_lm:
|
259
|
+
trackers = lf.context_value('__query_trackers__', [])
|
260
|
+
if trackers:
|
261
|
+
invocation = QueryInvocation(
|
262
|
+
input=pg.Ref(query_input),
|
263
|
+
schema=(
|
264
|
+
schema_lib.Schema.from_value(schema)
|
265
|
+
if schema not in (None, str) else None
|
266
|
+
),
|
267
|
+
output=pg.Ref(_result(output_message)),
|
268
|
+
lm=pg.Ref(lm),
|
269
|
+
examples=pg.Ref(examples) if examples else [],
|
270
|
+
usage_summary=usage_summary,
|
271
|
+
)
|
272
|
+
for i, (tracker, include_child_scopes) in enumerate(trackers):
|
273
|
+
if i == 0 or include_child_scopes:
|
274
|
+
tracker.append(invocation)
|
275
|
+
return output_message if returns_message else _result(output_message)
|
241
276
|
|
242
277
|
|
243
278
|
def query_prompt(
|
244
|
-
prompt: Union[str,
|
279
|
+
prompt: Union[str, lf.Template, Any],
|
245
280
|
schema: Union[
|
246
281
|
schema_lib.Schema, Type[Any], list[Type[Any]], dict[str, Any], None
|
247
282
|
] = None,
|
@@ -264,7 +299,7 @@ def query_output(
|
|
264
299
|
kwargs.pop('prompt', None)
|
265
300
|
kwargs.pop('lm', None)
|
266
301
|
return query(
|
267
|
-
|
302
|
+
pg.MISSING_VALUE, schema, lm=fake.StaticResponse(response), **kwargs
|
268
303
|
)
|
269
304
|
|
270
305
|
|
@@ -320,3 +355,69 @@ def _reward_fn(cls) -> Callable[
|
|
320
355
|
args = [self, input, expected_output, metadata]
|
321
356
|
return cls.__reward__(*args[:num_args])
|
322
357
|
return _reward
|
358
|
+
|
359
|
+
|
360
|
+
class QueryInvocation(pg.Object):
|
361
|
+
"""A class to represent the invocation of `lf.query`."""
|
362
|
+
|
363
|
+
input: Annotated[
|
364
|
+
Union[lf.Template, pg.Symbolic],
|
365
|
+
'Mapping input of `lf.query`.'
|
366
|
+
]
|
367
|
+
schema: pg.typing.Annotated[
|
368
|
+
schema_lib.schema_spec(noneable=True),
|
369
|
+
'Schema of `lf.query`.'
|
370
|
+
]
|
371
|
+
output: Annotated[
|
372
|
+
Any,
|
373
|
+
'Mapping output of `lf.query`.'
|
374
|
+
]
|
375
|
+
lm: Annotated[
|
376
|
+
lf.LanguageModel,
|
377
|
+
'Language model used for `lf.query`.'
|
378
|
+
]
|
379
|
+
examples: Annotated[
|
380
|
+
list[mapping.MappingExample],
|
381
|
+
'Fewshot exemplars for `lf.query`.'
|
382
|
+
]
|
383
|
+
usage_summary: Annotated[
|
384
|
+
lf.UsageSummary,
|
385
|
+
'Usage summary for `lf.query`.'
|
386
|
+
]
|
387
|
+
|
388
|
+
|
389
|
+
@contextlib.contextmanager
|
390
|
+
def track_queries(
|
391
|
+
include_child_scopes: bool = True
|
392
|
+
) -> Iterator[list[QueryInvocation]]:
|
393
|
+
"""Track all queries made during the context.
|
394
|
+
|
395
|
+
Example:
|
396
|
+
|
397
|
+
```
|
398
|
+
with lf.track_queries() as queries:
|
399
|
+
lf.query('hi', lm=lm)
|
400
|
+
lf.query('What is this {{image}}?', lm=lm, image=image)
|
401
|
+
|
402
|
+
print(queries)
|
403
|
+
```
|
404
|
+
|
405
|
+
Args:
|
406
|
+
include_child_scopes: If True, the queries made in child scopes will be
|
407
|
+
included in the returned list. Otherwise, only the queries made in the
|
408
|
+
current scope will be included.
|
409
|
+
|
410
|
+
Yields:
|
411
|
+
A list of `QueryInvocation` objects representing the queries made during
|
412
|
+
the context.
|
413
|
+
"""
|
414
|
+
trackers = lf.context_value('__query_trackers__', [])
|
415
|
+
tracker = []
|
416
|
+
|
417
|
+
with lf.context(
|
418
|
+
__query_trackers__=[(tracker, include_child_scopes)] + trackers
|
419
|
+
):
|
420
|
+
try:
|
421
|
+
yield tracker
|
422
|
+
finally:
|
423
|
+
pass
|
@@ -89,7 +89,7 @@ class QueryTest(unittest.TestCase):
|
|
89
89
|
)
|
90
90
|
self.assertEqual(
|
91
91
|
prompting.query(
|
92
|
-
lf.Template('what is {{x}} + {{y}}'
|
92
|
+
lf.Template('what is {{x}} + {{y}}', x=1, y=0), int, lm=lm.clone()
|
93
93
|
),
|
94
94
|
1,
|
95
95
|
)
|
@@ -365,6 +365,23 @@ class QueryTest(unittest.TestCase):
|
|
365
365
|
"""),
|
366
366
|
)
|
367
367
|
|
368
|
+
def test_query_prompt_with_metadata(self):
|
369
|
+
self.assertIn(
|
370
|
+
'x',
|
371
|
+
prompting.query_prompt(
|
372
|
+
'what is this?',
|
373
|
+
metadata_x=1
|
374
|
+
).metadata
|
375
|
+
)
|
376
|
+
self.assertIn(
|
377
|
+
'x',
|
378
|
+
prompting.query_prompt(
|
379
|
+
'what is this?',
|
380
|
+
int,
|
381
|
+
metadata_x=1
|
382
|
+
).metadata
|
383
|
+
)
|
384
|
+
|
368
385
|
def test_query_prompt_with_unrooted_template(self):
|
369
386
|
output = prompting.query_prompt(
|
370
387
|
pg.Dict(
|
@@ -945,5 +962,71 @@ class QueryStructureJsonTest(unittest.TestCase):
|
|
945
962
|
)
|
946
963
|
|
947
964
|
|
965
|
+
class TrackQueriesTest(unittest.TestCase):
|
966
|
+
|
967
|
+
def test_include_child_scopes(self):
|
968
|
+
lm = fake.StaticSequence([
|
969
|
+
'bar',
|
970
|
+
'Activity(description="hi")',
|
971
|
+
])
|
972
|
+
with prompting.track_queries() as queries:
|
973
|
+
prompting.query('foo', lm=lm)
|
974
|
+
with prompting.track_queries() as child_queries:
|
975
|
+
prompting.query('give me an activity', Activity, lm=lm)
|
976
|
+
|
977
|
+
self.assertEqual(len(queries), 2)
|
978
|
+
self.assertTrue(pg.eq(queries[0].input, lf.Template('foo')))
|
979
|
+
self.assertIsNone(queries[0].schema)
|
980
|
+
self.assertEqual(queries[0].output, 'bar')
|
981
|
+
self.assertIs(queries[0].lm, lm)
|
982
|
+
|
983
|
+
self.assertTrue(pg.eq(queries[1].input, lf.Template('give me an activity')))
|
984
|
+
self.assertEqual(queries[1].schema.spec.cls, Activity)
|
985
|
+
self.assertTrue(pg.eq(queries[1].output, Activity(description='hi')))
|
986
|
+
self.assertIs(queries[1].lm, lm)
|
987
|
+
self.assertGreater(queries[0].usage_summary.total.total_tokens, 0)
|
988
|
+
self.assertGreater(queries[1].usage_summary.total.total_tokens, 0)
|
989
|
+
|
990
|
+
self.assertEqual(len(child_queries), 1)
|
991
|
+
self.assertIs(child_queries[0], queries[1])
|
992
|
+
|
993
|
+
def test_exclude_child_scopes(self):
|
994
|
+
lm = fake.StaticSequence([
|
995
|
+
'bar',
|
996
|
+
'Activity(description="hi")',
|
997
|
+
])
|
998
|
+
with prompting.track_queries(include_child_scopes=False) as queries:
|
999
|
+
prompting.query('foo', lm=lm)
|
1000
|
+
with prompting.track_queries(include_child_scopes=False) as child_queries:
|
1001
|
+
prompting.query('give me an activity', Activity, lm=lm)
|
1002
|
+
|
1003
|
+
self.assertEqual(len(queries), 1)
|
1004
|
+
self.assertTrue(pg.eq(queries[0].input, lf.Template('foo')))
|
1005
|
+
self.assertIsNone(queries[0].schema)
|
1006
|
+
self.assertEqual(queries[0].output, 'bar')
|
1007
|
+
self.assertIs(queries[0].lm, lm)
|
1008
|
+
|
1009
|
+
self.assertEqual(len(child_queries), 1)
|
1010
|
+
self.assertTrue(
|
1011
|
+
pg.eq(child_queries[0].input, lf.Template('give me an activity'))
|
1012
|
+
)
|
1013
|
+
self.assertEqual(child_queries[0].schema.spec.cls, Activity)
|
1014
|
+
self.assertTrue(pg.eq(child_queries[0].output, Activity(description='hi')))
|
1015
|
+
self.assertIs(child_queries[0].lm, lm)
|
1016
|
+
|
1017
|
+
def test_concurrent_map(self):
|
1018
|
+
|
1019
|
+
def make_query(prompt):
|
1020
|
+
_ = prompting.query(prompt, lm=lm)
|
1021
|
+
|
1022
|
+
lm = fake.StaticSequence([
|
1023
|
+
'foo',
|
1024
|
+
'bar',
|
1025
|
+
])
|
1026
|
+
with prompting.track_queries() as queries:
|
1027
|
+
list(lf.concurrent_map(make_query, ['a', 'b']))
|
1028
|
+
self.assertEqual(len(queries), 2)
|
1029
|
+
|
1030
|
+
|
948
1031
|
if __name__ == '__main__':
|
949
1032
|
unittest.main()
|
@@ -1,4 +1,4 @@
|
|
1
|
-
langfun/__init__.py,sha256=
|
1
|
+
langfun/__init__.py,sha256=Xkzi8VV93jI7iLSrypzyHV9FtRbxHQpmoUCIZJSGHdA,2400
|
2
2
|
langfun/core/__init__.py,sha256=xlvFTXc7IKUTs8aCFRFhzOLTmmeuhXgk9yx2InBLNiA,4937
|
3
3
|
langfun/core/component.py,sha256=HVrEoTL1Y01iqOHC3FYdbAOnffqfHHtGJXoK1vkdEwo,11583
|
4
4
|
langfun/core/component_test.py,sha256=sG-T2wpvBfHqWGZE7sc4NayJj2aj5QFBzSwFiwrGEIc,10376
|
@@ -30,10 +30,10 @@ langfun/core/template_test.py,sha256=Qokz1hQFhRYaTZWBWGqvPJ0NXC9B9ennUpnRYHEf0hE
|
|
30
30
|
langfun/core/text_formatting.py,sha256=d7t9vaY6aCn1dkfkikpNYnBy5E_i93vHbfyDWFclGZU,5284
|
31
31
|
langfun/core/text_formatting_test.py,sha256=ck0Xzdd4YF4CtCUj7VE0GybfbAyKQ8p3xkM1FBGrqIk,2096
|
32
32
|
langfun/core/agentic/__init__.py,sha256=ndoDX0sAYsa3eVdXuu6nB-a-BH5TaK3urW6zAaFiyVs,1110
|
33
|
-
langfun/core/agentic/action.py,sha256=
|
33
|
+
langfun/core/agentic/action.py,sha256=lsEltCSrPag5GOvAeaakf_3iil28tKZJdN-NrovqQDw,8954
|
34
34
|
langfun/core/agentic/action_eval.py,sha256=ZtjTh34S7XPIUqandQ0YwAtzw-S7ofuZ7rRXnRbUMdQ,4424
|
35
35
|
langfun/core/agentic/action_eval_test.py,sha256=tRUkWmOE9p0rpNOq19xAY2oDEnYsEEykjg6sUpAwJk0,2832
|
36
|
-
langfun/core/agentic/action_test.py,sha256=
|
36
|
+
langfun/core/agentic/action_test.py,sha256=K6Ynop1zthRYMd_6Y4tpv3TFZRJAbNxEAwJf5dKlU5A,3235
|
37
37
|
langfun/core/coding/__init__.py,sha256=5utju_fwEsImaiftx4oXKl9FAM8p281k8-Esdh_-m1w,835
|
38
38
|
langfun/core/coding/python/__init__.py,sha256=MJ-vubliz-ebrZH3OBRKBwMi0S9-FrhGCp8YQLR6_I4,1776
|
39
39
|
langfun/core/coding/python/correction.py,sha256=WiBdoScL-6C___iA3Tg3vizuYtJWI-_4wy9zcMfVpj8,7020
|
@@ -118,7 +118,7 @@ langfun/core/modalities/pdf.py,sha256=mfaeCbUA4JslFVTARiJh8hW7imvL4tLVw9gUhO5bAZ
|
|
118
118
|
langfun/core/modalities/pdf_test.py,sha256=ulZ0FbnlsU0wkrdckJ4ONZPTYRyMPO9Aob1UO6FXygk,1950
|
119
119
|
langfun/core/modalities/video.py,sha256=vI9apcHIHGyp90i34Srg7S3G6IBDtDCk8qiXhwRQmkw,967
|
120
120
|
langfun/core/modalities/video_test.py,sha256=7OXZoohKMYjt7vrJUdPb553HLyl1oBOKRgzBePFv68Q,2042
|
121
|
-
langfun/core/structured/__init__.py,sha256=
|
121
|
+
langfun/core/structured/__init__.py,sha256=YGyGN-6gcGpzo1Hh-kpPFvC-dPYayjx7NRn06tyAdXE,4016
|
122
122
|
langfun/core/structured/completion.py,sha256=cS2PjG7sqzDu5x0xoTk8RmNcoeX55iVwH38NTefkMHg,8108
|
123
123
|
langfun/core/structured/completion_test.py,sha256=lendf6nPsNfAmd5A7k3v_HS2At9F_jjbKBcV7OEt94o,19310
|
124
124
|
langfun/core/structured/description.py,sha256=SXW4MJvshFjbR-0gw6rE21o6WXq12UlRXawvDBXMZFA,5211
|
@@ -129,8 +129,8 @@ langfun/core/structured/mapping.py,sha256=vLKH79UT-j0qkQdvqlQBO7SkXXuM-yr2Idm8_H
|
|
129
129
|
langfun/core/structured/mapping_test.py,sha256=bHm2ZCXBITq_G8Lvw_olFHeUUc4s_lGXZm9v9JhoPB4,9630
|
130
130
|
langfun/core/structured/parsing.py,sha256=D58wBWOC6r6DCJNychCDkiHPrsy1XJfBDCDDZtug00k,11765
|
131
131
|
langfun/core/structured/parsing_test.py,sha256=i0i090FVgM8ngGqYjds0hjEm1v7q4gv18k-z1kaNr7E,21467
|
132
|
-
langfun/core/structured/prompting.py,sha256=
|
133
|
-
langfun/core/structured/prompting_test.py,sha256=
|
132
|
+
langfun/core/structured/prompting.py,sha256=R9tfitDCBsQ725lzrSfaVgLi7FdtArWRStEMkzmJWQU,13698
|
133
|
+
langfun/core/structured/prompting_test.py,sha256=B0D70JmWgFYjRN9wfoSRX8zn0vdCUCJWk-igb59K0WY,29421
|
134
134
|
langfun/core/structured/schema.py,sha256=XHA-m_ENT_J0k8Q7WCiCL51xm7oXHOqskhO8RpPIurc,28174
|
135
135
|
langfun/core/structured/schema_generation.py,sha256=U3nRQsqmMZg_qIVDh2fiY3K4JLfsAL1LcKzIFP1iXFg,5316
|
136
136
|
langfun/core/structured/schema_generation_test.py,sha256=RM9s71kMNg2jTePwInkiW9fK1ACN37eyPeF8OII-0zw,2950
|
@@ -148,8 +148,8 @@ langfun/core/templates/demonstration.py,sha256=vCrgYubdZM5Umqcgp8NUVGXgr4P_c-fik
|
|
148
148
|
langfun/core/templates/demonstration_test.py,sha256=SafcDQ0WgI7pw05EmPI2S4v1t3ABKzup8jReCljHeK4,2162
|
149
149
|
langfun/core/templates/selfplay.py,sha256=yhgrJbiYwq47TgzThmHrDQTF4nDrTI09CWGhuQPNv-s,2273
|
150
150
|
langfun/core/templates/selfplay_test.py,sha256=Ot__1P1M8oJfoTp-M9-PQ6HUXqZKyMwvZ5f7yQ3yfyM,2326
|
151
|
-
langfun-0.1.2.
|
152
|
-
langfun-0.1.2.
|
153
|
-
langfun-0.1.2.
|
154
|
-
langfun-0.1.2.
|
155
|
-
langfun-0.1.2.
|
151
|
+
langfun-0.1.2.dev202412040804.dist-info/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
|
152
|
+
langfun-0.1.2.dev202412040804.dist-info/METADATA,sha256=0IziLhpTHxU_V2Sb2WLt0oiaBRLnrUj4QjwY9WnFW3g,8281
|
153
|
+
langfun-0.1.2.dev202412040804.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
|
154
|
+
langfun-0.1.2.dev202412040804.dist-info/top_level.txt,sha256=RhlEkHxs1qtzmmtWSwYoLVJAc1YrbPtxQ52uh8Z9VvY,8
|
155
|
+
langfun-0.1.2.dev202412040804.dist-info/RECORD,,
|
File without changes
|
File without changes
|
{langfun-0.1.2.dev202412030804.dist-info → langfun-0.1.2.dev202412040804.dist-info}/top_level.txt
RENAMED
File without changes
|