cehrgpt 0.1.4__py3-none-any.whl → 0.1.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cehrgpt/data/cehrgpt_data_processor.py +6 -5
- cehrgpt/data/hf_cehrgpt_dataset_collator.py +14 -0
- {cehrgpt-0.1.4.dist-info → cehrgpt-0.1.6.dist-info}/METADATA +6 -4
- {cehrgpt-0.1.4.dist-info → cehrgpt-0.1.6.dist-info}/RECORD +7 -7
- {cehrgpt-0.1.4.dist-info → cehrgpt-0.1.6.dist-info}/WHEEL +0 -0
- {cehrgpt-0.1.4.dist-info → cehrgpt-0.1.6.dist-info}/licenses/LICENSE +0 -0
- {cehrgpt-0.1.4.dist-info → cehrgpt-0.1.6.dist-info}/top_level.txt +0 -0
@@ -275,11 +275,7 @@ class CehrGptDataProcessor(DatasetMapping):
|
|
275
275
|
if demographic_tokens is not None
|
276
276
|
else self.empty_array
|
277
277
|
),
|
278
|
-
np.asarray(
|
279
|
-
self._convert_time_to_event(
|
280
|
-
record["concept_ids"][start_index:end_index]
|
281
|
-
)
|
282
|
-
),
|
278
|
+
np.asarray(record["time_to_visits"][start_index:end_index]),
|
283
279
|
np.asarray([-100.0]) if add_last_token else self.empty_array,
|
284
280
|
]
|
285
281
|
).astype(np.float32)
|
@@ -303,6 +299,11 @@ class CehrGptDataProcessor(DatasetMapping):
|
|
303
299
|
record["concept_ids"], record["epoch_times"]
|
304
300
|
)
|
305
301
|
|
302
|
+
if self.include_ttv_prediction:
|
303
|
+
record["time_to_visits"] = np.asarray(
|
304
|
+
self._convert_time_to_event(record["concept_ids"])
|
305
|
+
)
|
306
|
+
|
306
307
|
# Return the record directly if the actual sequence length is less than the max sequence
|
307
308
|
if seq_length <= new_max_length:
|
308
309
|
# We only add [END] to the end of the sequence in pre-training
|
@@ -528,6 +528,7 @@ class SamplePackingCehrGptDataCollator(CehrGptDataCollator):
|
|
528
528
|
current_epoch_times = []
|
529
529
|
current_value_indicators = []
|
530
530
|
current_values = []
|
531
|
+
current_time_to_visits = []
|
531
532
|
|
532
533
|
# MOTOR inputs
|
533
534
|
current_motor_censor_times = []
|
@@ -567,6 +568,16 @@ class SamplePackingCehrGptDataCollator(CehrGptDataCollator):
|
|
567
568
|
)
|
568
569
|
current_epoch_times.extend(epoch_times + [max(epoch_times)])
|
569
570
|
|
571
|
+
if self.include_ttv_prediction:
|
572
|
+
current_time_to_visits.extend(
|
573
|
+
(
|
574
|
+
example["time_to_visits"].tolist()
|
575
|
+
if isinstance(example["time_to_visits"], torch.Tensor)
|
576
|
+
else list(example["time_to_visits"])
|
577
|
+
)
|
578
|
+
+ [-100]
|
579
|
+
)
|
580
|
+
|
570
581
|
if self.include_values:
|
571
582
|
current_value_indicators.extend(
|
572
583
|
(
|
@@ -649,6 +660,9 @@ class SamplePackingCehrGptDataCollator(CehrGptDataCollator):
|
|
649
660
|
"epoch_times": current_epoch_times,
|
650
661
|
}
|
651
662
|
|
663
|
+
if self.include_ttv_prediction:
|
664
|
+
packed_example.update({"time_to_visits": current_time_to_visits})
|
665
|
+
|
652
666
|
if self.include_values:
|
653
667
|
packed_example.update(
|
654
668
|
{"value_indicators": current_value_indicators, "values": current_values}
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: cehrgpt
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.6
|
4
4
|
Summary: CEHR-GPT: Generating Electronic Health Records with Chronological Patient Timelines
|
5
5
|
Author-email: Chao Pang <chaopang229@gmail.com>, Xinzhuo Jiang <xj2193@cumc.columbia.edu>, Krishna Kalluri <kk3326@cumc.columbia.edu>, Elise Minto <em3697@cumc.columbia.edu>, Jason Patterson <jp3477@cumc.columbia.edu>, Nishanth Parameshwar Pavinkurve <np2689@cumc.columbia.edu>, Karthik Natarajan <kn2174@cumc.columbia.edu>
|
6
6
|
License: MIT License
|
@@ -12,8 +12,8 @@ Classifier: Programming Language :: Python :: 3
|
|
12
12
|
Requires-Python: >=3.10.0
|
13
13
|
Description-Content-Type: text/markdown
|
14
14
|
License-File: LICENSE
|
15
|
-
Requires-Dist: cehrbert
|
16
|
-
Requires-Dist: cehrbert_data
|
15
|
+
Requires-Dist: cehrbert==1.4.8
|
16
|
+
Requires-Dist: cehrbert_data==0.1.1
|
17
17
|
Requires-Dist: openai==1.54.3
|
18
18
|
Requires-Dist: optuna==4.0.0
|
19
19
|
Requires-Dist: transformers==4.44.1
|
@@ -104,7 +104,9 @@ python -u -m cehrgpt.runners.hf_cehrgpt_pretrain_runner \
|
|
104
104
|
--sample_packing --max_tokens_per_batch 16384 \
|
105
105
|
--warmup_ratio 0.01 --weight_decay 0.01 \
|
106
106
|
--num_train_epochs 50 --learning_rate 0.0002 \
|
107
|
-
--use_early_stopping
|
107
|
+
--use_early_stopping \
|
108
|
+
--load_best_model_at_end true \
|
109
|
+
--early_stopping_threshold 0.001
|
108
110
|
```
|
109
111
|
|
110
112
|
> **Tip**: Increase `max_position_embeddings` for longer context windows based on your use case.
|
@@ -18,9 +18,9 @@ cehrgpt/analysis/treatment_pathway/diabetes_treatment_pathway.py,sha256=qwAtJ3KV
|
|
18
18
|
cehrgpt/analysis/treatment_pathway/htn_treatment_pathway.py,sha256=0bsEE1VFIxzU33bSipM30p2fnHsWjGWWcu59y_38K3c,2870
|
19
19
|
cehrgpt/analysis/treatment_pathway/treatment_pathway.py,sha256=SCWphYH9ARa4ZKB9fgBYM9RC2Hc8PDwtoHHCX7th16Q,25496
|
20
20
|
cehrgpt/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
21
|
-
cehrgpt/data/cehrgpt_data_processor.py,sha256=
|
21
|
+
cehrgpt/data/cehrgpt_data_processor.py,sha256=27k_LsNs6m9M9uiaWKybNCp0d7aE2BxNmCVFW75EFo4,23143
|
22
22
|
cehrgpt/data/hf_cehrgpt_dataset.py,sha256=uz05TG5QCl3_Ybn9zZyWRg0pEbiAvL1yPWXK3BGsj0Q,3815
|
23
|
-
cehrgpt/data/hf_cehrgpt_dataset_collator.py,sha256=
|
23
|
+
cehrgpt/data/hf_cehrgpt_dataset_collator.py,sha256=sI_cszVCI7WeIYAcOfDV-IZFiHrZDJCLN3Hb5W-X72E,28467
|
24
24
|
cehrgpt/data/hf_cehrgpt_dataset_mapping.py,sha256=-Igd-P-yvYlJXGZSGlYHRnez464NCkZIko3boQDYS1E,27638
|
25
25
|
cehrgpt/data/sample_packing_sampler.py,sha256=vovGMtmhG70DRkSCeiaDEJ_rjKZ38y-YLaI1kkhFEkI,6747
|
26
26
|
cehrgpt/generation/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -78,8 +78,8 @@ cehrgpt/tools/upload_omop_tables.py,sha256=vdBAbkeAsGPA4NsyhNjelPVj3gS8yzmS1sKNM
|
|
78
78
|
cehrgpt/tools/linear_prob/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
79
79
|
cehrgpt/tools/linear_prob/compute_cehrgpt_features.py,sha256=0i34zAwePG0hZK2HSDaUlO-Fzyb5K4LqRuhrCVWivxA,19906
|
80
80
|
cehrgpt/tools/linear_prob/train_with_cehrgpt_features.py,sha256=w0UvzMKYGenN_KDVnbzutmy8IPLUxW5hPvpKKxDSL5U,5820
|
81
|
-
cehrgpt-0.1.
|
82
|
-
cehrgpt-0.1.
|
83
|
-
cehrgpt-0.1.
|
84
|
-
cehrgpt-0.1.
|
85
|
-
cehrgpt-0.1.
|
81
|
+
cehrgpt-0.1.6.dist-info/licenses/LICENSE,sha256=LOfC32zkfUIdGm8e_098jPbt8OHKtNWymDzxn2pA9Zk,1093
|
82
|
+
cehrgpt-0.1.6.dist-info/METADATA,sha256=44PUxaHLJ6us2MvSiqJeXVeG7M-Tr9DUzkZAwX2GoyM,10204
|
83
|
+
cehrgpt-0.1.6.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
84
|
+
cehrgpt-0.1.6.dist-info/top_level.txt,sha256=akNCJBbMSLV8nkOzdVzdy13hMJ5CIQURnAS_YYEDVwA,17
|
85
|
+
cehrgpt-0.1.6.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|