shinychat 0.0.1a0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- shinychat/__init__.py +3 -0
- shinychat/__version.py +21 -0
- shinychat/_chat.py +1878 -0
- shinychat/_chat_bookmark.py +110 -0
- shinychat/_chat_normalize.py +350 -0
- shinychat/_chat_provider_types.py +127 -0
- shinychat/_chat_tokenizer.py +67 -0
- shinychat/_chat_types.py +79 -0
- shinychat/_html_deps_py_shiny.py +41 -0
- shinychat/_markdown_stream.py +374 -0
- shinychat/_typing_extensions.py +63 -0
- shinychat/_utils.py +173 -0
- shinychat/express/__init__.py +3 -0
- shinychat/playwright/__init__.py +3 -0
- shinychat/playwright/_chat.py +154 -0
- shinychat/www/GIT_VERSION +1 -0
- shinychat/www/chat/chat.css +2 -0
- shinychat/www/chat/chat.css.map +7 -0
- shinychat/www/chat/chat.js +87 -0
- shinychat/www/chat/chat.js.map +7 -0
- shinychat/www/markdown-stream/markdown-stream.css +2 -0
- shinychat/www/markdown-stream/markdown-stream.css.map +7 -0
- shinychat/www/markdown-stream/markdown-stream.js +149 -0
- shinychat/www/markdown-stream/markdown-stream.js.map +7 -0
- shinychat-0.0.1a0.dist-info/METADATA +36 -0
- shinychat-0.0.1a0.dist-info/RECORD +28 -0
- shinychat-0.0.1a0.dist-info/WHEEL +4 -0
- shinychat-0.0.1a0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,110 @@
|
|
1
|
+
import importlib.util
|
2
|
+
from typing import Any, Awaitable, Callable, Protocol, runtime_checkable
|
3
|
+
|
4
|
+
from htmltools import TagChild
|
5
|
+
from shiny.types import Jsonifiable
|
6
|
+
|
7
|
+
chatlas_is_installed = importlib.util.find_spec("chatlas") is not None
|
8
|
+
|
9
|
+
|
10
|
+
def is_chatlas_chat_client(client: Any) -> bool:
|
11
|
+
if not chatlas_is_installed:
|
12
|
+
return False
|
13
|
+
import chatlas
|
14
|
+
|
15
|
+
return isinstance(client, chatlas.Chat)
|
16
|
+
|
17
|
+
|
18
|
+
@runtime_checkable
|
19
|
+
class ClientWithState(Protocol):
|
20
|
+
async def get_state(self) -> Jsonifiable: ...
|
21
|
+
|
22
|
+
"""
|
23
|
+
Retrieve JSON-like representation of chat client state.
|
24
|
+
|
25
|
+
This method is used to retrieve the state of the client object when saving a bookmark.
|
26
|
+
|
27
|
+
Returns
|
28
|
+
-------
|
29
|
+
:
|
30
|
+
A JSON-like representation of the current state of the client. It is not required to be a JSON string but something that can be serialized to JSON without further conversion.
|
31
|
+
"""
|
32
|
+
|
33
|
+
async def set_state(self, state: Jsonifiable): ...
|
34
|
+
|
35
|
+
"""
|
36
|
+
Method to set the chat client state.
|
37
|
+
|
38
|
+
This method is used to restore the state of the client when the app is restored from
|
39
|
+
a bookmark.
|
40
|
+
|
41
|
+
Parameters
|
42
|
+
----------
|
43
|
+
state
|
44
|
+
The value to infer the state from. This value will be the JSON capable value
|
45
|
+
returned by the `get_state()` method (after a round trip through JSON
|
46
|
+
serialization and unserialization).
|
47
|
+
"""
|
48
|
+
|
49
|
+
|
50
|
+
CancelCallback = Callable[[], None]
|
51
|
+
|
52
|
+
|
53
|
+
class BookmarkCancelCallback:
|
54
|
+
def __init__(self, cancel: CancelCallback):
|
55
|
+
self.cancel = cancel
|
56
|
+
|
57
|
+
def __call__(self):
|
58
|
+
self.cancel()
|
59
|
+
|
60
|
+
def tagify(self) -> TagChild:
|
61
|
+
return ""
|
62
|
+
|
63
|
+
|
64
|
+
# Chatlas specific implementation
|
65
|
+
def get_chatlas_state(
|
66
|
+
client: Any,
|
67
|
+
) -> Callable[[], Awaitable[Jsonifiable]]:
|
68
|
+
from chatlas import Chat, Turn
|
69
|
+
|
70
|
+
assert isinstance(client, Chat)
|
71
|
+
|
72
|
+
async def get_state() -> Jsonifiable:
|
73
|
+
turns: list[Turn[Any]] = client.get_turns()
|
74
|
+
return {
|
75
|
+
"version": 1,
|
76
|
+
"turns": [turn.model_dump(mode="json") for turn in turns],
|
77
|
+
}
|
78
|
+
|
79
|
+
return get_state
|
80
|
+
|
81
|
+
|
82
|
+
def set_chatlas_state(
|
83
|
+
client: Any,
|
84
|
+
) -> Callable[[Jsonifiable], Awaitable[None]]:
|
85
|
+
from chatlas import Chat, Turn
|
86
|
+
|
87
|
+
assert isinstance(client, Chat)
|
88
|
+
|
89
|
+
# TODO-future: Use pydantic model for validation
|
90
|
+
# instead of manual validation
|
91
|
+
async def set_state(value: Jsonifiable) -> None:
|
92
|
+
if not isinstance(value, dict):
|
93
|
+
raise ValueError("Chatlas bookmark value was not a dictionary")
|
94
|
+
|
95
|
+
version = value.get("version")
|
96
|
+
if version != 1:
|
97
|
+
raise ValueError(f"Unsupported Chatlas bookmark version: {version}")
|
98
|
+
turns_arr = value.get("turns")
|
99
|
+
|
100
|
+
if not isinstance(turns_arr, list):
|
101
|
+
raise ValueError(
|
102
|
+
"Chatlas bookmark value was not a list of chat message information"
|
103
|
+
)
|
104
|
+
|
105
|
+
turns: list[Turn[Any]] = [
|
106
|
+
Turn.model_validate(turn_obj) for turn_obj in turns_arr
|
107
|
+
]
|
108
|
+
client.set_turns(turns) # pyright: ignore[reportUnknownMemberType]
|
109
|
+
|
110
|
+
return set_state
|
@@ -0,0 +1,350 @@
|
|
1
|
+
import sys
|
2
|
+
from abc import ABC, abstractmethod
|
3
|
+
from typing import TYPE_CHECKING, Any, Optional, cast
|
4
|
+
|
5
|
+
from htmltools import HTML, Tagifiable
|
6
|
+
|
7
|
+
from ._chat_types import ChatMessage
|
8
|
+
|
9
|
+
if TYPE_CHECKING:
|
10
|
+
from anthropic.types import Message as AnthropicMessage
|
11
|
+
from anthropic.types import MessageStreamEvent
|
12
|
+
|
13
|
+
if sys.version_info >= (3, 9):
|
14
|
+
from google.generativeai.types.generation_types import ( # pyright: ignore[reportMissingTypeStubs]
|
15
|
+
GenerateContentResponse,
|
16
|
+
)
|
17
|
+
else:
|
18
|
+
|
19
|
+
class GenerateContentResponse:
|
20
|
+
text: str
|
21
|
+
|
22
|
+
from langchain_core.messages import BaseMessage, BaseMessageChunk
|
23
|
+
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
24
|
+
|
25
|
+
|
26
|
+
class BaseMessageNormalizer(ABC):
|
27
|
+
@abstractmethod
|
28
|
+
def normalize(self, message: Any) -> ChatMessage:
|
29
|
+
pass
|
30
|
+
|
31
|
+
@abstractmethod
|
32
|
+
def normalize_chunk(self, chunk: Any) -> ChatMessage:
|
33
|
+
pass
|
34
|
+
|
35
|
+
@abstractmethod
|
36
|
+
def can_normalize(self, message: Any) -> bool:
|
37
|
+
pass
|
38
|
+
|
39
|
+
@abstractmethod
|
40
|
+
def can_normalize_chunk(self, chunk: Any) -> bool:
|
41
|
+
pass
|
42
|
+
|
43
|
+
|
44
|
+
class StringNormalizer(BaseMessageNormalizer):
|
45
|
+
def normalize(self, message: Any) -> ChatMessage:
|
46
|
+
x = cast(Optional[str], message)
|
47
|
+
return ChatMessage(content=x or "", role="assistant")
|
48
|
+
|
49
|
+
def normalize_chunk(self, chunk: Any) -> ChatMessage:
|
50
|
+
x = cast(Optional[str], chunk)
|
51
|
+
return ChatMessage(content=x or "", role="assistant")
|
52
|
+
|
53
|
+
def can_normalize(self, message: Any) -> bool:
|
54
|
+
return isinstance(message, (str, HTML)) or message is None
|
55
|
+
|
56
|
+
def can_normalize_chunk(self, chunk: Any) -> bool:
|
57
|
+
return isinstance(chunk, (str, HTML)) or chunk is None
|
58
|
+
|
59
|
+
|
60
|
+
class DictNormalizer(BaseMessageNormalizer):
|
61
|
+
def normalize(self, message: Any) -> ChatMessage:
|
62
|
+
x = cast("dict[str, Any]", message)
|
63
|
+
if "content" not in x:
|
64
|
+
raise ValueError("Message must have 'content' key")
|
65
|
+
return ChatMessage(content=x["content"], role=x.get("role", "assistant"))
|
66
|
+
|
67
|
+
def normalize_chunk(self, chunk: Any) -> ChatMessage:
|
68
|
+
x = cast("dict[str, Any]", chunk)
|
69
|
+
if "content" not in x:
|
70
|
+
raise ValueError("Message must have 'content' key")
|
71
|
+
return ChatMessage(content=x["content"], role=x.get("role", "assistant"))
|
72
|
+
|
73
|
+
def can_normalize(self, message: Any) -> bool:
|
74
|
+
return isinstance(message, dict)
|
75
|
+
|
76
|
+
def can_normalize_chunk(self, chunk: Any) -> bool:
|
77
|
+
return isinstance(chunk, dict)
|
78
|
+
|
79
|
+
|
80
|
+
class TagifiableNormalizer(DictNormalizer):
|
81
|
+
def normalize(self, message: Any) -> ChatMessage:
|
82
|
+
x = cast("Tagifiable", message)
|
83
|
+
return super().normalize({"content": x})
|
84
|
+
|
85
|
+
def normalize_chunk(self, chunk: Any) -> ChatMessage:
|
86
|
+
x = cast("Tagifiable", chunk)
|
87
|
+
return super().normalize_chunk({"content": x})
|
88
|
+
|
89
|
+
def can_normalize(self, message: Any) -> bool:
|
90
|
+
return isinstance(message, Tagifiable)
|
91
|
+
|
92
|
+
def can_normalize_chunk(self, chunk: Any) -> bool:
|
93
|
+
return isinstance(chunk, Tagifiable)
|
94
|
+
|
95
|
+
|
96
|
+
class LangChainNormalizer(BaseMessageNormalizer):
|
97
|
+
def normalize(self, message: Any) -> ChatMessage:
|
98
|
+
x = cast("BaseMessage", message)
|
99
|
+
if isinstance(x.content, list): # type: ignore
|
100
|
+
raise ValueError(
|
101
|
+
"The `message.content` provided seems to represent numerous messages. "
|
102
|
+
"Consider iterating over `message.content` and calling .append_message() on each iteration."
|
103
|
+
)
|
104
|
+
return ChatMessage(content=x.content, role="assistant")
|
105
|
+
|
106
|
+
def normalize_chunk(self, chunk: Any) -> ChatMessage:
|
107
|
+
x = cast("BaseMessageChunk", chunk)
|
108
|
+
if isinstance(x.content, list): # type: ignore
|
109
|
+
raise ValueError(
|
110
|
+
"The `message.content` provided seems to represent numerous messages. "
|
111
|
+
"Consider iterating over `message.content` and calling .append_message() on each iteration."
|
112
|
+
)
|
113
|
+
return ChatMessage(content=x.content, role="assistant")
|
114
|
+
|
115
|
+
def can_normalize(self, message: Any) -> bool:
|
116
|
+
try:
|
117
|
+
from langchain_core.messages import BaseMessage
|
118
|
+
|
119
|
+
return isinstance(message, BaseMessage)
|
120
|
+
except Exception:
|
121
|
+
return False
|
122
|
+
|
123
|
+
def can_normalize_chunk(self, chunk: Any) -> bool:
|
124
|
+
try:
|
125
|
+
from langchain_core.messages import BaseMessageChunk
|
126
|
+
|
127
|
+
return isinstance(chunk, BaseMessageChunk)
|
128
|
+
except Exception:
|
129
|
+
return False
|
130
|
+
|
131
|
+
|
132
|
+
class OpenAINormalizer(StringNormalizer):
|
133
|
+
def normalize(self, message: Any) -> ChatMessage:
|
134
|
+
x = cast("ChatCompletion", message)
|
135
|
+
return super().normalize(x.choices[0].message.content)
|
136
|
+
|
137
|
+
def normalize_chunk(self, chunk: Any) -> ChatMessage:
|
138
|
+
x = cast("ChatCompletionChunk", chunk)
|
139
|
+
return super().normalize_chunk(x.choices[0].delta.content)
|
140
|
+
|
141
|
+
def can_normalize(self, message: Any) -> bool:
|
142
|
+
try:
|
143
|
+
from openai.types.chat import ChatCompletion
|
144
|
+
|
145
|
+
return isinstance(message, ChatCompletion)
|
146
|
+
except Exception:
|
147
|
+
return False
|
148
|
+
|
149
|
+
def can_normalize_chunk(self, chunk: Any) -> bool:
|
150
|
+
try:
|
151
|
+
from openai.types.chat import ChatCompletionChunk
|
152
|
+
|
153
|
+
return isinstance(chunk, ChatCompletionChunk)
|
154
|
+
except Exception:
|
155
|
+
return False
|
156
|
+
|
157
|
+
|
158
|
+
class AnthropicNormalizer(BaseMessageNormalizer):
|
159
|
+
def normalize(self, message: Any) -> ChatMessage:
|
160
|
+
x = cast("AnthropicMessage", message)
|
161
|
+
content = x.content[0]
|
162
|
+
if content.type != "text":
|
163
|
+
raise ValueError(
|
164
|
+
f"Anthropic message type {content.type} not supported. "
|
165
|
+
"Only 'text' type is currently supported"
|
166
|
+
)
|
167
|
+
return ChatMessage(content=content.text, role="assistant")
|
168
|
+
|
169
|
+
def normalize_chunk(self, chunk: Any) -> ChatMessage:
|
170
|
+
x = cast("MessageStreamEvent", chunk)
|
171
|
+
content = ""
|
172
|
+
if x.type == "content_block_delta":
|
173
|
+
if x.delta.type != "text_delta":
|
174
|
+
raise ValueError(
|
175
|
+
f"Anthropic message delta type {x.delta.type} not supported. "
|
176
|
+
"Only 'text_delta' type is supported"
|
177
|
+
)
|
178
|
+
content = x.delta.text
|
179
|
+
|
180
|
+
return ChatMessage(content=content, role="assistant")
|
181
|
+
|
182
|
+
def can_normalize(self, message: Any) -> bool:
|
183
|
+
try:
|
184
|
+
from anthropic.types import Message as AnthropicMessage
|
185
|
+
|
186
|
+
return isinstance(message, AnthropicMessage)
|
187
|
+
except Exception:
|
188
|
+
return False
|
189
|
+
|
190
|
+
def can_normalize_chunk(self, chunk: Any) -> bool:
|
191
|
+
try:
|
192
|
+
from anthropic.types import (
|
193
|
+
RawContentBlockDeltaEvent,
|
194
|
+
RawContentBlockStartEvent,
|
195
|
+
RawContentBlockStopEvent,
|
196
|
+
RawMessageDeltaEvent,
|
197
|
+
RawMessageStartEvent,
|
198
|
+
RawMessageStopEvent,
|
199
|
+
)
|
200
|
+
|
201
|
+
# The actual MessageStreamEvent is a generic, so isinstance() can't
|
202
|
+
# be used to check the type. Instead, we manually construct the relevant
|
203
|
+
# union of relevant classes...
|
204
|
+
return (
|
205
|
+
isinstance(chunk, RawContentBlockDeltaEvent)
|
206
|
+
or isinstance(chunk, RawContentBlockStartEvent)
|
207
|
+
or isinstance(chunk, RawContentBlockStopEvent)
|
208
|
+
or isinstance(chunk, RawMessageDeltaEvent)
|
209
|
+
or isinstance(chunk, RawMessageStartEvent)
|
210
|
+
or isinstance(chunk, RawMessageStopEvent)
|
211
|
+
)
|
212
|
+
except Exception:
|
213
|
+
return False
|
214
|
+
|
215
|
+
|
216
|
+
class GoogleNormalizer(BaseMessageNormalizer):
|
217
|
+
def normalize(self, message: Any) -> ChatMessage:
|
218
|
+
x = cast("GenerateContentResponse", message)
|
219
|
+
return ChatMessage(content=x.text, role="assistant")
|
220
|
+
|
221
|
+
def normalize_chunk(self, chunk: Any) -> ChatMessage:
|
222
|
+
x = cast("GenerateContentResponse", chunk)
|
223
|
+
return ChatMessage(content=x.text, role="assistant")
|
224
|
+
|
225
|
+
def can_normalize(self, message: Any) -> bool:
|
226
|
+
try:
|
227
|
+
import google.generativeai.types.generation_types as gtypes # pyright: ignore[reportMissingTypeStubs, reportMissingImports]
|
228
|
+
|
229
|
+
return isinstance(
|
230
|
+
message,
|
231
|
+
gtypes.GenerateContentResponse, # pyright: ignore[reportUnknownMemberType]
|
232
|
+
)
|
233
|
+
except Exception:
|
234
|
+
return False
|
235
|
+
|
236
|
+
def can_normalize_chunk(self, chunk: Any) -> bool:
|
237
|
+
return self.can_normalize(chunk)
|
238
|
+
|
239
|
+
|
240
|
+
class OllamaNormalizer(DictNormalizer):
|
241
|
+
def normalize(self, message: Any) -> ChatMessage:
|
242
|
+
x = cast("dict[str, Any]", message["message"])
|
243
|
+
return super().normalize(x)
|
244
|
+
|
245
|
+
def normalize_chunk(self, chunk: "dict[str, Any]") -> ChatMessage:
|
246
|
+
msg = cast("dict[str, Any]", chunk["message"])
|
247
|
+
return super().normalize_chunk(msg)
|
248
|
+
|
249
|
+
def can_normalize(self, message: Any) -> bool:
|
250
|
+
try:
|
251
|
+
from ollama import ChatResponse
|
252
|
+
|
253
|
+
# Ollama<0.4 used TypedDict (now it uses pydantic)
|
254
|
+
# https://github.com/ollama/ollama-python/pull/276
|
255
|
+
if isinstance(ChatResponse, dict):
|
256
|
+
return "message" in message and super().can_normalize(
|
257
|
+
message["message"]
|
258
|
+
)
|
259
|
+
else:
|
260
|
+
return isinstance(message, ChatResponse)
|
261
|
+
except Exception:
|
262
|
+
return False
|
263
|
+
|
264
|
+
def can_normalize_chunk(self, chunk: Any) -> bool:
|
265
|
+
return self.can_normalize(chunk)
|
266
|
+
|
267
|
+
|
268
|
+
class NormalizerRegistry:
|
269
|
+
def __init__(self) -> None:
|
270
|
+
# Order of strategies matters (the 1st one that can normalize the message is used)
|
271
|
+
# So make sure to put the most specific strategies first
|
272
|
+
self._strategies: dict[str, BaseMessageNormalizer] = {
|
273
|
+
"openai": OpenAINormalizer(),
|
274
|
+
"anthropic": AnthropicNormalizer(),
|
275
|
+
"google": GoogleNormalizer(),
|
276
|
+
"langchain": LangChainNormalizer(),
|
277
|
+
"ollama": OllamaNormalizer(),
|
278
|
+
"tagify": TagifiableNormalizer(),
|
279
|
+
"dict": DictNormalizer(),
|
280
|
+
"string": StringNormalizer(),
|
281
|
+
}
|
282
|
+
|
283
|
+
def register(
|
284
|
+
self, provider: str, strategy: BaseMessageNormalizer, force: bool = False
|
285
|
+
) -> None:
|
286
|
+
if provider in self._strategies:
|
287
|
+
if force:
|
288
|
+
del self._strategies[provider]
|
289
|
+
else:
|
290
|
+
raise ValueError(f"Provider {provider} already exists in registry")
|
291
|
+
# Update the strategies dict such that the new strategy is the first to be considered
|
292
|
+
self._strategies = {provider: strategy, **self._strategies}
|
293
|
+
|
294
|
+
|
295
|
+
message_normalizer_registry = NormalizerRegistry()
|
296
|
+
|
297
|
+
|
298
|
+
def register_custom_normalizer(
|
299
|
+
provider: str, normalizer: BaseMessageNormalizer, force: bool = False
|
300
|
+
) -> None:
|
301
|
+
"""
|
302
|
+
Register a custom normalizer for handling specific message types.
|
303
|
+
|
304
|
+
Parameters
|
305
|
+
----------
|
306
|
+
provider : str
|
307
|
+
A unique identifier for this normalizer in the registry
|
308
|
+
normalizer : BaseMessageNormalizer
|
309
|
+
A normalizer instance that can handle your specific message type
|
310
|
+
force : bool, optional
|
311
|
+
Whether to override an existing normalizer with the same provider name,
|
312
|
+
by default False
|
313
|
+
|
314
|
+
Examples
|
315
|
+
--------
|
316
|
+
>>> class MyCustomMessage:
|
317
|
+
... def __init__(self, content):
|
318
|
+
... self.content = content
|
319
|
+
...
|
320
|
+
>>> class MyCustomNormalizer(StringNormalizer):
|
321
|
+
... def normalize(self, message):
|
322
|
+
... return ChatMessage(content=message.content, role="assistant")
|
323
|
+
... def can_normalize(self, message):
|
324
|
+
... return isinstance(message, MyCustomMessage)
|
325
|
+
...
|
326
|
+
>>> register_custom_normalizer("my_provider", MyCustomNormalizer())
|
327
|
+
"""
|
328
|
+
message_normalizer_registry.register(provider, normalizer, force)
|
329
|
+
|
330
|
+
|
331
|
+
def normalize_message(message: Any) -> ChatMessage:
|
332
|
+
strategies = message_normalizer_registry._strategies
|
333
|
+
for strategy in strategies.values():
|
334
|
+
if strategy.can_normalize(message):
|
335
|
+
return strategy.normalize(message)
|
336
|
+
raise ValueError(
|
337
|
+
f"Could not find a normalizer for message of type {type(message)}: {message}. "
|
338
|
+
"Consider registering a custom normalizer via shiny.ui._chat_types.registry.register()"
|
339
|
+
)
|
340
|
+
|
341
|
+
|
342
|
+
def normalize_message_chunk(chunk: Any) -> ChatMessage:
|
343
|
+
strategies = message_normalizer_registry._strategies
|
344
|
+
for strategy in strategies.values():
|
345
|
+
if strategy.can_normalize_chunk(chunk):
|
346
|
+
return strategy.normalize_chunk(chunk)
|
347
|
+
raise ValueError(
|
348
|
+
f"Could not find a normalizer for message chunk of type {type(chunk)}: {chunk}. "
|
349
|
+
"Consider registering a custom normalizer via shiny.ui._chat_normalize.register_custom_normalizer()"
|
350
|
+
)
|
@@ -0,0 +1,127 @@
|
|
1
|
+
import sys
|
2
|
+
from typing import TYPE_CHECKING, Literal, Union
|
3
|
+
|
4
|
+
from ._chat_types import ChatMessageDict
|
5
|
+
|
6
|
+
if TYPE_CHECKING:
|
7
|
+
from anthropic.types import MessageParam as AnthropicMessage
|
8
|
+
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
9
|
+
from ollama import Message as OllamaMessage
|
10
|
+
from openai.types.chat import (
|
11
|
+
ChatCompletionAssistantMessageParam,
|
12
|
+
ChatCompletionSystemMessageParam,
|
13
|
+
ChatCompletionUserMessageParam,
|
14
|
+
)
|
15
|
+
|
16
|
+
if sys.version_info >= (3, 9):
|
17
|
+
import google.generativeai.types as gtypes # pyright: ignore[reportMissingTypeStubs]
|
18
|
+
|
19
|
+
GoogleMessage = gtypes.ContentDict
|
20
|
+
else:
|
21
|
+
GoogleMessage = object
|
22
|
+
|
23
|
+
LangChainMessage = Union[AIMessage, HumanMessage, SystemMessage]
|
24
|
+
OpenAIMessage = Union[
|
25
|
+
ChatCompletionAssistantMessageParam,
|
26
|
+
ChatCompletionSystemMessageParam,
|
27
|
+
ChatCompletionUserMessageParam,
|
28
|
+
]
|
29
|
+
|
30
|
+
ProviderMessage = Union[
|
31
|
+
AnthropicMessage, GoogleMessage, LangChainMessage, OpenAIMessage, OllamaMessage
|
32
|
+
]
|
33
|
+
else:
|
34
|
+
AnthropicMessage = GoogleMessage = LangChainMessage = OpenAIMessage = (
|
35
|
+
OllamaMessage
|
36
|
+
) = ProviderMessage = object
|
37
|
+
|
38
|
+
ProviderMessageFormat = Literal[
|
39
|
+
"anthropic",
|
40
|
+
"google",
|
41
|
+
"langchain",
|
42
|
+
"openai",
|
43
|
+
"ollama",
|
44
|
+
]
|
45
|
+
|
46
|
+
|
47
|
+
# TODO: use a strategy pattern to allow others to register
|
48
|
+
# their own message formats
|
49
|
+
def as_provider_message(
|
50
|
+
message: ChatMessageDict, format: ProviderMessageFormat
|
51
|
+
) -> "ProviderMessage":
|
52
|
+
if format == "anthropic":
|
53
|
+
return as_anthropic_message(message)
|
54
|
+
if format == "google":
|
55
|
+
return as_google_message(message)
|
56
|
+
if format == "langchain":
|
57
|
+
return as_langchain_message(message)
|
58
|
+
if format == "openai":
|
59
|
+
return as_openai_message(message)
|
60
|
+
if format == "ollama":
|
61
|
+
return as_ollama_message(message)
|
62
|
+
raise ValueError(f"Unknown format: {format}")
|
63
|
+
|
64
|
+
|
65
|
+
def as_anthropic_message(message: ChatMessageDict) -> "AnthropicMessage":
|
66
|
+
from anthropic.types import MessageParam as AnthropicMessage
|
67
|
+
|
68
|
+
if message["role"] == "system":
|
69
|
+
raise ValueError(
|
70
|
+
"Anthropic requires a system prompt to be specified in the `.create()` method"
|
71
|
+
)
|
72
|
+
return AnthropicMessage(content=message["content"], role=message["role"])
|
73
|
+
|
74
|
+
|
75
|
+
def as_google_message(message: ChatMessageDict) -> "GoogleMessage":
|
76
|
+
if sys.version_info < (3, 9):
|
77
|
+
raise ValueError("Google requires Python 3.9")
|
78
|
+
|
79
|
+
import google.generativeai.types as gtypes # pyright: ignore[reportMissingTypeStubs]
|
80
|
+
|
81
|
+
role = message["role"]
|
82
|
+
|
83
|
+
if role == "system":
|
84
|
+
raise ValueError(
|
85
|
+
"Google requires a system prompt to be specified in the `GenerativeModel()` constructor."
|
86
|
+
)
|
87
|
+
elif role == "assistant":
|
88
|
+
role = "model"
|
89
|
+
return gtypes.ContentDict(parts=[message["content"]], role=role)
|
90
|
+
|
91
|
+
|
92
|
+
def as_langchain_message(message: ChatMessageDict) -> "LangChainMessage":
|
93
|
+
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
94
|
+
|
95
|
+
content = message["content"]
|
96
|
+
role = message["role"]
|
97
|
+
if role == "system":
|
98
|
+
return SystemMessage(content=content)
|
99
|
+
if role == "assistant":
|
100
|
+
return AIMessage(content=content)
|
101
|
+
if role == "user":
|
102
|
+
return HumanMessage(content=content)
|
103
|
+
raise ValueError(f"Unknown role: {message['role']}")
|
104
|
+
|
105
|
+
|
106
|
+
def as_openai_message(message: ChatMessageDict) -> "OpenAIMessage":
|
107
|
+
from openai.types.chat import (
|
108
|
+
ChatCompletionAssistantMessageParam,
|
109
|
+
ChatCompletionSystemMessageParam,
|
110
|
+
ChatCompletionUserMessageParam,
|
111
|
+
)
|
112
|
+
|
113
|
+
content = message["content"]
|
114
|
+
role = message["role"]
|
115
|
+
if role == "system":
|
116
|
+
return ChatCompletionSystemMessageParam(content=content, role=role)
|
117
|
+
if role == "assistant":
|
118
|
+
return ChatCompletionAssistantMessageParam(content=content, role=role)
|
119
|
+
if role == "user":
|
120
|
+
return ChatCompletionUserMessageParam(content=content, role=role)
|
121
|
+
raise ValueError(f"Unknown role: {role}")
|
122
|
+
|
123
|
+
|
124
|
+
def as_ollama_message(message: ChatMessageDict) -> "OllamaMessage":
|
125
|
+
from ollama import Message as OllamaMessage
|
126
|
+
|
127
|
+
return OllamaMessage(content=message["content"], role=message["role"])
|
@@ -0,0 +1,67 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from typing import (
|
4
|
+
AbstractSet,
|
5
|
+
Any,
|
6
|
+
Collection,
|
7
|
+
Literal,
|
8
|
+
Protocol,
|
9
|
+
Union,
|
10
|
+
runtime_checkable,
|
11
|
+
)
|
12
|
+
|
13
|
+
|
14
|
+
# A duck type for tiktoken.Encoding
|
15
|
+
class TiktokenEncoding(Protocol):
|
16
|
+
name: str
|
17
|
+
|
18
|
+
def encode(
|
19
|
+
self,
|
20
|
+
text: str,
|
21
|
+
*,
|
22
|
+
allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), # noqa: B006
|
23
|
+
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
|
24
|
+
) -> list[int]: ...
|
25
|
+
|
26
|
+
|
27
|
+
# A duck type for tokenizers.Encoding
|
28
|
+
@runtime_checkable
|
29
|
+
class TokenizersEncoding(Protocol):
|
30
|
+
@property
|
31
|
+
def ids(self) -> list[int]: ...
|
32
|
+
|
33
|
+
|
34
|
+
# A duck type for tokenizers.Tokenizer
|
35
|
+
class TokenizersTokenizer(Protocol):
|
36
|
+
def encode(
|
37
|
+
self,
|
38
|
+
sequence: Any,
|
39
|
+
pair: Any = None,
|
40
|
+
is_pretokenized: bool = False,
|
41
|
+
add_special_tokens: bool = True,
|
42
|
+
) -> TokenizersEncoding: ...
|
43
|
+
|
44
|
+
|
45
|
+
TokenEncoding = Union[TiktokenEncoding, TokenizersTokenizer]
|
46
|
+
|
47
|
+
|
48
|
+
def get_default_tokenizer() -> TokenizersTokenizer:
|
49
|
+
try:
|
50
|
+
from tokenizers import Tokenizer
|
51
|
+
|
52
|
+
return Tokenizer.from_pretrained("bert-base-cased") # type: ignore
|
53
|
+
except ImportError:
|
54
|
+
raise ImportError(
|
55
|
+
"Failed to download a default tokenizer. "
|
56
|
+
"A tokenizer is required to impose `token_limits` on `chat.messages()`. "
|
57
|
+
"To get a generic default tokenizer, install the `tokenizers` "
|
58
|
+
"package (`pip install tokenizers`). "
|
59
|
+
)
|
60
|
+
except Exception as e:
|
61
|
+
raise RuntimeError(
|
62
|
+
"Failed to download a default tokenizer. "
|
63
|
+
"A tokenizer is required to impose `token_limits` on `chat.messages()`. "
|
64
|
+
"Try manually downloading a tokenizer using "
|
65
|
+
"`tokenizers.Tokenizer.from_pretrained()` and passing it to `ui.Chat()`."
|
66
|
+
f"Error: {e}"
|
67
|
+
) from e
|