langfun 0.0.2.dev20240330__py3-none-any.whl → 0.1.2.dev202501140804__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/__init__.py +22 -2
- langfun/core/__init__.py +17 -5
- langfun/core/agentic/__init__.py +30 -0
- langfun/core/agentic/action.py +854 -0
- langfun/core/agentic/action_eval.py +150 -0
- langfun/core/agentic/action_eval_test.py +109 -0
- langfun/core/agentic/action_test.py +136 -0
- langfun/core/coding/python/__init__.py +5 -11
- langfun/core/coding/python/correction.py +37 -28
- langfun/core/coding/python/correction_test.py +29 -3
- langfun/core/coding/python/execution.py +40 -216
- langfun/core/coding/python/execution_test.py +29 -89
- langfun/core/coding/python/generation.py +21 -11
- langfun/core/coding/python/generation_test.py +2 -2
- langfun/core/coding/python/parsing.py +108 -193
- langfun/core/coding/python/parsing_test.py +2 -105
- langfun/core/component.py +69 -2
- langfun/core/component_test.py +54 -0
- langfun/core/concurrent.py +414 -117
- langfun/core/concurrent_test.py +111 -24
- langfun/core/console.py +18 -5
- langfun/core/console_test.py +17 -0
- langfun/core/eval/__init__.py +17 -0
- langfun/core/eval/base.py +767 -140
- langfun/core/eval/base_test.py +238 -53
- langfun/core/eval/matching.py +80 -76
- langfun/core/eval/matching_test.py +19 -9
- langfun/core/eval/patching.py +130 -0
- langfun/core/eval/patching_test.py +170 -0
- langfun/core/eval/scoring.py +37 -28
- langfun/core/eval/scoring_test.py +21 -3
- langfun/core/eval/v2/__init__.py +42 -0
- langfun/core/eval/v2/checkpointing.py +380 -0
- langfun/core/eval/v2/checkpointing_test.py +228 -0
- langfun/core/eval/v2/eval_test_helper.py +136 -0
- langfun/core/eval/v2/evaluation.py +725 -0
- langfun/core/eval/v2/evaluation_test.py +180 -0
- langfun/core/eval/v2/example.py +305 -0
- langfun/core/eval/v2/example_test.py +128 -0
- langfun/core/eval/v2/experiment.py +1048 -0
- langfun/core/eval/v2/experiment_test.py +433 -0
- langfun/core/eval/v2/metric_values.py +156 -0
- langfun/core/eval/v2/metric_values_test.py +80 -0
- langfun/core/eval/v2/metrics.py +357 -0
- langfun/core/eval/v2/metrics_test.py +203 -0
- langfun/core/eval/v2/progress.py +348 -0
- langfun/core/eval/v2/progress_test.py +82 -0
- langfun/core/eval/v2/progress_tracking.py +210 -0
- langfun/core/eval/v2/progress_tracking_test.py +66 -0
- langfun/core/eval/v2/reporting.py +270 -0
- langfun/core/eval/v2/reporting_test.py +158 -0
- langfun/core/eval/v2/runners.py +488 -0
- langfun/core/eval/v2/runners_test.py +334 -0
- langfun/core/langfunc.py +3 -21
- langfun/core/langfunc_test.py +26 -8
- langfun/core/language_model.py +686 -48
- langfun/core/language_model_test.py +681 -44
- langfun/core/llms/__init__.py +100 -12
- langfun/core/llms/anthropic.py +488 -0
- langfun/core/llms/anthropic_test.py +235 -0
- langfun/core/llms/cache/base.py +21 -2
- langfun/core/llms/cache/in_memory.py +13 -0
- langfun/core/llms/cache/in_memory_test.py +88 -28
- langfun/core/llms/compositional.py +101 -0
- langfun/core/llms/compositional_test.py +73 -0
- langfun/core/llms/deepseek.py +117 -0
- langfun/core/llms/deepseek_test.py +61 -0
- langfun/core/llms/fake.py +39 -26
- langfun/core/llms/fake_test.py +136 -11
- langfun/core/llms/gemini.py +507 -0
- langfun/core/llms/gemini_test.py +195 -0
- langfun/core/llms/google_genai.py +62 -218
- langfun/core/llms/google_genai_test.py +9 -197
- langfun/core/llms/groq.py +276 -0
- langfun/core/llms/groq_test.py +64 -0
- langfun/core/llms/llama_cpp.py +15 -40
- langfun/core/llms/llama_cpp_test.py +4 -30
- langfun/core/llms/openai.py +436 -226
- langfun/core/llms/openai_compatible.py +179 -0
- langfun/core/llms/openai_compatible_test.py +495 -0
- langfun/core/llms/openai_test.py +35 -174
- langfun/core/llms/rest.py +113 -0
- langfun/core/llms/rest_test.py +111 -0
- langfun/core/llms/vertexai.py +192 -0
- langfun/core/llms/vertexai_test.py +52 -0
- langfun/core/logging.py +284 -0
- langfun/core/logging_test.py +125 -0
- langfun/core/message.py +319 -9
- langfun/core/message_test.py +190 -13
- langfun/core/modalities/__init__.py +6 -2
- langfun/core/modalities/audio.py +30 -0
- langfun/core/modalities/audio_test.py +63 -0
- langfun/core/modalities/image.py +39 -20
- langfun/core/modalities/image_test.py +52 -9
- langfun/core/modalities/mime.py +206 -29
- langfun/core/modalities/mime_test.py +90 -9
- langfun/core/modalities/ms_office.py +117 -0
- langfun/core/modalities/ms_office_test.py +389 -0
- langfun/core/modalities/pdf.py +22 -0
- langfun/core/modalities/pdf_test.py +57 -0
- langfun/core/modalities/video.py +9 -23
- langfun/core/modalities/video_test.py +3 -3
- langfun/core/modality.py +26 -3
- langfun/core/modality_test.py +2 -2
- langfun/core/sampling.py +11 -11
- langfun/core/structured/__init__.py +15 -16
- langfun/core/structured/completion.py +32 -5
- langfun/core/structured/completion_test.py +9 -8
- langfun/core/structured/description.py +2 -2
- langfun/core/structured/description_test.py +3 -3
- langfun/core/structured/function_generation.py +278 -0
- langfun/core/structured/function_generation_test.py +399 -0
- langfun/core/structured/mapping.py +150 -46
- langfun/core/structured/mapping_test.py +105 -0
- langfun/core/structured/parsing.py +33 -21
- langfun/core/structured/parsing_test.py +71 -22
- langfun/core/structured/querying.py +746 -0
- langfun/core/structured/{prompting_test.py → querying_test.py} +545 -60
- langfun/core/structured/schema.py +208 -99
- langfun/core/structured/schema_generation.py +1 -1
- langfun/core/structured/schema_generation_test.py +2 -2
- langfun/core/structured/schema_test.py +133 -34
- langfun/core/structured/scoring.py +125 -19
- langfun/core/structured/scoring_test.py +30 -0
- langfun/core/structured/tokenization.py +64 -0
- langfun/core/structured/tokenization_test.py +48 -0
- langfun/core/template.py +240 -11
- langfun/core/template_test.py +146 -1
- langfun/core/templates/conversation.py +9 -0
- langfun/core/templates/conversation_test.py +4 -3
- langfun/core/templates/selfplay_test.py +14 -2
- langfun-0.1.2.dev202501140804.dist-info/METADATA +225 -0
- langfun-0.1.2.dev202501140804.dist-info/RECORD +153 -0
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/WHEEL +1 -1
- langfun/core/coding/python/errors.py +0 -108
- langfun/core/coding/python/errors_test.py +0 -99
- langfun/core/coding/python/permissions.py +0 -90
- langfun/core/coding/python/permissions_test.py +0 -86
- langfun/core/structured/prompting.py +0 -217
- langfun/core/text_formatting.py +0 -162
- langfun/core/text_formatting_test.py +0 -47
- langfun-0.0.2.dev20240330.dist-info/METADATA +0 -99
- langfun-0.0.2.dev20240330.dist-info/RECORD +0 -102
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/LICENSE +0 -0
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/top_level.txt +0 -0
langfun/core/language_model.py
CHANGED
@@ -14,16 +14,50 @@
|
|
14
14
|
"""Interface for language model."""
|
15
15
|
|
16
16
|
import abc
|
17
|
+
import contextlib
|
17
18
|
import dataclasses
|
18
19
|
import enum
|
20
|
+
import functools
|
21
|
+
import math
|
22
|
+
import threading
|
19
23
|
import time
|
20
|
-
from typing import Annotated, Any, Callable, Sequence, Tuple, Type, Union
|
24
|
+
from typing import Annotated, Any, Callable, Iterator, Optional, Sequence, Tuple, Type, Union
|
21
25
|
from langfun.core import component
|
22
26
|
from langfun.core import concurrent
|
23
27
|
from langfun.core import console
|
24
28
|
from langfun.core import message as message_lib
|
29
|
+
|
25
30
|
import pyglove as pg
|
26
31
|
|
32
|
+
TOKENS_PER_REQUEST = 250 # Estimated num tokens for a single request
|
33
|
+
DEFAULT_MAX_CONCURRENCY = 1 # Use this as max concurrency if no RPM or TPM data
|
34
|
+
|
35
|
+
|
36
|
+
#
|
37
|
+
# Common errors during calling language models.
|
38
|
+
#
|
39
|
+
|
40
|
+
|
41
|
+
class LMError(RuntimeError):
|
42
|
+
"""Base class for language model errors."""
|
43
|
+
|
44
|
+
|
45
|
+
class RetryableLMError(LMError):
|
46
|
+
"""Base class for LLM errors that can be solved by retrying."""
|
47
|
+
|
48
|
+
|
49
|
+
class RateLimitError(RetryableLMError):
|
50
|
+
"""Error for rate limit reached."""
|
51
|
+
|
52
|
+
|
53
|
+
class TemporaryLMError(RetryableLMError):
|
54
|
+
"""Error for temporary service issues that can be retried."""
|
55
|
+
|
56
|
+
|
57
|
+
#
|
58
|
+
# Language model input/output interfaces.
|
59
|
+
#
|
60
|
+
|
27
61
|
|
28
62
|
class LMSample(pg.Object):
|
29
63
|
"""Response candidate."""
|
@@ -47,6 +81,142 @@ class LMSample(pg.Object):
|
|
47
81
|
] = None
|
48
82
|
|
49
83
|
|
84
|
+
class RetryStats(pg.Object):
|
85
|
+
"""Retry stats, which is aggregated across multiple retry entries."""
|
86
|
+
|
87
|
+
num_occurences: Annotated[
|
88
|
+
int,
|
89
|
+
'Total number of retry attempts on LLM (excluding the first attempt).',
|
90
|
+
] = 0
|
91
|
+
total_wait_interval: Annotated[
|
92
|
+
float, 'Total wait interval in seconds due to retry.'
|
93
|
+
] = 0
|
94
|
+
total_call_interval: Annotated[
|
95
|
+
float, 'Total LLM call interval in seconds.'
|
96
|
+
] = 0
|
97
|
+
errors: Annotated[
|
98
|
+
dict[str, int],
|
99
|
+
'A Counter of error types encountered during the retry attempts.',
|
100
|
+
] = {}
|
101
|
+
|
102
|
+
@classmethod
|
103
|
+
def from_retry_entries(
|
104
|
+
cls, retry_entries: Sequence[concurrent.RetryEntry]
|
105
|
+
) -> 'RetryStats':
|
106
|
+
"""Creates a RetryStats from a sequence of RetryEntry."""
|
107
|
+
if not retry_entries:
|
108
|
+
return RetryStats()
|
109
|
+
errors = {}
|
110
|
+
for retry in retry_entries:
|
111
|
+
if retry.error is not None:
|
112
|
+
errors[retry.error.__class__.__name__] = (
|
113
|
+
errors.get(retry.error.__class__.__name__, 0) + 1
|
114
|
+
)
|
115
|
+
return RetryStats(
|
116
|
+
num_occurences=len(retry_entries) - 1,
|
117
|
+
total_wait_interval=sum(e.wait_interval for e in retry_entries),
|
118
|
+
total_call_interval=sum(e.call_interval for e in retry_entries),
|
119
|
+
errors=errors,
|
120
|
+
)
|
121
|
+
|
122
|
+
def __add__(self, other: 'RetryStats') -> 'RetryStats':
|
123
|
+
errors = self.errors.copy()
|
124
|
+
for error, count in other.errors.items():
|
125
|
+
errors[error] = errors.get(error, 0) + count
|
126
|
+
return RetryStats(
|
127
|
+
num_occurences=self.num_occurences + other.num_occurences,
|
128
|
+
total_wait_interval=self.total_wait_interval
|
129
|
+
+ other.total_wait_interval,
|
130
|
+
total_call_interval=self.total_call_interval
|
131
|
+
+ other.total_call_interval,
|
132
|
+
errors=errors,
|
133
|
+
)
|
134
|
+
|
135
|
+
def __radd__(self, other: 'RetryStats') -> 'RetryStats':
|
136
|
+
return self + other
|
137
|
+
|
138
|
+
|
139
|
+
class LMSamplingUsage(pg.Object):
|
140
|
+
"""Usage information per completion."""
|
141
|
+
|
142
|
+
prompt_tokens: int
|
143
|
+
completion_tokens: int
|
144
|
+
total_tokens: int
|
145
|
+
num_requests: int = 1
|
146
|
+
estimated_cost: Annotated[
|
147
|
+
float | None,
|
148
|
+
(
|
149
|
+
'Estimated cost in US dollars. If None, cost estimating is not '
|
150
|
+
'suppported on the model being queried.'
|
151
|
+
),
|
152
|
+
] = None
|
153
|
+
retry_stats: RetryStats = RetryStats()
|
154
|
+
|
155
|
+
def __bool__(self) -> bool:
|
156
|
+
return self.num_requests > 0
|
157
|
+
|
158
|
+
@property
|
159
|
+
def average_prompt_tokens(self) -> int:
|
160
|
+
"""Returns the average prompt tokens per request."""
|
161
|
+
return self.prompt_tokens // self.num_requests
|
162
|
+
|
163
|
+
@property
|
164
|
+
def average_completion_tokens(self) -> int:
|
165
|
+
"""Returns the average completion tokens per request."""
|
166
|
+
return self.completion_tokens // self.num_requests
|
167
|
+
|
168
|
+
@property
|
169
|
+
def average_total_tokens(self) -> int:
|
170
|
+
"""Returns the average total tokens per request."""
|
171
|
+
return self.total_tokens // self.num_requests
|
172
|
+
|
173
|
+
@property
|
174
|
+
def average_estimated_cost(self) -> float | None:
|
175
|
+
"""Returns the average estimated cost per request."""
|
176
|
+
if self.estimated_cost is None:
|
177
|
+
return None
|
178
|
+
return self.estimated_cost / self.num_requests
|
179
|
+
|
180
|
+
def __add__(self, other: Optional['LMSamplingUsage']) -> 'LMSamplingUsage':
|
181
|
+
if other is None:
|
182
|
+
return self
|
183
|
+
if self.estimated_cost is None:
|
184
|
+
estimated_cost = other.estimated_cost
|
185
|
+
elif other.estimated_cost is None:
|
186
|
+
estimated_cost = self.estimated_cost
|
187
|
+
else:
|
188
|
+
estimated_cost = self.estimated_cost + other.estimated_cost
|
189
|
+
return LMSamplingUsage(
|
190
|
+
prompt_tokens=self.prompt_tokens + other.prompt_tokens,
|
191
|
+
completion_tokens=self.completion_tokens + other.completion_tokens,
|
192
|
+
total_tokens=self.total_tokens + other.total_tokens,
|
193
|
+
num_requests=self.num_requests + other.num_requests,
|
194
|
+
estimated_cost=estimated_cost,
|
195
|
+
retry_stats=self.retry_stats + other.retry_stats,
|
196
|
+
)
|
197
|
+
|
198
|
+
def __radd__(self, other: Optional['LMSamplingUsage']) -> 'LMSamplingUsage':
|
199
|
+
return self + other
|
200
|
+
|
201
|
+
|
202
|
+
class UsageNotAvailable(LMSamplingUsage):
|
203
|
+
"""Usage information not available."""
|
204
|
+
prompt_tokens: pg.typing.Int(0).freeze() # pytype: disable=invalid-annotation
|
205
|
+
completion_tokens: pg.typing.Int(0).freeze() # pytype: disable=invalid-annotation
|
206
|
+
total_tokens: pg.typing.Int(0).freeze() # pytype: disable=invalid-annotation
|
207
|
+
estimated_cost: pg.typing.Float(default=None, is_noneable=True).freeze() # pytype: disable=invalid-annotation
|
208
|
+
|
209
|
+
def __add__(self, other: Optional['LMSamplingUsage']) -> 'UsageNotAvailable':
|
210
|
+
if other is None:
|
211
|
+
return self
|
212
|
+
return UsageNotAvailable(
|
213
|
+
num_requests=self.num_requests + other.num_requests
|
214
|
+
)
|
215
|
+
|
216
|
+
def __radd__(self, other: Optional['LMSamplingUsage']) -> 'UsageNotAvailable':
|
217
|
+
return self + other
|
218
|
+
|
219
|
+
|
50
220
|
class LMSamplingResult(pg.Object):
|
51
221
|
"""Language model response."""
|
52
222
|
|
@@ -58,19 +228,39 @@ class LMSamplingResult(pg.Object):
|
|
58
228
|
),
|
59
229
|
] = []
|
60
230
|
|
231
|
+
usage: Annotated[
|
232
|
+
LMSamplingUsage,
|
233
|
+
'Usage information. Currently only OpenAI models are supported.',
|
234
|
+
] = UsageNotAvailable()
|
235
|
+
|
236
|
+
is_cached: Annotated[
|
237
|
+
bool,
|
238
|
+
'Whether the result is from cache or not.'
|
239
|
+
] = False
|
240
|
+
|
61
241
|
|
62
242
|
class LMSamplingOptions(component.Component):
|
63
243
|
"""Language model sampling options."""
|
64
244
|
|
65
245
|
temperature: Annotated[
|
66
|
-
float,
|
246
|
+
float | None,
|
67
247
|
(
|
68
248
|
'Model temperature, which is usually between 0 and 1.0. '
|
69
|
-
'OpenAI models have temperature range from 0.0 to 2.0.'
|
249
|
+
'OpenAI models have temperature range from 0.0 to 2.0. '
|
250
|
+
'If None (default), honor the model\'s default behavior. '
|
70
251
|
)
|
71
|
-
] =
|
72
|
-
|
252
|
+
] = None
|
253
|
+
|
254
|
+
max_tokens: Annotated[
|
255
|
+
int | None,
|
256
|
+
(
|
257
|
+
'Per example max tokens to generate. '
|
258
|
+
'If None, use the model default.'
|
259
|
+
)
|
260
|
+
] = None
|
261
|
+
|
73
262
|
n: Annotated[int | None, 'Max number of samples to return.'] = 1
|
263
|
+
|
74
264
|
top_k: Annotated[
|
75
265
|
int | None,
|
76
266
|
(
|
@@ -78,6 +268,7 @@ class LMSamplingOptions(component.Component):
|
|
78
268
|
'Not applicable to OpenAI models.'
|
79
269
|
)
|
80
270
|
] = 40
|
271
|
+
|
81
272
|
top_p: Annotated[
|
82
273
|
float | None,
|
83
274
|
(
|
@@ -86,6 +277,7 @@ class LMSamplingOptions(component.Component):
|
|
86
277
|
'`top_p` but not both.'
|
87
278
|
),
|
88
279
|
] = None
|
280
|
+
|
89
281
|
stop: Annotated[
|
90
282
|
list[str] | None,
|
91
283
|
(
|
@@ -95,9 +287,11 @@ class LMSamplingOptions(component.Component):
|
|
95
287
|
'`Model:` is reached.'
|
96
288
|
),
|
97
289
|
] = None
|
290
|
+
|
98
291
|
random_seed: Annotated[
|
99
292
|
int | None, 'A fixed random seed used during model inference.'
|
100
293
|
] = None
|
294
|
+
|
101
295
|
logprobs: Annotated[
|
102
296
|
bool,
|
103
297
|
(
|
@@ -106,6 +300,7 @@ class LMSamplingOptions(component.Component):
|
|
106
300
|
'in the content of message.'
|
107
301
|
),
|
108
302
|
] = False
|
303
|
+
|
109
304
|
top_logprobs: Annotated[
|
110
305
|
int | None,
|
111
306
|
(
|
@@ -135,6 +330,11 @@ class LMScoringResult(pg.Object):
|
|
135
330
|
float,
|
136
331
|
'The log likelyhood of the requested completion towards the prompt.',
|
137
332
|
]
|
333
|
+
gradients: Annotated[
|
334
|
+
Any | None,
|
335
|
+
'(Optional) gradients from the score method, w.r.t.' +
|
336
|
+
' prompt.metadata.weights.',
|
337
|
+
] = None
|
138
338
|
|
139
339
|
|
140
340
|
class LMCache(pg.Object):
|
@@ -149,6 +349,7 @@ class LMCache(pg.Object):
|
|
149
349
|
num_hit_expires: int = 0
|
150
350
|
num_misses: int = 0
|
151
351
|
num_updates: int = 0
|
352
|
+
num_deletes: int = 0
|
152
353
|
|
153
354
|
@abc.abstractmethod
|
154
355
|
def get(
|
@@ -166,6 +367,15 @@ class LMCache(pg.Object):
|
|
166
367
|
) -> None:
|
167
368
|
"""Puts the result of a prompt generated by a language model in cache."""
|
168
369
|
|
370
|
+
@abc.abstractmethod
|
371
|
+
def delete(
|
372
|
+
self,
|
373
|
+
lm: 'LanguageModel',
|
374
|
+
prompt: message_lib.Message,
|
375
|
+
seed: int,
|
376
|
+
) -> bool:
|
377
|
+
"""Deletes the result of a prompt generated by a language model in cache."""
|
378
|
+
|
169
379
|
@property
|
170
380
|
@abc.abstractmethod
|
171
381
|
def stats(self) -> Stats:
|
@@ -259,6 +469,15 @@ class LanguageModel(component.Component):
|
|
259
469
|
)
|
260
470
|
] = True
|
261
471
|
|
472
|
+
max_retry_interval: Annotated[
|
473
|
+
int,
|
474
|
+
(
|
475
|
+
'The max retry interval in seconds. This is useful when the retry '
|
476
|
+
'interval is exponential, to avoid the wait time to grow '
|
477
|
+
'exponentially.'
|
478
|
+
)
|
479
|
+
] = 300
|
480
|
+
|
262
481
|
debug: Annotated[
|
263
482
|
bool | LMDebugMode,
|
264
483
|
(
|
@@ -272,7 +491,10 @@ class LanguageModel(component.Component):
|
|
272
491
|
def __init__(self, *args, **kwargs) -> None:
|
273
492
|
"""Overrides __init__ to pass through **kwargs to sampling options."""
|
274
493
|
|
275
|
-
sampling_options = kwargs.pop(
|
494
|
+
sampling_options = kwargs.pop(
|
495
|
+
'sampling_options',
|
496
|
+
pg.clone(self.__schema__.fields['sampling_options'].default_value)
|
497
|
+
)
|
276
498
|
sampling_options_delta = {}
|
277
499
|
|
278
500
|
for k, v in kwargs.items():
|
@@ -315,9 +537,64 @@ class LanguageModel(component.Component):
|
|
315
537
|
|
316
538
|
with component.context(override_attrs=True, **kwargs):
|
317
539
|
if self.cache is None:
|
318
|
-
|
540
|
+
results = self._sample(prompts)
|
319
541
|
else:
|
320
|
-
|
542
|
+
results = self._sample_with_cache_lookup(prompts, cache_seed)
|
543
|
+
|
544
|
+
for prompt, result in zip(prompts, results):
|
545
|
+
|
546
|
+
# Tag LM input.
|
547
|
+
prompt.tag(message_lib.Message.TAG_LM_INPUT)
|
548
|
+
|
549
|
+
for sample in result.samples:
|
550
|
+
# Update metadata for response message.
|
551
|
+
|
552
|
+
response = sample.response
|
553
|
+
response.metadata.score = sample.score
|
554
|
+
response.metadata.logprobs = sample.logprobs
|
555
|
+
response.metadata.is_cached = result.is_cached
|
556
|
+
|
557
|
+
# NOTE(daiyip): Current usage is computed at per-result level,
|
558
|
+
# which is accurate when n=1. For n > 1, we average the usage across
|
559
|
+
# multiple samples.
|
560
|
+
usage = result.usage
|
561
|
+
if len(result.samples) == 1 or isinstance(usage, UsageNotAvailable):
|
562
|
+
response.metadata.usage = usage
|
563
|
+
else:
|
564
|
+
n = len(result.samples)
|
565
|
+
response.metadata.usage = LMSamplingUsage(
|
566
|
+
prompt_tokens=usage.prompt_tokens // n,
|
567
|
+
completion_tokens=usage.completion_tokens // n,
|
568
|
+
total_tokens=usage.total_tokens // n,
|
569
|
+
estimated_cost=(
|
570
|
+
usage.estimated_cost / n if usage.estimated_cost else None
|
571
|
+
),
|
572
|
+
retry_stats=RetryStats(
|
573
|
+
num_occurences=usage.retry_stats.num_occurences // n,
|
574
|
+
total_wait_interval=usage.retry_stats.total_wait_interval
|
575
|
+
/ n,
|
576
|
+
total_call_interval=usage.retry_stats.total_call_interval
|
577
|
+
/ n,
|
578
|
+
errors={
|
579
|
+
error: count // n
|
580
|
+
for error, count in usage.retry_stats.errors.items()
|
581
|
+
},
|
582
|
+
),
|
583
|
+
)
|
584
|
+
|
585
|
+
# Track usage.
|
586
|
+
trackers = component.context_value('__usage_trackers__', [])
|
587
|
+
if trackers:
|
588
|
+
model_id = self.model_id
|
589
|
+
for tracker in trackers:
|
590
|
+
tracker.track(model_id, usage, result.is_cached)
|
591
|
+
|
592
|
+
# Track the prompt for corresponding response.
|
593
|
+
response.source = prompt
|
594
|
+
|
595
|
+
# Tag LM response.
|
596
|
+
response.tag(message_lib.Message.TAG_LM_RESPONSE)
|
597
|
+
return results
|
321
598
|
|
322
599
|
def _sample_with_cache_lookup(
|
323
600
|
self, prompts: list[str | message_lib.Message], cache_seed: int
|
@@ -339,7 +616,9 @@ class LanguageModel(component.Component):
|
|
339
616
|
request_to_result_index[len(requests)] = i
|
340
617
|
requests.append(prompt)
|
341
618
|
else:
|
342
|
-
|
619
|
+
result = r.clone()
|
620
|
+
assert result.is_cached, result
|
621
|
+
results[i] = result
|
343
622
|
|
344
623
|
# Sample non-cache-hit prompts.
|
345
624
|
if requests:
|
@@ -356,8 +635,12 @@ class LanguageModel(component.Component):
|
|
356
635
|
sample.response.set('cache_seed', cache_seed)
|
357
636
|
|
358
637
|
if cache_seed is not None:
|
359
|
-
self.cache.put(
|
360
|
-
|
638
|
+
self.cache.put(
|
639
|
+
self,
|
640
|
+
prompt,
|
641
|
+
result.clone(override=dict(is_cached=True)),
|
642
|
+
seed=cache_seed
|
643
|
+
)
|
361
644
|
return results # pytype: disable=bad-return-type
|
362
645
|
|
363
646
|
@abc.abstractmethod
|
@@ -369,16 +652,16 @@ class LanguageModel(component.Component):
|
|
369
652
|
|
370
653
|
def _parallel_execute_with_currency_control(
|
371
654
|
self,
|
372
|
-
action: Callable[...,
|
655
|
+
action: Callable[..., LMSamplingResult],
|
373
656
|
inputs: Sequence[Any],
|
374
657
|
retry_on_errors: Union[
|
375
658
|
None,
|
376
|
-
Union[Type[
|
377
|
-
Sequence[Union[Type[
|
378
|
-
] =
|
379
|
-
) -> Any:
|
659
|
+
Union[Type[BaseException], Tuple[Type[BaseException], str]],
|
660
|
+
Sequence[Union[Type[BaseException], Tuple[Type[BaseException], str]]],
|
661
|
+
] = RetryableLMError,
|
662
|
+
) -> list[Any]:
|
380
663
|
"""Helper method for subclasses for implementing _sample."""
|
381
|
-
|
664
|
+
executed_jobs = concurrent.concurrent_execute(
|
382
665
|
action,
|
383
666
|
inputs,
|
384
667
|
executor=self.resource_id if self.max_concurrency else None,
|
@@ -387,7 +670,16 @@ class LanguageModel(component.Component):
|
|
387
670
|
max_attempts=self.max_attempts,
|
388
671
|
retry_interval=self.retry_interval,
|
389
672
|
exponential_backoff=self.exponential_backoff,
|
673
|
+
max_retry_interval=self.max_retry_interval,
|
674
|
+
return_jobs=True,
|
390
675
|
)
|
676
|
+
for job in executed_jobs:
|
677
|
+
if isinstance(job.result, LMSamplingResult):
|
678
|
+
job.result.usage.rebind(
|
679
|
+
retry_stats=RetryStats.from_retry_entries(job.retry_entries),
|
680
|
+
skip_notification=True,
|
681
|
+
)
|
682
|
+
return [job.result for job in executed_jobs]
|
391
683
|
|
392
684
|
def __call__(
|
393
685
|
self, prompt: message_lib.Message, *, cache_seed: int = 0, **kwargs
|
@@ -405,12 +697,9 @@ class LanguageModel(component.Component):
|
|
405
697
|
result = self.sample(
|
406
698
|
[prompt], sampling_options=sampling_options, cache_seed=cache_seed
|
407
699
|
)[0]
|
408
|
-
response = result.samples[0].response
|
409
|
-
logprobs = result.samples[0].logprobs
|
410
|
-
response.set('score', result.samples[0].score)
|
411
|
-
response.metadata.logprobs = logprobs
|
412
700
|
elapse = time.time() - request_start
|
413
|
-
|
701
|
+
response = result.samples[0].response
|
702
|
+
self._debug(prompt, response, call_counter, result.usage, elapse)
|
414
703
|
return response
|
415
704
|
|
416
705
|
def _debug(
|
@@ -418,35 +707,54 @@ class LanguageModel(component.Component):
|
|
418
707
|
prompt: message_lib.Message,
|
419
708
|
response: message_lib.Message,
|
420
709
|
call_counter: int,
|
710
|
+
usage: LMSamplingUsage,
|
421
711
|
elapse: float,
|
422
|
-
):
|
712
|
+
) -> None:
|
423
713
|
"""Outputs debugging information."""
|
424
714
|
debug = self.debug
|
425
715
|
if isinstance(debug, bool):
|
426
716
|
debug = LMDebugMode.ALL if debug else LMDebugMode.NONE
|
427
717
|
|
428
718
|
if debug & LMDebugMode.INFO:
|
429
|
-
self._debug_model_info(call_counter)
|
719
|
+
self._debug_model_info(call_counter, usage)
|
430
720
|
|
431
721
|
if debug & LMDebugMode.PROMPT:
|
432
|
-
self._debug_prompt(prompt, call_counter)
|
722
|
+
self._debug_prompt(prompt, call_counter, usage)
|
433
723
|
|
434
724
|
if debug & LMDebugMode.RESPONSE:
|
435
|
-
self._debug_response(response, call_counter, elapse)
|
725
|
+
self._debug_response(response, call_counter, usage, elapse)
|
436
726
|
|
437
|
-
def _debug_model_info(
|
727
|
+
def _debug_model_info(
|
728
|
+
self, call_counter: int, usage: LMSamplingUsage) -> None:
|
438
729
|
"""Outputs debugging information about the model."""
|
730
|
+
title_suffix = ''
|
731
|
+
if usage.total_tokens != 0:
|
732
|
+
title_suffix = pg.colored(
|
733
|
+
f' (total {usage.total_tokens} tokens)', 'red'
|
734
|
+
)
|
735
|
+
|
439
736
|
console.write(
|
440
737
|
self.format(compact=True, use_inferred=True),
|
441
|
-
title=f'[{call_counter}] LM INFO:',
|
738
|
+
title=f'[{call_counter}] LM INFO{title_suffix}:',
|
442
739
|
color='magenta',
|
443
740
|
)
|
444
741
|
|
445
|
-
def _debug_prompt(
|
742
|
+
def _debug_prompt(
|
743
|
+
self,
|
744
|
+
prompt: message_lib.Message,
|
745
|
+
call_counter: int,
|
746
|
+
usage: LMSamplingUsage,
|
747
|
+
) -> None:
|
446
748
|
"""Outputs debugging information about the prompt."""
|
749
|
+
title_suffix = ''
|
750
|
+
if usage.prompt_tokens != 0:
|
751
|
+
title_suffix = pg.colored(f' ({usage.prompt_tokens} tokens)', 'red')
|
752
|
+
|
447
753
|
console.write(
|
448
|
-
prompt
|
449
|
-
|
754
|
+
# We use metadata 'formatted_text' for scenarios where the prompt text
|
755
|
+
# is formatted by the LM.
|
756
|
+
prompt.get('formatted_text', prompt.text),
|
757
|
+
title=f'\n[{call_counter}] PROMPT SENT TO LM{title_suffix}:',
|
450
758
|
color='green',
|
451
759
|
)
|
452
760
|
referred_modalities = prompt.referred_modalities()
|
@@ -460,23 +768,40 @@ class LanguageModel(component.Component):
|
|
460
768
|
)
|
461
769
|
|
462
770
|
def _debug_response(
|
463
|
-
self,
|
464
|
-
|
771
|
+
self,
|
772
|
+
response: message_lib.Message,
|
773
|
+
call_counter: int,
|
774
|
+
usage: LMSamplingUsage,
|
775
|
+
elapse: float
|
776
|
+
) -> None:
|
465
777
|
"""Outputs debugging information about the response."""
|
778
|
+
title_suffix = ' ('
|
779
|
+
if usage.completion_tokens != 0:
|
780
|
+
title_suffix += f'{usage.completion_tokens} tokens '
|
781
|
+
title_suffix += f'in {elapse:.2f} seconds)'
|
782
|
+
title_suffix = pg.colored(title_suffix, 'red')
|
783
|
+
|
466
784
|
console.write(
|
467
785
|
str(response) + '\n',
|
468
|
-
title=f'\n[{call_counter}] LM RESPONSE
|
786
|
+
title=f'\n[{call_counter}] LM RESPONSE{title_suffix}:',
|
469
787
|
color='blue',
|
470
788
|
)
|
471
789
|
|
472
790
|
def score(
|
473
791
|
self,
|
474
|
-
prompt: str | message_lib.Message,
|
792
|
+
prompt: str | message_lib.Message | list[message_lib.Message],
|
475
793
|
completions: list[str | message_lib.Message],
|
476
794
|
**kwargs,
|
477
795
|
) -> list[LMScoringResult]:
|
478
796
|
"""Scores the given prompt."""
|
479
|
-
|
797
|
+
if isinstance(prompt, list):
|
798
|
+
if len(prompt) != len(completions):
|
799
|
+
raise ValueError(
|
800
|
+
'prompt and completions must have the same length.'
|
801
|
+
)
|
802
|
+
prompt = [message_lib.UserMessage.from_value(p) for p in prompt]
|
803
|
+
else:
|
804
|
+
prompt = message_lib.UserMessage.from_value(prompt)
|
480
805
|
completions = [message_lib.UserMessage.from_value(c) for c in completions]
|
481
806
|
|
482
807
|
call_counter = self._call_counter
|
@@ -492,7 +817,8 @@ class LanguageModel(component.Component):
|
|
492
817
|
return scoring_results
|
493
818
|
|
494
819
|
def _score(
|
495
|
-
self, prompt: message_lib.Message
|
820
|
+
self, prompt: message_lib.Message | list[message_lib.Message],
|
821
|
+
completions: list[message_lib.Message]
|
496
822
|
) -> list[LMScoringResult]:
|
497
823
|
"""Subclass to implement."""
|
498
824
|
raise NotImplementedError(
|
@@ -501,7 +827,7 @@ class LanguageModel(component.Component):
|
|
501
827
|
|
502
828
|
def _debug_score(
|
503
829
|
self,
|
504
|
-
prompt: message_lib.Message,
|
830
|
+
prompt: message_lib.Message | list[message_lib.Message],
|
505
831
|
completions: list[message_lib.Message],
|
506
832
|
scoring_results: list[LMScoringResult],
|
507
833
|
call_counter: int,
|
@@ -512,7 +838,7 @@ class LanguageModel(component.Component):
|
|
512
838
|
debug = LMDebugMode.ALL if debug else LMDebugMode.NONE
|
513
839
|
|
514
840
|
if debug & LMDebugMode.INFO:
|
515
|
-
self._debug_model_info(call_counter)
|
841
|
+
self._debug_model_info(call_counter, UsageNotAvailable())
|
516
842
|
|
517
843
|
if debug & LMDebugMode.PROMPT:
|
518
844
|
console.write(
|
@@ -520,15 +846,19 @@ class LanguageModel(component.Component):
|
|
520
846
|
title=f'\n[{call_counter}] SCORING LM WITH PROMPT:',
|
521
847
|
color='green',
|
522
848
|
)
|
523
|
-
|
524
|
-
|
525
|
-
|
526
|
-
|
527
|
-
|
528
|
-
|
529
|
-
|
530
|
-
|
531
|
-
|
849
|
+
if isinstance(prompt, list):
|
850
|
+
referred_modalities_lst = [p.referred_modalities() for p in prompt]
|
851
|
+
else:
|
852
|
+
referred_modalities_lst = [prompt.referred_modalities(),]
|
853
|
+
if referred_modalities_lst:
|
854
|
+
for referred_modalities in referred_modalities_lst:
|
855
|
+
console.write(
|
856
|
+
pg.object_utils.kvlist_str(
|
857
|
+
[(k, repr(v), None) for k, v in referred_modalities.items()]
|
858
|
+
),
|
859
|
+
title=f'\n[{call_counter}] MODALITY OBJECTS SENT TO LM:',
|
860
|
+
color='green',
|
861
|
+
)
|
532
862
|
|
533
863
|
if debug & LMDebugMode.RESPONSE:
|
534
864
|
console.write(
|
@@ -548,3 +878,311 @@ class LanguageModel(component.Component):
|
|
548
878
|
f'score: {r.score}',
|
549
879
|
color='blue',
|
550
880
|
)
|
881
|
+
|
882
|
+
def tokenize(
|
883
|
+
self,
|
884
|
+
prompt: str | message_lib.Message,
|
885
|
+
**kwargs,
|
886
|
+
) -> list[tuple[str | bytes, int]]:
|
887
|
+
"""Tokenizes the given prompt."""
|
888
|
+
prompt = message_lib.UserMessage.from_value(prompt)
|
889
|
+
call_counter = self._call_counter
|
890
|
+
self._call_counter += 1
|
891
|
+
|
892
|
+
with component.context(override_attrs=True, **kwargs):
|
893
|
+
request_start = time.time()
|
894
|
+
tokens = self._tokenize(prompt)
|
895
|
+
elapse = time.time() - request_start
|
896
|
+
self._debug_tokenize(prompt, tokens, call_counter, elapse)
|
897
|
+
return tokens
|
898
|
+
|
899
|
+
def _tokenize(
|
900
|
+
self, prompt: message_lib.Message
|
901
|
+
) -> list[tuple[str | bytes, int]]:
|
902
|
+
"""Subclass to implement."""
|
903
|
+
raise NotImplementedError(
|
904
|
+
f'{self.__class__.__name__} does not support tokenization.'
|
905
|
+
)
|
906
|
+
|
907
|
+
def _debug_tokenize(
|
908
|
+
self,
|
909
|
+
prompt: message_lib.Message,
|
910
|
+
tokens: list[tuple[str | bytes, int]],
|
911
|
+
call_counter: int,
|
912
|
+
elapse: float,
|
913
|
+
):
|
914
|
+
debug = self.debug
|
915
|
+
if isinstance(debug, bool):
|
916
|
+
debug = LMDebugMode.ALL if debug else LMDebugMode.NONE
|
917
|
+
|
918
|
+
if debug & LMDebugMode.INFO:
|
919
|
+
self._debug_model_info(call_counter, UsageNotAvailable())
|
920
|
+
|
921
|
+
if debug & LMDebugMode.PROMPT:
|
922
|
+
console.write(
|
923
|
+
prompt,
|
924
|
+
title=f'\n[{call_counter}] PROMPT TO TOKENIZE:',
|
925
|
+
color='green',
|
926
|
+
)
|
927
|
+
referred_modalities_lst = [prompt.referred_modalities(),]
|
928
|
+
if referred_modalities_lst:
|
929
|
+
for referred_modalities in referred_modalities_lst:
|
930
|
+
console.write(
|
931
|
+
pg.object_utils.kvlist_str(
|
932
|
+
[(k, repr(v), None) for k, v in referred_modalities.items()]
|
933
|
+
),
|
934
|
+
title=f'\n[{call_counter}] MODALITY OBJECTS SENT TO LM:',
|
935
|
+
color='green',
|
936
|
+
)
|
937
|
+
|
938
|
+
if debug & LMDebugMode.RESPONSE:
|
939
|
+
console.write(
|
940
|
+
tokens,
|
941
|
+
title=(
|
942
|
+
f'\n[{call_counter}] {len(tokens)} TOKENS RETURNED '
|
943
|
+
f'(in {elapse:.2f} seconds):'
|
944
|
+
),
|
945
|
+
color='blue',
|
946
|
+
)
|
947
|
+
|
948
|
+
def rate_to_max_concurrency(
|
949
|
+
self, requests_per_min: float = 0, tokens_per_min: float = 0
|
950
|
+
) -> int:
|
951
|
+
"""Converts a rate to a max concurrency."""
|
952
|
+
if tokens_per_min > 0:
|
953
|
+
return max(int(tokens_per_min / TOKENS_PER_REQUEST / 60), 1)
|
954
|
+
elif requests_per_min > 0:
|
955
|
+
return max(int(requests_per_min / 60), 1) # Max concurrency can't be zero
|
956
|
+
else:
|
957
|
+
return DEFAULT_MAX_CONCURRENCY # Default of 1
|
958
|
+
|
959
|
+
|
960
|
+
class UsageSummary(pg.Object, pg.views.HtmlTreeView.Extension):
|
961
|
+
"""Usage sumary."""
|
962
|
+
|
963
|
+
class AggregatedUsage(pg.Object):
|
964
|
+
"""Aggregated usage."""
|
965
|
+
|
966
|
+
total: LMSamplingUsage = LMSamplingUsage(0, 0, 0, 0, 0.0)
|
967
|
+
breakdown: dict[str, LMSamplingUsage] = {}
|
968
|
+
|
969
|
+
def __bool__(self) -> bool:
|
970
|
+
"""Returns True if the usage is non-empty."""
|
971
|
+
return bool(self.breakdown)
|
972
|
+
|
973
|
+
def add(
|
974
|
+
self,
|
975
|
+
model_id: str,
|
976
|
+
usage: LMSamplingUsage,
|
977
|
+
) -> None:
|
978
|
+
"""Adds an entry to the breakdown."""
|
979
|
+
aggregated = self.breakdown.get(model_id, None)
|
980
|
+
with pg.notify_on_change(False):
|
981
|
+
self.breakdown[model_id] = usage + aggregated
|
982
|
+
self.rebind(
|
983
|
+
total=self.total + usage,
|
984
|
+
raise_on_no_change=False
|
985
|
+
)
|
986
|
+
|
987
|
+
def merge(self, other: 'UsageSummary.AggregatedUsage') -> None:
|
988
|
+
"""Merges the usage summary."""
|
989
|
+
with pg.notify_on_change(False):
|
990
|
+
for model_id, usage in other.breakdown.items():
|
991
|
+
self.add(model_id, usage)
|
992
|
+
|
993
|
+
def _on_bound(self):
|
994
|
+
super()._on_bound()
|
995
|
+
self._usage_badge = None
|
996
|
+
self._lock = threading.Lock()
|
997
|
+
|
998
|
+
@property
|
999
|
+
def total(self) -> LMSamplingUsage:
|
1000
|
+
return self.cached.total + self.uncached.total
|
1001
|
+
|
1002
|
+
def add(self, model_id: str, usage: LMSamplingUsage, is_cached: bool):
|
1003
|
+
"""Updates the usage summary."""
|
1004
|
+
with self._lock:
|
1005
|
+
if is_cached:
|
1006
|
+
usage.rebind(estimated_cost=0.0, skip_notification=True)
|
1007
|
+
self.cached.add(model_id, usage)
|
1008
|
+
else:
|
1009
|
+
self.uncached.add(model_id, usage)
|
1010
|
+
self._update_view()
|
1011
|
+
|
1012
|
+
def merge(self, other: 'UsageSummary', as_cached: bool = False) -> None:
|
1013
|
+
"""Aggregates the usage summary.
|
1014
|
+
|
1015
|
+
Args:
|
1016
|
+
other: The usage summary to merge.
|
1017
|
+
as_cached: Whether to merge the usage summary as cached.
|
1018
|
+
"""
|
1019
|
+
with self._lock:
|
1020
|
+
self.cached.merge(other.cached)
|
1021
|
+
if as_cached:
|
1022
|
+
self.cached.merge(other.uncached)
|
1023
|
+
else:
|
1024
|
+
self.uncached.merge(other.uncached)
|
1025
|
+
self._update_view()
|
1026
|
+
|
1027
|
+
def _sym_nondefault(self) -> dict[str, Any]:
|
1028
|
+
"""Overrides nondefault values so volatile values are not included."""
|
1029
|
+
return dict()
|
1030
|
+
|
1031
|
+
#
|
1032
|
+
# Html views for the usage summary.
|
1033
|
+
#
|
1034
|
+
|
1035
|
+
def _update_view(self):
|
1036
|
+
if self._usage_badge is not None:
|
1037
|
+
self._usage_badge.update(
|
1038
|
+
self._badge_text(),
|
1039
|
+
tooltip=pg.format(
|
1040
|
+
self, verbose=False, custom_format=self._tooltip_format
|
1041
|
+
),
|
1042
|
+
styles=dict(color=self._badge_color()),
|
1043
|
+
)
|
1044
|
+
|
1045
|
+
def _badge_text(self) -> str:
|
1046
|
+
if self.total.estimated_cost is not None:
|
1047
|
+
return f'{self.total.estimated_cost:.3f}'
|
1048
|
+
return '0.000'
|
1049
|
+
|
1050
|
+
def _badge_color(self) -> str | None:
|
1051
|
+
if self.total.estimated_cost is None or self.total.estimated_cost < 1.0:
|
1052
|
+
return None
|
1053
|
+
|
1054
|
+
# Step 1: The normal cost range is around 1e-3 to 1e5.
|
1055
|
+
# Therefore we normalize the log10 value from [-3, 5] to [0, 1].
|
1056
|
+
normalized_value = (math.log10(self.total.estimated_cost) + 3) / (5 + 3)
|
1057
|
+
|
1058
|
+
# Step 2: Interpolate between green and red
|
1059
|
+
red = int(255 * normalized_value)
|
1060
|
+
green = int(255 * (1 - normalized_value))
|
1061
|
+
return f'rgb({red}, {green}, 0)'
|
1062
|
+
|
1063
|
+
def _tooltip_format(self, v, root_indent):
|
1064
|
+
del root_indent
|
1065
|
+
if isinstance(v, int):
|
1066
|
+
return f'{v:,}'
|
1067
|
+
if isinstance(v, float):
|
1068
|
+
return f'{v:,.3f}'
|
1069
|
+
return None
|
1070
|
+
|
1071
|
+
def _html_tree_view(
|
1072
|
+
self,
|
1073
|
+
*,
|
1074
|
+
view: pg.views.HtmlTreeView,
|
1075
|
+
extra_flags: dict[str, Any] | None = None,
|
1076
|
+
**kwargs
|
1077
|
+
) -> pg.Html:
|
1078
|
+
extra_flags = extra_flags or {}
|
1079
|
+
as_badge = extra_flags.pop('as_badge', False)
|
1080
|
+
interactive = extra_flags.get('interactive', True)
|
1081
|
+
if as_badge:
|
1082
|
+
usage_badge = self._usage_badge
|
1083
|
+
if usage_badge is None:
|
1084
|
+
usage_badge = pg.views.html.controls.Badge(
|
1085
|
+
self._badge_text(),
|
1086
|
+
tooltip=pg.format(
|
1087
|
+
self, custom_format=self._tooltip_format, verbose=False
|
1088
|
+
),
|
1089
|
+
css_classes=['usage-summary'],
|
1090
|
+
styles=dict(color=self._badge_color()),
|
1091
|
+
interactive=True,
|
1092
|
+
)
|
1093
|
+
if interactive:
|
1094
|
+
self._usage_badge = usage_badge
|
1095
|
+
return usage_badge.to_html()
|
1096
|
+
return super()._html_tree_view(
|
1097
|
+
view=view,
|
1098
|
+
extra_flags=extra_flags,
|
1099
|
+
**kwargs
|
1100
|
+
)
|
1101
|
+
|
1102
|
+
@classmethod
|
1103
|
+
@functools.cache
|
1104
|
+
def _html_tree_view_css_styles(cls) -> list[str]:
|
1105
|
+
return super()._html_tree_view_css_styles() + [
|
1106
|
+
"""
|
1107
|
+
.usage-summary.label {
|
1108
|
+
display: inline-flex;
|
1109
|
+
border-radius: 5px;
|
1110
|
+
padding: 5px;
|
1111
|
+
background-color: #f1f1f1;
|
1112
|
+
color: #CCC;
|
1113
|
+
}
|
1114
|
+
.usage-summary.label::before {
|
1115
|
+
content: '$';
|
1116
|
+
}
|
1117
|
+
"""
|
1118
|
+
]
|
1119
|
+
|
1120
|
+
pg.members(
|
1121
|
+
dict(
|
1122
|
+
cached=(
|
1123
|
+
pg.typing.Object(
|
1124
|
+
UsageSummary.AggregatedUsage,
|
1125
|
+
default=UsageSummary.AggregatedUsage()
|
1126
|
+
),
|
1127
|
+
'Aggregated usages for cached LLM calls.'
|
1128
|
+
),
|
1129
|
+
uncached=(
|
1130
|
+
pg.typing.Object(
|
1131
|
+
UsageSummary.AggregatedUsage,
|
1132
|
+
default=UsageSummary.AggregatedUsage()
|
1133
|
+
),
|
1134
|
+
'Aggregated usages for uncached LLM calls.'
|
1135
|
+
),
|
1136
|
+
)
|
1137
|
+
)(UsageSummary)
|
1138
|
+
|
1139
|
+
|
1140
|
+
class _UsageTracker:
|
1141
|
+
"""Usage tracker."""
|
1142
|
+
|
1143
|
+
def __init__(self, model_ids: set[str] | None):
|
1144
|
+
self.model_ids = model_ids
|
1145
|
+
self.usage_summary = UsageSummary()
|
1146
|
+
|
1147
|
+
def track(self, model_id: str, usage: LMSamplingUsage, is_cached: bool):
|
1148
|
+
if self.model_ids is None or model_id in self.model_ids:
|
1149
|
+
self.usage_summary.add(model_id, usage, is_cached)
|
1150
|
+
|
1151
|
+
|
1152
|
+
@contextlib.contextmanager
|
1153
|
+
def track_usages(
|
1154
|
+
*lm: Union[str, LanguageModel]
|
1155
|
+
) -> Iterator[UsageSummary]:
|
1156
|
+
"""Context manager to track the usages of all language models in scope.
|
1157
|
+
|
1158
|
+
`lf.track_usages` works with threads spawned by `lf.concurrent_map` and
|
1159
|
+
`lf.concurrent_execute`.
|
1160
|
+
|
1161
|
+
Example:
|
1162
|
+
```
|
1163
|
+
lm = lf.llms.GeminiPro1()
|
1164
|
+
with lf.track_usages() as usages:
|
1165
|
+
# invoke any code that will call LLMs.
|
1166
|
+
|
1167
|
+
print(usages[lm.model_id])
|
1168
|
+
```
|
1169
|
+
|
1170
|
+
Args:
|
1171
|
+
*lm: The language model(s) to track. If None, track all models in scope.
|
1172
|
+
|
1173
|
+
Yields:
|
1174
|
+
A dictionary of model ID to usage. If a model does not supports usage
|
1175
|
+
counting, the dict entry will be None.
|
1176
|
+
"""
|
1177
|
+
if not lm:
|
1178
|
+
model_ids = None
|
1179
|
+
else:
|
1180
|
+
model_ids = [m.model_id if isinstance(m, LanguageModel) else m for m in lm]
|
1181
|
+
|
1182
|
+
trackers = component.context_value('__usage_trackers__', [])
|
1183
|
+
tracker = _UsageTracker(set(model_ids) if model_ids else None)
|
1184
|
+
with component.context(__usage_trackers__=trackers + [tracker]):
|
1185
|
+
try:
|
1186
|
+
yield tracker.usage_summary
|
1187
|
+
finally:
|
1188
|
+
pass
|