cehrgpt 0.1.0__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,8 +1,10 @@
1
1
  import datetime
2
- from typing import Any, Dict, Generator, Optional
2
+ from collections import defaultdict
3
+ from typing import Any, Dict, Generator, List, Optional, Union
3
4
 
4
5
  import numpy as np
5
6
  import pandas as pd
7
+ from cehrbert.data_generators.hf_data_generator import UNKNOWN_VALUE
6
8
  from cehrbert.data_generators.hf_data_generator.hf_dataset_mapping import (
7
9
  ED_VISIT_TYPE_CODES,
8
10
  INPATIENT_VISIT_TYPE_CODES,
@@ -15,9 +17,17 @@ from cehrbert.data_generators.hf_data_generator.hf_dataset_mapping import (
15
17
  )
16
18
  from cehrbert.med_extension.schema_extension import Event
17
19
  from cehrbert.runners.hf_runner_argument_dataclass import DataTrainingArguments
20
+ from cehrbert_data.const.artificial_tokens import (
21
+ DISCHARGE_UNKNOWN_TOKEN,
22
+ GENDER_UNKNOWN_TOKEN,
23
+ RACE_UNKNOWN_TOKEN,
24
+ VISIT_UNKNOWN_TOKEN,
25
+ )
18
26
  from cehrbert_data.const.common import NA
19
27
  from cehrbert_data.decorators.patient_event_decorator_base import get_att_function
28
+ from datasets.formatting.formatting import LazyBatch
20
29
  from dateutil.relativedelta import relativedelta
30
+ from pandas import Series
21
31
 
22
32
  from cehrgpt.models.tokenization_hf_cehrgpt import (
23
33
  NONE_BIN,
@@ -25,6 +35,17 @@ from cehrgpt.models.tokenization_hf_cehrgpt import (
25
35
  CehrGptTokenizer,
26
36
  )
27
37
 
38
+ CEHRGPT_COLUMNS = [
39
+ "concept_ids",
40
+ "concept_value_masks",
41
+ "number_as_values",
42
+ "concept_as_values",
43
+ "is_numeric_types",
44
+ "concept_values",
45
+ "units",
46
+ "epoch_times",
47
+ ]
48
+
28
49
 
29
50
  def convert_date_to_posix_time(index_date: datetime.date) -> float:
30
51
  return datetime.datetime.combine(
@@ -32,7 +53,36 @@ def convert_date_to_posix_time(index_date: datetime.date) -> float:
32
53
  ).timestamp()
33
54
 
34
55
 
35
- class MedToCehrGPTDatasetMapping(DatasetMapping):
56
+ class DatasetMappingDecorator(DatasetMapping):
57
+
58
+ def batch_transform(
59
+ self, records: Union[LazyBatch, Dict[str, Any]]
60
+ ) -> List[Dict[str, Any]]:
61
+ """
62
+ Drop index_date if it contains None.
63
+
64
+ :param records:
65
+ :return:
66
+ """
67
+ if isinstance(records, LazyBatch):
68
+ table = records.pa_table
69
+
70
+ if "index_date" in table.column_names:
71
+ index_col = table.column("index_date")
72
+ if index_col.null_count > 0:
73
+ table = table.drop(["index_date"])
74
+ records = LazyBatch(pa_table=table, formatter=records.formatter)
75
+ else:
76
+ if "index_date" in records:
77
+ if pd.isna(records["index_date"][0]):
78
+ del records["index_date"]
79
+ return super().batch_transform(records=records)
80
+
81
+ def transform(self, record: Dict[str, Any]) -> Union[Dict[str, Any], Series]:
82
+ raise NotImplemented("Must be implemented")
83
+
84
+
85
+ class MedToCehrGPTDatasetMapping(DatasetMappingDecorator):
36
86
  def __init__(
37
87
  self,
38
88
  data_args: DataTrainingArguments,
@@ -65,6 +115,7 @@ class MedToCehrGPTDatasetMapping(DatasetMapping):
65
115
  def _update_cehrgpt_record(
66
116
  cehrgpt_record: Dict[str, Any],
67
117
  code: str,
118
+ time: datetime.datetime,
68
119
  concept_value_mask: int = 0,
69
120
  number_as_value: float = 0.0,
70
121
  concept_as_value: str = "0",
@@ -77,6 +128,7 @@ class MedToCehrGPTDatasetMapping(DatasetMapping):
77
128
  cehrgpt_record["concept_as_values"].append(concept_as_value)
78
129
  cehrgpt_record["units"].append(unit)
79
130
  cehrgpt_record["is_numeric_types"].append(is_numeric_type)
131
+ cehrgpt_record["epoch_times"].append(time.timestamp())
80
132
 
81
133
  def transform(self, record: Dict[str, Any]) -> Dict[str, Any]:
82
134
  cehrgpt_record = {
@@ -87,13 +139,16 @@ class MedToCehrGPTDatasetMapping(DatasetMapping):
87
139
  "concept_as_values": [],
88
140
  "units": [],
89
141
  "is_numeric_types": [],
142
+ "epoch_times": [],
90
143
  }
91
144
  # Extract the demographic information
92
145
  birth_datetime = record["birth_datetime"]
93
146
  if isinstance(birth_datetime, pd.Timestamp):
94
147
  birth_datetime = birth_datetime.to_pydatetime()
95
148
  gender = record["gender"]
149
+ gender = GENDER_UNKNOWN_TOKEN if gender == UNKNOWN_VALUE else gender
96
150
  race = record["race"]
151
+ race = RACE_UNKNOWN_TOKEN if race == UNKNOWN_VALUE else race
97
152
  visits = record["visits"]
98
153
  # This indicates this is columnar format
99
154
  if isinstance(visits, dict):
@@ -108,10 +163,12 @@ class MedToCehrGPTDatasetMapping(DatasetMapping):
108
163
  )
109
164
  year_str = f"year:{str(first_visit_start_datetime.year)}"
110
165
  age_str = f"age:{str(relativedelta(first_visit_start_datetime, birth_datetime).years)}"
111
- self._update_cehrgpt_record(cehrgpt_record, year_str)
112
- self._update_cehrgpt_record(cehrgpt_record, age_str)
113
- self._update_cehrgpt_record(cehrgpt_record, gender)
114
- self._update_cehrgpt_record(cehrgpt_record, race)
166
+ self._update_cehrgpt_record(
167
+ cehrgpt_record, year_str, first_visit_start_datetime
168
+ )
169
+ self._update_cehrgpt_record(cehrgpt_record, age_str, first_visit_start_datetime)
170
+ self._update_cehrgpt_record(cehrgpt_record, gender, first_visit_start_datetime)
171
+ self._update_cehrgpt_record(cehrgpt_record, race, first_visit_start_datetime)
115
172
 
116
173
  # Use a data cursor to keep track of time
117
174
  datetime_cursor: Optional[datetime.datetime] = None
@@ -146,6 +203,7 @@ class MedToCehrGPTDatasetMapping(DatasetMapping):
146
203
  self._update_cehrgpt_record(
147
204
  cehrgpt_record,
148
205
  code=self._time_token_function(time_delta),
206
+ time=visit_start_datetime,
149
207
  )
150
208
 
151
209
  datetime_cursor = visit_start_datetime
@@ -153,11 +211,13 @@ class MedToCehrGPTDatasetMapping(DatasetMapping):
153
211
  self._update_cehrgpt_record(
154
212
  cehrgpt_record,
155
213
  code="[VS]",
214
+ time=datetime_cursor,
156
215
  )
157
216
  # Add a visit type token
158
217
  self._update_cehrgpt_record(
159
218
  cehrgpt_record,
160
219
  code=visit_type,
220
+ time=datetime_cursor,
161
221
  )
162
222
  # We need to insert an inpatient hour token right after the visit type, we calculate the hour interval
163
223
  # with respect to the midnight of the day
@@ -167,6 +227,7 @@ class MedToCehrGPTDatasetMapping(DatasetMapping):
167
227
  self._update_cehrgpt_record(
168
228
  cehrgpt_record,
169
229
  code=f"i-H{datetime_cursor.hour}",
230
+ time=datetime_cursor,
170
231
  )
171
232
 
172
233
  # Keep track of the existing outpatient events, we don't want to add them again
@@ -185,6 +246,10 @@ class MedToCehrGPTDatasetMapping(DatasetMapping):
185
246
  concept_value_mask = int(
186
247
  numeric_value is not None or text_value is not None
187
248
  )
249
+ if numeric_value is None and text_value is not None:
250
+ if text_value.isnumeric():
251
+ numeric_value = float(text_value)
252
+
188
253
  is_numeric_type = int(numeric_value is not None)
189
254
  code = replace_escape_chars(e["code"])
190
255
 
@@ -208,6 +273,7 @@ class MedToCehrGPTDatasetMapping(DatasetMapping):
208
273
  self._update_cehrgpt_record(
209
274
  cehrgpt_record,
210
275
  code=f"i-{self._inpatient_time_token_function(time_diff_days)}",
276
+ time=event_time,
211
277
  )
212
278
 
213
279
  if self._include_inpatient_hour_token:
@@ -226,6 +292,7 @@ class MedToCehrGPTDatasetMapping(DatasetMapping):
226
292
  self._update_cehrgpt_record(
227
293
  cehrgpt_record,
228
294
  code=f"i-H{time_diff_hours}",
295
+ time=event_time,
229
296
  )
230
297
 
231
298
  if event_identity in existing_duplicate_events:
@@ -234,6 +301,7 @@ class MedToCehrGPTDatasetMapping(DatasetMapping):
234
301
  self._update_cehrgpt_record(
235
302
  cehrgpt_record,
236
303
  code=code,
304
+ time=event_time,
237
305
  concept_value_mask=concept_value_mask,
238
306
  unit=unit,
239
307
  number_as_value=numeric_value if numeric_value else 0.0,
@@ -262,17 +330,24 @@ class MedToCehrGPTDatasetMapping(DatasetMapping):
262
330
  # facility event
263
331
  discharge_facility = get_value(visit, "discharge_facility")
264
332
  if not discharge_facility:
265
- discharge_facility = "0"
266
-
333
+ discharge_facility = DISCHARGE_UNKNOWN_TOKEN
334
+ else:
335
+ discharge_facility = (
336
+ DISCHARGE_UNKNOWN_TOKEN
337
+ if discharge_facility == UNKNOWN_VALUE
338
+ else discharge_facility
339
+ )
267
340
  self._update_cehrgpt_record(
268
341
  cehrgpt_record,
269
342
  code=discharge_facility,
343
+ time=datetime_cursor,
270
344
  )
271
345
 
272
346
  # Reuse the age and date calculated for the last event in the patient timeline
273
347
  self._update_cehrgpt_record(
274
348
  cehrgpt_record,
275
349
  code="[VE]",
350
+ time=datetime_cursor,
276
351
  )
277
352
 
278
353
  # Generate the orders of the concepts that the cehrbert dataset mapping function expects
@@ -284,17 +359,21 @@ class MedToCehrGPTDatasetMapping(DatasetMapping):
284
359
  cehrgpt_record["num_of_concepts"] = len(cehrgpt_record["concept_ids"])
285
360
  cehrgpt_record["num_of_visits"] = len(visits)
286
361
 
287
- if record.get("index_date", None):
362
+ if record.get("index_date", None) is not None:
288
363
  cehrgpt_record["index_date"] = record["index_date"]
289
- if record.get("label", None):
364
+ if record.get("label", None) is not None:
290
365
  cehrgpt_record["label"] = record["label"]
291
- if record.get("age_at_index", None):
366
+ if record.get("age_at_index", None) is not None:
292
367
  cehrgpt_record["age_at_index"] = record["age_at_index"]
293
368
 
369
+ assert len(cehrgpt_record["epoch_times"]) == len(
370
+ cehrgpt_record["concept_ids"]
371
+ ), "The number of time stamps must match with the number of concepts in the sequence"
372
+
294
373
  return cehrgpt_record
295
374
 
296
375
 
297
- class HFCehrGptTokenizationMapping(DatasetMapping):
376
+ class HFCehrGptTokenizationMapping(DatasetMappingDecorator):
298
377
  def __init__(
299
378
  self,
300
379
  concept_tokenizer: CehrGptTokenizer,
@@ -308,9 +387,46 @@ class HFCehrGptTokenizationMapping(DatasetMapping):
308
387
  "is_numeric_types",
309
388
  ]
310
389
 
390
+ def filter_out_invalid_tokens(self, record: Dict[str, Any]) -> Dict[str, Any]:
391
+ column_names = []
392
+ seq_length = len(record["concept_ids"])
393
+
394
+ # We can't have "0" as a token in the tokenizer because it would break tokenization for "Race/0", "Visit/0"
395
+ # This is a pre-caution
396
+ if "0" in record["concept_ids"]:
397
+ if isinstance(record["concept_ids"], np.ndarray):
398
+ record["concept_ids"][record["concept_ids"] == "0"] = "Unknown"
399
+ else:
400
+ record["concept_ids"] = [
401
+ "Unknown" if x == "0" else x for x in record["concept_ids"]
402
+ ]
403
+
404
+ for k, v in record.items():
405
+ if k not in CEHRGPT_COLUMNS:
406
+ continue
407
+ if isinstance(v, (list, np.ndarray)) and len(v) == seq_length:
408
+ column_names.append(k)
409
+ valid_concept_ids = self._concept_tokenizer.get_vocab().keys()
410
+ valid_indices = [
411
+ idx
412
+ for idx, concept_id in enumerate(record["concept_ids"])
413
+ if concept_id in valid_concept_ids
414
+ ]
415
+ if len(valid_indices) != len(record["concept_ids"]):
416
+ for column in column_names:
417
+ values = record[column]
418
+ record[column] = [values[idx] for idx in valid_indices]
419
+ return record
420
+
311
421
  def transform(self, record: Dict[str, Any]) -> Dict[str, Any]:
422
+ # Remove the tokens from patient sequences that do not exist in the tokenizer
423
+ record = self.filter_out_invalid_tokens(record)
312
424
  # If any concept has a value associated with it, we normalize the value
313
425
  record["input_ids"] = self._concept_tokenizer.encode(record["concept_ids"])
426
+ assert len(record["input_ids"]) == len(record["concept_ids"]), (
427
+ "The number of tokens must equal to the number of concepts\n"
428
+ f"decoded concept_ids: {self._concept_tokenizer.decode(record['input_ids'], skip_special_tokens=False)}"
429
+ )
314
430
  record["value_indicators"] = record["concept_value_masks"]
315
431
  if "number_as_values" not in record or "concept_as_values" not in record:
316
432
  record["number_as_values"] = [
@@ -391,3 +507,89 @@ class HFFineTuningMapping(HFCehrGptTokenizationMapping):
391
507
  columns = super().remove_columns()
392
508
  columns.append("label")
393
509
  return columns
510
+
511
+
512
+ class ExtractTokenizedSequenceDataMapping:
513
+ def __init__(
514
+ self,
515
+ person_index_date_map: Dict[int, List[Dict[str, Any]]],
516
+ observation_window: int = 0,
517
+ ):
518
+ self.person_index_date_map = person_index_date_map
519
+ self.observation_window = observation_window
520
+
521
+ def _calculate_prediction_start_time(self, prediction_time: float):
522
+ if self.observation_window and self.observation_window > 0:
523
+ return max(prediction_time - self.observation_window * 24 * 3600, 0)
524
+ return 0
525
+
526
+ def transform(self, record: Dict[str, Any]) -> Dict[str, Any]:
527
+ person_id = record["person_id"]
528
+ prediction_times = self.person_index_date_map[person_id]
529
+ prediction_start_end_times = [
530
+ (
531
+ self._calculate_prediction_start_time(
532
+ prediction_time_label_map["index_date"].timestamp()
533
+ ),
534
+ prediction_time_label_map["index_date"].timestamp(),
535
+ prediction_time_label_map["label"],
536
+ )
537
+ for prediction_time_label_map in prediction_times
538
+ ]
539
+ observation_window_indices = np.zeros(
540
+ (len(prediction_times), len(record["epoch_times"])), dtype=bool
541
+ )
542
+ for i, epoch_time in enumerate(record["epoch_times"]):
543
+ for sample_n, (
544
+ feature_extraction_time_start,
545
+ feature_extraction_end_end,
546
+ _,
547
+ ) in enumerate(prediction_start_end_times):
548
+ if (
549
+ feature_extraction_time_start
550
+ <= epoch_time
551
+ <= feature_extraction_end_end
552
+ ):
553
+ observation_window_indices[sample_n][i] = True
554
+
555
+ seq_length = len(record["epoch_times"])
556
+ time_series_columns = ["concept_ids", "input_ids"]
557
+ static_inputs = dict()
558
+ for k, v in record.items():
559
+ if k in ["concept_ids", "input_ids"]:
560
+ continue
561
+ if isinstance(v, (list, np.ndarray)) and len(v) == seq_length:
562
+ time_series_columns.append(k)
563
+ else:
564
+ static_inputs[k] = v
565
+
566
+ batched_samples = defaultdict(list)
567
+ for (_, index_date, label), observation_window_index in zip(
568
+ prediction_start_end_times, observation_window_indices
569
+ ):
570
+ for k, v in static_inputs.items():
571
+ batched_samples[k].append(v)
572
+ batched_samples["classifier_label"].append(label)
573
+ batched_samples["index_date"].append(index_date)
574
+ try:
575
+ start_age = int(record["concept_ids"][1].split(":")[1])
576
+ except Exception:
577
+ start_age = -1
578
+ batched_samples["age_at_index"].append(start_age)
579
+ for time_series_column in time_series_columns:
580
+ batched_samples[time_series_column].append(
581
+ np.asarray(record[time_series_column])[observation_window_index]
582
+ )
583
+ return batched_samples
584
+
585
+ def batch_transform(self, record: Dict[str, Any]) -> Dict[str, Any]:
586
+ all_batched_record = defaultdict(list)
587
+ all_columns = record.keys()
588
+ for i in range(len(record["concept_ids"])):
589
+ one_record = {}
590
+ for column in all_columns:
591
+ one_record[column] = record[column][i]
592
+ new_batched_record = self.transform(one_record)
593
+ for k, v in new_batched_record.items():
594
+ all_batched_record[k].extend(v)
595
+ return all_batched_record
@@ -1,5 +1,6 @@
1
1
  from typing import Iterator, List, Optional
2
2
 
3
+ import numpy as np
3
4
  import torch
4
5
  import torch.distributed as dist
5
6
  from torch.utils.data import Sampler
@@ -33,6 +34,8 @@ class SamplePackingBatchSampler(Sampler[List[int]]):
33
34
  rank: Optional[int] = None,
34
35
  seed: int = 0,
35
36
  drop_last: bool = False,
37
+ negative_sampling_probability: Optional[float] = None,
38
+ labels: Optional[List[int]] = None,
36
39
  ):
37
40
  """
38
41
  Args:
@@ -73,6 +76,11 @@ class SamplePackingBatchSampler(Sampler[List[int]]):
73
76
  f"Invalid rank {rank}, rank should be in the interval [0, {num_replicas - 1}]"
74
77
  )
75
78
 
79
+ if negative_sampling_probability is not None and labels is None:
80
+ raise ValueError(
81
+ f"When the negative sampling probability is provide, the labels must be provided as well"
82
+ )
83
+
76
84
  self.lengths = lengths
77
85
  self.max_tokens_per_batch = max_tokens_per_batch
78
86
  self.max_position_embeddings = max_position_embeddings
@@ -80,6 +88,8 @@ class SamplePackingBatchSampler(Sampler[List[int]]):
80
88
  self.rank = rank
81
89
  self.seed = seed
82
90
  self.drop_last = drop_last
91
+ self.negative_sampling_probability = negative_sampling_probability
92
+ self.labels = labels
83
93
  # Trainer https://github.com/huggingface/transformers/blame/main/src/transformers/trainer.py#L2470
84
94
  # http://github.com/huggingface/accelerate/blob/v0.31.0/src/accelerate/data_loader.py#L482
85
95
  # the huggingface trainer will call the accelerate.data_loader.DataLoaderShard.set_epoch,
@@ -100,6 +110,14 @@ class SamplePackingBatchSampler(Sampler[List[int]]):
100
110
  current_batch_tokens = 0
101
111
 
102
112
  for idx in indices:
113
+ # There is a chance to skip the negative samples to account for the class imbalance
114
+ # in the fine-tuning dataset
115
+ if self.negative_sampling_probability:
116
+ if (
117
+ np.random.random() > self.negative_sampling_probability
118
+ and self.labels[idx] == 0
119
+ ):
120
+ continue
103
121
  # We take the minimum of the two because each sequence will be truncated to fit
104
122
  # the context window of the model
105
123
  sample_length = min(self.lengths[idx], self.max_position_embeddings)
@@ -131,10 +149,22 @@ class SamplePackingBatchSampler(Sampler[List[int]]):
131
149
  if len(self.lengths) == 0:
132
150
  return 0
133
151
 
134
- # We need to truncate the lengths due to the context window limit imposed by the model
135
- truncated_lengths = [
136
- min(self.max_position_embeddings, length + 2) for length in self.lengths
137
- ]
152
+ # There is a chance to skip the negative samples to account for the class imbalance
153
+ # in the fine-tuning dataset
154
+ if self.negative_sampling_probability:
155
+ truncated_lengths = []
156
+ for length, label in zip(self.lengths, self.labels):
157
+ if (
158
+ np.random.random() > self.negative_sampling_probability
159
+ and label == 0
160
+ ):
161
+ continue
162
+ truncated_lengths.append(length)
163
+ else:
164
+ # We need to truncate the lengths due to the context window limit imposed by the model
165
+ truncated_lengths = [
166
+ min(self.max_position_embeddings, length + 2) for length in self.lengths
167
+ ]
138
168
 
139
169
  # Calculate average sequence length
140
170
  avg_seq_length = sum(truncated_lengths) // len(truncated_lengths)
@@ -145,7 +175,7 @@ class SamplePackingBatchSampler(Sampler[List[int]]):
145
175
  # Estimate total number of batches
146
176
  if self.drop_last:
147
177
  # If dropping last incomplete batch
148
- return len(truncated_lengths) // seqs_per_batch * self.num_replicas
178
+ return len(truncated_lengths) // seqs_per_batch
149
179
  else:
150
180
  # If keeping last incomplete batch, ensure at least 1 batch
151
- return max(1, len(truncated_lengths) // seqs_per_batch) * self.num_replicas
181
+ return max(1, len(truncated_lengths) // seqs_per_batch)
@@ -60,6 +60,24 @@ OOV_CONCEPT_MAP = {
60
60
  }
61
61
 
62
62
 
63
+ def extract_gender_concept_id(gender_token: str) -> int:
64
+ if gender_token.startswith("Gender/"):
65
+ return int(gender_token[len("Gender/") :])
66
+ elif gender_token.isnumeric():
67
+ return int(gender_token)
68
+ else:
69
+ return 0
70
+
71
+
72
+ def extract_race_concept_id(race_token: str) -> int:
73
+ if race_token.startswith("Race/"):
74
+ return int(race_token[len("Race/") :])
75
+ elif race_token.isnumeric():
76
+ return int(race_token)
77
+ else:
78
+ return 0
79
+
80
+
63
81
  def create_folder_if_not_exists(output_folder, table_name):
64
82
  if not os.path.isdir(Path(output_folder) / table_name):
65
83
  os.mkdir(Path(output_folder) / table_name)
@@ -288,7 +306,13 @@ def gpt_to_omop_converter_batch(
288
306
  if int(birth_year) < 1900 or int(birth_year) > datetime.date.today().year:
289
307
  continue
290
308
 
291
- p = Person(person_id, start_gender, birth_year, start_race)
309
+ p = Person(
310
+ person_id=person_id,
311
+ gender_concept_id=extract_gender_concept_id(start_gender),
312
+ year_of_birth=birth_year,
313
+ race_concept_id=extract_race_concept_id(start_race),
314
+ )
315
+
292
316
  append_to_dict(omop_export_dict, p, person_id)
293
317
  id_mappings_dict["person"][person_id] = person_id
294
318
  pt_seq_dict[person_id] = " ".join(concept_ids)
@@ -316,7 +340,12 @@ def gpt_to_omop_converter_batch(
316
340
  id_mappings_dict["death"][person_id] = person_id
317
341
  else:
318
342
  try:
319
- visit_concept_id = int(clinical_events[event_idx + 1])
343
+ if clinical_events[event_idx + 1].startswith("Visit/"):
344
+ visit_concept_id = int(
345
+ clinical_events[event_idx + 1][len("Visit/") :]
346
+ )
347
+ else:
348
+ visit_concept_id = int(clinical_events[event_idx + 1])
320
349
  inpatient_visit_indicator = visit_concept_id in [
321
350
  9201,
322
351
  262,
@@ -349,6 +378,7 @@ def gpt_to_omop_converter_batch(
349
378
  visit_occurrence_id
350
379
  ] = person_id
351
380
  visit_occurrence_id += 1
381
+
352
382
  elif event in ATT_TIME_TOKENS:
353
383
  if event[0] == "D":
354
384
  att_date_delta = int(event[1:])
cehrgpt/gpt_utils.py CHANGED
@@ -11,6 +11,7 @@ from cehrgpt.models.special_tokens import (
11
11
  )
12
12
 
13
13
  # Regular expression pattern to match inpatient attendance tokens
14
+ MEDS_CODE_PATTERN = re.compile(r".*/.*")
14
15
  INPATIENT_ATT_PATTERN = re.compile(r"(?:VS-|i-)D(\d+)(?:-VE)?")
15
16
  DEMOGRAPHIC_PROMPT_SIZE = 4
16
17
 
@@ -194,8 +195,12 @@ def get_cehrgpt_output_folder(args, cehrgpt_tokenizer) -> str:
194
195
  return folder_name
195
196
 
196
197
 
197
- def is_clinical_event(token: str) -> bool:
198
- return token.isnumeric()
198
+ def is_clinical_event(token: str, meds: bool = False) -> bool:
199
+ if token.isnumeric():
200
+ return True
201
+ if meds:
202
+ return bool(MEDS_CODE_PATTERN.match(token))
203
+ return False
199
204
 
200
205
 
201
206
  def is_visit_start(token: str):
@@ -212,6 +217,18 @@ def is_visit_end(token: str) -> bool:
212
217
  return token in ["VE", "[VE]"]
213
218
 
214
219
 
220
+ def is_inpatient_hour_token(token: str) -> bool:
221
+ return token.startswith("i-H")
222
+
223
+
224
+ def extract_time_interval_in_hours(token: str) -> int:
225
+ try:
226
+ hour = int(token[3:])
227
+ return hour
228
+ except ValueError:
229
+ return 0
230
+
231
+
215
232
  def is_att_token(token: str):
216
233
  """
217
234
  Check if the token is an attention token.
@@ -251,6 +268,7 @@ def is_artificial_token(token: str) -> bool:
251
268
  return True
252
269
  if token == END_TOKEN:
253
270
  return True
271
+
254
272
  return False
255
273
 
256
274
 
cehrgpt/models/config.py CHANGED
@@ -121,6 +121,7 @@ class CEHRGPTConfig(PretrainedConfig):
121
121
  bos_token_id=50256,
122
122
  eos_token_id=50256,
123
123
  lab_token_ids=None,
124
+ ve_token_id=None,
124
125
  scale_attn_by_inverse_layer_idx=False,
125
126
  reorder_and_upcast_attn=False,
126
127
  exclude_position_ids=False,
@@ -128,6 +129,10 @@ class CEHRGPTConfig(PretrainedConfig):
128
129
  value_vocab_size=None,
129
130
  include_ttv_prediction=False,
130
131
  use_sub_time_tokenization=True,
132
+ include_motor_time_to_event=True,
133
+ motor_tte_vocab_size=None,
134
+ motor_time_to_event_weight=1.0,
135
+ motor_num_time_pieces=16,
131
136
  token_to_time_token_mapping: Dict[int, List] = None,
132
137
  use_pretrained_embeddings=False,
133
138
  n_pretrained_embeddings_layers=2,
@@ -144,6 +149,7 @@ class CEHRGPTConfig(PretrainedConfig):
144
149
  entropy_penalty=False,
145
150
  entropy_penalty_alpha=0.01,
146
151
  sample_packing_max_positions=None,
152
+ class_weights=None,
147
153
  **kwargs,
148
154
  ):
149
155
  if token_to_time_token_mapping is None:
@@ -192,6 +198,22 @@ class CEHRGPTConfig(PretrainedConfig):
192
198
  self._token_to_time_token_mapping = token_to_time_token_mapping
193
199
  self.time_token_loss_weight = time_token_loss_weight
194
200
  self.time_to_visit_loss_weight = time_to_visit_loss_weight
201
+
202
+ # MOTOR TTE configuration
203
+ self.motor_tte_vocab_size = motor_tte_vocab_size
204
+ self.include_motor_time_to_event = (
205
+ include_motor_time_to_event
206
+ and self.motor_tte_vocab_size
207
+ and self.motor_tte_vocab_size > 0
208
+ )
209
+ if self.include_motor_time_to_event and not ve_token_id:
210
+ raise RuntimeError(
211
+ f"ve_token_id must be provided when include_motor_time_to_event is True"
212
+ )
213
+ self.ve_token_id = ve_token_id
214
+ self.motor_time_to_event_weight = motor_time_to_event_weight
215
+ self.motor_num_time_pieces = motor_num_time_pieces
216
+
195
217
  self.causal_sfm = causal_sfm
196
218
  self.demographics_size = demographics_size
197
219
  self.use_pretrained_embeddings = use_pretrained_embeddings
@@ -206,6 +228,9 @@ class CEHRGPTConfig(PretrainedConfig):
206
228
  self.entropy_penalty_alpha = entropy_penalty_alpha
207
229
  self.value_prediction_loss_weight = value_prediction_loss_weight
208
230
 
231
+ # Class weights for fine-tuning
232
+ self.class_weights = class_weights
233
+
209
234
  kwargs["tie_word_embeddings"] = not use_pretrained_embeddings
210
235
 
211
236
  super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)