cehrgpt 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- __init__.py +0 -0
- cehrgpt/__init__.py +0 -0
- cehrgpt/analysis/__init__.py +0 -0
- cehrgpt/analysis/privacy/__init__.py +0 -0
- cehrgpt/analysis/privacy/attribute_inference.py +275 -0
- cehrgpt/analysis/privacy/attribute_inference_config.yml +8975 -0
- cehrgpt/analysis/privacy/member_inference.py +172 -0
- cehrgpt/analysis/privacy/nearest_neighbor_inference.py +189 -0
- cehrgpt/analysis/privacy/reid_inference.py +407 -0
- cehrgpt/analysis/privacy/utils.py +255 -0
- cehrgpt/cehrgpt_args.py +142 -0
- cehrgpt/data/__init__.py +0 -0
- cehrgpt/data/hf_cehrgpt_dataset.py +80 -0
- cehrgpt/data/hf_cehrgpt_dataset_collator.py +482 -0
- cehrgpt/data/hf_cehrgpt_dataset_mapping.py +116 -0
- cehrgpt/generation/__init__.py +0 -0
- cehrgpt/generation/chatgpt_generation.py +106 -0
- cehrgpt/generation/generate_batch_hf_gpt_sequence.py +333 -0
- cehrgpt/generation/omop_converter_batch.py +644 -0
- cehrgpt/generation/omop_entity.py +515 -0
- cehrgpt/gpt_utils.py +331 -0
- cehrgpt/models/__init__.py +0 -0
- cehrgpt/models/config.py +205 -0
- cehrgpt/models/hf_cehrgpt.py +1817 -0
- cehrgpt/models/hf_modeling_outputs.py +158 -0
- cehrgpt/models/pretrained_embeddings.py +82 -0
- cehrgpt/models/special_tokens.py +30 -0
- cehrgpt/models/tokenization_hf_cehrgpt.py +1077 -0
- cehrgpt/omop/__init__.py +0 -0
- cehrgpt/omop/condition_era.py +20 -0
- cehrgpt/omop/observation_period.py +43 -0
- cehrgpt/omop/omop_argparse.py +38 -0
- cehrgpt/omop/omop_table_builder.py +86 -0
- cehrgpt/omop/queries/__init__.py +0 -0
- cehrgpt/omop/queries/condition_era.py +86 -0
- cehrgpt/omop/queries/observation_period.py +135 -0
- cehrgpt/omop/sample_omop_tables.py +71 -0
- cehrgpt/runners/__init__.py +0 -0
- cehrgpt/runners/gpt_runner_util.py +99 -0
- cehrgpt/runners/hf_cehrgpt_finetune_runner.py +746 -0
- cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +370 -0
- cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +137 -0
- cehrgpt/runners/hyperparameter_search_util.py +223 -0
- cehrgpt/time_to_event/__init__.py +0 -0
- cehrgpt/time_to_event/config/30_day_readmission.yaml +8 -0
- cehrgpt/time_to_event/config/next_visit_type_prediction.yaml +8 -0
- cehrgpt/time_to_event/config/t2dm_hf.yaml +8 -0
- cehrgpt/time_to_event/time_to_event_model.py +226 -0
- cehrgpt/time_to_event/time_to_event_prediction.py +347 -0
- cehrgpt/time_to_event/time_to_event_utils.py +55 -0
- cehrgpt/tools/__init__.py +0 -0
- cehrgpt/tools/ehrshot_benchmark.py +74 -0
- cehrgpt/tools/generate_pretrained_embeddings.py +130 -0
- cehrgpt/tools/merge_synthetic_real_dataasets.py +218 -0
- cehrgpt/tools/upload_omop_tables.py +108 -0
- cehrgpt-0.0.1.dist-info/LICENSE +21 -0
- cehrgpt-0.0.1.dist-info/METADATA +66 -0
- cehrgpt-0.0.1.dist-info/RECORD +60 -0
- cehrgpt-0.0.1.dist-info/WHEEL +5 -0
- cehrgpt-0.0.1.dist-info/top_level.txt +2 -0
@@ -0,0 +1,1077 @@
|
|
1
|
+
import collections
|
2
|
+
import copy
|
3
|
+
import json
|
4
|
+
import os
|
5
|
+
import pickle
|
6
|
+
from functools import partial
|
7
|
+
from itertools import islice
|
8
|
+
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
|
9
|
+
|
10
|
+
import numpy as np
|
11
|
+
import scipy.stats as stats
|
12
|
+
import transformers
|
13
|
+
from cehrbert.models.hf_models.tokenization_utils import (
|
14
|
+
agg_helper,
|
15
|
+
agg_statistics,
|
16
|
+
load_json_file,
|
17
|
+
)
|
18
|
+
from cehrbert.runners.hf_runner_argument_dataclass import DataTrainingArguments
|
19
|
+
from datasets import Dataset, DatasetDict
|
20
|
+
from femr.stat_utils import OnlineStatistics, ReservoirSampler
|
21
|
+
from scipy.interpolate import UnivariateSpline
|
22
|
+
from tokenizers import AddedToken, Tokenizer
|
23
|
+
from tokenizers.models import WordLevel
|
24
|
+
from tokenizers.pre_tokenizers import WhitespaceSplit
|
25
|
+
from tokenizers.trainers import WordLevelTrainer
|
26
|
+
from tqdm import tqdm
|
27
|
+
from transformers import PreTrainedTokenizer
|
28
|
+
|
29
|
+
from cehrgpt.gpt_utils import (
|
30
|
+
convert_time_interval_to_time_tuple,
|
31
|
+
extract_time_interval_in_days,
|
32
|
+
is_att_token,
|
33
|
+
is_inpatient_att_token,
|
34
|
+
)
|
35
|
+
from cehrgpt.models.pretrained_embeddings import PretrainedEmbeddings
|
36
|
+
from cehrgpt.models.special_tokens import (
|
37
|
+
END_TOKEN,
|
38
|
+
OUT_OF_VOCABULARY_TOKEN,
|
39
|
+
PAD_TOKEN,
|
40
|
+
START_TOKEN,
|
41
|
+
)
|
42
|
+
|
43
|
+
NUM_OF_BINS = 10
|
44
|
+
DEGREE_OF_FREEDOM = 3
|
45
|
+
SAMPLE_SIZE = 10_000
|
46
|
+
NA = "N/A"
|
47
|
+
UNKNOWN_BIN = "BIN:unknown"
|
48
|
+
NONE_BIN = "BIN:NONE"
|
49
|
+
TOKENIZER_FILE_NAME = "cehrgpt_tokenizer.json"
|
50
|
+
VALUE_TOKENIZER_FILE_NAME = "cehrgpt_value_tokenizer.json"
|
51
|
+
TIME_TOKENIZER_FILE_NAME = "cehrgpt_time_tokenizer.json"
|
52
|
+
TOKEN_TO_SUB_TIME_TOKEN_MAPPING_FILE_NAME = "token_to_sub_time_token_mapping.json"
|
53
|
+
LAB_STATS_FILE_NAME = "cehrgpt_lab_stats.pickle"
|
54
|
+
LEGACY_LAB_STATS_FILE_NAME = "cehrgpt_lab_stats.json"
|
55
|
+
CONCEPT_MAPPING_FILE_NAME = "concept_name_mapping.json"
|
56
|
+
|
57
|
+
|
58
|
+
def truncated_sample(sample, standard_deviation):
|
59
|
+
lower_quantile = stats.norm.cdf(-standard_deviation)
|
60
|
+
upper_quantile = stats.norm.cdf(standard_deviation)
|
61
|
+
lower_bound = np.quantile(sample, lower_quantile)
|
62
|
+
upper_bound = np.quantile(sample, upper_quantile)
|
63
|
+
return [x for x in sample if lower_bound <= x <= upper_bound]
|
64
|
+
|
65
|
+
|
66
|
+
def is_valid_valid_bin(token: str) -> bool:
|
67
|
+
return token.startswith("BIN:")
|
68
|
+
|
69
|
+
|
70
|
+
def create_value_bin(bin_index: int) -> str:
|
71
|
+
return "BIN:" + str(bin_index)
|
72
|
+
|
73
|
+
|
74
|
+
def create_sample_from_bins(bins, sample_size: int = 10_000) -> List[float]:
|
75
|
+
"""
|
76
|
+
Generates a specified number of samples from a list of bins, each containing a fitted spline.
|
77
|
+
|
78
|
+
This function iterates over each bin, extracts the spline, and uses it to generate a set of samples
|
79
|
+
uniformly distributed along the x-axis defined by the spline's knots. It ensures that the total number
|
80
|
+
of samples generated matches the specified sample size by distributing the number of samples evenly
|
81
|
+
across the bins.
|
82
|
+
|
83
|
+
Parameters:
|
84
|
+
bins (List[Dict[str, UnivariateSpline]]): A list of dictionaries, each containing a 'spline' key
|
85
|
+
with a UnivariateSpline object as its value. These splines define the data distribution within
|
86
|
+
each bin from which samples are to be generated.
|
87
|
+
sample_size (int, optional): The total number of samples to generate from all bins combined.
|
88
|
+
Defaults to 10,000.
|
89
|
+
|
90
|
+
Returns:
|
91
|
+
List[float]: A list of sampled values, where each value is generated based on the spline functions
|
92
|
+
provided in the bins. The total number of samples in the list will be equal to `sample_size`.
|
93
|
+
|
94
|
+
Raises:
|
95
|
+
ValueError: If `sample_size` is less than the number of bins, as it would not be possible to generate
|
96
|
+
at least one sample per bin.
|
97
|
+
|
98
|
+
Example:
|
99
|
+
>>> x = np.linspace(0, 10, 100)
|
100
|
+
>>> y = np.sin(x)
|
101
|
+
>>> spline = UnivariateSpline(x, y, s=1)
|
102
|
+
>>> bins = [{'spline': spline} for _ in range(5)]
|
103
|
+
>>> samples = create_sample_from_bins(bins, 1000)
|
104
|
+
>>> len(samples)
|
105
|
+
1000
|
106
|
+
|
107
|
+
Note:
|
108
|
+
The function assumes that each bin's spline has a sufficient range of x-values (knots) to allow for
|
109
|
+
meaningful sampling. If the range of x-values is too narrow, the uniformity of the sample distribution
|
110
|
+
may be affected.
|
111
|
+
"""
|
112
|
+
sample = []
|
113
|
+
num_of_bins = len(bins)
|
114
|
+
if num_of_bins > 0:
|
115
|
+
sample_per_bin = sample_size // num_of_bins
|
116
|
+
for value_bin in bins:
|
117
|
+
bin_spline = value_bin["spline"]
|
118
|
+
x = np.random.uniform(
|
119
|
+
bin_spline.get_knots()[0], bin_spline.get_knots()[-1], sample_per_bin
|
120
|
+
)
|
121
|
+
y = bin_spline(x)
|
122
|
+
sample.extend(y)
|
123
|
+
return sample
|
124
|
+
|
125
|
+
|
126
|
+
def create_bins_with_spline(samples, num_bins, d_freedom=3) -> List[Dict[str, Any]]:
|
127
|
+
"""
|
128
|
+
Divides a list of numeric samples into a specified number of bins and fits a spline to the data in each bin.
|
129
|
+
|
130
|
+
This function first sorts the list of samples, then partitions the sorted list into `num_bins` bins. For each bin,
|
131
|
+
a UnivariateSpline is fitted to the data within the bin, using the specified degrees of freedom. The function
|
132
|
+
handles edge cases by assigning infinity to the bounds of the first and last bins.
|
133
|
+
|
134
|
+
Parameters:
|
135
|
+
samples (List[float]): A list of sample data points, which are real numbers.
|
136
|
+
num_bins (int): The number of bins to divide the sample data into. It is assumed that there are enough samples to at least fill the bins to the minimum required for spline fitting.
|
137
|
+
d_freedom (int, optional): The degree of freedom for the spline. Default is 1, which fits a linear spline.
|
138
|
+
|
139
|
+
Returns:
|
140
|
+
List[Dict[str, Any]]: A list of dictionaries, each representing a bin. Each dictionary contains:
|
141
|
+
- 'bin_index' (int): The index of the bin.
|
142
|
+
- 'start_val' (float): The starting value of the bin, with the first bin starting at negative infinity.
|
143
|
+
- 'end_val' (float): The ending value of the bin, with the last bin ending at positive infinity.
|
144
|
+
- 'spline' (UnivariateSpline): The spline object fitted to the data within the bin.
|
145
|
+
|
146
|
+
Raises:
|
147
|
+
ValueError: If `num_bins` is less than 1 or if there are insufficient samples to create the specified number of bins with the required minimum data points per bin.
|
148
|
+
|
149
|
+
Example:
|
150
|
+
>>> samples = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
151
|
+
>>> bins = create_bins_with_spline(samples, 2)
|
152
|
+
>>> for b in bins:
|
153
|
+
... print(b['bin_index'], b['start_val'], b['end_val'])
|
154
|
+
...
|
155
|
+
0 -inf 5.5
|
156
|
+
1 5.5 inf
|
157
|
+
|
158
|
+
Note:
|
159
|
+
The function assumes that each bin will have at least `d_freedom + 1` samples to fit the spline. If the total number of samples is less than `num_bins * (d_freedom + 1)`, no bins will be created.
|
160
|
+
"""
|
161
|
+
samples.sort()
|
162
|
+
bins = []
|
163
|
+
if len(samples) >= num_bins * (d_freedom + 1):
|
164
|
+
samples_per_bin = len(samples) // num_bins
|
165
|
+
for bin_index in range(0, num_bins):
|
166
|
+
if bin_index == 0:
|
167
|
+
start_val = float("-inf")
|
168
|
+
else:
|
169
|
+
start_val = samples[bin_index * samples_per_bin]
|
170
|
+
|
171
|
+
if bin_index == num_bins - 1:
|
172
|
+
end_val = float("inf")
|
173
|
+
else:
|
174
|
+
end_val = samples[(bin_index + 1) * samples_per_bin]
|
175
|
+
x = range(bin_index * samples_per_bin, (bin_index + 1) * samples_per_bin)
|
176
|
+
y = samples[bin_index * samples_per_bin : (bin_index + 1) * samples_per_bin]
|
177
|
+
spline = UnivariateSpline(x, y, k=d_freedom)
|
178
|
+
bins.append(
|
179
|
+
{
|
180
|
+
"bin_index": bin_index,
|
181
|
+
"start_val": start_val,
|
182
|
+
"end_val": end_val,
|
183
|
+
"spline": spline,
|
184
|
+
}
|
185
|
+
)
|
186
|
+
return bins
|
187
|
+
|
188
|
+
|
189
|
+
def map_statistics(batch: Dict[str, Any], size=10_000) -> Dict[str, Any]:
|
190
|
+
if "units" in batch:
|
191
|
+
batch_value_units = batch["units"]
|
192
|
+
else:
|
193
|
+
batch_value_units = [[NA for _ in cons] for cons in batch["concept_ids"]]
|
194
|
+
|
195
|
+
if "number_as_values" not in batch:
|
196
|
+
batched_number_as_values = [
|
197
|
+
[value if isinstance(value, float) else None for value in concept_values]
|
198
|
+
for concept_values in batch["concept_values"]
|
199
|
+
]
|
200
|
+
else:
|
201
|
+
batched_number_as_values = batch["number_as_values"]
|
202
|
+
|
203
|
+
if "concept_as_values" not in batch:
|
204
|
+
batched_concept_as_values = [
|
205
|
+
[value if isinstance(value, str) else None for value in concept_values]
|
206
|
+
for concept_values in batch["concept_values"]
|
207
|
+
]
|
208
|
+
else:
|
209
|
+
batched_concept_as_values = batch["concept_as_values"]
|
210
|
+
|
211
|
+
numeric_stats_by_lab = collections.defaultdict(partial(ReservoirSampler, size=size))
|
212
|
+
categorical_stats_by_lab = collections.defaultdict(int)
|
213
|
+
for (
|
214
|
+
concept_ids,
|
215
|
+
number_as_values,
|
216
|
+
concept_as_values,
|
217
|
+
concept_value_indicators,
|
218
|
+
units,
|
219
|
+
) in zip(
|
220
|
+
batch["concept_ids"],
|
221
|
+
batched_number_as_values,
|
222
|
+
batched_concept_as_values,
|
223
|
+
batch["concept_value_masks"],
|
224
|
+
batch_value_units,
|
225
|
+
):
|
226
|
+
for (
|
227
|
+
concept_id,
|
228
|
+
number_as_value,
|
229
|
+
concept_as_value,
|
230
|
+
concept_value_indicator,
|
231
|
+
unit,
|
232
|
+
) in zip(
|
233
|
+
concept_ids,
|
234
|
+
number_as_values,
|
235
|
+
concept_as_values,
|
236
|
+
concept_value_indicators,
|
237
|
+
units,
|
238
|
+
):
|
239
|
+
if concept_value_indicator == 1:
|
240
|
+
if number_as_value:
|
241
|
+
numeric_stats_by_lab[(concept_id, unit)].add(number_as_value, 1)
|
242
|
+
if concept_as_value:
|
243
|
+
categorical_stats_by_lab[(concept_id, concept_as_value)] += 1
|
244
|
+
|
245
|
+
return {
|
246
|
+
"numeric_stats_by_lab": numeric_stats_by_lab,
|
247
|
+
"categorical_stats_by_lab": categorical_stats_by_lab,
|
248
|
+
}
|
249
|
+
|
250
|
+
|
251
|
+
def create_numeric_concept_unit_mapping(
|
252
|
+
lab_stats: List[Dict[str, Any]]
|
253
|
+
) -> Tuple[Dict[str, List[float]], Dict[str, List[str]]]:
|
254
|
+
numeric_concept_unit_mapping = collections.defaultdict(list)
|
255
|
+
for each_lab_stat in lab_stats:
|
256
|
+
numeric_concept_unit_mapping[each_lab_stat["concept_id"]].append(
|
257
|
+
(each_lab_stat["count"], each_lab_stat["unit"])
|
258
|
+
)
|
259
|
+
|
260
|
+
concept_prob_mapping = dict()
|
261
|
+
concept_unit_mapping = dict()
|
262
|
+
for concept_id in numeric_concept_unit_mapping.keys():
|
263
|
+
counts, units = zip(*numeric_concept_unit_mapping[concept_id])
|
264
|
+
total_count = sum(counts)
|
265
|
+
probs = [float(c) / total_count for c in counts]
|
266
|
+
concept_prob_mapping[concept_id] = probs
|
267
|
+
concept_unit_mapping[concept_id] = units
|
268
|
+
return concept_prob_mapping, concept_unit_mapping
|
269
|
+
|
270
|
+
|
271
|
+
class NumericEventStatistics:
|
272
|
+
def __init__(self, lab_stats: List[Dict[str, Any]]):
|
273
|
+
self._lab_stats = lab_stats
|
274
|
+
self._lab_stats_mapping = {
|
275
|
+
(lab_stat["concept_id"], lab_stat["unit"]): {
|
276
|
+
"unit": lab_stat["unit"],
|
277
|
+
"mean": lab_stat["mean"],
|
278
|
+
"std": lab_stat["std"],
|
279
|
+
"value_outlier_std": lab_stat["value_outlier_std"],
|
280
|
+
"bins": lab_stat["bins"],
|
281
|
+
}
|
282
|
+
for lab_stat in lab_stats
|
283
|
+
}
|
284
|
+
self._concept_prob_mapping, self._concept_unit_mapping = (
|
285
|
+
create_numeric_concept_unit_mapping(lab_stats)
|
286
|
+
)
|
287
|
+
|
288
|
+
def get_numeric_concept_ids(self) -> List[str]:
|
289
|
+
return [_["concept_id"] for _ in self._lab_stats]
|
290
|
+
|
291
|
+
def get_random_unit(self, concept_id: str) -> str:
|
292
|
+
if concept_id in self._concept_prob_mapping:
|
293
|
+
unit_probs = self._concept_prob_mapping[concept_id]
|
294
|
+
return np.random.choice(
|
295
|
+
self._concept_unit_mapping[concept_id], p=unit_probs
|
296
|
+
)
|
297
|
+
return NA
|
298
|
+
|
299
|
+
def normalize(
|
300
|
+
self, concept_id: str, unit: str, concept_value: Union[float, str]
|
301
|
+
) -> str:
|
302
|
+
if isinstance(concept_value, float):
|
303
|
+
if (concept_id, unit) in self._lab_stats_mapping:
|
304
|
+
concept_unit_stats = self._lab_stats_mapping[(concept_id, unit)]
|
305
|
+
bins = concept_unit_stats["bins"]
|
306
|
+
if bins:
|
307
|
+
for each_bin in bins:
|
308
|
+
if (
|
309
|
+
each_bin["start_val"]
|
310
|
+
<= concept_value
|
311
|
+
<= each_bin["end_val"]
|
312
|
+
):
|
313
|
+
return create_value_bin(each_bin["bin_index"])
|
314
|
+
return UNKNOWN_BIN
|
315
|
+
|
316
|
+
def denormalize(
|
317
|
+
self, concept_id: str, value_bin: str
|
318
|
+
) -> Tuple[Optional[Union[float, str]], str]:
|
319
|
+
unit = self.get_random_unit(concept_id)
|
320
|
+
concept_value = value_bin
|
321
|
+
if (
|
322
|
+
is_valid_valid_bin(value_bin)
|
323
|
+
and (concept_id, unit) in self._lab_stats_mapping
|
324
|
+
):
|
325
|
+
lab_stats = self._lab_stats_mapping[(concept_id, unit)]
|
326
|
+
bin_index = value_bin.split(":")[1]
|
327
|
+
if bin_index.isnumeric():
|
328
|
+
bin_index = int(bin_index)
|
329
|
+
# There are rare cases during sequence generation where bin_index could be out of range
|
330
|
+
# when there are no bins for (concept_id, unit) due to the small number of values in the source data
|
331
|
+
if len(lab_stats["bins"]) > bin_index:
|
332
|
+
assert bin_index == lab_stats["bins"][bin_index]["bin_index"]
|
333
|
+
bin_spline = lab_stats["bins"][bin_index]["spline"]
|
334
|
+
x = np.random.uniform(
|
335
|
+
bin_spline.get_knots()[0], bin_spline.get_knots()[-1]
|
336
|
+
)
|
337
|
+
concept_value = bin_spline(x).item()
|
338
|
+
return concept_value, unit
|
339
|
+
|
340
|
+
|
341
|
+
class CehrGptTokenizer(PreTrainedTokenizer):
|
342
|
+
|
343
|
+
def __init__(
|
344
|
+
self,
|
345
|
+
tokenizer: Tokenizer,
|
346
|
+
value_tokenizer: Tokenizer,
|
347
|
+
att_tokenizer: Tokenizer,
|
348
|
+
token_to_sub_time_token_mapping: Dict[str, List[str]],
|
349
|
+
numeric_lab_stats: List[Dict[str, Any]],
|
350
|
+
categorical_lab_stats: Dict[Tuple[str, str], int],
|
351
|
+
concept_name_mapping: Dict[str, str],
|
352
|
+
pretrained_concept_embedding_model: PretrainedEmbeddings = None,
|
353
|
+
):
|
354
|
+
self._tokenizer = tokenizer
|
355
|
+
self._value_tokenizer = value_tokenizer
|
356
|
+
self._att_tokenizer = att_tokenizer
|
357
|
+
self._token_to_sub_time_token_mapping = token_to_sub_time_token_mapping
|
358
|
+
self._numeric_lab_stats = numeric_lab_stats
|
359
|
+
self._numeric_event_statistics = NumericEventStatistics(numeric_lab_stats)
|
360
|
+
self._categorical_lab_stats = categorical_lab_stats
|
361
|
+
self._concept_name_mapping = concept_name_mapping
|
362
|
+
self._oov_token_id = self._tokenizer.token_to_id(OUT_OF_VOCABULARY_TOKEN)
|
363
|
+
self._padding_token_id = self._tokenizer.token_to_id(PAD_TOKEN)
|
364
|
+
self._start_token_id = self._tokenizer.token_to_id(START_TOKEN)
|
365
|
+
self._end_token_id = self._tokenizer.token_to_id(END_TOKEN)
|
366
|
+
self._numeric_concept_ids = (
|
367
|
+
self._numeric_event_statistics.get_numeric_concept_ids()
|
368
|
+
)
|
369
|
+
self._categorical_concept_ids = list(
|
370
|
+
{t[0] for t in self._categorical_lab_stats.keys()}
|
371
|
+
)
|
372
|
+
self._padding_value_token_id = self._value_tokenizer.token_to_id(PAD_TOKEN)
|
373
|
+
self._pretrained_concept_embedding_model = (
|
374
|
+
pretrained_concept_embedding_model
|
375
|
+
if pretrained_concept_embedding_model
|
376
|
+
else PretrainedEmbeddings(None)
|
377
|
+
)
|
378
|
+
self._pretrained_concept_ids = [
|
379
|
+
_
|
380
|
+
for _ in self.get_vocab().keys()
|
381
|
+
if self._pretrained_concept_embedding_model.is_concept_available(_)
|
382
|
+
]
|
383
|
+
|
384
|
+
super().__init__()
|
385
|
+
|
386
|
+
@property
|
387
|
+
def pretrained_concept_ids(self):
|
388
|
+
return self._pretrained_concept_ids
|
389
|
+
|
390
|
+
@property
|
391
|
+
def pretrained_token_ids(self):
|
392
|
+
return self.encode(self._pretrained_concept_ids)
|
393
|
+
|
394
|
+
@property
|
395
|
+
def pretrained_embeddings(self):
|
396
|
+
return np.asarray(
|
397
|
+
[
|
398
|
+
self._pretrained_concept_embedding_model.get_concept_embeddings(_)
|
399
|
+
for _ in self._pretrained_concept_ids
|
400
|
+
]
|
401
|
+
)
|
402
|
+
|
403
|
+
@property
|
404
|
+
def vocab_size(self) -> int:
|
405
|
+
return self._tokenizer.get_vocab_size()
|
406
|
+
|
407
|
+
@property
|
408
|
+
def value_vocab_size(self) -> int:
|
409
|
+
return self._value_tokenizer.get_vocab_size()
|
410
|
+
|
411
|
+
@property
|
412
|
+
def time_token_vocab_size(self) -> int:
|
413
|
+
return self._att_tokenizer.get_vocab_size()
|
414
|
+
|
415
|
+
@property
|
416
|
+
def pad_value_token_id(self):
|
417
|
+
return self._padding_value_token_id
|
418
|
+
|
419
|
+
@property
|
420
|
+
def start_token_id(self):
|
421
|
+
return self._start_token_id
|
422
|
+
|
423
|
+
@property
|
424
|
+
def end_token_id(self):
|
425
|
+
return self._end_token_id
|
426
|
+
|
427
|
+
@property
|
428
|
+
def end_token(self):
|
429
|
+
return END_TOKEN
|
430
|
+
|
431
|
+
@property
|
432
|
+
def eos_token(self):
|
433
|
+
return END_TOKEN
|
434
|
+
|
435
|
+
@property
|
436
|
+
def eos_token_id(self):
|
437
|
+
return self._end_token_id
|
438
|
+
|
439
|
+
@property
|
440
|
+
def pad_token_id(self):
|
441
|
+
return self._padding_token_id
|
442
|
+
|
443
|
+
@property
|
444
|
+
def pad_token(self):
|
445
|
+
return PAD_TOKEN
|
446
|
+
|
447
|
+
@property
|
448
|
+
def numeric_concept_ids(self):
|
449
|
+
return self._numeric_concept_ids
|
450
|
+
|
451
|
+
@property
|
452
|
+
def categorical_concept_ids(self):
|
453
|
+
return self._categorical_concept_ids
|
454
|
+
|
455
|
+
@property
|
456
|
+
def lab_token_ids(self):
|
457
|
+
reserved_tokens = [START_TOKEN, PAD_TOKEN, END_TOKEN, OUT_OF_VOCABULARY_TOKEN]
|
458
|
+
return self.encode(
|
459
|
+
[
|
460
|
+
concept_id
|
461
|
+
for concept_id in self._numeric_concept_ids
|
462
|
+
+ self._categorical_concept_ids
|
463
|
+
if concept_id not in reserved_tokens
|
464
|
+
]
|
465
|
+
)
|
466
|
+
|
467
|
+
@property
|
468
|
+
def token_to_time_token_mapping(self) -> Dict[int, List[int]]:
|
469
|
+
default_mapping = {-1: [0, 0, 0]}
|
470
|
+
default_mapping.update(
|
471
|
+
{
|
472
|
+
self._tokenizer.token_to_id(time_token): list(
|
473
|
+
map(self._att_tokenizer.token_to_id, sub_time_tokens)
|
474
|
+
)
|
475
|
+
for time_token, sub_time_tokens in self._token_to_sub_time_token_mapping.items()
|
476
|
+
}
|
477
|
+
)
|
478
|
+
return default_mapping
|
479
|
+
|
480
|
+
@property
|
481
|
+
def pretrained_concept_embedding_model(self):
|
482
|
+
return self._pretrained_concept_embedding_model
|
483
|
+
|
484
|
+
def get_vocab(self) -> Dict[str, int]:
|
485
|
+
return self._tokenizer.get_vocab()
|
486
|
+
|
487
|
+
def get_value_vocab(self) -> Dict[str, int]:
|
488
|
+
return self._value_tokenizer.get_vocab()
|
489
|
+
|
490
|
+
def encode(self, concept_ids, **kwargs) -> Sequence[int]:
|
491
|
+
encoded = self._tokenizer.encode(concept_ids, is_pretokenized=True)
|
492
|
+
return encoded.ids
|
493
|
+
|
494
|
+
def decode(
|
495
|
+
self, concept_token_ids: List[int], skip_special_tokens: bool = True, **kwargs
|
496
|
+
) -> List[str]:
|
497
|
+
return self._tokenizer.decode(
|
498
|
+
concept_token_ids, skip_special_tokens=skip_special_tokens
|
499
|
+
).split(" ")
|
500
|
+
|
501
|
+
def encode_value(self, concept_values: Sequence[str]) -> Sequence[int]:
|
502
|
+
encoded = self._value_tokenizer.encode(concept_values, is_pretokenized=True)
|
503
|
+
return encoded.ids
|
504
|
+
|
505
|
+
def decode_value(
|
506
|
+
self, concept_value_token_ids: List[int], skip_special_tokens: bool = True
|
507
|
+
) -> List[str]:
|
508
|
+
return self._value_tokenizer.decode(
|
509
|
+
concept_value_token_ids, skip_special_tokens=skip_special_tokens
|
510
|
+
).split(" ")
|
511
|
+
|
512
|
+
def _convert_token_to_id(self, token):
|
513
|
+
"""Converts a token (str) in an id using the vocab."""
|
514
|
+
token_id = self._tokenizer.token_to_id(token)
|
515
|
+
return token_id if token_id else self._oov_token_id
|
516
|
+
|
517
|
+
def _convert_id_to_token(self, index):
|
518
|
+
"""Converts an index (integer) in a token (str) using the vocab."""
|
519
|
+
token = self._tokenizer.id_to_token(index)
|
520
|
+
return token if token else OUT_OF_VOCABULARY_TOKEN
|
521
|
+
|
522
|
+
def convert_tokens_to_string(self, tokens):
|
523
|
+
"""Converts a sequence of tokens (string) in a single string."""
|
524
|
+
out_string = " ".join([self._concept_name_mapping[t] for t in tokens])
|
525
|
+
return out_string
|
526
|
+
|
527
|
+
def save_pretrained(
|
528
|
+
self,
|
529
|
+
save_directory: Union[str, os.PathLike],
|
530
|
+
push_to_hub: bool = False,
|
531
|
+
**kwargs,
|
532
|
+
):
|
533
|
+
"""
|
534
|
+
Save the Cehrbert tokenizer.
|
535
|
+
|
536
|
+
This method make sure the batch processor can then be re-loaded using the
|
537
|
+
.from_pretrained class method.
|
538
|
+
|
539
|
+
Args:
|
540
|
+
save_directory (`str` or `os.PathLike`): The path to a directory where the tokenizer will be saved.
|
541
|
+
push_to_hub (`bool`, *optional*, defaults to `False`):
|
542
|
+
Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
|
543
|
+
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
|
544
|
+
namespace).
|
545
|
+
kwargs (`Dict[str, Any]`, *optional*):
|
546
|
+
Additional key word arguments passed along to the [`PushToHubMixin.push_to_hub`] method.
|
547
|
+
"""
|
548
|
+
assert not os.path.isfile(
|
549
|
+
save_directory
|
550
|
+
), f"Provided path ({save_directory}) should be a directory, not a file"
|
551
|
+
|
552
|
+
os.makedirs(save_directory, exist_ok=True)
|
553
|
+
|
554
|
+
if push_to_hub:
|
555
|
+
commit_message = kwargs.pop("commit_message", None)
|
556
|
+
repo_id = kwargs.pop("repo_id", str(save_directory).split(os.path.sep)[-1])
|
557
|
+
repo_id = self._create_repo(repo_id, **kwargs)
|
558
|
+
files_timestamps = self._get_files_timestamps(save_directory)
|
559
|
+
|
560
|
+
self._tokenizer.save(os.path.join(save_directory, TOKENIZER_FILE_NAME))
|
561
|
+
|
562
|
+
self._value_tokenizer.save(
|
563
|
+
os.path.join(save_directory, VALUE_TOKENIZER_FILE_NAME)
|
564
|
+
)
|
565
|
+
|
566
|
+
self._att_tokenizer.save(os.path.join(save_directory, TIME_TOKENIZER_FILE_NAME))
|
567
|
+
|
568
|
+
with open(
|
569
|
+
os.path.join(save_directory, TOKEN_TO_SUB_TIME_TOKEN_MAPPING_FILE_NAME), "w"
|
570
|
+
) as f:
|
571
|
+
json.dump(self._token_to_sub_time_token_mapping, f)
|
572
|
+
|
573
|
+
with open(os.path.join(save_directory, LAB_STATS_FILE_NAME), "wb") as f:
|
574
|
+
lab_stats = {
|
575
|
+
"numeric_lab_stats": self._numeric_lab_stats,
|
576
|
+
"categorical_lab_stats": self._categorical_lab_stats,
|
577
|
+
}
|
578
|
+
pickle.dump(lab_stats, f)
|
579
|
+
|
580
|
+
with open(os.path.join(save_directory, CONCEPT_MAPPING_FILE_NAME), "w") as f:
|
581
|
+
json.dump(self._concept_name_mapping, f)
|
582
|
+
|
583
|
+
self._pretrained_concept_embedding_model.save(save_directory)
|
584
|
+
|
585
|
+
if push_to_hub:
|
586
|
+
self._upload_modified_files(
|
587
|
+
save_directory,
|
588
|
+
repo_id,
|
589
|
+
files_timestamps,
|
590
|
+
commit_message=commit_message,
|
591
|
+
token=kwargs.get("token"),
|
592
|
+
)
|
593
|
+
|
594
|
+
@classmethod
|
595
|
+
def from_pretrained(
|
596
|
+
cls,
|
597
|
+
pretrained_model_name_or_path: Union[str, os.PathLike],
|
598
|
+
**kwargs,
|
599
|
+
):
|
600
|
+
"""
|
601
|
+
Load the CehrBert tokenizer.
|
602
|
+
|
603
|
+
Parameters:
|
604
|
+
pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
|
605
|
+
Can be either:
|
606
|
+
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
|
607
|
+
Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a
|
608
|
+
user or organization name, like `dbmdz/bert-base-german-cased`.
|
609
|
+
- A path to a *directory* containing tokenization data saved using
|
610
|
+
[`save_pretrained`], e.g., `./my_data_directory/`.
|
611
|
+
kwargs: Arguments for loading to pass to transformers.utils.hub.cached_file
|
612
|
+
|
613
|
+
Returns:
|
614
|
+
A CehrBert Tokenizer
|
615
|
+
"""
|
616
|
+
|
617
|
+
is_legacy_tokenizer = CehrGptTokenizer.is_legacy_tokenizer(
|
618
|
+
pretrained_model_name_or_path, **kwargs
|
619
|
+
)
|
620
|
+
|
621
|
+
# Load the concept tokenizer
|
622
|
+
tokenizer_file = transformers.utils.hub.cached_file(
|
623
|
+
pretrained_model_name_or_path, TOKENIZER_FILE_NAME, **kwargs
|
624
|
+
)
|
625
|
+
if not tokenizer_file:
|
626
|
+
return None
|
627
|
+
tokenizer = Tokenizer.from_file(tokenizer_file)
|
628
|
+
|
629
|
+
# Load the concept_value_tokenizer
|
630
|
+
if is_legacy_tokenizer:
|
631
|
+
value_tokenizer = Tokenizer(
|
632
|
+
WordLevel(unk_token=OUT_OF_VOCABULARY_TOKEN, vocab=dict())
|
633
|
+
)
|
634
|
+
else:
|
635
|
+
value_tokenizer_file = transformers.utils.hub.cached_file(
|
636
|
+
pretrained_model_name_or_path, VALUE_TOKENIZER_FILE_NAME, **kwargs
|
637
|
+
)
|
638
|
+
if not value_tokenizer_file:
|
639
|
+
return None
|
640
|
+
value_tokenizer = Tokenizer.from_file(value_tokenizer_file)
|
641
|
+
|
642
|
+
# Load the ttt tokenizer
|
643
|
+
att_tokenizer_file = transformers.utils.hub.cached_file(
|
644
|
+
pretrained_model_name_or_path, TIME_TOKENIZER_FILE_NAME, **kwargs
|
645
|
+
)
|
646
|
+
if not att_tokenizer_file:
|
647
|
+
return None
|
648
|
+
att_tokenizer = Tokenizer.from_file(att_tokenizer_file)
|
649
|
+
|
650
|
+
# Load the sub time token json file
|
651
|
+
token_to_sub_time_token_mapping_file = transformers.utils.hub.cached_file(
|
652
|
+
pretrained_model_name_or_path,
|
653
|
+
TOKEN_TO_SUB_TIME_TOKEN_MAPPING_FILE_NAME,
|
654
|
+
**kwargs,
|
655
|
+
)
|
656
|
+
if not token_to_sub_time_token_mapping_file:
|
657
|
+
return None
|
658
|
+
token_to_sub_time_token_mapping = load_json_file(
|
659
|
+
token_to_sub_time_token_mapping_file
|
660
|
+
)
|
661
|
+
|
662
|
+
# Load the lab stats pickle file
|
663
|
+
if is_legacy_tokenizer:
|
664
|
+
legacy_lab_stats_file = transformers.utils.hub.cached_file(
|
665
|
+
pretrained_model_name_or_path, LEGACY_LAB_STATS_FILE_NAME, **kwargs
|
666
|
+
)
|
667
|
+
if not legacy_lab_stats_file:
|
668
|
+
return None
|
669
|
+
# Support the old version of the numeric lab stats file
|
670
|
+
lab_stats = {
|
671
|
+
"numeric_lab_stats": load_json_file(legacy_lab_stats_file),
|
672
|
+
"categorical_lab_stats": dict(),
|
673
|
+
}
|
674
|
+
else:
|
675
|
+
lab_stats_file = transformers.utils.hub.cached_file(
|
676
|
+
pretrained_model_name_or_path, LAB_STATS_FILE_NAME, **kwargs
|
677
|
+
)
|
678
|
+
if not lab_stats_file:
|
679
|
+
return None
|
680
|
+
|
681
|
+
with open(lab_stats_file, "rb") as file:
|
682
|
+
lab_stats = pickle.load(file)
|
683
|
+
|
684
|
+
# Load the concept_name json file
|
685
|
+
concept_name_mapping_file = transformers.utils.hub.cached_file(
|
686
|
+
pretrained_model_name_or_path, CONCEPT_MAPPING_FILE_NAME, **kwargs
|
687
|
+
)
|
688
|
+
if not concept_name_mapping_file:
|
689
|
+
return None
|
690
|
+
concept_name_mapping = load_json_file(concept_name_mapping_file)
|
691
|
+
|
692
|
+
pretrained_embedding_model = PretrainedEmbeddings(pretrained_model_name_or_path)
|
693
|
+
|
694
|
+
return CehrGptTokenizer(
|
695
|
+
tokenizer,
|
696
|
+
value_tokenizer,
|
697
|
+
att_tokenizer,
|
698
|
+
token_to_sub_time_token_mapping,
|
699
|
+
lab_stats["numeric_lab_stats"],
|
700
|
+
lab_stats["categorical_lab_stats"],
|
701
|
+
concept_name_mapping,
|
702
|
+
pretrained_embedding_model,
|
703
|
+
)
|
704
|
+
|
705
|
+
@classmethod
|
706
|
+
def is_legacy_tokenizer(
|
707
|
+
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
|
708
|
+
):
|
709
|
+
try:
|
710
|
+
legacy_lab_stats_file = transformers.utils.hub.cached_file(
|
711
|
+
pretrained_model_name_or_path, LEGACY_LAB_STATS_FILE_NAME, **kwargs
|
712
|
+
)
|
713
|
+
return legacy_lab_stats_file is not None
|
714
|
+
except Exception:
|
715
|
+
return False
|
716
|
+
|
717
|
+
@classmethod
|
718
|
+
def expand_trained_tokenizer(
|
719
|
+
cls,
|
720
|
+
cehrgpt_tokenizer,
|
721
|
+
dataset: Union[Dataset, DatasetDict],
|
722
|
+
concept_name_mapping: Dict[str, str],
|
723
|
+
data_args: DataTrainingArguments,
|
724
|
+
pretrained_concept_embedding_model: PretrainedEmbeddings = None,
|
725
|
+
):
|
726
|
+
if not isinstance(cehrgpt_tokenizer, CehrGptTokenizer):
|
727
|
+
raise ValueError(
|
728
|
+
"The existing cehrgpt must be an instance of CehrGptTokenizer"
|
729
|
+
)
|
730
|
+
|
731
|
+
cehrgpt_tokenizer_copy = copy.deepcopy(cehrgpt_tokenizer)
|
732
|
+
|
733
|
+
new_tokenizer = CehrGptTokenizer.train_tokenizer(
|
734
|
+
dataset=dataset,
|
735
|
+
concept_name_mapping=concept_name_mapping,
|
736
|
+
data_args=data_args,
|
737
|
+
)
|
738
|
+
|
739
|
+
new_tokens = set(new_tokenizer.get_vocab().keys()) - set(
|
740
|
+
cehrgpt_tokenizer_copy.get_vocab().keys()
|
741
|
+
)
|
742
|
+
new_value_tokens = set(new_tokenizer.get_value_vocab().keys()) - set(
|
743
|
+
cehrgpt_tokenizer_copy.get_value_vocab().keys()
|
744
|
+
)
|
745
|
+
new_att_tokens = set(new_tokenizer._att_tokenizer.get_vocab().keys()) - set(
|
746
|
+
cehrgpt_tokenizer_copy._att_tokenizer.get_vocab().keys()
|
747
|
+
)
|
748
|
+
new_token_to_sub_time_token_mapping = (
|
749
|
+
new_tokenizer._token_to_sub_time_token_mapping
|
750
|
+
)
|
751
|
+
new_numeric_lab_stats = new_tokenizer._numeric_lab_stats
|
752
|
+
new_categorical_lab_stats = new_tokenizer._categorical_lab_stats
|
753
|
+
new_concept_name_mapping = new_tokenizer._concept_name_mapping
|
754
|
+
|
755
|
+
# Add new tokens to the existing tokenizer
|
756
|
+
cehrgpt_tokenizer_copy._tokenizer.add_tokens(
|
757
|
+
[
|
758
|
+
AddedToken(token, single_word=True, normalized=False)
|
759
|
+
for token in new_tokens
|
760
|
+
]
|
761
|
+
)
|
762
|
+
# Add new tokens to the existing value tokenizer
|
763
|
+
cehrgpt_tokenizer_copy._value_tokenizer.add_tokens(
|
764
|
+
[
|
765
|
+
AddedToken(token, single_word=True, normalized=False)
|
766
|
+
for token in new_value_tokens
|
767
|
+
]
|
768
|
+
)
|
769
|
+
# Add new time tokens to the existing att tokenizer
|
770
|
+
cehrgpt_tokenizer_copy._att_tokenizer.add_tokens(
|
771
|
+
[
|
772
|
+
AddedToken(token, single_word=True, normalized=False)
|
773
|
+
for token in new_att_tokens
|
774
|
+
]
|
775
|
+
)
|
776
|
+
# Merge the time_token -> List[sub_time_tokens] mapping
|
777
|
+
for time_token, sub_time_tokens in new_token_to_sub_time_token_mapping.items():
|
778
|
+
if (
|
779
|
+
time_token
|
780
|
+
not in cehrgpt_tokenizer_copy._token_to_sub_time_token_mapping
|
781
|
+
):
|
782
|
+
cehrgpt_tokenizer_copy._token_to_sub_time_token_mapping[time_token] = (
|
783
|
+
sub_time_tokens
|
784
|
+
)
|
785
|
+
|
786
|
+
# Merge numeric lab_stats
|
787
|
+
cehrgpt_tokenizer_copy._numeric_lab_stats = cls.merge_numeric_lab_stats(
|
788
|
+
cehrgpt_tokenizer_copy._numeric_lab_stats,
|
789
|
+
new_numeric_lab_stats,
|
790
|
+
)
|
791
|
+
# Merge categorical lab_stats
|
792
|
+
cehrgpt_tokenizer_copy._categorical_lab_stats = cls.merge_categorical_lab_stats(
|
793
|
+
cehrgpt_tokenizer_copy._categorical_lab_stats,
|
794
|
+
new_categorical_lab_stats,
|
795
|
+
)
|
796
|
+
|
797
|
+
# Merge concept_name_mapping
|
798
|
+
for token, concept_name in new_concept_name_mapping.items():
|
799
|
+
if token not in cehrgpt_tokenizer_copy._concept_name_mapping:
|
800
|
+
cehrgpt_tokenizer_copy._concept_name_mapping[token] = concept_name
|
801
|
+
|
802
|
+
return CehrGptTokenizer(
|
803
|
+
tokenizer=cehrgpt_tokenizer_copy._tokenizer,
|
804
|
+
value_tokenizer=cehrgpt_tokenizer_copy._value_tokenizer,
|
805
|
+
att_tokenizer=cehrgpt_tokenizer_copy._att_tokenizer,
|
806
|
+
token_to_sub_time_token_mapping=cehrgpt_tokenizer_copy._token_to_sub_time_token_mapping,
|
807
|
+
numeric_lab_stats=cehrgpt_tokenizer_copy._numeric_lab_stats,
|
808
|
+
categorical_lab_stats=cehrgpt_tokenizer_copy._categorical_lab_stats,
|
809
|
+
concept_name_mapping=cehrgpt_tokenizer_copy._concept_name_mapping,
|
810
|
+
pretrained_concept_embedding_model=pretrained_concept_embedding_model,
|
811
|
+
)
|
812
|
+
|
813
|
+
@classmethod
|
814
|
+
def merge_numeric_lab_stats(
|
815
|
+
cls,
|
816
|
+
lab_stats_existing: List[Dict[str, Any]],
|
817
|
+
lab_stats_new: List[Dict[str, Any]],
|
818
|
+
) -> List[Dict[str, Any]]:
|
819
|
+
|
820
|
+
lab_stats_existing_mapping = {
|
821
|
+
(lab_stat["concept_id"], lab_stat["unit"]): lab_stat
|
822
|
+
for lab_stat in lab_stats_existing
|
823
|
+
}
|
824
|
+
for lab_stat in lab_stats_new:
|
825
|
+
concept_unit_pair = (lab_stat["concept_id"], lab_stat["unit"])
|
826
|
+
if concept_unit_pair in lab_stats_existing_mapping:
|
827
|
+
existing = OnlineStatistics()
|
828
|
+
existing.count = lab_stats_existing_mapping[concept_unit_pair]["count"]
|
829
|
+
existing.current_mean = lab_stats_existing_mapping[concept_unit_pair][
|
830
|
+
"mean"
|
831
|
+
]
|
832
|
+
existing.variance = (
|
833
|
+
lab_stats_existing_mapping[concept_unit_pair]["std"] ** 2
|
834
|
+
* existing.count
|
835
|
+
)
|
836
|
+
new = OnlineStatistics()
|
837
|
+
new.count = lab_stat["count"]
|
838
|
+
new.current_mean = lab_stat["mean"]
|
839
|
+
new.variance = lab_stat["std"] ** 2 * new.count
|
840
|
+
existing.combine(new)
|
841
|
+
lab_stats_existing_mapping[concept_unit_pair]["mean"] = existing.mean()
|
842
|
+
lab_stats_existing_mapping[concept_unit_pair][
|
843
|
+
"std"
|
844
|
+
] = existing.standard_deviation()
|
845
|
+
lab_stats_existing_mapping[concept_unit_pair]["count"] = existing.count
|
846
|
+
# recreate the bins
|
847
|
+
sample = create_sample_from_bins(
|
848
|
+
lab_stats_existing_mapping[concept_unit_pair]["bins"]
|
849
|
+
)
|
850
|
+
sample.extend(create_sample_from_bins(lab_stat["bins"]))
|
851
|
+
lab_stats_existing_mapping[concept_unit_pair]["bins"] = (
|
852
|
+
create_bins_with_spline(sample, NUM_OF_BINS, DEGREE_OF_FREEDOM)
|
853
|
+
)
|
854
|
+
|
855
|
+
else:
|
856
|
+
if lab_stat["count"] > 0:
|
857
|
+
lab_stats_existing_mapping[concept_unit_pair] = lab_stat
|
858
|
+
|
859
|
+
return list(lab_stats_existing_mapping.values())
|
860
|
+
|
861
|
+
@classmethod
|
862
|
+
def merge_categorical_lab_stats(
|
863
|
+
cls,
|
864
|
+
categorical_lab_stats_existing: Dict[Tuple[str, str], int],
|
865
|
+
categorical_lab_stats_new: Dict[Tuple[str, str], int],
|
866
|
+
) -> Dict[Tuple[str, str], int]:
|
867
|
+
for (concept_id, concept_as_value), count in categorical_lab_stats_new.items():
|
868
|
+
if (concept_id, concept_as_value) not in categorical_lab_stats_new:
|
869
|
+
categorical_lab_stats_existing[(concept_id, concept_as_value)] = 0
|
870
|
+
categorical_lab_stats_existing[(concept_id, concept_as_value)] += count
|
871
|
+
return categorical_lab_stats_existing
|
872
|
+
|
873
|
+
@classmethod
|
874
|
+
def train_tokenizer(
|
875
|
+
cls,
|
876
|
+
dataset: Union[Dataset, DatasetDict],
|
877
|
+
concept_name_mapping: Dict[str, str],
|
878
|
+
data_args: DataTrainingArguments,
|
879
|
+
pretrained_concept_embedding_model: PretrainedEmbeddings = None,
|
880
|
+
):
|
881
|
+
"""
|
882
|
+
Train a huggingface word level tokenizer.
|
883
|
+
|
884
|
+
To use their tokenizer, we need to concatenate all the concepts
|
885
|
+
together and treat it as a sequence.
|
886
|
+
"""
|
887
|
+
|
888
|
+
if isinstance(dataset, DatasetDict):
|
889
|
+
dataset = dataset["train"]
|
890
|
+
|
891
|
+
concept_tokenizer = cls.train_concept_tokenizer(
|
892
|
+
dataset,
|
893
|
+
feature_name="concept_ids",
|
894
|
+
special_tokens=[PAD_TOKEN, OUT_OF_VOCABULARY_TOKEN, START_TOKEN, END_TOKEN],
|
895
|
+
unk_token=OUT_OF_VOCABULARY_TOKEN,
|
896
|
+
data_args=data_args,
|
897
|
+
)
|
898
|
+
concept_value_column = "concept_as_values"
|
899
|
+
for row in dataset:
|
900
|
+
if concept_value_column not in row:
|
901
|
+
concept_value_column = "concept_values"
|
902
|
+
break
|
903
|
+
value_tokenizer = cls.train_concept_tokenizer(
|
904
|
+
dataset,
|
905
|
+
feature_name=concept_value_column,
|
906
|
+
special_tokens=[OUT_OF_VOCABULARY_TOKEN, PAD_TOKEN],
|
907
|
+
unk_token=OUT_OF_VOCABULARY_TOKEN,
|
908
|
+
data_args=data_args,
|
909
|
+
)
|
910
|
+
value_tokenizer.add_tokens(
|
911
|
+
[
|
912
|
+
AddedToken(_, single_word=True, normalized=False)
|
913
|
+
for _ in [create_value_bin(_) for _ in range(NUM_OF_BINS)]
|
914
|
+
+ [UNKNOWN_BIN, NONE_BIN]
|
915
|
+
]
|
916
|
+
)
|
917
|
+
|
918
|
+
map_statistics_partial = partial(map_statistics, size=SAMPLE_SIZE)
|
919
|
+
|
920
|
+
if data_args.streaming:
|
921
|
+
parts = dataset.map(
|
922
|
+
partial(agg_helper, map_func=map_statistics_partial),
|
923
|
+
batched=True,
|
924
|
+
batch_size=data_args.preprocessing_batch_size,
|
925
|
+
new_fingerprint="invalid",
|
926
|
+
remove_columns=dataset.column_names,
|
927
|
+
)
|
928
|
+
else:
|
929
|
+
parts = dataset.map(
|
930
|
+
partial(agg_helper, map_func=map_statistics_partial),
|
931
|
+
batched=True,
|
932
|
+
batch_size=data_args.preprocessing_batch_size,
|
933
|
+
remove_columns=dataset.column_names,
|
934
|
+
num_proc=data_args.preprocessing_num_workers,
|
935
|
+
keep_in_memory=True,
|
936
|
+
new_fingerprint="invalid",
|
937
|
+
)
|
938
|
+
current = None
|
939
|
+
for stat in tqdm(parts, desc="Aggregating the lab statistics"):
|
940
|
+
fixed_stat = pickle.loads(stat["data"])
|
941
|
+
if current is None:
|
942
|
+
current = fixed_stat
|
943
|
+
else:
|
944
|
+
current = agg_statistics(current, fixed_stat)
|
945
|
+
|
946
|
+
numeric_lab_stats = []
|
947
|
+
for (concept_id, unit), online_stats in current["numeric_stats_by_lab"].items():
|
948
|
+
if len(online_stats.samples) == 0:
|
949
|
+
continue
|
950
|
+
samples = truncated_sample(
|
951
|
+
online_stats.samples, data_args.value_outlier_std
|
952
|
+
)
|
953
|
+
bins = create_bins_with_spline(samples, NUM_OF_BINS, DEGREE_OF_FREEDOM)
|
954
|
+
if len(bins) > 0:
|
955
|
+
numeric_lab_stats.append(
|
956
|
+
{
|
957
|
+
"concept_id": concept_id,
|
958
|
+
"unit": unit,
|
959
|
+
"mean": np.mean(samples),
|
960
|
+
"std": np.std(samples),
|
961
|
+
"count": len(online_stats.samples),
|
962
|
+
"value_outlier_std": data_args.value_outlier_std,
|
963
|
+
"bins": bins,
|
964
|
+
}
|
965
|
+
)
|
966
|
+
|
967
|
+
categorical_lab_stats = collections.defaultdict(int)
|
968
|
+
for (concept_id, value_as_concept), count in current[
|
969
|
+
"categorical_stats_by_lab"
|
970
|
+
].items():
|
971
|
+
categorical_lab_stats[(concept_id, value_as_concept)] += count
|
972
|
+
|
973
|
+
# We will train a tokenizer specifically for time intervals
|
974
|
+
sub_time_token_data = []
|
975
|
+
token_to_sub_time_token_mapping = collections.defaultdict(list)
|
976
|
+
for token, token_id in concept_tokenizer.get_vocab().items():
|
977
|
+
if is_att_token(token):
|
978
|
+
time_interval = extract_time_interval_in_days(token)
|
979
|
+
time_tuple = convert_time_interval_to_time_tuple(
|
980
|
+
time_interval, is_inpatient_att_token(token)
|
981
|
+
)
|
982
|
+
token_to_sub_time_token_mapping[token] = list(time_tuple)
|
983
|
+
sub_time_token_data.append(" ".join(time_tuple))
|
984
|
+
|
985
|
+
att_tokenizer = Tokenizer(
|
986
|
+
WordLevel(unk_token=OUT_OF_VOCABULARY_TOKEN, vocab=dict())
|
987
|
+
)
|
988
|
+
att_tokenizer.pre_tokenizer = WhitespaceSplit()
|
989
|
+
att_trainer = WordLevelTrainer(
|
990
|
+
special_tokens=[OUT_OF_VOCABULARY_TOKEN],
|
991
|
+
vocab_size=data_args.vocab_size,
|
992
|
+
min_frequency=0,
|
993
|
+
show_progress=True,
|
994
|
+
)
|
995
|
+
att_tokenizer.train_from_iterator(sub_time_token_data, trainer=att_trainer)
|
996
|
+
|
997
|
+
return CehrGptTokenizer(
|
998
|
+
concept_tokenizer,
|
999
|
+
value_tokenizer,
|
1000
|
+
att_tokenizer,
|
1001
|
+
token_to_sub_time_token_mapping,
|
1002
|
+
numeric_lab_stats,
|
1003
|
+
categorical_lab_stats,
|
1004
|
+
concept_name_mapping,
|
1005
|
+
pretrained_concept_embedding_model,
|
1006
|
+
)
|
1007
|
+
|
1008
|
+
@classmethod
|
1009
|
+
def train_concept_tokenizer(
|
1010
|
+
cls,
|
1011
|
+
dataset,
|
1012
|
+
feature_name,
|
1013
|
+
special_tokens: List[str],
|
1014
|
+
unk_token,
|
1015
|
+
data_args,
|
1016
|
+
):
|
1017
|
+
# Use the Fast Tokenizer from the Huggingface tokenizers Rust implementation.
|
1018
|
+
# https://github.com/huggingface/tokenizers
|
1019
|
+
concept_tokenizer = Tokenizer(WordLevel(unk_token=unk_token, vocab=dict()))
|
1020
|
+
concept_tokenizer.pre_tokenizer = WhitespaceSplit()
|
1021
|
+
concept_trainer = WordLevelTrainer(
|
1022
|
+
special_tokens=special_tokens,
|
1023
|
+
vocab_size=data_args.vocab_size,
|
1024
|
+
min_frequency=data_args.min_frequency,
|
1025
|
+
show_progress=True,
|
1026
|
+
)
|
1027
|
+
batch_concat_concepts_partial_func = partial(
|
1028
|
+
cls.batch_concat_concepts, feature_name=feature_name
|
1029
|
+
)
|
1030
|
+
if data_args.streaming:
|
1031
|
+
concatenated_features = dataset.map(
|
1032
|
+
batch_concat_concepts_partial_func,
|
1033
|
+
batched=True,
|
1034
|
+
batch_size=data_args.preprocessing_batch_size,
|
1035
|
+
)
|
1036
|
+
|
1037
|
+
def batched_generator():
|
1038
|
+
iterator = iter(concatenated_features)
|
1039
|
+
while True:
|
1040
|
+
batch = list(islice(iterator, data_args.preprocessing_batch_size))
|
1041
|
+
if not batch:
|
1042
|
+
break
|
1043
|
+
yield [example[feature_name] for example in batch]
|
1044
|
+
|
1045
|
+
# We pass a generator of list of texts (concatenated concept_ids) to train_from_iterator
|
1046
|
+
# for efficient training
|
1047
|
+
generator = batched_generator()
|
1048
|
+
else:
|
1049
|
+
concatenated_features = dataset.map(
|
1050
|
+
batch_concat_concepts_partial_func,
|
1051
|
+
num_proc=data_args.preprocessing_num_workers,
|
1052
|
+
batched=True,
|
1053
|
+
batch_size=data_args.preprocessing_batch_size,
|
1054
|
+
remove_columns=dataset.column_names,
|
1055
|
+
)
|
1056
|
+
generator = concatenated_features[feature_name]
|
1057
|
+
concept_tokenizer.train_from_iterator(generator, trainer=concept_trainer)
|
1058
|
+
return concept_tokenizer
|
1059
|
+
|
1060
|
+
def normalize(self, concept_id: str, unit: str, concept_value: float) -> str:
|
1061
|
+
return self._numeric_event_statistics.normalize(concept_id, unit, concept_value)
|
1062
|
+
|
1063
|
+
def denormalize(self, concept_id: str, value_bin: str) -> Tuple[float, str]:
|
1064
|
+
return self._numeric_event_statistics.denormalize(concept_id, value_bin)
|
1065
|
+
|
1066
|
+
@classmethod
|
1067
|
+
def batch_concat_concepts(
|
1068
|
+
cls, records: Dict[str, List], feature_name
|
1069
|
+
) -> Dict[str, List]:
|
1070
|
+
return {
|
1071
|
+
feature_name: [
|
1072
|
+
" ".join(
|
1073
|
+
[token for token in tokens if token and isinstance(token, str)]
|
1074
|
+
)
|
1075
|
+
for tokens in records[feature_name]
|
1076
|
+
]
|
1077
|
+
}
|