ommlds 0.0.0.dev448__py3-none-any.whl → 0.0.0.dev450__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ommlds/.omlish-manifests.json +1 -1
- ommlds/backends/google/protocol/__init__.py +3 -0
- ommlds/backends/google/protocol/_marshal.py +16 -0
- ommlds/backends/google/protocol/types.py +303 -76
- ommlds/backends/mlx/generation.py +1 -1
- ommlds/cli/main.py +27 -6
- ommlds/cli/sessions/chat/code.py +114 -0
- ommlds/cli/sessions/chat/interactive.py +2 -5
- ommlds/cli/sessions/chat/printing.py +1 -4
- ommlds/cli/sessions/chat/prompt.py +8 -1
- ommlds/cli/sessions/chat/state.py +1 -0
- ommlds/cli/sessions/chat/tools.py +17 -7
- ommlds/cli/tools/config.py +1 -0
- ommlds/cli/tools/inject.py +11 -3
- ommlds/minichain/__init__.py +4 -0
- ommlds/minichain/backends/impls/google/chat.py +66 -11
- ommlds/minichain/backends/impls/google/tools.py +149 -0
- ommlds/minichain/lib/code/prompts.py +6 -0
- ommlds/minichain/lib/fs/binfiles.py +108 -0
- ommlds/minichain/lib/fs/context.py +112 -0
- ommlds/minichain/lib/fs/errors.py +95 -0
- ommlds/minichain/lib/fs/suggestions.py +36 -0
- ommlds/minichain/lib/fs/tools/__init__.py +0 -0
- ommlds/minichain/lib/fs/tools/ls.py +38 -0
- ommlds/minichain/lib/fs/tools/read.py +115 -0
- ommlds/minichain/lib/fs/tools/recursivels/__init__.py +0 -0
- ommlds/minichain/lib/fs/tools/recursivels/execution.py +40 -0
- ommlds/minichain/lib/todo/__init__.py +0 -0
- ommlds/minichain/lib/todo/context.py +27 -0
- ommlds/minichain/lib/todo/tools/__init__.py +0 -0
- ommlds/minichain/lib/todo/tools/read.py +39 -0
- ommlds/minichain/lib/todo/tools/write.py +275 -0
- ommlds/minichain/lib/todo/types.py +55 -0
- ommlds/minichain/tools/execution/context.py +34 -14
- ommlds/minichain/tools/execution/errors.py +15 -0
- ommlds/minichain/tools/execution/reflect.py +0 -3
- ommlds/minichain/tools/jsonschema.py +11 -1
- ommlds/minichain/tools/reflect.py +47 -15
- ommlds/minichain/tools/types.py +9 -0
- ommlds/minichain/utils.py +27 -0
- {ommlds-0.0.0.dev448.dist-info → ommlds-0.0.0.dev450.dist-info}/METADATA +3 -3
- {ommlds-0.0.0.dev448.dist-info → ommlds-0.0.0.dev450.dist-info}/RECORD +49 -29
- ommlds/minichain/lib/fs/ls/execution.py +0 -32
- /ommlds/minichain/lib/{fs/ls → code}/__init__.py +0 -0
- /ommlds/minichain/lib/fs/{ls → tools/recursivels}/rendering.py +0 -0
- /ommlds/minichain/lib/fs/{ls → tools/recursivels}/running.py +0 -0
- {ommlds-0.0.0.dev448.dist-info → ommlds-0.0.0.dev450.dist-info}/WHEEL +0 -0
- {ommlds-0.0.0.dev448.dist-info → ommlds-0.0.0.dev450.dist-info}/entry_points.txt +0 -0
- {ommlds-0.0.0.dev448.dist-info → ommlds-0.0.0.dev450.dist-info}/licenses/LICENSE +0 -0
- {ommlds-0.0.0.dev448.dist-info → ommlds-0.0.0.dev450.dist-info}/top_level.txt +0 -0
ommlds/cli/main.py
CHANGED
|
@@ -24,6 +24,7 @@ from omlish.subprocesses.sync import subprocesses
|
|
|
24
24
|
from .. import minichain as mc
|
|
25
25
|
from .inject import bind_main
|
|
26
26
|
from .sessions.base import Session
|
|
27
|
+
from .sessions.chat.code import CodeChatSession
|
|
27
28
|
from .sessions.chat.interactive import InteractiveChatSession
|
|
28
29
|
from .sessions.chat.prompt import PromptChatSession
|
|
29
30
|
from .sessions.completion.completion import CompletionSession
|
|
@@ -58,6 +59,7 @@ async def _a_main(args: ta.Any = None) -> None:
|
|
|
58
59
|
|
|
59
60
|
parser.add_argument('-e', '--editor', action='store_true')
|
|
60
61
|
parser.add_argument('-i', '--interactive', action='store_true')
|
|
62
|
+
parser.add_argument('-c', '--code', action='store_true')
|
|
61
63
|
parser.add_argument('-s', '--stream', action='store_true')
|
|
62
64
|
parser.add_argument('-M', '--markdown', action='store_true')
|
|
63
65
|
|
|
@@ -65,6 +67,7 @@ async def _a_main(args: ta.Any = None) -> None:
|
|
|
65
67
|
parser.add_argument('-j', '--image', action='store_true')
|
|
66
68
|
|
|
67
69
|
parser.add_argument('--enable-fs-tools', action='store_true')
|
|
70
|
+
parser.add_argument('--enable-todo-tools', action='store_true')
|
|
68
71
|
parser.add_argument('--enable-unsafe-bash-tool', action='store_true')
|
|
69
72
|
parser.add_argument('--enable-test-weather-tool', action='store_true')
|
|
70
73
|
parser.add_argument('--dangerous-no-tool-confirmation', action='store_true')
|
|
@@ -73,7 +76,7 @@ async def _a_main(args: ta.Any = None) -> None:
|
|
|
73
76
|
|
|
74
77
|
#
|
|
75
78
|
|
|
76
|
-
content: mc.Content
|
|
79
|
+
content: mc.Content | None
|
|
77
80
|
|
|
78
81
|
if args.image:
|
|
79
82
|
content = mc.ImageContent(pimg.open(check.non_empty_str(check.single(args.prompt))))
|
|
@@ -88,12 +91,19 @@ async def _a_main(args: ta.Any = None) -> None:
|
|
|
88
91
|
if args.prompt:
|
|
89
92
|
raise ValueError('Must not provide prompt')
|
|
90
93
|
|
|
94
|
+
elif args.code:
|
|
95
|
+
if args.prompt:
|
|
96
|
+
content = ' '.join(args.prompt)
|
|
97
|
+
else:
|
|
98
|
+
content = None
|
|
99
|
+
|
|
91
100
|
elif not args.prompt:
|
|
92
101
|
raise ValueError('Must provide prompt')
|
|
93
102
|
|
|
94
103
|
else:
|
|
95
104
|
prompt = ' '.join(args.prompt)
|
|
96
105
|
|
|
106
|
+
# FIXME: ptk / maysync
|
|
97
107
|
if not sys.stdin.isatty() and not pycharm.is_pycharm_hosted():
|
|
98
108
|
stdin_data = sys.stdin.read()
|
|
99
109
|
prompt = '\n'.join([prompt, stdin_data])
|
|
@@ -126,21 +136,31 @@ async def _a_main(args: ta.Any = None) -> None:
|
|
|
126
136
|
dangerous_no_tool_confirmation=bool(args.dangerous_no_tool_confirmation),
|
|
127
137
|
)
|
|
128
138
|
|
|
139
|
+
elif args.code:
|
|
140
|
+
session_cfg = CodeChatSession.Config(
|
|
141
|
+
backend=args.backend,
|
|
142
|
+
model_name=args.model_name,
|
|
143
|
+
new=bool(args.new),
|
|
144
|
+
dangerous_no_tool_confirmation=bool(args.dangerous_no_tool_confirmation),
|
|
145
|
+
initial_message=content, # noqa
|
|
146
|
+
markdown=bool(args.markdown),
|
|
147
|
+
)
|
|
148
|
+
|
|
129
149
|
elif args.embed:
|
|
130
150
|
session_cfg = EmbeddingSession.Config(
|
|
131
|
-
content, # noqa
|
|
151
|
+
check.not_none(content), # noqa
|
|
132
152
|
backend=args.backend,
|
|
133
153
|
)
|
|
134
154
|
|
|
135
155
|
elif args.completion:
|
|
136
156
|
session_cfg = CompletionSession.Config(
|
|
137
|
-
content, # noqa
|
|
157
|
+
check.not_none(content), # noqa
|
|
138
158
|
backend=args.backend,
|
|
139
159
|
)
|
|
140
160
|
|
|
141
161
|
else:
|
|
142
162
|
session_cfg = PromptChatSession.Config(
|
|
143
|
-
content, # noqa
|
|
163
|
+
check.not_none(content), # noqa
|
|
144
164
|
backend=args.backend,
|
|
145
165
|
model_name=args.model_name,
|
|
146
166
|
new=bool(args.new),
|
|
@@ -152,7 +172,8 @@ async def _a_main(args: ta.Any = None) -> None:
|
|
|
152
172
|
#
|
|
153
173
|
|
|
154
174
|
tools_config = ToolsConfig(
|
|
155
|
-
enable_fs_tools=args.enable_fs_tools,
|
|
175
|
+
enable_fs_tools=args.enable_fs_tools or args.code,
|
|
176
|
+
enable_todo_tools=args.enable_todo_tools or args.code,
|
|
156
177
|
enable_unsafe_bash_tool=args.enable_unsafe_bash_tool,
|
|
157
178
|
enable_test_weather_tool=args.enable_test_weather_tool,
|
|
158
179
|
)
|
|
@@ -162,7 +183,7 @@ async def _a_main(args: ta.Any = None) -> None:
|
|
|
162
183
|
with inj.create_managed_injector(bind_main(
|
|
163
184
|
session_cfg=session_cfg,
|
|
164
185
|
tools_config=tools_config,
|
|
165
|
-
enable_backend_strings=isinstance(session_cfg, PromptChatSession.Config),
|
|
186
|
+
enable_backend_strings=isinstance(session_cfg, (PromptChatSession.Config, CodeChatSession.Config)),
|
|
166
187
|
)) as injector:
|
|
167
188
|
await injector[Session].run()
|
|
168
189
|
|
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
import dataclasses as dc
|
|
2
|
+
import itertools
|
|
3
|
+
import os.path
|
|
4
|
+
|
|
5
|
+
from omlish import check
|
|
6
|
+
from omlish import lang
|
|
7
|
+
|
|
8
|
+
from .... import minichain as mc
|
|
9
|
+
from ....minichain.lib.code.prompts import CODE_AGENT_SYSTEM_PROMPT
|
|
10
|
+
from .base import DEFAULT_CHAT_MODEL_BACKEND
|
|
11
|
+
from .base import ChatOptions
|
|
12
|
+
from .base import ChatSession
|
|
13
|
+
from .printing import ChatSessionPrinter
|
|
14
|
+
from .state import ChatStateManager
|
|
15
|
+
from .tools import ToolExecRequestExecutor
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
with lang.auto_proxy_import(globals()):
|
|
19
|
+
from omdev import ptk
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
##
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class CodeChatSession(ChatSession['CodeChatSession.Config']):
|
|
26
|
+
@dc.dataclass(frozen=True)
|
|
27
|
+
class Config(ChatSession.Config):
|
|
28
|
+
_: dc.KW_ONLY
|
|
29
|
+
|
|
30
|
+
new: bool = False
|
|
31
|
+
|
|
32
|
+
backend: str | None = None
|
|
33
|
+
model_name: str | None = None
|
|
34
|
+
|
|
35
|
+
initial_message: mc.Content | None = None
|
|
36
|
+
|
|
37
|
+
def __init__(
|
|
38
|
+
self,
|
|
39
|
+
config: Config,
|
|
40
|
+
*,
|
|
41
|
+
state_manager: ChatStateManager,
|
|
42
|
+
chat_options: ChatOptions | None = None,
|
|
43
|
+
printer: ChatSessionPrinter,
|
|
44
|
+
backend_catalog: mc.BackendCatalog,
|
|
45
|
+
tool_exec_request_executor: ToolExecRequestExecutor,
|
|
46
|
+
) -> None:
|
|
47
|
+
super().__init__(config)
|
|
48
|
+
|
|
49
|
+
self._state_manager = state_manager
|
|
50
|
+
self._chat_options = chat_options
|
|
51
|
+
self._printer = printer
|
|
52
|
+
self._backend_catalog = backend_catalog
|
|
53
|
+
self._tool_exec_request_executor = tool_exec_request_executor
|
|
54
|
+
|
|
55
|
+
async def run(self) -> None:
|
|
56
|
+
if self._config.new:
|
|
57
|
+
self._state_manager.clear_state()
|
|
58
|
+
state = self._state_manager.extend_chat([
|
|
59
|
+
mc.SystemMessage(CODE_AGENT_SYSTEM_PROMPT),
|
|
60
|
+
])
|
|
61
|
+
|
|
62
|
+
else:
|
|
63
|
+
state = self._state_manager.get_state()
|
|
64
|
+
|
|
65
|
+
backend = self._config.backend
|
|
66
|
+
if backend is None:
|
|
67
|
+
backend = DEFAULT_CHAT_MODEL_BACKEND
|
|
68
|
+
|
|
69
|
+
# FIXME: lol
|
|
70
|
+
from ....minichain.lib.fs.context import FsContext
|
|
71
|
+
fs_tool_context = FsContext(root_dir=os.getcwd())
|
|
72
|
+
from ....minichain.lib.todo.context import TodoContext
|
|
73
|
+
todo_tool_context = TodoContext()
|
|
74
|
+
|
|
75
|
+
mdl: mc.ChatChoicesService
|
|
76
|
+
async with lang.async_maybe_managing(self._backend_catalog.get_backend(
|
|
77
|
+
mc.ChatChoicesService,
|
|
78
|
+
backend,
|
|
79
|
+
*([mc.ModelName(mn)] if (mn := self._config.model_name) is not None else []),
|
|
80
|
+
)) as mdl:
|
|
81
|
+
for i in itertools.count():
|
|
82
|
+
if not i and self._config.initial_message is not None:
|
|
83
|
+
req_msg = mc.UserMessage(self._config.initial_message)
|
|
84
|
+
else:
|
|
85
|
+
prompt = await ptk.prompt('> ')
|
|
86
|
+
req_msg = mc.UserMessage(prompt)
|
|
87
|
+
|
|
88
|
+
state = self._state_manager.extend_chat([req_msg])
|
|
89
|
+
|
|
90
|
+
while True:
|
|
91
|
+
response = await mdl.invoke(mc.ChatChoicesRequest(
|
|
92
|
+
state.chat,
|
|
93
|
+
(self._chat_options or []),
|
|
94
|
+
))
|
|
95
|
+
resp_msg = check.single(response.v).m
|
|
96
|
+
|
|
97
|
+
self._printer.print(resp_msg)
|
|
98
|
+
state = self._state_manager.extend_chat([resp_msg])
|
|
99
|
+
|
|
100
|
+
if not (trs := resp_msg.tool_exec_requests):
|
|
101
|
+
break
|
|
102
|
+
|
|
103
|
+
tool_resp_lst = []
|
|
104
|
+
for tr in trs:
|
|
105
|
+
trm = await self._tool_exec_request_executor.execute_tool_request(
|
|
106
|
+
tr,
|
|
107
|
+
fs_tool_context,
|
|
108
|
+
todo_tool_context,
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
self._printer.print(trm.c)
|
|
112
|
+
tool_resp_lst.append(trm)
|
|
113
|
+
|
|
114
|
+
state = self._state_manager.extend_chat(tool_resp_lst)
|
|
@@ -1,5 +1,4 @@
|
|
|
1
1
|
import dataclasses as dc
|
|
2
|
-
import typing as ta
|
|
3
2
|
|
|
4
3
|
from omlish import lang
|
|
5
4
|
|
|
@@ -10,10 +9,8 @@ from .printing import ChatSessionPrinter
|
|
|
10
9
|
from .state import ChatStateManager
|
|
11
10
|
|
|
12
11
|
|
|
13
|
-
|
|
12
|
+
with lang.auto_proxy_import(globals()):
|
|
14
13
|
from omdev import ptk
|
|
15
|
-
else:
|
|
16
|
-
ptk = lang.proxy_import('omdev.ptk')
|
|
17
14
|
|
|
18
15
|
|
|
19
16
|
##
|
|
@@ -60,7 +57,7 @@ class InteractiveChatSession(ChatSession['InteractiveChatSession.Config']):
|
|
|
60
57
|
*([mc.ModelName(mn)] if (mn := self._config.model_name) is not None else []),
|
|
61
58
|
)) as mdl:
|
|
62
59
|
while True:
|
|
63
|
-
prompt = ptk.prompt('> ')
|
|
60
|
+
prompt = await ptk.prompt('> ')
|
|
64
61
|
|
|
65
62
|
req_msg = mc.UserMessage(prompt)
|
|
66
63
|
|
|
@@ -7,12 +7,9 @@ from omlish import lang
|
|
|
7
7
|
from .... import minichain as mc
|
|
8
8
|
|
|
9
9
|
|
|
10
|
-
|
|
10
|
+
with lang.auto_proxy_import(globals()):
|
|
11
11
|
from omdev import ptk
|
|
12
12
|
from omdev.ptk import markdown as ptk_md
|
|
13
|
-
else:
|
|
14
|
-
ptk = lang.proxy_import('omdev.ptk')
|
|
15
|
-
ptk_md = lang.proxy_import('omdev.ptk.markdown')
|
|
16
13
|
|
|
17
14
|
|
|
18
15
|
##
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import dataclasses as dc
|
|
2
|
+
import os
|
|
2
3
|
|
|
3
4
|
from omlish import check
|
|
4
5
|
from omlish import lang
|
|
@@ -125,7 +126,13 @@ class PromptChatSession(ChatSession['PromptChatSession.Config']):
|
|
|
125
126
|
|
|
126
127
|
tr: mc.ToolExecRequest = check.single(check.not_none(trs))
|
|
127
128
|
|
|
128
|
-
|
|
129
|
+
# FIXME: lol
|
|
130
|
+
from ....minichain.lib.fs.context import FsContext
|
|
131
|
+
|
|
132
|
+
trm = await self._tool_exec_request_executor.execute_tool_request(
|
|
133
|
+
tr,
|
|
134
|
+
FsContext(root_dir=os.getcwd()),
|
|
135
|
+
)
|
|
129
136
|
|
|
130
137
|
print(trm.c)
|
|
131
138
|
new_chat.append(trm)
|
|
@@ -9,10 +9,8 @@ from omlish.formats import json
|
|
|
9
9
|
from .... import minichain as mc
|
|
10
10
|
|
|
11
11
|
|
|
12
|
-
|
|
12
|
+
with lang.auto_proxy_import(globals()):
|
|
13
13
|
from omdev import ptk
|
|
14
|
-
else:
|
|
15
|
-
ptk = lang.proxy_import('omdev.ptk')
|
|
16
14
|
|
|
17
15
|
|
|
18
16
|
##
|
|
@@ -49,8 +47,9 @@ class AskingToolExecutionConfirmation(ToolExecutionConfirmation):
|
|
|
49
47
|
) -> None:
|
|
50
48
|
tr_dct = dict(
|
|
51
49
|
id=tr.id,
|
|
52
|
-
|
|
50
|
+
name=tce.spec.name,
|
|
53
51
|
args=tr.args,
|
|
52
|
+
spec=msh.marshal(tce.spec),
|
|
54
53
|
)
|
|
55
54
|
cr = await ptk.strict_confirm(f'Execute requested tool?\n\n{json.dumps_pretty(tr_dct)}\n\n')
|
|
56
55
|
|
|
@@ -63,7 +62,11 @@ class AskingToolExecutionConfirmation(ToolExecutionConfirmation):
|
|
|
63
62
|
|
|
64
63
|
class ToolExecRequestExecutor(lang.Abstract):
|
|
65
64
|
@abc.abstractmethod
|
|
66
|
-
def execute_tool_request(
|
|
65
|
+
def execute_tool_request(
|
|
66
|
+
self,
|
|
67
|
+
tr: mc.ToolExecRequest,
|
|
68
|
+
*ctx_items: ta.Any,
|
|
69
|
+
) -> ta.Awaitable[mc.ToolExecResultMessage]:
|
|
67
70
|
raise NotImplementedError
|
|
68
71
|
|
|
69
72
|
|
|
@@ -79,13 +82,20 @@ class ToolExecRequestExecutorImpl(ToolExecRequestExecutor):
|
|
|
79
82
|
self._catalog = catalog
|
|
80
83
|
self._confirmation = confirmation
|
|
81
84
|
|
|
82
|
-
async def execute_tool_request(
|
|
85
|
+
async def execute_tool_request(
|
|
86
|
+
self,
|
|
87
|
+
tr: mc.ToolExecRequest,
|
|
88
|
+
*ctx_items: ta.Any,
|
|
89
|
+
) -> mc.ToolExecResultMessage:
|
|
83
90
|
tce = self._catalog.by_name[check.non_empty_str(tr.name)]
|
|
84
91
|
|
|
85
92
|
await self._confirmation.confirm_tool_execution_or_raise(tr, tce)
|
|
86
93
|
|
|
87
94
|
return await mc.execute_tool_request(
|
|
88
|
-
mc.ToolContext(
|
|
95
|
+
mc.ToolContext(
|
|
96
|
+
tr,
|
|
97
|
+
*ctx_items,
|
|
98
|
+
),
|
|
89
99
|
tce.executor(),
|
|
90
100
|
tr,
|
|
91
101
|
)
|
ommlds/cli/tools/config.py
CHANGED
ommlds/cli/tools/inject.py
CHANGED
|
@@ -47,13 +47,21 @@ def bind_tools(tools_config: ToolsConfig) -> inj.Elements:
|
|
|
47
47
|
#
|
|
48
48
|
|
|
49
49
|
if tools_config.enable_fs_tools:
|
|
50
|
-
from ...minichain.lib.fs.ls
|
|
51
|
-
|
|
50
|
+
from ...minichain.lib.fs.tools.ls import ls_tool
|
|
52
51
|
els.append(bind_tool(ls_tool()))
|
|
53
52
|
|
|
53
|
+
from ...minichain.lib.fs.tools.read import read_tool
|
|
54
|
+
els.append(bind_tool(read_tool()))
|
|
55
|
+
|
|
56
|
+
if tools_config.enable_todo_tools:
|
|
57
|
+
from ...minichain.lib.todo.tools.read import todo_read_tool
|
|
58
|
+
els.append(bind_tool(todo_read_tool()))
|
|
59
|
+
|
|
60
|
+
from ...minichain.lib.todo.tools.write import todo_write_tool
|
|
61
|
+
els.append(bind_tool(todo_write_tool()))
|
|
62
|
+
|
|
54
63
|
if tools_config.enable_unsafe_bash_tool:
|
|
55
64
|
from ...minichain.lib.bash import bash_tool
|
|
56
|
-
|
|
57
65
|
els.append(bind_tool(bash_tool()))
|
|
58
66
|
|
|
59
67
|
if tools_config.enable_test_weather_tool:
|
ommlds/minichain/__init__.py
CHANGED
|
@@ -17,10 +17,14 @@ from ....chat.choices.types import AiChoice
|
|
|
17
17
|
from ....chat.messages import AiMessage
|
|
18
18
|
from ....chat.messages import Message
|
|
19
19
|
from ....chat.messages import SystemMessage
|
|
20
|
+
from ....chat.messages import ToolExecResultMessage
|
|
20
21
|
from ....chat.messages import UserMessage
|
|
22
|
+
from ....chat.tools.types import Tool
|
|
21
23
|
from ....models.configs import ModelName
|
|
22
24
|
from ....standard import ApiKey
|
|
25
|
+
from ....tools.types import ToolExecRequest
|
|
23
26
|
from .names import MODEL_NAMES
|
|
27
|
+
from .tools import build_tool_spec_schema
|
|
24
28
|
|
|
25
29
|
|
|
26
30
|
##
|
|
@@ -54,7 +58,6 @@ class GoogleChatChoicesService:
|
|
|
54
58
|
BASE_URL: ta.ClassVar[str] = 'https://generativelanguage.googleapis.com/v1beta/models'
|
|
55
59
|
|
|
56
60
|
ROLES_MAP: ta.ClassVar[ta.Mapping[type[Message], str]] = {
|
|
57
|
-
SystemMessage: 'system',
|
|
58
61
|
UserMessage: 'user',
|
|
59
62
|
AiMessage: 'assistant',
|
|
60
63
|
}
|
|
@@ -65,16 +68,56 @@ class GoogleChatChoicesService:
|
|
|
65
68
|
) -> ChatChoicesResponse:
|
|
66
69
|
key = check.not_none(self._api_key).reveal()
|
|
67
70
|
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
+
g_sys_content: pt.Content | None = None
|
|
72
|
+
g_contents: list[pt.Content] = []
|
|
73
|
+
for i, m in enumerate(request.v):
|
|
74
|
+
if isinstance(m, SystemMessage):
|
|
75
|
+
check.arg(i == 0)
|
|
76
|
+
check.none(g_sys_content)
|
|
77
|
+
g_sys_content = pt.Content(
|
|
71
78
|
parts=[pt.Part(
|
|
72
79
|
text=check.not_none(self._get_msg_content(m)),
|
|
73
80
|
)],
|
|
74
|
-
role=self.ROLES_MAP[type(m)], # type: ignore[arg-type]
|
|
75
81
|
)
|
|
76
|
-
|
|
77
|
-
|
|
82
|
+
elif isinstance(m, ToolExecResultMessage):
|
|
83
|
+
tr_resp_val: pt.Value
|
|
84
|
+
if m.c is None:
|
|
85
|
+
tr_resp_val = pt.NullValue() # type: ignore[unreachable]
|
|
86
|
+
elif isinstance(m.c, str):
|
|
87
|
+
tr_resp_val = pt.StringValue(m.c)
|
|
88
|
+
else:
|
|
89
|
+
raise TypeError(m.c)
|
|
90
|
+
g_contents.append(pt.Content(
|
|
91
|
+
parts=[pt.Part(
|
|
92
|
+
function_response=pt.FunctionResponse(
|
|
93
|
+
id=m.id,
|
|
94
|
+
name=m.name,
|
|
95
|
+
response={
|
|
96
|
+
'value': tr_resp_val,
|
|
97
|
+
},
|
|
98
|
+
),
|
|
99
|
+
)],
|
|
100
|
+
))
|
|
101
|
+
else:
|
|
102
|
+
g_contents.append(pt.Content(
|
|
103
|
+
parts=[pt.Part(
|
|
104
|
+
text=check.not_none(self._get_msg_content(m)),
|
|
105
|
+
)],
|
|
106
|
+
role=self.ROLES_MAP[type(m)], # type: ignore[arg-type]
|
|
107
|
+
))
|
|
108
|
+
|
|
109
|
+
g_tools: list[pt.Tool] = []
|
|
110
|
+
with tv.TypedValues(*request.options).consume() as oc:
|
|
111
|
+
t: Tool
|
|
112
|
+
for t in oc.pop(Tool, []):
|
|
113
|
+
g_tools.append(pt.Tool(
|
|
114
|
+
function_declarations=[build_tool_spec_schema(t.spec)],
|
|
115
|
+
))
|
|
116
|
+
|
|
117
|
+
g_req = pt.GenerateContentRequest(
|
|
118
|
+
contents=g_contents or None,
|
|
119
|
+
tools=g_tools or None,
|
|
120
|
+
system_instruction=g_sys_content,
|
|
78
121
|
)
|
|
79
122
|
|
|
80
123
|
req_dct = msh.marshal(g_req)
|
|
@@ -92,7 +135,19 @@ class GoogleChatChoicesService:
|
|
|
92
135
|
|
|
93
136
|
g_resp = msh.unmarshal(resp_dct, pt.GenerateContentResponse)
|
|
94
137
|
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
138
|
+
ai_choices: list[AiChoice] = []
|
|
139
|
+
for c in g_resp.candidates or []:
|
|
140
|
+
g_resp_part = check.single(check.not_none(check.not_none(c.content).parts))
|
|
141
|
+
ter: ToolExecRequest | None = None
|
|
142
|
+
if (g_fc := g_resp_part.function_call) is not None:
|
|
143
|
+
ter = ToolExecRequest(
|
|
144
|
+
id=g_fc.id,
|
|
145
|
+
name=g_fc.name,
|
|
146
|
+
args=g_fc.args or {},
|
|
147
|
+
)
|
|
148
|
+
ai_choices.append(AiChoice(AiMessage(
|
|
149
|
+
c=g_resp_part.text,
|
|
150
|
+
tool_exec_requests=[ter] if ter is not None else None,
|
|
151
|
+
)))
|
|
152
|
+
|
|
153
|
+
return ChatChoicesResponse(ai_choices)
|
|
@@ -0,0 +1,149 @@
|
|
|
1
|
+
"""
|
|
2
|
+
https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models
|
|
3
|
+
"""
|
|
4
|
+
import typing as ta
|
|
5
|
+
|
|
6
|
+
from omlish import check
|
|
7
|
+
from omlish import dataclasses as dc
|
|
8
|
+
|
|
9
|
+
from .....backends.google.protocol import types as pt
|
|
10
|
+
from ....content.prepare import ContentStrPreparer
|
|
11
|
+
from ....content.prepare import default_content_str_preparer
|
|
12
|
+
from ....tools.types import EnumToolDtype
|
|
13
|
+
from ....tools.types import MappingToolDtype
|
|
14
|
+
from ....tools.types import NullableToolDtype
|
|
15
|
+
from ....tools.types import ObjectToolDtype
|
|
16
|
+
from ....tools.types import PrimitiveToolDtype
|
|
17
|
+
from ....tools.types import SequenceToolDtype
|
|
18
|
+
from ....tools.types import ToolDtype
|
|
19
|
+
from ....tools.types import ToolSpec
|
|
20
|
+
from ....tools.types import TupleToolDtype
|
|
21
|
+
from ....tools.types import UnionToolDtype
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
##
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def _shallow_dc_asdict_not_none(o: ta.Any) -> dict[str, ta.Any]:
|
|
28
|
+
return {k: v for k, v in dc.shallow_asdict(o).items() if v is not None}
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
PT_TYPE_BY_PRIMITIVE_TYPE: ta.Mapping[str, pt.Type] = {
|
|
32
|
+
'string': 'STRING',
|
|
33
|
+
'number': 'NUMBER',
|
|
34
|
+
'integer': 'INTEGER',
|
|
35
|
+
'boolean': 'BOOLEAN',
|
|
36
|
+
'array': 'ARRAY',
|
|
37
|
+
'null': 'NULL',
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class ToolSchemaRenderer:
|
|
42
|
+
def __init__(
|
|
43
|
+
self,
|
|
44
|
+
*,
|
|
45
|
+
content_str_preparer: ContentStrPreparer | None = None,
|
|
46
|
+
) -> None:
|
|
47
|
+
super().__init__()
|
|
48
|
+
|
|
49
|
+
if content_str_preparer is None:
|
|
50
|
+
content_str_preparer = default_content_str_preparer()
|
|
51
|
+
self._content_str_preparer = content_str_preparer
|
|
52
|
+
|
|
53
|
+
def render_type(self, t: ToolDtype) -> pt.Schema:
|
|
54
|
+
if isinstance(t, PrimitiveToolDtype):
|
|
55
|
+
return pt.Schema(type=PT_TYPE_BY_PRIMITIVE_TYPE[t.type])
|
|
56
|
+
|
|
57
|
+
if isinstance(t, UnionToolDtype):
|
|
58
|
+
return pt.Schema(
|
|
59
|
+
any_of=[self.render_type(a) for a in t.args],
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
if isinstance(t, NullableToolDtype):
|
|
63
|
+
return pt.Schema(**{
|
|
64
|
+
**_shallow_dc_asdict_not_none(self.render_type(t.type)),
|
|
65
|
+
**dict(nullable=True),
|
|
66
|
+
})
|
|
67
|
+
|
|
68
|
+
if isinstance(t, SequenceToolDtype):
|
|
69
|
+
return pt.Schema(
|
|
70
|
+
type='ARRAY',
|
|
71
|
+
items=self.render_type(t.element),
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
if isinstance(t, MappingToolDtype):
|
|
75
|
+
# FIXME: t.key
|
|
76
|
+
# return {
|
|
77
|
+
# 'type': 'object',
|
|
78
|
+
# 'additionalProperties': self.render_type(t.value),
|
|
79
|
+
# }
|
|
80
|
+
raise NotImplementedError
|
|
81
|
+
|
|
82
|
+
if isinstance(t, TupleToolDtype):
|
|
83
|
+
# return {
|
|
84
|
+
# 'type': 'array',
|
|
85
|
+
# 'prefixItems': [self.render_type(e) for e in t.elements],
|
|
86
|
+
# }
|
|
87
|
+
raise NotImplementedError
|
|
88
|
+
|
|
89
|
+
if isinstance(t, EnumToolDtype):
|
|
90
|
+
return pt.Schema(**{
|
|
91
|
+
**_shallow_dc_asdict_not_none(self.render_type(t.type)),
|
|
92
|
+
**dict(enum=list(t.values)),
|
|
93
|
+
})
|
|
94
|
+
|
|
95
|
+
if isinstance(t, ObjectToolDtype):
|
|
96
|
+
return pt.Schema(
|
|
97
|
+
type='OBJECT',
|
|
98
|
+
properties={
|
|
99
|
+
k: self.render_type(v)
|
|
100
|
+
for k, v in t.fields.items()
|
|
101
|
+
},
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
raise TypeError(t)
|
|
105
|
+
|
|
106
|
+
def render_tool_params(self, ts: ToolSpec) -> pt.Schema:
|
|
107
|
+
pr_dct: dict[str, pt.Schema] | None = None
|
|
108
|
+
req_lst: list[str] | None = None
|
|
109
|
+
if ts.params is not None:
|
|
110
|
+
pr_dct = {}
|
|
111
|
+
req_lst = []
|
|
112
|
+
for p in ts.params or []:
|
|
113
|
+
pr_dct[check.non_empty_str(p.name)] = pt.Schema(**{
|
|
114
|
+
**(dict(description=self._content_str_preparer.prepare_str(p.desc)) if p.desc is not None else {}),
|
|
115
|
+
**(_shallow_dc_asdict_not_none(self.render_type(p.type)) if p.type is not None else {}),
|
|
116
|
+
})
|
|
117
|
+
if p.required:
|
|
118
|
+
req_lst.append(check.non_empty_str(p.name))
|
|
119
|
+
|
|
120
|
+
return pt.Schema(
|
|
121
|
+
type='OBJECT',
|
|
122
|
+
**(dict(properties=pr_dct) if pr_dct is not None else {}), # type: ignore[arg-type]
|
|
123
|
+
**(dict(required=req_lst) if req_lst is not None else {}), # type: ignore[arg-type]
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
def render_tool(self, ts: ToolSpec) -> pt.FunctionDeclaration:
|
|
127
|
+
ret_dct = {
|
|
128
|
+
**(dict(description=self._content_str_preparer.prepare_str(ts.returns_desc)) if ts.returns_desc is not None else {}), # noqa
|
|
129
|
+
**(_shallow_dc_asdict_not_none(self.render_type(ts.returns_type)) if ts.returns_type is not None else {}),
|
|
130
|
+
}
|
|
131
|
+
|
|
132
|
+
return pt.FunctionDeclaration(
|
|
133
|
+
name=check.non_empty_str(ts.name),
|
|
134
|
+
description=self._content_str_preparer.prepare_str(ts.desc) if ts.desc is not None else None, # type: ignore[arg-type] # noqa
|
|
135
|
+
behavior='BLOCKING',
|
|
136
|
+
parameters=self.render_tool_params(ts) if ts.params else None,
|
|
137
|
+
response=(pt.Schema(**ret_dct) if ret_dct else None),
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
##
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def build_tool_spec_schema(ts: ToolSpec) -> pt.FunctionDeclaration:
|
|
145
|
+
return ToolSchemaRenderer().render_tool(ts)
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def build_tool_spec_params_schema(ts: ToolSpec) -> pt.Schema:
|
|
149
|
+
return ToolSchemaRenderer().render_tool_params(ts)
|