aidial-adapter-anthropic 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aidial_adapter_anthropic/_utils/json.py +116 -0
- aidial_adapter_anthropic/_utils/list.py +84 -0
- aidial_adapter_anthropic/_utils/pydantic.py +6 -0
- aidial_adapter_anthropic/_utils/resource.py +54 -0
- aidial_adapter_anthropic/_utils/text.py +4 -0
- aidial_adapter_anthropic/adapter/__init__.py +4 -0
- aidial_adapter_anthropic/adapter/_base.py +95 -0
- aidial_adapter_anthropic/adapter/_claude/adapter.py +549 -0
- aidial_adapter_anthropic/adapter/_claude/blocks.py +128 -0
- aidial_adapter_anthropic/adapter/_claude/citations.py +63 -0
- aidial_adapter_anthropic/adapter/_claude/config.py +39 -0
- aidial_adapter_anthropic/adapter/_claude/converters.py +303 -0
- aidial_adapter_anthropic/adapter/_claude/params.py +25 -0
- aidial_adapter_anthropic/adapter/_claude/state.py +45 -0
- aidial_adapter_anthropic/adapter/_claude/tokenizer/__init__.py +10 -0
- aidial_adapter_anthropic/adapter/_claude/tokenizer/anthropic.py +57 -0
- aidial_adapter_anthropic/adapter/_claude/tokenizer/approximate.py +260 -0
- aidial_adapter_anthropic/adapter/_claude/tokenizer/base.py +26 -0
- aidial_adapter_anthropic/adapter/_claude/tools.py +98 -0
- aidial_adapter_anthropic/adapter/_decorator/base.py +53 -0
- aidial_adapter_anthropic/adapter/_decorator/preprocess.py +63 -0
- aidial_adapter_anthropic/adapter/_decorator/replicator.py +32 -0
- aidial_adapter_anthropic/adapter/_errors.py +71 -0
- aidial_adapter_anthropic/adapter/_tokenize.py +12 -0
- aidial_adapter_anthropic/adapter/_truncate_prompt.py +168 -0
- aidial_adapter_anthropic/adapter/claude.py +17 -0
- aidial_adapter_anthropic/dial/_attachments.py +238 -0
- aidial_adapter_anthropic/dial/_lazy_stage.py +40 -0
- aidial_adapter_anthropic/dial/_message.py +341 -0
- aidial_adapter_anthropic/dial/consumer.py +235 -0
- aidial_adapter_anthropic/dial/request.py +170 -0
- aidial_adapter_anthropic/dial/resource.py +189 -0
- aidial_adapter_anthropic/dial/storage.py +138 -0
- aidial_adapter_anthropic/dial/token_usage.py +19 -0
- aidial_adapter_anthropic/dial/tools.py +180 -0
- aidial_adapter_anthropic-0.1.0.dist-info/LICENSE +202 -0
- aidial_adapter_anthropic-0.1.0.dist-info/METADATA +121 -0
- aidial_adapter_anthropic-0.1.0.dist-info/RECORD +39 -0
- aidial_adapter_anthropic-0.1.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,341 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import List, Optional, Self, Union
|
|
3
|
+
|
|
4
|
+
from aidial_sdk.chat_completion import (
|
|
5
|
+
Attachment,
|
|
6
|
+
CacheBreakpoint,
|
|
7
|
+
CustomContent,
|
|
8
|
+
FunctionCall,
|
|
9
|
+
)
|
|
10
|
+
from aidial_sdk.chat_completion import Message as DialMessage
|
|
11
|
+
from aidial_sdk.chat_completion import (
|
|
12
|
+
MessageContentPart,
|
|
13
|
+
MessageContentTextPart,
|
|
14
|
+
MessageCustomFields,
|
|
15
|
+
Role,
|
|
16
|
+
ToolCall,
|
|
17
|
+
)
|
|
18
|
+
from pydantic import BaseModel
|
|
19
|
+
|
|
20
|
+
from aidial_adapter_anthropic.adapter._errors import ValidationError
|
|
21
|
+
from aidial_adapter_anthropic.dial.request import (
|
|
22
|
+
collect_text_content,
|
|
23
|
+
is_plain_text_content,
|
|
24
|
+
is_system_role,
|
|
25
|
+
is_text_content,
|
|
26
|
+
to_message_content,
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class MessageABC(ABC, BaseModel):
|
|
31
|
+
cache_breakpoint: CacheBreakpoint | None = None
|
|
32
|
+
|
|
33
|
+
@property
|
|
34
|
+
def custom_fields(self) -> MessageCustomFields | None:
|
|
35
|
+
if self.cache_breakpoint:
|
|
36
|
+
return MessageCustomFields(cache_breakpoint=self.cache_breakpoint)
|
|
37
|
+
return None
|
|
38
|
+
|
|
39
|
+
@abstractmethod
|
|
40
|
+
def to_message(self) -> DialMessage: ...
|
|
41
|
+
|
|
42
|
+
@classmethod
|
|
43
|
+
@abstractmethod
|
|
44
|
+
def from_message(cls, message: DialMessage) -> Self | None: ...
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class BaseMessageABC(MessageABC):
|
|
48
|
+
@property
|
|
49
|
+
@abstractmethod
|
|
50
|
+
def text_content(self) -> str: ...
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def _get_cache_breakpoint(message: DialMessage) -> CacheBreakpoint | None:
|
|
54
|
+
if message.custom_fields is None:
|
|
55
|
+
return None
|
|
56
|
+
return message.custom_fields.cache_breakpoint
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class SystemMessage(BaseMessageABC):
|
|
60
|
+
content: str | List[MessageContentTextPart]
|
|
61
|
+
is_developer: bool = False
|
|
62
|
+
|
|
63
|
+
def to_message(self) -> DialMessage:
|
|
64
|
+
return DialMessage(
|
|
65
|
+
role=Role.DEVELOPER if self.is_developer else Role.SYSTEM,
|
|
66
|
+
content=to_message_content(self.content),
|
|
67
|
+
custom_fields=self.custom_fields,
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
@classmethod
|
|
71
|
+
def from_message(cls, message: DialMessage) -> Self | None:
|
|
72
|
+
if not is_system_role(message.role):
|
|
73
|
+
return None
|
|
74
|
+
|
|
75
|
+
content = message.content
|
|
76
|
+
|
|
77
|
+
if not is_text_content(content):
|
|
78
|
+
raise ValidationError(
|
|
79
|
+
"System message is expected to be a string or a list of text content parts"
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
return cls(
|
|
83
|
+
is_developer=message.role == Role.DEVELOPER,
|
|
84
|
+
cache_breakpoint=_get_cache_breakpoint(message),
|
|
85
|
+
content=content,
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
@property
|
|
89
|
+
def text_content(self) -> str:
|
|
90
|
+
return collect_text_content(self.content)
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
class HumanRegularMessage(BaseMessageABC):
|
|
94
|
+
content: str | List[MessageContentPart]
|
|
95
|
+
custom_content: CustomContent | None = None
|
|
96
|
+
|
|
97
|
+
def to_message(self) -> DialMessage:
|
|
98
|
+
return DialMessage(
|
|
99
|
+
role=Role.USER,
|
|
100
|
+
content=self.content,
|
|
101
|
+
custom_content=self.custom_content,
|
|
102
|
+
custom_fields=self.custom_fields,
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
@classmethod
|
|
106
|
+
def from_message(cls, message: DialMessage) -> Self | None:
|
|
107
|
+
if message.role != Role.USER:
|
|
108
|
+
return None
|
|
109
|
+
|
|
110
|
+
content = message.content
|
|
111
|
+
if content is None:
|
|
112
|
+
raise ValidationError(
|
|
113
|
+
"User message is expected to have content field"
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
return cls(
|
|
117
|
+
content=content,
|
|
118
|
+
custom_content=message.custom_content,
|
|
119
|
+
cache_breakpoint=_get_cache_breakpoint(message),
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
@property
|
|
123
|
+
def text_content(self) -> str:
|
|
124
|
+
return collect_text_content(self.content)
|
|
125
|
+
|
|
126
|
+
@property
|
|
127
|
+
def attachments(self) -> List[Attachment]:
|
|
128
|
+
return (
|
|
129
|
+
self.custom_content.attachments or [] if self.custom_content else []
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
class HumanToolResultMessage(MessageABC):
|
|
134
|
+
id: str
|
|
135
|
+
content: str
|
|
136
|
+
|
|
137
|
+
def to_message(self) -> DialMessage:
|
|
138
|
+
return DialMessage(
|
|
139
|
+
role=Role.TOOL,
|
|
140
|
+
tool_call_id=self.id,
|
|
141
|
+
content=self.content,
|
|
142
|
+
custom_fields=self.custom_fields,
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
@classmethod
|
|
146
|
+
def from_message(cls, message: DialMessage) -> Self | None:
|
|
147
|
+
if message.role != Role.TOOL:
|
|
148
|
+
return None
|
|
149
|
+
|
|
150
|
+
if not is_plain_text_content(message.content):
|
|
151
|
+
raise ValidationError(
|
|
152
|
+
"The tool message shouldn't contain content parts"
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
if message.content is None or message.tool_call_id is None:
|
|
156
|
+
raise ValidationError(
|
|
157
|
+
"The tool message is expected to have content and tool_call_id fields"
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
return cls(
|
|
161
|
+
id=message.tool_call_id,
|
|
162
|
+
content=message.content,
|
|
163
|
+
cache_breakpoint=_get_cache_breakpoint(message),
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
class HumanFunctionResultMessage(MessageABC):
|
|
168
|
+
name: str
|
|
169
|
+
content: str
|
|
170
|
+
|
|
171
|
+
def to_message(self) -> DialMessage:
|
|
172
|
+
return DialMessage(
|
|
173
|
+
role=Role.FUNCTION,
|
|
174
|
+
name=self.name,
|
|
175
|
+
content=self.content,
|
|
176
|
+
custom_fields=self.custom_fields,
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
@classmethod
|
|
180
|
+
def from_message(cls, message: DialMessage) -> Self | None:
|
|
181
|
+
if message.role != Role.FUNCTION:
|
|
182
|
+
return None
|
|
183
|
+
|
|
184
|
+
if not is_plain_text_content(message.content):
|
|
185
|
+
raise ValidationError(
|
|
186
|
+
"The function message shouldn't contain content parts"
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
if message.content is None or message.name is None:
|
|
190
|
+
raise ValidationError(
|
|
191
|
+
"The function message is expected to have content and name fields"
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
return cls(
|
|
195
|
+
name=message.name,
|
|
196
|
+
content=message.content,
|
|
197
|
+
cache_breakpoint=_get_cache_breakpoint(message),
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
class AIRegularMessage(BaseMessageABC):
|
|
202
|
+
content: str | List[MessageContentPart]
|
|
203
|
+
"""
|
|
204
|
+
According to Azure OpenAI API, the assistant message could only have textual content.
|
|
205
|
+
However, we leave a loophole to provide image content parts just in case
|
|
206
|
+
one day multi-modal Bedrock models will be able to accept images in assistant messages.
|
|
207
|
+
"""
|
|
208
|
+
|
|
209
|
+
custom_content: Optional[CustomContent] = None
|
|
210
|
+
|
|
211
|
+
def to_message(self) -> DialMessage:
|
|
212
|
+
return DialMessage(
|
|
213
|
+
role=Role.ASSISTANT,
|
|
214
|
+
content=self.content, # type: ignore
|
|
215
|
+
custom_content=self.custom_content,
|
|
216
|
+
custom_fields=self.custom_fields,
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
@classmethod
|
|
220
|
+
def from_message(cls, message: DialMessage) -> Self | None:
|
|
221
|
+
if message.role != Role.ASSISTANT:
|
|
222
|
+
return None
|
|
223
|
+
|
|
224
|
+
if message.function_call is not None or message.tool_calls is not None:
|
|
225
|
+
return None
|
|
226
|
+
|
|
227
|
+
content = message.content
|
|
228
|
+
if content is None:
|
|
229
|
+
raise ValidationError(
|
|
230
|
+
"Assistant message is expected to have content field"
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
return cls(
|
|
234
|
+
content=content,
|
|
235
|
+
custom_content=message.custom_content,
|
|
236
|
+
cache_breakpoint=_get_cache_breakpoint(message),
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
@property
|
|
240
|
+
def text_content(self) -> str:
|
|
241
|
+
return collect_text_content(self.content)
|
|
242
|
+
|
|
243
|
+
@property
|
|
244
|
+
def attachments(self) -> List[Attachment]:
|
|
245
|
+
return (
|
|
246
|
+
self.custom_content.attachments or [] if self.custom_content else []
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
class AIToolCallMessage(MessageABC):
|
|
251
|
+
calls: List[ToolCall]
|
|
252
|
+
content: Optional[str] = None
|
|
253
|
+
custom_content: Optional[CustomContent] = None
|
|
254
|
+
|
|
255
|
+
def to_message(self) -> DialMessage:
|
|
256
|
+
return DialMessage(
|
|
257
|
+
role=Role.ASSISTANT,
|
|
258
|
+
content=self.content,
|
|
259
|
+
tool_calls=self.calls,
|
|
260
|
+
custom_content=self.custom_content,
|
|
261
|
+
custom_fields=self.custom_fields,
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
@classmethod
|
|
265
|
+
def from_message(cls, message: DialMessage) -> Self | None:
|
|
266
|
+
if message.role != Role.ASSISTANT:
|
|
267
|
+
return None
|
|
268
|
+
|
|
269
|
+
if message.tool_calls is None or message.function_call is not None:
|
|
270
|
+
return None
|
|
271
|
+
|
|
272
|
+
if not is_plain_text_content(message.content):
|
|
273
|
+
raise ValidationError(
|
|
274
|
+
"The assistant message with tool calls shouldn't contain content parts"
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
return cls(
|
|
278
|
+
calls=message.tool_calls,
|
|
279
|
+
content=message.content,
|
|
280
|
+
custom_content=message.custom_content,
|
|
281
|
+
cache_breakpoint=_get_cache_breakpoint(message),
|
|
282
|
+
)
|
|
283
|
+
|
|
284
|
+
|
|
285
|
+
class AIFunctionCallMessage(MessageABC):
|
|
286
|
+
call: FunctionCall
|
|
287
|
+
content: Optional[str] = None
|
|
288
|
+
|
|
289
|
+
def to_message(self) -> DialMessage:
|
|
290
|
+
return DialMessage(
|
|
291
|
+
role=Role.ASSISTANT,
|
|
292
|
+
content=self.content,
|
|
293
|
+
function_call=self.call,
|
|
294
|
+
custom_fields=self.custom_fields,
|
|
295
|
+
)
|
|
296
|
+
|
|
297
|
+
@classmethod
|
|
298
|
+
def from_message(cls, message: DialMessage) -> Self | None:
|
|
299
|
+
if message.role != Role.ASSISTANT:
|
|
300
|
+
return None
|
|
301
|
+
|
|
302
|
+
if message.function_call is None or message.tool_calls is not None:
|
|
303
|
+
return None
|
|
304
|
+
|
|
305
|
+
if not is_plain_text_content(message.content):
|
|
306
|
+
raise ValidationError(
|
|
307
|
+
"The assistant message with function call shouldn't contain content parts"
|
|
308
|
+
)
|
|
309
|
+
|
|
310
|
+
return cls(
|
|
311
|
+
call=message.function_call,
|
|
312
|
+
content=message.content,
|
|
313
|
+
cache_breakpoint=_get_cache_breakpoint(message),
|
|
314
|
+
)
|
|
315
|
+
|
|
316
|
+
|
|
317
|
+
BaseMessage = Union[SystemMessage, HumanRegularMessage, AIRegularMessage]
|
|
318
|
+
|
|
319
|
+
ToolMessage = Union[
|
|
320
|
+
HumanToolResultMessage,
|
|
321
|
+
HumanFunctionResultMessage,
|
|
322
|
+
AIToolCallMessage,
|
|
323
|
+
AIFunctionCallMessage,
|
|
324
|
+
]
|
|
325
|
+
|
|
326
|
+
|
|
327
|
+
def parse_dial_message(msg: DialMessage) -> BaseMessage | ToolMessage:
|
|
328
|
+
message = (
|
|
329
|
+
SystemMessage.from_message(msg)
|
|
330
|
+
or HumanRegularMessage.from_message(msg)
|
|
331
|
+
or HumanToolResultMessage.from_message(msg)
|
|
332
|
+
or HumanFunctionResultMessage.from_message(msg)
|
|
333
|
+
or AIRegularMessage.from_message(msg)
|
|
334
|
+
or AIToolCallMessage.from_message(msg)
|
|
335
|
+
or AIFunctionCallMessage.from_message(msg)
|
|
336
|
+
)
|
|
337
|
+
|
|
338
|
+
if message is None:
|
|
339
|
+
raise ValidationError("Unknown message type or invalid message")
|
|
340
|
+
|
|
341
|
+
return message
|
|
@@ -0,0 +1,235 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import dataclasses
|
|
4
|
+
from abc import ABC, abstractmethod
|
|
5
|
+
from types import TracebackType
|
|
6
|
+
from typing import ContextManager, List, Optional, Protocol, Self, Tuple
|
|
7
|
+
|
|
8
|
+
from aidial_sdk.chat_completion import (
|
|
9
|
+
Attachment,
|
|
10
|
+
Choice,
|
|
11
|
+
FinishReason,
|
|
12
|
+
FunctionCall,
|
|
13
|
+
Response,
|
|
14
|
+
ToolCall,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
from aidial_adapter_anthropic.adapter._truncate_prompt import DiscardedMessages
|
|
18
|
+
from aidial_adapter_anthropic.dial._lazy_stage import LazyStage
|
|
19
|
+
from aidial_adapter_anthropic.dial.token_usage import TokenUsage
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class _ArgumentConsumer(Protocol):
|
|
23
|
+
def append_arguments(self, arguments: str) -> Self: ...
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@dataclasses.dataclass
|
|
27
|
+
class ToolUseMessage:
|
|
28
|
+
call: _ArgumentConsumer
|
|
29
|
+
snapshot: str
|
|
30
|
+
|
|
31
|
+
def append_arguments(self, arguments: str) -> Self:
|
|
32
|
+
self.call.append_arguments(arguments)
|
|
33
|
+
self.snapshot += arguments
|
|
34
|
+
return self
|
|
35
|
+
|
|
36
|
+
def close(self) -> Self:
|
|
37
|
+
if not self.snapshot.strip():
|
|
38
|
+
self.append_arguments("{}")
|
|
39
|
+
return self
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class Consumer(ContextManager, ABC):
|
|
43
|
+
@abstractmethod
|
|
44
|
+
def fork(self) -> Consumer: ...
|
|
45
|
+
|
|
46
|
+
@property
|
|
47
|
+
@abstractmethod
|
|
48
|
+
def choice(self) -> Choice: ...
|
|
49
|
+
|
|
50
|
+
@abstractmethod
|
|
51
|
+
def close_content(self, finish_reason: FinishReason | None = None): ...
|
|
52
|
+
|
|
53
|
+
@abstractmethod
|
|
54
|
+
def append_content(self, content: str): ...
|
|
55
|
+
|
|
56
|
+
@abstractmethod
|
|
57
|
+
def add_attachment(self, attachment: Attachment): ...
|
|
58
|
+
|
|
59
|
+
@abstractmethod
|
|
60
|
+
def add_citation_attachment(
|
|
61
|
+
self, document_id: int, document: Attachment | None
|
|
62
|
+
) -> int: ...
|
|
63
|
+
|
|
64
|
+
@abstractmethod
|
|
65
|
+
def add_usage(self, usage: TokenUsage): ...
|
|
66
|
+
|
|
67
|
+
@abstractmethod
|
|
68
|
+
def set_discarded_messages(
|
|
69
|
+
self, discarded_messages: Optional[DiscardedMessages]
|
|
70
|
+
): ...
|
|
71
|
+
|
|
72
|
+
@abstractmethod
|
|
73
|
+
def get_discarded_messages(self) -> Optional[DiscardedMessages]: ...
|
|
74
|
+
|
|
75
|
+
@abstractmethod
|
|
76
|
+
def create_function_tool_call(self, call: ToolCall) -> ToolUseMessage: ...
|
|
77
|
+
|
|
78
|
+
@abstractmethod
|
|
79
|
+
def create_function_call(self, call: FunctionCall) -> ToolUseMessage: ...
|
|
80
|
+
|
|
81
|
+
@property
|
|
82
|
+
@abstractmethod
|
|
83
|
+
def has_function_call(self) -> bool: ...
|
|
84
|
+
|
|
85
|
+
def create_stage(self, title: str) -> LazyStage:
|
|
86
|
+
# NOTE: eta conversion to `factory = self.choice.create_stage`
|
|
87
|
+
# is invalid, since `self.choice` must be created lazily.
|
|
88
|
+
def factory(content: str):
|
|
89
|
+
return self.choice.create_stage(content)
|
|
90
|
+
|
|
91
|
+
return LazyStage(factory, title)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
class ChoiceConsumer(Consumer):
|
|
95
|
+
response: Response
|
|
96
|
+
|
|
97
|
+
usage: Optional[TokenUsage]
|
|
98
|
+
discarded_messages: Optional[DiscardedMessages]
|
|
99
|
+
|
|
100
|
+
_root: Optional[Consumer]
|
|
101
|
+
_choice: Optional[Choice]
|
|
102
|
+
_tool_calls: List[ToolUseMessage]
|
|
103
|
+
_citations: dict[int, Tuple[int, Attachment | None]]
|
|
104
|
+
|
|
105
|
+
def __init__(self, response: Response, root: Optional[Consumer] = None):
|
|
106
|
+
self.response = response
|
|
107
|
+
|
|
108
|
+
self.usage = None
|
|
109
|
+
self.discarded_messages = None
|
|
110
|
+
|
|
111
|
+
self._choice = None
|
|
112
|
+
self._root = root
|
|
113
|
+
self._tool_calls = []
|
|
114
|
+
self._citations = {}
|
|
115
|
+
|
|
116
|
+
def fork(self) -> Consumer:
|
|
117
|
+
return ChoiceConsumer(self.response, self._root or self)
|
|
118
|
+
|
|
119
|
+
@property
|
|
120
|
+
def choice(self) -> Choice:
|
|
121
|
+
if self._choice is None:
|
|
122
|
+
choice = self._choice = self.response.create_choice()
|
|
123
|
+
# Delay opening a choice to the very last moment
|
|
124
|
+
# so as to give opportunity for exceptions to bubble up to
|
|
125
|
+
# the level of HTTP response (instead of error objects in a stream).
|
|
126
|
+
choice.open()
|
|
127
|
+
return choice
|
|
128
|
+
else:
|
|
129
|
+
return self._choice
|
|
130
|
+
|
|
131
|
+
def __enter__(self) -> ChoiceConsumer:
|
|
132
|
+
return self
|
|
133
|
+
|
|
134
|
+
def __exit__(
|
|
135
|
+
self,
|
|
136
|
+
exc_type: type[BaseException] | None,
|
|
137
|
+
exc: BaseException | None,
|
|
138
|
+
traceback: TracebackType | None,
|
|
139
|
+
) -> bool | None:
|
|
140
|
+
for tool_call in self._tool_calls:
|
|
141
|
+
tool_call.close()
|
|
142
|
+
|
|
143
|
+
if exc is None and self._choice is not None:
|
|
144
|
+
self._choice.close()
|
|
145
|
+
|
|
146
|
+
if self._root is None:
|
|
147
|
+
if self.usage is not None:
|
|
148
|
+
self.response.set_usage(
|
|
149
|
+
prompt_tokens=self.usage.prompt_tokens,
|
|
150
|
+
completion_tokens=self.usage.completion_tokens,
|
|
151
|
+
prompt_tokens_details={
|
|
152
|
+
"cached_tokens": self.usage.cache_read_input_tokens
|
|
153
|
+
},
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
if self.discarded_messages is not None:
|
|
157
|
+
self.response.set_discarded_messages(self.discarded_messages)
|
|
158
|
+
|
|
159
|
+
return False
|
|
160
|
+
|
|
161
|
+
def close_content(self, finish_reason: FinishReason | None = None):
|
|
162
|
+
# Choice.close(finish_reason: Optional[FinishReason]) can be called only once
|
|
163
|
+
# Currently, there's no other way to explicitly set the finish reason
|
|
164
|
+
self.choice._last_finish_reason = finish_reason
|
|
165
|
+
|
|
166
|
+
def append_content(self, content: str):
|
|
167
|
+
self.choice.append_content(content)
|
|
168
|
+
|
|
169
|
+
def add_attachment(self, attachment: Attachment):
|
|
170
|
+
self.choice.add_attachment(attachment)
|
|
171
|
+
|
|
172
|
+
def add_citation_attachment(
|
|
173
|
+
self, document_id: int, document: Attachment | None
|
|
174
|
+
) -> int:
|
|
175
|
+
if document_id in self._citations:
|
|
176
|
+
return self._citations[document_id][0]
|
|
177
|
+
|
|
178
|
+
display_index = len(self._citations) + 1
|
|
179
|
+
self._citations[document_id] = (display_index, document)
|
|
180
|
+
|
|
181
|
+
if document:
|
|
182
|
+
document = document.copy()
|
|
183
|
+
document.title = f"[{display_index}] {document.title or ''}".strip()
|
|
184
|
+
document.reference_type = document.reference_type or document.type
|
|
185
|
+
document.reference_url = document.reference_url or document.url
|
|
186
|
+
self.add_attachment(document)
|
|
187
|
+
|
|
188
|
+
return display_index
|
|
189
|
+
|
|
190
|
+
def add_usage(self, usage: TokenUsage):
|
|
191
|
+
if self._root:
|
|
192
|
+
self._root.add_usage(usage)
|
|
193
|
+
else:
|
|
194
|
+
self.usage = (self.usage or TokenUsage()).accumulate(usage)
|
|
195
|
+
|
|
196
|
+
def set_discarded_messages(
|
|
197
|
+
self, discarded_messages: Optional[DiscardedMessages]
|
|
198
|
+
):
|
|
199
|
+
if self._root:
|
|
200
|
+
self._root.set_discarded_messages(discarded_messages)
|
|
201
|
+
else:
|
|
202
|
+
self.discarded_messages = discarded_messages
|
|
203
|
+
|
|
204
|
+
def get_discarded_messages(self) -> Optional[DiscardedMessages]:
|
|
205
|
+
if self._root:
|
|
206
|
+
return self._root.get_discarded_messages()
|
|
207
|
+
else:
|
|
208
|
+
return self.discarded_messages
|
|
209
|
+
|
|
210
|
+
def create_function_tool_call(self, call: ToolCall) -> ToolUseMessage:
|
|
211
|
+
tool_call = ToolUseMessage(
|
|
212
|
+
call=self.choice.create_function_tool_call(
|
|
213
|
+
id=call.id,
|
|
214
|
+
name=call.function.name,
|
|
215
|
+
arguments=call.function.arguments,
|
|
216
|
+
),
|
|
217
|
+
snapshot=call.function.arguments,
|
|
218
|
+
)
|
|
219
|
+
self._tool_calls.append(tool_call)
|
|
220
|
+
return tool_call
|
|
221
|
+
|
|
222
|
+
def create_function_call(self, call: FunctionCall) -> ToolUseMessage:
|
|
223
|
+
tool_call = ToolUseMessage(
|
|
224
|
+
call=self.choice.create_function_call(
|
|
225
|
+
name=call.name,
|
|
226
|
+
arguments=call.arguments,
|
|
227
|
+
),
|
|
228
|
+
snapshot=call.arguments,
|
|
229
|
+
)
|
|
230
|
+
self._tool_calls.append(tool_call)
|
|
231
|
+
return tool_call
|
|
232
|
+
|
|
233
|
+
@property
|
|
234
|
+
def has_function_call(self) -> bool:
|
|
235
|
+
return self._choice is not None and self._choice.has_function_call
|