cehrgpt 0.0.1__py3-none-any.whl → 0.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,7 +2,18 @@ import datetime
2
2
  from typing import Any, Dict
3
3
 
4
4
  import numpy as np
5
- from cehrbert.data_generators.hf_data_generator.hf_dataset_mapping import DatasetMapping
5
+ import pandas as pd
6
+ from cehrbert.data_generators.hf_data_generator.hf_dataset_mapping import (
7
+ ED_VISIT_TYPE_CODES,
8
+ INPATIENT_VISIT_TYPE_CODES,
9
+ INPATIENT_VISIT_TYPES,
10
+ DatasetMapping,
11
+ replace_escape_chars,
12
+ )
13
+ from cehrbert.runners.hf_runner_argument_dataclass import DataTrainingArguments
14
+ from cehrbert_data.const.common import NA
15
+ from cehrbert_data.decorators.patient_event_decorator_base import get_att_function
16
+ from dateutil.relativedelta import relativedelta
6
17
 
7
18
  from cehrgpt.models.tokenization_hf_cehrgpt import (
8
19
  NONE_BIN,
@@ -17,6 +28,261 @@ def convert_date_to_posix_time(index_date: datetime.date) -> float:
17
28
  ).timestamp()
18
29
 
19
30
 
31
+ class MedToCehrGPTDatasetMapping(DatasetMapping):
32
+ def __init__(
33
+ self,
34
+ data_args: DataTrainingArguments,
35
+ is_pretraining: bool = True,
36
+ include_inpatient_hour_token: bool = True,
37
+ ):
38
+ self._time_token_function = get_att_function(data_args.att_function_type)
39
+ self._include_auxiliary_token = data_args.include_auxiliary_token
40
+ self._inpatient_time_token_function = get_att_function(
41
+ data_args.inpatient_att_function_type
42
+ )
43
+ self._include_demographic_prompt = data_args.include_demographic_prompt
44
+ self._is_pretraining = is_pretraining
45
+ self._include_inpatient_hour_token = include_inpatient_hour_token
46
+
47
+ """
48
+ This mapping function converts the MED (https://github.com/Medical-Event-Data-Standard/meds/tree/main) extension
49
+ to the CehrGPT format. We make several assumptions
50
+ - The first event contains the demographic information
51
+ - From the second event onward
52
+ - the time of the event is visit_start_datetime.
53
+ - the first measurement contains the code indicating a standard OMOP Visit concept_id (e.g. 9201, 9202)
54
+ - in case of inpatient visits, the last measurement is assumed to
55
+ contain the standard OMOP concept id for discharge facilities (e.g 8536)
56
+ - in case of inpatient visits, datetime_value of the last measurement stores visit_end_datetime
57
+ """
58
+
59
+ def remove_columns(self):
60
+ if self._is_pretraining:
61
+ return ["visits", "birth_datetime", "index_date"]
62
+ else:
63
+ return [
64
+ "visits",
65
+ "birth_datetime",
66
+ "visit_concept_ids",
67
+ ]
68
+
69
+ @staticmethod
70
+ def _update_cehrgpt_record(
71
+ cehrgpt_record: Dict[str, Any],
72
+ code: str,
73
+ concept_value_mask: int = 0,
74
+ number_as_value: float = 0.0,
75
+ concept_as_value: str = "0",
76
+ is_numeric_type: int = 0,
77
+ unit: str = NA,
78
+ ) -> None:
79
+ cehrgpt_record["concept_ids"].append(replace_escape_chars(code))
80
+ cehrgpt_record["concept_value_masks"].append(concept_value_mask)
81
+ cehrgpt_record["number_as_values"].append(number_as_value)
82
+ cehrgpt_record["concept_as_values"].append(concept_as_value)
83
+ cehrgpt_record["units"].append(unit)
84
+ cehrgpt_record["is_numeric_types"].append(is_numeric_type)
85
+
86
+ def transform(self, record: Dict[str, Any]) -> Dict[str, Any]:
87
+ cehrgpt_record = {
88
+ "person_id": record["patient_id"],
89
+ "concept_ids": [],
90
+ "concept_value_masks": [],
91
+ "number_as_values": [],
92
+ "concept_as_values": [],
93
+ "units": [],
94
+ "is_numeric_types": [],
95
+ }
96
+ # Extract the demographic information
97
+ birth_datetime = record["birth_datetime"]
98
+ if isinstance(birth_datetime, pd.Timestamp):
99
+ birth_datetime = birth_datetime.to_pydatetime()
100
+ gender = record["gender"]
101
+ race = record["race"]
102
+
103
+ # Add the demographic tokens
104
+ first_visit = record["visits"][0]
105
+ year_str = f'year:{str(first_visit["visit_start_datetime"].year)}'
106
+ age_str = f'age:{str(relativedelta(first_visit["visit_start_datetime"], birth_datetime).years)}'
107
+ self._update_cehrgpt_record(cehrgpt_record, year_str)
108
+ self._update_cehrgpt_record(cehrgpt_record, age_str)
109
+ self._update_cehrgpt_record(cehrgpt_record, gender)
110
+ self._update_cehrgpt_record(cehrgpt_record, race)
111
+
112
+ # Use a data cursor to keep track of time
113
+ date_cursor = None
114
+
115
+ # Loop through all the visits excluding the first event containing the demographics
116
+ for i, visit in enumerate(
117
+ sorted(record["visits"], key=lambda e: e["visit_start_datetime"])
118
+ ):
119
+
120
+ events = visit["events"]
121
+
122
+ # Skip this visit if the number measurements in the event is zero
123
+ if events is None or len(events) == 0:
124
+ continue
125
+
126
+ visit_start_datetime = visit["visit_start_datetime"]
127
+ time_delta = (
128
+ (visit_start_datetime - date_cursor).days if date_cursor else None
129
+ )
130
+ date_cursor = visit_start_datetime
131
+
132
+ # We assume the first measurement to be the visit type of the current visit
133
+ visit_type = visit["visit_type"]
134
+ is_er_or_inpatient = (
135
+ visit_type in INPATIENT_VISIT_TYPES
136
+ or visit_type in INPATIENT_VISIT_TYPE_CODES
137
+ or visit_type in ED_VISIT_TYPE_CODES
138
+ )
139
+
140
+ # Add artificial time tokens to the patient timeline if timedelta exists
141
+ if time_delta is not None:
142
+ # This generates an artificial time token depending on the choice of the time token functions
143
+ self._update_cehrgpt_record(
144
+ cehrgpt_record,
145
+ code=self._time_token_function(time_delta),
146
+ )
147
+
148
+ # Add the VS token to the patient timeline to mark the start of a visit
149
+ relativedelta(visit["visit_start_datetime"], birth_datetime).years
150
+ # Calculate the week number since the epoch time
151
+ date = (
152
+ visit["visit_start_datetime"]
153
+ - datetime.datetime(year=1970, month=1, day=1)
154
+ ).days // 7
155
+
156
+ # Add a [VS] token
157
+ self._update_cehrgpt_record(
158
+ cehrgpt_record,
159
+ code="[VS]",
160
+ )
161
+ # Add a visit type token
162
+ self._update_cehrgpt_record(
163
+ cehrgpt_record,
164
+ code=visit_type,
165
+ )
166
+ # Keep track of the existing outpatient events, we don't want to add them again
167
+ existing_outpatient_events = list()
168
+ for e in events:
169
+ # If the event doesn't have a time stamp, we skip it
170
+ if not e["time"]:
171
+ continue
172
+
173
+ # If numeric_value exists, this is a concept/value tuple, we indicate this using a concept_value_mask
174
+ numeric_value = e.get("numeric_value", None)
175
+ text_value = e.get("text_value", None)
176
+ # The unit might be populated with a None value
177
+ unit = e.get("unit", NA) if e.get("unit", NA) else NA
178
+ concept_value_mask = int(
179
+ numeric_value is not None or text_value is not None
180
+ )
181
+ is_numeric_type = int(numeric_value is not None)
182
+ code = replace_escape_chars(e["code"])
183
+
184
+ # Add a medical token to the patient timeline
185
+ # If this is an inpatient visit, we use the event time stamps to calculate age and date
186
+ # because the patient can stay in the hospital for a period of time.
187
+ if is_er_or_inpatient:
188
+ # Calculate the week number since the epoch time
189
+ date = (
190
+ e["time"] - datetime.datetime(year=1970, month=1, day=1)
191
+ ).days // 7
192
+ # Calculate the time diff in days w.r.t the previous measurement
193
+ meas_time_diff = (e["time"] - date_cursor).days
194
+ # Update the date_cursor if the time diff between two neighboring measurements is greater than and
195
+ # equal to 1 day
196
+ if meas_time_diff > 0:
197
+ date_cursor = e["time"]
198
+ if self._inpatient_time_token_function:
199
+ # This generates an artificial time token depending on the choice of the time token functions
200
+ self._update_cehrgpt_record(
201
+ cehrgpt_record,
202
+ code=f"i-{self._inpatient_time_token_function(meas_time_diff)}",
203
+ )
204
+ else:
205
+ # For outpatient visits, we use the visit time stamp to calculate age and time because we assume
206
+ # the outpatient visits start and end on the same day.
207
+ # We check whether the date/code/value combination already exists in the existing events
208
+ # If they exist, we do not add them to the patient timeline for outpatient visits.
209
+ if (
210
+ date,
211
+ code,
212
+ numeric_value,
213
+ text_value,
214
+ concept_value_mask,
215
+ numeric_value,
216
+ ) in existing_outpatient_events:
217
+ continue
218
+
219
+ self._update_cehrgpt_record(
220
+ cehrgpt_record,
221
+ code=code,
222
+ concept_value_mask=concept_value_mask,
223
+ unit=unit,
224
+ number_as_value=numeric_value if numeric_value else 0.0,
225
+ concept_as_value=(
226
+ replace_escape_chars(text_value) if text_value else "0"
227
+ ),
228
+ is_numeric_type=is_numeric_type,
229
+ )
230
+ existing_outpatient_events.append(
231
+ (
232
+ date,
233
+ code,
234
+ numeric_value,
235
+ text_value,
236
+ concept_value_mask,
237
+ numeric_value,
238
+ )
239
+ )
240
+
241
+ # For inpatient or ER visits, we want to discharge_facility to the end of the visit
242
+ if is_er_or_inpatient:
243
+ # If visit_end_datetime is populated for the inpatient visit, we update the date_cursor
244
+ visit_end_datetime = visit.get("visit_end_datetime", None)
245
+ if visit_end_datetime:
246
+ date_cursor = visit_end_datetime
247
+
248
+ if self._include_auxiliary_token:
249
+ # Reuse the age and date calculated for the last event in the patient timeline for the discharge
250
+ # facility event
251
+ discharge_facility = (
252
+ visit["discharge_facility"]
253
+ if ("discharge_facility" in visit)
254
+ and visit["discharge_facility"]
255
+ else "0"
256
+ )
257
+
258
+ self._update_cehrgpt_record(
259
+ cehrgpt_record,
260
+ code=discharge_facility,
261
+ )
262
+
263
+ # Reuse the age and date calculated for the last event in the patient timeline
264
+ self._update_cehrgpt_record(
265
+ cehrgpt_record,
266
+ code="[VE]",
267
+ )
268
+
269
+ # Generate the orders of the concepts that the cehrbert dataset mapping function expects
270
+ cehrgpt_record["orders"] = list(
271
+ range(1, len(cehrgpt_record["concept_ids"]) + 1)
272
+ )
273
+
274
+ # Add some count information for this sequence
275
+ cehrgpt_record["num_of_concepts"] = len(cehrgpt_record["concept_ids"])
276
+ cehrgpt_record["num_of_visits"] = len(record["visits"])
277
+
278
+ if "label" in record:
279
+ cehrgpt_record["label"] = record["label"]
280
+ if "age_at_index" in record:
281
+ cehrgpt_record["age_at_index"] = record["age_at_index"]
282
+
283
+ return cehrgpt_record
284
+
285
+
20
286
  class HFCehrGptTokenizationMapping(DatasetMapping):
21
287
  def __init__(
22
288
  self,
@@ -0,0 +1,71 @@
1
+ import torch
2
+ from torch.nn.utils.rnn import pad_sequence
3
+
4
+ from cehrgpt.data.hf_cehrgpt_dataset_collator import CehrGptDataCollator
5
+
6
+
7
+ class CehrGptDPODataCollator(CehrGptDataCollator):
8
+
9
+ def create_preference_inputs(self, examples, prefix):
10
+ batch = {}
11
+ # Assume that each example in the batch is a dictionary with 'input_ids' and 'attention_mask'
12
+ batch_input_ids = [
13
+ self._try_reverse_tensor(
14
+ self._convert_to_tensor(example[f"{prefix}_input_ids"])
15
+ )
16
+ for example in examples
17
+ ]
18
+ batch_attention_mask = [
19
+ self._try_reverse_tensor(
20
+ torch.ones_like(
21
+ self._convert_to_tensor(example[f"{prefix}_input_ids"]),
22
+ dtype=torch.float,
23
+ )
24
+ )
25
+ for example in examples
26
+ ]
27
+ # Pad sequences to the max length in the batch
28
+ batch[f"{prefix}_input_ids"] = self._try_reverse_tensor(
29
+ pad_sequence(
30
+ batch_input_ids,
31
+ batch_first=True,
32
+ padding_value=self.tokenizer.pad_token_id,
33
+ ).to(torch.int64)
34
+ )
35
+ batch[f"{prefix}_attention_mask"] = self._try_reverse_tensor(
36
+ pad_sequence(batch_attention_mask, batch_first=True, padding_value=0.0)
37
+ )
38
+ assert batch[f"{prefix}_input_ids"].shape[1] <= self.max_length
39
+ assert batch[f"{prefix}_attention_mask"].shape[1] <= self.max_length
40
+
41
+ if self.include_values:
42
+ batch_value_indicators = [
43
+ self._try_reverse_tensor(
44
+ self._convert_to_tensor(example[f"{prefix}_value_indicators"])
45
+ )
46
+ for example in examples
47
+ ]
48
+ batch_values = [
49
+ self._try_reverse_tensor(
50
+ self._convert_to_tensor(example[f"{prefix}__values"])
51
+ )
52
+ for example in examples
53
+ ]
54
+
55
+ batch[f"{prefix}_value_indicators"] = self._try_reverse_tensor(
56
+ pad_sequence(
57
+ batch_value_indicators, batch_first=True, padding_value=False
58
+ )
59
+ )
60
+ batch[f"{prefix}_values"] = self._try_reverse_tensor(
61
+ pad_sequence(batch_values, batch_first=True, padding_value=-1.0)
62
+ )
63
+ assert batch[f"{prefix}_value_indicators"].shape[1] <= self.max_length
64
+ assert batch[f"{prefix}_values"].shape[1] <= self.max_length
65
+ return batch
66
+
67
+ def __call__(self, examples):
68
+ batch_chosen = self.create_preference_inputs(examples, "chosen")
69
+ batch_rejected = self.create_preference_inputs(examples, "rejected")
70
+ batch_chosen.update(batch_rejected)
71
+ return batch_chosen
@@ -0,0 +1,61 @@
1
+ import copy
2
+ from typing import Any, Dict
3
+
4
+ import numpy as np
5
+ from cehrbert.data_generators.hf_data_generator.hf_dataset_mapping import DatasetMapping
6
+
7
+ from cehrgpt.models.tokenization_hf_cehrgpt import CehrGptTokenizer
8
+
9
+
10
+ class HFCehrGptDPOTokenizationMapping(DatasetMapping):
11
+ def __init__(
12
+ self,
13
+ concept_tokenizer: CehrGptTokenizer,
14
+ ):
15
+ self._concept_tokenizer = concept_tokenizer
16
+ self._lab_token_ids = self._concept_tokenizer.lab_token_ids
17
+
18
+ def transform_with_prefix(self, record: Dict[str, Any], prefix) -> Dict[str, Any]:
19
+ concept_ids = record[f"{prefix}_concept_ids"]
20
+ input_ids = self._concept_tokenizer.encode(concept_ids)
21
+ record[f"{prefix}_input_ids"] = input_ids
22
+
23
+ if f"{prefix}_concept_value_masks" in record:
24
+ concept_value_masks = record[f"{prefix}_concept_value_masks"]
25
+ concept_values = record[f"{prefix}_concept_values"]
26
+ # If any concept has a value associated with it, we normalize the value
27
+ if np.any(np.asarray(concept_value_masks) > 0):
28
+ units = record[f"{prefix}_units"]
29
+ normalized_concept_values = copy.deepcopy(concept_values)
30
+ for i, (
31
+ concept_id,
32
+ unit,
33
+ token_id,
34
+ concept_value_mask,
35
+ concept_value,
36
+ ) in enumerate(
37
+ zip(
38
+ concept_ids,
39
+ units,
40
+ input_ids,
41
+ concept_value_masks,
42
+ concept_values,
43
+ )
44
+ ):
45
+ if token_id in self._lab_token_ids:
46
+ normalized_concept_value = self._concept_tokenizer.normalize(
47
+ concept_id, unit, concept_value
48
+ )
49
+ normalized_concept_values[i] = normalized_concept_value
50
+ record[f"{prefix}_concept_values"] = normalized_concept_values
51
+ # Overwrite the column names
52
+ record[f"{prefix}_value_indicators"] = record[
53
+ f"{prefix}_concept_value_masks"
54
+ ]
55
+ record[f"{prefix}_values"] = record[f"{prefix}_concept_values"]
56
+ return record
57
+
58
+ def transform(self, record: Dict[str, Any]) -> Dict[str, Any]:
59
+ record = self.transform_with_prefix(record, prefix="chosen")
60
+ record.update(self.transform_with_prefix(record, prefix="rejected"))
61
+ return record
@@ -0,0 +1,224 @@
1
+ import datetime
2
+ import os
3
+ import random
4
+ import uuid
5
+
6
+ import pandas as pd
7
+ import torch
8
+ from cehrbert.runners.runner_util import load_parquet_as_dataset
9
+ from transformers.utils import is_flash_attn_2_available, logging
10
+
11
+ from cehrgpt.cehrgpt_args import create_inference_base_arg_parser
12
+ from cehrgpt.generation.generate_batch_hf_gpt_sequence import (
13
+ generate_single_batch,
14
+ normalize_value,
15
+ )
16
+ from cehrgpt.gpt_utils import get_cehrgpt_output_folder
17
+ from cehrgpt.models.hf_cehrgpt import CEHRGPT2LMHeadModel
18
+ from cehrgpt.models.tokenization_hf_cehrgpt import CehrGptTokenizer
19
+
20
+ LOG = logging.get_logger("transformers")
21
+
22
+
23
+ def main(args):
24
+ if torch.cuda.is_available():
25
+ device = torch.device("cuda")
26
+ else:
27
+ device = torch.device("cpu")
28
+
29
+ cehrgpt_tokenizer = CehrGptTokenizer.from_pretrained(args.tokenizer_folder)
30
+ cehrgpt_model = (
31
+ CEHRGPT2LMHeadModel.from_pretrained(
32
+ args.model_folder,
33
+ attn_implementation=(
34
+ "flash_attention_2" if is_flash_attn_2_available() else "eager"
35
+ ),
36
+ torch_dtype=(
37
+ torch.bfloat16 if is_flash_attn_2_available() else torch.float32
38
+ ),
39
+ )
40
+ .eval()
41
+ .to(device)
42
+ )
43
+ cehrgpt_model.generation_config.pad_token_id = cehrgpt_tokenizer.pad_token_id
44
+ cehrgpt_model.generation_config.eos_token_id = cehrgpt_tokenizer.end_token_id
45
+ cehrgpt_model.generation_config.bos_token_id = cehrgpt_tokenizer.end_token_id
46
+
47
+ folder_name = get_cehrgpt_output_folder(args, cehrgpt_tokenizer)
48
+ output_folder_name = os.path.join(
49
+ args.output_folder, folder_name, "generated_sequences"
50
+ )
51
+
52
+ if not os.path.exists(output_folder_name):
53
+ os.makedirs(output_folder_name)
54
+
55
+ LOG.info(f"Loading tokenizer at {args.model_folder}")
56
+ LOG.info(f"Loading model at {args.model_folder}")
57
+ LOG.info(f"Write sequences to {output_folder_name}")
58
+ LOG.info(f"Context window {args.context_window}")
59
+ LOG.info(f"Temperature {args.temperature}")
60
+ LOG.info(f"Repetition Penalty {args.repetition_penalty}")
61
+ LOG.info(f"Sampling Strategy {args.sampling_strategy}")
62
+ LOG.info(f"Num beam {args.num_beams}")
63
+ LOG.info(f"Num beam groups {args.num_beam_groups}")
64
+ LOG.info(f"Epsilon cutoff {args.epsilon_cutoff}")
65
+ LOG.info(f"Top P {args.top_p}")
66
+ LOG.info(f"Top K {args.top_k}")
67
+ LOG.info(f"Loading sequence_data_path at {args.sequence_data_path}")
68
+
69
+ dataset = load_parquet_as_dataset(args.sequence_data_path)
70
+ total_rows = len(dataset)
71
+ float(args.batch_size) / total_rows
72
+ num_of_batches = args.num_of_patients // args.batch_size + 1
73
+ sequence_to_flush = []
74
+ for i in range(num_of_batches):
75
+ LOG.info(f"{datetime.datetime.now()}: Batch {i} started")
76
+ sample_data = []
77
+ while len(sample_data) == 0:
78
+ random_indices = random.sample(range(total_rows), k=1)
79
+ for row in dataset.select(random_indices):
80
+ if 4 <= len(row["concept_ids"]) <= cehrgpt_model.config.n_positions:
81
+ sample_data.append(row)
82
+ prompts = []
83
+ chosen_responses = []
84
+ cutoff_frac = random.uniform(0, args.cutoff_frac_max)
85
+ for row in sample_data:
86
+ seq_len = len(row["concept_ids"])
87
+ prompt_len = max(4, int(seq_len * cutoff_frac))
88
+ prompts.append(cehrgpt_tokenizer.encode(row["concept_ids"][:prompt_len]))
89
+ chosen_responses.append(
90
+ {
91
+ "person_id": row["person_id"],
92
+ "chosen_concept_ids": (
93
+ row["concept_ids"] if "concept_ids" in row else None
94
+ ),
95
+ "chosen_concept_values": (
96
+ row["concept_values"] if "concept_values" in row else None
97
+ ),
98
+ "chosen_concept_value_masks": (
99
+ row["concept_value_masks"]
100
+ if "concept_value_masks" in row
101
+ else None
102
+ ),
103
+ "chosen_units": row["units"] if "units" in row else None,
104
+ "prompt_length": prompt_len,
105
+ }
106
+ )
107
+
108
+ batch_sequences = generate_single_batch(
109
+ cehrgpt_model,
110
+ cehrgpt_tokenizer,
111
+ prompts=prompts,
112
+ max_new_tokens=args.context_window,
113
+ mini_num_of_concepts=args.min_num_of_concepts,
114
+ top_p=args.top_p,
115
+ top_k=args.top_k,
116
+ temperature=args.temperature,
117
+ repetition_penalty=args.repetition_penalty,
118
+ num_beams=args.num_beams,
119
+ num_beam_groups=args.num_beam_groups,
120
+ epsilon_cutoff=args.epsilon_cutoff,
121
+ device=device,
122
+ )
123
+
124
+ # Clear the cache
125
+ torch.cuda.empty_cache()
126
+
127
+ for seq, value_indicator, value, chosen_response in zip(
128
+ batch_sequences["sequences"],
129
+ batch_sequences["value_indicators"],
130
+ batch_sequences["values"],
131
+ chosen_responses,
132
+ ):
133
+ output = {"rejected_concept_ids": seq}
134
+ normalized_values, units = normalize_value(
135
+ seq, value_indicator, value, cehrgpt_tokenizer
136
+ )
137
+ if normalized_values is not None:
138
+ output["rejected_concept_values"] = normalized_values
139
+ if value_indicator is not None:
140
+ output["rejected_concept_value_masks"] = value_indicator
141
+ if units is not None:
142
+ output["rejected_units"] = units
143
+ output.update(chosen_response)
144
+ sequence_to_flush.append(output)
145
+
146
+ if len(sequence_to_flush) >= args.buffer_size:
147
+ LOG.info(f"{datetime.datetime.now()}: Flushing to the Disk at Batch {i}")
148
+ pd.DataFrame(
149
+ sequence_to_flush,
150
+ columns=[
151
+ "person_id",
152
+ "chosen_concept_ids",
153
+ "chosen_concept_values",
154
+ "chosen_concept_value_masks",
155
+ "chosen_units",
156
+ "prompt_length",
157
+ "rejected_concept_ids",
158
+ "rejected_concept_values",
159
+ "rejected_concept_value_masks",
160
+ "rejected_units",
161
+ ],
162
+ ).to_parquet(os.path.join(output_folder_name, f"{uuid.uuid4()}.parquet"))
163
+ sequence_to_flush.clear()
164
+
165
+ if len(sequence_to_flush) > 0:
166
+ LOG.info(f"{datetime.datetime.now()}: Flushing to the Disk at Final Batch")
167
+ pd.DataFrame(
168
+ sequence_to_flush,
169
+ columns=[
170
+ "person_id",
171
+ "chosen_concept_ids",
172
+ "chosen_concept_values",
173
+ "chosen_concept_value_masks",
174
+ "chosen_units",
175
+ "prompt_length",
176
+ "rejected_concept_ids",
177
+ "rejected_concept_values",
178
+ "rejected_concept_value_masks",
179
+ "rejected_units",
180
+ ],
181
+ ).to_parquet(os.path.join(output_folder_name, f"{uuid.uuid4()}-last.parquet"))
182
+
183
+
184
+ def create_arg_parser():
185
+ base_arg_parser = create_inference_base_arg_parser(
186
+ description="Arguments for generating paired patient sequences"
187
+ )
188
+ base_arg_parser.add_argument(
189
+ "--num_of_patients",
190
+ dest="num_of_patients",
191
+ action="store",
192
+ type=int,
193
+ help="The number of patients that will be generated",
194
+ required=True,
195
+ )
196
+ base_arg_parser.add_argument(
197
+ "--sequence_data_path",
198
+ dest="sequence_data_path",
199
+ action="store",
200
+ help="The path for your sequence data",
201
+ required=True,
202
+ )
203
+ base_arg_parser.add_argument(
204
+ "--cutoff_frac_max",
205
+ dest="cutoff_frac_max",
206
+ action="store",
207
+ type=float,
208
+ help="The max fraction of the patient sequences that will be used for prompting",
209
+ required=False,
210
+ default=0.5,
211
+ )
212
+ base_arg_parser.add_argument(
213
+ "--num_proc",
214
+ dest="num_proc",
215
+ action="store",
216
+ type=int,
217
+ required=False,
218
+ default=1,
219
+ )
220
+ return base_arg_parser
221
+
222
+
223
+ if __name__ == "__main__":
224
+ main(create_arg_parser().parse_args())
@@ -35,6 +35,7 @@ from cehrgpt.models.tokenization_hf_cehrgpt import END_TOKEN
35
35
  # TODO: move these to cehrbert_data
36
36
  STOP_TOKENS = ["VE", "[VE]", END_TOKEN]
37
37
 
38
+ OOV = "[OOV]"
38
39
  CURRENT_PATH = Path(__file__).parent
39
40
  START_TOKEN_SIZE = 4
40
41
  ATT_TIME_TOKENS = generate_artificial_time_tokens()
@@ -297,6 +298,8 @@ def gpt_to_omop_converter_batch(
297
298
  inpatient_visit_indicator = False
298
299
 
299
300
  for event_idx, event in enumerate(clinical_events, 0):
301
+ if event == OOV:
302
+ continue
300
303
  # For bad sequences, we don't proceed further and break from the for loop
301
304
  if bad_sequence:
302
305
  break
@@ -1766,6 +1766,7 @@ class CehrGptForClassification(CEHRGPTPreTrainedModel):
1766
1766
  output_attentions: Optional[bool] = None,
1767
1767
  output_hidden_states: Optional[bool] = None,
1768
1768
  return_dict: Optional[bool] = None,
1769
+ **kwargs,
1769
1770
  ) -> CehrGptSequenceClassifierOutput:
1770
1771
  cehrgpt_output = self.cehrgpt(
1771
1772
  input_ids=input_ids,
@@ -918,12 +918,12 @@ class CehrGptTokenizer(PreTrainedTokenizer):
918
918
  map_statistics_partial = partial(map_statistics, size=SAMPLE_SIZE)
919
919
 
920
920
  if data_args.streaming:
921
+ first_example = next(iter(dataset))
921
922
  parts = dataset.map(
922
923
  partial(agg_helper, map_func=map_statistics_partial),
923
924
  batched=True,
924
925
  batch_size=data_args.preprocessing_batch_size,
925
- new_fingerprint="invalid",
926
- remove_columns=dataset.column_names,
926
+ remove_columns=first_example.keys(),
927
927
  )
928
928
  else:
929
929
  parts = dataset.map(
File without changes