cehrgpt 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cehrgpt/analysis/htn_treatment_pathway.py +546 -0
- cehrgpt/analysis/treatment_pathway/__init__.py +0 -0
- cehrgpt/analysis/treatment_pathway/depression_treatment_pathway.py +94 -0
- cehrgpt/analysis/treatment_pathway/diabetes_treatment_pathway.py +94 -0
- cehrgpt/analysis/treatment_pathway/htn_treatment_pathway.py +94 -0
- cehrgpt/analysis/treatment_pathway/treatment_pathway.py +631 -0
- cehrgpt/data/cehrgpt_data_processor.py +549 -0
- cehrgpt/data/hf_cehrgpt_dataset.py +4 -0
- cehrgpt/data/hf_cehrgpt_dataset_collator.py +285 -652
- cehrgpt/data/hf_cehrgpt_dataset_mapping.py +38 -5
- cehrgpt/generation/cehrgpt_conditional_generation.py +2 -0
- cehrgpt/generation/generate_batch_hf_gpt_sequence.py +20 -12
- cehrgpt/generation/omop_converter_batch.py +11 -4
- cehrgpt/gpt_utils.py +73 -3
- cehrgpt/models/activations.py +27 -0
- cehrgpt/models/config.py +6 -2
- cehrgpt/models/gpt2.py +560 -0
- cehrgpt/models/hf_cehrgpt.py +183 -460
- cehrgpt/models/tokenization_hf_cehrgpt.py +380 -50
- cehrgpt/omop/ontology.py +154 -0
- cehrgpt/runners/hf_cehrgpt_finetune_runner.py +24 -78
- cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +48 -44
- cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +46 -34
- cehrgpt/runners/hyperparameter_search_util.py +180 -69
- cehrgpt/runners/sample_packing_trainer.py +11 -2
- cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +8 -2
- cehrgpt-0.1.4.dist-info/METADATA +238 -0
- {cehrgpt-0.1.2.dist-info → cehrgpt-0.1.4.dist-info}/RECORD +32 -22
- cehrgpt-0.1.2.dist-info/METADATA +0 -209
- /cehrgpt/tools/{merge_synthetic_real_dataasets.py → merge_synthetic_real_datasets.py} +0 -0
- {cehrgpt-0.1.2.dist-info → cehrgpt-0.1.4.dist-info}/WHEEL +0 -0
- {cehrgpt-0.1.2.dist-info → cehrgpt-0.1.4.dist-info}/licenses/LICENSE +0 -0
- {cehrgpt-0.1.2.dist-info → cehrgpt-0.1.4.dist-info}/top_level.txt +0 -0
@@ -1,19 +1,24 @@
|
|
1
1
|
import collections
|
2
2
|
import copy
|
3
3
|
import json
|
4
|
+
import math
|
4
5
|
import os
|
5
6
|
import pickle
|
6
7
|
from functools import partial
|
7
8
|
from itertools import islice
|
8
9
|
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
|
9
10
|
|
11
|
+
import femr
|
12
|
+
import femr.ontology
|
10
13
|
import numpy as np
|
11
14
|
import scipy.stats as stats
|
12
15
|
import transformers
|
13
16
|
from cehrbert.models.hf_models.tokenization_utils import agg_helper, load_json_file
|
14
17
|
from cehrbert.runners.hf_runner_argument_dataclass import DataTrainingArguments
|
18
|
+
from cehrbert_data.const.artificial_tokens import DEATH_TOKEN
|
15
19
|
from datasets import Dataset, DatasetDict, IterableDataset
|
16
20
|
from femr.stat_utils import OnlineStatistics, ReservoirSampler
|
21
|
+
from meds import death_code
|
17
22
|
from scipy.interpolate import UnivariateSpline
|
18
23
|
from tokenizers import AddedToken, Tokenizer
|
19
24
|
from tokenizers.models import WordLevel
|
@@ -38,6 +43,7 @@ from cehrgpt.models.special_tokens import (
|
|
38
43
|
PAD_TOKEN,
|
39
44
|
START_TOKEN,
|
40
45
|
)
|
46
|
+
from cehrgpt.omop.ontology import Ontology
|
41
47
|
|
42
48
|
NUM_OF_BINS = 10
|
43
49
|
DEGREE_OF_FREEDOM = 3
|
@@ -52,8 +58,10 @@ TOKEN_TO_SUB_TIME_TOKEN_MAPPING_FILE_NAME = "token_to_sub_time_token_mapping.jso
|
|
52
58
|
LAB_STATS_FILE_NAME = "cehrgpt_lab_stats.pickle"
|
53
59
|
LEGACY_LAB_STATS_FILE_NAME = "cehrgpt_lab_stats.json"
|
54
60
|
CONCEPT_STATS_FILE_NAME = "cehrgpt_concept_stats.json"
|
61
|
+
DEMOGRAPHICS_STATS_FILE_NAME = "demographics_stats.pickle"
|
55
62
|
CONCEPT_MAPPING_FILE_NAME = "concept_name_mapping.json"
|
56
|
-
|
63
|
+
MOTOR_TIME_TO_EVENT_TASK_INFO_FILE_NAME = "motor_time_to_event_info.pickle"
|
64
|
+
ONTOLOGY_FILE_NAME = "ontology.pickle"
|
57
65
|
|
58
66
|
LOG = logging.get_logger("transformers")
|
59
67
|
|
@@ -84,6 +92,23 @@ def get_dataset_len(dataset: Union[Dataset, IterableDataset]) -> int:
|
|
84
92
|
)
|
85
93
|
|
86
94
|
|
95
|
+
def get_allowed_motor_codes(
|
96
|
+
original_concept_codes: List[str], ontology: Optional[Ontology]
|
97
|
+
) -> List[str]:
|
98
|
+
filtered_original_concept_codes = filter(is_clinical_event, original_concept_codes)
|
99
|
+
if ontology:
|
100
|
+
allowed_motor_codes = []
|
101
|
+
for concept in filtered_original_concept_codes:
|
102
|
+
domain = ontology.get_domain(concept)
|
103
|
+
if domain and domain in ["Condition", "Procedure", "Drug", "Visit"]:
|
104
|
+
allowed_motor_codes.append(concept)
|
105
|
+
elif concept in [DEATH_TOKEN, death_code]:
|
106
|
+
allowed_motor_codes.append(concept)
|
107
|
+
return allowed_motor_codes
|
108
|
+
else:
|
109
|
+
return list(filtered_original_concept_codes)
|
110
|
+
|
111
|
+
|
87
112
|
def create_sample_from_bins(bins, sample_size: int = 10_000) -> List[float]:
|
88
113
|
"""
|
89
114
|
Generates a specified number of samples from a list of bins, each containing a fitted spline.
|
@@ -199,6 +224,112 @@ def create_bins_with_spline(samples, num_bins, d_freedom=3) -> List[Dict[str, An
|
|
199
224
|
return bins
|
200
225
|
|
201
226
|
|
227
|
+
def map_motor_tte_statistics(
|
228
|
+
batch: Dict[str, Any],
|
229
|
+
allowed_motor_codes: List[str],
|
230
|
+
) -> Dict[str, Any]:
|
231
|
+
motor_event_times = femr.stat_utils.ReservoirSampler(100_000)
|
232
|
+
task_tte_stats: Dict[str, int] = collections.defaultdict(int)
|
233
|
+
task_censor_stats: Dict[str, int] = collections.defaultdict(int)
|
234
|
+
for concept_ids in batch["concept_ids"]:
|
235
|
+
# First collect TTE data in reverse chronological order
|
236
|
+
censor_time = 0
|
237
|
+
time_to_event_dict: Dict[str, int] = {}
|
238
|
+
next_future_visit_concepts = set()
|
239
|
+
# Reverse walk through concept_ids to calculate TTE from each [VE] point
|
240
|
+
for concept_id in reversed(concept_ids):
|
241
|
+
if is_att_token(concept_id):
|
242
|
+
time_interval = extract_time_interval_in_days(concept_id)
|
243
|
+
if time_interval > 0:
|
244
|
+
# Update TTE for existing concepts, or add new ones seen in this visit
|
245
|
+
for existing_concept_id in list(time_to_event_dict.keys()):
|
246
|
+
if existing_concept_id in next_future_visit_concepts:
|
247
|
+
time_to_event_dict[existing_concept_id] = time_interval
|
248
|
+
else:
|
249
|
+
time_to_event_dict[existing_concept_id] += time_interval
|
250
|
+
|
251
|
+
for next_concept_id in next_future_visit_concepts:
|
252
|
+
if next_concept_id not in time_to_event_dict:
|
253
|
+
time_to_event_dict[next_concept_id] = time_interval
|
254
|
+
|
255
|
+
# Record the censor time at the end of the visit
|
256
|
+
censor_time += time_interval
|
257
|
+
|
258
|
+
# Keep track of the time to event value
|
259
|
+
for tte in time_to_event_dict.values():
|
260
|
+
motor_event_times.add(tte, 1)
|
261
|
+
|
262
|
+
for motor_code in allowed_motor_codes:
|
263
|
+
if motor_code in time_to_event_dict:
|
264
|
+
task_tte_stats[motor_code] += 1
|
265
|
+
else:
|
266
|
+
task_censor_stats[motor_code] += 1
|
267
|
+
next_future_visit_concepts.clear()
|
268
|
+
else:
|
269
|
+
next_future_visit_concepts.add(concept_id)
|
270
|
+
|
271
|
+
return {
|
272
|
+
"motor_event_times": motor_event_times,
|
273
|
+
"task_tte_stats": task_tte_stats,
|
274
|
+
"task_censor_stats": task_censor_stats,
|
275
|
+
}
|
276
|
+
|
277
|
+
|
278
|
+
def compute_motor_tte_statistics(
|
279
|
+
dataset: Dataset,
|
280
|
+
data_args: DataTrainingArguments,
|
281
|
+
allowed_motor_codes: List[str],
|
282
|
+
ontology: Optional[Ontology] = None,
|
283
|
+
) -> Dict[str, Any]:
|
284
|
+
map_motor_tte_statistics_partial = partial(
|
285
|
+
map_motor_tte_statistics,
|
286
|
+
allowed_motor_codes=allowed_motor_codes,
|
287
|
+
)
|
288
|
+
if data_args.streaming:
|
289
|
+
first_example = next(iter(dataset))
|
290
|
+
parts = dataset.map(
|
291
|
+
partial(agg_helper, map_func=map_motor_tte_statistics_partial),
|
292
|
+
batched=True,
|
293
|
+
batch_size=data_args.preprocessing_batch_size,
|
294
|
+
remove_columns=first_example.keys(),
|
295
|
+
)
|
296
|
+
else:
|
297
|
+
parts = dataset.map(
|
298
|
+
partial(agg_helper, map_func=map_motor_tte_statistics_partial),
|
299
|
+
batched=True,
|
300
|
+
batch_size=data_args.preprocessing_batch_size,
|
301
|
+
remove_columns=dataset.column_names,
|
302
|
+
num_proc=data_args.preprocessing_num_workers,
|
303
|
+
keep_in_memory=True,
|
304
|
+
new_fingerprint="invalid",
|
305
|
+
)
|
306
|
+
current = None
|
307
|
+
for stat in tqdm(parts, desc="Aggregating the MOTOR TTE statistics"):
|
308
|
+
fixed_stat = pickle.loads(stat["data"])
|
309
|
+
if current is None:
|
310
|
+
current = fixed_stat
|
311
|
+
else:
|
312
|
+
current["motor_event_times"].combine(fixed_stat["motor_event_times"])
|
313
|
+
for k, v in fixed_stat["task_tte_stats"].items():
|
314
|
+
current["task_tte_stats"][k] += v
|
315
|
+
for k, v in fixed_stat["task_censor_stats"].items():
|
316
|
+
current["task_censor_stats"][k] += v
|
317
|
+
|
318
|
+
# Aggregate the counts for the parent concepts
|
319
|
+
if ontology is not None:
|
320
|
+
for k in list(current["task_tte_stats"].keys()):
|
321
|
+
for parent in ontology.get_all_parents(k):
|
322
|
+
if parent != k:
|
323
|
+
current["task_tte_stats"][parent] += current["task_tte_stats"][k]
|
324
|
+
for k in list(current["task_censor_stats"].keys()):
|
325
|
+
for parent in ontology.get_all_parents(k):
|
326
|
+
if parent != k:
|
327
|
+
current["task_censor_stats"][parent] += current[
|
328
|
+
"task_censor_stats"
|
329
|
+
][k]
|
330
|
+
return current
|
331
|
+
|
332
|
+
|
202
333
|
def agg_statistics(stats1, stats2):
|
203
334
|
if stats1.get("numeric_stats_by_lab"):
|
204
335
|
for k, v in stats2["numeric_stats_by_lab"].items():
|
@@ -211,6 +342,10 @@ def agg_statistics(stats1, stats2):
|
|
211
342
|
if stats1.get("concept_code_stats"):
|
212
343
|
for concept_id, weight in stats2["concept_code_stats"].items():
|
213
344
|
stats1["concept_code_stats"][concept_id] += weight
|
345
|
+
if stats1.get("gender_list"):
|
346
|
+
stats1.get("gender_list").update(stats2.get("gender_list"))
|
347
|
+
if stats1.get("race_list"):
|
348
|
+
stats1.get("race_list").update(stats2.get("race_list"))
|
214
349
|
return stats1
|
215
350
|
|
216
351
|
|
@@ -239,6 +374,8 @@ def map_statistics(batch: Dict[str, Any], total_size, size=10_000) -> Dict[str,
|
|
239
374
|
numeric_stats_by_lab = collections.defaultdict(partial(ReservoirSampler, size=size))
|
240
375
|
categorical_stats_by_lab = collections.defaultdict(int)
|
241
376
|
concept_code_stats = collections.defaultdict(int)
|
377
|
+
gender_list = set()
|
378
|
+
race_list = set()
|
242
379
|
for (
|
243
380
|
concept_ids,
|
244
381
|
number_as_values,
|
@@ -252,6 +389,11 @@ def map_statistics(batch: Dict[str, Any], total_size, size=10_000) -> Dict[str,
|
|
252
389
|
batch["concept_value_masks"],
|
253
390
|
batch_value_units,
|
254
391
|
):
|
392
|
+
# Collecting demographics
|
393
|
+
gender, race = concept_ids[2:4]
|
394
|
+
gender_list.add(gender)
|
395
|
+
race_list.add(race)
|
396
|
+
|
255
397
|
unique_codes = set()
|
256
398
|
for (
|
257
399
|
concept_id,
|
@@ -280,11 +422,15 @@ def map_statistics(batch: Dict[str, Any], total_size, size=10_000) -> Dict[str,
|
|
280
422
|
"numeric_stats_by_lab": numeric_stats_by_lab,
|
281
423
|
"categorical_stats_by_lab": categorical_stats_by_lab,
|
282
424
|
"concept_code_stats": concept_code_stats,
|
425
|
+
"gender_list": gender_list,
|
426
|
+
"race_list": race_list,
|
283
427
|
}
|
284
428
|
|
285
429
|
|
286
430
|
def compute_statistics(
|
287
|
-
dataset: Dataset,
|
431
|
+
dataset: Dataset,
|
432
|
+
data_args: DataTrainingArguments,
|
433
|
+
ontology: Optional[Ontology] = None,
|
288
434
|
) -> Dict[str, Any]:
|
289
435
|
total = get_dataset_len(dataset)
|
290
436
|
map_statistics_partial = partial(map_statistics, total_size=total, size=SAMPLE_SIZE)
|
@@ -339,25 +485,44 @@ def compute_statistics(
|
|
339
485
|
].items():
|
340
486
|
categorical_lab_stats[(concept_id, value_as_concept)] += count
|
341
487
|
|
342
|
-
|
488
|
+
all_concept_code_stats = collections.defaultdict(float)
|
343
489
|
for concept_id, count in current["concept_code_stats"].items():
|
344
|
-
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
|
490
|
+
if ontology is not None:
|
491
|
+
parents = ontology.get_all_parents(concept_id)
|
492
|
+
for parent in parents:
|
493
|
+
all_concept_code_stats[parent] += count
|
494
|
+
else:
|
495
|
+
all_concept_code_stats[concept_id] += count
|
496
|
+
|
497
|
+
all_concept_code_entropies = collections.defaultdict(float)
|
498
|
+
for concept_id, weight in all_concept_code_stats.items():
|
499
|
+
baseline = (
|
500
|
+
min(
|
501
|
+
[1]
|
502
|
+
+ [
|
503
|
+
all_concept_code_stats[parent]
|
504
|
+
for parent in ontology.get_parents(concept_id)
|
505
|
+
]
|
506
|
+
)
|
507
|
+
if ontology is not None
|
508
|
+
else 1
|
509
|
+
)
|
510
|
+
weight = weight / baseline
|
511
|
+
weight = min(1.0, weight)
|
512
|
+
if weight != 0 and weight != 1:
|
513
|
+
weight = baseline * (
|
514
|
+
weight * math.log(weight) + (1 - weight) * math.log(1 - weight)
|
515
|
+
)
|
516
|
+
all_concept_code_entropies[concept_id] = weight
|
355
517
|
|
356
518
|
return {
|
357
519
|
"numeric_lab_stats": numeric_lab_stats,
|
358
520
|
"categorical_lab_stats": categorical_lab_stats,
|
359
|
-
"
|
360
|
-
"
|
521
|
+
"original_concept_codes": list(current["concept_code_stats"].keys()),
|
522
|
+
"all_concept_code_stats": all_concept_code_stats,
|
523
|
+
"all_concept_code_entropies": all_concept_code_entropies,
|
524
|
+
"gender_list": current["gender_list"],
|
525
|
+
"race_list": current["race_list"],
|
361
526
|
"total": total,
|
362
527
|
}
|
363
528
|
|
@@ -465,7 +630,10 @@ class CehrGptTokenizer(PreTrainedTokenizer):
|
|
465
630
|
categorical_lab_stats: Dict[Tuple[str, str], int],
|
466
631
|
concept_name_mapping: Dict[str, str],
|
467
632
|
pretrained_concept_embedding_model: PretrainedEmbeddings = None,
|
468
|
-
|
633
|
+
motor_task_info: Optional[Dict[str, Any]] = None,
|
634
|
+
gender_map: Optional[Dict[str, int]] = None,
|
635
|
+
race_map: Optional[Dict[str, int]] = None,
|
636
|
+
ontology: Optional[Ontology] = None,
|
469
637
|
):
|
470
638
|
self._tokenizer = tokenizer
|
471
639
|
self._value_tokenizer = value_tokenizer
|
@@ -485,6 +653,8 @@ class CehrGptTokenizer(PreTrainedTokenizer):
|
|
485
653
|
self._linear_token_id = None
|
486
654
|
if LINEAR_PROB_TOKEN in self._tokenizer.get_vocab():
|
487
655
|
self._linear_token_id = self._tokenizer.token_to_id(LINEAR_PROB_TOKEN)
|
656
|
+
else:
|
657
|
+
self._linear_token_id = self._oov_token_id
|
488
658
|
|
489
659
|
self._numeric_concept_ids = (
|
490
660
|
self._numeric_event_statistics.get_numeric_concept_ids()
|
@@ -503,13 +673,18 @@ class CehrGptTokenizer(PreTrainedTokenizer):
|
|
503
673
|
for _ in self.get_vocab().keys()
|
504
674
|
if self._pretrained_concept_embedding_model.is_concept_available(_)
|
505
675
|
]
|
506
|
-
self.
|
507
|
-
|
676
|
+
self._motor_task_info: Dict[str, Any] = (
|
677
|
+
motor_task_info if motor_task_info is not None else {}
|
678
|
+
)
|
679
|
+
self._motor_time_to_event_codes = self._motor_task_info.get(
|
680
|
+
"motor_time_to_event_codes", []
|
508
681
|
)
|
509
682
|
self._motor_code_to_id_mapping = {
|
510
683
|
code: i for i, code in enumerate(sorted(self._motor_time_to_event_codes))
|
511
684
|
}
|
512
|
-
|
685
|
+
self._gender_map = gender_map if gender_map else {}
|
686
|
+
self._race_map = race_map if race_map else {}
|
687
|
+
self._ontology = ontology
|
513
688
|
super().__init__()
|
514
689
|
|
515
690
|
@property
|
@@ -545,6 +720,16 @@ class CehrGptTokenizer(PreTrainedTokenizer):
|
|
545
720
|
def time_token_vocab_size(self) -> int:
|
546
721
|
return self._att_tokenizer.get_vocab_size()
|
547
722
|
|
723
|
+
@property
|
724
|
+
def gender_size(self) -> int:
|
725
|
+
# Plus one for the unknown
|
726
|
+
return len(self._gender_map) + 1
|
727
|
+
|
728
|
+
@property
|
729
|
+
def race_size(self) -> int:
|
730
|
+
# Plus one for the unknown
|
731
|
+
return len(self._race_map) + 1
|
732
|
+
|
548
733
|
@property
|
549
734
|
def pad_value_token_id(self):
|
550
735
|
return self._padding_value_token_id
|
@@ -591,6 +776,7 @@ class CehrGptTokenizer(PreTrainedTokenizer):
|
|
591
776
|
return LINEAR_PROB_TOKEN
|
592
777
|
return None
|
593
778
|
|
779
|
+
@property
|
594
780
|
def vs_token_id(self):
|
595
781
|
# We used VS for the historical data, currently, we use the new [VS] for the newer data
|
596
782
|
# so we need to check both cases.
|
@@ -658,8 +844,17 @@ class CehrGptTokenizer(PreTrainedTokenizer):
|
|
658
844
|
def pretrained_concept_embedding_model(self):
|
659
845
|
return self._pretrained_concept_embedding_model
|
660
846
|
|
847
|
+
def get_motor_time_bins(self, motor_num_time_pieces: int) -> List[int]:
|
848
|
+
time_bins = np.percentile(
|
849
|
+
self._motor_task_info["motor_event_times"].samples,
|
850
|
+
np.linspace(0, 100, motor_num_time_pieces + 1),
|
851
|
+
)
|
852
|
+
time_bins[0] = 0
|
853
|
+
time_bins[-1] = float("inf")
|
854
|
+
return list(time_bins)
|
855
|
+
|
661
856
|
def get_motor_token_id(self, concept_id: str) -> int:
|
662
|
-
if
|
857
|
+
if not self.is_motor_time_to_event_code(concept_id):
|
663
858
|
raise RuntimeError(f"Invalid motor concept id: {concept_id}")
|
664
859
|
return self._motor_code_to_id_mapping[concept_id]
|
665
860
|
|
@@ -671,6 +866,21 @@ class CehrGptTokenizer(PreTrainedTokenizer):
|
|
671
866
|
return True
|
672
867
|
return False
|
673
868
|
|
869
|
+
def get_motor_parents(self, concept_id: str) -> List[str]:
|
870
|
+
motor_codes = []
|
871
|
+
if self._ontology is None:
|
872
|
+
if self.is_motor_time_to_event_code(concept_id):
|
873
|
+
motor_codes.append(concept_id)
|
874
|
+
else:
|
875
|
+
motor_codes.extend(
|
876
|
+
[
|
877
|
+
p
|
878
|
+
for p in self._ontology.get_all_parents(concept_id)
|
879
|
+
if self.is_motor_time_to_event_code(p)
|
880
|
+
]
|
881
|
+
)
|
882
|
+
return motor_codes
|
883
|
+
|
674
884
|
def get_vocab(self) -> Dict[str, int]:
|
675
885
|
return self._tokenizer.get_vocab()
|
676
886
|
|
@@ -699,6 +909,12 @@ class CehrGptTokenizer(PreTrainedTokenizer):
|
|
699
909
|
concept_value_token_ids, skip_special_tokens=skip_special_tokens
|
700
910
|
).split(" ")
|
701
911
|
|
912
|
+
def encode_gender(self, gender: str) -> int:
|
913
|
+
return self._gender_map.get(gender, 0)
|
914
|
+
|
915
|
+
def encode_race(self, race: str) -> int:
|
916
|
+
return self._race_map.get(race, 0)
|
917
|
+
|
702
918
|
def add_token(self, tokens: Union[str, List[str]]) -> None:
|
703
919
|
if isinstance(tokens, str):
|
704
920
|
tokens = [tokens]
|
@@ -782,13 +998,28 @@ class CehrGptTokenizer(PreTrainedTokenizer):
|
|
782
998
|
}
|
783
999
|
pickle.dump(lab_stats, f)
|
784
1000
|
|
1001
|
+
with open(
|
1002
|
+
os.path.join(save_directory, DEMOGRAPHICS_STATS_FILE_NAME), "wb"
|
1003
|
+
) as f:
|
1004
|
+
pickle.dump(
|
1005
|
+
{
|
1006
|
+
"gender_map": self._gender_map,
|
1007
|
+
"race_map": self._race_map,
|
1008
|
+
},
|
1009
|
+
f,
|
1010
|
+
)
|
1011
|
+
|
785
1012
|
with open(os.path.join(save_directory, CONCEPT_MAPPING_FILE_NAME), "w") as f:
|
786
1013
|
json.dump(self._concept_name_mapping, f)
|
787
1014
|
|
788
1015
|
with open(
|
789
|
-
os.path.join(save_directory,
|
1016
|
+
os.path.join(save_directory, MOTOR_TIME_TO_EVENT_TASK_INFO_FILE_NAME), "wb"
|
790
1017
|
) as f:
|
791
|
-
|
1018
|
+
pickle.dump(self._motor_task_info, f)
|
1019
|
+
|
1020
|
+
if self._ontology is not None:
|
1021
|
+
with open(os.path.join(save_directory, ONTOLOGY_FILE_NAME), "wb") as f:
|
1022
|
+
pickle.dump(self._ontology, f)
|
792
1023
|
|
793
1024
|
self._pretrained_concept_embedding_model.save(save_directory)
|
794
1025
|
|
@@ -890,6 +1121,23 @@ class CehrGptTokenizer(PreTrainedTokenizer):
|
|
890
1121
|
|
891
1122
|
with open(lab_stats_file, "rb") as file:
|
892
1123
|
lab_stats = pickle.load(file)
|
1124
|
+
try:
|
1125
|
+
# Load the demographics stats json file
|
1126
|
+
demographics_stats_file = transformers.utils.hub.cached_file(
|
1127
|
+
pretrained_model_name_or_path, DEMOGRAPHICS_STATS_FILE_NAME, **kwargs
|
1128
|
+
)
|
1129
|
+
if demographics_stats_file:
|
1130
|
+
with open(demographics_stats_file, "rb") as file:
|
1131
|
+
demographics_stats = pickle.load(file)
|
1132
|
+
else:
|
1133
|
+
demographics_stats = None
|
1134
|
+
except EnvironmentError:
|
1135
|
+
LOG.warning(
|
1136
|
+
f"The %s files does not exist in %s, setting demographics_stats to None",
|
1137
|
+
DEMOGRAPHICS_STATS_FILE_NAME,
|
1138
|
+
pretrained_model_name_or_path,
|
1139
|
+
)
|
1140
|
+
demographics_stats = None
|
893
1141
|
|
894
1142
|
# Load the concept_name json file
|
895
1143
|
concept_name_mapping_file = transformers.utils.hub.cached_file(
|
@@ -907,13 +1155,42 @@ class CehrGptTokenizer(PreTrainedTokenizer):
|
|
907
1155
|
return None
|
908
1156
|
concept_code_stats = load_json_file(concept_code_stats_mapping_file)
|
909
1157
|
|
910
|
-
|
911
|
-
|
912
|
-
|
913
|
-
|
914
|
-
|
915
|
-
|
916
|
-
|
1158
|
+
try:
|
1159
|
+
# Load the MOTOR time to event codes file
|
1160
|
+
motor_tte_task_info_file = transformers.utils.hub.cached_file(
|
1161
|
+
pretrained_model_name_or_path,
|
1162
|
+
MOTOR_TIME_TO_EVENT_TASK_INFO_FILE_NAME,
|
1163
|
+
**kwargs,
|
1164
|
+
)
|
1165
|
+
if motor_tte_task_info_file:
|
1166
|
+
with open(motor_tte_task_info_file, "rb") as file:
|
1167
|
+
motor_task_info = pickle.load(file)
|
1168
|
+
else:
|
1169
|
+
motor_task_info = None
|
1170
|
+
except EnvironmentError:
|
1171
|
+
LOG.warning(
|
1172
|
+
f"The %s files does not exist in %s, setting motor_task_info to None",
|
1173
|
+
MOTOR_TIME_TO_EVENT_TASK_INFO_FILE_NAME,
|
1174
|
+
pretrained_model_name_or_path,
|
1175
|
+
)
|
1176
|
+
motor_task_info = None
|
1177
|
+
|
1178
|
+
ontology = None
|
1179
|
+
try:
|
1180
|
+
ontology_file = transformers.utils.hub.cached_file(
|
1181
|
+
pretrained_model_name_or_path,
|
1182
|
+
ONTOLOGY_FILE_NAME,
|
1183
|
+
**kwargs,
|
1184
|
+
)
|
1185
|
+
if ontology_file:
|
1186
|
+
with open(ontology_file, "rb") as file:
|
1187
|
+
ontology = pickle.load(file)
|
1188
|
+
except EnvironmentError | OSError:
|
1189
|
+
LOG.warning(
|
1190
|
+
"The ontology file %s does not existing in %s",
|
1191
|
+
ONTOLOGY_FILE_NAME,
|
1192
|
+
pretrained_model_name_or_path,
|
1193
|
+
)
|
917
1194
|
|
918
1195
|
pretrained_embedding_model = PretrainedEmbeddings(pretrained_model_name_or_path)
|
919
1196
|
|
@@ -927,7 +1204,10 @@ class CehrGptTokenizer(PreTrainedTokenizer):
|
|
927
1204
|
lab_stats["categorical_lab_stats"],
|
928
1205
|
concept_name_mapping,
|
929
1206
|
pretrained_embedding_model,
|
930
|
-
|
1207
|
+
motor_task_info,
|
1208
|
+
demographics_stats["gender_map"] if demographics_stats else None,
|
1209
|
+
demographics_stats["race_map"] if demographics_stats else None,
|
1210
|
+
ontology=ontology,
|
931
1211
|
)
|
932
1212
|
|
933
1213
|
@classmethod
|
@@ -1048,6 +1328,18 @@ class CehrGptTokenizer(PreTrainedTokenizer):
|
|
1048
1328
|
motor_time_to_event_code
|
1049
1329
|
)
|
1050
1330
|
|
1331
|
+
for gender in new_tokenizer._gender_map.keys():
|
1332
|
+
if gender not in cehrgpt_tokenizer_copy._gender_map:
|
1333
|
+
cehrgpt_tokenizer_copy._gender_map[gender] = len(
|
1334
|
+
cehrgpt_tokenizer_copy._gender_map
|
1335
|
+
)
|
1336
|
+
|
1337
|
+
for race in new_tokenizer._race_map.keys():
|
1338
|
+
if race not in cehrgpt_tokenizer_copy._race_map:
|
1339
|
+
cehrgpt_tokenizer_copy._race_map[race] = len(
|
1340
|
+
cehrgpt_tokenizer_copy._race_map
|
1341
|
+
)
|
1342
|
+
|
1051
1343
|
return CehrGptTokenizer(
|
1052
1344
|
tokenizer=cehrgpt_tokenizer_copy._tokenizer,
|
1053
1345
|
value_tokenizer=cehrgpt_tokenizer_copy._value_tokenizer,
|
@@ -1058,7 +1350,10 @@ class CehrGptTokenizer(PreTrainedTokenizer):
|
|
1058
1350
|
categorical_lab_stats=cehrgpt_tokenizer_copy._categorical_lab_stats,
|
1059
1351
|
concept_name_mapping=cehrgpt_tokenizer_copy._concept_name_mapping,
|
1060
1352
|
pretrained_concept_embedding_model=pretrained_concept_embedding_model,
|
1061
|
-
|
1353
|
+
motor_task_info=cehrgpt_tokenizer_copy._motor_task_info,
|
1354
|
+
gender_map=cehrgpt_tokenizer_copy._gender_map,
|
1355
|
+
race_map=cehrgpt_tokenizer_copy._race_map,
|
1356
|
+
ontology=cehrgpt_tokenizer_copy._ontology,
|
1062
1357
|
)
|
1063
1358
|
|
1064
1359
|
@classmethod
|
@@ -1128,10 +1423,10 @@ class CehrGptTokenizer(PreTrainedTokenizer):
|
|
1128
1423
|
concept_name_mapping: Dict[str, str],
|
1129
1424
|
data_args: DataTrainingArguments,
|
1130
1425
|
pretrained_concept_embedding_model: PretrainedEmbeddings = None,
|
1131
|
-
allowed_motor_codes: Optional[List[int]] = None,
|
1132
1426
|
num_motor_tasks: Optional[int] = None,
|
1133
1427
|
apply_entropy_filter: bool = False,
|
1134
1428
|
min_prevalence: float = 1 / 1000,
|
1429
|
+
ontology: Optional[Ontology] = None,
|
1135
1430
|
):
|
1136
1431
|
"""
|
1137
1432
|
Train a huggingface word level tokenizer.
|
@@ -1144,12 +1439,14 @@ class CehrGptTokenizer(PreTrainedTokenizer):
|
|
1144
1439
|
dataset = dataset["train"]
|
1145
1440
|
|
1146
1441
|
LOG.info("Calculating data statistics")
|
1147
|
-
cehrgpt_data_statistics = compute_statistics(dataset, data_args)
|
1148
|
-
cehrgpt_data_statistics["total"]
|
1442
|
+
cehrgpt_data_statistics = compute_statistics(dataset, data_args, ontology)
|
1149
1443
|
numeric_lab_stats = cehrgpt_data_statistics["numeric_lab_stats"]
|
1150
1444
|
categorical_lab_stats = cehrgpt_data_statistics["categorical_lab_stats"]
|
1151
|
-
|
1152
|
-
|
1445
|
+
original_concept_codes = cehrgpt_data_statistics["original_concept_codes"]
|
1446
|
+
all_concept_code_stats = cehrgpt_data_statistics["all_concept_code_stats"]
|
1447
|
+
all_concept_code_entropies = cehrgpt_data_statistics[
|
1448
|
+
"all_concept_code_entropies"
|
1449
|
+
]
|
1153
1450
|
|
1154
1451
|
if apply_entropy_filter:
|
1155
1452
|
min_prevalence = max(1e-8, min_prevalence)
|
@@ -1159,15 +1456,15 @@ class CehrGptTokenizer(PreTrainedTokenizer):
|
|
1159
1456
|
)
|
1160
1457
|
qualified_codes = [
|
1161
1458
|
k
|
1162
|
-
for k
|
1163
|
-
if
|
1459
|
+
for k in original_concept_codes
|
1460
|
+
if all_concept_code_entropies[k] <= min_entropy
|
1164
1461
|
or not is_clinical_event(k, data_args.is_data_in_meds)
|
1165
1462
|
]
|
1166
1463
|
else:
|
1167
1464
|
qualified_codes = [
|
1168
1465
|
k
|
1169
|
-
for k
|
1170
|
-
if min_prevalence <=
|
1466
|
+
for k in original_concept_codes
|
1467
|
+
if min_prevalence <= all_concept_code_stats[k]
|
1171
1468
|
or not is_clinical_event(k, data_args.is_data_in_meds)
|
1172
1469
|
]
|
1173
1470
|
|
@@ -1240,36 +1537,69 @@ class CehrGptTokenizer(PreTrainedTokenizer):
|
|
1240
1537
|
if concept_id in concept_name_mapping
|
1241
1538
|
}
|
1242
1539
|
|
1243
|
-
|
1244
|
-
if num_motor_tasks
|
1540
|
+
motor_task_info = None
|
1541
|
+
if num_motor_tasks:
|
1542
|
+
LOG.info("Computing the MOTOR TTE statistics")
|
1543
|
+
allowed_motor_codes = get_allowed_motor_codes(
|
1544
|
+
original_concept_codes, ontology
|
1545
|
+
)
|
1546
|
+
motor_tte_statistics = compute_motor_tte_statistics(
|
1547
|
+
dataset, data_args, allowed_motor_codes, ontology
|
1548
|
+
)
|
1245
1549
|
motor_time_to_event_codes = []
|
1246
1550
|
for concept_id, _ in sorted(
|
1247
|
-
|
1551
|
+
all_concept_code_entropies.items(), key=lambda t: t[1]
|
1248
1552
|
):
|
1249
|
-
if
|
1250
|
-
|
1251
|
-
|
1252
|
-
|
1553
|
+
if concept_id not in allowed_motor_codes:
|
1554
|
+
continue
|
1555
|
+
tte_stats = motor_tte_statistics["task_tte_stats"][concept_id]
|
1556
|
+
censor_stats = motor_tte_statistics["task_censor_stats"][concept_id]
|
1557
|
+
frac_events = tte_stats / (tte_stats + censor_stats)
|
1558
|
+
|
1559
|
+
if frac_events < 1 / 1000:
|
1560
|
+
LOG.info(
|
1561
|
+
"Ran into very rare task %s with %s", concept_id, frac_events
|
1562
|
+
)
|
1253
1563
|
continue
|
1564
|
+
|
1254
1565
|
if len(motor_time_to_event_codes) < num_motor_tasks:
|
1255
1566
|
motor_time_to_event_codes.append(concept_id)
|
1256
1567
|
else:
|
1257
1568
|
break
|
1569
|
+
|
1258
1570
|
LOG.info(
|
1259
1571
|
f"{len(motor_time_to_event_codes)} number of tasks have been added as MOTOR tasks"
|
1260
1572
|
)
|
1573
|
+
motor_task_info = {
|
1574
|
+
"motor_event_times": motor_tte_statistics["motor_event_times"],
|
1575
|
+
"task_tte_stats": motor_tte_statistics["task_tte_stats"],
|
1576
|
+
"task_censor_stats": motor_tte_statistics["task_censor_stats"],
|
1577
|
+
"motor_time_to_event_codes": motor_time_to_event_codes,
|
1578
|
+
}
|
1579
|
+
|
1580
|
+
gender_map = {
|
1581
|
+
gender: i + 1
|
1582
|
+
for i, gender in enumerate(sorted(cehrgpt_data_statistics["gender_list"]))
|
1583
|
+
}
|
1584
|
+
race_map = {
|
1585
|
+
race: i + 1
|
1586
|
+
for i, race in enumerate(sorted(cehrgpt_data_statistics["race_list"]))
|
1587
|
+
}
|
1261
1588
|
|
1262
1589
|
return CehrGptTokenizer(
|
1263
1590
|
concept_tokenizer,
|
1264
1591
|
value_tokenizer,
|
1265
1592
|
att_tokenizer,
|
1266
1593
|
token_to_sub_time_token_mapping,
|
1267
|
-
|
1594
|
+
all_concept_code_stats,
|
1268
1595
|
numeric_lab_stats,
|
1269
1596
|
categorical_lab_stats,
|
1270
1597
|
concept_name_mapping,
|
1271
1598
|
pretrained_concept_embedding_model,
|
1272
|
-
|
1599
|
+
motor_task_info,
|
1600
|
+
gender_map,
|
1601
|
+
race_map,
|
1602
|
+
ontology,
|
1273
1603
|
)
|
1274
1604
|
|
1275
1605
|
@classmethod
|