lm-deluge 0.0.8__tar.gz → 0.0.9__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of lm-deluge might be problematic. Click here for more details.

Files changed (54) hide show
  1. {lm_deluge-0.0.8/src/lm_deluge.egg-info → lm_deluge-0.0.9}/PKG-INFO +1 -1
  2. {lm_deluge-0.0.8 → lm_deluge-0.0.9}/pyproject.toml +1 -1
  3. {lm_deluge-0.0.8 → lm_deluge-0.0.9}/src/lm_deluge/api_requests/anthropic.py +23 -7
  4. {lm_deluge-0.0.8 → lm_deluge-0.0.9}/src/lm_deluge/api_requests/base.py +38 -11
  5. lm_deluge-0.0.9/src/lm_deluge/api_requests/bedrock.py +283 -0
  6. {lm_deluge-0.0.8 → lm_deluge-0.0.9}/src/lm_deluge/api_requests/common.py +2 -0
  7. {lm_deluge-0.0.8 → lm_deluge-0.0.9}/src/lm_deluge/api_requests/mistral.py +2 -2
  8. {lm_deluge-0.0.8 → lm_deluge-0.0.9}/src/lm_deluge/api_requests/openai.py +37 -6
  9. {lm_deluge-0.0.8 → lm_deluge-0.0.9}/src/lm_deluge/client.py +18 -3
  10. {lm_deluge-0.0.8 → lm_deluge-0.0.9}/src/lm_deluge/models.py +89 -24
  11. lm_deluge-0.0.9/src/lm_deluge/prompt.py +693 -0
  12. {lm_deluge-0.0.8 → lm_deluge-0.0.9}/src/lm_deluge/tool.py +16 -35
  13. {lm_deluge-0.0.8 → lm_deluge-0.0.9/src/lm_deluge.egg-info}/PKG-INFO +1 -1
  14. {lm_deluge-0.0.8 → lm_deluge-0.0.9}/src/lm_deluge.egg-info/SOURCES.txt +3 -0
  15. {lm_deluge-0.0.8 → lm_deluge-0.0.9}/tests/test_all_models.py +4 -0
  16. lm_deluge-0.0.9/tests/test_bedrock_models.py +252 -0
  17. lm_deluge-0.0.9/tests/test_tool_calls.py +401 -0
  18. lm_deluge-0.0.8/src/lm_deluge/prompt.py +0 -357
  19. {lm_deluge-0.0.8 → lm_deluge-0.0.9}/LICENSE +0 -0
  20. {lm_deluge-0.0.8 → lm_deluge-0.0.9}/README.md +0 -0
  21. {lm_deluge-0.0.8 → lm_deluge-0.0.9}/setup.cfg +0 -0
  22. {lm_deluge-0.0.8 → lm_deluge-0.0.9}/src/lm_deluge/__init__.py +0 -0
  23. {lm_deluge-0.0.8 → lm_deluge-0.0.9}/src/lm_deluge/api_requests/__init__.py +0 -0
  24. {lm_deluge-0.0.8 → lm_deluge-0.0.9}/src/lm_deluge/api_requests/deprecated/bedrock.py +0 -0
  25. {lm_deluge-0.0.8 → lm_deluge-0.0.9}/src/lm_deluge/api_requests/deprecated/cohere.py +0 -0
  26. {lm_deluge-0.0.8 → lm_deluge-0.0.9}/src/lm_deluge/api_requests/deprecated/deepseek.py +0 -0
  27. {lm_deluge-0.0.8 → lm_deluge-0.0.9}/src/lm_deluge/api_requests/deprecated/mistral.py +0 -0
  28. {lm_deluge-0.0.8 → lm_deluge-0.0.9}/src/lm_deluge/api_requests/deprecated/vertex.py +0 -0
  29. {lm_deluge-0.0.8 → lm_deluge-0.0.9}/src/lm_deluge/cache.py +0 -0
  30. {lm_deluge-0.0.8 → lm_deluge-0.0.9}/src/lm_deluge/embed.py +0 -0
  31. {lm_deluge-0.0.8 → lm_deluge-0.0.9}/src/lm_deluge/errors.py +0 -0
  32. {lm_deluge-0.0.8 → lm_deluge-0.0.9}/src/lm_deluge/gemini_limits.py +0 -0
  33. {lm_deluge-0.0.8 → lm_deluge-0.0.9}/src/lm_deluge/image.py +0 -0
  34. {lm_deluge-0.0.8 → lm_deluge-0.0.9}/src/lm_deluge/llm_tools/__init__.py +0 -0
  35. {lm_deluge-0.0.8 → lm_deluge-0.0.9}/src/lm_deluge/llm_tools/extract.py +0 -0
  36. {lm_deluge-0.0.8 → lm_deluge-0.0.9}/src/lm_deluge/llm_tools/score.py +0 -0
  37. {lm_deluge-0.0.8 → lm_deluge-0.0.9}/src/lm_deluge/llm_tools/translate.py +0 -0
  38. {lm_deluge-0.0.8 → lm_deluge-0.0.9}/src/lm_deluge/rerank.py +0 -0
  39. {lm_deluge-0.0.8 → lm_deluge-0.0.9}/src/lm_deluge/sampling_params.py +0 -0
  40. {lm_deluge-0.0.8 → lm_deluge-0.0.9}/src/lm_deluge/tracker.py +0 -0
  41. {lm_deluge-0.0.8 → lm_deluge-0.0.9}/src/lm_deluge/util/json.py +0 -0
  42. {lm_deluge-0.0.8 → lm_deluge-0.0.9}/src/lm_deluge/util/logprobs.py +0 -0
  43. {lm_deluge-0.0.8 → lm_deluge-0.0.9}/src/lm_deluge/util/validation.py +0 -0
  44. {lm_deluge-0.0.8 → lm_deluge-0.0.9}/src/lm_deluge/util/xml.py +0 -0
  45. {lm_deluge-0.0.8 → lm_deluge-0.0.9}/src/lm_deluge.egg-info/dependency_links.txt +0 -0
  46. {lm_deluge-0.0.8 → lm_deluge-0.0.9}/src/lm_deluge.egg-info/requires.txt +0 -0
  47. {lm_deluge-0.0.8 → lm_deluge-0.0.9}/src/lm_deluge.egg-info/top_level.txt +0 -0
  48. {lm_deluge-0.0.8 → lm_deluge-0.0.9}/tests/test_cache.py +0 -0
  49. {lm_deluge-0.0.8 → lm_deluge-0.0.9}/tests/test_image_models.py +0 -0
  50. {lm_deluge-0.0.8 → lm_deluge-0.0.9}/tests/test_image_utils.py +0 -0
  51. {lm_deluge-0.0.8 → lm_deluge-0.0.9}/tests/test_json_utils.py +0 -0
  52. {lm_deluge-0.0.8 → lm_deluge-0.0.9}/tests/test_sampling_params.py +0 -0
  53. {lm_deluge-0.0.8 → lm_deluge-0.0.9}/tests/test_translate.py +0 -0
  54. {lm_deluge-0.0.8 → lm_deluge-0.0.9}/tests/test_xml_utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: lm_deluge
3
- Version: 0.0.8
3
+ Version: 0.0.9
4
4
  Summary: Python utility for using LLM API models.
5
5
  Author-email: Benjamin Anderson <ben@trytaylor.ai>
6
6
  Requires-Python: >=3.10
@@ -3,7 +3,7 @@ requires = ["setuptools", "wheel"]
3
3
 
4
4
  [project]
5
5
  name = "lm_deluge"
6
- version = "0.0.8"
6
+ version = "0.0.9"
7
7
  authors = [{ name = "Benjamin Anderson", email = "ben@trytaylor.ai" }]
8
8
  description = "Python utility for using LLM API models."
9
9
  readme = "README.md"
@@ -6,7 +6,7 @@ import warnings
6
6
  from tqdm import tqdm
7
7
  from typing import Callable
8
8
 
9
- from lm_deluge.prompt import Conversation
9
+ from lm_deluge.prompt import Conversation, Message, Text, ToolCall, Thinking
10
10
  from .base import APIRequestBase, APIResponse
11
11
 
12
12
  from ..tracker import StatusTracker
@@ -34,6 +34,7 @@ class AnthropicRequest(APIRequestBase):
34
34
  # for retries
35
35
  all_model_names: list[str] | None = None,
36
36
  all_sampling_params: list[SamplingParams] | None = None,
37
+ tools: list | None = None,
37
38
  ):
38
39
  super().__init__(
39
40
  task_id=task_id,
@@ -50,6 +51,7 @@ class AnthropicRequest(APIRequestBase):
50
51
  debug=debug,
51
52
  all_model_names=all_model_names,
52
53
  all_sampling_params=all_sampling_params,
54
+ tools=tools,
53
55
  )
54
56
  self.model = APIModel.from_registry(model_name)
55
57
  self.url = f"{self.model.api_base}/messages"
@@ -94,12 +96,14 @@ class AnthropicRequest(APIRequestBase):
94
96
  )
95
97
  if self.system_message is not None:
96
98
  self.request_json["system"] = self.system_message
99
+ if tools:
100
+ self.request_json["tools"] = [tool.dump_for("anthropic") for tool in tools]
97
101
 
98
102
  async def handle_response(self, http_response: ClientResponse) -> APIResponse:
99
103
  is_error = False
100
104
  error_message = None
101
105
  thinking = None
102
- completion = None
106
+ content = None
103
107
  input_tokens = None
104
108
  output_tokens = None
105
109
  status_code = http_response.status
@@ -119,14 +123,26 @@ class AnthropicRequest(APIRequestBase):
119
123
  if status_code >= 200 and status_code < 300:
120
124
  try:
121
125
  data = await http_response.json()
122
- content = data["content"] # [0]["text"]
123
- for item in content:
126
+ response_content = data["content"]
127
+
128
+ # Parse response into Message with parts
129
+ parts = []
130
+ for item in response_content:
124
131
  if item["type"] == "text":
125
- completion = item["text"]
132
+ parts.append(Text(item["text"]))
126
133
  elif item["type"] == "thinking":
127
134
  thinking = item["thinking"]
135
+ parts.append(Thinking(item["thinking"]))
128
136
  elif item["type"] == "tool_use":
129
- continue # TODO: implement and report tool use
137
+ parts.append(
138
+ ToolCall(
139
+ id=item["id"],
140
+ name=item["name"],
141
+ arguments=item["input"],
142
+ )
143
+ )
144
+
145
+ content = Message("assistant", parts)
130
146
  input_tokens = data["usage"]["input_tokens"]
131
147
  output_tokens = data["usage"]["output_tokens"]
132
148
  except Exception as e:
@@ -162,7 +178,7 @@ class AnthropicRequest(APIRequestBase):
162
178
  is_error=is_error,
163
179
  error_message=error_message,
164
180
  prompt=self.prompt,
165
- completion=completion,
181
+ content=content,
166
182
  thinking=thinking,
167
183
  model_internal=self.model_name,
168
184
  sampling_params=self.sampling_params,
@@ -7,7 +7,7 @@ from dataclasses import dataclass
7
7
  from abc import ABC, abstractmethod
8
8
  from typing import Callable
9
9
 
10
- from lm_deluge.prompt import Conversation
10
+ from lm_deluge.prompt import Conversation, Message
11
11
 
12
12
  from ..tracker import StatusTracker
13
13
  from ..sampling_params import SamplingParams
@@ -30,10 +30,12 @@ class APIResponse:
30
30
  error_message: str | None
31
31
 
32
32
  # completion information
33
- completion: str | None
34
33
  input_tokens: int | None
35
34
  output_tokens: int | None
36
35
 
36
+ # response content - structured format
37
+ content: Message | None = None
38
+
37
39
  # optional or calculated automatically
38
40
  thinking: str | None = None # if model shows thinking tokens
39
41
  model_external: str | None = None # the model tag used by the API
@@ -47,6 +49,13 @@ class APIResponse:
47
49
  # set to true if should NOT retry with the same model (unrecoverable error)
48
50
  give_up_if_no_other_models: bool | None = False
49
51
 
52
+ @property
53
+ def completion(self) -> str | None:
54
+ """Backward compatibility: extract text from content Message."""
55
+ if self.content is not None:
56
+ return self.content.completion
57
+ return None
58
+
50
59
  def __post_init__(self):
51
60
  # calculate cost & get external model name
52
61
  self.id = int(self.id)
@@ -63,7 +72,7 @@ class APIResponse:
63
72
  self.input_tokens * api_model.input_cost / 1e6
64
73
  + self.output_tokens * api_model.output_cost / 1e6
65
74
  )
66
- elif self.completion is not None:
75
+ elif self.content is not None and self.completion is not None:
67
76
  print(
68
77
  f"Warning: Completion provided without token counts for model {self.model_internal}."
69
78
  )
@@ -79,7 +88,8 @@ class APIResponse:
79
88
  "status_code": self.status_code,
80
89
  "is_error": self.is_error,
81
90
  "error_message": self.error_message,
82
- "completion": self.completion,
91
+ "completion": self.completion, # computed property
92
+ "content": self.content.to_log() if self.content else None,
83
93
  "input_tokens": self.input_tokens,
84
94
  "output_tokens": self.output_tokens,
85
95
  "finish_reason": self.finish_reason,
@@ -88,11 +98,18 @@ class APIResponse:
88
98
 
89
99
  @classmethod
90
100
  def from_dict(cls, data: dict):
101
+ # Handle backward compatibility for content/completion
102
+ content = None
103
+ if "content" in data and data["content"] is not None:
104
+ # Reconstruct message from log format
105
+ content = Message.from_log(data["content"])
106
+ elif "completion" in data and data["completion"] is not None:
107
+ # Backward compatibility: create a Message with just text
108
+ content = Message.ai(data["completion"])
109
+
91
110
  return cls(
92
111
  id=data.get("id", random.randint(0, 1_000_000_000)),
93
112
  model_internal=data["model_internal"],
94
- model_external=data["model_external"],
95
- region=data["region"],
96
113
  prompt=Conversation.from_log(data["prompt"]),
97
114
  sampling_params=SamplingParams(**data["sampling_params"]),
98
115
  status_code=data["status_code"],
@@ -100,9 +117,14 @@ class APIResponse:
100
117
  error_message=data["error_message"],
101
118
  input_tokens=data["input_tokens"],
102
119
  output_tokens=data["output_tokens"],
103
- completion=data["completion"],
104
- finish_reason=data["finish_reason"],
105
- cost=data["cost"],
120
+ content=content,
121
+ thinking=data.get("thinking"),
122
+ model_external=data.get("model_external"),
123
+ region=data.get("region"),
124
+ logprobs=data.get("logprobs"),
125
+ finish_reason=data.get("finish_reason"),
126
+ cost=data.get("cost"),
127
+ cache_hit=data.get("cache_hit", False),
106
128
  )
107
129
 
108
130
  def write_to_file(self, filename):
@@ -145,6 +167,7 @@ class APIRequestBase(ABC):
145
167
  debug: bool = False,
146
168
  all_model_names: list[str] | None = None,
147
169
  all_sampling_params: list[SamplingParams] | None = None,
170
+ tools: list | None = None,
148
171
  ):
149
172
  if all_model_names is None:
150
173
  raise ValueError("all_model_names must be provided.")
@@ -166,6 +189,7 @@ class APIRequestBase(ABC):
166
189
  self.debug = debug
167
190
  self.all_model_names = all_model_names
168
191
  self.all_sampling_params = all_sampling_params
192
+ self.tools = tools
169
193
  self.result = [] # list of APIResponse objects from each attempt
170
194
 
171
195
  # these should be set in the __init__ of the subclass
@@ -255,6 +279,7 @@ class APIRequestBase(ABC):
255
279
  callback=self.callback,
256
280
  all_model_names=self.all_model_names,
257
281
  all_sampling_params=self.all_sampling_params,
282
+ tools=self.tools,
258
283
  )
259
284
  # PROBLEM: new request is never put into results array, so we can't get the result.
260
285
  self.retry_queue.put_nowait(new_request)
@@ -297,7 +322,7 @@ class APIRequestBase(ABC):
297
322
  status_code=None,
298
323
  is_error=True,
299
324
  error_message="Request timed out (terminated by client).",
300
- completion=None,
325
+ content=None,
301
326
  input_tokens=None,
302
327
  output_tokens=None,
303
328
  )
@@ -315,7 +340,7 @@ class APIRequestBase(ABC):
315
340
  status_code=None,
316
341
  is_error=True,
317
342
  error_message=f"Unexpected {type(e).__name__}: {str(e) or 'No message.'}",
318
- completion=None,
343
+ content=None,
319
344
  input_tokens=None,
320
345
  output_tokens=None,
321
346
  )
@@ -344,6 +369,7 @@ def create_api_request(
344
369
  callback: Callable | None = None,
345
370
  all_model_names: list[str] | None = None,
346
371
  all_sampling_params: list[SamplingParams] | None = None,
372
+ tools: list | None = None,
347
373
  ) -> APIRequestBase:
348
374
  from .common import CLASSES # circular import so made it lazy, does this work?
349
375
 
@@ -368,5 +394,6 @@ def create_api_request(
368
394
  callback=callback,
369
395
  all_model_names=all_model_names,
370
396
  all_sampling_params=all_sampling_params,
397
+ tools=tools,
371
398
  **kwargs,
372
399
  )
@@ -0,0 +1,283 @@
1
+ import asyncio
2
+ import json
3
+ import os
4
+ from aiohttp import ClientResponse
5
+ from tqdm import tqdm
6
+ from typing import Callable
7
+
8
+ try:
9
+ from requests_aws4auth import AWS4Auth
10
+ except ImportError:
11
+ raise ImportError(
12
+ "aws4auth is required for bedrock support. Install with: pip install requests-aws4auth"
13
+ )
14
+
15
+ from lm_deluge.prompt import Conversation, Message, Text, ToolCall, Thinking
16
+ from .base import APIRequestBase, APIResponse
17
+
18
+ from ..tracker import StatusTracker
19
+ from ..sampling_params import SamplingParams
20
+ from ..models import APIModel
21
+
22
+
23
+ class BedrockRequest(APIRequestBase):
24
+ def __init__(
25
+ self,
26
+ task_id: int,
27
+ model_name: str,
28
+ prompt: Conversation,
29
+ attempts_left: int,
30
+ status_tracker: StatusTracker,
31
+ retry_queue: asyncio.Queue,
32
+ results_arr: list,
33
+ request_timeout: int = 30,
34
+ sampling_params: SamplingParams = SamplingParams(),
35
+ pbar: tqdm | None = None,
36
+ callback: Callable | None = None,
37
+ debug: bool = False,
38
+ all_model_names: list[str] | None = None,
39
+ all_sampling_params: list[SamplingParams] | None = None,
40
+ tools: list | None = None,
41
+ ):
42
+ super().__init__(
43
+ task_id=task_id,
44
+ model_name=model_name,
45
+ prompt=prompt,
46
+ attempts_left=attempts_left,
47
+ status_tracker=status_tracker,
48
+ retry_queue=retry_queue,
49
+ results_arr=results_arr,
50
+ request_timeout=request_timeout,
51
+ sampling_params=sampling_params,
52
+ pbar=pbar,
53
+ callback=callback,
54
+ debug=debug,
55
+ all_model_names=all_model_names,
56
+ all_sampling_params=all_sampling_params,
57
+ tools=tools,
58
+ )
59
+
60
+ self.model = APIModel.from_registry(model_name)
61
+
62
+ # Get AWS credentials from environment
63
+ self.access_key = os.getenv("AWS_ACCESS_KEY_ID")
64
+ self.secret_key = os.getenv("AWS_SECRET_ACCESS_KEY")
65
+ self.session_token = os.getenv("AWS_SESSION_TOKEN")
66
+
67
+ if not self.access_key or not self.secret_key:
68
+ raise ValueError(
69
+ "AWS credentials not found. Please set AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY environment variables."
70
+ )
71
+
72
+ # Determine region - use us-west-2 for cross-region inference models
73
+ if self.model.name.startswith("us.anthropic."):
74
+ # Cross-region inference profiles should use us-west-2
75
+ self.region = "us-west-2"
76
+ else:
77
+ # Direct model IDs can use default region
78
+ self.region = getattr(self.model, "region", "us-east-1")
79
+ if hasattr(self.model, "regions") and self.model.regions:
80
+ if isinstance(self.model.regions, list):
81
+ self.region = self.model.regions[0]
82
+ elif isinstance(self.model.regions, dict):
83
+ self.region = list(self.model.regions.keys())[0]
84
+
85
+ # Construct the endpoint URL
86
+ self.service = "bedrock" # Service name for signing is 'bedrock' even though endpoint is bedrock-runtime
87
+ self.url = f"https://bedrock-runtime.{self.region}.amazonaws.com/model/{self.model.name}/invoke"
88
+
89
+ # Convert prompt to Anthropic format for bedrock
90
+ self.system_message, messages = prompt.to_anthropic()
91
+
92
+ # Prepare request body in Anthropic's bedrock format
93
+ self.request_json = {
94
+ "anthropic_version": "bedrock-2023-05-31",
95
+ "max_tokens": sampling_params.max_new_tokens,
96
+ "temperature": sampling_params.temperature,
97
+ "top_p": sampling_params.top_p,
98
+ "messages": messages,
99
+ }
100
+
101
+ if self.system_message is not None:
102
+ self.request_json["system"] = self.system_message
103
+
104
+ if tools:
105
+ self.request_json["tools"] = [tool.dump_for("anthropic") for tool in tools]
106
+
107
+ # Setup AWS4Auth for signing
108
+ self.auth = AWS4Auth(
109
+ self.access_key,
110
+ self.secret_key,
111
+ self.region,
112
+ self.service,
113
+ session_token=self.session_token,
114
+ )
115
+
116
+ # Setup basic headers (AWS4Auth will add the Authorization header)
117
+ self.request_header = {
118
+ "Content-Type": "application/json",
119
+ }
120
+
121
+ async def call_api(self):
122
+ """Override call_api to handle AWS4Auth signing."""
123
+ try:
124
+ import aiohttp
125
+
126
+ self.status_tracker.total_requests += 1
127
+ timeout = aiohttp.ClientTimeout(total=self.request_timeout)
128
+
129
+ # Prepare the request data
130
+ payload = json.dumps(self.request_json, separators=(",", ":")).encode(
131
+ "utf-8"
132
+ )
133
+
134
+ # Create a fake requests.PreparedRequest object for AWS4Auth to sign
135
+ import requests
136
+
137
+ fake_request = requests.Request(
138
+ method="POST",
139
+ url=self.url,
140
+ data=payload,
141
+ headers=self.request_header.copy(),
142
+ )
143
+
144
+ # Prepare the request so AWS4Auth can sign it properly
145
+ prepared_request = fake_request.prepare()
146
+
147
+ # Let AWS4Auth sign the prepared request
148
+ signed_request = self.auth(prepared_request)
149
+
150
+ # Extract the signed headers
151
+ signed_headers = dict(signed_request.headers)
152
+
153
+ async with aiohttp.ClientSession(timeout=timeout) as session:
154
+ async with session.post(
155
+ url=self.url,
156
+ headers=signed_headers,
157
+ data=payload,
158
+ ) as http_response:
159
+ response: APIResponse = await self.handle_response(http_response)
160
+
161
+ self.result.append(response)
162
+ if response.is_error:
163
+ self.handle_error(
164
+ create_new_request=response.retry_with_different_model or False,
165
+ give_up_if_no_other_models=response.give_up_if_no_other_models
166
+ or False,
167
+ )
168
+ else:
169
+ self.handle_success(response)
170
+
171
+ except asyncio.TimeoutError:
172
+ self.result.append(
173
+ APIResponse(
174
+ id=self.task_id,
175
+ model_internal=self.model_name,
176
+ prompt=self.prompt,
177
+ sampling_params=self.sampling_params,
178
+ status_code=None,
179
+ is_error=True,
180
+ error_message="Request timed out (terminated by client).",
181
+ content=None,
182
+ input_tokens=None,
183
+ output_tokens=None,
184
+ )
185
+ )
186
+ self.handle_error(create_new_request=False)
187
+
188
+ except Exception as e:
189
+ from ..errors import raise_if_modal_exception
190
+
191
+ raise_if_modal_exception(e)
192
+ self.result.append(
193
+ APIResponse(
194
+ id=self.task_id,
195
+ model_internal=self.model_name,
196
+ prompt=self.prompt,
197
+ sampling_params=self.sampling_params,
198
+ status_code=None,
199
+ is_error=True,
200
+ error_message=f"Unexpected {type(e).__name__}: {str(e) or 'No message.'}",
201
+ content=None,
202
+ input_tokens=None,
203
+ output_tokens=None,
204
+ )
205
+ )
206
+ self.handle_error(create_new_request=False)
207
+
208
+ async def handle_response(self, http_response: ClientResponse) -> APIResponse:
209
+ is_error = False
210
+ error_message = None
211
+ thinking = None
212
+ content = None
213
+ input_tokens = None
214
+ output_tokens = None
215
+ status_code = http_response.status
216
+ mimetype = http_response.headers.get("Content-Type", None)
217
+
218
+ if status_code >= 200 and status_code < 300:
219
+ try:
220
+ data = await http_response.json()
221
+ response_content = data["content"]
222
+
223
+ # Parse response into Message with parts
224
+ parts = []
225
+ for item in response_content:
226
+ if item["type"] == "text":
227
+ parts.append(Text(item["text"]))
228
+ elif item["type"] == "thinking":
229
+ thinking = item["thinking"]
230
+ parts.append(Thinking(item["thinking"]))
231
+ elif item["type"] == "tool_use":
232
+ parts.append(
233
+ ToolCall(
234
+ id=item["id"],
235
+ name=item["name"],
236
+ arguments=item["input"],
237
+ )
238
+ )
239
+
240
+ content = Message("assistant", parts)
241
+ input_tokens = data["usage"]["input_tokens"]
242
+ output_tokens = data["usage"]["output_tokens"]
243
+ except Exception as e:
244
+ is_error = True
245
+ error_message = (
246
+ f"Error calling .json() on response w/ status {status_code}: {e}"
247
+ )
248
+ elif mimetype and "json" in mimetype.lower():
249
+ is_error = True
250
+ data = await http_response.json()
251
+ error_message = json.dumps(data)
252
+ else:
253
+ is_error = True
254
+ text = await http_response.text()
255
+ error_message = text
256
+
257
+ # Handle special kinds of errors
258
+ if is_error and error_message is not None:
259
+ if (
260
+ "rate limit" in error_message.lower()
261
+ or "throttling" in error_message.lower()
262
+ or status_code == 429
263
+ ):
264
+ error_message += " (Rate limit error, triggering cooldown.)"
265
+ self.status_tracker.rate_limit_exceeded()
266
+ if "context length" in error_message or "too long" in error_message:
267
+ error_message += " (Context length exceeded, set retries to 0.)"
268
+ self.attempts_left = 0
269
+
270
+ return APIResponse(
271
+ id=self.task_id,
272
+ status_code=status_code,
273
+ is_error=is_error,
274
+ error_message=error_message,
275
+ prompt=self.prompt,
276
+ content=content,
277
+ thinking=thinking,
278
+ model_internal=self.model_name,
279
+ region=self.region,
280
+ sampling_params=self.sampling_params,
281
+ input_tokens=input_tokens,
282
+ output_tokens=output_tokens,
283
+ )
@@ -1,9 +1,11 @@
1
1
  from .openai import OpenAIRequest
2
2
  from .anthropic import AnthropicRequest
3
3
  from .mistral import MistralRequest
4
+ from .bedrock import BedrockRequest
4
5
 
5
6
  CLASSES = {
6
7
  "openai": OpenAIRequest,
7
8
  "anthropic": AnthropicRequest,
8
9
  "mistral": MistralRequest,
10
+ "bedrock": BedrockRequest,
9
11
  }
@@ -7,7 +7,7 @@ from tqdm.auto import tqdm
7
7
  from typing import Callable
8
8
 
9
9
  from .base import APIRequestBase, APIResponse
10
- from ..prompt import Conversation
10
+ from ..prompt import Conversation, Message
11
11
  from ..tracker import StatusTracker
12
12
  from ..sampling_params import SamplingParams
13
13
  from ..models import APIModel
@@ -130,7 +130,7 @@ class MistralRequest(APIRequestBase):
130
130
  error_message=error_message,
131
131
  prompt=self.prompt,
132
132
  logprobs=logprobs,
133
- completion=completion,
133
+ content=Message.ai(completion),
134
134
  model_internal=self.model_name,
135
135
  sampling_params=self.sampling_params,
136
136
  input_tokens=input_tokens,
@@ -7,7 +7,7 @@ from tqdm.auto import tqdm
7
7
  from typing import Callable
8
8
 
9
9
  from .base import APIRequestBase, APIResponse
10
- from ..prompt import Conversation
10
+ from ..prompt import Conversation, Message, Text, ToolCall, Thinking
11
11
  from ..tracker import StatusTracker
12
12
  from ..sampling_params import SamplingParams
13
13
  from ..models import APIModel
@@ -34,6 +34,7 @@ class OpenAIRequest(APIRequestBase):
34
34
  debug: bool = False,
35
35
  all_model_names: list[str] | None = None,
36
36
  all_sampling_params: list[SamplingParams] | None = None,
37
+ tools: list | None = None,
37
38
  ):
38
39
  super().__init__(
39
40
  task_id=task_id,
@@ -52,6 +53,7 @@ class OpenAIRequest(APIRequestBase):
52
53
  debug=debug,
53
54
  all_model_names=all_model_names,
54
55
  all_sampling_params=all_sampling_params,
56
+ tools=tools,
55
57
  )
56
58
  self.model = APIModel.from_registry(model_name)
57
59
  self.url = f"{self.model.api_base}/chat/completions"
@@ -85,12 +87,16 @@ class OpenAIRequest(APIRequestBase):
85
87
  self.request_json["top_logprobs"] = top_logprobs
86
88
  if sampling_params.json_mode and self.model.supports_json:
87
89
  self.request_json["response_format"] = {"type": "json_object"}
90
+ if tools:
91
+ self.request_json["tools"] = [
92
+ tool.dump_for("openai-completions") for tool in tools
93
+ ]
88
94
 
89
95
  async def handle_response(self, http_response: ClientResponse) -> APIResponse:
90
96
  is_error = False
91
97
  error_message = None
92
98
  thinking = None
93
- completion = None
99
+ content = None
94
100
  input_tokens = None
95
101
  output_tokens = None
96
102
  logprobs = None
@@ -108,9 +114,34 @@ class OpenAIRequest(APIRequestBase):
108
114
  if not is_error:
109
115
  assert data is not None, "data is None"
110
116
  try:
111
- completion = data["choices"][0]["message"]["content"]
112
- if "reasoning_content" in data["choices"][0]["message"]:
113
- thinking = data["choices"][0]["message"]["reasoning_content"]
117
+ # Parse response into Message with parts
118
+ parts = []
119
+ message = data["choices"][0]["message"]
120
+
121
+ # Add text content if present
122
+ if message.get("content"):
123
+ parts.append(Text(message["content"]))
124
+
125
+ # Add thinking content if present (reasoning models)
126
+ if "reasoning_content" in message:
127
+ thinking = message["reasoning_content"]
128
+ parts.append(Thinking(thinking))
129
+
130
+ # Add tool calls if present
131
+ if "tool_calls" in message:
132
+ for tool_call in message["tool_calls"]:
133
+ parts.append(
134
+ ToolCall(
135
+ id=tool_call["id"],
136
+ name=tool_call["function"]["name"],
137
+ arguments=json.loads(
138
+ tool_call["function"]["arguments"]
139
+ ),
140
+ )
141
+ )
142
+
143
+ content = Message("assistant", parts)
144
+
114
145
  input_tokens = data["usage"]["prompt_tokens"]
115
146
  output_tokens = data["usage"]["completion_tokens"]
116
147
  if self.logprobs and "logprobs" in data["choices"][0]:
@@ -144,7 +175,7 @@ class OpenAIRequest(APIRequestBase):
144
175
  prompt=self.prompt,
145
176
  logprobs=logprobs,
146
177
  thinking=thinking,
147
- completion=completion,
178
+ content=content,
148
179
  model_internal=self.model_name,
149
180
  sampling_params=self.sampling_params,
150
181
  input_tokens=input_tokens,