cehrgpt 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cehrgpt/analysis/htn_treatment_pathway.py +546 -0
- cehrgpt/analysis/treatment_pathway/__init__.py +0 -0
- cehrgpt/analysis/treatment_pathway/depression_treatment_pathway.py +94 -0
- cehrgpt/analysis/treatment_pathway/diabetes_treatment_pathway.py +94 -0
- cehrgpt/analysis/treatment_pathway/htn_treatment_pathway.py +94 -0
- cehrgpt/analysis/treatment_pathway/treatment_pathway.py +631 -0
- cehrgpt/data/cehrgpt_data_processor.py +549 -0
- cehrgpt/data/hf_cehrgpt_dataset.py +4 -0
- cehrgpt/data/hf_cehrgpt_dataset_collator.py +285 -652
- cehrgpt/data/hf_cehrgpt_dataset_mapping.py +38 -5
- cehrgpt/generation/cehrgpt_conditional_generation.py +2 -0
- cehrgpt/generation/generate_batch_hf_gpt_sequence.py +20 -12
- cehrgpt/generation/omop_converter_batch.py +11 -4
- cehrgpt/gpt_utils.py +73 -3
- cehrgpt/models/activations.py +27 -0
- cehrgpt/models/config.py +6 -2
- cehrgpt/models/gpt2.py +560 -0
- cehrgpt/models/hf_cehrgpt.py +183 -460
- cehrgpt/models/tokenization_hf_cehrgpt.py +380 -50
- cehrgpt/omop/ontology.py +154 -0
- cehrgpt/runners/hf_cehrgpt_finetune_runner.py +24 -78
- cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +48 -44
- cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +46 -34
- cehrgpt/runners/hyperparameter_search_util.py +180 -69
- cehrgpt/runners/sample_packing_trainer.py +11 -2
- cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +8 -2
- cehrgpt-0.1.4.dist-info/METADATA +238 -0
- {cehrgpt-0.1.2.dist-info → cehrgpt-0.1.4.dist-info}/RECORD +32 -22
- cehrgpt-0.1.2.dist-info/METADATA +0 -209
- /cehrgpt/tools/{merge_synthetic_real_dataasets.py → merge_synthetic_real_datasets.py} +0 -0
- {cehrgpt-0.1.2.dist-info → cehrgpt-0.1.4.dist-info}/WHEEL +0 -0
- {cehrgpt-0.1.2.dist-info → cehrgpt-0.1.4.dist-info}/licenses/LICENSE +0 -0
- {cehrgpt-0.1.2.dist-info → cehrgpt-0.1.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,549 @@
|
|
1
|
+
import random
|
2
|
+
from typing import Any, Dict, List, Optional
|
3
|
+
|
4
|
+
import numpy as np
|
5
|
+
import pandas as pd
|
6
|
+
import torch
|
7
|
+
from cehrbert.data_generators.hf_data_generator.hf_dataset_mapping import DatasetMapping
|
8
|
+
from transformers.utils import logging
|
9
|
+
|
10
|
+
from cehrgpt.gpt_utils import (
|
11
|
+
DEMOGRAPHIC_PROMPT_SIZE,
|
12
|
+
collect_demographic_prompts_at_visits,
|
13
|
+
construct_age_sequence,
|
14
|
+
construct_time_sequence,
|
15
|
+
extract_time_interval_in_days,
|
16
|
+
extract_time_interval_in_hours,
|
17
|
+
is_att_token,
|
18
|
+
is_clinical_event,
|
19
|
+
is_inpatient_att_token,
|
20
|
+
is_inpatient_hour_token,
|
21
|
+
random_slice_gpt_sequence,
|
22
|
+
)
|
23
|
+
from cehrgpt.models.tokenization_hf_cehrgpt import CehrGptTokenizer
|
24
|
+
|
25
|
+
TIME_TO_EVENT_MAX_TIME = 3650
|
26
|
+
INPATIENT_STAY_DURATION_LIMIT = 30
|
27
|
+
LOG = logging.get_logger("transformers")
|
28
|
+
|
29
|
+
|
30
|
+
class CehrGptDataProcessor(DatasetMapping):
|
31
|
+
def __init__(
|
32
|
+
self,
|
33
|
+
tokenizer: CehrGptTokenizer,
|
34
|
+
max_length: int,
|
35
|
+
shuffle_records: bool = False,
|
36
|
+
include_values: bool = False,
|
37
|
+
include_ttv_prediction: bool = False,
|
38
|
+
include_motor_time_to_event: bool = False,
|
39
|
+
motor_sampling_probability: float = 0.5,
|
40
|
+
pretraining: bool = True,
|
41
|
+
include_demographics: bool = False,
|
42
|
+
add_linear_prob_token: bool = False,
|
43
|
+
):
|
44
|
+
self.tokenizer = tokenizer
|
45
|
+
self.max_length = max_length
|
46
|
+
|
47
|
+
self.vs_token_id = tokenizer.vs_token_id
|
48
|
+
self.ve_token_id = tokenizer.ve_token_id
|
49
|
+
|
50
|
+
self.shuffle_records = shuffle_records
|
51
|
+
self.include_values = include_values
|
52
|
+
self.include_ttv_prediction = include_ttv_prediction
|
53
|
+
self.pretraining = pretraining
|
54
|
+
self.include_demographics = include_demographics
|
55
|
+
self.add_linear_prob_token = add_linear_prob_token
|
56
|
+
self.empty_array = np.asarray([])
|
57
|
+
|
58
|
+
if self.pretraining and self.add_linear_prob_token:
|
59
|
+
raise ValueError(
|
60
|
+
"pretraining and add_linear_prob_token cannot be specify at the same time"
|
61
|
+
)
|
62
|
+
|
63
|
+
# Motor related codes
|
64
|
+
self.include_motor_time_to_event = include_motor_time_to_event
|
65
|
+
self.motor_sampling_probability = motor_sampling_probability
|
66
|
+
self.motor_code_cache: Dict[str, List[str]] = {}
|
67
|
+
# Pre-compute vocab-wide token type mappings
|
68
|
+
self._precompute_vocab_mappings()
|
69
|
+
|
70
|
+
def _precompute_vocab_mappings(self):
|
71
|
+
"""Pre-compute token type mappings for entire vocabulary."""
|
72
|
+
LOG.info("Pre-computing vocabulary-wide token mappings...")
|
73
|
+
|
74
|
+
vocab = self.tokenizer.get_vocab()
|
75
|
+
self.vocab_to_idx = {token: idx for idx, token in enumerate(vocab.keys())}
|
76
|
+
self.vocab_tokens = list(vocab.keys())
|
77
|
+
|
78
|
+
# Pre-compute boolean arrays for token types
|
79
|
+
n_vocab = len(self.vocab_tokens)
|
80
|
+
self.is_att_token_array = np.zeros(n_vocab, dtype=bool)
|
81
|
+
self.is_clinical_event_array = np.zeros(n_vocab, dtype=bool)
|
82
|
+
self.time_intervals_array = np.full(n_vocab, -1, dtype=int)
|
83
|
+
|
84
|
+
for i, token in enumerate(self.vocab_tokens):
|
85
|
+
if is_att_token(token):
|
86
|
+
self.is_att_token_array[i] = True
|
87
|
+
try:
|
88
|
+
self.time_intervals_array[i] = extract_time_interval_in_days(token)
|
89
|
+
except (ValueError, AttributeError):
|
90
|
+
self.time_intervals_array[i] = -1
|
91
|
+
|
92
|
+
if is_clinical_event(token):
|
93
|
+
self.is_clinical_event_array[i] = True
|
94
|
+
|
95
|
+
LOG.info(f"Processed {n_vocab} vocabulary tokens")
|
96
|
+
|
97
|
+
@staticmethod
|
98
|
+
def _convert_time_to_event(concept_ids):
|
99
|
+
def default_value(c):
|
100
|
+
try:
|
101
|
+
if is_att_token(c):
|
102
|
+
time_to_visit = extract_time_interval_in_days(c)
|
103
|
+
if (
|
104
|
+
is_inpatient_att_token(c)
|
105
|
+
and time_to_visit > INPATIENT_STAY_DURATION_LIMIT
|
106
|
+
):
|
107
|
+
return -100
|
108
|
+
return time_to_visit
|
109
|
+
elif is_inpatient_hour_token(c):
|
110
|
+
return extract_time_interval_in_hours(c) / 24
|
111
|
+
return -100
|
112
|
+
except ValueError:
|
113
|
+
return -100
|
114
|
+
|
115
|
+
return [float(default_value(_)) for _ in concept_ids]
|
116
|
+
|
117
|
+
def random_sort(self, record: Dict[str, Any]) -> Dict[str, Any]:
|
118
|
+
if "record_ranks" not in record:
|
119
|
+
return record
|
120
|
+
|
121
|
+
sorting_column = record["record_ranks"]
|
122
|
+
random_order = np.random.rand(len(sorting_column))
|
123
|
+
|
124
|
+
if self.include_values:
|
125
|
+
iterator = zip(
|
126
|
+
sorting_column,
|
127
|
+
random_order,
|
128
|
+
record["input_ids"],
|
129
|
+
record["value_indicators"],
|
130
|
+
record["values"],
|
131
|
+
)
|
132
|
+
sorted_list = sorted(iterator, key=lambda tup2: (tup2[0], tup2[1], tup2[2]))
|
133
|
+
_, _, sorted_input_ids, sorted_value_indicators, sorted_values = zip(
|
134
|
+
*list(sorted_list)
|
135
|
+
)
|
136
|
+
record["input_ids"] = sorted_input_ids
|
137
|
+
record["value_indicators"] = sorted_value_indicators
|
138
|
+
record["values"] = sorted_values
|
139
|
+
else:
|
140
|
+
iterator = zip(sorting_column, random_order, record["input_ids"])
|
141
|
+
sorted_list = sorted(iterator, key=lambda tup2: (tup2[0], tup2[1], tup2[2]))
|
142
|
+
_, _, sorted_input_ids = zip(*list(sorted_list))
|
143
|
+
record["input_ids"] = sorted_input_ids
|
144
|
+
return record
|
145
|
+
|
146
|
+
def transform(self, example: Dict[str, Any]) -> Dict[str, Any]:
|
147
|
+
|
148
|
+
if self.shuffle_records:
|
149
|
+
example = self.random_sort(example)
|
150
|
+
|
151
|
+
if "concept_ids" not in example:
|
152
|
+
input_ids = example["input_ids"]
|
153
|
+
if isinstance(input_ids, torch.Tensor):
|
154
|
+
input_ids = input_ids.detach().tolist()
|
155
|
+
example["concept_ids"] = self.tokenizer.decode(
|
156
|
+
input_ids, skip_special_tokens=False
|
157
|
+
)
|
158
|
+
example["ages"] = pd.Series(example["ages"]).ffill().tolist()
|
159
|
+
example = self.slice_out_input_sequence(example)
|
160
|
+
# Add the motor labels
|
161
|
+
if self.include_motor_time_to_event:
|
162
|
+
motor_inputs = self.create_time_to_event_labels(example)
|
163
|
+
example.update(motor_inputs)
|
164
|
+
del example["concept_ids"]
|
165
|
+
return example
|
166
|
+
|
167
|
+
def update_inputs_based_on_indexes(
|
168
|
+
self,
|
169
|
+
record: Dict[str, Any],
|
170
|
+
start_index,
|
171
|
+
end_index,
|
172
|
+
add_end_token: bool = False,
|
173
|
+
demographic_tokens: Optional[List[str]] = None,
|
174
|
+
) -> Dict[str, Any]:
|
175
|
+
|
176
|
+
last_token_id = (
|
177
|
+
self.tokenizer.linear_token_id
|
178
|
+
if self.add_linear_prob_token
|
179
|
+
else self.tokenizer.end_token_id
|
180
|
+
)
|
181
|
+
|
182
|
+
add_last_token = self.add_linear_prob_token | add_end_token
|
183
|
+
|
184
|
+
# Slice out the concept ids
|
185
|
+
record["concept_ids"] = (
|
186
|
+
(demographic_tokens if demographic_tokens is not None else [])
|
187
|
+
+ (record["concept_ids"][start_index:end_index])
|
188
|
+
+ (
|
189
|
+
self.tokenizer.decode([last_token_id], skip_special_tokens=False)
|
190
|
+
if add_last_token
|
191
|
+
else []
|
192
|
+
)
|
193
|
+
)
|
194
|
+
|
195
|
+
record["input_ids"] = np.concatenate(
|
196
|
+
[
|
197
|
+
(
|
198
|
+
np.asarray(self.tokenizer.encode(demographic_tokens))
|
199
|
+
if demographic_tokens is not None
|
200
|
+
else self.empty_array
|
201
|
+
),
|
202
|
+
np.asarray(record["input_ids"][start_index:end_index]),
|
203
|
+
(np.asarray([last_token_id]) if add_last_token else self.empty_array),
|
204
|
+
]
|
205
|
+
).astype(np.int32)
|
206
|
+
|
207
|
+
record["ages"] = np.concatenate(
|
208
|
+
[
|
209
|
+
(
|
210
|
+
np.full([DEMOGRAPHIC_PROMPT_SIZE], record["ages"][0])
|
211
|
+
if demographic_tokens is not None
|
212
|
+
else self.empty_array
|
213
|
+
),
|
214
|
+
np.asarray(record["ages"][start_index:end_index]),
|
215
|
+
(
|
216
|
+
np.asarray([record["ages"][-1]])
|
217
|
+
if add_last_token
|
218
|
+
else self.empty_array
|
219
|
+
),
|
220
|
+
]
|
221
|
+
).astype(np.int32)
|
222
|
+
|
223
|
+
# For the new datasets, they contain the column "epoch_times"
|
224
|
+
record["epoch_times"] = np.concatenate(
|
225
|
+
[
|
226
|
+
(
|
227
|
+
np.zeros([DEMOGRAPHIC_PROMPT_SIZE])
|
228
|
+
if demographic_tokens is not None
|
229
|
+
else self.empty_array
|
230
|
+
),
|
231
|
+
np.asarray(record["epoch_times"][start_index:end_index]),
|
232
|
+
(
|
233
|
+
np.asarray([record["epoch_times"][-1]])
|
234
|
+
if add_last_token
|
235
|
+
else self.empty_array
|
236
|
+
),
|
237
|
+
]
|
238
|
+
).astype(np.float32)
|
239
|
+
|
240
|
+
if self.include_values:
|
241
|
+
record["value_indicators"] = np.concatenate(
|
242
|
+
[
|
243
|
+
(
|
244
|
+
np.zeros([DEMOGRAPHIC_PROMPT_SIZE])
|
245
|
+
if demographic_tokens is not None
|
246
|
+
else self.empty_array
|
247
|
+
),
|
248
|
+
np.asarray(record["value_indicators"][start_index:end_index]),
|
249
|
+
np.asarray([False]) if add_last_token else self.empty_array,
|
250
|
+
]
|
251
|
+
).astype(np.bool_)
|
252
|
+
record["values"] = np.concatenate(
|
253
|
+
[
|
254
|
+
(
|
255
|
+
np.full(
|
256
|
+
[DEMOGRAPHIC_PROMPT_SIZE], self.tokenizer.pad_value_token_id
|
257
|
+
)
|
258
|
+
if demographic_tokens is not None
|
259
|
+
else self.empty_array
|
260
|
+
),
|
261
|
+
np.asarray(record["values"][start_index:end_index]),
|
262
|
+
(
|
263
|
+
np.asarray([self.tokenizer.pad_value_token_id])
|
264
|
+
if add_last_token
|
265
|
+
else self.empty_array
|
266
|
+
),
|
267
|
+
]
|
268
|
+
).astype(np.int32)
|
269
|
+
|
270
|
+
if self.include_ttv_prediction:
|
271
|
+
record["time_to_visits"] = np.concatenate(
|
272
|
+
[
|
273
|
+
(
|
274
|
+
np.full([DEMOGRAPHIC_PROMPT_SIZE], -100.0)
|
275
|
+
if demographic_tokens is not None
|
276
|
+
else self.empty_array
|
277
|
+
),
|
278
|
+
np.asarray(
|
279
|
+
self._convert_time_to_event(
|
280
|
+
record["concept_ids"][start_index:end_index]
|
281
|
+
)
|
282
|
+
),
|
283
|
+
np.asarray([-100.0]) if add_last_token else self.empty_array,
|
284
|
+
]
|
285
|
+
).astype(np.float32)
|
286
|
+
|
287
|
+
return record
|
288
|
+
|
289
|
+
def slice_out_input_sequence(self, record: Dict[str, Any]) -> Dict[str, Any]:
|
290
|
+
"""Adding the start and end indices to extract a portion of the patient sequence."""
|
291
|
+
# Subtract one for the [END] or [LINEAR_PROB] token when sample_packing is not enabled
|
292
|
+
new_max_length = (
|
293
|
+
self.max_length - 1
|
294
|
+
if self.add_linear_prob_token or self.pretraining
|
295
|
+
else self.max_length
|
296
|
+
)
|
297
|
+
concept_ids = record["concept_ids"]
|
298
|
+
seq_length = len(record["input_ids"])
|
299
|
+
|
300
|
+
# For backward compatibility, in case these two columns do not already exist
|
301
|
+
record["ages"] = construct_age_sequence(record["concept_ids"], record["ages"])
|
302
|
+
record["epoch_times"] = construct_time_sequence(
|
303
|
+
record["concept_ids"], record["epoch_times"]
|
304
|
+
)
|
305
|
+
|
306
|
+
# Return the record directly if the actual sequence length is less than the max sequence
|
307
|
+
if seq_length <= new_max_length:
|
308
|
+
# We only add [END] to the end of the sequence in pre-training
|
309
|
+
record = self.update_inputs_based_on_indexes(
|
310
|
+
record, 0, seq_length, add_end_token=self.pretraining
|
311
|
+
)
|
312
|
+
return record
|
313
|
+
|
314
|
+
if self.pretraining:
|
315
|
+
end_index = new_max_length
|
316
|
+
# There is a 50% chance we randomly slice out a portion of the patient history and update the demographic
|
317
|
+
# prompt depending on the new starting point
|
318
|
+
if random.random() < 0.5:
|
319
|
+
start_index, end_index, demographic_tokens = random_slice_gpt_sequence(
|
320
|
+
concept_ids, new_max_length
|
321
|
+
)
|
322
|
+
if start_index != end_index:
|
323
|
+
record = self.update_inputs_based_on_indexes(
|
324
|
+
record, start_index, end_index + 1, add_end_token=False
|
325
|
+
)
|
326
|
+
return record
|
327
|
+
|
328
|
+
# The default employs a right truncation strategy, where the demographic prompt is reserved
|
329
|
+
for i in reversed(list(range(0, end_index))):
|
330
|
+
current_token = record["input_ids"][i]
|
331
|
+
if current_token == self.ve_token_id:
|
332
|
+
# Plus one because slicing is right exclusive
|
333
|
+
end_index = i + 1
|
334
|
+
break
|
335
|
+
|
336
|
+
record = self.update_inputs_based_on_indexes(
|
337
|
+
record=record, start_index=0, end_index=end_index, add_end_token=False
|
338
|
+
)
|
339
|
+
return record
|
340
|
+
else:
|
341
|
+
if self.include_demographics:
|
342
|
+
# We employ a left truncation strategy, where the most recent patient history is reserved for fine-tuning
|
343
|
+
demographic_prompts_at_visits = collect_demographic_prompts_at_visits(
|
344
|
+
concept_ids
|
345
|
+
)
|
346
|
+
for token_index, demographic_prompt in demographic_prompts_at_visits:
|
347
|
+
if (
|
348
|
+
seq_length - token_index
|
349
|
+
<= new_max_length - DEMOGRAPHIC_PROMPT_SIZE
|
350
|
+
):
|
351
|
+
return self.update_inputs_based_on_indexes(
|
352
|
+
record=record,
|
353
|
+
start_index=token_index,
|
354
|
+
end_index=seq_length,
|
355
|
+
add_end_token=False,
|
356
|
+
demographic_tokens=demographic_prompt,
|
357
|
+
)
|
358
|
+
else:
|
359
|
+
start_index = seq_length - new_max_length
|
360
|
+
end_index = seq_length
|
361
|
+
for i in range(start_index, end_index):
|
362
|
+
current_token = record["input_ids"][i]
|
363
|
+
if current_token == self.vs_token_id:
|
364
|
+
return self.update_inputs_based_on_indexes(
|
365
|
+
record=record,
|
366
|
+
start_index=i,
|
367
|
+
end_index=end_index,
|
368
|
+
add_end_token=False,
|
369
|
+
)
|
370
|
+
|
371
|
+
# This could happen when the last visit contains more than new_max_length number of tokens
|
372
|
+
# We simply take the last new_max_length number of tokens from the patient sequence
|
373
|
+
if len(record["input_ids"]) > new_max_length:
|
374
|
+
record = self.update_inputs_based_on_indexes(
|
375
|
+
record=record,
|
376
|
+
start_index=-new_max_length,
|
377
|
+
end_index=seq_length,
|
378
|
+
add_end_token=False,
|
379
|
+
)
|
380
|
+
return record
|
381
|
+
|
382
|
+
def create_time_to_event_labels(self, record: Dict[str, Any]) -> Dict[str, Any]:
|
383
|
+
"""
|
384
|
+
Generates time-to-event (TTE) labels and censoring indicators for each visit in a patient's timeline.
|
385
|
+
|
386
|
+
Processes the input sequence in reverse to compute the number of days from each visit (marked by [VE])
|
387
|
+
to the occurrence of future motor-related events.
|
388
|
+
|
389
|
+
Args:
|
390
|
+
record (Dict[str, Any]): A dictionary containing the encoded patient sequence with the key "input_ids".
|
391
|
+
This sequence includes [VS], [VE], time delta tokens, and motor TTE concept codes.
|
392
|
+
|
393
|
+
Returns:
|
394
|
+
Dict[str, Any]: The updated input record with added keys:
|
395
|
+
- "time_to_event_vectors": np.ndarray of shape [num_visits, motor_vocab_size], containing time-to-event values
|
396
|
+
- "event_indicators": np.ndarray of shape [num_visits, motor_vocab_size], where 0 = event occurred, 1 = censored
|
397
|
+
"""
|
398
|
+
|
399
|
+
"""Highly optimized vectorized version using pre-computed token type arrays."""
|
400
|
+
concept_ids = record["concept_ids"]
|
401
|
+
# Convert concept_ids to indices for vectorized operations
|
402
|
+
concept_indices = np.array([self.vocab_to_idx[cid] for cid in concept_ids])
|
403
|
+
# Vectorized token type detection
|
404
|
+
is_att_tokens = self.is_att_token_array[concept_indices]
|
405
|
+
is_clinical_events = self.is_clinical_event_array[concept_indices]
|
406
|
+
time_intervals = self.time_intervals_array[concept_indices]
|
407
|
+
|
408
|
+
# Find valid time tokens (att tokens with positive intervals)
|
409
|
+
valid_time_tokens = is_att_tokens & (time_intervals > 0)
|
410
|
+
n_concepts = len(concept_ids)
|
411
|
+
|
412
|
+
# We need to make sure event_times is monotonic
|
413
|
+
event_times = np.zeros(n_concepts, dtype=float)
|
414
|
+
previous_time_stamp = record["epoch_times"][0]
|
415
|
+
for i, time_stamp in enumerate(record["epoch_times"]):
|
416
|
+
if time_stamp < previous_time_stamp:
|
417
|
+
time_stamp = previous_time_stamp
|
418
|
+
else:
|
419
|
+
previous_time_stamp = time_stamp
|
420
|
+
event_times[i] = time_stamp
|
421
|
+
|
422
|
+
# Determine prediction positions
|
423
|
+
before_valid_time_tokens = np.roll(valid_time_tokens, -1)
|
424
|
+
# We randomly make predictions at 50% of the sequence positions
|
425
|
+
prediction_positions = (
|
426
|
+
np.random.rand(n_concepts) < self.motor_sampling_probability
|
427
|
+
)
|
428
|
+
# We don't predict at the att time tokens
|
429
|
+
prediction_positions &= ~is_att_tokens
|
430
|
+
# We disable TTE predictions using the demographics alone
|
431
|
+
prediction_positions[:4] = False
|
432
|
+
# We take the union of the random prediction positions and the positions right before time token
|
433
|
+
prediction_positions = prediction_positions | before_valid_time_tokens
|
434
|
+
# We exclude the events that occur at the last time stamp
|
435
|
+
prediction_positions &= event_times != event_times[-1]
|
436
|
+
|
437
|
+
prediction_indices = np.where(prediction_positions)[0]
|
438
|
+
if len(prediction_indices) == 0:
|
439
|
+
return {
|
440
|
+
"motor_censor_times": [],
|
441
|
+
"motor_row_indices": [],
|
442
|
+
"motor_col_indices": [],
|
443
|
+
"motor_values": [],
|
444
|
+
"motor_tte_task_indicators": [False] * n_concepts,
|
445
|
+
}
|
446
|
+
|
447
|
+
# Pre-compute all motor codes for clinical events to avoid repeated lookups
|
448
|
+
clinical_positions = np.where(is_clinical_events)[0]
|
449
|
+
motor_codes_cache = {} # position -> list of (motor_code, motor_token_id)
|
450
|
+
|
451
|
+
for pos in clinical_positions:
|
452
|
+
concept_id = concept_ids[pos]
|
453
|
+
if concept_id in self.motor_code_cache:
|
454
|
+
motor_codes = self.motor_code_cache[concept_id]
|
455
|
+
else:
|
456
|
+
motor_codes = self.tokenizer.get_motor_parents(concept_id)
|
457
|
+
self.motor_code_cache[concept_id] = motor_codes
|
458
|
+
|
459
|
+
if motor_codes:
|
460
|
+
motor_codes_cache[pos] = [
|
461
|
+
(motor_code, self.tokenizer.get_motor_token_id(motor_code))
|
462
|
+
for motor_code in motor_codes
|
463
|
+
]
|
464
|
+
|
465
|
+
# Process sections in REVERSE order but build results in FORWARD order
|
466
|
+
section_boundaries = np.concatenate([prediction_indices, [n_concepts]])
|
467
|
+
last_event_time = event_times[-1]
|
468
|
+
|
469
|
+
# Pre-allocate arrays with exact size needed
|
470
|
+
num_prediction_positions = len(prediction_indices)
|
471
|
+
motor_censor_times = np.zeros(num_prediction_positions, dtype=float)
|
472
|
+
motor_tte_task_indicators = np.zeros(n_concepts, dtype=bool)
|
473
|
+
|
474
|
+
# Store sparse matrix data grouped by row for efficient construction
|
475
|
+
sparse_data_by_row = {} # row_idx -> [(col_idx, value), ...]
|
476
|
+
|
477
|
+
# Global motor event state that accumulates as we go backwards
|
478
|
+
global_motor_events = {} # motor_code -> earliest_future_time
|
479
|
+
|
480
|
+
# Process in reverse order but assign to forward row indices
|
481
|
+
for i in range(len(prediction_indices) - 1, -1, -1):
|
482
|
+
start_index = prediction_indices[i]
|
483
|
+
end_index = section_boundaries[i + 1]
|
484
|
+
current_event_time = event_times[start_index]
|
485
|
+
|
486
|
+
# Add new motor events from this section to global state
|
487
|
+
section_start = start_index + 1
|
488
|
+
section_end = end_index + 1 if end_index < n_concepts else n_concepts
|
489
|
+
|
490
|
+
# Process clinical events in this section (in reverse order within section)
|
491
|
+
section_clinical_positions = clinical_positions[
|
492
|
+
(clinical_positions >= section_start)
|
493
|
+
& (clinical_positions < section_end)
|
494
|
+
]
|
495
|
+
|
496
|
+
for pos in reversed(section_clinical_positions):
|
497
|
+
if pos in motor_codes_cache:
|
498
|
+
concept_time = event_times[pos]
|
499
|
+
if concept_time > current_event_time:
|
500
|
+
for motor_code, motor_token_id in motor_codes_cache[pos]:
|
501
|
+
global_motor_events[motor_code] = (
|
502
|
+
concept_time,
|
503
|
+
motor_token_id,
|
504
|
+
)
|
505
|
+
|
506
|
+
# Store sparse matrix data for current prediction position
|
507
|
+
# Even if global_motor_events is empty, we still need to record this position
|
508
|
+
# because it indicates all motor tasks are censored at this time point
|
509
|
+
sparse_data_by_row[i] = [
|
510
|
+
(motor_token_id, motor_time - current_event_time)
|
511
|
+
for motor_code, (
|
512
|
+
motor_time,
|
513
|
+
motor_token_id,
|
514
|
+
) in global_motor_events.items()
|
515
|
+
]
|
516
|
+
motor_tte_task_indicators[start_index] = True
|
517
|
+
motor_censor_times[i] = last_event_time - current_event_time
|
518
|
+
|
519
|
+
# Build final sparse matrix lists in forward order (no reversals needed)
|
520
|
+
motor_row_indices = []
|
521
|
+
motor_col_indices = []
|
522
|
+
motor_values = []
|
523
|
+
|
524
|
+
for row_idx in sorted(sparse_data_by_row.keys()):
|
525
|
+
for col_idx, value in sparse_data_by_row[row_idx]:
|
526
|
+
motor_row_indices.append(row_idx)
|
527
|
+
motor_col_indices.append(col_idx)
|
528
|
+
motor_values.append(value)
|
529
|
+
|
530
|
+
# Filter out unused positions from motor_censor_times
|
531
|
+
motor_censor_times = [
|
532
|
+
motor_censor_times[i] for i in sorted(sparse_data_by_row.keys())
|
533
|
+
]
|
534
|
+
|
535
|
+
if len(motor_row_indices) == 0:
|
536
|
+
LOG.debug(
|
537
|
+
"No MOTOR tasks detected for this sample. "
|
538
|
+
"Length: %s, last 10 concepts: %s",
|
539
|
+
len(concept_ids),
|
540
|
+
concept_ids[-10:] if len(concept_ids) >= 10 else concept_ids,
|
541
|
+
)
|
542
|
+
|
543
|
+
return {
|
544
|
+
"motor_censor_times": motor_censor_times,
|
545
|
+
"motor_row_indices": motor_row_indices,
|
546
|
+
"motor_col_indices": motor_col_indices,
|
547
|
+
"motor_values": motor_values,
|
548
|
+
"motor_tte_task_indicators": motor_tte_task_indicators.tolist(),
|
549
|
+
}
|