crca 1.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- .github/ISSUE_TEMPLATE/bug_report.md +65 -0
- .github/ISSUE_TEMPLATE/feature_request.md +41 -0
- .github/PULL_REQUEST_TEMPLATE.md +20 -0
- .github/workflows/publish-manual.yml +61 -0
- .github/workflows/publish.yml +64 -0
- .gitignore +214 -0
- CRCA.py +4156 -0
- LICENSE +201 -0
- MANIFEST.in +43 -0
- PKG-INFO +5035 -0
- README.md +4959 -0
- __init__.py +17 -0
- branches/CRCA-Q.py +2728 -0
- branches/crca_cg/corposwarm.py +9065 -0
- branches/crca_cg/fix_rancher_docker_creds.ps1 +155 -0
- branches/crca_cg/package.json +5 -0
- branches/crca_cg/test_bolt_integration.py +446 -0
- branches/crca_cg/test_corposwarm_comprehensive.py +773 -0
- branches/crca_cg/test_new_features.py +163 -0
- branches/crca_sd/__init__.py +149 -0
- branches/crca_sd/crca_sd_core.py +770 -0
- branches/crca_sd/crca_sd_governance.py +1325 -0
- branches/crca_sd/crca_sd_mpc.py +1130 -0
- branches/crca_sd/crca_sd_realtime.py +1844 -0
- branches/crca_sd/crca_sd_tui.py +1133 -0
- crca-1.4.0.dist-info/METADATA +5035 -0
- crca-1.4.0.dist-info/RECORD +501 -0
- crca-1.4.0.dist-info/WHEEL +4 -0
- crca-1.4.0.dist-info/licenses/LICENSE +201 -0
- docs/CRCA-Q.md +2333 -0
- examples/config.yaml.example +25 -0
- examples/crca_sd_example.py +513 -0
- examples/data_broker_example.py +294 -0
- examples/logistics_corporation.py +861 -0
- examples/palantir_example.py +299 -0
- examples/policy_bench.py +934 -0
- examples/pridnestrovia-sd.py +705 -0
- examples/pridnestrovia_realtime.py +1902 -0
- prompts/__init__.py +10 -0
- prompts/default_crca.py +101 -0
- pyproject.toml +151 -0
- requirements.txt +76 -0
- schemas/__init__.py +43 -0
- schemas/mcpSchemas.py +51 -0
- schemas/policy.py +458 -0
- templates/__init__.py +38 -0
- templates/base_specialized_agent.py +195 -0
- templates/drift_detection.py +325 -0
- templates/examples/causal_agent_template.py +309 -0
- templates/examples/drag_drop_example.py +213 -0
- templates/examples/logistics_agent_template.py +207 -0
- templates/examples/trading_agent_template.py +206 -0
- templates/feature_mixins.py +253 -0
- templates/graph_management.py +442 -0
- templates/llm_integration.py +194 -0
- templates/module_registry.py +276 -0
- templates/mpc_planner.py +280 -0
- templates/policy_loop.py +1168 -0
- templates/prediction_framework.py +448 -0
- templates/statistical_methods.py +778 -0
- tests/sanity.yml +31 -0
- tests/sanity_check +406 -0
- tests/test_core.py +47 -0
- tests/test_crca_excel.py +166 -0
- tests/test_crca_sd.py +780 -0
- tests/test_data_broker.py +424 -0
- tests/test_palantir.py +349 -0
- tools/__init__.py +38 -0
- tools/actuators.py +437 -0
- tools/bolt.diy/Dockerfile +103 -0
- tools/bolt.diy/app/components/@settings/core/AvatarDropdown.tsx +175 -0
- tools/bolt.diy/app/components/@settings/core/ControlPanel.tsx +345 -0
- tools/bolt.diy/app/components/@settings/core/constants.tsx +108 -0
- tools/bolt.diy/app/components/@settings/core/types.ts +114 -0
- tools/bolt.diy/app/components/@settings/index.ts +12 -0
- tools/bolt.diy/app/components/@settings/shared/components/TabTile.tsx +151 -0
- tools/bolt.diy/app/components/@settings/shared/service-integration/ConnectionForm.tsx +193 -0
- tools/bolt.diy/app/components/@settings/shared/service-integration/ConnectionTestIndicator.tsx +60 -0
- tools/bolt.diy/app/components/@settings/shared/service-integration/ErrorState.tsx +102 -0
- tools/bolt.diy/app/components/@settings/shared/service-integration/LoadingState.tsx +94 -0
- tools/bolt.diy/app/components/@settings/shared/service-integration/ServiceHeader.tsx +72 -0
- tools/bolt.diy/app/components/@settings/shared/service-integration/index.ts +6 -0
- tools/bolt.diy/app/components/@settings/tabs/data/DataTab.tsx +721 -0
- tools/bolt.diy/app/components/@settings/tabs/data/DataVisualization.tsx +384 -0
- tools/bolt.diy/app/components/@settings/tabs/event-logs/EventLogsTab.tsx +1013 -0
- tools/bolt.diy/app/components/@settings/tabs/features/FeaturesTab.tsx +295 -0
- tools/bolt.diy/app/components/@settings/tabs/github/GitHubTab.tsx +281 -0
- tools/bolt.diy/app/components/@settings/tabs/github/components/GitHubAuthDialog.tsx +173 -0
- tools/bolt.diy/app/components/@settings/tabs/github/components/GitHubCacheManager.tsx +367 -0
- tools/bolt.diy/app/components/@settings/tabs/github/components/GitHubConnection.tsx +233 -0
- tools/bolt.diy/app/components/@settings/tabs/github/components/GitHubErrorBoundary.tsx +105 -0
- tools/bolt.diy/app/components/@settings/tabs/github/components/GitHubProgressiveLoader.tsx +266 -0
- tools/bolt.diy/app/components/@settings/tabs/github/components/GitHubRepositoryCard.tsx +121 -0
- tools/bolt.diy/app/components/@settings/tabs/github/components/GitHubRepositorySelector.tsx +312 -0
- tools/bolt.diy/app/components/@settings/tabs/github/components/GitHubStats.tsx +291 -0
- tools/bolt.diy/app/components/@settings/tabs/github/components/GitHubUserProfile.tsx +46 -0
- tools/bolt.diy/app/components/@settings/tabs/github/components/shared/GitHubStateIndicators.tsx +264 -0
- tools/bolt.diy/app/components/@settings/tabs/github/components/shared/RepositoryCard.tsx +361 -0
- tools/bolt.diy/app/components/@settings/tabs/github/components/shared/index.ts +11 -0
- tools/bolt.diy/app/components/@settings/tabs/gitlab/GitLabTab.tsx +305 -0
- tools/bolt.diy/app/components/@settings/tabs/gitlab/components/GitLabAuthDialog.tsx +186 -0
- tools/bolt.diy/app/components/@settings/tabs/gitlab/components/GitLabConnection.tsx +253 -0
- tools/bolt.diy/app/components/@settings/tabs/gitlab/components/GitLabRepositorySelector.tsx +358 -0
- tools/bolt.diy/app/components/@settings/tabs/gitlab/components/RepositoryCard.tsx +79 -0
- tools/bolt.diy/app/components/@settings/tabs/gitlab/components/RepositoryList.tsx +142 -0
- tools/bolt.diy/app/components/@settings/tabs/gitlab/components/StatsDisplay.tsx +91 -0
- tools/bolt.diy/app/components/@settings/tabs/gitlab/components/index.ts +4 -0
- tools/bolt.diy/app/components/@settings/tabs/mcp/McpServerList.tsx +99 -0
- tools/bolt.diy/app/components/@settings/tabs/mcp/McpServerListItem.tsx +70 -0
- tools/bolt.diy/app/components/@settings/tabs/mcp/McpStatusBadge.tsx +37 -0
- tools/bolt.diy/app/components/@settings/tabs/mcp/McpTab.tsx +239 -0
- tools/bolt.diy/app/components/@settings/tabs/netlify/NetlifyTab.tsx +1393 -0
- tools/bolt.diy/app/components/@settings/tabs/netlify/components/NetlifyConnection.tsx +990 -0
- tools/bolt.diy/app/components/@settings/tabs/netlify/components/index.ts +1 -0
- tools/bolt.diy/app/components/@settings/tabs/notifications/NotificationsTab.tsx +300 -0
- tools/bolt.diy/app/components/@settings/tabs/profile/ProfileTab.tsx +181 -0
- tools/bolt.diy/app/components/@settings/tabs/providers/cloud/CloudProvidersTab.tsx +308 -0
- tools/bolt.diy/app/components/@settings/tabs/providers/local/ErrorBoundary.tsx +68 -0
- tools/bolt.diy/app/components/@settings/tabs/providers/local/HealthStatusBadge.tsx +64 -0
- tools/bolt.diy/app/components/@settings/tabs/providers/local/LoadingSkeleton.tsx +107 -0
- tools/bolt.diy/app/components/@settings/tabs/providers/local/LocalProvidersTab.tsx +556 -0
- tools/bolt.diy/app/components/@settings/tabs/providers/local/ModelCard.tsx +106 -0
- tools/bolt.diy/app/components/@settings/tabs/providers/local/ProviderCard.tsx +120 -0
- tools/bolt.diy/app/components/@settings/tabs/providers/local/SetupGuide.tsx +671 -0
- tools/bolt.diy/app/components/@settings/tabs/providers/local/StatusDashboard.tsx +91 -0
- tools/bolt.diy/app/components/@settings/tabs/providers/local/types.ts +44 -0
- tools/bolt.diy/app/components/@settings/tabs/settings/SettingsTab.tsx +215 -0
- tools/bolt.diy/app/components/@settings/tabs/supabase/SupabaseTab.tsx +1089 -0
- tools/bolt.diy/app/components/@settings/tabs/vercel/VercelTab.tsx +909 -0
- tools/bolt.diy/app/components/@settings/tabs/vercel/components/VercelConnection.tsx +368 -0
- tools/bolt.diy/app/components/@settings/tabs/vercel/components/index.ts +1 -0
- tools/bolt.diy/app/components/@settings/utils/tab-helpers.ts +54 -0
- tools/bolt.diy/app/components/chat/APIKeyManager.tsx +169 -0
- tools/bolt.diy/app/components/chat/Artifact.tsx +296 -0
- tools/bolt.diy/app/components/chat/AssistantMessage.tsx +192 -0
- tools/bolt.diy/app/components/chat/BaseChat.module.scss +47 -0
- tools/bolt.diy/app/components/chat/BaseChat.tsx +522 -0
- tools/bolt.diy/app/components/chat/Chat.client.tsx +670 -0
- tools/bolt.diy/app/components/chat/ChatAlert.tsx +108 -0
- tools/bolt.diy/app/components/chat/ChatBox.tsx +334 -0
- tools/bolt.diy/app/components/chat/CodeBlock.module.scss +10 -0
- tools/bolt.diy/app/components/chat/CodeBlock.tsx +85 -0
- tools/bolt.diy/app/components/chat/DicussMode.tsx +17 -0
- tools/bolt.diy/app/components/chat/ExamplePrompts.tsx +37 -0
- tools/bolt.diy/app/components/chat/FilePreview.tsx +38 -0
- tools/bolt.diy/app/components/chat/GitCloneButton.tsx +327 -0
- tools/bolt.diy/app/components/chat/ImportFolderButton.tsx +141 -0
- tools/bolt.diy/app/components/chat/LLMApiAlert.tsx +109 -0
- tools/bolt.diy/app/components/chat/MCPTools.tsx +129 -0
- tools/bolt.diy/app/components/chat/Markdown.module.scss +171 -0
- tools/bolt.diy/app/components/chat/Markdown.spec.ts +48 -0
- tools/bolt.diy/app/components/chat/Markdown.tsx +252 -0
- tools/bolt.diy/app/components/chat/Messages.client.tsx +102 -0
- tools/bolt.diy/app/components/chat/ModelSelector.tsx +797 -0
- tools/bolt.diy/app/components/chat/NetlifyDeploymentLink.client.tsx +51 -0
- tools/bolt.diy/app/components/chat/ProgressCompilation.tsx +110 -0
- tools/bolt.diy/app/components/chat/ScreenshotStateManager.tsx +33 -0
- tools/bolt.diy/app/components/chat/SendButton.client.tsx +39 -0
- tools/bolt.diy/app/components/chat/SpeechRecognition.tsx +28 -0
- tools/bolt.diy/app/components/chat/StarterTemplates.tsx +38 -0
- tools/bolt.diy/app/components/chat/SupabaseAlert.tsx +199 -0
- tools/bolt.diy/app/components/chat/SupabaseConnection.tsx +339 -0
- tools/bolt.diy/app/components/chat/ThoughtBox.tsx +43 -0
- tools/bolt.diy/app/components/chat/ToolInvocations.tsx +409 -0
- tools/bolt.diy/app/components/chat/UserMessage.tsx +101 -0
- tools/bolt.diy/app/components/chat/VercelDeploymentLink.client.tsx +158 -0
- tools/bolt.diy/app/components/chat/chatExportAndImport/ExportChatButton.tsx +49 -0
- tools/bolt.diy/app/components/chat/chatExportAndImport/ImportButtons.tsx +96 -0
- tools/bolt.diy/app/components/deploy/DeployAlert.tsx +197 -0
- tools/bolt.diy/app/components/deploy/DeployButton.tsx +277 -0
- tools/bolt.diy/app/components/deploy/GitHubDeploy.client.tsx +171 -0
- tools/bolt.diy/app/components/deploy/GitHubDeploymentDialog.tsx +1041 -0
- tools/bolt.diy/app/components/deploy/GitLabDeploy.client.tsx +171 -0
- tools/bolt.diy/app/components/deploy/GitLabDeploymentDialog.tsx +764 -0
- tools/bolt.diy/app/components/deploy/NetlifyDeploy.client.tsx +246 -0
- tools/bolt.diy/app/components/deploy/VercelDeploy.client.tsx +235 -0
- tools/bolt.diy/app/components/editor/codemirror/BinaryContent.tsx +7 -0
- tools/bolt.diy/app/components/editor/codemirror/CodeMirrorEditor.tsx +555 -0
- tools/bolt.diy/app/components/editor/codemirror/EnvMasking.ts +80 -0
- tools/bolt.diy/app/components/editor/codemirror/cm-theme.ts +192 -0
- tools/bolt.diy/app/components/editor/codemirror/indent.ts +68 -0
- tools/bolt.diy/app/components/editor/codemirror/languages.ts +112 -0
- tools/bolt.diy/app/components/git/GitUrlImport.client.tsx +147 -0
- tools/bolt.diy/app/components/header/Header.tsx +42 -0
- tools/bolt.diy/app/components/header/HeaderActionButtons.client.tsx +54 -0
- tools/bolt.diy/app/components/mandate/MandateSubmission.tsx +167 -0
- tools/bolt.diy/app/components/observability/DeploymentStatus.tsx +168 -0
- tools/bolt.diy/app/components/observability/EventTimeline.tsx +119 -0
- tools/bolt.diy/app/components/observability/FileDiffViewer.tsx +121 -0
- tools/bolt.diy/app/components/observability/GovernanceStatus.tsx +197 -0
- tools/bolt.diy/app/components/observability/GovernorMetrics.tsx +246 -0
- tools/bolt.diy/app/components/observability/LogStream.tsx +244 -0
- tools/bolt.diy/app/components/observability/MandateDetails.tsx +201 -0
- tools/bolt.diy/app/components/observability/ObservabilityDashboard.tsx +200 -0
- tools/bolt.diy/app/components/sidebar/HistoryItem.tsx +187 -0
- tools/bolt.diy/app/components/sidebar/Menu.client.tsx +536 -0
- tools/bolt.diy/app/components/sidebar/date-binning.ts +59 -0
- tools/bolt.diy/app/components/txt +1 -0
- tools/bolt.diy/app/components/ui/BackgroundRays/index.tsx +18 -0
- tools/bolt.diy/app/components/ui/BackgroundRays/styles.module.scss +246 -0
- tools/bolt.diy/app/components/ui/Badge.tsx +53 -0
- tools/bolt.diy/app/components/ui/BranchSelector.tsx +270 -0
- tools/bolt.diy/app/components/ui/Breadcrumbs.tsx +101 -0
- tools/bolt.diy/app/components/ui/Button.tsx +46 -0
- tools/bolt.diy/app/components/ui/Card.tsx +55 -0
- tools/bolt.diy/app/components/ui/Checkbox.tsx +32 -0
- tools/bolt.diy/app/components/ui/CloseButton.tsx +49 -0
- tools/bolt.diy/app/components/ui/CodeBlock.tsx +103 -0
- tools/bolt.diy/app/components/ui/Collapsible.tsx +9 -0
- tools/bolt.diy/app/components/ui/ColorSchemeDialog.tsx +378 -0
- tools/bolt.diy/app/components/ui/Dialog.tsx +449 -0
- tools/bolt.diy/app/components/ui/Dropdown.tsx +63 -0
- tools/bolt.diy/app/components/ui/EmptyState.tsx +154 -0
- tools/bolt.diy/app/components/ui/FileIcon.tsx +346 -0
- tools/bolt.diy/app/components/ui/FilterChip.tsx +92 -0
- tools/bolt.diy/app/components/ui/GlowingEffect.tsx +192 -0
- tools/bolt.diy/app/components/ui/GradientCard.tsx +100 -0
- tools/bolt.diy/app/components/ui/IconButton.tsx +84 -0
- tools/bolt.diy/app/components/ui/Input.tsx +22 -0
- tools/bolt.diy/app/components/ui/Label.tsx +20 -0
- tools/bolt.diy/app/components/ui/LoadingDots.tsx +27 -0
- tools/bolt.diy/app/components/ui/LoadingOverlay.tsx +32 -0
- tools/bolt.diy/app/components/ui/PanelHeader.tsx +20 -0
- tools/bolt.diy/app/components/ui/PanelHeaderButton.tsx +36 -0
- tools/bolt.diy/app/components/ui/Popover.tsx +29 -0
- tools/bolt.diy/app/components/ui/Progress.tsx +22 -0
- tools/bolt.diy/app/components/ui/RepositoryStats.tsx +87 -0
- tools/bolt.diy/app/components/ui/ScrollArea.tsx +41 -0
- tools/bolt.diy/app/components/ui/SearchInput.tsx +80 -0
- tools/bolt.diy/app/components/ui/SearchResultItem.tsx +134 -0
- tools/bolt.diy/app/components/ui/Separator.tsx +22 -0
- tools/bolt.diy/app/components/ui/SettingsButton.tsx +35 -0
- tools/bolt.diy/app/components/ui/Slider.tsx +73 -0
- tools/bolt.diy/app/components/ui/StatusIndicator.tsx +90 -0
- tools/bolt.diy/app/components/ui/Switch.tsx +37 -0
- tools/bolt.diy/app/components/ui/Tabs.tsx +52 -0
- tools/bolt.diy/app/components/ui/TabsWithSlider.tsx +112 -0
- tools/bolt.diy/app/components/ui/ThemeSwitch.tsx +29 -0
- tools/bolt.diy/app/components/ui/Tooltip.tsx +122 -0
- tools/bolt.diy/app/components/ui/index.ts +38 -0
- tools/bolt.diy/app/components/ui/use-toast.ts +66 -0
- tools/bolt.diy/app/components/workbench/DiffView.tsx +796 -0
- tools/bolt.diy/app/components/workbench/EditorPanel.tsx +174 -0
- tools/bolt.diy/app/components/workbench/ExpoQrModal.tsx +55 -0
- tools/bolt.diy/app/components/workbench/FileBreadcrumb.tsx +150 -0
- tools/bolt.diy/app/components/workbench/FileTree.tsx +565 -0
- tools/bolt.diy/app/components/workbench/Inspector.tsx +126 -0
- tools/bolt.diy/app/components/workbench/InspectorPanel.tsx +146 -0
- tools/bolt.diy/app/components/workbench/LockManager.tsx +262 -0
- tools/bolt.diy/app/components/workbench/PortDropdown.tsx +91 -0
- tools/bolt.diy/app/components/workbench/Preview.tsx +1049 -0
- tools/bolt.diy/app/components/workbench/ScreenshotSelector.tsx +293 -0
- tools/bolt.diy/app/components/workbench/Search.tsx +257 -0
- tools/bolt.diy/app/components/workbench/Workbench.client.tsx +506 -0
- tools/bolt.diy/app/components/workbench/terminal/Terminal.tsx +131 -0
- tools/bolt.diy/app/components/workbench/terminal/TerminalManager.tsx +68 -0
- tools/bolt.diy/app/components/workbench/terminal/TerminalTabs.tsx +277 -0
- tools/bolt.diy/app/components/workbench/terminal/theme.ts +36 -0
- tools/bolt.diy/app/components/workflow/WorkflowPhase.tsx +109 -0
- tools/bolt.diy/app/components/workflow/WorkflowStatus.tsx +60 -0
- tools/bolt.diy/app/components/workflow/WorkflowTimeline.tsx +150 -0
- tools/bolt.diy/app/entry.client.tsx +7 -0
- tools/bolt.diy/app/entry.server.tsx +80 -0
- tools/bolt.diy/app/root.tsx +156 -0
- tools/bolt.diy/app/routes/_index.tsx +175 -0
- tools/bolt.diy/app/routes/api.bug-report.ts +254 -0
- tools/bolt.diy/app/routes/api.chat.ts +463 -0
- tools/bolt.diy/app/routes/api.check-env-key.ts +41 -0
- tools/bolt.diy/app/routes/api.configured-providers.ts +110 -0
- tools/bolt.diy/app/routes/api.corporate-swarm-status.ts +55 -0
- tools/bolt.diy/app/routes/api.enhancer.ts +137 -0
- tools/bolt.diy/app/routes/api.export-api-keys.ts +44 -0
- tools/bolt.diy/app/routes/api.git-info.ts +69 -0
- tools/bolt.diy/app/routes/api.git-proxy.$.ts +178 -0
- tools/bolt.diy/app/routes/api.github-branches.ts +166 -0
- tools/bolt.diy/app/routes/api.github-deploy.ts +67 -0
- tools/bolt.diy/app/routes/api.github-stats.ts +198 -0
- tools/bolt.diy/app/routes/api.github-template.ts +242 -0
- tools/bolt.diy/app/routes/api.github-user.ts +287 -0
- tools/bolt.diy/app/routes/api.gitlab-branches.ts +143 -0
- tools/bolt.diy/app/routes/api.gitlab-deploy.ts +67 -0
- tools/bolt.diy/app/routes/api.gitlab-projects.ts +105 -0
- tools/bolt.diy/app/routes/api.health.ts +8 -0
- tools/bolt.diy/app/routes/api.llmcall.ts +298 -0
- tools/bolt.diy/app/routes/api.mandate.ts +351 -0
- tools/bolt.diy/app/routes/api.mcp-check.ts +16 -0
- tools/bolt.diy/app/routes/api.mcp-update-config.ts +23 -0
- tools/bolt.diy/app/routes/api.models.$provider.ts +2 -0
- tools/bolt.diy/app/routes/api.models.ts +90 -0
- tools/bolt.diy/app/routes/api.netlify-deploy.ts +240 -0
- tools/bolt.diy/app/routes/api.netlify-user.ts +142 -0
- tools/bolt.diy/app/routes/api.supabase-user.ts +199 -0
- tools/bolt.diy/app/routes/api.supabase.query.ts +92 -0
- tools/bolt.diy/app/routes/api.supabase.ts +56 -0
- tools/bolt.diy/app/routes/api.supabase.variables.ts +32 -0
- tools/bolt.diy/app/routes/api.system.diagnostics.ts +142 -0
- tools/bolt.diy/app/routes/api.system.disk-info.ts +311 -0
- tools/bolt.diy/app/routes/api.system.git-info.ts +332 -0
- tools/bolt.diy/app/routes/api.update.ts +21 -0
- tools/bolt.diy/app/routes/api.vercel-deploy.ts +497 -0
- tools/bolt.diy/app/routes/api.vercel-user.ts +161 -0
- tools/bolt.diy/app/routes/api.workflow-status.$proposalId.ts +309 -0
- tools/bolt.diy/app/routes/chat.$id.tsx +8 -0
- tools/bolt.diy/app/routes/execute.$mandateId.tsx +432 -0
- tools/bolt.diy/app/routes/git.tsx +25 -0
- tools/bolt.diy/app/routes/observability.$mandateId.tsx +50 -0
- tools/bolt.diy/app/routes/webcontainer.connect.$id.tsx +32 -0
- tools/bolt.diy/app/routes/webcontainer.preview.$id.tsx +97 -0
- tools/bolt.diy/app/routes/workflow.$proposalId.tsx +170 -0
- tools/bolt.diy/app/styles/animations.scss +49 -0
- tools/bolt.diy/app/styles/components/code.scss +9 -0
- tools/bolt.diy/app/styles/components/editor.scss +135 -0
- tools/bolt.diy/app/styles/components/resize-handle.scss +30 -0
- tools/bolt.diy/app/styles/components/terminal.scss +3 -0
- tools/bolt.diy/app/styles/components/toast.scss +23 -0
- tools/bolt.diy/app/styles/diff-view.css +72 -0
- tools/bolt.diy/app/styles/index.scss +73 -0
- tools/bolt.diy/app/styles/variables.scss +255 -0
- tools/bolt.diy/app/styles/z-index.scss +37 -0
- tools/bolt.diy/app/types/GitHub.ts +182 -0
- tools/bolt.diy/app/types/GitLab.ts +103 -0
- tools/bolt.diy/app/types/actions.ts +85 -0
- tools/bolt.diy/app/types/artifact.ts +5 -0
- tools/bolt.diy/app/types/context.ts +26 -0
- tools/bolt.diy/app/types/design-scheme.ts +93 -0
- tools/bolt.diy/app/types/global.d.ts +13 -0
- tools/bolt.diy/app/types/mandate.ts +333 -0
- tools/bolt.diy/app/types/model.ts +25 -0
- tools/bolt.diy/app/types/netlify.ts +94 -0
- tools/bolt.diy/app/types/supabase.ts +54 -0
- tools/bolt.diy/app/types/template.ts +8 -0
- tools/bolt.diy/app/types/terminal.ts +9 -0
- tools/bolt.diy/app/types/theme.ts +1 -0
- tools/bolt.diy/app/types/vercel.ts +67 -0
- tools/bolt.diy/app/utils/buffer.ts +29 -0
- tools/bolt.diy/app/utils/classNames.ts +65 -0
- tools/bolt.diy/app/utils/constants.ts +147 -0
- tools/bolt.diy/app/utils/debounce.ts +13 -0
- tools/bolt.diy/app/utils/debugLogger.ts +1284 -0
- tools/bolt.diy/app/utils/diff.spec.ts +11 -0
- tools/bolt.diy/app/utils/diff.ts +117 -0
- tools/bolt.diy/app/utils/easings.ts +3 -0
- tools/bolt.diy/app/utils/fileLocks.ts +96 -0
- tools/bolt.diy/app/utils/fileUtils.ts +121 -0
- tools/bolt.diy/app/utils/folderImport.ts +73 -0
- tools/bolt.diy/app/utils/formatSize.ts +12 -0
- tools/bolt.diy/app/utils/getLanguageFromExtension.ts +24 -0
- tools/bolt.diy/app/utils/githubStats.ts +9 -0
- tools/bolt.diy/app/utils/gitlabStats.ts +54 -0
- tools/bolt.diy/app/utils/logger.ts +162 -0
- tools/bolt.diy/app/utils/markdown.ts +155 -0
- tools/bolt.diy/app/utils/mobile.ts +4 -0
- tools/bolt.diy/app/utils/os.ts +4 -0
- tools/bolt.diy/app/utils/path.ts +19 -0
- tools/bolt.diy/app/utils/projectCommands.ts +197 -0
- tools/bolt.diy/app/utils/promises.ts +19 -0
- tools/bolt.diy/app/utils/react.ts +6 -0
- tools/bolt.diy/app/utils/sampler.ts +49 -0
- tools/bolt.diy/app/utils/selectStarterTemplate.ts +255 -0
- tools/bolt.diy/app/utils/shell.ts +384 -0
- tools/bolt.diy/app/utils/stacktrace.ts +27 -0
- tools/bolt.diy/app/utils/stripIndent.ts +23 -0
- tools/bolt.diy/app/utils/terminal.ts +11 -0
- tools/bolt.diy/app/utils/unreachable.ts +3 -0
- tools/bolt.diy/app/vite-env.d.ts +2 -0
- tools/bolt.diy/assets/entitlements.mac.plist +25 -0
- tools/bolt.diy/assets/icons/icon.icns +0 -0
- tools/bolt.diy/assets/icons/icon.ico +0 -0
- tools/bolt.diy/assets/icons/icon.png +0 -0
- tools/bolt.diy/bindings.js +78 -0
- tools/bolt.diy/bindings.sh +33 -0
- tools/bolt.diy/docker-compose.yaml +145 -0
- tools/bolt.diy/electron/main/index.ts +201 -0
- tools/bolt.diy/electron/main/tsconfig.json +30 -0
- tools/bolt.diy/electron/main/ui/menu.ts +29 -0
- tools/bolt.diy/electron/main/ui/window.ts +54 -0
- tools/bolt.diy/electron/main/utils/auto-update.ts +110 -0
- tools/bolt.diy/electron/main/utils/constants.ts +4 -0
- tools/bolt.diy/electron/main/utils/cookie.ts +40 -0
- tools/bolt.diy/electron/main/utils/reload.ts +35 -0
- tools/bolt.diy/electron/main/utils/serve.ts +71 -0
- tools/bolt.diy/electron/main/utils/store.ts +3 -0
- tools/bolt.diy/electron/main/utils/vite-server.ts +44 -0
- tools/bolt.diy/electron/main/vite.config.ts +44 -0
- tools/bolt.diy/electron/preload/index.ts +22 -0
- tools/bolt.diy/electron/preload/tsconfig.json +7 -0
- tools/bolt.diy/electron/preload/vite.config.ts +31 -0
- tools/bolt.diy/electron-builder.yml +64 -0
- tools/bolt.diy/electron-update.yml +4 -0
- tools/bolt.diy/eslint.config.mjs +57 -0
- tools/bolt.diy/functions/[[path]].ts +12 -0
- tools/bolt.diy/icons/angular.svg +1 -0
- tools/bolt.diy/icons/astro.svg +8 -0
- tools/bolt.diy/icons/chat.svg +1 -0
- tools/bolt.diy/icons/expo-brand.svg +1 -0
- tools/bolt.diy/icons/expo.svg +4 -0
- tools/bolt.diy/icons/logo-text.svg +1 -0
- tools/bolt.diy/icons/logo.svg +4 -0
- tools/bolt.diy/icons/mcp.svg +1 -0
- tools/bolt.diy/icons/nativescript.svg +1 -0
- tools/bolt.diy/icons/netlify.svg +10 -0
- tools/bolt.diy/icons/nextjs.svg +1 -0
- tools/bolt.diy/icons/nuxt.svg +1 -0
- tools/bolt.diy/icons/qwik.svg +1 -0
- tools/bolt.diy/icons/react.svg +1 -0
- tools/bolt.diy/icons/remix.svg +24 -0
- tools/bolt.diy/icons/remotion.svg +1 -0
- tools/bolt.diy/icons/shadcn.svg +21 -0
- tools/bolt.diy/icons/slidev.svg +60 -0
- tools/bolt.diy/icons/solidjs.svg +1 -0
- tools/bolt.diy/icons/stars.svg +1 -0
- tools/bolt.diy/icons/svelte.svg +1 -0
- tools/bolt.diy/icons/typescript.svg +1 -0
- tools/bolt.diy/icons/vite.svg +1 -0
- tools/bolt.diy/icons/vue.svg +1 -0
- tools/bolt.diy/load-context.ts +9 -0
- tools/bolt.diy/notarize.cjs +31 -0
- tools/bolt.diy/package.json +218 -0
- tools/bolt.diy/playwright.config.preview.ts +35 -0
- tools/bolt.diy/pre-start.cjs +26 -0
- tools/bolt.diy/public/apple-touch-icon-precomposed.png +0 -0
- tools/bolt.diy/public/apple-touch-icon.png +0 -0
- tools/bolt.diy/public/favicon.ico +0 -0
- tools/bolt.diy/public/favicon.svg +4 -0
- tools/bolt.diy/public/icons/AmazonBedrock.svg +1 -0
- tools/bolt.diy/public/icons/Anthropic.svg +4 -0
- tools/bolt.diy/public/icons/Cohere.svg +4 -0
- tools/bolt.diy/public/icons/Deepseek.svg +5 -0
- tools/bolt.diy/public/icons/Default.svg +4 -0
- tools/bolt.diy/public/icons/Google.svg +4 -0
- tools/bolt.diy/public/icons/Groq.svg +4 -0
- tools/bolt.diy/public/icons/HuggingFace.svg +4 -0
- tools/bolt.diy/public/icons/Hyperbolic.svg +3 -0
- tools/bolt.diy/public/icons/LMStudio.svg +5 -0
- tools/bolt.diy/public/icons/Mistral.svg +4 -0
- tools/bolt.diy/public/icons/Ollama.svg +4 -0
- tools/bolt.diy/public/icons/OpenAI.svg +4 -0
- tools/bolt.diy/public/icons/OpenAILike.svg +4 -0
- tools/bolt.diy/public/icons/OpenRouter.svg +4 -0
- tools/bolt.diy/public/icons/Perplexity.svg +4 -0
- tools/bolt.diy/public/icons/Together.svg +4 -0
- tools/bolt.diy/public/icons/xAI.svg +5 -0
- tools/bolt.diy/public/inspector-script.js +292 -0
- tools/bolt.diy/public/logo-dark-styled.png +0 -0
- tools/bolt.diy/public/logo-dark.png +0 -0
- tools/bolt.diy/public/logo-light-styled.png +0 -0
- tools/bolt.diy/public/logo-light.png +0 -0
- tools/bolt.diy/public/logo.svg +15 -0
- tools/bolt.diy/public/social_preview_index.jpg +0 -0
- tools/bolt.diy/scripts/clean.js +45 -0
- tools/bolt.diy/scripts/electron-dev.mjs +181 -0
- tools/bolt.diy/scripts/setup-env.sh +41 -0
- tools/bolt.diy/scripts/update-imports.sh +7 -0
- tools/bolt.diy/scripts/update.sh +52 -0
- tools/bolt.diy/services/execution-governor/Dockerfile +41 -0
- tools/bolt.diy/services/execution-governor/config.ts +42 -0
- tools/bolt.diy/services/execution-governor/index.ts +683 -0
- tools/bolt.diy/services/execution-governor/metrics.ts +141 -0
- tools/bolt.diy/services/execution-governor/package.json +31 -0
- tools/bolt.diy/services/execution-governor/priority-queue.ts +139 -0
- tools/bolt.diy/services/execution-governor/tsconfig.json +21 -0
- tools/bolt.diy/services/execution-governor/types.ts +145 -0
- tools/bolt.diy/services/headless-executor/Dockerfile +43 -0
- tools/bolt.diy/services/headless-executor/executor.ts +210 -0
- tools/bolt.diy/services/headless-executor/index.ts +323 -0
- tools/bolt.diy/services/headless-executor/package.json +27 -0
- tools/bolt.diy/services/headless-executor/tsconfig.json +21 -0
- tools/bolt.diy/services/headless-executor/types.ts +38 -0
- tools/bolt.diy/test-workflows.sh +240 -0
- tools/bolt.diy/tests/integration/corporate-swarm.test.ts +208 -0
- tools/bolt.diy/tests/mandates/budget-limited.json +34 -0
- tools/bolt.diy/tests/mandates/complex.json +53 -0
- tools/bolt.diy/tests/mandates/constraint-enforced.json +36 -0
- tools/bolt.diy/tests/mandates/simple.json +35 -0
- tools/bolt.diy/tsconfig.json +37 -0
- tools/bolt.diy/types/istextorbinary.d.ts +15 -0
- tools/bolt.diy/uno.config.ts +279 -0
- tools/bolt.diy/vite-electron.config.ts +76 -0
- tools/bolt.diy/vite.config.ts +112 -0
- tools/bolt.diy/worker-configuration.d.ts +22 -0
- tools/bolt.diy/wrangler.toml +6 -0
- tools/code_generator.py +461 -0
- tools/file_operations.py +465 -0
- tools/mandate_generator.py +337 -0
- tools/mcpClientUtils.py +1216 -0
- tools/sensors.py +285 -0
- utils/Agent_types.py +15 -0
- utils/AnyToStr.py +0 -0
- utils/HHCS.py +277 -0
- utils/__init__.py +30 -0
- utils/agent.py +3627 -0
- utils/aop.py +2948 -0
- utils/canonical.py +143 -0
- utils/conversation.py +1195 -0
- utils/doctrine_versioning +230 -0
- utils/formatter.py +474 -0
- utils/ledger.py +311 -0
- utils/out_types.py +16 -0
- utils/rollback.py +339 -0
- utils/router.py +929 -0
- utils/tui.py +1908 -0
CRCA.py
ADDED
|
@@ -0,0 +1,4156 @@
|
|
|
1
|
+
"""CRCAAgent - Causal Reasoning and Counterfactual Analysis Agent.
|
|
2
|
+
|
|
3
|
+
This module provides a lightweight causal reasoning agent with LLM integration,
|
|
4
|
+
implemented in pure Python and intended as a flexible CR-CA engine for Swarms.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
# Standard library imports
|
|
8
|
+
import asyncio
|
|
9
|
+
import importlib
|
|
10
|
+
import importlib.util
|
|
11
|
+
import inspect
|
|
12
|
+
import logging
|
|
13
|
+
import math
|
|
14
|
+
import os
|
|
15
|
+
import threading
|
|
16
|
+
from dataclasses import dataclass
|
|
17
|
+
from enum import Enum
|
|
18
|
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
|
19
|
+
|
|
20
|
+
# Third-party imports
|
|
21
|
+
import numpy as np
|
|
22
|
+
from loguru import logger
|
|
23
|
+
from swarms.structs.agent import Agent
|
|
24
|
+
|
|
25
|
+
# Try to import rustworkx (required dependency)
|
|
26
|
+
try:
|
|
27
|
+
import rustworkx as rx
|
|
28
|
+
except Exception as e:
|
|
29
|
+
raise ImportError(
|
|
30
|
+
"rustworkx is required for the CRCAAgent rustworkx upgrade: pip install rustworkx"
|
|
31
|
+
) from e
|
|
32
|
+
|
|
33
|
+
# Optional heavy dependencies — used when available
|
|
34
|
+
try:
|
|
35
|
+
import pandas as pd # type: ignore
|
|
36
|
+
PANDAS_AVAILABLE = True
|
|
37
|
+
except Exception:
|
|
38
|
+
PANDAS_AVAILABLE = False
|
|
39
|
+
|
|
40
|
+
try:
|
|
41
|
+
from scipy import linalg as scipy_linalg # type: ignore
|
|
42
|
+
from scipy import stats as scipy_stats # type: ignore
|
|
43
|
+
SCIPY_AVAILABLE = True
|
|
44
|
+
except Exception:
|
|
45
|
+
SCIPY_AVAILABLE = False
|
|
46
|
+
|
|
47
|
+
try:
|
|
48
|
+
import cvxpy as cp # type: ignore
|
|
49
|
+
CVXPY_AVAILABLE = True
|
|
50
|
+
except Exception:
|
|
51
|
+
CVXPY_AVAILABLE = False
|
|
52
|
+
|
|
53
|
+
# Load environment variables from .env file
|
|
54
|
+
try:
|
|
55
|
+
from dotenv import load_dotenv
|
|
56
|
+
import os
|
|
57
|
+
|
|
58
|
+
# Explicitly load from .env file in the project root
|
|
59
|
+
env_path = os.path.join(os.path.dirname(__file__), '.env')
|
|
60
|
+
load_dotenv(dotenv_path=env_path, override=False)
|
|
61
|
+
except ImportError:
|
|
62
|
+
# dotenv not available, skip loading
|
|
63
|
+
pass
|
|
64
|
+
except Exception:
|
|
65
|
+
# .env file might not exist, that's okay
|
|
66
|
+
pass
|
|
67
|
+
|
|
68
|
+
# Local imports
|
|
69
|
+
try:
|
|
70
|
+
from prompts.default_crca import DEFAULT_CRCA_SYSTEM_PROMPT
|
|
71
|
+
except ImportError:
|
|
72
|
+
# Fallback if prompt file doesn't exist
|
|
73
|
+
DEFAULT_CRCA_SYSTEM_PROMPT = None
|
|
74
|
+
|
|
75
|
+
# Policy engine imports (optional - only if policy_mode is enabled)
|
|
76
|
+
try:
|
|
77
|
+
from schemas.policy import DoctrineV1
|
|
78
|
+
from utils.ledger import Ledger
|
|
79
|
+
from templates.policy_loop import PolicyLoopMixin
|
|
80
|
+
POLICY_ENGINE_AVAILABLE = True
|
|
81
|
+
except ImportError:
|
|
82
|
+
POLICY_ENGINE_AVAILABLE = False
|
|
83
|
+
logger.debug("Policy engine modules not available")
|
|
84
|
+
|
|
85
|
+
# Fix litellm async compatibility
|
|
86
|
+
try:
|
|
87
|
+
lu_spec = importlib.util.find_spec("litellm.litellm_core_utils.logging_utils")
|
|
88
|
+
if lu_spec is not None:
|
|
89
|
+
lu = importlib.import_module("litellm.litellm_core_utils.logging_utils")
|
|
90
|
+
try:
|
|
91
|
+
if hasattr(lu, "asyncio") and hasattr(lu.asyncio, "iscoroutinefunction"):
|
|
92
|
+
lu.asyncio.iscoroutinefunction = inspect.iscoroutinefunction
|
|
93
|
+
except Exception:
|
|
94
|
+
pass
|
|
95
|
+
except Exception:
|
|
96
|
+
pass
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
class CausalRelationType(Enum):
|
|
100
|
+
"""Enumeration of causal relationship types.
|
|
101
|
+
|
|
102
|
+
Defines the different types of causal relationships that can exist
|
|
103
|
+
between variables in a causal graph.
|
|
104
|
+
"""
|
|
105
|
+
|
|
106
|
+
DIRECT = "direct"
|
|
107
|
+
INDIRECT = "indirect"
|
|
108
|
+
CONFOUNDING = "confounding"
|
|
109
|
+
MEDIATING = "mediating"
|
|
110
|
+
MODERATING = "moderating"
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
@dataclass
|
|
114
|
+
class CausalNode:
|
|
115
|
+
"""Represents a node in the causal graph.
|
|
116
|
+
|
|
117
|
+
Attributes:
|
|
118
|
+
name: Name of the variable/node
|
|
119
|
+
value: Current value of the variable (optional)
|
|
120
|
+
confidence: Confidence level in the node (default: 1.0)
|
|
121
|
+
node_type: Type of the node (default: "variable")
|
|
122
|
+
"""
|
|
123
|
+
|
|
124
|
+
name: str
|
|
125
|
+
value: Optional[float] = None
|
|
126
|
+
confidence: float = 1.0
|
|
127
|
+
node_type: str = "variable"
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
@dataclass
|
|
131
|
+
class CausalEdge:
|
|
132
|
+
"""Represents an edge (causal relationship) in the causal graph.
|
|
133
|
+
|
|
134
|
+
Attributes:
|
|
135
|
+
source: Source variable name
|
|
136
|
+
target: Target variable name
|
|
137
|
+
strength: Strength of the causal relationship (default: 1.0)
|
|
138
|
+
relation_type: Type of causal relation (default: DIRECT)
|
|
139
|
+
confidence: Confidence level in the edge (default: 1.0)
|
|
140
|
+
"""
|
|
141
|
+
|
|
142
|
+
source: str
|
|
143
|
+
target: str
|
|
144
|
+
strength: float = 1.0
|
|
145
|
+
relation_type: CausalRelationType = CausalRelationType.DIRECT
|
|
146
|
+
confidence: float = 1.0
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
@dataclass
|
|
150
|
+
class CounterfactualScenario:
|
|
151
|
+
"""Represents a counterfactual scenario for analysis.
|
|
152
|
+
|
|
153
|
+
Attributes:
|
|
154
|
+
name: Name/identifier of the scenario
|
|
155
|
+
interventions: Dictionary mapping variable names to intervention values
|
|
156
|
+
expected_outcomes: Dictionary mapping variable names to expected outcomes
|
|
157
|
+
probability: Probability of this scenario (default: 1.0)
|
|
158
|
+
reasoning: Explanation of why this scenario is important
|
|
159
|
+
uncertainty_metadata: Optional metadata about prediction confidence, graph uncertainty, scenario relevance
|
|
160
|
+
sampling_distribution: Optional distribution type used for sampling (gaussian/uniform/mixture/adaptive)
|
|
161
|
+
monte_carlo_iterations: Optional number of Monte Carlo samples used
|
|
162
|
+
meta_reasoning_score: Optional overall quality/informativeness score
|
|
163
|
+
"""
|
|
164
|
+
|
|
165
|
+
name: str
|
|
166
|
+
interventions: Dict[str, float]
|
|
167
|
+
expected_outcomes: Dict[str, float]
|
|
168
|
+
probability: float = 1.0
|
|
169
|
+
reasoning: str = ""
|
|
170
|
+
uncertainty_metadata: Optional[Dict[str, Any]] = None
|
|
171
|
+
sampling_distribution: Optional[str] = None
|
|
172
|
+
monte_carlo_iterations: Optional[int] = None
|
|
173
|
+
meta_reasoning_score: Optional[float] = None
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
# Internal helper classes for meta-Monte Carlo counterfactual reasoning
|
|
177
|
+
class _AdaptiveInterventionSampler:
|
|
178
|
+
"""Adaptive intervention sampler based on causal graph structure."""
|
|
179
|
+
|
|
180
|
+
def __init__(self, agent: 'CRCAAgent'):
|
|
181
|
+
"""Initialize sampler with reference to agent.
|
|
182
|
+
|
|
183
|
+
Args:
|
|
184
|
+
agent: CRCAAgent instance for accessing causal graph and stats
|
|
185
|
+
"""
|
|
186
|
+
self.agent = agent
|
|
187
|
+
|
|
188
|
+
def sample_interventions(
|
|
189
|
+
self,
|
|
190
|
+
factual_state: Dict[str, float],
|
|
191
|
+
target_variables: List[str],
|
|
192
|
+
n_samples: int
|
|
193
|
+
) -> List[Dict[str, float]]:
|
|
194
|
+
"""Sample interventions using adaptive distributions.
|
|
195
|
+
|
|
196
|
+
Args:
|
|
197
|
+
factual_state: Current factual state
|
|
198
|
+
target_variables: Variables to sample interventions for
|
|
199
|
+
n_samples: Number of samples to generate
|
|
200
|
+
|
|
201
|
+
Returns:
|
|
202
|
+
List of intervention dictionaries
|
|
203
|
+
"""
|
|
204
|
+
samples = []
|
|
205
|
+
rng = np.random.default_rng(self.agent.seed if self.agent.seed is not None else None)
|
|
206
|
+
|
|
207
|
+
for _ in range(n_samples):
|
|
208
|
+
intervention = {}
|
|
209
|
+
for var in target_variables:
|
|
210
|
+
dist_type, params = self._get_adaptive_distribution(var, factual_state)
|
|
211
|
+
|
|
212
|
+
if dist_type == "gaussian":
|
|
213
|
+
val = self._sample_gaussian(var, params["mean"], params["std"], 1, rng)[0]
|
|
214
|
+
elif dist_type == "uniform":
|
|
215
|
+
val = self._sample_uniform(var, params["bounds"], 1, rng)[0]
|
|
216
|
+
elif dist_type == "mixture":
|
|
217
|
+
val = self._sample_mixture(var, params["components"], 1, rng)[0]
|
|
218
|
+
else: # adaptive/graph-based
|
|
219
|
+
val = self._sample_from_graph_structure(var, factual_state, 1, rng)[0]
|
|
220
|
+
|
|
221
|
+
intervention[var] = float(val)
|
|
222
|
+
|
|
223
|
+
samples.append(intervention)
|
|
224
|
+
|
|
225
|
+
return samples
|
|
226
|
+
|
|
227
|
+
def _get_adaptive_distribution(
|
|
228
|
+
self,
|
|
229
|
+
var: str,
|
|
230
|
+
factual_state: Dict[str, float]
|
|
231
|
+
) -> Tuple[str, Dict[str, Any]]:
|
|
232
|
+
"""Select distribution type based on graph structure.
|
|
233
|
+
|
|
234
|
+
Args:
|
|
235
|
+
var: Variable name
|
|
236
|
+
factual_state: Current factual state
|
|
237
|
+
|
|
238
|
+
Returns:
|
|
239
|
+
Tuple of (distribution_type, parameters)
|
|
240
|
+
"""
|
|
241
|
+
# Get edge strengths and confidence
|
|
242
|
+
parents = self.agent._get_parents(var)
|
|
243
|
+
children = self.agent._get_children(var)
|
|
244
|
+
|
|
245
|
+
# Calculate average edge strength
|
|
246
|
+
avg_strength = 0.0
|
|
247
|
+
avg_confidence = 1.0
|
|
248
|
+
if parents:
|
|
249
|
+
strengths = [abs(self.agent._edge_strength(p, var)) for p in parents]
|
|
250
|
+
confidences = []
|
|
251
|
+
for p in parents:
|
|
252
|
+
edge = self.agent.causal_graph.get(p, {}).get(var, {})
|
|
253
|
+
if isinstance(edge, dict):
|
|
254
|
+
confidences.append(edge.get("confidence", 1.0))
|
|
255
|
+
else:
|
|
256
|
+
confidences.append(1.0)
|
|
257
|
+
avg_strength = sum(strengths) / len(strengths) if strengths else 0.0
|
|
258
|
+
avg_confidence = sum(confidences) / len(confidences) if confidences else 1.0
|
|
259
|
+
|
|
260
|
+
# Get path length (max depth from root)
|
|
261
|
+
path_length = self._get_max_path_length(var)
|
|
262
|
+
|
|
263
|
+
# Get variable importance (number of descendants)
|
|
264
|
+
importance = len(self.agent._get_descendants(var))
|
|
265
|
+
|
|
266
|
+
# Get stats
|
|
267
|
+
stats = self.agent.standardization_stats.get(var, {"mean": 0.0, "std": 1.0})
|
|
268
|
+
mean = stats.get("mean", 0.0)
|
|
269
|
+
std = stats.get("std", 1.0) or 1.0
|
|
270
|
+
factual_val = factual_state.get(var, mean)
|
|
271
|
+
|
|
272
|
+
# Decision logic
|
|
273
|
+
if avg_confidence > 0.8 and avg_strength > 0.5:
|
|
274
|
+
# High confidence + strong edges -> narrow Gaussian
|
|
275
|
+
return "gaussian", {
|
|
276
|
+
"mean": factual_val,
|
|
277
|
+
"std": std * 0.3 # Narrow distribution
|
|
278
|
+
}
|
|
279
|
+
elif avg_confidence < 0.5:
|
|
280
|
+
# Low confidence -> wide uniform or mixture
|
|
281
|
+
if path_length > 3:
|
|
282
|
+
return "mixture", {
|
|
283
|
+
"components": [
|
|
284
|
+
{"type": "gaussian", "mean": factual_val, "std": std * 1.5, "weight": 0.5},
|
|
285
|
+
{"type": "uniform", "bounds": (factual_val - 3*std, factual_val + 3*std), "weight": 0.5}
|
|
286
|
+
]
|
|
287
|
+
}
|
|
288
|
+
else:
|
|
289
|
+
return "uniform", {
|
|
290
|
+
"bounds": (factual_val - 2*std, factual_val + 2*std)
|
|
291
|
+
}
|
|
292
|
+
elif path_length > 4:
|
|
293
|
+
# Long causal paths -> mixture to capture path uncertainty
|
|
294
|
+
return "mixture", {
|
|
295
|
+
"components": [
|
|
296
|
+
{"type": "gaussian", "mean": factual_val, "std": std * 0.8, "weight": 0.7},
|
|
297
|
+
{"type": "uniform", "bounds": (factual_val - 2*std, factual_val + 2*std), "weight": 0.3}
|
|
298
|
+
]
|
|
299
|
+
}
|
|
300
|
+
elif len(parents) > 3:
|
|
301
|
+
# Many parents -> mixture to capture multi-parent uncertainty
|
|
302
|
+
return "mixture", {
|
|
303
|
+
"components": [
|
|
304
|
+
{"type": "gaussian", "mean": factual_val, "std": std * 0.6, "weight": 0.6},
|
|
305
|
+
{"type": "uniform", "bounds": (factual_val - 2*std, factual_val + 2*std), "weight": 0.4}
|
|
306
|
+
]
|
|
307
|
+
}
|
|
308
|
+
elif len(parents) == 0:
|
|
309
|
+
# Exogenous variable -> uniform
|
|
310
|
+
return "uniform", {
|
|
311
|
+
"bounds": (factual_val - 2*std, factual_val + 2*std)
|
|
312
|
+
}
|
|
313
|
+
else:
|
|
314
|
+
# Default: adaptive based on graph structure
|
|
315
|
+
return "adaptive", {
|
|
316
|
+
"mean": factual_val,
|
|
317
|
+
"std": std,
|
|
318
|
+
"edge_strength": avg_strength,
|
|
319
|
+
"confidence": avg_confidence,
|
|
320
|
+
"path_length": path_length
|
|
321
|
+
}
|
|
322
|
+
|
|
323
|
+
def _get_max_path_length(self, var: str) -> int:
|
|
324
|
+
"""Get maximum path length from root to variable."""
|
|
325
|
+
def dfs(node: str, visited: set, depth: int) -> int:
|
|
326
|
+
if node in visited:
|
|
327
|
+
return depth
|
|
328
|
+
visited.add(node)
|
|
329
|
+
parents = self.agent._get_parents(node)
|
|
330
|
+
if not parents:
|
|
331
|
+
return depth
|
|
332
|
+
return max([dfs(p, visited.copy(), depth + 1) for p in parents] + [depth])
|
|
333
|
+
|
|
334
|
+
return dfs(var, set(), 0)
|
|
335
|
+
|
|
336
|
+
def _sample_gaussian(
|
|
337
|
+
self,
|
|
338
|
+
var: str,
|
|
339
|
+
mean: float,
|
|
340
|
+
std: float,
|
|
341
|
+
n: int,
|
|
342
|
+
rng: np.random.Generator
|
|
343
|
+
) -> List[float]:
|
|
344
|
+
"""Sample from Gaussian distribution."""
|
|
345
|
+
return [float(x) for x in rng.normal(mean, std, n)]
|
|
346
|
+
|
|
347
|
+
def _sample_uniform(
|
|
348
|
+
self,
|
|
349
|
+
var: str,
|
|
350
|
+
bounds: Tuple[float, float],
|
|
351
|
+
n: int,
|
|
352
|
+
rng: np.random.Generator
|
|
353
|
+
) -> List[float]:
|
|
354
|
+
"""Sample from uniform distribution."""
|
|
355
|
+
low, high = bounds
|
|
356
|
+
return [float(x) for x in rng.uniform(low, high, n)]
|
|
357
|
+
|
|
358
|
+
def _sample_mixture(
|
|
359
|
+
self,
|
|
360
|
+
var: str,
|
|
361
|
+
components: List[Dict[str, Any]],
|
|
362
|
+
n: int,
|
|
363
|
+
rng: np.random.Generator
|
|
364
|
+
) -> List[float]:
|
|
365
|
+
"""Sample from mixture distribution."""
|
|
366
|
+
samples = []
|
|
367
|
+
for _ in range(n):
|
|
368
|
+
# Select component based on weights
|
|
369
|
+
weights = [c.get("weight", 1.0/len(components)) for c in components]
|
|
370
|
+
total_weight = sum(weights)
|
|
371
|
+
probs = [w / total_weight for w in weights]
|
|
372
|
+
component_idx = rng.choice(len(components), p=probs)
|
|
373
|
+
component = components[component_idx]
|
|
374
|
+
|
|
375
|
+
if component["type"] == "gaussian":
|
|
376
|
+
val = rng.normal(component["mean"], component["std"])
|
|
377
|
+
else: # uniform
|
|
378
|
+
low, high = component["bounds"]
|
|
379
|
+
val = rng.uniform(low, high)
|
|
380
|
+
|
|
381
|
+
samples.append(float(val))
|
|
382
|
+
|
|
383
|
+
return samples
|
|
384
|
+
|
|
385
|
+
def _sample_from_graph_structure(
|
|
386
|
+
self,
|
|
387
|
+
var: str,
|
|
388
|
+
factual_state: Dict[str, float],
|
|
389
|
+
n: int,
|
|
390
|
+
rng: np.random.Generator
|
|
391
|
+
) -> List[float]:
|
|
392
|
+
"""Sample using graph structure information."""
|
|
393
|
+
stats = self.agent.standardization_stats.get(var, {"mean": 0.0, "std": 1.0})
|
|
394
|
+
mean = stats.get("mean", 0.0)
|
|
395
|
+
std = stats.get("std", 1.0) or 1.0
|
|
396
|
+
factual_val = factual_state.get(var, mean)
|
|
397
|
+
|
|
398
|
+
# Use graph-based adaptive std
|
|
399
|
+
parents = self.agent._get_parents(var)
|
|
400
|
+
if parents:
|
|
401
|
+
avg_strength = sum([abs(self.agent._edge_strength(p, var)) for p in parents]) / len(parents)
|
|
402
|
+
# Stronger edges -> narrower distribution
|
|
403
|
+
adaptive_std = std * (1.0 - 0.5 * min(1.0, avg_strength))
|
|
404
|
+
else:
|
|
405
|
+
adaptive_std = std * 1.5 # Wider for exogenous
|
|
406
|
+
|
|
407
|
+
return self._sample_gaussian(var, factual_val, adaptive_std, n, rng)
|
|
408
|
+
|
|
409
|
+
|
|
410
|
+
class _GraphUncertaintySampler:
|
|
411
|
+
"""Sample graph variations for uncertainty quantification."""
|
|
412
|
+
|
|
413
|
+
def __init__(self, agent: 'CRCAAgent'):
|
|
414
|
+
"""Initialize sampler with reference to agent.
|
|
415
|
+
|
|
416
|
+
Args:
|
|
417
|
+
agent: CRCAAgent instance for accessing causal graph
|
|
418
|
+
"""
|
|
419
|
+
self.agent = agent
|
|
420
|
+
|
|
421
|
+
def sample_graph_variations(
|
|
422
|
+
self,
|
|
423
|
+
n_samples: int,
|
|
424
|
+
uncertainty_data: Optional[Dict[str, Any]] = None
|
|
425
|
+
) -> List[Dict[Tuple[str, str], float]]:
|
|
426
|
+
"""Sample alternative graph structures.
|
|
427
|
+
|
|
428
|
+
Args:
|
|
429
|
+
n_samples: Number of graph variations to sample
|
|
430
|
+
uncertainty_data: Optional uncertainty data from quantify_uncertainty()
|
|
431
|
+
|
|
432
|
+
Returns:
|
|
433
|
+
List of graph variation dictionaries mapping (source, target) -> strength
|
|
434
|
+
"""
|
|
435
|
+
variations = []
|
|
436
|
+
rng = np.random.default_rng(self.agent.seed if self.agent.seed is not None else None)
|
|
437
|
+
|
|
438
|
+
# Get baseline strengths
|
|
439
|
+
baseline_strengths: Dict[Tuple[str, str], float] = {}
|
|
440
|
+
for u, targets in self.agent.causal_graph.items():
|
|
441
|
+
for v, meta in targets.items():
|
|
442
|
+
try:
|
|
443
|
+
baseline_strengths[(u, v)] = float(meta.get("strength", 0.0)) if isinstance(meta, dict) else float(meta)
|
|
444
|
+
except Exception:
|
|
445
|
+
baseline_strengths[(u, v)] = 0.0
|
|
446
|
+
|
|
447
|
+
# Get confidence intervals if available
|
|
448
|
+
edge_cis = {}
|
|
449
|
+
if uncertainty_data and "edge_cis" in uncertainty_data:
|
|
450
|
+
edge_cis = uncertainty_data["edge_cis"]
|
|
451
|
+
|
|
452
|
+
for _ in range(n_samples):
|
|
453
|
+
variation = {}
|
|
454
|
+
for (u, v), baseline_strength in baseline_strengths.items():
|
|
455
|
+
edge_key = f"{u}->{v}"
|
|
456
|
+
if edge_key in edge_cis:
|
|
457
|
+
# Sample from confidence interval
|
|
458
|
+
ci_lower, ci_upper = edge_cis[edge_key]
|
|
459
|
+
# Use truncated normal within CI
|
|
460
|
+
mean = (ci_lower + ci_upper) / 2.0
|
|
461
|
+
std = (ci_upper - ci_lower) / 4.0 # Approximate std from CI
|
|
462
|
+
sampled = rng.normal(mean, std)
|
|
463
|
+
# Truncate to CI bounds
|
|
464
|
+
sampled = max(ci_lower, min(ci_upper, sampled))
|
|
465
|
+
else:
|
|
466
|
+
# Sample around baseline with small perturbation
|
|
467
|
+
edge = self.agent.causal_graph.get(u, {}).get(v, {})
|
|
468
|
+
confidence = edge.get("confidence", 1.0) if isinstance(edge, dict) else 1.0
|
|
469
|
+
# Lower confidence -> larger perturbation
|
|
470
|
+
perturbation_std = 0.1 * (2.0 - confidence)
|
|
471
|
+
sampled = baseline_strength + rng.normal(0.0, perturbation_std)
|
|
472
|
+
|
|
473
|
+
variation[(u, v)] = float(sampled)
|
|
474
|
+
|
|
475
|
+
variations.append(variation)
|
|
476
|
+
|
|
477
|
+
return variations
|
|
478
|
+
|
|
479
|
+
|
|
480
|
+
class _PredictionQualityAssessor:
|
|
481
|
+
"""Assess quality and reliability of counterfactual predictions."""
|
|
482
|
+
|
|
483
|
+
def __init__(self, agent: 'CRCAAgent'):
|
|
484
|
+
"""Initialize assessor with reference to agent.
|
|
485
|
+
|
|
486
|
+
Args:
|
|
487
|
+
agent: CRCAAgent instance for accessing causal graph
|
|
488
|
+
"""
|
|
489
|
+
self.agent = agent
|
|
490
|
+
|
|
491
|
+
def assess_quality(
|
|
492
|
+
self,
|
|
493
|
+
predictions_across_variants: List[Dict[str, float]],
|
|
494
|
+
factual_state: Dict[str, float],
|
|
495
|
+
interventions: Dict[str, float]
|
|
496
|
+
) -> Tuple[float, Dict[str, Any]]:
|
|
497
|
+
"""Evaluate prediction reliability.
|
|
498
|
+
|
|
499
|
+
Args:
|
|
500
|
+
predictions_across_variants: List of predictions from different graph variants
|
|
501
|
+
factual_state: Original factual state
|
|
502
|
+
interventions: Applied interventions
|
|
503
|
+
|
|
504
|
+
Returns:
|
|
505
|
+
Tuple of (quality_score, detailed_metrics)
|
|
506
|
+
"""
|
|
507
|
+
if not predictions_across_variants:
|
|
508
|
+
return 0.0, {}
|
|
509
|
+
|
|
510
|
+
# Calculate consistency (variance across variants)
|
|
511
|
+
all_vars = set()
|
|
512
|
+
for pred in predictions_across_variants:
|
|
513
|
+
all_vars.update(pred.keys())
|
|
514
|
+
|
|
515
|
+
consistency_scores = {}
|
|
516
|
+
for var in all_vars:
|
|
517
|
+
values = [pred.get(var, 0.0) for pred in predictions_across_variants]
|
|
518
|
+
if len(values) > 1:
|
|
519
|
+
variance = float(np.var(values))
|
|
520
|
+
std = float(np.std(values))
|
|
521
|
+
mean_val = float(np.mean(values))
|
|
522
|
+
# Consistency: lower variance = higher consistency
|
|
523
|
+
# Normalize by mean to get coefficient of variation
|
|
524
|
+
cv = std / abs(mean_val) if abs(mean_val) > 1e-6 else std
|
|
525
|
+
consistency_scores[var] = {
|
|
526
|
+
"variance": variance,
|
|
527
|
+
"std": std,
|
|
528
|
+
"mean": mean_val,
|
|
529
|
+
"coefficient_of_variation": cv,
|
|
530
|
+
"consistency": max(0.0, 1.0 - min(1.0, cv)) # 1.0 = perfect consistency
|
|
531
|
+
}
|
|
532
|
+
else:
|
|
533
|
+
consistency_scores[var] = {
|
|
534
|
+
"variance": 0.0,
|
|
535
|
+
"std": 0.0,
|
|
536
|
+
"mean": values[0] if values else 0.0,
|
|
537
|
+
"coefficient_of_variation": 0.0,
|
|
538
|
+
"consistency": 1.0
|
|
539
|
+
}
|
|
540
|
+
|
|
541
|
+
# Calculate confidence based on edge strengths and path lengths
|
|
542
|
+
confidence_scores = {}
|
|
543
|
+
for var in all_vars:
|
|
544
|
+
# Get path from intervention variables to this variable
|
|
545
|
+
max_path_strength = 0.0
|
|
546
|
+
min_path_length = float('inf')
|
|
547
|
+
|
|
548
|
+
for interv_var in interventions.keys():
|
|
549
|
+
path = self.agent.identify_causal_chain(interv_var, var)
|
|
550
|
+
if path:
|
|
551
|
+
path_length = len(path) - 1
|
|
552
|
+
min_path_length = min(min_path_length, path_length)
|
|
553
|
+
|
|
554
|
+
# Calculate path strength (product of edge strengths)
|
|
555
|
+
path_strength = 1.0
|
|
556
|
+
for i in range(len(path) - 1):
|
|
557
|
+
u, v = path[i], path[i + 1]
|
|
558
|
+
edge_strength = abs(self.agent._edge_strength(u, v))
|
|
559
|
+
path_strength *= edge_strength
|
|
560
|
+
max_path_strength = max(max_path_strength, path_strength)
|
|
561
|
+
|
|
562
|
+
# Confidence: higher path strength, shorter path = higher confidence
|
|
563
|
+
path_confidence = max_path_strength * (1.0 / (1.0 + min_path_length * 0.1))
|
|
564
|
+
confidence_scores[var] = {
|
|
565
|
+
"path_strength": max_path_strength,
|
|
566
|
+
"path_length": min_path_length if min_path_length != float('inf') else 0,
|
|
567
|
+
"confidence": float(path_confidence)
|
|
568
|
+
}
|
|
569
|
+
|
|
570
|
+
# Calculate sensitivity (how much predictions change with small graph perturbations)
|
|
571
|
+
sensitivity_scores = {}
|
|
572
|
+
if len(predictions_across_variants) > 1:
|
|
573
|
+
baseline_pred = predictions_across_variants[0]
|
|
574
|
+
for var in all_vars:
|
|
575
|
+
baseline_val = baseline_pred.get(var, 0.0)
|
|
576
|
+
perturbations = [abs(pred.get(var, 0.0) - baseline_val) for pred in predictions_across_variants[1:]]
|
|
577
|
+
avg_perturbation = float(np.mean(perturbations)) if perturbations else 0.0
|
|
578
|
+
max_perturbation = float(max(perturbations)) if perturbations else 0.0
|
|
579
|
+
|
|
580
|
+
# Sensitivity: lower perturbation = lower sensitivity = better
|
|
581
|
+
sensitivity_scores[var] = {
|
|
582
|
+
"avg_perturbation": avg_perturbation,
|
|
583
|
+
"max_perturbation": max_perturbation,
|
|
584
|
+
"sensitivity": min(1.0, avg_perturbation / (abs(baseline_val) + 1e-6))
|
|
585
|
+
}
|
|
586
|
+
|
|
587
|
+
# Overall quality score: weighted combination
|
|
588
|
+
overall_quality = 0.0
|
|
589
|
+
if consistency_scores and confidence_scores:
|
|
590
|
+
consistency_avg = float(np.mean([s["consistency"] for s in consistency_scores.values()]))
|
|
591
|
+
confidence_avg = float(np.mean([s["confidence"] for s in confidence_scores.values()]))
|
|
592
|
+
|
|
593
|
+
# Weight: consistency 40%, confidence 60%
|
|
594
|
+
overall_quality = 0.4 * consistency_avg + 0.6 * confidence_avg
|
|
595
|
+
|
|
596
|
+
metrics = {
|
|
597
|
+
"consistency": consistency_scores,
|
|
598
|
+
"confidence": confidence_scores,
|
|
599
|
+
"sensitivity": sensitivity_scores,
|
|
600
|
+
"overall_quality": overall_quality
|
|
601
|
+
}
|
|
602
|
+
|
|
603
|
+
return overall_quality, metrics
|
|
604
|
+
|
|
605
|
+
|
|
606
|
+
class _MetaReasoningAnalyzer:
|
|
607
|
+
"""Analyze scenarios for meta-level reasoning about informativeness."""
|
|
608
|
+
|
|
609
|
+
def __init__(self, agent: 'CRCAAgent'):
|
|
610
|
+
"""Initialize analyzer with reference to agent.
|
|
611
|
+
|
|
612
|
+
Args:
|
|
613
|
+
agent: CRCAAgent instance for accessing causal graph
|
|
614
|
+
"""
|
|
615
|
+
self.agent = agent
|
|
616
|
+
|
|
617
|
+
def analyze_scenarios(
|
|
618
|
+
self,
|
|
619
|
+
scenarios_with_metadata: List[Tuple[CounterfactualScenario, Dict[str, Any]]]
|
|
620
|
+
) -> List[Tuple[CounterfactualScenario, float]]:
|
|
621
|
+
"""Comprehensive meta-analysis of scenarios.
|
|
622
|
+
|
|
623
|
+
Args:
|
|
624
|
+
scenarios_with_metadata: List of (scenario, metadata) tuples
|
|
625
|
+
|
|
626
|
+
Returns:
|
|
627
|
+
List of (scenario, meta_reasoning_score) tuples, sorted by score
|
|
628
|
+
"""
|
|
629
|
+
scored_scenarios = []
|
|
630
|
+
|
|
631
|
+
for scenario, metadata in scenarios_with_metadata:
|
|
632
|
+
# Calculate scenario relevance (information gain, expected utility)
|
|
633
|
+
relevance_score = self._calculate_relevance(scenario, metadata)
|
|
634
|
+
|
|
635
|
+
# Calculate graph uncertainty impact
|
|
636
|
+
uncertainty_impact = self._calculate_uncertainty_impact(scenario, metadata)
|
|
637
|
+
|
|
638
|
+
# Calculate prediction reliability
|
|
639
|
+
reliability = metadata.get("quality_score", 0.5)
|
|
640
|
+
|
|
641
|
+
# Informativeness score: combines relevance + reliability
|
|
642
|
+
informativeness = 0.5 * relevance_score + 0.3 * reliability + 0.2 * (1.0 - uncertainty_impact)
|
|
643
|
+
|
|
644
|
+
scored_scenarios.append((scenario, informativeness))
|
|
645
|
+
|
|
646
|
+
# Sort by informativeness (descending)
|
|
647
|
+
scored_scenarios.sort(key=lambda x: x[1], reverse=True)
|
|
648
|
+
|
|
649
|
+
return scored_scenarios
|
|
650
|
+
|
|
651
|
+
def _calculate_relevance(
|
|
652
|
+
self,
|
|
653
|
+
scenario: CounterfactualScenario,
|
|
654
|
+
metadata: Dict[str, Any]
|
|
655
|
+
) -> float:
|
|
656
|
+
"""Calculate scenario relevance (information gain, expected utility)."""
|
|
657
|
+
# Information gain: how different is this scenario from factual?
|
|
658
|
+
factual_state = metadata.get("factual_state", {})
|
|
659
|
+
interventions = scenario.interventions
|
|
660
|
+
|
|
661
|
+
# Calculate magnitude of intervention
|
|
662
|
+
intervention_magnitude = 0.0
|
|
663
|
+
for var, val in interventions.items():
|
|
664
|
+
factual_val = factual_state.get(var, 0.0)
|
|
665
|
+
stats = self.agent.standardization_stats.get(var, {"mean": 0.0, "std": 1.0})
|
|
666
|
+
std = stats.get("std", 1.0) or 1.0
|
|
667
|
+
# Normalized difference
|
|
668
|
+
diff = abs(val - factual_val) / std if std > 0 else abs(val - factual_val)
|
|
669
|
+
intervention_magnitude += diff
|
|
670
|
+
|
|
671
|
+
# Expected utility: how much do outcomes change?
|
|
672
|
+
outcome_magnitude = 0.0
|
|
673
|
+
for var, val in scenario.expected_outcomes.items():
|
|
674
|
+
factual_val = factual_state.get(var, 0.0)
|
|
675
|
+
stats = self.agent.standardization_stats.get(var, {"mean": 0.0, "std": 1.0})
|
|
676
|
+
std = stats.get("std", 1.0) or 1.0
|
|
677
|
+
diff = abs(val - factual_val) / std if std > 0 else abs(val - factual_val)
|
|
678
|
+
outcome_magnitude += diff
|
|
679
|
+
|
|
680
|
+
# Relevance: combination of intervention and outcome magnitude
|
|
681
|
+
# Normalize by number of variables
|
|
682
|
+
n_vars = max(1, len(interventions))
|
|
683
|
+
relevance = (intervention_magnitude + outcome_magnitude) / (2.0 * n_vars)
|
|
684
|
+
|
|
685
|
+
return min(1.0, relevance)
|
|
686
|
+
|
|
687
|
+
def _calculate_uncertainty_impact(
|
|
688
|
+
self,
|
|
689
|
+
scenario: CounterfactualScenario,
|
|
690
|
+
metadata: Dict[str, Any]
|
|
691
|
+
) -> float:
|
|
692
|
+
"""Calculate how graph uncertainty affects predictions."""
|
|
693
|
+
quality_metrics = metadata.get("quality_metrics", {})
|
|
694
|
+
consistency = quality_metrics.get("consistency", {})
|
|
695
|
+
|
|
696
|
+
if not consistency:
|
|
697
|
+
return 0.5 # Default moderate uncertainty
|
|
698
|
+
|
|
699
|
+
# Average coefficient of variation across all variables
|
|
700
|
+
cvs = [s.get("coefficient_of_variation", 0.0) for s in consistency.values()]
|
|
701
|
+
avg_cv = float(np.mean(cvs)) if cvs else 0.0
|
|
702
|
+
|
|
703
|
+
# Uncertainty impact: higher CV = higher impact
|
|
704
|
+
return min(1.0, avg_cv)
|
|
705
|
+
|
|
706
|
+
|
|
707
|
+
class CRCAAgent(Agent):
|
|
708
|
+
"""Causal Reasoning with Counterfactual Analysis Agent.
|
|
709
|
+
|
|
710
|
+
A lightweight causal reasoning agent with LLM integration, providing both
|
|
711
|
+
LLM-based causal analysis and deterministic causal simulation. Supports
|
|
712
|
+
automatic variable extraction, causal graph management, counterfactual
|
|
713
|
+
scenario generation, and comprehensive causal analysis.
|
|
714
|
+
|
|
715
|
+
Key Features:
|
|
716
|
+
- LLM integration for sophisticated causal reasoning
|
|
717
|
+
- Dual-mode operation: LLM-based analysis and deterministic simulation
|
|
718
|
+
- Automatic variable extraction from natural language tasks
|
|
719
|
+
- Causal graph management with rustworkx backend
|
|
720
|
+
- Counterfactual scenario generation
|
|
721
|
+
- Batch prediction support
|
|
722
|
+
- Async/await support for concurrent operations
|
|
723
|
+
|
|
724
|
+
Attributes:
|
|
725
|
+
causal_graph: Dictionary representing the causal graph structure
|
|
726
|
+
causal_memory: List storing analysis steps and results
|
|
727
|
+
causal_max_loops: Maximum number of loops for causal reasoning
|
|
728
|
+
"""
|
|
729
|
+
|
|
730
|
+
def __init__(
|
|
731
|
+
self,
|
|
732
|
+
variables: Optional[List[str]] = None,
|
|
733
|
+
causal_edges: Optional[List[Tuple[str, str]]] = None,
|
|
734
|
+
max_loops: Optional[Union[int, str]] = 3,
|
|
735
|
+
agent_name: str = "cr-ca-lite-agent",
|
|
736
|
+
agent_description: str = "Lightweight Causal Reasoning with Counterfactual Analysis Agent",
|
|
737
|
+
description: Optional[str] = None,
|
|
738
|
+
model_name: str = "gpt-4o",
|
|
739
|
+
system_prompt: Optional[str] = None,
|
|
740
|
+
global_system_prompt: Optional[str] = None,
|
|
741
|
+
secondary_system_prompt: Optional[str] = None,
|
|
742
|
+
enable_batch_predict: bool = False,
|
|
743
|
+
max_batch_size: int = 32,
|
|
744
|
+
bootstrap_workers: int = 0,
|
|
745
|
+
use_async: bool = False,
|
|
746
|
+
seed: Optional[int] = None,
|
|
747
|
+
enable_excel: bool = False,
|
|
748
|
+
agent_max_loops: Optional[Union[int, str]] = None,
|
|
749
|
+
policy: Optional[Union[DoctrineV1, str]] = None,
|
|
750
|
+
ledger_path: Optional[str] = None,
|
|
751
|
+
epoch_seconds: int = 3600,
|
|
752
|
+
policy_mode: bool = False,
|
|
753
|
+
sensor_registry: Optional[Any] = None,
|
|
754
|
+
actuator_registry: Optional[Any] = None,
|
|
755
|
+
**kwargs,
|
|
756
|
+
):
|
|
757
|
+
"""
|
|
758
|
+
Initialize CRCAAgent with causal reasoning capabilities.
|
|
759
|
+
|
|
760
|
+
Args:
|
|
761
|
+
variables: List of variable names for the causal graph
|
|
762
|
+
causal_edges: List of (source, target) tuples defining causal relationships
|
|
763
|
+
max_loops: Maximum loops for causal reasoning (default: 3)
|
|
764
|
+
agent_max_loops: Maximum loops for standard Agent operations (supports "auto")
|
|
765
|
+
If not provided, defaults to 1 for individual LLM calls.
|
|
766
|
+
Pass "auto" to enable automatic loop detection for standard operations.
|
|
767
|
+
policy: Policy doctrine (DoctrineV1 instance or path to JSON file) for policy mode
|
|
768
|
+
ledger_path: Path to SQLite ledger database for event storage
|
|
769
|
+
epoch_seconds: Length of one epoch in seconds (default: 3600)
|
|
770
|
+
policy_mode: Enable temporal policy engine mode (default: False)
|
|
771
|
+
**kwargs: Additional arguments passed to parent Agent class
|
|
772
|
+
"""
|
|
773
|
+
|
|
774
|
+
cr_ca_schema = CRCAAgent._get_cr_ca_schema()
|
|
775
|
+
extract_variables_schema = CRCAAgent._get_extract_variables_schema()
|
|
776
|
+
|
|
777
|
+
# Backwards-compatible alias for description
|
|
778
|
+
agent_description = description or agent_description
|
|
779
|
+
|
|
780
|
+
# Handle max_loops for standard Agent operations
|
|
781
|
+
# agent_max_loops parameter takes precedence, then check kwargs, then default to 1
|
|
782
|
+
if agent_max_loops is None:
|
|
783
|
+
agent_max_loops = kwargs.pop("max_loops", 1)
|
|
784
|
+
else:
|
|
785
|
+
# Remove max_loops from kwargs if agent_max_loops was explicitly provided
|
|
786
|
+
kwargs.pop("max_loops", None)
|
|
787
|
+
|
|
788
|
+
# Merge tools_list_dictionary from kwargs with CRCA schema
|
|
789
|
+
# Only add CRCA schema if user hasn't explicitly disabled it
|
|
790
|
+
use_crca_tools = kwargs.pop("use_crca_tools", True) # Default to True for backwards compatibility
|
|
791
|
+
existing_tools = kwargs.pop("tools_list_dictionary", [])
|
|
792
|
+
if not isinstance(existing_tools, list):
|
|
793
|
+
existing_tools = [existing_tools] if existing_tools else []
|
|
794
|
+
|
|
795
|
+
# Only add CRCA schemas if enabled
|
|
796
|
+
if use_crca_tools:
|
|
797
|
+
tools_list = [cr_ca_schema, extract_variables_schema] + existing_tools
|
|
798
|
+
else:
|
|
799
|
+
tools_list = existing_tools
|
|
800
|
+
|
|
801
|
+
# Get existing callable tools (functions) from kwargs
|
|
802
|
+
existing_callable_tools = kwargs.pop("tools", [])
|
|
803
|
+
if not isinstance(existing_callable_tools, list):
|
|
804
|
+
existing_callable_tools = [existing_callable_tools] if existing_callable_tools else []
|
|
805
|
+
|
|
806
|
+
agent_kwargs = {
|
|
807
|
+
"agent_name": agent_name,
|
|
808
|
+
"agent_description": agent_description,
|
|
809
|
+
"model_name": model_name,
|
|
810
|
+
"max_loops": agent_max_loops, # Use user-provided value or default to 1
|
|
811
|
+
"output_type": "final",
|
|
812
|
+
**kwargs, # All other Agent parameters passed through
|
|
813
|
+
}
|
|
814
|
+
|
|
815
|
+
# Always provide tools_list_dictionary if we have schemas
|
|
816
|
+
# This ensures the LLM knows about the tools
|
|
817
|
+
if tools_list:
|
|
818
|
+
agent_kwargs["tools_list_dictionary"] = tools_list
|
|
819
|
+
|
|
820
|
+
# Add existing callable tools if any
|
|
821
|
+
if existing_callable_tools:
|
|
822
|
+
agent_kwargs["tools"] = existing_callable_tools
|
|
823
|
+
|
|
824
|
+
# Always apply default CRCA prompt as base, then add custom prompt on top if provided
|
|
825
|
+
final_system_prompt = None
|
|
826
|
+
if DEFAULT_CRCA_SYSTEM_PROMPT is not None:
|
|
827
|
+
if system_prompt is not None:
|
|
828
|
+
# Combine: default first, then custom on top
|
|
829
|
+
final_system_prompt = f"{DEFAULT_CRCA_SYSTEM_PROMPT}\n\n--- Additional Instructions ---\n{system_prompt}"
|
|
830
|
+
else:
|
|
831
|
+
# Just use default
|
|
832
|
+
final_system_prompt = DEFAULT_CRCA_SYSTEM_PROMPT
|
|
833
|
+
elif system_prompt is not None:
|
|
834
|
+
# If no default but custom prompt provided, use custom prompt
|
|
835
|
+
final_system_prompt = system_prompt
|
|
836
|
+
|
|
837
|
+
if final_system_prompt is not None:
|
|
838
|
+
agent_kwargs["system_prompt"] = final_system_prompt
|
|
839
|
+
if global_system_prompt is not None:
|
|
840
|
+
agent_kwargs["global_system_prompt"] = global_system_prompt
|
|
841
|
+
if secondary_system_prompt is not None:
|
|
842
|
+
agent_kwargs["secondary_system_prompt"] = secondary_system_prompt
|
|
843
|
+
|
|
844
|
+
super().__init__(**agent_kwargs)
|
|
845
|
+
|
|
846
|
+
# Now that self exists, create and add the CRCA tool handlers if tools are enabled
|
|
847
|
+
if use_crca_tools:
|
|
848
|
+
# Create a wrapper function with the correct name that matches the schema
|
|
849
|
+
def generate_causal_analysis(
|
|
850
|
+
causal_analysis: str,
|
|
851
|
+
intervention_planning: str,
|
|
852
|
+
counterfactual_scenarios: List[Dict[str, Any]],
|
|
853
|
+
causal_strength_assessment: str,
|
|
854
|
+
optimal_solution: str,
|
|
855
|
+
) -> Dict[str, Any]:
|
|
856
|
+
"""Tool handler for generate_causal_analysis - wrapper that calls the instance method."""
|
|
857
|
+
return self._generate_causal_analysis_handler(
|
|
858
|
+
causal_analysis=causal_analysis,
|
|
859
|
+
intervention_planning=intervention_planning,
|
|
860
|
+
counterfactual_scenarios=counterfactual_scenarios,
|
|
861
|
+
causal_strength_assessment=causal_strength_assessment,
|
|
862
|
+
optimal_solution=optimal_solution,
|
|
863
|
+
)
|
|
864
|
+
|
|
865
|
+
# Create a wrapper function for extract_causal_variables
|
|
866
|
+
def extract_causal_variables(
|
|
867
|
+
required_variables: List[str],
|
|
868
|
+
causal_edges: List[List[str]],
|
|
869
|
+
reasoning: str,
|
|
870
|
+
optional_variables: Optional[List[str]] = None,
|
|
871
|
+
counterfactual_variables: Optional[List[str]] = None,
|
|
872
|
+
) -> Dict[str, Any]:
|
|
873
|
+
"""Tool handler for extract_causal_variables - wrapper that calls the instance method."""
|
|
874
|
+
return self._extract_causal_variables_handler(
|
|
875
|
+
required_variables=required_variables,
|
|
876
|
+
causal_edges=causal_edges,
|
|
877
|
+
reasoning=reasoning,
|
|
878
|
+
optional_variables=optional_variables,
|
|
879
|
+
counterfactual_variables=counterfactual_variables,
|
|
880
|
+
)
|
|
881
|
+
|
|
882
|
+
# Add the wrapper functions to the tools list
|
|
883
|
+
# The function names must match the schema names
|
|
884
|
+
if self.tools is None:
|
|
885
|
+
self.tools = []
|
|
886
|
+
self.add_tool(generate_causal_analysis)
|
|
887
|
+
self.add_tool(extract_causal_variables)
|
|
888
|
+
|
|
889
|
+
# CRITICAL: Re-initialize tool_struct after adding tools
|
|
890
|
+
# This ensures the BaseTool instance has the updated tools and function_map
|
|
891
|
+
if hasattr(self, 'setup_tools'):
|
|
892
|
+
self.tool_struct = self.setup_tools()
|
|
893
|
+
|
|
894
|
+
# Ensure tools_list_dictionary is set with our manual schemas
|
|
895
|
+
if not self.tools_list_dictionary or len(self.tools_list_dictionary) == 0:
|
|
896
|
+
self.tools_list_dictionary = [cr_ca_schema, extract_variables_schema] + existing_tools
|
|
897
|
+
else:
|
|
898
|
+
# Replace any auto-generated schemas with our manual ones
|
|
899
|
+
our_tool_names = {"generate_causal_analysis", "extract_causal_variables"}
|
|
900
|
+
filtered = []
|
|
901
|
+
for schema in self.tools_list_dictionary:
|
|
902
|
+
if isinstance(schema, dict):
|
|
903
|
+
func_name = schema.get("function", {}).get("name", "")
|
|
904
|
+
if func_name in our_tool_names:
|
|
905
|
+
continue # Skip auto-generated, we'll add manual
|
|
906
|
+
filtered.append(schema)
|
|
907
|
+
# Add our manual schemas first, then existing tools
|
|
908
|
+
self.tools_list_dictionary = [cr_ca_schema, extract_variables_schema] + filtered
|
|
909
|
+
|
|
910
|
+
self.causal_max_loops = max_loops
|
|
911
|
+
self.causal_graph: Dict[str, Dict[str, float]] = {}
|
|
912
|
+
self.causal_graph_reverse: Dict[str, List[str]] = {} # For fast parent lookup
|
|
913
|
+
self._graph = rx.PyDiGraph()
|
|
914
|
+
self._node_to_index: Dict[str, int] = {}
|
|
915
|
+
self._index_to_node: Dict[int, str] = {}
|
|
916
|
+
|
|
917
|
+
self.standardization_stats: Dict[str, Dict[str, float]] = {}
|
|
918
|
+
self.use_nonlinear_scm: bool = True
|
|
919
|
+
self.nonlinear_activation: str = "tanh" # options: 'tanh'|'identity'
|
|
920
|
+
self.interaction_terms: Dict[str, List[Tuple[str, str]]] = {}
|
|
921
|
+
self.edge_sign_constraints: Dict[Tuple[str, str], int] = {}
|
|
922
|
+
self.bayesian_priors: Dict[Tuple[str, str], Dict[str, float]] = {}
|
|
923
|
+
self.enable_batch_predict = bool(enable_batch_predict)
|
|
924
|
+
self.max_batch_size = int(max_batch_size)
|
|
925
|
+
self.bootstrap_workers = int(max(0, bootstrap_workers))
|
|
926
|
+
self.use_async = bool(use_async)
|
|
927
|
+
self.seed = seed if seed is not None else 42
|
|
928
|
+
self._rng = np.random.default_rng(self.seed)
|
|
929
|
+
|
|
930
|
+
if variables:
|
|
931
|
+
for var in variables:
|
|
932
|
+
self._ensure_node_exists(var)
|
|
933
|
+
|
|
934
|
+
if causal_edges:
|
|
935
|
+
for source, target in causal_edges:
|
|
936
|
+
self.add_causal_relationship(source, target)
|
|
937
|
+
|
|
938
|
+
self.causal_memory: List[Dict[str, Any]] = []
|
|
939
|
+
self._prediction_cache: Dict[Tuple[Tuple[Tuple[str, float], ...], Tuple[Tuple[str, float], ...]], Dict[str, float]] = {}
|
|
940
|
+
self._prediction_cache_order: List[Tuple[Tuple[Tuple[str, float], ...], Tuple[Tuple[str, float], ...]]] = []
|
|
941
|
+
self._prediction_cache_max: int = 1000
|
|
942
|
+
self._cache_enabled: bool = True
|
|
943
|
+
self._prediction_cache_lock = threading.Lock()
|
|
944
|
+
|
|
945
|
+
# Policy engine integration
|
|
946
|
+
self.policy_mode = bool(policy_mode)
|
|
947
|
+
self.policy: Optional[DoctrineV1] = None
|
|
948
|
+
self.ledger: Optional[Ledger] = None
|
|
949
|
+
self.epoch_seconds = int(epoch_seconds)
|
|
950
|
+
self.policy_loop: Optional[PolicyLoopMixin] = None
|
|
951
|
+
|
|
952
|
+
if self.policy_mode:
|
|
953
|
+
if not POLICY_ENGINE_AVAILABLE:
|
|
954
|
+
raise ImportError(
|
|
955
|
+
"Policy engine modules not available. "
|
|
956
|
+
"Ensure schemas/policy.py, utils/ledger.py, and templates/policy_loop.py exist."
|
|
957
|
+
)
|
|
958
|
+
|
|
959
|
+
# Load policy
|
|
960
|
+
if policy is None:
|
|
961
|
+
raise ValueError("policy must be provided when policy_mode=True")
|
|
962
|
+
|
|
963
|
+
if isinstance(policy, str):
|
|
964
|
+
# Load from JSON file
|
|
965
|
+
self.policy = DoctrineV1.from_json(policy)
|
|
966
|
+
elif isinstance(policy, DoctrineV1):
|
|
967
|
+
self.policy = policy
|
|
968
|
+
else:
|
|
969
|
+
raise TypeError(f"policy must be DoctrineV1 or str (JSON path), got {type(policy)}")
|
|
970
|
+
|
|
971
|
+
# Initialize ledger
|
|
972
|
+
if ledger_path is None:
|
|
973
|
+
ledger_path = "crca_ledger.db"
|
|
974
|
+
self.ledger = Ledger(ledger_path)
|
|
975
|
+
|
|
976
|
+
# Initialize policy loop mixin
|
|
977
|
+
self.policy_loop = PolicyLoopMixin(
|
|
978
|
+
doctrine=self.policy,
|
|
979
|
+
ledger=self.ledger,
|
|
980
|
+
seed=self.seed,
|
|
981
|
+
sensor_registry=sensor_registry,
|
|
982
|
+
actuator_registry=actuator_registry
|
|
983
|
+
)
|
|
984
|
+
|
|
985
|
+
logger.info(f"Policy mode enabled: doctrine={self.policy.version}, ledger={ledger_path}")
|
|
986
|
+
|
|
987
|
+
# Excel TUI integration
|
|
988
|
+
self._excel_enabled = bool(enable_excel)
|
|
989
|
+
self._excel_tables = None
|
|
990
|
+
self._excel_eval_engine = None
|
|
991
|
+
self._excel_scm_bridge = None
|
|
992
|
+
self._excel_dependency_graph = None
|
|
993
|
+
|
|
994
|
+
if enable_excel:
|
|
995
|
+
try:
|
|
996
|
+
from crca_excel.core.tables import TableManager
|
|
997
|
+
from crca_excel.core.deps import DependencyGraph
|
|
998
|
+
from crca_excel.core.eval import EvaluationEngine
|
|
999
|
+
from crca_excel.core.scm import SCMBridge
|
|
1000
|
+
|
|
1001
|
+
self._excel_tables = TableManager()
|
|
1002
|
+
self._excel_dependency_graph = DependencyGraph()
|
|
1003
|
+
self._excel_eval_engine = EvaluationEngine(
|
|
1004
|
+
self._excel_tables,
|
|
1005
|
+
self._excel_dependency_graph,
|
|
1006
|
+
max_iter=self.causal_max_loops if isinstance(self.causal_max_loops, int) else 100,
|
|
1007
|
+
epsilon=1e-6
|
|
1008
|
+
)
|
|
1009
|
+
self._excel_scm_bridge = SCMBridge(
|
|
1010
|
+
self._excel_tables,
|
|
1011
|
+
self._excel_eval_engine,
|
|
1012
|
+
crca_agent=self
|
|
1013
|
+
)
|
|
1014
|
+
|
|
1015
|
+
# Initialize standard tables
|
|
1016
|
+
self._initialize_excel_tables()
|
|
1017
|
+
|
|
1018
|
+
# Link causal graph to tables
|
|
1019
|
+
self._link_causal_graph_to_tables()
|
|
1020
|
+
except ImportError as e:
|
|
1021
|
+
logger.warning(f"Excel TUI modules not available: {e}")
|
|
1022
|
+
self._excel_enabled = False
|
|
1023
|
+
|
|
1024
|
+
@staticmethod
|
|
1025
|
+
def _get_cr_ca_schema() -> Dict[str, Any]:
|
|
1026
|
+
|
|
1027
|
+
return {
|
|
1028
|
+
"type": "function",
|
|
1029
|
+
"function": {
|
|
1030
|
+
"name": "generate_causal_analysis",
|
|
1031
|
+
"description": "Generates structured causal reasoning and counterfactual analysis",
|
|
1032
|
+
"parameters": {
|
|
1033
|
+
"type": "object",
|
|
1034
|
+
"properties": {
|
|
1035
|
+
"causal_analysis": {
|
|
1036
|
+
"type": "string",
|
|
1037
|
+
"description": "Analysis of causal relationships and mechanisms"
|
|
1038
|
+
},
|
|
1039
|
+
"intervention_planning": {
|
|
1040
|
+
"type": "string",
|
|
1041
|
+
"description": "Planned interventions to test causal hypotheses"
|
|
1042
|
+
},
|
|
1043
|
+
"counterfactual_scenarios": {
|
|
1044
|
+
"type": "array",
|
|
1045
|
+
"items": {
|
|
1046
|
+
"type": "object",
|
|
1047
|
+
"properties": {
|
|
1048
|
+
"scenario_name": {"type": "string"},
|
|
1049
|
+
"interventions": {"type": "object"},
|
|
1050
|
+
"expected_outcomes": {"type": "object"},
|
|
1051
|
+
"reasoning": {"type": "string"}
|
|
1052
|
+
}
|
|
1053
|
+
},
|
|
1054
|
+
"description": "Multiple counterfactual scenarios to explore"
|
|
1055
|
+
},
|
|
1056
|
+
"causal_strength_assessment": {
|
|
1057
|
+
"type": "string",
|
|
1058
|
+
"description": "Assessment of causal relationship strengths and confounders"
|
|
1059
|
+
},
|
|
1060
|
+
"optimal_solution": {
|
|
1061
|
+
"type": "string",
|
|
1062
|
+
"description": "Recommended optimal solution based on causal analysis"
|
|
1063
|
+
}
|
|
1064
|
+
},
|
|
1065
|
+
"required": [
|
|
1066
|
+
"causal_analysis",
|
|
1067
|
+
"intervention_planning",
|
|
1068
|
+
"counterfactual_scenarios",
|
|
1069
|
+
"causal_strength_assessment",
|
|
1070
|
+
"optimal_solution"
|
|
1071
|
+
]
|
|
1072
|
+
}
|
|
1073
|
+
}
|
|
1074
|
+
}
|
|
1075
|
+
|
|
1076
|
+
@staticmethod
|
|
1077
|
+
def _get_extract_variables_schema() -> Dict[str, Any]:
|
|
1078
|
+
"""
|
|
1079
|
+
Get the schema for the extract_causal_variables tool.
|
|
1080
|
+
|
|
1081
|
+
Returns:
|
|
1082
|
+
Dictionary containing the OpenAI function schema for variable extraction
|
|
1083
|
+
"""
|
|
1084
|
+
return {
|
|
1085
|
+
"type": "function",
|
|
1086
|
+
"function": {
|
|
1087
|
+
"name": "extract_causal_variables",
|
|
1088
|
+
"description": "Extract and propose causal variables, relationships, and counterfactual scenarios needed for causal analysis",
|
|
1089
|
+
"parameters": {
|
|
1090
|
+
"type": "object",
|
|
1091
|
+
"properties": {
|
|
1092
|
+
"required_variables": {
|
|
1093
|
+
"type": "array",
|
|
1094
|
+
"items": {"type": "string"},
|
|
1095
|
+
"description": "Core variables that must be included for causal analysis"
|
|
1096
|
+
},
|
|
1097
|
+
"optional_variables": {
|
|
1098
|
+
"type": "array",
|
|
1099
|
+
"items": {"type": "string"},
|
|
1100
|
+
"description": "Additional variables that may be useful but not essential"
|
|
1101
|
+
},
|
|
1102
|
+
"causal_edges": {
|
|
1103
|
+
"type": "array",
|
|
1104
|
+
"items": {
|
|
1105
|
+
"type": "array",
|
|
1106
|
+
"items": {"type": "string"},
|
|
1107
|
+
"minItems": 2,
|
|
1108
|
+
"maxItems": 2
|
|
1109
|
+
},
|
|
1110
|
+
"description": "Causal relationships as [source, target] pairs"
|
|
1111
|
+
},
|
|
1112
|
+
"counterfactual_variables": {
|
|
1113
|
+
"type": "array",
|
|
1114
|
+
"items": {"type": "string"},
|
|
1115
|
+
"description": "Variables to explore in counterfactual scenarios"
|
|
1116
|
+
},
|
|
1117
|
+
"reasoning": {
|
|
1118
|
+
"type": "string",
|
|
1119
|
+
"description": "Explanation of why these variables and relationships are needed"
|
|
1120
|
+
}
|
|
1121
|
+
},
|
|
1122
|
+
"required": ["required_variables", "causal_edges", "reasoning"]
|
|
1123
|
+
}
|
|
1124
|
+
}
|
|
1125
|
+
}
|
|
1126
|
+
|
|
1127
|
+
def step(self, task: str) -> str:
|
|
1128
|
+
"""
|
|
1129
|
+
Execute a single step of causal reasoning.
|
|
1130
|
+
|
|
1131
|
+
Args:
|
|
1132
|
+
task: Task string to process
|
|
1133
|
+
|
|
1134
|
+
Returns:
|
|
1135
|
+
Response string from the agent
|
|
1136
|
+
"""
|
|
1137
|
+
response = super().run(task)
|
|
1138
|
+
return response
|
|
1139
|
+
|
|
1140
|
+
def _generate_causal_analysis_handler(
|
|
1141
|
+
self,
|
|
1142
|
+
causal_analysis: str,
|
|
1143
|
+
intervention_planning: str,
|
|
1144
|
+
counterfactual_scenarios: List[Dict[str, Any]],
|
|
1145
|
+
causal_strength_assessment: str,
|
|
1146
|
+
optimal_solution: str,
|
|
1147
|
+
) -> Dict[str, Any]:
|
|
1148
|
+
"""
|
|
1149
|
+
Handler function for the generate_causal_analysis tool.
|
|
1150
|
+
|
|
1151
|
+
This function is called when the LLM invokes the generate_causal_analysis tool.
|
|
1152
|
+
It processes the tool's output and integrates it into the causal graph.
|
|
1153
|
+
|
|
1154
|
+
Args:
|
|
1155
|
+
causal_analysis: Analysis of causal relationships and mechanisms
|
|
1156
|
+
intervention_planning: Planned interventions to test causal hypotheses
|
|
1157
|
+
counterfactual_scenarios: List of counterfactual scenarios
|
|
1158
|
+
causal_strength_assessment: Assessment of causal relationship strengths
|
|
1159
|
+
optimal_solution: Recommended optimal solution
|
|
1160
|
+
|
|
1161
|
+
Returns:
|
|
1162
|
+
Dictionary with processed results
|
|
1163
|
+
"""
|
|
1164
|
+
logger.info("Processing causal analysis from tool call")
|
|
1165
|
+
|
|
1166
|
+
# Store the analysis in causal memory
|
|
1167
|
+
analysis_entry = {
|
|
1168
|
+
'type': 'analysis', # Mark this as an analysis entry
|
|
1169
|
+
'causal_analysis': causal_analysis,
|
|
1170
|
+
'intervention_planning': intervention_planning,
|
|
1171
|
+
'counterfactual_scenarios': counterfactual_scenarios,
|
|
1172
|
+
'causal_strength_assessment': causal_strength_assessment,
|
|
1173
|
+
'optimal_solution': optimal_solution,
|
|
1174
|
+
'timestamp': len(self.causal_memory)
|
|
1175
|
+
}
|
|
1176
|
+
|
|
1177
|
+
self.causal_memory.append(analysis_entry)
|
|
1178
|
+
|
|
1179
|
+
# Try to extract and update causal relationships from the analysis
|
|
1180
|
+
# This is a simple implementation - can be enhanced later
|
|
1181
|
+
try:
|
|
1182
|
+
# The LLM might mention relationships in the analysis
|
|
1183
|
+
# For now, we'll just store the analysis
|
|
1184
|
+
# In a more advanced version, we could parse the analysis to extract relationships
|
|
1185
|
+
pass
|
|
1186
|
+
except Exception as e:
|
|
1187
|
+
logger.warning(f"Error processing causal analysis: {e}")
|
|
1188
|
+
|
|
1189
|
+
# Return a structured response
|
|
1190
|
+
return {
|
|
1191
|
+
"status": "success",
|
|
1192
|
+
"message": "Causal analysis processed and stored",
|
|
1193
|
+
"analysis_summary": {
|
|
1194
|
+
"causal_analysis_length": len(causal_analysis),
|
|
1195
|
+
"num_scenarios": len(counterfactual_scenarios),
|
|
1196
|
+
"has_optimal_solution": bool(optimal_solution)
|
|
1197
|
+
}
|
|
1198
|
+
}
|
|
1199
|
+
|
|
1200
|
+
def _extract_causal_variables_handler(
|
|
1201
|
+
self,
|
|
1202
|
+
required_variables: List[str],
|
|
1203
|
+
causal_edges: List[List[str]],
|
|
1204
|
+
reasoning: str,
|
|
1205
|
+
optional_variables: Optional[List[str]] = None,
|
|
1206
|
+
counterfactual_variables: Optional[List[str]] = None,
|
|
1207
|
+
) -> Dict[str, Any]:
|
|
1208
|
+
"""
|
|
1209
|
+
Handler function for the extract_causal_variables tool.
|
|
1210
|
+
|
|
1211
|
+
This function is called when the LLM invokes the extract_causal_variables tool.
|
|
1212
|
+
It processes the tool's output and adds variables and edges to the causal graph.
|
|
1213
|
+
|
|
1214
|
+
Args:
|
|
1215
|
+
required_variables: Core variables that must be included for causal analysis
|
|
1216
|
+
causal_edges: Causal relationships as [source, target] pairs
|
|
1217
|
+
reasoning: Explanation of why these variables and relationships are needed
|
|
1218
|
+
optional_variables: Additional variables that may be useful but not essential
|
|
1219
|
+
counterfactual_variables: Variables to explore in counterfactual scenarios
|
|
1220
|
+
|
|
1221
|
+
Returns:
|
|
1222
|
+
Dictionary with processed results and summary of what was added
|
|
1223
|
+
"""
|
|
1224
|
+
import traceback
|
|
1225
|
+
|
|
1226
|
+
try:
|
|
1227
|
+
logger.info("=== EXTRACT HANDLER CALLED ===")
|
|
1228
|
+
logger.info(f"Processing variable extraction from tool call")
|
|
1229
|
+
logger.info(f"Received required_variables: {required_variables} (type: {type(required_variables)})")
|
|
1230
|
+
logger.info(f"Received causal_edges: {causal_edges} (type: {type(causal_edges)})")
|
|
1231
|
+
logger.info(f"Received optional_variables: {optional_variables}")
|
|
1232
|
+
logger.info(f"Received counterfactual_variables: {counterfactual_variables}")
|
|
1233
|
+
|
|
1234
|
+
# Validate inputs
|
|
1235
|
+
if not required_variables:
|
|
1236
|
+
logger.warning("No required_variables provided!")
|
|
1237
|
+
required_variables = []
|
|
1238
|
+
if not isinstance(required_variables, list):
|
|
1239
|
+
logger.warning(f"required_variables is not a list: {type(required_variables)}")
|
|
1240
|
+
required_variables = [str(required_variables)] if required_variables else []
|
|
1241
|
+
|
|
1242
|
+
if not causal_edges:
|
|
1243
|
+
logger.warning("No causal_edges provided!")
|
|
1244
|
+
causal_edges = []
|
|
1245
|
+
if not isinstance(causal_edges, list):
|
|
1246
|
+
logger.warning(f"causal_edges is not a list: {type(causal_edges)}")
|
|
1247
|
+
causal_edges = []
|
|
1248
|
+
|
|
1249
|
+
# Track what was added
|
|
1250
|
+
added_variables = []
|
|
1251
|
+
added_edges = []
|
|
1252
|
+
skipped_edges = []
|
|
1253
|
+
|
|
1254
|
+
# Add required variables
|
|
1255
|
+
for var in required_variables:
|
|
1256
|
+
if var and var.strip():
|
|
1257
|
+
var_clean = var.strip()
|
|
1258
|
+
if var_clean not in self.causal_graph:
|
|
1259
|
+
self._ensure_node_exists(var_clean)
|
|
1260
|
+
added_variables.append(var_clean)
|
|
1261
|
+
|
|
1262
|
+
# Add optional variables
|
|
1263
|
+
if optional_variables:
|
|
1264
|
+
for var in optional_variables:
|
|
1265
|
+
if var and var.strip():
|
|
1266
|
+
var_clean = var.strip()
|
|
1267
|
+
if var_clean not in self.causal_graph:
|
|
1268
|
+
self._ensure_node_exists(var_clean)
|
|
1269
|
+
added_variables.append(var_clean)
|
|
1270
|
+
|
|
1271
|
+
# Add counterfactual variables (also add them as nodes)
|
|
1272
|
+
if counterfactual_variables:
|
|
1273
|
+
for var in counterfactual_variables:
|
|
1274
|
+
if var and var.strip():
|
|
1275
|
+
var_clean = var.strip()
|
|
1276
|
+
if var_clean not in self.causal_graph:
|
|
1277
|
+
self._ensure_node_exists(var_clean)
|
|
1278
|
+
if var_clean not in added_variables:
|
|
1279
|
+
added_variables.append(var_clean)
|
|
1280
|
+
|
|
1281
|
+
# Add causal edges
|
|
1282
|
+
for edge in causal_edges:
|
|
1283
|
+
if isinstance(edge, (list, tuple)) and len(edge) >= 2:
|
|
1284
|
+
source = str(edge[0]).strip() if edge[0] else None
|
|
1285
|
+
target = str(edge[1]).strip() if edge[1] else None
|
|
1286
|
+
|
|
1287
|
+
if source and target:
|
|
1288
|
+
# Ensure both nodes exist
|
|
1289
|
+
if source not in self.causal_graph:
|
|
1290
|
+
self._ensure_node_exists(source)
|
|
1291
|
+
if source not in added_variables:
|
|
1292
|
+
added_variables.append(source)
|
|
1293
|
+
if target not in self.causal_graph:
|
|
1294
|
+
self._ensure_node_exists(target)
|
|
1295
|
+
if target not in added_variables:
|
|
1296
|
+
added_variables.append(target)
|
|
1297
|
+
|
|
1298
|
+
# Add the edge
|
|
1299
|
+
try:
|
|
1300
|
+
self.add_causal_relationship(source, target)
|
|
1301
|
+
added_edges.append((source, target))
|
|
1302
|
+
except Exception as e:
|
|
1303
|
+
logger.warning(f"Failed to add edge {source} -> {target}: {e}")
|
|
1304
|
+
skipped_edges.append((source, target))
|
|
1305
|
+
else:
|
|
1306
|
+
logger.warning(f"Invalid edge format: {edge}")
|
|
1307
|
+
skipped_edges.append(edge)
|
|
1308
|
+
|
|
1309
|
+
# Store extraction metadata in causal memory
|
|
1310
|
+
extraction_entry = {
|
|
1311
|
+
'type': 'variable_extraction',
|
|
1312
|
+
'required_variables': required_variables,
|
|
1313
|
+
'optional_variables': optional_variables or [],
|
|
1314
|
+
'counterfactual_variables': counterfactual_variables or [],
|
|
1315
|
+
'causal_edges': causal_edges,
|
|
1316
|
+
'reasoning': reasoning,
|
|
1317
|
+
'added_variables': added_variables,
|
|
1318
|
+
'added_edges': added_edges,
|
|
1319
|
+
'skipped_edges': skipped_edges,
|
|
1320
|
+
'timestamp': len(self.causal_memory)
|
|
1321
|
+
}
|
|
1322
|
+
|
|
1323
|
+
self.causal_memory.append(extraction_entry)
|
|
1324
|
+
|
|
1325
|
+
# Log what was actually added
|
|
1326
|
+
logger.info(f"Extraction complete: Added {len(added_variables)} variables, {len(added_edges)} edges")
|
|
1327
|
+
logger.info(f"Added variables: {added_variables}")
|
|
1328
|
+
logger.info(f"Added edges: {added_edges}")
|
|
1329
|
+
logger.info(f"Current graph size: {len(self.causal_graph)} variables, {sum(len(children) for children in self.causal_graph.values())} edges")
|
|
1330
|
+
|
|
1331
|
+
# Return a structured response
|
|
1332
|
+
result = {
|
|
1333
|
+
"status": "success",
|
|
1334
|
+
"message": "Variables and relationships extracted and added to causal graph",
|
|
1335
|
+
"summary": {
|
|
1336
|
+
"variables_added": len(added_variables),
|
|
1337
|
+
"edges_added": len(added_edges),
|
|
1338
|
+
"edges_skipped": len(skipped_edges),
|
|
1339
|
+
"total_variables_in_graph": len(self.causal_graph),
|
|
1340
|
+
"total_edges_in_graph": sum(len(children) for children in self.causal_graph.values())
|
|
1341
|
+
},
|
|
1342
|
+
"details": {
|
|
1343
|
+
"added_variables": added_variables,
|
|
1344
|
+
"added_edges": added_edges,
|
|
1345
|
+
"skipped_edges": skipped_edges if skipped_edges else None
|
|
1346
|
+
}
|
|
1347
|
+
}
|
|
1348
|
+
logger.info(f"Returning result: {result}")
|
|
1349
|
+
return result
|
|
1350
|
+
|
|
1351
|
+
except Exception as e:
|
|
1352
|
+
logger.error(f"ERROR in _extract_causal_variables_handler: {e}")
|
|
1353
|
+
logger.error(f"Traceback: {traceback.format_exc()}")
|
|
1354
|
+
# Return error but don't fail completely
|
|
1355
|
+
return {
|
|
1356
|
+
"status": "error",
|
|
1357
|
+
"message": f"Error processing variable extraction: {str(e)}",
|
|
1358
|
+
"summary": {
|
|
1359
|
+
"variables_added": 0,
|
|
1360
|
+
"edges_added": 0,
|
|
1361
|
+
}
|
|
1362
|
+
}
|
|
1363
|
+
|
|
1364
|
+
def _build_causal_prompt(self, task: str) -> str:
|
|
1365
|
+
return (
|
|
1366
|
+
"You are a Causal Reasoning with Counterfactual Analysis (CR-CA) agent.\n"
|
|
1367
|
+
f"Problem: {task}\n"
|
|
1368
|
+
f"Current causal graph has {len(self.causal_graph)} variables and "
|
|
1369
|
+
f"{sum(len(children) for children in self.causal_graph.values())} relationships.\n\n"
|
|
1370
|
+
"CRITICAL: You MUST use the generate_causal_analysis tool to provide your analysis.\n"
|
|
1371
|
+
"The variables have already been extracted. Now you need to generate the causal analysis.\n"
|
|
1372
|
+
"Do NOT call extract_causal_variables again. You MUST call generate_causal_analysis with:\n"
|
|
1373
|
+
"- causal_analysis: Detailed analysis of causal relationships and mechanisms\n"
|
|
1374
|
+
"- intervention_planning: Planned interventions to test causal hypotheses\n"
|
|
1375
|
+
"- counterfactual_scenarios: List of what-if scenarios (array of objects)\n"
|
|
1376
|
+
"- causal_strength_assessment: Assessment of relationship strengths and confounders\n"
|
|
1377
|
+
"- optimal_solution: Recommended solution based on analysis\n"
|
|
1378
|
+
)
|
|
1379
|
+
|
|
1380
|
+
def _build_variable_extraction_prompt(self, task: str) -> str:
|
|
1381
|
+
"""
|
|
1382
|
+
Build a prompt that guides the LLM to extract variables from a task.
|
|
1383
|
+
|
|
1384
|
+
Args:
|
|
1385
|
+
task: The task string to analyze
|
|
1386
|
+
|
|
1387
|
+
Returns:
|
|
1388
|
+
Formatted prompt string for variable extraction
|
|
1389
|
+
"""
|
|
1390
|
+
return (
|
|
1391
|
+
"You are analyzing a causal reasoning task. The causal graph is currently empty.\n"
|
|
1392
|
+
f"Task: {task}\n\n"
|
|
1393
|
+
"**CRITICAL: You MUST use the extract_causal_variables tool to proceed.**\n"
|
|
1394
|
+
"Do NOT just describe what variables might be needed - you MUST call the tool.\n\n"
|
|
1395
|
+
"Call the extract_causal_variables tool with:\n"
|
|
1396
|
+
"1. required_variables: List of core variables needed for causal analysis\n"
|
|
1397
|
+
"2. causal_edges: List of [source, target] pairs showing causal relationships\n"
|
|
1398
|
+
"3. reasoning: Explanation of why these variables are needed\n"
|
|
1399
|
+
"4. optional_variables: (optional) Additional variables that may be useful\n"
|
|
1400
|
+
"5. counterfactual_variables: (optional) Variables to explore in what-if scenarios\n\n"
|
|
1401
|
+
"Example: For a pricing task, you might extract:\n"
|
|
1402
|
+
"- required_variables: ['price', 'demand', 'supply', 'cost']\n"
|
|
1403
|
+
"- causal_edges: [['price', 'demand'], ['cost', 'price'], ['supply', 'price']]\n"
|
|
1404
|
+
"- reasoning: 'Price affects demand, cost affects price, supply affects price'\n\n"
|
|
1405
|
+
"You can call the tool multiple times to refine your extraction.\n"
|
|
1406
|
+
"Be thorough - extract all variables and relationships implied by the task."
|
|
1407
|
+
)
|
|
1408
|
+
|
|
1409
|
+
def _build_memory_context(self) -> str:
|
|
1410
|
+
"""Build memory context from causal_memory, handling different memory entry structures."""
|
|
1411
|
+
context_parts = []
|
|
1412
|
+
for step in self.causal_memory[-2:]: # Last 2 steps
|
|
1413
|
+
if isinstance(step, dict):
|
|
1414
|
+
# Handle standard analysis step structure
|
|
1415
|
+
if 'step' in step and 'analysis' in step:
|
|
1416
|
+
context_parts.append(f"Step {step['step']}: {step['analysis']}")
|
|
1417
|
+
# Handle extraction entry structure
|
|
1418
|
+
elif 'type' in step and step.get('type') == 'extraction':
|
|
1419
|
+
context_parts.append(f"Variable Extraction: {step.get('summary', 'Variables extracted')}")
|
|
1420
|
+
# Handle generic entry
|
|
1421
|
+
elif 'analysis' in step:
|
|
1422
|
+
context_parts.append(f"Analysis: {step['analysis']}")
|
|
1423
|
+
elif 'summary' in step:
|
|
1424
|
+
context_parts.append(f"Summary: {step['summary']}")
|
|
1425
|
+
return "\n".join(context_parts) if context_parts else ""
|
|
1426
|
+
|
|
1427
|
+
def _synthesize_causal_analysis(self, task: str) -> str:
|
|
1428
|
+
"""
|
|
1429
|
+
Synthesize a final causal analysis report from the analysis steps.
|
|
1430
|
+
|
|
1431
|
+
Uses direct LLM call to avoid Agent tool execution issues.
|
|
1432
|
+
|
|
1433
|
+
Args:
|
|
1434
|
+
task: Original task string
|
|
1435
|
+
|
|
1436
|
+
Returns:
|
|
1437
|
+
Synthesized causal analysis report
|
|
1438
|
+
"""
|
|
1439
|
+
synthesis_prompt = f"Based on the causal analysis steps performed, synthesize a concise causal report for: {task}"
|
|
1440
|
+
try:
|
|
1441
|
+
# Use direct LLM call to avoid tool execution errors
|
|
1442
|
+
response = self._call_llm_directly(synthesis_prompt)
|
|
1443
|
+
return str(response) if response else "Analysis synthesis failed"
|
|
1444
|
+
except Exception as e:
|
|
1445
|
+
logger.error(f"Error synthesizing causal analysis: {e}")
|
|
1446
|
+
return "Analysis synthesis failed"
|
|
1447
|
+
|
|
1448
|
+
def _should_trigger_causal_analysis(self, task: Optional[Union[str, Any]]) -> bool:
|
|
1449
|
+
"""
|
|
1450
|
+
Automatically detect if a task should trigger causal analysis.
|
|
1451
|
+
|
|
1452
|
+
This method analyzes the task content to determine if it requires
|
|
1453
|
+
causal reasoning, counterfactual analysis, or relationship analysis.
|
|
1454
|
+
|
|
1455
|
+
Args:
|
|
1456
|
+
task: The task string to analyze
|
|
1457
|
+
|
|
1458
|
+
Returns:
|
|
1459
|
+
True if causal analysis should be triggered, False otherwise
|
|
1460
|
+
"""
|
|
1461
|
+
if task is None:
|
|
1462
|
+
return False
|
|
1463
|
+
|
|
1464
|
+
if not isinstance(task, str):
|
|
1465
|
+
return False
|
|
1466
|
+
|
|
1467
|
+
task_lower = task.lower().strip()
|
|
1468
|
+
|
|
1469
|
+
# Keywords that indicate causal analysis is needed
|
|
1470
|
+
causal_keywords = [
|
|
1471
|
+
# Causal relationship terms
|
|
1472
|
+
'causal', 'causality', 'cause', 'causes', 'caused by', 'causing',
|
|
1473
|
+
'relationship', 'relationships', 'relate', 'relates', 'related',
|
|
1474
|
+
'influence', 'influences', 'influenced', 'affect', 'affects', 'affected',
|
|
1475
|
+
'impact', 'impacts', 'impacted', 'effect', 'effects',
|
|
1476
|
+
'depend', 'depends', 'dependency', 'dependencies',
|
|
1477
|
+
'correlation', 'correlate', 'correlates',
|
|
1478
|
+
|
|
1479
|
+
# Counterfactual analysis terms
|
|
1480
|
+
'counterfactual', 'counterfactuals', 'what if', 'what-if',
|
|
1481
|
+
'scenario', 'scenarios', 'alternative', 'alternatives',
|
|
1482
|
+
'hypothetical', 'hypothesis', 'hypotheses',
|
|
1483
|
+
'if then', 'if-then', 'suppose', 'assuming',
|
|
1484
|
+
|
|
1485
|
+
# Prediction and forecasting terms (often need causal reasoning)
|
|
1486
|
+
'predict', 'prediction', 'forecast', 'forecasting', 'project', 'projection',
|
|
1487
|
+
'expected', 'expect', 'expectation', 'estimate', 'estimation',
|
|
1488
|
+
'future', 'future value', 'future price', 'in 24 months', 'in X months',
|
|
1489
|
+
'will be', 'would be', 'could be', 'might be',
|
|
1490
|
+
|
|
1491
|
+
# Analysis terms
|
|
1492
|
+
'analyze', 'analysis', 'analyzing', 'analyze the', 'analyze how',
|
|
1493
|
+
'understand', 'understanding', 'explain', 'explanation',
|
|
1494
|
+
'reasoning', 'reason', 'rationale',
|
|
1495
|
+
|
|
1496
|
+
# Relationship-specific terms
|
|
1497
|
+
'between', 'among', 'link', 'links', 'connection', 'connections',
|
|
1498
|
+
'chain', 'chains', 'path', 'paths', 'flow', 'flows',
|
|
1499
|
+
|
|
1500
|
+
# Intervention terms
|
|
1501
|
+
'intervention', 'interventions', 'change', 'changes', 'modify', 'modifies',
|
|
1502
|
+
'adjust', 'adjusts', 'alter', 'alters',
|
|
1503
|
+
|
|
1504
|
+
# Risk and consequence terms (NEW)
|
|
1505
|
+
'risk', 'risks', 'consequence', 'consequences', 'benefit', 'benefits',
|
|
1506
|
+
'trade-off', 'trade-offs', 'tradeoff', 'tradeoffs',
|
|
1507
|
+
'downside', 'downsides', 'upside', 'upsides',
|
|
1508
|
+
|
|
1509
|
+
# Determination and outcome terms (NEW)
|
|
1510
|
+
'determine', 'determines', 'determining', 'determination',
|
|
1511
|
+
'result', 'results', 'resulting', 'outcome', 'outcomes',
|
|
1512
|
+
'consideration', 'considerations', 'factor', 'factors',
|
|
1513
|
+
'driver', 'drivers', 'driving', 'drives', 'driven',
|
|
1514
|
+
'lead to', 'leads to', 'leading to', 'led to',
|
|
1515
|
+
|
|
1516
|
+
# Decision-making terms (NEW)
|
|
1517
|
+
'should', 'should we', 'should i', 'should they',
|
|
1518
|
+
'better', 'best', 'worse', 'worst', 'compare', 'comparison',
|
|
1519
|
+
'option', 'options', 'choice', 'choices', 'choose',
|
|
1520
|
+
'strategy', 'strategies', 'approach', 'approaches',
|
|
1521
|
+
'decision', 'decisions', 'decide',
|
|
1522
|
+
|
|
1523
|
+
# Importance and consideration terms (NEW)
|
|
1524
|
+
'important', 'importance', 'matter', 'matters', 'mattering',
|
|
1525
|
+
'consider', 'considering', 'consideration', 'considerations',
|
|
1526
|
+
'key', 'keys', 'critical', 'crucial', 'essential',
|
|
1527
|
+
]
|
|
1528
|
+
|
|
1529
|
+
# Check if task contains causal keywords
|
|
1530
|
+
for keyword in causal_keywords:
|
|
1531
|
+
if keyword in task_lower:
|
|
1532
|
+
return True
|
|
1533
|
+
|
|
1534
|
+
# Check for questions about variables in the causal graph
|
|
1535
|
+
# If the task mentions any of our variables, it's likely a causal question
|
|
1536
|
+
graph_variables = [var.lower() for var in self.causal_graph.keys()]
|
|
1537
|
+
for var in graph_variables:
|
|
1538
|
+
if var in task_lower:
|
|
1539
|
+
return True
|
|
1540
|
+
|
|
1541
|
+
# Check for patterns that suggest causal reasoning
|
|
1542
|
+
causal_patterns = [
|
|
1543
|
+
'how does', 'how do', 'how will', 'how would', 'how might', 'how can',
|
|
1544
|
+
'why does', 'why do', 'why will', 'why would', 'why did',
|
|
1545
|
+
'what happens if', 'what would happen', 'what will happen', 'what happens when',
|
|
1546
|
+
'if we', 'if you', 'if they', 'if it', 'if this', 'if that',
|
|
1547
|
+
'what should', 'what should we', 'what should i',
|
|
1548
|
+
'which is', 'which are', 'which would', 'which will',
|
|
1549
|
+
'what results', 'what results from', 'what comes', 'what comes next',
|
|
1550
|
+
'what leads', 'what leads to', 'what follows', 'what follows from',
|
|
1551
|
+
]
|
|
1552
|
+
|
|
1553
|
+
for pattern in causal_patterns:
|
|
1554
|
+
if pattern in task_lower:
|
|
1555
|
+
return True
|
|
1556
|
+
|
|
1557
|
+
return False
|
|
1558
|
+
def _ensure_node_exists(self, node: str) -> None:
|
|
1559
|
+
|
|
1560
|
+
if node not in self.causal_graph:
|
|
1561
|
+
self.causal_graph[node] = {}
|
|
1562
|
+
if node not in self.causal_graph_reverse:
|
|
1563
|
+
self.causal_graph_reverse[node] = []
|
|
1564
|
+
try:
|
|
1565
|
+
self._ensure_node_index(node)
|
|
1566
|
+
except Exception:
|
|
1567
|
+
pass
|
|
1568
|
+
|
|
1569
|
+
def add_causal_relationship(
|
|
1570
|
+
self,
|
|
1571
|
+
source: str,
|
|
1572
|
+
target: str,
|
|
1573
|
+
strength: float = 1.0,
|
|
1574
|
+
relation_type: CausalRelationType = CausalRelationType.DIRECT,
|
|
1575
|
+
confidence: float = 1.0
|
|
1576
|
+
) -> None:
|
|
1577
|
+
|
|
1578
|
+
self._ensure_node_exists(source)
|
|
1579
|
+
self._ensure_node_exists(target)
|
|
1580
|
+
|
|
1581
|
+
meta = {
|
|
1582
|
+
"strength": float(strength),
|
|
1583
|
+
"relation_type": relation_type.value if isinstance(relation_type, Enum) else str(relation_type),
|
|
1584
|
+
"confidence": float(confidence),
|
|
1585
|
+
}
|
|
1586
|
+
|
|
1587
|
+
self.causal_graph.setdefault(source, {})[target] = meta
|
|
1588
|
+
|
|
1589
|
+
if source not in self.causal_graph_reverse.get(target, []):
|
|
1590
|
+
self.causal_graph_reverse.setdefault(target, []).append(source)
|
|
1591
|
+
|
|
1592
|
+
try:
|
|
1593
|
+
u_idx = self._ensure_node_index(source)
|
|
1594
|
+
v_idx = self._ensure_node_index(target)
|
|
1595
|
+
try:
|
|
1596
|
+
existing = self._graph.get_edge_data(u_idx, v_idx)
|
|
1597
|
+
except Exception:
|
|
1598
|
+
existing = None
|
|
1599
|
+
|
|
1600
|
+
if existing is None:
|
|
1601
|
+
try:
|
|
1602
|
+
self._graph.add_edge(u_idx, v_idx, meta)
|
|
1603
|
+
except Exception:
|
|
1604
|
+
try:
|
|
1605
|
+
import logging
|
|
1606
|
+
logging.getLogger(__name__).warning(
|
|
1607
|
+
f"rustworkx.add_edge failed for {source}->{target}; continuing with dict-only graph."
|
|
1608
|
+
)
|
|
1609
|
+
except Exception:
|
|
1610
|
+
pass
|
|
1611
|
+
else:
|
|
1612
|
+
try:
|
|
1613
|
+
if isinstance(existing, dict):
|
|
1614
|
+
existing.update(meta)
|
|
1615
|
+
try:
|
|
1616
|
+
import logging
|
|
1617
|
+
logging.getLogger(__name__).debug(
|
|
1618
|
+
f"Updated rustworkx edge data for {source}->{target} in-place."
|
|
1619
|
+
)
|
|
1620
|
+
except Exception:
|
|
1621
|
+
pass
|
|
1622
|
+
else:
|
|
1623
|
+
try:
|
|
1624
|
+
edge_idx = self._graph.get_edge_index(u_idx, v_idx)
|
|
1625
|
+
except Exception:
|
|
1626
|
+
edge_idx = None
|
|
1627
|
+
if edge_idx is not None and edge_idx >= 0:
|
|
1628
|
+
try:
|
|
1629
|
+
self._graph.remove_edge(edge_idx)
|
|
1630
|
+
self._graph.add_edge(u_idx, v_idx, meta)
|
|
1631
|
+
try:
|
|
1632
|
+
import logging
|
|
1633
|
+
logging.getLogger(__name__).debug(
|
|
1634
|
+
f"Replaced rustworkx edge for {source}->{target} with updated metadata."
|
|
1635
|
+
)
|
|
1636
|
+
except Exception:
|
|
1637
|
+
pass
|
|
1638
|
+
except Exception:
|
|
1639
|
+
try:
|
|
1640
|
+
import logging
|
|
1641
|
+
logging.getLogger(__name__).warning(
|
|
1642
|
+
f"Could not replace rustworkx edge for {source}->{target}; keeping dict-only metadata."
|
|
1643
|
+
)
|
|
1644
|
+
except Exception:
|
|
1645
|
+
pass
|
|
1646
|
+
else:
|
|
1647
|
+
try:
|
|
1648
|
+
import logging
|
|
1649
|
+
logging.getLogger(__name__).debug(
|
|
1650
|
+
f"rustworkx edge exists but index lookup failed for {source}->{target}; dict metadata used."
|
|
1651
|
+
)
|
|
1652
|
+
except Exception:
|
|
1653
|
+
pass
|
|
1654
|
+
except Exception:
|
|
1655
|
+
try:
|
|
1656
|
+
import logging
|
|
1657
|
+
logging.getLogger(__name__).warning(
|
|
1658
|
+
f"Failed updating rustworkx edge for {source}->{target}; continuing with dict-only graph."
|
|
1659
|
+
)
|
|
1660
|
+
except Exception:
|
|
1661
|
+
pass
|
|
1662
|
+
except Exception:
|
|
1663
|
+
try:
|
|
1664
|
+
import logging
|
|
1665
|
+
logging.getLogger(__name__).warning(
|
|
1666
|
+
"rustworkx operation failed during add_causal_relationship; continuing with dict-only graph."
|
|
1667
|
+
)
|
|
1668
|
+
except Exception:
|
|
1669
|
+
pass
|
|
1670
|
+
|
|
1671
|
+
def _get_parents(self, node: str) -> List[str]:
|
|
1672
|
+
|
|
1673
|
+
return self.causal_graph_reverse.get(node, [])
|
|
1674
|
+
|
|
1675
|
+
def _get_children(self, node: str) -> List[str]:
|
|
1676
|
+
|
|
1677
|
+
return list(self.causal_graph.get(node, {}).keys())
|
|
1678
|
+
|
|
1679
|
+
def _ensure_node_index(self, name: str) -> int:
|
|
1680
|
+
|
|
1681
|
+
if name in self._node_to_index:
|
|
1682
|
+
return self._node_to_index[name]
|
|
1683
|
+
idx = self._graph.add_node(name)
|
|
1684
|
+
self._node_to_index[name] = idx
|
|
1685
|
+
self._index_to_node[idx] = name
|
|
1686
|
+
return idx
|
|
1687
|
+
|
|
1688
|
+
def _node_index(self, name: str) -> Optional[int]:
|
|
1689
|
+
|
|
1690
|
+
return self._node_to_index.get(name)
|
|
1691
|
+
|
|
1692
|
+
def _node_name(self, idx: int) -> Optional[str]:
|
|
1693
|
+
|
|
1694
|
+
return self._index_to_node.get(idx)
|
|
1695
|
+
|
|
1696
|
+
def _edge_strength(self, source: str, target: str) -> float:
|
|
1697
|
+
|
|
1698
|
+
edge = self.causal_graph.get(source, {}).get(target, None)
|
|
1699
|
+
if isinstance(edge, dict):
|
|
1700
|
+
return float(edge.get("strength", 0.0))
|
|
1701
|
+
try:
|
|
1702
|
+
return float(edge) if edge is not None else 0.0
|
|
1703
|
+
except Exception:
|
|
1704
|
+
return 0.0
|
|
1705
|
+
|
|
1706
|
+
def _topological_sort(self) -> List[str]:
|
|
1707
|
+
|
|
1708
|
+
try:
|
|
1709
|
+
order_idx = rx.topological_sort(self._graph)
|
|
1710
|
+
result = [self._node_name(i) for i in order_idx if self._node_name(i) is not None]
|
|
1711
|
+
for n in list(self.causal_graph.keys()):
|
|
1712
|
+
if n not in result:
|
|
1713
|
+
result.append(n)
|
|
1714
|
+
return result
|
|
1715
|
+
except Exception:
|
|
1716
|
+
in_degree: Dict[str, int] = {node: 0 for node in self.causal_graph.keys()}
|
|
1717
|
+
for node in self.causal_graph:
|
|
1718
|
+
for child in self._get_children(node):
|
|
1719
|
+
in_degree[child] = in_degree.get(child, 0) + 1
|
|
1720
|
+
|
|
1721
|
+
queue: List[str] = [node for node, degree in in_degree.items() if degree == 0]
|
|
1722
|
+
result: List[str] = []
|
|
1723
|
+
while queue:
|
|
1724
|
+
node = queue.pop(0)
|
|
1725
|
+
result.append(node)
|
|
1726
|
+
for child in self._get_children(node):
|
|
1727
|
+
in_degree[child] -= 1
|
|
1728
|
+
if in_degree[child] == 0:
|
|
1729
|
+
queue.append(child)
|
|
1730
|
+
return result
|
|
1731
|
+
|
|
1732
|
+
def identify_causal_chain(self, start: str, end: str) -> List[str]:
|
|
1733
|
+
|
|
1734
|
+
if start not in self.causal_graph or end not in self.causal_graph:
|
|
1735
|
+
return []
|
|
1736
|
+
|
|
1737
|
+
if start == end:
|
|
1738
|
+
return [start]
|
|
1739
|
+
|
|
1740
|
+
queue: List[Tuple[str, List[str]]] = [(start, [start])]
|
|
1741
|
+
visited: set = {start}
|
|
1742
|
+
|
|
1743
|
+
while queue:
|
|
1744
|
+
current, path = queue.pop(0)
|
|
1745
|
+
|
|
1746
|
+
for child in self._get_children(current):
|
|
1747
|
+
if child == end:
|
|
1748
|
+
return path + [child]
|
|
1749
|
+
|
|
1750
|
+
if child not in visited:
|
|
1751
|
+
visited.add(child)
|
|
1752
|
+
queue.append((child, path + [child]))
|
|
1753
|
+
|
|
1754
|
+
return [] # No path found
|
|
1755
|
+
|
|
1756
|
+
|
|
1757
|
+
def _has_path(self, start: str, end: str) -> bool:
|
|
1758
|
+
|
|
1759
|
+
if start == end:
|
|
1760
|
+
return True
|
|
1761
|
+
|
|
1762
|
+
stack = [start]
|
|
1763
|
+
visited = set()
|
|
1764
|
+
|
|
1765
|
+
while stack:
|
|
1766
|
+
current = stack.pop()
|
|
1767
|
+
if current in visited:
|
|
1768
|
+
continue
|
|
1769
|
+
visited.add(current)
|
|
1770
|
+
|
|
1771
|
+
for child in self._get_children(current):
|
|
1772
|
+
if child == end:
|
|
1773
|
+
return True
|
|
1774
|
+
if child not in visited:
|
|
1775
|
+
stack.append(child)
|
|
1776
|
+
|
|
1777
|
+
return False
|
|
1778
|
+
|
|
1779
|
+
def clear_cache(self) -> None:
|
|
1780
|
+
|
|
1781
|
+
with self._prediction_cache_lock:
|
|
1782
|
+
self._prediction_cache.clear()
|
|
1783
|
+
self._prediction_cache_order.clear()
|
|
1784
|
+
|
|
1785
|
+
def enable_cache(self, flag: bool) -> None:
|
|
1786
|
+
|
|
1787
|
+
with self._prediction_cache_lock:
|
|
1788
|
+
self._cache_enabled = bool(flag)
|
|
1789
|
+
|
|
1790
|
+
|
|
1791
|
+
def _standardize_state(self, state: Dict[str, float]) -> Dict[str, float]:
|
|
1792
|
+
|
|
1793
|
+
z: Dict[str, float] = {}
|
|
1794
|
+
for k, v in state.items():
|
|
1795
|
+
s = self.standardization_stats.get(k)
|
|
1796
|
+
if s and s.get("std", 0.0) > 0:
|
|
1797
|
+
z[k] = (v - s["mean"]) / s["std"]
|
|
1798
|
+
else:
|
|
1799
|
+
z[k] = v
|
|
1800
|
+
return z
|
|
1801
|
+
|
|
1802
|
+
def _destandardize_value(self, var: str, z_value: float) -> float:
|
|
1803
|
+
|
|
1804
|
+
s = self.standardization_stats.get(var)
|
|
1805
|
+
if s and s.get("std", 0.0) > 0:
|
|
1806
|
+
return z_value * s["std"] + s["mean"]
|
|
1807
|
+
return z_value
|
|
1808
|
+
|
|
1809
|
+
def _predict_outcomes(
|
|
1810
|
+
self,
|
|
1811
|
+
factual_state: Dict[str, float],
|
|
1812
|
+
interventions: Dict[str, float]
|
|
1813
|
+
) -> Dict[str, float]:
|
|
1814
|
+
|
|
1815
|
+
if self.use_nonlinear_scm:
|
|
1816
|
+
z_pred = self._predict_z(factual_state, interventions, use_noise=None)
|
|
1817
|
+
return {v: self._destandardize_value(v, z_val) for v, z_val in z_pred.items()}
|
|
1818
|
+
|
|
1819
|
+
raw = factual_state.copy()
|
|
1820
|
+
raw.update(interventions)
|
|
1821
|
+
|
|
1822
|
+
z_state = self._standardize_state(raw)
|
|
1823
|
+
z_pred = dict(z_state)
|
|
1824
|
+
|
|
1825
|
+
for node in self._topological_sort():
|
|
1826
|
+
if node in interventions:
|
|
1827
|
+
if node not in z_pred:
|
|
1828
|
+
z_pred[node] = z_state.get(node, 0.0)
|
|
1829
|
+
continue
|
|
1830
|
+
|
|
1831
|
+
parents = self._get_parents(node)
|
|
1832
|
+
if not parents:
|
|
1833
|
+
continue
|
|
1834
|
+
|
|
1835
|
+
s = 0.0
|
|
1836
|
+
for p in parents:
|
|
1837
|
+
pz = z_pred.get(p, z_state.get(p, 0.0))
|
|
1838
|
+
strength = self._edge_strength(p, node)
|
|
1839
|
+
s += pz * strength
|
|
1840
|
+
|
|
1841
|
+
z_pred[node] = s
|
|
1842
|
+
|
|
1843
|
+
return {v: self._destandardize_value(v, z) for v, z in z_pred.items()}
|
|
1844
|
+
|
|
1845
|
+
def _predict_outcomes_with_graph_variant(
|
|
1846
|
+
self,
|
|
1847
|
+
factual_state: Dict[str, float],
|
|
1848
|
+
interventions: Dict[str, float],
|
|
1849
|
+
graph_variant: Dict[Tuple[str, str], float]
|
|
1850
|
+
) -> Dict[str, float]:
|
|
1851
|
+
"""Predict outcomes using a temporary graph variant.
|
|
1852
|
+
|
|
1853
|
+
Temporarily applies a graph variant (perturbed edge strengths),
|
|
1854
|
+
runs prediction, then restores original graph.
|
|
1855
|
+
|
|
1856
|
+
Args:
|
|
1857
|
+
factual_state: Current factual state
|
|
1858
|
+
interventions: Interventions to apply
|
|
1859
|
+
graph_variant: Dictionary mapping (source, target) -> strength
|
|
1860
|
+
|
|
1861
|
+
Returns:
|
|
1862
|
+
Predicted outcomes
|
|
1863
|
+
"""
|
|
1864
|
+
# Save original edge strengths
|
|
1865
|
+
original_strengths: Dict[Tuple[str, str], float] = {}
|
|
1866
|
+
for (u, v), variant_strength in graph_variant.items():
|
|
1867
|
+
if u in self.causal_graph and v in self.causal_graph[u]:
|
|
1868
|
+
edge = self.causal_graph[u][v]
|
|
1869
|
+
if isinstance(edge, dict):
|
|
1870
|
+
original_strengths[(u, v)] = edge.get("strength", 0.0)
|
|
1871
|
+
# Temporarily apply variant
|
|
1872
|
+
self.causal_graph[u][v]["strength"] = variant_strength
|
|
1873
|
+
else:
|
|
1874
|
+
original_strengths[(u, v)] = float(edge) if edge is not None else 0.0
|
|
1875
|
+
# Convert to dict format
|
|
1876
|
+
self.causal_graph[u][v] = {"strength": variant_strength, "confidence": 1.0}
|
|
1877
|
+
|
|
1878
|
+
try:
|
|
1879
|
+
# Run prediction with variant graph
|
|
1880
|
+
predictions = self._predict_outcomes(factual_state, interventions)
|
|
1881
|
+
finally:
|
|
1882
|
+
# Restore original edge strengths
|
|
1883
|
+
for (u, v), original_strength in original_strengths.items():
|
|
1884
|
+
if u in self.causal_graph and v in self.causal_graph[u]:
|
|
1885
|
+
edge = self.causal_graph[u][v]
|
|
1886
|
+
if isinstance(edge, dict):
|
|
1887
|
+
edge["strength"] = original_strength
|
|
1888
|
+
else:
|
|
1889
|
+
self.causal_graph[u][v] = original_strength
|
|
1890
|
+
|
|
1891
|
+
return predictions
|
|
1892
|
+
|
|
1893
|
+
def _predict_z(self, factual_state: Dict[str, float], interventions: Dict[str, float], use_noise: Optional[Dict[str, float]] = None) -> Dict[str, float]:
|
|
1894
|
+
|
|
1895
|
+
raw = factual_state.copy()
|
|
1896
|
+
raw.update(interventions)
|
|
1897
|
+
z_state = self._standardize_state(raw)
|
|
1898
|
+
z_pred: Dict[str, float] = dict(z_state)
|
|
1899
|
+
|
|
1900
|
+
for node in self._topological_sort():
|
|
1901
|
+
if node in interventions:
|
|
1902
|
+
z_pred[node] = z_state.get(node, 0.0)
|
|
1903
|
+
continue
|
|
1904
|
+
|
|
1905
|
+
parents = self._get_parents(node)
|
|
1906
|
+
if not parents:
|
|
1907
|
+
z_val = float(use_noise.get(node, 0.0)) if use_noise else z_state.get(node, 0.0)
|
|
1908
|
+
z_pred[node] = z_val
|
|
1909
|
+
continue
|
|
1910
|
+
|
|
1911
|
+
linear_term = 0.0
|
|
1912
|
+
for p in parents:
|
|
1913
|
+
parent_z = z_pred.get(p, z_state.get(p, 0.0))
|
|
1914
|
+
beta = self._edge_strength(p, node)
|
|
1915
|
+
linear_term += parent_z * beta
|
|
1916
|
+
|
|
1917
|
+
interaction_term = 0.0
|
|
1918
|
+
for (p1, p2) in self.interaction_terms.get(node, []):
|
|
1919
|
+
if p1 in parents and p2 in parents:
|
|
1920
|
+
z1 = z_pred.get(p1, z_state.get(p1, 0.0))
|
|
1921
|
+
z2 = z_pred.get(p2, z_state.get(p2, 0.0))
|
|
1922
|
+
gamma = 0.0
|
|
1923
|
+
edge_data = self.causal_graph.get(p1, {}).get(node, {})
|
|
1924
|
+
if isinstance(edge_data, dict):
|
|
1925
|
+
gamma = float(edge_data.get("interaction_strength", {}).get(p2, 0.0))
|
|
1926
|
+
interaction_term += gamma * z1 * z2
|
|
1927
|
+
|
|
1928
|
+
model_z = linear_term + interaction_term
|
|
1929
|
+
|
|
1930
|
+
if use_noise:
|
|
1931
|
+
model_z += float(use_noise.get(node, 0.0))
|
|
1932
|
+
|
|
1933
|
+
if self.nonlinear_activation == "tanh":
|
|
1934
|
+
model_z_act = float(np.tanh(model_z) * 3.0) # scale to limit
|
|
1935
|
+
else:
|
|
1936
|
+
model_z_act = float(model_z)
|
|
1937
|
+
|
|
1938
|
+
observed_z = z_state.get(node, 0.0)
|
|
1939
|
+
|
|
1940
|
+
threshold = float(getattr(self, "shock_preserve_threshold", 1e-3))
|
|
1941
|
+
if abs(observed_z) > threshold:
|
|
1942
|
+
z_pred[node] = float(observed_z)
|
|
1943
|
+
else:
|
|
1944
|
+
z_pred[node] = float(model_z_act)
|
|
1945
|
+
|
|
1946
|
+
return z_pred
|
|
1947
|
+
|
|
1948
|
+
def aap(self, factual_state: Dict[str, float], interventions: Dict[str, float]) -> Dict[str, float]:
|
|
1949
|
+
|
|
1950
|
+
return self.counterfactual_abduction_action_prediction(factual_state, interventions)
|
|
1951
|
+
|
|
1952
|
+
def _predict_outcomes_cached(
|
|
1953
|
+
self,
|
|
1954
|
+
factual_state: Dict[str, float],
|
|
1955
|
+
interventions: Dict[str, float],
|
|
1956
|
+
) -> Dict[str, float]:
|
|
1957
|
+
|
|
1958
|
+
with self._prediction_cache_lock:
|
|
1959
|
+
cache_enabled = self._cache_enabled
|
|
1960
|
+
if not cache_enabled:
|
|
1961
|
+
return self._predict_outcomes(factual_state, interventions)
|
|
1962
|
+
|
|
1963
|
+
state_key = tuple(sorted([(k, float(v)) for k, v in factual_state.items()]))
|
|
1964
|
+
inter_key = tuple(sorted([(k, float(v)) for k, v in interventions.items()]))
|
|
1965
|
+
cache_key = (state_key, inter_key)
|
|
1966
|
+
|
|
1967
|
+
with self._prediction_cache_lock:
|
|
1968
|
+
if cache_key in self._prediction_cache:
|
|
1969
|
+
return dict(self._prediction_cache[cache_key])
|
|
1970
|
+
|
|
1971
|
+
result = self._predict_outcomes(factual_state, interventions)
|
|
1972
|
+
|
|
1973
|
+
with self._prediction_cache_lock:
|
|
1974
|
+
if len(self._prediction_cache_order) >= self._prediction_cache_max:
|
|
1975
|
+
remove_count = max(1, self._prediction_cache_max // 10)
|
|
1976
|
+
for _ in range(remove_count):
|
|
1977
|
+
old = self._prediction_cache_order.pop(0)
|
|
1978
|
+
if old in self._prediction_cache:
|
|
1979
|
+
del self._prediction_cache[old]
|
|
1980
|
+
|
|
1981
|
+
self._prediction_cache_order.append(cache_key)
|
|
1982
|
+
self._prediction_cache[cache_key] = dict(result)
|
|
1983
|
+
return result
|
|
1984
|
+
|
|
1985
|
+
def _get_descendants(self, node: str) -> List[str]:
|
|
1986
|
+
|
|
1987
|
+
if node not in self.causal_graph:
|
|
1988
|
+
return []
|
|
1989
|
+
stack = [node]
|
|
1990
|
+
visited = set()
|
|
1991
|
+
descendants: List[str] = []
|
|
1992
|
+
while stack:
|
|
1993
|
+
cur = stack.pop()
|
|
1994
|
+
for child in self._get_children(cur):
|
|
1995
|
+
if child in visited:
|
|
1996
|
+
continue
|
|
1997
|
+
visited.add(child)
|
|
1998
|
+
descendants.append(child)
|
|
1999
|
+
stack.append(child)
|
|
2000
|
+
return descendants
|
|
2001
|
+
|
|
2002
|
+
def counterfactual_abduction_action_prediction(
|
|
2003
|
+
self,
|
|
2004
|
+
factual_state: Dict[str, float],
|
|
2005
|
+
interventions: Dict[str, float]
|
|
2006
|
+
) -> Dict[str, float]:
|
|
2007
|
+
|
|
2008
|
+
z = self._standardize_state(factual_state)
|
|
2009
|
+
|
|
2010
|
+
noise: Dict[str, float] = {}
|
|
2011
|
+
for node in self._topological_sort():
|
|
2012
|
+
parents = self._get_parents(node)
|
|
2013
|
+
if not parents:
|
|
2014
|
+
noise[node] = float(z.get(node, 0.0))
|
|
2015
|
+
continue
|
|
2016
|
+
|
|
2017
|
+
pred_z = 0.0
|
|
2018
|
+
for p in parents:
|
|
2019
|
+
pz = z.get(p, 0.0)
|
|
2020
|
+
strength = self._edge_strength(p, node)
|
|
2021
|
+
pred_z += pz * strength
|
|
2022
|
+
|
|
2023
|
+
noise[node] = float(z.get(node, 0.0) - pred_z)
|
|
2024
|
+
|
|
2025
|
+
cf_raw = factual_state.copy()
|
|
2026
|
+
cf_raw.update(interventions)
|
|
2027
|
+
z_cf = self._standardize_state(cf_raw)
|
|
2028
|
+
|
|
2029
|
+
z_pred: Dict[str, float] = {}
|
|
2030
|
+
for node in self._topological_sort():
|
|
2031
|
+
if node in interventions:
|
|
2032
|
+
z_pred[node] = float(z_cf.get(node, 0.0))
|
|
2033
|
+
continue
|
|
2034
|
+
|
|
2035
|
+
parents = self._get_parents(node)
|
|
2036
|
+
if not parents:
|
|
2037
|
+
z_pred[node] = float(noise.get(node, 0.0))
|
|
2038
|
+
continue
|
|
2039
|
+
|
|
2040
|
+
val = 0.0
|
|
2041
|
+
for p in parents:
|
|
2042
|
+
parent_z = z_pred.get(p, z_cf.get(p, 0.0))
|
|
2043
|
+
strength = self._edge_strength(p, node)
|
|
2044
|
+
val += parent_z * strength
|
|
2045
|
+
|
|
2046
|
+
z_pred[node] = float(val + noise.get(node, 0.0))
|
|
2047
|
+
|
|
2048
|
+
return {v: self._destandardize_value(v, z_val) for v, z_val in z_pred.items()}
|
|
2049
|
+
|
|
2050
|
+
def detect_confounders(self, treatment: str, outcome: str) -> List[str]:
|
|
2051
|
+
|
|
2052
|
+
def _ancestors(node: str) -> set:
|
|
2053
|
+
stack = [node]
|
|
2054
|
+
visited = set()
|
|
2055
|
+
while stack:
|
|
2056
|
+
cur = stack.pop()
|
|
2057
|
+
for p in self._get_parents(cur):
|
|
2058
|
+
if p in visited:
|
|
2059
|
+
continue
|
|
2060
|
+
visited.add(p)
|
|
2061
|
+
stack.append(p)
|
|
2062
|
+
return visited
|
|
2063
|
+
|
|
2064
|
+
if treatment not in self.causal_graph or outcome not in self.causal_graph:
|
|
2065
|
+
return []
|
|
2066
|
+
|
|
2067
|
+
treat_anc = _ancestors(treatment)
|
|
2068
|
+
out_anc = _ancestors(outcome)
|
|
2069
|
+
common = treat_anc.intersection(out_anc)
|
|
2070
|
+
return list(common)
|
|
2071
|
+
|
|
2072
|
+
def identify_adjustment_set(self, treatment: str, outcome: str) -> List[str]:
|
|
2073
|
+
|
|
2074
|
+
if treatment not in self.causal_graph or outcome not in self.causal_graph:
|
|
2075
|
+
return []
|
|
2076
|
+
|
|
2077
|
+
parents_t = set(self._get_parents(treatment))
|
|
2078
|
+
descendants_t = set(self._get_descendants(treatment))
|
|
2079
|
+
adjustment = [z for z in parents_t if z not in descendants_t and z != outcome]
|
|
2080
|
+
return adjustment
|
|
2081
|
+
|
|
2082
|
+
def _calculate_scenario_probability(
|
|
2083
|
+
self,
|
|
2084
|
+
factual_state: Dict[str, float],
|
|
2085
|
+
interventions: Dict[str, float]
|
|
2086
|
+
) -> float:
|
|
2087
|
+
|
|
2088
|
+
z_sq = 0.0
|
|
2089
|
+
for var, new in interventions.items():
|
|
2090
|
+
s = self.standardization_stats.get(var, {"mean": 0.0, "std": 1.0})
|
|
2091
|
+
mu, sd = s.get("mean", 0.0), s.get("std", 1.0) or 1.0
|
|
2092
|
+
old = factual_state.get(var, mu)
|
|
2093
|
+
dz = (new - mu) / sd - (old - mu) / sd
|
|
2094
|
+
z_sq += float(dz) * float(dz)
|
|
2095
|
+
|
|
2096
|
+
p = 0.95 * float(np.exp(-0.5 * z_sq)) + 0.05
|
|
2097
|
+
return float(max(0.05, min(0.98, p)))
|
|
2098
|
+
|
|
2099
|
+
def generate_counterfactual_scenarios(
|
|
2100
|
+
self,
|
|
2101
|
+
factual_state: Dict[str, float],
|
|
2102
|
+
target_variables: List[str],
|
|
2103
|
+
max_scenarios: int = 5,
|
|
2104
|
+
use_monte_carlo: bool = True,
|
|
2105
|
+
mc_samples: int = 1000,
|
|
2106
|
+
parallel_sampling: bool = True,
|
|
2107
|
+
use_deterministic_fallback: bool = False
|
|
2108
|
+
) -> List[CounterfactualScenario]:
|
|
2109
|
+
"""Generate counterfactual scenarios using meta-Monte Carlo reasoning.
|
|
2110
|
+
|
|
2111
|
+
Args:
|
|
2112
|
+
factual_state: Current factual state
|
|
2113
|
+
target_variables: Variables to generate scenarios for
|
|
2114
|
+
max_scenarios: Maximum number of scenarios to return
|
|
2115
|
+
use_monte_carlo: Whether to use Monte Carlo sampling (default: True)
|
|
2116
|
+
mc_samples: Number of Monte Carlo iterations (default: 1000)
|
|
2117
|
+
parallel_sampling: Whether to sample graph and interventions in parallel (default: True)
|
|
2118
|
+
use_deterministic_fallback: Fallback to old deterministic method if MC fails (default: False)
|
|
2119
|
+
|
|
2120
|
+
Returns:
|
|
2121
|
+
List of CounterfactualScenario objects with uncertainty metadata
|
|
2122
|
+
"""
|
|
2123
|
+
# Fallback to old deterministic method if requested
|
|
2124
|
+
if not use_monte_carlo:
|
|
2125
|
+
return self._generate_counterfactual_scenarios_deterministic(
|
|
2126
|
+
factual_state, target_variables, max_scenarios
|
|
2127
|
+
)
|
|
2128
|
+
|
|
2129
|
+
try:
|
|
2130
|
+
self.ensure_standardization_stats(factual_state)
|
|
2131
|
+
|
|
2132
|
+
# Initialize helper classes
|
|
2133
|
+
intervention_sampler = _AdaptiveInterventionSampler(self)
|
|
2134
|
+
graph_sampler = _GraphUncertaintySampler(self)
|
|
2135
|
+
quality_assessor = _PredictionQualityAssessor(self)
|
|
2136
|
+
meta_analyzer = _MetaReasoningAnalyzer(self)
|
|
2137
|
+
|
|
2138
|
+
# Get uncertainty data if available (from quantify_uncertainty)
|
|
2139
|
+
uncertainty_data = None
|
|
2140
|
+
# Try to get from cached uncertainty results if available
|
|
2141
|
+
if hasattr(self, '_cached_uncertainty_data'):
|
|
2142
|
+
uncertainty_data = self._cached_uncertainty_data
|
|
2143
|
+
|
|
2144
|
+
# Parallel sampling: graph variations and interventions
|
|
2145
|
+
import concurrent.futures
|
|
2146
|
+
|
|
2147
|
+
n_graph_samples = max(10, mc_samples // 100) # Sample fewer graph variants
|
|
2148
|
+
n_intervention_samples = mc_samples
|
|
2149
|
+
|
|
2150
|
+
if parallel_sampling and self.bootstrap_workers > 0:
|
|
2151
|
+
# Parallel execution
|
|
2152
|
+
with concurrent.futures.ThreadPoolExecutor(max_workers=self.bootstrap_workers) as executor:
|
|
2153
|
+
graph_future = executor.submit(
|
|
2154
|
+
graph_sampler.sample_graph_variations,
|
|
2155
|
+
n_graph_samples,
|
|
2156
|
+
uncertainty_data
|
|
2157
|
+
)
|
|
2158
|
+
intervention_future = executor.submit(
|
|
2159
|
+
intervention_sampler.sample_interventions,
|
|
2160
|
+
factual_state,
|
|
2161
|
+
target_variables,
|
|
2162
|
+
n_intervention_samples
|
|
2163
|
+
)
|
|
2164
|
+
|
|
2165
|
+
graph_variations = graph_future.result()
|
|
2166
|
+
interventions_list = intervention_future.result()
|
|
2167
|
+
else:
|
|
2168
|
+
# Sequential execution
|
|
2169
|
+
graph_variations = graph_sampler.sample_graph_variations(
|
|
2170
|
+
n_graph_samples,
|
|
2171
|
+
uncertainty_data
|
|
2172
|
+
)
|
|
2173
|
+
interventions_list = intervention_sampler.sample_interventions(
|
|
2174
|
+
factual_state,
|
|
2175
|
+
target_variables,
|
|
2176
|
+
n_intervention_samples
|
|
2177
|
+
)
|
|
2178
|
+
|
|
2179
|
+
# For each intervention, evaluate across graph variations
|
|
2180
|
+
scenarios_with_metadata = []
|
|
2181
|
+
|
|
2182
|
+
for intervention in interventions_list:
|
|
2183
|
+
# Get predictions across all graph variations
|
|
2184
|
+
predictions_across_variants = []
|
|
2185
|
+
for graph_variant in graph_variations:
|
|
2186
|
+
pred = self._predict_outcomes_with_graph_variant(
|
|
2187
|
+
factual_state,
|
|
2188
|
+
intervention,
|
|
2189
|
+
graph_variant
|
|
2190
|
+
)
|
|
2191
|
+
predictions_across_variants.append(pred)
|
|
2192
|
+
|
|
2193
|
+
# Assess prediction quality
|
|
2194
|
+
quality_score, quality_metrics = quality_assessor.assess_quality(
|
|
2195
|
+
predictions_across_variants,
|
|
2196
|
+
factual_state,
|
|
2197
|
+
intervention
|
|
2198
|
+
)
|
|
2199
|
+
|
|
2200
|
+
# Aggregate predictions (mean across variants)
|
|
2201
|
+
aggregated_outcomes = {}
|
|
2202
|
+
for var in set().union(*[p.keys() for p in predictions_across_variants]):
|
|
2203
|
+
values = [p.get(var, 0.0) for p in predictions_across_variants]
|
|
2204
|
+
aggregated_outcomes[var] = float(np.mean(values))
|
|
2205
|
+
|
|
2206
|
+
# Determine sampling distribution used
|
|
2207
|
+
sampling_dist = "adaptive" # Default
|
|
2208
|
+
for var in intervention.keys():
|
|
2209
|
+
dist_type, _ = intervention_sampler._get_adaptive_distribution(var, factual_state)
|
|
2210
|
+
sampling_dist = dist_type
|
|
2211
|
+
break
|
|
2212
|
+
|
|
2213
|
+
# Create scenario with metadata
|
|
2214
|
+
scenario = CounterfactualScenario(
|
|
2215
|
+
name=f"mc_scenario_{len(scenarios_with_metadata)}",
|
|
2216
|
+
interventions=intervention,
|
|
2217
|
+
expected_outcomes=aggregated_outcomes,
|
|
2218
|
+
probability=self._calculate_scenario_probability(factual_state, intervention),
|
|
2219
|
+
reasoning=f"Monte Carlo sampled intervention with quality score {quality_score:.3f}",
|
|
2220
|
+
uncertainty_metadata={
|
|
2221
|
+
"quality_score": quality_score,
|
|
2222
|
+
"quality_metrics": quality_metrics,
|
|
2223
|
+
"graph_variations_tested": len(graph_variations),
|
|
2224
|
+
"prediction_variance": {
|
|
2225
|
+
var: float(np.var([p.get(var, 0.0) for p in predictions_across_variants]))
|
|
2226
|
+
for var in aggregated_outcomes.keys()
|
|
2227
|
+
}
|
|
2228
|
+
},
|
|
2229
|
+
sampling_distribution=sampling_dist,
|
|
2230
|
+
monte_carlo_iterations=mc_samples,
|
|
2231
|
+
meta_reasoning_score=None # Will be set by meta_analyzer
|
|
2232
|
+
)
|
|
2233
|
+
|
|
2234
|
+
scenarios_with_metadata.append((
|
|
2235
|
+
scenario,
|
|
2236
|
+
{
|
|
2237
|
+
"factual_state": factual_state,
|
|
2238
|
+
"quality_score": quality_score,
|
|
2239
|
+
"quality_metrics": quality_metrics
|
|
2240
|
+
}
|
|
2241
|
+
))
|
|
2242
|
+
|
|
2243
|
+
# Meta-reasoning analysis: rank scenarios by informativeness
|
|
2244
|
+
ranked_scenarios = meta_analyzer.analyze_scenarios(scenarios_with_metadata)
|
|
2245
|
+
|
|
2246
|
+
# Update meta_reasoning_score in scenarios
|
|
2247
|
+
final_scenarios = []
|
|
2248
|
+
for scenario, meta_score in ranked_scenarios[:max_scenarios]:
|
|
2249
|
+
# Update scenario with meta-reasoning score
|
|
2250
|
+
scenario.meta_reasoning_score = meta_score
|
|
2251
|
+
final_scenarios.append(scenario)
|
|
2252
|
+
|
|
2253
|
+
return final_scenarios
|
|
2254
|
+
|
|
2255
|
+
except Exception as e:
|
|
2256
|
+
logger.error(f"Monte Carlo counterfactual generation failed: {e}")
|
|
2257
|
+
if use_deterministic_fallback:
|
|
2258
|
+
logger.warning("Falling back to deterministic method")
|
|
2259
|
+
return self._generate_counterfactual_scenarios_deterministic(
|
|
2260
|
+
factual_state, target_variables, max_scenarios
|
|
2261
|
+
)
|
|
2262
|
+
else:
|
|
2263
|
+
raise
|
|
2264
|
+
|
|
2265
|
+
def _generate_counterfactual_scenarios_deterministic(
|
|
2266
|
+
self,
|
|
2267
|
+
factual_state: Dict[str, float],
|
|
2268
|
+
target_variables: List[str],
|
|
2269
|
+
max_scenarios: int = 5
|
|
2270
|
+
) -> List[CounterfactualScenario]:
|
|
2271
|
+
"""Original deterministic counterfactual scenario generation (for backward compatibility).
|
|
2272
|
+
|
|
2273
|
+
Args:
|
|
2274
|
+
factual_state: Current factual state
|
|
2275
|
+
target_variables: Variables to generate scenarios for
|
|
2276
|
+
max_scenarios: Maximum number of scenarios to return
|
|
2277
|
+
|
|
2278
|
+
Returns:
|
|
2279
|
+
List of CounterfactualScenario objects
|
|
2280
|
+
"""
|
|
2281
|
+
self.ensure_standardization_stats(factual_state)
|
|
2282
|
+
|
|
2283
|
+
scenarios: List[CounterfactualScenario] = []
|
|
2284
|
+
z_steps = [-2.0, -1.0, -0.5, 0.5, 1.0, 2.0]
|
|
2285
|
+
|
|
2286
|
+
for i, tv in enumerate(target_variables[:max_scenarios]):
|
|
2287
|
+
stats = self.standardization_stats.get(tv, {"mean": 0.0, "std": 1.0})
|
|
2288
|
+
cur = factual_state.get(tv, stats.get("mean", 0.0))
|
|
2289
|
+
|
|
2290
|
+
if not stats or stats.get("std", 0.0) <= 0:
|
|
2291
|
+
base = cur
|
|
2292
|
+
abs_steps = [-2.0, -1.0, -0.5, 0.5, 1.0, 2.0]
|
|
2293
|
+
vals = [base + step for step in abs_steps]
|
|
2294
|
+
else:
|
|
2295
|
+
mean = stats["mean"]
|
|
2296
|
+
std = stats["std"]
|
|
2297
|
+
cz = (cur - mean) / std
|
|
2298
|
+
vals = [(cz + dz) * std + mean for dz in z_steps]
|
|
2299
|
+
|
|
2300
|
+
for j, v in enumerate(vals):
|
|
2301
|
+
interventions = {tv: float(v)}
|
|
2302
|
+
scenarios.append(
|
|
2303
|
+
CounterfactualScenario(
|
|
2304
|
+
name=f"scenario_{i}_{j}",
|
|
2305
|
+
interventions=interventions,
|
|
2306
|
+
expected_outcomes=self._predict_outcomes(
|
|
2307
|
+
factual_state, interventions
|
|
2308
|
+
),
|
|
2309
|
+
probability=self._calculate_scenario_probability(
|
|
2310
|
+
factual_state, interventions
|
|
2311
|
+
),
|
|
2312
|
+
reasoning=f"Intervention on {tv} with value {v}",
|
|
2313
|
+
)
|
|
2314
|
+
)
|
|
2315
|
+
|
|
2316
|
+
return scenarios
|
|
2317
|
+
|
|
2318
|
+
def analyze_causal_strength(self, source: str, target: str) -> Dict[str, float]:
|
|
2319
|
+
|
|
2320
|
+
if source not in self.causal_graph or target not in self.causal_graph.get(source, {}):
|
|
2321
|
+
return {"strength": 0.0, "confidence": 0.0, "path_length": float('inf')}
|
|
2322
|
+
|
|
2323
|
+
edge = self.causal_graph[source].get(target, {})
|
|
2324
|
+
strength = float(edge.get("strength", 0.0)) if isinstance(edge, dict) else float(edge)
|
|
2325
|
+
path = self.identify_causal_chain(source, target)
|
|
2326
|
+
path_length = len(path) - 1 if path else float('inf')
|
|
2327
|
+
|
|
2328
|
+
return {
|
|
2329
|
+
"strength": float(strength),
|
|
2330
|
+
"confidence": float(edge.get("confidence", 1.0) if isinstance(edge, dict) else 1.0),
|
|
2331
|
+
"path_length": path_length,
|
|
2332
|
+
"relation_type": edge.get("relation_type", CausalRelationType.DIRECT.value) if isinstance(edge, dict) else CausalRelationType.DIRECT.value
|
|
2333
|
+
}
|
|
2334
|
+
|
|
2335
|
+
def set_standardization_stats(
|
|
2336
|
+
self,
|
|
2337
|
+
variable: str,
|
|
2338
|
+
mean: float,
|
|
2339
|
+
std: float
|
|
2340
|
+
) -> None:
|
|
2341
|
+
|
|
2342
|
+
self.standardization_stats[variable] = {"mean": mean, "std": std if std > 0 else 1.0}
|
|
2343
|
+
|
|
2344
|
+
def ensure_standardization_stats(self, state: Dict[str, float]) -> None:
|
|
2345
|
+
|
|
2346
|
+
for var, val in state.items():
|
|
2347
|
+
if var not in self.standardization_stats:
|
|
2348
|
+
self.standardization_stats[var] = {"mean": float(val), "std": 1.0}
|
|
2349
|
+
|
|
2350
|
+
def get_nodes(self) -> List[str]:
|
|
2351
|
+
|
|
2352
|
+
return list(self.causal_graph.keys())
|
|
2353
|
+
|
|
2354
|
+
def get_edges(self) -> List[Tuple[str, str]]:
|
|
2355
|
+
|
|
2356
|
+
edges = []
|
|
2357
|
+
for source, targets in self.causal_graph.items():
|
|
2358
|
+
for target in targets.keys():
|
|
2359
|
+
edges.append((source, target))
|
|
2360
|
+
return edges
|
|
2361
|
+
|
|
2362
|
+
def is_dag(self) -> bool:
|
|
2363
|
+
|
|
2364
|
+
try:
|
|
2365
|
+
return rx.is_directed_acyclic_graph(self._graph)
|
|
2366
|
+
except Exception:
|
|
2367
|
+
def has_cycle(node: str, visited: set, rec_stack: set) -> bool:
|
|
2368
|
+
|
|
2369
|
+
visited.add(node)
|
|
2370
|
+
rec_stack.add(node)
|
|
2371
|
+
|
|
2372
|
+
for child in self._get_children(node):
|
|
2373
|
+
if child not in visited:
|
|
2374
|
+
if has_cycle(child, visited, rec_stack):
|
|
2375
|
+
return True
|
|
2376
|
+
elif child in rec_stack:
|
|
2377
|
+
return True
|
|
2378
|
+
|
|
2379
|
+
rec_stack.remove(node)
|
|
2380
|
+
return False
|
|
2381
|
+
|
|
2382
|
+
visited = set()
|
|
2383
|
+
rec_stack = set()
|
|
2384
|
+
|
|
2385
|
+
for node in self.causal_graph:
|
|
2386
|
+
if node not in visited:
|
|
2387
|
+
if has_cycle(node, visited, rec_stack):
|
|
2388
|
+
return False
|
|
2389
|
+
|
|
2390
|
+
return True
|
|
2391
|
+
|
|
2392
|
+
def run(
|
|
2393
|
+
self,
|
|
2394
|
+
task: Optional[Union[str, Any]] = None,
|
|
2395
|
+
img: Optional[str] = None,
|
|
2396
|
+
imgs: Optional[List[str]] = None,
|
|
2397
|
+
correct_answer: Optional[str] = None,
|
|
2398
|
+
streaming_callback: Optional[Any] = None,
|
|
2399
|
+
n: int = 1,
|
|
2400
|
+
initial_state: Optional[Any] = None,
|
|
2401
|
+
target_variables: Optional[List[str]] = None,
|
|
2402
|
+
max_steps: Union[int, str] = 1,
|
|
2403
|
+
*args,
|
|
2404
|
+
**kwargs,
|
|
2405
|
+
) -> Union[Dict[str, Any], Any]:
|
|
2406
|
+
"""
|
|
2407
|
+
Run the agent with support for both standard Agent features and causal analysis.
|
|
2408
|
+
|
|
2409
|
+
This method maintains compatibility with the parent Agent class while adding
|
|
2410
|
+
causal reasoning capabilities. It routes to causal analysis when appropriate,
|
|
2411
|
+
otherwise delegates to the parent Agent's standard functionality.
|
|
2412
|
+
|
|
2413
|
+
Args:
|
|
2414
|
+
task: Task string for LLM analysis, or state dict for causal evolution
|
|
2415
|
+
img: Optional image path for vision tasks (delegates to parent)
|
|
2416
|
+
imgs: Optional list of images (delegates to parent)
|
|
2417
|
+
correct_answer: Optional correct answer for validation (delegates to parent)
|
|
2418
|
+
streaming_callback: Optional callback for streaming output (delegates to parent)
|
|
2419
|
+
n: Number of runs (delegates to parent)
|
|
2420
|
+
initial_state: Initial state dictionary for causal evolution
|
|
2421
|
+
target_variables: Target variables for counterfactual analysis
|
|
2422
|
+
max_steps: Maximum evolution steps for causal analysis
|
|
2423
|
+
*args: Additional positional arguments
|
|
2424
|
+
**kwargs: Additional keyword arguments
|
|
2425
|
+
|
|
2426
|
+
Returns:
|
|
2427
|
+
Dictionary with causal analysis results, or standard Agent output
|
|
2428
|
+
"""
|
|
2429
|
+
# Check if this is a causal analysis operation
|
|
2430
|
+
# Criteria: initial_state or target_variables explicitly provided, or task is a dict,
|
|
2431
|
+
# OR automatic detection based on task content
|
|
2432
|
+
is_causal_operation = (
|
|
2433
|
+
initial_state is not None or
|
|
2434
|
+
target_variables is not None or
|
|
2435
|
+
(task is not None and isinstance(task, dict)) or
|
|
2436
|
+
(task is not None and isinstance(task, str) and task.strip().startswith('{')) or
|
|
2437
|
+
self._should_trigger_causal_analysis(task) # Automatic detection
|
|
2438
|
+
)
|
|
2439
|
+
|
|
2440
|
+
# Delegate to parent Agent for all standard operations
|
|
2441
|
+
# This includes: images, streaming, handoffs, multiple runs, regular text tasks, etc.
|
|
2442
|
+
# Only use causal analysis if explicitly requested via causal operation indicators
|
|
2443
|
+
if not is_causal_operation:
|
|
2444
|
+
# All standard Agent operations go to parent (handoffs, images, streaming, etc.)
|
|
2445
|
+
return super().run(
|
|
2446
|
+
task=task,
|
|
2447
|
+
img=img,
|
|
2448
|
+
imgs=imgs,
|
|
2449
|
+
correct_answer=correct_answer,
|
|
2450
|
+
streaming_callback=streaming_callback,
|
|
2451
|
+
n=n,
|
|
2452
|
+
*args,
|
|
2453
|
+
**kwargs,
|
|
2454
|
+
)
|
|
2455
|
+
|
|
2456
|
+
# Causal analysis operations - only when explicitly indicated
|
|
2457
|
+
if task is not None and isinstance(task, str) and initial_state is None and not task.strip().startswith('{'):
|
|
2458
|
+
return self._run_llm_causal_analysis(task, **kwargs)
|
|
2459
|
+
|
|
2460
|
+
if task is not None and initial_state is None:
|
|
2461
|
+
initial_state = task
|
|
2462
|
+
|
|
2463
|
+
if not isinstance(initial_state, dict):
|
|
2464
|
+
try:
|
|
2465
|
+
import json
|
|
2466
|
+
parsed = json.loads(initial_state)
|
|
2467
|
+
if isinstance(parsed, dict):
|
|
2468
|
+
initial_state = parsed
|
|
2469
|
+
else:
|
|
2470
|
+
return {"error": "initial_state JSON must decode to a dict"}
|
|
2471
|
+
except Exception:
|
|
2472
|
+
return {"error": "initial_state must be a dict or JSON-encoded dict"}
|
|
2473
|
+
|
|
2474
|
+
if target_variables is None:
|
|
2475
|
+
target_variables = list(self.causal_graph.keys())
|
|
2476
|
+
|
|
2477
|
+
def _resolve_max_steps(value: Union[int, str]) -> int:
|
|
2478
|
+
if isinstance(value, str) and value == "auto":
|
|
2479
|
+
return max(1, len(self.causal_graph))
|
|
2480
|
+
try:
|
|
2481
|
+
return int(value)
|
|
2482
|
+
except Exception:
|
|
2483
|
+
return max(1, len(self.causal_graph))
|
|
2484
|
+
|
|
2485
|
+
effective_steps = _resolve_max_steps(max_steps if max_steps != 1 or self.causal_max_loops == 1 else self.causal_max_loops)
|
|
2486
|
+
if max_steps == 1 and self.causal_max_loops != 1:
|
|
2487
|
+
effective_steps = _resolve_max_steps(self.causal_max_loops)
|
|
2488
|
+
|
|
2489
|
+
current_state = initial_state.copy()
|
|
2490
|
+
for step in range(effective_steps):
|
|
2491
|
+
current_state = self._predict_outcomes(current_state, {})
|
|
2492
|
+
|
|
2493
|
+
self.ensure_standardization_stats(current_state)
|
|
2494
|
+
counterfactual_scenarios = self.generate_counterfactual_scenarios(
|
|
2495
|
+
current_state,
|
|
2496
|
+
target_variables,
|
|
2497
|
+
max_scenarios=5
|
|
2498
|
+
)
|
|
2499
|
+
|
|
2500
|
+
return {
|
|
2501
|
+
"initial_state": initial_state,
|
|
2502
|
+
"evolved_state": current_state,
|
|
2503
|
+
"counterfactual_scenarios": counterfactual_scenarios,
|
|
2504
|
+
"causal_graph_info": {
|
|
2505
|
+
"nodes": self.get_nodes(),
|
|
2506
|
+
"edges": self.get_edges(),
|
|
2507
|
+
"is_dag": self.is_dag()
|
|
2508
|
+
},
|
|
2509
|
+
"steps": effective_steps
|
|
2510
|
+
}
|
|
2511
|
+
|
|
2512
|
+
def _call_llm_directly(self, prompt: str, system_prompt: Optional[str] = None) -> str:
|
|
2513
|
+
"""
|
|
2514
|
+
Call the LLM directly using litellm.completion(), bypassing Agent tool execution.
|
|
2515
|
+
|
|
2516
|
+
This method extracts the model configuration from the Agent and makes a direct
|
|
2517
|
+
API call to get plain text/JSON responses without function calling.
|
|
2518
|
+
|
|
2519
|
+
Args:
|
|
2520
|
+
prompt: The user prompt to send to the LLM
|
|
2521
|
+
system_prompt: Optional system prompt (defaults to Agent's system prompt)
|
|
2522
|
+
|
|
2523
|
+
Returns:
|
|
2524
|
+
The raw text response from the LLM
|
|
2525
|
+
"""
|
|
2526
|
+
try:
|
|
2527
|
+
import litellm
|
|
2528
|
+
except ImportError:
|
|
2529
|
+
raise ImportError("litellm is required for direct LLM calls")
|
|
2530
|
+
|
|
2531
|
+
# Get model configuration from Agent
|
|
2532
|
+
model_name = getattr(self, 'model_name', 'gpt-4o')
|
|
2533
|
+
api_key = getattr(self, 'llm_api_key', None) or os.getenv('OPENAI_API_KEY')
|
|
2534
|
+
base_url = getattr(self, 'llm_base_url', None)
|
|
2535
|
+
|
|
2536
|
+
# Get system prompt
|
|
2537
|
+
if system_prompt is None:
|
|
2538
|
+
system_prompt = getattr(self, 'system_prompt', None)
|
|
2539
|
+
|
|
2540
|
+
# Build messages
|
|
2541
|
+
messages = []
|
|
2542
|
+
if system_prompt:
|
|
2543
|
+
messages.append({"role": "system", "content": system_prompt})
|
|
2544
|
+
messages.append({"role": "user", "content": prompt})
|
|
2545
|
+
|
|
2546
|
+
# Call litellm directly
|
|
2547
|
+
try:
|
|
2548
|
+
response = litellm.completion(
|
|
2549
|
+
model=model_name,
|
|
2550
|
+
messages=messages,
|
|
2551
|
+
api_key=api_key,
|
|
2552
|
+
api_base=base_url,
|
|
2553
|
+
temperature=getattr(self, 'temperature', 0.5),
|
|
2554
|
+
max_tokens=getattr(self, 'max_tokens', 4096),
|
|
2555
|
+
)
|
|
2556
|
+
|
|
2557
|
+
# Extract text from response
|
|
2558
|
+
if response and hasattr(response, 'choices') and len(response.choices) > 0:
|
|
2559
|
+
content = response.choices[0].message.content
|
|
2560
|
+
return content if content else ""
|
|
2561
|
+
else:
|
|
2562
|
+
logger.warning("Empty response from LLM")
|
|
2563
|
+
return ""
|
|
2564
|
+
except Exception as e:
|
|
2565
|
+
logger.error(f"Error calling LLM directly: {e}")
|
|
2566
|
+
raise
|
|
2567
|
+
|
|
2568
|
+
def _extract_variables_ml_based(self, task: str) -> bool:
|
|
2569
|
+
"""
|
|
2570
|
+
Extract variables using ML/NLP-based approach - use LLM to generate structured output,
|
|
2571
|
+
parse the function call from the response, and automatically invoke the handler.
|
|
2572
|
+
|
|
2573
|
+
This bypasses unreliable function calling by directly parsing LLM output and
|
|
2574
|
+
invoking the handler programmatically.
|
|
2575
|
+
|
|
2576
|
+
Args:
|
|
2577
|
+
task: The task string to analyze
|
|
2578
|
+
|
|
2579
|
+
Returns:
|
|
2580
|
+
True if variables were successfully extracted, False otherwise
|
|
2581
|
+
"""
|
|
2582
|
+
import json
|
|
2583
|
+
import re
|
|
2584
|
+
|
|
2585
|
+
logger.info(f"Starting ML-based variable extraction for task: {task[:100]}...")
|
|
2586
|
+
|
|
2587
|
+
# Use LLM to generate structured JSON output with variables and edges
|
|
2588
|
+
extraction_prompt = f"""Analyze this task and extract causal variables and relationships.
|
|
2589
|
+
|
|
2590
|
+
Task: {task}
|
|
2591
|
+
|
|
2592
|
+
Return a JSON object with this exact structure:
|
|
2593
|
+
{{
|
|
2594
|
+
"required_variables": ["var1", "var2", "var3"],
|
|
2595
|
+
"causal_edges": [["var1", "var2"], ["var2", "var3"]],
|
|
2596
|
+
"reasoning": "Brief explanation of why these variables are needed",
|
|
2597
|
+
"optional_variables": ["var4"],
|
|
2598
|
+
"counterfactual_variables": ["var1", "var2"]
|
|
2599
|
+
}}
|
|
2600
|
+
|
|
2601
|
+
Extract ALL relevant variables and causal relationships. Be thorough.
|
|
2602
|
+
Return ONLY valid JSON, no other text."""
|
|
2603
|
+
|
|
2604
|
+
try:
|
|
2605
|
+
# Call LLM directly (bypasses Agent tool execution)
|
|
2606
|
+
# This gives us pure text/JSON response without function calling
|
|
2607
|
+
# The LLM is still the core component - we just parse its output programmatically
|
|
2608
|
+
raw_response = self._call_llm_directly(extraction_prompt)
|
|
2609
|
+
|
|
2610
|
+
extracted_data = None
|
|
2611
|
+
|
|
2612
|
+
# Parse JSON from LLM text response
|
|
2613
|
+
# The LLM returns plain text with JSON embedded, so we extract it
|
|
2614
|
+
response_text = str(raw_response)
|
|
2615
|
+
|
|
2616
|
+
# Try to extract JSON from response text
|
|
2617
|
+
json_match = re.search(r'\{[^{}]*"required_variables"[^{}]*\}', response_text, re.DOTALL)
|
|
2618
|
+
if not json_match:
|
|
2619
|
+
# Try to find any JSON object
|
|
2620
|
+
json_match = re.search(r'\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}', response_text, re.DOTALL)
|
|
2621
|
+
|
|
2622
|
+
if json_match:
|
|
2623
|
+
json_str = json_match.group(0)
|
|
2624
|
+
try:
|
|
2625
|
+
# Try to fix common JSON issues
|
|
2626
|
+
# Remove invalid escape sequences
|
|
2627
|
+
json_str = json_str.replace('\\"', '"').replace("\\'", "'")
|
|
2628
|
+
# Try parsing
|
|
2629
|
+
extracted_data = json.loads(json_str)
|
|
2630
|
+
logger.info("Parsed JSON from text response")
|
|
2631
|
+
except json.JSONDecodeError as e:
|
|
2632
|
+
# Try to extract just the JSON object more carefully
|
|
2633
|
+
try:
|
|
2634
|
+
# Find the innermost complete JSON object
|
|
2635
|
+
brace_count = 0
|
|
2636
|
+
start_idx = json_str.find('{')
|
|
2637
|
+
if start_idx >= 0:
|
|
2638
|
+
for i in range(start_idx, len(json_str)):
|
|
2639
|
+
if json_str[i] == '{':
|
|
2640
|
+
brace_count += 1
|
|
2641
|
+
elif json_str[i] == '}':
|
|
2642
|
+
brace_count -= 1
|
|
2643
|
+
if brace_count == 0:
|
|
2644
|
+
json_str = json_str[start_idx:i+1]
|
|
2645
|
+
extracted_data = json.loads(json_str)
|
|
2646
|
+
logger.info("Parsed JSON after fixing structure")
|
|
2647
|
+
break
|
|
2648
|
+
except (json.JSONDecodeError, ValueError) as e2:
|
|
2649
|
+
logger.warning(f"Failed to parse JSON after fixes: {e2}")
|
|
2650
|
+
logger.debug(f"JSON string was: {json_str[:500]}")
|
|
2651
|
+
|
|
2652
|
+
# Process extracted data
|
|
2653
|
+
if extracted_data:
|
|
2654
|
+
# Validate structure
|
|
2655
|
+
required_vars = extracted_data.get("required_variables", [])
|
|
2656
|
+
causal_edges = extracted_data.get("causal_edges", [])
|
|
2657
|
+
reasoning = extracted_data.get("reasoning", "Extracted from task analysis")
|
|
2658
|
+
optional_vars = extracted_data.get("optional_variables", [])
|
|
2659
|
+
counterfactual_vars = extracted_data.get("counterfactual_variables", [])
|
|
2660
|
+
|
|
2661
|
+
if required_vars and causal_edges:
|
|
2662
|
+
# Automatically invoke the handler with extracted data
|
|
2663
|
+
logger.info(f"Extracted {len(required_vars)} variables, {len(causal_edges)} edges via ML")
|
|
2664
|
+
result = self._extract_causal_variables_handler(
|
|
2665
|
+
required_variables=required_vars,
|
|
2666
|
+
causal_edges=causal_edges,
|
|
2667
|
+
reasoning=reasoning,
|
|
2668
|
+
optional_variables=optional_vars if optional_vars else None,
|
|
2669
|
+
counterfactual_variables=counterfactual_vars if counterfactual_vars else None,
|
|
2670
|
+
)
|
|
2671
|
+
|
|
2672
|
+
# Check if extraction was successful
|
|
2673
|
+
if result.get("status") == "success" and result.get("summary", {}).get("variables_added", 0) > 0:
|
|
2674
|
+
logger.info(f"ML-based extraction successful: {result.get('summary')}")
|
|
2675
|
+
return True
|
|
2676
|
+
else:
|
|
2677
|
+
logger.warning(f"ML-based extraction returned: {result}")
|
|
2678
|
+
else:
|
|
2679
|
+
logger.warning(f"Extracted data missing required fields. required_variables: {required_vars}, causal_edges: {causal_edges}")
|
|
2680
|
+
else:
|
|
2681
|
+
logger.warning("Could not extract data from LLM response")
|
|
2682
|
+
logger.debug(f"Raw response type: {type(raw_response)}, value: {str(raw_response)[:500]}")
|
|
2683
|
+
|
|
2684
|
+
except Exception as e:
|
|
2685
|
+
logger.error(f"Error in ML-based extraction: {e}")
|
|
2686
|
+
import traceback
|
|
2687
|
+
logger.debug(traceback.format_exc())
|
|
2688
|
+
|
|
2689
|
+
return len(self.causal_graph) > 0
|
|
2690
|
+
|
|
2691
|
+
def _generate_causal_analysis_ml_based(self, task: str) -> Optional[Dict[str, Any]]:
|
|
2692
|
+
"""
|
|
2693
|
+
Generate causal analysis using ML-based approach - use LLM to generate structured output,
|
|
2694
|
+
parse the function call from the response, and automatically invoke the handler.
|
|
2695
|
+
|
|
2696
|
+
This bypasses unreliable function calling by directly parsing LLM output and
|
|
2697
|
+
invoking the handler programmatically.
|
|
2698
|
+
|
|
2699
|
+
Args:
|
|
2700
|
+
task: The task string to analyze
|
|
2701
|
+
|
|
2702
|
+
Returns:
|
|
2703
|
+
Dictionary with causal analysis results, or None if extraction failed
|
|
2704
|
+
"""
|
|
2705
|
+
import json
|
|
2706
|
+
import re
|
|
2707
|
+
|
|
2708
|
+
logger.info(f"Starting ML-based causal analysis generation for task: {task[:100]}...")
|
|
2709
|
+
|
|
2710
|
+
# Build comprehensive causal analysis prompt
|
|
2711
|
+
causal_prompt = self._build_causal_prompt(task)
|
|
2712
|
+
|
|
2713
|
+
# Add instruction for structured output
|
|
2714
|
+
analysis_prompt = f"""{causal_prompt}
|
|
2715
|
+
|
|
2716
|
+
CRITICAL: You must return a JSON object with this EXACT structure. Do not include any text before or after the JSON.
|
|
2717
|
+
|
|
2718
|
+
Required JSON format:
|
|
2719
|
+
{{
|
|
2720
|
+
"causal_analysis": "Detailed analysis of causal relationships and mechanisms. This must be a comprehensive text explanation.",
|
|
2721
|
+
"intervention_planning": "Planned interventions to test causal hypotheses.",
|
|
2722
|
+
"counterfactual_scenarios": [
|
|
2723
|
+
{{
|
|
2724
|
+
"scenario_name": "Scenario 1",
|
|
2725
|
+
"interventions": {{"var1": 10, "var2": 20}},
|
|
2726
|
+
"expected_outcomes": {{"target_var": 30}},
|
|
2727
|
+
"reasoning": "Why this scenario is important..."
|
|
2728
|
+
}}
|
|
2729
|
+
],
|
|
2730
|
+
"causal_strength_assessment": "Assessment of relationship strengths and confounders.",
|
|
2731
|
+
"optimal_solution": "Recommended solution based on analysis."
|
|
2732
|
+
}}
|
|
2733
|
+
|
|
2734
|
+
IMPORTANT:
|
|
2735
|
+
- The "causal_analysis" field is REQUIRED and must contain detailed text analysis
|
|
2736
|
+
- Return ONLY the JSON object, no markdown, no code blocks, no explanations
|
|
2737
|
+
- Ensure all fields are present and properly formatted"""
|
|
2738
|
+
|
|
2739
|
+
try:
|
|
2740
|
+
# Call LLM directly (bypasses Agent tool execution)
|
|
2741
|
+
# This gives us pure text/JSON response without function calling
|
|
2742
|
+
# The LLM is still the core component - we just parse its output programmatically
|
|
2743
|
+
raw_response = self._call_llm_directly(analysis_prompt)
|
|
2744
|
+
|
|
2745
|
+
extracted_data = None
|
|
2746
|
+
|
|
2747
|
+
# Parse JSON from LLM text response
|
|
2748
|
+
# The LLM returns plain text with JSON embedded, so we extract it
|
|
2749
|
+
response_text = str(raw_response)
|
|
2750
|
+
|
|
2751
|
+
# First, try to extract JSON from markdown code blocks (```json ... ```)
|
|
2752
|
+
json_block_match = re.search(r'```(?:json)?\s*(\{.*?\})\s*```', response_text, re.DOTALL)
|
|
2753
|
+
if json_block_match:
|
|
2754
|
+
json_str = json_block_match.group(1)
|
|
2755
|
+
try:
|
|
2756
|
+
extracted_data = json.loads(json_str)
|
|
2757
|
+
logger.info("Parsed JSON from markdown code block")
|
|
2758
|
+
except json.JSONDecodeError:
|
|
2759
|
+
pass
|
|
2760
|
+
|
|
2761
|
+
# If not found in code block, try to extract JSON directly
|
|
2762
|
+
if not extracted_data:
|
|
2763
|
+
# Try to find JSON object with causal_analysis field
|
|
2764
|
+
json_match = re.search(r'\{[^{}]*"causal_analysis"[^{}]*\}', response_text, re.DOTALL)
|
|
2765
|
+
if not json_match:
|
|
2766
|
+
# Try to find any complete JSON object (handle nested objects)
|
|
2767
|
+
# This regex finds the outermost complete JSON object
|
|
2768
|
+
brace_count = 0
|
|
2769
|
+
start_idx = response_text.find('{')
|
|
2770
|
+
if start_idx >= 0:
|
|
2771
|
+
for i in range(start_idx, len(response_text)):
|
|
2772
|
+
if response_text[i] == '{':
|
|
2773
|
+
brace_count += 1
|
|
2774
|
+
elif response_text[i] == '}':
|
|
2775
|
+
brace_count -= 1
|
|
2776
|
+
if brace_count == 0:
|
|
2777
|
+
json_str = response_text[start_idx:i+1]
|
|
2778
|
+
try:
|
|
2779
|
+
# Try to fix common JSON issues
|
|
2780
|
+
json_str = json_str.replace('\\"', '"').replace("\\'", "'")
|
|
2781
|
+
extracted_data = json.loads(json_str)
|
|
2782
|
+
logger.info("Parsed JSON from text response")
|
|
2783
|
+
break
|
|
2784
|
+
except json.JSONDecodeError:
|
|
2785
|
+
# Try without the escape fixes
|
|
2786
|
+
try:
|
|
2787
|
+
extracted_data = json.loads(response_text[start_idx:i+1])
|
|
2788
|
+
logger.info("Parsed JSON after removing escape fixes")
|
|
2789
|
+
break
|
|
2790
|
+
except json.JSONDecodeError as e:
|
|
2791
|
+
logger.debug(f"Failed to parse JSON: {e}")
|
|
2792
|
+
break
|
|
2793
|
+
|
|
2794
|
+
# Process extracted data
|
|
2795
|
+
if extracted_data:
|
|
2796
|
+
# Validate structure - try multiple possible field names
|
|
2797
|
+
causal_analysis = (
|
|
2798
|
+
extracted_data.get("causal_analysis", "") or
|
|
2799
|
+
extracted_data.get("analysis", "") or
|
|
2800
|
+
extracted_data.get("causal_analysis_text", "") or
|
|
2801
|
+
str(extracted_data.get("analysis_text", ""))
|
|
2802
|
+
)
|
|
2803
|
+
intervention_planning = extracted_data.get("intervention_planning", "") or extracted_data.get("interventions", "")
|
|
2804
|
+
counterfactual_scenarios = extracted_data.get("counterfactual_scenarios", []) or extracted_data.get("scenarios", [])
|
|
2805
|
+
causal_strength_assessment = extracted_data.get("causal_strength_assessment", "") or extracted_data.get("strength_assessment", "")
|
|
2806
|
+
optimal_solution = extracted_data.get("optimal_solution", "") or extracted_data.get("solution", "")
|
|
2807
|
+
|
|
2808
|
+
# Log what we extracted for debugging (use INFO so it's visible)
|
|
2809
|
+
logger.info(f"Extracted fields - causal_analysis: {bool(causal_analysis)}, scenarios: {len(counterfactual_scenarios)}")
|
|
2810
|
+
logger.info(f"Extracted data keys: {list(extracted_data.keys())}")
|
|
2811
|
+
if not causal_analysis:
|
|
2812
|
+
# Log the full structure to see what we got
|
|
2813
|
+
logger.info(f"Full extracted data (first 1000 chars): {str(extracted_data)[:1000]}")
|
|
2814
|
+
|
|
2815
|
+
if causal_analysis:
|
|
2816
|
+
# Automatically invoke the handler with extracted data
|
|
2817
|
+
logger.info(f"Extracted causal analysis via ML ({len(causal_analysis)} chars, {len(counterfactual_scenarios)} scenarios)")
|
|
2818
|
+
result = self._generate_causal_analysis_handler(
|
|
2819
|
+
causal_analysis=causal_analysis,
|
|
2820
|
+
intervention_planning=intervention_planning,
|
|
2821
|
+
counterfactual_scenarios=counterfactual_scenarios,
|
|
2822
|
+
causal_strength_assessment=causal_strength_assessment,
|
|
2823
|
+
optimal_solution=optimal_solution,
|
|
2824
|
+
)
|
|
2825
|
+
|
|
2826
|
+
# Check if analysis was successful
|
|
2827
|
+
if result.get("status") == "success":
|
|
2828
|
+
logger.info(f"ML-based causal analysis generation successful")
|
|
2829
|
+
# Return the analysis data for use in final result
|
|
2830
|
+
return {
|
|
2831
|
+
'causal_analysis': causal_analysis,
|
|
2832
|
+
'intervention_planning': intervention_planning,
|
|
2833
|
+
'counterfactual_scenarios': counterfactual_scenarios,
|
|
2834
|
+
'causal_strength_assessment': causal_strength_assessment,
|
|
2835
|
+
'optimal_solution': optimal_solution,
|
|
2836
|
+
}
|
|
2837
|
+
else:
|
|
2838
|
+
logger.warning(f"ML-based analysis returned: {result}")
|
|
2839
|
+
else:
|
|
2840
|
+
logger.warning("Extracted data missing causal_analysis field")
|
|
2841
|
+
logger.info(f"Extracted data structure: {extracted_data}")
|
|
2842
|
+
logger.info(f"Available keys: {list(extracted_data.keys()) if isinstance(extracted_data, dict) else 'Not a dict'}")
|
|
2843
|
+
# Try to use the raw response text as causal_analysis if JSON parsing failed
|
|
2844
|
+
if not extracted_data or not isinstance(extracted_data, dict):
|
|
2845
|
+
logger.info("Attempting to use raw response as causal analysis")
|
|
2846
|
+
# Use the raw response as a fallback
|
|
2847
|
+
causal_analysis = response_text[:5000] # Limit length
|
|
2848
|
+
if causal_analysis:
|
|
2849
|
+
result = self._generate_causal_analysis_handler(
|
|
2850
|
+
causal_analysis=causal_analysis,
|
|
2851
|
+
intervention_planning="",
|
|
2852
|
+
counterfactual_scenarios=[],
|
|
2853
|
+
causal_strength_assessment="",
|
|
2854
|
+
optimal_solution="",
|
|
2855
|
+
)
|
|
2856
|
+
if result.get("status") == "success":
|
|
2857
|
+
return {
|
|
2858
|
+
'causal_analysis': causal_analysis,
|
|
2859
|
+
'intervention_planning': "",
|
|
2860
|
+
'counterfactual_scenarios': [],
|
|
2861
|
+
'causal_strength_assessment': "",
|
|
2862
|
+
'optimal_solution': "",
|
|
2863
|
+
}
|
|
2864
|
+
else:
|
|
2865
|
+
logger.warning("Could not extract data from LLM response for causal analysis")
|
|
2866
|
+
logger.debug(f"Raw response type: {type(raw_response)}, value: {str(raw_response)[:500]}")
|
|
2867
|
+
|
|
2868
|
+
except Exception as e:
|
|
2869
|
+
logger.error(f"Error in ML-based causal analysis generation: {e}")
|
|
2870
|
+
import traceback
|
|
2871
|
+
logger.debug(traceback.format_exc())
|
|
2872
|
+
|
|
2873
|
+
return None
|
|
2874
|
+
|
|
2875
|
+
def _extract_variables_from_task(self, task: str) -> bool:
|
|
2876
|
+
"""
|
|
2877
|
+
Extract variables and causal relationships from a task.
|
|
2878
|
+
|
|
2879
|
+
Uses ML-based extraction (structured LLM output + automatic handler invocation)
|
|
2880
|
+
instead of relying on unreliable function calling.
|
|
2881
|
+
|
|
2882
|
+
Args:
|
|
2883
|
+
task: The task string to analyze
|
|
2884
|
+
|
|
2885
|
+
Returns:
|
|
2886
|
+
True if variables were successfully extracted, False otherwise
|
|
2887
|
+
"""
|
|
2888
|
+
# Use ML-based extraction (more reliable than function calling)
|
|
2889
|
+
return self._extract_variables_ml_based(task)
|
|
2890
|
+
|
|
2891
|
+
def _run_llm_causal_analysis(self, task: str, target_variables: Optional[List[str]] = None, **kwargs) -> Dict[str, Any]:
|
|
2892
|
+
"""
|
|
2893
|
+
Run LLM-based causal analysis on a task.
|
|
2894
|
+
|
|
2895
|
+
Args:
|
|
2896
|
+
task: Task string to analyze
|
|
2897
|
+
target_variables: Optional list of target variables for counterfactual scenarios.
|
|
2898
|
+
If None, uses all variables in the causal graph.
|
|
2899
|
+
**kwargs: Additional arguments
|
|
2900
|
+
|
|
2901
|
+
Returns:
|
|
2902
|
+
Dictionary with causal analysis results
|
|
2903
|
+
"""
|
|
2904
|
+
self.causal_memory = []
|
|
2905
|
+
|
|
2906
|
+
# Check if causal graph is empty - if so, trigger variable extraction
|
|
2907
|
+
if len(self.causal_graph) == 0:
|
|
2908
|
+
logger.info("Causal graph is empty, starting variable extraction phase...")
|
|
2909
|
+
extraction_success = self._extract_variables_from_task(task)
|
|
2910
|
+
|
|
2911
|
+
if not extraction_success:
|
|
2912
|
+
# Extraction failed, return error
|
|
2913
|
+
return {
|
|
2914
|
+
'task': task,
|
|
2915
|
+
'error': 'Variable extraction failed',
|
|
2916
|
+
'message': (
|
|
2917
|
+
'Could not extract variables from the task. '
|
|
2918
|
+
'Please ensure the task describes causal relationships or variables, '
|
|
2919
|
+
'or manually initialize the agent with variables and causal_edges.'
|
|
2920
|
+
),
|
|
2921
|
+
'causal_graph_info': {
|
|
2922
|
+
'nodes': [],
|
|
2923
|
+
'edges': [],
|
|
2924
|
+
'is_dag': True
|
|
2925
|
+
}
|
|
2926
|
+
}
|
|
2927
|
+
|
|
2928
|
+
logger.info(
|
|
2929
|
+
f"Variable extraction completed. "
|
|
2930
|
+
f"Graph now has {len(self.causal_graph)} variables and "
|
|
2931
|
+
f"{sum(len(children) for children in self.causal_graph.values())} edges."
|
|
2932
|
+
)
|
|
2933
|
+
|
|
2934
|
+
# Use Agent's normal run method to get rich output and proper loop handling
|
|
2935
|
+
# This will show the LLM's rich output and respect max_loops
|
|
2936
|
+
causal_prompt = self._build_causal_prompt(task)
|
|
2937
|
+
|
|
2938
|
+
# Store the current memory size before running to detect new entries
|
|
2939
|
+
memory_size_before = len(self.causal_memory)
|
|
2940
|
+
|
|
2941
|
+
# Run the agent - this will execute tools and store results in causal_memory
|
|
2942
|
+
super().run(task=causal_prompt)
|
|
2943
|
+
|
|
2944
|
+
# Extract the causal analysis from causal_memory (stored by tool handler)
|
|
2945
|
+
# Look for the most recent analysis entry
|
|
2946
|
+
analysis_result = None
|
|
2947
|
+
final_analysis = ""
|
|
2948
|
+
|
|
2949
|
+
# Search backwards through causal_memory for the most recent analysis
|
|
2950
|
+
# Check both new entries (after memory_size_before) and all entries
|
|
2951
|
+
for entry in reversed(self.causal_memory):
|
|
2952
|
+
if isinstance(entry, dict):
|
|
2953
|
+
# Check if this is an analysis entry (has 'type' == 'analysis' or has 'causal_analysis' field)
|
|
2954
|
+
entry_type = entry.get('type')
|
|
2955
|
+
has_causal_analysis = 'causal_analysis' in entry
|
|
2956
|
+
|
|
2957
|
+
if entry_type == 'analysis' or has_causal_analysis:
|
|
2958
|
+
analysis_result = entry
|
|
2959
|
+
final_analysis = entry.get('causal_analysis', '')
|
|
2960
|
+
if final_analysis:
|
|
2961
|
+
logger.info(f"Found causal analysis in memory (type: {entry_type}, length: {len(final_analysis)})")
|
|
2962
|
+
break
|
|
2963
|
+
|
|
2964
|
+
# If no analysis found in memory, use ML-based generation as fallback
|
|
2965
|
+
if not final_analysis:
|
|
2966
|
+
logger.warning("No causal analysis found in causal_memory, using ML-based generation fallback")
|
|
2967
|
+
ml_analysis = self._generate_causal_analysis_ml_based(task)
|
|
2968
|
+
if ml_analysis and ml_analysis.get('causal_analysis'):
|
|
2969
|
+
# Use the returned analysis directly
|
|
2970
|
+
final_analysis = ml_analysis.get('causal_analysis', '')
|
|
2971
|
+
analysis_result = {
|
|
2972
|
+
'type': 'analysis',
|
|
2973
|
+
'causal_analysis': final_analysis,
|
|
2974
|
+
'intervention_planning': ml_analysis.get('intervention_planning', ''),
|
|
2975
|
+
'counterfactual_scenarios': ml_analysis.get('counterfactual_scenarios', []),
|
|
2976
|
+
'causal_strength_assessment': ml_analysis.get('causal_strength_assessment', ''),
|
|
2977
|
+
'optimal_solution': ml_analysis.get('optimal_solution', ''),
|
|
2978
|
+
}
|
|
2979
|
+
# Also ensure it's stored in memory for consistency
|
|
2980
|
+
if analysis_result not in self.causal_memory:
|
|
2981
|
+
self.causal_memory.append(analysis_result)
|
|
2982
|
+
if final_analysis:
|
|
2983
|
+
logger.info(f"Found causal analysis from ML-based fallback (length: {len(final_analysis)})")
|
|
2984
|
+
|
|
2985
|
+
# If still no analysis found, use the LLM's response as final fallback
|
|
2986
|
+
if not final_analysis:
|
|
2987
|
+
logger.warning("No causal analysis found in causal_memory, attempting fallback from conversation history")
|
|
2988
|
+
# Try to get the last response from the conversation
|
|
2989
|
+
if hasattr(self, 'short_memory') and self.short_memory:
|
|
2990
|
+
# Conversation object has conversation_history attribute, not get_messages()
|
|
2991
|
+
if hasattr(self.short_memory, 'conversation_history'):
|
|
2992
|
+
conversation_history = self.short_memory.conversation_history
|
|
2993
|
+
if conversation_history:
|
|
2994
|
+
# Get the last assistant message
|
|
2995
|
+
for msg in reversed(conversation_history):
|
|
2996
|
+
if isinstance(msg, dict) and msg.get('role') == 'assistant':
|
|
2997
|
+
final_analysis = msg.get('content', '')
|
|
2998
|
+
if final_analysis:
|
|
2999
|
+
logger.info(f"Extracted causal analysis from conversation history (length: {len(final_analysis)})")
|
|
3000
|
+
break
|
|
3001
|
+
elif hasattr(msg, 'role') and msg.role == 'assistant':
|
|
3002
|
+
final_analysis = getattr(msg, 'content', str(msg))
|
|
3003
|
+
if final_analysis:
|
|
3004
|
+
logger.info(f"Extracted causal analysis from conversation history (length: {len(final_analysis)})")
|
|
3005
|
+
break
|
|
3006
|
+
|
|
3007
|
+
if not final_analysis:
|
|
3008
|
+
logger.warning(f"Could not extract causal analysis. Memory size: {len(self.causal_memory)} (was {memory_size_before})")
|
|
3009
|
+
|
|
3010
|
+
default_state = {var: 0.0 for var in self.get_nodes()}
|
|
3011
|
+
self.ensure_standardization_stats(default_state)
|
|
3012
|
+
|
|
3013
|
+
# Use provided target_variables or default to all variables
|
|
3014
|
+
if target_variables is None:
|
|
3015
|
+
target_variables = self.get_nodes()
|
|
3016
|
+
|
|
3017
|
+
# Limit to top variables if too many
|
|
3018
|
+
target_vars = target_variables[:5] if len(target_variables) > 5 else target_variables
|
|
3019
|
+
|
|
3020
|
+
# Use counterfactual scenarios from ML analysis if available
|
|
3021
|
+
if analysis_result and isinstance(analysis_result, dict):
|
|
3022
|
+
ml_scenarios = analysis_result.get('counterfactual_scenarios', [])
|
|
3023
|
+
if ml_scenarios:
|
|
3024
|
+
counterfactual_scenarios = ml_scenarios
|
|
3025
|
+
else:
|
|
3026
|
+
# Fallback to generated scenarios
|
|
3027
|
+
counterfactual_scenarios = self.generate_counterfactual_scenarios(
|
|
3028
|
+
default_state,
|
|
3029
|
+
target_vars,
|
|
3030
|
+
max_scenarios=5
|
|
3031
|
+
)
|
|
3032
|
+
else:
|
|
3033
|
+
# Fallback to generated scenarios
|
|
3034
|
+
counterfactual_scenarios = self.generate_counterfactual_scenarios(
|
|
3035
|
+
default_state,
|
|
3036
|
+
target_vars,
|
|
3037
|
+
max_scenarios=5
|
|
3038
|
+
)
|
|
3039
|
+
|
|
3040
|
+
return {
|
|
3041
|
+
'task': task,
|
|
3042
|
+
'causal_analysis': final_analysis,
|
|
3043
|
+
'counterfactual_scenarios': counterfactual_scenarios,
|
|
3044
|
+
'causal_graph_info': {
|
|
3045
|
+
'nodes': self.get_nodes(),
|
|
3046
|
+
'edges': self.get_edges(),
|
|
3047
|
+
'is_dag': self.is_dag()
|
|
3048
|
+
},
|
|
3049
|
+
'analysis_steps': self.causal_memory,
|
|
3050
|
+
'intervention_planning': analysis_result.get('intervention_planning', '') if analysis_result else '',
|
|
3051
|
+
'causal_strength_assessment': analysis_result.get('causal_strength_assessment', '') if analysis_result else '',
|
|
3052
|
+
'optimal_solution': analysis_result.get('optimal_solution', '') if analysis_result else '',
|
|
3053
|
+
}
|
|
3054
|
+
|
|
3055
|
+
# =========================
|
|
3056
|
+
# Compatibility extensions
|
|
3057
|
+
# =========================
|
|
3058
|
+
|
|
3059
|
+
# ---- Helpers ----
|
|
3060
|
+
@staticmethod
|
|
3061
|
+
def _require_pandas() -> None:
|
|
3062
|
+
if not PANDAS_AVAILABLE:
|
|
3063
|
+
raise ImportError("pandas is required for this operation. Install pandas to proceed.")
|
|
3064
|
+
|
|
3065
|
+
@staticmethod
|
|
3066
|
+
def _require_scipy() -> None:
|
|
3067
|
+
if not SCIPY_AVAILABLE:
|
|
3068
|
+
raise ImportError("scipy is required for this operation. Install scipy to proceed.")
|
|
3069
|
+
|
|
3070
|
+
@staticmethod
|
|
3071
|
+
def _require_cvxpy() -> None:
|
|
3072
|
+
if not CVXPY_AVAILABLE:
|
|
3073
|
+
raise ImportError("cvxpy is required for this operation. Install cvxpy to proceed.")
|
|
3074
|
+
|
|
3075
|
+
def _edge_strength(self, source: str, target: str) -> float:
|
|
3076
|
+
edge = self.causal_graph.get(source, {}).get(target, None)
|
|
3077
|
+
if isinstance(edge, dict):
|
|
3078
|
+
return float(edge.get("strength", 0.0))
|
|
3079
|
+
try:
|
|
3080
|
+
return float(edge) if edge is not None else 0.0
|
|
3081
|
+
except Exception:
|
|
3082
|
+
return 0.0
|
|
3083
|
+
|
|
3084
|
+
class _TimeDebug:
|
|
3085
|
+
def __init__(self, name: str) -> None:
|
|
3086
|
+
self.name = name
|
|
3087
|
+
self.start = 0.0
|
|
3088
|
+
def __enter__(self):
|
|
3089
|
+
try:
|
|
3090
|
+
import time
|
|
3091
|
+
self.start = time.perf_counter()
|
|
3092
|
+
except Exception:
|
|
3093
|
+
self.start = 0.0
|
|
3094
|
+
return self
|
|
3095
|
+
def __exit__(self, exc_type, exc, tb):
|
|
3096
|
+
if logger.isEnabledFor(logging.DEBUG) and self.start:
|
|
3097
|
+
try:
|
|
3098
|
+
import time
|
|
3099
|
+
duration = time.perf_counter() - self.start
|
|
3100
|
+
logger.debug(f"{self.name} completed in {duration:.4f}s")
|
|
3101
|
+
except Exception:
|
|
3102
|
+
pass
|
|
3103
|
+
|
|
3104
|
+
def _ensure_edge(self, source: str, target: str) -> None:
|
|
3105
|
+
"""Ensure edge exists in both dict graph and rustworkx graph."""
|
|
3106
|
+
self._ensure_node_exists(source)
|
|
3107
|
+
self._ensure_node_exists(target)
|
|
3108
|
+
if target not in self.causal_graph.get(source, {}):
|
|
3109
|
+
self.causal_graph.setdefault(source, {})[target] = {"strength": 0.0, "confidence": 1.0}
|
|
3110
|
+
try:
|
|
3111
|
+
u_idx = self._ensure_node_index(source)
|
|
3112
|
+
v_idx = self._ensure_node_index(target)
|
|
3113
|
+
if self._graph.get_edge_data(u_idx, v_idx) is None:
|
|
3114
|
+
self._graph.add_edge(u_idx, v_idx, self.causal_graph[source][target])
|
|
3115
|
+
except Exception:
|
|
3116
|
+
pass
|
|
3117
|
+
|
|
3118
|
+
# Convenience graph-like methods for compatibility
|
|
3119
|
+
def add_nodes_from(self, nodes: List[str]) -> None:
|
|
3120
|
+
for n in nodes:
|
|
3121
|
+
self._ensure_node_exists(n)
|
|
3122
|
+
|
|
3123
|
+
def add_edges_from(self, edges: List[Tuple[str, str]]) -> None:
|
|
3124
|
+
for u, v in edges:
|
|
3125
|
+
self.add_causal_relationship(u, v)
|
|
3126
|
+
|
|
3127
|
+
def edges(self) -> List[Tuple[str, str]]:
|
|
3128
|
+
return self.get_edges()
|
|
3129
|
+
|
|
3130
|
+
# ---- Batched predictions ----
|
|
3131
|
+
def _predict_outcomes_batch(
|
|
3132
|
+
self,
|
|
3133
|
+
factual_states: List[Dict[str, float]],
|
|
3134
|
+
interventions: Optional[Union[Dict[str, float], List[Dict[str, float]]]] = None,
|
|
3135
|
+
) -> List[Dict[str, float]]:
|
|
3136
|
+
"""
|
|
3137
|
+
Batched deterministic SCM forward pass. Uses shared topology and vectorized parent aggregation.
|
|
3138
|
+
"""
|
|
3139
|
+
if not factual_states:
|
|
3140
|
+
return []
|
|
3141
|
+
if len(factual_states) == 1 or not self.enable_batch_predict:
|
|
3142
|
+
return [self._predict_outcomes(factual_states[0], interventions if isinstance(interventions, dict) else (interventions or {}))]
|
|
3143
|
+
|
|
3144
|
+
batch = len(factual_states)
|
|
3145
|
+
if interventions is None:
|
|
3146
|
+
interventions_list = [{} for _ in range(batch)]
|
|
3147
|
+
elif isinstance(interventions, list):
|
|
3148
|
+
interventions_list = interventions
|
|
3149
|
+
else:
|
|
3150
|
+
interventions_list = [interventions for _ in range(batch)]
|
|
3151
|
+
|
|
3152
|
+
topo = self._topological_sort()
|
|
3153
|
+
parents_map = {node: self._get_parents(node) for node in topo}
|
|
3154
|
+
stats = self.standardization_stats
|
|
3155
|
+
z_pred: Dict[str, np.ndarray] = {}
|
|
3156
|
+
|
|
3157
|
+
# Initialize z with raw + interventions standardized
|
|
3158
|
+
for node in topo:
|
|
3159
|
+
arr = np.empty(batch, dtype=float)
|
|
3160
|
+
mean = stats.get(node, {}).get("mean", 0.0)
|
|
3161
|
+
std = stats.get(node, {}).get("std", 1.0) or 1.0
|
|
3162
|
+
for i in range(batch):
|
|
3163
|
+
raw_val = interventions_list[i].get(node, factual_states[i].get(node, 0.0))
|
|
3164
|
+
arr[i] = (raw_val - mean) / std
|
|
3165
|
+
z_pred[node] = arr
|
|
3166
|
+
|
|
3167
|
+
# Propagate for non-intervened nodes
|
|
3168
|
+
for node in topo:
|
|
3169
|
+
parents = parents_map.get(node, [])
|
|
3170
|
+
if not parents:
|
|
3171
|
+
continue
|
|
3172
|
+
arr = z_pred[node]
|
|
3173
|
+
# Only recompute if node not directly intervened
|
|
3174
|
+
intervene_mask = np.array([node in interventions_list[i] for i in range(batch)], dtype=bool)
|
|
3175
|
+
if np.all(intervene_mask):
|
|
3176
|
+
continue
|
|
3177
|
+
if not parents:
|
|
3178
|
+
continue
|
|
3179
|
+
parent_matrix = np.vstack([z_pred[p] for p in parents]) # shape (k, batch)
|
|
3180
|
+
strengths = np.array([self._edge_strength(p, node) for p in parents], dtype=float).reshape(-1, 1)
|
|
3181
|
+
combined = (strengths * parent_matrix).sum(axis=0)
|
|
3182
|
+
if intervene_mask.any():
|
|
3183
|
+
# preserve intervened samples
|
|
3184
|
+
arr = np.where(intervene_mask, arr, combined)
|
|
3185
|
+
else:
|
|
3186
|
+
arr = combined
|
|
3187
|
+
z_pred[node] = arr
|
|
3188
|
+
|
|
3189
|
+
# De-standardize
|
|
3190
|
+
outputs: List[Dict[str, float]] = []
|
|
3191
|
+
for i in range(batch):
|
|
3192
|
+
out: Dict[str, float] = {}
|
|
3193
|
+
for node, z_arr in z_pred.items():
|
|
3194
|
+
s = stats.get(node, {"mean": 0.0, "std": 1.0})
|
|
3195
|
+
out[node] = float(z_arr[i] * s.get("std", 1.0) + s.get("mean", 0.0))
|
|
3196
|
+
outputs.append(out)
|
|
3197
|
+
return outputs
|
|
3198
|
+
|
|
3199
|
+
# Convenience graph-like methods for compatibility
|
|
3200
|
+
def add_nodes_from(self, nodes: List[str]) -> None:
|
|
3201
|
+
for n in nodes:
|
|
3202
|
+
self._ensure_node_exists(n)
|
|
3203
|
+
|
|
3204
|
+
def add_edges_from(self, edges: List[Tuple[str, str]]) -> None:
|
|
3205
|
+
for u, v in edges:
|
|
3206
|
+
self.add_causal_relationship(u, v)
|
|
3207
|
+
|
|
3208
|
+
def edges(self) -> List[Tuple[str, str]]:
|
|
3209
|
+
return self.get_edges()
|
|
3210
|
+
|
|
3211
|
+
# ---- Data-driven fitting ----
|
|
3212
|
+
def fit_from_dataframe(
|
|
3213
|
+
self,
|
|
3214
|
+
df: Any,
|
|
3215
|
+
variables: List[str],
|
|
3216
|
+
window: int = 30,
|
|
3217
|
+
decay_alpha: float = 0.9,
|
|
3218
|
+
ridge_lambda: float = 0.0,
|
|
3219
|
+
enforce_signs: bool = True
|
|
3220
|
+
) -> None:
|
|
3221
|
+
"""
|
|
3222
|
+
Fit edge strengths and standardization stats from a rolling window with recency weighting.
|
|
3223
|
+
"""
|
|
3224
|
+
with self._TimeDebug("fit_from_dataframe"):
|
|
3225
|
+
self._require_pandas()
|
|
3226
|
+
if df is None:
|
|
3227
|
+
return
|
|
3228
|
+
if not isinstance(df, pd.DataFrame):
|
|
3229
|
+
raise TypeError(f"df must be a pandas DataFrame, got {type(df)}")
|
|
3230
|
+
if not variables:
|
|
3231
|
+
return
|
|
3232
|
+
missing = [v for v in variables if v not in df.columns]
|
|
3233
|
+
if missing:
|
|
3234
|
+
raise ValueError(f"Variables not in DataFrame: {missing}")
|
|
3235
|
+
window = max(1, int(window))
|
|
3236
|
+
if not (0 < decay_alpha <= 1):
|
|
3237
|
+
raise ValueError("decay_alpha must be in (0,1]")
|
|
3238
|
+
|
|
3239
|
+
df_local = df[variables].dropna().copy()
|
|
3240
|
+
if df_local.empty:
|
|
3241
|
+
return
|
|
3242
|
+
window_df = df_local.tail(window)
|
|
3243
|
+
n = len(window_df)
|
|
3244
|
+
weights = np.array([decay_alpha ** (n - 1 - i) for i in range(n)], dtype=float)
|
|
3245
|
+
weights = weights / (weights.sum() if weights.sum() != 0 else 1.0)
|
|
3246
|
+
|
|
3247
|
+
# Standardization stats
|
|
3248
|
+
self.standardization_stats = {}
|
|
3249
|
+
for v in variables:
|
|
3250
|
+
m = float(window_df[v].mean())
|
|
3251
|
+
s = float(window_df[v].std(ddof=0))
|
|
3252
|
+
if s == 0:
|
|
3253
|
+
s = 1.0
|
|
3254
|
+
self.standardization_stats[v] = {"mean": m, "std": s}
|
|
3255
|
+
for node in self.causal_graph.keys():
|
|
3256
|
+
if node not in self.standardization_stats:
|
|
3257
|
+
self.standardization_stats[node] = {"mean": 0.0, "std": 1.0}
|
|
3258
|
+
|
|
3259
|
+
# Estimate edge strengths
|
|
3260
|
+
for child in list(self.causal_graph.keys()):
|
|
3261
|
+
parents = self._get_parents(child)
|
|
3262
|
+
if not parents:
|
|
3263
|
+
continue
|
|
3264
|
+
if child not in window_df.columns:
|
|
3265
|
+
continue
|
|
3266
|
+
parent_vals = []
|
|
3267
|
+
for p in parents:
|
|
3268
|
+
if p in window_df.columns:
|
|
3269
|
+
stats = self.standardization_stats.get(p, {"mean": 0.0, "std": 1.0})
|
|
3270
|
+
parent_vals.append(((window_df[p] - stats["mean"]) / stats["std"]).values)
|
|
3271
|
+
if not parent_vals:
|
|
3272
|
+
continue
|
|
3273
|
+
X = np.vstack(parent_vals).T
|
|
3274
|
+
y_stats = self.standardization_stats.get(child, {"mean": 0.0, "std": 1.0})
|
|
3275
|
+
y = ((window_df[child] - y_stats["mean"]) / y_stats["std"]).values
|
|
3276
|
+
W = np.diag(weights)
|
|
3277
|
+
XtW = X.T @ W
|
|
3278
|
+
XtWX = XtW @ X
|
|
3279
|
+
if ridge_lambda > 0 and XtWX.size > 0:
|
|
3280
|
+
k = XtWX.shape[0]
|
|
3281
|
+
XtWX = XtWX + ridge_lambda * np.eye(k)
|
|
3282
|
+
try:
|
|
3283
|
+
XtWX_inv = np.linalg.pinv(XtWX)
|
|
3284
|
+
beta = XtWX_inv @ (XtW @ y)
|
|
3285
|
+
except Exception:
|
|
3286
|
+
beta = np.zeros(X.shape[1])
|
|
3287
|
+
beta = np.asarray(beta)
|
|
3288
|
+
for idx, p in enumerate(parents):
|
|
3289
|
+
strength = float(beta[idx]) if idx < len(beta) else 0.0
|
|
3290
|
+
if enforce_signs:
|
|
3291
|
+
sign = self.edge_sign_constraints.get((p, child))
|
|
3292
|
+
if sign == 1 and strength < 0:
|
|
3293
|
+
strength = 0.0
|
|
3294
|
+
elif sign == -1 and strength > 0:
|
|
3295
|
+
strength = 0.0
|
|
3296
|
+
self._ensure_edge(p, child)
|
|
3297
|
+
self.causal_graph[p][child]["strength"] = strength
|
|
3298
|
+
self.causal_graph[p][child]["confidence"] = 1.0
|
|
3299
|
+
|
|
3300
|
+
# ---- Uncertainty ----
|
|
3301
|
+
def quantify_uncertainty(
|
|
3302
|
+
self,
|
|
3303
|
+
df: Any,
|
|
3304
|
+
variables: List[str],
|
|
3305
|
+
windows: int = 200,
|
|
3306
|
+
alpha: float = 0.95
|
|
3307
|
+
) -> Dict[str, Any]:
|
|
3308
|
+
with self._TimeDebug("quantify_uncertainty"):
|
|
3309
|
+
self._require_pandas()
|
|
3310
|
+
if df is None or not isinstance(df, pd.DataFrame):
|
|
3311
|
+
return {"edge_cis": {}, "samples": 0}
|
|
3312
|
+
usable = df[variables].dropna()
|
|
3313
|
+
if len(usable) < 10:
|
|
3314
|
+
return {"edge_cis": {}, "samples": 0}
|
|
3315
|
+
windows = max(1, int(windows))
|
|
3316
|
+
samples: Dict[Tuple[str, str], List[float]] = {}
|
|
3317
|
+
|
|
3318
|
+
# Snapshot current strengths to restore later
|
|
3319
|
+
baseline_strengths: Dict[Tuple[str, str], float] = {}
|
|
3320
|
+
for u, targets in self.causal_graph.items():
|
|
3321
|
+
for v, meta in targets.items():
|
|
3322
|
+
try:
|
|
3323
|
+
baseline_strengths[(u, v)] = float(meta.get("strength", 0.0)) if isinstance(meta, dict) else float(meta)
|
|
3324
|
+
except Exception:
|
|
3325
|
+
baseline_strengths[(u, v)] = 0.0
|
|
3326
|
+
baseline_stats = dict(self.standardization_stats)
|
|
3327
|
+
|
|
3328
|
+
def _snapshot_strengths() -> Dict[Tuple[str, str], float]:
|
|
3329
|
+
snap: Dict[Tuple[str, str], float] = {}
|
|
3330
|
+
for u, targets in self.causal_graph.items():
|
|
3331
|
+
for v, meta in targets.items():
|
|
3332
|
+
try:
|
|
3333
|
+
snap[(u, v)] = float(meta.get("strength", 0.0)) if isinstance(meta, dict) else float(meta)
|
|
3334
|
+
except Exception:
|
|
3335
|
+
snap[(u, v)] = 0.0
|
|
3336
|
+
return snap
|
|
3337
|
+
|
|
3338
|
+
def _bootstrap_single(df_sample: "pd.DataFrame") -> Dict[Tuple[str, str], float]:
|
|
3339
|
+
# Use a shallow clone to avoid mutating main agent when running in parallel
|
|
3340
|
+
clone = CRCAAgent(
|
|
3341
|
+
variables=list(self.causal_graph.keys()),
|
|
3342
|
+
causal_edges=self.get_edges(),
|
|
3343
|
+
model_name=self.model_name,
|
|
3344
|
+
max_loops=self.causal_max_loops,
|
|
3345
|
+
enable_batch_predict=self.enable_batch_predict,
|
|
3346
|
+
max_batch_size=self.max_batch_size,
|
|
3347
|
+
bootstrap_workers=0,
|
|
3348
|
+
use_async=self.use_async,
|
|
3349
|
+
seed=self.seed,
|
|
3350
|
+
)
|
|
3351
|
+
clone.edge_sign_constraints = dict(self.edge_sign_constraints)
|
|
3352
|
+
clone.standardization_stats = dict(baseline_stats)
|
|
3353
|
+
try:
|
|
3354
|
+
clone.fit_from_dataframe(
|
|
3355
|
+
df=df_sample,
|
|
3356
|
+
variables=variables,
|
|
3357
|
+
window=min(30, len(df_sample)),
|
|
3358
|
+
decay_alpha=0.9,
|
|
3359
|
+
ridge_lambda=0.0,
|
|
3360
|
+
enforce_signs=True,
|
|
3361
|
+
)
|
|
3362
|
+
return _snapshot_strengths_from_graph(clone.causal_graph)
|
|
3363
|
+
except Exception:
|
|
3364
|
+
return {}
|
|
3365
|
+
|
|
3366
|
+
def _snapshot_strengths_from_graph(graph: Dict[str, Dict[str, Any]]) -> Dict[Tuple[str, str], float]:
|
|
3367
|
+
res: Dict[Tuple[str, str], float] = {}
|
|
3368
|
+
for u, targets in graph.items():
|
|
3369
|
+
for v, meta in targets.items():
|
|
3370
|
+
try:
|
|
3371
|
+
res[(u, v)] = float(meta.get("strength", 0.0)) if isinstance(meta, dict) else float(meta)
|
|
3372
|
+
except Exception:
|
|
3373
|
+
res[(u, v)] = 0.0
|
|
3374
|
+
return res
|
|
3375
|
+
|
|
3376
|
+
use_parallel = self.bootstrap_workers > 0
|
|
3377
|
+
if use_parallel:
|
|
3378
|
+
import concurrent.futures
|
|
3379
|
+
with concurrent.futures.ThreadPoolExecutor(max_workers=self.bootstrap_workers) as executor:
|
|
3380
|
+
futures = []
|
|
3381
|
+
for i in range(windows):
|
|
3382
|
+
boot_df = usable.sample(n=len(usable), replace=True, random_state=self.seed + i)
|
|
3383
|
+
futures.append(executor.submit(_bootstrap_single, boot_df))
|
|
3384
|
+
for fut in futures:
|
|
3385
|
+
try:
|
|
3386
|
+
res_strengths = fut.result()
|
|
3387
|
+
for (u, v), w in res_strengths.items():
|
|
3388
|
+
samples.setdefault((u, v), []).append(w)
|
|
3389
|
+
except Exception:
|
|
3390
|
+
continue
|
|
3391
|
+
else:
|
|
3392
|
+
for i in range(windows):
|
|
3393
|
+
boot_df = usable.sample(n=len(usable), replace=True, random_state=self.seed + i)
|
|
3394
|
+
try:
|
|
3395
|
+
self.fit_from_dataframe(
|
|
3396
|
+
df=boot_df,
|
|
3397
|
+
variables=variables,
|
|
3398
|
+
window=min(30, len(boot_df)),
|
|
3399
|
+
decay_alpha=0.9,
|
|
3400
|
+
ridge_lambda=0.0,
|
|
3401
|
+
enforce_signs=True,
|
|
3402
|
+
)
|
|
3403
|
+
for (u, v), w in _snapshot_strengths().items():
|
|
3404
|
+
samples.setdefault((u, v), []).append(w)
|
|
3405
|
+
except Exception:
|
|
3406
|
+
continue
|
|
3407
|
+
|
|
3408
|
+
# Restore baseline strengths and stats
|
|
3409
|
+
for (u, v), w in baseline_strengths.items():
|
|
3410
|
+
if u in self.causal_graph and v in self.causal_graph[u]:
|
|
3411
|
+
self.causal_graph[u][v]["strength"] = w
|
|
3412
|
+
self.standardization_stats = baseline_stats
|
|
3413
|
+
|
|
3414
|
+
edge_cis: Dict[str, Tuple[float, float]] = {}
|
|
3415
|
+
for (u, v), arr in samples.items():
|
|
3416
|
+
arr_np = np.array(arr)
|
|
3417
|
+
lo = float(np.quantile(arr_np, (1 - alpha) / 2))
|
|
3418
|
+
hi = float(np.quantile(arr_np, 1 - (1 - alpha) / 2))
|
|
3419
|
+
edge_cis[f"{u}->{v}"] = (lo, hi)
|
|
3420
|
+
return {"edge_cis": edge_cis, "samples": windows}
|
|
3421
|
+
|
|
3422
|
+
# ---- Optimization ----
|
|
3423
|
+
def gradient_based_intervention_optimization(
|
|
3424
|
+
self,
|
|
3425
|
+
initial_state: Dict[str, float],
|
|
3426
|
+
target: str,
|
|
3427
|
+
intervention_vars: List[str],
|
|
3428
|
+
constraints: Optional[Dict[str, Tuple[float, float]]] = None,
|
|
3429
|
+
method: str = "L-BFGS-B",
|
|
3430
|
+
) -> Dict[str, Any]:
|
|
3431
|
+
self._require_scipy()
|
|
3432
|
+
from scipy.optimize import minimize # type: ignore
|
|
3433
|
+
|
|
3434
|
+
if not intervention_vars:
|
|
3435
|
+
return {"error": "intervention_vars cannot be empty", "optimal_intervention": {}, "success": False}
|
|
3436
|
+
|
|
3437
|
+
bounds = []
|
|
3438
|
+
x0 = []
|
|
3439
|
+
for var in intervention_vars:
|
|
3440
|
+
cur = float(initial_state.get(var, 0.0))
|
|
3441
|
+
x0.append(cur)
|
|
3442
|
+
if constraints and var in constraints:
|
|
3443
|
+
bounds.append(constraints[var])
|
|
3444
|
+
else:
|
|
3445
|
+
bounds.append((cur - 3.0, cur + 3.0))
|
|
3446
|
+
|
|
3447
|
+
def objective(x: np.ndarray) -> float:
|
|
3448
|
+
intervention = {intervention_vars[i]: float(x[i]) for i in range(len(x))}
|
|
3449
|
+
outcome = self._predict_outcomes(initial_state, intervention)
|
|
3450
|
+
return -float(outcome.get(target, 0.0))
|
|
3451
|
+
|
|
3452
|
+
try:
|
|
3453
|
+
result = minimize(
|
|
3454
|
+
objective,
|
|
3455
|
+
x0=np.array(x0, dtype=float),
|
|
3456
|
+
method=method,
|
|
3457
|
+
bounds=bounds,
|
|
3458
|
+
options={"maxiter": 100, "ftol": 1e-6},
|
|
3459
|
+
)
|
|
3460
|
+
optimal_intervention = {intervention_vars[i]: float(result.x[i]) for i in range(len(result.x))}
|
|
3461
|
+
optimal_outcome = self._predict_outcomes(initial_state, optimal_intervention)
|
|
3462
|
+
return {
|
|
3463
|
+
"optimal_intervention": optimal_intervention,
|
|
3464
|
+
"optimal_target_value": float(optimal_outcome.get(target, 0.0)),
|
|
3465
|
+
"objective_value": float(result.fun),
|
|
3466
|
+
"success": bool(result.success),
|
|
3467
|
+
"iterations": int(getattr(result, "nit", 0)),
|
|
3468
|
+
"convergence_message": str(result.message),
|
|
3469
|
+
}
|
|
3470
|
+
except Exception as e:
|
|
3471
|
+
logger.debug(f"gradient_based_intervention_optimization failed: {e}")
|
|
3472
|
+
return {"error": str(e), "optimal_intervention": {}, "success": False}
|
|
3473
|
+
|
|
3474
|
+
def bellman_optimal_intervention(
|
|
3475
|
+
self,
|
|
3476
|
+
initial_state: Dict[str, float],
|
|
3477
|
+
target: str,
|
|
3478
|
+
intervention_vars: List[str],
|
|
3479
|
+
horizon: int = 5,
|
|
3480
|
+
discount: float = 0.9,
|
|
3481
|
+
) -> Dict[str, Any]:
|
|
3482
|
+
if not intervention_vars:
|
|
3483
|
+
return {"error": "intervention_vars cannot be empty"}
|
|
3484
|
+
horizon = max(1, int(horizon))
|
|
3485
|
+
rng = np.random.default_rng(self.seed)
|
|
3486
|
+
current_state = dict(initial_state)
|
|
3487
|
+
sequence: List[Dict[str, float]] = []
|
|
3488
|
+
|
|
3489
|
+
def reward(state: Dict[str, float]) -> float:
|
|
3490
|
+
return float(state.get(target, 0.0))
|
|
3491
|
+
|
|
3492
|
+
for _ in range(horizon):
|
|
3493
|
+
best_value = float("-inf")
|
|
3494
|
+
best_intervention: Dict[str, float] = {}
|
|
3495
|
+
for _ in range(10):
|
|
3496
|
+
candidate = {}
|
|
3497
|
+
for var in intervention_vars:
|
|
3498
|
+
stats = self.standardization_stats.get(var, {"mean": current_state.get(var, 0.0), "std": 1.0})
|
|
3499
|
+
candidate[var] = float(rng.normal(stats["mean"], stats["std"]))
|
|
3500
|
+
next_state = self._predict_outcomes(current_state, candidate)
|
|
3501
|
+
val = reward(next_state)
|
|
3502
|
+
if val > best_value:
|
|
3503
|
+
best_value = val
|
|
3504
|
+
best_intervention = candidate
|
|
3505
|
+
if best_intervention:
|
|
3506
|
+
sequence.append(best_intervention)
|
|
3507
|
+
current_state = self._predict_outcomes(current_state, best_intervention)
|
|
3508
|
+
|
|
3509
|
+
return {
|
|
3510
|
+
"optimal_sequence": sequence,
|
|
3511
|
+
"final_state": current_state,
|
|
3512
|
+
"total_value": float(current_state.get(target, 0.0)),
|
|
3513
|
+
"horizon": horizon,
|
|
3514
|
+
"discount_factor": float(discount),
|
|
3515
|
+
}
|
|
3516
|
+
|
|
3517
|
+
# ---- Time-series & causality ----
|
|
3518
|
+
def granger_causality_test(
|
|
3519
|
+
self,
|
|
3520
|
+
df: Any,
|
|
3521
|
+
var1: str,
|
|
3522
|
+
var2: str,
|
|
3523
|
+
max_lag: int = 4,
|
|
3524
|
+
) -> Dict[str, Any]:
|
|
3525
|
+
self._require_pandas()
|
|
3526
|
+
if df is None or not isinstance(df, pd.DataFrame):
|
|
3527
|
+
return {"error": "Invalid data or variables"}
|
|
3528
|
+
data = df[[var1, var2]].dropna()
|
|
3529
|
+
if len(data) < max_lag * 2 + 5:
|
|
3530
|
+
return {"error": "Insufficient data"}
|
|
3531
|
+
try:
|
|
3532
|
+
from scipy.stats import f as f_dist # type: ignore
|
|
3533
|
+
except Exception:
|
|
3534
|
+
return {"error": "scipy f distribution not available"}
|
|
3535
|
+
|
|
3536
|
+
n = len(data)
|
|
3537
|
+
y = data[var2].values
|
|
3538
|
+
Xr = []
|
|
3539
|
+
Xu = []
|
|
3540
|
+
for t in range(max_lag, n):
|
|
3541
|
+
y_t = y[t]
|
|
3542
|
+
lags_var2 = [data[var2].iloc[t - i] for i in range(1, max_lag + 1)]
|
|
3543
|
+
lags_var1 = [data[var1].iloc[t - i] for i in range(1, max_lag + 1)]
|
|
3544
|
+
Xr.append(lags_var2)
|
|
3545
|
+
Xu.append(lags_var2 + lags_var1)
|
|
3546
|
+
y_vec = np.array(y[max_lag:], dtype=float)
|
|
3547
|
+
Xr = np.array(Xr, dtype=float)
|
|
3548
|
+
Xu = np.array(Xu, dtype=float)
|
|
3549
|
+
|
|
3550
|
+
def ols(X: np.ndarray, yv: np.ndarray) -> Tuple[np.ndarray, float]:
|
|
3551
|
+
beta = np.linalg.pinv(X) @ yv
|
|
3552
|
+
y_pred = X @ beta
|
|
3553
|
+
rss = float(np.sum((yv - y_pred) ** 2))
|
|
3554
|
+
return beta, rss
|
|
3555
|
+
|
|
3556
|
+
try:
|
|
3557
|
+
_, rss_r = ols(Xr, y_vec)
|
|
3558
|
+
_, rss_u = ols(Xu, y_vec)
|
|
3559
|
+
m = max_lag
|
|
3560
|
+
df2 = len(y_vec) - 2 * m - 1
|
|
3561
|
+
if df2 <= 0 or rss_u <= 1e-12:
|
|
3562
|
+
return {"error": "Degenerate case in F-test"}
|
|
3563
|
+
f_stat = ((rss_r - rss_u) / m) / (rss_u / df2)
|
|
3564
|
+
p_value = float(1.0 - f_dist.cdf(f_stat, m, df2))
|
|
3565
|
+
return {
|
|
3566
|
+
"f_statistic": float(f_stat),
|
|
3567
|
+
"p_value": p_value,
|
|
3568
|
+
"granger_causes": p_value < 0.05,
|
|
3569
|
+
"max_lag": max_lag,
|
|
3570
|
+
"restricted_rss": rss_r,
|
|
3571
|
+
"unrestricted_rss": rss_u,
|
|
3572
|
+
}
|
|
3573
|
+
except Exception as e:
|
|
3574
|
+
return {"error": str(e)}
|
|
3575
|
+
|
|
3576
|
+
def vector_autoregression_estimation(
|
|
3577
|
+
self,
|
|
3578
|
+
df: Any,
|
|
3579
|
+
variables: List[str],
|
|
3580
|
+
max_lag: int = 2,
|
|
3581
|
+
) -> Dict[str, Any]:
|
|
3582
|
+
self._require_pandas()
|
|
3583
|
+
if df is None or not isinstance(df, pd.DataFrame):
|
|
3584
|
+
return {"error": "Invalid data"}
|
|
3585
|
+
data = df[variables].dropna()
|
|
3586
|
+
if len(data) < max_lag * len(variables) + 5:
|
|
3587
|
+
return {"error": "Insufficient data"}
|
|
3588
|
+
n_vars = len(variables)
|
|
3589
|
+
X_lag = []
|
|
3590
|
+
y_mat = []
|
|
3591
|
+
for t in range(max_lag, len(data)):
|
|
3592
|
+
y_row = [data[var].iloc[t] for var in variables]
|
|
3593
|
+
y_mat.append(y_row)
|
|
3594
|
+
lag_row = []
|
|
3595
|
+
for lag in range(1, max_lag + 1):
|
|
3596
|
+
for var in variables:
|
|
3597
|
+
lag_row.append(data[var].iloc[t - lag])
|
|
3598
|
+
X_lag.append(lag_row)
|
|
3599
|
+
X = np.array(X_lag, dtype=float)
|
|
3600
|
+
Y = np.array(y_mat, dtype=float)
|
|
3601
|
+
coefficients: Dict[str, Any] = {}
|
|
3602
|
+
residuals = []
|
|
3603
|
+
for idx, var in enumerate(variables):
|
|
3604
|
+
y_vec = Y[:, idx]
|
|
3605
|
+
beta = np.linalg.pinv(X) @ y_vec
|
|
3606
|
+
y_pred = X @ beta
|
|
3607
|
+
res = y_vec - y_pred
|
|
3608
|
+
residuals.append(res)
|
|
3609
|
+
coefficients[var] = {"coefficients": beta.tolist()}
|
|
3610
|
+
residuals = np.array(residuals).T
|
|
3611
|
+
return {
|
|
3612
|
+
"coefficient_matrices": coefficients,
|
|
3613
|
+
"residuals": residuals.tolist(),
|
|
3614
|
+
"n_observations": len(Y),
|
|
3615
|
+
"n_variables": n_vars,
|
|
3616
|
+
"max_lag": max_lag,
|
|
3617
|
+
"variables": variables,
|
|
3618
|
+
}
|
|
3619
|
+
|
|
3620
|
+
def compute_information_theoretic_measures(
|
|
3621
|
+
self,
|
|
3622
|
+
df: Any,
|
|
3623
|
+
variables: List[str],
|
|
3624
|
+
) -> Dict[str, Any]:
|
|
3625
|
+
"""
|
|
3626
|
+
Compute simple entropy and mutual information estimates using histograms.
|
|
3627
|
+
"""
|
|
3628
|
+
self._require_pandas()
|
|
3629
|
+
if df is None or not isinstance(df, pd.DataFrame):
|
|
3630
|
+
return {"error": "Invalid data"}
|
|
3631
|
+
data = df[variables].dropna()
|
|
3632
|
+
if len(data) < 10:
|
|
3633
|
+
return {"error": "Insufficient data"}
|
|
3634
|
+
|
|
3635
|
+
results: Dict[str, Any] = {"entropies": {}, "mutual_information": {}}
|
|
3636
|
+
for var in variables:
|
|
3637
|
+
if var not in data.columns:
|
|
3638
|
+
continue
|
|
3639
|
+
series = data[var].dropna()
|
|
3640
|
+
if len(series) < 5:
|
|
3641
|
+
continue
|
|
3642
|
+
n_bins = min(20, max(5, int(np.sqrt(len(series)))))
|
|
3643
|
+
hist, _ = np.histogram(series, bins=n_bins)
|
|
3644
|
+
hist = hist[hist > 0]
|
|
3645
|
+
probs = hist / hist.sum()
|
|
3646
|
+
entropy = -np.sum(probs * np.log2(probs))
|
|
3647
|
+
results["entropies"][var] = float(entropy)
|
|
3648
|
+
|
|
3649
|
+
# Pairwise mutual information
|
|
3650
|
+
for i, var1 in enumerate(variables):
|
|
3651
|
+
if var1 not in results["entropies"]:
|
|
3652
|
+
continue
|
|
3653
|
+
for var2 in variables[i + 1:]:
|
|
3654
|
+
if var2 not in results["entropies"]:
|
|
3655
|
+
continue
|
|
3656
|
+
joint = data[[var1, var2]].dropna()
|
|
3657
|
+
if len(joint) < 5:
|
|
3658
|
+
continue
|
|
3659
|
+
n_bins = min(10, max(3, int(np.cbrt(len(joint)))))
|
|
3660
|
+
hist2d, _, _ = np.histogram2d(joint[var1], joint[var2], bins=n_bins)
|
|
3661
|
+
hist2d = hist2d[hist2d > 0]
|
|
3662
|
+
probs_joint = hist2d / hist2d.sum()
|
|
3663
|
+
h_joint = -np.sum(probs_joint * np.log2(probs_joint))
|
|
3664
|
+
mi = results["entropies"][var1] + results["entropies"][var2] - float(h_joint)
|
|
3665
|
+
results["mutual_information"][f"{var1};{var2}"] = float(max(0.0, mi))
|
|
3666
|
+
|
|
3667
|
+
return results
|
|
3668
|
+
|
|
3669
|
+
# ---- Bayesian & attribution ----
|
|
3670
|
+
def bayesian_edge_inference(
|
|
3671
|
+
self,
|
|
3672
|
+
df: Any,
|
|
3673
|
+
parent: str,
|
|
3674
|
+
child: str,
|
|
3675
|
+
prior_mu: float = 0.0,
|
|
3676
|
+
prior_sigma: float = 1.0,
|
|
3677
|
+
) -> Dict[str, Any]:
|
|
3678
|
+
self._require_pandas()
|
|
3679
|
+
if df is None or not isinstance(df, pd.DataFrame):
|
|
3680
|
+
return {"error": "Invalid data"}
|
|
3681
|
+
if parent not in df.columns or child not in df.columns:
|
|
3682
|
+
return {"error": "Variables not found"}
|
|
3683
|
+
data = df[[parent, child]].dropna()
|
|
3684
|
+
if len(data) < 5:
|
|
3685
|
+
return {"error": "Insufficient data"}
|
|
3686
|
+
X = data[parent].values.reshape(-1, 1)
|
|
3687
|
+
y = data[child].values
|
|
3688
|
+
X_mean, X_std = X.mean(), X.std() or 1.0
|
|
3689
|
+
y_mean, y_std = y.mean(), y.std() or 1.0
|
|
3690
|
+
X_norm = (X - X_mean) / X_std
|
|
3691
|
+
y_norm = (y - y_mean) / y_std
|
|
3692
|
+
XtX = X_norm.T @ X_norm
|
|
3693
|
+
Xty = X_norm.T @ y_norm
|
|
3694
|
+
beta_ols = float((np.linalg.pinv(XtX) @ Xty)[0])
|
|
3695
|
+
residuals = y_norm - X_norm @ np.array([beta_ols])
|
|
3696
|
+
sigma_sq = float(np.var(residuals))
|
|
3697
|
+
tau_likelihood = 1.0 / (sigma_sq + 1e-6)
|
|
3698
|
+
tau_prior = 1.0 / (prior_sigma ** 2)
|
|
3699
|
+
tau_post = tau_prior + tau_likelihood * len(data)
|
|
3700
|
+
mu_post = (tau_prior * prior_mu + tau_likelihood * len(data) * beta_ols) / tau_post
|
|
3701
|
+
sigma_post = math.sqrt(1.0 / tau_post)
|
|
3702
|
+
ci_lower = mu_post - 1.96 * sigma_post
|
|
3703
|
+
ci_upper = mu_post + 1.96 * sigma_post
|
|
3704
|
+
self.bayesian_priors[(parent, child)] = {"mu": prior_mu, "sigma": prior_sigma}
|
|
3705
|
+
return {
|
|
3706
|
+
"posterior_mean": float(mu_post),
|
|
3707
|
+
"posterior_std": float(sigma_post),
|
|
3708
|
+
"posterior_variance": float(sigma_post ** 2),
|
|
3709
|
+
"credible_interval_95": (float(ci_lower), float(ci_upper)),
|
|
3710
|
+
"ols_estimate": float(beta_ols),
|
|
3711
|
+
"prior_mu": float(prior_mu),
|
|
3712
|
+
"prior_sigma": float(prior_sigma),
|
|
3713
|
+
}
|
|
3714
|
+
|
|
3715
|
+
def sensitivity_analysis(
|
|
3716
|
+
self,
|
|
3717
|
+
intervention: Dict[str, float],
|
|
3718
|
+
target: str,
|
|
3719
|
+
perturbation_size: float = 0.01,
|
|
3720
|
+
) -> Dict[str, Any]:
|
|
3721
|
+
base_outcome = self._predict_outcomes({}, intervention)
|
|
3722
|
+
base_target = base_outcome.get(target, 0.0)
|
|
3723
|
+
sensitivities: Dict[str, float] = {}
|
|
3724
|
+
elasticities: Dict[str, float] = {}
|
|
3725
|
+
for var, val in intervention.items():
|
|
3726
|
+
perturbed = dict(intervention)
|
|
3727
|
+
perturbed[var] = val + perturbation_size
|
|
3728
|
+
perturbed_outcome = self._predict_outcomes({}, perturbed)
|
|
3729
|
+
pert_target = perturbed_outcome.get(target, 0.0)
|
|
3730
|
+
sensitivity = (pert_target - base_target) / perturbation_size
|
|
3731
|
+
sensitivities[var] = float(sensitivity)
|
|
3732
|
+
if abs(base_target) > 1e-6 and abs(val) > 1e-6:
|
|
3733
|
+
elasticities[var] = float(sensitivity * (val / base_target))
|
|
3734
|
+
else:
|
|
3735
|
+
elasticities[var] = 0.0
|
|
3736
|
+
most_inf = max(sensitivities.items(), key=lambda x: abs(x[1])) if sensitivities else (None, 0.0)
|
|
3737
|
+
total_sens = float(np.linalg.norm(list(sensitivities.values()))) if sensitivities else 0.0
|
|
3738
|
+
return {
|
|
3739
|
+
"sensitivities": sensitivities,
|
|
3740
|
+
"elasticities": elasticities,
|
|
3741
|
+
"total_sensitivity": total_sens,
|
|
3742
|
+
"most_influential_variable": most_inf[0],
|
|
3743
|
+
"most_influential_sensitivity": float(most_inf[1]),
|
|
3744
|
+
}
|
|
3745
|
+
|
|
3746
|
+
def deep_root_cause_analysis(
|
|
3747
|
+
self,
|
|
3748
|
+
problem_variable: str,
|
|
3749
|
+
max_depth: int = 20,
|
|
3750
|
+
min_path_strength: float = 0.01,
|
|
3751
|
+
) -> Dict[str, Any]:
|
|
3752
|
+
if problem_variable not in self.causal_graph:
|
|
3753
|
+
return {"error": f"Variable {problem_variable} not in causal graph"}
|
|
3754
|
+
all_ancestors = list(self.causal_graph_reverse.get(problem_variable, []))
|
|
3755
|
+
root_causes: List[Dict[str, Any]] = []
|
|
3756
|
+
paths_to_problem: List[Dict[str, Any]] = []
|
|
3757
|
+
|
|
3758
|
+
def path_strength(path: List[str]) -> float:
|
|
3759
|
+
prod = 1.0
|
|
3760
|
+
for i in range(len(path) - 1):
|
|
3761
|
+
u, v = path[i], path[i + 1]
|
|
3762
|
+
prod *= self._edge_strength(u, v)
|
|
3763
|
+
if abs(prod) < min_path_strength:
|
|
3764
|
+
return 0.0
|
|
3765
|
+
return prod
|
|
3766
|
+
|
|
3767
|
+
for anc in all_ancestors:
|
|
3768
|
+
try:
|
|
3769
|
+
queue = [(anc, [anc])]
|
|
3770
|
+
visited = set()
|
|
3771
|
+
while queue:
|
|
3772
|
+
node, path = queue.pop(0)
|
|
3773
|
+
if len(path) - 1 > max_depth:
|
|
3774
|
+
continue
|
|
3775
|
+
if node == problem_variable and len(path) > 1:
|
|
3776
|
+
ps = path_strength(path)
|
|
3777
|
+
if abs(ps) > 0:
|
|
3778
|
+
root_causes.append({
|
|
3779
|
+
"root_cause": path[0],
|
|
3780
|
+
"path_to_problem": path,
|
|
3781
|
+
"path_string": " -> ".join(path),
|
|
3782
|
+
"path_strength": float(ps),
|
|
3783
|
+
"depth": len(path) - 1,
|
|
3784
|
+
"is_exogenous": len(self._get_parents(path[0])) == 0,
|
|
3785
|
+
})
|
|
3786
|
+
paths_to_problem.append({
|
|
3787
|
+
"from": path[0],
|
|
3788
|
+
"to": problem_variable,
|
|
3789
|
+
"path": path,
|
|
3790
|
+
"strength": float(ps),
|
|
3791
|
+
})
|
|
3792
|
+
continue
|
|
3793
|
+
for child in self._get_children(node):
|
|
3794
|
+
if child not in visited:
|
|
3795
|
+
visited.add(child)
|
|
3796
|
+
queue.append((child, path + [child]))
|
|
3797
|
+
except Exception:
|
|
3798
|
+
continue
|
|
3799
|
+
|
|
3800
|
+
root_causes.sort(key=lambda x: (-x["is_exogenous"], -abs(x["path_strength"]), x["depth"]))
|
|
3801
|
+
ultimate_roots = [rc for rc in root_causes if rc.get("is_exogenous")]
|
|
3802
|
+
return {
|
|
3803
|
+
"problem_variable": problem_variable,
|
|
3804
|
+
"all_root_causes": root_causes[:20],
|
|
3805
|
+
"ultimate_root_causes": ultimate_roots[:10],
|
|
3806
|
+
"total_paths_found": len(paths_to_problem),
|
|
3807
|
+
"max_depth_reached": max([rc["depth"] for rc in root_causes] + [0]),
|
|
3808
|
+
}
|
|
3809
|
+
|
|
3810
|
+
def shapley_value_attribution(
|
|
3811
|
+
self,
|
|
3812
|
+
baseline_state: Dict[str, float],
|
|
3813
|
+
target_state: Dict[str, float],
|
|
3814
|
+
target: str,
|
|
3815
|
+
samples: int = 100,
|
|
3816
|
+
) -> Dict[str, Any]:
|
|
3817
|
+
variables = list(set(list(baseline_state.keys()) + list(target_state.keys())))
|
|
3818
|
+
n = len(variables)
|
|
3819
|
+
if n == 0:
|
|
3820
|
+
return {"shapley_values": {}, "normalized": {}, "total_attribution": 0.0}
|
|
3821
|
+
rng = np.random.default_rng(self.seed)
|
|
3822
|
+
contributions: Dict[str, float] = {v: 0.0 for v in variables}
|
|
3823
|
+
|
|
3824
|
+
def value(subset: List[str]) -> float:
|
|
3825
|
+
state = dict(baseline_state)
|
|
3826
|
+
for var in subset:
|
|
3827
|
+
if var in target_state:
|
|
3828
|
+
state[var] = target_state[var]
|
|
3829
|
+
outcome = self._predict_outcomes({}, state)
|
|
3830
|
+
return float(outcome.get(target, 0.0))
|
|
3831
|
+
|
|
3832
|
+
for _ in range(max(1, samples)):
|
|
3833
|
+
perm = list(variables)
|
|
3834
|
+
rng.shuffle(perm)
|
|
3835
|
+
cur_set: List[str] = []
|
|
3836
|
+
prev_val = value(cur_set)
|
|
3837
|
+
for v in perm:
|
|
3838
|
+
cur_set.append(v)
|
|
3839
|
+
new_val = value(cur_set)
|
|
3840
|
+
contributions[v] += new_val - prev_val
|
|
3841
|
+
prev_val = new_val
|
|
3842
|
+
|
|
3843
|
+
shapley_values = {k: v / float(samples) for k, v in contributions.items()}
|
|
3844
|
+
total = sum(abs(v) for v in shapley_values.values()) or 1.0
|
|
3845
|
+
normalized = {k: v / total for k, v in shapley_values.items()}
|
|
3846
|
+
return {
|
|
3847
|
+
"shapley_values": shapley_values,
|
|
3848
|
+
"normalized": normalized,
|
|
3849
|
+
"total_attribution": float(sum(abs(v) for v in shapley_values.values())),
|
|
3850
|
+
}
|
|
3851
|
+
|
|
3852
|
+
# ---- Multi-layer scenarios ----
|
|
3853
|
+
def multi_layer_whatif_analysis(
|
|
3854
|
+
self,
|
|
3855
|
+
scenarios: List[Dict[str, float]],
|
|
3856
|
+
depth: int = 3,
|
|
3857
|
+
) -> Dict[str, Any]:
|
|
3858
|
+
results: List[Dict[str, Any]] = []
|
|
3859
|
+
for scen in scenarios:
|
|
3860
|
+
layer1 = self._predict_outcomes({}, scen)
|
|
3861
|
+
affected = [k for k, v in layer1.items() if abs(v) > 0.01]
|
|
3862
|
+
layer2_scenarios = [{a: layer1.get(a, 0.0) * 1.2} for a in affected[:5]]
|
|
3863
|
+
layer2_results: List[Dict[str, Any]] = []
|
|
3864
|
+
for l2 in layer2_scenarios:
|
|
3865
|
+
l2_outcome = self._predict_outcomes(layer1, l2)
|
|
3866
|
+
layer2_results.append({"layer2_scenario": l2, "layer2_outcomes": l2_outcome})
|
|
3867
|
+
results.append({
|
|
3868
|
+
"scenario": scen,
|
|
3869
|
+
"layer1_direct_effects": layer1,
|
|
3870
|
+
"affected_variables": affected,
|
|
3871
|
+
"layer2_cascades": layer2_results,
|
|
3872
|
+
})
|
|
3873
|
+
return {"multi_layer_analysis": results, "summary": {"total_scenarios": len(results)}}
|
|
3874
|
+
|
|
3875
|
+
def explore_alternate_realities(
|
|
3876
|
+
self,
|
|
3877
|
+
factual_state: Dict[str, float],
|
|
3878
|
+
target_outcome: str,
|
|
3879
|
+
target_value: Optional[float] = None,
|
|
3880
|
+
max_realities: int = 50,
|
|
3881
|
+
max_interventions: int = 3,
|
|
3882
|
+
) -> Dict[str, Any]:
|
|
3883
|
+
rng = np.random.default_rng(self.seed)
|
|
3884
|
+
variables = list(factual_state.keys())
|
|
3885
|
+
realities: List[Dict[str, Any]] = []
|
|
3886
|
+
for _ in range(max_realities):
|
|
3887
|
+
num_int = rng.integers(1, max(2, max_interventions + 1))
|
|
3888
|
+
selected = rng.choice(variables, size=min(num_int, len(variables)), replace=False)
|
|
3889
|
+
intervention = {}
|
|
3890
|
+
for var in selected:
|
|
3891
|
+
stats = self.standardization_stats.get(var, {"mean": factual_state.get(var, 0.0), "std": 1.0})
|
|
3892
|
+
intervention[var] = float(rng.normal(stats["mean"], stats["std"] * 1.5))
|
|
3893
|
+
outcome = self._predict_outcomes(factual_state, intervention)
|
|
3894
|
+
target_val = outcome.get(target_outcome, 0.0)
|
|
3895
|
+
if target_value is not None:
|
|
3896
|
+
objective = -abs(target_val - target_value)
|
|
3897
|
+
else:
|
|
3898
|
+
objective = target_val
|
|
3899
|
+
realities.append({
|
|
3900
|
+
"interventions": intervention,
|
|
3901
|
+
"outcome": outcome,
|
|
3902
|
+
"target_value": float(target_val),
|
|
3903
|
+
"objective": float(objective),
|
|
3904
|
+
"delta_from_factual": float(target_val - factual_state.get(target_outcome, 0.0)),
|
|
3905
|
+
})
|
|
3906
|
+
realities.sort(key=lambda x: x["objective"], reverse=True)
|
|
3907
|
+
best = realities[0] if realities else None
|
|
3908
|
+
return {
|
|
3909
|
+
"factual_state": factual_state,
|
|
3910
|
+
"target_outcome": target_outcome,
|
|
3911
|
+
"target_value": target_value,
|
|
3912
|
+
"best_reality": best,
|
|
3913
|
+
"top_10_realities": realities[:10],
|
|
3914
|
+
"all_realities_explored": len(realities),
|
|
3915
|
+
"improvement_potential": (best["target_value"] - factual_state.get(target_outcome, 0.0)) if best else 0.0,
|
|
3916
|
+
}
|
|
3917
|
+
|
|
3918
|
+
# ---- Async wrappers ----
|
|
3919
|
+
async def run_async(
|
|
3920
|
+
self,
|
|
3921
|
+
task: Optional[Union[str, Any]] = None,
|
|
3922
|
+
initial_state: Optional[Any] = None,
|
|
3923
|
+
target_variables: Optional[List[str]] = None,
|
|
3924
|
+
max_steps: Union[int, str] = 1,
|
|
3925
|
+
**kwargs,
|
|
3926
|
+
) -> Dict[str, Any]:
|
|
3927
|
+
loop = asyncio.get_running_loop()
|
|
3928
|
+
return await loop.run_in_executor(None, lambda: self.run(task=task, initial_state=initial_state, target_variables=target_variables, max_steps=max_steps, **kwargs))
|
|
3929
|
+
|
|
3930
|
+
async def quantify_uncertainty_async(
|
|
3931
|
+
self,
|
|
3932
|
+
df: Any,
|
|
3933
|
+
variables: List[str],
|
|
3934
|
+
windows: int = 200,
|
|
3935
|
+
alpha: float = 0.95
|
|
3936
|
+
) -> Dict[str, Any]:
|
|
3937
|
+
loop = asyncio.get_running_loop()
|
|
3938
|
+
return await loop.run_in_executor(None, lambda: self.quantify_uncertainty(df=df, variables=variables, windows=windows, alpha=alpha))
|
|
3939
|
+
|
|
3940
|
+
async def granger_causality_test_async(
|
|
3941
|
+
self,
|
|
3942
|
+
df: Any,
|
|
3943
|
+
var1: str,
|
|
3944
|
+
var2: str,
|
|
3945
|
+
max_lag: int = 4,
|
|
3946
|
+
) -> Dict[str, Any]:
|
|
3947
|
+
loop = asyncio.get_running_loop()
|
|
3948
|
+
return await loop.run_in_executor(None, lambda: self.granger_causality_test(df=df, var1=var1, var2=var2, max_lag=max_lag))
|
|
3949
|
+
|
|
3950
|
+
async def vector_autoregression_estimation_async(
|
|
3951
|
+
self,
|
|
3952
|
+
df: Any,
|
|
3953
|
+
variables: List[str],
|
|
3954
|
+
max_lag: int = 2,
|
|
3955
|
+
) -> Dict[str, Any]:
|
|
3956
|
+
loop = asyncio.get_running_loop()
|
|
3957
|
+
return await loop.run_in_executor(None, lambda: self.vector_autoregression_estimation(df=df, variables=variables, max_lag=max_lag))
|
|
3958
|
+
|
|
3959
|
+
# =========================
|
|
3960
|
+
# Excel TUI Integration Methods
|
|
3961
|
+
# =========================
|
|
3962
|
+
|
|
3963
|
+
def _initialize_excel_tables(self) -> None:
|
|
3964
|
+
"""Initialize standard Excel tables."""
|
|
3965
|
+
if not self._excel_enabled or self._excel_tables is None:
|
|
3966
|
+
return
|
|
3967
|
+
|
|
3968
|
+
try:
|
|
3969
|
+
from crca_excel.core.standard_tables import initialize_standard_tables
|
|
3970
|
+
initialize_standard_tables(self._excel_tables)
|
|
3971
|
+
except Exception as e:
|
|
3972
|
+
logger.error(f"Error initializing standard tables: {e}")
|
|
3973
|
+
|
|
3974
|
+
def _link_causal_graph_to_tables(self) -> None:
|
|
3975
|
+
"""Link CRCA causal graph to Excel tables."""
|
|
3976
|
+
if not self._excel_enabled or self._excel_scm_bridge is None:
|
|
3977
|
+
return
|
|
3978
|
+
|
|
3979
|
+
try:
|
|
3980
|
+
self._excel_scm_bridge.link_causal_graph_to_tables(self.causal_graph)
|
|
3981
|
+
except Exception as e:
|
|
3982
|
+
logger.error(f"Error linking causal graph to tables: {e}")
|
|
3983
|
+
|
|
3984
|
+
def excel_edit_cell(self, table_name: str, row_key: Any, column_name: str, value: Any) -> None:
|
|
3985
|
+
"""
|
|
3986
|
+
Edit a cell in Excel tables and trigger recomputation.
|
|
3987
|
+
|
|
3988
|
+
Args:
|
|
3989
|
+
table_name: Table name
|
|
3990
|
+
row_key: Row key
|
|
3991
|
+
column_name: Column name
|
|
3992
|
+
value: Value to set
|
|
3993
|
+
"""
|
|
3994
|
+
if not self._excel_enabled or self._excel_tables is None:
|
|
3995
|
+
raise RuntimeError("Excel TUI not enabled")
|
|
3996
|
+
|
|
3997
|
+
self._excel_tables.set_cell(table_name, row_key, column_name, value)
|
|
3998
|
+
|
|
3999
|
+
# Trigger recomputation
|
|
4000
|
+
if self._excel_eval_engine:
|
|
4001
|
+
self._excel_eval_engine.recompute_dirty_cells()
|
|
4002
|
+
|
|
4003
|
+
def excel_apply_plan(self, plan: Dict[Tuple[str, Any, str], Any]) -> Dict[str, Any]:
|
|
4004
|
+
"""
|
|
4005
|
+
Apply a plan (set of interventions) to Excel tables.
|
|
4006
|
+
|
|
4007
|
+
Args:
|
|
4008
|
+
plan: Dictionary of (table_name, row_key, column_name) -> value
|
|
4009
|
+
|
|
4010
|
+
Returns:
|
|
4011
|
+
Dictionary with results
|
|
4012
|
+
"""
|
|
4013
|
+
if not self._excel_enabled or self._excel_scm_bridge is None:
|
|
4014
|
+
raise RuntimeError("Excel TUI not enabled")
|
|
4015
|
+
|
|
4016
|
+
# Convert plan format
|
|
4017
|
+
interventions = {
|
|
4018
|
+
(table_name, row_key, column_name): value
|
|
4019
|
+
for (table_name, row_key, column_name), value in plan.items()
|
|
4020
|
+
}
|
|
4021
|
+
|
|
4022
|
+
snapshot = self._excel_scm_bridge.do_intervention(interventions)
|
|
4023
|
+
|
|
4024
|
+
return {
|
|
4025
|
+
"snapshot": snapshot,
|
|
4026
|
+
"success": True
|
|
4027
|
+
}
|
|
4028
|
+
|
|
4029
|
+
def excel_generate_scenarios(
|
|
4030
|
+
self,
|
|
4031
|
+
base_interventions: Dict[Tuple[str, Any, str], Any],
|
|
4032
|
+
target_variables: List[Tuple[str, Any, str]],
|
|
4033
|
+
n_scenarios: int = 10
|
|
4034
|
+
) -> List[Dict[str, Any]]:
|
|
4035
|
+
"""
|
|
4036
|
+
Generate counterfactual scenarios.
|
|
4037
|
+
|
|
4038
|
+
Args:
|
|
4039
|
+
base_interventions: Base intervention set
|
|
4040
|
+
target_variables: Variables to vary
|
|
4041
|
+
n_scenarios: Number of scenarios
|
|
4042
|
+
|
|
4043
|
+
Returns:
|
|
4044
|
+
List of scenario dictionaries
|
|
4045
|
+
"""
|
|
4046
|
+
if not self._excel_enabled or self._excel_scm_bridge is None:
|
|
4047
|
+
raise RuntimeError("Excel TUI not enabled")
|
|
4048
|
+
|
|
4049
|
+
# Convert format
|
|
4050
|
+
base_interv = {
|
|
4051
|
+
(table_name, row_key, column_name): value
|
|
4052
|
+
for (table_name, row_key, column_name), value in base_interventions.items()
|
|
4053
|
+
}
|
|
4054
|
+
target_vars = [
|
|
4055
|
+
(table_name, row_key, column_name)
|
|
4056
|
+
for (table_name, row_key, column_name) in target_variables
|
|
4057
|
+
]
|
|
4058
|
+
|
|
4059
|
+
return self._excel_scm_bridge.generate_counterfactual_scenarios(
|
|
4060
|
+
base_interv,
|
|
4061
|
+
target_vars,
|
|
4062
|
+
n_scenarios=n_scenarios,
|
|
4063
|
+
seed=self.seed
|
|
4064
|
+
)
|
|
4065
|
+
|
|
4066
|
+
def excel_get_table(self, table_name: str):
|
|
4067
|
+
"""
|
|
4068
|
+
Get an Excel table.
|
|
4069
|
+
|
|
4070
|
+
Args:
|
|
4071
|
+
table_name: Table name
|
|
4072
|
+
|
|
4073
|
+
Returns:
|
|
4074
|
+
Table instance or None
|
|
4075
|
+
"""
|
|
4076
|
+
if not self._excel_enabled or self._excel_tables is None:
|
|
4077
|
+
return None
|
|
4078
|
+
|
|
4079
|
+
return self._excel_tables.get_table(table_name)
|
|
4080
|
+
|
|
4081
|
+
def excel_get_all_tables(self):
|
|
4082
|
+
"""
|
|
4083
|
+
Get all Excel tables.
|
|
4084
|
+
|
|
4085
|
+
Returns:
|
|
4086
|
+
Dictionary of table_name -> Table
|
|
4087
|
+
"""
|
|
4088
|
+
if not self._excel_enabled or self._excel_tables is None:
|
|
4089
|
+
return {}
|
|
4090
|
+
|
|
4091
|
+
return self._excel_tables.get_all_tables()
|
|
4092
|
+
|
|
4093
|
+
# =========================
|
|
4094
|
+
# Policy Engine Methods
|
|
4095
|
+
# =========================
|
|
4096
|
+
|
|
4097
|
+
def run_policy_loop(
|
|
4098
|
+
self,
|
|
4099
|
+
num_epochs: int,
|
|
4100
|
+
sensor_provider: Optional[Callable] = None,
|
|
4101
|
+
actuator: Optional[Callable] = None,
|
|
4102
|
+
start_epoch: int = 0
|
|
4103
|
+
) -> Dict[str, Any]:
|
|
4104
|
+
"""
|
|
4105
|
+
Execute temporal policy loop for specified number of epochs.
|
|
4106
|
+
|
|
4107
|
+
Args:
|
|
4108
|
+
num_epochs: Number of epochs to execute
|
|
4109
|
+
sensor_provider: Function that returns current state snapshot (Dict[str, float])
|
|
4110
|
+
If None, uses dummy sensor (all metrics = 0.0)
|
|
4111
|
+
actuator: Function that executes interventions (takes List[InterventionSpec])
|
|
4112
|
+
If None, interventions are logged but not executed
|
|
4113
|
+
start_epoch: Starting epoch number (default: 0)
|
|
4114
|
+
|
|
4115
|
+
Returns:
|
|
4116
|
+
Dict[str, Any]: Summary with decision hashes and epoch results
|
|
4117
|
+
|
|
4118
|
+
Raises:
|
|
4119
|
+
RuntimeError: If policy_mode is not enabled
|
|
4120
|
+
"""
|
|
4121
|
+
if not self.policy_mode or self.policy_loop is None:
|
|
4122
|
+
raise RuntimeError("Policy mode not enabled. Set policy_mode=True and provide policy.")
|
|
4123
|
+
|
|
4124
|
+
results = []
|
|
4125
|
+
decision_hashes = []
|
|
4126
|
+
|
|
4127
|
+
for epoch in range(start_epoch, start_epoch + num_epochs):
|
|
4128
|
+
try:
|
|
4129
|
+
epoch_result = self.policy_loop.run_epoch(
|
|
4130
|
+
epoch=epoch,
|
|
4131
|
+
sensor_provider=sensor_provider,
|
|
4132
|
+
actuator=actuator
|
|
4133
|
+
)
|
|
4134
|
+
results.append(epoch_result)
|
|
4135
|
+
decision_hashes.append(epoch_result.get("decision_hash", ""))
|
|
4136
|
+
logger.info(f"Epoch {epoch} completed: {len(epoch_result.get('interventions', []))} interventions")
|
|
4137
|
+
except Exception as e:
|
|
4138
|
+
logger.error(f"Error in epoch {epoch}: {e}")
|
|
4139
|
+
raise
|
|
4140
|
+
|
|
4141
|
+
return {
|
|
4142
|
+
"num_epochs": num_epochs,
|
|
4143
|
+
"start_epoch": start_epoch,
|
|
4144
|
+
"end_epoch": start_epoch + num_epochs - 1,
|
|
4145
|
+
"decision_hashes": decision_hashes,
|
|
4146
|
+
"epoch_results": results,
|
|
4147
|
+
"policy_hash": self.policy_loop.compiled_policy.policy_hash,
|
|
4148
|
+
"summary": {
|
|
4149
|
+
"total_interventions": sum(len(r.get("interventions", [])) for r in results),
|
|
4150
|
+
"conservative_mode_triggered": any(r.get("conservative_mode", False) for r in results),
|
|
4151
|
+
"drift_detected": any(r.get("cusum_stat", 0.0) > self.policy_loop.cusum_h for r in results)
|
|
4152
|
+
}
|
|
4153
|
+
}
|
|
4154
|
+
|
|
4155
|
+
|
|
4156
|
+
|