cehrgpt 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- __init__.py +0 -0
- cehrgpt/__init__.py +0 -0
- cehrgpt/analysis/__init__.py +0 -0
- cehrgpt/analysis/privacy/__init__.py +0 -0
- cehrgpt/analysis/privacy/attribute_inference.py +275 -0
- cehrgpt/analysis/privacy/attribute_inference_config.yml +8975 -0
- cehrgpt/analysis/privacy/member_inference.py +172 -0
- cehrgpt/analysis/privacy/nearest_neighbor_inference.py +189 -0
- cehrgpt/analysis/privacy/reid_inference.py +407 -0
- cehrgpt/analysis/privacy/utils.py +255 -0
- cehrgpt/cehrgpt_args.py +142 -0
- cehrgpt/data/__init__.py +0 -0
- cehrgpt/data/hf_cehrgpt_dataset.py +80 -0
- cehrgpt/data/hf_cehrgpt_dataset_collator.py +482 -0
- cehrgpt/data/hf_cehrgpt_dataset_mapping.py +116 -0
- cehrgpt/generation/__init__.py +0 -0
- cehrgpt/generation/chatgpt_generation.py +106 -0
- cehrgpt/generation/generate_batch_hf_gpt_sequence.py +333 -0
- cehrgpt/generation/omop_converter_batch.py +644 -0
- cehrgpt/generation/omop_entity.py +515 -0
- cehrgpt/gpt_utils.py +331 -0
- cehrgpt/models/__init__.py +0 -0
- cehrgpt/models/config.py +205 -0
- cehrgpt/models/hf_cehrgpt.py +1817 -0
- cehrgpt/models/hf_modeling_outputs.py +158 -0
- cehrgpt/models/pretrained_embeddings.py +82 -0
- cehrgpt/models/special_tokens.py +30 -0
- cehrgpt/models/tokenization_hf_cehrgpt.py +1077 -0
- cehrgpt/omop/__init__.py +0 -0
- cehrgpt/omop/condition_era.py +20 -0
- cehrgpt/omop/observation_period.py +43 -0
- cehrgpt/omop/omop_argparse.py +38 -0
- cehrgpt/omop/omop_table_builder.py +86 -0
- cehrgpt/omop/queries/__init__.py +0 -0
- cehrgpt/omop/queries/condition_era.py +86 -0
- cehrgpt/omop/queries/observation_period.py +135 -0
- cehrgpt/omop/sample_omop_tables.py +71 -0
- cehrgpt/runners/__init__.py +0 -0
- cehrgpt/runners/gpt_runner_util.py +99 -0
- cehrgpt/runners/hf_cehrgpt_finetune_runner.py +746 -0
- cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +370 -0
- cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +137 -0
- cehrgpt/runners/hyperparameter_search_util.py +223 -0
- cehrgpt/time_to_event/__init__.py +0 -0
- cehrgpt/time_to_event/config/30_day_readmission.yaml +8 -0
- cehrgpt/time_to_event/config/next_visit_type_prediction.yaml +8 -0
- cehrgpt/time_to_event/config/t2dm_hf.yaml +8 -0
- cehrgpt/time_to_event/time_to_event_model.py +226 -0
- cehrgpt/time_to_event/time_to_event_prediction.py +347 -0
- cehrgpt/time_to_event/time_to_event_utils.py +55 -0
- cehrgpt/tools/__init__.py +0 -0
- cehrgpt/tools/ehrshot_benchmark.py +74 -0
- cehrgpt/tools/generate_pretrained_embeddings.py +130 -0
- cehrgpt/tools/merge_synthetic_real_dataasets.py +218 -0
- cehrgpt/tools/upload_omop_tables.py +108 -0
- cehrgpt-0.0.1.dist-info/LICENSE +21 -0
- cehrgpt-0.0.1.dist-info/METADATA +66 -0
- cehrgpt-0.0.1.dist-info/RECORD +60 -0
- cehrgpt-0.0.1.dist-info/WHEEL +5 -0
- cehrgpt-0.0.1.dist-info/top_level.txt +2 -0
@@ -0,0 +1,106 @@
|
|
1
|
+
import os
|
2
|
+
from textwrap import dedent
|
3
|
+
|
4
|
+
import numpy as np
|
5
|
+
from jinja2 import BaseLoader, Environment
|
6
|
+
from openai import OpenAI
|
7
|
+
from pydantic import BaseModel
|
8
|
+
|
9
|
+
MODEL = "gpt-4o-2024-08-06"
|
10
|
+
TEMPLATE = """
|
11
|
+
You are a medical professional tasked with generating a synthetic patient sequence using the CEHR-GPT format, outlined as follows:
|
12
|
+
|
13
|
+
[year]: Represents the start year of the patient sequence.
|
14
|
+
[age]: Represents the start age of the patient sequence.
|
15
|
+
[gender]: Patient's gender, allowed values are "Male," "Female," and "Unknown."
|
16
|
+
[race]: Patient's race, allowed values are "White," "Black," "Asian," and "Unknown."
|
17
|
+
[VS]: Marks the start of a visit.
|
18
|
+
[VE]: Marks the end of a visit.
|
19
|
+
[VT]: Type of visit, with allowed values "9202" (outpatient), "9201" (inpatient), and "9203" (emergency room).
|
20
|
+
[C_i]: Clinical concept represented by an OMOP concept ID (could be a drug, condition, or procedure).
|
21
|
+
[ATT]: Artificial time tokens, representing time intervals in days (e.g., "D1," "D10").
|
22
|
+
[i-ATT]: Inpatient-specific artificial time tokens, representing intervals within inpatient stays (e.g., "i-D1"), these tokens should only appear in inpatient visits.
|
23
|
+
Each sequence can encompass multiple concepts within each visit and vary from one to ten visits, reflective of real-world clinical scenarios. All clinical concepts must correspond to valid OMOP IDs. The sequence must end on [VE]
|
24
|
+
|
25
|
+
Example of a sequence:
|
26
|
+
|
27
|
+
{
|
28
|
+
"seq": ['year:2008', 'age:28', '8532', '8527', '[VS]', '9202', '4301351',
|
29
|
+
'19078924', '35603428', '35603429', '40221381', '40223365',
|
30
|
+
'4155151', '4239130', '42536500', '4294382', '2108974', '433736',
|
31
|
+
'[VE]', 'D7', '[VS]', '9201', '43011850', '35603429', '35603600',
|
32
|
+
'35605482', '40163870', '40169706', '40221381', '35603428',
|
33
|
+
'19078921', '40244026', '948080', '1154615', '1593063', '4056973',
|
34
|
+
'4155151', '4194550', '3047860', '35604843', '43011962', '4160730',
|
35
|
+
'i-D1', '35604843', '40162587', '43011962', '433736', '948080',
|
36
|
+
'0', '[VE]', 'D14', '[VS]', '9202', '4019497', '[VE]', 'D26', '[VS]',
|
37
|
+
'1', '4019497', '[VE]', 'D198', '[VS]', '581477', '433736',
|
38
|
+
'[VE]', 'D19', '[VS]', '581477', '194152', '320128', '40483287', '433736', '[VE]']
|
39
|
+
}
|
40
|
+
|
41
|
+
When creating the sequence, please use the demographic tokens {{ demographic_prompt }} to construct a realistic and medically plausible patient trajectory.
|
42
|
+
"""
|
43
|
+
|
44
|
+
|
45
|
+
class PatientSequence(BaseModel):
|
46
|
+
seq: list[str]
|
47
|
+
|
48
|
+
|
49
|
+
if __name__ == "__main__":
|
50
|
+
import argparse
|
51
|
+
import uuid
|
52
|
+
|
53
|
+
import pandas as pd
|
54
|
+
from tqdm import tqdm
|
55
|
+
|
56
|
+
parser = argparse.ArgumentParser("ChatGPT patient generation")
|
57
|
+
parser.add_argument(
|
58
|
+
"--demographic_data",
|
59
|
+
dest="demographic_data",
|
60
|
+
action="store",
|
61
|
+
help="The path for your demographic_data",
|
62
|
+
required=True,
|
63
|
+
)
|
64
|
+
parser.add_argument(
|
65
|
+
"--output_folder",
|
66
|
+
dest="output_folder",
|
67
|
+
action="store",
|
68
|
+
help="The path for your output_folder",
|
69
|
+
required=True,
|
70
|
+
)
|
71
|
+
parser.add_argument(
|
72
|
+
"--num_sequences",
|
73
|
+
dest="num_sequences",
|
74
|
+
action="store",
|
75
|
+
type=int,
|
76
|
+
help="The path for your output_folder",
|
77
|
+
required=True,
|
78
|
+
)
|
79
|
+
args = parser.parse_args()
|
80
|
+
# Create a Jinja2 environment and render the template
|
81
|
+
env = Environment(loader=BaseLoader())
|
82
|
+
template = env.from_string(TEMPLATE)
|
83
|
+
demographics = pd.read_parquet(args.demographic_data)
|
84
|
+
|
85
|
+
for _ in tqdm(range(args.num_sequences)):
|
86
|
+
demographic_tokens = str(demographics.sample(1).concept_ids.iloc[0].tolist())
|
87
|
+
prompt = template.render(demographic_prompt=demographic_tokens)
|
88
|
+
client = OpenAI(api_key=os.environ.get("OPEN_AI_KEY"))
|
89
|
+
completion = client.beta.chat.completions.parse(
|
90
|
+
model=MODEL,
|
91
|
+
messages=[
|
92
|
+
{"role": "system", "content": "You are a medical professional."},
|
93
|
+
{"role": "user", "content": dedent(prompt)},
|
94
|
+
],
|
95
|
+
response_format=PatientSequence,
|
96
|
+
)
|
97
|
+
patient_sequence = completion.choices[0].message.parsed.seq
|
98
|
+
pd.DataFrame(
|
99
|
+
[
|
100
|
+
{
|
101
|
+
"concept_ids": patient_sequence,
|
102
|
+
"concept_values": np.zeros_like(patient_sequence),
|
103
|
+
}
|
104
|
+
],
|
105
|
+
columns=["concept_ids", "concept_values"],
|
106
|
+
).to_parquet(os.path.join(args.output_folder, f"{uuid.uuid4()}.parquet"))
|
@@ -0,0 +1,333 @@
|
|
1
|
+
import datetime
|
2
|
+
import os
|
3
|
+
import random
|
4
|
+
import uuid
|
5
|
+
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
6
|
+
|
7
|
+
import numpy as np
|
8
|
+
import pandas as pd
|
9
|
+
import torch
|
10
|
+
from cehrbert.runners.runner_util import load_parquet_as_dataset
|
11
|
+
from transformers import GenerationConfig
|
12
|
+
from transformers.utils import is_flash_attn_2_available, logging
|
13
|
+
|
14
|
+
from cehrgpt.cehrgpt_args import create_inference_base_arg_parser
|
15
|
+
from cehrgpt.generation.omop_converter_batch import START_TOKEN_SIZE
|
16
|
+
from cehrgpt.gpt_utils import get_cehrgpt_output_folder
|
17
|
+
from cehrgpt.models.hf_cehrgpt import CEHRGPT2LMHeadModel
|
18
|
+
from cehrgpt.models.special_tokens import END_TOKEN
|
19
|
+
from cehrgpt.models.tokenization_hf_cehrgpt import (
|
20
|
+
NA,
|
21
|
+
CehrGptTokenizer,
|
22
|
+
is_valid_valid_bin,
|
23
|
+
)
|
24
|
+
|
25
|
+
LOG = logging.get_logger("transformers")
|
26
|
+
|
27
|
+
|
28
|
+
def normalize_value(
|
29
|
+
seq: Sequence[str],
|
30
|
+
values: Sequence[str],
|
31
|
+
tokenizer: CehrGptTokenizer,
|
32
|
+
) -> Tuple[
|
33
|
+
Sequence[str],
|
34
|
+
Optional[Sequence[Optional[int]]],
|
35
|
+
Optional[Sequence[Optional[float]]],
|
36
|
+
Optional[Sequence[Optional[str]]],
|
37
|
+
Optional[Sequence[str]],
|
38
|
+
]:
|
39
|
+
concepts = []
|
40
|
+
number_as_values = []
|
41
|
+
concept_as_values = []
|
42
|
+
is_numeric_types = []
|
43
|
+
units = []
|
44
|
+
for concept, value in zip(seq, values):
|
45
|
+
if concept == END_TOKEN:
|
46
|
+
break
|
47
|
+
number_as_value = None
|
48
|
+
concept_as_value = value if value and value.isnumeric() else None
|
49
|
+
is_numeric_type = 0
|
50
|
+
unit = NA
|
51
|
+
# If concept is numeric, we expect the next token to be a value bin
|
52
|
+
if is_valid_valid_bin(value):
|
53
|
+
converted_value, unit = tokenizer.denormalize(concept, value)
|
54
|
+
if isinstance(converted_value, float):
|
55
|
+
number_as_value = converted_value
|
56
|
+
is_numeric_type = 1
|
57
|
+
|
58
|
+
concepts.append(concept)
|
59
|
+
number_as_values.append(number_as_value)
|
60
|
+
concept_as_values.append(concept_as_value)
|
61
|
+
is_numeric_types.append(is_numeric_type)
|
62
|
+
units.append(unit)
|
63
|
+
|
64
|
+
return (
|
65
|
+
concepts,
|
66
|
+
is_numeric_types,
|
67
|
+
number_as_values,
|
68
|
+
concept_as_values,
|
69
|
+
units,
|
70
|
+
)
|
71
|
+
|
72
|
+
|
73
|
+
def generate_single_batch(
|
74
|
+
model: CEHRGPT2LMHeadModel,
|
75
|
+
tokenizer: CehrGptTokenizer,
|
76
|
+
prompts: List[List[int]],
|
77
|
+
max_new_tokens=512,
|
78
|
+
mini_num_of_concepts=1,
|
79
|
+
top_p=0.95,
|
80
|
+
top_k=50,
|
81
|
+
temperature=1.0,
|
82
|
+
repetition_penalty=1.0,
|
83
|
+
num_beams=1,
|
84
|
+
num_beam_groups=1,
|
85
|
+
epsilon_cutoff=0.0,
|
86
|
+
device: Any = "cpu",
|
87
|
+
) -> Dict[str, Any]:
|
88
|
+
with torch.no_grad():
|
89
|
+
generation_config = GenerationConfig(
|
90
|
+
repetition_penalty=repetition_penalty,
|
91
|
+
max_length=max_new_tokens,
|
92
|
+
min_length=mini_num_of_concepts,
|
93
|
+
temperature=temperature,
|
94
|
+
top_p=top_p,
|
95
|
+
top_k=top_k,
|
96
|
+
bos_token_id=tokenizer.end_token_id,
|
97
|
+
eos_token_id=tokenizer.end_token_id,
|
98
|
+
pad_token_id=tokenizer.pad_token_id,
|
99
|
+
do_sample=True,
|
100
|
+
use_cache=True,
|
101
|
+
return_dict_in_generate=True,
|
102
|
+
output_attentions=False,
|
103
|
+
output_hidden_states=False,
|
104
|
+
output_scores=False,
|
105
|
+
renormalize_logits=True,
|
106
|
+
num_beams=num_beams,
|
107
|
+
num_beam_groups=num_beam_groups,
|
108
|
+
epsilon_cutoff=epsilon_cutoff,
|
109
|
+
)
|
110
|
+
batched_prompts = torch.tensor(prompts).to(device)
|
111
|
+
results = model.generate(
|
112
|
+
inputs=batched_prompts,
|
113
|
+
generation_config=generation_config,
|
114
|
+
lab_token_ids=tokenizer.lab_token_ids,
|
115
|
+
)
|
116
|
+
|
117
|
+
sequences = [
|
118
|
+
tokenizer.decode(seq.cpu().numpy(), skip_special_tokens=False)
|
119
|
+
for seq in results.sequences
|
120
|
+
]
|
121
|
+
if results.sequence_vals is not None:
|
122
|
+
values = [
|
123
|
+
tokenizer.decode_value(values.cpu().numpy(), skip_special_tokens=False)
|
124
|
+
for values in results.sequence_vals
|
125
|
+
]
|
126
|
+
else:
|
127
|
+
values = np.zeros_like(sequences)
|
128
|
+
values.fill(NA)
|
129
|
+
if results.sequence_val_masks is not None:
|
130
|
+
value_indicators = results.sequence_val_masks.cpu().numpy()
|
131
|
+
else:
|
132
|
+
value_indicators = np.zeros_like(sequences, dtype=np.int32).astype(bool)
|
133
|
+
return {
|
134
|
+
"sequences": sequences,
|
135
|
+
"values": values,
|
136
|
+
"value_indicators": value_indicators,
|
137
|
+
}
|
138
|
+
|
139
|
+
|
140
|
+
def main(args):
|
141
|
+
if torch.cuda.is_available():
|
142
|
+
device = torch.device("cuda")
|
143
|
+
else:
|
144
|
+
device = torch.device("cpu")
|
145
|
+
|
146
|
+
cehrgpt_tokenizer = CehrGptTokenizer.from_pretrained(args.tokenizer_folder)
|
147
|
+
cehrgpt_model = (
|
148
|
+
CEHRGPT2LMHeadModel.from_pretrained(
|
149
|
+
args.model_folder,
|
150
|
+
attn_implementation=(
|
151
|
+
"flash_attention_2" if is_flash_attn_2_available() else "eager"
|
152
|
+
),
|
153
|
+
torch_dtype=(
|
154
|
+
torch.bfloat16
|
155
|
+
if is_flash_attn_2_available() and args.use_bfloat16
|
156
|
+
else torch.float32
|
157
|
+
),
|
158
|
+
)
|
159
|
+
.eval()
|
160
|
+
.to(device)
|
161
|
+
)
|
162
|
+
cehrgpt_model.generation_config.pad_token_id = cehrgpt_tokenizer.pad_token_id
|
163
|
+
cehrgpt_model.generation_config.eos_token_id = cehrgpt_tokenizer.end_token_id
|
164
|
+
cehrgpt_model.generation_config.bos_token_id = cehrgpt_tokenizer.end_token_id
|
165
|
+
|
166
|
+
folder_name = get_cehrgpt_output_folder(args, cehrgpt_tokenizer)
|
167
|
+
output_folder_name = os.path.join(
|
168
|
+
args.output_folder, folder_name, "generated_sequences"
|
169
|
+
)
|
170
|
+
|
171
|
+
if not os.path.exists(output_folder_name):
|
172
|
+
os.makedirs(output_folder_name)
|
173
|
+
|
174
|
+
# Determine whether we will use the demographics with the long sequences
|
175
|
+
max_seq_allowed = (
|
176
|
+
cehrgpt_model.config.n_positions
|
177
|
+
if args.drop_long_sequences
|
178
|
+
else np.iinfo(np.int32).max
|
179
|
+
)
|
180
|
+
|
181
|
+
LOG.info(f"Loading tokenizer at {args.model_folder}")
|
182
|
+
LOG.info(f"Loading model at {args.model_folder}")
|
183
|
+
LOG.info(f"Write sequences to {output_folder_name}")
|
184
|
+
LOG.info(f"Context window {args.context_window}")
|
185
|
+
LOG.info(f"Max sequence allowed {max_seq_allowed}")
|
186
|
+
LOG.info(f"Temperature {args.temperature}")
|
187
|
+
LOG.info(f"Repetition Penalty {args.repetition_penalty}")
|
188
|
+
LOG.info(f"Sampling Strategy {args.sampling_strategy}")
|
189
|
+
LOG.info(f"Num beam {args.num_beams}")
|
190
|
+
LOG.info(f"Num beam groups {args.num_beam_groups}")
|
191
|
+
LOG.info(f"Epsilon cutoff {args.epsilon_cutoff}")
|
192
|
+
LOG.info(f"Top P {args.top_p}")
|
193
|
+
LOG.info(f"Top K {args.top_k}")
|
194
|
+
LOG.info(f"Loading demographic_info at {args.demographic_data_path}")
|
195
|
+
|
196
|
+
dataset = load_parquet_as_dataset(args.demographic_data_path)
|
197
|
+
total_rows = len(dataset)
|
198
|
+
|
199
|
+
num_of_batches = args.num_of_patients // args.batch_size + 1
|
200
|
+
sequence_to_flush = []
|
201
|
+
current_person_id = 1
|
202
|
+
for i in range(num_of_batches):
|
203
|
+
LOG.info(f"{datetime.datetime.now()}: Batch {i} started")
|
204
|
+
|
205
|
+
# Randomly pick demographics from the existing population
|
206
|
+
random_prompts = []
|
207
|
+
iter = 0
|
208
|
+
while len(random_prompts) < args.batch_size:
|
209
|
+
for row in dataset.select(
|
210
|
+
random.sample(range(total_rows), k=args.batch_size)
|
211
|
+
):
|
212
|
+
if (
|
213
|
+
args.min_num_of_concepts
|
214
|
+
<= len(row["concept_ids"])
|
215
|
+
<= max_seq_allowed
|
216
|
+
):
|
217
|
+
random_prompts.append(
|
218
|
+
cehrgpt_tokenizer.encode(row["concept_ids"][:START_TOKEN_SIZE])
|
219
|
+
)
|
220
|
+
iter += 1
|
221
|
+
if not random_prompts and iter > 10:
|
222
|
+
raise RuntimeError(
|
223
|
+
f"The length of concept_ids in {args.demographic_data_path} does not qualify!"
|
224
|
+
)
|
225
|
+
|
226
|
+
# Make sure the batch does not exceed batch_size
|
227
|
+
batch_sequences = generate_single_batch(
|
228
|
+
cehrgpt_model,
|
229
|
+
cehrgpt_tokenizer,
|
230
|
+
random_prompts[: args.batch_size],
|
231
|
+
max_new_tokens=args.context_window,
|
232
|
+
mini_num_of_concepts=args.min_num_of_concepts,
|
233
|
+
top_p=args.top_p,
|
234
|
+
top_k=args.top_k,
|
235
|
+
temperature=args.temperature,
|
236
|
+
repetition_penalty=args.repetition_penalty,
|
237
|
+
num_beams=args.num_beams,
|
238
|
+
num_beam_groups=args.num_beam_groups,
|
239
|
+
epsilon_cutoff=args.epsilon_cutoff,
|
240
|
+
device=device,
|
241
|
+
)
|
242
|
+
|
243
|
+
# Clear the cache
|
244
|
+
torch.cuda.empty_cache()
|
245
|
+
|
246
|
+
for concept_ids, value_indicators, values in zip(
|
247
|
+
batch_sequences["sequences"],
|
248
|
+
batch_sequences["value_indicators"],
|
249
|
+
batch_sequences["values"],
|
250
|
+
):
|
251
|
+
(
|
252
|
+
concept_ids,
|
253
|
+
is_numeric_types,
|
254
|
+
number_as_values,
|
255
|
+
concept_as_values,
|
256
|
+
units,
|
257
|
+
) = normalize_value(concept_ids, values, cehrgpt_tokenizer)
|
258
|
+
output = {"concept_ids": concept_ids, "person_id": current_person_id}
|
259
|
+
if is_numeric_types is not None:
|
260
|
+
output["is_numeric_types"] = is_numeric_types
|
261
|
+
if number_as_values is not None:
|
262
|
+
output["number_as_values"] = number_as_values
|
263
|
+
if concept_as_values is not None:
|
264
|
+
output["concept_as_values"] = concept_as_values
|
265
|
+
if value_indicators is not None:
|
266
|
+
output["concept_value_masks"] = value_indicators
|
267
|
+
if units is not None:
|
268
|
+
output["units"] = units
|
269
|
+
|
270
|
+
sequence_to_flush.append(output)
|
271
|
+
current_person_id += 1
|
272
|
+
|
273
|
+
if len(sequence_to_flush) >= args.buffer_size:
|
274
|
+
LOG.info(f"{datetime.datetime.now()}: Flushing to the Disk at Batch {i}")
|
275
|
+
pd.DataFrame(
|
276
|
+
sequence_to_flush,
|
277
|
+
columns=[
|
278
|
+
"concept_ids",
|
279
|
+
"person_id",
|
280
|
+
"is_numeric_types",
|
281
|
+
"number_as_values",
|
282
|
+
"concept_as_values",
|
283
|
+
"concept_value_masks",
|
284
|
+
"units",
|
285
|
+
],
|
286
|
+
).to_parquet(os.path.join(output_folder_name, f"{uuid.uuid4()}.parquet"))
|
287
|
+
sequence_to_flush.clear()
|
288
|
+
|
289
|
+
if len(sequence_to_flush) > 0:
|
290
|
+
LOG.info(f"{datetime.datetime.now()}: Flushing to the Disk at Final Batch")
|
291
|
+
pd.DataFrame(
|
292
|
+
sequence_to_flush,
|
293
|
+
columns=[
|
294
|
+
"concept_ids",
|
295
|
+
"person_id",
|
296
|
+
"is_numeric_types",
|
297
|
+
"number_as_values",
|
298
|
+
"concept_as_values",
|
299
|
+
"concept_value_masks",
|
300
|
+
"units",
|
301
|
+
],
|
302
|
+
).to_parquet(os.path.join(output_folder_name, f"{uuid.uuid4()}-last.parquet"))
|
303
|
+
|
304
|
+
|
305
|
+
def create_arg_parser():
|
306
|
+
base_arg_parser = create_inference_base_arg_parser(
|
307
|
+
description="Arguments for generating patient sequences"
|
308
|
+
)
|
309
|
+
base_arg_parser.add_argument(
|
310
|
+
"--num_of_patients",
|
311
|
+
dest="num_of_patients",
|
312
|
+
action="store",
|
313
|
+
type=int,
|
314
|
+
help="The number of patients that will be generated",
|
315
|
+
required=True,
|
316
|
+
)
|
317
|
+
base_arg_parser.add_argument(
|
318
|
+
"--demographic_data_path",
|
319
|
+
dest="demographic_data_path",
|
320
|
+
action="store",
|
321
|
+
help="The path for your concept_path",
|
322
|
+
required=True,
|
323
|
+
)
|
324
|
+
base_arg_parser.add_argument(
|
325
|
+
"--drop_long_sequences",
|
326
|
+
dest="drop_long_sequences",
|
327
|
+
action="store_true",
|
328
|
+
)
|
329
|
+
return base_arg_parser
|
330
|
+
|
331
|
+
|
332
|
+
if __name__ == "__main__":
|
333
|
+
main(create_arg_parser().parse_args())
|