langfun 0.1.2.dev202412030804__py3-none-any.whl → 0.1.2.dev202412070804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of langfun might be problematic. Click here for more details.

langfun/__init__.py CHANGED
@@ -35,6 +35,8 @@ complete = structured.complete
35
35
  score = structured.score
36
36
  generate_class = structured.generate_class
37
37
 
38
+ track_queries = structured.track_queries
39
+
38
40
  # Helper functions for input/output transformations based on
39
41
  # `lf.query` (e.g. jax-on-beam could use these for batch processing)
40
42
  query_prompt = structured.query_prompt
@@ -14,8 +14,10 @@
14
14
  """Base classes for agentic actions."""
15
15
 
16
16
  import abc
17
- from typing import Annotated, Any, Optional, Union
17
+ import contextlib
18
+ from typing import Annotated, Any, Iterable, Iterator, Optional, Type, Union
18
19
  import langfun.core as lf
20
+ from langfun.core import structured as lf_structured
19
21
  import pyglove as pg
20
22
 
21
23
 
@@ -35,12 +37,9 @@ class Action(pg.Object):
35
37
  self, session: Optional['Session'] = None, **kwargs) -> Any:
36
38
  """Executes the action."""
37
39
  session = session or Session()
38
- try:
39
- session.begin(self)
40
+ with session.track(self):
40
41
  self._result = self.call(session=session, **kwargs)
41
42
  return self._result
42
- finally:
43
- session.end(self)
44
43
 
45
44
  @abc.abstractmethod
46
45
  def call(self, session: 'Session', **kwargs) -> Any:
@@ -50,9 +49,20 @@ class Action(pg.Object):
50
49
  class ActionInvocation(pg.Object, pg.views.html.HtmlTreeView.Extension):
51
50
  """A class for capturing the invocation of an action."""
52
51
  action: Action
53
- result: Any = None
52
+
53
+ result: Annotated[
54
+ Any,
55
+ 'The result of the action.'
56
+ ] = None
57
+
54
58
  execution: Annotated[
55
- list[Union['ActionInvocation', lf.logging.LogEntry]],
59
+ list[
60
+ Union[
61
+ lf_structured.QueryInvocation,
62
+ 'ActionInvocation',
63
+ lf.logging.LogEntry
64
+ ]
65
+ ],
56
66
  'Execution execution.'
57
67
  ] = []
58
68
 
@@ -69,6 +79,18 @@ class ActionInvocation(pg.Object, pg.views.html.HtmlTreeView.Extension):
69
79
  """Returns child action invocations."""
70
80
  return [v for v in self.execution if isinstance(v, ActionInvocation)]
71
81
 
82
+ def queries(
83
+ self,
84
+ include_children: bool = False
85
+ ) -> Iterable[lf_structured.QueryInvocation]:
86
+ """Iterates over queries from the current invocation."""
87
+ for v in self.execution:
88
+ if isinstance(v, lf_structured.QueryInvocation):
89
+ yield v
90
+ elif isinstance(v, ActionInvocation):
91
+ if include_children:
92
+ yield from v.queries(include_children=True)
93
+
72
94
  def _html_tree_view_summary(
73
95
  self, *, view: pg.views.html.HtmlTreeView, **kwargs
74
96
  ):
@@ -190,29 +212,57 @@ class Session(pg.Object):
190
212
  assert self._invocation_stack
191
213
  return self._invocation_stack[-1]
192
214
 
193
- def begin(self, action: Action):
194
- """Signal the beginning of the execution of an action."""
215
+ @contextlib.contextmanager
216
+ def track(self, action: Action) -> Iterator[ActionInvocation]:
217
+ """Track the execution of an action."""
195
218
  new_invocation = ActionInvocation(pg.maybe_ref(action))
196
219
  with pg.notify_on_change(False):
197
220
  self.current_invocation.execution.append(new_invocation)
198
221
  self._invocation_stack.append(new_invocation)
199
222
 
200
- def end(self, action: Action):
201
- """Signal the end of the execution of an action."""
202
- assert self._invocation_stack
203
- invocation = self._invocation_stack.pop(-1)
204
- invocation.rebind(
205
- result=action.result, skip_notification=True, raise_on_no_change=False
206
- )
207
- assert invocation.action is action, (invocation.action, action)
208
- assert self._invocation_stack, self._invocation_stack
209
-
210
- if len(self._invocation_stack) == 1:
211
- self.root_invocation.rebind(
212
- result=invocation.result,
213
- skip_notification=True,
214
- raise_on_no_change=False
223
+ try:
224
+ yield new_invocation
225
+ finally:
226
+ assert self._invocation_stack
227
+ invocation = self._invocation_stack.pop(-1)
228
+ invocation.rebind(
229
+ result=action.result, skip_notification=True, raise_on_no_change=False
215
230
  )
231
+ assert invocation.action is action, (invocation.action, action)
232
+ assert self._invocation_stack, self._invocation_stack
233
+
234
+ if len(self._invocation_stack) == 1:
235
+ self.root_invocation.rebind(
236
+ result=invocation.result,
237
+ skip_notification=True,
238
+ raise_on_no_change=False
239
+ )
240
+
241
+ def query(
242
+ self,
243
+ prompt: Union[str, lf.Template, Any],
244
+ schema: Union[
245
+ lf_structured.Schema, Type[Any], list[Type[Any]], dict[str, Any], None
246
+ ] = None,
247
+ default: Any = lf.RAISE_IF_HAS_ERROR,
248
+ *,
249
+ lm: lf.LanguageModel | None = None,
250
+ examples: list[lf_structured.MappingExample] | None = None,
251
+ **kwargs
252
+ ) -> Any:
253
+ """Calls `lf.query` and associates it with the current invocation."""
254
+ with lf_structured.track_queries() as queries:
255
+ output = lf_structured.query(
256
+ prompt,
257
+ schema=schema,
258
+ default=default,
259
+ lm=lm,
260
+ examples=examples,
261
+ **kwargs
262
+ )
263
+ with pg.notify_on_change(False):
264
+ self.current_invocation.execution.extend(queries)
265
+ return output
216
266
 
217
267
  def _log(self, level: lf.logging.LogLevel, message: str, **kwargs):
218
268
  with pg.notify_on_change(False):
@@ -17,6 +17,7 @@ import unittest
17
17
 
18
18
  import langfun.core as lf
19
19
  from langfun.core.agentic import action as action_lib
20
+ from langfun.core.llms import fake
20
21
 
21
22
 
22
23
  class SessionTest(unittest.TestCase):
@@ -26,25 +27,35 @@ class SessionTest(unittest.TestCase):
26
27
 
27
28
  class Bar(action_lib.Action):
28
29
 
29
- def call(self, session, **kwargs):
30
+ def call(self, session, *, lm, **kwargs):
30
31
  test.assertIs(session.current_invocation.action, self)
31
32
  session.info('Begin Bar')
33
+ session.query('bar', lm=lm)
32
34
  return 2
33
35
 
34
36
  class Foo(action_lib.Action):
35
37
  x: int
36
38
 
37
- def call(self, session, **kwargs):
39
+ def call(self, session, *, lm, **kwargs):
38
40
  test.assertIs(session.current_invocation.action, self)
39
41
  session.info('Begin Foo', x=1)
40
- return self.x + Bar()(session)
42
+ session.query('foo', lm=lm)
43
+ return self.x + Bar()(session, lm=lm)
41
44
 
45
+ lm = fake.StaticResponse('lm response')
42
46
  session = action_lib.Session()
43
47
  root = session.root_invocation
44
48
  self.assertIsInstance(root.action, action_lib.RootAction)
45
49
  self.assertIs(session.current_invocation, session.root_invocation)
46
- self.assertEqual(Foo(1)(session), 3)
50
+ self.assertEqual(Foo(1)(session, lm=lm), 3)
47
51
  self.assertEqual(len(session.root_invocation.child_invocations), 1)
52
+ self.assertEqual(len(list(session.root_invocation.queries())), 0)
53
+ self.assertEqual(
54
+ len(list(session.root_invocation.queries(include_children=True))), 2
55
+ )
56
+ self.assertEqual(
57
+ len(list(session.root_invocation.child_invocations[0].queries())), 1
58
+ )
48
59
  self.assertEqual(len(session.root_invocation.child_invocations[0].logs), 1)
49
60
  self.assertEqual(
50
61
  len(session.root_invocation.child_invocations[0].child_invocations),
@@ -55,6 +66,11 @@ class SessionTest(unittest.TestCase):
55
66
  .child_invocations[0].child_invocations[0].logs),
56
67
  1
57
68
  )
69
+ self.assertEqual(
70
+ len(list(session.root_invocation
71
+ .child_invocations[0].child_invocations[0].queries())),
72
+ 1
73
+ )
58
74
  self.assertEqual(
59
75
  len(session.root_invocation
60
76
  .child_invocations[0].child_invocations[0].child_invocations),
@@ -261,6 +261,9 @@ class RunnerBase(Runner):
261
261
  if cache is not None:
262
262
  self.background_run(cache.save)
263
263
 
264
+ # Wait for the background tasks to finish.
265
+ self._io_pool.shutdown(wait=True)
266
+
264
267
  @abc.abstractmethod
265
268
  def _run(self, evaluations: list[Evaluation]) -> None:
266
269
  """Runs multiple evaluations."""
@@ -69,6 +69,8 @@ from langfun.core.structured.prompting import query
69
69
  from langfun.core.structured.prompting import query_prompt
70
70
  from langfun.core.structured.prompting import query_output
71
71
  from langfun.core.structured.prompting import query_reward
72
+ from langfun.core.structured.prompting import QueryInvocation
73
+ from langfun.core.structured.prompting import track_queries
72
74
 
73
75
  from langfun.core.structured.description import DescribeStructure
74
76
  from langfun.core.structured.description import describe
@@ -16,7 +16,7 @@
16
16
  import functools
17
17
  import inspect
18
18
  import re
19
- from typing import Any, Callable, Optional, Tuple
19
+ from typing import Any, Callable, Literal, Optional, Tuple
20
20
 
21
21
  from langfun.core import language_model
22
22
  from langfun.core import template
@@ -25,7 +25,7 @@ from langfun.core.structured import prompting
25
25
  import pyglove as pg
26
26
 
27
27
 
28
- def unittest_gen(signature, lm, num_retries=10):
28
+ def unittest_gen(signature, lm, num_retries=1):
29
29
  """Generates unit tests for a python function signature."""
30
30
 
31
31
  class UnitTest(pg.Object):
@@ -78,10 +78,13 @@ def _function_gen(
78
78
  func: Callable[..., Any],
79
79
  signature: str,
80
80
  lm: language_model.LanguageModel,
81
- num_retries: int = 10,
81
+ num_retries: int = 1,
82
82
  unittest: Optional[
83
- Callable[[Callable[..., Any]], None] | list[Tuple[Any, Any]]
83
+ Callable[[Callable[..., Any]], None]
84
+ | list[Tuple[Any, Any]]
85
+ | Literal["auto"]
84
86
  ] = None,
87
+ unittest_num_retries: int = 1,
85
88
  ):
86
89
  """Generates a python function with LLM and verify its quality with unit testing."""
87
90
 
@@ -131,9 +134,11 @@ def _function_gen(
131
134
  """
132
135
 
133
136
  unittest_examples = None
134
- if unittest is None:
135
- unittest_examples = unittest_gen(signature, lm=lm)
136
- elif not callable(unittest):
137
+ if unittest == "auto":
138
+ unittest_examples = unittest_gen(
139
+ signature, lm=lm, num_retries=unittest_num_retries
140
+ )
141
+ elif isinstance(unittest, list):
137
142
  unittest_examples = unittest
138
143
 
139
144
  for _ in range(num_retries):
@@ -145,11 +150,16 @@ def _function_gen(
145
150
 
146
151
  # Check whether the sigantures are the same.
147
152
  if inspect.signature(f) != inspect.signature(func):
153
+ pg.logging.warning(
154
+ "Signature mismatch. Expected: %s, Actual: %s",
155
+ inspect.signature(func),
156
+ inspect.signature(f),
157
+ )
148
158
  continue
149
159
 
150
160
  if callable(unittest):
151
161
  unittest(f)
152
- else:
162
+ elif unittest_examples:
153
163
  unittest_with_test_cases(f, unittest_examples)
154
164
 
155
165
  return f, source_code
@@ -172,10 +182,13 @@ def _process_signature(signature):
172
182
  def function_gen(
173
183
  lm: language_model.LanguageModel,
174
184
  cache_filename: str | None = None,
175
- num_retries: int = 10,
185
+ num_retries: int = 1,
176
186
  unittest: Optional[
177
- Callable[[Callable[..., Any]], None] | list[Tuple[Any, Any]]
187
+ Callable[[Callable[..., Any]], None]
188
+ | list[Tuple[Any, Any]]
189
+ | Literal["auto"]
178
190
  ] = None,
191
+ unittest_num_retries: int = 1,
179
192
  ):
180
193
  """A decorator for automating function generation using a language model.
181
194
 
@@ -192,9 +205,12 @@ def function_gen(
192
205
  make to generate a suitable function implementation.
193
206
  unittest: This optional parameter enables the definition of custom unit
194
207
  tests. You can either provide a list of test cases as tuples of inputs
195
- and outputs, or a function that throws an error if a test fails. If left
196
- as None (the default setting), the LLM will automatically create the
197
- unit test cases.
208
+ and outputs, or a function that throws an error if a test fails, or let
209
+ LLM automatically create the unit test cases. If a generated function is
210
+ and returned, it should pass all the unittests.
211
+ unittest_num_retries: If unittest is set to "auto", this parameter
212
+ specifies the number of times the LLM's attempts to generate unit test
213
+ cases.
198
214
 
199
215
  Returns:
200
216
  The implemented function object.
@@ -226,7 +242,12 @@ def function_gen(
226
242
  return func.__function__(*args, **kwargs)
227
243
 
228
244
  func.__function__, func.__source_code__ = _function_gen(
229
- func, signature, lm, num_retries=num_retries, unittest=unittest
245
+ func,
246
+ signature,
247
+ lm,
248
+ num_retries=num_retries,
249
+ unittest=unittest,
250
+ unittest_num_retries=unittest_num_retries,
230
251
  )
231
252
  if func.__function__ is None:
232
253
  raise ValueError(f"Function generation failed. Signature:\n{signature}")
@@ -63,6 +63,42 @@ class FunctionGenerationTest(unittest.TestCase):
63
63
 
64
64
  lm = fake.StaticSequence([unittest_lm_response, function_gen_lm_response])
65
65
 
66
+ @function_generation.function_gen(lm=lm, unittest='auto')
67
+ def linear_search(items, target): # pylint: disable=unused-argument
68
+ """Performs a linear search on a list to find a target value.
69
+
70
+ Args:
71
+ items (list): The list to search within.
72
+ target: The value to search for.
73
+
74
+ Returns:
75
+ int: The index of the target value if found, otherwise -1.
76
+ """
77
+
78
+ self.assertEqual(linear_search(['a', 'b', 'c'], 'c'), 2)
79
+ self.assertEqual(linear_search.source(), function_gen_lm_response)
80
+
81
+ def test_generate_function_without_unittest(self):
82
+ function_gen_lm_response = inspect.cleandoc("""
83
+ def linear_search(items, target):
84
+ \"\"\"
85
+ Performs a linear search on a list to find a target value.
86
+
87
+ Args:
88
+ items (list): The list to search within.
89
+ target: The value to search for.
90
+
91
+ Returns:
92
+ int: The index of the target value if found, otherwise -1.
93
+ \"\"\"
94
+ for i, item in enumerate(items):
95
+ if item == target:
96
+ return i
97
+ return -1
98
+ """)
99
+
100
+ lm = fake.StaticSequence([function_gen_lm_response])
101
+
66
102
  @function_generation.function_gen(lm=lm)
67
103
  def linear_search(items, target): # pylint: disable=unused-argument
68
104
  """Performs a linear search on a list to find a target value.
@@ -258,7 +294,9 @@ class FunctionGenerationTest(unittest.TestCase):
258
294
  cache_file = os.path.join(cache_file_dir, 'cache_file.json')
259
295
 
260
296
  @function_generation.function_gen(
261
- lm=lm, unittest=_unittest_fn, cache_filename=cache_file
297
+ lm=lm,
298
+ unittest=_unittest_fn,
299
+ cache_filename=cache_file,
262
300
  )
263
301
  def linear_search(items, target): # pylint: disable=unused-argument
264
302
  """Performs a linear search on a list to find a target value.
@@ -310,7 +348,9 @@ class FunctionGenerationTest(unittest.TestCase):
310
348
 
311
349
  custom_unittest = _unittest_fn
312
350
 
313
- @function_generation.function_gen(lm=lm, unittest=custom_unittest)
351
+ @function_generation.function_gen(
352
+ lm=lm, unittest=custom_unittest, num_retries=2
353
+ )
314
354
  def linear_search(items, target): # pylint: disable=unused-argument
315
355
  """Performs a linear search on a list to find a target value.
316
356
 
@@ -13,8 +13,9 @@
13
13
  # limitations under the License.
14
14
  """Symbolic query."""
15
15
 
16
+ import contextlib
16
17
  import functools
17
- from typing import Any, Callable, Type, Union
18
+ from typing import Annotated, Any, Callable, Iterator, Type, Union
18
19
 
19
20
  import langfun.core as lf
20
21
  from langfun.core.llms import fake
@@ -102,7 +103,7 @@ def _query_structure_cls(
102
103
 
103
104
 
104
105
  def query(
105
- prompt: Union[str, pg.Symbolic],
106
+ prompt: Union[str, lf.Template, Any],
106
107
  schema: Union[
107
108
  schema_lib.Schema, Type[Any], list[Type[Any]], dict[str, Any], None
108
109
  ] = None,
@@ -119,7 +120,7 @@ def query(
119
120
  skip_lm: bool = False,
120
121
  **kwargs,
121
122
  ) -> Any:
122
- """Parse a natural langugage message based on schema.
123
+ """Queries an language model for a (maybe) structured output.
123
124
 
124
125
  Examples:
125
126
 
@@ -189,59 +190,93 @@ def query(
189
190
  """
190
191
  # Internal usage logging.
191
192
 
193
+ # Normalize query schema.
192
194
  # When `lf.query` is used for symbolic completion, schema is automatically
193
195
  # inferred when it is None.
194
196
  if isinstance(prompt, pg.Symbolic) and prompt.sym_partial and schema is None:
195
197
  schema = prompt.__class__
196
198
 
197
- # Create a copy of the prompt if it has a parent object, so all child modality
198
- # objects could be referred by path relative to the prompt.
199
- if isinstance(prompt, lf.Template) and prompt.sym_parent:
200
- prompt = prompt.clone()
201
-
202
- if schema in (None, str):
203
- # Query with natural language output.
204
- output = lf.LangFunc.from_value(prompt, **kwargs)(
205
- lm=lm, cache_seed=cache_seed, skip_lm=skip_lm
199
+ # Normalize query input.
200
+ if isinstance(prompt, (lf.Message, str)):
201
+ # Query with structured output.
202
+ prompt_kwargs = kwargs.copy()
203
+ prompt_kwargs.pop('template_str', None)
204
+ query_input = lf.Template.from_value(prompt, **prompt_kwargs)
205
+ elif isinstance(prompt, lf.Template):
206
+ # Create a copy of the prompt if it has a parent object, so all child
207
+ # modality objects could be referred by path relative to the prompt.
208
+ query_input = prompt.clone() if prompt.sym_parent is not None else prompt
209
+
210
+ # Attach template metadata from kwargs. This is used to pass through fields
211
+ # from kwargs to the rendered message.
212
+ template_metadata = {
213
+ k: v for k, v in kwargs.items() if k.startswith('metadata_')
214
+ }
215
+ query_input.rebind(
216
+ template_metadata, skip_notification=True, raise_on_no_change=False
206
217
  )
207
- if response_postprocess:
208
- processed_text = response_postprocess(output.text)
209
- if processed_text != output.text:
210
- output = lf.AIMessage(processed_text, source=output)
211
- return output if returns_message else output.text
212
-
213
- # Query with structured output.
214
- prompt_kwargs = kwargs.copy()
215
-
216
- # NOTE(daiyip): when `template_str` is passed in, it's intended to modify the
217
- # QueryStructure template string. Therefore, we pop out the argument for
218
- # prompt rendering.
219
- prompt_kwargs.pop('template_str', None)
220
-
221
- if isinstance(prompt, (str, lf.Message, lf.Template)):
222
- prompt = lf.Template.from_value(prompt, **prompt_kwargs).render(lm=lm)
218
+ elif pg.MISSING_VALUE == prompt:
219
+ query_input = lf.UserMessage('')
223
220
  else:
224
- prompt = schema_lib.mark_missing(prompt)
225
-
226
- output = _query_structure_cls(protocol)(
227
- input=prompt,
228
- schema=schema,
229
- default=default,
230
- examples=examples,
231
- response_postprocess=response_postprocess,
232
- autofix=autofix if protocol == 'python' else 0,
233
- **kwargs,
234
- )(
235
- lm=lm,
236
- autofix_lm=autofix_lm or lm,
237
- cache_seed=cache_seed,
238
- skip_lm=skip_lm,
239
- )
240
- return output if returns_message else output.result
221
+ query_input = schema_lib.mark_missing(prompt)
222
+
223
+ with lf.track_usages() as usage_summary:
224
+ if schema in (None, str):
225
+ # Query with natural language output.
226
+ output_message = lf.LangFunc.from_value(query_input, **kwargs)(
227
+ lm=lm, cache_seed=cache_seed, skip_lm=skip_lm
228
+ )
229
+ if response_postprocess:
230
+ processed_text = response_postprocess(output_message.text)
231
+ if processed_text != output_message.text:
232
+ output_message = lf.AIMessage(processed_text, source=output_message)
233
+ else:
234
+ # Query with structured output.
235
+ output_message = _query_structure_cls(protocol)(
236
+ input=(
237
+ query_input.render(lm=lm)
238
+ if isinstance(query_input, lf.Template)
239
+ else query_input
240
+ ),
241
+ schema=schema,
242
+ default=default,
243
+ examples=examples,
244
+ response_postprocess=response_postprocess,
245
+ autofix=autofix if protocol == 'python' else 0,
246
+ **kwargs,
247
+ )(
248
+ lm=lm,
249
+ autofix_lm=autofix_lm or lm,
250
+ cache_seed=cache_seed,
251
+ skip_lm=skip_lm,
252
+ )
253
+
254
+ def _result(message: lf.Message):
255
+ return message.text if schema in (None, str) else message.result
256
+
257
+ # Track the query invocations.
258
+ if pg.MISSING_VALUE != prompt and not skip_lm:
259
+ trackers = lf.context_value('__query_trackers__', [])
260
+ if trackers:
261
+ invocation = QueryInvocation(
262
+ input=pg.Ref(query_input),
263
+ schema=(
264
+ schema_lib.Schema.from_value(schema)
265
+ if schema not in (None, str) else None
266
+ ),
267
+ output=pg.Ref(_result(output_message)),
268
+ lm=pg.Ref(lm),
269
+ examples=pg.Ref(examples) if examples else [],
270
+ usage_summary=usage_summary,
271
+ )
272
+ for i, (tracker, include_child_scopes) in enumerate(trackers):
273
+ if i == 0 or include_child_scopes:
274
+ tracker.append(invocation)
275
+ return output_message if returns_message else _result(output_message)
241
276
 
242
277
 
243
278
  def query_prompt(
244
- prompt: Union[str, pg.Symbolic],
279
+ prompt: Union[str, lf.Template, Any],
245
280
  schema: Union[
246
281
  schema_lib.Schema, Type[Any], list[Type[Any]], dict[str, Any], None
247
282
  ] = None,
@@ -264,7 +299,7 @@ def query_output(
264
299
  kwargs.pop('prompt', None)
265
300
  kwargs.pop('lm', None)
266
301
  return query(
267
- 'Unused prompt', schema, lm=fake.StaticResponse(response), **kwargs
302
+ pg.MISSING_VALUE, schema, lm=fake.StaticResponse(response), **kwargs
268
303
  )
269
304
 
270
305
 
@@ -320,3 +355,69 @@ def _reward_fn(cls) -> Callable[
320
355
  args = [self, input, expected_output, metadata]
321
356
  return cls.__reward__(*args[:num_args])
322
357
  return _reward
358
+
359
+
360
+ class QueryInvocation(pg.Object):
361
+ """A class to represent the invocation of `lf.query`."""
362
+
363
+ input: Annotated[
364
+ Union[lf.Template, pg.Symbolic],
365
+ 'Mapping input of `lf.query`.'
366
+ ]
367
+ schema: pg.typing.Annotated[
368
+ schema_lib.schema_spec(noneable=True),
369
+ 'Schema of `lf.query`.'
370
+ ]
371
+ output: Annotated[
372
+ Any,
373
+ 'Mapping output of `lf.query`.'
374
+ ]
375
+ lm: Annotated[
376
+ lf.LanguageModel,
377
+ 'Language model used for `lf.query`.'
378
+ ]
379
+ examples: Annotated[
380
+ list[mapping.MappingExample],
381
+ 'Fewshot exemplars for `lf.query`.'
382
+ ]
383
+ usage_summary: Annotated[
384
+ lf.UsageSummary,
385
+ 'Usage summary for `lf.query`.'
386
+ ]
387
+
388
+
389
+ @contextlib.contextmanager
390
+ def track_queries(
391
+ include_child_scopes: bool = True
392
+ ) -> Iterator[list[QueryInvocation]]:
393
+ """Track all queries made during the context.
394
+
395
+ Example:
396
+
397
+ ```
398
+ with lf.track_queries() as queries:
399
+ lf.query('hi', lm=lm)
400
+ lf.query('What is this {{image}}?', lm=lm, image=image)
401
+
402
+ print(queries)
403
+ ```
404
+
405
+ Args:
406
+ include_child_scopes: If True, the queries made in child scopes will be
407
+ included in the returned list. Otherwise, only the queries made in the
408
+ current scope will be included.
409
+
410
+ Yields:
411
+ A list of `QueryInvocation` objects representing the queries made during
412
+ the context.
413
+ """
414
+ trackers = lf.context_value('__query_trackers__', [])
415
+ tracker = []
416
+
417
+ with lf.context(
418
+ __query_trackers__=[(tracker, include_child_scopes)] + trackers
419
+ ):
420
+ try:
421
+ yield tracker
422
+ finally:
423
+ pass
@@ -89,7 +89,7 @@ class QueryTest(unittest.TestCase):
89
89
  )
90
90
  self.assertEqual(
91
91
  prompting.query(
92
- lf.Template('what is {{x}} + {{y}}'), int, x=1, y=0, lm=lm.clone()
92
+ lf.Template('what is {{x}} + {{y}}', x=1, y=0), int, lm=lm.clone()
93
93
  ),
94
94
  1,
95
95
  )
@@ -365,6 +365,23 @@ class QueryTest(unittest.TestCase):
365
365
  """),
366
366
  )
367
367
 
368
+ def test_query_prompt_with_metadata(self):
369
+ self.assertIn(
370
+ 'x',
371
+ prompting.query_prompt(
372
+ 'what is this?',
373
+ metadata_x=1
374
+ ).metadata
375
+ )
376
+ self.assertIn(
377
+ 'x',
378
+ prompting.query_prompt(
379
+ 'what is this?',
380
+ int,
381
+ metadata_x=1
382
+ ).metadata
383
+ )
384
+
368
385
  def test_query_prompt_with_unrooted_template(self):
369
386
  output = prompting.query_prompt(
370
387
  pg.Dict(
@@ -945,5 +962,71 @@ class QueryStructureJsonTest(unittest.TestCase):
945
962
  )
946
963
 
947
964
 
965
+ class TrackQueriesTest(unittest.TestCase):
966
+
967
+ def test_include_child_scopes(self):
968
+ lm = fake.StaticSequence([
969
+ 'bar',
970
+ 'Activity(description="hi")',
971
+ ])
972
+ with prompting.track_queries() as queries:
973
+ prompting.query('foo', lm=lm)
974
+ with prompting.track_queries() as child_queries:
975
+ prompting.query('give me an activity', Activity, lm=lm)
976
+
977
+ self.assertEqual(len(queries), 2)
978
+ self.assertTrue(pg.eq(queries[0].input, lf.Template('foo')))
979
+ self.assertIsNone(queries[0].schema)
980
+ self.assertEqual(queries[0].output, 'bar')
981
+ self.assertIs(queries[0].lm, lm)
982
+
983
+ self.assertTrue(pg.eq(queries[1].input, lf.Template('give me an activity')))
984
+ self.assertEqual(queries[1].schema.spec.cls, Activity)
985
+ self.assertTrue(pg.eq(queries[1].output, Activity(description='hi')))
986
+ self.assertIs(queries[1].lm, lm)
987
+ self.assertGreater(queries[0].usage_summary.total.total_tokens, 0)
988
+ self.assertGreater(queries[1].usage_summary.total.total_tokens, 0)
989
+
990
+ self.assertEqual(len(child_queries), 1)
991
+ self.assertIs(child_queries[0], queries[1])
992
+
993
+ def test_exclude_child_scopes(self):
994
+ lm = fake.StaticSequence([
995
+ 'bar',
996
+ 'Activity(description="hi")',
997
+ ])
998
+ with prompting.track_queries(include_child_scopes=False) as queries:
999
+ prompting.query('foo', lm=lm)
1000
+ with prompting.track_queries(include_child_scopes=False) as child_queries:
1001
+ prompting.query('give me an activity', Activity, lm=lm)
1002
+
1003
+ self.assertEqual(len(queries), 1)
1004
+ self.assertTrue(pg.eq(queries[0].input, lf.Template('foo')))
1005
+ self.assertIsNone(queries[0].schema)
1006
+ self.assertEqual(queries[0].output, 'bar')
1007
+ self.assertIs(queries[0].lm, lm)
1008
+
1009
+ self.assertEqual(len(child_queries), 1)
1010
+ self.assertTrue(
1011
+ pg.eq(child_queries[0].input, lf.Template('give me an activity'))
1012
+ )
1013
+ self.assertEqual(child_queries[0].schema.spec.cls, Activity)
1014
+ self.assertTrue(pg.eq(child_queries[0].output, Activity(description='hi')))
1015
+ self.assertIs(child_queries[0].lm, lm)
1016
+
1017
+ def test_concurrent_map(self):
1018
+
1019
+ def make_query(prompt):
1020
+ _ = prompting.query(prompt, lm=lm)
1021
+
1022
+ lm = fake.StaticSequence([
1023
+ 'foo',
1024
+ 'bar',
1025
+ ])
1026
+ with prompting.track_queries() as queries:
1027
+ list(lf.concurrent_map(make_query, ['a', 'b']))
1028
+ self.assertEqual(len(queries), 2)
1029
+
1030
+
948
1031
  if __name__ == '__main__':
949
1032
  unittest.main()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: langfun
3
- Version: 0.1.2.dev202412030804
3
+ Version: 0.1.2.dev202412070804
4
4
  Summary: Langfun: Language as Functions.
5
5
  Home-page: https://github.com/google/langfun
6
6
  Author: Langfun Authors
@@ -1,4 +1,4 @@
1
- langfun/__init__.py,sha256=o_HvoQggla5uqNA7uF1126aZhayHnVNP__nd_t5ElEQ,2358
1
+ langfun/__init__.py,sha256=Xkzi8VV93jI7iLSrypzyHV9FtRbxHQpmoUCIZJSGHdA,2400
2
2
  langfun/core/__init__.py,sha256=xlvFTXc7IKUTs8aCFRFhzOLTmmeuhXgk9yx2InBLNiA,4937
3
3
  langfun/core/component.py,sha256=HVrEoTL1Y01iqOHC3FYdbAOnffqfHHtGJXoK1vkdEwo,11583
4
4
  langfun/core/component_test.py,sha256=sG-T2wpvBfHqWGZE7sc4NayJj2aj5QFBzSwFiwrGEIc,10376
@@ -30,10 +30,10 @@ langfun/core/template_test.py,sha256=Qokz1hQFhRYaTZWBWGqvPJ0NXC9B9ennUpnRYHEf0hE
30
30
  langfun/core/text_formatting.py,sha256=d7t9vaY6aCn1dkfkikpNYnBy5E_i93vHbfyDWFclGZU,5284
31
31
  langfun/core/text_formatting_test.py,sha256=ck0Xzdd4YF4CtCUj7VE0GybfbAyKQ8p3xkM1FBGrqIk,2096
32
32
  langfun/core/agentic/__init__.py,sha256=ndoDX0sAYsa3eVdXuu6nB-a-BH5TaK3urW6zAaFiyVs,1110
33
- langfun/core/agentic/action.py,sha256=Am5E1EH1ZBAhzagbnDVRnR4vBzI4H6MEtQ58laSPfTg,7515
33
+ langfun/core/agentic/action.py,sha256=lsEltCSrPag5GOvAeaakf_3iil28tKZJdN-NrovqQDw,8954
34
34
  langfun/core/agentic/action_eval.py,sha256=ZtjTh34S7XPIUqandQ0YwAtzw-S7ofuZ7rRXnRbUMdQ,4424
35
35
  langfun/core/agentic/action_eval_test.py,sha256=tRUkWmOE9p0rpNOq19xAY2oDEnYsEEykjg6sUpAwJk0,2832
36
- langfun/core/agentic/action_test.py,sha256=CBsUQICD8yPCDUBBFouSkZuyLAcK_C-AWYc28Zts10E,2624
36
+ langfun/core/agentic/action_test.py,sha256=K6Ynop1zthRYMd_6Y4tpv3TFZRJAbNxEAwJf5dKlU5A,3235
37
37
  langfun/core/coding/__init__.py,sha256=5utju_fwEsImaiftx4oXKl9FAM8p281k8-Esdh_-m1w,835
38
38
  langfun/core/coding/python/__init__.py,sha256=MJ-vubliz-ebrZH3OBRKBwMi0S9-FrhGCp8YQLR6_I4,1776
39
39
  langfun/core/coding/python/correction.py,sha256=WiBdoScL-6C___iA3Tg3vizuYtJWI-_4wy9zcMfVpj8,7020
@@ -76,7 +76,7 @@ langfun/core/eval/v2/progress_tracking.py,sha256=l9fEkz4oP5McpZzf72Ua7PYm3lAWtRr
76
76
  langfun/core/eval/v2/progress_tracking_test.py,sha256=iO-DslCJWncU7-27XaMKxDeKrsGbwdk_tKfoRk3KboE,2271
77
77
  langfun/core/eval/v2/reporting.py,sha256=TGkli1IDwqfqsCJ_WslOMGk_24JDg7oRRTGXlAJlWpc,4361
78
78
  langfun/core/eval/v2/reporting_test.py,sha256=JxffbUPWInUyLjo-AQVFrllga884Mdfm05R86FtxSss,1482
79
- langfun/core/eval/v2/runners.py,sha256=zJmu-amUiYv1g0Ek4c3mXkBgp-AFvSF7WpXVZCCf7Y4,14245
79
+ langfun/core/eval/v2/runners.py,sha256=kP6ZEg9L8M7fK03tOZYGqIjTKUzoJn8Hz_LXS7btFPQ,14335
80
80
  langfun/core/eval/v2/runners_test.py,sha256=UeiUNygux_U6iGVG18rhp68ZE4hoWeoT6XsXvSjxNQg,11620
81
81
  langfun/core/eval/v2/test_helper.py,sha256=pDpZTBnWRR5xjJv3Uy3NWEzArqlL8FTMOgeR4C53F5M,2348
82
82
  langfun/core/llms/__init__.py,sha256=C4hyLflqOQT841nMfclcxcnOhdP83zR0GGW29PnA-vU,6216
@@ -118,19 +118,19 @@ langfun/core/modalities/pdf.py,sha256=mfaeCbUA4JslFVTARiJh8hW7imvL4tLVw9gUhO5bAZ
118
118
  langfun/core/modalities/pdf_test.py,sha256=ulZ0FbnlsU0wkrdckJ4ONZPTYRyMPO9Aob1UO6FXygk,1950
119
119
  langfun/core/modalities/video.py,sha256=vI9apcHIHGyp90i34Srg7S3G6IBDtDCk8qiXhwRQmkw,967
120
120
  langfun/core/modalities/video_test.py,sha256=7OXZoohKMYjt7vrJUdPb553HLyl1oBOKRgzBePFv68Q,2042
121
- langfun/core/structured/__init__.py,sha256=7EgI6pQIWnSNxhcIawBcSRDR8GPq5ytcmGxgPEjbxeA,3894
121
+ langfun/core/structured/__init__.py,sha256=YGyGN-6gcGpzo1Hh-kpPFvC-dPYayjx7NRn06tyAdXE,4016
122
122
  langfun/core/structured/completion.py,sha256=cS2PjG7sqzDu5x0xoTk8RmNcoeX55iVwH38NTefkMHg,8108
123
123
  langfun/core/structured/completion_test.py,sha256=lendf6nPsNfAmd5A7k3v_HS2At9F_jjbKBcV7OEt94o,19310
124
124
  langfun/core/structured/description.py,sha256=SXW4MJvshFjbR-0gw6rE21o6WXq12UlRXawvDBXMZFA,5211
125
125
  langfun/core/structured/description_test.py,sha256=UtZGjSFUaQ6130t1E5tcL7ODu0xIefkapb53TbnqsK8,7362
126
- langfun/core/structured/function_generation.py,sha256=pFgS3vcRAWiuFBol2x5Eeip3XqoudONsOpeJpWyjT3s,7479
127
- langfun/core/structured/function_generation_test.py,sha256=ZJI-aaGgWWszn92u7h5IZ9Pl70N2DgAGGJrIxPzsvwg,10065
126
+ langfun/core/structured/function_generation.py,sha256=gOV5B4KXzN6ng1P1QtZ8aOAEQB8eAbgwWGj57tnzWJY,8159
127
+ langfun/core/structured/function_generation_test.py,sha256=1OtstouOYyYOd_gmZtL8RRbh-FcYGEvBNju6lNrJrOA,11331
128
128
  langfun/core/structured/mapping.py,sha256=vLKH79UT-j0qkQdvqlQBO7SkXXuM-yr2Idm8_HH8qwM,13649
129
129
  langfun/core/structured/mapping_test.py,sha256=bHm2ZCXBITq_G8Lvw_olFHeUUc4s_lGXZm9v9JhoPB4,9630
130
130
  langfun/core/structured/parsing.py,sha256=D58wBWOC6r6DCJNychCDkiHPrsy1XJfBDCDDZtug00k,11765
131
131
  langfun/core/structured/parsing_test.py,sha256=i0i090FVgM8ngGqYjds0hjEm1v7q4gv18k-z1kaNr7E,21467
132
- langfun/core/structured/prompting.py,sha256=huwwh01AQQCwPBQESOMI_V1V5PZkVQ8C89Yjk67_4Uw,10677
133
- langfun/core/structured/prompting_test.py,sha256=pviyb8yTnxkWPAZodLIlQT8y2ScE6FfSHKWf1NUtV-Y,26718
132
+ langfun/core/structured/prompting.py,sha256=R9tfitDCBsQ725lzrSfaVgLi7FdtArWRStEMkzmJWQU,13698
133
+ langfun/core/structured/prompting_test.py,sha256=B0D70JmWgFYjRN9wfoSRX8zn0vdCUCJWk-igb59K0WY,29421
134
134
  langfun/core/structured/schema.py,sha256=XHA-m_ENT_J0k8Q7WCiCL51xm7oXHOqskhO8RpPIurc,28174
135
135
  langfun/core/structured/schema_generation.py,sha256=U3nRQsqmMZg_qIVDh2fiY3K4JLfsAL1LcKzIFP1iXFg,5316
136
136
  langfun/core/structured/schema_generation_test.py,sha256=RM9s71kMNg2jTePwInkiW9fK1ACN37eyPeF8OII-0zw,2950
@@ -148,8 +148,8 @@ langfun/core/templates/demonstration.py,sha256=vCrgYubdZM5Umqcgp8NUVGXgr4P_c-fik
148
148
  langfun/core/templates/demonstration_test.py,sha256=SafcDQ0WgI7pw05EmPI2S4v1t3ABKzup8jReCljHeK4,2162
149
149
  langfun/core/templates/selfplay.py,sha256=yhgrJbiYwq47TgzThmHrDQTF4nDrTI09CWGhuQPNv-s,2273
150
150
  langfun/core/templates/selfplay_test.py,sha256=Ot__1P1M8oJfoTp-M9-PQ6HUXqZKyMwvZ5f7yQ3yfyM,2326
151
- langfun-0.1.2.dev202412030804.dist-info/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
152
- langfun-0.1.2.dev202412030804.dist-info/METADATA,sha256=BG5zoGi1sstPc97R4aTE51aiVfa02LLq-BRHRb8R_3Q,8281
153
- langfun-0.1.2.dev202412030804.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
154
- langfun-0.1.2.dev202412030804.dist-info/top_level.txt,sha256=RhlEkHxs1qtzmmtWSwYoLVJAc1YrbPtxQ52uh8Z9VvY,8
155
- langfun-0.1.2.dev202412030804.dist-info/RECORD,,
151
+ langfun-0.1.2.dev202412070804.dist-info/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
152
+ langfun-0.1.2.dev202412070804.dist-info/METADATA,sha256=Ao9kzm8AFw7749nP7p_m3k41rVcBBzG8_QuMnLLN9_U,8281
153
+ langfun-0.1.2.dev202412070804.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
154
+ langfun-0.1.2.dev202412070804.dist-info/top_level.txt,sha256=RhlEkHxs1qtzmmtWSwYoLVJAc1YrbPtxQ52uh8Z9VvY,8
155
+ langfun-0.1.2.dev202412070804.dist-info/RECORD,,