langfun 0.0.2.dev20240330__py3-none-any.whl → 0.1.2.dev202501140804__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/__init__.py +22 -2
- langfun/core/__init__.py +17 -5
- langfun/core/agentic/__init__.py +30 -0
- langfun/core/agentic/action.py +854 -0
- langfun/core/agentic/action_eval.py +150 -0
- langfun/core/agentic/action_eval_test.py +109 -0
- langfun/core/agentic/action_test.py +136 -0
- langfun/core/coding/python/__init__.py +5 -11
- langfun/core/coding/python/correction.py +37 -28
- langfun/core/coding/python/correction_test.py +29 -3
- langfun/core/coding/python/execution.py +40 -216
- langfun/core/coding/python/execution_test.py +29 -89
- langfun/core/coding/python/generation.py +21 -11
- langfun/core/coding/python/generation_test.py +2 -2
- langfun/core/coding/python/parsing.py +108 -193
- langfun/core/coding/python/parsing_test.py +2 -105
- langfun/core/component.py +69 -2
- langfun/core/component_test.py +54 -0
- langfun/core/concurrent.py +414 -117
- langfun/core/concurrent_test.py +111 -24
- langfun/core/console.py +18 -5
- langfun/core/console_test.py +17 -0
- langfun/core/eval/__init__.py +17 -0
- langfun/core/eval/base.py +767 -140
- langfun/core/eval/base_test.py +238 -53
- langfun/core/eval/matching.py +80 -76
- langfun/core/eval/matching_test.py +19 -9
- langfun/core/eval/patching.py +130 -0
- langfun/core/eval/patching_test.py +170 -0
- langfun/core/eval/scoring.py +37 -28
- langfun/core/eval/scoring_test.py +21 -3
- langfun/core/eval/v2/__init__.py +42 -0
- langfun/core/eval/v2/checkpointing.py +380 -0
- langfun/core/eval/v2/checkpointing_test.py +228 -0
- langfun/core/eval/v2/eval_test_helper.py +136 -0
- langfun/core/eval/v2/evaluation.py +725 -0
- langfun/core/eval/v2/evaluation_test.py +180 -0
- langfun/core/eval/v2/example.py +305 -0
- langfun/core/eval/v2/example_test.py +128 -0
- langfun/core/eval/v2/experiment.py +1048 -0
- langfun/core/eval/v2/experiment_test.py +433 -0
- langfun/core/eval/v2/metric_values.py +156 -0
- langfun/core/eval/v2/metric_values_test.py +80 -0
- langfun/core/eval/v2/metrics.py +357 -0
- langfun/core/eval/v2/metrics_test.py +203 -0
- langfun/core/eval/v2/progress.py +348 -0
- langfun/core/eval/v2/progress_test.py +82 -0
- langfun/core/eval/v2/progress_tracking.py +210 -0
- langfun/core/eval/v2/progress_tracking_test.py +66 -0
- langfun/core/eval/v2/reporting.py +270 -0
- langfun/core/eval/v2/reporting_test.py +158 -0
- langfun/core/eval/v2/runners.py +488 -0
- langfun/core/eval/v2/runners_test.py +334 -0
- langfun/core/langfunc.py +3 -21
- langfun/core/langfunc_test.py +26 -8
- langfun/core/language_model.py +686 -48
- langfun/core/language_model_test.py +681 -44
- langfun/core/llms/__init__.py +100 -12
- langfun/core/llms/anthropic.py +488 -0
- langfun/core/llms/anthropic_test.py +235 -0
- langfun/core/llms/cache/base.py +21 -2
- langfun/core/llms/cache/in_memory.py +13 -0
- langfun/core/llms/cache/in_memory_test.py +88 -28
- langfun/core/llms/compositional.py +101 -0
- langfun/core/llms/compositional_test.py +73 -0
- langfun/core/llms/deepseek.py +117 -0
- langfun/core/llms/deepseek_test.py +61 -0
- langfun/core/llms/fake.py +39 -26
- langfun/core/llms/fake_test.py +136 -11
- langfun/core/llms/gemini.py +507 -0
- langfun/core/llms/gemini_test.py +195 -0
- langfun/core/llms/google_genai.py +62 -218
- langfun/core/llms/google_genai_test.py +9 -197
- langfun/core/llms/groq.py +276 -0
- langfun/core/llms/groq_test.py +64 -0
- langfun/core/llms/llama_cpp.py +15 -40
- langfun/core/llms/llama_cpp_test.py +4 -30
- langfun/core/llms/openai.py +436 -226
- langfun/core/llms/openai_compatible.py +179 -0
- langfun/core/llms/openai_compatible_test.py +495 -0
- langfun/core/llms/openai_test.py +35 -174
- langfun/core/llms/rest.py +113 -0
- langfun/core/llms/rest_test.py +111 -0
- langfun/core/llms/vertexai.py +192 -0
- langfun/core/llms/vertexai_test.py +52 -0
- langfun/core/logging.py +284 -0
- langfun/core/logging_test.py +125 -0
- langfun/core/message.py +319 -9
- langfun/core/message_test.py +190 -13
- langfun/core/modalities/__init__.py +6 -2
- langfun/core/modalities/audio.py +30 -0
- langfun/core/modalities/audio_test.py +63 -0
- langfun/core/modalities/image.py +39 -20
- langfun/core/modalities/image_test.py +52 -9
- langfun/core/modalities/mime.py +206 -29
- langfun/core/modalities/mime_test.py +90 -9
- langfun/core/modalities/ms_office.py +117 -0
- langfun/core/modalities/ms_office_test.py +389 -0
- langfun/core/modalities/pdf.py +22 -0
- langfun/core/modalities/pdf_test.py +57 -0
- langfun/core/modalities/video.py +9 -23
- langfun/core/modalities/video_test.py +3 -3
- langfun/core/modality.py +26 -3
- langfun/core/modality_test.py +2 -2
- langfun/core/sampling.py +11 -11
- langfun/core/structured/__init__.py +15 -16
- langfun/core/structured/completion.py +32 -5
- langfun/core/structured/completion_test.py +9 -8
- langfun/core/structured/description.py +2 -2
- langfun/core/structured/description_test.py +3 -3
- langfun/core/structured/function_generation.py +278 -0
- langfun/core/structured/function_generation_test.py +399 -0
- langfun/core/structured/mapping.py +150 -46
- langfun/core/structured/mapping_test.py +105 -0
- langfun/core/structured/parsing.py +33 -21
- langfun/core/structured/parsing_test.py +71 -22
- langfun/core/structured/querying.py +746 -0
- langfun/core/structured/{prompting_test.py → querying_test.py} +545 -60
- langfun/core/structured/schema.py +208 -99
- langfun/core/structured/schema_generation.py +1 -1
- langfun/core/structured/schema_generation_test.py +2 -2
- langfun/core/structured/schema_test.py +133 -34
- langfun/core/structured/scoring.py +125 -19
- langfun/core/structured/scoring_test.py +30 -0
- langfun/core/structured/tokenization.py +64 -0
- langfun/core/structured/tokenization_test.py +48 -0
- langfun/core/template.py +240 -11
- langfun/core/template_test.py +146 -1
- langfun/core/templates/conversation.py +9 -0
- langfun/core/templates/conversation_test.py +4 -3
- langfun/core/templates/selfplay_test.py +14 -2
- langfun-0.1.2.dev202501140804.dist-info/METADATA +225 -0
- langfun-0.1.2.dev202501140804.dist-info/RECORD +153 -0
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/WHEEL +1 -1
- langfun/core/coding/python/errors.py +0 -108
- langfun/core/coding/python/errors_test.py +0 -99
- langfun/core/coding/python/permissions.py +0 -90
- langfun/core/coding/python/permissions_test.py +0 -86
- langfun/core/structured/prompting.py +0 -217
- langfun/core/text_formatting.py +0 -162
- langfun/core/text_formatting_test.py +0 -47
- langfun-0.0.2.dev20240330.dist-info/METADATA +0 -99
- langfun-0.0.2.dev20240330.dist-info/RECORD +0 -102
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/LICENSE +0 -0
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/top_level.txt +0 -0
@@ -27,38 +27,53 @@ import pyglove as pg
|
|
27
27
|
@pg.use_init_args(['failures_before_attempt'])
|
28
28
|
class MockModel(lm_lib.LanguageModel):
|
29
29
|
"""A mock model that echo back user prompts."""
|
30
|
-
|
31
30
|
failures_before_attempt: int = 0
|
31
|
+
name: str = 'MockModel'
|
32
32
|
|
33
33
|
def _sample(self,
|
34
34
|
prompts: list[message_lib.Message]
|
35
35
|
) -> list[lm_lib.LMSamplingResult]:
|
36
36
|
context = pg.Dict(attempt=0)
|
37
37
|
|
38
|
-
def fake_sample(
|
38
|
+
def fake_sample(prompt):
|
39
39
|
if context.attempt >= self.failures_before_attempt:
|
40
|
-
return
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
40
|
+
return lm_lib.LMSamplingResult(
|
41
|
+
[
|
42
|
+
lm_lib.LMSample( # pylint: disable=g-complex-comprehension
|
43
|
+
response=prompt.text * self.sampling_options.top_k,
|
44
|
+
score=self.sampling_options.temperature or -1.0,
|
45
|
+
)
|
46
|
+
],
|
47
|
+
usage=lm_lib.LMSamplingUsage(
|
48
|
+
prompt_tokens=100,
|
49
|
+
completion_tokens=100,
|
50
|
+
total_tokens=200,
|
51
|
+
estimated_cost=1.0,
|
52
|
+
),
|
53
|
+
)
|
54
|
+
else:
|
55
|
+
context.attempt += 1
|
47
56
|
raise ValueError('Failed to sample prompts.')
|
48
57
|
|
49
|
-
|
50
|
-
fake_sample,
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
58
|
+
results = self._parallel_execute_with_currency_control(
|
59
|
+
fake_sample, prompts, retry_on_errors=ValueError
|
60
|
+
)
|
61
|
+
for result in results:
|
62
|
+
result.usage.retry_stats.rebind(
|
63
|
+
total_call_interval=0, skip_notification=True
|
64
|
+
)
|
65
|
+
return results
|
66
|
+
|
67
|
+
@property
|
68
|
+
def model_id(self) -> str:
|
69
|
+
return self.name
|
55
70
|
|
56
71
|
|
57
72
|
class MockScoringModel(MockModel):
|
58
73
|
|
59
74
|
def _score(
|
60
75
|
self,
|
61
|
-
prompt: message_lib.Message,
|
76
|
+
prompt: message_lib.Message | list[message_lib.Message],
|
62
77
|
completions: list[message_lib.Message],
|
63
78
|
**kwargs
|
64
79
|
) -> list[lm_lib.LMScoringResult]:
|
@@ -67,19 +82,26 @@ class MockScoringModel(MockModel):
|
|
67
82
|
]
|
68
83
|
|
69
84
|
|
85
|
+
class MockTokenizeModel(MockModel):
|
86
|
+
|
87
|
+
def _tokenize(
|
88
|
+
self, prompt: message_lib.Message) -> list[tuple[str | bytes, int]]:
|
89
|
+
return [(w, i) for i, w in enumerate(prompt.text.split(' '))]
|
90
|
+
|
91
|
+
|
70
92
|
class LMSamplingOptionsTest(unittest.TestCase):
|
71
93
|
"""Tests for LMSamplingOptions."""
|
72
94
|
|
73
95
|
def test_cache_key(self):
|
74
96
|
options = lm_lib.LMSamplingOptions()
|
75
97
|
key1 = options.cache_key()
|
76
|
-
self.assertEqual(key1, (
|
98
|
+
self.assertEqual(key1, (None, None, 1, 40, None, None))
|
77
99
|
with options.override(temperature=1.0, max_tokens=256):
|
78
100
|
key2 = options.cache_key()
|
79
101
|
self.assertEqual(key2, (1.0, 256, 1, 40, None, None))
|
80
102
|
|
81
103
|
# Make sure key1 does not change upon override.
|
82
|
-
self.assertEqual(key1, (
|
104
|
+
self.assertEqual(key1, (None, None, 1, 40, None, None))
|
83
105
|
|
84
106
|
|
85
107
|
class LanguageModelTest(unittest.TestCase):
|
@@ -95,13 +117,60 @@ class LanguageModelTest(unittest.TestCase):
|
|
95
117
|
self.assertEqual(lm.sampling_options.top_k, 2)
|
96
118
|
self.assertEqual(lm.max_attempts, 2)
|
97
119
|
|
120
|
+
def test_subclassing(self):
|
121
|
+
|
122
|
+
class ChildModel(lm_lib.LanguageModel):
|
123
|
+
|
124
|
+
sampling_options = lm_lib.LMSamplingOptions(
|
125
|
+
temperature=0.5, top_k=20
|
126
|
+
)
|
127
|
+
|
128
|
+
def _sample(self, *args, **kwargs):
|
129
|
+
pass
|
130
|
+
|
131
|
+
lm = ChildModel(top_k=10)
|
132
|
+
self.assertEqual(lm.sampling_options.temperature, 0.5)
|
133
|
+
self.assertEqual(lm.sampling_options.top_k, 10)
|
134
|
+
|
98
135
|
def test_sample(self):
|
99
136
|
lm = MockModel(top_k=1)
|
100
137
|
self.assertEqual(
|
101
138
|
lm.sample(prompts=['foo', 'bar']),
|
102
139
|
[
|
103
|
-
lm_lib.LMSamplingResult(
|
104
|
-
|
140
|
+
lm_lib.LMSamplingResult(
|
141
|
+
[
|
142
|
+
lm_lib.LMSample(
|
143
|
+
message_lib.AIMessage(
|
144
|
+
'foo',
|
145
|
+
score=-1.0,
|
146
|
+
logprobs=None,
|
147
|
+
is_cached=False,
|
148
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
|
149
|
+
tags=[message_lib.Message.TAG_LM_RESPONSE],
|
150
|
+
),
|
151
|
+
score=-1.0,
|
152
|
+
logprobs=None,
|
153
|
+
)
|
154
|
+
],
|
155
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
|
156
|
+
),
|
157
|
+
lm_lib.LMSamplingResult(
|
158
|
+
[
|
159
|
+
lm_lib.LMSample(
|
160
|
+
message_lib.AIMessage(
|
161
|
+
'bar',
|
162
|
+
score=-1.0,
|
163
|
+
logprobs=None,
|
164
|
+
is_cached=False,
|
165
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
|
166
|
+
tags=[message_lib.Message.TAG_LM_RESPONSE],
|
167
|
+
),
|
168
|
+
score=-1.0,
|
169
|
+
logprobs=None,
|
170
|
+
)
|
171
|
+
],
|
172
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
|
173
|
+
),
|
105
174
|
],
|
106
175
|
)
|
107
176
|
# Test override sampling_options.
|
@@ -112,38 +181,139 @@ class LanguageModelTest(unittest.TestCase):
|
|
112
181
|
),
|
113
182
|
[
|
114
183
|
lm_lib.LMSamplingResult(
|
115
|
-
[
|
184
|
+
[
|
185
|
+
lm_lib.LMSample(
|
186
|
+
message_lib.AIMessage(
|
187
|
+
'foo' * 2,
|
188
|
+
score=0.5,
|
189
|
+
logprobs=None,
|
190
|
+
is_cached=False,
|
191
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
|
192
|
+
tags=[message_lib.Message.TAG_LM_RESPONSE],
|
193
|
+
),
|
194
|
+
score=0.5,
|
195
|
+
logprobs=None,
|
196
|
+
),
|
197
|
+
],
|
198
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
|
116
199
|
),
|
117
200
|
lm_lib.LMSamplingResult(
|
118
|
-
[
|
201
|
+
[
|
202
|
+
lm_lib.LMSample(
|
203
|
+
message_lib.AIMessage(
|
204
|
+
'bar' * 2,
|
205
|
+
score=0.5,
|
206
|
+
logprobs=None,
|
207
|
+
is_cached=False,
|
208
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
|
209
|
+
tags=[message_lib.Message.TAG_LM_RESPONSE],
|
210
|
+
),
|
211
|
+
score=0.5,
|
212
|
+
logprobs=None,
|
213
|
+
),
|
214
|
+
],
|
215
|
+
usage=lm_lib.LMSamplingUsage(
|
216
|
+
prompt_tokens=100, completion_tokens=100, total_tokens=200,
|
217
|
+
num_requests=1, estimated_cost=1.0,
|
218
|
+
),
|
119
219
|
),
|
120
|
-
]
|
220
|
+
]
|
121
221
|
)
|
122
222
|
# Test override individual flags within sampling_options.
|
123
223
|
self.assertEqual(
|
124
224
|
lm.sample(prompts=['foo', 'bar'], temperature=1.0),
|
125
225
|
[
|
126
|
-
lm_lib.LMSamplingResult(
|
127
|
-
|
128
|
-
|
226
|
+
lm_lib.LMSamplingResult(
|
227
|
+
[
|
228
|
+
lm_lib.LMSample(
|
229
|
+
message_lib.AIMessage(
|
230
|
+
'foo',
|
231
|
+
score=1.0,
|
232
|
+
logprobs=None,
|
233
|
+
is_cached=False,
|
234
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
|
235
|
+
tags=[message_lib.Message.TAG_LM_RESPONSE],
|
236
|
+
),
|
237
|
+
score=1.0,
|
238
|
+
logprobs=None,
|
239
|
+
),
|
240
|
+
],
|
241
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
|
242
|
+
),
|
243
|
+
lm_lib.LMSamplingResult(
|
244
|
+
[
|
245
|
+
lm_lib.LMSample(
|
246
|
+
message_lib.AIMessage(
|
247
|
+
'bar',
|
248
|
+
score=1.0,
|
249
|
+
logprobs=None,
|
250
|
+
is_cached=False,
|
251
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
|
252
|
+
tags=[message_lib.Message.TAG_LM_RESPONSE],
|
253
|
+
),
|
254
|
+
score=1.0,
|
255
|
+
logprobs=None,
|
256
|
+
),
|
257
|
+
],
|
258
|
+
usage=lm_lib.LMSamplingUsage(
|
259
|
+
prompt_tokens=100, completion_tokens=100, total_tokens=200,
|
260
|
+
num_requests=1, estimated_cost=1.0,
|
261
|
+
),
|
262
|
+
),
|
263
|
+
]
|
129
264
|
)
|
130
265
|
self.assertEqual(
|
131
266
|
lm.sample(prompts=['foo', 'bar'], top_k=2, temperature=0.7),
|
132
267
|
[
|
133
268
|
lm_lib.LMSamplingResult(
|
134
|
-
[
|
269
|
+
[
|
270
|
+
lm_lib.LMSample(
|
271
|
+
message_lib.AIMessage(
|
272
|
+
'foo' * 2,
|
273
|
+
score=0.7,
|
274
|
+
logprobs=None,
|
275
|
+
is_cached=False,
|
276
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
|
277
|
+
tags=[message_lib.Message.TAG_LM_RESPONSE],
|
278
|
+
),
|
279
|
+
score=0.7,
|
280
|
+
logprobs=None,
|
281
|
+
),
|
282
|
+
],
|
283
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
|
135
284
|
),
|
136
285
|
lm_lib.LMSamplingResult(
|
137
|
-
[
|
286
|
+
[
|
287
|
+
lm_lib.LMSample(
|
288
|
+
message_lib.AIMessage(
|
289
|
+
'bar' * 2,
|
290
|
+
score=0.7,
|
291
|
+
logprobs=None,
|
292
|
+
is_cached=False,
|
293
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
|
294
|
+
tags=[message_lib.Message.TAG_LM_RESPONSE],
|
295
|
+
),
|
296
|
+
score=0.7,
|
297
|
+
logprobs=None,
|
298
|
+
),
|
299
|
+
],
|
300
|
+
usage=lm_lib.LMSamplingUsage(
|
301
|
+
prompt_tokens=100, completion_tokens=100, total_tokens=200,
|
302
|
+
num_requests=1, estimated_cost=1.0,
|
303
|
+
),
|
138
304
|
),
|
139
|
-
]
|
305
|
+
]
|
140
306
|
)
|
141
307
|
|
142
308
|
def test_call(self):
|
143
309
|
lm = MockModel(sampling_options=lm_lib.LMSamplingOptions(top_k=1))
|
144
310
|
response = lm(prompt='foo')
|
145
311
|
self.assertEqual(response.text, 'foo')
|
146
|
-
self.assertEqual(response.score,
|
312
|
+
self.assertEqual(response.score, -1.0)
|
313
|
+
self.assertIsNone(response.logprobs)
|
314
|
+
self.assertEqual(
|
315
|
+
response.usage, lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0)
|
316
|
+
)
|
147
317
|
|
148
318
|
# Test override sampling_options.
|
149
319
|
self.assertEqual(
|
@@ -158,16 +328,53 @@ class LanguageModelTest(unittest.TestCase):
|
|
158
328
|
self.assertEqual(
|
159
329
|
lm.sample(prompts=['foo', 'bar']),
|
160
330
|
[
|
161
|
-
lm_lib.LMSamplingResult(
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
331
|
+
lm_lib.LMSamplingResult(
|
332
|
+
[
|
333
|
+
lm_lib.LMSample(
|
334
|
+
message_lib.AIMessage(
|
335
|
+
'foo',
|
336
|
+
cache_seed=0,
|
337
|
+
score=-1.0,
|
338
|
+
logprobs=None,
|
339
|
+
is_cached=False,
|
340
|
+
usage=lm_lib.LMSamplingUsage(
|
341
|
+
100, 100, 200, 1, 1.0
|
342
|
+
),
|
343
|
+
tags=[message_lib.Message.TAG_LM_RESPONSE],
|
344
|
+
),
|
345
|
+
score=-1.0,
|
346
|
+
logprobs=None,
|
347
|
+
)
|
348
|
+
],
|
349
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
|
350
|
+
),
|
351
|
+
lm_lib.LMSamplingResult(
|
352
|
+
[
|
353
|
+
lm_lib.LMSample(
|
354
|
+
message_lib.AIMessage(
|
355
|
+
'bar',
|
356
|
+
cache_seed=0,
|
357
|
+
score=-1.0,
|
358
|
+
logprobs=None,
|
359
|
+
is_cached=False,
|
360
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
|
361
|
+
tags=[message_lib.Message.TAG_LM_RESPONSE],
|
362
|
+
),
|
363
|
+
score=-1.0,
|
364
|
+
logprobs=None,
|
365
|
+
)
|
366
|
+
],
|
367
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
|
368
|
+
),
|
369
|
+
],
|
370
|
+
)
|
166
371
|
self.assertEqual(cache.stats.num_queries, 2)
|
167
372
|
self.assertEqual(cache.stats.num_hits, 0)
|
168
373
|
self.assertEqual(cache.stats.num_updates, 2)
|
169
374
|
|
170
|
-
|
375
|
+
result = lm('foo')
|
376
|
+
self.assertEqual(result, 'foo')
|
377
|
+
self.assertTrue(result.metadata.is_cached)
|
171
378
|
self.assertEqual(lm('bar'), 'bar')
|
172
379
|
self.assertEqual(cache.stats.num_queries, 4)
|
173
380
|
self.assertEqual(cache.stats.num_hits, 2)
|
@@ -181,10 +388,42 @@ class LanguageModelTest(unittest.TestCase):
|
|
181
388
|
self.assertEqual(
|
182
389
|
lm.sample(prompts=['foo', 'baz'], temperature=1.0),
|
183
390
|
[
|
184
|
-
lm_lib.LMSamplingResult(
|
185
|
-
|
186
|
-
|
187
|
-
|
391
|
+
lm_lib.LMSamplingResult(
|
392
|
+
[
|
393
|
+
lm_lib.LMSample(
|
394
|
+
message_lib.AIMessage(
|
395
|
+
'foo',
|
396
|
+
cache_seed=0,
|
397
|
+
score=1.0,
|
398
|
+
logprobs=None,
|
399
|
+
is_cached=False,
|
400
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
|
401
|
+
tags=[message_lib.Message.TAG_LM_RESPONSE],
|
402
|
+
),
|
403
|
+
score=1.0,
|
404
|
+
logprobs=None,
|
405
|
+
)
|
406
|
+
],
|
407
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
|
408
|
+
),
|
409
|
+
lm_lib.LMSamplingResult(
|
410
|
+
[
|
411
|
+
lm_lib.LMSample(
|
412
|
+
message_lib.AIMessage(
|
413
|
+
'baz',
|
414
|
+
cache_seed=0,
|
415
|
+
score=1.0,
|
416
|
+
logprobs=None,
|
417
|
+
is_cached=False,
|
418
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
|
419
|
+
tags=[message_lib.Message.TAG_LM_RESPONSE],
|
420
|
+
),
|
421
|
+
score=1.0,
|
422
|
+
logprobs=None,
|
423
|
+
)
|
424
|
+
],
|
425
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
|
426
|
+
),
|
188
427
|
],
|
189
428
|
)
|
190
429
|
self.assertEqual(cache.stats.num_queries, 6)
|
@@ -209,13 +448,50 @@ class LanguageModelTest(unittest.TestCase):
|
|
209
448
|
|
210
449
|
def test_retry(self):
|
211
450
|
lm = MockModel(
|
212
|
-
failures_before_attempt=1, top_k=1,
|
451
|
+
failures_before_attempt=1, top_k=1, max_attempts=2, retry_interval=1
|
213
452
|
)
|
214
453
|
with self.assertRaisesRegex(
|
215
454
|
concurrent.RetryError, 'Calling .* failed after 1 attempts'
|
216
455
|
):
|
217
456
|
lm('foo', max_attempts=1)
|
218
|
-
|
457
|
+
|
458
|
+
usage = lm_lib.LMSamplingUsage(
|
459
|
+
prompt_tokens=100,
|
460
|
+
completion_tokens=100,
|
461
|
+
total_tokens=200,
|
462
|
+
num_requests=1,
|
463
|
+
estimated_cost=1.0,
|
464
|
+
retry_stats=lm_lib.RetryStats(
|
465
|
+
num_occurences=1,
|
466
|
+
total_wait_interval=1,
|
467
|
+
errors={'ValueError': 1},
|
468
|
+
),
|
469
|
+
)
|
470
|
+
out = lm.sample(['foo'])
|
471
|
+
self.assertEqual(
|
472
|
+
# lm.sample(['foo'], max_attempts=2),
|
473
|
+
out,
|
474
|
+
[
|
475
|
+
lm_lib.LMSamplingResult(
|
476
|
+
[
|
477
|
+
lm_lib.LMSample(
|
478
|
+
message_lib.AIMessage(
|
479
|
+
'foo',
|
480
|
+
score=-1.0,
|
481
|
+
logprobs=None,
|
482
|
+
is_cached=False,
|
483
|
+
usage=usage,
|
484
|
+
tags=['lm-response'],
|
485
|
+
),
|
486
|
+
score=-1.0,
|
487
|
+
logprobs=None,
|
488
|
+
)
|
489
|
+
],
|
490
|
+
usage=usage,
|
491
|
+
is_cached=False,
|
492
|
+
)
|
493
|
+
],
|
494
|
+
)
|
219
495
|
|
220
496
|
def test_debug(self):
|
221
497
|
class Image(modality.Modality):
|
@@ -227,8 +503,9 @@ class LanguageModelTest(unittest.TestCase):
|
|
227
503
|
with contextlib.redirect_stdout(string_io):
|
228
504
|
self.assertEqual(
|
229
505
|
lm(message_lib.UserMessage(
|
230
|
-
'hi
|
231
|
-
'hi
|
506
|
+
'hi <<[[image]]>>', image=Image()), debug=True),
|
507
|
+
'hi <<[[image]]>>'
|
508
|
+
)
|
232
509
|
|
233
510
|
debug_info = string_io.getvalue()
|
234
511
|
self.assertIn('[0] LM INFO', debug_info)
|
@@ -317,6 +594,17 @@ class LanguageModelTest(unittest.TestCase):
|
|
317
594
|
],
|
318
595
|
)
|
319
596
|
|
597
|
+
self.assertEqual(
|
598
|
+
lm.score(
|
599
|
+
[message_lib.UserMessage('hi {{image}}', image=Image()),
|
600
|
+
message_lib.UserMessage('hi {{image}}', image=Image())],
|
601
|
+
['1', '2'], debug=debug_mode),
|
602
|
+
[
|
603
|
+
lm_lib.LMScoringResult(score=-0.0),
|
604
|
+
lm_lib.LMScoringResult(score=-1.0),
|
605
|
+
],
|
606
|
+
)
|
607
|
+
|
320
608
|
debug_info = string_io.getvalue()
|
321
609
|
expected_included = [
|
322
610
|
debug_prints[f]
|
@@ -337,10 +625,359 @@ class LanguageModelTest(unittest.TestCase):
|
|
337
625
|
if debug_mode & lm_lib.LMDebugMode.PROMPT:
|
338
626
|
self.assertIn('[0] MODALITY OBJECTS SENT TO LM', debug_info)
|
339
627
|
|
628
|
+
def test_score_with_unmatched_prompt_and_completions(self):
|
629
|
+
with self.assertRaises(ValueError):
|
630
|
+
MockScoringModel().score(['hi',], ['1', '2', '3'])
|
631
|
+
|
340
632
|
def test_score_with_unsupported_model(self):
|
341
633
|
with self.assertRaises(NotImplementedError):
|
342
634
|
MockModel().score('hi', ['1', '2'])
|
343
635
|
|
636
|
+
def test_tokenize(self):
|
637
|
+
info_flag = lm_lib.LMDebugMode.INFO
|
638
|
+
prompt_flag = lm_lib.LMDebugMode.PROMPT
|
639
|
+
response_flag = lm_lib.LMDebugMode.RESPONSE
|
640
|
+
debug_prints = {
|
641
|
+
info_flag: 'LM INFO',
|
642
|
+
prompt_flag: 'PROMPT TO TOKENIZE',
|
643
|
+
response_flag: 'TOKENS RETURNED',
|
644
|
+
}
|
645
|
+
debug_modes = [
|
646
|
+
info_flag,
|
647
|
+
prompt_flag,
|
648
|
+
response_flag,
|
649
|
+
info_flag | prompt_flag,
|
650
|
+
info_flag | response_flag,
|
651
|
+
prompt_flag | response_flag,
|
652
|
+
info_flag | prompt_flag | response_flag,
|
653
|
+
]
|
654
|
+
|
655
|
+
class Image(modality.Modality):
|
656
|
+
def to_bytes(self):
|
657
|
+
return b'fake_image'
|
658
|
+
|
659
|
+
for debug_mode in debug_modes:
|
660
|
+
string_io = io.StringIO()
|
661
|
+
lm = MockTokenizeModel()
|
662
|
+
|
663
|
+
with contextlib.redirect_stdout(string_io):
|
664
|
+
self.assertEqual(
|
665
|
+
lm.tokenize(
|
666
|
+
message_lib.UserMessage('hi <<[[image]]>>', image=Image()),
|
667
|
+
debug=debug_mode),
|
668
|
+
[('hi', 0), ('<<[[image]]>>', 1)],
|
669
|
+
)
|
670
|
+
|
671
|
+
debug_info = string_io.getvalue()
|
672
|
+
expected_included = [
|
673
|
+
debug_prints[f]
|
674
|
+
for f in lm_lib.LMDebugMode
|
675
|
+
if f != lm_lib.LMDebugMode.NONE and f in debug_mode
|
676
|
+
]
|
677
|
+
expected_excluded = [
|
678
|
+
debug_prints[f]
|
679
|
+
for f in lm_lib.LMDebugMode
|
680
|
+
if f != lm_lib.LMDebugMode.NONE and f not in debug_mode
|
681
|
+
]
|
682
|
+
|
683
|
+
for expected_include in expected_included:
|
684
|
+
self.assertIn(expected_include, debug_info)
|
685
|
+
for expected_exclude in expected_excluded:
|
686
|
+
self.assertNotIn(expected_exclude, debug_info)
|
687
|
+
|
688
|
+
if debug_mode & lm_lib.LMDebugMode.PROMPT:
|
689
|
+
self.assertIn('[0] MODALITY OBJECTS SENT TO LM', debug_info)
|
690
|
+
|
691
|
+
def test_tokenize_with_unsupported_model(self):
|
692
|
+
with self.assertRaises(NotImplementedError):
|
693
|
+
MockModel().tokenize('hi')
|
694
|
+
|
695
|
+
def test_rate_to_max_concurrency_no_rpm_no_tpm(self) -> None:
|
696
|
+
lm = MockModel()
|
697
|
+
self.assertEqual(
|
698
|
+
lm_lib.DEFAULT_MAX_CONCURRENCY,
|
699
|
+
lm.rate_to_max_concurrency(requests_per_min=0, tokens_per_min=0),
|
700
|
+
)
|
701
|
+
self.assertEqual(
|
702
|
+
lm_lib.DEFAULT_MAX_CONCURRENCY,
|
703
|
+
lm.rate_to_max_concurrency(requests_per_min=-1, tokens_per_min=-1),
|
704
|
+
)
|
705
|
+
|
706
|
+
def test_rate_to_max_concurrency_only_rpm_specified_uses_rpm(self) -> None:
|
707
|
+
lm = MockModel()
|
708
|
+
test_rpm = 1e4
|
709
|
+
self.assertEqual(
|
710
|
+
lm.rate_to_max_concurrency(requests_per_min=test_rpm),
|
711
|
+
int(test_rpm / 60)
|
712
|
+
)
|
713
|
+
|
714
|
+
def test_rate_to_max_concurrency_tpm_specified_uses_tpm(self) -> None:
|
715
|
+
lm = MockModel()
|
716
|
+
test_tpm = 1e7
|
717
|
+
self.assertEqual(
|
718
|
+
lm.rate_to_max_concurrency(requests_per_min=1, tokens_per_min=test_tpm),
|
719
|
+
int(test_tpm / lm_lib.TOKENS_PER_REQUEST / 60),
|
720
|
+
)
|
721
|
+
|
722
|
+
def test_rate_to_max_concurrency_small_rate_returns_one(self) -> None:
|
723
|
+
lm = MockModel()
|
724
|
+
self.assertEqual(lm.rate_to_max_concurrency(requests_per_min=1), 1)
|
725
|
+
self.assertEqual(lm.rate_to_max_concurrency(tokens_per_min=1), 1)
|
726
|
+
|
727
|
+
def test_track_usages(self):
|
728
|
+
lm = MockModel(name='model1')
|
729
|
+
lm2 = MockModel(name='model2')
|
730
|
+
with lm_lib.track_usages() as usages1:
|
731
|
+
_ = lm('hi')
|
732
|
+
with lm_lib.track_usages(lm2) as usages2:
|
733
|
+
with lm_lib.track_usages('model1') as usages3:
|
734
|
+
with lm_lib.track_usages('model1', lm2) as usages4:
|
735
|
+
def call_lm(prompt):
|
736
|
+
_ = lm.sample([prompt] * 2)
|
737
|
+
lm2('hi')
|
738
|
+
list(concurrent.concurrent_map(call_lm, ['hi', 'hello']))
|
739
|
+
|
740
|
+
self.assertEqual(usages2.uncached.breakdown, {
|
741
|
+
'model2': lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
|
742
|
+
})
|
743
|
+
self.assertFalse(usages2.cached)
|
744
|
+
self.assertEqual(usages3.uncached.breakdown, {
|
745
|
+
'model1': lm_lib.LMSamplingUsage(100 * 4, 100 * 4, 200 * 4, 4, 4.0),
|
746
|
+
})
|
747
|
+
self.assertFalse(usages3.cached)
|
748
|
+
self.assertEqual(usages4.uncached.breakdown, {
|
749
|
+
'model1': lm_lib.LMSamplingUsage(100 * 4, 100 * 4, 200 * 4, 4, 4.0),
|
750
|
+
'model2': lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
|
751
|
+
})
|
752
|
+
self.assertFalse(usages4.cached)
|
753
|
+
self.assertEqual(usages1.uncached.breakdown, {
|
754
|
+
'model1': lm_lib.LMSamplingUsage(100 * 5, 100 * 5, 200 * 5, 5, 5.0),
|
755
|
+
'model2': lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
|
756
|
+
})
|
757
|
+
self.assertFalse(usages1.cached)
|
758
|
+
self.assertEqual(
|
759
|
+
usages1.total,
|
760
|
+
lm_lib.LMSamplingUsage(100 * 6, 100 * 6, 200 * 6, 6, 6.0),
|
761
|
+
)
|
762
|
+
|
763
|
+
cache = in_memory.InMemory()
|
764
|
+
lm = MockModel(cache=cache, name='model1')
|
765
|
+
with lm_lib.track_usages() as usages1:
|
766
|
+
_ = lm('hi')
|
767
|
+
self.assertEqual(usages1.uncached.breakdown, {
|
768
|
+
'model1': lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
|
769
|
+
})
|
770
|
+
self.assertFalse(usages1.cached)
|
771
|
+
with lm_lib.track_usages() as usages2:
|
772
|
+
_ = lm('hi')
|
773
|
+
self.assertEqual(usages2.cached.breakdown, {
|
774
|
+
'model1': lm_lib.LMSamplingUsage(100, 100, 200, 1, 0.0),
|
775
|
+
})
|
776
|
+
self.assertFalse(usages2.uncached)
|
777
|
+
|
778
|
+
|
779
|
+
class LMSamplingUsageTest(unittest.TestCase):
|
780
|
+
|
781
|
+
def test_basics(self):
|
782
|
+
usage = lm_lib.LMSamplingUsage(100, 200, 300, 4, 5.0)
|
783
|
+
self.assertEqual(usage.num_requests, 4)
|
784
|
+
self.assertEqual(usage.prompt_tokens, 100)
|
785
|
+
self.assertEqual(usage.completion_tokens, 200)
|
786
|
+
self.assertEqual(usage.total_tokens, 300)
|
787
|
+
self.assertEqual(usage.estimated_cost, 5.0)
|
788
|
+
self.assertEqual(usage.average_prompt_tokens, 25)
|
789
|
+
self.assertEqual(usage.average_completion_tokens, 50)
|
790
|
+
self.assertEqual(usage.average_total_tokens, 75)
|
791
|
+
self.assertEqual(usage.average_estimated_cost, 1.25)
|
792
|
+
|
793
|
+
def test_add(self):
|
794
|
+
usage1 = lm_lib.LMSamplingUsage(100, 200, 300, 4, 5.0)
|
795
|
+
usage1.rebind(retry_stats=lm_lib.RetryStats(1, 3, 4, {'e1': 1}))
|
796
|
+
usage2 = lm_lib.LMSamplingUsage(100, 200, 300, 4, 5.0)
|
797
|
+
self.assertEqual(usage1 + usage2, usage1 + usage2)
|
798
|
+
self.assertIs(usage1 + None, usage1)
|
799
|
+
self.assertIs(None + usage1, usage1)
|
800
|
+
usage3 = lm_lib.LMSamplingUsage(100, 200, 300, 4, None)
|
801
|
+
usage3.rebind(retry_stats=lm_lib.RetryStats(2, 4, 5, {'e1': 2, 'e2': 3}))
|
802
|
+
self.assertEqual(
|
803
|
+
usage1 + usage3,
|
804
|
+
lm_lib.LMSamplingUsage(
|
805
|
+
200,
|
806
|
+
400,
|
807
|
+
600,
|
808
|
+
8,
|
809
|
+
5.0,
|
810
|
+
retry_stats=lm_lib.RetryStats(3, 7, 9, {'e1': 3, 'e2': 3}),
|
811
|
+
),
|
812
|
+
)
|
813
|
+
self.assertEqual(
|
814
|
+
usage3 + usage1,
|
815
|
+
lm_lib.LMSamplingUsage(
|
816
|
+
200,
|
817
|
+
400,
|
818
|
+
600,
|
819
|
+
8,
|
820
|
+
5.0,
|
821
|
+
retry_stats=lm_lib.RetryStats(3, 7, 9, {'e1': 3, 'e2': 3}),
|
822
|
+
),
|
823
|
+
)
|
824
|
+
|
825
|
+
def test_usage_not_available(self):
|
826
|
+
usage_not_available = lm_lib.UsageNotAvailable()
|
827
|
+
self.assertEqual(usage_not_available.prompt_tokens, 0)
|
828
|
+
self.assertEqual(usage_not_available.completion_tokens, 0)
|
829
|
+
self.assertEqual(usage_not_available.total_tokens, 0)
|
830
|
+
self.assertEqual(usage_not_available.average_prompt_tokens, 0)
|
831
|
+
self.assertEqual(usage_not_available.average_completion_tokens, 0)
|
832
|
+
self.assertEqual(usage_not_available.average_total_tokens, 0)
|
833
|
+
self.assertIsNone(usage_not_available.average_estimated_cost)
|
834
|
+
self.assertTrue(usage_not_available)
|
835
|
+
self.assertEqual(
|
836
|
+
usage_not_available + lm_lib.LMSamplingUsage(1, 2, 3, 4, 5.0),
|
837
|
+
lm_lib.UsageNotAvailable(num_requests=5)
|
838
|
+
)
|
839
|
+
self.assertEqual(
|
840
|
+
lm_lib.LMSamplingUsage(1, 2, 3, 4, 5.0) + usage_not_available,
|
841
|
+
lm_lib.UsageNotAvailable(num_requests=5)
|
842
|
+
)
|
843
|
+
self.assertIs(None + usage_not_available, usage_not_available)
|
844
|
+
self.assertIs(usage_not_available + None, usage_not_available)
|
845
|
+
|
846
|
+
|
847
|
+
class UsageSummaryTest(unittest.TestCase):
|
848
|
+
|
849
|
+
def test_basics(self):
|
850
|
+
usage_summary = lm_lib.UsageSummary()
|
851
|
+
self.assertFalse(usage_summary.total)
|
852
|
+
self.assertFalse(usage_summary.cached)
|
853
|
+
self.assertFalse(usage_summary.uncached)
|
854
|
+
|
855
|
+
# Add uncached.
|
856
|
+
usage_summary.add(
|
857
|
+
'model1', lm_lib.LMSamplingUsage(1, 2, 3, 1, 5.0), False
|
858
|
+
)
|
859
|
+
self.assertEqual(
|
860
|
+
usage_summary.total, lm_lib.LMSamplingUsage(1, 2, 3, 1, 5.0)
|
861
|
+
)
|
862
|
+
self.assertEqual(
|
863
|
+
usage_summary.uncached.total, lm_lib.LMSamplingUsage(1, 2, 3, 1, 5.0)
|
864
|
+
)
|
865
|
+
# Add cached.
|
866
|
+
self.assertFalse(usage_summary.cached)
|
867
|
+
usage_summary.add(
|
868
|
+
'model1', lm_lib.LMSamplingUsage(1, 2, 3, 1, 5.0), True
|
869
|
+
)
|
870
|
+
self.assertEqual(
|
871
|
+
usage_summary.total, lm_lib.LMSamplingUsage(2, 4, 6, 2, 5.0)
|
872
|
+
)
|
873
|
+
self.assertEqual(
|
874
|
+
usage_summary.cached.total, lm_lib.LMSamplingUsage(1, 2, 3, 1, 0.0)
|
875
|
+
)
|
876
|
+
# Add UsageNotAvailable.
|
877
|
+
usage_summary.add(
|
878
|
+
'model1', lm_lib.UsageNotAvailable(num_requests=1), False
|
879
|
+
)
|
880
|
+
self.assertEqual(
|
881
|
+
usage_summary.total, lm_lib.UsageNotAvailable(num_requests=3)
|
882
|
+
)
|
883
|
+
self.assertEqual(
|
884
|
+
usage_summary.uncached.total, lm_lib.UsageNotAvailable(num_requests=2)
|
885
|
+
)
|
886
|
+
|
887
|
+
def test_merge(self):
|
888
|
+
usage_summary = lm_lib.UsageSummary()
|
889
|
+
usage_summary.add(
|
890
|
+
'model1', lm_lib.LMSamplingUsage(1, 2, 3, 1, 5.0), False
|
891
|
+
)
|
892
|
+
usage_summary.add(
|
893
|
+
'model2', lm_lib.LMSamplingUsage(1, 2, 3, 1, 5.0), False
|
894
|
+
)
|
895
|
+
usage_summary.add(
|
896
|
+
'model1', lm_lib.LMSamplingUsage(1, 2, 3, 1, 5.0), False
|
897
|
+
)
|
898
|
+
usage_summary2 = lm_lib.UsageSummary()
|
899
|
+
usage_summary2.add(
|
900
|
+
'model1', lm_lib.LMSamplingUsage(1, 2, 3, 1, 5.0), False
|
901
|
+
)
|
902
|
+
usage_summary2.add(
|
903
|
+
'model3', lm_lib.LMSamplingUsage(1, 2, 3, 1, 5.0), False
|
904
|
+
)
|
905
|
+
usage_summary2.merge(usage_summary)
|
906
|
+
self.assertEqual(
|
907
|
+
usage_summary2,
|
908
|
+
lm_lib.UsageSummary(
|
909
|
+
cached=lm_lib.UsageSummary.AggregatedUsage(
|
910
|
+
total=lm_lib.LMSamplingUsage(
|
911
|
+
prompt_tokens=0,
|
912
|
+
completion_tokens=0,
|
913
|
+
total_tokens=0,
|
914
|
+
num_requests=0,
|
915
|
+
estimated_cost=0.0,
|
916
|
+
),
|
917
|
+
breakdown={}
|
918
|
+
),
|
919
|
+
uncached=lm_lib.UsageSummary.AggregatedUsage(
|
920
|
+
total=lm_lib.LMSamplingUsage(
|
921
|
+
prompt_tokens=5,
|
922
|
+
completion_tokens=10,
|
923
|
+
total_tokens=15,
|
924
|
+
num_requests=5,
|
925
|
+
estimated_cost=25.0
|
926
|
+
),
|
927
|
+
breakdown=dict(
|
928
|
+
model1=lm_lib.LMSamplingUsage(
|
929
|
+
prompt_tokens=3,
|
930
|
+
completion_tokens=6,
|
931
|
+
total_tokens=9,
|
932
|
+
num_requests=3,
|
933
|
+
estimated_cost=15.0
|
934
|
+
),
|
935
|
+
model3=lm_lib.LMSamplingUsage(
|
936
|
+
prompt_tokens=1,
|
937
|
+
completion_tokens=2,
|
938
|
+
total_tokens=3,
|
939
|
+
num_requests=1,
|
940
|
+
estimated_cost=5.0
|
941
|
+
),
|
942
|
+
model2=lm_lib.LMSamplingUsage(
|
943
|
+
prompt_tokens=1,
|
944
|
+
completion_tokens=2,
|
945
|
+
total_tokens=3,
|
946
|
+
num_requests=1,
|
947
|
+
estimated_cost=5.0
|
948
|
+
)
|
949
|
+
)
|
950
|
+
)
|
951
|
+
)
|
952
|
+
)
|
953
|
+
|
954
|
+
def test_html_view(self):
|
955
|
+
usage_summary = lm_lib.UsageSummary()
|
956
|
+
usage_summary.add(
|
957
|
+
'model1', lm_lib.LMSamplingUsage(1, 2, 3, 1, 5.0), False
|
958
|
+
)
|
959
|
+
self.assertIn(
|
960
|
+
'5.000',
|
961
|
+
usage_summary.to_html(extra_flags=dict(as_badge=True)).content
|
962
|
+
)
|
963
|
+
usage_summary.add(
|
964
|
+
'model1', lm_lib.LMSamplingUsage(1, 2, 3, 1, 5.0), False
|
965
|
+
)
|
966
|
+
self.assertIn(
|
967
|
+
'10.000',
|
968
|
+
usage_summary.to_html(
|
969
|
+
extra_flags=dict(as_badge=True, interactive=True)
|
970
|
+
).content
|
971
|
+
)
|
972
|
+
self.assertTrue(
|
973
|
+
usage_summary.to_html().content.startswith('<details open')
|
974
|
+
)
|
975
|
+
with pg.views.html.controls.HtmlControl.track_scripts() as scripts:
|
976
|
+
usage_summary.add(
|
977
|
+
'model2', lm_lib.LMSamplingUsage(1, 2, 3, 1, 5.0), False
|
978
|
+
)
|
979
|
+
self.assertEqual(len(scripts), 4)
|
980
|
+
|
344
981
|
|
345
982
|
if __name__ == '__main__':
|
346
983
|
unittest.main()
|