langfun 0.0.2.dev20240429__py3-none-any.whl → 0.1.2.dev202501140804__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/__init__.py +20 -2
- langfun/core/__init__.py +16 -5
- langfun/core/agentic/__init__.py +30 -0
- langfun/core/agentic/action.py +854 -0
- langfun/core/agentic/action_eval.py +150 -0
- langfun/core/agentic/action_eval_test.py +109 -0
- langfun/core/agentic/action_test.py +136 -0
- langfun/core/coding/python/__init__.py +5 -11
- langfun/core/coding/python/correction.py +37 -21
- langfun/core/coding/python/correction_test.py +29 -3
- langfun/core/coding/python/execution.py +40 -216
- langfun/core/coding/python/execution_test.py +29 -89
- langfun/core/coding/python/generation.py +21 -11
- langfun/core/coding/python/generation_test.py +2 -2
- langfun/core/coding/python/parsing.py +108 -193
- langfun/core/coding/python/parsing_test.py +2 -105
- langfun/core/component.py +63 -2
- langfun/core/component_test.py +53 -0
- langfun/core/concurrent.py +414 -117
- langfun/core/concurrent_test.py +111 -24
- langfun/core/console.py +18 -5
- langfun/core/console_test.py +17 -0
- langfun/core/eval/__init__.py +16 -1
- langfun/core/eval/base.py +622 -174
- langfun/core/eval/base_test.py +200 -54
- langfun/core/eval/matching.py +63 -76
- langfun/core/eval/matching_test.py +17 -8
- langfun/core/eval/patching.py +130 -0
- langfun/core/eval/patching_test.py +170 -0
- langfun/core/eval/scoring.py +26 -26
- langfun/core/eval/scoring_test.py +19 -2
- langfun/core/eval/v2/__init__.py +42 -0
- langfun/core/eval/v2/checkpointing.py +380 -0
- langfun/core/eval/v2/checkpointing_test.py +228 -0
- langfun/core/eval/v2/eval_test_helper.py +136 -0
- langfun/core/eval/v2/evaluation.py +725 -0
- langfun/core/eval/v2/evaluation_test.py +180 -0
- langfun/core/eval/v2/example.py +305 -0
- langfun/core/eval/v2/example_test.py +128 -0
- langfun/core/eval/v2/experiment.py +1048 -0
- langfun/core/eval/v2/experiment_test.py +433 -0
- langfun/core/eval/v2/metric_values.py +156 -0
- langfun/core/eval/v2/metric_values_test.py +80 -0
- langfun/core/eval/v2/metrics.py +357 -0
- langfun/core/eval/v2/metrics_test.py +203 -0
- langfun/core/eval/v2/progress.py +348 -0
- langfun/core/eval/v2/progress_test.py +82 -0
- langfun/core/eval/v2/progress_tracking.py +210 -0
- langfun/core/eval/v2/progress_tracking_test.py +66 -0
- langfun/core/eval/v2/reporting.py +270 -0
- langfun/core/eval/v2/reporting_test.py +158 -0
- langfun/core/eval/v2/runners.py +488 -0
- langfun/core/eval/v2/runners_test.py +334 -0
- langfun/core/langfunc.py +4 -17
- langfun/core/langfunc_test.py +22 -6
- langfun/core/language_model.py +577 -39
- langfun/core/language_model_test.py +470 -56
- langfun/core/llms/__init__.py +87 -16
- langfun/core/llms/anthropic.py +312 -87
- langfun/core/llms/anthropic_test.py +71 -3
- langfun/core/llms/cache/base.py +21 -2
- langfun/core/llms/cache/in_memory.py +13 -0
- langfun/core/llms/cache/in_memory_test.py +53 -2
- langfun/core/llms/compositional.py +101 -0
- langfun/core/llms/compositional_test.py +73 -0
- langfun/core/llms/deepseek.py +117 -0
- langfun/core/llms/deepseek_test.py +61 -0
- langfun/core/llms/fake.py +11 -7
- langfun/core/llms/fake_test.py +14 -0
- langfun/core/llms/gemini.py +507 -0
- langfun/core/llms/gemini_test.py +195 -0
- langfun/core/llms/google_genai.py +62 -218
- langfun/core/llms/google_genai_test.py +9 -202
- langfun/core/llms/groq.py +160 -144
- langfun/core/llms/groq_test.py +31 -137
- langfun/core/llms/llama_cpp.py +15 -42
- langfun/core/llms/llama_cpp_test.py +4 -30
- langfun/core/llms/openai.py +395 -203
- langfun/core/llms/openai_compatible.py +179 -0
- langfun/core/llms/openai_compatible_test.py +495 -0
- langfun/core/llms/openai_test.py +30 -395
- langfun/core/llms/rest.py +113 -0
- langfun/core/llms/rest_test.py +111 -0
- langfun/core/llms/vertexai.py +192 -0
- langfun/core/llms/vertexai_test.py +52 -0
- langfun/core/logging.py +284 -0
- langfun/core/logging_test.py +125 -0
- langfun/core/message.py +319 -9
- langfun/core/message_test.py +190 -13
- langfun/core/modalities/__init__.py +6 -2
- langfun/core/modalities/audio.py +30 -0
- langfun/core/modalities/audio_test.py +63 -0
- langfun/core/modalities/image.py +39 -20
- langfun/core/modalities/image_test.py +52 -9
- langfun/core/modalities/mime.py +206 -29
- langfun/core/modalities/mime_test.py +90 -9
- langfun/core/modalities/ms_office.py +117 -0
- langfun/core/modalities/ms_office_test.py +389 -0
- langfun/core/modalities/pdf.py +22 -0
- langfun/core/modalities/pdf_test.py +57 -0
- langfun/core/modalities/video.py +9 -26
- langfun/core/modalities/video_test.py +3 -3
- langfun/core/modality.py +26 -3
- langfun/core/modality_test.py +2 -2
- langfun/core/sampling.py +11 -11
- langfun/core/structured/__init__.py +12 -16
- langfun/core/structured/completion.py +32 -5
- langfun/core/structured/completion_test.py +7 -6
- langfun/core/structured/description.py +2 -2
- langfun/core/structured/description_test.py +3 -3
- langfun/core/structured/function_generation.py +60 -27
- langfun/core/structured/function_generation_test.py +72 -2
- langfun/core/structured/mapping.py +97 -47
- langfun/core/structured/mapping_test.py +90 -2
- langfun/core/structured/parsing.py +33 -21
- langfun/core/structured/parsing_test.py +53 -9
- langfun/core/structured/querying.py +746 -0
- langfun/core/structured/{prompting_test.py → querying_test.py} +469 -51
- langfun/core/structured/schema.py +204 -97
- langfun/core/structured/schema_generation.py +1 -1
- langfun/core/structured/schema_test.py +130 -29
- langfun/core/structured/scoring.py +125 -19
- langfun/core/structured/scoring_test.py +30 -0
- langfun/core/structured/tokenization.py +64 -0
- langfun/core/structured/tokenization_test.py +48 -0
- langfun/core/template.py +115 -1
- langfun/core/template_test.py +71 -1
- langfun/core/templates/conversation.py +9 -0
- langfun/core/templates/conversation_test.py +4 -3
- langfun/core/templates/selfplay_test.py +10 -2
- langfun-0.1.2.dev202501140804.dist-info/METADATA +225 -0
- langfun-0.1.2.dev202501140804.dist-info/RECORD +153 -0
- {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501140804.dist-info}/WHEEL +1 -1
- langfun/core/coding/python/errors.py +0 -108
- langfun/core/coding/python/errors_test.py +0 -99
- langfun/core/coding/python/permissions.py +0 -90
- langfun/core/coding/python/permissions_test.py +0 -86
- langfun/core/structured/prompting.py +0 -238
- langfun/core/text_formatting.py +0 -162
- langfun/core/text_formatting_test.py +0 -47
- langfun-0.0.2.dev20240429.dist-info/METADATA +0 -100
- langfun-0.0.2.dev20240429.dist-info/RECORD +0 -108
- {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501140804.dist-info}/LICENSE +0 -0
- {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501140804.dist-info}/top_level.txt +0 -0
@@ -27,48 +27,53 @@ import pyglove as pg
|
|
27
27
|
@pg.use_init_args(['failures_before_attempt'])
|
28
28
|
class MockModel(lm_lib.LanguageModel):
|
29
29
|
"""A mock model that echo back user prompts."""
|
30
|
-
|
31
30
|
failures_before_attempt: int = 0
|
31
|
+
name: str = 'MockModel'
|
32
32
|
|
33
33
|
def _sample(self,
|
34
34
|
prompts: list[message_lib.Message]
|
35
35
|
) -> list[lm_lib.LMSamplingResult]:
|
36
36
|
context = pg.Dict(attempt=0)
|
37
37
|
|
38
|
-
def fake_sample(
|
38
|
+
def fake_sample(prompt):
|
39
39
|
if context.attempt >= self.failures_before_attempt:
|
40
|
-
return
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
context.attempt += 1
|
40
|
+
return lm_lib.LMSamplingResult(
|
41
|
+
[
|
42
|
+
lm_lib.LMSample( # pylint: disable=g-complex-comprehension
|
43
|
+
response=prompt.text * self.sampling_options.top_k,
|
44
|
+
score=self.sampling_options.temperature or -1.0,
|
45
|
+
)
|
46
|
+
],
|
47
|
+
usage=lm_lib.LMSamplingUsage(
|
48
|
+
prompt_tokens=100,
|
49
|
+
completion_tokens=100,
|
50
|
+
total_tokens=200,
|
51
|
+
estimated_cost=1.0,
|
52
|
+
),
|
53
|
+
)
|
54
|
+
else:
|
55
|
+
context.attempt += 1
|
57
56
|
raise ValueError('Failed to sample prompts.')
|
58
57
|
|
59
|
-
|
60
|
-
fake_sample,
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
58
|
+
results = self._parallel_execute_with_currency_control(
|
59
|
+
fake_sample, prompts, retry_on_errors=ValueError
|
60
|
+
)
|
61
|
+
for result in results:
|
62
|
+
result.usage.retry_stats.rebind(
|
63
|
+
total_call_interval=0, skip_notification=True
|
64
|
+
)
|
65
|
+
return results
|
66
|
+
|
67
|
+
@property
|
68
|
+
def model_id(self) -> str:
|
69
|
+
return self.name
|
65
70
|
|
66
71
|
|
67
72
|
class MockScoringModel(MockModel):
|
68
73
|
|
69
74
|
def _score(
|
70
75
|
self,
|
71
|
-
prompt: message_lib.Message,
|
76
|
+
prompt: message_lib.Message | list[message_lib.Message],
|
72
77
|
completions: list[message_lib.Message],
|
73
78
|
**kwargs
|
74
79
|
) -> list[lm_lib.LMScoringResult]:
|
@@ -77,6 +82,13 @@ class MockScoringModel(MockModel):
|
|
77
82
|
]
|
78
83
|
|
79
84
|
|
85
|
+
class MockTokenizeModel(MockModel):
|
86
|
+
|
87
|
+
def _tokenize(
|
88
|
+
self, prompt: message_lib.Message) -> list[tuple[str | bytes, int]]:
|
89
|
+
return [(w, i) for i, w in enumerate(prompt.text.split(' '))]
|
90
|
+
|
91
|
+
|
80
92
|
class LMSamplingOptionsTest(unittest.TestCase):
|
81
93
|
"""Tests for LMSamplingOptions."""
|
82
94
|
|
@@ -105,6 +117,21 @@ class LanguageModelTest(unittest.TestCase):
|
|
105
117
|
self.assertEqual(lm.sampling_options.top_k, 2)
|
106
118
|
self.assertEqual(lm.max_attempts, 2)
|
107
119
|
|
120
|
+
def test_subclassing(self):
|
121
|
+
|
122
|
+
class ChildModel(lm_lib.LanguageModel):
|
123
|
+
|
124
|
+
sampling_options = lm_lib.LMSamplingOptions(
|
125
|
+
temperature=0.5, top_k=20
|
126
|
+
)
|
127
|
+
|
128
|
+
def _sample(self, *args, **kwargs):
|
129
|
+
pass
|
130
|
+
|
131
|
+
lm = ChildModel(top_k=10)
|
132
|
+
self.assertEqual(lm.sampling_options.temperature, 0.5)
|
133
|
+
self.assertEqual(lm.sampling_options.top_k, 10)
|
134
|
+
|
108
135
|
def test_sample(self):
|
109
136
|
lm = MockModel(top_k=1)
|
110
137
|
self.assertEqual(
|
@@ -117,14 +144,15 @@ class LanguageModelTest(unittest.TestCase):
|
|
117
144
|
'foo',
|
118
145
|
score=-1.0,
|
119
146
|
logprobs=None,
|
120
|
-
|
147
|
+
is_cached=False,
|
148
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
|
121
149
|
tags=[message_lib.Message.TAG_LM_RESPONSE],
|
122
150
|
),
|
123
151
|
score=-1.0,
|
124
152
|
logprobs=None,
|
125
153
|
)
|
126
154
|
],
|
127
|
-
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
155
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
|
128
156
|
),
|
129
157
|
lm_lib.LMSamplingResult(
|
130
158
|
[
|
@@ -133,14 +161,15 @@ class LanguageModelTest(unittest.TestCase):
|
|
133
161
|
'bar',
|
134
162
|
score=-1.0,
|
135
163
|
logprobs=None,
|
136
|
-
|
164
|
+
is_cached=False,
|
165
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
|
137
166
|
tags=[message_lib.Message.TAG_LM_RESPONSE],
|
138
167
|
),
|
139
168
|
score=-1.0,
|
140
169
|
logprobs=None,
|
141
170
|
)
|
142
171
|
],
|
143
|
-
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
172
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
|
144
173
|
),
|
145
174
|
],
|
146
175
|
)
|
@@ -158,14 +187,15 @@ class LanguageModelTest(unittest.TestCase):
|
|
158
187
|
'foo' * 2,
|
159
188
|
score=0.5,
|
160
189
|
logprobs=None,
|
161
|
-
|
190
|
+
is_cached=False,
|
191
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
|
162
192
|
tags=[message_lib.Message.TAG_LM_RESPONSE],
|
163
193
|
),
|
164
194
|
score=0.5,
|
165
195
|
logprobs=None,
|
166
196
|
),
|
167
197
|
],
|
168
|
-
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
198
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
|
169
199
|
),
|
170
200
|
lm_lib.LMSamplingResult(
|
171
201
|
[
|
@@ -174,7 +204,8 @@ class LanguageModelTest(unittest.TestCase):
|
|
174
204
|
'bar' * 2,
|
175
205
|
score=0.5,
|
176
206
|
logprobs=None,
|
177
|
-
|
207
|
+
is_cached=False,
|
208
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
|
178
209
|
tags=[message_lib.Message.TAG_LM_RESPONSE],
|
179
210
|
),
|
180
211
|
score=0.5,
|
@@ -182,7 +213,8 @@ class LanguageModelTest(unittest.TestCase):
|
|
182
213
|
),
|
183
214
|
],
|
184
215
|
usage=lm_lib.LMSamplingUsage(
|
185
|
-
prompt_tokens=100, completion_tokens=100, total_tokens=200
|
216
|
+
prompt_tokens=100, completion_tokens=100, total_tokens=200,
|
217
|
+
num_requests=1, estimated_cost=1.0,
|
186
218
|
),
|
187
219
|
),
|
188
220
|
]
|
@@ -198,14 +230,15 @@ class LanguageModelTest(unittest.TestCase):
|
|
198
230
|
'foo',
|
199
231
|
score=1.0,
|
200
232
|
logprobs=None,
|
201
|
-
|
233
|
+
is_cached=False,
|
234
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
|
202
235
|
tags=[message_lib.Message.TAG_LM_RESPONSE],
|
203
236
|
),
|
204
237
|
score=1.0,
|
205
238
|
logprobs=None,
|
206
239
|
),
|
207
240
|
],
|
208
|
-
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
241
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
|
209
242
|
),
|
210
243
|
lm_lib.LMSamplingResult(
|
211
244
|
[
|
@@ -214,7 +247,8 @@ class LanguageModelTest(unittest.TestCase):
|
|
214
247
|
'bar',
|
215
248
|
score=1.0,
|
216
249
|
logprobs=None,
|
217
|
-
|
250
|
+
is_cached=False,
|
251
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
|
218
252
|
tags=[message_lib.Message.TAG_LM_RESPONSE],
|
219
253
|
),
|
220
254
|
score=1.0,
|
@@ -222,7 +256,8 @@ class LanguageModelTest(unittest.TestCase):
|
|
222
256
|
),
|
223
257
|
],
|
224
258
|
usage=lm_lib.LMSamplingUsage(
|
225
|
-
prompt_tokens=100, completion_tokens=100, total_tokens=200
|
259
|
+
prompt_tokens=100, completion_tokens=100, total_tokens=200,
|
260
|
+
num_requests=1, estimated_cost=1.0,
|
226
261
|
),
|
227
262
|
),
|
228
263
|
]
|
@@ -237,14 +272,15 @@ class LanguageModelTest(unittest.TestCase):
|
|
237
272
|
'foo' * 2,
|
238
273
|
score=0.7,
|
239
274
|
logprobs=None,
|
240
|
-
|
275
|
+
is_cached=False,
|
276
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
|
241
277
|
tags=[message_lib.Message.TAG_LM_RESPONSE],
|
242
278
|
),
|
243
279
|
score=0.7,
|
244
280
|
logprobs=None,
|
245
281
|
),
|
246
282
|
],
|
247
|
-
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
283
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
|
248
284
|
),
|
249
285
|
lm_lib.LMSamplingResult(
|
250
286
|
[
|
@@ -253,7 +289,8 @@ class LanguageModelTest(unittest.TestCase):
|
|
253
289
|
'bar' * 2,
|
254
290
|
score=0.7,
|
255
291
|
logprobs=None,
|
256
|
-
|
292
|
+
is_cached=False,
|
293
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
|
257
294
|
tags=[message_lib.Message.TAG_LM_RESPONSE],
|
258
295
|
),
|
259
296
|
score=0.7,
|
@@ -261,7 +298,8 @@ class LanguageModelTest(unittest.TestCase):
|
|
261
298
|
),
|
262
299
|
],
|
263
300
|
usage=lm_lib.LMSamplingUsage(
|
264
|
-
prompt_tokens=100, completion_tokens=100, total_tokens=200
|
301
|
+
prompt_tokens=100, completion_tokens=100, total_tokens=200,
|
302
|
+
num_requests=1, estimated_cost=1.0,
|
265
303
|
),
|
266
304
|
),
|
267
305
|
]
|
@@ -273,7 +311,9 @@ class LanguageModelTest(unittest.TestCase):
|
|
273
311
|
self.assertEqual(response.text, 'foo')
|
274
312
|
self.assertEqual(response.score, -1.0)
|
275
313
|
self.assertIsNone(response.logprobs)
|
276
|
-
self.assertEqual(
|
314
|
+
self.assertEqual(
|
315
|
+
response.usage, lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0)
|
316
|
+
)
|
277
317
|
|
278
318
|
# Test override sampling_options.
|
279
319
|
self.assertEqual(
|
@@ -296,14 +336,17 @@ class LanguageModelTest(unittest.TestCase):
|
|
296
336
|
cache_seed=0,
|
297
337
|
score=-1.0,
|
298
338
|
logprobs=None,
|
299
|
-
|
339
|
+
is_cached=False,
|
340
|
+
usage=lm_lib.LMSamplingUsage(
|
341
|
+
100, 100, 200, 1, 1.0
|
342
|
+
),
|
300
343
|
tags=[message_lib.Message.TAG_LM_RESPONSE],
|
301
344
|
),
|
302
345
|
score=-1.0,
|
303
346
|
logprobs=None,
|
304
347
|
)
|
305
348
|
],
|
306
|
-
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
349
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
|
307
350
|
),
|
308
351
|
lm_lib.LMSamplingResult(
|
309
352
|
[
|
@@ -313,14 +356,15 @@ class LanguageModelTest(unittest.TestCase):
|
|
313
356
|
cache_seed=0,
|
314
357
|
score=-1.0,
|
315
358
|
logprobs=None,
|
316
|
-
|
359
|
+
is_cached=False,
|
360
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
|
317
361
|
tags=[message_lib.Message.TAG_LM_RESPONSE],
|
318
362
|
),
|
319
363
|
score=-1.0,
|
320
364
|
logprobs=None,
|
321
365
|
)
|
322
366
|
],
|
323
|
-
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
367
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
|
324
368
|
),
|
325
369
|
],
|
326
370
|
)
|
@@ -328,7 +372,9 @@ class LanguageModelTest(unittest.TestCase):
|
|
328
372
|
self.assertEqual(cache.stats.num_hits, 0)
|
329
373
|
self.assertEqual(cache.stats.num_updates, 2)
|
330
374
|
|
331
|
-
|
375
|
+
result = lm('foo')
|
376
|
+
self.assertEqual(result, 'foo')
|
377
|
+
self.assertTrue(result.metadata.is_cached)
|
332
378
|
self.assertEqual(lm('bar'), 'bar')
|
333
379
|
self.assertEqual(cache.stats.num_queries, 4)
|
334
380
|
self.assertEqual(cache.stats.num_hits, 2)
|
@@ -350,14 +396,15 @@ class LanguageModelTest(unittest.TestCase):
|
|
350
396
|
cache_seed=0,
|
351
397
|
score=1.0,
|
352
398
|
logprobs=None,
|
353
|
-
|
399
|
+
is_cached=False,
|
400
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
|
354
401
|
tags=[message_lib.Message.TAG_LM_RESPONSE],
|
355
402
|
),
|
356
403
|
score=1.0,
|
357
404
|
logprobs=None,
|
358
405
|
)
|
359
406
|
],
|
360
|
-
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
407
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
|
361
408
|
),
|
362
409
|
lm_lib.LMSamplingResult(
|
363
410
|
[
|
@@ -367,14 +414,15 @@ class LanguageModelTest(unittest.TestCase):
|
|
367
414
|
cache_seed=0,
|
368
415
|
score=1.0,
|
369
416
|
logprobs=None,
|
370
|
-
|
417
|
+
is_cached=False,
|
418
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
|
371
419
|
tags=[message_lib.Message.TAG_LM_RESPONSE],
|
372
420
|
),
|
373
421
|
score=1.0,
|
374
422
|
logprobs=None,
|
375
423
|
)
|
376
424
|
],
|
377
|
-
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
425
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
|
378
426
|
),
|
379
427
|
],
|
380
428
|
)
|
@@ -400,13 +448,50 @@ class LanguageModelTest(unittest.TestCase):
|
|
400
448
|
|
401
449
|
def test_retry(self):
|
402
450
|
lm = MockModel(
|
403
|
-
failures_before_attempt=1, top_k=1,
|
451
|
+
failures_before_attempt=1, top_k=1, max_attempts=2, retry_interval=1
|
404
452
|
)
|
405
453
|
with self.assertRaisesRegex(
|
406
454
|
concurrent.RetryError, 'Calling .* failed after 1 attempts'
|
407
455
|
):
|
408
456
|
lm('foo', max_attempts=1)
|
409
|
-
|
457
|
+
|
458
|
+
usage = lm_lib.LMSamplingUsage(
|
459
|
+
prompt_tokens=100,
|
460
|
+
completion_tokens=100,
|
461
|
+
total_tokens=200,
|
462
|
+
num_requests=1,
|
463
|
+
estimated_cost=1.0,
|
464
|
+
retry_stats=lm_lib.RetryStats(
|
465
|
+
num_occurences=1,
|
466
|
+
total_wait_interval=1,
|
467
|
+
errors={'ValueError': 1},
|
468
|
+
),
|
469
|
+
)
|
470
|
+
out = lm.sample(['foo'])
|
471
|
+
self.assertEqual(
|
472
|
+
# lm.sample(['foo'], max_attempts=2),
|
473
|
+
out,
|
474
|
+
[
|
475
|
+
lm_lib.LMSamplingResult(
|
476
|
+
[
|
477
|
+
lm_lib.LMSample(
|
478
|
+
message_lib.AIMessage(
|
479
|
+
'foo',
|
480
|
+
score=-1.0,
|
481
|
+
logprobs=None,
|
482
|
+
is_cached=False,
|
483
|
+
usage=usage,
|
484
|
+
tags=['lm-response'],
|
485
|
+
),
|
486
|
+
score=-1.0,
|
487
|
+
logprobs=None,
|
488
|
+
)
|
489
|
+
],
|
490
|
+
usage=usage,
|
491
|
+
is_cached=False,
|
492
|
+
)
|
493
|
+
],
|
494
|
+
)
|
410
495
|
|
411
496
|
def test_debug(self):
|
412
497
|
class Image(modality.Modality):
|
@@ -418,8 +503,9 @@ class LanguageModelTest(unittest.TestCase):
|
|
418
503
|
with contextlib.redirect_stdout(string_io):
|
419
504
|
self.assertEqual(
|
420
505
|
lm(message_lib.UserMessage(
|
421
|
-
'hi
|
422
|
-
'hi
|
506
|
+
'hi <<[[image]]>>', image=Image()), debug=True),
|
507
|
+
'hi <<[[image]]>>'
|
508
|
+
)
|
423
509
|
|
424
510
|
debug_info = string_io.getvalue()
|
425
511
|
self.assertIn('[0] LM INFO', debug_info)
|
@@ -508,6 +594,17 @@ class LanguageModelTest(unittest.TestCase):
|
|
508
594
|
],
|
509
595
|
)
|
510
596
|
|
597
|
+
self.assertEqual(
|
598
|
+
lm.score(
|
599
|
+
[message_lib.UserMessage('hi {{image}}', image=Image()),
|
600
|
+
message_lib.UserMessage('hi {{image}}', image=Image())],
|
601
|
+
['1', '2'], debug=debug_mode),
|
602
|
+
[
|
603
|
+
lm_lib.LMScoringResult(score=-0.0),
|
604
|
+
lm_lib.LMScoringResult(score=-1.0),
|
605
|
+
],
|
606
|
+
)
|
607
|
+
|
511
608
|
debug_info = string_io.getvalue()
|
512
609
|
expected_included = [
|
513
610
|
debug_prints[f]
|
@@ -528,10 +625,73 @@ class LanguageModelTest(unittest.TestCase):
|
|
528
625
|
if debug_mode & lm_lib.LMDebugMode.PROMPT:
|
529
626
|
self.assertIn('[0] MODALITY OBJECTS SENT TO LM', debug_info)
|
530
627
|
|
628
|
+
def test_score_with_unmatched_prompt_and_completions(self):
|
629
|
+
with self.assertRaises(ValueError):
|
630
|
+
MockScoringModel().score(['hi',], ['1', '2', '3'])
|
631
|
+
|
531
632
|
def test_score_with_unsupported_model(self):
|
532
633
|
with self.assertRaises(NotImplementedError):
|
533
634
|
MockModel().score('hi', ['1', '2'])
|
534
635
|
|
636
|
+
def test_tokenize(self):
|
637
|
+
info_flag = lm_lib.LMDebugMode.INFO
|
638
|
+
prompt_flag = lm_lib.LMDebugMode.PROMPT
|
639
|
+
response_flag = lm_lib.LMDebugMode.RESPONSE
|
640
|
+
debug_prints = {
|
641
|
+
info_flag: 'LM INFO',
|
642
|
+
prompt_flag: 'PROMPT TO TOKENIZE',
|
643
|
+
response_flag: 'TOKENS RETURNED',
|
644
|
+
}
|
645
|
+
debug_modes = [
|
646
|
+
info_flag,
|
647
|
+
prompt_flag,
|
648
|
+
response_flag,
|
649
|
+
info_flag | prompt_flag,
|
650
|
+
info_flag | response_flag,
|
651
|
+
prompt_flag | response_flag,
|
652
|
+
info_flag | prompt_flag | response_flag,
|
653
|
+
]
|
654
|
+
|
655
|
+
class Image(modality.Modality):
|
656
|
+
def to_bytes(self):
|
657
|
+
return b'fake_image'
|
658
|
+
|
659
|
+
for debug_mode in debug_modes:
|
660
|
+
string_io = io.StringIO()
|
661
|
+
lm = MockTokenizeModel()
|
662
|
+
|
663
|
+
with contextlib.redirect_stdout(string_io):
|
664
|
+
self.assertEqual(
|
665
|
+
lm.tokenize(
|
666
|
+
message_lib.UserMessage('hi <<[[image]]>>', image=Image()),
|
667
|
+
debug=debug_mode),
|
668
|
+
[('hi', 0), ('<<[[image]]>>', 1)],
|
669
|
+
)
|
670
|
+
|
671
|
+
debug_info = string_io.getvalue()
|
672
|
+
expected_included = [
|
673
|
+
debug_prints[f]
|
674
|
+
for f in lm_lib.LMDebugMode
|
675
|
+
if f != lm_lib.LMDebugMode.NONE and f in debug_mode
|
676
|
+
]
|
677
|
+
expected_excluded = [
|
678
|
+
debug_prints[f]
|
679
|
+
for f in lm_lib.LMDebugMode
|
680
|
+
if f != lm_lib.LMDebugMode.NONE and f not in debug_mode
|
681
|
+
]
|
682
|
+
|
683
|
+
for expected_include in expected_included:
|
684
|
+
self.assertIn(expected_include, debug_info)
|
685
|
+
for expected_exclude in expected_excluded:
|
686
|
+
self.assertNotIn(expected_exclude, debug_info)
|
687
|
+
|
688
|
+
if debug_mode & lm_lib.LMDebugMode.PROMPT:
|
689
|
+
self.assertIn('[0] MODALITY OBJECTS SENT TO LM', debug_info)
|
690
|
+
|
691
|
+
def test_tokenize_with_unsupported_model(self):
|
692
|
+
with self.assertRaises(NotImplementedError):
|
693
|
+
MockModel().tokenize('hi')
|
694
|
+
|
535
695
|
def test_rate_to_max_concurrency_no_rpm_no_tpm(self) -> None:
|
536
696
|
lm = MockModel()
|
537
697
|
self.assertEqual(
|
@@ -564,6 +724,260 @@ class LanguageModelTest(unittest.TestCase):
|
|
564
724
|
self.assertEqual(lm.rate_to_max_concurrency(requests_per_min=1), 1)
|
565
725
|
self.assertEqual(lm.rate_to_max_concurrency(tokens_per_min=1), 1)
|
566
726
|
|
727
|
+
def test_track_usages(self):
|
728
|
+
lm = MockModel(name='model1')
|
729
|
+
lm2 = MockModel(name='model2')
|
730
|
+
with lm_lib.track_usages() as usages1:
|
731
|
+
_ = lm('hi')
|
732
|
+
with lm_lib.track_usages(lm2) as usages2:
|
733
|
+
with lm_lib.track_usages('model1') as usages3:
|
734
|
+
with lm_lib.track_usages('model1', lm2) as usages4:
|
735
|
+
def call_lm(prompt):
|
736
|
+
_ = lm.sample([prompt] * 2)
|
737
|
+
lm2('hi')
|
738
|
+
list(concurrent.concurrent_map(call_lm, ['hi', 'hello']))
|
739
|
+
|
740
|
+
self.assertEqual(usages2.uncached.breakdown, {
|
741
|
+
'model2': lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
|
742
|
+
})
|
743
|
+
self.assertFalse(usages2.cached)
|
744
|
+
self.assertEqual(usages3.uncached.breakdown, {
|
745
|
+
'model1': lm_lib.LMSamplingUsage(100 * 4, 100 * 4, 200 * 4, 4, 4.0),
|
746
|
+
})
|
747
|
+
self.assertFalse(usages3.cached)
|
748
|
+
self.assertEqual(usages4.uncached.breakdown, {
|
749
|
+
'model1': lm_lib.LMSamplingUsage(100 * 4, 100 * 4, 200 * 4, 4, 4.0),
|
750
|
+
'model2': lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
|
751
|
+
})
|
752
|
+
self.assertFalse(usages4.cached)
|
753
|
+
self.assertEqual(usages1.uncached.breakdown, {
|
754
|
+
'model1': lm_lib.LMSamplingUsage(100 * 5, 100 * 5, 200 * 5, 5, 5.0),
|
755
|
+
'model2': lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
|
756
|
+
})
|
757
|
+
self.assertFalse(usages1.cached)
|
758
|
+
self.assertEqual(
|
759
|
+
usages1.total,
|
760
|
+
lm_lib.LMSamplingUsage(100 * 6, 100 * 6, 200 * 6, 6, 6.0),
|
761
|
+
)
|
762
|
+
|
763
|
+
cache = in_memory.InMemory()
|
764
|
+
lm = MockModel(cache=cache, name='model1')
|
765
|
+
with lm_lib.track_usages() as usages1:
|
766
|
+
_ = lm('hi')
|
767
|
+
self.assertEqual(usages1.uncached.breakdown, {
|
768
|
+
'model1': lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
|
769
|
+
})
|
770
|
+
self.assertFalse(usages1.cached)
|
771
|
+
with lm_lib.track_usages() as usages2:
|
772
|
+
_ = lm('hi')
|
773
|
+
self.assertEqual(usages2.cached.breakdown, {
|
774
|
+
'model1': lm_lib.LMSamplingUsage(100, 100, 200, 1, 0.0),
|
775
|
+
})
|
776
|
+
self.assertFalse(usages2.uncached)
|
777
|
+
|
778
|
+
|
779
|
+
class LMSamplingUsageTest(unittest.TestCase):
|
780
|
+
|
781
|
+
def test_basics(self):
|
782
|
+
usage = lm_lib.LMSamplingUsage(100, 200, 300, 4, 5.0)
|
783
|
+
self.assertEqual(usage.num_requests, 4)
|
784
|
+
self.assertEqual(usage.prompt_tokens, 100)
|
785
|
+
self.assertEqual(usage.completion_tokens, 200)
|
786
|
+
self.assertEqual(usage.total_tokens, 300)
|
787
|
+
self.assertEqual(usage.estimated_cost, 5.0)
|
788
|
+
self.assertEqual(usage.average_prompt_tokens, 25)
|
789
|
+
self.assertEqual(usage.average_completion_tokens, 50)
|
790
|
+
self.assertEqual(usage.average_total_tokens, 75)
|
791
|
+
self.assertEqual(usage.average_estimated_cost, 1.25)
|
792
|
+
|
793
|
+
def test_add(self):
|
794
|
+
usage1 = lm_lib.LMSamplingUsage(100, 200, 300, 4, 5.0)
|
795
|
+
usage1.rebind(retry_stats=lm_lib.RetryStats(1, 3, 4, {'e1': 1}))
|
796
|
+
usage2 = lm_lib.LMSamplingUsage(100, 200, 300, 4, 5.0)
|
797
|
+
self.assertEqual(usage1 + usage2, usage1 + usage2)
|
798
|
+
self.assertIs(usage1 + None, usage1)
|
799
|
+
self.assertIs(None + usage1, usage1)
|
800
|
+
usage3 = lm_lib.LMSamplingUsage(100, 200, 300, 4, None)
|
801
|
+
usage3.rebind(retry_stats=lm_lib.RetryStats(2, 4, 5, {'e1': 2, 'e2': 3}))
|
802
|
+
self.assertEqual(
|
803
|
+
usage1 + usage3,
|
804
|
+
lm_lib.LMSamplingUsage(
|
805
|
+
200,
|
806
|
+
400,
|
807
|
+
600,
|
808
|
+
8,
|
809
|
+
5.0,
|
810
|
+
retry_stats=lm_lib.RetryStats(3, 7, 9, {'e1': 3, 'e2': 3}),
|
811
|
+
),
|
812
|
+
)
|
813
|
+
self.assertEqual(
|
814
|
+
usage3 + usage1,
|
815
|
+
lm_lib.LMSamplingUsage(
|
816
|
+
200,
|
817
|
+
400,
|
818
|
+
600,
|
819
|
+
8,
|
820
|
+
5.0,
|
821
|
+
retry_stats=lm_lib.RetryStats(3, 7, 9, {'e1': 3, 'e2': 3}),
|
822
|
+
),
|
823
|
+
)
|
824
|
+
|
825
|
+
def test_usage_not_available(self):
|
826
|
+
usage_not_available = lm_lib.UsageNotAvailable()
|
827
|
+
self.assertEqual(usage_not_available.prompt_tokens, 0)
|
828
|
+
self.assertEqual(usage_not_available.completion_tokens, 0)
|
829
|
+
self.assertEqual(usage_not_available.total_tokens, 0)
|
830
|
+
self.assertEqual(usage_not_available.average_prompt_tokens, 0)
|
831
|
+
self.assertEqual(usage_not_available.average_completion_tokens, 0)
|
832
|
+
self.assertEqual(usage_not_available.average_total_tokens, 0)
|
833
|
+
self.assertIsNone(usage_not_available.average_estimated_cost)
|
834
|
+
self.assertTrue(usage_not_available)
|
835
|
+
self.assertEqual(
|
836
|
+
usage_not_available + lm_lib.LMSamplingUsage(1, 2, 3, 4, 5.0),
|
837
|
+
lm_lib.UsageNotAvailable(num_requests=5)
|
838
|
+
)
|
839
|
+
self.assertEqual(
|
840
|
+
lm_lib.LMSamplingUsage(1, 2, 3, 4, 5.0) + usage_not_available,
|
841
|
+
lm_lib.UsageNotAvailable(num_requests=5)
|
842
|
+
)
|
843
|
+
self.assertIs(None + usage_not_available, usage_not_available)
|
844
|
+
self.assertIs(usage_not_available + None, usage_not_available)
|
845
|
+
|
846
|
+
|
847
|
+
class UsageSummaryTest(unittest.TestCase):
|
848
|
+
|
849
|
+
def test_basics(self):
|
850
|
+
usage_summary = lm_lib.UsageSummary()
|
851
|
+
self.assertFalse(usage_summary.total)
|
852
|
+
self.assertFalse(usage_summary.cached)
|
853
|
+
self.assertFalse(usage_summary.uncached)
|
854
|
+
|
855
|
+
# Add uncached.
|
856
|
+
usage_summary.add(
|
857
|
+
'model1', lm_lib.LMSamplingUsage(1, 2, 3, 1, 5.0), False
|
858
|
+
)
|
859
|
+
self.assertEqual(
|
860
|
+
usage_summary.total, lm_lib.LMSamplingUsage(1, 2, 3, 1, 5.0)
|
861
|
+
)
|
862
|
+
self.assertEqual(
|
863
|
+
usage_summary.uncached.total, lm_lib.LMSamplingUsage(1, 2, 3, 1, 5.0)
|
864
|
+
)
|
865
|
+
# Add cached.
|
866
|
+
self.assertFalse(usage_summary.cached)
|
867
|
+
usage_summary.add(
|
868
|
+
'model1', lm_lib.LMSamplingUsage(1, 2, 3, 1, 5.0), True
|
869
|
+
)
|
870
|
+
self.assertEqual(
|
871
|
+
usage_summary.total, lm_lib.LMSamplingUsage(2, 4, 6, 2, 5.0)
|
872
|
+
)
|
873
|
+
self.assertEqual(
|
874
|
+
usage_summary.cached.total, lm_lib.LMSamplingUsage(1, 2, 3, 1, 0.0)
|
875
|
+
)
|
876
|
+
# Add UsageNotAvailable.
|
877
|
+
usage_summary.add(
|
878
|
+
'model1', lm_lib.UsageNotAvailable(num_requests=1), False
|
879
|
+
)
|
880
|
+
self.assertEqual(
|
881
|
+
usage_summary.total, lm_lib.UsageNotAvailable(num_requests=3)
|
882
|
+
)
|
883
|
+
self.assertEqual(
|
884
|
+
usage_summary.uncached.total, lm_lib.UsageNotAvailable(num_requests=2)
|
885
|
+
)
|
886
|
+
|
887
|
+
def test_merge(self):
|
888
|
+
usage_summary = lm_lib.UsageSummary()
|
889
|
+
usage_summary.add(
|
890
|
+
'model1', lm_lib.LMSamplingUsage(1, 2, 3, 1, 5.0), False
|
891
|
+
)
|
892
|
+
usage_summary.add(
|
893
|
+
'model2', lm_lib.LMSamplingUsage(1, 2, 3, 1, 5.0), False
|
894
|
+
)
|
895
|
+
usage_summary.add(
|
896
|
+
'model1', lm_lib.LMSamplingUsage(1, 2, 3, 1, 5.0), False
|
897
|
+
)
|
898
|
+
usage_summary2 = lm_lib.UsageSummary()
|
899
|
+
usage_summary2.add(
|
900
|
+
'model1', lm_lib.LMSamplingUsage(1, 2, 3, 1, 5.0), False
|
901
|
+
)
|
902
|
+
usage_summary2.add(
|
903
|
+
'model3', lm_lib.LMSamplingUsage(1, 2, 3, 1, 5.0), False
|
904
|
+
)
|
905
|
+
usage_summary2.merge(usage_summary)
|
906
|
+
self.assertEqual(
|
907
|
+
usage_summary2,
|
908
|
+
lm_lib.UsageSummary(
|
909
|
+
cached=lm_lib.UsageSummary.AggregatedUsage(
|
910
|
+
total=lm_lib.LMSamplingUsage(
|
911
|
+
prompt_tokens=0,
|
912
|
+
completion_tokens=0,
|
913
|
+
total_tokens=0,
|
914
|
+
num_requests=0,
|
915
|
+
estimated_cost=0.0,
|
916
|
+
),
|
917
|
+
breakdown={}
|
918
|
+
),
|
919
|
+
uncached=lm_lib.UsageSummary.AggregatedUsage(
|
920
|
+
total=lm_lib.LMSamplingUsage(
|
921
|
+
prompt_tokens=5,
|
922
|
+
completion_tokens=10,
|
923
|
+
total_tokens=15,
|
924
|
+
num_requests=5,
|
925
|
+
estimated_cost=25.0
|
926
|
+
),
|
927
|
+
breakdown=dict(
|
928
|
+
model1=lm_lib.LMSamplingUsage(
|
929
|
+
prompt_tokens=3,
|
930
|
+
completion_tokens=6,
|
931
|
+
total_tokens=9,
|
932
|
+
num_requests=3,
|
933
|
+
estimated_cost=15.0
|
934
|
+
),
|
935
|
+
model3=lm_lib.LMSamplingUsage(
|
936
|
+
prompt_tokens=1,
|
937
|
+
completion_tokens=2,
|
938
|
+
total_tokens=3,
|
939
|
+
num_requests=1,
|
940
|
+
estimated_cost=5.0
|
941
|
+
),
|
942
|
+
model2=lm_lib.LMSamplingUsage(
|
943
|
+
prompt_tokens=1,
|
944
|
+
completion_tokens=2,
|
945
|
+
total_tokens=3,
|
946
|
+
num_requests=1,
|
947
|
+
estimated_cost=5.0
|
948
|
+
)
|
949
|
+
)
|
950
|
+
)
|
951
|
+
)
|
952
|
+
)
|
953
|
+
|
954
|
+
def test_html_view(self):
|
955
|
+
usage_summary = lm_lib.UsageSummary()
|
956
|
+
usage_summary.add(
|
957
|
+
'model1', lm_lib.LMSamplingUsage(1, 2, 3, 1, 5.0), False
|
958
|
+
)
|
959
|
+
self.assertIn(
|
960
|
+
'5.000',
|
961
|
+
usage_summary.to_html(extra_flags=dict(as_badge=True)).content
|
962
|
+
)
|
963
|
+
usage_summary.add(
|
964
|
+
'model1', lm_lib.LMSamplingUsage(1, 2, 3, 1, 5.0), False
|
965
|
+
)
|
966
|
+
self.assertIn(
|
967
|
+
'10.000',
|
968
|
+
usage_summary.to_html(
|
969
|
+
extra_flags=dict(as_badge=True, interactive=True)
|
970
|
+
).content
|
971
|
+
)
|
972
|
+
self.assertTrue(
|
973
|
+
usage_summary.to_html().content.startswith('<details open')
|
974
|
+
)
|
975
|
+
with pg.views.html.controls.HtmlControl.track_scripts() as scripts:
|
976
|
+
usage_summary.add(
|
977
|
+
'model2', lm_lib.LMSamplingUsage(1, 2, 3, 1, 5.0), False
|
978
|
+
)
|
979
|
+
self.assertEqual(len(scripts), 4)
|
980
|
+
|
567
981
|
|
568
982
|
if __name__ == '__main__':
|
569
983
|
unittest.main()
|