langfun 0.1.1.dev20240820__py3-none-any.whl → 0.1.1.dev20240822__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -15,6 +15,7 @@
15
15
 
16
16
  import collections
17
17
  import contextlib
18
+ import html
18
19
  import io
19
20
  from typing import Any, Callable, Iterator
20
21
 
@@ -126,7 +127,7 @@ def html_repr(
126
127
  if hasattr(v, '_repr_html_'):
127
128
  cs = v._repr_html_() # pylint: disable=protected-access
128
129
  else:
129
- cs = f'<span style="white-space: pre-wrap">{str(v)}</span>'
130
+ cs = f'<span style="white-space: pre-wrap">{html.escape(str(v))}</span>'
130
131
 
131
132
  key_color, key_bg_color, value_color, value_bg_color = item_color(k, v)
132
133
  key_span = html_round_text(
@@ -63,9 +63,12 @@ class SharingContentTest(unittest.TestCase):
63
63
  class Foo(pg.Object):
64
64
  x: int
65
65
 
66
- html = repr_utils.html_repr({'foo': pg.Ref(Foo(1))})
66
+ html = repr_utils.html_repr(
67
+ {'foo': pg.Ref(Foo(1)), 'bar': '<lf_image>'}
68
+ )
67
69
  self.assertIn('foo</span>', html)
68
70
  self.assertNotIn('Ref', html)
71
+ self.assertIn('&lt;lf_image&gt;', html)
69
72
 
70
73
 
71
74
  if __name__ == '__main__':
@@ -35,7 +35,55 @@ def score(
35
35
  return_scoring_results: bool = False,
36
36
  **kwargs,
37
37
  ) -> list[float] | list[lf.LMScoringResult]:
38
- """Scores the outputs based on the prompt."""
38
+ """Scores the outputs based on the prompt.
39
+
40
+ Examples:
41
+ ```
42
+ # Example 1: Scoring text output based on the user prompt.
43
+ scores = lf.score('{{x}} + {{y}} =', ['1', '2', '3'], lm=lm, x=1, y=2)
44
+ assert len(scores) == 3
45
+
46
+ # Example 2: Scoring int output based on the formulated OOP prompt.
47
+ scores = lf.score('1 + 1 =', [1, 2, 3], lm=lm)
48
+ assert len(scores) == 3
49
+
50
+ class Answer(pg.Object):
51
+ result: int
52
+
53
+ # Example 3: Scoring object output based on the formulated OOP prompt.
54
+ scores = lf.score('1 + 1 =', [Answer(1), Answer(2), Answer(3)], lm=lm)
55
+ assert len(scores) == 3
56
+
57
+ # Example 4: Scoring object field value based on the formulated OOP prompt
58
+ # and the generated tokens before the first `pg.oneof`.
59
+ scores = lf.score('1 + 1 =', [Answer(pg.oneof([1, 2, 3]))], lm=lm)
60
+ assert len(scores) == 3
61
+
62
+ # Example 5: Scoring multiple prompt/completion pairs.
63
+ scores = lf.score(
64
+ ['1 + 1=', '2 + 3='],
65
+ ['2', '4'],
66
+ lm=lm
67
+ )
68
+ assert len(scores) == 2
69
+ ```
70
+
71
+ Args:
72
+ prompt: The prompt(s) based on which each completion will be scored.
73
+ completions: A list of strings or symbolic objects as the output.
74
+ schema: The schema as the output type. If None, it will be inferred from
75
+ the completions.
76
+ lm: The language model used for scoring.
77
+ examples: Fewshot exemplars used together with the prompt in getting the
78
+ completions.
79
+ protocol: The protocol for formulating the prompt based on objects.
80
+ return_scoring_results: If True, returns a list of `lf.LMScoringResult`,
81
+ otherwise returns a list of floats as the scores of each completion.
82
+ **kwargs: Keyword arguments that are referred by the prompt.
83
+
84
+ Returns:
85
+ A list of floats or `lf.LMScoringResult` as the score of each completion.
86
+ """
39
87
  if not completions:
40
88
  raise ValueError('`completions` must not be empty.')
41
89
 
@@ -79,12 +127,36 @@ def score(
79
127
  completion_reprs = []
80
128
  for c in completions:
81
129
  if isinstance(c, mapping.MappingError):
82
- rep = c.lm_response
130
+ completion_reprs.append(c.lm_response)
83
131
  else:
84
132
  rep = mapping.MappingExample.value_repr(
85
133
  c, protocol=protocol, compact=False, verbose=False
86
134
  )
87
- completion_reprs.append(rep)
135
+
136
+ # NOTE(daiyip): supporting scenario of scoring object field with
137
+ # `pg.oneof`.
138
+ oneof_pos = rep.find('OneOf(')
139
+ if oneof_pos == -1:
140
+ completion_reprs.append(rep)
141
+ else:
142
+ assert protocol == 'python', protocol
143
+ if isinstance(input_message, list):
144
+ raise ValueError(
145
+ 'Scoring on object fields using `pg.oneof` must share the '
146
+ f'same prompt. Encountered: {prompt}'
147
+ )
148
+ input_message.text += '\n' + rep[:oneof_pos]
149
+ oneof = _get_first_oneof(c)
150
+ for v in oneof.candidates:
151
+ completion_reprs.append(
152
+ pg.format(
153
+ v,
154
+ python_format=True,
155
+ compact=False,
156
+ verbose=False,
157
+ root_indent=oneof.sym_path.depth
158
+ )
159
+ )
88
160
 
89
161
  results = lm.score(
90
162
  input_message,
@@ -93,3 +165,17 @@ def score(
93
165
  if return_scoring_results:
94
166
  return results
95
167
  return [r.score for r in results]
168
+
169
+
170
+ def _get_first_oneof(value: Any) -> pg.hyper.OneOf:
171
+ """Gets the first pg.oneof from a symbolic object."""
172
+ oneofs = []
173
+ def select_oneofs(k, v, p):
174
+ del k, p
175
+ if isinstance(v, pg.hyper.OneOf):
176
+ oneofs.append(v)
177
+ return pg.TraverseAction.CONTINUE
178
+ return pg.TraverseAction.ENTER
179
+ pg.traverse(value, select_oneofs)
180
+ assert oneofs
181
+ return oneofs[0]
@@ -16,6 +16,11 @@ import unittest
16
16
  import langfun.core as lf
17
17
  from langfun.core.llms import fake
18
18
  from langfun.core.structured import scoring
19
+ import pyglove as pg
20
+
21
+
22
+ class Answer(pg.Object):
23
+ result: int
19
24
 
20
25
 
21
26
  class ScoringTest(unittest.TestCase):
@@ -32,9 +37,28 @@ class ScoringTest(unittest.TestCase):
32
37
  with self.assertRaisesRegex(ValueError, '`lm` must be specified'):
33
38
  scoring.score('hi', [1, 2])
34
39
 
40
+ with self.assertRaisesRegex(
41
+ ValueError,
42
+ 'Scoring on object fields using `pg.oneof` must share the same prompt',
43
+ ):
44
+ scoring.score(
45
+ ['1 + 1=', '2 + 3='],
46
+ [Answer(pg.oneof([1, 2, 3]))],
47
+ lm=fake.Echo(),
48
+ )
49
+
35
50
  def test_score(self):
36
51
  self.assertEqual(scoring.score('hi', [1, 2], lm=fake.Echo()), [0.0, -1.0])
37
52
 
53
+ def test_score_on_field_values(self):
54
+ self.assertEqual(
55
+ scoring.score(
56
+ '1 + 1=',
57
+ [Answer(pg.oneof([1, 2, 3]))], lm=fake.Echo()
58
+ ),
59
+ [0.0, -1.0, -2.0]
60
+ )
61
+
38
62
  def test_score_returning_scoring_results(self):
39
63
  self.assertEqual(scoring.score(
40
64
  'hi', [1, 2], lm=fake.Echo(), return_scoring_results=True),
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: langfun
3
- Version: 0.1.1.dev20240820
3
+ Version: 0.1.1.dev20240822
4
4
  Summary: Langfun: Language as Functions.
5
5
  Home-page: https://github.com/google/langfun
6
6
  Author: Langfun Authors
@@ -19,8 +19,8 @@ langfun/core/modality.py,sha256=Tla4t86DUYHpbZ2G7dy1r19fTj_Ga5XOvlYp6lbWa-Q,3512
19
19
  langfun/core/modality_test.py,sha256=HyZ5xONKQ0Fw18SzoWAq-Ob9njOXIIjBo1hNtw-rudw,2400
20
20
  langfun/core/natural_language.py,sha256=3ynSnaYQnjE60LIPK5fyMgdIjubnPYZwzGq4rWPeloE,1177
21
21
  langfun/core/natural_language_test.py,sha256=LHGU_1ytbkGuSZQFIFP7vP3dBlcY4-A12fT6dbjUA0E,1424
22
- langfun/core/repr_utils.py,sha256=nKB9U4-8NE8qjx7Zl2g1yXLCbpM6Niq38ReMSyPtfJQ,5512
23
- langfun/core/repr_utils_test.py,sha256=Z018ULMZ8cgmygAH4dNnKBEKduBC7bl1-tZClD1pv9g,2606
22
+ langfun/core/repr_utils.py,sha256=Y6ccoQUMpRxDv_jUy2QtnP9cdz3QBjJtTIgxGIU-kfM,5537
23
+ langfun/core/repr_utils_test.py,sha256=_VhWpDbtlWaGadXL0gpmwQVmACenvzmLUng_AqR6zaE,2685
24
24
  langfun/core/sampling.py,sha256=vygWvgC8MFw0_AKNSmz-ywMXJYWf8cl0tI8QycvAmyI,5795
25
25
  langfun/core/sampling_test.py,sha256=U7PANpMsl9E_pa4_Y4FzesSjcwg-u-LKHGCWSgv-8FY,3663
26
26
  langfun/core/subscription.py,sha256=euawEuSZP-BHydaT-AQpfYFL0m5pWPGcW0upFhrojqc,10930
@@ -106,8 +106,8 @@ langfun/core/structured/schema.py,sha256=oiT4P4Q9pG-QOnFzxETN2EQZqNln8nG4zAJHxcm
106
106
  langfun/core/structured/schema_generation.py,sha256=U3nRQsqmMZg_qIVDh2fiY3K4JLfsAL1LcKzIFP1iXFg,5316
107
107
  langfun/core/structured/schema_generation_test.py,sha256=RM9s71kMNg2jTePwInkiW9fK1ACN37eyPeF8OII-0zw,2950
108
108
  langfun/core/structured/schema_test.py,sha256=RjYhwTgktQgyqAjzLvo967nTiIK9KWgP-aNGg4e7ihE,25258
109
- langfun/core/structured/scoring.py,sha256=pE2ilZC7cV1qlPZANOFIFVbNB7IixSTLcnmf9pRU3tc,2883
110
- langfun/core/structured/scoring_test.py,sha256=39_dw6p_FkoqeUccO67yIqos-MccAWezoozS21i8mi0,1732
109
+ langfun/core/structured/scoring.py,sha256=ae6SjLqoqsKFmcPnaJbsFmH4XFGKOQaJRjYZ1wm1Ywo,5860
110
+ langfun/core/structured/scoring_test.py,sha256=QvlwDAzwuamKL5tCotm1L3Sx0cs3idoNK4aIEhaO4Yk,2272
111
111
  langfun/core/templates/__init__.py,sha256=bO0eMsVJbi7sxEB2YlInKRQ2EVP-RyyKUwcD-8msuN4,927
112
112
  langfun/core/templates/completion.py,sha256=mUqZHOEV3ag6-A08XghpeEltcrBvCDxXP004eDDfeag,1931
113
113
  langfun/core/templates/completion_test.py,sha256=vGnjnM38UHyVDUyaUYtmp20s9KBGOdbPVsX-H-ET11E,1636
@@ -117,8 +117,8 @@ langfun/core/templates/demonstration.py,sha256=vCrgYubdZM5Umqcgp8NUVGXgr4P_c-fik
117
117
  langfun/core/templates/demonstration_test.py,sha256=SafcDQ0WgI7pw05EmPI2S4v1t3ABKzup8jReCljHeK4,2162
118
118
  langfun/core/templates/selfplay.py,sha256=yhgrJbiYwq47TgzThmHrDQTF4nDrTI09CWGhuQPNv-s,2273
119
119
  langfun/core/templates/selfplay_test.py,sha256=rBW2Qr8yi-aWYwoTwRR-n1peKyMX9QXPZXURjLgoiRs,2264
120
- langfun-0.1.1.dev20240820.dist-info/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
121
- langfun-0.1.1.dev20240820.dist-info/METADATA,sha256=t3TOnvt67GjpbhKOfk9K19GbEXtGlmNBMwrjqaqbYQU,5234
122
- langfun-0.1.1.dev20240820.dist-info/WHEEL,sha256=nCVcAvsfA9TDtwGwhYaRrlPhTLV9m-Ga6mdyDtuwK18,91
123
- langfun-0.1.1.dev20240820.dist-info/top_level.txt,sha256=RhlEkHxs1qtzmmtWSwYoLVJAc1YrbPtxQ52uh8Z9VvY,8
124
- langfun-0.1.1.dev20240820.dist-info/RECORD,,
120
+ langfun-0.1.1.dev20240822.dist-info/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
121
+ langfun-0.1.1.dev20240822.dist-info/METADATA,sha256=tldtlq7znDyRGiaq62EDI8aqpsKJSpPoSsl1cCE2OUc,5234
122
+ langfun-0.1.1.dev20240822.dist-info/WHEEL,sha256=Mdi9PDNwEZptOjTlUcAth7XJDFtKrHYaQMPulZeBCiQ,91
123
+ langfun-0.1.1.dev20240822.dist-info/top_level.txt,sha256=RhlEkHxs1qtzmmtWSwYoLVJAc1YrbPtxQ52uh8Z9VvY,8
124
+ langfun-0.1.1.dev20240822.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (73.0.0)
2
+ Generator: setuptools (73.0.1)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5