cehrgpt 0.0.2__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cehrgpt/data/hf_cehrgpt_dataset.py +24 -4
- cehrgpt/data/hf_cehrgpt_dataset_collator.py +260 -84
- cehrgpt/data/hf_cehrgpt_dataset_mapping.py +99 -88
- cehrgpt/data/sample_packing_sampler.py +151 -0
- cehrgpt/generation/generate_batch_hf_gpt_sequence.py +12 -9
- cehrgpt/models/config.py +10 -0
- cehrgpt/models/hf_cehrgpt.py +243 -73
- cehrgpt/models/tokenization_hf_cehrgpt.py +4 -0
- cehrgpt/runners/data_utils.py +243 -0
- cehrgpt/runners/gpt_runner_util.py +0 -10
- cehrgpt/runners/hf_cehrgpt_finetune_runner.py +152 -279
- cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +229 -105
- cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +42 -0
- cehrgpt/runners/hyperparameter_search_util.py +4 -1
- cehrgpt/runners/sample_packing_trainer.py +168 -0
- cehrgpt/simulations/generate_plots.py +95 -0
- cehrgpt/simulations/run_simulation.sh +24 -0
- cehrgpt/simulations/time_embedding_simulation.py +250 -0
- cehrgpt/simulations/time_token_simulation.py +177 -0
- cehrgpt/tools/linear_prob/__init__.py +0 -0
- cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +467 -0
- cehrgpt/tools/linear_prob/train_with_cehrgpt_features.py +152 -0
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.0.dist-info}/METADATA +7 -5
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.0.dist-info}/RECORD +28 -26
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.0.dist-info}/WHEEL +1 -1
- cehrgpt/data/hf_cehrgpt_dpo_collator.py +0 -71
- cehrgpt/data/hf_cehrgpt_dpo_dataset_mapping.py +0 -61
- cehrgpt/generation/generate_paired_cehrgpt_sequence.py +0 -224
- cehrgpt/rl_finetune/cehrgpt_dpo_trainer.py +0 -586
- cehrgpt/rl_finetune/cehrgpt_ppo_trainer.py +0 -464
- cehrgpt/rl_finetune/ppo_finetune.py +0 -394
- cehrgpt/rl_finetune/ppo_finetune_v2.py +0 -373
- cehrgpt/runners/hf_cehrgpt_dpo_runner.py +0 -119
- /cehrgpt/{rl_finetune → simulations}/__init__.py +0 -0
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.0.dist-info/licenses}/LICENSE +0 -0
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,152 @@
|
|
1
|
+
import argparse
|
2
|
+
import json
|
3
|
+
import pickle
|
4
|
+
from pathlib import Path
|
5
|
+
from typing import Any, Dict, Union
|
6
|
+
|
7
|
+
import numpy as np
|
8
|
+
import pandas as pd
|
9
|
+
import polars as pl
|
10
|
+
from sklearn.linear_model import LogisticRegressionCV
|
11
|
+
from sklearn.metrics import auc, precision_recall_curve, roc_auc_score
|
12
|
+
from sklearn.preprocessing import OneHotEncoder, StandardScaler
|
13
|
+
|
14
|
+
|
15
|
+
def prepare_dataset(
|
16
|
+
df: pd.DataFrame, feature_processor: Dict[str, Union[StandardScaler, OneHotEncoder]]
|
17
|
+
) -> Dict[str, Any]:
|
18
|
+
age_scaler = feature_processor["age_scaler"]
|
19
|
+
gender_encoder = feature_processor["gender_encoder"]
|
20
|
+
race_encoder = feature_processor["race_encoder"]
|
21
|
+
age_scaler.transform(df[["age_at_index"]].to_numpy())
|
22
|
+
|
23
|
+
one_hot_gender = gender_encoder.transform(
|
24
|
+
np.expand_dims(df.gender_concept_id.to_numpy(), axis=1)
|
25
|
+
)
|
26
|
+
one_hot_race = race_encoder.transform(
|
27
|
+
np.expand_dims(df.race_concept_id.to_numpy(), axis=1)
|
28
|
+
)
|
29
|
+
|
30
|
+
features = np.stack(df["features"].apply(lambda x: np.array(x).flatten()))
|
31
|
+
# features = np.hstack(
|
32
|
+
# [scaled_age, one_hot_gender.toarray(), one_hot_race.toarray(), features]
|
33
|
+
# )
|
34
|
+
return {
|
35
|
+
"subject_id": df["subject_id"].to_numpy(),
|
36
|
+
"prediction_time": df["prediction_time"].tolist(),
|
37
|
+
"features": features,
|
38
|
+
"boolean_value": df["boolean_value"].to_numpy(),
|
39
|
+
}
|
40
|
+
|
41
|
+
|
42
|
+
def main(args):
|
43
|
+
features_data_dir = Path(args.features_data_dir)
|
44
|
+
output_dir = Path(args.output_dir)
|
45
|
+
feature_processor_path = output_dir / "feature_processor.pickle"
|
46
|
+
logistic_dir = output_dir / "logistic"
|
47
|
+
logistic_dir.mkdir(exist_ok=True, parents=True)
|
48
|
+
logistic_test_result_file = logistic_dir / "metrics.json"
|
49
|
+
if logistic_test_result_file.exists():
|
50
|
+
print("The models have been trained, and skip ...")
|
51
|
+
exit(0)
|
52
|
+
|
53
|
+
feature_train = pd.read_parquet(
|
54
|
+
features_data_dir / "features_with_label" / "train_features"
|
55
|
+
)
|
56
|
+
feature_test = pd.read_parquet(
|
57
|
+
features_data_dir / "features_with_label" / "test_features"
|
58
|
+
)
|
59
|
+
|
60
|
+
feature_train = feature_train.sort_values(["subject_id", "prediction_time"]).sample(
|
61
|
+
frac=1.0,
|
62
|
+
random_state=42,
|
63
|
+
replace=False,
|
64
|
+
)
|
65
|
+
|
66
|
+
if feature_processor_path.exists():
|
67
|
+
with open(feature_processor_path, "rb") as f:
|
68
|
+
feature_processor = pickle.load(f)
|
69
|
+
else:
|
70
|
+
age_scaler, gender_encoder, race_encoder = (
|
71
|
+
StandardScaler(),
|
72
|
+
OneHotEncoder(handle_unknown="ignore"),
|
73
|
+
OneHotEncoder(handle_unknown="ignore"),
|
74
|
+
)
|
75
|
+
age_scaler = age_scaler.fit(feature_train[["age_at_index"]].to_numpy())
|
76
|
+
gender_encoder = gender_encoder.fit(
|
77
|
+
feature_train[["gender_concept_id"]].to_numpy()
|
78
|
+
)
|
79
|
+
race_encoder = race_encoder.fit(feature_train[["race_concept_id"]].to_numpy())
|
80
|
+
feature_processor = {
|
81
|
+
"age_scaler": age_scaler,
|
82
|
+
"gender_encoder": gender_encoder,
|
83
|
+
"race_encoder": race_encoder,
|
84
|
+
}
|
85
|
+
with open(feature_processor_path, "wb") as f:
|
86
|
+
pickle.dump(feature_processor, f)
|
87
|
+
|
88
|
+
if logistic_test_result_file.exists():
|
89
|
+
print(
|
90
|
+
f"The results for logistic regression already exist at {logistic_test_result_file}"
|
91
|
+
)
|
92
|
+
else:
|
93
|
+
logistic_model_file = logistic_dir / "model.pickle"
|
94
|
+
if logistic_model_file.exists():
|
95
|
+
print(
|
96
|
+
f"The logistic regression model already exist, loading it from {logistic_model_file}"
|
97
|
+
)
|
98
|
+
with open(logistic_model_file, "rb") as f:
|
99
|
+
model = pickle.load(f)
|
100
|
+
else:
|
101
|
+
train_dataset = prepare_dataset(feature_train, feature_processor)
|
102
|
+
# Train logistic regression
|
103
|
+
model = LogisticRegressionCV(scoring="roc_auc", random_state=42)
|
104
|
+
model.fit(train_dataset["features"], train_dataset["boolean_value"])
|
105
|
+
with open(logistic_model_file, "wb") as f:
|
106
|
+
pickle.dump(model, f)
|
107
|
+
|
108
|
+
test_dataset = prepare_dataset(feature_test, feature_processor)
|
109
|
+
y_pred = model.predict_proba(test_dataset["features"])[:, 1]
|
110
|
+
logistic_predictions = pl.DataFrame(
|
111
|
+
{
|
112
|
+
"subject_id": test_dataset["subject_id"].tolist(),
|
113
|
+
"prediction_time": test_dataset["prediction_time"],
|
114
|
+
"predicted_boolean_probability": y_pred.tolist(),
|
115
|
+
"predicted_boolean_value": None,
|
116
|
+
"boolean_value": test_dataset["boolean_value"].astype(bool).tolist(),
|
117
|
+
}
|
118
|
+
)
|
119
|
+
logistic_predictions = logistic_predictions.with_columns(
|
120
|
+
pl.col("predicted_boolean_value").cast(pl.Boolean())
|
121
|
+
)
|
122
|
+
logistic_test_predictions = logistic_dir / "test_predictions"
|
123
|
+
logistic_test_predictions.mkdir(exist_ok=True, parents=True)
|
124
|
+
logistic_predictions.write_parquet(
|
125
|
+
logistic_test_predictions / "predictions.parquet"
|
126
|
+
)
|
127
|
+
|
128
|
+
roc_auc = roc_auc_score(test_dataset["boolean_value"], y_pred)
|
129
|
+
precision, recall, _ = precision_recall_curve(
|
130
|
+
test_dataset["boolean_value"], y_pred
|
131
|
+
)
|
132
|
+
pr_auc = auc(recall, precision)
|
133
|
+
|
134
|
+
metrics = {"roc_auc": roc_auc, "pr_auc": pr_auc}
|
135
|
+
print("Logistic:", features_data_dir.name, metrics)
|
136
|
+
with open(logistic_test_result_file, "w") as f:
|
137
|
+
json.dump(metrics, f, indent=4)
|
138
|
+
|
139
|
+
|
140
|
+
if __name__ == "__main__":
|
141
|
+
parser = argparse.ArgumentParser(
|
142
|
+
description="Train logistic regression model with cehrgpt features"
|
143
|
+
)
|
144
|
+
parser.add_argument(
|
145
|
+
"--features_data_dir",
|
146
|
+
required=True,
|
147
|
+
help="Directory containing training and test feature files",
|
148
|
+
)
|
149
|
+
parser.add_argument(
|
150
|
+
"--output_dir", required=True, help="Directory to save the output results"
|
151
|
+
)
|
152
|
+
main(parser.parse_args())
|
@@ -1,6 +1,6 @@
|
|
1
|
-
Metadata-Version: 2.
|
1
|
+
Metadata-Version: 2.4
|
2
2
|
Name: cehrgpt
|
3
|
-
Version: 0.0
|
3
|
+
Version: 0.1.0
|
4
4
|
Summary: CEHR-GPT: Generating Electronic Health Records with Chronological Patient Timelines
|
5
5
|
Author-email: Chao Pang <chaopang229@gmail.com>, Xinzhuo Jiang <xj2193@cumc.columbia.edu>, Krishna Kalluri <kk3326@cumc.columbia.edu>, Elise Minto <em3697@cumc.columbia.edu>, Jason Patterson <jp3477@cumc.columbia.edu>, Nishanth Parameshwar Pavinkurve <np2689@cumc.columbia.edu>, Karthik Natarajan <kn2174@cumc.columbia.edu>
|
6
6
|
License: MIT License
|
@@ -12,13 +12,14 @@ Classifier: Programming Language :: Python :: 3
|
|
12
12
|
Requires-Python: >=3.10.0
|
13
13
|
Description-Content-Type: text/markdown
|
14
14
|
License-File: LICENSE
|
15
|
-
Requires-Dist: cehrbert==1.
|
15
|
+
Requires-Dist: cehrbert==1.4.1
|
16
|
+
Requires-Dist: cehrbert_data==0.0.7
|
16
17
|
Requires-Dist: openai==1.54.3
|
17
18
|
Requires-Dist: optuna==4.0.0
|
18
|
-
Requires-Dist: transformers==4.
|
19
|
+
Requires-Dist: transformers==4.44.0
|
19
20
|
Requires-Dist: tokenizers==0.19.0
|
20
21
|
Requires-Dist: peft==0.10.0
|
21
|
-
Requires-Dist:
|
22
|
+
Requires-Dist: lightgbm
|
22
23
|
Provides-Extra: dev
|
23
24
|
Requires-Dist: pre-commit; extra == "dev"
|
24
25
|
Requires-Dist: pytest; extra == "dev"
|
@@ -29,6 +30,7 @@ Requires-Dist: hypothesis; extra == "dev"
|
|
29
30
|
Requires-Dist: black; extra == "dev"
|
30
31
|
Provides-Extra: flash-attn
|
31
32
|
Requires-Dist: flash_attn; extra == "flash-attn"
|
33
|
+
Dynamic: license-file
|
32
34
|
|
33
35
|
# CEHRGPT
|
34
36
|
|
@@ -11,24 +11,22 @@ cehrgpt/analysis/privacy/nearest_neighbor_inference.py,sha256=qoJgWW7VsUMzjMGpTa
|
|
11
11
|
cehrgpt/analysis/privacy/reid_inference.py,sha256=Pypd3QJXQNY8VljpnIEa5zeAbTZHMjQOazaL-9VsBGw,13955
|
12
12
|
cehrgpt/analysis/privacy/utils.py,sha256=CRA4H9mPLBjMQGKzZ_x_3ro3tMap-NjsMDVqSOjHSVQ,8226
|
13
13
|
cehrgpt/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
14
|
-
cehrgpt/data/hf_cehrgpt_dataset.py,sha256=
|
15
|
-
cehrgpt/data/hf_cehrgpt_dataset_collator.py,sha256=
|
16
|
-
cehrgpt/data/hf_cehrgpt_dataset_mapping.py,sha256=
|
17
|
-
cehrgpt/data/
|
18
|
-
cehrgpt/data/hf_cehrgpt_dpo_dataset_mapping.py,sha256=uCLF5VEsyZAG1aNwqEM6Jy5Lx7bI5ALku52Z6Anine0,2574
|
14
|
+
cehrgpt/data/hf_cehrgpt_dataset.py,sha256=t9vpN05e--CiKgIlxLP0aLacISnvWWDPXtuFuJi3ksE,3736
|
15
|
+
cehrgpt/data/hf_cehrgpt_dataset_collator.py,sha256=DOvIF4Wzkd8-IO3zpIRZkX1j0IdvefaiSnrDn1YivCk,27912
|
16
|
+
cehrgpt/data/hf_cehrgpt_dataset_mapping.py,sha256=eI8CTk6yJ4DlNJWrNAkEmhWh353NeLqg5rwPpKqKT-U,17308
|
17
|
+
cehrgpt/data/sample_packing_sampler.py,sha256=0uKTbvtXpfS81esy_3epJ88eohyJPK46bfmxhle1fws,5419
|
19
18
|
cehrgpt/generation/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
20
19
|
cehrgpt/generation/chatgpt_generation.py,sha256=SrnLwHLdNtnAOEg36gNjqfoT9yd12iyPgpZffL2AFJo,4428
|
21
|
-
cehrgpt/generation/generate_batch_hf_gpt_sequence.py,sha256
|
22
|
-
cehrgpt/generation/generate_paired_cehrgpt_sequence.py,sha256=fLu3SHhRe_ZQfS09ebOktq2dekStgYfxmbrRawZQAO4,8280
|
20
|
+
cehrgpt/generation/generate_batch_hf_gpt_sequence.py,sha256=uSEh8aMmPD61nGewIaPSkIqm-2AxDjCBiu4cBfxHxU4,11503
|
23
21
|
cehrgpt/generation/omop_converter_batch.py,sha256=-c0AlDVy5pJ5Afhr8ERiCHhoRrEk8ozJi3g0yFdWaMI,25348
|
24
22
|
cehrgpt/generation/omop_entity.py,sha256=Q5Sr0AlyuPAm1FRPfnJO13q-u1fqRgYVHXruZ9g4xNE,19400
|
25
23
|
cehrgpt/models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
26
|
-
cehrgpt/models/config.py,sha256=
|
27
|
-
cehrgpt/models/hf_cehrgpt.py,sha256=
|
24
|
+
cehrgpt/models/config.py,sha256=Y3CiXZWniLP9_RlpU80Oe9gjn5leLmTYnNe_fWqfJLQ,10158
|
25
|
+
cehrgpt/models/hf_cehrgpt.py,sha256=3EQIOfa--oz4f8bM8KzbDi98G3XrUEQkox1vmBN001M,83321
|
28
26
|
cehrgpt/models/hf_modeling_outputs.py,sha256=LaWa1jI6BRIKMEjWOy1QUeOfTur5y_p2c-JyuGVTdtw,10301
|
29
27
|
cehrgpt/models/pretrained_embeddings.py,sha256=vLLVs17TLpXRqCVEWQxGGwPHkUJUO7laNTeBuyBK_yk,3238
|
30
28
|
cehrgpt/models/special_tokens.py,sha256=-a7HPJBbdIH0qQ6B3CcRKqvpG6FZlm4nbVPTswGSJ4U,485
|
31
|
-
cehrgpt/models/tokenization_hf_cehrgpt.py,sha256=
|
29
|
+
cehrgpt/models/tokenization_hf_cehrgpt.py,sha256=jjCRqS29IzMnKp40jNOs80UKh2z9lK5S6M02GSB-4mk,42351
|
32
30
|
cehrgpt/omop/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
33
31
|
cehrgpt/omop/condition_era.py,sha256=hPZALz2XaWnro_1bwIYNkI48foOJjueyg3CZ1BliCno,626
|
34
32
|
cehrgpt/omop/observation_period.py,sha256=TRMgv5Ya2RaS2im7oQ6BLC_5JL9EJYNYR62ApxIuHvg,1211
|
@@ -38,18 +36,19 @@ cehrgpt/omop/sample_omop_tables.py,sha256=2JZ8BNSvssceinwFanvuCRh-YlKrKn25U9w1pL
|
|
38
36
|
cehrgpt/omop/queries/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
39
37
|
cehrgpt/omop/queries/condition_era.py,sha256=LFB6vBAvshHJxtYIRkl7cfrF0kf7ay0piBKpmHBwrpE,2578
|
40
38
|
cehrgpt/omop/queries/observation_period.py,sha256=fpzr5DMNw-QLoSwp2Iatfch88E3hyhZ75usiIdG3A0U,6410
|
41
|
-
cehrgpt/rl_finetune/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
42
|
-
cehrgpt/rl_finetune/cehrgpt_dpo_trainer.py,sha256=VQHf5vy5i8K1imcqYakhitfAW-d2mnaEzkSoAYSW5kg,26062
|
43
|
-
cehrgpt/rl_finetune/cehrgpt_ppo_trainer.py,sha256=nYWYPCaWNjDGEwlo6UHOK1rvOZUx1vuJ8kYuAszI8Zg,17925
|
44
|
-
cehrgpt/rl_finetune/ppo_finetune.py,sha256=tSy-C0Kzgj5ffclBIDj-RTj78ZfrLmTESxVxd0n9yuE,13971
|
45
|
-
cehrgpt/rl_finetune/ppo_finetune_v2.py,sha256=7dChwKpq4zKmpkcxP4hryqBoIkcwmTJ44_BF8R2RghQ,13285
|
46
39
|
cehrgpt/runners/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
47
|
-
cehrgpt/runners/
|
48
|
-
cehrgpt/runners/
|
49
|
-
cehrgpt/runners/hf_cehrgpt_finetune_runner.py,sha256=
|
50
|
-
cehrgpt/runners/hf_cehrgpt_pretrain_runner.py,sha256=
|
51
|
-
cehrgpt/runners/hf_gpt_runner_argument_dataclass.py,sha256=
|
52
|
-
cehrgpt/runners/hyperparameter_search_util.py,sha256=
|
40
|
+
cehrgpt/runners/data_utils.py,sha256=ScZZnfXwgXKaMvKgFzdb4vtQ7F_lw97O5uNsFbfsyP4,10620
|
41
|
+
cehrgpt/runners/gpt_runner_util.py,sha256=YJQSRW9Mo4TjXSOUOTf6BUFcs1MGFiXU5T4ztKZcYhU,3485
|
42
|
+
cehrgpt/runners/hf_cehrgpt_finetune_runner.py,sha256=bkPl30Y9CSXBlmMkH-3cA3-aW8XJK36Q-adx___WjkE,26921
|
43
|
+
cehrgpt/runners/hf_cehrgpt_pretrain_runner.py,sha256=ViVa_flEGdk_SO0psMR7ho-o79igsz_l1x80u81WJ3A,23875
|
44
|
+
cehrgpt/runners/hf_gpt_runner_argument_dataclass.py,sha256=VrqgDSiAMfGyHEIodoOg_8LU5O0ndWf9EE0YOKDFKKA,7019
|
45
|
+
cehrgpt/runners/hyperparameter_search_util.py,sha256=pWFmGo9Ezju4YmuZ-ohbAbYB0GGMfIDVUCyvcTxS1iU,9153
|
46
|
+
cehrgpt/runners/sample_packing_trainer.py,sha256=aezX30vxpP1DDcH5hO-yn395NqBKi2Xhb0mFNHi9OBs,7340
|
47
|
+
cehrgpt/simulations/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
48
|
+
cehrgpt/simulations/generate_plots.py,sha256=BTZ71r8Kah0PMorkiO3vw55_p_9U1Z8KiD3GsPfaV0s,2520
|
49
|
+
cehrgpt/simulations/run_simulation.sh,sha256=DcJ6B19jIteUO0pZ0Tc21876lB9XxQHFAxlre7MtAzk,795
|
50
|
+
cehrgpt/simulations/time_embedding_simulation.py,sha256=HZ-imXH-bN-QYZN1PAfcERmNtaWIwKjbf0UrZduwCiA,8687
|
51
|
+
cehrgpt/simulations/time_token_simulation.py,sha256=sLg8vVXydvR_zk3BbqyrlA7sDIdhFnS-s5pSKcCilSc,6057
|
53
52
|
cehrgpt/time_to_event/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
54
53
|
cehrgpt/time_to_event/time_to_event_model.py,sha256=tfXa24l_0q1TBZ68BPRrHRC_3KRWYxrWGIv4myJlIb8,8497
|
55
54
|
cehrgpt/time_to_event/time_to_event_prediction.py,sha256=Ajesq2gSsILghWHCTLiiBhcyOCa7m6JPPMdi_xvBlR4,12624
|
@@ -63,8 +62,11 @@ cehrgpt/tools/generate_causal_patient_split_by_age.py,sha256=dmHiPAL_kR1WrhRteIi
|
|
63
62
|
cehrgpt/tools/generate_pretrained_embeddings.py,sha256=lhFSacGv8bMld6qigKZN8Op8eXpFi0DsJuQbWKOWXqI,4160
|
64
63
|
cehrgpt/tools/merge_synthetic_real_dataasets.py,sha256=O1dbQ32Le0t15fwymwAh9mfNVLEWuFwW53DNvESrWbY,7589
|
65
64
|
cehrgpt/tools/upload_omop_tables.py,sha256=vdBAbkeAsGPA4NsyhNjelPVj3gS8yzmS1sKNM1Qk96g,3791
|
66
|
-
cehrgpt
|
67
|
-
cehrgpt
|
68
|
-
cehrgpt
|
69
|
-
cehrgpt-0.0.
|
70
|
-
cehrgpt-0.0.
|
65
|
+
cehrgpt/tools/linear_prob/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
66
|
+
cehrgpt/tools/linear_prob/compute_cehrgpt_features.py,sha256=jVgAmBrZKp7ABfqKkzwV5Vl_G9jDCjPl98NSVmSwHpE,19291
|
67
|
+
cehrgpt/tools/linear_prob/train_with_cehrgpt_features.py,sha256=w0UvzMKYGenN_KDVnbzutmy8IPLUxW5hPvpKKxDSL5U,5820
|
68
|
+
cehrgpt-0.1.0.dist-info/licenses/LICENSE,sha256=LOfC32zkfUIdGm8e_098jPbt8OHKtNWymDzxn2pA9Zk,1093
|
69
|
+
cehrgpt-0.1.0.dist-info/METADATA,sha256=V02vsptjJRD_bybXVRFXPrJa-By9CX4j-oAA3EfXFq4,4933
|
70
|
+
cehrgpt-0.1.0.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
|
71
|
+
cehrgpt-0.1.0.dist-info/top_level.txt,sha256=akNCJBbMSLV8nkOzdVzdy13hMJ5CIQURnAS_YYEDVwA,17
|
72
|
+
cehrgpt-0.1.0.dist-info/RECORD,,
|
@@ -1,71 +0,0 @@
|
|
1
|
-
import torch
|
2
|
-
from torch.nn.utils.rnn import pad_sequence
|
3
|
-
|
4
|
-
from cehrgpt.data.hf_cehrgpt_dataset_collator import CehrGptDataCollator
|
5
|
-
|
6
|
-
|
7
|
-
class CehrGptDPODataCollator(CehrGptDataCollator):
|
8
|
-
|
9
|
-
def create_preference_inputs(self, examples, prefix):
|
10
|
-
batch = {}
|
11
|
-
# Assume that each example in the batch is a dictionary with 'input_ids' and 'attention_mask'
|
12
|
-
batch_input_ids = [
|
13
|
-
self._try_reverse_tensor(
|
14
|
-
self._convert_to_tensor(example[f"{prefix}_input_ids"])
|
15
|
-
)
|
16
|
-
for example in examples
|
17
|
-
]
|
18
|
-
batch_attention_mask = [
|
19
|
-
self._try_reverse_tensor(
|
20
|
-
torch.ones_like(
|
21
|
-
self._convert_to_tensor(example[f"{prefix}_input_ids"]),
|
22
|
-
dtype=torch.float,
|
23
|
-
)
|
24
|
-
)
|
25
|
-
for example in examples
|
26
|
-
]
|
27
|
-
# Pad sequences to the max length in the batch
|
28
|
-
batch[f"{prefix}_input_ids"] = self._try_reverse_tensor(
|
29
|
-
pad_sequence(
|
30
|
-
batch_input_ids,
|
31
|
-
batch_first=True,
|
32
|
-
padding_value=self.tokenizer.pad_token_id,
|
33
|
-
).to(torch.int64)
|
34
|
-
)
|
35
|
-
batch[f"{prefix}_attention_mask"] = self._try_reverse_tensor(
|
36
|
-
pad_sequence(batch_attention_mask, batch_first=True, padding_value=0.0)
|
37
|
-
)
|
38
|
-
assert batch[f"{prefix}_input_ids"].shape[1] <= self.max_length
|
39
|
-
assert batch[f"{prefix}_attention_mask"].shape[1] <= self.max_length
|
40
|
-
|
41
|
-
if self.include_values:
|
42
|
-
batch_value_indicators = [
|
43
|
-
self._try_reverse_tensor(
|
44
|
-
self._convert_to_tensor(example[f"{prefix}_value_indicators"])
|
45
|
-
)
|
46
|
-
for example in examples
|
47
|
-
]
|
48
|
-
batch_values = [
|
49
|
-
self._try_reverse_tensor(
|
50
|
-
self._convert_to_tensor(example[f"{prefix}__values"])
|
51
|
-
)
|
52
|
-
for example in examples
|
53
|
-
]
|
54
|
-
|
55
|
-
batch[f"{prefix}_value_indicators"] = self._try_reverse_tensor(
|
56
|
-
pad_sequence(
|
57
|
-
batch_value_indicators, batch_first=True, padding_value=False
|
58
|
-
)
|
59
|
-
)
|
60
|
-
batch[f"{prefix}_values"] = self._try_reverse_tensor(
|
61
|
-
pad_sequence(batch_values, batch_first=True, padding_value=-1.0)
|
62
|
-
)
|
63
|
-
assert batch[f"{prefix}_value_indicators"].shape[1] <= self.max_length
|
64
|
-
assert batch[f"{prefix}_values"].shape[1] <= self.max_length
|
65
|
-
return batch
|
66
|
-
|
67
|
-
def __call__(self, examples):
|
68
|
-
batch_chosen = self.create_preference_inputs(examples, "chosen")
|
69
|
-
batch_rejected = self.create_preference_inputs(examples, "rejected")
|
70
|
-
batch_chosen.update(batch_rejected)
|
71
|
-
return batch_chosen
|
@@ -1,61 +0,0 @@
|
|
1
|
-
import copy
|
2
|
-
from typing import Any, Dict
|
3
|
-
|
4
|
-
import numpy as np
|
5
|
-
from cehrbert.data_generators.hf_data_generator.hf_dataset_mapping import DatasetMapping
|
6
|
-
|
7
|
-
from cehrgpt.models.tokenization_hf_cehrgpt import CehrGptTokenizer
|
8
|
-
|
9
|
-
|
10
|
-
class HFCehrGptDPOTokenizationMapping(DatasetMapping):
|
11
|
-
def __init__(
|
12
|
-
self,
|
13
|
-
concept_tokenizer: CehrGptTokenizer,
|
14
|
-
):
|
15
|
-
self._concept_tokenizer = concept_tokenizer
|
16
|
-
self._lab_token_ids = self._concept_tokenizer.lab_token_ids
|
17
|
-
|
18
|
-
def transform_with_prefix(self, record: Dict[str, Any], prefix) -> Dict[str, Any]:
|
19
|
-
concept_ids = record[f"{prefix}_concept_ids"]
|
20
|
-
input_ids = self._concept_tokenizer.encode(concept_ids)
|
21
|
-
record[f"{prefix}_input_ids"] = input_ids
|
22
|
-
|
23
|
-
if f"{prefix}_concept_value_masks" in record:
|
24
|
-
concept_value_masks = record[f"{prefix}_concept_value_masks"]
|
25
|
-
concept_values = record[f"{prefix}_concept_values"]
|
26
|
-
# If any concept has a value associated with it, we normalize the value
|
27
|
-
if np.any(np.asarray(concept_value_masks) > 0):
|
28
|
-
units = record[f"{prefix}_units"]
|
29
|
-
normalized_concept_values = copy.deepcopy(concept_values)
|
30
|
-
for i, (
|
31
|
-
concept_id,
|
32
|
-
unit,
|
33
|
-
token_id,
|
34
|
-
concept_value_mask,
|
35
|
-
concept_value,
|
36
|
-
) in enumerate(
|
37
|
-
zip(
|
38
|
-
concept_ids,
|
39
|
-
units,
|
40
|
-
input_ids,
|
41
|
-
concept_value_masks,
|
42
|
-
concept_values,
|
43
|
-
)
|
44
|
-
):
|
45
|
-
if token_id in self._lab_token_ids:
|
46
|
-
normalized_concept_value = self._concept_tokenizer.normalize(
|
47
|
-
concept_id, unit, concept_value
|
48
|
-
)
|
49
|
-
normalized_concept_values[i] = normalized_concept_value
|
50
|
-
record[f"{prefix}_concept_values"] = normalized_concept_values
|
51
|
-
# Overwrite the column names
|
52
|
-
record[f"{prefix}_value_indicators"] = record[
|
53
|
-
f"{prefix}_concept_value_masks"
|
54
|
-
]
|
55
|
-
record[f"{prefix}_values"] = record[f"{prefix}_concept_values"]
|
56
|
-
return record
|
57
|
-
|
58
|
-
def transform(self, record: Dict[str, Any]) -> Dict[str, Any]:
|
59
|
-
record = self.transform_with_prefix(record, prefix="chosen")
|
60
|
-
record.update(self.transform_with_prefix(record, prefix="rejected"))
|
61
|
-
return record
|
@@ -1,224 +0,0 @@
|
|
1
|
-
import datetime
|
2
|
-
import os
|
3
|
-
import random
|
4
|
-
import uuid
|
5
|
-
|
6
|
-
import pandas as pd
|
7
|
-
import torch
|
8
|
-
from cehrbert.runners.runner_util import load_parquet_as_dataset
|
9
|
-
from transformers.utils import is_flash_attn_2_available, logging
|
10
|
-
|
11
|
-
from cehrgpt.cehrgpt_args import create_inference_base_arg_parser
|
12
|
-
from cehrgpt.generation.generate_batch_hf_gpt_sequence import (
|
13
|
-
generate_single_batch,
|
14
|
-
normalize_value,
|
15
|
-
)
|
16
|
-
from cehrgpt.gpt_utils import get_cehrgpt_output_folder
|
17
|
-
from cehrgpt.models.hf_cehrgpt import CEHRGPT2LMHeadModel
|
18
|
-
from cehrgpt.models.tokenization_hf_cehrgpt import CehrGptTokenizer
|
19
|
-
|
20
|
-
LOG = logging.get_logger("transformers")
|
21
|
-
|
22
|
-
|
23
|
-
def main(args):
|
24
|
-
if torch.cuda.is_available():
|
25
|
-
device = torch.device("cuda")
|
26
|
-
else:
|
27
|
-
device = torch.device("cpu")
|
28
|
-
|
29
|
-
cehrgpt_tokenizer = CehrGptTokenizer.from_pretrained(args.tokenizer_folder)
|
30
|
-
cehrgpt_model = (
|
31
|
-
CEHRGPT2LMHeadModel.from_pretrained(
|
32
|
-
args.model_folder,
|
33
|
-
attn_implementation=(
|
34
|
-
"flash_attention_2" if is_flash_attn_2_available() else "eager"
|
35
|
-
),
|
36
|
-
torch_dtype=(
|
37
|
-
torch.bfloat16 if is_flash_attn_2_available() else torch.float32
|
38
|
-
),
|
39
|
-
)
|
40
|
-
.eval()
|
41
|
-
.to(device)
|
42
|
-
)
|
43
|
-
cehrgpt_model.generation_config.pad_token_id = cehrgpt_tokenizer.pad_token_id
|
44
|
-
cehrgpt_model.generation_config.eos_token_id = cehrgpt_tokenizer.end_token_id
|
45
|
-
cehrgpt_model.generation_config.bos_token_id = cehrgpt_tokenizer.end_token_id
|
46
|
-
|
47
|
-
folder_name = get_cehrgpt_output_folder(args, cehrgpt_tokenizer)
|
48
|
-
output_folder_name = os.path.join(
|
49
|
-
args.output_folder, folder_name, "generated_sequences"
|
50
|
-
)
|
51
|
-
|
52
|
-
if not os.path.exists(output_folder_name):
|
53
|
-
os.makedirs(output_folder_name)
|
54
|
-
|
55
|
-
LOG.info(f"Loading tokenizer at {args.model_folder}")
|
56
|
-
LOG.info(f"Loading model at {args.model_folder}")
|
57
|
-
LOG.info(f"Write sequences to {output_folder_name}")
|
58
|
-
LOG.info(f"Context window {args.context_window}")
|
59
|
-
LOG.info(f"Temperature {args.temperature}")
|
60
|
-
LOG.info(f"Repetition Penalty {args.repetition_penalty}")
|
61
|
-
LOG.info(f"Sampling Strategy {args.sampling_strategy}")
|
62
|
-
LOG.info(f"Num beam {args.num_beams}")
|
63
|
-
LOG.info(f"Num beam groups {args.num_beam_groups}")
|
64
|
-
LOG.info(f"Epsilon cutoff {args.epsilon_cutoff}")
|
65
|
-
LOG.info(f"Top P {args.top_p}")
|
66
|
-
LOG.info(f"Top K {args.top_k}")
|
67
|
-
LOG.info(f"Loading sequence_data_path at {args.sequence_data_path}")
|
68
|
-
|
69
|
-
dataset = load_parquet_as_dataset(args.sequence_data_path)
|
70
|
-
total_rows = len(dataset)
|
71
|
-
float(args.batch_size) / total_rows
|
72
|
-
num_of_batches = args.num_of_patients // args.batch_size + 1
|
73
|
-
sequence_to_flush = []
|
74
|
-
for i in range(num_of_batches):
|
75
|
-
LOG.info(f"{datetime.datetime.now()}: Batch {i} started")
|
76
|
-
sample_data = []
|
77
|
-
while len(sample_data) == 0:
|
78
|
-
random_indices = random.sample(range(total_rows), k=1)
|
79
|
-
for row in dataset.select(random_indices):
|
80
|
-
if 4 <= len(row["concept_ids"]) <= cehrgpt_model.config.n_positions:
|
81
|
-
sample_data.append(row)
|
82
|
-
prompts = []
|
83
|
-
chosen_responses = []
|
84
|
-
cutoff_frac = random.uniform(0, args.cutoff_frac_max)
|
85
|
-
for row in sample_data:
|
86
|
-
seq_len = len(row["concept_ids"])
|
87
|
-
prompt_len = max(4, int(seq_len * cutoff_frac))
|
88
|
-
prompts.append(cehrgpt_tokenizer.encode(row["concept_ids"][:prompt_len]))
|
89
|
-
chosen_responses.append(
|
90
|
-
{
|
91
|
-
"person_id": row["person_id"],
|
92
|
-
"chosen_concept_ids": (
|
93
|
-
row["concept_ids"] if "concept_ids" in row else None
|
94
|
-
),
|
95
|
-
"chosen_concept_values": (
|
96
|
-
row["concept_values"] if "concept_values" in row else None
|
97
|
-
),
|
98
|
-
"chosen_concept_value_masks": (
|
99
|
-
row["concept_value_masks"]
|
100
|
-
if "concept_value_masks" in row
|
101
|
-
else None
|
102
|
-
),
|
103
|
-
"chosen_units": row["units"] if "units" in row else None,
|
104
|
-
"prompt_length": prompt_len,
|
105
|
-
}
|
106
|
-
)
|
107
|
-
|
108
|
-
batch_sequences = generate_single_batch(
|
109
|
-
cehrgpt_model,
|
110
|
-
cehrgpt_tokenizer,
|
111
|
-
prompts=prompts,
|
112
|
-
max_new_tokens=args.context_window,
|
113
|
-
mini_num_of_concepts=args.min_num_of_concepts,
|
114
|
-
top_p=args.top_p,
|
115
|
-
top_k=args.top_k,
|
116
|
-
temperature=args.temperature,
|
117
|
-
repetition_penalty=args.repetition_penalty,
|
118
|
-
num_beams=args.num_beams,
|
119
|
-
num_beam_groups=args.num_beam_groups,
|
120
|
-
epsilon_cutoff=args.epsilon_cutoff,
|
121
|
-
device=device,
|
122
|
-
)
|
123
|
-
|
124
|
-
# Clear the cache
|
125
|
-
torch.cuda.empty_cache()
|
126
|
-
|
127
|
-
for seq, value_indicator, value, chosen_response in zip(
|
128
|
-
batch_sequences["sequences"],
|
129
|
-
batch_sequences["value_indicators"],
|
130
|
-
batch_sequences["values"],
|
131
|
-
chosen_responses,
|
132
|
-
):
|
133
|
-
output = {"rejected_concept_ids": seq}
|
134
|
-
normalized_values, units = normalize_value(
|
135
|
-
seq, value_indicator, value, cehrgpt_tokenizer
|
136
|
-
)
|
137
|
-
if normalized_values is not None:
|
138
|
-
output["rejected_concept_values"] = normalized_values
|
139
|
-
if value_indicator is not None:
|
140
|
-
output["rejected_concept_value_masks"] = value_indicator
|
141
|
-
if units is not None:
|
142
|
-
output["rejected_units"] = units
|
143
|
-
output.update(chosen_response)
|
144
|
-
sequence_to_flush.append(output)
|
145
|
-
|
146
|
-
if len(sequence_to_flush) >= args.buffer_size:
|
147
|
-
LOG.info(f"{datetime.datetime.now()}: Flushing to the Disk at Batch {i}")
|
148
|
-
pd.DataFrame(
|
149
|
-
sequence_to_flush,
|
150
|
-
columns=[
|
151
|
-
"person_id",
|
152
|
-
"chosen_concept_ids",
|
153
|
-
"chosen_concept_values",
|
154
|
-
"chosen_concept_value_masks",
|
155
|
-
"chosen_units",
|
156
|
-
"prompt_length",
|
157
|
-
"rejected_concept_ids",
|
158
|
-
"rejected_concept_values",
|
159
|
-
"rejected_concept_value_masks",
|
160
|
-
"rejected_units",
|
161
|
-
],
|
162
|
-
).to_parquet(os.path.join(output_folder_name, f"{uuid.uuid4()}.parquet"))
|
163
|
-
sequence_to_flush.clear()
|
164
|
-
|
165
|
-
if len(sequence_to_flush) > 0:
|
166
|
-
LOG.info(f"{datetime.datetime.now()}: Flushing to the Disk at Final Batch")
|
167
|
-
pd.DataFrame(
|
168
|
-
sequence_to_flush,
|
169
|
-
columns=[
|
170
|
-
"person_id",
|
171
|
-
"chosen_concept_ids",
|
172
|
-
"chosen_concept_values",
|
173
|
-
"chosen_concept_value_masks",
|
174
|
-
"chosen_units",
|
175
|
-
"prompt_length",
|
176
|
-
"rejected_concept_ids",
|
177
|
-
"rejected_concept_values",
|
178
|
-
"rejected_concept_value_masks",
|
179
|
-
"rejected_units",
|
180
|
-
],
|
181
|
-
).to_parquet(os.path.join(output_folder_name, f"{uuid.uuid4()}-last.parquet"))
|
182
|
-
|
183
|
-
|
184
|
-
def create_arg_parser():
|
185
|
-
base_arg_parser = create_inference_base_arg_parser(
|
186
|
-
description="Arguments for generating paired patient sequences"
|
187
|
-
)
|
188
|
-
base_arg_parser.add_argument(
|
189
|
-
"--num_of_patients",
|
190
|
-
dest="num_of_patients",
|
191
|
-
action="store",
|
192
|
-
type=int,
|
193
|
-
help="The number of patients that will be generated",
|
194
|
-
required=True,
|
195
|
-
)
|
196
|
-
base_arg_parser.add_argument(
|
197
|
-
"--sequence_data_path",
|
198
|
-
dest="sequence_data_path",
|
199
|
-
action="store",
|
200
|
-
help="The path for your sequence data",
|
201
|
-
required=True,
|
202
|
-
)
|
203
|
-
base_arg_parser.add_argument(
|
204
|
-
"--cutoff_frac_max",
|
205
|
-
dest="cutoff_frac_max",
|
206
|
-
action="store",
|
207
|
-
type=float,
|
208
|
-
help="The max fraction of the patient sequences that will be used for prompting",
|
209
|
-
required=False,
|
210
|
-
default=0.5,
|
211
|
-
)
|
212
|
-
base_arg_parser.add_argument(
|
213
|
-
"--num_proc",
|
214
|
-
dest="num_proc",
|
215
|
-
action="store",
|
216
|
-
type=int,
|
217
|
-
required=False,
|
218
|
-
default=1,
|
219
|
-
)
|
220
|
-
return base_arg_parser
|
221
|
-
|
222
|
-
|
223
|
-
if __name__ == "__main__":
|
224
|
-
main(create_arg_parser().parse_args())
|