chatlas 0.12.0__py3-none-any.whl → 0.13.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of chatlas might be problematic. Click here for more details.

chatlas/__init__.py CHANGED
@@ -1,5 +1,11 @@
1
1
  from . import types
2
2
  from ._auto import ChatAuto
3
+ from ._batch_chat import (
4
+ batch_chat,
5
+ batch_chat_completed,
6
+ batch_chat_structured,
7
+ batch_chat_text,
8
+ )
3
9
  from ._chat import Chat
4
10
  from ._content import (
5
11
  ContentToolRequest,
@@ -36,6 +42,10 @@ except ImportError: # pragma: no cover
36
42
  __version__ = "0.0.0" # stub value for docs
37
43
 
38
44
  __all__ = (
45
+ "batch_chat",
46
+ "batch_chat_completed",
47
+ "batch_chat_structured",
48
+ "batch_chat_text",
39
49
  "ChatAnthropic",
40
50
  "ChatAuto",
41
51
  "ChatBedrockAnthropic",
chatlas/_batch_chat.py ADDED
@@ -0,0 +1,211 @@
1
+ """
2
+ Batch chat processing for submitting multiple requests simultaneously.
3
+
4
+ This module provides functionality for submitting multiple chat requests
5
+ in batches to providers that support it (currently OpenAI and Anthropic).
6
+ Batch processing can take up to 24 hours but offers significant cost savings
7
+ (up to 50% less than regular requests).
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import copy
13
+ from pathlib import Path
14
+ from typing import TypeVar, Union
15
+
16
+ from pydantic import BaseModel
17
+
18
+ from ._batch_job import BatchJob, ContentT
19
+ from ._chat import Chat
20
+
21
+ ChatT = TypeVar("ChatT", bound=Chat)
22
+ BaseModelT = TypeVar("BaseModelT", bound=BaseModel)
23
+
24
+
25
+ def batch_chat(
26
+ chat: ChatT,
27
+ prompts: list[ContentT] | list[list[ContentT]],
28
+ path: Union[str, Path],
29
+ wait: bool = True,
30
+ ) -> list[ChatT | None]:
31
+ """
32
+ Submit multiple chat requests in a batch.
33
+
34
+ This function allows you to submit multiple chat requests simultaneously
35
+ using provider batch APIs (currently OpenAI and Anthropic). Batch processing
36
+ can take up to 24 hours but offers significant cost savings.
37
+
38
+ Parameters
39
+ ----------
40
+ chat
41
+ Chat instance to use for the batch
42
+ prompts
43
+ List of prompts to process. Each can be a string or list of strings.
44
+ path
45
+ Path to file (with .json extension) to store batch state
46
+ wait
47
+ If True, wait for batch to complete. If False, return None if incomplete.
48
+
49
+ Returns
50
+ -------
51
+ List of Chat objects (one per prompt) if complete, None if wait=False and incomplete.
52
+ Individual Chat objects may be None if their request failed.
53
+
54
+ Example
55
+ -------
56
+
57
+ ```python
58
+ from chatlas import ChatOpenAI
59
+
60
+ chat = ChatOpenAI()
61
+ prompts = [
62
+ "What's the capital of France?",
63
+ "What's the capital of Germany?",
64
+ "What's the capital of Italy?",
65
+ ]
66
+
67
+ chats = batch_chat(chat, prompts, "capitals.json")
68
+ for i, result_chat in enumerate(chats):
69
+ if result_chat:
70
+ print(f"Prompt {i + 1}: {result_chat.get_last_turn().text}")
71
+ ```
72
+ """
73
+ job = BatchJob(chat, prompts, path, wait=wait)
74
+ job.step_until_done()
75
+
76
+ chats = []
77
+ assistant_turns = job.result_turns()
78
+ for user, assistant in zip(job.user_turns, assistant_turns):
79
+ if assistant is not None:
80
+ new_chat = copy.deepcopy(chat)
81
+ new_chat.add_turn(user)
82
+ new_chat.add_turn(assistant)
83
+ chats.append(new_chat)
84
+ else:
85
+ chats.append(None)
86
+
87
+ return chats
88
+
89
+
90
+ def batch_chat_text(
91
+ chat: Chat,
92
+ prompts: list[ContentT] | list[list[ContentT]],
93
+ path: Union[str, Path],
94
+ wait: bool = True,
95
+ ) -> list[str | None]:
96
+ """
97
+ Submit multiple chat requests in a batch and return text responses.
98
+
99
+ This is a convenience function that returns just the text of the responses
100
+ rather than full Chat objects.
101
+
102
+ Parameters
103
+ ----------
104
+ chat
105
+ Chat instance to use for the batch
106
+ prompts
107
+ List of prompts to process
108
+ path
109
+ Path to file (with .json extension) to store batch state
110
+ wait
111
+ If True, wait for batch to complete
112
+
113
+ Return
114
+ ------
115
+ List of text responses (or None for failed requests)
116
+ """
117
+ chats = batch_chat(chat, prompts, path, wait=wait)
118
+
119
+ texts = []
120
+ for x in chats:
121
+ if x is None:
122
+ texts.append(None)
123
+ continue
124
+ last_turn = x.get_last_turn()
125
+ if last_turn is None:
126
+ texts.append(None)
127
+ continue
128
+ texts.append(last_turn.text)
129
+
130
+ return texts
131
+
132
+
133
+ def batch_chat_structured(
134
+ chat: Chat,
135
+ prompts: list[ContentT] | list[list[ContentT]],
136
+ path: Union[str, Path],
137
+ data_model: type[BaseModelT],
138
+ wait: bool = True,
139
+ ) -> list[BaseModelT | None]:
140
+ """
141
+ Submit multiple structured data requests in a batch.
142
+
143
+ Parameters
144
+ ----------
145
+ chat
146
+ Chat instance to use for the batch
147
+ prompts
148
+ List of prompts to process
149
+ path
150
+ Path to file (with .json extension) to store batch state
151
+ data_model
152
+ Pydantic model class for structured responses
153
+ wait
154
+ If True, wait for batch to complete
155
+
156
+ Return
157
+ ------
158
+ List of structured data objects (or None for failed requests)
159
+ """
160
+ job = BatchJob(chat, prompts, path, data_model=data_model, wait=wait)
161
+ result = job.step_until_done()
162
+
163
+ if result is None:
164
+ return []
165
+
166
+ res: list[BaseModelT | None] = []
167
+ assistant_turns = job.result_turns()
168
+ for turn in assistant_turns:
169
+ if turn is None:
170
+ res.append(None)
171
+ else:
172
+ json = chat._extract_turn_json(turn)
173
+ model = data_model.model_validate(json)
174
+ res.append(model)
175
+
176
+ return res
177
+
178
+
179
+ def batch_chat_completed(
180
+ chat: Chat,
181
+ prompts: list[ContentT] | list[list[ContentT]],
182
+ path: Union[str, Path],
183
+ ) -> bool:
184
+ """
185
+ Check if a batch job is completed without waiting.
186
+
187
+ Parameters
188
+ ----------
189
+ chat
190
+ Chat instance used for the batch
191
+ prompts
192
+ List of prompts used for the batch
193
+ path
194
+ Path to batch state file
195
+
196
+ Returns
197
+ -------
198
+ True if batch is complete, False otherwise
199
+ """
200
+ job = BatchJob(chat, prompts, path, wait=False)
201
+ stage = job.stage
202
+
203
+ if stage == "submitting":
204
+ return False
205
+ elif stage == "waiting":
206
+ status = job._poll()
207
+ return not status.working
208
+ elif stage == "retrieving" or stage == "done":
209
+ return True
210
+ else:
211
+ raise ValueError(f"Unknown batch stage: {stage}")
chatlas/_batch_job.py ADDED
@@ -0,0 +1,234 @@
1
+ from __future__ import annotations
2
+
3
+ import hashlib
4
+ import json
5
+ import time
6
+ from datetime import timedelta
7
+ from pathlib import Path
8
+ from typing import Any, Literal, Optional, TypeVar, Union
9
+
10
+ from pydantic import BaseModel
11
+ from rich.console import Console
12
+ from rich.progress import Progress, SpinnerColumn, TextColumn
13
+
14
+ from ._chat import Chat
15
+ from ._content import Content
16
+ from ._provider import BatchStatus
17
+ from ._turn import Turn, user_turn
18
+ from ._typing_extensions import TypedDict
19
+
20
+ BatchStage = Literal["submitting", "waiting", "retrieving", "done"]
21
+
22
+
23
+ class BatchStateHash(TypedDict):
24
+ provider: str
25
+ model: str
26
+ prompts: str
27
+ user_turns: str
28
+
29
+
30
+ class BatchState(BaseModel):
31
+ version: int
32
+ stage: BatchStage
33
+ batch: dict[str, Any]
34
+ results: list[dict[str, Any]]
35
+ started_at: int
36
+ hash: BatchStateHash
37
+
38
+
39
+ ContentT = TypeVar("ContentT", bound=Union[str, Content])
40
+
41
+
42
+ class BatchJob:
43
+ """
44
+ Manages the lifecycle of a batch processing job.
45
+
46
+ A batch job goes through several stages:
47
+ 1. "submitting" - Initial submission to the provider
48
+ 2. "waiting" - Waiting for processing to complete
49
+ 3. "retrieving" - Downloading results
50
+ 4. "done" - Processing complete
51
+ """
52
+
53
+ def __init__(
54
+ self,
55
+ chat: Chat,
56
+ prompts: list[ContentT] | list[list[ContentT]],
57
+ path: Union[str, Path],
58
+ data_model: Optional[type[BaseModel]] = None,
59
+ wait: bool = True,
60
+ ):
61
+ if not chat.provider.has_batch_support():
62
+ raise ValueError("Batch requests are not supported by this provider")
63
+
64
+ self.chat = chat
65
+ self.prompts = prompts
66
+ self.path = Path(path)
67
+ self.data_model = data_model
68
+ self.should_wait = wait
69
+
70
+ # Convert prompts to user turns
71
+ self.user_turns: list[Turn] = []
72
+ for prompt in prompts:
73
+ if not isinstance(prompt, (str, Content)):
74
+ turn = user_turn(*prompt)
75
+ else:
76
+ turn = user_turn(prompt)
77
+ self.user_turns.append(turn)
78
+
79
+ # Job state management
80
+ self.provider = chat.provider
81
+ self.stage: BatchStage = "submitting"
82
+ self.batch: dict[str, Any] = {}
83
+ self.results: list[dict[str, Any]] = []
84
+
85
+ # Load existing state if file exists and is not empty
86
+ if self.path.exists() and self.path.stat().st_size > 0:
87
+ self._load_state()
88
+ else:
89
+ self.started_at = time.time()
90
+
91
+ def _load_state(self) -> None:
92
+ with open(self.path, "r") as f:
93
+ state = BatchState.model_validate_json(f.read())
94
+
95
+ self.stage = state.stage
96
+ self.batch = state.batch
97
+ self.results = state.results
98
+ self.started_at = state.started_at
99
+
100
+ # Verify hash to ensure consistency
101
+ stored_hash = state.hash
102
+ current_hash = self._compute_hash()
103
+
104
+ for key, value in current_hash.items():
105
+ if stored_hash.get(key) != value:
106
+ raise ValueError(
107
+ f"Batch state mismatch: {key} doesn't match stored value. "
108
+ f"Do you need to pick a different path?"
109
+ )
110
+
111
+ def _save_state(self) -> None:
112
+ state = BatchState(
113
+ version=1,
114
+ stage=self.stage,
115
+ batch=self.batch,
116
+ results=self.results,
117
+ started_at=int(self.started_at) if self.started_at else 0,
118
+ hash=self._compute_hash(),
119
+ )
120
+
121
+ with open(self.path, "w") as f:
122
+ f.write(state.model_dump_json(indent=2))
123
+
124
+ def _compute_hash(self) -> BatchStateHash:
125
+ turns = self.chat.get_turns(include_system_prompt=True)
126
+ return {
127
+ "provider": self.provider.name,
128
+ "model": self.provider.model,
129
+ "prompts": self._hash([str(p) for p in self.prompts]),
130
+ "user_turns": self._hash([str(turn) for turn in turns]),
131
+ }
132
+
133
+ @staticmethod
134
+ def _hash(x: Any) -> str:
135
+ return hashlib.md5(json.dumps(x, sort_keys=True).encode()).hexdigest()
136
+
137
+ def step(self) -> bool:
138
+ if self.stage == "submitting":
139
+ return self._submit()
140
+ elif self.stage == "waiting":
141
+ return self._wait()
142
+ elif self.stage == "retrieving":
143
+ return self._retrieve()
144
+ else:
145
+ raise ValueError(f"Unknown stage: {self.stage}")
146
+
147
+ def step_until_done(self) -> Optional["BatchJob"]:
148
+ while self.stage != "done":
149
+ if not self.step():
150
+ return None
151
+ return self
152
+
153
+ def _submit(self) -> bool:
154
+ existing_turns = self.chat.get_turns(include_system_prompt=True)
155
+
156
+ conversations = []
157
+ for turn in self.user_turns:
158
+ conversation = existing_turns + [turn]
159
+ conversations.append(conversation)
160
+
161
+ self.batch = self.provider.batch_submit(conversations, self.data_model)
162
+ self.stage = "waiting"
163
+ self._save_state()
164
+ return True
165
+
166
+ def _wait(self) -> bool:
167
+ # Always poll once, even when wait=False
168
+ status = self._poll()
169
+
170
+ if self.should_wait:
171
+ console = Console()
172
+
173
+ with Progress(
174
+ SpinnerColumn(),
175
+ TextColumn("Processing..."),
176
+ TextColumn("[{task.fields[elapsed]}]"),
177
+ TextColumn("{task.fields[n_processing]} pending |"),
178
+ TextColumn("[green]{task.fields[n_succeeded]}[/green] done |"),
179
+ TextColumn("[red]{task.fields[n_failed]}[/red] failed"),
180
+ console=console,
181
+ ) as progress:
182
+ task = progress.add_task(
183
+ "processing",
184
+ elapsed=self._elapsed(),
185
+ n_processing=status.n_processing,
186
+ n_succeeded=status.n_succeeded,
187
+ n_failed=status.n_failed,
188
+ )
189
+
190
+ while status.working:
191
+ time.sleep(0.5)
192
+ status = self._poll()
193
+ progress.update(
194
+ task,
195
+ elapsed=self._elapsed(),
196
+ n_processing=status.n_processing,
197
+ n_succeeded=status.n_succeeded,
198
+ n_failed=status.n_failed,
199
+ )
200
+
201
+ if not status.working:
202
+ self.stage = "retrieving"
203
+ self._save_state()
204
+ return True
205
+ else:
206
+ return False
207
+
208
+ def _poll(self) -> "BatchStatus":
209
+ if not self.batch:
210
+ raise ValueError("No batch to poll")
211
+ self.batch = self.provider.batch_poll(self.batch)
212
+ self._save_state()
213
+ return self.provider.batch_status(self.batch)
214
+
215
+ def _elapsed(self) -> str:
216
+ return str(timedelta(seconds=int(time.time()) - int(self.started_at)))
217
+
218
+ def _retrieve(self) -> bool:
219
+ if not self.batch:
220
+ raise ValueError("No batch to retrieve")
221
+ self.results = self.provider.batch_retrieve(self.batch)
222
+ self.stage = "done"
223
+ self._save_state()
224
+ return True
225
+
226
+ def result_turns(self) -> list[Turn | None]:
227
+ turns = []
228
+ for result in self.results:
229
+ turn = self.provider.batch_result_turn(
230
+ result, has_data_model=self.data_model is not None
231
+ )
232
+ turns.append(turn)
233
+
234
+ return turns