cehrgpt 0.1.0__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. cehrgpt/analysis/irregularity.py +36 -0
  2. cehrgpt/data/hf_cehrgpt_dataset.py +1 -0
  3. cehrgpt/data/hf_cehrgpt_dataset_collator.py +454 -68
  4. cehrgpt/data/hf_cehrgpt_dataset_mapping.py +232 -17
  5. cehrgpt/data/sample_packing_sampler.py +36 -6
  6. cehrgpt/generation/cehrgpt_conditional_generation.py +314 -0
  7. cehrgpt/generation/generate_batch_hf_gpt_sequence.py +15 -3
  8. cehrgpt/generation/omop_converter_batch.py +32 -2
  9. cehrgpt/gpt_utils.py +20 -2
  10. cehrgpt/models/config.py +25 -0
  11. cehrgpt/models/hf_cehrgpt.py +244 -39
  12. cehrgpt/models/hf_modeling_outputs.py +1 -0
  13. cehrgpt/models/special_tokens.py +1 -0
  14. cehrgpt/models/tokenization_hf_cehrgpt.py +354 -71
  15. cehrgpt/runners/data_utils.py +131 -5
  16. cehrgpt/runners/hf_cehrgpt_finetune_runner.py +84 -51
  17. cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +59 -7
  18. cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +60 -0
  19. cehrgpt/runners/hyperparameter_search_util.py +6 -7
  20. cehrgpt/runners/sample_packing_trainer.py +17 -0
  21. cehrgpt/time_to_event/config/1_year_cabg.yaml +23 -0
  22. cehrgpt/time_to_event/time_to_event_model.py +2 -13
  23. cehrgpt/time_to_event/time_to_event_prediction.py +27 -13
  24. cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +80 -62
  25. {cehrgpt-0.1.0.dist-info → cehrgpt-0.1.2.dist-info}/METADATA +102 -7
  26. {cehrgpt-0.1.0.dist-info → cehrgpt-0.1.2.dist-info}/RECORD +29 -26
  27. {cehrgpt-0.1.0.dist-info → cehrgpt-0.1.2.dist-info}/WHEEL +1 -1
  28. {cehrgpt-0.1.0.dist-info → cehrgpt-0.1.2.dist-info}/licenses/LICENSE +0 -0
  29. {cehrgpt-0.1.0.dist-info → cehrgpt-0.1.2.dist-info}/top_level.txt +0 -0
@@ -10,13 +10,9 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
10
10
  import numpy as np
11
11
  import scipy.stats as stats
12
12
  import transformers
13
- from cehrbert.models.hf_models.tokenization_utils import (
14
- agg_helper,
15
- agg_statistics,
16
- load_json_file,
17
- )
13
+ from cehrbert.models.hf_models.tokenization_utils import agg_helper, load_json_file
18
14
  from cehrbert.runners.hf_runner_argument_dataclass import DataTrainingArguments
19
- from datasets import Dataset, DatasetDict
15
+ from datasets import Dataset, DatasetDict, IterableDataset
20
16
  from femr.stat_utils import OnlineStatistics, ReservoirSampler
21
17
  from scipy.interpolate import UnivariateSpline
22
18
  from tokenizers import AddedToken, Tokenizer
@@ -31,11 +27,13 @@ from cehrgpt.gpt_utils import (
31
27
  convert_time_interval_to_time_tuple,
32
28
  extract_time_interval_in_days,
33
29
  is_att_token,
30
+ is_clinical_event,
34
31
  is_inpatient_att_token,
35
32
  )
36
33
  from cehrgpt.models.pretrained_embeddings import PretrainedEmbeddings
37
34
  from cehrgpt.models.special_tokens import (
38
35
  END_TOKEN,
36
+ LINEAR_PROB_TOKEN,
39
37
  OUT_OF_VOCABULARY_TOKEN,
40
38
  PAD_TOKEN,
41
39
  START_TOKEN,
@@ -53,7 +51,10 @@ TIME_TOKENIZER_FILE_NAME = "cehrgpt_time_tokenizer.json"
53
51
  TOKEN_TO_SUB_TIME_TOKEN_MAPPING_FILE_NAME = "token_to_sub_time_token_mapping.json"
54
52
  LAB_STATS_FILE_NAME = "cehrgpt_lab_stats.pickle"
55
53
  LEGACY_LAB_STATS_FILE_NAME = "cehrgpt_lab_stats.json"
54
+ CONCEPT_STATS_FILE_NAME = "cehrgpt_concept_stats.json"
56
55
  CONCEPT_MAPPING_FILE_NAME = "concept_name_mapping.json"
56
+ MOTOR_TIME_TO_EVENT_CODES_FILE_NAME = "motor_time_to_event_codes.json"
57
+
57
58
  LOG = logging.get_logger("transformers")
58
59
 
59
60
 
@@ -73,6 +74,16 @@ def create_value_bin(bin_index: int) -> str:
73
74
  return "BIN:" + str(bin_index)
74
75
 
75
76
 
77
+ def get_dataset_len(dataset: Union[Dataset, IterableDataset]) -> int:
78
+ if isinstance(dataset, Dataset):
79
+ return len(dataset)
80
+ elif isinstance(dataset, IterableDataset):
81
+ return sum([1 for _ in dataset])
82
+ raise RuntimeError(
83
+ "The dataset must be one of the two types (Dataset, IterableDataset)"
84
+ )
85
+
86
+
76
87
  def create_sample_from_bins(bins, sample_size: int = 10_000) -> List[float]:
77
88
  """
78
89
  Generates a specified number of samples from a list of bins, each containing a fitted spline.
@@ -188,7 +199,22 @@ def create_bins_with_spline(samples, num_bins, d_freedom=3) -> List[Dict[str, An
188
199
  return bins
189
200
 
190
201
 
191
- def map_statistics(batch: Dict[str, Any], size=10_000) -> Dict[str, Any]:
202
+ def agg_statistics(stats1, stats2):
203
+ if stats1.get("numeric_stats_by_lab"):
204
+ for k, v in stats2["numeric_stats_by_lab"].items():
205
+ stats1["numeric_stats_by_lab"][k].combine(v)
206
+ if stats1.get("categorical_stats_by_lab"):
207
+ for (concept_id, concept_as_value), count in stats2[
208
+ "categorical_stats_by_lab"
209
+ ].items():
210
+ stats1["categorical_stats_by_lab"][(concept_id, concept_as_value)] += count
211
+ if stats1.get("concept_code_stats"):
212
+ for concept_id, weight in stats2["concept_code_stats"].items():
213
+ stats1["concept_code_stats"][concept_id] += weight
214
+ return stats1
215
+
216
+
217
+ def map_statistics(batch: Dict[str, Any], total_size, size=10_000) -> Dict[str, Any]:
192
218
  if "units" in batch:
193
219
  batch_value_units = batch["units"]
194
220
  else:
@@ -212,6 +238,7 @@ def map_statistics(batch: Dict[str, Any], size=10_000) -> Dict[str, Any]:
212
238
 
213
239
  numeric_stats_by_lab = collections.defaultdict(partial(ReservoirSampler, size=size))
214
240
  categorical_stats_by_lab = collections.defaultdict(int)
241
+ concept_code_stats = collections.defaultdict(int)
215
242
  for (
216
243
  concept_ids,
217
244
  number_as_values,
@@ -225,6 +252,7 @@ def map_statistics(batch: Dict[str, Any], size=10_000) -> Dict[str, Any]:
225
252
  batch["concept_value_masks"],
226
253
  batch_value_units,
227
254
  ):
255
+ unique_codes = set()
228
256
  for (
229
257
  concept_id,
230
258
  number_as_value,
@@ -243,10 +271,94 @@ def map_statistics(batch: Dict[str, Any], size=10_000) -> Dict[str, Any]:
243
271
  numeric_stats_by_lab[(concept_id, unit)].add(number_as_value, 1)
244
272
  if concept_as_value:
245
273
  categorical_stats_by_lab[(concept_id, concept_as_value)] += 1
274
+ unique_codes.add(concept_id)
275
+
276
+ for code in unique_codes:
277
+ concept_code_stats[code] += 1 / total_size
246
278
 
247
279
  return {
248
280
  "numeric_stats_by_lab": numeric_stats_by_lab,
249
281
  "categorical_stats_by_lab": categorical_stats_by_lab,
282
+ "concept_code_stats": concept_code_stats,
283
+ }
284
+
285
+
286
+ def compute_statistics(
287
+ dataset: Dataset, data_args: DataTrainingArguments
288
+ ) -> Dict[str, Any]:
289
+ total = get_dataset_len(dataset)
290
+ map_statistics_partial = partial(map_statistics, total_size=total, size=SAMPLE_SIZE)
291
+ if data_args.streaming:
292
+ first_example = next(iter(dataset))
293
+ parts = dataset.map(
294
+ partial(agg_helper, map_func=map_statistics_partial),
295
+ batched=True,
296
+ batch_size=data_args.preprocessing_batch_size,
297
+ remove_columns=first_example.keys(),
298
+ )
299
+ else:
300
+ parts = dataset.map(
301
+ partial(agg_helper, map_func=map_statistics_partial),
302
+ batched=True,
303
+ batch_size=data_args.preprocessing_batch_size,
304
+ remove_columns=dataset.column_names,
305
+ num_proc=data_args.preprocessing_num_workers,
306
+ keep_in_memory=True,
307
+ new_fingerprint="invalid",
308
+ )
309
+ current = None
310
+ for stat in tqdm(parts, desc="Aggregating the lab statistics"):
311
+ fixed_stat = pickle.loads(stat["data"])
312
+ if current is None:
313
+ current = fixed_stat
314
+ else:
315
+ current = agg_statistics(current, fixed_stat)
316
+
317
+ numeric_lab_stats = []
318
+ for (concept_id, unit), online_stats in current["numeric_stats_by_lab"].items():
319
+ if len(online_stats.samples) == 0:
320
+ continue
321
+ samples = truncated_sample(online_stats.samples, data_args.value_outlier_std)
322
+ bins = create_bins_with_spline(samples, NUM_OF_BINS, DEGREE_OF_FREEDOM)
323
+ if len(bins) > 0:
324
+ numeric_lab_stats.append(
325
+ {
326
+ "concept_id": concept_id,
327
+ "unit": unit,
328
+ "mean": np.mean(samples),
329
+ "std": np.std(samples),
330
+ "count": len(online_stats.samples),
331
+ "value_outlier_std": data_args.value_outlier_std,
332
+ "bins": bins,
333
+ }
334
+ )
335
+
336
+ categorical_lab_stats = collections.defaultdict(int)
337
+ for (concept_id, value_as_concept), count in current[
338
+ "categorical_stats_by_lab"
339
+ ].items():
340
+ categorical_lab_stats[(concept_id, value_as_concept)] += count
341
+
342
+ concept_code_stats = collections.defaultdict(int)
343
+ for concept_id, count in current["concept_code_stats"].items():
344
+ concept_code_stats[concept_id] += count
345
+
346
+ code_weights = np.asarray(list(concept_code_stats.values())).clip(1e-8, 1 - 1e-8)
347
+ # Clip the values so we don't get errors when applying np.log
348
+ code_entropies = np.log(code_weights) * code_weights + (1 - code_weights) * np.log(
349
+ 1 - code_weights
350
+ )
351
+
352
+ concept_code_entropies = {
353
+ k: v for k, v in zip(concept_code_stats.keys(), code_entropies)
354
+ }
355
+
356
+ return {
357
+ "numeric_lab_stats": numeric_lab_stats,
358
+ "categorical_lab_stats": categorical_lab_stats,
359
+ "concept_code_stats": concept_code_stats,
360
+ "concept_code_entropies": concept_code_entropies,
361
+ "total": total,
250
362
  }
251
363
 
252
364
 
@@ -348,15 +460,18 @@ class CehrGptTokenizer(PreTrainedTokenizer):
348
460
  value_tokenizer: Tokenizer,
349
461
  att_tokenizer: Tokenizer,
350
462
  token_to_sub_time_token_mapping: Dict[str, List[str]],
463
+ concept_code_stats: Dict[str, Any],
351
464
  numeric_lab_stats: List[Dict[str, Any]],
352
465
  categorical_lab_stats: Dict[Tuple[str, str], int],
353
466
  concept_name_mapping: Dict[str, str],
354
467
  pretrained_concept_embedding_model: PretrainedEmbeddings = None,
468
+ motor_time_to_event_codes: Optional[List[str]] = None,
355
469
  ):
356
470
  self._tokenizer = tokenizer
357
471
  self._value_tokenizer = value_tokenizer
358
472
  self._att_tokenizer = att_tokenizer
359
473
  self._token_to_sub_time_token_mapping = token_to_sub_time_token_mapping
474
+ self._concept_code_stats = concept_code_stats
360
475
  self._numeric_lab_stats = numeric_lab_stats
361
476
  self._numeric_event_statistics = NumericEventStatistics(numeric_lab_stats)
362
477
  self._categorical_lab_stats = categorical_lab_stats
@@ -365,6 +480,12 @@ class CehrGptTokenizer(PreTrainedTokenizer):
365
480
  self._padding_token_id = self._tokenizer.token_to_id(PAD_TOKEN)
366
481
  self._start_token_id = self._tokenizer.token_to_id(START_TOKEN)
367
482
  self._end_token_id = self._tokenizer.token_to_id(END_TOKEN)
483
+
484
+ # Backward compatible with the old tokenizer
485
+ self._linear_token_id = None
486
+ if LINEAR_PROB_TOKEN in self._tokenizer.get_vocab():
487
+ self._linear_token_id = self._tokenizer.token_to_id(LINEAR_PROB_TOKEN)
488
+
368
489
  self._numeric_concept_ids = (
369
490
  self._numeric_event_statistics.get_numeric_concept_ids()
370
491
  )
@@ -382,6 +503,12 @@ class CehrGptTokenizer(PreTrainedTokenizer):
382
503
  for _ in self.get_vocab().keys()
383
504
  if self._pretrained_concept_embedding_model.is_concept_available(_)
384
505
  ]
506
+ self._motor_time_to_event_codes = (
507
+ motor_time_to_event_codes if motor_time_to_event_codes else []
508
+ )
509
+ self._motor_code_to_id_mapping = {
510
+ code: i for i, code in enumerate(sorted(self._motor_time_to_event_codes))
511
+ }
385
512
 
386
513
  super().__init__()
387
514
 
@@ -402,6 +529,10 @@ class CehrGptTokenizer(PreTrainedTokenizer):
402
529
  ]
403
530
  )
404
531
 
532
+ @property
533
+ def motor_tte_vocab_size(self) -> int:
534
+ return len(self._motor_code_to_id_mapping)
535
+
405
536
  @property
406
537
  def vocab_size(self) -> int:
407
538
  return self._tokenizer.get_vocab_size()
@@ -442,10 +573,45 @@ class CehrGptTokenizer(PreTrainedTokenizer):
442
573
  def pad_token_id(self):
443
574
  return self._padding_token_id
444
575
 
576
+ @property
577
+ def oov_token_id(self):
578
+ return self._oov_token_id
579
+
445
580
  @property
446
581
  def pad_token(self):
447
582
  return PAD_TOKEN
448
583
 
584
+ @property
585
+ def linear_token_id(self) -> Optional[int]:
586
+ return self._linear_token_id
587
+
588
+ @property
589
+ def linear_token(self) -> Optional[str]:
590
+ if LINEAR_PROB_TOKEN in self._tokenizer.get_vocab():
591
+ return LINEAR_PROB_TOKEN
592
+ return None
593
+
594
+ def vs_token_id(self):
595
+ # We used VS for the historical data, currently, we use the new [VS] for the newer data
596
+ # so we need to check both cases.
597
+ if "VS" in self._tokenizer.get_vocab():
598
+ return self._convert_token_to_id("VS")
599
+ elif "[VS]" in self._tokenizer.get_vocab():
600
+ return self._convert_token_to_id("[VS]")
601
+ else:
602
+ raise RuntimeError("The tokenizer does not contain either VS or [VS]")
603
+
604
+ @property
605
+ def ve_token_id(self):
606
+ # We used VE for the historical data, currently, we use the new [VE] for the newer data
607
+ # so we need to check both cases.
608
+ if "VE" in self._tokenizer.get_vocab():
609
+ return self._convert_token_to_id("VE")
610
+ elif "[VE]" in self._tokenizer.get_vocab():
611
+ return self._convert_token_to_id("[VE]")
612
+ else:
613
+ raise RuntimeError("The tokenizer does not contain either VE or [VE]")
614
+
449
615
  @property
450
616
  def numeric_concept_ids(self):
451
617
  return self._numeric_concept_ids
@@ -456,8 +622,14 @@ class CehrGptTokenizer(PreTrainedTokenizer):
456
622
 
457
623
  @property
458
624
  def lab_token_ids(self):
459
- reserved_tokens = [START_TOKEN, PAD_TOKEN, END_TOKEN, OUT_OF_VOCABULARY_TOKEN]
460
- return self.encode(
625
+ reserved_tokens = [
626
+ START_TOKEN,
627
+ PAD_TOKEN,
628
+ END_TOKEN,
629
+ OUT_OF_VOCABULARY_TOKEN,
630
+ LINEAR_PROB_TOKEN,
631
+ ]
632
+ lab_token_ids = self.encode(
461
633
  [
462
634
  concept_id
463
635
  for concept_id in self._numeric_concept_ids
@@ -465,6 +637,9 @@ class CehrGptTokenizer(PreTrainedTokenizer):
465
637
  if concept_id not in reserved_tokens
466
638
  ]
467
639
  )
640
+ return list(
641
+ filter(lambda token_id: token_id != self._oov_token_id, lab_token_ids)
642
+ )
468
643
 
469
644
  @property
470
645
  def token_to_time_token_mapping(self) -> Dict[int, List[int]]:
@@ -483,6 +658,19 @@ class CehrGptTokenizer(PreTrainedTokenizer):
483
658
  def pretrained_concept_embedding_model(self):
484
659
  return self._pretrained_concept_embedding_model
485
660
 
661
+ def get_motor_token_id(self, concept_id: str) -> int:
662
+ if concept_id not in concept_id:
663
+ raise RuntimeError(f"Invalid motor concept id: {concept_id}")
664
+ return self._motor_code_to_id_mapping[concept_id]
665
+
666
+ def is_motor_time_to_event_code(self, future_concept_id: str) -> bool:
667
+ if (
668
+ self._motor_time_to_event_codes
669
+ and future_concept_id in self._motor_time_to_event_codes
670
+ ):
671
+ return True
672
+ return False
673
+
486
674
  def get_vocab(self) -> Dict[str, int]:
487
675
  return self._tokenizer.get_vocab()
488
676
 
@@ -511,6 +699,18 @@ class CehrGptTokenizer(PreTrainedTokenizer):
511
699
  concept_value_token_ids, skip_special_tokens=skip_special_tokens
512
700
  ).split(" ")
513
701
 
702
+ def add_token(self, tokens: Union[str, List[str]]) -> None:
703
+ if isinstance(tokens, str):
704
+ tokens = [tokens]
705
+ vocab = self.get_vocab()
706
+ self._tokenizer.add_tokens(
707
+ [
708
+ AddedToken(token, single_word=True, normalized=False)
709
+ for token in tokens
710
+ if token not in vocab
711
+ ]
712
+ )
713
+
514
714
  def _convert_token_to_id(self, token):
515
715
  """Converts a token (str) in an id using the vocab."""
516
716
  token_id = self._tokenizer.token_to_id(token)
@@ -572,6 +772,9 @@ class CehrGptTokenizer(PreTrainedTokenizer):
572
772
  ) as f:
573
773
  json.dump(self._token_to_sub_time_token_mapping, f)
574
774
 
775
+ with open(os.path.join(save_directory, CONCEPT_STATS_FILE_NAME), "w") as f:
776
+ json.dump(self._concept_code_stats, f)
777
+
575
778
  with open(os.path.join(save_directory, LAB_STATS_FILE_NAME), "wb") as f:
576
779
  lab_stats = {
577
780
  "numeric_lab_stats": self._numeric_lab_stats,
@@ -582,6 +785,11 @@ class CehrGptTokenizer(PreTrainedTokenizer):
582
785
  with open(os.path.join(save_directory, CONCEPT_MAPPING_FILE_NAME), "w") as f:
583
786
  json.dump(self._concept_name_mapping, f)
584
787
 
788
+ with open(
789
+ os.path.join(save_directory, MOTOR_TIME_TO_EVENT_CODES_FILE_NAME), "w"
790
+ ) as f:
791
+ json.dump(self._motor_time_to_event_codes, f)
792
+
585
793
  self._pretrained_concept_embedding_model.save(save_directory)
586
794
 
587
795
  if push_to_hub:
@@ -691,6 +899,22 @@ class CehrGptTokenizer(PreTrainedTokenizer):
691
899
  return None
692
900
  concept_name_mapping = load_json_file(concept_name_mapping_file)
693
901
 
902
+ # Load the concept_code_stats json file
903
+ concept_code_stats_mapping_file = transformers.utils.hub.cached_file(
904
+ pretrained_model_name_or_path, CONCEPT_STATS_FILE_NAME, **kwargs
905
+ )
906
+ if not concept_code_stats_mapping_file:
907
+ return None
908
+ concept_code_stats = load_json_file(concept_code_stats_mapping_file)
909
+
910
+ # Load the MOTOR time to event codes file
911
+ motor_time_to_event_codes_file = transformers.utils.hub.cached_file(
912
+ pretrained_model_name_or_path, MOTOR_TIME_TO_EVENT_CODES_FILE_NAME, **kwargs
913
+ )
914
+ if not motor_time_to_event_codes_file:
915
+ return None
916
+ motor_time_to_event_codes = load_json_file(motor_time_to_event_codes_file)
917
+
694
918
  pretrained_embedding_model = PretrainedEmbeddings(pretrained_model_name_or_path)
695
919
 
696
920
  return CehrGptTokenizer(
@@ -698,10 +922,12 @@ class CehrGptTokenizer(PreTrainedTokenizer):
698
922
  value_tokenizer,
699
923
  att_tokenizer,
700
924
  token_to_sub_time_token_mapping,
925
+ concept_code_stats,
701
926
  lab_stats["numeric_lab_stats"],
702
927
  lab_stats["categorical_lab_stats"],
703
928
  concept_name_mapping,
704
929
  pretrained_embedding_model,
930
+ motor_time_to_event_codes,
705
931
  )
706
932
 
707
933
  @classmethod
@@ -724,6 +950,9 @@ class CehrGptTokenizer(PreTrainedTokenizer):
724
950
  concept_name_mapping: Dict[str, str],
725
951
  data_args: DataTrainingArguments,
726
952
  pretrained_concept_embedding_model: PretrainedEmbeddings = None,
953
+ num_motor_tasks: Optional[int] = None,
954
+ apply_entropy_filter: bool = False,
955
+ min_prevalence: float = 1 / 1000,
727
956
  ):
728
957
  if not isinstance(cehrgpt_tokenizer, CehrGptTokenizer):
729
958
  raise ValueError(
@@ -736,6 +965,9 @@ class CehrGptTokenizer(PreTrainedTokenizer):
736
965
  dataset=dataset,
737
966
  concept_name_mapping=concept_name_mapping,
738
967
  data_args=data_args,
968
+ num_motor_tasks=num_motor_tasks,
969
+ apply_entropy_filter=apply_entropy_filter,
970
+ min_prevalence=min_prevalence,
739
971
  )
740
972
 
741
973
  new_tokens = set(new_tokenizer.get_vocab().keys()) - set(
@@ -753,6 +985,7 @@ class CehrGptTokenizer(PreTrainedTokenizer):
753
985
  new_numeric_lab_stats = new_tokenizer._numeric_lab_stats
754
986
  new_categorical_lab_stats = new_tokenizer._categorical_lab_stats
755
987
  new_concept_name_mapping = new_tokenizer._concept_name_mapping
988
+ new_motor_time_to_event_codes = new_tokenizer._motor_time_to_event_codes
756
989
 
757
990
  # Add new tokens to the existing tokenizer
758
991
  cehrgpt_tokenizer_copy._tokenizer.add_tokens(
@@ -801,15 +1034,31 @@ class CehrGptTokenizer(PreTrainedTokenizer):
801
1034
  if token not in cehrgpt_tokenizer_copy._concept_name_mapping:
802
1035
  cehrgpt_tokenizer_copy._concept_name_mapping[token] = concept_name
803
1036
 
1037
+ # Merge motor_time_to_event_codes
1038
+ if (
1039
+ new_motor_time_to_event_codes
1040
+ and cehrgpt_tokenizer_copy._motor_time_to_event_codes
1041
+ ):
1042
+ for motor_time_to_event_code in new_motor_time_to_event_codes:
1043
+ if (
1044
+ motor_time_to_event_code
1045
+ not in cehrgpt_tokenizer_copy._motor_time_to_event_codes
1046
+ ):
1047
+ cehrgpt_tokenizer_copy._motor_time_to_event_codes.append(
1048
+ motor_time_to_event_code
1049
+ )
1050
+
804
1051
  return CehrGptTokenizer(
805
1052
  tokenizer=cehrgpt_tokenizer_copy._tokenizer,
806
1053
  value_tokenizer=cehrgpt_tokenizer_copy._value_tokenizer,
807
1054
  att_tokenizer=cehrgpt_tokenizer_copy._att_tokenizer,
808
1055
  token_to_sub_time_token_mapping=cehrgpt_tokenizer_copy._token_to_sub_time_token_mapping,
1056
+ concept_code_stats=cehrgpt_tokenizer_copy._concept_code_stats,
809
1057
  numeric_lab_stats=cehrgpt_tokenizer_copy._numeric_lab_stats,
810
1058
  categorical_lab_stats=cehrgpt_tokenizer_copy._categorical_lab_stats,
811
1059
  concept_name_mapping=cehrgpt_tokenizer_copy._concept_name_mapping,
812
1060
  pretrained_concept_embedding_model=pretrained_concept_embedding_model,
1061
+ motor_time_to_event_codes=cehrgpt_tokenizer_copy._motor_time_to_event_codes,
813
1062
  )
814
1063
 
815
1064
  @classmethod
@@ -879,6 +1128,10 @@ class CehrGptTokenizer(PreTrainedTokenizer):
879
1128
  concept_name_mapping: Dict[str, str],
880
1129
  data_args: DataTrainingArguments,
881
1130
  pretrained_concept_embedding_model: PretrainedEmbeddings = None,
1131
+ allowed_motor_codes: Optional[List[int]] = None,
1132
+ num_motor_tasks: Optional[int] = None,
1133
+ apply_entropy_filter: bool = False,
1134
+ min_prevalence: float = 1 / 1000,
882
1135
  ):
883
1136
  """
884
1137
  Train a huggingface word level tokenizer.
@@ -890,13 +1143,49 @@ class CehrGptTokenizer(PreTrainedTokenizer):
890
1143
  if isinstance(dataset, DatasetDict):
891
1144
  dataset = dataset["train"]
892
1145
 
1146
+ LOG.info("Calculating data statistics")
1147
+ cehrgpt_data_statistics = compute_statistics(dataset, data_args)
1148
+ cehrgpt_data_statistics["total"]
1149
+ numeric_lab_stats = cehrgpt_data_statistics["numeric_lab_stats"]
1150
+ categorical_lab_stats = cehrgpt_data_statistics["categorical_lab_stats"]
1151
+ concept_code_stats = cehrgpt_data_statistics["concept_code_stats"]
1152
+ concept_code_entropies = cehrgpt_data_statistics["concept_code_entropies"]
1153
+
1154
+ if apply_entropy_filter:
1155
+ min_prevalence = max(1e-8, min_prevalence)
1156
+ min_entropy = (
1157
+ np.log(1 - min_prevalence) * (1 - min_prevalence)
1158
+ + np.log(min_prevalence) * min_prevalence
1159
+ )
1160
+ qualified_codes = [
1161
+ k
1162
+ for k, v in concept_code_entropies.items()
1163
+ if v <= min_entropy
1164
+ or not is_clinical_event(k, data_args.is_data_in_meds)
1165
+ ]
1166
+ else:
1167
+ qualified_codes = [
1168
+ k
1169
+ for k, v in concept_code_stats.items()
1170
+ if min_prevalence <= v
1171
+ or not is_clinical_event(k, data_args.is_data_in_meds)
1172
+ ]
1173
+
1174
+ # Create the tokenizer now
893
1175
  LOG.info("Training the tokenizer for concepts")
894
1176
  concept_tokenizer = cls.train_concept_tokenizer(
895
1177
  dataset,
896
1178
  feature_name="concept_ids",
897
- special_tokens=[PAD_TOKEN, OUT_OF_VOCABULARY_TOKEN, START_TOKEN, END_TOKEN],
1179
+ special_tokens=[
1180
+ PAD_TOKEN,
1181
+ OUT_OF_VOCABULARY_TOKEN,
1182
+ START_TOKEN,
1183
+ END_TOKEN,
1184
+ LINEAR_PROB_TOKEN,
1185
+ ],
898
1186
  unk_token=OUT_OF_VOCABULARY_TOKEN,
899
1187
  data_args=data_args,
1188
+ qualified_codes=qualified_codes,
900
1189
  )
901
1190
  concept_value_column = "concept_as_values"
902
1191
  for row in dataset:
@@ -919,65 +1208,11 @@ class CehrGptTokenizer(PreTrainedTokenizer):
919
1208
  ]
920
1209
  )
921
1210
 
922
- map_statistics_partial = partial(map_statistics, size=SAMPLE_SIZE)
923
-
924
- if data_args.streaming:
925
- first_example = next(iter(dataset))
926
- parts = dataset.map(
927
- partial(agg_helper, map_func=map_statistics_partial),
928
- batched=True,
929
- batch_size=data_args.preprocessing_batch_size,
930
- remove_columns=first_example.keys(),
931
- )
932
- else:
933
- parts = dataset.map(
934
- partial(agg_helper, map_func=map_statistics_partial),
935
- batched=True,
936
- batch_size=data_args.preprocessing_batch_size,
937
- remove_columns=dataset.column_names,
938
- num_proc=data_args.preprocessing_num_workers,
939
- keep_in_memory=True,
940
- new_fingerprint="invalid",
941
- )
942
- current = None
943
- for stat in tqdm(parts, desc="Aggregating the lab statistics"):
944
- fixed_stat = pickle.loads(stat["data"])
945
- if current is None:
946
- current = fixed_stat
947
- else:
948
- current = agg_statistics(current, fixed_stat)
949
-
950
- numeric_lab_stats = []
951
- for (concept_id, unit), online_stats in current["numeric_stats_by_lab"].items():
952
- if len(online_stats.samples) == 0:
953
- continue
954
- samples = truncated_sample(
955
- online_stats.samples, data_args.value_outlier_std
956
- )
957
- bins = create_bins_with_spline(samples, NUM_OF_BINS, DEGREE_OF_FREEDOM)
958
- if len(bins) > 0:
959
- numeric_lab_stats.append(
960
- {
961
- "concept_id": concept_id,
962
- "unit": unit,
963
- "mean": np.mean(samples),
964
- "std": np.std(samples),
965
- "count": len(online_stats.samples),
966
- "value_outlier_std": data_args.value_outlier_std,
967
- "bins": bins,
968
- }
969
- )
970
-
971
- categorical_lab_stats = collections.defaultdict(int)
972
- for (concept_id, value_as_concept), count in current[
973
- "categorical_stats_by_lab"
974
- ].items():
975
- categorical_lab_stats[(concept_id, value_as_concept)] += count
976
-
977
1211
  # We will train a tokenizer specifically for time intervals
978
1212
  sub_time_token_data = []
979
1213
  token_to_sub_time_token_mapping = collections.defaultdict(list)
980
- for token, token_id in concept_tokenizer.get_vocab().items():
1214
+ vocab = concept_tokenizer.get_vocab()
1215
+ for token, token_id in vocab.items():
981
1216
  if is_att_token(token):
982
1217
  time_interval = extract_time_interval_in_days(token)
983
1218
  time_tuple = convert_time_interval_to_time_tuple(
@@ -998,15 +1233,43 @@ class CehrGptTokenizer(PreTrainedTokenizer):
998
1233
  )
999
1234
  att_tokenizer.train_from_iterator(sub_time_token_data, trainer=att_trainer)
1000
1235
 
1236
+ # Prune concept_name_mapping
1237
+ concept_name_mapping = {
1238
+ concept_id: concept_name_mapping[concept_id]
1239
+ for concept_id in vocab.keys()
1240
+ if concept_id in concept_name_mapping
1241
+ }
1242
+
1243
+ motor_time_to_event_codes = None
1244
+ if num_motor_tasks and allowed_motor_codes:
1245
+ motor_time_to_event_codes = []
1246
+ for concept_id, _ in sorted(
1247
+ concept_code_entropies.items(), key=lambda t: t[1]
1248
+ ):
1249
+ if (
1250
+ concept_id not in allowed_motor_codes
1251
+ or concept_id not in qualified_codes
1252
+ ):
1253
+ continue
1254
+ if len(motor_time_to_event_codes) < num_motor_tasks:
1255
+ motor_time_to_event_codes.append(concept_id)
1256
+ else:
1257
+ break
1258
+ LOG.info(
1259
+ f"{len(motor_time_to_event_codes)} number of tasks have been added as MOTOR tasks"
1260
+ )
1261
+
1001
1262
  return CehrGptTokenizer(
1002
1263
  concept_tokenizer,
1003
1264
  value_tokenizer,
1004
1265
  att_tokenizer,
1005
1266
  token_to_sub_time_token_mapping,
1267
+ concept_code_stats,
1006
1268
  numeric_lab_stats,
1007
1269
  categorical_lab_stats,
1008
1270
  concept_name_mapping,
1009
1271
  pretrained_concept_embedding_model,
1272
+ motor_time_to_event_codes,
1010
1273
  )
1011
1274
 
1012
1275
  @classmethod
@@ -1017,6 +1280,7 @@ class CehrGptTokenizer(PreTrainedTokenizer):
1017
1280
  special_tokens: List[str],
1018
1281
  unk_token,
1019
1282
  data_args,
1283
+ qualified_codes: Optional[List[str]] = None,
1020
1284
  ):
1021
1285
  # Use the Fast Tokenizer from the Huggingface tokenizers Rust implementation.
1022
1286
  # https://github.com/huggingface/tokenizers
@@ -1029,7 +1293,9 @@ class CehrGptTokenizer(PreTrainedTokenizer):
1029
1293
  show_progress=True,
1030
1294
  )
1031
1295
  batch_concat_concepts_partial_func = partial(
1032
- cls.batch_concat_concepts, feature_name=feature_name
1296
+ cls.batch_concat_concepts,
1297
+ feature_name=feature_name,
1298
+ qualified_codes=qualified_codes,
1033
1299
  )
1034
1300
  if data_args.streaming:
1035
1301
  concatenated_features = dataset.map(
@@ -1069,13 +1335,30 @@ class CehrGptTokenizer(PreTrainedTokenizer):
1069
1335
 
1070
1336
  @classmethod
1071
1337
  def batch_concat_concepts(
1072
- cls, records: Dict[str, List], feature_name
1338
+ cls,
1339
+ records: Dict[str, List],
1340
+ feature_name: str,
1341
+ qualified_codes: Optional[List[str]] = None,
1073
1342
  ) -> Dict[str, List]:
1343
+ def filter_token(t: str) -> bool:
1344
+ """
1345
+ If the token is None or not string, return False.
1346
+
1347
+ When qualified_codes is provided, t must be in
1348
+ qualified_codes to be valid, otherwise the tokens are always valid
1349
+
1350
+ :param t:
1351
+ :return:
1352
+ """
1353
+ if t is None or not isinstance(t, str):
1354
+ return False
1355
+ if qualified_codes:
1356
+ return t in qualified_codes
1357
+ return True
1358
+
1074
1359
  return {
1075
1360
  feature_name: [
1076
- " ".join(
1077
- [token for token in tokens if token and isinstance(token, str)]
1078
- )
1361
+ " ".join(filter(filter_token, tokens))
1079
1362
  for tokens in records[feature_name]
1080
1363
  ]
1081
1364
  }