langfun 0.0.2.dev20240429__py3-none-any.whl → 0.1.2.dev202501140804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (144) hide show
  1. langfun/__init__.py +20 -2
  2. langfun/core/__init__.py +16 -5
  3. langfun/core/agentic/__init__.py +30 -0
  4. langfun/core/agentic/action.py +854 -0
  5. langfun/core/agentic/action_eval.py +150 -0
  6. langfun/core/agentic/action_eval_test.py +109 -0
  7. langfun/core/agentic/action_test.py +136 -0
  8. langfun/core/coding/python/__init__.py +5 -11
  9. langfun/core/coding/python/correction.py +37 -21
  10. langfun/core/coding/python/correction_test.py +29 -3
  11. langfun/core/coding/python/execution.py +40 -216
  12. langfun/core/coding/python/execution_test.py +29 -89
  13. langfun/core/coding/python/generation.py +21 -11
  14. langfun/core/coding/python/generation_test.py +2 -2
  15. langfun/core/coding/python/parsing.py +108 -193
  16. langfun/core/coding/python/parsing_test.py +2 -105
  17. langfun/core/component.py +63 -2
  18. langfun/core/component_test.py +53 -0
  19. langfun/core/concurrent.py +414 -117
  20. langfun/core/concurrent_test.py +111 -24
  21. langfun/core/console.py +18 -5
  22. langfun/core/console_test.py +17 -0
  23. langfun/core/eval/__init__.py +16 -1
  24. langfun/core/eval/base.py +622 -174
  25. langfun/core/eval/base_test.py +200 -54
  26. langfun/core/eval/matching.py +63 -76
  27. langfun/core/eval/matching_test.py +17 -8
  28. langfun/core/eval/patching.py +130 -0
  29. langfun/core/eval/patching_test.py +170 -0
  30. langfun/core/eval/scoring.py +26 -26
  31. langfun/core/eval/scoring_test.py +19 -2
  32. langfun/core/eval/v2/__init__.py +42 -0
  33. langfun/core/eval/v2/checkpointing.py +380 -0
  34. langfun/core/eval/v2/checkpointing_test.py +228 -0
  35. langfun/core/eval/v2/eval_test_helper.py +136 -0
  36. langfun/core/eval/v2/evaluation.py +725 -0
  37. langfun/core/eval/v2/evaluation_test.py +180 -0
  38. langfun/core/eval/v2/example.py +305 -0
  39. langfun/core/eval/v2/example_test.py +128 -0
  40. langfun/core/eval/v2/experiment.py +1048 -0
  41. langfun/core/eval/v2/experiment_test.py +433 -0
  42. langfun/core/eval/v2/metric_values.py +156 -0
  43. langfun/core/eval/v2/metric_values_test.py +80 -0
  44. langfun/core/eval/v2/metrics.py +357 -0
  45. langfun/core/eval/v2/metrics_test.py +203 -0
  46. langfun/core/eval/v2/progress.py +348 -0
  47. langfun/core/eval/v2/progress_test.py +82 -0
  48. langfun/core/eval/v2/progress_tracking.py +210 -0
  49. langfun/core/eval/v2/progress_tracking_test.py +66 -0
  50. langfun/core/eval/v2/reporting.py +270 -0
  51. langfun/core/eval/v2/reporting_test.py +158 -0
  52. langfun/core/eval/v2/runners.py +488 -0
  53. langfun/core/eval/v2/runners_test.py +334 -0
  54. langfun/core/langfunc.py +4 -17
  55. langfun/core/langfunc_test.py +22 -6
  56. langfun/core/language_model.py +577 -39
  57. langfun/core/language_model_test.py +470 -56
  58. langfun/core/llms/__init__.py +87 -16
  59. langfun/core/llms/anthropic.py +312 -87
  60. langfun/core/llms/anthropic_test.py +71 -3
  61. langfun/core/llms/cache/base.py +21 -2
  62. langfun/core/llms/cache/in_memory.py +13 -0
  63. langfun/core/llms/cache/in_memory_test.py +53 -2
  64. langfun/core/llms/compositional.py +101 -0
  65. langfun/core/llms/compositional_test.py +73 -0
  66. langfun/core/llms/deepseek.py +117 -0
  67. langfun/core/llms/deepseek_test.py +61 -0
  68. langfun/core/llms/fake.py +11 -7
  69. langfun/core/llms/fake_test.py +14 -0
  70. langfun/core/llms/gemini.py +507 -0
  71. langfun/core/llms/gemini_test.py +195 -0
  72. langfun/core/llms/google_genai.py +62 -218
  73. langfun/core/llms/google_genai_test.py +9 -202
  74. langfun/core/llms/groq.py +160 -144
  75. langfun/core/llms/groq_test.py +31 -137
  76. langfun/core/llms/llama_cpp.py +15 -42
  77. langfun/core/llms/llama_cpp_test.py +4 -30
  78. langfun/core/llms/openai.py +395 -203
  79. langfun/core/llms/openai_compatible.py +179 -0
  80. langfun/core/llms/openai_compatible_test.py +495 -0
  81. langfun/core/llms/openai_test.py +30 -395
  82. langfun/core/llms/rest.py +113 -0
  83. langfun/core/llms/rest_test.py +111 -0
  84. langfun/core/llms/vertexai.py +192 -0
  85. langfun/core/llms/vertexai_test.py +52 -0
  86. langfun/core/logging.py +284 -0
  87. langfun/core/logging_test.py +125 -0
  88. langfun/core/message.py +319 -9
  89. langfun/core/message_test.py +190 -13
  90. langfun/core/modalities/__init__.py +6 -2
  91. langfun/core/modalities/audio.py +30 -0
  92. langfun/core/modalities/audio_test.py +63 -0
  93. langfun/core/modalities/image.py +39 -20
  94. langfun/core/modalities/image_test.py +52 -9
  95. langfun/core/modalities/mime.py +206 -29
  96. langfun/core/modalities/mime_test.py +90 -9
  97. langfun/core/modalities/ms_office.py +117 -0
  98. langfun/core/modalities/ms_office_test.py +389 -0
  99. langfun/core/modalities/pdf.py +22 -0
  100. langfun/core/modalities/pdf_test.py +57 -0
  101. langfun/core/modalities/video.py +9 -26
  102. langfun/core/modalities/video_test.py +3 -3
  103. langfun/core/modality.py +26 -3
  104. langfun/core/modality_test.py +2 -2
  105. langfun/core/sampling.py +11 -11
  106. langfun/core/structured/__init__.py +12 -16
  107. langfun/core/structured/completion.py +32 -5
  108. langfun/core/structured/completion_test.py +7 -6
  109. langfun/core/structured/description.py +2 -2
  110. langfun/core/structured/description_test.py +3 -3
  111. langfun/core/structured/function_generation.py +60 -27
  112. langfun/core/structured/function_generation_test.py +72 -2
  113. langfun/core/structured/mapping.py +97 -47
  114. langfun/core/structured/mapping_test.py +90 -2
  115. langfun/core/structured/parsing.py +33 -21
  116. langfun/core/structured/parsing_test.py +53 -9
  117. langfun/core/structured/querying.py +746 -0
  118. langfun/core/structured/{prompting_test.py → querying_test.py} +469 -51
  119. langfun/core/structured/schema.py +204 -97
  120. langfun/core/structured/schema_generation.py +1 -1
  121. langfun/core/structured/schema_test.py +130 -29
  122. langfun/core/structured/scoring.py +125 -19
  123. langfun/core/structured/scoring_test.py +30 -0
  124. langfun/core/structured/tokenization.py +64 -0
  125. langfun/core/structured/tokenization_test.py +48 -0
  126. langfun/core/template.py +115 -1
  127. langfun/core/template_test.py +71 -1
  128. langfun/core/templates/conversation.py +9 -0
  129. langfun/core/templates/conversation_test.py +4 -3
  130. langfun/core/templates/selfplay_test.py +10 -2
  131. langfun-0.1.2.dev202501140804.dist-info/METADATA +225 -0
  132. langfun-0.1.2.dev202501140804.dist-info/RECORD +153 -0
  133. {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501140804.dist-info}/WHEEL +1 -1
  134. langfun/core/coding/python/errors.py +0 -108
  135. langfun/core/coding/python/errors_test.py +0 -99
  136. langfun/core/coding/python/permissions.py +0 -90
  137. langfun/core/coding/python/permissions_test.py +0 -86
  138. langfun/core/structured/prompting.py +0 -238
  139. langfun/core/text_formatting.py +0 -162
  140. langfun/core/text_formatting_test.py +0 -47
  141. langfun-0.0.2.dev20240429.dist-info/METADATA +0 -100
  142. langfun-0.0.2.dev20240429.dist-info/RECORD +0 -108
  143. {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501140804.dist-info}/LICENSE +0 -0
  144. {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501140804.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,746 @@
1
+ # Copyright 2023 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Query LLM for structured output."""
15
+
16
+ import contextlib
17
+ import functools
18
+ import time
19
+ from typing import Annotated, Any, Callable, Iterator, Type, Union
20
+
21
+ import langfun.core as lf
22
+ from langfun.core.llms import fake
23
+ from langfun.core.structured import mapping
24
+ from langfun.core.structured import schema as schema_lib
25
+ import pyglove as pg
26
+
27
+
28
+ @lf.use_init_args(['schema', 'default', 'examples'])
29
+ class _QueryStructure(mapping.Mapping):
30
+ """Query an object out from a natural language text."""
31
+
32
+ context_title = 'CONTEXT'
33
+ input_title = 'INPUT_OBJECT'
34
+
35
+ # Mark schema as required.
36
+ schema: pg.typing.Annotated[
37
+ schema_lib.schema_spec(), 'Required schema for parsing.'
38
+ ]
39
+
40
+
41
+ class _QueryStructureJson(_QueryStructure):
42
+ """Query a structured value using JSON as the protocol."""
43
+
44
+ preamble = """
45
+ Please respond to the last {{ input_title }} with {{ output_title}} according to {{ schema_title }}:
46
+
47
+ INSTRUCTIONS:
48
+ 1. If the schema has `_type`, carry it over to the JSON output.
49
+ 2. If a field from the schema cannot be extracted from the response, use null as the JSON value.
50
+
51
+ {{ input_title }}:
52
+ 1 + 1 =
53
+
54
+ {{ schema_title }}:
55
+ {"result": {"_type": "langfun.core.structured.query.Answer", "final_answer": int}}
56
+
57
+ {{ output_title}}:
58
+ {"result": {"_type": "langfun.core.structured.query.Answer", "final_answer": 2}}
59
+ """
60
+
61
+ protocol = 'json'
62
+ schema_title = 'SCHEMA'
63
+ output_title = 'JSON'
64
+
65
+
66
+ class _QueryStructurePython(_QueryStructure):
67
+ """Query a structured value using Python as the protocol."""
68
+
69
+ preamble = """
70
+ Please respond to the last {{ input_title }} with {{ output_title }} according to {{ schema_title }}.
71
+
72
+ {{ input_title }}:
73
+ 1 + 1 =
74
+
75
+ {{ schema_title }}:
76
+ Answer
77
+
78
+ ```python
79
+ class Answer:
80
+ final_answer: int
81
+ ```
82
+
83
+ {{ output_title }}:
84
+ ```python
85
+ Answer(
86
+ final_answer=2
87
+ )
88
+ ```
89
+ """
90
+ protocol = 'python'
91
+ schema_title = 'OUTPUT_TYPE'
92
+ output_title = 'OUTPUT_OBJECT'
93
+
94
+
95
+ def _query_structure_cls(
96
+ protocol: schema_lib.SchemaProtocol,
97
+ ) -> Type[_QueryStructure]:
98
+ if protocol == 'json':
99
+ return _QueryStructureJson
100
+ elif protocol == 'python':
101
+ return _QueryStructurePython
102
+ else:
103
+ raise ValueError(f'Unknown protocol: {protocol!r}.')
104
+
105
+
106
+ def query(
107
+ prompt: Union[str, lf.Template, Any],
108
+ schema: schema_lib.SchemaType | None = None,
109
+ default: Any = lf.RAISE_IF_HAS_ERROR,
110
+ *,
111
+ lm: lf.LanguageModel | list[lf.LanguageModel] | None = None,
112
+ num_samples: int | list[int] = 1,
113
+ examples: list[mapping.MappingExample] | None = None,
114
+ cache_seed: int | None = 0,
115
+ response_postprocess: Callable[[str], str] | None = None,
116
+ autofix: int = 0,
117
+ autofix_lm: lf.LanguageModel | None = None,
118
+ protocol: schema_lib.SchemaProtocol = 'python',
119
+ returns_message: bool = False,
120
+ skip_lm: bool = False,
121
+ **kwargs,
122
+ ) -> Any:
123
+ """Query one or more language models for structured or unstructured outputs.
124
+
125
+ This is the primary API in Langfun for interacting with language models,
126
+ supporting natural language prompts, structured inputs, and multiple advanced
127
+ features.
128
+
129
+ Key Features:
130
+
131
+ - **Input**: Accepts natural language strings, structured inputs (e.g.,
132
+ `pg.Object`), and templates (`lf.Template`) with modality objects.
133
+
134
+ - **Output**: Returns structured outputs when `schema` is specified;
135
+ otherwise, outputs raw natural language (as a string).
136
+
137
+ - **Few-shot examples**: Supports structured few-shot examples with the
138
+ `examples` argument.
139
+
140
+ - **Multi-LM fan-out**: Sends queries to multiple language models with in
141
+ multiple samples in parallel, returning a list of outputs.
142
+
143
+ Examples:
144
+
145
+ Case 1: Regular natural language-based LLM query:
146
+
147
+ ```
148
+ lf.query('1 + 1 = ?', lm=lf.llms.Gpt4Turbo())
149
+
150
+ # Outptut: '2'
151
+ ```
152
+
153
+ Case 2: Query with structured output.
154
+
155
+ ```
156
+ lf.query('1 + 1 = ?', int, lm=lf.llms.Gpt4Turbo())
157
+
158
+ # Output: 2
159
+ ```
160
+
161
+ Case 3: Query with structured input.
162
+
163
+ ```
164
+ class Sum(pg.Object):
165
+ a: int
166
+ b: int
167
+
168
+ lf.query(Sum(1, 1), int, lm=lf.llms.Gpt4Turbo())
169
+
170
+ # Output: 2
171
+ ```
172
+
173
+ Case 4: Query with input of mixed modalities.
174
+
175
+ ```
176
+ class Animal(pg.Object):
177
+ pass
178
+
179
+ class Dog(Animal):
180
+ pass
181
+
182
+ class Entity(pg.Object):
183
+ name: str
184
+
185
+ lf.query(
186
+ 'What is in this {{image}} and {{objects}}?'
187
+ list[Entity],
188
+ lm=lf.llms.Gpt4Turbo()
189
+ image=lf.Image(path='/path/to/a/airplane.png'),
190
+ objects=[Dog()],
191
+ )
192
+
193
+ # Output: [Entity(name='airplane'), Entity(name='dog')]
194
+ ```
195
+
196
+ Case 5: Query with structured few-shot examples.
197
+ ```
198
+ lf.query(
199
+ 'What is in this {{image}} and {{objects}}?'
200
+ list[Entity],
201
+ lm=lf.llms.Gpt4Turbo()
202
+ image=lf.Image(path='/path/to/a/dinasaur.png'),
203
+ objects=[Dog()],
204
+ examples=[
205
+ lf.MappingExample(
206
+ input=lf.Template(
207
+ 'What is the object near the house in this {{image}}?',
208
+ image=lf.Image(path='/path/to/image.png'),
209
+ ),
210
+ schema=Entity,
211
+ output=Entity('cat'),
212
+ ),
213
+ ],
214
+ )
215
+
216
+ # Output: [Entity(name='dinasaur'), Entity(name='dog')]
217
+ ```
218
+
219
+ Case 6: Multiple queries to multiple models.
220
+ ```
221
+ lf.query(
222
+ '1 + 1 = ?',
223
+ int,
224
+ lm=[
225
+ lf.llms.Gpt4Turbo(),
226
+ lf.llms.Gemini1_5Pro(),
227
+ ],
228
+ num_samples=[1, 2],
229
+ )
230
+ # Output: [2, 2, 2]
231
+ ```
232
+
233
+ Args:
234
+ prompt: The input query. Can be:
235
+ - A natural language string (supports templating with `{{}}`),
236
+ - A `pg.Object` object for structured input,
237
+ - An `lf.Template` for mixed or template-based inputs.
238
+ schema: Type annotation or `lf.Schema` object for the expected output.
239
+ If `None` (default), the response will be a natural language string.
240
+ default: Default value to return if parsing fails. If not specified, an
241
+ error will be raised.
242
+ lm: The language model(s) to query. Can be:
243
+ - A single `LanguageModel`,
244
+ - A list of `LanguageModel`s for multi-model fan-out.
245
+ If `None`, the LM from `lf.context` will be used.
246
+ num_samples: Number of samples to generate. If a list is provided, its
247
+ length must match the number of models in `lm`.
248
+ examples: Few-shot examples to guide the model output. Defaults to `None`.
249
+ cache_seed: Seed for caching the query. Queries with the same
250
+ `(lm, prompt, cache_seed)` will use cached responses. If `None`,
251
+ caching is disabled.
252
+ response_postprocess: A post-processing function for the raw LM response.
253
+ If `None`, no post-processing occurs.
254
+ autofix: Number of attempts for auto-fixing code errors. Set to `0` to
255
+ disable auto-fixing. Not supported with the `'json'` protocol.
256
+ autofix_lm: The LM to use for auto-fixing. Defaults to the `autofix_lm`
257
+ from `lf.context` or the main `lm`.
258
+ protocol: Format for schema representation. Choices are `'json'` or
259
+ `'python'`. Default is `'python'`.
260
+ returns_message: If `True`, returns an `lf.Message` object instead of
261
+ the final parsed result.
262
+ skip_lm: If `True`, skips the LLM call and returns the rendered
263
+ prompt as a `UserMessage` object.
264
+ **kwargs: Additional keyword arguments for:
265
+ - Rendering templates (e.g., `template_str`, `preamble`),
266
+ - Configuring `lf.structured.Mapping`.
267
+
268
+ Returns:
269
+ The result of the query:
270
+ - A single output or a list of outputs if multiple models/samples are used.
271
+ - Each output is a parsed object matching `schema`, an `lf.Message` (if
272
+ `returns_message=True`), or a natural language string (default).
273
+ """
274
+ # Internal usage logging.
275
+
276
+ # Multiple quries will be issued when `lm` is a list or `num_samples` is
277
+ # greater than 1.
278
+ if isinstance(lm, list) or num_samples != 1:
279
+ def _single_query(inputs):
280
+ lm, example_i = inputs
281
+ return query(
282
+ prompt,
283
+ schema,
284
+ default=default,
285
+ lm=lm,
286
+ examples=examples,
287
+ # Usually num_examples should not be large, so we multiple the user
288
+ # provided cache seed by 100 to avoid collision.
289
+ cache_seed=(
290
+ None if cache_seed is None else cache_seed * 100 + example_i
291
+ ),
292
+ response_postprocess=response_postprocess,
293
+ autofix=autofix,
294
+ autofix_lm=autofix_lm,
295
+ protocol=protocol,
296
+ returns_message=returns_message,
297
+ skip_lm=skip_lm,
298
+ **kwargs,
299
+ )
300
+ lm_list = lm if isinstance(lm, list) else [lm]
301
+ num_samples_list = (
302
+ num_samples if isinstance(num_samples, list)
303
+ else [num_samples] * len(lm_list)
304
+ )
305
+ assert len(lm_list) == len(num_samples_list), (
306
+ 'Expect the length of `num_samples` to be the same as the '
307
+ f'the length of `lm`. Got {num_samples} and {lm_list}.'
308
+ )
309
+ query_inputs = []
310
+ total_queries = 0
311
+ for lm, num_samples in zip(lm_list, num_samples_list):
312
+ query_inputs.extend([(lm, i) for i in range(num_samples)])
313
+ total_queries += num_samples
314
+
315
+ samples = []
316
+ for _, output, error in lf.concurrent_map(
317
+ _single_query, query_inputs, max_workers=max(64, total_queries),
318
+ ordered=True,
319
+ ):
320
+ if error is None:
321
+ samples.append(output)
322
+ return samples
323
+
324
+ # Normalize query schema.
325
+ # When `lf.query` is used for symbolic completion, schema is automatically
326
+ # inferred when it is None.
327
+ if isinstance(prompt, pg.Symbolic) and prompt.sym_partial and schema is None:
328
+ schema = prompt.__class__
329
+
330
+ # Normalize query input.
331
+ if isinstance(prompt, (lf.Message, str)):
332
+ # Query with structured output.
333
+ prompt_kwargs = kwargs.copy()
334
+ prompt_kwargs.pop('template_str', None)
335
+ query_input = lf.Template.from_value(prompt, **prompt_kwargs)
336
+ elif isinstance(prompt, lf.Template):
337
+ # Create a copy of the prompt if it has a parent object, so all child
338
+ # modality objects could be referred by path relative to the prompt.
339
+ query_input = prompt.clone() if prompt.sym_parent is not None else prompt
340
+
341
+ # Attach template metadata from kwargs. This is used to pass through fields
342
+ # from kwargs to the rendered message.
343
+ template_metadata = {
344
+ k: v for k, v in kwargs.items() if k.startswith('metadata_')
345
+ }
346
+ query_input.rebind(
347
+ template_metadata, skip_notification=True, raise_on_no_change=False
348
+ )
349
+ elif pg.MISSING_VALUE == prompt:
350
+ query_input = lf.UserMessage('')
351
+ else:
352
+ query_input = schema_lib.mark_missing(prompt)
353
+
354
+ with lf.track_usages() as usage_summary:
355
+ start_time = time.time()
356
+ if schema in (None, str):
357
+ # Query with natural language output.
358
+ output_message = lf.LangFunc.from_value(query_input, **kwargs)(
359
+ lm=lm, cache_seed=cache_seed, skip_lm=skip_lm
360
+ )
361
+ if response_postprocess:
362
+ processed_text = response_postprocess(output_message.text)
363
+ if processed_text != output_message.text:
364
+ output_message = lf.AIMessage(processed_text, source=output_message)
365
+ else:
366
+ # Query with structured output.
367
+ output_message = _query_structure_cls(protocol)(
368
+ input=(
369
+ query_input.render(lm=lm)
370
+ if isinstance(query_input, lf.Template)
371
+ else query_input
372
+ ),
373
+ schema=schema,
374
+ default=default,
375
+ examples=examples,
376
+ response_postprocess=response_postprocess,
377
+ autofix=autofix if protocol == 'python' else 0,
378
+ **kwargs,
379
+ )(
380
+ lm=lm,
381
+ autofix_lm=autofix_lm or lm,
382
+ cache_seed=cache_seed,
383
+ skip_lm=skip_lm,
384
+ )
385
+ end_time = time.time()
386
+
387
+ def _result(message: lf.Message):
388
+ return message.text if schema in (None, str) else message.result
389
+
390
+ # Track the query invocations.
391
+ if pg.MISSING_VALUE != prompt and not skip_lm:
392
+ trackers = lf.context_value('__query_trackers__', [])
393
+ if trackers:
394
+ invocation = QueryInvocation(
395
+ input=pg.Ref(query_input),
396
+ schema=(
397
+ schema_lib.Schema.from_value(schema)
398
+ if schema not in (None, str) else None
399
+ ),
400
+ lm=pg.Ref(lm),
401
+ examples=pg.Ref(examples) if examples else [],
402
+ lm_response=lf.AIMessage(output_message.text),
403
+ usage_summary=usage_summary,
404
+ start_time=start_time,
405
+ end_time=end_time,
406
+ )
407
+ for i, (tracker, include_child_scopes) in enumerate(trackers):
408
+ if i == 0 or include_child_scopes:
409
+ tracker.append(invocation)
410
+ return output_message if returns_message else _result(output_message)
411
+
412
+
413
+ #
414
+ # Helper function for map-reduce style querying.
415
+ #
416
+
417
+
418
+ def query_and_reduce(
419
+ prompt: Union[str, lf.Template, Any],
420
+ schema: schema_lib.SchemaType | None = None,
421
+ *,
422
+ reduce: Callable[[list[Any]], Any],
423
+ lm: lf.LanguageModel | list[lf.LanguageModel] | None = None,
424
+ num_samples: int | list[int] = 1,
425
+ **kwargs,
426
+ ) -> Any:
427
+ """Issues multiple `lf.query` calls in parallel and reduce the outputs.
428
+
429
+ Args:
430
+ prompt: A str (may contain {{}} as template) as natural language input, or a
431
+ `pg.Symbolic` object as structured input as prompt to LLM.
432
+ schema: A type annotation as the schema for output object. If str (default),
433
+ the response will be a str in natural language.
434
+ reduce: A function to reduce the outputs of multiple `lf.query` calls. It
435
+ takes a list of outputs and returns the final object.
436
+ lm: The language model to use. If not specified, the language model from
437
+ `lf.context` context manager will be used.
438
+ num_samples: The number of samples to obtain from each language model being
439
+ requested. If a list is provided, it should have the same length as `lm`.
440
+ **kwargs: Additional arguments to pass to `lf.query`.
441
+
442
+ Returns:
443
+ The reduced output from multiple `lf.query` calls.
444
+ """
445
+ results = query(prompt, schema, lm=lm, num_samples=num_samples, **kwargs)
446
+ if isinstance(results, list):
447
+ results = reduce(results)
448
+ return results
449
+
450
+
451
+ #
452
+ # Functions for decomposing `lf.query` into pre-llm and post-llm operations.
453
+ #
454
+
455
+
456
+ def query_prompt(
457
+ prompt: Union[str, lf.Template, Any],
458
+ schema: schema_lib.SchemaType | None = None,
459
+ **kwargs,
460
+ ) -> lf.Message:
461
+ """Returns the final prompt sent to LLM for `lf.query`."""
462
+ kwargs.pop('returns_message', None)
463
+ kwargs.pop('skip_lm', None)
464
+ return query(prompt, schema, skip_lm=True, returns_message=True, **kwargs)
465
+
466
+
467
+ def query_output(
468
+ response: Union[str, lf.Message],
469
+ schema: schema_lib.SchemaType | None = None,
470
+ **kwargs,
471
+ ) -> Any:
472
+ """Returns the final output of `lf.query` from a provided LLM response."""
473
+ kwargs.pop('prompt', None)
474
+ kwargs.pop('lm', None)
475
+ return query(
476
+ pg.MISSING_VALUE, schema, lm=fake.StaticResponse(response), **kwargs
477
+ )
478
+
479
+
480
+ #
481
+ # Functions for computing reward of an LLM response based on a mapping example.
482
+ #
483
+
484
+
485
+ def query_reward(
486
+ mapping_example: Union[str, mapping.MappingExample],
487
+ response: Union[str, lf.Message],
488
+ ) -> float | None:
489
+ """Returns the reward of an LLM response based on an mapping example."""
490
+ if isinstance(mapping_example, str):
491
+ mapping_example = pg.from_json_str(mapping_example)
492
+ assert isinstance(mapping_example, mapping.MappingExample), mapping_example
493
+ schema = mapping_example.schema
494
+
495
+ if schema and isinstance(schema.spec, pg.typing.Object):
496
+ output_cls = schema.spec.cls
497
+ elif schema is None and isinstance(mapping_example.output, pg.Object):
498
+ output_cls = mapping_example.output.__class__
499
+ else:
500
+ output_cls = None
501
+
502
+ reward_fn = _reward_fn(output_cls)
503
+ if reward_fn is None:
504
+ return None
505
+
506
+ return reward_fn(
507
+ query_output(response, output_cls),
508
+ mapping_example.input,
509
+ mapping_example.output,
510
+ mapping_example.metadata,
511
+ )
512
+
513
+
514
+ @functools.cache
515
+ def _reward_fn(cls) -> Callable[
516
+ [
517
+ pg.Object, # Actual output object.
518
+ Any, # Input object.
519
+ pg.Object, # Expected output object.
520
+ pg.Dict # User metadata.
521
+ ], float] | None:
522
+ """Returns the reward function for a class that is being queried."""
523
+ if not callable(getattr(cls, '__reward__', None)):
524
+ return None
525
+
526
+ signature = pg.typing.signature(cls.__reward__)
527
+ num_args = len(signature.args)
528
+ if num_args < 2 or num_args > 4:
529
+ raise TypeError(
530
+ f'`{cls.__type_name__}.__reward__` should have signature: '
531
+ '`__reward__(self, input, [expected_output], [expected_metadata])`.'
532
+ )
533
+ def _reward(self, input, expected_output, metadata): # pylint: disable=redefined-builtin
534
+ args = [self, input, expected_output, metadata]
535
+ return cls.__reward__(*args[:num_args])
536
+ return _reward
537
+
538
+
539
+ #
540
+ # Functions for tracking `lf.query` invocations.
541
+ #
542
+
543
+
544
+ class QueryInvocation(pg.Object, pg.views.HtmlTreeView.Extension):
545
+ """A class to represent the invocation of `lf.query`."""
546
+
547
+ input: Annotated[
548
+ Union[lf.Template, pg.Symbolic],
549
+ 'Mapping input of `lf.query`.'
550
+ ]
551
+ schema: pg.typing.Annotated[
552
+ schema_lib.schema_spec(noneable=True),
553
+ 'Schema of `lf.query`.'
554
+ ]
555
+ lm_response: Annotated[
556
+ lf.Message,
557
+ 'Raw LM response.'
558
+ ]
559
+ lm: Annotated[
560
+ lf.LanguageModel,
561
+ 'Language model used for `lf.query`.'
562
+ ]
563
+ examples: Annotated[
564
+ list[mapping.MappingExample],
565
+ 'Fewshot exemplars for `lf.query`.'
566
+ ]
567
+ usage_summary: Annotated[
568
+ lf.UsageSummary,
569
+ 'Usage summary for `lf.query`.'
570
+ ]
571
+ start_time: Annotated[
572
+ float,
573
+ 'Start time of query.'
574
+ ]
575
+ end_time: Annotated[
576
+ float,
577
+ 'End time of query.'
578
+ ]
579
+
580
+ @functools.cached_property
581
+ def lm_request(self) -> lf.Message:
582
+ return query_prompt(self.input, self.schema)
583
+
584
+ @functools.cached_property
585
+ def output(self) -> Any:
586
+ """The output of `lf.query`. If it failed, returns the `MappingError`."""
587
+ try:
588
+ return query_output(self.lm_response, self.schema)
589
+ except mapping.MappingError as e:
590
+ return e
591
+
592
+ @property
593
+ def has_error(self) -> bool:
594
+ """Returns True if the query failed to generate a valid output."""
595
+ return isinstance(self.output, BaseException)
596
+
597
+ @property
598
+ def elapse(self) -> float:
599
+ """Returns query elapse in seconds."""
600
+ return self.end_time - self.start_time
601
+
602
+ def _on_bound(self):
603
+ super()._on_bound()
604
+ self.__dict__.pop('lm_request', None)
605
+ self.__dict__.pop('output', None)
606
+
607
+ def _html_tree_view_summary(
608
+ self,
609
+ *,
610
+ view: pg.views.HtmlTreeView,
611
+ **kwargs: Any
612
+ ) -> pg.Html | None:
613
+ kwargs.pop('title', None)
614
+ kwargs.pop('enable_summary_tooltip', None)
615
+ return view.summary(
616
+ value=self,
617
+ title=pg.Html.element(
618
+ 'div',
619
+ [
620
+ pg.views.html.controls.Label(
621
+ 'lf.query',
622
+ css_classes=['query-invocation-type-name']
623
+ ),
624
+ pg.views.html.controls.Badge(
625
+ f'lm={self.lm.model_id}',
626
+ pg.format(
627
+ self.lm,
628
+ verbose=False,
629
+ python_format=True,
630
+ hide_default_values=True
631
+ ),
632
+ css_classes=['query-invocation-lm']
633
+ ),
634
+ pg.views.html.controls.Badge(
635
+ f'{int(self.elapse)} seconds',
636
+ css_classes=['query-invocation-time']
637
+ ),
638
+ self.usage_summary.to_html(extra_flags=dict(as_badge=True))
639
+ ],
640
+ css_classes=['query-invocation-title']
641
+ ),
642
+ enable_summary_tooltip=False,
643
+ **kwargs
644
+ )
645
+
646
+ def _html_tree_view_content(
647
+ self,
648
+ *,
649
+ view: pg.views.HtmlTreeView,
650
+ **kwargs: Any
651
+ ) -> pg.Html:
652
+ return pg.views.html.controls.TabControl([
653
+ pg.views.html.controls.Tab(
654
+ 'input',
655
+ pg.view(self.input, collapse_level=None),
656
+ ),
657
+ pg.views.html.controls.Tab(
658
+ 'output',
659
+ pg.view(self.output, collapse_level=None),
660
+ ),
661
+ pg.views.html.controls.Tab(
662
+ 'schema',
663
+ pg.view(self.schema),
664
+ ),
665
+ pg.views.html.controls.Tab(
666
+ 'lm_request',
667
+ pg.view(
668
+ self.lm_request,
669
+ extra_flags=dict(include_message_metadata=False),
670
+ ),
671
+ ),
672
+ pg.views.html.controls.Tab(
673
+ 'lm_response',
674
+ pg.view(
675
+ self.lm_response,
676
+ extra_flags=dict(include_message_metadata=False)
677
+ ),
678
+ ),
679
+ ], tab_position='top', selected=1).to_html()
680
+
681
+ @classmethod
682
+ def _html_tree_view_css_styles(cls) -> list[str]:
683
+ return super()._html_tree_view_css_styles() + [
684
+ """
685
+ .query-invocation-title {
686
+ display: inline-block;
687
+ font-weight: normal;
688
+ }
689
+ .query-invocation-type-name {
690
+ color: #888;
691
+ }
692
+ .query-invocation-lm.badge {
693
+ margin-left: 5px;
694
+ margin-right: 5px;
695
+ color: white;
696
+ background-color: mediumslateblue;
697
+ }
698
+ .query-invocation-time.badge {
699
+ margin-left: 5px;
700
+ border-radius: 0px;
701
+ font-weight: bold;
702
+ background-color: aliceblue;
703
+ }
704
+ .query-invocation-title .usage-summary.label {
705
+ border-radius: 0px;
706
+ color: #AAA;
707
+ }
708
+ """
709
+ ]
710
+
711
+
712
+ @contextlib.contextmanager
713
+ def track_queries(
714
+ include_child_scopes: bool = True
715
+ ) -> Iterator[list[QueryInvocation]]:
716
+ """Track all queries made during the context.
717
+
718
+ Example:
719
+
720
+ ```
721
+ with lf.track_queries() as queries:
722
+ lf.query('hi', lm=lm)
723
+ lf.query('What is this {{image}}?', lm=lm, image=image)
724
+
725
+ print(queries)
726
+ ```
727
+
728
+ Args:
729
+ include_child_scopes: If True, the queries made in child scopes will be
730
+ included in the returned list. Otherwise, only the queries made in the
731
+ current scope will be included.
732
+
733
+ Yields:
734
+ A list of `QueryInvocation` objects representing the queries made during
735
+ the context.
736
+ """
737
+ trackers = lf.context_value('__query_trackers__', [])
738
+ tracker = []
739
+
740
+ with lf.context(
741
+ __query_trackers__=[(tracker, include_child_scopes)] + trackers
742
+ ):
743
+ try:
744
+ yield tracker
745
+ finally:
746
+ pass