chatlas 0.12.0__py3-none-any.whl → 0.13.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of chatlas might be problematic. Click here for more details.

chatlas/_chat.py CHANGED
@@ -78,6 +78,7 @@ CompletionT = TypeVar("CompletionT")
78
78
  EchoOptions = Literal["output", "all", "none", "text"]
79
79
 
80
80
  T = TypeVar("T")
81
+ BaseModelT = TypeVar("BaseModelT", bound=BaseModel)
81
82
 
82
83
 
83
84
  def is_present(value: T | None | MISSING_TYPE) -> TypeGuard[T]:
@@ -209,6 +210,7 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
209
210
  self,
210
211
  *,
211
212
  include_system_prompt: bool = False,
213
+ tool_result_role: Literal["assistant", "user"] = "user",
212
214
  ) -> list[Turn[CompletionT]]:
213
215
  """
214
216
  Get all the turns (i.e., message contents) in the chat.
@@ -217,14 +219,50 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
217
219
  ----------
218
220
  include_system_prompt
219
221
  Whether to include the system prompt in the turns.
222
+ tool_result_role
223
+ The role to assign to turns containing tool results. By default,
224
+ tool results are assigned a role of "user" since they represent
225
+ information provided to the assistant. If set to "assistant" tool
226
+ result content (plus the surrounding assistant turn contents) is
227
+ collected into a single assistant turn. This is convenient for
228
+ display purposes and more generally if you want the tool calling
229
+ loop to be contained in a single turn.
220
230
  """
221
231
 
222
232
  if not self._turns:
223
233
  return self._turns
224
234
 
225
235
  if not include_system_prompt and self._turns[0].role == "system":
226
- return self._turns[1:]
227
- return self._turns
236
+ turns = self._turns[1:]
237
+ else:
238
+ turns = self._turns
239
+
240
+ if tool_result_role == "user":
241
+ return turns
242
+
243
+ if tool_result_role != "assistant":
244
+ raise ValueError(
245
+ f"Expected `tool_result_role` to be one of 'user' or 'assistant', not '{tool_result_role}'"
246
+ )
247
+
248
+ # If a turn is purely a tool result, change its role
249
+ turns2 = copy.deepcopy(turns)
250
+ for turn in turns2:
251
+ if all(isinstance(c, ContentToolResult) for c in turn.contents):
252
+ turn.role = tool_result_role
253
+
254
+ # If two consecutive turns have the same role (i.e., assistant), collapse them into one
255
+ final_turns: list[Turn[CompletionT]] = []
256
+ for x in turns2:
257
+ if not final_turns:
258
+ final_turns.append(x)
259
+ continue
260
+ if x.role != final_turns[-1].role:
261
+ final_turns.append(x)
262
+ else:
263
+ final_turns[-1].contents.extend(x.contents)
264
+
265
+ return final_turns
228
266
 
229
267
  def get_last_turn(
230
268
  self,
@@ -531,7 +569,7 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
531
569
  args
532
570
  The input to get a token count for.
533
571
  data_model
534
- If the input is meant for data extraction (i.e., `.extract_data()`), then
572
+ If the input is meant for data extraction (i.e., `.chat_structured()`), then
535
573
  this should be the Pydantic model that describes the structure of the data to
536
574
  extract.
537
575
 
@@ -585,7 +623,7 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
585
623
  args
586
624
  The input to get a token count for.
587
625
  data_model
588
- If this input is meant for data extraction (i.e., `.extract_data_async()`),
626
+ If this input is meant for data extraction (i.e., `.chat_structured_async()`),
589
627
  then this should be the Pydantic model that describes the structure of the data
590
628
  to extract.
591
629
 
@@ -608,6 +646,7 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
608
646
  port: int = 0,
609
647
  host: str = "127.0.0.1",
610
648
  launch_browser: bool = True,
649
+ bookmark_store: Literal["url", "server", "disable"] = "url",
611
650
  bg_thread: Optional[bool] = None,
612
651
  echo: Optional[EchoOptions] = None,
613
652
  content: Literal["text", "all"] = "all",
@@ -626,6 +665,12 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
626
665
  The host to run the app on (the default is "127.0.0.1").
627
666
  launch_browser
628
667
  Whether to launch a browser window.
668
+ bookmark_store
669
+ One of the following (default is "url"):
670
+ - `"url"`: Store bookmarks in the URL (default).
671
+ - `"server"`: Store bookmarks on the server (requires a server-side
672
+ storage backend).
673
+ - `"disable"`: Disable bookmarking.
629
674
  bg_thread
630
675
  Whether to run the app in a background thread. If `None`, the app will
631
676
  run in a background thread if the current environment is a notebook.
@@ -647,24 +692,37 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
647
692
  from shiny import App, run_app, ui
648
693
  except ImportError:
649
694
  raise ImportError(
650
- "The `shiny` package is required for the `browser` method. "
695
+ "The `shiny` package is required for the `app()` method. "
651
696
  "Install it with `pip install shiny`."
652
697
  )
653
698
 
654
- app_ui = ui.page_fillable(
655
- ui.chat_ui("chat"),
656
- fillable_mobile=True,
657
- )
699
+ try:
700
+ from shinychat import (
701
+ Chat,
702
+ chat_ui,
703
+ message_content, # pyright: ignore[reportAttributeAccessIssue]
704
+ )
705
+ except ImportError:
706
+ raise ImportError(
707
+ "The `shinychat` package is required for the `app()` method. "
708
+ "Install it with `pip install shinychat`."
709
+ )
658
710
 
659
- def server(input): # noqa: A002
660
- chat = ui.Chat(
661
- "chat",
662
- messages=[
663
- {"role": turn.role, "content": turn.text}
664
- for turn in self.get_turns()
665
- ],
711
+ messages = [
712
+ message_content(x) for x in self.get_turns(tool_result_role="assistant")
713
+ ]
714
+
715
+ def app_ui(x):
716
+ return ui.page_fillable(
717
+ chat_ui("chat", messages=messages),
718
+ fillable_mobile=True,
666
719
  )
667
720
 
721
+ def server(input): # noqa: A002
722
+ chat = Chat("chat")
723
+
724
+ chat.enable_bookmarking(self)
725
+
668
726
  @chat.on_user_submit
669
727
  async def _(user_input: str):
670
728
  if stream:
@@ -688,7 +746,7 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
688
746
  )
689
747
  )
690
748
 
691
- app = App(app_ui, server)
749
+ app = App(app_ui, server, bookmark_store=bookmark_store)
692
750
 
693
751
  def _run_app():
694
752
  run_app(app, launch_browser=launch_browser, port=port, host=host)
@@ -997,20 +1055,22 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
997
1055
 
998
1056
  return wrapper()
999
1057
 
1000
- def extract_data(
1058
+ def chat_structured(
1001
1059
  self,
1002
1060
  *args: Content | str,
1003
- data_model: type[BaseModel],
1061
+ data_model: type[BaseModelT],
1004
1062
  echo: EchoOptions = "none",
1005
1063
  stream: bool = False,
1006
- ) -> dict[str, Any]:
1064
+ ) -> BaseModelT:
1007
1065
  """
1008
- Extract structured data from the given input.
1066
+ Extract structured data.
1009
1067
 
1010
1068
  Parameters
1011
1069
  ----------
1012
1070
  args
1013
- The input to extract data from.
1071
+ The input to send to the chatbot. This is typically the text you
1072
+ want to extract data from, but it can be omitted if the data is
1073
+ obvious from the existing conversation.
1014
1074
  data_model
1015
1075
  A Pydantic model describing the structure of the data to extract.
1016
1076
  echo
@@ -1024,10 +1084,47 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
1024
1084
 
1025
1085
  Returns
1026
1086
  -------
1027
- dict[str, Any]
1028
- The extracted data.
1087
+ BaseModelT
1088
+ An instance of the provided `data_model` containing the extracted data.
1029
1089
  """
1090
+ dat = self._submit_and_extract_data(
1091
+ *args,
1092
+ data_model=data_model,
1093
+ echo=echo,
1094
+ stream=stream,
1095
+ )
1096
+ return data_model.model_validate(dat)
1030
1097
 
1098
+ def extract_data(
1099
+ self,
1100
+ *args: Content | str,
1101
+ data_model: type[BaseModel],
1102
+ echo: EchoOptions = "none",
1103
+ stream: bool = False,
1104
+ ) -> dict[str, Any]:
1105
+ """
1106
+ Deprecated: use `.chat_structured()` instead.
1107
+ """
1108
+ warnings.warn(
1109
+ "The `extract_data()` method is deprecated and will be removed in a future release. "
1110
+ "Use the `chat_structured()` method instead.",
1111
+ DeprecationWarning,
1112
+ stacklevel=2,
1113
+ )
1114
+ return self._submit_and_extract_data(
1115
+ *args,
1116
+ data_model=data_model,
1117
+ echo=echo,
1118
+ stream=stream,
1119
+ )
1120
+
1121
+ def _submit_and_extract_data(
1122
+ self,
1123
+ *args: Content | str,
1124
+ data_model: type[BaseModel],
1125
+ echo: EchoOptions = "none",
1126
+ stream: bool = False,
1127
+ ) -> dict[str, Any]:
1031
1128
  display = self._markdown_display(echo=echo)
1032
1129
 
1033
1130
  response = ChatResponse(
@@ -1046,33 +1143,24 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
1046
1143
  turn = self.get_last_turn()
1047
1144
  assert turn is not None
1048
1145
 
1049
- res: list[ContentJson] = []
1050
- for x in turn.contents:
1051
- if isinstance(x, ContentJson):
1052
- res.append(x)
1053
-
1054
- if len(res) != 1:
1055
- raise ValueError(
1056
- f"Data extraction failed: {len(res)} data results received."
1057
- )
1058
-
1059
- json = res[0]
1060
- return json.value
1146
+ return Chat._extract_turn_json(turn)
1061
1147
 
1062
- async def extract_data_async(
1148
+ async def chat_structured_async(
1063
1149
  self,
1064
1150
  *args: Content | str,
1065
- data_model: type[BaseModel],
1151
+ data_model: type[BaseModelT],
1066
1152
  echo: EchoOptions = "none",
1067
1153
  stream: bool = False,
1068
- ) -> dict[str, Any]:
1154
+ ) -> BaseModelT:
1069
1155
  """
1070
1156
  Extract structured data from the given input asynchronously.
1071
1157
 
1072
1158
  Parameters
1073
1159
  ----------
1074
1160
  args
1075
- The input to extract data from.
1161
+ The input to send to the chatbot. This is typically the text you
1162
+ want to extract data from, but it can be omitted if the data is
1163
+ obvious from the existing conversation.
1076
1164
  data_model
1077
1165
  A Pydantic model describing the structure of the data to extract.
1078
1166
  echo
@@ -1087,10 +1175,47 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
1087
1175
 
1088
1176
  Returns
1089
1177
  -------
1090
- dict[str, Any]
1091
- The extracted data.
1178
+ BaseModelT
1179
+ An instance of the provided `data_model` containing the extracted data.
1092
1180
  """
1181
+ dat = await self._submit_and_extract_data_async(
1182
+ *args,
1183
+ data_model=data_model,
1184
+ echo=echo,
1185
+ stream=stream,
1186
+ )
1187
+ return data_model.model_validate(dat)
1093
1188
 
1189
+ async def extract_data_async(
1190
+ self,
1191
+ *args: Content | str,
1192
+ data_model: type[BaseModel],
1193
+ echo: EchoOptions = "none",
1194
+ stream: bool = False,
1195
+ ) -> dict[str, Any]:
1196
+ """
1197
+ Deprecated: use `.chat_structured_async()` instead.
1198
+ """
1199
+ warnings.warn(
1200
+ "The `extract_data_async()` method is deprecated and will be removed in a future release. "
1201
+ "Use the `chat_structured_async()` method instead.",
1202
+ DeprecationWarning,
1203
+ stacklevel=2,
1204
+ )
1205
+ return await self._submit_and_extract_data_async(
1206
+ *args,
1207
+ data_model=data_model,
1208
+ echo=echo,
1209
+ stream=stream,
1210
+ )
1211
+
1212
+ async def _submit_and_extract_data_async(
1213
+ self,
1214
+ *args: Content | str,
1215
+ data_model: type[BaseModel],
1216
+ echo: EchoOptions = "none",
1217
+ stream: bool = False,
1218
+ ) -> dict[str, Any]:
1094
1219
  display = self._markdown_display(echo=echo)
1095
1220
 
1096
1221
  response = ChatResponseAsync(
@@ -1109,6 +1234,10 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
1109
1234
  turn = self.get_last_turn()
1110
1235
  assert turn is not None
1111
1236
 
1237
+ return Chat._extract_turn_json(turn)
1238
+
1239
+ @staticmethod
1240
+ def _extract_turn_json(turn: Turn) -> dict[str, Any]:
1112
1241
  res: list[ContentJson] = []
1113
1242
  for x in turn.contents:
1114
1243
  if isinstance(x, ContentJson):
chatlas/_content.py CHANGED
@@ -603,7 +603,7 @@ class ContentJson(Content):
603
603
  JSON content
604
604
 
605
605
  This content type primarily exists to signal structured data extraction
606
- (i.e., data extracted via [](`~chatlas.Chat`)'s `.extract_data()` method)
606
+ (i.e., data extracted via [](`~chatlas.Chat`)'s `.chat_structured()` method)
607
607
 
608
608
  Parameters
609
609
  ----------
@@ -630,7 +630,7 @@ class ContentPDF(Content):
630
630
  PDF content
631
631
 
632
632
  This content type primarily exists to signal PDF data extraction
633
- (i.e., data extracted via [](`~chatlas.Chat`)'s `.extract_data()` method)
633
+ (i.e., data extracted via [](`~chatlas.Chat`)'s `.chat_structured()` method)
634
634
 
635
635
  Parameters
636
636
  ----------
chatlas/_provider.py CHANGED
@@ -3,6 +3,7 @@ from __future__ import annotations
3
3
  from abc import ABC, abstractmethod
4
4
  from datetime import date
5
5
  from typing import (
6
+ Any,
6
7
  AsyncIterable,
7
8
  Generic,
8
9
  Iterable,
@@ -100,6 +101,16 @@ StandardModelParamNames = Literal[
100
101
  ]
101
102
 
102
103
 
104
+ # Provider-agnostic batch status info
105
+ class BatchStatus(BaseModel):
106
+ """Status information for a batch job."""
107
+
108
+ working: bool
109
+ n_processing: int
110
+ n_succeeded: int
111
+ n_failed: int
112
+
113
+
103
114
  class Provider(
104
115
  ABC,
105
116
  Generic[
@@ -261,3 +272,80 @@ class Provider(
261
272
 
262
273
  @abstractmethod
263
274
  def supported_model_params(self) -> set[StandardModelParamNames]: ...
275
+
276
+ def has_batch_support(self) -> bool:
277
+ """
278
+ Returns whether this provider supports batch processing.
279
+ Override this method to return True for providers that implement batch methods.
280
+ """
281
+ return False
282
+
283
+ def batch_submit(
284
+ self,
285
+ conversations: list[list[Turn]],
286
+ data_model: Optional[type[BaseModel]] = None,
287
+ ) -> dict[str, Any]:
288
+ """
289
+ Submit a batch of conversations for processing.
290
+
291
+ Args:
292
+ conversations: List of conversation histories (each is a list of Turns)
293
+ data_model: Optional structured data model for responses
294
+
295
+ Returns:
296
+ BatchInfo containing batch job information
297
+ """
298
+ raise NotImplementedError("This provider does not support batch processing")
299
+
300
+ def batch_poll(self, batch: dict[str, Any]) -> dict[str, Any]:
301
+ """
302
+ Poll the status of a submitted batch.
303
+
304
+ Args:
305
+ batch: Batch information returned from batch_submit
306
+
307
+ Returns:
308
+ Updated batch information
309
+ """
310
+ raise NotImplementedError("This provider does not support batch processing")
311
+
312
+ def batch_status(self, batch: dict[str, Any]) -> BatchStatus:
313
+ """
314
+ Get the status of a batch.
315
+
316
+ Args:
317
+ batch: Batch information
318
+
319
+ Returns:
320
+ BatchStatus with processing status information
321
+ """
322
+ raise NotImplementedError("This provider does not support batch processing")
323
+
324
+ def batch_retrieve(self, batch: dict[str, Any]) -> list[dict[str, Any]]:
325
+ """
326
+ Retrieve results from a completed batch.
327
+
328
+ Args:
329
+ batch: Batch information
330
+
331
+ Returns:
332
+ List of BatchResult objects, one for each request in the batch
333
+ """
334
+ raise NotImplementedError("This provider does not support batch processing")
335
+
336
+ def batch_result_turn(
337
+ self,
338
+ result: dict[str, Any],
339
+ has_data_model: bool = False,
340
+ ) -> Turn | None:
341
+ """
342
+ Convert a batch result to a Turn.
343
+
344
+ Args:
345
+ result: Individual BatchResult from batch_retrieve
346
+ has_data_model: Whether the request used a structured data model
347
+
348
+ Returns:
349
+ Turn object or None if the result was an error
350
+ """
351
+ raise NotImplementedError("This provider does not support batch processing")
@@ -1,10 +1,12 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import base64
4
+ import re
4
5
  import warnings
5
6
  from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast, overload
6
7
 
7
8
  import orjson
9
+ from openai.types.chat import ChatCompletionToolParam
8
10
  from pydantic import BaseModel
9
11
 
10
12
  from ._chat import Chat
@@ -21,7 +23,13 @@ from ._content import (
21
23
  ContentToolResultResource,
22
24
  )
23
25
  from ._logging import log_model_default
24
- from ._provider import ModelInfo, Provider, StandardModelParamNames, StandardModelParams
26
+ from ._provider import (
27
+ BatchStatus,
28
+ ModelInfo,
29
+ Provider,
30
+ StandardModelParamNames,
31
+ StandardModelParams,
32
+ )
25
33
  from ._tokens import get_token_pricing, tokens_log
26
34
  from ._tools import Tool, basemodel_to_param_schema
27
35
  from ._turn import Turn, user_turn
@@ -38,11 +46,12 @@ if TYPE_CHECKING:
38
46
  )
39
47
  from anthropic.types.document_block_param import DocumentBlockParam
40
48
  from anthropic.types.image_block_param import ImageBlockParam
49
+ from anthropic.types.message_create_params import MessageCreateParamsNonStreaming
50
+ from anthropic.types.messages.batch_create_params import Request as BatchRequest
41
51
  from anthropic.types.model_param import ModelParam
42
52
  from anthropic.types.text_block_param import TextBlockParam
43
53
  from anthropic.types.tool_result_block_param import ToolResultBlockParam
44
54
  from anthropic.types.tool_use_block_param import ToolUseBlockParam
45
- from openai.types.chat import ChatCompletionToolParam
46
55
 
47
56
  from .types.anthropic import ChatBedrockClientArgs, ChatClientArgs, SubmitInputArgs
48
57
 
@@ -631,6 +640,101 @@ class AnthropicProvider(
631
640
  completion=completion,
632
641
  )
633
642
 
643
+ def has_batch_support(self) -> bool:
644
+ return True
645
+
646
+ def batch_submit(
647
+ self,
648
+ conversations: list[list[Turn]],
649
+ data_model: Optional[type[BaseModel]] = None,
650
+ ):
651
+ from anthropic import NotGiven
652
+
653
+ requests: list["BatchRequest"] = []
654
+
655
+ for i, turns in enumerate(conversations):
656
+ kwargs = self._chat_perform_args(
657
+ stream=False,
658
+ turns=turns,
659
+ tools={},
660
+ data_model=data_model,
661
+ )
662
+
663
+ params: "MessageCreateParamsNonStreaming" = {
664
+ "messages": kwargs.get("messages", {}),
665
+ "model": self.model,
666
+ "max_tokens": kwargs.get("max_tokens", 4096),
667
+ }
668
+
669
+ # If data_model, tools/tool_choice should be present
670
+ tools = kwargs.get("tools")
671
+ tool_choice = kwargs.get("tool_choice")
672
+ if tools and not isinstance(tools, NotGiven):
673
+ params["tools"] = tools
674
+ if tool_choice and not isinstance(tool_choice, NotGiven):
675
+ params["tool_choice"] = tool_choice
676
+
677
+ requests.append({"custom_id": f"request-{i}", "params": params})
678
+
679
+ batch = self._client.messages.batches.create(requests=requests)
680
+ return batch.model_dump()
681
+
682
+ def batch_poll(self, batch):
683
+ from anthropic.types.messages import MessageBatch
684
+
685
+ batch = MessageBatch.model_validate(batch)
686
+ b = self._client.messages.batches.retrieve(batch.id)
687
+ return b.model_dump()
688
+
689
+ def batch_status(self, batch) -> "BatchStatus":
690
+ from anthropic.types.messages import MessageBatch
691
+
692
+ batch = MessageBatch.model_validate(batch)
693
+ status = batch.processing_status
694
+ counts = batch.request_counts
695
+
696
+ return BatchStatus(
697
+ working=status != "ended",
698
+ n_processing=counts.processing,
699
+ n_succeeded=counts.succeeded,
700
+ n_failed=counts.errored + counts.canceled + counts.expired,
701
+ )
702
+
703
+ # https://docs.anthropic.com/en/api/retrieving-message-batch-results
704
+ def batch_retrieve(self, batch):
705
+ from anthropic.types.messages import MessageBatch
706
+
707
+ batch = MessageBatch.model_validate(batch)
708
+ if batch.results_url is None:
709
+ raise ValueError("Batch has no results URL")
710
+
711
+ results: list[dict[str, Any]] = []
712
+ for res in self._client.messages.batches.results(batch.id):
713
+ results.append(res.model_dump())
714
+
715
+ # Sort by custom_id to maintain order
716
+ def extract_id(x: str):
717
+ match = re.search(r"-(\d+)$", x)
718
+ return int(match.group(1)) if match else 0
719
+
720
+ results.sort(key=lambda x: extract_id(x.get("custom_id", "")))
721
+
722
+ return results
723
+
724
+ def batch_result_turn(self, result, has_data_model: bool = False) -> Turn | None:
725
+ from anthropic.types.messages.message_batch_individual_response import (
726
+ MessageBatchIndividualResponse,
727
+ )
728
+
729
+ result = MessageBatchIndividualResponse.model_validate(result)
730
+ if result.result.type != "succeeded":
731
+ # TODO: offer advice on what to do?
732
+ warnings.warn(f"Batch request didn't succeed: {result.result}")
733
+ return None
734
+
735
+ message = result.result.message
736
+ return self._as_turn(message, has_data_model)
737
+
634
738
 
635
739
  def ChatBedrockAnthropic(
636
740
  *,
@@ -141,7 +141,7 @@ def ChatGithub(
141
141
 
142
142
  class GitHubProvider(OpenAIProvider):
143
143
  def __init__(self, base_url: str, **kwargs):
144
- super().__init__(**kwargs)
144
+ super().__init__(base_url=base_url, **kwargs)
145
145
  self._base_url = base_url
146
146
 
147
147
  def list_models(self) -> list[ModelInfo]:
@@ -190,7 +190,7 @@ def list_models_gh_azure(base_url: str = "https://models.inference.ai.azure.com"
190
190
  for m in models:
191
191
  info: ModelInfo = {
192
192
  "id": m["name"],
193
- "provider": m["publisher"]
193
+ "provider": m["publisher"],
194
194
  }
195
195
  res.append(info)
196
196