langfun 0.0.2.dev20240429__py3-none-any.whl → 0.1.2.dev202501140804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (144) hide show
  1. langfun/__init__.py +20 -2
  2. langfun/core/__init__.py +16 -5
  3. langfun/core/agentic/__init__.py +30 -0
  4. langfun/core/agentic/action.py +854 -0
  5. langfun/core/agentic/action_eval.py +150 -0
  6. langfun/core/agentic/action_eval_test.py +109 -0
  7. langfun/core/agentic/action_test.py +136 -0
  8. langfun/core/coding/python/__init__.py +5 -11
  9. langfun/core/coding/python/correction.py +37 -21
  10. langfun/core/coding/python/correction_test.py +29 -3
  11. langfun/core/coding/python/execution.py +40 -216
  12. langfun/core/coding/python/execution_test.py +29 -89
  13. langfun/core/coding/python/generation.py +21 -11
  14. langfun/core/coding/python/generation_test.py +2 -2
  15. langfun/core/coding/python/parsing.py +108 -193
  16. langfun/core/coding/python/parsing_test.py +2 -105
  17. langfun/core/component.py +63 -2
  18. langfun/core/component_test.py +53 -0
  19. langfun/core/concurrent.py +414 -117
  20. langfun/core/concurrent_test.py +111 -24
  21. langfun/core/console.py +18 -5
  22. langfun/core/console_test.py +17 -0
  23. langfun/core/eval/__init__.py +16 -1
  24. langfun/core/eval/base.py +622 -174
  25. langfun/core/eval/base_test.py +200 -54
  26. langfun/core/eval/matching.py +63 -76
  27. langfun/core/eval/matching_test.py +17 -8
  28. langfun/core/eval/patching.py +130 -0
  29. langfun/core/eval/patching_test.py +170 -0
  30. langfun/core/eval/scoring.py +26 -26
  31. langfun/core/eval/scoring_test.py +19 -2
  32. langfun/core/eval/v2/__init__.py +42 -0
  33. langfun/core/eval/v2/checkpointing.py +380 -0
  34. langfun/core/eval/v2/checkpointing_test.py +228 -0
  35. langfun/core/eval/v2/eval_test_helper.py +136 -0
  36. langfun/core/eval/v2/evaluation.py +725 -0
  37. langfun/core/eval/v2/evaluation_test.py +180 -0
  38. langfun/core/eval/v2/example.py +305 -0
  39. langfun/core/eval/v2/example_test.py +128 -0
  40. langfun/core/eval/v2/experiment.py +1048 -0
  41. langfun/core/eval/v2/experiment_test.py +433 -0
  42. langfun/core/eval/v2/metric_values.py +156 -0
  43. langfun/core/eval/v2/metric_values_test.py +80 -0
  44. langfun/core/eval/v2/metrics.py +357 -0
  45. langfun/core/eval/v2/metrics_test.py +203 -0
  46. langfun/core/eval/v2/progress.py +348 -0
  47. langfun/core/eval/v2/progress_test.py +82 -0
  48. langfun/core/eval/v2/progress_tracking.py +210 -0
  49. langfun/core/eval/v2/progress_tracking_test.py +66 -0
  50. langfun/core/eval/v2/reporting.py +270 -0
  51. langfun/core/eval/v2/reporting_test.py +158 -0
  52. langfun/core/eval/v2/runners.py +488 -0
  53. langfun/core/eval/v2/runners_test.py +334 -0
  54. langfun/core/langfunc.py +4 -17
  55. langfun/core/langfunc_test.py +22 -6
  56. langfun/core/language_model.py +577 -39
  57. langfun/core/language_model_test.py +470 -56
  58. langfun/core/llms/__init__.py +87 -16
  59. langfun/core/llms/anthropic.py +312 -87
  60. langfun/core/llms/anthropic_test.py +71 -3
  61. langfun/core/llms/cache/base.py +21 -2
  62. langfun/core/llms/cache/in_memory.py +13 -0
  63. langfun/core/llms/cache/in_memory_test.py +53 -2
  64. langfun/core/llms/compositional.py +101 -0
  65. langfun/core/llms/compositional_test.py +73 -0
  66. langfun/core/llms/deepseek.py +117 -0
  67. langfun/core/llms/deepseek_test.py +61 -0
  68. langfun/core/llms/fake.py +11 -7
  69. langfun/core/llms/fake_test.py +14 -0
  70. langfun/core/llms/gemini.py +507 -0
  71. langfun/core/llms/gemini_test.py +195 -0
  72. langfun/core/llms/google_genai.py +62 -218
  73. langfun/core/llms/google_genai_test.py +9 -202
  74. langfun/core/llms/groq.py +160 -144
  75. langfun/core/llms/groq_test.py +31 -137
  76. langfun/core/llms/llama_cpp.py +15 -42
  77. langfun/core/llms/llama_cpp_test.py +4 -30
  78. langfun/core/llms/openai.py +395 -203
  79. langfun/core/llms/openai_compatible.py +179 -0
  80. langfun/core/llms/openai_compatible_test.py +495 -0
  81. langfun/core/llms/openai_test.py +30 -395
  82. langfun/core/llms/rest.py +113 -0
  83. langfun/core/llms/rest_test.py +111 -0
  84. langfun/core/llms/vertexai.py +192 -0
  85. langfun/core/llms/vertexai_test.py +52 -0
  86. langfun/core/logging.py +284 -0
  87. langfun/core/logging_test.py +125 -0
  88. langfun/core/message.py +319 -9
  89. langfun/core/message_test.py +190 -13
  90. langfun/core/modalities/__init__.py +6 -2
  91. langfun/core/modalities/audio.py +30 -0
  92. langfun/core/modalities/audio_test.py +63 -0
  93. langfun/core/modalities/image.py +39 -20
  94. langfun/core/modalities/image_test.py +52 -9
  95. langfun/core/modalities/mime.py +206 -29
  96. langfun/core/modalities/mime_test.py +90 -9
  97. langfun/core/modalities/ms_office.py +117 -0
  98. langfun/core/modalities/ms_office_test.py +389 -0
  99. langfun/core/modalities/pdf.py +22 -0
  100. langfun/core/modalities/pdf_test.py +57 -0
  101. langfun/core/modalities/video.py +9 -26
  102. langfun/core/modalities/video_test.py +3 -3
  103. langfun/core/modality.py +26 -3
  104. langfun/core/modality_test.py +2 -2
  105. langfun/core/sampling.py +11 -11
  106. langfun/core/structured/__init__.py +12 -16
  107. langfun/core/structured/completion.py +32 -5
  108. langfun/core/structured/completion_test.py +7 -6
  109. langfun/core/structured/description.py +2 -2
  110. langfun/core/structured/description_test.py +3 -3
  111. langfun/core/structured/function_generation.py +60 -27
  112. langfun/core/structured/function_generation_test.py +72 -2
  113. langfun/core/structured/mapping.py +97 -47
  114. langfun/core/structured/mapping_test.py +90 -2
  115. langfun/core/structured/parsing.py +33 -21
  116. langfun/core/structured/parsing_test.py +53 -9
  117. langfun/core/structured/querying.py +746 -0
  118. langfun/core/structured/{prompting_test.py → querying_test.py} +469 -51
  119. langfun/core/structured/schema.py +204 -97
  120. langfun/core/structured/schema_generation.py +1 -1
  121. langfun/core/structured/schema_test.py +130 -29
  122. langfun/core/structured/scoring.py +125 -19
  123. langfun/core/structured/scoring_test.py +30 -0
  124. langfun/core/structured/tokenization.py +64 -0
  125. langfun/core/structured/tokenization_test.py +48 -0
  126. langfun/core/template.py +115 -1
  127. langfun/core/template_test.py +71 -1
  128. langfun/core/templates/conversation.py +9 -0
  129. langfun/core/templates/conversation_test.py +4 -3
  130. langfun/core/templates/selfplay_test.py +10 -2
  131. langfun-0.1.2.dev202501140804.dist-info/METADATA +225 -0
  132. langfun-0.1.2.dev202501140804.dist-info/RECORD +153 -0
  133. {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501140804.dist-info}/WHEEL +1 -1
  134. langfun/core/coding/python/errors.py +0 -108
  135. langfun/core/coding/python/errors_test.py +0 -99
  136. langfun/core/coding/python/permissions.py +0 -90
  137. langfun/core/coding/python/permissions_test.py +0 -86
  138. langfun/core/structured/prompting.py +0 -238
  139. langfun/core/text_formatting.py +0 -162
  140. langfun/core/text_formatting_test.py +0 -47
  141. langfun-0.0.2.dev20240429.dist-info/METADATA +0 -100
  142. langfun-0.0.2.dev20240429.dist-info/RECORD +0 -108
  143. {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501140804.dist-info}/LICENSE +0 -0
  144. {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501140804.dist-info}/top_level.txt +0 -0
@@ -27,48 +27,53 @@ import pyglove as pg
27
27
  @pg.use_init_args(['failures_before_attempt'])
28
28
  class MockModel(lm_lib.LanguageModel):
29
29
  """A mock model that echo back user prompts."""
30
-
31
30
  failures_before_attempt: int = 0
31
+ name: str = 'MockModel'
32
32
 
33
33
  def _sample(self,
34
34
  prompts: list[message_lib.Message]
35
35
  ) -> list[lm_lib.LMSamplingResult]:
36
36
  context = pg.Dict(attempt=0)
37
37
 
38
- def fake_sample(prompts):
38
+ def fake_sample(prompt):
39
39
  if context.attempt >= self.failures_before_attempt:
40
- return [
41
- lm_lib.LMSamplingResult(
42
- [
43
- lm_lib.LMSample( # pylint: disable=g-complex-comprehension
44
- response=prompt.text * self.sampling_options.top_k,
45
- score=self.sampling_options.temperature or -1.0,
46
- )
47
- ],
48
- usage=lm_lib.LMSamplingUsage(
49
- prompt_tokens=100,
50
- completion_tokens=100,
51
- total_tokens=200,
52
- ),
53
- )
54
- for prompt in prompts
55
- ]
56
- context.attempt += 1
40
+ return lm_lib.LMSamplingResult(
41
+ [
42
+ lm_lib.LMSample( # pylint: disable=g-complex-comprehension
43
+ response=prompt.text * self.sampling_options.top_k,
44
+ score=self.sampling_options.temperature or -1.0,
45
+ )
46
+ ],
47
+ usage=lm_lib.LMSamplingUsage(
48
+ prompt_tokens=100,
49
+ completion_tokens=100,
50
+ total_tokens=200,
51
+ estimated_cost=1.0,
52
+ ),
53
+ )
54
+ else:
55
+ context.attempt += 1
57
56
  raise ValueError('Failed to sample prompts.')
58
57
 
59
- return concurrent.with_retry(
60
- fake_sample,
61
- retry_on_errors=ValueError,
62
- max_attempts=self.max_attempts,
63
- retry_interval=1,
64
- )(prompts)
58
+ results = self._parallel_execute_with_currency_control(
59
+ fake_sample, prompts, retry_on_errors=ValueError
60
+ )
61
+ for result in results:
62
+ result.usage.retry_stats.rebind(
63
+ total_call_interval=0, skip_notification=True
64
+ )
65
+ return results
66
+
67
+ @property
68
+ def model_id(self) -> str:
69
+ return self.name
65
70
 
66
71
 
67
72
  class MockScoringModel(MockModel):
68
73
 
69
74
  def _score(
70
75
  self,
71
- prompt: message_lib.Message,
76
+ prompt: message_lib.Message | list[message_lib.Message],
72
77
  completions: list[message_lib.Message],
73
78
  **kwargs
74
79
  ) -> list[lm_lib.LMScoringResult]:
@@ -77,6 +82,13 @@ class MockScoringModel(MockModel):
77
82
  ]
78
83
 
79
84
 
85
+ class MockTokenizeModel(MockModel):
86
+
87
+ def _tokenize(
88
+ self, prompt: message_lib.Message) -> list[tuple[str | bytes, int]]:
89
+ return [(w, i) for i, w in enumerate(prompt.text.split(' '))]
90
+
91
+
80
92
  class LMSamplingOptionsTest(unittest.TestCase):
81
93
  """Tests for LMSamplingOptions."""
82
94
 
@@ -105,6 +117,21 @@ class LanguageModelTest(unittest.TestCase):
105
117
  self.assertEqual(lm.sampling_options.top_k, 2)
106
118
  self.assertEqual(lm.max_attempts, 2)
107
119
 
120
+ def test_subclassing(self):
121
+
122
+ class ChildModel(lm_lib.LanguageModel):
123
+
124
+ sampling_options = lm_lib.LMSamplingOptions(
125
+ temperature=0.5, top_k=20
126
+ )
127
+
128
+ def _sample(self, *args, **kwargs):
129
+ pass
130
+
131
+ lm = ChildModel(top_k=10)
132
+ self.assertEqual(lm.sampling_options.temperature, 0.5)
133
+ self.assertEqual(lm.sampling_options.top_k, 10)
134
+
108
135
  def test_sample(self):
109
136
  lm = MockModel(top_k=1)
110
137
  self.assertEqual(
@@ -117,14 +144,15 @@ class LanguageModelTest(unittest.TestCase):
117
144
  'foo',
118
145
  score=-1.0,
119
146
  logprobs=None,
120
- usage=lm_lib.LMSamplingUsage(100, 100, 200),
147
+ is_cached=False,
148
+ usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
121
149
  tags=[message_lib.Message.TAG_LM_RESPONSE],
122
150
  ),
123
151
  score=-1.0,
124
152
  logprobs=None,
125
153
  )
126
154
  ],
127
- usage=lm_lib.LMSamplingUsage(100, 100, 200),
155
+ usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
128
156
  ),
129
157
  lm_lib.LMSamplingResult(
130
158
  [
@@ -133,14 +161,15 @@ class LanguageModelTest(unittest.TestCase):
133
161
  'bar',
134
162
  score=-1.0,
135
163
  logprobs=None,
136
- usage=lm_lib.LMSamplingUsage(100, 100, 200),
164
+ is_cached=False,
165
+ usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
137
166
  tags=[message_lib.Message.TAG_LM_RESPONSE],
138
167
  ),
139
168
  score=-1.0,
140
169
  logprobs=None,
141
170
  )
142
171
  ],
143
- usage=lm_lib.LMSamplingUsage(100, 100, 200),
172
+ usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
144
173
  ),
145
174
  ],
146
175
  )
@@ -158,14 +187,15 @@ class LanguageModelTest(unittest.TestCase):
158
187
  'foo' * 2,
159
188
  score=0.5,
160
189
  logprobs=None,
161
- usage=lm_lib.LMSamplingUsage(100, 100, 200),
190
+ is_cached=False,
191
+ usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
162
192
  tags=[message_lib.Message.TAG_LM_RESPONSE],
163
193
  ),
164
194
  score=0.5,
165
195
  logprobs=None,
166
196
  ),
167
197
  ],
168
- usage=lm_lib.LMSamplingUsage(100, 100, 200),
198
+ usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
169
199
  ),
170
200
  lm_lib.LMSamplingResult(
171
201
  [
@@ -174,7 +204,8 @@ class LanguageModelTest(unittest.TestCase):
174
204
  'bar' * 2,
175
205
  score=0.5,
176
206
  logprobs=None,
177
- usage=lm_lib.LMSamplingUsage(100, 100, 200),
207
+ is_cached=False,
208
+ usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
178
209
  tags=[message_lib.Message.TAG_LM_RESPONSE],
179
210
  ),
180
211
  score=0.5,
@@ -182,7 +213,8 @@ class LanguageModelTest(unittest.TestCase):
182
213
  ),
183
214
  ],
184
215
  usage=lm_lib.LMSamplingUsage(
185
- prompt_tokens=100, completion_tokens=100, total_tokens=200
216
+ prompt_tokens=100, completion_tokens=100, total_tokens=200,
217
+ num_requests=1, estimated_cost=1.0,
186
218
  ),
187
219
  ),
188
220
  ]
@@ -198,14 +230,15 @@ class LanguageModelTest(unittest.TestCase):
198
230
  'foo',
199
231
  score=1.0,
200
232
  logprobs=None,
201
- usage=lm_lib.LMSamplingUsage(100, 100, 200),
233
+ is_cached=False,
234
+ usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
202
235
  tags=[message_lib.Message.TAG_LM_RESPONSE],
203
236
  ),
204
237
  score=1.0,
205
238
  logprobs=None,
206
239
  ),
207
240
  ],
208
- usage=lm_lib.LMSamplingUsage(100, 100, 200),
241
+ usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
209
242
  ),
210
243
  lm_lib.LMSamplingResult(
211
244
  [
@@ -214,7 +247,8 @@ class LanguageModelTest(unittest.TestCase):
214
247
  'bar',
215
248
  score=1.0,
216
249
  logprobs=None,
217
- usage=lm_lib.LMSamplingUsage(100, 100, 200),
250
+ is_cached=False,
251
+ usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
218
252
  tags=[message_lib.Message.TAG_LM_RESPONSE],
219
253
  ),
220
254
  score=1.0,
@@ -222,7 +256,8 @@ class LanguageModelTest(unittest.TestCase):
222
256
  ),
223
257
  ],
224
258
  usage=lm_lib.LMSamplingUsage(
225
- prompt_tokens=100, completion_tokens=100, total_tokens=200
259
+ prompt_tokens=100, completion_tokens=100, total_tokens=200,
260
+ num_requests=1, estimated_cost=1.0,
226
261
  ),
227
262
  ),
228
263
  ]
@@ -237,14 +272,15 @@ class LanguageModelTest(unittest.TestCase):
237
272
  'foo' * 2,
238
273
  score=0.7,
239
274
  logprobs=None,
240
- usage=lm_lib.LMSamplingUsage(100, 100, 200),
275
+ is_cached=False,
276
+ usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
241
277
  tags=[message_lib.Message.TAG_LM_RESPONSE],
242
278
  ),
243
279
  score=0.7,
244
280
  logprobs=None,
245
281
  ),
246
282
  ],
247
- usage=lm_lib.LMSamplingUsage(100, 100, 200),
283
+ usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
248
284
  ),
249
285
  lm_lib.LMSamplingResult(
250
286
  [
@@ -253,7 +289,8 @@ class LanguageModelTest(unittest.TestCase):
253
289
  'bar' * 2,
254
290
  score=0.7,
255
291
  logprobs=None,
256
- usage=lm_lib.LMSamplingUsage(100, 100, 200),
292
+ is_cached=False,
293
+ usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
257
294
  tags=[message_lib.Message.TAG_LM_RESPONSE],
258
295
  ),
259
296
  score=0.7,
@@ -261,7 +298,8 @@ class LanguageModelTest(unittest.TestCase):
261
298
  ),
262
299
  ],
263
300
  usage=lm_lib.LMSamplingUsage(
264
- prompt_tokens=100, completion_tokens=100, total_tokens=200
301
+ prompt_tokens=100, completion_tokens=100, total_tokens=200,
302
+ num_requests=1, estimated_cost=1.0,
265
303
  ),
266
304
  ),
267
305
  ]
@@ -273,7 +311,9 @@ class LanguageModelTest(unittest.TestCase):
273
311
  self.assertEqual(response.text, 'foo')
274
312
  self.assertEqual(response.score, -1.0)
275
313
  self.assertIsNone(response.logprobs)
276
- self.assertEqual(response.usage, lm_lib.LMSamplingUsage(100, 100, 200))
314
+ self.assertEqual(
315
+ response.usage, lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0)
316
+ )
277
317
 
278
318
  # Test override sampling_options.
279
319
  self.assertEqual(
@@ -296,14 +336,17 @@ class LanguageModelTest(unittest.TestCase):
296
336
  cache_seed=0,
297
337
  score=-1.0,
298
338
  logprobs=None,
299
- usage=lm_lib.LMSamplingUsage(100, 100, 200),
339
+ is_cached=False,
340
+ usage=lm_lib.LMSamplingUsage(
341
+ 100, 100, 200, 1, 1.0
342
+ ),
300
343
  tags=[message_lib.Message.TAG_LM_RESPONSE],
301
344
  ),
302
345
  score=-1.0,
303
346
  logprobs=None,
304
347
  )
305
348
  ],
306
- usage=lm_lib.LMSamplingUsage(100, 100, 200),
349
+ usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
307
350
  ),
308
351
  lm_lib.LMSamplingResult(
309
352
  [
@@ -313,14 +356,15 @@ class LanguageModelTest(unittest.TestCase):
313
356
  cache_seed=0,
314
357
  score=-1.0,
315
358
  logprobs=None,
316
- usage=lm_lib.LMSamplingUsage(100, 100, 200),
359
+ is_cached=False,
360
+ usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
317
361
  tags=[message_lib.Message.TAG_LM_RESPONSE],
318
362
  ),
319
363
  score=-1.0,
320
364
  logprobs=None,
321
365
  )
322
366
  ],
323
- usage=lm_lib.LMSamplingUsage(100, 100, 200),
367
+ usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
324
368
  ),
325
369
  ],
326
370
  )
@@ -328,7 +372,9 @@ class LanguageModelTest(unittest.TestCase):
328
372
  self.assertEqual(cache.stats.num_hits, 0)
329
373
  self.assertEqual(cache.stats.num_updates, 2)
330
374
 
331
- self.assertEqual(lm('foo'), 'foo')
375
+ result = lm('foo')
376
+ self.assertEqual(result, 'foo')
377
+ self.assertTrue(result.metadata.is_cached)
332
378
  self.assertEqual(lm('bar'), 'bar')
333
379
  self.assertEqual(cache.stats.num_queries, 4)
334
380
  self.assertEqual(cache.stats.num_hits, 2)
@@ -350,14 +396,15 @@ class LanguageModelTest(unittest.TestCase):
350
396
  cache_seed=0,
351
397
  score=1.0,
352
398
  logprobs=None,
353
- usage=lm_lib.LMSamplingUsage(100, 100, 200),
399
+ is_cached=False,
400
+ usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
354
401
  tags=[message_lib.Message.TAG_LM_RESPONSE],
355
402
  ),
356
403
  score=1.0,
357
404
  logprobs=None,
358
405
  )
359
406
  ],
360
- usage=lm_lib.LMSamplingUsage(100, 100, 200),
407
+ usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
361
408
  ),
362
409
  lm_lib.LMSamplingResult(
363
410
  [
@@ -367,14 +414,15 @@ class LanguageModelTest(unittest.TestCase):
367
414
  cache_seed=0,
368
415
  score=1.0,
369
416
  logprobs=None,
370
- usage=lm_lib.LMSamplingUsage(100, 100, 200),
417
+ is_cached=False,
418
+ usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
371
419
  tags=[message_lib.Message.TAG_LM_RESPONSE],
372
420
  ),
373
421
  score=1.0,
374
422
  logprobs=None,
375
423
  )
376
424
  ],
377
- usage=lm_lib.LMSamplingUsage(100, 100, 200),
425
+ usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
378
426
  ),
379
427
  ],
380
428
  )
@@ -400,13 +448,50 @@ class LanguageModelTest(unittest.TestCase):
400
448
 
401
449
  def test_retry(self):
402
450
  lm = MockModel(
403
- failures_before_attempt=1, top_k=1,
451
+ failures_before_attempt=1, top_k=1, max_attempts=2, retry_interval=1
404
452
  )
405
453
  with self.assertRaisesRegex(
406
454
  concurrent.RetryError, 'Calling .* failed after 1 attempts'
407
455
  ):
408
456
  lm('foo', max_attempts=1)
409
- self.assertEqual(lm('foo', max_attempts=2), 'foo')
457
+
458
+ usage = lm_lib.LMSamplingUsage(
459
+ prompt_tokens=100,
460
+ completion_tokens=100,
461
+ total_tokens=200,
462
+ num_requests=1,
463
+ estimated_cost=1.0,
464
+ retry_stats=lm_lib.RetryStats(
465
+ num_occurences=1,
466
+ total_wait_interval=1,
467
+ errors={'ValueError': 1},
468
+ ),
469
+ )
470
+ out = lm.sample(['foo'])
471
+ self.assertEqual(
472
+ # lm.sample(['foo'], max_attempts=2),
473
+ out,
474
+ [
475
+ lm_lib.LMSamplingResult(
476
+ [
477
+ lm_lib.LMSample(
478
+ message_lib.AIMessage(
479
+ 'foo',
480
+ score=-1.0,
481
+ logprobs=None,
482
+ is_cached=False,
483
+ usage=usage,
484
+ tags=['lm-response'],
485
+ ),
486
+ score=-1.0,
487
+ logprobs=None,
488
+ )
489
+ ],
490
+ usage=usage,
491
+ is_cached=False,
492
+ )
493
+ ],
494
+ )
410
495
 
411
496
  def test_debug(self):
412
497
  class Image(modality.Modality):
@@ -418,8 +503,9 @@ class LanguageModelTest(unittest.TestCase):
418
503
  with contextlib.redirect_stdout(string_io):
419
504
  self.assertEqual(
420
505
  lm(message_lib.UserMessage(
421
- 'hi {{image}}', image=Image()), debug=True),
422
- 'hi {{image}}')
506
+ 'hi <<[[image]]>>', image=Image()), debug=True),
507
+ 'hi <<[[image]]>>'
508
+ )
423
509
 
424
510
  debug_info = string_io.getvalue()
425
511
  self.assertIn('[0] LM INFO', debug_info)
@@ -508,6 +594,17 @@ class LanguageModelTest(unittest.TestCase):
508
594
  ],
509
595
  )
510
596
 
597
+ self.assertEqual(
598
+ lm.score(
599
+ [message_lib.UserMessage('hi {{image}}', image=Image()),
600
+ message_lib.UserMessage('hi {{image}}', image=Image())],
601
+ ['1', '2'], debug=debug_mode),
602
+ [
603
+ lm_lib.LMScoringResult(score=-0.0),
604
+ lm_lib.LMScoringResult(score=-1.0),
605
+ ],
606
+ )
607
+
511
608
  debug_info = string_io.getvalue()
512
609
  expected_included = [
513
610
  debug_prints[f]
@@ -528,10 +625,73 @@ class LanguageModelTest(unittest.TestCase):
528
625
  if debug_mode & lm_lib.LMDebugMode.PROMPT:
529
626
  self.assertIn('[0] MODALITY OBJECTS SENT TO LM', debug_info)
530
627
 
628
+ def test_score_with_unmatched_prompt_and_completions(self):
629
+ with self.assertRaises(ValueError):
630
+ MockScoringModel().score(['hi',], ['1', '2', '3'])
631
+
531
632
  def test_score_with_unsupported_model(self):
532
633
  with self.assertRaises(NotImplementedError):
533
634
  MockModel().score('hi', ['1', '2'])
534
635
 
636
+ def test_tokenize(self):
637
+ info_flag = lm_lib.LMDebugMode.INFO
638
+ prompt_flag = lm_lib.LMDebugMode.PROMPT
639
+ response_flag = lm_lib.LMDebugMode.RESPONSE
640
+ debug_prints = {
641
+ info_flag: 'LM INFO',
642
+ prompt_flag: 'PROMPT TO TOKENIZE',
643
+ response_flag: 'TOKENS RETURNED',
644
+ }
645
+ debug_modes = [
646
+ info_flag,
647
+ prompt_flag,
648
+ response_flag,
649
+ info_flag | prompt_flag,
650
+ info_flag | response_flag,
651
+ prompt_flag | response_flag,
652
+ info_flag | prompt_flag | response_flag,
653
+ ]
654
+
655
+ class Image(modality.Modality):
656
+ def to_bytes(self):
657
+ return b'fake_image'
658
+
659
+ for debug_mode in debug_modes:
660
+ string_io = io.StringIO()
661
+ lm = MockTokenizeModel()
662
+
663
+ with contextlib.redirect_stdout(string_io):
664
+ self.assertEqual(
665
+ lm.tokenize(
666
+ message_lib.UserMessage('hi <<[[image]]>>', image=Image()),
667
+ debug=debug_mode),
668
+ [('hi', 0), ('<<[[image]]>>', 1)],
669
+ )
670
+
671
+ debug_info = string_io.getvalue()
672
+ expected_included = [
673
+ debug_prints[f]
674
+ for f in lm_lib.LMDebugMode
675
+ if f != lm_lib.LMDebugMode.NONE and f in debug_mode
676
+ ]
677
+ expected_excluded = [
678
+ debug_prints[f]
679
+ for f in lm_lib.LMDebugMode
680
+ if f != lm_lib.LMDebugMode.NONE and f not in debug_mode
681
+ ]
682
+
683
+ for expected_include in expected_included:
684
+ self.assertIn(expected_include, debug_info)
685
+ for expected_exclude in expected_excluded:
686
+ self.assertNotIn(expected_exclude, debug_info)
687
+
688
+ if debug_mode & lm_lib.LMDebugMode.PROMPT:
689
+ self.assertIn('[0] MODALITY OBJECTS SENT TO LM', debug_info)
690
+
691
+ def test_tokenize_with_unsupported_model(self):
692
+ with self.assertRaises(NotImplementedError):
693
+ MockModel().tokenize('hi')
694
+
535
695
  def test_rate_to_max_concurrency_no_rpm_no_tpm(self) -> None:
536
696
  lm = MockModel()
537
697
  self.assertEqual(
@@ -564,6 +724,260 @@ class LanguageModelTest(unittest.TestCase):
564
724
  self.assertEqual(lm.rate_to_max_concurrency(requests_per_min=1), 1)
565
725
  self.assertEqual(lm.rate_to_max_concurrency(tokens_per_min=1), 1)
566
726
 
727
+ def test_track_usages(self):
728
+ lm = MockModel(name='model1')
729
+ lm2 = MockModel(name='model2')
730
+ with lm_lib.track_usages() as usages1:
731
+ _ = lm('hi')
732
+ with lm_lib.track_usages(lm2) as usages2:
733
+ with lm_lib.track_usages('model1') as usages3:
734
+ with lm_lib.track_usages('model1', lm2) as usages4:
735
+ def call_lm(prompt):
736
+ _ = lm.sample([prompt] * 2)
737
+ lm2('hi')
738
+ list(concurrent.concurrent_map(call_lm, ['hi', 'hello']))
739
+
740
+ self.assertEqual(usages2.uncached.breakdown, {
741
+ 'model2': lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
742
+ })
743
+ self.assertFalse(usages2.cached)
744
+ self.assertEqual(usages3.uncached.breakdown, {
745
+ 'model1': lm_lib.LMSamplingUsage(100 * 4, 100 * 4, 200 * 4, 4, 4.0),
746
+ })
747
+ self.assertFalse(usages3.cached)
748
+ self.assertEqual(usages4.uncached.breakdown, {
749
+ 'model1': lm_lib.LMSamplingUsage(100 * 4, 100 * 4, 200 * 4, 4, 4.0),
750
+ 'model2': lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
751
+ })
752
+ self.assertFalse(usages4.cached)
753
+ self.assertEqual(usages1.uncached.breakdown, {
754
+ 'model1': lm_lib.LMSamplingUsage(100 * 5, 100 * 5, 200 * 5, 5, 5.0),
755
+ 'model2': lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
756
+ })
757
+ self.assertFalse(usages1.cached)
758
+ self.assertEqual(
759
+ usages1.total,
760
+ lm_lib.LMSamplingUsage(100 * 6, 100 * 6, 200 * 6, 6, 6.0),
761
+ )
762
+
763
+ cache = in_memory.InMemory()
764
+ lm = MockModel(cache=cache, name='model1')
765
+ with lm_lib.track_usages() as usages1:
766
+ _ = lm('hi')
767
+ self.assertEqual(usages1.uncached.breakdown, {
768
+ 'model1': lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
769
+ })
770
+ self.assertFalse(usages1.cached)
771
+ with lm_lib.track_usages() as usages2:
772
+ _ = lm('hi')
773
+ self.assertEqual(usages2.cached.breakdown, {
774
+ 'model1': lm_lib.LMSamplingUsage(100, 100, 200, 1, 0.0),
775
+ })
776
+ self.assertFalse(usages2.uncached)
777
+
778
+
779
+ class LMSamplingUsageTest(unittest.TestCase):
780
+
781
+ def test_basics(self):
782
+ usage = lm_lib.LMSamplingUsage(100, 200, 300, 4, 5.0)
783
+ self.assertEqual(usage.num_requests, 4)
784
+ self.assertEqual(usage.prompt_tokens, 100)
785
+ self.assertEqual(usage.completion_tokens, 200)
786
+ self.assertEqual(usage.total_tokens, 300)
787
+ self.assertEqual(usage.estimated_cost, 5.0)
788
+ self.assertEqual(usage.average_prompt_tokens, 25)
789
+ self.assertEqual(usage.average_completion_tokens, 50)
790
+ self.assertEqual(usage.average_total_tokens, 75)
791
+ self.assertEqual(usage.average_estimated_cost, 1.25)
792
+
793
+ def test_add(self):
794
+ usage1 = lm_lib.LMSamplingUsage(100, 200, 300, 4, 5.0)
795
+ usage1.rebind(retry_stats=lm_lib.RetryStats(1, 3, 4, {'e1': 1}))
796
+ usage2 = lm_lib.LMSamplingUsage(100, 200, 300, 4, 5.0)
797
+ self.assertEqual(usage1 + usage2, usage1 + usage2)
798
+ self.assertIs(usage1 + None, usage1)
799
+ self.assertIs(None + usage1, usage1)
800
+ usage3 = lm_lib.LMSamplingUsage(100, 200, 300, 4, None)
801
+ usage3.rebind(retry_stats=lm_lib.RetryStats(2, 4, 5, {'e1': 2, 'e2': 3}))
802
+ self.assertEqual(
803
+ usage1 + usage3,
804
+ lm_lib.LMSamplingUsage(
805
+ 200,
806
+ 400,
807
+ 600,
808
+ 8,
809
+ 5.0,
810
+ retry_stats=lm_lib.RetryStats(3, 7, 9, {'e1': 3, 'e2': 3}),
811
+ ),
812
+ )
813
+ self.assertEqual(
814
+ usage3 + usage1,
815
+ lm_lib.LMSamplingUsage(
816
+ 200,
817
+ 400,
818
+ 600,
819
+ 8,
820
+ 5.0,
821
+ retry_stats=lm_lib.RetryStats(3, 7, 9, {'e1': 3, 'e2': 3}),
822
+ ),
823
+ )
824
+
825
+ def test_usage_not_available(self):
826
+ usage_not_available = lm_lib.UsageNotAvailable()
827
+ self.assertEqual(usage_not_available.prompt_tokens, 0)
828
+ self.assertEqual(usage_not_available.completion_tokens, 0)
829
+ self.assertEqual(usage_not_available.total_tokens, 0)
830
+ self.assertEqual(usage_not_available.average_prompt_tokens, 0)
831
+ self.assertEqual(usage_not_available.average_completion_tokens, 0)
832
+ self.assertEqual(usage_not_available.average_total_tokens, 0)
833
+ self.assertIsNone(usage_not_available.average_estimated_cost)
834
+ self.assertTrue(usage_not_available)
835
+ self.assertEqual(
836
+ usage_not_available + lm_lib.LMSamplingUsage(1, 2, 3, 4, 5.0),
837
+ lm_lib.UsageNotAvailable(num_requests=5)
838
+ )
839
+ self.assertEqual(
840
+ lm_lib.LMSamplingUsage(1, 2, 3, 4, 5.0) + usage_not_available,
841
+ lm_lib.UsageNotAvailable(num_requests=5)
842
+ )
843
+ self.assertIs(None + usage_not_available, usage_not_available)
844
+ self.assertIs(usage_not_available + None, usage_not_available)
845
+
846
+
847
+ class UsageSummaryTest(unittest.TestCase):
848
+
849
+ def test_basics(self):
850
+ usage_summary = lm_lib.UsageSummary()
851
+ self.assertFalse(usage_summary.total)
852
+ self.assertFalse(usage_summary.cached)
853
+ self.assertFalse(usage_summary.uncached)
854
+
855
+ # Add uncached.
856
+ usage_summary.add(
857
+ 'model1', lm_lib.LMSamplingUsage(1, 2, 3, 1, 5.0), False
858
+ )
859
+ self.assertEqual(
860
+ usage_summary.total, lm_lib.LMSamplingUsage(1, 2, 3, 1, 5.0)
861
+ )
862
+ self.assertEqual(
863
+ usage_summary.uncached.total, lm_lib.LMSamplingUsage(1, 2, 3, 1, 5.0)
864
+ )
865
+ # Add cached.
866
+ self.assertFalse(usage_summary.cached)
867
+ usage_summary.add(
868
+ 'model1', lm_lib.LMSamplingUsage(1, 2, 3, 1, 5.0), True
869
+ )
870
+ self.assertEqual(
871
+ usage_summary.total, lm_lib.LMSamplingUsage(2, 4, 6, 2, 5.0)
872
+ )
873
+ self.assertEqual(
874
+ usage_summary.cached.total, lm_lib.LMSamplingUsage(1, 2, 3, 1, 0.0)
875
+ )
876
+ # Add UsageNotAvailable.
877
+ usage_summary.add(
878
+ 'model1', lm_lib.UsageNotAvailable(num_requests=1), False
879
+ )
880
+ self.assertEqual(
881
+ usage_summary.total, lm_lib.UsageNotAvailable(num_requests=3)
882
+ )
883
+ self.assertEqual(
884
+ usage_summary.uncached.total, lm_lib.UsageNotAvailable(num_requests=2)
885
+ )
886
+
887
+ def test_merge(self):
888
+ usage_summary = lm_lib.UsageSummary()
889
+ usage_summary.add(
890
+ 'model1', lm_lib.LMSamplingUsage(1, 2, 3, 1, 5.0), False
891
+ )
892
+ usage_summary.add(
893
+ 'model2', lm_lib.LMSamplingUsage(1, 2, 3, 1, 5.0), False
894
+ )
895
+ usage_summary.add(
896
+ 'model1', lm_lib.LMSamplingUsage(1, 2, 3, 1, 5.0), False
897
+ )
898
+ usage_summary2 = lm_lib.UsageSummary()
899
+ usage_summary2.add(
900
+ 'model1', lm_lib.LMSamplingUsage(1, 2, 3, 1, 5.0), False
901
+ )
902
+ usage_summary2.add(
903
+ 'model3', lm_lib.LMSamplingUsage(1, 2, 3, 1, 5.0), False
904
+ )
905
+ usage_summary2.merge(usage_summary)
906
+ self.assertEqual(
907
+ usage_summary2,
908
+ lm_lib.UsageSummary(
909
+ cached=lm_lib.UsageSummary.AggregatedUsage(
910
+ total=lm_lib.LMSamplingUsage(
911
+ prompt_tokens=0,
912
+ completion_tokens=0,
913
+ total_tokens=0,
914
+ num_requests=0,
915
+ estimated_cost=0.0,
916
+ ),
917
+ breakdown={}
918
+ ),
919
+ uncached=lm_lib.UsageSummary.AggregatedUsage(
920
+ total=lm_lib.LMSamplingUsage(
921
+ prompt_tokens=5,
922
+ completion_tokens=10,
923
+ total_tokens=15,
924
+ num_requests=5,
925
+ estimated_cost=25.0
926
+ ),
927
+ breakdown=dict(
928
+ model1=lm_lib.LMSamplingUsage(
929
+ prompt_tokens=3,
930
+ completion_tokens=6,
931
+ total_tokens=9,
932
+ num_requests=3,
933
+ estimated_cost=15.0
934
+ ),
935
+ model3=lm_lib.LMSamplingUsage(
936
+ prompt_tokens=1,
937
+ completion_tokens=2,
938
+ total_tokens=3,
939
+ num_requests=1,
940
+ estimated_cost=5.0
941
+ ),
942
+ model2=lm_lib.LMSamplingUsage(
943
+ prompt_tokens=1,
944
+ completion_tokens=2,
945
+ total_tokens=3,
946
+ num_requests=1,
947
+ estimated_cost=5.0
948
+ )
949
+ )
950
+ )
951
+ )
952
+ )
953
+
954
+ def test_html_view(self):
955
+ usage_summary = lm_lib.UsageSummary()
956
+ usage_summary.add(
957
+ 'model1', lm_lib.LMSamplingUsage(1, 2, 3, 1, 5.0), False
958
+ )
959
+ self.assertIn(
960
+ '5.000',
961
+ usage_summary.to_html(extra_flags=dict(as_badge=True)).content
962
+ )
963
+ usage_summary.add(
964
+ 'model1', lm_lib.LMSamplingUsage(1, 2, 3, 1, 5.0), False
965
+ )
966
+ self.assertIn(
967
+ '10.000',
968
+ usage_summary.to_html(
969
+ extra_flags=dict(as_badge=True, interactive=True)
970
+ ).content
971
+ )
972
+ self.assertTrue(
973
+ usage_summary.to_html().content.startswith('<details open')
974
+ )
975
+ with pg.views.html.controls.HtmlControl.track_scripts() as scripts:
976
+ usage_summary.add(
977
+ 'model2', lm_lib.LMSamplingUsage(1, 2, 3, 1, 5.0), False
978
+ )
979
+ self.assertEqual(len(scripts), 4)
980
+
567
981
 
568
982
  if __name__ == '__main__':
569
983
  unittest.main()