crca 1.4.0__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- CRCA.py +172 -7
- MODEL_CARD.md +53 -0
- PKG-INFO +8 -2
- RELEASE_NOTES.md +17 -0
- STABILITY.md +19 -0
- architecture/hybrid/consistency_engine.py +362 -0
- architecture/hybrid/conversation_manager.py +421 -0
- architecture/hybrid/explanation_generator.py +452 -0
- architecture/hybrid/few_shot_learner.py +533 -0
- architecture/hybrid/graph_compressor.py +286 -0
- architecture/hybrid/hybrid_agent.py +4398 -0
- architecture/hybrid/language_compiler.py +623 -0
- architecture/hybrid/main,py +0 -0
- architecture/hybrid/reasoning_tracker.py +322 -0
- architecture/hybrid/self_verifier.py +524 -0
- architecture/hybrid/task_decomposer.py +567 -0
- architecture/hybrid/text_corrector.py +341 -0
- benchmark_results/crca_core_benchmarks.json +178 -0
- branches/crca_sd/crca_sd_realtime.py +6 -2
- branches/general_agent/__init__.py +102 -0
- branches/general_agent/general_agent.py +1400 -0
- branches/general_agent/personality.py +169 -0
- branches/general_agent/utils/__init__.py +19 -0
- branches/general_agent/utils/prompt_builder.py +170 -0
- {crca-1.4.0.dist-info → crca-1.5.0.dist-info}/METADATA +8 -2
- {crca-1.4.0.dist-info → crca-1.5.0.dist-info}/RECORD +303 -20
- crca_core/__init__.py +35 -0
- crca_core/benchmarks/__init__.py +14 -0
- crca_core/benchmarks/synthetic_scm.py +103 -0
- crca_core/core/__init__.py +23 -0
- crca_core/core/api.py +120 -0
- crca_core/core/estimate.py +208 -0
- crca_core/core/godclass.py +72 -0
- crca_core/core/intervention_design.py +174 -0
- crca_core/core/lifecycle.py +48 -0
- crca_core/discovery/__init__.py +9 -0
- crca_core/discovery/tabular.py +193 -0
- crca_core/identify/__init__.py +171 -0
- crca_core/identify/backdoor.py +39 -0
- crca_core/identify/frontdoor.py +48 -0
- crca_core/identify/graph.py +106 -0
- crca_core/identify/id_algorithm.py +43 -0
- crca_core/identify/iv.py +48 -0
- crca_core/models/__init__.py +67 -0
- crca_core/models/provenance.py +56 -0
- crca_core/models/refusal.py +39 -0
- crca_core/models/result.py +83 -0
- crca_core/models/spec.py +151 -0
- crca_core/models/validation.py +68 -0
- crca_core/scm/__init__.py +9 -0
- crca_core/scm/linear_gaussian.py +198 -0
- crca_core/timeseries/__init__.py +6 -0
- crca_core/timeseries/pcmci.py +181 -0
- crca_llm/__init__.py +12 -0
- crca_llm/client.py +85 -0
- crca_llm/coauthor.py +118 -0
- crca_llm/orchestrator.py +289 -0
- crca_llm/types.py +21 -0
- crca_reasoning/__init__.py +16 -0
- crca_reasoning/critique.py +54 -0
- crca_reasoning/godclass.py +206 -0
- crca_reasoning/memory.py +24 -0
- crca_reasoning/rationale.py +10 -0
- crca_reasoning/react_controller.py +81 -0
- crca_reasoning/tool_router.py +97 -0
- crca_reasoning/types.py +40 -0
- crca_sd/__init__.py +15 -0
- crca_sd/crca_sd_core.py +2 -0
- crca_sd/crca_sd_governance.py +2 -0
- crca_sd/crca_sd_mpc.py +2 -0
- crca_sd/crca_sd_realtime.py +2 -0
- crca_sd/crca_sd_tui.py +2 -0
- cuda-keyring_1.1-1_all.deb +0 -0
- cuda-keyring_1.1-1_all.deb.1 +0 -0
- docs/IMAGE_ANNOTATION_USAGE.md +539 -0
- docs/INSTALL_DEEPSPEED.md +125 -0
- docs/api/branches/crca-cg.md +19 -0
- docs/api/branches/crca-q.md +27 -0
- docs/api/branches/crca-sd.md +37 -0
- docs/api/branches/general-agent.md +24 -0
- docs/api/branches/overview.md +19 -0
- docs/api/crca/agent-methods.md +62 -0
- docs/api/crca/operations.md +79 -0
- docs/api/crca/overview.md +32 -0
- docs/api/image-annotation/engine.md +52 -0
- docs/api/image-annotation/overview.md +17 -0
- docs/api/schemas/annotation.md +34 -0
- docs/api/schemas/core-schemas.md +82 -0
- docs/api/schemas/overview.md +32 -0
- docs/api/schemas/policy.md +30 -0
- docs/api/utils/conversation.md +22 -0
- docs/api/utils/graph-reasoner.md +32 -0
- docs/api/utils/overview.md +21 -0
- docs/api/utils/router.md +19 -0
- docs/api/utils/utilities.md +97 -0
- docs/architecture/causal-graphs.md +41 -0
- docs/architecture/data-flow.md +29 -0
- docs/architecture/design-principles.md +33 -0
- docs/architecture/hybrid-agent/components.md +38 -0
- docs/architecture/hybrid-agent/consistency.md +26 -0
- docs/architecture/hybrid-agent/overview.md +44 -0
- docs/architecture/hybrid-agent/reasoning.md +22 -0
- docs/architecture/llm-integration.md +26 -0
- docs/architecture/modular-structure.md +37 -0
- docs/architecture/overview.md +69 -0
- docs/architecture/policy-engine-arch.md +29 -0
- docs/branches/crca-cg/corposwarm.md +39 -0
- docs/branches/crca-cg/esg-scoring.md +30 -0
- docs/branches/crca-cg/multi-agent.md +35 -0
- docs/branches/crca-cg/overview.md +40 -0
- docs/branches/crca-q/alternative-data.md +55 -0
- docs/branches/crca-q/architecture.md +71 -0
- docs/branches/crca-q/backtesting.md +45 -0
- docs/branches/crca-q/causal-engine.md +33 -0
- docs/branches/crca-q/execution.md +39 -0
- docs/branches/crca-q/market-data.md +60 -0
- docs/branches/crca-q/overview.md +58 -0
- docs/branches/crca-q/philosophy.md +60 -0
- docs/branches/crca-q/portfolio-optimization.md +66 -0
- docs/branches/crca-q/risk-management.md +102 -0
- docs/branches/crca-q/setup.md +65 -0
- docs/branches/crca-q/signal-generation.md +61 -0
- docs/branches/crca-q/signal-validation.md +43 -0
- docs/branches/crca-sd/core.md +84 -0
- docs/branches/crca-sd/governance.md +53 -0
- docs/branches/crca-sd/mpc-solver.md +65 -0
- docs/branches/crca-sd/overview.md +59 -0
- docs/branches/crca-sd/realtime.md +28 -0
- docs/branches/crca-sd/tui.md +20 -0
- docs/branches/general-agent/overview.md +37 -0
- docs/branches/general-agent/personality.md +36 -0
- docs/branches/general-agent/prompt-builder.md +30 -0
- docs/changelog/index.md +79 -0
- docs/contributing/code-style.md +69 -0
- docs/contributing/documentation.md +43 -0
- docs/contributing/overview.md +29 -0
- docs/contributing/testing.md +29 -0
- docs/core/crcagent/async-operations.md +65 -0
- docs/core/crcagent/automatic-extraction.md +107 -0
- docs/core/crcagent/batch-prediction.md +80 -0
- docs/core/crcagent/bayesian-inference.md +60 -0
- docs/core/crcagent/causal-graph.md +92 -0
- docs/core/crcagent/counterfactuals.md +96 -0
- docs/core/crcagent/deterministic-simulation.md +78 -0
- docs/core/crcagent/dual-mode-operation.md +82 -0
- docs/core/crcagent/initialization.md +88 -0
- docs/core/crcagent/optimization.md +65 -0
- docs/core/crcagent/overview.md +63 -0
- docs/core/crcagent/time-series.md +57 -0
- docs/core/schemas/annotation.md +30 -0
- docs/core/schemas/core-schemas.md +82 -0
- docs/core/schemas/overview.md +30 -0
- docs/core/schemas/policy.md +41 -0
- docs/core/templates/base-agent.md +31 -0
- docs/core/templates/feature-mixins.md +31 -0
- docs/core/templates/overview.md +29 -0
- docs/core/templates/templates-guide.md +75 -0
- docs/core/tools/mcp-client.md +34 -0
- docs/core/tools/overview.md +24 -0
- docs/core/utils/conversation.md +27 -0
- docs/core/utils/graph-reasoner.md +29 -0
- docs/core/utils/overview.md +27 -0
- docs/core/utils/router.md +27 -0
- docs/core/utils/utilities.md +97 -0
- docs/css/custom.css +84 -0
- docs/examples/basic-usage.md +57 -0
- docs/examples/general-agent/general-agent-examples.md +50 -0
- docs/examples/hybrid-agent/hybrid-agent-examples.md +56 -0
- docs/examples/image-annotation/image-annotation-examples.md +54 -0
- docs/examples/integration/integration-examples.md +58 -0
- docs/examples/overview.md +37 -0
- docs/examples/trading/trading-examples.md +46 -0
- docs/features/causal-reasoning/advanced-topics.md +101 -0
- docs/features/causal-reasoning/counterfactuals.md +43 -0
- docs/features/causal-reasoning/do-calculus.md +50 -0
- docs/features/causal-reasoning/overview.md +47 -0
- docs/features/causal-reasoning/structural-models.md +52 -0
- docs/features/hybrid-agent/advanced-components.md +55 -0
- docs/features/hybrid-agent/core-components.md +64 -0
- docs/features/hybrid-agent/overview.md +34 -0
- docs/features/image-annotation/engine.md +82 -0
- docs/features/image-annotation/features.md +113 -0
- docs/features/image-annotation/integration.md +75 -0
- docs/features/image-annotation/overview.md +53 -0
- docs/features/image-annotation/quickstart.md +73 -0
- docs/features/policy-engine/doctrine-ledger.md +105 -0
- docs/features/policy-engine/monitoring.md +44 -0
- docs/features/policy-engine/mpc-control.md +89 -0
- docs/features/policy-engine/overview.md +46 -0
- docs/getting-started/configuration.md +225 -0
- docs/getting-started/first-agent.md +164 -0
- docs/getting-started/installation.md +144 -0
- docs/getting-started/quickstart.md +137 -0
- docs/index.md +118 -0
- docs/js/mathjax.js +13 -0
- docs/lrm/discovery_proof_notes.md +25 -0
- docs/lrm/finetune_full.md +83 -0
- docs/lrm/math_appendix.md +120 -0
- docs/lrm/overview.md +32 -0
- docs/mkdocs.yml +238 -0
- docs/stylesheets/extra.css +21 -0
- docs_generated/crca_core/CounterfactualResult.md +12 -0
- docs_generated/crca_core/DiscoveryHypothesisResult.md +13 -0
- docs_generated/crca_core/DraftSpec.md +13 -0
- docs_generated/crca_core/EstimateResult.md +13 -0
- docs_generated/crca_core/IdentificationResult.md +17 -0
- docs_generated/crca_core/InterventionDesignResult.md +12 -0
- docs_generated/crca_core/LockedSpec.md +15 -0
- docs_generated/crca_core/RefusalResult.md +12 -0
- docs_generated/crca_core/ValidationReport.md +9 -0
- docs_generated/crca_core/index.md +13 -0
- examples/general_agent_example.py +277 -0
- examples/general_agent_quickstart.py +202 -0
- examples/general_agent_simple.py +92 -0
- examples/hybrid_agent_auto_extraction.py +84 -0
- examples/hybrid_agent_dictionary_demo.py +104 -0
- examples/hybrid_agent_enhanced.py +179 -0
- examples/hybrid_agent_general_knowledge.py +107 -0
- examples/image_annotation_quickstart.py +328 -0
- examples/test_hybrid_fixes.py +77 -0
- image_annotation/__init__.py +27 -0
- image_annotation/annotation_engine.py +2593 -0
- install_cuda_wsl2.sh +59 -0
- install_deepspeed.sh +56 -0
- install_deepspeed_simple.sh +87 -0
- mkdocs.yml +252 -0
- ollama/Modelfile +8 -0
- prompts/__init__.py +2 -1
- prompts/default_crca.py +9 -1
- prompts/general_agent.py +227 -0
- prompts/image_annotation.py +56 -0
- pyproject.toml +17 -2
- requirements-docs.txt +10 -0
- requirements.txt +21 -2
- schemas/__init__.py +26 -1
- schemas/annotation.py +222 -0
- schemas/conversation.py +193 -0
- schemas/hybrid.py +211 -0
- schemas/reasoning.py +276 -0
- schemas_export/crca_core/CounterfactualResult.schema.json +108 -0
- schemas_export/crca_core/DiscoveryHypothesisResult.schema.json +113 -0
- schemas_export/crca_core/DraftSpec.schema.json +635 -0
- schemas_export/crca_core/EstimateResult.schema.json +113 -0
- schemas_export/crca_core/IdentificationResult.schema.json +145 -0
- schemas_export/crca_core/InterventionDesignResult.schema.json +111 -0
- schemas_export/crca_core/LockedSpec.schema.json +646 -0
- schemas_export/crca_core/RefusalResult.schema.json +90 -0
- schemas_export/crca_core/ValidationReport.schema.json +62 -0
- scripts/build_lrm_dataset.py +80 -0
- scripts/export_crca_core_schemas.py +54 -0
- scripts/export_hf_lrm.py +37 -0
- scripts/export_ollama_gguf.py +45 -0
- scripts/generate_changelog.py +157 -0
- scripts/generate_crca_core_docs_from_schemas.py +86 -0
- scripts/run_crca_core_benchmarks.py +163 -0
- scripts/run_full_finetune.py +198 -0
- scripts/run_lrm_eval.py +31 -0
- templates/graph_management.py +29 -0
- tests/conftest.py +9 -0
- tests/test_core.py +2 -3
- tests/test_crca_core_discovery_tabular.py +15 -0
- tests/test_crca_core_estimate_dowhy.py +36 -0
- tests/test_crca_core_identify.py +18 -0
- tests/test_crca_core_intervention_design.py +36 -0
- tests/test_crca_core_linear_gaussian_scm.py +69 -0
- tests/test_crca_core_spec.py +25 -0
- tests/test_crca_core_timeseries_pcmci.py +15 -0
- tests/test_crca_llm_coauthor.py +12 -0
- tests/test_crca_llm_orchestrator.py +80 -0
- tests/test_hybrid_agent_llm_enhanced.py +556 -0
- tests/test_image_annotation_demo.py +376 -0
- tests/test_image_annotation_operational.py +408 -0
- tests/test_image_annotation_unit.py +551 -0
- tests/test_training_moe.py +13 -0
- training/__init__.py +42 -0
- training/datasets.py +140 -0
- training/deepspeed_zero2_0_5b.json +22 -0
- training/deepspeed_zero2_1_5b.json +22 -0
- training/deepspeed_zero3_0_5b.json +28 -0
- training/deepspeed_zero3_14b.json +28 -0
- training/deepspeed_zero3_h100_3gpu.json +20 -0
- training/deepspeed_zero3_offload.json +28 -0
- training/eval.py +92 -0
- training/finetune.py +516 -0
- training/public_datasets.py +89 -0
- training_data/react_train.jsonl +7473 -0
- utils/agent_discovery.py +311 -0
- utils/batch_processor.py +317 -0
- utils/conversation.py +78 -0
- utils/edit_distance.py +118 -0
- utils/formatter.py +33 -0
- utils/graph_reasoner.py +530 -0
- utils/rate_limiter.py +283 -0
- utils/router.py +2 -2
- utils/tool_discovery.py +307 -0
- webui/__init__.py +10 -0
- webui/app.py +229 -0
- webui/config.py +104 -0
- webui/static/css/style.css +332 -0
- webui/static/js/main.js +284 -0
- webui/templates/index.html +42 -0
- tests/test_crca_excel.py +0 -166
- tests/test_data_broker.py +0 -424
- tests/test_palantir.py +0 -349
- {crca-1.4.0.dist-info → crca-1.5.0.dist-info}/WHEEL +0 -0
- {crca-1.4.0.dist-info → crca-1.5.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
"""Pydantic models and schema contracts for CRCA core."""
|
|
2
|
+
|
|
3
|
+
from crca_core.models.provenance import ProvenanceManifest
|
|
4
|
+
from crca_core.models.refusal import RefusalChecklistItem, RefusalReasonCode, RefusalResult
|
|
5
|
+
from crca_core.models.result import (
|
|
6
|
+
AnyResult,
|
|
7
|
+
BaseResult,
|
|
8
|
+
CounterfactualResult,
|
|
9
|
+
DiscoveryHypothesisResult,
|
|
10
|
+
EstimateResult,
|
|
11
|
+
IdentificationResult,
|
|
12
|
+
InterventionDesignResult,
|
|
13
|
+
ValidationIssue,
|
|
14
|
+
ValidationReport,
|
|
15
|
+
)
|
|
16
|
+
from crca_core.models.spec import (
|
|
17
|
+
AssumptionItem,
|
|
18
|
+
AssumptionSpec,
|
|
19
|
+
CausalGraphSpec,
|
|
20
|
+
DataColumnSpec,
|
|
21
|
+
DataSpec,
|
|
22
|
+
DraftSpec,
|
|
23
|
+
EdgeSpec,
|
|
24
|
+
EntityIndexSpec,
|
|
25
|
+
LockedSpec,
|
|
26
|
+
NoiseSpec,
|
|
27
|
+
NodeSpec,
|
|
28
|
+
RoleSpec,
|
|
29
|
+
SCMSpec,
|
|
30
|
+
SpecStatus,
|
|
31
|
+
StructuralEquationSpec,
|
|
32
|
+
TimeIndexSpec,
|
|
33
|
+
)
|
|
34
|
+
from crca_core.models.validation import validate_spec
|
|
35
|
+
|
|
36
|
+
__all__ = [
|
|
37
|
+
"ProvenanceManifest",
|
|
38
|
+
"RefusalChecklistItem",
|
|
39
|
+
"RefusalReasonCode",
|
|
40
|
+
"RefusalResult",
|
|
41
|
+
"AnyResult",
|
|
42
|
+
"BaseResult",
|
|
43
|
+
"CounterfactualResult",
|
|
44
|
+
"DiscoveryHypothesisResult",
|
|
45
|
+
"EstimateResult",
|
|
46
|
+
"IdentificationResult",
|
|
47
|
+
"InterventionDesignResult",
|
|
48
|
+
"ValidationIssue",
|
|
49
|
+
"ValidationReport",
|
|
50
|
+
"AssumptionItem",
|
|
51
|
+
"AssumptionSpec",
|
|
52
|
+
"CausalGraphSpec",
|
|
53
|
+
"DataColumnSpec",
|
|
54
|
+
"DataSpec",
|
|
55
|
+
"DraftSpec",
|
|
56
|
+
"EdgeSpec",
|
|
57
|
+
"EntityIndexSpec",
|
|
58
|
+
"LockedSpec",
|
|
59
|
+
"NoiseSpec",
|
|
60
|
+
"NodeSpec",
|
|
61
|
+
"RoleSpec",
|
|
62
|
+
"SCMSpec",
|
|
63
|
+
"SpecStatus",
|
|
64
|
+
"StructuralEquationSpec",
|
|
65
|
+
"TimeIndexSpec",
|
|
66
|
+
"validate_spec",
|
|
67
|
+
]
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
"""Provenance manifest for reproducible causal R&D runs.
|
|
2
|
+
|
|
3
|
+
The manifest must not contain raw data; only hashes and schema summaries.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import platform
|
|
9
|
+
import sys
|
|
10
|
+
import uuid
|
|
11
|
+
from datetime import datetime, timezone
|
|
12
|
+
from typing import Any, Dict, Optional
|
|
13
|
+
|
|
14
|
+
from pydantic import BaseModel, Field
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def utc_now_iso() -> str:
|
|
18
|
+
return datetime.now(timezone.utc).isoformat()
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class ProvenanceManifest(BaseModel):
|
|
22
|
+
"""Required provenance for every `crca_core` result."""
|
|
23
|
+
|
|
24
|
+
run_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
|
25
|
+
timestamp_utc: str = Field(default_factory=utc_now_iso)
|
|
26
|
+
spec_hash: str
|
|
27
|
+
data_hash: Optional[str] = None
|
|
28
|
+
|
|
29
|
+
library_versions: Dict[str, str] = Field(default_factory=dict)
|
|
30
|
+
random_seeds: Dict[str, Any] = Field(default_factory=dict)
|
|
31
|
+
algorithm_config: Dict[str, Any] = Field(default_factory=dict)
|
|
32
|
+
hardware_notes: Optional[Dict[str, Any]] = None
|
|
33
|
+
|
|
34
|
+
@classmethod
|
|
35
|
+
def minimal(
|
|
36
|
+
cls,
|
|
37
|
+
*,
|
|
38
|
+
spec_hash: str,
|
|
39
|
+
data_hash: Optional[str] = None,
|
|
40
|
+
random_seeds: Optional[Dict[str, Any]] = None,
|
|
41
|
+
algorithm_config: Optional[Dict[str, Any]] = None,
|
|
42
|
+
) -> "ProvenanceManifest":
|
|
43
|
+
"""Create a minimal manifest with environment versions populated."""
|
|
44
|
+
|
|
45
|
+
versions = {
|
|
46
|
+
"python": sys.version.split()[0],
|
|
47
|
+
"platform": platform.platform(),
|
|
48
|
+
}
|
|
49
|
+
return cls(
|
|
50
|
+
spec_hash=spec_hash,
|
|
51
|
+
data_hash=data_hash,
|
|
52
|
+
library_versions=versions,
|
|
53
|
+
random_seeds=random_seeds or {},
|
|
54
|
+
algorithm_config=algorithm_config or {},
|
|
55
|
+
)
|
|
56
|
+
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
"""Hard-refusal types for H1 enforcement."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from enum import Enum
|
|
6
|
+
from typing import List, Optional
|
|
7
|
+
|
|
8
|
+
from pydantic import BaseModel, Field
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class RefusalReasonCode(str, Enum):
|
|
12
|
+
"""Stable reason codes for refusal-first behavior."""
|
|
13
|
+
|
|
14
|
+
SPEC_NOT_LOCKED = "SPEC_NOT_LOCKED"
|
|
15
|
+
NO_SCM_FOR_COUNTERFACTUAL = "NO_SCM_FOR_COUNTERFACTUAL"
|
|
16
|
+
NOT_IDENTIFIABLE = "NOT_IDENTIFIABLE"
|
|
17
|
+
TIME_INDEX_INVALID = "TIME_INDEX_INVALID"
|
|
18
|
+
ASSUMPTIONS_UNDECLARED = "ASSUMPTIONS_UNDECLARED"
|
|
19
|
+
INPUT_INVALID = "INPUT_INVALID"
|
|
20
|
+
UNSUPPORTED_OPERATION = "UNSUPPORTED_OPERATION"
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class RefusalChecklistItem(BaseModel):
|
|
24
|
+
"""A single required input/action needed to proceed."""
|
|
25
|
+
|
|
26
|
+
item: str = Field(..., min_length=1)
|
|
27
|
+
rationale: str = Field(..., min_length=1)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class RefusalResult(BaseModel):
|
|
31
|
+
"""Structured refusal (no numeric causal output)."""
|
|
32
|
+
|
|
33
|
+
result_type: str = Field(default="Refusal", frozen=True)
|
|
34
|
+
reason_codes: List[RefusalReasonCode] = Field(default_factory=list)
|
|
35
|
+
message: str = Field(..., min_length=1)
|
|
36
|
+
checklist: List[RefusalChecklistItem] = Field(default_factory=list)
|
|
37
|
+
suggested_next_steps: List[str] = Field(default_factory=list)
|
|
38
|
+
details: Optional[str] = None
|
|
39
|
+
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
"""Structured result types for crca_core.
|
|
2
|
+
|
|
3
|
+
All results are structured objects. Human-readable reports must be generated
|
|
4
|
+
by rendering these objects, not by mixing narrative into scientific fields.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from typing import Any, Dict, List, Literal, Optional
|
|
10
|
+
|
|
11
|
+
from pydantic import BaseModel, Field
|
|
12
|
+
|
|
13
|
+
from crca_core.models.provenance import ProvenanceManifest
|
|
14
|
+
from crca_core.models.refusal import RefusalResult
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class ValidationIssue(BaseModel):
|
|
18
|
+
code: str = Field(..., min_length=1)
|
|
19
|
+
message: str = Field(..., min_length=1)
|
|
20
|
+
path: Optional[str] = None
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class ValidationReport(BaseModel):
|
|
24
|
+
"""Returned by `validate_spec`."""
|
|
25
|
+
|
|
26
|
+
ok: bool
|
|
27
|
+
errors: List[ValidationIssue] = Field(default_factory=list)
|
|
28
|
+
warnings: List[ValidationIssue] = Field(default_factory=list)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class BaseResult(BaseModel):
|
|
32
|
+
"""Base result type with mandatory provenance."""
|
|
33
|
+
|
|
34
|
+
result_type: str
|
|
35
|
+
provenance: ProvenanceManifest
|
|
36
|
+
assumptions: List[str] = Field(default_factory=list)
|
|
37
|
+
limitations: List[str] = Field(default_factory=list)
|
|
38
|
+
artifacts: Dict[str, Any] = Field(default_factory=dict)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class DiscoveryHypothesisResult(BaseResult):
|
|
42
|
+
result_type: Literal["DiscoveryHypothesis"] = "DiscoveryHypothesis"
|
|
43
|
+
graph_hypothesis: Dict[str, Any] = Field(default_factory=dict)
|
|
44
|
+
stability_report: Dict[str, Any] = Field(default_factory=dict)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class InterventionDesignResult(BaseResult):
|
|
48
|
+
result_type: Literal["InterventionDesign"] = "InterventionDesign"
|
|
49
|
+
designs: List[Dict[str, Any]] = Field(default_factory=list)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class CounterfactualResult(BaseResult):
|
|
53
|
+
result_type: Literal["CounterfactualResult"] = "CounterfactualResult"
|
|
54
|
+
counterfactual: Dict[str, Any] = Field(default_factory=dict)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class IdentificationResult(BaseResult):
|
|
58
|
+
result_type: Literal["IdentificationResult"] = "IdentificationResult"
|
|
59
|
+
method: str
|
|
60
|
+
scope: Literal["conservative", "partial", "complete"] = "conservative"
|
|
61
|
+
confidence: Literal["low", "medium", "high"] = "low"
|
|
62
|
+
estimand_expression: str
|
|
63
|
+
assumptions_used: List[str] = Field(default_factory=list)
|
|
64
|
+
witnesses: Dict[str, Any] = Field(default_factory=dict)
|
|
65
|
+
proof: Dict[str, Any] = Field(default_factory=dict)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class EstimateResult(BaseResult):
|
|
69
|
+
result_type: Literal["EstimateResult"] = "EstimateResult"
|
|
70
|
+
estimate: Dict[str, Any] = Field(default_factory=dict)
|
|
71
|
+
refutations: Dict[str, Any] = Field(default_factory=dict)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
AnyResult = (
|
|
75
|
+
RefusalResult
|
|
76
|
+
| ValidationReport
|
|
77
|
+
| DiscoveryHypothesisResult
|
|
78
|
+
| InterventionDesignResult
|
|
79
|
+
| CounterfactualResult
|
|
80
|
+
| IdentificationResult
|
|
81
|
+
| EstimateResult
|
|
82
|
+
)
|
|
83
|
+
|
crca_core/models/spec.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
1
|
+
"""Typed causal specification objects (Draft → Locked)."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from enum import Enum
|
|
6
|
+
from typing import Any, Dict, List, Literal, Optional, Tuple
|
|
7
|
+
|
|
8
|
+
from pydantic import BaseModel, Field
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class SpecStatus(str, Enum):
|
|
12
|
+
draft = "draft"
|
|
13
|
+
locked = "locked"
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class AssumptionStatus(str, Enum):
|
|
17
|
+
declared = "declared"
|
|
18
|
+
contested = "contested"
|
|
19
|
+
violated = "violated"
|
|
20
|
+
unknown = "unknown"
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class DataColumnSpec(BaseModel):
|
|
24
|
+
name: str = Field(..., min_length=1)
|
|
25
|
+
dtype: str = Field(..., min_length=1)
|
|
26
|
+
allowed_range: Optional[Tuple[float, float]] = None
|
|
27
|
+
missingness_expected: Optional[float] = Field(default=None, ge=0.0, le=1.0)
|
|
28
|
+
unit: Optional[str] = None
|
|
29
|
+
description: Optional[str] = None
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class TimeIndexSpec(BaseModel):
|
|
33
|
+
column: str = Field(..., min_length=1)
|
|
34
|
+
frequency: Optional[str] = None
|
|
35
|
+
timezone: Optional[str] = None
|
|
36
|
+
irregular_sampling_policy: Optional[str] = None
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class EntityIndexSpec(BaseModel):
|
|
40
|
+
entity_id_column: str = Field(..., min_length=1)
|
|
41
|
+
time_column: str = Field(..., min_length=1)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class DataSpec(BaseModel):
|
|
45
|
+
dataset_name: Optional[str] = None
|
|
46
|
+
dataset_hash: Optional[str] = None
|
|
47
|
+
columns: List[DataColumnSpec] = Field(default_factory=list)
|
|
48
|
+
time_index: Optional[TimeIndexSpec] = None
|
|
49
|
+
entity_index: Optional[EntityIndexSpec] = None
|
|
50
|
+
measurement_error_notes: Optional[str] = None
|
|
51
|
+
proxy_variables: Dict[str, str] = Field(default_factory=dict)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class NodeSpec(BaseModel):
|
|
55
|
+
name: str = Field(..., min_length=1)
|
|
56
|
+
observed: bool = True
|
|
57
|
+
unit: Optional[str] = None
|
|
58
|
+
description: Optional[str] = None
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class EdgeSpec(BaseModel):
|
|
62
|
+
source: str = Field(..., min_length=1)
|
|
63
|
+
target: str = Field(..., min_length=1)
|
|
64
|
+
lag: Optional[int] = None
|
|
65
|
+
description: Optional[str] = None
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class CausalGraphSpec(BaseModel):
|
|
69
|
+
nodes: List[NodeSpec] = Field(default_factory=list)
|
|
70
|
+
edges: List[EdgeSpec] = Field(default_factory=list)
|
|
71
|
+
latent_confounders: List[str] = Field(default_factory=list)
|
|
72
|
+
notes: Optional[str] = None
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class RoleSpec(BaseModel):
|
|
76
|
+
treatments: List[str] = Field(default_factory=list)
|
|
77
|
+
outcomes: List[str] = Field(default_factory=list)
|
|
78
|
+
mediators: List[str] = Field(default_factory=list)
|
|
79
|
+
instruments: List[str] = Field(default_factory=list)
|
|
80
|
+
adjustment_candidates: List[str] = Field(default_factory=list)
|
|
81
|
+
prohibited_controls: List[str] = Field(default_factory=list)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
class AssumptionItem(BaseModel):
|
|
85
|
+
name: str = Field(..., min_length=1)
|
|
86
|
+
status: AssumptionStatus = AssumptionStatus.unknown
|
|
87
|
+
description: Optional[str] = None
|
|
88
|
+
evidence: Optional[str] = None
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
class AssumptionSpec(BaseModel):
|
|
92
|
+
items: List[AssumptionItem] = Field(default_factory=list)
|
|
93
|
+
falsification_plan: List[str] = Field(default_factory=list)
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
class NoiseSpec(BaseModel):
|
|
97
|
+
distribution: Literal["gaussian"] = "gaussian"
|
|
98
|
+
params: Dict[str, Any] = Field(default_factory=dict)
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
class StructuralEquationSpec(BaseModel):
|
|
102
|
+
"""Represents one structural equation V = f(Pa(V), U_V).
|
|
103
|
+
|
|
104
|
+
v0.1: store both a human-readable formula and an executable parameterization
|
|
105
|
+
for supported SCM families.
|
|
106
|
+
"""
|
|
107
|
+
|
|
108
|
+
variable: str = Field(..., min_length=1)
|
|
109
|
+
parents: List[str] = Field(default_factory=list)
|
|
110
|
+
form: Literal["linear_gaussian"] = "linear_gaussian"
|
|
111
|
+
coefficients: Dict[str, float] = Field(default_factory=dict) # parent -> beta
|
|
112
|
+
intercept: float = 0.0
|
|
113
|
+
noise: NoiseSpec = Field(default_factory=NoiseSpec)
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
class SCMSpec(BaseModel):
|
|
117
|
+
"""Explicit SCM required for counterfactuals."""
|
|
118
|
+
|
|
119
|
+
scm_type: Literal["linear_gaussian"] = "linear_gaussian"
|
|
120
|
+
equations: List[StructuralEquationSpec] = Field(default_factory=list)
|
|
121
|
+
# Optional correlated noise for linear-Gaussian SCMs (advanced; v0.1 may require diagonal)
|
|
122
|
+
noise_cov: Optional[List[List[float]]] = None
|
|
123
|
+
intervention_semantics: Dict[str, str] = Field(default_factory=dict) # var -> set/shift/mechanism-change
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
class DraftSpec(BaseModel):
|
|
127
|
+
"""Draft spec (may be LLM-generated; never authorizes numeric causal outputs)."""
|
|
128
|
+
|
|
129
|
+
status: SpecStatus = Field(default=SpecStatus.draft, frozen=True)
|
|
130
|
+
data: DataSpec = Field(default_factory=DataSpec)
|
|
131
|
+
graph: CausalGraphSpec = Field(default_factory=CausalGraphSpec)
|
|
132
|
+
roles: RoleSpec = Field(default_factory=RoleSpec)
|
|
133
|
+
assumptions: AssumptionSpec = Field(default_factory=AssumptionSpec)
|
|
134
|
+
scm: Optional[SCMSpec] = None
|
|
135
|
+
draft_notes: Optional[str] = None
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
class LockedSpec(BaseModel):
|
|
139
|
+
"""Locked spec (authoritative for identification/estimation/simulation semantics)."""
|
|
140
|
+
|
|
141
|
+
status: SpecStatus = Field(default=SpecStatus.locked, frozen=True)
|
|
142
|
+
spec_hash: str
|
|
143
|
+
approvals: List[str] = Field(default_factory=list)
|
|
144
|
+
locked_at_utc: str
|
|
145
|
+
|
|
146
|
+
data: DataSpec
|
|
147
|
+
graph: CausalGraphSpec
|
|
148
|
+
roles: RoleSpec
|
|
149
|
+
assumptions: AssumptionSpec
|
|
150
|
+
scm: Optional[SCMSpec] = None
|
|
151
|
+
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
"""Spec validation (DraftSpec or LockedSpec).
|
|
2
|
+
|
|
3
|
+
Validation is intentionally conservative: it checks for obvious structural
|
|
4
|
+
issues and missing required fields for different operations.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from typing import Iterable, Optional
|
|
10
|
+
|
|
11
|
+
from crca_core.models.result import ValidationIssue, ValidationReport
|
|
12
|
+
from crca_core.models.spec import DraftSpec, LockedSpec
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def _node_names(nodes) -> set[str]:
|
|
16
|
+
return {n.name for n in nodes}
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def validate_spec(spec: DraftSpec | LockedSpec) -> ValidationReport:
|
|
20
|
+
errors: list[ValidationIssue] = []
|
|
21
|
+
warnings: list[ValidationIssue] = []
|
|
22
|
+
|
|
23
|
+
# Basic graph consistency
|
|
24
|
+
node_names = _node_names(spec.graph.nodes)
|
|
25
|
+
for e in spec.graph.edges:
|
|
26
|
+
if e.source not in node_names:
|
|
27
|
+
errors.append(
|
|
28
|
+
ValidationIssue(
|
|
29
|
+
code="EDGE_SOURCE_UNKNOWN",
|
|
30
|
+
message=f"Edge source '{e.source}' not in nodes",
|
|
31
|
+
path="graph.edges",
|
|
32
|
+
)
|
|
33
|
+
)
|
|
34
|
+
if e.target not in node_names:
|
|
35
|
+
errors.append(
|
|
36
|
+
ValidationIssue(
|
|
37
|
+
code="EDGE_TARGET_UNKNOWN",
|
|
38
|
+
message=f"Edge target '{e.target}' not in nodes",
|
|
39
|
+
path="graph.edges",
|
|
40
|
+
)
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
# Time-series checks
|
|
44
|
+
if spec.data.time_index is not None:
|
|
45
|
+
time_col = spec.data.time_index.column
|
|
46
|
+
if time_col and time_col not in {c.name for c in spec.data.columns}:
|
|
47
|
+
errors.append(
|
|
48
|
+
ValidationIssue(
|
|
49
|
+
code="TIME_INDEX_COLUMN_UNKNOWN",
|
|
50
|
+
message=f"time_index.column '{time_col}' not in data columns",
|
|
51
|
+
path="data.time_index.column",
|
|
52
|
+
)
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
# Roles consistency
|
|
56
|
+
for v in spec.roles.treatments + spec.roles.outcomes + spec.roles.mediators + spec.roles.instruments:
|
|
57
|
+
if v not in node_names:
|
|
58
|
+
warnings.append(
|
|
59
|
+
ValidationIssue(
|
|
60
|
+
code="ROLE_NODE_UNKNOWN",
|
|
61
|
+
message=f"Role variable '{v}' not present in graph nodes",
|
|
62
|
+
path="roles",
|
|
63
|
+
)
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
ok = len(errors) == 0
|
|
67
|
+
return ValidationReport(ok=ok, errors=errors, warnings=warnings)
|
|
68
|
+
|
|
@@ -0,0 +1,198 @@
|
|
|
1
|
+
"""Linear-Gaussian Structural Causal Model (SCM).
|
|
2
|
+
|
|
3
|
+
Implements counterfactual reasoning via abduction–action–prediction:
|
|
4
|
+
- Abduction: infer exogenous noise U from a factual observation of all endogenous variables.
|
|
5
|
+
- Action: apply do-interventions by replacing structural equations for intervened variables.
|
|
6
|
+
- Prediction: propagate values in topological order using the same inferred noise U.
|
|
7
|
+
|
|
8
|
+
This implementation is intentionally conservative:
|
|
9
|
+
- v0.1 assumes a fully observed system and (by default) independent noises.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from __future__ import annotations
|
|
13
|
+
|
|
14
|
+
from dataclasses import dataclass
|
|
15
|
+
from typing import Dict, Iterable, List, Mapping, Optional, Sequence, Tuple
|
|
16
|
+
|
|
17
|
+
import numpy as np
|
|
18
|
+
|
|
19
|
+
try:
|
|
20
|
+
import networkx as nx # type: ignore
|
|
21
|
+
except Exception as e: # pragma: no cover
|
|
22
|
+
raise ImportError("networkx is required for LinearGaussianSCM") from e
|
|
23
|
+
|
|
24
|
+
from crca_core.models.spec import SCMSpec, StructuralEquationSpec
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@dataclass(frozen=True)
|
|
28
|
+
class LinearGaussianSCM:
|
|
29
|
+
"""A linear SCM with additive Gaussian noise for each endogenous variable."""
|
|
30
|
+
|
|
31
|
+
variables: Tuple[str, ...]
|
|
32
|
+
parents: Dict[str, Tuple[str, ...]]
|
|
33
|
+
coefficients: Dict[Tuple[str, str], float] # (parent, child) -> beta
|
|
34
|
+
intercepts: Dict[str, float]
|
|
35
|
+
# Noise is represented per-variable; v0.1 assumes diagonal covariance by default.
|
|
36
|
+
noise_cov: Optional[np.ndarray] = None
|
|
37
|
+
|
|
38
|
+
@classmethod
|
|
39
|
+
def from_spec(cls, spec: SCMSpec) -> "LinearGaussianSCM":
|
|
40
|
+
if spec.scm_type != "linear_gaussian":
|
|
41
|
+
raise ValueError(f"Unsupported scm_type: {spec.scm_type}")
|
|
42
|
+
|
|
43
|
+
variables: List[str] = []
|
|
44
|
+
parents: Dict[str, Tuple[str, ...]] = {}
|
|
45
|
+
coefficients: Dict[Tuple[str, str], float] = {}
|
|
46
|
+
intercepts: Dict[str, float] = {}
|
|
47
|
+
|
|
48
|
+
for eq in spec.equations:
|
|
49
|
+
if eq.form != "linear_gaussian":
|
|
50
|
+
raise ValueError(f"Unsupported equation form: {eq.form}")
|
|
51
|
+
v = eq.variable
|
|
52
|
+
if v in variables:
|
|
53
|
+
raise ValueError(f"Duplicate equation for variable: {v}")
|
|
54
|
+
variables.append(v)
|
|
55
|
+
parents[v] = tuple(eq.parents)
|
|
56
|
+
intercepts[v] = float(eq.intercept)
|
|
57
|
+
for p, beta in eq.coefficients.items():
|
|
58
|
+
coefficients[(p, v)] = float(beta)
|
|
59
|
+
|
|
60
|
+
noise_cov = None
|
|
61
|
+
if spec.noise_cov is not None:
|
|
62
|
+
noise_cov = np.array(spec.noise_cov, dtype=float)
|
|
63
|
+
if noise_cov.shape[0] != noise_cov.shape[1]:
|
|
64
|
+
raise ValueError("noise_cov must be square")
|
|
65
|
+
if noise_cov.shape[0] != len(variables):
|
|
66
|
+
raise ValueError("noise_cov dimension must match number of equations/variables")
|
|
67
|
+
|
|
68
|
+
model = cls(
|
|
69
|
+
variables=tuple(variables),
|
|
70
|
+
parents=parents,
|
|
71
|
+
coefficients=coefficients,
|
|
72
|
+
intercepts=intercepts,
|
|
73
|
+
noise_cov=noise_cov,
|
|
74
|
+
)
|
|
75
|
+
# Validate acyclicity/topological order.
|
|
76
|
+
_ = model.topological_order()
|
|
77
|
+
return model
|
|
78
|
+
|
|
79
|
+
def topological_order(self) -> List[str]:
|
|
80
|
+
g = nx.DiGraph()
|
|
81
|
+
for v in self.variables:
|
|
82
|
+
g.add_node(v)
|
|
83
|
+
for child, ps in self.parents.items():
|
|
84
|
+
for p in ps:
|
|
85
|
+
g.add_edge(p, child)
|
|
86
|
+
if not nx.is_directed_acyclic_graph(g):
|
|
87
|
+
raise ValueError("SCM graph must be a DAG")
|
|
88
|
+
return list(nx.topological_sort(g))
|
|
89
|
+
|
|
90
|
+
def _matrix_form(self) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
|
91
|
+
"""Return (A, b, Sigma_u) for A V = b + U."""
|
|
92
|
+
n = len(self.variables)
|
|
93
|
+
index = {v: i for i, v in enumerate(self.variables)}
|
|
94
|
+
M = np.zeros((n, n), dtype=float)
|
|
95
|
+
b = np.zeros((n,), dtype=float)
|
|
96
|
+
for v, intercept in self.intercepts.items():
|
|
97
|
+
b[index[v]] = float(intercept)
|
|
98
|
+
for (p, v), beta in self.coefficients.items():
|
|
99
|
+
i = index[v]
|
|
100
|
+
j = index[p]
|
|
101
|
+
M[i, j] = float(beta)
|
|
102
|
+
A = np.eye(n) - M
|
|
103
|
+
if self.noise_cov is None:
|
|
104
|
+
Sigma_u = np.eye(n)
|
|
105
|
+
else:
|
|
106
|
+
Sigma_u = self.noise_cov
|
|
107
|
+
return A, b, Sigma_u
|
|
108
|
+
|
|
109
|
+
def abduce_noise_conditional(self, factual: Mapping[str, float]) -> Dict[str, float]:
|
|
110
|
+
"""Conditional Gaussian abduction for partial observations."""
|
|
111
|
+
A, b, Sigma_u = self._matrix_form()
|
|
112
|
+
n = len(self.variables)
|
|
113
|
+
index = {v: i for i, v in enumerate(self.variables)}
|
|
114
|
+
obs_idx = [index[v] for v in factual.keys() if v in index]
|
|
115
|
+
miss_idx = [i for i in range(n) if i not in obs_idx]
|
|
116
|
+
v_o = np.array([float(factual[self.variables[i]]) for i in obs_idx], dtype=float)
|
|
117
|
+
|
|
118
|
+
# Compute mean and covariance of V
|
|
119
|
+
A_inv = np.linalg.inv(A)
|
|
120
|
+
mu_v = A_inv @ b
|
|
121
|
+
Sigma_v = A_inv @ Sigma_u @ A_inv.T
|
|
122
|
+
|
|
123
|
+
if not miss_idx:
|
|
124
|
+
v_full = np.zeros(n, dtype=float)
|
|
125
|
+
for i in range(n):
|
|
126
|
+
v_full[i] = float(factual[self.variables[i]])
|
|
127
|
+
else:
|
|
128
|
+
Sigma_oo = Sigma_v[np.ix_(obs_idx, obs_idx)]
|
|
129
|
+
Sigma_mo = Sigma_v[np.ix_(miss_idx, obs_idx)]
|
|
130
|
+
mu_o = mu_v[obs_idx]
|
|
131
|
+
mu_m = mu_v[miss_idx]
|
|
132
|
+
# Conditional mean of V_m given V_o
|
|
133
|
+
v_m = mu_m + Sigma_mo @ np.linalg.inv(Sigma_oo) @ (v_o - mu_o)
|
|
134
|
+
v_full = np.zeros(n, dtype=float)
|
|
135
|
+
for i, idx in enumerate(obs_idx):
|
|
136
|
+
v_full[idx] = v_o[i]
|
|
137
|
+
for i, idx in enumerate(miss_idx):
|
|
138
|
+
v_full[idx] = v_m[i]
|
|
139
|
+
|
|
140
|
+
u_mean = A @ v_full - b
|
|
141
|
+
return {self.variables[i]: float(u_mean[i]) for i in range(n)}
|
|
142
|
+
|
|
143
|
+
def abduce_noise(self, factual: Mapping[str, float], *, allow_partial: bool = False) -> Dict[str, float]:
|
|
144
|
+
"""Infer per-variable noise U from a factual state.
|
|
145
|
+
|
|
146
|
+
If `allow_partial` is False, all endogenous variables must be observed.
|
|
147
|
+
If True, noise is inferred only for observed variables whose parents are observed.
|
|
148
|
+
"""
|
|
149
|
+
order = self.topological_order()
|
|
150
|
+
u: Dict[str, float] = {}
|
|
151
|
+
x: Dict[str, float] = {k: float(v) for k, v in factual.items()}
|
|
152
|
+
|
|
153
|
+
missing = [v for v in order if v not in x]
|
|
154
|
+
if missing and not allow_partial:
|
|
155
|
+
raise ValueError(f"Factual observation missing variables: {missing}")
|
|
156
|
+
if missing and allow_partial and self.noise_cov is not None:
|
|
157
|
+
return self.abduce_noise_conditional(factual)
|
|
158
|
+
|
|
159
|
+
for v in order:
|
|
160
|
+
if v not in x:
|
|
161
|
+
continue
|
|
162
|
+
pred = self.intercepts.get(v, 0.0)
|
|
163
|
+
for p in self.parents.get(v, ()):
|
|
164
|
+
if p not in x:
|
|
165
|
+
if allow_partial:
|
|
166
|
+
pred = None
|
|
167
|
+
break
|
|
168
|
+
raise ValueError(f"Factual observation missing parent '{p}' for '{v}'")
|
|
169
|
+
beta = self.coefficients.get((p, v), 0.0)
|
|
170
|
+
pred += beta * x[p]
|
|
171
|
+
if pred is None:
|
|
172
|
+
continue
|
|
173
|
+
u[v] = float(x[v] - pred)
|
|
174
|
+
return u
|
|
175
|
+
|
|
176
|
+
def predict(self, noise: Mapping[str, float], interventions: Optional[Mapping[str, float]] = None) -> Dict[str, float]:
|
|
177
|
+
"""Forward simulate endogenous variables under interventions using fixed noise."""
|
|
178
|
+
interventions = interventions or {}
|
|
179
|
+
order = self.topological_order()
|
|
180
|
+
x: Dict[str, float] = {}
|
|
181
|
+
|
|
182
|
+
for v in order:
|
|
183
|
+
if v in interventions:
|
|
184
|
+
x[v] = float(interventions[v])
|
|
185
|
+
continue
|
|
186
|
+
pred = self.intercepts.get(v, 0.0)
|
|
187
|
+
for p in self.parents.get(v, ()):
|
|
188
|
+
beta = self.coefficients.get((p, v), 0.0)
|
|
189
|
+
pred += beta * x[p]
|
|
190
|
+
pred += float(noise.get(v, 0.0))
|
|
191
|
+
x[v] = float(pred)
|
|
192
|
+
return x
|
|
193
|
+
|
|
194
|
+
def counterfactual(self, factual: Mapping[str, float], interventions: Mapping[str, float]) -> Dict[str, float]:
|
|
195
|
+
"""Compute a counterfactual state via abduction–action–prediction."""
|
|
196
|
+
u = self.abduce_noise(factual)
|
|
197
|
+
return self.predict(u, interventions=interventions)
|
|
198
|
+
|