cehrgpt 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. cehrgpt/analysis/htn_treatment_pathway.py +546 -0
  2. cehrgpt/analysis/treatment_pathway/__init__.py +0 -0
  3. cehrgpt/analysis/treatment_pathway/depression_treatment_pathway.py +94 -0
  4. cehrgpt/analysis/treatment_pathway/diabetes_treatment_pathway.py +94 -0
  5. cehrgpt/analysis/treatment_pathway/htn_treatment_pathway.py +94 -0
  6. cehrgpt/analysis/treatment_pathway/treatment_pathway.py +631 -0
  7. cehrgpt/data/cehrgpt_data_processor.py +549 -0
  8. cehrgpt/data/hf_cehrgpt_dataset.py +4 -0
  9. cehrgpt/data/hf_cehrgpt_dataset_collator.py +285 -652
  10. cehrgpt/data/hf_cehrgpt_dataset_mapping.py +38 -5
  11. cehrgpt/generation/cehrgpt_conditional_generation.py +2 -0
  12. cehrgpt/generation/generate_batch_hf_gpt_sequence.py +20 -12
  13. cehrgpt/generation/omop_converter_batch.py +11 -4
  14. cehrgpt/gpt_utils.py +73 -3
  15. cehrgpt/models/activations.py +27 -0
  16. cehrgpt/models/config.py +6 -2
  17. cehrgpt/models/gpt2.py +560 -0
  18. cehrgpt/models/hf_cehrgpt.py +183 -460
  19. cehrgpt/models/tokenization_hf_cehrgpt.py +380 -50
  20. cehrgpt/omop/ontology.py +154 -0
  21. cehrgpt/runners/hf_cehrgpt_finetune_runner.py +24 -78
  22. cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +48 -44
  23. cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +46 -34
  24. cehrgpt/runners/hyperparameter_search_util.py +180 -69
  25. cehrgpt/runners/sample_packing_trainer.py +11 -2
  26. cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +8 -2
  27. cehrgpt-0.1.4.dist-info/METADATA +238 -0
  28. {cehrgpt-0.1.2.dist-info → cehrgpt-0.1.4.dist-info}/RECORD +32 -22
  29. cehrgpt-0.1.2.dist-info/METADATA +0 -209
  30. /cehrgpt/tools/{merge_synthetic_real_dataasets.py → merge_synthetic_real_datasets.py} +0 -0
  31. {cehrgpt-0.1.2.dist-info → cehrgpt-0.1.4.dist-info}/WHEEL +0 -0
  32. {cehrgpt-0.1.2.dist-info → cehrgpt-0.1.4.dist-info}/licenses/LICENSE +0 -0
  33. {cehrgpt-0.1.2.dist-info → cehrgpt-0.1.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,549 @@
1
+ import random
2
+ from typing import Any, Dict, List, Optional
3
+
4
+ import numpy as np
5
+ import pandas as pd
6
+ import torch
7
+ from cehrbert.data_generators.hf_data_generator.hf_dataset_mapping import DatasetMapping
8
+ from transformers.utils import logging
9
+
10
+ from cehrgpt.gpt_utils import (
11
+ DEMOGRAPHIC_PROMPT_SIZE,
12
+ collect_demographic_prompts_at_visits,
13
+ construct_age_sequence,
14
+ construct_time_sequence,
15
+ extract_time_interval_in_days,
16
+ extract_time_interval_in_hours,
17
+ is_att_token,
18
+ is_clinical_event,
19
+ is_inpatient_att_token,
20
+ is_inpatient_hour_token,
21
+ random_slice_gpt_sequence,
22
+ )
23
+ from cehrgpt.models.tokenization_hf_cehrgpt import CehrGptTokenizer
24
+
25
+ TIME_TO_EVENT_MAX_TIME = 3650
26
+ INPATIENT_STAY_DURATION_LIMIT = 30
27
+ LOG = logging.get_logger("transformers")
28
+
29
+
30
+ class CehrGptDataProcessor(DatasetMapping):
31
+ def __init__(
32
+ self,
33
+ tokenizer: CehrGptTokenizer,
34
+ max_length: int,
35
+ shuffle_records: bool = False,
36
+ include_values: bool = False,
37
+ include_ttv_prediction: bool = False,
38
+ include_motor_time_to_event: bool = False,
39
+ motor_sampling_probability: float = 0.5,
40
+ pretraining: bool = True,
41
+ include_demographics: bool = False,
42
+ add_linear_prob_token: bool = False,
43
+ ):
44
+ self.tokenizer = tokenizer
45
+ self.max_length = max_length
46
+
47
+ self.vs_token_id = tokenizer.vs_token_id
48
+ self.ve_token_id = tokenizer.ve_token_id
49
+
50
+ self.shuffle_records = shuffle_records
51
+ self.include_values = include_values
52
+ self.include_ttv_prediction = include_ttv_prediction
53
+ self.pretraining = pretraining
54
+ self.include_demographics = include_demographics
55
+ self.add_linear_prob_token = add_linear_prob_token
56
+ self.empty_array = np.asarray([])
57
+
58
+ if self.pretraining and self.add_linear_prob_token:
59
+ raise ValueError(
60
+ "pretraining and add_linear_prob_token cannot be specify at the same time"
61
+ )
62
+
63
+ # Motor related codes
64
+ self.include_motor_time_to_event = include_motor_time_to_event
65
+ self.motor_sampling_probability = motor_sampling_probability
66
+ self.motor_code_cache: Dict[str, List[str]] = {}
67
+ # Pre-compute vocab-wide token type mappings
68
+ self._precompute_vocab_mappings()
69
+
70
+ def _precompute_vocab_mappings(self):
71
+ """Pre-compute token type mappings for entire vocabulary."""
72
+ LOG.info("Pre-computing vocabulary-wide token mappings...")
73
+
74
+ vocab = self.tokenizer.get_vocab()
75
+ self.vocab_to_idx = {token: idx for idx, token in enumerate(vocab.keys())}
76
+ self.vocab_tokens = list(vocab.keys())
77
+
78
+ # Pre-compute boolean arrays for token types
79
+ n_vocab = len(self.vocab_tokens)
80
+ self.is_att_token_array = np.zeros(n_vocab, dtype=bool)
81
+ self.is_clinical_event_array = np.zeros(n_vocab, dtype=bool)
82
+ self.time_intervals_array = np.full(n_vocab, -1, dtype=int)
83
+
84
+ for i, token in enumerate(self.vocab_tokens):
85
+ if is_att_token(token):
86
+ self.is_att_token_array[i] = True
87
+ try:
88
+ self.time_intervals_array[i] = extract_time_interval_in_days(token)
89
+ except (ValueError, AttributeError):
90
+ self.time_intervals_array[i] = -1
91
+
92
+ if is_clinical_event(token):
93
+ self.is_clinical_event_array[i] = True
94
+
95
+ LOG.info(f"Processed {n_vocab} vocabulary tokens")
96
+
97
+ @staticmethod
98
+ def _convert_time_to_event(concept_ids):
99
+ def default_value(c):
100
+ try:
101
+ if is_att_token(c):
102
+ time_to_visit = extract_time_interval_in_days(c)
103
+ if (
104
+ is_inpatient_att_token(c)
105
+ and time_to_visit > INPATIENT_STAY_DURATION_LIMIT
106
+ ):
107
+ return -100
108
+ return time_to_visit
109
+ elif is_inpatient_hour_token(c):
110
+ return extract_time_interval_in_hours(c) / 24
111
+ return -100
112
+ except ValueError:
113
+ return -100
114
+
115
+ return [float(default_value(_)) for _ in concept_ids]
116
+
117
+ def random_sort(self, record: Dict[str, Any]) -> Dict[str, Any]:
118
+ if "record_ranks" not in record:
119
+ return record
120
+
121
+ sorting_column = record["record_ranks"]
122
+ random_order = np.random.rand(len(sorting_column))
123
+
124
+ if self.include_values:
125
+ iterator = zip(
126
+ sorting_column,
127
+ random_order,
128
+ record["input_ids"],
129
+ record["value_indicators"],
130
+ record["values"],
131
+ )
132
+ sorted_list = sorted(iterator, key=lambda tup2: (tup2[0], tup2[1], tup2[2]))
133
+ _, _, sorted_input_ids, sorted_value_indicators, sorted_values = zip(
134
+ *list(sorted_list)
135
+ )
136
+ record["input_ids"] = sorted_input_ids
137
+ record["value_indicators"] = sorted_value_indicators
138
+ record["values"] = sorted_values
139
+ else:
140
+ iterator = zip(sorting_column, random_order, record["input_ids"])
141
+ sorted_list = sorted(iterator, key=lambda tup2: (tup2[0], tup2[1], tup2[2]))
142
+ _, _, sorted_input_ids = zip(*list(sorted_list))
143
+ record["input_ids"] = sorted_input_ids
144
+ return record
145
+
146
+ def transform(self, example: Dict[str, Any]) -> Dict[str, Any]:
147
+
148
+ if self.shuffle_records:
149
+ example = self.random_sort(example)
150
+
151
+ if "concept_ids" not in example:
152
+ input_ids = example["input_ids"]
153
+ if isinstance(input_ids, torch.Tensor):
154
+ input_ids = input_ids.detach().tolist()
155
+ example["concept_ids"] = self.tokenizer.decode(
156
+ input_ids, skip_special_tokens=False
157
+ )
158
+ example["ages"] = pd.Series(example["ages"]).ffill().tolist()
159
+ example = self.slice_out_input_sequence(example)
160
+ # Add the motor labels
161
+ if self.include_motor_time_to_event:
162
+ motor_inputs = self.create_time_to_event_labels(example)
163
+ example.update(motor_inputs)
164
+ del example["concept_ids"]
165
+ return example
166
+
167
+ def update_inputs_based_on_indexes(
168
+ self,
169
+ record: Dict[str, Any],
170
+ start_index,
171
+ end_index,
172
+ add_end_token: bool = False,
173
+ demographic_tokens: Optional[List[str]] = None,
174
+ ) -> Dict[str, Any]:
175
+
176
+ last_token_id = (
177
+ self.tokenizer.linear_token_id
178
+ if self.add_linear_prob_token
179
+ else self.tokenizer.end_token_id
180
+ )
181
+
182
+ add_last_token = self.add_linear_prob_token | add_end_token
183
+
184
+ # Slice out the concept ids
185
+ record["concept_ids"] = (
186
+ (demographic_tokens if demographic_tokens is not None else [])
187
+ + (record["concept_ids"][start_index:end_index])
188
+ + (
189
+ self.tokenizer.decode([last_token_id], skip_special_tokens=False)
190
+ if add_last_token
191
+ else []
192
+ )
193
+ )
194
+
195
+ record["input_ids"] = np.concatenate(
196
+ [
197
+ (
198
+ np.asarray(self.tokenizer.encode(demographic_tokens))
199
+ if demographic_tokens is not None
200
+ else self.empty_array
201
+ ),
202
+ np.asarray(record["input_ids"][start_index:end_index]),
203
+ (np.asarray([last_token_id]) if add_last_token else self.empty_array),
204
+ ]
205
+ ).astype(np.int32)
206
+
207
+ record["ages"] = np.concatenate(
208
+ [
209
+ (
210
+ np.full([DEMOGRAPHIC_PROMPT_SIZE], record["ages"][0])
211
+ if demographic_tokens is not None
212
+ else self.empty_array
213
+ ),
214
+ np.asarray(record["ages"][start_index:end_index]),
215
+ (
216
+ np.asarray([record["ages"][-1]])
217
+ if add_last_token
218
+ else self.empty_array
219
+ ),
220
+ ]
221
+ ).astype(np.int32)
222
+
223
+ # For the new datasets, they contain the column "epoch_times"
224
+ record["epoch_times"] = np.concatenate(
225
+ [
226
+ (
227
+ np.zeros([DEMOGRAPHIC_PROMPT_SIZE])
228
+ if demographic_tokens is not None
229
+ else self.empty_array
230
+ ),
231
+ np.asarray(record["epoch_times"][start_index:end_index]),
232
+ (
233
+ np.asarray([record["epoch_times"][-1]])
234
+ if add_last_token
235
+ else self.empty_array
236
+ ),
237
+ ]
238
+ ).astype(np.float32)
239
+
240
+ if self.include_values:
241
+ record["value_indicators"] = np.concatenate(
242
+ [
243
+ (
244
+ np.zeros([DEMOGRAPHIC_PROMPT_SIZE])
245
+ if demographic_tokens is not None
246
+ else self.empty_array
247
+ ),
248
+ np.asarray(record["value_indicators"][start_index:end_index]),
249
+ np.asarray([False]) if add_last_token else self.empty_array,
250
+ ]
251
+ ).astype(np.bool_)
252
+ record["values"] = np.concatenate(
253
+ [
254
+ (
255
+ np.full(
256
+ [DEMOGRAPHIC_PROMPT_SIZE], self.tokenizer.pad_value_token_id
257
+ )
258
+ if demographic_tokens is not None
259
+ else self.empty_array
260
+ ),
261
+ np.asarray(record["values"][start_index:end_index]),
262
+ (
263
+ np.asarray([self.tokenizer.pad_value_token_id])
264
+ if add_last_token
265
+ else self.empty_array
266
+ ),
267
+ ]
268
+ ).astype(np.int32)
269
+
270
+ if self.include_ttv_prediction:
271
+ record["time_to_visits"] = np.concatenate(
272
+ [
273
+ (
274
+ np.full([DEMOGRAPHIC_PROMPT_SIZE], -100.0)
275
+ if demographic_tokens is not None
276
+ else self.empty_array
277
+ ),
278
+ np.asarray(
279
+ self._convert_time_to_event(
280
+ record["concept_ids"][start_index:end_index]
281
+ )
282
+ ),
283
+ np.asarray([-100.0]) if add_last_token else self.empty_array,
284
+ ]
285
+ ).astype(np.float32)
286
+
287
+ return record
288
+
289
+ def slice_out_input_sequence(self, record: Dict[str, Any]) -> Dict[str, Any]:
290
+ """Adding the start and end indices to extract a portion of the patient sequence."""
291
+ # Subtract one for the [END] or [LINEAR_PROB] token when sample_packing is not enabled
292
+ new_max_length = (
293
+ self.max_length - 1
294
+ if self.add_linear_prob_token or self.pretraining
295
+ else self.max_length
296
+ )
297
+ concept_ids = record["concept_ids"]
298
+ seq_length = len(record["input_ids"])
299
+
300
+ # For backward compatibility, in case these two columns do not already exist
301
+ record["ages"] = construct_age_sequence(record["concept_ids"], record["ages"])
302
+ record["epoch_times"] = construct_time_sequence(
303
+ record["concept_ids"], record["epoch_times"]
304
+ )
305
+
306
+ # Return the record directly if the actual sequence length is less than the max sequence
307
+ if seq_length <= new_max_length:
308
+ # We only add [END] to the end of the sequence in pre-training
309
+ record = self.update_inputs_based_on_indexes(
310
+ record, 0, seq_length, add_end_token=self.pretraining
311
+ )
312
+ return record
313
+
314
+ if self.pretraining:
315
+ end_index = new_max_length
316
+ # There is a 50% chance we randomly slice out a portion of the patient history and update the demographic
317
+ # prompt depending on the new starting point
318
+ if random.random() < 0.5:
319
+ start_index, end_index, demographic_tokens = random_slice_gpt_sequence(
320
+ concept_ids, new_max_length
321
+ )
322
+ if start_index != end_index:
323
+ record = self.update_inputs_based_on_indexes(
324
+ record, start_index, end_index + 1, add_end_token=False
325
+ )
326
+ return record
327
+
328
+ # The default employs a right truncation strategy, where the demographic prompt is reserved
329
+ for i in reversed(list(range(0, end_index))):
330
+ current_token = record["input_ids"][i]
331
+ if current_token == self.ve_token_id:
332
+ # Plus one because slicing is right exclusive
333
+ end_index = i + 1
334
+ break
335
+
336
+ record = self.update_inputs_based_on_indexes(
337
+ record=record, start_index=0, end_index=end_index, add_end_token=False
338
+ )
339
+ return record
340
+ else:
341
+ if self.include_demographics:
342
+ # We employ a left truncation strategy, where the most recent patient history is reserved for fine-tuning
343
+ demographic_prompts_at_visits = collect_demographic_prompts_at_visits(
344
+ concept_ids
345
+ )
346
+ for token_index, demographic_prompt in demographic_prompts_at_visits:
347
+ if (
348
+ seq_length - token_index
349
+ <= new_max_length - DEMOGRAPHIC_PROMPT_SIZE
350
+ ):
351
+ return self.update_inputs_based_on_indexes(
352
+ record=record,
353
+ start_index=token_index,
354
+ end_index=seq_length,
355
+ add_end_token=False,
356
+ demographic_tokens=demographic_prompt,
357
+ )
358
+ else:
359
+ start_index = seq_length - new_max_length
360
+ end_index = seq_length
361
+ for i in range(start_index, end_index):
362
+ current_token = record["input_ids"][i]
363
+ if current_token == self.vs_token_id:
364
+ return self.update_inputs_based_on_indexes(
365
+ record=record,
366
+ start_index=i,
367
+ end_index=end_index,
368
+ add_end_token=False,
369
+ )
370
+
371
+ # This could happen when the last visit contains more than new_max_length number of tokens
372
+ # We simply take the last new_max_length number of tokens from the patient sequence
373
+ if len(record["input_ids"]) > new_max_length:
374
+ record = self.update_inputs_based_on_indexes(
375
+ record=record,
376
+ start_index=-new_max_length,
377
+ end_index=seq_length,
378
+ add_end_token=False,
379
+ )
380
+ return record
381
+
382
+ def create_time_to_event_labels(self, record: Dict[str, Any]) -> Dict[str, Any]:
383
+ """
384
+ Generates time-to-event (TTE) labels and censoring indicators for each visit in a patient's timeline.
385
+
386
+ Processes the input sequence in reverse to compute the number of days from each visit (marked by [VE])
387
+ to the occurrence of future motor-related events.
388
+
389
+ Args:
390
+ record (Dict[str, Any]): A dictionary containing the encoded patient sequence with the key "input_ids".
391
+ This sequence includes [VS], [VE], time delta tokens, and motor TTE concept codes.
392
+
393
+ Returns:
394
+ Dict[str, Any]: The updated input record with added keys:
395
+ - "time_to_event_vectors": np.ndarray of shape [num_visits, motor_vocab_size], containing time-to-event values
396
+ - "event_indicators": np.ndarray of shape [num_visits, motor_vocab_size], where 0 = event occurred, 1 = censored
397
+ """
398
+
399
+ """Highly optimized vectorized version using pre-computed token type arrays."""
400
+ concept_ids = record["concept_ids"]
401
+ # Convert concept_ids to indices for vectorized operations
402
+ concept_indices = np.array([self.vocab_to_idx[cid] for cid in concept_ids])
403
+ # Vectorized token type detection
404
+ is_att_tokens = self.is_att_token_array[concept_indices]
405
+ is_clinical_events = self.is_clinical_event_array[concept_indices]
406
+ time_intervals = self.time_intervals_array[concept_indices]
407
+
408
+ # Find valid time tokens (att tokens with positive intervals)
409
+ valid_time_tokens = is_att_tokens & (time_intervals > 0)
410
+ n_concepts = len(concept_ids)
411
+
412
+ # We need to make sure event_times is monotonic
413
+ event_times = np.zeros(n_concepts, dtype=float)
414
+ previous_time_stamp = record["epoch_times"][0]
415
+ for i, time_stamp in enumerate(record["epoch_times"]):
416
+ if time_stamp < previous_time_stamp:
417
+ time_stamp = previous_time_stamp
418
+ else:
419
+ previous_time_stamp = time_stamp
420
+ event_times[i] = time_stamp
421
+
422
+ # Determine prediction positions
423
+ before_valid_time_tokens = np.roll(valid_time_tokens, -1)
424
+ # We randomly make predictions at 50% of the sequence positions
425
+ prediction_positions = (
426
+ np.random.rand(n_concepts) < self.motor_sampling_probability
427
+ )
428
+ # We don't predict at the att time tokens
429
+ prediction_positions &= ~is_att_tokens
430
+ # We disable TTE predictions using the demographics alone
431
+ prediction_positions[:4] = False
432
+ # We take the union of the random prediction positions and the positions right before time token
433
+ prediction_positions = prediction_positions | before_valid_time_tokens
434
+ # We exclude the events that occur at the last time stamp
435
+ prediction_positions &= event_times != event_times[-1]
436
+
437
+ prediction_indices = np.where(prediction_positions)[0]
438
+ if len(prediction_indices) == 0:
439
+ return {
440
+ "motor_censor_times": [],
441
+ "motor_row_indices": [],
442
+ "motor_col_indices": [],
443
+ "motor_values": [],
444
+ "motor_tte_task_indicators": [False] * n_concepts,
445
+ }
446
+
447
+ # Pre-compute all motor codes for clinical events to avoid repeated lookups
448
+ clinical_positions = np.where(is_clinical_events)[0]
449
+ motor_codes_cache = {} # position -> list of (motor_code, motor_token_id)
450
+
451
+ for pos in clinical_positions:
452
+ concept_id = concept_ids[pos]
453
+ if concept_id in self.motor_code_cache:
454
+ motor_codes = self.motor_code_cache[concept_id]
455
+ else:
456
+ motor_codes = self.tokenizer.get_motor_parents(concept_id)
457
+ self.motor_code_cache[concept_id] = motor_codes
458
+
459
+ if motor_codes:
460
+ motor_codes_cache[pos] = [
461
+ (motor_code, self.tokenizer.get_motor_token_id(motor_code))
462
+ for motor_code in motor_codes
463
+ ]
464
+
465
+ # Process sections in REVERSE order but build results in FORWARD order
466
+ section_boundaries = np.concatenate([prediction_indices, [n_concepts]])
467
+ last_event_time = event_times[-1]
468
+
469
+ # Pre-allocate arrays with exact size needed
470
+ num_prediction_positions = len(prediction_indices)
471
+ motor_censor_times = np.zeros(num_prediction_positions, dtype=float)
472
+ motor_tte_task_indicators = np.zeros(n_concepts, dtype=bool)
473
+
474
+ # Store sparse matrix data grouped by row for efficient construction
475
+ sparse_data_by_row = {} # row_idx -> [(col_idx, value), ...]
476
+
477
+ # Global motor event state that accumulates as we go backwards
478
+ global_motor_events = {} # motor_code -> earliest_future_time
479
+
480
+ # Process in reverse order but assign to forward row indices
481
+ for i in range(len(prediction_indices) - 1, -1, -1):
482
+ start_index = prediction_indices[i]
483
+ end_index = section_boundaries[i + 1]
484
+ current_event_time = event_times[start_index]
485
+
486
+ # Add new motor events from this section to global state
487
+ section_start = start_index + 1
488
+ section_end = end_index + 1 if end_index < n_concepts else n_concepts
489
+
490
+ # Process clinical events in this section (in reverse order within section)
491
+ section_clinical_positions = clinical_positions[
492
+ (clinical_positions >= section_start)
493
+ & (clinical_positions < section_end)
494
+ ]
495
+
496
+ for pos in reversed(section_clinical_positions):
497
+ if pos in motor_codes_cache:
498
+ concept_time = event_times[pos]
499
+ if concept_time > current_event_time:
500
+ for motor_code, motor_token_id in motor_codes_cache[pos]:
501
+ global_motor_events[motor_code] = (
502
+ concept_time,
503
+ motor_token_id,
504
+ )
505
+
506
+ # Store sparse matrix data for current prediction position
507
+ # Even if global_motor_events is empty, we still need to record this position
508
+ # because it indicates all motor tasks are censored at this time point
509
+ sparse_data_by_row[i] = [
510
+ (motor_token_id, motor_time - current_event_time)
511
+ for motor_code, (
512
+ motor_time,
513
+ motor_token_id,
514
+ ) in global_motor_events.items()
515
+ ]
516
+ motor_tte_task_indicators[start_index] = True
517
+ motor_censor_times[i] = last_event_time - current_event_time
518
+
519
+ # Build final sparse matrix lists in forward order (no reversals needed)
520
+ motor_row_indices = []
521
+ motor_col_indices = []
522
+ motor_values = []
523
+
524
+ for row_idx in sorted(sparse_data_by_row.keys()):
525
+ for col_idx, value in sparse_data_by_row[row_idx]:
526
+ motor_row_indices.append(row_idx)
527
+ motor_col_indices.append(col_idx)
528
+ motor_values.append(value)
529
+
530
+ # Filter out unused positions from motor_censor_times
531
+ motor_censor_times = [
532
+ motor_censor_times[i] for i in sorted(sparse_data_by_row.keys())
533
+ ]
534
+
535
+ if len(motor_row_indices) == 0:
536
+ LOG.debug(
537
+ "No MOTOR tasks detected for this sample. "
538
+ "Length: %s, last 10 concepts: %s",
539
+ len(concept_ids),
540
+ concept_ids[-10:] if len(concept_ids) >= 10 else concept_ids,
541
+ )
542
+
543
+ return {
544
+ "motor_censor_times": motor_censor_times,
545
+ "motor_row_indices": motor_row_indices,
546
+ "motor_col_indices": motor_col_indices,
547
+ "motor_values": motor_values,
548
+ "motor_tte_task_indicators": motor_tte_task_indicators.tolist(),
549
+ }
@@ -24,6 +24,10 @@ CEHRGPT_COLUMNS = [
24
24
  "values",
25
25
  "value_indicators",
26
26
  "epoch_times",
27
+ "ages",
28
+ "genders",
29
+ "races",
30
+ "position_ids",
27
31
  ]
28
32
 
29
33
  TRANSFORMER_COLUMNS = ["input_ids"]