fastworkflow 2.17.8__py3-none-any.whl → 2.17.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -90,9 +90,9 @@ class ConversationSummary(BaseModel):
90
90
  class ConversationStore:
91
91
  """Rdict-backed conversation persistence per user"""
92
92
 
93
- def __init__(self, user_id: str, base_folder: str):
94
- self.user_id = user_id
95
- self.db_path = os.path.join(base_folder, f"{user_id}.rdb")
93
+ def __init__(self, channel_id: str, base_folder: str):
94
+ self.channel_id = channel_id
95
+ self.db_path = os.path.join(base_folder, f"{channel_id}.rdb")
96
96
  os.makedirs(base_folder, exist_ok=True)
97
97
 
98
98
  def _get_db(self) -> Rdict:
@@ -352,7 +352,7 @@ class ConversationStore:
352
352
  if conv_key in db:
353
353
  conv = db[conv_key]
354
354
  conversations.append({
355
- "user_id": self.user_id,
355
+ "channel_id": self.channel_id,
356
356
  "conversation_id": i,
357
357
  **conv
358
358
  })
@@ -153,7 +153,7 @@ def set_jwt_verification_mode(expect_encrypted: bool) -> None:
153
153
  )
154
154
 
155
155
 
156
- def create_access_token(user_id: str, expires_days: int | None = None) -> str:
156
+ def create_access_token(channel_id: str, user_id: Optional[str] = None, expires_days: int | None = None) -> str:
157
157
  """
158
158
  Create a JWT access token for a user.
159
159
 
@@ -162,7 +162,8 @@ def create_access_token(user_id: str, expires_days: int | None = None) -> str:
162
162
  - If False: Creates an unsigned token for trusted network use
163
163
 
164
164
  Args:
165
- user_id: User identifier
165
+ channel_id: Channel identifier (required)
166
+ user_id: User identifier (optional, will be included as uid claim if provided)
166
167
  expires_days: Optional custom expiration in days. If None, uses JWT_ACCESS_TOKEN_EXPIRE_MINUTES (default 60 minutes).
167
168
 
168
169
  Returns:
@@ -176,30 +177,34 @@ def create_access_token(user_id: str, expires_days: int | None = None) -> str:
176
177
 
177
178
  # JWT claims
178
179
  payload = {
179
- "sub": user_id, # Subject: the user identifier
180
+ "sub": channel_id, # Subject: the channel identifier
180
181
  "iat": int(now.timestamp()), # Issued at
181
182
  "exp": int(expire.timestamp()), # Expiration time
182
- "jti": f"{user_id}_{int(now.timestamp())}", # JWT ID (unique identifier)
183
+ "jti": f"{channel_id}_{int(now.timestamp())}", # JWT ID (unique identifier)
183
184
  "type": "access", # Token type
184
185
  "iss": JWT_ISSUER, # Issuer
185
186
  "aud": JWT_AUDIENCE # Audience
186
187
  }
187
188
 
189
+ # Add optional user_id claim
190
+ if user_id is not None:
191
+ payload["uid"] = user_id
192
+
188
193
  if EXPECT_ENCRYPTED_JWT:
189
194
  # Secure mode: create signed token
190
195
  private_key, _ = load_or_generate_keys()
191
196
  token = jwt.encode(payload, private_key, algorithm=JWT_ALGORITHM)
192
- logger.debug(f"Created signed access token for user_id: {user_id}, expires: {expire.isoformat()}")
197
+ logger.debug(f"Created signed access token for channel_id: {channel_id}, user_id: {user_id}, expires: {expire.isoformat()}")
193
198
  else:
194
199
  # Trusted network mode: create unsigned token using HS256 with empty key
195
200
  # This creates a JWT that can be decoded without verification
196
201
  token = jwt.encode(payload, "", algorithm="HS256")
197
- logger.debug(f"Created unsigned access token for user_id: {user_id}, expires: {expire.isoformat()}")
202
+ logger.debug(f"Created unsigned access token for channel_id: {channel_id}, user_id: {user_id}, expires: {expire.isoformat()}")
198
203
 
199
204
  return token
200
205
 
201
206
 
202
- def create_refresh_token(user_id: str) -> str:
207
+ def create_refresh_token(channel_id: str, user_id: Optional[str] = None) -> str:
203
208
  """
204
209
  Create a JWT refresh token for a user.
205
210
 
@@ -208,7 +213,8 @@ def create_refresh_token(user_id: str) -> str:
208
213
  - If False: Creates an unsigned token for trusted network use
209
214
 
210
215
  Args:
211
- user_id: User identifier
216
+ channel_id: Channel identifier (required)
217
+ user_id: User identifier (optional, will be included as uid claim if provided)
212
218
 
213
219
  Returns:
214
220
  str: Encoded JWT refresh token (signed or unsigned based on EXPECT_ENCRYPTED_JWT)
@@ -218,25 +224,29 @@ def create_refresh_token(user_id: str) -> str:
218
224
 
219
225
  # JWT claims
220
226
  payload = {
221
- "sub": user_id, # Subject: the user identifier
227
+ "sub": channel_id, # Subject: the channel identifier
222
228
  "iat": int(now.timestamp()), # Issued at
223
229
  "exp": int(expire.timestamp()), # Expiration time
224
- "jti": f"{user_id}_{int(now.timestamp())}_refresh", # JWT ID (unique identifier)
230
+ "jti": f"{channel_id}_{int(now.timestamp())}_refresh", # JWT ID (unique identifier)
225
231
  "type": "refresh", # Token type
226
232
  "iss": JWT_ISSUER, # Issuer
227
233
  "aud": JWT_AUDIENCE # Audience
228
234
  }
229
235
 
236
+ # Add optional user_id claim
237
+ if user_id is not None:
238
+ payload["uid"] = user_id
239
+
230
240
  if EXPECT_ENCRYPTED_JWT:
231
241
  # Secure mode: create signed token
232
242
  private_key, _ = load_or_generate_keys()
233
243
  token = jwt.encode(payload, private_key, algorithm=JWT_ALGORITHM)
234
- logger.debug(f"Created signed refresh token for user_id: {user_id}, expires: {expire.isoformat()}")
244
+ logger.debug(f"Created signed refresh token for channel_id: {channel_id}, user_id: {user_id}, expires: {expire.isoformat()}")
235
245
  else:
236
246
  # Trusted network mode: create unsigned token using HS256 with empty key
237
247
  # This creates a JWT that can be decoded without verification
238
248
  token = jwt.encode(payload, "", algorithm="HS256")
239
- logger.debug(f"Created unsigned refresh token for user_id: {user_id}, expires: {expire.isoformat()}")
249
+ logger.debug(f"Created unsigned refresh token for channel_id: {channel_id}, user_id: {user_id}, expires: {expire.isoformat()}")
240
250
 
241
251
  return token
242
252
 
@@ -281,7 +291,7 @@ def verify_token(token: str, expected_type: str = "access") -> dict:
281
291
  if payload.get("type") != expected_type:
282
292
  raise JWTError(f"Invalid token type: expected {expected_type}, got {payload.get('type')}")
283
293
 
284
- logger.debug(f"Token decoded (unverified mode): user_id={payload.get('sub')}, type={expected_type}")
294
+ logger.debug(f"Token decoded (unverified mode): channel_id={payload.get('sub')}, type={expected_type}")
285
295
  return payload
286
296
 
287
297
  # Standard mode: full verification (existing code)
@@ -301,7 +311,7 @@ def verify_token(token: str, expected_type: str = "access") -> dict:
301
311
  if payload.get("type") != expected_type:
302
312
  raise JWTError(f"Invalid token type: expected {expected_type}, got {payload.get('type')}")
303
313
 
304
- logger.debug(f"Token verified successfully: user_id={payload.get('sub')}, type={expected_type}")
314
+ logger.debug(f"Token verified successfully: channel_id={payload.get('sub')}, type={expected_type}")
305
315
  return payload
306
316
 
307
317
  except JWTError as e:
@@ -22,9 +22,12 @@ from .jwt_manager import verify_token
22
22
  # ============================================================================
23
23
 
24
24
  class InitializationRequest(BaseModel):
25
- """Request to initialize a FastWorkflow session for a user"""
26
- user_id: str
25
+ """Request to initialize a FastWorkflow session for a channel"""
26
+ channel_id: str
27
+ user_id: Optional[str] = None # Required if startup_command or startup_action provided
27
28
  stream_format: Optional[str] = None # "ndjson" | "sse" (default ndjson)
29
+ startup_command: Optional[str] = None # Mutually exclusive with startup_action
30
+ startup_action: Optional[dict[str, Any]] = None # Mutually exclusive with startup_command
28
31
 
29
32
 
30
33
  class TokenResponse(BaseModel):
@@ -35,9 +38,19 @@ class TokenResponse(BaseModel):
35
38
  expires_in: int # Access token expiration in seconds
36
39
 
37
40
 
41
+ class InitializeResponse(BaseModel):
42
+ """Response from initialization including tokens and optional startup output"""
43
+ access_token: str
44
+ refresh_token: str
45
+ token_type: str = "bearer"
46
+ expires_in: int # Access token expiration in seconds
47
+ startup_output: Optional[fastworkflow.CommandOutput] = None # Present if startup was executed
48
+
49
+
38
50
  class SessionData(BaseModel):
39
51
  """Validated session data extracted from JWT token"""
40
- user_id: str
52
+ channel_id: str
53
+ user_id: Optional[str] = None # From JWT uid claim
41
54
  token_type: str # "access" or "refresh"
42
55
  issued_at: int # Unix timestamp
43
56
  expires_at: int # Unix timestamp
@@ -48,7 +61,7 @@ class SessionData(BaseModel):
48
61
  class InvokeRequest(BaseModel):
49
62
  """
50
63
  Request to invoke agent or assistant.
51
- Requires user_id to be passed in the Authorization header (via JWT token).
64
+ Requires channel_id to be passed in the Authorization header (via JWT token).
52
65
  """
53
66
  user_query: str
54
67
  timeout_seconds: int = 60
@@ -57,7 +70,7 @@ class InvokeRequest(BaseModel):
57
70
  class PerformActionRequest(BaseModel):
58
71
  """
59
72
  Request to perform a specific action.
60
- Requires user_id to be passed in the Authorization header (via JWT token).
73
+ Requires channel_id to be passed in the Authorization header (via JWT token).
61
74
  """
62
75
  action: dict[str, Any] # Will be converted to fastworkflow.Action
63
76
  timeout_seconds: int = 60
@@ -66,7 +79,7 @@ class PerformActionRequest(BaseModel):
66
79
  class PostFeedbackRequest(BaseModel):
67
80
  """
68
81
  Request to post feedback on the latest turn.
69
- Requires user_id to be passed in the Authorization header (via JWT token).
82
+ Requires channel_id to be passed in the Authorization header (via JWT token).
70
83
 
71
84
  Note: binary_or_numeric_score accepts numeric values (float).
72
85
  Boolean values (True/False) are automatically converted to 1.0/0.0.
@@ -86,7 +99,7 @@ class PostFeedbackRequest(BaseModel):
86
99
  class ActivateConversationRequest(BaseModel):
87
100
  """
88
101
  Request to activate a conversation by ID.
89
- Requires user_id to be passed in the Authorization header (via JWT token).
102
+ Requires channel_id to be passed in the Authorization header (via JWT token).
90
103
  """
91
104
  conversation_id: int
92
105
 
@@ -98,7 +111,8 @@ class DumpConversationsRequest(BaseModel):
98
111
 
99
112
  class GenerateMCPTokenRequest(BaseModel):
100
113
  """Request to generate a long-lived MCP token"""
101
- user_id: str
114
+ channel_id: str
115
+ user_id: Optional[str] = None
102
116
  expires_days: int = 365
103
117
 
104
118
 
@@ -151,7 +165,7 @@ def get_session_from_jwt(
151
165
  ```python
152
166
  @app.post("/endpoint")
153
167
  async def endpoint(session: SessionData = Depends(get_session_from_jwt)):
154
- # Use session.user_id, session.token_type, etc.
168
+ # Use session.channel_id, session.token_type, etc.
155
169
  pass
156
170
  ```
157
171
 
@@ -177,7 +191,8 @@ def get_session_from_jwt(
177
191
 
178
192
  # Extract session data from payload, including the token for workflow context
179
193
  return SessionData(
180
- user_id=payload["sub"],
194
+ channel_id=payload["sub"],
195
+ user_id=payload.get("uid"), # Optional user_id from uid claim
181
196
  token_type=payload["type"],
182
197
  issued_at=payload["iat"],
183
198
  expires_at=payload["exp"],
@@ -200,8 +215,8 @@ def get_session_from_jwt(
200
215
 
201
216
 
202
217
  async def ensure_user_runtime_exists(
203
- user_id: str,
204
- session_manager: 'UserSessionManager',
218
+ channel_id: str,
219
+ session_manager: 'ChannelSessionManager',
205
220
  workflow_path: str,
206
221
  context: Optional[dict] = None,
207
222
  startup_command: Optional[str] = None,
@@ -216,8 +231,8 @@ async def ensure_user_runtime_exists(
216
231
  allowing it to be reused across different parts of the application without duplicating code.
217
232
 
218
233
  Args:
219
- user_id: The user identifier
220
- session_manager: The UserSessionManager instance
234
+ channel_id: The user identifier
235
+ session_manager: The ChannelSessionManager instance
221
236
  workflow_path: Path to the workflow directory (validated at server startup)
222
237
  context: Optional workflow context dictionary
223
238
  startup_command: Optional startup command
@@ -229,9 +244,9 @@ async def ensure_user_runtime_exists(
229
244
  HTTPException: If session creation fails
230
245
  """
231
246
  # Check if user already has an active session
232
- existing_runtime = await session_manager.get_session(user_id)
247
+ existing_runtime = await session_manager.get_session(channel_id)
233
248
  if existing_runtime:
234
- logger.debug(f"Session for user_id {user_id} already exists, skipping creation")
249
+ logger.debug(f"Session for channel_id {channel_id} already exists, skipping creation")
235
250
 
236
251
  # Update the workflow's context with the current token if provided
237
252
  if http_bearer_token and existing_runtime.chat_session:
@@ -243,7 +258,7 @@ async def ensure_user_runtime_exists(
243
258
  # 2. The workflow is NOT marked dirty (won't persist to disk)
244
259
  # 3. This is intentional for JWT tokens - we don't want to persist sensitive tokens
245
260
  active_workflow.context['http_bearer_token'] = http_bearer_token
246
- logger.debug(f"Updated http_bearer_token in workflow context for user_id {user_id}")
261
+ logger.debug(f"Updated http_bearer_token in workflow context for channel_id {channel_id}")
247
262
 
248
263
  return
249
264
 
@@ -256,13 +271,13 @@ async def ensure_user_runtime_exists(
256
271
  # Initialize context with http_bearer_token
257
272
  context = {'http_bearer_token': http_bearer_token}
258
273
 
259
- logger.info(f"Creating new session for user_id: {user_id}")
274
+ logger.info(f"Creating new session for channel_id: {channel_id}")
260
275
 
261
- # Resolve conversation store base folder from SPEEDDICT_FOLDERNAME/user_conversations
262
- conv_base_folder = get_userconversations_dir()
276
+ # Resolve conversation store base folder from SPEEDDICT_FOLDERNAME/channel_conversations
277
+ conv_base_folder = get_channelconversations_dir()
263
278
 
264
279
  # Create conversation store for this user
265
- conversation_store = ConversationStore(user_id, conv_base_folder)
280
+ conversation_store = ConversationStore(channel_id, conv_base_folder)
266
281
 
267
282
  # Create ChatSession in agent mode (forced)
268
283
  chat_session = fastworkflow.ChatSession(run_as_agent=True)
@@ -280,9 +295,9 @@ async def ensure_user_runtime_exists(
280
295
  # Restore the conversation history from saved turns
281
296
  restored_history = restore_history_from_turns(conversation["turns"])
282
297
  chat_session._conversation_history = restored_history
283
- logger.info(f"Restored conversation {conv_id_to_restore} for user {user_id}")
298
+ logger.info(f"Restored conversation {conv_id_to_restore} for user {channel_id}")
284
299
  else:
285
- logger.info(f"No conversations available for user {user_id}, starting new")
300
+ logger.info(f"No conversations available for user {channel_id}, starting new")
286
301
  conv_id_to_restore = None
287
302
 
288
303
  # Start the workflow
@@ -296,14 +311,14 @@ async def ensure_user_runtime_exists(
296
311
 
297
312
  # Create and store user runtime
298
313
  await session_manager.create_session(
299
- user_id=user_id,
314
+ channel_id=channel_id,
300
315
  chat_session=chat_session,
301
316
  conversation_store=conversation_store,
302
317
  active_conversation_id=conv_id_to_restore,
303
318
  stream_format=stream_format
304
319
  )
305
320
 
306
- logger.info(f"Successfully created session for user_id: {user_id}")
321
+ logger.info(f"Successfully created session for channel_id: {channel_id}")
307
322
 
308
323
  # Wait for workflow to be ready (background thread sets status to RUNNING)
309
324
  import asyncio
@@ -318,19 +333,19 @@ async def ensure_user_runtime_exists(
318
333
  logger.warning(f"Workflow not fully started after {max_wait}s, status={chat_session._status}")
319
334
 
320
335
 
321
- def get_userconversations_dir() -> str:
336
+ def get_channelconversations_dir() -> str:
322
337
  """
323
- Return SPEEDDICT_FOLDERNAME/user_conversations, creating the directory if missing.
338
+ Return SPEEDDICT_FOLDERNAME/channel_conversations, creating the directory if missing.
324
339
  fastworkflow is injected to avoid circular imports and to access get_env_var.
325
340
  """
326
341
  speedict_foldername = fastworkflow.get_env_var("SPEEDDICT_FOLDERNAME")
327
- user_conversations_dir = os.path.join(speedict_foldername, "user_conversations")
342
+ user_conversations_dir = os.path.join(speedict_foldername, "channel_conversations")
328
343
  os.makedirs(user_conversations_dir, exist_ok=True)
329
344
  return user_conversations_dir
330
345
 
331
346
 
332
347
  async def wait_for_command_output(
333
- runtime: 'UserRuntime',
348
+ runtime: 'ChannelRuntime',
334
349
  timeout_seconds: int
335
350
  ) -> 'fastworkflow.CommandOutput':
336
351
  """Wait for command output from the queue with timeout"""
@@ -349,14 +364,60 @@ async def wait_for_command_output(
349
364
  )
350
365
 
351
366
 
352
- def collect_trace_events(runtime: 'UserRuntime') -> list[dict[str, Any]]:
353
- """Drain and collect all trace events from the queue"""
367
+ def collect_trace_events(runtime: 'ChannelRuntime', user_id: Optional[str] = None) -> list[dict[str, Any]]:
368
+ """
369
+ Drain and collect all trace events from the queue.
370
+
371
+ Args:
372
+ runtime: ChannelRuntime containing the trace queue
373
+ user_id: Optional user_id to include in traces
374
+
375
+ Returns:
376
+ List of trace event dictionaries with optional user_id
377
+ """
354
378
  traces = []
355
379
 
356
380
  while True:
357
381
  try:
358
382
  evt = runtime.chat_session.command_trace_queue.get_nowait()
359
- traces.append({
383
+ trace = {
384
+ "direction": evt.direction.value if hasattr(evt.direction, 'value') else str(evt.direction),
385
+ "raw_command": evt.raw_command,
386
+ "command_name": evt.command_name,
387
+ "parameters": evt.parameters,
388
+ "response_text": evt.response_text,
389
+ "success": evt.success,
390
+ "timestamp_ms": evt.timestamp_ms
391
+ }
392
+ if user_id is not None:
393
+ trace["user_id"] = user_id
394
+ traces.append(trace)
395
+ except queue.Empty:
396
+ break
397
+
398
+ return traces
399
+
400
+
401
+ async def collect_trace_events_async(
402
+ trace_queue: queue.Queue,
403
+ user_id: Optional[str] = None
404
+ ) -> list[dict[str, Any]]:
405
+ """
406
+ Async version: Drain and collect all trace events from a trace queue.
407
+
408
+ Args:
409
+ trace_queue: The trace queue to drain
410
+ user_id: Optional user_id to include in traces
411
+
412
+ Returns:
413
+ List of trace event dictionaries with optional user_id
414
+ """
415
+ traces = []
416
+
417
+ while True:
418
+ try:
419
+ evt = trace_queue.get_nowait()
420
+ trace = {
360
421
  "direction": evt.direction.value if hasattr(evt.direction, 'value') else str(evt.direction),
361
422
  "raw_command": evt.raw_command,
362
423
  "command_name": evt.command_name,
@@ -364,7 +425,10 @@ def collect_trace_events(runtime: 'UserRuntime') -> list[dict[str, Any]]:
364
425
  "response_text": evt.response_text,
365
426
  "success": evt.success,
366
427
  "timestamp_ms": evt.timestamp_ms
367
- })
428
+ }
429
+ if user_id is not None:
430
+ trace["user_id"] = user_id
431
+ traces.append(trace)
368
432
  except queue.Empty:
369
433
  break
370
434
 
@@ -376,9 +440,9 @@ def collect_trace_events(runtime: 'UserRuntime') -> list[dict[str, Any]]:
376
440
  # ============================================================================
377
441
 
378
442
  @dataclass
379
- class UserRuntime:
380
- """Per-user runtime state"""
381
- user_id: str
443
+ class ChannelRuntime:
444
+ """Per-channel runtime state"""
445
+ channel_id: str
382
446
  active_conversation_id: int
383
447
  chat_session: 'fastworkflow.ChatSession'
384
448
  lock: asyncio.Lock
@@ -386,51 +450,51 @@ class UserRuntime:
386
450
  stream_format: str = "ndjson" # "ndjson" | "sse"
387
451
 
388
452
 
389
- class UserSessionManager:
390
- """Process-wide manager for user sessions"""
453
+ class ChannelSessionManager:
454
+ """Process-wide manager for channel sessions"""
391
455
 
392
456
  def __init__(self):
393
- self._sessions: dict[str, UserRuntime] = {}
457
+ self._sessions: dict[str, ChannelRuntime] = {}
394
458
  self._lock = asyncio.Lock()
395
459
 
396
- async def get_session(self, user_id: str) -> Optional[UserRuntime]:
397
- """Get a session by user_id"""
460
+ async def get_session(self, channel_id: str) -> Optional[ChannelRuntime]:
461
+ """Get a session by channel_id"""
398
462
  async with self._lock:
399
- return self._sessions.get(user_id)
463
+ return self._sessions.get(channel_id)
400
464
 
401
465
  async def create_session(
402
466
  self,
403
- user_id: str,
467
+ channel_id: str,
404
468
  chat_session: 'fastworkflow.ChatSession',
405
469
  conversation_store: 'ConversationStore',
406
470
  active_conversation_id: Optional[int] = None,
407
471
  stream_format: str = "ndjson"
408
- ) -> UserRuntime:
472
+ ) -> ChannelRuntime:
409
473
  """Create or update a session"""
410
474
  async with self._lock:
411
- runtime = UserRuntime(
412
- user_id=user_id,
475
+ runtime = ChannelRuntime(
476
+ channel_id=channel_id,
413
477
  active_conversation_id=active_conversation_id or 0,
414
478
  chat_session=chat_session,
415
479
  lock=asyncio.Lock(),
416
480
  conversation_store=conversation_store,
417
481
  stream_format=stream_format
418
482
  )
419
- self._sessions[user_id] = runtime
483
+ self._sessions[channel_id] = runtime
420
484
  return runtime
421
485
 
422
- async def remove_session(self, user_id: str) -> None:
486
+ async def remove_session(self, channel_id: str) -> None:
423
487
  """Remove a session"""
424
488
  async with self._lock:
425
- if user_id in self._sessions:
426
- del self._sessions[user_id]
489
+ if channel_id in self._sessions:
490
+ del self._sessions[channel_id]
427
491
 
428
492
 
429
493
  # ============================================================================
430
494
  # Helper Functions
431
495
  # ============================================================================
432
496
 
433
- def save_conversation_incremental(runtime: UserRuntime, extract_turns_func, logger) -> None:
497
+ def save_conversation_incremental(runtime: ChannelRuntime, extract_turns_func, logger) -> None:
434
498
  """
435
499
  Save conversation turns incrementally after each turn (without generating topic/summary).
436
500
  This provides crash protection - all turns except the last will be preserved.
@@ -442,7 +506,7 @@ def save_conversation_incremental(runtime: UserRuntime, extract_turns_func, logg
442
506
  # This is the first conversation for this session
443
507
  # Reserve ID 1 and use it
444
508
  runtime.active_conversation_id = runtime.conversation_store.reserve_next_conversation_id()
445
- logger.debug(f"Initialized first conversation with ID {runtime.active_conversation_id} for user {runtime.user_id}")
509
+ logger.debug(f"Initialized first conversation with ID {runtime.active_conversation_id} for user {runtime.channel_id}")
446
510
 
447
511
  # Save turns using the active conversation ID
448
512
  runtime.conversation_store.save_conversation_turns(
@@ -114,7 +114,7 @@ def _get_commands_with_parameters(json_path):
114
114
  command_directory = json.load(f)
115
115
 
116
116
  # Extract the command metadata
117
- commands_metadata = command_directory.get("map_commandkey_2_metadata", {})
117
+ commands_metadata = command_directory.get("map_command_2_metadata", {})
118
118
 
119
119
  # Initialize result dictionary
120
120
  commands_with_parameters = {}
@@ -2,6 +2,7 @@ import logging
2
2
  from typing import TYPE_CHECKING, Any, Callable, Literal
3
3
 
4
4
  from litellm import ContextWindowExceededError
5
+ from litellm import exceptions as litellm_exceptions
5
6
 
6
7
  import dspy
7
8
  from dspy.adapters.types.tool import Tool
@@ -41,6 +42,7 @@ class fastWorkflowReAct(Module):
41
42
  super().__init__()
42
43
  self.signature = signature = ensure_signature(signature)
43
44
  self.max_iters = max_iters
45
+ self.iteration_counter = 0
44
46
 
45
47
  tools = [t if isinstance(t, Tool) else Tool(t) for t in tools]
46
48
  tools = {tool.name: tool for tool in tools}
@@ -105,7 +107,8 @@ class fastWorkflowReAct(Module):
105
107
 
106
108
  trajectory = {}
107
109
  max_iters = input_args.pop("max_iters", self.max_iters)
108
- for idx in range(max_iters):
110
+ idx = 0
111
+ while True:
109
112
  try:
110
113
  pred = self._call_with_potential_trajectory_truncation(self.react, trajectory, **input_args)
111
114
  except ValueError as err:
@@ -126,6 +129,12 @@ class fastWorkflowReAct(Module):
126
129
  if pred.next_tool_name == "finish":
127
130
  break
128
131
 
132
+ idx += 1 # this is the counter for the index of the entire trajectory
133
+ self.iteration_counter += 1 # this counter just determines the number of times we run the react agent and it's reset everytime we call the user for clarification
134
+ if self.iteration_counter >= max_iters:
135
+ logger.warning("Max iterations reached")
136
+ break
137
+
129
138
  extract = self._call_with_potential_trajectory_truncation(self.extract, trajectory, **input_args)
130
139
  return dspy.Prediction(trajectory=trajectory, **extract)
131
140
 
@@ -161,6 +170,9 @@ class fastWorkflowReAct(Module):
161
170
  **input_args,
162
171
  trajectory=self._format_trajectory(trajectory),
163
172
  )
173
+ except litellm_exceptions.BadRequestError:
174
+ logger.warning("Trajectory exceeded the context window, truncating the oldest tool call information.")
175
+ trajectory = self.truncate_trajectory(trajectory)
164
176
  except ContextWindowExceededError:
165
177
  logger.warning("Trajectory exceeded the context window, truncating the oldest tool call information.")
166
178
  trajectory = self.truncate_trajectory(trajectory)