ommlds 0.0.0.dev449__py3-none-any.whl → 0.0.0.dev451__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ommlds might be problematic. Click here for more details.

Files changed (67) hide show
  1. ommlds/.omlish-manifests.json +3 -3
  2. ommlds/backends/anthropic/protocol/_marshal.py +1 -1
  3. ommlds/backends/anthropic/protocol/sse/_marshal.py +1 -1
  4. ommlds/backends/anthropic/protocol/sse/assemble.py +1 -1
  5. ommlds/backends/anthropic/protocol/types.py +30 -9
  6. ommlds/backends/google/protocol/__init__.py +3 -0
  7. ommlds/backends/google/protocol/_marshal.py +16 -0
  8. ommlds/backends/google/protocol/types.py +303 -76
  9. ommlds/backends/mlx/generation.py +1 -1
  10. ommlds/backends/openai/protocol/_marshal.py +1 -1
  11. ommlds/cli/main.py +29 -8
  12. ommlds/cli/sessions/chat/code.py +124 -0
  13. ommlds/cli/sessions/chat/interactive.py +2 -5
  14. ommlds/cli/sessions/chat/printing.py +5 -4
  15. ommlds/cli/sessions/chat/prompt.py +2 -2
  16. ommlds/cli/sessions/chat/state.py +1 -0
  17. ommlds/cli/sessions/chat/tools.py +3 -5
  18. ommlds/cli/tools/config.py +2 -1
  19. ommlds/cli/tools/inject.py +13 -3
  20. ommlds/minichain/__init__.py +12 -0
  21. ommlds/minichain/_marshal.py +39 -0
  22. ommlds/minichain/backends/impls/anthropic/chat.py +78 -10
  23. ommlds/minichain/backends/impls/google/chat.py +95 -12
  24. ommlds/minichain/backends/impls/google/tools.py +149 -0
  25. ommlds/minichain/chat/_marshal.py +1 -1
  26. ommlds/minichain/content/_marshal.py +24 -3
  27. ommlds/minichain/content/json.py +13 -0
  28. ommlds/minichain/content/materialize.py +13 -20
  29. ommlds/minichain/content/prepare.py +4 -0
  30. ommlds/minichain/json.py +20 -0
  31. ommlds/minichain/lib/code/prompts.py +6 -0
  32. ommlds/minichain/lib/fs/context.py +18 -4
  33. ommlds/minichain/lib/fs/errors.py +6 -0
  34. ommlds/minichain/lib/fs/tools/edit.py +104 -0
  35. ommlds/minichain/lib/fs/{catalog → tools}/ls.py +3 -3
  36. ommlds/minichain/lib/fs/{catalog → tools}/read.py +6 -6
  37. ommlds/minichain/lib/fs/tools/recursivels/__init__.py +0 -0
  38. ommlds/minichain/lib/fs/{catalog → tools}/recursivels/execution.py +2 -2
  39. ommlds/minichain/lib/todo/__init__.py +0 -0
  40. ommlds/minichain/lib/todo/context.py +54 -0
  41. ommlds/minichain/lib/todo/tools/__init__.py +0 -0
  42. ommlds/minichain/lib/todo/tools/read.py +44 -0
  43. ommlds/minichain/lib/todo/tools/write.py +335 -0
  44. ommlds/minichain/lib/todo/types.py +60 -0
  45. ommlds/minichain/llms/_marshal.py +1 -1
  46. ommlds/minichain/services/_marshal.py +1 -1
  47. ommlds/minichain/tools/_marshal.py +1 -1
  48. ommlds/minichain/tools/execution/catalog.py +2 -1
  49. ommlds/minichain/tools/execution/executors.py +8 -3
  50. ommlds/minichain/tools/execution/reflect.py +43 -5
  51. ommlds/minichain/tools/fns.py +46 -9
  52. ommlds/minichain/tools/jsonschema.py +11 -1
  53. ommlds/minichain/tools/reflect.py +9 -2
  54. ommlds/minichain/tools/types.py +9 -0
  55. ommlds/minichain/utils.py +27 -0
  56. ommlds/minichain/vectors/_marshal.py +1 -1
  57. ommlds/tools/ocr.py +7 -1
  58. {ommlds-0.0.0.dev449.dist-info → ommlds-0.0.0.dev451.dist-info}/METADATA +3 -3
  59. {ommlds-0.0.0.dev449.dist-info → ommlds-0.0.0.dev451.dist-info}/RECORD +67 -53
  60. /ommlds/minichain/lib/{fs/catalog → code}/__init__.py +0 -0
  61. /ommlds/minichain/lib/fs/{catalog/recursivels → tools}/__init__.py +0 -0
  62. /ommlds/minichain/lib/fs/{catalog → tools}/recursivels/rendering.py +0 -0
  63. /ommlds/minichain/lib/fs/{catalog → tools}/recursivels/running.py +0 -0
  64. {ommlds-0.0.0.dev449.dist-info → ommlds-0.0.0.dev451.dist-info}/WHEEL +0 -0
  65. {ommlds-0.0.0.dev449.dist-info → ommlds-0.0.0.dev451.dist-info}/entry_points.txt +0 -0
  66. {ommlds-0.0.0.dev449.dist-info → ommlds-0.0.0.dev451.dist-info}/licenses/LICENSE +0 -0
  67. {ommlds-0.0.0.dev449.dist-info → ommlds-0.0.0.dev451.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,124 @@
1
+ import dataclasses as dc
2
+ import itertools
3
+ import os.path
4
+
5
+ from omlish import check
6
+ from omlish import lang
7
+
8
+ from .... import minichain as mc
9
+ from ....minichain.lib.code.prompts import CODE_AGENT_SYSTEM_PROMPT
10
+ from ...tools.config import ToolsConfig
11
+ from .base import DEFAULT_CHAT_MODEL_BACKEND
12
+ from .base import ChatOptions
13
+ from .base import ChatSession
14
+ from .printing import ChatSessionPrinter
15
+ from .state import ChatStateManager
16
+ from .tools import ToolExecRequestExecutor
17
+
18
+
19
+ with lang.auto_proxy_import(globals()):
20
+ from omdev import ptk
21
+
22
+
23
+ ##
24
+
25
+
26
+ class CodeChatSession(ChatSession['CodeChatSession.Config']):
27
+ @dc.dataclass(frozen=True)
28
+ class Config(ChatSession.Config):
29
+ _: dc.KW_ONLY
30
+
31
+ new: bool = False
32
+
33
+ backend: str | None = None
34
+ model_name: str | None = None
35
+
36
+ initial_message: mc.Content | None = None
37
+
38
+ def __init__(
39
+ self,
40
+ config: Config,
41
+ *,
42
+ state_manager: ChatStateManager,
43
+ chat_options: ChatOptions | None = None,
44
+ printer: ChatSessionPrinter,
45
+ backend_catalog: mc.BackendCatalog,
46
+ tool_exec_request_executor: ToolExecRequestExecutor,
47
+ tools_config: ToolsConfig | None = None,
48
+ ) -> None:
49
+ super().__init__(config)
50
+
51
+ self._state_manager = state_manager
52
+ self._chat_options = chat_options
53
+ self._printer = printer
54
+ self._backend_catalog = backend_catalog
55
+ self._tool_exec_request_executor = tool_exec_request_executor
56
+ self._tools_config = tools_config
57
+
58
+ async def run(self) -> None:
59
+ if self._config.new:
60
+ self._state_manager.clear_state()
61
+ state = self._state_manager.extend_chat([
62
+ mc.SystemMessage(CODE_AGENT_SYSTEM_PROMPT),
63
+ ])
64
+
65
+ else:
66
+ state = self._state_manager.get_state()
67
+
68
+ backend = self._config.backend
69
+ if backend is None:
70
+ backend = DEFAULT_CHAT_MODEL_BACKEND
71
+
72
+ # FIXME: lol
73
+ from ....minichain.lib.fs.context import FsContext
74
+ fs_tool_context = FsContext(
75
+ root_dir=os.getcwd(),
76
+ writes_permitted=self._tools_config is not None and self._tools_config.enable_unsafe_tools_do_not_use_lol,
77
+ )
78
+
79
+ from ....minichain.lib.todo.context import TodoContext
80
+ todo_tool_context = TodoContext()
81
+
82
+ mdl: mc.ChatChoicesService
83
+ async with lang.async_maybe_managing(self._backend_catalog.get_backend(
84
+ mc.ChatChoicesService,
85
+ backend,
86
+ *([mc.ModelName(mn)] if (mn := self._config.model_name) is not None else []),
87
+ )) as mdl:
88
+ for i in itertools.count():
89
+ if not i and self._config.initial_message is not None:
90
+ req_msg = mc.UserMessage(self._config.initial_message)
91
+ else:
92
+ try:
93
+ prompt = await ptk.prompt('> ')
94
+ except EOFError:
95
+ break
96
+ req_msg = mc.UserMessage(prompt)
97
+
98
+ state = self._state_manager.extend_chat([req_msg])
99
+
100
+ while True:
101
+ response = await mdl.invoke(mc.ChatChoicesRequest(
102
+ state.chat,
103
+ (self._chat_options or []),
104
+ ))
105
+ resp_msg = check.single(response.v).m
106
+
107
+ self._printer.print(resp_msg)
108
+ state = self._state_manager.extend_chat([resp_msg])
109
+
110
+ if not (trs := resp_msg.tool_exec_requests):
111
+ break
112
+
113
+ tool_resp_lst = []
114
+ for tr in trs:
115
+ trm = await self._tool_exec_request_executor.execute_tool_request(
116
+ tr,
117
+ fs_tool_context,
118
+ todo_tool_context,
119
+ )
120
+
121
+ self._printer.print(trm.c)
122
+ tool_resp_lst.append(trm)
123
+
124
+ state = self._state_manager.extend_chat(tool_resp_lst)
@@ -1,5 +1,4 @@
1
1
  import dataclasses as dc
2
- import typing as ta
3
2
 
4
3
  from omlish import lang
5
4
 
@@ -10,10 +9,8 @@ from .printing import ChatSessionPrinter
10
9
  from .state import ChatStateManager
11
10
 
12
11
 
13
- if ta.TYPE_CHECKING:
12
+ with lang.auto_proxy_import(globals()):
14
13
  from omdev import ptk
15
- else:
16
- ptk = lang.proxy_import('omdev.ptk')
17
14
 
18
15
 
19
16
  ##
@@ -60,7 +57,7 @@ class InteractiveChatSession(ChatSession['InteractiveChatSession.Config']):
60
57
  *([mc.ModelName(mn)] if (mn := self._config.model_name) is not None else []),
61
58
  )) as mdl:
62
59
  while True:
63
- prompt = ptk.prompt('> ')
60
+ prompt = await ptk.prompt('> ')
64
61
 
65
62
  req_msg = mc.UserMessage(prompt)
66
63
 
@@ -3,16 +3,14 @@ import typing as ta
3
3
 
4
4
  from omlish import check
5
5
  from omlish import lang
6
+ from omlish.formats import json
6
7
 
7
8
  from .... import minichain as mc
8
9
 
9
10
 
10
- if ta.TYPE_CHECKING:
11
+ with lang.auto_proxy_import(globals()):
11
12
  from omdev import ptk
12
13
  from omdev.ptk import markdown as ptk_md
13
- else:
14
- ptk = lang.proxy_import('omdev.ptk')
15
- ptk_md = lang.proxy_import('omdev.ptk.markdown')
16
14
 
17
15
 
18
16
  ##
@@ -51,6 +49,9 @@ class StringChatSessionPrinter(ChatSessionPrinter, lang.Abstract):
51
49
  else:
52
50
  raise TypeError(obj)
53
51
 
52
+ elif isinstance(obj, mc.JsonContent):
53
+ self._print_str(json.dumps_pretty(obj.v))
54
+
54
55
  elif isinstance(obj, str):
55
56
  self._print_str(obj)
56
57
 
@@ -127,11 +127,11 @@ class PromptChatSession(ChatSession['PromptChatSession.Config']):
127
127
  tr: mc.ToolExecRequest = check.single(check.not_none(trs))
128
128
 
129
129
  # FIXME: lol
130
- from ....minichain.lib.fs.context import FsToolContext
130
+ from ....minichain.lib.fs.context import FsContext
131
131
 
132
132
  trm = await self._tool_exec_request_executor.execute_tool_request(
133
133
  tr,
134
- FsToolContext(root_dir=os.getcwd()),
134
+ FsContext(root_dir=os.getcwd()),
135
135
  )
136
136
 
137
137
  print(trm.c)
@@ -106,4 +106,5 @@ class StateStorageChatStateManager(ChatStateManager):
106
106
  updated_at=lang.utcnow(),
107
107
  )
108
108
  self._storage.save_state(self._key, state, ChatState)
109
+ self._state = state
109
110
  return state
@@ -3,16 +3,13 @@ import typing as ta
3
3
 
4
4
  from omlish import check
5
5
  from omlish import lang
6
- from omlish import marshal as msh
7
6
  from omlish.formats import json
8
7
 
9
8
  from .... import minichain as mc
10
9
 
11
10
 
12
- if ta.TYPE_CHECKING:
11
+ with lang.auto_proxy_import(globals()):
13
12
  from omdev import ptk
14
- else:
15
- ptk = lang.proxy_import('omdev.ptk')
16
13
 
17
14
 
18
15
  ##
@@ -49,8 +46,9 @@ class AskingToolExecutionConfirmation(ToolExecutionConfirmation):
49
46
  ) -> None:
50
47
  tr_dct = dict(
51
48
  id=tr.id,
52
- spec=msh.marshal(tce.spec),
49
+ name=tce.spec.name,
53
50
  args=tr.args,
51
+ # spec=msh.marshal(tce.spec),
54
52
  )
55
53
  cr = await ptk.strict_confirm(f'Execute requested tool?\n\n{json.dumps_pretty(tr_dct)}\n\n')
56
54
 
@@ -7,7 +7,8 @@ from omlish import dataclasses as dc
7
7
  @dc.dataclass(frozen=True, kw_only=True)
8
8
  class ToolsConfig:
9
9
  enable_fs_tools: bool = False
10
+ enable_todo_tools: bool = False
10
11
 
11
- enable_unsafe_bash_tool: bool = False
12
+ enable_unsafe_tools_do_not_use_lol: bool = False
12
13
 
13
14
  enable_test_weather_tool: bool = False
@@ -47,16 +47,26 @@ def bind_tools(tools_config: ToolsConfig) -> inj.Elements:
47
47
  #
48
48
 
49
49
  if tools_config.enable_fs_tools:
50
- from ...minichain.lib.fs.catalog.ls import ls_tool
50
+ from ...minichain.lib.fs.tools.ls import ls_tool
51
51
  els.append(bind_tool(ls_tool()))
52
52
 
53
- from ...minichain.lib.fs.catalog.read import read_tool
53
+ from ...minichain.lib.fs.tools.read import read_tool
54
54
  els.append(bind_tool(read_tool()))
55
55
 
56
- if tools_config.enable_unsafe_bash_tool:
56
+ if tools_config.enable_todo_tools:
57
+ from ...minichain.lib.todo.tools.read import todo_read_tool
58
+ els.append(bind_tool(todo_read_tool()))
59
+
60
+ from ...minichain.lib.todo.tools.write import todo_write_tool
61
+ els.append(bind_tool(todo_write_tool()))
62
+
63
+ if tools_config.enable_unsafe_tools_do_not_use_lol:
57
64
  from ...minichain.lib.bash import bash_tool
58
65
  els.append(bind_tool(bash_tool()))
59
66
 
67
+ from ...minichain.lib.fs.tools.edit import edit_tool
68
+ els.append(bind_tool(edit_tool()))
69
+
60
70
  if tools_config.enable_test_weather_tool:
61
71
  els.append(bind_tool(WEATHER_TOOL))
62
72
 
@@ -197,8 +197,14 @@ with _lang.auto_proxy_init(
197
197
  ImageContent,
198
198
  )
199
199
 
200
+ from .content.json import ( # noqa
201
+ JsonContent,
202
+ )
203
+
200
204
  from .content.materialize import ( # noqa
201
205
  CanContent,
206
+
207
+ materialize_content,
202
208
  )
203
209
 
204
210
  from .content.metadata import ( # noqa
@@ -493,6 +499,12 @@ with _lang.auto_proxy_init(
493
499
  EnvKey,
494
500
  )
495
501
 
502
+ from .json import ( # noqa
503
+ JsonSchema,
504
+
505
+ JsonValue,
506
+ )
507
+
496
508
  from .metadata import ( # noqa
497
509
  Metadata,
498
510
 
@@ -1,10 +1,13 @@
1
1
  from omlish import dataclasses as dc
2
+ from omlish import lang
2
3
  from omlish import marshal as msh
3
4
  from omlish import reflect as rfl
4
5
  from omlish.funcs import match as mfs
5
6
  from omlish.typedvalues.marshal import build_typed_values_marshaler
6
7
  from omlish.typedvalues.marshal import build_typed_values_unmarshaler
7
8
 
9
+ from .json import JsonValue
10
+
8
11
 
9
12
  ##
10
13
 
@@ -25,3 +28,39 @@ class _TypedValuesFieldUnmarshalerFactory(msh.UnmarshalerFactoryMatchClass):
25
28
  @mfs.simple(lambda _, ctx, rty: True)
26
29
  def _build(self, ctx: msh.UnmarshalContext, rty: rfl.Type) -> msh.Unmarshaler:
27
30
  return build_typed_values_unmarshaler(ctx, self.tvs_rty)
31
+
32
+
33
+ ##
34
+
35
+
36
+ class MarshalJsonValue(lang.NotInstantiable, lang.Final):
37
+ pass
38
+
39
+
40
+ class _JsonValueMarshalerFactory(msh.MarshalerFactoryMatchClass):
41
+ @mfs.simple(lambda _, ctx, rty: rty is MarshalJsonValue)
42
+ def _build(self, ctx: msh.MarshalContext, rty: rfl.Type) -> msh.Marshaler:
43
+ return msh.NopMarshalerUnmarshaler()
44
+
45
+
46
+ class _JsonValueUnmarshalerFactory(msh.UnmarshalerFactoryMatchClass):
47
+ @mfs.simple(lambda _, ctx, rty: rty is MarshalJsonValue)
48
+ def _build(self, ctx: msh.UnmarshalContext, rty: rfl.Type) -> msh.Unmarshaler:
49
+ return msh.NopMarshalerUnmarshaler()
50
+
51
+
52
+ ##
53
+
54
+
55
+ @lang.static_init
56
+ def _install_standard_marshaling() -> None:
57
+ msh.register_global_config(
58
+ JsonValue,
59
+ msh.ReflectOverride(MarshalJsonValue),
60
+ identity=True,
61
+ )
62
+
63
+ msh.install_standard_factories(
64
+ _JsonValueMarshalerFactory(),
65
+ _JsonValueUnmarshalerFactory(),
66
+ )
@@ -1,4 +1,5 @@
1
1
  """
2
+ https://docs.claude.com/en/api/messages
2
3
  https://github.com/anthropics/anthropic-sdk-python/tree/cd80d46f7a223a5493565d155da31b898a4c6ee5/src/anthropic/types
3
4
  https://github.com/anthropics/anthropic-sdk-python/blob/cd80d46f7a223a5493565d155da31b898a4c6ee5/src/anthropic/resources/completions.py#L53
4
5
  https://github.com/anthropics/anthropic-sdk-python/blob/cd80d46f7a223a5493565d155da31b898a4c6ee5/src/anthropic/resources/messages.py#L70
@@ -6,11 +7,12 @@ https://github.com/anthropics/anthropic-sdk-python/blob/cd80d46f7a223a5493565d15
6
7
  import typing as ta
7
8
 
8
9
  from omlish import check
9
- from omlish import lang
10
+ from omlish import marshal as msh
10
11
  from omlish import typedvalues as tv
11
12
  from omlish.formats import json
12
13
  from omlish.http import all as http
13
14
 
15
+ from .....backends.anthropic.protocol import types as pt
14
16
  from ....chat.choices.services import ChatChoicesRequest
15
17
  from ....chat.choices.services import ChatChoicesResponse
16
18
  from ....chat.choices.services import static_check_is_chat_choices_service
@@ -18,9 +20,14 @@ from ....chat.choices.types import AiChoice
18
20
  from ....chat.messages import AiMessage
19
21
  from ....chat.messages import Message
20
22
  from ....chat.messages import SystemMessage
23
+ from ....chat.messages import ToolExecResultMessage
21
24
  from ....chat.messages import UserMessage
25
+ from ....chat.tools.types import Tool
26
+ from ....content.prepare import prepare_content_str
22
27
  from ....models.configs import ModelName
23
28
  from ....standard import ApiKey
29
+ from ....tools.jsonschema import build_tool_spec_params_json_schema
30
+ from ....tools.types import ToolExecRequest
24
31
  from .names import MODEL_NAMES
25
32
 
26
33
 
@@ -68,26 +75,69 @@ class AnthropicChatChoicesService:
68
75
  *,
69
76
  max_tokens: int = 4096, # FIXME: ChatOption
70
77
  ) -> ChatChoicesResponse:
71
- messages = []
72
- system: str | None = None
78
+ messages: list[pt.Message] = []
79
+ system: list[pt.Content] | None = None
73
80
  for i, m in enumerate(request.v):
74
81
  if isinstance(m, SystemMessage):
75
82
  if i != 0 or system is not None:
76
83
  raise Exception('Only supports one system message and must be first')
77
- system = self._get_msg_content(m)
84
+ system = [pt.Text(check.not_none(self._get_msg_content(m)))]
85
+
86
+ elif isinstance(m, ToolExecResultMessage):
87
+ messages.append(pt.Message(
88
+ role='user',
89
+ content=[pt.ToolResult(
90
+ tool_use_id=check.not_none(m.id),
91
+ content=json.dumps_compact(msh.marshal(m.c)) if not isinstance(m.c, str) else m.c,
92
+ )],
93
+ ))
94
+
95
+ elif isinstance(m, AiMessage):
96
+ # messages.append(pt.Message(
97
+ # role=self.ROLES_MAP[type(m)], # noqa
98
+ # content=[pt.Text(check.isinstance(self._get_msg_content(m), str))],
99
+ # ))
100
+ a_tus: list[pt.ToolUse] = []
101
+ for tr in m.tool_exec_requests or []:
102
+ a_tus.append(pt.ToolUse(
103
+ id=check.not_none(tr.id),
104
+ name=check.not_none(tr.name),
105
+ input=tr.args,
106
+ ))
107
+ messages.append(pt.Message(
108
+ role='assistant',
109
+ content=[
110
+ *([pt.Text(check.isinstance(m.c, str))] if m.c is not None else []),
111
+ *a_tus,
112
+ ],
113
+ ))
114
+
78
115
  else:
79
- messages.append(dict(
80
- role=self.ROLES_MAP[type(m)], # noqa
81
- content=check.isinstance(self._get_msg_content(m), str),
116
+ messages.append(pt.Message(
117
+ role=self.ROLES_MAP[type(m)], # type: ignore[arg-type]
118
+ content=[pt.Text(check.isinstance(self._get_msg_content(m), str))],
119
+ ))
120
+
121
+ tools: list[pt.ToolSpec] = []
122
+ with tv.TypedValues(*request.options).consume() as oc:
123
+ t: Tool
124
+ for t in oc.pop(Tool, []):
125
+ tools.append(pt.ToolSpec(
126
+ name=check.not_none(t.spec.name),
127
+ description=prepare_content_str(t.spec.desc),
128
+ input_schema=build_tool_spec_params_json_schema(t.spec),
82
129
  ))
83
130
 
84
- raw_request = dict(
131
+ a_req = pt.MessagesRequest(
85
132
  model=MODEL_NAMES.resolve(self._model_name.v),
86
- **lang.opt_kw(system=system),
133
+ system=system,
87
134
  messages=messages,
135
+ tools=tools or None,
88
136
  max_tokens=max_tokens,
89
137
  )
90
138
 
139
+ raw_request = msh.marshal(a_req)
140
+
91
141
  raw_response = http.request(
92
142
  'https://api.anthropic.com/v1/messages',
93
143
  headers={
@@ -100,6 +150,24 @@ class AnthropicChatChoicesService:
100
150
 
101
151
  response = json.loads(check.not_none(raw_response.data).decode('utf-8'))
102
152
 
153
+ resp_c: ta.Any = None
154
+ ters: list[ToolExecRequest] = []
155
+ for c in response['content']:
156
+ if c['type'] == 'text':
157
+ check.none(resp_c)
158
+ resp_c = check.not_none(c['text'])
159
+ elif c['type'] == 'tool_use':
160
+ ters.append(ToolExecRequest(
161
+ id=c['id'],
162
+ name=c['name'],
163
+ args=c['input'],
164
+ ))
165
+ else:
166
+ raise TypeError(c['type'])
167
+
103
168
  return ChatChoicesResponse([
104
- AiChoice(AiMessage(response['content'][0]['text'])), # noqa
169
+ AiChoice(AiMessage(
170
+ resp_c,
171
+ tool_exec_requests=ters if ters else None,
172
+ )),
105
173
  ])
@@ -17,10 +17,15 @@ from ....chat.choices.types import AiChoice
17
17
  from ....chat.messages import AiMessage
18
18
  from ....chat.messages import Message
19
19
  from ....chat.messages import SystemMessage
20
+ from ....chat.messages import ToolExecResultMessage
20
21
  from ....chat.messages import UserMessage
22
+ from ....chat.tools.types import Tool
23
+ from ....content.types import Content
21
24
  from ....models.configs import ModelName
22
25
  from ....standard import ApiKey
26
+ from ....tools.types import ToolExecRequest
23
27
  from .names import MODEL_NAMES
28
+ from .tools import build_tool_spec_schema
24
29
 
25
30
 
26
31
  ##
@@ -54,9 +59,8 @@ class GoogleChatChoicesService:
54
59
  BASE_URL: ta.ClassVar[str] = 'https://generativelanguage.googleapis.com/v1beta/models'
55
60
 
56
61
  ROLES_MAP: ta.ClassVar[ta.Mapping[type[Message], str]] = {
57
- SystemMessage: 'system',
58
62
  UserMessage: 'user',
59
- AiMessage: 'assistant',
63
+ AiMessage: 'model',
60
64
  }
61
65
 
62
66
  async def invoke(
@@ -65,16 +69,77 @@ class GoogleChatChoicesService:
65
69
  ) -> ChatChoicesResponse:
66
70
  key = check.not_none(self._api_key).reveal()
67
71
 
68
- g_req = pt.GenerateContentRequest(
69
- contents=[
70
- pt.Content(
72
+ g_sys_content: pt.Content | None = None
73
+ g_contents: list[pt.Content] = []
74
+ for i, m in enumerate(request.v):
75
+ if isinstance(m, SystemMessage):
76
+ check.arg(i == 0)
77
+ check.none(g_sys_content)
78
+ g_sys_content = pt.Content(
71
79
  parts=[pt.Part(
72
80
  text=check.not_none(self._get_msg_content(m)),
73
81
  )],
74
- role=self.ROLES_MAP[type(m)], # type: ignore[arg-type]
75
82
  )
76
- for m in request.v
77
- ],
83
+
84
+ elif isinstance(m, ToolExecResultMessage):
85
+ tr_resp_val: pt.Value
86
+ if m.c is None:
87
+ tr_resp_val = pt.NullValue() # type: ignore[unreachable]
88
+ elif isinstance(m.c, str):
89
+ tr_resp_val = pt.StringValue(m.c)
90
+ else:
91
+ raise TypeError(m.c)
92
+ g_contents.append(pt.Content(
93
+ parts=[pt.Part(
94
+ function_response=pt.FunctionResponse(
95
+ id=m.id,
96
+ name=m.name,
97
+ response={
98
+ 'value': tr_resp_val,
99
+ },
100
+ ),
101
+ )],
102
+ ))
103
+
104
+ elif isinstance(m, AiMessage):
105
+ ai_parts: list[pt.Part] = []
106
+ if m.c is not None:
107
+ ai_parts.append(pt.Part(
108
+ text=check.not_none(self._get_msg_content(m)),
109
+ ))
110
+ for teq in m.tool_exec_requests or []:
111
+ ai_parts.append(pt.Part(
112
+ function_call=pt.FunctionCall(
113
+ id=teq.id,
114
+ name=teq.name,
115
+ args=teq.args,
116
+ ),
117
+ ))
118
+ g_contents.append(pt.Content(
119
+ parts=ai_parts,
120
+ role='model',
121
+ ))
122
+
123
+ else:
124
+ g_contents.append(pt.Content(
125
+ parts=[pt.Part(
126
+ text=check.not_none(self._get_msg_content(m)),
127
+ )],
128
+ role=self.ROLES_MAP[type(m)], # type: ignore[arg-type]
129
+ ))
130
+
131
+ g_tools: list[pt.Tool] = []
132
+ with tv.TypedValues(*request.options).consume() as oc:
133
+ t: Tool
134
+ for t in oc.pop(Tool, []):
135
+ g_tools.append(pt.Tool(
136
+ function_declarations=[build_tool_spec_schema(t.spec)],
137
+ ))
138
+
139
+ g_req = pt.GenerateContentRequest(
140
+ contents=g_contents or None,
141
+ tools=g_tools or None,
142
+ system_instruction=g_sys_content,
78
143
  )
79
144
 
80
145
  req_dct = msh.marshal(g_req)
@@ -92,7 +157,25 @@ class GoogleChatChoicesService:
92
157
 
93
158
  g_resp = msh.unmarshal(resp_dct, pt.GenerateContentResponse)
94
159
 
95
- return ChatChoicesResponse([
96
- AiChoice(AiMessage(check.not_none(check.not_none(check.not_none(c.content).parts)[0].text)))
97
- for c in g_resp.candidates or []
98
- ])
160
+ ai_choices: list[AiChoice] = []
161
+ for c in g_resp.candidates or []:
162
+ ai_c: Content | None = None
163
+ ters: list[ToolExecRequest] = []
164
+ for g_resp_part in check.not_none(check.not_none(c.content).parts):
165
+ if (g_txt := g_resp_part.text) is not None:
166
+ check.none(ai_c)
167
+ ai_c = g_txt
168
+ elif (g_fc := g_resp_part.function_call) is not None:
169
+ ters.append(ToolExecRequest(
170
+ id=g_fc.id,
171
+ name=g_fc.name,
172
+ args=g_fc.args or {},
173
+ ))
174
+ else:
175
+ raise TypeError(g_resp_part)
176
+ ai_choices.append(AiChoice(AiMessage(
177
+ c=ai_c,
178
+ tool_exec_requests=ters if ters else None,
179
+ )))
180
+
181
+ return ChatChoicesResponse(ai_choices)