chatlas 0.12.0__py3-none-any.whl → 0.13.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of chatlas might be problematic. Click here for more details.
- chatlas/__init__.py +10 -0
- chatlas/_batch_chat.py +211 -0
- chatlas/_batch_job.py +234 -0
- chatlas/_chat.py +171 -42
- chatlas/_content.py +2 -2
- chatlas/_provider.py +88 -0
- chatlas/_provider_anthropic.py +106 -2
- chatlas/_provider_github.py +2 -2
- chatlas/_provider_openai.py +143 -12
- chatlas/_version.py +2 -2
- {chatlas-0.12.0.dist-info → chatlas-0.13.1.dist-info}/METADATA +2 -1
- {chatlas-0.12.0.dist-info → chatlas-0.13.1.dist-info}/RECORD +14 -12
- {chatlas-0.12.0.dist-info → chatlas-0.13.1.dist-info}/WHEEL +0 -0
- {chatlas-0.12.0.dist-info → chatlas-0.13.1.dist-info}/licenses/LICENSE +0 -0
chatlas/_chat.py
CHANGED
|
@@ -78,6 +78,7 @@ CompletionT = TypeVar("CompletionT")
|
|
|
78
78
|
EchoOptions = Literal["output", "all", "none", "text"]
|
|
79
79
|
|
|
80
80
|
T = TypeVar("T")
|
|
81
|
+
BaseModelT = TypeVar("BaseModelT", bound=BaseModel)
|
|
81
82
|
|
|
82
83
|
|
|
83
84
|
def is_present(value: T | None | MISSING_TYPE) -> TypeGuard[T]:
|
|
@@ -209,6 +210,7 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
209
210
|
self,
|
|
210
211
|
*,
|
|
211
212
|
include_system_prompt: bool = False,
|
|
213
|
+
tool_result_role: Literal["assistant", "user"] = "user",
|
|
212
214
|
) -> list[Turn[CompletionT]]:
|
|
213
215
|
"""
|
|
214
216
|
Get all the turns (i.e., message contents) in the chat.
|
|
@@ -217,14 +219,50 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
217
219
|
----------
|
|
218
220
|
include_system_prompt
|
|
219
221
|
Whether to include the system prompt in the turns.
|
|
222
|
+
tool_result_role
|
|
223
|
+
The role to assign to turns containing tool results. By default,
|
|
224
|
+
tool results are assigned a role of "user" since they represent
|
|
225
|
+
information provided to the assistant. If set to "assistant" tool
|
|
226
|
+
result content (plus the surrounding assistant turn contents) is
|
|
227
|
+
collected into a single assistant turn. This is convenient for
|
|
228
|
+
display purposes and more generally if you want the tool calling
|
|
229
|
+
loop to be contained in a single turn.
|
|
220
230
|
"""
|
|
221
231
|
|
|
222
232
|
if not self._turns:
|
|
223
233
|
return self._turns
|
|
224
234
|
|
|
225
235
|
if not include_system_prompt and self._turns[0].role == "system":
|
|
226
|
-
|
|
227
|
-
|
|
236
|
+
turns = self._turns[1:]
|
|
237
|
+
else:
|
|
238
|
+
turns = self._turns
|
|
239
|
+
|
|
240
|
+
if tool_result_role == "user":
|
|
241
|
+
return turns
|
|
242
|
+
|
|
243
|
+
if tool_result_role != "assistant":
|
|
244
|
+
raise ValueError(
|
|
245
|
+
f"Expected `tool_result_role` to be one of 'user' or 'assistant', not '{tool_result_role}'"
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
# If a turn is purely a tool result, change its role
|
|
249
|
+
turns2 = copy.deepcopy(turns)
|
|
250
|
+
for turn in turns2:
|
|
251
|
+
if all(isinstance(c, ContentToolResult) for c in turn.contents):
|
|
252
|
+
turn.role = tool_result_role
|
|
253
|
+
|
|
254
|
+
# If two consecutive turns have the same role (i.e., assistant), collapse them into one
|
|
255
|
+
final_turns: list[Turn[CompletionT]] = []
|
|
256
|
+
for x in turns2:
|
|
257
|
+
if not final_turns:
|
|
258
|
+
final_turns.append(x)
|
|
259
|
+
continue
|
|
260
|
+
if x.role != final_turns[-1].role:
|
|
261
|
+
final_turns.append(x)
|
|
262
|
+
else:
|
|
263
|
+
final_turns[-1].contents.extend(x.contents)
|
|
264
|
+
|
|
265
|
+
return final_turns
|
|
228
266
|
|
|
229
267
|
def get_last_turn(
|
|
230
268
|
self,
|
|
@@ -531,7 +569,7 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
531
569
|
args
|
|
532
570
|
The input to get a token count for.
|
|
533
571
|
data_model
|
|
534
|
-
If the input is meant for data extraction (i.e., `.
|
|
572
|
+
If the input is meant for data extraction (i.e., `.chat_structured()`), then
|
|
535
573
|
this should be the Pydantic model that describes the structure of the data to
|
|
536
574
|
extract.
|
|
537
575
|
|
|
@@ -585,7 +623,7 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
585
623
|
args
|
|
586
624
|
The input to get a token count for.
|
|
587
625
|
data_model
|
|
588
|
-
If this input is meant for data extraction (i.e., `.
|
|
626
|
+
If this input is meant for data extraction (i.e., `.chat_structured_async()`),
|
|
589
627
|
then this should be the Pydantic model that describes the structure of the data
|
|
590
628
|
to extract.
|
|
591
629
|
|
|
@@ -608,6 +646,7 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
608
646
|
port: int = 0,
|
|
609
647
|
host: str = "127.0.0.1",
|
|
610
648
|
launch_browser: bool = True,
|
|
649
|
+
bookmark_store: Literal["url", "server", "disable"] = "url",
|
|
611
650
|
bg_thread: Optional[bool] = None,
|
|
612
651
|
echo: Optional[EchoOptions] = None,
|
|
613
652
|
content: Literal["text", "all"] = "all",
|
|
@@ -626,6 +665,12 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
626
665
|
The host to run the app on (the default is "127.0.0.1").
|
|
627
666
|
launch_browser
|
|
628
667
|
Whether to launch a browser window.
|
|
668
|
+
bookmark_store
|
|
669
|
+
One of the following (default is "url"):
|
|
670
|
+
- `"url"`: Store bookmarks in the URL (default).
|
|
671
|
+
- `"server"`: Store bookmarks on the server (requires a server-side
|
|
672
|
+
storage backend).
|
|
673
|
+
- `"disable"`: Disable bookmarking.
|
|
629
674
|
bg_thread
|
|
630
675
|
Whether to run the app in a background thread. If `None`, the app will
|
|
631
676
|
run in a background thread if the current environment is a notebook.
|
|
@@ -647,24 +692,37 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
647
692
|
from shiny import App, run_app, ui
|
|
648
693
|
except ImportError:
|
|
649
694
|
raise ImportError(
|
|
650
|
-
"The `shiny` package is required for the `
|
|
695
|
+
"The `shiny` package is required for the `app()` method. "
|
|
651
696
|
"Install it with `pip install shiny`."
|
|
652
697
|
)
|
|
653
698
|
|
|
654
|
-
|
|
655
|
-
|
|
656
|
-
|
|
657
|
-
|
|
699
|
+
try:
|
|
700
|
+
from shinychat import (
|
|
701
|
+
Chat,
|
|
702
|
+
chat_ui,
|
|
703
|
+
message_content, # pyright: ignore[reportAttributeAccessIssue]
|
|
704
|
+
)
|
|
705
|
+
except ImportError:
|
|
706
|
+
raise ImportError(
|
|
707
|
+
"The `shinychat` package is required for the `app()` method. "
|
|
708
|
+
"Install it with `pip install shinychat`."
|
|
709
|
+
)
|
|
658
710
|
|
|
659
|
-
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
|
|
663
|
-
|
|
664
|
-
|
|
665
|
-
|
|
711
|
+
messages = [
|
|
712
|
+
message_content(x) for x in self.get_turns(tool_result_role="assistant")
|
|
713
|
+
]
|
|
714
|
+
|
|
715
|
+
def app_ui(x):
|
|
716
|
+
return ui.page_fillable(
|
|
717
|
+
chat_ui("chat", messages=messages),
|
|
718
|
+
fillable_mobile=True,
|
|
666
719
|
)
|
|
667
720
|
|
|
721
|
+
def server(input): # noqa: A002
|
|
722
|
+
chat = Chat("chat")
|
|
723
|
+
|
|
724
|
+
chat.enable_bookmarking(self)
|
|
725
|
+
|
|
668
726
|
@chat.on_user_submit
|
|
669
727
|
async def _(user_input: str):
|
|
670
728
|
if stream:
|
|
@@ -688,7 +746,7 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
688
746
|
)
|
|
689
747
|
)
|
|
690
748
|
|
|
691
|
-
app = App(app_ui, server)
|
|
749
|
+
app = App(app_ui, server, bookmark_store=bookmark_store)
|
|
692
750
|
|
|
693
751
|
def _run_app():
|
|
694
752
|
run_app(app, launch_browser=launch_browser, port=port, host=host)
|
|
@@ -997,20 +1055,22 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
997
1055
|
|
|
998
1056
|
return wrapper()
|
|
999
1057
|
|
|
1000
|
-
def
|
|
1058
|
+
def chat_structured(
|
|
1001
1059
|
self,
|
|
1002
1060
|
*args: Content | str,
|
|
1003
|
-
data_model: type[
|
|
1061
|
+
data_model: type[BaseModelT],
|
|
1004
1062
|
echo: EchoOptions = "none",
|
|
1005
1063
|
stream: bool = False,
|
|
1006
|
-
) ->
|
|
1064
|
+
) -> BaseModelT:
|
|
1007
1065
|
"""
|
|
1008
|
-
Extract structured data
|
|
1066
|
+
Extract structured data.
|
|
1009
1067
|
|
|
1010
1068
|
Parameters
|
|
1011
1069
|
----------
|
|
1012
1070
|
args
|
|
1013
|
-
The input to
|
|
1071
|
+
The input to send to the chatbot. This is typically the text you
|
|
1072
|
+
want to extract data from, but it can be omitted if the data is
|
|
1073
|
+
obvious from the existing conversation.
|
|
1014
1074
|
data_model
|
|
1015
1075
|
A Pydantic model describing the structure of the data to extract.
|
|
1016
1076
|
echo
|
|
@@ -1024,10 +1084,47 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
1024
1084
|
|
|
1025
1085
|
Returns
|
|
1026
1086
|
-------
|
|
1027
|
-
|
|
1028
|
-
|
|
1087
|
+
BaseModelT
|
|
1088
|
+
An instance of the provided `data_model` containing the extracted data.
|
|
1029
1089
|
"""
|
|
1090
|
+
dat = self._submit_and_extract_data(
|
|
1091
|
+
*args,
|
|
1092
|
+
data_model=data_model,
|
|
1093
|
+
echo=echo,
|
|
1094
|
+
stream=stream,
|
|
1095
|
+
)
|
|
1096
|
+
return data_model.model_validate(dat)
|
|
1030
1097
|
|
|
1098
|
+
def extract_data(
|
|
1099
|
+
self,
|
|
1100
|
+
*args: Content | str,
|
|
1101
|
+
data_model: type[BaseModel],
|
|
1102
|
+
echo: EchoOptions = "none",
|
|
1103
|
+
stream: bool = False,
|
|
1104
|
+
) -> dict[str, Any]:
|
|
1105
|
+
"""
|
|
1106
|
+
Deprecated: use `.chat_structured()` instead.
|
|
1107
|
+
"""
|
|
1108
|
+
warnings.warn(
|
|
1109
|
+
"The `extract_data()` method is deprecated and will be removed in a future release. "
|
|
1110
|
+
"Use the `chat_structured()` method instead.",
|
|
1111
|
+
DeprecationWarning,
|
|
1112
|
+
stacklevel=2,
|
|
1113
|
+
)
|
|
1114
|
+
return self._submit_and_extract_data(
|
|
1115
|
+
*args,
|
|
1116
|
+
data_model=data_model,
|
|
1117
|
+
echo=echo,
|
|
1118
|
+
stream=stream,
|
|
1119
|
+
)
|
|
1120
|
+
|
|
1121
|
+
def _submit_and_extract_data(
|
|
1122
|
+
self,
|
|
1123
|
+
*args: Content | str,
|
|
1124
|
+
data_model: type[BaseModel],
|
|
1125
|
+
echo: EchoOptions = "none",
|
|
1126
|
+
stream: bool = False,
|
|
1127
|
+
) -> dict[str, Any]:
|
|
1031
1128
|
display = self._markdown_display(echo=echo)
|
|
1032
1129
|
|
|
1033
1130
|
response = ChatResponse(
|
|
@@ -1046,33 +1143,24 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
1046
1143
|
turn = self.get_last_turn()
|
|
1047
1144
|
assert turn is not None
|
|
1048
1145
|
|
|
1049
|
-
|
|
1050
|
-
for x in turn.contents:
|
|
1051
|
-
if isinstance(x, ContentJson):
|
|
1052
|
-
res.append(x)
|
|
1053
|
-
|
|
1054
|
-
if len(res) != 1:
|
|
1055
|
-
raise ValueError(
|
|
1056
|
-
f"Data extraction failed: {len(res)} data results received."
|
|
1057
|
-
)
|
|
1058
|
-
|
|
1059
|
-
json = res[0]
|
|
1060
|
-
return json.value
|
|
1146
|
+
return Chat._extract_turn_json(turn)
|
|
1061
1147
|
|
|
1062
|
-
async def
|
|
1148
|
+
async def chat_structured_async(
|
|
1063
1149
|
self,
|
|
1064
1150
|
*args: Content | str,
|
|
1065
|
-
data_model: type[
|
|
1151
|
+
data_model: type[BaseModelT],
|
|
1066
1152
|
echo: EchoOptions = "none",
|
|
1067
1153
|
stream: bool = False,
|
|
1068
|
-
) ->
|
|
1154
|
+
) -> BaseModelT:
|
|
1069
1155
|
"""
|
|
1070
1156
|
Extract structured data from the given input asynchronously.
|
|
1071
1157
|
|
|
1072
1158
|
Parameters
|
|
1073
1159
|
----------
|
|
1074
1160
|
args
|
|
1075
|
-
The input to
|
|
1161
|
+
The input to send to the chatbot. This is typically the text you
|
|
1162
|
+
want to extract data from, but it can be omitted if the data is
|
|
1163
|
+
obvious from the existing conversation.
|
|
1076
1164
|
data_model
|
|
1077
1165
|
A Pydantic model describing the structure of the data to extract.
|
|
1078
1166
|
echo
|
|
@@ -1087,10 +1175,47 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
1087
1175
|
|
|
1088
1176
|
Returns
|
|
1089
1177
|
-------
|
|
1090
|
-
|
|
1091
|
-
|
|
1178
|
+
BaseModelT
|
|
1179
|
+
An instance of the provided `data_model` containing the extracted data.
|
|
1092
1180
|
"""
|
|
1181
|
+
dat = await self._submit_and_extract_data_async(
|
|
1182
|
+
*args,
|
|
1183
|
+
data_model=data_model,
|
|
1184
|
+
echo=echo,
|
|
1185
|
+
stream=stream,
|
|
1186
|
+
)
|
|
1187
|
+
return data_model.model_validate(dat)
|
|
1093
1188
|
|
|
1189
|
+
async def extract_data_async(
|
|
1190
|
+
self,
|
|
1191
|
+
*args: Content | str,
|
|
1192
|
+
data_model: type[BaseModel],
|
|
1193
|
+
echo: EchoOptions = "none",
|
|
1194
|
+
stream: bool = False,
|
|
1195
|
+
) -> dict[str, Any]:
|
|
1196
|
+
"""
|
|
1197
|
+
Deprecated: use `.chat_structured_async()` instead.
|
|
1198
|
+
"""
|
|
1199
|
+
warnings.warn(
|
|
1200
|
+
"The `extract_data_async()` method is deprecated and will be removed in a future release. "
|
|
1201
|
+
"Use the `chat_structured_async()` method instead.",
|
|
1202
|
+
DeprecationWarning,
|
|
1203
|
+
stacklevel=2,
|
|
1204
|
+
)
|
|
1205
|
+
return await self._submit_and_extract_data_async(
|
|
1206
|
+
*args,
|
|
1207
|
+
data_model=data_model,
|
|
1208
|
+
echo=echo,
|
|
1209
|
+
stream=stream,
|
|
1210
|
+
)
|
|
1211
|
+
|
|
1212
|
+
async def _submit_and_extract_data_async(
|
|
1213
|
+
self,
|
|
1214
|
+
*args: Content | str,
|
|
1215
|
+
data_model: type[BaseModel],
|
|
1216
|
+
echo: EchoOptions = "none",
|
|
1217
|
+
stream: bool = False,
|
|
1218
|
+
) -> dict[str, Any]:
|
|
1094
1219
|
display = self._markdown_display(echo=echo)
|
|
1095
1220
|
|
|
1096
1221
|
response = ChatResponseAsync(
|
|
@@ -1109,6 +1234,10 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
1109
1234
|
turn = self.get_last_turn()
|
|
1110
1235
|
assert turn is not None
|
|
1111
1236
|
|
|
1237
|
+
return Chat._extract_turn_json(turn)
|
|
1238
|
+
|
|
1239
|
+
@staticmethod
|
|
1240
|
+
def _extract_turn_json(turn: Turn) -> dict[str, Any]:
|
|
1112
1241
|
res: list[ContentJson] = []
|
|
1113
1242
|
for x in turn.contents:
|
|
1114
1243
|
if isinstance(x, ContentJson):
|
chatlas/_content.py
CHANGED
|
@@ -603,7 +603,7 @@ class ContentJson(Content):
|
|
|
603
603
|
JSON content
|
|
604
604
|
|
|
605
605
|
This content type primarily exists to signal structured data extraction
|
|
606
|
-
(i.e., data extracted via [](`~chatlas.Chat`)'s `.
|
|
606
|
+
(i.e., data extracted via [](`~chatlas.Chat`)'s `.chat_structured()` method)
|
|
607
607
|
|
|
608
608
|
Parameters
|
|
609
609
|
----------
|
|
@@ -630,7 +630,7 @@ class ContentPDF(Content):
|
|
|
630
630
|
PDF content
|
|
631
631
|
|
|
632
632
|
This content type primarily exists to signal PDF data extraction
|
|
633
|
-
(i.e., data extracted via [](`~chatlas.Chat`)'s `.
|
|
633
|
+
(i.e., data extracted via [](`~chatlas.Chat`)'s `.chat_structured()` method)
|
|
634
634
|
|
|
635
635
|
Parameters
|
|
636
636
|
----------
|
chatlas/_provider.py
CHANGED
|
@@ -3,6 +3,7 @@ from __future__ import annotations
|
|
|
3
3
|
from abc import ABC, abstractmethod
|
|
4
4
|
from datetime import date
|
|
5
5
|
from typing import (
|
|
6
|
+
Any,
|
|
6
7
|
AsyncIterable,
|
|
7
8
|
Generic,
|
|
8
9
|
Iterable,
|
|
@@ -100,6 +101,16 @@ StandardModelParamNames = Literal[
|
|
|
100
101
|
]
|
|
101
102
|
|
|
102
103
|
|
|
104
|
+
# Provider-agnostic batch status info
|
|
105
|
+
class BatchStatus(BaseModel):
|
|
106
|
+
"""Status information for a batch job."""
|
|
107
|
+
|
|
108
|
+
working: bool
|
|
109
|
+
n_processing: int
|
|
110
|
+
n_succeeded: int
|
|
111
|
+
n_failed: int
|
|
112
|
+
|
|
113
|
+
|
|
103
114
|
class Provider(
|
|
104
115
|
ABC,
|
|
105
116
|
Generic[
|
|
@@ -261,3 +272,80 @@ class Provider(
|
|
|
261
272
|
|
|
262
273
|
@abstractmethod
|
|
263
274
|
def supported_model_params(self) -> set[StandardModelParamNames]: ...
|
|
275
|
+
|
|
276
|
+
def has_batch_support(self) -> bool:
|
|
277
|
+
"""
|
|
278
|
+
Returns whether this provider supports batch processing.
|
|
279
|
+
Override this method to return True for providers that implement batch methods.
|
|
280
|
+
"""
|
|
281
|
+
return False
|
|
282
|
+
|
|
283
|
+
def batch_submit(
|
|
284
|
+
self,
|
|
285
|
+
conversations: list[list[Turn]],
|
|
286
|
+
data_model: Optional[type[BaseModel]] = None,
|
|
287
|
+
) -> dict[str, Any]:
|
|
288
|
+
"""
|
|
289
|
+
Submit a batch of conversations for processing.
|
|
290
|
+
|
|
291
|
+
Args:
|
|
292
|
+
conversations: List of conversation histories (each is a list of Turns)
|
|
293
|
+
data_model: Optional structured data model for responses
|
|
294
|
+
|
|
295
|
+
Returns:
|
|
296
|
+
BatchInfo containing batch job information
|
|
297
|
+
"""
|
|
298
|
+
raise NotImplementedError("This provider does not support batch processing")
|
|
299
|
+
|
|
300
|
+
def batch_poll(self, batch: dict[str, Any]) -> dict[str, Any]:
|
|
301
|
+
"""
|
|
302
|
+
Poll the status of a submitted batch.
|
|
303
|
+
|
|
304
|
+
Args:
|
|
305
|
+
batch: Batch information returned from batch_submit
|
|
306
|
+
|
|
307
|
+
Returns:
|
|
308
|
+
Updated batch information
|
|
309
|
+
"""
|
|
310
|
+
raise NotImplementedError("This provider does not support batch processing")
|
|
311
|
+
|
|
312
|
+
def batch_status(self, batch: dict[str, Any]) -> BatchStatus:
|
|
313
|
+
"""
|
|
314
|
+
Get the status of a batch.
|
|
315
|
+
|
|
316
|
+
Args:
|
|
317
|
+
batch: Batch information
|
|
318
|
+
|
|
319
|
+
Returns:
|
|
320
|
+
BatchStatus with processing status information
|
|
321
|
+
"""
|
|
322
|
+
raise NotImplementedError("This provider does not support batch processing")
|
|
323
|
+
|
|
324
|
+
def batch_retrieve(self, batch: dict[str, Any]) -> list[dict[str, Any]]:
|
|
325
|
+
"""
|
|
326
|
+
Retrieve results from a completed batch.
|
|
327
|
+
|
|
328
|
+
Args:
|
|
329
|
+
batch: Batch information
|
|
330
|
+
|
|
331
|
+
Returns:
|
|
332
|
+
List of BatchResult objects, one for each request in the batch
|
|
333
|
+
"""
|
|
334
|
+
raise NotImplementedError("This provider does not support batch processing")
|
|
335
|
+
|
|
336
|
+
def batch_result_turn(
|
|
337
|
+
self,
|
|
338
|
+
result: dict[str, Any],
|
|
339
|
+
has_data_model: bool = False,
|
|
340
|
+
) -> Turn | None:
|
|
341
|
+
"""
|
|
342
|
+
Convert a batch result to a Turn.
|
|
343
|
+
|
|
344
|
+
Args:
|
|
345
|
+
result: Individual BatchResult from batch_retrieve
|
|
346
|
+
has_data_model: Whether the request used a structured data model
|
|
347
|
+
|
|
348
|
+
Returns:
|
|
349
|
+
Turn object or None if the result was an error
|
|
350
|
+
"""
|
|
351
|
+
raise NotImplementedError("This provider does not support batch processing")
|
chatlas/_provider_anthropic.py
CHANGED
|
@@ -1,10 +1,12 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import base64
|
|
4
|
+
import re
|
|
4
5
|
import warnings
|
|
5
6
|
from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast, overload
|
|
6
7
|
|
|
7
8
|
import orjson
|
|
9
|
+
from openai.types.chat import ChatCompletionToolParam
|
|
8
10
|
from pydantic import BaseModel
|
|
9
11
|
|
|
10
12
|
from ._chat import Chat
|
|
@@ -21,7 +23,13 @@ from ._content import (
|
|
|
21
23
|
ContentToolResultResource,
|
|
22
24
|
)
|
|
23
25
|
from ._logging import log_model_default
|
|
24
|
-
from ._provider import
|
|
26
|
+
from ._provider import (
|
|
27
|
+
BatchStatus,
|
|
28
|
+
ModelInfo,
|
|
29
|
+
Provider,
|
|
30
|
+
StandardModelParamNames,
|
|
31
|
+
StandardModelParams,
|
|
32
|
+
)
|
|
25
33
|
from ._tokens import get_token_pricing, tokens_log
|
|
26
34
|
from ._tools import Tool, basemodel_to_param_schema
|
|
27
35
|
from ._turn import Turn, user_turn
|
|
@@ -38,11 +46,12 @@ if TYPE_CHECKING:
|
|
|
38
46
|
)
|
|
39
47
|
from anthropic.types.document_block_param import DocumentBlockParam
|
|
40
48
|
from anthropic.types.image_block_param import ImageBlockParam
|
|
49
|
+
from anthropic.types.message_create_params import MessageCreateParamsNonStreaming
|
|
50
|
+
from anthropic.types.messages.batch_create_params import Request as BatchRequest
|
|
41
51
|
from anthropic.types.model_param import ModelParam
|
|
42
52
|
from anthropic.types.text_block_param import TextBlockParam
|
|
43
53
|
from anthropic.types.tool_result_block_param import ToolResultBlockParam
|
|
44
54
|
from anthropic.types.tool_use_block_param import ToolUseBlockParam
|
|
45
|
-
from openai.types.chat import ChatCompletionToolParam
|
|
46
55
|
|
|
47
56
|
from .types.anthropic import ChatBedrockClientArgs, ChatClientArgs, SubmitInputArgs
|
|
48
57
|
|
|
@@ -631,6 +640,101 @@ class AnthropicProvider(
|
|
|
631
640
|
completion=completion,
|
|
632
641
|
)
|
|
633
642
|
|
|
643
|
+
def has_batch_support(self) -> bool:
|
|
644
|
+
return True
|
|
645
|
+
|
|
646
|
+
def batch_submit(
|
|
647
|
+
self,
|
|
648
|
+
conversations: list[list[Turn]],
|
|
649
|
+
data_model: Optional[type[BaseModel]] = None,
|
|
650
|
+
):
|
|
651
|
+
from anthropic import NotGiven
|
|
652
|
+
|
|
653
|
+
requests: list["BatchRequest"] = []
|
|
654
|
+
|
|
655
|
+
for i, turns in enumerate(conversations):
|
|
656
|
+
kwargs = self._chat_perform_args(
|
|
657
|
+
stream=False,
|
|
658
|
+
turns=turns,
|
|
659
|
+
tools={},
|
|
660
|
+
data_model=data_model,
|
|
661
|
+
)
|
|
662
|
+
|
|
663
|
+
params: "MessageCreateParamsNonStreaming" = {
|
|
664
|
+
"messages": kwargs.get("messages", {}),
|
|
665
|
+
"model": self.model,
|
|
666
|
+
"max_tokens": kwargs.get("max_tokens", 4096),
|
|
667
|
+
}
|
|
668
|
+
|
|
669
|
+
# If data_model, tools/tool_choice should be present
|
|
670
|
+
tools = kwargs.get("tools")
|
|
671
|
+
tool_choice = kwargs.get("tool_choice")
|
|
672
|
+
if tools and not isinstance(tools, NotGiven):
|
|
673
|
+
params["tools"] = tools
|
|
674
|
+
if tool_choice and not isinstance(tool_choice, NotGiven):
|
|
675
|
+
params["tool_choice"] = tool_choice
|
|
676
|
+
|
|
677
|
+
requests.append({"custom_id": f"request-{i}", "params": params})
|
|
678
|
+
|
|
679
|
+
batch = self._client.messages.batches.create(requests=requests)
|
|
680
|
+
return batch.model_dump()
|
|
681
|
+
|
|
682
|
+
def batch_poll(self, batch):
|
|
683
|
+
from anthropic.types.messages import MessageBatch
|
|
684
|
+
|
|
685
|
+
batch = MessageBatch.model_validate(batch)
|
|
686
|
+
b = self._client.messages.batches.retrieve(batch.id)
|
|
687
|
+
return b.model_dump()
|
|
688
|
+
|
|
689
|
+
def batch_status(self, batch) -> "BatchStatus":
|
|
690
|
+
from anthropic.types.messages import MessageBatch
|
|
691
|
+
|
|
692
|
+
batch = MessageBatch.model_validate(batch)
|
|
693
|
+
status = batch.processing_status
|
|
694
|
+
counts = batch.request_counts
|
|
695
|
+
|
|
696
|
+
return BatchStatus(
|
|
697
|
+
working=status != "ended",
|
|
698
|
+
n_processing=counts.processing,
|
|
699
|
+
n_succeeded=counts.succeeded,
|
|
700
|
+
n_failed=counts.errored + counts.canceled + counts.expired,
|
|
701
|
+
)
|
|
702
|
+
|
|
703
|
+
# https://docs.anthropic.com/en/api/retrieving-message-batch-results
|
|
704
|
+
def batch_retrieve(self, batch):
|
|
705
|
+
from anthropic.types.messages import MessageBatch
|
|
706
|
+
|
|
707
|
+
batch = MessageBatch.model_validate(batch)
|
|
708
|
+
if batch.results_url is None:
|
|
709
|
+
raise ValueError("Batch has no results URL")
|
|
710
|
+
|
|
711
|
+
results: list[dict[str, Any]] = []
|
|
712
|
+
for res in self._client.messages.batches.results(batch.id):
|
|
713
|
+
results.append(res.model_dump())
|
|
714
|
+
|
|
715
|
+
# Sort by custom_id to maintain order
|
|
716
|
+
def extract_id(x: str):
|
|
717
|
+
match = re.search(r"-(\d+)$", x)
|
|
718
|
+
return int(match.group(1)) if match else 0
|
|
719
|
+
|
|
720
|
+
results.sort(key=lambda x: extract_id(x.get("custom_id", "")))
|
|
721
|
+
|
|
722
|
+
return results
|
|
723
|
+
|
|
724
|
+
def batch_result_turn(self, result, has_data_model: bool = False) -> Turn | None:
|
|
725
|
+
from anthropic.types.messages.message_batch_individual_response import (
|
|
726
|
+
MessageBatchIndividualResponse,
|
|
727
|
+
)
|
|
728
|
+
|
|
729
|
+
result = MessageBatchIndividualResponse.model_validate(result)
|
|
730
|
+
if result.result.type != "succeeded":
|
|
731
|
+
# TODO: offer advice on what to do?
|
|
732
|
+
warnings.warn(f"Batch request didn't succeed: {result.result}")
|
|
733
|
+
return None
|
|
734
|
+
|
|
735
|
+
message = result.result.message
|
|
736
|
+
return self._as_turn(message, has_data_model)
|
|
737
|
+
|
|
634
738
|
|
|
635
739
|
def ChatBedrockAnthropic(
|
|
636
740
|
*,
|
chatlas/_provider_github.py
CHANGED
|
@@ -141,7 +141,7 @@ def ChatGithub(
|
|
|
141
141
|
|
|
142
142
|
class GitHubProvider(OpenAIProvider):
|
|
143
143
|
def __init__(self, base_url: str, **kwargs):
|
|
144
|
-
super().__init__(**kwargs)
|
|
144
|
+
super().__init__(base_url=base_url, **kwargs)
|
|
145
145
|
self._base_url = base_url
|
|
146
146
|
|
|
147
147
|
def list_models(self) -> list[ModelInfo]:
|
|
@@ -190,7 +190,7 @@ def list_models_gh_azure(base_url: str = "https://models.inference.ai.azure.com"
|
|
|
190
190
|
for m in models:
|
|
191
191
|
info: ModelInfo = {
|
|
192
192
|
"id": m["name"],
|
|
193
|
-
"provider": m["publisher"]
|
|
193
|
+
"provider": m["publisher"],
|
|
194
194
|
}
|
|
195
195
|
res.append(info)
|
|
196
196
|
|