langfun 0.1.2.dev202502110804__py3-none-any.whl → 0.1.2.dev202502130804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
langfun/core/__init__.py CHANGED
@@ -93,14 +93,18 @@ from langfun.core.modality import ModalityRef
93
93
  from langfun.core.modality import ModalityError
94
94
 
95
95
  # Interfaces for languge models.
96
+ from langfun.core.language_model import ModelInfo
96
97
  from langfun.core.language_model import LanguageModel
98
+
97
99
  from langfun.core.language_model import LMSample
98
100
  from langfun.core.language_model import LMSamplingOptions
101
+ from langfun.core.language_model import LMSamplingResult
102
+ from langfun.core.language_model import LMScoringResult
103
+
99
104
  from langfun.core.language_model import LMSamplingUsage
100
105
  from langfun.core.language_model import UsageNotAvailable
101
106
  from langfun.core.language_model import UsageSummary
102
- from langfun.core.language_model import LMSamplingResult
103
- from langfun.core.language_model import LMScoringResult
107
+
104
108
  from langfun.core.language_model import LMCache
105
109
  from langfun.core.language_model import LMDebugMode
106
110
 
@@ -16,12 +16,14 @@
16
16
  import abc
17
17
  import contextlib
18
18
  import dataclasses
19
+ import datetime
19
20
  import enum
20
21
  import functools
21
22
  import math
23
+ import re
22
24
  import threading
23
25
  import time
24
- from typing import Annotated, Any, Callable, Iterator, Optional, Sequence, Tuple, Type, Union
26
+ from typing import Annotated, Any, Callable, ClassVar, Iterator, Literal, Optional, Sequence, Tuple, Type, Union, final
25
27
  from langfun.core import component
26
28
  from langfun.core import concurrent
27
29
  from langfun.core import console
@@ -29,10 +31,6 @@ from langfun.core import message as message_lib
29
31
 
30
32
  import pyglove as pg
31
33
 
32
- TOKENS_PER_REQUEST = 250 # Estimated num tokens for a single request
33
- DEFAULT_MAX_CONCURRENCY = 1 # Use this as max concurrency if no RPM or TPM data
34
-
35
-
36
34
  #
37
35
  # Common errors during calling language models.
38
36
  #
@@ -54,6 +52,254 @@ class TemporaryLMError(RetryableLMError):
54
52
  """Error for temporary service issues that can be retried."""
55
53
 
56
54
 
55
+ #
56
+ # Language model information.
57
+ #
58
+
59
+
60
+ class ModelInfo(pg.Object):
61
+ """Common information for a language model."""
62
+
63
+ # Constant for modalities.
64
+ TEXT_INPUT_ONLY = []
65
+
66
+ model_id: Annotated[
67
+ str,
68
+ 'A global unique identifier of the language model. ',
69
+ ]
70
+
71
+ alias_for: Annotated[
72
+ str | None,
73
+ 'The fixed-version model ID that this model is aliased for.',
74
+ ] = None
75
+
76
+ #
77
+ # Basic information.
78
+ #
79
+
80
+ model_type: Annotated[
81
+ Literal['unknown', 'pretrained', 'instruction-tuned', 'thinking'],
82
+ 'The type of the model.'
83
+ ] = 'unknown'
84
+
85
+ provider: Annotated[
86
+ str | None,
87
+ (
88
+ 'The service provider (host) of the LLM. E.g. VertexAI, Microsoft, '
89
+ 'etc.'
90
+ )
91
+ ] = None
92
+
93
+ description: Annotated[
94
+ str | None,
95
+ 'An optional description of the model family.'
96
+ ] = None
97
+
98
+ url: Annotated[
99
+ str | None,
100
+ 'The URL of the model.'
101
+ ] = None
102
+
103
+ release_date: Annotated[
104
+ datetime.date | None,
105
+ 'The release date of the model. '
106
+ ] = None
107
+
108
+ in_service: Annotated[
109
+ bool,
110
+ 'If True, the model is in service.'
111
+ ] = True
112
+
113
+ #
114
+ # LLM capabilities.
115
+ #
116
+
117
+ input_modalities: Annotated[
118
+ list[str] | None,
119
+ (
120
+ 'Supported MIME types as model inputs. '
121
+ 'If None, this information is unknown, so Langfun allows all '
122
+ 'modalities from the input to be passed to the model.'
123
+ )
124
+ ] = None
125
+
126
+ class ContextLength(pg.Object):
127
+ """Context length information."""
128
+
129
+ max_input_tokens: Annotated[
130
+ int | None,
131
+ (
132
+ 'The maximum number of input tokens of the language model. '
133
+ 'If None, there is no limit or this information is unknown.'
134
+ )
135
+ ] = None
136
+
137
+ max_output_tokens: Annotated[
138
+ int | None,
139
+ (
140
+ 'The maximum number of output tokens of the language model. '
141
+ 'If None, there is no limit or this information is unknown.'
142
+ )
143
+ ] = None
144
+
145
+ max_cot_tokens: Annotated[
146
+ int | None,
147
+ (
148
+ 'The maximum number of Chain-of-Thought tokens to generate. '
149
+ 'If None, there is not limit or not applicable.'
150
+ )
151
+ ] = None
152
+
153
+ context_length: Annotated[
154
+ ContextLength | None,
155
+ (
156
+ 'Context length information of the model. '
157
+ 'If None, this information is unknown.'
158
+ )
159
+ ] = None
160
+
161
+ #
162
+ # Common pricing information.
163
+ #
164
+
165
+ class Pricing(pg.Object):
166
+ """Pricing information."""
167
+
168
+ cost_per_1m_cached_input_tokens: Annotated[
169
+ float | None,
170
+ (
171
+ 'The cost per 1M cached input tokens in US dollars. '
172
+ 'If None, this information is unknown.'
173
+ )
174
+ ] = None
175
+
176
+ cost_per_1m_input_tokens: Annotated[
177
+ float | None,
178
+ (
179
+ 'The cost per 1M input tokens in US dollars. '
180
+ 'If None, this information is unknown.'
181
+ )
182
+ ] = None
183
+
184
+ cost_per_1m_output_tokens: Annotated[
185
+ float | None,
186
+ (
187
+ 'The cost per 1M output tokens in US dollars. '
188
+ 'If None, this information is unknown.'
189
+ )
190
+ ] = None
191
+
192
+ def estimate_cost(self, usage: 'LMSamplingUsage') -> float | None:
193
+ """Estimates the cost of using the model. Subclass could override.
194
+
195
+ Args:
196
+ usage: The usage information of the model.
197
+
198
+ Returns:
199
+ The estimated cost in US dollars. If None, cost estimating is not
200
+ supported on the model.
201
+ """
202
+ # NOTE(daiyip): supported cached tokens accounting in future.
203
+ if (self.cost_per_1m_input_tokens is None
204
+ or self.cost_per_1m_output_tokens is None):
205
+ return None
206
+ return (
207
+ self.cost_per_1m_input_tokens * usage.prompt_tokens
208
+ + self.cost_per_1m_output_tokens * usage.completion_tokens
209
+ ) / 1000_000
210
+
211
+ pricing: Annotated[
212
+ Pricing | None,
213
+ (
214
+ 'Pricing information. If None, this information is unknown.'
215
+ )
216
+ ] = None
217
+
218
+ #
219
+ # Rate limits.
220
+ #
221
+
222
+ class RateLimits(pg.Object):
223
+ """Preset rate limits."""
224
+
225
+ max_requests_per_minute: Annotated[
226
+ int | None,
227
+ (
228
+ 'The max number of requests per minute.'
229
+ 'If None, there is no limit.'
230
+ )
231
+ ] = None
232
+
233
+ max_tokens_per_minute: Annotated[
234
+ int | None,
235
+ (
236
+ 'The max number of tokens per minute.'
237
+ 'If None, there is no limit.'
238
+ )
239
+ ] = None
240
+
241
+ rate_limits: Annotated[
242
+ RateLimits | None,
243
+ (
244
+ 'Rate limits. If None, this information is unknown.'
245
+ )
246
+ ] = None
247
+
248
+ #
249
+ # Additional information.
250
+ #
251
+
252
+ metadata: Annotated[
253
+ dict[str, Any],
254
+ (
255
+ 'Model metadata. This could be used to store model-specific '
256
+ 'information, which could be consumed by modules that need to '
257
+ 'apply model-specific logic.'
258
+ ),
259
+ ] = {}
260
+
261
+ #
262
+ # Common model protocols.
263
+ #
264
+
265
+ @final
266
+ def estimate_cost(self, usage: 'LMSamplingUsage') -> float | None:
267
+ """Estimates the cost of using the model."""
268
+ if self.pricing is None:
269
+ return None
270
+ return self.pricing.estimate_cost(usage)
271
+
272
+ def supports_input(self, mime_type: str) -> bool:
273
+ """Returns True if an input MIME type is supported.
274
+
275
+ Subclass could override.
276
+
277
+ Args:
278
+ mime_type: The MIME type of the input.
279
+
280
+ Returns:
281
+ True if the input MIME type is supported.
282
+ """
283
+ if self._input_modalities is None:
284
+ return True
285
+ return mime_type.lower() in self._input_modalities
286
+
287
+ @property
288
+ def resource_id(self) -> str:
289
+ """Returns the resource ID of the LLM. Subclass could override."""
290
+ canonical_model_id = self.alias_for or self.model_id
291
+ if self.provider is None or not pg.is_deterministic(self.provider):
292
+ return canonical_model_id
293
+ provider = self.provider.lower().replace(' ', '_')
294
+ return f'{provider}://{canonical_model_id}'
295
+
296
+ def _on_bound(self):
297
+ super()._on_bound()
298
+ self._input_modalities = set(
299
+ [mime_type.lower() for mime_type in self.input_modalities]
300
+ ) if self.input_modalities is not None else None
301
+
302
+
57
303
  #
58
304
  # Language model input/output interfaces.
59
305
  #
@@ -88,12 +334,15 @@ class RetryStats(pg.Object):
88
334
  int,
89
335
  'Total number of retry attempts on LLM (excluding the first attempt).',
90
336
  ] = 0
337
+
91
338
  total_wait_interval: Annotated[
92
339
  float, 'Total wait interval in seconds due to retry.'
93
340
  ] = 0
341
+
94
342
  total_call_interval: Annotated[
95
343
  float, 'Total LLM call interval in seconds.'
96
344
  ] = 0
345
+
97
346
  errors: Annotated[
98
347
  dict[str, int],
99
348
  'A Counter of error types encountered during the retry attempts.',
@@ -426,13 +675,12 @@ class LanguageModel(component.Component):
426
675
  'Max concurrent requests being sent to the server. '
427
676
  'If None, there is no limit. '
428
677
  'Please note that the concurrency control is based on the '
429
- '`resource_id` property, meaning that model instances shared '
678
+ '`info.resource_id` property, meaning that model instances shared '
430
679
  'the same resource ID will be accounted under the same concurrency '
431
680
  'control key. This allows a process-level concurrency control '
432
681
  'for specific models regardless the number of LM (client) instances '
433
- 'created by the program. Subclasses could override this number or '
434
- 'replace it with a `max_concurrency` property to allow dynamic '
435
- 'concurrency control.'
682
+ 'created by the program. Subclasses could override the '
683
+ '`max_concurrency` property to allow dynamic concurrency control.'
436
684
  ),
437
685
  ] = None
438
686
 
@@ -487,6 +735,40 @@ class LanguageModel(component.Component):
487
735
  ),
488
736
  ] = False
489
737
 
738
+ _MODEL_FACTORY: ClassVar[dict[str, Callable[..., 'LanguageModel']]] = {}
739
+
740
+ @classmethod
741
+ def register(
742
+ cls,
743
+ model_id_or_prefix: str, factory: Callable[..., 'LanguageModel']
744
+ ) -> None:
745
+ """Registers a factory function for a model ID."""
746
+ cls._MODEL_FACTORY[model_id_or_prefix] = factory
747
+
748
+ @classmethod
749
+ def get(cls, model_id: str, *args, **kwargs):
750
+ """Creates a language model instance from a model ID."""
751
+ factory = cls._MODEL_FACTORY.get(model_id)
752
+ if factory is None:
753
+ factories = []
754
+ for k, v in cls._MODEL_FACTORY.items():
755
+ if re.match(k, model_id):
756
+ factories.append((k, v))
757
+ if not factories:
758
+ raise ValueError(f'Model not found: {model_id!r}.')
759
+ elif len(factories) > 1:
760
+ raise ValueError(
761
+ f'Multiple models found for {model_id!r}: '
762
+ f'{[x[0] for x in factories]}. '
763
+ 'Please specify a more specific model ID.'
764
+ )
765
+ factory = factories[0][1]
766
+ return factory(model_id, *args, **kwargs)
767
+
768
+ @classmethod
769
+ def dir(cls):
770
+ return sorted(list(LanguageModel._MODEL_FACTORY.keys()))
771
+
490
772
  @pg.explicit_method_override
491
773
  def __init__(self, *args, **kwargs) -> None:
492
774
  """Overrides __init__ to pass through **kwargs to sampling options."""
@@ -512,16 +794,62 @@ class LanguageModel(component.Component):
512
794
  def _on_bound(self):
513
795
  super()._on_bound()
514
796
  self._call_counter = 0
797
+ self.__dict__.pop('model_info', None)
798
+
799
+ @functools.cached_property
800
+ def model_info(self) -> ModelInfo:
801
+ """Returns the specification of the model."""
802
+ return ModelInfo(model_id='unknown')
803
+
804
+ #
805
+ # Shortcut properties/methods from `model_info`.
806
+ # If these behaviors need to be changed, please override the corresponding
807
+ # methods in the ModelInfo subclasses instead of these properties/methods.
808
+ #
515
809
 
810
+ @final
516
811
  @property
517
812
  def model_id(self) -> str:
518
813
  """Returns a string to identify the model."""
519
- return self.__class__.__name__
814
+ return self.model_info.alias_for or self.model_info.model_id
520
815
 
816
+ @final
521
817
  @property
522
818
  def resource_id(self) -> str:
523
819
  """Resource ID for performing request parallism control."""
524
- return self.model_id
820
+ return self.model_info.resource_id
821
+
822
+ @final
823
+ @property
824
+ def context_length(self) -> ModelInfo.ContextLength | None:
825
+ """Returns the context length of the model."""
826
+ return self.model_info.context_length
827
+
828
+ @final
829
+ @property
830
+ def pricing(self) -> ModelInfo.Pricing | None:
831
+ """Returns the pricing information of the model."""
832
+ return self.model_info.pricing
833
+
834
+ @final
835
+ @property
836
+ def rate_limits(self) -> ModelInfo.RateLimits | None:
837
+ """Returns the rate limits to the model."""
838
+ return self.model_info.rate_limits
839
+
840
+ @final
841
+ def supports_input(self, mime_type: str):
842
+ """Returns True if an input type is supported. Subclasses can override."""
843
+ return self.model_info.supports_input(mime_type)
844
+
845
+ @final
846
+ def estimate_cost(self, usage: LMSamplingUsage) -> float | None:
847
+ """Returns the estimated cost of a usage. Subclasses can override."""
848
+ return self.model_info.estimate_cost(usage)
849
+
850
+ #
851
+ # Language model operations.
852
+ #
525
853
 
526
854
  def sample(
527
855
  self,
@@ -554,10 +882,17 @@ class LanguageModel(component.Component):
554
882
  response.metadata.logprobs = sample.logprobs
555
883
  response.metadata.is_cached = result.is_cached
556
884
 
885
+ # Update estimated cost.
886
+ usage = result.usage
887
+ estimated_cost = self.estimate_cost(usage)
888
+ if estimated_cost is not None:
889
+ usage.rebind(
890
+ estimated_cost=estimated_cost, skip_notification=True
891
+ )
892
+
557
893
  # NOTE(daiyip): Current usage is computed at per-result level,
558
894
  # which is accurate when n=1. For n > 1, we average the usage across
559
895
  # multiple samples.
560
- usage = result.usage
561
896
  if len(result.samples) == 1 or isinstance(usage, UsageNotAvailable):
562
897
  response.metadata.usage = usage
563
898
  else:
@@ -945,16 +1280,24 @@ class LanguageModel(component.Component):
945
1280
  color='blue',
946
1281
  )
947
1282
 
948
- def rate_to_max_concurrency(
949
- self, requests_per_min: float = 0, tokens_per_min: float = 0
950
- ) -> int:
951
- """Converts a rate to a max concurrency."""
952
- if tokens_per_min > 0:
953
- return max(int(tokens_per_min / TOKENS_PER_REQUEST / 60), 1)
954
- elif requests_per_min > 0:
955
- return max(int(requests_per_min / 60), 1) # Max concurrency can't be zero
956
- else:
957
- return DEFAULT_MAX_CONCURRENCY # Default of 1
1283
+ @classmethod
1284
+ def estimate_max_concurrency(
1285
+ cls,
1286
+ max_tokens_per_minute: int | None,
1287
+ max_requests_per_minute: int | None,
1288
+ average_tokens_per_request: int = 250
1289
+ ) -> int | None:
1290
+ """Estimates max concurrency concurrency based on the rate limits."""
1291
+ # NOTE(daiyip): max concurrency is estimated based on the rate limit.
1292
+ # We assume each request has approximately 250 tokens, and each request
1293
+ # takes 1 second to complete. This might not be accurate for all models.
1294
+ if max_tokens_per_minute is not None:
1295
+ return max(
1296
+ int(max_tokens_per_minute / average_tokens_per_request / 60), 1
1297
+ )
1298
+ elif max_requests_per_minute is not None:
1299
+ return max(int(max_requests_per_minute / 60), 1)
1300
+ return None
958
1301
 
959
1302
 
960
1303
  class UsageSummary(pg.Object, pg.views.HtmlTreeView.Extension):
@@ -9,7 +9,7 @@
9
9
  # Unless required by applicable law or agreed to in writing, software
10
10
  # distributed under the License is distributed on an "AS IS" BASIS,
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
12
+ # See the License for the infoific language governing permissions and
13
13
  # limitations under the License.
14
14
  """Tests for language model."""
15
15
 
@@ -30,6 +30,14 @@ class MockModel(lm_lib.LanguageModel):
30
30
  failures_before_attempt: int = 0
31
31
  name: str = 'MockModel'
32
32
 
33
+ class ModelInfo(lm_lib.ModelInfo):
34
+ def estimate_cost(self, usage: lm_lib.LMSamplingUsage) -> float | None:
35
+ return 1.0
36
+
37
+ @property
38
+ def model_info(self) -> lm_lib.ModelInfo:
39
+ return MockModel.ModelInfo(model_id=self.name)
40
+
33
41
  def _sample(self,
34
42
  prompts: list[message_lib.Message]
35
43
  ) -> list[lm_lib.LMSamplingResult]:
@@ -89,6 +97,73 @@ class MockTokenizeModel(MockModel):
89
97
  return [(w, i) for i, w in enumerate(prompt.text.split(' '))]
90
98
 
91
99
 
100
+ class ModelInfoTest(unittest.TestCase):
101
+ """Tests for ModelInfo."""
102
+
103
+ def test_basics(self):
104
+ info = lm_lib.ModelInfo(
105
+ model_id='model1_alias',
106
+ provider='Test Provider',
107
+ alias_for='model1'
108
+ )
109
+ self.assertEqual(info.model_id, 'model1_alias')
110
+ self.assertEqual(info.alias_for, 'model1')
111
+ self.assertEqual(info.model_type, 'unknown')
112
+ self.assertEqual(info.provider, 'Test Provider')
113
+ self.assertIsNone(info.description)
114
+ self.assertIsNone(info.url)
115
+ self.assertIsNone(info.release_date)
116
+ self.assertTrue(info.in_service)
117
+ self.assertIsNone(info.context_length)
118
+ self.assertIsNone(info.pricing)
119
+ self.assertIsNone(info.rate_limits)
120
+ self.assertIsNone(info.input_modalities)
121
+
122
+ def test_resource_id(self):
123
+ info = lm_lib.ModelInfo(
124
+ model_id='model1',
125
+ )
126
+ self.assertEqual(info.resource_id, 'model1')
127
+ info = lm_lib.ModelInfo(
128
+ model_id='model1_alias',
129
+ provider='Test Provider',
130
+ alias_for='model1'
131
+ )
132
+ self.assertEqual(info.resource_id, 'test_provider://model1')
133
+ info = lm_lib.ModelInfo(
134
+ model_id='model1_alias',
135
+ provider=pg.oneof(['Provider1', 'Provider2']),
136
+ alias_for='model1'
137
+ )
138
+ self.assertEqual(info.resource_id, 'model1')
139
+
140
+ def test_estimate_cost(self):
141
+ self.assertIsNone(
142
+ lm_lib.ModelInfo('unknown').estimate_cost(
143
+ lm_lib.LMSamplingUsage(100, 100, 200, 1)
144
+ )
145
+ )
146
+ self.assertIsNone(
147
+ lm_lib.ModelInfo(
148
+ 'unknown', pricing=lm_lib.ModelInfo.Pricing()
149
+ ).estimate_cost(
150
+ lm_lib.LMSamplingUsage(100, 100, 200, 1)
151
+ )
152
+ )
153
+ self.assertEqual(
154
+ lm_lib.ModelInfo(
155
+ 'unknown', pricing=lm_lib.ModelInfo.Pricing(
156
+ cost_per_1m_input_tokens=1.0,
157
+ cost_per_1m_output_tokens=2.0,
158
+ cost_per_1m_cached_input_tokens=1.0,
159
+ )
160
+ ).estimate_cost(
161
+ lm_lib.LMSamplingUsage(100, 100, 200, 1)
162
+ ),
163
+ 0.0003
164
+ )
165
+
166
+
92
167
  class LMSamplingOptionsTest(unittest.TestCase):
93
168
  """Tests for LMSamplingOptions."""
94
169
 
@@ -107,15 +182,49 @@ class LMSamplingOptionsTest(unittest.TestCase):
107
182
  class LanguageModelTest(unittest.TestCase):
108
183
  """Tests for LanguageModel."""
109
184
 
110
- def test_init(self):
111
- lm = MockModel(1, temperature=0.5, top_k=2, max_attempts=2)
185
+ def test_register_and_get(self):
186
+ def mock_model(model_id: str, *args, **kwargs):
187
+ del model_id
188
+ return MockModel(*args, **kwargs)
189
+
190
+ lm_lib.LanguageModel.register('MockModel', mock_model)
191
+ lm = lm_lib.LanguageModel.get(
192
+ 'MockModel', 1, temperature=0.2
193
+ )
112
194
  self.assertEqual(lm.model_id, 'MockModel')
113
195
  self.assertEqual(lm.resource_id, 'MockModel')
196
+ self.assertEqual(lm.failures_before_attempt, 1)
197
+ self.assertEqual(lm.sampling_options.temperature, 0.2)
198
+ self.assertIn('MockModel', lm_lib.LanguageModel.dir())
199
+
200
+ lm_lib.LanguageModel.register('mock://.*', mock_model)
201
+ lm_lib.LanguageModel.register('mock.*', mock_model)
202
+ self.assertIsInstance(lm_lib.LanguageModel.get('mock'), MockModel)
203
+ with self.assertRaisesRegex(ValueError, 'Multiple models found'):
204
+ lm_lib.LanguageModel.get('mock://test2')
205
+
206
+ with self.assertRaisesRegex(ValueError, 'Model not found'):
207
+ lm_lib.LanguageModel.get('non-existent://test2')
208
+
209
+ def test_basics(self):
210
+ lm = MockModel(1, temperature=0.5, top_k=2, max_attempts=2)
211
+ self.assertEqual(
212
+ lm.model_info, MockModel.ModelInfo(model_id='MockModel')
213
+ )
214
+ self.assertEqual(lm.model_id, 'MockModel')
114
215
  self.assertIsNone(lm.max_concurrency)
115
216
  self.assertEqual(lm.failures_before_attempt, 1)
116
217
  self.assertEqual(lm.sampling_options.temperature, 0.5)
117
218
  self.assertEqual(lm.sampling_options.top_k, 2)
118
219
  self.assertEqual(lm.max_attempts, 2)
220
+ self.assertIsNone(lm.context_length)
221
+ self.assertIsNone(lm.pricing)
222
+ self.assertIsNone(lm.rate_limits)
223
+ self.assertTrue(lm.supports_input('image/png'))
224
+ self.assertEqual(
225
+ lm.estimate_cost(lm_lib.LMSamplingUsage(100, 100, 200, 1)),
226
+ 1.0
227
+ )
119
228
 
120
229
  def test_subclassing(self):
121
230
 
@@ -446,6 +555,17 @@ class LanguageModelTest(unittest.TestCase):
446
555
  lm = MockModel(cache=cache, top_k=1)
447
556
  self.assertEqual(lm('a'), 'a')
448
557
 
558
+ def test_estimate_max_concurrency(self):
559
+ self.assertIsNone(lm_lib.LanguageModel.estimate_max_concurrency(None, None))
560
+ self.assertEqual(
561
+ lm_lib.LanguageModel.estimate_max_concurrency(250 * 60 * 10, None),
562
+ 10
563
+ )
564
+ self.assertEqual(
565
+ lm_lib.LanguageModel.estimate_max_concurrency(None, 60 * 10),
566
+ 10
567
+ )
568
+
449
569
  def test_retry(self):
450
570
  lm = MockModel(
451
571
  failures_before_attempt=1, top_k=1, max_attempts=2, retry_interval=1
@@ -692,38 +812,6 @@ class LanguageModelTest(unittest.TestCase):
692
812
  with self.assertRaises(NotImplementedError):
693
813
  MockModel().tokenize('hi')
694
814
 
695
- def test_rate_to_max_concurrency_no_rpm_no_tpm(self) -> None:
696
- lm = MockModel()
697
- self.assertEqual(
698
- lm_lib.DEFAULT_MAX_CONCURRENCY,
699
- lm.rate_to_max_concurrency(requests_per_min=0, tokens_per_min=0),
700
- )
701
- self.assertEqual(
702
- lm_lib.DEFAULT_MAX_CONCURRENCY,
703
- lm.rate_to_max_concurrency(requests_per_min=-1, tokens_per_min=-1),
704
- )
705
-
706
- def test_rate_to_max_concurrency_only_rpm_specified_uses_rpm(self) -> None:
707
- lm = MockModel()
708
- test_rpm = 1e4
709
- self.assertEqual(
710
- lm.rate_to_max_concurrency(requests_per_min=test_rpm),
711
- int(test_rpm / 60)
712
- )
713
-
714
- def test_rate_to_max_concurrency_tpm_specified_uses_tpm(self) -> None:
715
- lm = MockModel()
716
- test_tpm = 1e7
717
- self.assertEqual(
718
- lm.rate_to_max_concurrency(requests_per_min=1, tokens_per_min=test_tpm),
719
- int(test_tpm / lm_lib.TOKENS_PER_REQUEST / 60),
720
- )
721
-
722
- def test_rate_to_max_concurrency_small_rate_returns_one(self) -> None:
723
- lm = MockModel()
724
- self.assertEqual(lm.rate_to_max_concurrency(requests_per_min=1), 1)
725
- self.assertEqual(lm.rate_to_max_concurrency(tokens_per_min=1), 1)
726
-
727
815
  def test_track_usages(self):
728
816
  lm = MockModel(name='model1')
729
817
  lm2 = MockModel(name='model2')