ommlds 0.0.0.dev461__py3-none-any.whl → 0.0.0.dev463__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ommlds might be problematic. Click here for more details.

Files changed (47) hide show
  1. ommlds/.omlish-manifests.json +4 -5
  2. ommlds/__about__.py +2 -2
  3. ommlds/backends/mlx/loading.py +58 -1
  4. ommlds/cli/main.py +18 -2
  5. ommlds/cli/sessions/chat/state.py +3 -3
  6. ommlds/cli/sessions/chat2/__init__.py +0 -0
  7. ommlds/cli/sessions/chat2/_inject.py +105 -0
  8. ommlds/cli/sessions/chat2/backends/__init__.py +0 -0
  9. ommlds/cli/sessions/chat2/backends/catalog.py +56 -0
  10. ommlds/cli/sessions/chat2/backends/types.py +36 -0
  11. ommlds/cli/sessions/chat2/chat/__init__.py +0 -0
  12. ommlds/cli/sessions/chat2/chat/ai/__init__.py +0 -0
  13. ommlds/cli/sessions/chat2/chat/ai/rendering.py +67 -0
  14. ommlds/cli/sessions/chat2/chat/ai/services.py +70 -0
  15. ommlds/cli/sessions/chat2/chat/ai/types.py +28 -0
  16. ommlds/cli/sessions/chat2/chat/state/__init__.py +0 -0
  17. ommlds/cli/sessions/chat2/chat/state/inmemory.py +34 -0
  18. ommlds/cli/sessions/chat2/chat/state/storage.py +53 -0
  19. ommlds/cli/sessions/chat2/chat/state/types.py +38 -0
  20. ommlds/cli/sessions/chat2/chat/user/__init__.py +0 -0
  21. ommlds/cli/sessions/chat2/chat/user/interactive.py +29 -0
  22. ommlds/cli/sessions/chat2/chat/user/oneshot.py +25 -0
  23. ommlds/cli/sessions/chat2/chat/user/types.py +15 -0
  24. ommlds/cli/sessions/chat2/configs.py +33 -0
  25. ommlds/cli/sessions/chat2/content/__init__.py +0 -0
  26. ommlds/cli/sessions/chat2/content/messages.py +30 -0
  27. ommlds/cli/sessions/chat2/content/strings.py +42 -0
  28. ommlds/cli/sessions/chat2/driver.py +43 -0
  29. ommlds/cli/sessions/chat2/inject.py +143 -0
  30. ommlds/cli/sessions/chat2/phases.py +55 -0
  31. ommlds/cli/sessions/chat2/rendering/__init__.py +0 -0
  32. ommlds/cli/sessions/chat2/rendering/markdown.py +52 -0
  33. ommlds/cli/sessions/chat2/rendering/raw.py +73 -0
  34. ommlds/cli/sessions/chat2/rendering/types.py +21 -0
  35. ommlds/cli/sessions/chat2/session.py +27 -0
  36. ommlds/cli/sessions/chat2/tools/__init__.py +0 -0
  37. ommlds/cli/sessions/chat2/tools/confirmation.py +46 -0
  38. ommlds/cli/sessions/chat2/tools/execution.py +53 -0
  39. ommlds/cli/sessions/inject.py +6 -1
  40. ommlds/cli/state.py +40 -23
  41. ommlds/minichain/backends/impls/anthropic/names.py +3 -4
  42. {ommlds-0.0.0.dev461.dist-info → ommlds-0.0.0.dev463.dist-info}/METADATA +7 -7
  43. {ommlds-0.0.0.dev461.dist-info → ommlds-0.0.0.dev463.dist-info}/RECORD +47 -14
  44. {ommlds-0.0.0.dev461.dist-info → ommlds-0.0.0.dev463.dist-info}/WHEEL +0 -0
  45. {ommlds-0.0.0.dev461.dist-info → ommlds-0.0.0.dev463.dist-info}/entry_points.txt +0 -0
  46. {ommlds-0.0.0.dev461.dist-info → ommlds-0.0.0.dev463.dist-info}/licenses/LICENSE +0 -0
  47. {ommlds-0.0.0.dev461.dist-info → ommlds-0.0.0.dev463.dist-info}/top_level.txt +0 -0
File without changes
@@ -0,0 +1,29 @@
1
+ import functools
2
+ import typing as ta
3
+
4
+ from omlish import lang
5
+
6
+ from ...... import minichain as mc
7
+ from .types import UserChatInput
8
+
9
+
10
+ ##
11
+
12
+
13
+ class InteractiveUserChatInput(UserChatInput):
14
+ def __init__(
15
+ self,
16
+ string_input: ta.Callable[[], ta.Awaitable[str]] | None = None,
17
+ ) -> None:
18
+ super().__init__()
19
+
20
+ if string_input is None:
21
+ string_input = lang.as_async(functools.partial(input, '> '))
22
+ self._string_input = string_input
23
+
24
+ async def get_next_user_messages(self) -> mc.UserChat:
25
+ try:
26
+ s = await self._string_input()
27
+ except EOFError:
28
+ return []
29
+ return [mc.UserMessage(s)]
@@ -0,0 +1,25 @@
1
+ import typing as ta
2
+
3
+ from ...... import minichain as mc
4
+ from .types import UserChatInput
5
+
6
+
7
+ ##
8
+
9
+
10
+ OneshotUserChatInputInitialChat = ta.NewType('OneshotUserChatInputInitialChat', mc.UserChat)
11
+
12
+
13
+ class OneshotUserChatInput(UserChatInput):
14
+ def __init__(
15
+ self,
16
+ initial_chat: OneshotUserChatInputInitialChat,
17
+ ) -> None:
18
+ super().__init__()
19
+
20
+ self._pending_chat: mc.UserChat | None = initial_chat
21
+
22
+ async def get_next_user_messages(self) -> mc.UserChat:
23
+ ret = self._pending_chat
24
+ self._pending_chat = None
25
+ return ret or []
@@ -0,0 +1,15 @@
1
+ import abc
2
+ import typing as ta
3
+
4
+ from omlish import lang
5
+
6
+ from ...... import minichain as mc
7
+
8
+
9
+ ##
10
+
11
+
12
+ class UserChatInput(lang.Abstract):
13
+ @abc.abstractmethod
14
+ def get_next_user_messages(self) -> ta.Awaitable[mc.UserChat]:
15
+ raise NotImplementedError
@@ -0,0 +1,33 @@
1
+ import dataclasses as dc
2
+ import typing as ta
3
+
4
+ from .... import minichain as mc
5
+
6
+
7
+ ##
8
+
9
+
10
+ DEFAULT_CHAT_MODEL_BACKEND = 'openai'
11
+
12
+
13
+ ##
14
+
15
+
16
+ @dc.dataclass(frozen=True)
17
+ class ChatConfig:
18
+ _: dc.KW_ONLY
19
+
20
+ backend: str | None = None
21
+ model_name: str | None = None
22
+
23
+ state: ta.Literal['new', 'continue', 'ephemeral'] = 'continue'
24
+
25
+ initial_content: mc.Content | None = None
26
+ interactive: bool = False
27
+
28
+ silent: bool = False
29
+ markdown: bool = False
30
+
31
+ stream: bool = False
32
+
33
+ dangerous_no_tool_confirmation: bool = False
File without changes
@@ -0,0 +1,30 @@
1
+ import abc
2
+
3
+ from omlish import check
4
+ from omlish import lang
5
+
6
+ from ..... import minichain as mc
7
+
8
+
9
+ ##
10
+
11
+
12
+ class MessageContentExtractor(lang.Abstract):
13
+ @abc.abstractmethod
14
+ def extract_message_content(self, message: mc.Message) -> mc.Content | None:
15
+ raise NotImplementedError
16
+
17
+
18
+ class MessageContentExtractorImpl(MessageContentExtractor):
19
+ def extract_message_content(self, message: mc.Message) -> mc.Content | None:
20
+ if isinstance(message, (mc.SystemMessage, mc.UserMessage, mc.AiMessage)):
21
+ if message.c is not None:
22
+ return check.isinstance(message.c, str)
23
+ else:
24
+ return None
25
+
26
+ elif isinstance(message, mc.ToolUseResultMessage):
27
+ return check.isinstance(message.tur.c, str)
28
+
29
+ else:
30
+ raise TypeError(message)
@@ -0,0 +1,42 @@
1
+ import abc
2
+ import typing as ta
3
+
4
+ from omlish import lang
5
+ from omlish.formats import json
6
+
7
+ from ..... import minichain as mc
8
+
9
+
10
+ ##
11
+
12
+
13
+ class ContentStringifier(lang.Abstract):
14
+ @abc.abstractmethod
15
+ def stringify_content(self, content: mc.Content) -> str | None:
16
+ raise NotImplementedError
17
+
18
+
19
+ class ContentStringifierImpl(ContentStringifier):
20
+ def stringify_content(self, content: mc.Content) -> str | None:
21
+ if isinstance(content, str):
22
+ return content
23
+
24
+ elif isinstance(content, mc.JsonContent):
25
+ return json.dumps_pretty(content.v)
26
+
27
+ else:
28
+ raise TypeError(content)
29
+
30
+
31
+ class HasContentStringifier(lang.Abstract):
32
+ def __init__(
33
+ self,
34
+ *args: ta.Any,
35
+ content_stringifier: ContentStringifier | None = None,
36
+ **kwargs: ta.Any,
37
+ ) -> None:
38
+ super().__init__(*args, **kwargs)
39
+
40
+ if content_stringifier is None:
41
+ content_stringifier = ContentStringifierImpl()
42
+ self._content_stringifier = content_stringifier
@@ -0,0 +1,43 @@
1
+ from .chat.ai.types import AiChatGenerator
2
+ from .chat.state.types import ChatStateManager
3
+ from .chat.user.types import UserChatInput
4
+ from .phases import ChatPhase
5
+ from .phases import ChatPhaseManager
6
+
7
+
8
+ ##
9
+
10
+
11
+ class ChatDriver:
12
+ def __init__(
13
+ self,
14
+ *,
15
+ phases: ChatPhaseManager,
16
+ ai_chat_generator: AiChatGenerator,
17
+ user_chat_input: UserChatInput,
18
+ chat_state_manager: ChatStateManager,
19
+ ):
20
+ super().__init__()
21
+
22
+ self._phases = phases
23
+ self._ai_chat_generator = ai_chat_generator
24
+ self._user_chat_input = user_chat_input
25
+ self._chat_state_manager = chat_state_manager
26
+
27
+ async def run(self) -> None:
28
+ await self._phases.set_phase(ChatPhase.STARTING)
29
+ await self._phases.set_phase(ChatPhase.STARTED)
30
+
31
+ while True:
32
+ next_user_chat = await self._user_chat_input.get_next_user_messages()
33
+ if not next_user_chat:
34
+ break
35
+
36
+ prev_user_chat = (await self._chat_state_manager.get_state()).chat
37
+
38
+ next_ai_chat = await self._ai_chat_generator.get_next_ai_messages([*prev_user_chat, *next_user_chat])
39
+
40
+ await self._chat_state_manager.extend_chat([*next_user_chat, *next_ai_chat])
41
+
42
+ await self._phases.set_phase(ChatPhase.STOPPING)
43
+ await self._phases.set_phase(ChatPhase.STOPPED)
@@ -0,0 +1,143 @@
1
+ import typing as ta
2
+
3
+ from omlish import inject as inj
4
+ from omlish import lang
5
+
6
+ from .... import minichain as mc
7
+ from . import _inject as _inj
8
+ from .configs import DEFAULT_CHAT_MODEL_BACKEND
9
+ from .configs import ChatConfig
10
+
11
+
12
+ ItemT = ta.TypeVar('ItemT')
13
+
14
+
15
+ ##
16
+
17
+
18
+ CHAT_OPTIONS = inj.items_binder_helper[mc.ChatChoicesOption](_inj.ChatChoicesServiceOptions)
19
+ BACKEND_CONFIGS = inj.items_binder_helper[mc.Config](_inj.BackendConfigs)
20
+ PHASE_CALLBACKS = inj.items_binder_helper[_inj.ChatPhaseCallback](_inj.ChatPhaseCallbacks)
21
+
22
+
23
+ ##
24
+
25
+
26
+ def bind_chat(cfg: ChatConfig) -> inj.Elements:
27
+ els: list[inj.Elemental] = []
28
+
29
+ #
30
+
31
+ els.extend([
32
+ CHAT_OPTIONS.bind_items_provider(singleton=True),
33
+ BACKEND_CONFIGS.bind_items_provider(singleton=True),
34
+ PHASE_CALLBACKS.bind_items_provider(singleton=True),
35
+ ])
36
+
37
+ #
38
+
39
+ if cfg.state in ('continue', 'new'):
40
+ els.extend([
41
+ inj.bind(_inj.StateStorageChatStateManager, singleton=True),
42
+ inj.bind(_inj.ChatStateManager, to_key=_inj.StateStorageChatStateManager),
43
+ ])
44
+
45
+ if cfg.state == 'new':
46
+ els.append(PHASE_CALLBACKS.bind_item(to_fn=lang.typed_lambda(cm=_inj.ChatStateManager)(
47
+ lambda cm: _inj.ChatPhaseCallback(_inj.ChatPhase.STARTING, cm.clear_state),
48
+ )))
49
+
50
+ elif cfg.state == 'ephemeral':
51
+ els.extend([
52
+ inj.bind(_inj.InMemoryChatStateManager, singleton=True),
53
+ inj.bind(_inj.ChatStateManager, to_key=_inj.InMemoryChatStateManager),
54
+ ])
55
+
56
+ else:
57
+ raise TypeError(cfg.state)
58
+
59
+ #
60
+
61
+ if cfg.interactive:
62
+ if cfg.initial_content is not None:
63
+ async def add_initial_content(cm: '_inj.ChatStateManager') -> None:
64
+ await cm.extend_chat([mc.UserMessage(cfg.initial_content)])
65
+
66
+ els.append(PHASE_CALLBACKS.bind_item(to_fn=lang.typed_lambda(cm=_inj.ChatStateManager)(
67
+ lambda cm: _inj.ChatPhaseCallback(_inj.ChatPhase.STARTED, lambda: add_initial_content(cm)),
68
+ )))
69
+
70
+ raise NotImplementedError
71
+
72
+ els.extend([
73
+ inj.bind(_inj.InteractiveUserChatInput, singleton=True),
74
+ inj.bind(_inj.UserChatInput, to_key=_inj.InteractiveUserChatInput),
75
+ ])
76
+
77
+ else:
78
+ if cfg.initial_content is None:
79
+ raise ValueError('Initial content is required for non-interactive chat')
80
+
81
+ els.extend([
82
+ inj.bind(_inj.OneshotUserChatInputInitialChat, to_const=[mc.UserMessage(cfg.initial_content)]),
83
+ inj.bind(_inj.OneshotUserChatInput, singleton=True),
84
+ inj.bind(_inj.UserChatInput, to_key=_inj.OneshotUserChatInput),
85
+ ])
86
+
87
+ #
88
+
89
+ if cfg.stream:
90
+ raise NotImplementedError
91
+
92
+ else:
93
+ ai_stack = inj.wrapper_binder_helper(_inj.AiChatGenerator)
94
+
95
+ els.append(ai_stack.push_bind(to_ctor=_inj.ChatChoicesServiceAiChatGenerator, singleton=True))
96
+
97
+ if not cfg.silent:
98
+ if cfg.markdown:
99
+ els.extend([
100
+ inj.bind(_inj.MarkdownContentRendering, singleton=True),
101
+ inj.bind(_inj.ContentRendering, to_key=_inj.MarkdownContentRendering),
102
+ ])
103
+ else:
104
+ els.extend([
105
+ inj.bind(_inj.RawContentRendering, singleton=True),
106
+ inj.bind(_inj.ContentRendering, to_key=_inj.RawContentRendering),
107
+ ])
108
+
109
+ els.append(ai_stack.push_bind(to_ctor=_inj.RenderingAiChatGenerator, singleton=True))
110
+
111
+ els.append(inj.bind(_inj.AiChatGenerator, to_key=ai_stack.top))
112
+
113
+ #
114
+
115
+ els.append(inj.bind(_inj.BackendName, to_const=cfg.backend or DEFAULT_CHAT_MODEL_BACKEND))
116
+
117
+ els.extend([
118
+ inj.bind(_inj.CatalogChatChoicesServiceBackendProvider),
119
+ inj.bind(_inj.ChatChoicesServiceBackendProvider, to_key=_inj.CatalogChatChoicesServiceBackendProvider),
120
+ ])
121
+
122
+ #
123
+
124
+ els.extend([
125
+ inj.bind(_inj.ToolUseExecutorImpl, singleton=True),
126
+ inj.bind(_inj.ToolUseExecutor, to_key=_inj.ToolUseExecutorImpl),
127
+ ])
128
+
129
+ #
130
+
131
+ els.extend([
132
+ inj.bind(_inj.ChatPhaseManager, singleton=True),
133
+ ])
134
+
135
+ #
136
+
137
+ els.extend([
138
+ inj.bind(_inj.ChatDriver, singleton=True),
139
+ ])
140
+
141
+ #
142
+
143
+ return inj.as_elements(*els)
@@ -0,0 +1,55 @@
1
+ import enum
2
+ import typing as ta
3
+
4
+ from omlish import check
5
+ from omlish import collections as col
6
+ from omlish import dataclasses as dc
7
+
8
+
9
+ ##
10
+
11
+
12
+ class ChatPhase(enum.Enum):
13
+ NEW = enum.auto()
14
+
15
+ STARTING = enum.auto()
16
+ STARTED = enum.auto()
17
+
18
+ STOPPING = enum.auto()
19
+ STOPPED = enum.auto()
20
+
21
+
22
+ ##
23
+
24
+
25
+ @dc.dataclass(frozen=True)
26
+ class ChatPhaseCallback:
27
+ phase: ChatPhase = dc.xfield(validate=lambda v: v != ChatPhase.NEW)
28
+ fn: ta.Callable[[], ta.Awaitable[None]] = dc.xfield()
29
+
30
+
31
+ ChatPhaseCallbacks = ta.NewType('ChatPhaseCallbacks', ta.Sequence[ChatPhaseCallback])
32
+
33
+
34
+ ##
35
+
36
+
37
+ class ChatPhaseManager:
38
+ def __init__(self, callbacks: ChatPhaseCallbacks) -> None:
39
+ super().__init__()
40
+
41
+ self._callbacks = callbacks
42
+ self._callbacks_by_phase = col.multi_map_by(lambda cb: cb.phase, callbacks)
43
+
44
+ check.state(not self._callbacks_by_phase.get(ChatPhase.NEW))
45
+
46
+ self._phase = ChatPhase.NEW
47
+
48
+ @property
49
+ def phase(self) -> ChatPhase:
50
+ return self._phase
51
+
52
+ async def set_phase(self, phase: ChatPhase) -> None:
53
+ self._phase = phase
54
+ for cb in self._callbacks_by_phase.get(phase, ()):
55
+ await cb.fn()
File without changes
@@ -0,0 +1,52 @@
1
+ import typing as ta
2
+
3
+ from omdev.tui import rich
4
+ from omlish import lang
5
+
6
+ from ..... import minichain as mc
7
+ from ..content.strings import ContentStringifier
8
+ from ..content.strings import HasContentStringifier
9
+ from .types import ContentRendering
10
+ from .types import StreamContentRendering
11
+
12
+
13
+ ##
14
+
15
+
16
+ class MarkdownContentRendering(ContentRendering, HasContentStringifier):
17
+ def __init__(
18
+ self,
19
+ *,
20
+ content_stringifier: ContentStringifier | None = None,
21
+ ) -> None:
22
+ super().__init__(content_stringifier=content_stringifier)
23
+
24
+ async def render_content(self, content: mc.Content) -> None:
25
+ if (s := self._content_stringifier.stringify_content(content)) is not None and (s := s.strip()):
26
+ rich.Console().print(rich.Markdown(s))
27
+
28
+
29
+ class MarkdownStreamContentRendering(StreamContentRendering, HasContentStringifier):
30
+ def __init__(
31
+ self,
32
+ *,
33
+ content_stringifier: ContentStringifier | None = None,
34
+ ) -> None:
35
+ super().__init__(content_stringifier=content_stringifier)
36
+
37
+ @ta.final
38
+ class _ContextInstance(ContentRendering, lang.AsyncExitStacked):
39
+ def __init__(self, owner: 'MarkdownStreamContentRendering') -> None:
40
+ self._owner = owner
41
+
42
+ _ir: rich.MarkdownLiveStream
43
+
44
+ async def _async_enter_contexts(self) -> None:
45
+ self._ir = self._enter_context(rich.IncrementalMarkdownLiveStream())
46
+
47
+ async def render_content(self, content: mc.Content) -> None:
48
+ if (s := self._owner._content_stringifier.stringify_content(content)) is not None: # noqa: SLF001
49
+ self._ir.feed(s)
50
+
51
+ def create_context(self) -> ta.AsyncContextManager[ContentRendering]:
52
+ return MarkdownStreamContentRendering._ContextInstance(self)
@@ -0,0 +1,73 @@
1
+ import typing as ta
2
+
3
+ from omlish import lang
4
+
5
+ from ..... import minichain as mc
6
+ from ..content.strings import ContentStringifier
7
+ from ..content.strings import HasContentStringifier
8
+ from .types import ContentRendering
9
+ from .types import StreamContentRendering
10
+
11
+
12
+ ##
13
+
14
+
15
+ class RawContentRendering(ContentRendering, HasContentStringifier):
16
+ def __init__(
17
+ self,
18
+ *,
19
+ printer: ta.Callable[[str], ta.Awaitable[None]] | None = None,
20
+ content_stringifier: ContentStringifier | None = None,
21
+ ) -> None:
22
+ super().__init__(content_stringifier=content_stringifier)
23
+
24
+ if printer is None:
25
+ printer = lang.as_async(print)
26
+ self._printer = printer
27
+
28
+ async def render_content(self, content: mc.Content) -> None:
29
+ if (s := self._content_stringifier.stringify_content(content)) is not None:
30
+ await self._printer(s)
31
+
32
+
33
+ class RawContentStreamRendering(StreamContentRendering, HasContentStringifier):
34
+ class Output(ta.Protocol):
35
+ def write(self, s: str) -> ta.Awaitable[None]: ...
36
+ def flush(self) -> ta.Awaitable[None]: ...
37
+
38
+ class PrintOutput:
39
+ async def write(self, s: str) -> None:
40
+ print(s, end='', flush=True)
41
+
42
+ async def flush(self) -> None:
43
+ print(flush=True)
44
+
45
+ def __init__(
46
+ self,
47
+ *,
48
+ output: Output | None = None,
49
+ content_stringifier: ContentStringifier | None = None,
50
+ ) -> None:
51
+ super().__init__(content_stringifier=content_stringifier)
52
+
53
+ if output is None:
54
+ output = RawContentStreamRendering.PrintOutput()
55
+ self._output = output
56
+
57
+ @ta.final
58
+ class _ContextInstance(ContentRendering, ta.AsyncContextManager):
59
+ def __init__(self, owner: 'RawContentStreamRendering') -> None:
60
+ self._owner = owner
61
+
62
+ async def __aenter__(self) -> ta.Self:
63
+ return self
64
+
65
+ async def __aexit__(self, *exc_info) -> None:
66
+ await self._owner._output.flush() # noqa: SLF001
67
+
68
+ async def render_content(self, content: mc.Content) -> None:
69
+ if (s := self._owner._content_stringifier.stringify_content(content)) is not None: # noqa: SLF001
70
+ await self._owner._output.write(s) # noqa: SLF001
71
+
72
+ def create_context(self) -> ta.AsyncContextManager[ContentRendering]:
73
+ return RawContentStreamRendering._ContextInstance(self)
@@ -0,0 +1,21 @@
1
+ import abc
2
+ import typing as ta
3
+
4
+ from omlish import lang
5
+
6
+ from ..... import minichain as mc
7
+
8
+
9
+ ##
10
+
11
+
12
+ class ContentRendering(lang.Abstract):
13
+ @abc.abstractmethod
14
+ def render_content(self, content: mc.Content) -> ta.Awaitable[None]:
15
+ raise NotImplementedError
16
+
17
+
18
+ class StreamContentRendering(lang.Abstract):
19
+ @abc.abstractmethod
20
+ def create_context(self) -> ta.AsyncContextManager[ContentRendering]:
21
+ raise NotImplementedError
@@ -0,0 +1,27 @@
1
+ import dataclasses as dc
2
+
3
+ from ..base import Session
4
+ from .configs import ChatConfig
5
+ from .driver import ChatDriver
6
+
7
+
8
+ ##
9
+
10
+
11
+ class Chat2Session(Session['Chat2Session.Config']):
12
+ @dc.dataclass(frozen=True)
13
+ class Config(Session.Config, ChatConfig):
14
+ pass
15
+
16
+ def __init__(
17
+ self,
18
+ config: Config,
19
+ *,
20
+ driver: ChatDriver,
21
+ ) -> None:
22
+ super().__init__(config)
23
+
24
+ self._driver = driver
25
+
26
+ async def run(self) -> None:
27
+ await self._driver.run()
File without changes
@@ -0,0 +1,46 @@
1
+ import abc
2
+ import typing as ta
3
+
4
+ from omlish import lang
5
+ from omlish.formats import json
6
+ from omlish.term.confirm import confirm_action
7
+
8
+ from ..... import minichain as mc
9
+
10
+
11
+ ##
12
+
13
+
14
+ class ToolExecutionRequestDeniedError(Exception):
15
+ pass
16
+
17
+
18
+ class ToolExecutionConfirmation(lang.Abstract):
19
+ @abc.abstractmethod
20
+ def confirm_tool_execution_or_raise(
21
+ self,
22
+ use: mc.ToolUse,
23
+ entry: mc.ToolCatalogEntry,
24
+ ) -> ta.Awaitable[None]:
25
+ raise NotImplementedError
26
+
27
+
28
+ ##
29
+
30
+
31
+ class InteractiveToolExecutionConfirmation(ToolExecutionConfirmation):
32
+ async def confirm_tool_execution_or_raise(
33
+ self,
34
+ use: mc.ToolUse,
35
+ entry: mc.ToolCatalogEntry,
36
+ ) -> None:
37
+ tr_dct = dict(
38
+ id=use.id,
39
+ name=entry.spec.name,
40
+ args=use.args,
41
+ # spec=msh.marshal(tce.spec),
42
+ )
43
+ cr = confirm_action(f'Execute requested tool?\n\n{json.dumps_pretty(tr_dct)}') # FIXME: async lol
44
+
45
+ if not cr:
46
+ raise ToolExecutionRequestDeniedError