cehrgpt 0.1.0__tar.gz → 0.1.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (111) hide show
  1. {cehrgpt-0.1.0/src/cehrgpt.egg-info → cehrgpt-0.1.1}/PKG-INFO +8 -7
  2. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/README.md +4 -4
  3. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/pyproject.toml +4 -3
  4. cehrgpt-0.1.1/sample_data/omop_vocab/concept/concept.parquet +0 -0
  5. cehrgpt-0.1.1/scripts/run_linear_prob.sh +260 -0
  6. cehrgpt-0.1.1/src/cehrgpt/analysis/irregularity.py +36 -0
  7. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/src/cehrgpt/data/hf_cehrgpt_dataset.py +1 -0
  8. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/src/cehrgpt/data/hf_cehrgpt_dataset_collator.py +398 -36
  9. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/src/cehrgpt/data/hf_cehrgpt_dataset_mapping.py +214 -12
  10. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/src/cehrgpt/data/sample_packing_sampler.py +36 -6
  11. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/src/cehrgpt/generation/omop_converter_batch.py +32 -2
  12. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/src/cehrgpt/gpt_utils.py +20 -2
  13. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/src/cehrgpt/models/config.py +25 -0
  14. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/src/cehrgpt/models/hf_cehrgpt.py +227 -33
  15. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/src/cehrgpt/models/hf_modeling_outputs.py +1 -0
  16. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/src/cehrgpt/models/special_tokens.py +1 -0
  17. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/src/cehrgpt/models/tokenization_hf_cehrgpt.py +354 -71
  18. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/src/cehrgpt/runners/data_utils.py +117 -2
  19. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/src/cehrgpt/runners/hf_cehrgpt_finetune_runner.py +75 -50
  20. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/src/cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +59 -7
  21. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/src/cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +48 -0
  22. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/src/cehrgpt/runners/hyperparameter_search_util.py +6 -7
  23. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/src/cehrgpt/runners/sample_packing_trainer.py +17 -0
  24. cehrgpt-0.1.1/src/cehrgpt/time_to_event/config/1_year_cabg.yaml +23 -0
  25. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/src/cehrgpt/time_to_event/time_to_event_model.py +2 -13
  26. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/src/cehrgpt/time_to_event/time_to_event_prediction.py +27 -13
  27. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/src/cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +85 -57
  28. {cehrgpt-0.1.0 → cehrgpt-0.1.1/src/cehrgpt.egg-info}/PKG-INFO +8 -7
  29. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/src/cehrgpt.egg-info/SOURCES.txt +4 -0
  30. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/src/cehrgpt.egg-info/requires.txt +4 -3
  31. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/tests/integration_tests/runners/hf_cehrgpt_pretrain_runner_test.py +15 -6
  32. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/tests/integration_tests/runners/hf_cehrgpt_pretrain_sample_packing_runner_test.py +7 -0
  33. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/.github/workflows/build-python.yaml +0 -0
  34. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/.github/workflows/tests.yaml +0 -0
  35. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/.gitignore +0 -0
  36. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/.pre-commit-config.yaml +0 -0
  37. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/LICENSE +0 -0
  38. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/sample_configs/cehrgpt_pretrain_sample_config.yaml +0 -0
  39. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/sample_data/pretrain/patient_sequence.parquet +0 -0
  40. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/sample_data/pretrained_embeddings/pretrained_embedding_concepts.pkl +0 -0
  41. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/sample_data/pretrained_embeddings/pretrained_embedding_vectors.npy +0 -0
  42. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/scripts/level_three_evaluation.sh +0 -0
  43. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/scripts/omop_pipeline.sh +0 -0
  44. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/scripts/pool_generated_sequences.sh +0 -0
  45. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/setup.cfg +0 -0
  46. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/src/__init__.py +0 -0
  47. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/src/cehrgpt/__init__.py +0 -0
  48. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/src/cehrgpt/analysis/__init__.py +0 -0
  49. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/src/cehrgpt/analysis/privacy/__init__.py +0 -0
  50. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/src/cehrgpt/analysis/privacy/attribute_inference.py +0 -0
  51. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/src/cehrgpt/analysis/privacy/attribute_inference_config.yml +0 -0
  52. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/src/cehrgpt/analysis/privacy/member_inference.py +0 -0
  53. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/src/cehrgpt/analysis/privacy/nearest_neighbor_inference.py +0 -0
  54. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/src/cehrgpt/analysis/privacy/reid_inference.py +0 -0
  55. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/src/cehrgpt/analysis/privacy/utils.py +0 -0
  56. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/src/cehrgpt/cehrgpt_args.py +0 -0
  57. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/src/cehrgpt/data/__init__.py +0 -0
  58. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/src/cehrgpt/generation/__init__.py +0 -0
  59. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/src/cehrgpt/generation/chatgpt_generation.py +0 -0
  60. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/src/cehrgpt/generation/generate_batch_hf_gpt_sequence.py +0 -0
  61. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/src/cehrgpt/generation/omop_entity.py +0 -0
  62. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/src/cehrgpt/models/__init__.py +0 -0
  63. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/src/cehrgpt/models/pretrained_embeddings.py +0 -0
  64. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/src/cehrgpt/omop/__init__.py +0 -0
  65. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/src/cehrgpt/omop/condition_era.py +0 -0
  66. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/src/cehrgpt/omop/observation_period.py +0 -0
  67. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/src/cehrgpt/omop/omop_argparse.py +0 -0
  68. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/src/cehrgpt/omop/omop_table_builder.py +0 -0
  69. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/src/cehrgpt/omop/queries/__init__.py +0 -0
  70. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/src/cehrgpt/omop/queries/condition_era.py +0 -0
  71. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/src/cehrgpt/omop/queries/observation_period.py +0 -0
  72. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/src/cehrgpt/omop/sample_omop_tables.py +0 -0
  73. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/src/cehrgpt/runners/__init__.py +0 -0
  74. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/src/cehrgpt/runners/gpt_runner_util.py +0 -0
  75. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/src/cehrgpt/simulations/__init__.py +0 -0
  76. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/src/cehrgpt/simulations/generate_plots.py +0 -0
  77. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/src/cehrgpt/simulations/run_simulation.sh +0 -0
  78. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/src/cehrgpt/simulations/time_embedding_simulation.py +0 -0
  79. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/src/cehrgpt/simulations/time_token_simulation.py +0 -0
  80. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/src/cehrgpt/time_to_event/__init__.py +0 -0
  81. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/src/cehrgpt/time_to_event/config/30_day_readmission.yaml +0 -0
  82. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/src/cehrgpt/time_to_event/config/next_visit_type_prediction.yaml +0 -0
  83. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/src/cehrgpt/time_to_event/config/t2dm_hf.yaml +0 -0
  84. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/src/cehrgpt/time_to_event/time_to_event_utils.py +0 -0
  85. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/src/cehrgpt/tools/__init__.py +0 -0
  86. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/src/cehrgpt/tools/ehrshot_benchmark.py +0 -0
  87. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/src/cehrgpt/tools/generate_causal_patient_split_by_age.py +0 -0
  88. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/src/cehrgpt/tools/generate_pretrained_embeddings.py +0 -0
  89. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/src/cehrgpt/tools/linear_prob/__init__.py +0 -0
  90. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/src/cehrgpt/tools/linear_prob/train_with_cehrgpt_features.py +0 -0
  91. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/src/cehrgpt/tools/merge_synthetic_real_dataasets.py +0 -0
  92. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/src/cehrgpt/tools/upload_omop_tables.py +0 -0
  93. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/src/cehrgpt.egg-info/dependency_links.txt +0 -0
  94. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/src/cehrgpt.egg-info/top_level.txt +0 -0
  95. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/tests/__init__.py +0 -0
  96. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/tests/integration_tests/__init__.py +0 -0
  97. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/tests/integration_tests/runners/__init__.py +0 -0
  98. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/tests/integration_tests/runners/hf_cehrgpt_pretrain_sfm_runner_test.py +0 -0
  99. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/tests/unit_tests/__init__.py +0 -0
  100. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/tests/unit_tests/gpt_utils_test.py +0 -0
  101. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/tests/unit_tests/models/__init__.py +0 -0
  102. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/tests/unit_tests/models/model_utils_test.py +0 -0
  103. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/tests/unit_tests/models/tokenization/__init__.py +0 -0
  104. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/tests/unit_tests/models/tokenization/create_bins_with_spline_test.py +0 -0
  105. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/tests/unit_tests/models/tokenization/create_sample_from_bins_test.py +0 -0
  106. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/tests/unit_tests/numeric_concept_statistics_test.py +0 -0
  107. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/tests/unit_tests/runners/__init__.py +0 -0
  108. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/tests/unit_tests/runners/hf_cehrgpt_finetune_runner_test.py +0 -0
  109. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/tests/unit_tests/tokenization_test.py +0 -0
  110. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/tests/unit_tests/tools/__init__.py +0 -0
  111. {cehrgpt-0.1.0 → cehrgpt-0.1.1}/tests/unit_tests/tools/upload_omop_tables_test.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: cehrgpt
3
- Version: 0.1.0
3
+ Version: 0.1.1
4
4
  Summary: CEHR-GPT: Generating Electronic Health Records with Chronological Patient Timelines
5
5
  Author-email: Chao Pang <chaopang229@gmail.com>, Xinzhuo Jiang <xj2193@cumc.columbia.edu>, Krishna Kalluri <kk3326@cumc.columbia.edu>, Elise Minto <em3697@cumc.columbia.edu>, Jason Patterson <jp3477@cumc.columbia.edu>, Nishanth Parameshwar Pavinkurve <np2689@cumc.columbia.edu>, Karthik Natarajan <kn2174@cumc.columbia.edu>
6
6
  License: MIT License
@@ -12,14 +12,15 @@ Classifier: Programming Language :: Python :: 3
12
12
  Requires-Python: >=3.10.0
13
13
  Description-Content-Type: text/markdown
14
14
  License-File: LICENSE
15
- Requires-Dist: cehrbert==1.4.1
16
- Requires-Dist: cehrbert_data==0.0.7
15
+ Requires-Dist: cehrbert==1.4.5
16
+ Requires-Dist: cehrbert_data==0.0.11
17
17
  Requires-Dist: openai==1.54.3
18
18
  Requires-Dist: optuna==4.0.0
19
- Requires-Dist: transformers==4.44.0
19
+ Requires-Dist: transformers==4.44.1
20
20
  Requires-Dist: tokenizers==0.19.0
21
21
  Requires-Dist: peft==0.10.0
22
22
  Requires-Dist: lightgbm
23
+ Requires-Dist: polars
23
24
  Provides-Extra: dev
24
25
  Requires-Dist: pre-commit; extra == "dev"
25
26
  Requires-Dist: pytest; extra == "dev"
@@ -36,9 +37,9 @@ Dynamic: license-file
36
37
 
37
38
  [![PyPI - Version](https://img.shields.io/pypi/v/cehrgpt)](https://pypi.org/project/cehrgpt/)
38
39
  ![Python](https://img.shields.io/badge/-Python_3.11-blue?logo=python&logoColor=white)
39
- [![tests](https://github.com/knatarajan-lab/cehrgpt-public/actions/workflows/tests.yaml/badge.svg)](https://github.com/knatarajan-lab/cehrgpt-public/actions/workflows/tests.yml)
40
- [![license](https://img.shields.io/badge/License-MIT-green.svg?labelColor=gray)](https://github.com/knatarajan-lab/cehrgpt-public/blob/main/LICENSE)
41
- [![contributors](https://img.shields.io/github/contributors/knatarajan-lab/cehrgpt-public.svg)](https://github.com/knatarajan-lab/cehrgpt-public/graphs/contributors)
40
+ [![tests](https://github.com/knatarajan-lab/cehrgpt/actions/workflows/tests.yaml/badge.svg)](https://github.com/knatarajan-lab/cehrgpt/actions/workflows/tests.yaml)
41
+ [![license](https://img.shields.io/badge/License-MIT-green.svg?labelColor=gray)](https://github.com/knatarajan-lab/cehrgpt/blob/main/LICENSE)
42
+ [![contributors](https://img.shields.io/github/contributors/knatarajan-lab/cehrgpt.svg)](https://github.com/knatarajan-lab/cehrgpt/graphs/contributors)
42
43
 
43
44
  ## Description
44
45
  CEHRGPT is a synthetic data generation model developed to handle structured electronic health records (EHR) with enhanced privacy and reliability. It leverages state-of-the-art natural language processing techniques to create realistic, anonymized patient data that can be used for research and development without compromising patient privacy.
@@ -2,9 +2,9 @@
2
2
 
3
3
  [![PyPI - Version](https://img.shields.io/pypi/v/cehrgpt)](https://pypi.org/project/cehrgpt/)
4
4
  ![Python](https://img.shields.io/badge/-Python_3.11-blue?logo=python&logoColor=white)
5
- [![tests](https://github.com/knatarajan-lab/cehrgpt-public/actions/workflows/tests.yaml/badge.svg)](https://github.com/knatarajan-lab/cehrgpt-public/actions/workflows/tests.yml)
6
- [![license](https://img.shields.io/badge/License-MIT-green.svg?labelColor=gray)](https://github.com/knatarajan-lab/cehrgpt-public/blob/main/LICENSE)
7
- [![contributors](https://img.shields.io/github/contributors/knatarajan-lab/cehrgpt-public.svg)](https://github.com/knatarajan-lab/cehrgpt-public/graphs/contributors)
5
+ [![tests](https://github.com/knatarajan-lab/cehrgpt/actions/workflows/tests.yaml/badge.svg)](https://github.com/knatarajan-lab/cehrgpt/actions/workflows/tests.yaml)
6
+ [![license](https://img.shields.io/badge/License-MIT-green.svg?labelColor=gray)](https://github.com/knatarajan-lab/cehrgpt/blob/main/LICENSE)
7
+ [![contributors](https://img.shields.io/github/contributors/knatarajan-lab/cehrgpt.svg)](https://github.com/knatarajan-lab/cehrgpt/graphs/contributors)
8
8
 
9
9
  ## Description
10
10
  CEHRGPT is a synthetic data generation model developed to handle structured electronic health records (EHR) with enhanced privacy and reliability. It leverages state-of-the-art natural language processing techniques to create realistic, anonymized patient data that can be used for research and development without compromising patient privacy.
@@ -77,4 +77,4 @@ sh scripts/omop_pipeline.sh \
77
77
  author={Natarajan, K and others},
78
78
  journal={arXiv preprint arXiv:2402.04400},
79
79
  year={2024}
80
- }
80
+ }
@@ -28,14 +28,15 @@ classifiers = [
28
28
  ]
29
29
 
30
30
  dependencies = [
31
- "cehrbert==1.4.1",
32
- "cehrbert_data==0.0.7",
31
+ "cehrbert==1.4.5",
32
+ "cehrbert_data==0.0.11",
33
33
  "openai==1.54.3",
34
34
  "optuna==4.0.0",
35
- "transformers==4.44.0",
35
+ "transformers==4.44.1",
36
36
  "tokenizers==0.19.0",
37
37
  "peft==0.10.0",
38
38
  "lightgbm",
39
+ "polars",
39
40
  ]
40
41
 
41
42
  [tool.setuptools_scm]
@@ -0,0 +1,260 @@
1
+ #!/bin/sh
2
+
3
+ # Function to display usage information
4
+ usage() {
5
+ echo "Usage: $0 [options]"
6
+ echo ""
7
+ echo "Options:"
8
+ echo " --base_dir=DIR Base directory containing cohorts (required)"
9
+ echo " --dataset_prepared_path=PATH Path to prepared dataset (required)"
10
+ echo " --model_path=PATH Path to pre-trained model and tokenizer (required)"
11
+ echo " --preprocessing_workers=NUM Number of preprocessing workers (required)"
12
+ echo " --batch_size=NUM Batch size for evaluation (required)"
13
+ echo " --output_dir=DIR Output directory for results (required)"
14
+ echo " --model_name=NAME Name for the model output directory (default: cehrgpt_model)"
15
+ echo " --max_tokens_per_batch=NUM Maximum tokens per batch (default: 16384)"
16
+ echo " --torch_type=TYPE Torch data type (default: float32)"
17
+ echo " --disable_sample_packing Disable sample packing (enabled by default)"
18
+ echo ""
19
+ echo "Example:"
20
+ echo " $0 --base_dir=/path/to/cohorts --dataset_prepared_path=/path/to/dataset_prepared \\"
21
+ echo " --model_path=/path/to/model --preprocessing_workers=16 --batch_size=64 \\"
22
+ echo " --output_dir=/path/to/outputs --model_name=my_model --torch_type=float16"
23
+ exit 1
24
+ }
25
+
26
+ # Default values
27
+ MODEL_NAME="cehrgpt_model"
28
+ MAX_TOKENS_PER_BATCH="16384"
29
+ TORCH_TYPE="bfloat16"
30
+ DISABLE_SAMPLE_PACKING="false"
31
+
32
+ # Parse command line arguments
33
+ for arg in "$@"; do
34
+ case $arg in
35
+ --base_dir=*)
36
+ BASE_DIR="${arg#*=}"
37
+ ;;
38
+ --dataset_prepared_path=*)
39
+ DATASET_PREPARED_PATH="${arg#*=}"
40
+ ;;
41
+ --model_path=*)
42
+ MODEL_PATH="${arg#*=}"
43
+ ;;
44
+ --preprocessing_workers=*)
45
+ PREPROCESSING_WORKERS="${arg#*=}"
46
+ ;;
47
+ --batch_size=*)
48
+ BATCH_SIZE="${arg#*=}"
49
+ ;;
50
+ --output_dir=*)
51
+ OUTPUT_DIR="${arg#*=}"
52
+ ;;
53
+ --model_name=*)
54
+ MODEL_NAME="${arg#*=}"
55
+ ;;
56
+ --max_tokens_per_batch=*)
57
+ MAX_TOKENS_PER_BATCH="${arg#*=}"
58
+ ;;
59
+ --torch_type=*)
60
+ TORCH_TYPE="${arg#*=}"
61
+ ;;
62
+ --disable_sample_packing)
63
+ DISABLE_SAMPLE_PACKING="true"
64
+ ;;
65
+ --help|-h)
66
+ usage
67
+ ;;
68
+ *)
69
+ echo "Error: Unknown option: $arg"
70
+ usage
71
+ ;;
72
+ esac
73
+ done
74
+
75
+ # Check for required arguments
76
+ if [ -z "$BASE_DIR" ] || [ -z "$DATASET_PREPARED_PATH" ] || [ -z "$MODEL_PATH" ] || [ -z "$PREPROCESSING_WORKERS" ] || [ -z "$BATCH_SIZE" ] || [ -z "$OUTPUT_DIR" ]; then
77
+ echo "Error: Missing required arguments"
78
+ usage
79
+ fi
80
+
81
+ # Validate arguments
82
+ if [ ! -d "$BASE_DIR" ]; then
83
+ echo "Error: Base directory does not exist: $BASE_DIR"
84
+ exit 1
85
+ fi
86
+
87
+ if [ ! -d "$DATASET_PREPARED_PATH" ]; then
88
+ echo "Error: Dataset prepared path does not exist: $DATASET_PREPARED_PATH"
89
+ exit 1
90
+ fi
91
+
92
+ if [ ! -d "$MODEL_PATH" ]; then
93
+ echo "Error: Model path does not exist: $MODEL_PATH"
94
+ exit 1
95
+ fi
96
+
97
+ # Create output directory if it doesn't exist
98
+ mkdir -p "$OUTPUT_DIR"
99
+
100
+ # Check if preprocessing workers is a number
101
+ if ! [ "$PREPROCESSING_WORKERS" -eq "$PREPROCESSING_WORKERS" ] 2>/dev/null; then
102
+ echo "Error: Preprocessing workers must be a number: $PREPROCESSING_WORKERS"
103
+ exit 1
104
+ fi
105
+
106
+ # Check if batch size is a number
107
+ if ! [ "$BATCH_SIZE" -eq "$BATCH_SIZE" ] 2>/dev/null; then
108
+ echo "Error: Batch size must be a number: $BATCH_SIZE"
109
+ exit 1
110
+ fi
111
+
112
+ # Check if max tokens per batch is a number
113
+ if ! [ "$MAX_TOKENS_PER_BATCH" -eq "$MAX_TOKENS_PER_BATCH" ] 2>/dev/null; then
114
+ echo "Error: Max tokens per batch must be a number: $MAX_TOKENS_PER_BATCH"
115
+ exit 1
116
+ fi
117
+
118
+ # Validate torch_type (common PyTorch data types)
119
+ case "$TORCH_TYPE" in
120
+ float16|float32|float64|bfloat16|int8|int16|int32|int64)
121
+ ;;
122
+ *)
123
+ echo "Error: Invalid torch_type. Supported types: float16, float32, float64, bfloat16, int8, int16, int32, int64"
124
+ exit 1
125
+ ;;
126
+ esac
127
+
128
+ # Validate disable_sample_packing is boolean-like
129
+ if [ "$DISABLE_SAMPLE_PACKING" != "true" ] && [ "$DISABLE_SAMPLE_PACKING" != "false" ]; then
130
+ echo "Error: disable_sample_packing must be 'true' or 'false': $DISABLE_SAMPLE_PACKING"
131
+ exit 1
132
+ fi
133
+
134
+ # Log file setup
135
+ LOG_DIR="$BASE_DIR/logs"
136
+ mkdir -p "$LOG_DIR"
137
+ TIMESTAMP=$(date +"%Y%m%d_%H%M%S")
138
+ MAIN_LOG="$LOG_DIR/run_${TIMESTAMP}.log"
139
+
140
+ # Log function
141
+ log() {
142
+ message="[$(date '+%Y-%m-%d %H:%M:%S')] $1"
143
+ echo "$message" | tee -a "$MAIN_LOG"
144
+ }
145
+
146
+ # Main execution
147
+ log "Starting feature extraction and model training process"
148
+ log "Configuration:"
149
+ log " --base_dir=$BASE_DIR"
150
+ log " --dataset_prepared_path=$DATASET_PREPARED_PATH"
151
+ log " --model_path=$MODEL_PATH"
152
+ log " --preprocessing_workers=$PREPROCESSING_WORKERS"
153
+ log " --batch_size=$BATCH_SIZE"
154
+ log " --output_dir=$OUTPUT_DIR"
155
+ log " --model_name=$MODEL_NAME"
156
+ log " --max_tokens_per_batch=$MAX_TOKENS_PER_BATCH"
157
+ log " --torch_type=$TORCH_TYPE"
158
+ log " --disable_sample_packing=$DISABLE_SAMPLE_PACKING"
159
+
160
+ # Find valid cohorts and write to a temp file
161
+ TEMP_COHORT_LIST="$LOG_DIR/cohort_list_${TIMESTAMP}.txt"
162
+ > "$TEMP_COHORT_LIST" # Clear the file
163
+
164
+ # Find all valid cohorts (directories with train and test subdirectories)
165
+ for cohort_dir in "$BASE_DIR"/*; do
166
+ if [ -d "$cohort_dir" ] && [ -d "$cohort_dir/train" ] && [ -d "$cohort_dir/test" ]; then
167
+ cohort_name=$(basename "$cohort_dir")
168
+ echo "$cohort_name" >> "$TEMP_COHORT_LIST"
169
+ fi
170
+ done
171
+
172
+ # Check if any valid cohorts were found
173
+ if [ ! -s "$TEMP_COHORT_LIST" ]; then
174
+ log "ERROR: No valid cohorts found in $BASE_DIR"
175
+ rm -f "$TEMP_COHORT_LIST"
176
+ exit 1
177
+ fi
178
+
179
+ # Display all cohorts that will be processed
180
+ cohort_count=$(wc -l < "$TEMP_COHORT_LIST")
181
+ log "Found $cohort_count cohorts to process:"
182
+ while read -r cohort; do
183
+ log "- $cohort"
184
+ done < "$TEMP_COHORT_LIST"
185
+
186
+ # Process each cohort sequentially
187
+ while read -r cohort_name; do
188
+ cohort_dir="$OUTPUT_DIR/$cohort_name"
189
+ output_dir="$cohort_dir/$MODEL_NAME"
190
+
191
+ log "===================================================="
192
+ log "Processing cohort: $cohort_name"
193
+ log "===================================================="
194
+
195
+ cohort_log="$LOG_DIR/${cohort_name}_${TIMESTAMP}.log"
196
+
197
+ # Create output directory if it doesn't exist
198
+ mkdir -p "$output_dir"
199
+
200
+ # Prepare command for feature extraction
201
+ FEATURE_CMD="python -u -m cehrgpt.tools.linear_prob.compute_cehrgpt_features \
202
+ --data_folder \"$BASE_DIR/$cohort_name/train/\" \
203
+ --test_data_folder \"$BASE_DIR/$cohort_name/test/\" \
204
+ --dataset_prepared_path \"$DATASET_PREPARED_PATH\" \
205
+ --model_name_or_path \"$MODEL_PATH\" \
206
+ --tokenizer_name_or_path \"$MODEL_PATH\" \
207
+ --output_dir \"$output_dir\" \
208
+ --preprocessing_num_workers \"$PREPROCESSING_WORKERS\" \
209
+ --per_device_eval_batch_size \"$BATCH_SIZE\" \
210
+ --max_tokens_per_batch \"$MAX_TOKENS_PER_BATCH\" \
211
+ --torch_type \"$TORCH_TYPE\""
212
+
213
+ # Add sample packing flag if not disabled
214
+ if [ "$DISABLE_SAMPLE_PACKING" = "false" ]; then
215
+ FEATURE_CMD="$FEATURE_CMD --sample_packing"
216
+ fi
217
+
218
+ # Step 1: Feature extraction
219
+ log "Starting feature extraction for $cohort_name..."
220
+ log "Command: $FEATURE_CMD"
221
+
222
+ eval "$FEATURE_CMD > \"$cohort_log\" 2>&1"
223
+
224
+ feature_extraction_status=$?
225
+ if [ $feature_extraction_status -ne 0 ]; then
226
+ log "ERROR: Feature extraction failed for $cohort_name. Check $cohort_log for details."
227
+ continue
228
+ fi
229
+
230
+ # Step 2: Model training
231
+ log "Starting model training for $cohort_name..."
232
+ log "Command: python -u -m cehrgpt.tools.linear_prob.train_with_cehrgpt_features --features_data_dir $output_dir --output_dir $output_dir"
233
+
234
+ python -u -m cehrgpt.tools.linear_prob.train_with_cehrgpt_features \
235
+ --features_data_dir "$output_dir" \
236
+ --output_dir "$output_dir" \
237
+ >> "$cohort_log" 2>&1
238
+
239
+ echo "Running meds-evaluation for logistic regression for $TASK_NAME..."
240
+ meds-evaluation-cli predictions_path="$output_dir/logistic/test_predictions" \
241
+ output_dir="$output_dir/logistic/"
242
+
243
+ # Check if the second command succeeded
244
+ if [ $? -ne 0 ]; then
245
+ echo "Error: Running meds-evaluation failed for logistic regression for task $TASK_NAME"
246
+ fi
247
+
248
+ model_training_status=$?
249
+ if [ $model_training_status -ne 0 ]; then
250
+ log "ERROR: Model training failed for $cohort_name. Check $cohort_log for details."
251
+ continue
252
+ fi
253
+
254
+ log "Successfully completed processing for $cohort_name"
255
+ done < "$TEMP_COHORT_LIST"
256
+
257
+ # Clean up
258
+ rm -f "$TEMP_COHORT_LIST"
259
+
260
+ log "All processing completed"
@@ -0,0 +1,36 @@
1
+ import os
2
+
3
+ import polars as pl
4
+
5
+ from cehrgpt.gpt_utils import extract_time_interval_in_days, is_att_token
6
+
7
+
8
+ def main(args):
9
+ dataset = pl.read_parquet(os.path.join(args.input_dir, "*.parquet"))
10
+ time_token_frequency_df = (
11
+ dataset.select(pl.col("concept_ids").explode().alias("concept_id"))
12
+ .filter(pl.col("concept_id").map_elements(is_att_token))
13
+ .with_columns(
14
+ pl.col("concept_id")
15
+ .map_elements(extract_time_interval_in_days)
16
+ .alias("time_interval")
17
+ )
18
+ )
19
+ results = time_token_frequency_df.select(
20
+ pl.mean("time_interval").alias("mean"), pl.std("time_interval").alias("std")
21
+ ).to_dicts()[0]
22
+ print(results)
23
+
24
+
25
+ if __name__ == "__main__":
26
+ import argparse
27
+
28
+ parser = argparse.ArgumentParser(description="EHR Irregularity analysis")
29
+ parser.add_argument(
30
+ "--input_dir",
31
+ dest="input_dir",
32
+ action="store",
33
+ help="The path for where the input data folder",
34
+ required=True,
35
+ )
36
+ main(parser.parse_args())
@@ -23,6 +23,7 @@ CEHRGPT_COLUMNS = [
23
23
  "num_of_visits",
24
24
  "values",
25
25
  "value_indicators",
26
+ "epoch_times",
26
27
  ]
27
28
 
28
29
  TRANSFORMER_COLUMNS = ["input_ids"]