lm-deluge 0.0.21__py3-none-any.whl → 0.0.23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of lm-deluge might be problematic. Click here for more details.

lm_deluge/batches.py CHANGED
@@ -3,7 +3,7 @@ import json
3
3
  import time
4
4
  import asyncio
5
5
  import aiohttp
6
- import pandas as pd
6
+ import tempfile
7
7
  from lm_deluge.prompt import CachePattern, Conversation, prompts_to_conversations
8
8
  from lm_deluge.config import SamplingParams
9
9
  from lm_deluge.models import APIModel
@@ -16,6 +16,7 @@ from rich.spinner import Spinner
16
16
  from rich.table import Table
17
17
  from rich.text import Text
18
18
  from lm_deluge.models import registry
19
+ from lm_deluge.request_context import RequestContext
19
20
 
20
21
 
21
22
  def _create_batch_status_display(
@@ -79,11 +80,8 @@ def _create_batch_status_display(
79
80
  return grid
80
81
 
81
82
 
82
- async def submit_batch_oa(batch_requests: list[dict]):
83
- """Submit one batch asynchronously."""
84
- pd.DataFrame(batch_requests).to_json(
85
- "requests_temp.jsonl", orient="records", lines=True
86
- )
83
+ async def submit_batch_oa(file_path: str):
84
+ """Upload a JSONL file and create one OpenAI batch."""
87
85
 
88
86
  # upload the file
89
87
  api_key = os.environ.get("OPENAI_API_KEY", None)
@@ -99,21 +97,22 @@ async def submit_batch_oa(batch_requests: list[dict]):
99
97
  url = "https://api.openai.com/v1/files"
100
98
  data = aiohttp.FormData()
101
99
  data.add_field("purpose", "batch")
102
- data.add_field(
103
- "file",
104
- open("requests_temp.jsonl", "rb"),
105
- filename="requests_temp.jsonl",
106
- content_type="application/json",
107
- )
100
+ with open(file_path, "rb") as f:
101
+ data.add_field(
102
+ "file",
103
+ f,
104
+ filename=os.path.basename(file_path),
105
+ content_type="application/json",
106
+ )
108
107
 
109
- async with session.post(url, data=data, headers=headers) as response:
110
- if response.status != 200:
111
- text = await response.text()
112
- raise ValueError(f"Error uploading file: {text}")
108
+ async with session.post(url, data=data, headers=headers) as response:
109
+ if response.status != 200:
110
+ text = await response.text()
111
+ raise ValueError(f"Error uploading file: {text}")
113
112
 
114
- print("File uploaded successfully")
115
- response_data = await response.json()
116
- file_id = response_data["id"]
113
+ print("File uploaded successfully")
114
+ response_data = await response.json()
115
+ file_id = response_data["id"]
117
116
 
118
117
  # Create batch
119
118
  url = "https://api.openai.com/v1/batches"
@@ -131,46 +130,82 @@ async def submit_batch_oa(batch_requests: list[dict]):
131
130
  response_data = await response.json()
132
131
  batch_id = response_data["id"]
133
132
  print("Batch job started successfully: id = ", batch_id)
134
- return batch_id
133
+
134
+ os.remove(file_path)
135
+ return batch_id
136
+
137
+
138
+ async def _submit_anthropic_batch(file_path: str, headers: dict, model: str):
139
+ """Upload a JSONL file and create one Anthropic batch."""
140
+
141
+ async with aiohttp.ClientSession() as session:
142
+ url = f"{registry[model].api_base}/messages/batches"
143
+ data = aiohttp.FormData()
144
+ with open(file_path, "rb") as f:
145
+ data.add_field(
146
+ "file",
147
+ f,
148
+ filename=os.path.basename(file_path),
149
+ content_type="application/json",
150
+ )
151
+
152
+ async with session.post(url, data=data, headers=headers) as response:
153
+ if response.status != 200:
154
+ text = await response.text()
155
+ raise ValueError(f"Error creating batch: {text}")
156
+
157
+ batch_data = await response.json()
158
+ batch_id = batch_data["id"]
159
+ print(f"Anthropic batch job started successfully: id = {batch_id}")
160
+
161
+ os.remove(file_path)
162
+ return batch_id
135
163
 
136
164
 
137
165
  async def submit_batches_oa(
138
166
  model: str,
139
167
  sampling_params: SamplingParams,
140
168
  prompts: Sequence[str | list[dict] | Conversation],
169
+ batch_size: int = 50_000,
141
170
  ):
142
- # if prompts are strings, convert them to message lists
171
+ """Write OpenAI batch requests to a file and submit."""
172
+ BATCH_SIZE = batch_size
173
+
143
174
  prompts = prompts_to_conversations(prompts)
144
175
  if any(p is None for p in prompts):
145
176
  raise ValueError("All prompts must be valid.")
146
- ids = [i for i, _ in enumerate(prompts)]
147
177
 
148
- # create file with requests to send to batch api
149
- batch_requests = []
150
178
  model_obj = APIModel.from_registry(model)
151
- for id, prompt in zip(ids, prompts):
152
- assert isinstance(prompt, Conversation)
153
- batch_requests.append(
154
- {
155
- "custom_id": str(id),
156
- "method": "POST",
157
- "url": "/v1/chat/completions",
158
- "body": _build_oa_chat_request(model_obj, prompt, [], sampling_params),
159
- }
160
- )
161
179
 
162
- # since the api only accepts up to 50,000 requests per batch job, we chunk into 50k chunks
163
- BATCH_SIZE = 50_000
164
- batches = [
165
- batch_requests[i : i + BATCH_SIZE]
166
- for i in range(0, len(batch_requests), BATCH_SIZE)
167
- ]
168
180
  tasks = []
169
- for batch in batches:
170
- tasks.append(asyncio.create_task(submit_batch_oa(batch)))
181
+
182
+ for start in range(0, len(prompts), BATCH_SIZE):
183
+ batch_prompts = prompts[start : start + BATCH_SIZE]
184
+ with tempfile.NamedTemporaryFile(mode="w+", suffix=".jsonl", delete=False) as f:
185
+ for idx, prompt in enumerate(batch_prompts, start=start):
186
+ assert isinstance(prompt, Conversation)
187
+ context = RequestContext(
188
+ task_id=idx,
189
+ model_name=model,
190
+ prompt=prompt,
191
+ sampling_params=sampling_params,
192
+ )
193
+ request = {
194
+ "custom_id": str(idx),
195
+ "method": "POST",
196
+ "url": "/v1/chat/completions",
197
+ "body": _build_oa_chat_request(model_obj, context),
198
+ }
199
+ json.dump(request, f)
200
+ f.write("\n")
201
+
202
+ file_path = f.name
203
+
204
+ tasks.append(asyncio.create_task(submit_batch_oa(file_path)))
205
+
171
206
  batch_ids = await asyncio.gather(*tasks)
172
207
 
173
- print(f"Submitted {len(batches)} batch jobs.")
208
+ print(f"Submitted {len(tasks)} batch jobs.")
174
209
 
175
210
  return batch_ids
176
211
 
@@ -181,6 +216,7 @@ async def submit_batches_anthropic(
181
216
  prompts: Sequence[str | list[dict] | Conversation],
182
217
  *,
183
218
  cache: CachePattern | None = None,
219
+ batch_size=100_000,
184
220
  ):
185
221
  """Submit a batch job to Anthropic's Message Batches API.
186
222
 
@@ -196,47 +232,40 @@ async def submit_batches_anthropic(
196
232
 
197
233
  # Convert prompts to Conversations
198
234
  prompts = prompts_to_conversations(prompts)
199
- # Create batch requests
200
- request_headers = None
201
- batch_requests = []
202
- for i, prompt in enumerate(prompts):
203
- assert isinstance(prompt, Conversation)
204
- # Build request body
205
- request_body, request_headers = _build_anthropic_request(
206
- APIModel.from_registry(model), prompt, [], sampling_params, cache
207
- )
208
235
 
209
- batch_requests.append({"custom_id": str(i), "params": request_body})
210
-
211
- # Chunk into batches of 100k requests (Anthropic's limit)
212
- BATCH_SIZE = 100_000
213
- batches = [
214
- batch_requests[i : i + BATCH_SIZE]
215
- for i in range(0, len(batch_requests), BATCH_SIZE)
216
- ]
217
- batch_ids = []
236
+ request_headers = None
237
+ BATCH_SIZE = batch_size
218
238
  batch_tasks = []
219
- async with aiohttp.ClientSession() as session:
220
- for batch in batches:
221
- url = f"{registry[model].api_base}/messages/batches"
222
- data = {"requests": batch}
223
239
 
224
- async def submit_batch(data, url, headers):
225
- async with session.post(url, json=data, headers=headers) as response:
226
- if response.status != 200:
227
- text = await response.text()
228
- raise ValueError(f"Error creating batch: {text}")
240
+ for start in range(0, len(prompts), BATCH_SIZE):
241
+ batch_prompts = prompts[start : start + BATCH_SIZE]
242
+ with tempfile.NamedTemporaryFile(mode="w+", suffix=".jsonl", delete=False) as f:
243
+ for idx, prompt in enumerate(batch_prompts, start=start):
244
+ assert isinstance(prompt, Conversation)
245
+ context = RequestContext(
246
+ task_id=idx,
247
+ model_name=model,
248
+ prompt=prompt,
249
+ sampling_params=sampling_params,
250
+ cache=cache,
251
+ )
252
+ request_body, request_headers = _build_anthropic_request(
253
+ APIModel.from_registry(model), context
254
+ )
255
+ json.dump({"custom_id": str(idx), "params": request_body}, f)
256
+ f.write("\n")
229
257
 
230
- batch_data = await response.json()
231
- batch_id = batch_data["id"]
232
- print(f"Anthropic batch job started successfully: id = {batch_id}")
233
- return batch_id
258
+ file_path = f.name
234
259
 
235
- batch_tasks.append(submit_batch(data, url, request_headers))
260
+ batch_tasks.append(
261
+ asyncio.create_task(
262
+ _submit_anthropic_batch(file_path, request_headers, model) # type: ignore
263
+ )
264
+ )
236
265
 
237
- batch_ids = await asyncio.gather(*batch_tasks)
266
+ batch_ids = await asyncio.gather(*batch_tasks)
238
267
 
239
- print(f"Submitted {len(batches)} batch jobs.")
268
+ print(f"Submitted {len(batch_tasks)} batch jobs.")
240
269
  return batch_ids
241
270
 
242
271
 
lm_deluge/client.py CHANGED
@@ -22,11 +22,8 @@ from .models import APIModel, registry
22
22
  from .request_context import RequestContext
23
23
  from .tracker import StatusTracker
24
24
 
25
- # from .cache import LevelDBCache, SqliteCache
26
-
27
25
 
28
26
  # TODO: get completions as they finish, not all at once at the end.
29
- # relatedly, would be nice to cache them as they finish too.
30
27
  # TODO: add optional max_input_tokens to client so we can reject long prompts to prevent abuse
31
28
  class LLMClient(BaseModel):
32
29
  """
@@ -60,6 +57,7 @@ class LLMClient(BaseModel):
60
57
  reasoning_effort: Literal["low", "medium", "high", None] = None
61
58
  logprobs: bool = False
62
59
  top_logprobs: int | None = None
60
+ force_local_mcp: bool = False
63
61
 
64
62
  # NEW! Builder methods
65
63
  def with_model(self, model: str):
@@ -113,6 +111,7 @@ class LLMClient(BaseModel):
113
111
  if isinstance(self.model_names, str):
114
112
  self.model_names = [self.model_names]
115
113
  if any(m not in registry for m in self.model_names):
114
+ print("got model names:", self.model_names)
116
115
  raise ValueError("all model_names must be in registry")
117
116
  if isinstance(self.sampling_params, SamplingParams):
118
117
  self.sampling_params = [self.sampling_params for _ in self.model_names]
@@ -368,6 +367,7 @@ class LLMClient(BaseModel):
368
367
  cache=cache,
369
368
  use_responses_api=use_responses_api,
370
369
  extra_headers=self.extra_headers,
370
+ force_local_mcp=self.force_local_mcp,
371
371
  )
372
372
  except StopIteration:
373
373
  prompts_not_finished = False
@@ -389,8 +389,6 @@ class LLMClient(BaseModel):
389
389
  results[ctx.task_id] = response
390
390
  except Exception as e:
391
391
  # Create an error response for validation errors and other exceptions
392
- from .api_requests.response import APIResponse
393
-
394
392
  error_response = APIResponse(
395
393
  id=ctx.task_id,
396
394
  model_internal=ctx.model_name,
@@ -421,7 +419,8 @@ class LLMClient(BaseModel):
421
419
 
422
420
  # Sleep - original logic
423
421
  await asyncio.sleep(seconds_to_sleep_each_loop + tracker.seconds_to_pause)
424
- tracker.log_final_status()
422
+
423
+ tracker.log_final_status()
425
424
 
426
425
  if return_completions_only:
427
426
  return [r.completion if r is not None else None for r in results]
@@ -468,7 +467,7 @@ class LLMClient(BaseModel):
468
467
  self,
469
468
  conversation: str | Conversation,
470
469
  *,
471
- tools: list[Tool | dict] | None = None,
470
+ tools: list[Tool | dict | MCPServer] | None = None,
472
471
  max_rounds: int = 5,
473
472
  show_progress: bool = False,
474
473
  ) -> tuple[Conversation, APIResponse]:
@@ -482,6 +481,16 @@ class LLMClient(BaseModel):
482
481
  if isinstance(conversation, str):
483
482
  conversation = Conversation.user(conversation)
484
483
 
484
+ # Expand MCPServer objects to their constituent tools for tool execution
485
+ expanded_tools: list[Tool] = []
486
+ if tools:
487
+ for tool in tools:
488
+ if isinstance(tool, Tool):
489
+ expanded_tools.append(tool)
490
+ elif isinstance(tool, MCPServer):
491
+ mcp_tools = await tool.to_tools()
492
+ expanded_tools.extend(mcp_tools)
493
+
485
494
  last_response: APIResponse | None = None
486
495
 
487
496
  for _ in range(max_rounds):
@@ -504,9 +513,9 @@ class LLMClient(BaseModel):
504
513
 
505
514
  for call in tool_calls:
506
515
  tool_obj = None
507
- if tools:
508
- for t in tools:
509
- if isinstance(t, Tool) and t.name == call.name:
516
+ if expanded_tools:
517
+ for t in expanded_tools:
518
+ if t.name == call.name:
510
519
  tool_obj = t
511
520
  break
512
521
 
@@ -553,6 +562,7 @@ class LLMClient(BaseModel):
553
562
  *,
554
563
  tools: list[Tool] | None = None,
555
564
  cache: CachePattern | None = None,
565
+ batch_size: int = 50_000,
556
566
  ):
557
567
  """Submit a batch job asynchronously, automatically detecting the provider based on model.
558
568
 
@@ -572,13 +582,16 @@ class LLMClient(BaseModel):
572
582
  api_spec = registry[model].api_spec
573
583
 
574
584
  if api_spec == "openai":
575
- return await submit_batches_oa(model, self.sampling_params[0], prompts)
585
+ return await submit_batches_oa(
586
+ model, self.sampling_params[0], prompts, batch_size=batch_size
587
+ )
576
588
  elif api_spec == "anthropic":
577
589
  return await submit_batches_anthropic(
578
590
  model,
579
591
  self.sampling_params[0],
580
592
  prompts,
581
593
  cache=cache,
594
+ batch_size=batch_size,
582
595
  )
583
596
  else:
584
597
  raise ValueError(f"Batch processing not supported for API spec: {api_spec}")
lm_deluge/image.py CHANGED
@@ -10,7 +10,7 @@ from typing import Literal
10
10
  import requests
11
11
  from PIL import Image as PILImage # type: ignore
12
12
 
13
- MediaType = Literal["image/jpeg", "image/png", "image/gif", "image/webp"]
13
+ MediaType = Literal["image/jpeg", "image/png", "image/gif", "image/webp"] | str
14
14
 
15
15
 
16
16
  @dataclass(slots=True)
@@ -23,6 +23,9 @@ class Image:
23
23
  _fingerprint_cache: str | None = field(init=False, default=None)
24
24
  _size_cache: tuple[int, int] | None = field(init=False, default=None)
25
25
 
26
+ def __repr__(self):
27
+ return f"Image(data=[{type(self.data)}], media_type={self.media_type}, detail={self.detail})"
28
+
26
29
  @classmethod
27
30
  def from_pdf(
28
31
  cls,
@@ -69,10 +72,11 @@ class Image:
69
72
  elif isinstance(self.data, Path) and self.data.exists():
70
73
  return Path(self.data).read_bytes()
71
74
  elif isinstance(self.data, str) and self.data.startswith("data:"):
75
+ # print("base64 path selected")
72
76
  header, encoded = self.data.split(",", 1)
73
77
  return base64.b64decode(encoded)
74
78
  else:
75
- raise ValueError("unreadable image format")
79
+ raise ValueError(f"unreadable image format. type: {type(self.data)}")
76
80
 
77
81
  def _mime(self) -> str:
78
82
  if self.media_type:
lm_deluge/models.py CHANGED
@@ -42,7 +42,7 @@ BUILTIN_MODELS = {
42
42
  "reasoning_model": False,
43
43
  },
44
44
  "llama-3.3-70b": {
45
- "id": "llama-3.3-70B",
45
+ "id": "llama-3.3-70b",
46
46
  "name": "Llama-3.3-70B-Instruct",
47
47
  "api_base": "https://api.llama.com/compat/v1",
48
48
  "api_key_env_var": "META_API_KEY",
@@ -56,7 +56,7 @@ BUILTIN_MODELS = {
56
56
  "reasoning_model": False,
57
57
  },
58
58
  "llama-3.3-8b": {
59
- "id": "llama-3.3-8B",
59
+ "id": "llama-3.3-8b",
60
60
  "name": "Llama-3.3-8B-Instruct",
61
61
  "api_base": "https://api.llama.com/compat/v1",
62
62
  "api_key_env_var": "META_API_KEY",
@@ -670,62 +670,62 @@ BUILTIN_MODELS = {
670
670
  # "requests_per_minute": 120,
671
671
  # "tokens_per_minute": None,
672
672
  # },
673
- "gemini-2.5-pro-vertex": {
674
- "id": "gemini-2.5-pro",
675
- "name": "gemini-2.5-pro-preview-05-06",
676
- "api_base": "",
677
- "api_key_env_var": "GOOGLE_APPLICATION_CREDENTIALS",
678
- "supports_json": True,
679
- "supports_logprobs": False,
680
- "api_spec": "vertex_gemini",
681
- "input_cost": 1.25,
682
- "output_cost": 10.0,
683
- "requests_per_minute": 20,
684
- "tokens_per_minute": 100_000,
685
- "reasoning_model": True,
686
- },
687
- "gemini-2.5-flash-vertex": {
688
- "id": "gemini-2.5-flash",
689
- "name": "gemini-2.5-flash-preview-05-20",
690
- "api_base": "",
691
- "api_key_env_var": "GOOGLE_APPLICATION_CREDENTIALS",
692
- "supports_json": True,
693
- "supports_logprobs": False,
694
- "api_spec": "vertex_gemini",
695
- "input_cost": 0.15,
696
- "output_cost": 0.6,
697
- "requests_per_minute": 20,
698
- "tokens_per_minute": 100_000,
699
- "reasoning_model": True,
700
- },
701
- "gemini-2.0-flash-vertex": {
702
- "id": "gemini-2.0-flash",
703
- "name": "gemini-2.0-flash",
704
- "api_base": "",
705
- "api_key_env_var": "GOOGLE_APPLICATION_CREDENTIALS",
706
- "supports_json": True,
707
- "supports_logprobs": False,
708
- "api_spec": "vertex_gemini",
709
- "input_cost": 0.10,
710
- "output_cost": 0.40,
711
- "requests_per_minute": 20,
712
- "tokens_per_minute": 100_000,
713
- "reasoning_model": False,
714
- },
715
- "gemini-2.0-flash-lite-vertex": {
716
- "id": "gemini-2.0-flash-lite",
717
- "name": "gemini-2.0-flash-lite",
718
- "api_base": "",
719
- "api_key_env_var": "GOOGLE_APPLICATION_CREDENTIALS",
720
- "supports_json": True,
721
- "supports_logprobs": False,
722
- "api_spec": "vertex_gemini",
723
- "input_cost": 0.075,
724
- "output_cost": 0.30,
725
- "requests_per_minute": 20,
726
- "tokens_per_minute": 100_000,
727
- "reasoning_model": False,
728
- },
673
+ # "gemini-2.5-pro-vertex": {
674
+ # "id": "gemini-2.5-pro",
675
+ # "name": "gemini-2.5-pro-preview-05-06",
676
+ # "api_base": "",
677
+ # "api_key_env_var": "GOOGLE_APPLICATION_CREDENTIALS",
678
+ # "supports_json": True,
679
+ # "supports_logprobs": False,
680
+ # "api_spec": "vertex_gemini",
681
+ # "input_cost": 1.25,
682
+ # "output_cost": 10.0,
683
+ # "requests_per_minute": 20,
684
+ # "tokens_per_minute": 100_000,
685
+ # "reasoning_model": True,
686
+ # },
687
+ # "gemini-2.5-flash-vertex": {
688
+ # "id": "gemini-2.5-flash",
689
+ # "name": "gemini-2.5-flash-preview-05-20",
690
+ # "api_base": "",
691
+ # "api_key_env_var": "GOOGLE_APPLICATION_CREDENTIALS",
692
+ # "supports_json": True,
693
+ # "supports_logprobs": False,
694
+ # "api_spec": "vertex_gemini",
695
+ # "input_cost": 0.15,
696
+ # "output_cost": 0.6,
697
+ # "requests_per_minute": 20,
698
+ # "tokens_per_minute": 100_000,
699
+ # "reasoning_model": True,
700
+ # },
701
+ # "gemini-2.0-flash-vertex": {
702
+ # "id": "gemini-2.0-flash",
703
+ # "name": "gemini-2.0-flash",
704
+ # "api_base": "",
705
+ # "api_key_env_var": "GOOGLE_APPLICATION_CREDENTIALS",
706
+ # "supports_json": True,
707
+ # "supports_logprobs": False,
708
+ # "api_spec": "vertex_gemini",
709
+ # "input_cost": 0.10,
710
+ # "output_cost": 0.40,
711
+ # "requests_per_minute": 20,
712
+ # "tokens_per_minute": 100_000,
713
+ # "reasoning_model": False,
714
+ # },
715
+ # "gemini-2.0-flash-lite-vertex": {
716
+ # "id": "gemini-2.0-flash-lite",
717
+ # "name": "gemini-2.0-flash-lite",
718
+ # "api_base": "",
719
+ # "api_key_env_var": "GOOGLE_APPLICATION_CREDENTIALS",
720
+ # "supports_json": True,
721
+ # "supports_logprobs": False,
722
+ # "api_spec": "vertex_gemini",
723
+ # "input_cost": 0.075,
724
+ # "output_cost": 0.30,
725
+ # "requests_per_minute": 20,
726
+ # "tokens_per_minute": 100_000,
727
+ # "reasoning_model": False,
728
+ # },
729
729
  # ███████████ █████ █████
730
730
  # ░░███░░░░░███ ░░███ ░░███
731
731
  # ░███ ░███ ██████ ███████ ████████ ██████ ██████ ░███ █████
@@ -1138,7 +1138,7 @@ BUILTIN_MODELS = {
1138
1138
  "output_cost": 0.7,
1139
1139
  },
1140
1140
  "mixtral-8x22b": {
1141
- "id": "mistral-8x22b",
1141
+ "id": "mixtral-8x22b",
1142
1142
  "name": "open-mixtral-8x22b",
1143
1143
  "api_base": "https://api.mistral.ai/v1",
1144
1144
  "api_key_env_var": "MISTRAL_API_KEY",
@@ -1243,3 +1243,5 @@ def register_model(**kwargs) -> APIModel:
1243
1243
  # Populate registry with builtin models
1244
1244
  for cfg in BUILTIN_MODELS.values():
1245
1245
  register_model(**cfg)
1246
+
1247
+ # print("Valid models:", registry.keys())