langfun 0.0.2.dev20240414__py3-none-any.whl → 0.0.2.dev20240415__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
langfun/core/__init__.py CHANGED
@@ -99,6 +99,7 @@ from langfun.core.modality import ModalityRef
99
99
  from langfun.core.language_model import LanguageModel
100
100
  from langfun.core.language_model import LMSample
101
101
  from langfun.core.language_model import LMSamplingOptions
102
+ from langfun.core.language_model import LMSamplingUsage
102
103
  from langfun.core.language_model import LMSamplingResult
103
104
  from langfun.core.language_model import LMScoringResult
104
105
  from langfun.core.language_model import LMCache
langfun/core/eval/base.py CHANGED
@@ -1565,6 +1565,7 @@ class Summary(pg.Object):
1565
1565
  results.append(
1566
1566
  pg.Dict(
1567
1567
  experiment=entry,
1568
+ dir=entry.dir,
1568
1569
  metrics=entry.result.metrics if entry.result else None,
1569
1570
  )
1570
1571
  )
@@ -194,6 +194,7 @@ class EvaluationTest(unittest.TestCase):
194
194
  cache_seed=0,
195
195
  score=1.0,
196
196
  logprobs=None,
197
+ usage=lf.LMSamplingUsage(387, 24, 411),
197
198
  tags=['lm-response', 'lm-output', 'transformed'],
198
199
  ),
199
200
  )
@@ -82,7 +82,9 @@ class LangFuncCallTest(unittest.TestCase):
82
82
  self.assertEqual(i.tags, ['rendered'])
83
83
 
84
84
  r = l()
85
- self.assertEqual(r, message.AIMessage('Hello!!!', score=0.0, logprobs=None))
85
+ self.assertEqual(
86
+ r, message.AIMessage('Hello!!!', score=0.0, logprobs=None, usage=None)
87
+ )
86
88
  self.assertEqual(r.tags, ['lm-response', 'lm-output'])
87
89
  self.assertEqual(r.source, message.UserMessage('Hello'))
88
90
  self.assertEqual(r.source.tags, ['rendered', 'lm-input'])
@@ -106,7 +108,7 @@ class LangFuncCallTest(unittest.TestCase):
106
108
  self.assertEqual(l.render(), 'Hello')
107
109
  r = l()
108
110
  self.assertEqual(
109
- r, message.AIMessage('Hello!!!', score=0.0, logprobs=None)
111
+ r, message.AIMessage('Hello!!!', score=0.0, logprobs=None, usage=None)
110
112
  )
111
113
  self.assertEqual(r.tags, ['lm-response', 'lm-output'])
112
114
 
@@ -47,6 +47,14 @@ class LMSample(pg.Object):
47
47
  ] = None
48
48
 
49
49
 
50
+ class LMSamplingUsage(pg.Object):
51
+ """Usage information per completion."""
52
+
53
+ prompt_tokens: int
54
+ completion_tokens: int
55
+ total_tokens: int
56
+
57
+
50
58
  class LMSamplingResult(pg.Object):
51
59
  """Language model response."""
52
60
 
@@ -58,6 +66,11 @@ class LMSamplingResult(pg.Object):
58
66
  ),
59
67
  ] = []
60
68
 
69
+ usage: Annotated[
70
+ LMSamplingUsage | None,
71
+ 'Usage information. Currently only OpenAI models are supported.',
72
+ ] = None
73
+
61
74
 
62
75
  class LMSamplingOptions(component.Component):
63
76
  """Language model sampling options."""
@@ -424,6 +437,8 @@ class LanguageModel(component.Component):
424
437
  logprobs = result.samples[0].logprobs
425
438
  response.set('score', result.samples[0].score)
426
439
  response.metadata.logprobs = logprobs
440
+ response.metadata.usage = result.usage
441
+
427
442
  elapse = time.time() - request_start
428
443
  self._debug(prompt, response, call_counter, elapse)
429
444
  return response
@@ -38,9 +38,19 @@ class MockModel(lm_lib.LanguageModel):
38
38
  def fake_sample(prompts):
39
39
  if context.attempt >= self.failures_before_attempt:
40
40
  return [
41
- lm_lib.LMSamplingResult([lm_lib.LMSample( # pylint: disable=g-complex-comprehension
42
- response=prompt.text * self.sampling_options.top_k,
43
- score=self.sampling_options.temperature or -1.0)])
41
+ lm_lib.LMSamplingResult(
42
+ [
43
+ lm_lib.LMSample( # pylint: disable=g-complex-comprehension
44
+ response=prompt.text * self.sampling_options.top_k,
45
+ score=self.sampling_options.temperature or -1.0,
46
+ )
47
+ ],
48
+ usage=lm_lib.LMSamplingUsage(
49
+ prompt_tokens=100,
50
+ completion_tokens=100,
51
+ total_tokens=200,
52
+ ),
53
+ )
44
54
  for prompt in prompts
45
55
  ]
46
56
  context.attempt += 1
@@ -100,8 +110,14 @@ class LanguageModelTest(unittest.TestCase):
100
110
  self.assertEqual(
101
111
  lm.sample(prompts=['foo', 'bar']),
102
112
  [
103
- lm_lib.LMSamplingResult([lm_lib.LMSample('foo', score=-1.0)]),
104
- lm_lib.LMSamplingResult([lm_lib.LMSample('bar', score=-1.0)]),
113
+ lm_lib.LMSamplingResult(
114
+ [lm_lib.LMSample('foo', score=-1.0)],
115
+ usage=lm_lib.LMSamplingUsage(100, 100, 200),
116
+ ),
117
+ lm_lib.LMSamplingResult(
118
+ [lm_lib.LMSample('bar', score=-1.0)],
119
+ usage=lm_lib.LMSamplingUsage(100, 100, 200),
120
+ ),
105
121
  ],
106
122
  )
107
123
  # Test override sampling_options.
@@ -112,10 +128,12 @@ class LanguageModelTest(unittest.TestCase):
112
128
  ),
113
129
  [
114
130
  lm_lib.LMSamplingResult(
115
- [lm_lib.LMSample('foo' * 2, score=0.5)]
131
+ [lm_lib.LMSample('foo' * 2, score=0.5)],
132
+ usage=lm_lib.LMSamplingUsage(100, 100, 200),
116
133
  ),
117
134
  lm_lib.LMSamplingResult(
118
- [lm_lib.LMSample('bar' * 2, score=0.5)]
135
+ [lm_lib.LMSample('bar' * 2, score=0.5)],
136
+ usage=lm_lib.LMSamplingUsage(100, 100, 200),
119
137
  ),
120
138
  ],
121
139
  )
@@ -123,18 +141,26 @@ class LanguageModelTest(unittest.TestCase):
123
141
  self.assertEqual(
124
142
  lm.sample(prompts=['foo', 'bar'], temperature=1.0),
125
143
  [
126
- lm_lib.LMSamplingResult([lm_lib.LMSample('foo', score=1.0)]),
127
- lm_lib.LMSamplingResult([lm_lib.LMSample('bar', score=1.0)]),
144
+ lm_lib.LMSamplingResult(
145
+ [lm_lib.LMSample('foo', score=1.0)],
146
+ usage=lm_lib.LMSamplingUsage(100, 100, 200),
147
+ ),
148
+ lm_lib.LMSamplingResult(
149
+ [lm_lib.LMSample('bar', score=1.0)],
150
+ usage=lm_lib.LMSamplingUsage(100, 100, 200),
151
+ ),
128
152
  ],
129
153
  )
130
154
  self.assertEqual(
131
155
  lm.sample(prompts=['foo', 'bar'], top_k=2, temperature=0.7),
132
156
  [
133
157
  lm_lib.LMSamplingResult(
134
- [lm_lib.LMSample('foo' * 2, score=0.7)]
158
+ [lm_lib.LMSample('foo' * 2, score=0.7)],
159
+ usage=lm_lib.LMSamplingUsage(100, 100, 200),
135
160
  ),
136
161
  lm_lib.LMSamplingResult(
137
- [lm_lib.LMSample('bar' * 2, score=0.7)]
162
+ [lm_lib.LMSample('bar' * 2, score=0.7)],
163
+ usage=lm_lib.LMSamplingUsage(100, 100, 200),
138
164
  ),
139
165
  ],
140
166
  )
@@ -144,6 +170,8 @@ class LanguageModelTest(unittest.TestCase):
144
170
  response = lm(prompt='foo')
145
171
  self.assertEqual(response.text, 'foo')
146
172
  self.assertEqual(response.score, -1.0)
173
+ self.assertIsNone(response.logprobs)
174
+ self.assertEqual(response.usage, lm_lib.LMSamplingUsage(100, 100, 200))
147
175
 
148
176
  # Test override sampling_options.
149
177
  self.assertEqual(
@@ -158,11 +186,24 @@ class LanguageModelTest(unittest.TestCase):
158
186
  self.assertEqual(
159
187
  lm.sample(prompts=['foo', 'bar']),
160
188
  [
161
- lm_lib.LMSamplingResult([lm_lib.LMSample(
162
- message_lib.AIMessage('foo', cache_seed=0), score=-1.0)]),
163
- lm_lib.LMSamplingResult([lm_lib.LMSample(
164
- message_lib.AIMessage('bar', cache_seed=0), score=-1.0)]),
165
- ])
189
+ lm_lib.LMSamplingResult(
190
+ [
191
+ lm_lib.LMSample(
192
+ message_lib.AIMessage('foo', cache_seed=0), score=-1.0
193
+ )
194
+ ],
195
+ usage=lm_lib.LMSamplingUsage(100, 100, 200),
196
+ ),
197
+ lm_lib.LMSamplingResult(
198
+ [
199
+ lm_lib.LMSample(
200
+ message_lib.AIMessage('bar', cache_seed=0), score=-1.0
201
+ )
202
+ ],
203
+ usage=lm_lib.LMSamplingUsage(100, 100, 200),
204
+ ),
205
+ ],
206
+ )
166
207
  self.assertEqual(cache.stats.num_queries, 2)
167
208
  self.assertEqual(cache.stats.num_hits, 0)
168
209
  self.assertEqual(cache.stats.num_updates, 2)
@@ -181,10 +222,22 @@ class LanguageModelTest(unittest.TestCase):
181
222
  self.assertEqual(
182
223
  lm.sample(prompts=['foo', 'baz'], temperature=1.0),
183
224
  [
184
- lm_lib.LMSamplingResult([lm_lib.LMSample(
185
- message_lib.AIMessage('foo', cache_seed=0), score=1.0)]),
186
- lm_lib.LMSamplingResult([lm_lib.LMSample(
187
- message_lib.AIMessage('baz', cache_seed=0), score=1.0)]),
225
+ lm_lib.LMSamplingResult(
226
+ [
227
+ lm_lib.LMSample(
228
+ message_lib.AIMessage('foo', cache_seed=0), score=1.0
229
+ )
230
+ ],
231
+ usage=lm_lib.LMSamplingUsage(100, 100, 200),
232
+ ),
233
+ lm_lib.LMSamplingResult(
234
+ [
235
+ lm_lib.LMSample(
236
+ message_lib.AIMessage('baz', cache_seed=0), score=1.0
237
+ )
238
+ ],
239
+ usage=lm_lib.LMSamplingUsage(100, 100, 200),
240
+ ),
188
241
  ],
189
242
  )
190
243
  self.assertEqual(cache.stats.num_queries, 6)
@@ -62,10 +62,19 @@ class InMemoryLMCacheTest(unittest.TestCase):
62
62
 
63
63
  def cache_entry(response_text, cache_seed=0):
64
64
  return base.LMCacheEntry(
65
- lf.LMSamplingResult([
66
- lf.LMSample(
67
- lf.AIMessage(response_text, cache_seed=cache_seed), score=1.0)
68
- ])
65
+ lf.LMSamplingResult(
66
+ [
67
+ lf.LMSample(
68
+ lf.AIMessage(response_text, cache_seed=cache_seed),
69
+ score=1.0
70
+ )
71
+ ],
72
+ usage=lf.LMSamplingUsage(
73
+ 1,
74
+ len(response_text),
75
+ len(response_text) + 1,
76
+ )
77
+ )
69
78
  )
70
79
 
71
80
  self.assertEqual(
langfun/core/llms/fake.py CHANGED
@@ -13,6 +13,7 @@
13
13
  # limitations under the License.
14
14
  """Fake LMs for testing."""
15
15
 
16
+ import abc
16
17
  from typing import Annotated
17
18
  import langfun.core as lf
18
19
 
@@ -23,15 +24,32 @@ class Fake(lf.LanguageModel):
23
24
  def _score(self, prompt: lf.Message, completions: list[lf.Message]):
24
25
  return [lf.LMScoringResult(score=-i * 1.0) for i in range(len(completions))]
25
26
 
27
+ def _sample(self, prompts: list[lf.Message]) -> list[lf.LMSamplingResult]:
28
+ results = []
29
+ for prompt in prompts:
30
+ response = self._response_from(prompt)
31
+ results.append(
32
+ lf.LMSamplingResult(
33
+ [lf.LMSample(response, 1.0)],
34
+ usage=lf.LMSamplingUsage(
35
+ prompt_tokens=len(prompt.text),
36
+ completion_tokens=len(response.text),
37
+ total_tokens=len(prompt.text) + len(response.text),
38
+ )
39
+ )
40
+ )
41
+ return results
42
+
43
+ @abc.abstractmethod
44
+ def _response_from(self, prompt: lf.Message) -> lf.Message:
45
+ """Returns the response for the given prompt."""
46
+
26
47
 
27
48
  class Echo(Fake):
28
49
  """A simple echo language model for testing."""
29
50
 
30
- def _sample(self, prompts: list[lf.Message]) -> list[lf.LMSamplingResult]:
31
- return [
32
- lf.LMSamplingResult([lf.LMSample(prompt.text, 1.0)])
33
- for prompt in prompts
34
- ]
51
+ def _response_from(self, prompt: lf.Message) -> lf.Message:
52
+ return lf.AIMessage(prompt.text)
35
53
 
36
54
 
37
55
  @lf.use_init_args(['response'])
@@ -43,11 +61,8 @@ class StaticResponse(Fake):
43
61
  'A canned response that will be returned regardless of the prompt.'
44
62
  ]
45
63
 
46
- def _sample(self, prompts: list[lf.Message]) -> list[lf.LMSamplingResult]:
47
- return [
48
- lf.LMSamplingResult([lf.LMSample(self.response, 1.0)])
49
- for _ in prompts
50
- ]
64
+ def _response_from(self, prompt: lf.Message) -> lf.Message:
65
+ return lf.AIMessage(self.response)
51
66
 
52
67
 
53
68
  @lf.use_init_args(['mapping'])
@@ -59,11 +74,8 @@ class StaticMapping(Fake):
59
74
  'A mapping from prompt to response.'
60
75
  ]
61
76
 
62
- def _sample(self, prompts: list[str]) -> list[lf.LMSamplingResult]:
63
- return [
64
- lf.LMSamplingResult([lf.LMSample(self.mapping[prompt], 1.0)])
65
- for prompt in prompts
66
- ]
77
+ def _response_from(self, prompt: lf.Message) -> lf.Message:
78
+ return lf.AIMessage(self.mapping[prompt])
67
79
 
68
80
 
69
81
  @lf.use_init_args(['sequence'])
@@ -79,10 +91,7 @@ class StaticSequence(Fake):
79
91
  super()._on_bound()
80
92
  self._pos = 0
81
93
 
82
- def _sample(self, prompts: list[str]) -> list[lf.LMSamplingResult]:
83
- results = []
84
- for _ in prompts:
85
- results.append(lf.LMSamplingResult(
86
- [lf.LMSample(self.sequence[self._pos], 1.0)]))
87
- self._pos += 1
88
- return results
94
+ def _response_from(self, prompt: lf.Message) -> lf.Message:
95
+ r = lf.AIMessage(self.sequence[self._pos])
96
+ self._pos += 1
97
+ return r
@@ -25,7 +25,12 @@ class EchoTest(unittest.TestCase):
25
25
  def test_sample(self):
26
26
  lm = fakelm.Echo()
27
27
  self.assertEqual(
28
- lm.sample(['hi']), [lf.LMSamplingResult([lf.LMSample('hi', 1.0)])]
28
+ lm.sample(['hi']),
29
+ [
30
+ lf.LMSamplingResult(
31
+ [lf.LMSample('hi', 1.0)],
32
+ lf.LMSamplingUsage(2, 2, 4))
33
+ ]
29
34
  )
30
35
 
31
36
  def test_call(self):
@@ -53,11 +58,21 @@ class StaticResponseTest(unittest.TestCase):
53
58
  lm = fakelm.StaticResponse(canned_response)
54
59
  self.assertEqual(
55
60
  lm.sample(['hi']),
56
- [lf.LMSamplingResult([lf.LMSample(canned_response, 1.0)])],
61
+ [
62
+ lf.LMSamplingResult(
63
+ [lf.LMSample(canned_response, 1.0)],
64
+ usage=lf.LMSamplingUsage(2, 38, 40)
65
+ )
66
+ ],
57
67
  )
58
68
  self.assertEqual(
59
69
  lm.sample(['Tell me a joke.']),
60
- [lf.LMSamplingResult([lf.LMSample(canned_response, 1.0)])],
70
+ [
71
+ lf.LMSamplingResult(
72
+ [lf.LMSample(canned_response, 1.0)],
73
+ usage=lf.LMSamplingUsage(15, 38, 53)
74
+ )
75
+ ],
61
76
  )
62
77
 
63
78
  def test_call(self):
@@ -85,8 +100,14 @@ class StaticMappingTest(unittest.TestCase):
85
100
  self.assertEqual(
86
101
  lm.sample(['Hi', 'How are you?']),
87
102
  [
88
- lf.LMSamplingResult([lf.LMSample('Hello', 1.0)]),
89
- lf.LMSamplingResult([lf.LMSample('I am fine, how about you?', 1.0)])
103
+ lf.LMSamplingResult(
104
+ [lf.LMSample('Hello', 1.0)],
105
+ usage=lf.LMSamplingUsage(2, 5, 7)
106
+ ),
107
+ lf.LMSamplingResult(
108
+ [lf.LMSample('I am fine, how about you?', 1.0)],
109
+ usage=lf.LMSamplingUsage(12, 25, 37)
110
+ )
90
111
  ]
91
112
  )
92
113
  with self.assertRaises(KeyError):
@@ -104,8 +125,14 @@ class StaticSequenceTest(unittest.TestCase):
104
125
  self.assertEqual(
105
126
  lm.sample(['Hi', 'How are you?']),
106
127
  [
107
- lf.LMSamplingResult([lf.LMSample('Hello', 1.0)]),
108
- lf.LMSamplingResult([lf.LMSample('I am fine, how about you?', 1.0)])
128
+ lf.LMSamplingResult(
129
+ [lf.LMSample('Hello', 1.0)],
130
+ usage=lf.LMSamplingUsage(2, 5, 7)
131
+ ),
132
+ lf.LMSamplingResult(
133
+ [lf.LMSample('I am fine, how about you?', 1.0)],
134
+ usage=lf.LMSamplingUsage(12, 25, 37)
135
+ )
109
136
  ]
110
137
  )
111
138
  with self.assertRaises(IndexError):
@@ -26,20 +26,6 @@ from openai import openai_object
26
26
  import pyglove as pg
27
27
 
28
28
 
29
- class Usage(pg.Object):
30
- """Usage information per completion."""
31
-
32
- prompt_tokens: int
33
- completion_tokens: int
34
- total_tokens: int
35
-
36
-
37
- class LMSamplingResult(lf.LMSamplingResult):
38
- """LMSamplingResult with usage information."""
39
-
40
- usage: Usage | None = None
41
-
42
-
43
29
  SUPPORTED_MODELS_AND_SETTINGS = [
44
30
  # Model name, max concurrent requests.
45
31
  # The concurrent requests is estimated by TPM/RPM from
@@ -181,7 +167,7 @@ class OpenAI(lf.LanguageModel):
181
167
  args['stop'] = options.stop
182
168
  return args
183
169
 
184
- def _sample(self, prompts: list[lf.Message]) -> list[LMSamplingResult]:
170
+ def _sample(self, prompts: list[lf.Message]) -> list[lf.LMSamplingResult]:
185
171
  assert self._api_initialized
186
172
  if self.is_chat_model:
187
173
  return self._chat_complete_batch(prompts)
@@ -189,7 +175,8 @@ class OpenAI(lf.LanguageModel):
189
175
  return self._complete_batch(prompts)
190
176
 
191
177
  def _complete_batch(
192
- self, prompts: list[lf.Message]) -> list[LMSamplingResult]:
178
+ self, prompts: list[lf.Message]
179
+ ) -> list[lf.LMSamplingResult]:
193
180
 
194
181
  def _open_ai_completion(prompts):
195
182
  response = openai.Completion.create(
@@ -204,13 +191,13 @@ class OpenAI(lf.LanguageModel):
204
191
  lf.LMSample(choice.text.strip(), score=choice.logprobs or 0.0)
205
192
  )
206
193
 
207
- usage = Usage(
194
+ usage = lf.LMSamplingUsage(
208
195
  prompt_tokens=response.usage.prompt_tokens,
209
196
  completion_tokens=response.usage.completion_tokens,
210
197
  total_tokens=response.usage.total_tokens,
211
198
  )
212
199
  return [
213
- LMSamplingResult(
200
+ lf.LMSamplingResult(
214
201
  samples_by_index[index], usage=usage if index == 0 else None
215
202
  )
216
203
  for index in sorted(samples_by_index.keys())
@@ -231,7 +218,7 @@ class OpenAI(lf.LanguageModel):
231
218
 
232
219
  def _chat_complete_batch(
233
220
  self, prompts: list[lf.Message]
234
- ) -> list[LMSamplingResult]:
221
+ ) -> list[lf.LMSamplingResult]:
235
222
  def _open_ai_chat_completion(prompt: lf.Message):
236
223
  if self.multimodal:
237
224
  content = []
@@ -272,9 +259,9 @@ class OpenAI(lf.LanguageModel):
272
259
  )
273
260
  )
274
261
 
275
- return LMSamplingResult(
262
+ return lf.LMSamplingResult(
276
263
  samples=samples,
277
- usage=Usage(
264
+ usage=lf.LMSamplingUsage(
278
265
  prompt_tokens=response.usage.prompt_tokens,
279
266
  completion_tokens=response.usage.completion_tokens,
280
267
  total_tokens=response.usage.total_tokens,
@@ -32,11 +32,14 @@ def mock_completion_query(prompt, *, n=1, **kwargs):
32
32
  text=f'Sample {k} for prompt {i}.',
33
33
  logprobs=k / 10,
34
34
  ))
35
- return pg.Dict(choices=choices, usage=openai.Usage(
36
- prompt_tokens=100,
37
- completion_tokens=100,
38
- total_tokens=200,
39
- ))
35
+ return pg.Dict(
36
+ choices=choices,
37
+ usage=lf.LMSamplingUsage(
38
+ prompt_tokens=100,
39
+ completion_tokens=100,
40
+ total_tokens=200,
41
+ ),
42
+ )
40
43
 
41
44
 
42
45
  def mock_chat_completion_query(messages, *, n=1, **kwargs):
@@ -49,11 +52,14 @@ def mock_chat_completion_query(messages, *, n=1, **kwargs):
49
52
  ),
50
53
  logprobs=None,
51
54
  ))
52
- return pg.Dict(choices=choices, usage=openai.Usage(
53
- prompt_tokens=100,
54
- completion_tokens=100,
55
- total_tokens=200,
56
- ))
55
+ return pg.Dict(
56
+ choices=choices,
57
+ usage=lf.LMSamplingUsage(
58
+ prompt_tokens=100,
59
+ completion_tokens=100,
60
+ total_tokens=200,
61
+ ),
62
+ )
57
63
 
58
64
 
59
65
  def mock_chat_completion_query_vision(messages, *, n=1, **kwargs):
@@ -69,11 +75,14 @@ def mock_chat_completion_query_vision(messages, *, n=1, **kwargs):
69
75
  ),
70
76
  logprobs=None,
71
77
  ))
72
- return pg.Dict(choices=choices, usage=openai.Usage(
73
- prompt_tokens=100,
74
- completion_tokens=100,
75
- total_tokens=200,
76
- ))
78
+ return pg.Dict(
79
+ choices=choices,
80
+ usage=lf.LMSamplingUsage(
81
+ prompt_tokens=100,
82
+ completion_tokens=100,
83
+ total_tokens=200,
84
+ ),
85
+ )
77
86
 
78
87
 
79
88
  class OpenaiTest(unittest.TestCase):
@@ -169,18 +178,28 @@ class OpenaiTest(unittest.TestCase):
169
178
  )
170
179
 
171
180
  self.assertEqual(len(results), 2)
172
- self.assertEqual(results[0], openai.LMSamplingResult([
173
- lf.LMSample('Sample 0 for prompt 0.', score=0.0),
174
- lf.LMSample('Sample 1 for prompt 0.', score=0.1),
175
- lf.LMSample('Sample 2 for prompt 0.', score=0.2),
176
- ], usage=openai.Usage(
177
- prompt_tokens=100, completion_tokens=100, total_tokens=200)))
178
-
179
- self.assertEqual(results[1], openai.LMSamplingResult([
180
- lf.LMSample('Sample 0 for prompt 1.', score=0.0),
181
- lf.LMSample('Sample 1 for prompt 1.', score=0.1),
182
- lf.LMSample('Sample 2 for prompt 1.', score=0.2),
183
- ]))
181
+ self.assertEqual(
182
+ results[0],
183
+ lf.LMSamplingResult(
184
+ [
185
+ lf.LMSample('Sample 0 for prompt 0.', score=0.0),
186
+ lf.LMSample('Sample 1 for prompt 0.', score=0.1),
187
+ lf.LMSample('Sample 2 for prompt 0.', score=0.2),
188
+ ],
189
+ usage=lf.LMSamplingUsage(
190
+ prompt_tokens=100, completion_tokens=100, total_tokens=200
191
+ ),
192
+ ),
193
+ )
194
+
195
+ self.assertEqual(
196
+ results[1],
197
+ lf.LMSamplingResult([
198
+ lf.LMSample('Sample 0 for prompt 1.', score=0.0),
199
+ lf.LMSample('Sample 1 for prompt 1.', score=0.1),
200
+ lf.LMSample('Sample 2 for prompt 1.', score=0.2),
201
+ ]),
202
+ )
184
203
 
185
204
  def test_sample_chat_completion(self):
186
205
  with mock.patch('openai.ChatCompletion.create') as mock_chat_completion:
@@ -191,18 +210,32 @@ class OpenaiTest(unittest.TestCase):
191
210
  )
192
211
 
193
212
  self.assertEqual(len(results), 2)
194
- self.assertEqual(results[0], openai.LMSamplingResult([
195
- lf.LMSample('Sample 0 for message.', score=0.0),
196
- lf.LMSample('Sample 1 for message.', score=0.0),
197
- lf.LMSample('Sample 2 for message.', score=0.0),
198
- ], usage=openai.Usage(
199
- prompt_tokens=100, completion_tokens=100, total_tokens=200)))
200
- self.assertEqual(results[1], openai.LMSamplingResult([
201
- lf.LMSample('Sample 0 for message.', score=0.0),
202
- lf.LMSample('Sample 1 for message.', score=0.0),
203
- lf.LMSample('Sample 2 for message.', score=0.0),
204
- ], usage=openai.Usage(
205
- prompt_tokens=100, completion_tokens=100, total_tokens=200)))
213
+ self.assertEqual(
214
+ results[0],
215
+ lf.LMSamplingResult(
216
+ [
217
+ lf.LMSample('Sample 0 for message.', score=0.0),
218
+ lf.LMSample('Sample 1 for message.', score=0.0),
219
+ lf.LMSample('Sample 2 for message.', score=0.0),
220
+ ],
221
+ usage=lf.LMSamplingUsage(
222
+ prompt_tokens=100, completion_tokens=100, total_tokens=200
223
+ ),
224
+ ),
225
+ )
226
+ self.assertEqual(
227
+ results[1],
228
+ lf.LMSamplingResult(
229
+ [
230
+ lf.LMSample('Sample 0 for message.', score=0.0),
231
+ lf.LMSample('Sample 1 for message.', score=0.0),
232
+ lf.LMSample('Sample 2 for message.', score=0.0),
233
+ ],
234
+ usage=lf.LMSamplingUsage(
235
+ prompt_tokens=100, completion_tokens=100, total_tokens=200
236
+ ),
237
+ ),
238
+ )
206
239
 
207
240
  def test_sample_with_contextual_options(self):
208
241
  with mock.patch('openai.Completion.create') as mock_completion:
@@ -212,11 +245,18 @@ class OpenaiTest(unittest.TestCase):
212
245
  results = lm.sample(['hello'])
213
246
 
214
247
  self.assertEqual(len(results), 1)
215
- self.assertEqual(results[0], openai.LMSamplingResult([
216
- lf.LMSample('Sample 0 for prompt 0.', score=0.0),
217
- lf.LMSample('Sample 1 for prompt 0.', score=0.1),
218
- ], usage=openai.Usage(
219
- prompt_tokens=100, completion_tokens=100, total_tokens=200)))
248
+ self.assertEqual(
249
+ results[0],
250
+ lf.LMSamplingResult(
251
+ [
252
+ lf.LMSample('Sample 0 for prompt 0.', score=0.0),
253
+ lf.LMSample('Sample 1 for prompt 0.', score=0.1),
254
+ ],
255
+ usage=lf.LMSamplingUsage(
256
+ prompt_tokens=100, completion_tokens=100, total_tokens=200
257
+ ),
258
+ ),
259
+ )
220
260
 
221
261
 
222
262
  if __name__ == '__main__':
@@ -583,6 +583,7 @@ class CompleteStructureTest(unittest.TestCase):
583
583
  result=Activity(description='foo'),
584
584
  score=1.0,
585
585
  logprobs=None,
586
+ usage=lf.LMSamplingUsage(553, 27, 580),
586
587
  tags=['lm-response', 'lm-output', 'transformed']
587
588
  )
588
589
  )
@@ -280,13 +280,15 @@ class ParseStructurePythonTest(unittest.TestCase):
280
280
  ),
281
281
  1,
282
282
  )
283
+ r = parsing.parse(
284
+ 'the answer is 1', int, user_prompt='what is 0 + 1?', lm=lm,
285
+ returns_message=True
286
+ )
283
287
  self.assertEqual(
284
- parsing.parse(
285
- 'the answer is 1', int, user_prompt='what is 0 + 1?', lm=lm,
286
- returns_message=True
287
- ),
288
+ r,
288
289
  lf.AIMessage(
289
290
  '1', score=1.0, result=1, logprobs=None,
291
+ usage=lf.LMSamplingUsage(652, 1, 653),
290
292
  tags=['lm-response', 'lm-output', 'transformed']
291
293
  ),
292
294
  )
@@ -634,13 +636,18 @@ class CallTest(unittest.TestCase):
634
636
  )
635
637
 
636
638
  def test_call_with_returning_message(self):
639
+ r = parsing.call(
640
+ 'Compute 1 + 2', int, lm=fake.StaticSequence(['three', '3']),
641
+ returns_message=True
642
+ )
637
643
  self.assertEqual(
638
- parsing.call(
639
- 'Compute 1 + 2', int, lm=fake.StaticSequence(['three', '3']),
640
- returns_message=True
641
- ),
644
+ r,
642
645
  lf.AIMessage(
643
- '3', result=3, score=1.0, logprobs=None,
646
+ '3',
647
+ result=3,
648
+ score=1.0,
649
+ logprobs=None,
650
+ usage=lf.LMSamplingUsage(315, 1, 316),
644
651
  tags=['lm-response', 'lm-output', 'transformed']
645
652
  ),
646
653
  )
@@ -77,6 +77,7 @@ class QueryTest(unittest.TestCase):
77
77
  result=1,
78
78
  score=1.0,
79
79
  logprobs=None,
80
+ usage=lf.LMSamplingUsage(323, 1, 324),
80
81
  tags=['lm-response', 'lm-output', 'transformed'],
81
82
  ),
82
83
  )
@@ -56,7 +56,9 @@ class SelfPlayTest(unittest.TestCase):
56
56
  g = NumberGuess(target_num=10)
57
57
 
58
58
  with lf.context(lm=NumberGuesser(guesses=[50, 20, 5, 10])):
59
- self.assertEqual(g(), lf.AIMessage('10', score=0.0, logprobs=None))
59
+ self.assertEqual(
60
+ g(), lf.AIMessage('10', score=0.0, logprobs=None, usage=None)
61
+ )
60
62
 
61
63
  self.assertEqual(g.num_turns, 4)
62
64
 
@@ -64,7 +66,9 @@ class SelfPlayTest(unittest.TestCase):
64
66
  g = NumberGuess(target_num=10, max_turns=10)
65
67
 
66
68
  with lf.context(lm=NumberGuesser(guesses=[50, 20, 5, 2, 5, 4])):
67
- self.assertEqual(g(), lf.AIMessage('2', score=0.0, logprobs=None))
69
+ self.assertEqual(
70
+ g(), lf.AIMessage('2', score=0.0, logprobs=None, usage=None)
71
+ )
68
72
 
69
73
  self.assertEqual(g.num_turns, 10)
70
74
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: langfun
3
- Version: 0.0.2.dev20240414
3
+ Version: 0.0.2.dev20240415
4
4
  Summary: Langfun: Language as Functions.
5
5
  Home-page: https://github.com/google/langfun
6
6
  Author: Langfun Authors
@@ -1,5 +1,5 @@
1
1
  langfun/__init__.py,sha256=PqX3u18BC0szYIMu00j-RKxvwkNPwXtAFZ-96oxrQ0M,1841
2
- langfun/core/__init__.py,sha256=sVcPl89lWYHQ1cUoaLaM8dErCovugJo5e2F3A_94Q3Y,4192
2
+ langfun/core/__init__.py,sha256=6QEuXOZ9BXxm6TjpaMXuLwUBTYO3pkFDqn9QVBXyyPQ,4248
3
3
  langfun/core/component.py,sha256=VRPfDB_2jEnxcB3-HoiVjG4ID-SMenNPIsytb0uXMPg,9674
4
4
  langfun/core/component_test.py,sha256=VAPd6V_-odAe8rBvesW3ogYDd6OSqRq4FaPhfgOM4Zg,7949
5
5
  langfun/core/concurrent.py,sha256=TRc49pJ3HQro2kb5FtcWkHjhBm8UcgE8RJybU5cU3-0,24537
@@ -7,9 +7,9 @@ langfun/core/concurrent_test.py,sha256=mwFMZhDUdppnDr7vDSTwcbMHwrdsIoKJwRYNtl4ZW
7
7
  langfun/core/console.py,sha256=bk5rNPNm9rMGW5YT2HixxU04p2umnoabn5SDz6Dqe88,2317
8
8
  langfun/core/console_test.py,sha256=5SYJdxpJGLgdSSQqqMPoA1X6jpsLD8rgcyk-EgI65oE,1077
9
9
  langfun/core/langfunc.py,sha256=WXdTc3QsmGD_n80KD9dFRr5MHpGZ9E_y_Rhtk4t9-3w,11852
10
- langfun/core/langfunc_test.py,sha256=rRxz2OOka5qagTSS1IcJ1Ij3mjjWawPFe1n9zYtGST8,8340
11
- langfun/core/language_model.py,sha256=D3aU7ep1MFnyMWYCfvbA3ZK9DgP_wk0PogXo1Kmvk4Q,17185
12
- langfun/core/language_model_test.py,sha256=bTyQVsH5JAxEzzzuq8VO8bVa9kiAMeiahzrxLxnOuQs,11380
10
+ langfun/core/langfunc_test.py,sha256=sQaKuZpGGmG80GRifhbxkj7nfzQLJKj4Vuw5y1s1K3U,8378
11
+ langfun/core/language_model.py,sha256=Tzswu0hyXOQOZ3fZ_Mz_Cc0ei7tVj8rTay9jJEgM6mI,17510
12
+ langfun/core/language_model_test.py,sha256=KvXXOr64TsSs3WkEALCLLZSlz09i7hBiHDOZ_8Eq8_o,13047
13
13
  langfun/core/memory.py,sha256=f-asN1F7Vehgdn_fK84v73GrEUOxRtaW934keutTKjk,2416
14
14
  langfun/core/message.py,sha256=QhvV9t5qaryPcruyxxcXi3gm9QDInkSldwTtK6sVJ3c,15734
15
15
  langfun/core/message_test.py,sha256=Z23pUM5vPnDrYkIIibe2KL73D5HKur_awI0ut_EQFQA,9501
@@ -40,25 +40,25 @@ langfun/core/coding/python/parsing_test.py,sha256=9vAWF484kWIm6JZq8NFiMgKUDhXV-d
40
40
  langfun/core/coding/python/permissions.py,sha256=1QWGHvzL8MM0Ok_auQ9tURqZHtdOfJaDpBzZ29GUE-c,2544
41
41
  langfun/core/coding/python/permissions_test.py,sha256=w5EDb8QxpxgJyZkojyzVWQvDfg366zn99-g__6TbPQ0,2699
42
42
  langfun/core/eval/__init__.py,sha256=iDA2OcJ3kR6ixZizXIY3N9LsjkaVrfTbSClTiSP8ekY,1291
43
- langfun/core/eval/base.py,sha256=Op-DO-YV8sL8mQvCfbzLfDDL6bDMuTtNYeyp5_QCBsQ,55328
44
- langfun/core/eval/base_test.py,sha256=mjdQ3ukxc7BhsVJkFJvqtz9EVhSR0OGL9j1zf_AfXR4,21540
43
+ langfun/core/eval/base.py,sha256=TZAmcdRBtzwMG1V3e_NgyJXg7J6dWMdMBrHvBnFuFho,55359
44
+ langfun/core/eval/base_test.py,sha256=OuuXFW_lX9bGhyd__kvlDSNJVne-5cSlnm-qDhyvOcc,21592
45
45
  langfun/core/eval/matching.py,sha256=aqNlYrlav7YmsB7rUlsdfoi1RLA5CYqn2RGPxRlPc78,9599
46
46
  langfun/core/eval/matching_test.py,sha256=FFHYD7IDuKe5RMjkx74ksukiwUhO5a_SS340JaIPMws,4898
47
47
  langfun/core/eval/scoring.py,sha256=aKeanBJf1yO3Q9JEtgPWoiZk_3M_GiqwXVXX7x_g22w,6172
48
48
  langfun/core/eval/scoring_test.py,sha256=YH1cIxBWtfdKcAV9Fh10vLkV5J-gxk8b6nxW4Z2u5pk,4024
49
49
  langfun/core/llms/__init__.py,sha256=gROJ8AjMq_ebXFcEfsyzYGCS6NsGfzf9d43nLu_TIdw,2504
50
- langfun/core/llms/fake.py,sha256=dVzOrW27RZ1p3DdQoRCRZs_vfoQcTcNrlWxia7oqmvw,2499
51
- langfun/core/llms/fake_test.py,sha256=Qk_Yoi4Z7P9o6f8Q_BZkaSlvxH89ZVsDxnVIbSBRBXk,3555
50
+ langfun/core/llms/fake.py,sha256=b-Xk5IPTbUt-elsyzd_i3n1tqzc_kgETXrEvgJruSMk,2824
51
+ langfun/core/llms/fake_test.py,sha256=AThvNyhZbkpsn-YO798uLgqB6TSw5XP2SKpKvcXEytw,4188
52
52
  langfun/core/llms/google_genai.py,sha256=n8zyJwh9UCTgb6-8LyvmjVNFGZQ4-zfzZ0ulkhHAnR8,8624
53
53
  langfun/core/llms/google_genai_test.py,sha256=_UcGTfl16-aDUlEWFC2W2F8y9jPUs53RBYA6MOCpGXw,7525
54
54
  langfun/core/llms/llama_cpp.py,sha256=Y_KkMUf3Xfac49koMUtUslKl3h-HWp3-ntq7Jaa3bdo,2385
55
55
  langfun/core/llms/llama_cpp_test.py,sha256=ZxC6defGd_HX9SFRU9U4cJiQnBKundbOrchbXuC1Z2M,1683
56
- langfun/core/llms/openai.py,sha256=uOJDflucpKZv3TPZwaeDSp9QMs2oDFuzh5Jm5j4dlm4,11680
57
- langfun/core/llms/openai_test.py,sha256=ulzp5uzEmEvnqZ21D0FP6eaiH1xMQ59FaLHoqA0lTgc,7570
56
+ langfun/core/llms/openai.py,sha256=1EUd8WTI6EpcU_fzD90-4M11RdL9Mj4S9zfrzUZIyGM,11463
57
+ langfun/core/llms/openai_test.py,sha256=hiByS95g3pXtjB2XfIdVCKiAZDb_-Qirb2_LsSyskpY,8166
58
58
  langfun/core/llms/cache/__init__.py,sha256=QAo3InUMDM_YpteNnVCSejI4zOsnjSMWKJKzkb3VY64,993
59
59
  langfun/core/llms/cache/base.py,sha256=cFfYvOIUae842pncqCAsRvqXCk2AnAsRYVx0mcIoAeY,3338
60
60
  langfun/core/llms/cache/in_memory.py,sha256=YfFyJEhLs73cUiB0ZfhMxYpdE8Iuxxw-dvMFwGHTSHw,4742
61
- langfun/core/llms/cache/in_memory_test.py,sha256=guHHjislh1Mj3-GBARICMh-qq5gh4fwZQ7SI5kQEAeQ,8510
61
+ langfun/core/llms/cache/in_memory_test.py,sha256=D-n26h__rVXQO51WRFhRfq5sw1oifRLx2SvCQWuNEm8,8747
62
62
  langfun/core/memories/__init__.py,sha256=HpghfZ-w1NQqzJXBx8Lz0daRhB2rcy2r9Xm491SBhC4,773
63
63
  langfun/core/memories/conversation_history.py,sha256=c9amD8hCxGFiZuVAzkP0dOMWSp8L90uvwkOejjuBqO0,1835
64
64
  langfun/core/memories/conversation_history_test.py,sha256=AaW8aNoFjxNusanwJDV0r3384Mg0eAweGmPx5DIkM0Y,2052
@@ -71,15 +71,15 @@ langfun/core/modalities/video.py,sha256=25M4XsNG5XEWRy57LYT_a6_aMURMPAgC41B3weEX
71
71
  langfun/core/modalities/video_test.py,sha256=jYuI2m8S8zDCAVBPEUbbpP205dXAht90A2_PHWo4-r8,2039
72
72
  langfun/core/structured/__init__.py,sha256=SpObW-HKpyKvkLlX8FV5ixz7CRm098j2aGfOguM3AUI,3462
73
73
  langfun/core/structured/completion.py,sha256=skBxt6V_fv2TBUKnzFgnPMbVY8HSYn8sY04MLok2yvs,7299
74
- langfun/core/structured/completion_test.py,sha256=98UCgA4gzfp6H6HgP2s2kcKs25YH3k4Nxj1rgAvmVBw,19249
74
+ langfun/core/structured/completion_test.py,sha256=0FJreSmz0Umsj47dIlOyCjBXUa7janIplXhg1CbLT4U,19301
75
75
  langfun/core/structured/description.py,sha256=SXW4MJvshFjbR-0gw6rE21o6WXq12UlRXawvDBXMZFA,5211
76
76
  langfun/core/structured/description_test.py,sha256=UtZGjSFUaQ6130t1E5tcL7ODu0xIefkapb53TbnqsK8,7362
77
77
  langfun/core/structured/mapping.py,sha256=7JInwZLmQdu7asHhC0vFLJNOCBnY-hrD6v5RQgf-xKk,11020
78
78
  langfun/core/structured/mapping_test.py,sha256=07DDCGbwytQHSMm7fCi5-Ly-JNgdV4ubHZq0wthX4A4,3338
79
79
  langfun/core/structured/parsing.py,sha256=keoVqEfzAbdULh6GawWFsTQzU91MzJXYFZjXGXLaD8g,11492
80
- langfun/core/structured/parsing_test.py,sha256=2_Uf3LYNRON1-5ysEr75xiG_cAxR3ZiixSfvUQu6mOQ,20846
80
+ langfun/core/structured/parsing_test.py,sha256=9rUe7ipRhltQv7y8NXgR98lBXhSVKnfRM9TSAyVdxbs,20980
81
81
  langfun/core/structured/prompting.py,sha256=mOmCWNVMnBk4rI7KBlEm5kmusPXoAKiWcohhzaw-s2o,7427
82
- langfun/core/structured/prompting_test.py,sha256=luJoJ16h0CkKmZv0-elOD2xLhqa7exZwHUTa9J15wqs,19894
82
+ langfun/core/structured/prompting_test.py,sha256=csOzqHRp6T3KGp7Dsm0vS-BkZdQ4ALRt09iiFNz_YmA,19945
83
83
  langfun/core/structured/schema.py,sha256=mJXirgqx3N7SA9zBO_ISHrzcV-ZRshLhnMJyCcSjGjY,25057
84
84
  langfun/core/structured/schema_generation.py,sha256=U3nRQsqmMZg_qIVDh2fiY3K4JLfsAL1LcKzIFP1iXFg,5316
85
85
  langfun/core/structured/schema_generation_test.py,sha256=cfZyP0gHno2fXy_c9vsVdvHmqKQSfuyUsCtfO3JFmYQ,2945
@@ -94,9 +94,9 @@ langfun/core/templates/conversation_test.py,sha256=RryYyIhfc34dLWOs6GfPQ8HU8mXpK
94
94
  langfun/core/templates/demonstration.py,sha256=vCrgYubdZM5Umqcgp8NUVGXgr4P_c-fikKhwhzwhpKI,1460
95
95
  langfun/core/templates/demonstration_test.py,sha256=SafcDQ0WgI7pw05EmPI2S4v1t3ABKzup8jReCljHeK4,2162
96
96
  langfun/core/templates/selfplay.py,sha256=yhgrJbiYwq47TgzThmHrDQTF4nDrTI09CWGhuQPNv-s,2273
97
- langfun/core/templates/selfplay_test.py,sha256=IB5rWbjK_9CTkqEo1BclQPzFAKcIiusJckH8J19HFgI,2096
98
- langfun-0.0.2.dev20240414.dist-info/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
99
- langfun-0.0.2.dev20240414.dist-info/METADATA,sha256=DhUU4VmRzC-JUs0fid1_V7miqyIwuSrpWs7NJII2yF8,3405
100
- langfun-0.0.2.dev20240414.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
101
- langfun-0.0.2.dev20240414.dist-info/top_level.txt,sha256=RhlEkHxs1qtzmmtWSwYoLVJAc1YrbPtxQ52uh8Z9VvY,8
102
- langfun-0.0.2.dev20240414.dist-info/RECORD,,
97
+ langfun/core/templates/selfplay_test.py,sha256=DYVrkk7uNKCqJGEHH31HssU2BPuMItU1vJLzfcXIlYg,2156
98
+ langfun-0.0.2.dev20240415.dist-info/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
99
+ langfun-0.0.2.dev20240415.dist-info/METADATA,sha256=V_zKk0hFMrBR6jMyr0C0v71Y4RJ9GL9b0uAkBerHIIw,3405
100
+ langfun-0.0.2.dev20240415.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
101
+ langfun-0.0.2.dev20240415.dist-info/top_level.txt,sha256=RhlEkHxs1qtzmmtWSwYoLVJAc1YrbPtxQ52uh8Z9VvY,8
102
+ langfun-0.0.2.dev20240415.dist-info/RECORD,,