langfun 0.0.2.dev20240429__py3-none-any.whl → 0.1.2.dev202501150804__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/__init__.py +20 -2
- langfun/core/__init__.py +16 -5
- langfun/core/agentic/__init__.py +30 -0
- langfun/core/agentic/action.py +854 -0
- langfun/core/agentic/action_eval.py +150 -0
- langfun/core/agentic/action_eval_test.py +109 -0
- langfun/core/agentic/action_test.py +136 -0
- langfun/core/coding/python/__init__.py +5 -11
- langfun/core/coding/python/correction.py +37 -21
- langfun/core/coding/python/correction_test.py +29 -3
- langfun/core/coding/python/execution.py +40 -216
- langfun/core/coding/python/execution_test.py +29 -89
- langfun/core/coding/python/generation.py +21 -11
- langfun/core/coding/python/generation_test.py +2 -2
- langfun/core/coding/python/parsing.py +108 -193
- langfun/core/coding/python/parsing_test.py +2 -105
- langfun/core/component.py +63 -2
- langfun/core/component_test.py +53 -0
- langfun/core/concurrent.py +414 -117
- langfun/core/concurrent_test.py +111 -24
- langfun/core/console.py +17 -5
- langfun/core/console_test.py +17 -0
- langfun/core/eval/__init__.py +16 -1
- langfun/core/eval/base.py +622 -174
- langfun/core/eval/base_test.py +200 -54
- langfun/core/eval/matching.py +63 -76
- langfun/core/eval/matching_test.py +17 -8
- langfun/core/eval/patching.py +130 -0
- langfun/core/eval/patching_test.py +170 -0
- langfun/core/eval/scoring.py +26 -26
- langfun/core/eval/scoring_test.py +19 -2
- langfun/core/eval/v2/__init__.py +42 -0
- langfun/core/eval/v2/checkpointing.py +380 -0
- langfun/core/eval/v2/checkpointing_test.py +228 -0
- langfun/core/eval/v2/eval_test_helper.py +136 -0
- langfun/core/eval/v2/evaluation.py +725 -0
- langfun/core/eval/v2/evaluation_test.py +180 -0
- langfun/core/eval/v2/example.py +305 -0
- langfun/core/eval/v2/example_test.py +128 -0
- langfun/core/eval/v2/experiment.py +1048 -0
- langfun/core/eval/v2/experiment_test.py +433 -0
- langfun/core/eval/v2/metric_values.py +156 -0
- langfun/core/eval/v2/metric_values_test.py +80 -0
- langfun/core/eval/v2/metrics.py +357 -0
- langfun/core/eval/v2/metrics_test.py +203 -0
- langfun/core/eval/v2/progress.py +348 -0
- langfun/core/eval/v2/progress_test.py +82 -0
- langfun/core/eval/v2/progress_tracking.py +210 -0
- langfun/core/eval/v2/progress_tracking_test.py +66 -0
- langfun/core/eval/v2/reporting.py +270 -0
- langfun/core/eval/v2/reporting_test.py +158 -0
- langfun/core/eval/v2/runners.py +488 -0
- langfun/core/eval/v2/runners_test.py +334 -0
- langfun/core/langfunc.py +4 -17
- langfun/core/langfunc_test.py +22 -6
- langfun/core/language_model.py +577 -39
- langfun/core/language_model_test.py +470 -56
- langfun/core/llms/__init__.py +87 -16
- langfun/core/llms/anthropic.py +312 -87
- langfun/core/llms/anthropic_test.py +71 -3
- langfun/core/llms/cache/base.py +21 -2
- langfun/core/llms/cache/in_memory.py +13 -0
- langfun/core/llms/cache/in_memory_test.py +53 -2
- langfun/core/llms/compositional.py +101 -0
- langfun/core/llms/compositional_test.py +73 -0
- langfun/core/llms/deepseek.py +117 -0
- langfun/core/llms/deepseek_test.py +61 -0
- langfun/core/llms/fake.py +11 -7
- langfun/core/llms/fake_test.py +14 -0
- langfun/core/llms/gemini.py +507 -0
- langfun/core/llms/gemini_test.py +195 -0
- langfun/core/llms/google_genai.py +62 -218
- langfun/core/llms/google_genai_test.py +9 -202
- langfun/core/llms/groq.py +160 -144
- langfun/core/llms/groq_test.py +31 -137
- langfun/core/llms/llama_cpp.py +15 -42
- langfun/core/llms/llama_cpp_test.py +4 -30
- langfun/core/llms/openai.py +395 -203
- langfun/core/llms/openai_compatible.py +179 -0
- langfun/core/llms/openai_compatible_test.py +495 -0
- langfun/core/llms/openai_test.py +30 -395
- langfun/core/llms/rest.py +113 -0
- langfun/core/llms/rest_test.py +111 -0
- langfun/core/llms/vertexai.py +192 -0
- langfun/core/llms/vertexai_test.py +52 -0
- langfun/core/logging.py +284 -0
- langfun/core/logging_test.py +125 -0
- langfun/core/message.py +319 -9
- langfun/core/message_test.py +190 -13
- langfun/core/modalities/__init__.py +6 -2
- langfun/core/modalities/audio.py +30 -0
- langfun/core/modalities/audio_test.py +63 -0
- langfun/core/modalities/image.py +39 -20
- langfun/core/modalities/image_test.py +52 -9
- langfun/core/modalities/mime.py +206 -29
- langfun/core/modalities/mime_test.py +90 -9
- langfun/core/modalities/ms_office.py +117 -0
- langfun/core/modalities/ms_office_test.py +389 -0
- langfun/core/modalities/pdf.py +22 -0
- langfun/core/modalities/pdf_test.py +57 -0
- langfun/core/modalities/video.py +9 -26
- langfun/core/modalities/video_test.py +3 -3
- langfun/core/modality.py +26 -3
- langfun/core/modality_test.py +2 -2
- langfun/core/sampling.py +11 -11
- langfun/core/structured/__init__.py +12 -16
- langfun/core/structured/completion.py +32 -5
- langfun/core/structured/completion_test.py +7 -6
- langfun/core/structured/description.py +2 -2
- langfun/core/structured/description_test.py +3 -3
- langfun/core/structured/function_generation.py +60 -27
- langfun/core/structured/function_generation_test.py +72 -2
- langfun/core/structured/mapping.py +97 -47
- langfun/core/structured/mapping_test.py +90 -2
- langfun/core/structured/parsing.py +33 -21
- langfun/core/structured/parsing_test.py +53 -9
- langfun/core/structured/querying.py +746 -0
- langfun/core/structured/{prompting_test.py → querying_test.py} +469 -51
- langfun/core/structured/schema.py +204 -97
- langfun/core/structured/schema_generation.py +1 -1
- langfun/core/structured/schema_test.py +130 -29
- langfun/core/structured/scoring.py +125 -19
- langfun/core/structured/scoring_test.py +30 -0
- langfun/core/structured/tokenization.py +64 -0
- langfun/core/structured/tokenization_test.py +48 -0
- langfun/core/template.py +115 -1
- langfun/core/template_test.py +71 -1
- langfun/core/templates/conversation.py +9 -0
- langfun/core/templates/conversation_test.py +4 -3
- langfun/core/templates/selfplay_test.py +10 -2
- langfun-0.1.2.dev202501150804.dist-info/METADATA +225 -0
- langfun-0.1.2.dev202501150804.dist-info/RECORD +153 -0
- {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501150804.dist-info}/WHEEL +1 -1
- langfun/core/coding/python/errors.py +0 -108
- langfun/core/coding/python/errors_test.py +0 -99
- langfun/core/coding/python/permissions.py +0 -90
- langfun/core/coding/python/permissions_test.py +0 -86
- langfun/core/structured/prompting.py +0 -238
- langfun/core/text_formatting.py +0 -162
- langfun/core/text_formatting_test.py +0 -47
- langfun-0.0.2.dev20240429.dist-info/METADATA +0 -100
- langfun-0.0.2.dev20240429.dist-info/RECORD +0 -108
- {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501150804.dist-info}/LICENSE +0 -0
- {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501150804.dist-info}/top_level.txt +0 -0
langfun/core/language_model.py
CHANGED
@@ -14,20 +14,51 @@
|
|
14
14
|
"""Interface for language model."""
|
15
15
|
|
16
16
|
import abc
|
17
|
+
import contextlib
|
17
18
|
import dataclasses
|
18
19
|
import enum
|
20
|
+
import functools
|
21
|
+
import math
|
22
|
+
import threading
|
19
23
|
import time
|
20
|
-
from typing import Annotated, Any, Callable, Sequence, Tuple, Type, Union
|
24
|
+
from typing import Annotated, Any, Callable, Iterator, Optional, Sequence, Tuple, Type, Union
|
21
25
|
from langfun.core import component
|
22
26
|
from langfun.core import concurrent
|
23
27
|
from langfun.core import console
|
24
28
|
from langfun.core import message as message_lib
|
29
|
+
|
25
30
|
import pyglove as pg
|
26
31
|
|
27
32
|
TOKENS_PER_REQUEST = 250 # Estimated num tokens for a single request
|
28
33
|
DEFAULT_MAX_CONCURRENCY = 1 # Use this as max concurrency if no RPM or TPM data
|
29
34
|
|
30
35
|
|
36
|
+
#
|
37
|
+
# Common errors during calling language models.
|
38
|
+
#
|
39
|
+
|
40
|
+
|
41
|
+
class LMError(RuntimeError):
|
42
|
+
"""Base class for language model errors."""
|
43
|
+
|
44
|
+
|
45
|
+
class RetryableLMError(LMError):
|
46
|
+
"""Base class for LLM errors that can be solved by retrying."""
|
47
|
+
|
48
|
+
|
49
|
+
class RateLimitError(RetryableLMError):
|
50
|
+
"""Error for rate limit reached."""
|
51
|
+
|
52
|
+
|
53
|
+
class TemporaryLMError(RetryableLMError):
|
54
|
+
"""Error for temporary service issues that can be retried."""
|
55
|
+
|
56
|
+
|
57
|
+
#
|
58
|
+
# Language model input/output interfaces.
|
59
|
+
#
|
60
|
+
|
61
|
+
|
31
62
|
class LMSample(pg.Object):
|
32
63
|
"""Response candidate."""
|
33
64
|
|
@@ -50,12 +81,140 @@ class LMSample(pg.Object):
|
|
50
81
|
] = None
|
51
82
|
|
52
83
|
|
84
|
+
class RetryStats(pg.Object):
|
85
|
+
"""Retry stats, which is aggregated across multiple retry entries."""
|
86
|
+
|
87
|
+
num_occurences: Annotated[
|
88
|
+
int,
|
89
|
+
'Total number of retry attempts on LLM (excluding the first attempt).',
|
90
|
+
] = 0
|
91
|
+
total_wait_interval: Annotated[
|
92
|
+
float, 'Total wait interval in seconds due to retry.'
|
93
|
+
] = 0
|
94
|
+
total_call_interval: Annotated[
|
95
|
+
float, 'Total LLM call interval in seconds.'
|
96
|
+
] = 0
|
97
|
+
errors: Annotated[
|
98
|
+
dict[str, int],
|
99
|
+
'A Counter of error types encountered during the retry attempts.',
|
100
|
+
] = {}
|
101
|
+
|
102
|
+
@classmethod
|
103
|
+
def from_retry_entries(
|
104
|
+
cls, retry_entries: Sequence[concurrent.RetryEntry]
|
105
|
+
) -> 'RetryStats':
|
106
|
+
"""Creates a RetryStats from a sequence of RetryEntry."""
|
107
|
+
if not retry_entries:
|
108
|
+
return RetryStats()
|
109
|
+
errors = {}
|
110
|
+
for retry in retry_entries:
|
111
|
+
if retry.error is not None:
|
112
|
+
errors[retry.error.__class__.__name__] = (
|
113
|
+
errors.get(retry.error.__class__.__name__, 0) + 1
|
114
|
+
)
|
115
|
+
return RetryStats(
|
116
|
+
num_occurences=len(retry_entries) - 1,
|
117
|
+
total_wait_interval=sum(e.wait_interval for e in retry_entries),
|
118
|
+
total_call_interval=sum(e.call_interval for e in retry_entries),
|
119
|
+
errors=errors,
|
120
|
+
)
|
121
|
+
|
122
|
+
def __add__(self, other: 'RetryStats') -> 'RetryStats':
|
123
|
+
errors = self.errors.copy()
|
124
|
+
for error, count in other.errors.items():
|
125
|
+
errors[error] = errors.get(error, 0) + count
|
126
|
+
return RetryStats(
|
127
|
+
num_occurences=self.num_occurences + other.num_occurences,
|
128
|
+
total_wait_interval=self.total_wait_interval
|
129
|
+
+ other.total_wait_interval,
|
130
|
+
total_call_interval=self.total_call_interval
|
131
|
+
+ other.total_call_interval,
|
132
|
+
errors=errors,
|
133
|
+
)
|
134
|
+
|
135
|
+
def __radd__(self, other: 'RetryStats') -> 'RetryStats':
|
136
|
+
return self + other
|
137
|
+
|
138
|
+
|
53
139
|
class LMSamplingUsage(pg.Object):
|
54
140
|
"""Usage information per completion."""
|
55
141
|
|
56
142
|
prompt_tokens: int
|
57
143
|
completion_tokens: int
|
58
144
|
total_tokens: int
|
145
|
+
num_requests: int = 1
|
146
|
+
estimated_cost: Annotated[
|
147
|
+
float | None,
|
148
|
+
(
|
149
|
+
'Estimated cost in US dollars. If None, cost estimating is not '
|
150
|
+
'suppported on the model being queried.'
|
151
|
+
),
|
152
|
+
] = None
|
153
|
+
retry_stats: RetryStats = RetryStats()
|
154
|
+
|
155
|
+
def __bool__(self) -> bool:
|
156
|
+
return self.num_requests > 0
|
157
|
+
|
158
|
+
@property
|
159
|
+
def average_prompt_tokens(self) -> int:
|
160
|
+
"""Returns the average prompt tokens per request."""
|
161
|
+
return self.prompt_tokens // self.num_requests
|
162
|
+
|
163
|
+
@property
|
164
|
+
def average_completion_tokens(self) -> int:
|
165
|
+
"""Returns the average completion tokens per request."""
|
166
|
+
return self.completion_tokens // self.num_requests
|
167
|
+
|
168
|
+
@property
|
169
|
+
def average_total_tokens(self) -> int:
|
170
|
+
"""Returns the average total tokens per request."""
|
171
|
+
return self.total_tokens // self.num_requests
|
172
|
+
|
173
|
+
@property
|
174
|
+
def average_estimated_cost(self) -> float | None:
|
175
|
+
"""Returns the average estimated cost per request."""
|
176
|
+
if self.estimated_cost is None:
|
177
|
+
return None
|
178
|
+
return self.estimated_cost / self.num_requests
|
179
|
+
|
180
|
+
def __add__(self, other: Optional['LMSamplingUsage']) -> 'LMSamplingUsage':
|
181
|
+
if other is None:
|
182
|
+
return self
|
183
|
+
if self.estimated_cost is None:
|
184
|
+
estimated_cost = other.estimated_cost
|
185
|
+
elif other.estimated_cost is None:
|
186
|
+
estimated_cost = self.estimated_cost
|
187
|
+
else:
|
188
|
+
estimated_cost = self.estimated_cost + other.estimated_cost
|
189
|
+
return LMSamplingUsage(
|
190
|
+
prompt_tokens=self.prompt_tokens + other.prompt_tokens,
|
191
|
+
completion_tokens=self.completion_tokens + other.completion_tokens,
|
192
|
+
total_tokens=self.total_tokens + other.total_tokens,
|
193
|
+
num_requests=self.num_requests + other.num_requests,
|
194
|
+
estimated_cost=estimated_cost,
|
195
|
+
retry_stats=self.retry_stats + other.retry_stats,
|
196
|
+
)
|
197
|
+
|
198
|
+
def __radd__(self, other: Optional['LMSamplingUsage']) -> 'LMSamplingUsage':
|
199
|
+
return self + other
|
200
|
+
|
201
|
+
|
202
|
+
class UsageNotAvailable(LMSamplingUsage):
|
203
|
+
"""Usage information not available."""
|
204
|
+
prompt_tokens: pg.typing.Int(0).freeze() # pytype: disable=invalid-annotation
|
205
|
+
completion_tokens: pg.typing.Int(0).freeze() # pytype: disable=invalid-annotation
|
206
|
+
total_tokens: pg.typing.Int(0).freeze() # pytype: disable=invalid-annotation
|
207
|
+
estimated_cost: pg.typing.Float(default=None, is_noneable=True).freeze() # pytype: disable=invalid-annotation
|
208
|
+
|
209
|
+
def __add__(self, other: Optional['LMSamplingUsage']) -> 'UsageNotAvailable':
|
210
|
+
if other is None:
|
211
|
+
return self
|
212
|
+
return UsageNotAvailable(
|
213
|
+
num_requests=self.num_requests + other.num_requests
|
214
|
+
)
|
215
|
+
|
216
|
+
def __radd__(self, other: Optional['LMSamplingUsage']) -> 'UsageNotAvailable':
|
217
|
+
return self + other
|
59
218
|
|
60
219
|
|
61
220
|
class LMSamplingResult(pg.Object):
|
@@ -70,9 +229,14 @@ class LMSamplingResult(pg.Object):
|
|
70
229
|
] = []
|
71
230
|
|
72
231
|
usage: Annotated[
|
73
|
-
LMSamplingUsage
|
232
|
+
LMSamplingUsage,
|
74
233
|
'Usage information. Currently only OpenAI models are supported.',
|
75
|
-
] =
|
234
|
+
] = UsageNotAvailable()
|
235
|
+
|
236
|
+
is_cached: Annotated[
|
237
|
+
bool,
|
238
|
+
'Whether the result is from cache or not.'
|
239
|
+
] = False
|
76
240
|
|
77
241
|
|
78
242
|
class LMSamplingOptions(component.Component):
|
@@ -166,6 +330,11 @@ class LMScoringResult(pg.Object):
|
|
166
330
|
float,
|
167
331
|
'The log likelyhood of the requested completion towards the prompt.',
|
168
332
|
]
|
333
|
+
gradients: Annotated[
|
334
|
+
Any | None,
|
335
|
+
'(Optional) gradients from the score method, w.r.t.' +
|
336
|
+
' prompt.metadata.weights.',
|
337
|
+
] = None
|
169
338
|
|
170
339
|
|
171
340
|
class LMCache(pg.Object):
|
@@ -180,6 +349,7 @@ class LMCache(pg.Object):
|
|
180
349
|
num_hit_expires: int = 0
|
181
350
|
num_misses: int = 0
|
182
351
|
num_updates: int = 0
|
352
|
+
num_deletes: int = 0
|
183
353
|
|
184
354
|
@abc.abstractmethod
|
185
355
|
def get(
|
@@ -197,6 +367,15 @@ class LMCache(pg.Object):
|
|
197
367
|
) -> None:
|
198
368
|
"""Puts the result of a prompt generated by a language model in cache."""
|
199
369
|
|
370
|
+
@abc.abstractmethod
|
371
|
+
def delete(
|
372
|
+
self,
|
373
|
+
lm: 'LanguageModel',
|
374
|
+
prompt: message_lib.Message,
|
375
|
+
seed: int,
|
376
|
+
) -> bool:
|
377
|
+
"""Deletes the result of a prompt generated by a language model in cache."""
|
378
|
+
|
200
379
|
@property
|
201
380
|
@abc.abstractmethod
|
202
381
|
def stats(self) -> Stats:
|
@@ -290,6 +469,15 @@ class LanguageModel(component.Component):
|
|
290
469
|
)
|
291
470
|
] = True
|
292
471
|
|
472
|
+
max_retry_interval: Annotated[
|
473
|
+
int,
|
474
|
+
(
|
475
|
+
'The max retry interval in seconds. This is useful when the retry '
|
476
|
+
'interval is exponential, to avoid the wait time to grow '
|
477
|
+
'exponentially.'
|
478
|
+
)
|
479
|
+
] = 300
|
480
|
+
|
293
481
|
debug: Annotated[
|
294
482
|
bool | LMDebugMode,
|
295
483
|
(
|
@@ -303,7 +491,10 @@ class LanguageModel(component.Component):
|
|
303
491
|
def __init__(self, *args, **kwargs) -> None:
|
304
492
|
"""Overrides __init__ to pass through **kwargs to sampling options."""
|
305
493
|
|
306
|
-
sampling_options = kwargs.pop(
|
494
|
+
sampling_options = kwargs.pop(
|
495
|
+
'sampling_options',
|
496
|
+
pg.clone(self.__schema__.fields['sampling_options'].default_value)
|
497
|
+
)
|
307
498
|
sampling_options_delta = {}
|
308
499
|
|
309
500
|
for k, v in kwargs.items():
|
@@ -361,12 +552,13 @@ class LanguageModel(component.Component):
|
|
361
552
|
response = sample.response
|
362
553
|
response.metadata.score = sample.score
|
363
554
|
response.metadata.logprobs = sample.logprobs
|
555
|
+
response.metadata.is_cached = result.is_cached
|
364
556
|
|
365
557
|
# NOTE(daiyip): Current usage is computed at per-result level,
|
366
558
|
# which is accurate when n=1. For n > 1, we average the usage across
|
367
559
|
# multiple samples.
|
368
560
|
usage = result.usage
|
369
|
-
if len(result.samples) == 1 or usage
|
561
|
+
if len(result.samples) == 1 or isinstance(usage, UsageNotAvailable):
|
370
562
|
response.metadata.usage = usage
|
371
563
|
else:
|
372
564
|
n = len(result.samples)
|
@@ -374,8 +566,29 @@ class LanguageModel(component.Component):
|
|
374
566
|
prompt_tokens=usage.prompt_tokens // n,
|
375
567
|
completion_tokens=usage.completion_tokens // n,
|
376
568
|
total_tokens=usage.total_tokens // n,
|
569
|
+
estimated_cost=(
|
570
|
+
usage.estimated_cost / n if usage.estimated_cost else None
|
571
|
+
),
|
572
|
+
retry_stats=RetryStats(
|
573
|
+
num_occurences=usage.retry_stats.num_occurences // n,
|
574
|
+
total_wait_interval=usage.retry_stats.total_wait_interval
|
575
|
+
/ n,
|
576
|
+
total_call_interval=usage.retry_stats.total_call_interval
|
577
|
+
/ n,
|
578
|
+
errors={
|
579
|
+
error: count // n
|
580
|
+
for error, count in usage.retry_stats.errors.items()
|
581
|
+
},
|
582
|
+
),
|
377
583
|
)
|
378
584
|
|
585
|
+
# Track usage.
|
586
|
+
trackers = component.context_value('__usage_trackers__', [])
|
587
|
+
if trackers:
|
588
|
+
model_id = self.model_id
|
589
|
+
for tracker in trackers:
|
590
|
+
tracker.track(model_id, usage, result.is_cached)
|
591
|
+
|
379
592
|
# Track the prompt for corresponding response.
|
380
593
|
response.source = prompt
|
381
594
|
|
@@ -403,7 +616,9 @@ class LanguageModel(component.Component):
|
|
403
616
|
request_to_result_index[len(requests)] = i
|
404
617
|
requests.append(prompt)
|
405
618
|
else:
|
406
|
-
|
619
|
+
result = r.clone()
|
620
|
+
assert result.is_cached, result
|
621
|
+
results[i] = result
|
407
622
|
|
408
623
|
# Sample non-cache-hit prompts.
|
409
624
|
if requests:
|
@@ -420,8 +635,12 @@ class LanguageModel(component.Component):
|
|
420
635
|
sample.response.set('cache_seed', cache_seed)
|
421
636
|
|
422
637
|
if cache_seed is not None:
|
423
|
-
self.cache.put(
|
424
|
-
|
638
|
+
self.cache.put(
|
639
|
+
self,
|
640
|
+
prompt,
|
641
|
+
result.clone(override=dict(is_cached=True)),
|
642
|
+
seed=cache_seed
|
643
|
+
)
|
425
644
|
return results # pytype: disable=bad-return-type
|
426
645
|
|
427
646
|
@abc.abstractmethod
|
@@ -433,16 +652,16 @@ class LanguageModel(component.Component):
|
|
433
652
|
|
434
653
|
def _parallel_execute_with_currency_control(
|
435
654
|
self,
|
436
|
-
action: Callable[...,
|
655
|
+
action: Callable[..., LMSamplingResult],
|
437
656
|
inputs: Sequence[Any],
|
438
657
|
retry_on_errors: Union[
|
439
658
|
None,
|
440
|
-
Union[Type[
|
441
|
-
Sequence[Union[Type[
|
442
|
-
] =
|
443
|
-
) -> Any:
|
659
|
+
Union[Type[BaseException], Tuple[Type[BaseException], str]],
|
660
|
+
Sequence[Union[Type[BaseException], Tuple[Type[BaseException], str]]],
|
661
|
+
] = RetryableLMError,
|
662
|
+
) -> list[Any]:
|
444
663
|
"""Helper method for subclasses for implementing _sample."""
|
445
|
-
|
664
|
+
executed_jobs = concurrent.concurrent_execute(
|
446
665
|
action,
|
447
666
|
inputs,
|
448
667
|
executor=self.resource_id if self.max_concurrency else None,
|
@@ -451,7 +670,16 @@ class LanguageModel(component.Component):
|
|
451
670
|
max_attempts=self.max_attempts,
|
452
671
|
retry_interval=self.retry_interval,
|
453
672
|
exponential_backoff=self.exponential_backoff,
|
673
|
+
max_retry_interval=self.max_retry_interval,
|
674
|
+
return_jobs=True,
|
454
675
|
)
|
676
|
+
for job in executed_jobs:
|
677
|
+
if isinstance(job.result, LMSamplingResult):
|
678
|
+
job.result.usage.rebind(
|
679
|
+
retry_stats=RetryStats.from_retry_entries(job.retry_entries),
|
680
|
+
skip_notification=True,
|
681
|
+
)
|
682
|
+
return [job.result for job in executed_jobs]
|
455
683
|
|
456
684
|
def __call__(
|
457
685
|
self, prompt: message_lib.Message, *, cache_seed: int = 0, **kwargs
|
@@ -479,7 +707,7 @@ class LanguageModel(component.Component):
|
|
479
707
|
prompt: message_lib.Message,
|
480
708
|
response: message_lib.Message,
|
481
709
|
call_counter: int,
|
482
|
-
usage: LMSamplingUsage
|
710
|
+
usage: LMSamplingUsage,
|
483
711
|
elapse: float,
|
484
712
|
) -> None:
|
485
713
|
"""Outputs debugging information."""
|
@@ -497,12 +725,13 @@ class LanguageModel(component.Component):
|
|
497
725
|
self._debug_response(response, call_counter, usage, elapse)
|
498
726
|
|
499
727
|
def _debug_model_info(
|
500
|
-
self, call_counter: int, usage: LMSamplingUsage
|
728
|
+
self, call_counter: int, usage: LMSamplingUsage) -> None:
|
501
729
|
"""Outputs debugging information about the model."""
|
502
730
|
title_suffix = ''
|
503
|
-
if usage
|
504
|
-
title_suffix =
|
505
|
-
f' (total {usage.total_tokens} tokens)', 'red'
|
731
|
+
if usage.total_tokens != 0:
|
732
|
+
title_suffix = pg.colored(
|
733
|
+
f' (total {usage.total_tokens} tokens)', 'red'
|
734
|
+
)
|
506
735
|
|
507
736
|
console.write(
|
508
737
|
self.format(compact=True, use_inferred=True),
|
@@ -514,12 +743,12 @@ class LanguageModel(component.Component):
|
|
514
743
|
self,
|
515
744
|
prompt: message_lib.Message,
|
516
745
|
call_counter: int,
|
517
|
-
usage: LMSamplingUsage
|
746
|
+
usage: LMSamplingUsage,
|
518
747
|
) -> None:
|
519
748
|
"""Outputs debugging information about the prompt."""
|
520
749
|
title_suffix = ''
|
521
|
-
if usage
|
522
|
-
title_suffix =
|
750
|
+
if usage.prompt_tokens != 0:
|
751
|
+
title_suffix = pg.colored(f' ({usage.prompt_tokens} tokens)', 'red')
|
523
752
|
|
524
753
|
console.write(
|
525
754
|
# We use metadata 'formatted_text' for scenarios where the prompt text
|
@@ -542,15 +771,15 @@ class LanguageModel(component.Component):
|
|
542
771
|
self,
|
543
772
|
response: message_lib.Message,
|
544
773
|
call_counter: int,
|
545
|
-
usage: LMSamplingUsage
|
774
|
+
usage: LMSamplingUsage,
|
546
775
|
elapse: float
|
547
776
|
) -> None:
|
548
777
|
"""Outputs debugging information about the response."""
|
549
778
|
title_suffix = ' ('
|
550
|
-
if usage
|
779
|
+
if usage.completion_tokens != 0:
|
551
780
|
title_suffix += f'{usage.completion_tokens} tokens '
|
552
781
|
title_suffix += f'in {elapse:.2f} seconds)'
|
553
|
-
title_suffix =
|
782
|
+
title_suffix = pg.colored(title_suffix, 'red')
|
554
783
|
|
555
784
|
console.write(
|
556
785
|
str(response) + '\n',
|
@@ -560,12 +789,19 @@ class LanguageModel(component.Component):
|
|
560
789
|
|
561
790
|
def score(
|
562
791
|
self,
|
563
|
-
prompt: str | message_lib.Message,
|
792
|
+
prompt: str | message_lib.Message | list[message_lib.Message],
|
564
793
|
completions: list[str | message_lib.Message],
|
565
794
|
**kwargs,
|
566
795
|
) -> list[LMScoringResult]:
|
567
796
|
"""Scores the given prompt."""
|
568
|
-
|
797
|
+
if isinstance(prompt, list):
|
798
|
+
if len(prompt) != len(completions):
|
799
|
+
raise ValueError(
|
800
|
+
'prompt and completions must have the same length.'
|
801
|
+
)
|
802
|
+
prompt = [message_lib.UserMessage.from_value(p) for p in prompt]
|
803
|
+
else:
|
804
|
+
prompt = message_lib.UserMessage.from_value(prompt)
|
569
805
|
completions = [message_lib.UserMessage.from_value(c) for c in completions]
|
570
806
|
|
571
807
|
call_counter = self._call_counter
|
@@ -581,7 +817,8 @@ class LanguageModel(component.Component):
|
|
581
817
|
return scoring_results
|
582
818
|
|
583
819
|
def _score(
|
584
|
-
self, prompt: message_lib.Message
|
820
|
+
self, prompt: message_lib.Message | list[message_lib.Message],
|
821
|
+
completions: list[message_lib.Message]
|
585
822
|
) -> list[LMScoringResult]:
|
586
823
|
"""Subclass to implement."""
|
587
824
|
raise NotImplementedError(
|
@@ -590,7 +827,7 @@ class LanguageModel(component.Component):
|
|
590
827
|
|
591
828
|
def _debug_score(
|
592
829
|
self,
|
593
|
-
prompt: message_lib.Message,
|
830
|
+
prompt: message_lib.Message | list[message_lib.Message],
|
594
831
|
completions: list[message_lib.Message],
|
595
832
|
scoring_results: list[LMScoringResult],
|
596
833
|
call_counter: int,
|
@@ -601,7 +838,7 @@ class LanguageModel(component.Component):
|
|
601
838
|
debug = LMDebugMode.ALL if debug else LMDebugMode.NONE
|
602
839
|
|
603
840
|
if debug & LMDebugMode.INFO:
|
604
|
-
self._debug_model_info(call_counter,
|
841
|
+
self._debug_model_info(call_counter, UsageNotAvailable())
|
605
842
|
|
606
843
|
if debug & LMDebugMode.PROMPT:
|
607
844
|
console.write(
|
@@ -609,15 +846,19 @@ class LanguageModel(component.Component):
|
|
609
846
|
title=f'\n[{call_counter}] SCORING LM WITH PROMPT:',
|
610
847
|
color='green',
|
611
848
|
)
|
612
|
-
|
613
|
-
|
614
|
-
|
615
|
-
|
616
|
-
|
617
|
-
|
618
|
-
|
619
|
-
|
620
|
-
|
849
|
+
if isinstance(prompt, list):
|
850
|
+
referred_modalities_lst = [p.referred_modalities() for p in prompt]
|
851
|
+
else:
|
852
|
+
referred_modalities_lst = [prompt.referred_modalities(),]
|
853
|
+
if referred_modalities_lst:
|
854
|
+
for referred_modalities in referred_modalities_lst:
|
855
|
+
console.write(
|
856
|
+
pg.object_utils.kvlist_str(
|
857
|
+
[(k, repr(v), None) for k, v in referred_modalities.items()]
|
858
|
+
),
|
859
|
+
title=f'\n[{call_counter}] MODALITY OBJECTS SENT TO LM:',
|
860
|
+
color='green',
|
861
|
+
)
|
621
862
|
|
622
863
|
if debug & LMDebugMode.RESPONSE:
|
623
864
|
console.write(
|
@@ -638,6 +879,72 @@ class LanguageModel(component.Component):
|
|
638
879
|
color='blue',
|
639
880
|
)
|
640
881
|
|
882
|
+
def tokenize(
|
883
|
+
self,
|
884
|
+
prompt: str | message_lib.Message,
|
885
|
+
**kwargs,
|
886
|
+
) -> list[tuple[str | bytes, int]]:
|
887
|
+
"""Tokenizes the given prompt."""
|
888
|
+
prompt = message_lib.UserMessage.from_value(prompt)
|
889
|
+
call_counter = self._call_counter
|
890
|
+
self._call_counter += 1
|
891
|
+
|
892
|
+
with component.context(override_attrs=True, **kwargs):
|
893
|
+
request_start = time.time()
|
894
|
+
tokens = self._tokenize(prompt)
|
895
|
+
elapse = time.time() - request_start
|
896
|
+
self._debug_tokenize(prompt, tokens, call_counter, elapse)
|
897
|
+
return tokens
|
898
|
+
|
899
|
+
def _tokenize(
|
900
|
+
self, prompt: message_lib.Message
|
901
|
+
) -> list[tuple[str | bytes, int]]:
|
902
|
+
"""Subclass to implement."""
|
903
|
+
raise NotImplementedError(
|
904
|
+
f'{self.__class__.__name__} does not support tokenization.'
|
905
|
+
)
|
906
|
+
|
907
|
+
def _debug_tokenize(
|
908
|
+
self,
|
909
|
+
prompt: message_lib.Message,
|
910
|
+
tokens: list[tuple[str | bytes, int]],
|
911
|
+
call_counter: int,
|
912
|
+
elapse: float,
|
913
|
+
):
|
914
|
+
debug = self.debug
|
915
|
+
if isinstance(debug, bool):
|
916
|
+
debug = LMDebugMode.ALL if debug else LMDebugMode.NONE
|
917
|
+
|
918
|
+
if debug & LMDebugMode.INFO:
|
919
|
+
self._debug_model_info(call_counter, UsageNotAvailable())
|
920
|
+
|
921
|
+
if debug & LMDebugMode.PROMPT:
|
922
|
+
console.write(
|
923
|
+
prompt,
|
924
|
+
title=f'\n[{call_counter}] PROMPT TO TOKENIZE:',
|
925
|
+
color='green',
|
926
|
+
)
|
927
|
+
referred_modalities_lst = [prompt.referred_modalities(),]
|
928
|
+
if referred_modalities_lst:
|
929
|
+
for referred_modalities in referred_modalities_lst:
|
930
|
+
console.write(
|
931
|
+
pg.object_utils.kvlist_str(
|
932
|
+
[(k, repr(v), None) for k, v in referred_modalities.items()]
|
933
|
+
),
|
934
|
+
title=f'\n[{call_counter}] MODALITY OBJECTS SENT TO LM:',
|
935
|
+
color='green',
|
936
|
+
)
|
937
|
+
|
938
|
+
if debug & LMDebugMode.RESPONSE:
|
939
|
+
console.write(
|
940
|
+
tokens,
|
941
|
+
title=(
|
942
|
+
f'\n[{call_counter}] {len(tokens)} TOKENS RETURNED '
|
943
|
+
f'(in {elapse:.2f} seconds):'
|
944
|
+
),
|
945
|
+
color='blue',
|
946
|
+
)
|
947
|
+
|
641
948
|
def rate_to_max_concurrency(
|
642
949
|
self, requests_per_min: float = 0, tokens_per_min: float = 0
|
643
950
|
) -> int:
|
@@ -648,3 +955,234 @@ class LanguageModel(component.Component):
|
|
648
955
|
return max(int(requests_per_min / 60), 1) # Max concurrency can't be zero
|
649
956
|
else:
|
650
957
|
return DEFAULT_MAX_CONCURRENCY # Default of 1
|
958
|
+
|
959
|
+
|
960
|
+
class UsageSummary(pg.Object, pg.views.HtmlTreeView.Extension):
|
961
|
+
"""Usage sumary."""
|
962
|
+
|
963
|
+
class AggregatedUsage(pg.Object):
|
964
|
+
"""Aggregated usage."""
|
965
|
+
|
966
|
+
total: LMSamplingUsage = LMSamplingUsage(0, 0, 0, 0, 0.0)
|
967
|
+
breakdown: dict[str, LMSamplingUsage] = {}
|
968
|
+
|
969
|
+
def __bool__(self) -> bool:
|
970
|
+
"""Returns True if the usage is non-empty."""
|
971
|
+
return bool(self.breakdown)
|
972
|
+
|
973
|
+
def add(
|
974
|
+
self,
|
975
|
+
model_id: str,
|
976
|
+
usage: LMSamplingUsage,
|
977
|
+
) -> None:
|
978
|
+
"""Adds an entry to the breakdown."""
|
979
|
+
aggregated = self.breakdown.get(model_id, None)
|
980
|
+
with pg.notify_on_change(False):
|
981
|
+
self.breakdown[model_id] = usage + aggregated
|
982
|
+
self.rebind(
|
983
|
+
total=self.total + usage,
|
984
|
+
raise_on_no_change=False
|
985
|
+
)
|
986
|
+
|
987
|
+
def merge(self, other: 'UsageSummary.AggregatedUsage') -> None:
|
988
|
+
"""Merges the usage summary."""
|
989
|
+
with pg.notify_on_change(False):
|
990
|
+
for model_id, usage in other.breakdown.items():
|
991
|
+
self.add(model_id, usage)
|
992
|
+
|
993
|
+
def _on_bound(self):
|
994
|
+
super()._on_bound()
|
995
|
+
self._usage_badge = None
|
996
|
+
self._lock = threading.Lock()
|
997
|
+
|
998
|
+
@property
|
999
|
+
def total(self) -> LMSamplingUsage:
|
1000
|
+
return self.cached.total + self.uncached.total
|
1001
|
+
|
1002
|
+
def add(self, model_id: str, usage: LMSamplingUsage, is_cached: bool):
|
1003
|
+
"""Updates the usage summary."""
|
1004
|
+
with self._lock:
|
1005
|
+
if is_cached:
|
1006
|
+
usage.rebind(estimated_cost=0.0, skip_notification=True)
|
1007
|
+
self.cached.add(model_id, usage)
|
1008
|
+
else:
|
1009
|
+
self.uncached.add(model_id, usage)
|
1010
|
+
self._update_view()
|
1011
|
+
|
1012
|
+
def merge(self, other: 'UsageSummary', as_cached: bool = False) -> None:
|
1013
|
+
"""Aggregates the usage summary.
|
1014
|
+
|
1015
|
+
Args:
|
1016
|
+
other: The usage summary to merge.
|
1017
|
+
as_cached: Whether to merge the usage summary as cached.
|
1018
|
+
"""
|
1019
|
+
with self._lock:
|
1020
|
+
self.cached.merge(other.cached)
|
1021
|
+
if as_cached:
|
1022
|
+
self.cached.merge(other.uncached)
|
1023
|
+
else:
|
1024
|
+
self.uncached.merge(other.uncached)
|
1025
|
+
self._update_view()
|
1026
|
+
|
1027
|
+
def _sym_nondefault(self) -> dict[str, Any]:
|
1028
|
+
"""Overrides nondefault values so volatile values are not included."""
|
1029
|
+
return dict()
|
1030
|
+
|
1031
|
+
#
|
1032
|
+
# Html views for the usage summary.
|
1033
|
+
#
|
1034
|
+
|
1035
|
+
def _update_view(self):
|
1036
|
+
if self._usage_badge is not None:
|
1037
|
+
self._usage_badge.update(
|
1038
|
+
self._badge_text(),
|
1039
|
+
tooltip=pg.format(
|
1040
|
+
self, verbose=False, custom_format=self._tooltip_format
|
1041
|
+
),
|
1042
|
+
styles=dict(color=self._badge_color()),
|
1043
|
+
)
|
1044
|
+
|
1045
|
+
def _badge_text(self) -> str:
|
1046
|
+
if self.total.estimated_cost is not None:
|
1047
|
+
return f'{self.total.estimated_cost:.3f}'
|
1048
|
+
return '0.000'
|
1049
|
+
|
1050
|
+
def _badge_color(self) -> str | None:
|
1051
|
+
if self.total.estimated_cost is None or self.total.estimated_cost < 1.0:
|
1052
|
+
return None
|
1053
|
+
|
1054
|
+
# Step 1: The normal cost range is around 1e-3 to 1e5.
|
1055
|
+
# Therefore we normalize the log10 value from [-3, 5] to [0, 1].
|
1056
|
+
normalized_value = (math.log10(self.total.estimated_cost) + 3) / (5 + 3)
|
1057
|
+
|
1058
|
+
# Step 2: Interpolate between green and red
|
1059
|
+
red = int(255 * normalized_value)
|
1060
|
+
green = int(255 * (1 - normalized_value))
|
1061
|
+
return f'rgb({red}, {green}, 0)'
|
1062
|
+
|
1063
|
+
def _tooltip_format(self, v, root_indent):
|
1064
|
+
del root_indent
|
1065
|
+
if isinstance(v, int):
|
1066
|
+
return f'{v:,}'
|
1067
|
+
if isinstance(v, float):
|
1068
|
+
return f'{v:,.3f}'
|
1069
|
+
return None
|
1070
|
+
|
1071
|
+
def _html_tree_view(
|
1072
|
+
self,
|
1073
|
+
*,
|
1074
|
+
view: pg.views.HtmlTreeView,
|
1075
|
+
extra_flags: dict[str, Any] | None = None,
|
1076
|
+
**kwargs
|
1077
|
+
) -> pg.Html:
|
1078
|
+
extra_flags = extra_flags or {}
|
1079
|
+
as_badge = extra_flags.pop('as_badge', False)
|
1080
|
+
interactive = extra_flags.get('interactive', True)
|
1081
|
+
if as_badge:
|
1082
|
+
usage_badge = self._usage_badge
|
1083
|
+
if usage_badge is None:
|
1084
|
+
usage_badge = pg.views.html.controls.Badge(
|
1085
|
+
self._badge_text(),
|
1086
|
+
tooltip=pg.format(
|
1087
|
+
self, custom_format=self._tooltip_format, verbose=False
|
1088
|
+
),
|
1089
|
+
css_classes=['usage-summary'],
|
1090
|
+
styles=dict(color=self._badge_color()),
|
1091
|
+
interactive=True,
|
1092
|
+
)
|
1093
|
+
if interactive:
|
1094
|
+
self._usage_badge = usage_badge
|
1095
|
+
return usage_badge.to_html()
|
1096
|
+
return super()._html_tree_view(
|
1097
|
+
view=view,
|
1098
|
+
extra_flags=extra_flags,
|
1099
|
+
**kwargs
|
1100
|
+
)
|
1101
|
+
|
1102
|
+
@classmethod
|
1103
|
+
@functools.cache
|
1104
|
+
def _html_tree_view_css_styles(cls) -> list[str]:
|
1105
|
+
return super()._html_tree_view_css_styles() + [
|
1106
|
+
"""
|
1107
|
+
.usage-summary.label {
|
1108
|
+
display: inline-flex;
|
1109
|
+
border-radius: 5px;
|
1110
|
+
padding: 5px;
|
1111
|
+
background-color: #f1f1f1;
|
1112
|
+
color: #CCC;
|
1113
|
+
}
|
1114
|
+
.usage-summary.label::before {
|
1115
|
+
content: '$';
|
1116
|
+
}
|
1117
|
+
"""
|
1118
|
+
]
|
1119
|
+
|
1120
|
+
pg.members(
|
1121
|
+
dict(
|
1122
|
+
cached=(
|
1123
|
+
pg.typing.Object(
|
1124
|
+
UsageSummary.AggregatedUsage,
|
1125
|
+
default=UsageSummary.AggregatedUsage()
|
1126
|
+
),
|
1127
|
+
'Aggregated usages for cached LLM calls.'
|
1128
|
+
),
|
1129
|
+
uncached=(
|
1130
|
+
pg.typing.Object(
|
1131
|
+
UsageSummary.AggregatedUsage,
|
1132
|
+
default=UsageSummary.AggregatedUsage()
|
1133
|
+
),
|
1134
|
+
'Aggregated usages for uncached LLM calls.'
|
1135
|
+
),
|
1136
|
+
)
|
1137
|
+
)(UsageSummary)
|
1138
|
+
|
1139
|
+
|
1140
|
+
class _UsageTracker:
|
1141
|
+
"""Usage tracker."""
|
1142
|
+
|
1143
|
+
def __init__(self, model_ids: set[str] | None):
|
1144
|
+
self.model_ids = model_ids
|
1145
|
+
self.usage_summary = UsageSummary()
|
1146
|
+
|
1147
|
+
def track(self, model_id: str, usage: LMSamplingUsage, is_cached: bool):
|
1148
|
+
if self.model_ids is None or model_id in self.model_ids:
|
1149
|
+
self.usage_summary.add(model_id, usage, is_cached)
|
1150
|
+
|
1151
|
+
|
1152
|
+
@contextlib.contextmanager
|
1153
|
+
def track_usages(
|
1154
|
+
*lm: Union[str, LanguageModel]
|
1155
|
+
) -> Iterator[UsageSummary]:
|
1156
|
+
"""Context manager to track the usages of all language models in scope.
|
1157
|
+
|
1158
|
+
`lf.track_usages` works with threads spawned by `lf.concurrent_map` and
|
1159
|
+
`lf.concurrent_execute`.
|
1160
|
+
|
1161
|
+
Example:
|
1162
|
+
```
|
1163
|
+
lm = lf.llms.GeminiPro1()
|
1164
|
+
with lf.track_usages() as usages:
|
1165
|
+
# invoke any code that will call LLMs.
|
1166
|
+
|
1167
|
+
print(usages[lm.model_id])
|
1168
|
+
```
|
1169
|
+
|
1170
|
+
Args:
|
1171
|
+
*lm: The language model(s) to track. If None, track all models in scope.
|
1172
|
+
|
1173
|
+
Yields:
|
1174
|
+
A dictionary of model ID to usage. If a model does not supports usage
|
1175
|
+
counting, the dict entry will be None.
|
1176
|
+
"""
|
1177
|
+
if not lm:
|
1178
|
+
model_ids = None
|
1179
|
+
else:
|
1180
|
+
model_ids = [m.model_id if isinstance(m, LanguageModel) else m for m in lm]
|
1181
|
+
|
1182
|
+
trackers = component.context_value('__usage_trackers__', [])
|
1183
|
+
tracker = _UsageTracker(set(model_ids) if model_ids else None)
|
1184
|
+
with component.context(__usage_trackers__=trackers + [tracker]):
|
1185
|
+
try:
|
1186
|
+
yield tracker.usage_summary
|
1187
|
+
finally:
|
1188
|
+
pass
|