aidial-adapter-anthropic 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. aidial_adapter_anthropic/_utils/json.py +116 -0
  2. aidial_adapter_anthropic/_utils/list.py +84 -0
  3. aidial_adapter_anthropic/_utils/pydantic.py +6 -0
  4. aidial_adapter_anthropic/_utils/resource.py +54 -0
  5. aidial_adapter_anthropic/_utils/text.py +4 -0
  6. aidial_adapter_anthropic/adapter/__init__.py +4 -0
  7. aidial_adapter_anthropic/adapter/_base.py +95 -0
  8. aidial_adapter_anthropic/adapter/_claude/adapter.py +549 -0
  9. aidial_adapter_anthropic/adapter/_claude/blocks.py +128 -0
  10. aidial_adapter_anthropic/adapter/_claude/citations.py +63 -0
  11. aidial_adapter_anthropic/adapter/_claude/config.py +39 -0
  12. aidial_adapter_anthropic/adapter/_claude/converters.py +303 -0
  13. aidial_adapter_anthropic/adapter/_claude/params.py +25 -0
  14. aidial_adapter_anthropic/adapter/_claude/state.py +45 -0
  15. aidial_adapter_anthropic/adapter/_claude/tokenizer/__init__.py +10 -0
  16. aidial_adapter_anthropic/adapter/_claude/tokenizer/anthropic.py +57 -0
  17. aidial_adapter_anthropic/adapter/_claude/tokenizer/approximate.py +260 -0
  18. aidial_adapter_anthropic/adapter/_claude/tokenizer/base.py +26 -0
  19. aidial_adapter_anthropic/adapter/_claude/tools.py +98 -0
  20. aidial_adapter_anthropic/adapter/_decorator/base.py +53 -0
  21. aidial_adapter_anthropic/adapter/_decorator/preprocess.py +63 -0
  22. aidial_adapter_anthropic/adapter/_decorator/replicator.py +32 -0
  23. aidial_adapter_anthropic/adapter/_errors.py +71 -0
  24. aidial_adapter_anthropic/adapter/_tokenize.py +12 -0
  25. aidial_adapter_anthropic/adapter/_truncate_prompt.py +168 -0
  26. aidial_adapter_anthropic/adapter/claude.py +17 -0
  27. aidial_adapter_anthropic/dial/_attachments.py +238 -0
  28. aidial_adapter_anthropic/dial/_lazy_stage.py +40 -0
  29. aidial_adapter_anthropic/dial/_message.py +341 -0
  30. aidial_adapter_anthropic/dial/consumer.py +235 -0
  31. aidial_adapter_anthropic/dial/request.py +170 -0
  32. aidial_adapter_anthropic/dial/resource.py +189 -0
  33. aidial_adapter_anthropic/dial/storage.py +138 -0
  34. aidial_adapter_anthropic/dial/token_usage.py +19 -0
  35. aidial_adapter_anthropic/dial/tools.py +180 -0
  36. aidial_adapter_anthropic-0.1.0.dist-info/LICENSE +202 -0
  37. aidial_adapter_anthropic-0.1.0.dist-info/METADATA +121 -0
  38. aidial_adapter_anthropic-0.1.0.dist-info/RECORD +39 -0
  39. aidial_adapter_anthropic-0.1.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,341 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import List, Optional, Self, Union
3
+
4
+ from aidial_sdk.chat_completion import (
5
+ Attachment,
6
+ CacheBreakpoint,
7
+ CustomContent,
8
+ FunctionCall,
9
+ )
10
+ from aidial_sdk.chat_completion import Message as DialMessage
11
+ from aidial_sdk.chat_completion import (
12
+ MessageContentPart,
13
+ MessageContentTextPart,
14
+ MessageCustomFields,
15
+ Role,
16
+ ToolCall,
17
+ )
18
+ from pydantic import BaseModel
19
+
20
+ from aidial_adapter_anthropic.adapter._errors import ValidationError
21
+ from aidial_adapter_anthropic.dial.request import (
22
+ collect_text_content,
23
+ is_plain_text_content,
24
+ is_system_role,
25
+ is_text_content,
26
+ to_message_content,
27
+ )
28
+
29
+
30
+ class MessageABC(ABC, BaseModel):
31
+ cache_breakpoint: CacheBreakpoint | None = None
32
+
33
+ @property
34
+ def custom_fields(self) -> MessageCustomFields | None:
35
+ if self.cache_breakpoint:
36
+ return MessageCustomFields(cache_breakpoint=self.cache_breakpoint)
37
+ return None
38
+
39
+ @abstractmethod
40
+ def to_message(self) -> DialMessage: ...
41
+
42
+ @classmethod
43
+ @abstractmethod
44
+ def from_message(cls, message: DialMessage) -> Self | None: ...
45
+
46
+
47
+ class BaseMessageABC(MessageABC):
48
+ @property
49
+ @abstractmethod
50
+ def text_content(self) -> str: ...
51
+
52
+
53
+ def _get_cache_breakpoint(message: DialMessage) -> CacheBreakpoint | None:
54
+ if message.custom_fields is None:
55
+ return None
56
+ return message.custom_fields.cache_breakpoint
57
+
58
+
59
+ class SystemMessage(BaseMessageABC):
60
+ content: str | List[MessageContentTextPart]
61
+ is_developer: bool = False
62
+
63
+ def to_message(self) -> DialMessage:
64
+ return DialMessage(
65
+ role=Role.DEVELOPER if self.is_developer else Role.SYSTEM,
66
+ content=to_message_content(self.content),
67
+ custom_fields=self.custom_fields,
68
+ )
69
+
70
+ @classmethod
71
+ def from_message(cls, message: DialMessage) -> Self | None:
72
+ if not is_system_role(message.role):
73
+ return None
74
+
75
+ content = message.content
76
+
77
+ if not is_text_content(content):
78
+ raise ValidationError(
79
+ "System message is expected to be a string or a list of text content parts"
80
+ )
81
+
82
+ return cls(
83
+ is_developer=message.role == Role.DEVELOPER,
84
+ cache_breakpoint=_get_cache_breakpoint(message),
85
+ content=content,
86
+ )
87
+
88
+ @property
89
+ def text_content(self) -> str:
90
+ return collect_text_content(self.content)
91
+
92
+
93
+ class HumanRegularMessage(BaseMessageABC):
94
+ content: str | List[MessageContentPart]
95
+ custom_content: CustomContent | None = None
96
+
97
+ def to_message(self) -> DialMessage:
98
+ return DialMessage(
99
+ role=Role.USER,
100
+ content=self.content,
101
+ custom_content=self.custom_content,
102
+ custom_fields=self.custom_fields,
103
+ )
104
+
105
+ @classmethod
106
+ def from_message(cls, message: DialMessage) -> Self | None:
107
+ if message.role != Role.USER:
108
+ return None
109
+
110
+ content = message.content
111
+ if content is None:
112
+ raise ValidationError(
113
+ "User message is expected to have content field"
114
+ )
115
+
116
+ return cls(
117
+ content=content,
118
+ custom_content=message.custom_content,
119
+ cache_breakpoint=_get_cache_breakpoint(message),
120
+ )
121
+
122
+ @property
123
+ def text_content(self) -> str:
124
+ return collect_text_content(self.content)
125
+
126
+ @property
127
+ def attachments(self) -> List[Attachment]:
128
+ return (
129
+ self.custom_content.attachments or [] if self.custom_content else []
130
+ )
131
+
132
+
133
+ class HumanToolResultMessage(MessageABC):
134
+ id: str
135
+ content: str
136
+
137
+ def to_message(self) -> DialMessage:
138
+ return DialMessage(
139
+ role=Role.TOOL,
140
+ tool_call_id=self.id,
141
+ content=self.content,
142
+ custom_fields=self.custom_fields,
143
+ )
144
+
145
+ @classmethod
146
+ def from_message(cls, message: DialMessage) -> Self | None:
147
+ if message.role != Role.TOOL:
148
+ return None
149
+
150
+ if not is_plain_text_content(message.content):
151
+ raise ValidationError(
152
+ "The tool message shouldn't contain content parts"
153
+ )
154
+
155
+ if message.content is None or message.tool_call_id is None:
156
+ raise ValidationError(
157
+ "The tool message is expected to have content and tool_call_id fields"
158
+ )
159
+
160
+ return cls(
161
+ id=message.tool_call_id,
162
+ content=message.content,
163
+ cache_breakpoint=_get_cache_breakpoint(message),
164
+ )
165
+
166
+
167
+ class HumanFunctionResultMessage(MessageABC):
168
+ name: str
169
+ content: str
170
+
171
+ def to_message(self) -> DialMessage:
172
+ return DialMessage(
173
+ role=Role.FUNCTION,
174
+ name=self.name,
175
+ content=self.content,
176
+ custom_fields=self.custom_fields,
177
+ )
178
+
179
+ @classmethod
180
+ def from_message(cls, message: DialMessage) -> Self | None:
181
+ if message.role != Role.FUNCTION:
182
+ return None
183
+
184
+ if not is_plain_text_content(message.content):
185
+ raise ValidationError(
186
+ "The function message shouldn't contain content parts"
187
+ )
188
+
189
+ if message.content is None or message.name is None:
190
+ raise ValidationError(
191
+ "The function message is expected to have content and name fields"
192
+ )
193
+
194
+ return cls(
195
+ name=message.name,
196
+ content=message.content,
197
+ cache_breakpoint=_get_cache_breakpoint(message),
198
+ )
199
+
200
+
201
+ class AIRegularMessage(BaseMessageABC):
202
+ content: str | List[MessageContentPart]
203
+ """
204
+ According to Azure OpenAI API, the assistant message could only have textual content.
205
+ However, we leave a loophole to provide image content parts just in case
206
+ one day multi-modal Bedrock models will be able to accept images in assistant messages.
207
+ """
208
+
209
+ custom_content: Optional[CustomContent] = None
210
+
211
+ def to_message(self) -> DialMessage:
212
+ return DialMessage(
213
+ role=Role.ASSISTANT,
214
+ content=self.content, # type: ignore
215
+ custom_content=self.custom_content,
216
+ custom_fields=self.custom_fields,
217
+ )
218
+
219
+ @classmethod
220
+ def from_message(cls, message: DialMessage) -> Self | None:
221
+ if message.role != Role.ASSISTANT:
222
+ return None
223
+
224
+ if message.function_call is not None or message.tool_calls is not None:
225
+ return None
226
+
227
+ content = message.content
228
+ if content is None:
229
+ raise ValidationError(
230
+ "Assistant message is expected to have content field"
231
+ )
232
+
233
+ return cls(
234
+ content=content,
235
+ custom_content=message.custom_content,
236
+ cache_breakpoint=_get_cache_breakpoint(message),
237
+ )
238
+
239
+ @property
240
+ def text_content(self) -> str:
241
+ return collect_text_content(self.content)
242
+
243
+ @property
244
+ def attachments(self) -> List[Attachment]:
245
+ return (
246
+ self.custom_content.attachments or [] if self.custom_content else []
247
+ )
248
+
249
+
250
+ class AIToolCallMessage(MessageABC):
251
+ calls: List[ToolCall]
252
+ content: Optional[str] = None
253
+ custom_content: Optional[CustomContent] = None
254
+
255
+ def to_message(self) -> DialMessage:
256
+ return DialMessage(
257
+ role=Role.ASSISTANT,
258
+ content=self.content,
259
+ tool_calls=self.calls,
260
+ custom_content=self.custom_content,
261
+ custom_fields=self.custom_fields,
262
+ )
263
+
264
+ @classmethod
265
+ def from_message(cls, message: DialMessage) -> Self | None:
266
+ if message.role != Role.ASSISTANT:
267
+ return None
268
+
269
+ if message.tool_calls is None or message.function_call is not None:
270
+ return None
271
+
272
+ if not is_plain_text_content(message.content):
273
+ raise ValidationError(
274
+ "The assistant message with tool calls shouldn't contain content parts"
275
+ )
276
+
277
+ return cls(
278
+ calls=message.tool_calls,
279
+ content=message.content,
280
+ custom_content=message.custom_content,
281
+ cache_breakpoint=_get_cache_breakpoint(message),
282
+ )
283
+
284
+
285
+ class AIFunctionCallMessage(MessageABC):
286
+ call: FunctionCall
287
+ content: Optional[str] = None
288
+
289
+ def to_message(self) -> DialMessage:
290
+ return DialMessage(
291
+ role=Role.ASSISTANT,
292
+ content=self.content,
293
+ function_call=self.call,
294
+ custom_fields=self.custom_fields,
295
+ )
296
+
297
+ @classmethod
298
+ def from_message(cls, message: DialMessage) -> Self | None:
299
+ if message.role != Role.ASSISTANT:
300
+ return None
301
+
302
+ if message.function_call is None or message.tool_calls is not None:
303
+ return None
304
+
305
+ if not is_plain_text_content(message.content):
306
+ raise ValidationError(
307
+ "The assistant message with function call shouldn't contain content parts"
308
+ )
309
+
310
+ return cls(
311
+ call=message.function_call,
312
+ content=message.content,
313
+ cache_breakpoint=_get_cache_breakpoint(message),
314
+ )
315
+
316
+
317
+ BaseMessage = Union[SystemMessage, HumanRegularMessage, AIRegularMessage]
318
+
319
+ ToolMessage = Union[
320
+ HumanToolResultMessage,
321
+ HumanFunctionResultMessage,
322
+ AIToolCallMessage,
323
+ AIFunctionCallMessage,
324
+ ]
325
+
326
+
327
+ def parse_dial_message(msg: DialMessage) -> BaseMessage | ToolMessage:
328
+ message = (
329
+ SystemMessage.from_message(msg)
330
+ or HumanRegularMessage.from_message(msg)
331
+ or HumanToolResultMessage.from_message(msg)
332
+ or HumanFunctionResultMessage.from_message(msg)
333
+ or AIRegularMessage.from_message(msg)
334
+ or AIToolCallMessage.from_message(msg)
335
+ or AIFunctionCallMessage.from_message(msg)
336
+ )
337
+
338
+ if message is None:
339
+ raise ValidationError("Unknown message type or invalid message")
340
+
341
+ return message
@@ -0,0 +1,235 @@
1
+ from __future__ import annotations
2
+
3
+ import dataclasses
4
+ from abc import ABC, abstractmethod
5
+ from types import TracebackType
6
+ from typing import ContextManager, List, Optional, Protocol, Self, Tuple
7
+
8
+ from aidial_sdk.chat_completion import (
9
+ Attachment,
10
+ Choice,
11
+ FinishReason,
12
+ FunctionCall,
13
+ Response,
14
+ ToolCall,
15
+ )
16
+
17
+ from aidial_adapter_anthropic.adapter._truncate_prompt import DiscardedMessages
18
+ from aidial_adapter_anthropic.dial._lazy_stage import LazyStage
19
+ from aidial_adapter_anthropic.dial.token_usage import TokenUsage
20
+
21
+
22
+ class _ArgumentConsumer(Protocol):
23
+ def append_arguments(self, arguments: str) -> Self: ...
24
+
25
+
26
+ @dataclasses.dataclass
27
+ class ToolUseMessage:
28
+ call: _ArgumentConsumer
29
+ snapshot: str
30
+
31
+ def append_arguments(self, arguments: str) -> Self:
32
+ self.call.append_arguments(arguments)
33
+ self.snapshot += arguments
34
+ return self
35
+
36
+ def close(self) -> Self:
37
+ if not self.snapshot.strip():
38
+ self.append_arguments("{}")
39
+ return self
40
+
41
+
42
+ class Consumer(ContextManager, ABC):
43
+ @abstractmethod
44
+ def fork(self) -> Consumer: ...
45
+
46
+ @property
47
+ @abstractmethod
48
+ def choice(self) -> Choice: ...
49
+
50
+ @abstractmethod
51
+ def close_content(self, finish_reason: FinishReason | None = None): ...
52
+
53
+ @abstractmethod
54
+ def append_content(self, content: str): ...
55
+
56
+ @abstractmethod
57
+ def add_attachment(self, attachment: Attachment): ...
58
+
59
+ @abstractmethod
60
+ def add_citation_attachment(
61
+ self, document_id: int, document: Attachment | None
62
+ ) -> int: ...
63
+
64
+ @abstractmethod
65
+ def add_usage(self, usage: TokenUsage): ...
66
+
67
+ @abstractmethod
68
+ def set_discarded_messages(
69
+ self, discarded_messages: Optional[DiscardedMessages]
70
+ ): ...
71
+
72
+ @abstractmethod
73
+ def get_discarded_messages(self) -> Optional[DiscardedMessages]: ...
74
+
75
+ @abstractmethod
76
+ def create_function_tool_call(self, call: ToolCall) -> ToolUseMessage: ...
77
+
78
+ @abstractmethod
79
+ def create_function_call(self, call: FunctionCall) -> ToolUseMessage: ...
80
+
81
+ @property
82
+ @abstractmethod
83
+ def has_function_call(self) -> bool: ...
84
+
85
+ def create_stage(self, title: str) -> LazyStage:
86
+ # NOTE: eta conversion to `factory = self.choice.create_stage`
87
+ # is invalid, since `self.choice` must be created lazily.
88
+ def factory(content: str):
89
+ return self.choice.create_stage(content)
90
+
91
+ return LazyStage(factory, title)
92
+
93
+
94
+ class ChoiceConsumer(Consumer):
95
+ response: Response
96
+
97
+ usage: Optional[TokenUsage]
98
+ discarded_messages: Optional[DiscardedMessages]
99
+
100
+ _root: Optional[Consumer]
101
+ _choice: Optional[Choice]
102
+ _tool_calls: List[ToolUseMessage]
103
+ _citations: dict[int, Tuple[int, Attachment | None]]
104
+
105
+ def __init__(self, response: Response, root: Optional[Consumer] = None):
106
+ self.response = response
107
+
108
+ self.usage = None
109
+ self.discarded_messages = None
110
+
111
+ self._choice = None
112
+ self._root = root
113
+ self._tool_calls = []
114
+ self._citations = {}
115
+
116
+ def fork(self) -> Consumer:
117
+ return ChoiceConsumer(self.response, self._root or self)
118
+
119
+ @property
120
+ def choice(self) -> Choice:
121
+ if self._choice is None:
122
+ choice = self._choice = self.response.create_choice()
123
+ # Delay opening a choice to the very last moment
124
+ # so as to give opportunity for exceptions to bubble up to
125
+ # the level of HTTP response (instead of error objects in a stream).
126
+ choice.open()
127
+ return choice
128
+ else:
129
+ return self._choice
130
+
131
+ def __enter__(self) -> ChoiceConsumer:
132
+ return self
133
+
134
+ def __exit__(
135
+ self,
136
+ exc_type: type[BaseException] | None,
137
+ exc: BaseException | None,
138
+ traceback: TracebackType | None,
139
+ ) -> bool | None:
140
+ for tool_call in self._tool_calls:
141
+ tool_call.close()
142
+
143
+ if exc is None and self._choice is not None:
144
+ self._choice.close()
145
+
146
+ if self._root is None:
147
+ if self.usage is not None:
148
+ self.response.set_usage(
149
+ prompt_tokens=self.usage.prompt_tokens,
150
+ completion_tokens=self.usage.completion_tokens,
151
+ prompt_tokens_details={
152
+ "cached_tokens": self.usage.cache_read_input_tokens
153
+ },
154
+ )
155
+
156
+ if self.discarded_messages is not None:
157
+ self.response.set_discarded_messages(self.discarded_messages)
158
+
159
+ return False
160
+
161
+ def close_content(self, finish_reason: FinishReason | None = None):
162
+ # Choice.close(finish_reason: Optional[FinishReason]) can be called only once
163
+ # Currently, there's no other way to explicitly set the finish reason
164
+ self.choice._last_finish_reason = finish_reason
165
+
166
+ def append_content(self, content: str):
167
+ self.choice.append_content(content)
168
+
169
+ def add_attachment(self, attachment: Attachment):
170
+ self.choice.add_attachment(attachment)
171
+
172
+ def add_citation_attachment(
173
+ self, document_id: int, document: Attachment | None
174
+ ) -> int:
175
+ if document_id in self._citations:
176
+ return self._citations[document_id][0]
177
+
178
+ display_index = len(self._citations) + 1
179
+ self._citations[document_id] = (display_index, document)
180
+
181
+ if document:
182
+ document = document.copy()
183
+ document.title = f"[{display_index}] {document.title or ''}".strip()
184
+ document.reference_type = document.reference_type or document.type
185
+ document.reference_url = document.reference_url or document.url
186
+ self.add_attachment(document)
187
+
188
+ return display_index
189
+
190
+ def add_usage(self, usage: TokenUsage):
191
+ if self._root:
192
+ self._root.add_usage(usage)
193
+ else:
194
+ self.usage = (self.usage or TokenUsage()).accumulate(usage)
195
+
196
+ def set_discarded_messages(
197
+ self, discarded_messages: Optional[DiscardedMessages]
198
+ ):
199
+ if self._root:
200
+ self._root.set_discarded_messages(discarded_messages)
201
+ else:
202
+ self.discarded_messages = discarded_messages
203
+
204
+ def get_discarded_messages(self) -> Optional[DiscardedMessages]:
205
+ if self._root:
206
+ return self._root.get_discarded_messages()
207
+ else:
208
+ return self.discarded_messages
209
+
210
+ def create_function_tool_call(self, call: ToolCall) -> ToolUseMessage:
211
+ tool_call = ToolUseMessage(
212
+ call=self.choice.create_function_tool_call(
213
+ id=call.id,
214
+ name=call.function.name,
215
+ arguments=call.function.arguments,
216
+ ),
217
+ snapshot=call.function.arguments,
218
+ )
219
+ self._tool_calls.append(tool_call)
220
+ return tool_call
221
+
222
+ def create_function_call(self, call: FunctionCall) -> ToolUseMessage:
223
+ tool_call = ToolUseMessage(
224
+ call=self.choice.create_function_call(
225
+ name=call.name,
226
+ arguments=call.arguments,
227
+ ),
228
+ snapshot=call.arguments,
229
+ )
230
+ self._tool_calls.append(tool_call)
231
+ return tool_call
232
+
233
+ @property
234
+ def has_function_call(self) -> bool:
235
+ return self._choice is not None and self._choice.has_function_call