cehrgpt 0.0.2__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cehrgpt/analysis/irregularity.py +36 -0
- cehrgpt/data/hf_cehrgpt_dataset.py +25 -4
- cehrgpt/data/hf_cehrgpt_dataset_collator.py +635 -97
- cehrgpt/data/hf_cehrgpt_dataset_mapping.py +308 -95
- cehrgpt/data/sample_packing_sampler.py +181 -0
- cehrgpt/generation/generate_batch_hf_gpt_sequence.py +12 -9
- cehrgpt/generation/omop_converter_batch.py +32 -2
- cehrgpt/gpt_utils.py +20 -2
- cehrgpt/models/config.py +35 -0
- cehrgpt/models/hf_cehrgpt.py +470 -106
- cehrgpt/models/hf_modeling_outputs.py +1 -0
- cehrgpt/models/special_tokens.py +1 -0
- cehrgpt/models/tokenization_hf_cehrgpt.py +358 -71
- cehrgpt/runners/data_utils.py +358 -0
- cehrgpt/runners/gpt_runner_util.py +0 -10
- cehrgpt/runners/hf_cehrgpt_finetune_runner.py +181 -283
- cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +288 -112
- cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +90 -0
- cehrgpt/runners/hyperparameter_search_util.py +10 -8
- cehrgpt/runners/sample_packing_trainer.py +185 -0
- cehrgpt/simulations/generate_plots.py +95 -0
- cehrgpt/simulations/run_simulation.sh +24 -0
- cehrgpt/simulations/time_embedding_simulation.py +250 -0
- cehrgpt/simulations/time_token_simulation.py +177 -0
- cehrgpt/time_to_event/config/1_year_cabg.yaml +23 -0
- cehrgpt/time_to_event/time_to_event_model.py +2 -13
- cehrgpt/time_to_event/time_to_event_prediction.py +27 -13
- cehrgpt/tools/linear_prob/__init__.py +0 -0
- cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +495 -0
- cehrgpt/tools/linear_prob/train_with_cehrgpt_features.py +152 -0
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/METADATA +11 -8
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/RECORD +36 -32
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/WHEEL +1 -1
- cehrgpt/data/hf_cehrgpt_dpo_collator.py +0 -71
- cehrgpt/data/hf_cehrgpt_dpo_dataset_mapping.py +0 -61
- cehrgpt/generation/generate_paired_cehrgpt_sequence.py +0 -224
- cehrgpt/rl_finetune/cehrgpt_dpo_trainer.py +0 -586
- cehrgpt/rl_finetune/cehrgpt_ppo_trainer.py +0 -464
- cehrgpt/rl_finetune/ppo_finetune.py +0 -394
- cehrgpt/rl_finetune/ppo_finetune_v2.py +0 -373
- cehrgpt/runners/hf_cehrgpt_dpo_runner.py +0 -119
- /cehrgpt/{rl_finetune → simulations}/__init__.py +0 -0
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info/licenses}/LICENSE +0 -0
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/top_level.txt +0 -0
@@ -10,13 +10,9 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
|
|
10
10
|
import numpy as np
|
11
11
|
import scipy.stats as stats
|
12
12
|
import transformers
|
13
|
-
from cehrbert.models.hf_models.tokenization_utils import
|
14
|
-
agg_helper,
|
15
|
-
agg_statistics,
|
16
|
-
load_json_file,
|
17
|
-
)
|
13
|
+
from cehrbert.models.hf_models.tokenization_utils import agg_helper, load_json_file
|
18
14
|
from cehrbert.runners.hf_runner_argument_dataclass import DataTrainingArguments
|
19
|
-
from datasets import Dataset, DatasetDict
|
15
|
+
from datasets import Dataset, DatasetDict, IterableDataset
|
20
16
|
from femr.stat_utils import OnlineStatistics, ReservoirSampler
|
21
17
|
from scipy.interpolate import UnivariateSpline
|
22
18
|
from tokenizers import AddedToken, Tokenizer
|
@@ -25,16 +21,19 @@ from tokenizers.pre_tokenizers import WhitespaceSplit
|
|
25
21
|
from tokenizers.trainers import WordLevelTrainer
|
26
22
|
from tqdm import tqdm
|
27
23
|
from transformers import PreTrainedTokenizer
|
24
|
+
from transformers.utils import logging
|
28
25
|
|
29
26
|
from cehrgpt.gpt_utils import (
|
30
27
|
convert_time_interval_to_time_tuple,
|
31
28
|
extract_time_interval_in_days,
|
32
29
|
is_att_token,
|
30
|
+
is_clinical_event,
|
33
31
|
is_inpatient_att_token,
|
34
32
|
)
|
35
33
|
from cehrgpt.models.pretrained_embeddings import PretrainedEmbeddings
|
36
34
|
from cehrgpt.models.special_tokens import (
|
37
35
|
END_TOKEN,
|
36
|
+
LINEAR_PROB_TOKEN,
|
38
37
|
OUT_OF_VOCABULARY_TOKEN,
|
39
38
|
PAD_TOKEN,
|
40
39
|
START_TOKEN,
|
@@ -52,7 +51,11 @@ TIME_TOKENIZER_FILE_NAME = "cehrgpt_time_tokenizer.json"
|
|
52
51
|
TOKEN_TO_SUB_TIME_TOKEN_MAPPING_FILE_NAME = "token_to_sub_time_token_mapping.json"
|
53
52
|
LAB_STATS_FILE_NAME = "cehrgpt_lab_stats.pickle"
|
54
53
|
LEGACY_LAB_STATS_FILE_NAME = "cehrgpt_lab_stats.json"
|
54
|
+
CONCEPT_STATS_FILE_NAME = "cehrgpt_concept_stats.json"
|
55
55
|
CONCEPT_MAPPING_FILE_NAME = "concept_name_mapping.json"
|
56
|
+
MOTOR_TIME_TO_EVENT_CODES_FILE_NAME = "motor_time_to_event_codes.json"
|
57
|
+
|
58
|
+
LOG = logging.get_logger("transformers")
|
56
59
|
|
57
60
|
|
58
61
|
def truncated_sample(sample, standard_deviation):
|
@@ -71,6 +74,16 @@ def create_value_bin(bin_index: int) -> str:
|
|
71
74
|
return "BIN:" + str(bin_index)
|
72
75
|
|
73
76
|
|
77
|
+
def get_dataset_len(dataset: Union[Dataset, IterableDataset]) -> int:
|
78
|
+
if isinstance(dataset, Dataset):
|
79
|
+
return len(dataset)
|
80
|
+
elif isinstance(dataset, IterableDataset):
|
81
|
+
return sum([1 for _ in dataset])
|
82
|
+
raise RuntimeError(
|
83
|
+
"The dataset must be one of the two types (Dataset, IterableDataset)"
|
84
|
+
)
|
85
|
+
|
86
|
+
|
74
87
|
def create_sample_from_bins(bins, sample_size: int = 10_000) -> List[float]:
|
75
88
|
"""
|
76
89
|
Generates a specified number of samples from a list of bins, each containing a fitted spline.
|
@@ -186,7 +199,22 @@ def create_bins_with_spline(samples, num_bins, d_freedom=3) -> List[Dict[str, An
|
|
186
199
|
return bins
|
187
200
|
|
188
201
|
|
189
|
-
def
|
202
|
+
def agg_statistics(stats1, stats2):
|
203
|
+
if stats1.get("numeric_stats_by_lab"):
|
204
|
+
for k, v in stats2["numeric_stats_by_lab"].items():
|
205
|
+
stats1["numeric_stats_by_lab"][k].combine(v)
|
206
|
+
if stats1.get("categorical_stats_by_lab"):
|
207
|
+
for (concept_id, concept_as_value), count in stats2[
|
208
|
+
"categorical_stats_by_lab"
|
209
|
+
].items():
|
210
|
+
stats1["categorical_stats_by_lab"][(concept_id, concept_as_value)] += count
|
211
|
+
if stats1.get("concept_code_stats"):
|
212
|
+
for concept_id, weight in stats2["concept_code_stats"].items():
|
213
|
+
stats1["concept_code_stats"][concept_id] += weight
|
214
|
+
return stats1
|
215
|
+
|
216
|
+
|
217
|
+
def map_statistics(batch: Dict[str, Any], total_size, size=10_000) -> Dict[str, Any]:
|
190
218
|
if "units" in batch:
|
191
219
|
batch_value_units = batch["units"]
|
192
220
|
else:
|
@@ -210,6 +238,7 @@ def map_statistics(batch: Dict[str, Any], size=10_000) -> Dict[str, Any]:
|
|
210
238
|
|
211
239
|
numeric_stats_by_lab = collections.defaultdict(partial(ReservoirSampler, size=size))
|
212
240
|
categorical_stats_by_lab = collections.defaultdict(int)
|
241
|
+
concept_code_stats = collections.defaultdict(int)
|
213
242
|
for (
|
214
243
|
concept_ids,
|
215
244
|
number_as_values,
|
@@ -223,6 +252,7 @@ def map_statistics(batch: Dict[str, Any], size=10_000) -> Dict[str, Any]:
|
|
223
252
|
batch["concept_value_masks"],
|
224
253
|
batch_value_units,
|
225
254
|
):
|
255
|
+
unique_codes = set()
|
226
256
|
for (
|
227
257
|
concept_id,
|
228
258
|
number_as_value,
|
@@ -241,10 +271,94 @@ def map_statistics(batch: Dict[str, Any], size=10_000) -> Dict[str, Any]:
|
|
241
271
|
numeric_stats_by_lab[(concept_id, unit)].add(number_as_value, 1)
|
242
272
|
if concept_as_value:
|
243
273
|
categorical_stats_by_lab[(concept_id, concept_as_value)] += 1
|
274
|
+
unique_codes.add(concept_id)
|
275
|
+
|
276
|
+
for code in unique_codes:
|
277
|
+
concept_code_stats[code] += 1 / total_size
|
244
278
|
|
245
279
|
return {
|
246
280
|
"numeric_stats_by_lab": numeric_stats_by_lab,
|
247
281
|
"categorical_stats_by_lab": categorical_stats_by_lab,
|
282
|
+
"concept_code_stats": concept_code_stats,
|
283
|
+
}
|
284
|
+
|
285
|
+
|
286
|
+
def compute_statistics(
|
287
|
+
dataset: Dataset, data_args: DataTrainingArguments
|
288
|
+
) -> Dict[str, Any]:
|
289
|
+
total = get_dataset_len(dataset)
|
290
|
+
map_statistics_partial = partial(map_statistics, total_size=total, size=SAMPLE_SIZE)
|
291
|
+
if data_args.streaming:
|
292
|
+
first_example = next(iter(dataset))
|
293
|
+
parts = dataset.map(
|
294
|
+
partial(agg_helper, map_func=map_statistics_partial),
|
295
|
+
batched=True,
|
296
|
+
batch_size=data_args.preprocessing_batch_size,
|
297
|
+
remove_columns=first_example.keys(),
|
298
|
+
)
|
299
|
+
else:
|
300
|
+
parts = dataset.map(
|
301
|
+
partial(agg_helper, map_func=map_statistics_partial),
|
302
|
+
batched=True,
|
303
|
+
batch_size=data_args.preprocessing_batch_size,
|
304
|
+
remove_columns=dataset.column_names,
|
305
|
+
num_proc=data_args.preprocessing_num_workers,
|
306
|
+
keep_in_memory=True,
|
307
|
+
new_fingerprint="invalid",
|
308
|
+
)
|
309
|
+
current = None
|
310
|
+
for stat in tqdm(parts, desc="Aggregating the lab statistics"):
|
311
|
+
fixed_stat = pickle.loads(stat["data"])
|
312
|
+
if current is None:
|
313
|
+
current = fixed_stat
|
314
|
+
else:
|
315
|
+
current = agg_statistics(current, fixed_stat)
|
316
|
+
|
317
|
+
numeric_lab_stats = []
|
318
|
+
for (concept_id, unit), online_stats in current["numeric_stats_by_lab"].items():
|
319
|
+
if len(online_stats.samples) == 0:
|
320
|
+
continue
|
321
|
+
samples = truncated_sample(online_stats.samples, data_args.value_outlier_std)
|
322
|
+
bins = create_bins_with_spline(samples, NUM_OF_BINS, DEGREE_OF_FREEDOM)
|
323
|
+
if len(bins) > 0:
|
324
|
+
numeric_lab_stats.append(
|
325
|
+
{
|
326
|
+
"concept_id": concept_id,
|
327
|
+
"unit": unit,
|
328
|
+
"mean": np.mean(samples),
|
329
|
+
"std": np.std(samples),
|
330
|
+
"count": len(online_stats.samples),
|
331
|
+
"value_outlier_std": data_args.value_outlier_std,
|
332
|
+
"bins": bins,
|
333
|
+
}
|
334
|
+
)
|
335
|
+
|
336
|
+
categorical_lab_stats = collections.defaultdict(int)
|
337
|
+
for (concept_id, value_as_concept), count in current[
|
338
|
+
"categorical_stats_by_lab"
|
339
|
+
].items():
|
340
|
+
categorical_lab_stats[(concept_id, value_as_concept)] += count
|
341
|
+
|
342
|
+
concept_code_stats = collections.defaultdict(int)
|
343
|
+
for concept_id, count in current["concept_code_stats"].items():
|
344
|
+
concept_code_stats[concept_id] += count
|
345
|
+
|
346
|
+
code_weights = np.asarray(list(concept_code_stats.values())).clip(1e-8, 1 - 1e-8)
|
347
|
+
# Clip the values so we don't get errors when applying np.log
|
348
|
+
code_entropies = np.log(code_weights) * code_weights + (1 - code_weights) * np.log(
|
349
|
+
1 - code_weights
|
350
|
+
)
|
351
|
+
|
352
|
+
concept_code_entropies = {
|
353
|
+
k: v for k, v in zip(concept_code_stats.keys(), code_entropies)
|
354
|
+
}
|
355
|
+
|
356
|
+
return {
|
357
|
+
"numeric_lab_stats": numeric_lab_stats,
|
358
|
+
"categorical_lab_stats": categorical_lab_stats,
|
359
|
+
"concept_code_stats": concept_code_stats,
|
360
|
+
"concept_code_entropies": concept_code_entropies,
|
361
|
+
"total": total,
|
248
362
|
}
|
249
363
|
|
250
364
|
|
@@ -346,15 +460,18 @@ class CehrGptTokenizer(PreTrainedTokenizer):
|
|
346
460
|
value_tokenizer: Tokenizer,
|
347
461
|
att_tokenizer: Tokenizer,
|
348
462
|
token_to_sub_time_token_mapping: Dict[str, List[str]],
|
463
|
+
concept_code_stats: Dict[str, Any],
|
349
464
|
numeric_lab_stats: List[Dict[str, Any]],
|
350
465
|
categorical_lab_stats: Dict[Tuple[str, str], int],
|
351
466
|
concept_name_mapping: Dict[str, str],
|
352
467
|
pretrained_concept_embedding_model: PretrainedEmbeddings = None,
|
468
|
+
motor_time_to_event_codes: Optional[List[str]] = None,
|
353
469
|
):
|
354
470
|
self._tokenizer = tokenizer
|
355
471
|
self._value_tokenizer = value_tokenizer
|
356
472
|
self._att_tokenizer = att_tokenizer
|
357
473
|
self._token_to_sub_time_token_mapping = token_to_sub_time_token_mapping
|
474
|
+
self._concept_code_stats = concept_code_stats
|
358
475
|
self._numeric_lab_stats = numeric_lab_stats
|
359
476
|
self._numeric_event_statistics = NumericEventStatistics(numeric_lab_stats)
|
360
477
|
self._categorical_lab_stats = categorical_lab_stats
|
@@ -363,6 +480,12 @@ class CehrGptTokenizer(PreTrainedTokenizer):
|
|
363
480
|
self._padding_token_id = self._tokenizer.token_to_id(PAD_TOKEN)
|
364
481
|
self._start_token_id = self._tokenizer.token_to_id(START_TOKEN)
|
365
482
|
self._end_token_id = self._tokenizer.token_to_id(END_TOKEN)
|
483
|
+
|
484
|
+
# Backward compatible with the old tokenizer
|
485
|
+
self._linear_token_id = None
|
486
|
+
if LINEAR_PROB_TOKEN in self._tokenizer.get_vocab():
|
487
|
+
self._linear_token_id = self._tokenizer.token_to_id(LINEAR_PROB_TOKEN)
|
488
|
+
|
366
489
|
self._numeric_concept_ids = (
|
367
490
|
self._numeric_event_statistics.get_numeric_concept_ids()
|
368
491
|
)
|
@@ -380,6 +503,12 @@ class CehrGptTokenizer(PreTrainedTokenizer):
|
|
380
503
|
for _ in self.get_vocab().keys()
|
381
504
|
if self._pretrained_concept_embedding_model.is_concept_available(_)
|
382
505
|
]
|
506
|
+
self._motor_time_to_event_codes = (
|
507
|
+
motor_time_to_event_codes if motor_time_to_event_codes else []
|
508
|
+
)
|
509
|
+
self._motor_code_to_id_mapping = {
|
510
|
+
code: i for i, code in enumerate(sorted(self._motor_time_to_event_codes))
|
511
|
+
}
|
383
512
|
|
384
513
|
super().__init__()
|
385
514
|
|
@@ -400,6 +529,10 @@ class CehrGptTokenizer(PreTrainedTokenizer):
|
|
400
529
|
]
|
401
530
|
)
|
402
531
|
|
532
|
+
@property
|
533
|
+
def motor_tte_vocab_size(self) -> int:
|
534
|
+
return len(self._motor_code_to_id_mapping)
|
535
|
+
|
403
536
|
@property
|
404
537
|
def vocab_size(self) -> int:
|
405
538
|
return self._tokenizer.get_vocab_size()
|
@@ -440,10 +573,45 @@ class CehrGptTokenizer(PreTrainedTokenizer):
|
|
440
573
|
def pad_token_id(self):
|
441
574
|
return self._padding_token_id
|
442
575
|
|
576
|
+
@property
|
577
|
+
def oov_token_id(self):
|
578
|
+
return self._oov_token_id
|
579
|
+
|
443
580
|
@property
|
444
581
|
def pad_token(self):
|
445
582
|
return PAD_TOKEN
|
446
583
|
|
584
|
+
@property
|
585
|
+
def linear_token_id(self) -> Optional[int]:
|
586
|
+
return self._linear_token_id
|
587
|
+
|
588
|
+
@property
|
589
|
+
def linear_token(self) -> Optional[str]:
|
590
|
+
if LINEAR_PROB_TOKEN in self._tokenizer.get_vocab():
|
591
|
+
return LINEAR_PROB_TOKEN
|
592
|
+
return None
|
593
|
+
|
594
|
+
def vs_token_id(self):
|
595
|
+
# We used VS for the historical data, currently, we use the new [VS] for the newer data
|
596
|
+
# so we need to check both cases.
|
597
|
+
if "VS" in self._tokenizer.get_vocab():
|
598
|
+
return self._convert_token_to_id("VS")
|
599
|
+
elif "[VS]" in self._tokenizer.get_vocab():
|
600
|
+
return self._convert_token_to_id("[VS]")
|
601
|
+
else:
|
602
|
+
raise RuntimeError("The tokenizer does not contain either VS or [VS]")
|
603
|
+
|
604
|
+
@property
|
605
|
+
def ve_token_id(self):
|
606
|
+
# We used VE for the historical data, currently, we use the new [VE] for the newer data
|
607
|
+
# so we need to check both cases.
|
608
|
+
if "VE" in self._tokenizer.get_vocab():
|
609
|
+
return self._convert_token_to_id("VE")
|
610
|
+
elif "[VE]" in self._tokenizer.get_vocab():
|
611
|
+
return self._convert_token_to_id("[VE]")
|
612
|
+
else:
|
613
|
+
raise RuntimeError("The tokenizer does not contain either VE or [VE]")
|
614
|
+
|
447
615
|
@property
|
448
616
|
def numeric_concept_ids(self):
|
449
617
|
return self._numeric_concept_ids
|
@@ -454,8 +622,14 @@ class CehrGptTokenizer(PreTrainedTokenizer):
|
|
454
622
|
|
455
623
|
@property
|
456
624
|
def lab_token_ids(self):
|
457
|
-
reserved_tokens = [
|
458
|
-
|
625
|
+
reserved_tokens = [
|
626
|
+
START_TOKEN,
|
627
|
+
PAD_TOKEN,
|
628
|
+
END_TOKEN,
|
629
|
+
OUT_OF_VOCABULARY_TOKEN,
|
630
|
+
LINEAR_PROB_TOKEN,
|
631
|
+
]
|
632
|
+
lab_token_ids = self.encode(
|
459
633
|
[
|
460
634
|
concept_id
|
461
635
|
for concept_id in self._numeric_concept_ids
|
@@ -463,6 +637,9 @@ class CehrGptTokenizer(PreTrainedTokenizer):
|
|
463
637
|
if concept_id not in reserved_tokens
|
464
638
|
]
|
465
639
|
)
|
640
|
+
return list(
|
641
|
+
filter(lambda token_id: token_id != self._oov_token_id, lab_token_ids)
|
642
|
+
)
|
466
643
|
|
467
644
|
@property
|
468
645
|
def token_to_time_token_mapping(self) -> Dict[int, List[int]]:
|
@@ -481,6 +658,19 @@ class CehrGptTokenizer(PreTrainedTokenizer):
|
|
481
658
|
def pretrained_concept_embedding_model(self):
|
482
659
|
return self._pretrained_concept_embedding_model
|
483
660
|
|
661
|
+
def get_motor_token_id(self, concept_id: str) -> int:
|
662
|
+
if concept_id not in concept_id:
|
663
|
+
raise RuntimeError(f"Invalid motor concept id: {concept_id}")
|
664
|
+
return self._motor_code_to_id_mapping[concept_id]
|
665
|
+
|
666
|
+
def is_motor_time_to_event_code(self, future_concept_id: str) -> bool:
|
667
|
+
if (
|
668
|
+
self._motor_time_to_event_codes
|
669
|
+
and future_concept_id in self._motor_time_to_event_codes
|
670
|
+
):
|
671
|
+
return True
|
672
|
+
return False
|
673
|
+
|
484
674
|
def get_vocab(self) -> Dict[str, int]:
|
485
675
|
return self._tokenizer.get_vocab()
|
486
676
|
|
@@ -509,6 +699,18 @@ class CehrGptTokenizer(PreTrainedTokenizer):
|
|
509
699
|
concept_value_token_ids, skip_special_tokens=skip_special_tokens
|
510
700
|
).split(" ")
|
511
701
|
|
702
|
+
def add_token(self, tokens: Union[str, List[str]]) -> None:
|
703
|
+
if isinstance(tokens, str):
|
704
|
+
tokens = [tokens]
|
705
|
+
vocab = self.get_vocab()
|
706
|
+
self._tokenizer.add_tokens(
|
707
|
+
[
|
708
|
+
AddedToken(token, single_word=True, normalized=False)
|
709
|
+
for token in tokens
|
710
|
+
if token not in vocab
|
711
|
+
]
|
712
|
+
)
|
713
|
+
|
512
714
|
def _convert_token_to_id(self, token):
|
513
715
|
"""Converts a token (str) in an id using the vocab."""
|
514
716
|
token_id = self._tokenizer.token_to_id(token)
|
@@ -570,6 +772,9 @@ class CehrGptTokenizer(PreTrainedTokenizer):
|
|
570
772
|
) as f:
|
571
773
|
json.dump(self._token_to_sub_time_token_mapping, f)
|
572
774
|
|
775
|
+
with open(os.path.join(save_directory, CONCEPT_STATS_FILE_NAME), "w") as f:
|
776
|
+
json.dump(self._concept_code_stats, f)
|
777
|
+
|
573
778
|
with open(os.path.join(save_directory, LAB_STATS_FILE_NAME), "wb") as f:
|
574
779
|
lab_stats = {
|
575
780
|
"numeric_lab_stats": self._numeric_lab_stats,
|
@@ -580,6 +785,11 @@ class CehrGptTokenizer(PreTrainedTokenizer):
|
|
580
785
|
with open(os.path.join(save_directory, CONCEPT_MAPPING_FILE_NAME), "w") as f:
|
581
786
|
json.dump(self._concept_name_mapping, f)
|
582
787
|
|
788
|
+
with open(
|
789
|
+
os.path.join(save_directory, MOTOR_TIME_TO_EVENT_CODES_FILE_NAME), "w"
|
790
|
+
) as f:
|
791
|
+
json.dump(self._motor_time_to_event_codes, f)
|
792
|
+
|
583
793
|
self._pretrained_concept_embedding_model.save(save_directory)
|
584
794
|
|
585
795
|
if push_to_hub:
|
@@ -689,6 +899,22 @@ class CehrGptTokenizer(PreTrainedTokenizer):
|
|
689
899
|
return None
|
690
900
|
concept_name_mapping = load_json_file(concept_name_mapping_file)
|
691
901
|
|
902
|
+
# Load the concept_code_stats json file
|
903
|
+
concept_code_stats_mapping_file = transformers.utils.hub.cached_file(
|
904
|
+
pretrained_model_name_or_path, CONCEPT_STATS_FILE_NAME, **kwargs
|
905
|
+
)
|
906
|
+
if not concept_code_stats_mapping_file:
|
907
|
+
return None
|
908
|
+
concept_code_stats = load_json_file(concept_code_stats_mapping_file)
|
909
|
+
|
910
|
+
# Load the MOTOR time to event codes file
|
911
|
+
motor_time_to_event_codes_file = transformers.utils.hub.cached_file(
|
912
|
+
pretrained_model_name_or_path, MOTOR_TIME_TO_EVENT_CODES_FILE_NAME, **kwargs
|
913
|
+
)
|
914
|
+
if not motor_time_to_event_codes_file:
|
915
|
+
return None
|
916
|
+
motor_time_to_event_codes = load_json_file(motor_time_to_event_codes_file)
|
917
|
+
|
692
918
|
pretrained_embedding_model = PretrainedEmbeddings(pretrained_model_name_or_path)
|
693
919
|
|
694
920
|
return CehrGptTokenizer(
|
@@ -696,10 +922,12 @@ class CehrGptTokenizer(PreTrainedTokenizer):
|
|
696
922
|
value_tokenizer,
|
697
923
|
att_tokenizer,
|
698
924
|
token_to_sub_time_token_mapping,
|
925
|
+
concept_code_stats,
|
699
926
|
lab_stats["numeric_lab_stats"],
|
700
927
|
lab_stats["categorical_lab_stats"],
|
701
928
|
concept_name_mapping,
|
702
929
|
pretrained_embedding_model,
|
930
|
+
motor_time_to_event_codes,
|
703
931
|
)
|
704
932
|
|
705
933
|
@classmethod
|
@@ -722,6 +950,9 @@ class CehrGptTokenizer(PreTrainedTokenizer):
|
|
722
950
|
concept_name_mapping: Dict[str, str],
|
723
951
|
data_args: DataTrainingArguments,
|
724
952
|
pretrained_concept_embedding_model: PretrainedEmbeddings = None,
|
953
|
+
num_motor_tasks: Optional[int] = None,
|
954
|
+
apply_entropy_filter: bool = False,
|
955
|
+
min_prevalence: float = 1 / 1000,
|
725
956
|
):
|
726
957
|
if not isinstance(cehrgpt_tokenizer, CehrGptTokenizer):
|
727
958
|
raise ValueError(
|
@@ -734,6 +965,9 @@ class CehrGptTokenizer(PreTrainedTokenizer):
|
|
734
965
|
dataset=dataset,
|
735
966
|
concept_name_mapping=concept_name_mapping,
|
736
967
|
data_args=data_args,
|
968
|
+
num_motor_tasks=num_motor_tasks,
|
969
|
+
apply_entropy_filter=apply_entropy_filter,
|
970
|
+
min_prevalence=min_prevalence,
|
737
971
|
)
|
738
972
|
|
739
973
|
new_tokens = set(new_tokenizer.get_vocab().keys()) - set(
|
@@ -751,6 +985,7 @@ class CehrGptTokenizer(PreTrainedTokenizer):
|
|
751
985
|
new_numeric_lab_stats = new_tokenizer._numeric_lab_stats
|
752
986
|
new_categorical_lab_stats = new_tokenizer._categorical_lab_stats
|
753
987
|
new_concept_name_mapping = new_tokenizer._concept_name_mapping
|
988
|
+
new_motor_time_to_event_codes = new_tokenizer._motor_time_to_event_codes
|
754
989
|
|
755
990
|
# Add new tokens to the existing tokenizer
|
756
991
|
cehrgpt_tokenizer_copy._tokenizer.add_tokens(
|
@@ -799,15 +1034,31 @@ class CehrGptTokenizer(PreTrainedTokenizer):
|
|
799
1034
|
if token not in cehrgpt_tokenizer_copy._concept_name_mapping:
|
800
1035
|
cehrgpt_tokenizer_copy._concept_name_mapping[token] = concept_name
|
801
1036
|
|
1037
|
+
# Merge motor_time_to_event_codes
|
1038
|
+
if (
|
1039
|
+
new_motor_time_to_event_codes
|
1040
|
+
and cehrgpt_tokenizer_copy._motor_time_to_event_codes
|
1041
|
+
):
|
1042
|
+
for motor_time_to_event_code in new_motor_time_to_event_codes:
|
1043
|
+
if (
|
1044
|
+
motor_time_to_event_code
|
1045
|
+
not in cehrgpt_tokenizer_copy._motor_time_to_event_codes
|
1046
|
+
):
|
1047
|
+
cehrgpt_tokenizer_copy._motor_time_to_event_codes.append(
|
1048
|
+
motor_time_to_event_code
|
1049
|
+
)
|
1050
|
+
|
802
1051
|
return CehrGptTokenizer(
|
803
1052
|
tokenizer=cehrgpt_tokenizer_copy._tokenizer,
|
804
1053
|
value_tokenizer=cehrgpt_tokenizer_copy._value_tokenizer,
|
805
1054
|
att_tokenizer=cehrgpt_tokenizer_copy._att_tokenizer,
|
806
1055
|
token_to_sub_time_token_mapping=cehrgpt_tokenizer_copy._token_to_sub_time_token_mapping,
|
1056
|
+
concept_code_stats=cehrgpt_tokenizer_copy._concept_code_stats,
|
807
1057
|
numeric_lab_stats=cehrgpt_tokenizer_copy._numeric_lab_stats,
|
808
1058
|
categorical_lab_stats=cehrgpt_tokenizer_copy._categorical_lab_stats,
|
809
1059
|
concept_name_mapping=cehrgpt_tokenizer_copy._concept_name_mapping,
|
810
1060
|
pretrained_concept_embedding_model=pretrained_concept_embedding_model,
|
1061
|
+
motor_time_to_event_codes=cehrgpt_tokenizer_copy._motor_time_to_event_codes,
|
811
1062
|
)
|
812
1063
|
|
813
1064
|
@classmethod
|
@@ -877,6 +1128,10 @@ class CehrGptTokenizer(PreTrainedTokenizer):
|
|
877
1128
|
concept_name_mapping: Dict[str, str],
|
878
1129
|
data_args: DataTrainingArguments,
|
879
1130
|
pretrained_concept_embedding_model: PretrainedEmbeddings = None,
|
1131
|
+
allowed_motor_codes: Optional[List[int]] = None,
|
1132
|
+
num_motor_tasks: Optional[int] = None,
|
1133
|
+
apply_entropy_filter: bool = False,
|
1134
|
+
min_prevalence: float = 1 / 1000,
|
880
1135
|
):
|
881
1136
|
"""
|
882
1137
|
Train a huggingface word level tokenizer.
|
@@ -888,18 +1143,56 @@ class CehrGptTokenizer(PreTrainedTokenizer):
|
|
888
1143
|
if isinstance(dataset, DatasetDict):
|
889
1144
|
dataset = dataset["train"]
|
890
1145
|
|
1146
|
+
LOG.info("Calculating data statistics")
|
1147
|
+
cehrgpt_data_statistics = compute_statistics(dataset, data_args)
|
1148
|
+
cehrgpt_data_statistics["total"]
|
1149
|
+
numeric_lab_stats = cehrgpt_data_statistics["numeric_lab_stats"]
|
1150
|
+
categorical_lab_stats = cehrgpt_data_statistics["categorical_lab_stats"]
|
1151
|
+
concept_code_stats = cehrgpt_data_statistics["concept_code_stats"]
|
1152
|
+
concept_code_entropies = cehrgpt_data_statistics["concept_code_entropies"]
|
1153
|
+
|
1154
|
+
if apply_entropy_filter:
|
1155
|
+
min_prevalence = max(1e-8, min_prevalence)
|
1156
|
+
min_entropy = (
|
1157
|
+
np.log(1 - min_prevalence) * (1 - min_prevalence)
|
1158
|
+
+ np.log(min_prevalence) * min_prevalence
|
1159
|
+
)
|
1160
|
+
qualified_codes = [
|
1161
|
+
k
|
1162
|
+
for k, v in concept_code_entropies.items()
|
1163
|
+
if v <= min_entropy
|
1164
|
+
or not is_clinical_event(k, data_args.is_data_in_meds)
|
1165
|
+
]
|
1166
|
+
else:
|
1167
|
+
qualified_codes = [
|
1168
|
+
k
|
1169
|
+
for k, v in concept_code_stats.items()
|
1170
|
+
if min_prevalence <= v
|
1171
|
+
or not is_clinical_event(k, data_args.is_data_in_meds)
|
1172
|
+
]
|
1173
|
+
|
1174
|
+
# Create the tokenizer now
|
1175
|
+
LOG.info("Training the tokenizer for concepts")
|
891
1176
|
concept_tokenizer = cls.train_concept_tokenizer(
|
892
1177
|
dataset,
|
893
1178
|
feature_name="concept_ids",
|
894
|
-
special_tokens=[
|
1179
|
+
special_tokens=[
|
1180
|
+
PAD_TOKEN,
|
1181
|
+
OUT_OF_VOCABULARY_TOKEN,
|
1182
|
+
START_TOKEN,
|
1183
|
+
END_TOKEN,
|
1184
|
+
LINEAR_PROB_TOKEN,
|
1185
|
+
],
|
895
1186
|
unk_token=OUT_OF_VOCABULARY_TOKEN,
|
896
1187
|
data_args=data_args,
|
1188
|
+
qualified_codes=qualified_codes,
|
897
1189
|
)
|
898
1190
|
concept_value_column = "concept_as_values"
|
899
1191
|
for row in dataset:
|
900
1192
|
if concept_value_column not in row:
|
901
1193
|
concept_value_column = "concept_values"
|
902
1194
|
break
|
1195
|
+
LOG.info("Training the tokenizer for values")
|
903
1196
|
value_tokenizer = cls.train_concept_tokenizer(
|
904
1197
|
dataset,
|
905
1198
|
feature_name=concept_value_column,
|
@@ -915,65 +1208,11 @@ class CehrGptTokenizer(PreTrainedTokenizer):
|
|
915
1208
|
]
|
916
1209
|
)
|
917
1210
|
|
918
|
-
map_statistics_partial = partial(map_statistics, size=SAMPLE_SIZE)
|
919
|
-
|
920
|
-
if data_args.streaming:
|
921
|
-
first_example = next(iter(dataset))
|
922
|
-
parts = dataset.map(
|
923
|
-
partial(agg_helper, map_func=map_statistics_partial),
|
924
|
-
batched=True,
|
925
|
-
batch_size=data_args.preprocessing_batch_size,
|
926
|
-
remove_columns=first_example.keys(),
|
927
|
-
)
|
928
|
-
else:
|
929
|
-
parts = dataset.map(
|
930
|
-
partial(agg_helper, map_func=map_statistics_partial),
|
931
|
-
batched=True,
|
932
|
-
batch_size=data_args.preprocessing_batch_size,
|
933
|
-
remove_columns=dataset.column_names,
|
934
|
-
num_proc=data_args.preprocessing_num_workers,
|
935
|
-
keep_in_memory=True,
|
936
|
-
new_fingerprint="invalid",
|
937
|
-
)
|
938
|
-
current = None
|
939
|
-
for stat in tqdm(parts, desc="Aggregating the lab statistics"):
|
940
|
-
fixed_stat = pickle.loads(stat["data"])
|
941
|
-
if current is None:
|
942
|
-
current = fixed_stat
|
943
|
-
else:
|
944
|
-
current = agg_statistics(current, fixed_stat)
|
945
|
-
|
946
|
-
numeric_lab_stats = []
|
947
|
-
for (concept_id, unit), online_stats in current["numeric_stats_by_lab"].items():
|
948
|
-
if len(online_stats.samples) == 0:
|
949
|
-
continue
|
950
|
-
samples = truncated_sample(
|
951
|
-
online_stats.samples, data_args.value_outlier_std
|
952
|
-
)
|
953
|
-
bins = create_bins_with_spline(samples, NUM_OF_BINS, DEGREE_OF_FREEDOM)
|
954
|
-
if len(bins) > 0:
|
955
|
-
numeric_lab_stats.append(
|
956
|
-
{
|
957
|
-
"concept_id": concept_id,
|
958
|
-
"unit": unit,
|
959
|
-
"mean": np.mean(samples),
|
960
|
-
"std": np.std(samples),
|
961
|
-
"count": len(online_stats.samples),
|
962
|
-
"value_outlier_std": data_args.value_outlier_std,
|
963
|
-
"bins": bins,
|
964
|
-
}
|
965
|
-
)
|
966
|
-
|
967
|
-
categorical_lab_stats = collections.defaultdict(int)
|
968
|
-
for (concept_id, value_as_concept), count in current[
|
969
|
-
"categorical_stats_by_lab"
|
970
|
-
].items():
|
971
|
-
categorical_lab_stats[(concept_id, value_as_concept)] += count
|
972
|
-
|
973
1211
|
# We will train a tokenizer specifically for time intervals
|
974
1212
|
sub_time_token_data = []
|
975
1213
|
token_to_sub_time_token_mapping = collections.defaultdict(list)
|
976
|
-
|
1214
|
+
vocab = concept_tokenizer.get_vocab()
|
1215
|
+
for token, token_id in vocab.items():
|
977
1216
|
if is_att_token(token):
|
978
1217
|
time_interval = extract_time_interval_in_days(token)
|
979
1218
|
time_tuple = convert_time_interval_to_time_tuple(
|
@@ -994,15 +1233,43 @@ class CehrGptTokenizer(PreTrainedTokenizer):
|
|
994
1233
|
)
|
995
1234
|
att_tokenizer.train_from_iterator(sub_time_token_data, trainer=att_trainer)
|
996
1235
|
|
1236
|
+
# Prune concept_name_mapping
|
1237
|
+
concept_name_mapping = {
|
1238
|
+
concept_id: concept_name_mapping[concept_id]
|
1239
|
+
for concept_id in vocab.keys()
|
1240
|
+
if concept_id in concept_name_mapping
|
1241
|
+
}
|
1242
|
+
|
1243
|
+
motor_time_to_event_codes = None
|
1244
|
+
if num_motor_tasks and allowed_motor_codes:
|
1245
|
+
motor_time_to_event_codes = []
|
1246
|
+
for concept_id, _ in sorted(
|
1247
|
+
concept_code_entropies.items(), key=lambda t: t[1]
|
1248
|
+
):
|
1249
|
+
if (
|
1250
|
+
concept_id not in allowed_motor_codes
|
1251
|
+
or concept_id not in qualified_codes
|
1252
|
+
):
|
1253
|
+
continue
|
1254
|
+
if len(motor_time_to_event_codes) < num_motor_tasks:
|
1255
|
+
motor_time_to_event_codes.append(concept_id)
|
1256
|
+
else:
|
1257
|
+
break
|
1258
|
+
LOG.info(
|
1259
|
+
f"{len(motor_time_to_event_codes)} number of tasks have been added as MOTOR tasks"
|
1260
|
+
)
|
1261
|
+
|
997
1262
|
return CehrGptTokenizer(
|
998
1263
|
concept_tokenizer,
|
999
1264
|
value_tokenizer,
|
1000
1265
|
att_tokenizer,
|
1001
1266
|
token_to_sub_time_token_mapping,
|
1267
|
+
concept_code_stats,
|
1002
1268
|
numeric_lab_stats,
|
1003
1269
|
categorical_lab_stats,
|
1004
1270
|
concept_name_mapping,
|
1005
1271
|
pretrained_concept_embedding_model,
|
1272
|
+
motor_time_to_event_codes,
|
1006
1273
|
)
|
1007
1274
|
|
1008
1275
|
@classmethod
|
@@ -1013,6 +1280,7 @@ class CehrGptTokenizer(PreTrainedTokenizer):
|
|
1013
1280
|
special_tokens: List[str],
|
1014
1281
|
unk_token,
|
1015
1282
|
data_args,
|
1283
|
+
qualified_codes: Optional[List[str]] = None,
|
1016
1284
|
):
|
1017
1285
|
# Use the Fast Tokenizer from the Huggingface tokenizers Rust implementation.
|
1018
1286
|
# https://github.com/huggingface/tokenizers
|
@@ -1025,7 +1293,9 @@ class CehrGptTokenizer(PreTrainedTokenizer):
|
|
1025
1293
|
show_progress=True,
|
1026
1294
|
)
|
1027
1295
|
batch_concat_concepts_partial_func = partial(
|
1028
|
-
cls.batch_concat_concepts,
|
1296
|
+
cls.batch_concat_concepts,
|
1297
|
+
feature_name=feature_name,
|
1298
|
+
qualified_codes=qualified_codes,
|
1029
1299
|
)
|
1030
1300
|
if data_args.streaming:
|
1031
1301
|
concatenated_features = dataset.map(
|
@@ -1065,13 +1335,30 @@ class CehrGptTokenizer(PreTrainedTokenizer):
|
|
1065
1335
|
|
1066
1336
|
@classmethod
|
1067
1337
|
def batch_concat_concepts(
|
1068
|
-
cls,
|
1338
|
+
cls,
|
1339
|
+
records: Dict[str, List],
|
1340
|
+
feature_name: str,
|
1341
|
+
qualified_codes: Optional[List[str]] = None,
|
1069
1342
|
) -> Dict[str, List]:
|
1343
|
+
def filter_token(t: str) -> bool:
|
1344
|
+
"""
|
1345
|
+
If the token is None or not string, return False.
|
1346
|
+
|
1347
|
+
When qualified_codes is provided, t must be in
|
1348
|
+
qualified_codes to be valid, otherwise the tokens are always valid
|
1349
|
+
|
1350
|
+
:param t:
|
1351
|
+
:return:
|
1352
|
+
"""
|
1353
|
+
if t is None or not isinstance(t, str):
|
1354
|
+
return False
|
1355
|
+
if qualified_codes:
|
1356
|
+
return t in qualified_codes
|
1357
|
+
return True
|
1358
|
+
|
1070
1359
|
return {
|
1071
1360
|
feature_name: [
|
1072
|
-
" ".join(
|
1073
|
-
[token for token in tokens if token and isinstance(token, str)]
|
1074
|
-
)
|
1361
|
+
" ".join(filter(filter_token, tokens))
|
1075
1362
|
for tokens in records[feature_name]
|
1076
1363
|
]
|
1077
1364
|
}
|