aidial-adapter-anthropic 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. aidial_adapter_anthropic/_utils/json.py +116 -0
  2. aidial_adapter_anthropic/_utils/list.py +84 -0
  3. aidial_adapter_anthropic/_utils/pydantic.py +6 -0
  4. aidial_adapter_anthropic/_utils/resource.py +54 -0
  5. aidial_adapter_anthropic/_utils/text.py +4 -0
  6. aidial_adapter_anthropic/adapter/__init__.py +4 -0
  7. aidial_adapter_anthropic/adapter/_base.py +95 -0
  8. aidial_adapter_anthropic/adapter/_claude/adapter.py +549 -0
  9. aidial_adapter_anthropic/adapter/_claude/blocks.py +128 -0
  10. aidial_adapter_anthropic/adapter/_claude/citations.py +63 -0
  11. aidial_adapter_anthropic/adapter/_claude/config.py +39 -0
  12. aidial_adapter_anthropic/adapter/_claude/converters.py +303 -0
  13. aidial_adapter_anthropic/adapter/_claude/params.py +25 -0
  14. aidial_adapter_anthropic/adapter/_claude/state.py +45 -0
  15. aidial_adapter_anthropic/adapter/_claude/tokenizer/__init__.py +10 -0
  16. aidial_adapter_anthropic/adapter/_claude/tokenizer/anthropic.py +57 -0
  17. aidial_adapter_anthropic/adapter/_claude/tokenizer/approximate.py +260 -0
  18. aidial_adapter_anthropic/adapter/_claude/tokenizer/base.py +26 -0
  19. aidial_adapter_anthropic/adapter/_claude/tools.py +98 -0
  20. aidial_adapter_anthropic/adapter/_decorator/base.py +53 -0
  21. aidial_adapter_anthropic/adapter/_decorator/preprocess.py +63 -0
  22. aidial_adapter_anthropic/adapter/_decorator/replicator.py +32 -0
  23. aidial_adapter_anthropic/adapter/_errors.py +71 -0
  24. aidial_adapter_anthropic/adapter/_tokenize.py +12 -0
  25. aidial_adapter_anthropic/adapter/_truncate_prompt.py +168 -0
  26. aidial_adapter_anthropic/adapter/claude.py +17 -0
  27. aidial_adapter_anthropic/dial/_attachments.py +238 -0
  28. aidial_adapter_anthropic/dial/_lazy_stage.py +40 -0
  29. aidial_adapter_anthropic/dial/_message.py +341 -0
  30. aidial_adapter_anthropic/dial/consumer.py +235 -0
  31. aidial_adapter_anthropic/dial/request.py +170 -0
  32. aidial_adapter_anthropic/dial/resource.py +189 -0
  33. aidial_adapter_anthropic/dial/storage.py +138 -0
  34. aidial_adapter_anthropic/dial/token_usage.py +19 -0
  35. aidial_adapter_anthropic/dial/tools.py +180 -0
  36. aidial_adapter_anthropic-0.1.0.dist-info/LICENSE +202 -0
  37. aidial_adapter_anthropic-0.1.0.dist-info/METADATA +121 -0
  38. aidial_adapter_anthropic-0.1.0.dist-info/RECORD +39 -0
  39. aidial_adapter_anthropic-0.1.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,168 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import Awaitable, Callable, List, Optional, Set, Tuple, TypeVar
3
+
4
+ from aidial_sdk.exceptions import ContextLengthExceededError
5
+ from aidial_sdk.exceptions import HTTPException as DialException
6
+ from aidial_sdk.exceptions import (
7
+ InvalidRequestError,
8
+ TruncatePromptSystemAndLastUserError,
9
+ )
10
+ from pydantic import BaseModel
11
+
12
+ from aidial_adapter_anthropic._utils.list import (
13
+ omit_by_indices,
14
+ select_by_indices,
15
+ )
16
+
17
+
18
+ class TruncatePromptError(ABC, BaseModel):
19
+ @abstractmethod
20
+ def to_dial_exception(self) -> DialException:
21
+ pass
22
+
23
+ def print(self) -> str:
24
+ return self.to_dial_exception().message
25
+
26
+
27
+ class InconsistentLimitsError(TruncatePromptError):
28
+ user_limit: int
29
+ model_limit: int
30
+
31
+ def to_dial_exception(self) -> DialException:
32
+ return InvalidRequestError(
33
+ f"The request maximum prompt tokens is {self.user_limit}. "
34
+ f"However, the model's maximum context length is {self.model_limit} tokens."
35
+ )
36
+
37
+
38
+ class ModelLimitOverflow(TruncatePromptError):
39
+ model_limit: int
40
+ token_count: int
41
+
42
+ def to_dial_exception(self) -> DialException:
43
+ return ContextLengthExceededError(self.model_limit, self.token_count)
44
+
45
+
46
+ class UserLimitOverflow(TruncatePromptError):
47
+ user_limit: int
48
+ token_count: int
49
+
50
+ def to_dial_exception(self) -> DialException:
51
+ return TruncatePromptSystemAndLastUserError(
52
+ self.user_limit, self.token_count
53
+ )
54
+
55
+
56
+ def _partition_indexer(chunks: List[int]) -> Callable[[int], List[int]]:
57
+ """
58
+ Returns a function that maps an index to indices of its partition.
59
+ """
60
+ mapping: dict[int, List[int]] = {}
61
+ offset = 0
62
+ for size in chunks:
63
+ chunk = list(range(offset, offset + size))
64
+ for idx in range(size):
65
+ mapping[offset + idx] = chunk
66
+ offset += size
67
+
68
+ return mapping.__getitem__
69
+
70
+
71
+ _T = TypeVar("_T")
72
+ DiscardedMessages = List[int]
73
+
74
+
75
+ async def truncate_prompt(
76
+ messages: List[_T],
77
+ tokenizer: Callable[[List[_T]], Awaitable[int]],
78
+ keep_message: Callable[[List[_T], int], bool],
79
+ partitioner: Callable[[List[_T]], List[int]],
80
+ model_limit: Optional[int],
81
+ user_limit: Optional[int],
82
+ ) -> Tuple[DiscardedMessages, List[_T]]:
83
+ """
84
+ Returns a list of indices of discarded messages and a list of preserved messages
85
+ """
86
+
87
+ result = await compute_discarded_messages(
88
+ messages,
89
+ tokenizer,
90
+ keep_message,
91
+ partitioner,
92
+ model_limit,
93
+ user_limit,
94
+ )
95
+
96
+ if isinstance(result, TruncatePromptError):
97
+ raise result.to_dial_exception()
98
+
99
+ return (list(result), omit_by_indices(messages, result))
100
+
101
+
102
+ async def compute_discarded_messages(
103
+ messages: List[_T],
104
+ tokenizer: Callable[[List[_T]], Awaitable[int]],
105
+ keep_message: Callable[[List[_T], int], bool],
106
+ partitioner: Callable[[List[_T]], List[int]],
107
+ model_limit: Optional[int],
108
+ user_limit: Optional[int],
109
+ ) -> DiscardedMessages | TruncatePromptError:
110
+ if (
111
+ user_limit is not None
112
+ and model_limit is not None
113
+ and user_limit > model_limit
114
+ ):
115
+ return InconsistentLimitsError(
116
+ user_limit=user_limit, model_limit=model_limit
117
+ )
118
+
119
+ if user_limit is None:
120
+ if model_limit is None:
121
+ return []
122
+
123
+ token_count = await tokenizer(messages)
124
+ if token_count <= model_limit:
125
+ return []
126
+
127
+ return ModelLimitOverflow(
128
+ model_limit=model_limit, token_count=token_count
129
+ )
130
+
131
+ partition_sizes = partitioner(messages)
132
+ if sum(partition_sizes) != len(messages):
133
+ raise ValueError(
134
+ "Partition sizes must add up to the number of messages."
135
+ )
136
+
137
+ async def _tokenize_selected(indices: Set[int]) -> int:
138
+ return await tokenizer(select_by_indices(messages, indices))
139
+
140
+ get_partition_indices = _partition_indexer(partition_sizes)
141
+
142
+ n = len(messages)
143
+ kept_indices: Set[int] = {
144
+ j
145
+ for i in range(n)
146
+ for j in get_partition_indices(i)
147
+ if keep_message(messages, i)
148
+ }
149
+
150
+ token_count = await _tokenize_selected(kept_indices)
151
+ if token_count > user_limit:
152
+ return UserLimitOverflow(user_limit=user_limit, token_count=token_count)
153
+
154
+ for idx in reversed(range(n)):
155
+ if idx in kept_indices:
156
+ continue
157
+
158
+ chunk_indices = get_partition_indices(idx)
159
+ new_token_count = await _tokenize_selected(
160
+ {*kept_indices, *chunk_indices}
161
+ )
162
+ if new_token_count > user_limit:
163
+ break
164
+
165
+ kept_indices.update(chunk_indices)
166
+
167
+ all_indices = set(range(n))
168
+ return sorted(list(all_indices - kept_indices))
@@ -0,0 +1,17 @@
1
+ from aidial_adapter_anthropic.adapter._claude.adapter import create_adapter
2
+ from aidial_adapter_anthropic.adapter._claude.state import MessageState
3
+ from aidial_adapter_anthropic.adapter._claude.tokenizer.approximate import (
4
+ ApproximateTokenizer,
5
+ )
6
+ from aidial_adapter_anthropic.adapter._claude.tokenizer.base import (
7
+ ClaudeTokenizer,
8
+ create_tokenizer,
9
+ )
10
+
11
+ __all__ = [
12
+ "create_adapter",
13
+ "MessageState",
14
+ "create_tokenizer",
15
+ "ApproximateTokenizer",
16
+ "ClaudeTokenizer",
17
+ ]
@@ -0,0 +1,238 @@
1
+ from __future__ import annotations
2
+
3
+ import inspect
4
+ from dataclasses import dataclass, field
5
+ from typing import (
6
+ AsyncIterator,
7
+ Callable,
8
+ Dict,
9
+ Generic,
10
+ List,
11
+ Protocol,
12
+ Sequence,
13
+ Set,
14
+ TypeVar,
15
+ assert_never,
16
+ runtime_checkable,
17
+ )
18
+
19
+ from aidial_sdk.chat_completion import (
20
+ MessageContentImagePart,
21
+ MessageContentRefusalPart,
22
+ MessageContentTextPart,
23
+ )
24
+ from pydantic import BaseModel
25
+
26
+ from aidial_adapter_anthropic._utils.list import aiter_to_list
27
+ from aidial_adapter_anthropic._utils.resource import Resource
28
+ from aidial_adapter_anthropic.adapter._errors import UserError, ValidationError
29
+ from aidial_adapter_anthropic.dial._message import BaseMessage, SystemMessage
30
+ from aidial_adapter_anthropic.dial.resource import (
31
+ AttachmentResource,
32
+ DialResource,
33
+ UnsupportedContentType,
34
+ URLResource,
35
+ )
36
+ from aidial_adapter_anthropic.dial.storage import FileStorage
37
+
38
+ _T = TypeVar("_T", covariant=True)
39
+ _Txt = TypeVar("_Txt", covariant=True)
40
+ _Config = TypeVar("_Config", bound=BaseModel, contravariant=True)
41
+
42
+
43
+ @runtime_checkable
44
+ class Handler(Protocol, Generic[_T]):
45
+ def __call__(self, resource: Resource) -> _T: ...
46
+
47
+
48
+ @runtime_checkable
49
+ class HandlerWithConfig(Protocol, Generic[_T, _Config]):
50
+ def __call__(self, resource: Resource, config: _Config | None) -> _T: ...
51
+
52
+
53
+ class AttachmentProcessor(BaseModel, Generic[_T, _Config]):
54
+ class Config:
55
+ arbitrary_types_allowed = True
56
+
57
+ supported_types: Dict[str, Set[str]]
58
+ """MIME type to file extensions mapping"""
59
+
60
+ handler: Handler[_T] | HandlerWithConfig[_T, _Config]
61
+
62
+ def handle(self, resource: Resource, config: _Config | None) -> _T:
63
+ sig = inspect.signature(self.handler)
64
+ params = list(sig.parameters.values())
65
+
66
+ with_config = (
67
+ len(params) >= 2
68
+ and params[1].kind
69
+ in (
70
+ inspect.Parameter.POSITIONAL_ONLY,
71
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
72
+ )
73
+ ) or any(p.kind is inspect.Parameter.VAR_KEYWORD for p in params)
74
+
75
+ if with_config:
76
+ return self.handler(resource, config) # type: ignore
77
+ return self.handler(resource) # type: ignore
78
+
79
+
80
+ @dataclass
81
+ class WithResources(Generic[_T]):
82
+ payload: _T
83
+ resources: List[DialResource] = field(default_factory=list)
84
+
85
+ @staticmethod
86
+ def transpose(xs: List[WithResources[_T]]) -> WithResources[List[_T]]:
87
+ resources = [r for x in xs for r in x.resources]
88
+ payload = [x.payload for x in xs]
89
+ return WithResources(payload=payload, resources=resources)
90
+
91
+
92
+ class AttachmentProcessors(BaseModel, Generic[_Txt, _T, _Config]):
93
+ config: _Config | None = None
94
+ attachment_processors: Sequence[AttachmentProcessor[_T, _Config]]
95
+ text_handler: Callable[[str], _Txt]
96
+ file_storage: FileStorage | None
97
+
98
+ @property
99
+ def supported_types(self) -> Dict[str, Set[str]]:
100
+ ret: Dict[str, Set[str]] = {}
101
+ for processor in self.attachment_processors:
102
+ for mime_type, file_exts in processor.supported_types.items():
103
+ ret.setdefault(mime_type, set()).update(file_exts)
104
+ return ret
105
+
106
+ @property
107
+ def supported_mime_types(self) -> List[str]:
108
+ return list(self.supported_types)
109
+
110
+ @property
111
+ def supported_image_types(self) -> List[str]:
112
+ return [t for t in self.supported_mime_types if t.startswith("image/")]
113
+
114
+ def _text_handler(self, text: str) -> WithResources[_Txt]:
115
+ return WithResources(self.text_handler(text))
116
+
117
+ async def process_system_message(
118
+ self, message: SystemMessage
119
+ ) -> List[_Txt]:
120
+ def _gen():
121
+ match (content := message.content):
122
+ case str():
123
+ if content:
124
+ yield self.text_handler(content)
125
+ case list():
126
+ for part in content:
127
+ match part:
128
+ case MessageContentTextPart(text=text):
129
+ if text:
130
+ yield self.text_handler(text)
131
+ case _:
132
+ assert_never(part)
133
+ case _:
134
+ assert_never(content)
135
+
136
+ return [x for x in _gen()]
137
+
138
+ async def process_attachments(
139
+ self, message: BaseMessage
140
+ ) -> WithResources[List[_T | _Txt]]:
141
+ ret = await aiter_to_list(self._process_attachments_iter(message)) or [
142
+ self._text_handler("")
143
+ ]
144
+ return WithResources.transpose(ret)
145
+
146
+ async def _process_attachments_iter(
147
+ self, message: BaseMessage
148
+ ) -> AsyncIterator[WithResources[_T | _Txt]]:
149
+ if not isinstance(message, SystemMessage):
150
+ for attachment in message.attachments:
151
+ yield await self._handle_dial_resource(
152
+ AttachmentResource(
153
+ attachment=attachment,
154
+ entity_name="attachment",
155
+ supported_types=self.supported_mime_types,
156
+ ),
157
+ )
158
+
159
+ content = message.content
160
+
161
+ match content:
162
+ case str():
163
+ if content:
164
+ yield self._text_handler(content)
165
+ case list():
166
+ for part in content:
167
+ match part:
168
+ case MessageContentTextPart(text=text):
169
+ if text:
170
+ yield self._text_handler(text)
171
+ case MessageContentImagePart(image_url=image_url):
172
+ yield await self._handle_dial_resource(
173
+ URLResource(
174
+ url=image_url.url,
175
+ entity_name="image url",
176
+ supported_types=self.supported_image_types,
177
+ ),
178
+ )
179
+ case MessageContentRefusalPart():
180
+ raise ValidationError(
181
+ "Refuse content parts aren't supported"
182
+ )
183
+ case _:
184
+ assert_never(part)
185
+ case _:
186
+ assert_never(content)
187
+
188
+ async def _download_resource(self, dial_resource: DialResource) -> Resource:
189
+ try:
190
+ return await dial_resource.download(self.file_storage)
191
+ except UnsupportedContentType as e:
192
+ raise UserError(
193
+ f"Unsupported media type: {e.type}",
194
+ _get_usage_message(self.get_file_exts(e.supported_types)),
195
+ )
196
+
197
+ async def _handle_resource(self, resource: Resource) -> _T:
198
+ for processor in self.attachment_processors:
199
+ if resource.type in processor.supported_types:
200
+ return processor.handle(resource, self.config)
201
+
202
+ raise UserError(
203
+ f"Unsupported media type: {resource.type}",
204
+ _get_usage_message(self.get_file_exts(self.supported_mime_types)),
205
+ )
206
+
207
+ async def _handle_dial_resource(
208
+ self, dial_resource: DialResource
209
+ ) -> WithResources[_T]:
210
+ resource = await self._download_resource(dial_resource)
211
+ message = await self._handle_resource(resource)
212
+ return WithResources(message, resources=[dial_resource])
213
+
214
+ def get_file_exts(self, mime_types: List[str]) -> List[str]:
215
+ return [
216
+ file_ext
217
+ for mime_type, file_exts in self.supported_types.items()
218
+ if mime_type in mime_types
219
+ for file_ext in file_exts
220
+ ]
221
+
222
+
223
+ def _get_usage_message(supported_exts: List[str]) -> str:
224
+ document_hint = ""
225
+ if "pdf" in supported_exts:
226
+ document_hint = '- "Summarize the document" for a PDF document'
227
+
228
+ return f"""
229
+ The application answers queries about attached files.
230
+ Attach file(s) and ask questions about them in the same message.
231
+
232
+ Supported attachment types: {', '.join(supported_exts)}.
233
+
234
+ Examples of queries:
235
+ - "Describe this picture" for an image
236
+ - "What are in these images? Is there any difference between them?" for multiple images
237
+ {document_hint}
238
+ """.strip()
@@ -0,0 +1,40 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Callable
4
+
5
+ from aidial_sdk.chat_completion import Stage
6
+
7
+ _StageFactory = Callable[[str], Stage]
8
+
9
+
10
+ class LazyStage:
11
+ title: str
12
+ stage_factory: _StageFactory
13
+
14
+ _stage: Stage | None = None
15
+
16
+ def __init__(self, stage_factory: _StageFactory, title: str):
17
+ self.stage_factory = stage_factory
18
+ self.title = title
19
+
20
+ def __enter__(self) -> LazyStage:
21
+ return self
22
+
23
+ async def __aenter__(self) -> LazyStage:
24
+ return self
25
+
26
+ def __exit__(self, exc_type, exc_value, traceback) -> None:
27
+ self.close()
28
+
29
+ async def __aexit__(self, exc_type, exc_value, traceback) -> None:
30
+ self.close()
31
+
32
+ def append_content(self, text: str) -> None:
33
+ if self._stage is None:
34
+ self._stage = self.stage_factory(self.title)
35
+ self._stage.open()
36
+ self._stage.append_content(text)
37
+
38
+ def close(self) -> None:
39
+ if self._stage is not None:
40
+ self._stage.close()