ommlds 0.0.0.dev487__py3-none-any.whl → 0.0.0.dev489__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ommlds/.omlish-manifests.json +54 -0
- ommlds/__about__.py +1 -1
- ommlds/backends/cerebras/__init__.py +7 -0
- ommlds/backends/cerebras/_dataclasses.py +4254 -0
- ommlds/backends/cerebras/_marshal.py +24 -0
- ommlds/backends/cerebras/protocol.py +312 -0
- ommlds/cli/_dataclasses.py +272 -88
- ommlds/cli/main.py +29 -9
- ommlds/cli/secrets.py +1 -0
- ommlds/cli/sessions/chat/chat/user/configs.py +0 -1
- ommlds/cli/sessions/chat/chat/user/inject.py +0 -10
- ommlds/cli/sessions/chat/configs.py +2 -0
- ommlds/cli/sessions/chat/inject.py +3 -0
- ommlds/cli/sessions/chat/interface/__init__.py +0 -0
- ommlds/cli/sessions/chat/interface/bare/__init__.py +0 -0
- ommlds/cli/sessions/chat/interface/bare/inject.py +32 -0
- ommlds/cli/sessions/chat/interface/bare/interface.py +19 -0
- ommlds/cli/sessions/chat/{chat/user/interactive.py → interface/bare/user.py} +1 -1
- ommlds/cli/sessions/chat/interface/base.py +13 -0
- ommlds/cli/sessions/chat/interface/configs.py +15 -0
- ommlds/cli/sessions/chat/interface/inject.py +24 -0
- ommlds/cli/sessions/chat/interface/textual/__init__.py +0 -0
- ommlds/cli/sessions/chat/interface/textual/app.py +191 -0
- ommlds/cli/sessions/chat/interface/textual/inject.py +27 -0
- ommlds/cli/sessions/chat/interface/textual/interface.py +22 -0
- ommlds/cli/sessions/chat/interface/textual/user.py +20 -0
- ommlds/cli/sessions/chat/session.py +12 -4
- ommlds/minichain/backends/impls/cerebras/__init__.py +0 -0
- ommlds/minichain/backends/impls/cerebras/chat.py +80 -0
- ommlds/minichain/backends/impls/cerebras/names.py +30 -0
- ommlds/minichain/backends/impls/cerebras/protocol.py +143 -0
- ommlds/minichain/backends/impls/cerebras/stream.py +125 -0
- {ommlds-0.0.0.dev487.dist-info → ommlds-0.0.0.dev489.dist-info}/METADATA +6 -6
- {ommlds-0.0.0.dev487.dist-info → ommlds-0.0.0.dev489.dist-info}/RECORD +38 -17
- {ommlds-0.0.0.dev487.dist-info → ommlds-0.0.0.dev489.dist-info}/WHEEL +0 -0
- {ommlds-0.0.0.dev487.dist-info → ommlds-0.0.0.dev489.dist-info}/entry_points.txt +0 -0
- {ommlds-0.0.0.dev487.dist-info → ommlds-0.0.0.dev489.dist-info}/licenses/LICENSE +0 -0
- {ommlds-0.0.0.dev487.dist-info → ommlds-0.0.0.dev489.dist-info}/top_level.txt +0 -0
ommlds/cli/main.py
CHANGED
|
@@ -77,31 +77,47 @@ class ChatProfile(Profile):
|
|
|
77
77
|
|
|
78
78
|
#
|
|
79
79
|
|
|
80
|
-
|
|
81
|
-
ap.arg('
|
|
82
|
-
ap.arg('-
|
|
83
|
-
ap.arg('-e', '--editor', action='store_true', group='
|
|
80
|
+
INTERFACE_ARGS: ta.ClassVar[ta.Sequence[ap.Arg]] = [
|
|
81
|
+
ap.arg('-i', '--interactive', action='store_true', group='interface'),
|
|
82
|
+
ap.arg('-T', '--textual', action='store_true', group='interface'),
|
|
83
|
+
ap.arg('-e', '--editor', action='store_true', group='interface'),
|
|
84
84
|
]
|
|
85
85
|
|
|
86
|
-
def
|
|
86
|
+
def configure_interface(self, cfg: ChatConfig) -> ChatConfig:
|
|
87
87
|
if self._args.editor:
|
|
88
88
|
check.arg(not self._args.interactive)
|
|
89
89
|
check.arg(not self._args.message)
|
|
90
90
|
raise NotImplementedError
|
|
91
91
|
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
return dc.replace(
|
|
92
|
+
if self._args.interactive:
|
|
93
|
+
cfg = dc.replace(
|
|
95
94
|
cfg,
|
|
95
|
+
interface=dc.replace(
|
|
96
|
+
cfg.interface,
|
|
97
|
+
interactive=True,
|
|
98
|
+
use_textual=self._args.textual,
|
|
99
|
+
),
|
|
96
100
|
user=dc.replace(
|
|
97
101
|
cfg.user,
|
|
98
102
|
interactive=True,
|
|
99
103
|
),
|
|
100
104
|
)
|
|
101
105
|
|
|
106
|
+
return cfg
|
|
107
|
+
|
|
108
|
+
#
|
|
109
|
+
|
|
110
|
+
INPUT_ARGS: ta.ClassVar[ta.Sequence[ap.Arg]] = [
|
|
111
|
+
ap.arg('message', nargs='*', group='input'),
|
|
112
|
+
]
|
|
113
|
+
|
|
114
|
+
def configure_input(self, cfg: ChatConfig) -> ChatConfig:
|
|
115
|
+
if self._args.interactive:
|
|
116
|
+
check.arg(not self._args.message)
|
|
117
|
+
|
|
102
118
|
elif self._args.message:
|
|
103
119
|
# TODO: '-' -> stdin
|
|
104
|
-
|
|
120
|
+
cfg = dc.replace(
|
|
105
121
|
cfg,
|
|
106
122
|
user=dc.replace(
|
|
107
123
|
cfg.user,
|
|
@@ -112,6 +128,8 @@ class ChatProfile(Profile):
|
|
|
112
128
|
else:
|
|
113
129
|
raise ValueError('Must specify input')
|
|
114
130
|
|
|
131
|
+
return cfg
|
|
132
|
+
|
|
115
133
|
#
|
|
116
134
|
|
|
117
135
|
STATE_ARGS: ta.ClassVar[ta.Sequence[ap.Arg]] = [
|
|
@@ -220,6 +238,7 @@ class ChatProfile(Profile):
|
|
|
220
238
|
|
|
221
239
|
for grp_name, grp_args in [
|
|
222
240
|
('backend', self.BACKEND_ARGS),
|
|
241
|
+
('interface', self.INTERFACE_ARGS),
|
|
223
242
|
('input', self.INPUT_ARGS),
|
|
224
243
|
('state', self.STATE_ARGS),
|
|
225
244
|
('output', self.OUTPUT_ARGS),
|
|
@@ -234,6 +253,7 @@ class ChatProfile(Profile):
|
|
|
234
253
|
|
|
235
254
|
cfg = ChatConfig()
|
|
236
255
|
cfg = self.configure_backend(cfg)
|
|
256
|
+
cfg = self.configure_interface(cfg)
|
|
237
257
|
cfg = self.configure_input(cfg)
|
|
238
258
|
cfg = self.configure_state(cfg)
|
|
239
259
|
cfg = self.configure_output(cfg)
|
ommlds/cli/secrets.py
CHANGED
|
@@ -9,10 +9,7 @@ from .configs import UserConfig
|
|
|
9
9
|
|
|
10
10
|
|
|
11
11
|
with lang.auto_proxy_import(globals()):
|
|
12
|
-
from .....inputs import asyncs as _inputs_asyncs
|
|
13
|
-
from .....inputs import sync as _inputs_sync
|
|
14
12
|
from ..state import types as _state
|
|
15
|
-
from . import interactive as _interactive
|
|
16
13
|
from . import oneshot as _oneshot
|
|
17
14
|
from . import types as _types
|
|
18
15
|
|
|
@@ -43,13 +40,6 @@ def bind_user(cfg: UserConfig = UserConfig()) -> inj.Elements:
|
|
|
43
40
|
|
|
44
41
|
raise NotImplementedError
|
|
45
42
|
|
|
46
|
-
els.append(inj.bind(_types.UserChatInput, to_ctor=_interactive.InteractiveUserChatInput, singleton=True))
|
|
47
|
-
|
|
48
|
-
els.extend([
|
|
49
|
-
inj.bind(_inputs_sync.SyncStringInput, to_const=_inputs_sync.InputSyncStringInput(use_readline=cfg.use_readline)), # noqa
|
|
50
|
-
inj.bind(_inputs_asyncs.AsyncStringInput, to_ctor=_inputs_asyncs.ThreadAsyncStringInput, singleton=True),
|
|
51
|
-
])
|
|
52
|
-
|
|
53
43
|
else:
|
|
54
44
|
if cfg.initial_user_content is None:
|
|
55
45
|
raise ValueError('Initial user content is required for non-interactive chat')
|
|
@@ -5,6 +5,7 @@ from ...rendering.configs import RenderingConfig
|
|
|
5
5
|
from .chat.ai.configs import AiConfig
|
|
6
6
|
from .chat.state.configs import StateConfig
|
|
7
7
|
from .chat.user.configs import UserConfig
|
|
8
|
+
from .interface.configs import InterfaceConfig
|
|
8
9
|
from .tools.configs import ToolsConfig
|
|
9
10
|
|
|
10
11
|
|
|
@@ -24,4 +25,5 @@ class ChatConfig:
|
|
|
24
25
|
state: StateConfig = StateConfig()
|
|
25
26
|
user: UserConfig = UserConfig()
|
|
26
27
|
rendering: RenderingConfig = RenderingConfig()
|
|
28
|
+
interface: InterfaceConfig = InterfaceConfig()
|
|
27
29
|
tools: ToolsConfig = ToolsConfig()
|
|
@@ -16,6 +16,7 @@ with lang.auto_proxy_import(globals()):
|
|
|
16
16
|
from .chat.ai import inject as _chat_ai
|
|
17
17
|
from .chat.state import inject as _chat_state
|
|
18
18
|
from .chat.user import inject as _chat_user
|
|
19
|
+
from .interface import inject as _interface
|
|
19
20
|
from .phases import inject as _phases
|
|
20
21
|
from .tools import inject as _tools
|
|
21
22
|
|
|
@@ -37,6 +38,8 @@ def bind_chat(cfg: ChatConfig) -> inj.Elements:
|
|
|
37
38
|
|
|
38
39
|
_chat_state.bind_state(cfg.state),
|
|
39
40
|
|
|
41
|
+
_interface.bind_interface(cfg.interface),
|
|
42
|
+
|
|
40
43
|
_phases.bind_phases(),
|
|
41
44
|
|
|
42
45
|
_rendering.bind_rendering(cfg.rendering),
|
|
File without changes
|
|
File without changes
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
from omlish import inject as inj
|
|
2
|
+
from omlish import lang
|
|
3
|
+
|
|
4
|
+
from ..base import ChatInterface
|
|
5
|
+
from ..configs import InterfaceConfig
|
|
6
|
+
from .interface import BareChatInterface
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
with lang.auto_proxy_import(globals()):
|
|
10
|
+
from .....inputs import asyncs as _inputs_asyncs
|
|
11
|
+
from .....inputs import sync as _inputs_sync
|
|
12
|
+
from ...chat.user import types as _user_types
|
|
13
|
+
from . import user as _user
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
##
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def bind_bare(cfg: InterfaceConfig = InterfaceConfig()) -> inj.Elements:
|
|
20
|
+
els: list[inj.Elemental] = [
|
|
21
|
+
inj.bind(ChatInterface, to_ctor=BareChatInterface, singleton=True),
|
|
22
|
+
]
|
|
23
|
+
|
|
24
|
+
if cfg.interactive:
|
|
25
|
+
els.append(inj.bind(_user_types.UserChatInput, to_ctor=_user.InteractiveUserChatInput, singleton=True))
|
|
26
|
+
|
|
27
|
+
els.extend([
|
|
28
|
+
inj.bind(_inputs_sync.SyncStringInput, to_const=_inputs_sync.InputSyncStringInput(use_readline=cfg.use_readline)), # noqa
|
|
29
|
+
inj.bind(_inputs_asyncs.AsyncStringInput, to_ctor=_inputs_asyncs.ThreadAsyncStringInput, singleton=True),
|
|
30
|
+
])
|
|
31
|
+
|
|
32
|
+
return inj.as_elements(*els)
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
from ...driver import ChatDriver
|
|
2
|
+
from ..base import ChatInterface
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
##
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class BareChatInterface(ChatInterface):
|
|
9
|
+
def __init__(
|
|
10
|
+
self,
|
|
11
|
+
*,
|
|
12
|
+
driver: ChatDriver,
|
|
13
|
+
) -> None:
|
|
14
|
+
super().__init__()
|
|
15
|
+
|
|
16
|
+
self._driver = driver
|
|
17
|
+
|
|
18
|
+
async def run(self) -> None:
|
|
19
|
+
await self._driver.run()
|
|
@@ -4,7 +4,7 @@ from ...... import minichain as mc
|
|
|
4
4
|
from .....inputs.asyncs import AsyncStringInput
|
|
5
5
|
from .....inputs.asyncs import SyncAsyncStringInput
|
|
6
6
|
from .....inputs.sync import InputSyncStringInput
|
|
7
|
-
from .types import UserChatInput
|
|
7
|
+
from ...chat.user.types import UserChatInput
|
|
8
8
|
|
|
9
9
|
|
|
10
10
|
##
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
import typing as ta
|
|
2
|
+
|
|
3
|
+
from omlish import dataclasses as dc
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
##
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@dc.dataclass(frozen=True, kw_only=True)
|
|
10
|
+
class InterfaceConfig:
|
|
11
|
+
interactive: bool = False
|
|
12
|
+
|
|
13
|
+
use_textual: bool = False
|
|
14
|
+
|
|
15
|
+
use_readline: bool | ta.Literal['auto'] = 'auto'
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
from omlish import inject as inj
|
|
2
|
+
from omlish import lang
|
|
3
|
+
|
|
4
|
+
from .configs import InterfaceConfig
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
with lang.auto_proxy_import(globals()):
|
|
8
|
+
from .bare import inject as _bare
|
|
9
|
+
from .textual import inject as _textual
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
##
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def bind_interface(cfg: InterfaceConfig = InterfaceConfig()) -> inj.Elements:
|
|
16
|
+
els: list[inj.Elemental] = []
|
|
17
|
+
|
|
18
|
+
if cfg.use_textual:
|
|
19
|
+
els.append(_textual.bind_textual())
|
|
20
|
+
|
|
21
|
+
else:
|
|
22
|
+
els.append(_bare.bind_bare(cfg))
|
|
23
|
+
|
|
24
|
+
return inj.as_elements(*els)
|
|
File without changes
|
|
@@ -0,0 +1,191 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import dataclasses as dc
|
|
3
|
+
import typing as ta
|
|
4
|
+
|
|
5
|
+
from omdev.tui import textual as tx
|
|
6
|
+
from omlish import check
|
|
7
|
+
|
|
8
|
+
from ...... import minichain as mc
|
|
9
|
+
from ...driver import ChatDriver
|
|
10
|
+
from .user import QueueUserChatInput
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
##
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class UserMessage(tx.Static):
|
|
17
|
+
pass
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class AiMessage(tx.Static):
|
|
21
|
+
pass
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
##
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class InputTextArea(tx.TextArea):
|
|
28
|
+
@dc.dataclass()
|
|
29
|
+
class Submitted(tx.Message):
|
|
30
|
+
text: str
|
|
31
|
+
|
|
32
|
+
def __init__(self, **kwargs: ta.Any) -> None:
|
|
33
|
+
super().__init__(**kwargs)
|
|
34
|
+
|
|
35
|
+
async def _on_key(self, event: tx.Key) -> None:
|
|
36
|
+
if event.key == 'enter':
|
|
37
|
+
event.prevent_default()
|
|
38
|
+
event.stop()
|
|
39
|
+
|
|
40
|
+
if text := self.text.strip():
|
|
41
|
+
self.post_message(self.Submitted(text))
|
|
42
|
+
|
|
43
|
+
else:
|
|
44
|
+
await super()._on_key(event)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
##
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class ChatApp(tx.App):
|
|
51
|
+
def __init__(
|
|
52
|
+
self,
|
|
53
|
+
*,
|
|
54
|
+
chat_driver: ChatDriver,
|
|
55
|
+
queue_user_chat_input: QueueUserChatInput,
|
|
56
|
+
) -> None:
|
|
57
|
+
super().__init__()
|
|
58
|
+
|
|
59
|
+
self._chat_driver = chat_driver
|
|
60
|
+
self._queue_user_chat_input = queue_user_chat_input
|
|
61
|
+
|
|
62
|
+
CSS: ta.ClassVar[str] = """
|
|
63
|
+
#messages-scroll {
|
|
64
|
+
width: 100%;
|
|
65
|
+
height: 1fr;
|
|
66
|
+
|
|
67
|
+
padding: 0 2 0 2;
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
#messages-container {
|
|
71
|
+
height: auto;
|
|
72
|
+
width: 100%;
|
|
73
|
+
|
|
74
|
+
margin-top: 1;
|
|
75
|
+
margin-bottom: 0;
|
|
76
|
+
|
|
77
|
+
layout: stream;
|
|
78
|
+
text-align: left;
|
|
79
|
+
}
|
|
80
|
+
|
|
81
|
+
#input-outer {
|
|
82
|
+
width: 100%;
|
|
83
|
+
height: auto;
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
#input-vertical {
|
|
87
|
+
width: 100%;
|
|
88
|
+
height: auto;
|
|
89
|
+
|
|
90
|
+
margin: 0 2 1 2;
|
|
91
|
+
|
|
92
|
+
padding: 0;
|
|
93
|
+
}
|
|
94
|
+
|
|
95
|
+
#input-vertical2 {
|
|
96
|
+
width: 100%;
|
|
97
|
+
height: auto;
|
|
98
|
+
|
|
99
|
+
border: round $foreground-muted;
|
|
100
|
+
|
|
101
|
+
padding: 0 1;
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
#input-horizontal {
|
|
105
|
+
width: 100%;
|
|
106
|
+
height: auto;
|
|
107
|
+
}
|
|
108
|
+
|
|
109
|
+
#input-glyph {
|
|
110
|
+
width: auto;
|
|
111
|
+
|
|
112
|
+
padding: 0 1 0 0;
|
|
113
|
+
|
|
114
|
+
background: transparent;
|
|
115
|
+
color: $primary;
|
|
116
|
+
|
|
117
|
+
text-style: bold;
|
|
118
|
+
}
|
|
119
|
+
|
|
120
|
+
#input {
|
|
121
|
+
width: 1fr;
|
|
122
|
+
height: auto;
|
|
123
|
+
max-height: 16;
|
|
124
|
+
|
|
125
|
+
border: none;
|
|
126
|
+
|
|
127
|
+
padding: 0;
|
|
128
|
+
|
|
129
|
+
background: transparent;
|
|
130
|
+
color: $text;
|
|
131
|
+
}
|
|
132
|
+
"""
|
|
133
|
+
|
|
134
|
+
ENABLE_COMMAND_PALETTE: ta.ClassVar[bool] = False
|
|
135
|
+
|
|
136
|
+
#
|
|
137
|
+
|
|
138
|
+
def compose(self) -> tx.ComposeResult:
|
|
139
|
+
with tx.VerticalScroll(id='messages-scroll'):
|
|
140
|
+
yield tx.Static(id='messages-container')
|
|
141
|
+
|
|
142
|
+
with tx.Static(id='input-outer'):
|
|
143
|
+
with tx.Vertical(id='input-vertical'):
|
|
144
|
+
with tx.Vertical(id='input-vertical2'):
|
|
145
|
+
with tx.Horizontal(id='input-horizontal'):
|
|
146
|
+
yield tx.Static('>', id='input-glyph')
|
|
147
|
+
yield InputTextArea(placeholder='...', id='input')
|
|
148
|
+
|
|
149
|
+
#
|
|
150
|
+
|
|
151
|
+
def _get_input_text_area(self) -> InputTextArea:
|
|
152
|
+
return self.query_one('#input', InputTextArea)
|
|
153
|
+
|
|
154
|
+
def _get_messages_container(self) -> tx.Static:
|
|
155
|
+
return self.query_one('#messages-container', tx.Static)
|
|
156
|
+
|
|
157
|
+
#
|
|
158
|
+
|
|
159
|
+
async def _mount_message(self, *messages: tx.Widget) -> None:
|
|
160
|
+
msg_ctr = self._get_messages_container()
|
|
161
|
+
|
|
162
|
+
for msg in messages:
|
|
163
|
+
await msg_ctr.mount(msg)
|
|
164
|
+
|
|
165
|
+
self.call_after_refresh(lambda: msg_ctr.scroll_end(animate=False))
|
|
166
|
+
|
|
167
|
+
#
|
|
168
|
+
|
|
169
|
+
_chat_driver_task: asyncio.Task | None = None
|
|
170
|
+
|
|
171
|
+
async def on_mount(self) -> None:
|
|
172
|
+
check.none(self._chat_driver_task)
|
|
173
|
+
self._chat_driver_task = asyncio.create_task(self._chat_driver.run())
|
|
174
|
+
|
|
175
|
+
self._get_input_text_area().focus()
|
|
176
|
+
|
|
177
|
+
await self._mount_message(UserMessage('Hello!'))
|
|
178
|
+
|
|
179
|
+
async def on_unmount(self) -> None:
|
|
180
|
+
await self._queue_user_chat_input.push_next_user_messages([])
|
|
181
|
+
await check.not_none(self._chat_driver_task)
|
|
182
|
+
|
|
183
|
+
async def on_input_text_area_submitted(self, event: InputTextArea.Submitted) -> None:
|
|
184
|
+
self._get_input_text_area().clear()
|
|
185
|
+
|
|
186
|
+
await self._mount_message(
|
|
187
|
+
UserMessage(event.text),
|
|
188
|
+
# AiMessage(f'You said: {event.text}!'),
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
await self._queue_user_chat_input.push_next_user_messages([mc.UserMessage(event.text)])
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
from omlish import inject as inj
|
|
2
|
+
|
|
3
|
+
from ...chat.user.types import UserChatInput
|
|
4
|
+
from ..base import ChatInterface
|
|
5
|
+
from .app import ChatApp
|
|
6
|
+
from .interface import TextualChatInterface
|
|
7
|
+
from .user import QueueUserChatInput
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
##
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def bind_textual() -> inj.Elements:
|
|
14
|
+
els: list[inj.Elemental] = [
|
|
15
|
+
inj.bind(ChatInterface, to_ctor=TextualChatInterface, singleton=True),
|
|
16
|
+
]
|
|
17
|
+
|
|
18
|
+
els.extend([
|
|
19
|
+
inj.bind(ChatApp, singleton=True),
|
|
20
|
+
])
|
|
21
|
+
|
|
22
|
+
els.extend([
|
|
23
|
+
inj.bind(QueueUserChatInput, singleton=True),
|
|
24
|
+
inj.bind(UserChatInput, to_key=QueueUserChatInput),
|
|
25
|
+
])
|
|
26
|
+
|
|
27
|
+
return inj.as_elements(*els)
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
from ...driver import ChatDriver
|
|
2
|
+
from ..base import ChatInterface
|
|
3
|
+
from .app import ChatApp
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
##
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class TextualChatInterface(ChatInterface):
|
|
10
|
+
def __init__(
|
|
11
|
+
self,
|
|
12
|
+
*,
|
|
13
|
+
driver: ChatDriver,
|
|
14
|
+
app: ChatApp,
|
|
15
|
+
) -> None:
|
|
16
|
+
super().__init__()
|
|
17
|
+
|
|
18
|
+
self._driver = driver
|
|
19
|
+
self._app = app
|
|
20
|
+
|
|
21
|
+
async def run(self) -> None:
|
|
22
|
+
await self._app.run_async()
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
|
|
3
|
+
from ...... import minichain as mc
|
|
4
|
+
from ...chat.user.types import UserChatInput
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
##
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class QueueUserChatInput(UserChatInput):
|
|
11
|
+
def __init__(self) -> None:
|
|
12
|
+
super().__init__()
|
|
13
|
+
|
|
14
|
+
self._queue: asyncio.Queue[mc.UserChat] = asyncio.Queue()
|
|
15
|
+
|
|
16
|
+
async def push_next_user_messages(self, chat: 'mc.UserChat') -> None:
|
|
17
|
+
await self._queue.put(chat)
|
|
18
|
+
|
|
19
|
+
async def get_next_user_messages(self) -> 'mc.UserChat':
|
|
20
|
+
return await self._queue.get()
|
|
@@ -1,14 +1,22 @@
|
|
|
1
|
+
import typing as ta
|
|
2
|
+
|
|
1
3
|
from omlish import dataclasses as dc
|
|
2
4
|
|
|
3
5
|
from ..base import Session
|
|
4
6
|
from .configs import ChatConfig
|
|
5
|
-
from .
|
|
7
|
+
from .interface.base import ChatInterface
|
|
6
8
|
|
|
7
9
|
|
|
8
10
|
##
|
|
9
11
|
|
|
10
12
|
|
|
13
|
+
@ta.final
|
|
11
14
|
class ChatSession(Session['ChatSession.Config']):
|
|
15
|
+
"""
|
|
16
|
+
An adapter to the lower level, dumber, non-chat-specific cli 'session' layer. Nothing else takes the kitchen-sink
|
|
17
|
+
'ChatConfig' object, it's only here for type dispatch in lower layers.
|
|
18
|
+
"""
|
|
19
|
+
|
|
12
20
|
@dc.dataclass(frozen=True)
|
|
13
21
|
class Config(Session.Config, ChatConfig):
|
|
14
22
|
pass
|
|
@@ -17,11 +25,11 @@ class ChatSession(Session['ChatSession.Config']):
|
|
|
17
25
|
self,
|
|
18
26
|
config: Config,
|
|
19
27
|
*,
|
|
20
|
-
|
|
28
|
+
interface: ChatInterface,
|
|
21
29
|
) -> None:
|
|
22
30
|
super().__init__(config)
|
|
23
31
|
|
|
24
|
-
self.
|
|
32
|
+
self._interface = interface
|
|
25
33
|
|
|
26
34
|
async def run(self) -> None:
|
|
27
|
-
await self.
|
|
35
|
+
await self._interface.run()
|
|
File without changes
|
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
import typing as ta
|
|
2
|
+
|
|
3
|
+
from omlish import check
|
|
4
|
+
from omlish import marshal as msh
|
|
5
|
+
from omlish import typedvalues as tv
|
|
6
|
+
from omlish.formats import json
|
|
7
|
+
from omlish.http import all as http
|
|
8
|
+
|
|
9
|
+
from .....backends.cerebras import protocol as pt
|
|
10
|
+
from ....chat.choices.services import ChatChoicesRequest
|
|
11
|
+
from ....chat.choices.services import ChatChoicesResponse
|
|
12
|
+
from ....chat.choices.services import static_check_is_chat_choices_service
|
|
13
|
+
from ....chat.tools.types import Tool
|
|
14
|
+
from ....models.configs import ModelName
|
|
15
|
+
from ....standard import ApiKey
|
|
16
|
+
from ....standard import DefaultOptions
|
|
17
|
+
from .names import MODEL_NAMES
|
|
18
|
+
from .protocol import build_cer_request_messages
|
|
19
|
+
from .protocol import build_cer_request_tool
|
|
20
|
+
from .protocol import build_mc_choices_response
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
##
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
# @omlish-manifest $.minichain.registries.manifests.RegistryManifest(
|
|
27
|
+
# name='cerebras',
|
|
28
|
+
# type='ChatChoicesService',
|
|
29
|
+
# )
|
|
30
|
+
@static_check_is_chat_choices_service
|
|
31
|
+
class CerebrasChatChoicesService:
|
|
32
|
+
DEFAULT_MODEL_NAME: ta.ClassVar[ModelName] = ModelName(check.not_none(MODEL_NAMES.default))
|
|
33
|
+
|
|
34
|
+
def __init__(
|
|
35
|
+
self,
|
|
36
|
+
*configs: ApiKey | ModelName | DefaultOptions,
|
|
37
|
+
http_client: http.AsyncHttpClient | None = None,
|
|
38
|
+
) -> None:
|
|
39
|
+
super().__init__()
|
|
40
|
+
|
|
41
|
+
self._http_client = http_client
|
|
42
|
+
|
|
43
|
+
with tv.consume(*configs) as cc:
|
|
44
|
+
self._model_name = cc.pop(self.DEFAULT_MODEL_NAME)
|
|
45
|
+
self._api_key = ApiKey.pop_secret(cc, env='CEREBRAS_API_KEY')
|
|
46
|
+
self._default_options: tv.TypedValues = DefaultOptions.pop(cc)
|
|
47
|
+
|
|
48
|
+
async def invoke(self, request: ChatChoicesRequest) -> ChatChoicesResponse:
|
|
49
|
+
tools: list[pt.ChatCompletionRequest.Tool] = []
|
|
50
|
+
with tv.TypedValues(*request.options).consume() as oc:
|
|
51
|
+
t: Tool
|
|
52
|
+
for t in oc.pop(Tool, []):
|
|
53
|
+
tools.append(build_cer_request_tool(t))
|
|
54
|
+
|
|
55
|
+
cer_request = pt.ChatCompletionRequest(
|
|
56
|
+
messages=build_cer_request_messages(request.v),
|
|
57
|
+
model=MODEL_NAMES.resolve(self._model_name.v),
|
|
58
|
+
tools=tools or None,
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
raw_request = msh.marshal(cer_request)
|
|
62
|
+
|
|
63
|
+
# TODO: headers:
|
|
64
|
+
# - CF-RAY
|
|
65
|
+
# - X-Amz-Cf-Id
|
|
66
|
+
# - X-delay-time
|
|
67
|
+
|
|
68
|
+
http_response = await http.async_request(
|
|
69
|
+
'https://api.cerebras.ai/v1/chat/completions',
|
|
70
|
+
headers={
|
|
71
|
+
http.consts.HEADER_CONTENT_TYPE: http.consts.CONTENT_TYPE_JSON,
|
|
72
|
+
http.consts.HEADER_AUTH: http.consts.format_bearer_auth_header(check.not_none(self._api_key).reveal()),
|
|
73
|
+
},
|
|
74
|
+
data=json.dumps(raw_request).encode('utf-8'),
|
|
75
|
+
client=self._http_client,
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
raw_response = json.loads(check.not_none(http_response.data).decode('utf-8'))
|
|
79
|
+
|
|
80
|
+
return build_mc_choices_response(msh.unmarshal(raw_response, pt.ChatCompletionResponse))
|