cehrgpt 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cehrgpt/data/hf_cehrgpt_dataset.py +24 -4
- cehrgpt/data/hf_cehrgpt_dataset_collator.py +260 -84
- cehrgpt/data/hf_cehrgpt_dataset_mapping.py +279 -2
- cehrgpt/data/sample_packing_sampler.py +151 -0
- cehrgpt/generation/generate_batch_hf_gpt_sequence.py +12 -9
- cehrgpt/generation/omop_converter_batch.py +3 -0
- cehrgpt/models/config.py +10 -0
- cehrgpt/models/hf_cehrgpt.py +244 -73
- cehrgpt/models/tokenization_hf_cehrgpt.py +6 -2
- cehrgpt/runners/data_utils.py +243 -0
- cehrgpt/runners/gpt_runner_util.py +0 -10
- cehrgpt/runners/hf_cehrgpt_finetune_runner.py +154 -260
- cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +250 -90
- cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +46 -0
- cehrgpt/runners/hyperparameter_search_util.py +4 -1
- cehrgpt/runners/sample_packing_trainer.py +168 -0
- cehrgpt/simulations/__init__.py +0 -0
- cehrgpt/simulations/generate_plots.py +95 -0
- cehrgpt/simulations/run_simulation.sh +24 -0
- cehrgpt/simulations/time_embedding_simulation.py +250 -0
- cehrgpt/simulations/time_token_simulation.py +177 -0
- cehrgpt/tools/generate_causal_patient_split_by_age.py +146 -0
- cehrgpt/tools/linear_prob/__init__.py +0 -0
- cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +467 -0
- cehrgpt/tools/linear_prob/train_with_cehrgpt_features.py +152 -0
- {cehrgpt-0.0.1.dist-info → cehrgpt-0.1.0.dist-info}/METADATA +57 -9
- {cehrgpt-0.0.1.dist-info → cehrgpt-0.1.0.dist-info}/RECORD +30 -18
- {cehrgpt-0.0.1.dist-info → cehrgpt-0.1.0.dist-info}/WHEEL +1 -1
- {cehrgpt-0.0.1.dist-info → cehrgpt-0.1.0.dist-info/licenses}/LICENSE +0 -0
- {cehrgpt-0.0.1.dist-info → cehrgpt-0.1.0.dist-info}/top_level.txt +0 -0
@@ -1,8 +1,23 @@
|
|
1
1
|
import datetime
|
2
|
-
from typing import Any, Dict
|
2
|
+
from typing import Any, Dict, Generator, Optional
|
3
3
|
|
4
4
|
import numpy as np
|
5
|
-
|
5
|
+
import pandas as pd
|
6
|
+
from cehrbert.data_generators.hf_data_generator.hf_dataset_mapping import (
|
7
|
+
ED_VISIT_TYPE_CODES,
|
8
|
+
INPATIENT_VISIT_TYPE_CODES,
|
9
|
+
INPATIENT_VISIT_TYPES,
|
10
|
+
DatasetMapping,
|
11
|
+
VisitObject,
|
12
|
+
get_value,
|
13
|
+
has_events_and_get_events,
|
14
|
+
replace_escape_chars,
|
15
|
+
)
|
16
|
+
from cehrbert.med_extension.schema_extension import Event
|
17
|
+
from cehrbert.runners.hf_runner_argument_dataclass import DataTrainingArguments
|
18
|
+
from cehrbert_data.const.common import NA
|
19
|
+
from cehrbert_data.decorators.patient_event_decorator_base import get_att_function
|
20
|
+
from dateutil.relativedelta import relativedelta
|
6
21
|
|
7
22
|
from cehrgpt.models.tokenization_hf_cehrgpt import (
|
8
23
|
NONE_BIN,
|
@@ -17,6 +32,268 @@ def convert_date_to_posix_time(index_date: datetime.date) -> float:
|
|
17
32
|
).timestamp()
|
18
33
|
|
19
34
|
|
35
|
+
class MedToCehrGPTDatasetMapping(DatasetMapping):
|
36
|
+
def __init__(
|
37
|
+
self,
|
38
|
+
data_args: DataTrainingArguments,
|
39
|
+
include_inpatient_hour_token: bool = True,
|
40
|
+
):
|
41
|
+
self._time_token_function = get_att_function(data_args.att_function_type)
|
42
|
+
self._include_auxiliary_token = data_args.include_auxiliary_token
|
43
|
+
self._inpatient_time_token_function = get_att_function(
|
44
|
+
data_args.inpatient_att_function_type
|
45
|
+
)
|
46
|
+
self._include_demographic_prompt = data_args.include_demographic_prompt
|
47
|
+
self._include_inpatient_hour_token = include_inpatient_hour_token
|
48
|
+
|
49
|
+
"""
|
50
|
+
This mapping function converts the MED (https://github.com/Medical-Event-Data-Standard/meds/tree/main) extension
|
51
|
+
to the CehrGPT format. We make several assumptions
|
52
|
+
- The first event contains the demographic information
|
53
|
+
- From the second event onward
|
54
|
+
- the time of the event is visit_start_datetime.
|
55
|
+
- the first measurement contains the code indicating a standard OMOP Visit concept_id (e.g. 9201, 9202)
|
56
|
+
- in case of inpatient visits, the last measurement is assumed to
|
57
|
+
contain the standard OMOP concept id for discharge facilities (e.g 8536)
|
58
|
+
- in case of inpatient visits, datetime_value of the last measurement stores visit_end_datetime
|
59
|
+
"""
|
60
|
+
|
61
|
+
def remove_columns(self):
|
62
|
+
return ["patient_id", "visits", "birth_datetime"]
|
63
|
+
|
64
|
+
@staticmethod
|
65
|
+
def _update_cehrgpt_record(
|
66
|
+
cehrgpt_record: Dict[str, Any],
|
67
|
+
code: str,
|
68
|
+
concept_value_mask: int = 0,
|
69
|
+
number_as_value: float = 0.0,
|
70
|
+
concept_as_value: str = "0",
|
71
|
+
is_numeric_type: int = 0,
|
72
|
+
unit: str = NA,
|
73
|
+
) -> None:
|
74
|
+
cehrgpt_record["concept_ids"].append(replace_escape_chars(code))
|
75
|
+
cehrgpt_record["concept_value_masks"].append(concept_value_mask)
|
76
|
+
cehrgpt_record["number_as_values"].append(number_as_value)
|
77
|
+
cehrgpt_record["concept_as_values"].append(concept_as_value)
|
78
|
+
cehrgpt_record["units"].append(unit)
|
79
|
+
cehrgpt_record["is_numeric_types"].append(is_numeric_type)
|
80
|
+
|
81
|
+
def transform(self, record: Dict[str, Any]) -> Dict[str, Any]:
|
82
|
+
cehrgpt_record = {
|
83
|
+
"person_id": record["patient_id"],
|
84
|
+
"concept_ids": [],
|
85
|
+
"concept_value_masks": [],
|
86
|
+
"number_as_values": [],
|
87
|
+
"concept_as_values": [],
|
88
|
+
"units": [],
|
89
|
+
"is_numeric_types": [],
|
90
|
+
}
|
91
|
+
# Extract the demographic information
|
92
|
+
birth_datetime = record["birth_datetime"]
|
93
|
+
if isinstance(birth_datetime, pd.Timestamp):
|
94
|
+
birth_datetime = birth_datetime.to_pydatetime()
|
95
|
+
gender = record["gender"]
|
96
|
+
race = record["race"]
|
97
|
+
visits = record["visits"]
|
98
|
+
# This indicates this is columnar format
|
99
|
+
if isinstance(visits, dict):
|
100
|
+
visits = sorted(self.convert_visit_columnar_to_python(visits))
|
101
|
+
else:
|
102
|
+
visits = sorted(visits, key=lambda _: get_value(_, "visit_start_datetime"))
|
103
|
+
|
104
|
+
# Add the demographic tokens
|
105
|
+
first_visit = visits[0]
|
106
|
+
first_visit_start_datetime: datetime.datetime = get_value(
|
107
|
+
first_visit, "visit_start_datetime"
|
108
|
+
)
|
109
|
+
year_str = f"year:{str(first_visit_start_datetime.year)}"
|
110
|
+
age_str = f"age:{str(relativedelta(first_visit_start_datetime, birth_datetime).years)}"
|
111
|
+
self._update_cehrgpt_record(cehrgpt_record, year_str)
|
112
|
+
self._update_cehrgpt_record(cehrgpt_record, age_str)
|
113
|
+
self._update_cehrgpt_record(cehrgpt_record, gender)
|
114
|
+
self._update_cehrgpt_record(cehrgpt_record, race)
|
115
|
+
|
116
|
+
# Use a data cursor to keep track of time
|
117
|
+
datetime_cursor: Optional[datetime.datetime] = None
|
118
|
+
visit: VisitObject
|
119
|
+
# Loop through all the visits
|
120
|
+
for i, visit in enumerate(visits):
|
121
|
+
events: Generator[Event, None, None] = get_value(visit, "events")
|
122
|
+
has_events, events = has_events_and_get_events(events)
|
123
|
+
if not has_events:
|
124
|
+
continue
|
125
|
+
|
126
|
+
visit_start_datetime: datetime.datetime = get_value(
|
127
|
+
visit, "visit_start_datetime"
|
128
|
+
)
|
129
|
+
# If visit_end_datetime is populated for the inpatient visit, we update the datetime_cursor
|
130
|
+
visit_end_datetime: Optional[datetime.datetime] = get_value(
|
131
|
+
visit, "visit_end_datetime"
|
132
|
+
)
|
133
|
+
|
134
|
+
# We assume the first measurement to be the visit type of the current visit
|
135
|
+
visit_type = get_value(visit, "visit_type")
|
136
|
+
is_er_or_inpatient = (
|
137
|
+
visit_type in INPATIENT_VISIT_TYPES
|
138
|
+
or visit_type in INPATIENT_VISIT_TYPE_CODES
|
139
|
+
or visit_type in ED_VISIT_TYPE_CODES
|
140
|
+
)
|
141
|
+
|
142
|
+
# Add artificial time tokens to the patient timeline if timedelta exists
|
143
|
+
if datetime_cursor is not None:
|
144
|
+
time_delta = max((visit_start_datetime - datetime_cursor).days, 0)
|
145
|
+
# This generates an artificial time token depending on the choice of the time token functions
|
146
|
+
self._update_cehrgpt_record(
|
147
|
+
cehrgpt_record,
|
148
|
+
code=self._time_token_function(time_delta),
|
149
|
+
)
|
150
|
+
|
151
|
+
datetime_cursor = visit_start_datetime
|
152
|
+
# Add a [VS] token
|
153
|
+
self._update_cehrgpt_record(
|
154
|
+
cehrgpt_record,
|
155
|
+
code="[VS]",
|
156
|
+
)
|
157
|
+
# Add a visit type token
|
158
|
+
self._update_cehrgpt_record(
|
159
|
+
cehrgpt_record,
|
160
|
+
code=visit_type,
|
161
|
+
)
|
162
|
+
# We need to insert an inpatient hour token right after the visit type, we calculate the hour interval
|
163
|
+
# with respect to the midnight of the day
|
164
|
+
if is_er_or_inpatient and self._include_inpatient_hour_token:
|
165
|
+
if datetime_cursor.hour > 0:
|
166
|
+
# This generates an artificial time token depending on the choice of the time token functions
|
167
|
+
self._update_cehrgpt_record(
|
168
|
+
cehrgpt_record,
|
169
|
+
code=f"i-H{datetime_cursor.hour}",
|
170
|
+
)
|
171
|
+
|
172
|
+
# Keep track of the existing outpatient events, we don't want to add them again
|
173
|
+
existing_duplicate_events = list()
|
174
|
+
for e in events:
|
175
|
+
# If the event doesn't have a time stamp, we skip it
|
176
|
+
event_time: datetime.datetime = e["time"]
|
177
|
+
if not event_time:
|
178
|
+
continue
|
179
|
+
|
180
|
+
# If numeric_value exists, this is a concept/value tuple, we indicate this using a concept_value_mask
|
181
|
+
numeric_value = e.get("numeric_value", None)
|
182
|
+
text_value = e.get("text_value", None)
|
183
|
+
# The unit might be populated with a None value
|
184
|
+
unit = e.get("unit", NA) if e.get("unit", NA) else NA
|
185
|
+
concept_value_mask = int(
|
186
|
+
numeric_value is not None or text_value is not None
|
187
|
+
)
|
188
|
+
is_numeric_type = int(numeric_value is not None)
|
189
|
+
code = replace_escape_chars(e["code"])
|
190
|
+
|
191
|
+
# Create the event identity
|
192
|
+
event_identity = (
|
193
|
+
(event_time, code, text_value, unit)
|
194
|
+
if is_er_or_inpatient
|
195
|
+
else (event_time.date(), code, text_value, unit)
|
196
|
+
)
|
197
|
+
|
198
|
+
# Add a medical token to the patient timeline
|
199
|
+
# If this is an inpatient visit, we use the event time stamps to calculate age and date
|
200
|
+
# because the patient can stay in the hospital for a period of time.
|
201
|
+
if is_er_or_inpatient:
|
202
|
+
# Calculate the time diff in days w.r.t the previous measurement
|
203
|
+
time_diff_days = (event_time - datetime_cursor).days
|
204
|
+
# Update the datetime_cursor if the time diff between two neighboring measurements is greater than and
|
205
|
+
# equal to 1 day
|
206
|
+
if self._inpatient_time_token_function and time_diff_days > 0:
|
207
|
+
# This generates an artificial time token depending on the choice of the time token functions
|
208
|
+
self._update_cehrgpt_record(
|
209
|
+
cehrgpt_record,
|
210
|
+
code=f"i-{self._inpatient_time_token_function(time_diff_days)}",
|
211
|
+
)
|
212
|
+
|
213
|
+
if self._include_inpatient_hour_token:
|
214
|
+
# if the time difference in days is greater than 0, we calculate the hour interval
|
215
|
+
# with respect to the midnight of the day
|
216
|
+
time_diff_hours = (
|
217
|
+
event_time.hour
|
218
|
+
if time_diff_days > 0
|
219
|
+
else int(
|
220
|
+
(event_time - datetime_cursor).total_seconds() // 3600
|
221
|
+
)
|
222
|
+
)
|
223
|
+
|
224
|
+
if time_diff_hours > 0:
|
225
|
+
# This generates an artificial time token depending on the choice of the time token functions
|
226
|
+
self._update_cehrgpt_record(
|
227
|
+
cehrgpt_record,
|
228
|
+
code=f"i-H{time_diff_hours}",
|
229
|
+
)
|
230
|
+
|
231
|
+
if event_identity in existing_duplicate_events:
|
232
|
+
continue
|
233
|
+
|
234
|
+
self._update_cehrgpt_record(
|
235
|
+
cehrgpt_record,
|
236
|
+
code=code,
|
237
|
+
concept_value_mask=concept_value_mask,
|
238
|
+
unit=unit,
|
239
|
+
number_as_value=numeric_value if numeric_value else 0.0,
|
240
|
+
concept_as_value=(
|
241
|
+
replace_escape_chars(text_value) if text_value else "0"
|
242
|
+
),
|
243
|
+
is_numeric_type=is_numeric_type,
|
244
|
+
)
|
245
|
+
existing_duplicate_events.append(event_identity)
|
246
|
+
# we only want to update the time stamp when data_cursor is less than the event time
|
247
|
+
if datetime_cursor < event_time or datetime_cursor is None:
|
248
|
+
datetime_cursor = event_time
|
249
|
+
# We need to bound the datetime_cursor if the current visit is an admission type of visit
|
250
|
+
# as the associated events could be generated after the visits are complete
|
251
|
+
if is_er_or_inpatient and visit_end_datetime is not None:
|
252
|
+
datetime_cursor = min(datetime_cursor, visit_end_datetime)
|
253
|
+
|
254
|
+
# For inpatient or ER visits, we want to discharge_facility to the end of the visit
|
255
|
+
if is_er_or_inpatient:
|
256
|
+
# If visit_end_datetime is populated for the inpatient visit, we update the datetime_cursor
|
257
|
+
if visit_end_datetime is not None:
|
258
|
+
datetime_cursor = visit_end_datetime
|
259
|
+
|
260
|
+
if self._include_auxiliary_token:
|
261
|
+
# Reuse the age and date calculated for the last event in the patient timeline for the discharge
|
262
|
+
# facility event
|
263
|
+
discharge_facility = get_value(visit, "discharge_facility")
|
264
|
+
if not discharge_facility:
|
265
|
+
discharge_facility = "0"
|
266
|
+
|
267
|
+
self._update_cehrgpt_record(
|
268
|
+
cehrgpt_record,
|
269
|
+
code=discharge_facility,
|
270
|
+
)
|
271
|
+
|
272
|
+
# Reuse the age and date calculated for the last event in the patient timeline
|
273
|
+
self._update_cehrgpt_record(
|
274
|
+
cehrgpt_record,
|
275
|
+
code="[VE]",
|
276
|
+
)
|
277
|
+
|
278
|
+
# Generate the orders of the concepts that the cehrbert dataset mapping function expects
|
279
|
+
cehrgpt_record["orders"] = list(
|
280
|
+
range(1, len(cehrgpt_record["concept_ids"]) + 1)
|
281
|
+
)
|
282
|
+
|
283
|
+
# Add some count information for this sequence
|
284
|
+
cehrgpt_record["num_of_concepts"] = len(cehrgpt_record["concept_ids"])
|
285
|
+
cehrgpt_record["num_of_visits"] = len(visits)
|
286
|
+
|
287
|
+
if record.get("index_date", None):
|
288
|
+
cehrgpt_record["index_date"] = record["index_date"]
|
289
|
+
if record.get("label", None):
|
290
|
+
cehrgpt_record["label"] = record["label"]
|
291
|
+
if record.get("age_at_index", None):
|
292
|
+
cehrgpt_record["age_at_index"] = record["age_at_index"]
|
293
|
+
|
294
|
+
return cehrgpt_record
|
295
|
+
|
296
|
+
|
20
297
|
class HFCehrGptTokenizationMapping(DatasetMapping):
|
21
298
|
def __init__(
|
22
299
|
self,
|
@@ -0,0 +1,151 @@
|
|
1
|
+
from typing import Iterator, List, Optional
|
2
|
+
|
3
|
+
import torch
|
4
|
+
import torch.distributed as dist
|
5
|
+
from torch.utils.data import Sampler
|
6
|
+
from transformers import logging
|
7
|
+
|
8
|
+
LOG = logging.get_logger("transformers")
|
9
|
+
|
10
|
+
|
11
|
+
class SamplePlacerHolder:
|
12
|
+
def __init__(self):
|
13
|
+
self.epoch = 0
|
14
|
+
|
15
|
+
def set_epoch(self, epoch):
|
16
|
+
self.epoch = epoch
|
17
|
+
|
18
|
+
|
19
|
+
class SamplePackingBatchSampler(Sampler[List[int]]):
|
20
|
+
"""
|
21
|
+
A batch sampler that creates batches by packing samples together.
|
22
|
+
|
23
|
+
to maximize GPU utilization, ensuring the total tokens per batch
|
24
|
+
doesn't exceed max_tokens.
|
25
|
+
"""
|
26
|
+
|
27
|
+
def __init__(
|
28
|
+
self,
|
29
|
+
lengths: List[int],
|
30
|
+
max_tokens_per_batch: int,
|
31
|
+
max_position_embeddings: int,
|
32
|
+
num_replicas: Optional[int] = None,
|
33
|
+
rank: Optional[int] = None,
|
34
|
+
seed: int = 0,
|
35
|
+
drop_last: bool = False,
|
36
|
+
):
|
37
|
+
"""
|
38
|
+
Args:
|
39
|
+
|
40
|
+
lengths: List of sequence lengths for each sample
|
41
|
+
max_tokens: Maximum number of tokens in a batch
|
42
|
+
drop_last: Whether to drop the last incomplete batch
|
43
|
+
"""
|
44
|
+
super().__init__()
|
45
|
+
|
46
|
+
if num_replicas is None:
|
47
|
+
if dist.is_available() and dist.is_initialized():
|
48
|
+
num_replicas = dist.get_world_size()
|
49
|
+
LOG.info(
|
50
|
+
"torch.distributed is initialized and there are %s of replicas",
|
51
|
+
num_replicas,
|
52
|
+
)
|
53
|
+
else:
|
54
|
+
num_replicas = 1
|
55
|
+
LOG.info(
|
56
|
+
"torch.dist is not initialized and therefore default to 1 for num_replicas"
|
57
|
+
)
|
58
|
+
|
59
|
+
if rank is None:
|
60
|
+
if dist.is_available() and dist.is_initialized():
|
61
|
+
rank = dist.get_rank()
|
62
|
+
LOG.info(
|
63
|
+
"torch.distributed is initialized and the current rank is %s", rank
|
64
|
+
)
|
65
|
+
else:
|
66
|
+
rank = 0
|
67
|
+
LOG.info(
|
68
|
+
"torch.distributed is not initialized and therefore default to 0 for rank"
|
69
|
+
)
|
70
|
+
|
71
|
+
if not (0 <= rank < num_replicas):
|
72
|
+
raise ValueError(
|
73
|
+
f"Invalid rank {rank}, rank should be in the interval [0, {num_replicas - 1}]"
|
74
|
+
)
|
75
|
+
|
76
|
+
self.lengths = lengths
|
77
|
+
self.max_tokens_per_batch = max_tokens_per_batch
|
78
|
+
self.max_position_embeddings = max_position_embeddings
|
79
|
+
self.num_replicas = num_replicas
|
80
|
+
self.rank = rank
|
81
|
+
self.seed = seed
|
82
|
+
self.drop_last = drop_last
|
83
|
+
# Trainer https://github.com/huggingface/transformers/blame/main/src/transformers/trainer.py#L2470
|
84
|
+
# http://github.com/huggingface/accelerate/blob/v0.31.0/src/accelerate/data_loader.py#L482
|
85
|
+
# the huggingface trainer will call the accelerate.data_loader.DataLoaderShard.set_epoch,
|
86
|
+
# which will call batch_sampler.sample.set_epoch
|
87
|
+
self.sampler = SamplePlacerHolder()
|
88
|
+
|
89
|
+
def __iter__(self) -> Iterator[List[int]]:
|
90
|
+
|
91
|
+
# deterministically shuffle based on epoch and seed
|
92
|
+
g = torch.Generator()
|
93
|
+
g.manual_seed(self.seed + self.sampler.epoch)
|
94
|
+
indices = torch.randperm(len(self.lengths), generator=g).tolist()
|
95
|
+
|
96
|
+
# Partition indices for this rank
|
97
|
+
indices = indices[self.rank :: self.num_replicas]
|
98
|
+
|
99
|
+
batch = []
|
100
|
+
current_batch_tokens = 0
|
101
|
+
|
102
|
+
for idx in indices:
|
103
|
+
# We take the minimum of the two because each sequence will be truncated to fit
|
104
|
+
# the context window of the model
|
105
|
+
sample_length = min(self.lengths[idx], self.max_position_embeddings)
|
106
|
+
# If adding this sample would exceed max_tokens_per_batch, yield the current batch
|
107
|
+
if (
|
108
|
+
current_batch_tokens + sample_length + 2 > self.max_tokens_per_batch
|
109
|
+
and batch
|
110
|
+
):
|
111
|
+
yield batch
|
112
|
+
batch = []
|
113
|
+
current_batch_tokens = 0
|
114
|
+
|
115
|
+
# Add the sample to the current batch
|
116
|
+
batch.append(idx)
|
117
|
+
# plus extract one for the [END] and [PAD] tokens to separate samples
|
118
|
+
current_batch_tokens += sample_length + 2
|
119
|
+
|
120
|
+
# Yield the last batch if it's not empty and we're not dropping it
|
121
|
+
if batch and not self.drop_last:
|
122
|
+
yield batch
|
123
|
+
|
124
|
+
def __len__(self) -> int:
|
125
|
+
"""
|
126
|
+
Estimates the number of batches that will be generated.
|
127
|
+
|
128
|
+
This is an approximation since the exact number depends on the specific
|
129
|
+
sequence lengths and their order.
|
130
|
+
"""
|
131
|
+
if len(self.lengths) == 0:
|
132
|
+
return 0
|
133
|
+
|
134
|
+
# We need to truncate the lengths due to the context window limit imposed by the model
|
135
|
+
truncated_lengths = [
|
136
|
+
min(self.max_position_embeddings, length + 2) for length in self.lengths
|
137
|
+
]
|
138
|
+
|
139
|
+
# Calculate average sequence length
|
140
|
+
avg_seq_length = sum(truncated_lengths) // len(truncated_lengths)
|
141
|
+
|
142
|
+
# Estimate average number of sequences per batch
|
143
|
+
seqs_per_batch = self.max_tokens_per_batch // avg_seq_length
|
144
|
+
|
145
|
+
# Estimate total number of batches
|
146
|
+
if self.drop_last:
|
147
|
+
# If dropping last incomplete batch
|
148
|
+
return len(truncated_lengths) // seqs_per_batch * self.num_replicas
|
149
|
+
else:
|
150
|
+
# If keeping last incomplete batch, ensure at least 1 batch
|
151
|
+
return max(1, len(truncated_lengths) // seqs_per_batch) * self.num_replicas
|
@@ -93,9 +93,9 @@ def generate_single_batch(
|
|
93
93
|
temperature=temperature,
|
94
94
|
top_p=top_p,
|
95
95
|
top_k=top_k,
|
96
|
-
bos_token_id=
|
97
|
-
eos_token_id=
|
98
|
-
pad_token_id=
|
96
|
+
bos_token_id=model.generation_config.bos_token_id,
|
97
|
+
eos_token_id=model.generation_config.eos_token_id,
|
98
|
+
pad_token_id=model.generation_config.pad_token_id,
|
99
99
|
do_sample=True,
|
100
100
|
use_cache=True,
|
101
101
|
return_dict_in_generate=True,
|
@@ -150,15 +150,11 @@ def main(args):
|
|
150
150
|
attn_implementation=(
|
151
151
|
"flash_attention_2" if is_flash_attn_2_available() else "eager"
|
152
152
|
),
|
153
|
-
torch_dtype=(
|
154
|
-
torch.bfloat16
|
155
|
-
if is_flash_attn_2_available() and args.use_bfloat16
|
156
|
-
else torch.float32
|
157
|
-
),
|
158
153
|
)
|
159
154
|
.eval()
|
160
155
|
.to(device)
|
161
156
|
)
|
157
|
+
|
162
158
|
cehrgpt_model.generation_config.pad_token_id = cehrgpt_tokenizer.pad_token_id
|
163
159
|
cehrgpt_model.generation_config.eos_token_id = cehrgpt_tokenizer.end_token_id
|
164
160
|
cehrgpt_model.generation_config.bos_token_id = cehrgpt_tokenizer.end_token_id
|
@@ -192,6 +188,7 @@ def main(args):
|
|
192
188
|
LOG.info(f"Top P {args.top_p}")
|
193
189
|
LOG.info(f"Top K {args.top_k}")
|
194
190
|
LOG.info(f"Loading demographic_info at {args.demographic_data_path}")
|
191
|
+
LOG.info(f"MEDS format: {args.meds_format}")
|
195
192
|
|
196
193
|
dataset = load_parquet_as_dataset(args.demographic_data_path)
|
197
194
|
total_rows = len(dataset)
|
@@ -199,6 +196,7 @@ def main(args):
|
|
199
196
|
num_of_batches = args.num_of_patients // args.batch_size + 1
|
200
197
|
sequence_to_flush = []
|
201
198
|
current_person_id = 1
|
199
|
+
prompt_size = 2 if args.meds_format else START_TOKEN_SIZE
|
202
200
|
for i in range(num_of_batches):
|
203
201
|
LOG.info(f"{datetime.datetime.now()}: Batch {i} started")
|
204
202
|
|
@@ -215,7 +213,7 @@ def main(args):
|
|
215
213
|
<= max_seq_allowed
|
216
214
|
):
|
217
215
|
random_prompts.append(
|
218
|
-
cehrgpt_tokenizer.encode(row["concept_ids"][:
|
216
|
+
cehrgpt_tokenizer.encode(row["concept_ids"][:prompt_size])
|
219
217
|
)
|
220
218
|
iter += 1
|
221
219
|
if not random_prompts and iter > 10:
|
@@ -326,6 +324,11 @@ def create_arg_parser():
|
|
326
324
|
dest="drop_long_sequences",
|
327
325
|
action="store_true",
|
328
326
|
)
|
327
|
+
base_arg_parser.add_argument(
|
328
|
+
"--meds_format",
|
329
|
+
dest="meds_format",
|
330
|
+
action="store_true",
|
331
|
+
)
|
329
332
|
return base_arg_parser
|
330
333
|
|
331
334
|
|
@@ -35,6 +35,7 @@ from cehrgpt.models.tokenization_hf_cehrgpt import END_TOKEN
|
|
35
35
|
# TODO: move these to cehrbert_data
|
36
36
|
STOP_TOKENS = ["VE", "[VE]", END_TOKEN]
|
37
37
|
|
38
|
+
OOV = "[OOV]"
|
38
39
|
CURRENT_PATH = Path(__file__).parent
|
39
40
|
START_TOKEN_SIZE = 4
|
40
41
|
ATT_TIME_TOKENS = generate_artificial_time_tokens()
|
@@ -297,6 +298,8 @@ def gpt_to_omop_converter_batch(
|
|
297
298
|
inpatient_visit_indicator = False
|
298
299
|
|
299
300
|
for event_idx, event in enumerate(clinical_events, 0):
|
301
|
+
if event == OOV:
|
302
|
+
continue
|
300
303
|
# For bad sequences, we don't proceed further and break from the for loop
|
301
304
|
if bad_sequence:
|
302
305
|
break
|
cehrgpt/models/config.py
CHANGED
@@ -133,14 +133,17 @@ class CEHRGPTConfig(PretrainedConfig):
|
|
133
133
|
n_pretrained_embeddings_layers=2,
|
134
134
|
pretrained_embedding_dim=768,
|
135
135
|
pretrained_token_ids: List[int] = None,
|
136
|
+
next_token_prediction_loss_weight=1.0,
|
136
137
|
time_token_loss_weight=1.0,
|
137
138
|
time_to_visit_loss_weight=1.0,
|
138
139
|
causal_sfm=False,
|
139
140
|
demographics_size=4,
|
140
141
|
lab_token_penalty=False,
|
141
142
|
lab_token_loss_weight=0.9,
|
143
|
+
value_prediction_loss_weight=1.0,
|
142
144
|
entropy_penalty=False,
|
143
145
|
entropy_penalty_alpha=0.01,
|
146
|
+
sample_packing_max_positions=None,
|
144
147
|
**kwargs,
|
145
148
|
):
|
146
149
|
if token_to_time_token_mapping is None:
|
@@ -150,6 +153,11 @@ class CEHRGPTConfig(PretrainedConfig):
|
|
150
153
|
self.vocab_size = vocab_size
|
151
154
|
self.time_token_vocab_size = time_token_vocab_size
|
152
155
|
self.n_positions = n_positions
|
156
|
+
self.sample_packing_max_positions = (
|
157
|
+
sample_packing_max_positions
|
158
|
+
if sample_packing_max_positions
|
159
|
+
else n_positions
|
160
|
+
)
|
153
161
|
self.n_embd = n_embd
|
154
162
|
self.n_layer = n_layer
|
155
163
|
self.n_head = n_head
|
@@ -178,6 +186,7 @@ class CEHRGPTConfig(PretrainedConfig):
|
|
178
186
|
self.include_values = include_values
|
179
187
|
self.value_vocab_size = value_vocab_size
|
180
188
|
|
189
|
+
self.next_token_prediction_loss_weight = next_token_prediction_loss_weight
|
181
190
|
self.include_ttv_prediction = include_ttv_prediction
|
182
191
|
self.use_sub_time_tokenization = use_sub_time_tokenization
|
183
192
|
self._token_to_time_token_mapping = token_to_time_token_mapping
|
@@ -195,6 +204,7 @@ class CEHRGPTConfig(PretrainedConfig):
|
|
195
204
|
self.lab_token_loss_weight = lab_token_loss_weight
|
196
205
|
self.entropy_penalty = entropy_penalty
|
197
206
|
self.entropy_penalty_alpha = entropy_penalty_alpha
|
207
|
+
self.value_prediction_loss_weight = value_prediction_loss_weight
|
198
208
|
|
199
209
|
kwargs["tie_word_embeddings"] = not use_pretrained_embeddings
|
200
210
|
|