cehrgpt 0.0.1__py3-none-any.whl → 0.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cehrgpt/data/hf_cehrgpt_dataset_mapping.py +267 -1
- cehrgpt/data/hf_cehrgpt_dpo_collator.py +71 -0
- cehrgpt/data/hf_cehrgpt_dpo_dataset_mapping.py +61 -0
- cehrgpt/generation/generate_paired_cehrgpt_sequence.py +224 -0
- cehrgpt/generation/omop_converter_batch.py +3 -0
- cehrgpt/models/hf_cehrgpt.py +1 -0
- cehrgpt/models/tokenization_hf_cehrgpt.py +2 -2
- cehrgpt/rl_finetune/__init__.py +0 -0
- cehrgpt/rl_finetune/cehrgpt_dpo_trainer.py +586 -0
- cehrgpt/rl_finetune/cehrgpt_ppo_trainer.py +464 -0
- cehrgpt/rl_finetune/ppo_finetune.py +394 -0
- cehrgpt/rl_finetune/ppo_finetune_v2.py +373 -0
- cehrgpt/runners/hf_cehrgpt_dpo_runner.py +119 -0
- cehrgpt/runners/hf_cehrgpt_finetune_runner.py +24 -3
- cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +44 -8
- cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +4 -0
- cehrgpt/tools/generate_causal_patient_split_by_age.py +146 -0
- {cehrgpt-0.0.1.dist-info → cehrgpt-0.0.2.dist-info}/METADATA +52 -6
- {cehrgpt-0.0.1.dist-info → cehrgpt-0.0.2.dist-info}/RECORD +22 -12
- {cehrgpt-0.0.1.dist-info → cehrgpt-0.0.2.dist-info}/WHEEL +1 -1
- {cehrgpt-0.0.1.dist-info → cehrgpt-0.0.2.dist-info}/LICENSE +0 -0
- {cehrgpt-0.0.1.dist-info → cehrgpt-0.0.2.dist-info}/top_level.txt +0 -0
@@ -2,7 +2,18 @@ import datetime
|
|
2
2
|
from typing import Any, Dict
|
3
3
|
|
4
4
|
import numpy as np
|
5
|
-
|
5
|
+
import pandas as pd
|
6
|
+
from cehrbert.data_generators.hf_data_generator.hf_dataset_mapping import (
|
7
|
+
ED_VISIT_TYPE_CODES,
|
8
|
+
INPATIENT_VISIT_TYPE_CODES,
|
9
|
+
INPATIENT_VISIT_TYPES,
|
10
|
+
DatasetMapping,
|
11
|
+
replace_escape_chars,
|
12
|
+
)
|
13
|
+
from cehrbert.runners.hf_runner_argument_dataclass import DataTrainingArguments
|
14
|
+
from cehrbert_data.const.common import NA
|
15
|
+
from cehrbert_data.decorators.patient_event_decorator_base import get_att_function
|
16
|
+
from dateutil.relativedelta import relativedelta
|
6
17
|
|
7
18
|
from cehrgpt.models.tokenization_hf_cehrgpt import (
|
8
19
|
NONE_BIN,
|
@@ -17,6 +28,261 @@ def convert_date_to_posix_time(index_date: datetime.date) -> float:
|
|
17
28
|
).timestamp()
|
18
29
|
|
19
30
|
|
31
|
+
class MedToCehrGPTDatasetMapping(DatasetMapping):
|
32
|
+
def __init__(
|
33
|
+
self,
|
34
|
+
data_args: DataTrainingArguments,
|
35
|
+
is_pretraining: bool = True,
|
36
|
+
include_inpatient_hour_token: bool = True,
|
37
|
+
):
|
38
|
+
self._time_token_function = get_att_function(data_args.att_function_type)
|
39
|
+
self._include_auxiliary_token = data_args.include_auxiliary_token
|
40
|
+
self._inpatient_time_token_function = get_att_function(
|
41
|
+
data_args.inpatient_att_function_type
|
42
|
+
)
|
43
|
+
self._include_demographic_prompt = data_args.include_demographic_prompt
|
44
|
+
self._is_pretraining = is_pretraining
|
45
|
+
self._include_inpatient_hour_token = include_inpatient_hour_token
|
46
|
+
|
47
|
+
"""
|
48
|
+
This mapping function converts the MED (https://github.com/Medical-Event-Data-Standard/meds/tree/main) extension
|
49
|
+
to the CehrGPT format. We make several assumptions
|
50
|
+
- The first event contains the demographic information
|
51
|
+
- From the second event onward
|
52
|
+
- the time of the event is visit_start_datetime.
|
53
|
+
- the first measurement contains the code indicating a standard OMOP Visit concept_id (e.g. 9201, 9202)
|
54
|
+
- in case of inpatient visits, the last measurement is assumed to
|
55
|
+
contain the standard OMOP concept id for discharge facilities (e.g 8536)
|
56
|
+
- in case of inpatient visits, datetime_value of the last measurement stores visit_end_datetime
|
57
|
+
"""
|
58
|
+
|
59
|
+
def remove_columns(self):
|
60
|
+
if self._is_pretraining:
|
61
|
+
return ["visits", "birth_datetime", "index_date"]
|
62
|
+
else:
|
63
|
+
return [
|
64
|
+
"visits",
|
65
|
+
"birth_datetime",
|
66
|
+
"visit_concept_ids",
|
67
|
+
]
|
68
|
+
|
69
|
+
@staticmethod
|
70
|
+
def _update_cehrgpt_record(
|
71
|
+
cehrgpt_record: Dict[str, Any],
|
72
|
+
code: str,
|
73
|
+
concept_value_mask: int = 0,
|
74
|
+
number_as_value: float = 0.0,
|
75
|
+
concept_as_value: str = "0",
|
76
|
+
is_numeric_type: int = 0,
|
77
|
+
unit: str = NA,
|
78
|
+
) -> None:
|
79
|
+
cehrgpt_record["concept_ids"].append(replace_escape_chars(code))
|
80
|
+
cehrgpt_record["concept_value_masks"].append(concept_value_mask)
|
81
|
+
cehrgpt_record["number_as_values"].append(number_as_value)
|
82
|
+
cehrgpt_record["concept_as_values"].append(concept_as_value)
|
83
|
+
cehrgpt_record["units"].append(unit)
|
84
|
+
cehrgpt_record["is_numeric_types"].append(is_numeric_type)
|
85
|
+
|
86
|
+
def transform(self, record: Dict[str, Any]) -> Dict[str, Any]:
|
87
|
+
cehrgpt_record = {
|
88
|
+
"person_id": record["patient_id"],
|
89
|
+
"concept_ids": [],
|
90
|
+
"concept_value_masks": [],
|
91
|
+
"number_as_values": [],
|
92
|
+
"concept_as_values": [],
|
93
|
+
"units": [],
|
94
|
+
"is_numeric_types": [],
|
95
|
+
}
|
96
|
+
# Extract the demographic information
|
97
|
+
birth_datetime = record["birth_datetime"]
|
98
|
+
if isinstance(birth_datetime, pd.Timestamp):
|
99
|
+
birth_datetime = birth_datetime.to_pydatetime()
|
100
|
+
gender = record["gender"]
|
101
|
+
race = record["race"]
|
102
|
+
|
103
|
+
# Add the demographic tokens
|
104
|
+
first_visit = record["visits"][0]
|
105
|
+
year_str = f'year:{str(first_visit["visit_start_datetime"].year)}'
|
106
|
+
age_str = f'age:{str(relativedelta(first_visit["visit_start_datetime"], birth_datetime).years)}'
|
107
|
+
self._update_cehrgpt_record(cehrgpt_record, year_str)
|
108
|
+
self._update_cehrgpt_record(cehrgpt_record, age_str)
|
109
|
+
self._update_cehrgpt_record(cehrgpt_record, gender)
|
110
|
+
self._update_cehrgpt_record(cehrgpt_record, race)
|
111
|
+
|
112
|
+
# Use a data cursor to keep track of time
|
113
|
+
date_cursor = None
|
114
|
+
|
115
|
+
# Loop through all the visits excluding the first event containing the demographics
|
116
|
+
for i, visit in enumerate(
|
117
|
+
sorted(record["visits"], key=lambda e: e["visit_start_datetime"])
|
118
|
+
):
|
119
|
+
|
120
|
+
events = visit["events"]
|
121
|
+
|
122
|
+
# Skip this visit if the number measurements in the event is zero
|
123
|
+
if events is None or len(events) == 0:
|
124
|
+
continue
|
125
|
+
|
126
|
+
visit_start_datetime = visit["visit_start_datetime"]
|
127
|
+
time_delta = (
|
128
|
+
(visit_start_datetime - date_cursor).days if date_cursor else None
|
129
|
+
)
|
130
|
+
date_cursor = visit_start_datetime
|
131
|
+
|
132
|
+
# We assume the first measurement to be the visit type of the current visit
|
133
|
+
visit_type = visit["visit_type"]
|
134
|
+
is_er_or_inpatient = (
|
135
|
+
visit_type in INPATIENT_VISIT_TYPES
|
136
|
+
or visit_type in INPATIENT_VISIT_TYPE_CODES
|
137
|
+
or visit_type in ED_VISIT_TYPE_CODES
|
138
|
+
)
|
139
|
+
|
140
|
+
# Add artificial time tokens to the patient timeline if timedelta exists
|
141
|
+
if time_delta is not None:
|
142
|
+
# This generates an artificial time token depending on the choice of the time token functions
|
143
|
+
self._update_cehrgpt_record(
|
144
|
+
cehrgpt_record,
|
145
|
+
code=self._time_token_function(time_delta),
|
146
|
+
)
|
147
|
+
|
148
|
+
# Add the VS token to the patient timeline to mark the start of a visit
|
149
|
+
relativedelta(visit["visit_start_datetime"], birth_datetime).years
|
150
|
+
# Calculate the week number since the epoch time
|
151
|
+
date = (
|
152
|
+
visit["visit_start_datetime"]
|
153
|
+
- datetime.datetime(year=1970, month=1, day=1)
|
154
|
+
).days // 7
|
155
|
+
|
156
|
+
# Add a [VS] token
|
157
|
+
self._update_cehrgpt_record(
|
158
|
+
cehrgpt_record,
|
159
|
+
code="[VS]",
|
160
|
+
)
|
161
|
+
# Add a visit type token
|
162
|
+
self._update_cehrgpt_record(
|
163
|
+
cehrgpt_record,
|
164
|
+
code=visit_type,
|
165
|
+
)
|
166
|
+
# Keep track of the existing outpatient events, we don't want to add them again
|
167
|
+
existing_outpatient_events = list()
|
168
|
+
for e in events:
|
169
|
+
# If the event doesn't have a time stamp, we skip it
|
170
|
+
if not e["time"]:
|
171
|
+
continue
|
172
|
+
|
173
|
+
# If numeric_value exists, this is a concept/value tuple, we indicate this using a concept_value_mask
|
174
|
+
numeric_value = e.get("numeric_value", None)
|
175
|
+
text_value = e.get("text_value", None)
|
176
|
+
# The unit might be populated with a None value
|
177
|
+
unit = e.get("unit", NA) if e.get("unit", NA) else NA
|
178
|
+
concept_value_mask = int(
|
179
|
+
numeric_value is not None or text_value is not None
|
180
|
+
)
|
181
|
+
is_numeric_type = int(numeric_value is not None)
|
182
|
+
code = replace_escape_chars(e["code"])
|
183
|
+
|
184
|
+
# Add a medical token to the patient timeline
|
185
|
+
# If this is an inpatient visit, we use the event time stamps to calculate age and date
|
186
|
+
# because the patient can stay in the hospital for a period of time.
|
187
|
+
if is_er_or_inpatient:
|
188
|
+
# Calculate the week number since the epoch time
|
189
|
+
date = (
|
190
|
+
e["time"] - datetime.datetime(year=1970, month=1, day=1)
|
191
|
+
).days // 7
|
192
|
+
# Calculate the time diff in days w.r.t the previous measurement
|
193
|
+
meas_time_diff = (e["time"] - date_cursor).days
|
194
|
+
# Update the date_cursor if the time diff between two neighboring measurements is greater than and
|
195
|
+
# equal to 1 day
|
196
|
+
if meas_time_diff > 0:
|
197
|
+
date_cursor = e["time"]
|
198
|
+
if self._inpatient_time_token_function:
|
199
|
+
# This generates an artificial time token depending on the choice of the time token functions
|
200
|
+
self._update_cehrgpt_record(
|
201
|
+
cehrgpt_record,
|
202
|
+
code=f"i-{self._inpatient_time_token_function(meas_time_diff)}",
|
203
|
+
)
|
204
|
+
else:
|
205
|
+
# For outpatient visits, we use the visit time stamp to calculate age and time because we assume
|
206
|
+
# the outpatient visits start and end on the same day.
|
207
|
+
# We check whether the date/code/value combination already exists in the existing events
|
208
|
+
# If they exist, we do not add them to the patient timeline for outpatient visits.
|
209
|
+
if (
|
210
|
+
date,
|
211
|
+
code,
|
212
|
+
numeric_value,
|
213
|
+
text_value,
|
214
|
+
concept_value_mask,
|
215
|
+
numeric_value,
|
216
|
+
) in existing_outpatient_events:
|
217
|
+
continue
|
218
|
+
|
219
|
+
self._update_cehrgpt_record(
|
220
|
+
cehrgpt_record,
|
221
|
+
code=code,
|
222
|
+
concept_value_mask=concept_value_mask,
|
223
|
+
unit=unit,
|
224
|
+
number_as_value=numeric_value if numeric_value else 0.0,
|
225
|
+
concept_as_value=(
|
226
|
+
replace_escape_chars(text_value) if text_value else "0"
|
227
|
+
),
|
228
|
+
is_numeric_type=is_numeric_type,
|
229
|
+
)
|
230
|
+
existing_outpatient_events.append(
|
231
|
+
(
|
232
|
+
date,
|
233
|
+
code,
|
234
|
+
numeric_value,
|
235
|
+
text_value,
|
236
|
+
concept_value_mask,
|
237
|
+
numeric_value,
|
238
|
+
)
|
239
|
+
)
|
240
|
+
|
241
|
+
# For inpatient or ER visits, we want to discharge_facility to the end of the visit
|
242
|
+
if is_er_or_inpatient:
|
243
|
+
# If visit_end_datetime is populated for the inpatient visit, we update the date_cursor
|
244
|
+
visit_end_datetime = visit.get("visit_end_datetime", None)
|
245
|
+
if visit_end_datetime:
|
246
|
+
date_cursor = visit_end_datetime
|
247
|
+
|
248
|
+
if self._include_auxiliary_token:
|
249
|
+
# Reuse the age and date calculated for the last event in the patient timeline for the discharge
|
250
|
+
# facility event
|
251
|
+
discharge_facility = (
|
252
|
+
visit["discharge_facility"]
|
253
|
+
if ("discharge_facility" in visit)
|
254
|
+
and visit["discharge_facility"]
|
255
|
+
else "0"
|
256
|
+
)
|
257
|
+
|
258
|
+
self._update_cehrgpt_record(
|
259
|
+
cehrgpt_record,
|
260
|
+
code=discharge_facility,
|
261
|
+
)
|
262
|
+
|
263
|
+
# Reuse the age and date calculated for the last event in the patient timeline
|
264
|
+
self._update_cehrgpt_record(
|
265
|
+
cehrgpt_record,
|
266
|
+
code="[VE]",
|
267
|
+
)
|
268
|
+
|
269
|
+
# Generate the orders of the concepts that the cehrbert dataset mapping function expects
|
270
|
+
cehrgpt_record["orders"] = list(
|
271
|
+
range(1, len(cehrgpt_record["concept_ids"]) + 1)
|
272
|
+
)
|
273
|
+
|
274
|
+
# Add some count information for this sequence
|
275
|
+
cehrgpt_record["num_of_concepts"] = len(cehrgpt_record["concept_ids"])
|
276
|
+
cehrgpt_record["num_of_visits"] = len(record["visits"])
|
277
|
+
|
278
|
+
if "label" in record:
|
279
|
+
cehrgpt_record["label"] = record["label"]
|
280
|
+
if "age_at_index" in record:
|
281
|
+
cehrgpt_record["age_at_index"] = record["age_at_index"]
|
282
|
+
|
283
|
+
return cehrgpt_record
|
284
|
+
|
285
|
+
|
20
286
|
class HFCehrGptTokenizationMapping(DatasetMapping):
|
21
287
|
def __init__(
|
22
288
|
self,
|
@@ -0,0 +1,71 @@
|
|
1
|
+
import torch
|
2
|
+
from torch.nn.utils.rnn import pad_sequence
|
3
|
+
|
4
|
+
from cehrgpt.data.hf_cehrgpt_dataset_collator import CehrGptDataCollator
|
5
|
+
|
6
|
+
|
7
|
+
class CehrGptDPODataCollator(CehrGptDataCollator):
|
8
|
+
|
9
|
+
def create_preference_inputs(self, examples, prefix):
|
10
|
+
batch = {}
|
11
|
+
# Assume that each example in the batch is a dictionary with 'input_ids' and 'attention_mask'
|
12
|
+
batch_input_ids = [
|
13
|
+
self._try_reverse_tensor(
|
14
|
+
self._convert_to_tensor(example[f"{prefix}_input_ids"])
|
15
|
+
)
|
16
|
+
for example in examples
|
17
|
+
]
|
18
|
+
batch_attention_mask = [
|
19
|
+
self._try_reverse_tensor(
|
20
|
+
torch.ones_like(
|
21
|
+
self._convert_to_tensor(example[f"{prefix}_input_ids"]),
|
22
|
+
dtype=torch.float,
|
23
|
+
)
|
24
|
+
)
|
25
|
+
for example in examples
|
26
|
+
]
|
27
|
+
# Pad sequences to the max length in the batch
|
28
|
+
batch[f"{prefix}_input_ids"] = self._try_reverse_tensor(
|
29
|
+
pad_sequence(
|
30
|
+
batch_input_ids,
|
31
|
+
batch_first=True,
|
32
|
+
padding_value=self.tokenizer.pad_token_id,
|
33
|
+
).to(torch.int64)
|
34
|
+
)
|
35
|
+
batch[f"{prefix}_attention_mask"] = self._try_reverse_tensor(
|
36
|
+
pad_sequence(batch_attention_mask, batch_first=True, padding_value=0.0)
|
37
|
+
)
|
38
|
+
assert batch[f"{prefix}_input_ids"].shape[1] <= self.max_length
|
39
|
+
assert batch[f"{prefix}_attention_mask"].shape[1] <= self.max_length
|
40
|
+
|
41
|
+
if self.include_values:
|
42
|
+
batch_value_indicators = [
|
43
|
+
self._try_reverse_tensor(
|
44
|
+
self._convert_to_tensor(example[f"{prefix}_value_indicators"])
|
45
|
+
)
|
46
|
+
for example in examples
|
47
|
+
]
|
48
|
+
batch_values = [
|
49
|
+
self._try_reverse_tensor(
|
50
|
+
self._convert_to_tensor(example[f"{prefix}__values"])
|
51
|
+
)
|
52
|
+
for example in examples
|
53
|
+
]
|
54
|
+
|
55
|
+
batch[f"{prefix}_value_indicators"] = self._try_reverse_tensor(
|
56
|
+
pad_sequence(
|
57
|
+
batch_value_indicators, batch_first=True, padding_value=False
|
58
|
+
)
|
59
|
+
)
|
60
|
+
batch[f"{prefix}_values"] = self._try_reverse_tensor(
|
61
|
+
pad_sequence(batch_values, batch_first=True, padding_value=-1.0)
|
62
|
+
)
|
63
|
+
assert batch[f"{prefix}_value_indicators"].shape[1] <= self.max_length
|
64
|
+
assert batch[f"{prefix}_values"].shape[1] <= self.max_length
|
65
|
+
return batch
|
66
|
+
|
67
|
+
def __call__(self, examples):
|
68
|
+
batch_chosen = self.create_preference_inputs(examples, "chosen")
|
69
|
+
batch_rejected = self.create_preference_inputs(examples, "rejected")
|
70
|
+
batch_chosen.update(batch_rejected)
|
71
|
+
return batch_chosen
|
@@ -0,0 +1,61 @@
|
|
1
|
+
import copy
|
2
|
+
from typing import Any, Dict
|
3
|
+
|
4
|
+
import numpy as np
|
5
|
+
from cehrbert.data_generators.hf_data_generator.hf_dataset_mapping import DatasetMapping
|
6
|
+
|
7
|
+
from cehrgpt.models.tokenization_hf_cehrgpt import CehrGptTokenizer
|
8
|
+
|
9
|
+
|
10
|
+
class HFCehrGptDPOTokenizationMapping(DatasetMapping):
|
11
|
+
def __init__(
|
12
|
+
self,
|
13
|
+
concept_tokenizer: CehrGptTokenizer,
|
14
|
+
):
|
15
|
+
self._concept_tokenizer = concept_tokenizer
|
16
|
+
self._lab_token_ids = self._concept_tokenizer.lab_token_ids
|
17
|
+
|
18
|
+
def transform_with_prefix(self, record: Dict[str, Any], prefix) -> Dict[str, Any]:
|
19
|
+
concept_ids = record[f"{prefix}_concept_ids"]
|
20
|
+
input_ids = self._concept_tokenizer.encode(concept_ids)
|
21
|
+
record[f"{prefix}_input_ids"] = input_ids
|
22
|
+
|
23
|
+
if f"{prefix}_concept_value_masks" in record:
|
24
|
+
concept_value_masks = record[f"{prefix}_concept_value_masks"]
|
25
|
+
concept_values = record[f"{prefix}_concept_values"]
|
26
|
+
# If any concept has a value associated with it, we normalize the value
|
27
|
+
if np.any(np.asarray(concept_value_masks) > 0):
|
28
|
+
units = record[f"{prefix}_units"]
|
29
|
+
normalized_concept_values = copy.deepcopy(concept_values)
|
30
|
+
for i, (
|
31
|
+
concept_id,
|
32
|
+
unit,
|
33
|
+
token_id,
|
34
|
+
concept_value_mask,
|
35
|
+
concept_value,
|
36
|
+
) in enumerate(
|
37
|
+
zip(
|
38
|
+
concept_ids,
|
39
|
+
units,
|
40
|
+
input_ids,
|
41
|
+
concept_value_masks,
|
42
|
+
concept_values,
|
43
|
+
)
|
44
|
+
):
|
45
|
+
if token_id in self._lab_token_ids:
|
46
|
+
normalized_concept_value = self._concept_tokenizer.normalize(
|
47
|
+
concept_id, unit, concept_value
|
48
|
+
)
|
49
|
+
normalized_concept_values[i] = normalized_concept_value
|
50
|
+
record[f"{prefix}_concept_values"] = normalized_concept_values
|
51
|
+
# Overwrite the column names
|
52
|
+
record[f"{prefix}_value_indicators"] = record[
|
53
|
+
f"{prefix}_concept_value_masks"
|
54
|
+
]
|
55
|
+
record[f"{prefix}_values"] = record[f"{prefix}_concept_values"]
|
56
|
+
return record
|
57
|
+
|
58
|
+
def transform(self, record: Dict[str, Any]) -> Dict[str, Any]:
|
59
|
+
record = self.transform_with_prefix(record, prefix="chosen")
|
60
|
+
record.update(self.transform_with_prefix(record, prefix="rejected"))
|
61
|
+
return record
|
@@ -0,0 +1,224 @@
|
|
1
|
+
import datetime
|
2
|
+
import os
|
3
|
+
import random
|
4
|
+
import uuid
|
5
|
+
|
6
|
+
import pandas as pd
|
7
|
+
import torch
|
8
|
+
from cehrbert.runners.runner_util import load_parquet_as_dataset
|
9
|
+
from transformers.utils import is_flash_attn_2_available, logging
|
10
|
+
|
11
|
+
from cehrgpt.cehrgpt_args import create_inference_base_arg_parser
|
12
|
+
from cehrgpt.generation.generate_batch_hf_gpt_sequence import (
|
13
|
+
generate_single_batch,
|
14
|
+
normalize_value,
|
15
|
+
)
|
16
|
+
from cehrgpt.gpt_utils import get_cehrgpt_output_folder
|
17
|
+
from cehrgpt.models.hf_cehrgpt import CEHRGPT2LMHeadModel
|
18
|
+
from cehrgpt.models.tokenization_hf_cehrgpt import CehrGptTokenizer
|
19
|
+
|
20
|
+
LOG = logging.get_logger("transformers")
|
21
|
+
|
22
|
+
|
23
|
+
def main(args):
|
24
|
+
if torch.cuda.is_available():
|
25
|
+
device = torch.device("cuda")
|
26
|
+
else:
|
27
|
+
device = torch.device("cpu")
|
28
|
+
|
29
|
+
cehrgpt_tokenizer = CehrGptTokenizer.from_pretrained(args.tokenizer_folder)
|
30
|
+
cehrgpt_model = (
|
31
|
+
CEHRGPT2LMHeadModel.from_pretrained(
|
32
|
+
args.model_folder,
|
33
|
+
attn_implementation=(
|
34
|
+
"flash_attention_2" if is_flash_attn_2_available() else "eager"
|
35
|
+
),
|
36
|
+
torch_dtype=(
|
37
|
+
torch.bfloat16 if is_flash_attn_2_available() else torch.float32
|
38
|
+
),
|
39
|
+
)
|
40
|
+
.eval()
|
41
|
+
.to(device)
|
42
|
+
)
|
43
|
+
cehrgpt_model.generation_config.pad_token_id = cehrgpt_tokenizer.pad_token_id
|
44
|
+
cehrgpt_model.generation_config.eos_token_id = cehrgpt_tokenizer.end_token_id
|
45
|
+
cehrgpt_model.generation_config.bos_token_id = cehrgpt_tokenizer.end_token_id
|
46
|
+
|
47
|
+
folder_name = get_cehrgpt_output_folder(args, cehrgpt_tokenizer)
|
48
|
+
output_folder_name = os.path.join(
|
49
|
+
args.output_folder, folder_name, "generated_sequences"
|
50
|
+
)
|
51
|
+
|
52
|
+
if not os.path.exists(output_folder_name):
|
53
|
+
os.makedirs(output_folder_name)
|
54
|
+
|
55
|
+
LOG.info(f"Loading tokenizer at {args.model_folder}")
|
56
|
+
LOG.info(f"Loading model at {args.model_folder}")
|
57
|
+
LOG.info(f"Write sequences to {output_folder_name}")
|
58
|
+
LOG.info(f"Context window {args.context_window}")
|
59
|
+
LOG.info(f"Temperature {args.temperature}")
|
60
|
+
LOG.info(f"Repetition Penalty {args.repetition_penalty}")
|
61
|
+
LOG.info(f"Sampling Strategy {args.sampling_strategy}")
|
62
|
+
LOG.info(f"Num beam {args.num_beams}")
|
63
|
+
LOG.info(f"Num beam groups {args.num_beam_groups}")
|
64
|
+
LOG.info(f"Epsilon cutoff {args.epsilon_cutoff}")
|
65
|
+
LOG.info(f"Top P {args.top_p}")
|
66
|
+
LOG.info(f"Top K {args.top_k}")
|
67
|
+
LOG.info(f"Loading sequence_data_path at {args.sequence_data_path}")
|
68
|
+
|
69
|
+
dataset = load_parquet_as_dataset(args.sequence_data_path)
|
70
|
+
total_rows = len(dataset)
|
71
|
+
float(args.batch_size) / total_rows
|
72
|
+
num_of_batches = args.num_of_patients // args.batch_size + 1
|
73
|
+
sequence_to_flush = []
|
74
|
+
for i in range(num_of_batches):
|
75
|
+
LOG.info(f"{datetime.datetime.now()}: Batch {i} started")
|
76
|
+
sample_data = []
|
77
|
+
while len(sample_data) == 0:
|
78
|
+
random_indices = random.sample(range(total_rows), k=1)
|
79
|
+
for row in dataset.select(random_indices):
|
80
|
+
if 4 <= len(row["concept_ids"]) <= cehrgpt_model.config.n_positions:
|
81
|
+
sample_data.append(row)
|
82
|
+
prompts = []
|
83
|
+
chosen_responses = []
|
84
|
+
cutoff_frac = random.uniform(0, args.cutoff_frac_max)
|
85
|
+
for row in sample_data:
|
86
|
+
seq_len = len(row["concept_ids"])
|
87
|
+
prompt_len = max(4, int(seq_len * cutoff_frac))
|
88
|
+
prompts.append(cehrgpt_tokenizer.encode(row["concept_ids"][:prompt_len]))
|
89
|
+
chosen_responses.append(
|
90
|
+
{
|
91
|
+
"person_id": row["person_id"],
|
92
|
+
"chosen_concept_ids": (
|
93
|
+
row["concept_ids"] if "concept_ids" in row else None
|
94
|
+
),
|
95
|
+
"chosen_concept_values": (
|
96
|
+
row["concept_values"] if "concept_values" in row else None
|
97
|
+
),
|
98
|
+
"chosen_concept_value_masks": (
|
99
|
+
row["concept_value_masks"]
|
100
|
+
if "concept_value_masks" in row
|
101
|
+
else None
|
102
|
+
),
|
103
|
+
"chosen_units": row["units"] if "units" in row else None,
|
104
|
+
"prompt_length": prompt_len,
|
105
|
+
}
|
106
|
+
)
|
107
|
+
|
108
|
+
batch_sequences = generate_single_batch(
|
109
|
+
cehrgpt_model,
|
110
|
+
cehrgpt_tokenizer,
|
111
|
+
prompts=prompts,
|
112
|
+
max_new_tokens=args.context_window,
|
113
|
+
mini_num_of_concepts=args.min_num_of_concepts,
|
114
|
+
top_p=args.top_p,
|
115
|
+
top_k=args.top_k,
|
116
|
+
temperature=args.temperature,
|
117
|
+
repetition_penalty=args.repetition_penalty,
|
118
|
+
num_beams=args.num_beams,
|
119
|
+
num_beam_groups=args.num_beam_groups,
|
120
|
+
epsilon_cutoff=args.epsilon_cutoff,
|
121
|
+
device=device,
|
122
|
+
)
|
123
|
+
|
124
|
+
# Clear the cache
|
125
|
+
torch.cuda.empty_cache()
|
126
|
+
|
127
|
+
for seq, value_indicator, value, chosen_response in zip(
|
128
|
+
batch_sequences["sequences"],
|
129
|
+
batch_sequences["value_indicators"],
|
130
|
+
batch_sequences["values"],
|
131
|
+
chosen_responses,
|
132
|
+
):
|
133
|
+
output = {"rejected_concept_ids": seq}
|
134
|
+
normalized_values, units = normalize_value(
|
135
|
+
seq, value_indicator, value, cehrgpt_tokenizer
|
136
|
+
)
|
137
|
+
if normalized_values is not None:
|
138
|
+
output["rejected_concept_values"] = normalized_values
|
139
|
+
if value_indicator is not None:
|
140
|
+
output["rejected_concept_value_masks"] = value_indicator
|
141
|
+
if units is not None:
|
142
|
+
output["rejected_units"] = units
|
143
|
+
output.update(chosen_response)
|
144
|
+
sequence_to_flush.append(output)
|
145
|
+
|
146
|
+
if len(sequence_to_flush) >= args.buffer_size:
|
147
|
+
LOG.info(f"{datetime.datetime.now()}: Flushing to the Disk at Batch {i}")
|
148
|
+
pd.DataFrame(
|
149
|
+
sequence_to_flush,
|
150
|
+
columns=[
|
151
|
+
"person_id",
|
152
|
+
"chosen_concept_ids",
|
153
|
+
"chosen_concept_values",
|
154
|
+
"chosen_concept_value_masks",
|
155
|
+
"chosen_units",
|
156
|
+
"prompt_length",
|
157
|
+
"rejected_concept_ids",
|
158
|
+
"rejected_concept_values",
|
159
|
+
"rejected_concept_value_masks",
|
160
|
+
"rejected_units",
|
161
|
+
],
|
162
|
+
).to_parquet(os.path.join(output_folder_name, f"{uuid.uuid4()}.parquet"))
|
163
|
+
sequence_to_flush.clear()
|
164
|
+
|
165
|
+
if len(sequence_to_flush) > 0:
|
166
|
+
LOG.info(f"{datetime.datetime.now()}: Flushing to the Disk at Final Batch")
|
167
|
+
pd.DataFrame(
|
168
|
+
sequence_to_flush,
|
169
|
+
columns=[
|
170
|
+
"person_id",
|
171
|
+
"chosen_concept_ids",
|
172
|
+
"chosen_concept_values",
|
173
|
+
"chosen_concept_value_masks",
|
174
|
+
"chosen_units",
|
175
|
+
"prompt_length",
|
176
|
+
"rejected_concept_ids",
|
177
|
+
"rejected_concept_values",
|
178
|
+
"rejected_concept_value_masks",
|
179
|
+
"rejected_units",
|
180
|
+
],
|
181
|
+
).to_parquet(os.path.join(output_folder_name, f"{uuid.uuid4()}-last.parquet"))
|
182
|
+
|
183
|
+
|
184
|
+
def create_arg_parser():
|
185
|
+
base_arg_parser = create_inference_base_arg_parser(
|
186
|
+
description="Arguments for generating paired patient sequences"
|
187
|
+
)
|
188
|
+
base_arg_parser.add_argument(
|
189
|
+
"--num_of_patients",
|
190
|
+
dest="num_of_patients",
|
191
|
+
action="store",
|
192
|
+
type=int,
|
193
|
+
help="The number of patients that will be generated",
|
194
|
+
required=True,
|
195
|
+
)
|
196
|
+
base_arg_parser.add_argument(
|
197
|
+
"--sequence_data_path",
|
198
|
+
dest="sequence_data_path",
|
199
|
+
action="store",
|
200
|
+
help="The path for your sequence data",
|
201
|
+
required=True,
|
202
|
+
)
|
203
|
+
base_arg_parser.add_argument(
|
204
|
+
"--cutoff_frac_max",
|
205
|
+
dest="cutoff_frac_max",
|
206
|
+
action="store",
|
207
|
+
type=float,
|
208
|
+
help="The max fraction of the patient sequences that will be used for prompting",
|
209
|
+
required=False,
|
210
|
+
default=0.5,
|
211
|
+
)
|
212
|
+
base_arg_parser.add_argument(
|
213
|
+
"--num_proc",
|
214
|
+
dest="num_proc",
|
215
|
+
action="store",
|
216
|
+
type=int,
|
217
|
+
required=False,
|
218
|
+
default=1,
|
219
|
+
)
|
220
|
+
return base_arg_parser
|
221
|
+
|
222
|
+
|
223
|
+
if __name__ == "__main__":
|
224
|
+
main(create_arg_parser().parse_args())
|
@@ -35,6 +35,7 @@ from cehrgpt.models.tokenization_hf_cehrgpt import END_TOKEN
|
|
35
35
|
# TODO: move these to cehrbert_data
|
36
36
|
STOP_TOKENS = ["VE", "[VE]", END_TOKEN]
|
37
37
|
|
38
|
+
OOV = "[OOV]"
|
38
39
|
CURRENT_PATH = Path(__file__).parent
|
39
40
|
START_TOKEN_SIZE = 4
|
40
41
|
ATT_TIME_TOKENS = generate_artificial_time_tokens()
|
@@ -297,6 +298,8 @@ def gpt_to_omop_converter_batch(
|
|
297
298
|
inpatient_visit_indicator = False
|
298
299
|
|
299
300
|
for event_idx, event in enumerate(clinical_events, 0):
|
301
|
+
if event == OOV:
|
302
|
+
continue
|
300
303
|
# For bad sequences, we don't proceed further and break from the for loop
|
301
304
|
if bad_sequence:
|
302
305
|
break
|
cehrgpt/models/hf_cehrgpt.py
CHANGED
@@ -1766,6 +1766,7 @@ class CehrGptForClassification(CEHRGPTPreTrainedModel):
|
|
1766
1766
|
output_attentions: Optional[bool] = None,
|
1767
1767
|
output_hidden_states: Optional[bool] = None,
|
1768
1768
|
return_dict: Optional[bool] = None,
|
1769
|
+
**kwargs,
|
1769
1770
|
) -> CehrGptSequenceClassifierOutput:
|
1770
1771
|
cehrgpt_output = self.cehrgpt(
|
1771
1772
|
input_ids=input_ids,
|
@@ -918,12 +918,12 @@ class CehrGptTokenizer(PreTrainedTokenizer):
|
|
918
918
|
map_statistics_partial = partial(map_statistics, size=SAMPLE_SIZE)
|
919
919
|
|
920
920
|
if data_args.streaming:
|
921
|
+
first_example = next(iter(dataset))
|
921
922
|
parts = dataset.map(
|
922
923
|
partial(agg_helper, map_func=map_statistics_partial),
|
923
924
|
batched=True,
|
924
925
|
batch_size=data_args.preprocessing_batch_size,
|
925
|
-
|
926
|
-
remove_columns=dataset.column_names,
|
926
|
+
remove_columns=first_example.keys(),
|
927
927
|
)
|
928
928
|
else:
|
929
929
|
parts = dataset.map(
|
File without changes
|