langfun 0.0.2.dev20240330__py3-none-any.whl → 0.0.2.dev20240511__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of langfun might be problematic. Click here for more details.
- langfun/__init__.py +7 -0
- langfun/core/__init__.py +1 -0
- langfun/core/coding/python/correction.py +0 -7
- langfun/core/component.py +6 -0
- langfun/core/component_test.py +1 -0
- langfun/core/eval/__init__.py +15 -0
- langfun/core/eval/base.py +665 -95
- langfun/core/eval/base_test.py +224 -53
- langfun/core/eval/matching.py +48 -30
- langfun/core/eval/matching_test.py +25 -3
- langfun/core/eval/patching.py +130 -0
- langfun/core/eval/patching_test.py +170 -0
- langfun/core/eval/scoring.py +19 -10
- langfun/core/eval/scoring_test.py +21 -3
- langfun/core/langfunc.py +1 -22
- langfun/core/langfunc_test.py +10 -4
- langfun/core/language_model.py +130 -24
- langfun/core/language_model_test.py +249 -26
- langfun/core/llms/__init__.py +27 -2
- langfun/core/llms/anthropic.py +263 -0
- langfun/core/llms/anthropic_test.py +167 -0
- langfun/core/llms/cache/in_memory_test.py +37 -28
- langfun/core/llms/fake.py +34 -25
- langfun/core/llms/fake_test.py +122 -11
- langfun/core/llms/google_genai.py +8 -0
- langfun/core/llms/google_genai_test.py +8 -3
- langfun/core/llms/groq.py +260 -0
- langfun/core/llms/groq_test.py +170 -0
- langfun/core/llms/llama_cpp.py +3 -1
- langfun/core/llms/openai.py +100 -81
- langfun/core/llms/openai_test.py +287 -60
- langfun/core/llms/vertexai.py +291 -0
- langfun/core/llms/vertexai_test.py +233 -0
- langfun/core/modalities/image.py +1 -3
- langfun/core/modalities/mime.py +6 -0
- langfun/core/modalities/video.py +6 -5
- langfun/core/structured/__init__.py +5 -0
- langfun/core/structured/completion_test.py +2 -2
- langfun/core/structured/function_generation.py +245 -0
- langfun/core/structured/function_generation_test.py +329 -0
- langfun/core/structured/mapping.py +61 -3
- langfun/core/structured/mapping_test.py +17 -0
- langfun/core/structured/parsing_test.py +18 -13
- langfun/core/structured/prompting.py +61 -12
- langfun/core/structured/prompting_test.py +122 -12
- langfun/core/structured/schema.py +38 -6
- langfun/core/structured/schema_generation_test.py +2 -2
- langfun/core/structured/schema_test.py +36 -7
- langfun/core/structured/scoring.py +4 -1
- langfun/core/structured/scoring_test.py +6 -0
- langfun/core/template.py +147 -11
- langfun/core/template_test.py +75 -0
- langfun/core/templates/selfplay_test.py +6 -2
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240511.dist-info}/METADATA +3 -2
- langfun-0.0.2.dev20240511.dist-info/RECORD +112 -0
- langfun-0.0.2.dev20240330.dist-info/RECORD +0 -102
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240511.dist-info}/LICENSE +0 -0
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240511.dist-info}/WHEEL +0 -0
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240511.dist-info}/top_level.txt +0 -0
@@ -17,11 +17,9 @@ import inspect
|
|
17
17
|
import unittest
|
18
18
|
|
19
19
|
import langfun.core as lf
|
20
|
-
from langfun.core import coding
|
21
20
|
from langfun.core.llms import fake
|
22
21
|
from langfun.core.structured import mapping
|
23
22
|
from langfun.core.structured import parsing
|
24
|
-
from langfun.core.structured import schema as schema_lib
|
25
23
|
import pyglove as pg
|
26
24
|
|
27
25
|
|
@@ -255,7 +253,7 @@ class ParseStructurePythonTest(unittest.TestCase):
|
|
255
253
|
override_attrs=True,
|
256
254
|
):
|
257
255
|
with self.assertRaisesRegex(
|
258
|
-
|
256
|
+
mapping.MappingError,
|
259
257
|
'name .* is not defined',
|
260
258
|
):
|
261
259
|
parsing.parse('three', int)
|
@@ -280,13 +278,15 @@ class ParseStructurePythonTest(unittest.TestCase):
|
|
280
278
|
),
|
281
279
|
1,
|
282
280
|
)
|
281
|
+
r = parsing.parse(
|
282
|
+
'the answer is 1', int, user_prompt='what is 0 + 1?', lm=lm,
|
283
|
+
returns_message=True
|
284
|
+
)
|
283
285
|
self.assertEqual(
|
284
|
-
|
285
|
-
'the answer is 1', int, user_prompt='what is 0 + 1?', lm=lm,
|
286
|
-
returns_message=True
|
287
|
-
),
|
286
|
+
r,
|
288
287
|
lf.AIMessage(
|
289
288
|
'1', score=1.0, result=1, logprobs=None,
|
289
|
+
usage=lf.LMSamplingUsage(652, 1, 653),
|
290
290
|
tags=['lm-response', 'lm-output', 'transformed']
|
291
291
|
),
|
292
292
|
)
|
@@ -544,7 +544,7 @@ class ParseStructureJsonTest(unittest.TestCase):
|
|
544
544
|
override_attrs=True,
|
545
545
|
):
|
546
546
|
with self.assertRaisesRegex(
|
547
|
-
|
547
|
+
mapping.MappingError,
|
548
548
|
'No JSON dict in the output',
|
549
549
|
):
|
550
550
|
parsing.parse('three', int, protocol='json')
|
@@ -634,13 +634,18 @@ class CallTest(unittest.TestCase):
|
|
634
634
|
)
|
635
635
|
|
636
636
|
def test_call_with_returning_message(self):
|
637
|
+
r = parsing.call(
|
638
|
+
'Compute 1 + 2', int, lm=fake.StaticSequence(['three', '3']),
|
639
|
+
returns_message=True
|
640
|
+
)
|
637
641
|
self.assertEqual(
|
638
|
-
|
639
|
-
'Compute 1 + 2', int, lm=fake.StaticSequence(['three', '3']),
|
640
|
-
returns_message=True
|
641
|
-
),
|
642
|
+
r,
|
642
643
|
lf.AIMessage(
|
643
|
-
'3',
|
644
|
+
'3',
|
645
|
+
result=3,
|
646
|
+
score=1.0,
|
647
|
+
logprobs=None,
|
648
|
+
usage=lf.LMSamplingUsage(315, 1, 316),
|
644
649
|
tags=['lm-response', 'lm-output', 'transformed']
|
645
650
|
),
|
646
651
|
)
|
@@ -13,9 +13,10 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
"""Symbolic query."""
|
15
15
|
|
16
|
-
from typing import Any, Type, Union
|
16
|
+
from typing import Any, Callable, Type, Union
|
17
17
|
|
18
18
|
import langfun.core as lf
|
19
|
+
from langfun.core.llms import fake
|
19
20
|
from langfun.core.structured import mapping
|
20
21
|
from langfun.core.structured import schema as schema_lib
|
21
22
|
import pyglove as pg
|
@@ -78,7 +79,9 @@ class QueryStructurePython(QueryStructure):
|
|
78
79
|
|
79
80
|
{{ output_title }}:
|
80
81
|
```python
|
81
|
-
Answer(
|
82
|
+
Answer(
|
83
|
+
final_answer=2
|
84
|
+
)
|
82
85
|
```
|
83
86
|
"""
|
84
87
|
protocol = 'python'
|
@@ -107,9 +110,11 @@ def query(
|
|
107
110
|
lm: lf.LanguageModel | None = None,
|
108
111
|
examples: list[mapping.MappingExample] | None = None,
|
109
112
|
cache_seed: int | None = 0,
|
113
|
+
response_postprocess: Callable[[str], str] | None = None,
|
110
114
|
autofix: int = 0,
|
111
115
|
autofix_lm: lf.LanguageModel | None = None,
|
112
116
|
protocol: schema_lib.SchemaProtocol = 'python',
|
117
|
+
include_methods: bool = False,
|
113
118
|
returns_message: bool = False,
|
114
119
|
skip_lm: bool = False,
|
115
120
|
**kwargs,
|
@@ -157,8 +162,11 @@ def query(
|
|
157
162
|
examples: An optional list of fewshot examples for helping parsing. If None,
|
158
163
|
the default one-shot example will be added.
|
159
164
|
cache_seed: Seed for computing cache key. The cache key is determined by a
|
160
|
-
tuple of (lm, prompt, cache seed). If None, cache will be disabled for
|
161
|
-
|
165
|
+
tuple of (lm, prompt, cache seed). If None, cache will be disabled for the
|
166
|
+
query even cache is configured by the LM.
|
167
|
+
response_postprocess: An optional callable object to process the raw LM
|
168
|
+
response before parsing it into the final output object. If None, the raw
|
169
|
+
LM response will not be processed.
|
162
170
|
autofix: Number of attempts to auto fix the generated code. If 0, autofix is
|
163
171
|
disabled. Auto-fix is not supported for 'json' protocol.
|
164
172
|
autofix_lm: The language model to use for autofix. If not specified, the
|
@@ -166,12 +174,17 @@ def query(
|
|
166
174
|
will use `lm`.
|
167
175
|
protocol: The protocol for schema/value representation. Applicable values
|
168
176
|
are 'json' and 'python'. By default `python` will be used.
|
177
|
+
include_methods: If True, include method definitions in the output type
|
178
|
+
during prompting.
|
169
179
|
returns_message: If True, returns `lf.Message` as the output, instead of
|
170
180
|
returning the structured `message.result`.
|
171
181
|
skip_lm: If True, returns the rendered prompt as a UserMessage object.
|
172
182
|
otherwise return the LLM response based on the rendered prompt.
|
173
|
-
**kwargs: Keyword arguments passed to the
|
174
|
-
`lf.structured.
|
183
|
+
**kwargs: Keyword arguments passed to render the prompt or configure the
|
184
|
+
`lf.structured.Mapping` class. Notable kwargs are:
|
185
|
+
- template_str: Change the root template for query.
|
186
|
+
- preamble: Change the preamble for query.
|
187
|
+
- mapping_template: Change the template for each mapping examle.
|
175
188
|
|
176
189
|
Returns:
|
177
190
|
The result based on the schema.
|
@@ -188,16 +201,22 @@ def query(
|
|
188
201
|
output = lf.LangFunc.from_value(prompt, **kwargs)(
|
189
202
|
lm=lm, cache_seed=cache_seed, skip_lm=skip_lm
|
190
203
|
)
|
204
|
+
if response_postprocess:
|
205
|
+
processed_text = response_postprocess(output.text)
|
206
|
+
if processed_text != output.text:
|
207
|
+
output = lf.AIMessage(processed_text, source=output)
|
191
208
|
return output if returns_message else output.text
|
192
209
|
|
193
210
|
# Query with structured output.
|
194
|
-
|
195
|
-
prompt = lf.Template(prompt, **kwargs)
|
196
|
-
elif isinstance(prompt, lf.Template):
|
197
|
-
prompt = prompt.rebind(**kwargs)
|
211
|
+
prompt_kwargs = kwargs.copy()
|
198
212
|
|
199
|
-
|
200
|
-
|
213
|
+
# NOTE(daiyip): when `template_str` is passed in, it's intended to modify the
|
214
|
+
# QueryStructure template string. Therefore, we pop out the argument for
|
215
|
+
# prompt rendering.
|
216
|
+
prompt_kwargs.pop('template_str', None)
|
217
|
+
|
218
|
+
if isinstance(prompt, (str, lf.Message, lf.Template)):
|
219
|
+
prompt = lf.Template.from_value(prompt, **prompt_kwargs).render(lm=lm)
|
201
220
|
else:
|
202
221
|
prompt = schema_lib.mark_missing(prompt)
|
203
222
|
|
@@ -206,6 +225,8 @@ def query(
|
|
206
225
|
schema=schema,
|
207
226
|
default=default,
|
208
227
|
examples=examples,
|
228
|
+
include_methods=include_methods,
|
229
|
+
response_postprocess=response_postprocess,
|
209
230
|
autofix=autofix if protocol == 'python' else 0,
|
210
231
|
**kwargs,
|
211
232
|
)(
|
@@ -215,3 +236,31 @@ def query(
|
|
215
236
|
skip_lm=skip_lm,
|
216
237
|
)
|
217
238
|
return output if returns_message else output.result
|
239
|
+
|
240
|
+
|
241
|
+
def query_prompt(
|
242
|
+
prompt: Union[str, pg.Symbolic],
|
243
|
+
schema: Union[
|
244
|
+
schema_lib.Schema, Type[Any], list[Type[Any]], dict[str, Any], None
|
245
|
+
] = None,
|
246
|
+
**kwargs,
|
247
|
+
) -> lf.Message:
|
248
|
+
"""Returns the final prompt sent to LLM for `lf.query`."""
|
249
|
+
kwargs.pop('returns_message', None)
|
250
|
+
kwargs.pop('skip_lm', None)
|
251
|
+
return query(prompt, schema, skip_lm=True, returns_message=True, **kwargs)
|
252
|
+
|
253
|
+
|
254
|
+
def query_output(
|
255
|
+
response: Union[str, lf.Message],
|
256
|
+
schema: Union[
|
257
|
+
schema_lib.Schema, Type[Any], list[Type[Any]], dict[str, Any], None
|
258
|
+
],
|
259
|
+
**kwargs,
|
260
|
+
) -> Any:
|
261
|
+
"""Returns the final output of `lf.query` from a provided LLM response."""
|
262
|
+
kwargs.pop('prompt', None)
|
263
|
+
kwargs.pop('lm', None)
|
264
|
+
return query(
|
265
|
+
'Unused prompt', schema, lm=fake.StaticResponse(response), **kwargs
|
266
|
+
)
|
@@ -17,12 +17,10 @@ import inspect
|
|
17
17
|
import unittest
|
18
18
|
|
19
19
|
import langfun.core as lf
|
20
|
-
from langfun.core import coding
|
21
20
|
from langfun.core import modalities
|
22
21
|
from langfun.core.llms import fake
|
23
22
|
from langfun.core.structured import mapping
|
24
23
|
from langfun.core.structured import prompting
|
25
|
-
from langfun.core.structured import schema as schema_lib
|
26
24
|
import pyglove as pg
|
27
25
|
|
28
26
|
|
@@ -77,6 +75,7 @@ class QueryTest(unittest.TestCase):
|
|
77
75
|
result=1,
|
78
76
|
score=1.0,
|
79
77
|
logprobs=None,
|
78
|
+
usage=lf.LMSamplingUsage(323, 1, 324),
|
80
79
|
tags=['lm-response', 'lm-output', 'transformed'],
|
81
80
|
),
|
82
81
|
)
|
@@ -116,12 +115,59 @@ class QueryTest(unittest.TestCase):
|
|
116
115
|
y=2,
|
117
116
|
lm=lm.clone(),
|
118
117
|
expected_snippet=(
|
119
|
-
'Please respond to the last INPUT_OBJECT with OUTPUT_OBJECT'
|
120
|
-
'
|
121
|
-
'
|
122
|
-
'
|
123
|
-
'
|
124
|
-
'
|
118
|
+
'Please respond to the last INPUT_OBJECT with OUTPUT_OBJECT '
|
119
|
+
'according to OUTPUT_TYPE.\n\n'
|
120
|
+
'INPUT_OBJECT:\n 1 + 1 =\n\n'
|
121
|
+
'OUTPUT_TYPE:\n'
|
122
|
+
' Answer\n\n'
|
123
|
+
' ```python\n'
|
124
|
+
' class Answer:\n'
|
125
|
+
' final_answer: int\n'
|
126
|
+
' ```\n\n'
|
127
|
+
'OUTPUT_OBJECT:\n'
|
128
|
+
' ```python\n'
|
129
|
+
' Answer(\n'
|
130
|
+
' final_answer=2\n'
|
131
|
+
' )\n'
|
132
|
+
' ```\n\n'
|
133
|
+
'INPUT_OBJECT:\n'
|
134
|
+
' What is 1 + 2?\n\n'
|
135
|
+
'OUTPUT_TYPE:\n'
|
136
|
+
' int\n\n'
|
137
|
+
'OUTPUT_OBJECT:'
|
138
|
+
),
|
139
|
+
)
|
140
|
+
|
141
|
+
def test_str_to_structure_render_custom_template(self):
|
142
|
+
lm = fake.StaticResponse('1')
|
143
|
+
self.assert_render(
|
144
|
+
'What is {{x}} + {{y}}?',
|
145
|
+
int,
|
146
|
+
x=1,
|
147
|
+
y=2,
|
148
|
+
lm=lm.clone(),
|
149
|
+
template_str='!!{{ DEFAULT }}!!',
|
150
|
+
expected_snippet=(
|
151
|
+
'!!Please respond to the last INPUT_OBJECT with OUTPUT_OBJECT '
|
152
|
+
'according to OUTPUT_TYPE.\n\n'
|
153
|
+
'INPUT_OBJECT:\n 1 + 1 =\n\n'
|
154
|
+
'OUTPUT_TYPE:\n'
|
155
|
+
' Answer\n\n'
|
156
|
+
' ```python\n'
|
157
|
+
' class Answer:\n'
|
158
|
+
' final_answer: int\n'
|
159
|
+
' ```\n\n'
|
160
|
+
'OUTPUT_OBJECT:\n'
|
161
|
+
' ```python\n'
|
162
|
+
' Answer(\n'
|
163
|
+
' final_answer=2\n'
|
164
|
+
' )\n'
|
165
|
+
' ```\n\n'
|
166
|
+
'INPUT_OBJECT:\n'
|
167
|
+
' What is 1 + 2?\n\n'
|
168
|
+
'OUTPUT_TYPE:\n'
|
169
|
+
' int\n\n'
|
170
|
+
'OUTPUT_OBJECT:!!'
|
125
171
|
),
|
126
172
|
)
|
127
173
|
|
@@ -239,6 +285,49 @@ class QueryTest(unittest.TestCase):
|
|
239
285
|
with self.assertRaisesRegex(ValueError, 'Unknown protocol'):
|
240
286
|
prompting.query('what is 1 + 1', int, protocol='text')
|
241
287
|
|
288
|
+
def test_query_prompt(self):
|
289
|
+
self.assertEqual(
|
290
|
+
prompting.query_prompt('what is this?', int),
|
291
|
+
inspect.cleandoc("""
|
292
|
+
Please respond to the last INPUT_OBJECT with OUTPUT_OBJECT according to OUTPUT_TYPE.
|
293
|
+
|
294
|
+
INPUT_OBJECT:
|
295
|
+
1 + 1 =
|
296
|
+
|
297
|
+
OUTPUT_TYPE:
|
298
|
+
Answer
|
299
|
+
|
300
|
+
```python
|
301
|
+
class Answer:
|
302
|
+
final_answer: int
|
303
|
+
```
|
304
|
+
|
305
|
+
OUTPUT_OBJECT:
|
306
|
+
```python
|
307
|
+
Answer(
|
308
|
+
final_answer=2
|
309
|
+
)
|
310
|
+
```
|
311
|
+
|
312
|
+
INPUT_OBJECT:
|
313
|
+
what is this?
|
314
|
+
|
315
|
+
OUTPUT_TYPE:
|
316
|
+
int
|
317
|
+
|
318
|
+
OUTPUT_OBJECT:
|
319
|
+
"""),
|
320
|
+
)
|
321
|
+
|
322
|
+
def test_query_output(self):
|
323
|
+
self.assertEqual(
|
324
|
+
prompting.query_output(
|
325
|
+
lf.AIMessage('1'),
|
326
|
+
int,
|
327
|
+
),
|
328
|
+
1,
|
329
|
+
)
|
330
|
+
|
242
331
|
|
243
332
|
class QueryStructurePythonTest(unittest.TestCase):
|
244
333
|
|
@@ -264,7 +353,9 @@ class QueryStructurePythonTest(unittest.TestCase):
|
|
264
353
|
|
265
354
|
OUTPUT_OBJECT:
|
266
355
|
```python
|
267
|
-
Answer(
|
356
|
+
Answer(
|
357
|
+
final_answer=2
|
358
|
+
)
|
268
359
|
```
|
269
360
|
|
270
361
|
INPUT_OBJECT:
|
@@ -308,7 +399,9 @@ class QueryStructurePythonTest(unittest.TestCase):
|
|
308
399
|
|
309
400
|
OUTPUT_OBJECT:
|
310
401
|
```python
|
311
|
-
Answer(
|
402
|
+
Answer(
|
403
|
+
final_answer=2
|
404
|
+
)
|
312
405
|
```
|
313
406
|
|
314
407
|
INPUT_OBJECT:
|
@@ -420,7 +513,7 @@ class QueryStructurePythonTest(unittest.TestCase):
|
|
420
513
|
override_attrs=True,
|
421
514
|
):
|
422
515
|
with self.assertRaisesRegex(
|
423
|
-
|
516
|
+
mapping.MappingError,
|
424
517
|
'name .* is not defined',
|
425
518
|
):
|
426
519
|
prompting.query('Compute 1 + 2', int)
|
@@ -436,6 +529,23 @@ class QueryStructurePythonTest(unittest.TestCase):
|
|
436
529
|
])
|
437
530
|
self.assertEqual(prompting.query('what is 1 + 0', int, lm=lm, autofix=3), 1)
|
438
531
|
|
532
|
+
def test_response_postprocess(self):
|
533
|
+
with lf.context(
|
534
|
+
lm=fake.StaticResponse('<!-- some comment-->\n3'),
|
535
|
+
override_attrs=True,
|
536
|
+
):
|
537
|
+
self.assertEqual(
|
538
|
+
prompting.query(
|
539
|
+
'Compute 1 + 2', response_postprocess=lambda x: x.split('\n')[1]),
|
540
|
+
'3'
|
541
|
+
)
|
542
|
+
self.assertEqual(
|
543
|
+
prompting.query(
|
544
|
+
'Compute 1 + 2', int,
|
545
|
+
response_postprocess=lambda x: x.split('\n')[1]),
|
546
|
+
3
|
547
|
+
)
|
548
|
+
|
439
549
|
|
440
550
|
class QueryStructureJsonTest(unittest.TestCase):
|
441
551
|
|
@@ -641,7 +751,7 @@ class QueryStructureJsonTest(unittest.TestCase):
|
|
641
751
|
override_attrs=True,
|
642
752
|
):
|
643
753
|
with self.assertRaisesRegex(
|
644
|
-
|
754
|
+
mapping.MappingError,
|
645
755
|
'No JSON dict in the output',
|
646
756
|
):
|
647
757
|
prompting.query('Compute 1 + 2', int, protocol='json')
|
@@ -301,6 +301,7 @@ class SchemaPythonRepr(SchemaRepr):
|
|
301
301
|
schema: Schema,
|
302
302
|
*,
|
303
303
|
include_result_definition: bool = True,
|
304
|
+
include_methods: bool = False,
|
304
305
|
markdown: bool = True,
|
305
306
|
**kwargs,
|
306
307
|
) -> str:
|
@@ -308,7 +309,7 @@ class SchemaPythonRepr(SchemaRepr):
|
|
308
309
|
if include_result_definition:
|
309
310
|
ret += self.result_definition(schema)
|
310
311
|
class_definition_str = self.class_definitions(
|
311
|
-
schema, markdown=markdown, **kwargs
|
312
|
+
schema, markdown=markdown, include_methods=include_methods, **kwargs
|
312
313
|
)
|
313
314
|
if class_definition_str:
|
314
315
|
ret += f'\n\n{class_definition_str}'
|
@@ -331,6 +332,7 @@ def class_definitions(
|
|
331
332
|
classes: Sequence[Type[Any]],
|
332
333
|
*,
|
333
334
|
include_pg_object_as_base: bool = False,
|
335
|
+
include_methods: bool = False,
|
334
336
|
strict: bool = False,
|
335
337
|
markdown: bool = False,
|
336
338
|
) -> str | None:
|
@@ -346,6 +348,7 @@ def class_definitions(
|
|
346
348
|
cls,
|
347
349
|
strict=strict,
|
348
350
|
include_pg_object_as_base=include_pg_object_as_base,
|
351
|
+
include_methods=include_methods,
|
349
352
|
)
|
350
353
|
)
|
351
354
|
ret = def_str.getvalue()
|
@@ -355,7 +358,10 @@ def class_definitions(
|
|
355
358
|
|
356
359
|
|
357
360
|
def class_definition(
|
358
|
-
cls,
|
361
|
+
cls,
|
362
|
+
strict: bool = False,
|
363
|
+
include_pg_object_as_base: bool = False,
|
364
|
+
include_methods: bool = False,
|
359
365
|
) -> str:
|
360
366
|
"""Returns the Python class definition."""
|
361
367
|
out = io.StringIO()
|
@@ -383,13 +389,16 @@ def class_definition(
|
|
383
389
|
out.write('\n')
|
384
390
|
out.write(' """\n')
|
385
391
|
|
392
|
+
empty_class = True
|
386
393
|
if schema.fields:
|
387
394
|
for key, field in schema.items():
|
388
395
|
if not isinstance(key, pg.typing.ConstStrKey):
|
389
|
-
|
396
|
+
pg.logging.warning(
|
390
397
|
'Variable-length keyword arguments is not supported in '
|
391
|
-
f'structured parsing or query. Encountered: {
|
398
|
+
f'structured parsing or query. Encountered: {cls}, Schema: {schema}'
|
392
399
|
)
|
400
|
+
continue
|
401
|
+
|
393
402
|
# Write field doc string as comments before the field definition.
|
394
403
|
if field.description:
|
395
404
|
for line in field.description.split('\n'):
|
@@ -399,11 +408,33 @@ def class_definition(
|
|
399
408
|
out.write('\n')
|
400
409
|
out.write(f' {field.key}: {annotation(field.value, strict=strict)}')
|
401
410
|
out.write('\n')
|
402
|
-
|
411
|
+
empty_class = False
|
412
|
+
|
413
|
+
if include_methods:
|
414
|
+
for method in _iter_newly_defined_methods(cls):
|
415
|
+
out.write('\n')
|
416
|
+
out.write(
|
417
|
+
textwrap.indent(
|
418
|
+
inspect.cleandoc('\n' + inspect.getsource(method)), ' ' * 2)
|
419
|
+
)
|
420
|
+
out.write('\n')
|
421
|
+
empty_class = False
|
422
|
+
|
423
|
+
if empty_class:
|
403
424
|
out.write(' pass\n')
|
404
425
|
return out.getvalue()
|
405
426
|
|
406
427
|
|
428
|
+
def _iter_newly_defined_methods(cls):
|
429
|
+
names = set(dir(cls))
|
430
|
+
for base in cls.__bases__:
|
431
|
+
names -= set(dir(base))
|
432
|
+
for name in names:
|
433
|
+
attr = getattr(cls, name)
|
434
|
+
if callable(attr):
|
435
|
+
yield attr
|
436
|
+
|
437
|
+
|
407
438
|
def annotation(
|
408
439
|
vs: pg.typing.ValueSpec,
|
409
440
|
annotate_optional: bool = True,
|
@@ -491,7 +522,8 @@ def annotation(
|
|
491
522
|
class SchemaJsonRepr(SchemaRepr):
|
492
523
|
"""JSON-representation for a schema."""
|
493
524
|
|
494
|
-
def repr(self, schema: Schema) -> str:
|
525
|
+
def repr(self, schema: Schema, **kwargs) -> str:
|
526
|
+
del kwargs
|
495
527
|
out = io.StringIO()
|
496
528
|
def _visit(node: Any) -> None:
|
497
529
|
if isinstance(node, str):
|
@@ -14,8 +14,8 @@
|
|
14
14
|
import inspect
|
15
15
|
import unittest
|
16
16
|
|
17
|
-
import langfun.core.coding as lf_coding
|
18
17
|
from langfun.core.llms import fake
|
18
|
+
from langfun.core.structured import mapping
|
19
19
|
from langfun.core.structured import schema_generation
|
20
20
|
|
21
21
|
|
@@ -92,7 +92,7 @@ class GenerateClassTest(unittest.TestCase):
|
|
92
92
|
)
|
93
93
|
self.assertIs(cls.__name__, 'B')
|
94
94
|
|
95
|
-
with self.assertRaises(
|
95
|
+
with self.assertRaises(mapping.MappingError):
|
96
96
|
schema_generation.generate_class(
|
97
97
|
'Foo',
|
98
98
|
'Generate a Foo class with a field pointing to another class A',
|
@@ -192,9 +192,9 @@ class SchemaTest(unittest.TestCase):
|
|
192
192
|
self.assertEqual(schema.parse('{"result": 1}'), 1)
|
193
193
|
schema = schema_lib.Schema(dict[str, int])
|
194
194
|
self.assertEqual(
|
195
|
-
schema.parse(
|
196
|
-
|
197
|
-
|
195
|
+
schema.parse('{"result": {"x": 1}}}'),
|
196
|
+
dict(x=1)
|
197
|
+
)
|
198
198
|
with self.assertRaisesRegex(
|
199
199
|
schema_lib.SchemaError, 'Expect .* but encountered .*'):
|
200
200
|
schema.parse('{"result": "def"}')
|
@@ -459,9 +459,24 @@ class SchemaPythonReprTest(unittest.TestCase):
|
|
459
459
|
x: str
|
460
460
|
__kwargs__: typing.Any
|
461
461
|
|
462
|
-
|
463
|
-
|
464
|
-
|
462
|
+
self.assertEqual(schema_lib.class_definition(C), 'class C:\n x: str\n')
|
463
|
+
|
464
|
+
class D(pg.Object):
|
465
|
+
x: str
|
466
|
+
def __call__(self, y: int) -> int:
|
467
|
+
return len(self.x) + y
|
468
|
+
|
469
|
+
self.assertEqual(
|
470
|
+
schema_lib.class_definition(D, include_methods=True),
|
471
|
+
inspect.cleandoc(
|
472
|
+
"""
|
473
|
+
class D:
|
474
|
+
x: str
|
475
|
+
|
476
|
+
def __call__(self, y: int) -> int:
|
477
|
+
return len(self.x) + y
|
478
|
+
""") + '\n'
|
479
|
+
)
|
465
480
|
|
466
481
|
def test_repr(self):
|
467
482
|
class Foo(pg.Object):
|
@@ -479,13 +494,21 @@ class SchemaPythonReprTest(unittest.TestCase):
|
|
479
494
|
class A(pg.Object):
|
480
495
|
foo: Foo
|
481
496
|
|
497
|
+
def foo_value(self) -> int:
|
498
|
+
return self.foo.x
|
499
|
+
|
482
500
|
class B(A):
|
483
501
|
bar: Bar
|
484
502
|
foo2: Foo
|
485
503
|
|
504
|
+
def bar_value(self) -> str:
|
505
|
+
return self.bar.y
|
506
|
+
|
486
507
|
schema = schema_lib.Schema([B])
|
487
508
|
self.assertEqual(
|
488
|
-
schema_lib.SchemaPythonRepr().class_definitions(
|
509
|
+
schema_lib.SchemaPythonRepr().class_definitions(
|
510
|
+
schema, include_methods=True
|
511
|
+
),
|
489
512
|
inspect.cleandoc('''
|
490
513
|
class Foo:
|
491
514
|
x: int
|
@@ -493,6 +516,9 @@ class SchemaPythonReprTest(unittest.TestCase):
|
|
493
516
|
class A:
|
494
517
|
foo: Foo
|
495
518
|
|
519
|
+
def foo_value(self) -> int:
|
520
|
+
return self.foo.x
|
521
|
+
|
496
522
|
class Bar:
|
497
523
|
"""Class Bar."""
|
498
524
|
y: str
|
@@ -505,6 +531,9 @@ class SchemaPythonReprTest(unittest.TestCase):
|
|
505
531
|
foo: Foo
|
506
532
|
bar: Bar
|
507
533
|
foo2: Foo
|
534
|
+
|
535
|
+
def bar_value(self) -> str:
|
536
|
+
return self.bar.y
|
508
537
|
''') + '\n',
|
509
538
|
)
|
510
539
|
|
@@ -32,8 +32,9 @@ def score(
|
|
32
32
|
lm: lf.LanguageModel | None = None,
|
33
33
|
examples: list[mapping.MappingExample] | None = None,
|
34
34
|
protocol: schema_lib.SchemaProtocol = 'python',
|
35
|
+
return_scoring_results: bool = False,
|
35
36
|
**kwargs,
|
36
|
-
) -> list[float]:
|
37
|
+
) -> list[float] | list[lf.LMScoringResult]:
|
37
38
|
"""Scores the outputs based on the prompt."""
|
38
39
|
if not completions:
|
39
40
|
raise ValueError('`completions` must not be empty.')
|
@@ -72,4 +73,6 @@ def score(
|
|
72
73
|
for c in completions
|
73
74
|
],
|
74
75
|
)
|
76
|
+
if return_scoring_results:
|
77
|
+
return results
|
75
78
|
return [r.score for r in results]
|
@@ -35,6 +35,12 @@ class ScoringTest(unittest.TestCase):
|
|
35
35
|
def test_score(self):
|
36
36
|
self.assertEqual(scoring.score('hi', [1, 2], lm=fake.Echo()), [0.0, -1.0])
|
37
37
|
|
38
|
+
def test_score_returning_scoring_results(self):
|
39
|
+
self.assertEqual(scoring.score(
|
40
|
+
'hi', [1, 2], lm=fake.Echo(), return_scoring_results=True),
|
41
|
+
[lf.LMScoringResult(score=0.0, gradients=None),
|
42
|
+
lf.LMScoringResult(score=-1.0, gradients=None)])
|
43
|
+
|
38
44
|
def test_scope_with_lm_from_the_context(self):
|
39
45
|
with lf.context(lm=fake.Echo()):
|
40
46
|
self.assertEqual(scoring.score('hi', [1, 2]), [0.0, -1.0])
|