langfun 0.1.1.dev20240820__py3-none-any.whl → 0.1.1.dev20240821__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/core/repr_utils.py +2 -1
- langfun/core/repr_utils_test.py +4 -1
- langfun/core/structured/scoring.py +89 -3
- langfun/core/structured/scoring_test.py +24 -0
- {langfun-0.1.1.dev20240820.dist-info → langfun-0.1.1.dev20240821.dist-info}/METADATA +1 -1
- {langfun-0.1.1.dev20240820.dist-info → langfun-0.1.1.dev20240821.dist-info}/RECORD +9 -9
- {langfun-0.1.1.dev20240820.dist-info → langfun-0.1.1.dev20240821.dist-info}/WHEEL +1 -1
- {langfun-0.1.1.dev20240820.dist-info → langfun-0.1.1.dev20240821.dist-info}/LICENSE +0 -0
- {langfun-0.1.1.dev20240820.dist-info → langfun-0.1.1.dev20240821.dist-info}/top_level.txt +0 -0
langfun/core/repr_utils.py
CHANGED
@@ -15,6 +15,7 @@
|
|
15
15
|
|
16
16
|
import collections
|
17
17
|
import contextlib
|
18
|
+
import html
|
18
19
|
import io
|
19
20
|
from typing import Any, Callable, Iterator
|
20
21
|
|
@@ -126,7 +127,7 @@ def html_repr(
|
|
126
127
|
if hasattr(v, '_repr_html_'):
|
127
128
|
cs = v._repr_html_() # pylint: disable=protected-access
|
128
129
|
else:
|
129
|
-
cs = f'<span style="white-space: pre-wrap">{str(v)}</span>'
|
130
|
+
cs = f'<span style="white-space: pre-wrap">{html.escape(str(v))}</span>'
|
130
131
|
|
131
132
|
key_color, key_bg_color, value_color, value_bg_color = item_color(k, v)
|
132
133
|
key_span = html_round_text(
|
langfun/core/repr_utils_test.py
CHANGED
@@ -63,9 +63,12 @@ class SharingContentTest(unittest.TestCase):
|
|
63
63
|
class Foo(pg.Object):
|
64
64
|
x: int
|
65
65
|
|
66
|
-
html = repr_utils.html_repr(
|
66
|
+
html = repr_utils.html_repr(
|
67
|
+
{'foo': pg.Ref(Foo(1)), 'bar': '<lf_image>'}
|
68
|
+
)
|
67
69
|
self.assertIn('foo</span>', html)
|
68
70
|
self.assertNotIn('Ref', html)
|
71
|
+
self.assertIn('<lf_image>', html)
|
69
72
|
|
70
73
|
|
71
74
|
if __name__ == '__main__':
|
@@ -35,7 +35,55 @@ def score(
|
|
35
35
|
return_scoring_results: bool = False,
|
36
36
|
**kwargs,
|
37
37
|
) -> list[float] | list[lf.LMScoringResult]:
|
38
|
-
"""Scores the outputs based on the prompt.
|
38
|
+
"""Scores the outputs based on the prompt.
|
39
|
+
|
40
|
+
Examples:
|
41
|
+
```
|
42
|
+
# Example 1: Scoring text output based on the user prompt.
|
43
|
+
scores = lf.score('{{x}} + {{y}} =', ['1', '2', '3'], lm=lm, x=1, y=2)
|
44
|
+
assert len(scores) == 3
|
45
|
+
|
46
|
+
# Example 2: Scoring int output based on the formulated OOP prompt.
|
47
|
+
scores = lf.score('1 + 1 =', [1, 2, 3], lm=lm)
|
48
|
+
assert len(scores) == 3
|
49
|
+
|
50
|
+
class Answer(pg.Object):
|
51
|
+
result: int
|
52
|
+
|
53
|
+
# Example 3: Scoring object output based on the formulated OOP prompt.
|
54
|
+
scores = lf.score('1 + 1 =', [Answer(1), Answer(2), Answer(3)], lm=lm)
|
55
|
+
assert len(scores) == 3
|
56
|
+
|
57
|
+
# Example 4: Scoring object field value based on the formulated OOP prompt
|
58
|
+
# and the generated tokens before the first `pg.oneof`.
|
59
|
+
scores = lf.score('1 + 1 =', [Answer(pg.oneof([1, 2, 3]))], lm=lm)
|
60
|
+
assert len(scores) == 3
|
61
|
+
|
62
|
+
# Example 5: Scoring multiple prompt/completion pairs.
|
63
|
+
scores = lf.score(
|
64
|
+
['1 + 1=', '2 + 3='],
|
65
|
+
['2', '4'],
|
66
|
+
lm=lm
|
67
|
+
)
|
68
|
+
assert len(scores) == 2
|
69
|
+
```
|
70
|
+
|
71
|
+
Args:
|
72
|
+
prompt: The prompt(s) based on which each completion will be scored.
|
73
|
+
completions: A list of strings or symbolic objects as the output.
|
74
|
+
schema: The schema as the output type. If None, it will be inferred from
|
75
|
+
the completions.
|
76
|
+
lm: The language model used for scoring.
|
77
|
+
examples: Fewshot exemplars used together with the prompt in getting the
|
78
|
+
completions.
|
79
|
+
protocol: The protocol for formulating the prompt based on objects.
|
80
|
+
return_scoring_results: If True, returns a list of `lf.LMScoringResult`,
|
81
|
+
otherwise returns a list of floats as the scores of each completion.
|
82
|
+
**kwargs: Keyword arguments that are referred by the prompt.
|
83
|
+
|
84
|
+
Returns:
|
85
|
+
A list of floats or `lf.LMScoringResult` as the score of each completion.
|
86
|
+
"""
|
39
87
|
if not completions:
|
40
88
|
raise ValueError('`completions` must not be empty.')
|
41
89
|
|
@@ -79,12 +127,36 @@ def score(
|
|
79
127
|
completion_reprs = []
|
80
128
|
for c in completions:
|
81
129
|
if isinstance(c, mapping.MappingError):
|
82
|
-
|
130
|
+
completion_reprs.append(c.lm_response)
|
83
131
|
else:
|
84
132
|
rep = mapping.MappingExample.value_repr(
|
85
133
|
c, protocol=protocol, compact=False, verbose=False
|
86
134
|
)
|
87
|
-
|
135
|
+
|
136
|
+
# NOTE(daiyip): supporting scenario of scoring object field with
|
137
|
+
# `pg.oneof`.
|
138
|
+
oneof_pos = rep.find('OneOf(')
|
139
|
+
if oneof_pos == -1:
|
140
|
+
completion_reprs.append(rep)
|
141
|
+
else:
|
142
|
+
assert protocol == 'python', protocol
|
143
|
+
if isinstance(input_message, list):
|
144
|
+
raise ValueError(
|
145
|
+
'Scoring on object fields using `pg.oneof` must share the '
|
146
|
+
f'same prompt. Encountered: {prompt}'
|
147
|
+
)
|
148
|
+
input_message.text += '\n' + rep[:oneof_pos]
|
149
|
+
oneof = _get_first_oneof(c)
|
150
|
+
for v in oneof.candidates:
|
151
|
+
completion_reprs.append(
|
152
|
+
pg.format(
|
153
|
+
v,
|
154
|
+
python_format=True,
|
155
|
+
compact=False,
|
156
|
+
verbose=False,
|
157
|
+
root_indent=oneof.sym_path.depth
|
158
|
+
)
|
159
|
+
)
|
88
160
|
|
89
161
|
results = lm.score(
|
90
162
|
input_message,
|
@@ -93,3 +165,17 @@ def score(
|
|
93
165
|
if return_scoring_results:
|
94
166
|
return results
|
95
167
|
return [r.score for r in results]
|
168
|
+
|
169
|
+
|
170
|
+
def _get_first_oneof(value: Any) -> pg.hyper.OneOf:
|
171
|
+
"""Gets the first pg.oneof from a symbolic object."""
|
172
|
+
oneofs = []
|
173
|
+
def select_oneofs(k, v, p):
|
174
|
+
del k, p
|
175
|
+
if isinstance(v, pg.hyper.OneOf):
|
176
|
+
oneofs.append(v)
|
177
|
+
return pg.TraverseAction.CONTINUE
|
178
|
+
return pg.TraverseAction.ENTER
|
179
|
+
pg.traverse(value, select_oneofs)
|
180
|
+
assert oneofs
|
181
|
+
return oneofs[0]
|
@@ -16,6 +16,11 @@ import unittest
|
|
16
16
|
import langfun.core as lf
|
17
17
|
from langfun.core.llms import fake
|
18
18
|
from langfun.core.structured import scoring
|
19
|
+
import pyglove as pg
|
20
|
+
|
21
|
+
|
22
|
+
class Answer(pg.Object):
|
23
|
+
result: int
|
19
24
|
|
20
25
|
|
21
26
|
class ScoringTest(unittest.TestCase):
|
@@ -32,9 +37,28 @@ class ScoringTest(unittest.TestCase):
|
|
32
37
|
with self.assertRaisesRegex(ValueError, '`lm` must be specified'):
|
33
38
|
scoring.score('hi', [1, 2])
|
34
39
|
|
40
|
+
with self.assertRaisesRegex(
|
41
|
+
ValueError,
|
42
|
+
'Scoring on object fields using `pg.oneof` must share the same prompt',
|
43
|
+
):
|
44
|
+
scoring.score(
|
45
|
+
['1 + 1=', '2 + 3='],
|
46
|
+
[Answer(pg.oneof([1, 2, 3]))],
|
47
|
+
lm=fake.Echo(),
|
48
|
+
)
|
49
|
+
|
35
50
|
def test_score(self):
|
36
51
|
self.assertEqual(scoring.score('hi', [1, 2], lm=fake.Echo()), [0.0, -1.0])
|
37
52
|
|
53
|
+
def test_score_on_field_values(self):
|
54
|
+
self.assertEqual(
|
55
|
+
scoring.score(
|
56
|
+
'1 + 1=',
|
57
|
+
[Answer(pg.oneof([1, 2, 3]))], lm=fake.Echo()
|
58
|
+
),
|
59
|
+
[0.0, -1.0, -2.0]
|
60
|
+
)
|
61
|
+
|
38
62
|
def test_score_returning_scoring_results(self):
|
39
63
|
self.assertEqual(scoring.score(
|
40
64
|
'hi', [1, 2], lm=fake.Echo(), return_scoring_results=True),
|
@@ -19,8 +19,8 @@ langfun/core/modality.py,sha256=Tla4t86DUYHpbZ2G7dy1r19fTj_Ga5XOvlYp6lbWa-Q,3512
|
|
19
19
|
langfun/core/modality_test.py,sha256=HyZ5xONKQ0Fw18SzoWAq-Ob9njOXIIjBo1hNtw-rudw,2400
|
20
20
|
langfun/core/natural_language.py,sha256=3ynSnaYQnjE60LIPK5fyMgdIjubnPYZwzGq4rWPeloE,1177
|
21
21
|
langfun/core/natural_language_test.py,sha256=LHGU_1ytbkGuSZQFIFP7vP3dBlcY4-A12fT6dbjUA0E,1424
|
22
|
-
langfun/core/repr_utils.py,sha256=
|
23
|
-
langfun/core/repr_utils_test.py,sha256=
|
22
|
+
langfun/core/repr_utils.py,sha256=Y6ccoQUMpRxDv_jUy2QtnP9cdz3QBjJtTIgxGIU-kfM,5537
|
23
|
+
langfun/core/repr_utils_test.py,sha256=_VhWpDbtlWaGadXL0gpmwQVmACenvzmLUng_AqR6zaE,2685
|
24
24
|
langfun/core/sampling.py,sha256=vygWvgC8MFw0_AKNSmz-ywMXJYWf8cl0tI8QycvAmyI,5795
|
25
25
|
langfun/core/sampling_test.py,sha256=U7PANpMsl9E_pa4_Y4FzesSjcwg-u-LKHGCWSgv-8FY,3663
|
26
26
|
langfun/core/subscription.py,sha256=euawEuSZP-BHydaT-AQpfYFL0m5pWPGcW0upFhrojqc,10930
|
@@ -106,8 +106,8 @@ langfun/core/structured/schema.py,sha256=oiT4P4Q9pG-QOnFzxETN2EQZqNln8nG4zAJHxcm
|
|
106
106
|
langfun/core/structured/schema_generation.py,sha256=U3nRQsqmMZg_qIVDh2fiY3K4JLfsAL1LcKzIFP1iXFg,5316
|
107
107
|
langfun/core/structured/schema_generation_test.py,sha256=RM9s71kMNg2jTePwInkiW9fK1ACN37eyPeF8OII-0zw,2950
|
108
108
|
langfun/core/structured/schema_test.py,sha256=RjYhwTgktQgyqAjzLvo967nTiIK9KWgP-aNGg4e7ihE,25258
|
109
|
-
langfun/core/structured/scoring.py,sha256=
|
110
|
-
langfun/core/structured/scoring_test.py,sha256=
|
109
|
+
langfun/core/structured/scoring.py,sha256=ae6SjLqoqsKFmcPnaJbsFmH4XFGKOQaJRjYZ1wm1Ywo,5860
|
110
|
+
langfun/core/structured/scoring_test.py,sha256=QvlwDAzwuamKL5tCotm1L3Sx0cs3idoNK4aIEhaO4Yk,2272
|
111
111
|
langfun/core/templates/__init__.py,sha256=bO0eMsVJbi7sxEB2YlInKRQ2EVP-RyyKUwcD-8msuN4,927
|
112
112
|
langfun/core/templates/completion.py,sha256=mUqZHOEV3ag6-A08XghpeEltcrBvCDxXP004eDDfeag,1931
|
113
113
|
langfun/core/templates/completion_test.py,sha256=vGnjnM38UHyVDUyaUYtmp20s9KBGOdbPVsX-H-ET11E,1636
|
@@ -117,8 +117,8 @@ langfun/core/templates/demonstration.py,sha256=vCrgYubdZM5Umqcgp8NUVGXgr4P_c-fik
|
|
117
117
|
langfun/core/templates/demonstration_test.py,sha256=SafcDQ0WgI7pw05EmPI2S4v1t3ABKzup8jReCljHeK4,2162
|
118
118
|
langfun/core/templates/selfplay.py,sha256=yhgrJbiYwq47TgzThmHrDQTF4nDrTI09CWGhuQPNv-s,2273
|
119
119
|
langfun/core/templates/selfplay_test.py,sha256=rBW2Qr8yi-aWYwoTwRR-n1peKyMX9QXPZXURjLgoiRs,2264
|
120
|
-
langfun-0.1.1.
|
121
|
-
langfun-0.1.1.
|
122
|
-
langfun-0.1.1.
|
123
|
-
langfun-0.1.1.
|
124
|
-
langfun-0.1.1.
|
120
|
+
langfun-0.1.1.dev20240821.dist-info/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
|
121
|
+
langfun-0.1.1.dev20240821.dist-info/METADATA,sha256=469KPCsIx2U_ZtMDN0qA4UTOnbcVayQyduyUs65ccVE,5234
|
122
|
+
langfun-0.1.1.dev20240821.dist-info/WHEEL,sha256=Mdi9PDNwEZptOjTlUcAth7XJDFtKrHYaQMPulZeBCiQ,91
|
123
|
+
langfun-0.1.1.dev20240821.dist-info/top_level.txt,sha256=RhlEkHxs1qtzmmtWSwYoLVJAc1YrbPtxQ52uh8Z9VvY,8
|
124
|
+
langfun-0.1.1.dev20240821.dist-info/RECORD,,
|
File without changes
|
File without changes
|