cehrgpt 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. __init__.py +0 -0
  2. cehrgpt/__init__.py +0 -0
  3. cehrgpt/analysis/__init__.py +0 -0
  4. cehrgpt/analysis/privacy/__init__.py +0 -0
  5. cehrgpt/analysis/privacy/attribute_inference.py +275 -0
  6. cehrgpt/analysis/privacy/attribute_inference_config.yml +8975 -0
  7. cehrgpt/analysis/privacy/member_inference.py +172 -0
  8. cehrgpt/analysis/privacy/nearest_neighbor_inference.py +189 -0
  9. cehrgpt/analysis/privacy/reid_inference.py +407 -0
  10. cehrgpt/analysis/privacy/utils.py +255 -0
  11. cehrgpt/cehrgpt_args.py +142 -0
  12. cehrgpt/data/__init__.py +0 -0
  13. cehrgpt/data/hf_cehrgpt_dataset.py +80 -0
  14. cehrgpt/data/hf_cehrgpt_dataset_collator.py +482 -0
  15. cehrgpt/data/hf_cehrgpt_dataset_mapping.py +116 -0
  16. cehrgpt/generation/__init__.py +0 -0
  17. cehrgpt/generation/chatgpt_generation.py +106 -0
  18. cehrgpt/generation/generate_batch_hf_gpt_sequence.py +333 -0
  19. cehrgpt/generation/omop_converter_batch.py +644 -0
  20. cehrgpt/generation/omop_entity.py +515 -0
  21. cehrgpt/gpt_utils.py +331 -0
  22. cehrgpt/models/__init__.py +0 -0
  23. cehrgpt/models/config.py +205 -0
  24. cehrgpt/models/hf_cehrgpt.py +1817 -0
  25. cehrgpt/models/hf_modeling_outputs.py +158 -0
  26. cehrgpt/models/pretrained_embeddings.py +82 -0
  27. cehrgpt/models/special_tokens.py +30 -0
  28. cehrgpt/models/tokenization_hf_cehrgpt.py +1077 -0
  29. cehrgpt/omop/__init__.py +0 -0
  30. cehrgpt/omop/condition_era.py +20 -0
  31. cehrgpt/omop/observation_period.py +43 -0
  32. cehrgpt/omop/omop_argparse.py +38 -0
  33. cehrgpt/omop/omop_table_builder.py +86 -0
  34. cehrgpt/omop/queries/__init__.py +0 -0
  35. cehrgpt/omop/queries/condition_era.py +86 -0
  36. cehrgpt/omop/queries/observation_period.py +135 -0
  37. cehrgpt/omop/sample_omop_tables.py +71 -0
  38. cehrgpt/runners/__init__.py +0 -0
  39. cehrgpt/runners/gpt_runner_util.py +99 -0
  40. cehrgpt/runners/hf_cehrgpt_finetune_runner.py +746 -0
  41. cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +370 -0
  42. cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +137 -0
  43. cehrgpt/runners/hyperparameter_search_util.py +223 -0
  44. cehrgpt/time_to_event/__init__.py +0 -0
  45. cehrgpt/time_to_event/config/30_day_readmission.yaml +8 -0
  46. cehrgpt/time_to_event/config/next_visit_type_prediction.yaml +8 -0
  47. cehrgpt/time_to_event/config/t2dm_hf.yaml +8 -0
  48. cehrgpt/time_to_event/time_to_event_model.py +226 -0
  49. cehrgpt/time_to_event/time_to_event_prediction.py +347 -0
  50. cehrgpt/time_to_event/time_to_event_utils.py +55 -0
  51. cehrgpt/tools/__init__.py +0 -0
  52. cehrgpt/tools/ehrshot_benchmark.py +74 -0
  53. cehrgpt/tools/generate_pretrained_embeddings.py +130 -0
  54. cehrgpt/tools/merge_synthetic_real_dataasets.py +218 -0
  55. cehrgpt/tools/upload_omop_tables.py +108 -0
  56. cehrgpt-0.0.1.dist-info/LICENSE +21 -0
  57. cehrgpt-0.0.1.dist-info/METADATA +66 -0
  58. cehrgpt-0.0.1.dist-info/RECORD +60 -0
  59. cehrgpt-0.0.1.dist-info/WHEEL +5 -0
  60. cehrgpt-0.0.1.dist-info/top_level.txt +2 -0
@@ -0,0 +1,347 @@
1
+ import datetime
2
+ import glob
3
+ import os
4
+ import shutil
5
+ import uuid
6
+ from dataclasses import asdict, dataclass
7
+ from typing import Any, Dict, List
8
+
9
+ import pandas as pd
10
+ import torch
11
+ import yaml
12
+ from cehrbert.runners.runner_util import load_parquet_as_dataset
13
+ from datasets import Dataset
14
+ from tqdm import tqdm
15
+ from transformers.utils import is_flash_attn_2_available, logging
16
+
17
+ from cehrgpt.cehrgpt_args import create_inference_base_arg_parser
18
+ from cehrgpt.gpt_utils import get_cehrgpt_output_folder, is_visit_end, is_visit_start
19
+ from cehrgpt.models.hf_cehrgpt import CEHRGPT2LMHeadModel
20
+ from cehrgpt.models.tokenization_hf_cehrgpt import CehrGptTokenizer
21
+ from cehrgpt.time_to_event.time_to_event_model import TimeToEventModel
22
+
23
+ LOG = logging.get_logger("transformers")
24
+
25
+
26
+ @dataclass
27
+ class TaskConfig:
28
+ task_name: str
29
+ outcome_events: List[str]
30
+ include_descendants: bool = False
31
+ future_visit_start: int = 0
32
+ future_visit_end: int = -1
33
+ prediction_window_start: int = 0
34
+ prediction_window_end: int = 365
35
+ max_new_tokens: int = 128
36
+
37
+
38
+ def load_task_config_from_yaml(task_config_yaml_file_path: str) -> TaskConfig:
39
+ # Read YAML file
40
+ try:
41
+ with open(task_config_yaml_file_path, "r") as stream:
42
+ task_definition = yaml.safe_load(stream)
43
+ return TaskConfig(**task_definition)
44
+ except yaml.YAMLError | OSError as e:
45
+ raise ValueError(
46
+ f"Could not open the task_config yaml file from {task_config_yaml_file_path}"
47
+ ) from e
48
+
49
+
50
+ def get_device():
51
+ return torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
52
+
53
+
54
+ def main(args):
55
+ cehrgpt_tokenizer = CehrGptTokenizer.from_pretrained(args.tokenizer_folder)
56
+ cehrgpt_model = (
57
+ CEHRGPT2LMHeadModel.from_pretrained(
58
+ args.model_folder,
59
+ attn_implementation=(
60
+ "flash_attention_2" if is_flash_attn_2_available() else "eager"
61
+ ),
62
+ torch_dtype=(
63
+ torch.bfloat16
64
+ if is_flash_attn_2_available() and args.use_bfloat16
65
+ else "auto"
66
+ ),
67
+ )
68
+ .eval()
69
+ .to(get_device())
70
+ )
71
+ cehrgpt_model.generation_config.pad_token_id = cehrgpt_tokenizer.pad_token_id
72
+ cehrgpt_model.generation_config.eos_token_id = cehrgpt_tokenizer.end_token_id
73
+ cehrgpt_model.generation_config.bos_token_id = cehrgpt_tokenizer.end_token_id
74
+
75
+ folder_name = get_cehrgpt_output_folder(args, cehrgpt_tokenizer)
76
+
77
+ task_config = load_task_config_from_yaml(args.task_config)
78
+ task_name = task_config.task_name
79
+ outcome_events = task_config.outcome_events
80
+
81
+ if task_config.include_descendants:
82
+ if not args.concept_ancestor:
83
+ raise RuntimeError(
84
+ "When include_descendants is set to True, the concept_ancestor data needs to be provided."
85
+ )
86
+ concept_ancestor = pd.read_parquet(args.concept_ancestor)
87
+ ancestor_concept_ids = [int(_) for _ in outcome_events if _.isnumeric()]
88
+ descendant_concept_ids = (
89
+ concept_ancestor[
90
+ concept_ancestor.ancestor_concept_id.isin(ancestor_concept_ids)
91
+ ]
92
+ .descendant_concept_id.unique()
93
+ .astype(str)
94
+ .tolist()
95
+ )
96
+ descendant_concept_ids = [
97
+ _ for _ in descendant_concept_ids if _ not in outcome_events
98
+ ]
99
+ outcome_events += descendant_concept_ids
100
+
101
+ prediction_output_folder_name = os.path.join(
102
+ args.output_folder, folder_name, task_name
103
+ )
104
+ temp_folder = os.path.join(args.output_folder, folder_name, "temp")
105
+ os.makedirs(prediction_output_folder_name, exist_ok=True)
106
+ os.makedirs(temp_folder, exist_ok=True)
107
+
108
+ LOG.info(f"Loading tokenizer at {args.model_folder}")
109
+ LOG.info(f"Loading model at {args.model_folder}")
110
+ LOG.info(f"Loading dataset_folder at {args.dataset_folder}")
111
+ LOG.info(f"Write time sensitive predictions to {prediction_output_folder_name}")
112
+ LOG.info(f"Context window {args.context_window}")
113
+ LOG.info(f"Number of new tokens {task_config.max_new_tokens}")
114
+ LOG.info(f"Temperature {args.temperature}")
115
+ LOG.info(f"Repetition Penalty {args.repetition_penalty}")
116
+ LOG.info(f"Sampling Strategy {args.sampling_strategy}")
117
+ LOG.info(f"Epsilon cutoff {args.epsilon_cutoff}")
118
+ LOG.info(f"Top P {args.top_p}")
119
+ LOG.info(f"Top K {args.top_k}")
120
+
121
+ cehrgpt_model.resize_position_embeddings(
122
+ cehrgpt_model.config.max_position_embeddings + task_config.max_new_tokens
123
+ )
124
+
125
+ generation_config = TimeToEventModel.get_generation_config(
126
+ tokenizer=cehrgpt_tokenizer,
127
+ max_length=cehrgpt_model.config.n_positions,
128
+ num_return_sequences=args.num_return_sequences,
129
+ top_p=args.top_p,
130
+ top_k=args.top_k,
131
+ temperature=args.temperature,
132
+ repetition_penalty=args.repetition_penalty,
133
+ epsilon_cutoff=args.epsilon_cutoff,
134
+ max_new_tokens=task_config.max_new_tokens,
135
+ )
136
+ ts_pred_model = TimeToEventModel(
137
+ tokenizer=cehrgpt_tokenizer,
138
+ model=cehrgpt_model,
139
+ outcome_events=outcome_events,
140
+ generation_config=generation_config,
141
+ batch_size=args.batch_size,
142
+ device=get_device(),
143
+ )
144
+ dataset = load_parquet_as_dataset(args.dataset_folder)
145
+
146
+ def filter_func(examples):
147
+ return [_ >= args.min_num_of_concepts for _ in examples["num_of_concepts"]]
148
+
149
+ test_dataset = dataset.filter(filter_func, batched=True, batch_size=1000)
150
+ test_dataset = test_dataset.shuffle(seed=42)
151
+
152
+ # Filter out the records for which the predictions have been generated previously
153
+ test_dataset = filter_out_existing_results(
154
+ test_dataset, prediction_output_folder_name
155
+ )
156
+ tte_outputs = []
157
+ for record in tqdm(test_dataset, total=len(test_dataset)):
158
+ sample_identifier = (
159
+ f"{record['person_id']}_{record['index_date'].strftime('%Y_%m_%d')}"
160
+ )
161
+ if acquire_lock_or_skip_if_already_exist(
162
+ output_folder=temp_folder, sample_id=sample_identifier
163
+ ):
164
+ continue
165
+ partial_history = record["concept_ids"]
166
+ label = record["label"]
167
+ time_to_event = record["time_to_event"] if "time_to_event" in record else None
168
+ seq_length = len(partial_history)
169
+ if (
170
+ generation_config.max_length
171
+ <= seq_length + generation_config.max_new_tokens
172
+ ):
173
+ start_index = seq_length - (
174
+ generation_config.max_length - generation_config.max_new_tokens
175
+ )
176
+ # Make sure the first token starts on VS
177
+ for i, token in enumerate(partial_history[start_index:]):
178
+ if is_visit_start(token):
179
+ start_index += i
180
+ break
181
+ partial_history = partial_history[start_index:]
182
+
183
+ concept_time_to_event = ts_pred_model.predict_time_to_events(
184
+ partial_history,
185
+ task_config.future_visit_start,
186
+ task_config.future_visit_end,
187
+ task_config.prediction_window_start,
188
+ task_config.prediction_window_end,
189
+ args.debug,
190
+ args.max_n_trial,
191
+ )
192
+ visit_counter = sum([int(is_visit_end(_)) for _ in partial_history])
193
+ tte_outputs.append(
194
+ {
195
+ "person_id": record["person_id"],
196
+ "index_date": record["index_date"],
197
+ "visit_counter": visit_counter,
198
+ "label": label,
199
+ "time_to_event": time_to_event,
200
+ "prediction": (
201
+ asdict(concept_time_to_event) if concept_time_to_event else None
202
+ ),
203
+ }
204
+ )
205
+ delete_lock_create_processed_flag(
206
+ output_folder=temp_folder, sample_id=sample_identifier
207
+ )
208
+ flush_to_disk_if_full(
209
+ tte_outputs, prediction_output_folder_name, args.buffer_size
210
+ )
211
+
212
+ # Final flush
213
+ flush_to_disk_if_full(tte_outputs, prediction_output_folder_name, args.buffer_size)
214
+ # Remove the temp folder
215
+ shutil.rmtree(temp_folder)
216
+
217
+
218
+ def delete_lock_create_processed_flag(output_folder: str, sample_id: str):
219
+ processed_flag_file = os.path.join(output_folder, f"{sample_id}.done")
220
+ # Obtain the lock for this example by creating an empty lock file
221
+ try:
222
+ # Using 'x' mode for exclusive creation; fails if the file already exists
223
+ with open(processed_flag_file, "x"):
224
+ pass # The file is created; nothing is written to it
225
+ except FileExistsError as e:
226
+ raise FileExistsError(
227
+ f"The processed flag file {processed_flag_file} already exists."
228
+ ) from e
229
+
230
+ lock_file = os.path.join(output_folder, f"{sample_id}.lock")
231
+ # Clean up the lock file
232
+ # Safely attempt to delete the lock file
233
+ try:
234
+ os.remove(lock_file)
235
+ except OSError as e:
236
+ raise OSError(f"Can not remove the lock file at {lock_file}") from e
237
+
238
+
239
+ def acquire_lock_or_skip_if_already_exist(output_folder: str, sample_id: str):
240
+ lock_file = os.path.join(output_folder, f"{sample_id}.lock")
241
+ if os.path.exists(lock_file):
242
+ LOG.info(f"Other process acquired the lock --> %s. Skipping...", sample_id)
243
+ return True
244
+ processed_flag_file = os.path.join(output_folder, f"{sample_id}.done")
245
+ if os.path.exists(processed_flag_file):
246
+ LOG.info(f"The sample has been processed --> %s. Skipping...", sample_id)
247
+ return True
248
+
249
+ # Obtain the lock for this example by creating an empty lock file
250
+ try:
251
+ # Using 'x' mode for exclusive creation; fails if the file already exists
252
+ with open(lock_file, "x"):
253
+ pass # The file is created; nothing is written to it
254
+ except FileExistsError:
255
+ LOG.info(f"Other process acquired the lock --> %s. Skipping...", sample_id)
256
+ return True
257
+ return False
258
+
259
+
260
+ def filter_out_existing_results(
261
+ test_dataset: Dataset, prediction_output_folder_name: str
262
+ ):
263
+ parquet_files = glob.glob(os.path.join(prediction_output_folder_name, "*parquet"))
264
+ if parquet_files:
265
+ cohort_members = set()
266
+ results_dataframe = pd.read_parquet(parquet_files)[["person_id", "index_date"]]
267
+ for row in results_dataframe.itertuples():
268
+ cohort_members.add((row.person_id, row.index_date.strftime("%Y-%m-%d")))
269
+
270
+ def filter_func(batched):
271
+ return [
272
+ (person_id, index_date.strftime("%Y-%m-%d")) not in cohort_members
273
+ for person_id, index_date in zip(
274
+ batched["person_id"], batched["index_date"]
275
+ )
276
+ ]
277
+
278
+ test_dataset = test_dataset.filter(filter_func, batched=True, batch_size=1000)
279
+ return test_dataset
280
+
281
+
282
+ def flush_to_disk_if_full(
283
+ tte_outputs: List[Dict[str, Any]], prediction_output_folder_name, buffer_size: int
284
+ ) -> None:
285
+ if len(tte_outputs) >= buffer_size:
286
+ LOG.info(
287
+ f"{datetime.datetime.now()}: Flushing time to visit predictions to disk"
288
+ )
289
+ output_parquet_file = os.path.join(
290
+ prediction_output_folder_name, f"{uuid.uuid4()}.parquet"
291
+ )
292
+ pd.DataFrame(
293
+ tte_outputs,
294
+ columns=[
295
+ "person_id",
296
+ "index_date",
297
+ "visit_counter",
298
+ "label",
299
+ "time_to_event",
300
+ "prediction",
301
+ ],
302
+ ).to_parquet(output_parquet_file)
303
+ tte_outputs.clear()
304
+
305
+
306
+ def create_arg_parser():
307
+ base_arg_parser = create_inference_base_arg_parser(
308
+ description="Arguments for time sensitive prediction"
309
+ )
310
+ base_arg_parser.add_argument(
311
+ "--dataset_folder",
312
+ dest="dataset_folder",
313
+ action="store",
314
+ help="The path for your dataset",
315
+ required=True,
316
+ )
317
+ base_arg_parser.add_argument(
318
+ "--num_return_sequences",
319
+ dest="num_return_sequences",
320
+ action="store",
321
+ type=int,
322
+ required=True,
323
+ )
324
+ base_arg_parser.add_argument(
325
+ "--task_config", dest="task_config", action="store", required=True
326
+ )
327
+ base_arg_parser.add_argument(
328
+ "--concept_ancestor", dest="concept_ancestor", action="store", required=False
329
+ )
330
+ base_arg_parser.add_argument(
331
+ "--debug",
332
+ dest="debug",
333
+ action="store_true",
334
+ )
335
+ base_arg_parser.add_argument(
336
+ "--max_n_trial",
337
+ dest="max_n_trial",
338
+ action="store",
339
+ type=int,
340
+ default=2,
341
+ required=False,
342
+ )
343
+ return base_arg_parser
344
+
345
+
346
+ if __name__ == "__main__":
347
+ main(create_arg_parser().parse_args())
@@ -0,0 +1,55 @@
1
+ from collections import defaultdict
2
+ from typing import Any, Dict, List, Tuple
3
+
4
+ import numpy as np
5
+
6
+ from cehrgpt.gpt_utils import is_att_token
7
+
8
+
9
+ def convert_month_token_to_upperbound_days(
10
+ month_token: str, time_bucket_size: int = 90
11
+ ) -> str:
12
+ if is_att_token(month_token):
13
+ if month_token == "LT":
14
+ return ">= 1095 days"
15
+ else:
16
+ base = (int(month_token[1:]) + 1) * 30 // (time_bucket_size + 1)
17
+ return (
18
+ f"{base * time_bucket_size} days - {(base + 1) * time_bucket_size} days"
19
+ )
20
+ raise ValueError(f"month_token: {month_token} is not a valid month token")
21
+
22
+
23
+ def calculate_time_bucket_probability(
24
+ predictions: List[Dict[str, Any]], time_bucket_size: int = 90
25
+ ) -> List[Tuple[str, Any]]:
26
+ predictions_with_time_buckets = [
27
+ {
28
+ "probability": p["probability"],
29
+ "time_bucket": convert_month_token_to_upperbound_days(
30
+ p["time_interval"], time_bucket_size
31
+ ),
32
+ }
33
+ for p in predictions
34
+ ]
35
+ # Dictionary to store summed probabilities per time bucket
36
+ grouped_probabilities = defaultdict(float)
37
+ # Loop through the data
38
+ for entry in predictions_with_time_buckets:
39
+ time_bucket = entry["time_bucket"]
40
+ probability = entry["probability"]
41
+ grouped_probabilities[time_bucket] += probability
42
+ return sorted(grouped_probabilities.items(), key=lambda item: item[1], reverse=True)
43
+
44
+
45
+ def calculate_accumulative_time_bucket_probability(
46
+ predictions: List[Dict[str, Any]], time_bucket_size: int = 90
47
+ ) -> List[Tuple[str, Any]]:
48
+ time_bucket_probability = calculate_time_bucket_probability(
49
+ predictions, time_bucket_size
50
+ )
51
+ accumulative_probs = np.cumsum([_[1] for _ in time_bucket_probability])
52
+ return [
53
+ (*_, accumulative_prob)
54
+ for _, accumulative_prob in zip(time_bucket_probability, accumulative_probs)
55
+ ]
File without changes
@@ -0,0 +1,74 @@
1
+ import argparse
2
+ import os
3
+ import select # Import select for monitoring stdout and stderr
4
+ import subprocess
5
+ from pathlib import Path
6
+
7
+ import yaml
8
+
9
+
10
+ def create_arg_parser():
11
+ parser = argparse.ArgumentParser(
12
+ description="Arguments for benchmarking CEHRGPT on ehrshot cohorts"
13
+ )
14
+ parser.add_argument("--cohort_dir", required=True)
15
+ parser.add_argument("--base_yaml_file", required=True)
16
+ parser.add_argument("--output_folder", required=True)
17
+ return parser.parse_args()
18
+
19
+
20
+ if __name__ == "__main__":
21
+ args = create_arg_parser()
22
+
23
+ with open(args.base_yaml_file, "rb") as stream:
24
+ base_config = yaml.safe_load(stream)
25
+
26
+ for cohort_name in os.listdir(args.cohort_dir):
27
+ if cohort_name.endswith("/"):
28
+ cohort_name = cohort_name[:-1]
29
+ individual_output = os.path.join(args.output_folder, cohort_name)
30
+ if os.path.exists(individual_output):
31
+ continue
32
+ Path(individual_output).mkdir(parents=True, exist_ok=True)
33
+ base_config["data_folder"] = os.path.join(args.cohort_dir, cohort_name, "train")
34
+ base_config["test_data_folder"] = os.path.join(
35
+ args.cohort_dir, cohort_name, "test"
36
+ )
37
+ base_config["output_dir"] = individual_output
38
+
39
+ # Write YAML data to a file
40
+ config_path = os.path.join(individual_output, "config.yaml")
41
+ with open(config_path, "w") as yaml_file:
42
+ yaml.dump(base_config, yaml_file, default_flow_style=False)
43
+
44
+ command = [
45
+ "python",
46
+ "-u",
47
+ "-m",
48
+ "cehrgpt.runners.hf_cehrgpt_finetune_runner",
49
+ config_path,
50
+ ]
51
+
52
+ # Start the subprocess
53
+ with subprocess.Popen(
54
+ command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True
55
+ ) as process:
56
+ while True:
57
+ # Use select to wait for either stdout or stderr to have data
58
+ reads = [process.stdout.fileno(), process.stderr.fileno()]
59
+ ret = select.select(reads, [], [])
60
+
61
+ # Read data from stdout and stderr as it becomes available
62
+ for fd in ret[0]:
63
+ if fd == process.stdout.fileno():
64
+ line = process.stdout.readline()
65
+ if line:
66
+ print(line, end="")
67
+ elif fd == process.stderr.fileno():
68
+ line = process.stderr.readline()
69
+ if line:
70
+ print(line, end="")
71
+
72
+ # Break loop when process finishes
73
+ if process.poll() is not None:
74
+ break
@@ -0,0 +1,130 @@
1
+ import argparse
2
+ import os
3
+ import pickle
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ import torch
8
+ from sklearn.preprocessing import normalize
9
+ from tqdm import tqdm
10
+ from transformers import AutoModel, AutoTokenizer
11
+
12
+ from cehrgpt.models.pretrained_embeddings import (
13
+ PRETRAINED_EMBEDDING_CONCEPT_FILE_NAME,
14
+ PRETRAINED_EMBEDDING_VECTOR_FILE_NAME,
15
+ )
16
+ from cehrgpt.models.tokenization_hf_cehrgpt import CehrGptTokenizer
17
+
18
+
19
+ def generate_embeddings_batch(texts, tokenizer, device, model):
20
+ input = tokenizer(
21
+ texts, return_tensors="pt", padding=True, truncation=True, max_length=512
22
+ )
23
+ input = {k: v.to(device) for k, v in input.items()}
24
+
25
+ with torch.no_grad():
26
+ outputs = model(**input)
27
+ embeddings = outputs.last_hidden_state.mean(dim=1).cpu().numpy()
28
+ return normalize(embeddings)
29
+
30
+
31
+ def main(args):
32
+ tokenizer = AutoTokenizer.from_pretrained(
33
+ "dunzhang/stella_en_1.5B_v5", trust_remote_code=True
34
+ )
35
+ model = AutoModel.from_pretrained(
36
+ "dunzhang/stella_en_1.5B_v5", trust_remote_code=True
37
+ ).eval()
38
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
39
+ model = model.to(device)
40
+
41
+ print("Load cehrgpt tokenizer")
42
+ cehrgpt_tokenizer = CehrGptTokenizer.from_pretrained(args.tokenizer_path)
43
+ concept_ids = [_ for _ in cehrgpt_tokenizer.get_vocab().keys() if _.isnumeric()]
44
+ vocab = pd.DataFrame(concept_ids, columns=["concept_id"])
45
+ vocab.drop_duplicates(subset=["concept_id"], inplace=True)
46
+ vocab = vocab.astype(str)
47
+
48
+ print("Load concept dataframe")
49
+ concept = pd.read_parquet(args.concept_parquet_file_path)
50
+ concept = concept.astype(str)
51
+
52
+ print("Merge concept ids and concept names")
53
+ vocab_with_name = vocab.merge(
54
+ concept, how="inner", left_on="concept_id", right_on="concept_id"
55
+ )
56
+
57
+ concept_ids = vocab_with_name["concept_id"].to_list()
58
+ concept_names = vocab_with_name["concept_name"].to_list()
59
+ all_embeddings = []
60
+ concept_dict = []
61
+ for i in tqdm(
62
+ range(0, (len(concept_names) + args.batch_size - 1), args.batch_size)
63
+ ):
64
+ batched_concept_names = concept_names[i : i + args.batch_size]
65
+ batched_concept_ids = concept_ids[i : i + args.batch_size]
66
+ try:
67
+ batch_embeddings = generate_embeddings_batch(
68
+ batched_concept_names, tokenizer, device, model
69
+ )
70
+ all_embeddings.extend(batch_embeddings)
71
+ concept_dict.extend(
72
+ [
73
+ {"concept_id": concept_id, "concept_name": concept_name}
74
+ for concept_id, concept_name in zip(
75
+ batched_concept_ids, batched_concept_names
76
+ )
77
+ ]
78
+ )
79
+ except Exception as e:
80
+ print(f"Error processing batch: {str(e)}")
81
+
82
+ np.save(
83
+ os.path.join(args.output_folder_path, PRETRAINED_EMBEDDING_VECTOR_FILE_NAME),
84
+ all_embeddings,
85
+ )
86
+
87
+ with open(
88
+ os.path.join(args.output_folder_path, PRETRAINED_EMBEDDING_CONCEPT_FILE_NAME),
89
+ "wb",
90
+ ) as file:
91
+ pickle.dump(concept_dict, file)
92
+
93
+
94
+ def create_arg_parser():
95
+ parser = argparse.ArgumentParser(description="Create pretrained embeddings")
96
+ parser.add_argument(
97
+ "--tokenizer_path",
98
+ dest="tokenizer_path",
99
+ action="store",
100
+ help="The path for the vocabulary json file",
101
+ required=True,
102
+ )
103
+ parser.add_argument(
104
+ "--concept_parquet_file_path",
105
+ dest="concept_parquet_file_path",
106
+ action="store",
107
+ help="The path for your concept_path",
108
+ required=True,
109
+ )
110
+ parser.add_argument(
111
+ "--batch_size",
112
+ dest="batch_size",
113
+ type=int,
114
+ default=16,
115
+ action="store",
116
+ help="Batch size to process the concept_names",
117
+ required=True,
118
+ )
119
+ parser.add_argument(
120
+ "--output_folder_path",
121
+ dest="output_folder_path",
122
+ action="store",
123
+ help="Output folder path for saving the embeddings and concept_names",
124
+ required=True,
125
+ )
126
+ return parser
127
+
128
+
129
+ if __name__ == "__main__":
130
+ main(create_arg_parser().parse_args())