ommlds 0.0.0.dev451__py3-none-any.whl → 0.0.0.dev452__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ommlds might be problematic. Click here for more details.

Files changed (61) hide show
  1. ommlds/.omlish-manifests.json +11 -11
  2. ommlds/backends/anthropic/protocol/_marshal.py +1 -1
  3. ommlds/backends/openai/protocol/_common.py +18 -0
  4. ommlds/backends/openai/protocol/_marshal.py +2 -1
  5. ommlds/backends/openai/protocol/chatcompletion/chunk.py +4 -0
  6. ommlds/backends/openai/protocol/chatcompletion/contentpart.py +15 -7
  7. ommlds/backends/openai/protocol/chatcompletion/message.py +10 -0
  8. ommlds/backends/openai/protocol/chatcompletion/request.py +25 -7
  9. ommlds/backends/openai/protocol/chatcompletion/response.py +10 -0
  10. ommlds/backends/openai/protocol/chatcompletion/responseformat.py +6 -0
  11. ommlds/backends/openai/protocol/chatcompletion/tokenlogprob.py +4 -0
  12. ommlds/backends/openai/protocol/completionusage.py +5 -0
  13. ommlds/cli/sessions/chat/code.py +22 -17
  14. ommlds/cli/sessions/chat/inject.py +4 -4
  15. ommlds/cli/sessions/chat/interactive.py +2 -1
  16. ommlds/cli/sessions/chat/printing.py +2 -2
  17. ommlds/cli/sessions/chat/prompt.py +28 -27
  18. ommlds/cli/sessions/chat/tools.py +12 -12
  19. ommlds/minichain/__init__.py +20 -8
  20. ommlds/minichain/backends/impls/anthropic/chat.py +27 -23
  21. ommlds/minichain/backends/impls/anthropic/names.py +3 -3
  22. ommlds/minichain/backends/impls/anthropic/stream.py +7 -7
  23. ommlds/minichain/backends/impls/google/chat.py +30 -32
  24. ommlds/minichain/backends/impls/google/stream.py +8 -4
  25. ommlds/minichain/backends/impls/llamacpp/chat.py +23 -17
  26. ommlds/minichain/backends/impls/llamacpp/format.py +4 -2
  27. ommlds/minichain/backends/impls/llamacpp/stream.py +6 -6
  28. ommlds/minichain/backends/impls/mistral.py +1 -1
  29. ommlds/minichain/backends/impls/mlx/chat.py +1 -1
  30. ommlds/minichain/backends/impls/openai/chat.py +6 -3
  31. ommlds/minichain/backends/impls/openai/format.py +80 -61
  32. ommlds/minichain/backends/impls/openai/format2.py +210 -0
  33. ommlds/minichain/backends/impls/openai/stream.py +9 -6
  34. ommlds/minichain/backends/impls/tinygrad/chat.py +10 -5
  35. ommlds/minichain/backends/impls/transformers/transformers.py +20 -16
  36. ommlds/minichain/chat/_marshal.py +15 -8
  37. ommlds/minichain/chat/choices/adapters.py +3 -3
  38. ommlds/minichain/chat/choices/types.py +2 -2
  39. ommlds/minichain/chat/history.py +1 -1
  40. ommlds/minichain/chat/messages.py +55 -19
  41. ommlds/minichain/chat/services.py +2 -2
  42. ommlds/minichain/chat/stream/_marshal.py +16 -0
  43. ommlds/minichain/chat/stream/adapters.py +39 -28
  44. ommlds/minichain/chat/stream/services.py +2 -2
  45. ommlds/minichain/chat/stream/types.py +20 -13
  46. ommlds/minichain/chat/tools/execution.py +8 -7
  47. ommlds/minichain/chat/tools/ids.py +9 -15
  48. ommlds/minichain/chat/tools/parsing.py +17 -26
  49. ommlds/minichain/chat/transforms/base.py +29 -38
  50. ommlds/minichain/chat/transforms/metadata.py +30 -4
  51. ommlds/minichain/chat/transforms/services.py +5 -7
  52. ommlds/minichain/tools/jsonschema.py +5 -6
  53. ommlds/minichain/tools/types.py +24 -1
  54. ommlds/server/server.py +1 -1
  55. ommlds/tools/git.py +18 -2
  56. {ommlds-0.0.0.dev451.dist-info → ommlds-0.0.0.dev452.dist-info}/METADATA +3 -3
  57. {ommlds-0.0.0.dev451.dist-info → ommlds-0.0.0.dev452.dist-info}/RECORD +61 -58
  58. {ommlds-0.0.0.dev451.dist-info → ommlds-0.0.0.dev452.dist-info}/WHEEL +0 -0
  59. {ommlds-0.0.0.dev451.dist-info → ommlds-0.0.0.dev452.dist-info}/entry_points.txt +0 -0
  60. {ommlds-0.0.0.dev451.dist-info → ommlds-0.0.0.dev452.dist-info}/licenses/LICENSE +0 -0
  61. {ommlds-0.0.0.dev451.dist-info → ommlds-0.0.0.dev452.dist-info}/top_level.txt +0 -0
@@ -9,7 +9,7 @@ from omlish import lang
9
9
 
10
10
  from ...types import Option
11
11
  from ...types import Output
12
- from ..messages import AiMessage
12
+ from ..messages import AiChat
13
13
  from ..types import ChatOptions
14
14
  from ..types import ChatOutputs
15
15
 
@@ -39,7 +39,7 @@ ChatChoicesOutputs: ta.TypeAlias = ChatChoicesOutput | ChatOutputs
39
39
 
40
40
  @dc.dataclass(frozen=True)
41
41
  class AiChoice(lang.Final):
42
- m: AiMessage
42
+ ms: AiChat
43
43
 
44
44
 
45
45
  AiChoices: ta.TypeAlias = ta.Sequence[AiChoice]
@@ -68,6 +68,6 @@ class HistoryAddingChatService:
68
68
  response = await self._inner.invoke(new_req)
69
69
  self._history.add(
70
70
  *request.v,
71
- response.v,
71
+ *response.v,
72
72
  )
73
73
  return response
@@ -13,7 +13,8 @@ from ..content.materialize import CanContent
13
13
  from ..content.transforms.base import ContentTransform
14
14
  from ..content.types import Content
15
15
  from ..metadata import MetadataContainer
16
- from ..tools.types import ToolExecRequest
16
+ from ..tools.types import ToolUse
17
+ from ..tools.types import ToolUseResult
17
18
  from .metadata import MessageMetadatas
18
19
 
19
20
 
@@ -47,11 +48,48 @@ class Message( # noqa
47
48
  return dc.replace(self, _metadata=tv.TypedValues(*self._metadata, *mds, override=override))
48
49
 
49
50
 
51
+ Chat: ta.TypeAlias = ta.Sequence[Message]
52
+
53
+
54
+ ##
55
+
56
+
57
+ @dc.dataclass(frozen=True)
58
+ class AnyUserMessage(Message, lang.Abstract):
59
+ pass
60
+
61
+
62
+ UserChat: ta.TypeAlias = ta.Sequence[AnyUserMessage]
63
+
64
+
65
+ def check_user_chat(chat: Chat) -> UserChat:
66
+ for m in chat:
67
+ check.isinstance(m, AnyUserMessage)
68
+ return ta.cast(UserChat, chat)
69
+
70
+
50
71
  #
51
72
 
52
73
 
53
74
  @dc.dataclass(frozen=True)
54
- class SystemMessage(Message, lang.Final):
75
+ class AnyAiMessage(Message, lang.Abstract):
76
+ pass
77
+
78
+
79
+ AiChat: ta.TypeAlias = ta.Sequence[AnyAiMessage]
80
+
81
+
82
+ def check_ai_chat(chat: Chat) -> AiChat:
83
+ for m in chat:
84
+ check.isinstance(m, AnyAiMessage)
85
+ return ta.cast(AiChat, chat)
86
+
87
+
88
+ ##
89
+
90
+
91
+ @dc.dataclass(frozen=True)
92
+ class SystemMessage(AnyUserMessage, lang.Final):
55
93
  c: CanContent
56
94
 
57
95
 
@@ -60,7 +98,7 @@ class SystemMessage(Message, lang.Final):
60
98
 
61
99
  @dc.dataclass(frozen=True)
62
100
  @msh.update_fields_metadata(['name'], omit_if=operator.not_)
63
- class UserMessage(Message, lang.Final):
101
+ class UserMessage(AnyUserMessage, lang.Final):
64
102
  c: CanContent
65
103
 
66
104
  name: str | None = dc.xfield(None, repr_fn=dc.opt_repr)
@@ -70,27 +108,21 @@ class UserMessage(Message, lang.Final):
70
108
 
71
109
 
72
110
  @dc.dataclass(frozen=True)
73
- @msh.update_fields_metadata(['tool_exec_requests'], omit_if=operator.not_)
74
- class AiMessage(Message, lang.Final):
75
- c: Content | None = dc.xfield(None, repr_fn=dc.opt_repr)
76
-
77
- tool_exec_requests: ta.Sequence[ToolExecRequest] | None = dc.xfield(None, repr_fn=dc.opt_repr)
111
+ class AiMessage(AnyAiMessage, lang.Final):
112
+ c: Content = dc.xfield(None, repr_fn=dc.opt_repr) # TODO: non-null?
78
113
 
79
114
 
80
115
  #
81
116
 
82
117
 
83
- @dc.dataclass(frozen=True, kw_only=True)
84
- class ToolExecResultMessage(Message, lang.Final):
85
- id: str | None = None
86
- name: str
87
- c: Content
88
-
89
-
90
- ##
118
+ @dc.dataclass(frozen=True)
119
+ class ToolUseMessage(AnyAiMessage, lang.Final):
120
+ tu: ToolUse
91
121
 
92
122
 
93
- Chat: ta.TypeAlias = ta.Sequence[Message]
123
+ @dc.dataclass(frozen=True)
124
+ class ToolUseResultMessage(AnyUserMessage, lang.Final):
125
+ tur: ToolUseResult
94
126
 
95
127
 
96
128
  ##
@@ -110,5 +142,9 @@ class _MessageContentTransform(ContentTransform, lang.Final, lang.NotInstantiabl
110
142
  return dc.replace(m, c=self.apply(m.c))
111
143
 
112
144
  @dispatch.install_method(ContentTransform.apply)
113
- def apply_tool_exec_result_message(self, m: ToolExecResultMessage) -> ToolExecResultMessage:
114
- return m
145
+ def apply_tool_use_message(self, m: ToolUseMessage) -> ToolUseMessage:
146
+ return dc.replace(m, tu=self.apply(m.tu))
147
+
148
+ @dispatch.install_method(ContentTransform.apply)
149
+ def apply_tool_use_result_message(self, m: ToolUseResultMessage) -> ToolUseResultMessage:
150
+ return dc.replace(m, tur=self.apply(m.tur))
@@ -7,7 +7,7 @@ from ..registries.globals import register_type
7
7
  from ..services import Request
8
8
  from ..services import Response
9
9
  from ..services import Service
10
- from .messages import AiMessage
10
+ from .messages import AiChat
11
11
  from .messages import Chat
12
12
  from .types import ChatOptions
13
13
  from .types import ChatOutputs
@@ -18,7 +18,7 @@ from .types import ChatOutputs
18
18
 
19
19
  ChatRequest: ta.TypeAlias = Request[Chat, ChatOptions]
20
20
 
21
- ChatResponse: ta.TypeAlias = Response[AiMessage, ChatOutputs]
21
+ ChatResponse: ta.TypeAlias = Response[AiChat, ChatOutputs]
22
22
 
23
23
  # @omlish-manifest $.minichain.registries.manifests.RegistryTypeManifest
24
24
  ChatService: ta.TypeAlias = Service[ChatRequest, ChatResponse]
@@ -0,0 +1,16 @@
1
+ from omlish import lang
2
+ from omlish import marshal as msh
3
+
4
+ from .types import AiChoiceDelta
5
+
6
+
7
+ ##
8
+
9
+
10
+ @lang.static_init
11
+ def _install_standard_marshaling() -> None:
12
+ acd_poly = msh.polymorphism_from_subclasses(AiChoiceDelta, naming=msh.Naming.SNAKE)
13
+ msh.install_standard_factories(
14
+ msh.PolymorphismMarshalerFactory(acd_poly),
15
+ msh.PolymorphismUnmarshalerFactory(acd_poly),
16
+ )
@@ -1,19 +1,22 @@
1
- import typing as ta
2
-
3
1
  from omlish import check
4
2
  from omlish import dataclasses as dc
5
3
  from omlish import lang
6
4
 
7
5
  from ...services import Response
6
+ from ...tools.types import ToolUse
8
7
  from ..choices.services import ChatChoicesRequest
9
8
  from ..choices.services import static_check_is_chat_choices_service
10
9
  from ..choices.types import AiChoice
11
10
  from ..choices.types import AiChoices
12
11
  from ..messages import AiMessage
13
- from ..messages import ToolExecRequest
12
+ from ..messages import AnyAiMessage
13
+ from ..messages import ToolUseMessage
14
14
  from .services import ChatChoicesOutputs
15
15
  from .services import ChatChoicesStreamOutputs
16
16
  from .services import ChatChoicesStreamService
17
+ from .types import AiChoiceDelta
18
+ from .types import ContentAiChoiceDelta
19
+ from .types import ToolUseAiChoiceDelta
17
20
 
18
21
 
19
22
  ##
@@ -24,46 +27,54 @@ from .services import ChatChoicesStreamService
24
27
  class ChatChoicesStreamServiceChatChoicesService:
25
28
  service: ChatChoicesStreamService
26
29
 
27
- class _Choice(ta.NamedTuple):
28
- parts: list[str]
29
- trs: list[ToolExecRequest]
30
-
31
30
  async def invoke(self, request: ChatChoicesRequest) -> Response[
32
31
  AiChoices,
33
32
  ChatChoicesOutputs | ChatChoicesStreamOutputs,
34
33
  ]:
35
- lst: list[ChatChoicesStreamServiceChatChoicesService._Choice] = []
34
+ choice_lsts: list[list[list[str] | ToolUse]] = []
35
+
36
+ def add(l: list[list[str] | ToolUse], d: AiChoiceDelta) -> None:
37
+ if isinstance(d, ContentAiChoiceDelta):
38
+ s = check.isinstance(d.c, str)
39
+ if l and isinstance(l[-1], list):
40
+ l[-1].append(s)
41
+ else:
42
+ l.append([s])
43
+
44
+ elif isinstance(d, ToolUseAiChoiceDelta):
45
+ l.append(d.tu)
46
+
47
+ else:
48
+ raise TypeError(d)
36
49
 
37
50
  async with (resp := await self.service.invoke(request)).v as it: # noqa
38
51
  i = -1 # noqa
52
+ l: list[list[str] | ToolUse]
39
53
  async for i, cs in lang.async_enumerate(it):
40
54
  if i == 0:
41
- for c in cs:
42
- m = c.m
43
- lst.append(self._Choice(
44
- [check.isinstance(m.c, str)] if m.c is not None else [],
45
- # FIXME
46
- # list(m.tool_exec_requests or []),
47
- [],
48
- ))
55
+ for c in cs.choices:
56
+ choice_lsts.append(l := [])
57
+ for d in c.deltas:
58
+ add(l, d)
49
59
 
50
60
  else:
51
- for ch, c in zip(lst, cs, strict=True):
52
- m = c.m
53
- if m.c is not None:
54
- ch.parts.append(check.isinstance(m.c, str))
55
- # FIXME
56
- # if m.tool_exec_requests:
57
- # ch.trs.extend(m.tool_exec_requests)
61
+ for l, c in zip(choice_lsts, cs.choices, strict=True):
62
+ for d in c.deltas:
63
+ add(l, d)
58
64
 
59
65
  # check.state(resp_v.is_done)
60
66
 
61
67
  ret: list[AiChoice] = []
62
- for ch in lst:
63
- ret.append(AiChoice(AiMessage(
64
- ''.join(ch.parts) if ch.parts else None,
65
- ch.trs or None,
66
- )))
68
+ for cl in choice_lsts:
69
+ cc: list[AnyAiMessage] = []
70
+ for e in cl:
71
+ if isinstance(e, list):
72
+ cc.append(AiMessage(''.join(e)))
73
+ elif isinstance(e, ToolUse):
74
+ cc.append(ToolUseMessage(e))
75
+ else:
76
+ raise TypeError(e)
77
+ ret.append(AiChoice(cc))
67
78
 
68
79
  # FIXME: outputs lol
69
80
  return Response(ret)
@@ -9,7 +9,7 @@ from ...services import Service
9
9
  from ...stream.services import StreamResponse
10
10
  from ..choices.types import ChatChoicesOutputs
11
11
  from ..messages import Chat
12
- from .types import AiChoiceDeltas
12
+ from .types import AiChoicesDeltas
13
13
  from .types import ChatChoicesStreamOptions
14
14
  from .types import ChatChoicesStreamOutputs
15
15
 
@@ -20,7 +20,7 @@ from .types import ChatChoicesStreamOutputs
20
20
  ChatChoicesStreamRequest: ta.TypeAlias = Request[Chat, ChatChoicesStreamOptions]
21
21
 
22
22
  ChatChoicesStreamResponse: ta.TypeAlias = StreamResponse[
23
- AiChoiceDeltas,
23
+ AiChoicesDeltas,
24
24
  ChatChoicesOutputs,
25
25
  ChatChoicesStreamOutputs,
26
26
  ]
@@ -1,4 +1,3 @@
1
- import operator
2
1
  import typing as ta
3
2
 
4
3
  from omlish import dataclasses as dc
@@ -7,11 +6,15 @@ from omlish import marshal as msh
7
6
 
8
7
  from ...content.types import Content
9
8
  from ...stream.services import StreamOptions
9
+ from ...tools.types import ToolUse
10
10
  from ...types import Option
11
11
  from ...types import Output
12
12
  from ..choices.types import ChatChoicesOptions
13
13
 
14
14
 
15
+ msh.register_global_module_import('._marshal', __package__)
16
+
17
+
15
18
  ##
16
19
 
17
20
 
@@ -36,24 +39,28 @@ ChatChoicesStreamOutputs: ta.TypeAlias = ChatChoicesStreamOutput
36
39
 
37
40
 
38
41
  @dc.dataclass(frozen=True)
39
- class ToolExecRequestDelta(lang.Final):
40
- index: int | None = None
41
- id: str | None = None
42
- name: str | None = None
43
- args: str | None = None
42
+ class AiChoiceDelta(lang.Sealed, lang.Abstract):
43
+ pass
44
+
45
+
46
+ @dc.dataclass(frozen=True)
47
+ class ContentAiChoiceDelta(AiChoiceDelta, lang.Final):
48
+ c: Content
44
49
 
45
50
 
46
51
  @dc.dataclass(frozen=True)
47
- @msh.update_fields_metadata(['tool_exec_requests'], omit_if=operator.not_)
48
- class AiMessageDelta(lang.Final):
49
- c: Content | None = dc.xfield(None, repr_fn=dc.opt_repr)
52
+ class ToolUseAiChoiceDelta(AiChoiceDelta, lang.Final):
53
+ tu: ToolUse
50
54
 
51
- tool_exec_requests: ta.Sequence[ToolExecRequestDelta] | None = dc.xfield(None, repr_fn=dc.opt_repr)
55
+
56
+ #
52
57
 
53
58
 
54
59
  @dc.dataclass(frozen=True)
55
- class AiChoiceDelta(lang.Final):
56
- m: AiMessageDelta
60
+ class AiChoiceDeltas(lang.Final):
61
+ deltas: ta.Sequence[AiChoiceDelta]
57
62
 
58
63
 
59
- AiChoiceDeltas: ta.TypeAlias = ta.Sequence[AiChoiceDelta]
64
+ @dc.dataclass(frozen=True)
65
+ class AiChoicesDeltas(lang.Final):
66
+ choices: ta.Sequence[AiChoiceDeltas]
@@ -1,25 +1,26 @@
1
1
  from ...tools.execution.context import ToolContext
2
2
  from ...tools.execution.executors import ToolExecutor
3
- from ..messages import ToolExecRequest
4
- from ..messages import ToolExecResultMessage
3
+ from ...tools.types import ToolUseResult
4
+ from ..messages import ToolUse
5
+ from ..messages import ToolUseResultMessage
5
6
 
6
7
 
7
8
  ##
8
9
 
9
10
 
10
- async def execute_tool_request(
11
+ async def execute_tool_use(
11
12
  ctx: ToolContext,
12
13
  tex: ToolExecutor,
13
- ter: ToolExecRequest,
14
- ) -> ToolExecResultMessage:
14
+ ter: ToolUse,
15
+ ) -> ToolUseResultMessage:
15
16
  result_str = await tex.execute_tool(
16
17
  ctx,
17
18
  ter.name,
18
19
  ter.args,
19
20
  )
20
21
 
21
- return ToolExecResultMessage(
22
+ return ToolUseResultMessage(ToolUseResult(
22
23
  id=ter.id,
23
24
  name=ter.name,
24
25
  c=result_str,
25
- )
26
+ ))
@@ -3,31 +3,25 @@ import uuid
3
3
 
4
4
  from omlish import dataclasses as dc
5
5
 
6
- from ..messages import AiMessage
6
+ from ..messages import Chat
7
7
  from ..messages import Message
8
- from ..messages import ToolExecRequest
8
+ from ..messages import ToolUseMessage
9
9
  from ..transforms.base import MessageTransform
10
10
 
11
11
 
12
12
  ##
13
13
 
14
14
 
15
- def simple_uuid_tool_exec_request_id_factory(m: AiMessage, ter: ToolExecRequest) -> str: # noqa
15
+ def simple_uuid_tool_exec_request_id_factory(m: ToolUseMessage) -> str: # noqa
16
16
  return str(uuid.uuid4())
17
17
 
18
18
 
19
19
  @dc.dataclass(frozen=True)
20
- class ToolExecRequestIdAddingMessageTransform(MessageTransform):
21
- id_factory: ta.Callable[[AiMessage, ToolExecRequest], str] = dc.field(default=simple_uuid_tool_exec_request_id_factory) # noqa
20
+ class ToolUseIdAddingMessageTransform(MessageTransform):
21
+ id_factory: ta.Callable[[ToolUseMessage], str] = dc.field(default=simple_uuid_tool_exec_request_id_factory) # noqa
22
22
 
23
- def transform_message(self, m: Message) -> Message:
24
- if not isinstance(m, AiMessage) or not m.tool_exec_requests:
25
- return m
23
+ def transform_message(self, m: Message) -> Chat:
24
+ if not isinstance(m, ToolUseMessage) or m.tu.id is not None:
25
+ return [m]
26
26
 
27
- lst: list[ToolExecRequest] = []
28
- for ter in m.tool_exec_requests:
29
- if ter.id is None:
30
- ter = dc.replace(ter, id=self.id_factory(m, ter))
31
- lst.append(ter)
32
-
33
- return dc.replace(m, tool_exec_requests=lst)
27
+ return [dc.replace(m, tu=dc.replace(m.tu, id=self.id_factory(m)))]
@@ -1,54 +1,45 @@
1
+ import typing as ta
2
+
1
3
  from omlish import check
2
4
  from omlish import dataclasses as dc
3
5
 
4
- from ...content.types import Content
5
6
  from ...text.toolparsing.base import ParsedToolExec
6
7
  from ...text.toolparsing.base import ToolExecParser
7
8
  from ..messages import AiMessage
8
- from ..messages import ToolExecRequest
9
- from ..transforms.base import MessageTransform
9
+ from ..messages import AnyAiMessage
10
+ from ..messages import ToolUse
11
+ from ..messages import ToolUseMessage
12
+ from ..transforms.base import AiMessageTransform
10
13
 
11
14
 
12
15
  ##
13
16
 
14
17
 
15
18
  @dc.dataclass(frozen=True)
16
- class ToolExecParsingMessageTransform(MessageTransform[AiMessage]):
19
+ class ToolExecParsingMessageTransform(AiMessageTransform):
17
20
  parser: ToolExecParser
18
21
 
19
- def transform_message(self, message: AiMessage) -> AiMessage:
22
+ def transform_message(self, message: AnyAiMessage) -> ta.Sequence[AnyAiMessage]:
23
+ if not isinstance(message, AiMessage):
24
+ return [message]
25
+
20
26
  pts = self.parser.parse_tool_execs_(check.isinstance(message.c or '', str))
21
27
 
22
- sl: list[str] = []
23
- xl: list[ToolExecRequest] = []
28
+ out: list[AnyAiMessage] = []
29
+
24
30
  for pt in pts:
25
31
  if isinstance(pt, ParsedToolExec):
26
- xl.append(ToolExecRequest(
32
+ out.append(ToolUseMessage(ToolUse(
27
33
  id=pt.id,
28
34
  name=pt.name,
29
35
  args=pt.args,
30
36
  raw_args=pt.raw_body,
31
- ))
37
+ )))
32
38
 
33
39
  elif isinstance(pt, str):
34
- sl.append(pt)
40
+ out.append(AiMessage(pt))
35
41
 
36
42
  else:
37
43
  raise TypeError(pt)
38
44
 
39
- c: Content | None
40
- if len(sl) == 1:
41
- [c] = sl
42
- elif sl:
43
- c = sl
44
- else:
45
- c = None
46
-
47
- return dc.replace(
48
- message,
49
- c=c,
50
- tool_exec_requests=[
51
- *(message.tool_exec_requests or []),
52
- *xl,
53
- ],
54
- )
45
+ return out
@@ -1,76 +1,67 @@
1
+ """
2
+ Mirrors omlish.funcs.pairs.
3
+
4
+ TODO:
5
+ - MessagesTransform ? MessageTransformMessagesTransform? :| ...
6
+ """
1
7
  import abc
2
8
  import typing as ta
3
9
 
4
10
  from omlish import dataclasses as dc
5
11
  from omlish import lang
6
12
 
13
+ from ..messages import AnyAiMessage
14
+ from ..messages import AnyUserMessage
7
15
  from ..messages import Chat
8
16
  from ..messages import Message
9
17
 
10
18
 
19
+ MessageF = ta.TypeVar('MessageF', bound=Message)
11
20
  MessageT = ta.TypeVar('MessageT', bound=Message)
12
21
 
13
22
 
14
23
  ##
15
24
 
16
25
 
17
- class MessageTransform(lang.Abstract, ta.Generic[MessageT]):
26
+ class MessageTransform(lang.Abstract, ta.Generic[MessageF, MessageT]):
18
27
  @abc.abstractmethod
19
- def transform_message(self, message: MessageT) -> MessageT:
28
+ def transform_message(self, message: MessageF) -> ta.Sequence[MessageT]:
20
29
  raise NotImplementedError
21
30
 
22
31
 
32
+ AiMessageTransform: ta.TypeAlias = MessageTransform[AnyAiMessage, AnyAiMessage]
33
+ UserMessageTransform: ta.TypeAlias = MessageTransform[AnyUserMessage, AnyUserMessage]
34
+
35
+
23
36
  @dc.dataclass(frozen=True)
24
37
  class CompositeMessageTransform(MessageTransform):
25
38
  mts: ta.Sequence[MessageTransform]
26
39
 
27
- def transform_message(self, message: Message) -> Message:
40
+ def transform_message(self, message: Message) -> Chat:
41
+ chat: Chat = [message]
28
42
  for mt in self.mts:
29
- message = mt.transform_message(message)
30
- return message
43
+ chat = [o for i in chat for o in mt.transform_message(i)]
44
+ return chat
31
45
 
32
46
 
33
47
  @dc.dataclass(frozen=True)
34
- class FnMessageTransform(MessageTransform, ta.Generic[MessageT]):
35
- fn: ta.Callable[[MessageT], MessageT]
48
+ class FnMessageTransform(MessageTransform, ta.Generic[MessageF, MessageT]):
49
+ fn: ta.Callable[[MessageF], ta.Sequence[MessageT]]
36
50
 
37
- def transform_message(self, message: MessageT) -> MessageT:
51
+ def transform_message(self, message: MessageF) -> ta.Sequence[MessageT]:
38
52
  return self.fn(message)
39
53
 
40
54
 
41
55
  @dc.dataclass(frozen=True)
42
- class TypeFilteredMessageTransform(MessageTransform[Message], ta.Generic[MessageT]):
56
+ class TypeFilteredMessageTransform(MessageTransform, ta.Generic[MessageF, MessageT]):
43
57
  ty: type | tuple[type, ...]
44
- mt: MessageTransform[MessageT]
58
+ mt: MessageTransform[MessageF, MessageT]
45
59
 
46
- def transform_message(self, message: Message) -> Message:
60
+ def transform_message(self, message: Message) -> Chat:
47
61
  if isinstance(message, self.ty):
48
- return self.mt.transform_message(ta.cast(MessageT, message))
62
+ return self.mt.transform_message(ta.cast(MessageF, message))
49
63
  else:
50
- return message
51
-
52
-
53
- @ta.overload
54
- def fn_message_transform(
55
- fn: ta.Callable[[MessageT], MessageT],
56
- ty: type[MessageT],
57
- ) -> MessageTransform[MessageT]:
58
- ...
59
-
60
-
61
- @ta.overload
62
- def fn_message_transform(
63
- fn: ta.Callable[[Message], Message],
64
- ty: type | tuple[type, ...] | None = None,
65
- ) -> MessageTransform:
66
- ...
67
-
68
-
69
- def fn_message_transform(fn, ty=None) -> MessageTransform[MessageT]:
70
- mt: MessageTransform = FnMessageTransform(fn)
71
- if ty is not None:
72
- mt = TypeFilteredMessageTransform(ty, mt)
73
- return mt
64
+ return [message]
74
65
 
75
66
 
76
67
  ##
@@ -108,7 +99,7 @@ class MessageTransformChatTransform(ChatTransform):
108
99
  mt: MessageTransform
109
100
 
110
101
  def transform_chat(self, chat: Chat) -> Chat:
111
- return [self.mt.transform_message(m) for m in chat]
102
+ return [o for i in chat for o in self.mt.transform_message(i)]
112
103
 
113
104
 
114
105
  @dc.dataclass(frozen=True)
@@ -117,6 +108,6 @@ class LastMessageTransformChatTransform(ChatTransform):
117
108
 
118
109
  def transform_chat(self, chat: Chat) -> Chat:
119
110
  if chat:
120
- return [*chat[:-1], self.mt.transform_message(chat[-1])]
111
+ return [*chat[:-1], *self.mt.transform_message(chat[-1])]
121
112
  else:
122
113
  return []