fastworkflow 2.16.0__py3-none-any.whl → 2.17.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of fastworkflow might be problematic. Click here for more details.
- fastworkflow/_workflows/command_metadata_extraction/_commands/IntentDetection/what_can_i_do.py +12 -7
- fastworkflow/chat_session.py +1 -0
- fastworkflow/command_context_model.py +73 -7
- fastworkflow/command_metadata_api.py +56 -6
- fastworkflow/run/__main__.py +0 -6
- fastworkflow/run_fastapi_mcp/README.md +300 -0
- fastworkflow/run_fastapi_mcp/__init__.py +0 -0
- fastworkflow/run_fastapi_mcp/conversation_store.py +391 -0
- fastworkflow/run_fastapi_mcp/jwt_manager.py +256 -0
- fastworkflow/run_fastapi_mcp/main.py +1206 -0
- fastworkflow/run_fastapi_mcp/mcp_specific.py +103 -0
- fastworkflow/run_fastapi_mcp/redoc_2_standalone_html.py +40 -0
- fastworkflow/run_fastapi_mcp/utils.py +427 -0
- {fastworkflow-2.16.0.dist-info → fastworkflow-2.17.1.dist-info}/METADATA +1 -1
- {fastworkflow-2.16.0.dist-info → fastworkflow-2.17.1.dist-info}/RECORD +18 -10
- {fastworkflow-2.16.0.dist-info → fastworkflow-2.17.1.dist-info}/LICENSE +0 -0
- {fastworkflow-2.16.0.dist-info → fastworkflow-2.17.1.dist-info}/WHEEL +0 -0
- {fastworkflow-2.16.0.dist-info → fastworkflow-2.17.1.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,391 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Conversation persistence layer for FastWorkflow
|
|
3
|
+
Provides Rdict-backed storage for multi-turn conversations with AI-generated topics/summaries
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import json
|
|
7
|
+
import os
|
|
8
|
+
from re import I
|
|
9
|
+
import time
|
|
10
|
+
from typing import Any, Optional
|
|
11
|
+
|
|
12
|
+
import dspy
|
|
13
|
+
from pydantic import BaseModel
|
|
14
|
+
from speedict import Rdict
|
|
15
|
+
|
|
16
|
+
from fastworkflow.utils.logging import logger
|
|
17
|
+
from fastworkflow.utils.dspy_utils import get_lm
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def extract_turns_from_history(conversation_history: 'dspy.History') -> list[dict[str, Any]]:
|
|
21
|
+
# sourcery skip: remove-unused-enumerate
|
|
22
|
+
"""
|
|
23
|
+
Extract turns from dspy.History format to Rdict format.
|
|
24
|
+
|
|
25
|
+
dspy.History.messages format:
|
|
26
|
+
[
|
|
27
|
+
{
|
|
28
|
+
"conversation summary": "summary text1",
|
|
29
|
+
"conversation_traces": "conversation_traces1",
|
|
30
|
+
"feedback": {...} or None
|
|
31
|
+
},
|
|
32
|
+
...
|
|
33
|
+
]
|
|
34
|
+
|
|
35
|
+
Rdict turn format:
|
|
36
|
+
[
|
|
37
|
+
{
|
|
38
|
+
"conversation summary": "...",
|
|
39
|
+
"conversation_traces": "...",
|
|
40
|
+
"feedback": {...} or None
|
|
41
|
+
},
|
|
42
|
+
...
|
|
43
|
+
]
|
|
44
|
+
|
|
45
|
+
Note: dspy.History stores conversation summaries, detailed traces, and optional feedback.
|
|
46
|
+
All fields are extracted and preserved for complete conversation persistence.
|
|
47
|
+
"""
|
|
48
|
+
turns = []
|
|
49
|
+
|
|
50
|
+
turns.extend(
|
|
51
|
+
{
|
|
52
|
+
"conversation summary": msg_dict.get("conversation summary"),
|
|
53
|
+
"conversation_traces": msg_dict.get("conversation_traces"),
|
|
54
|
+
"feedback": msg_dict.get("feedback"), # Preserve existing feedback
|
|
55
|
+
}
|
|
56
|
+
for msg_dict in conversation_history.messages
|
|
57
|
+
)
|
|
58
|
+
return turns
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def restore_history_from_turns(turns: list[dict[str, Any]]) -> 'dspy.History':
|
|
62
|
+
"""
|
|
63
|
+
Restore dspy.History from Rdict turns.
|
|
64
|
+
|
|
65
|
+
Converts back from Rdict format to dspy.History format.
|
|
66
|
+
Restores conversation summary, conversation_traces, and feedback for each turn.
|
|
67
|
+
"""
|
|
68
|
+
messages = []
|
|
69
|
+
|
|
70
|
+
messages.extend(
|
|
71
|
+
{
|
|
72
|
+
"conversation summary": turn.get("conversation summary"),
|
|
73
|
+
"conversation_traces": turn.get("conversation_traces"),
|
|
74
|
+
"feedback": turn.get("feedback"), # Restore feedback if present
|
|
75
|
+
}
|
|
76
|
+
for turn in turns
|
|
77
|
+
)
|
|
78
|
+
return dspy.History(messages=messages)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class ConversationSummary(BaseModel):
|
|
82
|
+
"""Summary of a conversation"""
|
|
83
|
+
conversation_id: int
|
|
84
|
+
topic: str
|
|
85
|
+
summary: str
|
|
86
|
+
created_at: int
|
|
87
|
+
updated_at: int
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
class ConversationStore:
|
|
91
|
+
"""Rdict-backed conversation persistence per user"""
|
|
92
|
+
|
|
93
|
+
def __init__(self, user_id: str, base_folder: str):
|
|
94
|
+
self.user_id = user_id
|
|
95
|
+
self.db_path = os.path.join(base_folder, f"{user_id}.rdb")
|
|
96
|
+
os.makedirs(base_folder, exist_ok=True)
|
|
97
|
+
|
|
98
|
+
def _get_db(self) -> Rdict:
|
|
99
|
+
"""Get Rdict instance"""
|
|
100
|
+
return Rdict(self.db_path)
|
|
101
|
+
|
|
102
|
+
def get_last_conversation_id(self) -> Optional[int]:
|
|
103
|
+
"""Get the last conversation ID for this user"""
|
|
104
|
+
try:
|
|
105
|
+
db = self._get_db()
|
|
106
|
+
meta = db.get("meta", {})
|
|
107
|
+
return meta.get("last_conversation_id")
|
|
108
|
+
finally:
|
|
109
|
+
db.close()
|
|
110
|
+
|
|
111
|
+
def _increment_conversation_id(self, db: Rdict) -> int:
|
|
112
|
+
"""Increment and return new conversation ID"""
|
|
113
|
+
meta = db.get("meta", {"last_conversation_id": 0})
|
|
114
|
+
new_id = meta["last_conversation_id"] + 1
|
|
115
|
+
meta["last_conversation_id"] = new_id
|
|
116
|
+
db["meta"] = meta
|
|
117
|
+
return new_id
|
|
118
|
+
|
|
119
|
+
def reserve_next_conversation_id(self) -> int:
|
|
120
|
+
"""Reserve the next conversation ID by incrementing the counter without creating a conversation"""
|
|
121
|
+
db = self._get_db()
|
|
122
|
+
try:
|
|
123
|
+
return self._increment_conversation_id(db)
|
|
124
|
+
finally:
|
|
125
|
+
db.close()
|
|
126
|
+
|
|
127
|
+
def _ensure_unique_topic(self, db: Rdict, candidate_topic: str) -> str:
|
|
128
|
+
"""Ensure topic is unique per user with case/whitespace insensitive comparison"""
|
|
129
|
+
# Normalize for comparison
|
|
130
|
+
normalized_candidate = candidate_topic.lower().strip()
|
|
131
|
+
|
|
132
|
+
# Get all existing topics
|
|
133
|
+
existing_topics = []
|
|
134
|
+
meta = db.get("meta", {"last_conversation_id": 0})
|
|
135
|
+
for i in range(1, meta.get("last_conversation_id", 0) + 1):
|
|
136
|
+
conv_key = f"conv:{i}"
|
|
137
|
+
if conv_key in db:
|
|
138
|
+
conv = db[conv_key]
|
|
139
|
+
existing_topics.append(conv.get("topic", ""))
|
|
140
|
+
|
|
141
|
+
# Check for collision
|
|
142
|
+
collision_count = 0
|
|
143
|
+
final_topic = candidate_topic
|
|
144
|
+
while any(final_topic.lower().strip() == t.lower().strip() for t in existing_topics):
|
|
145
|
+
collision_count += 1
|
|
146
|
+
final_topic = f"{candidate_topic} {collision_count}"
|
|
147
|
+
|
|
148
|
+
return final_topic
|
|
149
|
+
|
|
150
|
+
def save_conversation(
|
|
151
|
+
self,
|
|
152
|
+
topic: str,
|
|
153
|
+
summary: str,
|
|
154
|
+
turns: list[dict[str, Any]],
|
|
155
|
+
conversation_id: Optional[int] = None
|
|
156
|
+
) -> int:
|
|
157
|
+
"""
|
|
158
|
+
Save a conversation and return its ID.
|
|
159
|
+
|
|
160
|
+
Args:
|
|
161
|
+
topic: Conversation topic
|
|
162
|
+
summary: Conversation summary
|
|
163
|
+
turns: List of conversation turns
|
|
164
|
+
conversation_id: Optional specific ID to use. If None, increments to get next ID.
|
|
165
|
+
|
|
166
|
+
Returns:
|
|
167
|
+
The conversation ID used
|
|
168
|
+
"""
|
|
169
|
+
db = self._get_db()
|
|
170
|
+
try:
|
|
171
|
+
if conversation_id is not None:
|
|
172
|
+
# Use the specified ID (assumes it's valid and reserved)
|
|
173
|
+
conv_id = conversation_id
|
|
174
|
+
else:
|
|
175
|
+
# Increment to get next ID
|
|
176
|
+
conv_id = self._increment_conversation_id(db)
|
|
177
|
+
|
|
178
|
+
unique_topic = self._ensure_unique_topic(db, topic)
|
|
179
|
+
|
|
180
|
+
conversation = {
|
|
181
|
+
"topic": unique_topic,
|
|
182
|
+
"summary": summary,
|
|
183
|
+
"created_at": int(time.time() * 1000),
|
|
184
|
+
"updated_at": int(time.time() * 1000),
|
|
185
|
+
"turns": turns
|
|
186
|
+
}
|
|
187
|
+
db[f"conv:{conv_id}"] = conversation
|
|
188
|
+
return conv_id
|
|
189
|
+
finally:
|
|
190
|
+
db.close()
|
|
191
|
+
|
|
192
|
+
def get_conversation(self, conv_id: int) -> Optional[dict[str, Any]]:
|
|
193
|
+
"""Get a conversation by ID"""
|
|
194
|
+
db = self._get_db()
|
|
195
|
+
try:
|
|
196
|
+
return db.get(f"conv:{conv_id}")
|
|
197
|
+
finally:
|
|
198
|
+
db.close()
|
|
199
|
+
|
|
200
|
+
def get_conversation_by_topic(self, topic: str) -> Optional[tuple[int, dict[str, Any]]]:
|
|
201
|
+
"""Get conversation ID and data by topic (case/whitespace insensitive)"""
|
|
202
|
+
db = self._get_db()
|
|
203
|
+
try:
|
|
204
|
+
meta = db.get("meta", {"last_conversation_id": 0})
|
|
205
|
+
normalized_topic = topic.lower().strip()
|
|
206
|
+
|
|
207
|
+
for i in range(1, meta.get("last_conversation_id", 0) + 1):
|
|
208
|
+
conv_key = f"conv:{i}"
|
|
209
|
+
if conv_key in db:
|
|
210
|
+
conv = db[conv_key]
|
|
211
|
+
if conv.get("topic", "").lower().strip() == normalized_topic:
|
|
212
|
+
return i, conv
|
|
213
|
+
return None
|
|
214
|
+
finally:
|
|
215
|
+
db.close()
|
|
216
|
+
|
|
217
|
+
def list_conversations(self, limit: int) -> list[ConversationSummary]:
|
|
218
|
+
"""List conversations ordered by updated_at desc, up to limit"""
|
|
219
|
+
db = self._get_db()
|
|
220
|
+
try:
|
|
221
|
+
meta = db.get("meta", {"last_conversation_id": 0})
|
|
222
|
+
conversations = []
|
|
223
|
+
|
|
224
|
+
for i in range(1, meta.get("last_conversation_id", 0) + 1):
|
|
225
|
+
conv_key = f"conv:{i}"
|
|
226
|
+
if conv_key in db:
|
|
227
|
+
conv = db[conv_key]
|
|
228
|
+
conversations.append(
|
|
229
|
+
ConversationSummary(
|
|
230
|
+
conversation_id=i,
|
|
231
|
+
topic=conv.get("topic", ""),
|
|
232
|
+
summary=conv.get("summary", ""),
|
|
233
|
+
created_at=conv.get("created_at", 0),
|
|
234
|
+
updated_at=conv.get("updated_at", 0)
|
|
235
|
+
)
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
# Sort by updated_at desc and limit
|
|
239
|
+
conversations.sort(key=lambda c: c.updated_at, reverse=True)
|
|
240
|
+
return conversations[:limit]
|
|
241
|
+
finally:
|
|
242
|
+
db.close()
|
|
243
|
+
|
|
244
|
+
def update_conversation(
|
|
245
|
+
self,
|
|
246
|
+
conv_id: int,
|
|
247
|
+
topic: str,
|
|
248
|
+
summary: str,
|
|
249
|
+
turns: list[dict[str, Any]]
|
|
250
|
+
) -> None:
|
|
251
|
+
"""Update an existing conversation with new topic, summary, and turns"""
|
|
252
|
+
db = self._get_db()
|
|
253
|
+
try:
|
|
254
|
+
conv_key = f"conv:{conv_id}"
|
|
255
|
+
if conv_key not in db:
|
|
256
|
+
raise ValueError(f"Conversation {conv_id} not found")
|
|
257
|
+
|
|
258
|
+
conv = db[conv_key]
|
|
259
|
+
unique_topic = self._ensure_unique_topic(db, topic)
|
|
260
|
+
|
|
261
|
+
# Preserve created_at, update other fields
|
|
262
|
+
conv["topic"] = unique_topic
|
|
263
|
+
conv["summary"] = summary
|
|
264
|
+
conv["updated_at"] = int(time.time() * 1000)
|
|
265
|
+
conv["turns"] = turns
|
|
266
|
+
|
|
267
|
+
db[conv_key] = conv
|
|
268
|
+
finally:
|
|
269
|
+
db.close()
|
|
270
|
+
|
|
271
|
+
def update_conversation_topic_summary(
|
|
272
|
+
self,
|
|
273
|
+
conv_id: int,
|
|
274
|
+
topic: str,
|
|
275
|
+
summary: str
|
|
276
|
+
) -> None:
|
|
277
|
+
"""
|
|
278
|
+
Update only the topic and summary of an existing conversation.
|
|
279
|
+
Used when finalizing a conversation (turns already saved incrementally).
|
|
280
|
+
"""
|
|
281
|
+
db = self._get_db()
|
|
282
|
+
try:
|
|
283
|
+
conv_key = f"conv:{conv_id}"
|
|
284
|
+
if conv_key not in db:
|
|
285
|
+
raise ValueError(f"Conversation {conv_id} not found")
|
|
286
|
+
|
|
287
|
+
conv = db[conv_key]
|
|
288
|
+
unique_topic = self._ensure_unique_topic(db, topic)
|
|
289
|
+
|
|
290
|
+
# Only update topic, summary, and timestamp - preserve turns
|
|
291
|
+
conv["topic"] = unique_topic
|
|
292
|
+
conv["summary"] = summary
|
|
293
|
+
conv["updated_at"] = int(time.time() * 1000)
|
|
294
|
+
|
|
295
|
+
db[conv_key] = conv
|
|
296
|
+
finally:
|
|
297
|
+
db.close()
|
|
298
|
+
|
|
299
|
+
def save_conversation_turns(
|
|
300
|
+
self,
|
|
301
|
+
conversation_id: int,
|
|
302
|
+
turns: list[dict[str, Any]]
|
|
303
|
+
) -> int:
|
|
304
|
+
"""
|
|
305
|
+
Create a new conversation with placeholder topic/summary, or update existing turns.
|
|
306
|
+
Used for incremental saves without generating topic/summary.
|
|
307
|
+
|
|
308
|
+
Args:
|
|
309
|
+
conversation_id: The conversation ID to use
|
|
310
|
+
turns: List of conversation turns
|
|
311
|
+
|
|
312
|
+
Returns:
|
|
313
|
+
The conversation ID used
|
|
314
|
+
"""
|
|
315
|
+
db = self._get_db()
|
|
316
|
+
try:
|
|
317
|
+
conv_key = f"conv:{conversation_id}"
|
|
318
|
+
|
|
319
|
+
if conv_key in db:
|
|
320
|
+
# Conversation exists, just update turns
|
|
321
|
+
conv = db[conv_key]
|
|
322
|
+
conv["updated_at"] = int(time.time() * 1000)
|
|
323
|
+
conv["turns"] = turns
|
|
324
|
+
db[conv_key] = conv
|
|
325
|
+
else:
|
|
326
|
+
# Create new conversation with placeholder topic/summary
|
|
327
|
+
conversation = {
|
|
328
|
+
"topic": "", # Will be generated later
|
|
329
|
+
"summary": "", # Will be generated later
|
|
330
|
+
"created_at": int(time.time() * 1000),
|
|
331
|
+
"updated_at": int(time.time() * 1000),
|
|
332
|
+
"turns": turns
|
|
333
|
+
}
|
|
334
|
+
db[conv_key] = conversation
|
|
335
|
+
|
|
336
|
+
return conversation_id
|
|
337
|
+
finally:
|
|
338
|
+
db.close()
|
|
339
|
+
|
|
340
|
+
# NOTE: update_turn_feedback() removed - feedback is now saved via save_conversation_turns()
|
|
341
|
+
# in the incremental save flow after modifying conversation_history in memory
|
|
342
|
+
|
|
343
|
+
def get_all_conversations_for_dump(self) -> list[dict[str, Any]]:
|
|
344
|
+
"""Get all conversations for admin dump"""
|
|
345
|
+
db = self._get_db()
|
|
346
|
+
try:
|
|
347
|
+
meta = db.get("meta", {"last_conversation_id": 0})
|
|
348
|
+
conversations = []
|
|
349
|
+
|
|
350
|
+
for i in range(1, meta.get("last_conversation_id", 0) + 1):
|
|
351
|
+
conv_key = f"conv:{i}"
|
|
352
|
+
if conv_key in db:
|
|
353
|
+
conv = db[conv_key]
|
|
354
|
+
conversations.append({
|
|
355
|
+
"user_id": self.user_id,
|
|
356
|
+
"conversation_id": i,
|
|
357
|
+
**conv
|
|
358
|
+
})
|
|
359
|
+
|
|
360
|
+
return conversations
|
|
361
|
+
finally:
|
|
362
|
+
db.close()
|
|
363
|
+
|
|
364
|
+
|
|
365
|
+
def generate_topic_and_summary(turns: list[dict[str, Any]]) -> tuple[str, str]:
|
|
366
|
+
"""
|
|
367
|
+
Generate topic and summary for a conversation using DSPy.
|
|
368
|
+
|
|
369
|
+
Only passes conversation summaries (not verbose traces) to the AI model
|
|
370
|
+
for better quality topic/summary generation.
|
|
371
|
+
"""
|
|
372
|
+
class TopicSummarySignature(dspy.Signature):
|
|
373
|
+
"""Generate a concise topic and summary for a conversation"""
|
|
374
|
+
conversation_turns: str = dspy.InputField(desc="JSON representation of conversation turns")
|
|
375
|
+
topic: str = dspy.OutputField(desc="Short topic (3-6 words)")
|
|
376
|
+
summary: str = dspy.OutputField(desc="Brief summary paragraph")
|
|
377
|
+
|
|
378
|
+
# Extract only summaries for topic/summary generation (not verbose traces)
|
|
379
|
+
summaries_only = [
|
|
380
|
+
{"conversation summary": turn.get("conversation summary", "")}
|
|
381
|
+
for turn in turns
|
|
382
|
+
]
|
|
383
|
+
turns_str = json.dumps(summaries_only, indent=2)
|
|
384
|
+
|
|
385
|
+
# Configure DSPy with the conversation store LM using context manager
|
|
386
|
+
lm = get_lm("LLM_CONVERSATION_STORE", "LITELLM_API_KEY_CONVERSATION_STORE")
|
|
387
|
+
with dspy.context(lm=lm):
|
|
388
|
+
generator = dspy.ChainOfThought(TopicSummarySignature)
|
|
389
|
+
result = generator(conversation_turns=turns_str)
|
|
390
|
+
return result.topic, result.summary
|
|
391
|
+
|
|
@@ -0,0 +1,256 @@
|
|
|
1
|
+
"""
|
|
2
|
+
JWT Token Management for FastWorkflow FastAPI Service
|
|
3
|
+
|
|
4
|
+
Handles RSA key pair generation, JWT token creation and verification.
|
|
5
|
+
Keys are stored in ./jwt_keys/ directory.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import os
|
|
9
|
+
from datetime import datetime, timedelta, timezone
|
|
10
|
+
from typing import Optional
|
|
11
|
+
|
|
12
|
+
from jose import JWTError, jwt
|
|
13
|
+
from jose.constants import ALGORITHMS
|
|
14
|
+
from cryptography.hazmat.primitives import serialization
|
|
15
|
+
from cryptography.hazmat.primitives.asymmetric import rsa
|
|
16
|
+
from cryptography.hazmat.backends import default_backend
|
|
17
|
+
|
|
18
|
+
from fastworkflow.utils.logging import logger
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
# JWT Configuration (can be made configurable via env vars)
|
|
22
|
+
JWT_ALGORITHM = ALGORITHMS.RS256
|
|
23
|
+
JWT_ACCESS_TOKEN_EXPIRE_MINUTES = 60 # 1 hour
|
|
24
|
+
JWT_REFRESH_TOKEN_EXPIRE_DAYS = 30 # 30 days
|
|
25
|
+
JWT_ISSUER = "fastworkflow-api"
|
|
26
|
+
JWT_AUDIENCE = "fastworkflow-client"
|
|
27
|
+
|
|
28
|
+
# Key storage location (relative to project root)
|
|
29
|
+
KEYS_DIR = "./jwt_keys"
|
|
30
|
+
PRIVATE_KEY_PATH = os.path.join(KEYS_DIR, "private_key.pem")
|
|
31
|
+
PUBLIC_KEY_PATH = os.path.join(KEYS_DIR, "public_key.pem")
|
|
32
|
+
|
|
33
|
+
# In-memory cache for loaded keys
|
|
34
|
+
_private_key: Optional[str] = None
|
|
35
|
+
_public_key: Optional[str] = None
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def ensure_keys_directory() -> None:
|
|
39
|
+
"""Create jwt_keys directory if it doesn't exist."""
|
|
40
|
+
os.makedirs(KEYS_DIR, exist_ok=True)
|
|
41
|
+
logger.info(f"JWT keys directory ensured at: {KEYS_DIR}")
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def generate_rsa_key_pair() -> tuple[str, str]:
|
|
45
|
+
"""
|
|
46
|
+
Generate a new RSA 2048-bit key pair.
|
|
47
|
+
|
|
48
|
+
Returns:
|
|
49
|
+
tuple[str, str]: (private_key_pem, public_key_pem)
|
|
50
|
+
"""
|
|
51
|
+
logger.info("Generating new RSA 2048-bit key pair for JWT...")
|
|
52
|
+
|
|
53
|
+
# Generate private key
|
|
54
|
+
private_key = rsa.generate_private_key(
|
|
55
|
+
public_exponent=65537,
|
|
56
|
+
key_size=2048,
|
|
57
|
+
backend=default_backend()
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
# Serialize private key to PEM format
|
|
61
|
+
private_pem = private_key.private_bytes(
|
|
62
|
+
encoding=serialization.Encoding.PEM,
|
|
63
|
+
format=serialization.PrivateFormat.PKCS8,
|
|
64
|
+
encryption_algorithm=serialization.NoEncryption()
|
|
65
|
+
).decode('utf-8')
|
|
66
|
+
|
|
67
|
+
# Extract public key and serialize to PEM format
|
|
68
|
+
public_key = private_key.public_key()
|
|
69
|
+
public_pem = public_key.public_bytes(
|
|
70
|
+
encoding=serialization.Encoding.PEM,
|
|
71
|
+
format=serialization.PublicFormat.SubjectPublicKeyInfo
|
|
72
|
+
).decode('utf-8')
|
|
73
|
+
|
|
74
|
+
logger.info("RSA key pair generated successfully")
|
|
75
|
+
return private_pem, public_pem
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def save_keys_to_disk(private_pem: str, public_pem: str) -> None:
|
|
79
|
+
"""
|
|
80
|
+
Save RSA keys to disk with appropriate permissions.
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
private_pem: Private key in PEM format
|
|
84
|
+
public_pem: Public key in PEM format
|
|
85
|
+
"""
|
|
86
|
+
ensure_keys_directory()
|
|
87
|
+
|
|
88
|
+
# Save private key (mode 600 for security)
|
|
89
|
+
with open(PRIVATE_KEY_PATH, 'w') as f:
|
|
90
|
+
f.write(private_pem)
|
|
91
|
+
os.chmod(PRIVATE_KEY_PATH, 0o600)
|
|
92
|
+
logger.info(f"Private key saved to: {PRIVATE_KEY_PATH} (mode 600)")
|
|
93
|
+
|
|
94
|
+
# Save public key (mode 644 is fine)
|
|
95
|
+
with open(PUBLIC_KEY_PATH, 'w') as f:
|
|
96
|
+
f.write(public_pem)
|
|
97
|
+
os.chmod(PUBLIC_KEY_PATH, 0o644)
|
|
98
|
+
logger.info(f"Public key saved to: {PUBLIC_KEY_PATH} (mode 644)")
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def load_or_generate_keys() -> tuple[str, str]:
|
|
102
|
+
"""
|
|
103
|
+
Load existing RSA keys from disk, or generate new ones if they don't exist.
|
|
104
|
+
Caches keys in memory for performance.
|
|
105
|
+
|
|
106
|
+
Returns:
|
|
107
|
+
tuple[str, str]: (private_key_pem, public_key_pem)
|
|
108
|
+
"""
|
|
109
|
+
global _private_key, _public_key
|
|
110
|
+
|
|
111
|
+
# Return cached keys if available
|
|
112
|
+
if _private_key and _public_key:
|
|
113
|
+
return _private_key, _public_key
|
|
114
|
+
|
|
115
|
+
# Try to load existing keys
|
|
116
|
+
if os.path.exists(PRIVATE_KEY_PATH) and os.path.exists(PUBLIC_KEY_PATH):
|
|
117
|
+
logger.info("Loading existing RSA keys from disk...")
|
|
118
|
+
with open(PRIVATE_KEY_PATH, 'r') as f:
|
|
119
|
+
_private_key = f.read()
|
|
120
|
+
with open(PUBLIC_KEY_PATH, 'r') as f:
|
|
121
|
+
_public_key = f.read()
|
|
122
|
+
logger.info("RSA keys loaded successfully")
|
|
123
|
+
else:
|
|
124
|
+
# Generate and save new keys
|
|
125
|
+
logger.info("No existing RSA keys found, generating new ones...")
|
|
126
|
+
_private_key, _public_key = generate_rsa_key_pair()
|
|
127
|
+
save_keys_to_disk(_private_key, _public_key)
|
|
128
|
+
|
|
129
|
+
return _private_key, _public_key
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def create_access_token(user_id: str, expires_days: int | None = None) -> str:
|
|
133
|
+
"""
|
|
134
|
+
Create a JWT access token for a user.
|
|
135
|
+
|
|
136
|
+
Args:
|
|
137
|
+
user_id: User identifier
|
|
138
|
+
expires_days: Optional custom expiration in days. If None, uses JWT_ACCESS_TOKEN_EXPIRE_MINUTES (default 60 minutes).
|
|
139
|
+
|
|
140
|
+
Returns:
|
|
141
|
+
str: Encoded JWT access token
|
|
142
|
+
"""
|
|
143
|
+
private_key, _ = load_or_generate_keys()
|
|
144
|
+
|
|
145
|
+
now = datetime.now(timezone.utc)
|
|
146
|
+
if expires_days is not None:
|
|
147
|
+
expire = now + timedelta(days=expires_days)
|
|
148
|
+
else:
|
|
149
|
+
expire = now + timedelta(minutes=JWT_ACCESS_TOKEN_EXPIRE_MINUTES)
|
|
150
|
+
|
|
151
|
+
# JWT claims
|
|
152
|
+
payload = {
|
|
153
|
+
"sub": user_id, # Subject: the user identifier
|
|
154
|
+
"iat": int(now.timestamp()), # Issued at
|
|
155
|
+
"exp": int(expire.timestamp()), # Expiration time
|
|
156
|
+
"jti": f"{user_id}_{int(now.timestamp())}", # JWT ID (unique identifier)
|
|
157
|
+
"type": "access", # Token type
|
|
158
|
+
"iss": JWT_ISSUER, # Issuer
|
|
159
|
+
"aud": JWT_AUDIENCE # Audience
|
|
160
|
+
}
|
|
161
|
+
|
|
162
|
+
token = jwt.encode(payload, private_key, algorithm=JWT_ALGORITHM)
|
|
163
|
+
logger.debug(f"Created access token for user_id: {user_id}, expires: {expire.isoformat()}")
|
|
164
|
+
return token
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def create_refresh_token(user_id: str) -> str:
|
|
168
|
+
"""
|
|
169
|
+
Create a JWT refresh token for a user.
|
|
170
|
+
|
|
171
|
+
Args:
|
|
172
|
+
user_id: User identifier
|
|
173
|
+
|
|
174
|
+
Returns:
|
|
175
|
+
str: Encoded JWT refresh token
|
|
176
|
+
"""
|
|
177
|
+
private_key, _ = load_or_generate_keys()
|
|
178
|
+
|
|
179
|
+
now = datetime.now(timezone.utc)
|
|
180
|
+
expire = now + timedelta(days=JWT_REFRESH_TOKEN_EXPIRE_DAYS)
|
|
181
|
+
|
|
182
|
+
# JWT claims
|
|
183
|
+
payload = {
|
|
184
|
+
"sub": user_id, # Subject: the user identifier
|
|
185
|
+
"iat": int(now.timestamp()), # Issued at
|
|
186
|
+
"exp": int(expire.timestamp()), # Expiration time
|
|
187
|
+
"jti": f"{user_id}_{int(now.timestamp())}_refresh", # JWT ID (unique identifier)
|
|
188
|
+
"type": "refresh", # Token type
|
|
189
|
+
"iss": JWT_ISSUER, # Issuer
|
|
190
|
+
"aud": JWT_AUDIENCE # Audience
|
|
191
|
+
}
|
|
192
|
+
|
|
193
|
+
token = jwt.encode(payload, private_key, algorithm=JWT_ALGORITHM)
|
|
194
|
+
logger.debug(f"Created refresh token for user_id: {user_id}, expires: {expire.isoformat()}")
|
|
195
|
+
return token
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
def verify_token(token: str, expected_type: str = "access") -> dict:
|
|
199
|
+
"""
|
|
200
|
+
Verify and decode a JWT token.
|
|
201
|
+
|
|
202
|
+
Args:
|
|
203
|
+
token: JWT token string
|
|
204
|
+
expected_type: Expected token type ("access" or "refresh")
|
|
205
|
+
|
|
206
|
+
Returns:
|
|
207
|
+
dict: Decoded token payload
|
|
208
|
+
|
|
209
|
+
Raises:
|
|
210
|
+
JWTError: If token is invalid, expired, or type mismatch
|
|
211
|
+
"""
|
|
212
|
+
_, public_key = load_or_generate_keys()
|
|
213
|
+
|
|
214
|
+
try:
|
|
215
|
+
# Decode and verify token
|
|
216
|
+
payload = jwt.decode(
|
|
217
|
+
token,
|
|
218
|
+
public_key,
|
|
219
|
+
algorithms=[JWT_ALGORITHM],
|
|
220
|
+
issuer=JWT_ISSUER,
|
|
221
|
+
audience=JWT_AUDIENCE
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
# Verify token type
|
|
225
|
+
if payload.get("type") != expected_type:
|
|
226
|
+
raise JWTError(f"Invalid token type: expected {expected_type}, got {payload.get('type')}")
|
|
227
|
+
|
|
228
|
+
logger.debug(f"Token verified successfully: user_id={payload.get('sub')}, type={expected_type}")
|
|
229
|
+
return payload
|
|
230
|
+
|
|
231
|
+
except JWTError as e:
|
|
232
|
+
logger.warning(f"Token verification failed: {e}")
|
|
233
|
+
raise
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
def get_token_expiry(token: str) -> Optional[datetime]:
|
|
237
|
+
"""
|
|
238
|
+
Get the expiration time of a JWT token without full verification.
|
|
239
|
+
Useful for debugging/logging.
|
|
240
|
+
|
|
241
|
+
Args:
|
|
242
|
+
token: JWT token string
|
|
243
|
+
|
|
244
|
+
Returns:
|
|
245
|
+
datetime: Expiration time in UTC, or None if token is invalid
|
|
246
|
+
"""
|
|
247
|
+
try:
|
|
248
|
+
# Decode without verification (just to inspect claims)
|
|
249
|
+
payload = jwt.get_unverified_claims(token)
|
|
250
|
+
exp_timestamp = payload.get("exp")
|
|
251
|
+
if exp_timestamp:
|
|
252
|
+
return datetime.fromtimestamp(exp_timestamp, tz=timezone.utc)
|
|
253
|
+
except Exception as e:
|
|
254
|
+
logger.debug(f"Failed to get token expiry: {e}")
|
|
255
|
+
return None
|
|
256
|
+
|