cehrgpt 0.1.2__py3-none-any.whl → 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. cehrgpt/analysis/htn_treatment_pathway.py +546 -0
  2. cehrgpt/analysis/treatment_pathway/__init__.py +0 -0
  3. cehrgpt/analysis/treatment_pathway/depression_treatment_pathway.py +94 -0
  4. cehrgpt/analysis/treatment_pathway/diabetes_treatment_pathway.py +94 -0
  5. cehrgpt/analysis/treatment_pathway/htn_treatment_pathway.py +94 -0
  6. cehrgpt/analysis/treatment_pathway/treatment_pathway.py +631 -0
  7. cehrgpt/data/cehrgpt_data_processor.py +549 -0
  8. cehrgpt/data/hf_cehrgpt_dataset.py +4 -0
  9. cehrgpt/data/hf_cehrgpt_dataset_collator.py +285 -652
  10. cehrgpt/data/hf_cehrgpt_dataset_mapping.py +38 -5
  11. cehrgpt/generation/cehrgpt_conditional_generation.py +2 -0
  12. cehrgpt/generation/generate_batch_hf_gpt_sequence.py +20 -12
  13. cehrgpt/generation/omop_converter_batch.py +11 -4
  14. cehrgpt/gpt_utils.py +73 -3
  15. cehrgpt/models/activations.py +27 -0
  16. cehrgpt/models/config.py +6 -2
  17. cehrgpt/models/gpt2.py +560 -0
  18. cehrgpt/models/hf_cehrgpt.py +183 -460
  19. cehrgpt/models/tokenization_hf_cehrgpt.py +380 -50
  20. cehrgpt/omop/ontology.py +154 -0
  21. cehrgpt/runners/hf_cehrgpt_finetune_runner.py +24 -78
  22. cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +48 -44
  23. cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +46 -34
  24. cehrgpt/runners/hyperparameter_search_util.py +180 -69
  25. cehrgpt/runners/sample_packing_trainer.py +11 -2
  26. cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +8 -2
  27. cehrgpt-0.1.3.dist-info/METADATA +238 -0
  28. {cehrgpt-0.1.2.dist-info → cehrgpt-0.1.3.dist-info}/RECORD +32 -22
  29. cehrgpt-0.1.2.dist-info/METADATA +0 -209
  30. /cehrgpt/tools/{merge_synthetic_real_dataasets.py → merge_synthetic_real_datasets.py} +0 -0
  31. {cehrgpt-0.1.2.dist-info → cehrgpt-0.1.3.dist-info}/WHEEL +0 -0
  32. {cehrgpt-0.1.2.dist-info → cehrgpt-0.1.3.dist-info}/licenses/LICENSE +0 -0
  33. {cehrgpt-0.1.2.dist-info → cehrgpt-0.1.3.dist-info}/top_level.txt +0 -0
@@ -1,19 +1,24 @@
1
1
  import collections
2
2
  import copy
3
3
  import json
4
+ import math
4
5
  import os
5
6
  import pickle
6
7
  from functools import partial
7
8
  from itertools import islice
8
9
  from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
9
10
 
11
+ import femr
12
+ import femr.ontology
10
13
  import numpy as np
11
14
  import scipy.stats as stats
12
15
  import transformers
13
16
  from cehrbert.models.hf_models.tokenization_utils import agg_helper, load_json_file
14
17
  from cehrbert.runners.hf_runner_argument_dataclass import DataTrainingArguments
18
+ from cehrbert_data.const.artificial_tokens import DEATH_TOKEN
15
19
  from datasets import Dataset, DatasetDict, IterableDataset
16
20
  from femr.stat_utils import OnlineStatistics, ReservoirSampler
21
+ from meds import death_code
17
22
  from scipy.interpolate import UnivariateSpline
18
23
  from tokenizers import AddedToken, Tokenizer
19
24
  from tokenizers.models import WordLevel
@@ -38,6 +43,7 @@ from cehrgpt.models.special_tokens import (
38
43
  PAD_TOKEN,
39
44
  START_TOKEN,
40
45
  )
46
+ from cehrgpt.omop.ontology import Ontology
41
47
 
42
48
  NUM_OF_BINS = 10
43
49
  DEGREE_OF_FREEDOM = 3
@@ -52,8 +58,10 @@ TOKEN_TO_SUB_TIME_TOKEN_MAPPING_FILE_NAME = "token_to_sub_time_token_mapping.jso
52
58
  LAB_STATS_FILE_NAME = "cehrgpt_lab_stats.pickle"
53
59
  LEGACY_LAB_STATS_FILE_NAME = "cehrgpt_lab_stats.json"
54
60
  CONCEPT_STATS_FILE_NAME = "cehrgpt_concept_stats.json"
61
+ DEMOGRAPHICS_STATS_FILE_NAME = "demographics_stats.pickle"
55
62
  CONCEPT_MAPPING_FILE_NAME = "concept_name_mapping.json"
56
- MOTOR_TIME_TO_EVENT_CODES_FILE_NAME = "motor_time_to_event_codes.json"
63
+ MOTOR_TIME_TO_EVENT_TASK_INFO_FILE_NAME = "motor_time_to_event_info.pickle"
64
+ ONTOLOGY_FILE_NAME = "ontology.pickle"
57
65
 
58
66
  LOG = logging.get_logger("transformers")
59
67
 
@@ -84,6 +92,23 @@ def get_dataset_len(dataset: Union[Dataset, IterableDataset]) -> int:
84
92
  )
85
93
 
86
94
 
95
+ def get_allowed_motor_codes(
96
+ original_concept_codes: List[str], ontology: Optional[Ontology]
97
+ ) -> List[str]:
98
+ filtered_original_concept_codes = filter(is_clinical_event, original_concept_codes)
99
+ if ontology:
100
+ allowed_motor_codes = []
101
+ for concept in filtered_original_concept_codes:
102
+ domain = ontology.get_domain(concept)
103
+ if domain and domain in ["Condition", "Procedure", "Drug", "Visit"]:
104
+ allowed_motor_codes.append(concept)
105
+ elif concept in [DEATH_TOKEN, death_code]:
106
+ allowed_motor_codes.append(concept)
107
+ return allowed_motor_codes
108
+ else:
109
+ return list(filtered_original_concept_codes)
110
+
111
+
87
112
  def create_sample_from_bins(bins, sample_size: int = 10_000) -> List[float]:
88
113
  """
89
114
  Generates a specified number of samples from a list of bins, each containing a fitted spline.
@@ -199,6 +224,112 @@ def create_bins_with_spline(samples, num_bins, d_freedom=3) -> List[Dict[str, An
199
224
  return bins
200
225
 
201
226
 
227
+ def map_motor_tte_statistics(
228
+ batch: Dict[str, Any],
229
+ allowed_motor_codes: List[str],
230
+ ) -> Dict[str, Any]:
231
+ motor_event_times = femr.stat_utils.ReservoirSampler(100_000)
232
+ task_tte_stats: Dict[str, int] = collections.defaultdict(int)
233
+ task_censor_stats: Dict[str, int] = collections.defaultdict(int)
234
+ for concept_ids in batch["concept_ids"]:
235
+ # First collect TTE data in reverse chronological order
236
+ censor_time = 0
237
+ time_to_event_dict: Dict[str, int] = {}
238
+ next_future_visit_concepts = set()
239
+ # Reverse walk through concept_ids to calculate TTE from each [VE] point
240
+ for concept_id in reversed(concept_ids):
241
+ if is_att_token(concept_id):
242
+ time_interval = extract_time_interval_in_days(concept_id)
243
+ if time_interval > 0:
244
+ # Update TTE for existing concepts, or add new ones seen in this visit
245
+ for existing_concept_id in list(time_to_event_dict.keys()):
246
+ if existing_concept_id in next_future_visit_concepts:
247
+ time_to_event_dict[existing_concept_id] = time_interval
248
+ else:
249
+ time_to_event_dict[existing_concept_id] += time_interval
250
+
251
+ for next_concept_id in next_future_visit_concepts:
252
+ if next_concept_id not in time_to_event_dict:
253
+ time_to_event_dict[next_concept_id] = time_interval
254
+
255
+ # Record the censor time at the end of the visit
256
+ censor_time += time_interval
257
+
258
+ # Keep track of the time to event value
259
+ for tte in time_to_event_dict.values():
260
+ motor_event_times.add(tte, 1)
261
+
262
+ for motor_code in allowed_motor_codes:
263
+ if motor_code in time_to_event_dict:
264
+ task_tte_stats[motor_code] += 1
265
+ else:
266
+ task_censor_stats[motor_code] += 1
267
+ next_future_visit_concepts.clear()
268
+ else:
269
+ next_future_visit_concepts.add(concept_id)
270
+
271
+ return {
272
+ "motor_event_times": motor_event_times,
273
+ "task_tte_stats": task_tte_stats,
274
+ "task_censor_stats": task_censor_stats,
275
+ }
276
+
277
+
278
+ def compute_motor_tte_statistics(
279
+ dataset: Dataset,
280
+ data_args: DataTrainingArguments,
281
+ allowed_motor_codes: List[str],
282
+ ontology: Optional[Ontology] = None,
283
+ ) -> Dict[str, Any]:
284
+ map_motor_tte_statistics_partial = partial(
285
+ map_motor_tte_statistics,
286
+ allowed_motor_codes=allowed_motor_codes,
287
+ )
288
+ if data_args.streaming:
289
+ first_example = next(iter(dataset))
290
+ parts = dataset.map(
291
+ partial(agg_helper, map_func=map_motor_tte_statistics_partial),
292
+ batched=True,
293
+ batch_size=data_args.preprocessing_batch_size,
294
+ remove_columns=first_example.keys(),
295
+ )
296
+ else:
297
+ parts = dataset.map(
298
+ partial(agg_helper, map_func=map_motor_tte_statistics_partial),
299
+ batched=True,
300
+ batch_size=data_args.preprocessing_batch_size,
301
+ remove_columns=dataset.column_names,
302
+ num_proc=data_args.preprocessing_num_workers,
303
+ keep_in_memory=True,
304
+ new_fingerprint="invalid",
305
+ )
306
+ current = None
307
+ for stat in tqdm(parts, desc="Aggregating the MOTOR TTE statistics"):
308
+ fixed_stat = pickle.loads(stat["data"])
309
+ if current is None:
310
+ current = fixed_stat
311
+ else:
312
+ current["motor_event_times"].combine(fixed_stat["motor_event_times"])
313
+ for k, v in fixed_stat["task_tte_stats"].items():
314
+ current["task_tte_stats"][k] += v
315
+ for k, v in fixed_stat["task_censor_stats"].items():
316
+ current["task_censor_stats"][k] += v
317
+
318
+ # Aggregate the counts for the parent concepts
319
+ if ontology is not None:
320
+ for k in list(current["task_tte_stats"].keys()):
321
+ for parent in ontology.get_all_parents(k):
322
+ if parent != k:
323
+ current["task_tte_stats"][parent] += current["task_tte_stats"][k]
324
+ for k in list(current["task_censor_stats"].keys()):
325
+ for parent in ontology.get_all_parents(k):
326
+ if parent != k:
327
+ current["task_censor_stats"][parent] += current[
328
+ "task_censor_stats"
329
+ ][k]
330
+ return current
331
+
332
+
202
333
  def agg_statistics(stats1, stats2):
203
334
  if stats1.get("numeric_stats_by_lab"):
204
335
  for k, v in stats2["numeric_stats_by_lab"].items():
@@ -211,6 +342,10 @@ def agg_statistics(stats1, stats2):
211
342
  if stats1.get("concept_code_stats"):
212
343
  for concept_id, weight in stats2["concept_code_stats"].items():
213
344
  stats1["concept_code_stats"][concept_id] += weight
345
+ if stats1.get("gender_list"):
346
+ stats1.get("gender_list").update(stats2.get("gender_list"))
347
+ if stats1.get("race_list"):
348
+ stats1.get("race_list").update(stats2.get("race_list"))
214
349
  return stats1
215
350
 
216
351
 
@@ -239,6 +374,8 @@ def map_statistics(batch: Dict[str, Any], total_size, size=10_000) -> Dict[str,
239
374
  numeric_stats_by_lab = collections.defaultdict(partial(ReservoirSampler, size=size))
240
375
  categorical_stats_by_lab = collections.defaultdict(int)
241
376
  concept_code_stats = collections.defaultdict(int)
377
+ gender_list = set()
378
+ race_list = set()
242
379
  for (
243
380
  concept_ids,
244
381
  number_as_values,
@@ -252,6 +389,11 @@ def map_statistics(batch: Dict[str, Any], total_size, size=10_000) -> Dict[str,
252
389
  batch["concept_value_masks"],
253
390
  batch_value_units,
254
391
  ):
392
+ # Collecting demographics
393
+ gender, race = concept_ids[2:4]
394
+ gender_list.add(gender)
395
+ race_list.add(race)
396
+
255
397
  unique_codes = set()
256
398
  for (
257
399
  concept_id,
@@ -280,11 +422,15 @@ def map_statistics(batch: Dict[str, Any], total_size, size=10_000) -> Dict[str,
280
422
  "numeric_stats_by_lab": numeric_stats_by_lab,
281
423
  "categorical_stats_by_lab": categorical_stats_by_lab,
282
424
  "concept_code_stats": concept_code_stats,
425
+ "gender_list": gender_list,
426
+ "race_list": race_list,
283
427
  }
284
428
 
285
429
 
286
430
  def compute_statistics(
287
- dataset: Dataset, data_args: DataTrainingArguments
431
+ dataset: Dataset,
432
+ data_args: DataTrainingArguments,
433
+ ontology: Optional[Ontology] = None,
288
434
  ) -> Dict[str, Any]:
289
435
  total = get_dataset_len(dataset)
290
436
  map_statistics_partial = partial(map_statistics, total_size=total, size=SAMPLE_SIZE)
@@ -339,25 +485,44 @@ def compute_statistics(
339
485
  ].items():
340
486
  categorical_lab_stats[(concept_id, value_as_concept)] += count
341
487
 
342
- concept_code_stats = collections.defaultdict(int)
488
+ all_concept_code_stats = collections.defaultdict(float)
343
489
  for concept_id, count in current["concept_code_stats"].items():
344
- concept_code_stats[concept_id] += count
345
-
346
- code_weights = np.asarray(list(concept_code_stats.values())).clip(1e-8, 1 - 1e-8)
347
- # Clip the values so we don't get errors when applying np.log
348
- code_entropies = np.log(code_weights) * code_weights + (1 - code_weights) * np.log(
349
- 1 - code_weights
350
- )
351
-
352
- concept_code_entropies = {
353
- k: v for k, v in zip(concept_code_stats.keys(), code_entropies)
354
- }
490
+ if ontology is not None:
491
+ parents = ontology.get_all_parents(concept_id)
492
+ for parent in parents:
493
+ all_concept_code_stats[parent] += count
494
+ else:
495
+ all_concept_code_stats[concept_id] += count
496
+
497
+ all_concept_code_entropies = collections.defaultdict(float)
498
+ for concept_id, weight in all_concept_code_stats.items():
499
+ baseline = (
500
+ min(
501
+ [1]
502
+ + [
503
+ all_concept_code_stats[parent]
504
+ for parent in ontology.get_parents(concept_id)
505
+ ]
506
+ )
507
+ if ontology is not None
508
+ else 1
509
+ )
510
+ weight = weight / baseline
511
+ weight = min(1.0, weight)
512
+ if weight != 0 and weight != 1:
513
+ weight = baseline * (
514
+ weight * math.log(weight) + (1 - weight) * math.log(1 - weight)
515
+ )
516
+ all_concept_code_entropies[concept_id] = weight
355
517
 
356
518
  return {
357
519
  "numeric_lab_stats": numeric_lab_stats,
358
520
  "categorical_lab_stats": categorical_lab_stats,
359
- "concept_code_stats": concept_code_stats,
360
- "concept_code_entropies": concept_code_entropies,
521
+ "original_concept_codes": list(current["concept_code_stats"].keys()),
522
+ "all_concept_code_stats": all_concept_code_stats,
523
+ "all_concept_code_entropies": all_concept_code_entropies,
524
+ "gender_list": current["gender_list"],
525
+ "race_list": current["race_list"],
361
526
  "total": total,
362
527
  }
363
528
 
@@ -465,7 +630,10 @@ class CehrGptTokenizer(PreTrainedTokenizer):
465
630
  categorical_lab_stats: Dict[Tuple[str, str], int],
466
631
  concept_name_mapping: Dict[str, str],
467
632
  pretrained_concept_embedding_model: PretrainedEmbeddings = None,
468
- motor_time_to_event_codes: Optional[List[str]] = None,
633
+ motor_task_info: Optional[Dict[str, Any]] = None,
634
+ gender_map: Optional[Dict[str, int]] = None,
635
+ race_map: Optional[Dict[str, int]] = None,
636
+ ontology: Optional[Ontology] = None,
469
637
  ):
470
638
  self._tokenizer = tokenizer
471
639
  self._value_tokenizer = value_tokenizer
@@ -485,6 +653,8 @@ class CehrGptTokenizer(PreTrainedTokenizer):
485
653
  self._linear_token_id = None
486
654
  if LINEAR_PROB_TOKEN in self._tokenizer.get_vocab():
487
655
  self._linear_token_id = self._tokenizer.token_to_id(LINEAR_PROB_TOKEN)
656
+ else:
657
+ self._linear_token_id = self._oov_token_id
488
658
 
489
659
  self._numeric_concept_ids = (
490
660
  self._numeric_event_statistics.get_numeric_concept_ids()
@@ -503,13 +673,18 @@ class CehrGptTokenizer(PreTrainedTokenizer):
503
673
  for _ in self.get_vocab().keys()
504
674
  if self._pretrained_concept_embedding_model.is_concept_available(_)
505
675
  ]
506
- self._motor_time_to_event_codes = (
507
- motor_time_to_event_codes if motor_time_to_event_codes else []
676
+ self._motor_task_info: Dict[str, Any] = (
677
+ motor_task_info if motor_task_info is not None else {}
678
+ )
679
+ self._motor_time_to_event_codes = self._motor_task_info.get(
680
+ "motor_time_to_event_codes", []
508
681
  )
509
682
  self._motor_code_to_id_mapping = {
510
683
  code: i for i, code in enumerate(sorted(self._motor_time_to_event_codes))
511
684
  }
512
-
685
+ self._gender_map = gender_map if gender_map else {}
686
+ self._race_map = race_map if race_map else {}
687
+ self._ontology = ontology
513
688
  super().__init__()
514
689
 
515
690
  @property
@@ -545,6 +720,16 @@ class CehrGptTokenizer(PreTrainedTokenizer):
545
720
  def time_token_vocab_size(self) -> int:
546
721
  return self._att_tokenizer.get_vocab_size()
547
722
 
723
+ @property
724
+ def gender_size(self) -> int:
725
+ # Plus one for the unknown
726
+ return len(self._gender_map) + 1
727
+
728
+ @property
729
+ def race_size(self) -> int:
730
+ # Plus one for the unknown
731
+ return len(self._race_map) + 1
732
+
548
733
  @property
549
734
  def pad_value_token_id(self):
550
735
  return self._padding_value_token_id
@@ -591,6 +776,7 @@ class CehrGptTokenizer(PreTrainedTokenizer):
591
776
  return LINEAR_PROB_TOKEN
592
777
  return None
593
778
 
779
+ @property
594
780
  def vs_token_id(self):
595
781
  # We used VS for the historical data, currently, we use the new [VS] for the newer data
596
782
  # so we need to check both cases.
@@ -658,8 +844,17 @@ class CehrGptTokenizer(PreTrainedTokenizer):
658
844
  def pretrained_concept_embedding_model(self):
659
845
  return self._pretrained_concept_embedding_model
660
846
 
847
+ def get_motor_time_bins(self, motor_num_time_pieces: int) -> List[int]:
848
+ time_bins = np.percentile(
849
+ self._motor_task_info["motor_event_times"].samples,
850
+ np.linspace(0, 100, motor_num_time_pieces + 1),
851
+ )
852
+ time_bins[0] = 0
853
+ time_bins[-1] = float("inf")
854
+ return list(time_bins)
855
+
661
856
  def get_motor_token_id(self, concept_id: str) -> int:
662
- if concept_id not in concept_id:
857
+ if not self.is_motor_time_to_event_code(concept_id):
663
858
  raise RuntimeError(f"Invalid motor concept id: {concept_id}")
664
859
  return self._motor_code_to_id_mapping[concept_id]
665
860
 
@@ -671,6 +866,21 @@ class CehrGptTokenizer(PreTrainedTokenizer):
671
866
  return True
672
867
  return False
673
868
 
869
+ def get_motor_parents(self, concept_id: str) -> List[str]:
870
+ motor_codes = []
871
+ if self._ontology is None:
872
+ if self.is_motor_time_to_event_code(concept_id):
873
+ motor_codes.append(concept_id)
874
+ else:
875
+ motor_codes.extend(
876
+ [
877
+ p
878
+ for p in self._ontology.get_all_parents(concept_id)
879
+ if self.is_motor_time_to_event_code(p)
880
+ ]
881
+ )
882
+ return motor_codes
883
+
674
884
  def get_vocab(self) -> Dict[str, int]:
675
885
  return self._tokenizer.get_vocab()
676
886
 
@@ -699,6 +909,12 @@ class CehrGptTokenizer(PreTrainedTokenizer):
699
909
  concept_value_token_ids, skip_special_tokens=skip_special_tokens
700
910
  ).split(" ")
701
911
 
912
+ def encode_gender(self, gender: str) -> int:
913
+ return self._gender_map.get(gender, 0)
914
+
915
+ def encode_race(self, race: str) -> int:
916
+ return self._race_map.get(race, 0)
917
+
702
918
  def add_token(self, tokens: Union[str, List[str]]) -> None:
703
919
  if isinstance(tokens, str):
704
920
  tokens = [tokens]
@@ -782,13 +998,28 @@ class CehrGptTokenizer(PreTrainedTokenizer):
782
998
  }
783
999
  pickle.dump(lab_stats, f)
784
1000
 
1001
+ with open(
1002
+ os.path.join(save_directory, DEMOGRAPHICS_STATS_FILE_NAME), "wb"
1003
+ ) as f:
1004
+ pickle.dump(
1005
+ {
1006
+ "gender_map": self._gender_map,
1007
+ "race_map": self._race_map,
1008
+ },
1009
+ f,
1010
+ )
1011
+
785
1012
  with open(os.path.join(save_directory, CONCEPT_MAPPING_FILE_NAME), "w") as f:
786
1013
  json.dump(self._concept_name_mapping, f)
787
1014
 
788
1015
  with open(
789
- os.path.join(save_directory, MOTOR_TIME_TO_EVENT_CODES_FILE_NAME), "w"
1016
+ os.path.join(save_directory, MOTOR_TIME_TO_EVENT_TASK_INFO_FILE_NAME), "wb"
790
1017
  ) as f:
791
- json.dump(self._motor_time_to_event_codes, f)
1018
+ pickle.dump(self._motor_task_info, f)
1019
+
1020
+ if self._ontology is not None:
1021
+ with open(os.path.join(save_directory, ONTOLOGY_FILE_NAME), "wb") as f:
1022
+ pickle.dump(self._ontology, f)
792
1023
 
793
1024
  self._pretrained_concept_embedding_model.save(save_directory)
794
1025
 
@@ -890,6 +1121,23 @@ class CehrGptTokenizer(PreTrainedTokenizer):
890
1121
 
891
1122
  with open(lab_stats_file, "rb") as file:
892
1123
  lab_stats = pickle.load(file)
1124
+ try:
1125
+ # Load the demographics stats json file
1126
+ demographics_stats_file = transformers.utils.hub.cached_file(
1127
+ pretrained_model_name_or_path, DEMOGRAPHICS_STATS_FILE_NAME, **kwargs
1128
+ )
1129
+ if demographics_stats_file:
1130
+ with open(demographics_stats_file, "rb") as file:
1131
+ demographics_stats = pickle.load(file)
1132
+ else:
1133
+ demographics_stats = None
1134
+ except EnvironmentError:
1135
+ LOG.warning(
1136
+ f"The %s files does not exist in %s, setting demographics_stats to None",
1137
+ DEMOGRAPHICS_STATS_FILE_NAME,
1138
+ pretrained_model_name_or_path,
1139
+ )
1140
+ demographics_stats = None
893
1141
 
894
1142
  # Load the concept_name json file
895
1143
  concept_name_mapping_file = transformers.utils.hub.cached_file(
@@ -907,13 +1155,42 @@ class CehrGptTokenizer(PreTrainedTokenizer):
907
1155
  return None
908
1156
  concept_code_stats = load_json_file(concept_code_stats_mapping_file)
909
1157
 
910
- # Load the MOTOR time to event codes file
911
- motor_time_to_event_codes_file = transformers.utils.hub.cached_file(
912
- pretrained_model_name_or_path, MOTOR_TIME_TO_EVENT_CODES_FILE_NAME, **kwargs
913
- )
914
- if not motor_time_to_event_codes_file:
915
- return None
916
- motor_time_to_event_codes = load_json_file(motor_time_to_event_codes_file)
1158
+ try:
1159
+ # Load the MOTOR time to event codes file
1160
+ motor_tte_task_info_file = transformers.utils.hub.cached_file(
1161
+ pretrained_model_name_or_path,
1162
+ MOTOR_TIME_TO_EVENT_TASK_INFO_FILE_NAME,
1163
+ **kwargs,
1164
+ )
1165
+ if motor_tte_task_info_file:
1166
+ with open(motor_tte_task_info_file, "rb") as file:
1167
+ motor_task_info = pickle.load(file)
1168
+ else:
1169
+ motor_task_info = None
1170
+ except EnvironmentError:
1171
+ LOG.warning(
1172
+ f"The %s files does not exist in %s, setting motor_task_info to None",
1173
+ MOTOR_TIME_TO_EVENT_TASK_INFO_FILE_NAME,
1174
+ pretrained_model_name_or_path,
1175
+ )
1176
+ motor_task_info = None
1177
+
1178
+ ontology = None
1179
+ try:
1180
+ ontology_file = transformers.utils.hub.cached_file(
1181
+ pretrained_model_name_or_path,
1182
+ ONTOLOGY_FILE_NAME,
1183
+ **kwargs,
1184
+ )
1185
+ if ontology_file:
1186
+ with open(ontology_file, "rb") as file:
1187
+ ontology = pickle.load(file)
1188
+ except EnvironmentError | OSError:
1189
+ LOG.warning(
1190
+ "The ontology file %s does not existing in %s",
1191
+ ONTOLOGY_FILE_NAME,
1192
+ pretrained_model_name_or_path,
1193
+ )
917
1194
 
918
1195
  pretrained_embedding_model = PretrainedEmbeddings(pretrained_model_name_or_path)
919
1196
 
@@ -927,7 +1204,10 @@ class CehrGptTokenizer(PreTrainedTokenizer):
927
1204
  lab_stats["categorical_lab_stats"],
928
1205
  concept_name_mapping,
929
1206
  pretrained_embedding_model,
930
- motor_time_to_event_codes,
1207
+ motor_task_info,
1208
+ demographics_stats["gender_map"] if demographics_stats else None,
1209
+ demographics_stats["race_map"] if demographics_stats else None,
1210
+ ontology=ontology,
931
1211
  )
932
1212
 
933
1213
  @classmethod
@@ -1048,6 +1328,18 @@ class CehrGptTokenizer(PreTrainedTokenizer):
1048
1328
  motor_time_to_event_code
1049
1329
  )
1050
1330
 
1331
+ for gender in new_tokenizer._gender_map.keys():
1332
+ if gender not in cehrgpt_tokenizer_copy._gender_map:
1333
+ cehrgpt_tokenizer_copy._gender_map[gender] = len(
1334
+ cehrgpt_tokenizer_copy._gender_map
1335
+ )
1336
+
1337
+ for race in new_tokenizer._race_map.keys():
1338
+ if race not in cehrgpt_tokenizer_copy._race_map:
1339
+ cehrgpt_tokenizer_copy._race_map[race] = len(
1340
+ cehrgpt_tokenizer_copy._race_map
1341
+ )
1342
+
1051
1343
  return CehrGptTokenizer(
1052
1344
  tokenizer=cehrgpt_tokenizer_copy._tokenizer,
1053
1345
  value_tokenizer=cehrgpt_tokenizer_copy._value_tokenizer,
@@ -1058,7 +1350,10 @@ class CehrGptTokenizer(PreTrainedTokenizer):
1058
1350
  categorical_lab_stats=cehrgpt_tokenizer_copy._categorical_lab_stats,
1059
1351
  concept_name_mapping=cehrgpt_tokenizer_copy._concept_name_mapping,
1060
1352
  pretrained_concept_embedding_model=pretrained_concept_embedding_model,
1061
- motor_time_to_event_codes=cehrgpt_tokenizer_copy._motor_time_to_event_codes,
1353
+ motor_task_info=cehrgpt_tokenizer_copy._motor_task_info,
1354
+ gender_map=cehrgpt_tokenizer_copy._gender_map,
1355
+ race_map=cehrgpt_tokenizer_copy._race_map,
1356
+ ontology=cehrgpt_tokenizer_copy._ontology,
1062
1357
  )
1063
1358
 
1064
1359
  @classmethod
@@ -1128,10 +1423,10 @@ class CehrGptTokenizer(PreTrainedTokenizer):
1128
1423
  concept_name_mapping: Dict[str, str],
1129
1424
  data_args: DataTrainingArguments,
1130
1425
  pretrained_concept_embedding_model: PretrainedEmbeddings = None,
1131
- allowed_motor_codes: Optional[List[int]] = None,
1132
1426
  num_motor_tasks: Optional[int] = None,
1133
1427
  apply_entropy_filter: bool = False,
1134
1428
  min_prevalence: float = 1 / 1000,
1429
+ ontology: Optional[Ontology] = None,
1135
1430
  ):
1136
1431
  """
1137
1432
  Train a huggingface word level tokenizer.
@@ -1144,12 +1439,14 @@ class CehrGptTokenizer(PreTrainedTokenizer):
1144
1439
  dataset = dataset["train"]
1145
1440
 
1146
1441
  LOG.info("Calculating data statistics")
1147
- cehrgpt_data_statistics = compute_statistics(dataset, data_args)
1148
- cehrgpt_data_statistics["total"]
1442
+ cehrgpt_data_statistics = compute_statistics(dataset, data_args, ontology)
1149
1443
  numeric_lab_stats = cehrgpt_data_statistics["numeric_lab_stats"]
1150
1444
  categorical_lab_stats = cehrgpt_data_statistics["categorical_lab_stats"]
1151
- concept_code_stats = cehrgpt_data_statistics["concept_code_stats"]
1152
- concept_code_entropies = cehrgpt_data_statistics["concept_code_entropies"]
1445
+ original_concept_codes = cehrgpt_data_statistics["original_concept_codes"]
1446
+ all_concept_code_stats = cehrgpt_data_statistics["all_concept_code_stats"]
1447
+ all_concept_code_entropies = cehrgpt_data_statistics[
1448
+ "all_concept_code_entropies"
1449
+ ]
1153
1450
 
1154
1451
  if apply_entropy_filter:
1155
1452
  min_prevalence = max(1e-8, min_prevalence)
@@ -1159,15 +1456,15 @@ class CehrGptTokenizer(PreTrainedTokenizer):
1159
1456
  )
1160
1457
  qualified_codes = [
1161
1458
  k
1162
- for k, v in concept_code_entropies.items()
1163
- if v <= min_entropy
1459
+ for k in original_concept_codes
1460
+ if all_concept_code_entropies[k] <= min_entropy
1164
1461
  or not is_clinical_event(k, data_args.is_data_in_meds)
1165
1462
  ]
1166
1463
  else:
1167
1464
  qualified_codes = [
1168
1465
  k
1169
- for k, v in concept_code_stats.items()
1170
- if min_prevalence <= v
1466
+ for k in original_concept_codes
1467
+ if min_prevalence <= all_concept_code_stats[k]
1171
1468
  or not is_clinical_event(k, data_args.is_data_in_meds)
1172
1469
  ]
1173
1470
 
@@ -1240,36 +1537,69 @@ class CehrGptTokenizer(PreTrainedTokenizer):
1240
1537
  if concept_id in concept_name_mapping
1241
1538
  }
1242
1539
 
1243
- motor_time_to_event_codes = None
1244
- if num_motor_tasks and allowed_motor_codes:
1540
+ motor_task_info = None
1541
+ if num_motor_tasks:
1542
+ LOG.info("Computing the MOTOR TTE statistics")
1543
+ allowed_motor_codes = get_allowed_motor_codes(
1544
+ original_concept_codes, ontology
1545
+ )
1546
+ motor_tte_statistics = compute_motor_tte_statistics(
1547
+ dataset, data_args, allowed_motor_codes, ontology
1548
+ )
1245
1549
  motor_time_to_event_codes = []
1246
1550
  for concept_id, _ in sorted(
1247
- concept_code_entropies.items(), key=lambda t: t[1]
1551
+ all_concept_code_entropies.items(), key=lambda t: t[1]
1248
1552
  ):
1249
- if (
1250
- concept_id not in allowed_motor_codes
1251
- or concept_id not in qualified_codes
1252
- ):
1553
+ if concept_id not in allowed_motor_codes:
1554
+ continue
1555
+ tte_stats = motor_tte_statistics["task_tte_stats"][concept_id]
1556
+ censor_stats = motor_tte_statistics["task_censor_stats"][concept_id]
1557
+ frac_events = tte_stats / (tte_stats + censor_stats)
1558
+
1559
+ if frac_events < 1 / 1000:
1560
+ LOG.info(
1561
+ "Ran into very rare task %s with %s", concept_id, frac_events
1562
+ )
1253
1563
  continue
1564
+
1254
1565
  if len(motor_time_to_event_codes) < num_motor_tasks:
1255
1566
  motor_time_to_event_codes.append(concept_id)
1256
1567
  else:
1257
1568
  break
1569
+
1258
1570
  LOG.info(
1259
1571
  f"{len(motor_time_to_event_codes)} number of tasks have been added as MOTOR tasks"
1260
1572
  )
1573
+ motor_task_info = {
1574
+ "motor_event_times": motor_tte_statistics["motor_event_times"],
1575
+ "task_tte_stats": motor_tte_statistics["task_tte_stats"],
1576
+ "task_censor_stats": motor_tte_statistics["task_censor_stats"],
1577
+ "motor_time_to_event_codes": motor_time_to_event_codes,
1578
+ }
1579
+
1580
+ gender_map = {
1581
+ gender: i + 1
1582
+ for i, gender in enumerate(sorted(cehrgpt_data_statistics["gender_list"]))
1583
+ }
1584
+ race_map = {
1585
+ race: i + 1
1586
+ for i, race in enumerate(sorted(cehrgpt_data_statistics["race_list"]))
1587
+ }
1261
1588
 
1262
1589
  return CehrGptTokenizer(
1263
1590
  concept_tokenizer,
1264
1591
  value_tokenizer,
1265
1592
  att_tokenizer,
1266
1593
  token_to_sub_time_token_mapping,
1267
- concept_code_stats,
1594
+ all_concept_code_stats,
1268
1595
  numeric_lab_stats,
1269
1596
  categorical_lab_stats,
1270
1597
  concept_name_mapping,
1271
1598
  pretrained_concept_embedding_model,
1272
- motor_time_to_event_codes,
1599
+ motor_task_info,
1600
+ gender_map,
1601
+ race_map,
1602
+ ontology,
1273
1603
  )
1274
1604
 
1275
1605
  @classmethod