aury-agent 0.0.5__py3-none-any.whl → 0.0.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,4 +1,11 @@
1
- """Middleware protocol and base implementation."""
1
+ """Middleware protocol and base implementation.
2
+
3
+ Middleware can access InvocationContext via get_current_ctx_or_none() for:
4
+ - session_id, invocation_id, agent_id, agent_name
5
+ - backends, metadata, etc.
6
+
7
+ Middleware should use self._xxx for internal state between hooks.
8
+ """
2
9
  from __future__ import annotations
3
10
 
4
11
  from typing import Any, Protocol, runtime_checkable, TYPE_CHECKING
@@ -14,6 +21,7 @@ class Middleware(Protocol):
14
21
  """Middleware protocol for request/response processing.
15
22
 
16
23
  Includes both LLM request/response hooks and agent lifecycle hooks.
24
+ Use get_current_ctx_or_none() to access InvocationContext.
17
25
  """
18
26
 
19
27
  @property
@@ -26,13 +34,11 @@ class Middleware(Protocol):
26
34
  async def on_request(
27
35
  self,
28
36
  request: dict[str, Any],
29
- context: dict[str, Any],
30
37
  ) -> dict[str, Any] | None:
31
38
  """Process request before LLM call.
32
39
 
33
40
  Args:
34
41
  request: The request to process
35
- context: Execution context
36
42
 
37
43
  Returns:
38
44
  Modified request, or None to skip further processing
@@ -42,13 +48,11 @@ class Middleware(Protocol):
42
48
  async def on_response(
43
49
  self,
44
50
  response: dict[str, Any],
45
- context: dict[str, Any],
46
51
  ) -> dict[str, Any] | None:
47
52
  """Process response after LLM call.
48
53
 
49
54
  Args:
50
55
  response: The response to process
51
- context: Execution context
52
56
 
53
57
  Returns:
54
58
  Modified response, or None to skip further processing
@@ -58,13 +62,11 @@ class Middleware(Protocol):
58
62
  async def on_error(
59
63
  self,
60
64
  error: Exception,
61
- context: dict[str, Any],
62
65
  ) -> Exception | None:
63
66
  """Handle errors.
64
67
 
65
68
  Args:
66
69
  error: The exception that occurred
67
- context: Execution context
68
70
 
69
71
  Returns:
70
72
  Modified exception, or None to suppress
@@ -74,13 +76,11 @@ class Middleware(Protocol):
74
76
  async def on_model_stream(
75
77
  self,
76
78
  chunk: dict[str, Any],
77
- context: dict[str, Any],
78
79
  ) -> dict[str, Any] | None:
79
80
  """Process streaming chunk (triggered by trigger_mode).
80
81
 
81
82
  Args:
82
83
  chunk: The streaming chunk
83
- context: Execution context
84
84
 
85
85
  Returns:
86
86
  Modified chunk, or None to skip further processing
@@ -90,13 +90,11 @@ class Middleware(Protocol):
90
90
  async def on_thinking_stream(
91
91
  self,
92
92
  chunk: dict[str, Any],
93
- context: dict[str, Any],
94
93
  ) -> dict[str, Any] | None:
95
94
  """Process thinking stream chunk.
96
95
 
97
96
  Args:
98
97
  chunk: The thinking chunk with {"delta": str}
99
- context: Execution context
100
98
 
101
99
  Returns:
102
100
  Modified chunk, or None to skip
@@ -107,16 +105,14 @@ class Middleware(Protocol):
107
105
 
108
106
  async def on_agent_start(
109
107
  self,
110
- agent_id: str,
111
108
  input_data: Any,
112
- context: dict[str, Any],
113
109
  ) -> HookResult:
114
110
  """Called when agent starts processing.
115
111
 
112
+ Use get_current_ctx_or_none() to access agent_id, session_id, etc.
113
+
116
114
  Args:
117
- agent_id: The agent identifier
118
115
  input_data: Input to the agent
119
- context: Execution context
120
116
 
121
117
  Returns:
122
118
  HookResult controlling execution flow
@@ -125,16 +121,14 @@ class Middleware(Protocol):
125
121
 
126
122
  async def on_agent_end(
127
123
  self,
128
- agent_id: str,
129
124
  result: Any,
130
- context: dict[str, Any],
131
125
  ) -> HookResult:
132
126
  """Called when agent completes processing.
133
127
 
128
+ Use get_current_ctx_or_none() to access agent_id, session_id, etc.
129
+
134
130
  Args:
135
- agent_id: The agent identifier
136
131
  result: Agent's result
137
- context: Execution context
138
132
 
139
133
  Returns:
140
134
  HookResult (only CONTINUE/STOP meaningful here)
@@ -145,14 +139,12 @@ class Middleware(Protocol):
145
139
  self,
146
140
  tool: "BaseTool",
147
141
  params: dict[str, Any],
148
- context: dict[str, Any],
149
142
  ) -> HookResult:
150
143
  """Called before tool execution.
151
144
 
152
145
  Args:
153
146
  tool: The tool to be called
154
147
  params: Tool parameters
155
- context: Execution context
156
148
 
157
149
  Returns:
158
150
  HookResult - SKIP to skip tool, RETRY to modify params
@@ -164,7 +156,7 @@ class Middleware(Protocol):
164
156
  call_id: str,
165
157
  tool_name: str,
166
158
  delta: dict[str, Any],
167
- context: dict[str, Any],
159
+ accumulated_args: dict[str, Any],
168
160
  ) -> dict[str, Any] | None:
169
161
  """Called during streaming tool argument generation.
170
162
 
@@ -175,7 +167,7 @@ class Middleware(Protocol):
175
167
  call_id: Tool call identifier
176
168
  tool_name: Name of the tool being called
177
169
  delta: Incremental parameter update (e.g. {"content": "more text"})
178
- context: Execution context with 'accumulated_args' containing current state
170
+ accumulated_args: Current accumulated arguments state
179
171
 
180
172
  Returns:
181
173
  Modified delta, or None to skip emitting this delta
@@ -186,14 +178,12 @@ class Middleware(Protocol):
186
178
  self,
187
179
  tool: "BaseTool",
188
180
  result: "ToolResult",
189
- context: dict[str, Any],
190
181
  ) -> HookResult:
191
182
  """Called after tool execution.
192
183
 
193
184
  Args:
194
185
  tool: The tool that was called
195
186
  result: Tool execution result
196
- context: Execution context
197
187
 
198
188
  Returns:
199
189
  HookResult - RETRY to re-execute tool
@@ -205,7 +195,6 @@ class Middleware(Protocol):
205
195
  parent_agent_id: str,
206
196
  child_agent_id: str,
207
197
  mode: str, # "embedded" or "delegated"
208
- context: dict[str, Any],
209
198
  ) -> HookResult:
210
199
  """Called when delegating to a sub-agent.
211
200
 
@@ -213,7 +202,6 @@ class Middleware(Protocol):
213
202
  parent_agent_id: Parent agent identifier
214
203
  child_agent_id: Child agent identifier
215
204
  mode: Delegation mode
216
- context: Execution context
217
205
 
218
206
  Returns:
219
207
  HookResult - SKIP to skip delegation
@@ -225,7 +213,6 @@ class Middleware(Protocol):
225
213
  parent_agent_id: str,
226
214
  child_agent_id: str,
227
215
  result: Any,
228
- context: dict[str, Any],
229
216
  ) -> HookResult:
230
217
  """Called when sub-agent completes.
231
218
 
@@ -233,7 +220,6 @@ class Middleware(Protocol):
233
220
  parent_agent_id: Parent agent identifier
234
221
  child_agent_id: Child agent identifier
235
222
  result: Sub-agent's result
236
- context: Execution context
237
223
 
238
224
  Returns:
239
225
  HookResult (for post-processing)
@@ -243,7 +229,6 @@ class Middleware(Protocol):
243
229
  async def on_message_save(
244
230
  self,
245
231
  message: dict[str, Any],
246
- context: dict[str, Any],
247
232
  ) -> dict[str, Any] | None:
248
233
  """Called before saving a message to history.
249
234
 
@@ -252,7 +237,6 @@ class Middleware(Protocol):
252
237
 
253
238
  Args:
254
239
  message: Message dict with 'role', 'content', etc.
255
- context: Execution context
256
240
 
257
241
  Returns:
258
242
  Modified message, or None to skip saving
@@ -265,6 +249,9 @@ class BaseMiddleware:
265
249
 
266
250
  Subclass and override specific hooks as needed.
267
251
  All hooks have sensible pass-through defaults.
252
+
253
+ Use get_current_ctx_or_none() to access InvocationContext.
254
+ Use self._xxx for internal state between hooks.
268
255
  """
269
256
 
270
257
  _config: MiddlewareConfig = MiddlewareConfig()
@@ -278,7 +265,6 @@ class BaseMiddleware:
278
265
  async def on_request(
279
266
  self,
280
267
  request: dict[str, Any],
281
- context: dict[str, Any],
282
268
  ) -> dict[str, Any] | None:
283
269
  """Default: pass through."""
284
270
  return request
@@ -286,7 +272,6 @@ class BaseMiddleware:
286
272
  async def on_response(
287
273
  self,
288
274
  response: dict[str, Any],
289
- context: dict[str, Any],
290
275
  ) -> dict[str, Any] | None:
291
276
  """Default: pass through."""
292
277
  return response
@@ -294,7 +279,6 @@ class BaseMiddleware:
294
279
  async def on_error(
295
280
  self,
296
281
  error: Exception,
297
- context: dict[str, Any],
298
282
  ) -> Exception | None:
299
283
  """Default: re-raise error."""
300
284
  return error
@@ -302,7 +286,6 @@ class BaseMiddleware:
302
286
  async def on_model_stream(
303
287
  self,
304
288
  chunk: dict[str, Any],
305
- context: dict[str, Any],
306
289
  ) -> dict[str, Any] | None:
307
290
  """Default: pass through."""
308
291
  return chunk
@@ -310,7 +293,6 @@ class BaseMiddleware:
310
293
  async def on_thinking_stream(
311
294
  self,
312
295
  chunk: dict[str, Any],
313
- context: dict[str, Any],
314
296
  ) -> dict[str, Any] | None:
315
297
  """Default: pass through."""
316
298
  return chunk
@@ -319,18 +301,14 @@ class BaseMiddleware:
319
301
 
320
302
  async def on_agent_start(
321
303
  self,
322
- agent_id: str,
323
304
  input_data: Any,
324
- context: dict[str, Any],
325
305
  ) -> HookResult:
326
306
  """Default: continue."""
327
307
  return HookResult.proceed()
328
308
 
329
309
  async def on_agent_end(
330
310
  self,
331
- agent_id: str,
332
311
  result: Any,
333
- context: dict[str, Any],
334
312
  ) -> HookResult:
335
313
  """Default: continue."""
336
314
  return HookResult.proceed()
@@ -339,7 +317,6 @@ class BaseMiddleware:
339
317
  self,
340
318
  tool: "BaseTool",
341
319
  params: dict[str, Any],
342
- context: dict[str, Any],
343
320
  ) -> HookResult:
344
321
  """Default: continue."""
345
322
  return HookResult.proceed()
@@ -349,7 +326,7 @@ class BaseMiddleware:
349
326
  call_id: str,
350
327
  tool_name: str,
351
328
  delta: dict[str, Any],
352
- context: dict[str, Any],
329
+ accumulated_args: dict[str, Any],
353
330
  ) -> dict[str, Any] | None:
354
331
  """Default: pass through."""
355
332
  return delta
@@ -358,7 +335,6 @@ class BaseMiddleware:
358
335
  self,
359
336
  tool: "BaseTool",
360
337
  result: "ToolResult",
361
- context: dict[str, Any],
362
338
  ) -> HookResult:
363
339
  """Default: continue."""
364
340
  return HookResult.proceed()
@@ -368,7 +344,6 @@ class BaseMiddleware:
368
344
  parent_agent_id: str,
369
345
  child_agent_id: str,
370
346
  mode: str,
371
- context: dict[str, Any],
372
347
  ) -> HookResult:
373
348
  """Default: continue."""
374
349
  return HookResult.proceed()
@@ -378,7 +353,6 @@ class BaseMiddleware:
378
353
  parent_agent_id: str,
379
354
  child_agent_id: str,
380
355
  result: Any,
381
- context: dict[str, Any],
382
356
  ) -> HookResult:
383
357
  """Default: continue."""
384
358
  return HookResult.proceed()
@@ -386,7 +360,6 @@ class BaseMiddleware:
386
360
  async def on_message_save(
387
361
  self,
388
362
  message: dict[str, Any],
389
- context: dict[str, Any],
390
363
  ) -> dict[str, Any] | None:
391
364
  """Default: pass through."""
392
365
  return message
@@ -112,14 +112,13 @@ class MiddlewareChain:
112
112
  async def process_request(
113
113
  self,
114
114
  request: dict[str, Any],
115
- context: dict[str, Any],
116
115
  ) -> dict[str, Any] | None:
117
116
  """Process request through all middlewares."""
118
117
  current = request
119
118
  logger.debug(f"Processing request through {len(self._middlewares)} middlewares")
120
119
 
121
120
  for i, mw in enumerate(self._middlewares):
122
- result = await mw.on_request(current, context)
121
+ result = await mw.on_request(current)
123
122
  if result is None:
124
123
  logger.info(f"Middleware #{i} blocked request")
125
124
  return None
@@ -131,14 +130,13 @@ class MiddlewareChain:
131
130
  async def process_response(
132
131
  self,
133
132
  response: dict[str, Any],
134
- context: dict[str, Any],
135
133
  ) -> dict[str, Any] | None:
136
134
  """Process response through all middlewares (reverse order)."""
137
135
  current = response
138
136
  logger.debug(f"Processing response through {len(self._middlewares)} middlewares (reverse order)")
139
137
 
140
138
  for i, mw in enumerate(reversed(self._middlewares)):
141
- result = await mw.on_response(current, context)
139
+ result = await mw.on_response(current)
142
140
  if result is None:
143
141
  logger.info(f"Middleware #{i} blocked response")
144
142
  return None
@@ -150,14 +148,13 @@ class MiddlewareChain:
150
148
  async def process_error(
151
149
  self,
152
150
  error: Exception,
153
- context: dict[str, Any],
154
151
  ) -> Exception | None:
155
152
  """Process error through all middlewares."""
156
153
  current = error
157
154
  logger.debug(f"Processing error {type(error).__name__} through {len(self._middlewares)} middlewares")
158
155
 
159
156
  for i, mw in enumerate(self._middlewares):
160
- result = await mw.on_error(current, context)
157
+ result = await mw.on_error(current)
161
158
  if result is None:
162
159
  logger.info(f"Middleware #{i} suppressed error")
163
160
  return None
@@ -169,13 +166,11 @@ class MiddlewareChain:
169
166
  async def process_stream_chunk(
170
167
  self,
171
168
  chunk: dict[str, Any],
172
- context: dict[str, Any],
173
169
  ) -> dict[str, Any] | None:
174
170
  """Process streaming chunk through middlewares based on trigger mode."""
175
171
  text = chunk.get("text", chunk.get("delta", ""))
176
172
  self._token_buffer += text
177
173
  self._token_count += 1
178
- logger.debug(f"Processing stream chunk, token_count={self._token_count}, triggered_middlewares=?")
179
174
 
180
175
  current = chunk
181
176
  triggered_count = 0
@@ -185,14 +180,16 @@ class MiddlewareChain:
185
180
 
186
181
  if should_trigger:
187
182
  triggered_count += 1
188
- logger.debug(f"Middleware #{i} triggered, mode={mw.config.trigger_mode}")
189
- result = await mw.on_model_stream(current, context)
183
+ result = await mw.on_model_stream(current)
190
184
  if result is None:
191
185
  logger.info(f"Middleware #{i} blocked stream chunk")
192
186
  return None
193
187
  current = result
194
188
 
195
- logger.debug(f"Stream chunk processing completed, {triggered_count} middlewares triggered")
189
+ # Log only every 50 tokens to reduce noise
190
+ if self._token_count % 50 == 0:
191
+ logger.debug(f"Stream progress: token_count={self._token_count}, middlewares={len(self._middlewares)}")
192
+
196
193
  return current
197
194
 
198
195
  async def process_tool_call_delta(
@@ -200,7 +197,7 @@ class MiddlewareChain:
200
197
  call_id: str,
201
198
  tool_name: str,
202
199
  delta: dict[str, Any],
203
- context: dict[str, Any],
200
+ accumulated_args: dict[str, Any],
204
201
  ) -> dict[str, Any] | None:
205
202
  """Process tool call delta through all middlewares.
206
203
 
@@ -208,7 +205,7 @@ class MiddlewareChain:
208
205
  call_id: Tool call identifier
209
206
  tool_name: Name of the tool being called
210
207
  delta: Incremental parameter update
211
- context: Execution context with 'accumulated_args'
208
+ accumulated_args: Current accumulated arguments state
212
209
 
213
210
  Returns:
214
211
  Modified delta, or None to skip emitting
@@ -217,7 +214,7 @@ class MiddlewareChain:
217
214
  logger.debug(f"Processing tool_call_delta for {tool_name} (call_id={call_id}) through {len(self._middlewares)} middlewares")
218
215
 
219
216
  for i, mw in enumerate(self._middlewares):
220
- result = await mw.on_tool_call_delta(call_id, tool_name, current, context)
217
+ result = await mw.on_tool_call_delta(call_id, tool_name, current, accumulated_args)
221
218
  if result is None:
222
219
  logger.info(f"Middleware #{i} blocked tool_call_delta")
223
220
  return None
@@ -259,19 +256,17 @@ class MiddlewareChain:
259
256
 
260
257
  async def process_agent_start(
261
258
  self,
262
- agent_id: str,
263
259
  input_data: Any,
264
- context: dict[str, Any],
265
260
  ) -> HookResult:
266
261
  """Process agent start through all middlewares.
267
262
 
268
263
  Returns:
269
264
  First non-CONTINUE result, or CONTINUE if all pass
270
265
  """
271
- logger.debug(f"Processing agent_start for agent_id={agent_id}, {len(self._middlewares)} middlewares")
266
+ logger.debug(f"Processing agent_start, {len(self._middlewares)} middlewares")
272
267
  for i, mw in enumerate(self._middlewares):
273
268
  if hasattr(mw, 'on_agent_start'):
274
- result = await mw.on_agent_start(agent_id, input_data, context)
269
+ result = await mw.on_agent_start(input_data)
275
270
  if result.action != HookAction.CONTINUE:
276
271
  logger.info(f"Middleware #{i} returned {result.action} on agent_start")
277
272
  return result
@@ -280,15 +275,13 @@ class MiddlewareChain:
280
275
 
281
276
  async def process_agent_end(
282
277
  self,
283
- agent_id: str,
284
278
  result: Any,
285
- context: dict[str, Any],
286
279
  ) -> HookResult:
287
280
  """Process agent end through all middlewares (reverse order)."""
288
- logger.debug(f"Processing agent_end for agent_id={agent_id}, {len(self._middlewares)} middlewares (reverse order)")
281
+ logger.debug(f"Processing agent_end, {len(self._middlewares)} middlewares (reverse order)")
289
282
  for i, mw in enumerate(reversed(self._middlewares)):
290
283
  if hasattr(mw, 'on_agent_end'):
291
- hook_result = await mw.on_agent_end(agent_id, result, context)
284
+ hook_result = await mw.on_agent_end(result)
292
285
  if hook_result.action != HookAction.CONTINUE:
293
286
  logger.info(f"Middleware #{i} returned {hook_result.action} on agent_end")
294
287
  return hook_result
@@ -299,7 +292,6 @@ class MiddlewareChain:
299
292
  self,
300
293
  tool: "BaseTool",
301
294
  params: dict[str, Any],
302
- context: dict[str, Any],
303
295
  ) -> HookResult:
304
296
  """Process tool call through all middlewares.
305
297
 
@@ -309,7 +301,7 @@ class MiddlewareChain:
309
301
  logger.debug(f"Processing tool_call for tool={tool.name}, {len(self._middlewares)} middlewares")
310
302
  for i, mw in enumerate(self._middlewares):
311
303
  if hasattr(mw, 'on_tool_call'):
312
- result = await mw.on_tool_call(tool, params, context)
304
+ result = await mw.on_tool_call(tool, params)
313
305
  if result.action != HookAction.CONTINUE:
314
306
  logger.info(f"Middleware #{i} returned {result.action} on tool_call for tool={tool.name}")
315
307
  return result
@@ -320,13 +312,12 @@ class MiddlewareChain:
320
312
  self,
321
313
  tool: "BaseTool",
322
314
  result: "ToolResult",
323
- context: dict[str, Any],
324
315
  ) -> HookResult:
325
316
  """Process tool end through all middlewares (reverse order)."""
326
317
  logger.debug(f"Processing tool_end for tool={tool.name}, {len(self._middlewares)} middlewares (reverse order)")
327
318
  for i, mw in enumerate(reversed(self._middlewares)):
328
319
  if hasattr(mw, 'on_tool_end'):
329
- hook_result = await mw.on_tool_end(tool, result, context)
320
+ hook_result = await mw.on_tool_end(tool, result)
330
321
  if hook_result.action != HookAction.CONTINUE:
331
322
  logger.info(f"Middleware #{i} returned {hook_result.action} on tool_end for tool={tool.name}")
332
323
  return hook_result
@@ -338,14 +329,13 @@ class MiddlewareChain:
338
329
  parent_agent_id: str,
339
330
  child_agent_id: str,
340
331
  mode: str,
341
- context: dict[str, Any],
342
332
  ) -> HookResult:
343
333
  """Process sub-agent start through all middlewares."""
344
334
  logger.debug(f"Processing subagent_start, parent={parent_agent_id}, child={child_agent_id}, mode={mode}, {len(self._middlewares)} middlewares")
345
335
  for i, mw in enumerate(self._middlewares):
346
336
  if hasattr(mw, 'on_subagent_start'):
347
337
  result = await mw.on_subagent_start(
348
- parent_agent_id, child_agent_id, mode, context
338
+ parent_agent_id, child_agent_id, mode
349
339
  )
350
340
  if result.action != HookAction.CONTINUE:
351
341
  logger.info(f"Middleware #{i} returned {result.action} on subagent_start")
@@ -358,14 +348,13 @@ class MiddlewareChain:
358
348
  parent_agent_id: str,
359
349
  child_agent_id: str,
360
350
  result: Any,
361
- context: dict[str, Any],
362
351
  ) -> HookResult:
363
352
  """Process sub-agent end through all middlewares (reverse order)."""
364
353
  logger.debug(f"Processing subagent_end, parent={parent_agent_id}, child={child_agent_id}, {len(self._middlewares)} middlewares (reverse order)")
365
354
  for i, mw in enumerate(reversed(self._middlewares)):
366
355
  if hasattr(mw, 'on_subagent_end'):
367
356
  hook_result = await mw.on_subagent_end(
368
- parent_agent_id, child_agent_id, result, context
357
+ parent_agent_id, child_agent_id, result
369
358
  )
370
359
  if hook_result.action != HookAction.CONTINUE:
371
360
  logger.info(f"Middleware #{i} returned {hook_result.action} on subagent_end")
@@ -376,13 +365,11 @@ class MiddlewareChain:
376
365
  async def process_message_save(
377
366
  self,
378
367
  message: dict[str, Any],
379
- context: dict[str, Any],
380
368
  ) -> dict[str, Any] | None:
381
369
  """Process message save through all middlewares.
382
370
 
383
371
  Args:
384
372
  message: Message to be saved
385
- context: Execution context
386
373
 
387
374
  Returns:
388
375
  Modified message, or None to skip saving
@@ -392,7 +379,7 @@ class MiddlewareChain:
392
379
 
393
380
  for i, mw in enumerate(self._middlewares):
394
381
  if hasattr(mw, 'on_message_save'):
395
- result = await mw.on_message_save(current, context)
382
+ result = await mw.on_message_save(current)
396
383
  if result is None:
397
384
  logger.info(f"Middleware #{i} blocked message save for role={message.get('role')}")
398
385
  return None
@@ -56,37 +56,34 @@ class MessageBackendMiddleware(BaseMiddleware):
56
56
  async def on_message_save(
57
57
  self,
58
58
  message: dict[str, Any],
59
- context: dict[str, Any],
60
59
  ) -> dict[str, Any] | None:
61
60
  """Save message via backends.message.
62
61
 
63
62
  Args:
64
63
  message: Message dict with 'role', 'content', etc.
65
- context: Execution context with 'session_id', 'agent_id', 'backends'
66
64
 
67
65
  Returns:
68
66
  The message (pass through to other middlewares)
69
67
  """
70
68
  from ..core.context import get_current_ctx_or_none
71
69
 
72
- session_id = context.get("session_id", "")
73
- if not session_id:
74
- return message
75
-
76
70
  # Get MessageBackend from context
77
71
  ctx = get_current_ctx_or_none()
78
72
  if ctx is None or ctx.backends is None or ctx.backends.message is None:
79
73
  # No backend available, pass through
80
74
  return message
81
75
 
76
+ session_id = ctx.session_id or ""
77
+ if not session_id:
78
+ return message
79
+
82
80
  backend = ctx.backends.message
83
81
 
84
82
  # Extract message fields
85
83
  role = message.get("role", "")
86
84
  content = message.get("content", "")
87
- invocation_id = context.get("invocation_id", "")
88
- agent_id = context.get("agent_id")
89
- namespace = context.get("namespace")
85
+ invocation_id = ctx.invocation_id or ""
86
+ agent_id = ctx.agent_id
90
87
  tool_call_id = message.get("tool_call_id")
91
88
 
92
89
  # Build message dict for backend
@@ -107,7 +104,6 @@ class MessageBackendMiddleware(BaseMiddleware):
107
104
  message=msg_dict,
108
105
  type="truncated",
109
106
  agent_id=agent_id,
110
- namespace=namespace,
111
107
  invocation_id=invocation_id,
112
108
  )
113
109
 
@@ -118,7 +114,6 @@ class MessageBackendMiddleware(BaseMiddleware):
118
114
  message=message, # Full original message
119
115
  type="raw",
120
116
  agent_id=agent_id,
121
- namespace=namespace,
122
117
  invocation_id=invocation_id,
123
118
  )
124
119