lm-deluge 0.0.14__py3-none-any.whl → 0.0.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lm_deluge/api_requests/__init__.py +0 -2
- lm_deluge/api_requests/anthropic.py +58 -84
- lm_deluge/api_requests/base.py +43 -229
- lm_deluge/api_requests/bedrock.py +173 -195
- lm_deluge/api_requests/common.py +2 -0
- lm_deluge/api_requests/gemini.py +196 -0
- lm_deluge/api_requests/mistral.py +30 -60
- lm_deluge/api_requests/openai.py +147 -148
- lm_deluge/api_requests/response.py +2 -1
- lm_deluge/batches.py +1 -1
- lm_deluge/{computer_use/anthropic_tools.py → built_in_tools/anthropic.py} +56 -5
- lm_deluge/built_in_tools/openai.py +28 -0
- lm_deluge/client.py +221 -150
- lm_deluge/file.py +7 -2
- lm_deluge/image.py +13 -8
- lm_deluge/llm_tools/extract.py +23 -4
- lm_deluge/llm_tools/ocr.py +1 -0
- lm_deluge/models.py +96 -2
- lm_deluge/prompt.py +43 -27
- lm_deluge/request_context.py +75 -0
- lm_deluge/tool.py +93 -15
- lm_deluge/tracker.py +1 -0
- lm_deluge/usage.py +10 -0
- {lm_deluge-0.0.14.dist-info → lm_deluge-0.0.16.dist-info}/METADATA +25 -1
- lm_deluge-0.0.16.dist-info/RECORD +48 -0
- lm_deluge-0.0.14.dist-info/RECORD +0 -44
- {lm_deluge-0.0.14.dist-info → lm_deluge-0.0.16.dist-info}/WHEEL +0 -0
- {lm_deluge-0.0.14.dist-info → lm_deluge-0.0.16.dist-info}/licenses/LICENSE +0 -0
- {lm_deluge-0.0.14.dist-info → lm_deluge-0.0.16.dist-info}/top_level.txt +0 -0
|
@@ -11,14 +11,16 @@ def model_to_version(model: str) -> ToolVersion:
|
|
|
11
11
|
return "2025-04-29"
|
|
12
12
|
elif "3.7" in model:
|
|
13
13
|
return "2025-01-24"
|
|
14
|
-
|
|
14
|
+
elif "3.6" in model:
|
|
15
15
|
return "2024-10-22"
|
|
16
|
+
else:
|
|
17
|
+
raise ValueError("unsupported model for anthropic CUA")
|
|
16
18
|
|
|
17
19
|
|
|
18
20
|
def get_anthropic_cu_tools(
|
|
19
21
|
model: str,
|
|
20
|
-
display_width: int,
|
|
21
|
-
display_height: int,
|
|
22
|
+
display_width: int = 1024,
|
|
23
|
+
display_height: int = 768,
|
|
22
24
|
exclude_tools: list[ToolType] | None = None,
|
|
23
25
|
):
|
|
24
26
|
version = model_to_version(model)
|
|
@@ -31,8 +33,8 @@ def get_anthropic_cu_tools(
|
|
|
31
33
|
"display_height_px": display_height,
|
|
32
34
|
"display_number": None,
|
|
33
35
|
},
|
|
34
|
-
{"name": "str_replace_editor", "type": "
|
|
35
|
-
{"
|
|
36
|
+
{"name": "str_replace_editor", "type": "text_editor_20241022"},
|
|
37
|
+
{"name": "bash", "type": "bash_20241022"},
|
|
36
38
|
]
|
|
37
39
|
elif version == "2025-01-24":
|
|
38
40
|
result = [
|
|
@@ -73,3 +75,52 @@ def get_anthropic_cu_tools(
|
|
|
73
75
|
if "computer" in exclude_tools:
|
|
74
76
|
result = [x for x in result if "computer" not in x["name"]]
|
|
75
77
|
return result
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def bash_tool(model: str = "claude-4-sonnet"):
|
|
81
|
+
# Claude Sonnet 3.5 requires the computer-use-2024-10-22 beta header when using the bash tool.
|
|
82
|
+
# The bash tool is generally available in Claude 4 and Sonnet 3.7.
|
|
83
|
+
if "claude-4" in model:
|
|
84
|
+
return {"type": "text_editor_20250429", "name": "str_replace_based_edit_tool"}
|
|
85
|
+
elif "3.7" in model:
|
|
86
|
+
return {"type": "text_editor_20250124", "name": "str_replace_editor"}
|
|
87
|
+
else:
|
|
88
|
+
return {"type": "text_editor_20241022", "name": "str_replace_editor"}
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def text_editor_tool(model: str = "claude-4-sonnet"):
|
|
92
|
+
if "claude-4" in model:
|
|
93
|
+
return {"type": "bash_20250124", "name": "bash"}
|
|
94
|
+
elif "3.7" in model:
|
|
95
|
+
return {"type": "bash_20250124", "name": "bash"}
|
|
96
|
+
else:
|
|
97
|
+
return {"type": "bash_20241022", "name": "bash"}
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def web_search_tool(max_uses: int = 5):
|
|
101
|
+
res = {
|
|
102
|
+
"type": "web_search_20250305",
|
|
103
|
+
"name": "web_search",
|
|
104
|
+
# Optional: Limit the number of searches per request
|
|
105
|
+
"max_uses": 5,
|
|
106
|
+
# You can use either allowed_domains or blocked_domains, but not both in the same request.
|
|
107
|
+
# Optional: Only include results from these domains
|
|
108
|
+
# "allowed_domains": ["example.com", "trusteddomain.org"],
|
|
109
|
+
# Optional: Never include results from these domains
|
|
110
|
+
# "blocked_domains": ["untrustedsource.com"],
|
|
111
|
+
# Optional: Localize search results
|
|
112
|
+
# "user_location": {
|
|
113
|
+
# "type": "approximate",
|
|
114
|
+
# "city": "San Francisco",
|
|
115
|
+
# "region": "California",
|
|
116
|
+
# "country": "US",
|
|
117
|
+
# "timezone": "America/Los_Angeles"
|
|
118
|
+
# }
|
|
119
|
+
}
|
|
120
|
+
return res
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def code_execution_tool():
|
|
124
|
+
# The code execution tool is currently in beta.
|
|
125
|
+
# This feature requires the beta header: "anthropic-beta": "code-execution-2025-05-22"
|
|
126
|
+
return {"type": "code_execution_20250522", "name": "code_execution"}
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
def image_generation_openai():
|
|
2
|
+
# TODO: handle result properly
|
|
3
|
+
return {"type": "image_generation"}
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def code_interpreter_openai(container: dict | None = None):
|
|
7
|
+
if container is None:
|
|
8
|
+
container = {"type": "auto"}
|
|
9
|
+
return {"type": "code_interpreter", "container": container}
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def local_shell_openai():
|
|
13
|
+
return {"type": "local_shell"}
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def web_search_openai():
|
|
17
|
+
return {"type": "web_search_preview"}
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def computer_use_openai(
|
|
21
|
+
display_width: int = 1024, display_height: int = 768, environment: str = "browser"
|
|
22
|
+
):
|
|
23
|
+
return {
|
|
24
|
+
"type": "computer_use_preview",
|
|
25
|
+
"display_width": display_width,
|
|
26
|
+
"display_height": display_height,
|
|
27
|
+
"environment": environment,
|
|
28
|
+
}
|
lm_deluge/client.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import asyncio
|
|
2
|
+
import random
|
|
2
3
|
from typing import Any, Literal, Self, Sequence, overload
|
|
3
4
|
|
|
4
5
|
import numpy as np
|
|
@@ -13,12 +14,12 @@ from lm_deluge.batches import (
|
|
|
13
14
|
wait_for_batch_completion_async,
|
|
14
15
|
)
|
|
15
16
|
from lm_deluge.prompt import CachePattern, Conversation, prompts_to_conversations
|
|
16
|
-
from lm_deluge.tool import Tool
|
|
17
|
+
from lm_deluge.tool import MCPServer, Tool
|
|
17
18
|
|
|
18
|
-
from .api_requests import
|
|
19
|
-
from .api_requests.base import APIRequestBase, APIResponse, deduplicate_responses
|
|
19
|
+
from .api_requests.base import APIResponse
|
|
20
20
|
from .config import SamplingParams
|
|
21
|
-
from .models import registry
|
|
21
|
+
from .models import APIModel, registry
|
|
22
|
+
from .request_context import RequestContext
|
|
22
23
|
from .tracker import StatusTracker
|
|
23
24
|
|
|
24
25
|
# from .cache import LevelDBCache, SqliteCache
|
|
@@ -135,9 +136,7 @@ class LLMClient(BaseModel):
|
|
|
135
136
|
print(
|
|
136
137
|
"WARNING: using top_logprobs can result in very large outputs. consider limiting max_new_tokens."
|
|
137
138
|
)
|
|
138
|
-
if not all(
|
|
139
|
-
registry[model].get("supports_logprobs") for model in self.models
|
|
140
|
-
):
|
|
139
|
+
if not all(registry[model].supports_logprobs for model in self.models):
|
|
141
140
|
raise ValueError(
|
|
142
141
|
"logprobs can only be enabled if all models support it."
|
|
143
142
|
)
|
|
@@ -174,6 +173,110 @@ class LLMClient(BaseModel):
|
|
|
174
173
|
model_idx = np.random.choice(range(len(self.models)), p=self.model_weights)
|
|
175
174
|
return self.models[model_idx], self.sampling_params[model_idx]
|
|
176
175
|
|
|
176
|
+
def _select_different_model(self, current_model: str):
|
|
177
|
+
"""Select a model different from the provided one."""
|
|
178
|
+
other_models = [m for m in self.models if m != current_model]
|
|
179
|
+
if not other_models:
|
|
180
|
+
# No other models available, return current
|
|
181
|
+
return current_model, self.sampling_params[self.models.index(current_model)]
|
|
182
|
+
|
|
183
|
+
# Get weights for other models
|
|
184
|
+
other_indices = [self.models.index(m) for m in other_models]
|
|
185
|
+
weights = [self.model_weights[idx] for idx in other_indices]
|
|
186
|
+
weights = [w / sum(weights) for w in weights] # type: ignore
|
|
187
|
+
|
|
188
|
+
model_idx = np.random.choice(range(len(other_models)), p=weights)
|
|
189
|
+
chosen_model = other_models[model_idx]
|
|
190
|
+
chosen_sp = self.sampling_params[self.models.index(chosen_model)]
|
|
191
|
+
return chosen_model, chosen_sp
|
|
192
|
+
|
|
193
|
+
async def _wait_for_capacity(self, num_tokens: int, tracker: StatusTracker):
|
|
194
|
+
while True:
|
|
195
|
+
if tracker.check_capacity(num_tokens):
|
|
196
|
+
tracker.set_limiting_factor(None)
|
|
197
|
+
return
|
|
198
|
+
|
|
199
|
+
if tracker.seconds_to_pause > 0:
|
|
200
|
+
await asyncio.sleep(tracker.seconds_to_pause)
|
|
201
|
+
else:
|
|
202
|
+
await asyncio.sleep(random.random())
|
|
203
|
+
|
|
204
|
+
async def _execute_request(self, context: RequestContext) -> APIResponse:
|
|
205
|
+
"""Create and send a single API request using the provided context."""
|
|
206
|
+
model_obj = APIModel.from_registry(context.model_name)
|
|
207
|
+
request = model_obj.make_request(context)
|
|
208
|
+
response = await request.execute_once()
|
|
209
|
+
return response
|
|
210
|
+
|
|
211
|
+
async def process_single_request(
|
|
212
|
+
self, context: RequestContext, retry_queue: asyncio.Queue | None = None
|
|
213
|
+
) -> APIResponse:
|
|
214
|
+
"""Handle caching and single HTTP call for a request. Failed requests go to retry queue."""
|
|
215
|
+
# Check cache first
|
|
216
|
+
if self.cache:
|
|
217
|
+
cached = self.cache.get(context.prompt)
|
|
218
|
+
if cached:
|
|
219
|
+
cached.local_cache_hit = True
|
|
220
|
+
if context.status_tracker:
|
|
221
|
+
context.status_tracker.task_succeeded(context.task_id)
|
|
222
|
+
return cached
|
|
223
|
+
|
|
224
|
+
# Execute single request
|
|
225
|
+
assert context.status_tracker
|
|
226
|
+
context.status_tracker.update_pbar()
|
|
227
|
+
response = await self._execute_request(context)
|
|
228
|
+
|
|
229
|
+
# Handle successful response
|
|
230
|
+
if not response.is_error:
|
|
231
|
+
context.status_tracker.task_succeeded(context.task_id)
|
|
232
|
+
# Cache successful responses immediately
|
|
233
|
+
if self.cache and response.completion:
|
|
234
|
+
self.cache.put(context.prompt, response)
|
|
235
|
+
# Call callback if provided
|
|
236
|
+
context.maybe_callback(response, context.status_tracker)
|
|
237
|
+
return response
|
|
238
|
+
|
|
239
|
+
# Handle error response - add to retry queue if available
|
|
240
|
+
if retry_queue and context.attempts_left > 1:
|
|
241
|
+
# Decide whether to retry with a different model
|
|
242
|
+
if response.retry_with_different_model and len(self.models) > 1:
|
|
243
|
+
# Switch to different model for retry
|
|
244
|
+
new_model, new_sp = self._select_different_model(context.model_name)
|
|
245
|
+
retry_context = context.copy(
|
|
246
|
+
model_name=new_model,
|
|
247
|
+
sampling_params=new_sp,
|
|
248
|
+
attempts_left=context.attempts_left - 1,
|
|
249
|
+
)
|
|
250
|
+
else:
|
|
251
|
+
# Retry with same model
|
|
252
|
+
retry_context = context.copy(attempts_left=context.attempts_left - 1)
|
|
253
|
+
|
|
254
|
+
# Print error message for debugging
|
|
255
|
+
error_msg = (
|
|
256
|
+
f"Error task {context.task_id}. Model: {response.model_internal}"
|
|
257
|
+
)
|
|
258
|
+
if response.status_code:
|
|
259
|
+
error_msg += f" Code: {response.status_code},"
|
|
260
|
+
error_msg += f" Message: {response.error_message}. Retrying..."
|
|
261
|
+
print(error_msg)
|
|
262
|
+
|
|
263
|
+
# Add to retry queue for later processing
|
|
264
|
+
await retry_queue.put(retry_context)
|
|
265
|
+
return response # Return the error response for now
|
|
266
|
+
|
|
267
|
+
# No retries left or no retry queue - final failure
|
|
268
|
+
context.status_tracker.task_failed(context.task_id)
|
|
269
|
+
context.maybe_callback(response, context.status_tracker)
|
|
270
|
+
|
|
271
|
+
# Print final error message
|
|
272
|
+
error_msg = f"Error task {context.task_id}. Model: {response.model_internal}"
|
|
273
|
+
if response.status_code:
|
|
274
|
+
error_msg += f" Code: {response.status_code},"
|
|
275
|
+
error_msg += f" Message: {response.error_message}. Giving up."
|
|
276
|
+
print(error_msg)
|
|
277
|
+
|
|
278
|
+
return response
|
|
279
|
+
|
|
177
280
|
@overload
|
|
178
281
|
async def process_prompts_async(
|
|
179
282
|
self,
|
|
@@ -181,11 +284,8 @@ class LLMClient(BaseModel):
|
|
|
181
284
|
*,
|
|
182
285
|
return_completions_only: Literal[True],
|
|
183
286
|
show_progress: bool = ...,
|
|
184
|
-
tools: list[Tool] | None = ...,
|
|
287
|
+
tools: list[Tool | dict | MCPServer] | None = ...,
|
|
185
288
|
cache: CachePattern | None = ...,
|
|
186
|
-
computer_use: bool = ...,
|
|
187
|
-
display_width: int = ...,
|
|
188
|
-
display_height: int = ...,
|
|
189
289
|
use_responses_api: bool = ...,
|
|
190
290
|
) -> list[str | None]: ...
|
|
191
291
|
|
|
@@ -196,11 +296,8 @@ class LLMClient(BaseModel):
|
|
|
196
296
|
*,
|
|
197
297
|
return_completions_only: Literal[False] = ...,
|
|
198
298
|
show_progress: bool = ...,
|
|
199
|
-
tools: list[Tool] | None = ...,
|
|
299
|
+
tools: list[Tool | dict | MCPServer] | None = ...,
|
|
200
300
|
cache: CachePattern | None = ...,
|
|
201
|
-
computer_use: bool = ...,
|
|
202
|
-
display_width: int = ...,
|
|
203
|
-
display_height: int = ...,
|
|
204
301
|
use_responses_api: bool = ...,
|
|
205
302
|
) -> list[APIResponse | None]: ...
|
|
206
303
|
|
|
@@ -210,147 +307,117 @@ class LLMClient(BaseModel):
|
|
|
210
307
|
*,
|
|
211
308
|
return_completions_only: bool = False,
|
|
212
309
|
show_progress: bool = True,
|
|
213
|
-
tools: list[Tool] | None = None,
|
|
310
|
+
tools: list[Tool | dict | MCPServer] | None = None,
|
|
214
311
|
cache: CachePattern | None = None,
|
|
215
|
-
computer_use: bool = False,
|
|
216
|
-
display_width: int = 1024,
|
|
217
|
-
display_height: int = 768,
|
|
218
312
|
use_responses_api: bool = False,
|
|
219
313
|
) -> list[APIResponse | None] | list[str | None] | dict[str, int]:
|
|
220
|
-
#
|
|
314
|
+
# Convert prompts to Conversations - no upfront cache checking for dynamic caching!
|
|
221
315
|
prompts = prompts_to_conversations(prompts)
|
|
222
|
-
ids =
|
|
223
|
-
|
|
224
|
-
# if using cache, check for cached completions
|
|
225
|
-
if self.cache:
|
|
226
|
-
cached_results = [self.cache.get(prompt) for prompt in prompts]
|
|
227
|
-
cache_hit_ids = [
|
|
228
|
-
id for id, res in zip(ids, cached_results) if res is not None
|
|
229
|
-
]
|
|
230
|
-
cache_hit_results = [res for res in cached_results if res is not None]
|
|
231
|
-
assert len(cache_hit_ids) == len(
|
|
232
|
-
cache_hit_results
|
|
233
|
-
), "Cache hit ids and results must be the same length."
|
|
234
|
-
remaining_ids = np.array([i for i in ids if i not in cache_hit_ids])
|
|
235
|
-
remaining_prompts = [prompts[i] for i in remaining_ids]
|
|
236
|
-
print(
|
|
237
|
-
f"{len(cache_hit_ids)} cache hits; {len(remaining_ids)} prompts remaining."
|
|
238
|
-
)
|
|
239
|
-
|
|
240
|
-
else:
|
|
241
|
-
cache_hit_ids = []
|
|
242
|
-
cache_hit_results = []
|
|
243
|
-
remaining_prompts = prompts
|
|
244
|
-
remaining_ids = ids
|
|
245
|
-
|
|
316
|
+
ids = list(range(len(prompts)))
|
|
246
317
|
results: list[APIResponse | None] = [None for _ in range(len(prompts))]
|
|
247
|
-
if len(remaining_prompts) > 0:
|
|
248
|
-
# Create StatusTracker with integrated progress bar
|
|
249
|
-
tracker = StatusTracker(
|
|
250
|
-
max_requests_per_minute=self.max_requests_per_minute,
|
|
251
|
-
max_tokens_per_minute=self.max_tokens_per_minute,
|
|
252
|
-
max_concurrent_requests=self.max_concurrent_requests,
|
|
253
|
-
use_progress_bar=show_progress,
|
|
254
|
-
progress_bar_total=len(prompts),
|
|
255
|
-
progress_bar_disable=not show_progress,
|
|
256
|
-
use_rich=show_progress, # Disable Rich if progress is disabled
|
|
257
|
-
)
|
|
258
318
|
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
319
|
+
# Create StatusTracker
|
|
320
|
+
tracker = StatusTracker(
|
|
321
|
+
max_requests_per_minute=self.max_requests_per_minute,
|
|
322
|
+
max_tokens_per_minute=self.max_tokens_per_minute,
|
|
323
|
+
max_concurrent_requests=self.max_concurrent_requests,
|
|
324
|
+
use_progress_bar=show_progress,
|
|
325
|
+
progress_bar_total=len(prompts),
|
|
326
|
+
progress_bar_disable=not show_progress,
|
|
327
|
+
use_rich=show_progress,
|
|
328
|
+
)
|
|
329
|
+
|
|
330
|
+
tracker.init_progress_bar()
|
|
331
|
+
|
|
332
|
+
# Create retry queue for failed requests
|
|
333
|
+
retry_queue: asyncio.Queue[RequestContext] = asyncio.Queue()
|
|
334
|
+
|
|
335
|
+
# Calculate sleep time for rate limiting
|
|
336
|
+
seconds_to_sleep_each_loop = (60.0 * 0.9) / tracker.max_requests_per_minute
|
|
337
|
+
|
|
338
|
+
# Main dispatch loop - using original pattern but with all prompts
|
|
339
|
+
next_context = None # Persist across iterations like original
|
|
340
|
+
prompts_not_finished = True
|
|
341
|
+
prompts_iter = iter(zip(ids, prompts))
|
|
342
|
+
|
|
343
|
+
while True:
|
|
344
|
+
# Get next context (retry or new) - only if we don't already have one waiting
|
|
345
|
+
retry_request = False
|
|
346
|
+
if next_context is None:
|
|
347
|
+
if not retry_queue.empty():
|
|
348
|
+
next_context = retry_queue.get_nowait()
|
|
349
|
+
retry_request = True
|
|
350
|
+
print(f"Retrying request {next_context.task_id}.")
|
|
351
|
+
elif prompts_not_finished:
|
|
352
|
+
try:
|
|
353
|
+
task_id, prompt = next(prompts_iter)
|
|
354
|
+
model, sampling_params = self._select_model()
|
|
355
|
+
assert isinstance(prompt, Conversation)
|
|
356
|
+
next_context = RequestContext(
|
|
357
|
+
task_id=task_id,
|
|
358
|
+
model_name=model,
|
|
359
|
+
prompt=prompt,
|
|
360
|
+
sampling_params=sampling_params,
|
|
361
|
+
attempts_left=self.max_attempts,
|
|
362
|
+
request_timeout=self.request_timeout,
|
|
363
|
+
status_tracker=tracker,
|
|
364
|
+
tools=tools,
|
|
365
|
+
cache=cache,
|
|
366
|
+
use_responses_api=use_responses_api,
|
|
367
|
+
)
|
|
368
|
+
except StopIteration:
|
|
369
|
+
prompts_not_finished = False
|
|
370
|
+
|
|
371
|
+
# Update capacity - original logic
|
|
372
|
+
tracker.update_capacity()
|
|
373
|
+
|
|
374
|
+
# Dispatch if capacity available - original logic
|
|
375
|
+
if next_context:
|
|
376
|
+
if tracker.check_capacity(next_context.num_tokens, retry=retry_request):
|
|
377
|
+
tracker.set_limiting_factor(None)
|
|
378
|
+
|
|
379
|
+
# Launch simplified request processing
|
|
380
|
+
async def process_and_store(ctx: RequestContext):
|
|
283
381
|
try:
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
# select model
|
|
287
|
-
model, sampling_params = self._select_model()
|
|
288
|
-
|
|
289
|
-
next_request = create_api_request(
|
|
290
|
-
task_id=id,
|
|
291
|
-
model_name=model,
|
|
292
|
-
prompt=prompt, # type: ignore
|
|
293
|
-
request_timeout=self.request_timeout,
|
|
294
|
-
attempts_left=self.max_attempts,
|
|
295
|
-
status_tracker=tracker,
|
|
296
|
-
results_arr=requests,
|
|
297
|
-
sampling_params=sampling_params,
|
|
298
|
-
all_model_names=self.models,
|
|
299
|
-
all_sampling_params=self.sampling_params,
|
|
300
|
-
tools=tools,
|
|
301
|
-
cache=cache,
|
|
302
|
-
computer_use=computer_use,
|
|
303
|
-
display_width=display_width,
|
|
304
|
-
display_height=display_height,
|
|
305
|
-
use_responses_api=use_responses_api,
|
|
382
|
+
response = await self.process_single_request(
|
|
383
|
+
ctx, retry_queue
|
|
306
384
|
)
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
# after finishing, log final status
|
|
340
|
-
tracker.log_final_status()
|
|
341
|
-
|
|
342
|
-
# deduplicate results by id
|
|
343
|
-
api_results = deduplicate_responses(requests)
|
|
344
|
-
for res in api_results:
|
|
345
|
-
results[res.id] = res
|
|
346
|
-
# set to cache if result has a completion
|
|
347
|
-
if self.cache and res.completion:
|
|
348
|
-
self.cache.put(prompts[res.id], res)
|
|
385
|
+
results[ctx.task_id] = response
|
|
386
|
+
except Exception as e:
|
|
387
|
+
# Create an error response for validation errors and other exceptions
|
|
388
|
+
from .api_requests.response import APIResponse
|
|
389
|
+
|
|
390
|
+
error_response = APIResponse(
|
|
391
|
+
id=ctx.task_id,
|
|
392
|
+
model_internal=ctx.model_name,
|
|
393
|
+
prompt=ctx.prompt,
|
|
394
|
+
sampling_params=ctx.sampling_params,
|
|
395
|
+
status_code=None,
|
|
396
|
+
is_error=True,
|
|
397
|
+
error_message=str(e),
|
|
398
|
+
)
|
|
399
|
+
results[ctx.task_id] = error_response
|
|
400
|
+
# Mark task as completed so the main loop can finish
|
|
401
|
+
if ctx.status_tracker:
|
|
402
|
+
ctx.status_tracker.task_failed(ctx.task_id)
|
|
403
|
+
|
|
404
|
+
asyncio.create_task(process_and_store(next_context))
|
|
405
|
+
next_context = None # Reset after successful dispatch
|
|
406
|
+
|
|
407
|
+
# Update progress - original logic
|
|
408
|
+
tracker.update_pbar()
|
|
409
|
+
|
|
410
|
+
# Check completion - original logic
|
|
411
|
+
if (
|
|
412
|
+
tracker.num_tasks_in_progress == 0
|
|
413
|
+
and not prompts_not_finished
|
|
414
|
+
and retry_queue.empty()
|
|
415
|
+
):
|
|
416
|
+
break
|
|
349
417
|
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
results[id] = res
|
|
418
|
+
# Sleep - original logic
|
|
419
|
+
await asyncio.sleep(seconds_to_sleep_each_loop + tracker.seconds_to_pause)
|
|
420
|
+
tracker.log_final_status()
|
|
354
421
|
|
|
355
422
|
if return_completions_only:
|
|
356
423
|
return [r.completion if r is not None else None for r in results]
|
|
@@ -363,7 +430,7 @@ class LLMClient(BaseModel):
|
|
|
363
430
|
*,
|
|
364
431
|
return_completions_only: bool = False,
|
|
365
432
|
show_progress=True,
|
|
366
|
-
tools: list[Tool] | None = None,
|
|
433
|
+
tools: list[Tool | dict | MCPServer] | None = None,
|
|
367
434
|
cache: CachePattern | None = None,
|
|
368
435
|
):
|
|
369
436
|
return asyncio.run(
|
|
@@ -376,7 +443,11 @@ class LLMClient(BaseModel):
|
|
|
376
443
|
)
|
|
377
444
|
)
|
|
378
445
|
|
|
379
|
-
async def stream(
|
|
446
|
+
async def stream(
|
|
447
|
+
self,
|
|
448
|
+
prompt: str | Conversation,
|
|
449
|
+
tools: list[Tool | dict | MCPServer] | None = None,
|
|
450
|
+
):
|
|
380
451
|
model, sampling_params = self._select_model()
|
|
381
452
|
if isinstance(prompt, str):
|
|
382
453
|
prompt = Conversation.user(prompt)
|
|
@@ -409,7 +480,7 @@ class LLMClient(BaseModel):
|
|
|
409
480
|
if len(self.models) != 1:
|
|
410
481
|
raise ValueError("Batch jobs can only be submitted with a single model.")
|
|
411
482
|
model = self.models[0]
|
|
412
|
-
api_spec = registry[model].
|
|
483
|
+
api_spec = registry[model].api_spec
|
|
413
484
|
|
|
414
485
|
if api_spec == "openai":
|
|
415
486
|
return await submit_batches_oa(model, self.sampling_params[0], prompts)
|
lm_deluge/file.py
CHANGED
|
@@ -141,8 +141,13 @@ class File:
|
|
|
141
141
|
return filename, content, media_type
|
|
142
142
|
|
|
143
143
|
def gemini(self) -> dict:
|
|
144
|
-
"""For Gemini API -
|
|
145
|
-
|
|
144
|
+
"""For Gemini API - files are provided as inline data."""
|
|
145
|
+
return {
|
|
146
|
+
"inlineData": {
|
|
147
|
+
"mimeType": self._mime(),
|
|
148
|
+
"data": self._base64(include_header=False),
|
|
149
|
+
}
|
|
150
|
+
}
|
|
146
151
|
|
|
147
152
|
def mistral(self) -> dict:
|
|
148
153
|
"""For Mistral API - not yet supported."""
|
lm_deluge/image.py
CHANGED
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
import os
|
|
2
2
|
from contextlib import contextmanager
|
|
3
|
-
from functools import cached_property
|
|
4
3
|
import io
|
|
5
4
|
import requests
|
|
6
5
|
from PIL import Image as PILImage # type: ignore
|
|
@@ -18,6 +17,8 @@ class Image:
|
|
|
18
17
|
media_type: str | None = None # inferred if None
|
|
19
18
|
detail: Literal["low", "high", "auto"] = "auto"
|
|
20
19
|
type: str = field(init=False, default="image")
|
|
20
|
+
_fingerprint_cache: str | None = field(init=False, default=None)
|
|
21
|
+
_size_cache: tuple[int, int] | None = field(init=False, default=None)
|
|
21
22
|
|
|
22
23
|
@classmethod
|
|
23
24
|
def from_pdf(
|
|
@@ -95,12 +96,14 @@ class Image:
|
|
|
95
96
|
if img:
|
|
96
97
|
img.close()
|
|
97
98
|
|
|
98
|
-
@
|
|
99
|
+
@property
|
|
99
100
|
def size(self) -> tuple[int, int]:
|
|
100
|
-
|
|
101
|
-
|
|
101
|
+
if self._size_cache is None:
|
|
102
|
+
with self._image() as img:
|
|
103
|
+
self._size_cache = img.size
|
|
104
|
+
return self._size_cache
|
|
102
105
|
|
|
103
|
-
@
|
|
106
|
+
@property
|
|
104
107
|
def num_pixels(self) -> int:
|
|
105
108
|
return self.size[0] * self.size[1]
|
|
106
109
|
|
|
@@ -143,11 +146,13 @@ class Image:
|
|
|
143
146
|
new_width = int(new_height / height * width)
|
|
144
147
|
return self._resize((new_width, new_height))
|
|
145
148
|
|
|
146
|
-
@
|
|
149
|
+
@property
|
|
147
150
|
def fingerprint(self) -> str:
|
|
148
151
|
# return base64 of a very small version of the image
|
|
149
|
-
|
|
150
|
-
|
|
152
|
+
if self._fingerprint_cache is None:
|
|
153
|
+
small_image = self._resize_longer(max_size=48) # longer side = 48px
|
|
154
|
+
self._fingerprint_cache = base64.b64encode(small_image).decode("utf-8")
|
|
155
|
+
return self._fingerprint_cache
|
|
151
156
|
|
|
152
157
|
def resize(self, max_size: int) -> None:
|
|
153
158
|
"""
|