aidial-adapter-anthropic 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aidial_adapter_anthropic/_utils/json.py +116 -0
- aidial_adapter_anthropic/_utils/list.py +84 -0
- aidial_adapter_anthropic/_utils/pydantic.py +6 -0
- aidial_adapter_anthropic/_utils/resource.py +54 -0
- aidial_adapter_anthropic/_utils/text.py +4 -0
- aidial_adapter_anthropic/adapter/__init__.py +4 -0
- aidial_adapter_anthropic/adapter/_base.py +95 -0
- aidial_adapter_anthropic/adapter/_claude/adapter.py +549 -0
- aidial_adapter_anthropic/adapter/_claude/blocks.py +128 -0
- aidial_adapter_anthropic/adapter/_claude/citations.py +63 -0
- aidial_adapter_anthropic/adapter/_claude/config.py +39 -0
- aidial_adapter_anthropic/adapter/_claude/converters.py +303 -0
- aidial_adapter_anthropic/adapter/_claude/params.py +25 -0
- aidial_adapter_anthropic/adapter/_claude/state.py +45 -0
- aidial_adapter_anthropic/adapter/_claude/tokenizer/__init__.py +10 -0
- aidial_adapter_anthropic/adapter/_claude/tokenizer/anthropic.py +57 -0
- aidial_adapter_anthropic/adapter/_claude/tokenizer/approximate.py +260 -0
- aidial_adapter_anthropic/adapter/_claude/tokenizer/base.py +26 -0
- aidial_adapter_anthropic/adapter/_claude/tools.py +98 -0
- aidial_adapter_anthropic/adapter/_decorator/base.py +53 -0
- aidial_adapter_anthropic/adapter/_decorator/preprocess.py +63 -0
- aidial_adapter_anthropic/adapter/_decorator/replicator.py +32 -0
- aidial_adapter_anthropic/adapter/_errors.py +71 -0
- aidial_adapter_anthropic/adapter/_tokenize.py +12 -0
- aidial_adapter_anthropic/adapter/_truncate_prompt.py +168 -0
- aidial_adapter_anthropic/adapter/claude.py +17 -0
- aidial_adapter_anthropic/dial/_attachments.py +238 -0
- aidial_adapter_anthropic/dial/_lazy_stage.py +40 -0
- aidial_adapter_anthropic/dial/_message.py +341 -0
- aidial_adapter_anthropic/dial/consumer.py +235 -0
- aidial_adapter_anthropic/dial/request.py +170 -0
- aidial_adapter_anthropic/dial/resource.py +189 -0
- aidial_adapter_anthropic/dial/storage.py +138 -0
- aidial_adapter_anthropic/dial/token_usage.py +19 -0
- aidial_adapter_anthropic/dial/tools.py +180 -0
- aidial_adapter_anthropic-0.1.0.dist-info/LICENSE +202 -0
- aidial_adapter_anthropic-0.1.0.dist-info/METADATA +121 -0
- aidial_adapter_anthropic-0.1.0.dist-info/RECORD +39 -0
- aidial_adapter_anthropic-0.1.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,170 @@
|
|
|
1
|
+
from typing import (
|
|
2
|
+
List,
|
|
3
|
+
Literal,
|
|
4
|
+
Optional,
|
|
5
|
+
Type,
|
|
6
|
+
TypeGuard,
|
|
7
|
+
TypeVar,
|
|
8
|
+
assert_never,
|
|
9
|
+
)
|
|
10
|
+
|
|
11
|
+
from aidial_sdk.chat_completion import (
|
|
12
|
+
MessageContentImagePart,
|
|
13
|
+
MessageContentPart,
|
|
14
|
+
MessageContentTextPart,
|
|
15
|
+
Role,
|
|
16
|
+
)
|
|
17
|
+
from aidial_sdk.chat_completion.request import (
|
|
18
|
+
ChatCompletionRequest,
|
|
19
|
+
MessageContentRefusalPart,
|
|
20
|
+
)
|
|
21
|
+
from aidial_sdk.exceptions import RequestValidationError
|
|
22
|
+
from pydantic import BaseModel
|
|
23
|
+
from pydantic.v1 import ValidationError as PydanticValidationError
|
|
24
|
+
|
|
25
|
+
from aidial_adapter_anthropic.adapter._errors import ValidationError
|
|
26
|
+
from aidial_adapter_anthropic.dial.tools import (
|
|
27
|
+
ToolsConfig,
|
|
28
|
+
ToolsMode,
|
|
29
|
+
validate_messages,
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
MessageContent = str | List[MessageContentPart] | None
|
|
33
|
+
MessageContentSpecialized = (
|
|
34
|
+
MessageContent
|
|
35
|
+
| List[MessageContentTextPart]
|
|
36
|
+
| List[MessageContentImagePart]
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
_Model = TypeVar("_Model", bound=BaseModel)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class ModelParameters(BaseModel):
|
|
43
|
+
temperature: Optional[float] = None
|
|
44
|
+
top_p: Optional[float] = None
|
|
45
|
+
n: int = 1
|
|
46
|
+
stop: List[str] = []
|
|
47
|
+
seed: Optional[int] = None
|
|
48
|
+
max_tokens: Optional[int] = None
|
|
49
|
+
max_prompt_tokens: Optional[int] = None
|
|
50
|
+
stream: bool = False
|
|
51
|
+
tool_config: Optional[ToolsConfig] = None
|
|
52
|
+
configuration: Optional[dict] = None
|
|
53
|
+
|
|
54
|
+
@classmethod
|
|
55
|
+
def create(cls, request: ChatCompletionRequest) -> "ModelParameters":
|
|
56
|
+
stop: List[str] = []
|
|
57
|
+
if request.stop is not None:
|
|
58
|
+
stop = (
|
|
59
|
+
[request.stop]
|
|
60
|
+
if isinstance(request.stop, str)
|
|
61
|
+
else request.stop
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
validate_messages(request)
|
|
65
|
+
|
|
66
|
+
configuration = (
|
|
67
|
+
cf.configuration
|
|
68
|
+
if (cf := request.custom_fields) is not None
|
|
69
|
+
else None
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
return cls(
|
|
73
|
+
temperature=request.temperature,
|
|
74
|
+
top_p=request.top_p,
|
|
75
|
+
n=request.n or 1,
|
|
76
|
+
stop=stop,
|
|
77
|
+
seed=request.seed,
|
|
78
|
+
max_tokens=request.max_tokens,
|
|
79
|
+
max_prompt_tokens=request.max_prompt_tokens,
|
|
80
|
+
stream=request.stream,
|
|
81
|
+
tool_config=ToolsConfig.from_request(request),
|
|
82
|
+
configuration=configuration,
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
def add_stop_sequences(self, stop: List[str]) -> "ModelParameters":
|
|
86
|
+
return self.copy(update={"stop": [*self.stop, *stop]})
|
|
87
|
+
|
|
88
|
+
@property
|
|
89
|
+
def tools_mode(self) -> ToolsMode | None:
|
|
90
|
+
if self.tool_config is not None:
|
|
91
|
+
return self.tool_config.tools_mode
|
|
92
|
+
return None
|
|
93
|
+
|
|
94
|
+
def parse_configuration(self, cls: Type[_Model]) -> _Model:
|
|
95
|
+
try:
|
|
96
|
+
return cls.parse_obj(self.configuration or {})
|
|
97
|
+
except PydanticValidationError as e:
|
|
98
|
+
if self.configuration is None:
|
|
99
|
+
msg = "The configuration at path 'custom_fields.configuration' is missing."
|
|
100
|
+
else:
|
|
101
|
+
error = e.errors()[0]
|
|
102
|
+
path = ".".join(map(str, error["loc"]))
|
|
103
|
+
msg = f"Invalid request. Path: 'custom_fields.configuration.{path}', error: {error['msg']}"
|
|
104
|
+
|
|
105
|
+
raise RequestValidationError(msg)
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def collect_text_content(
|
|
109
|
+
content: MessageContentSpecialized, delimiter: str = "\n\n"
|
|
110
|
+
) -> str:
|
|
111
|
+
match content:
|
|
112
|
+
case None:
|
|
113
|
+
return ""
|
|
114
|
+
case str():
|
|
115
|
+
return content
|
|
116
|
+
case list():
|
|
117
|
+
texts: List[str] = []
|
|
118
|
+
for part in content:
|
|
119
|
+
match part:
|
|
120
|
+
case MessageContentTextPart(text=text):
|
|
121
|
+
texts.append(text)
|
|
122
|
+
case MessageContentImagePart():
|
|
123
|
+
raise ValidationError(
|
|
124
|
+
"Can't extract text from an image content part"
|
|
125
|
+
)
|
|
126
|
+
case MessageContentRefusalPart():
|
|
127
|
+
raise ValidationError(
|
|
128
|
+
"Can't extract text from a refusal content part"
|
|
129
|
+
)
|
|
130
|
+
case _:
|
|
131
|
+
assert_never(part)
|
|
132
|
+
return delimiter.join(texts)
|
|
133
|
+
case _:
|
|
134
|
+
assert_never(content)
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def to_message_content(content: MessageContentSpecialized) -> MessageContent:
|
|
138
|
+
match content:
|
|
139
|
+
case None | str():
|
|
140
|
+
return content
|
|
141
|
+
case list():
|
|
142
|
+
return [*content]
|
|
143
|
+
case _:
|
|
144
|
+
assert_never(content)
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def is_text_content(
|
|
148
|
+
content: MessageContent,
|
|
149
|
+
) -> TypeGuard[str | List[MessageContentTextPart]]:
|
|
150
|
+
match content:
|
|
151
|
+
case None:
|
|
152
|
+
return False
|
|
153
|
+
case str():
|
|
154
|
+
return True
|
|
155
|
+
case list():
|
|
156
|
+
return all(
|
|
157
|
+
isinstance(part, MessageContentTextPart) for part in content
|
|
158
|
+
)
|
|
159
|
+
case _:
|
|
160
|
+
assert_never(content)
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
def is_plain_text_content(content: MessageContent) -> TypeGuard[str | None]:
|
|
164
|
+
return content is None or isinstance(content, str)
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def is_system_role(
|
|
168
|
+
role: Role,
|
|
169
|
+
) -> TypeGuard[Literal[Role.SYSTEM, Role.DEVELOPER]]:
|
|
170
|
+
return role in [Role.SYSTEM, Role.DEVELOPER]
|
|
@@ -0,0 +1,189 @@
|
|
|
1
|
+
import base64
|
|
2
|
+
import mimetypes
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from typing import List
|
|
5
|
+
|
|
6
|
+
from aidial_sdk.chat_completion import Attachment
|
|
7
|
+
from pydantic import BaseModel, Field, root_validator, validator
|
|
8
|
+
|
|
9
|
+
from aidial_adapter_anthropic._utils.resource import Resource
|
|
10
|
+
from aidial_adapter_anthropic._utils.text import truncate_string
|
|
11
|
+
from aidial_adapter_anthropic.dial.storage import FileStorage, download_file
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class ValidationError(Exception):
|
|
15
|
+
message: str
|
|
16
|
+
|
|
17
|
+
def __init__(self, message: str):
|
|
18
|
+
self.message = message
|
|
19
|
+
super().__init__(message)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class MissingContentType(ValidationError):
|
|
23
|
+
pass
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class UnsupportedContentType(ValidationError):
|
|
27
|
+
type: str
|
|
28
|
+
supported_types: List[str]
|
|
29
|
+
|
|
30
|
+
def __init__(self, *, message: str, type: str, supported_types: List[str]):
|
|
31
|
+
self.type = type
|
|
32
|
+
self.supported_types = supported_types
|
|
33
|
+
super().__init__(message)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class DialResource(ABC, BaseModel):
|
|
37
|
+
entity_name: str = Field(default=None)
|
|
38
|
+
supported_types: List[str] | None = Field(default=None)
|
|
39
|
+
|
|
40
|
+
@abstractmethod
|
|
41
|
+
def to_attachment(self) -> Attachment: ...
|
|
42
|
+
|
|
43
|
+
@abstractmethod
|
|
44
|
+
async def download(self, storage: FileStorage | None) -> Resource: ...
|
|
45
|
+
|
|
46
|
+
@abstractmethod
|
|
47
|
+
async def guess_content_type(self) -> str | None: ...
|
|
48
|
+
|
|
49
|
+
@abstractmethod
|
|
50
|
+
async def get_resource_name(self, storage: FileStorage | None) -> str: ...
|
|
51
|
+
|
|
52
|
+
async def get_content_type(self) -> str:
|
|
53
|
+
type = await self.guess_content_type()
|
|
54
|
+
|
|
55
|
+
if not type:
|
|
56
|
+
raise MissingContentType(
|
|
57
|
+
f"Can't derive content type of the {self.entity_name}"
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
if (
|
|
61
|
+
self.supported_types is not None
|
|
62
|
+
and type not in self.supported_types
|
|
63
|
+
):
|
|
64
|
+
raise UnsupportedContentType(
|
|
65
|
+
message=f"The {self.entity_name} is not one of the supported types",
|
|
66
|
+
type=type,
|
|
67
|
+
supported_types=self.supported_types,
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
return type
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class URLResource(DialResource):
|
|
74
|
+
url: str
|
|
75
|
+
content_type: str | None = None
|
|
76
|
+
|
|
77
|
+
def to_attachment(self) -> Attachment:
|
|
78
|
+
return Attachment(type=self.content_type, url=self.url)
|
|
79
|
+
|
|
80
|
+
@root_validator
|
|
81
|
+
def validator(cls, values):
|
|
82
|
+
values["entity_name"] = values.get("entity_name") or "URL"
|
|
83
|
+
return values
|
|
84
|
+
|
|
85
|
+
async def download(self, storage: FileStorage | None) -> Resource:
|
|
86
|
+
type = await self.get_content_type()
|
|
87
|
+
data = await _download_url(storage, self.url)
|
|
88
|
+
return Resource(type=type, data=data)
|
|
89
|
+
|
|
90
|
+
async def guess_content_type(self) -> str | None:
|
|
91
|
+
return (
|
|
92
|
+
self.content_type
|
|
93
|
+
or Resource.parse_data_url_content_type(self.url)
|
|
94
|
+
or mimetypes.guess_type(self.url)[0]
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
def is_data_url(self) -> bool:
|
|
98
|
+
return Resource.parse_data_url_content_type(self.url) is not None
|
|
99
|
+
|
|
100
|
+
async def get_resource_name(self, storage: FileStorage | None) -> str:
|
|
101
|
+
if self.is_data_url():
|
|
102
|
+
return f"data URL ({await self.guess_content_type()})"
|
|
103
|
+
|
|
104
|
+
name = self.url
|
|
105
|
+
if storage is not None:
|
|
106
|
+
name = await storage.get_human_readable_name(self.url)
|
|
107
|
+
|
|
108
|
+
return truncate_string(name, n=50)
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
class AttachmentResource(DialResource):
|
|
112
|
+
attachment: Attachment
|
|
113
|
+
|
|
114
|
+
def to_attachment(self) -> Attachment:
|
|
115
|
+
return self.attachment
|
|
116
|
+
|
|
117
|
+
@validator("attachment", pre=True)
|
|
118
|
+
def parse_attachment(cls, value):
|
|
119
|
+
if isinstance(value, dict):
|
|
120
|
+
attachment = Attachment.parse_obj(value)
|
|
121
|
+
# Working around the issue of defaulting missing type to a markdown:
|
|
122
|
+
# https://github.com/epam/ai-dial-sdk/blob/2835107e950c89645a2b619fecba2518fa2d7bb1/aidial_sdk/chat_completion/request.py#L22
|
|
123
|
+
if "type" not in value:
|
|
124
|
+
attachment.type = None
|
|
125
|
+
return attachment
|
|
126
|
+
return value
|
|
127
|
+
|
|
128
|
+
@root_validator(pre=True)
|
|
129
|
+
def validator(cls, values):
|
|
130
|
+
values["entity_name"] = values.get("entity_name") or "attachment"
|
|
131
|
+
return values
|
|
132
|
+
|
|
133
|
+
async def download(self, storage: FileStorage | None) -> Resource:
|
|
134
|
+
type = await self.get_content_type()
|
|
135
|
+
|
|
136
|
+
if self.attachment.data:
|
|
137
|
+
data = base64.b64decode(self.attachment.data)
|
|
138
|
+
elif self.attachment.url:
|
|
139
|
+
data = await _download_url(storage, self.attachment.url)
|
|
140
|
+
else:
|
|
141
|
+
raise ValidationError(f"Invalid {self.entity_name}")
|
|
142
|
+
|
|
143
|
+
return Resource(type=type, data=data)
|
|
144
|
+
|
|
145
|
+
def create_url_resource(self, url: str) -> URLResource:
|
|
146
|
+
return URLResource(
|
|
147
|
+
url=url,
|
|
148
|
+
content_type=self.informative_content_type,
|
|
149
|
+
entity_name=self.entity_name,
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
@property
|
|
153
|
+
def informative_content_type(self) -> str | None:
|
|
154
|
+
if (
|
|
155
|
+
self.attachment.type is None
|
|
156
|
+
or "octet-stream" in self.attachment.type
|
|
157
|
+
):
|
|
158
|
+
return None
|
|
159
|
+
return self.attachment.type
|
|
160
|
+
|
|
161
|
+
async def guess_content_type(self) -> str | None:
|
|
162
|
+
if url := self.attachment.url:
|
|
163
|
+
type = await self.create_url_resource(url).guess_content_type()
|
|
164
|
+
if type:
|
|
165
|
+
return type
|
|
166
|
+
|
|
167
|
+
return self.attachment.type
|
|
168
|
+
|
|
169
|
+
async def get_resource_name(self, storage: FileStorage | None) -> str:
|
|
170
|
+
if title := self.attachment.title:
|
|
171
|
+
return title
|
|
172
|
+
|
|
173
|
+
if self.attachment.data:
|
|
174
|
+
return f"data {self.entity_name}"
|
|
175
|
+
elif url := self.attachment.url:
|
|
176
|
+
return await self.create_url_resource(url).get_resource_name(
|
|
177
|
+
storage
|
|
178
|
+
)
|
|
179
|
+
else:
|
|
180
|
+
raise ValidationError(f"Invalid {self.entity_name}")
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
async def _download_url(file_storage: FileStorage | None, url: str) -> bytes:
|
|
184
|
+
if (resource := Resource.from_data_url(url)) is not None:
|
|
185
|
+
return resource.data
|
|
186
|
+
|
|
187
|
+
if file_storage:
|
|
188
|
+
return await file_storage.download_file(url)
|
|
189
|
+
return await download_file(url)
|
|
@@ -0,0 +1,138 @@
|
|
|
1
|
+
import base64
|
|
2
|
+
import hashlib
|
|
3
|
+
import io
|
|
4
|
+
import logging
|
|
5
|
+
import mimetypes
|
|
6
|
+
from typing import Mapping, Optional, TypedDict
|
|
7
|
+
from urllib.parse import unquote, urljoin
|
|
8
|
+
|
|
9
|
+
import aiohttp
|
|
10
|
+
from pydantic import BaseModel
|
|
11
|
+
|
|
12
|
+
_log = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class FileMetadata(TypedDict):
|
|
16
|
+
name: str
|
|
17
|
+
parentPath: str
|
|
18
|
+
bucket: str
|
|
19
|
+
url: str
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class Bucket(TypedDict):
|
|
23
|
+
bucket: str
|
|
24
|
+
appdata: str
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class FileStorage(BaseModel):
|
|
28
|
+
dial_url: str
|
|
29
|
+
api_key: str
|
|
30
|
+
bucket: Optional[Bucket] = None
|
|
31
|
+
|
|
32
|
+
@property
|
|
33
|
+
def auth_headers(self) -> Mapping[str, str]:
|
|
34
|
+
return {"api-key": self.api_key}
|
|
35
|
+
|
|
36
|
+
async def _get_bucket(self, session: aiohttp.ClientSession) -> Bucket:
|
|
37
|
+
if self.bucket is None:
|
|
38
|
+
async with session.get(
|
|
39
|
+
f"{self.dial_url}/v1/bucket",
|
|
40
|
+
headers=self.auth_headers,
|
|
41
|
+
) as response:
|
|
42
|
+
response.raise_for_status()
|
|
43
|
+
self.bucket = bucket = await response.json()
|
|
44
|
+
_log.debug(f"bucket: {self.bucket}")
|
|
45
|
+
return bucket
|
|
46
|
+
|
|
47
|
+
return self.bucket
|
|
48
|
+
|
|
49
|
+
async def _get_user_bucket(self, session: aiohttp.ClientSession) -> str:
|
|
50
|
+
bucket = await self._get_bucket(session)
|
|
51
|
+
appdata = bucket.get("appdata")
|
|
52
|
+
if appdata is None:
|
|
53
|
+
raise ValueError(
|
|
54
|
+
"Can't retrieve user bucket because appdata isn't available"
|
|
55
|
+
)
|
|
56
|
+
return appdata.split("/", 1)[0]
|
|
57
|
+
|
|
58
|
+
@staticmethod
|
|
59
|
+
def _to_form_data(
|
|
60
|
+
filename: str, content_type: str, content: bytes
|
|
61
|
+
) -> aiohttp.FormData:
|
|
62
|
+
data = aiohttp.FormData()
|
|
63
|
+
data.add_field(
|
|
64
|
+
"file",
|
|
65
|
+
io.BytesIO(content),
|
|
66
|
+
filename=filename,
|
|
67
|
+
content_type=content_type,
|
|
68
|
+
)
|
|
69
|
+
return data
|
|
70
|
+
|
|
71
|
+
async def upload(
|
|
72
|
+
self, filename: str, content_type: str, content: bytes
|
|
73
|
+
) -> FileMetadata:
|
|
74
|
+
async with aiohttp.ClientSession() as session:
|
|
75
|
+
bucket = await self._get_bucket(session)
|
|
76
|
+
|
|
77
|
+
appdata = bucket["appdata"]
|
|
78
|
+
ext = mimetypes.guess_extension(content_type) or ""
|
|
79
|
+
url = f"{self.dial_url}/v1/files/{appdata}/{filename}{ext}"
|
|
80
|
+
|
|
81
|
+
data = FileStorage._to_form_data(filename, content_type, content)
|
|
82
|
+
|
|
83
|
+
async with session.put(
|
|
84
|
+
url=url,
|
|
85
|
+
data=data,
|
|
86
|
+
headers=self.auth_headers,
|
|
87
|
+
) as response:
|
|
88
|
+
response.raise_for_status()
|
|
89
|
+
meta = await response.json()
|
|
90
|
+
_log.debug(f"Uploaded file: url={url}, metadata={meta}")
|
|
91
|
+
return meta
|
|
92
|
+
|
|
93
|
+
async def upload_file_as_base64(
|
|
94
|
+
self, upload_dir: str, data: str, content_type: str
|
|
95
|
+
) -> FileMetadata:
|
|
96
|
+
filename = f"{upload_dir}/{compute_hash_digest(data)}"
|
|
97
|
+
content: bytes = base64.b64decode(data)
|
|
98
|
+
return await self.upload(filename, content_type, content)
|
|
99
|
+
|
|
100
|
+
def attachment_link_to_url(self, link: str) -> str:
|
|
101
|
+
return urljoin(f"{self.dial_url}/v1/", link)
|
|
102
|
+
|
|
103
|
+
def _url_to_attachment_link(self, url: str) -> str:
|
|
104
|
+
return url.removeprefix(f"{self.dial_url}/v1/")
|
|
105
|
+
|
|
106
|
+
async def download_file(self, link: str) -> bytes:
|
|
107
|
+
url = self.attachment_link_to_url(link)
|
|
108
|
+
headers: Mapping[str, str] = {}
|
|
109
|
+
if url.lower().startswith(self.dial_url.lower()):
|
|
110
|
+
headers = self.auth_headers
|
|
111
|
+
return await download_file(url, headers)
|
|
112
|
+
|
|
113
|
+
async def get_human_readable_name(self, link: str) -> str:
|
|
114
|
+
url = self.attachment_link_to_url(link)
|
|
115
|
+
link = self._url_to_attachment_link(url)
|
|
116
|
+
|
|
117
|
+
link = link.removeprefix("files/")
|
|
118
|
+
|
|
119
|
+
if link.startswith("public/"):
|
|
120
|
+
bucket = "public"
|
|
121
|
+
else:
|
|
122
|
+
async with aiohttp.ClientSession() as session:
|
|
123
|
+
bucket = await self._get_user_bucket(session)
|
|
124
|
+
|
|
125
|
+
link = link.removeprefix(f"{bucket}/")
|
|
126
|
+
decoded_link = unquote(link)
|
|
127
|
+
return link if link == decoded_link else repr(decoded_link)
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
async def download_file(url: str, headers: Mapping[str, str] = {}) -> bytes:
|
|
131
|
+
async with aiohttp.ClientSession() as session:
|
|
132
|
+
async with session.get(url, headers=headers) as response:
|
|
133
|
+
response.raise_for_status()
|
|
134
|
+
return await response.read()
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def compute_hash_digest(file_content: str) -> str:
|
|
138
|
+
return hashlib.sha256(file_content.encode()).hexdigest()
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
from pydantic import BaseModel
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class TokenUsage(BaseModel):
|
|
5
|
+
prompt_tokens: int = 0
|
|
6
|
+
completion_tokens: int = 0
|
|
7
|
+
cache_read_input_tokens: int = 0
|
|
8
|
+
cache_write_input_tokens: int = 0
|
|
9
|
+
|
|
10
|
+
@property
|
|
11
|
+
def total_tokens(self) -> int:
|
|
12
|
+
return self.prompt_tokens + self.completion_tokens
|
|
13
|
+
|
|
14
|
+
def accumulate(self, other: "TokenUsage") -> "TokenUsage":
|
|
15
|
+
self.prompt_tokens += other.prompt_tokens
|
|
16
|
+
self.completion_tokens += other.completion_tokens
|
|
17
|
+
self.cache_read_input_tokens += other.cache_read_input_tokens
|
|
18
|
+
self.cache_write_input_tokens += other.cache_write_input_tokens
|
|
19
|
+
return self
|
|
@@ -0,0 +1,180 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from enum import Enum
|
|
3
|
+
from typing import Dict, List, Literal, Self
|
|
4
|
+
|
|
5
|
+
from aidial_sdk.chat_completion import (
|
|
6
|
+
Function,
|
|
7
|
+
FunctionChoice,
|
|
8
|
+
Message,
|
|
9
|
+
Role,
|
|
10
|
+
Tool,
|
|
11
|
+
ToolChoice,
|
|
12
|
+
)
|
|
13
|
+
from aidial_sdk.chat_completion.request import (
|
|
14
|
+
AzureChatCompletionRequest,
|
|
15
|
+
StaticTool,
|
|
16
|
+
)
|
|
17
|
+
from pydantic import BaseModel
|
|
18
|
+
|
|
19
|
+
from aidial_adapter_anthropic.adapter._errors import ValidationError
|
|
20
|
+
|
|
21
|
+
_log = logging.getLogger(__name__)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class ToolsMode(Enum):
|
|
25
|
+
TOOLS = "TOOLS"
|
|
26
|
+
FUNCTIONS = "FUNCTIONS"
|
|
27
|
+
"""
|
|
28
|
+
Functions are deprecated instrument that came before tools
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class ToolsConfig(BaseModel):
|
|
33
|
+
tools: List[Tool]
|
|
34
|
+
"""
|
|
35
|
+
List of functions/tools.
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
tools_mode: ToolsMode
|
|
39
|
+
|
|
40
|
+
tool_choice: Literal["auto", "none", "required"] | ToolChoice
|
|
41
|
+
|
|
42
|
+
tool_ids: Dict[str, str]
|
|
43
|
+
"""
|
|
44
|
+
Mapping from tool call IDs to corresponding tool names.
|
|
45
|
+
Empty when there are no tool calls in the messages.
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
def not_supported(self) -> None:
|
|
49
|
+
if not self.tools:
|
|
50
|
+
return
|
|
51
|
+
if self.tools_mode == ToolsMode.TOOLS:
|
|
52
|
+
raise ValidationError("The tools aren't supported")
|
|
53
|
+
raise ValidationError("The functions aren't supported")
|
|
54
|
+
|
|
55
|
+
def create_fresh_tool_call_id(self, tool_name: str) -> str:
|
|
56
|
+
idx = 1
|
|
57
|
+
while True:
|
|
58
|
+
tool_id = f"{tool_name}_{idx}"
|
|
59
|
+
if tool_id not in self.tool_ids:
|
|
60
|
+
self.tool_ids[tool_id] = tool_name
|
|
61
|
+
return tool_id
|
|
62
|
+
idx += 1
|
|
63
|
+
|
|
64
|
+
def get_tool_name(self, tool_call_id: str) -> str:
|
|
65
|
+
tool_name = self.tool_ids.get(tool_call_id)
|
|
66
|
+
if tool_name is None:
|
|
67
|
+
raise ValidationError(f"Tool call ID not found: {self.tool_ids}")
|
|
68
|
+
return tool_name
|
|
69
|
+
|
|
70
|
+
@staticmethod
|
|
71
|
+
def _function_call_to_tool_choice(
|
|
72
|
+
function_call: Literal["auto", "none"] | FunctionChoice | None,
|
|
73
|
+
) -> Literal["auto", "none", "required"] | ToolChoice | None:
|
|
74
|
+
match function_call:
|
|
75
|
+
case FunctionChoice():
|
|
76
|
+
return ToolChoice(type="function", function=function_call)
|
|
77
|
+
case _:
|
|
78
|
+
return function_call
|
|
79
|
+
|
|
80
|
+
@staticmethod
|
|
81
|
+
def _get_tool_from_function(tool: Function | Tool | StaticTool) -> Tool:
|
|
82
|
+
if isinstance(tool, StaticTool):
|
|
83
|
+
raise ValidationError("Static tools aren't supported")
|
|
84
|
+
if isinstance(tool, Function):
|
|
85
|
+
return Tool(type="function", function=tool)
|
|
86
|
+
else:
|
|
87
|
+
return tool
|
|
88
|
+
|
|
89
|
+
@staticmethod
|
|
90
|
+
def _get_tools_from_functions(
|
|
91
|
+
tools: List[Function] | List[Tool | StaticTool],
|
|
92
|
+
) -> List[Tool]:
|
|
93
|
+
return [ToolsConfig._get_tool_from_function(tool) for tool in tools]
|
|
94
|
+
|
|
95
|
+
@classmethod
|
|
96
|
+
def from_request(cls, request: AzureChatCompletionRequest) -> Self | None:
|
|
97
|
+
validate_messages(request)
|
|
98
|
+
|
|
99
|
+
tool_ids = _collect_tool_ids(request.messages)
|
|
100
|
+
|
|
101
|
+
if request.functions is not None:
|
|
102
|
+
tools_mode = ToolsMode.FUNCTIONS
|
|
103
|
+
tools = cls._get_tools_from_functions(request.functions)
|
|
104
|
+
tool_choice = cls._function_call_to_tool_choice(
|
|
105
|
+
request.function_call
|
|
106
|
+
)
|
|
107
|
+
elif request.tools is not None:
|
|
108
|
+
tools_mode = ToolsMode.TOOLS
|
|
109
|
+
tools = cls._get_tools_from_functions(request.tools)
|
|
110
|
+
tool_choice = request.tool_choice
|
|
111
|
+
elif tool_ids:
|
|
112
|
+
tools_mode = ToolsMode.TOOLS
|
|
113
|
+
tools = []
|
|
114
|
+
tool_choice = None
|
|
115
|
+
else:
|
|
116
|
+
return None
|
|
117
|
+
|
|
118
|
+
return cls(
|
|
119
|
+
tools=tools,
|
|
120
|
+
tools_mode=tools_mode,
|
|
121
|
+
tool_choice=tool_choice or "auto",
|
|
122
|
+
tool_ids=tool_ids,
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def validate_messages(request: AzureChatCompletionRequest) -> None:
|
|
127
|
+
decl_tools = request.tools is not None
|
|
128
|
+
decl_functions = request.functions is not None
|
|
129
|
+
|
|
130
|
+
if decl_functions and decl_tools:
|
|
131
|
+
raise ValidationError("Both functions and tools are not allowed")
|
|
132
|
+
|
|
133
|
+
def warn(msg: str):
|
|
134
|
+
_log.warning(
|
|
135
|
+
f"The request is incomplete: {msg}. The model may misbehave."
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
tool_defs_are_missing = (
|
|
139
|
+
"the request is missing tool definitions in the 'tools' field"
|
|
140
|
+
)
|
|
141
|
+
func_defs_are_missing = (
|
|
142
|
+
"the request is missing function definitions in the 'functions' field"
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
for idx, message in enumerate(request.messages):
|
|
146
|
+
if (
|
|
147
|
+
message.role == Role.ASSISTANT
|
|
148
|
+
and message.tool_calls is not None
|
|
149
|
+
and not decl_tools
|
|
150
|
+
):
|
|
151
|
+
warn(
|
|
152
|
+
f"'messages[{idx}]' is an Assistant message with a tool call, but {tool_defs_are_missing}"
|
|
153
|
+
)
|
|
154
|
+
if (
|
|
155
|
+
message.role == Role.ASSISTANT
|
|
156
|
+
and message.function_call is not None
|
|
157
|
+
and not decl_functions
|
|
158
|
+
):
|
|
159
|
+
warn(
|
|
160
|
+
f"'messages[{idx}]' is an Assistant messages with a function call, but {func_defs_are_missing}"
|
|
161
|
+
)
|
|
162
|
+
if message.role == Role.FUNCTION and not decl_functions:
|
|
163
|
+
warn(
|
|
164
|
+
f"'messages[{idx}]' is a Function message, but {func_defs_are_missing}"
|
|
165
|
+
)
|
|
166
|
+
if message.role == Role.TOOL and not decl_tools:
|
|
167
|
+
warn(
|
|
168
|
+
f"'messages[{idx}]' is a Tool message, but {tool_defs_are_missing}"
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
def _collect_tool_ids(messages: List[Message]) -> Dict[str, str]:
|
|
173
|
+
ret: Dict[str, str] = {}
|
|
174
|
+
|
|
175
|
+
for message in messages:
|
|
176
|
+
if message.role == Role.ASSISTANT and message.tool_calls is not None:
|
|
177
|
+
for tool_call in message.tool_calls:
|
|
178
|
+
ret[tool_call.id] = tool_call.function.name
|
|
179
|
+
|
|
180
|
+
return ret
|