cehrgpt 0.0.2__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. cehrgpt/analysis/irregularity.py +36 -0
  2. cehrgpt/data/hf_cehrgpt_dataset.py +25 -4
  3. cehrgpt/data/hf_cehrgpt_dataset_collator.py +635 -97
  4. cehrgpt/data/hf_cehrgpt_dataset_mapping.py +308 -95
  5. cehrgpt/data/sample_packing_sampler.py +181 -0
  6. cehrgpt/generation/generate_batch_hf_gpt_sequence.py +12 -9
  7. cehrgpt/generation/omop_converter_batch.py +32 -2
  8. cehrgpt/gpt_utils.py +20 -2
  9. cehrgpt/models/config.py +35 -0
  10. cehrgpt/models/hf_cehrgpt.py +470 -106
  11. cehrgpt/models/hf_modeling_outputs.py +1 -0
  12. cehrgpt/models/special_tokens.py +1 -0
  13. cehrgpt/models/tokenization_hf_cehrgpt.py +358 -71
  14. cehrgpt/runners/data_utils.py +358 -0
  15. cehrgpt/runners/gpt_runner_util.py +0 -10
  16. cehrgpt/runners/hf_cehrgpt_finetune_runner.py +181 -283
  17. cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +288 -112
  18. cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +90 -0
  19. cehrgpt/runners/hyperparameter_search_util.py +10 -8
  20. cehrgpt/runners/sample_packing_trainer.py +185 -0
  21. cehrgpt/simulations/generate_plots.py +95 -0
  22. cehrgpt/simulations/run_simulation.sh +24 -0
  23. cehrgpt/simulations/time_embedding_simulation.py +250 -0
  24. cehrgpt/simulations/time_token_simulation.py +177 -0
  25. cehrgpt/time_to_event/config/1_year_cabg.yaml +23 -0
  26. cehrgpt/time_to_event/time_to_event_model.py +2 -13
  27. cehrgpt/time_to_event/time_to_event_prediction.py +27 -13
  28. cehrgpt/tools/linear_prob/__init__.py +0 -0
  29. cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +495 -0
  30. cehrgpt/tools/linear_prob/train_with_cehrgpt_features.py +152 -0
  31. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/METADATA +11 -8
  32. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/RECORD +36 -32
  33. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/WHEEL +1 -1
  34. cehrgpt/data/hf_cehrgpt_dpo_collator.py +0 -71
  35. cehrgpt/data/hf_cehrgpt_dpo_dataset_mapping.py +0 -61
  36. cehrgpt/generation/generate_paired_cehrgpt_sequence.py +0 -224
  37. cehrgpt/rl_finetune/cehrgpt_dpo_trainer.py +0 -586
  38. cehrgpt/rl_finetune/cehrgpt_ppo_trainer.py +0 -464
  39. cehrgpt/rl_finetune/ppo_finetune.py +0 -394
  40. cehrgpt/rl_finetune/ppo_finetune_v2.py +0 -373
  41. cehrgpt/runners/hf_cehrgpt_dpo_runner.py +0 -119
  42. /cehrgpt/{rl_finetune → simulations}/__init__.py +0 -0
  43. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info/licenses}/LICENSE +0 -0
  44. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/top_level.txt +0 -0
@@ -1,19 +1,33 @@
1
1
  import datetime
2
- from typing import Any, Dict
2
+ from collections import defaultdict
3
+ from typing import Any, Dict, Generator, List, Optional, Union
3
4
 
4
5
  import numpy as np
5
6
  import pandas as pd
7
+ from cehrbert.data_generators.hf_data_generator import UNKNOWN_VALUE
6
8
  from cehrbert.data_generators.hf_data_generator.hf_dataset_mapping import (
7
9
  ED_VISIT_TYPE_CODES,
8
10
  INPATIENT_VISIT_TYPE_CODES,
9
11
  INPATIENT_VISIT_TYPES,
10
12
  DatasetMapping,
13
+ VisitObject,
14
+ get_value,
15
+ has_events_and_get_events,
11
16
  replace_escape_chars,
12
17
  )
18
+ from cehrbert.med_extension.schema_extension import Event
13
19
  from cehrbert.runners.hf_runner_argument_dataclass import DataTrainingArguments
20
+ from cehrbert_data.const.artificial_tokens import (
21
+ DISCHARGE_UNKNOWN_TOKEN,
22
+ GENDER_UNKNOWN_TOKEN,
23
+ RACE_UNKNOWN_TOKEN,
24
+ VISIT_UNKNOWN_TOKEN,
25
+ )
14
26
  from cehrbert_data.const.common import NA
15
27
  from cehrbert_data.decorators.patient_event_decorator_base import get_att_function
28
+ from datasets.formatting.formatting import LazyBatch
16
29
  from dateutil.relativedelta import relativedelta
30
+ from pandas import Series
17
31
 
18
32
  from cehrgpt.models.tokenization_hf_cehrgpt import (
19
33
  NONE_BIN,
@@ -21,6 +35,17 @@ from cehrgpt.models.tokenization_hf_cehrgpt import (
21
35
  CehrGptTokenizer,
22
36
  )
23
37
 
38
+ CEHRGPT_COLUMNS = [
39
+ "concept_ids",
40
+ "concept_value_masks",
41
+ "number_as_values",
42
+ "concept_as_values",
43
+ "is_numeric_types",
44
+ "concept_values",
45
+ "units",
46
+ "epoch_times",
47
+ ]
48
+
24
49
 
25
50
  def convert_date_to_posix_time(index_date: datetime.date) -> float:
26
51
  return datetime.datetime.combine(
@@ -28,11 +53,39 @@ def convert_date_to_posix_time(index_date: datetime.date) -> float:
28
53
  ).timestamp()
29
54
 
30
55
 
31
- class MedToCehrGPTDatasetMapping(DatasetMapping):
56
+ class DatasetMappingDecorator(DatasetMapping):
57
+
58
+ def batch_transform(
59
+ self, records: Union[LazyBatch, Dict[str, Any]]
60
+ ) -> List[Dict[str, Any]]:
61
+ """
62
+ Drop index_date if it contains None.
63
+
64
+ :param records:
65
+ :return:
66
+ """
67
+ if isinstance(records, LazyBatch):
68
+ table = records.pa_table
69
+
70
+ if "index_date" in table.column_names:
71
+ index_col = table.column("index_date")
72
+ if index_col.null_count > 0:
73
+ table = table.drop(["index_date"])
74
+ records = LazyBatch(pa_table=table, formatter=records.formatter)
75
+ else:
76
+ if "index_date" in records:
77
+ if pd.isna(records["index_date"][0]):
78
+ del records["index_date"]
79
+ return super().batch_transform(records=records)
80
+
81
+ def transform(self, record: Dict[str, Any]) -> Union[Dict[str, Any], Series]:
82
+ raise NotImplemented("Must be implemented")
83
+
84
+
85
+ class MedToCehrGPTDatasetMapping(DatasetMappingDecorator):
32
86
  def __init__(
33
87
  self,
34
88
  data_args: DataTrainingArguments,
35
- is_pretraining: bool = True,
36
89
  include_inpatient_hour_token: bool = True,
37
90
  ):
38
91
  self._time_token_function = get_att_function(data_args.att_function_type)
@@ -41,7 +94,6 @@ class MedToCehrGPTDatasetMapping(DatasetMapping):
41
94
  data_args.inpatient_att_function_type
42
95
  )
43
96
  self._include_demographic_prompt = data_args.include_demographic_prompt
44
- self._is_pretraining = is_pretraining
45
97
  self._include_inpatient_hour_token = include_inpatient_hour_token
46
98
 
47
99
  """
@@ -57,19 +109,13 @@ class MedToCehrGPTDatasetMapping(DatasetMapping):
57
109
  """
58
110
 
59
111
  def remove_columns(self):
60
- if self._is_pretraining:
61
- return ["visits", "birth_datetime", "index_date"]
62
- else:
63
- return [
64
- "visits",
65
- "birth_datetime",
66
- "visit_concept_ids",
67
- ]
112
+ return ["patient_id", "visits", "birth_datetime"]
68
113
 
69
114
  @staticmethod
70
115
  def _update_cehrgpt_record(
71
116
  cehrgpt_record: Dict[str, Any],
72
117
  code: str,
118
+ time: datetime.datetime,
73
119
  concept_value_mask: int = 0,
74
120
  number_as_value: float = 0.0,
75
121
  concept_as_value: str = "0",
@@ -82,6 +128,7 @@ class MedToCehrGPTDatasetMapping(DatasetMapping):
82
128
  cehrgpt_record["concept_as_values"].append(concept_as_value)
83
129
  cehrgpt_record["units"].append(unit)
84
130
  cehrgpt_record["is_numeric_types"].append(is_numeric_type)
131
+ cehrgpt_record["epoch_times"].append(time.timestamp())
85
132
 
86
133
  def transform(self, record: Dict[str, Any]) -> Dict[str, Any]:
87
134
  cehrgpt_record = {
@@ -92,45 +139,57 @@ class MedToCehrGPTDatasetMapping(DatasetMapping):
92
139
  "concept_as_values": [],
93
140
  "units": [],
94
141
  "is_numeric_types": [],
142
+ "epoch_times": [],
95
143
  }
96
144
  # Extract the demographic information
97
145
  birth_datetime = record["birth_datetime"]
98
146
  if isinstance(birth_datetime, pd.Timestamp):
99
147
  birth_datetime = birth_datetime.to_pydatetime()
100
148
  gender = record["gender"]
149
+ gender = GENDER_UNKNOWN_TOKEN if gender == UNKNOWN_VALUE else gender
101
150
  race = record["race"]
151
+ race = RACE_UNKNOWN_TOKEN if race == UNKNOWN_VALUE else race
152
+ visits = record["visits"]
153
+ # This indicates this is columnar format
154
+ if isinstance(visits, dict):
155
+ visits = sorted(self.convert_visit_columnar_to_python(visits))
156
+ else:
157
+ visits = sorted(visits, key=lambda _: get_value(_, "visit_start_datetime"))
102
158
 
103
159
  # Add the demographic tokens
104
- first_visit = record["visits"][0]
105
- year_str = f'year:{str(first_visit["visit_start_datetime"].year)}'
106
- age_str = f'age:{str(relativedelta(first_visit["visit_start_datetime"], birth_datetime).years)}'
107
- self._update_cehrgpt_record(cehrgpt_record, year_str)
108
- self._update_cehrgpt_record(cehrgpt_record, age_str)
109
- self._update_cehrgpt_record(cehrgpt_record, gender)
110
- self._update_cehrgpt_record(cehrgpt_record, race)
160
+ first_visit = visits[0]
161
+ first_visit_start_datetime: datetime.datetime = get_value(
162
+ first_visit, "visit_start_datetime"
163
+ )
164
+ year_str = f"year:{str(first_visit_start_datetime.year)}"
165
+ age_str = f"age:{str(relativedelta(first_visit_start_datetime, birth_datetime).years)}"
166
+ self._update_cehrgpt_record(
167
+ cehrgpt_record, year_str, first_visit_start_datetime
168
+ )
169
+ self._update_cehrgpt_record(cehrgpt_record, age_str, first_visit_start_datetime)
170
+ self._update_cehrgpt_record(cehrgpt_record, gender, first_visit_start_datetime)
171
+ self._update_cehrgpt_record(cehrgpt_record, race, first_visit_start_datetime)
111
172
 
112
173
  # Use a data cursor to keep track of time
113
- date_cursor = None
114
-
115
- # Loop through all the visits excluding the first event containing the demographics
116
- for i, visit in enumerate(
117
- sorted(record["visits"], key=lambda e: e["visit_start_datetime"])
118
- ):
119
-
120
- events = visit["events"]
121
-
122
- # Skip this visit if the number measurements in the event is zero
123
- if events is None or len(events) == 0:
174
+ datetime_cursor: Optional[datetime.datetime] = None
175
+ visit: VisitObject
176
+ # Loop through all the visits
177
+ for i, visit in enumerate(visits):
178
+ events: Generator[Event, None, None] = get_value(visit, "events")
179
+ has_events, events = has_events_and_get_events(events)
180
+ if not has_events:
124
181
  continue
125
182
 
126
- visit_start_datetime = visit["visit_start_datetime"]
127
- time_delta = (
128
- (visit_start_datetime - date_cursor).days if date_cursor else None
183
+ visit_start_datetime: datetime.datetime = get_value(
184
+ visit, "visit_start_datetime"
185
+ )
186
+ # If visit_end_datetime is populated for the inpatient visit, we update the datetime_cursor
187
+ visit_end_datetime: Optional[datetime.datetime] = get_value(
188
+ visit, "visit_end_datetime"
129
189
  )
130
- date_cursor = visit_start_datetime
131
190
 
132
191
  # We assume the first measurement to be the visit type of the current visit
133
- visit_type = visit["visit_type"]
192
+ visit_type = get_value(visit, "visit_type")
134
193
  is_er_or_inpatient = (
135
194
  visit_type in INPATIENT_VISIT_TYPES
136
195
  or visit_type in INPATIENT_VISIT_TYPE_CODES
@@ -138,36 +197,45 @@ class MedToCehrGPTDatasetMapping(DatasetMapping):
138
197
  )
139
198
 
140
199
  # Add artificial time tokens to the patient timeline if timedelta exists
141
- if time_delta is not None:
200
+ if datetime_cursor is not None:
201
+ time_delta = max((visit_start_datetime - datetime_cursor).days, 0)
142
202
  # This generates an artificial time token depending on the choice of the time token functions
143
203
  self._update_cehrgpt_record(
144
204
  cehrgpt_record,
145
205
  code=self._time_token_function(time_delta),
206
+ time=visit_start_datetime,
146
207
  )
147
208
 
148
- # Add the VS token to the patient timeline to mark the start of a visit
149
- relativedelta(visit["visit_start_datetime"], birth_datetime).years
150
- # Calculate the week number since the epoch time
151
- date = (
152
- visit["visit_start_datetime"]
153
- - datetime.datetime(year=1970, month=1, day=1)
154
- ).days // 7
155
-
209
+ datetime_cursor = visit_start_datetime
156
210
  # Add a [VS] token
157
211
  self._update_cehrgpt_record(
158
212
  cehrgpt_record,
159
213
  code="[VS]",
214
+ time=datetime_cursor,
160
215
  )
161
216
  # Add a visit type token
162
217
  self._update_cehrgpt_record(
163
218
  cehrgpt_record,
164
219
  code=visit_type,
220
+ time=datetime_cursor,
165
221
  )
222
+ # We need to insert an inpatient hour token right after the visit type, we calculate the hour interval
223
+ # with respect to the midnight of the day
224
+ if is_er_or_inpatient and self._include_inpatient_hour_token:
225
+ if datetime_cursor.hour > 0:
226
+ # This generates an artificial time token depending on the choice of the time token functions
227
+ self._update_cehrgpt_record(
228
+ cehrgpt_record,
229
+ code=f"i-H{datetime_cursor.hour}",
230
+ time=datetime_cursor,
231
+ )
232
+
166
233
  # Keep track of the existing outpatient events, we don't want to add them again
167
- existing_outpatient_events = list()
234
+ existing_duplicate_events = list()
168
235
  for e in events:
169
236
  # If the event doesn't have a time stamp, we skip it
170
- if not e["time"]:
237
+ event_time: datetime.datetime = e["time"]
238
+ if not event_time:
171
239
  continue
172
240
 
173
241
  # If numeric_value exists, this is a concept/value tuple, we indicate this using a concept_value_mask
@@ -178,47 +246,62 @@ class MedToCehrGPTDatasetMapping(DatasetMapping):
178
246
  concept_value_mask = int(
179
247
  numeric_value is not None or text_value is not None
180
248
  )
249
+ if numeric_value is None and text_value is not None:
250
+ if text_value.isnumeric():
251
+ numeric_value = float(text_value)
252
+
181
253
  is_numeric_type = int(numeric_value is not None)
182
254
  code = replace_escape_chars(e["code"])
183
255
 
256
+ # Create the event identity
257
+ event_identity = (
258
+ (event_time, code, text_value, unit)
259
+ if is_er_or_inpatient
260
+ else (event_time.date(), code, text_value, unit)
261
+ )
262
+
184
263
  # Add a medical token to the patient timeline
185
264
  # If this is an inpatient visit, we use the event time stamps to calculate age and date
186
265
  # because the patient can stay in the hospital for a period of time.
187
266
  if is_er_or_inpatient:
188
- # Calculate the week number since the epoch time
189
- date = (
190
- e["time"] - datetime.datetime(year=1970, month=1, day=1)
191
- ).days // 7
192
267
  # Calculate the time diff in days w.r.t the previous measurement
193
- meas_time_diff = (e["time"] - date_cursor).days
194
- # Update the date_cursor if the time diff between two neighboring measurements is greater than and
268
+ time_diff_days = (event_time - datetime_cursor).days
269
+ # Update the datetime_cursor if the time diff between two neighboring measurements is greater than and
195
270
  # equal to 1 day
196
- if meas_time_diff > 0:
197
- date_cursor = e["time"]
198
- if self._inpatient_time_token_function:
271
+ if self._inpatient_time_token_function and time_diff_days > 0:
272
+ # This generates an artificial time token depending on the choice of the time token functions
273
+ self._update_cehrgpt_record(
274
+ cehrgpt_record,
275
+ code=f"i-{self._inpatient_time_token_function(time_diff_days)}",
276
+ time=event_time,
277
+ )
278
+
279
+ if self._include_inpatient_hour_token:
280
+ # if the time difference in days is greater than 0, we calculate the hour interval
281
+ # with respect to the midnight of the day
282
+ time_diff_hours = (
283
+ event_time.hour
284
+ if time_diff_days > 0
285
+ else int(
286
+ (event_time - datetime_cursor).total_seconds() // 3600
287
+ )
288
+ )
289
+
290
+ if time_diff_hours > 0:
199
291
  # This generates an artificial time token depending on the choice of the time token functions
200
292
  self._update_cehrgpt_record(
201
293
  cehrgpt_record,
202
- code=f"i-{self._inpatient_time_token_function(meas_time_diff)}",
294
+ code=f"i-H{time_diff_hours}",
295
+ time=event_time,
203
296
  )
204
- else:
205
- # For outpatient visits, we use the visit time stamp to calculate age and time because we assume
206
- # the outpatient visits start and end on the same day.
207
- # We check whether the date/code/value combination already exists in the existing events
208
- # If they exist, we do not add them to the patient timeline for outpatient visits.
209
- if (
210
- date,
211
- code,
212
- numeric_value,
213
- text_value,
214
- concept_value_mask,
215
- numeric_value,
216
- ) in existing_outpatient_events:
217
- continue
297
+
298
+ if event_identity in existing_duplicate_events:
299
+ continue
218
300
 
219
301
  self._update_cehrgpt_record(
220
302
  cehrgpt_record,
221
303
  code=code,
304
+ time=event_time,
222
305
  concept_value_mask=concept_value_mask,
223
306
  unit=unit,
224
307
  number_as_value=numeric_value if numeric_value else 0.0,
@@ -227,43 +310,44 @@ class MedToCehrGPTDatasetMapping(DatasetMapping):
227
310
  ),
228
311
  is_numeric_type=is_numeric_type,
229
312
  )
230
- existing_outpatient_events.append(
231
- (
232
- date,
233
- code,
234
- numeric_value,
235
- text_value,
236
- concept_value_mask,
237
- numeric_value,
238
- )
239
- )
313
+ existing_duplicate_events.append(event_identity)
314
+ # we only want to update the time stamp when data_cursor is less than the event time
315
+ if datetime_cursor < event_time or datetime_cursor is None:
316
+ datetime_cursor = event_time
317
+ # We need to bound the datetime_cursor if the current visit is an admission type of visit
318
+ # as the associated events could be generated after the visits are complete
319
+ if is_er_or_inpatient and visit_end_datetime is not None:
320
+ datetime_cursor = min(datetime_cursor, visit_end_datetime)
240
321
 
241
322
  # For inpatient or ER visits, we want to discharge_facility to the end of the visit
242
323
  if is_er_or_inpatient:
243
- # If visit_end_datetime is populated for the inpatient visit, we update the date_cursor
244
- visit_end_datetime = visit.get("visit_end_datetime", None)
245
- if visit_end_datetime:
246
- date_cursor = visit_end_datetime
324
+ # If visit_end_datetime is populated for the inpatient visit, we update the datetime_cursor
325
+ if visit_end_datetime is not None:
326
+ datetime_cursor = visit_end_datetime
247
327
 
248
328
  if self._include_auxiliary_token:
249
329
  # Reuse the age and date calculated for the last event in the patient timeline for the discharge
250
330
  # facility event
251
- discharge_facility = (
252
- visit["discharge_facility"]
253
- if ("discharge_facility" in visit)
254
- and visit["discharge_facility"]
255
- else "0"
256
- )
257
-
331
+ discharge_facility = get_value(visit, "discharge_facility")
332
+ if not discharge_facility:
333
+ discharge_facility = DISCHARGE_UNKNOWN_TOKEN
334
+ else:
335
+ discharge_facility = (
336
+ DISCHARGE_UNKNOWN_TOKEN
337
+ if discharge_facility == UNKNOWN_VALUE
338
+ else discharge_facility
339
+ )
258
340
  self._update_cehrgpt_record(
259
341
  cehrgpt_record,
260
342
  code=discharge_facility,
343
+ time=datetime_cursor,
261
344
  )
262
345
 
263
346
  # Reuse the age and date calculated for the last event in the patient timeline
264
347
  self._update_cehrgpt_record(
265
348
  cehrgpt_record,
266
349
  code="[VE]",
350
+ time=datetime_cursor,
267
351
  )
268
352
 
269
353
  # Generate the orders of the concepts that the cehrbert dataset mapping function expects
@@ -273,17 +357,23 @@ class MedToCehrGPTDatasetMapping(DatasetMapping):
273
357
 
274
358
  # Add some count information for this sequence
275
359
  cehrgpt_record["num_of_concepts"] = len(cehrgpt_record["concept_ids"])
276
- cehrgpt_record["num_of_visits"] = len(record["visits"])
360
+ cehrgpt_record["num_of_visits"] = len(visits)
277
361
 
278
- if "label" in record:
362
+ if record.get("index_date", None) is not None:
363
+ cehrgpt_record["index_date"] = record["index_date"]
364
+ if record.get("label", None) is not None:
279
365
  cehrgpt_record["label"] = record["label"]
280
- if "age_at_index" in record:
366
+ if record.get("age_at_index", None) is not None:
281
367
  cehrgpt_record["age_at_index"] = record["age_at_index"]
282
368
 
369
+ assert len(cehrgpt_record["epoch_times"]) == len(
370
+ cehrgpt_record["concept_ids"]
371
+ ), "The number of time stamps must match with the number of concepts in the sequence"
372
+
283
373
  return cehrgpt_record
284
374
 
285
375
 
286
- class HFCehrGptTokenizationMapping(DatasetMapping):
376
+ class HFCehrGptTokenizationMapping(DatasetMappingDecorator):
287
377
  def __init__(
288
378
  self,
289
379
  concept_tokenizer: CehrGptTokenizer,
@@ -297,9 +387,46 @@ class HFCehrGptTokenizationMapping(DatasetMapping):
297
387
  "is_numeric_types",
298
388
  ]
299
389
 
390
+ def filter_out_invalid_tokens(self, record: Dict[str, Any]) -> Dict[str, Any]:
391
+ column_names = []
392
+ seq_length = len(record["concept_ids"])
393
+
394
+ # We can't have "0" as a token in the tokenizer because it would break tokenization for "Race/0", "Visit/0"
395
+ # This is a pre-caution
396
+ if "0" in record["concept_ids"]:
397
+ if isinstance(record["concept_ids"], np.ndarray):
398
+ record["concept_ids"][record["concept_ids"] == "0"] = "Unknown"
399
+ else:
400
+ record["concept_ids"] = [
401
+ "Unknown" if x == "0" else x for x in record["concept_ids"]
402
+ ]
403
+
404
+ for k, v in record.items():
405
+ if k not in CEHRGPT_COLUMNS:
406
+ continue
407
+ if isinstance(v, (list, np.ndarray)) and len(v) == seq_length:
408
+ column_names.append(k)
409
+ valid_concept_ids = self._concept_tokenizer.get_vocab().keys()
410
+ valid_indices = [
411
+ idx
412
+ for idx, concept_id in enumerate(record["concept_ids"])
413
+ if concept_id in valid_concept_ids
414
+ ]
415
+ if len(valid_indices) != len(record["concept_ids"]):
416
+ for column in column_names:
417
+ values = record[column]
418
+ record[column] = [values[idx] for idx in valid_indices]
419
+ return record
420
+
300
421
  def transform(self, record: Dict[str, Any]) -> Dict[str, Any]:
422
+ # Remove the tokens from patient sequences that do not exist in the tokenizer
423
+ record = self.filter_out_invalid_tokens(record)
301
424
  # If any concept has a value associated with it, we normalize the value
302
425
  record["input_ids"] = self._concept_tokenizer.encode(record["concept_ids"])
426
+ assert len(record["input_ids"]) == len(record["concept_ids"]), (
427
+ "The number of tokens must equal to the number of concepts\n"
428
+ f"decoded concept_ids: {self._concept_tokenizer.decode(record['input_ids'], skip_special_tokens=False)}"
429
+ )
303
430
  record["value_indicators"] = record["concept_value_masks"]
304
431
  if "number_as_values" not in record or "concept_as_values" not in record:
305
432
  record["number_as_values"] = [
@@ -380,3 +507,89 @@ class HFFineTuningMapping(HFCehrGptTokenizationMapping):
380
507
  columns = super().remove_columns()
381
508
  columns.append("label")
382
509
  return columns
510
+
511
+
512
+ class ExtractTokenizedSequenceDataMapping:
513
+ def __init__(
514
+ self,
515
+ person_index_date_map: Dict[int, List[Dict[str, Any]]],
516
+ observation_window: int = 0,
517
+ ):
518
+ self.person_index_date_map = person_index_date_map
519
+ self.observation_window = observation_window
520
+
521
+ def _calculate_prediction_start_time(self, prediction_time: float):
522
+ if self.observation_window and self.observation_window > 0:
523
+ return max(prediction_time - self.observation_window * 24 * 3600, 0)
524
+ return 0
525
+
526
+ def transform(self, record: Dict[str, Any]) -> Dict[str, Any]:
527
+ person_id = record["person_id"]
528
+ prediction_times = self.person_index_date_map[person_id]
529
+ prediction_start_end_times = [
530
+ (
531
+ self._calculate_prediction_start_time(
532
+ prediction_time_label_map["index_date"].timestamp()
533
+ ),
534
+ prediction_time_label_map["index_date"].timestamp(),
535
+ prediction_time_label_map["label"],
536
+ )
537
+ for prediction_time_label_map in prediction_times
538
+ ]
539
+ observation_window_indices = np.zeros(
540
+ (len(prediction_times), len(record["epoch_times"])), dtype=bool
541
+ )
542
+ for i, epoch_time in enumerate(record["epoch_times"]):
543
+ for sample_n, (
544
+ feature_extraction_time_start,
545
+ feature_extraction_end_end,
546
+ _,
547
+ ) in enumerate(prediction_start_end_times):
548
+ if (
549
+ feature_extraction_time_start
550
+ <= epoch_time
551
+ <= feature_extraction_end_end
552
+ ):
553
+ observation_window_indices[sample_n][i] = True
554
+
555
+ seq_length = len(record["epoch_times"])
556
+ time_series_columns = ["concept_ids", "input_ids"]
557
+ static_inputs = dict()
558
+ for k, v in record.items():
559
+ if k in ["concept_ids", "input_ids"]:
560
+ continue
561
+ if isinstance(v, (list, np.ndarray)) and len(v) == seq_length:
562
+ time_series_columns.append(k)
563
+ else:
564
+ static_inputs[k] = v
565
+
566
+ batched_samples = defaultdict(list)
567
+ for (_, index_date, label), observation_window_index in zip(
568
+ prediction_start_end_times, observation_window_indices
569
+ ):
570
+ for k, v in static_inputs.items():
571
+ batched_samples[k].append(v)
572
+ batched_samples["classifier_label"].append(label)
573
+ batched_samples["index_date"].append(index_date)
574
+ try:
575
+ start_age = int(record["concept_ids"][1].split(":")[1])
576
+ except Exception:
577
+ start_age = -1
578
+ batched_samples["age_at_index"].append(start_age)
579
+ for time_series_column in time_series_columns:
580
+ batched_samples[time_series_column].append(
581
+ np.asarray(record[time_series_column])[observation_window_index]
582
+ )
583
+ return batched_samples
584
+
585
+ def batch_transform(self, record: Dict[str, Any]) -> Dict[str, Any]:
586
+ all_batched_record = defaultdict(list)
587
+ all_columns = record.keys()
588
+ for i in range(len(record["concept_ids"])):
589
+ one_record = {}
590
+ for column in all_columns:
591
+ one_record[column] = record[column][i]
592
+ new_batched_record = self.transform(one_record)
593
+ for k, v in new_batched_record.items():
594
+ all_batched_record[k].extend(v)
595
+ return all_batched_record