langfun 0.0.2.dev20240413__py3-none-any.whl → 0.0.2.dev20240415__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/core/__init__.py +1 -0
- langfun/core/eval/base.py +1 -0
- langfun/core/eval/base_test.py +1 -0
- langfun/core/langfunc_test.py +4 -2
- langfun/core/language_model.py +15 -0
- langfun/core/language_model_test.py +73 -20
- langfun/core/llms/cache/in_memory_test.py +13 -4
- langfun/core/llms/fake.py +31 -22
- langfun/core/llms/fake_test.py +34 -7
- langfun/core/llms/openai.py +8 -21
- langfun/core/llms/openai_test.py +84 -44
- langfun/core/structured/completion_test.py +1 -0
- langfun/core/structured/parsing_test.py +16 -9
- langfun/core/structured/prompting_test.py +1 -0
- langfun/core/templates/selfplay_test.py +6 -2
- {langfun-0.0.2.dev20240413.dist-info → langfun-0.0.2.dev20240415.dist-info}/METADATA +1 -1
- {langfun-0.0.2.dev20240413.dist-info → langfun-0.0.2.dev20240415.dist-info}/RECORD +20 -20
- {langfun-0.0.2.dev20240413.dist-info → langfun-0.0.2.dev20240415.dist-info}/LICENSE +0 -0
- {langfun-0.0.2.dev20240413.dist-info → langfun-0.0.2.dev20240415.dist-info}/WHEEL +0 -0
- {langfun-0.0.2.dev20240413.dist-info → langfun-0.0.2.dev20240415.dist-info}/top_level.txt +0 -0
langfun/core/__init__.py
CHANGED
@@ -99,6 +99,7 @@ from langfun.core.modality import ModalityRef
|
|
99
99
|
from langfun.core.language_model import LanguageModel
|
100
100
|
from langfun.core.language_model import LMSample
|
101
101
|
from langfun.core.language_model import LMSamplingOptions
|
102
|
+
from langfun.core.language_model import LMSamplingUsage
|
102
103
|
from langfun.core.language_model import LMSamplingResult
|
103
104
|
from langfun.core.language_model import LMScoringResult
|
104
105
|
from langfun.core.language_model import LMCache
|
langfun/core/eval/base.py
CHANGED
langfun/core/eval/base_test.py
CHANGED
langfun/core/langfunc_test.py
CHANGED
@@ -82,7 +82,9 @@ class LangFuncCallTest(unittest.TestCase):
|
|
82
82
|
self.assertEqual(i.tags, ['rendered'])
|
83
83
|
|
84
84
|
r = l()
|
85
|
-
self.assertEqual(
|
85
|
+
self.assertEqual(
|
86
|
+
r, message.AIMessage('Hello!!!', score=0.0, logprobs=None, usage=None)
|
87
|
+
)
|
86
88
|
self.assertEqual(r.tags, ['lm-response', 'lm-output'])
|
87
89
|
self.assertEqual(r.source, message.UserMessage('Hello'))
|
88
90
|
self.assertEqual(r.source.tags, ['rendered', 'lm-input'])
|
@@ -106,7 +108,7 @@ class LangFuncCallTest(unittest.TestCase):
|
|
106
108
|
self.assertEqual(l.render(), 'Hello')
|
107
109
|
r = l()
|
108
110
|
self.assertEqual(
|
109
|
-
r, message.AIMessage('Hello!!!', score=0.0, logprobs=None)
|
111
|
+
r, message.AIMessage('Hello!!!', score=0.0, logprobs=None, usage=None)
|
110
112
|
)
|
111
113
|
self.assertEqual(r.tags, ['lm-response', 'lm-output'])
|
112
114
|
|
langfun/core/language_model.py
CHANGED
@@ -47,6 +47,14 @@ class LMSample(pg.Object):
|
|
47
47
|
] = None
|
48
48
|
|
49
49
|
|
50
|
+
class LMSamplingUsage(pg.Object):
|
51
|
+
"""Usage information per completion."""
|
52
|
+
|
53
|
+
prompt_tokens: int
|
54
|
+
completion_tokens: int
|
55
|
+
total_tokens: int
|
56
|
+
|
57
|
+
|
50
58
|
class LMSamplingResult(pg.Object):
|
51
59
|
"""Language model response."""
|
52
60
|
|
@@ -58,6 +66,11 @@ class LMSamplingResult(pg.Object):
|
|
58
66
|
),
|
59
67
|
] = []
|
60
68
|
|
69
|
+
usage: Annotated[
|
70
|
+
LMSamplingUsage | None,
|
71
|
+
'Usage information. Currently only OpenAI models are supported.',
|
72
|
+
] = None
|
73
|
+
|
61
74
|
|
62
75
|
class LMSamplingOptions(component.Component):
|
63
76
|
"""Language model sampling options."""
|
@@ -424,6 +437,8 @@ class LanguageModel(component.Component):
|
|
424
437
|
logprobs = result.samples[0].logprobs
|
425
438
|
response.set('score', result.samples[0].score)
|
426
439
|
response.metadata.logprobs = logprobs
|
440
|
+
response.metadata.usage = result.usage
|
441
|
+
|
427
442
|
elapse = time.time() - request_start
|
428
443
|
self._debug(prompt, response, call_counter, elapse)
|
429
444
|
return response
|
@@ -38,9 +38,19 @@ class MockModel(lm_lib.LanguageModel):
|
|
38
38
|
def fake_sample(prompts):
|
39
39
|
if context.attempt >= self.failures_before_attempt:
|
40
40
|
return [
|
41
|
-
lm_lib.LMSamplingResult(
|
42
|
-
|
43
|
-
|
41
|
+
lm_lib.LMSamplingResult(
|
42
|
+
[
|
43
|
+
lm_lib.LMSample( # pylint: disable=g-complex-comprehension
|
44
|
+
response=prompt.text * self.sampling_options.top_k,
|
45
|
+
score=self.sampling_options.temperature or -1.0,
|
46
|
+
)
|
47
|
+
],
|
48
|
+
usage=lm_lib.LMSamplingUsage(
|
49
|
+
prompt_tokens=100,
|
50
|
+
completion_tokens=100,
|
51
|
+
total_tokens=200,
|
52
|
+
),
|
53
|
+
)
|
44
54
|
for prompt in prompts
|
45
55
|
]
|
46
56
|
context.attempt += 1
|
@@ -100,8 +110,14 @@ class LanguageModelTest(unittest.TestCase):
|
|
100
110
|
self.assertEqual(
|
101
111
|
lm.sample(prompts=['foo', 'bar']),
|
102
112
|
[
|
103
|
-
lm_lib.LMSamplingResult(
|
104
|
-
|
113
|
+
lm_lib.LMSamplingResult(
|
114
|
+
[lm_lib.LMSample('foo', score=-1.0)],
|
115
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
116
|
+
),
|
117
|
+
lm_lib.LMSamplingResult(
|
118
|
+
[lm_lib.LMSample('bar', score=-1.0)],
|
119
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
120
|
+
),
|
105
121
|
],
|
106
122
|
)
|
107
123
|
# Test override sampling_options.
|
@@ -112,10 +128,12 @@ class LanguageModelTest(unittest.TestCase):
|
|
112
128
|
),
|
113
129
|
[
|
114
130
|
lm_lib.LMSamplingResult(
|
115
|
-
[lm_lib.LMSample('foo' * 2, score=0.5)]
|
131
|
+
[lm_lib.LMSample('foo' * 2, score=0.5)],
|
132
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
116
133
|
),
|
117
134
|
lm_lib.LMSamplingResult(
|
118
|
-
[lm_lib.LMSample('bar' * 2, score=0.5)]
|
135
|
+
[lm_lib.LMSample('bar' * 2, score=0.5)],
|
136
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
119
137
|
),
|
120
138
|
],
|
121
139
|
)
|
@@ -123,18 +141,26 @@ class LanguageModelTest(unittest.TestCase):
|
|
123
141
|
self.assertEqual(
|
124
142
|
lm.sample(prompts=['foo', 'bar'], temperature=1.0),
|
125
143
|
[
|
126
|
-
lm_lib.LMSamplingResult(
|
127
|
-
|
144
|
+
lm_lib.LMSamplingResult(
|
145
|
+
[lm_lib.LMSample('foo', score=1.0)],
|
146
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
147
|
+
),
|
148
|
+
lm_lib.LMSamplingResult(
|
149
|
+
[lm_lib.LMSample('bar', score=1.0)],
|
150
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
151
|
+
),
|
128
152
|
],
|
129
153
|
)
|
130
154
|
self.assertEqual(
|
131
155
|
lm.sample(prompts=['foo', 'bar'], top_k=2, temperature=0.7),
|
132
156
|
[
|
133
157
|
lm_lib.LMSamplingResult(
|
134
|
-
[lm_lib.LMSample('foo' * 2, score=0.7)]
|
158
|
+
[lm_lib.LMSample('foo' * 2, score=0.7)],
|
159
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
135
160
|
),
|
136
161
|
lm_lib.LMSamplingResult(
|
137
|
-
[lm_lib.LMSample('bar' * 2, score=0.7)]
|
162
|
+
[lm_lib.LMSample('bar' * 2, score=0.7)],
|
163
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
138
164
|
),
|
139
165
|
],
|
140
166
|
)
|
@@ -144,6 +170,8 @@ class LanguageModelTest(unittest.TestCase):
|
|
144
170
|
response = lm(prompt='foo')
|
145
171
|
self.assertEqual(response.text, 'foo')
|
146
172
|
self.assertEqual(response.score, -1.0)
|
173
|
+
self.assertIsNone(response.logprobs)
|
174
|
+
self.assertEqual(response.usage, lm_lib.LMSamplingUsage(100, 100, 200))
|
147
175
|
|
148
176
|
# Test override sampling_options.
|
149
177
|
self.assertEqual(
|
@@ -158,11 +186,24 @@ class LanguageModelTest(unittest.TestCase):
|
|
158
186
|
self.assertEqual(
|
159
187
|
lm.sample(prompts=['foo', 'bar']),
|
160
188
|
[
|
161
|
-
lm_lib.LMSamplingResult(
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
189
|
+
lm_lib.LMSamplingResult(
|
190
|
+
[
|
191
|
+
lm_lib.LMSample(
|
192
|
+
message_lib.AIMessage('foo', cache_seed=0), score=-1.0
|
193
|
+
)
|
194
|
+
],
|
195
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
196
|
+
),
|
197
|
+
lm_lib.LMSamplingResult(
|
198
|
+
[
|
199
|
+
lm_lib.LMSample(
|
200
|
+
message_lib.AIMessage('bar', cache_seed=0), score=-1.0
|
201
|
+
)
|
202
|
+
],
|
203
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
204
|
+
),
|
205
|
+
],
|
206
|
+
)
|
166
207
|
self.assertEqual(cache.stats.num_queries, 2)
|
167
208
|
self.assertEqual(cache.stats.num_hits, 0)
|
168
209
|
self.assertEqual(cache.stats.num_updates, 2)
|
@@ -181,10 +222,22 @@ class LanguageModelTest(unittest.TestCase):
|
|
181
222
|
self.assertEqual(
|
182
223
|
lm.sample(prompts=['foo', 'baz'], temperature=1.0),
|
183
224
|
[
|
184
|
-
lm_lib.LMSamplingResult(
|
185
|
-
|
186
|
-
|
187
|
-
|
225
|
+
lm_lib.LMSamplingResult(
|
226
|
+
[
|
227
|
+
lm_lib.LMSample(
|
228
|
+
message_lib.AIMessage('foo', cache_seed=0), score=1.0
|
229
|
+
)
|
230
|
+
],
|
231
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
232
|
+
),
|
233
|
+
lm_lib.LMSamplingResult(
|
234
|
+
[
|
235
|
+
lm_lib.LMSample(
|
236
|
+
message_lib.AIMessage('baz', cache_seed=0), score=1.0
|
237
|
+
)
|
238
|
+
],
|
239
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
240
|
+
),
|
188
241
|
],
|
189
242
|
)
|
190
243
|
self.assertEqual(cache.stats.num_queries, 6)
|
@@ -62,10 +62,19 @@ class InMemoryLMCacheTest(unittest.TestCase):
|
|
62
62
|
|
63
63
|
def cache_entry(response_text, cache_seed=0):
|
64
64
|
return base.LMCacheEntry(
|
65
|
-
lf.LMSamplingResult(
|
66
|
-
|
67
|
-
lf.
|
68
|
-
|
65
|
+
lf.LMSamplingResult(
|
66
|
+
[
|
67
|
+
lf.LMSample(
|
68
|
+
lf.AIMessage(response_text, cache_seed=cache_seed),
|
69
|
+
score=1.0
|
70
|
+
)
|
71
|
+
],
|
72
|
+
usage=lf.LMSamplingUsage(
|
73
|
+
1,
|
74
|
+
len(response_text),
|
75
|
+
len(response_text) + 1,
|
76
|
+
)
|
77
|
+
)
|
69
78
|
)
|
70
79
|
|
71
80
|
self.assertEqual(
|
langfun/core/llms/fake.py
CHANGED
@@ -13,6 +13,7 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
"""Fake LMs for testing."""
|
15
15
|
|
16
|
+
import abc
|
16
17
|
from typing import Annotated
|
17
18
|
import langfun.core as lf
|
18
19
|
|
@@ -23,15 +24,32 @@ class Fake(lf.LanguageModel):
|
|
23
24
|
def _score(self, prompt: lf.Message, completions: list[lf.Message]):
|
24
25
|
return [lf.LMScoringResult(score=-i * 1.0) for i in range(len(completions))]
|
25
26
|
|
27
|
+
def _sample(self, prompts: list[lf.Message]) -> list[lf.LMSamplingResult]:
|
28
|
+
results = []
|
29
|
+
for prompt in prompts:
|
30
|
+
response = self._response_from(prompt)
|
31
|
+
results.append(
|
32
|
+
lf.LMSamplingResult(
|
33
|
+
[lf.LMSample(response, 1.0)],
|
34
|
+
usage=lf.LMSamplingUsage(
|
35
|
+
prompt_tokens=len(prompt.text),
|
36
|
+
completion_tokens=len(response.text),
|
37
|
+
total_tokens=len(prompt.text) + len(response.text),
|
38
|
+
)
|
39
|
+
)
|
40
|
+
)
|
41
|
+
return results
|
42
|
+
|
43
|
+
@abc.abstractmethod
|
44
|
+
def _response_from(self, prompt: lf.Message) -> lf.Message:
|
45
|
+
"""Returns the response for the given prompt."""
|
46
|
+
|
26
47
|
|
27
48
|
class Echo(Fake):
|
28
49
|
"""A simple echo language model for testing."""
|
29
50
|
|
30
|
-
def
|
31
|
-
return
|
32
|
-
lf.LMSamplingResult([lf.LMSample(prompt.text, 1.0)])
|
33
|
-
for prompt in prompts
|
34
|
-
]
|
51
|
+
def _response_from(self, prompt: lf.Message) -> lf.Message:
|
52
|
+
return lf.AIMessage(prompt.text)
|
35
53
|
|
36
54
|
|
37
55
|
@lf.use_init_args(['response'])
|
@@ -43,11 +61,8 @@ class StaticResponse(Fake):
|
|
43
61
|
'A canned response that will be returned regardless of the prompt.'
|
44
62
|
]
|
45
63
|
|
46
|
-
def
|
47
|
-
return
|
48
|
-
lf.LMSamplingResult([lf.LMSample(self.response, 1.0)])
|
49
|
-
for _ in prompts
|
50
|
-
]
|
64
|
+
def _response_from(self, prompt: lf.Message) -> lf.Message:
|
65
|
+
return lf.AIMessage(self.response)
|
51
66
|
|
52
67
|
|
53
68
|
@lf.use_init_args(['mapping'])
|
@@ -59,11 +74,8 @@ class StaticMapping(Fake):
|
|
59
74
|
'A mapping from prompt to response.'
|
60
75
|
]
|
61
76
|
|
62
|
-
def
|
63
|
-
return [
|
64
|
-
lf.LMSamplingResult([lf.LMSample(self.mapping[prompt], 1.0)])
|
65
|
-
for prompt in prompts
|
66
|
-
]
|
77
|
+
def _response_from(self, prompt: lf.Message) -> lf.Message:
|
78
|
+
return lf.AIMessage(self.mapping[prompt])
|
67
79
|
|
68
80
|
|
69
81
|
@lf.use_init_args(['sequence'])
|
@@ -79,10 +91,7 @@ class StaticSequence(Fake):
|
|
79
91
|
super()._on_bound()
|
80
92
|
self._pos = 0
|
81
93
|
|
82
|
-
def
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
[lf.LMSample(self.sequence[self._pos], 1.0)]))
|
87
|
-
self._pos += 1
|
88
|
-
return results
|
94
|
+
def _response_from(self, prompt: lf.Message) -> lf.Message:
|
95
|
+
r = lf.AIMessage(self.sequence[self._pos])
|
96
|
+
self._pos += 1
|
97
|
+
return r
|
langfun/core/llms/fake_test.py
CHANGED
@@ -25,7 +25,12 @@ class EchoTest(unittest.TestCase):
|
|
25
25
|
def test_sample(self):
|
26
26
|
lm = fakelm.Echo()
|
27
27
|
self.assertEqual(
|
28
|
-
lm.sample(['hi']),
|
28
|
+
lm.sample(['hi']),
|
29
|
+
[
|
30
|
+
lf.LMSamplingResult(
|
31
|
+
[lf.LMSample('hi', 1.0)],
|
32
|
+
lf.LMSamplingUsage(2, 2, 4))
|
33
|
+
]
|
29
34
|
)
|
30
35
|
|
31
36
|
def test_call(self):
|
@@ -53,11 +58,21 @@ class StaticResponseTest(unittest.TestCase):
|
|
53
58
|
lm = fakelm.StaticResponse(canned_response)
|
54
59
|
self.assertEqual(
|
55
60
|
lm.sample(['hi']),
|
56
|
-
[
|
61
|
+
[
|
62
|
+
lf.LMSamplingResult(
|
63
|
+
[lf.LMSample(canned_response, 1.0)],
|
64
|
+
usage=lf.LMSamplingUsage(2, 38, 40)
|
65
|
+
)
|
66
|
+
],
|
57
67
|
)
|
58
68
|
self.assertEqual(
|
59
69
|
lm.sample(['Tell me a joke.']),
|
60
|
-
[
|
70
|
+
[
|
71
|
+
lf.LMSamplingResult(
|
72
|
+
[lf.LMSample(canned_response, 1.0)],
|
73
|
+
usage=lf.LMSamplingUsage(15, 38, 53)
|
74
|
+
)
|
75
|
+
],
|
61
76
|
)
|
62
77
|
|
63
78
|
def test_call(self):
|
@@ -85,8 +100,14 @@ class StaticMappingTest(unittest.TestCase):
|
|
85
100
|
self.assertEqual(
|
86
101
|
lm.sample(['Hi', 'How are you?']),
|
87
102
|
[
|
88
|
-
lf.LMSamplingResult(
|
89
|
-
|
103
|
+
lf.LMSamplingResult(
|
104
|
+
[lf.LMSample('Hello', 1.0)],
|
105
|
+
usage=lf.LMSamplingUsage(2, 5, 7)
|
106
|
+
),
|
107
|
+
lf.LMSamplingResult(
|
108
|
+
[lf.LMSample('I am fine, how about you?', 1.0)],
|
109
|
+
usage=lf.LMSamplingUsage(12, 25, 37)
|
110
|
+
)
|
90
111
|
]
|
91
112
|
)
|
92
113
|
with self.assertRaises(KeyError):
|
@@ -104,8 +125,14 @@ class StaticSequenceTest(unittest.TestCase):
|
|
104
125
|
self.assertEqual(
|
105
126
|
lm.sample(['Hi', 'How are you?']),
|
106
127
|
[
|
107
|
-
lf.LMSamplingResult(
|
108
|
-
|
128
|
+
lf.LMSamplingResult(
|
129
|
+
[lf.LMSample('Hello', 1.0)],
|
130
|
+
usage=lf.LMSamplingUsage(2, 5, 7)
|
131
|
+
),
|
132
|
+
lf.LMSamplingResult(
|
133
|
+
[lf.LMSample('I am fine, how about you?', 1.0)],
|
134
|
+
usage=lf.LMSamplingUsage(12, 25, 37)
|
135
|
+
)
|
109
136
|
]
|
110
137
|
)
|
111
138
|
with self.assertRaises(IndexError):
|
langfun/core/llms/openai.py
CHANGED
@@ -26,20 +26,6 @@ from openai import openai_object
|
|
26
26
|
import pyglove as pg
|
27
27
|
|
28
28
|
|
29
|
-
class Usage(pg.Object):
|
30
|
-
"""Usage information per completion."""
|
31
|
-
|
32
|
-
prompt_tokens: int
|
33
|
-
completion_tokens: int
|
34
|
-
total_tokens: int
|
35
|
-
|
36
|
-
|
37
|
-
class LMSamplingResult(lf.LMSamplingResult):
|
38
|
-
"""LMSamplingResult with usage information."""
|
39
|
-
|
40
|
-
usage: Usage | None = None
|
41
|
-
|
42
|
-
|
43
29
|
SUPPORTED_MODELS_AND_SETTINGS = [
|
44
30
|
# Model name, max concurrent requests.
|
45
31
|
# The concurrent requests is estimated by TPM/RPM from
|
@@ -181,7 +167,7 @@ class OpenAI(lf.LanguageModel):
|
|
181
167
|
args['stop'] = options.stop
|
182
168
|
return args
|
183
169
|
|
184
|
-
def _sample(self, prompts: list[lf.Message]) -> list[LMSamplingResult]:
|
170
|
+
def _sample(self, prompts: list[lf.Message]) -> list[lf.LMSamplingResult]:
|
185
171
|
assert self._api_initialized
|
186
172
|
if self.is_chat_model:
|
187
173
|
return self._chat_complete_batch(prompts)
|
@@ -189,7 +175,8 @@ class OpenAI(lf.LanguageModel):
|
|
189
175
|
return self._complete_batch(prompts)
|
190
176
|
|
191
177
|
def _complete_batch(
|
192
|
-
self, prompts: list[lf.Message]
|
178
|
+
self, prompts: list[lf.Message]
|
179
|
+
) -> list[lf.LMSamplingResult]:
|
193
180
|
|
194
181
|
def _open_ai_completion(prompts):
|
195
182
|
response = openai.Completion.create(
|
@@ -204,13 +191,13 @@ class OpenAI(lf.LanguageModel):
|
|
204
191
|
lf.LMSample(choice.text.strip(), score=choice.logprobs or 0.0)
|
205
192
|
)
|
206
193
|
|
207
|
-
usage =
|
194
|
+
usage = lf.LMSamplingUsage(
|
208
195
|
prompt_tokens=response.usage.prompt_tokens,
|
209
196
|
completion_tokens=response.usage.completion_tokens,
|
210
197
|
total_tokens=response.usage.total_tokens,
|
211
198
|
)
|
212
199
|
return [
|
213
|
-
LMSamplingResult(
|
200
|
+
lf.LMSamplingResult(
|
214
201
|
samples_by_index[index], usage=usage if index == 0 else None
|
215
202
|
)
|
216
203
|
for index in sorted(samples_by_index.keys())
|
@@ -231,7 +218,7 @@ class OpenAI(lf.LanguageModel):
|
|
231
218
|
|
232
219
|
def _chat_complete_batch(
|
233
220
|
self, prompts: list[lf.Message]
|
234
|
-
) -> list[LMSamplingResult]:
|
221
|
+
) -> list[lf.LMSamplingResult]:
|
235
222
|
def _open_ai_chat_completion(prompt: lf.Message):
|
236
223
|
if self.multimodal:
|
237
224
|
content = []
|
@@ -272,9 +259,9 @@ class OpenAI(lf.LanguageModel):
|
|
272
259
|
)
|
273
260
|
)
|
274
261
|
|
275
|
-
return LMSamplingResult(
|
262
|
+
return lf.LMSamplingResult(
|
276
263
|
samples=samples,
|
277
|
-
usage=
|
264
|
+
usage=lf.LMSamplingUsage(
|
278
265
|
prompt_tokens=response.usage.prompt_tokens,
|
279
266
|
completion_tokens=response.usage.completion_tokens,
|
280
267
|
total_tokens=response.usage.total_tokens,
|
langfun/core/llms/openai_test.py
CHANGED
@@ -32,11 +32,14 @@ def mock_completion_query(prompt, *, n=1, **kwargs):
|
|
32
32
|
text=f'Sample {k} for prompt {i}.',
|
33
33
|
logprobs=k / 10,
|
34
34
|
))
|
35
|
-
return pg.Dict(
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
35
|
+
return pg.Dict(
|
36
|
+
choices=choices,
|
37
|
+
usage=lf.LMSamplingUsage(
|
38
|
+
prompt_tokens=100,
|
39
|
+
completion_tokens=100,
|
40
|
+
total_tokens=200,
|
41
|
+
),
|
42
|
+
)
|
40
43
|
|
41
44
|
|
42
45
|
def mock_chat_completion_query(messages, *, n=1, **kwargs):
|
@@ -49,11 +52,14 @@ def mock_chat_completion_query(messages, *, n=1, **kwargs):
|
|
49
52
|
),
|
50
53
|
logprobs=None,
|
51
54
|
))
|
52
|
-
return pg.Dict(
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
55
|
+
return pg.Dict(
|
56
|
+
choices=choices,
|
57
|
+
usage=lf.LMSamplingUsage(
|
58
|
+
prompt_tokens=100,
|
59
|
+
completion_tokens=100,
|
60
|
+
total_tokens=200,
|
61
|
+
),
|
62
|
+
)
|
57
63
|
|
58
64
|
|
59
65
|
def mock_chat_completion_query_vision(messages, *, n=1, **kwargs):
|
@@ -69,11 +75,14 @@ def mock_chat_completion_query_vision(messages, *, n=1, **kwargs):
|
|
69
75
|
),
|
70
76
|
logprobs=None,
|
71
77
|
))
|
72
|
-
return pg.Dict(
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
78
|
+
return pg.Dict(
|
79
|
+
choices=choices,
|
80
|
+
usage=lf.LMSamplingUsage(
|
81
|
+
prompt_tokens=100,
|
82
|
+
completion_tokens=100,
|
83
|
+
total_tokens=200,
|
84
|
+
),
|
85
|
+
)
|
77
86
|
|
78
87
|
|
79
88
|
class OpenaiTest(unittest.TestCase):
|
@@ -169,18 +178,28 @@ class OpenaiTest(unittest.TestCase):
|
|
169
178
|
)
|
170
179
|
|
171
180
|
self.assertEqual(len(results), 2)
|
172
|
-
self.assertEqual(
|
173
|
-
|
174
|
-
lf.
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
181
|
+
self.assertEqual(
|
182
|
+
results[0],
|
183
|
+
lf.LMSamplingResult(
|
184
|
+
[
|
185
|
+
lf.LMSample('Sample 0 for prompt 0.', score=0.0),
|
186
|
+
lf.LMSample('Sample 1 for prompt 0.', score=0.1),
|
187
|
+
lf.LMSample('Sample 2 for prompt 0.', score=0.2),
|
188
|
+
],
|
189
|
+
usage=lf.LMSamplingUsage(
|
190
|
+
prompt_tokens=100, completion_tokens=100, total_tokens=200
|
191
|
+
),
|
192
|
+
),
|
193
|
+
)
|
194
|
+
|
195
|
+
self.assertEqual(
|
196
|
+
results[1],
|
197
|
+
lf.LMSamplingResult([
|
198
|
+
lf.LMSample('Sample 0 for prompt 1.', score=0.0),
|
199
|
+
lf.LMSample('Sample 1 for prompt 1.', score=0.1),
|
200
|
+
lf.LMSample('Sample 2 for prompt 1.', score=0.2),
|
201
|
+
]),
|
202
|
+
)
|
184
203
|
|
185
204
|
def test_sample_chat_completion(self):
|
186
205
|
with mock.patch('openai.ChatCompletion.create') as mock_chat_completion:
|
@@ -191,18 +210,32 @@ class OpenaiTest(unittest.TestCase):
|
|
191
210
|
)
|
192
211
|
|
193
212
|
self.assertEqual(len(results), 2)
|
194
|
-
self.assertEqual(
|
195
|
-
|
196
|
-
lf.
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
213
|
+
self.assertEqual(
|
214
|
+
results[0],
|
215
|
+
lf.LMSamplingResult(
|
216
|
+
[
|
217
|
+
lf.LMSample('Sample 0 for message.', score=0.0),
|
218
|
+
lf.LMSample('Sample 1 for message.', score=0.0),
|
219
|
+
lf.LMSample('Sample 2 for message.', score=0.0),
|
220
|
+
],
|
221
|
+
usage=lf.LMSamplingUsage(
|
222
|
+
prompt_tokens=100, completion_tokens=100, total_tokens=200
|
223
|
+
),
|
224
|
+
),
|
225
|
+
)
|
226
|
+
self.assertEqual(
|
227
|
+
results[1],
|
228
|
+
lf.LMSamplingResult(
|
229
|
+
[
|
230
|
+
lf.LMSample('Sample 0 for message.', score=0.0),
|
231
|
+
lf.LMSample('Sample 1 for message.', score=0.0),
|
232
|
+
lf.LMSample('Sample 2 for message.', score=0.0),
|
233
|
+
],
|
234
|
+
usage=lf.LMSamplingUsage(
|
235
|
+
prompt_tokens=100, completion_tokens=100, total_tokens=200
|
236
|
+
),
|
237
|
+
),
|
238
|
+
)
|
206
239
|
|
207
240
|
def test_sample_with_contextual_options(self):
|
208
241
|
with mock.patch('openai.Completion.create') as mock_completion:
|
@@ -212,11 +245,18 @@ class OpenaiTest(unittest.TestCase):
|
|
212
245
|
results = lm.sample(['hello'])
|
213
246
|
|
214
247
|
self.assertEqual(len(results), 1)
|
215
|
-
self.assertEqual(
|
216
|
-
|
217
|
-
lf.
|
218
|
-
|
219
|
-
|
248
|
+
self.assertEqual(
|
249
|
+
results[0],
|
250
|
+
lf.LMSamplingResult(
|
251
|
+
[
|
252
|
+
lf.LMSample('Sample 0 for prompt 0.', score=0.0),
|
253
|
+
lf.LMSample('Sample 1 for prompt 0.', score=0.1),
|
254
|
+
],
|
255
|
+
usage=lf.LMSamplingUsage(
|
256
|
+
prompt_tokens=100, completion_tokens=100, total_tokens=200
|
257
|
+
),
|
258
|
+
),
|
259
|
+
)
|
220
260
|
|
221
261
|
|
222
262
|
if __name__ == '__main__':
|
@@ -280,13 +280,15 @@ class ParseStructurePythonTest(unittest.TestCase):
|
|
280
280
|
),
|
281
281
|
1,
|
282
282
|
)
|
283
|
+
r = parsing.parse(
|
284
|
+
'the answer is 1', int, user_prompt='what is 0 + 1?', lm=lm,
|
285
|
+
returns_message=True
|
286
|
+
)
|
283
287
|
self.assertEqual(
|
284
|
-
|
285
|
-
'the answer is 1', int, user_prompt='what is 0 + 1?', lm=lm,
|
286
|
-
returns_message=True
|
287
|
-
),
|
288
|
+
r,
|
288
289
|
lf.AIMessage(
|
289
290
|
'1', score=1.0, result=1, logprobs=None,
|
291
|
+
usage=lf.LMSamplingUsage(652, 1, 653),
|
290
292
|
tags=['lm-response', 'lm-output', 'transformed']
|
291
293
|
),
|
292
294
|
)
|
@@ -634,13 +636,18 @@ class CallTest(unittest.TestCase):
|
|
634
636
|
)
|
635
637
|
|
636
638
|
def test_call_with_returning_message(self):
|
639
|
+
r = parsing.call(
|
640
|
+
'Compute 1 + 2', int, lm=fake.StaticSequence(['three', '3']),
|
641
|
+
returns_message=True
|
642
|
+
)
|
637
643
|
self.assertEqual(
|
638
|
-
|
639
|
-
'Compute 1 + 2', int, lm=fake.StaticSequence(['three', '3']),
|
640
|
-
returns_message=True
|
641
|
-
),
|
644
|
+
r,
|
642
645
|
lf.AIMessage(
|
643
|
-
'3',
|
646
|
+
'3',
|
647
|
+
result=3,
|
648
|
+
score=1.0,
|
649
|
+
logprobs=None,
|
650
|
+
usage=lf.LMSamplingUsage(315, 1, 316),
|
644
651
|
tags=['lm-response', 'lm-output', 'transformed']
|
645
652
|
),
|
646
653
|
)
|
@@ -56,7 +56,9 @@ class SelfPlayTest(unittest.TestCase):
|
|
56
56
|
g = NumberGuess(target_num=10)
|
57
57
|
|
58
58
|
with lf.context(lm=NumberGuesser(guesses=[50, 20, 5, 10])):
|
59
|
-
self.assertEqual(
|
59
|
+
self.assertEqual(
|
60
|
+
g(), lf.AIMessage('10', score=0.0, logprobs=None, usage=None)
|
61
|
+
)
|
60
62
|
|
61
63
|
self.assertEqual(g.num_turns, 4)
|
62
64
|
|
@@ -64,7 +66,9 @@ class SelfPlayTest(unittest.TestCase):
|
|
64
66
|
g = NumberGuess(target_num=10, max_turns=10)
|
65
67
|
|
66
68
|
with lf.context(lm=NumberGuesser(guesses=[50, 20, 5, 2, 5, 4])):
|
67
|
-
self.assertEqual(
|
69
|
+
self.assertEqual(
|
70
|
+
g(), lf.AIMessage('2', score=0.0, logprobs=None, usage=None)
|
71
|
+
)
|
68
72
|
|
69
73
|
self.assertEqual(g.num_turns, 10)
|
70
74
|
|
@@ -1,5 +1,5 @@
|
|
1
1
|
langfun/__init__.py,sha256=PqX3u18BC0szYIMu00j-RKxvwkNPwXtAFZ-96oxrQ0M,1841
|
2
|
-
langfun/core/__init__.py,sha256=
|
2
|
+
langfun/core/__init__.py,sha256=6QEuXOZ9BXxm6TjpaMXuLwUBTYO3pkFDqn9QVBXyyPQ,4248
|
3
3
|
langfun/core/component.py,sha256=VRPfDB_2jEnxcB3-HoiVjG4ID-SMenNPIsytb0uXMPg,9674
|
4
4
|
langfun/core/component_test.py,sha256=VAPd6V_-odAe8rBvesW3ogYDd6OSqRq4FaPhfgOM4Zg,7949
|
5
5
|
langfun/core/concurrent.py,sha256=TRc49pJ3HQro2kb5FtcWkHjhBm8UcgE8RJybU5cU3-0,24537
|
@@ -7,9 +7,9 @@ langfun/core/concurrent_test.py,sha256=mwFMZhDUdppnDr7vDSTwcbMHwrdsIoKJwRYNtl4ZW
|
|
7
7
|
langfun/core/console.py,sha256=bk5rNPNm9rMGW5YT2HixxU04p2umnoabn5SDz6Dqe88,2317
|
8
8
|
langfun/core/console_test.py,sha256=5SYJdxpJGLgdSSQqqMPoA1X6jpsLD8rgcyk-EgI65oE,1077
|
9
9
|
langfun/core/langfunc.py,sha256=WXdTc3QsmGD_n80KD9dFRr5MHpGZ9E_y_Rhtk4t9-3w,11852
|
10
|
-
langfun/core/langfunc_test.py,sha256=
|
11
|
-
langfun/core/language_model.py,sha256=
|
12
|
-
langfun/core/language_model_test.py,sha256=
|
10
|
+
langfun/core/langfunc_test.py,sha256=sQaKuZpGGmG80GRifhbxkj7nfzQLJKj4Vuw5y1s1K3U,8378
|
11
|
+
langfun/core/language_model.py,sha256=Tzswu0hyXOQOZ3fZ_Mz_Cc0ei7tVj8rTay9jJEgM6mI,17510
|
12
|
+
langfun/core/language_model_test.py,sha256=KvXXOr64TsSs3WkEALCLLZSlz09i7hBiHDOZ_8Eq8_o,13047
|
13
13
|
langfun/core/memory.py,sha256=f-asN1F7Vehgdn_fK84v73GrEUOxRtaW934keutTKjk,2416
|
14
14
|
langfun/core/message.py,sha256=QhvV9t5qaryPcruyxxcXi3gm9QDInkSldwTtK6sVJ3c,15734
|
15
15
|
langfun/core/message_test.py,sha256=Z23pUM5vPnDrYkIIibe2KL73D5HKur_awI0ut_EQFQA,9501
|
@@ -40,25 +40,25 @@ langfun/core/coding/python/parsing_test.py,sha256=9vAWF484kWIm6JZq8NFiMgKUDhXV-d
|
|
40
40
|
langfun/core/coding/python/permissions.py,sha256=1QWGHvzL8MM0Ok_auQ9tURqZHtdOfJaDpBzZ29GUE-c,2544
|
41
41
|
langfun/core/coding/python/permissions_test.py,sha256=w5EDb8QxpxgJyZkojyzVWQvDfg366zn99-g__6TbPQ0,2699
|
42
42
|
langfun/core/eval/__init__.py,sha256=iDA2OcJ3kR6ixZizXIY3N9LsjkaVrfTbSClTiSP8ekY,1291
|
43
|
-
langfun/core/eval/base.py,sha256=
|
44
|
-
langfun/core/eval/base_test.py,sha256=
|
43
|
+
langfun/core/eval/base.py,sha256=TZAmcdRBtzwMG1V3e_NgyJXg7J6dWMdMBrHvBnFuFho,55359
|
44
|
+
langfun/core/eval/base_test.py,sha256=OuuXFW_lX9bGhyd__kvlDSNJVne-5cSlnm-qDhyvOcc,21592
|
45
45
|
langfun/core/eval/matching.py,sha256=aqNlYrlav7YmsB7rUlsdfoi1RLA5CYqn2RGPxRlPc78,9599
|
46
46
|
langfun/core/eval/matching_test.py,sha256=FFHYD7IDuKe5RMjkx74ksukiwUhO5a_SS340JaIPMws,4898
|
47
47
|
langfun/core/eval/scoring.py,sha256=aKeanBJf1yO3Q9JEtgPWoiZk_3M_GiqwXVXX7x_g22w,6172
|
48
48
|
langfun/core/eval/scoring_test.py,sha256=YH1cIxBWtfdKcAV9Fh10vLkV5J-gxk8b6nxW4Z2u5pk,4024
|
49
49
|
langfun/core/llms/__init__.py,sha256=gROJ8AjMq_ebXFcEfsyzYGCS6NsGfzf9d43nLu_TIdw,2504
|
50
|
-
langfun/core/llms/fake.py,sha256=
|
51
|
-
langfun/core/llms/fake_test.py,sha256=
|
50
|
+
langfun/core/llms/fake.py,sha256=b-Xk5IPTbUt-elsyzd_i3n1tqzc_kgETXrEvgJruSMk,2824
|
51
|
+
langfun/core/llms/fake_test.py,sha256=AThvNyhZbkpsn-YO798uLgqB6TSw5XP2SKpKvcXEytw,4188
|
52
52
|
langfun/core/llms/google_genai.py,sha256=n8zyJwh9UCTgb6-8LyvmjVNFGZQ4-zfzZ0ulkhHAnR8,8624
|
53
53
|
langfun/core/llms/google_genai_test.py,sha256=_UcGTfl16-aDUlEWFC2W2F8y9jPUs53RBYA6MOCpGXw,7525
|
54
54
|
langfun/core/llms/llama_cpp.py,sha256=Y_KkMUf3Xfac49koMUtUslKl3h-HWp3-ntq7Jaa3bdo,2385
|
55
55
|
langfun/core/llms/llama_cpp_test.py,sha256=ZxC6defGd_HX9SFRU9U4cJiQnBKundbOrchbXuC1Z2M,1683
|
56
|
-
langfun/core/llms/openai.py,sha256=
|
57
|
-
langfun/core/llms/openai_test.py,sha256=
|
56
|
+
langfun/core/llms/openai.py,sha256=1EUd8WTI6EpcU_fzD90-4M11RdL9Mj4S9zfrzUZIyGM,11463
|
57
|
+
langfun/core/llms/openai_test.py,sha256=hiByS95g3pXtjB2XfIdVCKiAZDb_-Qirb2_LsSyskpY,8166
|
58
58
|
langfun/core/llms/cache/__init__.py,sha256=QAo3InUMDM_YpteNnVCSejI4zOsnjSMWKJKzkb3VY64,993
|
59
59
|
langfun/core/llms/cache/base.py,sha256=cFfYvOIUae842pncqCAsRvqXCk2AnAsRYVx0mcIoAeY,3338
|
60
60
|
langfun/core/llms/cache/in_memory.py,sha256=YfFyJEhLs73cUiB0ZfhMxYpdE8Iuxxw-dvMFwGHTSHw,4742
|
61
|
-
langfun/core/llms/cache/in_memory_test.py,sha256=
|
61
|
+
langfun/core/llms/cache/in_memory_test.py,sha256=D-n26h__rVXQO51WRFhRfq5sw1oifRLx2SvCQWuNEm8,8747
|
62
62
|
langfun/core/memories/__init__.py,sha256=HpghfZ-w1NQqzJXBx8Lz0daRhB2rcy2r9Xm491SBhC4,773
|
63
63
|
langfun/core/memories/conversation_history.py,sha256=c9amD8hCxGFiZuVAzkP0dOMWSp8L90uvwkOejjuBqO0,1835
|
64
64
|
langfun/core/memories/conversation_history_test.py,sha256=AaW8aNoFjxNusanwJDV0r3384Mg0eAweGmPx5DIkM0Y,2052
|
@@ -71,15 +71,15 @@ langfun/core/modalities/video.py,sha256=25M4XsNG5XEWRy57LYT_a6_aMURMPAgC41B3weEX
|
|
71
71
|
langfun/core/modalities/video_test.py,sha256=jYuI2m8S8zDCAVBPEUbbpP205dXAht90A2_PHWo4-r8,2039
|
72
72
|
langfun/core/structured/__init__.py,sha256=SpObW-HKpyKvkLlX8FV5ixz7CRm098j2aGfOguM3AUI,3462
|
73
73
|
langfun/core/structured/completion.py,sha256=skBxt6V_fv2TBUKnzFgnPMbVY8HSYn8sY04MLok2yvs,7299
|
74
|
-
langfun/core/structured/completion_test.py,sha256=
|
74
|
+
langfun/core/structured/completion_test.py,sha256=0FJreSmz0Umsj47dIlOyCjBXUa7janIplXhg1CbLT4U,19301
|
75
75
|
langfun/core/structured/description.py,sha256=SXW4MJvshFjbR-0gw6rE21o6WXq12UlRXawvDBXMZFA,5211
|
76
76
|
langfun/core/structured/description_test.py,sha256=UtZGjSFUaQ6130t1E5tcL7ODu0xIefkapb53TbnqsK8,7362
|
77
77
|
langfun/core/structured/mapping.py,sha256=7JInwZLmQdu7asHhC0vFLJNOCBnY-hrD6v5RQgf-xKk,11020
|
78
78
|
langfun/core/structured/mapping_test.py,sha256=07DDCGbwytQHSMm7fCi5-Ly-JNgdV4ubHZq0wthX4A4,3338
|
79
79
|
langfun/core/structured/parsing.py,sha256=keoVqEfzAbdULh6GawWFsTQzU91MzJXYFZjXGXLaD8g,11492
|
80
|
-
langfun/core/structured/parsing_test.py,sha256=
|
80
|
+
langfun/core/structured/parsing_test.py,sha256=9rUe7ipRhltQv7y8NXgR98lBXhSVKnfRM9TSAyVdxbs,20980
|
81
81
|
langfun/core/structured/prompting.py,sha256=mOmCWNVMnBk4rI7KBlEm5kmusPXoAKiWcohhzaw-s2o,7427
|
82
|
-
langfun/core/structured/prompting_test.py,sha256=
|
82
|
+
langfun/core/structured/prompting_test.py,sha256=csOzqHRp6T3KGp7Dsm0vS-BkZdQ4ALRt09iiFNz_YmA,19945
|
83
83
|
langfun/core/structured/schema.py,sha256=mJXirgqx3N7SA9zBO_ISHrzcV-ZRshLhnMJyCcSjGjY,25057
|
84
84
|
langfun/core/structured/schema_generation.py,sha256=U3nRQsqmMZg_qIVDh2fiY3K4JLfsAL1LcKzIFP1iXFg,5316
|
85
85
|
langfun/core/structured/schema_generation_test.py,sha256=cfZyP0gHno2fXy_c9vsVdvHmqKQSfuyUsCtfO3JFmYQ,2945
|
@@ -94,9 +94,9 @@ langfun/core/templates/conversation_test.py,sha256=RryYyIhfc34dLWOs6GfPQ8HU8mXpK
|
|
94
94
|
langfun/core/templates/demonstration.py,sha256=vCrgYubdZM5Umqcgp8NUVGXgr4P_c-fikKhwhzwhpKI,1460
|
95
95
|
langfun/core/templates/demonstration_test.py,sha256=SafcDQ0WgI7pw05EmPI2S4v1t3ABKzup8jReCljHeK4,2162
|
96
96
|
langfun/core/templates/selfplay.py,sha256=yhgrJbiYwq47TgzThmHrDQTF4nDrTI09CWGhuQPNv-s,2273
|
97
|
-
langfun/core/templates/selfplay_test.py,sha256=
|
98
|
-
langfun-0.0.2.
|
99
|
-
langfun-0.0.2.
|
100
|
-
langfun-0.0.2.
|
101
|
-
langfun-0.0.2.
|
102
|
-
langfun-0.0.2.
|
97
|
+
langfun/core/templates/selfplay_test.py,sha256=DYVrkk7uNKCqJGEHH31HssU2BPuMItU1vJLzfcXIlYg,2156
|
98
|
+
langfun-0.0.2.dev20240415.dist-info/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
|
99
|
+
langfun-0.0.2.dev20240415.dist-info/METADATA,sha256=V_zKk0hFMrBR6jMyr0C0v71Y4RJ9GL9b0uAkBerHIIw,3405
|
100
|
+
langfun-0.0.2.dev20240415.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
101
|
+
langfun-0.0.2.dev20240415.dist-info/top_level.txt,sha256=RhlEkHxs1qtzmmtWSwYoLVJAc1YrbPtxQ52uh8Z9VvY,8
|
102
|
+
langfun-0.0.2.dev20240415.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|