batch-bridge 0.0.1rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- batch_bridge/__init__.py +10 -0
- batch_bridge/_base.py +310 -0
- batch_bridge/_openai.py +187 -0
- batch_bridge/errors.py +6 -0
- batch_bridge/types.py +28 -0
- batch_bridge-0.0.1rc0.dist-info/METADATA +122 -0
- batch_bridge-0.0.1rc0.dist-info/RECORD +9 -0
- batch_bridge-0.0.1rc0.dist-info/WHEEL +4 -0
- batch_bridge-0.0.1rc0.dist-info/licenses/LICENSE +21 -0
batch_bridge/__init__.py
ADDED
@@ -0,0 +1,10 @@
|
|
1
|
+
"""BatchBridge: A library for batch processing with LangGraph.
|
2
|
+
|
3
|
+
This library provides functionality for batching items and processing them in bulk,
|
4
|
+
integrating with LangGraph's interrupt semantics for efficient batch handling.
|
5
|
+
"""
|
6
|
+
|
7
|
+
from batch_bridge._base import Bridge, wait
|
8
|
+
from batch_bridge._openai import patch_openai
|
9
|
+
|
10
|
+
__all__ = ["Bridge", "wait", "patch_openai"]
|
batch_bridge/_base.py
ADDED
@@ -0,0 +1,310 @@
|
|
1
|
+
import asyncio
|
2
|
+
import datetime
|
3
|
+
import os
|
4
|
+
import typing
|
5
|
+
import uuid
|
6
|
+
from collections import defaultdict
|
7
|
+
|
8
|
+
from langgraph.config import get_config
|
9
|
+
from langgraph.constants import CONF, CONFIG_KEY_TASK_ID
|
10
|
+
from langgraph.errors import GraphInterrupt
|
11
|
+
from langgraph.graph import StateGraph
|
12
|
+
from langgraph.graph.state import CompiledStateGraph
|
13
|
+
from langgraph.types import Command, Send, interrupt
|
14
|
+
from langgraph_api.graph import register_graph_sync
|
15
|
+
from langgraph_sdk import get_client
|
16
|
+
from typing_extensions import Annotated, TypedDict
|
17
|
+
|
18
|
+
from batch_bridge.errors import BatchIngestException
|
19
|
+
from batch_bridge.types import InFlightBatch, QueueItem, RemoveBatch, T, U, V
|
20
|
+
|
21
|
+
if typing.TYPE_CHECKING:
|
22
|
+
pass
|
23
|
+
|
24
|
+
DEFAULT_THREAD_ID = "b0be531d-55f6-4b87-a309-23d2ed28d9da"
|
25
|
+
USE_CRONS = os.getenv("LANGSMITH_LANGGRAPH_API_VARIANT") != "local_dev"
|
26
|
+
|
27
|
+
|
28
|
+
_langgraph_client = get_client()
|
29
|
+
|
30
|
+
|
31
|
+
class CompiledBridge(CompiledStateGraph, typing.Generic[T, U]):
|
32
|
+
def __init__(
|
33
|
+
self,
|
34
|
+
*args: typing.Any,
|
35
|
+
graph_id: str = "BatchBridge",
|
36
|
+
__output_coercer__: typing.Callable[[U], typing.Any] | None = None,
|
37
|
+
**kwargs: typing.Any,
|
38
|
+
) -> None:
|
39
|
+
graph_id_ = kwargs.pop("__graph_id__", graph_id)
|
40
|
+
super().__init__(*args, **kwargs)
|
41
|
+
if graph_id_ is not None:
|
42
|
+
self.__graph_id__ = graph_id_
|
43
|
+
else:
|
44
|
+
self.__graph_id__ = graph_id
|
45
|
+
self.__output_coercer__ = __output_coercer__
|
46
|
+
|
47
|
+
async def wait(self, item: T, *, thread_id: str = DEFAULT_THREAD_ID) -> U:
|
48
|
+
result = await wait(item, bridge_id=self.__graph_id__, thread_id=thread_id)
|
49
|
+
if self.__output_coercer__ is not None:
|
50
|
+
result = self.__output_coercer__(result)
|
51
|
+
return result
|
52
|
+
|
53
|
+
|
54
|
+
class Bridge:
|
55
|
+
"""A bridge for batch processing."""
|
56
|
+
|
57
|
+
def __new__(
|
58
|
+
self,
|
59
|
+
submit: typing.Callable[[list[T]], typing.Awaitable[U]],
|
60
|
+
poll: typing.Callable[[U], typing.Awaitable[V]],
|
61
|
+
*,
|
62
|
+
should_submit: typing.Optional[
|
63
|
+
typing.Callable[[list[T], typing.Optional[datetime.datetime]], bool]
|
64
|
+
] = None,
|
65
|
+
graph_id: str = "BatchBridge",
|
66
|
+
__output_coercer__: typing.Callable[[U], V] | None = None,
|
67
|
+
# job_ttl: typing.Optional[datetime.timedelta] = MISSING, # type: ignore
|
68
|
+
) -> CompiledBridge:
|
69
|
+
if not asyncio.iscoroutinefunction(submit):
|
70
|
+
raise ValueError("submit must be a coroutine function")
|
71
|
+
if not asyncio.iscoroutinefunction(poll):
|
72
|
+
raise ValueError("poll must be a coroutine function")
|
73
|
+
if should_submit is None:
|
74
|
+
should_submit = _submit_after_minute
|
75
|
+
|
76
|
+
class InputSchema(TypedDict):
|
77
|
+
tasks: typing.Annotated[list[QueueItem[T]], _reduce_batch]
|
78
|
+
event: typing.Optional[typing.Literal["poll", "submit"]]
|
79
|
+
in_flight: Annotated[list[InFlightBatch[T]], _reduce_in_flight]
|
80
|
+
|
81
|
+
class State(InputSchema):
|
82
|
+
last_submit_time: datetime.datetime
|
83
|
+
|
84
|
+
async def route_entry(
|
85
|
+
state: State,
|
86
|
+
) -> typing.Union[
|
87
|
+
typing.Literal["check_should_submit"],
|
88
|
+
typing.Sequence[typing.Union[Send, typing.Literal["check_should_submit"]]],
|
89
|
+
]:
|
90
|
+
# The bridge stuff can be triggered when:
|
91
|
+
# 1. A new task is enqueued
|
92
|
+
# 2. A auto-cron to poll tasks is triggered
|
93
|
+
if state.get("event") == "submit":
|
94
|
+
return "check_should_submit"
|
95
|
+
elif state.get("event") == "poll":
|
96
|
+
result = [
|
97
|
+
*[
|
98
|
+
Send("poll_batch", batch_val)
|
99
|
+
for batch_val in state["in_flight"]
|
100
|
+
],
|
101
|
+
"check_should_submit",
|
102
|
+
]
|
103
|
+
return result
|
104
|
+
else:
|
105
|
+
raise ValueError("Invalid event")
|
106
|
+
|
107
|
+
async def create_batch(state: State) -> dict:
|
108
|
+
tasks = state.get("tasks", [])
|
109
|
+
task_values = [task["task"] for task in tasks]
|
110
|
+
batch_payload = await submit(task_values)
|
111
|
+
# Start the poller
|
112
|
+
batch_id = str(uuid.uuid4())
|
113
|
+
configurable = get_config()[CONF]
|
114
|
+
assistant_id = configurable["assistant_id"]
|
115
|
+
if USE_CRONS:
|
116
|
+
await _langgraph_client.crons.create(
|
117
|
+
assistant_id=assistant_id,
|
118
|
+
schedule="* * * * * *",
|
119
|
+
input={
|
120
|
+
"event": "poll",
|
121
|
+
"in_flight": InFlightBatch(
|
122
|
+
batch_id=batch_id,
|
123
|
+
batch_payload=batch_payload,
|
124
|
+
origins=[task["origin"] for task in tasks],
|
125
|
+
),
|
126
|
+
},
|
127
|
+
multitask_strategy="reject",
|
128
|
+
)
|
129
|
+
else:
|
130
|
+
await _langgraph_client.runs.create(
|
131
|
+
assistant_id=assistant_id,
|
132
|
+
input={
|
133
|
+
"event": "poll",
|
134
|
+
"in_flight": InFlightBatch(
|
135
|
+
batch_id=batch_id,
|
136
|
+
batch_payload=batch_payload,
|
137
|
+
origins=[task["origin"] for task in tasks],
|
138
|
+
),
|
139
|
+
},
|
140
|
+
multitask_strategy="reject",
|
141
|
+
after_seconds=30,
|
142
|
+
)
|
143
|
+
return {
|
144
|
+
"last_submit_time": datetime.datetime.now(datetime.timezone.utc),
|
145
|
+
"tasks": "__clear__",
|
146
|
+
"in_flight": [],
|
147
|
+
}
|
148
|
+
|
149
|
+
async def check_should_submit(state: State):
|
150
|
+
last_submit_time = state.get("last_submit_time") or datetime.datetime.now(
|
151
|
+
datetime.timezone.utc
|
152
|
+
)
|
153
|
+
if state.get("tasks"):
|
154
|
+
resp = await should_submit(state["tasks"], last_submit_time)
|
155
|
+
goto = "create_batch" if resp else "__end__"
|
156
|
+
else:
|
157
|
+
goto = "__end__"
|
158
|
+
return Command(update={"last_submit_time": last_submit_time}, goto=goto)
|
159
|
+
|
160
|
+
async def poll_batch(state: InFlightBatch) -> dict:
|
161
|
+
try:
|
162
|
+
poll_result = await poll(state["batch_payload"])
|
163
|
+
except BatchIngestException as e:
|
164
|
+
poll_result = e
|
165
|
+
|
166
|
+
if poll_result is not None:
|
167
|
+
# Now we need to re-trigger ALL the original tasks
|
168
|
+
to_resume = defaultdict(dict)
|
169
|
+
if isinstance(poll_result, Exception):
|
170
|
+
detail = (
|
171
|
+
poll_result.detail
|
172
|
+
if isinstance(poll_result, BatchIngestException)
|
173
|
+
else str(poll_result)
|
174
|
+
)
|
175
|
+
response = {
|
176
|
+
"__batch_bridge__": {
|
177
|
+
"kind": "exception",
|
178
|
+
"detail": detail,
|
179
|
+
}
|
180
|
+
}
|
181
|
+
for origin in state["origins"]:
|
182
|
+
to_resume[(origin["assistant_id"], origin["thread_id"])][
|
183
|
+
origin["task_id"]
|
184
|
+
] = response
|
185
|
+
else:
|
186
|
+
for origin, result in zip(state["origins"], poll_result):
|
187
|
+
to_resume[(origin["assistant_id"], origin["thread_id"])][
|
188
|
+
origin["task_id"]
|
189
|
+
] = result
|
190
|
+
await asyncio.gather(
|
191
|
+
*[
|
192
|
+
_langgraph_client.runs.create(
|
193
|
+
thread_id=thread_id,
|
194
|
+
assistant_id=assistant_id,
|
195
|
+
command={"resume": task_resumes},
|
196
|
+
)
|
197
|
+
for (assistant_id, thread_id), task_resumes in to_resume.items()
|
198
|
+
],
|
199
|
+
# Ignore errors here - we'll just move on
|
200
|
+
return_exceptions=True,
|
201
|
+
)
|
202
|
+
# Stop polling for this batch in particular
|
203
|
+
if USE_CRONS:
|
204
|
+
configurable = get_config()[CONF]
|
205
|
+
await _langgraph_client.crons.delete(configurable["cron_id"])
|
206
|
+
return {
|
207
|
+
"in_flight": RemoveBatch(state["batch_id"]),
|
208
|
+
}
|
209
|
+
|
210
|
+
compiled: CompiledBridge[typing.Any, typing.Any] = CompiledBridge(
|
211
|
+
**(
|
212
|
+
StateGraph(State, input=InputSchema)
|
213
|
+
.add_node(poll_batch)
|
214
|
+
.add_node(create_batch)
|
215
|
+
.add_node(check_should_submit)
|
216
|
+
.add_conditional_edges(
|
217
|
+
"__start__", route_entry, ["poll_batch", "check_should_submit"]
|
218
|
+
)
|
219
|
+
.compile(name=graph_id)
|
220
|
+
.__dict__
|
221
|
+
),
|
222
|
+
graph_id=graph_id,
|
223
|
+
__output_coercer__=__output_coercer__,
|
224
|
+
)
|
225
|
+
|
226
|
+
register_graph_sync(graph_id, compiled)
|
227
|
+
|
228
|
+
return compiled
|
229
|
+
|
230
|
+
def __init__(
|
231
|
+
self,
|
232
|
+
submit: typing.Callable[[list[T]], typing.Awaitable[U]],
|
233
|
+
poll: typing.Callable[[U], typing.Awaitable[V]],
|
234
|
+
*,
|
235
|
+
should_submit: typing.Optional[
|
236
|
+
typing.Callable[[list[T], typing.Optional[datetime.datetime]], bool]
|
237
|
+
] = None,
|
238
|
+
graph_id: str = "BatchBridge",
|
239
|
+
# job_ttl: typing.Optional[datetime.timedelta] = MISSING, # type: ignore
|
240
|
+
) -> None:
|
241
|
+
"""This is cool."""
|
242
|
+
...
|
243
|
+
|
244
|
+
|
245
|
+
async def wait(
|
246
|
+
item: T,
|
247
|
+
*,
|
248
|
+
bridge_id: str, # Graph ID (~ assistant_id) of the graph in your deployment.
|
249
|
+
thread_id: str = DEFAULT_THREAD_ID,
|
250
|
+
) -> None:
|
251
|
+
configurable = get_config()[CONF]
|
252
|
+
task_id = configurable[CONFIG_KEY_TASK_ID]
|
253
|
+
assistant_id = configurable["assistant_id"]
|
254
|
+
origin_thread_id = configurable["thread_id"]
|
255
|
+
task = {
|
256
|
+
"task": item,
|
257
|
+
"origin": {
|
258
|
+
"assistant_id": assistant_id,
|
259
|
+
"thread_id": origin_thread_id,
|
260
|
+
"task_id": task_id,
|
261
|
+
},
|
262
|
+
}
|
263
|
+
try:
|
264
|
+
result = interrupt(task)
|
265
|
+
if isinstance(result, dict) and (out_of_band := result.get("__batch_bridge__")):
|
266
|
+
if out_of_band["kind"] == "exception":
|
267
|
+
raise BatchIngestException(out_of_band["detail"])
|
268
|
+
raise NotImplementedError(f"Unknown out of band type: {out_of_band}")
|
269
|
+
return result
|
270
|
+
except GraphInterrupt:
|
271
|
+
await _langgraph_client.runs.create(
|
272
|
+
thread_id=thread_id,
|
273
|
+
assistant_id=bridge_id,
|
274
|
+
if_not_exists="create",
|
275
|
+
multitask_strategy="enqueue",
|
276
|
+
input={
|
277
|
+
"event": "submit",
|
278
|
+
"tasks": task,
|
279
|
+
},
|
280
|
+
)
|
281
|
+
raise
|
282
|
+
|
283
|
+
|
284
|
+
def _reduce_batch(
|
285
|
+
existing: list[T] | None, new: T | typing.Literal["__clear__"]
|
286
|
+
) -> list[T]:
|
287
|
+
if new == "__clear__":
|
288
|
+
return []
|
289
|
+
existing = existing if existing is not None else []
|
290
|
+
return [*existing, new]
|
291
|
+
|
292
|
+
|
293
|
+
def _reduce_in_flight(
|
294
|
+
existing: list[InFlightBatch[T]] | None, new: InFlightBatch[T] | RemoveBatch
|
295
|
+
) -> list[InFlightBatch[T]]:
|
296
|
+
if isinstance(new, RemoveBatch):
|
297
|
+
return [batch for batch in existing if batch["batch_id"] != new.batch_id]
|
298
|
+
if isinstance(new, dict) and "batch_id" in new:
|
299
|
+
new = [new]
|
300
|
+
existing = existing if existing is not None else []
|
301
|
+
return [*existing, *new]
|
302
|
+
|
303
|
+
|
304
|
+
async def _submit_after_minute(tasks: list, last_submit: datetime.datetime) -> bool:
|
305
|
+
if not tasks:
|
306
|
+
return False
|
307
|
+
# Main goal here is to avoid getting rate limited
|
308
|
+
return (
|
309
|
+
datetime.datetime.now(datetime.timezone.utc) - last_submit
|
310
|
+
).total_seconds() > 60
|
batch_bridge/_openai.py
ADDED
@@ -0,0 +1,187 @@
|
|
1
|
+
import io
|
2
|
+
import json
|
3
|
+
import typing
|
4
|
+
|
5
|
+
from openai import AsyncOpenAI
|
6
|
+
from typing_extensions import TypedDict
|
7
|
+
|
8
|
+
from examples.batch_bridge.errors import BatchIngestException
|
9
|
+
from batch_bridge._base import Bridge, CompiledBridge
|
10
|
+
import functools
|
11
|
+
|
12
|
+
if typing.TYPE_CHECKING:
|
13
|
+
from openai import AsyncOpenAI
|
14
|
+
from openai.types import ChatModel
|
15
|
+
from openai.types.chat import (
|
16
|
+
ChatCompletion,
|
17
|
+
ChatCompletionAudioParam,
|
18
|
+
ChatCompletionMessageParam,
|
19
|
+
ChatCompletionModality,
|
20
|
+
ChatCompletionPredictionContentParam,
|
21
|
+
ChatCompletionReasoningEffort,
|
22
|
+
ChatCompletionToolChoiceOptionParam,
|
23
|
+
ChatCompletionToolParam,
|
24
|
+
)
|
25
|
+
from openai.types.chat.completion_create_params import (
|
26
|
+
ResponseFormat,
|
27
|
+
)
|
28
|
+
|
29
|
+
CompletionOutputs = ChatCompletion
|
30
|
+
else:
|
31
|
+
CompletionOutputs = typing.Any
|
32
|
+
|
33
|
+
|
34
|
+
class CompletionInputs(TypedDict, total=False):
|
35
|
+
messages: typing.Required[typing.Sequence["ChatCompletionMessageParam"]]
|
36
|
+
model: typing.Required[typing.Union[str, "ChatModel"]]
|
37
|
+
audio: typing.Optional["ChatCompletionAudioParam"]
|
38
|
+
frequency_penalty: typing.Optional[float]
|
39
|
+
logit_bias: typing.Optional[dict[str, int]]
|
40
|
+
logprobs: typing.Optional[bool]
|
41
|
+
max_completion_tokens: typing.Optional[int]
|
42
|
+
max_tokens: typing.Optional[int]
|
43
|
+
metadata: typing.Optional[dict[str, str]]
|
44
|
+
modalities: typing.Optional[list["ChatCompletionModality"]]
|
45
|
+
n: typing.Optional[int]
|
46
|
+
parallel_tool_calls: bool
|
47
|
+
prediction: typing.Optional["ChatCompletionPredictionContentParam"]
|
48
|
+
presence_penalty: typing.Optional[float]
|
49
|
+
reasoning_effort: "ChatCompletionReasoningEffort"
|
50
|
+
response_format: "ResponseFormat"
|
51
|
+
seed: typing.Optional[int]
|
52
|
+
service_tier: typing.Optional[typing.Literal["auto", "default"]]
|
53
|
+
stop: typing.Union[typing.Optional[str], list[str]]
|
54
|
+
store: typing.Optional[bool]
|
55
|
+
temperature: typing.Optional[float]
|
56
|
+
tool_choice: "ChatCompletionToolChoiceOptionParam"
|
57
|
+
tools: typing.Sequence["ChatCompletionToolParam"]
|
58
|
+
top_logprobs: typing.Optional[int]
|
59
|
+
top_p: typing.Optional[float]
|
60
|
+
user: str
|
61
|
+
|
62
|
+
|
63
|
+
class Task(TypedDict, total=False):
|
64
|
+
custom_id: str
|
65
|
+
method: str
|
66
|
+
url: str
|
67
|
+
body: CompletionInputs
|
68
|
+
|
69
|
+
|
70
|
+
def patch_openai(client: typing.Optional["AsyncOpenAI"] = None) -> AsyncOpenAI:
|
71
|
+
from openai import AsyncOpenAI
|
72
|
+
from openai._types import NotGiven
|
73
|
+
|
74
|
+
if client is None:
|
75
|
+
client = AsyncOpenAI()
|
76
|
+
bridge = OpenAIBridge(client)
|
77
|
+
|
78
|
+
@functools.wraps(client.chat.completions.create)
|
79
|
+
async def async_openai_completions(
|
80
|
+
*args: typing.Any,
|
81
|
+
**kwargs: typing.Any,
|
82
|
+
) -> typing.Any:
|
83
|
+
to_send = {k: v for k, v in kwargs.items() if not isinstance(v, NotGiven)}
|
84
|
+
return await bridge.wait(CompletionInputs(**to_send))
|
85
|
+
|
86
|
+
client.chat.completions.create = async_openai_completions
|
87
|
+
|
88
|
+
return client
|
89
|
+
|
90
|
+
|
91
|
+
def OpenAIBridge(
|
92
|
+
client: typing.Optional["AsyncOpenAI"] = None,
|
93
|
+
*,
|
94
|
+
graph_id: str = "BatchBridge",
|
95
|
+
) -> CompiledBridge[CompletionInputs, CompletionOutputs]:
|
96
|
+
handler = OpenAIHandler(client)
|
97
|
+
return Bridge(
|
98
|
+
handler.submit,
|
99
|
+
handler.poll,
|
100
|
+
graph_id=graph_id,
|
101
|
+
__output_coercer__=handler.coerce_response,
|
102
|
+
)
|
103
|
+
|
104
|
+
|
105
|
+
class OpenAIHandler:
|
106
|
+
def __init__(self, client: AsyncOpenAI):
|
107
|
+
from openai import AsyncOpenAI
|
108
|
+
from openai.types.chat import ChatCompletion
|
109
|
+
|
110
|
+
if client is None:
|
111
|
+
client = AsyncOpenAI()
|
112
|
+
self.client = client
|
113
|
+
self.coerce_response = lambda response: ChatCompletion(**response)
|
114
|
+
|
115
|
+
async def submit(self, tasks: list[list[CompletionInputs]]) -> dict:
|
116
|
+
"""
|
117
|
+
Submits a batch of tasks to OpenAI using an in-memory JSONL file.
|
118
|
+
|
119
|
+
Args:
|
120
|
+
tasks (list[list[dict[str, str]]]): list of tasks to be submitted.
|
121
|
+
|
122
|
+
Returns:
|
123
|
+
dict: The created batch object from the API.
|
124
|
+
"""
|
125
|
+
# Create in-memory JSONL string from tasks
|
126
|
+
# Task is a list of messages
|
127
|
+
jsonl_str = "\n".join(
|
128
|
+
json.dumps(
|
129
|
+
Task(
|
130
|
+
custom_id=str(i),
|
131
|
+
method="POST",
|
132
|
+
url="/v1/chat/completions",
|
133
|
+
body=task,
|
134
|
+
)
|
135
|
+
)
|
136
|
+
for i, task in enumerate(tasks)
|
137
|
+
)
|
138
|
+
jsonl_bytes = jsonl_str.encode("utf-8")
|
139
|
+
in_memory_file = io.BytesIO(jsonl_bytes)
|
140
|
+
|
141
|
+
# Upload the in-memory file for batch processing
|
142
|
+
batch_input_file = await self.client.files.create(
|
143
|
+
file=in_memory_file, purpose="batch"
|
144
|
+
)
|
145
|
+
batch_input_file_id = batch_input_file.id
|
146
|
+
|
147
|
+
# Create the batch job with a 24h completion window
|
148
|
+
batch_object = await self.client.batches.create(
|
149
|
+
input_file_id=batch_input_file_id,
|
150
|
+
endpoint="/v1/chat/completions",
|
151
|
+
completion_window="24h",
|
152
|
+
metadata={"description": "nightly eval job"},
|
153
|
+
)
|
154
|
+
|
155
|
+
return batch_object.id
|
156
|
+
|
157
|
+
async def poll(self, batch_id: str) -> typing.Optional[list[dict[str, str]]]:
|
158
|
+
"""
|
159
|
+
Polls a batch job and, if complete, retrieves and parses the results.
|
160
|
+
|
161
|
+
Args:
|
162
|
+
batch_id (str): The ID of the batch job to poll.
|
163
|
+
|
164
|
+
Returns:
|
165
|
+
list[dict[str, str]] or None: The list of parsed results if the job is complete; otherwise, None.
|
166
|
+
"""
|
167
|
+
batch_obj = await self.client.batches.retrieve(batch_id)
|
168
|
+
if batch_obj.status == "failed":
|
169
|
+
raise BatchIngestException(f"Batch {batch_id} failed")
|
170
|
+
if batch_obj.status != "completed":
|
171
|
+
return None
|
172
|
+
|
173
|
+
result_file_id = batch_obj.output_file_id
|
174
|
+
result_content = (await self.client.files.content(result_file_id)).content
|
175
|
+
|
176
|
+
# Parse JSONL content from the result file into a list of objects
|
177
|
+
results = []
|
178
|
+
for line in result_content.decode("utf-8").splitlines():
|
179
|
+
if line.strip():
|
180
|
+
data = json.loads(line)
|
181
|
+
if (response := data.get("response")) and (
|
182
|
+
body := response.get("body")
|
183
|
+
):
|
184
|
+
results.append(body)
|
185
|
+
else:
|
186
|
+
results.append(data)
|
187
|
+
return results
|
batch_bridge/errors.py
ADDED
batch_bridge/types.py
ADDED
@@ -0,0 +1,28 @@
|
|
1
|
+
import typing
|
2
|
+
|
3
|
+
from typing_extensions import TypedDict
|
4
|
+
|
5
|
+
T = typing.TypeVar("T")
|
6
|
+
U = typing.TypeVar("U")
|
7
|
+
V = typing.TypeVar("V")
|
8
|
+
|
9
|
+
|
10
|
+
class Origin(TypedDict):
|
11
|
+
assistant_id: str
|
12
|
+
thread_id: str
|
13
|
+
task_id: str # UUID
|
14
|
+
|
15
|
+
|
16
|
+
class InFlightBatch(TypedDict, typing.Generic[T]):
|
17
|
+
batch_id: str
|
18
|
+
batch_payload: T
|
19
|
+
origins: typing.Sequence[Origin]
|
20
|
+
|
21
|
+
|
22
|
+
class QueueItem(TypedDict, typing.Generic[T]):
|
23
|
+
origin: Origin
|
24
|
+
task: T
|
25
|
+
|
26
|
+
|
27
|
+
class RemoveBatch(typing.NamedTuple):
|
28
|
+
batch_id: str
|
@@ -0,0 +1,122 @@
|
|
1
|
+
Metadata-Version: 2.4
|
2
|
+
Name: batch-bridge
|
3
|
+
Version: 0.0.1rc0
|
4
|
+
Summary: Prebuilt utilities for memory management and retrieval.
|
5
|
+
License: MIT License
|
6
|
+
|
7
|
+
Copyright (c) 2025 LangChain
|
8
|
+
|
9
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
10
|
+
of this software and associated documentation files (the "Software"), to deal
|
11
|
+
in the Software without restriction, including without limitation the rights
|
12
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
13
|
+
copies of the Software, and to permit persons to whom the Software is
|
14
|
+
furnished to do so, subject to the following conditions:
|
15
|
+
|
16
|
+
The above copyright notice and this permission notice shall be included in all
|
17
|
+
copies or substantial portions of the Software.
|
18
|
+
|
19
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
20
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
21
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
22
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
23
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
24
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
25
|
+
SOFTWARE.
|
26
|
+
License-File: LICENSE
|
27
|
+
Requires-Python: >=3.11
|
28
|
+
Requires-Dist: langgraph-api>=0.0.33
|
29
|
+
Requires-Dist: langgraph>=0.2.66
|
30
|
+
Description-Content-Type: text/markdown
|
31
|
+
|
32
|
+
# BatchBridge
|
33
|
+
|
34
|
+
BatchBridge is a library for efficient batch processing with LangGraph. It provides a mechanism to collect items, process them in batches, and handle results asynchronously using LangGraph's interrupt semantics.
|
35
|
+
|
36
|
+
## What is BatchBridge?
|
37
|
+
|
38
|
+
Batch APIs can cut AI inference costs by 50% or more, but they're difficult to use in agent workflows. They force you to manually aggregate requests and design your entire agent loop around batch processing rather than focusing on individual tasks. This makes your code more complex, harder to maintain, and less shareable.
|
39
|
+
|
40
|
+
BatchBridge solves this by making batch APIs work like standard completion APIs in LangGraph. Your code makes normal API calls while BatchBridge handles batching, submission, polling, and resumption behind the scenes. This lets you design and improve an agent using single completions, then make a one-line change to let it economically scale.
|
41
|
+
|
42
|
+
We aim to give you significant cost savings with minimal code complexity.
|
43
|
+
|
44
|
+
## Installation
|
45
|
+
|
46
|
+
```bash
|
47
|
+
pip install -e .
|
48
|
+
```
|
49
|
+
|
50
|
+
Since BatchBridge relies on LangGraph's durable execution and cron functionality, it must be
|
51
|
+
run on the LangGraph platform.
|
52
|
+
|
53
|
+
|
54
|
+
## Example with OpenAI's Batch API
|
55
|
+
|
56
|
+
BatchBridge has a native integration with OpenAI's Batch API:
|
57
|
+
|
58
|
+
```python
|
59
|
+
from batch_bridge import patch_openai
|
60
|
+
from openai import AsyncOpenAI
|
61
|
+
from langgraph.graph import StateGraph
|
62
|
+
from typing_extensions import Annotated, TypedDict
|
63
|
+
|
64
|
+
# Patch the client at the global level
|
65
|
+
client = patch_openai(AsyncOpenAI())
|
66
|
+
|
67
|
+
|
68
|
+
class State(TypedDict):
|
69
|
+
messages: Annotated[list[dict], lambda x, y: x + y]
|
70
|
+
|
71
|
+
|
72
|
+
async def my_model(state: State):
|
73
|
+
# This will:
|
74
|
+
# 1. submit the message to our bridge graph
|
75
|
+
# 2. Interrupt this agent graph.
|
76
|
+
# 3. resume once the bridge graph detects that the batch is complete
|
77
|
+
result = await client.chat.completions.create(
|
78
|
+
model="gpt-4o-mini", messages=state["messages"]
|
79
|
+
)
|
80
|
+
return {"messages": [result]}
|
81
|
+
|
82
|
+
|
83
|
+
graph = StateGraph(State).add_node(my_model).add_edge("__start__", "my_model").compile()
|
84
|
+
```
|
85
|
+
|
86
|
+
## Basic Usage
|
87
|
+
|
88
|
+
Under the hood, BatchBridge relies on two basic functions:
|
89
|
+
a submit() function and a poll() function.
|
90
|
+
|
91
|
+
Here's a simple example of how to use BatchBridge:
|
92
|
+
|
93
|
+
```python
|
94
|
+
from datetime import datetime, timedelta
|
95
|
+
from batch_bridge import Batcher
|
96
|
+
|
97
|
+
# Define functions for batch processing
|
98
|
+
def submit_batch(items):
|
99
|
+
"""Submit a batch of items for processing."""
|
100
|
+
# In a real implementation, this would submit to an external API
|
101
|
+
# and return a batch ID
|
102
|
+
print(f"Submitting batch of {len(items)} items")
|
103
|
+
return "batch_123"
|
104
|
+
|
105
|
+
def poll_batch(batch_id):
|
106
|
+
"""Poll for the results of a batch."""
|
107
|
+
# In a real implementation, this would check the status of the batch
|
108
|
+
# and return results when available
|
109
|
+
import time
|
110
|
+
time.sleep(2) # Simulate processing time
|
111
|
+
return [f"Processed: {item}" for item in ["item1", "item2"]]
|
112
|
+
|
113
|
+
# Create a batcher with default flush criteria
|
114
|
+
batcher = Batcher(
|
115
|
+
submit_func=submit_batch,
|
116
|
+
poll_func=poll_batch,
|
117
|
+
)
|
118
|
+
```
|
119
|
+
|
120
|
+
## License
|
121
|
+
|
122
|
+
MIT
|
@@ -0,0 +1,9 @@
|
|
1
|
+
batch_bridge/__init__.py,sha256=Iw5iz6Z5ev599n9NEppOEsbsJTZW4ZTvHTbsWM91U5o,368
|
2
|
+
batch_bridge/_base.py,sha256=NUpFnguVA-1naZ-aLRaHhLzdsohh6OyYuNJydzpO6Ic,11580
|
3
|
+
batch_bridge/_openai.py,sha256=aRbKB3OZJcWRI3eOhwsp7B56JXngkDy767w5WjoSqI4,6215
|
4
|
+
batch_bridge/errors.py,sha256=jVMWhYIwltJbG0wawGk3EO0C_mCsKXVeusRrP98dDmo,132
|
5
|
+
batch_bridge/types.py,sha256=G1SMbyc0FZI7eWV_YbYxCR920JuuTCeaFP1aec75g8g,488
|
6
|
+
batch_bridge-0.0.1rc0.dist-info/METADATA,sha256=HKtOffXaOb1znRqoqiXJ4MrOYaGpqMaxWAD-uLkzywM,4567
|
7
|
+
batch_bridge-0.0.1rc0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
+
batch_bridge-0.0.1rc0.dist-info/licenses/LICENSE,sha256=mK8TUeqFbgCMg1vImjEpBZYKMYBy-VBzK_NGx0ECfH0,1066
|
9
|
+
batch_bridge-0.0.1rc0.dist-info/RECORD,,
|
@@ -0,0 +1,21 @@
|
|
1
|
+
MIT License
|
2
|
+
|
3
|
+
Copyright (c) 2025 LangChain
|
4
|
+
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
7
|
+
in the Software without restriction, including without limitation the rights
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
10
|
+
furnished to do so, subject to the following conditions:
|
11
|
+
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
13
|
+
copies or substantial portions of the Software.
|
14
|
+
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21
|
+
SOFTWARE.
|