lm-deluge 0.0.56__py3-none-any.whl → 0.0.69__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lm_deluge/__init__.py +12 -1
- lm_deluge/api_requests/anthropic.py +12 -1
- lm_deluge/api_requests/base.py +87 -5
- lm_deluge/api_requests/bedrock.py +3 -4
- lm_deluge/api_requests/chat_reasoning.py +4 -0
- lm_deluge/api_requests/gemini.py +7 -6
- lm_deluge/api_requests/mistral.py +8 -9
- lm_deluge/api_requests/openai.py +179 -124
- lm_deluge/batches.py +25 -9
- lm_deluge/client.py +280 -67
- lm_deluge/config.py +1 -1
- lm_deluge/file.py +382 -13
- lm_deluge/mock_openai.py +482 -0
- lm_deluge/models/__init__.py +12 -8
- lm_deluge/models/anthropic.py +12 -20
- lm_deluge/models/bedrock.py +0 -14
- lm_deluge/models/cohere.py +0 -16
- lm_deluge/models/google.py +0 -20
- lm_deluge/models/grok.py +48 -4
- lm_deluge/models/groq.py +2 -2
- lm_deluge/models/kimi.py +34 -0
- lm_deluge/models/meta.py +0 -8
- lm_deluge/models/minimax.py +10 -0
- lm_deluge/models/openai.py +28 -34
- lm_deluge/models/openrouter.py +64 -1
- lm_deluge/models/together.py +0 -16
- lm_deluge/prompt.py +138 -29
- lm_deluge/request_context.py +9 -11
- lm_deluge/tool.py +395 -19
- lm_deluge/tracker.py +11 -5
- lm_deluge/warnings.py +46 -0
- {lm_deluge-0.0.56.dist-info → lm_deluge-0.0.69.dist-info}/METADATA +3 -1
- {lm_deluge-0.0.56.dist-info → lm_deluge-0.0.69.dist-info}/RECORD +36 -33
- lm_deluge/agent.py +0 -0
- lm_deluge/gemini_limits.py +0 -65
- {lm_deluge-0.0.56.dist-info → lm_deluge-0.0.69.dist-info}/WHEEL +0 -0
- {lm_deluge-0.0.56.dist-info → lm_deluge-0.0.69.dist-info}/licenses/LICENSE +0 -0
- {lm_deluge-0.0.56.dist-info → lm_deluge-0.0.69.dist-info}/top_level.txt +0 -0
lm_deluge/api_requests/openai.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import os
|
|
3
3
|
import traceback as tb
|
|
4
|
-
import warnings
|
|
5
4
|
from types import SimpleNamespace
|
|
6
5
|
|
|
7
6
|
import aiohttp
|
|
@@ -9,6 +8,7 @@ from aiohttp import ClientResponse
|
|
|
9
8
|
|
|
10
9
|
from lm_deluge.request_context import RequestContext
|
|
11
10
|
from lm_deluge.tool import MCPServer, Tool
|
|
11
|
+
from lm_deluge.warnings import maybe_warn
|
|
12
12
|
|
|
13
13
|
from ..config import SamplingParams
|
|
14
14
|
from ..models import APIModel
|
|
@@ -30,6 +30,26 @@ async def _build_oa_chat_request(
|
|
|
30
30
|
"temperature": sampling_params.temperature,
|
|
31
31
|
"top_p": sampling_params.top_p,
|
|
32
32
|
}
|
|
33
|
+
if context.service_tier:
|
|
34
|
+
assert context.service_tier in [
|
|
35
|
+
"auto",
|
|
36
|
+
"default",
|
|
37
|
+
"flex",
|
|
38
|
+
"priority",
|
|
39
|
+
], f"Invalid service tier: {context.service_tier}"
|
|
40
|
+
# flex is only supported for o3, o4-mini, gpt-5 models
|
|
41
|
+
if context.service_tier == "flex":
|
|
42
|
+
model_supports_flex = any(x in model.id for x in ["o3", "o4-mini", "gpt-5"])
|
|
43
|
+
if not model_supports_flex:
|
|
44
|
+
print(
|
|
45
|
+
f"WARNING: service_tier='flex' only supported for o3, o4-mini, gpt-5. "
|
|
46
|
+
f"Using 'auto' instead for model {model.id}."
|
|
47
|
+
)
|
|
48
|
+
request_json["service_tier"] = "auto"
|
|
49
|
+
else:
|
|
50
|
+
request_json["service_tier"] = context.service_tier
|
|
51
|
+
else:
|
|
52
|
+
request_json["service_tier"] = context.service_tier
|
|
33
53
|
# set max_tokens or max_completion_tokens dep. on provider
|
|
34
54
|
if "cohere" in model.api_base:
|
|
35
55
|
request_json["max_tokens"] = sampling_params.max_new_tokens
|
|
@@ -55,9 +75,8 @@ async def _build_oa_chat_request(
|
|
|
55
75
|
request_json["reasoning_effort"] = effort
|
|
56
76
|
else:
|
|
57
77
|
if sampling_params.reasoning_effort:
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
)
|
|
78
|
+
maybe_warn("WARN_REASONING_UNSUPPORTED", model_name=context.model_name)
|
|
79
|
+
|
|
61
80
|
if sampling_params.logprobs:
|
|
62
81
|
request_json["logprobs"] = True
|
|
63
82
|
if sampling_params.top_logprobs is not None:
|
|
@@ -85,8 +104,10 @@ class OpenAIRequest(APIRequestBase):
|
|
|
85
104
|
|
|
86
105
|
# Warn if cache is specified for non-Anthropic model
|
|
87
106
|
if self.context.cache is not None:
|
|
88
|
-
|
|
89
|
-
|
|
107
|
+
maybe_warn(
|
|
108
|
+
"WARN_CACHING_UNSUPPORTED",
|
|
109
|
+
model_name=self.context.model_name,
|
|
110
|
+
cache_param=self.context.cache,
|
|
90
111
|
)
|
|
91
112
|
self.model = APIModel.from_registry(self.context.model_name)
|
|
92
113
|
|
|
@@ -163,7 +184,8 @@ class OpenAIRequest(APIRequestBase):
|
|
|
163
184
|
|
|
164
185
|
content = Message("assistant", parts)
|
|
165
186
|
|
|
166
|
-
usage
|
|
187
|
+
if "usage" in data and data["usage"] is not None:
|
|
188
|
+
usage = Usage.from_openai_usage(data["usage"])
|
|
167
189
|
if (
|
|
168
190
|
self.context.sampling_params.logprobs
|
|
169
191
|
and "logprobs" in data["choices"][0]
|
|
@@ -213,9 +235,6 @@ class OpenAIRequest(APIRequestBase):
|
|
|
213
235
|
async def _build_oa_responses_request(
|
|
214
236
|
model: APIModel,
|
|
215
237
|
context: RequestContext,
|
|
216
|
-
# prompt: Conversation,
|
|
217
|
-
# tools: list[Tool] | None,
|
|
218
|
-
# sampling_params: SamplingParams,
|
|
219
238
|
):
|
|
220
239
|
prompt = context.prompt
|
|
221
240
|
sampling_params = context.sampling_params
|
|
@@ -226,7 +245,28 @@ async def _build_oa_responses_request(
|
|
|
226
245
|
"input": openai_responses_format["input"],
|
|
227
246
|
"temperature": sampling_params.temperature,
|
|
228
247
|
"top_p": sampling_params.top_p,
|
|
248
|
+
"background": context.background or False,
|
|
229
249
|
}
|
|
250
|
+
if context.service_tier:
|
|
251
|
+
assert context.service_tier in [
|
|
252
|
+
"auto",
|
|
253
|
+
"default",
|
|
254
|
+
"flex",
|
|
255
|
+
"priority",
|
|
256
|
+
], f"Invalid service tier: {context.service_tier}"
|
|
257
|
+
# flex is only supported for o3, o4-mini, gpt-5 models
|
|
258
|
+
if context.service_tier == "flex":
|
|
259
|
+
model_supports_flex = any(x in model.id for x in ["o3", "o4-mini", "gpt-5"])
|
|
260
|
+
if not model_supports_flex:
|
|
261
|
+
print(
|
|
262
|
+
f"WARNING: service_tier='flex' only supported for o3, o4-mini, gpt-5. "
|
|
263
|
+
f"Model {model.id} doesn't support flex. Using 'auto' instead."
|
|
264
|
+
)
|
|
265
|
+
request_json["service_tier"] = "auto"
|
|
266
|
+
else:
|
|
267
|
+
request_json["service_tier"] = context.service_tier
|
|
268
|
+
else:
|
|
269
|
+
request_json["service_tier"] = context.service_tier
|
|
230
270
|
if sampling_params.max_new_tokens:
|
|
231
271
|
request_json["max_output_tokens"] = sampling_params.max_new_tokens
|
|
232
272
|
|
|
@@ -245,9 +285,7 @@ async def _build_oa_responses_request(
|
|
|
245
285
|
}
|
|
246
286
|
else:
|
|
247
287
|
if sampling_params.reasoning_effort:
|
|
248
|
-
|
|
249
|
-
f"Ignoring reasoning_effort for non-reasoning model: {model.id}"
|
|
250
|
-
)
|
|
288
|
+
maybe_warn("WARN_REASONING_UNSUPPORTED", model_name=context.model_name)
|
|
251
289
|
|
|
252
290
|
if sampling_params.json_mode and model.supports_json:
|
|
253
291
|
request_json["text"] = {"format": {"type": "json_object"}}
|
|
@@ -284,8 +322,10 @@ class OpenAIResponsesRequest(APIRequestBase):
|
|
|
284
322
|
super().__init__(context)
|
|
285
323
|
# Warn if cache is specified for non-Anthropic model
|
|
286
324
|
if self.context.cache is not None:
|
|
287
|
-
|
|
288
|
-
|
|
325
|
+
maybe_warn(
|
|
326
|
+
"WARN_CACHING_UNSUPPORTED",
|
|
327
|
+
model_name=self.context.model_name,
|
|
328
|
+
cache_param=self.context.cache,
|
|
289
329
|
)
|
|
290
330
|
self.model = APIModel.from_registry(self.context.model_name)
|
|
291
331
|
|
|
@@ -310,7 +350,8 @@ class OpenAIResponsesRequest(APIRequestBase):
|
|
|
310
350
|
assert self.context.status_tracker
|
|
311
351
|
|
|
312
352
|
if status_code == 500:
|
|
313
|
-
|
|
353
|
+
res_text = await http_response.text()
|
|
354
|
+
print("Internal Server Error: ", res_text)
|
|
314
355
|
|
|
315
356
|
if status_code >= 200 and status_code < 300:
|
|
316
357
|
try:
|
|
@@ -322,126 +363,138 @@ class OpenAIResponsesRequest(APIRequestBase):
|
|
|
322
363
|
)
|
|
323
364
|
if not is_error:
|
|
324
365
|
assert data is not None, "data is None"
|
|
325
|
-
try:
|
|
326
|
-
# Parse Responses API format
|
|
327
|
-
parts = []
|
|
328
366
|
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
367
|
+
# Check if response is incomplete
|
|
368
|
+
if data.get("status") == "incomplete":
|
|
369
|
+
is_error = True
|
|
370
|
+
incomplete_reason = data.get("incomplete_details", {}).get(
|
|
371
|
+
"reason", "unknown"
|
|
372
|
+
)
|
|
373
|
+
error_message = f"Response incomplete: {incomplete_reason}"
|
|
374
|
+
|
|
375
|
+
if not is_error:
|
|
376
|
+
try:
|
|
377
|
+
# Parse Responses API format
|
|
378
|
+
parts = []
|
|
379
|
+
|
|
380
|
+
# Get the output array from the response
|
|
381
|
+
output = data.get("output", [])
|
|
382
|
+
if not output:
|
|
383
|
+
is_error = True
|
|
384
|
+
error_message = f"No output in response. Status: {data.get('status')}, error: {data.get('error')}, incomplete details: {data.get('incomplete_details')}"
|
|
385
|
+
else:
|
|
386
|
+
# Process each output item
|
|
387
|
+
for item in output:
|
|
388
|
+
if item.get("type") == "message":
|
|
389
|
+
message_content = item.get("content", [])
|
|
390
|
+
for content_item in message_content:
|
|
391
|
+
if content_item.get("type") == "output_text":
|
|
392
|
+
parts.append(Text(content_item["text"]))
|
|
393
|
+
elif content_item.get("type") == "refusal":
|
|
394
|
+
parts.append(Text(content_item["refusal"]))
|
|
395
|
+
elif item.get("type") == "reasoning":
|
|
396
|
+
summary = item["summary"]
|
|
397
|
+
if not summary:
|
|
398
|
+
continue
|
|
399
|
+
if isinstance(summary, list) and len(summary) > 0:
|
|
400
|
+
summary = summary[0]
|
|
401
|
+
assert isinstance(
|
|
402
|
+
summary, dict
|
|
403
|
+
), "summary isn't a dict"
|
|
404
|
+
parts.append(Thinking(summary["text"]))
|
|
405
|
+
elif item.get("type") == "function_call":
|
|
406
|
+
parts.append(
|
|
407
|
+
ToolCall(
|
|
408
|
+
id=item["call_id"],
|
|
409
|
+
name=item["name"],
|
|
410
|
+
arguments=json.loads(item["arguments"]),
|
|
411
|
+
)
|
|
358
412
|
)
|
|
359
|
-
)
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
413
|
+
elif item.get("type") == "mcp_call":
|
|
414
|
+
parts.append(
|
|
415
|
+
ToolCall(
|
|
416
|
+
id=item["id"],
|
|
417
|
+
name=item["name"],
|
|
418
|
+
arguments=json.loads(item["arguments"]),
|
|
419
|
+
built_in=True,
|
|
420
|
+
built_in_type="mcp_call",
|
|
421
|
+
extra_body={
|
|
422
|
+
"server_label": item["server_label"],
|
|
423
|
+
"error": item.get("error"),
|
|
424
|
+
"output": item.get("output"),
|
|
425
|
+
},
|
|
426
|
+
)
|
|
373
427
|
)
|
|
374
|
-
)
|
|
375
428
|
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
429
|
+
elif item.get("type") == "computer_call":
|
|
430
|
+
parts.append(
|
|
431
|
+
ToolCall(
|
|
432
|
+
id=item["call_id"],
|
|
433
|
+
name="computer_call",
|
|
434
|
+
arguments=item.get("action"),
|
|
435
|
+
built_in=True,
|
|
436
|
+
built_in_type="computer_call",
|
|
437
|
+
)
|
|
384
438
|
)
|
|
385
|
-
)
|
|
386
439
|
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
440
|
+
elif item.get("type") == "web_search_call":
|
|
441
|
+
parts.append(
|
|
442
|
+
ToolCall(
|
|
443
|
+
id=item["id"],
|
|
444
|
+
name="web_search_call",
|
|
445
|
+
arguments={},
|
|
446
|
+
built_in=True,
|
|
447
|
+
built_in_type="web_search_call",
|
|
448
|
+
extra_body={"status": item["status"]},
|
|
449
|
+
)
|
|
396
450
|
)
|
|
397
|
-
)
|
|
398
451
|
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
452
|
+
elif item.get("type") == "file_search_call":
|
|
453
|
+
parts.append(
|
|
454
|
+
ToolCall(
|
|
455
|
+
id=item["id"],
|
|
456
|
+
name="file_search_call",
|
|
457
|
+
arguments={"queries": item["queries"]},
|
|
458
|
+
built_in=True,
|
|
459
|
+
built_in_type="file_search_call",
|
|
460
|
+
extra_body={
|
|
461
|
+
"status": item["status"],
|
|
462
|
+
"results": item["results"],
|
|
463
|
+
},
|
|
464
|
+
)
|
|
411
465
|
)
|
|
412
|
-
)
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
466
|
+
elif item.get("type") == "image_generation_call":
|
|
467
|
+
parts.append(
|
|
468
|
+
ToolCall(
|
|
469
|
+
id=item["id"],
|
|
470
|
+
name="image_generation_call",
|
|
471
|
+
arguments={},
|
|
472
|
+
built_in=True,
|
|
473
|
+
built_in_type="image_generation_call",
|
|
474
|
+
extra_body={
|
|
475
|
+
"status": item["status"],
|
|
476
|
+
"result": item["result"],
|
|
477
|
+
},
|
|
478
|
+
)
|
|
425
479
|
)
|
|
426
|
-
)
|
|
427
480
|
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
481
|
+
# Handle reasoning if present
|
|
482
|
+
if "reasoning" in data and data["reasoning"].get("summary"):
|
|
483
|
+
thinking = data["reasoning"]["summary"]
|
|
484
|
+
parts.append(Thinking(thinking))
|
|
432
485
|
|
|
433
|
-
|
|
486
|
+
content = Message("assistant", parts)
|
|
434
487
|
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
488
|
+
# Extract usage information
|
|
489
|
+
if "usage" in data and data["usage"] is not None:
|
|
490
|
+
usage = Usage.from_openai_usage(data["usage"])
|
|
438
491
|
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
492
|
+
except Exception as e:
|
|
493
|
+
is_error = True
|
|
494
|
+
error_message = f"Error parsing {self.model.name} responses API response: {str(e)}"
|
|
495
|
+
print("got data:", data)
|
|
496
|
+
traceback = tb.format_exc()
|
|
497
|
+
print(f"Error details:\n{traceback}")
|
|
445
498
|
|
|
446
499
|
elif mimetype and "json" in mimetype.lower():
|
|
447
500
|
print("is_error True, json response")
|
|
@@ -488,8 +541,10 @@ async def stream_chat(
|
|
|
488
541
|
extra_headers: dict[str, str] | None = None,
|
|
489
542
|
):
|
|
490
543
|
if cache is not None:
|
|
491
|
-
|
|
492
|
-
|
|
544
|
+
maybe_warn(
|
|
545
|
+
"WARN_CACHING_UNSUPPORTED",
|
|
546
|
+
model_name=model_name,
|
|
547
|
+
cache_param=cache,
|
|
493
548
|
)
|
|
494
549
|
|
|
495
550
|
model = APIModel.from_registry(model_name)
|
lm_deluge/batches.py
CHANGED
|
@@ -3,7 +3,7 @@ import json
|
|
|
3
3
|
import os
|
|
4
4
|
import tempfile
|
|
5
5
|
import time
|
|
6
|
-
from typing import Literal, Sequence
|
|
6
|
+
from typing import Literal, Sequence, cast
|
|
7
7
|
|
|
8
8
|
import aiohttp
|
|
9
9
|
from rich.console import Console
|
|
@@ -16,7 +16,12 @@ from lm_deluge.api_requests.anthropic import _build_anthropic_request
|
|
|
16
16
|
from lm_deluge.api_requests.openai import _build_oa_chat_request
|
|
17
17
|
from lm_deluge.config import SamplingParams
|
|
18
18
|
from lm_deluge.models import APIModel, registry
|
|
19
|
-
from lm_deluge.prompt import
|
|
19
|
+
from lm_deluge.prompt import (
|
|
20
|
+
CachePattern,
|
|
21
|
+
Conversation,
|
|
22
|
+
Prompt,
|
|
23
|
+
prompts_to_conversations,
|
|
24
|
+
)
|
|
20
25
|
from lm_deluge.request_context import RequestContext
|
|
21
26
|
|
|
22
27
|
|
|
@@ -166,14 +171,18 @@ async def _submit_anthropic_batch(file_path: str, headers: dict, model: str):
|
|
|
166
171
|
async def create_batch_files_oa(
|
|
167
172
|
model: str,
|
|
168
173
|
sampling_params: SamplingParams,
|
|
169
|
-
prompts:
|
|
174
|
+
prompts: Prompt | Sequence[Prompt],
|
|
170
175
|
batch_size: int = 50_000,
|
|
171
176
|
destination: str | None = None, # if none provided, temp files
|
|
172
177
|
):
|
|
173
178
|
MAX_BATCH_SIZE_BYTES = 200 * 1024 * 1024 # 200MB
|
|
174
179
|
MAX_BATCH_SIZE_ITEMS = batch_size
|
|
175
180
|
|
|
176
|
-
|
|
181
|
+
if not isinstance(prompts, list):
|
|
182
|
+
prompts = cast(Sequence[Prompt], [prompts])
|
|
183
|
+
|
|
184
|
+
prompts = prompts_to_conversations(cast(Sequence[Prompt], prompts))
|
|
185
|
+
assert isinstance(prompts, Sequence)
|
|
177
186
|
if any(p is None for p in prompts):
|
|
178
187
|
raise ValueError("All prompts must be valid.")
|
|
179
188
|
|
|
@@ -251,14 +260,18 @@ async def create_batch_files_oa(
|
|
|
251
260
|
async def submit_batches_oa(
|
|
252
261
|
model: str,
|
|
253
262
|
sampling_params: SamplingParams,
|
|
254
|
-
prompts:
|
|
263
|
+
prompts: Prompt | Sequence[Prompt],
|
|
255
264
|
batch_size: int = 50_000,
|
|
256
265
|
):
|
|
257
266
|
"""Write OpenAI batch requests to a file and submit."""
|
|
258
267
|
MAX_BATCH_SIZE_BYTES = 200 * 1024 * 1024 # 200MB
|
|
259
268
|
MAX_BATCH_SIZE_ITEMS = batch_size
|
|
260
269
|
|
|
261
|
-
|
|
270
|
+
if not isinstance(prompts, list):
|
|
271
|
+
prompts = prompts = cast(Sequence[Prompt], [prompts])
|
|
272
|
+
|
|
273
|
+
prompts = prompts_to_conversations(cast(Sequence[Prompt], prompts))
|
|
274
|
+
assert isinstance(prompts, Sequence)
|
|
262
275
|
if any(p is None for p in prompts):
|
|
263
276
|
raise ValueError("All prompts must be valid.")
|
|
264
277
|
|
|
@@ -342,7 +355,7 @@ async def submit_batches_oa(
|
|
|
342
355
|
async def submit_batches_anthropic(
|
|
343
356
|
model: str,
|
|
344
357
|
sampling_params: SamplingParams,
|
|
345
|
-
prompts:
|
|
358
|
+
prompts: Prompt | Sequence[Prompt],
|
|
346
359
|
*,
|
|
347
360
|
cache: CachePattern | None = None,
|
|
348
361
|
batch_size=100_000,
|
|
@@ -362,13 +375,16 @@ async def submit_batches_anthropic(
|
|
|
362
375
|
MAX_BATCH_SIZE_ITEMS = batch_size
|
|
363
376
|
|
|
364
377
|
# Convert prompts to Conversations
|
|
365
|
-
|
|
378
|
+
if not isinstance(prompts, list):
|
|
379
|
+
prompts = prompts = cast(Sequence[Prompt], [prompts])
|
|
380
|
+
|
|
381
|
+
prompts = prompts_to_conversations(cast(Sequence[Prompt], prompts))
|
|
366
382
|
|
|
367
383
|
request_headers = None
|
|
368
384
|
batch_tasks = []
|
|
369
385
|
current_batch = []
|
|
370
386
|
current_batch_size = 0
|
|
371
|
-
|
|
387
|
+
assert isinstance(prompts, Sequence)
|
|
372
388
|
for idx, prompt in enumerate(prompts):
|
|
373
389
|
assert isinstance(prompt, Conversation)
|
|
374
390
|
context = RequestContext(
|