fastworkflow 2.16.0__py3-none-any.whl → 2.17.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of fastworkflow might be problematic. Click here for more details.

@@ -0,0 +1,391 @@
1
+ """
2
+ Conversation persistence layer for FastWorkflow
3
+ Provides Rdict-backed storage for multi-turn conversations with AI-generated topics/summaries
4
+ """
5
+
6
+ import json
7
+ import os
8
+ from re import I
9
+ import time
10
+ from typing import Any, Optional
11
+
12
+ import dspy
13
+ from pydantic import BaseModel
14
+ from speedict import Rdict
15
+
16
+ from fastworkflow.utils.logging import logger
17
+ from fastworkflow.utils.dspy_utils import get_lm
18
+
19
+
20
+ def extract_turns_from_history(conversation_history: 'dspy.History') -> list[dict[str, Any]]:
21
+ # sourcery skip: remove-unused-enumerate
22
+ """
23
+ Extract turns from dspy.History format to Rdict format.
24
+
25
+ dspy.History.messages format:
26
+ [
27
+ {
28
+ "conversation summary": "summary text1",
29
+ "conversation_traces": "conversation_traces1",
30
+ "feedback": {...} or None
31
+ },
32
+ ...
33
+ ]
34
+
35
+ Rdict turn format:
36
+ [
37
+ {
38
+ "conversation summary": "...",
39
+ "conversation_traces": "...",
40
+ "feedback": {...} or None
41
+ },
42
+ ...
43
+ ]
44
+
45
+ Note: dspy.History stores conversation summaries, detailed traces, and optional feedback.
46
+ All fields are extracted and preserved for complete conversation persistence.
47
+ """
48
+ turns = []
49
+
50
+ turns.extend(
51
+ {
52
+ "conversation summary": msg_dict.get("conversation summary"),
53
+ "conversation_traces": msg_dict.get("conversation_traces"),
54
+ "feedback": msg_dict.get("feedback"), # Preserve existing feedback
55
+ }
56
+ for msg_dict in conversation_history.messages
57
+ )
58
+ return turns
59
+
60
+
61
+ def restore_history_from_turns(turns: list[dict[str, Any]]) -> 'dspy.History':
62
+ """
63
+ Restore dspy.History from Rdict turns.
64
+
65
+ Converts back from Rdict format to dspy.History format.
66
+ Restores conversation summary, conversation_traces, and feedback for each turn.
67
+ """
68
+ messages = []
69
+
70
+ messages.extend(
71
+ {
72
+ "conversation summary": turn.get("conversation summary"),
73
+ "conversation_traces": turn.get("conversation_traces"),
74
+ "feedback": turn.get("feedback"), # Restore feedback if present
75
+ }
76
+ for turn in turns
77
+ )
78
+ return dspy.History(messages=messages)
79
+
80
+
81
+ class ConversationSummary(BaseModel):
82
+ """Summary of a conversation"""
83
+ conversation_id: int
84
+ topic: str
85
+ summary: str
86
+ created_at: int
87
+ updated_at: int
88
+
89
+
90
+ class ConversationStore:
91
+ """Rdict-backed conversation persistence per user"""
92
+
93
+ def __init__(self, user_id: str, base_folder: str):
94
+ self.user_id = user_id
95
+ self.db_path = os.path.join(base_folder, f"{user_id}.rdb")
96
+ os.makedirs(base_folder, exist_ok=True)
97
+
98
+ def _get_db(self) -> Rdict:
99
+ """Get Rdict instance"""
100
+ return Rdict(self.db_path)
101
+
102
+ def get_last_conversation_id(self) -> Optional[int]:
103
+ """Get the last conversation ID for this user"""
104
+ try:
105
+ db = self._get_db()
106
+ meta = db.get("meta", {})
107
+ return meta.get("last_conversation_id")
108
+ finally:
109
+ db.close()
110
+
111
+ def _increment_conversation_id(self, db: Rdict) -> int:
112
+ """Increment and return new conversation ID"""
113
+ meta = db.get("meta", {"last_conversation_id": 0})
114
+ new_id = meta["last_conversation_id"] + 1
115
+ meta["last_conversation_id"] = new_id
116
+ db["meta"] = meta
117
+ return new_id
118
+
119
+ def reserve_next_conversation_id(self) -> int:
120
+ """Reserve the next conversation ID by incrementing the counter without creating a conversation"""
121
+ db = self._get_db()
122
+ try:
123
+ return self._increment_conversation_id(db)
124
+ finally:
125
+ db.close()
126
+
127
+ def _ensure_unique_topic(self, db: Rdict, candidate_topic: str) -> str:
128
+ """Ensure topic is unique per user with case/whitespace insensitive comparison"""
129
+ # Normalize for comparison
130
+ normalized_candidate = candidate_topic.lower().strip()
131
+
132
+ # Get all existing topics
133
+ existing_topics = []
134
+ meta = db.get("meta", {"last_conversation_id": 0})
135
+ for i in range(1, meta.get("last_conversation_id", 0) + 1):
136
+ conv_key = f"conv:{i}"
137
+ if conv_key in db:
138
+ conv = db[conv_key]
139
+ existing_topics.append(conv.get("topic", ""))
140
+
141
+ # Check for collision
142
+ collision_count = 0
143
+ final_topic = candidate_topic
144
+ while any(final_topic.lower().strip() == t.lower().strip() for t in existing_topics):
145
+ collision_count += 1
146
+ final_topic = f"{candidate_topic} {collision_count}"
147
+
148
+ return final_topic
149
+
150
+ def save_conversation(
151
+ self,
152
+ topic: str,
153
+ summary: str,
154
+ turns: list[dict[str, Any]],
155
+ conversation_id: Optional[int] = None
156
+ ) -> int:
157
+ """
158
+ Save a conversation and return its ID.
159
+
160
+ Args:
161
+ topic: Conversation topic
162
+ summary: Conversation summary
163
+ turns: List of conversation turns
164
+ conversation_id: Optional specific ID to use. If None, increments to get next ID.
165
+
166
+ Returns:
167
+ The conversation ID used
168
+ """
169
+ db = self._get_db()
170
+ try:
171
+ if conversation_id is not None:
172
+ # Use the specified ID (assumes it's valid and reserved)
173
+ conv_id = conversation_id
174
+ else:
175
+ # Increment to get next ID
176
+ conv_id = self._increment_conversation_id(db)
177
+
178
+ unique_topic = self._ensure_unique_topic(db, topic)
179
+
180
+ conversation = {
181
+ "topic": unique_topic,
182
+ "summary": summary,
183
+ "created_at": int(time.time() * 1000),
184
+ "updated_at": int(time.time() * 1000),
185
+ "turns": turns
186
+ }
187
+ db[f"conv:{conv_id}"] = conversation
188
+ return conv_id
189
+ finally:
190
+ db.close()
191
+
192
+ def get_conversation(self, conv_id: int) -> Optional[dict[str, Any]]:
193
+ """Get a conversation by ID"""
194
+ db = self._get_db()
195
+ try:
196
+ return db.get(f"conv:{conv_id}")
197
+ finally:
198
+ db.close()
199
+
200
+ def get_conversation_by_topic(self, topic: str) -> Optional[tuple[int, dict[str, Any]]]:
201
+ """Get conversation ID and data by topic (case/whitespace insensitive)"""
202
+ db = self._get_db()
203
+ try:
204
+ meta = db.get("meta", {"last_conversation_id": 0})
205
+ normalized_topic = topic.lower().strip()
206
+
207
+ for i in range(1, meta.get("last_conversation_id", 0) + 1):
208
+ conv_key = f"conv:{i}"
209
+ if conv_key in db:
210
+ conv = db[conv_key]
211
+ if conv.get("topic", "").lower().strip() == normalized_topic:
212
+ return i, conv
213
+ return None
214
+ finally:
215
+ db.close()
216
+
217
+ def list_conversations(self, limit: int) -> list[ConversationSummary]:
218
+ """List conversations ordered by updated_at desc, up to limit"""
219
+ db = self._get_db()
220
+ try:
221
+ meta = db.get("meta", {"last_conversation_id": 0})
222
+ conversations = []
223
+
224
+ for i in range(1, meta.get("last_conversation_id", 0) + 1):
225
+ conv_key = f"conv:{i}"
226
+ if conv_key in db:
227
+ conv = db[conv_key]
228
+ conversations.append(
229
+ ConversationSummary(
230
+ conversation_id=i,
231
+ topic=conv.get("topic", ""),
232
+ summary=conv.get("summary", ""),
233
+ created_at=conv.get("created_at", 0),
234
+ updated_at=conv.get("updated_at", 0)
235
+ )
236
+ )
237
+
238
+ # Sort by updated_at desc and limit
239
+ conversations.sort(key=lambda c: c.updated_at, reverse=True)
240
+ return conversations[:limit]
241
+ finally:
242
+ db.close()
243
+
244
+ def update_conversation(
245
+ self,
246
+ conv_id: int,
247
+ topic: str,
248
+ summary: str,
249
+ turns: list[dict[str, Any]]
250
+ ) -> None:
251
+ """Update an existing conversation with new topic, summary, and turns"""
252
+ db = self._get_db()
253
+ try:
254
+ conv_key = f"conv:{conv_id}"
255
+ if conv_key not in db:
256
+ raise ValueError(f"Conversation {conv_id} not found")
257
+
258
+ conv = db[conv_key]
259
+ unique_topic = self._ensure_unique_topic(db, topic)
260
+
261
+ # Preserve created_at, update other fields
262
+ conv["topic"] = unique_topic
263
+ conv["summary"] = summary
264
+ conv["updated_at"] = int(time.time() * 1000)
265
+ conv["turns"] = turns
266
+
267
+ db[conv_key] = conv
268
+ finally:
269
+ db.close()
270
+
271
+ def update_conversation_topic_summary(
272
+ self,
273
+ conv_id: int,
274
+ topic: str,
275
+ summary: str
276
+ ) -> None:
277
+ """
278
+ Update only the topic and summary of an existing conversation.
279
+ Used when finalizing a conversation (turns already saved incrementally).
280
+ """
281
+ db = self._get_db()
282
+ try:
283
+ conv_key = f"conv:{conv_id}"
284
+ if conv_key not in db:
285
+ raise ValueError(f"Conversation {conv_id} not found")
286
+
287
+ conv = db[conv_key]
288
+ unique_topic = self._ensure_unique_topic(db, topic)
289
+
290
+ # Only update topic, summary, and timestamp - preserve turns
291
+ conv["topic"] = unique_topic
292
+ conv["summary"] = summary
293
+ conv["updated_at"] = int(time.time() * 1000)
294
+
295
+ db[conv_key] = conv
296
+ finally:
297
+ db.close()
298
+
299
+ def save_conversation_turns(
300
+ self,
301
+ conversation_id: int,
302
+ turns: list[dict[str, Any]]
303
+ ) -> int:
304
+ """
305
+ Create a new conversation with placeholder topic/summary, or update existing turns.
306
+ Used for incremental saves without generating topic/summary.
307
+
308
+ Args:
309
+ conversation_id: The conversation ID to use
310
+ turns: List of conversation turns
311
+
312
+ Returns:
313
+ The conversation ID used
314
+ """
315
+ db = self._get_db()
316
+ try:
317
+ conv_key = f"conv:{conversation_id}"
318
+
319
+ if conv_key in db:
320
+ # Conversation exists, just update turns
321
+ conv = db[conv_key]
322
+ conv["updated_at"] = int(time.time() * 1000)
323
+ conv["turns"] = turns
324
+ db[conv_key] = conv
325
+ else:
326
+ # Create new conversation with placeholder topic/summary
327
+ conversation = {
328
+ "topic": "", # Will be generated later
329
+ "summary": "", # Will be generated later
330
+ "created_at": int(time.time() * 1000),
331
+ "updated_at": int(time.time() * 1000),
332
+ "turns": turns
333
+ }
334
+ db[conv_key] = conversation
335
+
336
+ return conversation_id
337
+ finally:
338
+ db.close()
339
+
340
+ # NOTE: update_turn_feedback() removed - feedback is now saved via save_conversation_turns()
341
+ # in the incremental save flow after modifying conversation_history in memory
342
+
343
+ def get_all_conversations_for_dump(self) -> list[dict[str, Any]]:
344
+ """Get all conversations for admin dump"""
345
+ db = self._get_db()
346
+ try:
347
+ meta = db.get("meta", {"last_conversation_id": 0})
348
+ conversations = []
349
+
350
+ for i in range(1, meta.get("last_conversation_id", 0) + 1):
351
+ conv_key = f"conv:{i}"
352
+ if conv_key in db:
353
+ conv = db[conv_key]
354
+ conversations.append({
355
+ "user_id": self.user_id,
356
+ "conversation_id": i,
357
+ **conv
358
+ })
359
+
360
+ return conversations
361
+ finally:
362
+ db.close()
363
+
364
+
365
+ def generate_topic_and_summary(turns: list[dict[str, Any]]) -> tuple[str, str]:
366
+ """
367
+ Generate topic and summary for a conversation using DSPy.
368
+
369
+ Only passes conversation summaries (not verbose traces) to the AI model
370
+ for better quality topic/summary generation.
371
+ """
372
+ class TopicSummarySignature(dspy.Signature):
373
+ """Generate a concise topic and summary for a conversation"""
374
+ conversation_turns: str = dspy.InputField(desc="JSON representation of conversation turns")
375
+ topic: str = dspy.OutputField(desc="Short topic (3-6 words)")
376
+ summary: str = dspy.OutputField(desc="Brief summary paragraph")
377
+
378
+ # Extract only summaries for topic/summary generation (not verbose traces)
379
+ summaries_only = [
380
+ {"conversation summary": turn.get("conversation summary", "")}
381
+ for turn in turns
382
+ ]
383
+ turns_str = json.dumps(summaries_only, indent=2)
384
+
385
+ # Configure DSPy with the conversation store LM using context manager
386
+ lm = get_lm("LLM_CONVERSATION_STORE", "LITELLM_API_KEY_CONVERSATION_STORE")
387
+ with dspy.context(lm=lm):
388
+ generator = dspy.ChainOfThought(TopicSummarySignature)
389
+ result = generator(conversation_turns=turns_str)
390
+ return result.topic, result.summary
391
+
@@ -0,0 +1,256 @@
1
+ """
2
+ JWT Token Management for FastWorkflow FastAPI Service
3
+
4
+ Handles RSA key pair generation, JWT token creation and verification.
5
+ Keys are stored in ./jwt_keys/ directory.
6
+ """
7
+
8
+ import os
9
+ from datetime import datetime, timedelta, timezone
10
+ from typing import Optional
11
+
12
+ from jose import JWTError, jwt
13
+ from jose.constants import ALGORITHMS
14
+ from cryptography.hazmat.primitives import serialization
15
+ from cryptography.hazmat.primitives.asymmetric import rsa
16
+ from cryptography.hazmat.backends import default_backend
17
+
18
+ from fastworkflow.utils.logging import logger
19
+
20
+
21
+ # JWT Configuration (can be made configurable via env vars)
22
+ JWT_ALGORITHM = ALGORITHMS.RS256
23
+ JWT_ACCESS_TOKEN_EXPIRE_MINUTES = 60 # 1 hour
24
+ JWT_REFRESH_TOKEN_EXPIRE_DAYS = 30 # 30 days
25
+ JWT_ISSUER = "fastworkflow-api"
26
+ JWT_AUDIENCE = "fastworkflow-client"
27
+
28
+ # Key storage location (relative to project root)
29
+ KEYS_DIR = "./jwt_keys"
30
+ PRIVATE_KEY_PATH = os.path.join(KEYS_DIR, "private_key.pem")
31
+ PUBLIC_KEY_PATH = os.path.join(KEYS_DIR, "public_key.pem")
32
+
33
+ # In-memory cache for loaded keys
34
+ _private_key: Optional[str] = None
35
+ _public_key: Optional[str] = None
36
+
37
+
38
+ def ensure_keys_directory() -> None:
39
+ """Create jwt_keys directory if it doesn't exist."""
40
+ os.makedirs(KEYS_DIR, exist_ok=True)
41
+ logger.info(f"JWT keys directory ensured at: {KEYS_DIR}")
42
+
43
+
44
+ def generate_rsa_key_pair() -> tuple[str, str]:
45
+ """
46
+ Generate a new RSA 2048-bit key pair.
47
+
48
+ Returns:
49
+ tuple[str, str]: (private_key_pem, public_key_pem)
50
+ """
51
+ logger.info("Generating new RSA 2048-bit key pair for JWT...")
52
+
53
+ # Generate private key
54
+ private_key = rsa.generate_private_key(
55
+ public_exponent=65537,
56
+ key_size=2048,
57
+ backend=default_backend()
58
+ )
59
+
60
+ # Serialize private key to PEM format
61
+ private_pem = private_key.private_bytes(
62
+ encoding=serialization.Encoding.PEM,
63
+ format=serialization.PrivateFormat.PKCS8,
64
+ encryption_algorithm=serialization.NoEncryption()
65
+ ).decode('utf-8')
66
+
67
+ # Extract public key and serialize to PEM format
68
+ public_key = private_key.public_key()
69
+ public_pem = public_key.public_bytes(
70
+ encoding=serialization.Encoding.PEM,
71
+ format=serialization.PublicFormat.SubjectPublicKeyInfo
72
+ ).decode('utf-8')
73
+
74
+ logger.info("RSA key pair generated successfully")
75
+ return private_pem, public_pem
76
+
77
+
78
+ def save_keys_to_disk(private_pem: str, public_pem: str) -> None:
79
+ """
80
+ Save RSA keys to disk with appropriate permissions.
81
+
82
+ Args:
83
+ private_pem: Private key in PEM format
84
+ public_pem: Public key in PEM format
85
+ """
86
+ ensure_keys_directory()
87
+
88
+ # Save private key (mode 600 for security)
89
+ with open(PRIVATE_KEY_PATH, 'w') as f:
90
+ f.write(private_pem)
91
+ os.chmod(PRIVATE_KEY_PATH, 0o600)
92
+ logger.info(f"Private key saved to: {PRIVATE_KEY_PATH} (mode 600)")
93
+
94
+ # Save public key (mode 644 is fine)
95
+ with open(PUBLIC_KEY_PATH, 'w') as f:
96
+ f.write(public_pem)
97
+ os.chmod(PUBLIC_KEY_PATH, 0o644)
98
+ logger.info(f"Public key saved to: {PUBLIC_KEY_PATH} (mode 644)")
99
+
100
+
101
+ def load_or_generate_keys() -> tuple[str, str]:
102
+ """
103
+ Load existing RSA keys from disk, or generate new ones if they don't exist.
104
+ Caches keys in memory for performance.
105
+
106
+ Returns:
107
+ tuple[str, str]: (private_key_pem, public_key_pem)
108
+ """
109
+ global _private_key, _public_key
110
+
111
+ # Return cached keys if available
112
+ if _private_key and _public_key:
113
+ return _private_key, _public_key
114
+
115
+ # Try to load existing keys
116
+ if os.path.exists(PRIVATE_KEY_PATH) and os.path.exists(PUBLIC_KEY_PATH):
117
+ logger.info("Loading existing RSA keys from disk...")
118
+ with open(PRIVATE_KEY_PATH, 'r') as f:
119
+ _private_key = f.read()
120
+ with open(PUBLIC_KEY_PATH, 'r') as f:
121
+ _public_key = f.read()
122
+ logger.info("RSA keys loaded successfully")
123
+ else:
124
+ # Generate and save new keys
125
+ logger.info("No existing RSA keys found, generating new ones...")
126
+ _private_key, _public_key = generate_rsa_key_pair()
127
+ save_keys_to_disk(_private_key, _public_key)
128
+
129
+ return _private_key, _public_key
130
+
131
+
132
+ def create_access_token(user_id: str, expires_days: int | None = None) -> str:
133
+ """
134
+ Create a JWT access token for a user.
135
+
136
+ Args:
137
+ user_id: User identifier
138
+ expires_days: Optional custom expiration in days. If None, uses JWT_ACCESS_TOKEN_EXPIRE_MINUTES (default 60 minutes).
139
+
140
+ Returns:
141
+ str: Encoded JWT access token
142
+ """
143
+ private_key, _ = load_or_generate_keys()
144
+
145
+ now = datetime.now(timezone.utc)
146
+ if expires_days is not None:
147
+ expire = now + timedelta(days=expires_days)
148
+ else:
149
+ expire = now + timedelta(minutes=JWT_ACCESS_TOKEN_EXPIRE_MINUTES)
150
+
151
+ # JWT claims
152
+ payload = {
153
+ "sub": user_id, # Subject: the user identifier
154
+ "iat": int(now.timestamp()), # Issued at
155
+ "exp": int(expire.timestamp()), # Expiration time
156
+ "jti": f"{user_id}_{int(now.timestamp())}", # JWT ID (unique identifier)
157
+ "type": "access", # Token type
158
+ "iss": JWT_ISSUER, # Issuer
159
+ "aud": JWT_AUDIENCE # Audience
160
+ }
161
+
162
+ token = jwt.encode(payload, private_key, algorithm=JWT_ALGORITHM)
163
+ logger.debug(f"Created access token for user_id: {user_id}, expires: {expire.isoformat()}")
164
+ return token
165
+
166
+
167
+ def create_refresh_token(user_id: str) -> str:
168
+ """
169
+ Create a JWT refresh token for a user.
170
+
171
+ Args:
172
+ user_id: User identifier
173
+
174
+ Returns:
175
+ str: Encoded JWT refresh token
176
+ """
177
+ private_key, _ = load_or_generate_keys()
178
+
179
+ now = datetime.now(timezone.utc)
180
+ expire = now + timedelta(days=JWT_REFRESH_TOKEN_EXPIRE_DAYS)
181
+
182
+ # JWT claims
183
+ payload = {
184
+ "sub": user_id, # Subject: the user identifier
185
+ "iat": int(now.timestamp()), # Issued at
186
+ "exp": int(expire.timestamp()), # Expiration time
187
+ "jti": f"{user_id}_{int(now.timestamp())}_refresh", # JWT ID (unique identifier)
188
+ "type": "refresh", # Token type
189
+ "iss": JWT_ISSUER, # Issuer
190
+ "aud": JWT_AUDIENCE # Audience
191
+ }
192
+
193
+ token = jwt.encode(payload, private_key, algorithm=JWT_ALGORITHM)
194
+ logger.debug(f"Created refresh token for user_id: {user_id}, expires: {expire.isoformat()}")
195
+ return token
196
+
197
+
198
+ def verify_token(token: str, expected_type: str = "access") -> dict:
199
+ """
200
+ Verify and decode a JWT token.
201
+
202
+ Args:
203
+ token: JWT token string
204
+ expected_type: Expected token type ("access" or "refresh")
205
+
206
+ Returns:
207
+ dict: Decoded token payload
208
+
209
+ Raises:
210
+ JWTError: If token is invalid, expired, or type mismatch
211
+ """
212
+ _, public_key = load_or_generate_keys()
213
+
214
+ try:
215
+ # Decode and verify token
216
+ payload = jwt.decode(
217
+ token,
218
+ public_key,
219
+ algorithms=[JWT_ALGORITHM],
220
+ issuer=JWT_ISSUER,
221
+ audience=JWT_AUDIENCE
222
+ )
223
+
224
+ # Verify token type
225
+ if payload.get("type") != expected_type:
226
+ raise JWTError(f"Invalid token type: expected {expected_type}, got {payload.get('type')}")
227
+
228
+ logger.debug(f"Token verified successfully: user_id={payload.get('sub')}, type={expected_type}")
229
+ return payload
230
+
231
+ except JWTError as e:
232
+ logger.warning(f"Token verification failed: {e}")
233
+ raise
234
+
235
+
236
+ def get_token_expiry(token: str) -> Optional[datetime]:
237
+ """
238
+ Get the expiration time of a JWT token without full verification.
239
+ Useful for debugging/logging.
240
+
241
+ Args:
242
+ token: JWT token string
243
+
244
+ Returns:
245
+ datetime: Expiration time in UTC, or None if token is invalid
246
+ """
247
+ try:
248
+ # Decode without verification (just to inspect claims)
249
+ payload = jwt.get_unverified_claims(token)
250
+ exp_timestamp = payload.get("exp")
251
+ if exp_timestamp:
252
+ return datetime.fromtimestamp(exp_timestamp, tz=timezone.utc)
253
+ except Exception as e:
254
+ logger.debug(f"Failed to get token expiry: {e}")
255
+ return None
256
+