crca 1.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (501) hide show
  1. .github/ISSUE_TEMPLATE/bug_report.md +65 -0
  2. .github/ISSUE_TEMPLATE/feature_request.md +41 -0
  3. .github/PULL_REQUEST_TEMPLATE.md +20 -0
  4. .github/workflows/publish-manual.yml +61 -0
  5. .github/workflows/publish.yml +64 -0
  6. .gitignore +214 -0
  7. CRCA.py +4156 -0
  8. LICENSE +201 -0
  9. MANIFEST.in +43 -0
  10. PKG-INFO +5035 -0
  11. README.md +4959 -0
  12. __init__.py +17 -0
  13. branches/CRCA-Q.py +2728 -0
  14. branches/crca_cg/corposwarm.py +9065 -0
  15. branches/crca_cg/fix_rancher_docker_creds.ps1 +155 -0
  16. branches/crca_cg/package.json +5 -0
  17. branches/crca_cg/test_bolt_integration.py +446 -0
  18. branches/crca_cg/test_corposwarm_comprehensive.py +773 -0
  19. branches/crca_cg/test_new_features.py +163 -0
  20. branches/crca_sd/__init__.py +149 -0
  21. branches/crca_sd/crca_sd_core.py +770 -0
  22. branches/crca_sd/crca_sd_governance.py +1325 -0
  23. branches/crca_sd/crca_sd_mpc.py +1130 -0
  24. branches/crca_sd/crca_sd_realtime.py +1844 -0
  25. branches/crca_sd/crca_sd_tui.py +1133 -0
  26. crca-1.4.0.dist-info/METADATA +5035 -0
  27. crca-1.4.0.dist-info/RECORD +501 -0
  28. crca-1.4.0.dist-info/WHEEL +4 -0
  29. crca-1.4.0.dist-info/licenses/LICENSE +201 -0
  30. docs/CRCA-Q.md +2333 -0
  31. examples/config.yaml.example +25 -0
  32. examples/crca_sd_example.py +513 -0
  33. examples/data_broker_example.py +294 -0
  34. examples/logistics_corporation.py +861 -0
  35. examples/palantir_example.py +299 -0
  36. examples/policy_bench.py +934 -0
  37. examples/pridnestrovia-sd.py +705 -0
  38. examples/pridnestrovia_realtime.py +1902 -0
  39. prompts/__init__.py +10 -0
  40. prompts/default_crca.py +101 -0
  41. pyproject.toml +151 -0
  42. requirements.txt +76 -0
  43. schemas/__init__.py +43 -0
  44. schemas/mcpSchemas.py +51 -0
  45. schemas/policy.py +458 -0
  46. templates/__init__.py +38 -0
  47. templates/base_specialized_agent.py +195 -0
  48. templates/drift_detection.py +325 -0
  49. templates/examples/causal_agent_template.py +309 -0
  50. templates/examples/drag_drop_example.py +213 -0
  51. templates/examples/logistics_agent_template.py +207 -0
  52. templates/examples/trading_agent_template.py +206 -0
  53. templates/feature_mixins.py +253 -0
  54. templates/graph_management.py +442 -0
  55. templates/llm_integration.py +194 -0
  56. templates/module_registry.py +276 -0
  57. templates/mpc_planner.py +280 -0
  58. templates/policy_loop.py +1168 -0
  59. templates/prediction_framework.py +448 -0
  60. templates/statistical_methods.py +778 -0
  61. tests/sanity.yml +31 -0
  62. tests/sanity_check +406 -0
  63. tests/test_core.py +47 -0
  64. tests/test_crca_excel.py +166 -0
  65. tests/test_crca_sd.py +780 -0
  66. tests/test_data_broker.py +424 -0
  67. tests/test_palantir.py +349 -0
  68. tools/__init__.py +38 -0
  69. tools/actuators.py +437 -0
  70. tools/bolt.diy/Dockerfile +103 -0
  71. tools/bolt.diy/app/components/@settings/core/AvatarDropdown.tsx +175 -0
  72. tools/bolt.diy/app/components/@settings/core/ControlPanel.tsx +345 -0
  73. tools/bolt.diy/app/components/@settings/core/constants.tsx +108 -0
  74. tools/bolt.diy/app/components/@settings/core/types.ts +114 -0
  75. tools/bolt.diy/app/components/@settings/index.ts +12 -0
  76. tools/bolt.diy/app/components/@settings/shared/components/TabTile.tsx +151 -0
  77. tools/bolt.diy/app/components/@settings/shared/service-integration/ConnectionForm.tsx +193 -0
  78. tools/bolt.diy/app/components/@settings/shared/service-integration/ConnectionTestIndicator.tsx +60 -0
  79. tools/bolt.diy/app/components/@settings/shared/service-integration/ErrorState.tsx +102 -0
  80. tools/bolt.diy/app/components/@settings/shared/service-integration/LoadingState.tsx +94 -0
  81. tools/bolt.diy/app/components/@settings/shared/service-integration/ServiceHeader.tsx +72 -0
  82. tools/bolt.diy/app/components/@settings/shared/service-integration/index.ts +6 -0
  83. tools/bolt.diy/app/components/@settings/tabs/data/DataTab.tsx +721 -0
  84. tools/bolt.diy/app/components/@settings/tabs/data/DataVisualization.tsx +384 -0
  85. tools/bolt.diy/app/components/@settings/tabs/event-logs/EventLogsTab.tsx +1013 -0
  86. tools/bolt.diy/app/components/@settings/tabs/features/FeaturesTab.tsx +295 -0
  87. tools/bolt.diy/app/components/@settings/tabs/github/GitHubTab.tsx +281 -0
  88. tools/bolt.diy/app/components/@settings/tabs/github/components/GitHubAuthDialog.tsx +173 -0
  89. tools/bolt.diy/app/components/@settings/tabs/github/components/GitHubCacheManager.tsx +367 -0
  90. tools/bolt.diy/app/components/@settings/tabs/github/components/GitHubConnection.tsx +233 -0
  91. tools/bolt.diy/app/components/@settings/tabs/github/components/GitHubErrorBoundary.tsx +105 -0
  92. tools/bolt.diy/app/components/@settings/tabs/github/components/GitHubProgressiveLoader.tsx +266 -0
  93. tools/bolt.diy/app/components/@settings/tabs/github/components/GitHubRepositoryCard.tsx +121 -0
  94. tools/bolt.diy/app/components/@settings/tabs/github/components/GitHubRepositorySelector.tsx +312 -0
  95. tools/bolt.diy/app/components/@settings/tabs/github/components/GitHubStats.tsx +291 -0
  96. tools/bolt.diy/app/components/@settings/tabs/github/components/GitHubUserProfile.tsx +46 -0
  97. tools/bolt.diy/app/components/@settings/tabs/github/components/shared/GitHubStateIndicators.tsx +264 -0
  98. tools/bolt.diy/app/components/@settings/tabs/github/components/shared/RepositoryCard.tsx +361 -0
  99. tools/bolt.diy/app/components/@settings/tabs/github/components/shared/index.ts +11 -0
  100. tools/bolt.diy/app/components/@settings/tabs/gitlab/GitLabTab.tsx +305 -0
  101. tools/bolt.diy/app/components/@settings/tabs/gitlab/components/GitLabAuthDialog.tsx +186 -0
  102. tools/bolt.diy/app/components/@settings/tabs/gitlab/components/GitLabConnection.tsx +253 -0
  103. tools/bolt.diy/app/components/@settings/tabs/gitlab/components/GitLabRepositorySelector.tsx +358 -0
  104. tools/bolt.diy/app/components/@settings/tabs/gitlab/components/RepositoryCard.tsx +79 -0
  105. tools/bolt.diy/app/components/@settings/tabs/gitlab/components/RepositoryList.tsx +142 -0
  106. tools/bolt.diy/app/components/@settings/tabs/gitlab/components/StatsDisplay.tsx +91 -0
  107. tools/bolt.diy/app/components/@settings/tabs/gitlab/components/index.ts +4 -0
  108. tools/bolt.diy/app/components/@settings/tabs/mcp/McpServerList.tsx +99 -0
  109. tools/bolt.diy/app/components/@settings/tabs/mcp/McpServerListItem.tsx +70 -0
  110. tools/bolt.diy/app/components/@settings/tabs/mcp/McpStatusBadge.tsx +37 -0
  111. tools/bolt.diy/app/components/@settings/tabs/mcp/McpTab.tsx +239 -0
  112. tools/bolt.diy/app/components/@settings/tabs/netlify/NetlifyTab.tsx +1393 -0
  113. tools/bolt.diy/app/components/@settings/tabs/netlify/components/NetlifyConnection.tsx +990 -0
  114. tools/bolt.diy/app/components/@settings/tabs/netlify/components/index.ts +1 -0
  115. tools/bolt.diy/app/components/@settings/tabs/notifications/NotificationsTab.tsx +300 -0
  116. tools/bolt.diy/app/components/@settings/tabs/profile/ProfileTab.tsx +181 -0
  117. tools/bolt.diy/app/components/@settings/tabs/providers/cloud/CloudProvidersTab.tsx +308 -0
  118. tools/bolt.diy/app/components/@settings/tabs/providers/local/ErrorBoundary.tsx +68 -0
  119. tools/bolt.diy/app/components/@settings/tabs/providers/local/HealthStatusBadge.tsx +64 -0
  120. tools/bolt.diy/app/components/@settings/tabs/providers/local/LoadingSkeleton.tsx +107 -0
  121. tools/bolt.diy/app/components/@settings/tabs/providers/local/LocalProvidersTab.tsx +556 -0
  122. tools/bolt.diy/app/components/@settings/tabs/providers/local/ModelCard.tsx +106 -0
  123. tools/bolt.diy/app/components/@settings/tabs/providers/local/ProviderCard.tsx +120 -0
  124. tools/bolt.diy/app/components/@settings/tabs/providers/local/SetupGuide.tsx +671 -0
  125. tools/bolt.diy/app/components/@settings/tabs/providers/local/StatusDashboard.tsx +91 -0
  126. tools/bolt.diy/app/components/@settings/tabs/providers/local/types.ts +44 -0
  127. tools/bolt.diy/app/components/@settings/tabs/settings/SettingsTab.tsx +215 -0
  128. tools/bolt.diy/app/components/@settings/tabs/supabase/SupabaseTab.tsx +1089 -0
  129. tools/bolt.diy/app/components/@settings/tabs/vercel/VercelTab.tsx +909 -0
  130. tools/bolt.diy/app/components/@settings/tabs/vercel/components/VercelConnection.tsx +368 -0
  131. tools/bolt.diy/app/components/@settings/tabs/vercel/components/index.ts +1 -0
  132. tools/bolt.diy/app/components/@settings/utils/tab-helpers.ts +54 -0
  133. tools/bolt.diy/app/components/chat/APIKeyManager.tsx +169 -0
  134. tools/bolt.diy/app/components/chat/Artifact.tsx +296 -0
  135. tools/bolt.diy/app/components/chat/AssistantMessage.tsx +192 -0
  136. tools/bolt.diy/app/components/chat/BaseChat.module.scss +47 -0
  137. tools/bolt.diy/app/components/chat/BaseChat.tsx +522 -0
  138. tools/bolt.diy/app/components/chat/Chat.client.tsx +670 -0
  139. tools/bolt.diy/app/components/chat/ChatAlert.tsx +108 -0
  140. tools/bolt.diy/app/components/chat/ChatBox.tsx +334 -0
  141. tools/bolt.diy/app/components/chat/CodeBlock.module.scss +10 -0
  142. tools/bolt.diy/app/components/chat/CodeBlock.tsx +85 -0
  143. tools/bolt.diy/app/components/chat/DicussMode.tsx +17 -0
  144. tools/bolt.diy/app/components/chat/ExamplePrompts.tsx +37 -0
  145. tools/bolt.diy/app/components/chat/FilePreview.tsx +38 -0
  146. tools/bolt.diy/app/components/chat/GitCloneButton.tsx +327 -0
  147. tools/bolt.diy/app/components/chat/ImportFolderButton.tsx +141 -0
  148. tools/bolt.diy/app/components/chat/LLMApiAlert.tsx +109 -0
  149. tools/bolt.diy/app/components/chat/MCPTools.tsx +129 -0
  150. tools/bolt.diy/app/components/chat/Markdown.module.scss +171 -0
  151. tools/bolt.diy/app/components/chat/Markdown.spec.ts +48 -0
  152. tools/bolt.diy/app/components/chat/Markdown.tsx +252 -0
  153. tools/bolt.diy/app/components/chat/Messages.client.tsx +102 -0
  154. tools/bolt.diy/app/components/chat/ModelSelector.tsx +797 -0
  155. tools/bolt.diy/app/components/chat/NetlifyDeploymentLink.client.tsx +51 -0
  156. tools/bolt.diy/app/components/chat/ProgressCompilation.tsx +110 -0
  157. tools/bolt.diy/app/components/chat/ScreenshotStateManager.tsx +33 -0
  158. tools/bolt.diy/app/components/chat/SendButton.client.tsx +39 -0
  159. tools/bolt.diy/app/components/chat/SpeechRecognition.tsx +28 -0
  160. tools/bolt.diy/app/components/chat/StarterTemplates.tsx +38 -0
  161. tools/bolt.diy/app/components/chat/SupabaseAlert.tsx +199 -0
  162. tools/bolt.diy/app/components/chat/SupabaseConnection.tsx +339 -0
  163. tools/bolt.diy/app/components/chat/ThoughtBox.tsx +43 -0
  164. tools/bolt.diy/app/components/chat/ToolInvocations.tsx +409 -0
  165. tools/bolt.diy/app/components/chat/UserMessage.tsx +101 -0
  166. tools/bolt.diy/app/components/chat/VercelDeploymentLink.client.tsx +158 -0
  167. tools/bolt.diy/app/components/chat/chatExportAndImport/ExportChatButton.tsx +49 -0
  168. tools/bolt.diy/app/components/chat/chatExportAndImport/ImportButtons.tsx +96 -0
  169. tools/bolt.diy/app/components/deploy/DeployAlert.tsx +197 -0
  170. tools/bolt.diy/app/components/deploy/DeployButton.tsx +277 -0
  171. tools/bolt.diy/app/components/deploy/GitHubDeploy.client.tsx +171 -0
  172. tools/bolt.diy/app/components/deploy/GitHubDeploymentDialog.tsx +1041 -0
  173. tools/bolt.diy/app/components/deploy/GitLabDeploy.client.tsx +171 -0
  174. tools/bolt.diy/app/components/deploy/GitLabDeploymentDialog.tsx +764 -0
  175. tools/bolt.diy/app/components/deploy/NetlifyDeploy.client.tsx +246 -0
  176. tools/bolt.diy/app/components/deploy/VercelDeploy.client.tsx +235 -0
  177. tools/bolt.diy/app/components/editor/codemirror/BinaryContent.tsx +7 -0
  178. tools/bolt.diy/app/components/editor/codemirror/CodeMirrorEditor.tsx +555 -0
  179. tools/bolt.diy/app/components/editor/codemirror/EnvMasking.ts +80 -0
  180. tools/bolt.diy/app/components/editor/codemirror/cm-theme.ts +192 -0
  181. tools/bolt.diy/app/components/editor/codemirror/indent.ts +68 -0
  182. tools/bolt.diy/app/components/editor/codemirror/languages.ts +112 -0
  183. tools/bolt.diy/app/components/git/GitUrlImport.client.tsx +147 -0
  184. tools/bolt.diy/app/components/header/Header.tsx +42 -0
  185. tools/bolt.diy/app/components/header/HeaderActionButtons.client.tsx +54 -0
  186. tools/bolt.diy/app/components/mandate/MandateSubmission.tsx +167 -0
  187. tools/bolt.diy/app/components/observability/DeploymentStatus.tsx +168 -0
  188. tools/bolt.diy/app/components/observability/EventTimeline.tsx +119 -0
  189. tools/bolt.diy/app/components/observability/FileDiffViewer.tsx +121 -0
  190. tools/bolt.diy/app/components/observability/GovernanceStatus.tsx +197 -0
  191. tools/bolt.diy/app/components/observability/GovernorMetrics.tsx +246 -0
  192. tools/bolt.diy/app/components/observability/LogStream.tsx +244 -0
  193. tools/bolt.diy/app/components/observability/MandateDetails.tsx +201 -0
  194. tools/bolt.diy/app/components/observability/ObservabilityDashboard.tsx +200 -0
  195. tools/bolt.diy/app/components/sidebar/HistoryItem.tsx +187 -0
  196. tools/bolt.diy/app/components/sidebar/Menu.client.tsx +536 -0
  197. tools/bolt.diy/app/components/sidebar/date-binning.ts +59 -0
  198. tools/bolt.diy/app/components/txt +1 -0
  199. tools/bolt.diy/app/components/ui/BackgroundRays/index.tsx +18 -0
  200. tools/bolt.diy/app/components/ui/BackgroundRays/styles.module.scss +246 -0
  201. tools/bolt.diy/app/components/ui/Badge.tsx +53 -0
  202. tools/bolt.diy/app/components/ui/BranchSelector.tsx +270 -0
  203. tools/bolt.diy/app/components/ui/Breadcrumbs.tsx +101 -0
  204. tools/bolt.diy/app/components/ui/Button.tsx +46 -0
  205. tools/bolt.diy/app/components/ui/Card.tsx +55 -0
  206. tools/bolt.diy/app/components/ui/Checkbox.tsx +32 -0
  207. tools/bolt.diy/app/components/ui/CloseButton.tsx +49 -0
  208. tools/bolt.diy/app/components/ui/CodeBlock.tsx +103 -0
  209. tools/bolt.diy/app/components/ui/Collapsible.tsx +9 -0
  210. tools/bolt.diy/app/components/ui/ColorSchemeDialog.tsx +378 -0
  211. tools/bolt.diy/app/components/ui/Dialog.tsx +449 -0
  212. tools/bolt.diy/app/components/ui/Dropdown.tsx +63 -0
  213. tools/bolt.diy/app/components/ui/EmptyState.tsx +154 -0
  214. tools/bolt.diy/app/components/ui/FileIcon.tsx +346 -0
  215. tools/bolt.diy/app/components/ui/FilterChip.tsx +92 -0
  216. tools/bolt.diy/app/components/ui/GlowingEffect.tsx +192 -0
  217. tools/bolt.diy/app/components/ui/GradientCard.tsx +100 -0
  218. tools/bolt.diy/app/components/ui/IconButton.tsx +84 -0
  219. tools/bolt.diy/app/components/ui/Input.tsx +22 -0
  220. tools/bolt.diy/app/components/ui/Label.tsx +20 -0
  221. tools/bolt.diy/app/components/ui/LoadingDots.tsx +27 -0
  222. tools/bolt.diy/app/components/ui/LoadingOverlay.tsx +32 -0
  223. tools/bolt.diy/app/components/ui/PanelHeader.tsx +20 -0
  224. tools/bolt.diy/app/components/ui/PanelHeaderButton.tsx +36 -0
  225. tools/bolt.diy/app/components/ui/Popover.tsx +29 -0
  226. tools/bolt.diy/app/components/ui/Progress.tsx +22 -0
  227. tools/bolt.diy/app/components/ui/RepositoryStats.tsx +87 -0
  228. tools/bolt.diy/app/components/ui/ScrollArea.tsx +41 -0
  229. tools/bolt.diy/app/components/ui/SearchInput.tsx +80 -0
  230. tools/bolt.diy/app/components/ui/SearchResultItem.tsx +134 -0
  231. tools/bolt.diy/app/components/ui/Separator.tsx +22 -0
  232. tools/bolt.diy/app/components/ui/SettingsButton.tsx +35 -0
  233. tools/bolt.diy/app/components/ui/Slider.tsx +73 -0
  234. tools/bolt.diy/app/components/ui/StatusIndicator.tsx +90 -0
  235. tools/bolt.diy/app/components/ui/Switch.tsx +37 -0
  236. tools/bolt.diy/app/components/ui/Tabs.tsx +52 -0
  237. tools/bolt.diy/app/components/ui/TabsWithSlider.tsx +112 -0
  238. tools/bolt.diy/app/components/ui/ThemeSwitch.tsx +29 -0
  239. tools/bolt.diy/app/components/ui/Tooltip.tsx +122 -0
  240. tools/bolt.diy/app/components/ui/index.ts +38 -0
  241. tools/bolt.diy/app/components/ui/use-toast.ts +66 -0
  242. tools/bolt.diy/app/components/workbench/DiffView.tsx +796 -0
  243. tools/bolt.diy/app/components/workbench/EditorPanel.tsx +174 -0
  244. tools/bolt.diy/app/components/workbench/ExpoQrModal.tsx +55 -0
  245. tools/bolt.diy/app/components/workbench/FileBreadcrumb.tsx +150 -0
  246. tools/bolt.diy/app/components/workbench/FileTree.tsx +565 -0
  247. tools/bolt.diy/app/components/workbench/Inspector.tsx +126 -0
  248. tools/bolt.diy/app/components/workbench/InspectorPanel.tsx +146 -0
  249. tools/bolt.diy/app/components/workbench/LockManager.tsx +262 -0
  250. tools/bolt.diy/app/components/workbench/PortDropdown.tsx +91 -0
  251. tools/bolt.diy/app/components/workbench/Preview.tsx +1049 -0
  252. tools/bolt.diy/app/components/workbench/ScreenshotSelector.tsx +293 -0
  253. tools/bolt.diy/app/components/workbench/Search.tsx +257 -0
  254. tools/bolt.diy/app/components/workbench/Workbench.client.tsx +506 -0
  255. tools/bolt.diy/app/components/workbench/terminal/Terminal.tsx +131 -0
  256. tools/bolt.diy/app/components/workbench/terminal/TerminalManager.tsx +68 -0
  257. tools/bolt.diy/app/components/workbench/terminal/TerminalTabs.tsx +277 -0
  258. tools/bolt.diy/app/components/workbench/terminal/theme.ts +36 -0
  259. tools/bolt.diy/app/components/workflow/WorkflowPhase.tsx +109 -0
  260. tools/bolt.diy/app/components/workflow/WorkflowStatus.tsx +60 -0
  261. tools/bolt.diy/app/components/workflow/WorkflowTimeline.tsx +150 -0
  262. tools/bolt.diy/app/entry.client.tsx +7 -0
  263. tools/bolt.diy/app/entry.server.tsx +80 -0
  264. tools/bolt.diy/app/root.tsx +156 -0
  265. tools/bolt.diy/app/routes/_index.tsx +175 -0
  266. tools/bolt.diy/app/routes/api.bug-report.ts +254 -0
  267. tools/bolt.diy/app/routes/api.chat.ts +463 -0
  268. tools/bolt.diy/app/routes/api.check-env-key.ts +41 -0
  269. tools/bolt.diy/app/routes/api.configured-providers.ts +110 -0
  270. tools/bolt.diy/app/routes/api.corporate-swarm-status.ts +55 -0
  271. tools/bolt.diy/app/routes/api.enhancer.ts +137 -0
  272. tools/bolt.diy/app/routes/api.export-api-keys.ts +44 -0
  273. tools/bolt.diy/app/routes/api.git-info.ts +69 -0
  274. tools/bolt.diy/app/routes/api.git-proxy.$.ts +178 -0
  275. tools/bolt.diy/app/routes/api.github-branches.ts +166 -0
  276. tools/bolt.diy/app/routes/api.github-deploy.ts +67 -0
  277. tools/bolt.diy/app/routes/api.github-stats.ts +198 -0
  278. tools/bolt.diy/app/routes/api.github-template.ts +242 -0
  279. tools/bolt.diy/app/routes/api.github-user.ts +287 -0
  280. tools/bolt.diy/app/routes/api.gitlab-branches.ts +143 -0
  281. tools/bolt.diy/app/routes/api.gitlab-deploy.ts +67 -0
  282. tools/bolt.diy/app/routes/api.gitlab-projects.ts +105 -0
  283. tools/bolt.diy/app/routes/api.health.ts +8 -0
  284. tools/bolt.diy/app/routes/api.llmcall.ts +298 -0
  285. tools/bolt.diy/app/routes/api.mandate.ts +351 -0
  286. tools/bolt.diy/app/routes/api.mcp-check.ts +16 -0
  287. tools/bolt.diy/app/routes/api.mcp-update-config.ts +23 -0
  288. tools/bolt.diy/app/routes/api.models.$provider.ts +2 -0
  289. tools/bolt.diy/app/routes/api.models.ts +90 -0
  290. tools/bolt.diy/app/routes/api.netlify-deploy.ts +240 -0
  291. tools/bolt.diy/app/routes/api.netlify-user.ts +142 -0
  292. tools/bolt.diy/app/routes/api.supabase-user.ts +199 -0
  293. tools/bolt.diy/app/routes/api.supabase.query.ts +92 -0
  294. tools/bolt.diy/app/routes/api.supabase.ts +56 -0
  295. tools/bolt.diy/app/routes/api.supabase.variables.ts +32 -0
  296. tools/bolt.diy/app/routes/api.system.diagnostics.ts +142 -0
  297. tools/bolt.diy/app/routes/api.system.disk-info.ts +311 -0
  298. tools/bolt.diy/app/routes/api.system.git-info.ts +332 -0
  299. tools/bolt.diy/app/routes/api.update.ts +21 -0
  300. tools/bolt.diy/app/routes/api.vercel-deploy.ts +497 -0
  301. tools/bolt.diy/app/routes/api.vercel-user.ts +161 -0
  302. tools/bolt.diy/app/routes/api.workflow-status.$proposalId.ts +309 -0
  303. tools/bolt.diy/app/routes/chat.$id.tsx +8 -0
  304. tools/bolt.diy/app/routes/execute.$mandateId.tsx +432 -0
  305. tools/bolt.diy/app/routes/git.tsx +25 -0
  306. tools/bolt.diy/app/routes/observability.$mandateId.tsx +50 -0
  307. tools/bolt.diy/app/routes/webcontainer.connect.$id.tsx +32 -0
  308. tools/bolt.diy/app/routes/webcontainer.preview.$id.tsx +97 -0
  309. tools/bolt.diy/app/routes/workflow.$proposalId.tsx +170 -0
  310. tools/bolt.diy/app/styles/animations.scss +49 -0
  311. tools/bolt.diy/app/styles/components/code.scss +9 -0
  312. tools/bolt.diy/app/styles/components/editor.scss +135 -0
  313. tools/bolt.diy/app/styles/components/resize-handle.scss +30 -0
  314. tools/bolt.diy/app/styles/components/terminal.scss +3 -0
  315. tools/bolt.diy/app/styles/components/toast.scss +23 -0
  316. tools/bolt.diy/app/styles/diff-view.css +72 -0
  317. tools/bolt.diy/app/styles/index.scss +73 -0
  318. tools/bolt.diy/app/styles/variables.scss +255 -0
  319. tools/bolt.diy/app/styles/z-index.scss +37 -0
  320. tools/bolt.diy/app/types/GitHub.ts +182 -0
  321. tools/bolt.diy/app/types/GitLab.ts +103 -0
  322. tools/bolt.diy/app/types/actions.ts +85 -0
  323. tools/bolt.diy/app/types/artifact.ts +5 -0
  324. tools/bolt.diy/app/types/context.ts +26 -0
  325. tools/bolt.diy/app/types/design-scheme.ts +93 -0
  326. tools/bolt.diy/app/types/global.d.ts +13 -0
  327. tools/bolt.diy/app/types/mandate.ts +333 -0
  328. tools/bolt.diy/app/types/model.ts +25 -0
  329. tools/bolt.diy/app/types/netlify.ts +94 -0
  330. tools/bolt.diy/app/types/supabase.ts +54 -0
  331. tools/bolt.diy/app/types/template.ts +8 -0
  332. tools/bolt.diy/app/types/terminal.ts +9 -0
  333. tools/bolt.diy/app/types/theme.ts +1 -0
  334. tools/bolt.diy/app/types/vercel.ts +67 -0
  335. tools/bolt.diy/app/utils/buffer.ts +29 -0
  336. tools/bolt.diy/app/utils/classNames.ts +65 -0
  337. tools/bolt.diy/app/utils/constants.ts +147 -0
  338. tools/bolt.diy/app/utils/debounce.ts +13 -0
  339. tools/bolt.diy/app/utils/debugLogger.ts +1284 -0
  340. tools/bolt.diy/app/utils/diff.spec.ts +11 -0
  341. tools/bolt.diy/app/utils/diff.ts +117 -0
  342. tools/bolt.diy/app/utils/easings.ts +3 -0
  343. tools/bolt.diy/app/utils/fileLocks.ts +96 -0
  344. tools/bolt.diy/app/utils/fileUtils.ts +121 -0
  345. tools/bolt.diy/app/utils/folderImport.ts +73 -0
  346. tools/bolt.diy/app/utils/formatSize.ts +12 -0
  347. tools/bolt.diy/app/utils/getLanguageFromExtension.ts +24 -0
  348. tools/bolt.diy/app/utils/githubStats.ts +9 -0
  349. tools/bolt.diy/app/utils/gitlabStats.ts +54 -0
  350. tools/bolt.diy/app/utils/logger.ts +162 -0
  351. tools/bolt.diy/app/utils/markdown.ts +155 -0
  352. tools/bolt.diy/app/utils/mobile.ts +4 -0
  353. tools/bolt.diy/app/utils/os.ts +4 -0
  354. tools/bolt.diy/app/utils/path.ts +19 -0
  355. tools/bolt.diy/app/utils/projectCommands.ts +197 -0
  356. tools/bolt.diy/app/utils/promises.ts +19 -0
  357. tools/bolt.diy/app/utils/react.ts +6 -0
  358. tools/bolt.diy/app/utils/sampler.ts +49 -0
  359. tools/bolt.diy/app/utils/selectStarterTemplate.ts +255 -0
  360. tools/bolt.diy/app/utils/shell.ts +384 -0
  361. tools/bolt.diy/app/utils/stacktrace.ts +27 -0
  362. tools/bolt.diy/app/utils/stripIndent.ts +23 -0
  363. tools/bolt.diy/app/utils/terminal.ts +11 -0
  364. tools/bolt.diy/app/utils/unreachable.ts +3 -0
  365. tools/bolt.diy/app/vite-env.d.ts +2 -0
  366. tools/bolt.diy/assets/entitlements.mac.plist +25 -0
  367. tools/bolt.diy/assets/icons/icon.icns +0 -0
  368. tools/bolt.diy/assets/icons/icon.ico +0 -0
  369. tools/bolt.diy/assets/icons/icon.png +0 -0
  370. tools/bolt.diy/bindings.js +78 -0
  371. tools/bolt.diy/bindings.sh +33 -0
  372. tools/bolt.diy/docker-compose.yaml +145 -0
  373. tools/bolt.diy/electron/main/index.ts +201 -0
  374. tools/bolt.diy/electron/main/tsconfig.json +30 -0
  375. tools/bolt.diy/electron/main/ui/menu.ts +29 -0
  376. tools/bolt.diy/electron/main/ui/window.ts +54 -0
  377. tools/bolt.diy/electron/main/utils/auto-update.ts +110 -0
  378. tools/bolt.diy/electron/main/utils/constants.ts +4 -0
  379. tools/bolt.diy/electron/main/utils/cookie.ts +40 -0
  380. tools/bolt.diy/electron/main/utils/reload.ts +35 -0
  381. tools/bolt.diy/electron/main/utils/serve.ts +71 -0
  382. tools/bolt.diy/electron/main/utils/store.ts +3 -0
  383. tools/bolt.diy/electron/main/utils/vite-server.ts +44 -0
  384. tools/bolt.diy/electron/main/vite.config.ts +44 -0
  385. tools/bolt.diy/electron/preload/index.ts +22 -0
  386. tools/bolt.diy/electron/preload/tsconfig.json +7 -0
  387. tools/bolt.diy/electron/preload/vite.config.ts +31 -0
  388. tools/bolt.diy/electron-builder.yml +64 -0
  389. tools/bolt.diy/electron-update.yml +4 -0
  390. tools/bolt.diy/eslint.config.mjs +57 -0
  391. tools/bolt.diy/functions/[[path]].ts +12 -0
  392. tools/bolt.diy/icons/angular.svg +1 -0
  393. tools/bolt.diy/icons/astro.svg +8 -0
  394. tools/bolt.diy/icons/chat.svg +1 -0
  395. tools/bolt.diy/icons/expo-brand.svg +1 -0
  396. tools/bolt.diy/icons/expo.svg +4 -0
  397. tools/bolt.diy/icons/logo-text.svg +1 -0
  398. tools/bolt.diy/icons/logo.svg +4 -0
  399. tools/bolt.diy/icons/mcp.svg +1 -0
  400. tools/bolt.diy/icons/nativescript.svg +1 -0
  401. tools/bolt.diy/icons/netlify.svg +10 -0
  402. tools/bolt.diy/icons/nextjs.svg +1 -0
  403. tools/bolt.diy/icons/nuxt.svg +1 -0
  404. tools/bolt.diy/icons/qwik.svg +1 -0
  405. tools/bolt.diy/icons/react.svg +1 -0
  406. tools/bolt.diy/icons/remix.svg +24 -0
  407. tools/bolt.diy/icons/remotion.svg +1 -0
  408. tools/bolt.diy/icons/shadcn.svg +21 -0
  409. tools/bolt.diy/icons/slidev.svg +60 -0
  410. tools/bolt.diy/icons/solidjs.svg +1 -0
  411. tools/bolt.diy/icons/stars.svg +1 -0
  412. tools/bolt.diy/icons/svelte.svg +1 -0
  413. tools/bolt.diy/icons/typescript.svg +1 -0
  414. tools/bolt.diy/icons/vite.svg +1 -0
  415. tools/bolt.diy/icons/vue.svg +1 -0
  416. tools/bolt.diy/load-context.ts +9 -0
  417. tools/bolt.diy/notarize.cjs +31 -0
  418. tools/bolt.diy/package.json +218 -0
  419. tools/bolt.diy/playwright.config.preview.ts +35 -0
  420. tools/bolt.diy/pre-start.cjs +26 -0
  421. tools/bolt.diy/public/apple-touch-icon-precomposed.png +0 -0
  422. tools/bolt.diy/public/apple-touch-icon.png +0 -0
  423. tools/bolt.diy/public/favicon.ico +0 -0
  424. tools/bolt.diy/public/favicon.svg +4 -0
  425. tools/bolt.diy/public/icons/AmazonBedrock.svg +1 -0
  426. tools/bolt.diy/public/icons/Anthropic.svg +4 -0
  427. tools/bolt.diy/public/icons/Cohere.svg +4 -0
  428. tools/bolt.diy/public/icons/Deepseek.svg +5 -0
  429. tools/bolt.diy/public/icons/Default.svg +4 -0
  430. tools/bolt.diy/public/icons/Google.svg +4 -0
  431. tools/bolt.diy/public/icons/Groq.svg +4 -0
  432. tools/bolt.diy/public/icons/HuggingFace.svg +4 -0
  433. tools/bolt.diy/public/icons/Hyperbolic.svg +3 -0
  434. tools/bolt.diy/public/icons/LMStudio.svg +5 -0
  435. tools/bolt.diy/public/icons/Mistral.svg +4 -0
  436. tools/bolt.diy/public/icons/Ollama.svg +4 -0
  437. tools/bolt.diy/public/icons/OpenAI.svg +4 -0
  438. tools/bolt.diy/public/icons/OpenAILike.svg +4 -0
  439. tools/bolt.diy/public/icons/OpenRouter.svg +4 -0
  440. tools/bolt.diy/public/icons/Perplexity.svg +4 -0
  441. tools/bolt.diy/public/icons/Together.svg +4 -0
  442. tools/bolt.diy/public/icons/xAI.svg +5 -0
  443. tools/bolt.diy/public/inspector-script.js +292 -0
  444. tools/bolt.diy/public/logo-dark-styled.png +0 -0
  445. tools/bolt.diy/public/logo-dark.png +0 -0
  446. tools/bolt.diy/public/logo-light-styled.png +0 -0
  447. tools/bolt.diy/public/logo-light.png +0 -0
  448. tools/bolt.diy/public/logo.svg +15 -0
  449. tools/bolt.diy/public/social_preview_index.jpg +0 -0
  450. tools/bolt.diy/scripts/clean.js +45 -0
  451. tools/bolt.diy/scripts/electron-dev.mjs +181 -0
  452. tools/bolt.diy/scripts/setup-env.sh +41 -0
  453. tools/bolt.diy/scripts/update-imports.sh +7 -0
  454. tools/bolt.diy/scripts/update.sh +52 -0
  455. tools/bolt.diy/services/execution-governor/Dockerfile +41 -0
  456. tools/bolt.diy/services/execution-governor/config.ts +42 -0
  457. tools/bolt.diy/services/execution-governor/index.ts +683 -0
  458. tools/bolt.diy/services/execution-governor/metrics.ts +141 -0
  459. tools/bolt.diy/services/execution-governor/package.json +31 -0
  460. tools/bolt.diy/services/execution-governor/priority-queue.ts +139 -0
  461. tools/bolt.diy/services/execution-governor/tsconfig.json +21 -0
  462. tools/bolt.diy/services/execution-governor/types.ts +145 -0
  463. tools/bolt.diy/services/headless-executor/Dockerfile +43 -0
  464. tools/bolt.diy/services/headless-executor/executor.ts +210 -0
  465. tools/bolt.diy/services/headless-executor/index.ts +323 -0
  466. tools/bolt.diy/services/headless-executor/package.json +27 -0
  467. tools/bolt.diy/services/headless-executor/tsconfig.json +21 -0
  468. tools/bolt.diy/services/headless-executor/types.ts +38 -0
  469. tools/bolt.diy/test-workflows.sh +240 -0
  470. tools/bolt.diy/tests/integration/corporate-swarm.test.ts +208 -0
  471. tools/bolt.diy/tests/mandates/budget-limited.json +34 -0
  472. tools/bolt.diy/tests/mandates/complex.json +53 -0
  473. tools/bolt.diy/tests/mandates/constraint-enforced.json +36 -0
  474. tools/bolt.diy/tests/mandates/simple.json +35 -0
  475. tools/bolt.diy/tsconfig.json +37 -0
  476. tools/bolt.diy/types/istextorbinary.d.ts +15 -0
  477. tools/bolt.diy/uno.config.ts +279 -0
  478. tools/bolt.diy/vite-electron.config.ts +76 -0
  479. tools/bolt.diy/vite.config.ts +112 -0
  480. tools/bolt.diy/worker-configuration.d.ts +22 -0
  481. tools/bolt.diy/wrangler.toml +6 -0
  482. tools/code_generator.py +461 -0
  483. tools/file_operations.py +465 -0
  484. tools/mandate_generator.py +337 -0
  485. tools/mcpClientUtils.py +1216 -0
  486. tools/sensors.py +285 -0
  487. utils/Agent_types.py +15 -0
  488. utils/AnyToStr.py +0 -0
  489. utils/HHCS.py +277 -0
  490. utils/__init__.py +30 -0
  491. utils/agent.py +3627 -0
  492. utils/aop.py +2948 -0
  493. utils/canonical.py +143 -0
  494. utils/conversation.py +1195 -0
  495. utils/doctrine_versioning +230 -0
  496. utils/formatter.py +474 -0
  497. utils/ledger.py +311 -0
  498. utils/out_types.py +16 -0
  499. utils/rollback.py +339 -0
  500. utils/router.py +929 -0
  501. utils/tui.py +1908 -0
CRCA.py ADDED
@@ -0,0 +1,4156 @@
1
+ """CRCAAgent - Causal Reasoning and Counterfactual Analysis Agent.
2
+
3
+ This module provides a lightweight causal reasoning agent with LLM integration,
4
+ implemented in pure Python and intended as a flexible CR-CA engine for Swarms.
5
+ """
6
+
7
+ # Standard library imports
8
+ import asyncio
9
+ import importlib
10
+ import importlib.util
11
+ import inspect
12
+ import logging
13
+ import math
14
+ import os
15
+ import threading
16
+ from dataclasses import dataclass
17
+ from enum import Enum
18
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
19
+
20
+ # Third-party imports
21
+ import numpy as np
22
+ from loguru import logger
23
+ from swarms.structs.agent import Agent
24
+
25
+ # Try to import rustworkx (required dependency)
26
+ try:
27
+ import rustworkx as rx
28
+ except Exception as e:
29
+ raise ImportError(
30
+ "rustworkx is required for the CRCAAgent rustworkx upgrade: pip install rustworkx"
31
+ ) from e
32
+
33
+ # Optional heavy dependencies — used when available
34
+ try:
35
+ import pandas as pd # type: ignore
36
+ PANDAS_AVAILABLE = True
37
+ except Exception:
38
+ PANDAS_AVAILABLE = False
39
+
40
+ try:
41
+ from scipy import linalg as scipy_linalg # type: ignore
42
+ from scipy import stats as scipy_stats # type: ignore
43
+ SCIPY_AVAILABLE = True
44
+ except Exception:
45
+ SCIPY_AVAILABLE = False
46
+
47
+ try:
48
+ import cvxpy as cp # type: ignore
49
+ CVXPY_AVAILABLE = True
50
+ except Exception:
51
+ CVXPY_AVAILABLE = False
52
+
53
+ # Load environment variables from .env file
54
+ try:
55
+ from dotenv import load_dotenv
56
+ import os
57
+
58
+ # Explicitly load from .env file in the project root
59
+ env_path = os.path.join(os.path.dirname(__file__), '.env')
60
+ load_dotenv(dotenv_path=env_path, override=False)
61
+ except ImportError:
62
+ # dotenv not available, skip loading
63
+ pass
64
+ except Exception:
65
+ # .env file might not exist, that's okay
66
+ pass
67
+
68
+ # Local imports
69
+ try:
70
+ from prompts.default_crca import DEFAULT_CRCA_SYSTEM_PROMPT
71
+ except ImportError:
72
+ # Fallback if prompt file doesn't exist
73
+ DEFAULT_CRCA_SYSTEM_PROMPT = None
74
+
75
+ # Policy engine imports (optional - only if policy_mode is enabled)
76
+ try:
77
+ from schemas.policy import DoctrineV1
78
+ from utils.ledger import Ledger
79
+ from templates.policy_loop import PolicyLoopMixin
80
+ POLICY_ENGINE_AVAILABLE = True
81
+ except ImportError:
82
+ POLICY_ENGINE_AVAILABLE = False
83
+ logger.debug("Policy engine modules not available")
84
+
85
+ # Fix litellm async compatibility
86
+ try:
87
+ lu_spec = importlib.util.find_spec("litellm.litellm_core_utils.logging_utils")
88
+ if lu_spec is not None:
89
+ lu = importlib.import_module("litellm.litellm_core_utils.logging_utils")
90
+ try:
91
+ if hasattr(lu, "asyncio") and hasattr(lu.asyncio, "iscoroutinefunction"):
92
+ lu.asyncio.iscoroutinefunction = inspect.iscoroutinefunction
93
+ except Exception:
94
+ pass
95
+ except Exception:
96
+ pass
97
+
98
+
99
+ class CausalRelationType(Enum):
100
+ """Enumeration of causal relationship types.
101
+
102
+ Defines the different types of causal relationships that can exist
103
+ between variables in a causal graph.
104
+ """
105
+
106
+ DIRECT = "direct"
107
+ INDIRECT = "indirect"
108
+ CONFOUNDING = "confounding"
109
+ MEDIATING = "mediating"
110
+ MODERATING = "moderating"
111
+
112
+
113
+ @dataclass
114
+ class CausalNode:
115
+ """Represents a node in the causal graph.
116
+
117
+ Attributes:
118
+ name: Name of the variable/node
119
+ value: Current value of the variable (optional)
120
+ confidence: Confidence level in the node (default: 1.0)
121
+ node_type: Type of the node (default: "variable")
122
+ """
123
+
124
+ name: str
125
+ value: Optional[float] = None
126
+ confidence: float = 1.0
127
+ node_type: str = "variable"
128
+
129
+
130
+ @dataclass
131
+ class CausalEdge:
132
+ """Represents an edge (causal relationship) in the causal graph.
133
+
134
+ Attributes:
135
+ source: Source variable name
136
+ target: Target variable name
137
+ strength: Strength of the causal relationship (default: 1.0)
138
+ relation_type: Type of causal relation (default: DIRECT)
139
+ confidence: Confidence level in the edge (default: 1.0)
140
+ """
141
+
142
+ source: str
143
+ target: str
144
+ strength: float = 1.0
145
+ relation_type: CausalRelationType = CausalRelationType.DIRECT
146
+ confidence: float = 1.0
147
+
148
+
149
+ @dataclass
150
+ class CounterfactualScenario:
151
+ """Represents a counterfactual scenario for analysis.
152
+
153
+ Attributes:
154
+ name: Name/identifier of the scenario
155
+ interventions: Dictionary mapping variable names to intervention values
156
+ expected_outcomes: Dictionary mapping variable names to expected outcomes
157
+ probability: Probability of this scenario (default: 1.0)
158
+ reasoning: Explanation of why this scenario is important
159
+ uncertainty_metadata: Optional metadata about prediction confidence, graph uncertainty, scenario relevance
160
+ sampling_distribution: Optional distribution type used for sampling (gaussian/uniform/mixture/adaptive)
161
+ monte_carlo_iterations: Optional number of Monte Carlo samples used
162
+ meta_reasoning_score: Optional overall quality/informativeness score
163
+ """
164
+
165
+ name: str
166
+ interventions: Dict[str, float]
167
+ expected_outcomes: Dict[str, float]
168
+ probability: float = 1.0
169
+ reasoning: str = ""
170
+ uncertainty_metadata: Optional[Dict[str, Any]] = None
171
+ sampling_distribution: Optional[str] = None
172
+ monte_carlo_iterations: Optional[int] = None
173
+ meta_reasoning_score: Optional[float] = None
174
+
175
+
176
+ # Internal helper classes for meta-Monte Carlo counterfactual reasoning
177
+ class _AdaptiveInterventionSampler:
178
+ """Adaptive intervention sampler based on causal graph structure."""
179
+
180
+ def __init__(self, agent: 'CRCAAgent'):
181
+ """Initialize sampler with reference to agent.
182
+
183
+ Args:
184
+ agent: CRCAAgent instance for accessing causal graph and stats
185
+ """
186
+ self.agent = agent
187
+
188
+ def sample_interventions(
189
+ self,
190
+ factual_state: Dict[str, float],
191
+ target_variables: List[str],
192
+ n_samples: int
193
+ ) -> List[Dict[str, float]]:
194
+ """Sample interventions using adaptive distributions.
195
+
196
+ Args:
197
+ factual_state: Current factual state
198
+ target_variables: Variables to sample interventions for
199
+ n_samples: Number of samples to generate
200
+
201
+ Returns:
202
+ List of intervention dictionaries
203
+ """
204
+ samples = []
205
+ rng = np.random.default_rng(self.agent.seed if self.agent.seed is not None else None)
206
+
207
+ for _ in range(n_samples):
208
+ intervention = {}
209
+ for var in target_variables:
210
+ dist_type, params = self._get_adaptive_distribution(var, factual_state)
211
+
212
+ if dist_type == "gaussian":
213
+ val = self._sample_gaussian(var, params["mean"], params["std"], 1, rng)[0]
214
+ elif dist_type == "uniform":
215
+ val = self._sample_uniform(var, params["bounds"], 1, rng)[0]
216
+ elif dist_type == "mixture":
217
+ val = self._sample_mixture(var, params["components"], 1, rng)[0]
218
+ else: # adaptive/graph-based
219
+ val = self._sample_from_graph_structure(var, factual_state, 1, rng)[0]
220
+
221
+ intervention[var] = float(val)
222
+
223
+ samples.append(intervention)
224
+
225
+ return samples
226
+
227
+ def _get_adaptive_distribution(
228
+ self,
229
+ var: str,
230
+ factual_state: Dict[str, float]
231
+ ) -> Tuple[str, Dict[str, Any]]:
232
+ """Select distribution type based on graph structure.
233
+
234
+ Args:
235
+ var: Variable name
236
+ factual_state: Current factual state
237
+
238
+ Returns:
239
+ Tuple of (distribution_type, parameters)
240
+ """
241
+ # Get edge strengths and confidence
242
+ parents = self.agent._get_parents(var)
243
+ children = self.agent._get_children(var)
244
+
245
+ # Calculate average edge strength
246
+ avg_strength = 0.0
247
+ avg_confidence = 1.0
248
+ if parents:
249
+ strengths = [abs(self.agent._edge_strength(p, var)) for p in parents]
250
+ confidences = []
251
+ for p in parents:
252
+ edge = self.agent.causal_graph.get(p, {}).get(var, {})
253
+ if isinstance(edge, dict):
254
+ confidences.append(edge.get("confidence", 1.0))
255
+ else:
256
+ confidences.append(1.0)
257
+ avg_strength = sum(strengths) / len(strengths) if strengths else 0.0
258
+ avg_confidence = sum(confidences) / len(confidences) if confidences else 1.0
259
+
260
+ # Get path length (max depth from root)
261
+ path_length = self._get_max_path_length(var)
262
+
263
+ # Get variable importance (number of descendants)
264
+ importance = len(self.agent._get_descendants(var))
265
+
266
+ # Get stats
267
+ stats = self.agent.standardization_stats.get(var, {"mean": 0.0, "std": 1.0})
268
+ mean = stats.get("mean", 0.0)
269
+ std = stats.get("std", 1.0) or 1.0
270
+ factual_val = factual_state.get(var, mean)
271
+
272
+ # Decision logic
273
+ if avg_confidence > 0.8 and avg_strength > 0.5:
274
+ # High confidence + strong edges -> narrow Gaussian
275
+ return "gaussian", {
276
+ "mean": factual_val,
277
+ "std": std * 0.3 # Narrow distribution
278
+ }
279
+ elif avg_confidence < 0.5:
280
+ # Low confidence -> wide uniform or mixture
281
+ if path_length > 3:
282
+ return "mixture", {
283
+ "components": [
284
+ {"type": "gaussian", "mean": factual_val, "std": std * 1.5, "weight": 0.5},
285
+ {"type": "uniform", "bounds": (factual_val - 3*std, factual_val + 3*std), "weight": 0.5}
286
+ ]
287
+ }
288
+ else:
289
+ return "uniform", {
290
+ "bounds": (factual_val - 2*std, factual_val + 2*std)
291
+ }
292
+ elif path_length > 4:
293
+ # Long causal paths -> mixture to capture path uncertainty
294
+ return "mixture", {
295
+ "components": [
296
+ {"type": "gaussian", "mean": factual_val, "std": std * 0.8, "weight": 0.7},
297
+ {"type": "uniform", "bounds": (factual_val - 2*std, factual_val + 2*std), "weight": 0.3}
298
+ ]
299
+ }
300
+ elif len(parents) > 3:
301
+ # Many parents -> mixture to capture multi-parent uncertainty
302
+ return "mixture", {
303
+ "components": [
304
+ {"type": "gaussian", "mean": factual_val, "std": std * 0.6, "weight": 0.6},
305
+ {"type": "uniform", "bounds": (factual_val - 2*std, factual_val + 2*std), "weight": 0.4}
306
+ ]
307
+ }
308
+ elif len(parents) == 0:
309
+ # Exogenous variable -> uniform
310
+ return "uniform", {
311
+ "bounds": (factual_val - 2*std, factual_val + 2*std)
312
+ }
313
+ else:
314
+ # Default: adaptive based on graph structure
315
+ return "adaptive", {
316
+ "mean": factual_val,
317
+ "std": std,
318
+ "edge_strength": avg_strength,
319
+ "confidence": avg_confidence,
320
+ "path_length": path_length
321
+ }
322
+
323
+ def _get_max_path_length(self, var: str) -> int:
324
+ """Get maximum path length from root to variable."""
325
+ def dfs(node: str, visited: set, depth: int) -> int:
326
+ if node in visited:
327
+ return depth
328
+ visited.add(node)
329
+ parents = self.agent._get_parents(node)
330
+ if not parents:
331
+ return depth
332
+ return max([dfs(p, visited.copy(), depth + 1) for p in parents] + [depth])
333
+
334
+ return dfs(var, set(), 0)
335
+
336
+ def _sample_gaussian(
337
+ self,
338
+ var: str,
339
+ mean: float,
340
+ std: float,
341
+ n: int,
342
+ rng: np.random.Generator
343
+ ) -> List[float]:
344
+ """Sample from Gaussian distribution."""
345
+ return [float(x) for x in rng.normal(mean, std, n)]
346
+
347
+ def _sample_uniform(
348
+ self,
349
+ var: str,
350
+ bounds: Tuple[float, float],
351
+ n: int,
352
+ rng: np.random.Generator
353
+ ) -> List[float]:
354
+ """Sample from uniform distribution."""
355
+ low, high = bounds
356
+ return [float(x) for x in rng.uniform(low, high, n)]
357
+
358
+ def _sample_mixture(
359
+ self,
360
+ var: str,
361
+ components: List[Dict[str, Any]],
362
+ n: int,
363
+ rng: np.random.Generator
364
+ ) -> List[float]:
365
+ """Sample from mixture distribution."""
366
+ samples = []
367
+ for _ in range(n):
368
+ # Select component based on weights
369
+ weights = [c.get("weight", 1.0/len(components)) for c in components]
370
+ total_weight = sum(weights)
371
+ probs = [w / total_weight for w in weights]
372
+ component_idx = rng.choice(len(components), p=probs)
373
+ component = components[component_idx]
374
+
375
+ if component["type"] == "gaussian":
376
+ val = rng.normal(component["mean"], component["std"])
377
+ else: # uniform
378
+ low, high = component["bounds"]
379
+ val = rng.uniform(low, high)
380
+
381
+ samples.append(float(val))
382
+
383
+ return samples
384
+
385
+ def _sample_from_graph_structure(
386
+ self,
387
+ var: str,
388
+ factual_state: Dict[str, float],
389
+ n: int,
390
+ rng: np.random.Generator
391
+ ) -> List[float]:
392
+ """Sample using graph structure information."""
393
+ stats = self.agent.standardization_stats.get(var, {"mean": 0.0, "std": 1.0})
394
+ mean = stats.get("mean", 0.0)
395
+ std = stats.get("std", 1.0) or 1.0
396
+ factual_val = factual_state.get(var, mean)
397
+
398
+ # Use graph-based adaptive std
399
+ parents = self.agent._get_parents(var)
400
+ if parents:
401
+ avg_strength = sum([abs(self.agent._edge_strength(p, var)) for p in parents]) / len(parents)
402
+ # Stronger edges -> narrower distribution
403
+ adaptive_std = std * (1.0 - 0.5 * min(1.0, avg_strength))
404
+ else:
405
+ adaptive_std = std * 1.5 # Wider for exogenous
406
+
407
+ return self._sample_gaussian(var, factual_val, adaptive_std, n, rng)
408
+
409
+
410
+ class _GraphUncertaintySampler:
411
+ """Sample graph variations for uncertainty quantification."""
412
+
413
+ def __init__(self, agent: 'CRCAAgent'):
414
+ """Initialize sampler with reference to agent.
415
+
416
+ Args:
417
+ agent: CRCAAgent instance for accessing causal graph
418
+ """
419
+ self.agent = agent
420
+
421
+ def sample_graph_variations(
422
+ self,
423
+ n_samples: int,
424
+ uncertainty_data: Optional[Dict[str, Any]] = None
425
+ ) -> List[Dict[Tuple[str, str], float]]:
426
+ """Sample alternative graph structures.
427
+
428
+ Args:
429
+ n_samples: Number of graph variations to sample
430
+ uncertainty_data: Optional uncertainty data from quantify_uncertainty()
431
+
432
+ Returns:
433
+ List of graph variation dictionaries mapping (source, target) -> strength
434
+ """
435
+ variations = []
436
+ rng = np.random.default_rng(self.agent.seed if self.agent.seed is not None else None)
437
+
438
+ # Get baseline strengths
439
+ baseline_strengths: Dict[Tuple[str, str], float] = {}
440
+ for u, targets in self.agent.causal_graph.items():
441
+ for v, meta in targets.items():
442
+ try:
443
+ baseline_strengths[(u, v)] = float(meta.get("strength", 0.0)) if isinstance(meta, dict) else float(meta)
444
+ except Exception:
445
+ baseline_strengths[(u, v)] = 0.0
446
+
447
+ # Get confidence intervals if available
448
+ edge_cis = {}
449
+ if uncertainty_data and "edge_cis" in uncertainty_data:
450
+ edge_cis = uncertainty_data["edge_cis"]
451
+
452
+ for _ in range(n_samples):
453
+ variation = {}
454
+ for (u, v), baseline_strength in baseline_strengths.items():
455
+ edge_key = f"{u}->{v}"
456
+ if edge_key in edge_cis:
457
+ # Sample from confidence interval
458
+ ci_lower, ci_upper = edge_cis[edge_key]
459
+ # Use truncated normal within CI
460
+ mean = (ci_lower + ci_upper) / 2.0
461
+ std = (ci_upper - ci_lower) / 4.0 # Approximate std from CI
462
+ sampled = rng.normal(mean, std)
463
+ # Truncate to CI bounds
464
+ sampled = max(ci_lower, min(ci_upper, sampled))
465
+ else:
466
+ # Sample around baseline with small perturbation
467
+ edge = self.agent.causal_graph.get(u, {}).get(v, {})
468
+ confidence = edge.get("confidence", 1.0) if isinstance(edge, dict) else 1.0
469
+ # Lower confidence -> larger perturbation
470
+ perturbation_std = 0.1 * (2.0 - confidence)
471
+ sampled = baseline_strength + rng.normal(0.0, perturbation_std)
472
+
473
+ variation[(u, v)] = float(sampled)
474
+
475
+ variations.append(variation)
476
+
477
+ return variations
478
+
479
+
480
+ class _PredictionQualityAssessor:
481
+ """Assess quality and reliability of counterfactual predictions."""
482
+
483
+ def __init__(self, agent: 'CRCAAgent'):
484
+ """Initialize assessor with reference to agent.
485
+
486
+ Args:
487
+ agent: CRCAAgent instance for accessing causal graph
488
+ """
489
+ self.agent = agent
490
+
491
+ def assess_quality(
492
+ self,
493
+ predictions_across_variants: List[Dict[str, float]],
494
+ factual_state: Dict[str, float],
495
+ interventions: Dict[str, float]
496
+ ) -> Tuple[float, Dict[str, Any]]:
497
+ """Evaluate prediction reliability.
498
+
499
+ Args:
500
+ predictions_across_variants: List of predictions from different graph variants
501
+ factual_state: Original factual state
502
+ interventions: Applied interventions
503
+
504
+ Returns:
505
+ Tuple of (quality_score, detailed_metrics)
506
+ """
507
+ if not predictions_across_variants:
508
+ return 0.0, {}
509
+
510
+ # Calculate consistency (variance across variants)
511
+ all_vars = set()
512
+ for pred in predictions_across_variants:
513
+ all_vars.update(pred.keys())
514
+
515
+ consistency_scores = {}
516
+ for var in all_vars:
517
+ values = [pred.get(var, 0.0) for pred in predictions_across_variants]
518
+ if len(values) > 1:
519
+ variance = float(np.var(values))
520
+ std = float(np.std(values))
521
+ mean_val = float(np.mean(values))
522
+ # Consistency: lower variance = higher consistency
523
+ # Normalize by mean to get coefficient of variation
524
+ cv = std / abs(mean_val) if abs(mean_val) > 1e-6 else std
525
+ consistency_scores[var] = {
526
+ "variance": variance,
527
+ "std": std,
528
+ "mean": mean_val,
529
+ "coefficient_of_variation": cv,
530
+ "consistency": max(0.0, 1.0 - min(1.0, cv)) # 1.0 = perfect consistency
531
+ }
532
+ else:
533
+ consistency_scores[var] = {
534
+ "variance": 0.0,
535
+ "std": 0.0,
536
+ "mean": values[0] if values else 0.0,
537
+ "coefficient_of_variation": 0.0,
538
+ "consistency": 1.0
539
+ }
540
+
541
+ # Calculate confidence based on edge strengths and path lengths
542
+ confidence_scores = {}
543
+ for var in all_vars:
544
+ # Get path from intervention variables to this variable
545
+ max_path_strength = 0.0
546
+ min_path_length = float('inf')
547
+
548
+ for interv_var in interventions.keys():
549
+ path = self.agent.identify_causal_chain(interv_var, var)
550
+ if path:
551
+ path_length = len(path) - 1
552
+ min_path_length = min(min_path_length, path_length)
553
+
554
+ # Calculate path strength (product of edge strengths)
555
+ path_strength = 1.0
556
+ for i in range(len(path) - 1):
557
+ u, v = path[i], path[i + 1]
558
+ edge_strength = abs(self.agent._edge_strength(u, v))
559
+ path_strength *= edge_strength
560
+ max_path_strength = max(max_path_strength, path_strength)
561
+
562
+ # Confidence: higher path strength, shorter path = higher confidence
563
+ path_confidence = max_path_strength * (1.0 / (1.0 + min_path_length * 0.1))
564
+ confidence_scores[var] = {
565
+ "path_strength": max_path_strength,
566
+ "path_length": min_path_length if min_path_length != float('inf') else 0,
567
+ "confidence": float(path_confidence)
568
+ }
569
+
570
+ # Calculate sensitivity (how much predictions change with small graph perturbations)
571
+ sensitivity_scores = {}
572
+ if len(predictions_across_variants) > 1:
573
+ baseline_pred = predictions_across_variants[0]
574
+ for var in all_vars:
575
+ baseline_val = baseline_pred.get(var, 0.0)
576
+ perturbations = [abs(pred.get(var, 0.0) - baseline_val) for pred in predictions_across_variants[1:]]
577
+ avg_perturbation = float(np.mean(perturbations)) if perturbations else 0.0
578
+ max_perturbation = float(max(perturbations)) if perturbations else 0.0
579
+
580
+ # Sensitivity: lower perturbation = lower sensitivity = better
581
+ sensitivity_scores[var] = {
582
+ "avg_perturbation": avg_perturbation,
583
+ "max_perturbation": max_perturbation,
584
+ "sensitivity": min(1.0, avg_perturbation / (abs(baseline_val) + 1e-6))
585
+ }
586
+
587
+ # Overall quality score: weighted combination
588
+ overall_quality = 0.0
589
+ if consistency_scores and confidence_scores:
590
+ consistency_avg = float(np.mean([s["consistency"] for s in consistency_scores.values()]))
591
+ confidence_avg = float(np.mean([s["confidence"] for s in confidence_scores.values()]))
592
+
593
+ # Weight: consistency 40%, confidence 60%
594
+ overall_quality = 0.4 * consistency_avg + 0.6 * confidence_avg
595
+
596
+ metrics = {
597
+ "consistency": consistency_scores,
598
+ "confidence": confidence_scores,
599
+ "sensitivity": sensitivity_scores,
600
+ "overall_quality": overall_quality
601
+ }
602
+
603
+ return overall_quality, metrics
604
+
605
+
606
+ class _MetaReasoningAnalyzer:
607
+ """Analyze scenarios for meta-level reasoning about informativeness."""
608
+
609
+ def __init__(self, agent: 'CRCAAgent'):
610
+ """Initialize analyzer with reference to agent.
611
+
612
+ Args:
613
+ agent: CRCAAgent instance for accessing causal graph
614
+ """
615
+ self.agent = agent
616
+
617
+ def analyze_scenarios(
618
+ self,
619
+ scenarios_with_metadata: List[Tuple[CounterfactualScenario, Dict[str, Any]]]
620
+ ) -> List[Tuple[CounterfactualScenario, float]]:
621
+ """Comprehensive meta-analysis of scenarios.
622
+
623
+ Args:
624
+ scenarios_with_metadata: List of (scenario, metadata) tuples
625
+
626
+ Returns:
627
+ List of (scenario, meta_reasoning_score) tuples, sorted by score
628
+ """
629
+ scored_scenarios = []
630
+
631
+ for scenario, metadata in scenarios_with_metadata:
632
+ # Calculate scenario relevance (information gain, expected utility)
633
+ relevance_score = self._calculate_relevance(scenario, metadata)
634
+
635
+ # Calculate graph uncertainty impact
636
+ uncertainty_impact = self._calculate_uncertainty_impact(scenario, metadata)
637
+
638
+ # Calculate prediction reliability
639
+ reliability = metadata.get("quality_score", 0.5)
640
+
641
+ # Informativeness score: combines relevance + reliability
642
+ informativeness = 0.5 * relevance_score + 0.3 * reliability + 0.2 * (1.0 - uncertainty_impact)
643
+
644
+ scored_scenarios.append((scenario, informativeness))
645
+
646
+ # Sort by informativeness (descending)
647
+ scored_scenarios.sort(key=lambda x: x[1], reverse=True)
648
+
649
+ return scored_scenarios
650
+
651
+ def _calculate_relevance(
652
+ self,
653
+ scenario: CounterfactualScenario,
654
+ metadata: Dict[str, Any]
655
+ ) -> float:
656
+ """Calculate scenario relevance (information gain, expected utility)."""
657
+ # Information gain: how different is this scenario from factual?
658
+ factual_state = metadata.get("factual_state", {})
659
+ interventions = scenario.interventions
660
+
661
+ # Calculate magnitude of intervention
662
+ intervention_magnitude = 0.0
663
+ for var, val in interventions.items():
664
+ factual_val = factual_state.get(var, 0.0)
665
+ stats = self.agent.standardization_stats.get(var, {"mean": 0.0, "std": 1.0})
666
+ std = stats.get("std", 1.0) or 1.0
667
+ # Normalized difference
668
+ diff = abs(val - factual_val) / std if std > 0 else abs(val - factual_val)
669
+ intervention_magnitude += diff
670
+
671
+ # Expected utility: how much do outcomes change?
672
+ outcome_magnitude = 0.0
673
+ for var, val in scenario.expected_outcomes.items():
674
+ factual_val = factual_state.get(var, 0.0)
675
+ stats = self.agent.standardization_stats.get(var, {"mean": 0.0, "std": 1.0})
676
+ std = stats.get("std", 1.0) or 1.0
677
+ diff = abs(val - factual_val) / std if std > 0 else abs(val - factual_val)
678
+ outcome_magnitude += diff
679
+
680
+ # Relevance: combination of intervention and outcome magnitude
681
+ # Normalize by number of variables
682
+ n_vars = max(1, len(interventions))
683
+ relevance = (intervention_magnitude + outcome_magnitude) / (2.0 * n_vars)
684
+
685
+ return min(1.0, relevance)
686
+
687
+ def _calculate_uncertainty_impact(
688
+ self,
689
+ scenario: CounterfactualScenario,
690
+ metadata: Dict[str, Any]
691
+ ) -> float:
692
+ """Calculate how graph uncertainty affects predictions."""
693
+ quality_metrics = metadata.get("quality_metrics", {})
694
+ consistency = quality_metrics.get("consistency", {})
695
+
696
+ if not consistency:
697
+ return 0.5 # Default moderate uncertainty
698
+
699
+ # Average coefficient of variation across all variables
700
+ cvs = [s.get("coefficient_of_variation", 0.0) for s in consistency.values()]
701
+ avg_cv = float(np.mean(cvs)) if cvs else 0.0
702
+
703
+ # Uncertainty impact: higher CV = higher impact
704
+ return min(1.0, avg_cv)
705
+
706
+
707
+ class CRCAAgent(Agent):
708
+ """Causal Reasoning with Counterfactual Analysis Agent.
709
+
710
+ A lightweight causal reasoning agent with LLM integration, providing both
711
+ LLM-based causal analysis and deterministic causal simulation. Supports
712
+ automatic variable extraction, causal graph management, counterfactual
713
+ scenario generation, and comprehensive causal analysis.
714
+
715
+ Key Features:
716
+ - LLM integration for sophisticated causal reasoning
717
+ - Dual-mode operation: LLM-based analysis and deterministic simulation
718
+ - Automatic variable extraction from natural language tasks
719
+ - Causal graph management with rustworkx backend
720
+ - Counterfactual scenario generation
721
+ - Batch prediction support
722
+ - Async/await support for concurrent operations
723
+
724
+ Attributes:
725
+ causal_graph: Dictionary representing the causal graph structure
726
+ causal_memory: List storing analysis steps and results
727
+ causal_max_loops: Maximum number of loops for causal reasoning
728
+ """
729
+
730
+ def __init__(
731
+ self,
732
+ variables: Optional[List[str]] = None,
733
+ causal_edges: Optional[List[Tuple[str, str]]] = None,
734
+ max_loops: Optional[Union[int, str]] = 3,
735
+ agent_name: str = "cr-ca-lite-agent",
736
+ agent_description: str = "Lightweight Causal Reasoning with Counterfactual Analysis Agent",
737
+ description: Optional[str] = None,
738
+ model_name: str = "gpt-4o",
739
+ system_prompt: Optional[str] = None,
740
+ global_system_prompt: Optional[str] = None,
741
+ secondary_system_prompt: Optional[str] = None,
742
+ enable_batch_predict: bool = False,
743
+ max_batch_size: int = 32,
744
+ bootstrap_workers: int = 0,
745
+ use_async: bool = False,
746
+ seed: Optional[int] = None,
747
+ enable_excel: bool = False,
748
+ agent_max_loops: Optional[Union[int, str]] = None,
749
+ policy: Optional[Union[DoctrineV1, str]] = None,
750
+ ledger_path: Optional[str] = None,
751
+ epoch_seconds: int = 3600,
752
+ policy_mode: bool = False,
753
+ sensor_registry: Optional[Any] = None,
754
+ actuator_registry: Optional[Any] = None,
755
+ **kwargs,
756
+ ):
757
+ """
758
+ Initialize CRCAAgent with causal reasoning capabilities.
759
+
760
+ Args:
761
+ variables: List of variable names for the causal graph
762
+ causal_edges: List of (source, target) tuples defining causal relationships
763
+ max_loops: Maximum loops for causal reasoning (default: 3)
764
+ agent_max_loops: Maximum loops for standard Agent operations (supports "auto")
765
+ If not provided, defaults to 1 for individual LLM calls.
766
+ Pass "auto" to enable automatic loop detection for standard operations.
767
+ policy: Policy doctrine (DoctrineV1 instance or path to JSON file) for policy mode
768
+ ledger_path: Path to SQLite ledger database for event storage
769
+ epoch_seconds: Length of one epoch in seconds (default: 3600)
770
+ policy_mode: Enable temporal policy engine mode (default: False)
771
+ **kwargs: Additional arguments passed to parent Agent class
772
+ """
773
+
774
+ cr_ca_schema = CRCAAgent._get_cr_ca_schema()
775
+ extract_variables_schema = CRCAAgent._get_extract_variables_schema()
776
+
777
+ # Backwards-compatible alias for description
778
+ agent_description = description or agent_description
779
+
780
+ # Handle max_loops for standard Agent operations
781
+ # agent_max_loops parameter takes precedence, then check kwargs, then default to 1
782
+ if agent_max_loops is None:
783
+ agent_max_loops = kwargs.pop("max_loops", 1)
784
+ else:
785
+ # Remove max_loops from kwargs if agent_max_loops was explicitly provided
786
+ kwargs.pop("max_loops", None)
787
+
788
+ # Merge tools_list_dictionary from kwargs with CRCA schema
789
+ # Only add CRCA schema if user hasn't explicitly disabled it
790
+ use_crca_tools = kwargs.pop("use_crca_tools", True) # Default to True for backwards compatibility
791
+ existing_tools = kwargs.pop("tools_list_dictionary", [])
792
+ if not isinstance(existing_tools, list):
793
+ existing_tools = [existing_tools] if existing_tools else []
794
+
795
+ # Only add CRCA schemas if enabled
796
+ if use_crca_tools:
797
+ tools_list = [cr_ca_schema, extract_variables_schema] + existing_tools
798
+ else:
799
+ tools_list = existing_tools
800
+
801
+ # Get existing callable tools (functions) from kwargs
802
+ existing_callable_tools = kwargs.pop("tools", [])
803
+ if not isinstance(existing_callable_tools, list):
804
+ existing_callable_tools = [existing_callable_tools] if existing_callable_tools else []
805
+
806
+ agent_kwargs = {
807
+ "agent_name": agent_name,
808
+ "agent_description": agent_description,
809
+ "model_name": model_name,
810
+ "max_loops": agent_max_loops, # Use user-provided value or default to 1
811
+ "output_type": "final",
812
+ **kwargs, # All other Agent parameters passed through
813
+ }
814
+
815
+ # Always provide tools_list_dictionary if we have schemas
816
+ # This ensures the LLM knows about the tools
817
+ if tools_list:
818
+ agent_kwargs["tools_list_dictionary"] = tools_list
819
+
820
+ # Add existing callable tools if any
821
+ if existing_callable_tools:
822
+ agent_kwargs["tools"] = existing_callable_tools
823
+
824
+ # Always apply default CRCA prompt as base, then add custom prompt on top if provided
825
+ final_system_prompt = None
826
+ if DEFAULT_CRCA_SYSTEM_PROMPT is not None:
827
+ if system_prompt is not None:
828
+ # Combine: default first, then custom on top
829
+ final_system_prompt = f"{DEFAULT_CRCA_SYSTEM_PROMPT}\n\n--- Additional Instructions ---\n{system_prompt}"
830
+ else:
831
+ # Just use default
832
+ final_system_prompt = DEFAULT_CRCA_SYSTEM_PROMPT
833
+ elif system_prompt is not None:
834
+ # If no default but custom prompt provided, use custom prompt
835
+ final_system_prompt = system_prompt
836
+
837
+ if final_system_prompt is not None:
838
+ agent_kwargs["system_prompt"] = final_system_prompt
839
+ if global_system_prompt is not None:
840
+ agent_kwargs["global_system_prompt"] = global_system_prompt
841
+ if secondary_system_prompt is not None:
842
+ agent_kwargs["secondary_system_prompt"] = secondary_system_prompt
843
+
844
+ super().__init__(**agent_kwargs)
845
+
846
+ # Now that self exists, create and add the CRCA tool handlers if tools are enabled
847
+ if use_crca_tools:
848
+ # Create a wrapper function with the correct name that matches the schema
849
+ def generate_causal_analysis(
850
+ causal_analysis: str,
851
+ intervention_planning: str,
852
+ counterfactual_scenarios: List[Dict[str, Any]],
853
+ causal_strength_assessment: str,
854
+ optimal_solution: str,
855
+ ) -> Dict[str, Any]:
856
+ """Tool handler for generate_causal_analysis - wrapper that calls the instance method."""
857
+ return self._generate_causal_analysis_handler(
858
+ causal_analysis=causal_analysis,
859
+ intervention_planning=intervention_planning,
860
+ counterfactual_scenarios=counterfactual_scenarios,
861
+ causal_strength_assessment=causal_strength_assessment,
862
+ optimal_solution=optimal_solution,
863
+ )
864
+
865
+ # Create a wrapper function for extract_causal_variables
866
+ def extract_causal_variables(
867
+ required_variables: List[str],
868
+ causal_edges: List[List[str]],
869
+ reasoning: str,
870
+ optional_variables: Optional[List[str]] = None,
871
+ counterfactual_variables: Optional[List[str]] = None,
872
+ ) -> Dict[str, Any]:
873
+ """Tool handler for extract_causal_variables - wrapper that calls the instance method."""
874
+ return self._extract_causal_variables_handler(
875
+ required_variables=required_variables,
876
+ causal_edges=causal_edges,
877
+ reasoning=reasoning,
878
+ optional_variables=optional_variables,
879
+ counterfactual_variables=counterfactual_variables,
880
+ )
881
+
882
+ # Add the wrapper functions to the tools list
883
+ # The function names must match the schema names
884
+ if self.tools is None:
885
+ self.tools = []
886
+ self.add_tool(generate_causal_analysis)
887
+ self.add_tool(extract_causal_variables)
888
+
889
+ # CRITICAL: Re-initialize tool_struct after adding tools
890
+ # This ensures the BaseTool instance has the updated tools and function_map
891
+ if hasattr(self, 'setup_tools'):
892
+ self.tool_struct = self.setup_tools()
893
+
894
+ # Ensure tools_list_dictionary is set with our manual schemas
895
+ if not self.tools_list_dictionary or len(self.tools_list_dictionary) == 0:
896
+ self.tools_list_dictionary = [cr_ca_schema, extract_variables_schema] + existing_tools
897
+ else:
898
+ # Replace any auto-generated schemas with our manual ones
899
+ our_tool_names = {"generate_causal_analysis", "extract_causal_variables"}
900
+ filtered = []
901
+ for schema in self.tools_list_dictionary:
902
+ if isinstance(schema, dict):
903
+ func_name = schema.get("function", {}).get("name", "")
904
+ if func_name in our_tool_names:
905
+ continue # Skip auto-generated, we'll add manual
906
+ filtered.append(schema)
907
+ # Add our manual schemas first, then existing tools
908
+ self.tools_list_dictionary = [cr_ca_schema, extract_variables_schema] + filtered
909
+
910
+ self.causal_max_loops = max_loops
911
+ self.causal_graph: Dict[str, Dict[str, float]] = {}
912
+ self.causal_graph_reverse: Dict[str, List[str]] = {} # For fast parent lookup
913
+ self._graph = rx.PyDiGraph()
914
+ self._node_to_index: Dict[str, int] = {}
915
+ self._index_to_node: Dict[int, str] = {}
916
+
917
+ self.standardization_stats: Dict[str, Dict[str, float]] = {}
918
+ self.use_nonlinear_scm: bool = True
919
+ self.nonlinear_activation: str = "tanh" # options: 'tanh'|'identity'
920
+ self.interaction_terms: Dict[str, List[Tuple[str, str]]] = {}
921
+ self.edge_sign_constraints: Dict[Tuple[str, str], int] = {}
922
+ self.bayesian_priors: Dict[Tuple[str, str], Dict[str, float]] = {}
923
+ self.enable_batch_predict = bool(enable_batch_predict)
924
+ self.max_batch_size = int(max_batch_size)
925
+ self.bootstrap_workers = int(max(0, bootstrap_workers))
926
+ self.use_async = bool(use_async)
927
+ self.seed = seed if seed is not None else 42
928
+ self._rng = np.random.default_rng(self.seed)
929
+
930
+ if variables:
931
+ for var in variables:
932
+ self._ensure_node_exists(var)
933
+
934
+ if causal_edges:
935
+ for source, target in causal_edges:
936
+ self.add_causal_relationship(source, target)
937
+
938
+ self.causal_memory: List[Dict[str, Any]] = []
939
+ self._prediction_cache: Dict[Tuple[Tuple[Tuple[str, float], ...], Tuple[Tuple[str, float], ...]], Dict[str, float]] = {}
940
+ self._prediction_cache_order: List[Tuple[Tuple[Tuple[str, float], ...], Tuple[Tuple[str, float], ...]]] = []
941
+ self._prediction_cache_max: int = 1000
942
+ self._cache_enabled: bool = True
943
+ self._prediction_cache_lock = threading.Lock()
944
+
945
+ # Policy engine integration
946
+ self.policy_mode = bool(policy_mode)
947
+ self.policy: Optional[DoctrineV1] = None
948
+ self.ledger: Optional[Ledger] = None
949
+ self.epoch_seconds = int(epoch_seconds)
950
+ self.policy_loop: Optional[PolicyLoopMixin] = None
951
+
952
+ if self.policy_mode:
953
+ if not POLICY_ENGINE_AVAILABLE:
954
+ raise ImportError(
955
+ "Policy engine modules not available. "
956
+ "Ensure schemas/policy.py, utils/ledger.py, and templates/policy_loop.py exist."
957
+ )
958
+
959
+ # Load policy
960
+ if policy is None:
961
+ raise ValueError("policy must be provided when policy_mode=True")
962
+
963
+ if isinstance(policy, str):
964
+ # Load from JSON file
965
+ self.policy = DoctrineV1.from_json(policy)
966
+ elif isinstance(policy, DoctrineV1):
967
+ self.policy = policy
968
+ else:
969
+ raise TypeError(f"policy must be DoctrineV1 or str (JSON path), got {type(policy)}")
970
+
971
+ # Initialize ledger
972
+ if ledger_path is None:
973
+ ledger_path = "crca_ledger.db"
974
+ self.ledger = Ledger(ledger_path)
975
+
976
+ # Initialize policy loop mixin
977
+ self.policy_loop = PolicyLoopMixin(
978
+ doctrine=self.policy,
979
+ ledger=self.ledger,
980
+ seed=self.seed,
981
+ sensor_registry=sensor_registry,
982
+ actuator_registry=actuator_registry
983
+ )
984
+
985
+ logger.info(f"Policy mode enabled: doctrine={self.policy.version}, ledger={ledger_path}")
986
+
987
+ # Excel TUI integration
988
+ self._excel_enabled = bool(enable_excel)
989
+ self._excel_tables = None
990
+ self._excel_eval_engine = None
991
+ self._excel_scm_bridge = None
992
+ self._excel_dependency_graph = None
993
+
994
+ if enable_excel:
995
+ try:
996
+ from crca_excel.core.tables import TableManager
997
+ from crca_excel.core.deps import DependencyGraph
998
+ from crca_excel.core.eval import EvaluationEngine
999
+ from crca_excel.core.scm import SCMBridge
1000
+
1001
+ self._excel_tables = TableManager()
1002
+ self._excel_dependency_graph = DependencyGraph()
1003
+ self._excel_eval_engine = EvaluationEngine(
1004
+ self._excel_tables,
1005
+ self._excel_dependency_graph,
1006
+ max_iter=self.causal_max_loops if isinstance(self.causal_max_loops, int) else 100,
1007
+ epsilon=1e-6
1008
+ )
1009
+ self._excel_scm_bridge = SCMBridge(
1010
+ self._excel_tables,
1011
+ self._excel_eval_engine,
1012
+ crca_agent=self
1013
+ )
1014
+
1015
+ # Initialize standard tables
1016
+ self._initialize_excel_tables()
1017
+
1018
+ # Link causal graph to tables
1019
+ self._link_causal_graph_to_tables()
1020
+ except ImportError as e:
1021
+ logger.warning(f"Excel TUI modules not available: {e}")
1022
+ self._excel_enabled = False
1023
+
1024
+ @staticmethod
1025
+ def _get_cr_ca_schema() -> Dict[str, Any]:
1026
+
1027
+ return {
1028
+ "type": "function",
1029
+ "function": {
1030
+ "name": "generate_causal_analysis",
1031
+ "description": "Generates structured causal reasoning and counterfactual analysis",
1032
+ "parameters": {
1033
+ "type": "object",
1034
+ "properties": {
1035
+ "causal_analysis": {
1036
+ "type": "string",
1037
+ "description": "Analysis of causal relationships and mechanisms"
1038
+ },
1039
+ "intervention_planning": {
1040
+ "type": "string",
1041
+ "description": "Planned interventions to test causal hypotheses"
1042
+ },
1043
+ "counterfactual_scenarios": {
1044
+ "type": "array",
1045
+ "items": {
1046
+ "type": "object",
1047
+ "properties": {
1048
+ "scenario_name": {"type": "string"},
1049
+ "interventions": {"type": "object"},
1050
+ "expected_outcomes": {"type": "object"},
1051
+ "reasoning": {"type": "string"}
1052
+ }
1053
+ },
1054
+ "description": "Multiple counterfactual scenarios to explore"
1055
+ },
1056
+ "causal_strength_assessment": {
1057
+ "type": "string",
1058
+ "description": "Assessment of causal relationship strengths and confounders"
1059
+ },
1060
+ "optimal_solution": {
1061
+ "type": "string",
1062
+ "description": "Recommended optimal solution based on causal analysis"
1063
+ }
1064
+ },
1065
+ "required": [
1066
+ "causal_analysis",
1067
+ "intervention_planning",
1068
+ "counterfactual_scenarios",
1069
+ "causal_strength_assessment",
1070
+ "optimal_solution"
1071
+ ]
1072
+ }
1073
+ }
1074
+ }
1075
+
1076
+ @staticmethod
1077
+ def _get_extract_variables_schema() -> Dict[str, Any]:
1078
+ """
1079
+ Get the schema for the extract_causal_variables tool.
1080
+
1081
+ Returns:
1082
+ Dictionary containing the OpenAI function schema for variable extraction
1083
+ """
1084
+ return {
1085
+ "type": "function",
1086
+ "function": {
1087
+ "name": "extract_causal_variables",
1088
+ "description": "Extract and propose causal variables, relationships, and counterfactual scenarios needed for causal analysis",
1089
+ "parameters": {
1090
+ "type": "object",
1091
+ "properties": {
1092
+ "required_variables": {
1093
+ "type": "array",
1094
+ "items": {"type": "string"},
1095
+ "description": "Core variables that must be included for causal analysis"
1096
+ },
1097
+ "optional_variables": {
1098
+ "type": "array",
1099
+ "items": {"type": "string"},
1100
+ "description": "Additional variables that may be useful but not essential"
1101
+ },
1102
+ "causal_edges": {
1103
+ "type": "array",
1104
+ "items": {
1105
+ "type": "array",
1106
+ "items": {"type": "string"},
1107
+ "minItems": 2,
1108
+ "maxItems": 2
1109
+ },
1110
+ "description": "Causal relationships as [source, target] pairs"
1111
+ },
1112
+ "counterfactual_variables": {
1113
+ "type": "array",
1114
+ "items": {"type": "string"},
1115
+ "description": "Variables to explore in counterfactual scenarios"
1116
+ },
1117
+ "reasoning": {
1118
+ "type": "string",
1119
+ "description": "Explanation of why these variables and relationships are needed"
1120
+ }
1121
+ },
1122
+ "required": ["required_variables", "causal_edges", "reasoning"]
1123
+ }
1124
+ }
1125
+ }
1126
+
1127
+ def step(self, task: str) -> str:
1128
+ """
1129
+ Execute a single step of causal reasoning.
1130
+
1131
+ Args:
1132
+ task: Task string to process
1133
+
1134
+ Returns:
1135
+ Response string from the agent
1136
+ """
1137
+ response = super().run(task)
1138
+ return response
1139
+
1140
+ def _generate_causal_analysis_handler(
1141
+ self,
1142
+ causal_analysis: str,
1143
+ intervention_planning: str,
1144
+ counterfactual_scenarios: List[Dict[str, Any]],
1145
+ causal_strength_assessment: str,
1146
+ optimal_solution: str,
1147
+ ) -> Dict[str, Any]:
1148
+ """
1149
+ Handler function for the generate_causal_analysis tool.
1150
+
1151
+ This function is called when the LLM invokes the generate_causal_analysis tool.
1152
+ It processes the tool's output and integrates it into the causal graph.
1153
+
1154
+ Args:
1155
+ causal_analysis: Analysis of causal relationships and mechanisms
1156
+ intervention_planning: Planned interventions to test causal hypotheses
1157
+ counterfactual_scenarios: List of counterfactual scenarios
1158
+ causal_strength_assessment: Assessment of causal relationship strengths
1159
+ optimal_solution: Recommended optimal solution
1160
+
1161
+ Returns:
1162
+ Dictionary with processed results
1163
+ """
1164
+ logger.info("Processing causal analysis from tool call")
1165
+
1166
+ # Store the analysis in causal memory
1167
+ analysis_entry = {
1168
+ 'type': 'analysis', # Mark this as an analysis entry
1169
+ 'causal_analysis': causal_analysis,
1170
+ 'intervention_planning': intervention_planning,
1171
+ 'counterfactual_scenarios': counterfactual_scenarios,
1172
+ 'causal_strength_assessment': causal_strength_assessment,
1173
+ 'optimal_solution': optimal_solution,
1174
+ 'timestamp': len(self.causal_memory)
1175
+ }
1176
+
1177
+ self.causal_memory.append(analysis_entry)
1178
+
1179
+ # Try to extract and update causal relationships from the analysis
1180
+ # This is a simple implementation - can be enhanced later
1181
+ try:
1182
+ # The LLM might mention relationships in the analysis
1183
+ # For now, we'll just store the analysis
1184
+ # In a more advanced version, we could parse the analysis to extract relationships
1185
+ pass
1186
+ except Exception as e:
1187
+ logger.warning(f"Error processing causal analysis: {e}")
1188
+
1189
+ # Return a structured response
1190
+ return {
1191
+ "status": "success",
1192
+ "message": "Causal analysis processed and stored",
1193
+ "analysis_summary": {
1194
+ "causal_analysis_length": len(causal_analysis),
1195
+ "num_scenarios": len(counterfactual_scenarios),
1196
+ "has_optimal_solution": bool(optimal_solution)
1197
+ }
1198
+ }
1199
+
1200
+ def _extract_causal_variables_handler(
1201
+ self,
1202
+ required_variables: List[str],
1203
+ causal_edges: List[List[str]],
1204
+ reasoning: str,
1205
+ optional_variables: Optional[List[str]] = None,
1206
+ counterfactual_variables: Optional[List[str]] = None,
1207
+ ) -> Dict[str, Any]:
1208
+ """
1209
+ Handler function for the extract_causal_variables tool.
1210
+
1211
+ This function is called when the LLM invokes the extract_causal_variables tool.
1212
+ It processes the tool's output and adds variables and edges to the causal graph.
1213
+
1214
+ Args:
1215
+ required_variables: Core variables that must be included for causal analysis
1216
+ causal_edges: Causal relationships as [source, target] pairs
1217
+ reasoning: Explanation of why these variables and relationships are needed
1218
+ optional_variables: Additional variables that may be useful but not essential
1219
+ counterfactual_variables: Variables to explore in counterfactual scenarios
1220
+
1221
+ Returns:
1222
+ Dictionary with processed results and summary of what was added
1223
+ """
1224
+ import traceback
1225
+
1226
+ try:
1227
+ logger.info("=== EXTRACT HANDLER CALLED ===")
1228
+ logger.info(f"Processing variable extraction from tool call")
1229
+ logger.info(f"Received required_variables: {required_variables} (type: {type(required_variables)})")
1230
+ logger.info(f"Received causal_edges: {causal_edges} (type: {type(causal_edges)})")
1231
+ logger.info(f"Received optional_variables: {optional_variables}")
1232
+ logger.info(f"Received counterfactual_variables: {counterfactual_variables}")
1233
+
1234
+ # Validate inputs
1235
+ if not required_variables:
1236
+ logger.warning("No required_variables provided!")
1237
+ required_variables = []
1238
+ if not isinstance(required_variables, list):
1239
+ logger.warning(f"required_variables is not a list: {type(required_variables)}")
1240
+ required_variables = [str(required_variables)] if required_variables else []
1241
+
1242
+ if not causal_edges:
1243
+ logger.warning("No causal_edges provided!")
1244
+ causal_edges = []
1245
+ if not isinstance(causal_edges, list):
1246
+ logger.warning(f"causal_edges is not a list: {type(causal_edges)}")
1247
+ causal_edges = []
1248
+
1249
+ # Track what was added
1250
+ added_variables = []
1251
+ added_edges = []
1252
+ skipped_edges = []
1253
+
1254
+ # Add required variables
1255
+ for var in required_variables:
1256
+ if var and var.strip():
1257
+ var_clean = var.strip()
1258
+ if var_clean not in self.causal_graph:
1259
+ self._ensure_node_exists(var_clean)
1260
+ added_variables.append(var_clean)
1261
+
1262
+ # Add optional variables
1263
+ if optional_variables:
1264
+ for var in optional_variables:
1265
+ if var and var.strip():
1266
+ var_clean = var.strip()
1267
+ if var_clean not in self.causal_graph:
1268
+ self._ensure_node_exists(var_clean)
1269
+ added_variables.append(var_clean)
1270
+
1271
+ # Add counterfactual variables (also add them as nodes)
1272
+ if counterfactual_variables:
1273
+ for var in counterfactual_variables:
1274
+ if var and var.strip():
1275
+ var_clean = var.strip()
1276
+ if var_clean not in self.causal_graph:
1277
+ self._ensure_node_exists(var_clean)
1278
+ if var_clean not in added_variables:
1279
+ added_variables.append(var_clean)
1280
+
1281
+ # Add causal edges
1282
+ for edge in causal_edges:
1283
+ if isinstance(edge, (list, tuple)) and len(edge) >= 2:
1284
+ source = str(edge[0]).strip() if edge[0] else None
1285
+ target = str(edge[1]).strip() if edge[1] else None
1286
+
1287
+ if source and target:
1288
+ # Ensure both nodes exist
1289
+ if source not in self.causal_graph:
1290
+ self._ensure_node_exists(source)
1291
+ if source not in added_variables:
1292
+ added_variables.append(source)
1293
+ if target not in self.causal_graph:
1294
+ self._ensure_node_exists(target)
1295
+ if target not in added_variables:
1296
+ added_variables.append(target)
1297
+
1298
+ # Add the edge
1299
+ try:
1300
+ self.add_causal_relationship(source, target)
1301
+ added_edges.append((source, target))
1302
+ except Exception as e:
1303
+ logger.warning(f"Failed to add edge {source} -> {target}: {e}")
1304
+ skipped_edges.append((source, target))
1305
+ else:
1306
+ logger.warning(f"Invalid edge format: {edge}")
1307
+ skipped_edges.append(edge)
1308
+
1309
+ # Store extraction metadata in causal memory
1310
+ extraction_entry = {
1311
+ 'type': 'variable_extraction',
1312
+ 'required_variables': required_variables,
1313
+ 'optional_variables': optional_variables or [],
1314
+ 'counterfactual_variables': counterfactual_variables or [],
1315
+ 'causal_edges': causal_edges,
1316
+ 'reasoning': reasoning,
1317
+ 'added_variables': added_variables,
1318
+ 'added_edges': added_edges,
1319
+ 'skipped_edges': skipped_edges,
1320
+ 'timestamp': len(self.causal_memory)
1321
+ }
1322
+
1323
+ self.causal_memory.append(extraction_entry)
1324
+
1325
+ # Log what was actually added
1326
+ logger.info(f"Extraction complete: Added {len(added_variables)} variables, {len(added_edges)} edges")
1327
+ logger.info(f"Added variables: {added_variables}")
1328
+ logger.info(f"Added edges: {added_edges}")
1329
+ logger.info(f"Current graph size: {len(self.causal_graph)} variables, {sum(len(children) for children in self.causal_graph.values())} edges")
1330
+
1331
+ # Return a structured response
1332
+ result = {
1333
+ "status": "success",
1334
+ "message": "Variables and relationships extracted and added to causal graph",
1335
+ "summary": {
1336
+ "variables_added": len(added_variables),
1337
+ "edges_added": len(added_edges),
1338
+ "edges_skipped": len(skipped_edges),
1339
+ "total_variables_in_graph": len(self.causal_graph),
1340
+ "total_edges_in_graph": sum(len(children) for children in self.causal_graph.values())
1341
+ },
1342
+ "details": {
1343
+ "added_variables": added_variables,
1344
+ "added_edges": added_edges,
1345
+ "skipped_edges": skipped_edges if skipped_edges else None
1346
+ }
1347
+ }
1348
+ logger.info(f"Returning result: {result}")
1349
+ return result
1350
+
1351
+ except Exception as e:
1352
+ logger.error(f"ERROR in _extract_causal_variables_handler: {e}")
1353
+ logger.error(f"Traceback: {traceback.format_exc()}")
1354
+ # Return error but don't fail completely
1355
+ return {
1356
+ "status": "error",
1357
+ "message": f"Error processing variable extraction: {str(e)}",
1358
+ "summary": {
1359
+ "variables_added": 0,
1360
+ "edges_added": 0,
1361
+ }
1362
+ }
1363
+
1364
+ def _build_causal_prompt(self, task: str) -> str:
1365
+ return (
1366
+ "You are a Causal Reasoning with Counterfactual Analysis (CR-CA) agent.\n"
1367
+ f"Problem: {task}\n"
1368
+ f"Current causal graph has {len(self.causal_graph)} variables and "
1369
+ f"{sum(len(children) for children in self.causal_graph.values())} relationships.\n\n"
1370
+ "CRITICAL: You MUST use the generate_causal_analysis tool to provide your analysis.\n"
1371
+ "The variables have already been extracted. Now you need to generate the causal analysis.\n"
1372
+ "Do NOT call extract_causal_variables again. You MUST call generate_causal_analysis with:\n"
1373
+ "- causal_analysis: Detailed analysis of causal relationships and mechanisms\n"
1374
+ "- intervention_planning: Planned interventions to test causal hypotheses\n"
1375
+ "- counterfactual_scenarios: List of what-if scenarios (array of objects)\n"
1376
+ "- causal_strength_assessment: Assessment of relationship strengths and confounders\n"
1377
+ "- optimal_solution: Recommended solution based on analysis\n"
1378
+ )
1379
+
1380
+ def _build_variable_extraction_prompt(self, task: str) -> str:
1381
+ """
1382
+ Build a prompt that guides the LLM to extract variables from a task.
1383
+
1384
+ Args:
1385
+ task: The task string to analyze
1386
+
1387
+ Returns:
1388
+ Formatted prompt string for variable extraction
1389
+ """
1390
+ return (
1391
+ "You are analyzing a causal reasoning task. The causal graph is currently empty.\n"
1392
+ f"Task: {task}\n\n"
1393
+ "**CRITICAL: You MUST use the extract_causal_variables tool to proceed.**\n"
1394
+ "Do NOT just describe what variables might be needed - you MUST call the tool.\n\n"
1395
+ "Call the extract_causal_variables tool with:\n"
1396
+ "1. required_variables: List of core variables needed for causal analysis\n"
1397
+ "2. causal_edges: List of [source, target] pairs showing causal relationships\n"
1398
+ "3. reasoning: Explanation of why these variables are needed\n"
1399
+ "4. optional_variables: (optional) Additional variables that may be useful\n"
1400
+ "5. counterfactual_variables: (optional) Variables to explore in what-if scenarios\n\n"
1401
+ "Example: For a pricing task, you might extract:\n"
1402
+ "- required_variables: ['price', 'demand', 'supply', 'cost']\n"
1403
+ "- causal_edges: [['price', 'demand'], ['cost', 'price'], ['supply', 'price']]\n"
1404
+ "- reasoning: 'Price affects demand, cost affects price, supply affects price'\n\n"
1405
+ "You can call the tool multiple times to refine your extraction.\n"
1406
+ "Be thorough - extract all variables and relationships implied by the task."
1407
+ )
1408
+
1409
+ def _build_memory_context(self) -> str:
1410
+ """Build memory context from causal_memory, handling different memory entry structures."""
1411
+ context_parts = []
1412
+ for step in self.causal_memory[-2:]: # Last 2 steps
1413
+ if isinstance(step, dict):
1414
+ # Handle standard analysis step structure
1415
+ if 'step' in step and 'analysis' in step:
1416
+ context_parts.append(f"Step {step['step']}: {step['analysis']}")
1417
+ # Handle extraction entry structure
1418
+ elif 'type' in step and step.get('type') == 'extraction':
1419
+ context_parts.append(f"Variable Extraction: {step.get('summary', 'Variables extracted')}")
1420
+ # Handle generic entry
1421
+ elif 'analysis' in step:
1422
+ context_parts.append(f"Analysis: {step['analysis']}")
1423
+ elif 'summary' in step:
1424
+ context_parts.append(f"Summary: {step['summary']}")
1425
+ return "\n".join(context_parts) if context_parts else ""
1426
+
1427
+ def _synthesize_causal_analysis(self, task: str) -> str:
1428
+ """
1429
+ Synthesize a final causal analysis report from the analysis steps.
1430
+
1431
+ Uses direct LLM call to avoid Agent tool execution issues.
1432
+
1433
+ Args:
1434
+ task: Original task string
1435
+
1436
+ Returns:
1437
+ Synthesized causal analysis report
1438
+ """
1439
+ synthesis_prompt = f"Based on the causal analysis steps performed, synthesize a concise causal report for: {task}"
1440
+ try:
1441
+ # Use direct LLM call to avoid tool execution errors
1442
+ response = self._call_llm_directly(synthesis_prompt)
1443
+ return str(response) if response else "Analysis synthesis failed"
1444
+ except Exception as e:
1445
+ logger.error(f"Error synthesizing causal analysis: {e}")
1446
+ return "Analysis synthesis failed"
1447
+
1448
+ def _should_trigger_causal_analysis(self, task: Optional[Union[str, Any]]) -> bool:
1449
+ """
1450
+ Automatically detect if a task should trigger causal analysis.
1451
+
1452
+ This method analyzes the task content to determine if it requires
1453
+ causal reasoning, counterfactual analysis, or relationship analysis.
1454
+
1455
+ Args:
1456
+ task: The task string to analyze
1457
+
1458
+ Returns:
1459
+ True if causal analysis should be triggered, False otherwise
1460
+ """
1461
+ if task is None:
1462
+ return False
1463
+
1464
+ if not isinstance(task, str):
1465
+ return False
1466
+
1467
+ task_lower = task.lower().strip()
1468
+
1469
+ # Keywords that indicate causal analysis is needed
1470
+ causal_keywords = [
1471
+ # Causal relationship terms
1472
+ 'causal', 'causality', 'cause', 'causes', 'caused by', 'causing',
1473
+ 'relationship', 'relationships', 'relate', 'relates', 'related',
1474
+ 'influence', 'influences', 'influenced', 'affect', 'affects', 'affected',
1475
+ 'impact', 'impacts', 'impacted', 'effect', 'effects',
1476
+ 'depend', 'depends', 'dependency', 'dependencies',
1477
+ 'correlation', 'correlate', 'correlates',
1478
+
1479
+ # Counterfactual analysis terms
1480
+ 'counterfactual', 'counterfactuals', 'what if', 'what-if',
1481
+ 'scenario', 'scenarios', 'alternative', 'alternatives',
1482
+ 'hypothetical', 'hypothesis', 'hypotheses',
1483
+ 'if then', 'if-then', 'suppose', 'assuming',
1484
+
1485
+ # Prediction and forecasting terms (often need causal reasoning)
1486
+ 'predict', 'prediction', 'forecast', 'forecasting', 'project', 'projection',
1487
+ 'expected', 'expect', 'expectation', 'estimate', 'estimation',
1488
+ 'future', 'future value', 'future price', 'in 24 months', 'in X months',
1489
+ 'will be', 'would be', 'could be', 'might be',
1490
+
1491
+ # Analysis terms
1492
+ 'analyze', 'analysis', 'analyzing', 'analyze the', 'analyze how',
1493
+ 'understand', 'understanding', 'explain', 'explanation',
1494
+ 'reasoning', 'reason', 'rationale',
1495
+
1496
+ # Relationship-specific terms
1497
+ 'between', 'among', 'link', 'links', 'connection', 'connections',
1498
+ 'chain', 'chains', 'path', 'paths', 'flow', 'flows',
1499
+
1500
+ # Intervention terms
1501
+ 'intervention', 'interventions', 'change', 'changes', 'modify', 'modifies',
1502
+ 'adjust', 'adjusts', 'alter', 'alters',
1503
+
1504
+ # Risk and consequence terms (NEW)
1505
+ 'risk', 'risks', 'consequence', 'consequences', 'benefit', 'benefits',
1506
+ 'trade-off', 'trade-offs', 'tradeoff', 'tradeoffs',
1507
+ 'downside', 'downsides', 'upside', 'upsides',
1508
+
1509
+ # Determination and outcome terms (NEW)
1510
+ 'determine', 'determines', 'determining', 'determination',
1511
+ 'result', 'results', 'resulting', 'outcome', 'outcomes',
1512
+ 'consideration', 'considerations', 'factor', 'factors',
1513
+ 'driver', 'drivers', 'driving', 'drives', 'driven',
1514
+ 'lead to', 'leads to', 'leading to', 'led to',
1515
+
1516
+ # Decision-making terms (NEW)
1517
+ 'should', 'should we', 'should i', 'should they',
1518
+ 'better', 'best', 'worse', 'worst', 'compare', 'comparison',
1519
+ 'option', 'options', 'choice', 'choices', 'choose',
1520
+ 'strategy', 'strategies', 'approach', 'approaches',
1521
+ 'decision', 'decisions', 'decide',
1522
+
1523
+ # Importance and consideration terms (NEW)
1524
+ 'important', 'importance', 'matter', 'matters', 'mattering',
1525
+ 'consider', 'considering', 'consideration', 'considerations',
1526
+ 'key', 'keys', 'critical', 'crucial', 'essential',
1527
+ ]
1528
+
1529
+ # Check if task contains causal keywords
1530
+ for keyword in causal_keywords:
1531
+ if keyword in task_lower:
1532
+ return True
1533
+
1534
+ # Check for questions about variables in the causal graph
1535
+ # If the task mentions any of our variables, it's likely a causal question
1536
+ graph_variables = [var.lower() for var in self.causal_graph.keys()]
1537
+ for var in graph_variables:
1538
+ if var in task_lower:
1539
+ return True
1540
+
1541
+ # Check for patterns that suggest causal reasoning
1542
+ causal_patterns = [
1543
+ 'how does', 'how do', 'how will', 'how would', 'how might', 'how can',
1544
+ 'why does', 'why do', 'why will', 'why would', 'why did',
1545
+ 'what happens if', 'what would happen', 'what will happen', 'what happens when',
1546
+ 'if we', 'if you', 'if they', 'if it', 'if this', 'if that',
1547
+ 'what should', 'what should we', 'what should i',
1548
+ 'which is', 'which are', 'which would', 'which will',
1549
+ 'what results', 'what results from', 'what comes', 'what comes next',
1550
+ 'what leads', 'what leads to', 'what follows', 'what follows from',
1551
+ ]
1552
+
1553
+ for pattern in causal_patterns:
1554
+ if pattern in task_lower:
1555
+ return True
1556
+
1557
+ return False
1558
+ def _ensure_node_exists(self, node: str) -> None:
1559
+
1560
+ if node not in self.causal_graph:
1561
+ self.causal_graph[node] = {}
1562
+ if node not in self.causal_graph_reverse:
1563
+ self.causal_graph_reverse[node] = []
1564
+ try:
1565
+ self._ensure_node_index(node)
1566
+ except Exception:
1567
+ pass
1568
+
1569
+ def add_causal_relationship(
1570
+ self,
1571
+ source: str,
1572
+ target: str,
1573
+ strength: float = 1.0,
1574
+ relation_type: CausalRelationType = CausalRelationType.DIRECT,
1575
+ confidence: float = 1.0
1576
+ ) -> None:
1577
+
1578
+ self._ensure_node_exists(source)
1579
+ self._ensure_node_exists(target)
1580
+
1581
+ meta = {
1582
+ "strength": float(strength),
1583
+ "relation_type": relation_type.value if isinstance(relation_type, Enum) else str(relation_type),
1584
+ "confidence": float(confidence),
1585
+ }
1586
+
1587
+ self.causal_graph.setdefault(source, {})[target] = meta
1588
+
1589
+ if source not in self.causal_graph_reverse.get(target, []):
1590
+ self.causal_graph_reverse.setdefault(target, []).append(source)
1591
+
1592
+ try:
1593
+ u_idx = self._ensure_node_index(source)
1594
+ v_idx = self._ensure_node_index(target)
1595
+ try:
1596
+ existing = self._graph.get_edge_data(u_idx, v_idx)
1597
+ except Exception:
1598
+ existing = None
1599
+
1600
+ if existing is None:
1601
+ try:
1602
+ self._graph.add_edge(u_idx, v_idx, meta)
1603
+ except Exception:
1604
+ try:
1605
+ import logging
1606
+ logging.getLogger(__name__).warning(
1607
+ f"rustworkx.add_edge failed for {source}->{target}; continuing with dict-only graph."
1608
+ )
1609
+ except Exception:
1610
+ pass
1611
+ else:
1612
+ try:
1613
+ if isinstance(existing, dict):
1614
+ existing.update(meta)
1615
+ try:
1616
+ import logging
1617
+ logging.getLogger(__name__).debug(
1618
+ f"Updated rustworkx edge data for {source}->{target} in-place."
1619
+ )
1620
+ except Exception:
1621
+ pass
1622
+ else:
1623
+ try:
1624
+ edge_idx = self._graph.get_edge_index(u_idx, v_idx)
1625
+ except Exception:
1626
+ edge_idx = None
1627
+ if edge_idx is not None and edge_idx >= 0:
1628
+ try:
1629
+ self._graph.remove_edge(edge_idx)
1630
+ self._graph.add_edge(u_idx, v_idx, meta)
1631
+ try:
1632
+ import logging
1633
+ logging.getLogger(__name__).debug(
1634
+ f"Replaced rustworkx edge for {source}->{target} with updated metadata."
1635
+ )
1636
+ except Exception:
1637
+ pass
1638
+ except Exception:
1639
+ try:
1640
+ import logging
1641
+ logging.getLogger(__name__).warning(
1642
+ f"Could not replace rustworkx edge for {source}->{target}; keeping dict-only metadata."
1643
+ )
1644
+ except Exception:
1645
+ pass
1646
+ else:
1647
+ try:
1648
+ import logging
1649
+ logging.getLogger(__name__).debug(
1650
+ f"rustworkx edge exists but index lookup failed for {source}->{target}; dict metadata used."
1651
+ )
1652
+ except Exception:
1653
+ pass
1654
+ except Exception:
1655
+ try:
1656
+ import logging
1657
+ logging.getLogger(__name__).warning(
1658
+ f"Failed updating rustworkx edge for {source}->{target}; continuing with dict-only graph."
1659
+ )
1660
+ except Exception:
1661
+ pass
1662
+ except Exception:
1663
+ try:
1664
+ import logging
1665
+ logging.getLogger(__name__).warning(
1666
+ "rustworkx operation failed during add_causal_relationship; continuing with dict-only graph."
1667
+ )
1668
+ except Exception:
1669
+ pass
1670
+
1671
+ def _get_parents(self, node: str) -> List[str]:
1672
+
1673
+ return self.causal_graph_reverse.get(node, [])
1674
+
1675
+ def _get_children(self, node: str) -> List[str]:
1676
+
1677
+ return list(self.causal_graph.get(node, {}).keys())
1678
+
1679
+ def _ensure_node_index(self, name: str) -> int:
1680
+
1681
+ if name in self._node_to_index:
1682
+ return self._node_to_index[name]
1683
+ idx = self._graph.add_node(name)
1684
+ self._node_to_index[name] = idx
1685
+ self._index_to_node[idx] = name
1686
+ return idx
1687
+
1688
+ def _node_index(self, name: str) -> Optional[int]:
1689
+
1690
+ return self._node_to_index.get(name)
1691
+
1692
+ def _node_name(self, idx: int) -> Optional[str]:
1693
+
1694
+ return self._index_to_node.get(idx)
1695
+
1696
+ def _edge_strength(self, source: str, target: str) -> float:
1697
+
1698
+ edge = self.causal_graph.get(source, {}).get(target, None)
1699
+ if isinstance(edge, dict):
1700
+ return float(edge.get("strength", 0.0))
1701
+ try:
1702
+ return float(edge) if edge is not None else 0.0
1703
+ except Exception:
1704
+ return 0.0
1705
+
1706
+ def _topological_sort(self) -> List[str]:
1707
+
1708
+ try:
1709
+ order_idx = rx.topological_sort(self._graph)
1710
+ result = [self._node_name(i) for i in order_idx if self._node_name(i) is not None]
1711
+ for n in list(self.causal_graph.keys()):
1712
+ if n not in result:
1713
+ result.append(n)
1714
+ return result
1715
+ except Exception:
1716
+ in_degree: Dict[str, int] = {node: 0 for node in self.causal_graph.keys()}
1717
+ for node in self.causal_graph:
1718
+ for child in self._get_children(node):
1719
+ in_degree[child] = in_degree.get(child, 0) + 1
1720
+
1721
+ queue: List[str] = [node for node, degree in in_degree.items() if degree == 0]
1722
+ result: List[str] = []
1723
+ while queue:
1724
+ node = queue.pop(0)
1725
+ result.append(node)
1726
+ for child in self._get_children(node):
1727
+ in_degree[child] -= 1
1728
+ if in_degree[child] == 0:
1729
+ queue.append(child)
1730
+ return result
1731
+
1732
+ def identify_causal_chain(self, start: str, end: str) -> List[str]:
1733
+
1734
+ if start not in self.causal_graph or end not in self.causal_graph:
1735
+ return []
1736
+
1737
+ if start == end:
1738
+ return [start]
1739
+
1740
+ queue: List[Tuple[str, List[str]]] = [(start, [start])]
1741
+ visited: set = {start}
1742
+
1743
+ while queue:
1744
+ current, path = queue.pop(0)
1745
+
1746
+ for child in self._get_children(current):
1747
+ if child == end:
1748
+ return path + [child]
1749
+
1750
+ if child not in visited:
1751
+ visited.add(child)
1752
+ queue.append((child, path + [child]))
1753
+
1754
+ return [] # No path found
1755
+
1756
+
1757
+ def _has_path(self, start: str, end: str) -> bool:
1758
+
1759
+ if start == end:
1760
+ return True
1761
+
1762
+ stack = [start]
1763
+ visited = set()
1764
+
1765
+ while stack:
1766
+ current = stack.pop()
1767
+ if current in visited:
1768
+ continue
1769
+ visited.add(current)
1770
+
1771
+ for child in self._get_children(current):
1772
+ if child == end:
1773
+ return True
1774
+ if child not in visited:
1775
+ stack.append(child)
1776
+
1777
+ return False
1778
+
1779
+ def clear_cache(self) -> None:
1780
+
1781
+ with self._prediction_cache_lock:
1782
+ self._prediction_cache.clear()
1783
+ self._prediction_cache_order.clear()
1784
+
1785
+ def enable_cache(self, flag: bool) -> None:
1786
+
1787
+ with self._prediction_cache_lock:
1788
+ self._cache_enabled = bool(flag)
1789
+
1790
+
1791
+ def _standardize_state(self, state: Dict[str, float]) -> Dict[str, float]:
1792
+
1793
+ z: Dict[str, float] = {}
1794
+ for k, v in state.items():
1795
+ s = self.standardization_stats.get(k)
1796
+ if s and s.get("std", 0.0) > 0:
1797
+ z[k] = (v - s["mean"]) / s["std"]
1798
+ else:
1799
+ z[k] = v
1800
+ return z
1801
+
1802
+ def _destandardize_value(self, var: str, z_value: float) -> float:
1803
+
1804
+ s = self.standardization_stats.get(var)
1805
+ if s and s.get("std", 0.0) > 0:
1806
+ return z_value * s["std"] + s["mean"]
1807
+ return z_value
1808
+
1809
+ def _predict_outcomes(
1810
+ self,
1811
+ factual_state: Dict[str, float],
1812
+ interventions: Dict[str, float]
1813
+ ) -> Dict[str, float]:
1814
+
1815
+ if self.use_nonlinear_scm:
1816
+ z_pred = self._predict_z(factual_state, interventions, use_noise=None)
1817
+ return {v: self._destandardize_value(v, z_val) for v, z_val in z_pred.items()}
1818
+
1819
+ raw = factual_state.copy()
1820
+ raw.update(interventions)
1821
+
1822
+ z_state = self._standardize_state(raw)
1823
+ z_pred = dict(z_state)
1824
+
1825
+ for node in self._topological_sort():
1826
+ if node in interventions:
1827
+ if node not in z_pred:
1828
+ z_pred[node] = z_state.get(node, 0.0)
1829
+ continue
1830
+
1831
+ parents = self._get_parents(node)
1832
+ if not parents:
1833
+ continue
1834
+
1835
+ s = 0.0
1836
+ for p in parents:
1837
+ pz = z_pred.get(p, z_state.get(p, 0.0))
1838
+ strength = self._edge_strength(p, node)
1839
+ s += pz * strength
1840
+
1841
+ z_pred[node] = s
1842
+
1843
+ return {v: self._destandardize_value(v, z) for v, z in z_pred.items()}
1844
+
1845
+ def _predict_outcomes_with_graph_variant(
1846
+ self,
1847
+ factual_state: Dict[str, float],
1848
+ interventions: Dict[str, float],
1849
+ graph_variant: Dict[Tuple[str, str], float]
1850
+ ) -> Dict[str, float]:
1851
+ """Predict outcomes using a temporary graph variant.
1852
+
1853
+ Temporarily applies a graph variant (perturbed edge strengths),
1854
+ runs prediction, then restores original graph.
1855
+
1856
+ Args:
1857
+ factual_state: Current factual state
1858
+ interventions: Interventions to apply
1859
+ graph_variant: Dictionary mapping (source, target) -> strength
1860
+
1861
+ Returns:
1862
+ Predicted outcomes
1863
+ """
1864
+ # Save original edge strengths
1865
+ original_strengths: Dict[Tuple[str, str], float] = {}
1866
+ for (u, v), variant_strength in graph_variant.items():
1867
+ if u in self.causal_graph and v in self.causal_graph[u]:
1868
+ edge = self.causal_graph[u][v]
1869
+ if isinstance(edge, dict):
1870
+ original_strengths[(u, v)] = edge.get("strength", 0.0)
1871
+ # Temporarily apply variant
1872
+ self.causal_graph[u][v]["strength"] = variant_strength
1873
+ else:
1874
+ original_strengths[(u, v)] = float(edge) if edge is not None else 0.0
1875
+ # Convert to dict format
1876
+ self.causal_graph[u][v] = {"strength": variant_strength, "confidence": 1.0}
1877
+
1878
+ try:
1879
+ # Run prediction with variant graph
1880
+ predictions = self._predict_outcomes(factual_state, interventions)
1881
+ finally:
1882
+ # Restore original edge strengths
1883
+ for (u, v), original_strength in original_strengths.items():
1884
+ if u in self.causal_graph and v in self.causal_graph[u]:
1885
+ edge = self.causal_graph[u][v]
1886
+ if isinstance(edge, dict):
1887
+ edge["strength"] = original_strength
1888
+ else:
1889
+ self.causal_graph[u][v] = original_strength
1890
+
1891
+ return predictions
1892
+
1893
+ def _predict_z(self, factual_state: Dict[str, float], interventions: Dict[str, float], use_noise: Optional[Dict[str, float]] = None) -> Dict[str, float]:
1894
+
1895
+ raw = factual_state.copy()
1896
+ raw.update(interventions)
1897
+ z_state = self._standardize_state(raw)
1898
+ z_pred: Dict[str, float] = dict(z_state)
1899
+
1900
+ for node in self._topological_sort():
1901
+ if node in interventions:
1902
+ z_pred[node] = z_state.get(node, 0.0)
1903
+ continue
1904
+
1905
+ parents = self._get_parents(node)
1906
+ if not parents:
1907
+ z_val = float(use_noise.get(node, 0.0)) if use_noise else z_state.get(node, 0.0)
1908
+ z_pred[node] = z_val
1909
+ continue
1910
+
1911
+ linear_term = 0.0
1912
+ for p in parents:
1913
+ parent_z = z_pred.get(p, z_state.get(p, 0.0))
1914
+ beta = self._edge_strength(p, node)
1915
+ linear_term += parent_z * beta
1916
+
1917
+ interaction_term = 0.0
1918
+ for (p1, p2) in self.interaction_terms.get(node, []):
1919
+ if p1 in parents and p2 in parents:
1920
+ z1 = z_pred.get(p1, z_state.get(p1, 0.0))
1921
+ z2 = z_pred.get(p2, z_state.get(p2, 0.0))
1922
+ gamma = 0.0
1923
+ edge_data = self.causal_graph.get(p1, {}).get(node, {})
1924
+ if isinstance(edge_data, dict):
1925
+ gamma = float(edge_data.get("interaction_strength", {}).get(p2, 0.0))
1926
+ interaction_term += gamma * z1 * z2
1927
+
1928
+ model_z = linear_term + interaction_term
1929
+
1930
+ if use_noise:
1931
+ model_z += float(use_noise.get(node, 0.0))
1932
+
1933
+ if self.nonlinear_activation == "tanh":
1934
+ model_z_act = float(np.tanh(model_z) * 3.0) # scale to limit
1935
+ else:
1936
+ model_z_act = float(model_z)
1937
+
1938
+ observed_z = z_state.get(node, 0.0)
1939
+
1940
+ threshold = float(getattr(self, "shock_preserve_threshold", 1e-3))
1941
+ if abs(observed_z) > threshold:
1942
+ z_pred[node] = float(observed_z)
1943
+ else:
1944
+ z_pred[node] = float(model_z_act)
1945
+
1946
+ return z_pred
1947
+
1948
+ def aap(self, factual_state: Dict[str, float], interventions: Dict[str, float]) -> Dict[str, float]:
1949
+
1950
+ return self.counterfactual_abduction_action_prediction(factual_state, interventions)
1951
+
1952
+ def _predict_outcomes_cached(
1953
+ self,
1954
+ factual_state: Dict[str, float],
1955
+ interventions: Dict[str, float],
1956
+ ) -> Dict[str, float]:
1957
+
1958
+ with self._prediction_cache_lock:
1959
+ cache_enabled = self._cache_enabled
1960
+ if not cache_enabled:
1961
+ return self._predict_outcomes(factual_state, interventions)
1962
+
1963
+ state_key = tuple(sorted([(k, float(v)) for k, v in factual_state.items()]))
1964
+ inter_key = tuple(sorted([(k, float(v)) for k, v in interventions.items()]))
1965
+ cache_key = (state_key, inter_key)
1966
+
1967
+ with self._prediction_cache_lock:
1968
+ if cache_key in self._prediction_cache:
1969
+ return dict(self._prediction_cache[cache_key])
1970
+
1971
+ result = self._predict_outcomes(factual_state, interventions)
1972
+
1973
+ with self._prediction_cache_lock:
1974
+ if len(self._prediction_cache_order) >= self._prediction_cache_max:
1975
+ remove_count = max(1, self._prediction_cache_max // 10)
1976
+ for _ in range(remove_count):
1977
+ old = self._prediction_cache_order.pop(0)
1978
+ if old in self._prediction_cache:
1979
+ del self._prediction_cache[old]
1980
+
1981
+ self._prediction_cache_order.append(cache_key)
1982
+ self._prediction_cache[cache_key] = dict(result)
1983
+ return result
1984
+
1985
+ def _get_descendants(self, node: str) -> List[str]:
1986
+
1987
+ if node not in self.causal_graph:
1988
+ return []
1989
+ stack = [node]
1990
+ visited = set()
1991
+ descendants: List[str] = []
1992
+ while stack:
1993
+ cur = stack.pop()
1994
+ for child in self._get_children(cur):
1995
+ if child in visited:
1996
+ continue
1997
+ visited.add(child)
1998
+ descendants.append(child)
1999
+ stack.append(child)
2000
+ return descendants
2001
+
2002
+ def counterfactual_abduction_action_prediction(
2003
+ self,
2004
+ factual_state: Dict[str, float],
2005
+ interventions: Dict[str, float]
2006
+ ) -> Dict[str, float]:
2007
+
2008
+ z = self._standardize_state(factual_state)
2009
+
2010
+ noise: Dict[str, float] = {}
2011
+ for node in self._topological_sort():
2012
+ parents = self._get_parents(node)
2013
+ if not parents:
2014
+ noise[node] = float(z.get(node, 0.0))
2015
+ continue
2016
+
2017
+ pred_z = 0.0
2018
+ for p in parents:
2019
+ pz = z.get(p, 0.0)
2020
+ strength = self._edge_strength(p, node)
2021
+ pred_z += pz * strength
2022
+
2023
+ noise[node] = float(z.get(node, 0.0) - pred_z)
2024
+
2025
+ cf_raw = factual_state.copy()
2026
+ cf_raw.update(interventions)
2027
+ z_cf = self._standardize_state(cf_raw)
2028
+
2029
+ z_pred: Dict[str, float] = {}
2030
+ for node in self._topological_sort():
2031
+ if node in interventions:
2032
+ z_pred[node] = float(z_cf.get(node, 0.0))
2033
+ continue
2034
+
2035
+ parents = self._get_parents(node)
2036
+ if not parents:
2037
+ z_pred[node] = float(noise.get(node, 0.0))
2038
+ continue
2039
+
2040
+ val = 0.0
2041
+ for p in parents:
2042
+ parent_z = z_pred.get(p, z_cf.get(p, 0.0))
2043
+ strength = self._edge_strength(p, node)
2044
+ val += parent_z * strength
2045
+
2046
+ z_pred[node] = float(val + noise.get(node, 0.0))
2047
+
2048
+ return {v: self._destandardize_value(v, z_val) for v, z_val in z_pred.items()}
2049
+
2050
+ def detect_confounders(self, treatment: str, outcome: str) -> List[str]:
2051
+
2052
+ def _ancestors(node: str) -> set:
2053
+ stack = [node]
2054
+ visited = set()
2055
+ while stack:
2056
+ cur = stack.pop()
2057
+ for p in self._get_parents(cur):
2058
+ if p in visited:
2059
+ continue
2060
+ visited.add(p)
2061
+ stack.append(p)
2062
+ return visited
2063
+
2064
+ if treatment not in self.causal_graph or outcome not in self.causal_graph:
2065
+ return []
2066
+
2067
+ treat_anc = _ancestors(treatment)
2068
+ out_anc = _ancestors(outcome)
2069
+ common = treat_anc.intersection(out_anc)
2070
+ return list(common)
2071
+
2072
+ def identify_adjustment_set(self, treatment: str, outcome: str) -> List[str]:
2073
+
2074
+ if treatment not in self.causal_graph or outcome not in self.causal_graph:
2075
+ return []
2076
+
2077
+ parents_t = set(self._get_parents(treatment))
2078
+ descendants_t = set(self._get_descendants(treatment))
2079
+ adjustment = [z for z in parents_t if z not in descendants_t and z != outcome]
2080
+ return adjustment
2081
+
2082
+ def _calculate_scenario_probability(
2083
+ self,
2084
+ factual_state: Dict[str, float],
2085
+ interventions: Dict[str, float]
2086
+ ) -> float:
2087
+
2088
+ z_sq = 0.0
2089
+ for var, new in interventions.items():
2090
+ s = self.standardization_stats.get(var, {"mean": 0.0, "std": 1.0})
2091
+ mu, sd = s.get("mean", 0.0), s.get("std", 1.0) or 1.0
2092
+ old = factual_state.get(var, mu)
2093
+ dz = (new - mu) / sd - (old - mu) / sd
2094
+ z_sq += float(dz) * float(dz)
2095
+
2096
+ p = 0.95 * float(np.exp(-0.5 * z_sq)) + 0.05
2097
+ return float(max(0.05, min(0.98, p)))
2098
+
2099
+ def generate_counterfactual_scenarios(
2100
+ self,
2101
+ factual_state: Dict[str, float],
2102
+ target_variables: List[str],
2103
+ max_scenarios: int = 5,
2104
+ use_monte_carlo: bool = True,
2105
+ mc_samples: int = 1000,
2106
+ parallel_sampling: bool = True,
2107
+ use_deterministic_fallback: bool = False
2108
+ ) -> List[CounterfactualScenario]:
2109
+ """Generate counterfactual scenarios using meta-Monte Carlo reasoning.
2110
+
2111
+ Args:
2112
+ factual_state: Current factual state
2113
+ target_variables: Variables to generate scenarios for
2114
+ max_scenarios: Maximum number of scenarios to return
2115
+ use_monte_carlo: Whether to use Monte Carlo sampling (default: True)
2116
+ mc_samples: Number of Monte Carlo iterations (default: 1000)
2117
+ parallel_sampling: Whether to sample graph and interventions in parallel (default: True)
2118
+ use_deterministic_fallback: Fallback to old deterministic method if MC fails (default: False)
2119
+
2120
+ Returns:
2121
+ List of CounterfactualScenario objects with uncertainty metadata
2122
+ """
2123
+ # Fallback to old deterministic method if requested
2124
+ if not use_monte_carlo:
2125
+ return self._generate_counterfactual_scenarios_deterministic(
2126
+ factual_state, target_variables, max_scenarios
2127
+ )
2128
+
2129
+ try:
2130
+ self.ensure_standardization_stats(factual_state)
2131
+
2132
+ # Initialize helper classes
2133
+ intervention_sampler = _AdaptiveInterventionSampler(self)
2134
+ graph_sampler = _GraphUncertaintySampler(self)
2135
+ quality_assessor = _PredictionQualityAssessor(self)
2136
+ meta_analyzer = _MetaReasoningAnalyzer(self)
2137
+
2138
+ # Get uncertainty data if available (from quantify_uncertainty)
2139
+ uncertainty_data = None
2140
+ # Try to get from cached uncertainty results if available
2141
+ if hasattr(self, '_cached_uncertainty_data'):
2142
+ uncertainty_data = self._cached_uncertainty_data
2143
+
2144
+ # Parallel sampling: graph variations and interventions
2145
+ import concurrent.futures
2146
+
2147
+ n_graph_samples = max(10, mc_samples // 100) # Sample fewer graph variants
2148
+ n_intervention_samples = mc_samples
2149
+
2150
+ if parallel_sampling and self.bootstrap_workers > 0:
2151
+ # Parallel execution
2152
+ with concurrent.futures.ThreadPoolExecutor(max_workers=self.bootstrap_workers) as executor:
2153
+ graph_future = executor.submit(
2154
+ graph_sampler.sample_graph_variations,
2155
+ n_graph_samples,
2156
+ uncertainty_data
2157
+ )
2158
+ intervention_future = executor.submit(
2159
+ intervention_sampler.sample_interventions,
2160
+ factual_state,
2161
+ target_variables,
2162
+ n_intervention_samples
2163
+ )
2164
+
2165
+ graph_variations = graph_future.result()
2166
+ interventions_list = intervention_future.result()
2167
+ else:
2168
+ # Sequential execution
2169
+ graph_variations = graph_sampler.sample_graph_variations(
2170
+ n_graph_samples,
2171
+ uncertainty_data
2172
+ )
2173
+ interventions_list = intervention_sampler.sample_interventions(
2174
+ factual_state,
2175
+ target_variables,
2176
+ n_intervention_samples
2177
+ )
2178
+
2179
+ # For each intervention, evaluate across graph variations
2180
+ scenarios_with_metadata = []
2181
+
2182
+ for intervention in interventions_list:
2183
+ # Get predictions across all graph variations
2184
+ predictions_across_variants = []
2185
+ for graph_variant in graph_variations:
2186
+ pred = self._predict_outcomes_with_graph_variant(
2187
+ factual_state,
2188
+ intervention,
2189
+ graph_variant
2190
+ )
2191
+ predictions_across_variants.append(pred)
2192
+
2193
+ # Assess prediction quality
2194
+ quality_score, quality_metrics = quality_assessor.assess_quality(
2195
+ predictions_across_variants,
2196
+ factual_state,
2197
+ intervention
2198
+ )
2199
+
2200
+ # Aggregate predictions (mean across variants)
2201
+ aggregated_outcomes = {}
2202
+ for var in set().union(*[p.keys() for p in predictions_across_variants]):
2203
+ values = [p.get(var, 0.0) for p in predictions_across_variants]
2204
+ aggregated_outcomes[var] = float(np.mean(values))
2205
+
2206
+ # Determine sampling distribution used
2207
+ sampling_dist = "adaptive" # Default
2208
+ for var in intervention.keys():
2209
+ dist_type, _ = intervention_sampler._get_adaptive_distribution(var, factual_state)
2210
+ sampling_dist = dist_type
2211
+ break
2212
+
2213
+ # Create scenario with metadata
2214
+ scenario = CounterfactualScenario(
2215
+ name=f"mc_scenario_{len(scenarios_with_metadata)}",
2216
+ interventions=intervention,
2217
+ expected_outcomes=aggregated_outcomes,
2218
+ probability=self._calculate_scenario_probability(factual_state, intervention),
2219
+ reasoning=f"Monte Carlo sampled intervention with quality score {quality_score:.3f}",
2220
+ uncertainty_metadata={
2221
+ "quality_score": quality_score,
2222
+ "quality_metrics": quality_metrics,
2223
+ "graph_variations_tested": len(graph_variations),
2224
+ "prediction_variance": {
2225
+ var: float(np.var([p.get(var, 0.0) for p in predictions_across_variants]))
2226
+ for var in aggregated_outcomes.keys()
2227
+ }
2228
+ },
2229
+ sampling_distribution=sampling_dist,
2230
+ monte_carlo_iterations=mc_samples,
2231
+ meta_reasoning_score=None # Will be set by meta_analyzer
2232
+ )
2233
+
2234
+ scenarios_with_metadata.append((
2235
+ scenario,
2236
+ {
2237
+ "factual_state": factual_state,
2238
+ "quality_score": quality_score,
2239
+ "quality_metrics": quality_metrics
2240
+ }
2241
+ ))
2242
+
2243
+ # Meta-reasoning analysis: rank scenarios by informativeness
2244
+ ranked_scenarios = meta_analyzer.analyze_scenarios(scenarios_with_metadata)
2245
+
2246
+ # Update meta_reasoning_score in scenarios
2247
+ final_scenarios = []
2248
+ for scenario, meta_score in ranked_scenarios[:max_scenarios]:
2249
+ # Update scenario with meta-reasoning score
2250
+ scenario.meta_reasoning_score = meta_score
2251
+ final_scenarios.append(scenario)
2252
+
2253
+ return final_scenarios
2254
+
2255
+ except Exception as e:
2256
+ logger.error(f"Monte Carlo counterfactual generation failed: {e}")
2257
+ if use_deterministic_fallback:
2258
+ logger.warning("Falling back to deterministic method")
2259
+ return self._generate_counterfactual_scenarios_deterministic(
2260
+ factual_state, target_variables, max_scenarios
2261
+ )
2262
+ else:
2263
+ raise
2264
+
2265
+ def _generate_counterfactual_scenarios_deterministic(
2266
+ self,
2267
+ factual_state: Dict[str, float],
2268
+ target_variables: List[str],
2269
+ max_scenarios: int = 5
2270
+ ) -> List[CounterfactualScenario]:
2271
+ """Original deterministic counterfactual scenario generation (for backward compatibility).
2272
+
2273
+ Args:
2274
+ factual_state: Current factual state
2275
+ target_variables: Variables to generate scenarios for
2276
+ max_scenarios: Maximum number of scenarios to return
2277
+
2278
+ Returns:
2279
+ List of CounterfactualScenario objects
2280
+ """
2281
+ self.ensure_standardization_stats(factual_state)
2282
+
2283
+ scenarios: List[CounterfactualScenario] = []
2284
+ z_steps = [-2.0, -1.0, -0.5, 0.5, 1.0, 2.0]
2285
+
2286
+ for i, tv in enumerate(target_variables[:max_scenarios]):
2287
+ stats = self.standardization_stats.get(tv, {"mean": 0.0, "std": 1.0})
2288
+ cur = factual_state.get(tv, stats.get("mean", 0.0))
2289
+
2290
+ if not stats or stats.get("std", 0.0) <= 0:
2291
+ base = cur
2292
+ abs_steps = [-2.0, -1.0, -0.5, 0.5, 1.0, 2.0]
2293
+ vals = [base + step for step in abs_steps]
2294
+ else:
2295
+ mean = stats["mean"]
2296
+ std = stats["std"]
2297
+ cz = (cur - mean) / std
2298
+ vals = [(cz + dz) * std + mean for dz in z_steps]
2299
+
2300
+ for j, v in enumerate(vals):
2301
+ interventions = {tv: float(v)}
2302
+ scenarios.append(
2303
+ CounterfactualScenario(
2304
+ name=f"scenario_{i}_{j}",
2305
+ interventions=interventions,
2306
+ expected_outcomes=self._predict_outcomes(
2307
+ factual_state, interventions
2308
+ ),
2309
+ probability=self._calculate_scenario_probability(
2310
+ factual_state, interventions
2311
+ ),
2312
+ reasoning=f"Intervention on {tv} with value {v}",
2313
+ )
2314
+ )
2315
+
2316
+ return scenarios
2317
+
2318
+ def analyze_causal_strength(self, source: str, target: str) -> Dict[str, float]:
2319
+
2320
+ if source not in self.causal_graph or target not in self.causal_graph.get(source, {}):
2321
+ return {"strength": 0.0, "confidence": 0.0, "path_length": float('inf')}
2322
+
2323
+ edge = self.causal_graph[source].get(target, {})
2324
+ strength = float(edge.get("strength", 0.0)) if isinstance(edge, dict) else float(edge)
2325
+ path = self.identify_causal_chain(source, target)
2326
+ path_length = len(path) - 1 if path else float('inf')
2327
+
2328
+ return {
2329
+ "strength": float(strength),
2330
+ "confidence": float(edge.get("confidence", 1.0) if isinstance(edge, dict) else 1.0),
2331
+ "path_length": path_length,
2332
+ "relation_type": edge.get("relation_type", CausalRelationType.DIRECT.value) if isinstance(edge, dict) else CausalRelationType.DIRECT.value
2333
+ }
2334
+
2335
+ def set_standardization_stats(
2336
+ self,
2337
+ variable: str,
2338
+ mean: float,
2339
+ std: float
2340
+ ) -> None:
2341
+
2342
+ self.standardization_stats[variable] = {"mean": mean, "std": std if std > 0 else 1.0}
2343
+
2344
+ def ensure_standardization_stats(self, state: Dict[str, float]) -> None:
2345
+
2346
+ for var, val in state.items():
2347
+ if var not in self.standardization_stats:
2348
+ self.standardization_stats[var] = {"mean": float(val), "std": 1.0}
2349
+
2350
+ def get_nodes(self) -> List[str]:
2351
+
2352
+ return list(self.causal_graph.keys())
2353
+
2354
+ def get_edges(self) -> List[Tuple[str, str]]:
2355
+
2356
+ edges = []
2357
+ for source, targets in self.causal_graph.items():
2358
+ for target in targets.keys():
2359
+ edges.append((source, target))
2360
+ return edges
2361
+
2362
+ def is_dag(self) -> bool:
2363
+
2364
+ try:
2365
+ return rx.is_directed_acyclic_graph(self._graph)
2366
+ except Exception:
2367
+ def has_cycle(node: str, visited: set, rec_stack: set) -> bool:
2368
+
2369
+ visited.add(node)
2370
+ rec_stack.add(node)
2371
+
2372
+ for child in self._get_children(node):
2373
+ if child not in visited:
2374
+ if has_cycle(child, visited, rec_stack):
2375
+ return True
2376
+ elif child in rec_stack:
2377
+ return True
2378
+
2379
+ rec_stack.remove(node)
2380
+ return False
2381
+
2382
+ visited = set()
2383
+ rec_stack = set()
2384
+
2385
+ for node in self.causal_graph:
2386
+ if node not in visited:
2387
+ if has_cycle(node, visited, rec_stack):
2388
+ return False
2389
+
2390
+ return True
2391
+
2392
+ def run(
2393
+ self,
2394
+ task: Optional[Union[str, Any]] = None,
2395
+ img: Optional[str] = None,
2396
+ imgs: Optional[List[str]] = None,
2397
+ correct_answer: Optional[str] = None,
2398
+ streaming_callback: Optional[Any] = None,
2399
+ n: int = 1,
2400
+ initial_state: Optional[Any] = None,
2401
+ target_variables: Optional[List[str]] = None,
2402
+ max_steps: Union[int, str] = 1,
2403
+ *args,
2404
+ **kwargs,
2405
+ ) -> Union[Dict[str, Any], Any]:
2406
+ """
2407
+ Run the agent with support for both standard Agent features and causal analysis.
2408
+
2409
+ This method maintains compatibility with the parent Agent class while adding
2410
+ causal reasoning capabilities. It routes to causal analysis when appropriate,
2411
+ otherwise delegates to the parent Agent's standard functionality.
2412
+
2413
+ Args:
2414
+ task: Task string for LLM analysis, or state dict for causal evolution
2415
+ img: Optional image path for vision tasks (delegates to parent)
2416
+ imgs: Optional list of images (delegates to parent)
2417
+ correct_answer: Optional correct answer for validation (delegates to parent)
2418
+ streaming_callback: Optional callback for streaming output (delegates to parent)
2419
+ n: Number of runs (delegates to parent)
2420
+ initial_state: Initial state dictionary for causal evolution
2421
+ target_variables: Target variables for counterfactual analysis
2422
+ max_steps: Maximum evolution steps for causal analysis
2423
+ *args: Additional positional arguments
2424
+ **kwargs: Additional keyword arguments
2425
+
2426
+ Returns:
2427
+ Dictionary with causal analysis results, or standard Agent output
2428
+ """
2429
+ # Check if this is a causal analysis operation
2430
+ # Criteria: initial_state or target_variables explicitly provided, or task is a dict,
2431
+ # OR automatic detection based on task content
2432
+ is_causal_operation = (
2433
+ initial_state is not None or
2434
+ target_variables is not None or
2435
+ (task is not None and isinstance(task, dict)) or
2436
+ (task is not None and isinstance(task, str) and task.strip().startswith('{')) or
2437
+ self._should_trigger_causal_analysis(task) # Automatic detection
2438
+ )
2439
+
2440
+ # Delegate to parent Agent for all standard operations
2441
+ # This includes: images, streaming, handoffs, multiple runs, regular text tasks, etc.
2442
+ # Only use causal analysis if explicitly requested via causal operation indicators
2443
+ if not is_causal_operation:
2444
+ # All standard Agent operations go to parent (handoffs, images, streaming, etc.)
2445
+ return super().run(
2446
+ task=task,
2447
+ img=img,
2448
+ imgs=imgs,
2449
+ correct_answer=correct_answer,
2450
+ streaming_callback=streaming_callback,
2451
+ n=n,
2452
+ *args,
2453
+ **kwargs,
2454
+ )
2455
+
2456
+ # Causal analysis operations - only when explicitly indicated
2457
+ if task is not None and isinstance(task, str) and initial_state is None and not task.strip().startswith('{'):
2458
+ return self._run_llm_causal_analysis(task, **kwargs)
2459
+
2460
+ if task is not None and initial_state is None:
2461
+ initial_state = task
2462
+
2463
+ if not isinstance(initial_state, dict):
2464
+ try:
2465
+ import json
2466
+ parsed = json.loads(initial_state)
2467
+ if isinstance(parsed, dict):
2468
+ initial_state = parsed
2469
+ else:
2470
+ return {"error": "initial_state JSON must decode to a dict"}
2471
+ except Exception:
2472
+ return {"error": "initial_state must be a dict or JSON-encoded dict"}
2473
+
2474
+ if target_variables is None:
2475
+ target_variables = list(self.causal_graph.keys())
2476
+
2477
+ def _resolve_max_steps(value: Union[int, str]) -> int:
2478
+ if isinstance(value, str) and value == "auto":
2479
+ return max(1, len(self.causal_graph))
2480
+ try:
2481
+ return int(value)
2482
+ except Exception:
2483
+ return max(1, len(self.causal_graph))
2484
+
2485
+ effective_steps = _resolve_max_steps(max_steps if max_steps != 1 or self.causal_max_loops == 1 else self.causal_max_loops)
2486
+ if max_steps == 1 and self.causal_max_loops != 1:
2487
+ effective_steps = _resolve_max_steps(self.causal_max_loops)
2488
+
2489
+ current_state = initial_state.copy()
2490
+ for step in range(effective_steps):
2491
+ current_state = self._predict_outcomes(current_state, {})
2492
+
2493
+ self.ensure_standardization_stats(current_state)
2494
+ counterfactual_scenarios = self.generate_counterfactual_scenarios(
2495
+ current_state,
2496
+ target_variables,
2497
+ max_scenarios=5
2498
+ )
2499
+
2500
+ return {
2501
+ "initial_state": initial_state,
2502
+ "evolved_state": current_state,
2503
+ "counterfactual_scenarios": counterfactual_scenarios,
2504
+ "causal_graph_info": {
2505
+ "nodes": self.get_nodes(),
2506
+ "edges": self.get_edges(),
2507
+ "is_dag": self.is_dag()
2508
+ },
2509
+ "steps": effective_steps
2510
+ }
2511
+
2512
+ def _call_llm_directly(self, prompt: str, system_prompt: Optional[str] = None) -> str:
2513
+ """
2514
+ Call the LLM directly using litellm.completion(), bypassing Agent tool execution.
2515
+
2516
+ This method extracts the model configuration from the Agent and makes a direct
2517
+ API call to get plain text/JSON responses without function calling.
2518
+
2519
+ Args:
2520
+ prompt: The user prompt to send to the LLM
2521
+ system_prompt: Optional system prompt (defaults to Agent's system prompt)
2522
+
2523
+ Returns:
2524
+ The raw text response from the LLM
2525
+ """
2526
+ try:
2527
+ import litellm
2528
+ except ImportError:
2529
+ raise ImportError("litellm is required for direct LLM calls")
2530
+
2531
+ # Get model configuration from Agent
2532
+ model_name = getattr(self, 'model_name', 'gpt-4o')
2533
+ api_key = getattr(self, 'llm_api_key', None) or os.getenv('OPENAI_API_KEY')
2534
+ base_url = getattr(self, 'llm_base_url', None)
2535
+
2536
+ # Get system prompt
2537
+ if system_prompt is None:
2538
+ system_prompt = getattr(self, 'system_prompt', None)
2539
+
2540
+ # Build messages
2541
+ messages = []
2542
+ if system_prompt:
2543
+ messages.append({"role": "system", "content": system_prompt})
2544
+ messages.append({"role": "user", "content": prompt})
2545
+
2546
+ # Call litellm directly
2547
+ try:
2548
+ response = litellm.completion(
2549
+ model=model_name,
2550
+ messages=messages,
2551
+ api_key=api_key,
2552
+ api_base=base_url,
2553
+ temperature=getattr(self, 'temperature', 0.5),
2554
+ max_tokens=getattr(self, 'max_tokens', 4096),
2555
+ )
2556
+
2557
+ # Extract text from response
2558
+ if response and hasattr(response, 'choices') and len(response.choices) > 0:
2559
+ content = response.choices[0].message.content
2560
+ return content if content else ""
2561
+ else:
2562
+ logger.warning("Empty response from LLM")
2563
+ return ""
2564
+ except Exception as e:
2565
+ logger.error(f"Error calling LLM directly: {e}")
2566
+ raise
2567
+
2568
+ def _extract_variables_ml_based(self, task: str) -> bool:
2569
+ """
2570
+ Extract variables using ML/NLP-based approach - use LLM to generate structured output,
2571
+ parse the function call from the response, and automatically invoke the handler.
2572
+
2573
+ This bypasses unreliable function calling by directly parsing LLM output and
2574
+ invoking the handler programmatically.
2575
+
2576
+ Args:
2577
+ task: The task string to analyze
2578
+
2579
+ Returns:
2580
+ True if variables were successfully extracted, False otherwise
2581
+ """
2582
+ import json
2583
+ import re
2584
+
2585
+ logger.info(f"Starting ML-based variable extraction for task: {task[:100]}...")
2586
+
2587
+ # Use LLM to generate structured JSON output with variables and edges
2588
+ extraction_prompt = f"""Analyze this task and extract causal variables and relationships.
2589
+
2590
+ Task: {task}
2591
+
2592
+ Return a JSON object with this exact structure:
2593
+ {{
2594
+ "required_variables": ["var1", "var2", "var3"],
2595
+ "causal_edges": [["var1", "var2"], ["var2", "var3"]],
2596
+ "reasoning": "Brief explanation of why these variables are needed",
2597
+ "optional_variables": ["var4"],
2598
+ "counterfactual_variables": ["var1", "var2"]
2599
+ }}
2600
+
2601
+ Extract ALL relevant variables and causal relationships. Be thorough.
2602
+ Return ONLY valid JSON, no other text."""
2603
+
2604
+ try:
2605
+ # Call LLM directly (bypasses Agent tool execution)
2606
+ # This gives us pure text/JSON response without function calling
2607
+ # The LLM is still the core component - we just parse its output programmatically
2608
+ raw_response = self._call_llm_directly(extraction_prompt)
2609
+
2610
+ extracted_data = None
2611
+
2612
+ # Parse JSON from LLM text response
2613
+ # The LLM returns plain text with JSON embedded, so we extract it
2614
+ response_text = str(raw_response)
2615
+
2616
+ # Try to extract JSON from response text
2617
+ json_match = re.search(r'\{[^{}]*"required_variables"[^{}]*\}', response_text, re.DOTALL)
2618
+ if not json_match:
2619
+ # Try to find any JSON object
2620
+ json_match = re.search(r'\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}', response_text, re.DOTALL)
2621
+
2622
+ if json_match:
2623
+ json_str = json_match.group(0)
2624
+ try:
2625
+ # Try to fix common JSON issues
2626
+ # Remove invalid escape sequences
2627
+ json_str = json_str.replace('\\"', '"').replace("\\'", "'")
2628
+ # Try parsing
2629
+ extracted_data = json.loads(json_str)
2630
+ logger.info("Parsed JSON from text response")
2631
+ except json.JSONDecodeError as e:
2632
+ # Try to extract just the JSON object more carefully
2633
+ try:
2634
+ # Find the innermost complete JSON object
2635
+ brace_count = 0
2636
+ start_idx = json_str.find('{')
2637
+ if start_idx >= 0:
2638
+ for i in range(start_idx, len(json_str)):
2639
+ if json_str[i] == '{':
2640
+ brace_count += 1
2641
+ elif json_str[i] == '}':
2642
+ brace_count -= 1
2643
+ if brace_count == 0:
2644
+ json_str = json_str[start_idx:i+1]
2645
+ extracted_data = json.loads(json_str)
2646
+ logger.info("Parsed JSON after fixing structure")
2647
+ break
2648
+ except (json.JSONDecodeError, ValueError) as e2:
2649
+ logger.warning(f"Failed to parse JSON after fixes: {e2}")
2650
+ logger.debug(f"JSON string was: {json_str[:500]}")
2651
+
2652
+ # Process extracted data
2653
+ if extracted_data:
2654
+ # Validate structure
2655
+ required_vars = extracted_data.get("required_variables", [])
2656
+ causal_edges = extracted_data.get("causal_edges", [])
2657
+ reasoning = extracted_data.get("reasoning", "Extracted from task analysis")
2658
+ optional_vars = extracted_data.get("optional_variables", [])
2659
+ counterfactual_vars = extracted_data.get("counterfactual_variables", [])
2660
+
2661
+ if required_vars and causal_edges:
2662
+ # Automatically invoke the handler with extracted data
2663
+ logger.info(f"Extracted {len(required_vars)} variables, {len(causal_edges)} edges via ML")
2664
+ result = self._extract_causal_variables_handler(
2665
+ required_variables=required_vars,
2666
+ causal_edges=causal_edges,
2667
+ reasoning=reasoning,
2668
+ optional_variables=optional_vars if optional_vars else None,
2669
+ counterfactual_variables=counterfactual_vars if counterfactual_vars else None,
2670
+ )
2671
+
2672
+ # Check if extraction was successful
2673
+ if result.get("status") == "success" and result.get("summary", {}).get("variables_added", 0) > 0:
2674
+ logger.info(f"ML-based extraction successful: {result.get('summary')}")
2675
+ return True
2676
+ else:
2677
+ logger.warning(f"ML-based extraction returned: {result}")
2678
+ else:
2679
+ logger.warning(f"Extracted data missing required fields. required_variables: {required_vars}, causal_edges: {causal_edges}")
2680
+ else:
2681
+ logger.warning("Could not extract data from LLM response")
2682
+ logger.debug(f"Raw response type: {type(raw_response)}, value: {str(raw_response)[:500]}")
2683
+
2684
+ except Exception as e:
2685
+ logger.error(f"Error in ML-based extraction: {e}")
2686
+ import traceback
2687
+ logger.debug(traceback.format_exc())
2688
+
2689
+ return len(self.causal_graph) > 0
2690
+
2691
+ def _generate_causal_analysis_ml_based(self, task: str) -> Optional[Dict[str, Any]]:
2692
+ """
2693
+ Generate causal analysis using ML-based approach - use LLM to generate structured output,
2694
+ parse the function call from the response, and automatically invoke the handler.
2695
+
2696
+ This bypasses unreliable function calling by directly parsing LLM output and
2697
+ invoking the handler programmatically.
2698
+
2699
+ Args:
2700
+ task: The task string to analyze
2701
+
2702
+ Returns:
2703
+ Dictionary with causal analysis results, or None if extraction failed
2704
+ """
2705
+ import json
2706
+ import re
2707
+
2708
+ logger.info(f"Starting ML-based causal analysis generation for task: {task[:100]}...")
2709
+
2710
+ # Build comprehensive causal analysis prompt
2711
+ causal_prompt = self._build_causal_prompt(task)
2712
+
2713
+ # Add instruction for structured output
2714
+ analysis_prompt = f"""{causal_prompt}
2715
+
2716
+ CRITICAL: You must return a JSON object with this EXACT structure. Do not include any text before or after the JSON.
2717
+
2718
+ Required JSON format:
2719
+ {{
2720
+ "causal_analysis": "Detailed analysis of causal relationships and mechanisms. This must be a comprehensive text explanation.",
2721
+ "intervention_planning": "Planned interventions to test causal hypotheses.",
2722
+ "counterfactual_scenarios": [
2723
+ {{
2724
+ "scenario_name": "Scenario 1",
2725
+ "interventions": {{"var1": 10, "var2": 20}},
2726
+ "expected_outcomes": {{"target_var": 30}},
2727
+ "reasoning": "Why this scenario is important..."
2728
+ }}
2729
+ ],
2730
+ "causal_strength_assessment": "Assessment of relationship strengths and confounders.",
2731
+ "optimal_solution": "Recommended solution based on analysis."
2732
+ }}
2733
+
2734
+ IMPORTANT:
2735
+ - The "causal_analysis" field is REQUIRED and must contain detailed text analysis
2736
+ - Return ONLY the JSON object, no markdown, no code blocks, no explanations
2737
+ - Ensure all fields are present and properly formatted"""
2738
+
2739
+ try:
2740
+ # Call LLM directly (bypasses Agent tool execution)
2741
+ # This gives us pure text/JSON response without function calling
2742
+ # The LLM is still the core component - we just parse its output programmatically
2743
+ raw_response = self._call_llm_directly(analysis_prompt)
2744
+
2745
+ extracted_data = None
2746
+
2747
+ # Parse JSON from LLM text response
2748
+ # The LLM returns plain text with JSON embedded, so we extract it
2749
+ response_text = str(raw_response)
2750
+
2751
+ # First, try to extract JSON from markdown code blocks (```json ... ```)
2752
+ json_block_match = re.search(r'```(?:json)?\s*(\{.*?\})\s*```', response_text, re.DOTALL)
2753
+ if json_block_match:
2754
+ json_str = json_block_match.group(1)
2755
+ try:
2756
+ extracted_data = json.loads(json_str)
2757
+ logger.info("Parsed JSON from markdown code block")
2758
+ except json.JSONDecodeError:
2759
+ pass
2760
+
2761
+ # If not found in code block, try to extract JSON directly
2762
+ if not extracted_data:
2763
+ # Try to find JSON object with causal_analysis field
2764
+ json_match = re.search(r'\{[^{}]*"causal_analysis"[^{}]*\}', response_text, re.DOTALL)
2765
+ if not json_match:
2766
+ # Try to find any complete JSON object (handle nested objects)
2767
+ # This regex finds the outermost complete JSON object
2768
+ brace_count = 0
2769
+ start_idx = response_text.find('{')
2770
+ if start_idx >= 0:
2771
+ for i in range(start_idx, len(response_text)):
2772
+ if response_text[i] == '{':
2773
+ brace_count += 1
2774
+ elif response_text[i] == '}':
2775
+ brace_count -= 1
2776
+ if brace_count == 0:
2777
+ json_str = response_text[start_idx:i+1]
2778
+ try:
2779
+ # Try to fix common JSON issues
2780
+ json_str = json_str.replace('\\"', '"').replace("\\'", "'")
2781
+ extracted_data = json.loads(json_str)
2782
+ logger.info("Parsed JSON from text response")
2783
+ break
2784
+ except json.JSONDecodeError:
2785
+ # Try without the escape fixes
2786
+ try:
2787
+ extracted_data = json.loads(response_text[start_idx:i+1])
2788
+ logger.info("Parsed JSON after removing escape fixes")
2789
+ break
2790
+ except json.JSONDecodeError as e:
2791
+ logger.debug(f"Failed to parse JSON: {e}")
2792
+ break
2793
+
2794
+ # Process extracted data
2795
+ if extracted_data:
2796
+ # Validate structure - try multiple possible field names
2797
+ causal_analysis = (
2798
+ extracted_data.get("causal_analysis", "") or
2799
+ extracted_data.get("analysis", "") or
2800
+ extracted_data.get("causal_analysis_text", "") or
2801
+ str(extracted_data.get("analysis_text", ""))
2802
+ )
2803
+ intervention_planning = extracted_data.get("intervention_planning", "") or extracted_data.get("interventions", "")
2804
+ counterfactual_scenarios = extracted_data.get("counterfactual_scenarios", []) or extracted_data.get("scenarios", [])
2805
+ causal_strength_assessment = extracted_data.get("causal_strength_assessment", "") or extracted_data.get("strength_assessment", "")
2806
+ optimal_solution = extracted_data.get("optimal_solution", "") or extracted_data.get("solution", "")
2807
+
2808
+ # Log what we extracted for debugging (use INFO so it's visible)
2809
+ logger.info(f"Extracted fields - causal_analysis: {bool(causal_analysis)}, scenarios: {len(counterfactual_scenarios)}")
2810
+ logger.info(f"Extracted data keys: {list(extracted_data.keys())}")
2811
+ if not causal_analysis:
2812
+ # Log the full structure to see what we got
2813
+ logger.info(f"Full extracted data (first 1000 chars): {str(extracted_data)[:1000]}")
2814
+
2815
+ if causal_analysis:
2816
+ # Automatically invoke the handler with extracted data
2817
+ logger.info(f"Extracted causal analysis via ML ({len(causal_analysis)} chars, {len(counterfactual_scenarios)} scenarios)")
2818
+ result = self._generate_causal_analysis_handler(
2819
+ causal_analysis=causal_analysis,
2820
+ intervention_planning=intervention_planning,
2821
+ counterfactual_scenarios=counterfactual_scenarios,
2822
+ causal_strength_assessment=causal_strength_assessment,
2823
+ optimal_solution=optimal_solution,
2824
+ )
2825
+
2826
+ # Check if analysis was successful
2827
+ if result.get("status") == "success":
2828
+ logger.info(f"ML-based causal analysis generation successful")
2829
+ # Return the analysis data for use in final result
2830
+ return {
2831
+ 'causal_analysis': causal_analysis,
2832
+ 'intervention_planning': intervention_planning,
2833
+ 'counterfactual_scenarios': counterfactual_scenarios,
2834
+ 'causal_strength_assessment': causal_strength_assessment,
2835
+ 'optimal_solution': optimal_solution,
2836
+ }
2837
+ else:
2838
+ logger.warning(f"ML-based analysis returned: {result}")
2839
+ else:
2840
+ logger.warning("Extracted data missing causal_analysis field")
2841
+ logger.info(f"Extracted data structure: {extracted_data}")
2842
+ logger.info(f"Available keys: {list(extracted_data.keys()) if isinstance(extracted_data, dict) else 'Not a dict'}")
2843
+ # Try to use the raw response text as causal_analysis if JSON parsing failed
2844
+ if not extracted_data or not isinstance(extracted_data, dict):
2845
+ logger.info("Attempting to use raw response as causal analysis")
2846
+ # Use the raw response as a fallback
2847
+ causal_analysis = response_text[:5000] # Limit length
2848
+ if causal_analysis:
2849
+ result = self._generate_causal_analysis_handler(
2850
+ causal_analysis=causal_analysis,
2851
+ intervention_planning="",
2852
+ counterfactual_scenarios=[],
2853
+ causal_strength_assessment="",
2854
+ optimal_solution="",
2855
+ )
2856
+ if result.get("status") == "success":
2857
+ return {
2858
+ 'causal_analysis': causal_analysis,
2859
+ 'intervention_planning': "",
2860
+ 'counterfactual_scenarios': [],
2861
+ 'causal_strength_assessment': "",
2862
+ 'optimal_solution': "",
2863
+ }
2864
+ else:
2865
+ logger.warning("Could not extract data from LLM response for causal analysis")
2866
+ logger.debug(f"Raw response type: {type(raw_response)}, value: {str(raw_response)[:500]}")
2867
+
2868
+ except Exception as e:
2869
+ logger.error(f"Error in ML-based causal analysis generation: {e}")
2870
+ import traceback
2871
+ logger.debug(traceback.format_exc())
2872
+
2873
+ return None
2874
+
2875
+ def _extract_variables_from_task(self, task: str) -> bool:
2876
+ """
2877
+ Extract variables and causal relationships from a task.
2878
+
2879
+ Uses ML-based extraction (structured LLM output + automatic handler invocation)
2880
+ instead of relying on unreliable function calling.
2881
+
2882
+ Args:
2883
+ task: The task string to analyze
2884
+
2885
+ Returns:
2886
+ True if variables were successfully extracted, False otherwise
2887
+ """
2888
+ # Use ML-based extraction (more reliable than function calling)
2889
+ return self._extract_variables_ml_based(task)
2890
+
2891
+ def _run_llm_causal_analysis(self, task: str, target_variables: Optional[List[str]] = None, **kwargs) -> Dict[str, Any]:
2892
+ """
2893
+ Run LLM-based causal analysis on a task.
2894
+
2895
+ Args:
2896
+ task: Task string to analyze
2897
+ target_variables: Optional list of target variables for counterfactual scenarios.
2898
+ If None, uses all variables in the causal graph.
2899
+ **kwargs: Additional arguments
2900
+
2901
+ Returns:
2902
+ Dictionary with causal analysis results
2903
+ """
2904
+ self.causal_memory = []
2905
+
2906
+ # Check if causal graph is empty - if so, trigger variable extraction
2907
+ if len(self.causal_graph) == 0:
2908
+ logger.info("Causal graph is empty, starting variable extraction phase...")
2909
+ extraction_success = self._extract_variables_from_task(task)
2910
+
2911
+ if not extraction_success:
2912
+ # Extraction failed, return error
2913
+ return {
2914
+ 'task': task,
2915
+ 'error': 'Variable extraction failed',
2916
+ 'message': (
2917
+ 'Could not extract variables from the task. '
2918
+ 'Please ensure the task describes causal relationships or variables, '
2919
+ 'or manually initialize the agent with variables and causal_edges.'
2920
+ ),
2921
+ 'causal_graph_info': {
2922
+ 'nodes': [],
2923
+ 'edges': [],
2924
+ 'is_dag': True
2925
+ }
2926
+ }
2927
+
2928
+ logger.info(
2929
+ f"Variable extraction completed. "
2930
+ f"Graph now has {len(self.causal_graph)} variables and "
2931
+ f"{sum(len(children) for children in self.causal_graph.values())} edges."
2932
+ )
2933
+
2934
+ # Use Agent's normal run method to get rich output and proper loop handling
2935
+ # This will show the LLM's rich output and respect max_loops
2936
+ causal_prompt = self._build_causal_prompt(task)
2937
+
2938
+ # Store the current memory size before running to detect new entries
2939
+ memory_size_before = len(self.causal_memory)
2940
+
2941
+ # Run the agent - this will execute tools and store results in causal_memory
2942
+ super().run(task=causal_prompt)
2943
+
2944
+ # Extract the causal analysis from causal_memory (stored by tool handler)
2945
+ # Look for the most recent analysis entry
2946
+ analysis_result = None
2947
+ final_analysis = ""
2948
+
2949
+ # Search backwards through causal_memory for the most recent analysis
2950
+ # Check both new entries (after memory_size_before) and all entries
2951
+ for entry in reversed(self.causal_memory):
2952
+ if isinstance(entry, dict):
2953
+ # Check if this is an analysis entry (has 'type' == 'analysis' or has 'causal_analysis' field)
2954
+ entry_type = entry.get('type')
2955
+ has_causal_analysis = 'causal_analysis' in entry
2956
+
2957
+ if entry_type == 'analysis' or has_causal_analysis:
2958
+ analysis_result = entry
2959
+ final_analysis = entry.get('causal_analysis', '')
2960
+ if final_analysis:
2961
+ logger.info(f"Found causal analysis in memory (type: {entry_type}, length: {len(final_analysis)})")
2962
+ break
2963
+
2964
+ # If no analysis found in memory, use ML-based generation as fallback
2965
+ if not final_analysis:
2966
+ logger.warning("No causal analysis found in causal_memory, using ML-based generation fallback")
2967
+ ml_analysis = self._generate_causal_analysis_ml_based(task)
2968
+ if ml_analysis and ml_analysis.get('causal_analysis'):
2969
+ # Use the returned analysis directly
2970
+ final_analysis = ml_analysis.get('causal_analysis', '')
2971
+ analysis_result = {
2972
+ 'type': 'analysis',
2973
+ 'causal_analysis': final_analysis,
2974
+ 'intervention_planning': ml_analysis.get('intervention_planning', ''),
2975
+ 'counterfactual_scenarios': ml_analysis.get('counterfactual_scenarios', []),
2976
+ 'causal_strength_assessment': ml_analysis.get('causal_strength_assessment', ''),
2977
+ 'optimal_solution': ml_analysis.get('optimal_solution', ''),
2978
+ }
2979
+ # Also ensure it's stored in memory for consistency
2980
+ if analysis_result not in self.causal_memory:
2981
+ self.causal_memory.append(analysis_result)
2982
+ if final_analysis:
2983
+ logger.info(f"Found causal analysis from ML-based fallback (length: {len(final_analysis)})")
2984
+
2985
+ # If still no analysis found, use the LLM's response as final fallback
2986
+ if not final_analysis:
2987
+ logger.warning("No causal analysis found in causal_memory, attempting fallback from conversation history")
2988
+ # Try to get the last response from the conversation
2989
+ if hasattr(self, 'short_memory') and self.short_memory:
2990
+ # Conversation object has conversation_history attribute, not get_messages()
2991
+ if hasattr(self.short_memory, 'conversation_history'):
2992
+ conversation_history = self.short_memory.conversation_history
2993
+ if conversation_history:
2994
+ # Get the last assistant message
2995
+ for msg in reversed(conversation_history):
2996
+ if isinstance(msg, dict) and msg.get('role') == 'assistant':
2997
+ final_analysis = msg.get('content', '')
2998
+ if final_analysis:
2999
+ logger.info(f"Extracted causal analysis from conversation history (length: {len(final_analysis)})")
3000
+ break
3001
+ elif hasattr(msg, 'role') and msg.role == 'assistant':
3002
+ final_analysis = getattr(msg, 'content', str(msg))
3003
+ if final_analysis:
3004
+ logger.info(f"Extracted causal analysis from conversation history (length: {len(final_analysis)})")
3005
+ break
3006
+
3007
+ if not final_analysis:
3008
+ logger.warning(f"Could not extract causal analysis. Memory size: {len(self.causal_memory)} (was {memory_size_before})")
3009
+
3010
+ default_state = {var: 0.0 for var in self.get_nodes()}
3011
+ self.ensure_standardization_stats(default_state)
3012
+
3013
+ # Use provided target_variables or default to all variables
3014
+ if target_variables is None:
3015
+ target_variables = self.get_nodes()
3016
+
3017
+ # Limit to top variables if too many
3018
+ target_vars = target_variables[:5] if len(target_variables) > 5 else target_variables
3019
+
3020
+ # Use counterfactual scenarios from ML analysis if available
3021
+ if analysis_result and isinstance(analysis_result, dict):
3022
+ ml_scenarios = analysis_result.get('counterfactual_scenarios', [])
3023
+ if ml_scenarios:
3024
+ counterfactual_scenarios = ml_scenarios
3025
+ else:
3026
+ # Fallback to generated scenarios
3027
+ counterfactual_scenarios = self.generate_counterfactual_scenarios(
3028
+ default_state,
3029
+ target_vars,
3030
+ max_scenarios=5
3031
+ )
3032
+ else:
3033
+ # Fallback to generated scenarios
3034
+ counterfactual_scenarios = self.generate_counterfactual_scenarios(
3035
+ default_state,
3036
+ target_vars,
3037
+ max_scenarios=5
3038
+ )
3039
+
3040
+ return {
3041
+ 'task': task,
3042
+ 'causal_analysis': final_analysis,
3043
+ 'counterfactual_scenarios': counterfactual_scenarios,
3044
+ 'causal_graph_info': {
3045
+ 'nodes': self.get_nodes(),
3046
+ 'edges': self.get_edges(),
3047
+ 'is_dag': self.is_dag()
3048
+ },
3049
+ 'analysis_steps': self.causal_memory,
3050
+ 'intervention_planning': analysis_result.get('intervention_planning', '') if analysis_result else '',
3051
+ 'causal_strength_assessment': analysis_result.get('causal_strength_assessment', '') if analysis_result else '',
3052
+ 'optimal_solution': analysis_result.get('optimal_solution', '') if analysis_result else '',
3053
+ }
3054
+
3055
+ # =========================
3056
+ # Compatibility extensions
3057
+ # =========================
3058
+
3059
+ # ---- Helpers ----
3060
+ @staticmethod
3061
+ def _require_pandas() -> None:
3062
+ if not PANDAS_AVAILABLE:
3063
+ raise ImportError("pandas is required for this operation. Install pandas to proceed.")
3064
+
3065
+ @staticmethod
3066
+ def _require_scipy() -> None:
3067
+ if not SCIPY_AVAILABLE:
3068
+ raise ImportError("scipy is required for this operation. Install scipy to proceed.")
3069
+
3070
+ @staticmethod
3071
+ def _require_cvxpy() -> None:
3072
+ if not CVXPY_AVAILABLE:
3073
+ raise ImportError("cvxpy is required for this operation. Install cvxpy to proceed.")
3074
+
3075
+ def _edge_strength(self, source: str, target: str) -> float:
3076
+ edge = self.causal_graph.get(source, {}).get(target, None)
3077
+ if isinstance(edge, dict):
3078
+ return float(edge.get("strength", 0.0))
3079
+ try:
3080
+ return float(edge) if edge is not None else 0.0
3081
+ except Exception:
3082
+ return 0.0
3083
+
3084
+ class _TimeDebug:
3085
+ def __init__(self, name: str) -> None:
3086
+ self.name = name
3087
+ self.start = 0.0
3088
+ def __enter__(self):
3089
+ try:
3090
+ import time
3091
+ self.start = time.perf_counter()
3092
+ except Exception:
3093
+ self.start = 0.0
3094
+ return self
3095
+ def __exit__(self, exc_type, exc, tb):
3096
+ if logger.isEnabledFor(logging.DEBUG) and self.start:
3097
+ try:
3098
+ import time
3099
+ duration = time.perf_counter() - self.start
3100
+ logger.debug(f"{self.name} completed in {duration:.4f}s")
3101
+ except Exception:
3102
+ pass
3103
+
3104
+ def _ensure_edge(self, source: str, target: str) -> None:
3105
+ """Ensure edge exists in both dict graph and rustworkx graph."""
3106
+ self._ensure_node_exists(source)
3107
+ self._ensure_node_exists(target)
3108
+ if target not in self.causal_graph.get(source, {}):
3109
+ self.causal_graph.setdefault(source, {})[target] = {"strength": 0.0, "confidence": 1.0}
3110
+ try:
3111
+ u_idx = self._ensure_node_index(source)
3112
+ v_idx = self._ensure_node_index(target)
3113
+ if self._graph.get_edge_data(u_idx, v_idx) is None:
3114
+ self._graph.add_edge(u_idx, v_idx, self.causal_graph[source][target])
3115
+ except Exception:
3116
+ pass
3117
+
3118
+ # Convenience graph-like methods for compatibility
3119
+ def add_nodes_from(self, nodes: List[str]) -> None:
3120
+ for n in nodes:
3121
+ self._ensure_node_exists(n)
3122
+
3123
+ def add_edges_from(self, edges: List[Tuple[str, str]]) -> None:
3124
+ for u, v in edges:
3125
+ self.add_causal_relationship(u, v)
3126
+
3127
+ def edges(self) -> List[Tuple[str, str]]:
3128
+ return self.get_edges()
3129
+
3130
+ # ---- Batched predictions ----
3131
+ def _predict_outcomes_batch(
3132
+ self,
3133
+ factual_states: List[Dict[str, float]],
3134
+ interventions: Optional[Union[Dict[str, float], List[Dict[str, float]]]] = None,
3135
+ ) -> List[Dict[str, float]]:
3136
+ """
3137
+ Batched deterministic SCM forward pass. Uses shared topology and vectorized parent aggregation.
3138
+ """
3139
+ if not factual_states:
3140
+ return []
3141
+ if len(factual_states) == 1 or not self.enable_batch_predict:
3142
+ return [self._predict_outcomes(factual_states[0], interventions if isinstance(interventions, dict) else (interventions or {}))]
3143
+
3144
+ batch = len(factual_states)
3145
+ if interventions is None:
3146
+ interventions_list = [{} for _ in range(batch)]
3147
+ elif isinstance(interventions, list):
3148
+ interventions_list = interventions
3149
+ else:
3150
+ interventions_list = [interventions for _ in range(batch)]
3151
+
3152
+ topo = self._topological_sort()
3153
+ parents_map = {node: self._get_parents(node) for node in topo}
3154
+ stats = self.standardization_stats
3155
+ z_pred: Dict[str, np.ndarray] = {}
3156
+
3157
+ # Initialize z with raw + interventions standardized
3158
+ for node in topo:
3159
+ arr = np.empty(batch, dtype=float)
3160
+ mean = stats.get(node, {}).get("mean", 0.0)
3161
+ std = stats.get(node, {}).get("std", 1.0) or 1.0
3162
+ for i in range(batch):
3163
+ raw_val = interventions_list[i].get(node, factual_states[i].get(node, 0.0))
3164
+ arr[i] = (raw_val - mean) / std
3165
+ z_pred[node] = arr
3166
+
3167
+ # Propagate for non-intervened nodes
3168
+ for node in topo:
3169
+ parents = parents_map.get(node, [])
3170
+ if not parents:
3171
+ continue
3172
+ arr = z_pred[node]
3173
+ # Only recompute if node not directly intervened
3174
+ intervene_mask = np.array([node in interventions_list[i] for i in range(batch)], dtype=bool)
3175
+ if np.all(intervene_mask):
3176
+ continue
3177
+ if not parents:
3178
+ continue
3179
+ parent_matrix = np.vstack([z_pred[p] for p in parents]) # shape (k, batch)
3180
+ strengths = np.array([self._edge_strength(p, node) for p in parents], dtype=float).reshape(-1, 1)
3181
+ combined = (strengths * parent_matrix).sum(axis=0)
3182
+ if intervene_mask.any():
3183
+ # preserve intervened samples
3184
+ arr = np.where(intervene_mask, arr, combined)
3185
+ else:
3186
+ arr = combined
3187
+ z_pred[node] = arr
3188
+
3189
+ # De-standardize
3190
+ outputs: List[Dict[str, float]] = []
3191
+ for i in range(batch):
3192
+ out: Dict[str, float] = {}
3193
+ for node, z_arr in z_pred.items():
3194
+ s = stats.get(node, {"mean": 0.0, "std": 1.0})
3195
+ out[node] = float(z_arr[i] * s.get("std", 1.0) + s.get("mean", 0.0))
3196
+ outputs.append(out)
3197
+ return outputs
3198
+
3199
+ # Convenience graph-like methods for compatibility
3200
+ def add_nodes_from(self, nodes: List[str]) -> None:
3201
+ for n in nodes:
3202
+ self._ensure_node_exists(n)
3203
+
3204
+ def add_edges_from(self, edges: List[Tuple[str, str]]) -> None:
3205
+ for u, v in edges:
3206
+ self.add_causal_relationship(u, v)
3207
+
3208
+ def edges(self) -> List[Tuple[str, str]]:
3209
+ return self.get_edges()
3210
+
3211
+ # ---- Data-driven fitting ----
3212
+ def fit_from_dataframe(
3213
+ self,
3214
+ df: Any,
3215
+ variables: List[str],
3216
+ window: int = 30,
3217
+ decay_alpha: float = 0.9,
3218
+ ridge_lambda: float = 0.0,
3219
+ enforce_signs: bool = True
3220
+ ) -> None:
3221
+ """
3222
+ Fit edge strengths and standardization stats from a rolling window with recency weighting.
3223
+ """
3224
+ with self._TimeDebug("fit_from_dataframe"):
3225
+ self._require_pandas()
3226
+ if df is None:
3227
+ return
3228
+ if not isinstance(df, pd.DataFrame):
3229
+ raise TypeError(f"df must be a pandas DataFrame, got {type(df)}")
3230
+ if not variables:
3231
+ return
3232
+ missing = [v for v in variables if v not in df.columns]
3233
+ if missing:
3234
+ raise ValueError(f"Variables not in DataFrame: {missing}")
3235
+ window = max(1, int(window))
3236
+ if not (0 < decay_alpha <= 1):
3237
+ raise ValueError("decay_alpha must be in (0,1]")
3238
+
3239
+ df_local = df[variables].dropna().copy()
3240
+ if df_local.empty:
3241
+ return
3242
+ window_df = df_local.tail(window)
3243
+ n = len(window_df)
3244
+ weights = np.array([decay_alpha ** (n - 1 - i) for i in range(n)], dtype=float)
3245
+ weights = weights / (weights.sum() if weights.sum() != 0 else 1.0)
3246
+
3247
+ # Standardization stats
3248
+ self.standardization_stats = {}
3249
+ for v in variables:
3250
+ m = float(window_df[v].mean())
3251
+ s = float(window_df[v].std(ddof=0))
3252
+ if s == 0:
3253
+ s = 1.0
3254
+ self.standardization_stats[v] = {"mean": m, "std": s}
3255
+ for node in self.causal_graph.keys():
3256
+ if node not in self.standardization_stats:
3257
+ self.standardization_stats[node] = {"mean": 0.0, "std": 1.0}
3258
+
3259
+ # Estimate edge strengths
3260
+ for child in list(self.causal_graph.keys()):
3261
+ parents = self._get_parents(child)
3262
+ if not parents:
3263
+ continue
3264
+ if child not in window_df.columns:
3265
+ continue
3266
+ parent_vals = []
3267
+ for p in parents:
3268
+ if p in window_df.columns:
3269
+ stats = self.standardization_stats.get(p, {"mean": 0.0, "std": 1.0})
3270
+ parent_vals.append(((window_df[p] - stats["mean"]) / stats["std"]).values)
3271
+ if not parent_vals:
3272
+ continue
3273
+ X = np.vstack(parent_vals).T
3274
+ y_stats = self.standardization_stats.get(child, {"mean": 0.0, "std": 1.0})
3275
+ y = ((window_df[child] - y_stats["mean"]) / y_stats["std"]).values
3276
+ W = np.diag(weights)
3277
+ XtW = X.T @ W
3278
+ XtWX = XtW @ X
3279
+ if ridge_lambda > 0 and XtWX.size > 0:
3280
+ k = XtWX.shape[0]
3281
+ XtWX = XtWX + ridge_lambda * np.eye(k)
3282
+ try:
3283
+ XtWX_inv = np.linalg.pinv(XtWX)
3284
+ beta = XtWX_inv @ (XtW @ y)
3285
+ except Exception:
3286
+ beta = np.zeros(X.shape[1])
3287
+ beta = np.asarray(beta)
3288
+ for idx, p in enumerate(parents):
3289
+ strength = float(beta[idx]) if idx < len(beta) else 0.0
3290
+ if enforce_signs:
3291
+ sign = self.edge_sign_constraints.get((p, child))
3292
+ if sign == 1 and strength < 0:
3293
+ strength = 0.0
3294
+ elif sign == -1 and strength > 0:
3295
+ strength = 0.0
3296
+ self._ensure_edge(p, child)
3297
+ self.causal_graph[p][child]["strength"] = strength
3298
+ self.causal_graph[p][child]["confidence"] = 1.0
3299
+
3300
+ # ---- Uncertainty ----
3301
+ def quantify_uncertainty(
3302
+ self,
3303
+ df: Any,
3304
+ variables: List[str],
3305
+ windows: int = 200,
3306
+ alpha: float = 0.95
3307
+ ) -> Dict[str, Any]:
3308
+ with self._TimeDebug("quantify_uncertainty"):
3309
+ self._require_pandas()
3310
+ if df is None or not isinstance(df, pd.DataFrame):
3311
+ return {"edge_cis": {}, "samples": 0}
3312
+ usable = df[variables].dropna()
3313
+ if len(usable) < 10:
3314
+ return {"edge_cis": {}, "samples": 0}
3315
+ windows = max(1, int(windows))
3316
+ samples: Dict[Tuple[str, str], List[float]] = {}
3317
+
3318
+ # Snapshot current strengths to restore later
3319
+ baseline_strengths: Dict[Tuple[str, str], float] = {}
3320
+ for u, targets in self.causal_graph.items():
3321
+ for v, meta in targets.items():
3322
+ try:
3323
+ baseline_strengths[(u, v)] = float(meta.get("strength", 0.0)) if isinstance(meta, dict) else float(meta)
3324
+ except Exception:
3325
+ baseline_strengths[(u, v)] = 0.0
3326
+ baseline_stats = dict(self.standardization_stats)
3327
+
3328
+ def _snapshot_strengths() -> Dict[Tuple[str, str], float]:
3329
+ snap: Dict[Tuple[str, str], float] = {}
3330
+ for u, targets in self.causal_graph.items():
3331
+ for v, meta in targets.items():
3332
+ try:
3333
+ snap[(u, v)] = float(meta.get("strength", 0.0)) if isinstance(meta, dict) else float(meta)
3334
+ except Exception:
3335
+ snap[(u, v)] = 0.0
3336
+ return snap
3337
+
3338
+ def _bootstrap_single(df_sample: "pd.DataFrame") -> Dict[Tuple[str, str], float]:
3339
+ # Use a shallow clone to avoid mutating main agent when running in parallel
3340
+ clone = CRCAAgent(
3341
+ variables=list(self.causal_graph.keys()),
3342
+ causal_edges=self.get_edges(),
3343
+ model_name=self.model_name,
3344
+ max_loops=self.causal_max_loops,
3345
+ enable_batch_predict=self.enable_batch_predict,
3346
+ max_batch_size=self.max_batch_size,
3347
+ bootstrap_workers=0,
3348
+ use_async=self.use_async,
3349
+ seed=self.seed,
3350
+ )
3351
+ clone.edge_sign_constraints = dict(self.edge_sign_constraints)
3352
+ clone.standardization_stats = dict(baseline_stats)
3353
+ try:
3354
+ clone.fit_from_dataframe(
3355
+ df=df_sample,
3356
+ variables=variables,
3357
+ window=min(30, len(df_sample)),
3358
+ decay_alpha=0.9,
3359
+ ridge_lambda=0.0,
3360
+ enforce_signs=True,
3361
+ )
3362
+ return _snapshot_strengths_from_graph(clone.causal_graph)
3363
+ except Exception:
3364
+ return {}
3365
+
3366
+ def _snapshot_strengths_from_graph(graph: Dict[str, Dict[str, Any]]) -> Dict[Tuple[str, str], float]:
3367
+ res: Dict[Tuple[str, str], float] = {}
3368
+ for u, targets in graph.items():
3369
+ for v, meta in targets.items():
3370
+ try:
3371
+ res[(u, v)] = float(meta.get("strength", 0.0)) if isinstance(meta, dict) else float(meta)
3372
+ except Exception:
3373
+ res[(u, v)] = 0.0
3374
+ return res
3375
+
3376
+ use_parallel = self.bootstrap_workers > 0
3377
+ if use_parallel:
3378
+ import concurrent.futures
3379
+ with concurrent.futures.ThreadPoolExecutor(max_workers=self.bootstrap_workers) as executor:
3380
+ futures = []
3381
+ for i in range(windows):
3382
+ boot_df = usable.sample(n=len(usable), replace=True, random_state=self.seed + i)
3383
+ futures.append(executor.submit(_bootstrap_single, boot_df))
3384
+ for fut in futures:
3385
+ try:
3386
+ res_strengths = fut.result()
3387
+ for (u, v), w in res_strengths.items():
3388
+ samples.setdefault((u, v), []).append(w)
3389
+ except Exception:
3390
+ continue
3391
+ else:
3392
+ for i in range(windows):
3393
+ boot_df = usable.sample(n=len(usable), replace=True, random_state=self.seed + i)
3394
+ try:
3395
+ self.fit_from_dataframe(
3396
+ df=boot_df,
3397
+ variables=variables,
3398
+ window=min(30, len(boot_df)),
3399
+ decay_alpha=0.9,
3400
+ ridge_lambda=0.0,
3401
+ enforce_signs=True,
3402
+ )
3403
+ for (u, v), w in _snapshot_strengths().items():
3404
+ samples.setdefault((u, v), []).append(w)
3405
+ except Exception:
3406
+ continue
3407
+
3408
+ # Restore baseline strengths and stats
3409
+ for (u, v), w in baseline_strengths.items():
3410
+ if u in self.causal_graph and v in self.causal_graph[u]:
3411
+ self.causal_graph[u][v]["strength"] = w
3412
+ self.standardization_stats = baseline_stats
3413
+
3414
+ edge_cis: Dict[str, Tuple[float, float]] = {}
3415
+ for (u, v), arr in samples.items():
3416
+ arr_np = np.array(arr)
3417
+ lo = float(np.quantile(arr_np, (1 - alpha) / 2))
3418
+ hi = float(np.quantile(arr_np, 1 - (1 - alpha) / 2))
3419
+ edge_cis[f"{u}->{v}"] = (lo, hi)
3420
+ return {"edge_cis": edge_cis, "samples": windows}
3421
+
3422
+ # ---- Optimization ----
3423
+ def gradient_based_intervention_optimization(
3424
+ self,
3425
+ initial_state: Dict[str, float],
3426
+ target: str,
3427
+ intervention_vars: List[str],
3428
+ constraints: Optional[Dict[str, Tuple[float, float]]] = None,
3429
+ method: str = "L-BFGS-B",
3430
+ ) -> Dict[str, Any]:
3431
+ self._require_scipy()
3432
+ from scipy.optimize import minimize # type: ignore
3433
+
3434
+ if not intervention_vars:
3435
+ return {"error": "intervention_vars cannot be empty", "optimal_intervention": {}, "success": False}
3436
+
3437
+ bounds = []
3438
+ x0 = []
3439
+ for var in intervention_vars:
3440
+ cur = float(initial_state.get(var, 0.0))
3441
+ x0.append(cur)
3442
+ if constraints and var in constraints:
3443
+ bounds.append(constraints[var])
3444
+ else:
3445
+ bounds.append((cur - 3.0, cur + 3.0))
3446
+
3447
+ def objective(x: np.ndarray) -> float:
3448
+ intervention = {intervention_vars[i]: float(x[i]) for i in range(len(x))}
3449
+ outcome = self._predict_outcomes(initial_state, intervention)
3450
+ return -float(outcome.get(target, 0.0))
3451
+
3452
+ try:
3453
+ result = minimize(
3454
+ objective,
3455
+ x0=np.array(x0, dtype=float),
3456
+ method=method,
3457
+ bounds=bounds,
3458
+ options={"maxiter": 100, "ftol": 1e-6},
3459
+ )
3460
+ optimal_intervention = {intervention_vars[i]: float(result.x[i]) for i in range(len(result.x))}
3461
+ optimal_outcome = self._predict_outcomes(initial_state, optimal_intervention)
3462
+ return {
3463
+ "optimal_intervention": optimal_intervention,
3464
+ "optimal_target_value": float(optimal_outcome.get(target, 0.0)),
3465
+ "objective_value": float(result.fun),
3466
+ "success": bool(result.success),
3467
+ "iterations": int(getattr(result, "nit", 0)),
3468
+ "convergence_message": str(result.message),
3469
+ }
3470
+ except Exception as e:
3471
+ logger.debug(f"gradient_based_intervention_optimization failed: {e}")
3472
+ return {"error": str(e), "optimal_intervention": {}, "success": False}
3473
+
3474
+ def bellman_optimal_intervention(
3475
+ self,
3476
+ initial_state: Dict[str, float],
3477
+ target: str,
3478
+ intervention_vars: List[str],
3479
+ horizon: int = 5,
3480
+ discount: float = 0.9,
3481
+ ) -> Dict[str, Any]:
3482
+ if not intervention_vars:
3483
+ return {"error": "intervention_vars cannot be empty"}
3484
+ horizon = max(1, int(horizon))
3485
+ rng = np.random.default_rng(self.seed)
3486
+ current_state = dict(initial_state)
3487
+ sequence: List[Dict[str, float]] = []
3488
+
3489
+ def reward(state: Dict[str, float]) -> float:
3490
+ return float(state.get(target, 0.0))
3491
+
3492
+ for _ in range(horizon):
3493
+ best_value = float("-inf")
3494
+ best_intervention: Dict[str, float] = {}
3495
+ for _ in range(10):
3496
+ candidate = {}
3497
+ for var in intervention_vars:
3498
+ stats = self.standardization_stats.get(var, {"mean": current_state.get(var, 0.0), "std": 1.0})
3499
+ candidate[var] = float(rng.normal(stats["mean"], stats["std"]))
3500
+ next_state = self._predict_outcomes(current_state, candidate)
3501
+ val = reward(next_state)
3502
+ if val > best_value:
3503
+ best_value = val
3504
+ best_intervention = candidate
3505
+ if best_intervention:
3506
+ sequence.append(best_intervention)
3507
+ current_state = self._predict_outcomes(current_state, best_intervention)
3508
+
3509
+ return {
3510
+ "optimal_sequence": sequence,
3511
+ "final_state": current_state,
3512
+ "total_value": float(current_state.get(target, 0.0)),
3513
+ "horizon": horizon,
3514
+ "discount_factor": float(discount),
3515
+ }
3516
+
3517
+ # ---- Time-series & causality ----
3518
+ def granger_causality_test(
3519
+ self,
3520
+ df: Any,
3521
+ var1: str,
3522
+ var2: str,
3523
+ max_lag: int = 4,
3524
+ ) -> Dict[str, Any]:
3525
+ self._require_pandas()
3526
+ if df is None or not isinstance(df, pd.DataFrame):
3527
+ return {"error": "Invalid data or variables"}
3528
+ data = df[[var1, var2]].dropna()
3529
+ if len(data) < max_lag * 2 + 5:
3530
+ return {"error": "Insufficient data"}
3531
+ try:
3532
+ from scipy.stats import f as f_dist # type: ignore
3533
+ except Exception:
3534
+ return {"error": "scipy f distribution not available"}
3535
+
3536
+ n = len(data)
3537
+ y = data[var2].values
3538
+ Xr = []
3539
+ Xu = []
3540
+ for t in range(max_lag, n):
3541
+ y_t = y[t]
3542
+ lags_var2 = [data[var2].iloc[t - i] for i in range(1, max_lag + 1)]
3543
+ lags_var1 = [data[var1].iloc[t - i] for i in range(1, max_lag + 1)]
3544
+ Xr.append(lags_var2)
3545
+ Xu.append(lags_var2 + lags_var1)
3546
+ y_vec = np.array(y[max_lag:], dtype=float)
3547
+ Xr = np.array(Xr, dtype=float)
3548
+ Xu = np.array(Xu, dtype=float)
3549
+
3550
+ def ols(X: np.ndarray, yv: np.ndarray) -> Tuple[np.ndarray, float]:
3551
+ beta = np.linalg.pinv(X) @ yv
3552
+ y_pred = X @ beta
3553
+ rss = float(np.sum((yv - y_pred) ** 2))
3554
+ return beta, rss
3555
+
3556
+ try:
3557
+ _, rss_r = ols(Xr, y_vec)
3558
+ _, rss_u = ols(Xu, y_vec)
3559
+ m = max_lag
3560
+ df2 = len(y_vec) - 2 * m - 1
3561
+ if df2 <= 0 or rss_u <= 1e-12:
3562
+ return {"error": "Degenerate case in F-test"}
3563
+ f_stat = ((rss_r - rss_u) / m) / (rss_u / df2)
3564
+ p_value = float(1.0 - f_dist.cdf(f_stat, m, df2))
3565
+ return {
3566
+ "f_statistic": float(f_stat),
3567
+ "p_value": p_value,
3568
+ "granger_causes": p_value < 0.05,
3569
+ "max_lag": max_lag,
3570
+ "restricted_rss": rss_r,
3571
+ "unrestricted_rss": rss_u,
3572
+ }
3573
+ except Exception as e:
3574
+ return {"error": str(e)}
3575
+
3576
+ def vector_autoregression_estimation(
3577
+ self,
3578
+ df: Any,
3579
+ variables: List[str],
3580
+ max_lag: int = 2,
3581
+ ) -> Dict[str, Any]:
3582
+ self._require_pandas()
3583
+ if df is None or not isinstance(df, pd.DataFrame):
3584
+ return {"error": "Invalid data"}
3585
+ data = df[variables].dropna()
3586
+ if len(data) < max_lag * len(variables) + 5:
3587
+ return {"error": "Insufficient data"}
3588
+ n_vars = len(variables)
3589
+ X_lag = []
3590
+ y_mat = []
3591
+ for t in range(max_lag, len(data)):
3592
+ y_row = [data[var].iloc[t] for var in variables]
3593
+ y_mat.append(y_row)
3594
+ lag_row = []
3595
+ for lag in range(1, max_lag + 1):
3596
+ for var in variables:
3597
+ lag_row.append(data[var].iloc[t - lag])
3598
+ X_lag.append(lag_row)
3599
+ X = np.array(X_lag, dtype=float)
3600
+ Y = np.array(y_mat, dtype=float)
3601
+ coefficients: Dict[str, Any] = {}
3602
+ residuals = []
3603
+ for idx, var in enumerate(variables):
3604
+ y_vec = Y[:, idx]
3605
+ beta = np.linalg.pinv(X) @ y_vec
3606
+ y_pred = X @ beta
3607
+ res = y_vec - y_pred
3608
+ residuals.append(res)
3609
+ coefficients[var] = {"coefficients": beta.tolist()}
3610
+ residuals = np.array(residuals).T
3611
+ return {
3612
+ "coefficient_matrices": coefficients,
3613
+ "residuals": residuals.tolist(),
3614
+ "n_observations": len(Y),
3615
+ "n_variables": n_vars,
3616
+ "max_lag": max_lag,
3617
+ "variables": variables,
3618
+ }
3619
+
3620
+ def compute_information_theoretic_measures(
3621
+ self,
3622
+ df: Any,
3623
+ variables: List[str],
3624
+ ) -> Dict[str, Any]:
3625
+ """
3626
+ Compute simple entropy and mutual information estimates using histograms.
3627
+ """
3628
+ self._require_pandas()
3629
+ if df is None or not isinstance(df, pd.DataFrame):
3630
+ return {"error": "Invalid data"}
3631
+ data = df[variables].dropna()
3632
+ if len(data) < 10:
3633
+ return {"error": "Insufficient data"}
3634
+
3635
+ results: Dict[str, Any] = {"entropies": {}, "mutual_information": {}}
3636
+ for var in variables:
3637
+ if var not in data.columns:
3638
+ continue
3639
+ series = data[var].dropna()
3640
+ if len(series) < 5:
3641
+ continue
3642
+ n_bins = min(20, max(5, int(np.sqrt(len(series)))))
3643
+ hist, _ = np.histogram(series, bins=n_bins)
3644
+ hist = hist[hist > 0]
3645
+ probs = hist / hist.sum()
3646
+ entropy = -np.sum(probs * np.log2(probs))
3647
+ results["entropies"][var] = float(entropy)
3648
+
3649
+ # Pairwise mutual information
3650
+ for i, var1 in enumerate(variables):
3651
+ if var1 not in results["entropies"]:
3652
+ continue
3653
+ for var2 in variables[i + 1:]:
3654
+ if var2 not in results["entropies"]:
3655
+ continue
3656
+ joint = data[[var1, var2]].dropna()
3657
+ if len(joint) < 5:
3658
+ continue
3659
+ n_bins = min(10, max(3, int(np.cbrt(len(joint)))))
3660
+ hist2d, _, _ = np.histogram2d(joint[var1], joint[var2], bins=n_bins)
3661
+ hist2d = hist2d[hist2d > 0]
3662
+ probs_joint = hist2d / hist2d.sum()
3663
+ h_joint = -np.sum(probs_joint * np.log2(probs_joint))
3664
+ mi = results["entropies"][var1] + results["entropies"][var2] - float(h_joint)
3665
+ results["mutual_information"][f"{var1};{var2}"] = float(max(0.0, mi))
3666
+
3667
+ return results
3668
+
3669
+ # ---- Bayesian & attribution ----
3670
+ def bayesian_edge_inference(
3671
+ self,
3672
+ df: Any,
3673
+ parent: str,
3674
+ child: str,
3675
+ prior_mu: float = 0.0,
3676
+ prior_sigma: float = 1.0,
3677
+ ) -> Dict[str, Any]:
3678
+ self._require_pandas()
3679
+ if df is None or not isinstance(df, pd.DataFrame):
3680
+ return {"error": "Invalid data"}
3681
+ if parent not in df.columns or child not in df.columns:
3682
+ return {"error": "Variables not found"}
3683
+ data = df[[parent, child]].dropna()
3684
+ if len(data) < 5:
3685
+ return {"error": "Insufficient data"}
3686
+ X = data[parent].values.reshape(-1, 1)
3687
+ y = data[child].values
3688
+ X_mean, X_std = X.mean(), X.std() or 1.0
3689
+ y_mean, y_std = y.mean(), y.std() or 1.0
3690
+ X_norm = (X - X_mean) / X_std
3691
+ y_norm = (y - y_mean) / y_std
3692
+ XtX = X_norm.T @ X_norm
3693
+ Xty = X_norm.T @ y_norm
3694
+ beta_ols = float((np.linalg.pinv(XtX) @ Xty)[0])
3695
+ residuals = y_norm - X_norm @ np.array([beta_ols])
3696
+ sigma_sq = float(np.var(residuals))
3697
+ tau_likelihood = 1.0 / (sigma_sq + 1e-6)
3698
+ tau_prior = 1.0 / (prior_sigma ** 2)
3699
+ tau_post = tau_prior + tau_likelihood * len(data)
3700
+ mu_post = (tau_prior * prior_mu + tau_likelihood * len(data) * beta_ols) / tau_post
3701
+ sigma_post = math.sqrt(1.0 / tau_post)
3702
+ ci_lower = mu_post - 1.96 * sigma_post
3703
+ ci_upper = mu_post + 1.96 * sigma_post
3704
+ self.bayesian_priors[(parent, child)] = {"mu": prior_mu, "sigma": prior_sigma}
3705
+ return {
3706
+ "posterior_mean": float(mu_post),
3707
+ "posterior_std": float(sigma_post),
3708
+ "posterior_variance": float(sigma_post ** 2),
3709
+ "credible_interval_95": (float(ci_lower), float(ci_upper)),
3710
+ "ols_estimate": float(beta_ols),
3711
+ "prior_mu": float(prior_mu),
3712
+ "prior_sigma": float(prior_sigma),
3713
+ }
3714
+
3715
+ def sensitivity_analysis(
3716
+ self,
3717
+ intervention: Dict[str, float],
3718
+ target: str,
3719
+ perturbation_size: float = 0.01,
3720
+ ) -> Dict[str, Any]:
3721
+ base_outcome = self._predict_outcomes({}, intervention)
3722
+ base_target = base_outcome.get(target, 0.0)
3723
+ sensitivities: Dict[str, float] = {}
3724
+ elasticities: Dict[str, float] = {}
3725
+ for var, val in intervention.items():
3726
+ perturbed = dict(intervention)
3727
+ perturbed[var] = val + perturbation_size
3728
+ perturbed_outcome = self._predict_outcomes({}, perturbed)
3729
+ pert_target = perturbed_outcome.get(target, 0.0)
3730
+ sensitivity = (pert_target - base_target) / perturbation_size
3731
+ sensitivities[var] = float(sensitivity)
3732
+ if abs(base_target) > 1e-6 and abs(val) > 1e-6:
3733
+ elasticities[var] = float(sensitivity * (val / base_target))
3734
+ else:
3735
+ elasticities[var] = 0.0
3736
+ most_inf = max(sensitivities.items(), key=lambda x: abs(x[1])) if sensitivities else (None, 0.0)
3737
+ total_sens = float(np.linalg.norm(list(sensitivities.values()))) if sensitivities else 0.0
3738
+ return {
3739
+ "sensitivities": sensitivities,
3740
+ "elasticities": elasticities,
3741
+ "total_sensitivity": total_sens,
3742
+ "most_influential_variable": most_inf[0],
3743
+ "most_influential_sensitivity": float(most_inf[1]),
3744
+ }
3745
+
3746
+ def deep_root_cause_analysis(
3747
+ self,
3748
+ problem_variable: str,
3749
+ max_depth: int = 20,
3750
+ min_path_strength: float = 0.01,
3751
+ ) -> Dict[str, Any]:
3752
+ if problem_variable not in self.causal_graph:
3753
+ return {"error": f"Variable {problem_variable} not in causal graph"}
3754
+ all_ancestors = list(self.causal_graph_reverse.get(problem_variable, []))
3755
+ root_causes: List[Dict[str, Any]] = []
3756
+ paths_to_problem: List[Dict[str, Any]] = []
3757
+
3758
+ def path_strength(path: List[str]) -> float:
3759
+ prod = 1.0
3760
+ for i in range(len(path) - 1):
3761
+ u, v = path[i], path[i + 1]
3762
+ prod *= self._edge_strength(u, v)
3763
+ if abs(prod) < min_path_strength:
3764
+ return 0.0
3765
+ return prod
3766
+
3767
+ for anc in all_ancestors:
3768
+ try:
3769
+ queue = [(anc, [anc])]
3770
+ visited = set()
3771
+ while queue:
3772
+ node, path = queue.pop(0)
3773
+ if len(path) - 1 > max_depth:
3774
+ continue
3775
+ if node == problem_variable and len(path) > 1:
3776
+ ps = path_strength(path)
3777
+ if abs(ps) > 0:
3778
+ root_causes.append({
3779
+ "root_cause": path[0],
3780
+ "path_to_problem": path,
3781
+ "path_string": " -> ".join(path),
3782
+ "path_strength": float(ps),
3783
+ "depth": len(path) - 1,
3784
+ "is_exogenous": len(self._get_parents(path[0])) == 0,
3785
+ })
3786
+ paths_to_problem.append({
3787
+ "from": path[0],
3788
+ "to": problem_variable,
3789
+ "path": path,
3790
+ "strength": float(ps),
3791
+ })
3792
+ continue
3793
+ for child in self._get_children(node):
3794
+ if child not in visited:
3795
+ visited.add(child)
3796
+ queue.append((child, path + [child]))
3797
+ except Exception:
3798
+ continue
3799
+
3800
+ root_causes.sort(key=lambda x: (-x["is_exogenous"], -abs(x["path_strength"]), x["depth"]))
3801
+ ultimate_roots = [rc for rc in root_causes if rc.get("is_exogenous")]
3802
+ return {
3803
+ "problem_variable": problem_variable,
3804
+ "all_root_causes": root_causes[:20],
3805
+ "ultimate_root_causes": ultimate_roots[:10],
3806
+ "total_paths_found": len(paths_to_problem),
3807
+ "max_depth_reached": max([rc["depth"] for rc in root_causes] + [0]),
3808
+ }
3809
+
3810
+ def shapley_value_attribution(
3811
+ self,
3812
+ baseline_state: Dict[str, float],
3813
+ target_state: Dict[str, float],
3814
+ target: str,
3815
+ samples: int = 100,
3816
+ ) -> Dict[str, Any]:
3817
+ variables = list(set(list(baseline_state.keys()) + list(target_state.keys())))
3818
+ n = len(variables)
3819
+ if n == 0:
3820
+ return {"shapley_values": {}, "normalized": {}, "total_attribution": 0.0}
3821
+ rng = np.random.default_rng(self.seed)
3822
+ contributions: Dict[str, float] = {v: 0.0 for v in variables}
3823
+
3824
+ def value(subset: List[str]) -> float:
3825
+ state = dict(baseline_state)
3826
+ for var in subset:
3827
+ if var in target_state:
3828
+ state[var] = target_state[var]
3829
+ outcome = self._predict_outcomes({}, state)
3830
+ return float(outcome.get(target, 0.0))
3831
+
3832
+ for _ in range(max(1, samples)):
3833
+ perm = list(variables)
3834
+ rng.shuffle(perm)
3835
+ cur_set: List[str] = []
3836
+ prev_val = value(cur_set)
3837
+ for v in perm:
3838
+ cur_set.append(v)
3839
+ new_val = value(cur_set)
3840
+ contributions[v] += new_val - prev_val
3841
+ prev_val = new_val
3842
+
3843
+ shapley_values = {k: v / float(samples) for k, v in contributions.items()}
3844
+ total = sum(abs(v) for v in shapley_values.values()) or 1.0
3845
+ normalized = {k: v / total for k, v in shapley_values.items()}
3846
+ return {
3847
+ "shapley_values": shapley_values,
3848
+ "normalized": normalized,
3849
+ "total_attribution": float(sum(abs(v) for v in shapley_values.values())),
3850
+ }
3851
+
3852
+ # ---- Multi-layer scenarios ----
3853
+ def multi_layer_whatif_analysis(
3854
+ self,
3855
+ scenarios: List[Dict[str, float]],
3856
+ depth: int = 3,
3857
+ ) -> Dict[str, Any]:
3858
+ results: List[Dict[str, Any]] = []
3859
+ for scen in scenarios:
3860
+ layer1 = self._predict_outcomes({}, scen)
3861
+ affected = [k for k, v in layer1.items() if abs(v) > 0.01]
3862
+ layer2_scenarios = [{a: layer1.get(a, 0.0) * 1.2} for a in affected[:5]]
3863
+ layer2_results: List[Dict[str, Any]] = []
3864
+ for l2 in layer2_scenarios:
3865
+ l2_outcome = self._predict_outcomes(layer1, l2)
3866
+ layer2_results.append({"layer2_scenario": l2, "layer2_outcomes": l2_outcome})
3867
+ results.append({
3868
+ "scenario": scen,
3869
+ "layer1_direct_effects": layer1,
3870
+ "affected_variables": affected,
3871
+ "layer2_cascades": layer2_results,
3872
+ })
3873
+ return {"multi_layer_analysis": results, "summary": {"total_scenarios": len(results)}}
3874
+
3875
+ def explore_alternate_realities(
3876
+ self,
3877
+ factual_state: Dict[str, float],
3878
+ target_outcome: str,
3879
+ target_value: Optional[float] = None,
3880
+ max_realities: int = 50,
3881
+ max_interventions: int = 3,
3882
+ ) -> Dict[str, Any]:
3883
+ rng = np.random.default_rng(self.seed)
3884
+ variables = list(factual_state.keys())
3885
+ realities: List[Dict[str, Any]] = []
3886
+ for _ in range(max_realities):
3887
+ num_int = rng.integers(1, max(2, max_interventions + 1))
3888
+ selected = rng.choice(variables, size=min(num_int, len(variables)), replace=False)
3889
+ intervention = {}
3890
+ for var in selected:
3891
+ stats = self.standardization_stats.get(var, {"mean": factual_state.get(var, 0.0), "std": 1.0})
3892
+ intervention[var] = float(rng.normal(stats["mean"], stats["std"] * 1.5))
3893
+ outcome = self._predict_outcomes(factual_state, intervention)
3894
+ target_val = outcome.get(target_outcome, 0.0)
3895
+ if target_value is not None:
3896
+ objective = -abs(target_val - target_value)
3897
+ else:
3898
+ objective = target_val
3899
+ realities.append({
3900
+ "interventions": intervention,
3901
+ "outcome": outcome,
3902
+ "target_value": float(target_val),
3903
+ "objective": float(objective),
3904
+ "delta_from_factual": float(target_val - factual_state.get(target_outcome, 0.0)),
3905
+ })
3906
+ realities.sort(key=lambda x: x["objective"], reverse=True)
3907
+ best = realities[0] if realities else None
3908
+ return {
3909
+ "factual_state": factual_state,
3910
+ "target_outcome": target_outcome,
3911
+ "target_value": target_value,
3912
+ "best_reality": best,
3913
+ "top_10_realities": realities[:10],
3914
+ "all_realities_explored": len(realities),
3915
+ "improvement_potential": (best["target_value"] - factual_state.get(target_outcome, 0.0)) if best else 0.0,
3916
+ }
3917
+
3918
+ # ---- Async wrappers ----
3919
+ async def run_async(
3920
+ self,
3921
+ task: Optional[Union[str, Any]] = None,
3922
+ initial_state: Optional[Any] = None,
3923
+ target_variables: Optional[List[str]] = None,
3924
+ max_steps: Union[int, str] = 1,
3925
+ **kwargs,
3926
+ ) -> Dict[str, Any]:
3927
+ loop = asyncio.get_running_loop()
3928
+ return await loop.run_in_executor(None, lambda: self.run(task=task, initial_state=initial_state, target_variables=target_variables, max_steps=max_steps, **kwargs))
3929
+
3930
+ async def quantify_uncertainty_async(
3931
+ self,
3932
+ df: Any,
3933
+ variables: List[str],
3934
+ windows: int = 200,
3935
+ alpha: float = 0.95
3936
+ ) -> Dict[str, Any]:
3937
+ loop = asyncio.get_running_loop()
3938
+ return await loop.run_in_executor(None, lambda: self.quantify_uncertainty(df=df, variables=variables, windows=windows, alpha=alpha))
3939
+
3940
+ async def granger_causality_test_async(
3941
+ self,
3942
+ df: Any,
3943
+ var1: str,
3944
+ var2: str,
3945
+ max_lag: int = 4,
3946
+ ) -> Dict[str, Any]:
3947
+ loop = asyncio.get_running_loop()
3948
+ return await loop.run_in_executor(None, lambda: self.granger_causality_test(df=df, var1=var1, var2=var2, max_lag=max_lag))
3949
+
3950
+ async def vector_autoregression_estimation_async(
3951
+ self,
3952
+ df: Any,
3953
+ variables: List[str],
3954
+ max_lag: int = 2,
3955
+ ) -> Dict[str, Any]:
3956
+ loop = asyncio.get_running_loop()
3957
+ return await loop.run_in_executor(None, lambda: self.vector_autoregression_estimation(df=df, variables=variables, max_lag=max_lag))
3958
+
3959
+ # =========================
3960
+ # Excel TUI Integration Methods
3961
+ # =========================
3962
+
3963
+ def _initialize_excel_tables(self) -> None:
3964
+ """Initialize standard Excel tables."""
3965
+ if not self._excel_enabled or self._excel_tables is None:
3966
+ return
3967
+
3968
+ try:
3969
+ from crca_excel.core.standard_tables import initialize_standard_tables
3970
+ initialize_standard_tables(self._excel_tables)
3971
+ except Exception as e:
3972
+ logger.error(f"Error initializing standard tables: {e}")
3973
+
3974
+ def _link_causal_graph_to_tables(self) -> None:
3975
+ """Link CRCA causal graph to Excel tables."""
3976
+ if not self._excel_enabled or self._excel_scm_bridge is None:
3977
+ return
3978
+
3979
+ try:
3980
+ self._excel_scm_bridge.link_causal_graph_to_tables(self.causal_graph)
3981
+ except Exception as e:
3982
+ logger.error(f"Error linking causal graph to tables: {e}")
3983
+
3984
+ def excel_edit_cell(self, table_name: str, row_key: Any, column_name: str, value: Any) -> None:
3985
+ """
3986
+ Edit a cell in Excel tables and trigger recomputation.
3987
+
3988
+ Args:
3989
+ table_name: Table name
3990
+ row_key: Row key
3991
+ column_name: Column name
3992
+ value: Value to set
3993
+ """
3994
+ if not self._excel_enabled or self._excel_tables is None:
3995
+ raise RuntimeError("Excel TUI not enabled")
3996
+
3997
+ self._excel_tables.set_cell(table_name, row_key, column_name, value)
3998
+
3999
+ # Trigger recomputation
4000
+ if self._excel_eval_engine:
4001
+ self._excel_eval_engine.recompute_dirty_cells()
4002
+
4003
+ def excel_apply_plan(self, plan: Dict[Tuple[str, Any, str], Any]) -> Dict[str, Any]:
4004
+ """
4005
+ Apply a plan (set of interventions) to Excel tables.
4006
+
4007
+ Args:
4008
+ plan: Dictionary of (table_name, row_key, column_name) -> value
4009
+
4010
+ Returns:
4011
+ Dictionary with results
4012
+ """
4013
+ if not self._excel_enabled or self._excel_scm_bridge is None:
4014
+ raise RuntimeError("Excel TUI not enabled")
4015
+
4016
+ # Convert plan format
4017
+ interventions = {
4018
+ (table_name, row_key, column_name): value
4019
+ for (table_name, row_key, column_name), value in plan.items()
4020
+ }
4021
+
4022
+ snapshot = self._excel_scm_bridge.do_intervention(interventions)
4023
+
4024
+ return {
4025
+ "snapshot": snapshot,
4026
+ "success": True
4027
+ }
4028
+
4029
+ def excel_generate_scenarios(
4030
+ self,
4031
+ base_interventions: Dict[Tuple[str, Any, str], Any],
4032
+ target_variables: List[Tuple[str, Any, str]],
4033
+ n_scenarios: int = 10
4034
+ ) -> List[Dict[str, Any]]:
4035
+ """
4036
+ Generate counterfactual scenarios.
4037
+
4038
+ Args:
4039
+ base_interventions: Base intervention set
4040
+ target_variables: Variables to vary
4041
+ n_scenarios: Number of scenarios
4042
+
4043
+ Returns:
4044
+ List of scenario dictionaries
4045
+ """
4046
+ if not self._excel_enabled or self._excel_scm_bridge is None:
4047
+ raise RuntimeError("Excel TUI not enabled")
4048
+
4049
+ # Convert format
4050
+ base_interv = {
4051
+ (table_name, row_key, column_name): value
4052
+ for (table_name, row_key, column_name), value in base_interventions.items()
4053
+ }
4054
+ target_vars = [
4055
+ (table_name, row_key, column_name)
4056
+ for (table_name, row_key, column_name) in target_variables
4057
+ ]
4058
+
4059
+ return self._excel_scm_bridge.generate_counterfactual_scenarios(
4060
+ base_interv,
4061
+ target_vars,
4062
+ n_scenarios=n_scenarios,
4063
+ seed=self.seed
4064
+ )
4065
+
4066
+ def excel_get_table(self, table_name: str):
4067
+ """
4068
+ Get an Excel table.
4069
+
4070
+ Args:
4071
+ table_name: Table name
4072
+
4073
+ Returns:
4074
+ Table instance or None
4075
+ """
4076
+ if not self._excel_enabled or self._excel_tables is None:
4077
+ return None
4078
+
4079
+ return self._excel_tables.get_table(table_name)
4080
+
4081
+ def excel_get_all_tables(self):
4082
+ """
4083
+ Get all Excel tables.
4084
+
4085
+ Returns:
4086
+ Dictionary of table_name -> Table
4087
+ """
4088
+ if not self._excel_enabled or self._excel_tables is None:
4089
+ return {}
4090
+
4091
+ return self._excel_tables.get_all_tables()
4092
+
4093
+ # =========================
4094
+ # Policy Engine Methods
4095
+ # =========================
4096
+
4097
+ def run_policy_loop(
4098
+ self,
4099
+ num_epochs: int,
4100
+ sensor_provider: Optional[Callable] = None,
4101
+ actuator: Optional[Callable] = None,
4102
+ start_epoch: int = 0
4103
+ ) -> Dict[str, Any]:
4104
+ """
4105
+ Execute temporal policy loop for specified number of epochs.
4106
+
4107
+ Args:
4108
+ num_epochs: Number of epochs to execute
4109
+ sensor_provider: Function that returns current state snapshot (Dict[str, float])
4110
+ If None, uses dummy sensor (all metrics = 0.0)
4111
+ actuator: Function that executes interventions (takes List[InterventionSpec])
4112
+ If None, interventions are logged but not executed
4113
+ start_epoch: Starting epoch number (default: 0)
4114
+
4115
+ Returns:
4116
+ Dict[str, Any]: Summary with decision hashes and epoch results
4117
+
4118
+ Raises:
4119
+ RuntimeError: If policy_mode is not enabled
4120
+ """
4121
+ if not self.policy_mode or self.policy_loop is None:
4122
+ raise RuntimeError("Policy mode not enabled. Set policy_mode=True and provide policy.")
4123
+
4124
+ results = []
4125
+ decision_hashes = []
4126
+
4127
+ for epoch in range(start_epoch, start_epoch + num_epochs):
4128
+ try:
4129
+ epoch_result = self.policy_loop.run_epoch(
4130
+ epoch=epoch,
4131
+ sensor_provider=sensor_provider,
4132
+ actuator=actuator
4133
+ )
4134
+ results.append(epoch_result)
4135
+ decision_hashes.append(epoch_result.get("decision_hash", ""))
4136
+ logger.info(f"Epoch {epoch} completed: {len(epoch_result.get('interventions', []))} interventions")
4137
+ except Exception as e:
4138
+ logger.error(f"Error in epoch {epoch}: {e}")
4139
+ raise
4140
+
4141
+ return {
4142
+ "num_epochs": num_epochs,
4143
+ "start_epoch": start_epoch,
4144
+ "end_epoch": start_epoch + num_epochs - 1,
4145
+ "decision_hashes": decision_hashes,
4146
+ "epoch_results": results,
4147
+ "policy_hash": self.policy_loop.compiled_policy.policy_hash,
4148
+ "summary": {
4149
+ "total_interventions": sum(len(r.get("interventions", [])) for r in results),
4150
+ "conservative_mode_triggered": any(r.get("conservative_mode", False) for r in results),
4151
+ "drift_detected": any(r.get("cusum_stat", 0.0) > self.policy_loop.cusum_h for r in results)
4152
+ }
4153
+ }
4154
+
4155
+
4156
+