openai-agents 0.2.8__py3-none-any.whl → 0.6.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. agents/__init__.py +105 -4
  2. agents/_debug.py +15 -4
  3. agents/_run_impl.py +1203 -96
  4. agents/agent.py +164 -19
  5. agents/apply_diff.py +329 -0
  6. agents/editor.py +47 -0
  7. agents/exceptions.py +35 -0
  8. agents/extensions/experimental/__init__.py +6 -0
  9. agents/extensions/experimental/codex/__init__.py +92 -0
  10. agents/extensions/experimental/codex/codex.py +89 -0
  11. agents/extensions/experimental/codex/codex_options.py +35 -0
  12. agents/extensions/experimental/codex/codex_tool.py +1142 -0
  13. agents/extensions/experimental/codex/events.py +162 -0
  14. agents/extensions/experimental/codex/exec.py +263 -0
  15. agents/extensions/experimental/codex/items.py +245 -0
  16. agents/extensions/experimental/codex/output_schema_file.py +50 -0
  17. agents/extensions/experimental/codex/payloads.py +31 -0
  18. agents/extensions/experimental/codex/thread.py +214 -0
  19. agents/extensions/experimental/codex/thread_options.py +54 -0
  20. agents/extensions/experimental/codex/turn_options.py +36 -0
  21. agents/extensions/handoff_filters.py +13 -1
  22. agents/extensions/memory/__init__.py +120 -0
  23. agents/extensions/memory/advanced_sqlite_session.py +1285 -0
  24. agents/extensions/memory/async_sqlite_session.py +239 -0
  25. agents/extensions/memory/dapr_session.py +423 -0
  26. agents/extensions/memory/encrypt_session.py +185 -0
  27. agents/extensions/memory/redis_session.py +261 -0
  28. agents/extensions/memory/sqlalchemy_session.py +334 -0
  29. agents/extensions/models/litellm_model.py +449 -36
  30. agents/extensions/models/litellm_provider.py +3 -1
  31. agents/function_schema.py +47 -5
  32. agents/guardrail.py +16 -2
  33. agents/{handoffs.py → handoffs/__init__.py} +89 -47
  34. agents/handoffs/history.py +268 -0
  35. agents/items.py +237 -11
  36. agents/lifecycle.py +75 -14
  37. agents/mcp/server.py +280 -37
  38. agents/mcp/util.py +24 -3
  39. agents/memory/__init__.py +22 -2
  40. agents/memory/openai_conversations_session.py +91 -0
  41. agents/memory/openai_responses_compaction_session.py +249 -0
  42. agents/memory/session.py +19 -261
  43. agents/memory/sqlite_session.py +275 -0
  44. agents/memory/util.py +20 -0
  45. agents/model_settings.py +14 -3
  46. agents/models/__init__.py +13 -0
  47. agents/models/chatcmpl_converter.py +303 -50
  48. agents/models/chatcmpl_helpers.py +63 -0
  49. agents/models/chatcmpl_stream_handler.py +290 -68
  50. agents/models/default_models.py +58 -0
  51. agents/models/interface.py +4 -0
  52. agents/models/openai_chatcompletions.py +103 -49
  53. agents/models/openai_provider.py +10 -4
  54. agents/models/openai_responses.py +162 -46
  55. agents/realtime/__init__.py +4 -0
  56. agents/realtime/_util.py +14 -3
  57. agents/realtime/agent.py +7 -0
  58. agents/realtime/audio_formats.py +53 -0
  59. agents/realtime/config.py +78 -10
  60. agents/realtime/events.py +18 -0
  61. agents/realtime/handoffs.py +2 -2
  62. agents/realtime/items.py +17 -1
  63. agents/realtime/model.py +13 -0
  64. agents/realtime/model_events.py +12 -0
  65. agents/realtime/model_inputs.py +18 -1
  66. agents/realtime/openai_realtime.py +696 -150
  67. agents/realtime/session.py +243 -23
  68. agents/repl.py +7 -3
  69. agents/result.py +197 -38
  70. agents/run.py +949 -168
  71. agents/run_context.py +13 -2
  72. agents/stream_events.py +1 -0
  73. agents/strict_schema.py +14 -0
  74. agents/tool.py +413 -15
  75. agents/tool_context.py +22 -1
  76. agents/tool_guardrails.py +279 -0
  77. agents/tracing/__init__.py +2 -0
  78. agents/tracing/config.py +9 -0
  79. agents/tracing/create.py +4 -0
  80. agents/tracing/processor_interface.py +84 -11
  81. agents/tracing/processors.py +65 -54
  82. agents/tracing/provider.py +64 -7
  83. agents/tracing/spans.py +105 -0
  84. agents/tracing/traces.py +116 -16
  85. agents/usage.py +134 -12
  86. agents/util/_json.py +19 -1
  87. agents/util/_transforms.py +12 -2
  88. agents/voice/input.py +5 -4
  89. agents/voice/models/openai_stt.py +17 -9
  90. agents/voice/pipeline.py +2 -0
  91. agents/voice/pipeline_config.py +4 -0
  92. {openai_agents-0.2.8.dist-info → openai_agents-0.6.8.dist-info}/METADATA +44 -19
  93. openai_agents-0.6.8.dist-info/RECORD +134 -0
  94. {openai_agents-0.2.8.dist-info → openai_agents-0.6.8.dist-info}/WHEEL +1 -1
  95. openai_agents-0.2.8.dist-info/RECORD +0 -103
  96. {openai_agents-0.2.8.dist-info → openai_agents-0.6.8.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,239 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import json
5
+ from collections.abc import AsyncIterator
6
+ from contextlib import asynccontextmanager
7
+ from pathlib import Path
8
+ from typing import cast
9
+
10
+ import aiosqlite
11
+
12
+ from ...items import TResponseInputItem
13
+ from ...memory import SessionABC
14
+
15
+
16
+ class AsyncSQLiteSession(SessionABC):
17
+ """Async SQLite-based implementation of session storage.
18
+
19
+ This implementation stores conversation history in a SQLite database.
20
+ By default, uses an in-memory database that is lost when the process ends.
21
+ For persistent storage, provide a file path.
22
+ """
23
+
24
+ def __init__(
25
+ self,
26
+ session_id: str,
27
+ db_path: str | Path = ":memory:",
28
+ sessions_table: str = "agent_sessions",
29
+ messages_table: str = "agent_messages",
30
+ ):
31
+ """Initialize the async SQLite session.
32
+
33
+ Args:
34
+ session_id: Unique identifier for the conversation session
35
+ db_path: Path to the SQLite database file. Defaults to ':memory:' (in-memory database)
36
+ sessions_table: Name of the table to store session metadata. Defaults to
37
+ 'agent_sessions'
38
+ messages_table: Name of the table to store message data. Defaults to 'agent_messages'
39
+ """
40
+ self.session_id = session_id
41
+ self.db_path = db_path
42
+ self.sessions_table = sessions_table
43
+ self.messages_table = messages_table
44
+ self._connection: aiosqlite.Connection | None = None
45
+ self._lock = asyncio.Lock()
46
+ self._init_lock = asyncio.Lock()
47
+
48
+ async def _init_db_for_connection(self, conn: aiosqlite.Connection) -> None:
49
+ """Initialize the database schema for a specific connection."""
50
+ await conn.execute(
51
+ f"""
52
+ CREATE TABLE IF NOT EXISTS {self.sessions_table} (
53
+ session_id TEXT PRIMARY KEY,
54
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
55
+ updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
56
+ )
57
+ """
58
+ )
59
+
60
+ await conn.execute(
61
+ f"""
62
+ CREATE TABLE IF NOT EXISTS {self.messages_table} (
63
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
64
+ session_id TEXT NOT NULL,
65
+ message_data TEXT NOT NULL,
66
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
67
+ FOREIGN KEY (session_id) REFERENCES {self.sessions_table} (session_id)
68
+ ON DELETE CASCADE
69
+ )
70
+ """
71
+ )
72
+
73
+ await conn.execute(
74
+ f"""
75
+ CREATE INDEX IF NOT EXISTS idx_{self.messages_table}_session_id
76
+ ON {self.messages_table} (session_id, id)
77
+ """
78
+ )
79
+
80
+ await conn.commit()
81
+
82
+ async def _get_connection(self) -> aiosqlite.Connection:
83
+ """Get or create a database connection."""
84
+ if self._connection is not None:
85
+ return self._connection
86
+
87
+ async with self._init_lock:
88
+ if self._connection is None:
89
+ self._connection = await aiosqlite.connect(str(self.db_path))
90
+ await self._connection.execute("PRAGMA journal_mode=WAL")
91
+ await self._init_db_for_connection(self._connection)
92
+
93
+ return self._connection
94
+
95
+ @asynccontextmanager
96
+ async def _locked_connection(self) -> AsyncIterator[aiosqlite.Connection]:
97
+ """Provide a connection under the session lock."""
98
+ async with self._lock:
99
+ conn = await self._get_connection()
100
+ yield conn
101
+
102
+ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
103
+ """Retrieve the conversation history for this session.
104
+
105
+ Args:
106
+ limit: Maximum number of items to retrieve. If None, retrieves all items.
107
+ When specified, returns the latest N items in chronological order.
108
+
109
+ Returns:
110
+ List of input items representing the conversation history
111
+ """
112
+
113
+ async with self._locked_connection() as conn:
114
+ if limit is None:
115
+ cursor = await conn.execute(
116
+ f"""
117
+ SELECT message_data FROM {self.messages_table}
118
+ WHERE session_id = ?
119
+ ORDER BY id ASC
120
+ """,
121
+ (self.session_id,),
122
+ )
123
+ else:
124
+ cursor = await conn.execute(
125
+ f"""
126
+ SELECT message_data FROM {self.messages_table}
127
+ WHERE session_id = ?
128
+ ORDER BY id DESC
129
+ LIMIT ?
130
+ """,
131
+ (self.session_id, limit),
132
+ )
133
+
134
+ rows = list(await cursor.fetchall())
135
+ await cursor.close()
136
+
137
+ if limit is not None:
138
+ rows = rows[::-1]
139
+
140
+ items: list[TResponseInputItem] = []
141
+ for (message_data,) in rows:
142
+ try:
143
+ item = json.loads(message_data)
144
+ items.append(item)
145
+ except json.JSONDecodeError:
146
+ continue
147
+
148
+ return items
149
+
150
+ async def add_items(self, items: list[TResponseInputItem]) -> None:
151
+ """Add new items to the conversation history.
152
+
153
+ Args:
154
+ items: List of input items to add to the history
155
+ """
156
+ if not items:
157
+ return
158
+
159
+ async with self._locked_connection() as conn:
160
+ await conn.execute(
161
+ f"""
162
+ INSERT OR IGNORE INTO {self.sessions_table} (session_id) VALUES (?)
163
+ """,
164
+ (self.session_id,),
165
+ )
166
+
167
+ message_data = [(self.session_id, json.dumps(item)) for item in items]
168
+ await conn.executemany(
169
+ f"""
170
+ INSERT INTO {self.messages_table} (session_id, message_data) VALUES (?, ?)
171
+ """,
172
+ message_data,
173
+ )
174
+
175
+ await conn.execute(
176
+ f"""
177
+ UPDATE {self.sessions_table}
178
+ SET updated_at = CURRENT_TIMESTAMP
179
+ WHERE session_id = ?
180
+ """,
181
+ (self.session_id,),
182
+ )
183
+
184
+ await conn.commit()
185
+
186
+ async def pop_item(self) -> TResponseInputItem | None:
187
+ """Remove and return the most recent item from the session.
188
+
189
+ Returns:
190
+ The most recent item if it exists, None if the session is empty
191
+ """
192
+ async with self._locked_connection() as conn:
193
+ cursor = await conn.execute(
194
+ f"""
195
+ DELETE FROM {self.messages_table}
196
+ WHERE id = (
197
+ SELECT id FROM {self.messages_table}
198
+ WHERE session_id = ?
199
+ ORDER BY id DESC
200
+ LIMIT 1
201
+ )
202
+ RETURNING message_data
203
+ """,
204
+ (self.session_id,),
205
+ )
206
+
207
+ result = await cursor.fetchone()
208
+ await cursor.close()
209
+ await conn.commit()
210
+
211
+ if result:
212
+ message_data = result[0]
213
+ try:
214
+ return cast(TResponseInputItem, json.loads(message_data))
215
+ except json.JSONDecodeError:
216
+ return None
217
+
218
+ return None
219
+
220
+ async def clear_session(self) -> None:
221
+ """Clear all items for this session."""
222
+ async with self._locked_connection() as conn:
223
+ await conn.execute(
224
+ f"DELETE FROM {self.messages_table} WHERE session_id = ?",
225
+ (self.session_id,),
226
+ )
227
+ await conn.execute(
228
+ f"DELETE FROM {self.sessions_table} WHERE session_id = ?",
229
+ (self.session_id,),
230
+ )
231
+ await conn.commit()
232
+
233
+ async def close(self) -> None:
234
+ """Close the database connection."""
235
+ if self._connection is None:
236
+ return
237
+ async with self._lock:
238
+ await self._connection.close()
239
+ self._connection = None
@@ -0,0 +1,423 @@
1
+ """Dapr State Store-powered Session backend.
2
+
3
+ Usage::
4
+
5
+ from agents.extensions.memory import DaprSession
6
+
7
+ # Create from Dapr sidecar address
8
+ session = DaprSession.from_address(
9
+ session_id="user-123",
10
+ state_store_name="statestore",
11
+ dapr_address="localhost:50001",
12
+ )
13
+
14
+ # Or pass an existing Dapr client that your application already manages
15
+ session = DaprSession(
16
+ session_id="user-123",
17
+ state_store_name="statestore",
18
+ dapr_client=my_dapr_client,
19
+ )
20
+
21
+ await Runner.run(agent, "Hello", session=session)
22
+ """
23
+
24
+ from __future__ import annotations
25
+
26
+ import asyncio
27
+ import json
28
+ import random
29
+ import time
30
+ from typing import Any, Final, Literal
31
+
32
+ try:
33
+ from dapr.aio.clients import DaprClient
34
+ from dapr.clients.grpc._state import Concurrency, Consistency, StateOptions
35
+ except ImportError as e:
36
+ raise ImportError(
37
+ "DaprSession requires the 'dapr' package. Install it with: pip install dapr"
38
+ ) from e
39
+
40
+ from ...items import TResponseInputItem
41
+ from ...logger import logger
42
+ from ...memory.session import SessionABC
43
+
44
+ # Type alias for consistency levels
45
+ ConsistencyLevel = Literal["eventual", "strong"]
46
+
47
+ # Consistency level constants
48
+ DAPR_CONSISTENCY_EVENTUAL: ConsistencyLevel = "eventual"
49
+ DAPR_CONSISTENCY_STRONG: ConsistencyLevel = "strong"
50
+
51
+ _MAX_WRITE_ATTEMPTS: Final[int] = 5
52
+ _RETRY_BASE_DELAY_SECONDS: Final[float] = 0.05
53
+ _RETRY_MAX_DELAY_SECONDS: Final[float] = 1.0
54
+
55
+
56
+ class DaprSession(SessionABC):
57
+ """Dapr State Store implementation of :pyclass:`agents.memory.session.Session`."""
58
+
59
+ def __init__(
60
+ self,
61
+ session_id: str,
62
+ *,
63
+ state_store_name: str,
64
+ dapr_client: DaprClient,
65
+ ttl: int | None = None,
66
+ consistency: ConsistencyLevel = DAPR_CONSISTENCY_EVENTUAL,
67
+ ):
68
+ """Initializes a new DaprSession.
69
+
70
+ Args:
71
+ session_id (str): Unique identifier for the conversation.
72
+ state_store_name (str): Name of the Dapr state store component.
73
+ dapr_client (DaprClient): A pre-configured Dapr client.
74
+ ttl (int | None, optional): Time-to-live in seconds for session data.
75
+ If None, data persists indefinitely. Note that TTL support depends on
76
+ the underlying state store implementation. Defaults to None.
77
+ consistency (ConsistencyLevel, optional): Consistency level for state operations.
78
+ Use DAPR_CONSISTENCY_EVENTUAL or DAPR_CONSISTENCY_STRONG constants.
79
+ Defaults to DAPR_CONSISTENCY_EVENTUAL.
80
+ """
81
+ self.session_id = session_id
82
+ self._dapr_client = dapr_client
83
+ self._state_store_name = state_store_name
84
+ self._ttl = ttl
85
+ self._consistency = consistency
86
+ self._lock = asyncio.Lock()
87
+ self._owns_client = False # Track if we own the Dapr client
88
+
89
+ # State keys
90
+ self._messages_key = f"{self.session_id}:messages"
91
+ self._metadata_key = f"{self.session_id}:metadata"
92
+
93
+ @classmethod
94
+ def from_address(
95
+ cls,
96
+ session_id: str,
97
+ *,
98
+ state_store_name: str,
99
+ dapr_address: str = "localhost:50001",
100
+ **kwargs: Any,
101
+ ) -> DaprSession:
102
+ """Create a session from a Dapr sidecar address.
103
+
104
+ Args:
105
+ session_id (str): Conversation ID.
106
+ state_store_name (str): Name of the Dapr state store component.
107
+ dapr_address (str): Dapr sidecar gRPC address. Defaults to "localhost:50001".
108
+ **kwargs: Additional keyword arguments forwarded to the main constructor
109
+ (e.g., ttl, consistency).
110
+
111
+ Returns:
112
+ DaprSession: An instance of DaprSession connected to the specified Dapr sidecar.
113
+
114
+ Note:
115
+ The Dapr Python SDK performs health checks on the HTTP endpoint (default: http://localhost:3500).
116
+ Ensure the Dapr sidecar is started with --dapr-http-port 3500. Alternatively, set one of
117
+ these environment variables: DAPR_HTTP_ENDPOINT (e.g., "http://localhost:3500") or
118
+ DAPR_HTTP_PORT (e.g., "3500") to avoid connection errors.
119
+ """
120
+ dapr_client = DaprClient(address=dapr_address)
121
+ session = cls(
122
+ session_id, state_store_name=state_store_name, dapr_client=dapr_client, **kwargs
123
+ )
124
+ session._owns_client = True # We created the client, so we own it
125
+ return session
126
+
127
+ def _get_read_metadata(self) -> dict[str, str]:
128
+ """Get metadata for read operations including consistency.
129
+
130
+ The consistency level is passed through state_metadata as per Dapr's state API.
131
+ """
132
+ metadata: dict[str, str] = {}
133
+ # Add consistency level to metadata for read operations
134
+ if self._consistency:
135
+ metadata["consistency"] = self._consistency
136
+ return metadata
137
+
138
+ def _get_state_options(self, *, concurrency: Concurrency | None = None) -> StateOptions | None:
139
+ """Get StateOptions configured with consistency and optional concurrency."""
140
+ options_kwargs: dict[str, Any] = {}
141
+ if self._consistency == DAPR_CONSISTENCY_STRONG:
142
+ options_kwargs["consistency"] = Consistency.strong
143
+ elif self._consistency == DAPR_CONSISTENCY_EVENTUAL:
144
+ options_kwargs["consistency"] = Consistency.eventual
145
+ if concurrency is not None:
146
+ options_kwargs["concurrency"] = concurrency
147
+ if options_kwargs:
148
+ return StateOptions(**options_kwargs)
149
+ return None
150
+
151
+ def _get_metadata(self) -> dict[str, str]:
152
+ """Get metadata for state operations including TTL if configured."""
153
+ metadata = {}
154
+ if self._ttl is not None:
155
+ metadata["ttlInSeconds"] = str(self._ttl)
156
+ return metadata
157
+
158
+ async def _serialize_item(self, item: TResponseInputItem) -> str:
159
+ """Serialize an item to JSON string. Can be overridden by subclasses."""
160
+ return json.dumps(item, separators=(",", ":"))
161
+
162
+ async def _deserialize_item(self, item: str) -> TResponseInputItem:
163
+ """Deserialize a JSON string to an item. Can be overridden by subclasses."""
164
+ return json.loads(item) # type: ignore[no-any-return]
165
+
166
+ def _decode_messages(self, data: bytes | None) -> list[Any]:
167
+ if not data:
168
+ return []
169
+ try:
170
+ messages_json = data.decode("utf-8")
171
+ messages = json.loads(messages_json)
172
+ if isinstance(messages, list):
173
+ return list(messages)
174
+ except (json.JSONDecodeError, UnicodeDecodeError):
175
+ return []
176
+ return []
177
+
178
+ def _calculate_retry_delay(self, attempt: int) -> float:
179
+ base: float = _RETRY_BASE_DELAY_SECONDS * (2 ** max(0, attempt - 1))
180
+ delay: float = min(base, _RETRY_MAX_DELAY_SECONDS)
181
+ # Add jitter (10%) similar to tracing processors to avoid thundering herd.
182
+ return delay + random.uniform(0, 0.1 * delay)
183
+
184
+ def _is_concurrency_conflict(self, error: Exception) -> bool:
185
+ code_attr = getattr(error, "code", None)
186
+ if callable(code_attr):
187
+ try:
188
+ status_code = code_attr()
189
+ except Exception:
190
+ status_code = None
191
+ if status_code is not None:
192
+ status_name = getattr(status_code, "name", str(status_code))
193
+ if status_name in {"ABORTED", "FAILED_PRECONDITION"}:
194
+ return True
195
+ message = str(error).lower()
196
+ conflict_markers = (
197
+ "etag mismatch",
198
+ "etag does not match",
199
+ "precondition failed",
200
+ "concurrency conflict",
201
+ "invalid etag",
202
+ "failed to set key", # Redis state store Lua script error during conditional write
203
+ "user_script", # Redis script failure hint
204
+ )
205
+ return any(marker in message for marker in conflict_markers)
206
+
207
+ async def _handle_concurrency_conflict(self, error: Exception, attempt: int) -> bool:
208
+ if not self._is_concurrency_conflict(error):
209
+ return False
210
+ if attempt >= _MAX_WRITE_ATTEMPTS:
211
+ return False
212
+ delay = self._calculate_retry_delay(attempt)
213
+ if delay > 0:
214
+ await asyncio.sleep(delay)
215
+ return True
216
+
217
+ # ------------------------------------------------------------------
218
+ # Session protocol implementation
219
+ # ------------------------------------------------------------------
220
+
221
+ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
222
+ """Retrieve the conversation history for this session.
223
+
224
+ Args:
225
+ limit: Maximum number of items to retrieve. If None, retrieves all items.
226
+ When specified, returns the latest N items in chronological order.
227
+
228
+ Returns:
229
+ List of input items representing the conversation history
230
+ """
231
+ async with self._lock:
232
+ # Get messages from state store with consistency level
233
+ response = await self._dapr_client.get_state(
234
+ store_name=self._state_store_name,
235
+ key=self._messages_key,
236
+ state_metadata=self._get_read_metadata(),
237
+ )
238
+
239
+ messages = self._decode_messages(response.data)
240
+ if not messages:
241
+ return []
242
+ if limit is not None:
243
+ if limit <= 0:
244
+ return []
245
+ messages = messages[-limit:]
246
+ items: list[TResponseInputItem] = []
247
+ for msg in messages:
248
+ try:
249
+ if isinstance(msg, str):
250
+ item = await self._deserialize_item(msg)
251
+ else:
252
+ item = msg
253
+ items.append(item)
254
+ except (json.JSONDecodeError, TypeError):
255
+ continue
256
+ return items
257
+
258
+ async def add_items(self, items: list[TResponseInputItem]) -> None:
259
+ """Add new items to the conversation history.
260
+
261
+ Args:
262
+ items: List of input items to add to the history
263
+ """
264
+ if not items:
265
+ return
266
+
267
+ async with self._lock:
268
+ serialized_items: list[str] = [await self._serialize_item(item) for item in items]
269
+ attempt = 0
270
+ while True:
271
+ attempt += 1
272
+ response = await self._dapr_client.get_state(
273
+ store_name=self._state_store_name,
274
+ key=self._messages_key,
275
+ state_metadata=self._get_read_metadata(),
276
+ )
277
+ existing_messages = self._decode_messages(response.data)
278
+ updated_messages = existing_messages + serialized_items
279
+ messages_json = json.dumps(updated_messages, separators=(",", ":"))
280
+ etag = response.etag
281
+ try:
282
+ await self._dapr_client.save_state(
283
+ store_name=self._state_store_name,
284
+ key=self._messages_key,
285
+ value=messages_json,
286
+ etag=etag,
287
+ state_metadata=self._get_metadata(),
288
+ options=self._get_state_options(concurrency=Concurrency.first_write),
289
+ )
290
+ break
291
+ except Exception as error:
292
+ should_retry = await self._handle_concurrency_conflict(error, attempt)
293
+ if should_retry:
294
+ continue
295
+ raise
296
+
297
+ # Update metadata
298
+ metadata = {
299
+ "session_id": self.session_id,
300
+ "created_at": str(int(time.time())),
301
+ "updated_at": str(int(time.time())),
302
+ }
303
+ await self._dapr_client.save_state(
304
+ store_name=self._state_store_name,
305
+ key=self._metadata_key,
306
+ value=json.dumps(metadata),
307
+ state_metadata=self._get_metadata(),
308
+ options=self._get_state_options(),
309
+ )
310
+
311
+ async def pop_item(self) -> TResponseInputItem | None:
312
+ """Remove and return the most recent item from the session.
313
+
314
+ Returns:
315
+ The most recent item if it exists, None if the session is empty
316
+ """
317
+ async with self._lock:
318
+ attempt = 0
319
+ while True:
320
+ attempt += 1
321
+ response = await self._dapr_client.get_state(
322
+ store_name=self._state_store_name,
323
+ key=self._messages_key,
324
+ state_metadata=self._get_read_metadata(),
325
+ )
326
+ messages = self._decode_messages(response.data)
327
+ if not messages:
328
+ return None
329
+ last_item = messages.pop()
330
+ messages_json = json.dumps(messages, separators=(",", ":"))
331
+ etag = getattr(response, "etag", None) or None
332
+ etag = getattr(response, "etag", None) or None
333
+ try:
334
+ await self._dapr_client.save_state(
335
+ store_name=self._state_store_name,
336
+ key=self._messages_key,
337
+ value=messages_json,
338
+ etag=etag,
339
+ state_metadata=self._get_metadata(),
340
+ options=self._get_state_options(concurrency=Concurrency.first_write),
341
+ )
342
+ break
343
+ except Exception as error:
344
+ should_retry = await self._handle_concurrency_conflict(error, attempt)
345
+ if should_retry:
346
+ continue
347
+ raise
348
+ try:
349
+ if isinstance(last_item, str):
350
+ return await self._deserialize_item(last_item)
351
+ return last_item # type: ignore[no-any-return]
352
+ except (json.JSONDecodeError, TypeError):
353
+ return None
354
+
355
+ async def clear_session(self) -> None:
356
+ """Clear all items for this session."""
357
+ async with self._lock:
358
+ # Delete messages and metadata keys
359
+ await self._dapr_client.delete_state(
360
+ store_name=self._state_store_name,
361
+ key=self._messages_key,
362
+ options=self._get_state_options(),
363
+ )
364
+
365
+ await self._dapr_client.delete_state(
366
+ store_name=self._state_store_name,
367
+ key=self._metadata_key,
368
+ options=self._get_state_options(),
369
+ )
370
+
371
+ async def close(self) -> None:
372
+ """Close the Dapr client connection.
373
+
374
+ Only closes the connection if this session owns the Dapr client
375
+ (i.e., created via from_address). If the client was injected externally,
376
+ the caller is responsible for managing its lifecycle.
377
+ """
378
+ if self._owns_client:
379
+ await self._dapr_client.close()
380
+
381
+ async def __aenter__(self) -> DaprSession:
382
+ """Enter async context manager."""
383
+ return self
384
+
385
+ async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
386
+ """Exit async context manager and close the connection."""
387
+ await self.close()
388
+
389
+ async def ping(self) -> bool:
390
+ """Test Dapr connectivity by checking metadata.
391
+
392
+ Returns:
393
+ True if Dapr is reachable, False otherwise.
394
+ """
395
+ try:
396
+ # First attempt a read; some stores may not be initialized yet.
397
+ await self._dapr_client.get_state(
398
+ store_name=self._state_store_name,
399
+ key="__ping__",
400
+ state_metadata=self._get_read_metadata(),
401
+ )
402
+ return True
403
+ except Exception as initial_error:
404
+ # If relation/table is missing or store isn't initialized,
405
+ # attempt a write to initialize it, then read again.
406
+ try:
407
+ await self._dapr_client.save_state(
408
+ store_name=self._state_store_name,
409
+ key="__ping__",
410
+ value="ok",
411
+ state_metadata=self._get_metadata(),
412
+ options=self._get_state_options(),
413
+ )
414
+ # Read again after write.
415
+ await self._dapr_client.get_state(
416
+ store_name=self._state_store_name,
417
+ key="__ping__",
418
+ state_metadata=self._get_read_metadata(),
419
+ )
420
+ return True
421
+ except Exception:
422
+ logger.error("Dapr connection failed: %s", initial_error)
423
+ return False