langfun 0.1.2.dev202502110804__py3-none-any.whl → 0.1.2.dev202502130804__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/core/__init__.py +6 -2
- langfun/core/language_model.py +365 -22
- langfun/core/language_model_test.py +123 -35
- langfun/core/llms/__init__.py +50 -57
- langfun/core/llms/anthropic.py +434 -163
- langfun/core/llms/anthropic_test.py +20 -1
- langfun/core/llms/deepseek.py +90 -51
- langfun/core/llms/deepseek_test.py +15 -16
- langfun/core/llms/fake.py +6 -0
- langfun/core/llms/gemini.py +480 -390
- langfun/core/llms/gemini_test.py +27 -7
- langfun/core/llms/google_genai.py +80 -50
- langfun/core/llms/google_genai_test.py +11 -4
- langfun/core/llms/groq.py +268 -167
- langfun/core/llms/groq_test.py +9 -3
- langfun/core/llms/openai.py +839 -328
- langfun/core/llms/openai_compatible.py +3 -18
- langfun/core/llms/openai_compatible_test.py +20 -5
- langfun/core/llms/openai_test.py +14 -4
- langfun/core/llms/rest.py +11 -6
- langfun/core/llms/vertexai.py +238 -240
- langfun/core/llms/vertexai_test.py +35 -8
- {langfun-0.1.2.dev202502110804.dist-info → langfun-0.1.2.dev202502130804.dist-info}/METADATA +1 -1
- {langfun-0.1.2.dev202502110804.dist-info → langfun-0.1.2.dev202502130804.dist-info}/RECORD +27 -27
- {langfun-0.1.2.dev202502110804.dist-info → langfun-0.1.2.dev202502130804.dist-info}/LICENSE +0 -0
- {langfun-0.1.2.dev202502110804.dist-info → langfun-0.1.2.dev202502130804.dist-info}/WHEEL +0 -0
- {langfun-0.1.2.dev202502110804.dist-info → langfun-0.1.2.dev202502130804.dist-info}/top_level.txt +0 -0
langfun/core/__init__.py
CHANGED
@@ -93,14 +93,18 @@ from langfun.core.modality import ModalityRef
|
|
93
93
|
from langfun.core.modality import ModalityError
|
94
94
|
|
95
95
|
# Interfaces for languge models.
|
96
|
+
from langfun.core.language_model import ModelInfo
|
96
97
|
from langfun.core.language_model import LanguageModel
|
98
|
+
|
97
99
|
from langfun.core.language_model import LMSample
|
98
100
|
from langfun.core.language_model import LMSamplingOptions
|
101
|
+
from langfun.core.language_model import LMSamplingResult
|
102
|
+
from langfun.core.language_model import LMScoringResult
|
103
|
+
|
99
104
|
from langfun.core.language_model import LMSamplingUsage
|
100
105
|
from langfun.core.language_model import UsageNotAvailable
|
101
106
|
from langfun.core.language_model import UsageSummary
|
102
|
-
|
103
|
-
from langfun.core.language_model import LMScoringResult
|
107
|
+
|
104
108
|
from langfun.core.language_model import LMCache
|
105
109
|
from langfun.core.language_model import LMDebugMode
|
106
110
|
|
langfun/core/language_model.py
CHANGED
@@ -16,12 +16,14 @@
|
|
16
16
|
import abc
|
17
17
|
import contextlib
|
18
18
|
import dataclasses
|
19
|
+
import datetime
|
19
20
|
import enum
|
20
21
|
import functools
|
21
22
|
import math
|
23
|
+
import re
|
22
24
|
import threading
|
23
25
|
import time
|
24
|
-
from typing import Annotated, Any, Callable, Iterator, Optional, Sequence, Tuple, Type, Union
|
26
|
+
from typing import Annotated, Any, Callable, ClassVar, Iterator, Literal, Optional, Sequence, Tuple, Type, Union, final
|
25
27
|
from langfun.core import component
|
26
28
|
from langfun.core import concurrent
|
27
29
|
from langfun.core import console
|
@@ -29,10 +31,6 @@ from langfun.core import message as message_lib
|
|
29
31
|
|
30
32
|
import pyglove as pg
|
31
33
|
|
32
|
-
TOKENS_PER_REQUEST = 250 # Estimated num tokens for a single request
|
33
|
-
DEFAULT_MAX_CONCURRENCY = 1 # Use this as max concurrency if no RPM or TPM data
|
34
|
-
|
35
|
-
|
36
34
|
#
|
37
35
|
# Common errors during calling language models.
|
38
36
|
#
|
@@ -54,6 +52,254 @@ class TemporaryLMError(RetryableLMError):
|
|
54
52
|
"""Error for temporary service issues that can be retried."""
|
55
53
|
|
56
54
|
|
55
|
+
#
|
56
|
+
# Language model information.
|
57
|
+
#
|
58
|
+
|
59
|
+
|
60
|
+
class ModelInfo(pg.Object):
|
61
|
+
"""Common information for a language model."""
|
62
|
+
|
63
|
+
# Constant for modalities.
|
64
|
+
TEXT_INPUT_ONLY = []
|
65
|
+
|
66
|
+
model_id: Annotated[
|
67
|
+
str,
|
68
|
+
'A global unique identifier of the language model. ',
|
69
|
+
]
|
70
|
+
|
71
|
+
alias_for: Annotated[
|
72
|
+
str | None,
|
73
|
+
'The fixed-version model ID that this model is aliased for.',
|
74
|
+
] = None
|
75
|
+
|
76
|
+
#
|
77
|
+
# Basic information.
|
78
|
+
#
|
79
|
+
|
80
|
+
model_type: Annotated[
|
81
|
+
Literal['unknown', 'pretrained', 'instruction-tuned', 'thinking'],
|
82
|
+
'The type of the model.'
|
83
|
+
] = 'unknown'
|
84
|
+
|
85
|
+
provider: Annotated[
|
86
|
+
str | None,
|
87
|
+
(
|
88
|
+
'The service provider (host) of the LLM. E.g. VertexAI, Microsoft, '
|
89
|
+
'etc.'
|
90
|
+
)
|
91
|
+
] = None
|
92
|
+
|
93
|
+
description: Annotated[
|
94
|
+
str | None,
|
95
|
+
'An optional description of the model family.'
|
96
|
+
] = None
|
97
|
+
|
98
|
+
url: Annotated[
|
99
|
+
str | None,
|
100
|
+
'The URL of the model.'
|
101
|
+
] = None
|
102
|
+
|
103
|
+
release_date: Annotated[
|
104
|
+
datetime.date | None,
|
105
|
+
'The release date of the model. '
|
106
|
+
] = None
|
107
|
+
|
108
|
+
in_service: Annotated[
|
109
|
+
bool,
|
110
|
+
'If True, the model is in service.'
|
111
|
+
] = True
|
112
|
+
|
113
|
+
#
|
114
|
+
# LLM capabilities.
|
115
|
+
#
|
116
|
+
|
117
|
+
input_modalities: Annotated[
|
118
|
+
list[str] | None,
|
119
|
+
(
|
120
|
+
'Supported MIME types as model inputs. '
|
121
|
+
'If None, this information is unknown, so Langfun allows all '
|
122
|
+
'modalities from the input to be passed to the model.'
|
123
|
+
)
|
124
|
+
] = None
|
125
|
+
|
126
|
+
class ContextLength(pg.Object):
|
127
|
+
"""Context length information."""
|
128
|
+
|
129
|
+
max_input_tokens: Annotated[
|
130
|
+
int | None,
|
131
|
+
(
|
132
|
+
'The maximum number of input tokens of the language model. '
|
133
|
+
'If None, there is no limit or this information is unknown.'
|
134
|
+
)
|
135
|
+
] = None
|
136
|
+
|
137
|
+
max_output_tokens: Annotated[
|
138
|
+
int | None,
|
139
|
+
(
|
140
|
+
'The maximum number of output tokens of the language model. '
|
141
|
+
'If None, there is no limit or this information is unknown.'
|
142
|
+
)
|
143
|
+
] = None
|
144
|
+
|
145
|
+
max_cot_tokens: Annotated[
|
146
|
+
int | None,
|
147
|
+
(
|
148
|
+
'The maximum number of Chain-of-Thought tokens to generate. '
|
149
|
+
'If None, there is not limit or not applicable.'
|
150
|
+
)
|
151
|
+
] = None
|
152
|
+
|
153
|
+
context_length: Annotated[
|
154
|
+
ContextLength | None,
|
155
|
+
(
|
156
|
+
'Context length information of the model. '
|
157
|
+
'If None, this information is unknown.'
|
158
|
+
)
|
159
|
+
] = None
|
160
|
+
|
161
|
+
#
|
162
|
+
# Common pricing information.
|
163
|
+
#
|
164
|
+
|
165
|
+
class Pricing(pg.Object):
|
166
|
+
"""Pricing information."""
|
167
|
+
|
168
|
+
cost_per_1m_cached_input_tokens: Annotated[
|
169
|
+
float | None,
|
170
|
+
(
|
171
|
+
'The cost per 1M cached input tokens in US dollars. '
|
172
|
+
'If None, this information is unknown.'
|
173
|
+
)
|
174
|
+
] = None
|
175
|
+
|
176
|
+
cost_per_1m_input_tokens: Annotated[
|
177
|
+
float | None,
|
178
|
+
(
|
179
|
+
'The cost per 1M input tokens in US dollars. '
|
180
|
+
'If None, this information is unknown.'
|
181
|
+
)
|
182
|
+
] = None
|
183
|
+
|
184
|
+
cost_per_1m_output_tokens: Annotated[
|
185
|
+
float | None,
|
186
|
+
(
|
187
|
+
'The cost per 1M output tokens in US dollars. '
|
188
|
+
'If None, this information is unknown.'
|
189
|
+
)
|
190
|
+
] = None
|
191
|
+
|
192
|
+
def estimate_cost(self, usage: 'LMSamplingUsage') -> float | None:
|
193
|
+
"""Estimates the cost of using the model. Subclass could override.
|
194
|
+
|
195
|
+
Args:
|
196
|
+
usage: The usage information of the model.
|
197
|
+
|
198
|
+
Returns:
|
199
|
+
The estimated cost in US dollars. If None, cost estimating is not
|
200
|
+
supported on the model.
|
201
|
+
"""
|
202
|
+
# NOTE(daiyip): supported cached tokens accounting in future.
|
203
|
+
if (self.cost_per_1m_input_tokens is None
|
204
|
+
or self.cost_per_1m_output_tokens is None):
|
205
|
+
return None
|
206
|
+
return (
|
207
|
+
self.cost_per_1m_input_tokens * usage.prompt_tokens
|
208
|
+
+ self.cost_per_1m_output_tokens * usage.completion_tokens
|
209
|
+
) / 1000_000
|
210
|
+
|
211
|
+
pricing: Annotated[
|
212
|
+
Pricing | None,
|
213
|
+
(
|
214
|
+
'Pricing information. If None, this information is unknown.'
|
215
|
+
)
|
216
|
+
] = None
|
217
|
+
|
218
|
+
#
|
219
|
+
# Rate limits.
|
220
|
+
#
|
221
|
+
|
222
|
+
class RateLimits(pg.Object):
|
223
|
+
"""Preset rate limits."""
|
224
|
+
|
225
|
+
max_requests_per_minute: Annotated[
|
226
|
+
int | None,
|
227
|
+
(
|
228
|
+
'The max number of requests per minute.'
|
229
|
+
'If None, there is no limit.'
|
230
|
+
)
|
231
|
+
] = None
|
232
|
+
|
233
|
+
max_tokens_per_minute: Annotated[
|
234
|
+
int | None,
|
235
|
+
(
|
236
|
+
'The max number of tokens per minute.'
|
237
|
+
'If None, there is no limit.'
|
238
|
+
)
|
239
|
+
] = None
|
240
|
+
|
241
|
+
rate_limits: Annotated[
|
242
|
+
RateLimits | None,
|
243
|
+
(
|
244
|
+
'Rate limits. If None, this information is unknown.'
|
245
|
+
)
|
246
|
+
] = None
|
247
|
+
|
248
|
+
#
|
249
|
+
# Additional information.
|
250
|
+
#
|
251
|
+
|
252
|
+
metadata: Annotated[
|
253
|
+
dict[str, Any],
|
254
|
+
(
|
255
|
+
'Model metadata. This could be used to store model-specific '
|
256
|
+
'information, which could be consumed by modules that need to '
|
257
|
+
'apply model-specific logic.'
|
258
|
+
),
|
259
|
+
] = {}
|
260
|
+
|
261
|
+
#
|
262
|
+
# Common model protocols.
|
263
|
+
#
|
264
|
+
|
265
|
+
@final
|
266
|
+
def estimate_cost(self, usage: 'LMSamplingUsage') -> float | None:
|
267
|
+
"""Estimates the cost of using the model."""
|
268
|
+
if self.pricing is None:
|
269
|
+
return None
|
270
|
+
return self.pricing.estimate_cost(usage)
|
271
|
+
|
272
|
+
def supports_input(self, mime_type: str) -> bool:
|
273
|
+
"""Returns True if an input MIME type is supported.
|
274
|
+
|
275
|
+
Subclass could override.
|
276
|
+
|
277
|
+
Args:
|
278
|
+
mime_type: The MIME type of the input.
|
279
|
+
|
280
|
+
Returns:
|
281
|
+
True if the input MIME type is supported.
|
282
|
+
"""
|
283
|
+
if self._input_modalities is None:
|
284
|
+
return True
|
285
|
+
return mime_type.lower() in self._input_modalities
|
286
|
+
|
287
|
+
@property
|
288
|
+
def resource_id(self) -> str:
|
289
|
+
"""Returns the resource ID of the LLM. Subclass could override."""
|
290
|
+
canonical_model_id = self.alias_for or self.model_id
|
291
|
+
if self.provider is None or not pg.is_deterministic(self.provider):
|
292
|
+
return canonical_model_id
|
293
|
+
provider = self.provider.lower().replace(' ', '_')
|
294
|
+
return f'{provider}://{canonical_model_id}'
|
295
|
+
|
296
|
+
def _on_bound(self):
|
297
|
+
super()._on_bound()
|
298
|
+
self._input_modalities = set(
|
299
|
+
[mime_type.lower() for mime_type in self.input_modalities]
|
300
|
+
) if self.input_modalities is not None else None
|
301
|
+
|
302
|
+
|
57
303
|
#
|
58
304
|
# Language model input/output interfaces.
|
59
305
|
#
|
@@ -88,12 +334,15 @@ class RetryStats(pg.Object):
|
|
88
334
|
int,
|
89
335
|
'Total number of retry attempts on LLM (excluding the first attempt).',
|
90
336
|
] = 0
|
337
|
+
|
91
338
|
total_wait_interval: Annotated[
|
92
339
|
float, 'Total wait interval in seconds due to retry.'
|
93
340
|
] = 0
|
341
|
+
|
94
342
|
total_call_interval: Annotated[
|
95
343
|
float, 'Total LLM call interval in seconds.'
|
96
344
|
] = 0
|
345
|
+
|
97
346
|
errors: Annotated[
|
98
347
|
dict[str, int],
|
99
348
|
'A Counter of error types encountered during the retry attempts.',
|
@@ -426,13 +675,12 @@ class LanguageModel(component.Component):
|
|
426
675
|
'Max concurrent requests being sent to the server. '
|
427
676
|
'If None, there is no limit. '
|
428
677
|
'Please note that the concurrency control is based on the '
|
429
|
-
'`resource_id` property, meaning that model instances shared '
|
678
|
+
'`info.resource_id` property, meaning that model instances shared '
|
430
679
|
'the same resource ID will be accounted under the same concurrency '
|
431
680
|
'control key. This allows a process-level concurrency control '
|
432
681
|
'for specific models regardless the number of LM (client) instances '
|
433
|
-
'created by the program. Subclasses could override
|
434
|
-
'
|
435
|
-
'concurrency control.'
|
682
|
+
'created by the program. Subclasses could override the '
|
683
|
+
'`max_concurrency` property to allow dynamic concurrency control.'
|
436
684
|
),
|
437
685
|
] = None
|
438
686
|
|
@@ -487,6 +735,40 @@ class LanguageModel(component.Component):
|
|
487
735
|
),
|
488
736
|
] = False
|
489
737
|
|
738
|
+
_MODEL_FACTORY: ClassVar[dict[str, Callable[..., 'LanguageModel']]] = {}
|
739
|
+
|
740
|
+
@classmethod
|
741
|
+
def register(
|
742
|
+
cls,
|
743
|
+
model_id_or_prefix: str, factory: Callable[..., 'LanguageModel']
|
744
|
+
) -> None:
|
745
|
+
"""Registers a factory function for a model ID."""
|
746
|
+
cls._MODEL_FACTORY[model_id_or_prefix] = factory
|
747
|
+
|
748
|
+
@classmethod
|
749
|
+
def get(cls, model_id: str, *args, **kwargs):
|
750
|
+
"""Creates a language model instance from a model ID."""
|
751
|
+
factory = cls._MODEL_FACTORY.get(model_id)
|
752
|
+
if factory is None:
|
753
|
+
factories = []
|
754
|
+
for k, v in cls._MODEL_FACTORY.items():
|
755
|
+
if re.match(k, model_id):
|
756
|
+
factories.append((k, v))
|
757
|
+
if not factories:
|
758
|
+
raise ValueError(f'Model not found: {model_id!r}.')
|
759
|
+
elif len(factories) > 1:
|
760
|
+
raise ValueError(
|
761
|
+
f'Multiple models found for {model_id!r}: '
|
762
|
+
f'{[x[0] for x in factories]}. '
|
763
|
+
'Please specify a more specific model ID.'
|
764
|
+
)
|
765
|
+
factory = factories[0][1]
|
766
|
+
return factory(model_id, *args, **kwargs)
|
767
|
+
|
768
|
+
@classmethod
|
769
|
+
def dir(cls):
|
770
|
+
return sorted(list(LanguageModel._MODEL_FACTORY.keys()))
|
771
|
+
|
490
772
|
@pg.explicit_method_override
|
491
773
|
def __init__(self, *args, **kwargs) -> None:
|
492
774
|
"""Overrides __init__ to pass through **kwargs to sampling options."""
|
@@ -512,16 +794,62 @@ class LanguageModel(component.Component):
|
|
512
794
|
def _on_bound(self):
|
513
795
|
super()._on_bound()
|
514
796
|
self._call_counter = 0
|
797
|
+
self.__dict__.pop('model_info', None)
|
798
|
+
|
799
|
+
@functools.cached_property
|
800
|
+
def model_info(self) -> ModelInfo:
|
801
|
+
"""Returns the specification of the model."""
|
802
|
+
return ModelInfo(model_id='unknown')
|
803
|
+
|
804
|
+
#
|
805
|
+
# Shortcut properties/methods from `model_info`.
|
806
|
+
# If these behaviors need to be changed, please override the corresponding
|
807
|
+
# methods in the ModelInfo subclasses instead of these properties/methods.
|
808
|
+
#
|
515
809
|
|
810
|
+
@final
|
516
811
|
@property
|
517
812
|
def model_id(self) -> str:
|
518
813
|
"""Returns a string to identify the model."""
|
519
|
-
return self.
|
814
|
+
return self.model_info.alias_for or self.model_info.model_id
|
520
815
|
|
816
|
+
@final
|
521
817
|
@property
|
522
818
|
def resource_id(self) -> str:
|
523
819
|
"""Resource ID for performing request parallism control."""
|
524
|
-
return self.
|
820
|
+
return self.model_info.resource_id
|
821
|
+
|
822
|
+
@final
|
823
|
+
@property
|
824
|
+
def context_length(self) -> ModelInfo.ContextLength | None:
|
825
|
+
"""Returns the context length of the model."""
|
826
|
+
return self.model_info.context_length
|
827
|
+
|
828
|
+
@final
|
829
|
+
@property
|
830
|
+
def pricing(self) -> ModelInfo.Pricing | None:
|
831
|
+
"""Returns the pricing information of the model."""
|
832
|
+
return self.model_info.pricing
|
833
|
+
|
834
|
+
@final
|
835
|
+
@property
|
836
|
+
def rate_limits(self) -> ModelInfo.RateLimits | None:
|
837
|
+
"""Returns the rate limits to the model."""
|
838
|
+
return self.model_info.rate_limits
|
839
|
+
|
840
|
+
@final
|
841
|
+
def supports_input(self, mime_type: str):
|
842
|
+
"""Returns True if an input type is supported. Subclasses can override."""
|
843
|
+
return self.model_info.supports_input(mime_type)
|
844
|
+
|
845
|
+
@final
|
846
|
+
def estimate_cost(self, usage: LMSamplingUsage) -> float | None:
|
847
|
+
"""Returns the estimated cost of a usage. Subclasses can override."""
|
848
|
+
return self.model_info.estimate_cost(usage)
|
849
|
+
|
850
|
+
#
|
851
|
+
# Language model operations.
|
852
|
+
#
|
525
853
|
|
526
854
|
def sample(
|
527
855
|
self,
|
@@ -554,10 +882,17 @@ class LanguageModel(component.Component):
|
|
554
882
|
response.metadata.logprobs = sample.logprobs
|
555
883
|
response.metadata.is_cached = result.is_cached
|
556
884
|
|
885
|
+
# Update estimated cost.
|
886
|
+
usage = result.usage
|
887
|
+
estimated_cost = self.estimate_cost(usage)
|
888
|
+
if estimated_cost is not None:
|
889
|
+
usage.rebind(
|
890
|
+
estimated_cost=estimated_cost, skip_notification=True
|
891
|
+
)
|
892
|
+
|
557
893
|
# NOTE(daiyip): Current usage is computed at per-result level,
|
558
894
|
# which is accurate when n=1. For n > 1, we average the usage across
|
559
895
|
# multiple samples.
|
560
|
-
usage = result.usage
|
561
896
|
if len(result.samples) == 1 or isinstance(usage, UsageNotAvailable):
|
562
897
|
response.metadata.usage = usage
|
563
898
|
else:
|
@@ -945,16 +1280,24 @@ class LanguageModel(component.Component):
|
|
945
1280
|
color='blue',
|
946
1281
|
)
|
947
1282
|
|
948
|
-
|
949
|
-
|
950
|
-
|
951
|
-
|
952
|
-
|
953
|
-
|
954
|
-
|
955
|
-
|
956
|
-
|
957
|
-
|
1283
|
+
@classmethod
|
1284
|
+
def estimate_max_concurrency(
|
1285
|
+
cls,
|
1286
|
+
max_tokens_per_minute: int | None,
|
1287
|
+
max_requests_per_minute: int | None,
|
1288
|
+
average_tokens_per_request: int = 250
|
1289
|
+
) -> int | None:
|
1290
|
+
"""Estimates max concurrency concurrency based on the rate limits."""
|
1291
|
+
# NOTE(daiyip): max concurrency is estimated based on the rate limit.
|
1292
|
+
# We assume each request has approximately 250 tokens, and each request
|
1293
|
+
# takes 1 second to complete. This might not be accurate for all models.
|
1294
|
+
if max_tokens_per_minute is not None:
|
1295
|
+
return max(
|
1296
|
+
int(max_tokens_per_minute / average_tokens_per_request / 60), 1
|
1297
|
+
)
|
1298
|
+
elif max_requests_per_minute is not None:
|
1299
|
+
return max(int(max_requests_per_minute / 60), 1)
|
1300
|
+
return None
|
958
1301
|
|
959
1302
|
|
960
1303
|
class UsageSummary(pg.Object, pg.views.HtmlTreeView.Extension):
|
@@ -9,7 +9,7 @@
|
|
9
9
|
# Unless required by applicable law or agreed to in writing, software
|
10
10
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the
|
12
|
+
# See the License for the infoific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
"""Tests for language model."""
|
15
15
|
|
@@ -30,6 +30,14 @@ class MockModel(lm_lib.LanguageModel):
|
|
30
30
|
failures_before_attempt: int = 0
|
31
31
|
name: str = 'MockModel'
|
32
32
|
|
33
|
+
class ModelInfo(lm_lib.ModelInfo):
|
34
|
+
def estimate_cost(self, usage: lm_lib.LMSamplingUsage) -> float | None:
|
35
|
+
return 1.0
|
36
|
+
|
37
|
+
@property
|
38
|
+
def model_info(self) -> lm_lib.ModelInfo:
|
39
|
+
return MockModel.ModelInfo(model_id=self.name)
|
40
|
+
|
33
41
|
def _sample(self,
|
34
42
|
prompts: list[message_lib.Message]
|
35
43
|
) -> list[lm_lib.LMSamplingResult]:
|
@@ -89,6 +97,73 @@ class MockTokenizeModel(MockModel):
|
|
89
97
|
return [(w, i) for i, w in enumerate(prompt.text.split(' '))]
|
90
98
|
|
91
99
|
|
100
|
+
class ModelInfoTest(unittest.TestCase):
|
101
|
+
"""Tests for ModelInfo."""
|
102
|
+
|
103
|
+
def test_basics(self):
|
104
|
+
info = lm_lib.ModelInfo(
|
105
|
+
model_id='model1_alias',
|
106
|
+
provider='Test Provider',
|
107
|
+
alias_for='model1'
|
108
|
+
)
|
109
|
+
self.assertEqual(info.model_id, 'model1_alias')
|
110
|
+
self.assertEqual(info.alias_for, 'model1')
|
111
|
+
self.assertEqual(info.model_type, 'unknown')
|
112
|
+
self.assertEqual(info.provider, 'Test Provider')
|
113
|
+
self.assertIsNone(info.description)
|
114
|
+
self.assertIsNone(info.url)
|
115
|
+
self.assertIsNone(info.release_date)
|
116
|
+
self.assertTrue(info.in_service)
|
117
|
+
self.assertIsNone(info.context_length)
|
118
|
+
self.assertIsNone(info.pricing)
|
119
|
+
self.assertIsNone(info.rate_limits)
|
120
|
+
self.assertIsNone(info.input_modalities)
|
121
|
+
|
122
|
+
def test_resource_id(self):
|
123
|
+
info = lm_lib.ModelInfo(
|
124
|
+
model_id='model1',
|
125
|
+
)
|
126
|
+
self.assertEqual(info.resource_id, 'model1')
|
127
|
+
info = lm_lib.ModelInfo(
|
128
|
+
model_id='model1_alias',
|
129
|
+
provider='Test Provider',
|
130
|
+
alias_for='model1'
|
131
|
+
)
|
132
|
+
self.assertEqual(info.resource_id, 'test_provider://model1')
|
133
|
+
info = lm_lib.ModelInfo(
|
134
|
+
model_id='model1_alias',
|
135
|
+
provider=pg.oneof(['Provider1', 'Provider2']),
|
136
|
+
alias_for='model1'
|
137
|
+
)
|
138
|
+
self.assertEqual(info.resource_id, 'model1')
|
139
|
+
|
140
|
+
def test_estimate_cost(self):
|
141
|
+
self.assertIsNone(
|
142
|
+
lm_lib.ModelInfo('unknown').estimate_cost(
|
143
|
+
lm_lib.LMSamplingUsage(100, 100, 200, 1)
|
144
|
+
)
|
145
|
+
)
|
146
|
+
self.assertIsNone(
|
147
|
+
lm_lib.ModelInfo(
|
148
|
+
'unknown', pricing=lm_lib.ModelInfo.Pricing()
|
149
|
+
).estimate_cost(
|
150
|
+
lm_lib.LMSamplingUsage(100, 100, 200, 1)
|
151
|
+
)
|
152
|
+
)
|
153
|
+
self.assertEqual(
|
154
|
+
lm_lib.ModelInfo(
|
155
|
+
'unknown', pricing=lm_lib.ModelInfo.Pricing(
|
156
|
+
cost_per_1m_input_tokens=1.0,
|
157
|
+
cost_per_1m_output_tokens=2.0,
|
158
|
+
cost_per_1m_cached_input_tokens=1.0,
|
159
|
+
)
|
160
|
+
).estimate_cost(
|
161
|
+
lm_lib.LMSamplingUsage(100, 100, 200, 1)
|
162
|
+
),
|
163
|
+
0.0003
|
164
|
+
)
|
165
|
+
|
166
|
+
|
92
167
|
class LMSamplingOptionsTest(unittest.TestCase):
|
93
168
|
"""Tests for LMSamplingOptions."""
|
94
169
|
|
@@ -107,15 +182,49 @@ class LMSamplingOptionsTest(unittest.TestCase):
|
|
107
182
|
class LanguageModelTest(unittest.TestCase):
|
108
183
|
"""Tests for LanguageModel."""
|
109
184
|
|
110
|
-
def
|
111
|
-
|
185
|
+
def test_register_and_get(self):
|
186
|
+
def mock_model(model_id: str, *args, **kwargs):
|
187
|
+
del model_id
|
188
|
+
return MockModel(*args, **kwargs)
|
189
|
+
|
190
|
+
lm_lib.LanguageModel.register('MockModel', mock_model)
|
191
|
+
lm = lm_lib.LanguageModel.get(
|
192
|
+
'MockModel', 1, temperature=0.2
|
193
|
+
)
|
112
194
|
self.assertEqual(lm.model_id, 'MockModel')
|
113
195
|
self.assertEqual(lm.resource_id, 'MockModel')
|
196
|
+
self.assertEqual(lm.failures_before_attempt, 1)
|
197
|
+
self.assertEqual(lm.sampling_options.temperature, 0.2)
|
198
|
+
self.assertIn('MockModel', lm_lib.LanguageModel.dir())
|
199
|
+
|
200
|
+
lm_lib.LanguageModel.register('mock://.*', mock_model)
|
201
|
+
lm_lib.LanguageModel.register('mock.*', mock_model)
|
202
|
+
self.assertIsInstance(lm_lib.LanguageModel.get('mock'), MockModel)
|
203
|
+
with self.assertRaisesRegex(ValueError, 'Multiple models found'):
|
204
|
+
lm_lib.LanguageModel.get('mock://test2')
|
205
|
+
|
206
|
+
with self.assertRaisesRegex(ValueError, 'Model not found'):
|
207
|
+
lm_lib.LanguageModel.get('non-existent://test2')
|
208
|
+
|
209
|
+
def test_basics(self):
|
210
|
+
lm = MockModel(1, temperature=0.5, top_k=2, max_attempts=2)
|
211
|
+
self.assertEqual(
|
212
|
+
lm.model_info, MockModel.ModelInfo(model_id='MockModel')
|
213
|
+
)
|
214
|
+
self.assertEqual(lm.model_id, 'MockModel')
|
114
215
|
self.assertIsNone(lm.max_concurrency)
|
115
216
|
self.assertEqual(lm.failures_before_attempt, 1)
|
116
217
|
self.assertEqual(lm.sampling_options.temperature, 0.5)
|
117
218
|
self.assertEqual(lm.sampling_options.top_k, 2)
|
118
219
|
self.assertEqual(lm.max_attempts, 2)
|
220
|
+
self.assertIsNone(lm.context_length)
|
221
|
+
self.assertIsNone(lm.pricing)
|
222
|
+
self.assertIsNone(lm.rate_limits)
|
223
|
+
self.assertTrue(lm.supports_input('image/png'))
|
224
|
+
self.assertEqual(
|
225
|
+
lm.estimate_cost(lm_lib.LMSamplingUsage(100, 100, 200, 1)),
|
226
|
+
1.0
|
227
|
+
)
|
119
228
|
|
120
229
|
def test_subclassing(self):
|
121
230
|
|
@@ -446,6 +555,17 @@ class LanguageModelTest(unittest.TestCase):
|
|
446
555
|
lm = MockModel(cache=cache, top_k=1)
|
447
556
|
self.assertEqual(lm('a'), 'a')
|
448
557
|
|
558
|
+
def test_estimate_max_concurrency(self):
|
559
|
+
self.assertIsNone(lm_lib.LanguageModel.estimate_max_concurrency(None, None))
|
560
|
+
self.assertEqual(
|
561
|
+
lm_lib.LanguageModel.estimate_max_concurrency(250 * 60 * 10, None),
|
562
|
+
10
|
563
|
+
)
|
564
|
+
self.assertEqual(
|
565
|
+
lm_lib.LanguageModel.estimate_max_concurrency(None, 60 * 10),
|
566
|
+
10
|
567
|
+
)
|
568
|
+
|
449
569
|
def test_retry(self):
|
450
570
|
lm = MockModel(
|
451
571
|
failures_before_attempt=1, top_k=1, max_attempts=2, retry_interval=1
|
@@ -692,38 +812,6 @@ class LanguageModelTest(unittest.TestCase):
|
|
692
812
|
with self.assertRaises(NotImplementedError):
|
693
813
|
MockModel().tokenize('hi')
|
694
814
|
|
695
|
-
def test_rate_to_max_concurrency_no_rpm_no_tpm(self) -> None:
|
696
|
-
lm = MockModel()
|
697
|
-
self.assertEqual(
|
698
|
-
lm_lib.DEFAULT_MAX_CONCURRENCY,
|
699
|
-
lm.rate_to_max_concurrency(requests_per_min=0, tokens_per_min=0),
|
700
|
-
)
|
701
|
-
self.assertEqual(
|
702
|
-
lm_lib.DEFAULT_MAX_CONCURRENCY,
|
703
|
-
lm.rate_to_max_concurrency(requests_per_min=-1, tokens_per_min=-1),
|
704
|
-
)
|
705
|
-
|
706
|
-
def test_rate_to_max_concurrency_only_rpm_specified_uses_rpm(self) -> None:
|
707
|
-
lm = MockModel()
|
708
|
-
test_rpm = 1e4
|
709
|
-
self.assertEqual(
|
710
|
-
lm.rate_to_max_concurrency(requests_per_min=test_rpm),
|
711
|
-
int(test_rpm / 60)
|
712
|
-
)
|
713
|
-
|
714
|
-
def test_rate_to_max_concurrency_tpm_specified_uses_tpm(self) -> None:
|
715
|
-
lm = MockModel()
|
716
|
-
test_tpm = 1e7
|
717
|
-
self.assertEqual(
|
718
|
-
lm.rate_to_max_concurrency(requests_per_min=1, tokens_per_min=test_tpm),
|
719
|
-
int(test_tpm / lm_lib.TOKENS_PER_REQUEST / 60),
|
720
|
-
)
|
721
|
-
|
722
|
-
def test_rate_to_max_concurrency_small_rate_returns_one(self) -> None:
|
723
|
-
lm = MockModel()
|
724
|
-
self.assertEqual(lm.rate_to_max_concurrency(requests_per_min=1), 1)
|
725
|
-
self.assertEqual(lm.rate_to_max_concurrency(tokens_per_min=1), 1)
|
726
|
-
|
727
815
|
def test_track_usages(self):
|
728
816
|
lm = MockModel(name='model1')
|
729
817
|
lm2 = MockModel(name='model2')
|