cehrgpt 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. __init__.py +0 -0
  2. cehrgpt/__init__.py +0 -0
  3. cehrgpt/analysis/__init__.py +0 -0
  4. cehrgpt/analysis/privacy/__init__.py +0 -0
  5. cehrgpt/analysis/privacy/attribute_inference.py +275 -0
  6. cehrgpt/analysis/privacy/attribute_inference_config.yml +8975 -0
  7. cehrgpt/analysis/privacy/member_inference.py +172 -0
  8. cehrgpt/analysis/privacy/nearest_neighbor_inference.py +189 -0
  9. cehrgpt/analysis/privacy/reid_inference.py +407 -0
  10. cehrgpt/analysis/privacy/utils.py +255 -0
  11. cehrgpt/cehrgpt_args.py +142 -0
  12. cehrgpt/data/__init__.py +0 -0
  13. cehrgpt/data/hf_cehrgpt_dataset.py +80 -0
  14. cehrgpt/data/hf_cehrgpt_dataset_collator.py +482 -0
  15. cehrgpt/data/hf_cehrgpt_dataset_mapping.py +116 -0
  16. cehrgpt/generation/__init__.py +0 -0
  17. cehrgpt/generation/chatgpt_generation.py +106 -0
  18. cehrgpt/generation/generate_batch_hf_gpt_sequence.py +333 -0
  19. cehrgpt/generation/omop_converter_batch.py +644 -0
  20. cehrgpt/generation/omop_entity.py +515 -0
  21. cehrgpt/gpt_utils.py +331 -0
  22. cehrgpt/models/__init__.py +0 -0
  23. cehrgpt/models/config.py +205 -0
  24. cehrgpt/models/hf_cehrgpt.py +1817 -0
  25. cehrgpt/models/hf_modeling_outputs.py +158 -0
  26. cehrgpt/models/pretrained_embeddings.py +82 -0
  27. cehrgpt/models/special_tokens.py +30 -0
  28. cehrgpt/models/tokenization_hf_cehrgpt.py +1077 -0
  29. cehrgpt/omop/__init__.py +0 -0
  30. cehrgpt/omop/condition_era.py +20 -0
  31. cehrgpt/omop/observation_period.py +43 -0
  32. cehrgpt/omop/omop_argparse.py +38 -0
  33. cehrgpt/omop/omop_table_builder.py +86 -0
  34. cehrgpt/omop/queries/__init__.py +0 -0
  35. cehrgpt/omop/queries/condition_era.py +86 -0
  36. cehrgpt/omop/queries/observation_period.py +135 -0
  37. cehrgpt/omop/sample_omop_tables.py +71 -0
  38. cehrgpt/runners/__init__.py +0 -0
  39. cehrgpt/runners/gpt_runner_util.py +99 -0
  40. cehrgpt/runners/hf_cehrgpt_finetune_runner.py +746 -0
  41. cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +370 -0
  42. cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +137 -0
  43. cehrgpt/runners/hyperparameter_search_util.py +223 -0
  44. cehrgpt/time_to_event/__init__.py +0 -0
  45. cehrgpt/time_to_event/config/30_day_readmission.yaml +8 -0
  46. cehrgpt/time_to_event/config/next_visit_type_prediction.yaml +8 -0
  47. cehrgpt/time_to_event/config/t2dm_hf.yaml +8 -0
  48. cehrgpt/time_to_event/time_to_event_model.py +226 -0
  49. cehrgpt/time_to_event/time_to_event_prediction.py +347 -0
  50. cehrgpt/time_to_event/time_to_event_utils.py +55 -0
  51. cehrgpt/tools/__init__.py +0 -0
  52. cehrgpt/tools/ehrshot_benchmark.py +74 -0
  53. cehrgpt/tools/generate_pretrained_embeddings.py +130 -0
  54. cehrgpt/tools/merge_synthetic_real_dataasets.py +218 -0
  55. cehrgpt/tools/upload_omop_tables.py +108 -0
  56. cehrgpt-0.0.1.dist-info/LICENSE +21 -0
  57. cehrgpt-0.0.1.dist-info/METADATA +66 -0
  58. cehrgpt-0.0.1.dist-info/RECORD +60 -0
  59. cehrgpt-0.0.1.dist-info/WHEEL +5 -0
  60. cehrgpt-0.0.1.dist-info/top_level.txt +2 -0
@@ -0,0 +1,1077 @@
1
+ import collections
2
+ import copy
3
+ import json
4
+ import os
5
+ import pickle
6
+ from functools import partial
7
+ from itertools import islice
8
+ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
9
+
10
+ import numpy as np
11
+ import scipy.stats as stats
12
+ import transformers
13
+ from cehrbert.models.hf_models.tokenization_utils import (
14
+ agg_helper,
15
+ agg_statistics,
16
+ load_json_file,
17
+ )
18
+ from cehrbert.runners.hf_runner_argument_dataclass import DataTrainingArguments
19
+ from datasets import Dataset, DatasetDict
20
+ from femr.stat_utils import OnlineStatistics, ReservoirSampler
21
+ from scipy.interpolate import UnivariateSpline
22
+ from tokenizers import AddedToken, Tokenizer
23
+ from tokenizers.models import WordLevel
24
+ from tokenizers.pre_tokenizers import WhitespaceSplit
25
+ from tokenizers.trainers import WordLevelTrainer
26
+ from tqdm import tqdm
27
+ from transformers import PreTrainedTokenizer
28
+
29
+ from cehrgpt.gpt_utils import (
30
+ convert_time_interval_to_time_tuple,
31
+ extract_time_interval_in_days,
32
+ is_att_token,
33
+ is_inpatient_att_token,
34
+ )
35
+ from cehrgpt.models.pretrained_embeddings import PretrainedEmbeddings
36
+ from cehrgpt.models.special_tokens import (
37
+ END_TOKEN,
38
+ OUT_OF_VOCABULARY_TOKEN,
39
+ PAD_TOKEN,
40
+ START_TOKEN,
41
+ )
42
+
43
+ NUM_OF_BINS = 10
44
+ DEGREE_OF_FREEDOM = 3
45
+ SAMPLE_SIZE = 10_000
46
+ NA = "N/A"
47
+ UNKNOWN_BIN = "BIN:unknown"
48
+ NONE_BIN = "BIN:NONE"
49
+ TOKENIZER_FILE_NAME = "cehrgpt_tokenizer.json"
50
+ VALUE_TOKENIZER_FILE_NAME = "cehrgpt_value_tokenizer.json"
51
+ TIME_TOKENIZER_FILE_NAME = "cehrgpt_time_tokenizer.json"
52
+ TOKEN_TO_SUB_TIME_TOKEN_MAPPING_FILE_NAME = "token_to_sub_time_token_mapping.json"
53
+ LAB_STATS_FILE_NAME = "cehrgpt_lab_stats.pickle"
54
+ LEGACY_LAB_STATS_FILE_NAME = "cehrgpt_lab_stats.json"
55
+ CONCEPT_MAPPING_FILE_NAME = "concept_name_mapping.json"
56
+
57
+
58
+ def truncated_sample(sample, standard_deviation):
59
+ lower_quantile = stats.norm.cdf(-standard_deviation)
60
+ upper_quantile = stats.norm.cdf(standard_deviation)
61
+ lower_bound = np.quantile(sample, lower_quantile)
62
+ upper_bound = np.quantile(sample, upper_quantile)
63
+ return [x for x in sample if lower_bound <= x <= upper_bound]
64
+
65
+
66
+ def is_valid_valid_bin(token: str) -> bool:
67
+ return token.startswith("BIN:")
68
+
69
+
70
+ def create_value_bin(bin_index: int) -> str:
71
+ return "BIN:" + str(bin_index)
72
+
73
+
74
+ def create_sample_from_bins(bins, sample_size: int = 10_000) -> List[float]:
75
+ """
76
+ Generates a specified number of samples from a list of bins, each containing a fitted spline.
77
+
78
+ This function iterates over each bin, extracts the spline, and uses it to generate a set of samples
79
+ uniformly distributed along the x-axis defined by the spline's knots. It ensures that the total number
80
+ of samples generated matches the specified sample size by distributing the number of samples evenly
81
+ across the bins.
82
+
83
+ Parameters:
84
+ bins (List[Dict[str, UnivariateSpline]]): A list of dictionaries, each containing a 'spline' key
85
+ with a UnivariateSpline object as its value. These splines define the data distribution within
86
+ each bin from which samples are to be generated.
87
+ sample_size (int, optional): The total number of samples to generate from all bins combined.
88
+ Defaults to 10,000.
89
+
90
+ Returns:
91
+ List[float]: A list of sampled values, where each value is generated based on the spline functions
92
+ provided in the bins. The total number of samples in the list will be equal to `sample_size`.
93
+
94
+ Raises:
95
+ ValueError: If `sample_size` is less than the number of bins, as it would not be possible to generate
96
+ at least one sample per bin.
97
+
98
+ Example:
99
+ >>> x = np.linspace(0, 10, 100)
100
+ >>> y = np.sin(x)
101
+ >>> spline = UnivariateSpline(x, y, s=1)
102
+ >>> bins = [{'spline': spline} for _ in range(5)]
103
+ >>> samples = create_sample_from_bins(bins, 1000)
104
+ >>> len(samples)
105
+ 1000
106
+
107
+ Note:
108
+ The function assumes that each bin's spline has a sufficient range of x-values (knots) to allow for
109
+ meaningful sampling. If the range of x-values is too narrow, the uniformity of the sample distribution
110
+ may be affected.
111
+ """
112
+ sample = []
113
+ num_of_bins = len(bins)
114
+ if num_of_bins > 0:
115
+ sample_per_bin = sample_size // num_of_bins
116
+ for value_bin in bins:
117
+ bin_spline = value_bin["spline"]
118
+ x = np.random.uniform(
119
+ bin_spline.get_knots()[0], bin_spline.get_knots()[-1], sample_per_bin
120
+ )
121
+ y = bin_spline(x)
122
+ sample.extend(y)
123
+ return sample
124
+
125
+
126
+ def create_bins_with_spline(samples, num_bins, d_freedom=3) -> List[Dict[str, Any]]:
127
+ """
128
+ Divides a list of numeric samples into a specified number of bins and fits a spline to the data in each bin.
129
+
130
+ This function first sorts the list of samples, then partitions the sorted list into `num_bins` bins. For each bin,
131
+ a UnivariateSpline is fitted to the data within the bin, using the specified degrees of freedom. The function
132
+ handles edge cases by assigning infinity to the bounds of the first and last bins.
133
+
134
+ Parameters:
135
+ samples (List[float]): A list of sample data points, which are real numbers.
136
+ num_bins (int): The number of bins to divide the sample data into. It is assumed that there are enough samples to at least fill the bins to the minimum required for spline fitting.
137
+ d_freedom (int, optional): The degree of freedom for the spline. Default is 1, which fits a linear spline.
138
+
139
+ Returns:
140
+ List[Dict[str, Any]]: A list of dictionaries, each representing a bin. Each dictionary contains:
141
+ - 'bin_index' (int): The index of the bin.
142
+ - 'start_val' (float): The starting value of the bin, with the first bin starting at negative infinity.
143
+ - 'end_val' (float): The ending value of the bin, with the last bin ending at positive infinity.
144
+ - 'spline' (UnivariateSpline): The spline object fitted to the data within the bin.
145
+
146
+ Raises:
147
+ ValueError: If `num_bins` is less than 1 or if there are insufficient samples to create the specified number of bins with the required minimum data points per bin.
148
+
149
+ Example:
150
+ >>> samples = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
151
+ >>> bins = create_bins_with_spline(samples, 2)
152
+ >>> for b in bins:
153
+ ... print(b['bin_index'], b['start_val'], b['end_val'])
154
+ ...
155
+ 0 -inf 5.5
156
+ 1 5.5 inf
157
+
158
+ Note:
159
+ The function assumes that each bin will have at least `d_freedom + 1` samples to fit the spline. If the total number of samples is less than `num_bins * (d_freedom + 1)`, no bins will be created.
160
+ """
161
+ samples.sort()
162
+ bins = []
163
+ if len(samples) >= num_bins * (d_freedom + 1):
164
+ samples_per_bin = len(samples) // num_bins
165
+ for bin_index in range(0, num_bins):
166
+ if bin_index == 0:
167
+ start_val = float("-inf")
168
+ else:
169
+ start_val = samples[bin_index * samples_per_bin]
170
+
171
+ if bin_index == num_bins - 1:
172
+ end_val = float("inf")
173
+ else:
174
+ end_val = samples[(bin_index + 1) * samples_per_bin]
175
+ x = range(bin_index * samples_per_bin, (bin_index + 1) * samples_per_bin)
176
+ y = samples[bin_index * samples_per_bin : (bin_index + 1) * samples_per_bin]
177
+ spline = UnivariateSpline(x, y, k=d_freedom)
178
+ bins.append(
179
+ {
180
+ "bin_index": bin_index,
181
+ "start_val": start_val,
182
+ "end_val": end_val,
183
+ "spline": spline,
184
+ }
185
+ )
186
+ return bins
187
+
188
+
189
+ def map_statistics(batch: Dict[str, Any], size=10_000) -> Dict[str, Any]:
190
+ if "units" in batch:
191
+ batch_value_units = batch["units"]
192
+ else:
193
+ batch_value_units = [[NA for _ in cons] for cons in batch["concept_ids"]]
194
+
195
+ if "number_as_values" not in batch:
196
+ batched_number_as_values = [
197
+ [value if isinstance(value, float) else None for value in concept_values]
198
+ for concept_values in batch["concept_values"]
199
+ ]
200
+ else:
201
+ batched_number_as_values = batch["number_as_values"]
202
+
203
+ if "concept_as_values" not in batch:
204
+ batched_concept_as_values = [
205
+ [value if isinstance(value, str) else None for value in concept_values]
206
+ for concept_values in batch["concept_values"]
207
+ ]
208
+ else:
209
+ batched_concept_as_values = batch["concept_as_values"]
210
+
211
+ numeric_stats_by_lab = collections.defaultdict(partial(ReservoirSampler, size=size))
212
+ categorical_stats_by_lab = collections.defaultdict(int)
213
+ for (
214
+ concept_ids,
215
+ number_as_values,
216
+ concept_as_values,
217
+ concept_value_indicators,
218
+ units,
219
+ ) in zip(
220
+ batch["concept_ids"],
221
+ batched_number_as_values,
222
+ batched_concept_as_values,
223
+ batch["concept_value_masks"],
224
+ batch_value_units,
225
+ ):
226
+ for (
227
+ concept_id,
228
+ number_as_value,
229
+ concept_as_value,
230
+ concept_value_indicator,
231
+ unit,
232
+ ) in zip(
233
+ concept_ids,
234
+ number_as_values,
235
+ concept_as_values,
236
+ concept_value_indicators,
237
+ units,
238
+ ):
239
+ if concept_value_indicator == 1:
240
+ if number_as_value:
241
+ numeric_stats_by_lab[(concept_id, unit)].add(number_as_value, 1)
242
+ if concept_as_value:
243
+ categorical_stats_by_lab[(concept_id, concept_as_value)] += 1
244
+
245
+ return {
246
+ "numeric_stats_by_lab": numeric_stats_by_lab,
247
+ "categorical_stats_by_lab": categorical_stats_by_lab,
248
+ }
249
+
250
+
251
+ def create_numeric_concept_unit_mapping(
252
+ lab_stats: List[Dict[str, Any]]
253
+ ) -> Tuple[Dict[str, List[float]], Dict[str, List[str]]]:
254
+ numeric_concept_unit_mapping = collections.defaultdict(list)
255
+ for each_lab_stat in lab_stats:
256
+ numeric_concept_unit_mapping[each_lab_stat["concept_id"]].append(
257
+ (each_lab_stat["count"], each_lab_stat["unit"])
258
+ )
259
+
260
+ concept_prob_mapping = dict()
261
+ concept_unit_mapping = dict()
262
+ for concept_id in numeric_concept_unit_mapping.keys():
263
+ counts, units = zip(*numeric_concept_unit_mapping[concept_id])
264
+ total_count = sum(counts)
265
+ probs = [float(c) / total_count for c in counts]
266
+ concept_prob_mapping[concept_id] = probs
267
+ concept_unit_mapping[concept_id] = units
268
+ return concept_prob_mapping, concept_unit_mapping
269
+
270
+
271
+ class NumericEventStatistics:
272
+ def __init__(self, lab_stats: List[Dict[str, Any]]):
273
+ self._lab_stats = lab_stats
274
+ self._lab_stats_mapping = {
275
+ (lab_stat["concept_id"], lab_stat["unit"]): {
276
+ "unit": lab_stat["unit"],
277
+ "mean": lab_stat["mean"],
278
+ "std": lab_stat["std"],
279
+ "value_outlier_std": lab_stat["value_outlier_std"],
280
+ "bins": lab_stat["bins"],
281
+ }
282
+ for lab_stat in lab_stats
283
+ }
284
+ self._concept_prob_mapping, self._concept_unit_mapping = (
285
+ create_numeric_concept_unit_mapping(lab_stats)
286
+ )
287
+
288
+ def get_numeric_concept_ids(self) -> List[str]:
289
+ return [_["concept_id"] for _ in self._lab_stats]
290
+
291
+ def get_random_unit(self, concept_id: str) -> str:
292
+ if concept_id in self._concept_prob_mapping:
293
+ unit_probs = self._concept_prob_mapping[concept_id]
294
+ return np.random.choice(
295
+ self._concept_unit_mapping[concept_id], p=unit_probs
296
+ )
297
+ return NA
298
+
299
+ def normalize(
300
+ self, concept_id: str, unit: str, concept_value: Union[float, str]
301
+ ) -> str:
302
+ if isinstance(concept_value, float):
303
+ if (concept_id, unit) in self._lab_stats_mapping:
304
+ concept_unit_stats = self._lab_stats_mapping[(concept_id, unit)]
305
+ bins = concept_unit_stats["bins"]
306
+ if bins:
307
+ for each_bin in bins:
308
+ if (
309
+ each_bin["start_val"]
310
+ <= concept_value
311
+ <= each_bin["end_val"]
312
+ ):
313
+ return create_value_bin(each_bin["bin_index"])
314
+ return UNKNOWN_BIN
315
+
316
+ def denormalize(
317
+ self, concept_id: str, value_bin: str
318
+ ) -> Tuple[Optional[Union[float, str]], str]:
319
+ unit = self.get_random_unit(concept_id)
320
+ concept_value = value_bin
321
+ if (
322
+ is_valid_valid_bin(value_bin)
323
+ and (concept_id, unit) in self._lab_stats_mapping
324
+ ):
325
+ lab_stats = self._lab_stats_mapping[(concept_id, unit)]
326
+ bin_index = value_bin.split(":")[1]
327
+ if bin_index.isnumeric():
328
+ bin_index = int(bin_index)
329
+ # There are rare cases during sequence generation where bin_index could be out of range
330
+ # when there are no bins for (concept_id, unit) due to the small number of values in the source data
331
+ if len(lab_stats["bins"]) > bin_index:
332
+ assert bin_index == lab_stats["bins"][bin_index]["bin_index"]
333
+ bin_spline = lab_stats["bins"][bin_index]["spline"]
334
+ x = np.random.uniform(
335
+ bin_spline.get_knots()[0], bin_spline.get_knots()[-1]
336
+ )
337
+ concept_value = bin_spline(x).item()
338
+ return concept_value, unit
339
+
340
+
341
+ class CehrGptTokenizer(PreTrainedTokenizer):
342
+
343
+ def __init__(
344
+ self,
345
+ tokenizer: Tokenizer,
346
+ value_tokenizer: Tokenizer,
347
+ att_tokenizer: Tokenizer,
348
+ token_to_sub_time_token_mapping: Dict[str, List[str]],
349
+ numeric_lab_stats: List[Dict[str, Any]],
350
+ categorical_lab_stats: Dict[Tuple[str, str], int],
351
+ concept_name_mapping: Dict[str, str],
352
+ pretrained_concept_embedding_model: PretrainedEmbeddings = None,
353
+ ):
354
+ self._tokenizer = tokenizer
355
+ self._value_tokenizer = value_tokenizer
356
+ self._att_tokenizer = att_tokenizer
357
+ self._token_to_sub_time_token_mapping = token_to_sub_time_token_mapping
358
+ self._numeric_lab_stats = numeric_lab_stats
359
+ self._numeric_event_statistics = NumericEventStatistics(numeric_lab_stats)
360
+ self._categorical_lab_stats = categorical_lab_stats
361
+ self._concept_name_mapping = concept_name_mapping
362
+ self._oov_token_id = self._tokenizer.token_to_id(OUT_OF_VOCABULARY_TOKEN)
363
+ self._padding_token_id = self._tokenizer.token_to_id(PAD_TOKEN)
364
+ self._start_token_id = self._tokenizer.token_to_id(START_TOKEN)
365
+ self._end_token_id = self._tokenizer.token_to_id(END_TOKEN)
366
+ self._numeric_concept_ids = (
367
+ self._numeric_event_statistics.get_numeric_concept_ids()
368
+ )
369
+ self._categorical_concept_ids = list(
370
+ {t[0] for t in self._categorical_lab_stats.keys()}
371
+ )
372
+ self._padding_value_token_id = self._value_tokenizer.token_to_id(PAD_TOKEN)
373
+ self._pretrained_concept_embedding_model = (
374
+ pretrained_concept_embedding_model
375
+ if pretrained_concept_embedding_model
376
+ else PretrainedEmbeddings(None)
377
+ )
378
+ self._pretrained_concept_ids = [
379
+ _
380
+ for _ in self.get_vocab().keys()
381
+ if self._pretrained_concept_embedding_model.is_concept_available(_)
382
+ ]
383
+
384
+ super().__init__()
385
+
386
+ @property
387
+ def pretrained_concept_ids(self):
388
+ return self._pretrained_concept_ids
389
+
390
+ @property
391
+ def pretrained_token_ids(self):
392
+ return self.encode(self._pretrained_concept_ids)
393
+
394
+ @property
395
+ def pretrained_embeddings(self):
396
+ return np.asarray(
397
+ [
398
+ self._pretrained_concept_embedding_model.get_concept_embeddings(_)
399
+ for _ in self._pretrained_concept_ids
400
+ ]
401
+ )
402
+
403
+ @property
404
+ def vocab_size(self) -> int:
405
+ return self._tokenizer.get_vocab_size()
406
+
407
+ @property
408
+ def value_vocab_size(self) -> int:
409
+ return self._value_tokenizer.get_vocab_size()
410
+
411
+ @property
412
+ def time_token_vocab_size(self) -> int:
413
+ return self._att_tokenizer.get_vocab_size()
414
+
415
+ @property
416
+ def pad_value_token_id(self):
417
+ return self._padding_value_token_id
418
+
419
+ @property
420
+ def start_token_id(self):
421
+ return self._start_token_id
422
+
423
+ @property
424
+ def end_token_id(self):
425
+ return self._end_token_id
426
+
427
+ @property
428
+ def end_token(self):
429
+ return END_TOKEN
430
+
431
+ @property
432
+ def eos_token(self):
433
+ return END_TOKEN
434
+
435
+ @property
436
+ def eos_token_id(self):
437
+ return self._end_token_id
438
+
439
+ @property
440
+ def pad_token_id(self):
441
+ return self._padding_token_id
442
+
443
+ @property
444
+ def pad_token(self):
445
+ return PAD_TOKEN
446
+
447
+ @property
448
+ def numeric_concept_ids(self):
449
+ return self._numeric_concept_ids
450
+
451
+ @property
452
+ def categorical_concept_ids(self):
453
+ return self._categorical_concept_ids
454
+
455
+ @property
456
+ def lab_token_ids(self):
457
+ reserved_tokens = [START_TOKEN, PAD_TOKEN, END_TOKEN, OUT_OF_VOCABULARY_TOKEN]
458
+ return self.encode(
459
+ [
460
+ concept_id
461
+ for concept_id in self._numeric_concept_ids
462
+ + self._categorical_concept_ids
463
+ if concept_id not in reserved_tokens
464
+ ]
465
+ )
466
+
467
+ @property
468
+ def token_to_time_token_mapping(self) -> Dict[int, List[int]]:
469
+ default_mapping = {-1: [0, 0, 0]}
470
+ default_mapping.update(
471
+ {
472
+ self._tokenizer.token_to_id(time_token): list(
473
+ map(self._att_tokenizer.token_to_id, sub_time_tokens)
474
+ )
475
+ for time_token, sub_time_tokens in self._token_to_sub_time_token_mapping.items()
476
+ }
477
+ )
478
+ return default_mapping
479
+
480
+ @property
481
+ def pretrained_concept_embedding_model(self):
482
+ return self._pretrained_concept_embedding_model
483
+
484
+ def get_vocab(self) -> Dict[str, int]:
485
+ return self._tokenizer.get_vocab()
486
+
487
+ def get_value_vocab(self) -> Dict[str, int]:
488
+ return self._value_tokenizer.get_vocab()
489
+
490
+ def encode(self, concept_ids, **kwargs) -> Sequence[int]:
491
+ encoded = self._tokenizer.encode(concept_ids, is_pretokenized=True)
492
+ return encoded.ids
493
+
494
+ def decode(
495
+ self, concept_token_ids: List[int], skip_special_tokens: bool = True, **kwargs
496
+ ) -> List[str]:
497
+ return self._tokenizer.decode(
498
+ concept_token_ids, skip_special_tokens=skip_special_tokens
499
+ ).split(" ")
500
+
501
+ def encode_value(self, concept_values: Sequence[str]) -> Sequence[int]:
502
+ encoded = self._value_tokenizer.encode(concept_values, is_pretokenized=True)
503
+ return encoded.ids
504
+
505
+ def decode_value(
506
+ self, concept_value_token_ids: List[int], skip_special_tokens: bool = True
507
+ ) -> List[str]:
508
+ return self._value_tokenizer.decode(
509
+ concept_value_token_ids, skip_special_tokens=skip_special_tokens
510
+ ).split(" ")
511
+
512
+ def _convert_token_to_id(self, token):
513
+ """Converts a token (str) in an id using the vocab."""
514
+ token_id = self._tokenizer.token_to_id(token)
515
+ return token_id if token_id else self._oov_token_id
516
+
517
+ def _convert_id_to_token(self, index):
518
+ """Converts an index (integer) in a token (str) using the vocab."""
519
+ token = self._tokenizer.id_to_token(index)
520
+ return token if token else OUT_OF_VOCABULARY_TOKEN
521
+
522
+ def convert_tokens_to_string(self, tokens):
523
+ """Converts a sequence of tokens (string) in a single string."""
524
+ out_string = " ".join([self._concept_name_mapping[t] for t in tokens])
525
+ return out_string
526
+
527
+ def save_pretrained(
528
+ self,
529
+ save_directory: Union[str, os.PathLike],
530
+ push_to_hub: bool = False,
531
+ **kwargs,
532
+ ):
533
+ """
534
+ Save the Cehrbert tokenizer.
535
+
536
+ This method make sure the batch processor can then be re-loaded using the
537
+ .from_pretrained class method.
538
+
539
+ Args:
540
+ save_directory (`str` or `os.PathLike`): The path to a directory where the tokenizer will be saved.
541
+ push_to_hub (`bool`, *optional*, defaults to `False`):
542
+ Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
543
+ repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
544
+ namespace).
545
+ kwargs (`Dict[str, Any]`, *optional*):
546
+ Additional key word arguments passed along to the [`PushToHubMixin.push_to_hub`] method.
547
+ """
548
+ assert not os.path.isfile(
549
+ save_directory
550
+ ), f"Provided path ({save_directory}) should be a directory, not a file"
551
+
552
+ os.makedirs(save_directory, exist_ok=True)
553
+
554
+ if push_to_hub:
555
+ commit_message = kwargs.pop("commit_message", None)
556
+ repo_id = kwargs.pop("repo_id", str(save_directory).split(os.path.sep)[-1])
557
+ repo_id = self._create_repo(repo_id, **kwargs)
558
+ files_timestamps = self._get_files_timestamps(save_directory)
559
+
560
+ self._tokenizer.save(os.path.join(save_directory, TOKENIZER_FILE_NAME))
561
+
562
+ self._value_tokenizer.save(
563
+ os.path.join(save_directory, VALUE_TOKENIZER_FILE_NAME)
564
+ )
565
+
566
+ self._att_tokenizer.save(os.path.join(save_directory, TIME_TOKENIZER_FILE_NAME))
567
+
568
+ with open(
569
+ os.path.join(save_directory, TOKEN_TO_SUB_TIME_TOKEN_MAPPING_FILE_NAME), "w"
570
+ ) as f:
571
+ json.dump(self._token_to_sub_time_token_mapping, f)
572
+
573
+ with open(os.path.join(save_directory, LAB_STATS_FILE_NAME), "wb") as f:
574
+ lab_stats = {
575
+ "numeric_lab_stats": self._numeric_lab_stats,
576
+ "categorical_lab_stats": self._categorical_lab_stats,
577
+ }
578
+ pickle.dump(lab_stats, f)
579
+
580
+ with open(os.path.join(save_directory, CONCEPT_MAPPING_FILE_NAME), "w") as f:
581
+ json.dump(self._concept_name_mapping, f)
582
+
583
+ self._pretrained_concept_embedding_model.save(save_directory)
584
+
585
+ if push_to_hub:
586
+ self._upload_modified_files(
587
+ save_directory,
588
+ repo_id,
589
+ files_timestamps,
590
+ commit_message=commit_message,
591
+ token=kwargs.get("token"),
592
+ )
593
+
594
+ @classmethod
595
+ def from_pretrained(
596
+ cls,
597
+ pretrained_model_name_or_path: Union[str, os.PathLike],
598
+ **kwargs,
599
+ ):
600
+ """
601
+ Load the CehrBert tokenizer.
602
+
603
+ Parameters:
604
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
605
+ Can be either:
606
+ - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
607
+ Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a
608
+ user or organization name, like `dbmdz/bert-base-german-cased`.
609
+ - A path to a *directory* containing tokenization data saved using
610
+ [`save_pretrained`], e.g., `./my_data_directory/`.
611
+ kwargs: Arguments for loading to pass to transformers.utils.hub.cached_file
612
+
613
+ Returns:
614
+ A CehrBert Tokenizer
615
+ """
616
+
617
+ is_legacy_tokenizer = CehrGptTokenizer.is_legacy_tokenizer(
618
+ pretrained_model_name_or_path, **kwargs
619
+ )
620
+
621
+ # Load the concept tokenizer
622
+ tokenizer_file = transformers.utils.hub.cached_file(
623
+ pretrained_model_name_or_path, TOKENIZER_FILE_NAME, **kwargs
624
+ )
625
+ if not tokenizer_file:
626
+ return None
627
+ tokenizer = Tokenizer.from_file(tokenizer_file)
628
+
629
+ # Load the concept_value_tokenizer
630
+ if is_legacy_tokenizer:
631
+ value_tokenizer = Tokenizer(
632
+ WordLevel(unk_token=OUT_OF_VOCABULARY_TOKEN, vocab=dict())
633
+ )
634
+ else:
635
+ value_tokenizer_file = transformers.utils.hub.cached_file(
636
+ pretrained_model_name_or_path, VALUE_TOKENIZER_FILE_NAME, **kwargs
637
+ )
638
+ if not value_tokenizer_file:
639
+ return None
640
+ value_tokenizer = Tokenizer.from_file(value_tokenizer_file)
641
+
642
+ # Load the ttt tokenizer
643
+ att_tokenizer_file = transformers.utils.hub.cached_file(
644
+ pretrained_model_name_or_path, TIME_TOKENIZER_FILE_NAME, **kwargs
645
+ )
646
+ if not att_tokenizer_file:
647
+ return None
648
+ att_tokenizer = Tokenizer.from_file(att_tokenizer_file)
649
+
650
+ # Load the sub time token json file
651
+ token_to_sub_time_token_mapping_file = transformers.utils.hub.cached_file(
652
+ pretrained_model_name_or_path,
653
+ TOKEN_TO_SUB_TIME_TOKEN_MAPPING_FILE_NAME,
654
+ **kwargs,
655
+ )
656
+ if not token_to_sub_time_token_mapping_file:
657
+ return None
658
+ token_to_sub_time_token_mapping = load_json_file(
659
+ token_to_sub_time_token_mapping_file
660
+ )
661
+
662
+ # Load the lab stats pickle file
663
+ if is_legacy_tokenizer:
664
+ legacy_lab_stats_file = transformers.utils.hub.cached_file(
665
+ pretrained_model_name_or_path, LEGACY_LAB_STATS_FILE_NAME, **kwargs
666
+ )
667
+ if not legacy_lab_stats_file:
668
+ return None
669
+ # Support the old version of the numeric lab stats file
670
+ lab_stats = {
671
+ "numeric_lab_stats": load_json_file(legacy_lab_stats_file),
672
+ "categorical_lab_stats": dict(),
673
+ }
674
+ else:
675
+ lab_stats_file = transformers.utils.hub.cached_file(
676
+ pretrained_model_name_or_path, LAB_STATS_FILE_NAME, **kwargs
677
+ )
678
+ if not lab_stats_file:
679
+ return None
680
+
681
+ with open(lab_stats_file, "rb") as file:
682
+ lab_stats = pickle.load(file)
683
+
684
+ # Load the concept_name json file
685
+ concept_name_mapping_file = transformers.utils.hub.cached_file(
686
+ pretrained_model_name_or_path, CONCEPT_MAPPING_FILE_NAME, **kwargs
687
+ )
688
+ if not concept_name_mapping_file:
689
+ return None
690
+ concept_name_mapping = load_json_file(concept_name_mapping_file)
691
+
692
+ pretrained_embedding_model = PretrainedEmbeddings(pretrained_model_name_or_path)
693
+
694
+ return CehrGptTokenizer(
695
+ tokenizer,
696
+ value_tokenizer,
697
+ att_tokenizer,
698
+ token_to_sub_time_token_mapping,
699
+ lab_stats["numeric_lab_stats"],
700
+ lab_stats["categorical_lab_stats"],
701
+ concept_name_mapping,
702
+ pretrained_embedding_model,
703
+ )
704
+
705
+ @classmethod
706
+ def is_legacy_tokenizer(
707
+ cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
708
+ ):
709
+ try:
710
+ legacy_lab_stats_file = transformers.utils.hub.cached_file(
711
+ pretrained_model_name_or_path, LEGACY_LAB_STATS_FILE_NAME, **kwargs
712
+ )
713
+ return legacy_lab_stats_file is not None
714
+ except Exception:
715
+ return False
716
+
717
+ @classmethod
718
+ def expand_trained_tokenizer(
719
+ cls,
720
+ cehrgpt_tokenizer,
721
+ dataset: Union[Dataset, DatasetDict],
722
+ concept_name_mapping: Dict[str, str],
723
+ data_args: DataTrainingArguments,
724
+ pretrained_concept_embedding_model: PretrainedEmbeddings = None,
725
+ ):
726
+ if not isinstance(cehrgpt_tokenizer, CehrGptTokenizer):
727
+ raise ValueError(
728
+ "The existing cehrgpt must be an instance of CehrGptTokenizer"
729
+ )
730
+
731
+ cehrgpt_tokenizer_copy = copy.deepcopy(cehrgpt_tokenizer)
732
+
733
+ new_tokenizer = CehrGptTokenizer.train_tokenizer(
734
+ dataset=dataset,
735
+ concept_name_mapping=concept_name_mapping,
736
+ data_args=data_args,
737
+ )
738
+
739
+ new_tokens = set(new_tokenizer.get_vocab().keys()) - set(
740
+ cehrgpt_tokenizer_copy.get_vocab().keys()
741
+ )
742
+ new_value_tokens = set(new_tokenizer.get_value_vocab().keys()) - set(
743
+ cehrgpt_tokenizer_copy.get_value_vocab().keys()
744
+ )
745
+ new_att_tokens = set(new_tokenizer._att_tokenizer.get_vocab().keys()) - set(
746
+ cehrgpt_tokenizer_copy._att_tokenizer.get_vocab().keys()
747
+ )
748
+ new_token_to_sub_time_token_mapping = (
749
+ new_tokenizer._token_to_sub_time_token_mapping
750
+ )
751
+ new_numeric_lab_stats = new_tokenizer._numeric_lab_stats
752
+ new_categorical_lab_stats = new_tokenizer._categorical_lab_stats
753
+ new_concept_name_mapping = new_tokenizer._concept_name_mapping
754
+
755
+ # Add new tokens to the existing tokenizer
756
+ cehrgpt_tokenizer_copy._tokenizer.add_tokens(
757
+ [
758
+ AddedToken(token, single_word=True, normalized=False)
759
+ for token in new_tokens
760
+ ]
761
+ )
762
+ # Add new tokens to the existing value tokenizer
763
+ cehrgpt_tokenizer_copy._value_tokenizer.add_tokens(
764
+ [
765
+ AddedToken(token, single_word=True, normalized=False)
766
+ for token in new_value_tokens
767
+ ]
768
+ )
769
+ # Add new time tokens to the existing att tokenizer
770
+ cehrgpt_tokenizer_copy._att_tokenizer.add_tokens(
771
+ [
772
+ AddedToken(token, single_word=True, normalized=False)
773
+ for token in new_att_tokens
774
+ ]
775
+ )
776
+ # Merge the time_token -> List[sub_time_tokens] mapping
777
+ for time_token, sub_time_tokens in new_token_to_sub_time_token_mapping.items():
778
+ if (
779
+ time_token
780
+ not in cehrgpt_tokenizer_copy._token_to_sub_time_token_mapping
781
+ ):
782
+ cehrgpt_tokenizer_copy._token_to_sub_time_token_mapping[time_token] = (
783
+ sub_time_tokens
784
+ )
785
+
786
+ # Merge numeric lab_stats
787
+ cehrgpt_tokenizer_copy._numeric_lab_stats = cls.merge_numeric_lab_stats(
788
+ cehrgpt_tokenizer_copy._numeric_lab_stats,
789
+ new_numeric_lab_stats,
790
+ )
791
+ # Merge categorical lab_stats
792
+ cehrgpt_tokenizer_copy._categorical_lab_stats = cls.merge_categorical_lab_stats(
793
+ cehrgpt_tokenizer_copy._categorical_lab_stats,
794
+ new_categorical_lab_stats,
795
+ )
796
+
797
+ # Merge concept_name_mapping
798
+ for token, concept_name in new_concept_name_mapping.items():
799
+ if token not in cehrgpt_tokenizer_copy._concept_name_mapping:
800
+ cehrgpt_tokenizer_copy._concept_name_mapping[token] = concept_name
801
+
802
+ return CehrGptTokenizer(
803
+ tokenizer=cehrgpt_tokenizer_copy._tokenizer,
804
+ value_tokenizer=cehrgpt_tokenizer_copy._value_tokenizer,
805
+ att_tokenizer=cehrgpt_tokenizer_copy._att_tokenizer,
806
+ token_to_sub_time_token_mapping=cehrgpt_tokenizer_copy._token_to_sub_time_token_mapping,
807
+ numeric_lab_stats=cehrgpt_tokenizer_copy._numeric_lab_stats,
808
+ categorical_lab_stats=cehrgpt_tokenizer_copy._categorical_lab_stats,
809
+ concept_name_mapping=cehrgpt_tokenizer_copy._concept_name_mapping,
810
+ pretrained_concept_embedding_model=pretrained_concept_embedding_model,
811
+ )
812
+
813
+ @classmethod
814
+ def merge_numeric_lab_stats(
815
+ cls,
816
+ lab_stats_existing: List[Dict[str, Any]],
817
+ lab_stats_new: List[Dict[str, Any]],
818
+ ) -> List[Dict[str, Any]]:
819
+
820
+ lab_stats_existing_mapping = {
821
+ (lab_stat["concept_id"], lab_stat["unit"]): lab_stat
822
+ for lab_stat in lab_stats_existing
823
+ }
824
+ for lab_stat in lab_stats_new:
825
+ concept_unit_pair = (lab_stat["concept_id"], lab_stat["unit"])
826
+ if concept_unit_pair in lab_stats_existing_mapping:
827
+ existing = OnlineStatistics()
828
+ existing.count = lab_stats_existing_mapping[concept_unit_pair]["count"]
829
+ existing.current_mean = lab_stats_existing_mapping[concept_unit_pair][
830
+ "mean"
831
+ ]
832
+ existing.variance = (
833
+ lab_stats_existing_mapping[concept_unit_pair]["std"] ** 2
834
+ * existing.count
835
+ )
836
+ new = OnlineStatistics()
837
+ new.count = lab_stat["count"]
838
+ new.current_mean = lab_stat["mean"]
839
+ new.variance = lab_stat["std"] ** 2 * new.count
840
+ existing.combine(new)
841
+ lab_stats_existing_mapping[concept_unit_pair]["mean"] = existing.mean()
842
+ lab_stats_existing_mapping[concept_unit_pair][
843
+ "std"
844
+ ] = existing.standard_deviation()
845
+ lab_stats_existing_mapping[concept_unit_pair]["count"] = existing.count
846
+ # recreate the bins
847
+ sample = create_sample_from_bins(
848
+ lab_stats_existing_mapping[concept_unit_pair]["bins"]
849
+ )
850
+ sample.extend(create_sample_from_bins(lab_stat["bins"]))
851
+ lab_stats_existing_mapping[concept_unit_pair]["bins"] = (
852
+ create_bins_with_spline(sample, NUM_OF_BINS, DEGREE_OF_FREEDOM)
853
+ )
854
+
855
+ else:
856
+ if lab_stat["count"] > 0:
857
+ lab_stats_existing_mapping[concept_unit_pair] = lab_stat
858
+
859
+ return list(lab_stats_existing_mapping.values())
860
+
861
+ @classmethod
862
+ def merge_categorical_lab_stats(
863
+ cls,
864
+ categorical_lab_stats_existing: Dict[Tuple[str, str], int],
865
+ categorical_lab_stats_new: Dict[Tuple[str, str], int],
866
+ ) -> Dict[Tuple[str, str], int]:
867
+ for (concept_id, concept_as_value), count in categorical_lab_stats_new.items():
868
+ if (concept_id, concept_as_value) not in categorical_lab_stats_new:
869
+ categorical_lab_stats_existing[(concept_id, concept_as_value)] = 0
870
+ categorical_lab_stats_existing[(concept_id, concept_as_value)] += count
871
+ return categorical_lab_stats_existing
872
+
873
+ @classmethod
874
+ def train_tokenizer(
875
+ cls,
876
+ dataset: Union[Dataset, DatasetDict],
877
+ concept_name_mapping: Dict[str, str],
878
+ data_args: DataTrainingArguments,
879
+ pretrained_concept_embedding_model: PretrainedEmbeddings = None,
880
+ ):
881
+ """
882
+ Train a huggingface word level tokenizer.
883
+
884
+ To use their tokenizer, we need to concatenate all the concepts
885
+ together and treat it as a sequence.
886
+ """
887
+
888
+ if isinstance(dataset, DatasetDict):
889
+ dataset = dataset["train"]
890
+
891
+ concept_tokenizer = cls.train_concept_tokenizer(
892
+ dataset,
893
+ feature_name="concept_ids",
894
+ special_tokens=[PAD_TOKEN, OUT_OF_VOCABULARY_TOKEN, START_TOKEN, END_TOKEN],
895
+ unk_token=OUT_OF_VOCABULARY_TOKEN,
896
+ data_args=data_args,
897
+ )
898
+ concept_value_column = "concept_as_values"
899
+ for row in dataset:
900
+ if concept_value_column not in row:
901
+ concept_value_column = "concept_values"
902
+ break
903
+ value_tokenizer = cls.train_concept_tokenizer(
904
+ dataset,
905
+ feature_name=concept_value_column,
906
+ special_tokens=[OUT_OF_VOCABULARY_TOKEN, PAD_TOKEN],
907
+ unk_token=OUT_OF_VOCABULARY_TOKEN,
908
+ data_args=data_args,
909
+ )
910
+ value_tokenizer.add_tokens(
911
+ [
912
+ AddedToken(_, single_word=True, normalized=False)
913
+ for _ in [create_value_bin(_) for _ in range(NUM_OF_BINS)]
914
+ + [UNKNOWN_BIN, NONE_BIN]
915
+ ]
916
+ )
917
+
918
+ map_statistics_partial = partial(map_statistics, size=SAMPLE_SIZE)
919
+
920
+ if data_args.streaming:
921
+ parts = dataset.map(
922
+ partial(agg_helper, map_func=map_statistics_partial),
923
+ batched=True,
924
+ batch_size=data_args.preprocessing_batch_size,
925
+ new_fingerprint="invalid",
926
+ remove_columns=dataset.column_names,
927
+ )
928
+ else:
929
+ parts = dataset.map(
930
+ partial(agg_helper, map_func=map_statistics_partial),
931
+ batched=True,
932
+ batch_size=data_args.preprocessing_batch_size,
933
+ remove_columns=dataset.column_names,
934
+ num_proc=data_args.preprocessing_num_workers,
935
+ keep_in_memory=True,
936
+ new_fingerprint="invalid",
937
+ )
938
+ current = None
939
+ for stat in tqdm(parts, desc="Aggregating the lab statistics"):
940
+ fixed_stat = pickle.loads(stat["data"])
941
+ if current is None:
942
+ current = fixed_stat
943
+ else:
944
+ current = agg_statistics(current, fixed_stat)
945
+
946
+ numeric_lab_stats = []
947
+ for (concept_id, unit), online_stats in current["numeric_stats_by_lab"].items():
948
+ if len(online_stats.samples) == 0:
949
+ continue
950
+ samples = truncated_sample(
951
+ online_stats.samples, data_args.value_outlier_std
952
+ )
953
+ bins = create_bins_with_spline(samples, NUM_OF_BINS, DEGREE_OF_FREEDOM)
954
+ if len(bins) > 0:
955
+ numeric_lab_stats.append(
956
+ {
957
+ "concept_id": concept_id,
958
+ "unit": unit,
959
+ "mean": np.mean(samples),
960
+ "std": np.std(samples),
961
+ "count": len(online_stats.samples),
962
+ "value_outlier_std": data_args.value_outlier_std,
963
+ "bins": bins,
964
+ }
965
+ )
966
+
967
+ categorical_lab_stats = collections.defaultdict(int)
968
+ for (concept_id, value_as_concept), count in current[
969
+ "categorical_stats_by_lab"
970
+ ].items():
971
+ categorical_lab_stats[(concept_id, value_as_concept)] += count
972
+
973
+ # We will train a tokenizer specifically for time intervals
974
+ sub_time_token_data = []
975
+ token_to_sub_time_token_mapping = collections.defaultdict(list)
976
+ for token, token_id in concept_tokenizer.get_vocab().items():
977
+ if is_att_token(token):
978
+ time_interval = extract_time_interval_in_days(token)
979
+ time_tuple = convert_time_interval_to_time_tuple(
980
+ time_interval, is_inpatient_att_token(token)
981
+ )
982
+ token_to_sub_time_token_mapping[token] = list(time_tuple)
983
+ sub_time_token_data.append(" ".join(time_tuple))
984
+
985
+ att_tokenizer = Tokenizer(
986
+ WordLevel(unk_token=OUT_OF_VOCABULARY_TOKEN, vocab=dict())
987
+ )
988
+ att_tokenizer.pre_tokenizer = WhitespaceSplit()
989
+ att_trainer = WordLevelTrainer(
990
+ special_tokens=[OUT_OF_VOCABULARY_TOKEN],
991
+ vocab_size=data_args.vocab_size,
992
+ min_frequency=0,
993
+ show_progress=True,
994
+ )
995
+ att_tokenizer.train_from_iterator(sub_time_token_data, trainer=att_trainer)
996
+
997
+ return CehrGptTokenizer(
998
+ concept_tokenizer,
999
+ value_tokenizer,
1000
+ att_tokenizer,
1001
+ token_to_sub_time_token_mapping,
1002
+ numeric_lab_stats,
1003
+ categorical_lab_stats,
1004
+ concept_name_mapping,
1005
+ pretrained_concept_embedding_model,
1006
+ )
1007
+
1008
+ @classmethod
1009
+ def train_concept_tokenizer(
1010
+ cls,
1011
+ dataset,
1012
+ feature_name,
1013
+ special_tokens: List[str],
1014
+ unk_token,
1015
+ data_args,
1016
+ ):
1017
+ # Use the Fast Tokenizer from the Huggingface tokenizers Rust implementation.
1018
+ # https://github.com/huggingface/tokenizers
1019
+ concept_tokenizer = Tokenizer(WordLevel(unk_token=unk_token, vocab=dict()))
1020
+ concept_tokenizer.pre_tokenizer = WhitespaceSplit()
1021
+ concept_trainer = WordLevelTrainer(
1022
+ special_tokens=special_tokens,
1023
+ vocab_size=data_args.vocab_size,
1024
+ min_frequency=data_args.min_frequency,
1025
+ show_progress=True,
1026
+ )
1027
+ batch_concat_concepts_partial_func = partial(
1028
+ cls.batch_concat_concepts, feature_name=feature_name
1029
+ )
1030
+ if data_args.streaming:
1031
+ concatenated_features = dataset.map(
1032
+ batch_concat_concepts_partial_func,
1033
+ batched=True,
1034
+ batch_size=data_args.preprocessing_batch_size,
1035
+ )
1036
+
1037
+ def batched_generator():
1038
+ iterator = iter(concatenated_features)
1039
+ while True:
1040
+ batch = list(islice(iterator, data_args.preprocessing_batch_size))
1041
+ if not batch:
1042
+ break
1043
+ yield [example[feature_name] for example in batch]
1044
+
1045
+ # We pass a generator of list of texts (concatenated concept_ids) to train_from_iterator
1046
+ # for efficient training
1047
+ generator = batched_generator()
1048
+ else:
1049
+ concatenated_features = dataset.map(
1050
+ batch_concat_concepts_partial_func,
1051
+ num_proc=data_args.preprocessing_num_workers,
1052
+ batched=True,
1053
+ batch_size=data_args.preprocessing_batch_size,
1054
+ remove_columns=dataset.column_names,
1055
+ )
1056
+ generator = concatenated_features[feature_name]
1057
+ concept_tokenizer.train_from_iterator(generator, trainer=concept_trainer)
1058
+ return concept_tokenizer
1059
+
1060
+ def normalize(self, concept_id: str, unit: str, concept_value: float) -> str:
1061
+ return self._numeric_event_statistics.normalize(concept_id, unit, concept_value)
1062
+
1063
+ def denormalize(self, concept_id: str, value_bin: str) -> Tuple[float, str]:
1064
+ return self._numeric_event_statistics.denormalize(concept_id, value_bin)
1065
+
1066
+ @classmethod
1067
+ def batch_concat_concepts(
1068
+ cls, records: Dict[str, List], feature_name
1069
+ ) -> Dict[str, List]:
1070
+ return {
1071
+ feature_name: [
1072
+ " ".join(
1073
+ [token for token in tokens if token and isinstance(token, str)]
1074
+ )
1075
+ for tokens in records[feature_name]
1076
+ ]
1077
+ }