cehrgpt 0.0.2__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. cehrgpt/analysis/irregularity.py +36 -0
  2. cehrgpt/data/hf_cehrgpt_dataset.py +25 -4
  3. cehrgpt/data/hf_cehrgpt_dataset_collator.py +635 -97
  4. cehrgpt/data/hf_cehrgpt_dataset_mapping.py +308 -95
  5. cehrgpt/data/sample_packing_sampler.py +181 -0
  6. cehrgpt/generation/generate_batch_hf_gpt_sequence.py +12 -9
  7. cehrgpt/generation/omop_converter_batch.py +32 -2
  8. cehrgpt/gpt_utils.py +20 -2
  9. cehrgpt/models/config.py +35 -0
  10. cehrgpt/models/hf_cehrgpt.py +470 -106
  11. cehrgpt/models/hf_modeling_outputs.py +1 -0
  12. cehrgpt/models/special_tokens.py +1 -0
  13. cehrgpt/models/tokenization_hf_cehrgpt.py +358 -71
  14. cehrgpt/runners/data_utils.py +358 -0
  15. cehrgpt/runners/gpt_runner_util.py +0 -10
  16. cehrgpt/runners/hf_cehrgpt_finetune_runner.py +181 -283
  17. cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +288 -112
  18. cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +90 -0
  19. cehrgpt/runners/hyperparameter_search_util.py +10 -8
  20. cehrgpt/runners/sample_packing_trainer.py +185 -0
  21. cehrgpt/simulations/generate_plots.py +95 -0
  22. cehrgpt/simulations/run_simulation.sh +24 -0
  23. cehrgpt/simulations/time_embedding_simulation.py +250 -0
  24. cehrgpt/simulations/time_token_simulation.py +177 -0
  25. cehrgpt/time_to_event/config/1_year_cabg.yaml +23 -0
  26. cehrgpt/time_to_event/time_to_event_model.py +2 -13
  27. cehrgpt/time_to_event/time_to_event_prediction.py +27 -13
  28. cehrgpt/tools/linear_prob/__init__.py +0 -0
  29. cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +495 -0
  30. cehrgpt/tools/linear_prob/train_with_cehrgpt_features.py +152 -0
  31. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/METADATA +11 -8
  32. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/RECORD +36 -32
  33. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/WHEEL +1 -1
  34. cehrgpt/data/hf_cehrgpt_dpo_collator.py +0 -71
  35. cehrgpt/data/hf_cehrgpt_dpo_dataset_mapping.py +0 -61
  36. cehrgpt/generation/generate_paired_cehrgpt_sequence.py +0 -224
  37. cehrgpt/rl_finetune/cehrgpt_dpo_trainer.py +0 -586
  38. cehrgpt/rl_finetune/cehrgpt_ppo_trainer.py +0 -464
  39. cehrgpt/rl_finetune/ppo_finetune.py +0 -394
  40. cehrgpt/rl_finetune/ppo_finetune_v2.py +0 -373
  41. cehrgpt/runners/hf_cehrgpt_dpo_runner.py +0 -119
  42. /cehrgpt/{rl_finetune → simulations}/__init__.py +0 -0
  43. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info/licenses}/LICENSE +0 -0
  44. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/top_level.txt +0 -0
@@ -1,8 +1,9 @@
1
1
  __init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  cehrgpt/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  cehrgpt/cehrgpt_args.py,sha256=zPLp9Qjlq5PapWx3R15BNnyaX8zV3dxr4PuWj71r0Lg,3516
4
- cehrgpt/gpt_utils.py,sha256=bksHCXMX4j859VSv1Q284rVr4gn1Y8dCx4a_V-g4mug,10939
4
+ cehrgpt/gpt_utils.py,sha256=IA5qw-hxcKkGO07AB47lDNRU6mlb9jblpKO7KeLLN78,11342
5
5
  cehrgpt/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
6
+ cehrgpt/analysis/irregularity.py,sha256=Rfl_daMvSh9cZ68vUwfmuH-JYCFXdAph2ITHHffYC0Y,1047
6
7
  cehrgpt/analysis/privacy/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
7
8
  cehrgpt/analysis/privacy/attribute_inference.py,sha256=0ANVW0I5uvOl6IxQ15-vMVQd0mugOgSGReBUQQESImg,9368
8
9
  cehrgpt/analysis/privacy/attribute_inference_config.yml,sha256=hfLfpBlDqqsNOynpRHK414vV24edKA6ta-inmEhM2ao,103272
@@ -11,24 +12,22 @@ cehrgpt/analysis/privacy/nearest_neighbor_inference.py,sha256=qoJgWW7VsUMzjMGpTa
11
12
  cehrgpt/analysis/privacy/reid_inference.py,sha256=Pypd3QJXQNY8VljpnIEa5zeAbTZHMjQOazaL-9VsBGw,13955
12
13
  cehrgpt/analysis/privacy/utils.py,sha256=CRA4H9mPLBjMQGKzZ_x_3ro3tMap-NjsMDVqSOjHSVQ,8226
13
14
  cehrgpt/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
- cehrgpt/data/hf_cehrgpt_dataset.py,sha256=7hvjjqE8WInVuRvAtNkFI_J-xluFBv1Ij4TPTdUxPM4,2570
15
- cehrgpt/data/hf_cehrgpt_dataset_collator.py,sha256=RYw5Isrwa4sdyQQ3Nf3cu7xPDA3m-c5ecCFf_y1TJKY,20497
16
- cehrgpt/data/hf_cehrgpt_dataset_mapping.py,sha256=IjGwLKbEfNPxH3hsNmb8p48_imHnMWtslDK6f7R_1pc,16053
17
- cehrgpt/data/hf_cehrgpt_dpo_collator.py,sha256=cqDK0SUOt3yAqUHWKGuLVi3WmmUMZ6eyxTv9fC9idZA,2787
18
- cehrgpt/data/hf_cehrgpt_dpo_dataset_mapping.py,sha256=uCLF5VEsyZAG1aNwqEM6Jy5Lx7bI5ALku52Z6Anine0,2574
15
+ cehrgpt/data/hf_cehrgpt_dataset.py,sha256=hwJlGW7XiJIr6cXtmwvReQf9yLZJPD-dvJGvRg5ERqU,3755
16
+ cehrgpt/data/hf_cehrgpt_dataset_collator.py,sha256=ACMXiaYnR3bKD5dRleL0_siEvhL-2HAFcy5eBgvxnH4,44412
17
+ cehrgpt/data/hf_cehrgpt_dataset_mapping.py,sha256=KU0WMjc2vT1zBAl7JJkOc8dgGxsL1uFDy4dDrv-RkII,25668
18
+ cehrgpt/data/sample_packing_sampler.py,sha256=vovGMtmhG70DRkSCeiaDEJ_rjKZ38y-YLaI1kkhFEkI,6747
19
19
  cehrgpt/generation/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
20
20
  cehrgpt/generation/chatgpt_generation.py,sha256=SrnLwHLdNtnAOEg36gNjqfoT9yd12iyPgpZffL2AFJo,4428
21
- cehrgpt/generation/generate_batch_hf_gpt_sequence.py,sha256=-WLpKlulVVDJSdA2jXyp87gfLW4Q3aAtwULK8fDtn_E,11408
22
- cehrgpt/generation/generate_paired_cehrgpt_sequence.py,sha256=fLu3SHhRe_ZQfS09ebOktq2dekStgYfxmbrRawZQAO4,8280
23
- cehrgpt/generation/omop_converter_batch.py,sha256=-c0AlDVy5pJ5Afhr8ERiCHhoRrEk8ozJi3g0yFdWaMI,25348
21
+ cehrgpt/generation/generate_batch_hf_gpt_sequence.py,sha256=uSEh8aMmPD61nGewIaPSkIqm-2AxDjCBiu4cBfxHxU4,11503
22
+ cehrgpt/generation/omop_converter_batch.py,sha256=LUmCD-t_6ZP1YfNDZCqYewl-XIIaIgRZ_dAxuR_VdCQ,26275
24
23
  cehrgpt/generation/omop_entity.py,sha256=Q5Sr0AlyuPAm1FRPfnJO13q-u1fqRgYVHXruZ9g4xNE,19400
25
24
  cehrgpt/models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
26
- cehrgpt/models/config.py,sha256=xek4W_siO7WtMAKE7zDsENotsIE70F8dcW-PTC0kBKk,9700
27
- cehrgpt/models/hf_cehrgpt.py,sha256=CKseTvGkBFwXK40Z_uKD1_d84oSYCFqKmHI0qtdk72g,75757
28
- cehrgpt/models/hf_modeling_outputs.py,sha256=LaWa1jI6BRIKMEjWOy1QUeOfTur5y_p2c-JyuGVTdtw,10301
25
+ cehrgpt/models/config.py,sha256=nOAKgH5420HLCcy7n1hE7MbqR861Iq4DTutKoAd25tg,11090
26
+ cehrgpt/models/hf_cehrgpt.py,sha256=77CAkdMPgxD4xSpFU7gYGzRn6_Iv-4q7FnHpnZGsKxw,92450
27
+ cehrgpt/models/hf_modeling_outputs.py,sha256=5X4WEYKqT37phv_e5ZAv3A_N0wqdAUJLJRm6TxS6dDQ,10356
29
28
  cehrgpt/models/pretrained_embeddings.py,sha256=vLLVs17TLpXRqCVEWQxGGwPHkUJUO7laNTeBuyBK_yk,3238
30
- cehrgpt/models/special_tokens.py,sha256=-a7HPJBbdIH0qQ6B3CcRKqvpG6FZlm4nbVPTswGSJ4U,485
31
- cehrgpt/models/tokenization_hf_cehrgpt.py,sha256=JAZjnmQq-JjUxZK7XIsqdZB07ZB7BC2WraCjpO_6AOM,42161
29
+ cehrgpt/models/special_tokens.py,sha256=lrw45B4tea4Dsajn09Cz6w5D2TfHmYXikZkgwnstu_o,521
30
+ cehrgpt/models/tokenization_hf_cehrgpt.py,sha256=cAxHTctpVBxfWfC3XcwDQavN1zwWN9Nid_Fajd5zQWQ,53159
32
31
  cehrgpt/omop/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
33
32
  cehrgpt/omop/condition_era.py,sha256=hPZALz2XaWnro_1bwIYNkI48foOJjueyg3CZ1BliCno,626
34
33
  cehrgpt/omop/observation_period.py,sha256=TRMgv5Ya2RaS2im7oQ6BLC_5JL9EJYNYR62ApxIuHvg,1211
@@ -38,22 +37,24 @@ cehrgpt/omop/sample_omop_tables.py,sha256=2JZ8BNSvssceinwFanvuCRh-YlKrKn25U9w1pL
38
37
  cehrgpt/omop/queries/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
39
38
  cehrgpt/omop/queries/condition_era.py,sha256=LFB6vBAvshHJxtYIRkl7cfrF0kf7ay0piBKpmHBwrpE,2578
40
39
  cehrgpt/omop/queries/observation_period.py,sha256=fpzr5DMNw-QLoSwp2Iatfch88E3hyhZ75usiIdG3A0U,6410
41
- cehrgpt/rl_finetune/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
42
- cehrgpt/rl_finetune/cehrgpt_dpo_trainer.py,sha256=VQHf5vy5i8K1imcqYakhitfAW-d2mnaEzkSoAYSW5kg,26062
43
- cehrgpt/rl_finetune/cehrgpt_ppo_trainer.py,sha256=nYWYPCaWNjDGEwlo6UHOK1rvOZUx1vuJ8kYuAszI8Zg,17925
44
- cehrgpt/rl_finetune/ppo_finetune.py,sha256=tSy-C0Kzgj5ffclBIDj-RTj78ZfrLmTESxVxd0n9yuE,13971
45
- cehrgpt/rl_finetune/ppo_finetune_v2.py,sha256=7dChwKpq4zKmpkcxP4hryqBoIkcwmTJ44_BF8R2RghQ,13285
46
40
  cehrgpt/runners/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
47
- cehrgpt/runners/gpt_runner_util.py,sha256=88HKSVj-ADGBCMo7C3znKSMPnAAALa1iU_6P6i9sD0M,3867
48
- cehrgpt/runners/hf_cehrgpt_dpo_runner.py,sha256=Z4qNl9CZFC5YvUBc9ZzdOV5wsBFvMTdxfTn4jjtJQ-Y,4583
49
- cehrgpt/runners/hf_cehrgpt_finetune_runner.py,sha256=reflNRb6YB6f_3jAfzFAdwKtTl6hvdIp9Jc7DC-Sv-U,32580
50
- cehrgpt/runners/hf_cehrgpt_pretrain_runner.py,sha256=L3UpjtzxuS8a_tshlqpZN_sXnJSs3yzry0GZNT__05A,18200
51
- cehrgpt/runners/hf_gpt_runner_argument_dataclass.py,sha256=gKVf4BLzNCFiJR7nZVkf-QRcj8fAEVvIUTV-AVH0g_U,5312
52
- cehrgpt/runners/hyperparameter_search_util.py,sha256=i4qAb_22JO78l40MSyBPwDgAGuGc96efXmg_833cSSo,9044
41
+ cehrgpt/runners/data_utils.py,sha256=I6k1TkiiZR8ggw3eVO16g2lVPY-Hu3b-nbrIOKlFIO0,15528
42
+ cehrgpt/runners/gpt_runner_util.py,sha256=YJQSRW9Mo4TjXSOUOTf6BUFcs1MGFiXU5T4ztKZcYhU,3485
43
+ cehrgpt/runners/hf_cehrgpt_finetune_runner.py,sha256=GVbHHqf5TWGbVWlQG-XurgYH8pKRjTk8ug_ib9L9U7E,28118
44
+ cehrgpt/runners/hf_cehrgpt_pretrain_runner.py,sha256=ERSnvB38fPYVghtKQeNTZ8VfeXnoRcCHB0cWISWaZ84,26523
45
+ cehrgpt/runners/hf_gpt_runner_argument_dataclass.py,sha256=ejAFLM9g765p1fyeF5MITsiIeWHKkz9wTeFDeVgxSto,8851
46
+ cehrgpt/runners/hyperparameter_search_util.py,sha256=YWdFQ1igQs-G_wqWUrUzYraGiz8OSpSYyvid-I5nhWA,9262
47
+ cehrgpt/runners/sample_packing_trainer.py,sha256=Zb7Aqwnk8-VqrjEKUVeg5XzZWmHxXOU2sDn1YURS-FU,7960
48
+ cehrgpt/simulations/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
49
+ cehrgpt/simulations/generate_plots.py,sha256=BTZ71r8Kah0PMorkiO3vw55_p_9U1Z8KiD3GsPfaV0s,2520
50
+ cehrgpt/simulations/run_simulation.sh,sha256=DcJ6B19jIteUO0pZ0Tc21876lB9XxQHFAxlre7MtAzk,795
51
+ cehrgpt/simulations/time_embedding_simulation.py,sha256=HZ-imXH-bN-QYZN1PAfcERmNtaWIwKjbf0UrZduwCiA,8687
52
+ cehrgpt/simulations/time_token_simulation.py,sha256=sLg8vVXydvR_zk3BbqyrlA7sDIdhFnS-s5pSKcCilSc,6057
53
53
  cehrgpt/time_to_event/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
54
- cehrgpt/time_to_event/time_to_event_model.py,sha256=tfXa24l_0q1TBZ68BPRrHRC_3KRWYxrWGIv4myJlIb8,8497
55
- cehrgpt/time_to_event/time_to_event_prediction.py,sha256=Ajesq2gSsILghWHCTLiiBhcyOCa7m6JPPMdi_xvBlR4,12624
54
+ cehrgpt/time_to_event/time_to_event_model.py,sha256=Plm0bZxvlAbnMl82DTBXWvaXLvrqcdkzcP_celX8WC4,8055
55
+ cehrgpt/time_to_event/time_to_event_prediction.py,sha256=W2e7UqIV7ELdfTy997HS66vggjnhdncCKt840knI0Dw,13183
56
56
  cehrgpt/time_to_event/time_to_event_utils.py,sha256=KN4hwGgxy2nJtO7osbYQBF3-HpmGUWefNfexzPYiEwc,1937
57
+ cehrgpt/time_to_event/config/1_year_cabg.yaml,sha256=SFF2-F5D02pDSMRddDrEUoERBCd0t2Hzln_xC-Mo2hA,407
57
58
  cehrgpt/time_to_event/config/30_day_readmission.yaml,sha256=Hn5KnEXMtSV_CtCpmAU4wjkc0-gTXvniaH991TSbUXA,234
58
59
  cehrgpt/time_to_event/config/next_visit_type_prediction.yaml,sha256=WMj2ZutEvHKIMyGG51xtXaL6MyRANKvpg9xT8ouctLc,319
59
60
  cehrgpt/time_to_event/config/t2dm_hf.yaml,sha256=_oMQzh2eJTYzEaMOpmhAzbX-qmdsKlkORELL6HxOxHo,202
@@ -63,8 +64,11 @@ cehrgpt/tools/generate_causal_patient_split_by_age.py,sha256=dmHiPAL_kR1WrhRteIi
63
64
  cehrgpt/tools/generate_pretrained_embeddings.py,sha256=lhFSacGv8bMld6qigKZN8Op8eXpFi0DsJuQbWKOWXqI,4160
64
65
  cehrgpt/tools/merge_synthetic_real_dataasets.py,sha256=O1dbQ32Le0t15fwymwAh9mfNVLEWuFwW53DNvESrWbY,7589
65
66
  cehrgpt/tools/upload_omop_tables.py,sha256=vdBAbkeAsGPA4NsyhNjelPVj3gS8yzmS1sKNM1Qk96g,3791
66
- cehrgpt-0.0.2.dist-info/LICENSE,sha256=LOfC32zkfUIdGm8e_098jPbt8OHKtNWymDzxn2pA9Zk,1093
67
- cehrgpt-0.0.2.dist-info/METADATA,sha256=joUmDJWMEBvYphrkwYiK273FwSL9okY74D93ncrbvMU,4878
68
- cehrgpt-0.0.2.dist-info/WHEEL,sha256=52BFRY2Up02UkjOa29eZOS2VxUrpPORXg1pkohGGUS8,91
69
- cehrgpt-0.0.2.dist-info/top_level.txt,sha256=akNCJBbMSLV8nkOzdVzdy13hMJ5CIQURnAS_YYEDVwA,17
70
- cehrgpt-0.0.2.dist-info/RECORD,,
67
+ cehrgpt/tools/linear_prob/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
68
+ cehrgpt/tools/linear_prob/compute_cehrgpt_features.py,sha256=q0rmlBWDDEkjHjwcTouGUhCYa32a1vRicaDOAMsdW0I,20741
69
+ cehrgpt/tools/linear_prob/train_with_cehrgpt_features.py,sha256=w0UvzMKYGenN_KDVnbzutmy8IPLUxW5hPvpKKxDSL5U,5820
70
+ cehrgpt-0.1.1.dist-info/licenses/LICENSE,sha256=LOfC32zkfUIdGm8e_098jPbt8OHKtNWymDzxn2pA9Zk,1093
71
+ cehrgpt-0.1.1.dist-info/METADATA,sha256=VnXH74vJQZaV7VxGiIvJnFhQA0jzJQNx86yHFkygobM,4922
72
+ cehrgpt-0.1.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
73
+ cehrgpt-0.1.1.dist-info/top_level.txt,sha256=akNCJBbMSLV8nkOzdVzdy13hMJ5CIQURnAS_YYEDVwA,17
74
+ cehrgpt-0.1.1.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (76.0.0)
2
+ Generator: setuptools (80.9.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,71 +0,0 @@
1
- import torch
2
- from torch.nn.utils.rnn import pad_sequence
3
-
4
- from cehrgpt.data.hf_cehrgpt_dataset_collator import CehrGptDataCollator
5
-
6
-
7
- class CehrGptDPODataCollator(CehrGptDataCollator):
8
-
9
- def create_preference_inputs(self, examples, prefix):
10
- batch = {}
11
- # Assume that each example in the batch is a dictionary with 'input_ids' and 'attention_mask'
12
- batch_input_ids = [
13
- self._try_reverse_tensor(
14
- self._convert_to_tensor(example[f"{prefix}_input_ids"])
15
- )
16
- for example in examples
17
- ]
18
- batch_attention_mask = [
19
- self._try_reverse_tensor(
20
- torch.ones_like(
21
- self._convert_to_tensor(example[f"{prefix}_input_ids"]),
22
- dtype=torch.float,
23
- )
24
- )
25
- for example in examples
26
- ]
27
- # Pad sequences to the max length in the batch
28
- batch[f"{prefix}_input_ids"] = self._try_reverse_tensor(
29
- pad_sequence(
30
- batch_input_ids,
31
- batch_first=True,
32
- padding_value=self.tokenizer.pad_token_id,
33
- ).to(torch.int64)
34
- )
35
- batch[f"{prefix}_attention_mask"] = self._try_reverse_tensor(
36
- pad_sequence(batch_attention_mask, batch_first=True, padding_value=0.0)
37
- )
38
- assert batch[f"{prefix}_input_ids"].shape[1] <= self.max_length
39
- assert batch[f"{prefix}_attention_mask"].shape[1] <= self.max_length
40
-
41
- if self.include_values:
42
- batch_value_indicators = [
43
- self._try_reverse_tensor(
44
- self._convert_to_tensor(example[f"{prefix}_value_indicators"])
45
- )
46
- for example in examples
47
- ]
48
- batch_values = [
49
- self._try_reverse_tensor(
50
- self._convert_to_tensor(example[f"{prefix}__values"])
51
- )
52
- for example in examples
53
- ]
54
-
55
- batch[f"{prefix}_value_indicators"] = self._try_reverse_tensor(
56
- pad_sequence(
57
- batch_value_indicators, batch_first=True, padding_value=False
58
- )
59
- )
60
- batch[f"{prefix}_values"] = self._try_reverse_tensor(
61
- pad_sequence(batch_values, batch_first=True, padding_value=-1.0)
62
- )
63
- assert batch[f"{prefix}_value_indicators"].shape[1] <= self.max_length
64
- assert batch[f"{prefix}_values"].shape[1] <= self.max_length
65
- return batch
66
-
67
- def __call__(self, examples):
68
- batch_chosen = self.create_preference_inputs(examples, "chosen")
69
- batch_rejected = self.create_preference_inputs(examples, "rejected")
70
- batch_chosen.update(batch_rejected)
71
- return batch_chosen
@@ -1,61 +0,0 @@
1
- import copy
2
- from typing import Any, Dict
3
-
4
- import numpy as np
5
- from cehrbert.data_generators.hf_data_generator.hf_dataset_mapping import DatasetMapping
6
-
7
- from cehrgpt.models.tokenization_hf_cehrgpt import CehrGptTokenizer
8
-
9
-
10
- class HFCehrGptDPOTokenizationMapping(DatasetMapping):
11
- def __init__(
12
- self,
13
- concept_tokenizer: CehrGptTokenizer,
14
- ):
15
- self._concept_tokenizer = concept_tokenizer
16
- self._lab_token_ids = self._concept_tokenizer.lab_token_ids
17
-
18
- def transform_with_prefix(self, record: Dict[str, Any], prefix) -> Dict[str, Any]:
19
- concept_ids = record[f"{prefix}_concept_ids"]
20
- input_ids = self._concept_tokenizer.encode(concept_ids)
21
- record[f"{prefix}_input_ids"] = input_ids
22
-
23
- if f"{prefix}_concept_value_masks" in record:
24
- concept_value_masks = record[f"{prefix}_concept_value_masks"]
25
- concept_values = record[f"{prefix}_concept_values"]
26
- # If any concept has a value associated with it, we normalize the value
27
- if np.any(np.asarray(concept_value_masks) > 0):
28
- units = record[f"{prefix}_units"]
29
- normalized_concept_values = copy.deepcopy(concept_values)
30
- for i, (
31
- concept_id,
32
- unit,
33
- token_id,
34
- concept_value_mask,
35
- concept_value,
36
- ) in enumerate(
37
- zip(
38
- concept_ids,
39
- units,
40
- input_ids,
41
- concept_value_masks,
42
- concept_values,
43
- )
44
- ):
45
- if token_id in self._lab_token_ids:
46
- normalized_concept_value = self._concept_tokenizer.normalize(
47
- concept_id, unit, concept_value
48
- )
49
- normalized_concept_values[i] = normalized_concept_value
50
- record[f"{prefix}_concept_values"] = normalized_concept_values
51
- # Overwrite the column names
52
- record[f"{prefix}_value_indicators"] = record[
53
- f"{prefix}_concept_value_masks"
54
- ]
55
- record[f"{prefix}_values"] = record[f"{prefix}_concept_values"]
56
- return record
57
-
58
- def transform(self, record: Dict[str, Any]) -> Dict[str, Any]:
59
- record = self.transform_with_prefix(record, prefix="chosen")
60
- record.update(self.transform_with_prefix(record, prefix="rejected"))
61
- return record
@@ -1,224 +0,0 @@
1
- import datetime
2
- import os
3
- import random
4
- import uuid
5
-
6
- import pandas as pd
7
- import torch
8
- from cehrbert.runners.runner_util import load_parquet_as_dataset
9
- from transformers.utils import is_flash_attn_2_available, logging
10
-
11
- from cehrgpt.cehrgpt_args import create_inference_base_arg_parser
12
- from cehrgpt.generation.generate_batch_hf_gpt_sequence import (
13
- generate_single_batch,
14
- normalize_value,
15
- )
16
- from cehrgpt.gpt_utils import get_cehrgpt_output_folder
17
- from cehrgpt.models.hf_cehrgpt import CEHRGPT2LMHeadModel
18
- from cehrgpt.models.tokenization_hf_cehrgpt import CehrGptTokenizer
19
-
20
- LOG = logging.get_logger("transformers")
21
-
22
-
23
- def main(args):
24
- if torch.cuda.is_available():
25
- device = torch.device("cuda")
26
- else:
27
- device = torch.device("cpu")
28
-
29
- cehrgpt_tokenizer = CehrGptTokenizer.from_pretrained(args.tokenizer_folder)
30
- cehrgpt_model = (
31
- CEHRGPT2LMHeadModel.from_pretrained(
32
- args.model_folder,
33
- attn_implementation=(
34
- "flash_attention_2" if is_flash_attn_2_available() else "eager"
35
- ),
36
- torch_dtype=(
37
- torch.bfloat16 if is_flash_attn_2_available() else torch.float32
38
- ),
39
- )
40
- .eval()
41
- .to(device)
42
- )
43
- cehrgpt_model.generation_config.pad_token_id = cehrgpt_tokenizer.pad_token_id
44
- cehrgpt_model.generation_config.eos_token_id = cehrgpt_tokenizer.end_token_id
45
- cehrgpt_model.generation_config.bos_token_id = cehrgpt_tokenizer.end_token_id
46
-
47
- folder_name = get_cehrgpt_output_folder(args, cehrgpt_tokenizer)
48
- output_folder_name = os.path.join(
49
- args.output_folder, folder_name, "generated_sequences"
50
- )
51
-
52
- if not os.path.exists(output_folder_name):
53
- os.makedirs(output_folder_name)
54
-
55
- LOG.info(f"Loading tokenizer at {args.model_folder}")
56
- LOG.info(f"Loading model at {args.model_folder}")
57
- LOG.info(f"Write sequences to {output_folder_name}")
58
- LOG.info(f"Context window {args.context_window}")
59
- LOG.info(f"Temperature {args.temperature}")
60
- LOG.info(f"Repetition Penalty {args.repetition_penalty}")
61
- LOG.info(f"Sampling Strategy {args.sampling_strategy}")
62
- LOG.info(f"Num beam {args.num_beams}")
63
- LOG.info(f"Num beam groups {args.num_beam_groups}")
64
- LOG.info(f"Epsilon cutoff {args.epsilon_cutoff}")
65
- LOG.info(f"Top P {args.top_p}")
66
- LOG.info(f"Top K {args.top_k}")
67
- LOG.info(f"Loading sequence_data_path at {args.sequence_data_path}")
68
-
69
- dataset = load_parquet_as_dataset(args.sequence_data_path)
70
- total_rows = len(dataset)
71
- float(args.batch_size) / total_rows
72
- num_of_batches = args.num_of_patients // args.batch_size + 1
73
- sequence_to_flush = []
74
- for i in range(num_of_batches):
75
- LOG.info(f"{datetime.datetime.now()}: Batch {i} started")
76
- sample_data = []
77
- while len(sample_data) == 0:
78
- random_indices = random.sample(range(total_rows), k=1)
79
- for row in dataset.select(random_indices):
80
- if 4 <= len(row["concept_ids"]) <= cehrgpt_model.config.n_positions:
81
- sample_data.append(row)
82
- prompts = []
83
- chosen_responses = []
84
- cutoff_frac = random.uniform(0, args.cutoff_frac_max)
85
- for row in sample_data:
86
- seq_len = len(row["concept_ids"])
87
- prompt_len = max(4, int(seq_len * cutoff_frac))
88
- prompts.append(cehrgpt_tokenizer.encode(row["concept_ids"][:prompt_len]))
89
- chosen_responses.append(
90
- {
91
- "person_id": row["person_id"],
92
- "chosen_concept_ids": (
93
- row["concept_ids"] if "concept_ids" in row else None
94
- ),
95
- "chosen_concept_values": (
96
- row["concept_values"] if "concept_values" in row else None
97
- ),
98
- "chosen_concept_value_masks": (
99
- row["concept_value_masks"]
100
- if "concept_value_masks" in row
101
- else None
102
- ),
103
- "chosen_units": row["units"] if "units" in row else None,
104
- "prompt_length": prompt_len,
105
- }
106
- )
107
-
108
- batch_sequences = generate_single_batch(
109
- cehrgpt_model,
110
- cehrgpt_tokenizer,
111
- prompts=prompts,
112
- max_new_tokens=args.context_window,
113
- mini_num_of_concepts=args.min_num_of_concepts,
114
- top_p=args.top_p,
115
- top_k=args.top_k,
116
- temperature=args.temperature,
117
- repetition_penalty=args.repetition_penalty,
118
- num_beams=args.num_beams,
119
- num_beam_groups=args.num_beam_groups,
120
- epsilon_cutoff=args.epsilon_cutoff,
121
- device=device,
122
- )
123
-
124
- # Clear the cache
125
- torch.cuda.empty_cache()
126
-
127
- for seq, value_indicator, value, chosen_response in zip(
128
- batch_sequences["sequences"],
129
- batch_sequences["value_indicators"],
130
- batch_sequences["values"],
131
- chosen_responses,
132
- ):
133
- output = {"rejected_concept_ids": seq}
134
- normalized_values, units = normalize_value(
135
- seq, value_indicator, value, cehrgpt_tokenizer
136
- )
137
- if normalized_values is not None:
138
- output["rejected_concept_values"] = normalized_values
139
- if value_indicator is not None:
140
- output["rejected_concept_value_masks"] = value_indicator
141
- if units is not None:
142
- output["rejected_units"] = units
143
- output.update(chosen_response)
144
- sequence_to_flush.append(output)
145
-
146
- if len(sequence_to_flush) >= args.buffer_size:
147
- LOG.info(f"{datetime.datetime.now()}: Flushing to the Disk at Batch {i}")
148
- pd.DataFrame(
149
- sequence_to_flush,
150
- columns=[
151
- "person_id",
152
- "chosen_concept_ids",
153
- "chosen_concept_values",
154
- "chosen_concept_value_masks",
155
- "chosen_units",
156
- "prompt_length",
157
- "rejected_concept_ids",
158
- "rejected_concept_values",
159
- "rejected_concept_value_masks",
160
- "rejected_units",
161
- ],
162
- ).to_parquet(os.path.join(output_folder_name, f"{uuid.uuid4()}.parquet"))
163
- sequence_to_flush.clear()
164
-
165
- if len(sequence_to_flush) > 0:
166
- LOG.info(f"{datetime.datetime.now()}: Flushing to the Disk at Final Batch")
167
- pd.DataFrame(
168
- sequence_to_flush,
169
- columns=[
170
- "person_id",
171
- "chosen_concept_ids",
172
- "chosen_concept_values",
173
- "chosen_concept_value_masks",
174
- "chosen_units",
175
- "prompt_length",
176
- "rejected_concept_ids",
177
- "rejected_concept_values",
178
- "rejected_concept_value_masks",
179
- "rejected_units",
180
- ],
181
- ).to_parquet(os.path.join(output_folder_name, f"{uuid.uuid4()}-last.parquet"))
182
-
183
-
184
- def create_arg_parser():
185
- base_arg_parser = create_inference_base_arg_parser(
186
- description="Arguments for generating paired patient sequences"
187
- )
188
- base_arg_parser.add_argument(
189
- "--num_of_patients",
190
- dest="num_of_patients",
191
- action="store",
192
- type=int,
193
- help="The number of patients that will be generated",
194
- required=True,
195
- )
196
- base_arg_parser.add_argument(
197
- "--sequence_data_path",
198
- dest="sequence_data_path",
199
- action="store",
200
- help="The path for your sequence data",
201
- required=True,
202
- )
203
- base_arg_parser.add_argument(
204
- "--cutoff_frac_max",
205
- dest="cutoff_frac_max",
206
- action="store",
207
- type=float,
208
- help="The max fraction of the patient sequences that will be used for prompting",
209
- required=False,
210
- default=0.5,
211
- )
212
- base_arg_parser.add_argument(
213
- "--num_proc",
214
- dest="num_proc",
215
- action="store",
216
- type=int,
217
- required=False,
218
- default=1,
219
- )
220
- return base_arg_parser
221
-
222
-
223
- if __name__ == "__main__":
224
- main(create_arg_parser().parse_args())