fastworkflow 2.17.8__py3-none-any.whl → 2.17.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fastworkflow/_workflows/command_metadata_extraction/parameter_extraction.py +81 -5
- fastworkflow/chat_session.py +102 -33
- fastworkflow/command_metadata_api.py +50 -0
- fastworkflow/intent_clarification_agent.py +132 -0
- fastworkflow/run_fastapi_mcp/README.md +58 -40
- fastworkflow/run_fastapi_mcp/__main__.py +197 -114
- fastworkflow/run_fastapi_mcp/conversation_store.py +4 -4
- fastworkflow/run_fastapi_mcp/jwt_manager.py +24 -14
- fastworkflow/run_fastapi_mcp/utils.py +116 -52
- fastworkflow/train/__main__.py +1 -1
- fastworkflow/utils/react.py +13 -1
- fastworkflow/utils/signatures.py +49 -31
- fastworkflow/workflow_agent.py +78 -4
- {fastworkflow-2.17.8.dist-info → fastworkflow-2.17.10.dist-info}/METADATA +1 -1
- {fastworkflow-2.17.8.dist-info → fastworkflow-2.17.10.dist-info}/RECORD +18 -17
- {fastworkflow-2.17.8.dist-info → fastworkflow-2.17.10.dist-info}/LICENSE +0 -0
- {fastworkflow-2.17.8.dist-info → fastworkflow-2.17.10.dist-info}/WHEEL +0 -0
- {fastworkflow-2.17.8.dist-info → fastworkflow-2.17.10.dist-info}/entry_points.txt +0 -0
|
@@ -90,9 +90,9 @@ class ConversationSummary(BaseModel):
|
|
|
90
90
|
class ConversationStore:
|
|
91
91
|
"""Rdict-backed conversation persistence per user"""
|
|
92
92
|
|
|
93
|
-
def __init__(self,
|
|
94
|
-
self.
|
|
95
|
-
self.db_path = os.path.join(base_folder, f"{
|
|
93
|
+
def __init__(self, channel_id: str, base_folder: str):
|
|
94
|
+
self.channel_id = channel_id
|
|
95
|
+
self.db_path = os.path.join(base_folder, f"{channel_id}.rdb")
|
|
96
96
|
os.makedirs(base_folder, exist_ok=True)
|
|
97
97
|
|
|
98
98
|
def _get_db(self) -> Rdict:
|
|
@@ -352,7 +352,7 @@ class ConversationStore:
|
|
|
352
352
|
if conv_key in db:
|
|
353
353
|
conv = db[conv_key]
|
|
354
354
|
conversations.append({
|
|
355
|
-
"
|
|
355
|
+
"channel_id": self.channel_id,
|
|
356
356
|
"conversation_id": i,
|
|
357
357
|
**conv
|
|
358
358
|
})
|
|
@@ -153,7 +153,7 @@ def set_jwt_verification_mode(expect_encrypted: bool) -> None:
|
|
|
153
153
|
)
|
|
154
154
|
|
|
155
155
|
|
|
156
|
-
def create_access_token(user_id: str, expires_days: int | None = None) -> str:
|
|
156
|
+
def create_access_token(channel_id: str, user_id: Optional[str] = None, expires_days: int | None = None) -> str:
|
|
157
157
|
"""
|
|
158
158
|
Create a JWT access token for a user.
|
|
159
159
|
|
|
@@ -162,7 +162,8 @@ def create_access_token(user_id: str, expires_days: int | None = None) -> str:
|
|
|
162
162
|
- If False: Creates an unsigned token for trusted network use
|
|
163
163
|
|
|
164
164
|
Args:
|
|
165
|
-
|
|
165
|
+
channel_id: Channel identifier (required)
|
|
166
|
+
user_id: User identifier (optional, will be included as uid claim if provided)
|
|
166
167
|
expires_days: Optional custom expiration in days. If None, uses JWT_ACCESS_TOKEN_EXPIRE_MINUTES (default 60 minutes).
|
|
167
168
|
|
|
168
169
|
Returns:
|
|
@@ -176,30 +177,34 @@ def create_access_token(user_id: str, expires_days: int | None = None) -> str:
|
|
|
176
177
|
|
|
177
178
|
# JWT claims
|
|
178
179
|
payload = {
|
|
179
|
-
"sub":
|
|
180
|
+
"sub": channel_id, # Subject: the channel identifier
|
|
180
181
|
"iat": int(now.timestamp()), # Issued at
|
|
181
182
|
"exp": int(expire.timestamp()), # Expiration time
|
|
182
|
-
"jti": f"{
|
|
183
|
+
"jti": f"{channel_id}_{int(now.timestamp())}", # JWT ID (unique identifier)
|
|
183
184
|
"type": "access", # Token type
|
|
184
185
|
"iss": JWT_ISSUER, # Issuer
|
|
185
186
|
"aud": JWT_AUDIENCE # Audience
|
|
186
187
|
}
|
|
187
188
|
|
|
189
|
+
# Add optional user_id claim
|
|
190
|
+
if user_id is not None:
|
|
191
|
+
payload["uid"] = user_id
|
|
192
|
+
|
|
188
193
|
if EXPECT_ENCRYPTED_JWT:
|
|
189
194
|
# Secure mode: create signed token
|
|
190
195
|
private_key, _ = load_or_generate_keys()
|
|
191
196
|
token = jwt.encode(payload, private_key, algorithm=JWT_ALGORITHM)
|
|
192
|
-
logger.debug(f"Created signed access token for user_id: {user_id}, expires: {expire.isoformat()}")
|
|
197
|
+
logger.debug(f"Created signed access token for channel_id: {channel_id}, user_id: {user_id}, expires: {expire.isoformat()}")
|
|
193
198
|
else:
|
|
194
199
|
# Trusted network mode: create unsigned token using HS256 with empty key
|
|
195
200
|
# This creates a JWT that can be decoded without verification
|
|
196
201
|
token = jwt.encode(payload, "", algorithm="HS256")
|
|
197
|
-
logger.debug(f"Created unsigned access token for user_id: {user_id}, expires: {expire.isoformat()}")
|
|
202
|
+
logger.debug(f"Created unsigned access token for channel_id: {channel_id}, user_id: {user_id}, expires: {expire.isoformat()}")
|
|
198
203
|
|
|
199
204
|
return token
|
|
200
205
|
|
|
201
206
|
|
|
202
|
-
def create_refresh_token(user_id: str) -> str:
|
|
207
|
+
def create_refresh_token(channel_id: str, user_id: Optional[str] = None) -> str:
|
|
203
208
|
"""
|
|
204
209
|
Create a JWT refresh token for a user.
|
|
205
210
|
|
|
@@ -208,7 +213,8 @@ def create_refresh_token(user_id: str) -> str:
|
|
|
208
213
|
- If False: Creates an unsigned token for trusted network use
|
|
209
214
|
|
|
210
215
|
Args:
|
|
211
|
-
|
|
216
|
+
channel_id: Channel identifier (required)
|
|
217
|
+
user_id: User identifier (optional, will be included as uid claim if provided)
|
|
212
218
|
|
|
213
219
|
Returns:
|
|
214
220
|
str: Encoded JWT refresh token (signed or unsigned based on EXPECT_ENCRYPTED_JWT)
|
|
@@ -218,25 +224,29 @@ def create_refresh_token(user_id: str) -> str:
|
|
|
218
224
|
|
|
219
225
|
# JWT claims
|
|
220
226
|
payload = {
|
|
221
|
-
"sub":
|
|
227
|
+
"sub": channel_id, # Subject: the channel identifier
|
|
222
228
|
"iat": int(now.timestamp()), # Issued at
|
|
223
229
|
"exp": int(expire.timestamp()), # Expiration time
|
|
224
|
-
"jti": f"{
|
|
230
|
+
"jti": f"{channel_id}_{int(now.timestamp())}_refresh", # JWT ID (unique identifier)
|
|
225
231
|
"type": "refresh", # Token type
|
|
226
232
|
"iss": JWT_ISSUER, # Issuer
|
|
227
233
|
"aud": JWT_AUDIENCE # Audience
|
|
228
234
|
}
|
|
229
235
|
|
|
236
|
+
# Add optional user_id claim
|
|
237
|
+
if user_id is not None:
|
|
238
|
+
payload["uid"] = user_id
|
|
239
|
+
|
|
230
240
|
if EXPECT_ENCRYPTED_JWT:
|
|
231
241
|
# Secure mode: create signed token
|
|
232
242
|
private_key, _ = load_or_generate_keys()
|
|
233
243
|
token = jwt.encode(payload, private_key, algorithm=JWT_ALGORITHM)
|
|
234
|
-
logger.debug(f"Created signed refresh token for user_id: {user_id}, expires: {expire.isoformat()}")
|
|
244
|
+
logger.debug(f"Created signed refresh token for channel_id: {channel_id}, user_id: {user_id}, expires: {expire.isoformat()}")
|
|
235
245
|
else:
|
|
236
246
|
# Trusted network mode: create unsigned token using HS256 with empty key
|
|
237
247
|
# This creates a JWT that can be decoded without verification
|
|
238
248
|
token = jwt.encode(payload, "", algorithm="HS256")
|
|
239
|
-
logger.debug(f"Created unsigned refresh token for user_id: {user_id}, expires: {expire.isoformat()}")
|
|
249
|
+
logger.debug(f"Created unsigned refresh token for channel_id: {channel_id}, user_id: {user_id}, expires: {expire.isoformat()}")
|
|
240
250
|
|
|
241
251
|
return token
|
|
242
252
|
|
|
@@ -281,7 +291,7 @@ def verify_token(token: str, expected_type: str = "access") -> dict:
|
|
|
281
291
|
if payload.get("type") != expected_type:
|
|
282
292
|
raise JWTError(f"Invalid token type: expected {expected_type}, got {payload.get('type')}")
|
|
283
293
|
|
|
284
|
-
logger.debug(f"Token decoded (unverified mode):
|
|
294
|
+
logger.debug(f"Token decoded (unverified mode): channel_id={payload.get('sub')}, type={expected_type}")
|
|
285
295
|
return payload
|
|
286
296
|
|
|
287
297
|
# Standard mode: full verification (existing code)
|
|
@@ -301,7 +311,7 @@ def verify_token(token: str, expected_type: str = "access") -> dict:
|
|
|
301
311
|
if payload.get("type") != expected_type:
|
|
302
312
|
raise JWTError(f"Invalid token type: expected {expected_type}, got {payload.get('type')}")
|
|
303
313
|
|
|
304
|
-
logger.debug(f"Token verified successfully:
|
|
314
|
+
logger.debug(f"Token verified successfully: channel_id={payload.get('sub')}, type={expected_type}")
|
|
305
315
|
return payload
|
|
306
316
|
|
|
307
317
|
except JWTError as e:
|
|
@@ -22,9 +22,12 @@ from .jwt_manager import verify_token
|
|
|
22
22
|
# ============================================================================
|
|
23
23
|
|
|
24
24
|
class InitializationRequest(BaseModel):
|
|
25
|
-
"""Request to initialize a FastWorkflow session for a
|
|
26
|
-
|
|
25
|
+
"""Request to initialize a FastWorkflow session for a channel"""
|
|
26
|
+
channel_id: str
|
|
27
|
+
user_id: Optional[str] = None # Required if startup_command or startup_action provided
|
|
27
28
|
stream_format: Optional[str] = None # "ndjson" | "sse" (default ndjson)
|
|
29
|
+
startup_command: Optional[str] = None # Mutually exclusive with startup_action
|
|
30
|
+
startup_action: Optional[dict[str, Any]] = None # Mutually exclusive with startup_command
|
|
28
31
|
|
|
29
32
|
|
|
30
33
|
class TokenResponse(BaseModel):
|
|
@@ -35,9 +38,19 @@ class TokenResponse(BaseModel):
|
|
|
35
38
|
expires_in: int # Access token expiration in seconds
|
|
36
39
|
|
|
37
40
|
|
|
41
|
+
class InitializeResponse(BaseModel):
|
|
42
|
+
"""Response from initialization including tokens and optional startup output"""
|
|
43
|
+
access_token: str
|
|
44
|
+
refresh_token: str
|
|
45
|
+
token_type: str = "bearer"
|
|
46
|
+
expires_in: int # Access token expiration in seconds
|
|
47
|
+
startup_output: Optional[fastworkflow.CommandOutput] = None # Present if startup was executed
|
|
48
|
+
|
|
49
|
+
|
|
38
50
|
class SessionData(BaseModel):
|
|
39
51
|
"""Validated session data extracted from JWT token"""
|
|
40
|
-
|
|
52
|
+
channel_id: str
|
|
53
|
+
user_id: Optional[str] = None # From JWT uid claim
|
|
41
54
|
token_type: str # "access" or "refresh"
|
|
42
55
|
issued_at: int # Unix timestamp
|
|
43
56
|
expires_at: int # Unix timestamp
|
|
@@ -48,7 +61,7 @@ class SessionData(BaseModel):
|
|
|
48
61
|
class InvokeRequest(BaseModel):
|
|
49
62
|
"""
|
|
50
63
|
Request to invoke agent or assistant.
|
|
51
|
-
Requires
|
|
64
|
+
Requires channel_id to be passed in the Authorization header (via JWT token).
|
|
52
65
|
"""
|
|
53
66
|
user_query: str
|
|
54
67
|
timeout_seconds: int = 60
|
|
@@ -57,7 +70,7 @@ class InvokeRequest(BaseModel):
|
|
|
57
70
|
class PerformActionRequest(BaseModel):
|
|
58
71
|
"""
|
|
59
72
|
Request to perform a specific action.
|
|
60
|
-
Requires
|
|
73
|
+
Requires channel_id to be passed in the Authorization header (via JWT token).
|
|
61
74
|
"""
|
|
62
75
|
action: dict[str, Any] # Will be converted to fastworkflow.Action
|
|
63
76
|
timeout_seconds: int = 60
|
|
@@ -66,7 +79,7 @@ class PerformActionRequest(BaseModel):
|
|
|
66
79
|
class PostFeedbackRequest(BaseModel):
|
|
67
80
|
"""
|
|
68
81
|
Request to post feedback on the latest turn.
|
|
69
|
-
Requires
|
|
82
|
+
Requires channel_id to be passed in the Authorization header (via JWT token).
|
|
70
83
|
|
|
71
84
|
Note: binary_or_numeric_score accepts numeric values (float).
|
|
72
85
|
Boolean values (True/False) are automatically converted to 1.0/0.0.
|
|
@@ -86,7 +99,7 @@ class PostFeedbackRequest(BaseModel):
|
|
|
86
99
|
class ActivateConversationRequest(BaseModel):
|
|
87
100
|
"""
|
|
88
101
|
Request to activate a conversation by ID.
|
|
89
|
-
Requires
|
|
102
|
+
Requires channel_id to be passed in the Authorization header (via JWT token).
|
|
90
103
|
"""
|
|
91
104
|
conversation_id: int
|
|
92
105
|
|
|
@@ -98,7 +111,8 @@ class DumpConversationsRequest(BaseModel):
|
|
|
98
111
|
|
|
99
112
|
class GenerateMCPTokenRequest(BaseModel):
|
|
100
113
|
"""Request to generate a long-lived MCP token"""
|
|
101
|
-
|
|
114
|
+
channel_id: str
|
|
115
|
+
user_id: Optional[str] = None
|
|
102
116
|
expires_days: int = 365
|
|
103
117
|
|
|
104
118
|
|
|
@@ -151,7 +165,7 @@ def get_session_from_jwt(
|
|
|
151
165
|
```python
|
|
152
166
|
@app.post("/endpoint")
|
|
153
167
|
async def endpoint(session: SessionData = Depends(get_session_from_jwt)):
|
|
154
|
-
# Use session.
|
|
168
|
+
# Use session.channel_id, session.token_type, etc.
|
|
155
169
|
pass
|
|
156
170
|
```
|
|
157
171
|
|
|
@@ -177,7 +191,8 @@ def get_session_from_jwt(
|
|
|
177
191
|
|
|
178
192
|
# Extract session data from payload, including the token for workflow context
|
|
179
193
|
return SessionData(
|
|
180
|
-
|
|
194
|
+
channel_id=payload["sub"],
|
|
195
|
+
user_id=payload.get("uid"), # Optional user_id from uid claim
|
|
181
196
|
token_type=payload["type"],
|
|
182
197
|
issued_at=payload["iat"],
|
|
183
198
|
expires_at=payload["exp"],
|
|
@@ -200,8 +215,8 @@ def get_session_from_jwt(
|
|
|
200
215
|
|
|
201
216
|
|
|
202
217
|
async def ensure_user_runtime_exists(
|
|
203
|
-
|
|
204
|
-
session_manager: '
|
|
218
|
+
channel_id: str,
|
|
219
|
+
session_manager: 'ChannelSessionManager',
|
|
205
220
|
workflow_path: str,
|
|
206
221
|
context: Optional[dict] = None,
|
|
207
222
|
startup_command: Optional[str] = None,
|
|
@@ -216,8 +231,8 @@ async def ensure_user_runtime_exists(
|
|
|
216
231
|
allowing it to be reused across different parts of the application without duplicating code.
|
|
217
232
|
|
|
218
233
|
Args:
|
|
219
|
-
|
|
220
|
-
session_manager: The
|
|
234
|
+
channel_id: The user identifier
|
|
235
|
+
session_manager: The ChannelSessionManager instance
|
|
221
236
|
workflow_path: Path to the workflow directory (validated at server startup)
|
|
222
237
|
context: Optional workflow context dictionary
|
|
223
238
|
startup_command: Optional startup command
|
|
@@ -229,9 +244,9 @@ async def ensure_user_runtime_exists(
|
|
|
229
244
|
HTTPException: If session creation fails
|
|
230
245
|
"""
|
|
231
246
|
# Check if user already has an active session
|
|
232
|
-
existing_runtime = await session_manager.get_session(
|
|
247
|
+
existing_runtime = await session_manager.get_session(channel_id)
|
|
233
248
|
if existing_runtime:
|
|
234
|
-
logger.debug(f"Session for
|
|
249
|
+
logger.debug(f"Session for channel_id {channel_id} already exists, skipping creation")
|
|
235
250
|
|
|
236
251
|
# Update the workflow's context with the current token if provided
|
|
237
252
|
if http_bearer_token and existing_runtime.chat_session:
|
|
@@ -243,7 +258,7 @@ async def ensure_user_runtime_exists(
|
|
|
243
258
|
# 2. The workflow is NOT marked dirty (won't persist to disk)
|
|
244
259
|
# 3. This is intentional for JWT tokens - we don't want to persist sensitive tokens
|
|
245
260
|
active_workflow.context['http_bearer_token'] = http_bearer_token
|
|
246
|
-
logger.debug(f"Updated http_bearer_token in workflow context for
|
|
261
|
+
logger.debug(f"Updated http_bearer_token in workflow context for channel_id {channel_id}")
|
|
247
262
|
|
|
248
263
|
return
|
|
249
264
|
|
|
@@ -256,13 +271,13 @@ async def ensure_user_runtime_exists(
|
|
|
256
271
|
# Initialize context with http_bearer_token
|
|
257
272
|
context = {'http_bearer_token': http_bearer_token}
|
|
258
273
|
|
|
259
|
-
logger.info(f"Creating new session for
|
|
274
|
+
logger.info(f"Creating new session for channel_id: {channel_id}")
|
|
260
275
|
|
|
261
|
-
# Resolve conversation store base folder from SPEEDDICT_FOLDERNAME/
|
|
262
|
-
conv_base_folder =
|
|
276
|
+
# Resolve conversation store base folder from SPEEDDICT_FOLDERNAME/channel_conversations
|
|
277
|
+
conv_base_folder = get_channelconversations_dir()
|
|
263
278
|
|
|
264
279
|
# Create conversation store for this user
|
|
265
|
-
conversation_store = ConversationStore(
|
|
280
|
+
conversation_store = ConversationStore(channel_id, conv_base_folder)
|
|
266
281
|
|
|
267
282
|
# Create ChatSession in agent mode (forced)
|
|
268
283
|
chat_session = fastworkflow.ChatSession(run_as_agent=True)
|
|
@@ -280,9 +295,9 @@ async def ensure_user_runtime_exists(
|
|
|
280
295
|
# Restore the conversation history from saved turns
|
|
281
296
|
restored_history = restore_history_from_turns(conversation["turns"])
|
|
282
297
|
chat_session._conversation_history = restored_history
|
|
283
|
-
logger.info(f"Restored conversation {conv_id_to_restore} for user {
|
|
298
|
+
logger.info(f"Restored conversation {conv_id_to_restore} for user {channel_id}")
|
|
284
299
|
else:
|
|
285
|
-
logger.info(f"No conversations available for user {
|
|
300
|
+
logger.info(f"No conversations available for user {channel_id}, starting new")
|
|
286
301
|
conv_id_to_restore = None
|
|
287
302
|
|
|
288
303
|
# Start the workflow
|
|
@@ -296,14 +311,14 @@ async def ensure_user_runtime_exists(
|
|
|
296
311
|
|
|
297
312
|
# Create and store user runtime
|
|
298
313
|
await session_manager.create_session(
|
|
299
|
-
|
|
314
|
+
channel_id=channel_id,
|
|
300
315
|
chat_session=chat_session,
|
|
301
316
|
conversation_store=conversation_store,
|
|
302
317
|
active_conversation_id=conv_id_to_restore,
|
|
303
318
|
stream_format=stream_format
|
|
304
319
|
)
|
|
305
320
|
|
|
306
|
-
logger.info(f"Successfully created session for
|
|
321
|
+
logger.info(f"Successfully created session for channel_id: {channel_id}")
|
|
307
322
|
|
|
308
323
|
# Wait for workflow to be ready (background thread sets status to RUNNING)
|
|
309
324
|
import asyncio
|
|
@@ -318,19 +333,19 @@ async def ensure_user_runtime_exists(
|
|
|
318
333
|
logger.warning(f"Workflow not fully started after {max_wait}s, status={chat_session._status}")
|
|
319
334
|
|
|
320
335
|
|
|
321
|
-
def
|
|
336
|
+
def get_channelconversations_dir() -> str:
|
|
322
337
|
"""
|
|
323
|
-
Return SPEEDDICT_FOLDERNAME/
|
|
338
|
+
Return SPEEDDICT_FOLDERNAME/channel_conversations, creating the directory if missing.
|
|
324
339
|
fastworkflow is injected to avoid circular imports and to access get_env_var.
|
|
325
340
|
"""
|
|
326
341
|
speedict_foldername = fastworkflow.get_env_var("SPEEDDICT_FOLDERNAME")
|
|
327
|
-
user_conversations_dir = os.path.join(speedict_foldername, "
|
|
342
|
+
user_conversations_dir = os.path.join(speedict_foldername, "channel_conversations")
|
|
328
343
|
os.makedirs(user_conversations_dir, exist_ok=True)
|
|
329
344
|
return user_conversations_dir
|
|
330
345
|
|
|
331
346
|
|
|
332
347
|
async def wait_for_command_output(
|
|
333
|
-
runtime: '
|
|
348
|
+
runtime: 'ChannelRuntime',
|
|
334
349
|
timeout_seconds: int
|
|
335
350
|
) -> 'fastworkflow.CommandOutput':
|
|
336
351
|
"""Wait for command output from the queue with timeout"""
|
|
@@ -349,14 +364,60 @@ async def wait_for_command_output(
|
|
|
349
364
|
)
|
|
350
365
|
|
|
351
366
|
|
|
352
|
-
def collect_trace_events(runtime: '
|
|
353
|
-
"""
|
|
367
|
+
def collect_trace_events(runtime: 'ChannelRuntime', user_id: Optional[str] = None) -> list[dict[str, Any]]:
|
|
368
|
+
"""
|
|
369
|
+
Drain and collect all trace events from the queue.
|
|
370
|
+
|
|
371
|
+
Args:
|
|
372
|
+
runtime: ChannelRuntime containing the trace queue
|
|
373
|
+
user_id: Optional user_id to include in traces
|
|
374
|
+
|
|
375
|
+
Returns:
|
|
376
|
+
List of trace event dictionaries with optional user_id
|
|
377
|
+
"""
|
|
354
378
|
traces = []
|
|
355
379
|
|
|
356
380
|
while True:
|
|
357
381
|
try:
|
|
358
382
|
evt = runtime.chat_session.command_trace_queue.get_nowait()
|
|
359
|
-
|
|
383
|
+
trace = {
|
|
384
|
+
"direction": evt.direction.value if hasattr(evt.direction, 'value') else str(evt.direction),
|
|
385
|
+
"raw_command": evt.raw_command,
|
|
386
|
+
"command_name": evt.command_name,
|
|
387
|
+
"parameters": evt.parameters,
|
|
388
|
+
"response_text": evt.response_text,
|
|
389
|
+
"success": evt.success,
|
|
390
|
+
"timestamp_ms": evt.timestamp_ms
|
|
391
|
+
}
|
|
392
|
+
if user_id is not None:
|
|
393
|
+
trace["user_id"] = user_id
|
|
394
|
+
traces.append(trace)
|
|
395
|
+
except queue.Empty:
|
|
396
|
+
break
|
|
397
|
+
|
|
398
|
+
return traces
|
|
399
|
+
|
|
400
|
+
|
|
401
|
+
async def collect_trace_events_async(
|
|
402
|
+
trace_queue: queue.Queue,
|
|
403
|
+
user_id: Optional[str] = None
|
|
404
|
+
) -> list[dict[str, Any]]:
|
|
405
|
+
"""
|
|
406
|
+
Async version: Drain and collect all trace events from a trace queue.
|
|
407
|
+
|
|
408
|
+
Args:
|
|
409
|
+
trace_queue: The trace queue to drain
|
|
410
|
+
user_id: Optional user_id to include in traces
|
|
411
|
+
|
|
412
|
+
Returns:
|
|
413
|
+
List of trace event dictionaries with optional user_id
|
|
414
|
+
"""
|
|
415
|
+
traces = []
|
|
416
|
+
|
|
417
|
+
while True:
|
|
418
|
+
try:
|
|
419
|
+
evt = trace_queue.get_nowait()
|
|
420
|
+
trace = {
|
|
360
421
|
"direction": evt.direction.value if hasattr(evt.direction, 'value') else str(evt.direction),
|
|
361
422
|
"raw_command": evt.raw_command,
|
|
362
423
|
"command_name": evt.command_name,
|
|
@@ -364,7 +425,10 @@ def collect_trace_events(runtime: 'UserRuntime') -> list[dict[str, Any]]:
|
|
|
364
425
|
"response_text": evt.response_text,
|
|
365
426
|
"success": evt.success,
|
|
366
427
|
"timestamp_ms": evt.timestamp_ms
|
|
367
|
-
}
|
|
428
|
+
}
|
|
429
|
+
if user_id is not None:
|
|
430
|
+
trace["user_id"] = user_id
|
|
431
|
+
traces.append(trace)
|
|
368
432
|
except queue.Empty:
|
|
369
433
|
break
|
|
370
434
|
|
|
@@ -376,9 +440,9 @@ def collect_trace_events(runtime: 'UserRuntime') -> list[dict[str, Any]]:
|
|
|
376
440
|
# ============================================================================
|
|
377
441
|
|
|
378
442
|
@dataclass
|
|
379
|
-
class
|
|
380
|
-
"""Per-
|
|
381
|
-
|
|
443
|
+
class ChannelRuntime:
|
|
444
|
+
"""Per-channel runtime state"""
|
|
445
|
+
channel_id: str
|
|
382
446
|
active_conversation_id: int
|
|
383
447
|
chat_session: 'fastworkflow.ChatSession'
|
|
384
448
|
lock: asyncio.Lock
|
|
@@ -386,51 +450,51 @@ class UserRuntime:
|
|
|
386
450
|
stream_format: str = "ndjson" # "ndjson" | "sse"
|
|
387
451
|
|
|
388
452
|
|
|
389
|
-
class
|
|
390
|
-
"""Process-wide manager for
|
|
453
|
+
class ChannelSessionManager:
|
|
454
|
+
"""Process-wide manager for channel sessions"""
|
|
391
455
|
|
|
392
456
|
def __init__(self):
|
|
393
|
-
self._sessions: dict[str,
|
|
457
|
+
self._sessions: dict[str, ChannelRuntime] = {}
|
|
394
458
|
self._lock = asyncio.Lock()
|
|
395
459
|
|
|
396
|
-
async def get_session(self,
|
|
397
|
-
"""Get a session by
|
|
460
|
+
async def get_session(self, channel_id: str) -> Optional[ChannelRuntime]:
|
|
461
|
+
"""Get a session by channel_id"""
|
|
398
462
|
async with self._lock:
|
|
399
|
-
return self._sessions.get(
|
|
463
|
+
return self._sessions.get(channel_id)
|
|
400
464
|
|
|
401
465
|
async def create_session(
|
|
402
466
|
self,
|
|
403
|
-
|
|
467
|
+
channel_id: str,
|
|
404
468
|
chat_session: 'fastworkflow.ChatSession',
|
|
405
469
|
conversation_store: 'ConversationStore',
|
|
406
470
|
active_conversation_id: Optional[int] = None,
|
|
407
471
|
stream_format: str = "ndjson"
|
|
408
|
-
) ->
|
|
472
|
+
) -> ChannelRuntime:
|
|
409
473
|
"""Create or update a session"""
|
|
410
474
|
async with self._lock:
|
|
411
|
-
runtime =
|
|
412
|
-
|
|
475
|
+
runtime = ChannelRuntime(
|
|
476
|
+
channel_id=channel_id,
|
|
413
477
|
active_conversation_id=active_conversation_id or 0,
|
|
414
478
|
chat_session=chat_session,
|
|
415
479
|
lock=asyncio.Lock(),
|
|
416
480
|
conversation_store=conversation_store,
|
|
417
481
|
stream_format=stream_format
|
|
418
482
|
)
|
|
419
|
-
self._sessions[
|
|
483
|
+
self._sessions[channel_id] = runtime
|
|
420
484
|
return runtime
|
|
421
485
|
|
|
422
|
-
async def remove_session(self,
|
|
486
|
+
async def remove_session(self, channel_id: str) -> None:
|
|
423
487
|
"""Remove a session"""
|
|
424
488
|
async with self._lock:
|
|
425
|
-
if
|
|
426
|
-
del self._sessions[
|
|
489
|
+
if channel_id in self._sessions:
|
|
490
|
+
del self._sessions[channel_id]
|
|
427
491
|
|
|
428
492
|
|
|
429
493
|
# ============================================================================
|
|
430
494
|
# Helper Functions
|
|
431
495
|
# ============================================================================
|
|
432
496
|
|
|
433
|
-
def save_conversation_incremental(runtime:
|
|
497
|
+
def save_conversation_incremental(runtime: ChannelRuntime, extract_turns_func, logger) -> None:
|
|
434
498
|
"""
|
|
435
499
|
Save conversation turns incrementally after each turn (without generating topic/summary).
|
|
436
500
|
This provides crash protection - all turns except the last will be preserved.
|
|
@@ -442,7 +506,7 @@ def save_conversation_incremental(runtime: UserRuntime, extract_turns_func, logg
|
|
|
442
506
|
# This is the first conversation for this session
|
|
443
507
|
# Reserve ID 1 and use it
|
|
444
508
|
runtime.active_conversation_id = runtime.conversation_store.reserve_next_conversation_id()
|
|
445
|
-
logger.debug(f"Initialized first conversation with ID {runtime.active_conversation_id} for user {runtime.
|
|
509
|
+
logger.debug(f"Initialized first conversation with ID {runtime.active_conversation_id} for user {runtime.channel_id}")
|
|
446
510
|
|
|
447
511
|
# Save turns using the active conversation ID
|
|
448
512
|
runtime.conversation_store.save_conversation_turns(
|
fastworkflow/train/__main__.py
CHANGED
|
@@ -114,7 +114,7 @@ def _get_commands_with_parameters(json_path):
|
|
|
114
114
|
command_directory = json.load(f)
|
|
115
115
|
|
|
116
116
|
# Extract the command metadata
|
|
117
|
-
commands_metadata = command_directory.get("
|
|
117
|
+
commands_metadata = command_directory.get("map_command_2_metadata", {})
|
|
118
118
|
|
|
119
119
|
# Initialize result dictionary
|
|
120
120
|
commands_with_parameters = {}
|
fastworkflow/utils/react.py
CHANGED
|
@@ -2,6 +2,7 @@ import logging
|
|
|
2
2
|
from typing import TYPE_CHECKING, Any, Callable, Literal
|
|
3
3
|
|
|
4
4
|
from litellm import ContextWindowExceededError
|
|
5
|
+
from litellm import exceptions as litellm_exceptions
|
|
5
6
|
|
|
6
7
|
import dspy
|
|
7
8
|
from dspy.adapters.types.tool import Tool
|
|
@@ -41,6 +42,7 @@ class fastWorkflowReAct(Module):
|
|
|
41
42
|
super().__init__()
|
|
42
43
|
self.signature = signature = ensure_signature(signature)
|
|
43
44
|
self.max_iters = max_iters
|
|
45
|
+
self.iteration_counter = 0
|
|
44
46
|
|
|
45
47
|
tools = [t if isinstance(t, Tool) else Tool(t) for t in tools]
|
|
46
48
|
tools = {tool.name: tool for tool in tools}
|
|
@@ -105,7 +107,8 @@ class fastWorkflowReAct(Module):
|
|
|
105
107
|
|
|
106
108
|
trajectory = {}
|
|
107
109
|
max_iters = input_args.pop("max_iters", self.max_iters)
|
|
108
|
-
|
|
110
|
+
idx = 0
|
|
111
|
+
while True:
|
|
109
112
|
try:
|
|
110
113
|
pred = self._call_with_potential_trajectory_truncation(self.react, trajectory, **input_args)
|
|
111
114
|
except ValueError as err:
|
|
@@ -126,6 +129,12 @@ class fastWorkflowReAct(Module):
|
|
|
126
129
|
if pred.next_tool_name == "finish":
|
|
127
130
|
break
|
|
128
131
|
|
|
132
|
+
idx += 1 # this is the counter for the index of the entire trajectory
|
|
133
|
+
self.iteration_counter += 1 # this counter just determines the number of times we run the react agent and it's reset everytime we call the user for clarification
|
|
134
|
+
if self.iteration_counter >= max_iters:
|
|
135
|
+
logger.warning("Max iterations reached")
|
|
136
|
+
break
|
|
137
|
+
|
|
129
138
|
extract = self._call_with_potential_trajectory_truncation(self.extract, trajectory, **input_args)
|
|
130
139
|
return dspy.Prediction(trajectory=trajectory, **extract)
|
|
131
140
|
|
|
@@ -161,6 +170,9 @@ class fastWorkflowReAct(Module):
|
|
|
161
170
|
**input_args,
|
|
162
171
|
trajectory=self._format_trajectory(trajectory),
|
|
163
172
|
)
|
|
173
|
+
except litellm_exceptions.BadRequestError:
|
|
174
|
+
logger.warning("Trajectory exceeded the context window, truncating the oldest tool call information.")
|
|
175
|
+
trajectory = self.truncate_trajectory(trajectory)
|
|
164
176
|
except ContextWindowExceededError:
|
|
165
177
|
logger.warning("Trajectory exceeded the context window, truncating the oldest tool call information.")
|
|
166
178
|
trajectory = self.truncate_trajectory(trajectory)
|