lm-deluge 0.0.3__tar.gz → 0.0.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of lm-deluge might be problematic. Click here for more details.

Files changed (44) hide show
  1. {lm_deluge-0.0.3 → lm_deluge-0.0.4}/PKG-INFO +2 -2
  2. {lm_deluge-0.0.3 → lm_deluge-0.0.4}/pyproject.toml +2 -2
  3. {lm_deluge-0.0.3 → lm_deluge-0.0.4}/src/lm_deluge/api_requests/anthropic.py +4 -8
  4. {lm_deluge-0.0.3 → lm_deluge-0.0.4}/src/lm_deluge/api_requests/base.py +23 -27
  5. {lm_deluge-0.0.3 → lm_deluge-0.0.4}/src/lm_deluge/api_requests/cohere.py +4 -6
  6. {lm_deluge-0.0.3 → lm_deluge-0.0.4}/src/lm_deluge/api_requests/deprecated/bedrock.py +4 -4
  7. {lm_deluge-0.0.3 → lm_deluge-0.0.4}/src/lm_deluge/api_requests/deprecated/deepseek.py +2 -2
  8. {lm_deluge-0.0.3 → lm_deluge-0.0.4}/src/lm_deluge/api_requests/deprecated/mistral.py +2 -2
  9. {lm_deluge-0.0.3 → lm_deluge-0.0.4}/src/lm_deluge/api_requests/openai.py +5 -7
  10. {lm_deluge-0.0.3 → lm_deluge-0.0.4}/src/lm_deluge/api_requests/vertex.py +9 -13
  11. {lm_deluge-0.0.3 → lm_deluge-0.0.4}/src/lm_deluge/client.py +28 -43
  12. {lm_deluge-0.0.3 → lm_deluge-0.0.4}/src/lm_deluge/embed.py +13 -28
  13. {lm_deluge-0.0.3 → lm_deluge-0.0.4}/src/lm_deluge/llm_tools/extract.py +5 -5
  14. {lm_deluge-0.0.3 → lm_deluge-0.0.4}/src/lm_deluge/models.py +4 -5
  15. {lm_deluge-0.0.3 → lm_deluge-0.0.4}/src/lm_deluge/rerank.py +15 -29
  16. lm_deluge-0.0.4/src/lm_deluge/tracker.py +43 -0
  17. {lm_deluge-0.0.3 → lm_deluge-0.0.4}/src/lm_deluge/util/logprobs.py +2 -2
  18. {lm_deluge-0.0.3 → lm_deluge-0.0.4}/src/lm_deluge.egg-info/PKG-INFO +2 -2
  19. lm_deluge-0.0.3/src/lm_deluge/tracker.py +0 -12
  20. {lm_deluge-0.0.3 → lm_deluge-0.0.4}/README.md +0 -0
  21. {lm_deluge-0.0.3 → lm_deluge-0.0.4}/setup.cfg +0 -0
  22. {lm_deluge-0.0.3 → lm_deluge-0.0.4}/src/lm_deluge/__init__.py +0 -0
  23. {lm_deluge-0.0.3 → lm_deluge-0.0.4}/src/lm_deluge/api_requests/__init__.py +0 -0
  24. {lm_deluge-0.0.3 → lm_deluge-0.0.4}/src/lm_deluge/api_requests/common.py +0 -0
  25. {lm_deluge-0.0.3 → lm_deluge-0.0.4}/src/lm_deluge/api_requests/google.py +0 -0
  26. {lm_deluge-0.0.3 → lm_deluge-0.0.4}/src/lm_deluge/cache.py +0 -0
  27. {lm_deluge-0.0.3 → lm_deluge-0.0.4}/src/lm_deluge/errors.py +0 -0
  28. {lm_deluge-0.0.3 → lm_deluge-0.0.4}/src/lm_deluge/gemini_limits.py +0 -0
  29. {lm_deluge-0.0.3 → lm_deluge-0.0.4}/src/lm_deluge/image.py +0 -0
  30. {lm_deluge-0.0.3 → lm_deluge-0.0.4}/src/lm_deluge/llm_tools/__init__.py +0 -0
  31. {lm_deluge-0.0.3 → lm_deluge-0.0.4}/src/lm_deluge/llm_tools/score.py +0 -0
  32. {lm_deluge-0.0.3 → lm_deluge-0.0.4}/src/lm_deluge/llm_tools/translate.py +0 -0
  33. {lm_deluge-0.0.3 → lm_deluge-0.0.4}/src/lm_deluge/prompt.py +0 -0
  34. {lm_deluge-0.0.3 → lm_deluge-0.0.4}/src/lm_deluge/sampling_params.py +0 -0
  35. {lm_deluge-0.0.3 → lm_deluge-0.0.4}/src/lm_deluge/tool.py +0 -0
  36. {lm_deluge-0.0.3 → lm_deluge-0.0.4}/src/lm_deluge/util/json.py +0 -0
  37. {lm_deluge-0.0.3 → lm_deluge-0.0.4}/src/lm_deluge/util/pdf.py +0 -0
  38. {lm_deluge-0.0.3 → lm_deluge-0.0.4}/src/lm_deluge/util/validation.py +0 -0
  39. {lm_deluge-0.0.3 → lm_deluge-0.0.4}/src/lm_deluge/util/xml.py +0 -0
  40. {lm_deluge-0.0.3 → lm_deluge-0.0.4}/src/lm_deluge.egg-info/SOURCES.txt +0 -0
  41. {lm_deluge-0.0.3 → lm_deluge-0.0.4}/src/lm_deluge.egg-info/dependency_links.txt +0 -0
  42. {lm_deluge-0.0.3 → lm_deluge-0.0.4}/src/lm_deluge.egg-info/requires.txt +0 -0
  43. {lm_deluge-0.0.3 → lm_deluge-0.0.4}/src/lm_deluge.egg-info/top_level.txt +0 -0
  44. {lm_deluge-0.0.3 → lm_deluge-0.0.4}/tests/test_heal_json.py +0 -0
@@ -1,9 +1,9 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: lm_deluge
3
- Version: 0.0.3
3
+ Version: 0.0.4
4
4
  Summary: Python utility for using LLM API models.
5
5
  Author-email: Benjamin Anderson <ben@trytaylor.ai>
6
- Requires-Python: >=3.9
6
+ Requires-Python: >=3.10
7
7
  Description-Content-Type: text/markdown
8
8
  Requires-Dist: python-dotenv
9
9
  Requires-Dist: json5
@@ -3,11 +3,11 @@ requires = ["setuptools", "wheel"]
3
3
 
4
4
  [project]
5
5
  name = "lm_deluge"
6
- version = "0.0.3"
6
+ version = "0.0.4"
7
7
  authors = [{ name = "Benjamin Anderson", email = "ben@trytaylor.ai" }]
8
8
  description = "Python utility for using LLM API models."
9
9
  readme = "README.md"
10
- requires-python = ">=3.9"
10
+ requires-python = ">=3.10"
11
11
  keywords = []
12
12
  license = { text = "" }
13
13
  classifiers = []
@@ -3,9 +3,8 @@ from aiohttp import ClientResponse
3
3
  import json
4
4
  import os
5
5
  import warnings
6
- import time
7
6
  from tqdm import tqdm
8
- from typing import Optional, Callable
7
+ from typing import Callable
9
8
 
10
9
  from lm_deluge.prompt import Conversation
11
10
  from .base import APIRequestBase, APIResponse
@@ -29,8 +28,8 @@ class AnthropicRequest(APIRequestBase):
29
28
  results_arr: list,
30
29
  request_timeout: int = 30,
31
30
  sampling_params: SamplingParams = SamplingParams(),
32
- pbar: Optional[tqdm] = None,
33
- callback: Optional[Callable] = None,
31
+ pbar: tqdm | None = None,
32
+ callback: Callable | None = None,
34
33
  debug: bool = False,
35
34
  # for retries
36
35
  all_model_names: list[str] | None = None,
@@ -96,8 +95,6 @@ class AnthropicRequest(APIRequestBase):
96
95
  if self.system_message is not None:
97
96
  self.request_json["system"] = self.system_message
98
97
 
99
- # print("request data:", self.request_json)
100
-
101
98
  async def handle_response(self, http_response: ClientResponse) -> APIResponse:
102
99
  is_error = False
103
100
  error_message = None
@@ -156,8 +153,7 @@ class AnthropicRequest(APIRequestBase):
156
153
  or "overloaded" in error_message.lower()
157
154
  ):
158
155
  error_message += " (Rate limit error, triggering cooldown.)"
159
- self.status_tracker.time_of_last_rate_limit_error = time.time()
160
- self.status_tracker.num_rate_limit_errors += 1
156
+ self.status_tracker.rate_limit_exceeded()
161
157
  if "context length" in error_message:
162
158
  error_message += " (Context length exceeded, set retries to 0.)"
163
159
  self.attempts_left = 0
@@ -5,7 +5,7 @@ import random
5
5
  from tqdm import tqdm
6
6
  from dataclasses import dataclass
7
7
  from abc import ABC, abstractmethod
8
- from typing import Optional, Callable
8
+ from typing import Callable
9
9
 
10
10
  from lm_deluge.prompt import Conversation
11
11
 
@@ -26,25 +26,25 @@ class APIResponse:
26
26
 
27
27
  # http response information
28
28
  status_code: int | None
29
- is_error: Optional[bool]
30
- error_message: Optional[str]
29
+ is_error: bool | None
30
+ error_message: str | None
31
31
 
32
32
  # completion information
33
- completion: Optional[str]
34
- input_tokens: Optional[int]
35
- output_tokens: Optional[int]
33
+ completion: str | None
34
+ input_tokens: int | None
35
+ output_tokens: int | None
36
36
 
37
37
  # optional or calculated automatically
38
- thinking: Optional[str] = None # if model shows thinking tokens
39
- model_external: Optional[str] = None # the model tag used by the API
40
- region: Optional[str] = None
41
- logprobs: Optional[list] = None
42
- finish_reason: Optional[str] = None # make required later
43
- cost: Optional[float] = None # calculated automatically
38
+ thinking: str | None = None # if model shows thinking tokens
39
+ model_external: str | None = None # the model tag used by the API
40
+ region: str | None = None
41
+ logprobs: list | None = None
42
+ finish_reason: str | None = None # make required later
43
+ cost: float | None = None # calculated automatically
44
44
  # set to true if is_error and should be retried with a different model
45
- retry_with_different_model: Optional[bool] = False
45
+ retry_with_different_model: bool | None = False
46
46
  # set to true if should NOT retry with the same model (unrecoverable error)
47
- give_up_if_no_other_models: Optional[bool] = False
47
+ give_up_if_no_other_models: bool | None = False
48
48
 
49
49
  def __post_init__(self):
50
50
  # calculate cost & get external model name
@@ -138,9 +138,9 @@ class APIRequestBase(ABC):
138
138
  request_timeout: int = 30,
139
139
  sampling_params: SamplingParams = SamplingParams(),
140
140
  logprobs: bool = False,
141
- top_logprobs: Optional[int] = None,
142
- pbar: Optional[tqdm] = None,
143
- callback: Optional[Callable] = None,
141
+ top_logprobs: int | None = None,
142
+ pbar: tqdm | None = None,
143
+ callback: Callable | None = None,
144
144
  debug: bool = False,
145
145
  all_model_names: list[str] | None = None,
146
146
  all_sampling_params: list[SamplingParams] | None = None,
@@ -185,8 +185,7 @@ class APIRequestBase(ABC):
185
185
  def handle_success(self, data):
186
186
  self.call_callback()
187
187
  self.increment_pbar()
188
- self.status_tracker.num_tasks_in_progress -= 1
189
- self.status_tracker.num_tasks_succeeded += 1
188
+ self.status_tracker.task_succeeded(self.task_id)
190
189
 
191
190
  def handle_error(self, create_new_request=False, give_up_if_no_other_models=False):
192
191
  """
@@ -215,8 +214,7 @@ class APIRequestBase(ABC):
215
214
  print(
216
215
  f"No other models to try for task {self.task_id}. Giving up."
217
216
  )
218
- self.status_tracker.num_tasks_in_progress -= 1
219
- self.status_tracker.num_tasks_failed += 1
217
+ self.status_tracker.task_failed(self.task_id)
220
218
  else:
221
219
  print(
222
220
  f"No other models to try for task {self.task_id}. Retrying with same model."
@@ -263,8 +261,7 @@ class APIRequestBase(ABC):
263
261
  self.results_arr.append(new_request)
264
262
  else:
265
263
  print(f"Task {self.task_id} out of tries.")
266
- self.status_tracker.num_tasks_in_progress -= 1
267
- self.status_tracker.num_tasks_failed += 1
264
+ self.status_tracker.task_failed(self.task_id)
268
265
 
269
266
  async def call_api(self):
270
267
  try:
@@ -308,7 +305,6 @@ class APIRequestBase(ABC):
308
305
 
309
306
  except Exception as e:
310
307
  raise_if_modal_exception(e)
311
- # print(f"Unexpected error {type(e).__name__}: {str(e) or 'No message.'}")
312
308
  self.result.append(
313
309
  APIResponse(
314
310
  id=self.task_id,
@@ -342,9 +338,9 @@ def create_api_request(
342
338
  request_timeout: int = 30,
343
339
  sampling_params: SamplingParams = SamplingParams(),
344
340
  logprobs: bool = False,
345
- top_logprobs: Optional[int] = None,
346
- pbar: Optional[tqdm] = None,
347
- callback: Optional[Callable] = None,
341
+ top_logprobs: int | None = None,
342
+ pbar: tqdm | None = None,
343
+ callback: Callable | None = None,
348
344
  all_model_names: list[str] | None = None,
349
345
  all_sampling_params: list[SamplingParams] | None = None,
350
346
  ) -> APIRequestBase:
@@ -4,9 +4,8 @@ import asyncio
4
4
  from aiohttp import ClientResponse
5
5
  import json
6
6
  import os
7
- import time
8
7
  from tqdm import tqdm
9
- from typing import Optional, Callable
8
+ from typing import Callable
10
9
  from lm_deluge.prompt import Conversation
11
10
  from .base import APIRequestBase, APIResponse
12
11
 
@@ -29,8 +28,8 @@ class CohereRequest(APIRequestBase):
29
28
  retry_queue: asyncio.Queue,
30
29
  request_timeout: int = 30,
31
30
  sampling_params: SamplingParams = SamplingParams(),
32
- pbar: Optional[tqdm] = None,
33
- callback: Optional[Callable] = None,
31
+ pbar: tqdm | None = None,
32
+ callback: Callable | None = None,
34
33
  debug: bool = False,
35
34
  all_model_names: list[str] | None = None,
36
35
  all_sampling_params: list[SamplingParams] | None = None,
@@ -118,8 +117,7 @@ class CohereRequest(APIRequestBase):
118
117
  or "overloaded" in error_message.lower()
119
118
  ):
120
119
  error_message += " (Rate limit error, triggering cooldown.)"
121
- self.status_tracker.time_of_last_rate_limit_error = time.time()
122
- self.status_tracker.num_rate_limit_errors += 1
120
+ self.status_tracker.rate_limit_exceeded()
123
121
  if "context length" in error_message:
124
122
  error_message += " (Context length exceeded, set retries to 0.)"
125
123
  self.attempts_left = 0
@@ -55,8 +55,8 @@
55
55
  # retry_queue: asyncio.Queue,
56
56
  # request_timeout: int = 30,
57
57
  # sampling_params: SamplingParams = SamplingParams(),
58
- # pbar: Optional[tqdm] = None,
59
- # callback: Optional[Callable] = None,
58
+ # pbar: tqdm | None = None,
59
+ # callback: Callable | None = None,
60
60
  # debug: bool = False,
61
61
  # all_model_names: list[str] | None = None,
62
62
  # all_sampling_params: list[SamplingParams] | None = None,
@@ -175,8 +175,8 @@
175
175
  # results_arr: list,
176
176
  # request_timeout: int = 30,
177
177
  # sampling_params: SamplingParams = SamplingParams(),
178
- # pbar: Optional[tqdm] = None,
179
- # callback: Optional[Callable] = None,
178
+ # pbar: tqdm | None = None,
179
+ # callback: Callable | None = None,
180
180
  # debug: bool = False,
181
181
  # all_model_names: list[str] | None = None,
182
182
  # all_sampling_params: list[SamplingParams] | None = None,
@@ -25,8 +25,8 @@
25
25
  # results_arr: list,
26
26
  # request_timeout: int = 30,
27
27
  # sampling_params: SamplingParams = SamplingParams(),
28
- # pbar: Optional[tqdm] = None,
29
- # callback: Optional[Callable] = None,
28
+ # pbar: tqdm | None = None,
29
+ # callback: Callable | None = None,
30
30
  # debug: bool = False,
31
31
  # all_model_names: list[str] = None,
32
32
  # all_sampling_params: list[SamplingParams] = None,
@@ -27,8 +27,8 @@
27
27
  # results_arr: list,
28
28
  # request_timeout: int = 30,
29
29
  # sampling_params: SamplingParams = SamplingParams(),
30
- # pbar: Optional[tqdm] = None,
31
- # callback: Optional[Callable] = None,
30
+ # pbar: tqdm | None = None,
31
+ # callback: Callable | None = None,
32
32
  # debug: bool = False,
33
33
  # all_model_names: list[str] = None,
34
34
  # all_sampling_params: list[SamplingParams] = None,
@@ -3,9 +3,8 @@ import warnings
3
3
  from aiohttp import ClientResponse
4
4
  import json
5
5
  import os
6
- import time
7
6
  from tqdm.auto import tqdm
8
- from typing import Optional, Callable
7
+ from typing import Callable
9
8
 
10
9
  from .base import APIRequestBase, APIResponse
11
10
  from ..prompt import Conversation
@@ -29,9 +28,9 @@ class OpenAIRequest(APIRequestBase):
29
28
  request_timeout: int = 30,
30
29
  sampling_params: SamplingParams = SamplingParams(),
31
30
  logprobs: bool = False,
32
- top_logprobs: Optional[int] = None,
33
- pbar: Optional[tqdm] = None,
34
- callback: Optional[Callable] = None,
31
+ top_logprobs: int | None = None,
32
+ pbar: tqdm | None = None,
33
+ callback: Callable | None = None,
35
34
  debug: bool = False,
36
35
  all_model_names: list[str] | None = None,
37
36
  all_sampling_params: list[SamplingParams] | None = None,
@@ -124,8 +123,7 @@ class OpenAIRequest(APIRequestBase):
124
123
  if is_error and error_message is not None:
125
124
  if "rate limit" in error_message.lower() or status_code == 429:
126
125
  error_message += " (Rate limit error, triggering cooldown.)"
127
- self.status_tracker.time_of_last_rate_limit_error = time.time()
128
- self.status_tracker.num_rate_limit_errors += 1
126
+ self.status_tracker.rate_limit_exceeded()
129
127
  if "context length" in error_message:
130
128
  error_message += " (Context length exceeded, set retries to 0.)"
131
129
  self.attempts_left = 0
@@ -5,7 +5,7 @@ import json
5
5
  import os
6
6
  import time
7
7
  from tqdm import tqdm
8
- from typing import Optional, Callable
8
+ from typing import Callable
9
9
 
10
10
  from lm_deluge.prompt import Conversation
11
11
  from .base import APIRequestBase, APIResponse
@@ -57,8 +57,8 @@ class VertexAnthropicRequest(APIRequestBase):
57
57
  results_arr: list,
58
58
  request_timeout: int = 30,
59
59
  sampling_params: SamplingParams = SamplingParams(),
60
- pbar: Optional[tqdm] = None,
61
- callback: Optional[Callable] = None,
60
+ pbar: tqdm | None = None,
61
+ callback: Callable | None = None,
62
62
  debug: bool = False,
63
63
  ):
64
64
  super().__init__(
@@ -141,8 +141,7 @@ class VertexAnthropicRequest(APIRequestBase):
141
141
  or status_code == 429
142
142
  ):
143
143
  error_message += " (Rate limit error, triggering cooldown.)"
144
- self.status_tracker.time_of_last_rate_limit_error = time.time()
145
- self.status_tracker.num_rate_limit_errors += 1
144
+ self.status_tracker.rate_limit_exceeded()
146
145
  if "context length" in error_message:
147
146
  error_message += " (Context length exceeded, set retries to 0.)"
148
147
  self.attempts_left = 0
@@ -185,8 +184,8 @@ class GeminiRequest(APIRequestBase):
185
184
  results_arr: list,
186
185
  request_timeout: int = 30,
187
186
  sampling_params: SamplingParams = SamplingParams(),
188
- pbar: Optional[tqdm] = None,
189
- callback: Optional[Callable] = None,
187
+ pbar: tqdm | None = None,
188
+ callback: Callable | None = None,
190
189
  debug: bool = False,
191
190
  all_model_names: list[str] | None = None,
192
191
  all_sampling_params: list[SamplingParams] | None = None,
@@ -302,16 +301,14 @@ class GeminiRequest(APIRequestBase):
302
301
  error_message = "Finish reason SAFETY."
303
302
  retry_with_different_model = True
304
303
  else:
305
- print("Actual structure of response:")
306
- print(data)
304
+ print("Actual structure of response:", data)
307
305
  is_error = True
308
306
  error_message = "No content in response."
309
307
  except Exception as e:
310
308
  is_error = True
311
309
  error_message = f"Error calling .json() on response w/ status {status_code}: {e.__class__} {e}"
312
310
  if isinstance(e, KeyError):
313
- print("Actual structure of response:")
314
- print(data)
311
+ print("Actual structure of response:", data)
315
312
  elif "json" in (mimetype or "").lower():
316
313
  is_error = True
317
314
  data = await http_response.json()
@@ -332,8 +329,7 @@ class GeminiRequest(APIRequestBase):
332
329
  status_code == 429
333
330
  ):
334
331
  error_message += " (Rate limit error, triggering cooldown & retrying with different model.)"
335
- self.status_tracker.time_of_last_rate_limit_error = time.time()
336
- self.status_tracker.num_rate_limit_errors += 1
332
+ self.status_tracker.rate_limit_exceeded()
337
333
  retry_with_different_model = (
338
334
  True # if possible, retry with a different model
339
335
  )
@@ -5,7 +5,7 @@ import numpy as np
5
5
  import time
6
6
  import yaml
7
7
  from dataclasses import dataclass
8
- from typing import Sequence, overload, Literal, Optional, Union, Any
8
+ from typing import Sequence, overload, Literal, Any
9
9
  from tqdm.auto import tqdm
10
10
 
11
11
  from lm_deluge.prompt import Conversation
@@ -31,11 +31,11 @@ class ClientConfig:
31
31
  max_concurrent_requests: int
32
32
  max_attempts: int
33
33
  request_timeout: int
34
- sampling_params: Union[SamplingParams, list[SamplingParams]]
35
- model_weights: Union[list[float], Literal["uniform", "rate_limit"]]
34
+ sampling_params: SamplingParams | list[SamplingParams]
35
+ model_weights: list[float] | Literal["uniform", "rate_limit"]
36
36
  logprobs: bool = False
37
- top_logprobs: Optional[int] = None
38
- cache: Optional[Any] = None
37
+ top_logprobs: int | None = None
38
+ cache: Any = None
39
39
 
40
40
  @classmethod
41
41
  def from_dict(cls, config_dict: dict):
@@ -82,23 +82,21 @@ class LLMClient:
82
82
  Handles models, sampling params for each model, model weights, rate limits, etc.
83
83
  """
84
84
 
85
- pass
86
-
87
85
  def __init__(
88
86
  self,
89
87
  model_names: list[str],
90
88
  max_requests_per_minute: int,
91
89
  max_tokens_per_minute: int,
92
90
  max_concurrent_requests: int,
93
- sampling_params: Union[SamplingParams, list[SamplingParams]] = SamplingParams(),
94
- model_weights: Union[list[float], Literal["uniform", "rate_limit"]] = "uniform",
91
+ sampling_params: SamplingParams | list[SamplingParams] = SamplingParams(),
92
+ model_weights: list[float] | Literal["uniform", "rate_limit"] = "uniform",
95
93
  max_attempts: int = 5,
96
94
  request_timeout: int = 30,
97
95
  logprobs: bool = False,
98
- top_logprobs: Optional[int] = None,
96
+ top_logprobs: int | None = None,
99
97
  use_qps: bool = False,
100
98
  debug: bool = False,
101
- cache: Optional[Any] = None,
99
+ cache: Any = None,
102
100
  ):
103
101
  self.models = model_names
104
102
  if isinstance(sampling_params, SamplingParams):
@@ -154,7 +152,7 @@ class LLMClient:
154
152
  self.cache = cache
155
153
 
156
154
  @classmethod
157
- def from_config(cls, config: ClientConfig, cache: Optional[Any] = None):
155
+ def from_config(cls, config: ClientConfig, cache: Any = None):
158
156
  return cls(
159
157
  model_names=config.model_names,
160
158
  max_requests_per_minute=config.max_requests_per_minute,
@@ -168,25 +166,25 @@ class LLMClient:
168
166
  )
169
167
 
170
168
  @classmethod
171
- def from_yaml(cls, file_path: str, cache: Optional[Any] = None):
169
+ def from_yaml(cls, file_path: str, cache: Any = None):
172
170
  return cls.from_config(ClientConfig.from_yaml(file_path), cache=cache)
173
171
 
174
172
  @classmethod
175
173
  def basic(
176
174
  cls,
177
- model: Union[str, list[str]],
175
+ model: str | list[str],
178
176
  max_requests_per_minute: int = 5_000,
179
177
  max_tokens_per_minute: int = 1_000_000,
180
178
  max_concurrent_requests: int = 1_000,
181
179
  temperature: float = 0.75,
182
180
  max_new_tokens: int = 1000,
183
181
  reasoning_effort: Literal[None, "low", "medium", "high"] = None,
184
- model_weights: Union[list[float], Literal["uniform", "rate_limit"]] = "uniform",
182
+ model_weights: list[float] | Literal["uniform", "rate_limit"] = "uniform",
185
183
  logprobs: bool = False,
186
- top_logprobs: Optional[int] = None,
184
+ top_logprobs: int | None = None,
187
185
  max_attempts: int = 5,
188
186
  request_timeout: int = 30,
189
- cache: Optional[Any] = None,
187
+ cache: Any = None,
190
188
  ):
191
189
  model_names = model if isinstance(model, list) else [model]
192
190
  return cls(
@@ -222,8 +220,6 @@ class LLMClient:
222
220
  top_logprobs=self.top_logprobs,
223
221
  )
224
222
 
225
- from typing import Union, Literal
226
-
227
223
  @overload
228
224
  async def process_prompts_async(
229
225
  self,
@@ -485,9 +481,9 @@ class LLMClient:
485
481
 
486
482
 
487
483
  def api_prompts_dry_run(
488
- ids: Union[np.ndarray, list[int]],
484
+ ids: np.ndarray | list[int],
489
485
  prompts: list[Conversation],
490
- models: Union[str, list[str]],
486
+ models: str | list[str],
491
487
  model_weights: list[float],
492
488
  sampling_params: list[SamplingParams],
493
489
  max_tokens_per_minute: int = 500_000,
@@ -543,19 +539,19 @@ def api_prompts_dry_run(
543
539
 
544
540
 
545
541
  async def process_api_prompts_async(
546
- ids: Union[np.ndarray, list[int]],
542
+ ids: np.ndarray | list[int],
547
543
  prompts: list[Conversation],
548
- models: Union[str, list[str]],
544
+ models: str | list[str],
549
545
  model_weights: list[float],
550
546
  sampling_params: list[SamplingParams],
551
547
  logprobs: bool,
552
- top_logprobs: Optional[int],
548
+ top_logprobs: int | None,
553
549
  max_attempts: int = 5,
554
550
  max_tokens_per_minute: int = 500_000,
555
551
  max_requests_per_minute: int = 1_000,
556
552
  max_concurrent_requests: int = 1_000,
557
553
  request_timeout: int = 30,
558
- progress_bar: Optional[tqdm] = None,
554
+ progress_bar: tqdm | None = None,
559
555
  use_qps: bool = False,
560
556
  verbose: bool = False,
561
557
  ):
@@ -712,28 +708,17 @@ async def process_api_prompts_async(
712
708
  await asyncio.sleep(seconds_to_sleep_each_loop)
713
709
 
714
710
  # if a rate limit error was hit recently, pause to cool down
715
- seconds_since_rate_limit_error = (
716
- time.time() - status_tracker.time_of_last_rate_limit_error
711
+ remaining_seconds_to_pause = max(
712
+ 0,
713
+ seconds_to_pause_after_rate_limit_error
714
+ - status_tracker.time_since_rate_limit_error,
717
715
  )
718
- if seconds_since_rate_limit_error < seconds_to_pause_after_rate_limit_error:
719
- remaining_seconds_to_pause = (
720
- seconds_to_pause_after_rate_limit_error - seconds_since_rate_limit_error
721
- )
716
+ if remaining_seconds_to_pause > 0:
722
717
  await asyncio.sleep(remaining_seconds_to_pause)
723
- # ^e.g., if pause is 15 seconds and final limit was hit 5 seconds ago
724
- print(
725
- f"Pausing to cool down until {time.ctime(status_tracker.time_of_last_rate_limit_error + seconds_to_pause_after_rate_limit_error)}"
726
- )
718
+ print(f"Pausing {remaining_seconds_to_pause}s to cool down.")
727
719
 
728
720
  # after finishing, log final status
729
- if status_tracker.num_tasks_failed > 0:
730
- print(
731
- f"{status_tracker.num_tasks_failed} / {status_tracker.num_tasks_started} requests failed."
732
- )
733
- if status_tracker.num_rate_limit_errors > 0:
734
- print(
735
- f"{status_tracker.num_rate_limit_errors} rate limit errors received. Consider running at a lower rate."
736
- )
721
+ status_tracker.log_final_status()
737
722
  if verbose:
738
723
  print(
739
724
  f"After processing, got {len(results)} results for {len(ids)} inputs. Removing duplicates."
@@ -5,7 +5,7 @@ import aiohttp
5
5
  from tqdm.auto import tqdm
6
6
  import asyncio
7
7
  import time
8
- from typing import Any, Optional
8
+ from typing import Any
9
9
  from dataclasses import dataclass
10
10
  from .tracker import StatusTracker
11
11
 
@@ -58,7 +58,7 @@ class EmbeddingRequest:
58
58
  status_tracker: StatusTracker,
59
59
  retry_queue: asyncio.Queue,
60
60
  request_timeout: int,
61
- pbar: Optional[tqdm] = None,
61
+ pbar: tqdm | None = None,
62
62
  **kwargs, # openai or cohere specific params
63
63
  ):
64
64
  self.task_id = task_id
@@ -78,8 +78,7 @@ class EmbeddingRequest:
78
78
 
79
79
  def handle_success(self):
80
80
  self.increment_pbar()
81
- self.status_tracker.num_tasks_in_progress -= 1
82
- self.status_tracker.num_tasks_succeeded += 1
81
+ self.status_tracker.task_succeeded(self.task_id)
83
82
 
84
83
  def handle_error(self):
85
84
  last_result: EmbeddingResponse = self.result[-1]
@@ -94,8 +93,7 @@ class EmbeddingRequest:
94
93
  return
95
94
  else:
96
95
  print(f"Task {self.task_id} out of tries.")
97
- self.status_tracker.num_tasks_in_progress -= 1
98
- self.status_tracker.num_tasks_failed += 1
96
+ self.status_tracker.task_failed(self.task_id)
99
97
 
100
98
  async def handle_response(self, response: aiohttp.ClientResponse):
101
99
  try:
@@ -217,7 +215,7 @@ class EmbeddingResponse:
217
215
  id: int
218
216
  status_code: int | None
219
217
  is_error: bool
220
- error_message: Optional[str]
218
+ error_message: str | None
221
219
  texts: list[str]
222
220
  embeddings: list[list[float]]
223
221
 
@@ -282,8 +280,7 @@ async def embed_parallel_async(
282
280
  pbar=pbar,
283
281
  **kwargs,
284
282
  )
285
- status_tracker.num_tasks_started += 1
286
- status_tracker.num_tasks_in_progress += 1
283
+ status_tracker.start_task(batch_id)
287
284
  results.append(next_request)
288
285
 
289
286
  except StopIteration:
@@ -333,29 +330,17 @@ async def embed_parallel_async(
333
330
  await asyncio.sleep(seconds_to_sleep_each_loop)
334
331
 
335
332
  # if a rate limit error was hit recently, pause to cool down
336
- seconds_since_rate_limit_error = (
337
- time.time() - status_tracker.time_of_last_rate_limit_error
333
+ remaining_seconds_to_pause = max(
334
+ 0,
335
+ seconds_to_pause_after_rate_limit_error
336
+ - status_tracker.time_since_rate_limit_error,
338
337
  )
339
- if seconds_since_rate_limit_error < seconds_to_pause_after_rate_limit_error:
340
- remaining_seconds_to_pause = (
341
- seconds_to_pause_after_rate_limit_error - seconds_since_rate_limit_error
342
- )
338
+ if remaining_seconds_to_pause > 0:
343
339
  await asyncio.sleep(remaining_seconds_to_pause)
344
- # ^e.g., if pause is 15 seconds and final limit was hit 5 seconds ago
345
- print(
346
- f"Pausing to cool down until {time.ctime(status_tracker.time_of_last_rate_limit_error + seconds_to_pause_after_rate_limit_error)}"
347
- )
340
+ print(f"Pausing {remaining_seconds_to_pause}s to cool down.")
348
341
 
349
342
  # after finishing, log final status
350
- if status_tracker.num_tasks_failed > 0:
351
- print(
352
- f"{status_tracker.num_tasks_failed} / {status_tracker.num_tasks_started} requests failed."
353
- )
354
- if status_tracker.num_rate_limit_errors > 0:
355
- print(
356
- f"{status_tracker.num_rate_limit_errors} rate limit errors received. Consider running at a lower rate."
357
- )
358
-
343
+ status_tracker.log_final_status()
359
344
  print(
360
345
  f"After processing, got {len(results)} results for {len(ids)} inputs. Removing duplicates."
361
346
  )
@@ -3,7 +3,7 @@ import json
3
3
  from ..prompt import Conversation
4
4
  import asyncio
5
5
  from ..client import LLMClient
6
- from typing import Optional, Any
6
+ from typing import Any
7
7
  from ..util.json import load_json
8
8
 
9
9
  try:
@@ -16,8 +16,8 @@ async def extract_async(
16
16
  inputs: list[str | Any],
17
17
  schema: Any,
18
18
  client: LLMClient,
19
- document_name: Optional[str] = None,
20
- object_name: Optional[str] = None,
19
+ document_name: str | None = None,
20
+ object_name: str | None = None,
21
21
  show_progress: bool = True,
22
22
  return_prompts: bool = False,
23
23
  ):
@@ -93,8 +93,8 @@ def extract(
93
93
  inputs: list[str | Any],
94
94
  schema: Any,
95
95
  client: LLMClient,
96
- document_name: Optional[str] = None,
97
- object_name: Optional[str] = None,
96
+ document_name: str | None = None,
97
+ object_name: str | None = None,
98
98
  show_progress: bool = True,
99
99
  return_prompts: bool = False,
100
100
  ):
@@ -1,6 +1,5 @@
1
1
  import random
2
2
  from dataclasses import dataclass, field
3
- from typing import Optional
4
3
  from .gemini_limits import gemini_1_5_pro_limits, gemini_flash_limits
5
4
 
6
5
  registry = {
@@ -928,15 +927,15 @@ class APIModel:
928
927
  api_base: str
929
928
  api_key_env_var: str
930
929
  api_spec: str
931
- input_cost: Optional[float] = 0 # $ per million input tokens
932
- output_cost: Optional[float] = 0 # $ per million output tokens
930
+ input_cost: float | None = 0 # $ per million input tokens
931
+ output_cost: float | None = 0 # $ per million output tokens
933
932
  supports_json: bool = False
934
933
  supports_logprobs: bool = False
935
934
  reasoning_model: bool = False
936
935
  regions: list[str] | dict[str, int] = field(default_factory=list)
937
936
  tokens_per_minute: int | None = None
938
937
  requests_per_minute: int | None = None
939
- gpus: Optional[list[str]] = None
938
+ gpus: list[str] | None = None
940
939
 
941
940
  @classmethod
942
941
  def from_registry(cls, name: str):
@@ -950,7 +949,7 @@ class APIModel:
950
949
  regions = self.regions
951
950
  weights = [1] * len(regions)
952
951
  elif isinstance(self.regions, dict):
953
- regions = self.regions.keys()
952
+ regions = list(self.regions.keys())
954
953
  weights = self.regions.values()
955
954
  else:
956
955
  raise ValueError("no regions to sample")
@@ -4,7 +4,6 @@ import aiohttp
4
4
  from tqdm.auto import tqdm
5
5
  import asyncio
6
6
  import time
7
- from typing import Optional
8
7
  from dataclasses import dataclass
9
8
  from .tracker import StatusTracker
10
9
 
@@ -28,7 +27,7 @@ class RerankingRequest:
28
27
  status_tracker: StatusTracker,
29
28
  retry_queue: asyncio.Queue,
30
29
  request_timeout: int,
31
- pbar: Optional[tqdm] = None,
30
+ pbar: tqdm | None = None,
32
31
  ):
33
32
  self.task_id = task_id
34
33
  self.model_name = model_name
@@ -48,8 +47,7 @@ class RerankingRequest:
48
47
 
49
48
  def handle_success(self):
50
49
  self.increment_pbar()
51
- self.status_tracker.num_tasks_in_progress -= 1
52
- self.status_tracker.num_tasks_succeeded += 1
50
+ self.status_tracker.task_succeeded(self.task_id)
53
51
 
54
52
  def handle_error(self):
55
53
  """
@@ -69,8 +67,7 @@ class RerankingRequest:
69
67
  return
70
68
  else:
71
69
  print(f"Task {self.task_id} out of tries.")
72
- self.status_tracker.num_tasks_in_progress -= 1
73
- self.status_tracker.num_tasks_failed += 1
70
+ self.status_tracker.task_failed(self.task_id)
74
71
 
75
72
  async def handle_response(self, response: aiohttp.ClientResponse):
76
73
  try:
@@ -127,8 +124,9 @@ class RerankingRequest:
127
124
  try:
128
125
  self.status_tracker.total_requests += 1
129
126
  async with aiohttp.ClientSession() as session:
127
+ timeout = aiohttp.ClientTimeout(total=self.request_timeout)
130
128
  async with session.post(
131
- url, headers=headers, json=data, timeout=self.request_timeout
129
+ url, headers=headers, json=data, timeout=timeout
132
130
  ) as response:
133
131
  # print("got response!!")
134
132
  response_obj: RerankingResponse = await self.handle_response(
@@ -176,7 +174,7 @@ class RerankingResponse:
176
174
  id: int
177
175
  status_code: int | None
178
176
  is_error: bool
179
- error_message: Optional[str]
177
+ error_message: str | None
180
178
  query: str
181
179
  documents: list[str]
182
180
  top_k_indices: list[int]
@@ -196,7 +194,7 @@ async def rerank_parallel_async(
196
194
  max_requests_per_minute: int = 4_000,
197
195
  max_concurrent_requests: int = 500,
198
196
  request_timeout: int = 10,
199
- progress_bar: Optional[tqdm] = None,
197
+ progress_bar: tqdm | None = None,
200
198
  ):
201
199
  """Processes rerank requests in parallel, throttling to stay under rate limits."""
202
200
  ids = range(len(queries))
@@ -243,8 +241,7 @@ async def rerank_parallel_async(
243
241
  request_timeout=request_timeout,
244
242
  pbar=progress_bar,
245
243
  )
246
- status_tracker.num_tasks_started += 1
247
- status_tracker.num_tasks_in_progress += 1
244
+ status_tracker.start_task(req_id)
248
245
  results.append(next_request)
249
246
 
250
247
  except StopIteration:
@@ -294,28 +291,17 @@ async def rerank_parallel_async(
294
291
  await asyncio.sleep(seconds_to_sleep_each_loop)
295
292
 
296
293
  # if a rate limit error was hit recently, pause to cool down
297
- seconds_since_rate_limit_error = (
298
- time.time() - status_tracker.time_of_last_rate_limit_error
294
+ remaining_seconds_to_pause = max(
295
+ 0,
296
+ seconds_to_pause_after_rate_limit_error
297
+ - status_tracker.time_since_rate_limit_error,
299
298
  )
300
- if seconds_since_rate_limit_error < seconds_to_pause_after_rate_limit_error:
301
- remaining_seconds_to_pause = (
302
- seconds_to_pause_after_rate_limit_error - seconds_since_rate_limit_error
303
- )
299
+ if remaining_seconds_to_pause > 0:
304
300
  await asyncio.sleep(remaining_seconds_to_pause)
305
- # ^e.g., if pause is 15 seconds and final limit was hit 5 seconds ago
306
- print(
307
- f"Pausing to cool down until {time.ctime(status_tracker.time_of_last_rate_limit_error + seconds_to_pause_after_rate_limit_error)}"
308
- )
301
+ print(f"Pausing {remaining_seconds_to_pause}s to cool down.")
309
302
 
310
303
  # after finishing, log final status
311
- if status_tracker.num_tasks_failed > 0:
312
- print(
313
- f"{status_tracker.num_tasks_failed} / {status_tracker.num_tasks_started} requests failed."
314
- )
315
- if status_tracker.num_rate_limit_errors > 0:
316
- print(
317
- f"{status_tracker.num_rate_limit_errors} rate limit errors received. Consider running at a lower rate."
318
- )
304
+ status_tracker.log_final_status()
319
305
 
320
306
  print(
321
307
  f"After processing, got {len(results)} results for {len(ids)} inputs. Removing duplicates."
@@ -0,0 +1,43 @@
1
+ import time
2
+ from dataclasses import dataclass
3
+
4
+
5
+ @dataclass
6
+ class StatusTracker:
7
+ num_tasks_started: int = 0
8
+ num_tasks_in_progress: int = 0
9
+ num_tasks_succeeded: int = 0
10
+ num_tasks_failed: int = 0
11
+ num_rate_limit_errors: int = 0
12
+ time_of_last_rate_limit_error: int | float = 0
13
+ total_requests = 0
14
+
15
+ @property
16
+ def time_since_rate_limit_error(self):
17
+ return time.time() - self.time_of_last_rate_limit_error
18
+
19
+ def start_task(self, task_id):
20
+ self.num_tasks_started += 1
21
+ self.num_tasks_in_progress += 1
22
+
23
+ def rate_limit_exceeded(self):
24
+ self.time_of_last_rate_limit_error = time.time()
25
+ self.num_rate_limit_errors += 1
26
+
27
+ def task_succeeded(self, task_id):
28
+ self.num_tasks_in_progress -= 1
29
+ self.num_tasks_succeeded += 1
30
+
31
+ def task_failed(self, task_id):
32
+ self.num_tasks_in_progress -= 1
33
+ self.num_tasks_failed += 1
34
+
35
+ def log_final_status(self):
36
+ if self.num_tasks_failed > 0:
37
+ print(
38
+ f"{self.num_tasks_failed} / {self.num_tasks_started} requests failed."
39
+ )
40
+ if self.num_rate_limit_errors > 0:
41
+ print(
42
+ f"{self.num_rate_limit_errors} rate limit errors received. Consider running at a lower rate."
43
+ )
@@ -1,6 +1,6 @@
1
1
  import re
2
2
  import numpy as np
3
- from typing import TypedDict, Optional, Callable
3
+ from typing import TypedDict, Callable
4
4
 
5
5
 
6
6
  class TopLogprob(TypedDict):
@@ -403,7 +403,7 @@ def extract_prob(
403
403
  normalize_top_logprobs: bool = True, # if using top_logprobs, normalize by all the present tokens so they add up to 1
404
404
  use_complement: bool = False, # if True, assume there's 2 choices, and return 1 - p if the top token doesn't match
405
405
  token_index: int = 0, # get from the first token of the completion by default
406
- token_match_fn: Optional[Callable[[str, str], bool]] = is_match,
406
+ token_match_fn: Callable[[str, str], bool] | None = is_match,
407
407
  ):
408
408
  """
409
409
  Extract the probability of the token from the logprobs object of a single
@@ -1,9 +1,9 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: lm_deluge
3
- Version: 0.0.3
3
+ Version: 0.0.4
4
4
  Summary: Python utility for using LLM API models.
5
5
  Author-email: Benjamin Anderson <ben@trytaylor.ai>
6
- Requires-Python: >=3.9
6
+ Requires-Python: >=3.10
7
7
  Description-Content-Type: text/markdown
8
8
  Requires-Dist: python-dotenv
9
9
  Requires-Dist: json5
@@ -1,12 +0,0 @@
1
- from dataclasses import dataclass
2
-
3
-
4
- @dataclass
5
- class StatusTracker:
6
- num_tasks_started: int = 0
7
- num_tasks_in_progress: int = 0
8
- num_tasks_succeeded: int = 0
9
- num_tasks_failed: int = 0
10
- num_rate_limit_errors: int = 0
11
- time_of_last_rate_limit_error: int | float = 0
12
- total_requests = 0
File without changes
File without changes