chatlas 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of chatlas might be problematic. Click here for more details.
- chatlas/__init__.py +38 -0
- chatlas/_anthropic.py +643 -0
- chatlas/_chat.py +1279 -0
- chatlas/_content.py +242 -0
- chatlas/_content_image.py +272 -0
- chatlas/_display.py +139 -0
- chatlas/_github.py +147 -0
- chatlas/_google.py +456 -0
- chatlas/_groq.py +143 -0
- chatlas/_interpolate.py +133 -0
- chatlas/_logging.py +61 -0
- chatlas/_merge.py +103 -0
- chatlas/_ollama.py +125 -0
- chatlas/_openai.py +654 -0
- chatlas/_perplexity.py +148 -0
- chatlas/_provider.py +143 -0
- chatlas/_tokens.py +87 -0
- chatlas/_tokens_old.py +148 -0
- chatlas/_tools.py +134 -0
- chatlas/_turn.py +147 -0
- chatlas/_typing_extensions.py +26 -0
- chatlas/_utils.py +106 -0
- chatlas/types/__init__.py +32 -0
- chatlas/types/anthropic/__init__.py +14 -0
- chatlas/types/anthropic/_client.py +29 -0
- chatlas/types/anthropic/_client_bedrock.py +23 -0
- chatlas/types/anthropic/_submit.py +57 -0
- chatlas/types/google/__init__.py +12 -0
- chatlas/types/google/_client.py +101 -0
- chatlas/types/google/_submit.py +113 -0
- chatlas/types/openai/__init__.py +14 -0
- chatlas/types/openai/_client.py +22 -0
- chatlas/types/openai/_client_azure.py +25 -0
- chatlas/types/openai/_submit.py +135 -0
- chatlas-0.2.0.dist-info/METADATA +319 -0
- chatlas-0.2.0.dist-info/RECORD +37 -0
- chatlas-0.2.0.dist-info/WHEEL +4 -0
chatlas/_perplexity.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from typing import TYPE_CHECKING, Optional
|
|
5
|
+
|
|
6
|
+
from ._chat import Chat
|
|
7
|
+
from ._logging import log_model_default
|
|
8
|
+
from ._openai import ChatOpenAI
|
|
9
|
+
from ._turn import Turn
|
|
10
|
+
from ._utils import MISSING, MISSING_TYPE
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
from ._openai import ChatCompletion
|
|
14
|
+
from .types.openai import ChatClientArgs, SubmitInputArgs
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def ChatPerplexity(
|
|
18
|
+
*,
|
|
19
|
+
system_prompt: Optional[str] = None,
|
|
20
|
+
turns: Optional[list[Turn]] = None,
|
|
21
|
+
model: Optional[str] = None,
|
|
22
|
+
api_key: Optional[str] = None,
|
|
23
|
+
base_url: str = "https://api.perplexity.ai/",
|
|
24
|
+
seed: Optional[int] | MISSING_TYPE = MISSING,
|
|
25
|
+
kwargs: Optional["ChatClientArgs"] = None,
|
|
26
|
+
) -> Chat["SubmitInputArgs", ChatCompletion]:
|
|
27
|
+
"""
|
|
28
|
+
Chat with a model hosted on perplexity.ai.
|
|
29
|
+
|
|
30
|
+
Perplexity AI is a platform for running LLMs that are capable of
|
|
31
|
+
searching the web in real-time to help them answer questions with
|
|
32
|
+
information that may not have been available when the model was trained.
|
|
33
|
+
|
|
34
|
+
Prerequisites
|
|
35
|
+
-------------
|
|
36
|
+
|
|
37
|
+
::: {.callout-note}
|
|
38
|
+
## API key
|
|
39
|
+
|
|
40
|
+
Sign up at <https://www.perplexity.ai> to get an API key.
|
|
41
|
+
:::
|
|
42
|
+
|
|
43
|
+
::: {.callout-note}
|
|
44
|
+
## Python requirements
|
|
45
|
+
|
|
46
|
+
`ChatPerplexity` requires the `openai` package (e.g., `pip install openai`).
|
|
47
|
+
:::
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
Examples
|
|
51
|
+
--------
|
|
52
|
+
|
|
53
|
+
```python
|
|
54
|
+
import os
|
|
55
|
+
from chatlas import ChatPerplexity
|
|
56
|
+
|
|
57
|
+
chat = ChatPerplexity(api_key=os.getenv("PERPLEXITY_API_KEY"))
|
|
58
|
+
chat.chat("What is the capital of France?")
|
|
59
|
+
```
|
|
60
|
+
|
|
61
|
+
Parameters
|
|
62
|
+
----------
|
|
63
|
+
system_prompt
|
|
64
|
+
A system prompt to set the behavior of the assistant.
|
|
65
|
+
turns
|
|
66
|
+
A list of turns to start the chat with (i.e., continuing a previous
|
|
67
|
+
conversation). If not provided, the conversation begins from scratch. Do
|
|
68
|
+
not provide non-`None` values for both `turns` and `system_prompt`. Each
|
|
69
|
+
message in the list should be a dictionary with at least `role` (usually
|
|
70
|
+
`system`, `user`, or `assistant`, but `tool` is also possible). Normally
|
|
71
|
+
there is also a `content` field, which is a string.
|
|
72
|
+
model
|
|
73
|
+
The model to use for the chat. The default, None, will pick a reasonable
|
|
74
|
+
default, and warn you about it. We strongly recommend explicitly
|
|
75
|
+
choosing a model for all but the most casual use.
|
|
76
|
+
api_key
|
|
77
|
+
The API key to use for authentication. You generally should not supply
|
|
78
|
+
this directly, but instead set the `PERPLEXITY_API_KEY` environment
|
|
79
|
+
variable.
|
|
80
|
+
base_url
|
|
81
|
+
The base URL to the endpoint; the default uses Perplexity's API.
|
|
82
|
+
seed
|
|
83
|
+
Optional integer seed that ChatGPT uses to try and make output more
|
|
84
|
+
reproducible.
|
|
85
|
+
kwargs
|
|
86
|
+
Additional arguments to pass to the `openai.OpenAI()` client
|
|
87
|
+
constructor.
|
|
88
|
+
|
|
89
|
+
Returns
|
|
90
|
+
-------
|
|
91
|
+
Chat
|
|
92
|
+
A chat object that retains the state of the conversation.
|
|
93
|
+
|
|
94
|
+
Note
|
|
95
|
+
----
|
|
96
|
+
This function is a lightweight wrapper around [](`chatlas.ChatOpenAI`) with
|
|
97
|
+
the defaults tweaked for perplexity.ai.
|
|
98
|
+
|
|
99
|
+
Note
|
|
100
|
+
----
|
|
101
|
+
Pasting an API key into a chat constructor (e.g., `ChatPerplexity(api_key="...")`)
|
|
102
|
+
is the simplest way to get started, and is fine for interactive use, but is
|
|
103
|
+
problematic for code that may be shared with others.
|
|
104
|
+
|
|
105
|
+
Instead, consider using environment variables or a configuration file to manage
|
|
106
|
+
your credentials. One popular way to manage credentials is to use a `.env` file
|
|
107
|
+
to store your credentials, and then use the `python-dotenv` package to load them
|
|
108
|
+
into your environment.
|
|
109
|
+
|
|
110
|
+
```shell
|
|
111
|
+
pip install python-dotenv
|
|
112
|
+
```
|
|
113
|
+
|
|
114
|
+
```shell
|
|
115
|
+
# .env
|
|
116
|
+
PERPLEXITY_API_KEY=...
|
|
117
|
+
```
|
|
118
|
+
|
|
119
|
+
```python
|
|
120
|
+
from chatlas import ChatPerplexity
|
|
121
|
+
from dotenv import load_dotenv
|
|
122
|
+
|
|
123
|
+
load_dotenv()
|
|
124
|
+
chat = ChatPerplexity()
|
|
125
|
+
chat.console()
|
|
126
|
+
```
|
|
127
|
+
|
|
128
|
+
Another, more general, solution is to load your environment variables into the shell
|
|
129
|
+
before starting Python (maybe in a `.bashrc`, `.zshrc`, etc. file):
|
|
130
|
+
|
|
131
|
+
```shell
|
|
132
|
+
export PERPLEXITY_API_KEY=...
|
|
133
|
+
```
|
|
134
|
+
"""
|
|
135
|
+
if model is None:
|
|
136
|
+
model = log_model_default("llama-3.1-sonar-small-128k-online")
|
|
137
|
+
if api_key is None:
|
|
138
|
+
api_key = os.getenv("PERPLEXITY_API_KEY")
|
|
139
|
+
|
|
140
|
+
return ChatOpenAI(
|
|
141
|
+
system_prompt=system_prompt,
|
|
142
|
+
turns=turns,
|
|
143
|
+
model=model,
|
|
144
|
+
api_key=api_key,
|
|
145
|
+
base_url=base_url,
|
|
146
|
+
seed=seed,
|
|
147
|
+
kwargs=kwargs,
|
|
148
|
+
)
|
chatlas/_provider.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from typing import (
|
|
5
|
+
Any,
|
|
6
|
+
AsyncIterable,
|
|
7
|
+
Generic,
|
|
8
|
+
Iterable,
|
|
9
|
+
Literal,
|
|
10
|
+
Optional,
|
|
11
|
+
TypeVar,
|
|
12
|
+
overload,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
from pydantic import BaseModel
|
|
16
|
+
|
|
17
|
+
from ._tools import Tool
|
|
18
|
+
from ._turn import Turn
|
|
19
|
+
|
|
20
|
+
ChatCompletionT = TypeVar("ChatCompletionT")
|
|
21
|
+
ChatCompletionChunkT = TypeVar("ChatCompletionChunkT")
|
|
22
|
+
# A dictionary representation of a chat completion
|
|
23
|
+
ChatCompletionDictT = TypeVar("ChatCompletionDictT")
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class Provider(
|
|
27
|
+
ABC, Generic[ChatCompletionT, ChatCompletionChunkT, ChatCompletionDictT]
|
|
28
|
+
):
|
|
29
|
+
"""
|
|
30
|
+
A model provider interface for a [](`~chatlas.Chat`).
|
|
31
|
+
|
|
32
|
+
This abstract class defines the interface a model provider must implement in
|
|
33
|
+
order to be used with a [](`~chatlas.Chat`) instance. The provider is
|
|
34
|
+
responsible for performing the actual chat completion, and for handling the
|
|
35
|
+
streaming of the completion results.
|
|
36
|
+
|
|
37
|
+
Note that this class is exposed for developers who wish to implement their
|
|
38
|
+
own provider. In general, you should not need to interact with this class
|
|
39
|
+
directly.
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
@overload
|
|
43
|
+
@abstractmethod
|
|
44
|
+
def chat_perform(
|
|
45
|
+
self,
|
|
46
|
+
*,
|
|
47
|
+
stream: Literal[False],
|
|
48
|
+
turns: list[Turn],
|
|
49
|
+
tools: dict[str, Tool],
|
|
50
|
+
data_model: Optional[type[BaseModel]],
|
|
51
|
+
kwargs: Any,
|
|
52
|
+
) -> ChatCompletionT: ...
|
|
53
|
+
|
|
54
|
+
@overload
|
|
55
|
+
@abstractmethod
|
|
56
|
+
def chat_perform(
|
|
57
|
+
self,
|
|
58
|
+
*,
|
|
59
|
+
stream: Literal[True],
|
|
60
|
+
turns: list[Turn],
|
|
61
|
+
tools: dict[str, Tool],
|
|
62
|
+
data_model: Optional[type[BaseModel]],
|
|
63
|
+
kwargs: Any,
|
|
64
|
+
) -> Iterable[ChatCompletionChunkT]: ...
|
|
65
|
+
|
|
66
|
+
@abstractmethod
|
|
67
|
+
def chat_perform(
|
|
68
|
+
self,
|
|
69
|
+
*,
|
|
70
|
+
stream: bool,
|
|
71
|
+
turns: list[Turn],
|
|
72
|
+
tools: dict[str, Tool],
|
|
73
|
+
data_model: Optional[type[BaseModel]],
|
|
74
|
+
kwargs: Any,
|
|
75
|
+
) -> Iterable[ChatCompletionChunkT] | ChatCompletionT: ...
|
|
76
|
+
|
|
77
|
+
@overload
|
|
78
|
+
@abstractmethod
|
|
79
|
+
async def chat_perform_async(
|
|
80
|
+
self,
|
|
81
|
+
*,
|
|
82
|
+
stream: Literal[False],
|
|
83
|
+
turns: list[Turn],
|
|
84
|
+
tools: dict[str, Tool],
|
|
85
|
+
data_model: Optional[type[BaseModel]],
|
|
86
|
+
kwargs: Any,
|
|
87
|
+
) -> ChatCompletionT: ...
|
|
88
|
+
|
|
89
|
+
@overload
|
|
90
|
+
@abstractmethod
|
|
91
|
+
async def chat_perform_async(
|
|
92
|
+
self,
|
|
93
|
+
*,
|
|
94
|
+
stream: Literal[True],
|
|
95
|
+
turns: list[Turn],
|
|
96
|
+
tools: dict[str, Tool],
|
|
97
|
+
data_model: Optional[type[BaseModel]],
|
|
98
|
+
kwargs: Any,
|
|
99
|
+
) -> AsyncIterable[ChatCompletionChunkT]: ...
|
|
100
|
+
|
|
101
|
+
@abstractmethod
|
|
102
|
+
async def chat_perform_async(
|
|
103
|
+
self,
|
|
104
|
+
*,
|
|
105
|
+
stream: bool,
|
|
106
|
+
turns: list[Turn],
|
|
107
|
+
tools: dict[str, Tool],
|
|
108
|
+
data_model: Optional[type[BaseModel]],
|
|
109
|
+
kwargs: Any,
|
|
110
|
+
) -> AsyncIterable[ChatCompletionChunkT] | ChatCompletionT: ...
|
|
111
|
+
|
|
112
|
+
@abstractmethod
|
|
113
|
+
def stream_text(self, chunk: ChatCompletionChunkT) -> Optional[str]: ...
|
|
114
|
+
|
|
115
|
+
@abstractmethod
|
|
116
|
+
def stream_merge_chunks(
|
|
117
|
+
self,
|
|
118
|
+
completion: Optional[ChatCompletionDictT],
|
|
119
|
+
chunk: ChatCompletionChunkT,
|
|
120
|
+
) -> ChatCompletionDictT: ...
|
|
121
|
+
|
|
122
|
+
@abstractmethod
|
|
123
|
+
def stream_turn(
|
|
124
|
+
self,
|
|
125
|
+
completion: ChatCompletionDictT,
|
|
126
|
+
has_data_model: bool,
|
|
127
|
+
stream: Any,
|
|
128
|
+
) -> Turn: ...
|
|
129
|
+
|
|
130
|
+
@abstractmethod
|
|
131
|
+
async def stream_turn_async(
|
|
132
|
+
self,
|
|
133
|
+
completion: ChatCompletionDictT,
|
|
134
|
+
has_data_model: bool,
|
|
135
|
+
stream: Any,
|
|
136
|
+
) -> Turn: ...
|
|
137
|
+
|
|
138
|
+
@abstractmethod
|
|
139
|
+
def value_turn(
|
|
140
|
+
self,
|
|
141
|
+
completion: ChatCompletionT,
|
|
142
|
+
has_data_model: bool,
|
|
143
|
+
) -> Turn: ...
|
chatlas/_tokens.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import copy
|
|
4
|
+
from threading import Lock
|
|
5
|
+
from typing import TYPE_CHECKING
|
|
6
|
+
|
|
7
|
+
from ._logging import logger
|
|
8
|
+
from ._typing_extensions import TypedDict
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
from ._provider import Provider
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class TokenUsage(TypedDict):
|
|
15
|
+
"""
|
|
16
|
+
Token usage for a given provider (name).
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
name: str
|
|
20
|
+
input: int
|
|
21
|
+
output: int
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class ThreadSafeTokenCounter:
|
|
25
|
+
def __init__(self):
|
|
26
|
+
self._lock = Lock()
|
|
27
|
+
self._tokens: dict[str, TokenUsage] = {}
|
|
28
|
+
|
|
29
|
+
def log_tokens(self, name: str, input_tokens: int, output_tokens: int) -> None:
|
|
30
|
+
logger.info(
|
|
31
|
+
f"Provider '{name}' generated a response of {output_tokens} tokens "
|
|
32
|
+
f"from an input of {input_tokens} tokens."
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
with self._lock:
|
|
36
|
+
if name not in self._tokens:
|
|
37
|
+
self._tokens[name] = {
|
|
38
|
+
"name": name,
|
|
39
|
+
"input": input_tokens,
|
|
40
|
+
"output": output_tokens,
|
|
41
|
+
}
|
|
42
|
+
else:
|
|
43
|
+
self._tokens[name]["input"] += input_tokens
|
|
44
|
+
self._tokens[name]["output"] += output_tokens
|
|
45
|
+
|
|
46
|
+
def get_usage(self) -> list[TokenUsage] | None:
|
|
47
|
+
with self._lock:
|
|
48
|
+
if not self._tokens:
|
|
49
|
+
return None
|
|
50
|
+
# Create a deep copy to avoid external modifications
|
|
51
|
+
return copy.deepcopy(list(self._tokens.values()))
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
# Global instance
|
|
55
|
+
_token_counter = ThreadSafeTokenCounter()
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def tokens_log(provider: "Provider", tokens: tuple[int, int]) -> None:
|
|
59
|
+
"""
|
|
60
|
+
Log token usage for a provider in a thread-safe manner.
|
|
61
|
+
"""
|
|
62
|
+
name = provider.__class__.__name__.replace("Provider", "")
|
|
63
|
+
_token_counter.log_tokens(name, tokens[0], tokens[1])
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def tokens_reset() -> None:
|
|
67
|
+
"""
|
|
68
|
+
Reset the token usage counter
|
|
69
|
+
"""
|
|
70
|
+
global _token_counter # noqa: PLW0603
|
|
71
|
+
_token_counter = ThreadSafeTokenCounter()
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def token_usage() -> list[TokenUsage] | None:
|
|
75
|
+
"""
|
|
76
|
+
Report on token usage in the current session
|
|
77
|
+
|
|
78
|
+
Call this function to find out the cumulative number of tokens that you
|
|
79
|
+
have sent and received in the current session.
|
|
80
|
+
|
|
81
|
+
Returns
|
|
82
|
+
-------
|
|
83
|
+
list[TokenUsage] | None
|
|
84
|
+
A list of dictionaries with the following keys: "name", "input", and "output".
|
|
85
|
+
If no tokens have been logged, then None is returned.
|
|
86
|
+
"""
|
|
87
|
+
return _token_counter.get_usage()
|
chatlas/_tokens_old.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
1
|
+
# # A duck type for tiktoken.Encoding
|
|
2
|
+
# class TiktokenEncoding(Protocol):
|
|
3
|
+
# name: str
|
|
4
|
+
#
|
|
5
|
+
# def encode(
|
|
6
|
+
# self,
|
|
7
|
+
# text: str,
|
|
8
|
+
# *,
|
|
9
|
+
# allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), # noqa: B006
|
|
10
|
+
# disallowed_special: Union[Literal["all"], Collection[str]] = "all",
|
|
11
|
+
# ) -> list[int]: ...
|
|
12
|
+
#
|
|
13
|
+
#
|
|
14
|
+
# # A duck type for tokenizers.Encoding
|
|
15
|
+
# @runtime_checkable
|
|
16
|
+
# class TokenizersEncoding(Protocol):
|
|
17
|
+
# @property
|
|
18
|
+
# def ids(self) -> list[int]: ...
|
|
19
|
+
#
|
|
20
|
+
#
|
|
21
|
+
# # A duck type for tokenizers.Tokenizer
|
|
22
|
+
# class TokenizersTokenizer(Protocol):
|
|
23
|
+
# def encode(
|
|
24
|
+
# self,
|
|
25
|
+
# sequence: Any,
|
|
26
|
+
# pair: Any = None,
|
|
27
|
+
# is_pretokenized: bool = False,
|
|
28
|
+
# add_special_tokens: bool = True,
|
|
29
|
+
# ) -> TokenizersEncoding: ...
|
|
30
|
+
#
|
|
31
|
+
#
|
|
32
|
+
# TokenEncoding = Union[TiktokenEncoding, TokenizersTokenizer]
|
|
33
|
+
#
|
|
34
|
+
#
|
|
35
|
+
# def get_default_tokenizer() -> TokenizersTokenizer | None:
|
|
36
|
+
# try:
|
|
37
|
+
# from tokenizers import Tokenizer
|
|
38
|
+
#
|
|
39
|
+
# return Tokenizer.from_pretrained("bert-base-cased") # type: ignore
|
|
40
|
+
# except Exception:
|
|
41
|
+
# pass
|
|
42
|
+
#
|
|
43
|
+
# return None
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
# def _get_token_count(
|
|
47
|
+
# self,
|
|
48
|
+
# content: str,
|
|
49
|
+
# ) -> int:
|
|
50
|
+
# if self._tokenizer is None:
|
|
51
|
+
# self._tokenizer = get_default_tokenizer()
|
|
52
|
+
#
|
|
53
|
+
# if self._tokenizer is None:
|
|
54
|
+
# raise ValueError(
|
|
55
|
+
# "A tokenizer is required to impose `token_limits` on messages. "
|
|
56
|
+
# "To get a generic default tokenizer, install the `tokenizers` "
|
|
57
|
+
# "package (`pip install tokenizers`). "
|
|
58
|
+
# "To get a more precise token count, provide a specific tokenizer "
|
|
59
|
+
# "to the `Chat` constructor."
|
|
60
|
+
# )
|
|
61
|
+
#
|
|
62
|
+
# encoded = self._tokenizer.encode(content)
|
|
63
|
+
# if isinstance(encoded, TokenizersEncoding):
|
|
64
|
+
# return len(encoded.ids)
|
|
65
|
+
# else:
|
|
66
|
+
# return len(encoded)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
# def _trim_messages(
|
|
70
|
+
# self,
|
|
71
|
+
# messages: tuple[TransformedMessage, ...],
|
|
72
|
+
# token_limits: tuple[int, int],
|
|
73
|
+
# ) -> tuple[TransformedMessage, ...]:
|
|
74
|
+
|
|
75
|
+
# n_total, n_reserve = token_limits
|
|
76
|
+
# if n_total <= n_reserve:
|
|
77
|
+
# raise ValueError(
|
|
78
|
+
# f"Invalid token limits: {token_limits}. The 1st value must be greater "
|
|
79
|
+
# "than the 2nd value."
|
|
80
|
+
# )
|
|
81
|
+
|
|
82
|
+
# # Since don't trim system messages, 1st obtain their total token count
|
|
83
|
+
# # (so we can determine how many non-system messages can fit)
|
|
84
|
+
# n_system_tokens: int = 0
|
|
85
|
+
# n_system_messages: int = 0
|
|
86
|
+
# n_other_messages: int = 0
|
|
87
|
+
# token_counts: list[int] = []
|
|
88
|
+
# for m in messages:
|
|
89
|
+
# content = (
|
|
90
|
+
# m.content_server if isinstance(m, TransformedMessage) else m.content
|
|
91
|
+
# )
|
|
92
|
+
# count = self._get_token_count(content)
|
|
93
|
+
# token_counts.append(count)
|
|
94
|
+
# if m.role == "system":
|
|
95
|
+
# n_system_tokens += count
|
|
96
|
+
# n_system_messages += 1
|
|
97
|
+
# else:
|
|
98
|
+
# n_other_messages += 1
|
|
99
|
+
|
|
100
|
+
# remaining_non_system_tokens = n_total - n_reserve - n_system_tokens
|
|
101
|
+
|
|
102
|
+
# if remaining_non_system_tokens <= 0:
|
|
103
|
+
# raise ValueError(
|
|
104
|
+
# f"System messages exceed `.messages(token_limits={token_limits})`. "
|
|
105
|
+
# "Consider increasing the 1st value of `token_limit` or setting it to "
|
|
106
|
+
# "`token_limit=None` to disable token limits."
|
|
107
|
+
# )
|
|
108
|
+
|
|
109
|
+
# # Now, iterate through the messages in reverse order and appending
|
|
110
|
+
# # until we run out of tokens
|
|
111
|
+
# messages2: list[TransformedMessage] = []
|
|
112
|
+
# n_other_messages2: int = 0
|
|
113
|
+
# token_counts.reverse()
|
|
114
|
+
# for i, m in enumerate(reversed(messages)):
|
|
115
|
+
# if m.role == "system":
|
|
116
|
+
# messages2.append(m)
|
|
117
|
+
# continue
|
|
118
|
+
# remaining_non_system_tokens -= token_counts[i]
|
|
119
|
+
# if remaining_non_system_tokens >= 0:
|
|
120
|
+
# messages2.append(m)
|
|
121
|
+
# n_other_messages2 += 1
|
|
122
|
+
|
|
123
|
+
# messages2.reverse()
|
|
124
|
+
|
|
125
|
+
# if len(messages2) == n_system_messages and n_other_messages2 > 0:
|
|
126
|
+
# raise ValueError(
|
|
127
|
+
# f"Only system messages fit within `.messages(token_limits={token_limits})`. "
|
|
128
|
+
# "Consider increasing the 1st value of `token_limit` or setting it to "
|
|
129
|
+
# "`token_limit=None` to disable token limits."
|
|
130
|
+
# )
|
|
131
|
+
|
|
132
|
+
# return tuple(messages2)
|
|
133
|
+
|
|
134
|
+
# def _trim_anthropic_messages(
|
|
135
|
+
# self,
|
|
136
|
+
# messages: tuple[TransformedMessage, ...],
|
|
137
|
+
# ) -> tuple[TransformedMessage, ...]:
|
|
138
|
+
|
|
139
|
+
# if any(m.role == "system" for m in messages):
|
|
140
|
+
# raise ValueError(
|
|
141
|
+
# "Anthropic requires a system prompt to be specified in it's `.create()` method "
|
|
142
|
+
# "(not in the chat messages with `role: system`)."
|
|
143
|
+
# )
|
|
144
|
+
# for i, m in enumerate(messages):
|
|
145
|
+
# if m.role == "user":
|
|
146
|
+
# return messages[i:]
|
|
147
|
+
|
|
148
|
+
# return ()
|
chatlas/_tools.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import inspect
|
|
4
|
+
import warnings
|
|
5
|
+
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional
|
|
6
|
+
|
|
7
|
+
from pydantic import BaseModel, Field, create_model
|
|
8
|
+
|
|
9
|
+
from . import _utils
|
|
10
|
+
|
|
11
|
+
__all__ = ("Tool",)
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
from openai.types.chat import ChatCompletionToolParam
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class Tool:
|
|
18
|
+
"""
|
|
19
|
+
Define a tool
|
|
20
|
+
|
|
21
|
+
Define a Python function for use by a chatbot. The function will always be
|
|
22
|
+
invoked in the current Python process.
|
|
23
|
+
|
|
24
|
+
Parameters
|
|
25
|
+
----------
|
|
26
|
+
func
|
|
27
|
+
The function to be invoked when the tool is called.
|
|
28
|
+
model
|
|
29
|
+
A Pydantic model that describes the input parameters for the function.
|
|
30
|
+
If not provided, the model will be inferred from the function's type hints.
|
|
31
|
+
The primary reason why you might want to provide a model in
|
|
32
|
+
Note that the name and docstring of the model takes precedence over the
|
|
33
|
+
name and docstring of the function.
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
func: Callable[..., Any] | Callable[..., Awaitable[Any]]
|
|
37
|
+
|
|
38
|
+
def __init__(
|
|
39
|
+
self,
|
|
40
|
+
func: Callable[..., Any] | Callable[..., Awaitable[Any]],
|
|
41
|
+
*,
|
|
42
|
+
model: Optional[type[BaseModel]] = None,
|
|
43
|
+
):
|
|
44
|
+
self.func = func
|
|
45
|
+
self._is_async = _utils.is_async_callable(func)
|
|
46
|
+
self.schema = func_to_schema(func, model)
|
|
47
|
+
self.name = self.schema["function"]["name"]
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def func_to_schema(
|
|
51
|
+
func: Callable[..., Any] | Callable[..., Awaitable[Any]],
|
|
52
|
+
model: Optional[type[BaseModel]] = None,
|
|
53
|
+
) -> "ChatCompletionToolParam":
|
|
54
|
+
if model is None:
|
|
55
|
+
model = func_to_basemodel(func)
|
|
56
|
+
|
|
57
|
+
# Throw if there is a mismatch between the model and the function parameters
|
|
58
|
+
params = inspect.signature(func).parameters
|
|
59
|
+
fields = model.model_fields
|
|
60
|
+
diff = set(params) ^ set(fields)
|
|
61
|
+
if diff:
|
|
62
|
+
raise ValueError(
|
|
63
|
+
f"`model` fields must match tool function parameters exactly. "
|
|
64
|
+
f"Fields found in one but not the other: {diff}"
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
params = basemodel_to_param_schema(model)
|
|
68
|
+
|
|
69
|
+
return {
|
|
70
|
+
"type": "function",
|
|
71
|
+
"function": {
|
|
72
|
+
"name": model.__name__ or func.__name__,
|
|
73
|
+
"description": model.__doc__ or func.__doc__ or "",
|
|
74
|
+
"parameters": params,
|
|
75
|
+
},
|
|
76
|
+
}
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def func_to_basemodel(func: Callable) -> type[BaseModel]:
|
|
80
|
+
params = inspect.signature(func).parameters
|
|
81
|
+
fields = {}
|
|
82
|
+
|
|
83
|
+
for name, param in params.items():
|
|
84
|
+
annotation = param.annotation
|
|
85
|
+
|
|
86
|
+
if annotation == inspect.Parameter.empty:
|
|
87
|
+
warnings.warn(
|
|
88
|
+
f"Parameter `{name}` of function `{name}` has no type hint. "
|
|
89
|
+
"Using `Any` as a fallback."
|
|
90
|
+
)
|
|
91
|
+
annotation = Any
|
|
92
|
+
|
|
93
|
+
if param.default != inspect.Parameter.empty:
|
|
94
|
+
field = Field(default=param.default)
|
|
95
|
+
else:
|
|
96
|
+
field = Field()
|
|
97
|
+
|
|
98
|
+
# Add the field to our fields dict
|
|
99
|
+
fields[name] = (annotation, field)
|
|
100
|
+
|
|
101
|
+
return create_model(func.__name__, **fields)
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def basemodel_to_param_schema(model: type[BaseModel]) -> dict[str, object]:
|
|
105
|
+
try:
|
|
106
|
+
import openai
|
|
107
|
+
except ImportError:
|
|
108
|
+
raise ImportError(
|
|
109
|
+
"The openai package is required for this functionality. "
|
|
110
|
+
"Please install it with `pip install openai`."
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
# Lean on openai's ability to translate BaseModel.model_json_schema()
|
|
114
|
+
# to a valid tool schema (this wouldn't be impossible to do ourselves,
|
|
115
|
+
# but it's fair amount of logic to substitute `$refs`, etc.)
|
|
116
|
+
tool = openai.pydantic_function_tool(model)
|
|
117
|
+
|
|
118
|
+
fn = tool["function"]
|
|
119
|
+
if "parameters" not in fn:
|
|
120
|
+
raise ValueError("Expected `parameters` in function definition.")
|
|
121
|
+
|
|
122
|
+
params = fn["parameters"]
|
|
123
|
+
|
|
124
|
+
# For some reason, openai (or pydantic?) wants to include a title
|
|
125
|
+
# at the model and field level. I don't think we actually need or want this.
|
|
126
|
+
if "title" in params:
|
|
127
|
+
del params["title"]
|
|
128
|
+
|
|
129
|
+
if "properties" in params and isinstance(params["properties"], dict):
|
|
130
|
+
for prop in params["properties"].values():
|
|
131
|
+
if "title" in prop:
|
|
132
|
+
del prop["title"]
|
|
133
|
+
|
|
134
|
+
return params
|