cehrgpt 0.0.2__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cehrgpt/analysis/irregularity.py +36 -0
- cehrgpt/data/hf_cehrgpt_dataset.py +25 -4
- cehrgpt/data/hf_cehrgpt_dataset_collator.py +635 -97
- cehrgpt/data/hf_cehrgpt_dataset_mapping.py +308 -95
- cehrgpt/data/sample_packing_sampler.py +181 -0
- cehrgpt/generation/generate_batch_hf_gpt_sequence.py +12 -9
- cehrgpt/generation/omop_converter_batch.py +32 -2
- cehrgpt/gpt_utils.py +20 -2
- cehrgpt/models/config.py +35 -0
- cehrgpt/models/hf_cehrgpt.py +470 -106
- cehrgpt/models/hf_modeling_outputs.py +1 -0
- cehrgpt/models/special_tokens.py +1 -0
- cehrgpt/models/tokenization_hf_cehrgpt.py +358 -71
- cehrgpt/runners/data_utils.py +358 -0
- cehrgpt/runners/gpt_runner_util.py +0 -10
- cehrgpt/runners/hf_cehrgpt_finetune_runner.py +181 -283
- cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +288 -112
- cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +90 -0
- cehrgpt/runners/hyperparameter_search_util.py +10 -8
- cehrgpt/runners/sample_packing_trainer.py +185 -0
- cehrgpt/simulations/generate_plots.py +95 -0
- cehrgpt/simulations/run_simulation.sh +24 -0
- cehrgpt/simulations/time_embedding_simulation.py +250 -0
- cehrgpt/simulations/time_token_simulation.py +177 -0
- cehrgpt/time_to_event/config/1_year_cabg.yaml +23 -0
- cehrgpt/time_to_event/time_to_event_model.py +2 -13
- cehrgpt/time_to_event/time_to_event_prediction.py +27 -13
- cehrgpt/tools/linear_prob/__init__.py +0 -0
- cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +495 -0
- cehrgpt/tools/linear_prob/train_with_cehrgpt_features.py +152 -0
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/METADATA +11 -8
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/RECORD +36 -32
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/WHEEL +1 -1
- cehrgpt/data/hf_cehrgpt_dpo_collator.py +0 -71
- cehrgpt/data/hf_cehrgpt_dpo_dataset_mapping.py +0 -61
- cehrgpt/generation/generate_paired_cehrgpt_sequence.py +0 -224
- cehrgpt/rl_finetune/cehrgpt_dpo_trainer.py +0 -586
- cehrgpt/rl_finetune/cehrgpt_ppo_trainer.py +0 -464
- cehrgpt/rl_finetune/ppo_finetune.py +0 -394
- cehrgpt/rl_finetune/ppo_finetune_v2.py +0 -373
- cehrgpt/runners/hf_cehrgpt_dpo_runner.py +0 -119
- /cehrgpt/{rl_finetune → simulations}/__init__.py +0 -0
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info/licenses}/LICENSE +0 -0
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/top_level.txt +0 -0
@@ -1,8 +1,9 @@
|
|
1
1
|
__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
2
|
cehrgpt/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
3
|
cehrgpt/cehrgpt_args.py,sha256=zPLp9Qjlq5PapWx3R15BNnyaX8zV3dxr4PuWj71r0Lg,3516
|
4
|
-
cehrgpt/gpt_utils.py,sha256=
|
4
|
+
cehrgpt/gpt_utils.py,sha256=IA5qw-hxcKkGO07AB47lDNRU6mlb9jblpKO7KeLLN78,11342
|
5
5
|
cehrgpt/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
6
|
+
cehrgpt/analysis/irregularity.py,sha256=Rfl_daMvSh9cZ68vUwfmuH-JYCFXdAph2ITHHffYC0Y,1047
|
6
7
|
cehrgpt/analysis/privacy/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
7
8
|
cehrgpt/analysis/privacy/attribute_inference.py,sha256=0ANVW0I5uvOl6IxQ15-vMVQd0mugOgSGReBUQQESImg,9368
|
8
9
|
cehrgpt/analysis/privacy/attribute_inference_config.yml,sha256=hfLfpBlDqqsNOynpRHK414vV24edKA6ta-inmEhM2ao,103272
|
@@ -11,24 +12,22 @@ cehrgpt/analysis/privacy/nearest_neighbor_inference.py,sha256=qoJgWW7VsUMzjMGpTa
|
|
11
12
|
cehrgpt/analysis/privacy/reid_inference.py,sha256=Pypd3QJXQNY8VljpnIEa5zeAbTZHMjQOazaL-9VsBGw,13955
|
12
13
|
cehrgpt/analysis/privacy/utils.py,sha256=CRA4H9mPLBjMQGKzZ_x_3ro3tMap-NjsMDVqSOjHSVQ,8226
|
13
14
|
cehrgpt/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
14
|
-
cehrgpt/data/hf_cehrgpt_dataset.py,sha256=
|
15
|
-
cehrgpt/data/hf_cehrgpt_dataset_collator.py,sha256=
|
16
|
-
cehrgpt/data/hf_cehrgpt_dataset_mapping.py,sha256=
|
17
|
-
cehrgpt/data/
|
18
|
-
cehrgpt/data/hf_cehrgpt_dpo_dataset_mapping.py,sha256=uCLF5VEsyZAG1aNwqEM6Jy5Lx7bI5ALku52Z6Anine0,2574
|
15
|
+
cehrgpt/data/hf_cehrgpt_dataset.py,sha256=hwJlGW7XiJIr6cXtmwvReQf9yLZJPD-dvJGvRg5ERqU,3755
|
16
|
+
cehrgpt/data/hf_cehrgpt_dataset_collator.py,sha256=ACMXiaYnR3bKD5dRleL0_siEvhL-2HAFcy5eBgvxnH4,44412
|
17
|
+
cehrgpt/data/hf_cehrgpt_dataset_mapping.py,sha256=KU0WMjc2vT1zBAl7JJkOc8dgGxsL1uFDy4dDrv-RkII,25668
|
18
|
+
cehrgpt/data/sample_packing_sampler.py,sha256=vovGMtmhG70DRkSCeiaDEJ_rjKZ38y-YLaI1kkhFEkI,6747
|
19
19
|
cehrgpt/generation/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
20
20
|
cehrgpt/generation/chatgpt_generation.py,sha256=SrnLwHLdNtnAOEg36gNjqfoT9yd12iyPgpZffL2AFJo,4428
|
21
|
-
cehrgpt/generation/generate_batch_hf_gpt_sequence.py,sha256
|
22
|
-
cehrgpt/generation/
|
23
|
-
cehrgpt/generation/omop_converter_batch.py,sha256=-c0AlDVy5pJ5Afhr8ERiCHhoRrEk8ozJi3g0yFdWaMI,25348
|
21
|
+
cehrgpt/generation/generate_batch_hf_gpt_sequence.py,sha256=uSEh8aMmPD61nGewIaPSkIqm-2AxDjCBiu4cBfxHxU4,11503
|
22
|
+
cehrgpt/generation/omop_converter_batch.py,sha256=LUmCD-t_6ZP1YfNDZCqYewl-XIIaIgRZ_dAxuR_VdCQ,26275
|
24
23
|
cehrgpt/generation/omop_entity.py,sha256=Q5Sr0AlyuPAm1FRPfnJO13q-u1fqRgYVHXruZ9g4xNE,19400
|
25
24
|
cehrgpt/models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
26
|
-
cehrgpt/models/config.py,sha256=
|
27
|
-
cehrgpt/models/hf_cehrgpt.py,sha256=
|
28
|
-
cehrgpt/models/hf_modeling_outputs.py,sha256=
|
25
|
+
cehrgpt/models/config.py,sha256=nOAKgH5420HLCcy7n1hE7MbqR861Iq4DTutKoAd25tg,11090
|
26
|
+
cehrgpt/models/hf_cehrgpt.py,sha256=77CAkdMPgxD4xSpFU7gYGzRn6_Iv-4q7FnHpnZGsKxw,92450
|
27
|
+
cehrgpt/models/hf_modeling_outputs.py,sha256=5X4WEYKqT37phv_e5ZAv3A_N0wqdAUJLJRm6TxS6dDQ,10356
|
29
28
|
cehrgpt/models/pretrained_embeddings.py,sha256=vLLVs17TLpXRqCVEWQxGGwPHkUJUO7laNTeBuyBK_yk,3238
|
30
|
-
cehrgpt/models/special_tokens.py,sha256
|
31
|
-
cehrgpt/models/tokenization_hf_cehrgpt.py,sha256=
|
29
|
+
cehrgpt/models/special_tokens.py,sha256=lrw45B4tea4Dsajn09Cz6w5D2TfHmYXikZkgwnstu_o,521
|
30
|
+
cehrgpt/models/tokenization_hf_cehrgpt.py,sha256=cAxHTctpVBxfWfC3XcwDQavN1zwWN9Nid_Fajd5zQWQ,53159
|
32
31
|
cehrgpt/omop/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
33
32
|
cehrgpt/omop/condition_era.py,sha256=hPZALz2XaWnro_1bwIYNkI48foOJjueyg3CZ1BliCno,626
|
34
33
|
cehrgpt/omop/observation_period.py,sha256=TRMgv5Ya2RaS2im7oQ6BLC_5JL9EJYNYR62ApxIuHvg,1211
|
@@ -38,22 +37,24 @@ cehrgpt/omop/sample_omop_tables.py,sha256=2JZ8BNSvssceinwFanvuCRh-YlKrKn25U9w1pL
|
|
38
37
|
cehrgpt/omop/queries/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
39
38
|
cehrgpt/omop/queries/condition_era.py,sha256=LFB6vBAvshHJxtYIRkl7cfrF0kf7ay0piBKpmHBwrpE,2578
|
40
39
|
cehrgpt/omop/queries/observation_period.py,sha256=fpzr5DMNw-QLoSwp2Iatfch88E3hyhZ75usiIdG3A0U,6410
|
41
|
-
cehrgpt/rl_finetune/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
42
|
-
cehrgpt/rl_finetune/cehrgpt_dpo_trainer.py,sha256=VQHf5vy5i8K1imcqYakhitfAW-d2mnaEzkSoAYSW5kg,26062
|
43
|
-
cehrgpt/rl_finetune/cehrgpt_ppo_trainer.py,sha256=nYWYPCaWNjDGEwlo6UHOK1rvOZUx1vuJ8kYuAszI8Zg,17925
|
44
|
-
cehrgpt/rl_finetune/ppo_finetune.py,sha256=tSy-C0Kzgj5ffclBIDj-RTj78ZfrLmTESxVxd0n9yuE,13971
|
45
|
-
cehrgpt/rl_finetune/ppo_finetune_v2.py,sha256=7dChwKpq4zKmpkcxP4hryqBoIkcwmTJ44_BF8R2RghQ,13285
|
46
40
|
cehrgpt/runners/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
47
|
-
cehrgpt/runners/
|
48
|
-
cehrgpt/runners/
|
49
|
-
cehrgpt/runners/hf_cehrgpt_finetune_runner.py,sha256=
|
50
|
-
cehrgpt/runners/hf_cehrgpt_pretrain_runner.py,sha256=
|
51
|
-
cehrgpt/runners/hf_gpt_runner_argument_dataclass.py,sha256=
|
52
|
-
cehrgpt/runners/hyperparameter_search_util.py,sha256=
|
41
|
+
cehrgpt/runners/data_utils.py,sha256=I6k1TkiiZR8ggw3eVO16g2lVPY-Hu3b-nbrIOKlFIO0,15528
|
42
|
+
cehrgpt/runners/gpt_runner_util.py,sha256=YJQSRW9Mo4TjXSOUOTf6BUFcs1MGFiXU5T4ztKZcYhU,3485
|
43
|
+
cehrgpt/runners/hf_cehrgpt_finetune_runner.py,sha256=GVbHHqf5TWGbVWlQG-XurgYH8pKRjTk8ug_ib9L9U7E,28118
|
44
|
+
cehrgpt/runners/hf_cehrgpt_pretrain_runner.py,sha256=ERSnvB38fPYVghtKQeNTZ8VfeXnoRcCHB0cWISWaZ84,26523
|
45
|
+
cehrgpt/runners/hf_gpt_runner_argument_dataclass.py,sha256=ejAFLM9g765p1fyeF5MITsiIeWHKkz9wTeFDeVgxSto,8851
|
46
|
+
cehrgpt/runners/hyperparameter_search_util.py,sha256=YWdFQ1igQs-G_wqWUrUzYraGiz8OSpSYyvid-I5nhWA,9262
|
47
|
+
cehrgpt/runners/sample_packing_trainer.py,sha256=Zb7Aqwnk8-VqrjEKUVeg5XzZWmHxXOU2sDn1YURS-FU,7960
|
48
|
+
cehrgpt/simulations/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
49
|
+
cehrgpt/simulations/generate_plots.py,sha256=BTZ71r8Kah0PMorkiO3vw55_p_9U1Z8KiD3GsPfaV0s,2520
|
50
|
+
cehrgpt/simulations/run_simulation.sh,sha256=DcJ6B19jIteUO0pZ0Tc21876lB9XxQHFAxlre7MtAzk,795
|
51
|
+
cehrgpt/simulations/time_embedding_simulation.py,sha256=HZ-imXH-bN-QYZN1PAfcERmNtaWIwKjbf0UrZduwCiA,8687
|
52
|
+
cehrgpt/simulations/time_token_simulation.py,sha256=sLg8vVXydvR_zk3BbqyrlA7sDIdhFnS-s5pSKcCilSc,6057
|
53
53
|
cehrgpt/time_to_event/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
54
|
-
cehrgpt/time_to_event/time_to_event_model.py,sha256=
|
55
|
-
cehrgpt/time_to_event/time_to_event_prediction.py,sha256=
|
54
|
+
cehrgpt/time_to_event/time_to_event_model.py,sha256=Plm0bZxvlAbnMl82DTBXWvaXLvrqcdkzcP_celX8WC4,8055
|
55
|
+
cehrgpt/time_to_event/time_to_event_prediction.py,sha256=W2e7UqIV7ELdfTy997HS66vggjnhdncCKt840knI0Dw,13183
|
56
56
|
cehrgpt/time_to_event/time_to_event_utils.py,sha256=KN4hwGgxy2nJtO7osbYQBF3-HpmGUWefNfexzPYiEwc,1937
|
57
|
+
cehrgpt/time_to_event/config/1_year_cabg.yaml,sha256=SFF2-F5D02pDSMRddDrEUoERBCd0t2Hzln_xC-Mo2hA,407
|
57
58
|
cehrgpt/time_to_event/config/30_day_readmission.yaml,sha256=Hn5KnEXMtSV_CtCpmAU4wjkc0-gTXvniaH991TSbUXA,234
|
58
59
|
cehrgpt/time_to_event/config/next_visit_type_prediction.yaml,sha256=WMj2ZutEvHKIMyGG51xtXaL6MyRANKvpg9xT8ouctLc,319
|
59
60
|
cehrgpt/time_to_event/config/t2dm_hf.yaml,sha256=_oMQzh2eJTYzEaMOpmhAzbX-qmdsKlkORELL6HxOxHo,202
|
@@ -63,8 +64,11 @@ cehrgpt/tools/generate_causal_patient_split_by_age.py,sha256=dmHiPAL_kR1WrhRteIi
|
|
63
64
|
cehrgpt/tools/generate_pretrained_embeddings.py,sha256=lhFSacGv8bMld6qigKZN8Op8eXpFi0DsJuQbWKOWXqI,4160
|
64
65
|
cehrgpt/tools/merge_synthetic_real_dataasets.py,sha256=O1dbQ32Le0t15fwymwAh9mfNVLEWuFwW53DNvESrWbY,7589
|
65
66
|
cehrgpt/tools/upload_omop_tables.py,sha256=vdBAbkeAsGPA4NsyhNjelPVj3gS8yzmS1sKNM1Qk96g,3791
|
66
|
-
cehrgpt
|
67
|
-
cehrgpt
|
68
|
-
cehrgpt
|
69
|
-
cehrgpt-0.
|
70
|
-
cehrgpt-0.
|
67
|
+
cehrgpt/tools/linear_prob/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
68
|
+
cehrgpt/tools/linear_prob/compute_cehrgpt_features.py,sha256=q0rmlBWDDEkjHjwcTouGUhCYa32a1vRicaDOAMsdW0I,20741
|
69
|
+
cehrgpt/tools/linear_prob/train_with_cehrgpt_features.py,sha256=w0UvzMKYGenN_KDVnbzutmy8IPLUxW5hPvpKKxDSL5U,5820
|
70
|
+
cehrgpt-0.1.1.dist-info/licenses/LICENSE,sha256=LOfC32zkfUIdGm8e_098jPbt8OHKtNWymDzxn2pA9Zk,1093
|
71
|
+
cehrgpt-0.1.1.dist-info/METADATA,sha256=VnXH74vJQZaV7VxGiIvJnFhQA0jzJQNx86yHFkygobM,4922
|
72
|
+
cehrgpt-0.1.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
73
|
+
cehrgpt-0.1.1.dist-info/top_level.txt,sha256=akNCJBbMSLV8nkOzdVzdy13hMJ5CIQURnAS_YYEDVwA,17
|
74
|
+
cehrgpt-0.1.1.dist-info/RECORD,,
|
@@ -1,71 +0,0 @@
|
|
1
|
-
import torch
|
2
|
-
from torch.nn.utils.rnn import pad_sequence
|
3
|
-
|
4
|
-
from cehrgpt.data.hf_cehrgpt_dataset_collator import CehrGptDataCollator
|
5
|
-
|
6
|
-
|
7
|
-
class CehrGptDPODataCollator(CehrGptDataCollator):
|
8
|
-
|
9
|
-
def create_preference_inputs(self, examples, prefix):
|
10
|
-
batch = {}
|
11
|
-
# Assume that each example in the batch is a dictionary with 'input_ids' and 'attention_mask'
|
12
|
-
batch_input_ids = [
|
13
|
-
self._try_reverse_tensor(
|
14
|
-
self._convert_to_tensor(example[f"{prefix}_input_ids"])
|
15
|
-
)
|
16
|
-
for example in examples
|
17
|
-
]
|
18
|
-
batch_attention_mask = [
|
19
|
-
self._try_reverse_tensor(
|
20
|
-
torch.ones_like(
|
21
|
-
self._convert_to_tensor(example[f"{prefix}_input_ids"]),
|
22
|
-
dtype=torch.float,
|
23
|
-
)
|
24
|
-
)
|
25
|
-
for example in examples
|
26
|
-
]
|
27
|
-
# Pad sequences to the max length in the batch
|
28
|
-
batch[f"{prefix}_input_ids"] = self._try_reverse_tensor(
|
29
|
-
pad_sequence(
|
30
|
-
batch_input_ids,
|
31
|
-
batch_first=True,
|
32
|
-
padding_value=self.tokenizer.pad_token_id,
|
33
|
-
).to(torch.int64)
|
34
|
-
)
|
35
|
-
batch[f"{prefix}_attention_mask"] = self._try_reverse_tensor(
|
36
|
-
pad_sequence(batch_attention_mask, batch_first=True, padding_value=0.0)
|
37
|
-
)
|
38
|
-
assert batch[f"{prefix}_input_ids"].shape[1] <= self.max_length
|
39
|
-
assert batch[f"{prefix}_attention_mask"].shape[1] <= self.max_length
|
40
|
-
|
41
|
-
if self.include_values:
|
42
|
-
batch_value_indicators = [
|
43
|
-
self._try_reverse_tensor(
|
44
|
-
self._convert_to_tensor(example[f"{prefix}_value_indicators"])
|
45
|
-
)
|
46
|
-
for example in examples
|
47
|
-
]
|
48
|
-
batch_values = [
|
49
|
-
self._try_reverse_tensor(
|
50
|
-
self._convert_to_tensor(example[f"{prefix}__values"])
|
51
|
-
)
|
52
|
-
for example in examples
|
53
|
-
]
|
54
|
-
|
55
|
-
batch[f"{prefix}_value_indicators"] = self._try_reverse_tensor(
|
56
|
-
pad_sequence(
|
57
|
-
batch_value_indicators, batch_first=True, padding_value=False
|
58
|
-
)
|
59
|
-
)
|
60
|
-
batch[f"{prefix}_values"] = self._try_reverse_tensor(
|
61
|
-
pad_sequence(batch_values, batch_first=True, padding_value=-1.0)
|
62
|
-
)
|
63
|
-
assert batch[f"{prefix}_value_indicators"].shape[1] <= self.max_length
|
64
|
-
assert batch[f"{prefix}_values"].shape[1] <= self.max_length
|
65
|
-
return batch
|
66
|
-
|
67
|
-
def __call__(self, examples):
|
68
|
-
batch_chosen = self.create_preference_inputs(examples, "chosen")
|
69
|
-
batch_rejected = self.create_preference_inputs(examples, "rejected")
|
70
|
-
batch_chosen.update(batch_rejected)
|
71
|
-
return batch_chosen
|
@@ -1,61 +0,0 @@
|
|
1
|
-
import copy
|
2
|
-
from typing import Any, Dict
|
3
|
-
|
4
|
-
import numpy as np
|
5
|
-
from cehrbert.data_generators.hf_data_generator.hf_dataset_mapping import DatasetMapping
|
6
|
-
|
7
|
-
from cehrgpt.models.tokenization_hf_cehrgpt import CehrGptTokenizer
|
8
|
-
|
9
|
-
|
10
|
-
class HFCehrGptDPOTokenizationMapping(DatasetMapping):
|
11
|
-
def __init__(
|
12
|
-
self,
|
13
|
-
concept_tokenizer: CehrGptTokenizer,
|
14
|
-
):
|
15
|
-
self._concept_tokenizer = concept_tokenizer
|
16
|
-
self._lab_token_ids = self._concept_tokenizer.lab_token_ids
|
17
|
-
|
18
|
-
def transform_with_prefix(self, record: Dict[str, Any], prefix) -> Dict[str, Any]:
|
19
|
-
concept_ids = record[f"{prefix}_concept_ids"]
|
20
|
-
input_ids = self._concept_tokenizer.encode(concept_ids)
|
21
|
-
record[f"{prefix}_input_ids"] = input_ids
|
22
|
-
|
23
|
-
if f"{prefix}_concept_value_masks" in record:
|
24
|
-
concept_value_masks = record[f"{prefix}_concept_value_masks"]
|
25
|
-
concept_values = record[f"{prefix}_concept_values"]
|
26
|
-
# If any concept has a value associated with it, we normalize the value
|
27
|
-
if np.any(np.asarray(concept_value_masks) > 0):
|
28
|
-
units = record[f"{prefix}_units"]
|
29
|
-
normalized_concept_values = copy.deepcopy(concept_values)
|
30
|
-
for i, (
|
31
|
-
concept_id,
|
32
|
-
unit,
|
33
|
-
token_id,
|
34
|
-
concept_value_mask,
|
35
|
-
concept_value,
|
36
|
-
) in enumerate(
|
37
|
-
zip(
|
38
|
-
concept_ids,
|
39
|
-
units,
|
40
|
-
input_ids,
|
41
|
-
concept_value_masks,
|
42
|
-
concept_values,
|
43
|
-
)
|
44
|
-
):
|
45
|
-
if token_id in self._lab_token_ids:
|
46
|
-
normalized_concept_value = self._concept_tokenizer.normalize(
|
47
|
-
concept_id, unit, concept_value
|
48
|
-
)
|
49
|
-
normalized_concept_values[i] = normalized_concept_value
|
50
|
-
record[f"{prefix}_concept_values"] = normalized_concept_values
|
51
|
-
# Overwrite the column names
|
52
|
-
record[f"{prefix}_value_indicators"] = record[
|
53
|
-
f"{prefix}_concept_value_masks"
|
54
|
-
]
|
55
|
-
record[f"{prefix}_values"] = record[f"{prefix}_concept_values"]
|
56
|
-
return record
|
57
|
-
|
58
|
-
def transform(self, record: Dict[str, Any]) -> Dict[str, Any]:
|
59
|
-
record = self.transform_with_prefix(record, prefix="chosen")
|
60
|
-
record.update(self.transform_with_prefix(record, prefix="rejected"))
|
61
|
-
return record
|
@@ -1,224 +0,0 @@
|
|
1
|
-
import datetime
|
2
|
-
import os
|
3
|
-
import random
|
4
|
-
import uuid
|
5
|
-
|
6
|
-
import pandas as pd
|
7
|
-
import torch
|
8
|
-
from cehrbert.runners.runner_util import load_parquet_as_dataset
|
9
|
-
from transformers.utils import is_flash_attn_2_available, logging
|
10
|
-
|
11
|
-
from cehrgpt.cehrgpt_args import create_inference_base_arg_parser
|
12
|
-
from cehrgpt.generation.generate_batch_hf_gpt_sequence import (
|
13
|
-
generate_single_batch,
|
14
|
-
normalize_value,
|
15
|
-
)
|
16
|
-
from cehrgpt.gpt_utils import get_cehrgpt_output_folder
|
17
|
-
from cehrgpt.models.hf_cehrgpt import CEHRGPT2LMHeadModel
|
18
|
-
from cehrgpt.models.tokenization_hf_cehrgpt import CehrGptTokenizer
|
19
|
-
|
20
|
-
LOG = logging.get_logger("transformers")
|
21
|
-
|
22
|
-
|
23
|
-
def main(args):
|
24
|
-
if torch.cuda.is_available():
|
25
|
-
device = torch.device("cuda")
|
26
|
-
else:
|
27
|
-
device = torch.device("cpu")
|
28
|
-
|
29
|
-
cehrgpt_tokenizer = CehrGptTokenizer.from_pretrained(args.tokenizer_folder)
|
30
|
-
cehrgpt_model = (
|
31
|
-
CEHRGPT2LMHeadModel.from_pretrained(
|
32
|
-
args.model_folder,
|
33
|
-
attn_implementation=(
|
34
|
-
"flash_attention_2" if is_flash_attn_2_available() else "eager"
|
35
|
-
),
|
36
|
-
torch_dtype=(
|
37
|
-
torch.bfloat16 if is_flash_attn_2_available() else torch.float32
|
38
|
-
),
|
39
|
-
)
|
40
|
-
.eval()
|
41
|
-
.to(device)
|
42
|
-
)
|
43
|
-
cehrgpt_model.generation_config.pad_token_id = cehrgpt_tokenizer.pad_token_id
|
44
|
-
cehrgpt_model.generation_config.eos_token_id = cehrgpt_tokenizer.end_token_id
|
45
|
-
cehrgpt_model.generation_config.bos_token_id = cehrgpt_tokenizer.end_token_id
|
46
|
-
|
47
|
-
folder_name = get_cehrgpt_output_folder(args, cehrgpt_tokenizer)
|
48
|
-
output_folder_name = os.path.join(
|
49
|
-
args.output_folder, folder_name, "generated_sequences"
|
50
|
-
)
|
51
|
-
|
52
|
-
if not os.path.exists(output_folder_name):
|
53
|
-
os.makedirs(output_folder_name)
|
54
|
-
|
55
|
-
LOG.info(f"Loading tokenizer at {args.model_folder}")
|
56
|
-
LOG.info(f"Loading model at {args.model_folder}")
|
57
|
-
LOG.info(f"Write sequences to {output_folder_name}")
|
58
|
-
LOG.info(f"Context window {args.context_window}")
|
59
|
-
LOG.info(f"Temperature {args.temperature}")
|
60
|
-
LOG.info(f"Repetition Penalty {args.repetition_penalty}")
|
61
|
-
LOG.info(f"Sampling Strategy {args.sampling_strategy}")
|
62
|
-
LOG.info(f"Num beam {args.num_beams}")
|
63
|
-
LOG.info(f"Num beam groups {args.num_beam_groups}")
|
64
|
-
LOG.info(f"Epsilon cutoff {args.epsilon_cutoff}")
|
65
|
-
LOG.info(f"Top P {args.top_p}")
|
66
|
-
LOG.info(f"Top K {args.top_k}")
|
67
|
-
LOG.info(f"Loading sequence_data_path at {args.sequence_data_path}")
|
68
|
-
|
69
|
-
dataset = load_parquet_as_dataset(args.sequence_data_path)
|
70
|
-
total_rows = len(dataset)
|
71
|
-
float(args.batch_size) / total_rows
|
72
|
-
num_of_batches = args.num_of_patients // args.batch_size + 1
|
73
|
-
sequence_to_flush = []
|
74
|
-
for i in range(num_of_batches):
|
75
|
-
LOG.info(f"{datetime.datetime.now()}: Batch {i} started")
|
76
|
-
sample_data = []
|
77
|
-
while len(sample_data) == 0:
|
78
|
-
random_indices = random.sample(range(total_rows), k=1)
|
79
|
-
for row in dataset.select(random_indices):
|
80
|
-
if 4 <= len(row["concept_ids"]) <= cehrgpt_model.config.n_positions:
|
81
|
-
sample_data.append(row)
|
82
|
-
prompts = []
|
83
|
-
chosen_responses = []
|
84
|
-
cutoff_frac = random.uniform(0, args.cutoff_frac_max)
|
85
|
-
for row in sample_data:
|
86
|
-
seq_len = len(row["concept_ids"])
|
87
|
-
prompt_len = max(4, int(seq_len * cutoff_frac))
|
88
|
-
prompts.append(cehrgpt_tokenizer.encode(row["concept_ids"][:prompt_len]))
|
89
|
-
chosen_responses.append(
|
90
|
-
{
|
91
|
-
"person_id": row["person_id"],
|
92
|
-
"chosen_concept_ids": (
|
93
|
-
row["concept_ids"] if "concept_ids" in row else None
|
94
|
-
),
|
95
|
-
"chosen_concept_values": (
|
96
|
-
row["concept_values"] if "concept_values" in row else None
|
97
|
-
),
|
98
|
-
"chosen_concept_value_masks": (
|
99
|
-
row["concept_value_masks"]
|
100
|
-
if "concept_value_masks" in row
|
101
|
-
else None
|
102
|
-
),
|
103
|
-
"chosen_units": row["units"] if "units" in row else None,
|
104
|
-
"prompt_length": prompt_len,
|
105
|
-
}
|
106
|
-
)
|
107
|
-
|
108
|
-
batch_sequences = generate_single_batch(
|
109
|
-
cehrgpt_model,
|
110
|
-
cehrgpt_tokenizer,
|
111
|
-
prompts=prompts,
|
112
|
-
max_new_tokens=args.context_window,
|
113
|
-
mini_num_of_concepts=args.min_num_of_concepts,
|
114
|
-
top_p=args.top_p,
|
115
|
-
top_k=args.top_k,
|
116
|
-
temperature=args.temperature,
|
117
|
-
repetition_penalty=args.repetition_penalty,
|
118
|
-
num_beams=args.num_beams,
|
119
|
-
num_beam_groups=args.num_beam_groups,
|
120
|
-
epsilon_cutoff=args.epsilon_cutoff,
|
121
|
-
device=device,
|
122
|
-
)
|
123
|
-
|
124
|
-
# Clear the cache
|
125
|
-
torch.cuda.empty_cache()
|
126
|
-
|
127
|
-
for seq, value_indicator, value, chosen_response in zip(
|
128
|
-
batch_sequences["sequences"],
|
129
|
-
batch_sequences["value_indicators"],
|
130
|
-
batch_sequences["values"],
|
131
|
-
chosen_responses,
|
132
|
-
):
|
133
|
-
output = {"rejected_concept_ids": seq}
|
134
|
-
normalized_values, units = normalize_value(
|
135
|
-
seq, value_indicator, value, cehrgpt_tokenizer
|
136
|
-
)
|
137
|
-
if normalized_values is not None:
|
138
|
-
output["rejected_concept_values"] = normalized_values
|
139
|
-
if value_indicator is not None:
|
140
|
-
output["rejected_concept_value_masks"] = value_indicator
|
141
|
-
if units is not None:
|
142
|
-
output["rejected_units"] = units
|
143
|
-
output.update(chosen_response)
|
144
|
-
sequence_to_flush.append(output)
|
145
|
-
|
146
|
-
if len(sequence_to_flush) >= args.buffer_size:
|
147
|
-
LOG.info(f"{datetime.datetime.now()}: Flushing to the Disk at Batch {i}")
|
148
|
-
pd.DataFrame(
|
149
|
-
sequence_to_flush,
|
150
|
-
columns=[
|
151
|
-
"person_id",
|
152
|
-
"chosen_concept_ids",
|
153
|
-
"chosen_concept_values",
|
154
|
-
"chosen_concept_value_masks",
|
155
|
-
"chosen_units",
|
156
|
-
"prompt_length",
|
157
|
-
"rejected_concept_ids",
|
158
|
-
"rejected_concept_values",
|
159
|
-
"rejected_concept_value_masks",
|
160
|
-
"rejected_units",
|
161
|
-
],
|
162
|
-
).to_parquet(os.path.join(output_folder_name, f"{uuid.uuid4()}.parquet"))
|
163
|
-
sequence_to_flush.clear()
|
164
|
-
|
165
|
-
if len(sequence_to_flush) > 0:
|
166
|
-
LOG.info(f"{datetime.datetime.now()}: Flushing to the Disk at Final Batch")
|
167
|
-
pd.DataFrame(
|
168
|
-
sequence_to_flush,
|
169
|
-
columns=[
|
170
|
-
"person_id",
|
171
|
-
"chosen_concept_ids",
|
172
|
-
"chosen_concept_values",
|
173
|
-
"chosen_concept_value_masks",
|
174
|
-
"chosen_units",
|
175
|
-
"prompt_length",
|
176
|
-
"rejected_concept_ids",
|
177
|
-
"rejected_concept_values",
|
178
|
-
"rejected_concept_value_masks",
|
179
|
-
"rejected_units",
|
180
|
-
],
|
181
|
-
).to_parquet(os.path.join(output_folder_name, f"{uuid.uuid4()}-last.parquet"))
|
182
|
-
|
183
|
-
|
184
|
-
def create_arg_parser():
|
185
|
-
base_arg_parser = create_inference_base_arg_parser(
|
186
|
-
description="Arguments for generating paired patient sequences"
|
187
|
-
)
|
188
|
-
base_arg_parser.add_argument(
|
189
|
-
"--num_of_patients",
|
190
|
-
dest="num_of_patients",
|
191
|
-
action="store",
|
192
|
-
type=int,
|
193
|
-
help="The number of patients that will be generated",
|
194
|
-
required=True,
|
195
|
-
)
|
196
|
-
base_arg_parser.add_argument(
|
197
|
-
"--sequence_data_path",
|
198
|
-
dest="sequence_data_path",
|
199
|
-
action="store",
|
200
|
-
help="The path for your sequence data",
|
201
|
-
required=True,
|
202
|
-
)
|
203
|
-
base_arg_parser.add_argument(
|
204
|
-
"--cutoff_frac_max",
|
205
|
-
dest="cutoff_frac_max",
|
206
|
-
action="store",
|
207
|
-
type=float,
|
208
|
-
help="The max fraction of the patient sequences that will be used for prompting",
|
209
|
-
required=False,
|
210
|
-
default=0.5,
|
211
|
-
)
|
212
|
-
base_arg_parser.add_argument(
|
213
|
-
"--num_proc",
|
214
|
-
dest="num_proc",
|
215
|
-
action="store",
|
216
|
-
type=int,
|
217
|
-
required=False,
|
218
|
-
default=1,
|
219
|
-
)
|
220
|
-
return base_arg_parser
|
221
|
-
|
222
|
-
|
223
|
-
if __name__ == "__main__":
|
224
|
-
main(create_arg_parser().parse_args())
|