cehrgpt 0.0.2__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. cehrgpt/data/hf_cehrgpt_dataset.py +24 -4
  2. cehrgpt/data/hf_cehrgpt_dataset_collator.py +260 -84
  3. cehrgpt/data/hf_cehrgpt_dataset_mapping.py +99 -88
  4. cehrgpt/data/sample_packing_sampler.py +151 -0
  5. cehrgpt/generation/generate_batch_hf_gpt_sequence.py +12 -9
  6. cehrgpt/models/config.py +10 -0
  7. cehrgpt/models/hf_cehrgpt.py +243 -73
  8. cehrgpt/models/tokenization_hf_cehrgpt.py +4 -0
  9. cehrgpt/runners/data_utils.py +243 -0
  10. cehrgpt/runners/gpt_runner_util.py +0 -10
  11. cehrgpt/runners/hf_cehrgpt_finetune_runner.py +152 -279
  12. cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +229 -105
  13. cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +42 -0
  14. cehrgpt/runners/hyperparameter_search_util.py +4 -1
  15. cehrgpt/runners/sample_packing_trainer.py +168 -0
  16. cehrgpt/simulations/generate_plots.py +95 -0
  17. cehrgpt/simulations/run_simulation.sh +24 -0
  18. cehrgpt/simulations/time_embedding_simulation.py +250 -0
  19. cehrgpt/simulations/time_token_simulation.py +177 -0
  20. cehrgpt/tools/linear_prob/__init__.py +0 -0
  21. cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +467 -0
  22. cehrgpt/tools/linear_prob/train_with_cehrgpt_features.py +152 -0
  23. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.0.dist-info}/METADATA +7 -5
  24. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.0.dist-info}/RECORD +28 -26
  25. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.0.dist-info}/WHEEL +1 -1
  26. cehrgpt/data/hf_cehrgpt_dpo_collator.py +0 -71
  27. cehrgpt/data/hf_cehrgpt_dpo_dataset_mapping.py +0 -61
  28. cehrgpt/generation/generate_paired_cehrgpt_sequence.py +0 -224
  29. cehrgpt/rl_finetune/cehrgpt_dpo_trainer.py +0 -586
  30. cehrgpt/rl_finetune/cehrgpt_ppo_trainer.py +0 -464
  31. cehrgpt/rl_finetune/ppo_finetune.py +0 -394
  32. cehrgpt/rl_finetune/ppo_finetune_v2.py +0 -373
  33. cehrgpt/runners/hf_cehrgpt_dpo_runner.py +0 -119
  34. /cehrgpt/{rl_finetune → simulations}/__init__.py +0 -0
  35. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.0.dist-info/licenses}/LICENSE +0 -0
  36. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,152 @@
1
+ import argparse
2
+ import json
3
+ import pickle
4
+ from pathlib import Path
5
+ from typing import Any, Dict, Union
6
+
7
+ import numpy as np
8
+ import pandas as pd
9
+ import polars as pl
10
+ from sklearn.linear_model import LogisticRegressionCV
11
+ from sklearn.metrics import auc, precision_recall_curve, roc_auc_score
12
+ from sklearn.preprocessing import OneHotEncoder, StandardScaler
13
+
14
+
15
+ def prepare_dataset(
16
+ df: pd.DataFrame, feature_processor: Dict[str, Union[StandardScaler, OneHotEncoder]]
17
+ ) -> Dict[str, Any]:
18
+ age_scaler = feature_processor["age_scaler"]
19
+ gender_encoder = feature_processor["gender_encoder"]
20
+ race_encoder = feature_processor["race_encoder"]
21
+ age_scaler.transform(df[["age_at_index"]].to_numpy())
22
+
23
+ one_hot_gender = gender_encoder.transform(
24
+ np.expand_dims(df.gender_concept_id.to_numpy(), axis=1)
25
+ )
26
+ one_hot_race = race_encoder.transform(
27
+ np.expand_dims(df.race_concept_id.to_numpy(), axis=1)
28
+ )
29
+
30
+ features = np.stack(df["features"].apply(lambda x: np.array(x).flatten()))
31
+ # features = np.hstack(
32
+ # [scaled_age, one_hot_gender.toarray(), one_hot_race.toarray(), features]
33
+ # )
34
+ return {
35
+ "subject_id": df["subject_id"].to_numpy(),
36
+ "prediction_time": df["prediction_time"].tolist(),
37
+ "features": features,
38
+ "boolean_value": df["boolean_value"].to_numpy(),
39
+ }
40
+
41
+
42
+ def main(args):
43
+ features_data_dir = Path(args.features_data_dir)
44
+ output_dir = Path(args.output_dir)
45
+ feature_processor_path = output_dir / "feature_processor.pickle"
46
+ logistic_dir = output_dir / "logistic"
47
+ logistic_dir.mkdir(exist_ok=True, parents=True)
48
+ logistic_test_result_file = logistic_dir / "metrics.json"
49
+ if logistic_test_result_file.exists():
50
+ print("The models have been trained, and skip ...")
51
+ exit(0)
52
+
53
+ feature_train = pd.read_parquet(
54
+ features_data_dir / "features_with_label" / "train_features"
55
+ )
56
+ feature_test = pd.read_parquet(
57
+ features_data_dir / "features_with_label" / "test_features"
58
+ )
59
+
60
+ feature_train = feature_train.sort_values(["subject_id", "prediction_time"]).sample(
61
+ frac=1.0,
62
+ random_state=42,
63
+ replace=False,
64
+ )
65
+
66
+ if feature_processor_path.exists():
67
+ with open(feature_processor_path, "rb") as f:
68
+ feature_processor = pickle.load(f)
69
+ else:
70
+ age_scaler, gender_encoder, race_encoder = (
71
+ StandardScaler(),
72
+ OneHotEncoder(handle_unknown="ignore"),
73
+ OneHotEncoder(handle_unknown="ignore"),
74
+ )
75
+ age_scaler = age_scaler.fit(feature_train[["age_at_index"]].to_numpy())
76
+ gender_encoder = gender_encoder.fit(
77
+ feature_train[["gender_concept_id"]].to_numpy()
78
+ )
79
+ race_encoder = race_encoder.fit(feature_train[["race_concept_id"]].to_numpy())
80
+ feature_processor = {
81
+ "age_scaler": age_scaler,
82
+ "gender_encoder": gender_encoder,
83
+ "race_encoder": race_encoder,
84
+ }
85
+ with open(feature_processor_path, "wb") as f:
86
+ pickle.dump(feature_processor, f)
87
+
88
+ if logistic_test_result_file.exists():
89
+ print(
90
+ f"The results for logistic regression already exist at {logistic_test_result_file}"
91
+ )
92
+ else:
93
+ logistic_model_file = logistic_dir / "model.pickle"
94
+ if logistic_model_file.exists():
95
+ print(
96
+ f"The logistic regression model already exist, loading it from {logistic_model_file}"
97
+ )
98
+ with open(logistic_model_file, "rb") as f:
99
+ model = pickle.load(f)
100
+ else:
101
+ train_dataset = prepare_dataset(feature_train, feature_processor)
102
+ # Train logistic regression
103
+ model = LogisticRegressionCV(scoring="roc_auc", random_state=42)
104
+ model.fit(train_dataset["features"], train_dataset["boolean_value"])
105
+ with open(logistic_model_file, "wb") as f:
106
+ pickle.dump(model, f)
107
+
108
+ test_dataset = prepare_dataset(feature_test, feature_processor)
109
+ y_pred = model.predict_proba(test_dataset["features"])[:, 1]
110
+ logistic_predictions = pl.DataFrame(
111
+ {
112
+ "subject_id": test_dataset["subject_id"].tolist(),
113
+ "prediction_time": test_dataset["prediction_time"],
114
+ "predicted_boolean_probability": y_pred.tolist(),
115
+ "predicted_boolean_value": None,
116
+ "boolean_value": test_dataset["boolean_value"].astype(bool).tolist(),
117
+ }
118
+ )
119
+ logistic_predictions = logistic_predictions.with_columns(
120
+ pl.col("predicted_boolean_value").cast(pl.Boolean())
121
+ )
122
+ logistic_test_predictions = logistic_dir / "test_predictions"
123
+ logistic_test_predictions.mkdir(exist_ok=True, parents=True)
124
+ logistic_predictions.write_parquet(
125
+ logistic_test_predictions / "predictions.parquet"
126
+ )
127
+
128
+ roc_auc = roc_auc_score(test_dataset["boolean_value"], y_pred)
129
+ precision, recall, _ = precision_recall_curve(
130
+ test_dataset["boolean_value"], y_pred
131
+ )
132
+ pr_auc = auc(recall, precision)
133
+
134
+ metrics = {"roc_auc": roc_auc, "pr_auc": pr_auc}
135
+ print("Logistic:", features_data_dir.name, metrics)
136
+ with open(logistic_test_result_file, "w") as f:
137
+ json.dump(metrics, f, indent=4)
138
+
139
+
140
+ if __name__ == "__main__":
141
+ parser = argparse.ArgumentParser(
142
+ description="Train logistic regression model with cehrgpt features"
143
+ )
144
+ parser.add_argument(
145
+ "--features_data_dir",
146
+ required=True,
147
+ help="Directory containing training and test feature files",
148
+ )
149
+ parser.add_argument(
150
+ "--output_dir", required=True, help="Directory to save the output results"
151
+ )
152
+ main(parser.parse_args())
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.2
1
+ Metadata-Version: 2.4
2
2
  Name: cehrgpt
3
- Version: 0.0.2
3
+ Version: 0.1.0
4
4
  Summary: CEHR-GPT: Generating Electronic Health Records with Chronological Patient Timelines
5
5
  Author-email: Chao Pang <chaopang229@gmail.com>, Xinzhuo Jiang <xj2193@cumc.columbia.edu>, Krishna Kalluri <kk3326@cumc.columbia.edu>, Elise Minto <em3697@cumc.columbia.edu>, Jason Patterson <jp3477@cumc.columbia.edu>, Nishanth Parameshwar Pavinkurve <np2689@cumc.columbia.edu>, Karthik Natarajan <kn2174@cumc.columbia.edu>
6
6
  License: MIT License
@@ -12,13 +12,14 @@ Classifier: Programming Language :: Python :: 3
12
12
  Requires-Python: >=3.10.0
13
13
  Description-Content-Type: text/markdown
14
14
  License-File: LICENSE
15
- Requires-Dist: cehrbert==1.3.8
15
+ Requires-Dist: cehrbert==1.4.1
16
+ Requires-Dist: cehrbert_data==0.0.7
16
17
  Requires-Dist: openai==1.54.3
17
18
  Requires-Dist: optuna==4.0.0
18
- Requires-Dist: transformers==4.40.0
19
+ Requires-Dist: transformers==4.44.0
19
20
  Requires-Dist: tokenizers==0.19.0
20
21
  Requires-Dist: peft==0.10.0
21
- Requires-Dist: trl==0.11.4
22
+ Requires-Dist: lightgbm
22
23
  Provides-Extra: dev
23
24
  Requires-Dist: pre-commit; extra == "dev"
24
25
  Requires-Dist: pytest; extra == "dev"
@@ -29,6 +30,7 @@ Requires-Dist: hypothesis; extra == "dev"
29
30
  Requires-Dist: black; extra == "dev"
30
31
  Provides-Extra: flash-attn
31
32
  Requires-Dist: flash_attn; extra == "flash-attn"
33
+ Dynamic: license-file
32
34
 
33
35
  # CEHRGPT
34
36
 
@@ -11,24 +11,22 @@ cehrgpt/analysis/privacy/nearest_neighbor_inference.py,sha256=qoJgWW7VsUMzjMGpTa
11
11
  cehrgpt/analysis/privacy/reid_inference.py,sha256=Pypd3QJXQNY8VljpnIEa5zeAbTZHMjQOazaL-9VsBGw,13955
12
12
  cehrgpt/analysis/privacy/utils.py,sha256=CRA4H9mPLBjMQGKzZ_x_3ro3tMap-NjsMDVqSOjHSVQ,8226
13
13
  cehrgpt/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
- cehrgpt/data/hf_cehrgpt_dataset.py,sha256=7hvjjqE8WInVuRvAtNkFI_J-xluFBv1Ij4TPTdUxPM4,2570
15
- cehrgpt/data/hf_cehrgpt_dataset_collator.py,sha256=RYw5Isrwa4sdyQQ3Nf3cu7xPDA3m-c5ecCFf_y1TJKY,20497
16
- cehrgpt/data/hf_cehrgpt_dataset_mapping.py,sha256=IjGwLKbEfNPxH3hsNmb8p48_imHnMWtslDK6f7R_1pc,16053
17
- cehrgpt/data/hf_cehrgpt_dpo_collator.py,sha256=cqDK0SUOt3yAqUHWKGuLVi3WmmUMZ6eyxTv9fC9idZA,2787
18
- cehrgpt/data/hf_cehrgpt_dpo_dataset_mapping.py,sha256=uCLF5VEsyZAG1aNwqEM6Jy5Lx7bI5ALku52Z6Anine0,2574
14
+ cehrgpt/data/hf_cehrgpt_dataset.py,sha256=t9vpN05e--CiKgIlxLP0aLacISnvWWDPXtuFuJi3ksE,3736
15
+ cehrgpt/data/hf_cehrgpt_dataset_collator.py,sha256=DOvIF4Wzkd8-IO3zpIRZkX1j0IdvefaiSnrDn1YivCk,27912
16
+ cehrgpt/data/hf_cehrgpt_dataset_mapping.py,sha256=eI8CTk6yJ4DlNJWrNAkEmhWh353NeLqg5rwPpKqKT-U,17308
17
+ cehrgpt/data/sample_packing_sampler.py,sha256=0uKTbvtXpfS81esy_3epJ88eohyJPK46bfmxhle1fws,5419
19
18
  cehrgpt/generation/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
20
19
  cehrgpt/generation/chatgpt_generation.py,sha256=SrnLwHLdNtnAOEg36gNjqfoT9yd12iyPgpZffL2AFJo,4428
21
- cehrgpt/generation/generate_batch_hf_gpt_sequence.py,sha256=-WLpKlulVVDJSdA2jXyp87gfLW4Q3aAtwULK8fDtn_E,11408
22
- cehrgpt/generation/generate_paired_cehrgpt_sequence.py,sha256=fLu3SHhRe_ZQfS09ebOktq2dekStgYfxmbrRawZQAO4,8280
20
+ cehrgpt/generation/generate_batch_hf_gpt_sequence.py,sha256=uSEh8aMmPD61nGewIaPSkIqm-2AxDjCBiu4cBfxHxU4,11503
23
21
  cehrgpt/generation/omop_converter_batch.py,sha256=-c0AlDVy5pJ5Afhr8ERiCHhoRrEk8ozJi3g0yFdWaMI,25348
24
22
  cehrgpt/generation/omop_entity.py,sha256=Q5Sr0AlyuPAm1FRPfnJO13q-u1fqRgYVHXruZ9g4xNE,19400
25
23
  cehrgpt/models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
26
- cehrgpt/models/config.py,sha256=xek4W_siO7WtMAKE7zDsENotsIE70F8dcW-PTC0kBKk,9700
27
- cehrgpt/models/hf_cehrgpt.py,sha256=CKseTvGkBFwXK40Z_uKD1_d84oSYCFqKmHI0qtdk72g,75757
24
+ cehrgpt/models/config.py,sha256=Y3CiXZWniLP9_RlpU80Oe9gjn5leLmTYnNe_fWqfJLQ,10158
25
+ cehrgpt/models/hf_cehrgpt.py,sha256=3EQIOfa--oz4f8bM8KzbDi98G3XrUEQkox1vmBN001M,83321
28
26
  cehrgpt/models/hf_modeling_outputs.py,sha256=LaWa1jI6BRIKMEjWOy1QUeOfTur5y_p2c-JyuGVTdtw,10301
29
27
  cehrgpt/models/pretrained_embeddings.py,sha256=vLLVs17TLpXRqCVEWQxGGwPHkUJUO7laNTeBuyBK_yk,3238
30
28
  cehrgpt/models/special_tokens.py,sha256=-a7HPJBbdIH0qQ6B3CcRKqvpG6FZlm4nbVPTswGSJ4U,485
31
- cehrgpt/models/tokenization_hf_cehrgpt.py,sha256=JAZjnmQq-JjUxZK7XIsqdZB07ZB7BC2WraCjpO_6AOM,42161
29
+ cehrgpt/models/tokenization_hf_cehrgpt.py,sha256=jjCRqS29IzMnKp40jNOs80UKh2z9lK5S6M02GSB-4mk,42351
32
30
  cehrgpt/omop/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
33
31
  cehrgpt/omop/condition_era.py,sha256=hPZALz2XaWnro_1bwIYNkI48foOJjueyg3CZ1BliCno,626
34
32
  cehrgpt/omop/observation_period.py,sha256=TRMgv5Ya2RaS2im7oQ6BLC_5JL9EJYNYR62ApxIuHvg,1211
@@ -38,18 +36,19 @@ cehrgpt/omop/sample_omop_tables.py,sha256=2JZ8BNSvssceinwFanvuCRh-YlKrKn25U9w1pL
38
36
  cehrgpt/omop/queries/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
39
37
  cehrgpt/omop/queries/condition_era.py,sha256=LFB6vBAvshHJxtYIRkl7cfrF0kf7ay0piBKpmHBwrpE,2578
40
38
  cehrgpt/omop/queries/observation_period.py,sha256=fpzr5DMNw-QLoSwp2Iatfch88E3hyhZ75usiIdG3A0U,6410
41
- cehrgpt/rl_finetune/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
42
- cehrgpt/rl_finetune/cehrgpt_dpo_trainer.py,sha256=VQHf5vy5i8K1imcqYakhitfAW-d2mnaEzkSoAYSW5kg,26062
43
- cehrgpt/rl_finetune/cehrgpt_ppo_trainer.py,sha256=nYWYPCaWNjDGEwlo6UHOK1rvOZUx1vuJ8kYuAszI8Zg,17925
44
- cehrgpt/rl_finetune/ppo_finetune.py,sha256=tSy-C0Kzgj5ffclBIDj-RTj78ZfrLmTESxVxd0n9yuE,13971
45
- cehrgpt/rl_finetune/ppo_finetune_v2.py,sha256=7dChwKpq4zKmpkcxP4hryqBoIkcwmTJ44_BF8R2RghQ,13285
46
39
  cehrgpt/runners/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
47
- cehrgpt/runners/gpt_runner_util.py,sha256=88HKSVj-ADGBCMo7C3znKSMPnAAALa1iU_6P6i9sD0M,3867
48
- cehrgpt/runners/hf_cehrgpt_dpo_runner.py,sha256=Z4qNl9CZFC5YvUBc9ZzdOV5wsBFvMTdxfTn4jjtJQ-Y,4583
49
- cehrgpt/runners/hf_cehrgpt_finetune_runner.py,sha256=reflNRb6YB6f_3jAfzFAdwKtTl6hvdIp9Jc7DC-Sv-U,32580
50
- cehrgpt/runners/hf_cehrgpt_pretrain_runner.py,sha256=L3UpjtzxuS8a_tshlqpZN_sXnJSs3yzry0GZNT__05A,18200
51
- cehrgpt/runners/hf_gpt_runner_argument_dataclass.py,sha256=gKVf4BLzNCFiJR7nZVkf-QRcj8fAEVvIUTV-AVH0g_U,5312
52
- cehrgpt/runners/hyperparameter_search_util.py,sha256=i4qAb_22JO78l40MSyBPwDgAGuGc96efXmg_833cSSo,9044
40
+ cehrgpt/runners/data_utils.py,sha256=ScZZnfXwgXKaMvKgFzdb4vtQ7F_lw97O5uNsFbfsyP4,10620
41
+ cehrgpt/runners/gpt_runner_util.py,sha256=YJQSRW9Mo4TjXSOUOTf6BUFcs1MGFiXU5T4ztKZcYhU,3485
42
+ cehrgpt/runners/hf_cehrgpt_finetune_runner.py,sha256=bkPl30Y9CSXBlmMkH-3cA3-aW8XJK36Q-adx___WjkE,26921
43
+ cehrgpt/runners/hf_cehrgpt_pretrain_runner.py,sha256=ViVa_flEGdk_SO0psMR7ho-o79igsz_l1x80u81WJ3A,23875
44
+ cehrgpt/runners/hf_gpt_runner_argument_dataclass.py,sha256=VrqgDSiAMfGyHEIodoOg_8LU5O0ndWf9EE0YOKDFKKA,7019
45
+ cehrgpt/runners/hyperparameter_search_util.py,sha256=pWFmGo9Ezju4YmuZ-ohbAbYB0GGMfIDVUCyvcTxS1iU,9153
46
+ cehrgpt/runners/sample_packing_trainer.py,sha256=aezX30vxpP1DDcH5hO-yn395NqBKi2Xhb0mFNHi9OBs,7340
47
+ cehrgpt/simulations/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
48
+ cehrgpt/simulations/generate_plots.py,sha256=BTZ71r8Kah0PMorkiO3vw55_p_9U1Z8KiD3GsPfaV0s,2520
49
+ cehrgpt/simulations/run_simulation.sh,sha256=DcJ6B19jIteUO0pZ0Tc21876lB9XxQHFAxlre7MtAzk,795
50
+ cehrgpt/simulations/time_embedding_simulation.py,sha256=HZ-imXH-bN-QYZN1PAfcERmNtaWIwKjbf0UrZduwCiA,8687
51
+ cehrgpt/simulations/time_token_simulation.py,sha256=sLg8vVXydvR_zk3BbqyrlA7sDIdhFnS-s5pSKcCilSc,6057
53
52
  cehrgpt/time_to_event/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
54
53
  cehrgpt/time_to_event/time_to_event_model.py,sha256=tfXa24l_0q1TBZ68BPRrHRC_3KRWYxrWGIv4myJlIb8,8497
55
54
  cehrgpt/time_to_event/time_to_event_prediction.py,sha256=Ajesq2gSsILghWHCTLiiBhcyOCa7m6JPPMdi_xvBlR4,12624
@@ -63,8 +62,11 @@ cehrgpt/tools/generate_causal_patient_split_by_age.py,sha256=dmHiPAL_kR1WrhRteIi
63
62
  cehrgpt/tools/generate_pretrained_embeddings.py,sha256=lhFSacGv8bMld6qigKZN8Op8eXpFi0DsJuQbWKOWXqI,4160
64
63
  cehrgpt/tools/merge_synthetic_real_dataasets.py,sha256=O1dbQ32Le0t15fwymwAh9mfNVLEWuFwW53DNvESrWbY,7589
65
64
  cehrgpt/tools/upload_omop_tables.py,sha256=vdBAbkeAsGPA4NsyhNjelPVj3gS8yzmS1sKNM1Qk96g,3791
66
- cehrgpt-0.0.2.dist-info/LICENSE,sha256=LOfC32zkfUIdGm8e_098jPbt8OHKtNWymDzxn2pA9Zk,1093
67
- cehrgpt-0.0.2.dist-info/METADATA,sha256=joUmDJWMEBvYphrkwYiK273FwSL9okY74D93ncrbvMU,4878
68
- cehrgpt-0.0.2.dist-info/WHEEL,sha256=52BFRY2Up02UkjOa29eZOS2VxUrpPORXg1pkohGGUS8,91
69
- cehrgpt-0.0.2.dist-info/top_level.txt,sha256=akNCJBbMSLV8nkOzdVzdy13hMJ5CIQURnAS_YYEDVwA,17
70
- cehrgpt-0.0.2.dist-info/RECORD,,
65
+ cehrgpt/tools/linear_prob/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
66
+ cehrgpt/tools/linear_prob/compute_cehrgpt_features.py,sha256=jVgAmBrZKp7ABfqKkzwV5Vl_G9jDCjPl98NSVmSwHpE,19291
67
+ cehrgpt/tools/linear_prob/train_with_cehrgpt_features.py,sha256=w0UvzMKYGenN_KDVnbzutmy8IPLUxW5hPvpKKxDSL5U,5820
68
+ cehrgpt-0.1.0.dist-info/licenses/LICENSE,sha256=LOfC32zkfUIdGm8e_098jPbt8OHKtNWymDzxn2pA9Zk,1093
69
+ cehrgpt-0.1.0.dist-info/METADATA,sha256=V02vsptjJRD_bybXVRFXPrJa-By9CX4j-oAA3EfXFq4,4933
70
+ cehrgpt-0.1.0.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
71
+ cehrgpt-0.1.0.dist-info/top_level.txt,sha256=akNCJBbMSLV8nkOzdVzdy13hMJ5CIQURnAS_YYEDVwA,17
72
+ cehrgpt-0.1.0.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (76.0.0)
2
+ Generator: setuptools (80.7.1)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,71 +0,0 @@
1
- import torch
2
- from torch.nn.utils.rnn import pad_sequence
3
-
4
- from cehrgpt.data.hf_cehrgpt_dataset_collator import CehrGptDataCollator
5
-
6
-
7
- class CehrGptDPODataCollator(CehrGptDataCollator):
8
-
9
- def create_preference_inputs(self, examples, prefix):
10
- batch = {}
11
- # Assume that each example in the batch is a dictionary with 'input_ids' and 'attention_mask'
12
- batch_input_ids = [
13
- self._try_reverse_tensor(
14
- self._convert_to_tensor(example[f"{prefix}_input_ids"])
15
- )
16
- for example in examples
17
- ]
18
- batch_attention_mask = [
19
- self._try_reverse_tensor(
20
- torch.ones_like(
21
- self._convert_to_tensor(example[f"{prefix}_input_ids"]),
22
- dtype=torch.float,
23
- )
24
- )
25
- for example in examples
26
- ]
27
- # Pad sequences to the max length in the batch
28
- batch[f"{prefix}_input_ids"] = self._try_reverse_tensor(
29
- pad_sequence(
30
- batch_input_ids,
31
- batch_first=True,
32
- padding_value=self.tokenizer.pad_token_id,
33
- ).to(torch.int64)
34
- )
35
- batch[f"{prefix}_attention_mask"] = self._try_reverse_tensor(
36
- pad_sequence(batch_attention_mask, batch_first=True, padding_value=0.0)
37
- )
38
- assert batch[f"{prefix}_input_ids"].shape[1] <= self.max_length
39
- assert batch[f"{prefix}_attention_mask"].shape[1] <= self.max_length
40
-
41
- if self.include_values:
42
- batch_value_indicators = [
43
- self._try_reverse_tensor(
44
- self._convert_to_tensor(example[f"{prefix}_value_indicators"])
45
- )
46
- for example in examples
47
- ]
48
- batch_values = [
49
- self._try_reverse_tensor(
50
- self._convert_to_tensor(example[f"{prefix}__values"])
51
- )
52
- for example in examples
53
- ]
54
-
55
- batch[f"{prefix}_value_indicators"] = self._try_reverse_tensor(
56
- pad_sequence(
57
- batch_value_indicators, batch_first=True, padding_value=False
58
- )
59
- )
60
- batch[f"{prefix}_values"] = self._try_reverse_tensor(
61
- pad_sequence(batch_values, batch_first=True, padding_value=-1.0)
62
- )
63
- assert batch[f"{prefix}_value_indicators"].shape[1] <= self.max_length
64
- assert batch[f"{prefix}_values"].shape[1] <= self.max_length
65
- return batch
66
-
67
- def __call__(self, examples):
68
- batch_chosen = self.create_preference_inputs(examples, "chosen")
69
- batch_rejected = self.create_preference_inputs(examples, "rejected")
70
- batch_chosen.update(batch_rejected)
71
- return batch_chosen
@@ -1,61 +0,0 @@
1
- import copy
2
- from typing import Any, Dict
3
-
4
- import numpy as np
5
- from cehrbert.data_generators.hf_data_generator.hf_dataset_mapping import DatasetMapping
6
-
7
- from cehrgpt.models.tokenization_hf_cehrgpt import CehrGptTokenizer
8
-
9
-
10
- class HFCehrGptDPOTokenizationMapping(DatasetMapping):
11
- def __init__(
12
- self,
13
- concept_tokenizer: CehrGptTokenizer,
14
- ):
15
- self._concept_tokenizer = concept_tokenizer
16
- self._lab_token_ids = self._concept_tokenizer.lab_token_ids
17
-
18
- def transform_with_prefix(self, record: Dict[str, Any], prefix) -> Dict[str, Any]:
19
- concept_ids = record[f"{prefix}_concept_ids"]
20
- input_ids = self._concept_tokenizer.encode(concept_ids)
21
- record[f"{prefix}_input_ids"] = input_ids
22
-
23
- if f"{prefix}_concept_value_masks" in record:
24
- concept_value_masks = record[f"{prefix}_concept_value_masks"]
25
- concept_values = record[f"{prefix}_concept_values"]
26
- # If any concept has a value associated with it, we normalize the value
27
- if np.any(np.asarray(concept_value_masks) > 0):
28
- units = record[f"{prefix}_units"]
29
- normalized_concept_values = copy.deepcopy(concept_values)
30
- for i, (
31
- concept_id,
32
- unit,
33
- token_id,
34
- concept_value_mask,
35
- concept_value,
36
- ) in enumerate(
37
- zip(
38
- concept_ids,
39
- units,
40
- input_ids,
41
- concept_value_masks,
42
- concept_values,
43
- )
44
- ):
45
- if token_id in self._lab_token_ids:
46
- normalized_concept_value = self._concept_tokenizer.normalize(
47
- concept_id, unit, concept_value
48
- )
49
- normalized_concept_values[i] = normalized_concept_value
50
- record[f"{prefix}_concept_values"] = normalized_concept_values
51
- # Overwrite the column names
52
- record[f"{prefix}_value_indicators"] = record[
53
- f"{prefix}_concept_value_masks"
54
- ]
55
- record[f"{prefix}_values"] = record[f"{prefix}_concept_values"]
56
- return record
57
-
58
- def transform(self, record: Dict[str, Any]) -> Dict[str, Any]:
59
- record = self.transform_with_prefix(record, prefix="chosen")
60
- record.update(self.transform_with_prefix(record, prefix="rejected"))
61
- return record
@@ -1,224 +0,0 @@
1
- import datetime
2
- import os
3
- import random
4
- import uuid
5
-
6
- import pandas as pd
7
- import torch
8
- from cehrbert.runners.runner_util import load_parquet_as_dataset
9
- from transformers.utils import is_flash_attn_2_available, logging
10
-
11
- from cehrgpt.cehrgpt_args import create_inference_base_arg_parser
12
- from cehrgpt.generation.generate_batch_hf_gpt_sequence import (
13
- generate_single_batch,
14
- normalize_value,
15
- )
16
- from cehrgpt.gpt_utils import get_cehrgpt_output_folder
17
- from cehrgpt.models.hf_cehrgpt import CEHRGPT2LMHeadModel
18
- from cehrgpt.models.tokenization_hf_cehrgpt import CehrGptTokenizer
19
-
20
- LOG = logging.get_logger("transformers")
21
-
22
-
23
- def main(args):
24
- if torch.cuda.is_available():
25
- device = torch.device("cuda")
26
- else:
27
- device = torch.device("cpu")
28
-
29
- cehrgpt_tokenizer = CehrGptTokenizer.from_pretrained(args.tokenizer_folder)
30
- cehrgpt_model = (
31
- CEHRGPT2LMHeadModel.from_pretrained(
32
- args.model_folder,
33
- attn_implementation=(
34
- "flash_attention_2" if is_flash_attn_2_available() else "eager"
35
- ),
36
- torch_dtype=(
37
- torch.bfloat16 if is_flash_attn_2_available() else torch.float32
38
- ),
39
- )
40
- .eval()
41
- .to(device)
42
- )
43
- cehrgpt_model.generation_config.pad_token_id = cehrgpt_tokenizer.pad_token_id
44
- cehrgpt_model.generation_config.eos_token_id = cehrgpt_tokenizer.end_token_id
45
- cehrgpt_model.generation_config.bos_token_id = cehrgpt_tokenizer.end_token_id
46
-
47
- folder_name = get_cehrgpt_output_folder(args, cehrgpt_tokenizer)
48
- output_folder_name = os.path.join(
49
- args.output_folder, folder_name, "generated_sequences"
50
- )
51
-
52
- if not os.path.exists(output_folder_name):
53
- os.makedirs(output_folder_name)
54
-
55
- LOG.info(f"Loading tokenizer at {args.model_folder}")
56
- LOG.info(f"Loading model at {args.model_folder}")
57
- LOG.info(f"Write sequences to {output_folder_name}")
58
- LOG.info(f"Context window {args.context_window}")
59
- LOG.info(f"Temperature {args.temperature}")
60
- LOG.info(f"Repetition Penalty {args.repetition_penalty}")
61
- LOG.info(f"Sampling Strategy {args.sampling_strategy}")
62
- LOG.info(f"Num beam {args.num_beams}")
63
- LOG.info(f"Num beam groups {args.num_beam_groups}")
64
- LOG.info(f"Epsilon cutoff {args.epsilon_cutoff}")
65
- LOG.info(f"Top P {args.top_p}")
66
- LOG.info(f"Top K {args.top_k}")
67
- LOG.info(f"Loading sequence_data_path at {args.sequence_data_path}")
68
-
69
- dataset = load_parquet_as_dataset(args.sequence_data_path)
70
- total_rows = len(dataset)
71
- float(args.batch_size) / total_rows
72
- num_of_batches = args.num_of_patients // args.batch_size + 1
73
- sequence_to_flush = []
74
- for i in range(num_of_batches):
75
- LOG.info(f"{datetime.datetime.now()}: Batch {i} started")
76
- sample_data = []
77
- while len(sample_data) == 0:
78
- random_indices = random.sample(range(total_rows), k=1)
79
- for row in dataset.select(random_indices):
80
- if 4 <= len(row["concept_ids"]) <= cehrgpt_model.config.n_positions:
81
- sample_data.append(row)
82
- prompts = []
83
- chosen_responses = []
84
- cutoff_frac = random.uniform(0, args.cutoff_frac_max)
85
- for row in sample_data:
86
- seq_len = len(row["concept_ids"])
87
- prompt_len = max(4, int(seq_len * cutoff_frac))
88
- prompts.append(cehrgpt_tokenizer.encode(row["concept_ids"][:prompt_len]))
89
- chosen_responses.append(
90
- {
91
- "person_id": row["person_id"],
92
- "chosen_concept_ids": (
93
- row["concept_ids"] if "concept_ids" in row else None
94
- ),
95
- "chosen_concept_values": (
96
- row["concept_values"] if "concept_values" in row else None
97
- ),
98
- "chosen_concept_value_masks": (
99
- row["concept_value_masks"]
100
- if "concept_value_masks" in row
101
- else None
102
- ),
103
- "chosen_units": row["units"] if "units" in row else None,
104
- "prompt_length": prompt_len,
105
- }
106
- )
107
-
108
- batch_sequences = generate_single_batch(
109
- cehrgpt_model,
110
- cehrgpt_tokenizer,
111
- prompts=prompts,
112
- max_new_tokens=args.context_window,
113
- mini_num_of_concepts=args.min_num_of_concepts,
114
- top_p=args.top_p,
115
- top_k=args.top_k,
116
- temperature=args.temperature,
117
- repetition_penalty=args.repetition_penalty,
118
- num_beams=args.num_beams,
119
- num_beam_groups=args.num_beam_groups,
120
- epsilon_cutoff=args.epsilon_cutoff,
121
- device=device,
122
- )
123
-
124
- # Clear the cache
125
- torch.cuda.empty_cache()
126
-
127
- for seq, value_indicator, value, chosen_response in zip(
128
- batch_sequences["sequences"],
129
- batch_sequences["value_indicators"],
130
- batch_sequences["values"],
131
- chosen_responses,
132
- ):
133
- output = {"rejected_concept_ids": seq}
134
- normalized_values, units = normalize_value(
135
- seq, value_indicator, value, cehrgpt_tokenizer
136
- )
137
- if normalized_values is not None:
138
- output["rejected_concept_values"] = normalized_values
139
- if value_indicator is not None:
140
- output["rejected_concept_value_masks"] = value_indicator
141
- if units is not None:
142
- output["rejected_units"] = units
143
- output.update(chosen_response)
144
- sequence_to_flush.append(output)
145
-
146
- if len(sequence_to_flush) >= args.buffer_size:
147
- LOG.info(f"{datetime.datetime.now()}: Flushing to the Disk at Batch {i}")
148
- pd.DataFrame(
149
- sequence_to_flush,
150
- columns=[
151
- "person_id",
152
- "chosen_concept_ids",
153
- "chosen_concept_values",
154
- "chosen_concept_value_masks",
155
- "chosen_units",
156
- "prompt_length",
157
- "rejected_concept_ids",
158
- "rejected_concept_values",
159
- "rejected_concept_value_masks",
160
- "rejected_units",
161
- ],
162
- ).to_parquet(os.path.join(output_folder_name, f"{uuid.uuid4()}.parquet"))
163
- sequence_to_flush.clear()
164
-
165
- if len(sequence_to_flush) > 0:
166
- LOG.info(f"{datetime.datetime.now()}: Flushing to the Disk at Final Batch")
167
- pd.DataFrame(
168
- sequence_to_flush,
169
- columns=[
170
- "person_id",
171
- "chosen_concept_ids",
172
- "chosen_concept_values",
173
- "chosen_concept_value_masks",
174
- "chosen_units",
175
- "prompt_length",
176
- "rejected_concept_ids",
177
- "rejected_concept_values",
178
- "rejected_concept_value_masks",
179
- "rejected_units",
180
- ],
181
- ).to_parquet(os.path.join(output_folder_name, f"{uuid.uuid4()}-last.parquet"))
182
-
183
-
184
- def create_arg_parser():
185
- base_arg_parser = create_inference_base_arg_parser(
186
- description="Arguments for generating paired patient sequences"
187
- )
188
- base_arg_parser.add_argument(
189
- "--num_of_patients",
190
- dest="num_of_patients",
191
- action="store",
192
- type=int,
193
- help="The number of patients that will be generated",
194
- required=True,
195
- )
196
- base_arg_parser.add_argument(
197
- "--sequence_data_path",
198
- dest="sequence_data_path",
199
- action="store",
200
- help="The path for your sequence data",
201
- required=True,
202
- )
203
- base_arg_parser.add_argument(
204
- "--cutoff_frac_max",
205
- dest="cutoff_frac_max",
206
- action="store",
207
- type=float,
208
- help="The max fraction of the patient sequences that will be used for prompting",
209
- required=False,
210
- default=0.5,
211
- )
212
- base_arg_parser.add_argument(
213
- "--num_proc",
214
- dest="num_proc",
215
- action="store",
216
- type=int,
217
- required=False,
218
- default=1,
219
- )
220
- return base_arg_parser
221
-
222
-
223
- if __name__ == "__main__":
224
- main(create_arg_parser().parse_args())