chatlas 0.11.1__py3-none-any.whl → 0.13.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of chatlas might be problematic. Click here for more details.
- chatlas/__init__.py +10 -0
- chatlas/_auto.py +173 -61
- chatlas/_batch_chat.py +211 -0
- chatlas/_batch_job.py +234 -0
- chatlas/_chat.py +181 -43
- chatlas/_content.py +13 -8
- chatlas/_provider.py +88 -0
- chatlas/_provider_anthropic.py +106 -2
- chatlas/_provider_openai.py +143 -12
- chatlas/_tools.py +11 -3
- chatlas/_version.py +2 -2
- chatlas/types/anthropic/_submit.py +2 -2
- chatlas/types/openai/_client.py +2 -2
- chatlas/types/openai/_client_azure.py +2 -2
- chatlas/types/openai/_submit.py +2 -2
- {chatlas-0.11.1.dist-info → chatlas-0.13.0.dist-info}/METADATA +2 -1
- {chatlas-0.11.1.dist-info → chatlas-0.13.0.dist-info}/RECORD +19 -17
- {chatlas-0.11.1.dist-info → chatlas-0.13.0.dist-info}/WHEEL +0 -0
- {chatlas-0.11.1.dist-info → chatlas-0.13.0.dist-info}/licenses/LICENSE +0 -0
chatlas/_batch_job.py
ADDED
|
@@ -0,0 +1,234 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import hashlib
|
|
4
|
+
import json
|
|
5
|
+
import time
|
|
6
|
+
from datetime import timedelta
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Any, Literal, Optional, TypeVar, Union
|
|
9
|
+
|
|
10
|
+
from pydantic import BaseModel
|
|
11
|
+
from rich.console import Console
|
|
12
|
+
from rich.progress import Progress, SpinnerColumn, TextColumn
|
|
13
|
+
|
|
14
|
+
from ._chat import Chat
|
|
15
|
+
from ._content import Content
|
|
16
|
+
from ._provider import BatchStatus
|
|
17
|
+
from ._turn import Turn, user_turn
|
|
18
|
+
from ._typing_extensions import TypedDict
|
|
19
|
+
|
|
20
|
+
BatchStage = Literal["submitting", "waiting", "retrieving", "done"]
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class BatchStateHash(TypedDict):
|
|
24
|
+
provider: str
|
|
25
|
+
model: str
|
|
26
|
+
prompts: str
|
|
27
|
+
user_turns: str
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class BatchState(BaseModel):
|
|
31
|
+
version: int
|
|
32
|
+
stage: BatchStage
|
|
33
|
+
batch: dict[str, Any]
|
|
34
|
+
results: list[dict[str, Any]]
|
|
35
|
+
started_at: int
|
|
36
|
+
hash: BatchStateHash
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
ContentT = TypeVar("ContentT", bound=Union[str, Content])
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class BatchJob:
|
|
43
|
+
"""
|
|
44
|
+
Manages the lifecycle of a batch processing job.
|
|
45
|
+
|
|
46
|
+
A batch job goes through several stages:
|
|
47
|
+
1. "submitting" - Initial submission to the provider
|
|
48
|
+
2. "waiting" - Waiting for processing to complete
|
|
49
|
+
3. "retrieving" - Downloading results
|
|
50
|
+
4. "done" - Processing complete
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
def __init__(
|
|
54
|
+
self,
|
|
55
|
+
chat: Chat,
|
|
56
|
+
prompts: list[ContentT] | list[list[ContentT]],
|
|
57
|
+
path: Union[str, Path],
|
|
58
|
+
data_model: Optional[type[BaseModel]] = None,
|
|
59
|
+
wait: bool = True,
|
|
60
|
+
):
|
|
61
|
+
if not chat.provider.has_batch_support():
|
|
62
|
+
raise ValueError("Batch requests are not supported by this provider")
|
|
63
|
+
|
|
64
|
+
self.chat = chat
|
|
65
|
+
self.prompts = prompts
|
|
66
|
+
self.path = Path(path)
|
|
67
|
+
self.data_model = data_model
|
|
68
|
+
self.should_wait = wait
|
|
69
|
+
|
|
70
|
+
# Convert prompts to user turns
|
|
71
|
+
self.user_turns: list[Turn] = []
|
|
72
|
+
for prompt in prompts:
|
|
73
|
+
if not isinstance(prompt, (str, Content)):
|
|
74
|
+
turn = user_turn(*prompt)
|
|
75
|
+
else:
|
|
76
|
+
turn = user_turn(prompt)
|
|
77
|
+
self.user_turns.append(turn)
|
|
78
|
+
|
|
79
|
+
# Job state management
|
|
80
|
+
self.provider = chat.provider
|
|
81
|
+
self.stage: BatchStage = "submitting"
|
|
82
|
+
self.batch: dict[str, Any] = {}
|
|
83
|
+
self.results: list[dict[str, Any]] = []
|
|
84
|
+
|
|
85
|
+
# Load existing state if file exists and is not empty
|
|
86
|
+
if self.path.exists() and self.path.stat().st_size > 0:
|
|
87
|
+
self._load_state()
|
|
88
|
+
else:
|
|
89
|
+
self.started_at = time.time()
|
|
90
|
+
|
|
91
|
+
def _load_state(self) -> None:
|
|
92
|
+
with open(self.path, "r") as f:
|
|
93
|
+
state = BatchState.model_validate_json(f.read())
|
|
94
|
+
|
|
95
|
+
self.stage = state.stage
|
|
96
|
+
self.batch = state.batch
|
|
97
|
+
self.results = state.results
|
|
98
|
+
self.started_at = state.started_at
|
|
99
|
+
|
|
100
|
+
# Verify hash to ensure consistency
|
|
101
|
+
stored_hash = state.hash
|
|
102
|
+
current_hash = self._compute_hash()
|
|
103
|
+
|
|
104
|
+
for key, value in current_hash.items():
|
|
105
|
+
if stored_hash.get(key) != value:
|
|
106
|
+
raise ValueError(
|
|
107
|
+
f"Batch state mismatch: {key} doesn't match stored value. "
|
|
108
|
+
f"Do you need to pick a different path?"
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
def _save_state(self) -> None:
|
|
112
|
+
state = BatchState(
|
|
113
|
+
version=1,
|
|
114
|
+
stage=self.stage,
|
|
115
|
+
batch=self.batch,
|
|
116
|
+
results=self.results,
|
|
117
|
+
started_at=int(self.started_at) if self.started_at else 0,
|
|
118
|
+
hash=self._compute_hash(),
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
with open(self.path, "w") as f:
|
|
122
|
+
f.write(state.model_dump_json(indent=2))
|
|
123
|
+
|
|
124
|
+
def _compute_hash(self) -> BatchStateHash:
|
|
125
|
+
turns = self.chat.get_turns(include_system_prompt=True)
|
|
126
|
+
return {
|
|
127
|
+
"provider": self.provider.name,
|
|
128
|
+
"model": self.provider.model,
|
|
129
|
+
"prompts": self._hash([str(p) for p in self.prompts]),
|
|
130
|
+
"user_turns": self._hash([str(turn) for turn in turns]),
|
|
131
|
+
}
|
|
132
|
+
|
|
133
|
+
@staticmethod
|
|
134
|
+
def _hash(x: Any) -> str:
|
|
135
|
+
return hashlib.md5(json.dumps(x, sort_keys=True).encode()).hexdigest()
|
|
136
|
+
|
|
137
|
+
def step(self) -> bool:
|
|
138
|
+
if self.stage == "submitting":
|
|
139
|
+
return self._submit()
|
|
140
|
+
elif self.stage == "waiting":
|
|
141
|
+
return self._wait()
|
|
142
|
+
elif self.stage == "retrieving":
|
|
143
|
+
return self._retrieve()
|
|
144
|
+
else:
|
|
145
|
+
raise ValueError(f"Unknown stage: {self.stage}")
|
|
146
|
+
|
|
147
|
+
def step_until_done(self) -> Optional["BatchJob"]:
|
|
148
|
+
while self.stage != "done":
|
|
149
|
+
if not self.step():
|
|
150
|
+
return None
|
|
151
|
+
return self
|
|
152
|
+
|
|
153
|
+
def _submit(self) -> bool:
|
|
154
|
+
existing_turns = self.chat.get_turns(include_system_prompt=True)
|
|
155
|
+
|
|
156
|
+
conversations = []
|
|
157
|
+
for turn in self.user_turns:
|
|
158
|
+
conversation = existing_turns + [turn]
|
|
159
|
+
conversations.append(conversation)
|
|
160
|
+
|
|
161
|
+
self.batch = self.provider.batch_submit(conversations, self.data_model)
|
|
162
|
+
self.stage = "waiting"
|
|
163
|
+
self._save_state()
|
|
164
|
+
return True
|
|
165
|
+
|
|
166
|
+
def _wait(self) -> bool:
|
|
167
|
+
# Always poll once, even when wait=False
|
|
168
|
+
status = self._poll()
|
|
169
|
+
|
|
170
|
+
if self.should_wait:
|
|
171
|
+
console = Console()
|
|
172
|
+
|
|
173
|
+
with Progress(
|
|
174
|
+
SpinnerColumn(),
|
|
175
|
+
TextColumn("Processing..."),
|
|
176
|
+
TextColumn("[{task.fields[elapsed]}]"),
|
|
177
|
+
TextColumn("{task.fields[n_processing]} pending |"),
|
|
178
|
+
TextColumn("[green]{task.fields[n_succeeded]}[/green] done |"),
|
|
179
|
+
TextColumn("[red]{task.fields[n_failed]}[/red] failed"),
|
|
180
|
+
console=console,
|
|
181
|
+
) as progress:
|
|
182
|
+
task = progress.add_task(
|
|
183
|
+
"processing",
|
|
184
|
+
elapsed=self._elapsed(),
|
|
185
|
+
n_processing=status.n_processing,
|
|
186
|
+
n_succeeded=status.n_succeeded,
|
|
187
|
+
n_failed=status.n_failed,
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
while status.working:
|
|
191
|
+
time.sleep(0.5)
|
|
192
|
+
status = self._poll()
|
|
193
|
+
progress.update(
|
|
194
|
+
task,
|
|
195
|
+
elapsed=self._elapsed(),
|
|
196
|
+
n_processing=status.n_processing,
|
|
197
|
+
n_succeeded=status.n_succeeded,
|
|
198
|
+
n_failed=status.n_failed,
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
if not status.working:
|
|
202
|
+
self.stage = "retrieving"
|
|
203
|
+
self._save_state()
|
|
204
|
+
return True
|
|
205
|
+
else:
|
|
206
|
+
return False
|
|
207
|
+
|
|
208
|
+
def _poll(self) -> "BatchStatus":
|
|
209
|
+
if not self.batch:
|
|
210
|
+
raise ValueError("No batch to poll")
|
|
211
|
+
self.batch = self.provider.batch_poll(self.batch)
|
|
212
|
+
self._save_state()
|
|
213
|
+
return self.provider.batch_status(self.batch)
|
|
214
|
+
|
|
215
|
+
def _elapsed(self) -> str:
|
|
216
|
+
return str(timedelta(seconds=int(time.time()) - int(self.started_at)))
|
|
217
|
+
|
|
218
|
+
def _retrieve(self) -> bool:
|
|
219
|
+
if not self.batch:
|
|
220
|
+
raise ValueError("No batch to retrieve")
|
|
221
|
+
self.results = self.provider.batch_retrieve(self.batch)
|
|
222
|
+
self.stage = "done"
|
|
223
|
+
self._save_state()
|
|
224
|
+
return True
|
|
225
|
+
|
|
226
|
+
def result_turns(self) -> list[Turn | None]:
|
|
227
|
+
turns = []
|
|
228
|
+
for result in self.results:
|
|
229
|
+
turn = self.provider.batch_result_turn(
|
|
230
|
+
result, has_data_model=self.data_model is not None
|
|
231
|
+
)
|
|
232
|
+
turns.append(turn)
|
|
233
|
+
|
|
234
|
+
return turns
|
chatlas/_chat.py
CHANGED
|
@@ -78,6 +78,7 @@ CompletionT = TypeVar("CompletionT")
|
|
|
78
78
|
EchoOptions = Literal["output", "all", "none", "text"]
|
|
79
79
|
|
|
80
80
|
T = TypeVar("T")
|
|
81
|
+
BaseModelT = TypeVar("BaseModelT", bound=BaseModel)
|
|
81
82
|
|
|
82
83
|
|
|
83
84
|
def is_present(value: T | None | MISSING_TYPE) -> TypeGuard[T]:
|
|
@@ -209,6 +210,7 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
209
210
|
self,
|
|
210
211
|
*,
|
|
211
212
|
include_system_prompt: bool = False,
|
|
213
|
+
tool_result_role: Literal["assistant", "user"] = "user",
|
|
212
214
|
) -> list[Turn[CompletionT]]:
|
|
213
215
|
"""
|
|
214
216
|
Get all the turns (i.e., message contents) in the chat.
|
|
@@ -217,14 +219,50 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
217
219
|
----------
|
|
218
220
|
include_system_prompt
|
|
219
221
|
Whether to include the system prompt in the turns.
|
|
222
|
+
tool_result_role
|
|
223
|
+
The role to assign to turns containing tool results. By default,
|
|
224
|
+
tool results are assigned a role of "user" since they represent
|
|
225
|
+
information provided to the assistant. If set to "assistant" tool
|
|
226
|
+
result content (plus the surrounding assistant turn contents) is
|
|
227
|
+
collected into a single assistant turn. This is convenient for
|
|
228
|
+
display purposes and more generally if you want the tool calling
|
|
229
|
+
loop to be contained in a single turn.
|
|
220
230
|
"""
|
|
221
231
|
|
|
222
232
|
if not self._turns:
|
|
223
233
|
return self._turns
|
|
224
234
|
|
|
225
235
|
if not include_system_prompt and self._turns[0].role == "system":
|
|
226
|
-
|
|
227
|
-
|
|
236
|
+
turns = self._turns[1:]
|
|
237
|
+
else:
|
|
238
|
+
turns = self._turns
|
|
239
|
+
|
|
240
|
+
if tool_result_role == "user":
|
|
241
|
+
return turns
|
|
242
|
+
|
|
243
|
+
if tool_result_role != "assistant":
|
|
244
|
+
raise ValueError(
|
|
245
|
+
f"Expected `tool_result_role` to be one of 'user' or 'assistant', not '{tool_result_role}'"
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
# If a turn is purely a tool result, change its role
|
|
249
|
+
turns2 = copy.deepcopy(turns)
|
|
250
|
+
for turn in turns2:
|
|
251
|
+
if all(isinstance(c, ContentToolResult) for c in turn.contents):
|
|
252
|
+
turn.role = tool_result_role
|
|
253
|
+
|
|
254
|
+
# If two consecutive turns have the same role (i.e., assistant), collapse them into one
|
|
255
|
+
final_turns: list[Turn[CompletionT]] = []
|
|
256
|
+
for x in turns2:
|
|
257
|
+
if not final_turns:
|
|
258
|
+
final_turns.append(x)
|
|
259
|
+
continue
|
|
260
|
+
if x.role != final_turns[-1].role:
|
|
261
|
+
final_turns.append(x)
|
|
262
|
+
else:
|
|
263
|
+
final_turns[-1].contents.extend(x.contents)
|
|
264
|
+
|
|
265
|
+
return final_turns
|
|
228
266
|
|
|
229
267
|
def get_last_turn(
|
|
230
268
|
self,
|
|
@@ -531,7 +569,7 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
531
569
|
args
|
|
532
570
|
The input to get a token count for.
|
|
533
571
|
data_model
|
|
534
|
-
If the input is meant for data extraction (i.e., `.
|
|
572
|
+
If the input is meant for data extraction (i.e., `.chat_structured()`), then
|
|
535
573
|
this should be the Pydantic model that describes the structure of the data to
|
|
536
574
|
extract.
|
|
537
575
|
|
|
@@ -585,7 +623,7 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
585
623
|
args
|
|
586
624
|
The input to get a token count for.
|
|
587
625
|
data_model
|
|
588
|
-
If this input is meant for data extraction (i.e., `.
|
|
626
|
+
If this input is meant for data extraction (i.e., `.chat_structured_async()`),
|
|
589
627
|
then this should be the Pydantic model that describes the structure of the data
|
|
590
628
|
to extract.
|
|
591
629
|
|
|
@@ -608,6 +646,7 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
608
646
|
port: int = 0,
|
|
609
647
|
host: str = "127.0.0.1",
|
|
610
648
|
launch_browser: bool = True,
|
|
649
|
+
bookmark_store: Literal["url", "server", "disable"] = "url",
|
|
611
650
|
bg_thread: Optional[bool] = None,
|
|
612
651
|
echo: Optional[EchoOptions] = None,
|
|
613
652
|
content: Literal["text", "all"] = "all",
|
|
@@ -626,6 +665,12 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
626
665
|
The host to run the app on (the default is "127.0.0.1").
|
|
627
666
|
launch_browser
|
|
628
667
|
Whether to launch a browser window.
|
|
668
|
+
bookmark_store
|
|
669
|
+
One of the following (default is "url"):
|
|
670
|
+
- `"url"`: Store bookmarks in the URL (default).
|
|
671
|
+
- `"server"`: Store bookmarks on the server (requires a server-side
|
|
672
|
+
storage backend).
|
|
673
|
+
- `"disable"`: Disable bookmarking.
|
|
629
674
|
bg_thread
|
|
630
675
|
Whether to run the app in a background thread. If `None`, the app will
|
|
631
676
|
run in a background thread if the current environment is a notebook.
|
|
@@ -647,24 +692,37 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
647
692
|
from shiny import App, run_app, ui
|
|
648
693
|
except ImportError:
|
|
649
694
|
raise ImportError(
|
|
650
|
-
"The `shiny` package is required for the `
|
|
695
|
+
"The `shiny` package is required for the `app()` method. "
|
|
651
696
|
"Install it with `pip install shiny`."
|
|
652
697
|
)
|
|
653
698
|
|
|
654
|
-
|
|
655
|
-
|
|
656
|
-
|
|
657
|
-
|
|
699
|
+
try:
|
|
700
|
+
from shinychat import (
|
|
701
|
+
Chat,
|
|
702
|
+
chat_ui,
|
|
703
|
+
message_content, # pyright: ignore[reportAttributeAccessIssue]
|
|
704
|
+
)
|
|
705
|
+
except ImportError:
|
|
706
|
+
raise ImportError(
|
|
707
|
+
"The `shinychat` package is required for the `app()` method. "
|
|
708
|
+
"Install it with `pip install shinychat`."
|
|
709
|
+
)
|
|
658
710
|
|
|
659
|
-
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
|
|
663
|
-
|
|
664
|
-
|
|
665
|
-
|
|
711
|
+
messages = [
|
|
712
|
+
message_content(x) for x in self.get_turns(tool_result_role="assistant")
|
|
713
|
+
]
|
|
714
|
+
|
|
715
|
+
def app_ui(x):
|
|
716
|
+
return ui.page_fillable(
|
|
717
|
+
chat_ui("chat", messages=messages),
|
|
718
|
+
fillable_mobile=True,
|
|
666
719
|
)
|
|
667
720
|
|
|
721
|
+
def server(input): # noqa: A002
|
|
722
|
+
chat = Chat("chat")
|
|
723
|
+
|
|
724
|
+
chat.enable_bookmarking(self)
|
|
725
|
+
|
|
668
726
|
@chat.on_user_submit
|
|
669
727
|
async def _(user_input: str):
|
|
670
728
|
if stream:
|
|
@@ -688,7 +746,7 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
688
746
|
)
|
|
689
747
|
)
|
|
690
748
|
|
|
691
|
-
app = App(app_ui, server)
|
|
749
|
+
app = App(app_ui, server, bookmark_store=bookmark_store)
|
|
692
750
|
|
|
693
751
|
def _run_app():
|
|
694
752
|
run_app(app, launch_browser=launch_browser, port=port, host=host)
|
|
@@ -997,20 +1055,22 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
997
1055
|
|
|
998
1056
|
return wrapper()
|
|
999
1057
|
|
|
1000
|
-
def
|
|
1058
|
+
def chat_structured(
|
|
1001
1059
|
self,
|
|
1002
1060
|
*args: Content | str,
|
|
1003
|
-
data_model: type[
|
|
1061
|
+
data_model: type[BaseModelT],
|
|
1004
1062
|
echo: EchoOptions = "none",
|
|
1005
1063
|
stream: bool = False,
|
|
1006
|
-
) ->
|
|
1064
|
+
) -> BaseModelT:
|
|
1007
1065
|
"""
|
|
1008
|
-
Extract structured data
|
|
1066
|
+
Extract structured data.
|
|
1009
1067
|
|
|
1010
1068
|
Parameters
|
|
1011
1069
|
----------
|
|
1012
1070
|
args
|
|
1013
|
-
The input to
|
|
1071
|
+
The input to send to the chatbot. This is typically the text you
|
|
1072
|
+
want to extract data from, but it can be omitted if the data is
|
|
1073
|
+
obvious from the existing conversation.
|
|
1014
1074
|
data_model
|
|
1015
1075
|
A Pydantic model describing the structure of the data to extract.
|
|
1016
1076
|
echo
|
|
@@ -1024,10 +1084,47 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
1024
1084
|
|
|
1025
1085
|
Returns
|
|
1026
1086
|
-------
|
|
1027
|
-
|
|
1028
|
-
|
|
1087
|
+
BaseModelT
|
|
1088
|
+
An instance of the provided `data_model` containing the extracted data.
|
|
1089
|
+
"""
|
|
1090
|
+
dat = self._submit_and_extract_data(
|
|
1091
|
+
*args,
|
|
1092
|
+
data_model=data_model,
|
|
1093
|
+
echo=echo,
|
|
1094
|
+
stream=stream,
|
|
1095
|
+
)
|
|
1096
|
+
return data_model.model_validate(dat)
|
|
1097
|
+
|
|
1098
|
+
def extract_data(
|
|
1099
|
+
self,
|
|
1100
|
+
*args: Content | str,
|
|
1101
|
+
data_model: type[BaseModel],
|
|
1102
|
+
echo: EchoOptions = "none",
|
|
1103
|
+
stream: bool = False,
|
|
1104
|
+
) -> dict[str, Any]:
|
|
1029
1105
|
"""
|
|
1106
|
+
Deprecated: use `.chat_structured()` instead.
|
|
1107
|
+
"""
|
|
1108
|
+
warnings.warn(
|
|
1109
|
+
"The `extract_data()` method is deprecated and will be removed in a future release. "
|
|
1110
|
+
"Use the `chat_structured()` method instead.",
|
|
1111
|
+
DeprecationWarning,
|
|
1112
|
+
stacklevel=2,
|
|
1113
|
+
)
|
|
1114
|
+
return self._submit_and_extract_data(
|
|
1115
|
+
*args,
|
|
1116
|
+
data_model=data_model,
|
|
1117
|
+
echo=echo,
|
|
1118
|
+
stream=stream,
|
|
1119
|
+
)
|
|
1030
1120
|
|
|
1121
|
+
def _submit_and_extract_data(
|
|
1122
|
+
self,
|
|
1123
|
+
*args: Content | str,
|
|
1124
|
+
data_model: type[BaseModel],
|
|
1125
|
+
echo: EchoOptions = "none",
|
|
1126
|
+
stream: bool = False,
|
|
1127
|
+
) -> dict[str, Any]:
|
|
1031
1128
|
display = self._markdown_display(echo=echo)
|
|
1032
1129
|
|
|
1033
1130
|
response = ChatResponse(
|
|
@@ -1046,33 +1143,24 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
1046
1143
|
turn = self.get_last_turn()
|
|
1047
1144
|
assert turn is not None
|
|
1048
1145
|
|
|
1049
|
-
|
|
1050
|
-
for x in turn.contents:
|
|
1051
|
-
if isinstance(x, ContentJson):
|
|
1052
|
-
res.append(x)
|
|
1053
|
-
|
|
1054
|
-
if len(res) != 1:
|
|
1055
|
-
raise ValueError(
|
|
1056
|
-
f"Data extraction failed: {len(res)} data results received."
|
|
1057
|
-
)
|
|
1058
|
-
|
|
1059
|
-
json = res[0]
|
|
1060
|
-
return json.value
|
|
1146
|
+
return Chat._extract_turn_json(turn)
|
|
1061
1147
|
|
|
1062
|
-
async def
|
|
1148
|
+
async def chat_structured_async(
|
|
1063
1149
|
self,
|
|
1064
1150
|
*args: Content | str,
|
|
1065
|
-
data_model: type[
|
|
1151
|
+
data_model: type[BaseModelT],
|
|
1066
1152
|
echo: EchoOptions = "none",
|
|
1067
1153
|
stream: bool = False,
|
|
1068
|
-
) ->
|
|
1154
|
+
) -> BaseModelT:
|
|
1069
1155
|
"""
|
|
1070
1156
|
Extract structured data from the given input asynchronously.
|
|
1071
1157
|
|
|
1072
1158
|
Parameters
|
|
1073
1159
|
----------
|
|
1074
1160
|
args
|
|
1075
|
-
The input to
|
|
1161
|
+
The input to send to the chatbot. This is typically the text you
|
|
1162
|
+
want to extract data from, but it can be omitted if the data is
|
|
1163
|
+
obvious from the existing conversation.
|
|
1076
1164
|
data_model
|
|
1077
1165
|
A Pydantic model describing the structure of the data to extract.
|
|
1078
1166
|
echo
|
|
@@ -1087,10 +1175,47 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
1087
1175
|
|
|
1088
1176
|
Returns
|
|
1089
1177
|
-------
|
|
1090
|
-
|
|
1091
|
-
|
|
1178
|
+
BaseModelT
|
|
1179
|
+
An instance of the provided `data_model` containing the extracted data.
|
|
1180
|
+
"""
|
|
1181
|
+
dat = await self._submit_and_extract_data_async(
|
|
1182
|
+
*args,
|
|
1183
|
+
data_model=data_model,
|
|
1184
|
+
echo=echo,
|
|
1185
|
+
stream=stream,
|
|
1186
|
+
)
|
|
1187
|
+
return data_model.model_validate(dat)
|
|
1188
|
+
|
|
1189
|
+
async def extract_data_async(
|
|
1190
|
+
self,
|
|
1191
|
+
*args: Content | str,
|
|
1192
|
+
data_model: type[BaseModel],
|
|
1193
|
+
echo: EchoOptions = "none",
|
|
1194
|
+
stream: bool = False,
|
|
1195
|
+
) -> dict[str, Any]:
|
|
1092
1196
|
"""
|
|
1197
|
+
Deprecated: use `.chat_structured_async()` instead.
|
|
1198
|
+
"""
|
|
1199
|
+
warnings.warn(
|
|
1200
|
+
"The `extract_data_async()` method is deprecated and will be removed in a future release. "
|
|
1201
|
+
"Use the `chat_structured_async()` method instead.",
|
|
1202
|
+
DeprecationWarning,
|
|
1203
|
+
stacklevel=2,
|
|
1204
|
+
)
|
|
1205
|
+
return await self._submit_and_extract_data_async(
|
|
1206
|
+
*args,
|
|
1207
|
+
data_model=data_model,
|
|
1208
|
+
echo=echo,
|
|
1209
|
+
stream=stream,
|
|
1210
|
+
)
|
|
1093
1211
|
|
|
1212
|
+
async def _submit_and_extract_data_async(
|
|
1213
|
+
self,
|
|
1214
|
+
*args: Content | str,
|
|
1215
|
+
data_model: type[BaseModel],
|
|
1216
|
+
echo: EchoOptions = "none",
|
|
1217
|
+
stream: bool = False,
|
|
1218
|
+
) -> dict[str, Any]:
|
|
1094
1219
|
display = self._markdown_display(echo=echo)
|
|
1095
1220
|
|
|
1096
1221
|
response = ChatResponseAsync(
|
|
@@ -1109,6 +1234,10 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
1109
1234
|
turn = self.get_last_turn()
|
|
1110
1235
|
assert turn is not None
|
|
1111
1236
|
|
|
1237
|
+
return Chat._extract_turn_json(turn)
|
|
1238
|
+
|
|
1239
|
+
@staticmethod
|
|
1240
|
+
def _extract_turn_json(turn: Turn) -> dict[str, Any]:
|
|
1112
1241
|
res: list[ContentJson] = []
|
|
1113
1242
|
for x in turn.contents:
|
|
1114
1243
|
if isinstance(x, ContentJson):
|
|
@@ -1535,7 +1664,7 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
1535
1664
|
|
|
1536
1665
|
def register_tool(
|
|
1537
1666
|
self,
|
|
1538
|
-
func: Callable[..., Any] | Callable[..., Awaitable[Any]],
|
|
1667
|
+
func: Callable[..., Any] | Callable[..., Awaitable[Any]] | Tool,
|
|
1539
1668
|
*,
|
|
1540
1669
|
force: bool = False,
|
|
1541
1670
|
name: Optional[str] = None,
|
|
@@ -1629,6 +1758,15 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]):
|
|
|
1629
1758
|
ValueError
|
|
1630
1759
|
If a tool with the same name already exists and `force` is `False`.
|
|
1631
1760
|
"""
|
|
1761
|
+
if isinstance(func, Tool):
|
|
1762
|
+
name = name or func.name
|
|
1763
|
+
annotations = annotations or func.annotations
|
|
1764
|
+
if model is not None:
|
|
1765
|
+
func = Tool.from_func(
|
|
1766
|
+
func.func, name=name, model=model, annotations=annotations
|
|
1767
|
+
)
|
|
1768
|
+
func = func.func
|
|
1769
|
+
|
|
1632
1770
|
tool = Tool.from_func(func, name=name, model=model, annotations=annotations)
|
|
1633
1771
|
if tool.name in self._tools and not force:
|
|
1634
1772
|
raise ValueError(
|
chatlas/_content.py
CHANGED
|
@@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Any, Literal, Optional, Union
|
|
|
6
6
|
import orjson
|
|
7
7
|
from pydantic import BaseModel, ConfigDict
|
|
8
8
|
|
|
9
|
-
from ._typing_extensions import
|
|
9
|
+
from ._typing_extensions import TypedDict
|
|
10
10
|
|
|
11
11
|
if TYPE_CHECKING:
|
|
12
12
|
from ._tools import Tool
|
|
@@ -24,16 +24,16 @@ class ToolAnnotations(TypedDict, total=False):
|
|
|
24
24
|
received from untrusted servers.
|
|
25
25
|
"""
|
|
26
26
|
|
|
27
|
-
title:
|
|
27
|
+
title: str
|
|
28
28
|
"""A human-readable title for the tool."""
|
|
29
29
|
|
|
30
|
-
readOnlyHint:
|
|
30
|
+
readOnlyHint: bool
|
|
31
31
|
"""
|
|
32
32
|
If true, the tool does not modify its environment.
|
|
33
33
|
Default: false
|
|
34
34
|
"""
|
|
35
35
|
|
|
36
|
-
destructiveHint:
|
|
36
|
+
destructiveHint: bool
|
|
37
37
|
"""
|
|
38
38
|
If true, the tool may perform destructive updates to its environment.
|
|
39
39
|
If false, the tool performs only additive updates.
|
|
@@ -41,7 +41,7 @@ class ToolAnnotations(TypedDict, total=False):
|
|
|
41
41
|
Default: true
|
|
42
42
|
"""
|
|
43
43
|
|
|
44
|
-
idempotentHint:
|
|
44
|
+
idempotentHint: bool
|
|
45
45
|
"""
|
|
46
46
|
If true, calling the tool repeatedly with the same arguments
|
|
47
47
|
will have no additional effect on the its environment.
|
|
@@ -49,7 +49,7 @@ class ToolAnnotations(TypedDict, total=False):
|
|
|
49
49
|
Default: false
|
|
50
50
|
"""
|
|
51
51
|
|
|
52
|
-
openWorldHint:
|
|
52
|
+
openWorldHint: bool
|
|
53
53
|
"""
|
|
54
54
|
If true, this tool may interact with an "open world" of external
|
|
55
55
|
entities. If false, the tool's domain of interaction is closed.
|
|
@@ -58,6 +58,11 @@ class ToolAnnotations(TypedDict, total=False):
|
|
|
58
58
|
Default: true
|
|
59
59
|
"""
|
|
60
60
|
|
|
61
|
+
extra: dict[str, Any]
|
|
62
|
+
"""
|
|
63
|
+
Additional metadata about the tool.
|
|
64
|
+
"""
|
|
65
|
+
|
|
61
66
|
|
|
62
67
|
ImageContentTypes = Literal[
|
|
63
68
|
"image/png",
|
|
@@ -598,7 +603,7 @@ class ContentJson(Content):
|
|
|
598
603
|
JSON content
|
|
599
604
|
|
|
600
605
|
This content type primarily exists to signal structured data extraction
|
|
601
|
-
(i.e., data extracted via [](`~chatlas.Chat`)'s `.
|
|
606
|
+
(i.e., data extracted via [](`~chatlas.Chat`)'s `.chat_structured()` method)
|
|
602
607
|
|
|
603
608
|
Parameters
|
|
604
609
|
----------
|
|
@@ -625,7 +630,7 @@ class ContentPDF(Content):
|
|
|
625
630
|
PDF content
|
|
626
631
|
|
|
627
632
|
This content type primarily exists to signal PDF data extraction
|
|
628
|
-
(i.e., data extracted via [](`~chatlas.Chat`)'s `.
|
|
633
|
+
(i.e., data extracted via [](`~chatlas.Chat`)'s `.chat_structured()` method)
|
|
629
634
|
|
|
630
635
|
Parameters
|
|
631
636
|
----------
|