ommlds 0.0.0.dev449__py3-none-any.whl → 0.0.0.dev451__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ommlds might be problematic. Click here for more details.
- ommlds/.omlish-manifests.json +3 -3
- ommlds/backends/anthropic/protocol/_marshal.py +1 -1
- ommlds/backends/anthropic/protocol/sse/_marshal.py +1 -1
- ommlds/backends/anthropic/protocol/sse/assemble.py +1 -1
- ommlds/backends/anthropic/protocol/types.py +30 -9
- ommlds/backends/google/protocol/__init__.py +3 -0
- ommlds/backends/google/protocol/_marshal.py +16 -0
- ommlds/backends/google/protocol/types.py +303 -76
- ommlds/backends/mlx/generation.py +1 -1
- ommlds/backends/openai/protocol/_marshal.py +1 -1
- ommlds/cli/main.py +29 -8
- ommlds/cli/sessions/chat/code.py +124 -0
- ommlds/cli/sessions/chat/interactive.py +2 -5
- ommlds/cli/sessions/chat/printing.py +5 -4
- ommlds/cli/sessions/chat/prompt.py +2 -2
- ommlds/cli/sessions/chat/state.py +1 -0
- ommlds/cli/sessions/chat/tools.py +3 -5
- ommlds/cli/tools/config.py +2 -1
- ommlds/cli/tools/inject.py +13 -3
- ommlds/minichain/__init__.py +12 -0
- ommlds/minichain/_marshal.py +39 -0
- ommlds/minichain/backends/impls/anthropic/chat.py +78 -10
- ommlds/minichain/backends/impls/google/chat.py +95 -12
- ommlds/minichain/backends/impls/google/tools.py +149 -0
- ommlds/minichain/chat/_marshal.py +1 -1
- ommlds/minichain/content/_marshal.py +24 -3
- ommlds/minichain/content/json.py +13 -0
- ommlds/minichain/content/materialize.py +13 -20
- ommlds/minichain/content/prepare.py +4 -0
- ommlds/minichain/json.py +20 -0
- ommlds/minichain/lib/code/prompts.py +6 -0
- ommlds/minichain/lib/fs/context.py +18 -4
- ommlds/minichain/lib/fs/errors.py +6 -0
- ommlds/minichain/lib/fs/tools/edit.py +104 -0
- ommlds/minichain/lib/fs/{catalog → tools}/ls.py +3 -3
- ommlds/minichain/lib/fs/{catalog → tools}/read.py +6 -6
- ommlds/minichain/lib/fs/tools/recursivels/__init__.py +0 -0
- ommlds/minichain/lib/fs/{catalog → tools}/recursivels/execution.py +2 -2
- ommlds/minichain/lib/todo/__init__.py +0 -0
- ommlds/minichain/lib/todo/context.py +54 -0
- ommlds/minichain/lib/todo/tools/__init__.py +0 -0
- ommlds/minichain/lib/todo/tools/read.py +44 -0
- ommlds/minichain/lib/todo/tools/write.py +335 -0
- ommlds/minichain/lib/todo/types.py +60 -0
- ommlds/minichain/llms/_marshal.py +1 -1
- ommlds/minichain/services/_marshal.py +1 -1
- ommlds/minichain/tools/_marshal.py +1 -1
- ommlds/minichain/tools/execution/catalog.py +2 -1
- ommlds/minichain/tools/execution/executors.py +8 -3
- ommlds/minichain/tools/execution/reflect.py +43 -5
- ommlds/minichain/tools/fns.py +46 -9
- ommlds/minichain/tools/jsonschema.py +11 -1
- ommlds/minichain/tools/reflect.py +9 -2
- ommlds/minichain/tools/types.py +9 -0
- ommlds/minichain/utils.py +27 -0
- ommlds/minichain/vectors/_marshal.py +1 -1
- ommlds/tools/ocr.py +7 -1
- {ommlds-0.0.0.dev449.dist-info → ommlds-0.0.0.dev451.dist-info}/METADATA +3 -3
- {ommlds-0.0.0.dev449.dist-info → ommlds-0.0.0.dev451.dist-info}/RECORD +67 -53
- /ommlds/minichain/lib/{fs/catalog → code}/__init__.py +0 -0
- /ommlds/minichain/lib/fs/{catalog/recursivels → tools}/__init__.py +0 -0
- /ommlds/minichain/lib/fs/{catalog → tools}/recursivels/rendering.py +0 -0
- /ommlds/minichain/lib/fs/{catalog → tools}/recursivels/running.py +0 -0
- {ommlds-0.0.0.dev449.dist-info → ommlds-0.0.0.dev451.dist-info}/WHEEL +0 -0
- {ommlds-0.0.0.dev449.dist-info → ommlds-0.0.0.dev451.dist-info}/entry_points.txt +0 -0
- {ommlds-0.0.0.dev449.dist-info → ommlds-0.0.0.dev451.dist-info}/licenses/LICENSE +0 -0
- {ommlds-0.0.0.dev449.dist-info → ommlds-0.0.0.dev451.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
import dataclasses as dc
|
|
2
|
+
import itertools
|
|
3
|
+
import os.path
|
|
4
|
+
|
|
5
|
+
from omlish import check
|
|
6
|
+
from omlish import lang
|
|
7
|
+
|
|
8
|
+
from .... import minichain as mc
|
|
9
|
+
from ....minichain.lib.code.prompts import CODE_AGENT_SYSTEM_PROMPT
|
|
10
|
+
from ...tools.config import ToolsConfig
|
|
11
|
+
from .base import DEFAULT_CHAT_MODEL_BACKEND
|
|
12
|
+
from .base import ChatOptions
|
|
13
|
+
from .base import ChatSession
|
|
14
|
+
from .printing import ChatSessionPrinter
|
|
15
|
+
from .state import ChatStateManager
|
|
16
|
+
from .tools import ToolExecRequestExecutor
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
with lang.auto_proxy_import(globals()):
|
|
20
|
+
from omdev import ptk
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
##
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class CodeChatSession(ChatSession['CodeChatSession.Config']):
|
|
27
|
+
@dc.dataclass(frozen=True)
|
|
28
|
+
class Config(ChatSession.Config):
|
|
29
|
+
_: dc.KW_ONLY
|
|
30
|
+
|
|
31
|
+
new: bool = False
|
|
32
|
+
|
|
33
|
+
backend: str | None = None
|
|
34
|
+
model_name: str | None = None
|
|
35
|
+
|
|
36
|
+
initial_message: mc.Content | None = None
|
|
37
|
+
|
|
38
|
+
def __init__(
|
|
39
|
+
self,
|
|
40
|
+
config: Config,
|
|
41
|
+
*,
|
|
42
|
+
state_manager: ChatStateManager,
|
|
43
|
+
chat_options: ChatOptions | None = None,
|
|
44
|
+
printer: ChatSessionPrinter,
|
|
45
|
+
backend_catalog: mc.BackendCatalog,
|
|
46
|
+
tool_exec_request_executor: ToolExecRequestExecutor,
|
|
47
|
+
tools_config: ToolsConfig | None = None,
|
|
48
|
+
) -> None:
|
|
49
|
+
super().__init__(config)
|
|
50
|
+
|
|
51
|
+
self._state_manager = state_manager
|
|
52
|
+
self._chat_options = chat_options
|
|
53
|
+
self._printer = printer
|
|
54
|
+
self._backend_catalog = backend_catalog
|
|
55
|
+
self._tool_exec_request_executor = tool_exec_request_executor
|
|
56
|
+
self._tools_config = tools_config
|
|
57
|
+
|
|
58
|
+
async def run(self) -> None:
|
|
59
|
+
if self._config.new:
|
|
60
|
+
self._state_manager.clear_state()
|
|
61
|
+
state = self._state_manager.extend_chat([
|
|
62
|
+
mc.SystemMessage(CODE_AGENT_SYSTEM_PROMPT),
|
|
63
|
+
])
|
|
64
|
+
|
|
65
|
+
else:
|
|
66
|
+
state = self._state_manager.get_state()
|
|
67
|
+
|
|
68
|
+
backend = self._config.backend
|
|
69
|
+
if backend is None:
|
|
70
|
+
backend = DEFAULT_CHAT_MODEL_BACKEND
|
|
71
|
+
|
|
72
|
+
# FIXME: lol
|
|
73
|
+
from ....minichain.lib.fs.context import FsContext
|
|
74
|
+
fs_tool_context = FsContext(
|
|
75
|
+
root_dir=os.getcwd(),
|
|
76
|
+
writes_permitted=self._tools_config is not None and self._tools_config.enable_unsafe_tools_do_not_use_lol,
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
from ....minichain.lib.todo.context import TodoContext
|
|
80
|
+
todo_tool_context = TodoContext()
|
|
81
|
+
|
|
82
|
+
mdl: mc.ChatChoicesService
|
|
83
|
+
async with lang.async_maybe_managing(self._backend_catalog.get_backend(
|
|
84
|
+
mc.ChatChoicesService,
|
|
85
|
+
backend,
|
|
86
|
+
*([mc.ModelName(mn)] if (mn := self._config.model_name) is not None else []),
|
|
87
|
+
)) as mdl:
|
|
88
|
+
for i in itertools.count():
|
|
89
|
+
if not i and self._config.initial_message is not None:
|
|
90
|
+
req_msg = mc.UserMessage(self._config.initial_message)
|
|
91
|
+
else:
|
|
92
|
+
try:
|
|
93
|
+
prompt = await ptk.prompt('> ')
|
|
94
|
+
except EOFError:
|
|
95
|
+
break
|
|
96
|
+
req_msg = mc.UserMessage(prompt)
|
|
97
|
+
|
|
98
|
+
state = self._state_manager.extend_chat([req_msg])
|
|
99
|
+
|
|
100
|
+
while True:
|
|
101
|
+
response = await mdl.invoke(mc.ChatChoicesRequest(
|
|
102
|
+
state.chat,
|
|
103
|
+
(self._chat_options or []),
|
|
104
|
+
))
|
|
105
|
+
resp_msg = check.single(response.v).m
|
|
106
|
+
|
|
107
|
+
self._printer.print(resp_msg)
|
|
108
|
+
state = self._state_manager.extend_chat([resp_msg])
|
|
109
|
+
|
|
110
|
+
if not (trs := resp_msg.tool_exec_requests):
|
|
111
|
+
break
|
|
112
|
+
|
|
113
|
+
tool_resp_lst = []
|
|
114
|
+
for tr in trs:
|
|
115
|
+
trm = await self._tool_exec_request_executor.execute_tool_request(
|
|
116
|
+
tr,
|
|
117
|
+
fs_tool_context,
|
|
118
|
+
todo_tool_context,
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
self._printer.print(trm.c)
|
|
122
|
+
tool_resp_lst.append(trm)
|
|
123
|
+
|
|
124
|
+
state = self._state_manager.extend_chat(tool_resp_lst)
|
|
@@ -1,5 +1,4 @@
|
|
|
1
1
|
import dataclasses as dc
|
|
2
|
-
import typing as ta
|
|
3
2
|
|
|
4
3
|
from omlish import lang
|
|
5
4
|
|
|
@@ -10,10 +9,8 @@ from .printing import ChatSessionPrinter
|
|
|
10
9
|
from .state import ChatStateManager
|
|
11
10
|
|
|
12
11
|
|
|
13
|
-
|
|
12
|
+
with lang.auto_proxy_import(globals()):
|
|
14
13
|
from omdev import ptk
|
|
15
|
-
else:
|
|
16
|
-
ptk = lang.proxy_import('omdev.ptk')
|
|
17
14
|
|
|
18
15
|
|
|
19
16
|
##
|
|
@@ -60,7 +57,7 @@ class InteractiveChatSession(ChatSession['InteractiveChatSession.Config']):
|
|
|
60
57
|
*([mc.ModelName(mn)] if (mn := self._config.model_name) is not None else []),
|
|
61
58
|
)) as mdl:
|
|
62
59
|
while True:
|
|
63
|
-
prompt = ptk.prompt('> ')
|
|
60
|
+
prompt = await ptk.prompt('> ')
|
|
64
61
|
|
|
65
62
|
req_msg = mc.UserMessage(prompt)
|
|
66
63
|
|
|
@@ -3,16 +3,14 @@ import typing as ta
|
|
|
3
3
|
|
|
4
4
|
from omlish import check
|
|
5
5
|
from omlish import lang
|
|
6
|
+
from omlish.formats import json
|
|
6
7
|
|
|
7
8
|
from .... import minichain as mc
|
|
8
9
|
|
|
9
10
|
|
|
10
|
-
|
|
11
|
+
with lang.auto_proxy_import(globals()):
|
|
11
12
|
from omdev import ptk
|
|
12
13
|
from omdev.ptk import markdown as ptk_md
|
|
13
|
-
else:
|
|
14
|
-
ptk = lang.proxy_import('omdev.ptk')
|
|
15
|
-
ptk_md = lang.proxy_import('omdev.ptk.markdown')
|
|
16
14
|
|
|
17
15
|
|
|
18
16
|
##
|
|
@@ -51,6 +49,9 @@ class StringChatSessionPrinter(ChatSessionPrinter, lang.Abstract):
|
|
|
51
49
|
else:
|
|
52
50
|
raise TypeError(obj)
|
|
53
51
|
|
|
52
|
+
elif isinstance(obj, mc.JsonContent):
|
|
53
|
+
self._print_str(json.dumps_pretty(obj.v))
|
|
54
|
+
|
|
54
55
|
elif isinstance(obj, str):
|
|
55
56
|
self._print_str(obj)
|
|
56
57
|
|
|
@@ -127,11 +127,11 @@ class PromptChatSession(ChatSession['PromptChatSession.Config']):
|
|
|
127
127
|
tr: mc.ToolExecRequest = check.single(check.not_none(trs))
|
|
128
128
|
|
|
129
129
|
# FIXME: lol
|
|
130
|
-
from ....minichain.lib.fs.context import
|
|
130
|
+
from ....minichain.lib.fs.context import FsContext
|
|
131
131
|
|
|
132
132
|
trm = await self._tool_exec_request_executor.execute_tool_request(
|
|
133
133
|
tr,
|
|
134
|
-
|
|
134
|
+
FsContext(root_dir=os.getcwd()),
|
|
135
135
|
)
|
|
136
136
|
|
|
137
137
|
print(trm.c)
|
|
@@ -3,16 +3,13 @@ import typing as ta
|
|
|
3
3
|
|
|
4
4
|
from omlish import check
|
|
5
5
|
from omlish import lang
|
|
6
|
-
from omlish import marshal as msh
|
|
7
6
|
from omlish.formats import json
|
|
8
7
|
|
|
9
8
|
from .... import minichain as mc
|
|
10
9
|
|
|
11
10
|
|
|
12
|
-
|
|
11
|
+
with lang.auto_proxy_import(globals()):
|
|
13
12
|
from omdev import ptk
|
|
14
|
-
else:
|
|
15
|
-
ptk = lang.proxy_import('omdev.ptk')
|
|
16
13
|
|
|
17
14
|
|
|
18
15
|
##
|
|
@@ -49,8 +46,9 @@ class AskingToolExecutionConfirmation(ToolExecutionConfirmation):
|
|
|
49
46
|
) -> None:
|
|
50
47
|
tr_dct = dict(
|
|
51
48
|
id=tr.id,
|
|
52
|
-
|
|
49
|
+
name=tce.spec.name,
|
|
53
50
|
args=tr.args,
|
|
51
|
+
# spec=msh.marshal(tce.spec),
|
|
54
52
|
)
|
|
55
53
|
cr = await ptk.strict_confirm(f'Execute requested tool?\n\n{json.dumps_pretty(tr_dct)}\n\n')
|
|
56
54
|
|
ommlds/cli/tools/config.py
CHANGED
|
@@ -7,7 +7,8 @@ from omlish import dataclasses as dc
|
|
|
7
7
|
@dc.dataclass(frozen=True, kw_only=True)
|
|
8
8
|
class ToolsConfig:
|
|
9
9
|
enable_fs_tools: bool = False
|
|
10
|
+
enable_todo_tools: bool = False
|
|
10
11
|
|
|
11
|
-
|
|
12
|
+
enable_unsafe_tools_do_not_use_lol: bool = False
|
|
12
13
|
|
|
13
14
|
enable_test_weather_tool: bool = False
|
ommlds/cli/tools/inject.py
CHANGED
|
@@ -47,16 +47,26 @@ def bind_tools(tools_config: ToolsConfig) -> inj.Elements:
|
|
|
47
47
|
#
|
|
48
48
|
|
|
49
49
|
if tools_config.enable_fs_tools:
|
|
50
|
-
from ...minichain.lib.fs.
|
|
50
|
+
from ...minichain.lib.fs.tools.ls import ls_tool
|
|
51
51
|
els.append(bind_tool(ls_tool()))
|
|
52
52
|
|
|
53
|
-
from ...minichain.lib.fs.
|
|
53
|
+
from ...minichain.lib.fs.tools.read import read_tool
|
|
54
54
|
els.append(bind_tool(read_tool()))
|
|
55
55
|
|
|
56
|
-
if tools_config.
|
|
56
|
+
if tools_config.enable_todo_tools:
|
|
57
|
+
from ...minichain.lib.todo.tools.read import todo_read_tool
|
|
58
|
+
els.append(bind_tool(todo_read_tool()))
|
|
59
|
+
|
|
60
|
+
from ...minichain.lib.todo.tools.write import todo_write_tool
|
|
61
|
+
els.append(bind_tool(todo_write_tool()))
|
|
62
|
+
|
|
63
|
+
if tools_config.enable_unsafe_tools_do_not_use_lol:
|
|
57
64
|
from ...minichain.lib.bash import bash_tool
|
|
58
65
|
els.append(bind_tool(bash_tool()))
|
|
59
66
|
|
|
67
|
+
from ...minichain.lib.fs.tools.edit import edit_tool
|
|
68
|
+
els.append(bind_tool(edit_tool()))
|
|
69
|
+
|
|
60
70
|
if tools_config.enable_test_weather_tool:
|
|
61
71
|
els.append(bind_tool(WEATHER_TOOL))
|
|
62
72
|
|
ommlds/minichain/__init__.py
CHANGED
|
@@ -197,8 +197,14 @@ with _lang.auto_proxy_init(
|
|
|
197
197
|
ImageContent,
|
|
198
198
|
)
|
|
199
199
|
|
|
200
|
+
from .content.json import ( # noqa
|
|
201
|
+
JsonContent,
|
|
202
|
+
)
|
|
203
|
+
|
|
200
204
|
from .content.materialize import ( # noqa
|
|
201
205
|
CanContent,
|
|
206
|
+
|
|
207
|
+
materialize_content,
|
|
202
208
|
)
|
|
203
209
|
|
|
204
210
|
from .content.metadata import ( # noqa
|
|
@@ -493,6 +499,12 @@ with _lang.auto_proxy_init(
|
|
|
493
499
|
EnvKey,
|
|
494
500
|
)
|
|
495
501
|
|
|
502
|
+
from .json import ( # noqa
|
|
503
|
+
JsonSchema,
|
|
504
|
+
|
|
505
|
+
JsonValue,
|
|
506
|
+
)
|
|
507
|
+
|
|
496
508
|
from .metadata import ( # noqa
|
|
497
509
|
Metadata,
|
|
498
510
|
|
ommlds/minichain/_marshal.py
CHANGED
|
@@ -1,10 +1,13 @@
|
|
|
1
1
|
from omlish import dataclasses as dc
|
|
2
|
+
from omlish import lang
|
|
2
3
|
from omlish import marshal as msh
|
|
3
4
|
from omlish import reflect as rfl
|
|
4
5
|
from omlish.funcs import match as mfs
|
|
5
6
|
from omlish.typedvalues.marshal import build_typed_values_marshaler
|
|
6
7
|
from omlish.typedvalues.marshal import build_typed_values_unmarshaler
|
|
7
8
|
|
|
9
|
+
from .json import JsonValue
|
|
10
|
+
|
|
8
11
|
|
|
9
12
|
##
|
|
10
13
|
|
|
@@ -25,3 +28,39 @@ class _TypedValuesFieldUnmarshalerFactory(msh.UnmarshalerFactoryMatchClass):
|
|
|
25
28
|
@mfs.simple(lambda _, ctx, rty: True)
|
|
26
29
|
def _build(self, ctx: msh.UnmarshalContext, rty: rfl.Type) -> msh.Unmarshaler:
|
|
27
30
|
return build_typed_values_unmarshaler(ctx, self.tvs_rty)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
##
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class MarshalJsonValue(lang.NotInstantiable, lang.Final):
|
|
37
|
+
pass
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class _JsonValueMarshalerFactory(msh.MarshalerFactoryMatchClass):
|
|
41
|
+
@mfs.simple(lambda _, ctx, rty: rty is MarshalJsonValue)
|
|
42
|
+
def _build(self, ctx: msh.MarshalContext, rty: rfl.Type) -> msh.Marshaler:
|
|
43
|
+
return msh.NopMarshalerUnmarshaler()
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class _JsonValueUnmarshalerFactory(msh.UnmarshalerFactoryMatchClass):
|
|
47
|
+
@mfs.simple(lambda _, ctx, rty: rty is MarshalJsonValue)
|
|
48
|
+
def _build(self, ctx: msh.UnmarshalContext, rty: rfl.Type) -> msh.Unmarshaler:
|
|
49
|
+
return msh.NopMarshalerUnmarshaler()
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
##
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
@lang.static_init
|
|
56
|
+
def _install_standard_marshaling() -> None:
|
|
57
|
+
msh.register_global_config(
|
|
58
|
+
JsonValue,
|
|
59
|
+
msh.ReflectOverride(MarshalJsonValue),
|
|
60
|
+
identity=True,
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
msh.install_standard_factories(
|
|
64
|
+
_JsonValueMarshalerFactory(),
|
|
65
|
+
_JsonValueUnmarshalerFactory(),
|
|
66
|
+
)
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
"""
|
|
2
|
+
https://docs.claude.com/en/api/messages
|
|
2
3
|
https://github.com/anthropics/anthropic-sdk-python/tree/cd80d46f7a223a5493565d155da31b898a4c6ee5/src/anthropic/types
|
|
3
4
|
https://github.com/anthropics/anthropic-sdk-python/blob/cd80d46f7a223a5493565d155da31b898a4c6ee5/src/anthropic/resources/completions.py#L53
|
|
4
5
|
https://github.com/anthropics/anthropic-sdk-python/blob/cd80d46f7a223a5493565d155da31b898a4c6ee5/src/anthropic/resources/messages.py#L70
|
|
@@ -6,11 +7,12 @@ https://github.com/anthropics/anthropic-sdk-python/blob/cd80d46f7a223a5493565d15
|
|
|
6
7
|
import typing as ta
|
|
7
8
|
|
|
8
9
|
from omlish import check
|
|
9
|
-
from omlish import
|
|
10
|
+
from omlish import marshal as msh
|
|
10
11
|
from omlish import typedvalues as tv
|
|
11
12
|
from omlish.formats import json
|
|
12
13
|
from omlish.http import all as http
|
|
13
14
|
|
|
15
|
+
from .....backends.anthropic.protocol import types as pt
|
|
14
16
|
from ....chat.choices.services import ChatChoicesRequest
|
|
15
17
|
from ....chat.choices.services import ChatChoicesResponse
|
|
16
18
|
from ....chat.choices.services import static_check_is_chat_choices_service
|
|
@@ -18,9 +20,14 @@ from ....chat.choices.types import AiChoice
|
|
|
18
20
|
from ....chat.messages import AiMessage
|
|
19
21
|
from ....chat.messages import Message
|
|
20
22
|
from ....chat.messages import SystemMessage
|
|
23
|
+
from ....chat.messages import ToolExecResultMessage
|
|
21
24
|
from ....chat.messages import UserMessage
|
|
25
|
+
from ....chat.tools.types import Tool
|
|
26
|
+
from ....content.prepare import prepare_content_str
|
|
22
27
|
from ....models.configs import ModelName
|
|
23
28
|
from ....standard import ApiKey
|
|
29
|
+
from ....tools.jsonschema import build_tool_spec_params_json_schema
|
|
30
|
+
from ....tools.types import ToolExecRequest
|
|
24
31
|
from .names import MODEL_NAMES
|
|
25
32
|
|
|
26
33
|
|
|
@@ -68,26 +75,69 @@ class AnthropicChatChoicesService:
|
|
|
68
75
|
*,
|
|
69
76
|
max_tokens: int = 4096, # FIXME: ChatOption
|
|
70
77
|
) -> ChatChoicesResponse:
|
|
71
|
-
messages = []
|
|
72
|
-
system:
|
|
78
|
+
messages: list[pt.Message] = []
|
|
79
|
+
system: list[pt.Content] | None = None
|
|
73
80
|
for i, m in enumerate(request.v):
|
|
74
81
|
if isinstance(m, SystemMessage):
|
|
75
82
|
if i != 0 or system is not None:
|
|
76
83
|
raise Exception('Only supports one system message and must be first')
|
|
77
|
-
system = self._get_msg_content(m)
|
|
84
|
+
system = [pt.Text(check.not_none(self._get_msg_content(m)))]
|
|
85
|
+
|
|
86
|
+
elif isinstance(m, ToolExecResultMessage):
|
|
87
|
+
messages.append(pt.Message(
|
|
88
|
+
role='user',
|
|
89
|
+
content=[pt.ToolResult(
|
|
90
|
+
tool_use_id=check.not_none(m.id),
|
|
91
|
+
content=json.dumps_compact(msh.marshal(m.c)) if not isinstance(m.c, str) else m.c,
|
|
92
|
+
)],
|
|
93
|
+
))
|
|
94
|
+
|
|
95
|
+
elif isinstance(m, AiMessage):
|
|
96
|
+
# messages.append(pt.Message(
|
|
97
|
+
# role=self.ROLES_MAP[type(m)], # noqa
|
|
98
|
+
# content=[pt.Text(check.isinstance(self._get_msg_content(m), str))],
|
|
99
|
+
# ))
|
|
100
|
+
a_tus: list[pt.ToolUse] = []
|
|
101
|
+
for tr in m.tool_exec_requests or []:
|
|
102
|
+
a_tus.append(pt.ToolUse(
|
|
103
|
+
id=check.not_none(tr.id),
|
|
104
|
+
name=check.not_none(tr.name),
|
|
105
|
+
input=tr.args,
|
|
106
|
+
))
|
|
107
|
+
messages.append(pt.Message(
|
|
108
|
+
role='assistant',
|
|
109
|
+
content=[
|
|
110
|
+
*([pt.Text(check.isinstance(m.c, str))] if m.c is not None else []),
|
|
111
|
+
*a_tus,
|
|
112
|
+
],
|
|
113
|
+
))
|
|
114
|
+
|
|
78
115
|
else:
|
|
79
|
-
messages.append(
|
|
80
|
-
role=self.ROLES_MAP[type(m)], #
|
|
81
|
-
content=check.isinstance(self._get_msg_content(m), str),
|
|
116
|
+
messages.append(pt.Message(
|
|
117
|
+
role=self.ROLES_MAP[type(m)], # type: ignore[arg-type]
|
|
118
|
+
content=[pt.Text(check.isinstance(self._get_msg_content(m), str))],
|
|
119
|
+
))
|
|
120
|
+
|
|
121
|
+
tools: list[pt.ToolSpec] = []
|
|
122
|
+
with tv.TypedValues(*request.options).consume() as oc:
|
|
123
|
+
t: Tool
|
|
124
|
+
for t in oc.pop(Tool, []):
|
|
125
|
+
tools.append(pt.ToolSpec(
|
|
126
|
+
name=check.not_none(t.spec.name),
|
|
127
|
+
description=prepare_content_str(t.spec.desc),
|
|
128
|
+
input_schema=build_tool_spec_params_json_schema(t.spec),
|
|
82
129
|
))
|
|
83
130
|
|
|
84
|
-
|
|
131
|
+
a_req = pt.MessagesRequest(
|
|
85
132
|
model=MODEL_NAMES.resolve(self._model_name.v),
|
|
86
|
-
|
|
133
|
+
system=system,
|
|
87
134
|
messages=messages,
|
|
135
|
+
tools=tools or None,
|
|
88
136
|
max_tokens=max_tokens,
|
|
89
137
|
)
|
|
90
138
|
|
|
139
|
+
raw_request = msh.marshal(a_req)
|
|
140
|
+
|
|
91
141
|
raw_response = http.request(
|
|
92
142
|
'https://api.anthropic.com/v1/messages',
|
|
93
143
|
headers={
|
|
@@ -100,6 +150,24 @@ class AnthropicChatChoicesService:
|
|
|
100
150
|
|
|
101
151
|
response = json.loads(check.not_none(raw_response.data).decode('utf-8'))
|
|
102
152
|
|
|
153
|
+
resp_c: ta.Any = None
|
|
154
|
+
ters: list[ToolExecRequest] = []
|
|
155
|
+
for c in response['content']:
|
|
156
|
+
if c['type'] == 'text':
|
|
157
|
+
check.none(resp_c)
|
|
158
|
+
resp_c = check.not_none(c['text'])
|
|
159
|
+
elif c['type'] == 'tool_use':
|
|
160
|
+
ters.append(ToolExecRequest(
|
|
161
|
+
id=c['id'],
|
|
162
|
+
name=c['name'],
|
|
163
|
+
args=c['input'],
|
|
164
|
+
))
|
|
165
|
+
else:
|
|
166
|
+
raise TypeError(c['type'])
|
|
167
|
+
|
|
103
168
|
return ChatChoicesResponse([
|
|
104
|
-
AiChoice(AiMessage(
|
|
169
|
+
AiChoice(AiMessage(
|
|
170
|
+
resp_c,
|
|
171
|
+
tool_exec_requests=ters if ters else None,
|
|
172
|
+
)),
|
|
105
173
|
])
|
|
@@ -17,10 +17,15 @@ from ....chat.choices.types import AiChoice
|
|
|
17
17
|
from ....chat.messages import AiMessage
|
|
18
18
|
from ....chat.messages import Message
|
|
19
19
|
from ....chat.messages import SystemMessage
|
|
20
|
+
from ....chat.messages import ToolExecResultMessage
|
|
20
21
|
from ....chat.messages import UserMessage
|
|
22
|
+
from ....chat.tools.types import Tool
|
|
23
|
+
from ....content.types import Content
|
|
21
24
|
from ....models.configs import ModelName
|
|
22
25
|
from ....standard import ApiKey
|
|
26
|
+
from ....tools.types import ToolExecRequest
|
|
23
27
|
from .names import MODEL_NAMES
|
|
28
|
+
from .tools import build_tool_spec_schema
|
|
24
29
|
|
|
25
30
|
|
|
26
31
|
##
|
|
@@ -54,9 +59,8 @@ class GoogleChatChoicesService:
|
|
|
54
59
|
BASE_URL: ta.ClassVar[str] = 'https://generativelanguage.googleapis.com/v1beta/models'
|
|
55
60
|
|
|
56
61
|
ROLES_MAP: ta.ClassVar[ta.Mapping[type[Message], str]] = {
|
|
57
|
-
SystemMessage: 'system',
|
|
58
62
|
UserMessage: 'user',
|
|
59
|
-
AiMessage: '
|
|
63
|
+
AiMessage: 'model',
|
|
60
64
|
}
|
|
61
65
|
|
|
62
66
|
async def invoke(
|
|
@@ -65,16 +69,77 @@ class GoogleChatChoicesService:
|
|
|
65
69
|
) -> ChatChoicesResponse:
|
|
66
70
|
key = check.not_none(self._api_key).reveal()
|
|
67
71
|
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
72
|
+
g_sys_content: pt.Content | None = None
|
|
73
|
+
g_contents: list[pt.Content] = []
|
|
74
|
+
for i, m in enumerate(request.v):
|
|
75
|
+
if isinstance(m, SystemMessage):
|
|
76
|
+
check.arg(i == 0)
|
|
77
|
+
check.none(g_sys_content)
|
|
78
|
+
g_sys_content = pt.Content(
|
|
71
79
|
parts=[pt.Part(
|
|
72
80
|
text=check.not_none(self._get_msg_content(m)),
|
|
73
81
|
)],
|
|
74
|
-
role=self.ROLES_MAP[type(m)], # type: ignore[arg-type]
|
|
75
82
|
)
|
|
76
|
-
|
|
77
|
-
|
|
83
|
+
|
|
84
|
+
elif isinstance(m, ToolExecResultMessage):
|
|
85
|
+
tr_resp_val: pt.Value
|
|
86
|
+
if m.c is None:
|
|
87
|
+
tr_resp_val = pt.NullValue() # type: ignore[unreachable]
|
|
88
|
+
elif isinstance(m.c, str):
|
|
89
|
+
tr_resp_val = pt.StringValue(m.c)
|
|
90
|
+
else:
|
|
91
|
+
raise TypeError(m.c)
|
|
92
|
+
g_contents.append(pt.Content(
|
|
93
|
+
parts=[pt.Part(
|
|
94
|
+
function_response=pt.FunctionResponse(
|
|
95
|
+
id=m.id,
|
|
96
|
+
name=m.name,
|
|
97
|
+
response={
|
|
98
|
+
'value': tr_resp_val,
|
|
99
|
+
},
|
|
100
|
+
),
|
|
101
|
+
)],
|
|
102
|
+
))
|
|
103
|
+
|
|
104
|
+
elif isinstance(m, AiMessage):
|
|
105
|
+
ai_parts: list[pt.Part] = []
|
|
106
|
+
if m.c is not None:
|
|
107
|
+
ai_parts.append(pt.Part(
|
|
108
|
+
text=check.not_none(self._get_msg_content(m)),
|
|
109
|
+
))
|
|
110
|
+
for teq in m.tool_exec_requests or []:
|
|
111
|
+
ai_parts.append(pt.Part(
|
|
112
|
+
function_call=pt.FunctionCall(
|
|
113
|
+
id=teq.id,
|
|
114
|
+
name=teq.name,
|
|
115
|
+
args=teq.args,
|
|
116
|
+
),
|
|
117
|
+
))
|
|
118
|
+
g_contents.append(pt.Content(
|
|
119
|
+
parts=ai_parts,
|
|
120
|
+
role='model',
|
|
121
|
+
))
|
|
122
|
+
|
|
123
|
+
else:
|
|
124
|
+
g_contents.append(pt.Content(
|
|
125
|
+
parts=[pt.Part(
|
|
126
|
+
text=check.not_none(self._get_msg_content(m)),
|
|
127
|
+
)],
|
|
128
|
+
role=self.ROLES_MAP[type(m)], # type: ignore[arg-type]
|
|
129
|
+
))
|
|
130
|
+
|
|
131
|
+
g_tools: list[pt.Tool] = []
|
|
132
|
+
with tv.TypedValues(*request.options).consume() as oc:
|
|
133
|
+
t: Tool
|
|
134
|
+
for t in oc.pop(Tool, []):
|
|
135
|
+
g_tools.append(pt.Tool(
|
|
136
|
+
function_declarations=[build_tool_spec_schema(t.spec)],
|
|
137
|
+
))
|
|
138
|
+
|
|
139
|
+
g_req = pt.GenerateContentRequest(
|
|
140
|
+
contents=g_contents or None,
|
|
141
|
+
tools=g_tools or None,
|
|
142
|
+
system_instruction=g_sys_content,
|
|
78
143
|
)
|
|
79
144
|
|
|
80
145
|
req_dct = msh.marshal(g_req)
|
|
@@ -92,7 +157,25 @@ class GoogleChatChoicesService:
|
|
|
92
157
|
|
|
93
158
|
g_resp = msh.unmarshal(resp_dct, pt.GenerateContentResponse)
|
|
94
159
|
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
160
|
+
ai_choices: list[AiChoice] = []
|
|
161
|
+
for c in g_resp.candidates or []:
|
|
162
|
+
ai_c: Content | None = None
|
|
163
|
+
ters: list[ToolExecRequest] = []
|
|
164
|
+
for g_resp_part in check.not_none(check.not_none(c.content).parts):
|
|
165
|
+
if (g_txt := g_resp_part.text) is not None:
|
|
166
|
+
check.none(ai_c)
|
|
167
|
+
ai_c = g_txt
|
|
168
|
+
elif (g_fc := g_resp_part.function_call) is not None:
|
|
169
|
+
ters.append(ToolExecRequest(
|
|
170
|
+
id=g_fc.id,
|
|
171
|
+
name=g_fc.name,
|
|
172
|
+
args=g_fc.args or {},
|
|
173
|
+
))
|
|
174
|
+
else:
|
|
175
|
+
raise TypeError(g_resp_part)
|
|
176
|
+
ai_choices.append(AiChoice(AiMessage(
|
|
177
|
+
c=ai_c,
|
|
178
|
+
tool_exec_requests=ters if ters else None,
|
|
179
|
+
)))
|
|
180
|
+
|
|
181
|
+
return ChatChoicesResponse(ai_choices)
|