langfun 0.0.2.dev20240414__py3-none-any.whl → 0.0.2.dev20240418__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
langfun/__init__.py CHANGED
@@ -34,6 +34,7 @@ score = structured.score
34
34
  generate_class = structured.generate_class
35
35
 
36
36
  source_form = structured.source_form
37
+ function_gen = structured.function_gen
37
38
 
38
39
  from langfun.core import eval # pylint: disable=redefined-builtin
39
40
  from langfun.core import templates
langfun/core/__init__.py CHANGED
@@ -99,6 +99,7 @@ from langfun.core.modality import ModalityRef
99
99
  from langfun.core.language_model import LanguageModel
100
100
  from langfun.core.language_model import LMSample
101
101
  from langfun.core.language_model import LMSamplingOptions
102
+ from langfun.core.language_model import LMSamplingUsage
102
103
  from langfun.core.language_model import LMSamplingResult
103
104
  from langfun.core.language_model import LMScoringResult
104
105
  from langfun.core.language_model import LMCache
langfun/core/eval/base.py CHANGED
@@ -1565,6 +1565,7 @@ class Summary(pg.Object):
1565
1565
  results.append(
1566
1566
  pg.Dict(
1567
1567
  experiment=entry,
1568
+ dir=entry.dir,
1568
1569
  metrics=entry.result.metrics if entry.result else None,
1569
1570
  )
1570
1571
  )
@@ -194,6 +194,7 @@ class EvaluationTest(unittest.TestCase):
194
194
  cache_seed=0,
195
195
  score=1.0,
196
196
  logprobs=None,
197
+ usage=lf.LMSamplingUsage(387, 24, 411),
197
198
  tags=['lm-response', 'lm-output', 'transformed'],
198
199
  ),
199
200
  )
@@ -82,7 +82,9 @@ class LangFuncCallTest(unittest.TestCase):
82
82
  self.assertEqual(i.tags, ['rendered'])
83
83
 
84
84
  r = l()
85
- self.assertEqual(r, message.AIMessage('Hello!!!', score=0.0, logprobs=None))
85
+ self.assertEqual(
86
+ r, message.AIMessage('Hello!!!', score=0.0, logprobs=None, usage=None)
87
+ )
86
88
  self.assertEqual(r.tags, ['lm-response', 'lm-output'])
87
89
  self.assertEqual(r.source, message.UserMessage('Hello'))
88
90
  self.assertEqual(r.source.tags, ['rendered', 'lm-input'])
@@ -106,7 +108,7 @@ class LangFuncCallTest(unittest.TestCase):
106
108
  self.assertEqual(l.render(), 'Hello')
107
109
  r = l()
108
110
  self.assertEqual(
109
- r, message.AIMessage('Hello!!!', score=0.0, logprobs=None)
111
+ r, message.AIMessage('Hello!!!', score=0.0, logprobs=None, usage=None)
110
112
  )
111
113
  self.assertEqual(r.tags, ['lm-response', 'lm-output'])
112
114
 
@@ -47,6 +47,14 @@ class LMSample(pg.Object):
47
47
  ] = None
48
48
 
49
49
 
50
+ class LMSamplingUsage(pg.Object):
51
+ """Usage information per completion."""
52
+
53
+ prompt_tokens: int
54
+ completion_tokens: int
55
+ total_tokens: int
56
+
57
+
50
58
  class LMSamplingResult(pg.Object):
51
59
  """Language model response."""
52
60
 
@@ -58,6 +66,11 @@ class LMSamplingResult(pg.Object):
58
66
  ),
59
67
  ] = []
60
68
 
69
+ usage: Annotated[
70
+ LMSamplingUsage | None,
71
+ 'Usage information. Currently only OpenAI models are supported.',
72
+ ] = None
73
+
61
74
 
62
75
  class LMSamplingOptions(component.Component):
63
76
  """Language model sampling options."""
@@ -424,6 +437,8 @@ class LanguageModel(component.Component):
424
437
  logprobs = result.samples[0].logprobs
425
438
  response.set('score', result.samples[0].score)
426
439
  response.metadata.logprobs = logprobs
440
+ response.metadata.usage = result.usage
441
+
427
442
  elapse = time.time() - request_start
428
443
  self._debug(prompt, response, call_counter, elapse)
429
444
  return response
@@ -38,9 +38,19 @@ class MockModel(lm_lib.LanguageModel):
38
38
  def fake_sample(prompts):
39
39
  if context.attempt >= self.failures_before_attempt:
40
40
  return [
41
- lm_lib.LMSamplingResult([lm_lib.LMSample( # pylint: disable=g-complex-comprehension
42
- response=prompt.text * self.sampling_options.top_k,
43
- score=self.sampling_options.temperature or -1.0)])
41
+ lm_lib.LMSamplingResult(
42
+ [
43
+ lm_lib.LMSample( # pylint: disable=g-complex-comprehension
44
+ response=prompt.text * self.sampling_options.top_k,
45
+ score=self.sampling_options.temperature or -1.0,
46
+ )
47
+ ],
48
+ usage=lm_lib.LMSamplingUsage(
49
+ prompt_tokens=100,
50
+ completion_tokens=100,
51
+ total_tokens=200,
52
+ ),
53
+ )
44
54
  for prompt in prompts
45
55
  ]
46
56
  context.attempt += 1
@@ -100,8 +110,14 @@ class LanguageModelTest(unittest.TestCase):
100
110
  self.assertEqual(
101
111
  lm.sample(prompts=['foo', 'bar']),
102
112
  [
103
- lm_lib.LMSamplingResult([lm_lib.LMSample('foo', score=-1.0)]),
104
- lm_lib.LMSamplingResult([lm_lib.LMSample('bar', score=-1.0)]),
113
+ lm_lib.LMSamplingResult(
114
+ [lm_lib.LMSample('foo', score=-1.0)],
115
+ usage=lm_lib.LMSamplingUsage(100, 100, 200),
116
+ ),
117
+ lm_lib.LMSamplingResult(
118
+ [lm_lib.LMSample('bar', score=-1.0)],
119
+ usage=lm_lib.LMSamplingUsage(100, 100, 200),
120
+ ),
105
121
  ],
106
122
  )
107
123
  # Test override sampling_options.
@@ -112,10 +128,12 @@ class LanguageModelTest(unittest.TestCase):
112
128
  ),
113
129
  [
114
130
  lm_lib.LMSamplingResult(
115
- [lm_lib.LMSample('foo' * 2, score=0.5)]
131
+ [lm_lib.LMSample('foo' * 2, score=0.5)],
132
+ usage=lm_lib.LMSamplingUsage(100, 100, 200),
116
133
  ),
117
134
  lm_lib.LMSamplingResult(
118
- [lm_lib.LMSample('bar' * 2, score=0.5)]
135
+ [lm_lib.LMSample('bar' * 2, score=0.5)],
136
+ usage=lm_lib.LMSamplingUsage(100, 100, 200),
119
137
  ),
120
138
  ],
121
139
  )
@@ -123,18 +141,26 @@ class LanguageModelTest(unittest.TestCase):
123
141
  self.assertEqual(
124
142
  lm.sample(prompts=['foo', 'bar'], temperature=1.0),
125
143
  [
126
- lm_lib.LMSamplingResult([lm_lib.LMSample('foo', score=1.0)]),
127
- lm_lib.LMSamplingResult([lm_lib.LMSample('bar', score=1.0)]),
144
+ lm_lib.LMSamplingResult(
145
+ [lm_lib.LMSample('foo', score=1.0)],
146
+ usage=lm_lib.LMSamplingUsage(100, 100, 200),
147
+ ),
148
+ lm_lib.LMSamplingResult(
149
+ [lm_lib.LMSample('bar', score=1.0)],
150
+ usage=lm_lib.LMSamplingUsage(100, 100, 200),
151
+ ),
128
152
  ],
129
153
  )
130
154
  self.assertEqual(
131
155
  lm.sample(prompts=['foo', 'bar'], top_k=2, temperature=0.7),
132
156
  [
133
157
  lm_lib.LMSamplingResult(
134
- [lm_lib.LMSample('foo' * 2, score=0.7)]
158
+ [lm_lib.LMSample('foo' * 2, score=0.7)],
159
+ usage=lm_lib.LMSamplingUsage(100, 100, 200),
135
160
  ),
136
161
  lm_lib.LMSamplingResult(
137
- [lm_lib.LMSample('bar' * 2, score=0.7)]
162
+ [lm_lib.LMSample('bar' * 2, score=0.7)],
163
+ usage=lm_lib.LMSamplingUsage(100, 100, 200),
138
164
  ),
139
165
  ],
140
166
  )
@@ -144,6 +170,8 @@ class LanguageModelTest(unittest.TestCase):
144
170
  response = lm(prompt='foo')
145
171
  self.assertEqual(response.text, 'foo')
146
172
  self.assertEqual(response.score, -1.0)
173
+ self.assertIsNone(response.logprobs)
174
+ self.assertEqual(response.usage, lm_lib.LMSamplingUsage(100, 100, 200))
147
175
 
148
176
  # Test override sampling_options.
149
177
  self.assertEqual(
@@ -158,11 +186,24 @@ class LanguageModelTest(unittest.TestCase):
158
186
  self.assertEqual(
159
187
  lm.sample(prompts=['foo', 'bar']),
160
188
  [
161
- lm_lib.LMSamplingResult([lm_lib.LMSample(
162
- message_lib.AIMessage('foo', cache_seed=0), score=-1.0)]),
163
- lm_lib.LMSamplingResult([lm_lib.LMSample(
164
- message_lib.AIMessage('bar', cache_seed=0), score=-1.0)]),
165
- ])
189
+ lm_lib.LMSamplingResult(
190
+ [
191
+ lm_lib.LMSample(
192
+ message_lib.AIMessage('foo', cache_seed=0), score=-1.0
193
+ )
194
+ ],
195
+ usage=lm_lib.LMSamplingUsage(100, 100, 200),
196
+ ),
197
+ lm_lib.LMSamplingResult(
198
+ [
199
+ lm_lib.LMSample(
200
+ message_lib.AIMessage('bar', cache_seed=0), score=-1.0
201
+ )
202
+ ],
203
+ usage=lm_lib.LMSamplingUsage(100, 100, 200),
204
+ ),
205
+ ],
206
+ )
166
207
  self.assertEqual(cache.stats.num_queries, 2)
167
208
  self.assertEqual(cache.stats.num_hits, 0)
168
209
  self.assertEqual(cache.stats.num_updates, 2)
@@ -181,10 +222,22 @@ class LanguageModelTest(unittest.TestCase):
181
222
  self.assertEqual(
182
223
  lm.sample(prompts=['foo', 'baz'], temperature=1.0),
183
224
  [
184
- lm_lib.LMSamplingResult([lm_lib.LMSample(
185
- message_lib.AIMessage('foo', cache_seed=0), score=1.0)]),
186
- lm_lib.LMSamplingResult([lm_lib.LMSample(
187
- message_lib.AIMessage('baz', cache_seed=0), score=1.0)]),
225
+ lm_lib.LMSamplingResult(
226
+ [
227
+ lm_lib.LMSample(
228
+ message_lib.AIMessage('foo', cache_seed=0), score=1.0
229
+ )
230
+ ],
231
+ usage=lm_lib.LMSamplingUsage(100, 100, 200),
232
+ ),
233
+ lm_lib.LMSamplingResult(
234
+ [
235
+ lm_lib.LMSample(
236
+ message_lib.AIMessage('baz', cache_seed=0), score=1.0
237
+ )
238
+ ],
239
+ usage=lm_lib.LMSamplingUsage(100, 100, 200),
240
+ ),
188
241
  ],
189
242
  )
190
243
  self.assertEqual(cache.stats.num_queries, 6)
@@ -62,10 +62,19 @@ class InMemoryLMCacheTest(unittest.TestCase):
62
62
 
63
63
  def cache_entry(response_text, cache_seed=0):
64
64
  return base.LMCacheEntry(
65
- lf.LMSamplingResult([
66
- lf.LMSample(
67
- lf.AIMessage(response_text, cache_seed=cache_seed), score=1.0)
68
- ])
65
+ lf.LMSamplingResult(
66
+ [
67
+ lf.LMSample(
68
+ lf.AIMessage(response_text, cache_seed=cache_seed),
69
+ score=1.0
70
+ )
71
+ ],
72
+ usage=lf.LMSamplingUsage(
73
+ 1,
74
+ len(response_text),
75
+ len(response_text) + 1,
76
+ )
77
+ )
69
78
  )
70
79
 
71
80
  self.assertEqual(
langfun/core/llms/fake.py CHANGED
@@ -13,6 +13,7 @@
13
13
  # limitations under the License.
14
14
  """Fake LMs for testing."""
15
15
 
16
+ import abc
16
17
  from typing import Annotated
17
18
  import langfun.core as lf
18
19
 
@@ -23,15 +24,32 @@ class Fake(lf.LanguageModel):
23
24
  def _score(self, prompt: lf.Message, completions: list[lf.Message]):
24
25
  return [lf.LMScoringResult(score=-i * 1.0) for i in range(len(completions))]
25
26
 
27
+ def _sample(self, prompts: list[lf.Message]) -> list[lf.LMSamplingResult]:
28
+ results = []
29
+ for prompt in prompts:
30
+ response = self._response_from(prompt)
31
+ results.append(
32
+ lf.LMSamplingResult(
33
+ [lf.LMSample(response, 1.0)],
34
+ usage=lf.LMSamplingUsage(
35
+ prompt_tokens=len(prompt.text),
36
+ completion_tokens=len(response.text),
37
+ total_tokens=len(prompt.text) + len(response.text),
38
+ )
39
+ )
40
+ )
41
+ return results
42
+
43
+ @abc.abstractmethod
44
+ def _response_from(self, prompt: lf.Message) -> lf.Message:
45
+ """Returns the response for the given prompt."""
46
+
26
47
 
27
48
  class Echo(Fake):
28
49
  """A simple echo language model for testing."""
29
50
 
30
- def _sample(self, prompts: list[lf.Message]) -> list[lf.LMSamplingResult]:
31
- return [
32
- lf.LMSamplingResult([lf.LMSample(prompt.text, 1.0)])
33
- for prompt in prompts
34
- ]
51
+ def _response_from(self, prompt: lf.Message) -> lf.Message:
52
+ return lf.AIMessage(prompt.text)
35
53
 
36
54
 
37
55
  @lf.use_init_args(['response'])
@@ -43,11 +61,8 @@ class StaticResponse(Fake):
43
61
  'A canned response that will be returned regardless of the prompt.'
44
62
  ]
45
63
 
46
- def _sample(self, prompts: list[lf.Message]) -> list[lf.LMSamplingResult]:
47
- return [
48
- lf.LMSamplingResult([lf.LMSample(self.response, 1.0)])
49
- for _ in prompts
50
- ]
64
+ def _response_from(self, prompt: lf.Message) -> lf.Message:
65
+ return lf.AIMessage(self.response)
51
66
 
52
67
 
53
68
  @lf.use_init_args(['mapping'])
@@ -59,11 +74,8 @@ class StaticMapping(Fake):
59
74
  'A mapping from prompt to response.'
60
75
  ]
61
76
 
62
- def _sample(self, prompts: list[str]) -> list[lf.LMSamplingResult]:
63
- return [
64
- lf.LMSamplingResult([lf.LMSample(self.mapping[prompt], 1.0)])
65
- for prompt in prompts
66
- ]
77
+ def _response_from(self, prompt: lf.Message) -> lf.Message:
78
+ return lf.AIMessage(self.mapping[prompt])
67
79
 
68
80
 
69
81
  @lf.use_init_args(['sequence'])
@@ -79,10 +91,7 @@ class StaticSequence(Fake):
79
91
  super()._on_bound()
80
92
  self._pos = 0
81
93
 
82
- def _sample(self, prompts: list[str]) -> list[lf.LMSamplingResult]:
83
- results = []
84
- for _ in prompts:
85
- results.append(lf.LMSamplingResult(
86
- [lf.LMSample(self.sequence[self._pos], 1.0)]))
87
- self._pos += 1
88
- return results
94
+ def _response_from(self, prompt: lf.Message) -> lf.Message:
95
+ r = lf.AIMessage(self.sequence[self._pos])
96
+ self._pos += 1
97
+ return r
@@ -25,7 +25,12 @@ class EchoTest(unittest.TestCase):
25
25
  def test_sample(self):
26
26
  lm = fakelm.Echo()
27
27
  self.assertEqual(
28
- lm.sample(['hi']), [lf.LMSamplingResult([lf.LMSample('hi', 1.0)])]
28
+ lm.sample(['hi']),
29
+ [
30
+ lf.LMSamplingResult(
31
+ [lf.LMSample('hi', 1.0)],
32
+ lf.LMSamplingUsage(2, 2, 4))
33
+ ]
29
34
  )
30
35
 
31
36
  def test_call(self):
@@ -53,11 +58,21 @@ class StaticResponseTest(unittest.TestCase):
53
58
  lm = fakelm.StaticResponse(canned_response)
54
59
  self.assertEqual(
55
60
  lm.sample(['hi']),
56
- [lf.LMSamplingResult([lf.LMSample(canned_response, 1.0)])],
61
+ [
62
+ lf.LMSamplingResult(
63
+ [lf.LMSample(canned_response, 1.0)],
64
+ usage=lf.LMSamplingUsage(2, 38, 40)
65
+ )
66
+ ],
57
67
  )
58
68
  self.assertEqual(
59
69
  lm.sample(['Tell me a joke.']),
60
- [lf.LMSamplingResult([lf.LMSample(canned_response, 1.0)])],
70
+ [
71
+ lf.LMSamplingResult(
72
+ [lf.LMSample(canned_response, 1.0)],
73
+ usage=lf.LMSamplingUsage(15, 38, 53)
74
+ )
75
+ ],
61
76
  )
62
77
 
63
78
  def test_call(self):
@@ -85,8 +100,14 @@ class StaticMappingTest(unittest.TestCase):
85
100
  self.assertEqual(
86
101
  lm.sample(['Hi', 'How are you?']),
87
102
  [
88
- lf.LMSamplingResult([lf.LMSample('Hello', 1.0)]),
89
- lf.LMSamplingResult([lf.LMSample('I am fine, how about you?', 1.0)])
103
+ lf.LMSamplingResult(
104
+ [lf.LMSample('Hello', 1.0)],
105
+ usage=lf.LMSamplingUsage(2, 5, 7)
106
+ ),
107
+ lf.LMSamplingResult(
108
+ [lf.LMSample('I am fine, how about you?', 1.0)],
109
+ usage=lf.LMSamplingUsage(12, 25, 37)
110
+ )
90
111
  ]
91
112
  )
92
113
  with self.assertRaises(KeyError):
@@ -104,8 +125,14 @@ class StaticSequenceTest(unittest.TestCase):
104
125
  self.assertEqual(
105
126
  lm.sample(['Hi', 'How are you?']),
106
127
  [
107
- lf.LMSamplingResult([lf.LMSample('Hello', 1.0)]),
108
- lf.LMSamplingResult([lf.LMSample('I am fine, how about you?', 1.0)])
128
+ lf.LMSamplingResult(
129
+ [lf.LMSample('Hello', 1.0)],
130
+ usage=lf.LMSamplingUsage(2, 5, 7)
131
+ ),
132
+ lf.LMSamplingResult(
133
+ [lf.LMSample('I am fine, how about you?', 1.0)],
134
+ usage=lf.LMSamplingUsage(12, 25, 37)
135
+ )
109
136
  ]
110
137
  )
111
138
  with self.assertRaises(IndexError):
@@ -26,20 +26,6 @@ from openai import openai_object
26
26
  import pyglove as pg
27
27
 
28
28
 
29
- class Usage(pg.Object):
30
- """Usage information per completion."""
31
-
32
- prompt_tokens: int
33
- completion_tokens: int
34
- total_tokens: int
35
-
36
-
37
- class LMSamplingResult(lf.LMSamplingResult):
38
- """LMSamplingResult with usage information."""
39
-
40
- usage: Usage | None = None
41
-
42
-
43
29
  SUPPORTED_MODELS_AND_SETTINGS = [
44
30
  # Model name, max concurrent requests.
45
31
  # The concurrent requests is estimated by TPM/RPM from
@@ -181,7 +167,7 @@ class OpenAI(lf.LanguageModel):
181
167
  args['stop'] = options.stop
182
168
  return args
183
169
 
184
- def _sample(self, prompts: list[lf.Message]) -> list[LMSamplingResult]:
170
+ def _sample(self, prompts: list[lf.Message]) -> list[lf.LMSamplingResult]:
185
171
  assert self._api_initialized
186
172
  if self.is_chat_model:
187
173
  return self._chat_complete_batch(prompts)
@@ -189,7 +175,8 @@ class OpenAI(lf.LanguageModel):
189
175
  return self._complete_batch(prompts)
190
176
 
191
177
  def _complete_batch(
192
- self, prompts: list[lf.Message]) -> list[LMSamplingResult]:
178
+ self, prompts: list[lf.Message]
179
+ ) -> list[lf.LMSamplingResult]:
193
180
 
194
181
  def _open_ai_completion(prompts):
195
182
  response = openai.Completion.create(
@@ -204,13 +191,13 @@ class OpenAI(lf.LanguageModel):
204
191
  lf.LMSample(choice.text.strip(), score=choice.logprobs or 0.0)
205
192
  )
206
193
 
207
- usage = Usage(
194
+ usage = lf.LMSamplingUsage(
208
195
  prompt_tokens=response.usage.prompt_tokens,
209
196
  completion_tokens=response.usage.completion_tokens,
210
197
  total_tokens=response.usage.total_tokens,
211
198
  )
212
199
  return [
213
- LMSamplingResult(
200
+ lf.LMSamplingResult(
214
201
  samples_by_index[index], usage=usage if index == 0 else None
215
202
  )
216
203
  for index in sorted(samples_by_index.keys())
@@ -231,7 +218,7 @@ class OpenAI(lf.LanguageModel):
231
218
 
232
219
  def _chat_complete_batch(
233
220
  self, prompts: list[lf.Message]
234
- ) -> list[LMSamplingResult]:
221
+ ) -> list[lf.LMSamplingResult]:
235
222
  def _open_ai_chat_completion(prompt: lf.Message):
236
223
  if self.multimodal:
237
224
  content = []
@@ -272,9 +259,9 @@ class OpenAI(lf.LanguageModel):
272
259
  )
273
260
  )
274
261
 
275
- return LMSamplingResult(
262
+ return lf.LMSamplingResult(
276
263
  samples=samples,
277
- usage=Usage(
264
+ usage=lf.LMSamplingUsage(
278
265
  prompt_tokens=response.usage.prompt_tokens,
279
266
  completion_tokens=response.usage.completion_tokens,
280
267
  total_tokens=response.usage.total_tokens,
@@ -32,11 +32,14 @@ def mock_completion_query(prompt, *, n=1, **kwargs):
32
32
  text=f'Sample {k} for prompt {i}.',
33
33
  logprobs=k / 10,
34
34
  ))
35
- return pg.Dict(choices=choices, usage=openai.Usage(
36
- prompt_tokens=100,
37
- completion_tokens=100,
38
- total_tokens=200,
39
- ))
35
+ return pg.Dict(
36
+ choices=choices,
37
+ usage=lf.LMSamplingUsage(
38
+ prompt_tokens=100,
39
+ completion_tokens=100,
40
+ total_tokens=200,
41
+ ),
42
+ )
40
43
 
41
44
 
42
45
  def mock_chat_completion_query(messages, *, n=1, **kwargs):
@@ -49,11 +52,14 @@ def mock_chat_completion_query(messages, *, n=1, **kwargs):
49
52
  ),
50
53
  logprobs=None,
51
54
  ))
52
- return pg.Dict(choices=choices, usage=openai.Usage(
53
- prompt_tokens=100,
54
- completion_tokens=100,
55
- total_tokens=200,
56
- ))
55
+ return pg.Dict(
56
+ choices=choices,
57
+ usage=lf.LMSamplingUsage(
58
+ prompt_tokens=100,
59
+ completion_tokens=100,
60
+ total_tokens=200,
61
+ ),
62
+ )
57
63
 
58
64
 
59
65
  def mock_chat_completion_query_vision(messages, *, n=1, **kwargs):
@@ -69,11 +75,14 @@ def mock_chat_completion_query_vision(messages, *, n=1, **kwargs):
69
75
  ),
70
76
  logprobs=None,
71
77
  ))
72
- return pg.Dict(choices=choices, usage=openai.Usage(
73
- prompt_tokens=100,
74
- completion_tokens=100,
75
- total_tokens=200,
76
- ))
78
+ return pg.Dict(
79
+ choices=choices,
80
+ usage=lf.LMSamplingUsage(
81
+ prompt_tokens=100,
82
+ completion_tokens=100,
83
+ total_tokens=200,
84
+ ),
85
+ )
77
86
 
78
87
 
79
88
  class OpenaiTest(unittest.TestCase):
@@ -169,18 +178,28 @@ class OpenaiTest(unittest.TestCase):
169
178
  )
170
179
 
171
180
  self.assertEqual(len(results), 2)
172
- self.assertEqual(results[0], openai.LMSamplingResult([
173
- lf.LMSample('Sample 0 for prompt 0.', score=0.0),
174
- lf.LMSample('Sample 1 for prompt 0.', score=0.1),
175
- lf.LMSample('Sample 2 for prompt 0.', score=0.2),
176
- ], usage=openai.Usage(
177
- prompt_tokens=100, completion_tokens=100, total_tokens=200)))
178
-
179
- self.assertEqual(results[1], openai.LMSamplingResult([
180
- lf.LMSample('Sample 0 for prompt 1.', score=0.0),
181
- lf.LMSample('Sample 1 for prompt 1.', score=0.1),
182
- lf.LMSample('Sample 2 for prompt 1.', score=0.2),
183
- ]))
181
+ self.assertEqual(
182
+ results[0],
183
+ lf.LMSamplingResult(
184
+ [
185
+ lf.LMSample('Sample 0 for prompt 0.', score=0.0),
186
+ lf.LMSample('Sample 1 for prompt 0.', score=0.1),
187
+ lf.LMSample('Sample 2 for prompt 0.', score=0.2),
188
+ ],
189
+ usage=lf.LMSamplingUsage(
190
+ prompt_tokens=100, completion_tokens=100, total_tokens=200
191
+ ),
192
+ ),
193
+ )
194
+
195
+ self.assertEqual(
196
+ results[1],
197
+ lf.LMSamplingResult([
198
+ lf.LMSample('Sample 0 for prompt 1.', score=0.0),
199
+ lf.LMSample('Sample 1 for prompt 1.', score=0.1),
200
+ lf.LMSample('Sample 2 for prompt 1.', score=0.2),
201
+ ]),
202
+ )
184
203
 
185
204
  def test_sample_chat_completion(self):
186
205
  with mock.patch('openai.ChatCompletion.create') as mock_chat_completion:
@@ -191,18 +210,32 @@ class OpenaiTest(unittest.TestCase):
191
210
  )
192
211
 
193
212
  self.assertEqual(len(results), 2)
194
- self.assertEqual(results[0], openai.LMSamplingResult([
195
- lf.LMSample('Sample 0 for message.', score=0.0),
196
- lf.LMSample('Sample 1 for message.', score=0.0),
197
- lf.LMSample('Sample 2 for message.', score=0.0),
198
- ], usage=openai.Usage(
199
- prompt_tokens=100, completion_tokens=100, total_tokens=200)))
200
- self.assertEqual(results[1], openai.LMSamplingResult([
201
- lf.LMSample('Sample 0 for message.', score=0.0),
202
- lf.LMSample('Sample 1 for message.', score=0.0),
203
- lf.LMSample('Sample 2 for message.', score=0.0),
204
- ], usage=openai.Usage(
205
- prompt_tokens=100, completion_tokens=100, total_tokens=200)))
213
+ self.assertEqual(
214
+ results[0],
215
+ lf.LMSamplingResult(
216
+ [
217
+ lf.LMSample('Sample 0 for message.', score=0.0),
218
+ lf.LMSample('Sample 1 for message.', score=0.0),
219
+ lf.LMSample('Sample 2 for message.', score=0.0),
220
+ ],
221
+ usage=lf.LMSamplingUsage(
222
+ prompt_tokens=100, completion_tokens=100, total_tokens=200
223
+ ),
224
+ ),
225
+ )
226
+ self.assertEqual(
227
+ results[1],
228
+ lf.LMSamplingResult(
229
+ [
230
+ lf.LMSample('Sample 0 for message.', score=0.0),
231
+ lf.LMSample('Sample 1 for message.', score=0.0),
232
+ lf.LMSample('Sample 2 for message.', score=0.0),
233
+ ],
234
+ usage=lf.LMSamplingUsage(
235
+ prompt_tokens=100, completion_tokens=100, total_tokens=200
236
+ ),
237
+ ),
238
+ )
206
239
 
207
240
  def test_sample_with_contextual_options(self):
208
241
  with mock.patch('openai.Completion.create') as mock_completion:
@@ -212,11 +245,18 @@ class OpenaiTest(unittest.TestCase):
212
245
  results = lm.sample(['hello'])
213
246
 
214
247
  self.assertEqual(len(results), 1)
215
- self.assertEqual(results[0], openai.LMSamplingResult([
216
- lf.LMSample('Sample 0 for prompt 0.', score=0.0),
217
- lf.LMSample('Sample 1 for prompt 0.', score=0.1),
218
- ], usage=openai.Usage(
219
- prompt_tokens=100, completion_tokens=100, total_tokens=200)))
248
+ self.assertEqual(
249
+ results[0],
250
+ lf.LMSamplingResult(
251
+ [
252
+ lf.LMSample('Sample 0 for prompt 0.', score=0.0),
253
+ lf.LMSample('Sample 1 for prompt 0.', score=0.1),
254
+ ],
255
+ usage=lf.LMSamplingUsage(
256
+ prompt_tokens=100, completion_tokens=100, total_tokens=200
257
+ ),
258
+ ),
259
+ )
220
260
 
221
261
 
222
262
  if __name__ == '__main__':