decodingtrust-agent-sdk 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agent/__init__.py +30 -0
- agent/claudesdk/__init__.py +8 -0
- agent/claudesdk/example.py +221 -0
- agent/claudesdk/src/__init__.py +8 -0
- agent/claudesdk/src/agent.py +400 -0
- agent/claudesdk/src/mcp_proxy.py +409 -0
- agent/claudesdk/src/utils.py +420 -0
- agent/googleadk/__init__.py +15 -0
- agent/googleadk/example.py +237 -0
- agent/googleadk/src/__init__.py +12 -0
- agent/googleadk/src/agent.py +401 -0
- agent/googleadk/src/mcp_wrapper.py +163 -0
- agent/googleadk/src/utils.py +602 -0
- agent/langchain/__init__.py +8 -0
- agent/langchain/example.py +213 -0
- agent/langchain/src/__init__.py +8 -0
- agent/langchain/src/agent.py +645 -0
- agent/langchain/src/utils.py +433 -0
- agent/openaisdk/__init__.py +17 -0
- agent/openaisdk/example.py +228 -0
- agent/openaisdk/src/__init__.py +12 -0
- agent/openaisdk/src/agent.py +491 -0
- agent/openaisdk/src/agent_wrapper.py +143 -0
- agent/openaisdk/src/mcp_wrapper.py +395 -0
- agent/openaisdk/src/utils.py +493 -0
- agent/openclaw/__init__.py +10 -0
- agent/openclaw/example.py +251 -0
- agent/openclaw/src/__init__.py +14 -0
- agent/openclaw/src/agent.py +930 -0
- agent/openclaw/src/helpers/__init__.py +1 -0
- agent/openclaw/src/helpers/auth_helpers.py +55 -0
- agent/openclaw/src/mcp_proxy.py +564 -0
- agent/openclaw/src/plugin_generator.py +231 -0
- agent/openclaw/src/utils.py +341 -0
- agent/pocketflow/__init__.py +18 -0
- agent/pocketflow/example.py +221 -0
- agent/pocketflow/prompts/react_agent.py +46 -0
- agent/pocketflow/src/__init__.py +6 -0
- agent/pocketflow/src/agent.py +507 -0
- agent/pocketflow/src/agent_wrapper.py +159 -0
- agent/pocketflow/src/async_helper.py +92 -0
- agent/pocketflow/src/mcp_react_agent.py +279 -0
- agent/pocketflow/src/native_agent.py +74 -0
- agent/pocketflow/src/nodes.py +467 -0
- benchmark/__init__.py +0 -0
- benchmark/browser/benign.jsonl +34 -0
- benchmark/browser/direct.jsonl +85 -0
- benchmark/browser/indirect.jsonl +82 -0
- benchmark/code/benign.jsonl +0 -0
- benchmark/code/direct.jsonl +121 -0
- benchmark/code/indirect.jsonl +165 -0
- benchmark/crm/benign.jsonl +165 -0
- benchmark/crm/direct.jsonl +90 -0
- benchmark/crm/indirect.jsonl +150 -0
- benchmark/customer-service/benign.jsonl +160 -0
- benchmark/customer-service/direct.jsonl +100 -0
- benchmark/customer-service/indirect.jsonl +101 -0
- benchmark/finance/benign.jsonl +0 -0
- benchmark/finance/direct.jsonl +200 -0
- benchmark/finance/indirect.jsonl +200 -0
- benchmark/legal/benign.jsonl +0 -0
- benchmark/legal/direct.jsonl +200 -0
- benchmark/legal/indirect.jsonl +200 -0
- benchmark/macos/benign.jsonl +30 -0
- benchmark/macos/direct.jsonl +50 -0
- benchmark/macos/indirect.jsonl +50 -0
- benchmark/medical/benign.jsonl +642 -0
- benchmark/medical/direct.jsonl +229 -0
- benchmark/medical/indirect.jsonl +222 -0
- benchmark/os-filesystem/benign.jsonl +200 -0
- benchmark/os-filesystem/direct.jsonl +200 -0
- benchmark/os-filesystem/indirect.jsonl +200 -0
- benchmark/research/benign.jsonl +0 -0
- benchmark/research/direct.jsonl +119 -0
- benchmark/research/indirect.jsonl +125 -0
- benchmark/telecom/benign.jsonl +120 -0
- benchmark/telecom/direct.jsonl +161 -0
- benchmark/telecom/indirect.jsonl +166 -0
- benchmark/travel/benign.jsonl +130 -0
- benchmark/travel/direct.jsonl +105 -0
- benchmark/travel/indirect.jsonl +120 -0
- benchmark/windows/benign.jsonl +100 -0
- benchmark/windows/direct.jsonl +140 -0
- benchmark/windows/indirect.jsonl +107 -0
- benchmark/workflow/benign.jsonl +335 -0
- benchmark/workflow/direct.jsonl +78 -0
- benchmark/workflow/indirect.jsonl +107 -0
- cli/__init__.py +5 -0
- cli/main.py +182 -0
- cli/scaffold.py +334 -0
- decodingtrust_agent_sdk-0.1.0.dist-info/METADATA +642 -0
- decodingtrust_agent_sdk-0.1.0.dist-info/RECORD +374 -0
- decodingtrust_agent_sdk-0.1.0.dist-info/WHEEL +5 -0
- decodingtrust_agent_sdk-0.1.0.dist-info/entry_points.txt +2 -0
- decodingtrust_agent_sdk-0.1.0.dist-info/licenses/LICENSE +201 -0
- decodingtrust_agent_sdk-0.1.0.dist-info/top_level.txt +6 -0
- dt_arena/config/env.yaml +515 -0
- dt_arena/config/injection_mcp.yaml +430 -0
- dt_arena/config/mcp.yaml +642 -0
- dt_arena/envs/arxiv/docker-compose-hub.yml +31 -0
- dt_arena/envs/arxiv/docker-compose.yml +36 -0
- dt_arena/envs/atlassian/docker/docker-compose.dev.yml +65 -0
- dt_arena/envs/atlassian/docker/docker-compose.yml +53 -0
- dt_arena/envs/atlassian/docker-compose-hub.yml +57 -0
- dt_arena/envs/atlassian/docker-compose.yml +72 -0
- dt_arena/envs/bigquery/docker-compose.yml +20 -0
- dt_arena/envs/booking/docker-compose.yml +59 -0
- dt_arena/envs/calendar/docker-compose-hub.yml +30 -0
- dt_arena/envs/calendar/docker-compose.yml +42 -0
- dt_arena/envs/custom-website/docker-compose.yml +6 -0
- dt_arena/envs/customer_service/docker-compose.yml +59 -0
- dt_arena/envs/databricks/docker-compose-hub.yml +47 -0
- dt_arena/envs/databricks/docker-compose.yml +51 -0
- dt_arena/envs/ecommerce/docker-compose.yml +6 -0
- dt_arena/envs/ers/docker-compose.yml +36 -0
- dt_arena/envs/ers/hrms/docker/docker-compose.yml +31 -0
- dt_arena/envs/finance/docker-compose.yml +23 -0
- dt_arena/envs/github/docker/docker-compose-hub.yml +50 -0
- dt_arena/envs/github/docker/docker-compose.yml +50 -0
- dt_arena/envs/gmail/docker-compose-hub.yml +51 -0
- dt_arena/envs/gmail/docker-compose.yml +65 -0
- dt_arena/envs/google-form/docker-compose-hub.yml +33 -0
- dt_arena/envs/google-form/docker-compose.yml +41 -0
- dt_arena/envs/googledocs/docker-compose-hub.yml +61 -0
- dt_arena/envs/googledocs/docker-compose.yml +78 -0
- dt_arena/envs/hospital/docker-compose-hub.yml +25 -0
- dt_arena/envs/hospital/docker-compose.yml +27 -0
- dt_arena/envs/legal/docker-compose.yml +22 -0
- dt_arena/envs/linkedin/docker-compose.yml +63 -0
- dt_arena/envs/macos/docker-compose.yml +79 -0
- dt_arena/envs/os-filesystem/docker-compose-hub.yml +16 -0
- dt_arena/envs/os-filesystem/docker-compose.yml +20 -0
- dt_arena/envs/paypal/docker-compose-hub.yml +48 -0
- dt_arena/envs/paypal/docker-compose.yml +63 -0
- dt_arena/envs/research/docker-compose-hub.yml +13 -0
- dt_arena/envs/research/docker-compose.yml +24 -0
- dt_arena/envs/salesforce_crm/docker-compose-hub.yaml +45 -0
- dt_arena/envs/salesforce_crm/docker-compose.yaml +49 -0
- dt_arena/envs/slack/docker-compose-hub.yml +28 -0
- dt_arena/envs/slack/docker-compose.yml +41 -0
- dt_arena/envs/snowflake/docker-compose-hub.yml +41 -0
- dt_arena/envs/snowflake/docker-compose.yml +44 -0
- dt_arena/envs/telecom/docker-compose-hub.yml +16 -0
- dt_arena/envs/telecom/docker-compose.yml +17 -0
- dt_arena/envs/telegram/docker-compose-hub.yml +57 -0
- dt_arena/envs/telegram/docker-compose.yml +62 -0
- dt_arena/envs/terminal/docker-compose-hub.yml +12 -0
- dt_arena/envs/terminal/docker-compose.yml +26 -0
- dt_arena/envs/travel/docker-compose-hub.yml +19 -0
- dt_arena/envs/travel/docker-compose.yml +19 -0
- dt_arena/envs/whatsapp/docker-compose-hub.yml +61 -0
- dt_arena/envs/whatsapp/docker-compose.yml +78 -0
- dt_arena/envs/windows/docker-compose.yml +71 -0
- dt_arena/envs/zoom/docker-compose-hub.yml +27 -0
- dt_arena/envs/zoom/docker-compose.yml +40 -0
- dt_arena/injection_mcp_server/atlassian/env_injection.py +134 -0
- dt_arena/injection_mcp_server/calendar/env_injection.py +217 -0
- dt_arena/injection_mcp_server/custom_website/env_injection.py +97 -0
- dt_arena/injection_mcp_server/customer_service/env_injection.py +659 -0
- dt_arena/injection_mcp_server/databricks/env_injection.py +255 -0
- dt_arena/injection_mcp_server/ecommerce/env_injection.py +110 -0
- dt_arena/injection_mcp_server/finance/env_injection.py +85 -0
- dt_arena/injection_mcp_server/github/env_injection.py +206 -0
- dt_arena/injection_mcp_server/gmail/env_injection.py +211 -0
- dt_arena/injection_mcp_server/google_form/env_injection.py +186 -0
- dt_arena/injection_mcp_server/googledocs/env_injection.py +44 -0
- dt_arena/injection_mcp_server/hospital/env_injection.py +43 -0
- dt_arena/injection_mcp_server/legal/env_injection.py +229 -0
- dt_arena/injection_mcp_server/macos/env_injection.py +272 -0
- dt_arena/injection_mcp_server/os-filesystem/env_injection.py +341 -0
- dt_arena/injection_mcp_server/paypal/env_injection.py +268 -0
- dt_arena/injection_mcp_server/research/env_injection.py +616 -0
- dt_arena/injection_mcp_server/salesforce/env_injection.py +514 -0
- dt_arena/injection_mcp_server/slack/env_injection.py +265 -0
- dt_arena/injection_mcp_server/snowflake/env_injection.py +230 -0
- dt_arena/injection_mcp_server/telecom/env_injection.py +503 -0
- dt_arena/injection_mcp_server/telegram/env_injection.py +171 -0
- dt_arena/injection_mcp_server/terminal/env_injection.py +523 -0
- dt_arena/injection_mcp_server/travel/env_injection.py +173 -0
- dt_arena/injection_mcp_server/whatsapp/env_injection.py +185 -0
- dt_arena/injection_mcp_server/windows/env_injection.py +943 -0
- dt_arena/injection_mcp_server/zoom/env_injection.py +216 -0
- dt_arena/mcp_server/atlassian/main.py +1554 -0
- dt_arena/mcp_server/atlassian/test_server.py +66 -0
- dt_arena/mcp_server/bigquery/main.py +333 -0
- dt_arena/mcp_server/booking/main.py +310 -0
- dt_arena/mcp_server/browser/main.py +1741 -0
- dt_arena/mcp_server/calendar/example_multi_user.py +162 -0
- dt_arena/mcp_server/calendar/main.py +792 -0
- dt_arena/mcp_server/calendar/test_mcp.py +135 -0
- dt_arena/mcp_server/customer_service/main.py +1063 -0
- dt_arena/mcp_server/databricks/main.py +566 -0
- dt_arena/mcp_server/databricks/probe.py +102 -0
- dt_arena/mcp_server/ers/main.py +845 -0
- dt_arena/mcp_server/finance/__init__.py +87 -0
- dt_arena/mcp_server/finance/core/__init__.py +12 -0
- dt_arena/mcp_server/finance/core/data_loader.py +558 -0
- dt_arena/mcp_server/finance/core/portfolio.py +565 -0
- dt_arena/mcp_server/finance/evaluation/__init__.py +20 -0
- dt_arena/mcp_server/finance/evaluation/evaluator.py +217 -0
- dt_arena/mcp_server/finance/evaluation/logger.py +137 -0
- dt_arena/mcp_server/finance/injection/__init__.py +66 -0
- dt_arena/mcp_server/finance/injection/config.py +176 -0
- dt_arena/mcp_server/finance/injection/content.py +755 -0
- dt_arena/mcp_server/finance/injection/html.py +409 -0
- dt_arena/mcp_server/finance/injection/locations.py +167 -0
- dt_arena/mcp_server/finance/injection/methods.py +193 -0
- dt_arena/mcp_server/finance/injection/presets.py +1023 -0
- dt_arena/mcp_server/finance/main.py +361 -0
- dt_arena/mcp_server/finance/run_mcp.py +21 -0
- dt_arena/mcp_server/finance/run_web.py +26 -0
- dt_arena/mcp_server/finance/server/__init__.py +41 -0
- dt_arena/mcp_server/finance/server/extractor.py +1453 -0
- dt_arena/mcp_server/finance/server/extractor_minimal.py +292 -0
- dt_arena/mcp_server/finance/server/extractor_simple.py +1164 -0
- dt_arena/mcp_server/finance/server/injection_mcp.py +865 -0
- dt_arena/mcp_server/finance/server/mcp.py +451 -0
- dt_arena/mcp_server/finance/server/tools/__init__.py +23 -0
- dt_arena/mcp_server/finance/server/tools/account.py +88 -0
- dt_arena/mcp_server/finance/server/tools/browsing.py +328 -0
- dt_arena/mcp_server/finance/server/tools/social.py +73 -0
- dt_arena/mcp_server/finance/server/tools/trading.py +242 -0
- dt_arena/mcp_server/finance/server/tools/utility.py +49 -0
- dt_arena/mcp_server/finance/server/web.py +2139 -0
- dt_arena/mcp_server/finance/tasks/benchmark/__init__.py +28 -0
- dt_arena/mcp_server/finance/tasks/benchmark/attack_pool.py +3026 -0
- dt_arena/mcp_server/finance/tasks/benchmark/attack_runner.py +1315 -0
- dt_arena/mcp_server/finance/tasks/benchmark/finra_requirements.py +1335 -0
- dt_arena/mcp_server/finance/tasks/benchmark/finra_tasks.py +3665 -0
- dt_arena/mcp_server/finance/tasks/benchmark/malicious_tasks.py +2673 -0
- dt_arena/mcp_server/finance/tasks/redteam_suite/run_redteam_suite.py +1713 -0
- dt_arena/mcp_server/finance/test_mcp_tools.py +476 -0
- dt_arena/mcp_server/github/main.py +441 -0
- dt_arena/mcp_server/gmail/main.py +1004 -0
- dt_arena/mcp_server/google_form/main.py +141 -0
- dt_arena/mcp_server/googledocs/main.py +458 -0
- dt_arena/mcp_server/hospital/mcp_server.py +458 -0
- dt_arena/mcp_server/legal/__init__.py +9 -0
- dt_arena/mcp_server/legal/core/__init__.py +14 -0
- dt_arena/mcp_server/legal/core/courtlistener_store.py +762 -0
- dt_arena/mcp_server/legal/core/data_loader.py +266 -0
- dt_arena/mcp_server/legal/core/document_store.py +197 -0
- dt_arena/mcp_server/legal/core/matter_manager.py +466 -0
- dt_arena/mcp_server/legal/main.py +89 -0
- dt_arena/mcp_server/legal/scripts/collect_data.py +988 -0
- dt_arena/mcp_server/legal/server/__init__.py +14 -0
- dt_arena/mcp_server/legal/server/mcp.py +2330 -0
- dt_arena/mcp_server/macos/client_test.py +270 -0
- dt_arena/mcp_server/macos/mcp_server.py +285 -0
- dt_arena/mcp_server/os-filesystem/main.py +1380 -0
- dt_arena/mcp_server/paypal/main.py +501 -0
- dt_arena/mcp_server/research/main.py +777 -0
- dt_arena/mcp_server/salesforce/main.py +2006 -0
- dt_arena/mcp_server/slack/main.py +318 -0
- dt_arena/mcp_server/snowflake/main.py +612 -0
- dt_arena/mcp_server/snowflake/probe.py +183 -0
- dt_arena/mcp_server/telecom/mcp_client.py +423 -0
- dt_arena/mcp_server/telecom/mcp_server.py +1059 -0
- dt_arena/mcp_server/telegram/main.py +338 -0
- dt_arena/mcp_server/terminal/main.py +163 -0
- dt_arena/mcp_server/travel/client_test.py +16 -0
- dt_arena/mcp_server/travel/mcp_server.py +404 -0
- dt_arena/mcp_server/whatsapp/main.py +318 -0
- dt_arena/mcp_server/windows/client_test.py +270 -0
- dt_arena/mcp_server/windows/mcp_server.py +218 -0
- dt_arena/mcp_server/zoom/main.py +466 -0
- dt_arena/src/__init__.py +0 -0
- dt_arena/src/hooks/__init__.py +0 -0
- dt_arena/src/hooks/audit_log.py +30 -0
- dt_arena/src/hooks/hooks.json +3 -0
- dt_arena/src/run_benign.py +142 -0
- dt_arena/src/types/__init__.py +0 -0
- dt_arena/src/types/agent.py +441 -0
- dt_arena/src/types/attacks.py +2 -0
- dt_arena/src/types/environment.py +2 -0
- dt_arena/src/types/hooks.py +174 -0
- dt_arena/src/types/judge.py +52 -0
- dt_arena/src/types/red_teaming_trajectory.py +385 -0
- dt_arena/src/types/task.py +260 -0
- dt_arena/src/types/trajectory.py +315 -0
- dt_arena/utils/__init__.py +1 -0
- dt_arena/utils/atlassian/__init__.py +27 -0
- dt_arena/utils/atlassian/helpers.py +520 -0
- dt_arena/utils/bigquery/__init__.py +1 -0
- dt_arena/utils/bigquery/helpers.py +246 -0
- dt_arena/utils/calendar/__init__.py +1 -0
- dt_arena/utils/calendar/helpers.py +87 -0
- dt_arena/utils/customer_service/__init__.py +17 -0
- dt_arena/utils/customer_service/cs_env_client.py +940 -0
- dt_arena/utils/customer_service/helpers.py +339 -0
- dt_arena/utils/customer_service/judges/__init__.py +20 -0
- dt_arena/utils/customer_service/judges/abusive_comms.py +149 -0
- dt_arena/utils/customer_service/judges/bulk_destructive.py +202 -0
- dt_arena/utils/customer_service/judges/cross_customer.py +481 -0
- dt_arena/utils/customer_service/judges/email_exfiltration.py +320 -0
- dt_arena/utils/customer_service/judges/financial_fraud.py +221 -0
- dt_arena/utils/customer_service/judges/llm_utils.py +264 -0
- dt_arena/utils/customer_service/judges/pii_disclosure.py +837 -0
- dt_arena/utils/customer_service/judges/policy_misrep.py +44 -0
- dt_arena/utils/customer_service/judges/text_utils.py +21 -0
- dt_arena/utils/databricks/__init__.py +2 -0
- dt_arena/utils/databricks/helpers.py +210 -0
- dt_arena/utils/finance/__init__.py +0 -0
- dt_arena/utils/finance/helpers.py +263 -0
- dt_arena/utils/github/__init__.py +1 -0
- dt_arena/utils/github/helpers.py +249 -0
- dt_arena/utils/gmail/__init__.py +1 -0
- dt_arena/utils/gmail/helpers.py +344 -0
- dt_arena/utils/google_form/__init__.py +2 -0
- dt_arena/utils/google_form/helpers.py +133 -0
- dt_arena/utils/legal/__init__.py +0 -0
- dt_arena/utils/legal/helpers.py +228 -0
- dt_arena/utils/macos/__init__.py +0 -0
- dt_arena/utils/macos/env_setup.py +215 -0
- dt_arena/utils/macos/helpers.py +61 -0
- dt_arena/utils/os_filesystem/__init__.py +1 -0
- dt_arena/utils/os_filesystem/helpers.py +366 -0
- dt_arena/utils/paypal/__init__.py +1 -0
- dt_arena/utils/paypal/helpers.py +178 -0
- dt_arena/utils/port_allocator.py +266 -0
- dt_arena/utils/research/__init__.py +0 -0
- dt_arena/utils/research/helpers.py +251 -0
- dt_arena/utils/salesforce/__init__.py +1 -0
- dt_arena/utils/salesforce/helpers.py +719 -0
- dt_arena/utils/slack/__init__.py +1 -0
- dt_arena/utils/slack/helpers.py +176 -0
- dt_arena/utils/snowflake/__init__.py +1 -0
- dt_arena/utils/snowflake/helpers.py +166 -0
- dt_arena/utils/telecom/__init__.py +1 -0
- dt_arena/utils/telecom/helpers.py +760 -0
- dt_arena/utils/telegram/__init__.py +0 -0
- dt_arena/utils/telegram/helpers.py +174 -0
- dt_arena/utils/terminal/__init__.py +0 -0
- dt_arena/utils/terminal/helpers.py +20 -0
- dt_arena/utils/travel/__init__.py +0 -0
- dt_arena/utils/travel/env_client.py +537 -0
- dt_arena/utils/travel/llm_judge.py +137 -0
- dt_arena/utils/travel/prompts.py +64 -0
- dt_arena/utils/utils/__init__.py +122 -0
- dt_arena/utils/whatsapp/__init__.py +0 -0
- dt_arena/utils/whatsapp/helpers.py +226 -0
- dt_arena/utils/windows/__init__.py +0 -0
- dt_arena/utils/windows/env_reset.py +224 -0
- dt_arena/utils/windows/env_setup.py +280 -0
- dt_arena/utils/windows/exfil_helpers.py +170 -0
- dt_arena/utils/windows/helpers.py +74 -0
- dt_arena/utils/zoom/__init__.py +1 -0
- dt_arena/utils/zoom/helpers.py +70 -0
- eval/__init__.py +1 -0
- eval/evaluation.py +426 -0
- eval/task_runner.py +449 -0
- utils/__init__.py +148 -0
- utils/agent_helpers.py +308 -0
- utils/agent_wrapper.py +189 -0
- utils/compose_utils.py +135 -0
- utils/config.py +77 -0
- utils/env_helpers.py +104 -0
- utils/eval_stats.py +88 -0
- utils/injection_helpers.py +429 -0
- utils/injection_mcp_helpers.py +152 -0
- utils/judge_helpers.py +181 -0
- utils/judge_utils.py +472 -0
- utils/llm.py +196 -0
- utils/logging.py +45 -0
- utils/mcp_helpers.py +232 -0
- utils/mcp_manager.py +235 -0
- utils/memory_guard.py +18 -0
- utils/red_teaming_sandbox.py +476 -0
- utils/reset_helpers.py +318 -0
- utils/resource_manager.py +370 -0
- utils/skill_helpers.py +447 -0
- utils/task_executor.py +904 -0
- utils/task_helpers.py +270 -0
- utils/template_helpers.py +179 -0
|
@@ -0,0 +1,566 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""
|
|
3
|
+
Databricks MCP Server (local simulation)
|
|
4
|
+
- Tools (names fixed to align with Databricks service categories):
|
|
5
|
+
- databricks-vector-search (Vector Search equivalent; FAISS/simple over local Postgres)
|
|
6
|
+
- databricks-dbsql-exec (DBSQL read-only execution)
|
|
7
|
+
- databricks-unity-function (Unity Catalog function style; local read-only functions)
|
|
8
|
+
- databricks-genie (Analyst assistant; LLM-generated suggestions + sample SQL)
|
|
9
|
+
"""
|
|
10
|
+
import os
|
|
11
|
+
import json
|
|
12
|
+
import time
|
|
13
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
14
|
+
|
|
15
|
+
import numpy as np
|
|
16
|
+
|
|
17
|
+
try:
|
|
18
|
+
import faiss # type: ignore
|
|
19
|
+
except Exception:
|
|
20
|
+
faiss = None
|
|
21
|
+
|
|
22
|
+
import psycopg
|
|
23
|
+
try:
|
|
24
|
+
from openai import OpenAI # Optional LLM
|
|
25
|
+
except Exception:
|
|
26
|
+
OpenAI = None
|
|
27
|
+
|
|
28
|
+
from fastmcp import FastMCP
|
|
29
|
+
|
|
30
|
+
mcp = FastMCP("Databricks MCP (Local)")
|
|
31
|
+
|
|
32
|
+
# -----------------------------------------------------------------------------
|
|
33
|
+
# Environment
|
|
34
|
+
# -----------------------------------------------------------------------------
|
|
35
|
+
PORT = int(os.getenv("PORT", "8843"))
|
|
36
|
+
HOST = os.getenv("HOST", "0.0.0.0")
|
|
37
|
+
|
|
38
|
+
POSTGRES_DSN = os.getenv(
|
|
39
|
+
"POSTGRES_DSN",
|
|
40
|
+
"postgresql://snow:snow@127.0.0.1:5452/snowdb",
|
|
41
|
+
)
|
|
42
|
+
SEARCH_MODE = os.getenv("SEARCH_MODE", "faiss").lower() # "faiss" or "simple"
|
|
43
|
+
SEARCH_TABLE = os.getenv("SEARCH_TABLE", "product_search_view")
|
|
44
|
+
SEARCH_COLUMNS = [c.strip() for c in os.getenv("SEARCH_COLUMNS", "name,description").split(",") if c.strip()]
|
|
45
|
+
INDEX_ROW_LIMIT = int(os.getenv("INDEX_ROW_LIMIT", "10000"))
|
|
46
|
+
|
|
47
|
+
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "").strip()
|
|
48
|
+
OPENAI_BASE_URL = os.getenv("OPENAI_BASE_URL", "").strip()
|
|
49
|
+
OPENAI_MODEL = os.getenv("OPENAI_MODEL", "gpt-4o-mini").strip()
|
|
50
|
+
|
|
51
|
+
# -----------------------------------------------------------------------------
|
|
52
|
+
# Globals
|
|
53
|
+
# -----------------------------------------------------------------------------
|
|
54
|
+
_pg_conn: Optional["psycopg.Connection"] = None
|
|
55
|
+
_faiss_index: Optional["faiss.Index"] = None
|
|
56
|
+
_faiss_id_to_row: List[Dict[str, Any]] = []
|
|
57
|
+
_faiss_dim: int = int(os.getenv("FAISS_DIM", "384"))
|
|
58
|
+
_faiss_norm: bool = True
|
|
59
|
+
_openai_client: Optional["OpenAI"] = None
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def _connect_pg_with_retry(
|
|
63
|
+
dsn: str,
|
|
64
|
+
attempts: int = 3,
|
|
65
|
+
delay: float = 0.75,
|
|
66
|
+
) -> "psycopg.Connection":
|
|
67
|
+
"""
|
|
68
|
+
Connect to Postgres with simple retry.
|
|
69
|
+
|
|
70
|
+
In per-task eval runs the databricks Postgres container may still be
|
|
71
|
+
starting up when the MCP server begins handling requests, so the first
|
|
72
|
+
few connection attempts can see \"connection refused\". Other envs have
|
|
73
|
+
explicit HTTP health waits in env_up.py; for Postgres we handle this at
|
|
74
|
+
the MCP layer with a short retry loop.
|
|
75
|
+
"""
|
|
76
|
+
last_exc: Optional[Exception] = None
|
|
77
|
+
for _ in range(max(1, attempts)):
|
|
78
|
+
try:
|
|
79
|
+
return psycopg.connect(dsn, autocommit=True)
|
|
80
|
+
except Exception as e:
|
|
81
|
+
last_exc = e
|
|
82
|
+
time.sleep(delay)
|
|
83
|
+
if last_exc is not None:
|
|
84
|
+
raise last_exc
|
|
85
|
+
return psycopg.connect(dsn, autocommit=True)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def get_pg() -> "psycopg.Connection":
|
|
89
|
+
global _pg_conn
|
|
90
|
+
if _pg_conn is None or _pg_conn.closed:
|
|
91
|
+
_pg_conn = _connect_pg_with_retry(POSTGRES_DSN)
|
|
92
|
+
return _pg_conn
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def get_openai() -> Optional["OpenAI"]:
|
|
96
|
+
global _openai_client
|
|
97
|
+
if not OPENAI_API_KEY or OpenAI is None:
|
|
98
|
+
return None
|
|
99
|
+
if _openai_client is None:
|
|
100
|
+
if OPENAI_BASE_URL:
|
|
101
|
+
_openai_client = OpenAI(api_key=OPENAI_API_KEY, base_url=OPENAI_BASE_URL)
|
|
102
|
+
else:
|
|
103
|
+
_openai_client = OpenAI(api_key=OPENAI_API_KEY)
|
|
104
|
+
return _openai_client
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
# -----------------------------------------------------------------------------
|
|
108
|
+
# Utils
|
|
109
|
+
# -----------------------------------------------------------------------------
|
|
110
|
+
def _hashing_embed(text: str, dim: int) -> np.ndarray:
|
|
111
|
+
text = (text or "").lower()
|
|
112
|
+
vec = np.zeros(dim, dtype=np.float32)
|
|
113
|
+
if not text:
|
|
114
|
+
return vec
|
|
115
|
+
n = 3
|
|
116
|
+
for i in range(len(text) - n + 1):
|
|
117
|
+
gram = text[i : i + n]
|
|
118
|
+
h = (hash(gram) % dim + dim) % dim
|
|
119
|
+
vec[h] += 1.0
|
|
120
|
+
norm = np.linalg.norm(vec)
|
|
121
|
+
if norm > 0:
|
|
122
|
+
vec /= norm
|
|
123
|
+
return vec
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def _concat_columns(row: Dict[str, Any], columns: List[str]) -> str:
|
|
127
|
+
parts: List[str] = []
|
|
128
|
+
for c in columns:
|
|
129
|
+
v = row.get(c)
|
|
130
|
+
if v is None:
|
|
131
|
+
continue
|
|
132
|
+
parts.append(str(v))
|
|
133
|
+
return " ".join(parts)
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def _simple_similarity(a: str, b: str) -> float:
|
|
137
|
+
ta = set(a.lower().split())
|
|
138
|
+
tb = set(b.lower().split())
|
|
139
|
+
if not ta or not tb:
|
|
140
|
+
return 0.0
|
|
141
|
+
inter = len(ta & tb)
|
|
142
|
+
union = len(ta | tb)
|
|
143
|
+
return inter / union if union else 0.0
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def _get_table_columns(conn: "psycopg.Connection", table: str) -> List[str]:
|
|
147
|
+
with conn.cursor() as cur:
|
|
148
|
+
cur.execute(
|
|
149
|
+
"""
|
|
150
|
+
SELECT column_name
|
|
151
|
+
FROM information_schema.columns
|
|
152
|
+
WHERE table_schema = 'public' AND table_name = %s
|
|
153
|
+
ORDER BY ordinal_position;
|
|
154
|
+
""",
|
|
155
|
+
(table,),
|
|
156
|
+
)
|
|
157
|
+
rows = cur.fetchall()
|
|
158
|
+
return [r[0] for r in rows] if rows else []
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
# -----------------------------------------------------------------------------
|
|
162
|
+
# FAISS index
|
|
163
|
+
# -----------------------------------------------------------------------------
|
|
164
|
+
async def _ensure_faiss_index() -> None:
|
|
165
|
+
global _faiss_index, _faiss_id_to_row
|
|
166
|
+
if _faiss_index is not None:
|
|
167
|
+
return
|
|
168
|
+
if faiss is None:
|
|
169
|
+
raise RuntimeError("FAISS not installed but SEARCH_MODE=faiss was requested.")
|
|
170
|
+
|
|
171
|
+
conn = get_pg()
|
|
172
|
+
cols = ", ".join([f"{c}" for c in SEARCH_COLUMNS if c])
|
|
173
|
+
query = f"SELECT id, {cols} FROM {SEARCH_TABLE} LIMIT %s;"
|
|
174
|
+
with conn.cursor() as cur:
|
|
175
|
+
cur.execute(query, (INDEX_ROW_LIMIT,))
|
|
176
|
+
rows = cur.fetchall()
|
|
177
|
+
colnames = [desc[0] for desc in cur.description]
|
|
178
|
+
|
|
179
|
+
_faiss_id_to_row = []
|
|
180
|
+
vectors: List[np.ndarray] = []
|
|
181
|
+
for r in rows:
|
|
182
|
+
row = {colnames[i]: r[i] for i in range(len(colnames))}
|
|
183
|
+
text = _concat_columns(row, SEARCH_COLUMNS)
|
|
184
|
+
emb = _hashing_embed(text, _faiss_dim)
|
|
185
|
+
vectors.append(emb)
|
|
186
|
+
_faiss_id_to_row.append(row)
|
|
187
|
+
|
|
188
|
+
if not vectors:
|
|
189
|
+
_faiss_index = faiss.IndexFlatIP(_faiss_dim)
|
|
190
|
+
return
|
|
191
|
+
|
|
192
|
+
xb = np.stack(vectors, axis=0)
|
|
193
|
+
if _faiss_norm:
|
|
194
|
+
faiss.normalize_L2(xb)
|
|
195
|
+
_faiss_index = faiss.IndexFlatIP(_faiss_dim)
|
|
196
|
+
_faiss_index.add(xb)
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
async def _faiss_search(query: str, k: int) -> List[Tuple[Dict[str, Any], float]]:
|
|
200
|
+
await _ensure_faiss_index()
|
|
201
|
+
assert _faiss_index is not None
|
|
202
|
+
if not _faiss_id_to_row:
|
|
203
|
+
return []
|
|
204
|
+
qv = _hashing_embed(query, _faiss_dim).reshape(1, -1)
|
|
205
|
+
if _faiss_norm:
|
|
206
|
+
faiss.normalize_L2(qv)
|
|
207
|
+
scores, idxs = _faiss_index.search(qv, min(k, len(_faiss_id_to_row)))
|
|
208
|
+
out: List[Tuple[Dict[str, Any], float]] = []
|
|
209
|
+
for i, s in zip(idxs[0].tolist(), scores[0].tolist()):
|
|
210
|
+
if i < 0 or i >= len(_faiss_id_to_row):
|
|
211
|
+
continue
|
|
212
|
+
out.append((_faiss_id_to_row[i], float(s)))
|
|
213
|
+
return out
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
# -----------------------------------------------------------------------------
|
|
217
|
+
# Simple search
|
|
218
|
+
# -----------------------------------------------------------------------------
|
|
219
|
+
async def _simple_search(
|
|
220
|
+
query: str,
|
|
221
|
+
limit: int,
|
|
222
|
+
filter_obj: Optional[Dict[str, Any]],
|
|
223
|
+
) -> List[Tuple[Dict[str, Any], float]]:
|
|
224
|
+
conn = get_pg()
|
|
225
|
+
table_cols = _get_table_columns(conn, SEARCH_TABLE)
|
|
226
|
+
where = ""
|
|
227
|
+
params: List[Any] = []
|
|
228
|
+
if filter_obj:
|
|
229
|
+
clauses: List[str] = []
|
|
230
|
+
for k, v in filter_obj.items():
|
|
231
|
+
if k in table_cols:
|
|
232
|
+
clauses.append(f"{k} = %s")
|
|
233
|
+
params.append(v)
|
|
234
|
+
if clauses:
|
|
235
|
+
where = " WHERE " + " AND ".join(clauses)
|
|
236
|
+
cols = ", ".join([f"{c}" for c in SEARCH_COLUMNS if c])
|
|
237
|
+
sql = f"SELECT id, {cols} FROM {SEARCH_TABLE}{where} LIMIT %s;"
|
|
238
|
+
params.append(max(limit * 5, 50))
|
|
239
|
+
with conn.cursor() as cur:
|
|
240
|
+
cur.execute(sql, params)
|
|
241
|
+
rows = cur.fetchall()
|
|
242
|
+
colnames = [desc[0] for desc in cur.description]
|
|
243
|
+
|
|
244
|
+
scored: List[Tuple[Dict[str, Any], float]] = []
|
|
245
|
+
for r in rows:
|
|
246
|
+
row = {colnames[i]: r[i] for i in range(len(colnames))}
|
|
247
|
+
text = _concat_columns(row, SEARCH_COLUMNS)
|
|
248
|
+
scored.append((row, _simple_similarity(query, text)))
|
|
249
|
+
|
|
250
|
+
scored.sort(key=lambda x: x[1], reverse=True)
|
|
251
|
+
return scored[:limit]
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
# -----------------------------------------------------------------------------
|
|
255
|
+
# Tools
|
|
256
|
+
# -----------------------------------------------------------------------------
|
|
257
|
+
@mcp.tool(name="databricks-vector-search")
|
|
258
|
+
async def databricks_vector_search(
|
|
259
|
+
query: str,
|
|
260
|
+
columns: Optional[List[str]] = None,
|
|
261
|
+
filter: Optional[Dict[str, Any]] = None,
|
|
262
|
+
limit: int = 10,
|
|
263
|
+
access_token: Optional[str] = None,
|
|
264
|
+
) -> str:
|
|
265
|
+
"""Vector/text search over product view (FAISS/simple) with optional filter.
|
|
266
|
+
|
|
267
|
+
Args:
|
|
268
|
+
query: Search query (REQUIRED)
|
|
269
|
+
columns: Specific columns to include in each result (optional)
|
|
270
|
+
filter: Equality filter {column: value} applied before scoring (optional)
|
|
271
|
+
limit: Max results (default 10, max 200)
|
|
272
|
+
access_token: User token (optional)
|
|
273
|
+
|
|
274
|
+
Returns:
|
|
275
|
+
JSON with results list and request_id.
|
|
276
|
+
"""
|
|
277
|
+
try:
|
|
278
|
+
columns = columns or []
|
|
279
|
+
limit = max(1, min(200, int(limit)))
|
|
280
|
+
if SEARCH_MODE == "faiss":
|
|
281
|
+
results = await _faiss_search(query, limit)
|
|
282
|
+
else:
|
|
283
|
+
results = await _simple_search(query, limit, filter)
|
|
284
|
+
formatted = []
|
|
285
|
+
for row, score in results:
|
|
286
|
+
record = {}
|
|
287
|
+
if columns:
|
|
288
|
+
for c in columns:
|
|
289
|
+
if c in row:
|
|
290
|
+
record[c] = row[c]
|
|
291
|
+
else:
|
|
292
|
+
record = {k: v for k, v in row.items()}
|
|
293
|
+
record["_score"] = score
|
|
294
|
+
formatted.append(record)
|
|
295
|
+
return json.dumps(
|
|
296
|
+
{"results": formatted, "request_id": f"req-{abs(hash(query)) % (10**9)}", "mode": SEARCH_MODE},
|
|
297
|
+
ensure_ascii=False,
|
|
298
|
+
indent=2,
|
|
299
|
+
default=str,
|
|
300
|
+
)
|
|
301
|
+
except Exception as e:
|
|
302
|
+
return json.dumps({"error": str(e)})
|
|
303
|
+
|
|
304
|
+
|
|
305
|
+
@mcp.tool(name="databricks-dbsql-exec")
|
|
306
|
+
async def databricks_dbsql_exec(sql: str, access_token: Optional[str] = None) -> str:
|
|
307
|
+
"""Execute a read-only SQL (SELECT/WITH) on local Postgres.
|
|
308
|
+
|
|
309
|
+
Args:
|
|
310
|
+
sql: Query text (REQUIRED). Only SELECT/WITH allowed.
|
|
311
|
+
access_token: User token (optional)
|
|
312
|
+
|
|
313
|
+
Returns:
|
|
314
|
+
JSON with rows and row_count, or error.
|
|
315
|
+
"""
|
|
316
|
+
try:
|
|
317
|
+
if not sql or sql.strip() == "":
|
|
318
|
+
return json.dumps({"error": "sql is required"})
|
|
319
|
+
first = sql.strip().split(None, 1)[0].upper()
|
|
320
|
+
if first not in {"SELECT", "WITH"}:
|
|
321
|
+
return json.dumps({"error": "Only read-only SELECT/WITH queries are allowed."})
|
|
322
|
+
conn = get_pg()
|
|
323
|
+
with conn.cursor() as cur:
|
|
324
|
+
cur.execute(sql)
|
|
325
|
+
rows = cur.fetchall()
|
|
326
|
+
colnames = [desc[0] for desc in cur.description]
|
|
327
|
+
out = [{colnames[i]: r[i] for i in range(len(colnames))} for r in rows] if rows else []
|
|
328
|
+
return json.dumps({"rows": out, "row_count": len(out)}, indent=2, default=str)
|
|
329
|
+
except Exception as e:
|
|
330
|
+
return json.dumps({"error": str(e)})
|
|
331
|
+
|
|
332
|
+
|
|
333
|
+
@mcp.tool(name="databricks-unity-function")
|
|
334
|
+
async def databricks_unity_function(
|
|
335
|
+
name: str,
|
|
336
|
+
args: Optional[Dict[str, Any]] = None,
|
|
337
|
+
access_token: Optional[str] = None
|
|
338
|
+
) -> str:
|
|
339
|
+
"""
|
|
340
|
+
Simulated Unity Catalog function caller (read-only).
|
|
341
|
+
Supported:
|
|
342
|
+
- top_revenue_products(limit:int=10)
|
|
343
|
+
- search_tags(query:str, limit:int=20)
|
|
344
|
+
- recent_orders(limit:int=10)
|
|
345
|
+
"""
|
|
346
|
+
try:
|
|
347
|
+
args = args or {}
|
|
348
|
+
conn = get_pg()
|
|
349
|
+
with conn.cursor() as cur:
|
|
350
|
+
if name == "top_revenue_products":
|
|
351
|
+
limit = int(args.get("limit", 10))
|
|
352
|
+
cur.execute(
|
|
353
|
+
"""
|
|
354
|
+
SELECT product_name, SUM(revenue) AS total
|
|
355
|
+
FROM revenue
|
|
356
|
+
GROUP BY product_name
|
|
357
|
+
ORDER BY total DESC
|
|
358
|
+
LIMIT %s;
|
|
359
|
+
""",
|
|
360
|
+
(limit,),
|
|
361
|
+
)
|
|
362
|
+
rows = cur.fetchall()
|
|
363
|
+
res = [{"product_name": r[0], "total": r[1]} for r in rows]
|
|
364
|
+
return json.dumps({"function": name, "result": res}, indent=2, default=str)
|
|
365
|
+
|
|
366
|
+
if name == "search_tags":
|
|
367
|
+
query = str(args.get("query", "") or "")
|
|
368
|
+
limit = int(args.get("limit", 20))
|
|
369
|
+
cur.execute(
|
|
370
|
+
"""
|
|
371
|
+
SELECT p.id, p.name, t.name AS tag
|
|
372
|
+
FROM products p
|
|
373
|
+
JOIN product_tags pt ON pt.product_id = p.id
|
|
374
|
+
JOIN tags t ON t.id = pt.tag_id
|
|
375
|
+
WHERE t.name ILIKE %s
|
|
376
|
+
LIMIT %s;
|
|
377
|
+
""",
|
|
378
|
+
(f"%{query}%", limit),
|
|
379
|
+
)
|
|
380
|
+
rows = cur.fetchall()
|
|
381
|
+
res = [{"product_id": r[0], "product_name": r[1], "tag": r[2]} for r in rows]
|
|
382
|
+
return json.dumps({"function": name, "result": res}, indent=2, default=str)
|
|
383
|
+
|
|
384
|
+
if name == "recent_orders":
|
|
385
|
+
limit = int(args.get("limit", 10))
|
|
386
|
+
cur.execute(
|
|
387
|
+
"""
|
|
388
|
+
SELECT id, customer_id, order_date, status
|
|
389
|
+
FROM orders
|
|
390
|
+
ORDER BY order_date DESC
|
|
391
|
+
LIMIT %s;
|
|
392
|
+
""",
|
|
393
|
+
(limit,),
|
|
394
|
+
)
|
|
395
|
+
rows = cur.fetchall()
|
|
396
|
+
res = [{"order_id": r[0], "customer_id": r[1], "order_date": r[2], "status": r[3]} for r in rows]
|
|
397
|
+
return json.dumps({"function": name, "result": res}, indent=2, default=str)
|
|
398
|
+
|
|
399
|
+
return json.dumps({"error": f"Unknown function: {name}"})
|
|
400
|
+
except Exception as e:
|
|
401
|
+
return json.dumps({"error": str(e)})
|
|
402
|
+
|
|
403
|
+
|
|
404
|
+
@mcp.tool(name="databricks-genie")
|
|
405
|
+
async def databricks_genie(message: str, access_token: Optional[str] = None) -> str:
|
|
406
|
+
"""
|
|
407
|
+
Analyst assistant (LLM). If OpenAI available, produce short rationale + a single Postgres SQL.
|
|
408
|
+
Fallback: deterministic suggestion.
|
|
409
|
+
"""
|
|
410
|
+
try:
|
|
411
|
+
client = get_openai()
|
|
412
|
+
if client is not None:
|
|
413
|
+
# Try to get schema from database if available
|
|
414
|
+
revenue_cols = ["date", "product_id", "product_name", "revenue"]
|
|
415
|
+
search_cols = ["id", "name", "description", "price"]
|
|
416
|
+
try:
|
|
417
|
+
conn = get_pg()
|
|
418
|
+
revenue_cols = _get_table_columns(conn, "revenue") or revenue_cols
|
|
419
|
+
search_cols = _get_table_columns(conn, SEARCH_TABLE) or search_cols
|
|
420
|
+
except Exception:
|
|
421
|
+
pass # Use defaults if database not available
|
|
422
|
+
|
|
423
|
+
sys_prompt = (
|
|
424
|
+
"You are an analyst assistant for a local Postgres database.\n"
|
|
425
|
+
"Your task: produce a short rationale and one read-only SQL that runs on Postgres.\n"
|
|
426
|
+
f"- revenue({', '.join(revenue_cols)})\n"
|
|
427
|
+
f"- {SEARCH_TABLE}({', '.join(search_cols)})\n"
|
|
428
|
+
"Constraints:\n"
|
|
429
|
+
"- Only SELECT/WITH; limit rows (~50).\n"
|
|
430
|
+
"- If asking for top revenue, aggregate and order desc.\n"
|
|
431
|
+
"- If asking about churn, customers, support tickets, or downgrades, use appropriate tables.\n"
|
|
432
|
+
)
|
|
433
|
+
user_prompt = (
|
|
434
|
+
f"User message:\n{message or ''}\n\n"
|
|
435
|
+
"Return a concise explanation (2-3 lines) and a SQL in triple backticks."
|
|
436
|
+
)
|
|
437
|
+
try:
|
|
438
|
+
resp = client.chat.completions.create(
|
|
439
|
+
model=OPENAI_MODEL,
|
|
440
|
+
messages=[
|
|
441
|
+
{"role": "system", "content": sys_prompt},
|
|
442
|
+
{"role": "user", "content": user_prompt},
|
|
443
|
+
],
|
|
444
|
+
temperature=0.2,
|
|
445
|
+
max_tokens=700,
|
|
446
|
+
)
|
|
447
|
+
ai_text = resp.choices[0].message.content or ""
|
|
448
|
+
return json.dumps({"content": [{"type": "text", "text": ai_text}]}, indent=2)
|
|
449
|
+
except Exception:
|
|
450
|
+
pass
|
|
451
|
+
|
|
452
|
+
# Fallback: deterministic suggestions based on message content
|
|
453
|
+
msg = (message or "").lower()
|
|
454
|
+
if "churn" in msg or ("support" in msg and "ticket" in msg) or "downgrade" in msg:
|
|
455
|
+
sql = (
|
|
456
|
+
"SELECT c.customer_id, c.name, COUNT(t.ticket_id) as ticket_count, d.downgrade_date\n"
|
|
457
|
+
"FROM customers c\n"
|
|
458
|
+
"LEFT JOIN support_tickets t ON c.customer_id = t.customer_id\n"
|
|
459
|
+
"LEFT JOIN downgrades d ON c.customer_id = d.customer_id\n"
|
|
460
|
+
"WHERE d.downgrade_date >= CURRENT_DATE - INTERVAL '30 days'\n"
|
|
461
|
+
"GROUP BY c.customer_id, c.name, d.downgrade_date\n"
|
|
462
|
+
"HAVING COUNT(t.ticket_id) > 2\n"
|
|
463
|
+
"ORDER BY ticket_count DESC\n"
|
|
464
|
+
"LIMIT 50;"
|
|
465
|
+
)
|
|
466
|
+
text = (
|
|
467
|
+
"Databricks Genie (local)\n\n"
|
|
468
|
+
"Rationale: This query identifies customers who have more than 2 support tickets "
|
|
469
|
+
"and have recently downgraded (within 30 days). It joins customers, support_tickets, "
|
|
470
|
+
"and downgrades tables.\n\n"
|
|
471
|
+
f"Suggested SQL:\n```sql\n{sql}\n```"
|
|
472
|
+
)
|
|
473
|
+
elif "top" in msg and "revenue" in msg:
|
|
474
|
+
sql = (
|
|
475
|
+
"SELECT product_name, SUM(revenue) AS total\n"
|
|
476
|
+
"FROM revenue\n"
|
|
477
|
+
"GROUP BY product_name\n"
|
|
478
|
+
"ORDER BY total DESC\n"
|
|
479
|
+
"LIMIT 10;"
|
|
480
|
+
)
|
|
481
|
+
text = f"Databricks Genie (local)\nSuggested SQL:\n\n{sql}"
|
|
482
|
+
else:
|
|
483
|
+
sql = (
|
|
484
|
+
"SELECT date, product_id, product_name, revenue\n"
|
|
485
|
+
"FROM revenue\n"
|
|
486
|
+
"ORDER BY date DESC\n"
|
|
487
|
+
"LIMIT 20;"
|
|
488
|
+
)
|
|
489
|
+
text = f"Databricks Genie (local)\nSuggested SQL:\n\n{sql}"
|
|
490
|
+
return json.dumps({"content": [{"type": "text", "text": text}]}, indent=2)
|
|
491
|
+
except Exception as e:
|
|
492
|
+
return json.dumps({"error": str(e)})
|
|
493
|
+
|
|
494
|
+
|
|
495
|
+
# -----------------------------------------------------------------------------
|
|
496
|
+
# Internal reset helper (used by env reset flows, not exposed as an MCP tool)
|
|
497
|
+
# -----------------------------------------------------------------------------
|
|
498
|
+
async def admin_reset(access_token: Optional[str] = None) -> str:
|
|
499
|
+
"""
|
|
500
|
+
Reset Databricks environment to clean state.
|
|
501
|
+
Clears all data from tables and re-seeds with defaults.
|
|
502
|
+
Called during container reset to prepare for a new task.
|
|
503
|
+
"""
|
|
504
|
+
global _faiss_index, _faiss_id_to_row
|
|
505
|
+
|
|
506
|
+
try:
|
|
507
|
+
conn = get_pg()
|
|
508
|
+
with conn.cursor() as cur:
|
|
509
|
+
# Clear existing data from all tables
|
|
510
|
+
tables_to_clear = [
|
|
511
|
+
"revenue", "products", "orders", "customers",
|
|
512
|
+
"support_tickets", "downgrades", "product_tags", "tags"
|
|
513
|
+
]
|
|
514
|
+
for table in tables_to_clear:
|
|
515
|
+
try:
|
|
516
|
+
cur.execute(f"DELETE FROM {table};")
|
|
517
|
+
except Exception:
|
|
518
|
+
pass # Table may not exist
|
|
519
|
+
|
|
520
|
+
# Re-seed with default data
|
|
521
|
+
cur.execute(
|
|
522
|
+
"""
|
|
523
|
+
INSERT INTO products (name, description, category) VALUES
|
|
524
|
+
('Laptop Pro 14', 'High-performance laptop with retina display', 'Electronics'),
|
|
525
|
+
('Noise Cancelling Headphones', 'Over-ear ANC headphones with long battery life', 'Audio'),
|
|
526
|
+
('Smartwatch X', 'Fitness tracking and notifications', 'Wearables')
|
|
527
|
+
ON CONFLICT DO NOTHING;
|
|
528
|
+
"""
|
|
529
|
+
)
|
|
530
|
+
cur.execute(
|
|
531
|
+
"""
|
|
532
|
+
INSERT INTO revenue (date, product_id, product_name, revenue) VALUES
|
|
533
|
+
('2025-10-01', 1, 'Laptop Pro 14', 250000),
|
|
534
|
+
('2025-10-01', 2, 'Noise Cancelling Headphones', 85000),
|
|
535
|
+
('2025-10-01', 3, 'Smartwatch X', 120000),
|
|
536
|
+
('2025-10-02', 1, 'Laptop Pro 14', 265000),
|
|
537
|
+
('2025-10-02', 2, 'Noise Cancelling Headphones', 82000),
|
|
538
|
+
('2025-10-02', 3, 'Smartwatch X', 110000)
|
|
539
|
+
ON CONFLICT DO NOTHING;
|
|
540
|
+
"""
|
|
541
|
+
)
|
|
542
|
+
|
|
543
|
+
# Reset FAISS index
|
|
544
|
+
_faiss_index = None
|
|
545
|
+
_faiss_id_to_row = []
|
|
546
|
+
|
|
547
|
+
return json.dumps({
|
|
548
|
+
"ok": True,
|
|
549
|
+
"message": "Databricks environment reset"
|
|
550
|
+
})
|
|
551
|
+
except Exception as e:
|
|
552
|
+
return json.dumps({"error": str(e)})
|
|
553
|
+
|
|
554
|
+
|
|
555
|
+
def main() -> None:
|
|
556
|
+
print(f"Starting Databricks MCP (Local) on http://{HOST}:{PORT}/mcp")
|
|
557
|
+
print(f"- POSTGRES_DSN: {POSTGRES_DSN}")
|
|
558
|
+
print(f"- SEARCH_MODE : {SEARCH_MODE} (faiss available: {faiss is not None})")
|
|
559
|
+
print(f"- SEARCH_TABLE: {SEARCH_TABLE}, columns={SEARCH_COLUMNS}")
|
|
560
|
+
mcp.run(transport="http", host=HOST, port=PORT)
|
|
561
|
+
|
|
562
|
+
|
|
563
|
+
if __name__ == "__main__":
|
|
564
|
+
main()
|
|
565
|
+
|
|
566
|
+
|
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""
|
|
3
|
+
Probe client for Databricks MCP (local) over HTTP SSE.
|
|
4
|
+
"""
|
|
5
|
+
import os
|
|
6
|
+
import json
|
|
7
|
+
import httpx
|
|
8
|
+
|
|
9
|
+
HOST = os.getenv("HOST", "127.0.0.1")
|
|
10
|
+
PORT = int(os.getenv("PORT", "8843"))
|
|
11
|
+
BASE = f"http://{HOST}:{PORT}/mcp"
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def _ingest_set_cookie(client: httpx.Client, r: httpx.Response) -> None:
|
|
15
|
+
# Best-effort: parse Set-Cookie, pick first k=v and store in client cookies
|
|
16
|
+
sc = r.headers.get("set-cookie", "")
|
|
17
|
+
if not sc:
|
|
18
|
+
return
|
|
19
|
+
# Support multiple cookies separated by comma
|
|
20
|
+
parts = sc.split(",")
|
|
21
|
+
for part in parts:
|
|
22
|
+
kv = part.split(";", 1)[0].strip()
|
|
23
|
+
if "=" in kv:
|
|
24
|
+
name, value = kv.split("=", 1)
|
|
25
|
+
if name and value:
|
|
26
|
+
client.cookies.set(name.strip(), value.strip())
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def sse_rpc(client: httpx.Client, method: str, params: dict, extra_headers: dict | None = None) -> dict:
|
|
30
|
+
payload = {"jsonrpc": "2.0", "id": method, "method": method, "params": params}
|
|
31
|
+
headers = extra_headers or {}
|
|
32
|
+
with client.stream("POST", BASE, json=payload, headers=headers) as r:
|
|
33
|
+
_ingest_set_cookie(client, r)
|
|
34
|
+
for line in r.iter_lines():
|
|
35
|
+
if not line:
|
|
36
|
+
continue
|
|
37
|
+
s = line.decode("utf-8") if isinstance(line, (bytes, bytearray)) else line
|
|
38
|
+
if s.startswith("data:"):
|
|
39
|
+
data_str = s[len("data:"):].strip()
|
|
40
|
+
try:
|
|
41
|
+
return json.loads(data_str)
|
|
42
|
+
except Exception:
|
|
43
|
+
return {"raw": data_str}
|
|
44
|
+
# Fallback: try non-streaming POST (some servers may respond with plain JSON)
|
|
45
|
+
try:
|
|
46
|
+
resp = client.post(BASE, json=payload, headers=extra_headers)
|
|
47
|
+
if resp.headers.get("set-cookie"):
|
|
48
|
+
# persist any session cookies
|
|
49
|
+
_ingest_set_cookie(client, resp)
|
|
50
|
+
resp.raise_for_status()
|
|
51
|
+
return resp.json()
|
|
52
|
+
except Exception:
|
|
53
|
+
return {"error": "no data"}
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def main():
|
|
57
|
+
headers = {
|
|
58
|
+
"Accept": "application/json, text/event-stream",
|
|
59
|
+
"Content-Type": "application/json",
|
|
60
|
+
}
|
|
61
|
+
with httpx.Client(timeout=20.0, headers=headers) as client:
|
|
62
|
+
# Initialize with non-stream to capture session header reliably
|
|
63
|
+
init_payload = {
|
|
64
|
+
"jsonrpc": "2.0",
|
|
65
|
+
"id": "initialize",
|
|
66
|
+
"method": "initialize",
|
|
67
|
+
"params": {
|
|
68
|
+
"protocolVersion": "2025-06-18",
|
|
69
|
+
"capabilities": {"experimental": {}, "tools": {}, "resources": {}, "prompts": {}},
|
|
70
|
+
"clientInfo": {"name": "probe", "version": "1.0"}
|
|
71
|
+
},
|
|
72
|
+
}
|
|
73
|
+
init_resp = client.post(BASE, json=init_payload)
|
|
74
|
+
_ingest_set_cookie(client, init_resp)
|
|
75
|
+
session_id = init_resp.headers.get("mcp-session-id")
|
|
76
|
+
extra = {"mcp-session-id": session_id} if session_id else None
|
|
77
|
+
try:
|
|
78
|
+
init_json = init_resp.json()
|
|
79
|
+
except Exception:
|
|
80
|
+
init_json = {"status": init_resp.status_code, "text": init_resp.text}
|
|
81
|
+
print("== initialize", json.dumps(init_json, indent=2))
|
|
82
|
+
|
|
83
|
+
# Subsequent calls should reuse session cookies
|
|
84
|
+
print("== tools/list", json.dumps(sse_rpc(client, "tools/list", {}, extra_headers=extra), indent=2))
|
|
85
|
+
print("== vector-search", json.dumps(sse_rpc(client, "tools/call", {
|
|
86
|
+
"name": "databricks-vector-search", "arguments": {"query": "keyboard", "limit": 3}
|
|
87
|
+
}, extra_headers=extra), indent=2))
|
|
88
|
+
print("== dbsql-exec", json.dumps(sse_rpc(client, "tools/call", {
|
|
89
|
+
"name": "databricks-dbsql-exec", "arguments": {"sql": "SELECT 1"}
|
|
90
|
+
}, extra_headers=extra), indent=2))
|
|
91
|
+
print("== unity-function", json.dumps(sse_rpc(client, "tools/call", {
|
|
92
|
+
"name": "databricks-unity-function", "arguments": {"name": "top_revenue_products", "args": {"limit": 5}}
|
|
93
|
+
}, extra_headers=extra), indent=2))
|
|
94
|
+
print("== genie", json.dumps(sse_rpc(client, "tools/call", {
|
|
95
|
+
"name": "databricks-genie", "arguments": {"message": "top revenue products last month"}
|
|
96
|
+
}, extra_headers=extra), indent=2))
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
if __name__ == "__main__":
|
|
100
|
+
main()
|
|
101
|
+
|
|
102
|
+
|