langfun 0.0.2.dev20240330__py3-none-any.whl → 0.1.2.dev202501140804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (145) hide show
  1. langfun/__init__.py +22 -2
  2. langfun/core/__init__.py +17 -5
  3. langfun/core/agentic/__init__.py +30 -0
  4. langfun/core/agentic/action.py +854 -0
  5. langfun/core/agentic/action_eval.py +150 -0
  6. langfun/core/agentic/action_eval_test.py +109 -0
  7. langfun/core/agentic/action_test.py +136 -0
  8. langfun/core/coding/python/__init__.py +5 -11
  9. langfun/core/coding/python/correction.py +37 -28
  10. langfun/core/coding/python/correction_test.py +29 -3
  11. langfun/core/coding/python/execution.py +40 -216
  12. langfun/core/coding/python/execution_test.py +29 -89
  13. langfun/core/coding/python/generation.py +21 -11
  14. langfun/core/coding/python/generation_test.py +2 -2
  15. langfun/core/coding/python/parsing.py +108 -193
  16. langfun/core/coding/python/parsing_test.py +2 -105
  17. langfun/core/component.py +69 -2
  18. langfun/core/component_test.py +54 -0
  19. langfun/core/concurrent.py +414 -117
  20. langfun/core/concurrent_test.py +111 -24
  21. langfun/core/console.py +18 -5
  22. langfun/core/console_test.py +17 -0
  23. langfun/core/eval/__init__.py +17 -0
  24. langfun/core/eval/base.py +767 -140
  25. langfun/core/eval/base_test.py +238 -53
  26. langfun/core/eval/matching.py +80 -76
  27. langfun/core/eval/matching_test.py +19 -9
  28. langfun/core/eval/patching.py +130 -0
  29. langfun/core/eval/patching_test.py +170 -0
  30. langfun/core/eval/scoring.py +37 -28
  31. langfun/core/eval/scoring_test.py +21 -3
  32. langfun/core/eval/v2/__init__.py +42 -0
  33. langfun/core/eval/v2/checkpointing.py +380 -0
  34. langfun/core/eval/v2/checkpointing_test.py +228 -0
  35. langfun/core/eval/v2/eval_test_helper.py +136 -0
  36. langfun/core/eval/v2/evaluation.py +725 -0
  37. langfun/core/eval/v2/evaluation_test.py +180 -0
  38. langfun/core/eval/v2/example.py +305 -0
  39. langfun/core/eval/v2/example_test.py +128 -0
  40. langfun/core/eval/v2/experiment.py +1048 -0
  41. langfun/core/eval/v2/experiment_test.py +433 -0
  42. langfun/core/eval/v2/metric_values.py +156 -0
  43. langfun/core/eval/v2/metric_values_test.py +80 -0
  44. langfun/core/eval/v2/metrics.py +357 -0
  45. langfun/core/eval/v2/metrics_test.py +203 -0
  46. langfun/core/eval/v2/progress.py +348 -0
  47. langfun/core/eval/v2/progress_test.py +82 -0
  48. langfun/core/eval/v2/progress_tracking.py +210 -0
  49. langfun/core/eval/v2/progress_tracking_test.py +66 -0
  50. langfun/core/eval/v2/reporting.py +270 -0
  51. langfun/core/eval/v2/reporting_test.py +158 -0
  52. langfun/core/eval/v2/runners.py +488 -0
  53. langfun/core/eval/v2/runners_test.py +334 -0
  54. langfun/core/langfunc.py +3 -21
  55. langfun/core/langfunc_test.py +26 -8
  56. langfun/core/language_model.py +686 -48
  57. langfun/core/language_model_test.py +681 -44
  58. langfun/core/llms/__init__.py +100 -12
  59. langfun/core/llms/anthropic.py +488 -0
  60. langfun/core/llms/anthropic_test.py +235 -0
  61. langfun/core/llms/cache/base.py +21 -2
  62. langfun/core/llms/cache/in_memory.py +13 -0
  63. langfun/core/llms/cache/in_memory_test.py +88 -28
  64. langfun/core/llms/compositional.py +101 -0
  65. langfun/core/llms/compositional_test.py +73 -0
  66. langfun/core/llms/deepseek.py +117 -0
  67. langfun/core/llms/deepseek_test.py +61 -0
  68. langfun/core/llms/fake.py +39 -26
  69. langfun/core/llms/fake_test.py +136 -11
  70. langfun/core/llms/gemini.py +507 -0
  71. langfun/core/llms/gemini_test.py +195 -0
  72. langfun/core/llms/google_genai.py +62 -218
  73. langfun/core/llms/google_genai_test.py +9 -197
  74. langfun/core/llms/groq.py +276 -0
  75. langfun/core/llms/groq_test.py +64 -0
  76. langfun/core/llms/llama_cpp.py +15 -40
  77. langfun/core/llms/llama_cpp_test.py +4 -30
  78. langfun/core/llms/openai.py +436 -226
  79. langfun/core/llms/openai_compatible.py +179 -0
  80. langfun/core/llms/openai_compatible_test.py +495 -0
  81. langfun/core/llms/openai_test.py +35 -174
  82. langfun/core/llms/rest.py +113 -0
  83. langfun/core/llms/rest_test.py +111 -0
  84. langfun/core/llms/vertexai.py +192 -0
  85. langfun/core/llms/vertexai_test.py +52 -0
  86. langfun/core/logging.py +284 -0
  87. langfun/core/logging_test.py +125 -0
  88. langfun/core/message.py +319 -9
  89. langfun/core/message_test.py +190 -13
  90. langfun/core/modalities/__init__.py +6 -2
  91. langfun/core/modalities/audio.py +30 -0
  92. langfun/core/modalities/audio_test.py +63 -0
  93. langfun/core/modalities/image.py +39 -20
  94. langfun/core/modalities/image_test.py +52 -9
  95. langfun/core/modalities/mime.py +206 -29
  96. langfun/core/modalities/mime_test.py +90 -9
  97. langfun/core/modalities/ms_office.py +117 -0
  98. langfun/core/modalities/ms_office_test.py +389 -0
  99. langfun/core/modalities/pdf.py +22 -0
  100. langfun/core/modalities/pdf_test.py +57 -0
  101. langfun/core/modalities/video.py +9 -23
  102. langfun/core/modalities/video_test.py +3 -3
  103. langfun/core/modality.py +26 -3
  104. langfun/core/modality_test.py +2 -2
  105. langfun/core/sampling.py +11 -11
  106. langfun/core/structured/__init__.py +15 -16
  107. langfun/core/structured/completion.py +32 -5
  108. langfun/core/structured/completion_test.py +9 -8
  109. langfun/core/structured/description.py +2 -2
  110. langfun/core/structured/description_test.py +3 -3
  111. langfun/core/structured/function_generation.py +278 -0
  112. langfun/core/structured/function_generation_test.py +399 -0
  113. langfun/core/structured/mapping.py +150 -46
  114. langfun/core/structured/mapping_test.py +105 -0
  115. langfun/core/structured/parsing.py +33 -21
  116. langfun/core/structured/parsing_test.py +71 -22
  117. langfun/core/structured/querying.py +746 -0
  118. langfun/core/structured/{prompting_test.py → querying_test.py} +545 -60
  119. langfun/core/structured/schema.py +208 -99
  120. langfun/core/structured/schema_generation.py +1 -1
  121. langfun/core/structured/schema_generation_test.py +2 -2
  122. langfun/core/structured/schema_test.py +133 -34
  123. langfun/core/structured/scoring.py +125 -19
  124. langfun/core/structured/scoring_test.py +30 -0
  125. langfun/core/structured/tokenization.py +64 -0
  126. langfun/core/structured/tokenization_test.py +48 -0
  127. langfun/core/template.py +240 -11
  128. langfun/core/template_test.py +146 -1
  129. langfun/core/templates/conversation.py +9 -0
  130. langfun/core/templates/conversation_test.py +4 -3
  131. langfun/core/templates/selfplay_test.py +14 -2
  132. langfun-0.1.2.dev202501140804.dist-info/METADATA +225 -0
  133. langfun-0.1.2.dev202501140804.dist-info/RECORD +153 -0
  134. {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/WHEEL +1 -1
  135. langfun/core/coding/python/errors.py +0 -108
  136. langfun/core/coding/python/errors_test.py +0 -99
  137. langfun/core/coding/python/permissions.py +0 -90
  138. langfun/core/coding/python/permissions_test.py +0 -86
  139. langfun/core/structured/prompting.py +0 -217
  140. langfun/core/text_formatting.py +0 -162
  141. langfun/core/text_formatting_test.py +0 -47
  142. langfun-0.0.2.dev20240330.dist-info/METADATA +0 -99
  143. langfun-0.0.2.dev20240330.dist-info/RECORD +0 -102
  144. {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/LICENSE +0 -0
  145. {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/top_level.txt +0 -0
@@ -27,38 +27,53 @@ import pyglove as pg
27
27
  @pg.use_init_args(['failures_before_attempt'])
28
28
  class MockModel(lm_lib.LanguageModel):
29
29
  """A mock model that echo back user prompts."""
30
-
31
30
  failures_before_attempt: int = 0
31
+ name: str = 'MockModel'
32
32
 
33
33
  def _sample(self,
34
34
  prompts: list[message_lib.Message]
35
35
  ) -> list[lm_lib.LMSamplingResult]:
36
36
  context = pg.Dict(attempt=0)
37
37
 
38
- def fake_sample(prompts):
38
+ def fake_sample(prompt):
39
39
  if context.attempt >= self.failures_before_attempt:
40
- return [
41
- lm_lib.LMSamplingResult([lm_lib.LMSample( # pylint: disable=g-complex-comprehension
42
- response=prompt.text * self.sampling_options.top_k,
43
- score=self.sampling_options.temperature)])
44
- for prompt in prompts
45
- ]
46
- context.attempt += 1
40
+ return lm_lib.LMSamplingResult(
41
+ [
42
+ lm_lib.LMSample( # pylint: disable=g-complex-comprehension
43
+ response=prompt.text * self.sampling_options.top_k,
44
+ score=self.sampling_options.temperature or -1.0,
45
+ )
46
+ ],
47
+ usage=lm_lib.LMSamplingUsage(
48
+ prompt_tokens=100,
49
+ completion_tokens=100,
50
+ total_tokens=200,
51
+ estimated_cost=1.0,
52
+ ),
53
+ )
54
+ else:
55
+ context.attempt += 1
47
56
  raise ValueError('Failed to sample prompts.')
48
57
 
49
- return concurrent.with_retry(
50
- fake_sample,
51
- retry_on_errors=ValueError,
52
- max_attempts=self.max_attempts,
53
- retry_interval=1,
54
- )(prompts)
58
+ results = self._parallel_execute_with_currency_control(
59
+ fake_sample, prompts, retry_on_errors=ValueError
60
+ )
61
+ for result in results:
62
+ result.usage.retry_stats.rebind(
63
+ total_call_interval=0, skip_notification=True
64
+ )
65
+ return results
66
+
67
+ @property
68
+ def model_id(self) -> str:
69
+ return self.name
55
70
 
56
71
 
57
72
  class MockScoringModel(MockModel):
58
73
 
59
74
  def _score(
60
75
  self,
61
- prompt: message_lib.Message,
76
+ prompt: message_lib.Message | list[message_lib.Message],
62
77
  completions: list[message_lib.Message],
63
78
  **kwargs
64
79
  ) -> list[lm_lib.LMScoringResult]:
@@ -67,19 +82,26 @@ class MockScoringModel(MockModel):
67
82
  ]
68
83
 
69
84
 
85
+ class MockTokenizeModel(MockModel):
86
+
87
+ def _tokenize(
88
+ self, prompt: message_lib.Message) -> list[tuple[str | bytes, int]]:
89
+ return [(w, i) for i, w in enumerate(prompt.text.split(' '))]
90
+
91
+
70
92
  class LMSamplingOptionsTest(unittest.TestCase):
71
93
  """Tests for LMSamplingOptions."""
72
94
 
73
95
  def test_cache_key(self):
74
96
  options = lm_lib.LMSamplingOptions()
75
97
  key1 = options.cache_key()
76
- self.assertEqual(key1, (0.0, 1024, 1, 40, None, None))
98
+ self.assertEqual(key1, (None, None, 1, 40, None, None))
77
99
  with options.override(temperature=1.0, max_tokens=256):
78
100
  key2 = options.cache_key()
79
101
  self.assertEqual(key2, (1.0, 256, 1, 40, None, None))
80
102
 
81
103
  # Make sure key1 does not change upon override.
82
- self.assertEqual(key1, (0.0, 1024, 1, 40, None, None))
104
+ self.assertEqual(key1, (None, None, 1, 40, None, None))
83
105
 
84
106
 
85
107
  class LanguageModelTest(unittest.TestCase):
@@ -95,13 +117,60 @@ class LanguageModelTest(unittest.TestCase):
95
117
  self.assertEqual(lm.sampling_options.top_k, 2)
96
118
  self.assertEqual(lm.max_attempts, 2)
97
119
 
120
+ def test_subclassing(self):
121
+
122
+ class ChildModel(lm_lib.LanguageModel):
123
+
124
+ sampling_options = lm_lib.LMSamplingOptions(
125
+ temperature=0.5, top_k=20
126
+ )
127
+
128
+ def _sample(self, *args, **kwargs):
129
+ pass
130
+
131
+ lm = ChildModel(top_k=10)
132
+ self.assertEqual(lm.sampling_options.temperature, 0.5)
133
+ self.assertEqual(lm.sampling_options.top_k, 10)
134
+
98
135
  def test_sample(self):
99
136
  lm = MockModel(top_k=1)
100
137
  self.assertEqual(
101
138
  lm.sample(prompts=['foo', 'bar']),
102
139
  [
103
- lm_lib.LMSamplingResult([lm_lib.LMSample('foo', score=0.0)]),
104
- lm_lib.LMSamplingResult([lm_lib.LMSample('bar', score=0.0)]),
140
+ lm_lib.LMSamplingResult(
141
+ [
142
+ lm_lib.LMSample(
143
+ message_lib.AIMessage(
144
+ 'foo',
145
+ score=-1.0,
146
+ logprobs=None,
147
+ is_cached=False,
148
+ usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
149
+ tags=[message_lib.Message.TAG_LM_RESPONSE],
150
+ ),
151
+ score=-1.0,
152
+ logprobs=None,
153
+ )
154
+ ],
155
+ usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
156
+ ),
157
+ lm_lib.LMSamplingResult(
158
+ [
159
+ lm_lib.LMSample(
160
+ message_lib.AIMessage(
161
+ 'bar',
162
+ score=-1.0,
163
+ logprobs=None,
164
+ is_cached=False,
165
+ usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
166
+ tags=[message_lib.Message.TAG_LM_RESPONSE],
167
+ ),
168
+ score=-1.0,
169
+ logprobs=None,
170
+ )
171
+ ],
172
+ usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
173
+ ),
105
174
  ],
106
175
  )
107
176
  # Test override sampling_options.
@@ -112,38 +181,139 @@ class LanguageModelTest(unittest.TestCase):
112
181
  ),
113
182
  [
114
183
  lm_lib.LMSamplingResult(
115
- [lm_lib.LMSample('foo' * 2, score=0.5)]
184
+ [
185
+ lm_lib.LMSample(
186
+ message_lib.AIMessage(
187
+ 'foo' * 2,
188
+ score=0.5,
189
+ logprobs=None,
190
+ is_cached=False,
191
+ usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
192
+ tags=[message_lib.Message.TAG_LM_RESPONSE],
193
+ ),
194
+ score=0.5,
195
+ logprobs=None,
196
+ ),
197
+ ],
198
+ usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
116
199
  ),
117
200
  lm_lib.LMSamplingResult(
118
- [lm_lib.LMSample('bar' * 2, score=0.5)]
201
+ [
202
+ lm_lib.LMSample(
203
+ message_lib.AIMessage(
204
+ 'bar' * 2,
205
+ score=0.5,
206
+ logprobs=None,
207
+ is_cached=False,
208
+ usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
209
+ tags=[message_lib.Message.TAG_LM_RESPONSE],
210
+ ),
211
+ score=0.5,
212
+ logprobs=None,
213
+ ),
214
+ ],
215
+ usage=lm_lib.LMSamplingUsage(
216
+ prompt_tokens=100, completion_tokens=100, total_tokens=200,
217
+ num_requests=1, estimated_cost=1.0,
218
+ ),
119
219
  ),
120
- ],
220
+ ]
121
221
  )
122
222
  # Test override individual flags within sampling_options.
123
223
  self.assertEqual(
124
224
  lm.sample(prompts=['foo', 'bar'], temperature=1.0),
125
225
  [
126
- lm_lib.LMSamplingResult([lm_lib.LMSample('foo', score=1.0)]),
127
- lm_lib.LMSamplingResult([lm_lib.LMSample('bar', score=1.0)]),
128
- ],
226
+ lm_lib.LMSamplingResult(
227
+ [
228
+ lm_lib.LMSample(
229
+ message_lib.AIMessage(
230
+ 'foo',
231
+ score=1.0,
232
+ logprobs=None,
233
+ is_cached=False,
234
+ usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
235
+ tags=[message_lib.Message.TAG_LM_RESPONSE],
236
+ ),
237
+ score=1.0,
238
+ logprobs=None,
239
+ ),
240
+ ],
241
+ usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
242
+ ),
243
+ lm_lib.LMSamplingResult(
244
+ [
245
+ lm_lib.LMSample(
246
+ message_lib.AIMessage(
247
+ 'bar',
248
+ score=1.0,
249
+ logprobs=None,
250
+ is_cached=False,
251
+ usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
252
+ tags=[message_lib.Message.TAG_LM_RESPONSE],
253
+ ),
254
+ score=1.0,
255
+ logprobs=None,
256
+ ),
257
+ ],
258
+ usage=lm_lib.LMSamplingUsage(
259
+ prompt_tokens=100, completion_tokens=100, total_tokens=200,
260
+ num_requests=1, estimated_cost=1.0,
261
+ ),
262
+ ),
263
+ ]
129
264
  )
130
265
  self.assertEqual(
131
266
  lm.sample(prompts=['foo', 'bar'], top_k=2, temperature=0.7),
132
267
  [
133
268
  lm_lib.LMSamplingResult(
134
- [lm_lib.LMSample('foo' * 2, score=0.7)]
269
+ [
270
+ lm_lib.LMSample(
271
+ message_lib.AIMessage(
272
+ 'foo' * 2,
273
+ score=0.7,
274
+ logprobs=None,
275
+ is_cached=False,
276
+ usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
277
+ tags=[message_lib.Message.TAG_LM_RESPONSE],
278
+ ),
279
+ score=0.7,
280
+ logprobs=None,
281
+ ),
282
+ ],
283
+ usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
135
284
  ),
136
285
  lm_lib.LMSamplingResult(
137
- [lm_lib.LMSample('bar' * 2, score=0.7)]
286
+ [
287
+ lm_lib.LMSample(
288
+ message_lib.AIMessage(
289
+ 'bar' * 2,
290
+ score=0.7,
291
+ logprobs=None,
292
+ is_cached=False,
293
+ usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
294
+ tags=[message_lib.Message.TAG_LM_RESPONSE],
295
+ ),
296
+ score=0.7,
297
+ logprobs=None,
298
+ ),
299
+ ],
300
+ usage=lm_lib.LMSamplingUsage(
301
+ prompt_tokens=100, completion_tokens=100, total_tokens=200,
302
+ num_requests=1, estimated_cost=1.0,
303
+ ),
138
304
  ),
139
- ],
305
+ ]
140
306
  )
141
307
 
142
308
  def test_call(self):
143
309
  lm = MockModel(sampling_options=lm_lib.LMSamplingOptions(top_k=1))
144
310
  response = lm(prompt='foo')
145
311
  self.assertEqual(response.text, 'foo')
146
- self.assertEqual(response.score, 0.0)
312
+ self.assertEqual(response.score, -1.0)
313
+ self.assertIsNone(response.logprobs)
314
+ self.assertEqual(
315
+ response.usage, lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0)
316
+ )
147
317
 
148
318
  # Test override sampling_options.
149
319
  self.assertEqual(
@@ -158,16 +328,53 @@ class LanguageModelTest(unittest.TestCase):
158
328
  self.assertEqual(
159
329
  lm.sample(prompts=['foo', 'bar']),
160
330
  [
161
- lm_lib.LMSamplingResult([lm_lib.LMSample(
162
- message_lib.AIMessage('foo', cache_seed=0), score=0.0)]),
163
- lm_lib.LMSamplingResult([lm_lib.LMSample(
164
- message_lib.AIMessage('bar', cache_seed=0), score=0.0)]),
165
- ])
331
+ lm_lib.LMSamplingResult(
332
+ [
333
+ lm_lib.LMSample(
334
+ message_lib.AIMessage(
335
+ 'foo',
336
+ cache_seed=0,
337
+ score=-1.0,
338
+ logprobs=None,
339
+ is_cached=False,
340
+ usage=lm_lib.LMSamplingUsage(
341
+ 100, 100, 200, 1, 1.0
342
+ ),
343
+ tags=[message_lib.Message.TAG_LM_RESPONSE],
344
+ ),
345
+ score=-1.0,
346
+ logprobs=None,
347
+ )
348
+ ],
349
+ usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
350
+ ),
351
+ lm_lib.LMSamplingResult(
352
+ [
353
+ lm_lib.LMSample(
354
+ message_lib.AIMessage(
355
+ 'bar',
356
+ cache_seed=0,
357
+ score=-1.0,
358
+ logprobs=None,
359
+ is_cached=False,
360
+ usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
361
+ tags=[message_lib.Message.TAG_LM_RESPONSE],
362
+ ),
363
+ score=-1.0,
364
+ logprobs=None,
365
+ )
366
+ ],
367
+ usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
368
+ ),
369
+ ],
370
+ )
166
371
  self.assertEqual(cache.stats.num_queries, 2)
167
372
  self.assertEqual(cache.stats.num_hits, 0)
168
373
  self.assertEqual(cache.stats.num_updates, 2)
169
374
 
170
- self.assertEqual(lm('foo'), 'foo')
375
+ result = lm('foo')
376
+ self.assertEqual(result, 'foo')
377
+ self.assertTrue(result.metadata.is_cached)
171
378
  self.assertEqual(lm('bar'), 'bar')
172
379
  self.assertEqual(cache.stats.num_queries, 4)
173
380
  self.assertEqual(cache.stats.num_hits, 2)
@@ -181,10 +388,42 @@ class LanguageModelTest(unittest.TestCase):
181
388
  self.assertEqual(
182
389
  lm.sample(prompts=['foo', 'baz'], temperature=1.0),
183
390
  [
184
- lm_lib.LMSamplingResult([lm_lib.LMSample(
185
- message_lib.AIMessage('foo', cache_seed=0), score=1.0)]),
186
- lm_lib.LMSamplingResult([lm_lib.LMSample(
187
- message_lib.AIMessage('baz', cache_seed=0), score=1.0)]),
391
+ lm_lib.LMSamplingResult(
392
+ [
393
+ lm_lib.LMSample(
394
+ message_lib.AIMessage(
395
+ 'foo',
396
+ cache_seed=0,
397
+ score=1.0,
398
+ logprobs=None,
399
+ is_cached=False,
400
+ usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
401
+ tags=[message_lib.Message.TAG_LM_RESPONSE],
402
+ ),
403
+ score=1.0,
404
+ logprobs=None,
405
+ )
406
+ ],
407
+ usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
408
+ ),
409
+ lm_lib.LMSamplingResult(
410
+ [
411
+ lm_lib.LMSample(
412
+ message_lib.AIMessage(
413
+ 'baz',
414
+ cache_seed=0,
415
+ score=1.0,
416
+ logprobs=None,
417
+ is_cached=False,
418
+ usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
419
+ tags=[message_lib.Message.TAG_LM_RESPONSE],
420
+ ),
421
+ score=1.0,
422
+ logprobs=None,
423
+ )
424
+ ],
425
+ usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
426
+ ),
188
427
  ],
189
428
  )
190
429
  self.assertEqual(cache.stats.num_queries, 6)
@@ -209,13 +448,50 @@ class LanguageModelTest(unittest.TestCase):
209
448
 
210
449
  def test_retry(self):
211
450
  lm = MockModel(
212
- failures_before_attempt=1, top_k=1,
451
+ failures_before_attempt=1, top_k=1, max_attempts=2, retry_interval=1
213
452
  )
214
453
  with self.assertRaisesRegex(
215
454
  concurrent.RetryError, 'Calling .* failed after 1 attempts'
216
455
  ):
217
456
  lm('foo', max_attempts=1)
218
- self.assertEqual(lm('foo', max_attempts=2), 'foo')
457
+
458
+ usage = lm_lib.LMSamplingUsage(
459
+ prompt_tokens=100,
460
+ completion_tokens=100,
461
+ total_tokens=200,
462
+ num_requests=1,
463
+ estimated_cost=1.0,
464
+ retry_stats=lm_lib.RetryStats(
465
+ num_occurences=1,
466
+ total_wait_interval=1,
467
+ errors={'ValueError': 1},
468
+ ),
469
+ )
470
+ out = lm.sample(['foo'])
471
+ self.assertEqual(
472
+ # lm.sample(['foo'], max_attempts=2),
473
+ out,
474
+ [
475
+ lm_lib.LMSamplingResult(
476
+ [
477
+ lm_lib.LMSample(
478
+ message_lib.AIMessage(
479
+ 'foo',
480
+ score=-1.0,
481
+ logprobs=None,
482
+ is_cached=False,
483
+ usage=usage,
484
+ tags=['lm-response'],
485
+ ),
486
+ score=-1.0,
487
+ logprobs=None,
488
+ )
489
+ ],
490
+ usage=usage,
491
+ is_cached=False,
492
+ )
493
+ ],
494
+ )
219
495
 
220
496
  def test_debug(self):
221
497
  class Image(modality.Modality):
@@ -227,8 +503,9 @@ class LanguageModelTest(unittest.TestCase):
227
503
  with contextlib.redirect_stdout(string_io):
228
504
  self.assertEqual(
229
505
  lm(message_lib.UserMessage(
230
- 'hi {{image}}', image=Image()), debug=True),
231
- 'hi {{image}}')
506
+ 'hi <<[[image]]>>', image=Image()), debug=True),
507
+ 'hi <<[[image]]>>'
508
+ )
232
509
 
233
510
  debug_info = string_io.getvalue()
234
511
  self.assertIn('[0] LM INFO', debug_info)
@@ -317,6 +594,17 @@ class LanguageModelTest(unittest.TestCase):
317
594
  ],
318
595
  )
319
596
 
597
+ self.assertEqual(
598
+ lm.score(
599
+ [message_lib.UserMessage('hi {{image}}', image=Image()),
600
+ message_lib.UserMessage('hi {{image}}', image=Image())],
601
+ ['1', '2'], debug=debug_mode),
602
+ [
603
+ lm_lib.LMScoringResult(score=-0.0),
604
+ lm_lib.LMScoringResult(score=-1.0),
605
+ ],
606
+ )
607
+
320
608
  debug_info = string_io.getvalue()
321
609
  expected_included = [
322
610
  debug_prints[f]
@@ -337,10 +625,359 @@ class LanguageModelTest(unittest.TestCase):
337
625
  if debug_mode & lm_lib.LMDebugMode.PROMPT:
338
626
  self.assertIn('[0] MODALITY OBJECTS SENT TO LM', debug_info)
339
627
 
628
+ def test_score_with_unmatched_prompt_and_completions(self):
629
+ with self.assertRaises(ValueError):
630
+ MockScoringModel().score(['hi',], ['1', '2', '3'])
631
+
340
632
  def test_score_with_unsupported_model(self):
341
633
  with self.assertRaises(NotImplementedError):
342
634
  MockModel().score('hi', ['1', '2'])
343
635
 
636
+ def test_tokenize(self):
637
+ info_flag = lm_lib.LMDebugMode.INFO
638
+ prompt_flag = lm_lib.LMDebugMode.PROMPT
639
+ response_flag = lm_lib.LMDebugMode.RESPONSE
640
+ debug_prints = {
641
+ info_flag: 'LM INFO',
642
+ prompt_flag: 'PROMPT TO TOKENIZE',
643
+ response_flag: 'TOKENS RETURNED',
644
+ }
645
+ debug_modes = [
646
+ info_flag,
647
+ prompt_flag,
648
+ response_flag,
649
+ info_flag | prompt_flag,
650
+ info_flag | response_flag,
651
+ prompt_flag | response_flag,
652
+ info_flag | prompt_flag | response_flag,
653
+ ]
654
+
655
+ class Image(modality.Modality):
656
+ def to_bytes(self):
657
+ return b'fake_image'
658
+
659
+ for debug_mode in debug_modes:
660
+ string_io = io.StringIO()
661
+ lm = MockTokenizeModel()
662
+
663
+ with contextlib.redirect_stdout(string_io):
664
+ self.assertEqual(
665
+ lm.tokenize(
666
+ message_lib.UserMessage('hi <<[[image]]>>', image=Image()),
667
+ debug=debug_mode),
668
+ [('hi', 0), ('<<[[image]]>>', 1)],
669
+ )
670
+
671
+ debug_info = string_io.getvalue()
672
+ expected_included = [
673
+ debug_prints[f]
674
+ for f in lm_lib.LMDebugMode
675
+ if f != lm_lib.LMDebugMode.NONE and f in debug_mode
676
+ ]
677
+ expected_excluded = [
678
+ debug_prints[f]
679
+ for f in lm_lib.LMDebugMode
680
+ if f != lm_lib.LMDebugMode.NONE and f not in debug_mode
681
+ ]
682
+
683
+ for expected_include in expected_included:
684
+ self.assertIn(expected_include, debug_info)
685
+ for expected_exclude in expected_excluded:
686
+ self.assertNotIn(expected_exclude, debug_info)
687
+
688
+ if debug_mode & lm_lib.LMDebugMode.PROMPT:
689
+ self.assertIn('[0] MODALITY OBJECTS SENT TO LM', debug_info)
690
+
691
+ def test_tokenize_with_unsupported_model(self):
692
+ with self.assertRaises(NotImplementedError):
693
+ MockModel().tokenize('hi')
694
+
695
+ def test_rate_to_max_concurrency_no_rpm_no_tpm(self) -> None:
696
+ lm = MockModel()
697
+ self.assertEqual(
698
+ lm_lib.DEFAULT_MAX_CONCURRENCY,
699
+ lm.rate_to_max_concurrency(requests_per_min=0, tokens_per_min=0),
700
+ )
701
+ self.assertEqual(
702
+ lm_lib.DEFAULT_MAX_CONCURRENCY,
703
+ lm.rate_to_max_concurrency(requests_per_min=-1, tokens_per_min=-1),
704
+ )
705
+
706
+ def test_rate_to_max_concurrency_only_rpm_specified_uses_rpm(self) -> None:
707
+ lm = MockModel()
708
+ test_rpm = 1e4
709
+ self.assertEqual(
710
+ lm.rate_to_max_concurrency(requests_per_min=test_rpm),
711
+ int(test_rpm / 60)
712
+ )
713
+
714
+ def test_rate_to_max_concurrency_tpm_specified_uses_tpm(self) -> None:
715
+ lm = MockModel()
716
+ test_tpm = 1e7
717
+ self.assertEqual(
718
+ lm.rate_to_max_concurrency(requests_per_min=1, tokens_per_min=test_tpm),
719
+ int(test_tpm / lm_lib.TOKENS_PER_REQUEST / 60),
720
+ )
721
+
722
+ def test_rate_to_max_concurrency_small_rate_returns_one(self) -> None:
723
+ lm = MockModel()
724
+ self.assertEqual(lm.rate_to_max_concurrency(requests_per_min=1), 1)
725
+ self.assertEqual(lm.rate_to_max_concurrency(tokens_per_min=1), 1)
726
+
727
+ def test_track_usages(self):
728
+ lm = MockModel(name='model1')
729
+ lm2 = MockModel(name='model2')
730
+ with lm_lib.track_usages() as usages1:
731
+ _ = lm('hi')
732
+ with lm_lib.track_usages(lm2) as usages2:
733
+ with lm_lib.track_usages('model1') as usages3:
734
+ with lm_lib.track_usages('model1', lm2) as usages4:
735
+ def call_lm(prompt):
736
+ _ = lm.sample([prompt] * 2)
737
+ lm2('hi')
738
+ list(concurrent.concurrent_map(call_lm, ['hi', 'hello']))
739
+
740
+ self.assertEqual(usages2.uncached.breakdown, {
741
+ 'model2': lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
742
+ })
743
+ self.assertFalse(usages2.cached)
744
+ self.assertEqual(usages3.uncached.breakdown, {
745
+ 'model1': lm_lib.LMSamplingUsage(100 * 4, 100 * 4, 200 * 4, 4, 4.0),
746
+ })
747
+ self.assertFalse(usages3.cached)
748
+ self.assertEqual(usages4.uncached.breakdown, {
749
+ 'model1': lm_lib.LMSamplingUsage(100 * 4, 100 * 4, 200 * 4, 4, 4.0),
750
+ 'model2': lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
751
+ })
752
+ self.assertFalse(usages4.cached)
753
+ self.assertEqual(usages1.uncached.breakdown, {
754
+ 'model1': lm_lib.LMSamplingUsage(100 * 5, 100 * 5, 200 * 5, 5, 5.0),
755
+ 'model2': lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
756
+ })
757
+ self.assertFalse(usages1.cached)
758
+ self.assertEqual(
759
+ usages1.total,
760
+ lm_lib.LMSamplingUsage(100 * 6, 100 * 6, 200 * 6, 6, 6.0),
761
+ )
762
+
763
+ cache = in_memory.InMemory()
764
+ lm = MockModel(cache=cache, name='model1')
765
+ with lm_lib.track_usages() as usages1:
766
+ _ = lm('hi')
767
+ self.assertEqual(usages1.uncached.breakdown, {
768
+ 'model1': lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
769
+ })
770
+ self.assertFalse(usages1.cached)
771
+ with lm_lib.track_usages() as usages2:
772
+ _ = lm('hi')
773
+ self.assertEqual(usages2.cached.breakdown, {
774
+ 'model1': lm_lib.LMSamplingUsage(100, 100, 200, 1, 0.0),
775
+ })
776
+ self.assertFalse(usages2.uncached)
777
+
778
+
779
+ class LMSamplingUsageTest(unittest.TestCase):
780
+
781
+ def test_basics(self):
782
+ usage = lm_lib.LMSamplingUsage(100, 200, 300, 4, 5.0)
783
+ self.assertEqual(usage.num_requests, 4)
784
+ self.assertEqual(usage.prompt_tokens, 100)
785
+ self.assertEqual(usage.completion_tokens, 200)
786
+ self.assertEqual(usage.total_tokens, 300)
787
+ self.assertEqual(usage.estimated_cost, 5.0)
788
+ self.assertEqual(usage.average_prompt_tokens, 25)
789
+ self.assertEqual(usage.average_completion_tokens, 50)
790
+ self.assertEqual(usage.average_total_tokens, 75)
791
+ self.assertEqual(usage.average_estimated_cost, 1.25)
792
+
793
+ def test_add(self):
794
+ usage1 = lm_lib.LMSamplingUsage(100, 200, 300, 4, 5.0)
795
+ usage1.rebind(retry_stats=lm_lib.RetryStats(1, 3, 4, {'e1': 1}))
796
+ usage2 = lm_lib.LMSamplingUsage(100, 200, 300, 4, 5.0)
797
+ self.assertEqual(usage1 + usage2, usage1 + usage2)
798
+ self.assertIs(usage1 + None, usage1)
799
+ self.assertIs(None + usage1, usage1)
800
+ usage3 = lm_lib.LMSamplingUsage(100, 200, 300, 4, None)
801
+ usage3.rebind(retry_stats=lm_lib.RetryStats(2, 4, 5, {'e1': 2, 'e2': 3}))
802
+ self.assertEqual(
803
+ usage1 + usage3,
804
+ lm_lib.LMSamplingUsage(
805
+ 200,
806
+ 400,
807
+ 600,
808
+ 8,
809
+ 5.0,
810
+ retry_stats=lm_lib.RetryStats(3, 7, 9, {'e1': 3, 'e2': 3}),
811
+ ),
812
+ )
813
+ self.assertEqual(
814
+ usage3 + usage1,
815
+ lm_lib.LMSamplingUsage(
816
+ 200,
817
+ 400,
818
+ 600,
819
+ 8,
820
+ 5.0,
821
+ retry_stats=lm_lib.RetryStats(3, 7, 9, {'e1': 3, 'e2': 3}),
822
+ ),
823
+ )
824
+
825
+ def test_usage_not_available(self):
826
+ usage_not_available = lm_lib.UsageNotAvailable()
827
+ self.assertEqual(usage_not_available.prompt_tokens, 0)
828
+ self.assertEqual(usage_not_available.completion_tokens, 0)
829
+ self.assertEqual(usage_not_available.total_tokens, 0)
830
+ self.assertEqual(usage_not_available.average_prompt_tokens, 0)
831
+ self.assertEqual(usage_not_available.average_completion_tokens, 0)
832
+ self.assertEqual(usage_not_available.average_total_tokens, 0)
833
+ self.assertIsNone(usage_not_available.average_estimated_cost)
834
+ self.assertTrue(usage_not_available)
835
+ self.assertEqual(
836
+ usage_not_available + lm_lib.LMSamplingUsage(1, 2, 3, 4, 5.0),
837
+ lm_lib.UsageNotAvailable(num_requests=5)
838
+ )
839
+ self.assertEqual(
840
+ lm_lib.LMSamplingUsage(1, 2, 3, 4, 5.0) + usage_not_available,
841
+ lm_lib.UsageNotAvailable(num_requests=5)
842
+ )
843
+ self.assertIs(None + usage_not_available, usage_not_available)
844
+ self.assertIs(usage_not_available + None, usage_not_available)
845
+
846
+
847
+ class UsageSummaryTest(unittest.TestCase):
848
+
849
+ def test_basics(self):
850
+ usage_summary = lm_lib.UsageSummary()
851
+ self.assertFalse(usage_summary.total)
852
+ self.assertFalse(usage_summary.cached)
853
+ self.assertFalse(usage_summary.uncached)
854
+
855
+ # Add uncached.
856
+ usage_summary.add(
857
+ 'model1', lm_lib.LMSamplingUsage(1, 2, 3, 1, 5.0), False
858
+ )
859
+ self.assertEqual(
860
+ usage_summary.total, lm_lib.LMSamplingUsage(1, 2, 3, 1, 5.0)
861
+ )
862
+ self.assertEqual(
863
+ usage_summary.uncached.total, lm_lib.LMSamplingUsage(1, 2, 3, 1, 5.0)
864
+ )
865
+ # Add cached.
866
+ self.assertFalse(usage_summary.cached)
867
+ usage_summary.add(
868
+ 'model1', lm_lib.LMSamplingUsage(1, 2, 3, 1, 5.0), True
869
+ )
870
+ self.assertEqual(
871
+ usage_summary.total, lm_lib.LMSamplingUsage(2, 4, 6, 2, 5.0)
872
+ )
873
+ self.assertEqual(
874
+ usage_summary.cached.total, lm_lib.LMSamplingUsage(1, 2, 3, 1, 0.0)
875
+ )
876
+ # Add UsageNotAvailable.
877
+ usage_summary.add(
878
+ 'model1', lm_lib.UsageNotAvailable(num_requests=1), False
879
+ )
880
+ self.assertEqual(
881
+ usage_summary.total, lm_lib.UsageNotAvailable(num_requests=3)
882
+ )
883
+ self.assertEqual(
884
+ usage_summary.uncached.total, lm_lib.UsageNotAvailable(num_requests=2)
885
+ )
886
+
887
+ def test_merge(self):
888
+ usage_summary = lm_lib.UsageSummary()
889
+ usage_summary.add(
890
+ 'model1', lm_lib.LMSamplingUsage(1, 2, 3, 1, 5.0), False
891
+ )
892
+ usage_summary.add(
893
+ 'model2', lm_lib.LMSamplingUsage(1, 2, 3, 1, 5.0), False
894
+ )
895
+ usage_summary.add(
896
+ 'model1', lm_lib.LMSamplingUsage(1, 2, 3, 1, 5.0), False
897
+ )
898
+ usage_summary2 = lm_lib.UsageSummary()
899
+ usage_summary2.add(
900
+ 'model1', lm_lib.LMSamplingUsage(1, 2, 3, 1, 5.0), False
901
+ )
902
+ usage_summary2.add(
903
+ 'model3', lm_lib.LMSamplingUsage(1, 2, 3, 1, 5.0), False
904
+ )
905
+ usage_summary2.merge(usage_summary)
906
+ self.assertEqual(
907
+ usage_summary2,
908
+ lm_lib.UsageSummary(
909
+ cached=lm_lib.UsageSummary.AggregatedUsage(
910
+ total=lm_lib.LMSamplingUsage(
911
+ prompt_tokens=0,
912
+ completion_tokens=0,
913
+ total_tokens=0,
914
+ num_requests=0,
915
+ estimated_cost=0.0,
916
+ ),
917
+ breakdown={}
918
+ ),
919
+ uncached=lm_lib.UsageSummary.AggregatedUsage(
920
+ total=lm_lib.LMSamplingUsage(
921
+ prompt_tokens=5,
922
+ completion_tokens=10,
923
+ total_tokens=15,
924
+ num_requests=5,
925
+ estimated_cost=25.0
926
+ ),
927
+ breakdown=dict(
928
+ model1=lm_lib.LMSamplingUsage(
929
+ prompt_tokens=3,
930
+ completion_tokens=6,
931
+ total_tokens=9,
932
+ num_requests=3,
933
+ estimated_cost=15.0
934
+ ),
935
+ model3=lm_lib.LMSamplingUsage(
936
+ prompt_tokens=1,
937
+ completion_tokens=2,
938
+ total_tokens=3,
939
+ num_requests=1,
940
+ estimated_cost=5.0
941
+ ),
942
+ model2=lm_lib.LMSamplingUsage(
943
+ prompt_tokens=1,
944
+ completion_tokens=2,
945
+ total_tokens=3,
946
+ num_requests=1,
947
+ estimated_cost=5.0
948
+ )
949
+ )
950
+ )
951
+ )
952
+ )
953
+
954
+ def test_html_view(self):
955
+ usage_summary = lm_lib.UsageSummary()
956
+ usage_summary.add(
957
+ 'model1', lm_lib.LMSamplingUsage(1, 2, 3, 1, 5.0), False
958
+ )
959
+ self.assertIn(
960
+ '5.000',
961
+ usage_summary.to_html(extra_flags=dict(as_badge=True)).content
962
+ )
963
+ usage_summary.add(
964
+ 'model1', lm_lib.LMSamplingUsage(1, 2, 3, 1, 5.0), False
965
+ )
966
+ self.assertIn(
967
+ '10.000',
968
+ usage_summary.to_html(
969
+ extra_flags=dict(as_badge=True, interactive=True)
970
+ ).content
971
+ )
972
+ self.assertTrue(
973
+ usage_summary.to_html().content.startswith('<details open')
974
+ )
975
+ with pg.views.html.controls.HtmlControl.track_scripts() as scripts:
976
+ usage_summary.add(
977
+ 'model2', lm_lib.LMSamplingUsage(1, 2, 3, 1, 5.0), False
978
+ )
979
+ self.assertEqual(len(scripts), 4)
980
+
344
981
 
345
982
  if __name__ == '__main__':
346
983
  unittest.main()