aidial-adapter-anthropic 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. aidial_adapter_anthropic/_utils/json.py +116 -0
  2. aidial_adapter_anthropic/_utils/list.py +84 -0
  3. aidial_adapter_anthropic/_utils/pydantic.py +6 -0
  4. aidial_adapter_anthropic/_utils/resource.py +54 -0
  5. aidial_adapter_anthropic/_utils/text.py +4 -0
  6. aidial_adapter_anthropic/adapter/__init__.py +4 -0
  7. aidial_adapter_anthropic/adapter/_base.py +95 -0
  8. aidial_adapter_anthropic/adapter/_claude/adapter.py +549 -0
  9. aidial_adapter_anthropic/adapter/_claude/blocks.py +128 -0
  10. aidial_adapter_anthropic/adapter/_claude/citations.py +63 -0
  11. aidial_adapter_anthropic/adapter/_claude/config.py +39 -0
  12. aidial_adapter_anthropic/adapter/_claude/converters.py +303 -0
  13. aidial_adapter_anthropic/adapter/_claude/params.py +25 -0
  14. aidial_adapter_anthropic/adapter/_claude/state.py +45 -0
  15. aidial_adapter_anthropic/adapter/_claude/tokenizer/__init__.py +10 -0
  16. aidial_adapter_anthropic/adapter/_claude/tokenizer/anthropic.py +57 -0
  17. aidial_adapter_anthropic/adapter/_claude/tokenizer/approximate.py +260 -0
  18. aidial_adapter_anthropic/adapter/_claude/tokenizer/base.py +26 -0
  19. aidial_adapter_anthropic/adapter/_claude/tools.py +98 -0
  20. aidial_adapter_anthropic/adapter/_decorator/base.py +53 -0
  21. aidial_adapter_anthropic/adapter/_decorator/preprocess.py +63 -0
  22. aidial_adapter_anthropic/adapter/_decorator/replicator.py +32 -0
  23. aidial_adapter_anthropic/adapter/_errors.py +71 -0
  24. aidial_adapter_anthropic/adapter/_tokenize.py +12 -0
  25. aidial_adapter_anthropic/adapter/_truncate_prompt.py +168 -0
  26. aidial_adapter_anthropic/adapter/claude.py +17 -0
  27. aidial_adapter_anthropic/dial/_attachments.py +238 -0
  28. aidial_adapter_anthropic/dial/_lazy_stage.py +40 -0
  29. aidial_adapter_anthropic/dial/_message.py +341 -0
  30. aidial_adapter_anthropic/dial/consumer.py +235 -0
  31. aidial_adapter_anthropic/dial/request.py +170 -0
  32. aidial_adapter_anthropic/dial/resource.py +189 -0
  33. aidial_adapter_anthropic/dial/storage.py +138 -0
  34. aidial_adapter_anthropic/dial/token_usage.py +19 -0
  35. aidial_adapter_anthropic/dial/tools.py +180 -0
  36. aidial_adapter_anthropic-0.1.0.dist-info/LICENSE +202 -0
  37. aidial_adapter_anthropic-0.1.0.dist-info/METADATA +121 -0
  38. aidial_adapter_anthropic-0.1.0.dist-info/RECORD +39 -0
  39. aidial_adapter_anthropic-0.1.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,116 @@
1
+ """
2
+ Utilities for pretty-printing JSON in debug logs.
3
+ These functions are useful for dumping large data structures,
4
+ with options to trim long strings and lists to specified limits.
5
+ """
6
+
7
+ import json
8
+ from dataclasses import asdict, is_dataclass
9
+ from enum import Enum
10
+ from typing import Any
11
+
12
+ from anthropic import Omit
13
+ from pydantic import BaseModel
14
+
15
+
16
+ def json_dumps_short(
17
+ obj: Any, *, string_limit: int = 100, list_len_limit: int = 10, **kwargs
18
+ ) -> str:
19
+ def default(obj) -> str:
20
+ return _truncate_strings(str(obj), string_limit)
21
+
22
+ return json.dumps(
23
+ _truncate_lists(
24
+ _truncate_strings(_to_dict(obj, **kwargs), string_limit),
25
+ list_len_limit,
26
+ ),
27
+ default=default,
28
+ )
29
+
30
+
31
+ def json_dumps(obj: Any, **kwargs) -> str:
32
+ return json.dumps(_to_dict(obj, **kwargs))
33
+
34
+
35
+ def _to_dict(obj: Any, **kwargs) -> Any:
36
+ def rec(val):
37
+ return _to_dict(val, **kwargs)
38
+
39
+ def dict_field(key: str, val: Any) -> Any:
40
+ if key in kwargs.get("excluded_keys", []):
41
+ return "<excluded>"
42
+ return val
43
+
44
+ if isinstance(obj, bytes):
45
+ return f"<bytes>({len(obj):_} B)"
46
+
47
+ if isinstance(obj, Enum):
48
+ return obj.value
49
+
50
+ if isinstance(obj, dict):
51
+ return {key: rec(dict_field(key, value)) for key, value in obj.items()}
52
+
53
+ if isinstance(obj, list):
54
+ return [rec(element) for element in obj]
55
+
56
+ if isinstance(obj, tuple):
57
+ return tuple(rec(element) for element in obj)
58
+
59
+ if isinstance(obj, BaseModel):
60
+ return rec(obj.dict())
61
+
62
+ if hasattr(obj, "to_dict"):
63
+ return rec(obj.to_dict())
64
+
65
+ if is_dataclass(type(obj)):
66
+ return rec(asdict(obj))
67
+
68
+ if isinstance(obj, Omit):
69
+ return "omit"
70
+
71
+ return obj
72
+
73
+
74
+ def _truncate_strings(obj: Any, limit: int) -> Any:
75
+ def rec(val):
76
+ return _truncate_strings(val, limit)
77
+
78
+ if isinstance(obj, dict):
79
+ return {key: rec(value) for key, value in obj.items()}
80
+
81
+ if isinstance(obj, list):
82
+ return [rec(element) for element in obj]
83
+
84
+ if isinstance(obj, tuple):
85
+ return tuple(rec(element) for element in obj)
86
+
87
+ if isinstance(obj, str) and len(obj) > limit:
88
+ skip = len(obj) - limit
89
+ return (
90
+ obj[: limit // 2] + f"...({skip:_} skipped)..." + obj[-limit // 2 :]
91
+ )
92
+
93
+ return obj
94
+
95
+
96
+ def _truncate_lists(obj: Any, limit: int) -> Any:
97
+ def rec(val):
98
+ return _truncate_lists(val, limit)
99
+
100
+ if isinstance(obj, dict):
101
+ return {key: rec(value) for key, value in obj.items()}
102
+
103
+ if isinstance(obj, list):
104
+ if len(obj) > limit:
105
+ skip = len(obj) - limit
106
+ obj = (
107
+ obj[: limit // 2]
108
+ + [f"...({skip:_} skipped)..."]
109
+ + obj[-limit // 2 :]
110
+ )
111
+ return [rec(element) for element in obj]
112
+
113
+ if isinstance(obj, tuple):
114
+ return tuple(rec(element) for element in obj)
115
+
116
+ return obj
@@ -0,0 +1,84 @@
1
+ from dataclasses import dataclass, field
2
+ from typing import (
3
+ Any,
4
+ AsyncIterator,
5
+ Callable,
6
+ Container,
7
+ Generic,
8
+ Iterable,
9
+ List,
10
+ Self,
11
+ Set,
12
+ Tuple,
13
+ TypeVar,
14
+ )
15
+
16
+ _T = TypeVar("_T")
17
+ _V = TypeVar("_V")
18
+
19
+
20
+ def select_by_indices(lst: List[_T], indices: Container[int]) -> List[_T]:
21
+ return [elem for idx, elem in enumerate(lst) if idx in indices]
22
+
23
+
24
+ def omit_by_indices(lst: List[_T], indices: Container[int]) -> List[_T]:
25
+ return [elem for idx, elem in enumerate(lst) if idx not in indices]
26
+
27
+
28
+ def group_by(
29
+ lst: List[_T],
30
+ key: Callable[[_T], Any],
31
+ init: Callable[[_T], _V],
32
+ merge: Callable[[_V, _T], _V],
33
+ ) -> List[_V]:
34
+
35
+ def _gen():
36
+ if not lst:
37
+ return
38
+
39
+ prev_val = init(lst[0])
40
+ prev_key = key(lst[0])
41
+
42
+ for elem in lst[1:]:
43
+ if prev_key == key(elem):
44
+ prev_val = merge(prev_val, elem)
45
+ else:
46
+ yield prev_val
47
+ prev_val = init(elem)
48
+ prev_key = key(elem)
49
+
50
+ yield prev_val
51
+
52
+ return list(_gen())
53
+
54
+
55
+ @dataclass
56
+ class ListProjection(Generic[_T]):
57
+ """
58
+ The class represents a transformation of the original list which may
59
+ include merge, removal and addition of the original list elements.
60
+
61
+ Each derivative element is mapped onto a subset of original elements.
62
+ The subsets must be disjoint.
63
+ """
64
+
65
+ list: List[Tuple[_T, Set[int]]] = field(default_factory=list)
66
+
67
+ @property
68
+ def raw_list(self) -> List[_T]:
69
+ return [msg for msg, _ in self.list]
70
+
71
+ def to_original_indices(self, indices: Iterable[int]) -> Set[int]:
72
+ return {
73
+ orig_index
74
+ for index in indices
75
+ for orig_index in self.list[index][1]
76
+ }
77
+
78
+ def append(self, elem: _T, idx: int) -> Self:
79
+ self.list.append((elem, {idx}))
80
+ return self
81
+
82
+
83
+ async def aiter_to_list(iterator: AsyncIterator[_T]) -> List[_T]:
84
+ return [item async for item in iterator]
@@ -0,0 +1,6 @@
1
+ from pydantic import BaseModel
2
+
3
+
4
+ class ExtraForbidModel(BaseModel):
5
+ class Config:
6
+ extra = "forbid"
@@ -0,0 +1,54 @@
1
+ import base64
2
+ import re
3
+ from typing import Optional
4
+
5
+ from pydantic import BaseModel
6
+
7
+
8
+ class Resource(BaseModel):
9
+ type: str
10
+ data: bytes
11
+
12
+ @classmethod
13
+ def from_base64(cls, type: str, data_base64: str) -> "Resource":
14
+ try:
15
+ data = base64.b64decode(data_base64, validate=True)
16
+ except Exception:
17
+ raise ValueError("Invalid base64 data")
18
+
19
+ return cls(type=type, data=data)
20
+
21
+ @classmethod
22
+ def from_data_url(cls, data_url: str) -> Optional["Resource"]:
23
+ """
24
+ Parsing a resource encoded as a data URL.
25
+ See https://developer.mozilla.org/en-US/docs/Web/HTTP/Basics_of_HTTP/Data_URLs for reference.
26
+ """
27
+
28
+ type = cls.parse_data_url_content_type(data_url)
29
+ if type is None:
30
+ return None
31
+
32
+ data_base64 = data_url.removeprefix(cls._to_data_url_prefix(type))
33
+
34
+ return cls.from_base64(type, data_base64)
35
+
36
+ @property
37
+ def data_base64(self) -> str:
38
+ return base64.b64encode(self.data).decode()
39
+
40
+ def to_data_url(self) -> str:
41
+ return f"{self._to_data_url_prefix(self.type)}{self.data_base64}"
42
+
43
+ @staticmethod
44
+ def parse_data_url_content_type(data_url: str) -> Optional[str]:
45
+ pattern = r"^data:([^;]+);base64,"
46
+ match = re.match(pattern, data_url)
47
+ return None if match is None else match.group(1)
48
+
49
+ @staticmethod
50
+ def _to_data_url_prefix(content_type: str) -> str:
51
+ return f"data:{content_type};base64,"
52
+
53
+ def __str__(self) -> str:
54
+ return self.to_data_url()[:100] + "..."
@@ -0,0 +1,4 @@
1
+ def truncate_string(s: str, n: int) -> str:
2
+ if len(s) <= n:
3
+ return s
4
+ return s[:n] + "..."
@@ -0,0 +1,4 @@
1
+ from aidial_adapter_anthropic.adapter._base import ChatCompletionAdapter
2
+ from aidial_adapter_anthropic.adapter._errors import UserError, ValidationError
3
+
4
+ __all__ = ["ChatCompletionAdapter", "UserError", "ValidationError"]
@@ -0,0 +1,95 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import Any, List, Set, Tuple, Type
3
+
4
+ from aidial_sdk.chat_completion import Message
5
+ from pydantic import BaseModel
6
+
7
+ from aidial_adapter_anthropic._utils.list import ListProjection
8
+ from aidial_adapter_anthropic.adapter._errors import ValidationError
9
+ from aidial_adapter_anthropic.adapter._truncate_prompt import DiscardedMessages
10
+ from aidial_adapter_anthropic.dial.consumer import Consumer
11
+ from aidial_adapter_anthropic.dial.request import (
12
+ ModelParameters,
13
+ collect_text_content,
14
+ is_system_role,
15
+ )
16
+
17
+
18
+ class ChatCompletionAdapter(ABC, BaseModel):
19
+ class Config:
20
+ arbitrary_types_allowed = True
21
+
22
+ @abstractmethod
23
+ async def chat(
24
+ self,
25
+ consumer: Consumer,
26
+ params: ModelParameters,
27
+ messages: List[Message],
28
+ ) -> None:
29
+ pass
30
+
31
+ async def configuration(self) -> Type[BaseModel]:
32
+ raise NotImplementedError
33
+
34
+ async def count_prompt_tokens(
35
+ self, params: ModelParameters, messages: List[Message]
36
+ ) -> int:
37
+ raise NotImplementedError
38
+
39
+ async def count_completion_tokens(self, string: str) -> int:
40
+ raise NotImplementedError
41
+
42
+ async def compute_discarded_messages(
43
+ self, params: ModelParameters, messages: List[Message]
44
+ ) -> DiscardedMessages | None:
45
+ """
46
+ The method truncates the list of messages to fit
47
+ into the token limit set in `params.max_prompt_tokens`.
48
+
49
+ If the limit isn't provided, then it returns None.
50
+ Otherwise, returns the indices of _discarded_ messages which should be
51
+ removed from the list to make the rest fit into the token limit.
52
+ """
53
+ raise NotImplementedError
54
+
55
+
56
+ def default_preprocess_messages(
57
+ messages: List[Message],
58
+ ) -> ListProjection[Message]:
59
+ def _is_empty_system_message(msg: Message) -> bool:
60
+ return (
61
+ is_system_role(msg.role)
62
+ and collect_text_content(msg.content).strip() == ""
63
+ )
64
+
65
+ ret: List[Tuple[Message, Set[int]]] = []
66
+ idx: Set[int] = set()
67
+
68
+ for i, msg in enumerate(messages):
69
+ idx.add(i)
70
+ if _is_empty_system_message(msg):
71
+ continue
72
+ ret.append((msg, idx))
73
+ idx = set()
74
+
75
+ if len(ret) == 0:
76
+ raise ValidationError("List of messages must not be empty")
77
+
78
+ return ListProjection(ret)
79
+
80
+
81
+ def keep_last(messages: List[Any], idx: int) -> bool:
82
+ return idx == len(messages) - 1
83
+
84
+
85
+ def keep_last_and_system_messages(messages: List[Message], idx: int) -> bool:
86
+ return is_system_role(messages[idx].role) or keep_last(messages, idx)
87
+
88
+
89
+ def trivial_partitioner(messages: List[Any]) -> List[int]:
90
+ return [1] * len(messages)
91
+
92
+
93
+ def turn_based_partitioner(messages: List[Any]) -> List[int]:
94
+ n = len(messages)
95
+ return [2] * (n // 2) + [1] * (n % 2)