langfun 0.1.2.dev202410110804__py3-none-any.whl → 0.1.2.dev202410130803__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/core/__init__.py +1 -0
- langfun/core/eval/base_test.py +1 -0
- langfun/core/langfunc_test.py +2 -2
- langfun/core/language_model.py +140 -24
- langfun/core/language_model_test.py +166 -36
- langfun/core/llms/__init__.py +8 -1
- langfun/core/llms/anthropic.py +72 -7
- langfun/core/llms/cache/in_memory_test.py +3 -2
- langfun/core/llms/fake_test.py +7 -0
- langfun/core/llms/groq.py +154 -6
- langfun/core/llms/openai.py +300 -42
- langfun/core/llms/openai_test.py +35 -8
- langfun/core/llms/vertexai.py +121 -16
- langfun/core/logging.py +9 -3
- langfun/core/message.py +23 -12
- langfun/core/message_test.py +2 -2
- langfun/core/structured/completion_test.py +1 -0
- langfun/core/structured/mapping.py +1 -1
- langfun/core/structured/parsing_test.py +2 -1
- langfun/core/structured/prompting_test.py +1 -0
- langfun/core/template.py +8 -5
- langfun/core/templates/selfplay_test.py +4 -2
- {langfun-0.1.2.dev202410110804.dist-info → langfun-0.1.2.dev202410130803.dist-info}/METADATA +1 -1
- {langfun-0.1.2.dev202410110804.dist-info → langfun-0.1.2.dev202410130803.dist-info}/RECORD +27 -27
- {langfun-0.1.2.dev202410110804.dist-info → langfun-0.1.2.dev202410130803.dist-info}/LICENSE +0 -0
- {langfun-0.1.2.dev202410110804.dist-info → langfun-0.1.2.dev202410130803.dist-info}/WHEEL +0 -0
- {langfun-0.1.2.dev202410110804.dist-info → langfun-0.1.2.dev202410130803.dist-info}/top_level.txt +0 -0
langfun/core/__init__.py
CHANGED
@@ -103,6 +103,7 @@ from langfun.core.language_model import LMSample
|
|
103
103
|
from langfun.core.language_model import LMSamplingOptions
|
104
104
|
from langfun.core.language_model import LMSamplingUsage
|
105
105
|
from langfun.core.language_model import UsageNotAvailable
|
106
|
+
from langfun.core.language_model import UsageSummary
|
106
107
|
from langfun.core.language_model import LMSamplingResult
|
107
108
|
from langfun.core.language_model import LMScoringResult
|
108
109
|
from langfun.core.language_model import LMCache
|
langfun/core/eval/base_test.py
CHANGED
langfun/core/langfunc_test.py
CHANGED
@@ -89,7 +89,7 @@ class LangFuncCallTest(unittest.TestCase):
|
|
89
89
|
self.assertEqual(
|
90
90
|
r,
|
91
91
|
message.AIMessage(
|
92
|
-
'Hello!!!', score=0.0, logprobs=None,
|
92
|
+
'Hello!!!', score=0.0, logprobs=None, is_cached=False,
|
93
93
|
usage=language_model.UsageNotAvailable()
|
94
94
|
)
|
95
95
|
)
|
@@ -120,7 +120,7 @@ class LangFuncCallTest(unittest.TestCase):
|
|
120
120
|
self.assertEqual(
|
121
121
|
r,
|
122
122
|
message.AIMessage(
|
123
|
-
'Hello!!!', score=0.0, logprobs=None,
|
123
|
+
'Hello!!!', score=0.0, logprobs=None, is_cached=False,
|
124
124
|
usage=language_model.UsageNotAvailable()
|
125
125
|
)
|
126
126
|
)
|
langfun/core/language_model.py
CHANGED
@@ -19,7 +19,7 @@ import dataclasses
|
|
19
19
|
import enum
|
20
20
|
import threading
|
21
21
|
import time
|
22
|
-
from typing import Annotated, Any, Callable, Iterator, Sequence, Tuple, Type, Union
|
22
|
+
from typing import Annotated, Any, Callable, Iterator, Optional, Sequence, Tuple, Type, Union
|
23
23
|
from langfun.core import component
|
24
24
|
from langfun.core import concurrent
|
25
25
|
from langfun.core import console
|
@@ -86,25 +86,75 @@ class LMSamplingUsage(pg.Object):
|
|
86
86
|
completion_tokens: int
|
87
87
|
total_tokens: int
|
88
88
|
num_requests: int = 1
|
89
|
+
estimated_cost: Annotated[
|
90
|
+
float | None,
|
91
|
+
(
|
92
|
+
'Estimated cost in US dollars. If None, cost estimating is not '
|
93
|
+
'suppported on the model being queried.'
|
94
|
+
)
|
95
|
+
] = None
|
96
|
+
|
97
|
+
def __bool__(self) -> bool:
|
98
|
+
return self.num_requests > 0
|
99
|
+
|
100
|
+
@property
|
101
|
+
def average_prompt_tokens(self) -> int:
|
102
|
+
"""Returns the average prompt tokens per request."""
|
103
|
+
return self.prompt_tokens // self.num_requests
|
104
|
+
|
105
|
+
@property
|
106
|
+
def average_completion_tokens(self) -> int:
|
107
|
+
"""Returns the average completion tokens per request."""
|
108
|
+
return self.completion_tokens // self.num_requests
|
109
|
+
|
110
|
+
@property
|
111
|
+
def average_total_tokens(self) -> int:
|
112
|
+
"""Returns the average total tokens per request."""
|
113
|
+
return self.total_tokens // self.num_requests
|
89
114
|
|
90
|
-
|
115
|
+
@property
|
116
|
+
def average_estimated_cost(self) -> float | None:
|
117
|
+
"""Returns the average estimated cost per request."""
|
118
|
+
if self.estimated_cost is None:
|
119
|
+
return None
|
120
|
+
return self.estimated_cost / self.num_requests
|
121
|
+
|
122
|
+
def __add__(self, other: Optional['LMSamplingUsage']) -> 'LMSamplingUsage':
|
123
|
+
if other is None:
|
124
|
+
return self
|
91
125
|
return LMSamplingUsage(
|
92
126
|
prompt_tokens=self.prompt_tokens + other.prompt_tokens,
|
93
127
|
completion_tokens=self.completion_tokens + other.completion_tokens,
|
94
128
|
total_tokens=self.total_tokens + other.total_tokens,
|
95
129
|
num_requests=self.num_requests + other.num_requests,
|
130
|
+
estimated_cost=(
|
131
|
+
self.estimated_cost + other.estimated_cost # pylint: disable=g-long-ternary
|
132
|
+
if (self.estimated_cost is not None
|
133
|
+
and other.estimated_cost is not None)
|
134
|
+
else None
|
135
|
+
)
|
96
136
|
)
|
97
137
|
|
138
|
+
def __radd__(self, other: Optional['LMSamplingUsage']) -> 'LMSamplingUsage':
|
139
|
+
return self + other
|
140
|
+
|
98
141
|
|
99
142
|
class UsageNotAvailable(LMSamplingUsage):
|
100
143
|
"""Usage information not available."""
|
101
144
|
prompt_tokens: pg.typing.Int(0).freeze() # pytype: disable=invalid-annotation
|
102
145
|
completion_tokens: pg.typing.Int(0).freeze() # pytype: disable=invalid-annotation
|
103
146
|
total_tokens: pg.typing.Int(0).freeze() # pytype: disable=invalid-annotation
|
104
|
-
|
147
|
+
estimated_cost: pg.typing.Float(default=None, is_noneable=True).freeze() # pytype: disable=invalid-annotation
|
105
148
|
|
106
|
-
def
|
107
|
-
|
149
|
+
def __add__(self, other: Optional['LMSamplingUsage']) -> 'UsageNotAvailable':
|
150
|
+
if other is None:
|
151
|
+
return self
|
152
|
+
return UsageNotAvailable(
|
153
|
+
num_requests=self.num_requests + other.num_requests
|
154
|
+
)
|
155
|
+
|
156
|
+
def __radd__(self, other: Optional['LMSamplingUsage']) -> 'UsageNotAvailable':
|
157
|
+
return self + other
|
108
158
|
|
109
159
|
|
110
160
|
class LMSamplingResult(pg.Object):
|
@@ -123,6 +173,11 @@ class LMSamplingResult(pg.Object):
|
|
123
173
|
'Usage information. Currently only OpenAI models are supported.',
|
124
174
|
] = UsageNotAvailable()
|
125
175
|
|
176
|
+
is_cached: Annotated[
|
177
|
+
bool,
|
178
|
+
'Whether the result is from cache or not.'
|
179
|
+
] = False
|
180
|
+
|
126
181
|
|
127
182
|
class LMSamplingOptions(component.Component):
|
128
183
|
"""Language model sampling options."""
|
@@ -425,12 +480,13 @@ class LanguageModel(component.Component):
|
|
425
480
|
response = sample.response
|
426
481
|
response.metadata.score = sample.score
|
427
482
|
response.metadata.logprobs = sample.logprobs
|
483
|
+
response.metadata.is_cached = result.is_cached
|
428
484
|
|
429
485
|
# NOTE(daiyip): Current usage is computed at per-result level,
|
430
486
|
# which is accurate when n=1. For n > 1, we average the usage across
|
431
487
|
# multiple samples.
|
432
488
|
usage = result.usage
|
433
|
-
if len(result.samples) == 1 or
|
489
|
+
if len(result.samples) == 1 or isinstance(usage, UsageNotAvailable):
|
434
490
|
response.metadata.usage = usage
|
435
491
|
else:
|
436
492
|
n = len(result.samples)
|
@@ -438,6 +494,9 @@ class LanguageModel(component.Component):
|
|
438
494
|
prompt_tokens=usage.prompt_tokens // n,
|
439
495
|
completion_tokens=usage.completion_tokens // n,
|
440
496
|
total_tokens=usage.total_tokens // n,
|
497
|
+
estimated_cost=(
|
498
|
+
usage.estimated_cost / n if usage.estimated_cost else None
|
499
|
+
)
|
441
500
|
)
|
442
501
|
|
443
502
|
# Track usage.
|
@@ -445,7 +504,7 @@ class LanguageModel(component.Component):
|
|
445
504
|
if trackers:
|
446
505
|
model_id = self.model_id
|
447
506
|
for tracker in trackers:
|
448
|
-
tracker.track(model_id, usage)
|
507
|
+
tracker.track(model_id, usage, result.is_cached)
|
449
508
|
|
450
509
|
# Track the prompt for corresponding response.
|
451
510
|
response.source = prompt
|
@@ -474,7 +533,9 @@ class LanguageModel(component.Component):
|
|
474
533
|
request_to_result_index[len(requests)] = i
|
475
534
|
requests.append(prompt)
|
476
535
|
else:
|
477
|
-
|
536
|
+
result = r.clone()
|
537
|
+
assert result.is_cached, result
|
538
|
+
results[i] = result
|
478
539
|
|
479
540
|
# Sample non-cache-hit prompts.
|
480
541
|
if requests:
|
@@ -491,8 +552,12 @@ class LanguageModel(component.Component):
|
|
491
552
|
sample.response.set('cache_seed', cache_seed)
|
492
553
|
|
493
554
|
if cache_seed is not None:
|
494
|
-
self.cache.put(
|
495
|
-
|
555
|
+
self.cache.put(
|
556
|
+
self,
|
557
|
+
prompt,
|
558
|
+
result.clone(override=dict(is_cached=True)),
|
559
|
+
seed=cache_seed
|
560
|
+
)
|
496
561
|
return results # pytype: disable=bad-return-type
|
497
562
|
|
498
563
|
@abc.abstractmethod
|
@@ -800,30 +865,81 @@ class LanguageModel(component.Component):
|
|
800
865
|
return DEFAULT_MAX_CONCURRENCY # Default of 1
|
801
866
|
|
802
867
|
|
868
|
+
class UsageSummary(pg.Object):
|
869
|
+
"""Usage sumary."""
|
870
|
+
|
871
|
+
class AggregatedUsage(pg.Object):
|
872
|
+
"""Aggregated usage."""
|
873
|
+
|
874
|
+
total: LMSamplingUsage = LMSamplingUsage(0, 0, 0, 0, 0.0)
|
875
|
+
breakdown: dict[str, LMSamplingUsage] = {}
|
876
|
+
|
877
|
+
def __bool__(self) -> bool:
|
878
|
+
"""Returns True if the usage is non-empty."""
|
879
|
+
return bool(self.breakdown)
|
880
|
+
|
881
|
+
def add(
|
882
|
+
self,
|
883
|
+
model_id: str,
|
884
|
+
usage: LMSamplingUsage,
|
885
|
+
) -> None:
|
886
|
+
"""Adds an entry to the breakdown."""
|
887
|
+
aggregated = self.breakdown.get(model_id, None)
|
888
|
+
with pg.notify_on_change(False):
|
889
|
+
self.breakdown[model_id] = usage + aggregated
|
890
|
+
self.rebind(total=self.total + usage, skip_notification=True)
|
891
|
+
|
892
|
+
@property
|
893
|
+
def total(self) -> LMSamplingUsage:
|
894
|
+
return self.cached.total + self.uncached.total
|
895
|
+
|
896
|
+
def update(self, model_id: str, usage: LMSamplingUsage, is_cached: bool):
|
897
|
+
"""Updates the usage summary."""
|
898
|
+
if is_cached:
|
899
|
+
usage.rebind(estimated_cost=0.0, skip_notification=True)
|
900
|
+
self.cached.add(model_id, usage)
|
901
|
+
else:
|
902
|
+
self.uncached.add(model_id, usage)
|
903
|
+
|
904
|
+
|
905
|
+
pg.members(
|
906
|
+
dict(
|
907
|
+
cached=(
|
908
|
+
pg.typing.Object(
|
909
|
+
UsageSummary.AggregatedUsage,
|
910
|
+
default=UsageSummary.AggregatedUsage()
|
911
|
+
),
|
912
|
+
'Aggregated usages for cached LLM calls.'
|
913
|
+
),
|
914
|
+
uncached=(
|
915
|
+
pg.typing.Object(
|
916
|
+
UsageSummary.AggregatedUsage,
|
917
|
+
default=UsageSummary.AggregatedUsage()
|
918
|
+
),
|
919
|
+
'Aggregated usages for uncached LLM calls.'
|
920
|
+
),
|
921
|
+
)
|
922
|
+
)(UsageSummary)
|
923
|
+
|
924
|
+
|
803
925
|
class _UsageTracker:
|
804
926
|
"""Usage tracker."""
|
805
927
|
|
806
928
|
def __init__(self, model_ids: set[str] | None):
|
807
929
|
self.model_ids = model_ids
|
930
|
+
self.usage_summary = UsageSummary()
|
808
931
|
self._lock = threading.Lock()
|
809
|
-
|
810
|
-
|
811
|
-
|
812
|
-
|
813
|
-
|
814
|
-
if self.model_ids is not None and model_id not in self.model_ids:
|
815
|
-
return
|
816
|
-
with self._lock:
|
817
|
-
if not isinstance(usage, UsageNotAvailable) and model_id in self.usages:
|
818
|
-
self.usages[model_id] += usage
|
819
|
-
else:
|
820
|
-
self.usages[model_id] = usage
|
932
|
+
|
933
|
+
def track(self, model_id: str, usage: LMSamplingUsage, is_cached: bool):
|
934
|
+
if self.model_ids is None or model_id in self.model_ids:
|
935
|
+
with self._lock:
|
936
|
+
self.usage_summary.update(model_id, usage, is_cached)
|
821
937
|
|
822
938
|
|
823
939
|
@contextlib.contextmanager
|
824
940
|
def track_usages(
|
825
941
|
*lm: Union[str, LanguageModel]
|
826
|
-
) -> Iterator[
|
942
|
+
) -> Iterator[UsageSummary]:
|
827
943
|
"""Context manager to track the usages of all language models in scope.
|
828
944
|
|
829
945
|
`lf.track_usages` works with threads spawned by `lf.concurrent_map` and
|
@@ -854,6 +970,6 @@ def track_usages(
|
|
854
970
|
tracker = _UsageTracker(set(model_ids) if model_ids else None)
|
855
971
|
with component.context(__usage_trackers__=trackers + [tracker]):
|
856
972
|
try:
|
857
|
-
yield tracker.
|
973
|
+
yield tracker.usage_summary
|
858
974
|
finally:
|
859
975
|
pass
|
@@ -49,6 +49,7 @@ class MockModel(lm_lib.LanguageModel):
|
|
49
49
|
prompt_tokens=100,
|
50
50
|
completion_tokens=100,
|
51
51
|
total_tokens=200,
|
52
|
+
estimated_cost=1.0,
|
52
53
|
),
|
53
54
|
)
|
54
55
|
for prompt in prompts
|
@@ -128,14 +129,15 @@ class LanguageModelTest(unittest.TestCase):
|
|
128
129
|
'foo',
|
129
130
|
score=-1.0,
|
130
131
|
logprobs=None,
|
131
|
-
|
132
|
+
is_cached=False,
|
133
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
|
132
134
|
tags=[message_lib.Message.TAG_LM_RESPONSE],
|
133
135
|
),
|
134
136
|
score=-1.0,
|
135
137
|
logprobs=None,
|
136
138
|
)
|
137
139
|
],
|
138
|
-
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
140
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
|
139
141
|
),
|
140
142
|
lm_lib.LMSamplingResult(
|
141
143
|
[
|
@@ -144,14 +146,15 @@ class LanguageModelTest(unittest.TestCase):
|
|
144
146
|
'bar',
|
145
147
|
score=-1.0,
|
146
148
|
logprobs=None,
|
147
|
-
|
149
|
+
is_cached=False,
|
150
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
|
148
151
|
tags=[message_lib.Message.TAG_LM_RESPONSE],
|
149
152
|
),
|
150
153
|
score=-1.0,
|
151
154
|
logprobs=None,
|
152
155
|
)
|
153
156
|
],
|
154
|
-
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
157
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
|
155
158
|
),
|
156
159
|
],
|
157
160
|
)
|
@@ -169,14 +172,15 @@ class LanguageModelTest(unittest.TestCase):
|
|
169
172
|
'foo' * 2,
|
170
173
|
score=0.5,
|
171
174
|
logprobs=None,
|
172
|
-
|
175
|
+
is_cached=False,
|
176
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
|
173
177
|
tags=[message_lib.Message.TAG_LM_RESPONSE],
|
174
178
|
),
|
175
179
|
score=0.5,
|
176
180
|
logprobs=None,
|
177
181
|
),
|
178
182
|
],
|
179
|
-
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
183
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
|
180
184
|
),
|
181
185
|
lm_lib.LMSamplingResult(
|
182
186
|
[
|
@@ -185,7 +189,8 @@ class LanguageModelTest(unittest.TestCase):
|
|
185
189
|
'bar' * 2,
|
186
190
|
score=0.5,
|
187
191
|
logprobs=None,
|
188
|
-
|
192
|
+
is_cached=False,
|
193
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
|
189
194
|
tags=[message_lib.Message.TAG_LM_RESPONSE],
|
190
195
|
),
|
191
196
|
score=0.5,
|
@@ -193,7 +198,8 @@ class LanguageModelTest(unittest.TestCase):
|
|
193
198
|
),
|
194
199
|
],
|
195
200
|
usage=lm_lib.LMSamplingUsage(
|
196
|
-
prompt_tokens=100, completion_tokens=100, total_tokens=200
|
201
|
+
prompt_tokens=100, completion_tokens=100, total_tokens=200,
|
202
|
+
num_requests=1, estimated_cost=1.0,
|
197
203
|
),
|
198
204
|
),
|
199
205
|
]
|
@@ -209,14 +215,15 @@ class LanguageModelTest(unittest.TestCase):
|
|
209
215
|
'foo',
|
210
216
|
score=1.0,
|
211
217
|
logprobs=None,
|
212
|
-
|
218
|
+
is_cached=False,
|
219
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
|
213
220
|
tags=[message_lib.Message.TAG_LM_RESPONSE],
|
214
221
|
),
|
215
222
|
score=1.0,
|
216
223
|
logprobs=None,
|
217
224
|
),
|
218
225
|
],
|
219
|
-
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
226
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
|
220
227
|
),
|
221
228
|
lm_lib.LMSamplingResult(
|
222
229
|
[
|
@@ -225,7 +232,8 @@ class LanguageModelTest(unittest.TestCase):
|
|
225
232
|
'bar',
|
226
233
|
score=1.0,
|
227
234
|
logprobs=None,
|
228
|
-
|
235
|
+
is_cached=False,
|
236
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
|
229
237
|
tags=[message_lib.Message.TAG_LM_RESPONSE],
|
230
238
|
),
|
231
239
|
score=1.0,
|
@@ -233,7 +241,8 @@ class LanguageModelTest(unittest.TestCase):
|
|
233
241
|
),
|
234
242
|
],
|
235
243
|
usage=lm_lib.LMSamplingUsage(
|
236
|
-
prompt_tokens=100, completion_tokens=100, total_tokens=200
|
244
|
+
prompt_tokens=100, completion_tokens=100, total_tokens=200,
|
245
|
+
num_requests=1, estimated_cost=1.0,
|
237
246
|
),
|
238
247
|
),
|
239
248
|
]
|
@@ -248,14 +257,15 @@ class LanguageModelTest(unittest.TestCase):
|
|
248
257
|
'foo' * 2,
|
249
258
|
score=0.7,
|
250
259
|
logprobs=None,
|
251
|
-
|
260
|
+
is_cached=False,
|
261
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
|
252
262
|
tags=[message_lib.Message.TAG_LM_RESPONSE],
|
253
263
|
),
|
254
264
|
score=0.7,
|
255
265
|
logprobs=None,
|
256
266
|
),
|
257
267
|
],
|
258
|
-
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
268
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
|
259
269
|
),
|
260
270
|
lm_lib.LMSamplingResult(
|
261
271
|
[
|
@@ -264,7 +274,8 @@ class LanguageModelTest(unittest.TestCase):
|
|
264
274
|
'bar' * 2,
|
265
275
|
score=0.7,
|
266
276
|
logprobs=None,
|
267
|
-
|
277
|
+
is_cached=False,
|
278
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
|
268
279
|
tags=[message_lib.Message.TAG_LM_RESPONSE],
|
269
280
|
),
|
270
281
|
score=0.7,
|
@@ -272,7 +283,8 @@ class LanguageModelTest(unittest.TestCase):
|
|
272
283
|
),
|
273
284
|
],
|
274
285
|
usage=lm_lib.LMSamplingUsage(
|
275
|
-
prompt_tokens=100, completion_tokens=100, total_tokens=200
|
286
|
+
prompt_tokens=100, completion_tokens=100, total_tokens=200,
|
287
|
+
num_requests=1, estimated_cost=1.0,
|
276
288
|
),
|
277
289
|
),
|
278
290
|
]
|
@@ -284,7 +296,9 @@ class LanguageModelTest(unittest.TestCase):
|
|
284
296
|
self.assertEqual(response.text, 'foo')
|
285
297
|
self.assertEqual(response.score, -1.0)
|
286
298
|
self.assertIsNone(response.logprobs)
|
287
|
-
self.assertEqual(
|
299
|
+
self.assertEqual(
|
300
|
+
response.usage, lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0)
|
301
|
+
)
|
288
302
|
|
289
303
|
# Test override sampling_options.
|
290
304
|
self.assertEqual(
|
@@ -307,14 +321,17 @@ class LanguageModelTest(unittest.TestCase):
|
|
307
321
|
cache_seed=0,
|
308
322
|
score=-1.0,
|
309
323
|
logprobs=None,
|
310
|
-
|
324
|
+
is_cached=False,
|
325
|
+
usage=lm_lib.LMSamplingUsage(
|
326
|
+
100, 100, 200, 1, 1.0
|
327
|
+
),
|
311
328
|
tags=[message_lib.Message.TAG_LM_RESPONSE],
|
312
329
|
),
|
313
330
|
score=-1.0,
|
314
331
|
logprobs=None,
|
315
332
|
)
|
316
333
|
],
|
317
|
-
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
334
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
|
318
335
|
),
|
319
336
|
lm_lib.LMSamplingResult(
|
320
337
|
[
|
@@ -324,14 +341,15 @@ class LanguageModelTest(unittest.TestCase):
|
|
324
341
|
cache_seed=0,
|
325
342
|
score=-1.0,
|
326
343
|
logprobs=None,
|
327
|
-
|
344
|
+
is_cached=False,
|
345
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
|
328
346
|
tags=[message_lib.Message.TAG_LM_RESPONSE],
|
329
347
|
),
|
330
348
|
score=-1.0,
|
331
349
|
logprobs=None,
|
332
350
|
)
|
333
351
|
],
|
334
|
-
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
352
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
|
335
353
|
),
|
336
354
|
],
|
337
355
|
)
|
@@ -339,7 +357,9 @@ class LanguageModelTest(unittest.TestCase):
|
|
339
357
|
self.assertEqual(cache.stats.num_hits, 0)
|
340
358
|
self.assertEqual(cache.stats.num_updates, 2)
|
341
359
|
|
342
|
-
|
360
|
+
result = lm('foo')
|
361
|
+
self.assertEqual(result, 'foo')
|
362
|
+
self.assertTrue(result.metadata.is_cached)
|
343
363
|
self.assertEqual(lm('bar'), 'bar')
|
344
364
|
self.assertEqual(cache.stats.num_queries, 4)
|
345
365
|
self.assertEqual(cache.stats.num_hits, 2)
|
@@ -361,14 +381,15 @@ class LanguageModelTest(unittest.TestCase):
|
|
361
381
|
cache_seed=0,
|
362
382
|
score=1.0,
|
363
383
|
logprobs=None,
|
364
|
-
|
384
|
+
is_cached=False,
|
385
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
|
365
386
|
tags=[message_lib.Message.TAG_LM_RESPONSE],
|
366
387
|
),
|
367
388
|
score=1.0,
|
368
389
|
logprobs=None,
|
369
390
|
)
|
370
391
|
],
|
371
|
-
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
392
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
|
372
393
|
),
|
373
394
|
lm_lib.LMSamplingResult(
|
374
395
|
[
|
@@ -378,14 +399,15 @@ class LanguageModelTest(unittest.TestCase):
|
|
378
399
|
cache_seed=0,
|
379
400
|
score=1.0,
|
380
401
|
logprobs=None,
|
381
|
-
|
402
|
+
is_cached=False,
|
403
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
|
382
404
|
tags=[message_lib.Message.TAG_LM_RESPONSE],
|
383
405
|
),
|
384
406
|
score=1.0,
|
385
407
|
logprobs=None,
|
386
408
|
)
|
387
409
|
],
|
388
|
-
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
410
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
|
389
411
|
),
|
390
412
|
],
|
391
413
|
)
|
@@ -663,20 +685,128 @@ class LanguageModelTest(unittest.TestCase):
|
|
663
685
|
lm2('hi')
|
664
686
|
list(concurrent.concurrent_map(call_lm, ['hi', 'hello']))
|
665
687
|
|
666
|
-
|
667
|
-
|
688
|
+
print(usages2)
|
689
|
+
self.assertEqual(usages2.uncached.breakdown, {
|
690
|
+
'model2': lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
|
691
|
+
})
|
692
|
+
self.assertFalse(usages2.cached)
|
693
|
+
self.assertEqual(usages3.uncached.breakdown, {
|
694
|
+
'model1': lm_lib.LMSamplingUsage(100 * 4, 100 * 4, 200 * 4, 4, 4.0),
|
668
695
|
})
|
669
|
-
self.
|
670
|
-
|
696
|
+
self.assertFalse(usages3.cached)
|
697
|
+
self.assertEqual(usages4.uncached.breakdown, {
|
698
|
+
'model1': lm_lib.LMSamplingUsage(100 * 4, 100 * 4, 200 * 4, 4, 4.0),
|
699
|
+
'model2': lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
|
671
700
|
})
|
672
|
-
self.
|
673
|
-
|
674
|
-
'
|
701
|
+
self.assertFalse(usages4.cached)
|
702
|
+
self.assertEqual(usages1.uncached.breakdown, {
|
703
|
+
'model1': lm_lib.LMSamplingUsage(100 * 5, 100 * 5, 200 * 5, 5, 5.0),
|
704
|
+
'model2': lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
|
675
705
|
})
|
676
|
-
self.
|
677
|
-
|
678
|
-
|
706
|
+
self.assertFalse(usages1.cached)
|
707
|
+
self.assertEqual(
|
708
|
+
usages1.total,
|
709
|
+
lm_lib.LMSamplingUsage(100 * 6, 100 * 6, 200 * 6, 6, 6.0),
|
710
|
+
)
|
711
|
+
|
712
|
+
cache = in_memory.InMemory()
|
713
|
+
lm = MockModel(cache=cache, name='model1')
|
714
|
+
with lm_lib.track_usages() as usages1:
|
715
|
+
_ = lm('hi')
|
716
|
+
self.assertEqual(usages1.uncached.breakdown, {
|
717
|
+
'model1': lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
|
679
718
|
})
|
719
|
+
self.assertFalse(usages1.cached)
|
720
|
+
with lm_lib.track_usages() as usages2:
|
721
|
+
_ = lm('hi')
|
722
|
+
self.assertEqual(usages2.cached.breakdown, {
|
723
|
+
'model1': lm_lib.LMSamplingUsage(100, 100, 200, 1, 0.0),
|
724
|
+
})
|
725
|
+
self.assertFalse(usages2.uncached)
|
726
|
+
|
727
|
+
|
728
|
+
class LMSamplingUsageTest(unittest.TestCase):
|
729
|
+
|
730
|
+
def test_basics(self):
|
731
|
+
usage = lm_lib.LMSamplingUsage(100, 200, 300, 4, 5.0)
|
732
|
+
self.assertEqual(usage.num_requests, 4)
|
733
|
+
self.assertEqual(usage.prompt_tokens, 100)
|
734
|
+
self.assertEqual(usage.completion_tokens, 200)
|
735
|
+
self.assertEqual(usage.total_tokens, 300)
|
736
|
+
self.assertEqual(usage.estimated_cost, 5.0)
|
737
|
+
self.assertEqual(usage.average_prompt_tokens, 25)
|
738
|
+
self.assertEqual(usage.average_completion_tokens, 50)
|
739
|
+
self.assertEqual(usage.average_total_tokens, 75)
|
740
|
+
self.assertEqual(usage.average_estimated_cost, 1.25)
|
741
|
+
|
742
|
+
def test_add(self):
|
743
|
+
usage1 = lm_lib.LMSamplingUsage(100, 200, 300, 4, 5.0)
|
744
|
+
usage2 = lm_lib.LMSamplingUsage(100, 200, 300, 4, 5.0)
|
745
|
+
self.assertEqual(usage1 + usage2, usage1 + usage2)
|
746
|
+
self.assertIs(usage1 + None, usage1)
|
747
|
+
self.assertIs(None + usage1, usage1)
|
748
|
+
|
749
|
+
def test_usage_not_available(self):
|
750
|
+
usage_not_available = lm_lib.UsageNotAvailable()
|
751
|
+
self.assertEqual(usage_not_available.prompt_tokens, 0)
|
752
|
+
self.assertEqual(usage_not_available.completion_tokens, 0)
|
753
|
+
self.assertEqual(usage_not_available.total_tokens, 0)
|
754
|
+
self.assertEqual(usage_not_available.average_prompt_tokens, 0)
|
755
|
+
self.assertEqual(usage_not_available.average_completion_tokens, 0)
|
756
|
+
self.assertEqual(usage_not_available.average_total_tokens, 0)
|
757
|
+
self.assertIsNone(usage_not_available.average_estimated_cost)
|
758
|
+
self.assertTrue(usage_not_available)
|
759
|
+
self.assertEqual(
|
760
|
+
usage_not_available + lm_lib.LMSamplingUsage(1, 2, 3, 4, 5.0),
|
761
|
+
lm_lib.UsageNotAvailable(num_requests=5)
|
762
|
+
)
|
763
|
+
self.assertEqual(
|
764
|
+
lm_lib.LMSamplingUsage(1, 2, 3, 4, 5.0) + usage_not_available,
|
765
|
+
lm_lib.UsageNotAvailable(num_requests=5)
|
766
|
+
)
|
767
|
+
self.assertIs(None + usage_not_available, usage_not_available)
|
768
|
+
self.assertIs(usage_not_available + None, usage_not_available)
|
769
|
+
|
770
|
+
|
771
|
+
class UsageSummaryTest(unittest.TestCase):
|
772
|
+
|
773
|
+
def test_basics(self):
|
774
|
+
usage_summary = lm_lib.UsageSummary()
|
775
|
+
self.assertFalse(usage_summary.total)
|
776
|
+
self.assertFalse(usage_summary.cached)
|
777
|
+
self.assertFalse(usage_summary.uncached)
|
778
|
+
|
779
|
+
# Add uncached.
|
780
|
+
usage_summary.update(
|
781
|
+
'model1', lm_lib.LMSamplingUsage(1, 2, 3, 1, 5.0), False
|
782
|
+
)
|
783
|
+
self.assertEqual(
|
784
|
+
usage_summary.total, lm_lib.LMSamplingUsage(1, 2, 3, 1, 5.0)
|
785
|
+
)
|
786
|
+
self.assertEqual(
|
787
|
+
usage_summary.uncached.total, lm_lib.LMSamplingUsage(1, 2, 3, 1, 5.0)
|
788
|
+
)
|
789
|
+
# Add cached.
|
790
|
+
self.assertFalse(usage_summary.cached)
|
791
|
+
usage_summary.update(
|
792
|
+
'model1', lm_lib.LMSamplingUsage(1, 2, 3, 1, 5.0), True
|
793
|
+
)
|
794
|
+
self.assertEqual(
|
795
|
+
usage_summary.total, lm_lib.LMSamplingUsage(2, 4, 6, 2, 5.0)
|
796
|
+
)
|
797
|
+
self.assertEqual(
|
798
|
+
usage_summary.cached.total, lm_lib.LMSamplingUsage(1, 2, 3, 1, 0.0)
|
799
|
+
)
|
800
|
+
# Add UsageNotAvailable.
|
801
|
+
usage_summary.update(
|
802
|
+
'model1', lm_lib.UsageNotAvailable(num_requests=1), False
|
803
|
+
)
|
804
|
+
self.assertEqual(
|
805
|
+
usage_summary.total, lm_lib.UsageNotAvailable(num_requests=3)
|
806
|
+
)
|
807
|
+
self.assertEqual(
|
808
|
+
usage_summary.uncached.total, lm_lib.UsageNotAvailable(num_requests=2)
|
809
|
+
)
|
680
810
|
|
681
811
|
|
682
812
|
if __name__ == '__main__':
|