lm-deluge 0.0.15__py3-none-any.whl → 0.0.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,78 +1,46 @@
1
- import warnings
2
- from aiohttp import ClientResponse
3
1
  import json
4
2
  import os
5
- from typing import Callable
3
+ import warnings
4
+
5
+ from aiohttp import ClientResponse
6
6
 
7
- from .base import APIRequestBase, APIResponse
8
- from ..prompt import Conversation, Message, CachePattern
9
- from ..usage import Usage
10
- from ..tracker import StatusTracker
11
- from ..config import SamplingParams
12
7
  from ..models import APIModel
8
+ from ..prompt import Message
9
+ from ..request_context import RequestContext
10
+ from ..usage import Usage
11
+ from .base import APIRequestBase, APIResponse
13
12
 
14
13
 
15
14
  class MistralRequest(APIRequestBase):
16
- def __init__(
17
- self,
18
- task_id: int,
19
- # should always be 'role', 'content' keys.
20
- # internal logic should handle translating to specific API format
21
- model_name: str, # must correspond to registry
22
- prompt: Conversation,
23
- attempts_left: int,
24
- status_tracker: StatusTracker,
25
- results_arr: list,
26
- request_timeout: int = 30,
27
- sampling_params: SamplingParams = SamplingParams(),
28
- callback: Callable | None = None,
29
- all_model_names: list[str] | None = None,
30
- all_sampling_params: list[SamplingParams] | None = None,
31
- tools: list | None = None,
32
- cache: CachePattern | None = None,
33
- ):
34
- super().__init__(
35
- task_id=task_id,
36
- model_name=model_name,
37
- prompt=prompt,
38
- attempts_left=attempts_left,
39
- status_tracker=status_tracker,
40
- results_arr=results_arr,
41
- request_timeout=request_timeout,
42
- sampling_params=sampling_params,
43
- callback=callback,
44
- all_model_names=all_model_names,
45
- all_sampling_params=all_sampling_params,
46
- tools=tools,
47
- cache=cache,
48
- )
15
+ def __init__(self, context: RequestContext):
16
+ super().__init__(context=context)
49
17
 
50
18
  # Warn if cache is specified for non-Anthropic model
51
- if cache is not None:
19
+ if self.context.cache is not None:
52
20
  warnings.warn(
53
- f"Cache parameter '{cache}' is only supported for Anthropic models, ignoring for {model_name}"
21
+ f"Cache parameter '{self.context.cache}' is only supported for Anthropic models, ignoring for {self.context.model_name}"
54
22
  )
55
- self.model = APIModel.from_registry(model_name)
23
+ self.model = APIModel.from_registry(self.context.model_name)
56
24
  self.url = f"{self.model.api_base}/chat/completions"
57
25
  self.request_header = {
58
26
  "Authorization": f"Bearer {os.getenv(self.model.api_key_env_var)}"
59
27
  }
60
28
  self.request_json = {
61
29
  "model": self.model.name,
62
- "messages": prompt.to_mistral(),
63
- "temperature": sampling_params.temperature,
64
- "top_p": sampling_params.top_p,
65
- "max_tokens": sampling_params.max_new_tokens,
30
+ "messages": self.context.prompt.to_mistral(),
31
+ "temperature": self.context.sampling_params.temperature,
32
+ "top_p": self.context.sampling_params.top_p,
33
+ "max_tokens": self.context.sampling_params.max_new_tokens,
66
34
  }
67
- if sampling_params.reasoning_effort:
35
+ if self.context.sampling_params.reasoning_effort:
68
36
  warnings.warn(
69
- f"Ignoring reasoning_effort param for non-reasoning model: {model_name}"
37
+ f"Ignoring reasoning_effort param for non-reasoning model: {self.context.model_name}"
70
38
  )
71
- if sampling_params.logprobs:
39
+ if self.context.sampling_params.logprobs:
72
40
  warnings.warn(
73
- f"Ignoring logprobs param for non-logprobs model: {model_name}"
41
+ f"Ignoring logprobs param for non-logprobs model: {self.context.model_name}"
74
42
  )
75
- if sampling_params.json_mode and self.model.supports_json:
43
+ if self.context.sampling_params.json_mode and self.model.supports_json:
76
44
  self.request_json["response_format"] = {"type": "json_object"}
77
45
 
78
46
  async def handle_response(self, http_response: ClientResponse) -> APIResponse:
@@ -84,6 +52,8 @@ class MistralRequest(APIRequestBase):
84
52
  status_code = http_response.status
85
53
  mimetype = http_response.headers.get("Content-Type", None)
86
54
  data = None
55
+ assert self.context.status_tracker
56
+
87
57
  if status_code >= 200 and status_code < 300:
88
58
  try:
89
59
  data = await http_response.json()
@@ -98,7 +68,7 @@ class MistralRequest(APIRequestBase):
98
68
  completion = data["choices"][0]["message"]["content"]
99
69
  usage = Usage.from_mistral_usage(data["usage"])
100
70
  if (
101
- self.sampling_params.logprobs
71
+ self.context.sampling_params.logprobs
102
72
  and "logprobs" in data["choices"][0]
103
73
  ):
104
74
  logprobs = data["choices"][0]["logprobs"]["content"]
@@ -118,20 +88,20 @@ class MistralRequest(APIRequestBase):
118
88
  if is_error and error_message is not None:
119
89
  if "rate limit" in error_message.lower() or status_code == 429:
120
90
  error_message += " (Rate limit error, triggering cooldown.)"
121
- self.status_tracker.rate_limit_exceeded()
91
+ self.context.status_tracker.rate_limit_exceeded()
122
92
  if "context length" in error_message:
123
93
  error_message += " (Context length exceeded, set retries to 0.)"
124
- self.attempts_left = 0
94
+ self.context.attempts_left = 0
125
95
 
126
96
  return APIResponse(
127
- id=self.task_id,
97
+ id=self.context.task_id,
128
98
  status_code=status_code,
129
99
  is_error=is_error,
130
100
  error_message=error_message,
131
- prompt=self.prompt,
101
+ prompt=self.context.prompt,
132
102
  logprobs=logprobs,
133
103
  content=Message.ai(completion),
134
- model_internal=self.model_name,
135
- sampling_params=self.sampling_params,
104
+ model_internal=self.context.model_name,
105
+ sampling_params=self.context.sampling_params,
136
106
  usage=usage,
137
107
  )
@@ -1,17 +1,16 @@
1
1
  import json
2
2
  import os
3
3
  import warnings
4
- from typing import Callable
5
4
 
6
5
  import aiohttp
7
6
  from aiohttp import ClientResponse
8
7
 
9
- from lm_deluge.tool import Tool
8
+ from lm_deluge.request_context import RequestContext
9
+ from lm_deluge.tool import MCPServer, Tool
10
10
 
11
11
  from ..config import SamplingParams
12
12
  from ..models import APIModel
13
13
  from ..prompt import CachePattern, Conversation, Message, Text, Thinking, ToolCall
14
- from ..tracker import StatusTracker
15
14
  from ..usage import Usage
16
15
  from .base import APIRequestBase, APIResponse
17
16
 
@@ -53,54 +52,36 @@ def _build_oa_chat_request(
53
52
  return request_json
54
53
 
55
54
 
55
+ def _build_oa_responses_request(
56
+ model: APIModel,
57
+ prompt: Conversation,
58
+ tools: list[Tool] | None,
59
+ sampling_params: SamplingParams,
60
+ ):
61
+ pass # TODO: implement
62
+
63
+
56
64
  class OpenAIRequest(APIRequestBase):
57
- def __init__(
58
- self,
59
- task_id: int,
60
- # should always be 'role', 'content' keys.
61
- # internal logic should handle translating to specific API format
62
- model_name: str, # must correspond to registry
63
- prompt: Conversation,
64
- attempts_left: int,
65
- status_tracker: StatusTracker,
66
- results_arr: list,
67
- request_timeout: int = 30,
68
- sampling_params: SamplingParams = SamplingParams(),
69
- callback: Callable | None = None,
70
- all_model_names: list[str] | None = None,
71
- all_sampling_params: list[SamplingParams] | None = None,
72
- tools: list | None = None,
73
- cache: CachePattern | None = None,
74
- ):
75
- super().__init__(
76
- task_id=task_id,
77
- model_name=model_name,
78
- prompt=prompt,
79
- attempts_left=attempts_left,
80
- status_tracker=status_tracker,
81
- results_arr=results_arr,
82
- request_timeout=request_timeout,
83
- sampling_params=sampling_params,
84
- callback=callback,
85
- all_model_names=all_model_names,
86
- all_sampling_params=all_sampling_params,
87
- tools=tools,
88
- cache=cache,
89
- )
65
+ def __init__(self, context: RequestContext):
66
+ # Pass context to parent, which will handle backwards compatibility
67
+ super().__init__(context=context)
90
68
 
91
69
  # Warn if cache is specified for non-Anthropic model
92
- if cache is not None:
70
+ if self.context.cache is not None:
93
71
  warnings.warn(
94
- f"Cache parameter '{cache}' is only supported for Anthropic models, ignoring for {model_name}"
72
+ f"Cache parameter '{self.context.cache}' is only supported for Anthropic models, ignoring for {self.context.model_name}"
95
73
  )
96
- self.model = APIModel.from_registry(model_name)
74
+ self.model = APIModel.from_registry(self.context.model_name)
97
75
  self.url = f"{self.model.api_base}/chat/completions"
98
76
  self.request_header = {
99
77
  "Authorization": f"Bearer {os.getenv(self.model.api_key_env_var)}"
100
78
  }
101
79
 
102
80
  self.request_json = _build_oa_chat_request(
103
- self.model, prompt, tools, sampling_params
81
+ self.model,
82
+ self.context.prompt,
83
+ self.context.tools,
84
+ self.context.sampling_params,
104
85
  )
105
86
 
106
87
  async def handle_response(self, http_response: ClientResponse) -> APIResponse:
@@ -114,6 +95,8 @@ class OpenAIRequest(APIRequestBase):
114
95
  mimetype = http_response.headers.get("Content-Type", None)
115
96
  data = None
116
97
  finish_reason = None
98
+ assert self.context.status_tracker
99
+
117
100
  if status_code >= 200 and status_code < 300:
118
101
  try:
119
102
  data = await http_response.json()
@@ -156,7 +139,7 @@ class OpenAIRequest(APIRequestBase):
156
139
 
157
140
  usage = Usage.from_openai_usage(data["usage"])
158
141
  if (
159
- self.sampling_params.logprobs
142
+ self.context.sampling_params.logprobs
160
143
  and "logprobs" in data["choices"][0]
161
144
  ):
162
145
  logprobs = data["choices"][0]["logprobs"]["content"]
@@ -176,22 +159,22 @@ class OpenAIRequest(APIRequestBase):
176
159
  if is_error and error_message is not None:
177
160
  if "rate limit" in error_message.lower() or status_code == 429:
178
161
  error_message += " (Rate limit error, triggering cooldown.)"
179
- self.status_tracker.rate_limit_exceeded()
162
+ self.context.status_tracker.rate_limit_exceeded()
180
163
  if "context length" in error_message:
181
164
  error_message += " (Context length exceeded, set retries to 0.)"
182
- self.attempts_left = 0
165
+ self.context.attempts_left = 0
183
166
 
184
167
  return APIResponse(
185
- id=self.task_id,
168
+ id=self.context.task_id,
186
169
  status_code=status_code,
187
170
  is_error=is_error,
188
171
  error_message=error_message,
189
- prompt=self.prompt,
172
+ prompt=self.context.prompt,
190
173
  logprobs=logprobs,
191
174
  thinking=thinking,
192
175
  content=content,
193
- model_internal=self.model_name,
194
- sampling_params=self.sampling_params,
176
+ model_internal=self.context.model_name,
177
+ sampling_params=self.context.sampling_params,
195
178
  usage=usage,
196
179
  raw_response=data,
197
180
  finish_reason=finish_reason,
@@ -199,117 +182,78 @@ class OpenAIRequest(APIRequestBase):
199
182
 
200
183
 
201
184
  class OpenAIResponsesRequest(APIRequestBase):
202
- def __init__(
203
- self,
204
- task_id: int,
205
- model_name: str,
206
- prompt: Conversation,
207
- attempts_left: int,
208
- status_tracker: StatusTracker,
209
- results_arr: list,
210
- request_timeout: int = 30,
211
- sampling_params: SamplingParams = SamplingParams(),
212
- callback: Callable | None = None,
213
- all_model_names: list[str] | None = None,
214
- all_sampling_params: list[SamplingParams] | None = None,
215
- tools: list | None = None,
216
- cache: CachePattern | None = None,
217
- computer_use: bool = False,
218
- display_width: int = 1024,
219
- display_height: int = 768,
220
- ):
221
- super().__init__(
222
- task_id=task_id,
223
- model_name=model_name,
224
- prompt=prompt,
225
- attempts_left=attempts_left,
226
- status_tracker=status_tracker,
227
- results_arr=results_arr,
228
- request_timeout=request_timeout,
229
- sampling_params=sampling_params,
230
- callback=callback,
231
- all_model_names=all_model_names,
232
- all_sampling_params=all_sampling_params,
233
- tools=tools,
234
- cache=cache,
235
- )
236
-
237
- # Store computer use parameters
238
- self.computer_use = computer_use
239
- self.display_width = display_width
240
- self.display_height = display_height
241
-
242
- # Validate computer use requirements
243
- if computer_use and model_name != "openai-computer-use-preview":
244
- raise ValueError(
245
- f"Computer use is only supported with openai-computer-use-preview model, got {model_name}"
246
- )
247
-
185
+ def __init__(self, context: RequestContext):
186
+ super().__init__(context)
248
187
  # Warn if cache is specified for non-Anthropic model
249
- if cache is not None:
188
+ if self.context.cache is not None:
250
189
  warnings.warn(
251
- f"Cache parameter '{cache}' is only supported for Anthropic models, ignoring for {model_name}"
190
+ f"Cache parameter '{self.context.cache}' is only supported for Anthropic models, ignoring for {self.context.model_name}"
252
191
  )
253
- self.model = APIModel.from_registry(model_name)
192
+ self.model = APIModel.from_registry(self.context.model_name)
254
193
  self.url = f"{self.model.api_base}/responses"
255
194
  self.request_header = {
256
195
  "Authorization": f"Bearer {os.getenv(self.model.api_key_env_var)}"
257
196
  }
258
197
 
259
198
  # Convert conversation to input format for Responses API
260
- openai_responses_format = prompt.to_openai_responses()
199
+ openai_responses_format = self.context.prompt.to_openai_responses()
261
200
 
262
201
  self.request_json = {
263
202
  "model": self.model.name,
264
203
  "input": openai_responses_format["input"],
265
- "temperature": sampling_params.temperature,
266
- "top_p": sampling_params.top_p,
204
+ "temperature": self.context.sampling_params.temperature,
205
+ "top_p": self.context.sampling_params.top_p,
267
206
  }
268
207
 
269
208
  # Add max_output_tokens for responses API
270
- if sampling_params.max_new_tokens:
271
- self.request_json["max_output_tokens"] = sampling_params.max_new_tokens
209
+ if self.context.sampling_params.max_new_tokens:
210
+ self.request_json["max_output_tokens"] = (
211
+ self.context.sampling_params.max_new_tokens
212
+ )
272
213
 
273
214
  if self.model.reasoning_model:
274
- if sampling_params.reasoning_effort in [None, "none"]:
215
+ if self.context.sampling_params.reasoning_effort in [None, "none"]:
275
216
  # gemini models can switch reasoning off
276
217
  if "gemini" in self.model.id:
277
- self.sampling_params.reasoning_effort = "none" # expects string
218
+ self.context.sampling_params.reasoning_effort = (
219
+ "none" # expects string
220
+ )
278
221
  # openai models can only go down to "low"
279
222
  else:
280
- self.sampling_params.reasoning_effort = "low"
223
+ self.context.sampling_params.reasoning_effort = "low"
281
224
  self.request_json["temperature"] = 1.0
282
225
  self.request_json["top_p"] = 1.0
283
226
  self.request_json["reasoning"] = {
284
- "effort": sampling_params.reasoning_effort
227
+ "effort": self.context.sampling_params.reasoning_effort
285
228
  }
286
229
  else:
287
- if sampling_params.reasoning_effort:
230
+ if self.context.sampling_params.reasoning_effort:
288
231
  warnings.warn(
289
- f"Ignoring reasoning_effort param for non-reasoning model: {model_name}"
232
+ f"Ignoring reasoning_effort param for non-reasoning model: {self.context.model_name}"
290
233
  )
291
234
 
292
- if sampling_params.json_mode and self.model.supports_json:
235
+ if self.context.sampling_params.json_mode and self.model.supports_json:
293
236
  self.request_json["text"] = {"format": {"type": "json_object"}}
294
237
 
295
238
  # Handle tools
296
239
  request_tools = []
297
- if computer_use:
298
- # Add computer use tool
299
- request_tools.append(
300
- {
301
- "type": "computer_use_preview",
302
- "display_width": display_width,
303
- "display_height": display_height,
304
- "environment": "browser", # Default to browser, could be configurable
305
- }
306
- )
307
- # Set truncation to auto as required for computer use
308
- self.request_json["truncation"] = "auto"
309
-
310
- if tools:
240
+ if self.context.tools:
311
241
  # Add regular function tools
312
- request_tools.extend([tool.dump_for("openai-responses") for tool in tools])
242
+ for tool in self.context.tools:
243
+ if isinstance(tool, Tool):
244
+ request_tools.append(tool.dump_for("openai-responses"))
245
+ elif isinstance(tool, dict):
246
+ # if computer use, make sure model supports it
247
+ if tool["type"] == "computer_use_preview":
248
+ if self.context.model_name != "openai-computer-use-preview":
249
+ raise ValueError(
250
+ f"model {self.context.model_name} does not support computer use"
251
+ )
252
+ # have to use truncation
253
+ self.request_json["truncation"] = "auto"
254
+ request_tools.append(tool) # allow passing dict
255
+ elif isinstance(tool, MCPServer):
256
+ request_tools.append(tool.for_openai_responses())
313
257
 
314
258
  if request_tools:
315
259
  self.request_json["tools"] = request_tools
@@ -324,6 +268,7 @@ class OpenAIResponsesRequest(APIRequestBase):
324
268
  status_code = http_response.status
325
269
  mimetype = http_response.headers.get("Content-Type", None)
326
270
  data = None
271
+ assert self.context.status_tracker
327
272
 
328
273
  if status_code >= 200 and status_code < 300:
329
274
  try:
@@ -352,26 +297,83 @@ class OpenAIResponsesRequest(APIRequestBase):
352
297
  for content_item in message_content:
353
298
  if content_item.get("type") == "output_text":
354
299
  parts.append(Text(content_item["text"]))
355
- # Handle tool calls if present
356
- elif content_item.get("type") == "tool_call":
357
- tool_call = content_item["tool_call"]
358
- parts.append(
359
- ToolCall(
360
- id=tool_call["id"],
361
- name=tool_call["function"]["name"],
362
- arguments=json.loads(
363
- tool_call["function"]["arguments"]
364
- ),
365
- )
366
- )
300
+ elif content_item.get("type") == "refusal":
301
+ parts.append(Text(content_item["refusal"]))
302
+ elif item.get("type") == "reasoning":
303
+ parts.append(Thinking(item["summary"]["text"]))
304
+ elif item.get("type") == "function_call":
305
+ parts.append(
306
+ ToolCall(
307
+ id=item["call_id"],
308
+ name=item["name"],
309
+ arguments=json.loads(item["arguments"]),
310
+ )
311
+ )
312
+ elif item.get("type") == "mcp_call":
313
+ parts.append(
314
+ ToolCall(
315
+ id=item["id"],
316
+ name=item["name"],
317
+ arguments=json.loads(item["arguments"]),
318
+ built_in=True,
319
+ built_in_type="mcp_call",
320
+ extra_body={
321
+ "server_label": item["server_label"],
322
+ "error": item.get("error"),
323
+ "output": item.get("output"),
324
+ },
325
+ )
326
+ )
327
+
367
328
  elif item.get("type") == "computer_call":
368
- # Handle computer use actions
369
- action = item.get("action", {})
370
329
  parts.append(
371
330
  ToolCall(
372
331
  id=item["call_id"],
373
- name=f"_computer_{action.get('type', 'action')}",
374
- arguments=action,
332
+ name="computer_call",
333
+ arguments=item.get("action"),
334
+ built_in=True,
335
+ built_in_type="computer_call",
336
+ )
337
+ )
338
+
339
+ elif item.get("type") == "web_search_call":
340
+ parts.append(
341
+ ToolCall(
342
+ id=item["id"],
343
+ name="web_search_call",
344
+ arguments={},
345
+ built_in=True,
346
+ built_in_type="web_search_call",
347
+ extra_body={"status": item["status"]},
348
+ )
349
+ )
350
+
351
+ elif item.get("type") == "file_search_call":
352
+ parts.append(
353
+ ToolCall(
354
+ id=item["id"],
355
+ name="file_search_call",
356
+ arguments={"queries": item["queries"]},
357
+ built_in=True,
358
+ built_in_type="file_search_call",
359
+ extra_body={
360
+ "status": item["status"],
361
+ "results": item["results"],
362
+ },
363
+ )
364
+ )
365
+ elif item.get("type") == "image_generation_call":
366
+ parts.append(
367
+ ToolCall(
368
+ id=item["id"],
369
+ name="image_generation_call",
370
+ arguments={},
371
+ built_in=True,
372
+ built_in_type="image_generation_call",
373
+ extra_body={
374
+ "status": item["status"],
375
+ "result": item["result"],
376
+ },
375
377
  )
376
378
  )
377
379
 
@@ -386,9 +388,6 @@ class OpenAIResponsesRequest(APIRequestBase):
386
388
  if "usage" in data:
387
389
  usage = Usage.from_openai_usage(data["usage"])
388
390
 
389
- # Extract response_id for computer use continuation
390
- # response_id = data.get("id")
391
-
392
391
  except Exception as e:
393
392
  is_error = True
394
393
  error_message = f"Error parsing {self.model.name} responses API response: {str(e)}"
@@ -406,22 +405,22 @@ class OpenAIResponsesRequest(APIRequestBase):
406
405
  if is_error and error_message is not None:
407
406
  if "rate limit" in error_message.lower() or status_code == 429:
408
407
  error_message += " (Rate limit error, triggering cooldown.)"
409
- self.status_tracker.rate_limit_exceeded()
408
+ self.context.status_tracker.rate_limit_exceeded()
410
409
  if "context length" in error_message:
411
410
  error_message += " (Context length exceeded, set retries to 0.)"
412
- self.attempts_left = 0
411
+ self.context.attempts_left = 0
413
412
 
414
413
  return APIResponse(
415
- id=self.task_id,
414
+ id=self.context.task_id,
416
415
  status_code=status_code,
417
416
  is_error=is_error,
418
417
  error_message=error_message,
419
- prompt=self.prompt,
418
+ prompt=self.context.prompt,
420
419
  logprobs=logprobs,
421
420
  thinking=thinking,
422
421
  content=content,
423
- model_internal=self.model_name,
424
- sampling_params=self.sampling_params,
422
+ model_internal=self.context.model_name,
423
+ sampling_params=self.context.sampling_params,
425
424
  usage=usage,
426
425
  raw_response=data,
427
426
  )
@@ -35,7 +35,8 @@ class APIResponse:
35
35
  logprobs: list | None = None
36
36
  finish_reason: str | None = None # make required later
37
37
  cost: float | None = None # calculated automatically
38
- cache_hit: bool = False # manually set if true
38
+ cache_hit: bool = False # manually set if true (provider-side caching)
39
+ local_cache_hit: bool = False # set if hit our local dynamic cache
39
40
  # set to true if is_error and should be retried with a different model
40
41
  retry_with_different_model: bool | None = False
41
42
  # set to true if should NOT retry with the same model (unrecoverable error)
lm_deluge/batches.py CHANGED
@@ -218,7 +218,7 @@ async def submit_batches_anthropic(
218
218
  batch_tasks = []
219
219
  async with aiohttp.ClientSession() as session:
220
220
  for batch in batches:
221
- url = f"{registry[model]['api_base']}/messages/batches"
221
+ url = f"{registry[model].api_base}/messages/batches"
222
222
  data = {"requests": batch}
223
223
 
224
224
  async def submit_batch(data, url, headers):