aury-agent 0.0.9__py3-none-any.whl → 0.0.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -5,7 +5,7 @@ Backends provide abstracted interfaces for various capabilities:
5
5
  Data Backends (storage):
6
6
  - SessionBackend: Session management
7
7
  - InvocationBackend: Invocation management
8
- - MessageBackend: Message storage (truncated/raw)
8
+ - MessageBackend: Message storage
9
9
  - MemoryBackend: Long-term memory with search
10
10
  - ArtifactBackend: File/artifact storage
11
11
  - StateBackend: Generic key-value state
@@ -26,7 +26,7 @@ from typing import TYPE_CHECKING
26
26
  # Data backends - new architecture
27
27
  from .session import SessionBackend, InMemorySessionBackend
28
28
  from .invocation import InvocationBackend, InMemoryInvocationBackend
29
- from .message import MessageBackend, MessageType, InMemoryMessageBackend
29
+ from .message import MessageBackend, InMemoryMessageBackend
30
30
  from .memory import MemoryBackend, InMemoryMemoryBackend
31
31
  from .artifact import ArtifactBackend, ArtifactSource, InMemoryArtifactBackend
32
32
 
@@ -144,7 +144,6 @@ __all__ = [
144
144
 
145
145
  # Message backend
146
146
  "MessageBackend",
147
- "MessageType",
148
147
  "InMemoryMessageBackend",
149
148
 
150
149
  # Memory backend
@@ -1,9 +1,8 @@
1
1
  """Message backend."""
2
- from .types import MessageBackend, MessageType
2
+ from .types import MessageBackend
3
3
  from .memory import InMemoryMessageBackend
4
4
 
5
5
  __all__ = [
6
6
  "MessageBackend",
7
- "MessageType",
8
7
  "InMemoryMessageBackend",
9
8
  ]
@@ -4,45 +4,36 @@ from __future__ import annotations
4
4
  from datetime import datetime
5
5
  from typing import Any
6
6
 
7
- from .types import MessageType
8
-
9
7
 
10
8
  class InMemoryMessageBackend:
11
9
  """In-memory implementation of MessageBackend.
12
10
 
13
- Stores both truncated and raw messages in separate dicts.
14
- Suitable for testing and simple single-process use cases.
11
+ Simple in-memory storage for testing and single-process use cases.
15
12
  """
16
13
 
17
14
  def __init__(self) -> None:
18
15
  # Key format: "{session_id}" or "{session_id}:{namespace}"
19
16
  # Value: list of message dicts
20
- self._truncated: dict[str, list[dict[str, Any]]] = {}
21
- self._raw: dict[str, list[dict[str, Any]]] = {}
17
+ self._messages: dict[str, list[dict[str, Any]]] = {}
22
18
 
23
19
  def _make_key(self, session_id: str, namespace: str | None) -> str:
24
20
  if namespace:
25
21
  return f"{session_id}:{namespace}"
26
22
  return session_id
27
23
 
28
- def _get_store(self, type: MessageType) -> dict[str, list[dict[str, Any]]]:
29
- return self._truncated if type == "truncated" else self._raw
30
-
31
24
  async def add(
32
25
  self,
33
26
  session_id: str,
34
27
  message: dict[str, Any],
35
- type: MessageType = "truncated",
36
28
  agent_id: str | None = None,
37
29
  namespace: str | None = None,
38
30
  invocation_id: str | None = None,
39
31
  ) -> None:
40
32
  """Add a message."""
41
33
  key = self._make_key(session_id, namespace)
42
- store = self._get_store(type)
43
34
 
44
- if key not in store:
45
- store[key] = []
35
+ if key not in self._messages:
36
+ self._messages[key] = []
46
37
 
47
38
  # Add metadata
48
39
  msg = {
@@ -51,20 +42,18 @@ class InMemoryMessageBackend:
51
42
  "invocation_id": invocation_id,
52
43
  "created_at": datetime.now().isoformat(),
53
44
  }
54
- store[key].append(msg)
45
+ self._messages[key].append(msg)
55
46
 
56
47
  async def get(
57
48
  self,
58
49
  session_id: str,
59
- type: MessageType = "truncated",
60
50
  agent_id: str | None = None,
61
51
  namespace: str | None = None,
62
52
  limit: int | None = None,
63
53
  ) -> list[dict[str, Any]]:
64
54
  """Get messages."""
65
55
  key = self._make_key(session_id, namespace)
66
- store = self._get_store(type)
67
- messages = store.get(key, [])
56
+ messages = self._messages.get(key, [])
68
57
 
69
58
  # Filter by agent_id if specified
70
59
  if agent_id:
@@ -80,42 +69,31 @@ class InMemoryMessageBackend:
80
69
  self,
81
70
  session_id: str,
82
71
  invocation_id: str,
83
- type: MessageType | None = None,
84
72
  namespace: str | None = None,
85
73
  ) -> int:
86
74
  """Delete messages by invocation."""
87
75
  key = self._make_key(session_id, namespace)
88
- deleted = 0
89
-
90
- types_to_delete = [type] if type else ["truncated", "raw"]
91
76
 
92
- for t in types_to_delete:
93
- store = self._get_store(t)
94
- if key in store:
95
- original = store[key]
96
- store[key] = [m for m in original if m.get("invocation_id") != invocation_id]
97
- deleted += len(original) - len(store[key])
77
+ if key not in self._messages:
78
+ return 0
98
79
 
99
- return deleted
80
+ original = self._messages[key]
81
+ self._messages[key] = [m for m in original if m.get("invocation_id") != invocation_id]
82
+ return len(original) - len(self._messages[key])
100
83
 
101
84
  async def clear(
102
85
  self,
103
86
  session_id: str,
104
- type: MessageType | None = None,
105
87
  namespace: str | None = None,
106
88
  ) -> int:
107
89
  """Clear all messages for a session."""
108
90
  key = self._make_key(session_id, namespace)
109
- deleted = 0
110
-
111
- types_to_clear = [type] if type else ["truncated", "raw"]
112
91
 
113
- for t in types_to_clear:
114
- store = self._get_store(t)
115
- if key in store:
116
- deleted += len(store[key])
117
- del store[key]
92
+ if key not in self._messages:
93
+ return 0
118
94
 
95
+ deleted = len(self._messages[key])
96
+ del self._messages[key]
119
97
  return deleted
120
98
 
121
99
 
@@ -1,49 +1,29 @@
1
1
  """Message backend types and protocols."""
2
2
  from __future__ import annotations
3
3
 
4
- from typing import Any, Literal, Protocol, runtime_checkable
5
-
6
-
7
- MessageType = Literal["truncated", "raw"]
4
+ from typing import Any, Protocol, runtime_checkable
8
5
 
9
6
 
10
7
  @runtime_checkable
11
8
  class MessageBackend(Protocol):
12
9
  """Protocol for message storage.
13
10
 
14
- Handles both truncated (context window) and raw (full history) messages
15
- through a unified interface with type parameter.
16
-
17
- - truncated: Messages kept in context window, may be summarized/trimmed
18
- - raw: Full original messages for audit/replay
11
+ Simple interface for message persistence.
12
+ Storage details (raw/truncated handling) are left to the application layer.
19
13
 
20
14
  Example usage:
21
- # Add truncated message (for LLM context)
22
15
  await backend.add(
23
16
  session_id="sess_123",
24
17
  message={"role": "user", "content": "Hello"},
25
- type="truncated",
26
- )
27
-
28
- # Add raw message (for audit)
29
- await backend.add(
30
- session_id="sess_123",
31
- message={"role": "user", "content": "Hello", "attachments": [...]},
32
- type="raw",
33
18
  )
34
19
 
35
- # Get messages for LLM
36
- messages = await backend.get("sess_123", type="truncated", limit=50)
37
-
38
- # Get raw history
39
- raw_messages = await backend.get("sess_123", type="raw")
20
+ messages = await backend.get("sess_123", limit=50)
40
21
  """
41
22
 
42
23
  async def add(
43
24
  self,
44
25
  session_id: str,
45
26
  message: dict[str, Any],
46
- type: MessageType = "truncated",
47
27
  agent_id: str | None = None,
48
28
  namespace: str | None = None,
49
29
  invocation_id: str | None = None,
@@ -53,7 +33,6 @@ class MessageBackend(Protocol):
53
33
  Args:
54
34
  session_id: Session ID
55
35
  message: Message dict (role, content, tool_call_id, etc.)
56
- type: Message type - "truncated" or "raw"
57
36
  agent_id: Optional agent ID
58
37
  namespace: Optional namespace for sub-agent isolation
59
38
  invocation_id: Optional invocation ID for grouping
@@ -63,7 +42,6 @@ class MessageBackend(Protocol):
63
42
  async def get(
64
43
  self,
65
44
  session_id: str,
66
- type: MessageType = "truncated",
67
45
  agent_id: str | None = None,
68
46
  namespace: str | None = None,
69
47
  limit: int | None = None,
@@ -72,7 +50,6 @@ class MessageBackend(Protocol):
72
50
 
73
51
  Args:
74
52
  session_id: Session ID
75
- type: Message type - "truncated" or "raw"
76
53
  agent_id: Optional filter by agent
77
54
  namespace: Optional namespace filter
78
55
  limit: Max messages to return (None = all)
@@ -86,7 +63,6 @@ class MessageBackend(Protocol):
86
63
  self,
87
64
  session_id: str,
88
65
  invocation_id: str,
89
- type: MessageType | None = None,
90
66
  namespace: str | None = None,
91
67
  ) -> int:
92
68
  """Delete messages by invocation (for revert).
@@ -94,7 +70,6 @@ class MessageBackend(Protocol):
94
70
  Args:
95
71
  session_id: Session ID
96
72
  invocation_id: Invocation ID to delete
97
- type: Message type to delete, None = both types
98
73
  namespace: Optional namespace filter
99
74
 
100
75
  Returns:
@@ -105,14 +80,12 @@ class MessageBackend(Protocol):
105
80
  async def clear(
106
81
  self,
107
82
  session_id: str,
108
- type: MessageType | None = None,
109
83
  namespace: str | None = None,
110
84
  ) -> int:
111
85
  """Clear all messages for a session.
112
86
 
113
87
  Args:
114
88
  session_id: Session ID
115
- type: Message type to clear, None = both types
116
89
  namespace: Optional namespace filter
117
90
 
118
91
  Returns:
@@ -121,4 +94,4 @@ class MessageBackend(Protocol):
121
94
  ...
122
95
 
123
96
 
124
- __all__ = ["MessageBackend", "MessageType"]
97
+ __all__ = ["MessageBackend"]
@@ -79,7 +79,6 @@ class MessageContextProvider(BaseContextProvider):
79
79
  if ctx.backends is not None and ctx.backends.message is not None:
80
80
  messages = await ctx.backends.message.get(
81
81
  session_id=ctx.session.id,
82
- type="truncated",
83
82
  limit=self.max_messages,
84
83
  )
85
84
  # Convert to LLM format (include tool_call_id for tool messages)
@@ -728,11 +728,11 @@ class InvocationContext:
728
728
  **{k: v for k, v in request.items() if k not in ("messages", "stream")}
729
729
  ):
730
730
  if self.middleware:
731
- chunk_dict = {"chunk": chunk}
732
- processed = await self.middleware.process_stream_chunk(chunk_dict)
731
+ chunk_dict = {"delta": chunk}
732
+ processed = await self.middleware.process_text_stream(chunk_dict)
733
733
  if processed is None:
734
734
  continue
735
- chunk = processed.get("chunk", chunk)
735
+ chunk = processed.get("delta", chunk)
736
736
  yield chunk
737
737
 
738
738
  except Exception as e:
@@ -90,14 +90,41 @@ class ToolContext:
90
90
 
91
91
  @dataclass
92
92
  class ToolResult:
93
- """Tool execution result for LLM."""
94
- output: str
93
+ """Tool execution result for LLM.
94
+
95
+ Supports dual output for context management:
96
+ - output: Complete output (raw), for storage and recall
97
+ - truncated_output: Shortened output for context window
98
+
99
+ If truncated_output is not provided, it defaults to output.
100
+ """
101
+ output: str # Complete output (raw)
95
102
  is_error: bool = False
103
+ truncated_output: str | None = None # Shortened output (defaults to output)
104
+
105
+ def __post_init__(self):
106
+ # Default truncated to output if not provided
107
+ if self.truncated_output is None:
108
+ self.truncated_output = self.output
96
109
 
97
110
  @classmethod
98
- def success(cls, output: str) -> ToolResult:
99
- """Create a successful result."""
100
- return cls(output=output, is_error=False)
111
+ def success(
112
+ cls,
113
+ output: str,
114
+ *,
115
+ truncated_output: str | None = None,
116
+ ) -> ToolResult:
117
+ """Create a successful result.
118
+
119
+ Args:
120
+ output: Complete output (raw)
121
+ truncated_output: Shortened output for context (defaults to output)
122
+ """
123
+ return cls(
124
+ output=output,
125
+ is_error=False,
126
+ truncated_output=truncated_output,
127
+ )
101
128
 
102
129
  @classmethod
103
130
  def error(cls, message: str) -> ToolResult:
@@ -133,18 +133,31 @@ class LLMMessage:
133
133
  - system: System prompt
134
134
  - user: User message (can include images)
135
135
  - assistant: Assistant response (can include tool_calls)
136
- - tool: Tool result (requires tool_call_id)
136
+ - tool: Tool result (requires tool_call_id and name)
137
137
  """
138
138
  role: Literal["system", "user", "assistant", "tool"]
139
139
  content: str | list[dict[str, Any]]
140
140
  tool_call_id: str | None = None # Required for tool role
141
+ name: str | None = None # Tool name, required for Gemini compatibility
141
142
 
142
143
  def to_dict(self) -> dict[str, Any]:
143
144
  d = {"role": self.role, "content": self.content}
144
145
  if self.tool_call_id:
145
146
  d["tool_call_id"] = self.tool_call_id
147
+ if self.name:
148
+ d["name"] = self.name
146
149
  return d
147
150
 
151
+ def get(self, key: str, default: Any = None) -> Any:
152
+ """Dict-like access for middleware compatibility."""
153
+ return getattr(self, key, default)
154
+
155
+ def __getitem__(self, key: str) -> Any:
156
+ """Dict-like access via []."""
157
+ if hasattr(self, key):
158
+ return getattr(self, key)
159
+ raise KeyError(key)
160
+
148
161
  @classmethod
149
162
  def system(cls, content: str) -> "LLMMessage":
150
163
  """Create system message."""
@@ -161,9 +174,15 @@ class LLMMessage:
161
174
  return cls(role="assistant", content=content)
162
175
 
163
176
  @classmethod
164
- def tool(cls, content: str, tool_call_id: str) -> "LLMMessage":
165
- """Create tool result message."""
166
- return cls(role="tool", content=content, tool_call_id=tool_call_id)
177
+ def tool(cls, content: str, tool_call_id: str, name: str | None = None) -> "LLMMessage":
178
+ """Create tool result message.
179
+
180
+ Args:
181
+ content: Tool result content
182
+ tool_call_id: ID of the tool call this result is for
183
+ name: Tool name (required for Gemini compatibility)
184
+ """
185
+ return cls(role="tool", content=content, tool_call_id=tool_call_id, name=name)
167
186
 
168
187
 
169
188
  @runtime_checkable
@@ -15,11 +15,6 @@ from .store import (
15
15
  MessageStore,
16
16
  InMemoryMessageStore,
17
17
  )
18
- from .raw_store import (
19
- RawMessageStore,
20
- StateBackendRawMessageStore,
21
- InMemoryRawMessageStore,
22
- )
23
18
  from .config import (
24
19
  MessageConfig,
25
20
  )
@@ -31,10 +26,6 @@ __all__ = [
31
26
  # Store (protocol + in-memory for testing)
32
27
  "MessageStore",
33
28
  "InMemoryMessageStore",
34
- # Raw Store
35
- "RawMessageStore",
36
- "StateBackendRawMessageStore",
37
- "InMemoryRawMessageStore",
38
29
  # Config
39
30
  "MessageConfig",
40
31
  ]
@@ -13,7 +13,6 @@ from .chain import MiddlewareChain
13
13
  from .message_container import MessageContainerMiddleware
14
14
  from .message import MessageBackendMiddleware
15
15
  from .truncation import MessageTruncationMiddleware
16
- from .raw_message import RawMessageMiddleware
17
16
 
18
17
  __all__ = [
19
18
  "TriggerMode",
@@ -27,5 +26,4 @@ __all__ = [
27
26
  "MessageContainerMiddleware",
28
27
  "MessageBackendMiddleware",
29
28
  "MessageTruncationMiddleware",
30
- "RawMessageMiddleware",
31
29
  ]
@@ -73,17 +73,28 @@ class Middleware(Protocol):
73
73
  """
74
74
  ...
75
75
 
76
- async def on_model_stream(
76
+ async def on_text_stream(
77
77
  self,
78
78
  chunk: dict[str, Any],
79
79
  ) -> dict[str, Any] | None:
80
- """Process streaming chunk (triggered by trigger_mode).
80
+ """Process text streaming chunk.
81
81
 
82
82
  Args:
83
- chunk: The streaming chunk
83
+ chunk: The text chunk with {"delta": str}
84
84
 
85
85
  Returns:
86
- Modified chunk, or None to skip further processing
86
+ Modified chunk, or None to skip
87
+ """
88
+ ...
89
+
90
+ async def on_text_stream_end(self) -> dict[str, Any] | None:
91
+ """Called when text stream ends.
92
+
93
+ Use this to flush any buffered text content.
94
+
95
+ Returns:
96
+ Optional dict with {"delta": str} to emit final content,
97
+ or None if no additional content.
87
98
  """
88
99
  ...
89
100
 
@@ -101,6 +112,17 @@ class Middleware(Protocol):
101
112
  """
102
113
  ...
103
114
 
115
+ async def on_thinking_stream_end(self) -> dict[str, Any] | None:
116
+ """Called when thinking stream ends.
117
+
118
+ Use this to flush any buffered thinking content.
119
+
120
+ Returns:
121
+ Optional dict with {"delta": str} to emit final thinking content,
122
+ or None if no additional content.
123
+ """
124
+ ...
125
+
104
126
  # ========== Agent Lifecycle Hooks ==========
105
127
 
106
128
  async def on_agent_start(
@@ -283,13 +305,17 @@ class BaseMiddleware:
283
305
  """Default: re-raise error."""
284
306
  return error
285
307
 
286
- async def on_model_stream(
308
+ async def on_text_stream(
287
309
  self,
288
310
  chunk: dict[str, Any],
289
311
  ) -> dict[str, Any] | None:
290
312
  """Default: pass through."""
291
313
  return chunk
292
314
 
315
+ async def on_text_stream_end(self) -> dict[str, Any] | None:
316
+ """Default: no additional content."""
317
+ return None
318
+
293
319
  async def on_thinking_stream(
294
320
  self,
295
321
  chunk: dict[str, Any],
@@ -297,6 +323,10 @@ class BaseMiddleware:
297
323
  """Default: pass through."""
298
324
  return chunk
299
325
 
326
+ async def on_thinking_stream_end(self) -> dict[str, Any] | None:
327
+ """Default: no additional content."""
328
+ return None
329
+
300
330
  # ========== Agent Lifecycle Hooks ==========
301
331
 
302
332
  async def on_agent_start(
@@ -163,32 +163,44 @@ class MiddlewareChain:
163
163
  logger.debug("Error processing completed")
164
164
  return current
165
165
 
166
- async def process_stream_chunk(
166
+ async def process_text_stream(
167
167
  self,
168
168
  chunk: dict[str, Any],
169
169
  ) -> dict[str, Any] | None:
170
- """Process streaming chunk through middlewares based on trigger mode."""
171
- text = chunk.get("text", chunk.get("delta", ""))
170
+ """Process text streaming chunk through middlewares based on trigger mode."""
171
+ text = chunk.get("delta", "")
172
172
  self._token_buffer += text
173
173
  self._token_count += 1
174
174
 
175
175
  current = chunk
176
- triggered_count = 0
177
176
 
178
177
  for i, mw in enumerate(self._middlewares):
179
178
  should_trigger = self._should_trigger(mw, text)
180
179
 
181
180
  if should_trigger:
182
- triggered_count += 1
183
- result = await mw.on_model_stream(current)
181
+ result = await mw.on_text_stream(current)
184
182
  if result is None:
185
- logger.info(f"Middleware #{i} blocked stream chunk")
186
183
  return None
187
184
  current = result
188
185
 
189
186
  # Log only every 50 tokens to reduce noise
190
187
  if self._token_count % 50 == 0:
191
- logger.debug(f"Stream progress: token_count={self._token_count}, middlewares={len(self._middlewares)}")
188
+ logger.debug(f"Text stream progress: token_count={self._token_count}")
189
+
190
+ return current
191
+
192
+ async def process_thinking_stream(
193
+ self,
194
+ chunk: dict[str, Any],
195
+ ) -> dict[str, Any] | None:
196
+ """Process thinking streaming chunk through all middlewares."""
197
+ current = chunk
198
+
199
+ for i, mw in enumerate(self._middlewares):
200
+ result = await mw.on_thinking_stream(current)
201
+ if result is None:
202
+ return None
203
+ current = result
192
204
 
193
205
  return current
194
206
 
@@ -241,6 +253,50 @@ class MiddlewareChain:
241
253
  boundaries = (".", "。", "\n", "!", "?", "!", "?", ";", ";")
242
254
  return text.rstrip().endswith(boundaries)
243
255
 
256
+ async def process_text_stream_end(self) -> list[dict[str, Any]]:
257
+ """Process text stream end through all middlewares.
258
+
259
+ Called when text stream ends, before on_response.
260
+ Allows middlewares to flush any buffered content.
261
+
262
+ Returns:
263
+ List of final chunks to emit (may be empty)
264
+ """
265
+ final_chunks: list[dict[str, Any]] = []
266
+ logger.debug(f"Processing text_stream_end through {len(self._middlewares)} middlewares")
267
+
268
+ for i, mw in enumerate(self._middlewares):
269
+ if hasattr(mw, 'on_text_stream_end'):
270
+ result = await mw.on_text_stream_end()
271
+ if result is not None:
272
+ logger.debug(f"Middleware #{i} returned final chunk on text_stream_end")
273
+ final_chunks.append(result)
274
+
275
+ logger.debug(f"Text stream end processing completed, {len(final_chunks)} final chunks")
276
+ return final_chunks
277
+
278
+ async def process_thinking_stream_end(self) -> list[dict[str, Any]]:
279
+ """Process thinking stream end through all middlewares.
280
+
281
+ Called when thinking stream ends.
282
+ Allows middlewares to flush any buffered thinking content.
283
+
284
+ Returns:
285
+ List of final thinking chunks to emit (may be empty)
286
+ """
287
+ final_chunks: list[dict[str, Any]] = []
288
+ logger.debug(f"Processing thinking_stream_end through {len(self._middlewares)} middlewares")
289
+
290
+ for i, mw in enumerate(self._middlewares):
291
+ if hasattr(mw, 'on_thinking_stream_end'):
292
+ result = await mw.on_thinking_stream_end()
293
+ if result is not None:
294
+ logger.debug(f"Middleware #{i} returned final chunk on thinking_stream_end")
295
+ final_chunks.append(result)
296
+
297
+ logger.debug(f"Thinking stream end processing completed, {len(final_chunks)} final chunks")
298
+ return final_chunks
299
+
244
300
  def reset_stream_state(self) -> None:
245
301
  """Reset streaming state (call at start of new stream)."""
246
302
  logger.debug("Resetting stream state")