langfun 0.1.2.dev202412170805__py3-none-any.whl → 0.1.2.dev202412190804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
langfun/__init__.py CHANGED
@@ -37,6 +37,9 @@ generate_class = structured.generate_class
37
37
 
38
38
  track_queries = structured.track_queries
39
39
 
40
+ # Helper function for map-reduce style querying.
41
+ query_and_reduce = structured.query_and_reduce
42
+
40
43
  # Helper functions for input/output transformations based on
41
44
  # `lf.query` (e.g. jax-on-beam could use these for batch processing)
42
45
  query_prompt = structured.query_prompt
@@ -92,7 +92,7 @@ class PerExampleCheckpointer(Checkpointer):
92
92
  )
93
93
  )
94
94
  writer.add(example)
95
- del writer
95
+ writer.close()
96
96
  runner.background_run(save_state, example)
97
97
 
98
98
  def _file_prefix_and_ext(self, filename: str) -> tuple[str, str]:
@@ -128,6 +128,8 @@ class BulkCheckpointer(Checkpointer):
128
128
  ) -> None:
129
129
  with self._lock:
130
130
  if self._sequence_writer is not None:
131
+ for writer in self._sequence_writer.values():
132
+ writer.close()
131
133
  self._sequence_writer.clear()
132
134
 
133
135
  def on_run_complete(
@@ -174,6 +176,9 @@ class BulkCheckpointer(Checkpointer):
174
176
  assert experiment.id in self._sequence_writer
175
177
  with self._lock:
176
178
  if self._sequence_writer is not None:
179
+ # Make sure the writer is closed without delay so the file will be
180
+ # available immediately.
181
+ self._sequence_writer[experiment.id].close()
177
182
  del self._sequence_writer[experiment.id]
178
183
 
179
184
  def on_example_complete(
@@ -207,9 +212,13 @@ class SequenceWriter:
207
212
  return
208
213
  self._sequence_writer.add(example_blob)
209
214
 
210
- def __del__(self):
215
+ def close(self):
211
216
  # Make sure there is no write in progress.
212
217
  with self._lock:
213
- assert self._sequence_writer is not None
218
+ if self._sequence_writer is None:
219
+ return
214
220
  self._sequence_writer.close()
215
221
  self._sequence_writer = None
222
+
223
+ def __del__(self):
224
+ self.close()
@@ -14,7 +14,9 @@
14
14
  """Base class for Langfun evaluation tasks."""
15
15
 
16
16
  import abc
17
+ import datetime
17
18
  import functools
19
+ import threading
18
20
  import time
19
21
 
20
22
  from typing import Annotated, Any, Callable, Iterable
@@ -63,6 +65,8 @@ class Evaluation(experiment_lib.Experiment):
63
65
  self.__dict__.pop('is_leaf', None)
64
66
  self.__dict__.pop('children', None)
65
67
  super()._on_bound()
68
+ self._log_entries = []
69
+ self._log_lock = threading.Lock()
66
70
 
67
71
  #
68
72
  # Handling evaluation hierarchy (materialized vs. hyper evaluations).
@@ -277,6 +281,41 @@ class Evaluation(experiment_lib.Experiment):
277
281
  for metric in self.metrics:
278
282
  metric.reset()
279
283
 
284
+ #
285
+ # Evaluation-level logging.
286
+ #
287
+
288
+ def _log(self, level: lf.logging.LogLevel, message: str, **kwargs):
289
+ with self._log_lock:
290
+ self._log_entries.append(
291
+ lf.logging.LogEntry(
292
+ level=level,
293
+ time=datetime.datetime.now(),
294
+ message=message,
295
+ metadata=kwargs,
296
+ )
297
+ )
298
+
299
+ def debug(self, message: str, **kwargs):
300
+ """Logs a debug message to the session."""
301
+ self._log('debug', message, **kwargs)
302
+
303
+ def info(self, message: str, **kwargs):
304
+ """Logs an info message to the session."""
305
+ self._log('info', message, **kwargs)
306
+
307
+ def warning(self, message: str, **kwargs):
308
+ """Logs a warning message to the session."""
309
+ self._log('warning', message, **kwargs)
310
+
311
+ def error(self, message: str, **kwargs):
312
+ """Logs an error message to the session."""
313
+ self._log('error', message, **kwargs)
314
+
315
+ def fatal(self, message: str, **kwargs):
316
+ """Logs a fatal message to the session."""
317
+ self._log('fatal', message, **kwargs)
318
+
280
319
  #
281
320
  # HTML views.
282
321
  #
@@ -465,6 +504,25 @@ class Evaluation(experiment_lib.Experiment):
465
504
  )
466
505
  )
467
506
 
507
+ def _logs_tab() -> pg.views.html.controls.Tab:
508
+ """Renders a tab for the logs of the evaluation."""
509
+ with self._log_lock:
510
+ log_history = '\n'.join(str(l) for l in self._log_entries)
511
+ return pg.views.html.controls.Tab(
512
+ label='Logs',
513
+ content=pg.Html.element(
514
+ 'div',
515
+ [
516
+ pg.Html.element(
517
+ 'textarea',
518
+ [pg.Html.escape(log_history)],
519
+ readonly=True,
520
+ css_classes=['logs-textarea'],
521
+ )
522
+ ]
523
+ )
524
+ )
525
+
468
526
  def _main_tabs() -> pg.Html:
469
527
  return pg.Html.element(
470
528
  'div',
@@ -474,6 +532,8 @@ class Evaluation(experiment_lib.Experiment):
474
532
  _definition_tab(),
475
533
  ] + [
476
534
  _metric_tab(m) for m in self.metrics
535
+ ] + [
536
+ _logs_tab()
477
537
  ],
478
538
  selected=1,
479
539
  )
@@ -593,6 +653,14 @@ class Evaluation(experiment_lib.Experiment):
593
653
  width:100%;
594
654
  height:100%;
595
655
  }
656
+ .logs-textarea {
657
+ width: 100%;
658
+ height: 500px;
659
+ padding: 5px;
660
+ border: 1px solid #DDD;
661
+ background-color: #EEE;
662
+ resize: vertical;
663
+ }
596
664
  """
597
665
  ]
598
666
 
@@ -615,6 +683,11 @@ class EvaluationState:
615
683
  assert isinstance(example, example_lib.Example), example
616
684
  self._evaluated_examples[example.id] = example
617
685
 
686
+ @property
687
+ def evaluated_examples(self) -> dict[int, example_lib.Example]:
688
+ """Returns the examples in the state."""
689
+ return self._evaluated_examples
690
+
618
691
  def get(self, example_id: int) -> example_lib.Example | None:
619
692
  """Returns the example with the given ID."""
620
693
  return self._evaluated_examples.get(example_id)
@@ -622,9 +695,3 @@ class EvaluationState:
622
695
  def update(self, example: example_lib.Example) -> None:
623
696
  """Updates the state with the given example."""
624
697
  self._evaluated_examples[example.id] = example
625
-
626
- @property
627
- def evaluated_examples(self) -> dict[int, example_lib.Example]:
628
- """Returns the examples in the state."""
629
- return self._evaluated_examples
630
-
@@ -133,6 +133,12 @@ class EvaluationTest(unittest.TestCase):
133
133
 
134
134
  def test_html_view(self):
135
135
  exp = test_helper.TestEvaluation()
136
+ exp.debug('debug message')
137
+ exp.info('info message')
138
+ exp.warning('warning message', x=1)
139
+ exp.error('error message', x=1)
140
+ exp.fatal('fatal message')
141
+
136
142
  self.assertIn(
137
143
  exp.id,
138
144
  exp.to_html(extra_flags=dict(card_view=True, current_run=None)).content
@@ -409,7 +409,7 @@ class VertexAIGemini2_0(VertexAI): # pylint: disable=invalid-name
409
409
  )
410
410
 
411
411
 
412
- class VertexAIGeminiFlash2_0Exp(VertexAI): # pylint: disable=invalid-name
412
+ class VertexAIGeminiFlash2_0Exp(VertexAIGemini2_0): # pylint: disable=invalid-name
413
413
  """Vertex AI Gemini 2.0 Flash model."""
414
414
 
415
415
  model = 'gemini-2.0-flash-exp'
langfun/core/logging.py CHANGED
@@ -54,6 +54,25 @@ class LogEntry(pg.Object, pg.views.HtmlTreeView.Extension):
54
54
  def should_output(self, min_log_level: LogLevel) -> bool:
55
55
  return _LOG_LEVELS.index(self.level) >= _LOG_LEVELS.index(min_log_level)
56
56
 
57
+ def format(self,
58
+ compact: bool = False,
59
+ verbose: bool = True,
60
+ root_indent: int = 0,
61
+ *,
62
+ text_format: bool = True,
63
+ **kwargs):
64
+ if text_format:
65
+ s = f"""{self.time.strftime('%H:%M:%S')} {self.level.upper()} - {self.message}"""
66
+ if self.metadata:
67
+ s += f' (metadata: {self.metadata!r})'
68
+ return s
69
+ return super().format(
70
+ compact=compact,
71
+ verbose=verbose,
72
+ root_indent=root_indent,
73
+ **kwargs
74
+ )
75
+
57
76
  def _html_tree_view_summary(
58
77
  self,
59
78
  view: pg.views.HtmlTreeView,
@@ -61,6 +61,25 @@ class LoggingTest(unittest.TestCase):
61
61
  print(actual)
62
62
  self.assertEqual(actual, expected)
63
63
 
64
+ def test_format(self):
65
+ time = datetime.datetime(2024, 10, 10, 12, 30, 45)
66
+ self.assertEqual(
67
+ str(
68
+ logging.LogEntry(
69
+ level='info', message='hello\nworld',
70
+ time=time, metadata=dict(x=1),
71
+ )
72
+ ),
73
+ '12:30:45 INFO - hello\nworld (metadata: {x=1})',
74
+ )
75
+ self.assertIn(
76
+ 'LogEntry(',
77
+ logging.LogEntry(
78
+ level='info', message='hello\nworld',
79
+ time=time, metadata=dict(x=1),
80
+ ).format(text_format=False),
81
+ )
82
+
64
83
  def test_html(self):
65
84
  time = datetime.datetime(2024, 10, 10, 12, 30, 45)
66
85
  self.assert_html_content(
@@ -56,6 +56,8 @@ from langfun.core.structured.parsing import call
56
56
  from langfun.core.structured.querying import track_queries
57
57
  from langfun.core.structured.querying import QueryInvocation
58
58
  from langfun.core.structured.querying import query
59
+ from langfun.core.structured.querying import query_and_reduce
60
+
59
61
  from langfun.core.structured.querying import query_prompt
60
62
  from langfun.core.structured.querying import query_output
61
63
  from langfun.core.structured.querying import query_reward
@@ -270,24 +270,31 @@ def call(
270
270
  if schema in (str, None):
271
271
  return lm_output if returns_message else lm_output.text
272
272
 
273
+ def _chain_nl_output_message(parsing_message: lf.Message):
274
+ """Chain the source of the parsed output to the LM output."""
275
+ parsing_message.root.source = lm_output
276
+ parsing_message.tag('parsing-lm-output')
277
+ parsing_message.lm_input.tag('parsing-lm-input')
278
+
273
279
  # Call `parsing_lm` for structured parsing.
274
- parsing_message = querying.query(
275
- lm_output.text,
276
- schema,
277
- examples=parsing_examples,
278
- lm=parsing_lm or lm,
279
- include_context=parsing_include_context,
280
- cache_seed=cache_seed,
281
- autofix=autofix,
282
- autofix_lm=autofix_lm or lm,
283
- protocol=protocol,
284
- returns_message=True,
285
- **kwargs,
286
- )
287
- # Chain the source of the parsed output to the LM output.
288
- parsing_message.root.source = lm_output
289
- parsing_message.tag('parsing-lm-output')
290
- parsing_message.lm_input.tag('parsing-lm-input')
280
+ try:
281
+ parsing_message = querying.query(
282
+ lm_output.text,
283
+ schema,
284
+ examples=parsing_examples,
285
+ lm=parsing_lm or lm,
286
+ include_context=parsing_include_context,
287
+ cache_seed=cache_seed,
288
+ autofix=autofix,
289
+ autofix_lm=autofix_lm or lm,
290
+ protocol=protocol,
291
+ returns_message=True,
292
+ **kwargs,
293
+ )
294
+ _chain_nl_output_message(parsing_message)
295
+ except mapping.MappingError as e:
296
+ _chain_nl_output_message(e.lm_response)
297
+ raise e
291
298
  return parsing_message if returns_message else parsing_message.result
292
299
 
293
300
 
@@ -686,6 +686,31 @@ class CallTest(unittest.TestCase):
686
686
  ],
687
687
  returns_message=True,
688
688
  )
689
+ self.assertIn('parsing-lm-output', output.tags)
690
+ self.assertIn('parsing-lm-input', output.source.tags)
691
+ self.assertEqual(output.root.text, 'Compute 1 + 2')
692
+
693
+ def test_call_with_parsing_message_chaining_on_parsing_error(self):
694
+ try:
695
+ output = parsing.call(
696
+ 'Compute 1 + 2',
697
+ int,
698
+ lm=fake.StaticSequence(['three']),
699
+ parsing_lm=fake.StaticSequence(['abc']),
700
+ parsing_examples=[
701
+ mapping.MappingExample(
702
+ context='Multiple four and five',
703
+ input='twenty',
704
+ schema=int,
705
+ output=20,
706
+ )
707
+ ],
708
+ returns_message=True,
709
+ )
710
+ except mapping.MappingError as e:
711
+ output = e.lm_response
712
+ self.assertIn('parsing-lm-output', output.tags)
713
+ self.assertIn('parsing-lm-input', output.source.tags)
689
714
  self.assertEqual(output.root.text, 'Compute 1 + 2')
690
715
 
691
716
  def test_call_with_autofix(self):
@@ -105,12 +105,11 @@ def _query_structure_cls(
105
105
 
106
106
  def query(
107
107
  prompt: Union[str, lf.Template, Any],
108
- schema: Union[
109
- schema_lib.Schema, Type[Any], list[Type[Any]], dict[str, Any], None
110
- ] = None,
108
+ schema: schema_lib.SchemaType | None = None,
111
109
  default: Any = lf.RAISE_IF_HAS_ERROR,
112
110
  *,
113
- lm: lf.LanguageModel | None = None,
111
+ lm: lf.LanguageModel | list[lf.LanguageModel] | None = None,
112
+ num_samples: int | list[int] = 1,
114
113
  examples: list[mapping.MappingExample] | None = None,
115
114
  cache_seed: int | None = 0,
116
115
  response_postprocess: Callable[[str], str] | None = None,
@@ -121,76 +120,207 @@ def query(
121
120
  skip_lm: bool = False,
122
121
  **kwargs,
123
122
  ) -> Any:
124
- """Queries an language model for a (maybe) structured output.
123
+ """Query one or more language models for structured or unstructured outputs.
124
+
125
+ This is the primary API in Langfun for interacting with language models,
126
+ supporting natural language prompts, structured inputs, and multiple advanced
127
+ features.
128
+
129
+ Key Features:
130
+
131
+ - **Input**: Accepts natural language strings, structured inputs (e.g.,
132
+ `pg.Object`), and templates (`lf.Template`) with modality objects.
133
+
134
+ - **Output**: Returns structured outputs when `schema` is specified;
135
+ otherwise, outputs raw natural language (as a string).
136
+
137
+ - **Few-shot examples**: Supports structured few-shot examples with the
138
+ `examples` argument.
139
+
140
+ - **Multi-LM fan-out**: Sends queries to multiple language models with in
141
+ multiple samples in parallel, returning a list of outputs.
125
142
 
126
143
  Examples:
127
144
 
145
+ Case 1: Regular natural language-based LLM query:
146
+
147
+ ```
148
+ lf.query('1 + 1 = ?', lm=lf.llms.Gpt4Turbo())
149
+
150
+ # Outptut: '2'
151
+ ```
152
+
153
+ Case 2: Query with structured output.
154
+
155
+ ```
156
+ lf.query('1 + 1 = ?', int, lm=lf.llms.Gpt4Turbo())
157
+
158
+ # Output: 2
159
+ ```
160
+
161
+ Case 3: Query with structured input.
162
+
163
+ ```
164
+ class Sum(pg.Object):
165
+ a: int
166
+ b: int
167
+
168
+ lf.query(Sum(1, 1), int, lm=lf.llms.Gpt4Turbo())
169
+
170
+ # Output: 2
171
+ ```
172
+
173
+ Case 4: Query with input of mixed modalities.
174
+
128
175
  ```
129
- class FlightDuration:
130
- hours: int
131
- minutes: int
132
-
133
- class Flight(pg.Object):
134
- airline: str
135
- flight_number: str
136
- departure_airport_code: str
137
- arrival_airport_code: str
138
- departure_time: str
139
- arrival_time: str
140
- duration: FlightDuration
141
- stops: int
142
- price: float
143
-
144
- prompt = '''
145
- Information about flight UA2631.
146
- '''
147
-
148
- r = lf.query(prompt, Flight)
149
- assert isinstance(r, Flight)
150
- assert r.airline == 'United Airlines'
151
- assert r.departure_airport_code == 'SFO'
152
- assert r.duration.hour = 7
176
+ class Animal(pg.Object):
177
+ pass
178
+
179
+ class Dog(Animal):
180
+ pass
181
+
182
+ class Entity(pg.Object):
183
+ name: str
184
+
185
+ lf.query(
186
+ 'What is in this {{image}} and {{objects}}?'
187
+ list[Entity],
188
+ lm=lf.llms.Gpt4Turbo()
189
+ image=lf.Image(path='/path/to/a/airplane.png'),
190
+ objects=[Dog()],
191
+ )
192
+
193
+ # Output: [Entity(name='airplane'), Entity(name='dog')]
194
+ ```
195
+
196
+ Case 5: Query with structured few-shot examples.
197
+ ```
198
+ lf.query(
199
+ 'What is in this {{image}} and {{objects}}?'
200
+ list[Entity],
201
+ lm=lf.llms.Gpt4Turbo()
202
+ image=lf.Image(path='/path/to/a/dinasaur.png'),
203
+ objects=[Dog()],
204
+ examples=[
205
+ lf.MappingExample(
206
+ input=lf.Template(
207
+ 'What is the object near the house in this {{image}}?',
208
+ image=lf.Image(path='/path/to/image.png'),
209
+ ),
210
+ schema=Entity,
211
+ output=Entity('cat'),
212
+ ),
213
+ ],
214
+ )
215
+
216
+ # Output: [Entity(name='dinasaur'), Entity(name='dog')]
217
+ ```
218
+
219
+ Case 6: Multiple queries to multiple models.
220
+ ```
221
+ lf.query(
222
+ '1 + 1 = ?',
223
+ int,
224
+ lm=[
225
+ lf.llms.Gpt4Turbo(),
226
+ lf.llms.Gemini1_5Pro(),
227
+ ],
228
+ num_samples=[1, 2],
229
+ )
230
+ # Output: [2, 2, 2]
153
231
  ```
154
232
 
155
233
  Args:
156
- prompt: A str (may contain {{}} as template) as natural language input, or a
157
- `pg.Symbolic` object as structured input as prompt to LLM.
158
- schema: A type annotation as the schema for output object. If str (default),
159
- the response will be a str in natural language.
160
- default: The default value if parsing failed. If not specified, error will
161
- be raised.
162
- lm: The language model to use. If not specified, the language model from
163
- `lf.context` context manager will be used.
164
- examples: An optional list of fewshot examples for helping parsing. If None,
165
- the default one-shot example will be added.
166
- cache_seed: Seed for computing cache key. The cache key is determined by a
167
- tuple of (lm, prompt, cache seed). If None, cache will be disabled for the
168
- query even cache is configured by the LM.
169
- response_postprocess: An optional callable object to process the raw LM
170
- response before parsing it into the final output object. If None, the raw
171
- LM response will not be processed.
172
- autofix: Number of attempts to auto fix the generated code. If 0, autofix is
173
- disabled. Auto-fix is not supported for 'json' protocol.
174
- autofix_lm: The language model to use for autofix. If not specified, the
175
- `autofix_lm` from `lf.context` context manager will be used. Otherwise it
176
- will use `lm`.
177
- protocol: The protocol for schema/value representation. Applicable values
178
- are 'json' and 'python'. By default `python` will be used.
179
- returns_message: If True, returns `lf.Message` as the output, instead of
180
- returning the structured `message.result`.
181
- skip_lm: If True, returns the rendered prompt as a UserMessage object.
182
- otherwise return the LLM response based on the rendered prompt.
183
- **kwargs: Keyword arguments passed to render the prompt or configure the
184
- `lf.structured.Mapping` class. Notable kwargs are:
185
- - template_str: Change the root template for query.
186
- - preamble: Change the preamble for query.
187
- - mapping_template: Change the template for each mapping examle.
234
+ prompt: The input query. Can be:
235
+ - A natural language string (supports templating with `{{}}`),
236
+ - A `pg.Object` object for structured input,
237
+ - An `lf.Template` for mixed or template-based inputs.
238
+ schema: Type annotation or `lf.Schema` object for the expected output.
239
+ If `None` (default), the response will be a natural language string.
240
+ default: Default value to return if parsing fails. If not specified, an
241
+ error will be raised.
242
+ lm: The language model(s) to query. Can be:
243
+ - A single `LanguageModel`,
244
+ - A list of `LanguageModel`s for multi-model fan-out.
245
+ If `None`, the LM from `lf.context` will be used.
246
+ num_samples: Number of samples to generate. If a list is provided, its
247
+ length must match the number of models in `lm`.
248
+ examples: Few-shot examples to guide the model output. Defaults to `None`.
249
+ cache_seed: Seed for caching the query. Queries with the same
250
+ `(lm, prompt, cache_seed)` will use cached responses. If `None`,
251
+ caching is disabled.
252
+ response_postprocess: A post-processing function for the raw LM response.
253
+ If `None`, no post-processing occurs.
254
+ autofix: Number of attempts for auto-fixing code errors. Set to `0` to
255
+ disable auto-fixing. Not supported with the `'json'` protocol.
256
+ autofix_lm: The LM to use for auto-fixing. Defaults to the `autofix_lm`
257
+ from `lf.context` or the main `lm`.
258
+ protocol: Format for schema representation. Choices are `'json'` or
259
+ `'python'`. Default is `'python'`.
260
+ returns_message: If `True`, returns an `lf.Message` object instead of
261
+ the final parsed result.
262
+ skip_lm: If `True`, skips the LLM call and returns the rendered
263
+ prompt as a `UserMessage` object.
264
+ **kwargs: Additional keyword arguments for:
265
+ - Rendering templates (e.g., `template_str`, `preamble`),
266
+ - Configuring `lf.structured.Mapping`.
188
267
 
189
268
  Returns:
190
- The result based on the schema.
269
+ The result of the query:
270
+ - A single output or a list of outputs if multiple models/samples are used.
271
+ - Each output is a parsed object matching `schema`, an `lf.Message` (if
272
+ `returns_message=True`), or a natural language string (default).
191
273
  """
192
274
  # Internal usage logging.
193
275
 
276
+ # Multiple quries will be issued when `lm` is a list or `num_samples` is
277
+ # greater than 1.
278
+ if isinstance(lm, list) or num_samples != 1:
279
+ def _single_query(inputs):
280
+ lm, example_i = inputs
281
+ return query(
282
+ prompt,
283
+ schema,
284
+ default=default,
285
+ lm=lm,
286
+ examples=examples,
287
+ # Usually num_examples should not be large, so we multiple the user
288
+ # provided cache seed by 100 to avoid collision.
289
+ cache_seed=(
290
+ None if cache_seed is None else cache_seed * 100 + example_i
291
+ ),
292
+ response_postprocess=response_postprocess,
293
+ autofix=autofix,
294
+ autofix_lm=autofix_lm,
295
+ protocol=protocol,
296
+ returns_message=returns_message,
297
+ skip_lm=skip_lm,
298
+ **kwargs,
299
+ )
300
+ lm_list = lm if isinstance(lm, list) else [lm]
301
+ num_samples_list = (
302
+ num_samples if isinstance(num_samples, list)
303
+ else [num_samples] * len(lm_list)
304
+ )
305
+ assert len(lm_list) == len(num_samples_list), (
306
+ 'Expect the length of `num_samples` to be the same as the '
307
+ f'the length of `lm`. Got {num_samples} and {lm_list}.'
308
+ )
309
+ query_inputs = []
310
+ total_queries = 0
311
+ for lm, num_samples in zip(lm_list, num_samples_list):
312
+ query_inputs.extend([(lm, i) for i in range(num_samples)])
313
+ total_queries += num_samples
314
+
315
+ samples = []
316
+ for _, output, error in lf.concurrent_map(
317
+ _single_query, query_inputs, max_workers=max(64, total_queries),
318
+ ordered=True,
319
+ ):
320
+ if error is None:
321
+ samples.append(output)
322
+ return samples
323
+
194
324
  # Normalize query schema.
195
325
  # When `lf.query` is used for symbolic completion, schema is automatically
196
326
  # inferred when it is None.
@@ -280,11 +410,52 @@ def query(
280
410
  return output_message if returns_message else _result(output_message)
281
411
 
282
412
 
413
+ #
414
+ # Helper function for map-reduce style querying.
415
+ #
416
+
417
+
418
+ def query_and_reduce(
419
+ prompt: Union[str, lf.Template, Any],
420
+ schema: schema_lib.SchemaType | None = None,
421
+ *,
422
+ reduce: Callable[[list[Any]], Any],
423
+ lm: lf.LanguageModel | list[lf.LanguageModel] | None = None,
424
+ num_samples: int | list[int] = 1,
425
+ **kwargs,
426
+ ) -> Any:
427
+ """Issues multiple `lf.query` calls in parallel and reduce the outputs.
428
+
429
+ Args:
430
+ prompt: A str (may contain {{}} as template) as natural language input, or a
431
+ `pg.Symbolic` object as structured input as prompt to LLM.
432
+ schema: A type annotation as the schema for output object. If str (default),
433
+ the response will be a str in natural language.
434
+ reduce: A function to reduce the outputs of multiple `lf.query` calls. It
435
+ takes a list of outputs and returns the final object.
436
+ lm: The language model to use. If not specified, the language model from
437
+ `lf.context` context manager will be used.
438
+ num_samples: The number of samples to obtain from each language model being
439
+ requested. If a list is provided, it should have the same length as `lm`.
440
+ **kwargs: Additional arguments to pass to `lf.query`.
441
+
442
+ Returns:
443
+ The reduced output from multiple `lf.query` calls.
444
+ """
445
+ results = query(prompt, schema, lm=lm, num_samples=num_samples, **kwargs)
446
+ if isinstance(results, list):
447
+ results = reduce(results)
448
+ return results
449
+
450
+
451
+ #
452
+ # Functions for decomposing `lf.query` into pre-llm and post-llm operations.
453
+ #
454
+
455
+
283
456
  def query_prompt(
284
457
  prompt: Union[str, lf.Template, Any],
285
- schema: Union[
286
- schema_lib.Schema, Type[Any], list[Type[Any]], dict[str, Any], None
287
- ] = None,
458
+ schema: schema_lib.SchemaType | None = None,
288
459
  **kwargs,
289
460
  ) -> lf.Message:
290
461
  """Returns the final prompt sent to LLM for `lf.query`."""
@@ -295,9 +466,7 @@ def query_prompt(
295
466
 
296
467
  def query_output(
297
468
  response: Union[str, lf.Message],
298
- schema: Union[
299
- schema_lib.Schema, Type[Any], list[Type[Any]], dict[str, Any], None
300
- ],
469
+ schema: schema_lib.SchemaType | None = None,
301
470
  **kwargs,
302
471
  ) -> Any:
303
472
  """Returns the final output of `lf.query` from a provided LLM response."""
@@ -308,6 +477,11 @@ def query_output(
308
477
  )
309
478
 
310
479
 
480
+ #
481
+ # Functions for computing reward of an LLM response based on a mapping example.
482
+ #
483
+
484
+
311
485
  def query_reward(
312
486
  mapping_example: Union[str, mapping.MappingExample],
313
487
  response: Union[str, lf.Message],
@@ -362,6 +536,11 @@ def _reward_fn(cls) -> Callable[
362
536
  return _reward
363
537
 
364
538
 
539
+ #
540
+ # Functions for tracking `lf.query` invocations.
541
+ #
542
+
543
+
365
544
  class QueryInvocation(pg.Object, pg.views.HtmlTreeView.Extension):
366
545
  """A class to represent the invocation of `lf.query`."""
367
546
 
@@ -327,6 +327,69 @@ class QueryTest(unittest.TestCase):
327
327
  expected_modalities=3,
328
328
  )
329
329
 
330
+ def test_multiple_queries(self):
331
+ self.assertEqual(
332
+ querying.query(
333
+ 'Compute 1 + 2',
334
+ int,
335
+ lm=[
336
+ fake.StaticResponse('1'),
337
+ fake.StaticResponse('2'),
338
+ ],
339
+ num_samples=[1, 2],
340
+ ),
341
+ [1, 2, 2]
342
+ )
343
+ self.assertEqual(
344
+ querying.query(
345
+ 'Compute 1 + 2',
346
+ int,
347
+ lm=[
348
+ fake.StaticResponse('1'),
349
+ fake.StaticResponse('2'),
350
+ ],
351
+ num_samples=2,
352
+ ),
353
+ [1, 1, 2, 2]
354
+ )
355
+ self.assertEqual(
356
+ querying.query(
357
+ 'Compute 1 + 2',
358
+ int,
359
+ lm=[
360
+ fake.StaticResponse('1'),
361
+ fake.StaticResponse('abc'),
362
+ ],
363
+ num_samples=[1, 2],
364
+ ),
365
+ [1]
366
+ )
367
+ self.assertEqual(
368
+ querying.query(
369
+ 'Compute 1 + 2',
370
+ int,
371
+ default=0,
372
+ lm=[
373
+ fake.StaticResponse('1'),
374
+ fake.StaticResponse('abc'),
375
+ ],
376
+ num_samples=[1, 2],
377
+ ),
378
+ [1, 0, 0]
379
+ )
380
+ results = querying.query(
381
+ 'Compute 1 + 2',
382
+ int,
383
+ default=0,
384
+ lm=[
385
+ fake.StaticResponse('1'),
386
+ fake.StaticResponse('abc'),
387
+ ],
388
+ returns_message=True,
389
+ )
390
+ self.assertEqual([r.text for r in results], ['1', 'abc'])
391
+ self.assertEqual([r.result for r in results], [1, 0])
392
+
330
393
  def test_bad_protocol(self):
331
394
  with self.assertRaisesRegex(ValueError, 'Unknown protocol'):
332
395
  querying.query('what is 1 + 1', int, protocol='text')
@@ -393,6 +456,30 @@ class QueryTest(unittest.TestCase):
393
456
  )
394
457
  self.assertIsNotNone(output.get_modality('image'))
395
458
 
459
+ def test_query_and_reduce(self):
460
+ self.assertEqual(
461
+ querying.query_and_reduce(
462
+ 'Compute 1 + 1',
463
+ int,
464
+ reduce=sum,
465
+ lm=[
466
+ fake.StaticResponse('1'),
467
+ fake.StaticResponse('2'),
468
+ ],
469
+ num_samples=[1, 2],
470
+ ),
471
+ 5
472
+ )
473
+ self.assertEqual(
474
+ querying.query_and_reduce(
475
+ 'Compute 1 + 1',
476
+ int,
477
+ reduce=sum,
478
+ lm=fake.StaticResponse('2'),
479
+ ),
480
+ 2
481
+ )
482
+
396
483
  def test_query_output(self):
397
484
  self.assertEqual(
398
485
  querying.query_output(
@@ -213,18 +213,8 @@ class Schema(
213
213
  """
214
214
  )
215
215
 
216
- def _html_tree_view_tooltip(
217
- self,
218
- *,
219
- view: pg.views.HtmlTreeView,
220
- content: pg.Html | str | None = None,
221
- **kwargs,
222
- ):
223
- return view.tooltip(
224
- self,
225
- content=content or pg.Html.escape(self.schema_str(protocol='python')),
226
- **kwargs
227
- )
216
+
217
+ SchemaType = Union[Schema, Type[Any], list[Type[Any]], dict[str, Any]]
228
218
 
229
219
 
230
220
  def _top_level_object_specs_from_value(value: pg.Symbolic) -> list[Type[Any]]:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: langfun
3
- Version: 0.1.2.dev202412170805
3
+ Version: 0.1.2.dev202412190804
4
4
  Summary: Langfun: Language as Functions.
5
5
  Home-page: https://github.com/google/langfun
6
6
  Author: Langfun Authors
@@ -1,4 +1,4 @@
1
- langfun/__init__.py,sha256=Xkzi8VV93jI7iLSrypzyHV9FtRbxHQpmoUCIZJSGHdA,2400
1
+ langfun/__init__.py,sha256=fhfPXpHN7GoGqixpFfqhQkYxFs_siP_LhbjZhd3lhio,2497
2
2
  langfun/core/__init__.py,sha256=xlvFTXc7IKUTs8aCFRFhzOLTmmeuhXgk9yx2InBLNiA,4937
3
3
  langfun/core/component.py,sha256=HVrEoTL1Y01iqOHC3FYdbAOnffqfHHtGJXoK1vkdEwo,11583
4
4
  langfun/core/component_test.py,sha256=sG-T2wpvBfHqWGZE7sc4NayJj2aj5QFBzSwFiwrGEIc,10376
@@ -10,8 +10,8 @@ langfun/core/langfunc.py,sha256=G50YgoVZ0y1GFw2ev41MlOqr6qa8YakbvNC0h_E0PiA,1114
10
10
  langfun/core/langfunc_test.py,sha256=fKIAqcSNI_7M6nwoZW77HEam8Oa6vcWhsCNgVJanzb4,8822
11
11
  langfun/core/language_model.py,sha256=b15MZ_qbydnz5vQ09t7sf9tc3C7qWvMSxUrGfT0p99I,33827
12
12
  langfun/core/language_model_test.py,sha256=hnYhtw7GM_TbhgsJzHNYTaoDewUlPHpOVlI7xEkCFuI,31783
13
- langfun/core/logging.py,sha256=uslllP0RTGN223oro1m4nZZ0bFppcL07OwbFKm2iG6k,7519
14
- langfun/core/logging_test.py,sha256=b5bPTSUoYeICATaO6I8dOVumodwRbxSp1Oz96Sf3KcE,6104
13
+ langfun/core/logging.py,sha256=W3mLEMXdo210Q5OX3a1ZTc4nU-xMy73-IfNKnsA-RFo,8051
14
+ langfun/core/logging_test.py,sha256=N7-YvSXC8zvnr2SNwWHOykn1CFmqvIuTLDgn41Ku9JU,6642
15
15
  langfun/core/memory.py,sha256=f-asN1F7Vehgdn_fK84v73GrEUOxRtaW934keutTKjk,2416
16
16
  langfun/core/message.py,sha256=16oiMpg9O9VKrgpfrvJrfvga3n3FzUuD_zdWb9nvSWA,25686
17
17
  langfun/core/message_test.py,sha256=jtZoNBNbA99i2fjoKg5vTRgoUe84J4MH8ZMGakGmTHs,32577
@@ -58,10 +58,10 @@ langfun/core/eval/patching_test.py,sha256=8kCd54Egjju22FMgtJuxEsrXkW8ifs-UUBHtrC
58
58
  langfun/core/eval/scoring.py,sha256=B69IsIxiPs1xZcOBFIhZF70YmDue2Siik-CPL2bh33s,6254
59
59
  langfun/core/eval/scoring_test.py,sha256=O8olHbrUEg60gMxwOkWzKBJZpZoUlmVnBANX5Se2SXM,4546
60
60
  langfun/core/eval/v2/__init__.py,sha256=qoa6zKdFXOFyCX6vay6OdgPf1eUhYGoHYAxe35qECGk,1628
61
- langfun/core/eval/v2/checkpointing.py,sha256=ks0az5IJCceZG8V8pcqpkaaXcO9IHnbLpjklmBM1_uQ,6257
61
+ langfun/core/eval/v2/checkpointing.py,sha256=8vxH3AfIBS8dxA0IiOZBUxAHXIx5m2tSWSSumDLpzp8,6546
62
62
  langfun/core/eval/v2/checkpointing_test.py,sha256=dAERKQTW_PM1B0oUauB0YVQkMEI-cgJq0q-wAVlGYpU,4383
63
- langfun/core/eval/v2/evaluation.py,sha256=h_AWRUSKhEs-bHLBgqo-GeBYXluD5bPbAqypRW0ajfA,19441
64
- langfun/core/eval/v2/evaluation_test.py,sha256=hh6L2HhQPQ6NBv1pXKcNkYraNcV9MLuJ--69t9jbmaI,5846
63
+ langfun/core/eval/v2/evaluation.py,sha256=7PC-npbEQjwwv0pWbv8vGi_OkzZ7QpJrEpYoixFBlno,21429
64
+ langfun/core/eval/v2/evaluation_test.py,sha256=ld8oBOjsfN-LNLL2eViSTu17wAq90GcsfURXX6oVlFo,6014
65
65
  langfun/core/eval/v2/example.py,sha256=fURrvdNmMsVMqoEErcsmLmC6Xq3ny16dYsnLH8HVlcY,9626
66
66
  langfun/core/eval/v2/example_test.py,sha256=WcJmU7IQQXvjFia63mokySC4CqxzVL9Wso1sC5F0YK8,3032
67
67
  langfun/core/eval/v2/experiment.py,sha256=0JBGckJ93aqSdffpJPDVPy_I5T2BXscghTxiglHzJWo,29556
@@ -96,7 +96,7 @@ langfun/core/llms/openai.py,sha256=l49v6RubfInvV0iG114AymTKNogTX4u4N-UFCeSgIxw,2
96
96
  langfun/core/llms/openai_test.py,sha256=kOWa1nf-nJvtYY10REUw5wojh3ZgfU8tRaCZ8wUgJbA,16623
97
97
  langfun/core/llms/rest.py,sha256=sWbYUV8S3SuOg9giq7xwD-xDRfaF7NP_ig7bI52-Rj4,3442
98
98
  langfun/core/llms/rest_test.py,sha256=NZ3Nf0XQVpT9kLP5cBVo_yBHLI7vWTYhWQxYEJVMGs4,3472
99
- langfun/core/llms/vertexai.py,sha256=adUTByiuiTHBQ31tM_EXPUWIyUwo3zqyYIe9UILAFDE,14981
99
+ langfun/core/llms/vertexai.py,sha256=oEd665IBwzCTlHuLEMrCdwgQzrFB5ERcnxw6nrYNSyk,14990
100
100
  langfun/core/llms/vertexai_test.py,sha256=ffcA5yPecnQy_rhkuYAw_6o1iLW8AR8FgswmHt6aAys,6725
101
101
  langfun/core/llms/cache/__init__.py,sha256=QAo3InUMDM_YpteNnVCSejI4zOsnjSMWKJKzkb3VY64,993
102
102
  langfun/core/llms/cache/base.py,sha256=rt3zwmyw0y9jsSGW-ZbV1vAfLxQ7_3AVk0l2EySlse4,3918
@@ -118,7 +118,7 @@ langfun/core/modalities/pdf.py,sha256=mfaeCbUA4JslFVTARiJh8hW7imvL4tLVw9gUhO5bAZ
118
118
  langfun/core/modalities/pdf_test.py,sha256=ulZ0FbnlsU0wkrdckJ4ONZPTYRyMPO9Aob1UO6FXygk,1950
119
119
  langfun/core/modalities/video.py,sha256=vI9apcHIHGyp90i34Srg7S3G6IBDtDCk8qiXhwRQmkw,967
120
120
  langfun/core/modalities/video_test.py,sha256=7OXZoohKMYjt7vrJUdPb553HLyl1oBOKRgzBePFv68Q,2042
121
- langfun/core/structured/__init__.py,sha256=i5Plug6041OtEltsDSQxCPFM-4XD4HKZrhw9uLQXFXk,3157
121
+ langfun/core/structured/__init__.py,sha256=3Wb4ks14D5E5z1nQ5trXSXiLqFN5E_3HvcEJQAfFRl0,3220
122
122
  langfun/core/structured/completion.py,sha256=yW95Yd4wbt964I5wIyPUtIVeeqeZbA6HidLgN0ZpWm0,8110
123
123
  langfun/core/structured/completion_test.py,sha256=VtYfI3ciVSSWbi8x3l1WwpWK-Ofn2DMHYEEqm2uTzhw,19314
124
124
  langfun/core/structured/description.py,sha256=6BztYOiucPkF4CrTQtPLPJo1gN2dwnKmaJW83GBf4H0,5213
@@ -127,11 +127,11 @@ langfun/core/structured/function_generation.py,sha256=g7AOR_e8HxFU6n6Df750aGkgMg
127
127
  langfun/core/structured/function_generation_test.py,sha256=LaXYDXf9GlqUrR6v_gtmK_H4kxzonmU7SYbn7XXMgjU,12128
128
128
  langfun/core/structured/mapping.py,sha256=vLKH79UT-j0qkQdvqlQBO7SkXXuM-yr2Idm8_HH8qwM,13649
129
129
  langfun/core/structured/mapping_test.py,sha256=bHm2ZCXBITq_G8Lvw_olFHeUUc4s_lGXZm9v9JhoPB4,9630
130
- langfun/core/structured/parsing.py,sha256=lhEkdnvxKzkYwHsTvBdE2j6uLWl-J8uQu6c-3xcsBXM,11770
131
- langfun/core/structured/parsing_test.py,sha256=-uPiLi0cRBkf0ZycZsgLPIfRLLdwYhRbm2LHHp_pVGE,21475
132
- langfun/core/structured/querying.py,sha256=3RvK20G39TFrW6yV4lR9s7SbxlIpKRZ43Hey0obtJ3M,17539
133
- langfun/core/structured/querying_test.py,sha256=CuoClQLtEwW3Pd2otGwAoVzR0KBjT-Keppu9BOpz4mA,29705
134
- langfun/core/structured/schema.py,sha256=XnIKBDpPU6YmWJ0ncmoW23ha6VLyLzXf7hdY2-Urn-w,28133
130
+ langfun/core/structured/parsing.py,sha256=MGvI7ypXlwfzr5XB8_TFU9Ei0_5reYqkWkv64eAy0EA,12015
131
+ langfun/core/structured/parsing_test.py,sha256=kNPrhpdPY3iWhUld0TFYU-Zgn44wC0d6YuQ9XdVbQ8o,22346
132
+ langfun/core/structured/querying.py,sha256=sXGhYtiEBac8iOkYOErGXyX8SAHSB1gg69WePhOyGxE,22759
133
+ langfun/core/structured/querying_test.py,sha256=M9Apg83KjQUjT42K9LheBEr74DX3Inwd0YmCanA71kc,31738
134
+ langfun/core/structured/schema.py,sha256=0VUPSfX1JEQ0xu8WvEymCKK_WSGwBNA-rQD2hATErmU,27912
135
135
  langfun/core/structured/schema_generation.py,sha256=U3nRQsqmMZg_qIVDh2fiY3K4JLfsAL1LcKzIFP1iXFg,5316
136
136
  langfun/core/structured/schema_generation_test.py,sha256=RM9s71kMNg2jTePwInkiW9fK1ACN37eyPeF8OII-0zw,2950
137
137
  langfun/core/structured/schema_test.py,sha256=RjYhwTgktQgyqAjzLvo967nTiIK9KWgP-aNGg4e7ihE,25258
@@ -148,8 +148,8 @@ langfun/core/templates/demonstration.py,sha256=vCrgYubdZM5Umqcgp8NUVGXgr4P_c-fik
148
148
  langfun/core/templates/demonstration_test.py,sha256=SafcDQ0WgI7pw05EmPI2S4v1t3ABKzup8jReCljHeK4,2162
149
149
  langfun/core/templates/selfplay.py,sha256=yhgrJbiYwq47TgzThmHrDQTF4nDrTI09CWGhuQPNv-s,2273
150
150
  langfun/core/templates/selfplay_test.py,sha256=Ot__1P1M8oJfoTp-M9-PQ6HUXqZKyMwvZ5f7yQ3yfyM,2326
151
- langfun-0.1.2.dev202412170805.dist-info/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
152
- langfun-0.1.2.dev202412170805.dist-info/METADATA,sha256=b9Ans5czArE6iUCD8pENhPZzGEgwjLK5oPSF31NPj0s,8281
153
- langfun-0.1.2.dev202412170805.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
154
- langfun-0.1.2.dev202412170805.dist-info/top_level.txt,sha256=RhlEkHxs1qtzmmtWSwYoLVJAc1YrbPtxQ52uh8Z9VvY,8
155
- langfun-0.1.2.dev202412170805.dist-info/RECORD,,
151
+ langfun-0.1.2.dev202412190804.dist-info/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
152
+ langfun-0.1.2.dev202412190804.dist-info/METADATA,sha256=Zr8TfOnhdo83h3aGRNRWXTrJ54h7Sh7E-7Lj95iJVDw,8281
153
+ langfun-0.1.2.dev202412190804.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
154
+ langfun-0.1.2.dev202412190804.dist-info/top_level.txt,sha256=RhlEkHxs1qtzmmtWSwYoLVJAc1YrbPtxQ52uh8Z9VvY,8
155
+ langfun-0.1.2.dev202412190804.dist-info/RECORD,,