cehrgpt 0.1.0__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cehrgpt/analysis/irregularity.py +36 -0
- cehrgpt/data/hf_cehrgpt_dataset.py +1 -0
- cehrgpt/data/hf_cehrgpt_dataset_collator.py +454 -68
- cehrgpt/data/hf_cehrgpt_dataset_mapping.py +232 -17
- cehrgpt/data/sample_packing_sampler.py +36 -6
- cehrgpt/generation/cehrgpt_conditional_generation.py +314 -0
- cehrgpt/generation/generate_batch_hf_gpt_sequence.py +15 -3
- cehrgpt/generation/omop_converter_batch.py +32 -2
- cehrgpt/gpt_utils.py +20 -2
- cehrgpt/models/config.py +25 -0
- cehrgpt/models/hf_cehrgpt.py +244 -39
- cehrgpt/models/hf_modeling_outputs.py +1 -0
- cehrgpt/models/special_tokens.py +1 -0
- cehrgpt/models/tokenization_hf_cehrgpt.py +354 -71
- cehrgpt/runners/data_utils.py +131 -5
- cehrgpt/runners/hf_cehrgpt_finetune_runner.py +84 -51
- cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +59 -7
- cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +60 -0
- cehrgpt/runners/hyperparameter_search_util.py +6 -7
- cehrgpt/runners/sample_packing_trainer.py +17 -0
- cehrgpt/time_to_event/config/1_year_cabg.yaml +23 -0
- cehrgpt/time_to_event/time_to_event_model.py +2 -13
- cehrgpt/time_to_event/time_to_event_prediction.py +27 -13
- cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +80 -62
- {cehrgpt-0.1.0.dist-info → cehrgpt-0.1.2.dist-info}/METADATA +102 -7
- {cehrgpt-0.1.0.dist-info → cehrgpt-0.1.2.dist-info}/RECORD +29 -26
- {cehrgpt-0.1.0.dist-info → cehrgpt-0.1.2.dist-info}/WHEEL +1 -1
- {cehrgpt-0.1.0.dist-info → cehrgpt-0.1.2.dist-info}/licenses/LICENSE +0 -0
- {cehrgpt-0.1.0.dist-info → cehrgpt-0.1.2.dist-info}/top_level.txt +0 -0
@@ -10,13 +10,9 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
|
|
10
10
|
import numpy as np
|
11
11
|
import scipy.stats as stats
|
12
12
|
import transformers
|
13
|
-
from cehrbert.models.hf_models.tokenization_utils import
|
14
|
-
agg_helper,
|
15
|
-
agg_statistics,
|
16
|
-
load_json_file,
|
17
|
-
)
|
13
|
+
from cehrbert.models.hf_models.tokenization_utils import agg_helper, load_json_file
|
18
14
|
from cehrbert.runners.hf_runner_argument_dataclass import DataTrainingArguments
|
19
|
-
from datasets import Dataset, DatasetDict
|
15
|
+
from datasets import Dataset, DatasetDict, IterableDataset
|
20
16
|
from femr.stat_utils import OnlineStatistics, ReservoirSampler
|
21
17
|
from scipy.interpolate import UnivariateSpline
|
22
18
|
from tokenizers import AddedToken, Tokenizer
|
@@ -31,11 +27,13 @@ from cehrgpt.gpt_utils import (
|
|
31
27
|
convert_time_interval_to_time_tuple,
|
32
28
|
extract_time_interval_in_days,
|
33
29
|
is_att_token,
|
30
|
+
is_clinical_event,
|
34
31
|
is_inpatient_att_token,
|
35
32
|
)
|
36
33
|
from cehrgpt.models.pretrained_embeddings import PretrainedEmbeddings
|
37
34
|
from cehrgpt.models.special_tokens import (
|
38
35
|
END_TOKEN,
|
36
|
+
LINEAR_PROB_TOKEN,
|
39
37
|
OUT_OF_VOCABULARY_TOKEN,
|
40
38
|
PAD_TOKEN,
|
41
39
|
START_TOKEN,
|
@@ -53,7 +51,10 @@ TIME_TOKENIZER_FILE_NAME = "cehrgpt_time_tokenizer.json"
|
|
53
51
|
TOKEN_TO_SUB_TIME_TOKEN_MAPPING_FILE_NAME = "token_to_sub_time_token_mapping.json"
|
54
52
|
LAB_STATS_FILE_NAME = "cehrgpt_lab_stats.pickle"
|
55
53
|
LEGACY_LAB_STATS_FILE_NAME = "cehrgpt_lab_stats.json"
|
54
|
+
CONCEPT_STATS_FILE_NAME = "cehrgpt_concept_stats.json"
|
56
55
|
CONCEPT_MAPPING_FILE_NAME = "concept_name_mapping.json"
|
56
|
+
MOTOR_TIME_TO_EVENT_CODES_FILE_NAME = "motor_time_to_event_codes.json"
|
57
|
+
|
57
58
|
LOG = logging.get_logger("transformers")
|
58
59
|
|
59
60
|
|
@@ -73,6 +74,16 @@ def create_value_bin(bin_index: int) -> str:
|
|
73
74
|
return "BIN:" + str(bin_index)
|
74
75
|
|
75
76
|
|
77
|
+
def get_dataset_len(dataset: Union[Dataset, IterableDataset]) -> int:
|
78
|
+
if isinstance(dataset, Dataset):
|
79
|
+
return len(dataset)
|
80
|
+
elif isinstance(dataset, IterableDataset):
|
81
|
+
return sum([1 for _ in dataset])
|
82
|
+
raise RuntimeError(
|
83
|
+
"The dataset must be one of the two types (Dataset, IterableDataset)"
|
84
|
+
)
|
85
|
+
|
86
|
+
|
76
87
|
def create_sample_from_bins(bins, sample_size: int = 10_000) -> List[float]:
|
77
88
|
"""
|
78
89
|
Generates a specified number of samples from a list of bins, each containing a fitted spline.
|
@@ -188,7 +199,22 @@ def create_bins_with_spline(samples, num_bins, d_freedom=3) -> List[Dict[str, An
|
|
188
199
|
return bins
|
189
200
|
|
190
201
|
|
191
|
-
def
|
202
|
+
def agg_statistics(stats1, stats2):
|
203
|
+
if stats1.get("numeric_stats_by_lab"):
|
204
|
+
for k, v in stats2["numeric_stats_by_lab"].items():
|
205
|
+
stats1["numeric_stats_by_lab"][k].combine(v)
|
206
|
+
if stats1.get("categorical_stats_by_lab"):
|
207
|
+
for (concept_id, concept_as_value), count in stats2[
|
208
|
+
"categorical_stats_by_lab"
|
209
|
+
].items():
|
210
|
+
stats1["categorical_stats_by_lab"][(concept_id, concept_as_value)] += count
|
211
|
+
if stats1.get("concept_code_stats"):
|
212
|
+
for concept_id, weight in stats2["concept_code_stats"].items():
|
213
|
+
stats1["concept_code_stats"][concept_id] += weight
|
214
|
+
return stats1
|
215
|
+
|
216
|
+
|
217
|
+
def map_statistics(batch: Dict[str, Any], total_size, size=10_000) -> Dict[str, Any]:
|
192
218
|
if "units" in batch:
|
193
219
|
batch_value_units = batch["units"]
|
194
220
|
else:
|
@@ -212,6 +238,7 @@ def map_statistics(batch: Dict[str, Any], size=10_000) -> Dict[str, Any]:
|
|
212
238
|
|
213
239
|
numeric_stats_by_lab = collections.defaultdict(partial(ReservoirSampler, size=size))
|
214
240
|
categorical_stats_by_lab = collections.defaultdict(int)
|
241
|
+
concept_code_stats = collections.defaultdict(int)
|
215
242
|
for (
|
216
243
|
concept_ids,
|
217
244
|
number_as_values,
|
@@ -225,6 +252,7 @@ def map_statistics(batch: Dict[str, Any], size=10_000) -> Dict[str, Any]:
|
|
225
252
|
batch["concept_value_masks"],
|
226
253
|
batch_value_units,
|
227
254
|
):
|
255
|
+
unique_codes = set()
|
228
256
|
for (
|
229
257
|
concept_id,
|
230
258
|
number_as_value,
|
@@ -243,10 +271,94 @@ def map_statistics(batch: Dict[str, Any], size=10_000) -> Dict[str, Any]:
|
|
243
271
|
numeric_stats_by_lab[(concept_id, unit)].add(number_as_value, 1)
|
244
272
|
if concept_as_value:
|
245
273
|
categorical_stats_by_lab[(concept_id, concept_as_value)] += 1
|
274
|
+
unique_codes.add(concept_id)
|
275
|
+
|
276
|
+
for code in unique_codes:
|
277
|
+
concept_code_stats[code] += 1 / total_size
|
246
278
|
|
247
279
|
return {
|
248
280
|
"numeric_stats_by_lab": numeric_stats_by_lab,
|
249
281
|
"categorical_stats_by_lab": categorical_stats_by_lab,
|
282
|
+
"concept_code_stats": concept_code_stats,
|
283
|
+
}
|
284
|
+
|
285
|
+
|
286
|
+
def compute_statistics(
|
287
|
+
dataset: Dataset, data_args: DataTrainingArguments
|
288
|
+
) -> Dict[str, Any]:
|
289
|
+
total = get_dataset_len(dataset)
|
290
|
+
map_statistics_partial = partial(map_statistics, total_size=total, size=SAMPLE_SIZE)
|
291
|
+
if data_args.streaming:
|
292
|
+
first_example = next(iter(dataset))
|
293
|
+
parts = dataset.map(
|
294
|
+
partial(agg_helper, map_func=map_statistics_partial),
|
295
|
+
batched=True,
|
296
|
+
batch_size=data_args.preprocessing_batch_size,
|
297
|
+
remove_columns=first_example.keys(),
|
298
|
+
)
|
299
|
+
else:
|
300
|
+
parts = dataset.map(
|
301
|
+
partial(agg_helper, map_func=map_statistics_partial),
|
302
|
+
batched=True,
|
303
|
+
batch_size=data_args.preprocessing_batch_size,
|
304
|
+
remove_columns=dataset.column_names,
|
305
|
+
num_proc=data_args.preprocessing_num_workers,
|
306
|
+
keep_in_memory=True,
|
307
|
+
new_fingerprint="invalid",
|
308
|
+
)
|
309
|
+
current = None
|
310
|
+
for stat in tqdm(parts, desc="Aggregating the lab statistics"):
|
311
|
+
fixed_stat = pickle.loads(stat["data"])
|
312
|
+
if current is None:
|
313
|
+
current = fixed_stat
|
314
|
+
else:
|
315
|
+
current = agg_statistics(current, fixed_stat)
|
316
|
+
|
317
|
+
numeric_lab_stats = []
|
318
|
+
for (concept_id, unit), online_stats in current["numeric_stats_by_lab"].items():
|
319
|
+
if len(online_stats.samples) == 0:
|
320
|
+
continue
|
321
|
+
samples = truncated_sample(online_stats.samples, data_args.value_outlier_std)
|
322
|
+
bins = create_bins_with_spline(samples, NUM_OF_BINS, DEGREE_OF_FREEDOM)
|
323
|
+
if len(bins) > 0:
|
324
|
+
numeric_lab_stats.append(
|
325
|
+
{
|
326
|
+
"concept_id": concept_id,
|
327
|
+
"unit": unit,
|
328
|
+
"mean": np.mean(samples),
|
329
|
+
"std": np.std(samples),
|
330
|
+
"count": len(online_stats.samples),
|
331
|
+
"value_outlier_std": data_args.value_outlier_std,
|
332
|
+
"bins": bins,
|
333
|
+
}
|
334
|
+
)
|
335
|
+
|
336
|
+
categorical_lab_stats = collections.defaultdict(int)
|
337
|
+
for (concept_id, value_as_concept), count in current[
|
338
|
+
"categorical_stats_by_lab"
|
339
|
+
].items():
|
340
|
+
categorical_lab_stats[(concept_id, value_as_concept)] += count
|
341
|
+
|
342
|
+
concept_code_stats = collections.defaultdict(int)
|
343
|
+
for concept_id, count in current["concept_code_stats"].items():
|
344
|
+
concept_code_stats[concept_id] += count
|
345
|
+
|
346
|
+
code_weights = np.asarray(list(concept_code_stats.values())).clip(1e-8, 1 - 1e-8)
|
347
|
+
# Clip the values so we don't get errors when applying np.log
|
348
|
+
code_entropies = np.log(code_weights) * code_weights + (1 - code_weights) * np.log(
|
349
|
+
1 - code_weights
|
350
|
+
)
|
351
|
+
|
352
|
+
concept_code_entropies = {
|
353
|
+
k: v for k, v in zip(concept_code_stats.keys(), code_entropies)
|
354
|
+
}
|
355
|
+
|
356
|
+
return {
|
357
|
+
"numeric_lab_stats": numeric_lab_stats,
|
358
|
+
"categorical_lab_stats": categorical_lab_stats,
|
359
|
+
"concept_code_stats": concept_code_stats,
|
360
|
+
"concept_code_entropies": concept_code_entropies,
|
361
|
+
"total": total,
|
250
362
|
}
|
251
363
|
|
252
364
|
|
@@ -348,15 +460,18 @@ class CehrGptTokenizer(PreTrainedTokenizer):
|
|
348
460
|
value_tokenizer: Tokenizer,
|
349
461
|
att_tokenizer: Tokenizer,
|
350
462
|
token_to_sub_time_token_mapping: Dict[str, List[str]],
|
463
|
+
concept_code_stats: Dict[str, Any],
|
351
464
|
numeric_lab_stats: List[Dict[str, Any]],
|
352
465
|
categorical_lab_stats: Dict[Tuple[str, str], int],
|
353
466
|
concept_name_mapping: Dict[str, str],
|
354
467
|
pretrained_concept_embedding_model: PretrainedEmbeddings = None,
|
468
|
+
motor_time_to_event_codes: Optional[List[str]] = None,
|
355
469
|
):
|
356
470
|
self._tokenizer = tokenizer
|
357
471
|
self._value_tokenizer = value_tokenizer
|
358
472
|
self._att_tokenizer = att_tokenizer
|
359
473
|
self._token_to_sub_time_token_mapping = token_to_sub_time_token_mapping
|
474
|
+
self._concept_code_stats = concept_code_stats
|
360
475
|
self._numeric_lab_stats = numeric_lab_stats
|
361
476
|
self._numeric_event_statistics = NumericEventStatistics(numeric_lab_stats)
|
362
477
|
self._categorical_lab_stats = categorical_lab_stats
|
@@ -365,6 +480,12 @@ class CehrGptTokenizer(PreTrainedTokenizer):
|
|
365
480
|
self._padding_token_id = self._tokenizer.token_to_id(PAD_TOKEN)
|
366
481
|
self._start_token_id = self._tokenizer.token_to_id(START_TOKEN)
|
367
482
|
self._end_token_id = self._tokenizer.token_to_id(END_TOKEN)
|
483
|
+
|
484
|
+
# Backward compatible with the old tokenizer
|
485
|
+
self._linear_token_id = None
|
486
|
+
if LINEAR_PROB_TOKEN in self._tokenizer.get_vocab():
|
487
|
+
self._linear_token_id = self._tokenizer.token_to_id(LINEAR_PROB_TOKEN)
|
488
|
+
|
368
489
|
self._numeric_concept_ids = (
|
369
490
|
self._numeric_event_statistics.get_numeric_concept_ids()
|
370
491
|
)
|
@@ -382,6 +503,12 @@ class CehrGptTokenizer(PreTrainedTokenizer):
|
|
382
503
|
for _ in self.get_vocab().keys()
|
383
504
|
if self._pretrained_concept_embedding_model.is_concept_available(_)
|
384
505
|
]
|
506
|
+
self._motor_time_to_event_codes = (
|
507
|
+
motor_time_to_event_codes if motor_time_to_event_codes else []
|
508
|
+
)
|
509
|
+
self._motor_code_to_id_mapping = {
|
510
|
+
code: i for i, code in enumerate(sorted(self._motor_time_to_event_codes))
|
511
|
+
}
|
385
512
|
|
386
513
|
super().__init__()
|
387
514
|
|
@@ -402,6 +529,10 @@ class CehrGptTokenizer(PreTrainedTokenizer):
|
|
402
529
|
]
|
403
530
|
)
|
404
531
|
|
532
|
+
@property
|
533
|
+
def motor_tte_vocab_size(self) -> int:
|
534
|
+
return len(self._motor_code_to_id_mapping)
|
535
|
+
|
405
536
|
@property
|
406
537
|
def vocab_size(self) -> int:
|
407
538
|
return self._tokenizer.get_vocab_size()
|
@@ -442,10 +573,45 @@ class CehrGptTokenizer(PreTrainedTokenizer):
|
|
442
573
|
def pad_token_id(self):
|
443
574
|
return self._padding_token_id
|
444
575
|
|
576
|
+
@property
|
577
|
+
def oov_token_id(self):
|
578
|
+
return self._oov_token_id
|
579
|
+
|
445
580
|
@property
|
446
581
|
def pad_token(self):
|
447
582
|
return PAD_TOKEN
|
448
583
|
|
584
|
+
@property
|
585
|
+
def linear_token_id(self) -> Optional[int]:
|
586
|
+
return self._linear_token_id
|
587
|
+
|
588
|
+
@property
|
589
|
+
def linear_token(self) -> Optional[str]:
|
590
|
+
if LINEAR_PROB_TOKEN in self._tokenizer.get_vocab():
|
591
|
+
return LINEAR_PROB_TOKEN
|
592
|
+
return None
|
593
|
+
|
594
|
+
def vs_token_id(self):
|
595
|
+
# We used VS for the historical data, currently, we use the new [VS] for the newer data
|
596
|
+
# so we need to check both cases.
|
597
|
+
if "VS" in self._tokenizer.get_vocab():
|
598
|
+
return self._convert_token_to_id("VS")
|
599
|
+
elif "[VS]" in self._tokenizer.get_vocab():
|
600
|
+
return self._convert_token_to_id("[VS]")
|
601
|
+
else:
|
602
|
+
raise RuntimeError("The tokenizer does not contain either VS or [VS]")
|
603
|
+
|
604
|
+
@property
|
605
|
+
def ve_token_id(self):
|
606
|
+
# We used VE for the historical data, currently, we use the new [VE] for the newer data
|
607
|
+
# so we need to check both cases.
|
608
|
+
if "VE" in self._tokenizer.get_vocab():
|
609
|
+
return self._convert_token_to_id("VE")
|
610
|
+
elif "[VE]" in self._tokenizer.get_vocab():
|
611
|
+
return self._convert_token_to_id("[VE]")
|
612
|
+
else:
|
613
|
+
raise RuntimeError("The tokenizer does not contain either VE or [VE]")
|
614
|
+
|
449
615
|
@property
|
450
616
|
def numeric_concept_ids(self):
|
451
617
|
return self._numeric_concept_ids
|
@@ -456,8 +622,14 @@ class CehrGptTokenizer(PreTrainedTokenizer):
|
|
456
622
|
|
457
623
|
@property
|
458
624
|
def lab_token_ids(self):
|
459
|
-
reserved_tokens = [
|
460
|
-
|
625
|
+
reserved_tokens = [
|
626
|
+
START_TOKEN,
|
627
|
+
PAD_TOKEN,
|
628
|
+
END_TOKEN,
|
629
|
+
OUT_OF_VOCABULARY_TOKEN,
|
630
|
+
LINEAR_PROB_TOKEN,
|
631
|
+
]
|
632
|
+
lab_token_ids = self.encode(
|
461
633
|
[
|
462
634
|
concept_id
|
463
635
|
for concept_id in self._numeric_concept_ids
|
@@ -465,6 +637,9 @@ class CehrGptTokenizer(PreTrainedTokenizer):
|
|
465
637
|
if concept_id not in reserved_tokens
|
466
638
|
]
|
467
639
|
)
|
640
|
+
return list(
|
641
|
+
filter(lambda token_id: token_id != self._oov_token_id, lab_token_ids)
|
642
|
+
)
|
468
643
|
|
469
644
|
@property
|
470
645
|
def token_to_time_token_mapping(self) -> Dict[int, List[int]]:
|
@@ -483,6 +658,19 @@ class CehrGptTokenizer(PreTrainedTokenizer):
|
|
483
658
|
def pretrained_concept_embedding_model(self):
|
484
659
|
return self._pretrained_concept_embedding_model
|
485
660
|
|
661
|
+
def get_motor_token_id(self, concept_id: str) -> int:
|
662
|
+
if concept_id not in concept_id:
|
663
|
+
raise RuntimeError(f"Invalid motor concept id: {concept_id}")
|
664
|
+
return self._motor_code_to_id_mapping[concept_id]
|
665
|
+
|
666
|
+
def is_motor_time_to_event_code(self, future_concept_id: str) -> bool:
|
667
|
+
if (
|
668
|
+
self._motor_time_to_event_codes
|
669
|
+
and future_concept_id in self._motor_time_to_event_codes
|
670
|
+
):
|
671
|
+
return True
|
672
|
+
return False
|
673
|
+
|
486
674
|
def get_vocab(self) -> Dict[str, int]:
|
487
675
|
return self._tokenizer.get_vocab()
|
488
676
|
|
@@ -511,6 +699,18 @@ class CehrGptTokenizer(PreTrainedTokenizer):
|
|
511
699
|
concept_value_token_ids, skip_special_tokens=skip_special_tokens
|
512
700
|
).split(" ")
|
513
701
|
|
702
|
+
def add_token(self, tokens: Union[str, List[str]]) -> None:
|
703
|
+
if isinstance(tokens, str):
|
704
|
+
tokens = [tokens]
|
705
|
+
vocab = self.get_vocab()
|
706
|
+
self._tokenizer.add_tokens(
|
707
|
+
[
|
708
|
+
AddedToken(token, single_word=True, normalized=False)
|
709
|
+
for token in tokens
|
710
|
+
if token not in vocab
|
711
|
+
]
|
712
|
+
)
|
713
|
+
|
514
714
|
def _convert_token_to_id(self, token):
|
515
715
|
"""Converts a token (str) in an id using the vocab."""
|
516
716
|
token_id = self._tokenizer.token_to_id(token)
|
@@ -572,6 +772,9 @@ class CehrGptTokenizer(PreTrainedTokenizer):
|
|
572
772
|
) as f:
|
573
773
|
json.dump(self._token_to_sub_time_token_mapping, f)
|
574
774
|
|
775
|
+
with open(os.path.join(save_directory, CONCEPT_STATS_FILE_NAME), "w") as f:
|
776
|
+
json.dump(self._concept_code_stats, f)
|
777
|
+
|
575
778
|
with open(os.path.join(save_directory, LAB_STATS_FILE_NAME), "wb") as f:
|
576
779
|
lab_stats = {
|
577
780
|
"numeric_lab_stats": self._numeric_lab_stats,
|
@@ -582,6 +785,11 @@ class CehrGptTokenizer(PreTrainedTokenizer):
|
|
582
785
|
with open(os.path.join(save_directory, CONCEPT_MAPPING_FILE_NAME), "w") as f:
|
583
786
|
json.dump(self._concept_name_mapping, f)
|
584
787
|
|
788
|
+
with open(
|
789
|
+
os.path.join(save_directory, MOTOR_TIME_TO_EVENT_CODES_FILE_NAME), "w"
|
790
|
+
) as f:
|
791
|
+
json.dump(self._motor_time_to_event_codes, f)
|
792
|
+
|
585
793
|
self._pretrained_concept_embedding_model.save(save_directory)
|
586
794
|
|
587
795
|
if push_to_hub:
|
@@ -691,6 +899,22 @@ class CehrGptTokenizer(PreTrainedTokenizer):
|
|
691
899
|
return None
|
692
900
|
concept_name_mapping = load_json_file(concept_name_mapping_file)
|
693
901
|
|
902
|
+
# Load the concept_code_stats json file
|
903
|
+
concept_code_stats_mapping_file = transformers.utils.hub.cached_file(
|
904
|
+
pretrained_model_name_or_path, CONCEPT_STATS_FILE_NAME, **kwargs
|
905
|
+
)
|
906
|
+
if not concept_code_stats_mapping_file:
|
907
|
+
return None
|
908
|
+
concept_code_stats = load_json_file(concept_code_stats_mapping_file)
|
909
|
+
|
910
|
+
# Load the MOTOR time to event codes file
|
911
|
+
motor_time_to_event_codes_file = transformers.utils.hub.cached_file(
|
912
|
+
pretrained_model_name_or_path, MOTOR_TIME_TO_EVENT_CODES_FILE_NAME, **kwargs
|
913
|
+
)
|
914
|
+
if not motor_time_to_event_codes_file:
|
915
|
+
return None
|
916
|
+
motor_time_to_event_codes = load_json_file(motor_time_to_event_codes_file)
|
917
|
+
|
694
918
|
pretrained_embedding_model = PretrainedEmbeddings(pretrained_model_name_or_path)
|
695
919
|
|
696
920
|
return CehrGptTokenizer(
|
@@ -698,10 +922,12 @@ class CehrGptTokenizer(PreTrainedTokenizer):
|
|
698
922
|
value_tokenizer,
|
699
923
|
att_tokenizer,
|
700
924
|
token_to_sub_time_token_mapping,
|
925
|
+
concept_code_stats,
|
701
926
|
lab_stats["numeric_lab_stats"],
|
702
927
|
lab_stats["categorical_lab_stats"],
|
703
928
|
concept_name_mapping,
|
704
929
|
pretrained_embedding_model,
|
930
|
+
motor_time_to_event_codes,
|
705
931
|
)
|
706
932
|
|
707
933
|
@classmethod
|
@@ -724,6 +950,9 @@ class CehrGptTokenizer(PreTrainedTokenizer):
|
|
724
950
|
concept_name_mapping: Dict[str, str],
|
725
951
|
data_args: DataTrainingArguments,
|
726
952
|
pretrained_concept_embedding_model: PretrainedEmbeddings = None,
|
953
|
+
num_motor_tasks: Optional[int] = None,
|
954
|
+
apply_entropy_filter: bool = False,
|
955
|
+
min_prevalence: float = 1 / 1000,
|
727
956
|
):
|
728
957
|
if not isinstance(cehrgpt_tokenizer, CehrGptTokenizer):
|
729
958
|
raise ValueError(
|
@@ -736,6 +965,9 @@ class CehrGptTokenizer(PreTrainedTokenizer):
|
|
736
965
|
dataset=dataset,
|
737
966
|
concept_name_mapping=concept_name_mapping,
|
738
967
|
data_args=data_args,
|
968
|
+
num_motor_tasks=num_motor_tasks,
|
969
|
+
apply_entropy_filter=apply_entropy_filter,
|
970
|
+
min_prevalence=min_prevalence,
|
739
971
|
)
|
740
972
|
|
741
973
|
new_tokens = set(new_tokenizer.get_vocab().keys()) - set(
|
@@ -753,6 +985,7 @@ class CehrGptTokenizer(PreTrainedTokenizer):
|
|
753
985
|
new_numeric_lab_stats = new_tokenizer._numeric_lab_stats
|
754
986
|
new_categorical_lab_stats = new_tokenizer._categorical_lab_stats
|
755
987
|
new_concept_name_mapping = new_tokenizer._concept_name_mapping
|
988
|
+
new_motor_time_to_event_codes = new_tokenizer._motor_time_to_event_codes
|
756
989
|
|
757
990
|
# Add new tokens to the existing tokenizer
|
758
991
|
cehrgpt_tokenizer_copy._tokenizer.add_tokens(
|
@@ -801,15 +1034,31 @@ class CehrGptTokenizer(PreTrainedTokenizer):
|
|
801
1034
|
if token not in cehrgpt_tokenizer_copy._concept_name_mapping:
|
802
1035
|
cehrgpt_tokenizer_copy._concept_name_mapping[token] = concept_name
|
803
1036
|
|
1037
|
+
# Merge motor_time_to_event_codes
|
1038
|
+
if (
|
1039
|
+
new_motor_time_to_event_codes
|
1040
|
+
and cehrgpt_tokenizer_copy._motor_time_to_event_codes
|
1041
|
+
):
|
1042
|
+
for motor_time_to_event_code in new_motor_time_to_event_codes:
|
1043
|
+
if (
|
1044
|
+
motor_time_to_event_code
|
1045
|
+
not in cehrgpt_tokenizer_copy._motor_time_to_event_codes
|
1046
|
+
):
|
1047
|
+
cehrgpt_tokenizer_copy._motor_time_to_event_codes.append(
|
1048
|
+
motor_time_to_event_code
|
1049
|
+
)
|
1050
|
+
|
804
1051
|
return CehrGptTokenizer(
|
805
1052
|
tokenizer=cehrgpt_tokenizer_copy._tokenizer,
|
806
1053
|
value_tokenizer=cehrgpt_tokenizer_copy._value_tokenizer,
|
807
1054
|
att_tokenizer=cehrgpt_tokenizer_copy._att_tokenizer,
|
808
1055
|
token_to_sub_time_token_mapping=cehrgpt_tokenizer_copy._token_to_sub_time_token_mapping,
|
1056
|
+
concept_code_stats=cehrgpt_tokenizer_copy._concept_code_stats,
|
809
1057
|
numeric_lab_stats=cehrgpt_tokenizer_copy._numeric_lab_stats,
|
810
1058
|
categorical_lab_stats=cehrgpt_tokenizer_copy._categorical_lab_stats,
|
811
1059
|
concept_name_mapping=cehrgpt_tokenizer_copy._concept_name_mapping,
|
812
1060
|
pretrained_concept_embedding_model=pretrained_concept_embedding_model,
|
1061
|
+
motor_time_to_event_codes=cehrgpt_tokenizer_copy._motor_time_to_event_codes,
|
813
1062
|
)
|
814
1063
|
|
815
1064
|
@classmethod
|
@@ -879,6 +1128,10 @@ class CehrGptTokenizer(PreTrainedTokenizer):
|
|
879
1128
|
concept_name_mapping: Dict[str, str],
|
880
1129
|
data_args: DataTrainingArguments,
|
881
1130
|
pretrained_concept_embedding_model: PretrainedEmbeddings = None,
|
1131
|
+
allowed_motor_codes: Optional[List[int]] = None,
|
1132
|
+
num_motor_tasks: Optional[int] = None,
|
1133
|
+
apply_entropy_filter: bool = False,
|
1134
|
+
min_prevalence: float = 1 / 1000,
|
882
1135
|
):
|
883
1136
|
"""
|
884
1137
|
Train a huggingface word level tokenizer.
|
@@ -890,13 +1143,49 @@ class CehrGptTokenizer(PreTrainedTokenizer):
|
|
890
1143
|
if isinstance(dataset, DatasetDict):
|
891
1144
|
dataset = dataset["train"]
|
892
1145
|
|
1146
|
+
LOG.info("Calculating data statistics")
|
1147
|
+
cehrgpt_data_statistics = compute_statistics(dataset, data_args)
|
1148
|
+
cehrgpt_data_statistics["total"]
|
1149
|
+
numeric_lab_stats = cehrgpt_data_statistics["numeric_lab_stats"]
|
1150
|
+
categorical_lab_stats = cehrgpt_data_statistics["categorical_lab_stats"]
|
1151
|
+
concept_code_stats = cehrgpt_data_statistics["concept_code_stats"]
|
1152
|
+
concept_code_entropies = cehrgpt_data_statistics["concept_code_entropies"]
|
1153
|
+
|
1154
|
+
if apply_entropy_filter:
|
1155
|
+
min_prevalence = max(1e-8, min_prevalence)
|
1156
|
+
min_entropy = (
|
1157
|
+
np.log(1 - min_prevalence) * (1 - min_prevalence)
|
1158
|
+
+ np.log(min_prevalence) * min_prevalence
|
1159
|
+
)
|
1160
|
+
qualified_codes = [
|
1161
|
+
k
|
1162
|
+
for k, v in concept_code_entropies.items()
|
1163
|
+
if v <= min_entropy
|
1164
|
+
or not is_clinical_event(k, data_args.is_data_in_meds)
|
1165
|
+
]
|
1166
|
+
else:
|
1167
|
+
qualified_codes = [
|
1168
|
+
k
|
1169
|
+
for k, v in concept_code_stats.items()
|
1170
|
+
if min_prevalence <= v
|
1171
|
+
or not is_clinical_event(k, data_args.is_data_in_meds)
|
1172
|
+
]
|
1173
|
+
|
1174
|
+
# Create the tokenizer now
|
893
1175
|
LOG.info("Training the tokenizer for concepts")
|
894
1176
|
concept_tokenizer = cls.train_concept_tokenizer(
|
895
1177
|
dataset,
|
896
1178
|
feature_name="concept_ids",
|
897
|
-
special_tokens=[
|
1179
|
+
special_tokens=[
|
1180
|
+
PAD_TOKEN,
|
1181
|
+
OUT_OF_VOCABULARY_TOKEN,
|
1182
|
+
START_TOKEN,
|
1183
|
+
END_TOKEN,
|
1184
|
+
LINEAR_PROB_TOKEN,
|
1185
|
+
],
|
898
1186
|
unk_token=OUT_OF_VOCABULARY_TOKEN,
|
899
1187
|
data_args=data_args,
|
1188
|
+
qualified_codes=qualified_codes,
|
900
1189
|
)
|
901
1190
|
concept_value_column = "concept_as_values"
|
902
1191
|
for row in dataset:
|
@@ -919,65 +1208,11 @@ class CehrGptTokenizer(PreTrainedTokenizer):
|
|
919
1208
|
]
|
920
1209
|
)
|
921
1210
|
|
922
|
-
map_statistics_partial = partial(map_statistics, size=SAMPLE_SIZE)
|
923
|
-
|
924
|
-
if data_args.streaming:
|
925
|
-
first_example = next(iter(dataset))
|
926
|
-
parts = dataset.map(
|
927
|
-
partial(agg_helper, map_func=map_statistics_partial),
|
928
|
-
batched=True,
|
929
|
-
batch_size=data_args.preprocessing_batch_size,
|
930
|
-
remove_columns=first_example.keys(),
|
931
|
-
)
|
932
|
-
else:
|
933
|
-
parts = dataset.map(
|
934
|
-
partial(agg_helper, map_func=map_statistics_partial),
|
935
|
-
batched=True,
|
936
|
-
batch_size=data_args.preprocessing_batch_size,
|
937
|
-
remove_columns=dataset.column_names,
|
938
|
-
num_proc=data_args.preprocessing_num_workers,
|
939
|
-
keep_in_memory=True,
|
940
|
-
new_fingerprint="invalid",
|
941
|
-
)
|
942
|
-
current = None
|
943
|
-
for stat in tqdm(parts, desc="Aggregating the lab statistics"):
|
944
|
-
fixed_stat = pickle.loads(stat["data"])
|
945
|
-
if current is None:
|
946
|
-
current = fixed_stat
|
947
|
-
else:
|
948
|
-
current = agg_statistics(current, fixed_stat)
|
949
|
-
|
950
|
-
numeric_lab_stats = []
|
951
|
-
for (concept_id, unit), online_stats in current["numeric_stats_by_lab"].items():
|
952
|
-
if len(online_stats.samples) == 0:
|
953
|
-
continue
|
954
|
-
samples = truncated_sample(
|
955
|
-
online_stats.samples, data_args.value_outlier_std
|
956
|
-
)
|
957
|
-
bins = create_bins_with_spline(samples, NUM_OF_BINS, DEGREE_OF_FREEDOM)
|
958
|
-
if len(bins) > 0:
|
959
|
-
numeric_lab_stats.append(
|
960
|
-
{
|
961
|
-
"concept_id": concept_id,
|
962
|
-
"unit": unit,
|
963
|
-
"mean": np.mean(samples),
|
964
|
-
"std": np.std(samples),
|
965
|
-
"count": len(online_stats.samples),
|
966
|
-
"value_outlier_std": data_args.value_outlier_std,
|
967
|
-
"bins": bins,
|
968
|
-
}
|
969
|
-
)
|
970
|
-
|
971
|
-
categorical_lab_stats = collections.defaultdict(int)
|
972
|
-
for (concept_id, value_as_concept), count in current[
|
973
|
-
"categorical_stats_by_lab"
|
974
|
-
].items():
|
975
|
-
categorical_lab_stats[(concept_id, value_as_concept)] += count
|
976
|
-
|
977
1211
|
# We will train a tokenizer specifically for time intervals
|
978
1212
|
sub_time_token_data = []
|
979
1213
|
token_to_sub_time_token_mapping = collections.defaultdict(list)
|
980
|
-
|
1214
|
+
vocab = concept_tokenizer.get_vocab()
|
1215
|
+
for token, token_id in vocab.items():
|
981
1216
|
if is_att_token(token):
|
982
1217
|
time_interval = extract_time_interval_in_days(token)
|
983
1218
|
time_tuple = convert_time_interval_to_time_tuple(
|
@@ -998,15 +1233,43 @@ class CehrGptTokenizer(PreTrainedTokenizer):
|
|
998
1233
|
)
|
999
1234
|
att_tokenizer.train_from_iterator(sub_time_token_data, trainer=att_trainer)
|
1000
1235
|
|
1236
|
+
# Prune concept_name_mapping
|
1237
|
+
concept_name_mapping = {
|
1238
|
+
concept_id: concept_name_mapping[concept_id]
|
1239
|
+
for concept_id in vocab.keys()
|
1240
|
+
if concept_id in concept_name_mapping
|
1241
|
+
}
|
1242
|
+
|
1243
|
+
motor_time_to_event_codes = None
|
1244
|
+
if num_motor_tasks and allowed_motor_codes:
|
1245
|
+
motor_time_to_event_codes = []
|
1246
|
+
for concept_id, _ in sorted(
|
1247
|
+
concept_code_entropies.items(), key=lambda t: t[1]
|
1248
|
+
):
|
1249
|
+
if (
|
1250
|
+
concept_id not in allowed_motor_codes
|
1251
|
+
or concept_id not in qualified_codes
|
1252
|
+
):
|
1253
|
+
continue
|
1254
|
+
if len(motor_time_to_event_codes) < num_motor_tasks:
|
1255
|
+
motor_time_to_event_codes.append(concept_id)
|
1256
|
+
else:
|
1257
|
+
break
|
1258
|
+
LOG.info(
|
1259
|
+
f"{len(motor_time_to_event_codes)} number of tasks have been added as MOTOR tasks"
|
1260
|
+
)
|
1261
|
+
|
1001
1262
|
return CehrGptTokenizer(
|
1002
1263
|
concept_tokenizer,
|
1003
1264
|
value_tokenizer,
|
1004
1265
|
att_tokenizer,
|
1005
1266
|
token_to_sub_time_token_mapping,
|
1267
|
+
concept_code_stats,
|
1006
1268
|
numeric_lab_stats,
|
1007
1269
|
categorical_lab_stats,
|
1008
1270
|
concept_name_mapping,
|
1009
1271
|
pretrained_concept_embedding_model,
|
1272
|
+
motor_time_to_event_codes,
|
1010
1273
|
)
|
1011
1274
|
|
1012
1275
|
@classmethod
|
@@ -1017,6 +1280,7 @@ class CehrGptTokenizer(PreTrainedTokenizer):
|
|
1017
1280
|
special_tokens: List[str],
|
1018
1281
|
unk_token,
|
1019
1282
|
data_args,
|
1283
|
+
qualified_codes: Optional[List[str]] = None,
|
1020
1284
|
):
|
1021
1285
|
# Use the Fast Tokenizer from the Huggingface tokenizers Rust implementation.
|
1022
1286
|
# https://github.com/huggingface/tokenizers
|
@@ -1029,7 +1293,9 @@ class CehrGptTokenizer(PreTrainedTokenizer):
|
|
1029
1293
|
show_progress=True,
|
1030
1294
|
)
|
1031
1295
|
batch_concat_concepts_partial_func = partial(
|
1032
|
-
cls.batch_concat_concepts,
|
1296
|
+
cls.batch_concat_concepts,
|
1297
|
+
feature_name=feature_name,
|
1298
|
+
qualified_codes=qualified_codes,
|
1033
1299
|
)
|
1034
1300
|
if data_args.streaming:
|
1035
1301
|
concatenated_features = dataset.map(
|
@@ -1069,13 +1335,30 @@ class CehrGptTokenizer(PreTrainedTokenizer):
|
|
1069
1335
|
|
1070
1336
|
@classmethod
|
1071
1337
|
def batch_concat_concepts(
|
1072
|
-
cls,
|
1338
|
+
cls,
|
1339
|
+
records: Dict[str, List],
|
1340
|
+
feature_name: str,
|
1341
|
+
qualified_codes: Optional[List[str]] = None,
|
1073
1342
|
) -> Dict[str, List]:
|
1343
|
+
def filter_token(t: str) -> bool:
|
1344
|
+
"""
|
1345
|
+
If the token is None or not string, return False.
|
1346
|
+
|
1347
|
+
When qualified_codes is provided, t must be in
|
1348
|
+
qualified_codes to be valid, otherwise the tokens are always valid
|
1349
|
+
|
1350
|
+
:param t:
|
1351
|
+
:return:
|
1352
|
+
"""
|
1353
|
+
if t is None or not isinstance(t, str):
|
1354
|
+
return False
|
1355
|
+
if qualified_codes:
|
1356
|
+
return t in qualified_codes
|
1357
|
+
return True
|
1358
|
+
|
1074
1359
|
return {
|
1075
1360
|
feature_name: [
|
1076
|
-
" ".join(
|
1077
|
-
[token for token in tokens if token and isinstance(token, str)]
|
1078
|
-
)
|
1361
|
+
" ".join(filter(filter_token, tokens))
|
1079
1362
|
for tokens in records[feature_name]
|
1080
1363
|
]
|
1081
1364
|
}
|