cehrgpt 0.0.2__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cehrgpt/data/hf_cehrgpt_dataset.py +24 -4
- cehrgpt/data/hf_cehrgpt_dataset_collator.py +260 -84
- cehrgpt/data/hf_cehrgpt_dataset_mapping.py +99 -88
- cehrgpt/data/sample_packing_sampler.py +151 -0
- cehrgpt/generation/generate_batch_hf_gpt_sequence.py +12 -9
- cehrgpt/models/config.py +10 -0
- cehrgpt/models/hf_cehrgpt.py +243 -73
- cehrgpt/models/tokenization_hf_cehrgpt.py +4 -0
- cehrgpt/runners/data_utils.py +243 -0
- cehrgpt/runners/gpt_runner_util.py +0 -10
- cehrgpt/runners/hf_cehrgpt_finetune_runner.py +152 -279
- cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +229 -105
- cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +42 -0
- cehrgpt/runners/hyperparameter_search_util.py +4 -1
- cehrgpt/runners/sample_packing_trainer.py +168 -0
- cehrgpt/simulations/generate_plots.py +95 -0
- cehrgpt/simulations/run_simulation.sh +24 -0
- cehrgpt/simulations/time_embedding_simulation.py +250 -0
- cehrgpt/simulations/time_token_simulation.py +177 -0
- cehrgpt/tools/linear_prob/__init__.py +0 -0
- cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +467 -0
- cehrgpt/tools/linear_prob/train_with_cehrgpt_features.py +152 -0
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.0.dist-info}/METADATA +7 -5
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.0.dist-info}/RECORD +28 -26
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.0.dist-info}/WHEEL +1 -1
- cehrgpt/data/hf_cehrgpt_dpo_collator.py +0 -71
- cehrgpt/data/hf_cehrgpt_dpo_dataset_mapping.py +0 -61
- cehrgpt/generation/generate_paired_cehrgpt_sequence.py +0 -224
- cehrgpt/rl_finetune/cehrgpt_dpo_trainer.py +0 -586
- cehrgpt/rl_finetune/cehrgpt_ppo_trainer.py +0 -464
- cehrgpt/rl_finetune/ppo_finetune.py +0 -394
- cehrgpt/rl_finetune/ppo_finetune_v2.py +0 -373
- cehrgpt/runners/hf_cehrgpt_dpo_runner.py +0 -119
- /cehrgpt/{rl_finetune → simulations}/__init__.py +0 -0
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.0.dist-info/licenses}/LICENSE +0 -0
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.0.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,5 @@
|
|
1
1
|
import datetime
|
2
|
-
from typing import Any, Dict
|
2
|
+
from typing import Any, Dict, Generator, Optional
|
3
3
|
|
4
4
|
import numpy as np
|
5
5
|
import pandas as pd
|
@@ -8,8 +8,12 @@ from cehrbert.data_generators.hf_data_generator.hf_dataset_mapping import (
|
|
8
8
|
INPATIENT_VISIT_TYPE_CODES,
|
9
9
|
INPATIENT_VISIT_TYPES,
|
10
10
|
DatasetMapping,
|
11
|
+
VisitObject,
|
12
|
+
get_value,
|
13
|
+
has_events_and_get_events,
|
11
14
|
replace_escape_chars,
|
12
15
|
)
|
16
|
+
from cehrbert.med_extension.schema_extension import Event
|
13
17
|
from cehrbert.runners.hf_runner_argument_dataclass import DataTrainingArguments
|
14
18
|
from cehrbert_data.const.common import NA
|
15
19
|
from cehrbert_data.decorators.patient_event_decorator_base import get_att_function
|
@@ -32,7 +36,6 @@ class MedToCehrGPTDatasetMapping(DatasetMapping):
|
|
32
36
|
def __init__(
|
33
37
|
self,
|
34
38
|
data_args: DataTrainingArguments,
|
35
|
-
is_pretraining: bool = True,
|
36
39
|
include_inpatient_hour_token: bool = True,
|
37
40
|
):
|
38
41
|
self._time_token_function = get_att_function(data_args.att_function_type)
|
@@ -41,7 +44,6 @@ class MedToCehrGPTDatasetMapping(DatasetMapping):
|
|
41
44
|
data_args.inpatient_att_function_type
|
42
45
|
)
|
43
46
|
self._include_demographic_prompt = data_args.include_demographic_prompt
|
44
|
-
self._is_pretraining = is_pretraining
|
45
47
|
self._include_inpatient_hour_token = include_inpatient_hour_token
|
46
48
|
|
47
49
|
"""
|
@@ -57,14 +59,7 @@ class MedToCehrGPTDatasetMapping(DatasetMapping):
|
|
57
59
|
"""
|
58
60
|
|
59
61
|
def remove_columns(self):
|
60
|
-
|
61
|
-
return ["visits", "birth_datetime", "index_date"]
|
62
|
-
else:
|
63
|
-
return [
|
64
|
-
"visits",
|
65
|
-
"birth_datetime",
|
66
|
-
"visit_concept_ids",
|
67
|
-
]
|
62
|
+
return ["patient_id", "visits", "birth_datetime"]
|
68
63
|
|
69
64
|
@staticmethod
|
70
65
|
def _update_cehrgpt_record(
|
@@ -99,38 +94,45 @@ class MedToCehrGPTDatasetMapping(DatasetMapping):
|
|
99
94
|
birth_datetime = birth_datetime.to_pydatetime()
|
100
95
|
gender = record["gender"]
|
101
96
|
race = record["race"]
|
97
|
+
visits = record["visits"]
|
98
|
+
# This indicates this is columnar format
|
99
|
+
if isinstance(visits, dict):
|
100
|
+
visits = sorted(self.convert_visit_columnar_to_python(visits))
|
101
|
+
else:
|
102
|
+
visits = sorted(visits, key=lambda _: get_value(_, "visit_start_datetime"))
|
102
103
|
|
103
104
|
# Add the demographic tokens
|
104
|
-
first_visit =
|
105
|
-
|
106
|
-
|
105
|
+
first_visit = visits[0]
|
106
|
+
first_visit_start_datetime: datetime.datetime = get_value(
|
107
|
+
first_visit, "visit_start_datetime"
|
108
|
+
)
|
109
|
+
year_str = f"year:{str(first_visit_start_datetime.year)}"
|
110
|
+
age_str = f"age:{str(relativedelta(first_visit_start_datetime, birth_datetime).years)}"
|
107
111
|
self._update_cehrgpt_record(cehrgpt_record, year_str)
|
108
112
|
self._update_cehrgpt_record(cehrgpt_record, age_str)
|
109
113
|
self._update_cehrgpt_record(cehrgpt_record, gender)
|
110
114
|
self._update_cehrgpt_record(cehrgpt_record, race)
|
111
115
|
|
112
116
|
# Use a data cursor to keep track of time
|
113
|
-
|
114
|
-
|
115
|
-
# Loop through all the visits
|
116
|
-
for i, visit in enumerate(
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
events = visit["events"]
|
121
|
-
|
122
|
-
# Skip this visit if the number measurements in the event is zero
|
123
|
-
if events is None or len(events) == 0:
|
117
|
+
datetime_cursor: Optional[datetime.datetime] = None
|
118
|
+
visit: VisitObject
|
119
|
+
# Loop through all the visits
|
120
|
+
for i, visit in enumerate(visits):
|
121
|
+
events: Generator[Event, None, None] = get_value(visit, "events")
|
122
|
+
has_events, events = has_events_and_get_events(events)
|
123
|
+
if not has_events:
|
124
124
|
continue
|
125
125
|
|
126
|
-
visit_start_datetime =
|
127
|
-
|
128
|
-
|
126
|
+
visit_start_datetime: datetime.datetime = get_value(
|
127
|
+
visit, "visit_start_datetime"
|
128
|
+
)
|
129
|
+
# If visit_end_datetime is populated for the inpatient visit, we update the datetime_cursor
|
130
|
+
visit_end_datetime: Optional[datetime.datetime] = get_value(
|
131
|
+
visit, "visit_end_datetime"
|
129
132
|
)
|
130
|
-
date_cursor = visit_start_datetime
|
131
133
|
|
132
134
|
# We assume the first measurement to be the visit type of the current visit
|
133
|
-
visit_type = visit
|
135
|
+
visit_type = get_value(visit, "visit_type")
|
134
136
|
is_er_or_inpatient = (
|
135
137
|
visit_type in INPATIENT_VISIT_TYPES
|
136
138
|
or visit_type in INPATIENT_VISIT_TYPE_CODES
|
@@ -138,21 +140,15 @@ class MedToCehrGPTDatasetMapping(DatasetMapping):
|
|
138
140
|
)
|
139
141
|
|
140
142
|
# Add artificial time tokens to the patient timeline if timedelta exists
|
141
|
-
if
|
143
|
+
if datetime_cursor is not None:
|
144
|
+
time_delta = max((visit_start_datetime - datetime_cursor).days, 0)
|
142
145
|
# This generates an artificial time token depending on the choice of the time token functions
|
143
146
|
self._update_cehrgpt_record(
|
144
147
|
cehrgpt_record,
|
145
148
|
code=self._time_token_function(time_delta),
|
146
149
|
)
|
147
150
|
|
148
|
-
|
149
|
-
relativedelta(visit["visit_start_datetime"], birth_datetime).years
|
150
|
-
# Calculate the week number since the epoch time
|
151
|
-
date = (
|
152
|
-
visit["visit_start_datetime"]
|
153
|
-
- datetime.datetime(year=1970, month=1, day=1)
|
154
|
-
).days // 7
|
155
|
-
|
151
|
+
datetime_cursor = visit_start_datetime
|
156
152
|
# Add a [VS] token
|
157
153
|
self._update_cehrgpt_record(
|
158
154
|
cehrgpt_record,
|
@@ -163,11 +159,22 @@ class MedToCehrGPTDatasetMapping(DatasetMapping):
|
|
163
159
|
cehrgpt_record,
|
164
160
|
code=visit_type,
|
165
161
|
)
|
162
|
+
# We need to insert an inpatient hour token right after the visit type, we calculate the hour interval
|
163
|
+
# with respect to the midnight of the day
|
164
|
+
if is_er_or_inpatient and self._include_inpatient_hour_token:
|
165
|
+
if datetime_cursor.hour > 0:
|
166
|
+
# This generates an artificial time token depending on the choice of the time token functions
|
167
|
+
self._update_cehrgpt_record(
|
168
|
+
cehrgpt_record,
|
169
|
+
code=f"i-H{datetime_cursor.hour}",
|
170
|
+
)
|
171
|
+
|
166
172
|
# Keep track of the existing outpatient events, we don't want to add them again
|
167
|
-
|
173
|
+
existing_duplicate_events = list()
|
168
174
|
for e in events:
|
169
175
|
# If the event doesn't have a time stamp, we skip it
|
170
|
-
|
176
|
+
event_time: datetime.datetime = e["time"]
|
177
|
+
if not event_time:
|
171
178
|
continue
|
172
179
|
|
173
180
|
# If numeric_value exists, this is a concept/value tuple, we indicate this using a concept_value_mask
|
@@ -181,40 +188,48 @@ class MedToCehrGPTDatasetMapping(DatasetMapping):
|
|
181
188
|
is_numeric_type = int(numeric_value is not None)
|
182
189
|
code = replace_escape_chars(e["code"])
|
183
190
|
|
191
|
+
# Create the event identity
|
192
|
+
event_identity = (
|
193
|
+
(event_time, code, text_value, unit)
|
194
|
+
if is_er_or_inpatient
|
195
|
+
else (event_time.date(), code, text_value, unit)
|
196
|
+
)
|
197
|
+
|
184
198
|
# Add a medical token to the patient timeline
|
185
199
|
# If this is an inpatient visit, we use the event time stamps to calculate age and date
|
186
200
|
# because the patient can stay in the hospital for a period of time.
|
187
201
|
if is_er_or_inpatient:
|
188
|
-
# Calculate the week number since the epoch time
|
189
|
-
date = (
|
190
|
-
e["time"] - datetime.datetime(year=1970, month=1, day=1)
|
191
|
-
).days // 7
|
192
202
|
# Calculate the time diff in days w.r.t the previous measurement
|
193
|
-
|
194
|
-
# Update the
|
203
|
+
time_diff_days = (event_time - datetime_cursor).days
|
204
|
+
# Update the datetime_cursor if the time diff between two neighboring measurements is greater than and
|
195
205
|
# equal to 1 day
|
196
|
-
if
|
197
|
-
|
198
|
-
|
206
|
+
if self._inpatient_time_token_function and time_diff_days > 0:
|
207
|
+
# This generates an artificial time token depending on the choice of the time token functions
|
208
|
+
self._update_cehrgpt_record(
|
209
|
+
cehrgpt_record,
|
210
|
+
code=f"i-{self._inpatient_time_token_function(time_diff_days)}",
|
211
|
+
)
|
212
|
+
|
213
|
+
if self._include_inpatient_hour_token:
|
214
|
+
# if the time difference in days is greater than 0, we calculate the hour interval
|
215
|
+
# with respect to the midnight of the day
|
216
|
+
time_diff_hours = (
|
217
|
+
event_time.hour
|
218
|
+
if time_diff_days > 0
|
219
|
+
else int(
|
220
|
+
(event_time - datetime_cursor).total_seconds() // 3600
|
221
|
+
)
|
222
|
+
)
|
223
|
+
|
224
|
+
if time_diff_hours > 0:
|
199
225
|
# This generates an artificial time token depending on the choice of the time token functions
|
200
226
|
self._update_cehrgpt_record(
|
201
227
|
cehrgpt_record,
|
202
|
-
code=f"i-{
|
228
|
+
code=f"i-H{time_diff_hours}",
|
203
229
|
)
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
# We check whether the date/code/value combination already exists in the existing events
|
208
|
-
# If they exist, we do not add them to the patient timeline for outpatient visits.
|
209
|
-
if (
|
210
|
-
date,
|
211
|
-
code,
|
212
|
-
numeric_value,
|
213
|
-
text_value,
|
214
|
-
concept_value_mask,
|
215
|
-
numeric_value,
|
216
|
-
) in existing_outpatient_events:
|
217
|
-
continue
|
230
|
+
|
231
|
+
if event_identity in existing_duplicate_events:
|
232
|
+
continue
|
218
233
|
|
219
234
|
self._update_cehrgpt_record(
|
220
235
|
cehrgpt_record,
|
@@ -227,33 +242,27 @@ class MedToCehrGPTDatasetMapping(DatasetMapping):
|
|
227
242
|
),
|
228
243
|
is_numeric_type=is_numeric_type,
|
229
244
|
)
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
)
|
239
|
-
)
|
245
|
+
existing_duplicate_events.append(event_identity)
|
246
|
+
# we only want to update the time stamp when data_cursor is less than the event time
|
247
|
+
if datetime_cursor < event_time or datetime_cursor is None:
|
248
|
+
datetime_cursor = event_time
|
249
|
+
# We need to bound the datetime_cursor if the current visit is an admission type of visit
|
250
|
+
# as the associated events could be generated after the visits are complete
|
251
|
+
if is_er_or_inpatient and visit_end_datetime is not None:
|
252
|
+
datetime_cursor = min(datetime_cursor, visit_end_datetime)
|
240
253
|
|
241
254
|
# For inpatient or ER visits, we want to discharge_facility to the end of the visit
|
242
255
|
if is_er_or_inpatient:
|
243
|
-
# If visit_end_datetime is populated for the inpatient visit, we update the
|
244
|
-
visit_end_datetime
|
245
|
-
|
246
|
-
date_cursor = visit_end_datetime
|
256
|
+
# If visit_end_datetime is populated for the inpatient visit, we update the datetime_cursor
|
257
|
+
if visit_end_datetime is not None:
|
258
|
+
datetime_cursor = visit_end_datetime
|
247
259
|
|
248
260
|
if self._include_auxiliary_token:
|
249
261
|
# Reuse the age and date calculated for the last event in the patient timeline for the discharge
|
250
262
|
# facility event
|
251
|
-
discharge_facility = (
|
252
|
-
|
253
|
-
|
254
|
-
and visit["discharge_facility"]
|
255
|
-
else "0"
|
256
|
-
)
|
263
|
+
discharge_facility = get_value(visit, "discharge_facility")
|
264
|
+
if not discharge_facility:
|
265
|
+
discharge_facility = "0"
|
257
266
|
|
258
267
|
self._update_cehrgpt_record(
|
259
268
|
cehrgpt_record,
|
@@ -273,11 +282,13 @@ class MedToCehrGPTDatasetMapping(DatasetMapping):
|
|
273
282
|
|
274
283
|
# Add some count information for this sequence
|
275
284
|
cehrgpt_record["num_of_concepts"] = len(cehrgpt_record["concept_ids"])
|
276
|
-
cehrgpt_record["num_of_visits"] = len(
|
285
|
+
cehrgpt_record["num_of_visits"] = len(visits)
|
277
286
|
|
278
|
-
if "
|
287
|
+
if record.get("index_date", None):
|
288
|
+
cehrgpt_record["index_date"] = record["index_date"]
|
289
|
+
if record.get("label", None):
|
279
290
|
cehrgpt_record["label"] = record["label"]
|
280
|
-
if "age_at_index"
|
291
|
+
if record.get("age_at_index", None):
|
281
292
|
cehrgpt_record["age_at_index"] = record["age_at_index"]
|
282
293
|
|
283
294
|
return cehrgpt_record
|
@@ -0,0 +1,151 @@
|
|
1
|
+
from typing import Iterator, List, Optional
|
2
|
+
|
3
|
+
import torch
|
4
|
+
import torch.distributed as dist
|
5
|
+
from torch.utils.data import Sampler
|
6
|
+
from transformers import logging
|
7
|
+
|
8
|
+
LOG = logging.get_logger("transformers")
|
9
|
+
|
10
|
+
|
11
|
+
class SamplePlacerHolder:
|
12
|
+
def __init__(self):
|
13
|
+
self.epoch = 0
|
14
|
+
|
15
|
+
def set_epoch(self, epoch):
|
16
|
+
self.epoch = epoch
|
17
|
+
|
18
|
+
|
19
|
+
class SamplePackingBatchSampler(Sampler[List[int]]):
|
20
|
+
"""
|
21
|
+
A batch sampler that creates batches by packing samples together.
|
22
|
+
|
23
|
+
to maximize GPU utilization, ensuring the total tokens per batch
|
24
|
+
doesn't exceed max_tokens.
|
25
|
+
"""
|
26
|
+
|
27
|
+
def __init__(
|
28
|
+
self,
|
29
|
+
lengths: List[int],
|
30
|
+
max_tokens_per_batch: int,
|
31
|
+
max_position_embeddings: int,
|
32
|
+
num_replicas: Optional[int] = None,
|
33
|
+
rank: Optional[int] = None,
|
34
|
+
seed: int = 0,
|
35
|
+
drop_last: bool = False,
|
36
|
+
):
|
37
|
+
"""
|
38
|
+
Args:
|
39
|
+
|
40
|
+
lengths: List of sequence lengths for each sample
|
41
|
+
max_tokens: Maximum number of tokens in a batch
|
42
|
+
drop_last: Whether to drop the last incomplete batch
|
43
|
+
"""
|
44
|
+
super().__init__()
|
45
|
+
|
46
|
+
if num_replicas is None:
|
47
|
+
if dist.is_available() and dist.is_initialized():
|
48
|
+
num_replicas = dist.get_world_size()
|
49
|
+
LOG.info(
|
50
|
+
"torch.distributed is initialized and there are %s of replicas",
|
51
|
+
num_replicas,
|
52
|
+
)
|
53
|
+
else:
|
54
|
+
num_replicas = 1
|
55
|
+
LOG.info(
|
56
|
+
"torch.dist is not initialized and therefore default to 1 for num_replicas"
|
57
|
+
)
|
58
|
+
|
59
|
+
if rank is None:
|
60
|
+
if dist.is_available() and dist.is_initialized():
|
61
|
+
rank = dist.get_rank()
|
62
|
+
LOG.info(
|
63
|
+
"torch.distributed is initialized and the current rank is %s", rank
|
64
|
+
)
|
65
|
+
else:
|
66
|
+
rank = 0
|
67
|
+
LOG.info(
|
68
|
+
"torch.distributed is not initialized and therefore default to 0 for rank"
|
69
|
+
)
|
70
|
+
|
71
|
+
if not (0 <= rank < num_replicas):
|
72
|
+
raise ValueError(
|
73
|
+
f"Invalid rank {rank}, rank should be in the interval [0, {num_replicas - 1}]"
|
74
|
+
)
|
75
|
+
|
76
|
+
self.lengths = lengths
|
77
|
+
self.max_tokens_per_batch = max_tokens_per_batch
|
78
|
+
self.max_position_embeddings = max_position_embeddings
|
79
|
+
self.num_replicas = num_replicas
|
80
|
+
self.rank = rank
|
81
|
+
self.seed = seed
|
82
|
+
self.drop_last = drop_last
|
83
|
+
# Trainer https://github.com/huggingface/transformers/blame/main/src/transformers/trainer.py#L2470
|
84
|
+
# http://github.com/huggingface/accelerate/blob/v0.31.0/src/accelerate/data_loader.py#L482
|
85
|
+
# the huggingface trainer will call the accelerate.data_loader.DataLoaderShard.set_epoch,
|
86
|
+
# which will call batch_sampler.sample.set_epoch
|
87
|
+
self.sampler = SamplePlacerHolder()
|
88
|
+
|
89
|
+
def __iter__(self) -> Iterator[List[int]]:
|
90
|
+
|
91
|
+
# deterministically shuffle based on epoch and seed
|
92
|
+
g = torch.Generator()
|
93
|
+
g.manual_seed(self.seed + self.sampler.epoch)
|
94
|
+
indices = torch.randperm(len(self.lengths), generator=g).tolist()
|
95
|
+
|
96
|
+
# Partition indices for this rank
|
97
|
+
indices = indices[self.rank :: self.num_replicas]
|
98
|
+
|
99
|
+
batch = []
|
100
|
+
current_batch_tokens = 0
|
101
|
+
|
102
|
+
for idx in indices:
|
103
|
+
# We take the minimum of the two because each sequence will be truncated to fit
|
104
|
+
# the context window of the model
|
105
|
+
sample_length = min(self.lengths[idx], self.max_position_embeddings)
|
106
|
+
# If adding this sample would exceed max_tokens_per_batch, yield the current batch
|
107
|
+
if (
|
108
|
+
current_batch_tokens + sample_length + 2 > self.max_tokens_per_batch
|
109
|
+
and batch
|
110
|
+
):
|
111
|
+
yield batch
|
112
|
+
batch = []
|
113
|
+
current_batch_tokens = 0
|
114
|
+
|
115
|
+
# Add the sample to the current batch
|
116
|
+
batch.append(idx)
|
117
|
+
# plus extract one for the [END] and [PAD] tokens to separate samples
|
118
|
+
current_batch_tokens += sample_length + 2
|
119
|
+
|
120
|
+
# Yield the last batch if it's not empty and we're not dropping it
|
121
|
+
if batch and not self.drop_last:
|
122
|
+
yield batch
|
123
|
+
|
124
|
+
def __len__(self) -> int:
|
125
|
+
"""
|
126
|
+
Estimates the number of batches that will be generated.
|
127
|
+
|
128
|
+
This is an approximation since the exact number depends on the specific
|
129
|
+
sequence lengths and their order.
|
130
|
+
"""
|
131
|
+
if len(self.lengths) == 0:
|
132
|
+
return 0
|
133
|
+
|
134
|
+
# We need to truncate the lengths due to the context window limit imposed by the model
|
135
|
+
truncated_lengths = [
|
136
|
+
min(self.max_position_embeddings, length + 2) for length in self.lengths
|
137
|
+
]
|
138
|
+
|
139
|
+
# Calculate average sequence length
|
140
|
+
avg_seq_length = sum(truncated_lengths) // len(truncated_lengths)
|
141
|
+
|
142
|
+
# Estimate average number of sequences per batch
|
143
|
+
seqs_per_batch = self.max_tokens_per_batch // avg_seq_length
|
144
|
+
|
145
|
+
# Estimate total number of batches
|
146
|
+
if self.drop_last:
|
147
|
+
# If dropping last incomplete batch
|
148
|
+
return len(truncated_lengths) // seqs_per_batch * self.num_replicas
|
149
|
+
else:
|
150
|
+
# If keeping last incomplete batch, ensure at least 1 batch
|
151
|
+
return max(1, len(truncated_lengths) // seqs_per_batch) * self.num_replicas
|
@@ -93,9 +93,9 @@ def generate_single_batch(
|
|
93
93
|
temperature=temperature,
|
94
94
|
top_p=top_p,
|
95
95
|
top_k=top_k,
|
96
|
-
bos_token_id=
|
97
|
-
eos_token_id=
|
98
|
-
pad_token_id=
|
96
|
+
bos_token_id=model.generation_config.bos_token_id,
|
97
|
+
eos_token_id=model.generation_config.eos_token_id,
|
98
|
+
pad_token_id=model.generation_config.pad_token_id,
|
99
99
|
do_sample=True,
|
100
100
|
use_cache=True,
|
101
101
|
return_dict_in_generate=True,
|
@@ -150,15 +150,11 @@ def main(args):
|
|
150
150
|
attn_implementation=(
|
151
151
|
"flash_attention_2" if is_flash_attn_2_available() else "eager"
|
152
152
|
),
|
153
|
-
torch_dtype=(
|
154
|
-
torch.bfloat16
|
155
|
-
if is_flash_attn_2_available() and args.use_bfloat16
|
156
|
-
else torch.float32
|
157
|
-
),
|
158
153
|
)
|
159
154
|
.eval()
|
160
155
|
.to(device)
|
161
156
|
)
|
157
|
+
|
162
158
|
cehrgpt_model.generation_config.pad_token_id = cehrgpt_tokenizer.pad_token_id
|
163
159
|
cehrgpt_model.generation_config.eos_token_id = cehrgpt_tokenizer.end_token_id
|
164
160
|
cehrgpt_model.generation_config.bos_token_id = cehrgpt_tokenizer.end_token_id
|
@@ -192,6 +188,7 @@ def main(args):
|
|
192
188
|
LOG.info(f"Top P {args.top_p}")
|
193
189
|
LOG.info(f"Top K {args.top_k}")
|
194
190
|
LOG.info(f"Loading demographic_info at {args.demographic_data_path}")
|
191
|
+
LOG.info(f"MEDS format: {args.meds_format}")
|
195
192
|
|
196
193
|
dataset = load_parquet_as_dataset(args.demographic_data_path)
|
197
194
|
total_rows = len(dataset)
|
@@ -199,6 +196,7 @@ def main(args):
|
|
199
196
|
num_of_batches = args.num_of_patients // args.batch_size + 1
|
200
197
|
sequence_to_flush = []
|
201
198
|
current_person_id = 1
|
199
|
+
prompt_size = 2 if args.meds_format else START_TOKEN_SIZE
|
202
200
|
for i in range(num_of_batches):
|
203
201
|
LOG.info(f"{datetime.datetime.now()}: Batch {i} started")
|
204
202
|
|
@@ -215,7 +213,7 @@ def main(args):
|
|
215
213
|
<= max_seq_allowed
|
216
214
|
):
|
217
215
|
random_prompts.append(
|
218
|
-
cehrgpt_tokenizer.encode(row["concept_ids"][:
|
216
|
+
cehrgpt_tokenizer.encode(row["concept_ids"][:prompt_size])
|
219
217
|
)
|
220
218
|
iter += 1
|
221
219
|
if not random_prompts and iter > 10:
|
@@ -326,6 +324,11 @@ def create_arg_parser():
|
|
326
324
|
dest="drop_long_sequences",
|
327
325
|
action="store_true",
|
328
326
|
)
|
327
|
+
base_arg_parser.add_argument(
|
328
|
+
"--meds_format",
|
329
|
+
dest="meds_format",
|
330
|
+
action="store_true",
|
331
|
+
)
|
329
332
|
return base_arg_parser
|
330
333
|
|
331
334
|
|
cehrgpt/models/config.py
CHANGED
@@ -133,14 +133,17 @@ class CEHRGPTConfig(PretrainedConfig):
|
|
133
133
|
n_pretrained_embeddings_layers=2,
|
134
134
|
pretrained_embedding_dim=768,
|
135
135
|
pretrained_token_ids: List[int] = None,
|
136
|
+
next_token_prediction_loss_weight=1.0,
|
136
137
|
time_token_loss_weight=1.0,
|
137
138
|
time_to_visit_loss_weight=1.0,
|
138
139
|
causal_sfm=False,
|
139
140
|
demographics_size=4,
|
140
141
|
lab_token_penalty=False,
|
141
142
|
lab_token_loss_weight=0.9,
|
143
|
+
value_prediction_loss_weight=1.0,
|
142
144
|
entropy_penalty=False,
|
143
145
|
entropy_penalty_alpha=0.01,
|
146
|
+
sample_packing_max_positions=None,
|
144
147
|
**kwargs,
|
145
148
|
):
|
146
149
|
if token_to_time_token_mapping is None:
|
@@ -150,6 +153,11 @@ class CEHRGPTConfig(PretrainedConfig):
|
|
150
153
|
self.vocab_size = vocab_size
|
151
154
|
self.time_token_vocab_size = time_token_vocab_size
|
152
155
|
self.n_positions = n_positions
|
156
|
+
self.sample_packing_max_positions = (
|
157
|
+
sample_packing_max_positions
|
158
|
+
if sample_packing_max_positions
|
159
|
+
else n_positions
|
160
|
+
)
|
153
161
|
self.n_embd = n_embd
|
154
162
|
self.n_layer = n_layer
|
155
163
|
self.n_head = n_head
|
@@ -178,6 +186,7 @@ class CEHRGPTConfig(PretrainedConfig):
|
|
178
186
|
self.include_values = include_values
|
179
187
|
self.value_vocab_size = value_vocab_size
|
180
188
|
|
189
|
+
self.next_token_prediction_loss_weight = next_token_prediction_loss_weight
|
181
190
|
self.include_ttv_prediction = include_ttv_prediction
|
182
191
|
self.use_sub_time_tokenization = use_sub_time_tokenization
|
183
192
|
self._token_to_time_token_mapping = token_to_time_token_mapping
|
@@ -195,6 +204,7 @@ class CEHRGPTConfig(PretrainedConfig):
|
|
195
204
|
self.lab_token_loss_weight = lab_token_loss_weight
|
196
205
|
self.entropy_penalty = entropy_penalty
|
197
206
|
self.entropy_penalty_alpha = entropy_penalty_alpha
|
207
|
+
self.value_prediction_loss_weight = value_prediction_loss_weight
|
198
208
|
|
199
209
|
kwargs["tie_word_embeddings"] = not use_pretrained_embeddings
|
200
210
|
|