cehrgpt 0.0.2__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. cehrgpt/analysis/irregularity.py +36 -0
  2. cehrgpt/data/hf_cehrgpt_dataset.py +25 -4
  3. cehrgpt/data/hf_cehrgpt_dataset_collator.py +635 -97
  4. cehrgpt/data/hf_cehrgpt_dataset_mapping.py +308 -95
  5. cehrgpt/data/sample_packing_sampler.py +181 -0
  6. cehrgpt/generation/generate_batch_hf_gpt_sequence.py +12 -9
  7. cehrgpt/generation/omop_converter_batch.py +32 -2
  8. cehrgpt/gpt_utils.py +20 -2
  9. cehrgpt/models/config.py +35 -0
  10. cehrgpt/models/hf_cehrgpt.py +470 -106
  11. cehrgpt/models/hf_modeling_outputs.py +1 -0
  12. cehrgpt/models/special_tokens.py +1 -0
  13. cehrgpt/models/tokenization_hf_cehrgpt.py +358 -71
  14. cehrgpt/runners/data_utils.py +358 -0
  15. cehrgpt/runners/gpt_runner_util.py +0 -10
  16. cehrgpt/runners/hf_cehrgpt_finetune_runner.py +181 -283
  17. cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +288 -112
  18. cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +90 -0
  19. cehrgpt/runners/hyperparameter_search_util.py +10 -8
  20. cehrgpt/runners/sample_packing_trainer.py +185 -0
  21. cehrgpt/simulations/generate_plots.py +95 -0
  22. cehrgpt/simulations/run_simulation.sh +24 -0
  23. cehrgpt/simulations/time_embedding_simulation.py +250 -0
  24. cehrgpt/simulations/time_token_simulation.py +177 -0
  25. cehrgpt/time_to_event/config/1_year_cabg.yaml +23 -0
  26. cehrgpt/time_to_event/time_to_event_model.py +2 -13
  27. cehrgpt/time_to_event/time_to_event_prediction.py +27 -13
  28. cehrgpt/tools/linear_prob/__init__.py +0 -0
  29. cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +495 -0
  30. cehrgpt/tools/linear_prob/train_with_cehrgpt_features.py +152 -0
  31. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/METADATA +11 -8
  32. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/RECORD +36 -32
  33. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/WHEEL +1 -1
  34. cehrgpt/data/hf_cehrgpt_dpo_collator.py +0 -71
  35. cehrgpt/data/hf_cehrgpt_dpo_dataset_mapping.py +0 -61
  36. cehrgpt/generation/generate_paired_cehrgpt_sequence.py +0 -224
  37. cehrgpt/rl_finetune/cehrgpt_dpo_trainer.py +0 -586
  38. cehrgpt/rl_finetune/cehrgpt_ppo_trainer.py +0 -464
  39. cehrgpt/rl_finetune/ppo_finetune.py +0 -394
  40. cehrgpt/rl_finetune/ppo_finetune_v2.py +0 -373
  41. cehrgpt/runners/hf_cehrgpt_dpo_runner.py +0 -119
  42. /cehrgpt/{rl_finetune → simulations}/__init__.py +0 -0
  43. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info/licenses}/LICENSE +0 -0
  44. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/top_level.txt +0 -0
@@ -10,13 +10,9 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
10
10
  import numpy as np
11
11
  import scipy.stats as stats
12
12
  import transformers
13
- from cehrbert.models.hf_models.tokenization_utils import (
14
- agg_helper,
15
- agg_statistics,
16
- load_json_file,
17
- )
13
+ from cehrbert.models.hf_models.tokenization_utils import agg_helper, load_json_file
18
14
  from cehrbert.runners.hf_runner_argument_dataclass import DataTrainingArguments
19
- from datasets import Dataset, DatasetDict
15
+ from datasets import Dataset, DatasetDict, IterableDataset
20
16
  from femr.stat_utils import OnlineStatistics, ReservoirSampler
21
17
  from scipy.interpolate import UnivariateSpline
22
18
  from tokenizers import AddedToken, Tokenizer
@@ -25,16 +21,19 @@ from tokenizers.pre_tokenizers import WhitespaceSplit
25
21
  from tokenizers.trainers import WordLevelTrainer
26
22
  from tqdm import tqdm
27
23
  from transformers import PreTrainedTokenizer
24
+ from transformers.utils import logging
28
25
 
29
26
  from cehrgpt.gpt_utils import (
30
27
  convert_time_interval_to_time_tuple,
31
28
  extract_time_interval_in_days,
32
29
  is_att_token,
30
+ is_clinical_event,
33
31
  is_inpatient_att_token,
34
32
  )
35
33
  from cehrgpt.models.pretrained_embeddings import PretrainedEmbeddings
36
34
  from cehrgpt.models.special_tokens import (
37
35
  END_TOKEN,
36
+ LINEAR_PROB_TOKEN,
38
37
  OUT_OF_VOCABULARY_TOKEN,
39
38
  PAD_TOKEN,
40
39
  START_TOKEN,
@@ -52,7 +51,11 @@ TIME_TOKENIZER_FILE_NAME = "cehrgpt_time_tokenizer.json"
52
51
  TOKEN_TO_SUB_TIME_TOKEN_MAPPING_FILE_NAME = "token_to_sub_time_token_mapping.json"
53
52
  LAB_STATS_FILE_NAME = "cehrgpt_lab_stats.pickle"
54
53
  LEGACY_LAB_STATS_FILE_NAME = "cehrgpt_lab_stats.json"
54
+ CONCEPT_STATS_FILE_NAME = "cehrgpt_concept_stats.json"
55
55
  CONCEPT_MAPPING_FILE_NAME = "concept_name_mapping.json"
56
+ MOTOR_TIME_TO_EVENT_CODES_FILE_NAME = "motor_time_to_event_codes.json"
57
+
58
+ LOG = logging.get_logger("transformers")
56
59
 
57
60
 
58
61
  def truncated_sample(sample, standard_deviation):
@@ -71,6 +74,16 @@ def create_value_bin(bin_index: int) -> str:
71
74
  return "BIN:" + str(bin_index)
72
75
 
73
76
 
77
+ def get_dataset_len(dataset: Union[Dataset, IterableDataset]) -> int:
78
+ if isinstance(dataset, Dataset):
79
+ return len(dataset)
80
+ elif isinstance(dataset, IterableDataset):
81
+ return sum([1 for _ in dataset])
82
+ raise RuntimeError(
83
+ "The dataset must be one of the two types (Dataset, IterableDataset)"
84
+ )
85
+
86
+
74
87
  def create_sample_from_bins(bins, sample_size: int = 10_000) -> List[float]:
75
88
  """
76
89
  Generates a specified number of samples from a list of bins, each containing a fitted spline.
@@ -186,7 +199,22 @@ def create_bins_with_spline(samples, num_bins, d_freedom=3) -> List[Dict[str, An
186
199
  return bins
187
200
 
188
201
 
189
- def map_statistics(batch: Dict[str, Any], size=10_000) -> Dict[str, Any]:
202
+ def agg_statistics(stats1, stats2):
203
+ if stats1.get("numeric_stats_by_lab"):
204
+ for k, v in stats2["numeric_stats_by_lab"].items():
205
+ stats1["numeric_stats_by_lab"][k].combine(v)
206
+ if stats1.get("categorical_stats_by_lab"):
207
+ for (concept_id, concept_as_value), count in stats2[
208
+ "categorical_stats_by_lab"
209
+ ].items():
210
+ stats1["categorical_stats_by_lab"][(concept_id, concept_as_value)] += count
211
+ if stats1.get("concept_code_stats"):
212
+ for concept_id, weight in stats2["concept_code_stats"].items():
213
+ stats1["concept_code_stats"][concept_id] += weight
214
+ return stats1
215
+
216
+
217
+ def map_statistics(batch: Dict[str, Any], total_size, size=10_000) -> Dict[str, Any]:
190
218
  if "units" in batch:
191
219
  batch_value_units = batch["units"]
192
220
  else:
@@ -210,6 +238,7 @@ def map_statistics(batch: Dict[str, Any], size=10_000) -> Dict[str, Any]:
210
238
 
211
239
  numeric_stats_by_lab = collections.defaultdict(partial(ReservoirSampler, size=size))
212
240
  categorical_stats_by_lab = collections.defaultdict(int)
241
+ concept_code_stats = collections.defaultdict(int)
213
242
  for (
214
243
  concept_ids,
215
244
  number_as_values,
@@ -223,6 +252,7 @@ def map_statistics(batch: Dict[str, Any], size=10_000) -> Dict[str, Any]:
223
252
  batch["concept_value_masks"],
224
253
  batch_value_units,
225
254
  ):
255
+ unique_codes = set()
226
256
  for (
227
257
  concept_id,
228
258
  number_as_value,
@@ -241,10 +271,94 @@ def map_statistics(batch: Dict[str, Any], size=10_000) -> Dict[str, Any]:
241
271
  numeric_stats_by_lab[(concept_id, unit)].add(number_as_value, 1)
242
272
  if concept_as_value:
243
273
  categorical_stats_by_lab[(concept_id, concept_as_value)] += 1
274
+ unique_codes.add(concept_id)
275
+
276
+ for code in unique_codes:
277
+ concept_code_stats[code] += 1 / total_size
244
278
 
245
279
  return {
246
280
  "numeric_stats_by_lab": numeric_stats_by_lab,
247
281
  "categorical_stats_by_lab": categorical_stats_by_lab,
282
+ "concept_code_stats": concept_code_stats,
283
+ }
284
+
285
+
286
+ def compute_statistics(
287
+ dataset: Dataset, data_args: DataTrainingArguments
288
+ ) -> Dict[str, Any]:
289
+ total = get_dataset_len(dataset)
290
+ map_statistics_partial = partial(map_statistics, total_size=total, size=SAMPLE_SIZE)
291
+ if data_args.streaming:
292
+ first_example = next(iter(dataset))
293
+ parts = dataset.map(
294
+ partial(agg_helper, map_func=map_statistics_partial),
295
+ batched=True,
296
+ batch_size=data_args.preprocessing_batch_size,
297
+ remove_columns=first_example.keys(),
298
+ )
299
+ else:
300
+ parts = dataset.map(
301
+ partial(agg_helper, map_func=map_statistics_partial),
302
+ batched=True,
303
+ batch_size=data_args.preprocessing_batch_size,
304
+ remove_columns=dataset.column_names,
305
+ num_proc=data_args.preprocessing_num_workers,
306
+ keep_in_memory=True,
307
+ new_fingerprint="invalid",
308
+ )
309
+ current = None
310
+ for stat in tqdm(parts, desc="Aggregating the lab statistics"):
311
+ fixed_stat = pickle.loads(stat["data"])
312
+ if current is None:
313
+ current = fixed_stat
314
+ else:
315
+ current = agg_statistics(current, fixed_stat)
316
+
317
+ numeric_lab_stats = []
318
+ for (concept_id, unit), online_stats in current["numeric_stats_by_lab"].items():
319
+ if len(online_stats.samples) == 0:
320
+ continue
321
+ samples = truncated_sample(online_stats.samples, data_args.value_outlier_std)
322
+ bins = create_bins_with_spline(samples, NUM_OF_BINS, DEGREE_OF_FREEDOM)
323
+ if len(bins) > 0:
324
+ numeric_lab_stats.append(
325
+ {
326
+ "concept_id": concept_id,
327
+ "unit": unit,
328
+ "mean": np.mean(samples),
329
+ "std": np.std(samples),
330
+ "count": len(online_stats.samples),
331
+ "value_outlier_std": data_args.value_outlier_std,
332
+ "bins": bins,
333
+ }
334
+ )
335
+
336
+ categorical_lab_stats = collections.defaultdict(int)
337
+ for (concept_id, value_as_concept), count in current[
338
+ "categorical_stats_by_lab"
339
+ ].items():
340
+ categorical_lab_stats[(concept_id, value_as_concept)] += count
341
+
342
+ concept_code_stats = collections.defaultdict(int)
343
+ for concept_id, count in current["concept_code_stats"].items():
344
+ concept_code_stats[concept_id] += count
345
+
346
+ code_weights = np.asarray(list(concept_code_stats.values())).clip(1e-8, 1 - 1e-8)
347
+ # Clip the values so we don't get errors when applying np.log
348
+ code_entropies = np.log(code_weights) * code_weights + (1 - code_weights) * np.log(
349
+ 1 - code_weights
350
+ )
351
+
352
+ concept_code_entropies = {
353
+ k: v for k, v in zip(concept_code_stats.keys(), code_entropies)
354
+ }
355
+
356
+ return {
357
+ "numeric_lab_stats": numeric_lab_stats,
358
+ "categorical_lab_stats": categorical_lab_stats,
359
+ "concept_code_stats": concept_code_stats,
360
+ "concept_code_entropies": concept_code_entropies,
361
+ "total": total,
248
362
  }
249
363
 
250
364
 
@@ -346,15 +460,18 @@ class CehrGptTokenizer(PreTrainedTokenizer):
346
460
  value_tokenizer: Tokenizer,
347
461
  att_tokenizer: Tokenizer,
348
462
  token_to_sub_time_token_mapping: Dict[str, List[str]],
463
+ concept_code_stats: Dict[str, Any],
349
464
  numeric_lab_stats: List[Dict[str, Any]],
350
465
  categorical_lab_stats: Dict[Tuple[str, str], int],
351
466
  concept_name_mapping: Dict[str, str],
352
467
  pretrained_concept_embedding_model: PretrainedEmbeddings = None,
468
+ motor_time_to_event_codes: Optional[List[str]] = None,
353
469
  ):
354
470
  self._tokenizer = tokenizer
355
471
  self._value_tokenizer = value_tokenizer
356
472
  self._att_tokenizer = att_tokenizer
357
473
  self._token_to_sub_time_token_mapping = token_to_sub_time_token_mapping
474
+ self._concept_code_stats = concept_code_stats
358
475
  self._numeric_lab_stats = numeric_lab_stats
359
476
  self._numeric_event_statistics = NumericEventStatistics(numeric_lab_stats)
360
477
  self._categorical_lab_stats = categorical_lab_stats
@@ -363,6 +480,12 @@ class CehrGptTokenizer(PreTrainedTokenizer):
363
480
  self._padding_token_id = self._tokenizer.token_to_id(PAD_TOKEN)
364
481
  self._start_token_id = self._tokenizer.token_to_id(START_TOKEN)
365
482
  self._end_token_id = self._tokenizer.token_to_id(END_TOKEN)
483
+
484
+ # Backward compatible with the old tokenizer
485
+ self._linear_token_id = None
486
+ if LINEAR_PROB_TOKEN in self._tokenizer.get_vocab():
487
+ self._linear_token_id = self._tokenizer.token_to_id(LINEAR_PROB_TOKEN)
488
+
366
489
  self._numeric_concept_ids = (
367
490
  self._numeric_event_statistics.get_numeric_concept_ids()
368
491
  )
@@ -380,6 +503,12 @@ class CehrGptTokenizer(PreTrainedTokenizer):
380
503
  for _ in self.get_vocab().keys()
381
504
  if self._pretrained_concept_embedding_model.is_concept_available(_)
382
505
  ]
506
+ self._motor_time_to_event_codes = (
507
+ motor_time_to_event_codes if motor_time_to_event_codes else []
508
+ )
509
+ self._motor_code_to_id_mapping = {
510
+ code: i for i, code in enumerate(sorted(self._motor_time_to_event_codes))
511
+ }
383
512
 
384
513
  super().__init__()
385
514
 
@@ -400,6 +529,10 @@ class CehrGptTokenizer(PreTrainedTokenizer):
400
529
  ]
401
530
  )
402
531
 
532
+ @property
533
+ def motor_tte_vocab_size(self) -> int:
534
+ return len(self._motor_code_to_id_mapping)
535
+
403
536
  @property
404
537
  def vocab_size(self) -> int:
405
538
  return self._tokenizer.get_vocab_size()
@@ -440,10 +573,45 @@ class CehrGptTokenizer(PreTrainedTokenizer):
440
573
  def pad_token_id(self):
441
574
  return self._padding_token_id
442
575
 
576
+ @property
577
+ def oov_token_id(self):
578
+ return self._oov_token_id
579
+
443
580
  @property
444
581
  def pad_token(self):
445
582
  return PAD_TOKEN
446
583
 
584
+ @property
585
+ def linear_token_id(self) -> Optional[int]:
586
+ return self._linear_token_id
587
+
588
+ @property
589
+ def linear_token(self) -> Optional[str]:
590
+ if LINEAR_PROB_TOKEN in self._tokenizer.get_vocab():
591
+ return LINEAR_PROB_TOKEN
592
+ return None
593
+
594
+ def vs_token_id(self):
595
+ # We used VS for the historical data, currently, we use the new [VS] for the newer data
596
+ # so we need to check both cases.
597
+ if "VS" in self._tokenizer.get_vocab():
598
+ return self._convert_token_to_id("VS")
599
+ elif "[VS]" in self._tokenizer.get_vocab():
600
+ return self._convert_token_to_id("[VS]")
601
+ else:
602
+ raise RuntimeError("The tokenizer does not contain either VS or [VS]")
603
+
604
+ @property
605
+ def ve_token_id(self):
606
+ # We used VE for the historical data, currently, we use the new [VE] for the newer data
607
+ # so we need to check both cases.
608
+ if "VE" in self._tokenizer.get_vocab():
609
+ return self._convert_token_to_id("VE")
610
+ elif "[VE]" in self._tokenizer.get_vocab():
611
+ return self._convert_token_to_id("[VE]")
612
+ else:
613
+ raise RuntimeError("The tokenizer does not contain either VE or [VE]")
614
+
447
615
  @property
448
616
  def numeric_concept_ids(self):
449
617
  return self._numeric_concept_ids
@@ -454,8 +622,14 @@ class CehrGptTokenizer(PreTrainedTokenizer):
454
622
 
455
623
  @property
456
624
  def lab_token_ids(self):
457
- reserved_tokens = [START_TOKEN, PAD_TOKEN, END_TOKEN, OUT_OF_VOCABULARY_TOKEN]
458
- return self.encode(
625
+ reserved_tokens = [
626
+ START_TOKEN,
627
+ PAD_TOKEN,
628
+ END_TOKEN,
629
+ OUT_OF_VOCABULARY_TOKEN,
630
+ LINEAR_PROB_TOKEN,
631
+ ]
632
+ lab_token_ids = self.encode(
459
633
  [
460
634
  concept_id
461
635
  for concept_id in self._numeric_concept_ids
@@ -463,6 +637,9 @@ class CehrGptTokenizer(PreTrainedTokenizer):
463
637
  if concept_id not in reserved_tokens
464
638
  ]
465
639
  )
640
+ return list(
641
+ filter(lambda token_id: token_id != self._oov_token_id, lab_token_ids)
642
+ )
466
643
 
467
644
  @property
468
645
  def token_to_time_token_mapping(self) -> Dict[int, List[int]]:
@@ -481,6 +658,19 @@ class CehrGptTokenizer(PreTrainedTokenizer):
481
658
  def pretrained_concept_embedding_model(self):
482
659
  return self._pretrained_concept_embedding_model
483
660
 
661
+ def get_motor_token_id(self, concept_id: str) -> int:
662
+ if concept_id not in concept_id:
663
+ raise RuntimeError(f"Invalid motor concept id: {concept_id}")
664
+ return self._motor_code_to_id_mapping[concept_id]
665
+
666
+ def is_motor_time_to_event_code(self, future_concept_id: str) -> bool:
667
+ if (
668
+ self._motor_time_to_event_codes
669
+ and future_concept_id in self._motor_time_to_event_codes
670
+ ):
671
+ return True
672
+ return False
673
+
484
674
  def get_vocab(self) -> Dict[str, int]:
485
675
  return self._tokenizer.get_vocab()
486
676
 
@@ -509,6 +699,18 @@ class CehrGptTokenizer(PreTrainedTokenizer):
509
699
  concept_value_token_ids, skip_special_tokens=skip_special_tokens
510
700
  ).split(" ")
511
701
 
702
+ def add_token(self, tokens: Union[str, List[str]]) -> None:
703
+ if isinstance(tokens, str):
704
+ tokens = [tokens]
705
+ vocab = self.get_vocab()
706
+ self._tokenizer.add_tokens(
707
+ [
708
+ AddedToken(token, single_word=True, normalized=False)
709
+ for token in tokens
710
+ if token not in vocab
711
+ ]
712
+ )
713
+
512
714
  def _convert_token_to_id(self, token):
513
715
  """Converts a token (str) in an id using the vocab."""
514
716
  token_id = self._tokenizer.token_to_id(token)
@@ -570,6 +772,9 @@ class CehrGptTokenizer(PreTrainedTokenizer):
570
772
  ) as f:
571
773
  json.dump(self._token_to_sub_time_token_mapping, f)
572
774
 
775
+ with open(os.path.join(save_directory, CONCEPT_STATS_FILE_NAME), "w") as f:
776
+ json.dump(self._concept_code_stats, f)
777
+
573
778
  with open(os.path.join(save_directory, LAB_STATS_FILE_NAME), "wb") as f:
574
779
  lab_stats = {
575
780
  "numeric_lab_stats": self._numeric_lab_stats,
@@ -580,6 +785,11 @@ class CehrGptTokenizer(PreTrainedTokenizer):
580
785
  with open(os.path.join(save_directory, CONCEPT_MAPPING_FILE_NAME), "w") as f:
581
786
  json.dump(self._concept_name_mapping, f)
582
787
 
788
+ with open(
789
+ os.path.join(save_directory, MOTOR_TIME_TO_EVENT_CODES_FILE_NAME), "w"
790
+ ) as f:
791
+ json.dump(self._motor_time_to_event_codes, f)
792
+
583
793
  self._pretrained_concept_embedding_model.save(save_directory)
584
794
 
585
795
  if push_to_hub:
@@ -689,6 +899,22 @@ class CehrGptTokenizer(PreTrainedTokenizer):
689
899
  return None
690
900
  concept_name_mapping = load_json_file(concept_name_mapping_file)
691
901
 
902
+ # Load the concept_code_stats json file
903
+ concept_code_stats_mapping_file = transformers.utils.hub.cached_file(
904
+ pretrained_model_name_or_path, CONCEPT_STATS_FILE_NAME, **kwargs
905
+ )
906
+ if not concept_code_stats_mapping_file:
907
+ return None
908
+ concept_code_stats = load_json_file(concept_code_stats_mapping_file)
909
+
910
+ # Load the MOTOR time to event codes file
911
+ motor_time_to_event_codes_file = transformers.utils.hub.cached_file(
912
+ pretrained_model_name_or_path, MOTOR_TIME_TO_EVENT_CODES_FILE_NAME, **kwargs
913
+ )
914
+ if not motor_time_to_event_codes_file:
915
+ return None
916
+ motor_time_to_event_codes = load_json_file(motor_time_to_event_codes_file)
917
+
692
918
  pretrained_embedding_model = PretrainedEmbeddings(pretrained_model_name_or_path)
693
919
 
694
920
  return CehrGptTokenizer(
@@ -696,10 +922,12 @@ class CehrGptTokenizer(PreTrainedTokenizer):
696
922
  value_tokenizer,
697
923
  att_tokenizer,
698
924
  token_to_sub_time_token_mapping,
925
+ concept_code_stats,
699
926
  lab_stats["numeric_lab_stats"],
700
927
  lab_stats["categorical_lab_stats"],
701
928
  concept_name_mapping,
702
929
  pretrained_embedding_model,
930
+ motor_time_to_event_codes,
703
931
  )
704
932
 
705
933
  @classmethod
@@ -722,6 +950,9 @@ class CehrGptTokenizer(PreTrainedTokenizer):
722
950
  concept_name_mapping: Dict[str, str],
723
951
  data_args: DataTrainingArguments,
724
952
  pretrained_concept_embedding_model: PretrainedEmbeddings = None,
953
+ num_motor_tasks: Optional[int] = None,
954
+ apply_entropy_filter: bool = False,
955
+ min_prevalence: float = 1 / 1000,
725
956
  ):
726
957
  if not isinstance(cehrgpt_tokenizer, CehrGptTokenizer):
727
958
  raise ValueError(
@@ -734,6 +965,9 @@ class CehrGptTokenizer(PreTrainedTokenizer):
734
965
  dataset=dataset,
735
966
  concept_name_mapping=concept_name_mapping,
736
967
  data_args=data_args,
968
+ num_motor_tasks=num_motor_tasks,
969
+ apply_entropy_filter=apply_entropy_filter,
970
+ min_prevalence=min_prevalence,
737
971
  )
738
972
 
739
973
  new_tokens = set(new_tokenizer.get_vocab().keys()) - set(
@@ -751,6 +985,7 @@ class CehrGptTokenizer(PreTrainedTokenizer):
751
985
  new_numeric_lab_stats = new_tokenizer._numeric_lab_stats
752
986
  new_categorical_lab_stats = new_tokenizer._categorical_lab_stats
753
987
  new_concept_name_mapping = new_tokenizer._concept_name_mapping
988
+ new_motor_time_to_event_codes = new_tokenizer._motor_time_to_event_codes
754
989
 
755
990
  # Add new tokens to the existing tokenizer
756
991
  cehrgpt_tokenizer_copy._tokenizer.add_tokens(
@@ -799,15 +1034,31 @@ class CehrGptTokenizer(PreTrainedTokenizer):
799
1034
  if token not in cehrgpt_tokenizer_copy._concept_name_mapping:
800
1035
  cehrgpt_tokenizer_copy._concept_name_mapping[token] = concept_name
801
1036
 
1037
+ # Merge motor_time_to_event_codes
1038
+ if (
1039
+ new_motor_time_to_event_codes
1040
+ and cehrgpt_tokenizer_copy._motor_time_to_event_codes
1041
+ ):
1042
+ for motor_time_to_event_code in new_motor_time_to_event_codes:
1043
+ if (
1044
+ motor_time_to_event_code
1045
+ not in cehrgpt_tokenizer_copy._motor_time_to_event_codes
1046
+ ):
1047
+ cehrgpt_tokenizer_copy._motor_time_to_event_codes.append(
1048
+ motor_time_to_event_code
1049
+ )
1050
+
802
1051
  return CehrGptTokenizer(
803
1052
  tokenizer=cehrgpt_tokenizer_copy._tokenizer,
804
1053
  value_tokenizer=cehrgpt_tokenizer_copy._value_tokenizer,
805
1054
  att_tokenizer=cehrgpt_tokenizer_copy._att_tokenizer,
806
1055
  token_to_sub_time_token_mapping=cehrgpt_tokenizer_copy._token_to_sub_time_token_mapping,
1056
+ concept_code_stats=cehrgpt_tokenizer_copy._concept_code_stats,
807
1057
  numeric_lab_stats=cehrgpt_tokenizer_copy._numeric_lab_stats,
808
1058
  categorical_lab_stats=cehrgpt_tokenizer_copy._categorical_lab_stats,
809
1059
  concept_name_mapping=cehrgpt_tokenizer_copy._concept_name_mapping,
810
1060
  pretrained_concept_embedding_model=pretrained_concept_embedding_model,
1061
+ motor_time_to_event_codes=cehrgpt_tokenizer_copy._motor_time_to_event_codes,
811
1062
  )
812
1063
 
813
1064
  @classmethod
@@ -877,6 +1128,10 @@ class CehrGptTokenizer(PreTrainedTokenizer):
877
1128
  concept_name_mapping: Dict[str, str],
878
1129
  data_args: DataTrainingArguments,
879
1130
  pretrained_concept_embedding_model: PretrainedEmbeddings = None,
1131
+ allowed_motor_codes: Optional[List[int]] = None,
1132
+ num_motor_tasks: Optional[int] = None,
1133
+ apply_entropy_filter: bool = False,
1134
+ min_prevalence: float = 1 / 1000,
880
1135
  ):
881
1136
  """
882
1137
  Train a huggingface word level tokenizer.
@@ -888,18 +1143,56 @@ class CehrGptTokenizer(PreTrainedTokenizer):
888
1143
  if isinstance(dataset, DatasetDict):
889
1144
  dataset = dataset["train"]
890
1145
 
1146
+ LOG.info("Calculating data statistics")
1147
+ cehrgpt_data_statistics = compute_statistics(dataset, data_args)
1148
+ cehrgpt_data_statistics["total"]
1149
+ numeric_lab_stats = cehrgpt_data_statistics["numeric_lab_stats"]
1150
+ categorical_lab_stats = cehrgpt_data_statistics["categorical_lab_stats"]
1151
+ concept_code_stats = cehrgpt_data_statistics["concept_code_stats"]
1152
+ concept_code_entropies = cehrgpt_data_statistics["concept_code_entropies"]
1153
+
1154
+ if apply_entropy_filter:
1155
+ min_prevalence = max(1e-8, min_prevalence)
1156
+ min_entropy = (
1157
+ np.log(1 - min_prevalence) * (1 - min_prevalence)
1158
+ + np.log(min_prevalence) * min_prevalence
1159
+ )
1160
+ qualified_codes = [
1161
+ k
1162
+ for k, v in concept_code_entropies.items()
1163
+ if v <= min_entropy
1164
+ or not is_clinical_event(k, data_args.is_data_in_meds)
1165
+ ]
1166
+ else:
1167
+ qualified_codes = [
1168
+ k
1169
+ for k, v in concept_code_stats.items()
1170
+ if min_prevalence <= v
1171
+ or not is_clinical_event(k, data_args.is_data_in_meds)
1172
+ ]
1173
+
1174
+ # Create the tokenizer now
1175
+ LOG.info("Training the tokenizer for concepts")
891
1176
  concept_tokenizer = cls.train_concept_tokenizer(
892
1177
  dataset,
893
1178
  feature_name="concept_ids",
894
- special_tokens=[PAD_TOKEN, OUT_OF_VOCABULARY_TOKEN, START_TOKEN, END_TOKEN],
1179
+ special_tokens=[
1180
+ PAD_TOKEN,
1181
+ OUT_OF_VOCABULARY_TOKEN,
1182
+ START_TOKEN,
1183
+ END_TOKEN,
1184
+ LINEAR_PROB_TOKEN,
1185
+ ],
895
1186
  unk_token=OUT_OF_VOCABULARY_TOKEN,
896
1187
  data_args=data_args,
1188
+ qualified_codes=qualified_codes,
897
1189
  )
898
1190
  concept_value_column = "concept_as_values"
899
1191
  for row in dataset:
900
1192
  if concept_value_column not in row:
901
1193
  concept_value_column = "concept_values"
902
1194
  break
1195
+ LOG.info("Training the tokenizer for values")
903
1196
  value_tokenizer = cls.train_concept_tokenizer(
904
1197
  dataset,
905
1198
  feature_name=concept_value_column,
@@ -915,65 +1208,11 @@ class CehrGptTokenizer(PreTrainedTokenizer):
915
1208
  ]
916
1209
  )
917
1210
 
918
- map_statistics_partial = partial(map_statistics, size=SAMPLE_SIZE)
919
-
920
- if data_args.streaming:
921
- first_example = next(iter(dataset))
922
- parts = dataset.map(
923
- partial(agg_helper, map_func=map_statistics_partial),
924
- batched=True,
925
- batch_size=data_args.preprocessing_batch_size,
926
- remove_columns=first_example.keys(),
927
- )
928
- else:
929
- parts = dataset.map(
930
- partial(agg_helper, map_func=map_statistics_partial),
931
- batched=True,
932
- batch_size=data_args.preprocessing_batch_size,
933
- remove_columns=dataset.column_names,
934
- num_proc=data_args.preprocessing_num_workers,
935
- keep_in_memory=True,
936
- new_fingerprint="invalid",
937
- )
938
- current = None
939
- for stat in tqdm(parts, desc="Aggregating the lab statistics"):
940
- fixed_stat = pickle.loads(stat["data"])
941
- if current is None:
942
- current = fixed_stat
943
- else:
944
- current = agg_statistics(current, fixed_stat)
945
-
946
- numeric_lab_stats = []
947
- for (concept_id, unit), online_stats in current["numeric_stats_by_lab"].items():
948
- if len(online_stats.samples) == 0:
949
- continue
950
- samples = truncated_sample(
951
- online_stats.samples, data_args.value_outlier_std
952
- )
953
- bins = create_bins_with_spline(samples, NUM_OF_BINS, DEGREE_OF_FREEDOM)
954
- if len(bins) > 0:
955
- numeric_lab_stats.append(
956
- {
957
- "concept_id": concept_id,
958
- "unit": unit,
959
- "mean": np.mean(samples),
960
- "std": np.std(samples),
961
- "count": len(online_stats.samples),
962
- "value_outlier_std": data_args.value_outlier_std,
963
- "bins": bins,
964
- }
965
- )
966
-
967
- categorical_lab_stats = collections.defaultdict(int)
968
- for (concept_id, value_as_concept), count in current[
969
- "categorical_stats_by_lab"
970
- ].items():
971
- categorical_lab_stats[(concept_id, value_as_concept)] += count
972
-
973
1211
  # We will train a tokenizer specifically for time intervals
974
1212
  sub_time_token_data = []
975
1213
  token_to_sub_time_token_mapping = collections.defaultdict(list)
976
- for token, token_id in concept_tokenizer.get_vocab().items():
1214
+ vocab = concept_tokenizer.get_vocab()
1215
+ for token, token_id in vocab.items():
977
1216
  if is_att_token(token):
978
1217
  time_interval = extract_time_interval_in_days(token)
979
1218
  time_tuple = convert_time_interval_to_time_tuple(
@@ -994,15 +1233,43 @@ class CehrGptTokenizer(PreTrainedTokenizer):
994
1233
  )
995
1234
  att_tokenizer.train_from_iterator(sub_time_token_data, trainer=att_trainer)
996
1235
 
1236
+ # Prune concept_name_mapping
1237
+ concept_name_mapping = {
1238
+ concept_id: concept_name_mapping[concept_id]
1239
+ for concept_id in vocab.keys()
1240
+ if concept_id in concept_name_mapping
1241
+ }
1242
+
1243
+ motor_time_to_event_codes = None
1244
+ if num_motor_tasks and allowed_motor_codes:
1245
+ motor_time_to_event_codes = []
1246
+ for concept_id, _ in sorted(
1247
+ concept_code_entropies.items(), key=lambda t: t[1]
1248
+ ):
1249
+ if (
1250
+ concept_id not in allowed_motor_codes
1251
+ or concept_id not in qualified_codes
1252
+ ):
1253
+ continue
1254
+ if len(motor_time_to_event_codes) < num_motor_tasks:
1255
+ motor_time_to_event_codes.append(concept_id)
1256
+ else:
1257
+ break
1258
+ LOG.info(
1259
+ f"{len(motor_time_to_event_codes)} number of tasks have been added as MOTOR tasks"
1260
+ )
1261
+
997
1262
  return CehrGptTokenizer(
998
1263
  concept_tokenizer,
999
1264
  value_tokenizer,
1000
1265
  att_tokenizer,
1001
1266
  token_to_sub_time_token_mapping,
1267
+ concept_code_stats,
1002
1268
  numeric_lab_stats,
1003
1269
  categorical_lab_stats,
1004
1270
  concept_name_mapping,
1005
1271
  pretrained_concept_embedding_model,
1272
+ motor_time_to_event_codes,
1006
1273
  )
1007
1274
 
1008
1275
  @classmethod
@@ -1013,6 +1280,7 @@ class CehrGptTokenizer(PreTrainedTokenizer):
1013
1280
  special_tokens: List[str],
1014
1281
  unk_token,
1015
1282
  data_args,
1283
+ qualified_codes: Optional[List[str]] = None,
1016
1284
  ):
1017
1285
  # Use the Fast Tokenizer from the Huggingface tokenizers Rust implementation.
1018
1286
  # https://github.com/huggingface/tokenizers
@@ -1025,7 +1293,9 @@ class CehrGptTokenizer(PreTrainedTokenizer):
1025
1293
  show_progress=True,
1026
1294
  )
1027
1295
  batch_concat_concepts_partial_func = partial(
1028
- cls.batch_concat_concepts, feature_name=feature_name
1296
+ cls.batch_concat_concepts,
1297
+ feature_name=feature_name,
1298
+ qualified_codes=qualified_codes,
1029
1299
  )
1030
1300
  if data_args.streaming:
1031
1301
  concatenated_features = dataset.map(
@@ -1065,13 +1335,30 @@ class CehrGptTokenizer(PreTrainedTokenizer):
1065
1335
 
1066
1336
  @classmethod
1067
1337
  def batch_concat_concepts(
1068
- cls, records: Dict[str, List], feature_name
1338
+ cls,
1339
+ records: Dict[str, List],
1340
+ feature_name: str,
1341
+ qualified_codes: Optional[List[str]] = None,
1069
1342
  ) -> Dict[str, List]:
1343
+ def filter_token(t: str) -> bool:
1344
+ """
1345
+ If the token is None or not string, return False.
1346
+
1347
+ When qualified_codes is provided, t must be in
1348
+ qualified_codes to be valid, otherwise the tokens are always valid
1349
+
1350
+ :param t:
1351
+ :return:
1352
+ """
1353
+ if t is None or not isinstance(t, str):
1354
+ return False
1355
+ if qualified_codes:
1356
+ return t in qualified_codes
1357
+ return True
1358
+
1070
1359
  return {
1071
1360
  feature_name: [
1072
- " ".join(
1073
- [token for token in tokens if token and isinstance(token, str)]
1074
- )
1361
+ " ".join(filter(filter_token, tokens))
1075
1362
  for tokens in records[feature_name]
1076
1363
  ]
1077
1364
  }