dtSpark 1.1.0a2__py3-none-any.whl → 1.1.0a6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. dtSpark/_version.txt +1 -1
  2. dtSpark/aws/authentication.py +1 -1
  3. dtSpark/aws/bedrock.py +238 -239
  4. dtSpark/aws/costs.py +9 -5
  5. dtSpark/aws/pricing.py +25 -21
  6. dtSpark/cli_interface.py +69 -62
  7. dtSpark/conversation_manager.py +54 -47
  8. dtSpark/core/application.py +151 -111
  9. dtSpark/core/context_compaction.py +241 -226
  10. dtSpark/daemon/__init__.py +36 -22
  11. dtSpark/daemon/action_monitor.py +46 -17
  12. dtSpark/daemon/daemon_app.py +126 -104
  13. dtSpark/daemon/daemon_manager.py +59 -23
  14. dtSpark/daemon/pid_file.py +3 -2
  15. dtSpark/database/autonomous_actions.py +3 -0
  16. dtSpark/database/credential_prompt.py +52 -54
  17. dtSpark/files/manager.py +6 -12
  18. dtSpark/limits/__init__.py +1 -1
  19. dtSpark/limits/tokens.py +2 -2
  20. dtSpark/llm/anthropic_direct.py +246 -141
  21. dtSpark/llm/ollama.py +3 -1
  22. dtSpark/mcp_integration/manager.py +4 -4
  23. dtSpark/mcp_integration/tool_selector.py +83 -77
  24. dtSpark/resources/config.yaml.template +10 -0
  25. dtSpark/safety/patterns.py +45 -46
  26. dtSpark/safety/prompt_inspector.py +8 -1
  27. dtSpark/scheduler/creation_tools.py +273 -181
  28. dtSpark/scheduler/executor.py +503 -221
  29. dtSpark/tools/builtin.py +70 -53
  30. dtSpark/web/endpoints/autonomous_actions.py +12 -9
  31. dtSpark/web/endpoints/chat.py +18 -6
  32. dtSpark/web/endpoints/conversations.py +57 -17
  33. dtSpark/web/endpoints/main_menu.py +132 -105
  34. dtSpark/web/endpoints/streaming.py +2 -2
  35. dtSpark/web/server.py +65 -5
  36. dtSpark/web/ssl_utils.py +3 -3
  37. dtSpark/web/static/css/dark-theme.css +8 -29
  38. dtSpark/web/static/js/actions.js +2 -1
  39. dtSpark/web/static/js/chat.js +6 -8
  40. dtSpark/web/static/js/main.js +8 -8
  41. dtSpark/web/static/js/sse-client.js +130 -122
  42. dtSpark/web/templates/actions.html +5 -5
  43. dtSpark/web/templates/base.html +13 -0
  44. dtSpark/web/templates/chat.html +52 -50
  45. dtSpark/web/templates/conversations.html +50 -22
  46. dtSpark/web/templates/goodbye.html +2 -2
  47. dtSpark/web/templates/main_menu.html +17 -17
  48. dtSpark/web/templates/new_conversation.html +51 -20
  49. dtSpark/web/web_interface.py +2 -2
  50. {dtspark-1.1.0a2.dist-info → dtspark-1.1.0a6.dist-info}/METADATA +9 -2
  51. dtspark-1.1.0a6.dist-info/RECORD +96 -0
  52. dtspark-1.1.0a2.dist-info/RECORD +0 -96
  53. {dtspark-1.1.0a2.dist-info → dtspark-1.1.0a6.dist-info}/WHEEL +0 -0
  54. {dtspark-1.1.0a2.dist-info → dtspark-1.1.0a6.dist-info}/entry_points.txt +0 -0
  55. {dtspark-1.1.0a2.dist-info → dtspark-1.1.0a6.dist-info}/licenses/LICENSE +0 -0
  56. {dtspark-1.1.0a2.dist-info → dtspark-1.1.0a6.dist-info}/top_level.txt +0 -0
@@ -156,26 +156,55 @@ class ActionChangeMonitor:
156
156
  current_ids.add(action_id)
157
157
 
158
158
  if action_id not in self._known_actions:
159
- # New action detected
160
- logger.info(f"New action detected: {action['name']} (ID: {action_id})")
161
- self._known_actions[action_id] = version
162
- if self.on_action_added:
163
- try:
164
- self.on_action_added(action)
165
- except Exception as e:
166
- logger.error(f"Error in on_action_added callback: {e}")
167
-
159
+ self._handle_new_action(action, version)
168
160
  elif self._known_actions[action_id] != version:
169
- # Modified action detected
170
- logger.info(f"Action modified: {action['name']} (ID: {action_id}, v{self._known_actions[action_id]} -> v{version})")
171
- self._known_actions[action_id] = version
172
- if self.on_action_modified:
173
- try:
174
- self.on_action_modified(action)
175
- except Exception as e:
176
- logger.error(f"Error in on_action_modified callback: {e}")
161
+ self._handle_modified_action(action, version)
177
162
 
178
163
  # Check for deleted actions
164
+ self._handle_deleted_actions(current_ids)
165
+
166
+ def _handle_new_action(self, action: Dict, version: int) -> None:
167
+ """
168
+ Process a newly detected action.
169
+
170
+ Args:
171
+ action: The action dictionary from the database
172
+ version: The action's version number
173
+ """
174
+ action_id = action['id']
175
+ logger.info(f"New action detected: {action['name']} (ID: {action_id})")
176
+ self._known_actions[action_id] = version
177
+ if self.on_action_added:
178
+ try:
179
+ self.on_action_added(action)
180
+ except Exception as e:
181
+ logger.error(f"Error in on_action_added callback: {e}")
182
+
183
+ def _handle_modified_action(self, action: Dict, version: int) -> None:
184
+ """
185
+ Process a modified action.
186
+
187
+ Args:
188
+ action: The action dictionary from the database
189
+ version: The action's new version number
190
+ """
191
+ action_id = action['id']
192
+ old_version = self._known_actions[action_id]
193
+ logger.info(f"Action modified: {action['name']} (ID: {action_id}, v{old_version} -> v{version})")
194
+ self._known_actions[action_id] = version
195
+ if self.on_action_modified:
196
+ try:
197
+ self.on_action_modified(action)
198
+ except Exception as e:
199
+ logger.error(f"Error in on_action_modified callback: {e}")
200
+
201
+ def _handle_deleted_actions(self, current_ids: set) -> None:
202
+ """
203
+ Process actions that have been deleted from the database.
204
+
205
+ Args:
206
+ current_ids: Set of action IDs currently present in the database
207
+ """
179
208
  deleted_ids = set(self._known_actions.keys()) - current_ids
180
209
  for action_id in deleted_ids:
181
210
  logger.info(f"Action deleted: ID {action_id}")
@@ -190,44 +190,9 @@ class DaemonApplication(AbstractApp):
190
190
  logger.info("Waiting for shutdown signal (SIGTERM/SIGINT)...")
191
191
  logger.info("=" * 60)
192
192
 
193
- # Set up signal handlers for graceful shutdown
194
- import signal
195
-
196
- def signal_handler(signum, frame):
197
- print(f"\nReceived signal {signum}, initiating shutdown...")
198
- logger.info(f"Received signal {signum}, initiating shutdown...")
199
- self._shutdown_event.set()
200
-
201
- signal.signal(signal.SIGINT, signal_handler)
202
- signal.signal(signal.SIGTERM, signal_handler)
203
- if sys.platform == 'win32':
204
- signal.signal(signal.SIGBREAK, signal_handler)
205
-
206
- # Block until shutdown signal
207
- # Use polling with timeout to allow signal processing on Windows
208
- # Also check for stop signal file (Windows cross-console shutdown)
209
- stop_signal_file = pid_file_path + '.stop'
210
-
211
- while not self._shutdown_event.is_set():
212
- try:
213
- # Check for stop signal file (used by 'daemon stop' on Windows)
214
- if os.path.exists(stop_signal_file):
215
- print("\nStop signal file detected, initiating shutdown...")
216
- logger.info("Stop signal file detected, initiating shutdown...")
217
- # Remove the signal file
218
- try:
219
- os.remove(stop_signal_file)
220
- except Exception:
221
- pass
222
- self._shutdown_event.set()
223
- break
224
-
225
- self._shutdown_event.wait(timeout=1.0)
226
- except KeyboardInterrupt:
227
- print("\nKeyboard interrupt received, initiating shutdown...")
228
- logger.info("Keyboard interrupt received, initiating shutdown...")
229
- self._shutdown_event.set()
230
- break
193
+ # Set up signal handlers and wait for shutdown
194
+ self._setup_signal_handlers()
195
+ self._wait_for_shutdown(pid_file_path)
231
196
 
232
197
  print("Shutdown signal received")
233
198
  logger.info("Shutdown signal received")
@@ -245,6 +210,51 @@ class DaemonApplication(AbstractApp):
245
210
 
246
211
  return 0
247
212
 
213
+ def _setup_signal_handlers(self):
214
+ """Register OS signal handlers for graceful shutdown."""
215
+ import signal
216
+
217
+ def signal_handler(signum, frame):
218
+ print(f"\nReceived signal {signum}, initiating shutdown...")
219
+ logger.info(f"Received signal {signum}, initiating shutdown...")
220
+ self._shutdown_event.set()
221
+
222
+ signal.signal(signal.SIGINT, signal_handler)
223
+ signal.signal(signal.SIGTERM, signal_handler)
224
+ if sys.platform == 'win32':
225
+ signal.signal(signal.SIGBREAK, signal_handler)
226
+
227
+ def _wait_for_shutdown(self, pid_file_path: str):
228
+ """
229
+ Block until a shutdown signal is received.
230
+
231
+ Polls for stop signal file (Windows cross-console shutdown) and
232
+ handles keyboard interrupts.
233
+
234
+ Args:
235
+ pid_file_path: Path to the PID file (used to derive stop signal file path)
236
+ """
237
+ stop_signal_file = pid_file_path + '.stop'
238
+
239
+ while not self._shutdown_event.is_set():
240
+ try:
241
+ if os.path.exists(stop_signal_file):
242
+ print("\nStop signal file detected, initiating shutdown...")
243
+ logger.info("Stop signal file detected, initiating shutdown...")
244
+ try:
245
+ os.remove(stop_signal_file)
246
+ except Exception:
247
+ pass
248
+ self._shutdown_event.set()
249
+ break
250
+
251
+ self._shutdown_event.wait(timeout=1.0)
252
+ except KeyboardInterrupt:
253
+ print("\nKeyboard interrupt received, initiating shutdown...")
254
+ logger.info("Keyboard interrupt received, initiating shutdown...")
255
+ self._shutdown_event.set()
256
+ break
257
+
248
258
  def _initialise_components(self):
249
259
  """Initialise database, LLM manager, and scheduler components."""
250
260
  print(" - Initialising daemon components...")
@@ -288,92 +298,104 @@ class DaemonApplication(AbstractApp):
288
298
 
289
299
  def _configure_llm_providers(self):
290
300
  """Configure LLM providers based on settings."""
291
- # AWS Bedrock
301
+ self._configure_aws_bedrock()
302
+ self._configure_anthropic_direct()
303
+ self._configure_ollama()
304
+
305
+ # Log summary of configured providers
306
+ providers = list(self.llm_manager.providers.keys())
307
+ if providers:
308
+ print(f" - LLM providers configured: {', '.join(providers)}")
309
+ logger.info(f"LLM providers configured: {providers}")
310
+ else:
311
+ print(" - Warning: No LLM providers configured!")
312
+ logger.warning("No LLM providers configured - actions will fail to execute")
313
+
314
+ def _configure_aws_bedrock(self):
315
+ """Configure AWS Bedrock LLM provider if enabled."""
292
316
  aws_enabled = self._get_nested_setting('llm_providers.aws_bedrock.enabled', True)
293
- if aws_enabled:
294
- try:
295
- from dtSpark.llm import BedrockService
296
- from dtSpark.aws.authenticator import AWSAuthenticator
317
+ if not aws_enabled:
318
+ return
319
+
320
+ try:
321
+ from dtSpark.llm import BedrockService
322
+ from dtSpark.aws.authenticator import AWSAuthenticator
297
323
 
298
- aws_region = self._get_nested_setting('llm_providers.aws_bedrock.region', 'us-east-1')
299
- aws_profile = self._get_nested_setting('llm_providers.aws_bedrock.sso_profile', 'default')
300
- request_timeout = self.settings.get('bedrock.request_timeout', 300)
324
+ aws_region = self._get_nested_setting('llm_providers.aws_bedrock.region', 'us-east-1')
325
+ aws_profile = self._get_nested_setting('llm_providers.aws_bedrock.sso_profile', 'default')
326
+ request_timeout = self.settings.get('bedrock.request_timeout', 300)
301
327
 
302
- # Check for API key authentication
303
- aws_access_key_id = self._get_nested_setting('llm_providers.aws_bedrock.access_key_id', None)
304
- aws_secret_access_key = self._get_nested_setting('llm_providers.aws_bedrock.secret_access_key', None)
328
+ aws_access_key_id = self._get_nested_setting('llm_providers.aws_bedrock.access_key_id', None)
329
+ aws_secret_access_key = self._get_nested_setting('llm_providers.aws_bedrock.secret_access_key', None)
305
330
 
306
- authenticator = AWSAuthenticator(
331
+ authenticator = AWSAuthenticator(
332
+ region=aws_region,
333
+ sso_profile=aws_profile,
334
+ access_key_id=aws_access_key_id,
335
+ secret_access_key=aws_secret_access_key
336
+ )
337
+
338
+ if authenticator.authenticate():
339
+ bedrock_service = BedrockService(
340
+ session=authenticator.session,
307
341
  region=aws_region,
308
- sso_profile=aws_profile,
309
- access_key_id=aws_access_key_id,
310
- secret_access_key=aws_secret_access_key
342
+ request_timeout=request_timeout
311
343
  )
344
+ self.llm_manager.register_provider(bedrock_service)
345
+ logger.info("AWS Bedrock provider configured")
312
346
 
313
- if authenticator.authenticate():
314
- bedrock_service = BedrockService(
315
- session=authenticator.session,
316
- region=aws_region,
317
- request_timeout=request_timeout
318
- )
319
- self.llm_manager.register_provider(bedrock_service)
320
- logger.info("AWS Bedrock provider configured")
321
-
322
- except Exception as e:
323
- logger.warning(f"Failed to configure AWS Bedrock: {e}")
347
+ except Exception as e:
348
+ logger.warning(f"Failed to configure AWS Bedrock: {e}")
324
349
 
325
- # Anthropic Direct
350
+ def _configure_anthropic_direct(self):
351
+ """Configure Anthropic Direct LLM provider if enabled."""
326
352
  anthropic_enabled = self._get_nested_setting('llm_providers.anthropic.enabled', False)
327
353
  logger.debug(f"Anthropic Direct enabled: {anthropic_enabled}")
328
- if anthropic_enabled:
329
- try:
330
- from dtSpark.llm import AnthropicService
354
+ if not anthropic_enabled:
355
+ return
331
356
 
332
- api_key = self._get_nested_setting('llm_providers.anthropic.api_key', None)
333
- max_tokens = self.settings.get('bedrock.max_tokens', 8192)
357
+ try:
358
+ from dtSpark.llm import AnthropicService
334
359
 
335
- # Log whether API key was found (don't log the actual key)
336
- if api_key:
337
- logger.info(f"Anthropic API key found (starts with: {api_key[:10] if len(api_key) > 10 else 'SHORT'}...)")
338
- else:
339
- logger.warning("Anthropic API key not found in settings")
360
+ api_key = self._get_nested_setting('llm_providers.anthropic.api_key', None)
361
+ max_tokens = self.settings.get('bedrock.max_tokens', 8192)
340
362
 
341
- anthropic_service = AnthropicService(
342
- api_key=api_key,
343
- default_max_tokens=max_tokens
344
- )
345
- self.llm_manager.register_provider(anthropic_service)
346
- print(f" - Anthropic Direct provider configured")
347
- logger.info("Anthropic Direct provider configured")
363
+ if api_key:
364
+ key_prefix = api_key[:10] if len(api_key) > 10 else 'SHORT'
365
+ logger.info("Anthropic API key found (starts with: %s...)", key_prefix)
366
+ else:
367
+ logger.warning("Anthropic API key not found in settings")
348
368
 
349
- except Exception as e:
350
- print(f" - Warning: Failed to configure Anthropic Direct: {e}")
351
- logger.warning(f"Failed to configure Anthropic Direct: {e}")
369
+ anthropic_service = AnthropicService(
370
+ api_key=api_key,
371
+ default_max_tokens=max_tokens
372
+ )
373
+ self.llm_manager.register_provider(anthropic_service)
374
+ print(" - Anthropic Direct provider configured")
375
+ logger.info("Anthropic Direct provider configured")
376
+
377
+ except Exception as e:
378
+ print(f" - Warning: Failed to configure Anthropic Direct: {e}")
379
+ logger.warning(f"Failed to configure Anthropic Direct: {e}")
352
380
 
353
- # Ollama
381
+ def _configure_ollama(self):
382
+ """Configure Ollama LLM provider if enabled."""
354
383
  ollama_enabled = self._get_nested_setting('llm_providers.ollama.enabled', False)
355
- if ollama_enabled:
356
- try:
357
- from dtSpark.llm import OllamaService
384
+ if not ollama_enabled:
385
+ return
358
386
 
359
- base_url = self._get_nested_setting('llm_providers.ollama.base_url', 'http://localhost:11434')
360
- verify_ssl = self._get_nested_setting('llm_providers.ollama.verify_ssl', True)
387
+ try:
388
+ from dtSpark.llm import OllamaService
361
389
 
362
- ollama_service = OllamaService(base_url=base_url, verify_ssl=verify_ssl)
363
- self.llm_manager.register_provider(ollama_service)
364
- logger.info("Ollama provider configured")
390
+ base_url = self._get_nested_setting('llm_providers.ollama.base_url', 'http://localhost:11434')
391
+ verify_ssl = self._get_nested_setting('llm_providers.ollama.verify_ssl', True)
365
392
 
366
- except Exception as e:
367
- logger.warning(f"Failed to configure Ollama: {e}")
393
+ ollama_service = OllamaService(base_url=base_url, verify_ssl=verify_ssl)
394
+ self.llm_manager.register_provider(ollama_service)
395
+ logger.info("Ollama provider configured")
368
396
 
369
- # Log summary of configured providers
370
- providers = list(self.llm_manager.providers.keys())
371
- if providers:
372
- print(f" - LLM providers configured: {', '.join(providers)}")
373
- logger.info(f"LLM providers configured: {providers}")
374
- else:
375
- print(" - Warning: No LLM providers configured!")
376
- logger.warning("No LLM providers configured - actions will fail to execute")
397
+ except Exception as e:
398
+ logger.warning(f"Failed to configure Ollama: {e}")
377
399
 
378
400
  def _initialise_mcp(self):
379
401
  """Initialise MCP manager if enabled."""
@@ -190,21 +190,39 @@ class DaemonManager:
190
190
  print(f"Stopping daemon (PID: {pid})...")
191
191
 
192
192
  # Send termination signal
193
+ signal_result = self._send_stop_signal(pid)
194
+ if signal_result is not None:
195
+ return signal_result
196
+
197
+ # Wait for graceful shutdown
198
+ for i in range(timeout):
199
+ if not self.pid_file.is_running():
200
+ print("Daemon stopped")
201
+ self._cleanup_stop_file()
202
+ return 0
203
+ time.sleep(1)
204
+ if (i + 1) % 5 == 0:
205
+ print(f"Waiting for shutdown... ({i + 1}/{timeout}s)")
206
+
207
+ # Process didn't stop gracefully - clean up and report
208
+ self._cleanup_stop_file()
209
+ return self._handle_stop_timeout(pid, timeout)
210
+
211
+ def _send_stop_signal(self, pid: int) -> Optional[int]:
212
+ """
213
+ Send a termination signal to the daemon process.
214
+
215
+ Args:
216
+ pid: Process ID of the daemon
217
+
218
+ Returns:
219
+ Exit code if the stop completed immediately (success or failure),
220
+ or None if the caller should wait for shutdown
221
+ """
193
222
  try:
194
223
  if sys.platform == 'win32':
195
- # Windows: Create a stop signal file that the daemon checks
196
- # This is more reliable than signals across console sessions
197
- stop_file = str(self.pid_file.path) + '.stop'
198
- try:
199
- with open(stop_file, 'w') as f:
200
- f.write(str(pid))
201
- print("Stop signal sent")
202
- except Exception as e:
203
- print(f"Failed to create stop signal: {e}")
204
- # Fall back to taskkill /F if signal file fails
205
- subprocess.run(['taskkill', '/F', '/PID', str(pid)], capture_output=True)
224
+ self._send_stop_signal_windows(pid)
206
225
  else:
207
- # Unix: SIGTERM for graceful shutdown
208
226
  os.kill(pid, signal.SIGTERM)
209
227
  except ProcessLookupError:
210
228
  print("Daemon process not found")
@@ -216,19 +234,37 @@ class DaemonManager:
216
234
  except Exception as e:
217
235
  print(f"Error stopping daemon: {e}")
218
236
  return 1
237
+ return None
219
238
 
220
- # Wait for graceful shutdown
221
- for i in range(timeout):
222
- if not self.pid_file.is_running():
223
- print("Daemon stopped")
224
- self._cleanup_stop_file()
225
- return 0
226
- time.sleep(1)
227
- if (i + 1) % 5 == 0:
228
- print(f"Waiting for shutdown... ({i + 1}/{timeout}s)")
239
+ def _send_stop_signal_windows(self, pid: int) -> None:
240
+ """
241
+ Send a stop signal on Windows using a signal file.
229
242
 
230
- # Process didn't stop gracefully - clean up and report
231
- self._cleanup_stop_file()
243
+ Falls back to taskkill if the signal file cannot be created.
244
+
245
+ Args:
246
+ pid: Process ID of the daemon
247
+ """
248
+ stop_file = str(self.pid_file.path) + '.stop'
249
+ try:
250
+ with open(stop_file, 'w') as f:
251
+ f.write(str(pid))
252
+ print("Stop signal sent")
253
+ except Exception as e:
254
+ print(f"Failed to create stop signal: {e}")
255
+ subprocess.run(['taskkill', '/F', '/PID', str(pid)], capture_output=True)
256
+
257
+ def _handle_stop_timeout(self, pid: int, timeout: int) -> int:
258
+ """
259
+ Handle the case where the daemon did not stop within the timeout.
260
+
261
+ Args:
262
+ pid: Process ID of the daemon
263
+ timeout: The timeout that was exceeded
264
+
265
+ Returns:
266
+ 0 if force-terminated successfully, 1 otherwise
267
+ """
232
268
  print(f"Daemon did not stop within {timeout} seconds")
233
269
  if sys.platform == 'win32':
234
270
  print("Forcing termination...")
@@ -108,8 +108,9 @@ class PIDFile:
108
108
  import ctypes
109
109
  kernel32 = ctypes.windll.kernel32
110
110
 
111
- # PROCESS_QUERY_LIMITED_INFORMATION = 0x1000
112
- handle = kernel32.OpenProcess(0x1000, False, pid)
111
+ # Windows constant: PROCESS_QUERY_LIMITED_INFORMATION
112
+ process_query_limited_information = 0x1000
113
+ handle = kernel32.OpenProcess(process_query_limited_information, False, pid)
113
114
  if handle:
114
115
  kernel32.CloseHandle(handle)
115
116
  return True
@@ -1059,6 +1059,9 @@ def register_daemon(
1059
1059
  conn.commit()
1060
1060
  logging.info(f"Daemon re-registered: {daemon_id} (PID: {pid})")
1061
1061
  return True
1062
+ except sqlite3.Error as e:
1063
+ logging.error(f"Failed to register daemon {daemon_id}: {e}")
1064
+ return False
1062
1065
 
1063
1066
 
1064
1067
  def update_daemon_heartbeat(
@@ -31,10 +31,7 @@ def prompt_for_credentials(db_type: str, existing_credentials: Optional[Database
31
31
  db_type_lower = db_type.lower()
32
32
 
33
33
  # Start with existing credentials or create new
34
- if existing_credentials:
35
- creds = existing_credentials
36
- else:
37
- creds = DatabaseCredentials()
34
+ creds = existing_credentials if existing_credentials else DatabaseCredentials()
38
35
 
39
36
  # SQLite only needs path
40
37
  if db_type_lower == 'sqlite':
@@ -46,6 +43,23 @@ def prompt_for_credentials(db_type: str, existing_credentials: Optional[Database
46
43
  return creds
47
44
 
48
45
  # Display information panel
46
+ _display_connection_panel(console, db_type)
47
+
48
+ # Prompt for remote database credentials
49
+ _prompt_remote_credentials(console, creds, db_type_lower)
50
+
51
+ # MSSQL-specific: driver selection
52
+ if db_type_lower in ('mssql', 'sqlserver', 'mssqlserver'):
53
+ _prompt_mssql_driver(console, creds)
54
+
55
+ console.print()
56
+ logging.info(f"Database credentials collected for {db_type}")
57
+
58
+ return creds
59
+
60
+
61
+ def _display_connection_panel(console: Console, db_type: str) -> None:
62
+ """Display the database connection setup information panel."""
49
63
  console.print()
50
64
  console.print(Panel(
51
65
  f"[bold cyan]Database Connection Setup[/bold cyan]\n\n"
@@ -56,78 +70,62 @@ def prompt_for_credentials(db_type: str, existing_credentials: Optional[Database
56
70
  ))
57
71
  console.print()
58
72
 
59
- # Prompt for remote database credentials
73
+
74
+ def _prompt_remote_credentials(console: Console, creds: DatabaseCredentials, db_type_lower: str) -> None:
75
+ """Prompt for remote database connection credentials (host, port, database, username, password, SSL)."""
60
76
  if not creds.host:
61
- creds.host = Prompt.ask(
62
- "Database host",
63
- default="localhost"
64
- )
77
+ creds.host = Prompt.ask("Database host", default="localhost")
65
78
 
66
79
  if not creds.port:
67
80
  default_ports = {
68
- 'mysql': 3306,
69
- 'mariadb': 3306,
70
- 'postgresql': 5432,
71
- 'mssql': 1433,
72
- 'sqlserver': 1433
81
+ 'mysql': 3306, 'mariadb': 3306,
82
+ 'postgresql': 5432, 'mssql': 1433, 'sqlserver': 1433,
73
83
  }
74
84
  default_port = default_ports.get(db_type_lower, 3306)
75
-
76
- port_input = Prompt.ask(
77
- "Database port",
78
- default=str(default_port)
79
- )
85
+ port_input = Prompt.ask("Database port", default=str(default_port))
80
86
  creds.port = int(port_input)
81
87
 
82
88
  if not creds.database:
83
- creds.database = Prompt.ask(
84
- "Database name",
85
- default="dtawsbedrockcli"
86
- )
89
+ creds.database = Prompt.ask("Database name", default="dtawsbedrockcli")
87
90
 
88
91
  if not creds.username:
89
92
  creds.username = Prompt.ask("Database username")
90
93
 
91
94
  if not creds.password:
92
- # Use getpass for secure password input
93
95
  console.print("[cyan]Database password:[/cyan] ", end="")
94
96
  creds.password = getpass.getpass("")
95
97
 
96
- # SSL option
97
- use_ssl = Confirm.ask("Use SSL/TLS connection?", default=False)
98
- creds.ssl = use_ssl
98
+ creds.ssl = Confirm.ask("Use SSL/TLS connection?", default=False)
99
99
 
100
- # MSSQL-specific: driver selection
101
- if db_type_lower in ('mssql', 'sqlserver', 'mssqlserver'):
102
- if not creds.driver:
103
- drivers = [
104
- "ODBC Driver 17 for SQL Server",
105
- "ODBC Driver 18 for SQL Server",
106
- "SQL Server Native Client 11.0",
107
- "Custom"
108
- ]
109
100
 
110
- console.print()
111
- console.print("[bold cyan]Select ODBC driver:[/bold cyan]")
112
- for i, driver in enumerate(drivers, 1):
113
- console.print(f" [{i}] {driver}")
114
-
115
- choice = Prompt.ask(
116
- "Driver selection",
117
- choices=[str(i) for i in range(1, len(drivers) + 1)],
118
- default="1"
119
- )
101
+ def _prompt_mssql_driver(console: Console, creds: DatabaseCredentials) -> None:
102
+ """Prompt for MSSQL ODBC driver selection if not already set."""
103
+ if creds.driver:
104
+ return
120
105
 
121
- choice_idx = int(choice) - 1
122
- if choice_idx == len(drivers) - 1: # Custom
123
- creds.driver = Prompt.ask("Enter custom ODBC driver name")
124
- else:
125
- creds.driver = drivers[choice_idx]
106
+ drivers = [
107
+ "ODBC Driver 17 for SQL Server",
108
+ "ODBC Driver 18 for SQL Server",
109
+ "SQL Server Native Client 11.0",
110
+ "Custom",
111
+ ]
126
112
 
127
113
  console.print()
128
- logging.info(f"Database credentials collected for {db_type}")
129
-
130
- return creds
114
+ console.print("[bold cyan]Select ODBC driver:[/bold cyan]")
115
+ for i, driver in enumerate(drivers, 1):
116
+ console.print(f" [{i}] {driver}")
117
+
118
+ choice = Prompt.ask(
119
+ "Driver selection",
120
+ choices=[str(i) for i in range(1, len(drivers) + 1)],
121
+ default="1",
122
+ )
123
+
124
+ choice_idx = int(choice) - 1
125
+ if choice_idx == len(drivers) - 1: # Custom
126
+ creds.driver = Prompt.ask("Enter custom ODBC driver name")
127
+ else:
128
+ creds.driver = drivers[choice_idx]
131
129
 
132
130
 
133
131
  def test_credentials(db_type: str, credentials: DatabaseCredentials) -> tuple[bool, Optional[str]]:
dtSpark/files/manager.py CHANGED
@@ -306,17 +306,11 @@ class FileManager:
306
306
  if not dir_path.is_dir():
307
307
  raise NotADirectoryError(f"Not a directory: {directory_path}")
308
308
 
309
- supported_files = []
310
-
311
- if recursive:
312
- # Recursively find all files
313
- for file_path in dir_path.rglob('*'):
314
- if file_path.is_file() and cls.is_supported(str(file_path)):
315
- supported_files.append(str(file_path.absolute()))
316
- else:
317
- # Only scan immediate directory
318
- for file_path in dir_path.iterdir():
319
- if file_path.is_file() and cls.is_supported(str(file_path)):
320
- supported_files.append(str(file_path.absolute()))
309
+ iterator = dir_path.rglob('*') if recursive else dir_path.iterdir()
310
+ supported_files = [
311
+ str(file_path.absolute())
312
+ for file_path in iterator
313
+ if file_path.is_file() and cls.is_supported(str(file_path))
314
+ ]
321
315
 
322
316
  return sorted(supported_files)
@@ -2,7 +2,7 @@
2
2
  from .tokens import TokenManager, LimitStatus
3
3
  try:
4
4
  from .costs import CostManager
5
- except:
5
+ except ImportError:
6
6
  CostManager = None
7
7
 
8
8
  __all__ = ['TokenManager', 'LimitStatus']