langfun 0.1.2.dev202410110804__py3-none-any.whl → 0.1.2.dev202410120803__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
langfun/core/__init__.py CHANGED
@@ -103,6 +103,7 @@ from langfun.core.language_model import LMSample
103
103
  from langfun.core.language_model import LMSamplingOptions
104
104
  from langfun.core.language_model import LMSamplingUsage
105
105
  from langfun.core.language_model import UsageNotAvailable
106
+ from langfun.core.language_model import UsageSummary
106
107
  from langfun.core.language_model import LMSamplingResult
107
108
  from langfun.core.language_model import LMScoringResult
108
109
  from langfun.core.language_model import LMCache
@@ -194,6 +194,7 @@ class EvaluationTest(unittest.TestCase):
194
194
  cache_seed=0,
195
195
  score=1.0,
196
196
  logprobs=None,
197
+ is_cached=False,
197
198
  usage=lf.LMSamplingUsage(387, 24, 411),
198
199
  tags=['lm-response', 'lm-output', 'transformed'],
199
200
  ),
@@ -89,7 +89,7 @@ class LangFuncCallTest(unittest.TestCase):
89
89
  self.assertEqual(
90
90
  r,
91
91
  message.AIMessage(
92
- 'Hello!!!', score=0.0, logprobs=None,
92
+ 'Hello!!!', score=0.0, logprobs=None, is_cached=False,
93
93
  usage=language_model.UsageNotAvailable()
94
94
  )
95
95
  )
@@ -120,7 +120,7 @@ class LangFuncCallTest(unittest.TestCase):
120
120
  self.assertEqual(
121
121
  r,
122
122
  message.AIMessage(
123
- 'Hello!!!', score=0.0, logprobs=None,
123
+ 'Hello!!!', score=0.0, logprobs=None, is_cached=False,
124
124
  usage=language_model.UsageNotAvailable()
125
125
  )
126
126
  )
@@ -19,7 +19,7 @@ import dataclasses
19
19
  import enum
20
20
  import threading
21
21
  import time
22
- from typing import Annotated, Any, Callable, Iterator, Sequence, Tuple, Type, Union
22
+ from typing import Annotated, Any, Callable, Iterator, Optional, Sequence, Tuple, Type, Union
23
23
  from langfun.core import component
24
24
  from langfun.core import concurrent
25
25
  from langfun.core import console
@@ -86,25 +86,75 @@ class LMSamplingUsage(pg.Object):
86
86
  completion_tokens: int
87
87
  total_tokens: int
88
88
  num_requests: int = 1
89
+ estimated_cost: Annotated[
90
+ float | None,
91
+ (
92
+ 'Estimated cost in US dollars. If None, cost estimating is not '
93
+ 'suppported on the model being queried.'
94
+ )
95
+ ] = None
96
+
97
+ def __bool__(self) -> bool:
98
+ return self.num_requests > 0
99
+
100
+ @property
101
+ def average_prompt_tokens(self) -> int:
102
+ """Returns the average prompt tokens per request."""
103
+ return self.prompt_tokens // self.num_requests
104
+
105
+ @property
106
+ def average_completion_tokens(self) -> int:
107
+ """Returns the average completion tokens per request."""
108
+ return self.completion_tokens // self.num_requests
109
+
110
+ @property
111
+ def average_total_tokens(self) -> int:
112
+ """Returns the average total tokens per request."""
113
+ return self.total_tokens // self.num_requests
89
114
 
90
- def __add__(self, other: 'LMSamplingUsage') -> 'LMSamplingUsage':
115
+ @property
116
+ def average_estimated_cost(self) -> float | None:
117
+ """Returns the average estimated cost per request."""
118
+ if self.estimated_cost is None:
119
+ return None
120
+ return self.estimated_cost / self.num_requests
121
+
122
+ def __add__(self, other: Optional['LMSamplingUsage']) -> 'LMSamplingUsage':
123
+ if other is None:
124
+ return self
91
125
  return LMSamplingUsage(
92
126
  prompt_tokens=self.prompt_tokens + other.prompt_tokens,
93
127
  completion_tokens=self.completion_tokens + other.completion_tokens,
94
128
  total_tokens=self.total_tokens + other.total_tokens,
95
129
  num_requests=self.num_requests + other.num_requests,
130
+ estimated_cost=(
131
+ self.estimated_cost + other.estimated_cost # pylint: disable=g-long-ternary
132
+ if (self.estimated_cost is not None
133
+ and other.estimated_cost is not None)
134
+ else None
135
+ )
96
136
  )
97
137
 
138
+ def __radd__(self, other: Optional['LMSamplingUsage']) -> 'LMSamplingUsage':
139
+ return self + other
140
+
98
141
 
99
142
  class UsageNotAvailable(LMSamplingUsage):
100
143
  """Usage information not available."""
101
144
  prompt_tokens: pg.typing.Int(0).freeze() # pytype: disable=invalid-annotation
102
145
  completion_tokens: pg.typing.Int(0).freeze() # pytype: disable=invalid-annotation
103
146
  total_tokens: pg.typing.Int(0).freeze() # pytype: disable=invalid-annotation
104
- num_requests: pg.typing.Int(1).freeze() # pytype: disable=invalid-annotation
147
+ estimated_cost: pg.typing.Float(default=None, is_noneable=True).freeze() # pytype: disable=invalid-annotation
105
148
 
106
- def __bool__(self) -> bool:
107
- return False
149
+ def __add__(self, other: Optional['LMSamplingUsage']) -> 'UsageNotAvailable':
150
+ if other is None:
151
+ return self
152
+ return UsageNotAvailable(
153
+ num_requests=self.num_requests + other.num_requests
154
+ )
155
+
156
+ def __radd__(self, other: Optional['LMSamplingUsage']) -> 'UsageNotAvailable':
157
+ return self + other
108
158
 
109
159
 
110
160
  class LMSamplingResult(pg.Object):
@@ -123,6 +173,11 @@ class LMSamplingResult(pg.Object):
123
173
  'Usage information. Currently only OpenAI models are supported.',
124
174
  ] = UsageNotAvailable()
125
175
 
176
+ is_cached: Annotated[
177
+ bool,
178
+ 'Whether the result is from cache or not.'
179
+ ] = False
180
+
126
181
 
127
182
  class LMSamplingOptions(component.Component):
128
183
  """Language model sampling options."""
@@ -425,12 +480,13 @@ class LanguageModel(component.Component):
425
480
  response = sample.response
426
481
  response.metadata.score = sample.score
427
482
  response.metadata.logprobs = sample.logprobs
483
+ response.metadata.is_cached = result.is_cached
428
484
 
429
485
  # NOTE(daiyip): Current usage is computed at per-result level,
430
486
  # which is accurate when n=1. For n > 1, we average the usage across
431
487
  # multiple samples.
432
488
  usage = result.usage
433
- if len(result.samples) == 1 or not usage:
489
+ if len(result.samples) == 1 or isinstance(usage, UsageNotAvailable):
434
490
  response.metadata.usage = usage
435
491
  else:
436
492
  n = len(result.samples)
@@ -438,6 +494,9 @@ class LanguageModel(component.Component):
438
494
  prompt_tokens=usage.prompt_tokens // n,
439
495
  completion_tokens=usage.completion_tokens // n,
440
496
  total_tokens=usage.total_tokens // n,
497
+ estimated_cost=(
498
+ usage.estimated_cost / n if usage.estimated_cost else None
499
+ )
441
500
  )
442
501
 
443
502
  # Track usage.
@@ -445,7 +504,7 @@ class LanguageModel(component.Component):
445
504
  if trackers:
446
505
  model_id = self.model_id
447
506
  for tracker in trackers:
448
- tracker.track(model_id, usage)
507
+ tracker.track(model_id, usage, result.is_cached)
449
508
 
450
509
  # Track the prompt for corresponding response.
451
510
  response.source = prompt
@@ -474,7 +533,9 @@ class LanguageModel(component.Component):
474
533
  request_to_result_index[len(requests)] = i
475
534
  requests.append(prompt)
476
535
  else:
477
- results[i] = r.clone()
536
+ result = r.clone()
537
+ assert result.is_cached, result
538
+ results[i] = result
478
539
 
479
540
  # Sample non-cache-hit prompts.
480
541
  if requests:
@@ -491,8 +552,12 @@ class LanguageModel(component.Component):
491
552
  sample.response.set('cache_seed', cache_seed)
492
553
 
493
554
  if cache_seed is not None:
494
- self.cache.put(self, prompt, result.clone(), seed=cache_seed)
495
-
555
+ self.cache.put(
556
+ self,
557
+ prompt,
558
+ result.clone(override=dict(is_cached=True)),
559
+ seed=cache_seed
560
+ )
496
561
  return results # pytype: disable=bad-return-type
497
562
 
498
563
  @abc.abstractmethod
@@ -800,30 +865,81 @@ class LanguageModel(component.Component):
800
865
  return DEFAULT_MAX_CONCURRENCY # Default of 1
801
866
 
802
867
 
868
+ class UsageSummary(pg.Object):
869
+ """Usage sumary."""
870
+
871
+ class AggregatedUsage(pg.Object):
872
+ """Aggregated usage."""
873
+
874
+ total: LMSamplingUsage = LMSamplingUsage(0, 0, 0, 0, 0.0)
875
+ breakdown: dict[str, LMSamplingUsage] = {}
876
+
877
+ def __bool__(self) -> bool:
878
+ """Returns True if the usage is non-empty."""
879
+ return bool(self.breakdown)
880
+
881
+ def add(
882
+ self,
883
+ model_id: str,
884
+ usage: LMSamplingUsage,
885
+ ) -> None:
886
+ """Adds an entry to the breakdown."""
887
+ aggregated = self.breakdown.get(model_id, None)
888
+ with pg.notify_on_change(False):
889
+ self.breakdown[model_id] = usage + aggregated
890
+ self.rebind(total=self.total + usage, skip_notification=True)
891
+
892
+ @property
893
+ def total(self) -> LMSamplingUsage:
894
+ return self.cached.total + self.uncached.total
895
+
896
+ def update(self, model_id: str, usage: LMSamplingUsage, is_cached: bool):
897
+ """Updates the usage summary."""
898
+ if is_cached:
899
+ usage.rebind(estimated_cost=0.0, skip_notification=True)
900
+ self.cached.add(model_id, usage)
901
+ else:
902
+ self.uncached.add(model_id, usage)
903
+
904
+
905
+ pg.members(
906
+ dict(
907
+ cached=(
908
+ pg.typing.Object(
909
+ UsageSummary.AggregatedUsage,
910
+ default=UsageSummary.AggregatedUsage()
911
+ ),
912
+ 'Aggregated usages for cached LLM calls.'
913
+ ),
914
+ uncached=(
915
+ pg.typing.Object(
916
+ UsageSummary.AggregatedUsage,
917
+ default=UsageSummary.AggregatedUsage()
918
+ ),
919
+ 'Aggregated usages for uncached LLM calls.'
920
+ ),
921
+ )
922
+ )(UsageSummary)
923
+
924
+
803
925
  class _UsageTracker:
804
926
  """Usage tracker."""
805
927
 
806
928
  def __init__(self, model_ids: set[str] | None):
807
929
  self.model_ids = model_ids
930
+ self.usage_summary = UsageSummary()
808
931
  self._lock = threading.Lock()
809
- self.usages = {
810
- m: LMSamplingUsage(0, 0, 0, 0) for m in model_ids
811
- } if model_ids else {}
812
-
813
- def track(self, model_id: str, usage: LMSamplingUsage):
814
- if self.model_ids is not None and model_id not in self.model_ids:
815
- return
816
- with self._lock:
817
- if not isinstance(usage, UsageNotAvailable) and model_id in self.usages:
818
- self.usages[model_id] += usage
819
- else:
820
- self.usages[model_id] = usage
932
+
933
+ def track(self, model_id: str, usage: LMSamplingUsage, is_cached: bool):
934
+ if self.model_ids is None or model_id in self.model_ids:
935
+ with self._lock:
936
+ self.usage_summary.update(model_id, usage, is_cached)
821
937
 
822
938
 
823
939
  @contextlib.contextmanager
824
940
  def track_usages(
825
941
  *lm: Union[str, LanguageModel]
826
- ) -> Iterator[dict[str, LMSamplingUsage]]:
942
+ ) -> Iterator[UsageSummary]:
827
943
  """Context manager to track the usages of all language models in scope.
828
944
 
829
945
  `lf.track_usages` works with threads spawned by `lf.concurrent_map` and
@@ -854,6 +970,6 @@ def track_usages(
854
970
  tracker = _UsageTracker(set(model_ids) if model_ids else None)
855
971
  with component.context(__usage_trackers__=trackers + [tracker]):
856
972
  try:
857
- yield tracker.usages
973
+ yield tracker.usage_summary
858
974
  finally:
859
975
  pass
@@ -49,6 +49,7 @@ class MockModel(lm_lib.LanguageModel):
49
49
  prompt_tokens=100,
50
50
  completion_tokens=100,
51
51
  total_tokens=200,
52
+ estimated_cost=1.0,
52
53
  ),
53
54
  )
54
55
  for prompt in prompts
@@ -128,14 +129,15 @@ class LanguageModelTest(unittest.TestCase):
128
129
  'foo',
129
130
  score=-1.0,
130
131
  logprobs=None,
131
- usage=lm_lib.LMSamplingUsage(100, 100, 200),
132
+ is_cached=False,
133
+ usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
132
134
  tags=[message_lib.Message.TAG_LM_RESPONSE],
133
135
  ),
134
136
  score=-1.0,
135
137
  logprobs=None,
136
138
  )
137
139
  ],
138
- usage=lm_lib.LMSamplingUsage(100, 100, 200),
140
+ usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
139
141
  ),
140
142
  lm_lib.LMSamplingResult(
141
143
  [
@@ -144,14 +146,15 @@ class LanguageModelTest(unittest.TestCase):
144
146
  'bar',
145
147
  score=-1.0,
146
148
  logprobs=None,
147
- usage=lm_lib.LMSamplingUsage(100, 100, 200),
149
+ is_cached=False,
150
+ usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
148
151
  tags=[message_lib.Message.TAG_LM_RESPONSE],
149
152
  ),
150
153
  score=-1.0,
151
154
  logprobs=None,
152
155
  )
153
156
  ],
154
- usage=lm_lib.LMSamplingUsage(100, 100, 200),
157
+ usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
155
158
  ),
156
159
  ],
157
160
  )
@@ -169,14 +172,15 @@ class LanguageModelTest(unittest.TestCase):
169
172
  'foo' * 2,
170
173
  score=0.5,
171
174
  logprobs=None,
172
- usage=lm_lib.LMSamplingUsage(100, 100, 200),
175
+ is_cached=False,
176
+ usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
173
177
  tags=[message_lib.Message.TAG_LM_RESPONSE],
174
178
  ),
175
179
  score=0.5,
176
180
  logprobs=None,
177
181
  ),
178
182
  ],
179
- usage=lm_lib.LMSamplingUsage(100, 100, 200),
183
+ usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
180
184
  ),
181
185
  lm_lib.LMSamplingResult(
182
186
  [
@@ -185,7 +189,8 @@ class LanguageModelTest(unittest.TestCase):
185
189
  'bar' * 2,
186
190
  score=0.5,
187
191
  logprobs=None,
188
- usage=lm_lib.LMSamplingUsage(100, 100, 200),
192
+ is_cached=False,
193
+ usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
189
194
  tags=[message_lib.Message.TAG_LM_RESPONSE],
190
195
  ),
191
196
  score=0.5,
@@ -193,7 +198,8 @@ class LanguageModelTest(unittest.TestCase):
193
198
  ),
194
199
  ],
195
200
  usage=lm_lib.LMSamplingUsage(
196
- prompt_tokens=100, completion_tokens=100, total_tokens=200
201
+ prompt_tokens=100, completion_tokens=100, total_tokens=200,
202
+ num_requests=1, estimated_cost=1.0,
197
203
  ),
198
204
  ),
199
205
  ]
@@ -209,14 +215,15 @@ class LanguageModelTest(unittest.TestCase):
209
215
  'foo',
210
216
  score=1.0,
211
217
  logprobs=None,
212
- usage=lm_lib.LMSamplingUsage(100, 100, 200),
218
+ is_cached=False,
219
+ usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
213
220
  tags=[message_lib.Message.TAG_LM_RESPONSE],
214
221
  ),
215
222
  score=1.0,
216
223
  logprobs=None,
217
224
  ),
218
225
  ],
219
- usage=lm_lib.LMSamplingUsage(100, 100, 200),
226
+ usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
220
227
  ),
221
228
  lm_lib.LMSamplingResult(
222
229
  [
@@ -225,7 +232,8 @@ class LanguageModelTest(unittest.TestCase):
225
232
  'bar',
226
233
  score=1.0,
227
234
  logprobs=None,
228
- usage=lm_lib.LMSamplingUsage(100, 100, 200),
235
+ is_cached=False,
236
+ usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
229
237
  tags=[message_lib.Message.TAG_LM_RESPONSE],
230
238
  ),
231
239
  score=1.0,
@@ -233,7 +241,8 @@ class LanguageModelTest(unittest.TestCase):
233
241
  ),
234
242
  ],
235
243
  usage=lm_lib.LMSamplingUsage(
236
- prompt_tokens=100, completion_tokens=100, total_tokens=200
244
+ prompt_tokens=100, completion_tokens=100, total_tokens=200,
245
+ num_requests=1, estimated_cost=1.0,
237
246
  ),
238
247
  ),
239
248
  ]
@@ -248,14 +257,15 @@ class LanguageModelTest(unittest.TestCase):
248
257
  'foo' * 2,
249
258
  score=0.7,
250
259
  logprobs=None,
251
- usage=lm_lib.LMSamplingUsage(100, 100, 200),
260
+ is_cached=False,
261
+ usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
252
262
  tags=[message_lib.Message.TAG_LM_RESPONSE],
253
263
  ),
254
264
  score=0.7,
255
265
  logprobs=None,
256
266
  ),
257
267
  ],
258
- usage=lm_lib.LMSamplingUsage(100, 100, 200),
268
+ usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
259
269
  ),
260
270
  lm_lib.LMSamplingResult(
261
271
  [
@@ -264,7 +274,8 @@ class LanguageModelTest(unittest.TestCase):
264
274
  'bar' * 2,
265
275
  score=0.7,
266
276
  logprobs=None,
267
- usage=lm_lib.LMSamplingUsage(100, 100, 200),
277
+ is_cached=False,
278
+ usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
268
279
  tags=[message_lib.Message.TAG_LM_RESPONSE],
269
280
  ),
270
281
  score=0.7,
@@ -272,7 +283,8 @@ class LanguageModelTest(unittest.TestCase):
272
283
  ),
273
284
  ],
274
285
  usage=lm_lib.LMSamplingUsage(
275
- prompt_tokens=100, completion_tokens=100, total_tokens=200
286
+ prompt_tokens=100, completion_tokens=100, total_tokens=200,
287
+ num_requests=1, estimated_cost=1.0,
276
288
  ),
277
289
  ),
278
290
  ]
@@ -284,7 +296,9 @@ class LanguageModelTest(unittest.TestCase):
284
296
  self.assertEqual(response.text, 'foo')
285
297
  self.assertEqual(response.score, -1.0)
286
298
  self.assertIsNone(response.logprobs)
287
- self.assertEqual(response.usage, lm_lib.LMSamplingUsage(100, 100, 200))
299
+ self.assertEqual(
300
+ response.usage, lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0)
301
+ )
288
302
 
289
303
  # Test override sampling_options.
290
304
  self.assertEqual(
@@ -307,14 +321,17 @@ class LanguageModelTest(unittest.TestCase):
307
321
  cache_seed=0,
308
322
  score=-1.0,
309
323
  logprobs=None,
310
- usage=lm_lib.LMSamplingUsage(100, 100, 200),
324
+ is_cached=False,
325
+ usage=lm_lib.LMSamplingUsage(
326
+ 100, 100, 200, 1, 1.0
327
+ ),
311
328
  tags=[message_lib.Message.TAG_LM_RESPONSE],
312
329
  ),
313
330
  score=-1.0,
314
331
  logprobs=None,
315
332
  )
316
333
  ],
317
- usage=lm_lib.LMSamplingUsage(100, 100, 200),
334
+ usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
318
335
  ),
319
336
  lm_lib.LMSamplingResult(
320
337
  [
@@ -324,14 +341,15 @@ class LanguageModelTest(unittest.TestCase):
324
341
  cache_seed=0,
325
342
  score=-1.0,
326
343
  logprobs=None,
327
- usage=lm_lib.LMSamplingUsage(100, 100, 200),
344
+ is_cached=False,
345
+ usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
328
346
  tags=[message_lib.Message.TAG_LM_RESPONSE],
329
347
  ),
330
348
  score=-1.0,
331
349
  logprobs=None,
332
350
  )
333
351
  ],
334
- usage=lm_lib.LMSamplingUsage(100, 100, 200),
352
+ usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
335
353
  ),
336
354
  ],
337
355
  )
@@ -339,7 +357,9 @@ class LanguageModelTest(unittest.TestCase):
339
357
  self.assertEqual(cache.stats.num_hits, 0)
340
358
  self.assertEqual(cache.stats.num_updates, 2)
341
359
 
342
- self.assertEqual(lm('foo'), 'foo')
360
+ result = lm('foo')
361
+ self.assertEqual(result, 'foo')
362
+ self.assertTrue(result.metadata.is_cached)
343
363
  self.assertEqual(lm('bar'), 'bar')
344
364
  self.assertEqual(cache.stats.num_queries, 4)
345
365
  self.assertEqual(cache.stats.num_hits, 2)
@@ -361,14 +381,15 @@ class LanguageModelTest(unittest.TestCase):
361
381
  cache_seed=0,
362
382
  score=1.0,
363
383
  logprobs=None,
364
- usage=lm_lib.LMSamplingUsage(100, 100, 200),
384
+ is_cached=False,
385
+ usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
365
386
  tags=[message_lib.Message.TAG_LM_RESPONSE],
366
387
  ),
367
388
  score=1.0,
368
389
  logprobs=None,
369
390
  )
370
391
  ],
371
- usage=lm_lib.LMSamplingUsage(100, 100, 200),
392
+ usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
372
393
  ),
373
394
  lm_lib.LMSamplingResult(
374
395
  [
@@ -378,14 +399,15 @@ class LanguageModelTest(unittest.TestCase):
378
399
  cache_seed=0,
379
400
  score=1.0,
380
401
  logprobs=None,
381
- usage=lm_lib.LMSamplingUsage(100, 100, 200),
402
+ is_cached=False,
403
+ usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
382
404
  tags=[message_lib.Message.TAG_LM_RESPONSE],
383
405
  ),
384
406
  score=1.0,
385
407
  logprobs=None,
386
408
  )
387
409
  ],
388
- usage=lm_lib.LMSamplingUsage(100, 100, 200),
410
+ usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
389
411
  ),
390
412
  ],
391
413
  )
@@ -663,20 +685,128 @@ class LanguageModelTest(unittest.TestCase):
663
685
  lm2('hi')
664
686
  list(concurrent.concurrent_map(call_lm, ['hi', 'hello']))
665
687
 
666
- self.assertEqual(usages2, {
667
- 'model2': lm_lib.LMSamplingUsage(100, 100, 200, 1),
688
+ print(usages2)
689
+ self.assertEqual(usages2.uncached.breakdown, {
690
+ 'model2': lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
691
+ })
692
+ self.assertFalse(usages2.cached)
693
+ self.assertEqual(usages3.uncached.breakdown, {
694
+ 'model1': lm_lib.LMSamplingUsage(100 * 4, 100 * 4, 200 * 4, 4, 4.0),
668
695
  })
669
- self.assertEqual(usages3, {
670
- 'model1': lm_lib.LMSamplingUsage(100 * 4, 100 * 4, 200 * 4, 4),
696
+ self.assertFalse(usages3.cached)
697
+ self.assertEqual(usages4.uncached.breakdown, {
698
+ 'model1': lm_lib.LMSamplingUsage(100 * 4, 100 * 4, 200 * 4, 4, 4.0),
699
+ 'model2': lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
671
700
  })
672
- self.assertEqual(usages4, {
673
- 'model1': lm_lib.LMSamplingUsage(100 * 4, 100 * 4, 200 * 4, 4),
674
- 'model2': lm_lib.LMSamplingUsage(100, 100, 200, 1),
701
+ self.assertFalse(usages4.cached)
702
+ self.assertEqual(usages1.uncached.breakdown, {
703
+ 'model1': lm_lib.LMSamplingUsage(100 * 5, 100 * 5, 200 * 5, 5, 5.0),
704
+ 'model2': lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
675
705
  })
676
- self.assertEqual(usages1, {
677
- 'model1': lm_lib.LMSamplingUsage(100 * 5, 100 * 5, 200 * 5, 5),
678
- 'model2': lm_lib.LMSamplingUsage(100, 100, 200, 1),
706
+ self.assertFalse(usages1.cached)
707
+ self.assertEqual(
708
+ usages1.total,
709
+ lm_lib.LMSamplingUsage(100 * 6, 100 * 6, 200 * 6, 6, 6.0),
710
+ )
711
+
712
+ cache = in_memory.InMemory()
713
+ lm = MockModel(cache=cache, name='model1')
714
+ with lm_lib.track_usages() as usages1:
715
+ _ = lm('hi')
716
+ self.assertEqual(usages1.uncached.breakdown, {
717
+ 'model1': lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
679
718
  })
719
+ self.assertFalse(usages1.cached)
720
+ with lm_lib.track_usages() as usages2:
721
+ _ = lm('hi')
722
+ self.assertEqual(usages2.cached.breakdown, {
723
+ 'model1': lm_lib.LMSamplingUsage(100, 100, 200, 1, 0.0),
724
+ })
725
+ self.assertFalse(usages2.uncached)
726
+
727
+
728
+ class LMSamplingUsageTest(unittest.TestCase):
729
+
730
+ def test_basics(self):
731
+ usage = lm_lib.LMSamplingUsage(100, 200, 300, 4, 5.0)
732
+ self.assertEqual(usage.num_requests, 4)
733
+ self.assertEqual(usage.prompt_tokens, 100)
734
+ self.assertEqual(usage.completion_tokens, 200)
735
+ self.assertEqual(usage.total_tokens, 300)
736
+ self.assertEqual(usage.estimated_cost, 5.0)
737
+ self.assertEqual(usage.average_prompt_tokens, 25)
738
+ self.assertEqual(usage.average_completion_tokens, 50)
739
+ self.assertEqual(usage.average_total_tokens, 75)
740
+ self.assertEqual(usage.average_estimated_cost, 1.25)
741
+
742
+ def test_add(self):
743
+ usage1 = lm_lib.LMSamplingUsage(100, 200, 300, 4, 5.0)
744
+ usage2 = lm_lib.LMSamplingUsage(100, 200, 300, 4, 5.0)
745
+ self.assertEqual(usage1 + usage2, usage1 + usage2)
746
+ self.assertIs(usage1 + None, usage1)
747
+ self.assertIs(None + usage1, usage1)
748
+
749
+ def test_usage_not_available(self):
750
+ usage_not_available = lm_lib.UsageNotAvailable()
751
+ self.assertEqual(usage_not_available.prompt_tokens, 0)
752
+ self.assertEqual(usage_not_available.completion_tokens, 0)
753
+ self.assertEqual(usage_not_available.total_tokens, 0)
754
+ self.assertEqual(usage_not_available.average_prompt_tokens, 0)
755
+ self.assertEqual(usage_not_available.average_completion_tokens, 0)
756
+ self.assertEqual(usage_not_available.average_total_tokens, 0)
757
+ self.assertIsNone(usage_not_available.average_estimated_cost)
758
+ self.assertTrue(usage_not_available)
759
+ self.assertEqual(
760
+ usage_not_available + lm_lib.LMSamplingUsage(1, 2, 3, 4, 5.0),
761
+ lm_lib.UsageNotAvailable(num_requests=5)
762
+ )
763
+ self.assertEqual(
764
+ lm_lib.LMSamplingUsage(1, 2, 3, 4, 5.0) + usage_not_available,
765
+ lm_lib.UsageNotAvailable(num_requests=5)
766
+ )
767
+ self.assertIs(None + usage_not_available, usage_not_available)
768
+ self.assertIs(usage_not_available + None, usage_not_available)
769
+
770
+
771
+ class UsageSummaryTest(unittest.TestCase):
772
+
773
+ def test_basics(self):
774
+ usage_summary = lm_lib.UsageSummary()
775
+ self.assertFalse(usage_summary.total)
776
+ self.assertFalse(usage_summary.cached)
777
+ self.assertFalse(usage_summary.uncached)
778
+
779
+ # Add uncached.
780
+ usage_summary.update(
781
+ 'model1', lm_lib.LMSamplingUsage(1, 2, 3, 1, 5.0), False
782
+ )
783
+ self.assertEqual(
784
+ usage_summary.total, lm_lib.LMSamplingUsage(1, 2, 3, 1, 5.0)
785
+ )
786
+ self.assertEqual(
787
+ usage_summary.uncached.total, lm_lib.LMSamplingUsage(1, 2, 3, 1, 5.0)
788
+ )
789
+ # Add cached.
790
+ self.assertFalse(usage_summary.cached)
791
+ usage_summary.update(
792
+ 'model1', lm_lib.LMSamplingUsage(1, 2, 3, 1, 5.0), True
793
+ )
794
+ self.assertEqual(
795
+ usage_summary.total, lm_lib.LMSamplingUsage(2, 4, 6, 2, 5.0)
796
+ )
797
+ self.assertEqual(
798
+ usage_summary.cached.total, lm_lib.LMSamplingUsage(1, 2, 3, 1, 0.0)
799
+ )
800
+ # Add UsageNotAvailable.
801
+ usage_summary.update(
802
+ 'model1', lm_lib.UsageNotAvailable(num_requests=1), False
803
+ )
804
+ self.assertEqual(
805
+ usage_summary.total, lm_lib.UsageNotAvailable(num_requests=3)
806
+ )
807
+ self.assertEqual(
808
+ usage_summary.uncached.total, lm_lib.UsageNotAvailable(num_requests=2)
809
+ )
680
810
 
681
811
 
682
812
  if __name__ == '__main__':