langfun 0.1.1.dev20240819__py3-none-any.whl → 0.1.1.dev20240821__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
langfun/core/eval/base.py CHANGED
@@ -1061,6 +1061,8 @@ class Evaluation(Evaluable):
1061
1061
  try:
1062
1062
  with lf.use_settings(debug=debug):
1063
1063
  output_message = copy.process(example, **(self.additional_args or {}))
1064
+ self.process_output(example, output_message)
1065
+
1064
1066
  if self.schema is None:
1065
1067
  output = output_message.text
1066
1068
  else:
@@ -1123,7 +1125,9 @@ class Evaluation(Evaluable):
1123
1125
  # generated code with calls to `input` will raise an error, thus not
1124
1126
  # blocking the evaluation.
1125
1127
  with lf_coding.context(input=None):
1126
- return self.process(example, **(self.additional_args or {}))
1128
+ output_message = self.process(example, **(self.additional_args or {}))
1129
+ self.process_output(example, output_message)
1130
+ return output_message
1127
1131
 
1128
1132
  try:
1129
1133
  for example, message, error in lf.concurrent_map(
@@ -1201,6 +1205,29 @@ class Evaluation(Evaluable):
1201
1205
  **kwargs,
1202
1206
  )
1203
1207
 
1208
+ def process_output(self, example: Any, output: lf.Message) -> None:
1209
+ """Process the output for an example.
1210
+
1211
+ Subclasses can override this method to generate and attach additional
1212
+ metadata for debugging purpose. For example, draw bounding boxes on the
1213
+ input image based on LLM predicted boxes and attach to output_message's
1214
+ metadata.
1215
+
1216
+ Example:
1217
+
1218
+ class BoundingBoxEval(lf.eval.Matching):
1219
+ ...
1220
+ def process_output(example, output):
1221
+ output.metadata.image_with_bbox = draw_bboxes(
1222
+ example.image, output.result)
1223
+
1224
+ Args:
1225
+ example: User input.
1226
+ output: LLM's output message. Users could attach additional
1227
+ information to the message, which will be shown in debugging
1228
+ """
1229
+ del example, output
1230
+
1204
1231
  def _status(self, progress: lf.concurrent.Progress) -> dict[str, Any]:
1205
1232
  return {
1206
1233
  'Model': self.lm.model_id,
@@ -15,6 +15,7 @@
15
15
 
16
16
  import collections
17
17
  import contextlib
18
+ import html
18
19
  import io
19
20
  from typing import Any, Callable, Iterator
20
21
 
@@ -126,7 +127,7 @@ def html_repr(
126
127
  if hasattr(v, '_repr_html_'):
127
128
  cs = v._repr_html_() # pylint: disable=protected-access
128
129
  else:
129
- cs = f'<span style="white-space: pre-wrap">{str(v)}</span>'
130
+ cs = f'<span style="white-space: pre-wrap">{html.escape(str(v))}</span>'
130
131
 
131
132
  key_color, key_bg_color, value_color, value_bg_color = item_color(k, v)
132
133
  key_span = html_round_text(
@@ -63,9 +63,12 @@ class SharingContentTest(unittest.TestCase):
63
63
  class Foo(pg.Object):
64
64
  x: int
65
65
 
66
- html = repr_utils.html_repr({'foo': pg.Ref(Foo(1))})
66
+ html = repr_utils.html_repr(
67
+ {'foo': pg.Ref(Foo(1)), 'bar': '<lf_image>'}
68
+ )
67
69
  self.assertIn('foo</span>', html)
68
70
  self.assertNotIn('Ref', html)
71
+ self.assertIn('&lt;lf_image&gt;', html)
69
72
 
70
73
 
71
74
  if __name__ == '__main__':
@@ -92,6 +92,15 @@ class MappingExample(lf.NaturalLanguageFormattable, lf.Component):
92
92
  'The natural language context for this mapping. ',
93
93
  ] = None
94
94
 
95
+ metadata: Annotated[
96
+ dict[str, Any],
97
+ (
98
+ 'The metadata associated with the mapping example, '
99
+ 'which chould carry structured data, such as tool function input. '
100
+ 'It is a `pg.Dict` object whose keys can be accessed by attributes.'
101
+ ),
102
+ ] = pg.Dict()
103
+
95
104
  def schema_repr(
96
105
  self, protocol: schema_lib.SchemaProtocol = 'python', **kwargs
97
106
  ) -> str:
@@ -157,16 +166,21 @@ class MappingExample(lf.NaturalLanguageFormattable, lf.Component):
157
166
 
158
167
  result.write(lf.colored('[INPUT]\n', styles=['bold']))
159
168
  result.write(lf.colored(self.input_repr(), color='green'))
160
- result.write('\n\n')
161
169
 
162
170
  if self.schema is not None:
171
+ result.write('\n\n')
163
172
  result.write(lf.colored('[SCHEMA]\n', styles=['bold']))
164
173
  result.write(lf.colored(self.schema_repr(), color='red'))
165
- result.write('\n\n')
166
174
 
167
175
  if schema_lib.MISSING != self.output:
176
+ result.write('\n\n')
168
177
  result.write(lf.colored('[OUTPUT]\n', styles=['bold']))
169
178
  result.write(lf.colored(self.output_repr(), color='blue'))
179
+
180
+ if self.metadata:
181
+ result.write('\n\n')
182
+ result.write(lf.colored('[METADATA]\n', styles=['bold']))
183
+ result.write(lf.colored(str(self.metadata), color='cyan'))
170
184
  return result.getvalue().strip()
171
185
 
172
186
 
@@ -129,6 +129,33 @@ class MappingExampleTest(unittest.TestCase):
129
129
  """),
130
130
  )
131
131
 
132
+ def test_str_with_metadata(self):
133
+ self.assertEqual(
134
+ str(
135
+ mapping.MappingExample(
136
+ '1 + 1 = 2',
137
+ schema=int,
138
+ context='Give the answer.',
139
+ metadata={'foo': 'bar'},
140
+ )
141
+ ),
142
+ inspect.cleandoc("""
143
+ \x1b[1m[CONTEXT]
144
+ \x1b[0m\x1b[35mGive the answer.\x1b[0m
145
+
146
+ \x1b[1m[INPUT]
147
+ \x1b[0m\x1b[32m1 + 1 = 2\x1b[0m
148
+
149
+ \x1b[1m[SCHEMA]
150
+ \x1b[0m\x1b[31mint\x1b[0m
151
+
152
+ \x1b[1m[METADATA]
153
+ \x1b[0m\x1b[36m{
154
+ foo = 'bar'
155
+ }\x1b[0m
156
+ """),
157
+ )
158
+
132
159
  def test_serialization(self):
133
160
  example = mapping.MappingExample(
134
161
  'the answer is 2', 2, int, context='compute 1 + 1'
@@ -35,7 +35,55 @@ def score(
35
35
  return_scoring_results: bool = False,
36
36
  **kwargs,
37
37
  ) -> list[float] | list[lf.LMScoringResult]:
38
- """Scores the outputs based on the prompt."""
38
+ """Scores the outputs based on the prompt.
39
+
40
+ Examples:
41
+ ```
42
+ # Example 1: Scoring text output based on the user prompt.
43
+ scores = lf.score('{{x}} + {{y}} =', ['1', '2', '3'], lm=lm, x=1, y=2)
44
+ assert len(scores) == 3
45
+
46
+ # Example 2: Scoring int output based on the formulated OOP prompt.
47
+ scores = lf.score('1 + 1 =', [1, 2, 3], lm=lm)
48
+ assert len(scores) == 3
49
+
50
+ class Answer(pg.Object):
51
+ result: int
52
+
53
+ # Example 3: Scoring object output based on the formulated OOP prompt.
54
+ scores = lf.score('1 + 1 =', [Answer(1), Answer(2), Answer(3)], lm=lm)
55
+ assert len(scores) == 3
56
+
57
+ # Example 4: Scoring object field value based on the formulated OOP prompt
58
+ # and the generated tokens before the first `pg.oneof`.
59
+ scores = lf.score('1 + 1 =', [Answer(pg.oneof([1, 2, 3]))], lm=lm)
60
+ assert len(scores) == 3
61
+
62
+ # Example 5: Scoring multiple prompt/completion pairs.
63
+ scores = lf.score(
64
+ ['1 + 1=', '2 + 3='],
65
+ ['2', '4'],
66
+ lm=lm
67
+ )
68
+ assert len(scores) == 2
69
+ ```
70
+
71
+ Args:
72
+ prompt: The prompt(s) based on which each completion will be scored.
73
+ completions: A list of strings or symbolic objects as the output.
74
+ schema: The schema as the output type. If None, it will be inferred from
75
+ the completions.
76
+ lm: The language model used for scoring.
77
+ examples: Fewshot exemplars used together with the prompt in getting the
78
+ completions.
79
+ protocol: The protocol for formulating the prompt based on objects.
80
+ return_scoring_results: If True, returns a list of `lf.LMScoringResult`,
81
+ otherwise returns a list of floats as the scores of each completion.
82
+ **kwargs: Keyword arguments that are referred by the prompt.
83
+
84
+ Returns:
85
+ A list of floats or `lf.LMScoringResult` as the score of each completion.
86
+ """
39
87
  if not completions:
40
88
  raise ValueError('`completions` must not be empty.')
41
89
 
@@ -79,12 +127,36 @@ def score(
79
127
  completion_reprs = []
80
128
  for c in completions:
81
129
  if isinstance(c, mapping.MappingError):
82
- rep = c.lm_response
130
+ completion_reprs.append(c.lm_response)
83
131
  else:
84
132
  rep = mapping.MappingExample.value_repr(
85
133
  c, protocol=protocol, compact=False, verbose=False
86
134
  )
87
- completion_reprs.append(rep)
135
+
136
+ # NOTE(daiyip): supporting scenario of scoring object field with
137
+ # `pg.oneof`.
138
+ oneof_pos = rep.find('OneOf(')
139
+ if oneof_pos == -1:
140
+ completion_reprs.append(rep)
141
+ else:
142
+ assert protocol == 'python', protocol
143
+ if isinstance(input_message, list):
144
+ raise ValueError(
145
+ 'Scoring on object fields using `pg.oneof` must share the '
146
+ f'same prompt. Encountered: {prompt}'
147
+ )
148
+ input_message.text += '\n' + rep[:oneof_pos]
149
+ oneof = _get_first_oneof(c)
150
+ for v in oneof.candidates:
151
+ completion_reprs.append(
152
+ pg.format(
153
+ v,
154
+ python_format=True,
155
+ compact=False,
156
+ verbose=False,
157
+ root_indent=oneof.sym_path.depth
158
+ )
159
+ )
88
160
 
89
161
  results = lm.score(
90
162
  input_message,
@@ -93,3 +165,17 @@ def score(
93
165
  if return_scoring_results:
94
166
  return results
95
167
  return [r.score for r in results]
168
+
169
+
170
+ def _get_first_oneof(value: Any) -> pg.hyper.OneOf:
171
+ """Gets the first pg.oneof from a symbolic object."""
172
+ oneofs = []
173
+ def select_oneofs(k, v, p):
174
+ del k, p
175
+ if isinstance(v, pg.hyper.OneOf):
176
+ oneofs.append(v)
177
+ return pg.TraverseAction.CONTINUE
178
+ return pg.TraverseAction.ENTER
179
+ pg.traverse(value, select_oneofs)
180
+ assert oneofs
181
+ return oneofs[0]
@@ -16,6 +16,11 @@ import unittest
16
16
  import langfun.core as lf
17
17
  from langfun.core.llms import fake
18
18
  from langfun.core.structured import scoring
19
+ import pyglove as pg
20
+
21
+
22
+ class Answer(pg.Object):
23
+ result: int
19
24
 
20
25
 
21
26
  class ScoringTest(unittest.TestCase):
@@ -32,9 +37,28 @@ class ScoringTest(unittest.TestCase):
32
37
  with self.assertRaisesRegex(ValueError, '`lm` must be specified'):
33
38
  scoring.score('hi', [1, 2])
34
39
 
40
+ with self.assertRaisesRegex(
41
+ ValueError,
42
+ 'Scoring on object fields using `pg.oneof` must share the same prompt',
43
+ ):
44
+ scoring.score(
45
+ ['1 + 1=', '2 + 3='],
46
+ [Answer(pg.oneof([1, 2, 3]))],
47
+ lm=fake.Echo(),
48
+ )
49
+
35
50
  def test_score(self):
36
51
  self.assertEqual(scoring.score('hi', [1, 2], lm=fake.Echo()), [0.0, -1.0])
37
52
 
53
+ def test_score_on_field_values(self):
54
+ self.assertEqual(
55
+ scoring.score(
56
+ '1 + 1=',
57
+ [Answer(pg.oneof([1, 2, 3]))], lm=fake.Echo()
58
+ ),
59
+ [0.0, -1.0, -2.0]
60
+ )
61
+
38
62
  def test_score_returning_scoring_results(self):
39
63
  self.assertEqual(scoring.score(
40
64
  'hi', [1, 2], lm=fake.Echo(), return_scoring_results=True),
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: langfun
3
- Version: 0.1.1.dev20240819
3
+ Version: 0.1.1.dev20240821
4
4
  Summary: Langfun: Language as Functions.
5
5
  Home-page: https://github.com/google/langfun
6
6
  Author: Langfun Authors
@@ -19,8 +19,8 @@ langfun/core/modality.py,sha256=Tla4t86DUYHpbZ2G7dy1r19fTj_Ga5XOvlYp6lbWa-Q,3512
19
19
  langfun/core/modality_test.py,sha256=HyZ5xONKQ0Fw18SzoWAq-Ob9njOXIIjBo1hNtw-rudw,2400
20
20
  langfun/core/natural_language.py,sha256=3ynSnaYQnjE60LIPK5fyMgdIjubnPYZwzGq4rWPeloE,1177
21
21
  langfun/core/natural_language_test.py,sha256=LHGU_1ytbkGuSZQFIFP7vP3dBlcY4-A12fT6dbjUA0E,1424
22
- langfun/core/repr_utils.py,sha256=nKB9U4-8NE8qjx7Zl2g1yXLCbpM6Niq38ReMSyPtfJQ,5512
23
- langfun/core/repr_utils_test.py,sha256=Z018ULMZ8cgmygAH4dNnKBEKduBC7bl1-tZClD1pv9g,2606
22
+ langfun/core/repr_utils.py,sha256=Y6ccoQUMpRxDv_jUy2QtnP9cdz3QBjJtTIgxGIU-kfM,5537
23
+ langfun/core/repr_utils_test.py,sha256=_VhWpDbtlWaGadXL0gpmwQVmACenvzmLUng_AqR6zaE,2685
24
24
  langfun/core/sampling.py,sha256=vygWvgC8MFw0_AKNSmz-ywMXJYWf8cl0tI8QycvAmyI,5795
25
25
  langfun/core/sampling_test.py,sha256=U7PANpMsl9E_pa4_Y4FzesSjcwg-u-LKHGCWSgv-8FY,3663
26
26
  langfun/core/subscription.py,sha256=euawEuSZP-BHydaT-AQpfYFL0m5pWPGcW0upFhrojqc,10930
@@ -44,7 +44,7 @@ langfun/core/coding/python/parsing_test.py,sha256=9vAWF484kWIm6JZq8NFiMgKUDhXV-d
44
44
  langfun/core/coding/python/permissions.py,sha256=1QWGHvzL8MM0Ok_auQ9tURqZHtdOfJaDpBzZ29GUE-c,2544
45
45
  langfun/core/coding/python/permissions_test.py,sha256=w5EDb8QxpxgJyZkojyzVWQvDfg366zn99-g__6TbPQ0,2699
46
46
  langfun/core/eval/__init__.py,sha256=Ogdr9OtTywhhLPHi3AZzOD2mXX2oyaHWflrSTMm96uA,1899
47
- langfun/core/eval/base.py,sha256=0_iaKuQhS49PlbWqCQ5EABUMKavr2R4ltcJZWCVoZZg,73816
47
+ langfun/core/eval/base.py,sha256=BiWColibVo9-4P27Z0hIWXe8_UPocJTSTUdKeOPVwxI,74746
48
48
  langfun/core/eval/base_test.py,sha256=p1EfqviHMz_ppQY8FU67h5OCgL0tzhLvXzGIsq0sVyI,26930
49
49
  langfun/core/eval/matching.py,sha256=9GX8HfO9jKxgNLAivgy5K88Xhoh6Z7Pptq65pe7vht8,9762
50
50
  langfun/core/eval/matching_test.py,sha256=f7iVyXH5KGJBWt4Wp14Bt9J3X59A6Ayfog9MbuFvPew,5532
@@ -96,8 +96,8 @@ langfun/core/structured/description.py,sha256=SXW4MJvshFjbR-0gw6rE21o6WXq12UlRXa
96
96
  langfun/core/structured/description_test.py,sha256=UtZGjSFUaQ6130t1E5tcL7ODu0xIefkapb53TbnqsK8,7362
97
97
  langfun/core/structured/function_generation.py,sha256=pFgS3vcRAWiuFBol2x5Eeip3XqoudONsOpeJpWyjT3s,7479
98
98
  langfun/core/structured/function_generation_test.py,sha256=ZJI-aaGgWWszn92u7h5IZ9Pl70N2DgAGGJrIxPzsvwg,10065
99
- langfun/core/structured/mapping.py,sha256=QKbSnvOgut-sx2mZPjHJcdlDLxR8b3ZC16ZLWociwog,11298
100
- langfun/core/structured/mapping_test.py,sha256=PiXklMeIa8L6KtMi3ju7J9Y39gZy0hIGz-Oeq4A_7XE,3835
99
+ langfun/core/structured/mapping.py,sha256=CsflMwm5cKJYZ2ag-neroA4CQlhu2wjFRSxKpd_qQDQ,11778
100
+ langfun/core/structured/mapping_test.py,sha256=zQoVx3kAD5oSm_OJAQA6q41NXLLyn8qs6CIVJgAoP_w,4489
101
101
  langfun/core/structured/parsing.py,sha256=keoVqEfzAbdULh6GawWFsTQzU91MzJXYFZjXGXLaD8g,11492
102
102
  langfun/core/structured/parsing_test.py,sha256=34wDrXaQ-EYhJLfDL8mX9K53oQMSzh5pVYdKjnESmK8,20895
103
103
  langfun/core/structured/prompting.py,sha256=_U6Z65AwXvVvfaQFCY9GawB_QV9S3u7P7BOU2URABmw,8873
@@ -106,8 +106,8 @@ langfun/core/structured/schema.py,sha256=oiT4P4Q9pG-QOnFzxETN2EQZqNln8nG4zAJHxcm
106
106
  langfun/core/structured/schema_generation.py,sha256=U3nRQsqmMZg_qIVDh2fiY3K4JLfsAL1LcKzIFP1iXFg,5316
107
107
  langfun/core/structured/schema_generation_test.py,sha256=RM9s71kMNg2jTePwInkiW9fK1ACN37eyPeF8OII-0zw,2950
108
108
  langfun/core/structured/schema_test.py,sha256=RjYhwTgktQgyqAjzLvo967nTiIK9KWgP-aNGg4e7ihE,25258
109
- langfun/core/structured/scoring.py,sha256=pE2ilZC7cV1qlPZANOFIFVbNB7IixSTLcnmf9pRU3tc,2883
110
- langfun/core/structured/scoring_test.py,sha256=39_dw6p_FkoqeUccO67yIqos-MccAWezoozS21i8mi0,1732
109
+ langfun/core/structured/scoring.py,sha256=ae6SjLqoqsKFmcPnaJbsFmH4XFGKOQaJRjYZ1wm1Ywo,5860
110
+ langfun/core/structured/scoring_test.py,sha256=QvlwDAzwuamKL5tCotm1L3Sx0cs3idoNK4aIEhaO4Yk,2272
111
111
  langfun/core/templates/__init__.py,sha256=bO0eMsVJbi7sxEB2YlInKRQ2EVP-RyyKUwcD-8msuN4,927
112
112
  langfun/core/templates/completion.py,sha256=mUqZHOEV3ag6-A08XghpeEltcrBvCDxXP004eDDfeag,1931
113
113
  langfun/core/templates/completion_test.py,sha256=vGnjnM38UHyVDUyaUYtmp20s9KBGOdbPVsX-H-ET11E,1636
@@ -117,8 +117,8 @@ langfun/core/templates/demonstration.py,sha256=vCrgYubdZM5Umqcgp8NUVGXgr4P_c-fik
117
117
  langfun/core/templates/demonstration_test.py,sha256=SafcDQ0WgI7pw05EmPI2S4v1t3ABKzup8jReCljHeK4,2162
118
118
  langfun/core/templates/selfplay.py,sha256=yhgrJbiYwq47TgzThmHrDQTF4nDrTI09CWGhuQPNv-s,2273
119
119
  langfun/core/templates/selfplay_test.py,sha256=rBW2Qr8yi-aWYwoTwRR-n1peKyMX9QXPZXURjLgoiRs,2264
120
- langfun-0.1.1.dev20240819.dist-info/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
121
- langfun-0.1.1.dev20240819.dist-info/METADATA,sha256=XaHWEVmO67aqEbYT0Aa0wqV81wHGVkcZ3SgAiI5jOyM,5234
122
- langfun-0.1.1.dev20240819.dist-info/WHEEL,sha256=HiCZjzuy6Dw0hdX5R3LCFPDmFS4BWl8H-8W39XfmgX4,91
123
- langfun-0.1.1.dev20240819.dist-info/top_level.txt,sha256=RhlEkHxs1qtzmmtWSwYoLVJAc1YrbPtxQ52uh8Z9VvY,8
124
- langfun-0.1.1.dev20240819.dist-info/RECORD,,
120
+ langfun-0.1.1.dev20240821.dist-info/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
121
+ langfun-0.1.1.dev20240821.dist-info/METADATA,sha256=469KPCsIx2U_ZtMDN0qA4UTOnbcVayQyduyUs65ccVE,5234
122
+ langfun-0.1.1.dev20240821.dist-info/WHEEL,sha256=Mdi9PDNwEZptOjTlUcAth7XJDFtKrHYaQMPulZeBCiQ,91
123
+ langfun-0.1.1.dev20240821.dist-info/top_level.txt,sha256=RhlEkHxs1qtzmmtWSwYoLVJAc1YrbPtxQ52uh8Z9VvY,8
124
+ langfun-0.1.1.dev20240821.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (72.2.0)
2
+ Generator: setuptools (73.0.1)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5