rtisanepy 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,18 @@
1
+ {
2
+ "permissions": {
3
+ "allow": [
4
+ "WebFetch(domain:github.com)",
5
+ "WebSearch",
6
+ "WebFetch(domain:tisane-stats.org)",
7
+ "WebFetch(domain:dl.acm.org)",
8
+ "WebFetch(domain:raw.githubusercontent.com)",
9
+ "WebFetch(domain:api.github.com)",
10
+ "Bash(python:*)",
11
+ "Bash(git add:*)",
12
+ "Bash(git commit:*)",
13
+ "Bash(git push:*)",
14
+ "Bash(gh pr:*)",
15
+ "Bash(git:*)"
16
+ ]
17
+ }
18
+ }
@@ -0,0 +1,7 @@
1
+ .venv/
2
+ __pycache__/
3
+ *.egg-info/
4
+ .pytest_cache/
5
+ dist/
6
+ build/
7
+ *.pyc
@@ -0,0 +1 @@
1
+ 3.12.0
@@ -0,0 +1,162 @@
1
+ Metadata-Version: 2.4
2
+ Name: rtisanepy
3
+ Version: 0.1.0
4
+ Summary: Authoring generalized linear mixed effects models from conceptual models
5
+ Project-URL: Homepage, https://github.com/emjun/rTisanePy
6
+ Project-URL: Repository, https://github.com/emjun/rTisanePy
7
+ Author: Eunice Jun
8
+ License-Expression: MIT
9
+ Classifier: License :: OSI Approved :: MIT License
10
+ Classifier: Operating System :: OS Independent
11
+ Classifier: Programming Language :: Python :: 3
12
+ Requires-Python: >=3.10
13
+ Requires-Dist: networkx>=3.0
14
+ Provides-Extra: dev
15
+ Requires-Dist: pytest>=7.0; extra == 'dev'
16
+ Description-Content-Type: text/markdown
17
+
18
+ # rTisanePy
19
+
20
+ Authoring generalized linear mixed effects models from conceptual models. Based on [rTisane](https://github.com/emjun/rTisane).
21
+
22
+ rTisanePy lets you:
23
+
24
+ - Declare variables and specify causal and relational assumptions
25
+ - Query for a statistical model from a conceptual model
26
+ - Automatically identify confounders (following [Cinelli, Forney & Pearl, 2022](https://doi.org/10.1177/00491241221099552))
27
+ - Infer random effects from nesting and repeated measures
28
+ - Select candidate family/link functions based on DV type
29
+
30
+ ## Installation
31
+
32
+ ```bash
33
+ pip install rtisanepy
34
+ ```
35
+
36
+ For development:
37
+
38
+ ```bash
39
+ pip install -e ".[dev]"
40
+ ```
41
+
42
+ Requires Python 3.10+ and `networkx`.
43
+
44
+ ## Quick Start
45
+
46
+ ```python
47
+ from rtisanepy import (
48
+ Unit, continuous, categories,
49
+ causes, equals, increases,
50
+ ConceptualModel,
51
+ )
52
+
53
+ # Declare variables
54
+ student = Unit("student", cardinality=100)
55
+ classroom = Unit("classroom", cardinality=10)
56
+ student = Unit("student", cardinality=100, nests_within=classroom)
57
+
58
+ tutoring = categories(unit=student, name="Tutoring", cardinality=2)
59
+ ses = categories(unit=student, name="SES", order=["lower", "middle", "upper"])
60
+ test_score = continuous(unit=student, name="TestScore")
61
+
62
+ # Build a conceptual model
63
+ cm = (
64
+ ConceptualModel()
65
+ .assume(causes(ses, test_score))
66
+ .assume(causes(ses, tutoring))
67
+ .hypothesize(causes(tutoring, test_score,
68
+ when=equals(tutoring, "in-person"),
69
+ then=increases(test_score)))
70
+ )
71
+
72
+ # Query for a statistical model
73
+ model = cm.query(iv=tutoring, dv=test_score)
74
+
75
+ print(model.formula()) # TestScore ~ Tutoring + SES + (1 | classroom)
76
+ print(model.family) # Inverse Gaussian
77
+ print(model.main_effects) # ['Tutoring', 'SES']
78
+ print(model.random_effects) # [RandomIntercept(group=classroom)]
79
+
80
+ # Override family/link
81
+ model2 = model.with_family("Gaussian", link="identity")
82
+ ```
83
+
84
+ ## API Overview
85
+
86
+ ### Variables
87
+
88
+ | Constructor | Description |
89
+ |---|---|
90
+ | `Unit(name, cardinality, nests_within=None)` | Entity (participant, subject) |
91
+ | `Participant(name, cardinality, nests_within=None)` | Alias for Unit |
92
+ | `Time(name, order=None, cardinality=0)` | Time variable |
93
+ | `continuous(unit, name, number_of_instances=1)` | Continuous measure |
94
+ | `counts(unit, name, number_of_instances=1)` | Count measure |
95
+ | `categories(unit, name, *, cardinality=None, order=None)` | Categorical measure (ordered if `order` given) |
96
+
97
+ ### Relationships
98
+
99
+ | Function | Description |
100
+ |---|---|
101
+ | `causes(cause, effect, *, when=None, then=None)` | Directed causal relationship |
102
+ | `relates(lhs, rhs, *, when=None, then=None)` | Undirected (ambiguous) relationship |
103
+ | `nests(base, group)` | Nesting relationship between units |
104
+
105
+ ### Comparisons (for `when`/`then` annotations)
106
+
107
+ | Function | Description |
108
+ |---|---|
109
+ | `equals(variable, value)` | Variable equals a value |
110
+ | `not_equals(variable, value)` | Variable does not equal a value |
111
+ | `increases(variable)` | Variable increases |
112
+ | `decreases(variable)` | Variable decreases |
113
+
114
+ ### ConceptualModel
115
+
116
+ ```python
117
+ cm = ConceptualModel()
118
+ cm.assume(relationship) # Add assumed relationship (returns self)
119
+ cm.hypothesize(relationship) # Add hypothesized relationship (returns self)
120
+ cm.interacts(*vars, dv=dv) # Add interaction annotation (returns self)
121
+ cm.query(iv=iv, dv=dv) # Infer a StatisticalModel
122
+ ```
123
+
124
+ At least one hypothesized relationship connecting `iv` and `dv` is required to query.
125
+
126
+ ### StatisticalModel
127
+
128
+ | Attribute / Method | Description |
129
+ |---|---|
130
+ | `.formula()` | R-style formula string |
131
+ | `.main_effects` | List of main effect variable names |
132
+ | `.interaction_effects` | List of interaction terms |
133
+ | `.random_effects` | List of `RandomIntercept` / `RandomSlope` |
134
+ | `.family` | Selected family function |
135
+ | `.link` | Selected link function |
136
+ | `.family_candidates` | All candidate family/link pairs |
137
+ | `.summary()` | Machine-readable dict |
138
+ | `.with_family(family, link=None)` | Copy with overridden family/link |
139
+
140
+ ## Running Tests
141
+
142
+ ```bash
143
+ python -m pytest tests/ -v
144
+ ```
145
+
146
+ ## Citation
147
+
148
+ If you use rTisanePy in your research, please cite:
149
+
150
+ ```bibtex
151
+ @software{rtisanepy,
152
+ author = {Eunice Jun},
153
+ title = {rTisanePy: A Python Tool for Authoring Statistical Models from Conceptual Models},
154
+ url = {https://github.com/emjun/rTisanePy},
155
+ year = {2026}
156
+ }
157
+ ```
158
+
159
+ ## References
160
+
161
+ - Eunice Jun, Audrey Seo, Jeffrey Heer, and René Just. 2024. [Tisane: Authoring Statistical Models via Formal Reasoning from Conceptual and Data Relationships.](https://doi.org/10.1145/3613904.3642347) *CHI 2024*.
162
+ - Carlos Cinelli, Andrew Forney, and Judea Pearl. 2022. [A Crash Course in Good and Bad Controls.](https://doi.org/10.1177/00491241221099552) *Sociological Methods & Research*.
@@ -0,0 +1,145 @@
1
+ # rTisanePy
2
+
3
+ Authoring generalized linear mixed effects models from conceptual models. Based on [rTisane](https://github.com/emjun/rTisane).
4
+
5
+ rTisanePy lets you:
6
+
7
+ - Declare variables and specify causal and relational assumptions
8
+ - Query for a statistical model from a conceptual model
9
+ - Automatically identify confounders (following [Cinelli, Forney & Pearl, 2022](https://doi.org/10.1177/00491241221099552))
10
+ - Infer random effects from nesting and repeated measures
11
+ - Select candidate family/link functions based on DV type
12
+
13
+ ## Installation
14
+
15
+ ```bash
16
+ pip install rtisanepy
17
+ ```
18
+
19
+ For development:
20
+
21
+ ```bash
22
+ pip install -e ".[dev]"
23
+ ```
24
+
25
+ Requires Python 3.10+ and `networkx`.
26
+
27
+ ## Quick Start
28
+
29
+ ```python
30
+ from rtisanepy import (
31
+ Unit, continuous, categories,
32
+ causes, equals, increases,
33
+ ConceptualModel,
34
+ )
35
+
36
+ # Declare variables
37
+ student = Unit("student", cardinality=100)
38
+ classroom = Unit("classroom", cardinality=10)
39
+ student = Unit("student", cardinality=100, nests_within=classroom)
40
+
41
+ tutoring = categories(unit=student, name="Tutoring", cardinality=2)
42
+ ses = categories(unit=student, name="SES", order=["lower", "middle", "upper"])
43
+ test_score = continuous(unit=student, name="TestScore")
44
+
45
+ # Build a conceptual model
46
+ cm = (
47
+ ConceptualModel()
48
+ .assume(causes(ses, test_score))
49
+ .assume(causes(ses, tutoring))
50
+ .hypothesize(causes(tutoring, test_score,
51
+ when=equals(tutoring, "in-person"),
52
+ then=increases(test_score)))
53
+ )
54
+
55
+ # Query for a statistical model
56
+ model = cm.query(iv=tutoring, dv=test_score)
57
+
58
+ print(model.formula()) # TestScore ~ Tutoring + SES + (1 | classroom)
59
+ print(model.family) # Inverse Gaussian
60
+ print(model.main_effects) # ['Tutoring', 'SES']
61
+ print(model.random_effects) # [RandomIntercept(group=classroom)]
62
+
63
+ # Override family/link
64
+ model2 = model.with_family("Gaussian", link="identity")
65
+ ```
66
+
67
+ ## API Overview
68
+
69
+ ### Variables
70
+
71
+ | Constructor | Description |
72
+ |---|---|
73
+ | `Unit(name, cardinality, nests_within=None)` | Entity (participant, subject) |
74
+ | `Participant(name, cardinality, nests_within=None)` | Alias for Unit |
75
+ | `Time(name, order=None, cardinality=0)` | Time variable |
76
+ | `continuous(unit, name, number_of_instances=1)` | Continuous measure |
77
+ | `counts(unit, name, number_of_instances=1)` | Count measure |
78
+ | `categories(unit, name, *, cardinality=None, order=None)` | Categorical measure (ordered if `order` given) |
79
+
80
+ ### Relationships
81
+
82
+ | Function | Description |
83
+ |---|---|
84
+ | `causes(cause, effect, *, when=None, then=None)` | Directed causal relationship |
85
+ | `relates(lhs, rhs, *, when=None, then=None)` | Undirected (ambiguous) relationship |
86
+ | `nests(base, group)` | Nesting relationship between units |
87
+
88
+ ### Comparisons (for `when`/`then` annotations)
89
+
90
+ | Function | Description |
91
+ |---|---|
92
+ | `equals(variable, value)` | Variable equals a value |
93
+ | `not_equals(variable, value)` | Variable does not equal a value |
94
+ | `increases(variable)` | Variable increases |
95
+ | `decreases(variable)` | Variable decreases |
96
+
97
+ ### ConceptualModel
98
+
99
+ ```python
100
+ cm = ConceptualModel()
101
+ cm.assume(relationship) # Add assumed relationship (returns self)
102
+ cm.hypothesize(relationship) # Add hypothesized relationship (returns self)
103
+ cm.interacts(*vars, dv=dv) # Add interaction annotation (returns self)
104
+ cm.query(iv=iv, dv=dv) # Infer a StatisticalModel
105
+ ```
106
+
107
+ At least one hypothesized relationship connecting `iv` and `dv` is required to query.
108
+
109
+ ### StatisticalModel
110
+
111
+ | Attribute / Method | Description |
112
+ |---|---|
113
+ | `.formula()` | R-style formula string |
114
+ | `.main_effects` | List of main effect variable names |
115
+ | `.interaction_effects` | List of interaction terms |
116
+ | `.random_effects` | List of `RandomIntercept` / `RandomSlope` |
117
+ | `.family` | Selected family function |
118
+ | `.link` | Selected link function |
119
+ | `.family_candidates` | All candidate family/link pairs |
120
+ | `.summary()` | Machine-readable dict |
121
+ | `.with_family(family, link=None)` | Copy with overridden family/link |
122
+
123
+ ## Running Tests
124
+
125
+ ```bash
126
+ python -m pytest tests/ -v
127
+ ```
128
+
129
+ ## Citation
130
+
131
+ If you use rTisanePy in your research, please cite:
132
+
133
+ ```bibtex
134
+ @software{rtisanepy,
135
+ author = {Eunice Jun},
136
+ title = {rTisanePy: A Python Tool for Authoring Statistical Models from Conceptual Models},
137
+ url = {https://github.com/emjun/rTisanePy},
138
+ year = {2026}
139
+ }
140
+ ```
141
+
142
+ ## References
143
+
144
+ - Eunice Jun, Audrey Seo, Jeffrey Heer, and René Just. 2024. [Tisane: Authoring Statistical Models via Formal Reasoning from Conceptual and Data Relationships.](https://doi.org/10.1145/3613904.3642347) *CHI 2024*.
145
+ - Carlos Cinelli, Andrew Forney, and Judea Pearl. 2022. [A Crash Course in Good and Bad Controls.](https://doi.org/10.1177/00491241221099552) *Sociological Methods & Research*.
@@ -0,0 +1,27 @@
1
+ [build-system]
2
+ requires = ["hatchling"]
3
+ build-backend = "hatchling.build"
4
+
5
+ [project]
6
+ name = "rtisanepy"
7
+ version = "0.1.0"
8
+ description = "Authoring generalized linear mixed effects models from conceptual models"
9
+ readme = "README.md"
10
+ license = "MIT"
11
+ requires-python = ">=3.10"
12
+ dependencies = ["networkx>=3.0"]
13
+ authors = [
14
+ { name = "Eunice Jun" },
15
+ ]
16
+ classifiers = [
17
+ "Programming Language :: Python :: 3",
18
+ "License :: OSI Approved :: MIT License",
19
+ "Operating System :: OS Independent",
20
+ ]
21
+
22
+ [project.urls]
23
+ Homepage = "https://github.com/emjun/rTisanePy"
24
+ Repository = "https://github.com/emjun/rTisanePy"
25
+
26
+ [project.optional-dependencies]
27
+ dev = ["pytest>=7.0"]
@@ -0,0 +1,48 @@
1
+ """rTisanePy — Python port of rTisane for authoring statistical models."""
2
+
3
+ from .variables import (
4
+ Unit,
5
+ Participant,
6
+ Time,
7
+ Continuous,
8
+ Counts,
9
+ UnorderedCategories,
10
+ OrderedCategories,
11
+ Measure,
12
+ Exactly,
13
+ AtMost,
14
+ Per,
15
+ continuous,
16
+ counts,
17
+ categories,
18
+ )
19
+ from .relationships import causes, relates, nests
20
+ from .comparisons import equals, not_equals, increases, decreases
21
+ from .conceptual_model import ConceptualModel
22
+ from .statistical_model import StatisticalModel
23
+
24
+ __all__ = [
25
+ "Unit",
26
+ "Participant",
27
+ "Time",
28
+ "Continuous",
29
+ "Counts",
30
+ "UnorderedCategories",
31
+ "OrderedCategories",
32
+ "Measure",
33
+ "Exactly",
34
+ "AtMost",
35
+ "Per",
36
+ "continuous",
37
+ "counts",
38
+ "categories",
39
+ "causes",
40
+ "relates",
41
+ "nests",
42
+ "equals",
43
+ "not_equals",
44
+ "increases",
45
+ "decreases",
46
+ "ConceptualModel",
47
+ "StatisticalModel",
48
+ ]
@@ -0,0 +1,37 @@
1
+ """Comparison predicates for when/then annotations.
2
+
3
+ Ports equals.R, increases.R, decreases.R, notEquals.R.
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ from dataclasses import dataclass
9
+
10
+ from .variables import AbstractVariable
11
+
12
+
13
+ @dataclass(frozen=True)
14
+ class Compares:
15
+ """Holds a variable reference and a condition string."""
16
+ variable: AbstractVariable
17
+ condition: str
18
+
19
+
20
+ def equals(variable: AbstractVariable, value) -> Compares:
21
+ """Condition: variable == value."""
22
+ return Compares(variable=variable, condition=f"=={value}")
23
+
24
+
25
+ def not_equals(variable: AbstractVariable, value) -> Compares:
26
+ """Condition: variable != value."""
27
+ return Compares(variable=variable, condition=f"!={value}")
28
+
29
+
30
+ def increases(variable: AbstractVariable) -> Compares:
31
+ """Condition: variable increases."""
32
+ return Compares(variable=variable, condition="increases")
33
+
34
+
35
+ def decreases(variable: AbstractVariable) -> Compares:
36
+ """Condition: variable decreases."""
37
+ return Compares(variable=variable, condition="decreases")
@@ -0,0 +1,171 @@
1
+ """ConceptualModel: the central object for declaring domain knowledge.
2
+
3
+ Ports ConceptualModel, assume(), hypothesize() from rTisane.
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ from dataclasses import dataclass, field
9
+ from typing import TYPE_CHECKING
10
+
11
+ import networkx as nx
12
+
13
+ from .variables import AbstractVariable, Measure, Unit, Time
14
+ from .relationships import Causes, Relates, Nests, Interacts
15
+ from .graph import build_graph
16
+
17
+ if TYPE_CHECKING:
18
+ from .statistical_model import StatisticalModel
19
+
20
+
21
+ # ---------------------------------------------------------------------------
22
+ # Wrappers that label a relationship as assumed vs hypothesized
23
+ # ---------------------------------------------------------------------------
24
+
25
+ @dataclass
26
+ class Assumption:
27
+ """A relationship the analyst assumes (prior work / strong belief)."""
28
+ relationship: Causes | Relates | Interacts
29
+
30
+
31
+ @dataclass
32
+ class Hypothesis:
33
+ """A relationship the analyst hypothesises (focus of analysis)."""
34
+ relationship: Causes | Relates | Interacts
35
+
36
+
37
+ # ---------------------------------------------------------------------------
38
+ # ConceptualModel
39
+ # ---------------------------------------------------------------------------
40
+
41
+ class ConceptualModel:
42
+ """Graph-backed collection of variables and labeled relationships."""
43
+
44
+ def __init__(self):
45
+ self.variables: dict[str, AbstractVariable] = {}
46
+ self.relationships: list[Assumption | Hypothesis] = []
47
+ self.graph: nx.DiGraph = nx.DiGraph()
48
+ self.interactions: list[Interacts] = []
49
+
50
+ # -- helpers ----------------------------------------------------------
51
+
52
+ def _track_variables(self, rel: Causes | Relates | Nests | Interacts):
53
+ """Register variables mentioned in a relationship."""
54
+ if isinstance(rel, Causes):
55
+ self._add_var(rel.cause)
56
+ self._add_var(rel.effect)
57
+ elif isinstance(rel, Relates):
58
+ self._add_var(rel.lhs)
59
+ self._add_var(rel.rhs)
60
+ elif isinstance(rel, Nests):
61
+ self._add_var(rel.base)
62
+ self._add_var(rel.group)
63
+ elif isinstance(rel, Interacts):
64
+ for v in rel.variables:
65
+ self._add_var(v)
66
+ if rel.dv is not None:
67
+ self._add_var(rel.dv)
68
+
69
+ def _add_var(self, var: AbstractVariable):
70
+ if var.name not in self.variables:
71
+ self.variables[var.name] = var
72
+
73
+ def _rebuild_graph(self):
74
+ self.graph = build_graph(self.relationships)
75
+
76
+ # -- public API -------------------------------------------------------
77
+
78
+ def assume(self, relationship: Causes | Relates) -> "ConceptualModel":
79
+ """Add an assumed relationship. Returns *self* for chaining."""
80
+ a = Assumption(relationship)
81
+ self.relationships.append(a)
82
+ self._track_variables(relationship)
83
+ self._rebuild_graph()
84
+ return self
85
+
86
+ def hypothesize(self, relationship: Causes | Relates) -> "ConceptualModel":
87
+ """Add a hypothesised relationship. Returns *self* for chaining."""
88
+ h = Hypothesis(relationship)
89
+ self.relationships.append(h)
90
+ self._track_variables(relationship)
91
+ self._rebuild_graph()
92
+ return self
93
+
94
+ def interacts(self, *variables: AbstractVariable, dv: AbstractVariable) -> "ConceptualModel":
95
+ """Annotate an interaction among *variables* w.r.t. *dv*."""
96
+ interaction = Interacts(variables=list(variables), dv=dv)
97
+ self.interactions.append(interaction)
98
+ self._track_variables(interaction)
99
+ return self
100
+
101
+ # -- query ------------------------------------------------------------
102
+
103
+ def query(self, *, iv: AbstractVariable, dv: AbstractVariable, data=None) -> "StatisticalModel":
104
+ """Infer a statistical model for the effect of *iv* on *dv*.
105
+
106
+ Requires at least one hypothesised relationship involving *iv* and *dv*.
107
+ """
108
+ from .inference.confounders import infer_confounders
109
+ from .inference.random_effects import infer_random_effects
110
+ from .inference.family_link import infer_family_link_functions
111
+ from .statistical_model import StatisticalModel
112
+
113
+ # Validate: must have at least one hypothesis
114
+ hypotheses = [r for r in self.relationships if isinstance(r, Hypothesis)]
115
+ if not hypotheses:
116
+ raise ValueError("ConceptualModel must contain at least one hypothesized relationship to query.")
117
+
118
+ # Validate: hypothesis must connect iv and dv
119
+ has_iv_dv = False
120
+ for h in hypotheses:
121
+ rel = h.relationship
122
+ if isinstance(rel, Causes):
123
+ if rel.cause.name == iv.name and rel.effect.name == dv.name:
124
+ has_iv_dv = True
125
+ elif isinstance(rel, Relates):
126
+ names = {rel.lhs.name, rel.rhs.name}
127
+ if iv.name in names and dv.name in names:
128
+ has_iv_dv = True
129
+ if not has_iv_dv:
130
+ raise ValueError(
131
+ f"No hypothesis connects iv={iv.name!r} and dv={dv.name!r}."
132
+ )
133
+
134
+ # 1. Confounders
135
+ confounders = infer_confounders(self, iv, dv)
136
+
137
+ # 2. Random effects
138
+ relevant_interactions = [
139
+ ix for ix in self.interactions if ix.dv is not None and ix.dv.name == dv.name
140
+ ]
141
+ random_effects = infer_random_effects(
142
+ confounders=confounders,
143
+ interactions=relevant_interactions,
144
+ conceptual_model=self,
145
+ iv=iv,
146
+ dv=dv,
147
+ )
148
+
149
+ # 3. Family / link
150
+ family_candidates = infer_family_link_functions(dv)
151
+
152
+ # 4. Interaction effects
153
+ interaction_effects: list[str] = []
154
+ for ix in relevant_interactions:
155
+ var_names = [v.name for v in ix.variables]
156
+ interaction_effects.append(":".join(var_names))
157
+
158
+ # 5. Build StatisticalModel
159
+ first_family = next(iter(family_candidates))
160
+ first_link = family_candidates[first_family][0] if family_candidates[first_family] else None
161
+
162
+ return StatisticalModel(
163
+ dv=dv.name,
164
+ iv=iv.name,
165
+ main_effects=[iv.name] + confounders,
166
+ interaction_effects=interaction_effects,
167
+ random_effects=random_effects,
168
+ family_candidates=family_candidates,
169
+ family=first_family,
170
+ link=first_link,
171
+ )
@@ -0,0 +1,62 @@
1
+ """DAG construction and traversal using networkx.
2
+
3
+ Replaces rTisane's dagitty-based graph operations.
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ import networkx as nx
9
+
10
+ from .relationships import Causes, Relates
11
+
12
+
13
+ def build_graph(relationships) -> nx.DiGraph:
14
+ """Build a directed graph from a list of Assumption/Hypothesis wrappers.
15
+
16
+ - Causes → single directed edge (cause → effect)
17
+ - Relates → two directed edges tagged ``ambiguous=True``
18
+ """
19
+ g = nx.DiGraph()
20
+ for entry in relationships:
21
+ rel = entry.relationship
22
+ if isinstance(rel, Causes):
23
+ g.add_edge(rel.cause.name, rel.effect.name, ambiguous=False)
24
+ elif isinstance(rel, Relates):
25
+ g.add_edge(rel.lhs.name, rel.rhs.name, ambiguous=True)
26
+ g.add_edge(rel.rhs.name, rel.lhs.name, ambiguous=True)
27
+ return g
28
+
29
+
30
+ def parents(graph: nx.DiGraph, node: str) -> list[str]:
31
+ """Direct predecessors of *node*."""
32
+ if node not in graph:
33
+ return []
34
+ return list(graph.predecessors(node))
35
+
36
+
37
+ def children(graph: nx.DiGraph, node: str) -> list[str]:
38
+ """Direct successors of *node*."""
39
+ if node not in graph:
40
+ return []
41
+ return list(graph.successors(node))
42
+
43
+
44
+ def ancestors(graph: nx.DiGraph, node: str) -> set[str]:
45
+ """All ancestors of *node* (including *node* itself, matching dagitty)."""
46
+ if node not in graph:
47
+ return {node}
48
+ return nx.ancestors(graph, node) | {node}
49
+
50
+
51
+ def descendants(graph: nx.DiGraph, node: str) -> set[str]:
52
+ """All descendants of *node* (including *node* itself)."""
53
+ if node not in graph:
54
+ return {node}
55
+ return nx.descendants(graph, node) | {node}
56
+
57
+
58
+ def find_paths(graph: nx.DiGraph, source: str, target: str) -> list[list[str]]:
59
+ """All simple paths from *source* to *target*."""
60
+ if source not in graph or target not in graph:
61
+ return []
62
+ return list(nx.all_simple_paths(graph, source, target))
File without changes