chuk-ai-session-manager 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- chuk_ai_session_manager/__init__.py +57 -0
- chuk_ai_session_manager/exceptions.py +129 -0
- chuk_ai_session_manager/infinite_conversation.py +316 -0
- chuk_ai_session_manager/models/__init__.py +44 -0
- chuk_ai_session_manager/models/event_source.py +8 -0
- chuk_ai_session_manager/models/event_type.py +9 -0
- chuk_ai_session_manager/models/session.py +316 -0
- chuk_ai_session_manager/models/session_event.py +166 -0
- chuk_ai_session_manager/models/session_metadata.py +37 -0
- chuk_ai_session_manager/models/session_run.py +115 -0
- chuk_ai_session_manager/models/token_usage.py +316 -0
- chuk_ai_session_manager/sample_tools.py +194 -0
- chuk_ai_session_manager/session_aware_tool_processor.py +178 -0
- chuk_ai_session_manager/session_prompt_builder.py +474 -0
- chuk_ai_session_manager/storage/__init__.py +44 -0
- chuk_ai_session_manager/storage/base.py +50 -0
- chuk_ai_session_manager/storage/providers/__init__.py +0 -0
- chuk_ai_session_manager/storage/providers/file.py +348 -0
- chuk_ai_session_manager/storage/providers/memory.py +96 -0
- chuk_ai_session_manager/storage/providers/redis.py +295 -0
- chuk_ai_session_manager-0.1.1.dist-info/METADATA +501 -0
- chuk_ai_session_manager-0.1.1.dist-info/RECORD +24 -0
- chuk_ai_session_manager-0.1.1.dist-info/WHEEL +5 -0
- chuk_ai_session_manager-0.1.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,316 @@
|
|
|
1
|
+
# chuk_ai_session_manager/models/session.py
|
|
2
|
+
"""
|
|
3
|
+
Session model for the chuk session manager with improved async support.
|
|
4
|
+
"""
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
from datetime import datetime, timezone
|
|
7
|
+
from typing import Any, Dict, List, Optional, Generic, TypeVar, Union
|
|
8
|
+
from uuid import uuid4
|
|
9
|
+
from pydantic import BaseModel, Field, model_validator
|
|
10
|
+
import asyncio
|
|
11
|
+
|
|
12
|
+
# Import models that Session depends on
|
|
13
|
+
from chuk_ai_session_manager.models.session_metadata import SessionMetadata
|
|
14
|
+
from chuk_ai_session_manager.models.session_event import SessionEvent
|
|
15
|
+
from chuk_ai_session_manager.models.token_usage import TokenUsage, TokenSummary
|
|
16
|
+
# Import SessionRun and RunStatus directly to avoid circular import
|
|
17
|
+
from chuk_ai_session_manager.models.session_run import SessionRun, RunStatus
|
|
18
|
+
|
|
19
|
+
MessageT = TypeVar('MessageT')
|
|
20
|
+
|
|
21
|
+
class Session(BaseModel, Generic[MessageT]):
|
|
22
|
+
"""A standalone conversation session with hierarchical support and async methods."""
|
|
23
|
+
id: str = Field(default_factory=lambda: str(uuid4()))
|
|
24
|
+
metadata: SessionMetadata = Field(default_factory=SessionMetadata)
|
|
25
|
+
|
|
26
|
+
parent_id: Optional[str] = None
|
|
27
|
+
child_ids: List[str] = Field(default_factory=list)
|
|
28
|
+
|
|
29
|
+
task_ids: List[str] = Field(default_factory=list)
|
|
30
|
+
runs: List[SessionRun] = Field(default_factory=list)
|
|
31
|
+
events: List[SessionEvent[MessageT]] = Field(default_factory=list)
|
|
32
|
+
state: Dict[str, Any] = Field(default_factory=dict)
|
|
33
|
+
|
|
34
|
+
# Token tracking
|
|
35
|
+
token_summary: TokenSummary = Field(default_factory=TokenSummary)
|
|
36
|
+
|
|
37
|
+
@model_validator(mode="after")
|
|
38
|
+
def _sync_hierarchy(cls, model: Session) -> Session:
|
|
39
|
+
"""After creation, sync this session with its parent in the store.
|
|
40
|
+
|
|
41
|
+
Note: This is synchronous for compatibility with Pydantic.
|
|
42
|
+
For async parent syncing, use async_init() after creation.
|
|
43
|
+
"""
|
|
44
|
+
# This validator will be called during model creation,
|
|
45
|
+
# but won't actually sync with storage - that requires async
|
|
46
|
+
return model
|
|
47
|
+
|
|
48
|
+
async def async_init(self) -> None:
|
|
49
|
+
"""
|
|
50
|
+
Initialize async components of the session.
|
|
51
|
+
|
|
52
|
+
Call this after creating a new session to properly set up
|
|
53
|
+
parent-child relationships in the async storage.
|
|
54
|
+
"""
|
|
55
|
+
if self.parent_id:
|
|
56
|
+
# Import here to avoid circular import
|
|
57
|
+
from chuk_ai_session_manager.storage import SessionStoreProvider
|
|
58
|
+
store = SessionStoreProvider.get_store()
|
|
59
|
+
parent = await store.get(self.parent_id)
|
|
60
|
+
if parent and self.id not in parent.child_ids:
|
|
61
|
+
parent.child_ids.append(self.id)
|
|
62
|
+
await store.save(parent)
|
|
63
|
+
|
|
64
|
+
@property
|
|
65
|
+
def last_update_time(self) -> datetime:
|
|
66
|
+
"""Return timestamp of most recent event, or session creation."""
|
|
67
|
+
if not self.events:
|
|
68
|
+
return self.metadata.created_at
|
|
69
|
+
return max(evt.timestamp for evt in self.events)
|
|
70
|
+
|
|
71
|
+
@property
|
|
72
|
+
def active_run(self) -> Optional[SessionRun]:
|
|
73
|
+
"""Return the currently running SessionRun, if any."""
|
|
74
|
+
for run in reversed(self.runs):
|
|
75
|
+
if run.status == RunStatus.RUNNING:
|
|
76
|
+
return run
|
|
77
|
+
return None
|
|
78
|
+
|
|
79
|
+
@property
|
|
80
|
+
def total_tokens(self) -> int:
|
|
81
|
+
"""Return the total number of tokens used in this session."""
|
|
82
|
+
return self.token_summary.total_tokens
|
|
83
|
+
|
|
84
|
+
@property
|
|
85
|
+
def total_cost(self) -> float:
|
|
86
|
+
"""Return the total estimated cost of this session."""
|
|
87
|
+
return self.token_summary.total_estimated_cost_usd
|
|
88
|
+
|
|
89
|
+
async def add_child(self, child_id: str) -> None:
|
|
90
|
+
"""Add a child session ID and save the session."""
|
|
91
|
+
if child_id not in self.child_ids:
|
|
92
|
+
self.child_ids.append(child_id)
|
|
93
|
+
# Save the updated session
|
|
94
|
+
from chuk_ai_session_manager.storage import SessionStoreProvider
|
|
95
|
+
store = SessionStoreProvider.get_store()
|
|
96
|
+
await store.save(self)
|
|
97
|
+
|
|
98
|
+
async def remove_child(self, child_id: str) -> None:
|
|
99
|
+
"""Remove a child session ID and save the session."""
|
|
100
|
+
if child_id in self.child_ids:
|
|
101
|
+
self.child_ids.remove(child_id)
|
|
102
|
+
# Save the updated session
|
|
103
|
+
from chuk_ai_session_manager.storage import SessionStoreProvider
|
|
104
|
+
store = SessionStoreProvider.get_store()
|
|
105
|
+
await store.save(self)
|
|
106
|
+
|
|
107
|
+
async def ancestors(self) -> List[Session]:
|
|
108
|
+
"""Fetch ancestor sessions from store asynchronously."""
|
|
109
|
+
result: List[Session] = []
|
|
110
|
+
current = self.parent_id
|
|
111
|
+
|
|
112
|
+
# Import here to avoid circular import
|
|
113
|
+
from chuk_ai_session_manager.storage import SessionStoreProvider
|
|
114
|
+
store = SessionStoreProvider.get_store()
|
|
115
|
+
|
|
116
|
+
while current:
|
|
117
|
+
parent = await store.get(current)
|
|
118
|
+
if not parent:
|
|
119
|
+
break
|
|
120
|
+
result.append(parent)
|
|
121
|
+
current = parent.parent_id
|
|
122
|
+
return result
|
|
123
|
+
|
|
124
|
+
async def descendants(self) -> List[Session]:
|
|
125
|
+
"""Fetch all descendant sessions from store in DFS order asynchronously."""
|
|
126
|
+
result: List[Session] = []
|
|
127
|
+
stack = list(self.child_ids)
|
|
128
|
+
|
|
129
|
+
# Import here to avoid circular import
|
|
130
|
+
from chuk_ai_session_manager.storage import SessionStoreProvider
|
|
131
|
+
store = SessionStoreProvider.get_store()
|
|
132
|
+
|
|
133
|
+
while stack:
|
|
134
|
+
cid = stack.pop()
|
|
135
|
+
child = await store.get(cid)
|
|
136
|
+
if child:
|
|
137
|
+
result.append(child)
|
|
138
|
+
stack.extend(child.child_ids)
|
|
139
|
+
return result
|
|
140
|
+
|
|
141
|
+
async def add_event(self, event: SessionEvent[MessageT]) -> None:
|
|
142
|
+
"""
|
|
143
|
+
Add an event to the session and update token tracking asynchronously.
|
|
144
|
+
|
|
145
|
+
Args:
|
|
146
|
+
event: The event to add
|
|
147
|
+
"""
|
|
148
|
+
# Add the event
|
|
149
|
+
self.events.append(event)
|
|
150
|
+
|
|
151
|
+
# Update token summary if this event has token usage
|
|
152
|
+
if event.token_usage:
|
|
153
|
+
await self.token_summary.add_usage(event.token_usage)
|
|
154
|
+
|
|
155
|
+
async def add_event_and_save(self, event: SessionEvent[MessageT]) -> None:
|
|
156
|
+
"""
|
|
157
|
+
Add an event to the session, update token tracking, and save the session.
|
|
158
|
+
|
|
159
|
+
Args:
|
|
160
|
+
event: The event to add
|
|
161
|
+
"""
|
|
162
|
+
# Add the event asynchronously
|
|
163
|
+
await self.add_event(event)
|
|
164
|
+
|
|
165
|
+
# Save the session
|
|
166
|
+
from chuk_ai_session_manager.storage import SessionStoreProvider
|
|
167
|
+
store = SessionStoreProvider.get_store()
|
|
168
|
+
await store.save(self)
|
|
169
|
+
|
|
170
|
+
async def get_token_usage_by_source(self) -> Dict[str, TokenSummary]:
|
|
171
|
+
"""
|
|
172
|
+
Get token usage statistics grouped by event source asynchronously.
|
|
173
|
+
|
|
174
|
+
Returns:
|
|
175
|
+
A dictionary mapping event sources to token summaries
|
|
176
|
+
"""
|
|
177
|
+
result: Dict[str, TokenSummary] = {}
|
|
178
|
+
|
|
179
|
+
for event in self.events:
|
|
180
|
+
if not event.token_usage:
|
|
181
|
+
continue
|
|
182
|
+
|
|
183
|
+
source = event.source.value
|
|
184
|
+
if source not in result:
|
|
185
|
+
result[source] = TokenSummary()
|
|
186
|
+
|
|
187
|
+
await result[source].add_usage(event.token_usage)
|
|
188
|
+
|
|
189
|
+
return result
|
|
190
|
+
|
|
191
|
+
async def get_token_usage_by_run(self) -> Dict[str, TokenSummary]:
|
|
192
|
+
"""
|
|
193
|
+
Get token usage statistics grouped by run asynchronously.
|
|
194
|
+
|
|
195
|
+
Returns:
|
|
196
|
+
A dictionary mapping run IDs to token summaries
|
|
197
|
+
"""
|
|
198
|
+
result: Dict[str, TokenSummary] = {}
|
|
199
|
+
|
|
200
|
+
# Add an entry for events without a run
|
|
201
|
+
result["no_run"] = TokenSummary()
|
|
202
|
+
|
|
203
|
+
for event in self.events:
|
|
204
|
+
if not event.token_usage:
|
|
205
|
+
continue
|
|
206
|
+
|
|
207
|
+
run_id = event.task_id or "no_run"
|
|
208
|
+
if run_id not in result:
|
|
209
|
+
result[run_id] = TokenSummary()
|
|
210
|
+
|
|
211
|
+
await result[run_id].add_usage(event.token_usage)
|
|
212
|
+
|
|
213
|
+
return result
|
|
214
|
+
|
|
215
|
+
async def count_message_tokens(
|
|
216
|
+
self,
|
|
217
|
+
message: Union[str, Dict[str, Any]],
|
|
218
|
+
model: str = "gpt-3.5-turbo"
|
|
219
|
+
) -> int:
|
|
220
|
+
"""
|
|
221
|
+
Count tokens in a message asynchronously.
|
|
222
|
+
|
|
223
|
+
Args:
|
|
224
|
+
message: The message to count tokens for (string or dict)
|
|
225
|
+
model: The model to use for counting
|
|
226
|
+
|
|
227
|
+
Returns:
|
|
228
|
+
The number of tokens in the message
|
|
229
|
+
"""
|
|
230
|
+
# If message is already a string, count directly
|
|
231
|
+
if isinstance(message, str):
|
|
232
|
+
return await TokenUsage.count_tokens(message, model)
|
|
233
|
+
|
|
234
|
+
# If it's a dict (like OpenAI messages), extract content
|
|
235
|
+
if isinstance(message, dict) and "content" in message:
|
|
236
|
+
return await TokenUsage.count_tokens(message["content"], model)
|
|
237
|
+
|
|
238
|
+
# If it's some other object, convert to string and count
|
|
239
|
+
return await TokenUsage.count_tokens(str(message), model)
|
|
240
|
+
|
|
241
|
+
async def set_state(self, key: str, value: Any) -> None:
|
|
242
|
+
"""
|
|
243
|
+
Set a state value asynchronously.
|
|
244
|
+
|
|
245
|
+
Args:
|
|
246
|
+
key: The state key to set
|
|
247
|
+
value: The value to set
|
|
248
|
+
"""
|
|
249
|
+
self.state[key] = value
|
|
250
|
+
|
|
251
|
+
# Auto-save if needed (could be added as an option)
|
|
252
|
+
# from chuk_ai_session_manager.storage import SessionStoreProvider
|
|
253
|
+
# store = SessionStoreProvider.get_store()
|
|
254
|
+
# await store.save(self)
|
|
255
|
+
|
|
256
|
+
async def get_state(self, key: str, default: Any = None) -> Any:
|
|
257
|
+
"""
|
|
258
|
+
Get a state value asynchronously.
|
|
259
|
+
|
|
260
|
+
Args:
|
|
261
|
+
key: The state key to retrieve
|
|
262
|
+
default: Default value to return if key not found
|
|
263
|
+
|
|
264
|
+
Returns:
|
|
265
|
+
The state value or default if not found
|
|
266
|
+
"""
|
|
267
|
+
return self.state.get(key, default)
|
|
268
|
+
|
|
269
|
+
async def has_state(self, key: str) -> bool:
|
|
270
|
+
"""
|
|
271
|
+
Check if a state key exists asynchronously.
|
|
272
|
+
|
|
273
|
+
Args:
|
|
274
|
+
key: The state key to check
|
|
275
|
+
|
|
276
|
+
Returns:
|
|
277
|
+
True if the key exists in state
|
|
278
|
+
"""
|
|
279
|
+
return key in self.state
|
|
280
|
+
|
|
281
|
+
async def remove_state(self, key: str) -> None:
|
|
282
|
+
"""
|
|
283
|
+
Remove a state key-value pair asynchronously.
|
|
284
|
+
|
|
285
|
+
Args:
|
|
286
|
+
key: The state key to remove
|
|
287
|
+
"""
|
|
288
|
+
if key in self.state:
|
|
289
|
+
del self.state[key]
|
|
290
|
+
|
|
291
|
+
# Auto-save if needed (could be added as an option)
|
|
292
|
+
# from chuk_ai_session_manager.storage import SessionStoreProvider
|
|
293
|
+
# store = SessionStoreProvider.get_store()
|
|
294
|
+
# await store.save(self)
|
|
295
|
+
|
|
296
|
+
@classmethod
|
|
297
|
+
async def create(cls, parent_id: Optional[str] = None, **kwargs) -> Session:
|
|
298
|
+
"""
|
|
299
|
+
Create a new session asynchronously, handling parent-child relationships.
|
|
300
|
+
|
|
301
|
+
Args:
|
|
302
|
+
parent_id: Optional parent session ID
|
|
303
|
+
**kwargs: Additional arguments for Session initialization
|
|
304
|
+
|
|
305
|
+
Returns:
|
|
306
|
+
A new Session instance with parent-child relationships set up
|
|
307
|
+
"""
|
|
308
|
+
session = cls(parent_id=parent_id, **kwargs)
|
|
309
|
+
await session.async_init()
|
|
310
|
+
|
|
311
|
+
# Save the new session
|
|
312
|
+
from chuk_ai_session_manager.storage import SessionStoreProvider
|
|
313
|
+
store = SessionStoreProvider.get_store()
|
|
314
|
+
await store.save(session)
|
|
315
|
+
|
|
316
|
+
return session
|
|
@@ -0,0 +1,166 @@
|
|
|
1
|
+
# chuk_ai_session_manager/models/session_event.py
|
|
2
|
+
"""
|
|
3
|
+
Session event model for the chuk session manager with improved async support.
|
|
4
|
+
"""
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
from datetime import datetime, timezone
|
|
7
|
+
from typing import Any, Dict, Generic, Optional, TypeVar, Union
|
|
8
|
+
from uuid import uuid4
|
|
9
|
+
from pydantic import BaseModel, Field, ConfigDict
|
|
10
|
+
|
|
11
|
+
# session manager
|
|
12
|
+
from chuk_ai_session_manager.models.event_source import EventSource
|
|
13
|
+
from chuk_ai_session_manager.models.event_type import EventType
|
|
14
|
+
from chuk_ai_session_manager.models.token_usage import TokenUsage
|
|
15
|
+
|
|
16
|
+
# Generic type for event message content
|
|
17
|
+
MessageT = TypeVar('MessageT')
|
|
18
|
+
|
|
19
|
+
class SessionEvent(BaseModel, Generic[MessageT]):
|
|
20
|
+
"""An event in a session."""
|
|
21
|
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
22
|
+
|
|
23
|
+
id: str = Field(default_factory=lambda: str(uuid4()))
|
|
24
|
+
timestamp: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
|
25
|
+
message: Optional[MessageT] = None
|
|
26
|
+
task_id: Optional[str] = None
|
|
27
|
+
type: EventType = EventType.MESSAGE
|
|
28
|
+
source: EventSource = EventSource.LLM
|
|
29
|
+
metadata: Dict[str, Any] = Field(default_factory=dict)
|
|
30
|
+
|
|
31
|
+
# Field for token usage tracking
|
|
32
|
+
token_usage: Optional[TokenUsage] = None
|
|
33
|
+
|
|
34
|
+
@classmethod
|
|
35
|
+
async def create_with_tokens(
|
|
36
|
+
cls,
|
|
37
|
+
message: MessageT,
|
|
38
|
+
prompt: str,
|
|
39
|
+
completion: Optional[str] = None,
|
|
40
|
+
model: str = "gpt-3.5-turbo",
|
|
41
|
+
source: EventSource = EventSource.LLM,
|
|
42
|
+
type: EventType = EventType.MESSAGE,
|
|
43
|
+
task_id: Optional[str] = None,
|
|
44
|
+
metadata: Optional[Dict[str, Any]] = None
|
|
45
|
+
) -> SessionEvent[MessageT]:
|
|
46
|
+
"""
|
|
47
|
+
Create a session event with automatic token counting asynchronously.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
message: The message content
|
|
51
|
+
prompt: The prompt text used (for token counting)
|
|
52
|
+
completion: The completion text (for token counting)
|
|
53
|
+
model: The model used for this interaction
|
|
54
|
+
source: The source of this event
|
|
55
|
+
type: The type of this event
|
|
56
|
+
task_id: Optional task ID this event is associated with
|
|
57
|
+
metadata: Optional additional metadata
|
|
58
|
+
|
|
59
|
+
Returns:
|
|
60
|
+
A new SessionEvent with token usage information
|
|
61
|
+
"""
|
|
62
|
+
# Use the async method of TokenUsage
|
|
63
|
+
token_usage = await TokenUsage.from_text(prompt, completion, model)
|
|
64
|
+
|
|
65
|
+
# Create the event
|
|
66
|
+
event = cls(
|
|
67
|
+
message=message,
|
|
68
|
+
task_id=task_id,
|
|
69
|
+
type=type,
|
|
70
|
+
source=source,
|
|
71
|
+
metadata=metadata or {},
|
|
72
|
+
token_usage=token_usage
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
return event
|
|
76
|
+
|
|
77
|
+
async def update_token_usage(
|
|
78
|
+
self,
|
|
79
|
+
prompt: Optional[str] = None,
|
|
80
|
+
completion: Optional[str] = None,
|
|
81
|
+
model: Optional[str] = None
|
|
82
|
+
) -> None:
|
|
83
|
+
"""
|
|
84
|
+
Update token usage information for this event.
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
prompt: The prompt text to count tokens for
|
|
88
|
+
completion: The completion text to count tokens for
|
|
89
|
+
model: The model to use for counting tokens
|
|
90
|
+
"""
|
|
91
|
+
# If we don't have token_usage yet, create it
|
|
92
|
+
if self.token_usage is None:
|
|
93
|
+
self.token_usage = TokenUsage(model=model or "")
|
|
94
|
+
|
|
95
|
+
# If model is provided, update it
|
|
96
|
+
if model and not self.token_usage.model:
|
|
97
|
+
self.token_usage.model = model
|
|
98
|
+
|
|
99
|
+
# Calculate tokens if text is provided
|
|
100
|
+
if prompt:
|
|
101
|
+
# Use async method for token counting
|
|
102
|
+
prompt_tokens = await TokenUsage.count_tokens(prompt, self.token_usage.model)
|
|
103
|
+
self.token_usage.prompt_tokens = prompt_tokens
|
|
104
|
+
|
|
105
|
+
if completion:
|
|
106
|
+
# Use async method for token counting
|
|
107
|
+
completion_tokens = await TokenUsage.count_tokens(completion, self.token_usage.model)
|
|
108
|
+
self.token_usage.completion_tokens = completion_tokens
|
|
109
|
+
|
|
110
|
+
# Recalculate totals
|
|
111
|
+
self.token_usage.total_tokens = self.token_usage.prompt_tokens + self.token_usage.completion_tokens
|
|
112
|
+
if self.token_usage.model:
|
|
113
|
+
# Use async method for cost calculation
|
|
114
|
+
self.token_usage.estimated_cost_usd = await self.token_usage.calculate_cost()
|
|
115
|
+
|
|
116
|
+
# Metadata async methods with clean names
|
|
117
|
+
async def get_metadata(self, key: str, default: Any = None) -> Any:
|
|
118
|
+
"""Get a metadata value.
|
|
119
|
+
|
|
120
|
+
Args:
|
|
121
|
+
key: The metadata key to retrieve
|
|
122
|
+
default: Default value to return if key not found
|
|
123
|
+
|
|
124
|
+
Returns:
|
|
125
|
+
The metadata value or default if not found
|
|
126
|
+
"""
|
|
127
|
+
return self.metadata.get(key, default)
|
|
128
|
+
|
|
129
|
+
async def set_metadata(self, key: str, value: Any) -> None:
|
|
130
|
+
"""Set a metadata value.
|
|
131
|
+
|
|
132
|
+
Args:
|
|
133
|
+
key: The metadata key to set
|
|
134
|
+
value: The value to set
|
|
135
|
+
"""
|
|
136
|
+
self.metadata[key] = value
|
|
137
|
+
|
|
138
|
+
async def has_metadata(self, key: str) -> bool:
|
|
139
|
+
"""Check if a metadata key exists.
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
key: The metadata key to check
|
|
143
|
+
|
|
144
|
+
Returns:
|
|
145
|
+
True if the key exists in metadata
|
|
146
|
+
"""
|
|
147
|
+
return key in self.metadata
|
|
148
|
+
|
|
149
|
+
async def remove_metadata(self, key: str) -> None:
|
|
150
|
+
"""Remove a metadata key-value pair.
|
|
151
|
+
|
|
152
|
+
Args:
|
|
153
|
+
key: The metadata key to remove
|
|
154
|
+
"""
|
|
155
|
+
if key in self.metadata:
|
|
156
|
+
del self.metadata[key]
|
|
157
|
+
|
|
158
|
+
# Alternative async method for updating metadata for backward compatibility
|
|
159
|
+
async def update_metadata(self, key: str, value: Any) -> None:
|
|
160
|
+
"""Update a metadata value (alias for set_metadata).
|
|
161
|
+
|
|
162
|
+
Args:
|
|
163
|
+
key: The metadata key to set
|
|
164
|
+
value: The value to set
|
|
165
|
+
"""
|
|
166
|
+
await self.set_metadata(key, value)
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
# chuk_ai_session_manager/models/session_metadata.py
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
from datetime import datetime, timezone
|
|
4
|
+
from typing import Any, Dict, Optional
|
|
5
|
+
from pydantic import BaseModel, Field
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class SessionMetadata(BaseModel):
|
|
9
|
+
"""Core metadata associated with a session."""
|
|
10
|
+
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
|
11
|
+
updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
|
12
|
+
|
|
13
|
+
# Free-form properties for session-level identifiers and custom data
|
|
14
|
+
properties: Dict[str, Any] = Field(default_factory=dict)
|
|
15
|
+
|
|
16
|
+
async def set_property(self, key: str, value: Any) -> None:
|
|
17
|
+
"""Add or update a custom metadata property asynchronously."""
|
|
18
|
+
self.properties[key] = value
|
|
19
|
+
self.updated_at = datetime.now(timezone.utc)
|
|
20
|
+
|
|
21
|
+
async def get_property(self, key: str) -> Any:
|
|
22
|
+
"""Retrieve a metadata property by key asynchronously."""
|
|
23
|
+
return self.properties.get(key)
|
|
24
|
+
|
|
25
|
+
async def update_timestamp(self) -> None:
|
|
26
|
+
"""Update the updated_at timestamp asynchronously."""
|
|
27
|
+
self.updated_at = datetime.now(timezone.utc)
|
|
28
|
+
|
|
29
|
+
@classmethod
|
|
30
|
+
async def create(cls, properties: Optional[Dict[str, Any]] = None) -> SessionMetadata:
|
|
31
|
+
"""Create a new SessionMetadata instance asynchronously."""
|
|
32
|
+
now = datetime.now(timezone.utc)
|
|
33
|
+
return cls(
|
|
34
|
+
created_at=now,
|
|
35
|
+
updated_at=now,
|
|
36
|
+
properties=properties or {}
|
|
37
|
+
)
|
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
# chuk_ai_session_manager/session_run.py
|
|
2
|
+
"""
|
|
3
|
+
Session run model for the chuk session manager with improved async support.
|
|
4
|
+
"""
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
from datetime import datetime, timezone
|
|
7
|
+
from enum import Enum
|
|
8
|
+
from typing import Any, Dict, Optional, List
|
|
9
|
+
from uuid import uuid4
|
|
10
|
+
from pydantic import BaseModel, Field, ConfigDict
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class RunStatus(str, Enum):
|
|
14
|
+
"""Status of a session run."""
|
|
15
|
+
PENDING = "pending"
|
|
16
|
+
RUNNING = "running"
|
|
17
|
+
COMPLETED = "completed"
|
|
18
|
+
FAILED = "failed"
|
|
19
|
+
CANCELLED = "cancelled"
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class SessionRun(BaseModel):
|
|
23
|
+
"""A single execution or "run" within a session."""
|
|
24
|
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
25
|
+
|
|
26
|
+
id: str = Field(default_factory=lambda: str(uuid4()))
|
|
27
|
+
started_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
|
28
|
+
ended_at: Optional[datetime] = None
|
|
29
|
+
status: RunStatus = RunStatus.PENDING
|
|
30
|
+
metadata: Dict[str, Any] = Field(default_factory=dict)
|
|
31
|
+
tool_calls: List[str] = Field(default_factory=list) # IDs of associated tool call events
|
|
32
|
+
|
|
33
|
+
@classmethod
|
|
34
|
+
async def create(cls, metadata: Optional[Dict[str, Any]] = None) -> SessionRun:
|
|
35
|
+
"""Create a new session run asynchronously."""
|
|
36
|
+
return cls(
|
|
37
|
+
status=RunStatus.PENDING,
|
|
38
|
+
metadata=metadata or {}
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
async def mark_running(self) -> None:
|
|
42
|
+
"""Mark the run as started/running asynchronously."""
|
|
43
|
+
self.status = RunStatus.RUNNING
|
|
44
|
+
self.started_at = datetime.now(timezone.utc)
|
|
45
|
+
|
|
46
|
+
async def mark_completed(self) -> None:
|
|
47
|
+
"""Mark the run as completed successfully asynchronously."""
|
|
48
|
+
self.status = RunStatus.COMPLETED
|
|
49
|
+
self.ended_at = datetime.now(timezone.utc)
|
|
50
|
+
|
|
51
|
+
async def mark_failed(self, reason: Optional[str] = None) -> None:
|
|
52
|
+
"""Mark the run as failed asynchronously."""
|
|
53
|
+
self.status = RunStatus.FAILED
|
|
54
|
+
self.ended_at = datetime.now(timezone.utc)
|
|
55
|
+
if reason:
|
|
56
|
+
await self.set_metadata("failure_reason", reason)
|
|
57
|
+
|
|
58
|
+
async def mark_cancelled(self, reason: Optional[str] = None) -> None:
|
|
59
|
+
"""Mark the run as cancelled asynchronously."""
|
|
60
|
+
self.status = RunStatus.CANCELLED
|
|
61
|
+
self.ended_at = datetime.now(timezone.utc)
|
|
62
|
+
if reason:
|
|
63
|
+
await self.set_metadata("cancel_reason", reason)
|
|
64
|
+
|
|
65
|
+
async def set_metadata(self, key: str, value: Any) -> None:
|
|
66
|
+
"""Set a metadata value asynchronously."""
|
|
67
|
+
self.metadata[key] = value
|
|
68
|
+
|
|
69
|
+
async def get_metadata(self, key: str, default: Any = None) -> Any:
|
|
70
|
+
"""Get a metadata value asynchronously."""
|
|
71
|
+
return self.metadata.get(key, default)
|
|
72
|
+
|
|
73
|
+
async def has_metadata(self, key: str) -> bool:
|
|
74
|
+
"""Check if a metadata key exists asynchronously."""
|
|
75
|
+
return key in self.metadata
|
|
76
|
+
|
|
77
|
+
async def remove_metadata(self, key: str) -> None:
|
|
78
|
+
"""Remove a metadata key-value pair asynchronously."""
|
|
79
|
+
if key in self.metadata:
|
|
80
|
+
del self.metadata[key]
|
|
81
|
+
|
|
82
|
+
async def get_duration(self) -> Optional[float]:
|
|
83
|
+
"""Get the duration of the run in seconds asynchronously."""
|
|
84
|
+
if self.ended_at is None:
|
|
85
|
+
return None
|
|
86
|
+
return (self.ended_at - self.started_at).total_seconds()
|
|
87
|
+
|
|
88
|
+
async def add_tool_call(self, tool_call_id: str) -> None:
|
|
89
|
+
"""Associate a tool call event with this run asynchronously."""
|
|
90
|
+
if tool_call_id not in self.tool_calls:
|
|
91
|
+
self.tool_calls.append(tool_call_id)
|
|
92
|
+
|
|
93
|
+
async def get_tool_calls(self, session: Any) -> List[Any]:
|
|
94
|
+
"""Get all tool call events associated with this run asynchronously."""
|
|
95
|
+
# We use Any type to avoid circular imports
|
|
96
|
+
return [
|
|
97
|
+
event for event in session.events
|
|
98
|
+
if event.id in self.tool_calls
|
|
99
|
+
]
|
|
100
|
+
|
|
101
|
+
async def to_dict(self) -> Dict[str, Any]:
|
|
102
|
+
"""Convert the run to a dictionary asynchronously."""
|
|
103
|
+
result = {
|
|
104
|
+
"id": self.id,
|
|
105
|
+
"status": self.status.value,
|
|
106
|
+
"started_at": self.started_at.isoformat(),
|
|
107
|
+
"metadata": self.metadata,
|
|
108
|
+
"tool_calls": self.tool_calls
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
if self.ended_at:
|
|
112
|
+
result["ended_at"] = self.ended_at.isoformat()
|
|
113
|
+
result["duration"] = await self.get_duration()
|
|
114
|
+
|
|
115
|
+
return result
|