cehrgpt 0.0.1__py3-none-any.whl → 0.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cehrgpt/data/hf_cehrgpt_dataset_mapping.py +267 -1
- cehrgpt/data/hf_cehrgpt_dpo_collator.py +71 -0
- cehrgpt/data/hf_cehrgpt_dpo_dataset_mapping.py +61 -0
- cehrgpt/generation/generate_paired_cehrgpt_sequence.py +224 -0
- cehrgpt/generation/omop_converter_batch.py +3 -0
- cehrgpt/models/hf_cehrgpt.py +1 -0
- cehrgpt/models/tokenization_hf_cehrgpt.py +2 -2
- cehrgpt/rl_finetune/__init__.py +0 -0
- cehrgpt/rl_finetune/cehrgpt_dpo_trainer.py +586 -0
- cehrgpt/rl_finetune/cehrgpt_ppo_trainer.py +464 -0
- cehrgpt/rl_finetune/ppo_finetune.py +394 -0
- cehrgpt/rl_finetune/ppo_finetune_v2.py +373 -0
- cehrgpt/runners/hf_cehrgpt_dpo_runner.py +119 -0
- cehrgpt/runners/hf_cehrgpt_finetune_runner.py +24 -3
- cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +44 -8
- cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +4 -0
- cehrgpt/tools/generate_causal_patient_split_by_age.py +146 -0
- {cehrgpt-0.0.1.dist-info → cehrgpt-0.0.2.dist-info}/METADATA +52 -6
- {cehrgpt-0.0.1.dist-info → cehrgpt-0.0.2.dist-info}/RECORD +22 -12
- {cehrgpt-0.0.1.dist-info → cehrgpt-0.0.2.dist-info}/WHEEL +1 -1
- {cehrgpt-0.0.1.dist-info → cehrgpt-0.0.2.dist-info}/LICENSE +0 -0
- {cehrgpt-0.0.1.dist-info → cehrgpt-0.0.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,373 @@
|
|
1
|
+
import datetime
|
2
|
+
import os
|
3
|
+
import pickle
|
4
|
+
from collections import Counter, defaultdict
|
5
|
+
from functools import partial
|
6
|
+
from typing import Any, Dict, List
|
7
|
+
|
8
|
+
import numpy as np
|
9
|
+
import torch
|
10
|
+
from cehrbert.models.hf_models.tokenization_utils import agg_helper
|
11
|
+
from cehrbert.runners.runner_util import load_parquet_as_dataset
|
12
|
+
from tqdm import tqdm
|
13
|
+
from transformers.utils import is_flash_attn_2_available, logging
|
14
|
+
from trl import AutoModelForCausalLMWithValueHead, PPOConfig, create_reference_model
|
15
|
+
|
16
|
+
from cehrgpt.cehrgpt_args import create_inference_base_arg_parser
|
17
|
+
from cehrgpt.generation.generate_batch_hf_gpt_sequence import generate_single_batch
|
18
|
+
from cehrgpt.gpt_utils import get_cehrgpt_output_folder
|
19
|
+
from cehrgpt.models.hf_cehrgpt import CEHRGPT2LMHeadModel
|
20
|
+
from cehrgpt.models.tokenization_hf_cehrgpt import CehrGptTokenizer
|
21
|
+
from cehrgpt.rl_finetune.cehrgpt_ppo_trainer import (
|
22
|
+
CehrGptPPODataCollator,
|
23
|
+
CehrGptPPOTrainer,
|
24
|
+
)
|
25
|
+
|
26
|
+
LOG = logging.get_logger("transformers")
|
27
|
+
|
28
|
+
|
29
|
+
def extract_concept_frequency(records: Dict[str, Any]) -> Dict[str, int]:
|
30
|
+
batched_concept_ids = records["concept_ids"]
|
31
|
+
outputs = defaultdict(int)
|
32
|
+
for concept_ids in batched_concept_ids:
|
33
|
+
for concept_id, cnt in dict(Counter(concept_ids[4:])).items():
|
34
|
+
outputs[concept_id] += cnt
|
35
|
+
return outputs
|
36
|
+
|
37
|
+
|
38
|
+
def main(args):
|
39
|
+
if torch.cuda.is_available():
|
40
|
+
device = torch.device("cuda")
|
41
|
+
else:
|
42
|
+
device = torch.device("cpu")
|
43
|
+
|
44
|
+
cehrgpt_tokenizer = CehrGptTokenizer.from_pretrained(args.tokenizer_folder)
|
45
|
+
model_folder_name = os.path.join(
|
46
|
+
args.output_folder, get_cehrgpt_output_folder(args, cehrgpt_tokenizer), "model"
|
47
|
+
)
|
48
|
+
|
49
|
+
if not os.path.exists(model_folder_name):
|
50
|
+
os.makedirs(model_folder_name)
|
51
|
+
|
52
|
+
if args.restore_from_checkpoint:
|
53
|
+
try:
|
54
|
+
cehrgpt_model = CEHRGPT2LMHeadModel.from_pretrained(
|
55
|
+
model_folder_name,
|
56
|
+
attn_implementation=(
|
57
|
+
"flash_attention_2" if is_flash_attn_2_available() else "eager"
|
58
|
+
),
|
59
|
+
torch_dtype=(
|
60
|
+
torch.bfloat16 if is_flash_attn_2_available() else torch.float32
|
61
|
+
),
|
62
|
+
)
|
63
|
+
except Exception:
|
64
|
+
LOG.warning(
|
65
|
+
"Checkpoint does not exist in %s, loading from the %s",
|
66
|
+
model_folder_name,
|
67
|
+
args.model_folder,
|
68
|
+
)
|
69
|
+
cehrgpt_model = CEHRGPT2LMHeadModel.from_pretrained(
|
70
|
+
args.model_folder,
|
71
|
+
attn_implementation=(
|
72
|
+
"flash_attention_2" if is_flash_attn_2_available() else "eager"
|
73
|
+
),
|
74
|
+
torch_dtype=(
|
75
|
+
torch.bfloat16 if is_flash_attn_2_available() else torch.float32
|
76
|
+
),
|
77
|
+
)
|
78
|
+
else:
|
79
|
+
cehrgpt_model = CEHRGPT2LMHeadModel.from_pretrained(
|
80
|
+
args.model_folder,
|
81
|
+
attn_implementation=(
|
82
|
+
"flash_attention_2" if is_flash_attn_2_available() else "eager"
|
83
|
+
),
|
84
|
+
torch_dtype=(
|
85
|
+
torch.bfloat16 if is_flash_attn_2_available() else torch.float32
|
86
|
+
),
|
87
|
+
)
|
88
|
+
|
89
|
+
cehrgpt_model.generation_config.pad_token_id = cehrgpt_tokenizer.pad_token_id
|
90
|
+
cehrgpt_model.generation_config.eos_token_id = cehrgpt_tokenizer.end_token_id
|
91
|
+
cehrgpt_model.generation_config.bos_token_id = cehrgpt_tokenizer.end_token_id
|
92
|
+
model = AutoModelForCausalLMWithValueHead(cehrgpt_model).to(device)
|
93
|
+
model.is_peft_model = False
|
94
|
+
ref_model = create_reference_model(model).to(device)
|
95
|
+
|
96
|
+
# create a ppo trainer
|
97
|
+
ppo_trainer = CehrGptPPOTrainer(
|
98
|
+
config=PPOConfig(
|
99
|
+
batch_size=args.batch_size,
|
100
|
+
mini_batch_size=args.mini_batch_size,
|
101
|
+
init_kl_coef=args.init_kl_coef,
|
102
|
+
vf_coef=args.vf_coef,
|
103
|
+
kl_penalty=args.kl_penalty,
|
104
|
+
gamma=args.gamma,
|
105
|
+
use_score_scaling=args.use_score_scaling,
|
106
|
+
),
|
107
|
+
model=model,
|
108
|
+
ref_model=ref_model,
|
109
|
+
tokenizer=cehrgpt_tokenizer,
|
110
|
+
training_data_collator=CehrGptPPODataCollator(
|
111
|
+
cehrgpt_tokenizer, max_length=args.context_window
|
112
|
+
),
|
113
|
+
)
|
114
|
+
|
115
|
+
LOG.info(f"Loading tokenizer at {args.model_folder}")
|
116
|
+
LOG.info(f"Loading model at {args.model_folder}")
|
117
|
+
LOG.info(f"Will save the fine-tuned model at {model_folder_name}")
|
118
|
+
LOG.info(f"Context window {args.context_window}")
|
119
|
+
LOG.info(f"Temperature {args.temperature}")
|
120
|
+
LOG.info(f"Repetition Penalty {args.repetition_penalty}")
|
121
|
+
LOG.info(f"Sampling Strategy {args.sampling_strategy}")
|
122
|
+
LOG.info(f"Num beam {args.num_beams}")
|
123
|
+
LOG.info(f"Num beam groups {args.num_beam_groups}")
|
124
|
+
LOG.info(f"Epsilon cutoff {args.epsilon_cutoff}")
|
125
|
+
LOG.info(f"Top P {args.top_p}")
|
126
|
+
LOG.info(f"Top K {args.top_k}")
|
127
|
+
LOG.info(f"Loading demographic_info at {args.demographic_data_path}")
|
128
|
+
|
129
|
+
dataset = load_parquet_as_dataset(args.demographic_data_path).filter(
|
130
|
+
lambda batched: [
|
131
|
+
model.config.n_positions >= num_of_concepts > args.min_num_tokens
|
132
|
+
for num_of_concepts in batched["num_of_concepts"]
|
133
|
+
],
|
134
|
+
batched=True,
|
135
|
+
)
|
136
|
+
parts = dataset.map(
|
137
|
+
partial(agg_helper, map_func=extract_concept_frequency),
|
138
|
+
batched=True,
|
139
|
+
batch_size=1000,
|
140
|
+
num_proc=args.num_proc,
|
141
|
+
remove_columns=dataset.column_names,
|
142
|
+
)
|
143
|
+
|
144
|
+
concept_stats = defaultdict(float)
|
145
|
+
for stat in tqdm(parts, desc="Aggregating the concept counts"):
|
146
|
+
fixed_stat = pickle.loads(stat["data"])
|
147
|
+
for concept_id, count in fixed_stat.items():
|
148
|
+
concept_stats[concept_id] += count
|
149
|
+
total_sum = sum(concept_stats.values())
|
150
|
+
for concept_id, count in concept_stats.items():
|
151
|
+
concept_stats[concept_id] = count / total_sum
|
152
|
+
|
153
|
+
logs = []
|
154
|
+
device = ppo_trainer.current_device
|
155
|
+
total_rows = len(dataset)
|
156
|
+
num_of_micro_batches = args.batch_size // args.mini_batch_size
|
157
|
+
for i in tqdm(range(args.num_of_steps)):
|
158
|
+
LOG.info(f"{datetime.datetime.now()}: Batch {i} started")
|
159
|
+
random_prompts = []
|
160
|
+
batched_sequences = []
|
161
|
+
batched_values = []
|
162
|
+
batched_value_indicators = []
|
163
|
+
for _ in range(num_of_micro_batches):
|
164
|
+
random_indices = np.random.randint(0, total_rows, args.mini_batch_size)
|
165
|
+
random_prompts_micro_batch = [
|
166
|
+
record["concept_ids"][:4] for record in dataset.select(random_indices)
|
167
|
+
]
|
168
|
+
random_prompts.extend(random_prompts_micro_batch)
|
169
|
+
micro_batched_prompts = [
|
170
|
+
cehrgpt_tokenizer.encode(random_prompt)
|
171
|
+
for random_prompt in random_prompts_micro_batch
|
172
|
+
]
|
173
|
+
|
174
|
+
micro_batched_sequences = generate_single_batch(
|
175
|
+
cehrgpt_model,
|
176
|
+
cehrgpt_tokenizer,
|
177
|
+
micro_batched_prompts,
|
178
|
+
max_new_tokens=args.context_window,
|
179
|
+
mini_num_of_concepts=args.min_num_of_concepts,
|
180
|
+
top_p=args.top_p,
|
181
|
+
top_k=args.top_k,
|
182
|
+
temperature=args.temperature,
|
183
|
+
repetition_penalty=args.repetition_penalty,
|
184
|
+
num_beams=args.num_beams,
|
185
|
+
num_beam_groups=args.num_beam_groups,
|
186
|
+
epsilon_cutoff=args.epsilon_cutoff,
|
187
|
+
device=device,
|
188
|
+
)
|
189
|
+
# Clear the cache
|
190
|
+
torch.cuda.empty_cache()
|
191
|
+
batched_sequences.extend(micro_batched_sequences["sequences"])
|
192
|
+
batched_values.extend(micro_batched_sequences["values"])
|
193
|
+
batched_value_indicators.extend(micro_batched_sequences["value_indicators"])
|
194
|
+
|
195
|
+
LOG.info(f"{datetime.datetime.now()}: Batch {i} sequence generated")
|
196
|
+
reward = compute_marginal_dist_reward(
|
197
|
+
batched_sequences, concept_stats, cehrgpt_tokenizer
|
198
|
+
)
|
199
|
+
LOG.info(f"{datetime.datetime.now()}: Batch {i} KL divergence reward: {reward}")
|
200
|
+
query_tensors = []
|
201
|
+
response_tensors = []
|
202
|
+
value_tensors = []
|
203
|
+
value_indicator_tensors = []
|
204
|
+
rewards = []
|
205
|
+
for sequence, values, value_indicators in zip(
|
206
|
+
batched_sequences, batched_values, batched_value_indicators
|
207
|
+
):
|
208
|
+
# Convert sequence to a NumPy array if it's not already one
|
209
|
+
sequence_array = np.asarray(sequence)
|
210
|
+
# Find the end token
|
211
|
+
condition_array = sequence_array == cehrgpt_tokenizer.end_token
|
212
|
+
end_index = (
|
213
|
+
np.argmax(condition_array)
|
214
|
+
if condition_array.any()
|
215
|
+
else len(sequence_array) - 1
|
216
|
+
)
|
217
|
+
|
218
|
+
sequence = sequence[: end_index + 1]
|
219
|
+
values = values[: end_index + 1]
|
220
|
+
value_indicators = value_indicators[: end_index + 1]
|
221
|
+
|
222
|
+
query_tensors.append(torch.tensor(cehrgpt_tokenizer.encode(sequence[:4])))
|
223
|
+
response_tensors.append(
|
224
|
+
torch.tensor(cehrgpt_tokenizer.encode(sequence[4:]))
|
225
|
+
)
|
226
|
+
value_tensors.append(torch.tensor(cehrgpt_tokenizer.encode_value(values)))
|
227
|
+
value_indicator_tensors.append(torch.tensor(value_indicators))
|
228
|
+
rewards.append(reward)
|
229
|
+
|
230
|
+
train_stats = ppo_trainer.step(
|
231
|
+
query_tensors,
|
232
|
+
response_tensors,
|
233
|
+
rewards,
|
234
|
+
value_tensors,
|
235
|
+
value_indicator_tensors,
|
236
|
+
)
|
237
|
+
LOG.info(f"{datetime.datetime.now()}: Batch {i} stats: {train_stats}")
|
238
|
+
logs.append(reward)
|
239
|
+
ppo_trainer.log_stats(stats=train_stats, batch={}, rewards=rewards)
|
240
|
+
ppo_trainer.save_pretrained(model_folder_name)
|
241
|
+
with open(os.path.join(model_folder_name, "ppo_finetune_stats.pkl"), "wb") as f:
|
242
|
+
pickle.dump(logs, f)
|
243
|
+
|
244
|
+
|
245
|
+
def compute_marginal_dist_reward(
|
246
|
+
batched_sequences: List[List[str]],
|
247
|
+
expected_concept_dist: Dict[str, float],
|
248
|
+
tokenizer: CehrGptTokenizer,
|
249
|
+
) -> torch.Tensor:
|
250
|
+
actual_concept_dist = dict(
|
251
|
+
Counter(
|
252
|
+
[
|
253
|
+
concept_id
|
254
|
+
for sequence in batched_sequences
|
255
|
+
for concept_id in sequence[4:]
|
256
|
+
]
|
257
|
+
)
|
258
|
+
)
|
259
|
+
total_count = sum(actual_concept_dist.values())
|
260
|
+
for concept_id in actual_concept_dist.keys():
|
261
|
+
actual_concept_dist[concept_id] /= total_count
|
262
|
+
# Translate the concept ids to token ids
|
263
|
+
actual_dist = np.zeros(tokenizer.vocab_size)
|
264
|
+
actual_dist[tokenizer.encode(list(actual_concept_dist.keys()))] = list(
|
265
|
+
actual_concept_dist.values()
|
266
|
+
)
|
267
|
+
# Add a small epsilon to avoid log(0)
|
268
|
+
epsilon = 1e-10
|
269
|
+
logprob_dist = torch.tensor(np.log(actual_dist + epsilon))
|
270
|
+
# Translate the concept ids to token ids
|
271
|
+
ref_dist = np.zeros(tokenizer.vocab_size)
|
272
|
+
ref_dist[tokenizer.encode(list(expected_concept_dist.keys()))] = list(
|
273
|
+
expected_concept_dist.values()
|
274
|
+
)
|
275
|
+
ref_logprob_dist = torch.tensor(np.log(ref_dist + epsilon))
|
276
|
+
|
277
|
+
# Flip is required due to this issue? :https://github.com/pytorch/pytorch/issues/57459
|
278
|
+
return torch.exp(
|
279
|
+
-torch.nn.functional.kl_div(
|
280
|
+
ref_logprob_dist, logprob_dist, log_target=True, reduction="none"
|
281
|
+
).sum(-1)
|
282
|
+
)
|
283
|
+
|
284
|
+
|
285
|
+
def create_arg_parser():
|
286
|
+
base_arg_parser = create_inference_base_arg_parser(
|
287
|
+
description="Arguments for finetuning cehr-gpt using PPO"
|
288
|
+
)
|
289
|
+
base_arg_parser.add_argument(
|
290
|
+
"--mini_batch_size",
|
291
|
+
dest="mini_batch_size",
|
292
|
+
action="store",
|
293
|
+
type=int,
|
294
|
+
required=True,
|
295
|
+
)
|
296
|
+
base_arg_parser.add_argument(
|
297
|
+
"--init_kl_coef",
|
298
|
+
dest="init_kl_coef",
|
299
|
+
action="store",
|
300
|
+
type=float,
|
301
|
+
required=False,
|
302
|
+
default=0.1,
|
303
|
+
)
|
304
|
+
base_arg_parser.add_argument(
|
305
|
+
"--vf_coef",
|
306
|
+
dest="vf_coef",
|
307
|
+
action="store",
|
308
|
+
type=float,
|
309
|
+
required=False,
|
310
|
+
default=0.1,
|
311
|
+
)
|
312
|
+
base_arg_parser.add_argument(
|
313
|
+
"--kl_penalty",
|
314
|
+
dest="kl_penalty",
|
315
|
+
action="store",
|
316
|
+
choices=["kl", "abs", "mse", "full"],
|
317
|
+
required=False,
|
318
|
+
default="kl",
|
319
|
+
)
|
320
|
+
base_arg_parser.add_argument(
|
321
|
+
"--gamma",
|
322
|
+
dest="gamma",
|
323
|
+
action="store",
|
324
|
+
type=float,
|
325
|
+
required=False,
|
326
|
+
default=0.99,
|
327
|
+
)
|
328
|
+
base_arg_parser.add_argument(
|
329
|
+
"--num_proc",
|
330
|
+
dest="num_proc",
|
331
|
+
action="store",
|
332
|
+
type=int,
|
333
|
+
default=4,
|
334
|
+
required=False,
|
335
|
+
)
|
336
|
+
base_arg_parser.add_argument(
|
337
|
+
"--num_of_steps",
|
338
|
+
dest="num_of_steps",
|
339
|
+
action="store",
|
340
|
+
type=int,
|
341
|
+
default=1028,
|
342
|
+
required=False,
|
343
|
+
)
|
344
|
+
base_arg_parser.add_argument(
|
345
|
+
"--min_num_tokens",
|
346
|
+
dest="min_num_tokens",
|
347
|
+
action="store",
|
348
|
+
type=int,
|
349
|
+
default=4,
|
350
|
+
required=False,
|
351
|
+
)
|
352
|
+
base_arg_parser.add_argument(
|
353
|
+
"--demographic_data_path",
|
354
|
+
dest="demographic_data_path",
|
355
|
+
action="store",
|
356
|
+
help="The path for your concept_path",
|
357
|
+
required=True,
|
358
|
+
)
|
359
|
+
base_arg_parser.add_argument(
|
360
|
+
"--restore_from_checkpoint",
|
361
|
+
dest="restore_from_checkpoint",
|
362
|
+
action="store_true",
|
363
|
+
)
|
364
|
+
base_arg_parser.add_argument(
|
365
|
+
"--use_score_scaling",
|
366
|
+
dest="use_score_scaling",
|
367
|
+
action="store_true",
|
368
|
+
)
|
369
|
+
return base_arg_parser
|
370
|
+
|
371
|
+
|
372
|
+
if __name__ == "__main__":
|
373
|
+
main(create_arg_parser().parse_args())
|
@@ -0,0 +1,119 @@
|
|
1
|
+
from cehrbert.data_generators.hf_data_generator.hf_dataset import (
|
2
|
+
apply_cehrbert_dataset_mapping,
|
3
|
+
)
|
4
|
+
from cehrbert.runners.runner_util import (
|
5
|
+
generate_prepared_ds_path,
|
6
|
+
get_last_hf_checkpoint,
|
7
|
+
load_parquet_as_dataset,
|
8
|
+
)
|
9
|
+
from datasets import DatasetDict, load_from_disk
|
10
|
+
from transformers import set_seed
|
11
|
+
from transformers.utils import is_flash_attn_2_available, logging
|
12
|
+
|
13
|
+
from cehrgpt.data.hf_cehrgpt_dpo_collator import CehrGptDPODataCollator
|
14
|
+
from cehrgpt.data.hf_cehrgpt_dpo_dataset_mapping import HFCehrGptDPOTokenizationMapping
|
15
|
+
from cehrgpt.models.hf_cehrgpt import CEHRGPT2LMHeadModel
|
16
|
+
from cehrgpt.rl_finetune.cehrgpt_dpo_trainer import CehrGptDPOTrainer
|
17
|
+
from cehrgpt.runners.gpt_runner_util import parse_dpo_runner_args
|
18
|
+
from cehrgpt.runners.hf_cehrgpt_finetune_runner import load_pretrained_tokenizer
|
19
|
+
|
20
|
+
LOG = logging.get_logger("transformers")
|
21
|
+
|
22
|
+
|
23
|
+
def main():
|
24
|
+
cehrgpt_args, data_args, model_args, dpo_config = parse_dpo_runner_args()
|
25
|
+
tokenizer = load_pretrained_tokenizer(model_args)
|
26
|
+
prepared_ds_path = generate_prepared_ds_path(
|
27
|
+
data_args, model_args, data_folder=data_args.cohort_folder
|
28
|
+
)
|
29
|
+
if any(prepared_ds_path.glob("*")):
|
30
|
+
LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
|
31
|
+
processed_dataset = load_from_disk(str(prepared_ds_path))
|
32
|
+
LOG.info("Prepared dataset loaded from disk...")
|
33
|
+
else:
|
34
|
+
dataset = load_parquet_as_dataset(data_args.data_folder)
|
35
|
+
# Random split
|
36
|
+
dataset = dataset.train_test_split(
|
37
|
+
test_size=data_args.validation_split_percentage, seed=dpo_config.seed
|
38
|
+
)
|
39
|
+
processed_dataset = apply_cehrbert_dataset_mapping(
|
40
|
+
dataset,
|
41
|
+
mapping_function=HFCehrGptDPOTokenizationMapping(tokenizer),
|
42
|
+
batch_size=data_args.preprocessing_batch_size,
|
43
|
+
num_proc=data_args.preprocessing_num_workers,
|
44
|
+
streaming=data_args.streaming,
|
45
|
+
)
|
46
|
+
|
47
|
+
processed_dataset = processed_dataset.filter(
|
48
|
+
lambda batch: [
|
49
|
+
len(chosen_concept_ids) < model_args.max_position_embeddings
|
50
|
+
for chosen_concept_ids in batch["chosen_concept_ids"]
|
51
|
+
],
|
52
|
+
batched=True,
|
53
|
+
batch_size=data_args.preprocessing_batch_size,
|
54
|
+
num_proc=data_args.preprocessing_num_workers,
|
55
|
+
).filter(
|
56
|
+
lambda batch: [
|
57
|
+
len(rejected_concept_ids) < model_args.max_position_embeddings
|
58
|
+
for rejected_concept_ids in batch["rejected_concept_ids"]
|
59
|
+
],
|
60
|
+
batched=True,
|
61
|
+
batch_size=data_args.preprocessing_batch_size,
|
62
|
+
num_proc=data_args.preprocessing_num_workers,
|
63
|
+
)
|
64
|
+
processed_dataset.save_to_disk(prepared_ds_path)
|
65
|
+
|
66
|
+
# Set seed before initializing model.
|
67
|
+
set_seed(dpo_config.seed)
|
68
|
+
processed_dataset.set_format("pt")
|
69
|
+
|
70
|
+
# A hacky way to prevent the training from removing unmatched inputs
|
71
|
+
dpo_config.label_names = [
|
72
|
+
"chosen_input_ids",
|
73
|
+
"rejected_input_ids",
|
74
|
+
"chosen_concept_values",
|
75
|
+
"rejected_concept_values",
|
76
|
+
"chosen_concept_value_masks",
|
77
|
+
"rejected_concept_value_masks",
|
78
|
+
]
|
79
|
+
|
80
|
+
attn_implementation = (
|
81
|
+
"flash_attention_2" if is_flash_attn_2_available() else "eager"
|
82
|
+
)
|
83
|
+
model = CEHRGPT2LMHeadModel.from_pretrained(
|
84
|
+
model_args.model_name_or_path,
|
85
|
+
attn_implementation=attn_implementation,
|
86
|
+
)
|
87
|
+
ref_model = CEHRGPT2LMHeadModel.from_pretrained(
|
88
|
+
model_args.model_name_or_path,
|
89
|
+
attn_implementation=attn_implementation,
|
90
|
+
)
|
91
|
+
|
92
|
+
# Initialize Trainer for final training on the combined train+val set
|
93
|
+
trainer = CehrGptDPOTrainer(
|
94
|
+
model=model,
|
95
|
+
ref_model=ref_model,
|
96
|
+
args=dpo_config,
|
97
|
+
tokenizer=tokenizer,
|
98
|
+
train_dataset=processed_dataset["train"],
|
99
|
+
eval_dataset=processed_dataset["test"],
|
100
|
+
data_collator=CehrGptDPODataCollator(
|
101
|
+
tokenizer=tokenizer,
|
102
|
+
max_length=model_args.max_position_embeddings,
|
103
|
+
pretraining=False,
|
104
|
+
include_ttv_prediction=False,
|
105
|
+
use_sub_time_tokenization=False,
|
106
|
+
),
|
107
|
+
)
|
108
|
+
# Train the model on the combined train + val set
|
109
|
+
checkpoint = get_last_hf_checkpoint(dpo_config)
|
110
|
+
train_result = trainer.train(resume_from_checkpoint=checkpoint)
|
111
|
+
trainer.save_model() # Saves the tokenizer too for easy upload
|
112
|
+
metrics = train_result.metrics
|
113
|
+
trainer.log_metrics("train", metrics)
|
114
|
+
trainer.save_metrics("train", metrics)
|
115
|
+
trainer.save_state()
|
116
|
+
|
117
|
+
|
118
|
+
if __name__ == "__main__":
|
119
|
+
main()
|
@@ -43,6 +43,7 @@ from transformers.utils import is_flash_attn_2_available, logging
|
|
43
43
|
|
44
44
|
from cehrgpt.data.hf_cehrgpt_dataset import create_cehrgpt_finetuning_dataset
|
45
45
|
from cehrgpt.data.hf_cehrgpt_dataset_collator import CehrGptDataCollator
|
46
|
+
from cehrgpt.data.hf_cehrgpt_dataset_mapping import MedToCehrGPTDatasetMapping
|
46
47
|
from cehrgpt.models.hf_cehrgpt import (
|
47
48
|
CEHRGPTConfig,
|
48
49
|
CehrGptForClassification,
|
@@ -408,10 +409,24 @@ def main():
|
|
408
409
|
except Exception as e:
|
409
410
|
LOG.exception(e)
|
410
411
|
dataset = create_dataset_from_meds_reader(
|
411
|
-
data_args,
|
412
|
+
data_args=data_args,
|
413
|
+
dataset_mappings=[
|
414
|
+
MedToCehrGPTDatasetMapping(
|
415
|
+
data_args=data_args,
|
416
|
+
is_pretraining=False,
|
417
|
+
include_inpatient_hour_token=cehrgpt_args.include_inpatient_hour_token,
|
418
|
+
)
|
419
|
+
],
|
412
420
|
)
|
413
421
|
if not data_args.streaming:
|
414
|
-
dataset.save_to_disk(meds_extension_path)
|
422
|
+
dataset.save_to_disk(str(meds_extension_path))
|
423
|
+
stats = dataset.cleanup_cache_files()
|
424
|
+
LOG.info(
|
425
|
+
"Clean up the cached files for the cehrgpt dataset transformed from the MEDS: %s",
|
426
|
+
stats,
|
427
|
+
)
|
428
|
+
dataset = load_from_disk(str(meds_extension_path))
|
429
|
+
|
415
430
|
train_set = dataset["train"]
|
416
431
|
validation_set = dataset["validation"]
|
417
432
|
test_set = dataset["test"]
|
@@ -451,7 +466,13 @@ def main():
|
|
451
466
|
dataset=final_splits, cehrgpt_tokenizer=tokenizer, data_args=data_args
|
452
467
|
)
|
453
468
|
if not data_args.streaming:
|
454
|
-
processed_dataset.save_to_disk(prepared_ds_path)
|
469
|
+
processed_dataset.save_to_disk(str(prepared_ds_path))
|
470
|
+
stats = processed_dataset.cleanup_cache_files()
|
471
|
+
LOG.info(
|
472
|
+
"Clean up the cached files for the cehrgpt finetuning dataset : %s",
|
473
|
+
stats,
|
474
|
+
)
|
475
|
+
processed_dataset = load_from_disk(str(prepared_ds_path))
|
455
476
|
|
456
477
|
# Set seed before initializing model.
|
457
478
|
set_seed(training_args.seed)
|
@@ -21,12 +21,13 @@ from transformers.utils import is_flash_attn_2_available, logging
|
|
21
21
|
|
22
22
|
from cehrgpt.data.hf_cehrgpt_dataset import create_cehrgpt_pretraining_dataset
|
23
23
|
from cehrgpt.data.hf_cehrgpt_dataset_collator import CehrGptDataCollator
|
24
|
+
from cehrgpt.data.hf_cehrgpt_dataset_mapping import MedToCehrGPTDatasetMapping
|
24
25
|
from cehrgpt.models.config import CEHRGPTConfig
|
25
26
|
from cehrgpt.models.hf_cehrgpt import CEHRGPT2LMHeadModel
|
26
27
|
from cehrgpt.models.pretrained_embeddings import PretrainedEmbeddings
|
27
28
|
from cehrgpt.models.tokenization_hf_cehrgpt import CehrGptTokenizer
|
28
29
|
from cehrgpt.runners.gpt_runner_util import parse_runner_args
|
29
|
-
from
|
30
|
+
from cehrgpt.runners.hf_gpt_runner_argument_dataclass import CehrGPTArguments
|
30
31
|
|
31
32
|
LOG = logging.get_logger("transformers")
|
32
33
|
|
@@ -82,11 +83,25 @@ def load_and_create_model(
|
|
82
83
|
model_abspath = os.path.expanduser(model_args.model_name_or_path)
|
83
84
|
if cehrgpt_args.continue_pretrain:
|
84
85
|
try:
|
85
|
-
|
86
|
+
pretrained_model = CEHRGPT2LMHeadModel.from_pretrained(
|
86
87
|
model_abspath,
|
87
88
|
attn_implementation=attn_implementation,
|
88
89
|
torch_dtype=torch_dtype,
|
89
90
|
)
|
91
|
+
if (
|
92
|
+
pretrained_model.config.max_position_embeddings
|
93
|
+
< model_args.max_position_embeddings
|
94
|
+
):
|
95
|
+
LOG.info(
|
96
|
+
f"Increase model.config.max_position_embeddings to {model_args.max_position_embeddings}"
|
97
|
+
)
|
98
|
+
pretrained_model.config.max_position_embeddings = (
|
99
|
+
model_args.max_position_embeddings
|
100
|
+
)
|
101
|
+
pretrained_model.resize_position_embeddings(
|
102
|
+
model_args.max_position_embeddings
|
103
|
+
)
|
104
|
+
return pretrained_model
|
90
105
|
except Exception as e:
|
91
106
|
LOG.error(
|
92
107
|
f"When continue_pretrain is set to True, it assumes that CEHR-GPT has been trained "
|
@@ -94,7 +109,7 @@ def load_and_create_model(
|
|
94
109
|
)
|
95
110
|
raise e
|
96
111
|
try:
|
97
|
-
model_config =
|
112
|
+
model_config = CEHRGPTConfig.from_pretrained(
|
98
113
|
model_abspath, attn_implementation=attn_implementation
|
99
114
|
)
|
100
115
|
except Exception as e:
|
@@ -148,7 +163,7 @@ def main():
|
|
148
163
|
# The iterable dataset doesn't have sharding implemented, so the number of works has to be set to 0
|
149
164
|
# Otherwise the trainer will throw an error
|
150
165
|
training_args.dataloader_num_workers = 0
|
151
|
-
training_args.dataloader_prefetch_factor =
|
166
|
+
training_args.dataloader_prefetch_factor = None
|
152
167
|
|
153
168
|
prepared_ds_path = generate_prepared_ds_path(data_args, model_args)
|
154
169
|
if os.path.exists(os.path.join(data_args.data_folder, "dataset_dict.json")):
|
@@ -212,14 +227,29 @@ def main():
|
|
212
227
|
except FileNotFoundError as e:
|
213
228
|
LOG.exception(e)
|
214
229
|
dataset = create_dataset_from_meds_reader(
|
215
|
-
data_args,
|
230
|
+
data_args=data_args,
|
231
|
+
dataset_mappings=[
|
232
|
+
MedToCehrGPTDatasetMapping(
|
233
|
+
data_args=data_args,
|
234
|
+
is_pretraining=True,
|
235
|
+
include_inpatient_hour_token=cehrgpt_args.include_inpatient_hour_token,
|
236
|
+
)
|
237
|
+
],
|
216
238
|
)
|
217
239
|
if not data_args.streaming:
|
218
|
-
dataset.save_to_disk(meds_extension_path)
|
240
|
+
dataset.save_to_disk(str(meds_extension_path))
|
241
|
+
stats = dataset.cleanup_cache_files()
|
242
|
+
LOG.info(
|
243
|
+
"Clean up the cached files for the cehrgpt dataset transformed from the MEDS: %s",
|
244
|
+
stats,
|
245
|
+
)
|
246
|
+
dataset = load_from_disk(str(meds_extension_path))
|
219
247
|
else:
|
220
248
|
# Load the dataset from the parquet files
|
221
249
|
dataset = load_parquet_as_dataset(
|
222
|
-
data_args.data_folder,
|
250
|
+
os.path.expanduser(data_args.data_folder),
|
251
|
+
split="train",
|
252
|
+
streaming=data_args.streaming,
|
223
253
|
)
|
224
254
|
# If streaming is enabled, we need to manually split the data into train/val
|
225
255
|
if data_args.streaming and data_args.validation_split_num:
|
@@ -274,7 +304,13 @@ def main():
|
|
274
304
|
)
|
275
305
|
# only save the data to the disk if it is not streaming
|
276
306
|
if not data_args.streaming:
|
277
|
-
processed_dataset.save_to_disk(prepared_ds_path)
|
307
|
+
processed_dataset.save_to_disk(str(prepared_ds_path))
|
308
|
+
stats = processed_dataset.cleanup_cache_files()
|
309
|
+
LOG.info(
|
310
|
+
"Clean up the cached files for the cehrgpt pretraining dataset: %s",
|
311
|
+
stats,
|
312
|
+
)
|
313
|
+
processed_dataset = load_from_disk(str(prepared_ds_path))
|
278
314
|
|
279
315
|
def filter_func(examples):
|
280
316
|
if cehrgpt_args.drop_long_sequences:
|
@@ -6,6 +6,10 @@ from typing import List, Optional
|
|
6
6
|
class CehrGPTArguments:
|
7
7
|
"""Arguments pertaining to what data we are going to input our model for training and eval."""
|
8
8
|
|
9
|
+
include_inpatient_hour_token: Optional[bool] = dataclasses.field(
|
10
|
+
default=True,
|
11
|
+
metadata={"help": "Include inpatient hour token"},
|
12
|
+
)
|
9
13
|
include_demographics: Optional[bool] = dataclasses.field(
|
10
14
|
default=False,
|
11
15
|
metadata={
|