cehrgpt 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. __init__.py +0 -0
  2. cehrgpt/__init__.py +0 -0
  3. cehrgpt/analysis/__init__.py +0 -0
  4. cehrgpt/analysis/privacy/__init__.py +0 -0
  5. cehrgpt/analysis/privacy/attribute_inference.py +275 -0
  6. cehrgpt/analysis/privacy/attribute_inference_config.yml +8975 -0
  7. cehrgpt/analysis/privacy/member_inference.py +172 -0
  8. cehrgpt/analysis/privacy/nearest_neighbor_inference.py +189 -0
  9. cehrgpt/analysis/privacy/reid_inference.py +407 -0
  10. cehrgpt/analysis/privacy/utils.py +255 -0
  11. cehrgpt/cehrgpt_args.py +142 -0
  12. cehrgpt/data/__init__.py +0 -0
  13. cehrgpt/data/hf_cehrgpt_dataset.py +80 -0
  14. cehrgpt/data/hf_cehrgpt_dataset_collator.py +482 -0
  15. cehrgpt/data/hf_cehrgpt_dataset_mapping.py +116 -0
  16. cehrgpt/generation/__init__.py +0 -0
  17. cehrgpt/generation/chatgpt_generation.py +106 -0
  18. cehrgpt/generation/generate_batch_hf_gpt_sequence.py +333 -0
  19. cehrgpt/generation/omop_converter_batch.py +644 -0
  20. cehrgpt/generation/omop_entity.py +515 -0
  21. cehrgpt/gpt_utils.py +331 -0
  22. cehrgpt/models/__init__.py +0 -0
  23. cehrgpt/models/config.py +205 -0
  24. cehrgpt/models/hf_cehrgpt.py +1817 -0
  25. cehrgpt/models/hf_modeling_outputs.py +158 -0
  26. cehrgpt/models/pretrained_embeddings.py +82 -0
  27. cehrgpt/models/special_tokens.py +30 -0
  28. cehrgpt/models/tokenization_hf_cehrgpt.py +1077 -0
  29. cehrgpt/omop/__init__.py +0 -0
  30. cehrgpt/omop/condition_era.py +20 -0
  31. cehrgpt/omop/observation_period.py +43 -0
  32. cehrgpt/omop/omop_argparse.py +38 -0
  33. cehrgpt/omop/omop_table_builder.py +86 -0
  34. cehrgpt/omop/queries/__init__.py +0 -0
  35. cehrgpt/omop/queries/condition_era.py +86 -0
  36. cehrgpt/omop/queries/observation_period.py +135 -0
  37. cehrgpt/omop/sample_omop_tables.py +71 -0
  38. cehrgpt/runners/__init__.py +0 -0
  39. cehrgpt/runners/gpt_runner_util.py +99 -0
  40. cehrgpt/runners/hf_cehrgpt_finetune_runner.py +746 -0
  41. cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +370 -0
  42. cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +137 -0
  43. cehrgpt/runners/hyperparameter_search_util.py +223 -0
  44. cehrgpt/time_to_event/__init__.py +0 -0
  45. cehrgpt/time_to_event/config/30_day_readmission.yaml +8 -0
  46. cehrgpt/time_to_event/config/next_visit_type_prediction.yaml +8 -0
  47. cehrgpt/time_to_event/config/t2dm_hf.yaml +8 -0
  48. cehrgpt/time_to_event/time_to_event_model.py +226 -0
  49. cehrgpt/time_to_event/time_to_event_prediction.py +347 -0
  50. cehrgpt/time_to_event/time_to_event_utils.py +55 -0
  51. cehrgpt/tools/__init__.py +0 -0
  52. cehrgpt/tools/ehrshot_benchmark.py +74 -0
  53. cehrgpt/tools/generate_pretrained_embeddings.py +130 -0
  54. cehrgpt/tools/merge_synthetic_real_dataasets.py +218 -0
  55. cehrgpt/tools/upload_omop_tables.py +108 -0
  56. cehrgpt-0.0.1.dist-info/LICENSE +21 -0
  57. cehrgpt-0.0.1.dist-info/METADATA +66 -0
  58. cehrgpt-0.0.1.dist-info/RECORD +60 -0
  59. cehrgpt-0.0.1.dist-info/WHEEL +5 -0
  60. cehrgpt-0.0.1.dist-info/top_level.txt +2 -0
@@ -0,0 +1,482 @@
1
+ import random
2
+ from typing import Any, Dict
3
+
4
+ import numpy as np
5
+ import torch
6
+ from torch.nn.utils.rnn import pad_sequence
7
+
8
+ from cehrgpt.gpt_utils import (
9
+ DEMOGRAPHIC_PROMPT_SIZE,
10
+ collect_demographic_prompts_at_visits,
11
+ extract_time_interval_in_days,
12
+ is_att_token,
13
+ is_inpatient_att_token,
14
+ random_slice_gpt_sequence,
15
+ )
16
+ from cehrgpt.models.tokenization_hf_cehrgpt import CehrGptTokenizer
17
+
18
+ INPATIENT_STAY_DURATION_LIMIT = 30
19
+
20
+
21
+ class CehrGptDataCollator:
22
+ def __init__(
23
+ self,
24
+ tokenizer: CehrGptTokenizer,
25
+ max_length: int,
26
+ shuffle_records: bool = False,
27
+ include_values: bool = False,
28
+ include_ttv_prediction: bool = False,
29
+ use_sub_time_tokenization: bool = False,
30
+ pretraining: bool = True,
31
+ include_demographics: bool = False,
32
+ ):
33
+ self.tokenizer = tokenizer
34
+ self.max_length = max_length
35
+ # Pre-compute these so we can use them later on
36
+ # We used VS for the historical data, currently, we use the new [VS] for the newer data
37
+ # so we need to check both cases.
38
+ self.vs_token_id = tokenizer._convert_token_to_id("VS")
39
+ if self.vs_token_id == tokenizer._oov_token_id:
40
+ self.vs_token_id = tokenizer._convert_token_to_id("[VS]")
41
+ self.ve_token_id = tokenizer._convert_token_to_id("VE")
42
+ if self.ve_token_id == tokenizer._oov_token_id:
43
+ self.ve_token_id = tokenizer._convert_token_to_id("[VE]")
44
+
45
+ self.shuffle_records = shuffle_records
46
+ self.include_values = include_values
47
+ self.include_ttv_prediction = include_ttv_prediction
48
+ self.use_sub_time_tokenization = use_sub_time_tokenization
49
+ self.pretraining = pretraining
50
+ self.include_demographics = include_demographics
51
+
52
+ if self.use_sub_time_tokenization:
53
+ token_to_time_token_mapping = tokenizer.token_to_time_token_mapping
54
+ if not token_to_time_token_mapping:
55
+ raise ValueError(
56
+ "The token_to_time_token_mapping in CehrGptTokenizer cannot be None "
57
+ "when use_sub_time_tokenization is enabled"
58
+ )
59
+ # Create the tensors for converting time tokens to the sub time tokens
60
+ self.time_tokens = torch.tensor(
61
+ list(tokenizer.token_to_time_token_mapping.keys()), dtype=torch.int64
62
+ )
63
+ self.mapped_sub_time_tokens = torch.tensor(
64
+ list(token_to_time_token_mapping.values()), dtype=torch.int64
65
+ )
66
+
67
+ def _try_reverse_tensor(self, tensor: torch.Tensor) -> torch.Tensor:
68
+ if not self.pretraining:
69
+ return torch.flip(tensor, dims=[-1])
70
+ return tensor
71
+
72
+ @staticmethod
73
+ def _convert_to_tensor(features: Any) -> torch.Tensor:
74
+ if isinstance(features, torch.Tensor):
75
+ return features
76
+ else:
77
+ return torch.tensor(features)
78
+
79
+ @staticmethod
80
+ def _convert_time_to_event(concept_ids):
81
+ def default_value(c):
82
+ try:
83
+ if is_att_token(c):
84
+ time_to_visit = extract_time_interval_in_days(c)
85
+ if (
86
+ is_inpatient_att_token(c)
87
+ and time_to_visit > INPATIENT_STAY_DURATION_LIMIT
88
+ ):
89
+ return -100
90
+ return time_to_visit
91
+ return -100
92
+ except ValueError:
93
+ return -100
94
+
95
+ return [float(default_value(_)) for _ in concept_ids]
96
+
97
+ def __call__(self, examples):
98
+
99
+ examples = [self.generate_start_end_index(_) for _ in examples]
100
+ examples = [self.random_sort(_) for _ in examples]
101
+ batch = {}
102
+
103
+ # Assume that each example in the batch is a dictionary with 'input_ids' and 'attention_mask'
104
+ batch_input_ids = [
105
+ self._try_reverse_tensor(self._convert_to_tensor(example["input_ids"]))
106
+ for example in examples
107
+ ]
108
+ batch_attention_mask = [
109
+ self._try_reverse_tensor(
110
+ torch.ones_like(
111
+ self._convert_to_tensor(example["input_ids"]), dtype=torch.float
112
+ )
113
+ )
114
+ for example in examples
115
+ ]
116
+
117
+ # Pad sequences to the max length in the batch
118
+ batch["input_ids"] = self._try_reverse_tensor(
119
+ pad_sequence(
120
+ batch_input_ids,
121
+ batch_first=True,
122
+ padding_value=self.tokenizer.pad_token_id,
123
+ ).to(torch.int64)
124
+ )
125
+
126
+ batch["attention_mask"] = self._try_reverse_tensor(
127
+ pad_sequence(batch_attention_mask, batch_first=True, padding_value=0.0)
128
+ )
129
+ assert batch["input_ids"].shape[1] <= self.max_length
130
+ assert batch["attention_mask"].shape[1] <= self.max_length
131
+
132
+ if self.pretraining:
133
+ batch["labels"] = self._try_reverse_tensor(
134
+ pad_sequence(
135
+ batch_input_ids,
136
+ batch_first=True,
137
+ padding_value=-100,
138
+ ).to(torch.int64)
139
+ )
140
+
141
+ if self.use_sub_time_tokenization:
142
+ time_token_indicators = torch.isin(batch["input_ids"], self.time_tokens)
143
+ masked_tokens = batch["input_ids"].clone()
144
+ masked_tokens[~time_token_indicators] = -1
145
+ # Get the index of the sub_time_tokens from the time_tokens tensor
146
+ sub_time_token_indices = torch.argmax(
147
+ (
148
+ masked_tokens.unsqueeze(-1)
149
+ == self.time_tokens.unsqueeze(0).unsqueeze(0)
150
+ ).to(torch.int32),
151
+ dim=-1,
152
+ )
153
+ sub_time_tokens = self.mapped_sub_time_tokens[sub_time_token_indices]
154
+ batch["time_token_indicators"] = time_token_indicators
155
+ batch["sub_time_tokens"] = sub_time_tokens
156
+
157
+ if self.include_ttv_prediction:
158
+ batch_time_to_visits = [
159
+ self._try_reverse_tensor(
160
+ self._convert_to_tensor(example["time_to_visits"])
161
+ )
162
+ for example in examples
163
+ ]
164
+ batch["time_to_visits"] = self._try_reverse_tensor(
165
+ pad_sequence(
166
+ batch_time_to_visits, batch_first=True, padding_value=-100.0
167
+ )
168
+ )
169
+
170
+ if self.include_values:
171
+ batch_value_indicators = [
172
+ self._try_reverse_tensor(
173
+ self._convert_to_tensor(example["value_indicators"])
174
+ )
175
+ for example in examples
176
+ ]
177
+ batch_values = [
178
+ self._try_reverse_tensor(self._convert_to_tensor(example["values"]))
179
+ for example in examples
180
+ ]
181
+
182
+ batch["value_indicators"] = self._try_reverse_tensor(
183
+ pad_sequence(
184
+ batch_value_indicators, batch_first=True, padding_value=False
185
+ )
186
+ )
187
+ batch["values"] = self._try_reverse_tensor(
188
+ pad_sequence(
189
+ batch_values,
190
+ batch_first=True,
191
+ padding_value=self.tokenizer.pad_value_token_id,
192
+ ).to(torch.int64)
193
+ )
194
+ assert batch["value_indicators"].shape[1] <= self.max_length
195
+ assert batch["values"].shape[1] <= self.max_length
196
+
197
+ if self.pretraining:
198
+ batch["true_value_indicators"] = batch["value_indicators"].clone()
199
+ batch["true_values"] = torch.where(
200
+ batch["value_indicators"], batch["values"].clone(), -100
201
+ )
202
+
203
+ if "person_id" in examples[0]:
204
+ batch["person_id"] = torch.cat(
205
+ [
206
+ self._convert_to_tensor(example["person_id"]).reshape(-1, 1)
207
+ for example in examples
208
+ ],
209
+ dim=0,
210
+ ).to(torch.int32)
211
+
212
+ if "index_date" in examples[0]:
213
+ batch["index_date"] = torch.cat(
214
+ [
215
+ self._convert_to_tensor(example["index_date"]).reshape(-1, 1)
216
+ for example in examples
217
+ ],
218
+ dim=0,
219
+ ).to(torch.float32)
220
+
221
+ if "age_at_index" in examples[0]:
222
+ batch["age_at_index"] = torch.cat(
223
+ [
224
+ self._convert_to_tensor(example["age_at_index"]).reshape(-1, 1)
225
+ for example in examples
226
+ ],
227
+ dim=0,
228
+ ).to(torch.float32)
229
+
230
+ if "classifier_label" in examples[0]:
231
+ batch["classifier_label"] = torch.cat(
232
+ [
233
+ self._convert_to_tensor(example["classifier_label"]).reshape(-1, 1)
234
+ for example in examples
235
+ ],
236
+ dim=0,
237
+ ).to(torch.float32)
238
+
239
+ return batch
240
+
241
+ def random_sort(self, record: Dict[str, Any]) -> Dict[str, Any]:
242
+
243
+ if not self.shuffle_records:
244
+ return record
245
+
246
+ if "record_ranks" not in record:
247
+ return record
248
+
249
+ sorting_column = record["record_ranks"]
250
+ random_order = np.random.rand(len(sorting_column))
251
+
252
+ if self.include_values:
253
+ iterator = zip(
254
+ sorting_column,
255
+ random_order,
256
+ record["input_ids"],
257
+ record["value_indicators"],
258
+ record["values"],
259
+ )
260
+ sorted_list = sorted(iterator, key=lambda tup2: (tup2[0], tup2[1], tup2[2]))
261
+ _, _, sorted_input_ids, sorted_value_indicators, sorted_values = zip(
262
+ *list(sorted_list)
263
+ )
264
+ record["input_ids"] = self._convert_to_tensor(sorted_input_ids)
265
+ record["value_indicators"] = self._convert_to_tensor(
266
+ sorted_value_indicators
267
+ )
268
+ record["values"] = self._convert_to_tensor(sorted_values)
269
+ else:
270
+ iterator = zip(sorting_column, random_order, record["input_ids"])
271
+ sorted_list = sorted(iterator, key=lambda tup2: (tup2[0], tup2[1], tup2[2]))
272
+ _, _, sorted_input_ids = zip(*list(sorted_list))
273
+ record["input_ids"] = self._convert_to_tensor(sorted_input_ids)
274
+ return record
275
+
276
+ def generate_start_end_index(self, record: Dict[str, Any]) -> Dict[str, Any]:
277
+ """Adding the start and end indices to extract a portion of the patient sequence."""
278
+ # concept_ids will be used to for time to event predictions and identifying the visit starts
279
+ input_ids = record["input_ids"]
280
+ if isinstance(input_ids, torch.Tensor):
281
+ input_ids = input_ids.detach().tolist()
282
+ concept_ids = self.tokenizer.decode(input_ids, skip_special_tokens=False)
283
+ seq_length = len(record["input_ids"])
284
+ new_max_length = self.max_length - 1 # Subtract one for the [END] token
285
+
286
+ # Return the record directly if the actual sequence length is less than the max sequence
287
+ if seq_length <= new_max_length:
288
+ record["input_ids"] = torch.concat(
289
+ [
290
+ self._convert_to_tensor(record["input_ids"]),
291
+ self._convert_to_tensor([self.tokenizer.end_token_id]),
292
+ ]
293
+ )
294
+ if self.include_values:
295
+ record["value_indicators"] = torch.concat(
296
+ [
297
+ self._convert_to_tensor(record["value_indicators"]),
298
+ self._convert_to_tensor([False]),
299
+ ]
300
+ ).to(torch.bool)
301
+ record["values"] = torch.concat(
302
+ [
303
+ self._convert_to_tensor(record["values"]),
304
+ self._convert_to_tensor([self.tokenizer.pad_value_token_id]),
305
+ ]
306
+ )
307
+ if self.include_ttv_prediction:
308
+ record["time_to_visits"] = torch.concat(
309
+ [
310
+ self._convert_to_tensor(
311
+ self._convert_time_to_event(concept_ids)
312
+ ),
313
+ self._convert_to_tensor([-100.0]),
314
+ ]
315
+ )
316
+
317
+ return record
318
+
319
+ if self.pretraining:
320
+ # There is a 50% chance we randomly slice out a portion of the patient history and update the demographic
321
+ # prompt depending on the new starting point
322
+ if random.random() < 0.5:
323
+ start_index, end_index, demographic_tokens = random_slice_gpt_sequence(
324
+ concept_ids, new_max_length
325
+ )
326
+ if start_index != end_index:
327
+ record["input_ids"] = self._convert_to_tensor(
328
+ record["input_ids"][start_index : end_index + 1]
329
+ )
330
+ if self.include_values:
331
+ record["value_indicators"] = self._convert_to_tensor(
332
+ record["value_indicators"][start_index : end_index + 1]
333
+ ).to(torch.bool)
334
+ record["values"] = self._convert_to_tensor(
335
+ record["values"][start_index : end_index + 1]
336
+ )
337
+ if self.include_ttv_prediction:
338
+ record["time_to_visits"] = self._convert_to_tensor(
339
+ self._convert_time_to_event(
340
+ concept_ids[start_index : end_index + 1]
341
+ )
342
+ )
343
+ return record
344
+
345
+ # The default employs a right truncation strategy, where the demographic prompt is reserved
346
+ end_index = new_max_length
347
+ for i in reversed(list(range(0, end_index))):
348
+ current_token = record["input_ids"][i]
349
+ if current_token == self.ve_token_id:
350
+ end_index = i
351
+ break
352
+
353
+ record["input_ids"] = record["input_ids"][0:end_index]
354
+ if self.include_values:
355
+ record["value_indicators"] = self._convert_to_tensor(
356
+ record["value_indicators"][0:end_index]
357
+ ).to(torch.bool)
358
+ record["values"] = self._convert_to_tensor(
359
+ record["values"][0:end_index]
360
+ )
361
+ if self.include_ttv_prediction:
362
+ record["time_to_visits"] = self._convert_to_tensor(
363
+ self._convert_time_to_event(concept_ids[0:end_index])
364
+ )
365
+ return record
366
+ else:
367
+ if self.include_demographics:
368
+ # We employ a left truncation strategy, where the most recent patient history is reserved for fine-tuning
369
+ demographic_prompts_at_visits = collect_demographic_prompts_at_visits(
370
+ concept_ids
371
+ )
372
+ for token_index, demographic_prompt in demographic_prompts_at_visits:
373
+ if (
374
+ seq_length - token_index
375
+ <= new_max_length - DEMOGRAPHIC_PROMPT_SIZE
376
+ ):
377
+ demographic_tokens = self.tokenizer.encode(demographic_prompt)
378
+ record["input_ids"] = torch.concat(
379
+ [
380
+ self._convert_to_tensor(demographic_tokens),
381
+ self._convert_to_tensor(
382
+ record["input_ids"][token_index:seq_length]
383
+ ),
384
+ ]
385
+ )
386
+ if self.include_values:
387
+ record["value_indicators"] = torch.concat(
388
+ [
389
+ torch.zeros(
390
+ [DEMOGRAPHIC_PROMPT_SIZE], dtype=torch.int32
391
+ ).to(torch.bool),
392
+ self._convert_to_tensor(
393
+ record["value_indicators"][
394
+ token_index:seq_length
395
+ ]
396
+ ),
397
+ ]
398
+ )
399
+ record["values"] = torch.concat(
400
+ [
401
+ torch.zeros(
402
+ [DEMOGRAPHIC_PROMPT_SIZE], dtype=torch.int32
403
+ )
404
+ .to(torch.int32)
405
+ .fill_(self.tokenizer.pad_value_token_id),
406
+ self._convert_to_tensor(
407
+ record["values"][token_index:seq_length]
408
+ ),
409
+ ]
410
+ )
411
+ if self.include_ttv_prediction:
412
+ record["time_to_visits"] = torch.concat(
413
+ [
414
+ torch.zeros(
415
+ [DEMOGRAPHIC_PROMPT_SIZE], dtype=torch.int32
416
+ )
417
+ .to(torch.float32)
418
+ .fill_(-100.0),
419
+ record["time_to_visits"][token_index:seq_length],
420
+ ]
421
+ )
422
+ break
423
+ else:
424
+ start_index = seq_length - new_max_length
425
+ end_index = seq_length
426
+ for i in range(start_index, end_index):
427
+ current_token = record["input_ids"][i]
428
+ if current_token == self.vs_token_id:
429
+ record["input_ids"] = record["input_ids"][i:end_index]
430
+ if self.include_values:
431
+ record["value_indicators"] = record["value_indicators"][
432
+ i:end_index
433
+ ]
434
+ record["values"] = record["values"][i:end_index]
435
+ if self.include_ttv_prediction:
436
+ record["time_to_visits"] = record["time_to_visits"][
437
+ i:end_index
438
+ ]
439
+ break
440
+
441
+ # This could happen when the last visit contains more than new_max_length number of tokens
442
+ # We simply take the last new_max_length number of tokens from the patient sequence
443
+ if len(record["input_ids"]) > new_max_length:
444
+ record["input_ids"] = record["input_ids"][-new_max_length:]
445
+ if self.include_values:
446
+ record["value_indicators"] = record["value_indicators"][
447
+ -new_max_length:
448
+ ]
449
+ record["values"] = record["values"][-new_max_length:]
450
+ if self.include_ttv_prediction:
451
+ record["time_to_visits"] = record["time_to_visits"][
452
+ -new_max_length:
453
+ ]
454
+
455
+ # Finally we add the end token to the end of the sequence
456
+ record["input_ids"] = torch.concat(
457
+ [
458
+ self._convert_to_tensor(record["input_ids"]),
459
+ self._convert_to_tensor([self.tokenizer.end_token_id]),
460
+ ]
461
+ )
462
+ if self.include_values:
463
+ record["value_indicators"] = torch.concat(
464
+ [
465
+ self._convert_to_tensor(record["value_indicators"]),
466
+ self._convert_to_tensor([False]),
467
+ ]
468
+ ).to(torch.bool)
469
+ record["values"] = torch.concat(
470
+ [
471
+ self._convert_to_tensor(record["values"]),
472
+ self._convert_to_tensor([self.tokenizer.pad_value_token_id]),
473
+ ]
474
+ )
475
+ if self.include_ttv_prediction:
476
+ record["time_to_visits"] = torch.concat(
477
+ [
478
+ record["time_to_visits"],
479
+ self._convert_to_tensor([-100.0]),
480
+ ]
481
+ )
482
+ return record
@@ -0,0 +1,116 @@
1
+ import datetime
2
+ from typing import Any, Dict
3
+
4
+ import numpy as np
5
+ from cehrbert.data_generators.hf_data_generator.hf_dataset_mapping import DatasetMapping
6
+
7
+ from cehrgpt.models.tokenization_hf_cehrgpt import (
8
+ NONE_BIN,
9
+ UNKNOWN_BIN,
10
+ CehrGptTokenizer,
11
+ )
12
+
13
+
14
+ def convert_date_to_posix_time(index_date: datetime.date) -> float:
15
+ return datetime.datetime.combine(
16
+ index_date, datetime.datetime.min.time()
17
+ ).timestamp()
18
+
19
+
20
+ class HFCehrGptTokenizationMapping(DatasetMapping):
21
+ def __init__(
22
+ self,
23
+ concept_tokenizer: CehrGptTokenizer,
24
+ ):
25
+ self._concept_tokenizer = concept_tokenizer
26
+ self._lab_token_ids = self._concept_tokenizer.lab_token_ids
27
+
28
+ def remove_columns(self):
29
+ return [
30
+ "concept_value_masks",
31
+ "is_numeric_types",
32
+ ]
33
+
34
+ def transform(self, record: Dict[str, Any]) -> Dict[str, Any]:
35
+ # If any concept has a value associated with it, we normalize the value
36
+ record["input_ids"] = self._concept_tokenizer.encode(record["concept_ids"])
37
+ record["value_indicators"] = record["concept_value_masks"]
38
+ if "number_as_values" not in record or "concept_as_values" not in record:
39
+ record["number_as_values"] = [
40
+ float(value) if isinstance(value, float) else None
41
+ for value in record["concept_values"]
42
+ ]
43
+ record["is_numeric_types"] = [
44
+ int(isinstance(value, float)) for value in record["concept_values"]
45
+ ]
46
+ record["concept_as_values"] = [
47
+ value if isinstance(value, str) else None
48
+ for value in record["concept_values"]
49
+ ]
50
+ if np.any(np.asarray(record["concept_value_masks"]) > 0):
51
+ values = []
52
+ for i, (
53
+ concept_id,
54
+ unit,
55
+ concept_value_mask,
56
+ number_as_value,
57
+ concept_as_value,
58
+ is_numeric_type,
59
+ ) in enumerate(
60
+ zip(
61
+ record["concept_ids"],
62
+ record["units"],
63
+ record["concept_value_masks"],
64
+ record["number_as_values"],
65
+ record["concept_as_values"],
66
+ record["is_numeric_types"],
67
+ )
68
+ ):
69
+ if concept_value_mask == 1:
70
+ value = UNKNOWN_BIN
71
+ if is_numeric_type == 1:
72
+ if concept_id in self._concept_tokenizer.numeric_concept_ids:
73
+ value = self._concept_tokenizer.normalize(
74
+ concept_id, unit, number_as_value
75
+ )
76
+ elif isinstance(concept_as_value, str):
77
+ value = concept_as_value
78
+ values.append(value)
79
+ else:
80
+ values.append(NONE_BIN)
81
+ assert len(values) == len(record["input_ids"])
82
+ record["values"] = self._concept_tokenizer.encode_value(values)
83
+ else:
84
+ record["values"] = self._concept_tokenizer.encode_value(
85
+ [NONE_BIN for _ in range(len(record["concept_value_masks"]))]
86
+ )
87
+ # Delete these features because they contain null values and pyarrow cannot concatenate multiple records
88
+ del record["number_as_values"]
89
+ del record["concept_as_values"]
90
+ return record
91
+
92
+
93
+ class HFFineTuningMapping(HFCehrGptTokenizationMapping):
94
+ """Consider removing this transformation in the future."""
95
+
96
+ def transform(self, record: Dict[str, Any]) -> Dict[str, Any]:
97
+ record = super().transform(record)
98
+ record.update(
99
+ {
100
+ "age_at_index": (
101
+ record["age"] if "age" in record else record["age_at_index"]
102
+ ),
103
+ "classifier_label": int(record["label"] > 0),
104
+ "index_date": (
105
+ convert_date_to_posix_time(record["index_date"])
106
+ if "index_date" in record
107
+ else None
108
+ ),
109
+ }
110
+ )
111
+ return record
112
+
113
+ def remove_columns(self):
114
+ columns = super().remove_columns()
115
+ columns.append("label")
116
+ return columns
File without changes