langfun 0.1.2.dev202412020805__py3-none-any.whl → 0.1.2.dev202412050804__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/__init__.py +2 -0
- langfun/core/agentic/action.py +74 -24
- langfun/core/agentic/action_test.py +20 -4
- langfun/core/eval/v2/runners.py +3 -0
- langfun/core/llms/__init__.py +1 -7
- langfun/core/llms/openai.py +142 -207
- langfun/core/llms/openai_test.py +160 -224
- langfun/core/llms/vertexai.py +23 -422
- langfun/core/llms/vertexai_test.py +21 -335
- langfun/core/structured/__init__.py +2 -0
- langfun/core/structured/prompting.py +148 -47
- langfun/core/structured/prompting_test.py +84 -1
- {langfun-0.1.2.dev202412020805.dist-info → langfun-0.1.2.dev202412050804.dist-info}/METADATA +1 -12
- {langfun-0.1.2.dev202412020805.dist-info → langfun-0.1.2.dev202412050804.dist-info}/RECORD +17 -17
- {langfun-0.1.2.dev202412020805.dist-info → langfun-0.1.2.dev202412050804.dist-info}/LICENSE +0 -0
- {langfun-0.1.2.dev202412020805.dist-info → langfun-0.1.2.dev202412050804.dist-info}/WHEEL +0 -0
- {langfun-0.1.2.dev202412020805.dist-info → langfun-0.1.2.dev202412050804.dist-info}/top_level.txt +0 -0
@@ -13,13 +13,12 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
"""Tests for Gemini models."""
|
15
15
|
|
16
|
+
import base64
|
16
17
|
import os
|
17
18
|
from typing import Any
|
18
19
|
import unittest
|
19
20
|
from unittest import mock
|
20
21
|
|
21
|
-
from google.cloud.aiplatform import models as aiplatform_models
|
22
|
-
from vertexai import generative_models
|
23
22
|
import langfun.core as lf
|
24
23
|
from langfun.core import modalities as lf_modalities
|
25
24
|
from langfun.core.llms import vertexai
|
@@ -39,33 +38,6 @@ example_image = (
|
|
39
38
|
)
|
40
39
|
|
41
40
|
|
42
|
-
def mock_generate_content(content, generation_config, **kwargs):
|
43
|
-
del kwargs
|
44
|
-
c = pg.Dict(generation_config.to_dict())
|
45
|
-
return generative_models.GenerationResponse.from_dict({
|
46
|
-
'candidates': [
|
47
|
-
{
|
48
|
-
'index': 0,
|
49
|
-
'content': {
|
50
|
-
'role': 'model',
|
51
|
-
'parts': [
|
52
|
-
{
|
53
|
-
'text': (
|
54
|
-
f'This is a response to {content[0]} with '
|
55
|
-
f'temperature={c.temperature}, '
|
56
|
-
f'top_p={c.top_p}, '
|
57
|
-
f'top_k={c.top_k}, '
|
58
|
-
f'max_tokens={c.max_output_tokens}, '
|
59
|
-
f'stop={"".join(c.stop_sequences)}.'
|
60
|
-
)
|
61
|
-
},
|
62
|
-
],
|
63
|
-
},
|
64
|
-
},
|
65
|
-
]
|
66
|
-
})
|
67
|
-
|
68
|
-
|
69
41
|
def mock_requests_post(url: str, json: dict[str, Any], **kwargs):
|
70
42
|
del url, kwargs
|
71
43
|
c = pg.Dict(json['generationConfig'])
|
@@ -100,273 +72,7 @@ def mock_requests_post(url: str, json: dict[str, Any], **kwargs):
|
|
100
72
|
return response
|
101
73
|
|
102
74
|
|
103
|
-
def mock_endpoint_predict(instances, **kwargs):
|
104
|
-
del kwargs
|
105
|
-
assert len(instances) == 1
|
106
|
-
return aiplatform_models.Prediction(
|
107
|
-
predictions=[
|
108
|
-
f"This is a response to {instances[0]['prompt']} with"
|
109
|
-
f" temperature={instances[0]['temperature']},"
|
110
|
-
f" top_p={instances[0]['top_p']}, top_k={instances[0]['top_k']},"
|
111
|
-
f" max_tokens={instances[0]['max_tokens']}."
|
112
|
-
],
|
113
|
-
deployed_model_id='',
|
114
|
-
)
|
115
|
-
|
116
|
-
|
117
75
|
class VertexAITest(unittest.TestCase):
|
118
|
-
"""Tests for Vertex model."""
|
119
|
-
|
120
|
-
def test_content_from_message_text_only(self):
|
121
|
-
text = 'This is a beautiful day'
|
122
|
-
model = vertexai.VertexAIGeminiPro1Vision()
|
123
|
-
chunks = model._content_from_message(lf.UserMessage(text))
|
124
|
-
self.assertEqual(chunks, [text])
|
125
|
-
|
126
|
-
def test_content_from_message_mm(self):
|
127
|
-
message = lf.UserMessage(
|
128
|
-
'This is an <<[[image]]>>, what is it?',
|
129
|
-
image=lf_modalities.Image.from_bytes(example_image),
|
130
|
-
)
|
131
|
-
|
132
|
-
# Non-multimodal model.
|
133
|
-
with self.assertRaisesRegex(lf.ModalityError, 'Unsupported modality'):
|
134
|
-
vertexai.VertexAIPalm2()._content_from_message(message)
|
135
|
-
|
136
|
-
model = vertexai.VertexAIGeminiPro1Vision()
|
137
|
-
chunks = model._content_from_message(message)
|
138
|
-
self.maxDiff = None
|
139
|
-
self.assertEqual([chunks[0], chunks[2]], ['This is an', ', what is it?'])
|
140
|
-
self.assertIsInstance(chunks[1], generative_models.Part)
|
141
|
-
|
142
|
-
def test_generation_response_to_message_text_only(self):
|
143
|
-
response = generative_models.GenerationResponse.from_dict({
|
144
|
-
'candidates': [
|
145
|
-
{
|
146
|
-
'index': 0,
|
147
|
-
'content': {
|
148
|
-
'role': 'model',
|
149
|
-
'parts': [
|
150
|
-
{
|
151
|
-
'text': 'hello world',
|
152
|
-
},
|
153
|
-
],
|
154
|
-
},
|
155
|
-
},
|
156
|
-
],
|
157
|
-
})
|
158
|
-
model = vertexai.VertexAIGeminiPro1Vision()
|
159
|
-
message = model._generation_response_to_message(response)
|
160
|
-
self.assertEqual(message, lf.AIMessage('hello world'))
|
161
|
-
|
162
|
-
def test_model_hub(self):
|
163
|
-
with mock.patch(
|
164
|
-
'vertexai.generative_models.'
|
165
|
-
'GenerativeModel.__init__'
|
166
|
-
) as mock_model_init:
|
167
|
-
mock_model_init.side_effect = lambda *args, **kwargs: None
|
168
|
-
model = vertexai._VERTEXAI_MODEL_HUB.get_generative_model(
|
169
|
-
'gemini-1.0-pro'
|
170
|
-
)
|
171
|
-
self.assertIsNotNone(model)
|
172
|
-
self.assertIs(
|
173
|
-
vertexai._VERTEXAI_MODEL_HUB.get_generative_model('gemini-1.0-pro'),
|
174
|
-
model,
|
175
|
-
)
|
176
|
-
|
177
|
-
with mock.patch(
|
178
|
-
'vertexai.language_models.'
|
179
|
-
'TextGenerationModel.from_pretrained'
|
180
|
-
) as mock_model_init:
|
181
|
-
|
182
|
-
class TextGenerationModel:
|
183
|
-
pass
|
184
|
-
|
185
|
-
mock_model_init.side_effect = lambda *args, **kw: TextGenerationModel()
|
186
|
-
model = vertexai._VERTEXAI_MODEL_HUB.get_text_generation_model(
|
187
|
-
'text-bison'
|
188
|
-
)
|
189
|
-
self.assertIsNotNone(model)
|
190
|
-
self.assertIs(
|
191
|
-
vertexai._VERTEXAI_MODEL_HUB.get_text_generation_model('text-bison'),
|
192
|
-
model,
|
193
|
-
)
|
194
|
-
|
195
|
-
def test_project_and_location_check(self):
|
196
|
-
with self.assertRaisesRegex(ValueError, 'Please specify `project`'):
|
197
|
-
_ = vertexai.VertexAIGeminiPro1Vision()._api_initialized
|
198
|
-
|
199
|
-
with self.assertRaisesRegex(ValueError, 'Please specify `location`'):
|
200
|
-
_ = vertexai.VertexAIGeminiPro1Vision(project='abc')._api_initialized
|
201
|
-
|
202
|
-
self.assertTrue(
|
203
|
-
vertexai.VertexAIGeminiPro1Vision(
|
204
|
-
project='abc', location='us-central1'
|
205
|
-
)._api_initialized
|
206
|
-
)
|
207
|
-
|
208
|
-
os.environ['VERTEXAI_PROJECT'] = 'abc'
|
209
|
-
os.environ['VERTEXAI_LOCATION'] = 'us-central1'
|
210
|
-
self.assertTrue(vertexai.VertexAIGeminiPro1Vision()._api_initialized)
|
211
|
-
del os.environ['VERTEXAI_PROJECT']
|
212
|
-
del os.environ['VERTEXAI_LOCATION']
|
213
|
-
|
214
|
-
def test_generation_config(self):
|
215
|
-
model = vertexai.VertexAIGeminiPro1Vision()
|
216
|
-
json_schema = {
|
217
|
-
'type': 'object',
|
218
|
-
'properties': {
|
219
|
-
'name': {'type': 'string'},
|
220
|
-
},
|
221
|
-
'required': ['name'],
|
222
|
-
'title': 'Person',
|
223
|
-
}
|
224
|
-
config = model._generation_config(
|
225
|
-
lf.UserMessage('hi', json_schema=json_schema),
|
226
|
-
lf.LMSamplingOptions(
|
227
|
-
temperature=2.0,
|
228
|
-
top_p=1.0,
|
229
|
-
top_k=20,
|
230
|
-
max_tokens=1024,
|
231
|
-
stop=['\n'],
|
232
|
-
),
|
233
|
-
)
|
234
|
-
actual = config.to_dict()
|
235
|
-
# There is a discrepancy between the `property_ordering` in the
|
236
|
-
# Google-internal version and the open-source version.
|
237
|
-
actual['response_schema'].pop('property_ordering', None)
|
238
|
-
if pg.KeyPath.parse('response_schema.type_').get(actual):
|
239
|
-
actual['response_schema']['type'] = actual['response_schema'].pop('type_')
|
240
|
-
if pg.KeyPath.parse('response_schema.properties.name.type_').get(actual):
|
241
|
-
actual['response_schema']['properties']['name']['type'] = actual[
|
242
|
-
'response_schema']['properties']['name'].pop('type_')
|
243
|
-
|
244
|
-
self.assertEqual(
|
245
|
-
actual,
|
246
|
-
dict(
|
247
|
-
temperature=2.0,
|
248
|
-
top_p=1.0,
|
249
|
-
top_k=20.0,
|
250
|
-
max_output_tokens=1024,
|
251
|
-
stop_sequences=['\n'],
|
252
|
-
response_mime_type='application/json',
|
253
|
-
response_schema={
|
254
|
-
'type': 'OBJECT',
|
255
|
-
'properties': {
|
256
|
-
'name': {'type': 'STRING'}
|
257
|
-
},
|
258
|
-
'required': ['name'],
|
259
|
-
'title': 'Person',
|
260
|
-
}
|
261
|
-
),
|
262
|
-
)
|
263
|
-
with self.assertRaisesRegex(
|
264
|
-
ValueError, '`json_schema` must be a dict, got'
|
265
|
-
):
|
266
|
-
model._generation_config(
|
267
|
-
lf.UserMessage('hi', json_schema='not a dict'),
|
268
|
-
lf.LMSamplingOptions(),
|
269
|
-
)
|
270
|
-
|
271
|
-
def test_call_generative_model(self):
|
272
|
-
with mock.patch(
|
273
|
-
'vertexai.generative_models.'
|
274
|
-
'GenerativeModel.__init__'
|
275
|
-
) as mock_model_init:
|
276
|
-
mock_model_init.side_effect = lambda *args, **kwargs: None
|
277
|
-
|
278
|
-
with mock.patch(
|
279
|
-
'vertexai.generative_models.'
|
280
|
-
'GenerativeModel.generate_content'
|
281
|
-
) as mock_generate:
|
282
|
-
mock_generate.side_effect = mock_generate_content
|
283
|
-
|
284
|
-
lm = vertexai.VertexAIGeminiPro1Vision(
|
285
|
-
project='abc', location='us-central1'
|
286
|
-
)
|
287
|
-
self.assertEqual(
|
288
|
-
lm(
|
289
|
-
'hello',
|
290
|
-
temperature=2.0,
|
291
|
-
top_p=1.0,
|
292
|
-
top_k=20,
|
293
|
-
max_tokens=1024,
|
294
|
-
stop='\n',
|
295
|
-
).text,
|
296
|
-
(
|
297
|
-
'This is a response to hello with temperature=2.0, '
|
298
|
-
'top_p=1.0, top_k=20.0, max_tokens=1024, stop=\n.'
|
299
|
-
),
|
300
|
-
)
|
301
|
-
|
302
|
-
def test_call_text_generation_model(self):
|
303
|
-
with mock.patch(
|
304
|
-
'vertexai.language_models.'
|
305
|
-
'TextGenerationModel.from_pretrained'
|
306
|
-
) as mock_model_init:
|
307
|
-
|
308
|
-
class TextGenerationModel:
|
309
|
-
|
310
|
-
def predict(self, prompt, **kwargs):
|
311
|
-
c = pg.Dict(kwargs)
|
312
|
-
return pg.Dict(
|
313
|
-
text=(
|
314
|
-
f'This is a response to {prompt} with '
|
315
|
-
f'temperature={c.temperature}, '
|
316
|
-
f'top_p={c.top_p}, '
|
317
|
-
f'top_k={c.top_k}, '
|
318
|
-
f'max_tokens={c.max_output_tokens}, '
|
319
|
-
f'stop={"".join(c.stop_sequences)}.'
|
320
|
-
)
|
321
|
-
)
|
322
|
-
|
323
|
-
mock_model_init.side_effect = lambda *args, **kw: TextGenerationModel()
|
324
|
-
lm = vertexai.VertexAIPalm2(project='abc', location='us-central1')
|
325
|
-
self.assertEqual(
|
326
|
-
lm(
|
327
|
-
'hello',
|
328
|
-
temperature=2.0,
|
329
|
-
top_p=1.0,
|
330
|
-
top_k=20,
|
331
|
-
max_tokens=1024,
|
332
|
-
stop='\n',
|
333
|
-
).text,
|
334
|
-
(
|
335
|
-
'This is a response to hello with temperature=2.0, '
|
336
|
-
'top_p=1.0, top_k=20, max_tokens=1024, stop=\n.'
|
337
|
-
),
|
338
|
-
)
|
339
|
-
|
340
|
-
def test_call_endpoint_model(self):
|
341
|
-
with mock.patch(
|
342
|
-
'google.cloud.aiplatform.models.Endpoint.__init__'
|
343
|
-
) as mock_model_init:
|
344
|
-
mock_model_init.side_effect = lambda *args, **kwargs: None
|
345
|
-
with mock.patch(
|
346
|
-
'google.cloud.aiplatform.models.Endpoint.predict'
|
347
|
-
) as mock_model_predict:
|
348
|
-
|
349
|
-
mock_model_predict.side_effect = mock_endpoint_predict
|
350
|
-
lm = vertexai.VertexAI(
|
351
|
-
'custom',
|
352
|
-
endpoint_name='123',
|
353
|
-
project='abc',
|
354
|
-
location='us-central1',
|
355
|
-
)
|
356
|
-
self.assertEqual(
|
357
|
-
lm(
|
358
|
-
'hello',
|
359
|
-
temperature=2.0,
|
360
|
-
top_p=1.0,
|
361
|
-
top_k=20,
|
362
|
-
max_tokens=50,
|
363
|
-
),
|
364
|
-
'This is a response to hello with temperature=2.0, top_p=1.0,'
|
365
|
-
' top_k=20, max_tokens=50.',
|
366
|
-
)
|
367
|
-
|
368
|
-
|
369
|
-
class VertexRestfulAITest(unittest.TestCase):
|
370
76
|
"""Tests for Vertex model with REST API."""
|
371
77
|
|
372
78
|
def test_content_from_message_text_only(self):
|
@@ -376,9 +82,9 @@ class VertexRestfulAITest(unittest.TestCase):
|
|
376
82
|
self.assertEqual(chunks, {'role': 'user', 'parts': [{'text': text}]})
|
377
83
|
|
378
84
|
def test_content_from_message_mm(self):
|
85
|
+
image = lf_modalities.Image.from_bytes(example_image)
|
379
86
|
message = lf.UserMessage(
|
380
|
-
'This is an <<[[image]]>>, what is it?',
|
381
|
-
image=lf_modalities.Image.from_bytes(example_image),
|
87
|
+
'This is an <<[[image]]>>, what is it?', image=image
|
382
88
|
)
|
383
89
|
|
384
90
|
# Non-multimodal model.
|
@@ -386,46 +92,25 @@ class VertexRestfulAITest(unittest.TestCase):
|
|
386
92
|
vertexai.VertexAIGeminiPro1()._content_from_message(message)
|
387
93
|
|
388
94
|
model = vertexai.VertexAIGeminiPro1Vision()
|
389
|
-
|
390
|
-
self.
|
391
|
-
|
392
|
-
|
393
|
-
|
394
|
-
|
395
|
-
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
'parts': [
|
402
|
-
{
|
403
|
-
'text': 'hello world',
|
404
|
-
},
|
405
|
-
],
|
95
|
+
content = model._content_from_message(message)
|
96
|
+
self.assertEqual(
|
97
|
+
content,
|
98
|
+
{
|
99
|
+
'role': 'user',
|
100
|
+
'parts': [
|
101
|
+
{'text': 'This is an'},
|
102
|
+
{
|
103
|
+
'inlineData': {
|
104
|
+
'data': base64.b64encode(example_image).decode(),
|
105
|
+
'mimeType': 'image/png',
|
106
|
+
}
|
406
107
|
},
|
407
|
-
|
408
|
-
|
409
|
-
|
410
|
-
|
411
|
-
message = model._generation_response_to_message(response)
|
412
|
-
self.assertEqual(message, lf.AIMessage('hello world'))
|
413
|
-
|
414
|
-
def test_model_hub(self):
|
415
|
-
with mock.patch(
|
416
|
-
'vertexai.generative_models.'
|
417
|
-
'GenerativeModel.__init__'
|
418
|
-
) as mock_model_init:
|
419
|
-
mock_model_init.side_effect = lambda *args, **kwargs: None
|
420
|
-
model = vertexai._VERTEXAI_MODEL_HUB.get_generative_model(
|
421
|
-
'gemini-1.0-pro'
|
422
|
-
)
|
423
|
-
self.assertIsNotNone(model)
|
424
|
-
self.assertIs(
|
425
|
-
vertexai._VERTEXAI_MODEL_HUB.get_generative_model('gemini-1.0-pro'),
|
426
|
-
model,
|
427
|
-
)
|
108
|
+
{'text': ', what is it?'},
|
109
|
+
],
|
110
|
+
},
|
111
|
+
)
|
428
112
|
|
113
|
+
@mock.patch.object(vertexai.VertexAI, 'credentials', new=True)
|
429
114
|
def test_project_and_location_check(self):
|
430
115
|
with self.assertRaisesRegex(ValueError, 'Please specify `project`'):
|
431
116
|
_ = vertexai.VertexAIGeminiPro1()._api_initialized
|
@@ -496,6 +181,7 @@ class VertexRestfulAITest(unittest.TestCase):
|
|
496
181
|
lf.LMSamplingOptions(),
|
497
182
|
)
|
498
183
|
|
184
|
+
@mock.patch.object(vertexai.VertexAI, 'credentials', new=True)
|
499
185
|
def test_call_model(self):
|
500
186
|
with mock.patch('requests.Session.post') as mock_generate:
|
501
187
|
mock_generate.side_effect = mock_requests_post
|
@@ -69,6 +69,8 @@ from langfun.core.structured.prompting import query
|
|
69
69
|
from langfun.core.structured.prompting import query_prompt
|
70
70
|
from langfun.core.structured.prompting import query_output
|
71
71
|
from langfun.core.structured.prompting import query_reward
|
72
|
+
from langfun.core.structured.prompting import QueryInvocation
|
73
|
+
from langfun.core.structured.prompting import track_queries
|
72
74
|
|
73
75
|
from langfun.core.structured.description import DescribeStructure
|
74
76
|
from langfun.core.structured.description import describe
|
@@ -13,8 +13,9 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
"""Symbolic query."""
|
15
15
|
|
16
|
+
import contextlib
|
16
17
|
import functools
|
17
|
-
from typing import Any, Callable, Type, Union
|
18
|
+
from typing import Annotated, Any, Callable, Iterator, Type, Union
|
18
19
|
|
19
20
|
import langfun.core as lf
|
20
21
|
from langfun.core.llms import fake
|
@@ -102,7 +103,7 @@ def _query_structure_cls(
|
|
102
103
|
|
103
104
|
|
104
105
|
def query(
|
105
|
-
prompt: Union[str,
|
106
|
+
prompt: Union[str, lf.Template, Any],
|
106
107
|
schema: Union[
|
107
108
|
schema_lib.Schema, Type[Any], list[Type[Any]], dict[str, Any], None
|
108
109
|
] = None,
|
@@ -119,7 +120,7 @@ def query(
|
|
119
120
|
skip_lm: bool = False,
|
120
121
|
**kwargs,
|
121
122
|
) -> Any:
|
122
|
-
"""
|
123
|
+
"""Queries an language model for a (maybe) structured output.
|
123
124
|
|
124
125
|
Examples:
|
125
126
|
|
@@ -189,59 +190,93 @@ def query(
|
|
189
190
|
"""
|
190
191
|
# Internal usage logging.
|
191
192
|
|
193
|
+
# Normalize query schema.
|
192
194
|
# When `lf.query` is used for symbolic completion, schema is automatically
|
193
195
|
# inferred when it is None.
|
194
196
|
if isinstance(prompt, pg.Symbolic) and prompt.sym_partial and schema is None:
|
195
197
|
schema = prompt.__class__
|
196
198
|
|
197
|
-
#
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
199
|
+
# Normalize query input.
|
200
|
+
if isinstance(prompt, (lf.Message, str)):
|
201
|
+
# Query with structured output.
|
202
|
+
prompt_kwargs = kwargs.copy()
|
203
|
+
prompt_kwargs.pop('template_str', None)
|
204
|
+
query_input = lf.Template.from_value(prompt, **prompt_kwargs)
|
205
|
+
elif isinstance(prompt, lf.Template):
|
206
|
+
# Create a copy of the prompt if it has a parent object, so all child
|
207
|
+
# modality objects could be referred by path relative to the prompt.
|
208
|
+
query_input = prompt.clone() if prompt.sym_parent is not None else prompt
|
209
|
+
|
210
|
+
# Attach template metadata from kwargs. This is used to pass through fields
|
211
|
+
# from kwargs to the rendered message.
|
212
|
+
template_metadata = {
|
213
|
+
k: v for k, v in kwargs.items() if k.startswith('metadata_')
|
214
|
+
}
|
215
|
+
query_input.rebind(
|
216
|
+
template_metadata, skip_notification=True, raise_on_no_change=False
|
206
217
|
)
|
207
|
-
|
208
|
-
|
209
|
-
if processed_text != output.text:
|
210
|
-
output = lf.AIMessage(processed_text, source=output)
|
211
|
-
return output if returns_message else output.text
|
212
|
-
|
213
|
-
# Query with structured output.
|
214
|
-
prompt_kwargs = kwargs.copy()
|
215
|
-
|
216
|
-
# NOTE(daiyip): when `template_str` is passed in, it's intended to modify the
|
217
|
-
# QueryStructure template string. Therefore, we pop out the argument for
|
218
|
-
# prompt rendering.
|
219
|
-
prompt_kwargs.pop('template_str', None)
|
220
|
-
|
221
|
-
if isinstance(prompt, (str, lf.Message, lf.Template)):
|
222
|
-
prompt = lf.Template.from_value(prompt, **prompt_kwargs).render(lm=lm)
|
218
|
+
elif pg.MISSING_VALUE == prompt:
|
219
|
+
query_input = lf.UserMessage('')
|
223
220
|
else:
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
221
|
+
query_input = schema_lib.mark_missing(prompt)
|
222
|
+
|
223
|
+
with lf.track_usages() as usage_summary:
|
224
|
+
if schema in (None, str):
|
225
|
+
# Query with natural language output.
|
226
|
+
output_message = lf.LangFunc.from_value(query_input, **kwargs)(
|
227
|
+
lm=lm, cache_seed=cache_seed, skip_lm=skip_lm
|
228
|
+
)
|
229
|
+
if response_postprocess:
|
230
|
+
processed_text = response_postprocess(output_message.text)
|
231
|
+
if processed_text != output_message.text:
|
232
|
+
output_message = lf.AIMessage(processed_text, source=output_message)
|
233
|
+
else:
|
234
|
+
# Query with structured output.
|
235
|
+
output_message = _query_structure_cls(protocol)(
|
236
|
+
input=(
|
237
|
+
query_input.render(lm=lm)
|
238
|
+
if isinstance(query_input, lf.Template)
|
239
|
+
else query_input
|
240
|
+
),
|
241
|
+
schema=schema,
|
242
|
+
default=default,
|
243
|
+
examples=examples,
|
244
|
+
response_postprocess=response_postprocess,
|
245
|
+
autofix=autofix if protocol == 'python' else 0,
|
246
|
+
**kwargs,
|
247
|
+
)(
|
248
|
+
lm=lm,
|
249
|
+
autofix_lm=autofix_lm or lm,
|
250
|
+
cache_seed=cache_seed,
|
251
|
+
skip_lm=skip_lm,
|
252
|
+
)
|
253
|
+
|
254
|
+
def _result(message: lf.Message):
|
255
|
+
return message.text if schema in (None, str) else message.result
|
256
|
+
|
257
|
+
# Track the query invocations.
|
258
|
+
if pg.MISSING_VALUE != prompt and not skip_lm:
|
259
|
+
trackers = lf.context_value('__query_trackers__', [])
|
260
|
+
if trackers:
|
261
|
+
invocation = QueryInvocation(
|
262
|
+
input=pg.Ref(query_input),
|
263
|
+
schema=(
|
264
|
+
schema_lib.Schema.from_value(schema)
|
265
|
+
if schema not in (None, str) else None
|
266
|
+
),
|
267
|
+
output=pg.Ref(_result(output_message)),
|
268
|
+
lm=pg.Ref(lm),
|
269
|
+
examples=pg.Ref(examples) if examples else [],
|
270
|
+
usage_summary=usage_summary,
|
271
|
+
)
|
272
|
+
for i, (tracker, include_child_scopes) in enumerate(trackers):
|
273
|
+
if i == 0 or include_child_scopes:
|
274
|
+
tracker.append(invocation)
|
275
|
+
return output_message if returns_message else _result(output_message)
|
241
276
|
|
242
277
|
|
243
278
|
def query_prompt(
|
244
|
-
prompt: Union[str,
|
279
|
+
prompt: Union[str, lf.Template, Any],
|
245
280
|
schema: Union[
|
246
281
|
schema_lib.Schema, Type[Any], list[Type[Any]], dict[str, Any], None
|
247
282
|
] = None,
|
@@ -264,7 +299,7 @@ def query_output(
|
|
264
299
|
kwargs.pop('prompt', None)
|
265
300
|
kwargs.pop('lm', None)
|
266
301
|
return query(
|
267
|
-
|
302
|
+
pg.MISSING_VALUE, schema, lm=fake.StaticResponse(response), **kwargs
|
268
303
|
)
|
269
304
|
|
270
305
|
|
@@ -320,3 +355,69 @@ def _reward_fn(cls) -> Callable[
|
|
320
355
|
args = [self, input, expected_output, metadata]
|
321
356
|
return cls.__reward__(*args[:num_args])
|
322
357
|
return _reward
|
358
|
+
|
359
|
+
|
360
|
+
class QueryInvocation(pg.Object):
|
361
|
+
"""A class to represent the invocation of `lf.query`."""
|
362
|
+
|
363
|
+
input: Annotated[
|
364
|
+
Union[lf.Template, pg.Symbolic],
|
365
|
+
'Mapping input of `lf.query`.'
|
366
|
+
]
|
367
|
+
schema: pg.typing.Annotated[
|
368
|
+
schema_lib.schema_spec(noneable=True),
|
369
|
+
'Schema of `lf.query`.'
|
370
|
+
]
|
371
|
+
output: Annotated[
|
372
|
+
Any,
|
373
|
+
'Mapping output of `lf.query`.'
|
374
|
+
]
|
375
|
+
lm: Annotated[
|
376
|
+
lf.LanguageModel,
|
377
|
+
'Language model used for `lf.query`.'
|
378
|
+
]
|
379
|
+
examples: Annotated[
|
380
|
+
list[mapping.MappingExample],
|
381
|
+
'Fewshot exemplars for `lf.query`.'
|
382
|
+
]
|
383
|
+
usage_summary: Annotated[
|
384
|
+
lf.UsageSummary,
|
385
|
+
'Usage summary for `lf.query`.'
|
386
|
+
]
|
387
|
+
|
388
|
+
|
389
|
+
@contextlib.contextmanager
|
390
|
+
def track_queries(
|
391
|
+
include_child_scopes: bool = True
|
392
|
+
) -> Iterator[list[QueryInvocation]]:
|
393
|
+
"""Track all queries made during the context.
|
394
|
+
|
395
|
+
Example:
|
396
|
+
|
397
|
+
```
|
398
|
+
with lf.track_queries() as queries:
|
399
|
+
lf.query('hi', lm=lm)
|
400
|
+
lf.query('What is this {{image}}?', lm=lm, image=image)
|
401
|
+
|
402
|
+
print(queries)
|
403
|
+
```
|
404
|
+
|
405
|
+
Args:
|
406
|
+
include_child_scopes: If True, the queries made in child scopes will be
|
407
|
+
included in the returned list. Otherwise, only the queries made in the
|
408
|
+
current scope will be included.
|
409
|
+
|
410
|
+
Yields:
|
411
|
+
A list of `QueryInvocation` objects representing the queries made during
|
412
|
+
the context.
|
413
|
+
"""
|
414
|
+
trackers = lf.context_value('__query_trackers__', [])
|
415
|
+
tracker = []
|
416
|
+
|
417
|
+
with lf.context(
|
418
|
+
__query_trackers__=[(tracker, include_child_scopes)] + trackers
|
419
|
+
):
|
420
|
+
try:
|
421
|
+
yield tracker
|
422
|
+
finally:
|
423
|
+
pass
|