ommlds 0.0.0.dev462__py3-none-any.whl → 0.0.0.dev463__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ommlds might be problematic. Click here for more details.
- ommlds/backends/mlx/loading.py +58 -1
- ommlds/cli/main.py +18 -2
- ommlds/cli/sessions/chat/state.py +3 -3
- ommlds/cli/sessions/chat2/__init__.py +0 -0
- ommlds/cli/sessions/chat2/_inject.py +105 -0
- ommlds/cli/sessions/chat2/backends/__init__.py +0 -0
- ommlds/cli/sessions/chat2/backends/catalog.py +56 -0
- ommlds/cli/sessions/chat2/backends/types.py +36 -0
- ommlds/cli/sessions/chat2/chat/__init__.py +0 -0
- ommlds/cli/sessions/chat2/chat/ai/__init__.py +0 -0
- ommlds/cli/sessions/chat2/chat/ai/rendering.py +67 -0
- ommlds/cli/sessions/chat2/chat/ai/services.py +70 -0
- ommlds/cli/sessions/chat2/chat/ai/types.py +28 -0
- ommlds/cli/sessions/chat2/chat/state/__init__.py +0 -0
- ommlds/cli/sessions/chat2/chat/state/inmemory.py +34 -0
- ommlds/cli/sessions/chat2/chat/state/storage.py +53 -0
- ommlds/cli/sessions/chat2/chat/state/types.py +38 -0
- ommlds/cli/sessions/chat2/chat/user/__init__.py +0 -0
- ommlds/cli/sessions/chat2/chat/user/interactive.py +29 -0
- ommlds/cli/sessions/chat2/chat/user/oneshot.py +25 -0
- ommlds/cli/sessions/chat2/chat/user/types.py +15 -0
- ommlds/cli/sessions/chat2/configs.py +33 -0
- ommlds/cli/sessions/chat2/content/__init__.py +0 -0
- ommlds/cli/sessions/chat2/content/messages.py +30 -0
- ommlds/cli/sessions/chat2/content/strings.py +42 -0
- ommlds/cli/sessions/chat2/driver.py +43 -0
- ommlds/cli/sessions/chat2/inject.py +143 -0
- ommlds/cli/sessions/chat2/phases.py +55 -0
- ommlds/cli/sessions/chat2/rendering/__init__.py +0 -0
- ommlds/cli/sessions/chat2/rendering/markdown.py +52 -0
- ommlds/cli/sessions/chat2/rendering/raw.py +73 -0
- ommlds/cli/sessions/chat2/rendering/types.py +21 -0
- ommlds/cli/sessions/chat2/session.py +27 -0
- ommlds/cli/sessions/chat2/tools/__init__.py +0 -0
- ommlds/cli/sessions/chat2/tools/confirmation.py +46 -0
- ommlds/cli/sessions/chat2/tools/execution.py +53 -0
- ommlds/cli/sessions/inject.py +6 -1
- ommlds/cli/state.py +40 -23
- {ommlds-0.0.0.dev462.dist-info → ommlds-0.0.0.dev463.dist-info}/METADATA +3 -3
- {ommlds-0.0.0.dev462.dist-info → ommlds-0.0.0.dev463.dist-info}/RECORD +44 -11
- {ommlds-0.0.0.dev462.dist-info → ommlds-0.0.0.dev463.dist-info}/WHEEL +0 -0
- {ommlds-0.0.0.dev462.dist-info → ommlds-0.0.0.dev463.dist-info}/entry_points.txt +0 -0
- {ommlds-0.0.0.dev462.dist-info → ommlds-0.0.0.dev463.dist-info}/licenses/LICENSE +0 -0
- {ommlds-0.0.0.dev462.dist-info → ommlds-0.0.0.dev463.dist-info}/top_level.txt +0 -0
ommlds/backends/mlx/loading.py
CHANGED
|
@@ -15,6 +15,63 @@ from .tokenization import load_tokenization
|
|
|
15
15
|
##
|
|
16
16
|
|
|
17
17
|
|
|
18
|
+
def get_model_path(
|
|
19
|
+
path_or_hf_repo: str,
|
|
20
|
+
revision: str | None = None,
|
|
21
|
+
) -> tuple[pathlib.Path, str | None]:
|
|
22
|
+
"""
|
|
23
|
+
Ensures the model is available locally. If the path does not exist locally,
|
|
24
|
+
it is downloaded from the Hugging Face Hub.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
path_or_hf_repo (str): The local path or Hugging Face repository ID of the model.
|
|
28
|
+
revision (str, optional): A revision id which can be a branch name, a tag, or a commit hash.
|
|
29
|
+
|
|
30
|
+
Returns:
|
|
31
|
+
Tuple[Path, str]: A tuple containing the local file path and the Hugging Face repo ID.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
model_path = pathlib.Path(path_or_hf_repo)
|
|
35
|
+
|
|
36
|
+
if not model_path.exists():
|
|
37
|
+
from huggingface_hub import snapshot_download
|
|
38
|
+
hf_path = path_or_hf_repo
|
|
39
|
+
model_path = pathlib.Path(
|
|
40
|
+
snapshot_download(
|
|
41
|
+
path_or_hf_repo,
|
|
42
|
+
revision=revision,
|
|
43
|
+
allow_patterns=[
|
|
44
|
+
'*.jinja',
|
|
45
|
+
'*.json',
|
|
46
|
+
'*.jsonl',
|
|
47
|
+
'*.py',
|
|
48
|
+
'*.txt',
|
|
49
|
+
|
|
50
|
+
'model*.safetensors',
|
|
51
|
+
|
|
52
|
+
'*.tiktoken',
|
|
53
|
+
'tiktoken.model',
|
|
54
|
+
'tokenizer.model',
|
|
55
|
+
],
|
|
56
|
+
),
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
else:
|
|
60
|
+
from huggingface_hub import ModelCard
|
|
61
|
+
|
|
62
|
+
card_path = model_path / 'README.md'
|
|
63
|
+
if card_path.is_file():
|
|
64
|
+
card = ModelCard.load(card_path)
|
|
65
|
+
hf_path = card.data.base_model
|
|
66
|
+
else:
|
|
67
|
+
hf_path = None
|
|
68
|
+
|
|
69
|
+
return model_path, hf_path
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
##
|
|
73
|
+
|
|
74
|
+
|
|
18
75
|
@dc.dataclass(frozen=True, kw_only=True)
|
|
19
76
|
class LoadedModel:
|
|
20
77
|
path: pathlib.Path
|
|
@@ -46,7 +103,7 @@ def load_model(
|
|
|
46
103
|
) -> LoadedModel:
|
|
47
104
|
# FIXME: get_model_path return annotation is wrong:
|
|
48
105
|
# https://github.com/ml-explore/mlx-lm/blob/9ee2b7358f5e258af7b31a8561acfbbe56ad5085/mlx_lm/utils.py#L82
|
|
49
|
-
model_path_res = ta.cast(ta.Any,
|
|
106
|
+
model_path_res = ta.cast(ta.Any, get_model_path(path_or_hf_repo))
|
|
50
107
|
if isinstance(model_path_res, tuple):
|
|
51
108
|
model_path = check.isinstance(model_path_res[0], pathlib.Path)
|
|
52
109
|
else:
|
ommlds/cli/main.py
CHANGED
|
@@ -27,6 +27,7 @@ from .sessions.base import Session
|
|
|
27
27
|
from .sessions.chat.code import CodeChatSession
|
|
28
28
|
from .sessions.chat.interactive import InteractiveChatSession
|
|
29
29
|
from .sessions.chat.prompt import PromptChatSession
|
|
30
|
+
from .sessions.chat2.session import Chat2Session
|
|
30
31
|
from .sessions.completion.completion import CompletionSession
|
|
31
32
|
from .sessions.embedding.embedding import EmbeddingSession
|
|
32
33
|
from .tools.config import ToolsConfig
|
|
@@ -66,6 +67,8 @@ async def _a_main(args: ta.Any = None) -> None:
|
|
|
66
67
|
parser.add_argument('-E', '--embed', action='store_true')
|
|
67
68
|
parser.add_argument('-j', '--image', action='store_true')
|
|
68
69
|
|
|
70
|
+
parser.add_argument('-2', '--two', action='store_true')
|
|
71
|
+
|
|
69
72
|
parser.add_argument('--enable-fs-tools', action='store_true')
|
|
70
73
|
parser.add_argument('--enable-todo-tools', action='store_true')
|
|
71
74
|
parser.add_argument('--enable-unsafe-tools-do-not-use-lol', action='store_true')
|
|
@@ -128,7 +131,16 @@ async def _a_main(args: ta.Any = None) -> None:
|
|
|
128
131
|
|
|
129
132
|
session_cfg: Session.Config
|
|
130
133
|
|
|
131
|
-
if args.
|
|
134
|
+
if args.two:
|
|
135
|
+
session_cfg = Chat2Session.Config(
|
|
136
|
+
backend=args.backend,
|
|
137
|
+
model_name=args.model_name,
|
|
138
|
+
state='ephemeral',
|
|
139
|
+
initial_content=content, # noqa
|
|
140
|
+
# dangerous_no_tool_confirmation=bool(args.dangerous_no_tool_confirmation),
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
elif args.interactive:
|
|
132
144
|
session_cfg = InteractiveChatSession.Config(
|
|
133
145
|
backend=args.backend,
|
|
134
146
|
model_name=args.model_name,
|
|
@@ -183,7 +195,11 @@ async def _a_main(args: ta.Any = None) -> None:
|
|
|
183
195
|
with inj.create_managed_injector(bind_main(
|
|
184
196
|
session_cfg=session_cfg,
|
|
185
197
|
tools_config=tools_config,
|
|
186
|
-
enable_backend_strings=isinstance(session_cfg, (
|
|
198
|
+
enable_backend_strings=isinstance(session_cfg, (
|
|
199
|
+
Chat2Session.Config,
|
|
200
|
+
CodeChatSession.Config,
|
|
201
|
+
PromptChatSession.Config,
|
|
202
|
+
)),
|
|
187
203
|
)) as injector:
|
|
188
204
|
await injector[Session].run()
|
|
189
205
|
|
|
@@ -86,7 +86,7 @@ class StateStorageChatStateManager(ChatStateManager):
|
|
|
86
86
|
def get_state(self) -> ChatState:
|
|
87
87
|
if self._state is not None:
|
|
88
88
|
return self._state
|
|
89
|
-
state: ChatState | None = self._storage.load_state(self._key, ChatState)
|
|
89
|
+
state: ChatState | None = lang.sync_await(self._storage.load_state(self._key, ChatState))
|
|
90
90
|
if state is None:
|
|
91
91
|
state = ChatState()
|
|
92
92
|
self._state = state
|
|
@@ -94,7 +94,7 @@ class StateStorageChatStateManager(ChatStateManager):
|
|
|
94
94
|
|
|
95
95
|
def clear_state(self) -> ChatState:
|
|
96
96
|
state = ChatState()
|
|
97
|
-
self._storage.save_state(self._key, state, ChatState)
|
|
97
|
+
lang.sync_await(self._storage.save_state(self._key, state, ChatState))
|
|
98
98
|
self._state = state
|
|
99
99
|
return state
|
|
100
100
|
|
|
@@ -105,6 +105,6 @@ class StateStorageChatStateManager(ChatStateManager):
|
|
|
105
105
|
chat=[*state.chat, *chat_additions],
|
|
106
106
|
updated_at=lang.utcnow(),
|
|
107
107
|
)
|
|
108
|
-
self._storage.save_state(self._key, state, ChatState)
|
|
108
|
+
lang.sync_await(self._storage.save_state(self._key, state, ChatState))
|
|
109
109
|
self._state = state
|
|
110
110
|
return state
|
|
File without changes
|
|
@@ -0,0 +1,105 @@
|
|
|
1
|
+
from omlish import lang
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
with lang.auto_proxy_init(globals()):
|
|
5
|
+
from .backends.catalog import ( # noqa
|
|
6
|
+
CatalogChatChoicesServiceBackendProvider,
|
|
7
|
+
CatalogChatChoicesStreamServiceBackendProvider,
|
|
8
|
+
)
|
|
9
|
+
|
|
10
|
+
from .backends.types import ( # noqa
|
|
11
|
+
BackendName,
|
|
12
|
+
BackendConfigs,
|
|
13
|
+
BackendProvider,
|
|
14
|
+
ChatChoicesServiceBackendProvider,
|
|
15
|
+
ChatChoicesStreamServiceBackendProvider,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
from .chat.ai.rendering import ( # noqa
|
|
19
|
+
RenderingAiChatGenerator,
|
|
20
|
+
RenderingStreamAiChatGenerator,
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
from .chat.ai.services import ( # noqa
|
|
24
|
+
ChatChoicesServiceOptions,
|
|
25
|
+
ChatChoicesServiceAiChatGenerator,
|
|
26
|
+
ChatChoicesStreamServiceStreamAiChatGenerator,
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
from .chat.ai.types import ( # noqa
|
|
30
|
+
AiChatGenerator,
|
|
31
|
+
StreamAiChatGenerator,
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
from .chat.state.inmemory import ( # noqa
|
|
35
|
+
InMemoryChatStateManager,
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
from .chat.state.storage import ( # noqa
|
|
39
|
+
StateStorageChatStateManager,
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
from .chat.state.types import ( # noqa
|
|
43
|
+
ChatState,
|
|
44
|
+
ChatStateManager,
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
from .chat.user.interactive import ( # noqa
|
|
48
|
+
InteractiveUserChatInput,
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
from .chat.user.oneshot import ( # noqa
|
|
52
|
+
OneshotUserChatInputInitialChat,
|
|
53
|
+
OneshotUserChatInput,
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
from .chat.user.types import ( # noqa
|
|
57
|
+
UserChatInput,
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
from .content.messages import ( # noqa
|
|
61
|
+
MessageContentExtractor,
|
|
62
|
+
MessageContentExtractorImpl,
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
from .content.strings import ( # noqa
|
|
66
|
+
ContentStringifier,
|
|
67
|
+
ContentStringifierImpl,
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
from .rendering.markdown import ( # noqa
|
|
71
|
+
MarkdownContentRendering,
|
|
72
|
+
MarkdownStreamContentRendering,
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
from .rendering.raw import ( # noqa
|
|
76
|
+
RawContentRendering,
|
|
77
|
+
RawContentStreamRendering,
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
from .rendering.types import ( # noqa
|
|
81
|
+
ContentRendering,
|
|
82
|
+
StreamContentRendering,
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
from .tools.confirmation import ( # noqa
|
|
86
|
+
ToolExecutionRequestDeniedError,
|
|
87
|
+
ToolExecutionConfirmation,
|
|
88
|
+
InteractiveToolExecutionConfirmation,
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
from .tools.execution import ( # noqa
|
|
92
|
+
ToolUseExecutor,
|
|
93
|
+
ToolUseExecutorImpl,
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
from .driver import ( # noqa
|
|
97
|
+
ChatDriver,
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
from .phases import ( # noqa
|
|
101
|
+
ChatPhase,
|
|
102
|
+
ChatPhaseCallback,
|
|
103
|
+
ChatPhaseCallbacks,
|
|
104
|
+
ChatPhaseManager,
|
|
105
|
+
)
|
|
File without changes
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
import contextlib
|
|
2
|
+
import typing as ta
|
|
3
|
+
|
|
4
|
+
from omlish import lang
|
|
5
|
+
|
|
6
|
+
from ..... import minichain as mc
|
|
7
|
+
from .types import BackendConfigs
|
|
8
|
+
from .types import BackendName
|
|
9
|
+
from .types import BackendProvider
|
|
10
|
+
from .types import ChatChoicesServiceBackendProvider
|
|
11
|
+
from .types import ChatChoicesStreamServiceBackendProvider
|
|
12
|
+
from .types import ServiceT
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
##
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class _CatalogBackendProvider(BackendProvider[ServiceT], lang.Abstract):
|
|
19
|
+
def __init__(
|
|
20
|
+
self,
|
|
21
|
+
*,
|
|
22
|
+
name: BackendName,
|
|
23
|
+
catalog: mc.BackendCatalog,
|
|
24
|
+
configs: BackendConfigs | None = None,
|
|
25
|
+
) -> None:
|
|
26
|
+
super().__init__()
|
|
27
|
+
|
|
28
|
+
self._name = name
|
|
29
|
+
self._catalog = catalog
|
|
30
|
+
self._configs = configs
|
|
31
|
+
|
|
32
|
+
@contextlib.asynccontextmanager
|
|
33
|
+
async def _provide_backend(self, cls: type[ServiceT]) -> ta.AsyncIterator[ServiceT]:
|
|
34
|
+
service: ServiceT
|
|
35
|
+
async with lang.async_maybe_managing(self._catalog.get_backend(
|
|
36
|
+
cls,
|
|
37
|
+
self._name,
|
|
38
|
+
*(self._configs or []),
|
|
39
|
+
)) as service:
|
|
40
|
+
yield service
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class CatalogChatChoicesServiceBackendProvider(
|
|
44
|
+
_CatalogBackendProvider[mc.ChatChoicesService],
|
|
45
|
+
ChatChoicesServiceBackendProvider,
|
|
46
|
+
):
|
|
47
|
+
def provide_backend(self) -> ta.AsyncContextManager[mc.ChatChoicesService]:
|
|
48
|
+
return self._provide_backend(mc.ChatChoicesService) # type: ignore[type-abstract]
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class CatalogChatChoicesStreamServiceBackendProvider(
|
|
52
|
+
_CatalogBackendProvider[mc.ChatChoicesStreamService],
|
|
53
|
+
ChatChoicesStreamServiceBackendProvider,
|
|
54
|
+
):
|
|
55
|
+
def provide_backend(self) -> ta.AsyncContextManager[mc.ChatChoicesStreamService]:
|
|
56
|
+
return self._provide_backend(mc.ChatChoicesStreamService) # type: ignore[type-abstract]
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
import abc
|
|
2
|
+
import typing as ta
|
|
3
|
+
|
|
4
|
+
from omlish import lang
|
|
5
|
+
|
|
6
|
+
from ..... import minichain as mc
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
ServiceT = ta.TypeVar('ServiceT', bound=mc.Service)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
##
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
BackendName = ta.NewType('BackendName', str)
|
|
16
|
+
BackendConfigs = ta.NewType('BackendConfigs', ta.Sequence[mc.Config])
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
##
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class BackendProvider(lang.Abstract, ta.Generic[ServiceT]):
|
|
23
|
+
@abc.abstractmethod
|
|
24
|
+
def provide_backend(self) -> ta.AsyncContextManager[ServiceT]:
|
|
25
|
+
raise NotImplementedError
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
##
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class ChatChoicesServiceBackendProvider(BackendProvider[mc.ChatChoicesService], lang.Abstract):
|
|
32
|
+
pass
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class ChatChoicesStreamServiceBackendProvider(BackendProvider[mc.ChatChoicesStreamService], lang.Abstract):
|
|
36
|
+
pass
|
|
File without changes
|
|
File without changes
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
import typing as ta
|
|
2
|
+
|
|
3
|
+
from ...... import minichain as mc
|
|
4
|
+
from ...content.messages import MessageContentExtractor
|
|
5
|
+
from ...content.messages import MessageContentExtractorImpl
|
|
6
|
+
from ...rendering.types import ContentRendering
|
|
7
|
+
from .types import AiChatGenerator
|
|
8
|
+
from .types import StreamAiChatGenerator
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
##
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class RenderingAiChatGenerator(AiChatGenerator):
|
|
15
|
+
def __init__(
|
|
16
|
+
self,
|
|
17
|
+
*,
|
|
18
|
+
wrapped: AiChatGenerator,
|
|
19
|
+
extractor: MessageContentExtractor | None = None,
|
|
20
|
+
renderer: ContentRendering,
|
|
21
|
+
) -> None:
|
|
22
|
+
super().__init__()
|
|
23
|
+
|
|
24
|
+
self._wrapped = wrapped
|
|
25
|
+
if extractor is None:
|
|
26
|
+
extractor = MessageContentExtractorImpl()
|
|
27
|
+
self._extractor = extractor
|
|
28
|
+
self._renderer = renderer
|
|
29
|
+
|
|
30
|
+
async def get_next_ai_messages(self, chat: mc.Chat) -> mc.AiChat:
|
|
31
|
+
out = await self._wrapped.get_next_ai_messages(chat)
|
|
32
|
+
|
|
33
|
+
for msg in out:
|
|
34
|
+
if (c := self._extractor.extract_message_content(msg)) is not None:
|
|
35
|
+
await self._renderer.render_content(c)
|
|
36
|
+
|
|
37
|
+
return out
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class RenderingStreamAiChatGenerator(StreamAiChatGenerator):
|
|
41
|
+
def __init__(
|
|
42
|
+
self,
|
|
43
|
+
*,
|
|
44
|
+
wrapped: StreamAiChatGenerator,
|
|
45
|
+
extractor: MessageContentExtractor | None = None,
|
|
46
|
+
renderer: ContentRendering,
|
|
47
|
+
) -> None:
|
|
48
|
+
super().__init__()
|
|
49
|
+
|
|
50
|
+
self._wrapped = wrapped
|
|
51
|
+
if extractor is None:
|
|
52
|
+
extractor = MessageContentExtractorImpl()
|
|
53
|
+
self._extractor = extractor
|
|
54
|
+
self._renderer = renderer
|
|
55
|
+
|
|
56
|
+
async def get_next_ai_messages_streamed(
|
|
57
|
+
self,
|
|
58
|
+
chat: mc.Chat,
|
|
59
|
+
delta_callback: ta.Callable[[mc.AiChoiceDelta], ta.Awaitable[None]] | None = None,
|
|
60
|
+
) -> mc.AiChat:
|
|
61
|
+
async def inner(delta: mc.AiChoiceDelta) -> None:
|
|
62
|
+
# FIXME: render lol
|
|
63
|
+
|
|
64
|
+
if delta_callback is not None:
|
|
65
|
+
await delta_callback(delta)
|
|
66
|
+
|
|
67
|
+
return await self._wrapped.get_next_ai_messages_streamed(chat, delta_callback=inner)
|
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
import typing as ta
|
|
2
|
+
|
|
3
|
+
from omlish import check
|
|
4
|
+
|
|
5
|
+
from ...... import minichain as mc
|
|
6
|
+
from ...backends.types import ChatChoicesServiceBackendProvider
|
|
7
|
+
from .types import AiChatGenerator
|
|
8
|
+
from .types import StreamAiChatGenerator
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
##
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
ChatChoicesServiceOptions = ta.NewType('ChatChoicesServiceOptions', ta.Sequence[mc.ChatChoicesOptions])
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
##
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class ChatChoicesServiceAiChatGenerator(AiChatGenerator):
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
service_provider: ChatChoicesServiceBackendProvider,
|
|
24
|
+
*,
|
|
25
|
+
options: ChatChoicesServiceOptions | None = None,
|
|
26
|
+
) -> None:
|
|
27
|
+
super().__init__()
|
|
28
|
+
|
|
29
|
+
self._service_provider = service_provider
|
|
30
|
+
self._options = options
|
|
31
|
+
|
|
32
|
+
async def get_next_ai_messages(self, chat: mc.Chat) -> mc.AiChat:
|
|
33
|
+
async with self._service_provider.provide_backend() as service:
|
|
34
|
+
resp = await service.invoke(mc.ChatChoicesRequest(chat, self._options or []))
|
|
35
|
+
|
|
36
|
+
return check.single(resp.v).ms
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class ChatChoicesStreamServiceStreamAiChatGenerator(StreamAiChatGenerator):
|
|
40
|
+
def __init__(
|
|
41
|
+
self,
|
|
42
|
+
service: mc.ChatChoicesStreamService,
|
|
43
|
+
*,
|
|
44
|
+
options: ChatChoicesServiceOptions | None = None,
|
|
45
|
+
) -> None:
|
|
46
|
+
super().__init__()
|
|
47
|
+
|
|
48
|
+
self._service = service
|
|
49
|
+
self._options = options
|
|
50
|
+
|
|
51
|
+
async def get_next_ai_messages_streamed(
|
|
52
|
+
self,
|
|
53
|
+
chat: mc.Chat,
|
|
54
|
+
delta_callback: ta.Callable[[mc.AiChoiceDelta], ta.Awaitable[None]] | None = None,
|
|
55
|
+
) -> mc.AiChat:
|
|
56
|
+
lst: list[str] = []
|
|
57
|
+
|
|
58
|
+
async with (await self._service.invoke(mc.ChatChoicesStreamRequest(chat, self._options or []))).v as st_resp:
|
|
59
|
+
async for o in st_resp:
|
|
60
|
+
choice = check.single(o.choices)
|
|
61
|
+
|
|
62
|
+
for delta in choice.deltas:
|
|
63
|
+
if delta_callback is not None:
|
|
64
|
+
await delta_callback(delta)
|
|
65
|
+
|
|
66
|
+
c = check.isinstance(delta, mc.ContentAiChoiceDelta).c # noqa
|
|
67
|
+
if c is not None:
|
|
68
|
+
lst.append(check.isinstance(c, str))
|
|
69
|
+
|
|
70
|
+
return [mc.AiMessage(''.join(lst))]
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
import abc
|
|
2
|
+
import typing as ta
|
|
3
|
+
|
|
4
|
+
from omlish import lang
|
|
5
|
+
|
|
6
|
+
from ...... import minichain as mc
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
##
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class AiChatGenerator(lang.Abstract):
|
|
13
|
+
@abc.abstractmethod
|
|
14
|
+
def get_next_ai_messages(self, chat: mc.Chat) -> ta.Awaitable[mc.AiChat]:
|
|
15
|
+
raise NotImplementedError
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class StreamAiChatGenerator(AiChatGenerator, lang.Abstract):
|
|
19
|
+
def get_next_ai_messages(self, chat: mc.Chat) -> ta.Awaitable[mc.AiChat]:
|
|
20
|
+
return self.get_next_ai_messages_streamed(chat)
|
|
21
|
+
|
|
22
|
+
@abc.abstractmethod
|
|
23
|
+
def get_next_ai_messages_streamed(
|
|
24
|
+
self,
|
|
25
|
+
chat: mc.Chat,
|
|
26
|
+
delta_callback: ta.Callable[[mc.AiChoiceDelta], ta.Awaitable[None]] | None = None,
|
|
27
|
+
) -> ta.Awaitable[mc.AiChat]:
|
|
28
|
+
raise NotImplementedError
|
|
File without changes
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
import dataclasses as dc
|
|
2
|
+
|
|
3
|
+
from omlish import lang
|
|
4
|
+
|
|
5
|
+
from ...... import minichain as mc
|
|
6
|
+
from .types import ChatState
|
|
7
|
+
from .types import ChatStateManager
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
##
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class InMemoryChatStateManager(ChatStateManager):
|
|
14
|
+
def __init__(self, initial_state: ChatState | None = None) -> None:
|
|
15
|
+
super().__init__()
|
|
16
|
+
|
|
17
|
+
if initial_state is None:
|
|
18
|
+
initial_state = ChatState()
|
|
19
|
+
self._state = initial_state
|
|
20
|
+
|
|
21
|
+
async def get_state(self) -> ChatState:
|
|
22
|
+
return self._state
|
|
23
|
+
|
|
24
|
+
async def clear_state(self) -> ChatState:
|
|
25
|
+
self._state = ChatState()
|
|
26
|
+
return self._state
|
|
27
|
+
|
|
28
|
+
async def extend_chat(self, chat_additions: mc.Chat) -> ChatState:
|
|
29
|
+
self._state = dc.replace(
|
|
30
|
+
self._state,
|
|
31
|
+
chat=[*self._state.chat, *chat_additions],
|
|
32
|
+
updated_at=lang.utcnow(),
|
|
33
|
+
)
|
|
34
|
+
return self._state
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
import dataclasses as dc
|
|
2
|
+
|
|
3
|
+
from omlish import check
|
|
4
|
+
from omlish import lang
|
|
5
|
+
|
|
6
|
+
from ...... import minichain as mc
|
|
7
|
+
from .....state import StateStorage
|
|
8
|
+
from .types import ChatState
|
|
9
|
+
from .types import ChatStateManager
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
##
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class StateStorageChatStateManager(ChatStateManager):
|
|
16
|
+
def __init__(
|
|
17
|
+
self,
|
|
18
|
+
*,
|
|
19
|
+
storage: StateStorage,
|
|
20
|
+
key: str = 'chat',
|
|
21
|
+
) -> None:
|
|
22
|
+
super().__init__()
|
|
23
|
+
|
|
24
|
+
self._storage = storage
|
|
25
|
+
self._key = check.non_empty_str(key)
|
|
26
|
+
|
|
27
|
+
self._state: ChatState | None = None
|
|
28
|
+
|
|
29
|
+
async def get_state(self) -> ChatState:
|
|
30
|
+
if self._state is not None:
|
|
31
|
+
return self._state
|
|
32
|
+
state: ChatState | None = await self._storage.load_state(self._key, ChatState)
|
|
33
|
+
if state is None:
|
|
34
|
+
state = ChatState()
|
|
35
|
+
self._state = state
|
|
36
|
+
return state
|
|
37
|
+
|
|
38
|
+
async def clear_state(self) -> ChatState:
|
|
39
|
+
state = ChatState()
|
|
40
|
+
await self._storage.save_state(self._key, state, ChatState)
|
|
41
|
+
self._state = state
|
|
42
|
+
return state
|
|
43
|
+
|
|
44
|
+
async def extend_chat(self, chat_additions: mc.Chat) -> ChatState:
|
|
45
|
+
state = await self.get_state()
|
|
46
|
+
state = dc.replace(
|
|
47
|
+
state,
|
|
48
|
+
chat=[*state.chat, *chat_additions],
|
|
49
|
+
updated_at=lang.utcnow(),
|
|
50
|
+
)
|
|
51
|
+
await self._storage.save_state(self._key, state, ChatState)
|
|
52
|
+
self._state = state
|
|
53
|
+
return state
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
import abc
|
|
2
|
+
import dataclasses as dc
|
|
3
|
+
import datetime
|
|
4
|
+
import typing as ta
|
|
5
|
+
|
|
6
|
+
from omlish import lang
|
|
7
|
+
|
|
8
|
+
from ...... import minichain as mc
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
##
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dc.dataclass(frozen=True)
|
|
15
|
+
class ChatState:
|
|
16
|
+
name: str | None = None
|
|
17
|
+
|
|
18
|
+
created_at: datetime.datetime = dc.field(default_factory=lang.utcnow)
|
|
19
|
+
updated_at: datetime.datetime = dc.field(default_factory=lang.utcnow)
|
|
20
|
+
|
|
21
|
+
chat: mc.Chat = ()
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
##
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class ChatStateManager(lang.Abstract):
|
|
28
|
+
@abc.abstractmethod
|
|
29
|
+
def get_state(self) -> ta.Awaitable[ChatState]:
|
|
30
|
+
raise NotImplementedError
|
|
31
|
+
|
|
32
|
+
@abc.abstractmethod
|
|
33
|
+
def clear_state(self) -> ta.Awaitable[ChatState]:
|
|
34
|
+
raise NotImplementedError
|
|
35
|
+
|
|
36
|
+
@abc.abstractmethod
|
|
37
|
+
def extend_chat(self, chat_additions: mc.Chat) -> ta.Awaitable[ChatState]:
|
|
38
|
+
raise NotImplementedError
|
|
File without changes
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
import typing as ta
|
|
3
|
+
|
|
4
|
+
from omlish import lang
|
|
5
|
+
|
|
6
|
+
from ...... import minichain as mc
|
|
7
|
+
from .types import UserChatInput
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
##
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class InteractiveUserChatInput(UserChatInput):
|
|
14
|
+
def __init__(
|
|
15
|
+
self,
|
|
16
|
+
string_input: ta.Callable[[], ta.Awaitable[str]] | None = None,
|
|
17
|
+
) -> None:
|
|
18
|
+
super().__init__()
|
|
19
|
+
|
|
20
|
+
if string_input is None:
|
|
21
|
+
string_input = lang.as_async(functools.partial(input, '> '))
|
|
22
|
+
self._string_input = string_input
|
|
23
|
+
|
|
24
|
+
async def get_next_user_messages(self) -> mc.UserChat:
|
|
25
|
+
try:
|
|
26
|
+
s = await self._string_input()
|
|
27
|
+
except EOFError:
|
|
28
|
+
return []
|
|
29
|
+
return [mc.UserMessage(s)]
|