lm-deluge 0.0.7__tar.gz → 0.0.9__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of lm-deluge might be problematic. Click here for more details.

Files changed (54) hide show
  1. lm_deluge-0.0.9/LICENSE +7 -0
  2. {lm_deluge-0.0.7 → lm_deluge-0.0.9}/PKG-INFO +6 -8
  3. {lm_deluge-0.0.7 → lm_deluge-0.0.9}/README.md +3 -7
  4. {lm_deluge-0.0.7 → lm_deluge-0.0.9}/pyproject.toml +1 -1
  5. {lm_deluge-0.0.7 → lm_deluge-0.0.9}/src/lm_deluge/api_requests/anthropic.py +23 -7
  6. {lm_deluge-0.0.7 → lm_deluge-0.0.9}/src/lm_deluge/api_requests/base.py +38 -11
  7. lm_deluge-0.0.9/src/lm_deluge/api_requests/bedrock.py +283 -0
  8. {lm_deluge-0.0.7 → lm_deluge-0.0.9}/src/lm_deluge/api_requests/common.py +2 -0
  9. {lm_deluge-0.0.7 → lm_deluge-0.0.9}/src/lm_deluge/api_requests/mistral.py +2 -2
  10. {lm_deluge-0.0.7 → lm_deluge-0.0.9}/src/lm_deluge/api_requests/openai.py +37 -6
  11. {lm_deluge-0.0.7 → lm_deluge-0.0.9}/src/lm_deluge/client.py +18 -3
  12. {lm_deluge-0.0.7 → lm_deluge-0.0.9}/src/lm_deluge/models.py +114 -24
  13. lm_deluge-0.0.9/src/lm_deluge/prompt.py +693 -0
  14. {lm_deluge-0.0.7 → lm_deluge-0.0.9}/src/lm_deluge/tool.py +16 -35
  15. {lm_deluge-0.0.7 → lm_deluge-0.0.9}/src/lm_deluge.egg-info/PKG-INFO +6 -8
  16. {lm_deluge-0.0.7 → lm_deluge-0.0.9}/src/lm_deluge.egg-info/SOURCES.txt +4 -0
  17. {lm_deluge-0.0.7 → lm_deluge-0.0.9}/tests/test_all_models.py +4 -0
  18. lm_deluge-0.0.9/tests/test_bedrock_models.py +252 -0
  19. lm_deluge-0.0.9/tests/test_tool_calls.py +401 -0
  20. lm_deluge-0.0.7/src/lm_deluge/prompt.py +0 -357
  21. {lm_deluge-0.0.7 → lm_deluge-0.0.9}/setup.cfg +0 -0
  22. {lm_deluge-0.0.7 → lm_deluge-0.0.9}/src/lm_deluge/__init__.py +0 -0
  23. {lm_deluge-0.0.7 → lm_deluge-0.0.9}/src/lm_deluge/api_requests/__init__.py +0 -0
  24. {lm_deluge-0.0.7 → lm_deluge-0.0.9}/src/lm_deluge/api_requests/deprecated/bedrock.py +0 -0
  25. {lm_deluge-0.0.7 → lm_deluge-0.0.9}/src/lm_deluge/api_requests/deprecated/cohere.py +0 -0
  26. {lm_deluge-0.0.7 → lm_deluge-0.0.9}/src/lm_deluge/api_requests/deprecated/deepseek.py +0 -0
  27. {lm_deluge-0.0.7 → lm_deluge-0.0.9}/src/lm_deluge/api_requests/deprecated/mistral.py +0 -0
  28. {lm_deluge-0.0.7 → lm_deluge-0.0.9}/src/lm_deluge/api_requests/deprecated/vertex.py +0 -0
  29. {lm_deluge-0.0.7 → lm_deluge-0.0.9}/src/lm_deluge/cache.py +0 -0
  30. {lm_deluge-0.0.7 → lm_deluge-0.0.9}/src/lm_deluge/embed.py +0 -0
  31. {lm_deluge-0.0.7 → lm_deluge-0.0.9}/src/lm_deluge/errors.py +0 -0
  32. {lm_deluge-0.0.7 → lm_deluge-0.0.9}/src/lm_deluge/gemini_limits.py +0 -0
  33. {lm_deluge-0.0.7 → lm_deluge-0.0.9}/src/lm_deluge/image.py +0 -0
  34. {lm_deluge-0.0.7 → lm_deluge-0.0.9}/src/lm_deluge/llm_tools/__init__.py +0 -0
  35. {lm_deluge-0.0.7 → lm_deluge-0.0.9}/src/lm_deluge/llm_tools/extract.py +0 -0
  36. {lm_deluge-0.0.7 → lm_deluge-0.0.9}/src/lm_deluge/llm_tools/score.py +0 -0
  37. {lm_deluge-0.0.7 → lm_deluge-0.0.9}/src/lm_deluge/llm_tools/translate.py +0 -0
  38. {lm_deluge-0.0.7 → lm_deluge-0.0.9}/src/lm_deluge/rerank.py +0 -0
  39. {lm_deluge-0.0.7 → lm_deluge-0.0.9}/src/lm_deluge/sampling_params.py +0 -0
  40. {lm_deluge-0.0.7 → lm_deluge-0.0.9}/src/lm_deluge/tracker.py +0 -0
  41. {lm_deluge-0.0.7 → lm_deluge-0.0.9}/src/lm_deluge/util/json.py +0 -0
  42. {lm_deluge-0.0.7 → lm_deluge-0.0.9}/src/lm_deluge/util/logprobs.py +0 -0
  43. {lm_deluge-0.0.7 → lm_deluge-0.0.9}/src/lm_deluge/util/validation.py +0 -0
  44. {lm_deluge-0.0.7 → lm_deluge-0.0.9}/src/lm_deluge/util/xml.py +0 -0
  45. {lm_deluge-0.0.7 → lm_deluge-0.0.9}/src/lm_deluge.egg-info/dependency_links.txt +0 -0
  46. {lm_deluge-0.0.7 → lm_deluge-0.0.9}/src/lm_deluge.egg-info/requires.txt +0 -0
  47. {lm_deluge-0.0.7 → lm_deluge-0.0.9}/src/lm_deluge.egg-info/top_level.txt +0 -0
  48. {lm_deluge-0.0.7 → lm_deluge-0.0.9}/tests/test_cache.py +0 -0
  49. {lm_deluge-0.0.7 → lm_deluge-0.0.9}/tests/test_image_models.py +0 -0
  50. {lm_deluge-0.0.7 → lm_deluge-0.0.9}/tests/test_image_utils.py +0 -0
  51. {lm_deluge-0.0.7 → lm_deluge-0.0.9}/tests/test_json_utils.py +0 -0
  52. {lm_deluge-0.0.7 → lm_deluge-0.0.9}/tests/test_sampling_params.py +0 -0
  53. {lm_deluge-0.0.7 → lm_deluge-0.0.9}/tests/test_translate.py +0 -0
  54. {lm_deluge-0.0.7 → lm_deluge-0.0.9}/tests/test_xml_utils.py +0 -0
@@ -0,0 +1,7 @@
1
+ Copyright 2025, Taylor AI
2
+
3
+ Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
4
+
5
+ The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
6
+
7
+ THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
@@ -1,10 +1,11 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: lm_deluge
3
- Version: 0.0.7
3
+ Version: 0.0.9
4
4
  Summary: Python utility for using LLM API models.
5
5
  Author-email: Benjamin Anderson <ben@trytaylor.ai>
6
6
  Requires-Python: >=3.10
7
7
  Description-Content-Type: text/markdown
8
+ License-File: LICENSE
8
9
  Requires-Dist: python-dotenv
9
10
  Requires-Dist: json5
10
11
  Requires-Dist: PyYAML
@@ -22,10 +23,11 @@ Requires-Dist: pdf2image
22
23
  Requires-Dist: pillow
23
24
  Requires-Dist: fasttext-wheel
24
25
  Requires-Dist: fasttext-langdetect
26
+ Dynamic: license-file
25
27
 
26
- # lm_deluge
28
+ # lm-deluge
27
29
 
28
- `lm_deluge` is a lightweight helper library for maxing out your rate limits with LLM providers. It provides the following:
30
+ `lm-deluge` is a lightweight helper library for maxing out your rate limits with LLM providers. It provides the following:
29
31
 
30
32
  - **Unified client** – Send prompts to all relevant models with a single client.
31
33
  - **Massive concurrency with throttling** – Set `max_tokens_per_minute` and `max_requests_per_minute` and let it fly. The client will process as many requests as possible while respecting rate limits and retrying failures.
@@ -77,11 +79,7 @@ print(resp[0].completion)
77
79
 
78
80
  API calls can be customized in a few ways.
79
81
 
80
- 1. **Sampling Parameters.** This determines things like structured outputs, maximum completion tokens, nucleus sampling, etc. Provide a custom `SamplingParams` to the `LLMClient` to set temperature, top_p, json_mode, max_new_tokens, and/or reasoning_effort.
81
-
82
- You can pass 1 `SamplingParams` to use for all models, or a list of `SamplingParams` that's the same length as the list of models. You can also pass many of these arguments directly to `LLMClient.basic` so you don't have to construct an entire `SamplingParams` object.
83
-
84
-
82
+ 1. **Sampling Parameters.** This determines things like structured outputs, maximum completion tokens, nucleus sampling, etc. Provide a custom `SamplingParams` to the `LLMClient` to set temperature, top_p, json_mode, max_new_tokens, and/or reasoning_effort. You can pass 1 `SamplingParams` to use for all models, or a list of `SamplingParams` that's the same length as the list of models. You can also pass many of these arguments directly to `LLMClient.basic` so you don't have to construct an entire `SamplingParams` object.
85
83
  2. **Arguments to LLMClient.** This is where you set request timeout, rate limits, model name(s), model weight(s) for distributing requests across models, retries, and caching.
86
84
  3. **Arguments to process_prompts.** Per-call, you can set verbosity, whether to display progress, and whether to return just completions (rather than the full APIResponse object).
87
85
 
@@ -1,6 +1,6 @@
1
- # lm_deluge
1
+ # lm-deluge
2
2
 
3
- `lm_deluge` is a lightweight helper library for maxing out your rate limits with LLM providers. It provides the following:
3
+ `lm-deluge` is a lightweight helper library for maxing out your rate limits with LLM providers. It provides the following:
4
4
 
5
5
  - **Unified client** – Send prompts to all relevant models with a single client.
6
6
  - **Massive concurrency with throttling** – Set `max_tokens_per_minute` and `max_requests_per_minute` and let it fly. The client will process as many requests as possible while respecting rate limits and retrying failures.
@@ -52,11 +52,7 @@ print(resp[0].completion)
52
52
 
53
53
  API calls can be customized in a few ways.
54
54
 
55
- 1. **Sampling Parameters.** This determines things like structured outputs, maximum completion tokens, nucleus sampling, etc. Provide a custom `SamplingParams` to the `LLMClient` to set temperature, top_p, json_mode, max_new_tokens, and/or reasoning_effort.
56
-
57
- You can pass 1 `SamplingParams` to use for all models, or a list of `SamplingParams` that's the same length as the list of models. You can also pass many of these arguments directly to `LLMClient.basic` so you don't have to construct an entire `SamplingParams` object.
58
-
59
-
55
+ 1. **Sampling Parameters.** This determines things like structured outputs, maximum completion tokens, nucleus sampling, etc. Provide a custom `SamplingParams` to the `LLMClient` to set temperature, top_p, json_mode, max_new_tokens, and/or reasoning_effort. You can pass 1 `SamplingParams` to use for all models, or a list of `SamplingParams` that's the same length as the list of models. You can also pass many of these arguments directly to `LLMClient.basic` so you don't have to construct an entire `SamplingParams` object.
60
56
  2. **Arguments to LLMClient.** This is where you set request timeout, rate limits, model name(s), model weight(s) for distributing requests across models, retries, and caching.
61
57
  3. **Arguments to process_prompts.** Per-call, you can set verbosity, whether to display progress, and whether to return just completions (rather than the full APIResponse object).
62
58
 
@@ -3,7 +3,7 @@ requires = ["setuptools", "wheel"]
3
3
 
4
4
  [project]
5
5
  name = "lm_deluge"
6
- version = "0.0.7"
6
+ version = "0.0.9"
7
7
  authors = [{ name = "Benjamin Anderson", email = "ben@trytaylor.ai" }]
8
8
  description = "Python utility for using LLM API models."
9
9
  readme = "README.md"
@@ -6,7 +6,7 @@ import warnings
6
6
  from tqdm import tqdm
7
7
  from typing import Callable
8
8
 
9
- from lm_deluge.prompt import Conversation
9
+ from lm_deluge.prompt import Conversation, Message, Text, ToolCall, Thinking
10
10
  from .base import APIRequestBase, APIResponse
11
11
 
12
12
  from ..tracker import StatusTracker
@@ -34,6 +34,7 @@ class AnthropicRequest(APIRequestBase):
34
34
  # for retries
35
35
  all_model_names: list[str] | None = None,
36
36
  all_sampling_params: list[SamplingParams] | None = None,
37
+ tools: list | None = None,
37
38
  ):
38
39
  super().__init__(
39
40
  task_id=task_id,
@@ -50,6 +51,7 @@ class AnthropicRequest(APIRequestBase):
50
51
  debug=debug,
51
52
  all_model_names=all_model_names,
52
53
  all_sampling_params=all_sampling_params,
54
+ tools=tools,
53
55
  )
54
56
  self.model = APIModel.from_registry(model_name)
55
57
  self.url = f"{self.model.api_base}/messages"
@@ -94,12 +96,14 @@ class AnthropicRequest(APIRequestBase):
94
96
  )
95
97
  if self.system_message is not None:
96
98
  self.request_json["system"] = self.system_message
99
+ if tools:
100
+ self.request_json["tools"] = [tool.dump_for("anthropic") for tool in tools]
97
101
 
98
102
  async def handle_response(self, http_response: ClientResponse) -> APIResponse:
99
103
  is_error = False
100
104
  error_message = None
101
105
  thinking = None
102
- completion = None
106
+ content = None
103
107
  input_tokens = None
104
108
  output_tokens = None
105
109
  status_code = http_response.status
@@ -119,14 +123,26 @@ class AnthropicRequest(APIRequestBase):
119
123
  if status_code >= 200 and status_code < 300:
120
124
  try:
121
125
  data = await http_response.json()
122
- content = data["content"] # [0]["text"]
123
- for item in content:
126
+ response_content = data["content"]
127
+
128
+ # Parse response into Message with parts
129
+ parts = []
130
+ for item in response_content:
124
131
  if item["type"] == "text":
125
- completion = item["text"]
132
+ parts.append(Text(item["text"]))
126
133
  elif item["type"] == "thinking":
127
134
  thinking = item["thinking"]
135
+ parts.append(Thinking(item["thinking"]))
128
136
  elif item["type"] == "tool_use":
129
- continue # TODO: implement and report tool use
137
+ parts.append(
138
+ ToolCall(
139
+ id=item["id"],
140
+ name=item["name"],
141
+ arguments=item["input"],
142
+ )
143
+ )
144
+
145
+ content = Message("assistant", parts)
130
146
  input_tokens = data["usage"]["input_tokens"]
131
147
  output_tokens = data["usage"]["output_tokens"]
132
148
  except Exception as e:
@@ -162,7 +178,7 @@ class AnthropicRequest(APIRequestBase):
162
178
  is_error=is_error,
163
179
  error_message=error_message,
164
180
  prompt=self.prompt,
165
- completion=completion,
181
+ content=content,
166
182
  thinking=thinking,
167
183
  model_internal=self.model_name,
168
184
  sampling_params=self.sampling_params,
@@ -7,7 +7,7 @@ from dataclasses import dataclass
7
7
  from abc import ABC, abstractmethod
8
8
  from typing import Callable
9
9
 
10
- from lm_deluge.prompt import Conversation
10
+ from lm_deluge.prompt import Conversation, Message
11
11
 
12
12
  from ..tracker import StatusTracker
13
13
  from ..sampling_params import SamplingParams
@@ -30,10 +30,12 @@ class APIResponse:
30
30
  error_message: str | None
31
31
 
32
32
  # completion information
33
- completion: str | None
34
33
  input_tokens: int | None
35
34
  output_tokens: int | None
36
35
 
36
+ # response content - structured format
37
+ content: Message | None = None
38
+
37
39
  # optional or calculated automatically
38
40
  thinking: str | None = None # if model shows thinking tokens
39
41
  model_external: str | None = None # the model tag used by the API
@@ -47,6 +49,13 @@ class APIResponse:
47
49
  # set to true if should NOT retry with the same model (unrecoverable error)
48
50
  give_up_if_no_other_models: bool | None = False
49
51
 
52
+ @property
53
+ def completion(self) -> str | None:
54
+ """Backward compatibility: extract text from content Message."""
55
+ if self.content is not None:
56
+ return self.content.completion
57
+ return None
58
+
50
59
  def __post_init__(self):
51
60
  # calculate cost & get external model name
52
61
  self.id = int(self.id)
@@ -63,7 +72,7 @@ class APIResponse:
63
72
  self.input_tokens * api_model.input_cost / 1e6
64
73
  + self.output_tokens * api_model.output_cost / 1e6
65
74
  )
66
- elif self.completion is not None:
75
+ elif self.content is not None and self.completion is not None:
67
76
  print(
68
77
  f"Warning: Completion provided without token counts for model {self.model_internal}."
69
78
  )
@@ -79,7 +88,8 @@ class APIResponse:
79
88
  "status_code": self.status_code,
80
89
  "is_error": self.is_error,
81
90
  "error_message": self.error_message,
82
- "completion": self.completion,
91
+ "completion": self.completion, # computed property
92
+ "content": self.content.to_log() if self.content else None,
83
93
  "input_tokens": self.input_tokens,
84
94
  "output_tokens": self.output_tokens,
85
95
  "finish_reason": self.finish_reason,
@@ -88,11 +98,18 @@ class APIResponse:
88
98
 
89
99
  @classmethod
90
100
  def from_dict(cls, data: dict):
101
+ # Handle backward compatibility for content/completion
102
+ content = None
103
+ if "content" in data and data["content"] is not None:
104
+ # Reconstruct message from log format
105
+ content = Message.from_log(data["content"])
106
+ elif "completion" in data and data["completion"] is not None:
107
+ # Backward compatibility: create a Message with just text
108
+ content = Message.ai(data["completion"])
109
+
91
110
  return cls(
92
111
  id=data.get("id", random.randint(0, 1_000_000_000)),
93
112
  model_internal=data["model_internal"],
94
- model_external=data["model_external"],
95
- region=data["region"],
96
113
  prompt=Conversation.from_log(data["prompt"]),
97
114
  sampling_params=SamplingParams(**data["sampling_params"]),
98
115
  status_code=data["status_code"],
@@ -100,9 +117,14 @@ class APIResponse:
100
117
  error_message=data["error_message"],
101
118
  input_tokens=data["input_tokens"],
102
119
  output_tokens=data["output_tokens"],
103
- completion=data["completion"],
104
- finish_reason=data["finish_reason"],
105
- cost=data["cost"],
120
+ content=content,
121
+ thinking=data.get("thinking"),
122
+ model_external=data.get("model_external"),
123
+ region=data.get("region"),
124
+ logprobs=data.get("logprobs"),
125
+ finish_reason=data.get("finish_reason"),
126
+ cost=data.get("cost"),
127
+ cache_hit=data.get("cache_hit", False),
106
128
  )
107
129
 
108
130
  def write_to_file(self, filename):
@@ -145,6 +167,7 @@ class APIRequestBase(ABC):
145
167
  debug: bool = False,
146
168
  all_model_names: list[str] | None = None,
147
169
  all_sampling_params: list[SamplingParams] | None = None,
170
+ tools: list | None = None,
148
171
  ):
149
172
  if all_model_names is None:
150
173
  raise ValueError("all_model_names must be provided.")
@@ -166,6 +189,7 @@ class APIRequestBase(ABC):
166
189
  self.debug = debug
167
190
  self.all_model_names = all_model_names
168
191
  self.all_sampling_params = all_sampling_params
192
+ self.tools = tools
169
193
  self.result = [] # list of APIResponse objects from each attempt
170
194
 
171
195
  # these should be set in the __init__ of the subclass
@@ -255,6 +279,7 @@ class APIRequestBase(ABC):
255
279
  callback=self.callback,
256
280
  all_model_names=self.all_model_names,
257
281
  all_sampling_params=self.all_sampling_params,
282
+ tools=self.tools,
258
283
  )
259
284
  # PROBLEM: new request is never put into results array, so we can't get the result.
260
285
  self.retry_queue.put_nowait(new_request)
@@ -297,7 +322,7 @@ class APIRequestBase(ABC):
297
322
  status_code=None,
298
323
  is_error=True,
299
324
  error_message="Request timed out (terminated by client).",
300
- completion=None,
325
+ content=None,
301
326
  input_tokens=None,
302
327
  output_tokens=None,
303
328
  )
@@ -315,7 +340,7 @@ class APIRequestBase(ABC):
315
340
  status_code=None,
316
341
  is_error=True,
317
342
  error_message=f"Unexpected {type(e).__name__}: {str(e) or 'No message.'}",
318
- completion=None,
343
+ content=None,
319
344
  input_tokens=None,
320
345
  output_tokens=None,
321
346
  )
@@ -344,6 +369,7 @@ def create_api_request(
344
369
  callback: Callable | None = None,
345
370
  all_model_names: list[str] | None = None,
346
371
  all_sampling_params: list[SamplingParams] | None = None,
372
+ tools: list | None = None,
347
373
  ) -> APIRequestBase:
348
374
  from .common import CLASSES # circular import so made it lazy, does this work?
349
375
 
@@ -368,5 +394,6 @@ def create_api_request(
368
394
  callback=callback,
369
395
  all_model_names=all_model_names,
370
396
  all_sampling_params=all_sampling_params,
397
+ tools=tools,
371
398
  **kwargs,
372
399
  )
@@ -0,0 +1,283 @@
1
+ import asyncio
2
+ import json
3
+ import os
4
+ from aiohttp import ClientResponse
5
+ from tqdm import tqdm
6
+ from typing import Callable
7
+
8
+ try:
9
+ from requests_aws4auth import AWS4Auth
10
+ except ImportError:
11
+ raise ImportError(
12
+ "aws4auth is required for bedrock support. Install with: pip install requests-aws4auth"
13
+ )
14
+
15
+ from lm_deluge.prompt import Conversation, Message, Text, ToolCall, Thinking
16
+ from .base import APIRequestBase, APIResponse
17
+
18
+ from ..tracker import StatusTracker
19
+ from ..sampling_params import SamplingParams
20
+ from ..models import APIModel
21
+
22
+
23
+ class BedrockRequest(APIRequestBase):
24
+ def __init__(
25
+ self,
26
+ task_id: int,
27
+ model_name: str,
28
+ prompt: Conversation,
29
+ attempts_left: int,
30
+ status_tracker: StatusTracker,
31
+ retry_queue: asyncio.Queue,
32
+ results_arr: list,
33
+ request_timeout: int = 30,
34
+ sampling_params: SamplingParams = SamplingParams(),
35
+ pbar: tqdm | None = None,
36
+ callback: Callable | None = None,
37
+ debug: bool = False,
38
+ all_model_names: list[str] | None = None,
39
+ all_sampling_params: list[SamplingParams] | None = None,
40
+ tools: list | None = None,
41
+ ):
42
+ super().__init__(
43
+ task_id=task_id,
44
+ model_name=model_name,
45
+ prompt=prompt,
46
+ attempts_left=attempts_left,
47
+ status_tracker=status_tracker,
48
+ retry_queue=retry_queue,
49
+ results_arr=results_arr,
50
+ request_timeout=request_timeout,
51
+ sampling_params=sampling_params,
52
+ pbar=pbar,
53
+ callback=callback,
54
+ debug=debug,
55
+ all_model_names=all_model_names,
56
+ all_sampling_params=all_sampling_params,
57
+ tools=tools,
58
+ )
59
+
60
+ self.model = APIModel.from_registry(model_name)
61
+
62
+ # Get AWS credentials from environment
63
+ self.access_key = os.getenv("AWS_ACCESS_KEY_ID")
64
+ self.secret_key = os.getenv("AWS_SECRET_ACCESS_KEY")
65
+ self.session_token = os.getenv("AWS_SESSION_TOKEN")
66
+
67
+ if not self.access_key or not self.secret_key:
68
+ raise ValueError(
69
+ "AWS credentials not found. Please set AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY environment variables."
70
+ )
71
+
72
+ # Determine region - use us-west-2 for cross-region inference models
73
+ if self.model.name.startswith("us.anthropic."):
74
+ # Cross-region inference profiles should use us-west-2
75
+ self.region = "us-west-2"
76
+ else:
77
+ # Direct model IDs can use default region
78
+ self.region = getattr(self.model, "region", "us-east-1")
79
+ if hasattr(self.model, "regions") and self.model.regions:
80
+ if isinstance(self.model.regions, list):
81
+ self.region = self.model.regions[0]
82
+ elif isinstance(self.model.regions, dict):
83
+ self.region = list(self.model.regions.keys())[0]
84
+
85
+ # Construct the endpoint URL
86
+ self.service = "bedrock" # Service name for signing is 'bedrock' even though endpoint is bedrock-runtime
87
+ self.url = f"https://bedrock-runtime.{self.region}.amazonaws.com/model/{self.model.name}/invoke"
88
+
89
+ # Convert prompt to Anthropic format for bedrock
90
+ self.system_message, messages = prompt.to_anthropic()
91
+
92
+ # Prepare request body in Anthropic's bedrock format
93
+ self.request_json = {
94
+ "anthropic_version": "bedrock-2023-05-31",
95
+ "max_tokens": sampling_params.max_new_tokens,
96
+ "temperature": sampling_params.temperature,
97
+ "top_p": sampling_params.top_p,
98
+ "messages": messages,
99
+ }
100
+
101
+ if self.system_message is not None:
102
+ self.request_json["system"] = self.system_message
103
+
104
+ if tools:
105
+ self.request_json["tools"] = [tool.dump_for("anthropic") for tool in tools]
106
+
107
+ # Setup AWS4Auth for signing
108
+ self.auth = AWS4Auth(
109
+ self.access_key,
110
+ self.secret_key,
111
+ self.region,
112
+ self.service,
113
+ session_token=self.session_token,
114
+ )
115
+
116
+ # Setup basic headers (AWS4Auth will add the Authorization header)
117
+ self.request_header = {
118
+ "Content-Type": "application/json",
119
+ }
120
+
121
+ async def call_api(self):
122
+ """Override call_api to handle AWS4Auth signing."""
123
+ try:
124
+ import aiohttp
125
+
126
+ self.status_tracker.total_requests += 1
127
+ timeout = aiohttp.ClientTimeout(total=self.request_timeout)
128
+
129
+ # Prepare the request data
130
+ payload = json.dumps(self.request_json, separators=(",", ":")).encode(
131
+ "utf-8"
132
+ )
133
+
134
+ # Create a fake requests.PreparedRequest object for AWS4Auth to sign
135
+ import requests
136
+
137
+ fake_request = requests.Request(
138
+ method="POST",
139
+ url=self.url,
140
+ data=payload,
141
+ headers=self.request_header.copy(),
142
+ )
143
+
144
+ # Prepare the request so AWS4Auth can sign it properly
145
+ prepared_request = fake_request.prepare()
146
+
147
+ # Let AWS4Auth sign the prepared request
148
+ signed_request = self.auth(prepared_request)
149
+
150
+ # Extract the signed headers
151
+ signed_headers = dict(signed_request.headers)
152
+
153
+ async with aiohttp.ClientSession(timeout=timeout) as session:
154
+ async with session.post(
155
+ url=self.url,
156
+ headers=signed_headers,
157
+ data=payload,
158
+ ) as http_response:
159
+ response: APIResponse = await self.handle_response(http_response)
160
+
161
+ self.result.append(response)
162
+ if response.is_error:
163
+ self.handle_error(
164
+ create_new_request=response.retry_with_different_model or False,
165
+ give_up_if_no_other_models=response.give_up_if_no_other_models
166
+ or False,
167
+ )
168
+ else:
169
+ self.handle_success(response)
170
+
171
+ except asyncio.TimeoutError:
172
+ self.result.append(
173
+ APIResponse(
174
+ id=self.task_id,
175
+ model_internal=self.model_name,
176
+ prompt=self.prompt,
177
+ sampling_params=self.sampling_params,
178
+ status_code=None,
179
+ is_error=True,
180
+ error_message="Request timed out (terminated by client).",
181
+ content=None,
182
+ input_tokens=None,
183
+ output_tokens=None,
184
+ )
185
+ )
186
+ self.handle_error(create_new_request=False)
187
+
188
+ except Exception as e:
189
+ from ..errors import raise_if_modal_exception
190
+
191
+ raise_if_modal_exception(e)
192
+ self.result.append(
193
+ APIResponse(
194
+ id=self.task_id,
195
+ model_internal=self.model_name,
196
+ prompt=self.prompt,
197
+ sampling_params=self.sampling_params,
198
+ status_code=None,
199
+ is_error=True,
200
+ error_message=f"Unexpected {type(e).__name__}: {str(e) or 'No message.'}",
201
+ content=None,
202
+ input_tokens=None,
203
+ output_tokens=None,
204
+ )
205
+ )
206
+ self.handle_error(create_new_request=False)
207
+
208
+ async def handle_response(self, http_response: ClientResponse) -> APIResponse:
209
+ is_error = False
210
+ error_message = None
211
+ thinking = None
212
+ content = None
213
+ input_tokens = None
214
+ output_tokens = None
215
+ status_code = http_response.status
216
+ mimetype = http_response.headers.get("Content-Type", None)
217
+
218
+ if status_code >= 200 and status_code < 300:
219
+ try:
220
+ data = await http_response.json()
221
+ response_content = data["content"]
222
+
223
+ # Parse response into Message with parts
224
+ parts = []
225
+ for item in response_content:
226
+ if item["type"] == "text":
227
+ parts.append(Text(item["text"]))
228
+ elif item["type"] == "thinking":
229
+ thinking = item["thinking"]
230
+ parts.append(Thinking(item["thinking"]))
231
+ elif item["type"] == "tool_use":
232
+ parts.append(
233
+ ToolCall(
234
+ id=item["id"],
235
+ name=item["name"],
236
+ arguments=item["input"],
237
+ )
238
+ )
239
+
240
+ content = Message("assistant", parts)
241
+ input_tokens = data["usage"]["input_tokens"]
242
+ output_tokens = data["usage"]["output_tokens"]
243
+ except Exception as e:
244
+ is_error = True
245
+ error_message = (
246
+ f"Error calling .json() on response w/ status {status_code}: {e}"
247
+ )
248
+ elif mimetype and "json" in mimetype.lower():
249
+ is_error = True
250
+ data = await http_response.json()
251
+ error_message = json.dumps(data)
252
+ else:
253
+ is_error = True
254
+ text = await http_response.text()
255
+ error_message = text
256
+
257
+ # Handle special kinds of errors
258
+ if is_error and error_message is not None:
259
+ if (
260
+ "rate limit" in error_message.lower()
261
+ or "throttling" in error_message.lower()
262
+ or status_code == 429
263
+ ):
264
+ error_message += " (Rate limit error, triggering cooldown.)"
265
+ self.status_tracker.rate_limit_exceeded()
266
+ if "context length" in error_message or "too long" in error_message:
267
+ error_message += " (Context length exceeded, set retries to 0.)"
268
+ self.attempts_left = 0
269
+
270
+ return APIResponse(
271
+ id=self.task_id,
272
+ status_code=status_code,
273
+ is_error=is_error,
274
+ error_message=error_message,
275
+ prompt=self.prompt,
276
+ content=content,
277
+ thinking=thinking,
278
+ model_internal=self.model_name,
279
+ region=self.region,
280
+ sampling_params=self.sampling_params,
281
+ input_tokens=input_tokens,
282
+ output_tokens=output_tokens,
283
+ )
@@ -1,9 +1,11 @@
1
1
  from .openai import OpenAIRequest
2
2
  from .anthropic import AnthropicRequest
3
3
  from .mistral import MistralRequest
4
+ from .bedrock import BedrockRequest
4
5
 
5
6
  CLASSES = {
6
7
  "openai": OpenAIRequest,
7
8
  "anthropic": AnthropicRequest,
8
9
  "mistral": MistralRequest,
10
+ "bedrock": BedrockRequest,
9
11
  }
@@ -7,7 +7,7 @@ from tqdm.auto import tqdm
7
7
  from typing import Callable
8
8
 
9
9
  from .base import APIRequestBase, APIResponse
10
- from ..prompt import Conversation
10
+ from ..prompt import Conversation, Message
11
11
  from ..tracker import StatusTracker
12
12
  from ..sampling_params import SamplingParams
13
13
  from ..models import APIModel
@@ -130,7 +130,7 @@ class MistralRequest(APIRequestBase):
130
130
  error_message=error_message,
131
131
  prompt=self.prompt,
132
132
  logprobs=logprobs,
133
- completion=completion,
133
+ content=Message.ai(completion),
134
134
  model_internal=self.model_name,
135
135
  sampling_params=self.sampling_params,
136
136
  input_tokens=input_tokens,