cehrgpt 0.0.2__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cehrgpt/analysis/irregularity.py +36 -0
- cehrgpt/data/hf_cehrgpt_dataset.py +25 -4
- cehrgpt/data/hf_cehrgpt_dataset_collator.py +635 -97
- cehrgpt/data/hf_cehrgpt_dataset_mapping.py +308 -95
- cehrgpt/data/sample_packing_sampler.py +181 -0
- cehrgpt/generation/generate_batch_hf_gpt_sequence.py +12 -9
- cehrgpt/generation/omop_converter_batch.py +32 -2
- cehrgpt/gpt_utils.py +20 -2
- cehrgpt/models/config.py +35 -0
- cehrgpt/models/hf_cehrgpt.py +470 -106
- cehrgpt/models/hf_modeling_outputs.py +1 -0
- cehrgpt/models/special_tokens.py +1 -0
- cehrgpt/models/tokenization_hf_cehrgpt.py +358 -71
- cehrgpt/runners/data_utils.py +358 -0
- cehrgpt/runners/gpt_runner_util.py +0 -10
- cehrgpt/runners/hf_cehrgpt_finetune_runner.py +181 -283
- cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +288 -112
- cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +90 -0
- cehrgpt/runners/hyperparameter_search_util.py +10 -8
- cehrgpt/runners/sample_packing_trainer.py +185 -0
- cehrgpt/simulations/generate_plots.py +95 -0
- cehrgpt/simulations/run_simulation.sh +24 -0
- cehrgpt/simulations/time_embedding_simulation.py +250 -0
- cehrgpt/simulations/time_token_simulation.py +177 -0
- cehrgpt/time_to_event/config/1_year_cabg.yaml +23 -0
- cehrgpt/time_to_event/time_to_event_model.py +2 -13
- cehrgpt/time_to_event/time_to_event_prediction.py +27 -13
- cehrgpt/tools/linear_prob/__init__.py +0 -0
- cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +495 -0
- cehrgpt/tools/linear_prob/train_with_cehrgpt_features.py +152 -0
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/METADATA +11 -8
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/RECORD +36 -32
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/WHEEL +1 -1
- cehrgpt/data/hf_cehrgpt_dpo_collator.py +0 -71
- cehrgpt/data/hf_cehrgpt_dpo_dataset_mapping.py +0 -61
- cehrgpt/generation/generate_paired_cehrgpt_sequence.py +0 -224
- cehrgpt/rl_finetune/cehrgpt_dpo_trainer.py +0 -586
- cehrgpt/rl_finetune/cehrgpt_ppo_trainer.py +0 -464
- cehrgpt/rl_finetune/ppo_finetune.py +0 -394
- cehrgpt/rl_finetune/ppo_finetune_v2.py +0 -373
- cehrgpt/runners/hf_cehrgpt_dpo_runner.py +0 -119
- /cehrgpt/{rl_finetune → simulations}/__init__.py +0 -0
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info/licenses}/LICENSE +0 -0
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/top_level.txt +0 -0
@@ -1,373 +0,0 @@
|
|
1
|
-
import datetime
|
2
|
-
import os
|
3
|
-
import pickle
|
4
|
-
from collections import Counter, defaultdict
|
5
|
-
from functools import partial
|
6
|
-
from typing import Any, Dict, List
|
7
|
-
|
8
|
-
import numpy as np
|
9
|
-
import torch
|
10
|
-
from cehrbert.models.hf_models.tokenization_utils import agg_helper
|
11
|
-
from cehrbert.runners.runner_util import load_parquet_as_dataset
|
12
|
-
from tqdm import tqdm
|
13
|
-
from transformers.utils import is_flash_attn_2_available, logging
|
14
|
-
from trl import AutoModelForCausalLMWithValueHead, PPOConfig, create_reference_model
|
15
|
-
|
16
|
-
from cehrgpt.cehrgpt_args import create_inference_base_arg_parser
|
17
|
-
from cehrgpt.generation.generate_batch_hf_gpt_sequence import generate_single_batch
|
18
|
-
from cehrgpt.gpt_utils import get_cehrgpt_output_folder
|
19
|
-
from cehrgpt.models.hf_cehrgpt import CEHRGPT2LMHeadModel
|
20
|
-
from cehrgpt.models.tokenization_hf_cehrgpt import CehrGptTokenizer
|
21
|
-
from cehrgpt.rl_finetune.cehrgpt_ppo_trainer import (
|
22
|
-
CehrGptPPODataCollator,
|
23
|
-
CehrGptPPOTrainer,
|
24
|
-
)
|
25
|
-
|
26
|
-
LOG = logging.get_logger("transformers")
|
27
|
-
|
28
|
-
|
29
|
-
def extract_concept_frequency(records: Dict[str, Any]) -> Dict[str, int]:
|
30
|
-
batched_concept_ids = records["concept_ids"]
|
31
|
-
outputs = defaultdict(int)
|
32
|
-
for concept_ids in batched_concept_ids:
|
33
|
-
for concept_id, cnt in dict(Counter(concept_ids[4:])).items():
|
34
|
-
outputs[concept_id] += cnt
|
35
|
-
return outputs
|
36
|
-
|
37
|
-
|
38
|
-
def main(args):
|
39
|
-
if torch.cuda.is_available():
|
40
|
-
device = torch.device("cuda")
|
41
|
-
else:
|
42
|
-
device = torch.device("cpu")
|
43
|
-
|
44
|
-
cehrgpt_tokenizer = CehrGptTokenizer.from_pretrained(args.tokenizer_folder)
|
45
|
-
model_folder_name = os.path.join(
|
46
|
-
args.output_folder, get_cehrgpt_output_folder(args, cehrgpt_tokenizer), "model"
|
47
|
-
)
|
48
|
-
|
49
|
-
if not os.path.exists(model_folder_name):
|
50
|
-
os.makedirs(model_folder_name)
|
51
|
-
|
52
|
-
if args.restore_from_checkpoint:
|
53
|
-
try:
|
54
|
-
cehrgpt_model = CEHRGPT2LMHeadModel.from_pretrained(
|
55
|
-
model_folder_name,
|
56
|
-
attn_implementation=(
|
57
|
-
"flash_attention_2" if is_flash_attn_2_available() else "eager"
|
58
|
-
),
|
59
|
-
torch_dtype=(
|
60
|
-
torch.bfloat16 if is_flash_attn_2_available() else torch.float32
|
61
|
-
),
|
62
|
-
)
|
63
|
-
except Exception:
|
64
|
-
LOG.warning(
|
65
|
-
"Checkpoint does not exist in %s, loading from the %s",
|
66
|
-
model_folder_name,
|
67
|
-
args.model_folder,
|
68
|
-
)
|
69
|
-
cehrgpt_model = CEHRGPT2LMHeadModel.from_pretrained(
|
70
|
-
args.model_folder,
|
71
|
-
attn_implementation=(
|
72
|
-
"flash_attention_2" if is_flash_attn_2_available() else "eager"
|
73
|
-
),
|
74
|
-
torch_dtype=(
|
75
|
-
torch.bfloat16 if is_flash_attn_2_available() else torch.float32
|
76
|
-
),
|
77
|
-
)
|
78
|
-
else:
|
79
|
-
cehrgpt_model = CEHRGPT2LMHeadModel.from_pretrained(
|
80
|
-
args.model_folder,
|
81
|
-
attn_implementation=(
|
82
|
-
"flash_attention_2" if is_flash_attn_2_available() else "eager"
|
83
|
-
),
|
84
|
-
torch_dtype=(
|
85
|
-
torch.bfloat16 if is_flash_attn_2_available() else torch.float32
|
86
|
-
),
|
87
|
-
)
|
88
|
-
|
89
|
-
cehrgpt_model.generation_config.pad_token_id = cehrgpt_tokenizer.pad_token_id
|
90
|
-
cehrgpt_model.generation_config.eos_token_id = cehrgpt_tokenizer.end_token_id
|
91
|
-
cehrgpt_model.generation_config.bos_token_id = cehrgpt_tokenizer.end_token_id
|
92
|
-
model = AutoModelForCausalLMWithValueHead(cehrgpt_model).to(device)
|
93
|
-
model.is_peft_model = False
|
94
|
-
ref_model = create_reference_model(model).to(device)
|
95
|
-
|
96
|
-
# create a ppo trainer
|
97
|
-
ppo_trainer = CehrGptPPOTrainer(
|
98
|
-
config=PPOConfig(
|
99
|
-
batch_size=args.batch_size,
|
100
|
-
mini_batch_size=args.mini_batch_size,
|
101
|
-
init_kl_coef=args.init_kl_coef,
|
102
|
-
vf_coef=args.vf_coef,
|
103
|
-
kl_penalty=args.kl_penalty,
|
104
|
-
gamma=args.gamma,
|
105
|
-
use_score_scaling=args.use_score_scaling,
|
106
|
-
),
|
107
|
-
model=model,
|
108
|
-
ref_model=ref_model,
|
109
|
-
tokenizer=cehrgpt_tokenizer,
|
110
|
-
training_data_collator=CehrGptPPODataCollator(
|
111
|
-
cehrgpt_tokenizer, max_length=args.context_window
|
112
|
-
),
|
113
|
-
)
|
114
|
-
|
115
|
-
LOG.info(f"Loading tokenizer at {args.model_folder}")
|
116
|
-
LOG.info(f"Loading model at {args.model_folder}")
|
117
|
-
LOG.info(f"Will save the fine-tuned model at {model_folder_name}")
|
118
|
-
LOG.info(f"Context window {args.context_window}")
|
119
|
-
LOG.info(f"Temperature {args.temperature}")
|
120
|
-
LOG.info(f"Repetition Penalty {args.repetition_penalty}")
|
121
|
-
LOG.info(f"Sampling Strategy {args.sampling_strategy}")
|
122
|
-
LOG.info(f"Num beam {args.num_beams}")
|
123
|
-
LOG.info(f"Num beam groups {args.num_beam_groups}")
|
124
|
-
LOG.info(f"Epsilon cutoff {args.epsilon_cutoff}")
|
125
|
-
LOG.info(f"Top P {args.top_p}")
|
126
|
-
LOG.info(f"Top K {args.top_k}")
|
127
|
-
LOG.info(f"Loading demographic_info at {args.demographic_data_path}")
|
128
|
-
|
129
|
-
dataset = load_parquet_as_dataset(args.demographic_data_path).filter(
|
130
|
-
lambda batched: [
|
131
|
-
model.config.n_positions >= num_of_concepts > args.min_num_tokens
|
132
|
-
for num_of_concepts in batched["num_of_concepts"]
|
133
|
-
],
|
134
|
-
batched=True,
|
135
|
-
)
|
136
|
-
parts = dataset.map(
|
137
|
-
partial(agg_helper, map_func=extract_concept_frequency),
|
138
|
-
batched=True,
|
139
|
-
batch_size=1000,
|
140
|
-
num_proc=args.num_proc,
|
141
|
-
remove_columns=dataset.column_names,
|
142
|
-
)
|
143
|
-
|
144
|
-
concept_stats = defaultdict(float)
|
145
|
-
for stat in tqdm(parts, desc="Aggregating the concept counts"):
|
146
|
-
fixed_stat = pickle.loads(stat["data"])
|
147
|
-
for concept_id, count in fixed_stat.items():
|
148
|
-
concept_stats[concept_id] += count
|
149
|
-
total_sum = sum(concept_stats.values())
|
150
|
-
for concept_id, count in concept_stats.items():
|
151
|
-
concept_stats[concept_id] = count / total_sum
|
152
|
-
|
153
|
-
logs = []
|
154
|
-
device = ppo_trainer.current_device
|
155
|
-
total_rows = len(dataset)
|
156
|
-
num_of_micro_batches = args.batch_size // args.mini_batch_size
|
157
|
-
for i in tqdm(range(args.num_of_steps)):
|
158
|
-
LOG.info(f"{datetime.datetime.now()}: Batch {i} started")
|
159
|
-
random_prompts = []
|
160
|
-
batched_sequences = []
|
161
|
-
batched_values = []
|
162
|
-
batched_value_indicators = []
|
163
|
-
for _ in range(num_of_micro_batches):
|
164
|
-
random_indices = np.random.randint(0, total_rows, args.mini_batch_size)
|
165
|
-
random_prompts_micro_batch = [
|
166
|
-
record["concept_ids"][:4] for record in dataset.select(random_indices)
|
167
|
-
]
|
168
|
-
random_prompts.extend(random_prompts_micro_batch)
|
169
|
-
micro_batched_prompts = [
|
170
|
-
cehrgpt_tokenizer.encode(random_prompt)
|
171
|
-
for random_prompt in random_prompts_micro_batch
|
172
|
-
]
|
173
|
-
|
174
|
-
micro_batched_sequences = generate_single_batch(
|
175
|
-
cehrgpt_model,
|
176
|
-
cehrgpt_tokenizer,
|
177
|
-
micro_batched_prompts,
|
178
|
-
max_new_tokens=args.context_window,
|
179
|
-
mini_num_of_concepts=args.min_num_of_concepts,
|
180
|
-
top_p=args.top_p,
|
181
|
-
top_k=args.top_k,
|
182
|
-
temperature=args.temperature,
|
183
|
-
repetition_penalty=args.repetition_penalty,
|
184
|
-
num_beams=args.num_beams,
|
185
|
-
num_beam_groups=args.num_beam_groups,
|
186
|
-
epsilon_cutoff=args.epsilon_cutoff,
|
187
|
-
device=device,
|
188
|
-
)
|
189
|
-
# Clear the cache
|
190
|
-
torch.cuda.empty_cache()
|
191
|
-
batched_sequences.extend(micro_batched_sequences["sequences"])
|
192
|
-
batched_values.extend(micro_batched_sequences["values"])
|
193
|
-
batched_value_indicators.extend(micro_batched_sequences["value_indicators"])
|
194
|
-
|
195
|
-
LOG.info(f"{datetime.datetime.now()}: Batch {i} sequence generated")
|
196
|
-
reward = compute_marginal_dist_reward(
|
197
|
-
batched_sequences, concept_stats, cehrgpt_tokenizer
|
198
|
-
)
|
199
|
-
LOG.info(f"{datetime.datetime.now()}: Batch {i} KL divergence reward: {reward}")
|
200
|
-
query_tensors = []
|
201
|
-
response_tensors = []
|
202
|
-
value_tensors = []
|
203
|
-
value_indicator_tensors = []
|
204
|
-
rewards = []
|
205
|
-
for sequence, values, value_indicators in zip(
|
206
|
-
batched_sequences, batched_values, batched_value_indicators
|
207
|
-
):
|
208
|
-
# Convert sequence to a NumPy array if it's not already one
|
209
|
-
sequence_array = np.asarray(sequence)
|
210
|
-
# Find the end token
|
211
|
-
condition_array = sequence_array == cehrgpt_tokenizer.end_token
|
212
|
-
end_index = (
|
213
|
-
np.argmax(condition_array)
|
214
|
-
if condition_array.any()
|
215
|
-
else len(sequence_array) - 1
|
216
|
-
)
|
217
|
-
|
218
|
-
sequence = sequence[: end_index + 1]
|
219
|
-
values = values[: end_index + 1]
|
220
|
-
value_indicators = value_indicators[: end_index + 1]
|
221
|
-
|
222
|
-
query_tensors.append(torch.tensor(cehrgpt_tokenizer.encode(sequence[:4])))
|
223
|
-
response_tensors.append(
|
224
|
-
torch.tensor(cehrgpt_tokenizer.encode(sequence[4:]))
|
225
|
-
)
|
226
|
-
value_tensors.append(torch.tensor(cehrgpt_tokenizer.encode_value(values)))
|
227
|
-
value_indicator_tensors.append(torch.tensor(value_indicators))
|
228
|
-
rewards.append(reward)
|
229
|
-
|
230
|
-
train_stats = ppo_trainer.step(
|
231
|
-
query_tensors,
|
232
|
-
response_tensors,
|
233
|
-
rewards,
|
234
|
-
value_tensors,
|
235
|
-
value_indicator_tensors,
|
236
|
-
)
|
237
|
-
LOG.info(f"{datetime.datetime.now()}: Batch {i} stats: {train_stats}")
|
238
|
-
logs.append(reward)
|
239
|
-
ppo_trainer.log_stats(stats=train_stats, batch={}, rewards=rewards)
|
240
|
-
ppo_trainer.save_pretrained(model_folder_name)
|
241
|
-
with open(os.path.join(model_folder_name, "ppo_finetune_stats.pkl"), "wb") as f:
|
242
|
-
pickle.dump(logs, f)
|
243
|
-
|
244
|
-
|
245
|
-
def compute_marginal_dist_reward(
|
246
|
-
batched_sequences: List[List[str]],
|
247
|
-
expected_concept_dist: Dict[str, float],
|
248
|
-
tokenizer: CehrGptTokenizer,
|
249
|
-
) -> torch.Tensor:
|
250
|
-
actual_concept_dist = dict(
|
251
|
-
Counter(
|
252
|
-
[
|
253
|
-
concept_id
|
254
|
-
for sequence in batched_sequences
|
255
|
-
for concept_id in sequence[4:]
|
256
|
-
]
|
257
|
-
)
|
258
|
-
)
|
259
|
-
total_count = sum(actual_concept_dist.values())
|
260
|
-
for concept_id in actual_concept_dist.keys():
|
261
|
-
actual_concept_dist[concept_id] /= total_count
|
262
|
-
# Translate the concept ids to token ids
|
263
|
-
actual_dist = np.zeros(tokenizer.vocab_size)
|
264
|
-
actual_dist[tokenizer.encode(list(actual_concept_dist.keys()))] = list(
|
265
|
-
actual_concept_dist.values()
|
266
|
-
)
|
267
|
-
# Add a small epsilon to avoid log(0)
|
268
|
-
epsilon = 1e-10
|
269
|
-
logprob_dist = torch.tensor(np.log(actual_dist + epsilon))
|
270
|
-
# Translate the concept ids to token ids
|
271
|
-
ref_dist = np.zeros(tokenizer.vocab_size)
|
272
|
-
ref_dist[tokenizer.encode(list(expected_concept_dist.keys()))] = list(
|
273
|
-
expected_concept_dist.values()
|
274
|
-
)
|
275
|
-
ref_logprob_dist = torch.tensor(np.log(ref_dist + epsilon))
|
276
|
-
|
277
|
-
# Flip is required due to this issue? :https://github.com/pytorch/pytorch/issues/57459
|
278
|
-
return torch.exp(
|
279
|
-
-torch.nn.functional.kl_div(
|
280
|
-
ref_logprob_dist, logprob_dist, log_target=True, reduction="none"
|
281
|
-
).sum(-1)
|
282
|
-
)
|
283
|
-
|
284
|
-
|
285
|
-
def create_arg_parser():
|
286
|
-
base_arg_parser = create_inference_base_arg_parser(
|
287
|
-
description="Arguments for finetuning cehr-gpt using PPO"
|
288
|
-
)
|
289
|
-
base_arg_parser.add_argument(
|
290
|
-
"--mini_batch_size",
|
291
|
-
dest="mini_batch_size",
|
292
|
-
action="store",
|
293
|
-
type=int,
|
294
|
-
required=True,
|
295
|
-
)
|
296
|
-
base_arg_parser.add_argument(
|
297
|
-
"--init_kl_coef",
|
298
|
-
dest="init_kl_coef",
|
299
|
-
action="store",
|
300
|
-
type=float,
|
301
|
-
required=False,
|
302
|
-
default=0.1,
|
303
|
-
)
|
304
|
-
base_arg_parser.add_argument(
|
305
|
-
"--vf_coef",
|
306
|
-
dest="vf_coef",
|
307
|
-
action="store",
|
308
|
-
type=float,
|
309
|
-
required=False,
|
310
|
-
default=0.1,
|
311
|
-
)
|
312
|
-
base_arg_parser.add_argument(
|
313
|
-
"--kl_penalty",
|
314
|
-
dest="kl_penalty",
|
315
|
-
action="store",
|
316
|
-
choices=["kl", "abs", "mse", "full"],
|
317
|
-
required=False,
|
318
|
-
default="kl",
|
319
|
-
)
|
320
|
-
base_arg_parser.add_argument(
|
321
|
-
"--gamma",
|
322
|
-
dest="gamma",
|
323
|
-
action="store",
|
324
|
-
type=float,
|
325
|
-
required=False,
|
326
|
-
default=0.99,
|
327
|
-
)
|
328
|
-
base_arg_parser.add_argument(
|
329
|
-
"--num_proc",
|
330
|
-
dest="num_proc",
|
331
|
-
action="store",
|
332
|
-
type=int,
|
333
|
-
default=4,
|
334
|
-
required=False,
|
335
|
-
)
|
336
|
-
base_arg_parser.add_argument(
|
337
|
-
"--num_of_steps",
|
338
|
-
dest="num_of_steps",
|
339
|
-
action="store",
|
340
|
-
type=int,
|
341
|
-
default=1028,
|
342
|
-
required=False,
|
343
|
-
)
|
344
|
-
base_arg_parser.add_argument(
|
345
|
-
"--min_num_tokens",
|
346
|
-
dest="min_num_tokens",
|
347
|
-
action="store",
|
348
|
-
type=int,
|
349
|
-
default=4,
|
350
|
-
required=False,
|
351
|
-
)
|
352
|
-
base_arg_parser.add_argument(
|
353
|
-
"--demographic_data_path",
|
354
|
-
dest="demographic_data_path",
|
355
|
-
action="store",
|
356
|
-
help="The path for your concept_path",
|
357
|
-
required=True,
|
358
|
-
)
|
359
|
-
base_arg_parser.add_argument(
|
360
|
-
"--restore_from_checkpoint",
|
361
|
-
dest="restore_from_checkpoint",
|
362
|
-
action="store_true",
|
363
|
-
)
|
364
|
-
base_arg_parser.add_argument(
|
365
|
-
"--use_score_scaling",
|
366
|
-
dest="use_score_scaling",
|
367
|
-
action="store_true",
|
368
|
-
)
|
369
|
-
return base_arg_parser
|
370
|
-
|
371
|
-
|
372
|
-
if __name__ == "__main__":
|
373
|
-
main(create_arg_parser().parse_args())
|
@@ -1,119 +0,0 @@
|
|
1
|
-
from cehrbert.data_generators.hf_data_generator.hf_dataset import (
|
2
|
-
apply_cehrbert_dataset_mapping,
|
3
|
-
)
|
4
|
-
from cehrbert.runners.runner_util import (
|
5
|
-
generate_prepared_ds_path,
|
6
|
-
get_last_hf_checkpoint,
|
7
|
-
load_parquet_as_dataset,
|
8
|
-
)
|
9
|
-
from datasets import DatasetDict, load_from_disk
|
10
|
-
from transformers import set_seed
|
11
|
-
from transformers.utils import is_flash_attn_2_available, logging
|
12
|
-
|
13
|
-
from cehrgpt.data.hf_cehrgpt_dpo_collator import CehrGptDPODataCollator
|
14
|
-
from cehrgpt.data.hf_cehrgpt_dpo_dataset_mapping import HFCehrGptDPOTokenizationMapping
|
15
|
-
from cehrgpt.models.hf_cehrgpt import CEHRGPT2LMHeadModel
|
16
|
-
from cehrgpt.rl_finetune.cehrgpt_dpo_trainer import CehrGptDPOTrainer
|
17
|
-
from cehrgpt.runners.gpt_runner_util import parse_dpo_runner_args
|
18
|
-
from cehrgpt.runners.hf_cehrgpt_finetune_runner import load_pretrained_tokenizer
|
19
|
-
|
20
|
-
LOG = logging.get_logger("transformers")
|
21
|
-
|
22
|
-
|
23
|
-
def main():
|
24
|
-
cehrgpt_args, data_args, model_args, dpo_config = parse_dpo_runner_args()
|
25
|
-
tokenizer = load_pretrained_tokenizer(model_args)
|
26
|
-
prepared_ds_path = generate_prepared_ds_path(
|
27
|
-
data_args, model_args, data_folder=data_args.cohort_folder
|
28
|
-
)
|
29
|
-
if any(prepared_ds_path.glob("*")):
|
30
|
-
LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
|
31
|
-
processed_dataset = load_from_disk(str(prepared_ds_path))
|
32
|
-
LOG.info("Prepared dataset loaded from disk...")
|
33
|
-
else:
|
34
|
-
dataset = load_parquet_as_dataset(data_args.data_folder)
|
35
|
-
# Random split
|
36
|
-
dataset = dataset.train_test_split(
|
37
|
-
test_size=data_args.validation_split_percentage, seed=dpo_config.seed
|
38
|
-
)
|
39
|
-
processed_dataset = apply_cehrbert_dataset_mapping(
|
40
|
-
dataset,
|
41
|
-
mapping_function=HFCehrGptDPOTokenizationMapping(tokenizer),
|
42
|
-
batch_size=data_args.preprocessing_batch_size,
|
43
|
-
num_proc=data_args.preprocessing_num_workers,
|
44
|
-
streaming=data_args.streaming,
|
45
|
-
)
|
46
|
-
|
47
|
-
processed_dataset = processed_dataset.filter(
|
48
|
-
lambda batch: [
|
49
|
-
len(chosen_concept_ids) < model_args.max_position_embeddings
|
50
|
-
for chosen_concept_ids in batch["chosen_concept_ids"]
|
51
|
-
],
|
52
|
-
batched=True,
|
53
|
-
batch_size=data_args.preprocessing_batch_size,
|
54
|
-
num_proc=data_args.preprocessing_num_workers,
|
55
|
-
).filter(
|
56
|
-
lambda batch: [
|
57
|
-
len(rejected_concept_ids) < model_args.max_position_embeddings
|
58
|
-
for rejected_concept_ids in batch["rejected_concept_ids"]
|
59
|
-
],
|
60
|
-
batched=True,
|
61
|
-
batch_size=data_args.preprocessing_batch_size,
|
62
|
-
num_proc=data_args.preprocessing_num_workers,
|
63
|
-
)
|
64
|
-
processed_dataset.save_to_disk(prepared_ds_path)
|
65
|
-
|
66
|
-
# Set seed before initializing model.
|
67
|
-
set_seed(dpo_config.seed)
|
68
|
-
processed_dataset.set_format("pt")
|
69
|
-
|
70
|
-
# A hacky way to prevent the training from removing unmatched inputs
|
71
|
-
dpo_config.label_names = [
|
72
|
-
"chosen_input_ids",
|
73
|
-
"rejected_input_ids",
|
74
|
-
"chosen_concept_values",
|
75
|
-
"rejected_concept_values",
|
76
|
-
"chosen_concept_value_masks",
|
77
|
-
"rejected_concept_value_masks",
|
78
|
-
]
|
79
|
-
|
80
|
-
attn_implementation = (
|
81
|
-
"flash_attention_2" if is_flash_attn_2_available() else "eager"
|
82
|
-
)
|
83
|
-
model = CEHRGPT2LMHeadModel.from_pretrained(
|
84
|
-
model_args.model_name_or_path,
|
85
|
-
attn_implementation=attn_implementation,
|
86
|
-
)
|
87
|
-
ref_model = CEHRGPT2LMHeadModel.from_pretrained(
|
88
|
-
model_args.model_name_or_path,
|
89
|
-
attn_implementation=attn_implementation,
|
90
|
-
)
|
91
|
-
|
92
|
-
# Initialize Trainer for final training on the combined train+val set
|
93
|
-
trainer = CehrGptDPOTrainer(
|
94
|
-
model=model,
|
95
|
-
ref_model=ref_model,
|
96
|
-
args=dpo_config,
|
97
|
-
tokenizer=tokenizer,
|
98
|
-
train_dataset=processed_dataset["train"],
|
99
|
-
eval_dataset=processed_dataset["test"],
|
100
|
-
data_collator=CehrGptDPODataCollator(
|
101
|
-
tokenizer=tokenizer,
|
102
|
-
max_length=model_args.max_position_embeddings,
|
103
|
-
pretraining=False,
|
104
|
-
include_ttv_prediction=False,
|
105
|
-
use_sub_time_tokenization=False,
|
106
|
-
),
|
107
|
-
)
|
108
|
-
# Train the model on the combined train + val set
|
109
|
-
checkpoint = get_last_hf_checkpoint(dpo_config)
|
110
|
-
train_result = trainer.train(resume_from_checkpoint=checkpoint)
|
111
|
-
trainer.save_model() # Saves the tokenizer too for easy upload
|
112
|
-
metrics = train_result.metrics
|
113
|
-
trainer.log_metrics("train", metrics)
|
114
|
-
trainer.save_metrics("train", metrics)
|
115
|
-
trainer.save_state()
|
116
|
-
|
117
|
-
|
118
|
-
if __name__ == "__main__":
|
119
|
-
main()
|
File without changes
|
File without changes
|
File without changes
|