aidial-adapter-anthropic 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aidial_adapter_anthropic/_utils/json.py +116 -0
- aidial_adapter_anthropic/_utils/list.py +84 -0
- aidial_adapter_anthropic/_utils/pydantic.py +6 -0
- aidial_adapter_anthropic/_utils/resource.py +54 -0
- aidial_adapter_anthropic/_utils/text.py +4 -0
- aidial_adapter_anthropic/adapter/__init__.py +4 -0
- aidial_adapter_anthropic/adapter/_base.py +95 -0
- aidial_adapter_anthropic/adapter/_claude/adapter.py +549 -0
- aidial_adapter_anthropic/adapter/_claude/blocks.py +128 -0
- aidial_adapter_anthropic/adapter/_claude/citations.py +63 -0
- aidial_adapter_anthropic/adapter/_claude/config.py +39 -0
- aidial_adapter_anthropic/adapter/_claude/converters.py +303 -0
- aidial_adapter_anthropic/adapter/_claude/params.py +25 -0
- aidial_adapter_anthropic/adapter/_claude/state.py +45 -0
- aidial_adapter_anthropic/adapter/_claude/tokenizer/__init__.py +10 -0
- aidial_adapter_anthropic/adapter/_claude/tokenizer/anthropic.py +57 -0
- aidial_adapter_anthropic/adapter/_claude/tokenizer/approximate.py +260 -0
- aidial_adapter_anthropic/adapter/_claude/tokenizer/base.py +26 -0
- aidial_adapter_anthropic/adapter/_claude/tools.py +98 -0
- aidial_adapter_anthropic/adapter/_decorator/base.py +53 -0
- aidial_adapter_anthropic/adapter/_decorator/preprocess.py +63 -0
- aidial_adapter_anthropic/adapter/_decorator/replicator.py +32 -0
- aidial_adapter_anthropic/adapter/_errors.py +71 -0
- aidial_adapter_anthropic/adapter/_tokenize.py +12 -0
- aidial_adapter_anthropic/adapter/_truncate_prompt.py +168 -0
- aidial_adapter_anthropic/adapter/claude.py +17 -0
- aidial_adapter_anthropic/dial/_attachments.py +238 -0
- aidial_adapter_anthropic/dial/_lazy_stage.py +40 -0
- aidial_adapter_anthropic/dial/_message.py +341 -0
- aidial_adapter_anthropic/dial/consumer.py +235 -0
- aidial_adapter_anthropic/dial/request.py +170 -0
- aidial_adapter_anthropic/dial/resource.py +189 -0
- aidial_adapter_anthropic/dial/storage.py +138 -0
- aidial_adapter_anthropic/dial/token_usage.py +19 -0
- aidial_adapter_anthropic/dial/tools.py +180 -0
- aidial_adapter_anthropic-0.1.0.dist-info/LICENSE +202 -0
- aidial_adapter_anthropic-0.1.0.dist-info/METADATA +121 -0
- aidial_adapter_anthropic-0.1.0.dist-info/RECORD +39 -0
- aidial_adapter_anthropic-0.1.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Utilities for pretty-printing JSON in debug logs.
|
|
3
|
+
These functions are useful for dumping large data structures,
|
|
4
|
+
with options to trim long strings and lists to specified limits.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import json
|
|
8
|
+
from dataclasses import asdict, is_dataclass
|
|
9
|
+
from enum import Enum
|
|
10
|
+
from typing import Any
|
|
11
|
+
|
|
12
|
+
from anthropic import Omit
|
|
13
|
+
from pydantic import BaseModel
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def json_dumps_short(
|
|
17
|
+
obj: Any, *, string_limit: int = 100, list_len_limit: int = 10, **kwargs
|
|
18
|
+
) -> str:
|
|
19
|
+
def default(obj) -> str:
|
|
20
|
+
return _truncate_strings(str(obj), string_limit)
|
|
21
|
+
|
|
22
|
+
return json.dumps(
|
|
23
|
+
_truncate_lists(
|
|
24
|
+
_truncate_strings(_to_dict(obj, **kwargs), string_limit),
|
|
25
|
+
list_len_limit,
|
|
26
|
+
),
|
|
27
|
+
default=default,
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def json_dumps(obj: Any, **kwargs) -> str:
|
|
32
|
+
return json.dumps(_to_dict(obj, **kwargs))
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def _to_dict(obj: Any, **kwargs) -> Any:
|
|
36
|
+
def rec(val):
|
|
37
|
+
return _to_dict(val, **kwargs)
|
|
38
|
+
|
|
39
|
+
def dict_field(key: str, val: Any) -> Any:
|
|
40
|
+
if key in kwargs.get("excluded_keys", []):
|
|
41
|
+
return "<excluded>"
|
|
42
|
+
return val
|
|
43
|
+
|
|
44
|
+
if isinstance(obj, bytes):
|
|
45
|
+
return f"<bytes>({len(obj):_} B)"
|
|
46
|
+
|
|
47
|
+
if isinstance(obj, Enum):
|
|
48
|
+
return obj.value
|
|
49
|
+
|
|
50
|
+
if isinstance(obj, dict):
|
|
51
|
+
return {key: rec(dict_field(key, value)) for key, value in obj.items()}
|
|
52
|
+
|
|
53
|
+
if isinstance(obj, list):
|
|
54
|
+
return [rec(element) for element in obj]
|
|
55
|
+
|
|
56
|
+
if isinstance(obj, tuple):
|
|
57
|
+
return tuple(rec(element) for element in obj)
|
|
58
|
+
|
|
59
|
+
if isinstance(obj, BaseModel):
|
|
60
|
+
return rec(obj.dict())
|
|
61
|
+
|
|
62
|
+
if hasattr(obj, "to_dict"):
|
|
63
|
+
return rec(obj.to_dict())
|
|
64
|
+
|
|
65
|
+
if is_dataclass(type(obj)):
|
|
66
|
+
return rec(asdict(obj))
|
|
67
|
+
|
|
68
|
+
if isinstance(obj, Omit):
|
|
69
|
+
return "omit"
|
|
70
|
+
|
|
71
|
+
return obj
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def _truncate_strings(obj: Any, limit: int) -> Any:
|
|
75
|
+
def rec(val):
|
|
76
|
+
return _truncate_strings(val, limit)
|
|
77
|
+
|
|
78
|
+
if isinstance(obj, dict):
|
|
79
|
+
return {key: rec(value) for key, value in obj.items()}
|
|
80
|
+
|
|
81
|
+
if isinstance(obj, list):
|
|
82
|
+
return [rec(element) for element in obj]
|
|
83
|
+
|
|
84
|
+
if isinstance(obj, tuple):
|
|
85
|
+
return tuple(rec(element) for element in obj)
|
|
86
|
+
|
|
87
|
+
if isinstance(obj, str) and len(obj) > limit:
|
|
88
|
+
skip = len(obj) - limit
|
|
89
|
+
return (
|
|
90
|
+
obj[: limit // 2] + f"...({skip:_} skipped)..." + obj[-limit // 2 :]
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
return obj
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def _truncate_lists(obj: Any, limit: int) -> Any:
|
|
97
|
+
def rec(val):
|
|
98
|
+
return _truncate_lists(val, limit)
|
|
99
|
+
|
|
100
|
+
if isinstance(obj, dict):
|
|
101
|
+
return {key: rec(value) for key, value in obj.items()}
|
|
102
|
+
|
|
103
|
+
if isinstance(obj, list):
|
|
104
|
+
if len(obj) > limit:
|
|
105
|
+
skip = len(obj) - limit
|
|
106
|
+
obj = (
|
|
107
|
+
obj[: limit // 2]
|
|
108
|
+
+ [f"...({skip:_} skipped)..."]
|
|
109
|
+
+ obj[-limit // 2 :]
|
|
110
|
+
)
|
|
111
|
+
return [rec(element) for element in obj]
|
|
112
|
+
|
|
113
|
+
if isinstance(obj, tuple):
|
|
114
|
+
return tuple(rec(element) for element in obj)
|
|
115
|
+
|
|
116
|
+
return obj
|
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
from dataclasses import dataclass, field
|
|
2
|
+
from typing import (
|
|
3
|
+
Any,
|
|
4
|
+
AsyncIterator,
|
|
5
|
+
Callable,
|
|
6
|
+
Container,
|
|
7
|
+
Generic,
|
|
8
|
+
Iterable,
|
|
9
|
+
List,
|
|
10
|
+
Self,
|
|
11
|
+
Set,
|
|
12
|
+
Tuple,
|
|
13
|
+
TypeVar,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
_T = TypeVar("_T")
|
|
17
|
+
_V = TypeVar("_V")
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def select_by_indices(lst: List[_T], indices: Container[int]) -> List[_T]:
|
|
21
|
+
return [elem for idx, elem in enumerate(lst) if idx in indices]
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def omit_by_indices(lst: List[_T], indices: Container[int]) -> List[_T]:
|
|
25
|
+
return [elem for idx, elem in enumerate(lst) if idx not in indices]
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def group_by(
|
|
29
|
+
lst: List[_T],
|
|
30
|
+
key: Callable[[_T], Any],
|
|
31
|
+
init: Callable[[_T], _V],
|
|
32
|
+
merge: Callable[[_V, _T], _V],
|
|
33
|
+
) -> List[_V]:
|
|
34
|
+
|
|
35
|
+
def _gen():
|
|
36
|
+
if not lst:
|
|
37
|
+
return
|
|
38
|
+
|
|
39
|
+
prev_val = init(lst[0])
|
|
40
|
+
prev_key = key(lst[0])
|
|
41
|
+
|
|
42
|
+
for elem in lst[1:]:
|
|
43
|
+
if prev_key == key(elem):
|
|
44
|
+
prev_val = merge(prev_val, elem)
|
|
45
|
+
else:
|
|
46
|
+
yield prev_val
|
|
47
|
+
prev_val = init(elem)
|
|
48
|
+
prev_key = key(elem)
|
|
49
|
+
|
|
50
|
+
yield prev_val
|
|
51
|
+
|
|
52
|
+
return list(_gen())
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
@dataclass
|
|
56
|
+
class ListProjection(Generic[_T]):
|
|
57
|
+
"""
|
|
58
|
+
The class represents a transformation of the original list which may
|
|
59
|
+
include merge, removal and addition of the original list elements.
|
|
60
|
+
|
|
61
|
+
Each derivative element is mapped onto a subset of original elements.
|
|
62
|
+
The subsets must be disjoint.
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
list: List[Tuple[_T, Set[int]]] = field(default_factory=list)
|
|
66
|
+
|
|
67
|
+
@property
|
|
68
|
+
def raw_list(self) -> List[_T]:
|
|
69
|
+
return [msg for msg, _ in self.list]
|
|
70
|
+
|
|
71
|
+
def to_original_indices(self, indices: Iterable[int]) -> Set[int]:
|
|
72
|
+
return {
|
|
73
|
+
orig_index
|
|
74
|
+
for index in indices
|
|
75
|
+
for orig_index in self.list[index][1]
|
|
76
|
+
}
|
|
77
|
+
|
|
78
|
+
def append(self, elem: _T, idx: int) -> Self:
|
|
79
|
+
self.list.append((elem, {idx}))
|
|
80
|
+
return self
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
async def aiter_to_list(iterator: AsyncIterator[_T]) -> List[_T]:
|
|
84
|
+
return [item async for item in iterator]
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
import base64
|
|
2
|
+
import re
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
from pydantic import BaseModel
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class Resource(BaseModel):
|
|
9
|
+
type: str
|
|
10
|
+
data: bytes
|
|
11
|
+
|
|
12
|
+
@classmethod
|
|
13
|
+
def from_base64(cls, type: str, data_base64: str) -> "Resource":
|
|
14
|
+
try:
|
|
15
|
+
data = base64.b64decode(data_base64, validate=True)
|
|
16
|
+
except Exception:
|
|
17
|
+
raise ValueError("Invalid base64 data")
|
|
18
|
+
|
|
19
|
+
return cls(type=type, data=data)
|
|
20
|
+
|
|
21
|
+
@classmethod
|
|
22
|
+
def from_data_url(cls, data_url: str) -> Optional["Resource"]:
|
|
23
|
+
"""
|
|
24
|
+
Parsing a resource encoded as a data URL.
|
|
25
|
+
See https://developer.mozilla.org/en-US/docs/Web/HTTP/Basics_of_HTTP/Data_URLs for reference.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
type = cls.parse_data_url_content_type(data_url)
|
|
29
|
+
if type is None:
|
|
30
|
+
return None
|
|
31
|
+
|
|
32
|
+
data_base64 = data_url.removeprefix(cls._to_data_url_prefix(type))
|
|
33
|
+
|
|
34
|
+
return cls.from_base64(type, data_base64)
|
|
35
|
+
|
|
36
|
+
@property
|
|
37
|
+
def data_base64(self) -> str:
|
|
38
|
+
return base64.b64encode(self.data).decode()
|
|
39
|
+
|
|
40
|
+
def to_data_url(self) -> str:
|
|
41
|
+
return f"{self._to_data_url_prefix(self.type)}{self.data_base64}"
|
|
42
|
+
|
|
43
|
+
@staticmethod
|
|
44
|
+
def parse_data_url_content_type(data_url: str) -> Optional[str]:
|
|
45
|
+
pattern = r"^data:([^;]+);base64,"
|
|
46
|
+
match = re.match(pattern, data_url)
|
|
47
|
+
return None if match is None else match.group(1)
|
|
48
|
+
|
|
49
|
+
@staticmethod
|
|
50
|
+
def _to_data_url_prefix(content_type: str) -> str:
|
|
51
|
+
return f"data:{content_type};base64,"
|
|
52
|
+
|
|
53
|
+
def __str__(self) -> str:
|
|
54
|
+
return self.to_data_url()[:100] + "..."
|
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import Any, List, Set, Tuple, Type
|
|
3
|
+
|
|
4
|
+
from aidial_sdk.chat_completion import Message
|
|
5
|
+
from pydantic import BaseModel
|
|
6
|
+
|
|
7
|
+
from aidial_adapter_anthropic._utils.list import ListProjection
|
|
8
|
+
from aidial_adapter_anthropic.adapter._errors import ValidationError
|
|
9
|
+
from aidial_adapter_anthropic.adapter._truncate_prompt import DiscardedMessages
|
|
10
|
+
from aidial_adapter_anthropic.dial.consumer import Consumer
|
|
11
|
+
from aidial_adapter_anthropic.dial.request import (
|
|
12
|
+
ModelParameters,
|
|
13
|
+
collect_text_content,
|
|
14
|
+
is_system_role,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class ChatCompletionAdapter(ABC, BaseModel):
|
|
19
|
+
class Config:
|
|
20
|
+
arbitrary_types_allowed = True
|
|
21
|
+
|
|
22
|
+
@abstractmethod
|
|
23
|
+
async def chat(
|
|
24
|
+
self,
|
|
25
|
+
consumer: Consumer,
|
|
26
|
+
params: ModelParameters,
|
|
27
|
+
messages: List[Message],
|
|
28
|
+
) -> None:
|
|
29
|
+
pass
|
|
30
|
+
|
|
31
|
+
async def configuration(self) -> Type[BaseModel]:
|
|
32
|
+
raise NotImplementedError
|
|
33
|
+
|
|
34
|
+
async def count_prompt_tokens(
|
|
35
|
+
self, params: ModelParameters, messages: List[Message]
|
|
36
|
+
) -> int:
|
|
37
|
+
raise NotImplementedError
|
|
38
|
+
|
|
39
|
+
async def count_completion_tokens(self, string: str) -> int:
|
|
40
|
+
raise NotImplementedError
|
|
41
|
+
|
|
42
|
+
async def compute_discarded_messages(
|
|
43
|
+
self, params: ModelParameters, messages: List[Message]
|
|
44
|
+
) -> DiscardedMessages | None:
|
|
45
|
+
"""
|
|
46
|
+
The method truncates the list of messages to fit
|
|
47
|
+
into the token limit set in `params.max_prompt_tokens`.
|
|
48
|
+
|
|
49
|
+
If the limit isn't provided, then it returns None.
|
|
50
|
+
Otherwise, returns the indices of _discarded_ messages which should be
|
|
51
|
+
removed from the list to make the rest fit into the token limit.
|
|
52
|
+
"""
|
|
53
|
+
raise NotImplementedError
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def default_preprocess_messages(
|
|
57
|
+
messages: List[Message],
|
|
58
|
+
) -> ListProjection[Message]:
|
|
59
|
+
def _is_empty_system_message(msg: Message) -> bool:
|
|
60
|
+
return (
|
|
61
|
+
is_system_role(msg.role)
|
|
62
|
+
and collect_text_content(msg.content).strip() == ""
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
ret: List[Tuple[Message, Set[int]]] = []
|
|
66
|
+
idx: Set[int] = set()
|
|
67
|
+
|
|
68
|
+
for i, msg in enumerate(messages):
|
|
69
|
+
idx.add(i)
|
|
70
|
+
if _is_empty_system_message(msg):
|
|
71
|
+
continue
|
|
72
|
+
ret.append((msg, idx))
|
|
73
|
+
idx = set()
|
|
74
|
+
|
|
75
|
+
if len(ret) == 0:
|
|
76
|
+
raise ValidationError("List of messages must not be empty")
|
|
77
|
+
|
|
78
|
+
return ListProjection(ret)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def keep_last(messages: List[Any], idx: int) -> bool:
|
|
82
|
+
return idx == len(messages) - 1
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def keep_last_and_system_messages(messages: List[Message], idx: int) -> bool:
|
|
86
|
+
return is_system_role(messages[idx].role) or keep_last(messages, idx)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def trivial_partitioner(messages: List[Any]) -> List[int]:
|
|
90
|
+
return [1] * len(messages)
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def turn_based_partitioner(messages: List[Any]) -> List[int]:
|
|
94
|
+
n = len(messages)
|
|
95
|
+
return [2] * (n // 2) + [1] * (n % 2)
|