langfun 0.0.2.dev20240429__py3-none-any.whl → 0.1.2.dev202501140804__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/__init__.py +20 -2
- langfun/core/__init__.py +16 -5
- langfun/core/agentic/__init__.py +30 -0
- langfun/core/agentic/action.py +854 -0
- langfun/core/agentic/action_eval.py +150 -0
- langfun/core/agentic/action_eval_test.py +109 -0
- langfun/core/agentic/action_test.py +136 -0
- langfun/core/coding/python/__init__.py +5 -11
- langfun/core/coding/python/correction.py +37 -21
- langfun/core/coding/python/correction_test.py +29 -3
- langfun/core/coding/python/execution.py +40 -216
- langfun/core/coding/python/execution_test.py +29 -89
- langfun/core/coding/python/generation.py +21 -11
- langfun/core/coding/python/generation_test.py +2 -2
- langfun/core/coding/python/parsing.py +108 -193
- langfun/core/coding/python/parsing_test.py +2 -105
- langfun/core/component.py +63 -2
- langfun/core/component_test.py +53 -0
- langfun/core/concurrent.py +414 -117
- langfun/core/concurrent_test.py +111 -24
- langfun/core/console.py +18 -5
- langfun/core/console_test.py +17 -0
- langfun/core/eval/__init__.py +16 -1
- langfun/core/eval/base.py +622 -174
- langfun/core/eval/base_test.py +200 -54
- langfun/core/eval/matching.py +63 -76
- langfun/core/eval/matching_test.py +17 -8
- langfun/core/eval/patching.py +130 -0
- langfun/core/eval/patching_test.py +170 -0
- langfun/core/eval/scoring.py +26 -26
- langfun/core/eval/scoring_test.py +19 -2
- langfun/core/eval/v2/__init__.py +42 -0
- langfun/core/eval/v2/checkpointing.py +380 -0
- langfun/core/eval/v2/checkpointing_test.py +228 -0
- langfun/core/eval/v2/eval_test_helper.py +136 -0
- langfun/core/eval/v2/evaluation.py +725 -0
- langfun/core/eval/v2/evaluation_test.py +180 -0
- langfun/core/eval/v2/example.py +305 -0
- langfun/core/eval/v2/example_test.py +128 -0
- langfun/core/eval/v2/experiment.py +1048 -0
- langfun/core/eval/v2/experiment_test.py +433 -0
- langfun/core/eval/v2/metric_values.py +156 -0
- langfun/core/eval/v2/metric_values_test.py +80 -0
- langfun/core/eval/v2/metrics.py +357 -0
- langfun/core/eval/v2/metrics_test.py +203 -0
- langfun/core/eval/v2/progress.py +348 -0
- langfun/core/eval/v2/progress_test.py +82 -0
- langfun/core/eval/v2/progress_tracking.py +210 -0
- langfun/core/eval/v2/progress_tracking_test.py +66 -0
- langfun/core/eval/v2/reporting.py +270 -0
- langfun/core/eval/v2/reporting_test.py +158 -0
- langfun/core/eval/v2/runners.py +488 -0
- langfun/core/eval/v2/runners_test.py +334 -0
- langfun/core/langfunc.py +4 -17
- langfun/core/langfunc_test.py +22 -6
- langfun/core/language_model.py +577 -39
- langfun/core/language_model_test.py +470 -56
- langfun/core/llms/__init__.py +87 -16
- langfun/core/llms/anthropic.py +312 -87
- langfun/core/llms/anthropic_test.py +71 -3
- langfun/core/llms/cache/base.py +21 -2
- langfun/core/llms/cache/in_memory.py +13 -0
- langfun/core/llms/cache/in_memory_test.py +53 -2
- langfun/core/llms/compositional.py +101 -0
- langfun/core/llms/compositional_test.py +73 -0
- langfun/core/llms/deepseek.py +117 -0
- langfun/core/llms/deepseek_test.py +61 -0
- langfun/core/llms/fake.py +11 -7
- langfun/core/llms/fake_test.py +14 -0
- langfun/core/llms/gemini.py +507 -0
- langfun/core/llms/gemini_test.py +195 -0
- langfun/core/llms/google_genai.py +62 -218
- langfun/core/llms/google_genai_test.py +9 -202
- langfun/core/llms/groq.py +160 -144
- langfun/core/llms/groq_test.py +31 -137
- langfun/core/llms/llama_cpp.py +15 -42
- langfun/core/llms/llama_cpp_test.py +4 -30
- langfun/core/llms/openai.py +395 -203
- langfun/core/llms/openai_compatible.py +179 -0
- langfun/core/llms/openai_compatible_test.py +495 -0
- langfun/core/llms/openai_test.py +30 -395
- langfun/core/llms/rest.py +113 -0
- langfun/core/llms/rest_test.py +111 -0
- langfun/core/llms/vertexai.py +192 -0
- langfun/core/llms/vertexai_test.py +52 -0
- langfun/core/logging.py +284 -0
- langfun/core/logging_test.py +125 -0
- langfun/core/message.py +319 -9
- langfun/core/message_test.py +190 -13
- langfun/core/modalities/__init__.py +6 -2
- langfun/core/modalities/audio.py +30 -0
- langfun/core/modalities/audio_test.py +63 -0
- langfun/core/modalities/image.py +39 -20
- langfun/core/modalities/image_test.py +52 -9
- langfun/core/modalities/mime.py +206 -29
- langfun/core/modalities/mime_test.py +90 -9
- langfun/core/modalities/ms_office.py +117 -0
- langfun/core/modalities/ms_office_test.py +389 -0
- langfun/core/modalities/pdf.py +22 -0
- langfun/core/modalities/pdf_test.py +57 -0
- langfun/core/modalities/video.py +9 -26
- langfun/core/modalities/video_test.py +3 -3
- langfun/core/modality.py +26 -3
- langfun/core/modality_test.py +2 -2
- langfun/core/sampling.py +11 -11
- langfun/core/structured/__init__.py +12 -16
- langfun/core/structured/completion.py +32 -5
- langfun/core/structured/completion_test.py +7 -6
- langfun/core/structured/description.py +2 -2
- langfun/core/structured/description_test.py +3 -3
- langfun/core/structured/function_generation.py +60 -27
- langfun/core/structured/function_generation_test.py +72 -2
- langfun/core/structured/mapping.py +97 -47
- langfun/core/structured/mapping_test.py +90 -2
- langfun/core/structured/parsing.py +33 -21
- langfun/core/structured/parsing_test.py +53 -9
- langfun/core/structured/querying.py +746 -0
- langfun/core/structured/{prompting_test.py → querying_test.py} +469 -51
- langfun/core/structured/schema.py +204 -97
- langfun/core/structured/schema_generation.py +1 -1
- langfun/core/structured/schema_test.py +130 -29
- langfun/core/structured/scoring.py +125 -19
- langfun/core/structured/scoring_test.py +30 -0
- langfun/core/structured/tokenization.py +64 -0
- langfun/core/structured/tokenization_test.py +48 -0
- langfun/core/template.py +115 -1
- langfun/core/template_test.py +71 -1
- langfun/core/templates/conversation.py +9 -0
- langfun/core/templates/conversation_test.py +4 -3
- langfun/core/templates/selfplay_test.py +10 -2
- langfun-0.1.2.dev202501140804.dist-info/METADATA +225 -0
- langfun-0.1.2.dev202501140804.dist-info/RECORD +153 -0
- {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501140804.dist-info}/WHEEL +1 -1
- langfun/core/coding/python/errors.py +0 -108
- langfun/core/coding/python/errors_test.py +0 -99
- langfun/core/coding/python/permissions.py +0 -90
- langfun/core/coding/python/permissions_test.py +0 -86
- langfun/core/structured/prompting.py +0 -238
- langfun/core/text_formatting.py +0 -162
- langfun/core/text_formatting_test.py +0 -47
- langfun-0.0.2.dev20240429.dist-info/METADATA +0 -100
- langfun-0.0.2.dev20240429.dist-info/RECORD +0 -108
- {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501140804.dist-info}/LICENSE +0 -0
- {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501140804.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,746 @@
|
|
1
|
+
# Copyright 2023 The Langfun Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
"""Query LLM for structured output."""
|
15
|
+
|
16
|
+
import contextlib
|
17
|
+
import functools
|
18
|
+
import time
|
19
|
+
from typing import Annotated, Any, Callable, Iterator, Type, Union
|
20
|
+
|
21
|
+
import langfun.core as lf
|
22
|
+
from langfun.core.llms import fake
|
23
|
+
from langfun.core.structured import mapping
|
24
|
+
from langfun.core.structured import schema as schema_lib
|
25
|
+
import pyglove as pg
|
26
|
+
|
27
|
+
|
28
|
+
@lf.use_init_args(['schema', 'default', 'examples'])
|
29
|
+
class _QueryStructure(mapping.Mapping):
|
30
|
+
"""Query an object out from a natural language text."""
|
31
|
+
|
32
|
+
context_title = 'CONTEXT'
|
33
|
+
input_title = 'INPUT_OBJECT'
|
34
|
+
|
35
|
+
# Mark schema as required.
|
36
|
+
schema: pg.typing.Annotated[
|
37
|
+
schema_lib.schema_spec(), 'Required schema for parsing.'
|
38
|
+
]
|
39
|
+
|
40
|
+
|
41
|
+
class _QueryStructureJson(_QueryStructure):
|
42
|
+
"""Query a structured value using JSON as the protocol."""
|
43
|
+
|
44
|
+
preamble = """
|
45
|
+
Please respond to the last {{ input_title }} with {{ output_title}} according to {{ schema_title }}:
|
46
|
+
|
47
|
+
INSTRUCTIONS:
|
48
|
+
1. If the schema has `_type`, carry it over to the JSON output.
|
49
|
+
2. If a field from the schema cannot be extracted from the response, use null as the JSON value.
|
50
|
+
|
51
|
+
{{ input_title }}:
|
52
|
+
1 + 1 =
|
53
|
+
|
54
|
+
{{ schema_title }}:
|
55
|
+
{"result": {"_type": "langfun.core.structured.query.Answer", "final_answer": int}}
|
56
|
+
|
57
|
+
{{ output_title}}:
|
58
|
+
{"result": {"_type": "langfun.core.structured.query.Answer", "final_answer": 2}}
|
59
|
+
"""
|
60
|
+
|
61
|
+
protocol = 'json'
|
62
|
+
schema_title = 'SCHEMA'
|
63
|
+
output_title = 'JSON'
|
64
|
+
|
65
|
+
|
66
|
+
class _QueryStructurePython(_QueryStructure):
|
67
|
+
"""Query a structured value using Python as the protocol."""
|
68
|
+
|
69
|
+
preamble = """
|
70
|
+
Please respond to the last {{ input_title }} with {{ output_title }} according to {{ schema_title }}.
|
71
|
+
|
72
|
+
{{ input_title }}:
|
73
|
+
1 + 1 =
|
74
|
+
|
75
|
+
{{ schema_title }}:
|
76
|
+
Answer
|
77
|
+
|
78
|
+
```python
|
79
|
+
class Answer:
|
80
|
+
final_answer: int
|
81
|
+
```
|
82
|
+
|
83
|
+
{{ output_title }}:
|
84
|
+
```python
|
85
|
+
Answer(
|
86
|
+
final_answer=2
|
87
|
+
)
|
88
|
+
```
|
89
|
+
"""
|
90
|
+
protocol = 'python'
|
91
|
+
schema_title = 'OUTPUT_TYPE'
|
92
|
+
output_title = 'OUTPUT_OBJECT'
|
93
|
+
|
94
|
+
|
95
|
+
def _query_structure_cls(
|
96
|
+
protocol: schema_lib.SchemaProtocol,
|
97
|
+
) -> Type[_QueryStructure]:
|
98
|
+
if protocol == 'json':
|
99
|
+
return _QueryStructureJson
|
100
|
+
elif protocol == 'python':
|
101
|
+
return _QueryStructurePython
|
102
|
+
else:
|
103
|
+
raise ValueError(f'Unknown protocol: {protocol!r}.')
|
104
|
+
|
105
|
+
|
106
|
+
def query(
|
107
|
+
prompt: Union[str, lf.Template, Any],
|
108
|
+
schema: schema_lib.SchemaType | None = None,
|
109
|
+
default: Any = lf.RAISE_IF_HAS_ERROR,
|
110
|
+
*,
|
111
|
+
lm: lf.LanguageModel | list[lf.LanguageModel] | None = None,
|
112
|
+
num_samples: int | list[int] = 1,
|
113
|
+
examples: list[mapping.MappingExample] | None = None,
|
114
|
+
cache_seed: int | None = 0,
|
115
|
+
response_postprocess: Callable[[str], str] | None = None,
|
116
|
+
autofix: int = 0,
|
117
|
+
autofix_lm: lf.LanguageModel | None = None,
|
118
|
+
protocol: schema_lib.SchemaProtocol = 'python',
|
119
|
+
returns_message: bool = False,
|
120
|
+
skip_lm: bool = False,
|
121
|
+
**kwargs,
|
122
|
+
) -> Any:
|
123
|
+
"""Query one or more language models for structured or unstructured outputs.
|
124
|
+
|
125
|
+
This is the primary API in Langfun for interacting with language models,
|
126
|
+
supporting natural language prompts, structured inputs, and multiple advanced
|
127
|
+
features.
|
128
|
+
|
129
|
+
Key Features:
|
130
|
+
|
131
|
+
- **Input**: Accepts natural language strings, structured inputs (e.g.,
|
132
|
+
`pg.Object`), and templates (`lf.Template`) with modality objects.
|
133
|
+
|
134
|
+
- **Output**: Returns structured outputs when `schema` is specified;
|
135
|
+
otherwise, outputs raw natural language (as a string).
|
136
|
+
|
137
|
+
- **Few-shot examples**: Supports structured few-shot examples with the
|
138
|
+
`examples` argument.
|
139
|
+
|
140
|
+
- **Multi-LM fan-out**: Sends queries to multiple language models with in
|
141
|
+
multiple samples in parallel, returning a list of outputs.
|
142
|
+
|
143
|
+
Examples:
|
144
|
+
|
145
|
+
Case 1: Regular natural language-based LLM query:
|
146
|
+
|
147
|
+
```
|
148
|
+
lf.query('1 + 1 = ?', lm=lf.llms.Gpt4Turbo())
|
149
|
+
|
150
|
+
# Outptut: '2'
|
151
|
+
```
|
152
|
+
|
153
|
+
Case 2: Query with structured output.
|
154
|
+
|
155
|
+
```
|
156
|
+
lf.query('1 + 1 = ?', int, lm=lf.llms.Gpt4Turbo())
|
157
|
+
|
158
|
+
# Output: 2
|
159
|
+
```
|
160
|
+
|
161
|
+
Case 3: Query with structured input.
|
162
|
+
|
163
|
+
```
|
164
|
+
class Sum(pg.Object):
|
165
|
+
a: int
|
166
|
+
b: int
|
167
|
+
|
168
|
+
lf.query(Sum(1, 1), int, lm=lf.llms.Gpt4Turbo())
|
169
|
+
|
170
|
+
# Output: 2
|
171
|
+
```
|
172
|
+
|
173
|
+
Case 4: Query with input of mixed modalities.
|
174
|
+
|
175
|
+
```
|
176
|
+
class Animal(pg.Object):
|
177
|
+
pass
|
178
|
+
|
179
|
+
class Dog(Animal):
|
180
|
+
pass
|
181
|
+
|
182
|
+
class Entity(pg.Object):
|
183
|
+
name: str
|
184
|
+
|
185
|
+
lf.query(
|
186
|
+
'What is in this {{image}} and {{objects}}?'
|
187
|
+
list[Entity],
|
188
|
+
lm=lf.llms.Gpt4Turbo()
|
189
|
+
image=lf.Image(path='/path/to/a/airplane.png'),
|
190
|
+
objects=[Dog()],
|
191
|
+
)
|
192
|
+
|
193
|
+
# Output: [Entity(name='airplane'), Entity(name='dog')]
|
194
|
+
```
|
195
|
+
|
196
|
+
Case 5: Query with structured few-shot examples.
|
197
|
+
```
|
198
|
+
lf.query(
|
199
|
+
'What is in this {{image}} and {{objects}}?'
|
200
|
+
list[Entity],
|
201
|
+
lm=lf.llms.Gpt4Turbo()
|
202
|
+
image=lf.Image(path='/path/to/a/dinasaur.png'),
|
203
|
+
objects=[Dog()],
|
204
|
+
examples=[
|
205
|
+
lf.MappingExample(
|
206
|
+
input=lf.Template(
|
207
|
+
'What is the object near the house in this {{image}}?',
|
208
|
+
image=lf.Image(path='/path/to/image.png'),
|
209
|
+
),
|
210
|
+
schema=Entity,
|
211
|
+
output=Entity('cat'),
|
212
|
+
),
|
213
|
+
],
|
214
|
+
)
|
215
|
+
|
216
|
+
# Output: [Entity(name='dinasaur'), Entity(name='dog')]
|
217
|
+
```
|
218
|
+
|
219
|
+
Case 6: Multiple queries to multiple models.
|
220
|
+
```
|
221
|
+
lf.query(
|
222
|
+
'1 + 1 = ?',
|
223
|
+
int,
|
224
|
+
lm=[
|
225
|
+
lf.llms.Gpt4Turbo(),
|
226
|
+
lf.llms.Gemini1_5Pro(),
|
227
|
+
],
|
228
|
+
num_samples=[1, 2],
|
229
|
+
)
|
230
|
+
# Output: [2, 2, 2]
|
231
|
+
```
|
232
|
+
|
233
|
+
Args:
|
234
|
+
prompt: The input query. Can be:
|
235
|
+
- A natural language string (supports templating with `{{}}`),
|
236
|
+
- A `pg.Object` object for structured input,
|
237
|
+
- An `lf.Template` for mixed or template-based inputs.
|
238
|
+
schema: Type annotation or `lf.Schema` object for the expected output.
|
239
|
+
If `None` (default), the response will be a natural language string.
|
240
|
+
default: Default value to return if parsing fails. If not specified, an
|
241
|
+
error will be raised.
|
242
|
+
lm: The language model(s) to query. Can be:
|
243
|
+
- A single `LanguageModel`,
|
244
|
+
- A list of `LanguageModel`s for multi-model fan-out.
|
245
|
+
If `None`, the LM from `lf.context` will be used.
|
246
|
+
num_samples: Number of samples to generate. If a list is provided, its
|
247
|
+
length must match the number of models in `lm`.
|
248
|
+
examples: Few-shot examples to guide the model output. Defaults to `None`.
|
249
|
+
cache_seed: Seed for caching the query. Queries with the same
|
250
|
+
`(lm, prompt, cache_seed)` will use cached responses. If `None`,
|
251
|
+
caching is disabled.
|
252
|
+
response_postprocess: A post-processing function for the raw LM response.
|
253
|
+
If `None`, no post-processing occurs.
|
254
|
+
autofix: Number of attempts for auto-fixing code errors. Set to `0` to
|
255
|
+
disable auto-fixing. Not supported with the `'json'` protocol.
|
256
|
+
autofix_lm: The LM to use for auto-fixing. Defaults to the `autofix_lm`
|
257
|
+
from `lf.context` or the main `lm`.
|
258
|
+
protocol: Format for schema representation. Choices are `'json'` or
|
259
|
+
`'python'`. Default is `'python'`.
|
260
|
+
returns_message: If `True`, returns an `lf.Message` object instead of
|
261
|
+
the final parsed result.
|
262
|
+
skip_lm: If `True`, skips the LLM call and returns the rendered
|
263
|
+
prompt as a `UserMessage` object.
|
264
|
+
**kwargs: Additional keyword arguments for:
|
265
|
+
- Rendering templates (e.g., `template_str`, `preamble`),
|
266
|
+
- Configuring `lf.structured.Mapping`.
|
267
|
+
|
268
|
+
Returns:
|
269
|
+
The result of the query:
|
270
|
+
- A single output or a list of outputs if multiple models/samples are used.
|
271
|
+
- Each output is a parsed object matching `schema`, an `lf.Message` (if
|
272
|
+
`returns_message=True`), or a natural language string (default).
|
273
|
+
"""
|
274
|
+
# Internal usage logging.
|
275
|
+
|
276
|
+
# Multiple quries will be issued when `lm` is a list or `num_samples` is
|
277
|
+
# greater than 1.
|
278
|
+
if isinstance(lm, list) or num_samples != 1:
|
279
|
+
def _single_query(inputs):
|
280
|
+
lm, example_i = inputs
|
281
|
+
return query(
|
282
|
+
prompt,
|
283
|
+
schema,
|
284
|
+
default=default,
|
285
|
+
lm=lm,
|
286
|
+
examples=examples,
|
287
|
+
# Usually num_examples should not be large, so we multiple the user
|
288
|
+
# provided cache seed by 100 to avoid collision.
|
289
|
+
cache_seed=(
|
290
|
+
None if cache_seed is None else cache_seed * 100 + example_i
|
291
|
+
),
|
292
|
+
response_postprocess=response_postprocess,
|
293
|
+
autofix=autofix,
|
294
|
+
autofix_lm=autofix_lm,
|
295
|
+
protocol=protocol,
|
296
|
+
returns_message=returns_message,
|
297
|
+
skip_lm=skip_lm,
|
298
|
+
**kwargs,
|
299
|
+
)
|
300
|
+
lm_list = lm if isinstance(lm, list) else [lm]
|
301
|
+
num_samples_list = (
|
302
|
+
num_samples if isinstance(num_samples, list)
|
303
|
+
else [num_samples] * len(lm_list)
|
304
|
+
)
|
305
|
+
assert len(lm_list) == len(num_samples_list), (
|
306
|
+
'Expect the length of `num_samples` to be the same as the '
|
307
|
+
f'the length of `lm`. Got {num_samples} and {lm_list}.'
|
308
|
+
)
|
309
|
+
query_inputs = []
|
310
|
+
total_queries = 0
|
311
|
+
for lm, num_samples in zip(lm_list, num_samples_list):
|
312
|
+
query_inputs.extend([(lm, i) for i in range(num_samples)])
|
313
|
+
total_queries += num_samples
|
314
|
+
|
315
|
+
samples = []
|
316
|
+
for _, output, error in lf.concurrent_map(
|
317
|
+
_single_query, query_inputs, max_workers=max(64, total_queries),
|
318
|
+
ordered=True,
|
319
|
+
):
|
320
|
+
if error is None:
|
321
|
+
samples.append(output)
|
322
|
+
return samples
|
323
|
+
|
324
|
+
# Normalize query schema.
|
325
|
+
# When `lf.query` is used for symbolic completion, schema is automatically
|
326
|
+
# inferred when it is None.
|
327
|
+
if isinstance(prompt, pg.Symbolic) and prompt.sym_partial and schema is None:
|
328
|
+
schema = prompt.__class__
|
329
|
+
|
330
|
+
# Normalize query input.
|
331
|
+
if isinstance(prompt, (lf.Message, str)):
|
332
|
+
# Query with structured output.
|
333
|
+
prompt_kwargs = kwargs.copy()
|
334
|
+
prompt_kwargs.pop('template_str', None)
|
335
|
+
query_input = lf.Template.from_value(prompt, **prompt_kwargs)
|
336
|
+
elif isinstance(prompt, lf.Template):
|
337
|
+
# Create a copy of the prompt if it has a parent object, so all child
|
338
|
+
# modality objects could be referred by path relative to the prompt.
|
339
|
+
query_input = prompt.clone() if prompt.sym_parent is not None else prompt
|
340
|
+
|
341
|
+
# Attach template metadata from kwargs. This is used to pass through fields
|
342
|
+
# from kwargs to the rendered message.
|
343
|
+
template_metadata = {
|
344
|
+
k: v for k, v in kwargs.items() if k.startswith('metadata_')
|
345
|
+
}
|
346
|
+
query_input.rebind(
|
347
|
+
template_metadata, skip_notification=True, raise_on_no_change=False
|
348
|
+
)
|
349
|
+
elif pg.MISSING_VALUE == prompt:
|
350
|
+
query_input = lf.UserMessage('')
|
351
|
+
else:
|
352
|
+
query_input = schema_lib.mark_missing(prompt)
|
353
|
+
|
354
|
+
with lf.track_usages() as usage_summary:
|
355
|
+
start_time = time.time()
|
356
|
+
if schema in (None, str):
|
357
|
+
# Query with natural language output.
|
358
|
+
output_message = lf.LangFunc.from_value(query_input, **kwargs)(
|
359
|
+
lm=lm, cache_seed=cache_seed, skip_lm=skip_lm
|
360
|
+
)
|
361
|
+
if response_postprocess:
|
362
|
+
processed_text = response_postprocess(output_message.text)
|
363
|
+
if processed_text != output_message.text:
|
364
|
+
output_message = lf.AIMessage(processed_text, source=output_message)
|
365
|
+
else:
|
366
|
+
# Query with structured output.
|
367
|
+
output_message = _query_structure_cls(protocol)(
|
368
|
+
input=(
|
369
|
+
query_input.render(lm=lm)
|
370
|
+
if isinstance(query_input, lf.Template)
|
371
|
+
else query_input
|
372
|
+
),
|
373
|
+
schema=schema,
|
374
|
+
default=default,
|
375
|
+
examples=examples,
|
376
|
+
response_postprocess=response_postprocess,
|
377
|
+
autofix=autofix if protocol == 'python' else 0,
|
378
|
+
**kwargs,
|
379
|
+
)(
|
380
|
+
lm=lm,
|
381
|
+
autofix_lm=autofix_lm or lm,
|
382
|
+
cache_seed=cache_seed,
|
383
|
+
skip_lm=skip_lm,
|
384
|
+
)
|
385
|
+
end_time = time.time()
|
386
|
+
|
387
|
+
def _result(message: lf.Message):
|
388
|
+
return message.text if schema in (None, str) else message.result
|
389
|
+
|
390
|
+
# Track the query invocations.
|
391
|
+
if pg.MISSING_VALUE != prompt and not skip_lm:
|
392
|
+
trackers = lf.context_value('__query_trackers__', [])
|
393
|
+
if trackers:
|
394
|
+
invocation = QueryInvocation(
|
395
|
+
input=pg.Ref(query_input),
|
396
|
+
schema=(
|
397
|
+
schema_lib.Schema.from_value(schema)
|
398
|
+
if schema not in (None, str) else None
|
399
|
+
),
|
400
|
+
lm=pg.Ref(lm),
|
401
|
+
examples=pg.Ref(examples) if examples else [],
|
402
|
+
lm_response=lf.AIMessage(output_message.text),
|
403
|
+
usage_summary=usage_summary,
|
404
|
+
start_time=start_time,
|
405
|
+
end_time=end_time,
|
406
|
+
)
|
407
|
+
for i, (tracker, include_child_scopes) in enumerate(trackers):
|
408
|
+
if i == 0 or include_child_scopes:
|
409
|
+
tracker.append(invocation)
|
410
|
+
return output_message if returns_message else _result(output_message)
|
411
|
+
|
412
|
+
|
413
|
+
#
|
414
|
+
# Helper function for map-reduce style querying.
|
415
|
+
#
|
416
|
+
|
417
|
+
|
418
|
+
def query_and_reduce(
|
419
|
+
prompt: Union[str, lf.Template, Any],
|
420
|
+
schema: schema_lib.SchemaType | None = None,
|
421
|
+
*,
|
422
|
+
reduce: Callable[[list[Any]], Any],
|
423
|
+
lm: lf.LanguageModel | list[lf.LanguageModel] | None = None,
|
424
|
+
num_samples: int | list[int] = 1,
|
425
|
+
**kwargs,
|
426
|
+
) -> Any:
|
427
|
+
"""Issues multiple `lf.query` calls in parallel and reduce the outputs.
|
428
|
+
|
429
|
+
Args:
|
430
|
+
prompt: A str (may contain {{}} as template) as natural language input, or a
|
431
|
+
`pg.Symbolic` object as structured input as prompt to LLM.
|
432
|
+
schema: A type annotation as the schema for output object. If str (default),
|
433
|
+
the response will be a str in natural language.
|
434
|
+
reduce: A function to reduce the outputs of multiple `lf.query` calls. It
|
435
|
+
takes a list of outputs and returns the final object.
|
436
|
+
lm: The language model to use. If not specified, the language model from
|
437
|
+
`lf.context` context manager will be used.
|
438
|
+
num_samples: The number of samples to obtain from each language model being
|
439
|
+
requested. If a list is provided, it should have the same length as `lm`.
|
440
|
+
**kwargs: Additional arguments to pass to `lf.query`.
|
441
|
+
|
442
|
+
Returns:
|
443
|
+
The reduced output from multiple `lf.query` calls.
|
444
|
+
"""
|
445
|
+
results = query(prompt, schema, lm=lm, num_samples=num_samples, **kwargs)
|
446
|
+
if isinstance(results, list):
|
447
|
+
results = reduce(results)
|
448
|
+
return results
|
449
|
+
|
450
|
+
|
451
|
+
#
|
452
|
+
# Functions for decomposing `lf.query` into pre-llm and post-llm operations.
|
453
|
+
#
|
454
|
+
|
455
|
+
|
456
|
+
def query_prompt(
|
457
|
+
prompt: Union[str, lf.Template, Any],
|
458
|
+
schema: schema_lib.SchemaType | None = None,
|
459
|
+
**kwargs,
|
460
|
+
) -> lf.Message:
|
461
|
+
"""Returns the final prompt sent to LLM for `lf.query`."""
|
462
|
+
kwargs.pop('returns_message', None)
|
463
|
+
kwargs.pop('skip_lm', None)
|
464
|
+
return query(prompt, schema, skip_lm=True, returns_message=True, **kwargs)
|
465
|
+
|
466
|
+
|
467
|
+
def query_output(
|
468
|
+
response: Union[str, lf.Message],
|
469
|
+
schema: schema_lib.SchemaType | None = None,
|
470
|
+
**kwargs,
|
471
|
+
) -> Any:
|
472
|
+
"""Returns the final output of `lf.query` from a provided LLM response."""
|
473
|
+
kwargs.pop('prompt', None)
|
474
|
+
kwargs.pop('lm', None)
|
475
|
+
return query(
|
476
|
+
pg.MISSING_VALUE, schema, lm=fake.StaticResponse(response), **kwargs
|
477
|
+
)
|
478
|
+
|
479
|
+
|
480
|
+
#
|
481
|
+
# Functions for computing reward of an LLM response based on a mapping example.
|
482
|
+
#
|
483
|
+
|
484
|
+
|
485
|
+
def query_reward(
|
486
|
+
mapping_example: Union[str, mapping.MappingExample],
|
487
|
+
response: Union[str, lf.Message],
|
488
|
+
) -> float | None:
|
489
|
+
"""Returns the reward of an LLM response based on an mapping example."""
|
490
|
+
if isinstance(mapping_example, str):
|
491
|
+
mapping_example = pg.from_json_str(mapping_example)
|
492
|
+
assert isinstance(mapping_example, mapping.MappingExample), mapping_example
|
493
|
+
schema = mapping_example.schema
|
494
|
+
|
495
|
+
if schema and isinstance(schema.spec, pg.typing.Object):
|
496
|
+
output_cls = schema.spec.cls
|
497
|
+
elif schema is None and isinstance(mapping_example.output, pg.Object):
|
498
|
+
output_cls = mapping_example.output.__class__
|
499
|
+
else:
|
500
|
+
output_cls = None
|
501
|
+
|
502
|
+
reward_fn = _reward_fn(output_cls)
|
503
|
+
if reward_fn is None:
|
504
|
+
return None
|
505
|
+
|
506
|
+
return reward_fn(
|
507
|
+
query_output(response, output_cls),
|
508
|
+
mapping_example.input,
|
509
|
+
mapping_example.output,
|
510
|
+
mapping_example.metadata,
|
511
|
+
)
|
512
|
+
|
513
|
+
|
514
|
+
@functools.cache
|
515
|
+
def _reward_fn(cls) -> Callable[
|
516
|
+
[
|
517
|
+
pg.Object, # Actual output object.
|
518
|
+
Any, # Input object.
|
519
|
+
pg.Object, # Expected output object.
|
520
|
+
pg.Dict # User metadata.
|
521
|
+
], float] | None:
|
522
|
+
"""Returns the reward function for a class that is being queried."""
|
523
|
+
if not callable(getattr(cls, '__reward__', None)):
|
524
|
+
return None
|
525
|
+
|
526
|
+
signature = pg.typing.signature(cls.__reward__)
|
527
|
+
num_args = len(signature.args)
|
528
|
+
if num_args < 2 or num_args > 4:
|
529
|
+
raise TypeError(
|
530
|
+
f'`{cls.__type_name__}.__reward__` should have signature: '
|
531
|
+
'`__reward__(self, input, [expected_output], [expected_metadata])`.'
|
532
|
+
)
|
533
|
+
def _reward(self, input, expected_output, metadata): # pylint: disable=redefined-builtin
|
534
|
+
args = [self, input, expected_output, metadata]
|
535
|
+
return cls.__reward__(*args[:num_args])
|
536
|
+
return _reward
|
537
|
+
|
538
|
+
|
539
|
+
#
|
540
|
+
# Functions for tracking `lf.query` invocations.
|
541
|
+
#
|
542
|
+
|
543
|
+
|
544
|
+
class QueryInvocation(pg.Object, pg.views.HtmlTreeView.Extension):
|
545
|
+
"""A class to represent the invocation of `lf.query`."""
|
546
|
+
|
547
|
+
input: Annotated[
|
548
|
+
Union[lf.Template, pg.Symbolic],
|
549
|
+
'Mapping input of `lf.query`.'
|
550
|
+
]
|
551
|
+
schema: pg.typing.Annotated[
|
552
|
+
schema_lib.schema_spec(noneable=True),
|
553
|
+
'Schema of `lf.query`.'
|
554
|
+
]
|
555
|
+
lm_response: Annotated[
|
556
|
+
lf.Message,
|
557
|
+
'Raw LM response.'
|
558
|
+
]
|
559
|
+
lm: Annotated[
|
560
|
+
lf.LanguageModel,
|
561
|
+
'Language model used for `lf.query`.'
|
562
|
+
]
|
563
|
+
examples: Annotated[
|
564
|
+
list[mapping.MappingExample],
|
565
|
+
'Fewshot exemplars for `lf.query`.'
|
566
|
+
]
|
567
|
+
usage_summary: Annotated[
|
568
|
+
lf.UsageSummary,
|
569
|
+
'Usage summary for `lf.query`.'
|
570
|
+
]
|
571
|
+
start_time: Annotated[
|
572
|
+
float,
|
573
|
+
'Start time of query.'
|
574
|
+
]
|
575
|
+
end_time: Annotated[
|
576
|
+
float,
|
577
|
+
'End time of query.'
|
578
|
+
]
|
579
|
+
|
580
|
+
@functools.cached_property
|
581
|
+
def lm_request(self) -> lf.Message:
|
582
|
+
return query_prompt(self.input, self.schema)
|
583
|
+
|
584
|
+
@functools.cached_property
|
585
|
+
def output(self) -> Any:
|
586
|
+
"""The output of `lf.query`. If it failed, returns the `MappingError`."""
|
587
|
+
try:
|
588
|
+
return query_output(self.lm_response, self.schema)
|
589
|
+
except mapping.MappingError as e:
|
590
|
+
return e
|
591
|
+
|
592
|
+
@property
|
593
|
+
def has_error(self) -> bool:
|
594
|
+
"""Returns True if the query failed to generate a valid output."""
|
595
|
+
return isinstance(self.output, BaseException)
|
596
|
+
|
597
|
+
@property
|
598
|
+
def elapse(self) -> float:
|
599
|
+
"""Returns query elapse in seconds."""
|
600
|
+
return self.end_time - self.start_time
|
601
|
+
|
602
|
+
def _on_bound(self):
|
603
|
+
super()._on_bound()
|
604
|
+
self.__dict__.pop('lm_request', None)
|
605
|
+
self.__dict__.pop('output', None)
|
606
|
+
|
607
|
+
def _html_tree_view_summary(
|
608
|
+
self,
|
609
|
+
*,
|
610
|
+
view: pg.views.HtmlTreeView,
|
611
|
+
**kwargs: Any
|
612
|
+
) -> pg.Html | None:
|
613
|
+
kwargs.pop('title', None)
|
614
|
+
kwargs.pop('enable_summary_tooltip', None)
|
615
|
+
return view.summary(
|
616
|
+
value=self,
|
617
|
+
title=pg.Html.element(
|
618
|
+
'div',
|
619
|
+
[
|
620
|
+
pg.views.html.controls.Label(
|
621
|
+
'lf.query',
|
622
|
+
css_classes=['query-invocation-type-name']
|
623
|
+
),
|
624
|
+
pg.views.html.controls.Badge(
|
625
|
+
f'lm={self.lm.model_id}',
|
626
|
+
pg.format(
|
627
|
+
self.lm,
|
628
|
+
verbose=False,
|
629
|
+
python_format=True,
|
630
|
+
hide_default_values=True
|
631
|
+
),
|
632
|
+
css_classes=['query-invocation-lm']
|
633
|
+
),
|
634
|
+
pg.views.html.controls.Badge(
|
635
|
+
f'{int(self.elapse)} seconds',
|
636
|
+
css_classes=['query-invocation-time']
|
637
|
+
),
|
638
|
+
self.usage_summary.to_html(extra_flags=dict(as_badge=True))
|
639
|
+
],
|
640
|
+
css_classes=['query-invocation-title']
|
641
|
+
),
|
642
|
+
enable_summary_tooltip=False,
|
643
|
+
**kwargs
|
644
|
+
)
|
645
|
+
|
646
|
+
def _html_tree_view_content(
|
647
|
+
self,
|
648
|
+
*,
|
649
|
+
view: pg.views.HtmlTreeView,
|
650
|
+
**kwargs: Any
|
651
|
+
) -> pg.Html:
|
652
|
+
return pg.views.html.controls.TabControl([
|
653
|
+
pg.views.html.controls.Tab(
|
654
|
+
'input',
|
655
|
+
pg.view(self.input, collapse_level=None),
|
656
|
+
),
|
657
|
+
pg.views.html.controls.Tab(
|
658
|
+
'output',
|
659
|
+
pg.view(self.output, collapse_level=None),
|
660
|
+
),
|
661
|
+
pg.views.html.controls.Tab(
|
662
|
+
'schema',
|
663
|
+
pg.view(self.schema),
|
664
|
+
),
|
665
|
+
pg.views.html.controls.Tab(
|
666
|
+
'lm_request',
|
667
|
+
pg.view(
|
668
|
+
self.lm_request,
|
669
|
+
extra_flags=dict(include_message_metadata=False),
|
670
|
+
),
|
671
|
+
),
|
672
|
+
pg.views.html.controls.Tab(
|
673
|
+
'lm_response',
|
674
|
+
pg.view(
|
675
|
+
self.lm_response,
|
676
|
+
extra_flags=dict(include_message_metadata=False)
|
677
|
+
),
|
678
|
+
),
|
679
|
+
], tab_position='top', selected=1).to_html()
|
680
|
+
|
681
|
+
@classmethod
|
682
|
+
def _html_tree_view_css_styles(cls) -> list[str]:
|
683
|
+
return super()._html_tree_view_css_styles() + [
|
684
|
+
"""
|
685
|
+
.query-invocation-title {
|
686
|
+
display: inline-block;
|
687
|
+
font-weight: normal;
|
688
|
+
}
|
689
|
+
.query-invocation-type-name {
|
690
|
+
color: #888;
|
691
|
+
}
|
692
|
+
.query-invocation-lm.badge {
|
693
|
+
margin-left: 5px;
|
694
|
+
margin-right: 5px;
|
695
|
+
color: white;
|
696
|
+
background-color: mediumslateblue;
|
697
|
+
}
|
698
|
+
.query-invocation-time.badge {
|
699
|
+
margin-left: 5px;
|
700
|
+
border-radius: 0px;
|
701
|
+
font-weight: bold;
|
702
|
+
background-color: aliceblue;
|
703
|
+
}
|
704
|
+
.query-invocation-title .usage-summary.label {
|
705
|
+
border-radius: 0px;
|
706
|
+
color: #AAA;
|
707
|
+
}
|
708
|
+
"""
|
709
|
+
]
|
710
|
+
|
711
|
+
|
712
|
+
@contextlib.contextmanager
|
713
|
+
def track_queries(
|
714
|
+
include_child_scopes: bool = True
|
715
|
+
) -> Iterator[list[QueryInvocation]]:
|
716
|
+
"""Track all queries made during the context.
|
717
|
+
|
718
|
+
Example:
|
719
|
+
|
720
|
+
```
|
721
|
+
with lf.track_queries() as queries:
|
722
|
+
lf.query('hi', lm=lm)
|
723
|
+
lf.query('What is this {{image}}?', lm=lm, image=image)
|
724
|
+
|
725
|
+
print(queries)
|
726
|
+
```
|
727
|
+
|
728
|
+
Args:
|
729
|
+
include_child_scopes: If True, the queries made in child scopes will be
|
730
|
+
included in the returned list. Otherwise, only the queries made in the
|
731
|
+
current scope will be included.
|
732
|
+
|
733
|
+
Yields:
|
734
|
+
A list of `QueryInvocation` objects representing the queries made during
|
735
|
+
the context.
|
736
|
+
"""
|
737
|
+
trackers = lf.context_value('__query_trackers__', [])
|
738
|
+
tracker = []
|
739
|
+
|
740
|
+
with lf.context(
|
741
|
+
__query_trackers__=[(tracker, include_child_scopes)] + trackers
|
742
|
+
):
|
743
|
+
try:
|
744
|
+
yield tracker
|
745
|
+
finally:
|
746
|
+
pass
|