cehrgpt 0.0.2__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cehrgpt/data/hf_cehrgpt_dataset.py +24 -4
- cehrgpt/data/hf_cehrgpt_dataset_collator.py +260 -84
- cehrgpt/data/hf_cehrgpt_dataset_mapping.py +99 -88
- cehrgpt/data/sample_packing_sampler.py +151 -0
- cehrgpt/generation/generate_batch_hf_gpt_sequence.py +12 -9
- cehrgpt/models/config.py +10 -0
- cehrgpt/models/hf_cehrgpt.py +243 -73
- cehrgpt/models/tokenization_hf_cehrgpt.py +4 -0
- cehrgpt/runners/data_utils.py +243 -0
- cehrgpt/runners/gpt_runner_util.py +0 -10
- cehrgpt/runners/hf_cehrgpt_finetune_runner.py +152 -279
- cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +229 -105
- cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +42 -0
- cehrgpt/runners/hyperparameter_search_util.py +4 -1
- cehrgpt/runners/sample_packing_trainer.py +168 -0
- cehrgpt/simulations/generate_plots.py +95 -0
- cehrgpt/simulations/run_simulation.sh +24 -0
- cehrgpt/simulations/time_embedding_simulation.py +250 -0
- cehrgpt/simulations/time_token_simulation.py +177 -0
- cehrgpt/tools/linear_prob/__init__.py +0 -0
- cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +467 -0
- cehrgpt/tools/linear_prob/train_with_cehrgpt_features.py +152 -0
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.0.dist-info}/METADATA +7 -5
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.0.dist-info}/RECORD +28 -26
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.0.dist-info}/WHEEL +1 -1
- cehrgpt/data/hf_cehrgpt_dpo_collator.py +0 -71
- cehrgpt/data/hf_cehrgpt_dpo_dataset_mapping.py +0 -61
- cehrgpt/generation/generate_paired_cehrgpt_sequence.py +0 -224
- cehrgpt/rl_finetune/cehrgpt_dpo_trainer.py +0 -586
- cehrgpt/rl_finetune/cehrgpt_ppo_trainer.py +0 -464
- cehrgpt/rl_finetune/ppo_finetune.py +0 -394
- cehrgpt/rl_finetune/ppo_finetune_v2.py +0 -373
- cehrgpt/runners/hf_cehrgpt_dpo_runner.py +0 -119
- /cehrgpt/{rl_finetune → simulations}/__init__.py +0 -0
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.0.dist-info/licenses}/LICENSE +0 -0
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.0.dist-info}/top_level.txt +0 -0
@@ -1,394 +0,0 @@
|
|
1
|
-
import datetime
|
2
|
-
import os
|
3
|
-
import pickle
|
4
|
-
import random
|
5
|
-
from collections import Counter, defaultdict
|
6
|
-
from functools import partial
|
7
|
-
from typing import Any, Dict, List, Tuple
|
8
|
-
|
9
|
-
import numpy as np
|
10
|
-
import torch
|
11
|
-
from cehrbert.models.hf_models.tokenization_utils import agg_helper
|
12
|
-
from cehrbert.runners.runner_util import load_parquet_as_dataset
|
13
|
-
from tqdm import tqdm
|
14
|
-
from transformers import GenerationConfig
|
15
|
-
from transformers.utils import is_flash_attn_2_available, logging
|
16
|
-
from trl import (
|
17
|
-
AutoModelForCausalLMWithValueHead,
|
18
|
-
PPOConfig,
|
19
|
-
PPOTrainer,
|
20
|
-
create_reference_model,
|
21
|
-
)
|
22
|
-
|
23
|
-
from cehrgpt.cehrgpt_args import create_inference_base_arg_parser
|
24
|
-
from cehrgpt.gpt_utils import get_cehrgpt_output_folder
|
25
|
-
from cehrgpt.models.hf_cehrgpt import CEHRGPT2LMHeadModel
|
26
|
-
from cehrgpt.models.tokenization_hf_cehrgpt import CehrGptTokenizer
|
27
|
-
|
28
|
-
LOG = logging.get_logger("transformers")
|
29
|
-
|
30
|
-
|
31
|
-
def extract_demographics_info(
|
32
|
-
records: Dict[str, Any]
|
33
|
-
) -> Dict[Tuple[str, str, str, str], Dict[str, int]]:
|
34
|
-
batched_concept_ids = records["concept_ids"]
|
35
|
-
outputs = defaultdict(dict)
|
36
|
-
for concept_ids in batched_concept_ids:
|
37
|
-
start_year, start_age, gender, race = concept_ids[:4]
|
38
|
-
existing_stats = outputs[(start_year, start_age, gender, race)]
|
39
|
-
for concept_id, cnt in dict(Counter(concept_ids[4:])).items():
|
40
|
-
if concept_id in existing_stats:
|
41
|
-
existing_stats[concept_id] += cnt
|
42
|
-
else:
|
43
|
-
existing_stats[concept_id] = cnt
|
44
|
-
if "total" in existing_stats:
|
45
|
-
existing_stats["total"] += 1
|
46
|
-
else:
|
47
|
-
existing_stats["total"] = 1
|
48
|
-
return outputs
|
49
|
-
|
50
|
-
|
51
|
-
def generate_single_batch(
|
52
|
-
model,
|
53
|
-
tokenizer,
|
54
|
-
batched_prompts,
|
55
|
-
max_new_tokens=512,
|
56
|
-
mini_num_of_concepts=1,
|
57
|
-
top_p=0.95,
|
58
|
-
top_k=50,
|
59
|
-
temperature=1.0,
|
60
|
-
repetition_penalty=1.0,
|
61
|
-
num_beams=1,
|
62
|
-
num_beam_groups=1,
|
63
|
-
epsilon_cutoff=0.0,
|
64
|
-
) -> List[List[str]]:
|
65
|
-
with torch.no_grad():
|
66
|
-
generation_config = GenerationConfig(
|
67
|
-
repetition_penalty=repetition_penalty,
|
68
|
-
max_length=max_new_tokens,
|
69
|
-
min_length=mini_num_of_concepts,
|
70
|
-
temperature=temperature,
|
71
|
-
top_p=top_p,
|
72
|
-
top_k=top_k,
|
73
|
-
bos_token_id=tokenizer.end_token_id,
|
74
|
-
eos_token_id=tokenizer.end_token_id,
|
75
|
-
pad_token_id=tokenizer.pad_token_id,
|
76
|
-
do_sample=True,
|
77
|
-
use_cache=True,
|
78
|
-
return_dict_in_generate=True,
|
79
|
-
output_attentions=False,
|
80
|
-
output_hidden_states=False,
|
81
|
-
output_scores=False,
|
82
|
-
renormalize_logits=True,
|
83
|
-
num_beams=num_beams,
|
84
|
-
num_beam_groups=num_beam_groups,
|
85
|
-
epsilon_cutoff=epsilon_cutoff,
|
86
|
-
)
|
87
|
-
results = model.generate(
|
88
|
-
inputs=batched_prompts, generation_config=generation_config
|
89
|
-
)
|
90
|
-
|
91
|
-
return [tokenizer.decode(seq.cpu().numpy()) for seq in results.sequences]
|
92
|
-
|
93
|
-
|
94
|
-
def main(args):
|
95
|
-
if torch.cuda.is_available():
|
96
|
-
device = torch.device("cuda")
|
97
|
-
else:
|
98
|
-
device = torch.device("cpu")
|
99
|
-
|
100
|
-
cehrgpt_tokenizer = CehrGptTokenizer.from_pretrained(args.tokenizer_folder)
|
101
|
-
model_folder_name = os.path.join(
|
102
|
-
args.output_folder, get_cehrgpt_output_folder(args, cehrgpt_tokenizer), "model"
|
103
|
-
)
|
104
|
-
|
105
|
-
if not os.path.exists(model_folder_name):
|
106
|
-
os.makedirs(model_folder_name)
|
107
|
-
|
108
|
-
if args.restore_from_checkpoint:
|
109
|
-
try:
|
110
|
-
cehrgpt_model = CEHRGPT2LMHeadModel.from_pretrained(
|
111
|
-
model_folder_name,
|
112
|
-
attn_implementation=(
|
113
|
-
"flash_attention_2" if is_flash_attn_2_available() else "eager"
|
114
|
-
),
|
115
|
-
torch_dtype=(
|
116
|
-
torch.bfloat16 if is_flash_attn_2_available() else torch.float32
|
117
|
-
),
|
118
|
-
)
|
119
|
-
except Exception:
|
120
|
-
LOG.warning(
|
121
|
-
"Checkpoint does not exist in %s, loading from the %s",
|
122
|
-
model_folder_name,
|
123
|
-
args.model_folder,
|
124
|
-
)
|
125
|
-
cehrgpt_model = CEHRGPT2LMHeadModel.from_pretrained(
|
126
|
-
args.model_folder,
|
127
|
-
attn_implementation=(
|
128
|
-
"flash_attention_2" if is_flash_attn_2_available() else "eager"
|
129
|
-
),
|
130
|
-
torch_dtype=(
|
131
|
-
torch.bfloat16 if is_flash_attn_2_available() else torch.float32
|
132
|
-
),
|
133
|
-
)
|
134
|
-
else:
|
135
|
-
cehrgpt_model = CEHRGPT2LMHeadModel.from_pretrained(
|
136
|
-
args.model_folder,
|
137
|
-
attn_implementation=(
|
138
|
-
"flash_attention_2" if is_flash_attn_2_available() else "eager"
|
139
|
-
),
|
140
|
-
torch_dtype=(
|
141
|
-
torch.bfloat16 if is_flash_attn_2_available() else torch.float32
|
142
|
-
),
|
143
|
-
)
|
144
|
-
|
145
|
-
cehrgpt_model.generation_config.pad_token_id = cehrgpt_tokenizer.pad_token_id
|
146
|
-
cehrgpt_model.generation_config.eos_token_id = cehrgpt_tokenizer.end_token_id
|
147
|
-
cehrgpt_model.generation_config.bos_token_id = cehrgpt_tokenizer.end_token_id
|
148
|
-
model = AutoModelForCausalLMWithValueHead(cehrgpt_model).to(device)
|
149
|
-
model.is_peft_model = False
|
150
|
-
ref_model = create_reference_model(model).to(device)
|
151
|
-
|
152
|
-
# create a ppo trainer
|
153
|
-
ppo_trainer = PPOTrainer(
|
154
|
-
config=PPOConfig(
|
155
|
-
batch_size=args.batch_size,
|
156
|
-
mini_batch_size=args.mini_batch_size,
|
157
|
-
init_kl_coef=args.init_kl_coef,
|
158
|
-
vf_coef=args.vf_coef,
|
159
|
-
kl_penalty=args.kl_penalty,
|
160
|
-
gamma=args.gamma,
|
161
|
-
),
|
162
|
-
model=model,
|
163
|
-
ref_model=ref_model,
|
164
|
-
tokenizer=cehrgpt_tokenizer,
|
165
|
-
)
|
166
|
-
|
167
|
-
LOG.info(f"Loading tokenizer at {args.model_folder}")
|
168
|
-
LOG.info(f"Loading model at {args.model_folder}")
|
169
|
-
LOG.info(f"Will save the fine-tuned model at {model_folder_name}")
|
170
|
-
LOG.info(f"Context window {args.context_window}")
|
171
|
-
LOG.info(f"Temperature {args.temperature}")
|
172
|
-
LOG.info(f"Repetition Penalty {args.repetition_penalty}")
|
173
|
-
LOG.info(f"Sampling Strategy {args.sampling_strategy}")
|
174
|
-
LOG.info(f"Num beam {args.num_beams}")
|
175
|
-
LOG.info(f"Num beam groups {args.num_beam_groups}")
|
176
|
-
LOG.info(f"Epsilon cutoff {args.epsilon_cutoff}")
|
177
|
-
LOG.info(f"Top P {args.top_p}")
|
178
|
-
LOG.info(f"Top K {args.top_k}")
|
179
|
-
LOG.info(f"Loading demographic_info at {args.demographic_data_path}")
|
180
|
-
|
181
|
-
dataset = load_parquet_as_dataset(args.demographic_data_path)
|
182
|
-
parts = dataset.filter(
|
183
|
-
lambda batched: [
|
184
|
-
num_of_concepts > 4 for num_of_concepts in batched["num_of_concepts"]
|
185
|
-
],
|
186
|
-
batched=True,
|
187
|
-
).map(
|
188
|
-
partial(agg_helper, map_func=extract_demographics_info),
|
189
|
-
batched=True,
|
190
|
-
batch_size=1000,
|
191
|
-
num_proc=args.num_proc,
|
192
|
-
remove_columns=dataset.column_names,
|
193
|
-
)
|
194
|
-
prompts_and_concept_stats = defaultdict(dict)
|
195
|
-
for stat in tqdm(parts, desc="Aggregating the concept counts"):
|
196
|
-
fixed_stat = pickle.loads(stat["data"])
|
197
|
-
for prompt, concept_stats in fixed_stat.items():
|
198
|
-
for concept_id, count in concept_stats.items():
|
199
|
-
if concept_id not in prompts_and_concept_stats[prompt]:
|
200
|
-
prompts_and_concept_stats[prompt][concept_id] = count
|
201
|
-
else:
|
202
|
-
prompts_and_concept_stats[prompt][concept_id] += count
|
203
|
-
|
204
|
-
prompt_weights = defaultdict(int)
|
205
|
-
for prompt, concept_stats in prompts_and_concept_stats.items():
|
206
|
-
prompt_weight = concept_stats.pop("total")
|
207
|
-
prompt_weights[prompt] = prompt_weight
|
208
|
-
total_count = sum(concept_stats.values())
|
209
|
-
for concept_id in concept_stats.keys():
|
210
|
-
concept_stats[concept_id] = concept_stats[concept_id] / total_count
|
211
|
-
|
212
|
-
logs = []
|
213
|
-
prompts = list(prompt_weights.keys())
|
214
|
-
weight_sum = sum(prompt_weights.values())
|
215
|
-
prompt_weights = np.asarray(list(prompt_weights.values())) / weight_sum
|
216
|
-
device = ppo_trainer.current_device
|
217
|
-
num_of_micro_batches = args.batch_size // args.mini_batch_size
|
218
|
-
for i in tqdm(range(args.num_of_steps)):
|
219
|
-
LOG.info(f"{datetime.datetime.now()}: Batch {i} started")
|
220
|
-
random_prompt = random.choices(prompts, weights=prompt_weights, k=1)[0]
|
221
|
-
prompt_weight = prompt_weights[prompts.index(random_prompt)]
|
222
|
-
LOG.info(
|
223
|
-
f"%s: Batch %s random_prompt: %s with weight %.2f%% (%d / %s)",
|
224
|
-
datetime.datetime.now(),
|
225
|
-
i,
|
226
|
-
random_prompt,
|
227
|
-
prompt_weight * 100,
|
228
|
-
int(prompt_weights[prompts.index(random_prompt)] * weight_sum),
|
229
|
-
weight_sum,
|
230
|
-
)
|
231
|
-
expected_concept_dist = prompts_and_concept_stats[random_prompt]
|
232
|
-
batched_sequences = []
|
233
|
-
for _ in range(num_of_micro_batches):
|
234
|
-
batched_prompts = torch.tensor(
|
235
|
-
[
|
236
|
-
cehrgpt_tokenizer.encode(random_prompt)
|
237
|
-
for _ in range(args.mini_batch_size)
|
238
|
-
]
|
239
|
-
).to(device)
|
240
|
-
mini_batched_sequences = generate_single_batch(
|
241
|
-
cehrgpt_model,
|
242
|
-
cehrgpt_tokenizer,
|
243
|
-
batched_prompts,
|
244
|
-
max_new_tokens=args.context_window,
|
245
|
-
mini_num_of_concepts=args.min_num_of_concepts,
|
246
|
-
top_p=args.top_p,
|
247
|
-
top_k=args.top_k,
|
248
|
-
temperature=args.temperature,
|
249
|
-
repetition_penalty=args.repetition_penalty,
|
250
|
-
num_beams=args.num_beams,
|
251
|
-
num_beam_groups=args.num_beam_groups,
|
252
|
-
epsilon_cutoff=args.epsilon_cutoff,
|
253
|
-
)
|
254
|
-
# Clear the cache
|
255
|
-
torch.cuda.empty_cache()
|
256
|
-
batched_sequences.extend(mini_batched_sequences)
|
257
|
-
|
258
|
-
LOG.info(f"{datetime.datetime.now()}: Batch {i} sequence generated")
|
259
|
-
reward = compute_marginal_dist_reward(
|
260
|
-
batched_sequences, expected_concept_dist, cehrgpt_tokenizer
|
261
|
-
)
|
262
|
-
LOG.info(f"{datetime.datetime.now()}: Batch {i} KL divergence reward: {reward}")
|
263
|
-
query_tensors = []
|
264
|
-
response_tensors = []
|
265
|
-
rewards = []
|
266
|
-
for sequence in batched_sequences:
|
267
|
-
query_tensors.append(torch.tensor(cehrgpt_tokenizer.encode(sequence[:4])))
|
268
|
-
response_tensors.append(
|
269
|
-
torch.tensor(cehrgpt_tokenizer.encode(sequence[4:]))
|
270
|
-
)
|
271
|
-
rewards.append(reward)
|
272
|
-
train_stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
|
273
|
-
LOG.info(f"{datetime.datetime.now()}: Batch {i} stats: {train_stats}")
|
274
|
-
logs.append(reward)
|
275
|
-
ppo_trainer.log_stats(stats=train_stats, batch={}, rewards=rewards)
|
276
|
-
ppo_trainer.save_pretrained(model_folder_name)
|
277
|
-
with open(os.path.join(model_folder_name, "ppo_finetune_stats.pkl"), "wb") as f:
|
278
|
-
pickle.dump(logs, f)
|
279
|
-
|
280
|
-
|
281
|
-
def compute_marginal_dist_reward(
|
282
|
-
batched_sequences: List[List[str]],
|
283
|
-
expected_concept_dist: Dict[str, float],
|
284
|
-
tokenizer: CehrGptTokenizer,
|
285
|
-
) -> torch.Tensor:
|
286
|
-
actual_concept_dist = dict(
|
287
|
-
Counter(
|
288
|
-
[
|
289
|
-
concept_id
|
290
|
-
for sequence in batched_sequences
|
291
|
-
for concept_id in sequence[4:]
|
292
|
-
]
|
293
|
-
)
|
294
|
-
)
|
295
|
-
total_count = sum(actual_concept_dist.values())
|
296
|
-
for concept_id in actual_concept_dist.keys():
|
297
|
-
actual_concept_dist[concept_id] /= total_count
|
298
|
-
# Translate the concept ids to token ids
|
299
|
-
actual_dist = np.zeros(tokenizer.vocab_size)
|
300
|
-
actual_dist[tokenizer.encode(list(actual_concept_dist.keys()))] = list(
|
301
|
-
actual_concept_dist.values()
|
302
|
-
)
|
303
|
-
# Add a small epsilon to avoid log(0)
|
304
|
-
epsilon = 1e-10
|
305
|
-
logprob_dist = torch.tensor(np.log(actual_dist + epsilon))
|
306
|
-
# Translate the concept ids to token ids
|
307
|
-
ref_dist = np.zeros(tokenizer.vocab_size)
|
308
|
-
ref_dist[tokenizer.encode(list(expected_concept_dist.keys()))] = list(
|
309
|
-
expected_concept_dist.values()
|
310
|
-
)
|
311
|
-
ref_logprob_dist = torch.tensor(np.log(ref_dist + epsilon))
|
312
|
-
|
313
|
-
# Flip is required due to this issue? :https://github.com/pytorch/pytorch/issues/57459
|
314
|
-
return -torch.nn.functional.kl_div(
|
315
|
-
ref_logprob_dist, logprob_dist, log_target=True, reduction="none"
|
316
|
-
).sum(-1)
|
317
|
-
|
318
|
-
|
319
|
-
def create_arg_parser():
|
320
|
-
base_arg_parser = create_inference_base_arg_parser(
|
321
|
-
description="Arguments for finetuning cehr-gpt using PPO"
|
322
|
-
)
|
323
|
-
base_arg_parser.add_argument(
|
324
|
-
"--mini_batch_size",
|
325
|
-
dest="mini_batch_size",
|
326
|
-
action="store",
|
327
|
-
type=int,
|
328
|
-
required=True,
|
329
|
-
)
|
330
|
-
base_arg_parser.add_argument(
|
331
|
-
"--init_kl_coef",
|
332
|
-
dest="init_kl_coef",
|
333
|
-
action="store",
|
334
|
-
type=float,
|
335
|
-
required=False,
|
336
|
-
default=0.1,
|
337
|
-
)
|
338
|
-
base_arg_parser.add_argument(
|
339
|
-
"--vf_coef",
|
340
|
-
dest="vf_coef",
|
341
|
-
action="store",
|
342
|
-
type=float,
|
343
|
-
required=False,
|
344
|
-
default=0.1,
|
345
|
-
)
|
346
|
-
base_arg_parser.add_argument(
|
347
|
-
"--kl_penalty",
|
348
|
-
dest="kl_penalty",
|
349
|
-
action="store",
|
350
|
-
choices=["kl", "abs", "mse", "full"],
|
351
|
-
required=False,
|
352
|
-
default="kl",
|
353
|
-
)
|
354
|
-
base_arg_parser.add_argument(
|
355
|
-
"--gamma",
|
356
|
-
dest="gamma",
|
357
|
-
action="store",
|
358
|
-
type=float,
|
359
|
-
required=False,
|
360
|
-
default=0.99,
|
361
|
-
)
|
362
|
-
base_arg_parser.add_argument(
|
363
|
-
"--num_proc",
|
364
|
-
dest="num_proc",
|
365
|
-
action="store",
|
366
|
-
type=int,
|
367
|
-
default=4,
|
368
|
-
required=False,
|
369
|
-
)
|
370
|
-
base_arg_parser.add_argument(
|
371
|
-
"--num_of_steps",
|
372
|
-
dest="num_of_steps",
|
373
|
-
action="store",
|
374
|
-
type=int,
|
375
|
-
default=1028,
|
376
|
-
required=False,
|
377
|
-
)
|
378
|
-
base_arg_parser.add_argument(
|
379
|
-
"--demographic_data_path",
|
380
|
-
dest="demographic_data_path",
|
381
|
-
action="store",
|
382
|
-
help="The path for your concept_path",
|
383
|
-
required=True,
|
384
|
-
)
|
385
|
-
base_arg_parser.add_argument(
|
386
|
-
"--restore_from_checkpoint",
|
387
|
-
dest="restore_from_checkpoint",
|
388
|
-
action="store_true",
|
389
|
-
)
|
390
|
-
return base_arg_parser
|
391
|
-
|
392
|
-
|
393
|
-
if __name__ == "__main__":
|
394
|
-
main(create_arg_parser().parse_args())
|