ommlds 0.0.0.dev487__py3-none-any.whl → 0.0.0.dev488__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ommlds/.omlish-manifests.json +54 -0
- ommlds/__about__.py +1 -1
- ommlds/backends/cerebras/__init__.py +7 -0
- ommlds/backends/cerebras/_dataclasses.py +4254 -0
- ommlds/backends/cerebras/_marshal.py +24 -0
- ommlds/backends/cerebras/protocol.py +312 -0
- ommlds/cli/_dataclasses.py +272 -88
- ommlds/cli/main.py +29 -9
- ommlds/cli/secrets.py +1 -0
- ommlds/cli/sessions/chat/chat/user/configs.py +0 -1
- ommlds/cli/sessions/chat/chat/user/inject.py +0 -7
- ommlds/cli/sessions/chat/configs.py +2 -0
- ommlds/cli/sessions/chat/inject.py +3 -0
- ommlds/cli/sessions/chat/interface/__init__.py +0 -0
- ommlds/cli/sessions/chat/interface/bare/__init__.py +0 -0
- ommlds/cli/sessions/chat/interface/bare/inject.py +28 -0
- ommlds/cli/sessions/chat/interface/bare/interface.py +19 -0
- ommlds/cli/sessions/chat/interface/base.py +13 -0
- ommlds/cli/sessions/chat/interface/configs.py +15 -0
- ommlds/cli/sessions/chat/interface/inject.py +24 -0
- ommlds/cli/sessions/chat/interface/textual/__init__.py +0 -0
- ommlds/cli/sessions/chat/interface/textual/app.py +14 -0
- ommlds/cli/sessions/chat/interface/textual/inject.py +20 -0
- ommlds/cli/sessions/chat/interface/textual/interface.py +22 -0
- ommlds/cli/sessions/chat/session.py +12 -4
- ommlds/minichain/backends/impls/cerebras/__init__.py +0 -0
- ommlds/minichain/backends/impls/cerebras/chat.py +80 -0
- ommlds/minichain/backends/impls/cerebras/names.py +30 -0
- ommlds/minichain/backends/impls/cerebras/protocol.py +143 -0
- ommlds/minichain/backends/impls/cerebras/stream.py +125 -0
- {ommlds-0.0.0.dev487.dist-info → ommlds-0.0.0.dev488.dist-info}/METADATA +6 -6
- {ommlds-0.0.0.dev487.dist-info → ommlds-0.0.0.dev488.dist-info}/RECORD +36 -16
- {ommlds-0.0.0.dev487.dist-info → ommlds-0.0.0.dev488.dist-info}/WHEEL +0 -0
- {ommlds-0.0.0.dev487.dist-info → ommlds-0.0.0.dev488.dist-info}/entry_points.txt +0 -0
- {ommlds-0.0.0.dev487.dist-info → ommlds-0.0.0.dev488.dist-info}/licenses/LICENSE +0 -0
- {ommlds-0.0.0.dev487.dist-info → ommlds-0.0.0.dev488.dist-info}/top_level.txt +0 -0
ommlds/cli/main.py
CHANGED
|
@@ -77,31 +77,47 @@ class ChatProfile(Profile):
|
|
|
77
77
|
|
|
78
78
|
#
|
|
79
79
|
|
|
80
|
-
|
|
81
|
-
ap.arg('
|
|
82
|
-
ap.arg('-
|
|
83
|
-
ap.arg('-e', '--editor', action='store_true', group='
|
|
80
|
+
INTERFACE_ARGS: ta.ClassVar[ta.Sequence[ap.Arg]] = [
|
|
81
|
+
ap.arg('-i', '--interactive', action='store_true', group='interface'),
|
|
82
|
+
ap.arg('-T', '--textual', action='store_true', group='interface'),
|
|
83
|
+
ap.arg('-e', '--editor', action='store_true', group='interface'),
|
|
84
84
|
]
|
|
85
85
|
|
|
86
|
-
def
|
|
86
|
+
def configure_interface(self, cfg: ChatConfig) -> ChatConfig:
|
|
87
87
|
if self._args.editor:
|
|
88
88
|
check.arg(not self._args.interactive)
|
|
89
89
|
check.arg(not self._args.message)
|
|
90
90
|
raise NotImplementedError
|
|
91
91
|
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
return dc.replace(
|
|
92
|
+
if self._args.interactive:
|
|
93
|
+
cfg = dc.replace(
|
|
95
94
|
cfg,
|
|
95
|
+
interface=dc.replace(
|
|
96
|
+
cfg.interface,
|
|
97
|
+
interactive=True,
|
|
98
|
+
use_textual=self._args.textual,
|
|
99
|
+
),
|
|
96
100
|
user=dc.replace(
|
|
97
101
|
cfg.user,
|
|
98
102
|
interactive=True,
|
|
99
103
|
),
|
|
100
104
|
)
|
|
101
105
|
|
|
106
|
+
return cfg
|
|
107
|
+
|
|
108
|
+
#
|
|
109
|
+
|
|
110
|
+
INPUT_ARGS: ta.ClassVar[ta.Sequence[ap.Arg]] = [
|
|
111
|
+
ap.arg('message', nargs='*', group='input'),
|
|
112
|
+
]
|
|
113
|
+
|
|
114
|
+
def configure_input(self, cfg: ChatConfig) -> ChatConfig:
|
|
115
|
+
if self._args.interactive:
|
|
116
|
+
check.arg(not self._args.message)
|
|
117
|
+
|
|
102
118
|
elif self._args.message:
|
|
103
119
|
# TODO: '-' -> stdin
|
|
104
|
-
|
|
120
|
+
cfg = dc.replace(
|
|
105
121
|
cfg,
|
|
106
122
|
user=dc.replace(
|
|
107
123
|
cfg.user,
|
|
@@ -112,6 +128,8 @@ class ChatProfile(Profile):
|
|
|
112
128
|
else:
|
|
113
129
|
raise ValueError('Must specify input')
|
|
114
130
|
|
|
131
|
+
return cfg
|
|
132
|
+
|
|
115
133
|
#
|
|
116
134
|
|
|
117
135
|
STATE_ARGS: ta.ClassVar[ta.Sequence[ap.Arg]] = [
|
|
@@ -220,6 +238,7 @@ class ChatProfile(Profile):
|
|
|
220
238
|
|
|
221
239
|
for grp_name, grp_args in [
|
|
222
240
|
('backend', self.BACKEND_ARGS),
|
|
241
|
+
('interface', self.INTERFACE_ARGS),
|
|
223
242
|
('input', self.INPUT_ARGS),
|
|
224
243
|
('state', self.STATE_ARGS),
|
|
225
244
|
('output', self.OUTPUT_ARGS),
|
|
@@ -234,6 +253,7 @@ class ChatProfile(Profile):
|
|
|
234
253
|
|
|
235
254
|
cfg = ChatConfig()
|
|
236
255
|
cfg = self.configure_backend(cfg)
|
|
256
|
+
cfg = self.configure_interface(cfg)
|
|
237
257
|
cfg = self.configure_input(cfg)
|
|
238
258
|
cfg = self.configure_state(cfg)
|
|
239
259
|
cfg = self.configure_output(cfg)
|
ommlds/cli/secrets.py
CHANGED
|
@@ -9,8 +9,6 @@ from .configs import UserConfig
|
|
|
9
9
|
|
|
10
10
|
|
|
11
11
|
with lang.auto_proxy_import(globals()):
|
|
12
|
-
from .....inputs import asyncs as _inputs_asyncs
|
|
13
|
-
from .....inputs import sync as _inputs_sync
|
|
14
12
|
from ..state import types as _state
|
|
15
13
|
from . import interactive as _interactive
|
|
16
14
|
from . import oneshot as _oneshot
|
|
@@ -45,11 +43,6 @@ def bind_user(cfg: UserConfig = UserConfig()) -> inj.Elements:
|
|
|
45
43
|
|
|
46
44
|
els.append(inj.bind(_types.UserChatInput, to_ctor=_interactive.InteractiveUserChatInput, singleton=True))
|
|
47
45
|
|
|
48
|
-
els.extend([
|
|
49
|
-
inj.bind(_inputs_sync.SyncStringInput, to_const=_inputs_sync.InputSyncStringInput(use_readline=cfg.use_readline)), # noqa
|
|
50
|
-
inj.bind(_inputs_asyncs.AsyncStringInput, to_ctor=_inputs_asyncs.ThreadAsyncStringInput, singleton=True),
|
|
51
|
-
])
|
|
52
|
-
|
|
53
46
|
else:
|
|
54
47
|
if cfg.initial_user_content is None:
|
|
55
48
|
raise ValueError('Initial user content is required for non-interactive chat')
|
|
@@ -5,6 +5,7 @@ from ...rendering.configs import RenderingConfig
|
|
|
5
5
|
from .chat.ai.configs import AiConfig
|
|
6
6
|
from .chat.state.configs import StateConfig
|
|
7
7
|
from .chat.user.configs import UserConfig
|
|
8
|
+
from .interface.configs import InterfaceConfig
|
|
8
9
|
from .tools.configs import ToolsConfig
|
|
9
10
|
|
|
10
11
|
|
|
@@ -24,4 +25,5 @@ class ChatConfig:
|
|
|
24
25
|
state: StateConfig = StateConfig()
|
|
25
26
|
user: UserConfig = UserConfig()
|
|
26
27
|
rendering: RenderingConfig = RenderingConfig()
|
|
28
|
+
interface: InterfaceConfig = InterfaceConfig()
|
|
27
29
|
tools: ToolsConfig = ToolsConfig()
|
|
@@ -16,6 +16,7 @@ with lang.auto_proxy_import(globals()):
|
|
|
16
16
|
from .chat.ai import inject as _chat_ai
|
|
17
17
|
from .chat.state import inject as _chat_state
|
|
18
18
|
from .chat.user import inject as _chat_user
|
|
19
|
+
from .interface import inject as _interface
|
|
19
20
|
from .phases import inject as _phases
|
|
20
21
|
from .tools import inject as _tools
|
|
21
22
|
|
|
@@ -37,6 +38,8 @@ def bind_chat(cfg: ChatConfig) -> inj.Elements:
|
|
|
37
38
|
|
|
38
39
|
_chat_state.bind_state(cfg.state),
|
|
39
40
|
|
|
41
|
+
_interface.bind_interface(cfg.interface),
|
|
42
|
+
|
|
40
43
|
_phases.bind_phases(),
|
|
41
44
|
|
|
42
45
|
_rendering.bind_rendering(cfg.rendering),
|
|
File without changes
|
|
File without changes
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
from omlish import inject as inj
|
|
2
|
+
from omlish import lang
|
|
3
|
+
|
|
4
|
+
from ..base import ChatInterface
|
|
5
|
+
from ..configs import InterfaceConfig
|
|
6
|
+
from .interface import BareChatInterface
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
with lang.auto_proxy_import(globals()):
|
|
10
|
+
from .....inputs import asyncs as _inputs_asyncs
|
|
11
|
+
from .....inputs import sync as _inputs_sync
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
##
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def bind_bare(cfg: InterfaceConfig = InterfaceConfig()) -> inj.Elements:
|
|
18
|
+
els: list[inj.Elemental] = [
|
|
19
|
+
inj.bind(ChatInterface, to_ctor=BareChatInterface, singleton=True),
|
|
20
|
+
]
|
|
21
|
+
|
|
22
|
+
if cfg.interactive:
|
|
23
|
+
els.extend([
|
|
24
|
+
inj.bind(_inputs_sync.SyncStringInput, to_const=_inputs_sync.InputSyncStringInput(use_readline=cfg.use_readline)), # noqa
|
|
25
|
+
inj.bind(_inputs_asyncs.AsyncStringInput, to_ctor=_inputs_asyncs.ThreadAsyncStringInput, singleton=True),
|
|
26
|
+
])
|
|
27
|
+
|
|
28
|
+
return inj.as_elements(*els)
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
from ...driver import ChatDriver
|
|
2
|
+
from ..base import ChatInterface
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
##
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class BareChatInterface(ChatInterface):
|
|
9
|
+
def __init__(
|
|
10
|
+
self,
|
|
11
|
+
*,
|
|
12
|
+
driver: ChatDriver,
|
|
13
|
+
) -> None:
|
|
14
|
+
super().__init__()
|
|
15
|
+
|
|
16
|
+
self._driver = driver
|
|
17
|
+
|
|
18
|
+
async def run(self) -> None:
|
|
19
|
+
await self._driver.run()
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
import typing as ta
|
|
2
|
+
|
|
3
|
+
from omlish import dataclasses as dc
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
##
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@dc.dataclass(frozen=True, kw_only=True)
|
|
10
|
+
class InterfaceConfig:
|
|
11
|
+
interactive: bool = False
|
|
12
|
+
|
|
13
|
+
use_textual: bool = False
|
|
14
|
+
|
|
15
|
+
use_readline: bool | ta.Literal['auto'] = 'auto'
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
from omlish import inject as inj
|
|
2
|
+
from omlish import lang
|
|
3
|
+
|
|
4
|
+
from .configs import InterfaceConfig
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
with lang.auto_proxy_import(globals()):
|
|
8
|
+
from .bare import inject as _bare
|
|
9
|
+
from .textual import inject as _textual
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
##
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def bind_interface(cfg: InterfaceConfig = InterfaceConfig()) -> inj.Elements:
|
|
16
|
+
els: list[inj.Elemental] = []
|
|
17
|
+
|
|
18
|
+
if cfg.use_textual:
|
|
19
|
+
els.append(_textual.bind_textual())
|
|
20
|
+
|
|
21
|
+
else:
|
|
22
|
+
els.append(_bare.bind_bare(cfg))
|
|
23
|
+
|
|
24
|
+
return inj.as_elements(*els)
|
|
File without changes
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
from omlish import inject as inj
|
|
2
|
+
|
|
3
|
+
from ..base import ChatInterface
|
|
4
|
+
from .app import ChatApp
|
|
5
|
+
from .interface import TextualChatInterface
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
##
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def bind_textual() -> inj.Elements:
|
|
12
|
+
els: list[inj.Elemental] = [
|
|
13
|
+
inj.bind(ChatInterface, to_ctor=TextualChatInterface, singleton=True),
|
|
14
|
+
]
|
|
15
|
+
|
|
16
|
+
els.extend([
|
|
17
|
+
inj.bind(ChatApp, singleton=True),
|
|
18
|
+
])
|
|
19
|
+
|
|
20
|
+
return inj.as_elements(*els)
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
from ...driver import ChatDriver
|
|
2
|
+
from ..base import ChatInterface
|
|
3
|
+
from .app import ChatApp
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
##
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class TextualChatInterface(ChatInterface):
|
|
10
|
+
def __init__(
|
|
11
|
+
self,
|
|
12
|
+
*,
|
|
13
|
+
driver: ChatDriver,
|
|
14
|
+
app: ChatApp,
|
|
15
|
+
) -> None:
|
|
16
|
+
super().__init__()
|
|
17
|
+
|
|
18
|
+
self._driver = driver
|
|
19
|
+
self._app = app
|
|
20
|
+
|
|
21
|
+
async def run(self) -> None:
|
|
22
|
+
await self._app.run_async()
|
|
@@ -1,14 +1,22 @@
|
|
|
1
|
+
import typing as ta
|
|
2
|
+
|
|
1
3
|
from omlish import dataclasses as dc
|
|
2
4
|
|
|
3
5
|
from ..base import Session
|
|
4
6
|
from .configs import ChatConfig
|
|
5
|
-
from .
|
|
7
|
+
from .interface.base import ChatInterface
|
|
6
8
|
|
|
7
9
|
|
|
8
10
|
##
|
|
9
11
|
|
|
10
12
|
|
|
13
|
+
@ta.final
|
|
11
14
|
class ChatSession(Session['ChatSession.Config']):
|
|
15
|
+
"""
|
|
16
|
+
An adapter to the lower level, dumber, non-chat-specific cli 'session' layer. Nothing else takes the kitchen-sink
|
|
17
|
+
'ChatConfig' object, it's only here for type dispatch in lower layers.
|
|
18
|
+
"""
|
|
19
|
+
|
|
12
20
|
@dc.dataclass(frozen=True)
|
|
13
21
|
class Config(Session.Config, ChatConfig):
|
|
14
22
|
pass
|
|
@@ -17,11 +25,11 @@ class ChatSession(Session['ChatSession.Config']):
|
|
|
17
25
|
self,
|
|
18
26
|
config: Config,
|
|
19
27
|
*,
|
|
20
|
-
|
|
28
|
+
interface: ChatInterface,
|
|
21
29
|
) -> None:
|
|
22
30
|
super().__init__(config)
|
|
23
31
|
|
|
24
|
-
self.
|
|
32
|
+
self._interface = interface
|
|
25
33
|
|
|
26
34
|
async def run(self) -> None:
|
|
27
|
-
await self.
|
|
35
|
+
await self._interface.run()
|
|
File without changes
|
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
import typing as ta
|
|
2
|
+
|
|
3
|
+
from omlish import check
|
|
4
|
+
from omlish import marshal as msh
|
|
5
|
+
from omlish import typedvalues as tv
|
|
6
|
+
from omlish.formats import json
|
|
7
|
+
from omlish.http import all as http
|
|
8
|
+
|
|
9
|
+
from .....backends.cerebras import protocol as pt
|
|
10
|
+
from ....chat.choices.services import ChatChoicesRequest
|
|
11
|
+
from ....chat.choices.services import ChatChoicesResponse
|
|
12
|
+
from ....chat.choices.services import static_check_is_chat_choices_service
|
|
13
|
+
from ....chat.tools.types import Tool
|
|
14
|
+
from ....models.configs import ModelName
|
|
15
|
+
from ....standard import ApiKey
|
|
16
|
+
from ....standard import DefaultOptions
|
|
17
|
+
from .names import MODEL_NAMES
|
|
18
|
+
from .protocol import build_cer_request_messages
|
|
19
|
+
from .protocol import build_cer_request_tool
|
|
20
|
+
from .protocol import build_mc_choices_response
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
##
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
# @omlish-manifest $.minichain.registries.manifests.RegistryManifest(
|
|
27
|
+
# name='cerebras',
|
|
28
|
+
# type='ChatChoicesService',
|
|
29
|
+
# )
|
|
30
|
+
@static_check_is_chat_choices_service
|
|
31
|
+
class CerebrasChatChoicesService:
|
|
32
|
+
DEFAULT_MODEL_NAME: ta.ClassVar[ModelName] = ModelName(check.not_none(MODEL_NAMES.default))
|
|
33
|
+
|
|
34
|
+
def __init__(
|
|
35
|
+
self,
|
|
36
|
+
*configs: ApiKey | ModelName | DefaultOptions,
|
|
37
|
+
http_client: http.AsyncHttpClient | None = None,
|
|
38
|
+
) -> None:
|
|
39
|
+
super().__init__()
|
|
40
|
+
|
|
41
|
+
self._http_client = http_client
|
|
42
|
+
|
|
43
|
+
with tv.consume(*configs) as cc:
|
|
44
|
+
self._model_name = cc.pop(self.DEFAULT_MODEL_NAME)
|
|
45
|
+
self._api_key = ApiKey.pop_secret(cc, env='CEREBRAS_API_KEY')
|
|
46
|
+
self._default_options: tv.TypedValues = DefaultOptions.pop(cc)
|
|
47
|
+
|
|
48
|
+
async def invoke(self, request: ChatChoicesRequest) -> ChatChoicesResponse:
|
|
49
|
+
tools: list[pt.ChatCompletionRequest.Tool] = []
|
|
50
|
+
with tv.TypedValues(*request.options).consume() as oc:
|
|
51
|
+
t: Tool
|
|
52
|
+
for t in oc.pop(Tool, []):
|
|
53
|
+
tools.append(build_cer_request_tool(t))
|
|
54
|
+
|
|
55
|
+
cer_request = pt.ChatCompletionRequest(
|
|
56
|
+
messages=build_cer_request_messages(request.v),
|
|
57
|
+
model=MODEL_NAMES.resolve(self._model_name.v),
|
|
58
|
+
tools=tools or None,
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
raw_request = msh.marshal(cer_request)
|
|
62
|
+
|
|
63
|
+
# TODO: headers:
|
|
64
|
+
# - CF-RAY
|
|
65
|
+
# - X-Amz-Cf-Id
|
|
66
|
+
# - X-delay-time
|
|
67
|
+
|
|
68
|
+
http_response = await http.async_request(
|
|
69
|
+
'https://api.cerebras.ai/v1/chat/completions',
|
|
70
|
+
headers={
|
|
71
|
+
http.consts.HEADER_CONTENT_TYPE: http.consts.CONTENT_TYPE_JSON,
|
|
72
|
+
http.consts.HEADER_AUTH: http.consts.format_bearer_auth_header(check.not_none(self._api_key).reveal()),
|
|
73
|
+
},
|
|
74
|
+
data=json.dumps(raw_request).encode('utf-8'),
|
|
75
|
+
client=self._http_client,
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
raw_response = json.loads(check.not_none(http_response.data).decode('utf-8'))
|
|
79
|
+
|
|
80
|
+
return build_mc_choices_response(msh.unmarshal(raw_response, pt.ChatCompletionResponse))
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
"""
|
|
2
|
+
https://inference-docs.cerebras.ai/models/overview
|
|
3
|
+
"""
|
|
4
|
+
from ....models.names import ModelNameCollection
|
|
5
|
+
from ...strings.manifests import BackendStringsManifest
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
##
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
MODEL_NAMES = ModelNameCollection(
|
|
12
|
+
default='gpt-oss-120b',
|
|
13
|
+
aliases={
|
|
14
|
+
'llama3.1-8b': None,
|
|
15
|
+
'llama-3.3-70b': None,
|
|
16
|
+
'gpt-oss-120b': None,
|
|
17
|
+
'qwen-3-32b': None,
|
|
18
|
+
},
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
# @omlish-manifest
|
|
23
|
+
_BACKEND_STRINGS_MANIFEST = BackendStringsManifest(
|
|
24
|
+
[
|
|
25
|
+
'ChatChoicesService',
|
|
26
|
+
'ChatChoicesStreamService',
|
|
27
|
+
],
|
|
28
|
+
'cerebras',
|
|
29
|
+
model_names=MODEL_NAMES,
|
|
30
|
+
)
|
|
@@ -0,0 +1,143 @@
|
|
|
1
|
+
import itertools
|
|
2
|
+
|
|
3
|
+
from omlish import check
|
|
4
|
+
from omlish.formats import json
|
|
5
|
+
|
|
6
|
+
from .....backends.cerebras import protocol as pt
|
|
7
|
+
from ....chat.choices.services import ChatChoicesResponse
|
|
8
|
+
from ....chat.choices.stream.types import AiChoiceDeltas
|
|
9
|
+
from ....chat.choices.types import AiChoice
|
|
10
|
+
from ....chat.messages import AiMessage
|
|
11
|
+
from ....chat.messages import AnyAiMessage
|
|
12
|
+
from ....chat.messages import Chat
|
|
13
|
+
from ....chat.messages import SystemMessage
|
|
14
|
+
from ....chat.messages import ToolUseMessage
|
|
15
|
+
from ....chat.messages import ToolUseResultMessage
|
|
16
|
+
from ....chat.messages import UserMessage
|
|
17
|
+
from ....chat.stream.types import AiDelta
|
|
18
|
+
from ....chat.stream.types import ContentAiDelta
|
|
19
|
+
from ....chat.stream.types import ToolUseAiDelta
|
|
20
|
+
from ....chat.tools.types import Tool
|
|
21
|
+
from ....content.prepare import prepare_content_str
|
|
22
|
+
from ....tools.jsonschema import build_tool_spec_params_json_schema
|
|
23
|
+
from ....tools.types import ToolUse
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
##
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def build_cer_request_messages(chat: Chat) -> list[pt.ChatCompletionRequest.Message]:
|
|
30
|
+
cer_msgs: list[pt.ChatCompletionRequest.Message] = []
|
|
31
|
+
|
|
32
|
+
for _, g in itertools.groupby(chat, lambda mc_m: isinstance(mc_m, AnyAiMessage)):
|
|
33
|
+
mc_msgs = list(g)
|
|
34
|
+
|
|
35
|
+
if isinstance(mc_msgs[0], AnyAiMessage):
|
|
36
|
+
tups: list[tuple[AiMessage | None, list[ToolUseMessage]]] = []
|
|
37
|
+
for mc_msg in mc_msgs:
|
|
38
|
+
if isinstance(mc_msg, AiMessage):
|
|
39
|
+
tups.append((mc_msg, []))
|
|
40
|
+
|
|
41
|
+
elif isinstance(mc_msg, ToolUseMessage):
|
|
42
|
+
if not tups:
|
|
43
|
+
tups.append((None, []))
|
|
44
|
+
tups[-1][1].append(mc_msg)
|
|
45
|
+
|
|
46
|
+
else:
|
|
47
|
+
raise TypeError(mc_msg)
|
|
48
|
+
|
|
49
|
+
for mc_ai_msg, mc_tu_msgs in tups:
|
|
50
|
+
cer_msgs.append(pt.ChatCompletionRequest.AssistantMessage(
|
|
51
|
+
content=check.isinstance(mc_ai_msg.c, str) if mc_ai_msg is not None else None,
|
|
52
|
+
tool_calls=[
|
|
53
|
+
pt.ChatCompletionRequest.AssistantMessage.ToolCall(
|
|
54
|
+
function=pt.ChatCompletionRequest.AssistantMessage.ToolCall.Function(
|
|
55
|
+
name=mc_tu_msg.tu.name,
|
|
56
|
+
arguments=check.not_none(mc_tu_msg.tu.raw_args),
|
|
57
|
+
),
|
|
58
|
+
id=check.not_none(mc_tu_msg.tu.id),
|
|
59
|
+
)
|
|
60
|
+
for mc_tu_msg in mc_tu_msgs
|
|
61
|
+
] if mc_tu_msgs else None,
|
|
62
|
+
))
|
|
63
|
+
|
|
64
|
+
else:
|
|
65
|
+
for mc_msg in mc_msgs:
|
|
66
|
+
if isinstance(mc_msg, SystemMessage):
|
|
67
|
+
cer_msgs.append(pt.ChatCompletionRequest.SystemMessage(
|
|
68
|
+
content=check.isinstance(mc_msg.c, str),
|
|
69
|
+
))
|
|
70
|
+
|
|
71
|
+
elif isinstance(mc_msg, UserMessage):
|
|
72
|
+
cer_msgs.append(pt.ChatCompletionRequest.UserMessage(
|
|
73
|
+
content=check.isinstance(mc_msg.c, str),
|
|
74
|
+
))
|
|
75
|
+
|
|
76
|
+
elif isinstance(mc_msg, ToolUseResultMessage):
|
|
77
|
+
cer_msgs.append(pt.ChatCompletionRequest.ToolMessage(
|
|
78
|
+
tool_call_id=check.not_none(mc_msg.tur.id),
|
|
79
|
+
content=check.isinstance(mc_msg.tur.c, str),
|
|
80
|
+
))
|
|
81
|
+
|
|
82
|
+
else:
|
|
83
|
+
raise TypeError(mc_msg)
|
|
84
|
+
|
|
85
|
+
return cer_msgs
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def build_cer_request_tool(t: Tool) -> pt.ChatCompletionRequest.Tool:
|
|
89
|
+
return pt.ChatCompletionRequest.Tool(
|
|
90
|
+
function=pt.ChatCompletionRequest.Tool.Function(
|
|
91
|
+
name=check.not_none(t.spec.name),
|
|
92
|
+
description=prepare_content_str(t.spec.desc),
|
|
93
|
+
parameters=build_tool_spec_params_json_schema(t.spec),
|
|
94
|
+
),
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def build_mc_choices_response(cer_resp: pt.ChatCompletionResponse) -> ChatChoicesResponse:
|
|
99
|
+
def build_choice(cer_choice: pt.ChatCompletionResponse.Choice) -> AiChoice:
|
|
100
|
+
cer_msg = cer_choice.message
|
|
101
|
+
|
|
102
|
+
lst: list[AnyAiMessage] = []
|
|
103
|
+
|
|
104
|
+
if cer_msg.content is not None:
|
|
105
|
+
lst.append(AiMessage(
|
|
106
|
+
check.isinstance(cer_msg.content, str),
|
|
107
|
+
))
|
|
108
|
+
|
|
109
|
+
for cer_tc in cer_msg.tool_calls or []:
|
|
110
|
+
lst.append(ToolUseMessage(ToolUse(
|
|
111
|
+
id=cer_tc.id,
|
|
112
|
+
name=cer_tc.function.name,
|
|
113
|
+
args=json.loads(cer_tc.function.arguments or '{}'),
|
|
114
|
+
raw_args=cer_tc.function.arguments,
|
|
115
|
+
)))
|
|
116
|
+
|
|
117
|
+
return AiChoice(lst)
|
|
118
|
+
|
|
119
|
+
return ChatChoicesResponse(list(map(build_choice, cer_resp.choices)))
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def build_mc_ai_choice_deltas(delta: pt.ChatCompletionChunk.Choice.Delta) -> AiChoiceDeltas:
|
|
123
|
+
if delta.role in (None, 'assistant'):
|
|
124
|
+
lst: list[AiDelta] = []
|
|
125
|
+
|
|
126
|
+
if delta.content is not None:
|
|
127
|
+
lst.append(ContentAiDelta(delta.content))
|
|
128
|
+
|
|
129
|
+
for tc in delta.tool_calls or []:
|
|
130
|
+
tc_fn = check.not_none(tc.function)
|
|
131
|
+
lst.append(ToolUseAiDelta(
|
|
132
|
+
id=tc.id,
|
|
133
|
+
name=check.not_none(tc_fn.name),
|
|
134
|
+
args=json.loads(tc_fn.arguments or '{}'),
|
|
135
|
+
))
|
|
136
|
+
|
|
137
|
+
return AiChoiceDeltas(lst)
|
|
138
|
+
|
|
139
|
+
elif delta.channel in ('analysis', 'commentary'):
|
|
140
|
+
return AiChoiceDeltas([])
|
|
141
|
+
|
|
142
|
+
else:
|
|
143
|
+
raise ValueError(delta)
|