syntha-ehr 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- syntha/__init__.py +14 -0
- syntha/cli.py +256 -0
- syntha/conditional.py +176 -0
- syntha/data.py +30 -0
- syntha/export_model.py +50 -0
- syntha/fhir/__init__.py +3 -0
- syntha/fhir/clinical_extras.py +246 -0
- syntha/fhir/codes.py +95 -0
- syntha/fhir/export.py +326 -0
- syntha/fhir/panels.py +104 -0
- syntha/fhir/resources.py +154 -0
- syntha/fhir/rxnorm.py +57 -0
- syntha/generator/__init__.py +4 -0
- syntha/generator/constraints.py +79 -0
- syntha/generator/copula.py +191 -0
- syntha/generator/missingness.py +155 -0
- syntha/generator/mixed_corr.py +229 -0
- syntha/locale/__init__.py +19 -0
- syntha/locale/turkish.py +121 -0
- syntha/longitudinal.py +79 -0
- syntha/longitudinal_labs.py +228 -0
- syntha/models/__init__.py +3 -0
- syntha/models/registry.py +124 -0
- syntha/modules/__init__.py +32 -0
- syntha/modules/asthma_copd.py +45 -0
- syntha/modules/base.py +46 -0
- syntha/modules/depression_anxiety.py +64 -0
- syntha/modules/diabetes.py +53 -0
- syntha/modules/hyperlipidemia.py +32 -0
- syntha/modules/hypertension.py +47 -0
- syntha/modules/ihd.py +28 -0
- syntha/modules/thyroid.py +28 -0
- syntha/pipeline.py +193 -0
- syntha/preprocess.py +36 -0
- syntha/privacy.py +204 -0
- syntha/reference_ranges.py +175 -0
- syntha/schema.py +127 -0
- syntha/server.py +187 -0
- syntha/validate.py +138 -0
- syntha_ehr-0.5.0.dist-info/METADATA +526 -0
- syntha_ehr-0.5.0.dist-info/RECORD +45 -0
- syntha_ehr-0.5.0.dist-info/WHEEL +5 -0
- syntha_ehr-0.5.0.dist-info/entry_points.txt +2 -0
- syntha_ehr-0.5.0.dist-info/licenses/LICENSE +17 -0
- syntha_ehr-0.5.0.dist-info/top_level.txt +1 -0
syntha/__init__.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
"""syntha — synthetic patient record generator.
|
|
2
|
+
|
|
3
|
+
Top-level re-exports for ergonomics. Users can do:
|
|
4
|
+
|
|
5
|
+
from syntha import GaussianCopulaGenerator, PipelineConfig, run
|
|
6
|
+
|
|
7
|
+
instead of having to know the submodule layout.
|
|
8
|
+
"""
|
|
9
|
+
__version__ = "0.5.0"
|
|
10
|
+
|
|
11
|
+
from .generator.copula import GaussianCopulaGenerator
|
|
12
|
+
from .pipeline import PipelineConfig, run
|
|
13
|
+
|
|
14
|
+
__all__ = ["GaussianCopulaGenerator", "PipelineConfig", "run", "__version__"]
|
syntha/cli.py
ADDED
|
@@ -0,0 +1,256 @@
|
|
|
1
|
+
"""Command-line interface for syntha."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
import json
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
|
|
7
|
+
import click
|
|
8
|
+
import numpy as np
|
|
9
|
+
import pandas as pd
|
|
10
|
+
|
|
11
|
+
from .generator.copula import GaussianCopulaGenerator
|
|
12
|
+
from .models.registry import ModelRegistry
|
|
13
|
+
from .pipeline import PipelineConfig, run
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@click.group()
|
|
17
|
+
@click.version_option()
|
|
18
|
+
def main() -> None:
|
|
19
|
+
"""syntha — synthetic patient record generator."""
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@main.command()
|
|
23
|
+
@click.option("--input", "input_csv", required=True, type=click.Path(exists=True))
|
|
24
|
+
@click.option("--output", "output_dir", required=True, type=click.Path())
|
|
25
|
+
@click.option("--n", default=1000, show_default=True, help="Number of synthetic episodes (or total encounters in longitudinal mode)")
|
|
26
|
+
@click.option("--cohort", default="strict", show_default=True)
|
|
27
|
+
@click.option("--seed", default=42, show_default=True)
|
|
28
|
+
@click.option("--csv/--no-csv", default=True, show_default=True)
|
|
29
|
+
@click.option("--fhir/--no-fhir", default=True, show_default=True)
|
|
30
|
+
@click.option("--fhir-format", type=click.Choice(["ndjson", "json"]), default="ndjson", show_default=True)
|
|
31
|
+
@click.option("--modules/--no-modules", default=True, show_default=True, help="Run Synthea-style clinical modules during FHIR export")
|
|
32
|
+
@click.option("--longitudinal", is_flag=True, default=False, help="Expand each baseline into multiple encounters over time")
|
|
33
|
+
@click.option("--encounters-per-patient", default=4.0, show_default=True)
|
|
34
|
+
@click.option("--years-of-history", default=3.0, show_default=True)
|
|
35
|
+
@click.option("--registry-dir", default=None, help="Directory for the trained-model registry (default: <output>/models)")
|
|
36
|
+
@click.option("--lab-history/--no-lab-history", default=False, show_default=True,
|
|
37
|
+
help="Emit 2-4 prior measurements per lab (v0.5.5 longitudinal labs)")
|
|
38
|
+
@click.option("--conditional-missingness/--no-conditional-missingness", default=True,
|
|
39
|
+
show_default=True, help="Apply comorbidity-conditional missingness (v0.5.2)")
|
|
40
|
+
@click.option("--validation/--no-validation", default=True, show_default=True,
|
|
41
|
+
help="Compute KS/Wasserstein/correlation report alongside output")
|
|
42
|
+
def generate(input_csv, output_dir, n, cohort, seed, csv, fhir, fhir_format,
|
|
43
|
+
modules, longitudinal, encounters_per_patient, years_of_history,
|
|
44
|
+
registry_dir, lab_history, conditional_missingness, validation):
|
|
45
|
+
"""Train copula, sample, run modules, write CSV + FHIR + model card."""
|
|
46
|
+
cfg = PipelineConfig(
|
|
47
|
+
n=n, cohort=cohort, random_seed=seed,
|
|
48
|
+
write_csv=csv, write_fhir=fhir, fhir_format=fhir_format,
|
|
49
|
+
run_modules=modules, longitudinal=longitudinal,
|
|
50
|
+
encounters_per_patient_mean=encounters_per_patient,
|
|
51
|
+
years_of_history=years_of_history, registry_dir=registry_dir,
|
|
52
|
+
write_validation=validation,
|
|
53
|
+
apply_conditional_missingness=conditional_missingness,
|
|
54
|
+
include_lab_history=lab_history,
|
|
55
|
+
)
|
|
56
|
+
click.echo(json.dumps(run(input_csv, output_dir, cfg), indent=2))
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
@main.command()
|
|
60
|
+
@click.option("--input", "input_csv", required=True, type=click.Path(exists=True))
|
|
61
|
+
@click.option("--registry", "registry_dir", required=True, type=click.Path())
|
|
62
|
+
@click.option("--name", required=True)
|
|
63
|
+
@click.option("--cohort", default="strict", show_default=True)
|
|
64
|
+
@click.option("--seed", default=42, show_default=True)
|
|
65
|
+
def fit(input_csv, registry_dir, name, cohort, seed):
|
|
66
|
+
"""Fit a copula and store it (with model card) in the registry."""
|
|
67
|
+
from . import data, preprocess
|
|
68
|
+
src = data.load_episodes(input_csv)
|
|
69
|
+
modeled = preprocess.clip_to_physiologic(
|
|
70
|
+
preprocess.coerce_types(data.filter_to_modeled(src))
|
|
71
|
+
)
|
|
72
|
+
feat_df, bcols, ccols = preprocess.split_modeled(modeled)
|
|
73
|
+
gen = GaussianCopulaGenerator(random_seed=seed).fit(feat_df, bcols, ccols, cohort=cohort)
|
|
74
|
+
registry = ModelRegistry(registry_dir)
|
|
75
|
+
card = registry.save(name, gen, input_csv, modeled, bcols, ccols, cohort)
|
|
76
|
+
click.echo(json.dumps({
|
|
77
|
+
"registry": str(registry_dir), "name": name,
|
|
78
|
+
"n_train": card.n_train, "source_sha256": card.source_sha256,
|
|
79
|
+
}, indent=2))
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
@main.command(name="list-models")
|
|
83
|
+
@click.option("--registry", "registry_dir", required=True, type=click.Path(exists=True))
|
|
84
|
+
def list_models(registry_dir):
|
|
85
|
+
"""List models in a registry."""
|
|
86
|
+
registry = ModelRegistry(registry_dir)
|
|
87
|
+
click.echo(json.dumps(registry.list_models(), indent=2))
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
@main.command(name="export-model")
|
|
91
|
+
@click.option("--registry", "registry_dir", required=True, type=click.Path(exists=True))
|
|
92
|
+
@click.option("--name", required=True)
|
|
93
|
+
@click.option("--output", "output_path", required=True, type=click.Path())
|
|
94
|
+
@click.option("--quantiles", default=200, show_default=True, help="Order statistics per continuous marginal")
|
|
95
|
+
def export_model(registry_dir, name, output_path, quantiles):
|
|
96
|
+
"""Export a registered copula to JSON for use by the Tauri desktop app."""
|
|
97
|
+
from .export_model import export_model_to_json
|
|
98
|
+
gen, card = ModelRegistry(registry_dir).load(name)
|
|
99
|
+
path = export_model_to_json(gen, output_path, n_quantiles=quantiles)
|
|
100
|
+
click.echo(json.dumps({
|
|
101
|
+
"path": str(path), "cohort": card.cohort,
|
|
102
|
+
"n_train": card.n_train, "size_kb": path.stat().st_size // 1024,
|
|
103
|
+
}, indent=2))
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
@main.command(name="show-card")
|
|
107
|
+
@click.option("--registry", "registry_dir", required=True, type=click.Path(exists=True))
|
|
108
|
+
@click.option("--name", required=True)
|
|
109
|
+
def show_card(registry_dir, name):
|
|
110
|
+
"""Show a model card."""
|
|
111
|
+
registry = ModelRegistry(registry_dir)
|
|
112
|
+
_, card = registry.load(name)
|
|
113
|
+
click.echo(card.to_json())
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
@main.command()
|
|
117
|
+
@click.option("--registry", "registry_dir", required=True, type=click.Path(exists=True))
|
|
118
|
+
@click.option("--name", required=True)
|
|
119
|
+
@click.option("--output", "output_csv", required=True, type=click.Path())
|
|
120
|
+
@click.option("--n", default=1000, show_default=True)
|
|
121
|
+
def sample(registry_dir, name, output_csv, n):
|
|
122
|
+
"""Sample n rows from a registered model (raw — no constraints, no FHIR)."""
|
|
123
|
+
registry = ModelRegistry(registry_dir)
|
|
124
|
+
gen, _ = registry.load(name)
|
|
125
|
+
df = gen.sample(n)
|
|
126
|
+
Path(output_csv).parent.mkdir(parents=True, exist_ok=True)
|
|
127
|
+
df.to_csv(output_csv, index=False)
|
|
128
|
+
click.echo(json.dumps({"rows": len(df), "csv": output_csv}, indent=2))
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
@main.command(name="sample-conditional")
|
|
132
|
+
@click.option("--registry", "registry_dir", required=True, type=click.Path(exists=True))
|
|
133
|
+
@click.option("--name", required=True)
|
|
134
|
+
@click.option("--output", "output_csv", required=True, type=click.Path())
|
|
135
|
+
@click.option("--n", default=1000, show_default=True, help="Target number of accepted rows")
|
|
136
|
+
@click.option(
|
|
137
|
+
"--condition",
|
|
138
|
+
required=True,
|
|
139
|
+
help=(
|
|
140
|
+
"pandas-style filter expression in the model's column names. "
|
|
141
|
+
'Example: --condition "age > 60 & DM_Tum == 1 & bp_systolic >= 140"'
|
|
142
|
+
),
|
|
143
|
+
)
|
|
144
|
+
@click.option("--oversample", default=5.0, show_default=True,
|
|
145
|
+
help="Initial oversample factor for rejection sampling")
|
|
146
|
+
@click.option("--max-rounds", default=10, show_default=True)
|
|
147
|
+
def sample_conditional_cmd(registry_dir, name, output_csv, n, condition,
|
|
148
|
+
oversample, max_rounds):
|
|
149
|
+
"""Conditional sampling via rejection — return only rows matching --condition."""
|
|
150
|
+
from .conditional import sample_conditional
|
|
151
|
+
registry = ModelRegistry(registry_dir)
|
|
152
|
+
gen, _ = registry.load(name)
|
|
153
|
+
result = sample_conditional(
|
|
154
|
+
gen, n=n, condition=condition,
|
|
155
|
+
oversample_factor=oversample, max_rounds=max_rounds,
|
|
156
|
+
)
|
|
157
|
+
Path(output_csv).parent.mkdir(parents=True, exist_ok=True)
|
|
158
|
+
result.rows.to_csv(output_csv, index=False)
|
|
159
|
+
click.echo(json.dumps({
|
|
160
|
+
"n_requested": result.n_requested,
|
|
161
|
+
"n_generated": result.n_generated,
|
|
162
|
+
"rounds": result.rounds,
|
|
163
|
+
"rejection_rate": round(result.rejection_rate, 3),
|
|
164
|
+
"csv": output_csv,
|
|
165
|
+
}, indent=2))
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
@main.command()
|
|
169
|
+
@click.option("--input", "input_csv", required=True, type=click.Path(exists=True))
|
|
170
|
+
@click.option("--output", "output_dir", required=True, type=click.Path())
|
|
171
|
+
@click.option("--format", "fmt", type=click.Choice(["ndjson", "json"]), default="ndjson", show_default=True)
|
|
172
|
+
@click.option("--modules/--no-modules", default=True, show_default=True)
|
|
173
|
+
def fhir(input_csv, output_dir, fmt, modules):
|
|
174
|
+
"""Convert an existing synthetic CSV to FHIR R4 bundles."""
|
|
175
|
+
from .fhir.export import write_fhir_bundles
|
|
176
|
+
df = pd.read_csv(input_csv)
|
|
177
|
+
path = write_fhir_bundles(df, output_dir, fmt=fmt, run_modules=modules)
|
|
178
|
+
click.echo(json.dumps({"rows": len(df), "fhir": str(path)}, indent=2))
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
@main.command()
|
|
182
|
+
@click.option("--bundles", "bundles_ndjson", required=True, type=click.Path(exists=True), help="Path to bundles.ndjson")
|
|
183
|
+
@click.option("--host", default="127.0.0.1", show_default=True)
|
|
184
|
+
@click.option("--port", default=8080, show_default=True)
|
|
185
|
+
def serve(bundles_ndjson, host, port):
|
|
186
|
+
"""Boot a minimal read-only FHIR R4 server backed by a bundles NDJSON file."""
|
|
187
|
+
from .server import serve_forever
|
|
188
|
+
serve_forever(bundles_ndjson, host=host, port=port)
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
@main.command(name="validate")
|
|
192
|
+
@click.option("--source", "source_csv", required=True, type=click.Path(exists=True))
|
|
193
|
+
@click.option("--synthetic", "synthetic_csv", required=True, type=click.Path(exists=True))
|
|
194
|
+
@click.option("--output", "report_path", required=True, type=click.Path())
|
|
195
|
+
def validate_cmd(source_csv, synthetic_csv, report_path):
|
|
196
|
+
"""Compute KS / Wasserstein / correlation-diff between source and synthetic.
|
|
197
|
+
|
|
198
|
+
(CLI command name is `validate`; the Python function is `validate_cmd`
|
|
199
|
+
to avoid shadowing the `validate.validate` module function — per the
|
|
200
|
+
v0.5 architecture review.)
|
|
201
|
+
"""
|
|
202
|
+
from . import data, preprocess
|
|
203
|
+
from .validate import save_report
|
|
204
|
+
from .validate import validate as _v
|
|
205
|
+
|
|
206
|
+
src = preprocess.coerce_types(data.filter_to_modeled(data.load_episodes(source_csv)))
|
|
207
|
+
syn = pd.read_csv(synthetic_csv)
|
|
208
|
+
_, bcols, ccols = preprocess.split_modeled(src)
|
|
209
|
+
report = _v(src, syn, ccols, bcols)
|
|
210
|
+
save_report(report, report_path)
|
|
211
|
+
click.echo(json.dumps(report.summary(), indent=2))
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
@main.command()
|
|
215
|
+
@click.option("--source", "source_csv", required=True, type=click.Path(exists=True))
|
|
216
|
+
@click.option("--synthetic", "synthetic_csv", required=True, type=click.Path(exists=True))
|
|
217
|
+
@click.option("--output", "report_path", required=True, type=click.Path())
|
|
218
|
+
@click.option("--split", default=0.8, show_default=True,
|
|
219
|
+
help="Fraction of source used as train (rest is held out for the attack)")
|
|
220
|
+
def audit(source_csv, synthetic_csv, report_path, split):
|
|
221
|
+
"""Run a privacy audit (membership + attribute inference attacks) against
|
|
222
|
+
a synthetic CSV and write a JSON report. CI fails on MIA AUC > 0.60."""
|
|
223
|
+
from . import data, preprocess
|
|
224
|
+
from .privacy import DEFAULT_MIA_THRESHOLD, run_privacy_audit
|
|
225
|
+
|
|
226
|
+
src = preprocess.coerce_types(data.filter_to_modeled(data.load_episodes(source_csv)))
|
|
227
|
+
syn = pd.read_csv(synthetic_csv)
|
|
228
|
+
rng = np.random.default_rng(42)
|
|
229
|
+
idx = rng.permutation(len(src))
|
|
230
|
+
n_train = int(split * len(src))
|
|
231
|
+
real_train = src.iloc[idx[:n_train]].copy()
|
|
232
|
+
real_holdout = src.iloc[idx[n_train:]].copy()
|
|
233
|
+
|
|
234
|
+
feature_cols = [c for c in src.columns if c in syn.columns
|
|
235
|
+
and c not in ("RF_EPISODE2", "HASTA_ID", "episode_date", "gender")]
|
|
236
|
+
sensitive_targets = [c for c in ["Hipertansiyon", "DM_Tum", "Hiperlipidemi"]
|
|
237
|
+
if c in feature_cols]
|
|
238
|
+
|
|
239
|
+
report = run_privacy_audit(
|
|
240
|
+
real_train, real_holdout, syn,
|
|
241
|
+
feature_cols=feature_cols,
|
|
242
|
+
sensitive_targets=sensitive_targets,
|
|
243
|
+
)
|
|
244
|
+
summary = report.summary()
|
|
245
|
+
Path(report_path).parent.mkdir(parents=True, exist_ok=True)
|
|
246
|
+
Path(report_path).write_text(json.dumps(summary, indent=2))
|
|
247
|
+
click.echo(json.dumps(summary, indent=2))
|
|
248
|
+
if summary["membership_inference_auc"] > DEFAULT_MIA_THRESHOLD:
|
|
249
|
+
raise click.ClickException(
|
|
250
|
+
f"membership inference AUC {summary['membership_inference_auc']:.3f} > "
|
|
251
|
+
f"{DEFAULT_MIA_THRESHOLD} — possible memorization"
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
if __name__ == "__main__":
|
|
256
|
+
main()
|
syntha/conditional.py
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
1
|
+
"""Conditional sampling — generate synthetic episodes matching a filter.
|
|
2
|
+
|
|
3
|
+
Usage from the CLI:
|
|
4
|
+
|
|
5
|
+
syntha generate \\
|
|
6
|
+
--input data/raw/pristine_tolerant_episodes.csv \\
|
|
7
|
+
--condition 'age > 60 & DM_Tum == 1' \\
|
|
8
|
+
--n 1000 --output out/diabetic_seniors
|
|
9
|
+
|
|
10
|
+
Implementation: rejection sampling around the existing copula. Sample
|
|
11
|
+
`oversample_factor × n` rows from the fitted copula, evaluate the
|
|
12
|
+
pandas-style filter expression, keep the matches, repeat up to
|
|
13
|
+
``max_rounds`` until `n` matches are accumulated. Inefficient for very
|
|
14
|
+
rare conditions (e.g. P(condition) ≈ 0.01) but exact — no need to refit.
|
|
15
|
+
|
|
16
|
+
For rarer conditions, we could implement true conditional Gaussian
|
|
17
|
+
sampling (condition on a subset of variables, sample from the remaining
|
|
18
|
+
marginal of the multivariate normal), but that's a v0.7 task.
|
|
19
|
+
"""
|
|
20
|
+
from __future__ import annotations
|
|
21
|
+
|
|
22
|
+
import ast
|
|
23
|
+
from dataclasses import dataclass
|
|
24
|
+
|
|
25
|
+
import pandas as pd
|
|
26
|
+
|
|
27
|
+
from .generator.constraints import PhysiologicConstraints
|
|
28
|
+
from .generator.copula import GaussianCopulaGenerator
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@dataclass
|
|
32
|
+
class ConditionalSamplingResult:
|
|
33
|
+
rows: pd.DataFrame
|
|
34
|
+
n_requested: int
|
|
35
|
+
n_generated: int
|
|
36
|
+
rounds: int
|
|
37
|
+
rejection_rate: float
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
# Allowed AST node types in a condition expression. The walker rejects
|
|
41
|
+
# anything that could reach out (Attribute, Call, Subscript, Import) or
|
|
42
|
+
# bind new names (Lambda, comprehensions). Everything left over is just
|
|
43
|
+
# comparisons + boolean ops + arithmetic + literals + column references.
|
|
44
|
+
_ALLOWED_AST_NODES = {
|
|
45
|
+
ast.Expression, ast.Compare, ast.BoolOp, ast.UnaryOp, ast.BinOp,
|
|
46
|
+
ast.Name, ast.Constant, ast.Load,
|
|
47
|
+
ast.And, ast.Or, ast.Not, ast.Invert,
|
|
48
|
+
ast.Eq, ast.NotEq, ast.Lt, ast.LtE, ast.Gt, ast.GtE,
|
|
49
|
+
ast.In, ast.NotIn, ast.Is, ast.IsNot,
|
|
50
|
+
ast.Add, ast.Sub, ast.Mult, ast.Div, ast.Mod, ast.FloorDiv, ast.Pow,
|
|
51
|
+
ast.BitAnd, ast.BitOr, ast.BitXor, ast.LShift, ast.RShift,
|
|
52
|
+
ast.USub, ast.UAdd,
|
|
53
|
+
ast.Tuple, ast.List,
|
|
54
|
+
}
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def _safe_eval_filter(df: pd.DataFrame, expression: str) -> pd.Series:
|
|
58
|
+
"""Apply a pandas-query-style filter on `df`, AST-validated.
|
|
59
|
+
|
|
60
|
+
We parse the expression with ``ast.parse(mode="eval")`` and walk the
|
|
61
|
+
tree rejecting any node not on the allowlist. This catches attribute
|
|
62
|
+
access (``age.__class__``), function calls (``__import__('os')``),
|
|
63
|
+
subscripts (``a[0]``), lambdas, comprehensions, named expressions
|
|
64
|
+
(``:=``), starred expressions, and anything else that could escape
|
|
65
|
+
the DataFrame namespace. Names are further restricted to actual
|
|
66
|
+
DataFrame columns.
|
|
67
|
+
"""
|
|
68
|
+
try:
|
|
69
|
+
tree = ast.parse(expression, mode="eval")
|
|
70
|
+
except SyntaxError as e:
|
|
71
|
+
raise ValueError(f"invalid condition syntax: {e}") from e
|
|
72
|
+
|
|
73
|
+
for node in ast.walk(tree):
|
|
74
|
+
if type(node) not in _ALLOWED_AST_NODES:
|
|
75
|
+
raise ValueError(
|
|
76
|
+
f"condition contains disallowed construct "
|
|
77
|
+
f"{type(node).__name__!r}; only comparisons, boolean ops, "
|
|
78
|
+
f"arithmetic, literals, and column-name references are allowed. "
|
|
79
|
+
f"Got: {expression!r}"
|
|
80
|
+
)
|
|
81
|
+
if isinstance(node, ast.Name) and node.id not in df.columns:
|
|
82
|
+
raise ValueError(
|
|
83
|
+
f"condition references unknown name {node.id!r}; "
|
|
84
|
+
f"available columns: {list(df.columns)[:8]}…"
|
|
85
|
+
)
|
|
86
|
+
mask = df.eval(expression, engine="python")
|
|
87
|
+
if not isinstance(mask, pd.Series):
|
|
88
|
+
raise ValueError(
|
|
89
|
+
f"condition {expression!r} did not evaluate to a Series"
|
|
90
|
+
)
|
|
91
|
+
# Accept both numpy bool and pandas nullable boolean. The copula casts
|
|
92
|
+
# some columns to Int64 (nullable) which propagates to boolean masks.
|
|
93
|
+
if mask.dtype.name not in {"bool", "boolean"}:
|
|
94
|
+
raise ValueError(
|
|
95
|
+
f"condition {expression!r} did not evaluate to a boolean mask "
|
|
96
|
+
f"(got dtype {mask.dtype.name})"
|
|
97
|
+
)
|
|
98
|
+
# Convert nullable boolean to plain numpy bool, treating NA as False so
|
|
99
|
+
# downstream df[mask] indexing works without choking.
|
|
100
|
+
return mask.fillna(False).astype(bool)
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def sample_conditional(
|
|
104
|
+
generator: GaussianCopulaGenerator,
|
|
105
|
+
n: int,
|
|
106
|
+
condition: str,
|
|
107
|
+
*,
|
|
108
|
+
oversample_factor: float = 5.0,
|
|
109
|
+
max_rounds: int = 10,
|
|
110
|
+
constraints: PhysiologicConstraints | None = None,
|
|
111
|
+
) -> ConditionalSamplingResult:
|
|
112
|
+
"""Rejection-sample n rows that satisfy ``condition``.
|
|
113
|
+
|
|
114
|
+
Parameters
|
|
115
|
+
----------
|
|
116
|
+
generator:
|
|
117
|
+
A fitted GaussianCopulaGenerator.
|
|
118
|
+
n:
|
|
119
|
+
Target number of accepted rows.
|
|
120
|
+
condition:
|
|
121
|
+
A pandas df.eval-style expression in the model's column names.
|
|
122
|
+
Example: ``"age > 60 & DM_Tum == 1 & bp_systolic >= 140"``
|
|
123
|
+
oversample_factor:
|
|
124
|
+
Initial multiplier — how many candidates to draw per accepted
|
|
125
|
+
target. 5× is a reasonable default for moderately selective
|
|
126
|
+
filters (P ≥ 0.05). Raise for rarer conditions.
|
|
127
|
+
max_rounds:
|
|
128
|
+
Safety cap. After this many rounds without filling the quota,
|
|
129
|
+
the function returns what it has and reports the rejection rate.
|
|
130
|
+
constraints:
|
|
131
|
+
Optional PhysiologicConstraints to apply before the user filter.
|
|
132
|
+
|
|
133
|
+
Returns
|
|
134
|
+
-------
|
|
135
|
+
ConditionalSamplingResult with the accepted rows, the round count,
|
|
136
|
+
and the empirical rejection rate (so users can see how rare their
|
|
137
|
+
condition is in this copula).
|
|
138
|
+
"""
|
|
139
|
+
collected: list[pd.DataFrame] = []
|
|
140
|
+
rounds = 0
|
|
141
|
+
total_drawn = 0
|
|
142
|
+
deficit = n
|
|
143
|
+
|
|
144
|
+
while sum(len(d) for d in collected) < n and rounds < max_rounds:
|
|
145
|
+
rounds += 1
|
|
146
|
+
batch_size = max(1, int(deficit * oversample_factor))
|
|
147
|
+
drawn = generator.sample(batch_size)
|
|
148
|
+
total_drawn += len(drawn)
|
|
149
|
+
|
|
150
|
+
if constraints is not None:
|
|
151
|
+
kept, _ = constraints.apply(drawn)
|
|
152
|
+
else:
|
|
153
|
+
kept = drawn
|
|
154
|
+
|
|
155
|
+
mask = _safe_eval_filter(kept, condition)
|
|
156
|
+
matches = kept[mask].reset_index(drop=True)
|
|
157
|
+
collected.append(matches)
|
|
158
|
+
|
|
159
|
+
accumulated = sum(len(d) for d in collected)
|
|
160
|
+
deficit = max(0, n - accumulated)
|
|
161
|
+
if deficit > 0:
|
|
162
|
+
# Update oversample_factor: if we got <50% of what we hoped for,
|
|
163
|
+
# double the next batch to converge faster.
|
|
164
|
+
empirical_rate = accumulated / total_drawn
|
|
165
|
+
if empirical_rate > 0:
|
|
166
|
+
oversample_factor = max(oversample_factor, 1.5 / empirical_rate)
|
|
167
|
+
|
|
168
|
+
final = pd.concat(collected, ignore_index=True).head(n).reset_index(drop=True)
|
|
169
|
+
rejection_rate = 1.0 - (len(final) / total_drawn) if total_drawn else 1.0
|
|
170
|
+
return ConditionalSamplingResult(
|
|
171
|
+
rows=final,
|
|
172
|
+
n_requested=n,
|
|
173
|
+
n_generated=len(final),
|
|
174
|
+
rounds=rounds,
|
|
175
|
+
rejection_rate=rejection_rate,
|
|
176
|
+
)
|
syntha/data.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
"""CSV loader for pristine-episode source files."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
|
|
6
|
+
import pandas as pd
|
|
7
|
+
|
|
8
|
+
from . import schema
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def load_episodes(path: str | Path) -> pd.DataFrame:
|
|
12
|
+
"""Read a pristine-episodes CSV. Handles BOM-prefixed first column."""
|
|
13
|
+
df = pd.read_csv(path, encoding="utf-8-sig", low_memory=False)
|
|
14
|
+
df.columns = [c.lstrip("").strip() for c in df.columns]
|
|
15
|
+
return df
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def filter_to_modeled(df: pd.DataFrame) -> pd.DataFrame:
|
|
19
|
+
"""Keep only the columns we know how to model. Unknown columns are dropped."""
|
|
20
|
+
keep = [c for c in schema.all_modeled_columns() if c in df.columns]
|
|
21
|
+
return df[keep].copy()
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def date_range(df: pd.DataFrame) -> tuple[pd.Timestamp, pd.Timestamp]:
|
|
25
|
+
if "episode_date" not in df.columns:
|
|
26
|
+
return pd.Timestamp("2015-01-01"), pd.Timestamp("2024-12-31")
|
|
27
|
+
s = pd.to_datetime(df["episode_date"], errors="coerce").dropna()
|
|
28
|
+
if s.empty:
|
|
29
|
+
return pd.Timestamp("2015-01-01"), pd.Timestamp("2024-12-31")
|
|
30
|
+
return s.min(), s.max()
|
syntha/export_model.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
"""Export a trained copula model to a compact JSON the Tauri app consumes.
|
|
2
|
+
|
|
3
|
+
We downsample each continuous marginal to ``n_quantiles`` order statistics to
|
|
4
|
+
keep the JSON small (≈100 KB for the tolerant cohort at 200 quantiles), then
|
|
5
|
+
serialize the correlation matrix and binary marginals verbatim.
|
|
6
|
+
"""
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import json
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
|
|
12
|
+
import numpy as np
|
|
13
|
+
|
|
14
|
+
from .generator.copula import GaussianCopulaGenerator
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def _quantile_grid(values: np.ndarray, n: int) -> list[float]:
|
|
18
|
+
if len(values) == 0:
|
|
19
|
+
return [0.0]
|
|
20
|
+
if len(values) <= n:
|
|
21
|
+
return [float(v) for v in np.sort(values)]
|
|
22
|
+
qs = np.linspace(0.0, 1.0, n)
|
|
23
|
+
return [float(v) for v in np.quantile(values, qs)]
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def export_model_to_json(
|
|
27
|
+
gen: GaussianCopulaGenerator,
|
|
28
|
+
path: str | Path,
|
|
29
|
+
n_quantiles: int = 200,
|
|
30
|
+
) -> Path:
|
|
31
|
+
if gen.model is None:
|
|
32
|
+
raise RuntimeError("generator has no fitted model")
|
|
33
|
+
m = gen.model
|
|
34
|
+
payload = {
|
|
35
|
+
"format": "syntha-copula-v1",
|
|
36
|
+
"cohort": m.cohort,
|
|
37
|
+
"columns": list(m.columns),
|
|
38
|
+
"binary_cols": sorted(m.binary_cols),
|
|
39
|
+
"p_missing": {c: float(v) for c, v in m.p_missing.items()},
|
|
40
|
+
"binary_p": {c: float(v) for c, v in m.binary_p.items()},
|
|
41
|
+
"continuous_quantiles": {
|
|
42
|
+
c: _quantile_grid(q, n_quantiles) for c, q in m.continuous_quantiles.items()
|
|
43
|
+
},
|
|
44
|
+
"correlation": m.correlation.tolist(),
|
|
45
|
+
"n_train": m.n_train,
|
|
46
|
+
}
|
|
47
|
+
out = Path(path)
|
|
48
|
+
out.parent.mkdir(parents=True, exist_ok=True)
|
|
49
|
+
out.write_text(json.dumps(payload, ensure_ascii=False), encoding="utf-8")
|
|
50
|
+
return out
|
syntha/fhir/__init__.py
ADDED