ommlds 0.0.0.dev466__py3-none-any.whl → 0.0.0.dev468__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ommlds might be problematic. Click here for more details.
- ommlds/.omlish-manifests.json +129 -6
- ommlds/__about__.py +2 -2
- ommlds/backends/ollama/__init__.py +0 -0
- ommlds/backends/ollama/protocol.py +170 -0
- ommlds/backends/transformers/__init__.py +0 -0
- ommlds/backends/transformers/streamers.py +73 -0
- ommlds/cli/sessions/chat/backends/catalog.py +1 -1
- ommlds/minichain/__init__.py +4 -0
- ommlds/minichain/backends/impls/llamacpp/chat.py +9 -0
- ommlds/minichain/backends/impls/llamacpp/stream.py +26 -10
- ommlds/minichain/backends/impls/mlx/chat.py +95 -21
- ommlds/minichain/backends/impls/ollama/__init__.py +0 -0
- ommlds/minichain/backends/impls/ollama/chat.py +196 -0
- ommlds/minichain/backends/impls/openai/chat.py +2 -2
- ommlds/minichain/backends/impls/openai/format.py +106 -107
- ommlds/minichain/backends/impls/openai/stream.py +14 -13
- ommlds/minichain/backends/impls/transformers/transformers.py +93 -14
- ommlds/minichain/chat/stream/types.py +3 -0
- ommlds/minichain/standard.py +7 -0
- {ommlds-0.0.0.dev466.dist-info → ommlds-0.0.0.dev468.dist-info}/METADATA +7 -7
- {ommlds-0.0.0.dev466.dist-info → ommlds-0.0.0.dev468.dist-info}/RECORD +25 -20
- ommlds/minichain/backends/impls/openai/format2.py +0 -210
- {ommlds-0.0.0.dev466.dist-info → ommlds-0.0.0.dev468.dist-info}/WHEEL +0 -0
- {ommlds-0.0.0.dev466.dist-info → ommlds-0.0.0.dev468.dist-info}/entry_points.txt +0 -0
- {ommlds-0.0.0.dev466.dist-info → ommlds-0.0.0.dev468.dist-info}/licenses/LICENSE +0 -0
- {ommlds-0.0.0.dev466.dist-info → ommlds-0.0.0.dev468.dist-info}/top_level.txt +0 -0
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import contextlib
|
|
1
2
|
import typing as ta
|
|
2
3
|
|
|
3
4
|
from omlish import check
|
|
@@ -5,6 +6,7 @@ from omlish import lang
|
|
|
5
6
|
from omlish import typedvalues as tv
|
|
6
7
|
|
|
7
8
|
from .....backends import mlx as mlxu
|
|
9
|
+
from ....chat.choices.services import ChatChoicesOutputs
|
|
8
10
|
from ....chat.choices.services import ChatChoicesRequest
|
|
9
11
|
from ....chat.choices.services import ChatChoicesResponse
|
|
10
12
|
from ....chat.choices.services import static_check_is_chat_choices_service
|
|
@@ -14,19 +16,28 @@ from ....chat.messages import AiMessage
|
|
|
14
16
|
from ....chat.messages import Message
|
|
15
17
|
from ....chat.messages import SystemMessage
|
|
16
18
|
from ....chat.messages import UserMessage
|
|
19
|
+
from ....chat.stream.services import ChatChoicesStreamRequest
|
|
20
|
+
from ....chat.stream.services import ChatChoicesStreamResponse
|
|
21
|
+
from ....chat.stream.services import static_check_is_chat_choices_stream_service
|
|
22
|
+
from ....chat.stream.types import AiChoiceDeltas
|
|
23
|
+
from ....chat.stream.types import AiChoicesDeltas
|
|
24
|
+
from ....chat.stream.types import ContentAiChoiceDelta
|
|
17
25
|
from ....configs import Config
|
|
18
26
|
from ....llms.types import MaxTokens
|
|
19
27
|
from ....models.configs import ModelPath
|
|
20
28
|
from ....models.configs import ModelRepo
|
|
21
29
|
from ....models.configs import ModelSpecifier
|
|
30
|
+
from ....resources import UseResources
|
|
22
31
|
from ....standard import DefaultOptions
|
|
32
|
+
from ....stream.services import StreamResponseSink
|
|
33
|
+
from ....stream.services import new_stream_response
|
|
23
34
|
|
|
24
35
|
|
|
25
36
|
##
|
|
26
37
|
|
|
27
38
|
|
|
28
39
|
# @omlish-manifest $.minichain.backends.strings.manifests.BackendStringsManifest(
|
|
29
|
-
# ['ChatChoicesService'],
|
|
40
|
+
# ['ChatChoicesService', 'ChatChoicesStreamService'],
|
|
30
41
|
# 'mlx',
|
|
31
42
|
# )
|
|
32
43
|
|
|
@@ -34,12 +45,7 @@ from ....standard import DefaultOptions
|
|
|
34
45
|
##
|
|
35
46
|
|
|
36
47
|
|
|
37
|
-
|
|
38
|
-
# name='mlx',
|
|
39
|
-
# type='ChatChoicesService',
|
|
40
|
-
# )
|
|
41
|
-
@static_check_is_chat_choices_service
|
|
42
|
-
class MlxChatChoicesService(lang.ExitStacked):
|
|
48
|
+
class BaseMlxChatChoicesService(lang.ExitStacked):
|
|
43
49
|
DEFAULT_MODEL: ta.ClassVar[ModelSpecifier] = (
|
|
44
50
|
# 'mlx-community/DeepSeek-Coder-V2-Lite-Instruct-8bit'
|
|
45
51
|
# 'mlx-community/Llama-3.3-70B-Instruct-4bit'
|
|
@@ -52,8 +58,8 @@ class MlxChatChoicesService(lang.ExitStacked):
|
|
|
52
58
|
# 'mlx-community/Qwen2.5-0.5B-4bit'
|
|
53
59
|
# 'mlx-community/Qwen2.5-32B-Instruct-8bit'
|
|
54
60
|
# 'mlx-community/Qwen2.5-Coder-32B-Instruct-8bit'
|
|
55
|
-
# 'mlx-community/mamba-2.8b-hf-f16'
|
|
56
61
|
# 'mlx-community/Qwen3-30B-A3B-6bit'
|
|
62
|
+
# 'mlx-community/mamba-2.8b-hf-f16'
|
|
57
63
|
)
|
|
58
64
|
|
|
59
65
|
def __init__(self, *configs: Config) -> None:
|
|
@@ -70,10 +76,7 @@ class MlxChatChoicesService(lang.ExitStacked):
|
|
|
70
76
|
}
|
|
71
77
|
|
|
72
78
|
def _get_msg_content(self, m: Message) -> str | None:
|
|
73
|
-
if isinstance(m, AiMessage):
|
|
74
|
-
return check.isinstance(m.c, str)
|
|
75
|
-
|
|
76
|
-
elif isinstance(m, (SystemMessage, UserMessage)):
|
|
79
|
+
if isinstance(m, (AiMessage, SystemMessage, UserMessage)):
|
|
77
80
|
return check.isinstance(m.c, str)
|
|
78
81
|
|
|
79
82
|
else:
|
|
@@ -96,10 +99,9 @@ class MlxChatChoicesService(lang.ExitStacked):
|
|
|
96
99
|
max_tokens=MaxTokens,
|
|
97
100
|
)
|
|
98
101
|
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
tokenizer = loaded_model.tokenization.tokenizer
|
|
102
|
+
@lang.cached_function(transient=True)
|
|
103
|
+
def _get_tokenizer(self) -> mlxu.tokenization.Tokenizer:
|
|
104
|
+
tokenizer = self._load_model().tokenization.tokenizer
|
|
103
105
|
|
|
104
106
|
if not (
|
|
105
107
|
hasattr(tokenizer, 'apply_chat_template') and
|
|
@@ -107,26 +109,44 @@ class MlxChatChoicesService(lang.ExitStacked):
|
|
|
107
109
|
):
|
|
108
110
|
raise RuntimeError(tokenizer)
|
|
109
111
|
|
|
110
|
-
|
|
112
|
+
return tokenizer
|
|
113
|
+
|
|
114
|
+
def _build_prompt(self, messages: ta.Sequence[Message]) -> str:
|
|
115
|
+
return check.isinstance(self._get_tokenizer().apply_chat_template(
|
|
111
116
|
[ # type: ignore[arg-type]
|
|
112
117
|
dict(
|
|
113
118
|
role=self.ROLES_MAP[type(m)],
|
|
114
119
|
content=self._get_msg_content(m),
|
|
115
120
|
)
|
|
116
|
-
for m in
|
|
121
|
+
for m in messages
|
|
117
122
|
],
|
|
118
123
|
tokenize=False,
|
|
119
124
|
add_generation_prompt=True,
|
|
120
|
-
)
|
|
125
|
+
), str)
|
|
121
126
|
|
|
122
|
-
|
|
127
|
+
def _build_kwargs(self, oc: tv.TypedValuesConsumer) -> dict[str, ta.Any]:
|
|
128
|
+
kwargs: dict[str, ta.Any] = {}
|
|
129
|
+
kwargs.update(oc.pop_scalar_kwargs(**self._OPTION_KWARG_NAMES_MAP))
|
|
130
|
+
return kwargs
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
# @omlish-manifest $.minichain.registries.manifests.RegistryManifest(
|
|
134
|
+
# name='mlx',
|
|
135
|
+
# type='ChatChoicesService',
|
|
136
|
+
# )
|
|
137
|
+
@static_check_is_chat_choices_service
|
|
138
|
+
class MlxChatChoicesService(BaseMlxChatChoicesService):
|
|
139
|
+
async def invoke(self, request: ChatChoicesRequest) -> ChatChoicesResponse:
|
|
140
|
+
loaded_model = self._load_model()
|
|
141
|
+
|
|
142
|
+
prompt = self._build_prompt(request.v)
|
|
123
143
|
|
|
124
144
|
with tv.consume(
|
|
125
145
|
*self._default_options,
|
|
126
146
|
*request.options,
|
|
127
147
|
override=True,
|
|
128
148
|
) as oc:
|
|
129
|
-
kwargs.
|
|
149
|
+
kwargs = self._build_kwargs(oc)
|
|
130
150
|
|
|
131
151
|
response = mlxu.generate(
|
|
132
152
|
loaded_model.model,
|
|
@@ -139,3 +159,57 @@ class MlxChatChoicesService(lang.ExitStacked):
|
|
|
139
159
|
return ChatChoicesResponse([
|
|
140
160
|
AiChoice([AiMessage(response)]) # noqa
|
|
141
161
|
])
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
# @omlish-manifest $.minichain.registries.manifests.RegistryManifest(
|
|
165
|
+
# name='mlx',
|
|
166
|
+
# type='ChatChoicesStreamService',
|
|
167
|
+
# )
|
|
168
|
+
@static_check_is_chat_choices_stream_service
|
|
169
|
+
class MlxChatChoicesStreamService(BaseMlxChatChoicesService):
|
|
170
|
+
def __init__(self, *configs: Config) -> None:
|
|
171
|
+
super().__init__()
|
|
172
|
+
|
|
173
|
+
with tv.consume(*configs) as cc:
|
|
174
|
+
self._model = cc.pop(MlxChatChoicesService.DEFAULT_MODEL)
|
|
175
|
+
self._default_options: tv.TypedValues = DefaultOptions.pop(cc)
|
|
176
|
+
|
|
177
|
+
READ_CHUNK_SIZE = 64 * 1024
|
|
178
|
+
|
|
179
|
+
async def invoke(
|
|
180
|
+
self,
|
|
181
|
+
request: ChatChoicesStreamRequest,
|
|
182
|
+
*,
|
|
183
|
+
max_tokens: int = 4096, # FIXME: ChatOption
|
|
184
|
+
) -> ChatChoicesStreamResponse:
|
|
185
|
+
loaded_model = self._load_model()
|
|
186
|
+
|
|
187
|
+
prompt = self._build_prompt(request.v)
|
|
188
|
+
|
|
189
|
+
with tv.consume(
|
|
190
|
+
*self._default_options,
|
|
191
|
+
*request.options,
|
|
192
|
+
override=True,
|
|
193
|
+
) as oc:
|
|
194
|
+
oc.pop(UseResources, None)
|
|
195
|
+
kwargs = self._build_kwargs(oc)
|
|
196
|
+
|
|
197
|
+
async with UseResources.or_new(request.options) as rs:
|
|
198
|
+
gen: ta.Iterator[mlxu.GenerationOutput] = rs.enter_context(contextlib.closing(mlxu.stream_generate(
|
|
199
|
+
loaded_model.model,
|
|
200
|
+
loaded_model.tokenization,
|
|
201
|
+
check.isinstance(prompt, str),
|
|
202
|
+
mlxu.GenerationParams(**kwargs),
|
|
203
|
+
# verbose=True,
|
|
204
|
+
)))
|
|
205
|
+
|
|
206
|
+
async def inner(sink: StreamResponseSink[AiChoicesDeltas]) -> ta.Sequence[ChatChoicesOutputs]:
|
|
207
|
+
for go in gen:
|
|
208
|
+
if go.text:
|
|
209
|
+
await sink.emit(AiChoicesDeltas([AiChoiceDeltas([
|
|
210
|
+
ContentAiChoiceDelta(go.text),
|
|
211
|
+
])]))
|
|
212
|
+
|
|
213
|
+
return []
|
|
214
|
+
|
|
215
|
+
return await new_stream_response(rs, inner)
|
|
File without changes
|
|
@@ -0,0 +1,196 @@
|
|
|
1
|
+
import typing as ta
|
|
2
|
+
|
|
3
|
+
from omlish import check
|
|
4
|
+
from omlish import lang
|
|
5
|
+
from omlish import marshal as msh
|
|
6
|
+
from omlish import typedvalues as tv
|
|
7
|
+
from omlish.formats import json
|
|
8
|
+
from omlish.http import all as http
|
|
9
|
+
from omlish.io.buffers import DelimitingBuffer
|
|
10
|
+
|
|
11
|
+
from .....backends.ollama import protocol as pt
|
|
12
|
+
from ....chat.choices.services import ChatChoicesOutputs
|
|
13
|
+
from ....chat.choices.services import ChatChoicesRequest
|
|
14
|
+
from ....chat.choices.services import ChatChoicesResponse
|
|
15
|
+
from ....chat.choices.services import static_check_is_chat_choices_service
|
|
16
|
+
from ....chat.choices.types import AiChoice
|
|
17
|
+
from ....chat.messages import AiMessage
|
|
18
|
+
from ....chat.messages import AnyAiMessage
|
|
19
|
+
from ....chat.messages import Message
|
|
20
|
+
from ....chat.messages import SystemMessage
|
|
21
|
+
from ....chat.messages import UserMessage
|
|
22
|
+
from ....chat.stream.services import ChatChoicesStreamRequest
|
|
23
|
+
from ....chat.stream.services import ChatChoicesStreamResponse
|
|
24
|
+
from ....chat.stream.services import static_check_is_chat_choices_stream_service
|
|
25
|
+
from ....chat.stream.types import AiChoiceDeltas
|
|
26
|
+
from ....chat.stream.types import AiChoicesDeltas
|
|
27
|
+
from ....chat.stream.types import ContentAiChoiceDelta
|
|
28
|
+
from ....models.configs import ModelName
|
|
29
|
+
from ....resources import UseResources
|
|
30
|
+
from ....standard import ApiUrl
|
|
31
|
+
from ....stream.services import StreamResponseSink
|
|
32
|
+
from ....stream.services import new_stream_response
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
##
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
# @omlish-manifest $.minichain.backends.strings.manifests.BackendStringsManifest(
|
|
39
|
+
# [
|
|
40
|
+
# 'ChatChoicesService',
|
|
41
|
+
# 'ChatChoicesStreamService',
|
|
42
|
+
# ],
|
|
43
|
+
# 'ollama',
|
|
44
|
+
# )
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
##
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class BaseOllamaChatChoicesService(lang.Abstract):
|
|
51
|
+
DEFAULT_API_URL: ta.ClassVar[ApiUrl] = ApiUrl('http://localhost:11434/api')
|
|
52
|
+
DEFAULT_MODEL_NAME: ta.ClassVar[ModelName] = ModelName('llama3.2')
|
|
53
|
+
|
|
54
|
+
def __init__(
|
|
55
|
+
self,
|
|
56
|
+
*configs: ApiUrl | ModelName,
|
|
57
|
+
) -> None:
|
|
58
|
+
super().__init__()
|
|
59
|
+
|
|
60
|
+
with tv.consume(*configs) as cc:
|
|
61
|
+
self._api_url = cc.pop(self.DEFAULT_API_URL)
|
|
62
|
+
self._model_name = cc.pop(self.DEFAULT_MODEL_NAME)
|
|
63
|
+
|
|
64
|
+
#
|
|
65
|
+
|
|
66
|
+
ROLE_MAP: ta.ClassVar[ta.Mapping[type[Message], pt.Role]] = {
|
|
67
|
+
SystemMessage: 'system',
|
|
68
|
+
UserMessage: 'user',
|
|
69
|
+
AiMessage: 'assistant',
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
@classmethod
|
|
73
|
+
def _get_message_content(cls, m: Message) -> str | None:
|
|
74
|
+
if isinstance(m, (AiMessage, UserMessage, SystemMessage)):
|
|
75
|
+
return check.isinstance(m.c, str)
|
|
76
|
+
else:
|
|
77
|
+
raise TypeError(m)
|
|
78
|
+
|
|
79
|
+
@classmethod
|
|
80
|
+
def _build_request_messages(cls, mc_msgs: ta.Iterable[Message]) -> ta.Sequence[pt.Message]:
|
|
81
|
+
messages: list[pt.Message] = []
|
|
82
|
+
for m in mc_msgs:
|
|
83
|
+
messages.append(pt.Message(
|
|
84
|
+
role=cls.ROLE_MAP[type(m)],
|
|
85
|
+
content=cls._get_message_content(m),
|
|
86
|
+
))
|
|
87
|
+
return messages
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
##
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
# @omlish-manifest $.minichain.registries.manifests.RegistryManifest(
|
|
94
|
+
# name='ollama',
|
|
95
|
+
# type='ChatChoicesService',
|
|
96
|
+
# )
|
|
97
|
+
@static_check_is_chat_choices_service
|
|
98
|
+
class OllamaChatChoicesService(BaseOllamaChatChoicesService):
|
|
99
|
+
async def invoke(
|
|
100
|
+
self,
|
|
101
|
+
request: ChatChoicesRequest,
|
|
102
|
+
) -> ChatChoicesResponse:
|
|
103
|
+
messages = self._build_request_messages(request.v)
|
|
104
|
+
|
|
105
|
+
a_req = pt.ChatRequest(
|
|
106
|
+
model=self._model_name.v,
|
|
107
|
+
messages=messages,
|
|
108
|
+
# tools=tools or None,
|
|
109
|
+
stream=False,
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
raw_request = msh.marshal(a_req)
|
|
113
|
+
|
|
114
|
+
raw_response = http.request(
|
|
115
|
+
self._api_url.v.removesuffix('/') + '/chat',
|
|
116
|
+
data=json.dumps(raw_request).encode('utf-8'),
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
json_response = json.loads(check.not_none(raw_response.data).decode('utf-8'))
|
|
120
|
+
|
|
121
|
+
resp = msh.unmarshal(json_response, pt.ChatResponse)
|
|
122
|
+
|
|
123
|
+
out: list[AnyAiMessage] = []
|
|
124
|
+
if resp.message.role == 'assistant':
|
|
125
|
+
out.append(AiMessage(
|
|
126
|
+
check.not_none(resp.message.content),
|
|
127
|
+
))
|
|
128
|
+
else:
|
|
129
|
+
raise TypeError(resp.message.role)
|
|
130
|
+
|
|
131
|
+
return ChatChoicesResponse([
|
|
132
|
+
AiChoice(out),
|
|
133
|
+
])
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
##
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
# @omlish-manifest $.minichain.registries.manifests.RegistryManifest(
|
|
140
|
+
# name='ollama',
|
|
141
|
+
# type='ChatChoicesStreamService',
|
|
142
|
+
# )
|
|
143
|
+
@static_check_is_chat_choices_stream_service
|
|
144
|
+
class OllamaChatChoicesStreamService(BaseOllamaChatChoicesService):
|
|
145
|
+
READ_CHUNK_SIZE = 64 * 1024
|
|
146
|
+
|
|
147
|
+
async def invoke(
|
|
148
|
+
self,
|
|
149
|
+
request: ChatChoicesStreamRequest,
|
|
150
|
+
) -> ChatChoicesStreamResponse:
|
|
151
|
+
messages = self._build_request_messages(request.v)
|
|
152
|
+
|
|
153
|
+
a_req = pt.ChatRequest(
|
|
154
|
+
model=self._model_name.v,
|
|
155
|
+
messages=messages,
|
|
156
|
+
# tools=tools or None,
|
|
157
|
+
stream=True,
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
raw_request = msh.marshal(a_req)
|
|
161
|
+
|
|
162
|
+
http_request = http.HttpRequest(
|
|
163
|
+
self._api_url.v.removesuffix('/') + '/chat',
|
|
164
|
+
data=json.dumps(raw_request).encode('utf-8'),
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
async with UseResources.or_new(request.options) as rs:
|
|
168
|
+
http_client = rs.enter_context(http.client())
|
|
169
|
+
http_response = rs.enter_context(http_client.stream_request(http_request))
|
|
170
|
+
|
|
171
|
+
async def inner(sink: StreamResponseSink[AiChoicesDeltas]) -> ta.Sequence[ChatChoicesOutputs] | None:
|
|
172
|
+
db = DelimitingBuffer([b'\r', b'\n', b'\r\n'])
|
|
173
|
+
while True:
|
|
174
|
+
# FIXME: read1 not on response stream protocol
|
|
175
|
+
b = http_response.stream.read1(self.READ_CHUNK_SIZE) # type: ignore[attr-defined]
|
|
176
|
+
for l in db.feed(b):
|
|
177
|
+
if isinstance(l, DelimitingBuffer.Incomplete):
|
|
178
|
+
# FIXME: handle
|
|
179
|
+
return []
|
|
180
|
+
|
|
181
|
+
lj = json.loads(l.decode('utf-8'))
|
|
182
|
+
lp: pt.ChatResponse = msh.unmarshal(lj, pt.ChatResponse)
|
|
183
|
+
|
|
184
|
+
check.state(lp.message.role == 'assistant')
|
|
185
|
+
check.none(lp.message.tool_name)
|
|
186
|
+
check.state(not lp.message.tool_calls)
|
|
187
|
+
|
|
188
|
+
if (c := lp.message.content):
|
|
189
|
+
await sink.emit(AiChoicesDeltas([AiChoiceDeltas([ContentAiChoiceDelta(
|
|
190
|
+
c,
|
|
191
|
+
)])]))
|
|
192
|
+
|
|
193
|
+
if not b:
|
|
194
|
+
return []
|
|
195
|
+
|
|
196
|
+
return await new_stream_response(rs, inner)
|
|
@@ -26,8 +26,8 @@ from ....chat.choices.services import static_check_is_chat_choices_service
|
|
|
26
26
|
from ....models.configs import ModelName
|
|
27
27
|
from ....standard import ApiKey
|
|
28
28
|
from ....standard import DefaultOptions
|
|
29
|
-
from .
|
|
30
|
-
from .
|
|
29
|
+
from .format import OpenaiChatRequestHandler
|
|
30
|
+
from .format import build_mc_choices_response
|
|
31
31
|
from .names import MODEL_NAMES
|
|
32
32
|
|
|
33
33
|
|
|
@@ -2,18 +2,17 @@ import typing as ta
|
|
|
2
2
|
|
|
3
3
|
from omlish import cached
|
|
4
4
|
from omlish import check
|
|
5
|
-
from omlish import lang
|
|
6
5
|
from omlish import typedvalues as tv
|
|
7
6
|
from omlish.formats import json
|
|
8
7
|
|
|
8
|
+
from .....backends.openai import protocol as pt
|
|
9
9
|
from ....chat.choices.services import ChatChoicesResponse
|
|
10
10
|
from ....chat.choices.types import AiChoice
|
|
11
|
+
from ....chat.choices.types import AiChoices
|
|
11
12
|
from ....chat.choices.types import ChatChoicesOptions
|
|
12
|
-
from ....chat.messages import AiChat
|
|
13
13
|
from ....chat.messages import AiMessage
|
|
14
14
|
from ....chat.messages import AnyAiMessage
|
|
15
15
|
from ....chat.messages import Chat
|
|
16
|
-
from ....chat.messages import Message
|
|
17
16
|
from ....chat.messages import SystemMessage
|
|
18
17
|
from ....chat.messages import ToolUseMessage
|
|
19
18
|
from ....chat.messages import ToolUseResultMessage
|
|
@@ -28,7 +27,7 @@ from ....llms.types import MaxTokens
|
|
|
28
27
|
from ....llms.types import Temperature
|
|
29
28
|
from ....llms.types import TokenUsage
|
|
30
29
|
from ....llms.types import TokenUsageOutput
|
|
31
|
-
from ....tools.jsonschema import
|
|
30
|
+
from ....tools.jsonschema import build_tool_spec_params_json_schema
|
|
32
31
|
from ....tools.types import ToolSpec
|
|
33
32
|
from ....tools.types import ToolUse
|
|
34
33
|
from ....types import Option
|
|
@@ -37,61 +36,115 @@ from ....types import Option
|
|
|
37
36
|
##
|
|
38
37
|
|
|
39
38
|
|
|
40
|
-
def
|
|
41
|
-
|
|
39
|
+
def build_oai_request_msgs(mc_chat: Chat) -> ta.Sequence[pt.ChatCompletionMessage]:
|
|
40
|
+
oai_msgs: list[pt.ChatCompletionMessage] = []
|
|
42
41
|
|
|
43
|
-
for
|
|
44
|
-
if isinstance(
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
content=m.c,
|
|
42
|
+
for mc_msg in mc_chat:
|
|
43
|
+
if isinstance(mc_msg, SystemMessage):
|
|
44
|
+
oai_msgs.append(pt.SystemChatCompletionMessage(
|
|
45
|
+
content=check.isinstance(mc_msg.c, str),
|
|
48
46
|
))
|
|
49
47
|
|
|
50
|
-
elif isinstance(
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
content=check.isinstance(m.c, (str, None)),
|
|
48
|
+
elif isinstance(mc_msg, AiMessage):
|
|
49
|
+
oai_msgs.append(pt.AssistantChatCompletionMessage(
|
|
50
|
+
content=check.isinstance(mc_msg.c, (str, None)),
|
|
54
51
|
))
|
|
55
52
|
|
|
56
|
-
elif isinstance(
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
arguments=check.not_none(m.tu.raw_args),
|
|
64
|
-
name=m.tu.name,
|
|
65
|
-
),
|
|
66
|
-
type='function',
|
|
53
|
+
elif isinstance(mc_msg, ToolUseMessage):
|
|
54
|
+
oai_msgs.append(pt.AssistantChatCompletionMessage(
|
|
55
|
+
tool_calls=[pt.AssistantChatCompletionMessage.ToolCall(
|
|
56
|
+
id=check.not_none(mc_msg.tu.id),
|
|
57
|
+
function=pt.AssistantChatCompletionMessage.ToolCall.Function(
|
|
58
|
+
arguments=check.not_none(mc_msg.tu.raw_args),
|
|
59
|
+
name=mc_msg.tu.name,
|
|
67
60
|
),
|
|
68
|
-
],
|
|
61
|
+
)],
|
|
69
62
|
))
|
|
70
63
|
|
|
71
|
-
elif isinstance(
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
content=prepare_content_str(m.c),
|
|
64
|
+
elif isinstance(mc_msg, UserMessage):
|
|
65
|
+
oai_msgs.append(pt.UserChatCompletionMessage(
|
|
66
|
+
content=prepare_content_str(mc_msg.c),
|
|
75
67
|
))
|
|
76
68
|
|
|
77
|
-
elif isinstance(
|
|
69
|
+
elif isinstance(mc_msg, ToolUseResultMessage):
|
|
78
70
|
tc: str
|
|
79
|
-
if isinstance(
|
|
80
|
-
tc =
|
|
81
|
-
elif isinstance(
|
|
82
|
-
tc = json.dumps_compact(
|
|
71
|
+
if isinstance(mc_msg.tur.c, str):
|
|
72
|
+
tc = mc_msg.tur.c
|
|
73
|
+
elif isinstance(mc_msg.tur.c, JsonContent):
|
|
74
|
+
tc = json.dumps_compact(mc_msg.tur.c)
|
|
83
75
|
else:
|
|
84
|
-
raise TypeError(
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
tool_call_id=m.tur.id,
|
|
76
|
+
raise TypeError(mc_msg.tur.c)
|
|
77
|
+
oai_msgs.append(pt.ToolChatCompletionMessage(
|
|
78
|
+
tool_call_id=check.not_none(mc_msg.tur.id),
|
|
88
79
|
content=tc,
|
|
89
80
|
))
|
|
90
81
|
|
|
91
82
|
else:
|
|
92
|
-
raise TypeError(
|
|
83
|
+
raise TypeError(mc_msg)
|
|
93
84
|
|
|
94
|
-
return
|
|
85
|
+
return oai_msgs
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
#
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def build_mc_ai_choice(oai_choice: pt.ChatCompletionResponseChoice) -> AiChoice:
|
|
92
|
+
cur: list[AnyAiMessage] = []
|
|
93
|
+
|
|
94
|
+
oai_msg = oai_choice.message
|
|
95
|
+
|
|
96
|
+
if (oai_c := oai_msg.content) is not None:
|
|
97
|
+
cur.append(AiMessage(check.isinstance(oai_c, str)))
|
|
98
|
+
|
|
99
|
+
for oai_tc in oai_msg.tool_calls or []:
|
|
100
|
+
cur.append(ToolUseMessage(ToolUse(
|
|
101
|
+
id=oai_tc.id,
|
|
102
|
+
name=oai_tc.function.name,
|
|
103
|
+
args=json.loads(oai_tc.function.arguments or '{}'),
|
|
104
|
+
raw_args=oai_tc.function.arguments,
|
|
105
|
+
)))
|
|
106
|
+
|
|
107
|
+
return AiChoice(cur)
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def build_mc_ai_choices(oai_resp: pt.ChatCompletionResponse) -> AiChoices:
|
|
111
|
+
return [
|
|
112
|
+
build_mc_ai_choice(oai_choice)
|
|
113
|
+
for oai_choice in oai_resp.choices
|
|
114
|
+
]
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def build_mc_choices_response(oai_resp: pt.ChatCompletionResponse) -> ChatChoicesResponse:
|
|
118
|
+
return ChatChoicesResponse(
|
|
119
|
+
build_mc_ai_choices(oai_resp),
|
|
120
|
+
|
|
121
|
+
tv.TypedValues(
|
|
122
|
+
*([TokenUsageOutput(TokenUsage(
|
|
123
|
+
input=tu.prompt_tokens,
|
|
124
|
+
output=tu.completion_tokens,
|
|
125
|
+
total=tu.total_tokens,
|
|
126
|
+
))] if (tu := oai_resp.usage) is not None else []),
|
|
127
|
+
),
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def build_mc_ai_choice_delta(delta: pt.ChatCompletionChunkChoiceDelta) -> AiChoiceDelta:
|
|
132
|
+
if delta.content is not None:
|
|
133
|
+
check.state(not delta.tool_calls)
|
|
134
|
+
return ContentAiChoiceDelta(delta.content)
|
|
135
|
+
|
|
136
|
+
elif delta.tool_calls is not None:
|
|
137
|
+
check.state(delta.content is None)
|
|
138
|
+
tc = check.single(delta.tool_calls)
|
|
139
|
+
tc_fn = check.not_none(tc.function)
|
|
140
|
+
return PartialToolUseAiChoiceDelta(
|
|
141
|
+
id=tc.id,
|
|
142
|
+
name=tc_fn.name,
|
|
143
|
+
raw_args=tc_fn.arguments,
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
else:
|
|
147
|
+
raise ValueError(delta)
|
|
95
148
|
|
|
96
149
|
|
|
97
150
|
##
|
|
@@ -112,14 +165,6 @@ class OpenaiChatRequestHandler:
|
|
|
112
165
|
self._model = model
|
|
113
166
|
self._mandatory_kwargs = mandatory_kwargs
|
|
114
167
|
|
|
115
|
-
ROLES_MAP: ta.ClassVar[ta.Mapping[type[Message], str]] = {
|
|
116
|
-
SystemMessage: 'system',
|
|
117
|
-
UserMessage: 'user',
|
|
118
|
-
AiMessage: 'assistant',
|
|
119
|
-
ToolUseMessage: 'assistant',
|
|
120
|
-
ToolUseResultMessage: 'tool',
|
|
121
|
-
}
|
|
122
|
-
|
|
123
168
|
DEFAULT_OPTIONS: ta.ClassVar[tv.TypedValues[Option]] = tv.TypedValues[Option](
|
|
124
169
|
Temperature(0.),
|
|
125
170
|
MaxTokens(1024),
|
|
@@ -162,72 +207,26 @@ class OpenaiChatRequestHandler:
|
|
|
162
207
|
)
|
|
163
208
|
|
|
164
209
|
@cached.function
|
|
165
|
-
def
|
|
210
|
+
def oai_request(self) -> pt.ChatCompletionRequest:
|
|
166
211
|
po = self._process_options()
|
|
167
212
|
|
|
168
|
-
tools = [
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
213
|
+
tools: list[pt.ChatCompletionRequestTool] = [
|
|
214
|
+
pt.ChatCompletionRequestTool(
|
|
215
|
+
function=pt.ChatCompletionRequestTool.Function(
|
|
216
|
+
name=check.not_none(ts.name),
|
|
217
|
+
description=prepare_content_str(ts.desc),
|
|
218
|
+
parameters=build_tool_spec_params_json_schema(ts),
|
|
219
|
+
),
|
|
172
220
|
)
|
|
173
221
|
for ts in po.tools_by_name.values()
|
|
174
222
|
]
|
|
175
223
|
|
|
176
|
-
return
|
|
224
|
+
return pt.ChatCompletionRequest(
|
|
177
225
|
model=self._model,
|
|
178
|
-
messages=
|
|
226
|
+
messages=build_oai_request_msgs(self._chat),
|
|
179
227
|
top_p=1,
|
|
180
|
-
|
|
228
|
+
tools=tools or None,
|
|
181
229
|
frequency_penalty=0.0,
|
|
182
230
|
presence_penalty=0.0,
|
|
183
231
|
**po.kwargs,
|
|
184
232
|
)
|
|
185
|
-
|
|
186
|
-
def build_ai_chat(self, message: ta.Mapping[str, ta.Any]) -> AiChat:
|
|
187
|
-
out: list[AnyAiMessage] = []
|
|
188
|
-
if (c := message.get('content')) is not None:
|
|
189
|
-
out.append(AiMessage(c))
|
|
190
|
-
for tc in message.get('tool_calls', []):
|
|
191
|
-
out.append(ToolUseMessage(
|
|
192
|
-
ToolUse(
|
|
193
|
-
id=tc['id'],
|
|
194
|
-
name=tc['function']['name'],
|
|
195
|
-
args=json.loads(tc['function']['arguments'] or '{}'),
|
|
196
|
-
raw_args=tc['function']['arguments'],
|
|
197
|
-
),
|
|
198
|
-
))
|
|
199
|
-
return out
|
|
200
|
-
|
|
201
|
-
def build_response(self, raw_response: ta.Mapping[str, ta.Any]) -> ChatChoicesResponse:
|
|
202
|
-
return ChatChoicesResponse(
|
|
203
|
-
[
|
|
204
|
-
AiChoice(self.build_ai_chat(choice['message']))
|
|
205
|
-
for choice in raw_response['choices']
|
|
206
|
-
],
|
|
207
|
-
|
|
208
|
-
tv.TypedValues(
|
|
209
|
-
*([TokenUsageOutput(TokenUsage(
|
|
210
|
-
input=tu['prompt_tokens'],
|
|
211
|
-
output=tu['completion_tokens'],
|
|
212
|
-
total=tu['total_tokens'],
|
|
213
|
-
))] if (tu := raw_response.get('usage')) is not None else []),
|
|
214
|
-
),
|
|
215
|
-
)
|
|
216
|
-
|
|
217
|
-
def build_ai_choice_delta(self, delta: ta.Mapping[str, ta.Any]) -> AiChoiceDelta:
|
|
218
|
-
if (c := delta.get('content')) is not None:
|
|
219
|
-
check.state(not delta.get('tool_calls'))
|
|
220
|
-
return ContentAiChoiceDelta(c)
|
|
221
|
-
|
|
222
|
-
elif (tcs := delta.get('tool_calls')) is not None: # noqa
|
|
223
|
-
check.state(delta.get('content') is None)
|
|
224
|
-
tc = check.single(tcs)
|
|
225
|
-
tc_fn = tc['function']
|
|
226
|
-
return PartialToolUseAiChoiceDelta(
|
|
227
|
-
id=tc.get('id'),
|
|
228
|
-
name=tc_fn.get('name'),
|
|
229
|
-
raw_args=tc_fn.get('arguments'),
|
|
230
|
-
)
|
|
231
|
-
|
|
232
|
-
else:
|
|
233
|
-
raise ValueError(delta)
|