langfun 0.0.2.dev20240330__py3-none-any.whl → 0.0.2.dev20240511__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of langfun might be problematic. Click here for more details.

Files changed (59) hide show
  1. langfun/__init__.py +7 -0
  2. langfun/core/__init__.py +1 -0
  3. langfun/core/coding/python/correction.py +0 -7
  4. langfun/core/component.py +6 -0
  5. langfun/core/component_test.py +1 -0
  6. langfun/core/eval/__init__.py +15 -0
  7. langfun/core/eval/base.py +665 -95
  8. langfun/core/eval/base_test.py +224 -53
  9. langfun/core/eval/matching.py +48 -30
  10. langfun/core/eval/matching_test.py +25 -3
  11. langfun/core/eval/patching.py +130 -0
  12. langfun/core/eval/patching_test.py +170 -0
  13. langfun/core/eval/scoring.py +19 -10
  14. langfun/core/eval/scoring_test.py +21 -3
  15. langfun/core/langfunc.py +1 -22
  16. langfun/core/langfunc_test.py +10 -4
  17. langfun/core/language_model.py +130 -24
  18. langfun/core/language_model_test.py +249 -26
  19. langfun/core/llms/__init__.py +27 -2
  20. langfun/core/llms/anthropic.py +263 -0
  21. langfun/core/llms/anthropic_test.py +167 -0
  22. langfun/core/llms/cache/in_memory_test.py +37 -28
  23. langfun/core/llms/fake.py +34 -25
  24. langfun/core/llms/fake_test.py +122 -11
  25. langfun/core/llms/google_genai.py +8 -0
  26. langfun/core/llms/google_genai_test.py +8 -3
  27. langfun/core/llms/groq.py +260 -0
  28. langfun/core/llms/groq_test.py +170 -0
  29. langfun/core/llms/llama_cpp.py +3 -1
  30. langfun/core/llms/openai.py +100 -81
  31. langfun/core/llms/openai_test.py +287 -60
  32. langfun/core/llms/vertexai.py +291 -0
  33. langfun/core/llms/vertexai_test.py +233 -0
  34. langfun/core/modalities/image.py +1 -3
  35. langfun/core/modalities/mime.py +6 -0
  36. langfun/core/modalities/video.py +6 -5
  37. langfun/core/structured/__init__.py +5 -0
  38. langfun/core/structured/completion_test.py +2 -2
  39. langfun/core/structured/function_generation.py +245 -0
  40. langfun/core/structured/function_generation_test.py +329 -0
  41. langfun/core/structured/mapping.py +61 -3
  42. langfun/core/structured/mapping_test.py +17 -0
  43. langfun/core/structured/parsing_test.py +18 -13
  44. langfun/core/structured/prompting.py +61 -12
  45. langfun/core/structured/prompting_test.py +122 -12
  46. langfun/core/structured/schema.py +38 -6
  47. langfun/core/structured/schema_generation_test.py +2 -2
  48. langfun/core/structured/schema_test.py +36 -7
  49. langfun/core/structured/scoring.py +4 -1
  50. langfun/core/structured/scoring_test.py +6 -0
  51. langfun/core/template.py +147 -11
  52. langfun/core/template_test.py +75 -0
  53. langfun/core/templates/selfplay_test.py +6 -2
  54. {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240511.dist-info}/METADATA +3 -2
  55. langfun-0.0.2.dev20240511.dist-info/RECORD +112 -0
  56. langfun-0.0.2.dev20240330.dist-info/RECORD +0 -102
  57. {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240511.dist-info}/LICENSE +0 -0
  58. {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240511.dist-info}/WHEEL +0 -0
  59. {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240511.dist-info}/top_level.txt +0 -0
@@ -17,11 +17,9 @@ import inspect
17
17
  import unittest
18
18
 
19
19
  import langfun.core as lf
20
- from langfun.core import coding
21
20
  from langfun.core.llms import fake
22
21
  from langfun.core.structured import mapping
23
22
  from langfun.core.structured import parsing
24
- from langfun.core.structured import schema as schema_lib
25
23
  import pyglove as pg
26
24
 
27
25
 
@@ -255,7 +253,7 @@ class ParseStructurePythonTest(unittest.TestCase):
255
253
  override_attrs=True,
256
254
  ):
257
255
  with self.assertRaisesRegex(
258
- coding.CodeError,
256
+ mapping.MappingError,
259
257
  'name .* is not defined',
260
258
  ):
261
259
  parsing.parse('three', int)
@@ -280,13 +278,15 @@ class ParseStructurePythonTest(unittest.TestCase):
280
278
  ),
281
279
  1,
282
280
  )
281
+ r = parsing.parse(
282
+ 'the answer is 1', int, user_prompt='what is 0 + 1?', lm=lm,
283
+ returns_message=True
284
+ )
283
285
  self.assertEqual(
284
- parsing.parse(
285
- 'the answer is 1', int, user_prompt='what is 0 + 1?', lm=lm,
286
- returns_message=True
287
- ),
286
+ r,
288
287
  lf.AIMessage(
289
288
  '1', score=1.0, result=1, logprobs=None,
289
+ usage=lf.LMSamplingUsage(652, 1, 653),
290
290
  tags=['lm-response', 'lm-output', 'transformed']
291
291
  ),
292
292
  )
@@ -544,7 +544,7 @@ class ParseStructureJsonTest(unittest.TestCase):
544
544
  override_attrs=True,
545
545
  ):
546
546
  with self.assertRaisesRegex(
547
- schema_lib.JsonError,
547
+ mapping.MappingError,
548
548
  'No JSON dict in the output',
549
549
  ):
550
550
  parsing.parse('three', int, protocol='json')
@@ -634,13 +634,18 @@ class CallTest(unittest.TestCase):
634
634
  )
635
635
 
636
636
  def test_call_with_returning_message(self):
637
+ r = parsing.call(
638
+ 'Compute 1 + 2', int, lm=fake.StaticSequence(['three', '3']),
639
+ returns_message=True
640
+ )
637
641
  self.assertEqual(
638
- parsing.call(
639
- 'Compute 1 + 2', int, lm=fake.StaticSequence(['three', '3']),
640
- returns_message=True
641
- ),
642
+ r,
642
643
  lf.AIMessage(
643
- '3', result=3, score=1.0, logprobs=None,
644
+ '3',
645
+ result=3,
646
+ score=1.0,
647
+ logprobs=None,
648
+ usage=lf.LMSamplingUsage(315, 1, 316),
644
649
  tags=['lm-response', 'lm-output', 'transformed']
645
650
  ),
646
651
  )
@@ -13,9 +13,10 @@
13
13
  # limitations under the License.
14
14
  """Symbolic query."""
15
15
 
16
- from typing import Any, Type, Union
16
+ from typing import Any, Callable, Type, Union
17
17
 
18
18
  import langfun.core as lf
19
+ from langfun.core.llms import fake
19
20
  from langfun.core.structured import mapping
20
21
  from langfun.core.structured import schema as schema_lib
21
22
  import pyglove as pg
@@ -78,7 +79,9 @@ class QueryStructurePython(QueryStructure):
78
79
 
79
80
  {{ output_title }}:
80
81
  ```python
81
- Answer(final_answer=2)
82
+ Answer(
83
+ final_answer=2
84
+ )
82
85
  ```
83
86
  """
84
87
  protocol = 'python'
@@ -107,9 +110,11 @@ def query(
107
110
  lm: lf.LanguageModel | None = None,
108
111
  examples: list[mapping.MappingExample] | None = None,
109
112
  cache_seed: int | None = 0,
113
+ response_postprocess: Callable[[str], str] | None = None,
110
114
  autofix: int = 0,
111
115
  autofix_lm: lf.LanguageModel | None = None,
112
116
  protocol: schema_lib.SchemaProtocol = 'python',
117
+ include_methods: bool = False,
113
118
  returns_message: bool = False,
114
119
  skip_lm: bool = False,
115
120
  **kwargs,
@@ -157,8 +162,11 @@ def query(
157
162
  examples: An optional list of fewshot examples for helping parsing. If None,
158
163
  the default one-shot example will be added.
159
164
  cache_seed: Seed for computing cache key. The cache key is determined by a
160
- tuple of (lm, prompt, cache seed). If None, cache will be disabled for
161
- the query even cache is configured by the LM.
165
+ tuple of (lm, prompt, cache seed). If None, cache will be disabled for the
166
+ query even cache is configured by the LM.
167
+ response_postprocess: An optional callable object to process the raw LM
168
+ response before parsing it into the final output object. If None, the raw
169
+ LM response will not be processed.
162
170
  autofix: Number of attempts to auto fix the generated code. If 0, autofix is
163
171
  disabled. Auto-fix is not supported for 'json' protocol.
164
172
  autofix_lm: The language model to use for autofix. If not specified, the
@@ -166,12 +174,17 @@ def query(
166
174
  will use `lm`.
167
175
  protocol: The protocol for schema/value representation. Applicable values
168
176
  are 'json' and 'python'. By default `python` will be used.
177
+ include_methods: If True, include method definitions in the output type
178
+ during prompting.
169
179
  returns_message: If True, returns `lf.Message` as the output, instead of
170
180
  returning the structured `message.result`.
171
181
  skip_lm: If True, returns the rendered prompt as a UserMessage object.
172
182
  otherwise return the LLM response based on the rendered prompt.
173
- **kwargs: Keyword arguments passed to the
174
- `lf.structured.NaturalLanguageToStructureed` transform.
183
+ **kwargs: Keyword arguments passed to render the prompt or configure the
184
+ `lf.structured.Mapping` class. Notable kwargs are:
185
+ - template_str: Change the root template for query.
186
+ - preamble: Change the preamble for query.
187
+ - mapping_template: Change the template for each mapping examle.
175
188
 
176
189
  Returns:
177
190
  The result based on the schema.
@@ -188,16 +201,22 @@ def query(
188
201
  output = lf.LangFunc.from_value(prompt, **kwargs)(
189
202
  lm=lm, cache_seed=cache_seed, skip_lm=skip_lm
190
203
  )
204
+ if response_postprocess:
205
+ processed_text = response_postprocess(output.text)
206
+ if processed_text != output.text:
207
+ output = lf.AIMessage(processed_text, source=output)
191
208
  return output if returns_message else output.text
192
209
 
193
210
  # Query with structured output.
194
- if isinstance(prompt, str):
195
- prompt = lf.Template(prompt, **kwargs)
196
- elif isinstance(prompt, lf.Template):
197
- prompt = prompt.rebind(**kwargs)
211
+ prompt_kwargs = kwargs.copy()
198
212
 
199
- if isinstance(prompt, lf.Template):
200
- prompt = prompt.render(lm=lm)
213
+ # NOTE(daiyip): when `template_str` is passed in, it's intended to modify the
214
+ # QueryStructure template string. Therefore, we pop out the argument for
215
+ # prompt rendering.
216
+ prompt_kwargs.pop('template_str', None)
217
+
218
+ if isinstance(prompt, (str, lf.Message, lf.Template)):
219
+ prompt = lf.Template.from_value(prompt, **prompt_kwargs).render(lm=lm)
201
220
  else:
202
221
  prompt = schema_lib.mark_missing(prompt)
203
222
 
@@ -206,6 +225,8 @@ def query(
206
225
  schema=schema,
207
226
  default=default,
208
227
  examples=examples,
228
+ include_methods=include_methods,
229
+ response_postprocess=response_postprocess,
209
230
  autofix=autofix if protocol == 'python' else 0,
210
231
  **kwargs,
211
232
  )(
@@ -215,3 +236,31 @@ def query(
215
236
  skip_lm=skip_lm,
216
237
  )
217
238
  return output if returns_message else output.result
239
+
240
+
241
+ def query_prompt(
242
+ prompt: Union[str, pg.Symbolic],
243
+ schema: Union[
244
+ schema_lib.Schema, Type[Any], list[Type[Any]], dict[str, Any], None
245
+ ] = None,
246
+ **kwargs,
247
+ ) -> lf.Message:
248
+ """Returns the final prompt sent to LLM for `lf.query`."""
249
+ kwargs.pop('returns_message', None)
250
+ kwargs.pop('skip_lm', None)
251
+ return query(prompt, schema, skip_lm=True, returns_message=True, **kwargs)
252
+
253
+
254
+ def query_output(
255
+ response: Union[str, lf.Message],
256
+ schema: Union[
257
+ schema_lib.Schema, Type[Any], list[Type[Any]], dict[str, Any], None
258
+ ],
259
+ **kwargs,
260
+ ) -> Any:
261
+ """Returns the final output of `lf.query` from a provided LLM response."""
262
+ kwargs.pop('prompt', None)
263
+ kwargs.pop('lm', None)
264
+ return query(
265
+ 'Unused prompt', schema, lm=fake.StaticResponse(response), **kwargs
266
+ )
@@ -17,12 +17,10 @@ import inspect
17
17
  import unittest
18
18
 
19
19
  import langfun.core as lf
20
- from langfun.core import coding
21
20
  from langfun.core import modalities
22
21
  from langfun.core.llms import fake
23
22
  from langfun.core.structured import mapping
24
23
  from langfun.core.structured import prompting
25
- from langfun.core.structured import schema as schema_lib
26
24
  import pyglove as pg
27
25
 
28
26
 
@@ -77,6 +75,7 @@ class QueryTest(unittest.TestCase):
77
75
  result=1,
78
76
  score=1.0,
79
77
  logprobs=None,
78
+ usage=lf.LMSamplingUsage(323, 1, 324),
80
79
  tags=['lm-response', 'lm-output', 'transformed'],
81
80
  ),
82
81
  )
@@ -116,12 +115,59 @@ class QueryTest(unittest.TestCase):
116
115
  y=2,
117
116
  lm=lm.clone(),
118
117
  expected_snippet=(
119
- 'Please respond to the last INPUT_OBJECT with OUTPUT_OBJECT'
120
- ' according to OUTPUT_TYPE.\n\nINPUT_OBJECT:\n 1 + 1'
121
- ' =\n\nOUTPUT_TYPE:\n Answer\n\n ```python\n class Answer:\n '
122
- ' final_answer: int\n ```\n\nOUTPUT_OBJECT:\n ```python\n '
123
- ' Answer(final_answer=2)\n ```\n\nINPUT_OBJECT:\n What is 1 +'
124
- ' 2?\n\nOUTPUT_TYPE:\n int\n\nOUTPUT_OBJECT:'
118
+ 'Please respond to the last INPUT_OBJECT with OUTPUT_OBJECT '
119
+ 'according to OUTPUT_TYPE.\n\n'
120
+ 'INPUT_OBJECT:\n 1 + 1 =\n\n'
121
+ 'OUTPUT_TYPE:\n'
122
+ ' Answer\n\n'
123
+ ' ```python\n'
124
+ ' class Answer:\n'
125
+ ' final_answer: int\n'
126
+ ' ```\n\n'
127
+ 'OUTPUT_OBJECT:\n'
128
+ ' ```python\n'
129
+ ' Answer(\n'
130
+ ' final_answer=2\n'
131
+ ' )\n'
132
+ ' ```\n\n'
133
+ 'INPUT_OBJECT:\n'
134
+ ' What is 1 + 2?\n\n'
135
+ 'OUTPUT_TYPE:\n'
136
+ ' int\n\n'
137
+ 'OUTPUT_OBJECT:'
138
+ ),
139
+ )
140
+
141
+ def test_str_to_structure_render_custom_template(self):
142
+ lm = fake.StaticResponse('1')
143
+ self.assert_render(
144
+ 'What is {{x}} + {{y}}?',
145
+ int,
146
+ x=1,
147
+ y=2,
148
+ lm=lm.clone(),
149
+ template_str='!!{{ DEFAULT }}!!',
150
+ expected_snippet=(
151
+ '!!Please respond to the last INPUT_OBJECT with OUTPUT_OBJECT '
152
+ 'according to OUTPUT_TYPE.\n\n'
153
+ 'INPUT_OBJECT:\n 1 + 1 =\n\n'
154
+ 'OUTPUT_TYPE:\n'
155
+ ' Answer\n\n'
156
+ ' ```python\n'
157
+ ' class Answer:\n'
158
+ ' final_answer: int\n'
159
+ ' ```\n\n'
160
+ 'OUTPUT_OBJECT:\n'
161
+ ' ```python\n'
162
+ ' Answer(\n'
163
+ ' final_answer=2\n'
164
+ ' )\n'
165
+ ' ```\n\n'
166
+ 'INPUT_OBJECT:\n'
167
+ ' What is 1 + 2?\n\n'
168
+ 'OUTPUT_TYPE:\n'
169
+ ' int\n\n'
170
+ 'OUTPUT_OBJECT:!!'
125
171
  ),
126
172
  )
127
173
 
@@ -239,6 +285,49 @@ class QueryTest(unittest.TestCase):
239
285
  with self.assertRaisesRegex(ValueError, 'Unknown protocol'):
240
286
  prompting.query('what is 1 + 1', int, protocol='text')
241
287
 
288
+ def test_query_prompt(self):
289
+ self.assertEqual(
290
+ prompting.query_prompt('what is this?', int),
291
+ inspect.cleandoc("""
292
+ Please respond to the last INPUT_OBJECT with OUTPUT_OBJECT according to OUTPUT_TYPE.
293
+
294
+ INPUT_OBJECT:
295
+ 1 + 1 =
296
+
297
+ OUTPUT_TYPE:
298
+ Answer
299
+
300
+ ```python
301
+ class Answer:
302
+ final_answer: int
303
+ ```
304
+
305
+ OUTPUT_OBJECT:
306
+ ```python
307
+ Answer(
308
+ final_answer=2
309
+ )
310
+ ```
311
+
312
+ INPUT_OBJECT:
313
+ what is this?
314
+
315
+ OUTPUT_TYPE:
316
+ int
317
+
318
+ OUTPUT_OBJECT:
319
+ """),
320
+ )
321
+
322
+ def test_query_output(self):
323
+ self.assertEqual(
324
+ prompting.query_output(
325
+ lf.AIMessage('1'),
326
+ int,
327
+ ),
328
+ 1,
329
+ )
330
+
242
331
 
243
332
  class QueryStructurePythonTest(unittest.TestCase):
244
333
 
@@ -264,7 +353,9 @@ class QueryStructurePythonTest(unittest.TestCase):
264
353
 
265
354
  OUTPUT_OBJECT:
266
355
  ```python
267
- Answer(final_answer=2)
356
+ Answer(
357
+ final_answer=2
358
+ )
268
359
  ```
269
360
 
270
361
  INPUT_OBJECT:
@@ -308,7 +399,9 @@ class QueryStructurePythonTest(unittest.TestCase):
308
399
 
309
400
  OUTPUT_OBJECT:
310
401
  ```python
311
- Answer(final_answer=2)
402
+ Answer(
403
+ final_answer=2
404
+ )
312
405
  ```
313
406
 
314
407
  INPUT_OBJECT:
@@ -420,7 +513,7 @@ class QueryStructurePythonTest(unittest.TestCase):
420
513
  override_attrs=True,
421
514
  ):
422
515
  with self.assertRaisesRegex(
423
- coding.CodeError,
516
+ mapping.MappingError,
424
517
  'name .* is not defined',
425
518
  ):
426
519
  prompting.query('Compute 1 + 2', int)
@@ -436,6 +529,23 @@ class QueryStructurePythonTest(unittest.TestCase):
436
529
  ])
437
530
  self.assertEqual(prompting.query('what is 1 + 0', int, lm=lm, autofix=3), 1)
438
531
 
532
+ def test_response_postprocess(self):
533
+ with lf.context(
534
+ lm=fake.StaticResponse('<!-- some comment-->\n3'),
535
+ override_attrs=True,
536
+ ):
537
+ self.assertEqual(
538
+ prompting.query(
539
+ 'Compute 1 + 2', response_postprocess=lambda x: x.split('\n')[1]),
540
+ '3'
541
+ )
542
+ self.assertEqual(
543
+ prompting.query(
544
+ 'Compute 1 + 2', int,
545
+ response_postprocess=lambda x: x.split('\n')[1]),
546
+ 3
547
+ )
548
+
439
549
 
440
550
  class QueryStructureJsonTest(unittest.TestCase):
441
551
 
@@ -641,7 +751,7 @@ class QueryStructureJsonTest(unittest.TestCase):
641
751
  override_attrs=True,
642
752
  ):
643
753
  with self.assertRaisesRegex(
644
- schema_lib.JsonError,
754
+ mapping.MappingError,
645
755
  'No JSON dict in the output',
646
756
  ):
647
757
  prompting.query('Compute 1 + 2', int, protocol='json')
@@ -301,6 +301,7 @@ class SchemaPythonRepr(SchemaRepr):
301
301
  schema: Schema,
302
302
  *,
303
303
  include_result_definition: bool = True,
304
+ include_methods: bool = False,
304
305
  markdown: bool = True,
305
306
  **kwargs,
306
307
  ) -> str:
@@ -308,7 +309,7 @@ class SchemaPythonRepr(SchemaRepr):
308
309
  if include_result_definition:
309
310
  ret += self.result_definition(schema)
310
311
  class_definition_str = self.class_definitions(
311
- schema, markdown=markdown, **kwargs
312
+ schema, markdown=markdown, include_methods=include_methods, **kwargs
312
313
  )
313
314
  if class_definition_str:
314
315
  ret += f'\n\n{class_definition_str}'
@@ -331,6 +332,7 @@ def class_definitions(
331
332
  classes: Sequence[Type[Any]],
332
333
  *,
333
334
  include_pg_object_as_base: bool = False,
335
+ include_methods: bool = False,
334
336
  strict: bool = False,
335
337
  markdown: bool = False,
336
338
  ) -> str | None:
@@ -346,6 +348,7 @@ def class_definitions(
346
348
  cls,
347
349
  strict=strict,
348
350
  include_pg_object_as_base=include_pg_object_as_base,
351
+ include_methods=include_methods,
349
352
  )
350
353
  )
351
354
  ret = def_str.getvalue()
@@ -355,7 +358,10 @@ def class_definitions(
355
358
 
356
359
 
357
360
  def class_definition(
358
- cls, strict: bool = False, include_pg_object_as_base: bool = False
361
+ cls,
362
+ strict: bool = False,
363
+ include_pg_object_as_base: bool = False,
364
+ include_methods: bool = False,
359
365
  ) -> str:
360
366
  """Returns the Python class definition."""
361
367
  out = io.StringIO()
@@ -383,13 +389,16 @@ def class_definition(
383
389
  out.write('\n')
384
390
  out.write(' """\n')
385
391
 
392
+ empty_class = True
386
393
  if schema.fields:
387
394
  for key, field in schema.items():
388
395
  if not isinstance(key, pg.typing.ConstStrKey):
389
- raise TypeError(
396
+ pg.logging.warning(
390
397
  'Variable-length keyword arguments is not supported in '
391
- f'structured parsing or query. Encountered: {field}'
398
+ f'structured parsing or query. Encountered: {cls}, Schema: {schema}'
392
399
  )
400
+ continue
401
+
393
402
  # Write field doc string as comments before the field definition.
394
403
  if field.description:
395
404
  for line in field.description.split('\n'):
@@ -399,11 +408,33 @@ def class_definition(
399
408
  out.write('\n')
400
409
  out.write(f' {field.key}: {annotation(field.value, strict=strict)}')
401
410
  out.write('\n')
402
- else:
411
+ empty_class = False
412
+
413
+ if include_methods:
414
+ for method in _iter_newly_defined_methods(cls):
415
+ out.write('\n')
416
+ out.write(
417
+ textwrap.indent(
418
+ inspect.cleandoc('\n' + inspect.getsource(method)), ' ' * 2)
419
+ )
420
+ out.write('\n')
421
+ empty_class = False
422
+
423
+ if empty_class:
403
424
  out.write(' pass\n')
404
425
  return out.getvalue()
405
426
 
406
427
 
428
+ def _iter_newly_defined_methods(cls):
429
+ names = set(dir(cls))
430
+ for base in cls.__bases__:
431
+ names -= set(dir(base))
432
+ for name in names:
433
+ attr = getattr(cls, name)
434
+ if callable(attr):
435
+ yield attr
436
+
437
+
407
438
  def annotation(
408
439
  vs: pg.typing.ValueSpec,
409
440
  annotate_optional: bool = True,
@@ -491,7 +522,8 @@ def annotation(
491
522
  class SchemaJsonRepr(SchemaRepr):
492
523
  """JSON-representation for a schema."""
493
524
 
494
- def repr(self, schema: Schema) -> str:
525
+ def repr(self, schema: Schema, **kwargs) -> str:
526
+ del kwargs
495
527
  out = io.StringIO()
496
528
  def _visit(node: Any) -> None:
497
529
  if isinstance(node, str):
@@ -14,8 +14,8 @@
14
14
  import inspect
15
15
  import unittest
16
16
 
17
- import langfun.core.coding as lf_coding
18
17
  from langfun.core.llms import fake
18
+ from langfun.core.structured import mapping
19
19
  from langfun.core.structured import schema_generation
20
20
 
21
21
 
@@ -92,7 +92,7 @@ class GenerateClassTest(unittest.TestCase):
92
92
  )
93
93
  self.assertIs(cls.__name__, 'B')
94
94
 
95
- with self.assertRaises(lf_coding.CodeError):
95
+ with self.assertRaises(mapping.MappingError):
96
96
  schema_generation.generate_class(
97
97
  'Foo',
98
98
  'Generate a Foo class with a field pointing to another class A',
@@ -192,9 +192,9 @@ class SchemaTest(unittest.TestCase):
192
192
  self.assertEqual(schema.parse('{"result": 1}'), 1)
193
193
  schema = schema_lib.Schema(dict[str, int])
194
194
  self.assertEqual(
195
- schema.parse(
196
- '{"result": {"_type": "Unknown", "x": 1}}}', force_dict=True),
197
- dict(x=1))
195
+ schema.parse('{"result": {"x": 1}}}'),
196
+ dict(x=1)
197
+ )
198
198
  with self.assertRaisesRegex(
199
199
  schema_lib.SchemaError, 'Expect .* but encountered .*'):
200
200
  schema.parse('{"result": "def"}')
@@ -459,9 +459,24 @@ class SchemaPythonReprTest(unittest.TestCase):
459
459
  x: str
460
460
  __kwargs__: typing.Any
461
461
 
462
- with self.assertRaisesRegex(
463
- TypeError, 'Variable-length keyword arguments is not supported'):
464
- schema_lib.class_definition(C)
462
+ self.assertEqual(schema_lib.class_definition(C), 'class C:\n x: str\n')
463
+
464
+ class D(pg.Object):
465
+ x: str
466
+ def __call__(self, y: int) -> int:
467
+ return len(self.x) + y
468
+
469
+ self.assertEqual(
470
+ schema_lib.class_definition(D, include_methods=True),
471
+ inspect.cleandoc(
472
+ """
473
+ class D:
474
+ x: str
475
+
476
+ def __call__(self, y: int) -> int:
477
+ return len(self.x) + y
478
+ """) + '\n'
479
+ )
465
480
 
466
481
  def test_repr(self):
467
482
  class Foo(pg.Object):
@@ -479,13 +494,21 @@ class SchemaPythonReprTest(unittest.TestCase):
479
494
  class A(pg.Object):
480
495
  foo: Foo
481
496
 
497
+ def foo_value(self) -> int:
498
+ return self.foo.x
499
+
482
500
  class B(A):
483
501
  bar: Bar
484
502
  foo2: Foo
485
503
 
504
+ def bar_value(self) -> str:
505
+ return self.bar.y
506
+
486
507
  schema = schema_lib.Schema([B])
487
508
  self.assertEqual(
488
- schema_lib.SchemaPythonRepr().class_definitions(schema),
509
+ schema_lib.SchemaPythonRepr().class_definitions(
510
+ schema, include_methods=True
511
+ ),
489
512
  inspect.cleandoc('''
490
513
  class Foo:
491
514
  x: int
@@ -493,6 +516,9 @@ class SchemaPythonReprTest(unittest.TestCase):
493
516
  class A:
494
517
  foo: Foo
495
518
 
519
+ def foo_value(self) -> int:
520
+ return self.foo.x
521
+
496
522
  class Bar:
497
523
  """Class Bar."""
498
524
  y: str
@@ -505,6 +531,9 @@ class SchemaPythonReprTest(unittest.TestCase):
505
531
  foo: Foo
506
532
  bar: Bar
507
533
  foo2: Foo
534
+
535
+ def bar_value(self) -> str:
536
+ return self.bar.y
508
537
  ''') + '\n',
509
538
  )
510
539
 
@@ -32,8 +32,9 @@ def score(
32
32
  lm: lf.LanguageModel | None = None,
33
33
  examples: list[mapping.MappingExample] | None = None,
34
34
  protocol: schema_lib.SchemaProtocol = 'python',
35
+ return_scoring_results: bool = False,
35
36
  **kwargs,
36
- ) -> list[float]:
37
+ ) -> list[float] | list[lf.LMScoringResult]:
37
38
  """Scores the outputs based on the prompt."""
38
39
  if not completions:
39
40
  raise ValueError('`completions` must not be empty.')
@@ -72,4 +73,6 @@ def score(
72
73
  for c in completions
73
74
  ],
74
75
  )
76
+ if return_scoring_results:
77
+ return results
75
78
  return [r.score for r in results]
@@ -35,6 +35,12 @@ class ScoringTest(unittest.TestCase):
35
35
  def test_score(self):
36
36
  self.assertEqual(scoring.score('hi', [1, 2], lm=fake.Echo()), [0.0, -1.0])
37
37
 
38
+ def test_score_returning_scoring_results(self):
39
+ self.assertEqual(scoring.score(
40
+ 'hi', [1, 2], lm=fake.Echo(), return_scoring_results=True),
41
+ [lf.LMScoringResult(score=0.0, gradients=None),
42
+ lf.LMScoringResult(score=-1.0, gradients=None)])
43
+
38
44
  def test_scope_with_lm_from_the_context(self):
39
45
  with lf.context(lm=fake.Echo()):
40
46
  self.assertEqual(scoring.score('hi', [1, 2]), [0.0, -1.0])