chuk-ai-session-manager 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,316 @@
1
+ # chuk_ai_session_manager/models/session.py
2
+ """
3
+ Session model for the chuk session manager with improved async support.
4
+ """
5
+ from __future__ import annotations
6
+ from datetime import datetime, timezone
7
+ from typing import Any, Dict, List, Optional, Generic, TypeVar, Union
8
+ from uuid import uuid4
9
+ from pydantic import BaseModel, Field, model_validator
10
+ import asyncio
11
+
12
+ # Import models that Session depends on
13
+ from chuk_ai_session_manager.models.session_metadata import SessionMetadata
14
+ from chuk_ai_session_manager.models.session_event import SessionEvent
15
+ from chuk_ai_session_manager.models.token_usage import TokenUsage, TokenSummary
16
+ # Import SessionRun and RunStatus directly to avoid circular import
17
+ from chuk_ai_session_manager.models.session_run import SessionRun, RunStatus
18
+
19
+ MessageT = TypeVar('MessageT')
20
+
21
+ class Session(BaseModel, Generic[MessageT]):
22
+ """A standalone conversation session with hierarchical support and async methods."""
23
+ id: str = Field(default_factory=lambda: str(uuid4()))
24
+ metadata: SessionMetadata = Field(default_factory=SessionMetadata)
25
+
26
+ parent_id: Optional[str] = None
27
+ child_ids: List[str] = Field(default_factory=list)
28
+
29
+ task_ids: List[str] = Field(default_factory=list)
30
+ runs: List[SessionRun] = Field(default_factory=list)
31
+ events: List[SessionEvent[MessageT]] = Field(default_factory=list)
32
+ state: Dict[str, Any] = Field(default_factory=dict)
33
+
34
+ # Token tracking
35
+ token_summary: TokenSummary = Field(default_factory=TokenSummary)
36
+
37
+ @model_validator(mode="after")
38
+ def _sync_hierarchy(cls, model: Session) -> Session:
39
+ """After creation, sync this session with its parent in the store.
40
+
41
+ Note: This is synchronous for compatibility with Pydantic.
42
+ For async parent syncing, use async_init() after creation.
43
+ """
44
+ # This validator will be called during model creation,
45
+ # but won't actually sync with storage - that requires async
46
+ return model
47
+
48
+ async def async_init(self) -> None:
49
+ """
50
+ Initialize async components of the session.
51
+
52
+ Call this after creating a new session to properly set up
53
+ parent-child relationships in the async storage.
54
+ """
55
+ if self.parent_id:
56
+ # Import here to avoid circular import
57
+ from chuk_ai_session_manager.storage import SessionStoreProvider
58
+ store = SessionStoreProvider.get_store()
59
+ parent = await store.get(self.parent_id)
60
+ if parent and self.id not in parent.child_ids:
61
+ parent.child_ids.append(self.id)
62
+ await store.save(parent)
63
+
64
+ @property
65
+ def last_update_time(self) -> datetime:
66
+ """Return timestamp of most recent event, or session creation."""
67
+ if not self.events:
68
+ return self.metadata.created_at
69
+ return max(evt.timestamp for evt in self.events)
70
+
71
+ @property
72
+ def active_run(self) -> Optional[SessionRun]:
73
+ """Return the currently running SessionRun, if any."""
74
+ for run in reversed(self.runs):
75
+ if run.status == RunStatus.RUNNING:
76
+ return run
77
+ return None
78
+
79
+ @property
80
+ def total_tokens(self) -> int:
81
+ """Return the total number of tokens used in this session."""
82
+ return self.token_summary.total_tokens
83
+
84
+ @property
85
+ def total_cost(self) -> float:
86
+ """Return the total estimated cost of this session."""
87
+ return self.token_summary.total_estimated_cost_usd
88
+
89
+ async def add_child(self, child_id: str) -> None:
90
+ """Add a child session ID and save the session."""
91
+ if child_id not in self.child_ids:
92
+ self.child_ids.append(child_id)
93
+ # Save the updated session
94
+ from chuk_ai_session_manager.storage import SessionStoreProvider
95
+ store = SessionStoreProvider.get_store()
96
+ await store.save(self)
97
+
98
+ async def remove_child(self, child_id: str) -> None:
99
+ """Remove a child session ID and save the session."""
100
+ if child_id in self.child_ids:
101
+ self.child_ids.remove(child_id)
102
+ # Save the updated session
103
+ from chuk_ai_session_manager.storage import SessionStoreProvider
104
+ store = SessionStoreProvider.get_store()
105
+ await store.save(self)
106
+
107
+ async def ancestors(self) -> List[Session]:
108
+ """Fetch ancestor sessions from store asynchronously."""
109
+ result: List[Session] = []
110
+ current = self.parent_id
111
+
112
+ # Import here to avoid circular import
113
+ from chuk_ai_session_manager.storage import SessionStoreProvider
114
+ store = SessionStoreProvider.get_store()
115
+
116
+ while current:
117
+ parent = await store.get(current)
118
+ if not parent:
119
+ break
120
+ result.append(parent)
121
+ current = parent.parent_id
122
+ return result
123
+
124
+ async def descendants(self) -> List[Session]:
125
+ """Fetch all descendant sessions from store in DFS order asynchronously."""
126
+ result: List[Session] = []
127
+ stack = list(self.child_ids)
128
+
129
+ # Import here to avoid circular import
130
+ from chuk_ai_session_manager.storage import SessionStoreProvider
131
+ store = SessionStoreProvider.get_store()
132
+
133
+ while stack:
134
+ cid = stack.pop()
135
+ child = await store.get(cid)
136
+ if child:
137
+ result.append(child)
138
+ stack.extend(child.child_ids)
139
+ return result
140
+
141
+ async def add_event(self, event: SessionEvent[MessageT]) -> None:
142
+ """
143
+ Add an event to the session and update token tracking asynchronously.
144
+
145
+ Args:
146
+ event: The event to add
147
+ """
148
+ # Add the event
149
+ self.events.append(event)
150
+
151
+ # Update token summary if this event has token usage
152
+ if event.token_usage:
153
+ await self.token_summary.add_usage(event.token_usage)
154
+
155
+ async def add_event_and_save(self, event: SessionEvent[MessageT]) -> None:
156
+ """
157
+ Add an event to the session, update token tracking, and save the session.
158
+
159
+ Args:
160
+ event: The event to add
161
+ """
162
+ # Add the event asynchronously
163
+ await self.add_event(event)
164
+
165
+ # Save the session
166
+ from chuk_ai_session_manager.storage import SessionStoreProvider
167
+ store = SessionStoreProvider.get_store()
168
+ await store.save(self)
169
+
170
+ async def get_token_usage_by_source(self) -> Dict[str, TokenSummary]:
171
+ """
172
+ Get token usage statistics grouped by event source asynchronously.
173
+
174
+ Returns:
175
+ A dictionary mapping event sources to token summaries
176
+ """
177
+ result: Dict[str, TokenSummary] = {}
178
+
179
+ for event in self.events:
180
+ if not event.token_usage:
181
+ continue
182
+
183
+ source = event.source.value
184
+ if source not in result:
185
+ result[source] = TokenSummary()
186
+
187
+ await result[source].add_usage(event.token_usage)
188
+
189
+ return result
190
+
191
+ async def get_token_usage_by_run(self) -> Dict[str, TokenSummary]:
192
+ """
193
+ Get token usage statistics grouped by run asynchronously.
194
+
195
+ Returns:
196
+ A dictionary mapping run IDs to token summaries
197
+ """
198
+ result: Dict[str, TokenSummary] = {}
199
+
200
+ # Add an entry for events without a run
201
+ result["no_run"] = TokenSummary()
202
+
203
+ for event in self.events:
204
+ if not event.token_usage:
205
+ continue
206
+
207
+ run_id = event.task_id or "no_run"
208
+ if run_id not in result:
209
+ result[run_id] = TokenSummary()
210
+
211
+ await result[run_id].add_usage(event.token_usage)
212
+
213
+ return result
214
+
215
+ async def count_message_tokens(
216
+ self,
217
+ message: Union[str, Dict[str, Any]],
218
+ model: str = "gpt-3.5-turbo"
219
+ ) -> int:
220
+ """
221
+ Count tokens in a message asynchronously.
222
+
223
+ Args:
224
+ message: The message to count tokens for (string or dict)
225
+ model: The model to use for counting
226
+
227
+ Returns:
228
+ The number of tokens in the message
229
+ """
230
+ # If message is already a string, count directly
231
+ if isinstance(message, str):
232
+ return await TokenUsage.count_tokens(message, model)
233
+
234
+ # If it's a dict (like OpenAI messages), extract content
235
+ if isinstance(message, dict) and "content" in message:
236
+ return await TokenUsage.count_tokens(message["content"], model)
237
+
238
+ # If it's some other object, convert to string and count
239
+ return await TokenUsage.count_tokens(str(message), model)
240
+
241
+ async def set_state(self, key: str, value: Any) -> None:
242
+ """
243
+ Set a state value asynchronously.
244
+
245
+ Args:
246
+ key: The state key to set
247
+ value: The value to set
248
+ """
249
+ self.state[key] = value
250
+
251
+ # Auto-save if needed (could be added as an option)
252
+ # from chuk_ai_session_manager.storage import SessionStoreProvider
253
+ # store = SessionStoreProvider.get_store()
254
+ # await store.save(self)
255
+
256
+ async def get_state(self, key: str, default: Any = None) -> Any:
257
+ """
258
+ Get a state value asynchronously.
259
+
260
+ Args:
261
+ key: The state key to retrieve
262
+ default: Default value to return if key not found
263
+
264
+ Returns:
265
+ The state value or default if not found
266
+ """
267
+ return self.state.get(key, default)
268
+
269
+ async def has_state(self, key: str) -> bool:
270
+ """
271
+ Check if a state key exists asynchronously.
272
+
273
+ Args:
274
+ key: The state key to check
275
+
276
+ Returns:
277
+ True if the key exists in state
278
+ """
279
+ return key in self.state
280
+
281
+ async def remove_state(self, key: str) -> None:
282
+ """
283
+ Remove a state key-value pair asynchronously.
284
+
285
+ Args:
286
+ key: The state key to remove
287
+ """
288
+ if key in self.state:
289
+ del self.state[key]
290
+
291
+ # Auto-save if needed (could be added as an option)
292
+ # from chuk_ai_session_manager.storage import SessionStoreProvider
293
+ # store = SessionStoreProvider.get_store()
294
+ # await store.save(self)
295
+
296
+ @classmethod
297
+ async def create(cls, parent_id: Optional[str] = None, **kwargs) -> Session:
298
+ """
299
+ Create a new session asynchronously, handling parent-child relationships.
300
+
301
+ Args:
302
+ parent_id: Optional parent session ID
303
+ **kwargs: Additional arguments for Session initialization
304
+
305
+ Returns:
306
+ A new Session instance with parent-child relationships set up
307
+ """
308
+ session = cls(parent_id=parent_id, **kwargs)
309
+ await session.async_init()
310
+
311
+ # Save the new session
312
+ from chuk_ai_session_manager.storage import SessionStoreProvider
313
+ store = SessionStoreProvider.get_store()
314
+ await store.save(session)
315
+
316
+ return session
@@ -0,0 +1,166 @@
1
+ # chuk_ai_session_manager/models/session_event.py
2
+ """
3
+ Session event model for the chuk session manager with improved async support.
4
+ """
5
+ from __future__ import annotations
6
+ from datetime import datetime, timezone
7
+ from typing import Any, Dict, Generic, Optional, TypeVar, Union
8
+ from uuid import uuid4
9
+ from pydantic import BaseModel, Field, ConfigDict
10
+
11
+ # session manager
12
+ from chuk_ai_session_manager.models.event_source import EventSource
13
+ from chuk_ai_session_manager.models.event_type import EventType
14
+ from chuk_ai_session_manager.models.token_usage import TokenUsage
15
+
16
+ # Generic type for event message content
17
+ MessageT = TypeVar('MessageT')
18
+
19
+ class SessionEvent(BaseModel, Generic[MessageT]):
20
+ """An event in a session."""
21
+ model_config = ConfigDict(arbitrary_types_allowed=True)
22
+
23
+ id: str = Field(default_factory=lambda: str(uuid4()))
24
+ timestamp: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
25
+ message: Optional[MessageT] = None
26
+ task_id: Optional[str] = None
27
+ type: EventType = EventType.MESSAGE
28
+ source: EventSource = EventSource.LLM
29
+ metadata: Dict[str, Any] = Field(default_factory=dict)
30
+
31
+ # Field for token usage tracking
32
+ token_usage: Optional[TokenUsage] = None
33
+
34
+ @classmethod
35
+ async def create_with_tokens(
36
+ cls,
37
+ message: MessageT,
38
+ prompt: str,
39
+ completion: Optional[str] = None,
40
+ model: str = "gpt-3.5-turbo",
41
+ source: EventSource = EventSource.LLM,
42
+ type: EventType = EventType.MESSAGE,
43
+ task_id: Optional[str] = None,
44
+ metadata: Optional[Dict[str, Any]] = None
45
+ ) -> SessionEvent[MessageT]:
46
+ """
47
+ Create a session event with automatic token counting asynchronously.
48
+
49
+ Args:
50
+ message: The message content
51
+ prompt: The prompt text used (for token counting)
52
+ completion: The completion text (for token counting)
53
+ model: The model used for this interaction
54
+ source: The source of this event
55
+ type: The type of this event
56
+ task_id: Optional task ID this event is associated with
57
+ metadata: Optional additional metadata
58
+
59
+ Returns:
60
+ A new SessionEvent with token usage information
61
+ """
62
+ # Use the async method of TokenUsage
63
+ token_usage = await TokenUsage.from_text(prompt, completion, model)
64
+
65
+ # Create the event
66
+ event = cls(
67
+ message=message,
68
+ task_id=task_id,
69
+ type=type,
70
+ source=source,
71
+ metadata=metadata or {},
72
+ token_usage=token_usage
73
+ )
74
+
75
+ return event
76
+
77
+ async def update_token_usage(
78
+ self,
79
+ prompt: Optional[str] = None,
80
+ completion: Optional[str] = None,
81
+ model: Optional[str] = None
82
+ ) -> None:
83
+ """
84
+ Update token usage information for this event.
85
+
86
+ Args:
87
+ prompt: The prompt text to count tokens for
88
+ completion: The completion text to count tokens for
89
+ model: The model to use for counting tokens
90
+ """
91
+ # If we don't have token_usage yet, create it
92
+ if self.token_usage is None:
93
+ self.token_usage = TokenUsage(model=model or "")
94
+
95
+ # If model is provided, update it
96
+ if model and not self.token_usage.model:
97
+ self.token_usage.model = model
98
+
99
+ # Calculate tokens if text is provided
100
+ if prompt:
101
+ # Use async method for token counting
102
+ prompt_tokens = await TokenUsage.count_tokens(prompt, self.token_usage.model)
103
+ self.token_usage.prompt_tokens = prompt_tokens
104
+
105
+ if completion:
106
+ # Use async method for token counting
107
+ completion_tokens = await TokenUsage.count_tokens(completion, self.token_usage.model)
108
+ self.token_usage.completion_tokens = completion_tokens
109
+
110
+ # Recalculate totals
111
+ self.token_usage.total_tokens = self.token_usage.prompt_tokens + self.token_usage.completion_tokens
112
+ if self.token_usage.model:
113
+ # Use async method for cost calculation
114
+ self.token_usage.estimated_cost_usd = await self.token_usage.calculate_cost()
115
+
116
+ # Metadata async methods with clean names
117
+ async def get_metadata(self, key: str, default: Any = None) -> Any:
118
+ """Get a metadata value.
119
+
120
+ Args:
121
+ key: The metadata key to retrieve
122
+ default: Default value to return if key not found
123
+
124
+ Returns:
125
+ The metadata value or default if not found
126
+ """
127
+ return self.metadata.get(key, default)
128
+
129
+ async def set_metadata(self, key: str, value: Any) -> None:
130
+ """Set a metadata value.
131
+
132
+ Args:
133
+ key: The metadata key to set
134
+ value: The value to set
135
+ """
136
+ self.metadata[key] = value
137
+
138
+ async def has_metadata(self, key: str) -> bool:
139
+ """Check if a metadata key exists.
140
+
141
+ Args:
142
+ key: The metadata key to check
143
+
144
+ Returns:
145
+ True if the key exists in metadata
146
+ """
147
+ return key in self.metadata
148
+
149
+ async def remove_metadata(self, key: str) -> None:
150
+ """Remove a metadata key-value pair.
151
+
152
+ Args:
153
+ key: The metadata key to remove
154
+ """
155
+ if key in self.metadata:
156
+ del self.metadata[key]
157
+
158
+ # Alternative async method for updating metadata for backward compatibility
159
+ async def update_metadata(self, key: str, value: Any) -> None:
160
+ """Update a metadata value (alias for set_metadata).
161
+
162
+ Args:
163
+ key: The metadata key to set
164
+ value: The value to set
165
+ """
166
+ await self.set_metadata(key, value)
@@ -0,0 +1,37 @@
1
+ # chuk_ai_session_manager/models/session_metadata.py
2
+ from __future__ import annotations
3
+ from datetime import datetime, timezone
4
+ from typing import Any, Dict, Optional
5
+ from pydantic import BaseModel, Field
6
+
7
+
8
+ class SessionMetadata(BaseModel):
9
+ """Core metadata associated with a session."""
10
+ created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
11
+ updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
12
+
13
+ # Free-form properties for session-level identifiers and custom data
14
+ properties: Dict[str, Any] = Field(default_factory=dict)
15
+
16
+ async def set_property(self, key: str, value: Any) -> None:
17
+ """Add or update a custom metadata property asynchronously."""
18
+ self.properties[key] = value
19
+ self.updated_at = datetime.now(timezone.utc)
20
+
21
+ async def get_property(self, key: str) -> Any:
22
+ """Retrieve a metadata property by key asynchronously."""
23
+ return self.properties.get(key)
24
+
25
+ async def update_timestamp(self) -> None:
26
+ """Update the updated_at timestamp asynchronously."""
27
+ self.updated_at = datetime.now(timezone.utc)
28
+
29
+ @classmethod
30
+ async def create(cls, properties: Optional[Dict[str, Any]] = None) -> SessionMetadata:
31
+ """Create a new SessionMetadata instance asynchronously."""
32
+ now = datetime.now(timezone.utc)
33
+ return cls(
34
+ created_at=now,
35
+ updated_at=now,
36
+ properties=properties or {}
37
+ )
@@ -0,0 +1,115 @@
1
+ # chuk_ai_session_manager/session_run.py
2
+ """
3
+ Session run model for the chuk session manager with improved async support.
4
+ """
5
+ from __future__ import annotations
6
+ from datetime import datetime, timezone
7
+ from enum import Enum
8
+ from typing import Any, Dict, Optional, List
9
+ from uuid import uuid4
10
+ from pydantic import BaseModel, Field, ConfigDict
11
+
12
+
13
+ class RunStatus(str, Enum):
14
+ """Status of a session run."""
15
+ PENDING = "pending"
16
+ RUNNING = "running"
17
+ COMPLETED = "completed"
18
+ FAILED = "failed"
19
+ CANCELLED = "cancelled"
20
+
21
+
22
+ class SessionRun(BaseModel):
23
+ """A single execution or "run" within a session."""
24
+ model_config = ConfigDict(arbitrary_types_allowed=True)
25
+
26
+ id: str = Field(default_factory=lambda: str(uuid4()))
27
+ started_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
28
+ ended_at: Optional[datetime] = None
29
+ status: RunStatus = RunStatus.PENDING
30
+ metadata: Dict[str, Any] = Field(default_factory=dict)
31
+ tool_calls: List[str] = Field(default_factory=list) # IDs of associated tool call events
32
+
33
+ @classmethod
34
+ async def create(cls, metadata: Optional[Dict[str, Any]] = None) -> SessionRun:
35
+ """Create a new session run asynchronously."""
36
+ return cls(
37
+ status=RunStatus.PENDING,
38
+ metadata=metadata or {}
39
+ )
40
+
41
+ async def mark_running(self) -> None:
42
+ """Mark the run as started/running asynchronously."""
43
+ self.status = RunStatus.RUNNING
44
+ self.started_at = datetime.now(timezone.utc)
45
+
46
+ async def mark_completed(self) -> None:
47
+ """Mark the run as completed successfully asynchronously."""
48
+ self.status = RunStatus.COMPLETED
49
+ self.ended_at = datetime.now(timezone.utc)
50
+
51
+ async def mark_failed(self, reason: Optional[str] = None) -> None:
52
+ """Mark the run as failed asynchronously."""
53
+ self.status = RunStatus.FAILED
54
+ self.ended_at = datetime.now(timezone.utc)
55
+ if reason:
56
+ await self.set_metadata("failure_reason", reason)
57
+
58
+ async def mark_cancelled(self, reason: Optional[str] = None) -> None:
59
+ """Mark the run as cancelled asynchronously."""
60
+ self.status = RunStatus.CANCELLED
61
+ self.ended_at = datetime.now(timezone.utc)
62
+ if reason:
63
+ await self.set_metadata("cancel_reason", reason)
64
+
65
+ async def set_metadata(self, key: str, value: Any) -> None:
66
+ """Set a metadata value asynchronously."""
67
+ self.metadata[key] = value
68
+
69
+ async def get_metadata(self, key: str, default: Any = None) -> Any:
70
+ """Get a metadata value asynchronously."""
71
+ return self.metadata.get(key, default)
72
+
73
+ async def has_metadata(self, key: str) -> bool:
74
+ """Check if a metadata key exists asynchronously."""
75
+ return key in self.metadata
76
+
77
+ async def remove_metadata(self, key: str) -> None:
78
+ """Remove a metadata key-value pair asynchronously."""
79
+ if key in self.metadata:
80
+ del self.metadata[key]
81
+
82
+ async def get_duration(self) -> Optional[float]:
83
+ """Get the duration of the run in seconds asynchronously."""
84
+ if self.ended_at is None:
85
+ return None
86
+ return (self.ended_at - self.started_at).total_seconds()
87
+
88
+ async def add_tool_call(self, tool_call_id: str) -> None:
89
+ """Associate a tool call event with this run asynchronously."""
90
+ if tool_call_id not in self.tool_calls:
91
+ self.tool_calls.append(tool_call_id)
92
+
93
+ async def get_tool_calls(self, session: Any) -> List[Any]:
94
+ """Get all tool call events associated with this run asynchronously."""
95
+ # We use Any type to avoid circular imports
96
+ return [
97
+ event for event in session.events
98
+ if event.id in self.tool_calls
99
+ ]
100
+
101
+ async def to_dict(self) -> Dict[str, Any]:
102
+ """Convert the run to a dictionary asynchronously."""
103
+ result = {
104
+ "id": self.id,
105
+ "status": self.status.value,
106
+ "started_at": self.started_at.isoformat(),
107
+ "metadata": self.metadata,
108
+ "tool_calls": self.tool_calls
109
+ }
110
+
111
+ if self.ended_at:
112
+ result["ended_at"] = self.ended_at.isoformat()
113
+ result["duration"] = await self.get_duration()
114
+
115
+ return result