aidial-adapter-anthropic 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. aidial_adapter_anthropic/_utils/json.py +116 -0
  2. aidial_adapter_anthropic/_utils/list.py +84 -0
  3. aidial_adapter_anthropic/_utils/pydantic.py +6 -0
  4. aidial_adapter_anthropic/_utils/resource.py +54 -0
  5. aidial_adapter_anthropic/_utils/text.py +4 -0
  6. aidial_adapter_anthropic/adapter/__init__.py +4 -0
  7. aidial_adapter_anthropic/adapter/_base.py +95 -0
  8. aidial_adapter_anthropic/adapter/_claude/adapter.py +549 -0
  9. aidial_adapter_anthropic/adapter/_claude/blocks.py +128 -0
  10. aidial_adapter_anthropic/adapter/_claude/citations.py +63 -0
  11. aidial_adapter_anthropic/adapter/_claude/config.py +39 -0
  12. aidial_adapter_anthropic/adapter/_claude/converters.py +303 -0
  13. aidial_adapter_anthropic/adapter/_claude/params.py +25 -0
  14. aidial_adapter_anthropic/adapter/_claude/state.py +45 -0
  15. aidial_adapter_anthropic/adapter/_claude/tokenizer/__init__.py +10 -0
  16. aidial_adapter_anthropic/adapter/_claude/tokenizer/anthropic.py +57 -0
  17. aidial_adapter_anthropic/adapter/_claude/tokenizer/approximate.py +260 -0
  18. aidial_adapter_anthropic/adapter/_claude/tokenizer/base.py +26 -0
  19. aidial_adapter_anthropic/adapter/_claude/tools.py +98 -0
  20. aidial_adapter_anthropic/adapter/_decorator/base.py +53 -0
  21. aidial_adapter_anthropic/adapter/_decorator/preprocess.py +63 -0
  22. aidial_adapter_anthropic/adapter/_decorator/replicator.py +32 -0
  23. aidial_adapter_anthropic/adapter/_errors.py +71 -0
  24. aidial_adapter_anthropic/adapter/_tokenize.py +12 -0
  25. aidial_adapter_anthropic/adapter/_truncate_prompt.py +168 -0
  26. aidial_adapter_anthropic/adapter/claude.py +17 -0
  27. aidial_adapter_anthropic/dial/_attachments.py +238 -0
  28. aidial_adapter_anthropic/dial/_lazy_stage.py +40 -0
  29. aidial_adapter_anthropic/dial/_message.py +341 -0
  30. aidial_adapter_anthropic/dial/consumer.py +235 -0
  31. aidial_adapter_anthropic/dial/request.py +170 -0
  32. aidial_adapter_anthropic/dial/resource.py +189 -0
  33. aidial_adapter_anthropic/dial/storage.py +138 -0
  34. aidial_adapter_anthropic/dial/token_usage.py +19 -0
  35. aidial_adapter_anthropic/dial/tools.py +180 -0
  36. aidial_adapter_anthropic-0.1.0.dist-info/LICENSE +202 -0
  37. aidial_adapter_anthropic-0.1.0.dist-info/METADATA +121 -0
  38. aidial_adapter_anthropic-0.1.0.dist-info/RECORD +39 -0
  39. aidial_adapter_anthropic-0.1.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,170 @@
1
+ from typing import (
2
+ List,
3
+ Literal,
4
+ Optional,
5
+ Type,
6
+ TypeGuard,
7
+ TypeVar,
8
+ assert_never,
9
+ )
10
+
11
+ from aidial_sdk.chat_completion import (
12
+ MessageContentImagePart,
13
+ MessageContentPart,
14
+ MessageContentTextPart,
15
+ Role,
16
+ )
17
+ from aidial_sdk.chat_completion.request import (
18
+ ChatCompletionRequest,
19
+ MessageContentRefusalPart,
20
+ )
21
+ from aidial_sdk.exceptions import RequestValidationError
22
+ from pydantic import BaseModel
23
+ from pydantic.v1 import ValidationError as PydanticValidationError
24
+
25
+ from aidial_adapter_anthropic.adapter._errors import ValidationError
26
+ from aidial_adapter_anthropic.dial.tools import (
27
+ ToolsConfig,
28
+ ToolsMode,
29
+ validate_messages,
30
+ )
31
+
32
+ MessageContent = str | List[MessageContentPart] | None
33
+ MessageContentSpecialized = (
34
+ MessageContent
35
+ | List[MessageContentTextPart]
36
+ | List[MessageContentImagePart]
37
+ )
38
+
39
+ _Model = TypeVar("_Model", bound=BaseModel)
40
+
41
+
42
+ class ModelParameters(BaseModel):
43
+ temperature: Optional[float] = None
44
+ top_p: Optional[float] = None
45
+ n: int = 1
46
+ stop: List[str] = []
47
+ seed: Optional[int] = None
48
+ max_tokens: Optional[int] = None
49
+ max_prompt_tokens: Optional[int] = None
50
+ stream: bool = False
51
+ tool_config: Optional[ToolsConfig] = None
52
+ configuration: Optional[dict] = None
53
+
54
+ @classmethod
55
+ def create(cls, request: ChatCompletionRequest) -> "ModelParameters":
56
+ stop: List[str] = []
57
+ if request.stop is not None:
58
+ stop = (
59
+ [request.stop]
60
+ if isinstance(request.stop, str)
61
+ else request.stop
62
+ )
63
+
64
+ validate_messages(request)
65
+
66
+ configuration = (
67
+ cf.configuration
68
+ if (cf := request.custom_fields) is not None
69
+ else None
70
+ )
71
+
72
+ return cls(
73
+ temperature=request.temperature,
74
+ top_p=request.top_p,
75
+ n=request.n or 1,
76
+ stop=stop,
77
+ seed=request.seed,
78
+ max_tokens=request.max_tokens,
79
+ max_prompt_tokens=request.max_prompt_tokens,
80
+ stream=request.stream,
81
+ tool_config=ToolsConfig.from_request(request),
82
+ configuration=configuration,
83
+ )
84
+
85
+ def add_stop_sequences(self, stop: List[str]) -> "ModelParameters":
86
+ return self.copy(update={"stop": [*self.stop, *stop]})
87
+
88
+ @property
89
+ def tools_mode(self) -> ToolsMode | None:
90
+ if self.tool_config is not None:
91
+ return self.tool_config.tools_mode
92
+ return None
93
+
94
+ def parse_configuration(self, cls: Type[_Model]) -> _Model:
95
+ try:
96
+ return cls.parse_obj(self.configuration or {})
97
+ except PydanticValidationError as e:
98
+ if self.configuration is None:
99
+ msg = "The configuration at path 'custom_fields.configuration' is missing."
100
+ else:
101
+ error = e.errors()[0]
102
+ path = ".".join(map(str, error["loc"]))
103
+ msg = f"Invalid request. Path: 'custom_fields.configuration.{path}', error: {error['msg']}"
104
+
105
+ raise RequestValidationError(msg)
106
+
107
+
108
+ def collect_text_content(
109
+ content: MessageContentSpecialized, delimiter: str = "\n\n"
110
+ ) -> str:
111
+ match content:
112
+ case None:
113
+ return ""
114
+ case str():
115
+ return content
116
+ case list():
117
+ texts: List[str] = []
118
+ for part in content:
119
+ match part:
120
+ case MessageContentTextPart(text=text):
121
+ texts.append(text)
122
+ case MessageContentImagePart():
123
+ raise ValidationError(
124
+ "Can't extract text from an image content part"
125
+ )
126
+ case MessageContentRefusalPart():
127
+ raise ValidationError(
128
+ "Can't extract text from a refusal content part"
129
+ )
130
+ case _:
131
+ assert_never(part)
132
+ return delimiter.join(texts)
133
+ case _:
134
+ assert_never(content)
135
+
136
+
137
+ def to_message_content(content: MessageContentSpecialized) -> MessageContent:
138
+ match content:
139
+ case None | str():
140
+ return content
141
+ case list():
142
+ return [*content]
143
+ case _:
144
+ assert_never(content)
145
+
146
+
147
+ def is_text_content(
148
+ content: MessageContent,
149
+ ) -> TypeGuard[str | List[MessageContentTextPart]]:
150
+ match content:
151
+ case None:
152
+ return False
153
+ case str():
154
+ return True
155
+ case list():
156
+ return all(
157
+ isinstance(part, MessageContentTextPart) for part in content
158
+ )
159
+ case _:
160
+ assert_never(content)
161
+
162
+
163
+ def is_plain_text_content(content: MessageContent) -> TypeGuard[str | None]:
164
+ return content is None or isinstance(content, str)
165
+
166
+
167
+ def is_system_role(
168
+ role: Role,
169
+ ) -> TypeGuard[Literal[Role.SYSTEM, Role.DEVELOPER]]:
170
+ return role in [Role.SYSTEM, Role.DEVELOPER]
@@ -0,0 +1,189 @@
1
+ import base64
2
+ import mimetypes
3
+ from abc import ABC, abstractmethod
4
+ from typing import List
5
+
6
+ from aidial_sdk.chat_completion import Attachment
7
+ from pydantic import BaseModel, Field, root_validator, validator
8
+
9
+ from aidial_adapter_anthropic._utils.resource import Resource
10
+ from aidial_adapter_anthropic._utils.text import truncate_string
11
+ from aidial_adapter_anthropic.dial.storage import FileStorage, download_file
12
+
13
+
14
+ class ValidationError(Exception):
15
+ message: str
16
+
17
+ def __init__(self, message: str):
18
+ self.message = message
19
+ super().__init__(message)
20
+
21
+
22
+ class MissingContentType(ValidationError):
23
+ pass
24
+
25
+
26
+ class UnsupportedContentType(ValidationError):
27
+ type: str
28
+ supported_types: List[str]
29
+
30
+ def __init__(self, *, message: str, type: str, supported_types: List[str]):
31
+ self.type = type
32
+ self.supported_types = supported_types
33
+ super().__init__(message)
34
+
35
+
36
+ class DialResource(ABC, BaseModel):
37
+ entity_name: str = Field(default=None)
38
+ supported_types: List[str] | None = Field(default=None)
39
+
40
+ @abstractmethod
41
+ def to_attachment(self) -> Attachment: ...
42
+
43
+ @abstractmethod
44
+ async def download(self, storage: FileStorage | None) -> Resource: ...
45
+
46
+ @abstractmethod
47
+ async def guess_content_type(self) -> str | None: ...
48
+
49
+ @abstractmethod
50
+ async def get_resource_name(self, storage: FileStorage | None) -> str: ...
51
+
52
+ async def get_content_type(self) -> str:
53
+ type = await self.guess_content_type()
54
+
55
+ if not type:
56
+ raise MissingContentType(
57
+ f"Can't derive content type of the {self.entity_name}"
58
+ )
59
+
60
+ if (
61
+ self.supported_types is not None
62
+ and type not in self.supported_types
63
+ ):
64
+ raise UnsupportedContentType(
65
+ message=f"The {self.entity_name} is not one of the supported types",
66
+ type=type,
67
+ supported_types=self.supported_types,
68
+ )
69
+
70
+ return type
71
+
72
+
73
+ class URLResource(DialResource):
74
+ url: str
75
+ content_type: str | None = None
76
+
77
+ def to_attachment(self) -> Attachment:
78
+ return Attachment(type=self.content_type, url=self.url)
79
+
80
+ @root_validator
81
+ def validator(cls, values):
82
+ values["entity_name"] = values.get("entity_name") or "URL"
83
+ return values
84
+
85
+ async def download(self, storage: FileStorage | None) -> Resource:
86
+ type = await self.get_content_type()
87
+ data = await _download_url(storage, self.url)
88
+ return Resource(type=type, data=data)
89
+
90
+ async def guess_content_type(self) -> str | None:
91
+ return (
92
+ self.content_type
93
+ or Resource.parse_data_url_content_type(self.url)
94
+ or mimetypes.guess_type(self.url)[0]
95
+ )
96
+
97
+ def is_data_url(self) -> bool:
98
+ return Resource.parse_data_url_content_type(self.url) is not None
99
+
100
+ async def get_resource_name(self, storage: FileStorage | None) -> str:
101
+ if self.is_data_url():
102
+ return f"data URL ({await self.guess_content_type()})"
103
+
104
+ name = self.url
105
+ if storage is not None:
106
+ name = await storage.get_human_readable_name(self.url)
107
+
108
+ return truncate_string(name, n=50)
109
+
110
+
111
+ class AttachmentResource(DialResource):
112
+ attachment: Attachment
113
+
114
+ def to_attachment(self) -> Attachment:
115
+ return self.attachment
116
+
117
+ @validator("attachment", pre=True)
118
+ def parse_attachment(cls, value):
119
+ if isinstance(value, dict):
120
+ attachment = Attachment.parse_obj(value)
121
+ # Working around the issue of defaulting missing type to a markdown:
122
+ # https://github.com/epam/ai-dial-sdk/blob/2835107e950c89645a2b619fecba2518fa2d7bb1/aidial_sdk/chat_completion/request.py#L22
123
+ if "type" not in value:
124
+ attachment.type = None
125
+ return attachment
126
+ return value
127
+
128
+ @root_validator(pre=True)
129
+ def validator(cls, values):
130
+ values["entity_name"] = values.get("entity_name") or "attachment"
131
+ return values
132
+
133
+ async def download(self, storage: FileStorage | None) -> Resource:
134
+ type = await self.get_content_type()
135
+
136
+ if self.attachment.data:
137
+ data = base64.b64decode(self.attachment.data)
138
+ elif self.attachment.url:
139
+ data = await _download_url(storage, self.attachment.url)
140
+ else:
141
+ raise ValidationError(f"Invalid {self.entity_name}")
142
+
143
+ return Resource(type=type, data=data)
144
+
145
+ def create_url_resource(self, url: str) -> URLResource:
146
+ return URLResource(
147
+ url=url,
148
+ content_type=self.informative_content_type,
149
+ entity_name=self.entity_name,
150
+ )
151
+
152
+ @property
153
+ def informative_content_type(self) -> str | None:
154
+ if (
155
+ self.attachment.type is None
156
+ or "octet-stream" in self.attachment.type
157
+ ):
158
+ return None
159
+ return self.attachment.type
160
+
161
+ async def guess_content_type(self) -> str | None:
162
+ if url := self.attachment.url:
163
+ type = await self.create_url_resource(url).guess_content_type()
164
+ if type:
165
+ return type
166
+
167
+ return self.attachment.type
168
+
169
+ async def get_resource_name(self, storage: FileStorage | None) -> str:
170
+ if title := self.attachment.title:
171
+ return title
172
+
173
+ if self.attachment.data:
174
+ return f"data {self.entity_name}"
175
+ elif url := self.attachment.url:
176
+ return await self.create_url_resource(url).get_resource_name(
177
+ storage
178
+ )
179
+ else:
180
+ raise ValidationError(f"Invalid {self.entity_name}")
181
+
182
+
183
+ async def _download_url(file_storage: FileStorage | None, url: str) -> bytes:
184
+ if (resource := Resource.from_data_url(url)) is not None:
185
+ return resource.data
186
+
187
+ if file_storage:
188
+ return await file_storage.download_file(url)
189
+ return await download_file(url)
@@ -0,0 +1,138 @@
1
+ import base64
2
+ import hashlib
3
+ import io
4
+ import logging
5
+ import mimetypes
6
+ from typing import Mapping, Optional, TypedDict
7
+ from urllib.parse import unquote, urljoin
8
+
9
+ import aiohttp
10
+ from pydantic import BaseModel
11
+
12
+ _log = logging.getLogger(__name__)
13
+
14
+
15
+ class FileMetadata(TypedDict):
16
+ name: str
17
+ parentPath: str
18
+ bucket: str
19
+ url: str
20
+
21
+
22
+ class Bucket(TypedDict):
23
+ bucket: str
24
+ appdata: str
25
+
26
+
27
+ class FileStorage(BaseModel):
28
+ dial_url: str
29
+ api_key: str
30
+ bucket: Optional[Bucket] = None
31
+
32
+ @property
33
+ def auth_headers(self) -> Mapping[str, str]:
34
+ return {"api-key": self.api_key}
35
+
36
+ async def _get_bucket(self, session: aiohttp.ClientSession) -> Bucket:
37
+ if self.bucket is None:
38
+ async with session.get(
39
+ f"{self.dial_url}/v1/bucket",
40
+ headers=self.auth_headers,
41
+ ) as response:
42
+ response.raise_for_status()
43
+ self.bucket = bucket = await response.json()
44
+ _log.debug(f"bucket: {self.bucket}")
45
+ return bucket
46
+
47
+ return self.bucket
48
+
49
+ async def _get_user_bucket(self, session: aiohttp.ClientSession) -> str:
50
+ bucket = await self._get_bucket(session)
51
+ appdata = bucket.get("appdata")
52
+ if appdata is None:
53
+ raise ValueError(
54
+ "Can't retrieve user bucket because appdata isn't available"
55
+ )
56
+ return appdata.split("/", 1)[0]
57
+
58
+ @staticmethod
59
+ def _to_form_data(
60
+ filename: str, content_type: str, content: bytes
61
+ ) -> aiohttp.FormData:
62
+ data = aiohttp.FormData()
63
+ data.add_field(
64
+ "file",
65
+ io.BytesIO(content),
66
+ filename=filename,
67
+ content_type=content_type,
68
+ )
69
+ return data
70
+
71
+ async def upload(
72
+ self, filename: str, content_type: str, content: bytes
73
+ ) -> FileMetadata:
74
+ async with aiohttp.ClientSession() as session:
75
+ bucket = await self._get_bucket(session)
76
+
77
+ appdata = bucket["appdata"]
78
+ ext = mimetypes.guess_extension(content_type) or ""
79
+ url = f"{self.dial_url}/v1/files/{appdata}/{filename}{ext}"
80
+
81
+ data = FileStorage._to_form_data(filename, content_type, content)
82
+
83
+ async with session.put(
84
+ url=url,
85
+ data=data,
86
+ headers=self.auth_headers,
87
+ ) as response:
88
+ response.raise_for_status()
89
+ meta = await response.json()
90
+ _log.debug(f"Uploaded file: url={url}, metadata={meta}")
91
+ return meta
92
+
93
+ async def upload_file_as_base64(
94
+ self, upload_dir: str, data: str, content_type: str
95
+ ) -> FileMetadata:
96
+ filename = f"{upload_dir}/{compute_hash_digest(data)}"
97
+ content: bytes = base64.b64decode(data)
98
+ return await self.upload(filename, content_type, content)
99
+
100
+ def attachment_link_to_url(self, link: str) -> str:
101
+ return urljoin(f"{self.dial_url}/v1/", link)
102
+
103
+ def _url_to_attachment_link(self, url: str) -> str:
104
+ return url.removeprefix(f"{self.dial_url}/v1/")
105
+
106
+ async def download_file(self, link: str) -> bytes:
107
+ url = self.attachment_link_to_url(link)
108
+ headers: Mapping[str, str] = {}
109
+ if url.lower().startswith(self.dial_url.lower()):
110
+ headers = self.auth_headers
111
+ return await download_file(url, headers)
112
+
113
+ async def get_human_readable_name(self, link: str) -> str:
114
+ url = self.attachment_link_to_url(link)
115
+ link = self._url_to_attachment_link(url)
116
+
117
+ link = link.removeprefix("files/")
118
+
119
+ if link.startswith("public/"):
120
+ bucket = "public"
121
+ else:
122
+ async with aiohttp.ClientSession() as session:
123
+ bucket = await self._get_user_bucket(session)
124
+
125
+ link = link.removeprefix(f"{bucket}/")
126
+ decoded_link = unquote(link)
127
+ return link if link == decoded_link else repr(decoded_link)
128
+
129
+
130
+ async def download_file(url: str, headers: Mapping[str, str] = {}) -> bytes:
131
+ async with aiohttp.ClientSession() as session:
132
+ async with session.get(url, headers=headers) as response:
133
+ response.raise_for_status()
134
+ return await response.read()
135
+
136
+
137
+ def compute_hash_digest(file_content: str) -> str:
138
+ return hashlib.sha256(file_content.encode()).hexdigest()
@@ -0,0 +1,19 @@
1
+ from pydantic import BaseModel
2
+
3
+
4
+ class TokenUsage(BaseModel):
5
+ prompt_tokens: int = 0
6
+ completion_tokens: int = 0
7
+ cache_read_input_tokens: int = 0
8
+ cache_write_input_tokens: int = 0
9
+
10
+ @property
11
+ def total_tokens(self) -> int:
12
+ return self.prompt_tokens + self.completion_tokens
13
+
14
+ def accumulate(self, other: "TokenUsage") -> "TokenUsage":
15
+ self.prompt_tokens += other.prompt_tokens
16
+ self.completion_tokens += other.completion_tokens
17
+ self.cache_read_input_tokens += other.cache_read_input_tokens
18
+ self.cache_write_input_tokens += other.cache_write_input_tokens
19
+ return self
@@ -0,0 +1,180 @@
1
+ import logging
2
+ from enum import Enum
3
+ from typing import Dict, List, Literal, Self
4
+
5
+ from aidial_sdk.chat_completion import (
6
+ Function,
7
+ FunctionChoice,
8
+ Message,
9
+ Role,
10
+ Tool,
11
+ ToolChoice,
12
+ )
13
+ from aidial_sdk.chat_completion.request import (
14
+ AzureChatCompletionRequest,
15
+ StaticTool,
16
+ )
17
+ from pydantic import BaseModel
18
+
19
+ from aidial_adapter_anthropic.adapter._errors import ValidationError
20
+
21
+ _log = logging.getLogger(__name__)
22
+
23
+
24
+ class ToolsMode(Enum):
25
+ TOOLS = "TOOLS"
26
+ FUNCTIONS = "FUNCTIONS"
27
+ """
28
+ Functions are deprecated instrument that came before tools
29
+ """
30
+
31
+
32
+ class ToolsConfig(BaseModel):
33
+ tools: List[Tool]
34
+ """
35
+ List of functions/tools.
36
+ """
37
+
38
+ tools_mode: ToolsMode
39
+
40
+ tool_choice: Literal["auto", "none", "required"] | ToolChoice
41
+
42
+ tool_ids: Dict[str, str]
43
+ """
44
+ Mapping from tool call IDs to corresponding tool names.
45
+ Empty when there are no tool calls in the messages.
46
+ """
47
+
48
+ def not_supported(self) -> None:
49
+ if not self.tools:
50
+ return
51
+ if self.tools_mode == ToolsMode.TOOLS:
52
+ raise ValidationError("The tools aren't supported")
53
+ raise ValidationError("The functions aren't supported")
54
+
55
+ def create_fresh_tool_call_id(self, tool_name: str) -> str:
56
+ idx = 1
57
+ while True:
58
+ tool_id = f"{tool_name}_{idx}"
59
+ if tool_id not in self.tool_ids:
60
+ self.tool_ids[tool_id] = tool_name
61
+ return tool_id
62
+ idx += 1
63
+
64
+ def get_tool_name(self, tool_call_id: str) -> str:
65
+ tool_name = self.tool_ids.get(tool_call_id)
66
+ if tool_name is None:
67
+ raise ValidationError(f"Tool call ID not found: {self.tool_ids}")
68
+ return tool_name
69
+
70
+ @staticmethod
71
+ def _function_call_to_tool_choice(
72
+ function_call: Literal["auto", "none"] | FunctionChoice | None,
73
+ ) -> Literal["auto", "none", "required"] | ToolChoice | None:
74
+ match function_call:
75
+ case FunctionChoice():
76
+ return ToolChoice(type="function", function=function_call)
77
+ case _:
78
+ return function_call
79
+
80
+ @staticmethod
81
+ def _get_tool_from_function(tool: Function | Tool | StaticTool) -> Tool:
82
+ if isinstance(tool, StaticTool):
83
+ raise ValidationError("Static tools aren't supported")
84
+ if isinstance(tool, Function):
85
+ return Tool(type="function", function=tool)
86
+ else:
87
+ return tool
88
+
89
+ @staticmethod
90
+ def _get_tools_from_functions(
91
+ tools: List[Function] | List[Tool | StaticTool],
92
+ ) -> List[Tool]:
93
+ return [ToolsConfig._get_tool_from_function(tool) for tool in tools]
94
+
95
+ @classmethod
96
+ def from_request(cls, request: AzureChatCompletionRequest) -> Self | None:
97
+ validate_messages(request)
98
+
99
+ tool_ids = _collect_tool_ids(request.messages)
100
+
101
+ if request.functions is not None:
102
+ tools_mode = ToolsMode.FUNCTIONS
103
+ tools = cls._get_tools_from_functions(request.functions)
104
+ tool_choice = cls._function_call_to_tool_choice(
105
+ request.function_call
106
+ )
107
+ elif request.tools is not None:
108
+ tools_mode = ToolsMode.TOOLS
109
+ tools = cls._get_tools_from_functions(request.tools)
110
+ tool_choice = request.tool_choice
111
+ elif tool_ids:
112
+ tools_mode = ToolsMode.TOOLS
113
+ tools = []
114
+ tool_choice = None
115
+ else:
116
+ return None
117
+
118
+ return cls(
119
+ tools=tools,
120
+ tools_mode=tools_mode,
121
+ tool_choice=tool_choice or "auto",
122
+ tool_ids=tool_ids,
123
+ )
124
+
125
+
126
+ def validate_messages(request: AzureChatCompletionRequest) -> None:
127
+ decl_tools = request.tools is not None
128
+ decl_functions = request.functions is not None
129
+
130
+ if decl_functions and decl_tools:
131
+ raise ValidationError("Both functions and tools are not allowed")
132
+
133
+ def warn(msg: str):
134
+ _log.warning(
135
+ f"The request is incomplete: {msg}. The model may misbehave."
136
+ )
137
+
138
+ tool_defs_are_missing = (
139
+ "the request is missing tool definitions in the 'tools' field"
140
+ )
141
+ func_defs_are_missing = (
142
+ "the request is missing function definitions in the 'functions' field"
143
+ )
144
+
145
+ for idx, message in enumerate(request.messages):
146
+ if (
147
+ message.role == Role.ASSISTANT
148
+ and message.tool_calls is not None
149
+ and not decl_tools
150
+ ):
151
+ warn(
152
+ f"'messages[{idx}]' is an Assistant message with a tool call, but {tool_defs_are_missing}"
153
+ )
154
+ if (
155
+ message.role == Role.ASSISTANT
156
+ and message.function_call is not None
157
+ and not decl_functions
158
+ ):
159
+ warn(
160
+ f"'messages[{idx}]' is an Assistant messages with a function call, but {func_defs_are_missing}"
161
+ )
162
+ if message.role == Role.FUNCTION and not decl_functions:
163
+ warn(
164
+ f"'messages[{idx}]' is a Function message, but {func_defs_are_missing}"
165
+ )
166
+ if message.role == Role.TOOL and not decl_tools:
167
+ warn(
168
+ f"'messages[{idx}]' is a Tool message, but {tool_defs_are_missing}"
169
+ )
170
+
171
+
172
+ def _collect_tool_ids(messages: List[Message]) -> Dict[str, str]:
173
+ ret: Dict[str, str] = {}
174
+
175
+ for message in messages:
176
+ if message.role == Role.ASSISTANT and message.tool_calls is not None:
177
+ for tool_call in message.tool_calls:
178
+ ret[tool_call.id] = tool_call.function.name
179
+
180
+ return ret