lm-deluge 0.0.11__py3-none-any.whl → 0.0.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of lm-deluge might be problematic. Click here for more details.

@@ -0,0 +1,75 @@
1
+ from typing import Literal
2
+
3
+ ToolVersion = Literal["2024-10-22", "2025-01-24", "2025-04-29"]
4
+ ToolType = Literal["bash", "computer", "editor"]
5
+
6
+
7
+ def model_to_version(model: str) -> ToolVersion:
8
+ if "opus" not in model and "sonnet" not in model:
9
+ raise ValueError("cannot use computer tools with incompatible model")
10
+ if "claude-4" in model:
11
+ return "2025-04-29"
12
+ elif "3.7" in model:
13
+ return "2025-01-24"
14
+ else:
15
+ return "2024-10-22"
16
+
17
+
18
+ def get_anthropic_cu_tools(
19
+ model: str,
20
+ display_width: int,
21
+ display_height: int,
22
+ exclude_tools: list[ToolType] | None = None,
23
+ ):
24
+ version = model_to_version(model)
25
+ if version == "2024-10-22":
26
+ result = [
27
+ {
28
+ "name": "computer",
29
+ "type": "computer_20241022",
30
+ "display_width_px": display_width,
31
+ "display_height_px": display_height,
32
+ "display_number": None,
33
+ },
34
+ {"name": "str_replace_editor", "type": "text_editor_20250429"},
35
+ {"type": "bash_20250124", "name": "bash"},
36
+ ]
37
+ elif version == "2025-01-24":
38
+ result = [
39
+ {
40
+ "name": "computer",
41
+ "type": "computer_20250124",
42
+ "display_width_px": display_width,
43
+ "display_height_px": display_height,
44
+ "display_number": None,
45
+ },
46
+ {"name": "str_replace_editor", "type": "text_editor_20250124"},
47
+ {"type": "bash_20250124", "name": "bash"},
48
+ ]
49
+ elif version == "2025-04-29":
50
+ result = [
51
+ {
52
+ "name": "computer",
53
+ "type": "computer_20250124",
54
+ "display_width_px": display_width,
55
+ "display_height_px": display_height,
56
+ "display_number": None,
57
+ },
58
+ {"name": "str_replace_based_edit_tool", "type": "text_editor_20250429"},
59
+ {
60
+ "name": "bash",
61
+ "type": "bash_20250124",
62
+ },
63
+ ]
64
+ else:
65
+ raise ValueError("invalid tool version")
66
+
67
+ if exclude_tools is None:
68
+ return result
69
+ if "bash" in exclude_tools:
70
+ result = [x for x in result if x["name"] != "bash"]
71
+ if "editor" in exclude_tools:
72
+ result = [x for x in result if "edit" not in x["name"]]
73
+ if "computer" in exclude_tools:
74
+ result = [x for x in result if "computer" not in x["name"]]
75
+ return result
@@ -1,14 +1,15 @@
1
- from dataclasses import dataclass
1
+ from pydantic import BaseModel
2
2
  from typing import Literal
3
3
 
4
4
 
5
- @dataclass
6
- class SamplingParams:
5
+ class SamplingParams(BaseModel):
7
6
  temperature: float = 0.0
8
7
  top_p: float = 1.0
9
8
  json_mode: bool = False
10
9
  max_new_tokens: int = 512
11
10
  reasoning_effort: Literal["low", "medium", "high", None] = None
11
+ logprobs: bool = False
12
+ top_logprobs: int | None = None
12
13
 
13
14
  def to_vllm(self):
14
15
  try:
@@ -23,3 +24,9 @@ class SamplingParams:
23
24
  top_p=self.top_p,
24
25
  max_tokens=self.max_new_tokens,
25
26
  )
27
+
28
+
29
+ class ComputerUseParams(BaseModel):
30
+ enabled: bool = False
31
+ display_width: int = 1024
32
+ display_height: int = 768
lm_deluge/embed.py CHANGED
@@ -1,12 +1,14 @@
1
1
  ### specific utility for cohere rerank api
2
- import os
3
- import numpy as np
4
- import aiohttp
5
- from tqdm.auto import tqdm
6
2
  import asyncio
3
+ import os
7
4
  import time
8
- from typing import Any
9
5
  from dataclasses import dataclass
6
+ from typing import Any
7
+
8
+ import aiohttp
9
+ import numpy as np
10
+ from tqdm.auto import tqdm
11
+
10
12
  from .tracker import StatusTracker
11
13
 
12
14
  registry = {
@@ -56,7 +58,6 @@ class EmbeddingRequest:
56
58
  texts: list[str],
57
59
  attempts_left: int,
58
60
  status_tracker: StatusTracker,
59
- retry_queue: asyncio.Queue,
60
61
  request_timeout: int,
61
62
  pbar: tqdm | None = None,
62
63
  **kwargs, # openai or cohere specific params
@@ -66,7 +67,6 @@ class EmbeddingRequest:
66
67
  self.texts = texts
67
68
  self.attempts_left = attempts_left
68
69
  self.status_tracker = status_tracker
69
- self.retry_queue = retry_queue
70
70
  self.request_timeout = request_timeout
71
71
  self.pbar = pbar
72
72
  self.result = []
@@ -89,7 +89,8 @@ class EmbeddingRequest:
89
89
  print(error_to_print)
90
90
  if self.attempts_left > 0:
91
91
  self.attempts_left -= 1
92
- self.retry_queue.put_nowait(self)
92
+ assert self.status_tracker.retry_queue
93
+ self.status_tracker.retry_queue.put_nowait(self)
93
94
  return
94
95
  else:
95
96
  print(f"Task {self.task_id} out of tries.")
@@ -243,7 +244,11 @@ async def embed_parallel_async(
243
244
 
244
245
  # initialize trackers
245
246
  retry_queue = asyncio.Queue()
246
- status_tracker = StatusTracker()
247
+ status_tracker = StatusTracker(
248
+ max_tokens_per_minute=10_000_000,
249
+ max_requests_per_minute=max_requests_per_minute,
250
+ max_concurrent_requests=1_000,
251
+ )
247
252
  next_request = None # variable to hold the next request to call
248
253
 
249
254
  # initialize available capacity counts
@@ -262,7 +267,8 @@ async def embed_parallel_async(
262
267
  while True:
263
268
  # get next request (if one is not already waiting for capacity)
264
269
  if next_request is None:
265
- if not retry_queue.empty():
270
+ assert status_tracker.retry_queue
271
+ if not status_tracker.retry_queue.empty():
266
272
  next_request = retry_queue.get_nowait()
267
273
  print(f"Retrying request {next_request.task_id}.")
268
274
  elif prompts_not_finished:
@@ -285,7 +291,7 @@ async def embed_parallel_async(
285
291
 
286
292
  except StopIteration:
287
293
  prompts_not_finished = False
288
- print("API requests finished, only retries remain.")
294
+ # print("API requests finished, only retries remain.")
289
295
 
290
296
  # update available capacity
291
297
  current_time = time.time()
lm_deluge/models.py CHANGED
@@ -178,6 +178,21 @@ registry = {
178
178
  # ░███
179
179
  # █████
180
180
  # ░░░░░
181
+ "openai-computer-use-preview": {
182
+ "id": "openai-computer-use-preview",
183
+ "name": "computer-use-preview",
184
+ "api_base": "https://api.openai.com/v1",
185
+ "api_key_env_var": "OPENAI_API_KEY",
186
+ "supports_json": True,
187
+ "supports_logprobs": False,
188
+ "supports_responses": True,
189
+ "api_spec": "openai-responses",
190
+ "input_cost": 2.0,
191
+ "output_cost": 8.0,
192
+ "requests_per_minute": 20,
193
+ "tokens_per_minute": 100_000,
194
+ "reasoning_model": False,
195
+ },
181
196
  "o3": {
182
197
  "id": "o3",
183
198
  "name": "o3-2025-04-16",
@@ -185,6 +200,7 @@ registry = {
185
200
  "api_key_env_var": "OPENAI_API_KEY",
186
201
  "supports_json": False,
187
202
  "supports_logprobs": True,
203
+ "supports_responses": True,
188
204
  "api_spec": "openai",
189
205
  "input_cost": 10.0,
190
206
  "output_cost": 40.0,
@@ -199,6 +215,7 @@ registry = {
199
215
  "api_key_env_var": "OPENAI_API_KEY",
200
216
  "supports_json": False,
201
217
  "supports_logprobs": True,
218
+ "supports_responses": True,
202
219
  "api_spec": "openai",
203
220
  "input_cost": 1.1,
204
221
  "output_cost": 4.4,
@@ -213,6 +230,7 @@ registry = {
213
230
  "api_key_env_var": "OPENAI_API_KEY",
214
231
  "supports_json": True,
215
232
  "supports_logprobs": True,
233
+ "supports_responses": True,
216
234
  "api_spec": "openai",
217
235
  "input_cost": 2.0,
218
236
  "output_cost": 8.0,
@@ -227,6 +245,7 @@ registry = {
227
245
  "api_key_env_var": "OPENAI_API_KEY",
228
246
  "supports_json": True,
229
247
  "supports_logprobs": True,
248
+ "supports_responses": True,
230
249
  "api_spec": "openai",
231
250
  "input_cost": 0.4,
232
251
  "output_cost": 1.6,
@@ -241,6 +260,7 @@ registry = {
241
260
  "api_key_env_var": "OPENAI_API_KEY",
242
261
  "supports_json": True,
243
262
  "supports_logprobs": True,
263
+ "supports_responses": True,
244
264
  "api_spec": "openai",
245
265
  "input_cost": 0.1,
246
266
  "output_cost": 0.4,
@@ -255,6 +275,7 @@ registry = {
255
275
  "api_key_env_var": "OPENAI_API_KEY",
256
276
  "supports_json": False,
257
277
  "supports_logprobs": True,
278
+ "supports_responses": True,
258
279
  "api_spec": "openai",
259
280
  "input_cost": 75.0,
260
281
  "output_cost": 150.0,
@@ -269,6 +290,7 @@ registry = {
269
290
  "api_key_env_var": "OPENAI_API_KEY",
270
291
  "supports_json": False,
271
292
  "supports_logprobs": True,
293
+ "supports_responses": True,
272
294
  "api_spec": "openai",
273
295
  "input_cost": 1.1,
274
296
  "output_cost": 4.4,
@@ -283,6 +305,7 @@ registry = {
283
305
  "api_key_env_var": "OPENAI_API_KEY",
284
306
  "supports_json": False,
285
307
  "supports_logprobs": True,
308
+ "supports_responses": True,
286
309
  "api_spec": "openai",
287
310
  "input_cost": 15.0,
288
311
  "output_cost": 60.0,
@@ -297,6 +320,7 @@ registry = {
297
320
  "api_key_env_var": "OPENAI_API_KEY",
298
321
  "supports_json": False,
299
322
  "supports_logprobs": True,
323
+ "supports_responses": True,
300
324
  "api_spec": "openai",
301
325
  "input_cost": 15.0,
302
326
  "output_cost": 60.0,
@@ -311,6 +335,7 @@ registry = {
311
335
  "api_key_env_var": "OPENAI_API_KEY",
312
336
  "supports_json": False,
313
337
  "supports_logprobs": True,
338
+ "supports_responses": True,
314
339
  "api_spec": "openai",
315
340
  "input_cost": 3.0,
316
341
  "output_cost": 15.0,
@@ -325,6 +350,7 @@ registry = {
325
350
  "api_key_env_var": "OPENAI_API_KEY",
326
351
  "supports_json": True,
327
352
  "supports_logprobs": True,
353
+ "supports_responses": True,
328
354
  "api_spec": "openai",
329
355
  "input_cost": 5.0,
330
356
  "output_cost": 15.0,
@@ -338,6 +364,7 @@ registry = {
338
364
  "api_key_env_var": "OPENAI_API_KEY",
339
365
  "supports_json": True,
340
366
  "supports_logprobs": True,
367
+ "supports_responses": True,
341
368
  "api_spec": "openai",
342
369
  "input_cost": 0.15,
343
370
  "output_cost": 0.6,
@@ -351,6 +378,7 @@ registry = {
351
378
  "api_key_env_var": "OPENAI_API_KEY",
352
379
  "supports_json": True,
353
380
  "supports_logprobs": True,
381
+ "supports_responses": True,
354
382
  "api_spec": "openai",
355
383
  "input_cost": 0.0,
356
384
  "output_cost": 0.0,
@@ -364,6 +392,7 @@ registry = {
364
392
  "api_key_env_var": "OPENAI_API_KEY",
365
393
  "supports_json": True,
366
394
  "supports_logprobs": True,
395
+ "supports_responses": True,
367
396
  "api_spec": "openai",
368
397
  "input_cost": 0.5,
369
398
  "output_cost": 1.5,
@@ -377,6 +406,7 @@ registry = {
377
406
  "api_key_env_var": "OPENAI_API_KEY",
378
407
  "supports_json": True,
379
408
  "supports_logprobs": True,
409
+ "supports_responses": True,
380
410
  "api_spec": "openai",
381
411
  "input_cost": 10.0,
382
412
  "output_cost": 30.0,
@@ -390,6 +420,7 @@ registry = {
390
420
  "api_key_env_var": "OPENAI_API_KEY",
391
421
  "supports_json": False,
392
422
  "supports_logprobs": False,
423
+ "supports_responses": True,
393
424
  "api_spec": "openai",
394
425
  "input_cost": 30.0,
395
426
  "output_cost": 60.0,
@@ -403,6 +434,7 @@ registry = {
403
434
  "api_key_env_var": "OPENAI_API_KEY",
404
435
  "supports_json": False,
405
436
  "supports_logprobs": False,
437
+ "supports_responses": True,
406
438
  "api_spec": "openai",
407
439
  "input_cost": 60.0,
408
440
  "output_cost": 120.0,
@@ -1093,6 +1125,7 @@ class APIModel:
1093
1125
  output_cost: float | None = 0 # $ per million output tokens
1094
1126
  supports_json: bool = False
1095
1127
  supports_logprobs: bool = False
1128
+ supports_responses: bool = False
1096
1129
  reasoning_model: bool = False
1097
1130
  regions: list[str] | dict[str, int] = field(default_factory=list)
1098
1131
  tokens_per_minute: int | None = None
lm_deluge/prompt.py CHANGED
@@ -4,10 +4,18 @@ import tiktoken
4
4
  import xxhash
5
5
  from dataclasses import dataclass, field
6
6
  from pathlib import Path
7
- from typing import Literal
7
+ from typing import Literal, Sequence
8
8
  from lm_deluge.models import APIModel
9
9
  from lm_deluge.image import Image
10
10
 
11
+ CachePattern = Literal[
12
+ "tools_only",
13
+ "system_and_tools",
14
+ "last_user_message",
15
+ "last_2_user_messages",
16
+ "last_3_user_messages",
17
+ ]
18
+
11
19
  ###############################################################################
12
20
  # 1. Low-level content blocks – either text or an image #
13
21
  ###############################################################################
@@ -91,24 +99,58 @@ class ToolCall:
91
99
  @dataclass(slots=True)
92
100
  class ToolResult:
93
101
  tool_call_id: str # references the ToolCall.id
94
- result: str # tool execution result
102
+ result: (
103
+ str | dict | list[dict]
104
+ ) # tool execution result - can be string or list for images
95
105
  type: str = field(init=False, default="tool_result")
96
106
 
97
107
  @property
98
108
  def fingerprint(self) -> str:
99
- return xxhash.xxh64(f"{self.tool_call_id}:{self.result}".encode()).hexdigest()
109
+ result_str = (
110
+ json.dumps(self.result, sort_keys=True)
111
+ if isinstance(self.result, list) or isinstance(self.result, dict)
112
+ else str(self.result)
113
+ )
114
+ return xxhash.xxh64(f"{self.tool_call_id}:{result_str}".encode()).hexdigest()
100
115
 
101
116
  # ── provider-specific emission ────────────────────────────────────────────
102
117
  def oa_chat(
103
118
  self,
104
119
  ) -> dict: # OpenAI Chat Completions - tool results are separate messages
105
- return {"tool_call_id": self.tool_call_id, "content": self.result}
120
+ content = (
121
+ json.dumps(self.result) if isinstance(self.result, list) else self.result
122
+ )
123
+ return {"tool_call_id": self.tool_call_id, "content": content}
106
124
 
107
125
  def oa_resp(self) -> dict: # OpenAI Responses
126
+ # Check if this is a computer use output (special case)
127
+ if isinstance(self.result, dict) and self.result.get("_computer_use_output"):
128
+ # This is a computer use output, emit it properly
129
+ output_data = self.result.copy()
130
+ output_data.pop("_computer_use_output") # Remove marker
131
+
132
+ result = {
133
+ "type": "computer_call_output",
134
+ "call_id": self.tool_call_id,
135
+ "output": output_data.get("output", {}),
136
+ }
137
+
138
+ # Add acknowledged safety checks if present
139
+ if "acknowledged_safety_checks" in output_data:
140
+ result["acknowledged_safety_checks"] = output_data[
141
+ "acknowledged_safety_checks"
142
+ ]
143
+
144
+ return result
145
+
146
+ # Regular function result
147
+ result = (
148
+ json.dumps(self.result) if isinstance(self.result, list) else self.result
149
+ )
108
150
  return {
109
151
  "type": "function_result",
110
152
  "call_id": self.tool_call_id,
111
- "result": self.result,
153
+ "result": result,
112
154
  }
113
155
 
114
156
  def anthropic(self) -> dict: # Anthropic Messages
@@ -420,6 +462,14 @@ class Message:
420
462
 
421
463
  def oa_resp(self) -> dict:
422
464
  content = [p.oa_resp() for p in self.parts]
465
+ # For OpenAI Responses API, handle tool results specially
466
+ if self.role == "tool" or (
467
+ self.role == "user" and any(isinstance(p, ToolResult) for p in self.parts)
468
+ ):
469
+ # Tool results are returned directly, not wrapped in a message
470
+ # This handles computer_call_output when stored as ToolResult
471
+ if len(self.parts) == 1 and isinstance(self.parts[0], ToolResult):
472
+ return self.parts[0].oa_resp()
423
473
  return {"role": self.role, "content": content}
424
474
 
425
475
  def anthropic(self) -> dict:
@@ -514,9 +564,41 @@ class Conversation:
514
564
 
515
565
  def to_openai_responses(self) -> dict:
516
566
  # OpenAI Responses = single “input” array, role must be user/assistant
517
- return {"input": [m.oa_resp() for m in self.messages if m.role != "system"]}
567
+ input_items = []
568
+
569
+ for m in self.messages:
570
+ if m.role == "system":
571
+ continue
572
+ elif m.role == "assistant":
573
+ # For assistant messages, extract computer calls as separate items
574
+ text_parts = []
575
+ for p in m.parts:
576
+ if isinstance(p, ToolCall) and p.name.startswith("_computer_"):
577
+ # Computer calls become separate items in the input array
578
+ action_type = p.name.replace("_computer_", "")
579
+ input_items.append(
580
+ {
581
+ "type": "computer_call",
582
+ "call_id": p.id,
583
+ "action": {"type": action_type, **p.arguments},
584
+ }
585
+ )
586
+ elif isinstance(p, Text):
587
+ text_parts.append({"type": "output_text", "text": p.text})
588
+ # TODO: Handle other part types as needed
589
+
590
+ # Add message if it has text content
591
+ if text_parts:
592
+ input_items.append({"role": m.role, "content": text_parts})
593
+ else:
594
+ # User and tool messages use normal format
595
+ input_items.append(m.oa_resp())
518
596
 
519
- def to_anthropic(self) -> tuple[str | None, list[dict]]:
597
+ return {"input": input_items}
598
+
599
+ def to_anthropic(
600
+ self, cache_pattern: CachePattern | None = None
601
+ ) -> tuple[str | list[dict] | None, list[dict]]:
520
602
  system_msg = next(
521
603
  (
522
604
  m.parts[0].text
@@ -535,8 +617,84 @@ class Conversation:
535
617
  other.append(user_msg.anthropic())
536
618
  else:
537
619
  other.append(m.anthropic())
620
+
621
+ # Apply cache control if specified
622
+ if cache_pattern is not None:
623
+ system_msg, other = self._apply_cache_control(
624
+ system_msg, other, cache_pattern
625
+ )
626
+
538
627
  return system_msg, other
539
628
 
629
+ def _apply_cache_control(
630
+ self,
631
+ system_msg: str | None | list[dict],
632
+ messages: list[dict],
633
+ cache_pattern: CachePattern,
634
+ ) -> tuple[str | list[dict] | None, list[dict]]:
635
+ """Apply cache control to system message and/or messages based on the pattern."""
636
+
637
+ if cache_pattern == "system_and_tools" and system_msg is not None:
638
+ # Convert system message to structured format with cache control
639
+ # This caches tools+system prefix (since system comes after tools)
640
+ system_msg = [
641
+ {
642
+ "type": "text",
643
+ "text": system_msg,
644
+ "cache_control": {"type": "ephemeral"},
645
+ }
646
+ ]
647
+
648
+ if cache_pattern == "last_user_message":
649
+ # Cache the last user message
650
+ user_messages = [i for i, m in enumerate(messages) if m["role"] == "user"]
651
+ if user_messages:
652
+ last_user_idx = user_messages[-1]
653
+ self._add_cache_control_to_message(messages[last_user_idx])
654
+
655
+ elif cache_pattern == "last_2_user_messages":
656
+ # Cache the last 2 user messages
657
+ user_messages = [i for i, m in enumerate(messages) if m["role"] == "user"]
658
+ for idx in user_messages[-2:]:
659
+ self._add_cache_control_to_message(messages[idx])
660
+
661
+ elif cache_pattern == "last_3_user_messages":
662
+ # Cache the last 3 user messages
663
+ user_messages = [i for i, m in enumerate(messages) if m["role"] == "user"]
664
+ for idx in user_messages[-3:]:
665
+ self._add_cache_control_to_message(messages[idx])
666
+
667
+ return system_msg, messages
668
+
669
+ def lock_images_as_bytes(self) -> "Conversation":
670
+ """
671
+ Convert all images to bytes format to ensure they remain unchanged for caching.
672
+ This should be called when caching is enabled to prevent cache invalidation
673
+ from image reference changes.
674
+ """
675
+ for message in self.messages:
676
+ for part in message.parts:
677
+ if isinstance(part, Image):
678
+ # Force conversion to bytes if not already
679
+ part.data = part._bytes()
680
+ return self
681
+
682
+ def _add_cache_control_to_message(self, message: dict) -> None:
683
+ """Add cache control to a message's content."""
684
+ content = message["content"]
685
+ if isinstance(content, str):
686
+ # Convert string content to structured format with cache control
687
+ message["content"] = [
688
+ {
689
+ "type": "text",
690
+ "text": content,
691
+ "cache_control": {"type": "ephemeral"},
692
+ }
693
+ ]
694
+ elif isinstance(content, list) and content:
695
+ # Add cache control to the last content block
696
+ content[-1]["cache_control"] = {"type": "ephemeral"}
697
+
540
698
  def to_gemini(self) -> tuple[str | None, list[dict]]:
541
699
  system_msg = next(
542
700
  (
@@ -664,6 +822,14 @@ class Conversation:
664
822
  return cls(msgs)
665
823
 
666
824
 
825
+ def prompts_to_conversations(prompts: Sequence[str | list[dict] | Conversation]):
826
+ if any(isinstance(x, list) for x in prompts):
827
+ raise ValueError("can't convert list[dict] to conversation yet")
828
+ return [ # type: ignore
829
+ Conversation.user(p) if isinstance(p, str) else p for p in prompts
830
+ ]
831
+
832
+
667
833
  ###############################################################################
668
834
  # --------------------------------------------------------------------------- #
669
835
  # Basic usage examples #
lm_deluge/rerank.py CHANGED
@@ -1,10 +1,12 @@
1
1
  ### specific utility for cohere rerank api
2
- import os
3
- import aiohttp
4
- from tqdm.auto import tqdm
5
2
  import asyncio
3
+ import os
6
4
  import time
7
5
  from dataclasses import dataclass
6
+
7
+ import aiohttp
8
+ from tqdm.auto import tqdm
9
+
8
10
  from .tracker import StatusTracker
9
11
 
10
12
  registry = [
@@ -25,7 +27,6 @@ class RerankingRequest:
25
27
  top_k: int,
26
28
  attempts_left: int,
27
29
  status_tracker: StatusTracker,
28
- retry_queue: asyncio.Queue,
29
30
  request_timeout: int,
30
31
  pbar: tqdm | None = None,
31
32
  ):
@@ -36,7 +37,6 @@ class RerankingRequest:
36
37
  self.top_k = top_k
37
38
  self.attempts_left = attempts_left
38
39
  self.status_tracker = status_tracker
39
- self.retry_queue = retry_queue
40
40
  self.request_timeout = request_timeout
41
41
  self.pbar = pbar
42
42
  self.result = []
@@ -63,7 +63,8 @@ class RerankingRequest:
63
63
  print(error_to_print)
64
64
  if self.attempts_left > 0:
65
65
  self.attempts_left -= 1
66
- self.retry_queue.put_nowait(self)
66
+ assert self.status_tracker.retry_queue
67
+ self.status_tracker.retry_queue.put_nowait(self)
67
68
  return
68
69
  else:
69
70
  print(f"Task {self.task_id} out of tries.")
@@ -203,8 +204,12 @@ async def rerank_parallel_async(
203
204
  seconds_to_sleep_each_loop = 0.003 # so concurrent tasks can run
204
205
 
205
206
  # initialize trackers
206
- retry_queue = asyncio.Queue()
207
- status_tracker = StatusTracker()
207
+ # retry_queue = asyncio.Queue()
208
+ status_tracker = StatusTracker(
209
+ max_tokens_per_minute=10_000_000,
210
+ max_requests_per_minute=max_requests_per_minute,
211
+ max_concurrent_requests=1_000,
212
+ )
208
213
  next_request = None # variable to hold the next request to call
209
214
 
210
215
  # initialize available capacity counts
@@ -222,8 +227,10 @@ async def rerank_parallel_async(
222
227
  while True:
223
228
  # get next request (if one is not already waiting for capacity)
224
229
  if next_request is None:
225
- if not retry_queue.empty():
226
- next_request = retry_queue.get_nowait()
230
+ assert status_tracker.retry_queue
231
+
232
+ if not status_tracker.retry_queue.empty():
233
+ next_request = status_tracker.retry_queue.get_nowait()
227
234
  print(f"Retrying request {next_request.task_id}.")
228
235
  elif prompts_not_finished:
229
236
  try:
@@ -237,7 +244,6 @@ async def rerank_parallel_async(
237
244
  top_k=top_k,
238
245
  attempts_left=max_attempts,
239
246
  status_tracker=status_tracker,
240
- retry_queue=retry_queue,
241
247
  request_timeout=request_timeout,
242
248
  pbar=progress_bar,
243
249
  )
@@ -246,7 +252,7 @@ async def rerank_parallel_async(
246
252
 
247
253
  except StopIteration:
248
254
  prompts_not_finished = False
249
- print("API requests finished, only retries remain.")
255
+ # print("API requests finished, only retries remain.")
250
256
 
251
257
  # update available capacity
252
258
  current_time = time.time()
lm_deluge/tool.py CHANGED
@@ -4,7 +4,7 @@ import asyncio
4
4
 
5
5
  from fastmcp import Client # pip install fastmcp >= 2.0
6
6
  from mcp.types import Tool as MCPTool
7
- from pydantic import BaseModel, Field
7
+ from pydantic import BaseModel, Field, field_validator
8
8
 
9
9
 
10
10
  async def _load_all_mcp_tools(client: Client) -> list["Tool"]:
@@ -46,6 +46,16 @@ class Tool(BaseModel):
46
46
  # if desired, can provide a callable to run the tool
47
47
  run: Callable | None = None
48
48
 
49
+ @field_validator("name")
50
+ @classmethod
51
+ def validate_name(cls, v: str) -> str:
52
+ if v.startswith("_computer_"):
53
+ raise ValueError(
54
+ f"Tool name '{v}' uses reserved prefix '_computer_'. "
55
+ "This prefix is reserved for computer use actions."
56
+ )
57
+ return v
58
+
49
59
  def _is_async(self) -> bool:
50
60
  return inspect.iscoroutinefunction(self.run)
51
61