cehrgpt 0.1.0__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cehrgpt/analysis/irregularity.py +36 -0
- cehrgpt/data/hf_cehrgpt_dataset.py +1 -0
- cehrgpt/data/hf_cehrgpt_dataset_collator.py +454 -68
- cehrgpt/data/hf_cehrgpt_dataset_mapping.py +232 -17
- cehrgpt/data/sample_packing_sampler.py +36 -6
- cehrgpt/generation/cehrgpt_conditional_generation.py +314 -0
- cehrgpt/generation/generate_batch_hf_gpt_sequence.py +15 -3
- cehrgpt/generation/omop_converter_batch.py +32 -2
- cehrgpt/gpt_utils.py +20 -2
- cehrgpt/models/config.py +25 -0
- cehrgpt/models/hf_cehrgpt.py +244 -39
- cehrgpt/models/hf_modeling_outputs.py +1 -0
- cehrgpt/models/special_tokens.py +1 -0
- cehrgpt/models/tokenization_hf_cehrgpt.py +354 -71
- cehrgpt/runners/data_utils.py +131 -5
- cehrgpt/runners/hf_cehrgpt_finetune_runner.py +84 -51
- cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +59 -7
- cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +60 -0
- cehrgpt/runners/hyperparameter_search_util.py +6 -7
- cehrgpt/runners/sample_packing_trainer.py +17 -0
- cehrgpt/time_to_event/config/1_year_cabg.yaml +23 -0
- cehrgpt/time_to_event/time_to_event_model.py +2 -13
- cehrgpt/time_to_event/time_to_event_prediction.py +27 -13
- cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +80 -62
- {cehrgpt-0.1.0.dist-info → cehrgpt-0.1.2.dist-info}/METADATA +102 -7
- {cehrgpt-0.1.0.dist-info → cehrgpt-0.1.2.dist-info}/RECORD +29 -26
- {cehrgpt-0.1.0.dist-info → cehrgpt-0.1.2.dist-info}/WHEEL +1 -1
- {cehrgpt-0.1.0.dist-info → cehrgpt-0.1.2.dist-info}/licenses/LICENSE +0 -0
- {cehrgpt-0.1.0.dist-info → cehrgpt-0.1.2.dist-info}/top_level.txt +0 -0
@@ -1,8 +1,10 @@
|
|
1
1
|
import datetime
|
2
|
-
from
|
2
|
+
from collections import defaultdict
|
3
|
+
from typing import Any, Dict, Generator, List, Optional, Union
|
3
4
|
|
4
5
|
import numpy as np
|
5
6
|
import pandas as pd
|
7
|
+
from cehrbert.data_generators.hf_data_generator import UNKNOWN_VALUE
|
6
8
|
from cehrbert.data_generators.hf_data_generator.hf_dataset_mapping import (
|
7
9
|
ED_VISIT_TYPE_CODES,
|
8
10
|
INPATIENT_VISIT_TYPE_CODES,
|
@@ -15,9 +17,16 @@ from cehrbert.data_generators.hf_data_generator.hf_dataset_mapping import (
|
|
15
17
|
)
|
16
18
|
from cehrbert.med_extension.schema_extension import Event
|
17
19
|
from cehrbert.runners.hf_runner_argument_dataclass import DataTrainingArguments
|
20
|
+
from cehrbert_data.const.artificial_tokens import (
|
21
|
+
DISCHARGE_UNKNOWN_TOKEN,
|
22
|
+
GENDER_UNKNOWN_TOKEN,
|
23
|
+
RACE_UNKNOWN_TOKEN,
|
24
|
+
)
|
18
25
|
from cehrbert_data.const.common import NA
|
19
26
|
from cehrbert_data.decorators.patient_event_decorator_base import get_att_function
|
27
|
+
from datasets.formatting.formatting import LazyBatch
|
20
28
|
from dateutil.relativedelta import relativedelta
|
29
|
+
from pandas import Series
|
21
30
|
|
22
31
|
from cehrgpt.models.tokenization_hf_cehrgpt import (
|
23
32
|
NONE_BIN,
|
@@ -25,14 +34,60 @@ from cehrgpt.models.tokenization_hf_cehrgpt import (
|
|
25
34
|
CehrGptTokenizer,
|
26
35
|
)
|
27
36
|
|
37
|
+
CEHRGPT_COLUMNS = [
|
38
|
+
"concept_ids",
|
39
|
+
"concept_value_masks",
|
40
|
+
"number_as_values",
|
41
|
+
"concept_as_values",
|
42
|
+
"is_numeric_types",
|
43
|
+
"concept_values",
|
44
|
+
"units",
|
45
|
+
"epoch_times",
|
46
|
+
]
|
47
|
+
|
48
|
+
|
49
|
+
def convert_date_to_posix_time(index_date: Union[datetime.date, int, float]) -> float:
|
50
|
+
if isinstance(index_date, datetime.date):
|
51
|
+
return (
|
52
|
+
datetime.datetime.combine(index_date, datetime.datetime.min.time())
|
53
|
+
.replace(tzinfo=datetime.timezone.utc)
|
54
|
+
.timestamp()
|
55
|
+
)
|
56
|
+
elif isinstance(index_date, datetime.datetime):
|
57
|
+
return index_date.replace(tzinfo=datetime.timezone.utc).timestamp()
|
58
|
+
return index_date
|
59
|
+
|
60
|
+
|
61
|
+
class DatasetMappingDecorator(DatasetMapping):
|
62
|
+
|
63
|
+
def batch_transform(
|
64
|
+
self, records: Union[LazyBatch, Dict[str, Any]]
|
65
|
+
) -> List[Dict[str, Any]]:
|
66
|
+
"""
|
67
|
+
Drop index_date if it contains None.
|
68
|
+
|
69
|
+
:param records:
|
70
|
+
:return:
|
71
|
+
"""
|
72
|
+
if isinstance(records, LazyBatch):
|
73
|
+
table = records.pa_table
|
74
|
+
|
75
|
+
if "index_date" in table.column_names:
|
76
|
+
index_col = table.column("index_date")
|
77
|
+
if index_col.null_count > 0:
|
78
|
+
table = table.drop(["index_date"])
|
79
|
+
records = LazyBatch(pa_table=table, formatter=records.formatter)
|
80
|
+
else:
|
81
|
+
if "index_date" in records:
|
82
|
+
if pd.isna(records["index_date"][0]):
|
83
|
+
del records["index_date"]
|
84
|
+
return super().batch_transform(records=records)
|
28
85
|
|
29
|
-
def
|
30
|
-
|
31
|
-
index_date, datetime.datetime.min.time()
|
32
|
-
).timestamp()
|
86
|
+
def transform(self, record: Dict[str, Any]) -> Union[Dict[str, Any], Series]:
|
87
|
+
raise NotImplemented("Must be implemented")
|
33
88
|
|
34
89
|
|
35
|
-
class MedToCehrGPTDatasetMapping(
|
90
|
+
class MedToCehrGPTDatasetMapping(DatasetMappingDecorator):
|
36
91
|
def __init__(
|
37
92
|
self,
|
38
93
|
data_args: DataTrainingArguments,
|
@@ -65,6 +120,7 @@ class MedToCehrGPTDatasetMapping(DatasetMapping):
|
|
65
120
|
def _update_cehrgpt_record(
|
66
121
|
cehrgpt_record: Dict[str, Any],
|
67
122
|
code: str,
|
123
|
+
time: datetime.datetime,
|
68
124
|
concept_value_mask: int = 0,
|
69
125
|
number_as_value: float = 0.0,
|
70
126
|
concept_as_value: str = "0",
|
@@ -77,6 +133,9 @@ class MedToCehrGPTDatasetMapping(DatasetMapping):
|
|
77
133
|
cehrgpt_record["concept_as_values"].append(concept_as_value)
|
78
134
|
cehrgpt_record["units"].append(unit)
|
79
135
|
cehrgpt_record["is_numeric_types"].append(is_numeric_type)
|
136
|
+
cehrgpt_record["epoch_times"].append(
|
137
|
+
time.replace(tzinfo=datetime.timezone.utc).timestamp()
|
138
|
+
)
|
80
139
|
|
81
140
|
def transform(self, record: Dict[str, Any]) -> Dict[str, Any]:
|
82
141
|
cehrgpt_record = {
|
@@ -87,13 +146,16 @@ class MedToCehrGPTDatasetMapping(DatasetMapping):
|
|
87
146
|
"concept_as_values": [],
|
88
147
|
"units": [],
|
89
148
|
"is_numeric_types": [],
|
149
|
+
"epoch_times": [],
|
90
150
|
}
|
91
151
|
# Extract the demographic information
|
92
152
|
birth_datetime = record["birth_datetime"]
|
93
153
|
if isinstance(birth_datetime, pd.Timestamp):
|
94
154
|
birth_datetime = birth_datetime.to_pydatetime()
|
95
155
|
gender = record["gender"]
|
156
|
+
gender = GENDER_UNKNOWN_TOKEN if gender == UNKNOWN_VALUE else gender
|
96
157
|
race = record["race"]
|
158
|
+
race = RACE_UNKNOWN_TOKEN if race == UNKNOWN_VALUE else race
|
97
159
|
visits = record["visits"]
|
98
160
|
# This indicates this is columnar format
|
99
161
|
if isinstance(visits, dict):
|
@@ -108,10 +170,12 @@ class MedToCehrGPTDatasetMapping(DatasetMapping):
|
|
108
170
|
)
|
109
171
|
year_str = f"year:{str(first_visit_start_datetime.year)}"
|
110
172
|
age_str = f"age:{str(relativedelta(first_visit_start_datetime, birth_datetime).years)}"
|
111
|
-
self._update_cehrgpt_record(
|
112
|
-
|
113
|
-
|
114
|
-
self._update_cehrgpt_record(cehrgpt_record,
|
173
|
+
self._update_cehrgpt_record(
|
174
|
+
cehrgpt_record, year_str, first_visit_start_datetime
|
175
|
+
)
|
176
|
+
self._update_cehrgpt_record(cehrgpt_record, age_str, first_visit_start_datetime)
|
177
|
+
self._update_cehrgpt_record(cehrgpt_record, gender, first_visit_start_datetime)
|
178
|
+
self._update_cehrgpt_record(cehrgpt_record, race, first_visit_start_datetime)
|
115
179
|
|
116
180
|
# Use a data cursor to keep track of time
|
117
181
|
datetime_cursor: Optional[datetime.datetime] = None
|
@@ -146,6 +210,7 @@ class MedToCehrGPTDatasetMapping(DatasetMapping):
|
|
146
210
|
self._update_cehrgpt_record(
|
147
211
|
cehrgpt_record,
|
148
212
|
code=self._time_token_function(time_delta),
|
213
|
+
time=visit_start_datetime,
|
149
214
|
)
|
150
215
|
|
151
216
|
datetime_cursor = visit_start_datetime
|
@@ -153,11 +218,13 @@ class MedToCehrGPTDatasetMapping(DatasetMapping):
|
|
153
218
|
self._update_cehrgpt_record(
|
154
219
|
cehrgpt_record,
|
155
220
|
code="[VS]",
|
221
|
+
time=datetime_cursor,
|
156
222
|
)
|
157
223
|
# Add a visit type token
|
158
224
|
self._update_cehrgpt_record(
|
159
225
|
cehrgpt_record,
|
160
226
|
code=visit_type,
|
227
|
+
time=datetime_cursor,
|
161
228
|
)
|
162
229
|
# We need to insert an inpatient hour token right after the visit type, we calculate the hour interval
|
163
230
|
# with respect to the midnight of the day
|
@@ -167,6 +234,7 @@ class MedToCehrGPTDatasetMapping(DatasetMapping):
|
|
167
234
|
self._update_cehrgpt_record(
|
168
235
|
cehrgpt_record,
|
169
236
|
code=f"i-H{datetime_cursor.hour}",
|
237
|
+
time=datetime_cursor,
|
170
238
|
)
|
171
239
|
|
172
240
|
# Keep track of the existing outpatient events, we don't want to add them again
|
@@ -185,6 +253,10 @@ class MedToCehrGPTDatasetMapping(DatasetMapping):
|
|
185
253
|
concept_value_mask = int(
|
186
254
|
numeric_value is not None or text_value is not None
|
187
255
|
)
|
256
|
+
if numeric_value is None and text_value is not None:
|
257
|
+
if text_value.isnumeric():
|
258
|
+
numeric_value = float(text_value)
|
259
|
+
|
188
260
|
is_numeric_type = int(numeric_value is not None)
|
189
261
|
code = replace_escape_chars(e["code"])
|
190
262
|
|
@@ -208,6 +280,7 @@ class MedToCehrGPTDatasetMapping(DatasetMapping):
|
|
208
280
|
self._update_cehrgpt_record(
|
209
281
|
cehrgpt_record,
|
210
282
|
code=f"i-{self._inpatient_time_token_function(time_diff_days)}",
|
283
|
+
time=event_time,
|
211
284
|
)
|
212
285
|
|
213
286
|
if self._include_inpatient_hour_token:
|
@@ -226,6 +299,7 @@ class MedToCehrGPTDatasetMapping(DatasetMapping):
|
|
226
299
|
self._update_cehrgpt_record(
|
227
300
|
cehrgpt_record,
|
228
301
|
code=f"i-H{time_diff_hours}",
|
302
|
+
time=event_time,
|
229
303
|
)
|
230
304
|
|
231
305
|
if event_identity in existing_duplicate_events:
|
@@ -234,6 +308,7 @@ class MedToCehrGPTDatasetMapping(DatasetMapping):
|
|
234
308
|
self._update_cehrgpt_record(
|
235
309
|
cehrgpt_record,
|
236
310
|
code=code,
|
311
|
+
time=event_time,
|
237
312
|
concept_value_mask=concept_value_mask,
|
238
313
|
unit=unit,
|
239
314
|
number_as_value=numeric_value if numeric_value else 0.0,
|
@@ -262,17 +337,24 @@ class MedToCehrGPTDatasetMapping(DatasetMapping):
|
|
262
337
|
# facility event
|
263
338
|
discharge_facility = get_value(visit, "discharge_facility")
|
264
339
|
if not discharge_facility:
|
265
|
-
discharge_facility =
|
266
|
-
|
340
|
+
discharge_facility = DISCHARGE_UNKNOWN_TOKEN
|
341
|
+
else:
|
342
|
+
discharge_facility = (
|
343
|
+
DISCHARGE_UNKNOWN_TOKEN
|
344
|
+
if discharge_facility == UNKNOWN_VALUE
|
345
|
+
else discharge_facility
|
346
|
+
)
|
267
347
|
self._update_cehrgpt_record(
|
268
348
|
cehrgpt_record,
|
269
349
|
code=discharge_facility,
|
350
|
+
time=datetime_cursor,
|
270
351
|
)
|
271
352
|
|
272
353
|
# Reuse the age and date calculated for the last event in the patient timeline
|
273
354
|
self._update_cehrgpt_record(
|
274
355
|
cehrgpt_record,
|
275
356
|
code="[VE]",
|
357
|
+
time=datetime_cursor,
|
276
358
|
)
|
277
359
|
|
278
360
|
# Generate the orders of the concepts that the cehrbert dataset mapping function expects
|
@@ -284,17 +366,23 @@ class MedToCehrGPTDatasetMapping(DatasetMapping):
|
|
284
366
|
cehrgpt_record["num_of_concepts"] = len(cehrgpt_record["concept_ids"])
|
285
367
|
cehrgpt_record["num_of_visits"] = len(visits)
|
286
368
|
|
287
|
-
if record.get("index_date", None):
|
288
|
-
cehrgpt_record["index_date"] =
|
289
|
-
|
369
|
+
if record.get("index_date", None) is not None:
|
370
|
+
cehrgpt_record["index_date"] = (
|
371
|
+
record["index_date"].replace(tzinfo=datetime.timezone.utc).timestamp()
|
372
|
+
)
|
373
|
+
if record.get("label", None) is not None:
|
290
374
|
cehrgpt_record["label"] = record["label"]
|
291
|
-
if record.get("age_at_index", None):
|
375
|
+
if record.get("age_at_index", None) is not None:
|
292
376
|
cehrgpt_record["age_at_index"] = record["age_at_index"]
|
293
377
|
|
378
|
+
assert len(cehrgpt_record["epoch_times"]) == len(
|
379
|
+
cehrgpt_record["concept_ids"]
|
380
|
+
), "The number of time stamps must match with the number of concepts in the sequence"
|
381
|
+
|
294
382
|
return cehrgpt_record
|
295
383
|
|
296
384
|
|
297
|
-
class HFCehrGptTokenizationMapping(
|
385
|
+
class HFCehrGptTokenizationMapping(DatasetMappingDecorator):
|
298
386
|
def __init__(
|
299
387
|
self,
|
300
388
|
concept_tokenizer: CehrGptTokenizer,
|
@@ -308,9 +396,46 @@ class HFCehrGptTokenizationMapping(DatasetMapping):
|
|
308
396
|
"is_numeric_types",
|
309
397
|
]
|
310
398
|
|
399
|
+
def filter_out_invalid_tokens(self, record: Dict[str, Any]) -> Dict[str, Any]:
|
400
|
+
column_names = []
|
401
|
+
seq_length = len(record["concept_ids"])
|
402
|
+
|
403
|
+
# We can't have "0" as a token in the tokenizer because it would break tokenization for "Race/0", "Visit/0"
|
404
|
+
# This is a pre-caution
|
405
|
+
if "0" in record["concept_ids"]:
|
406
|
+
if isinstance(record["concept_ids"], np.ndarray):
|
407
|
+
record["concept_ids"][record["concept_ids"] == "0"] = "Unknown"
|
408
|
+
else:
|
409
|
+
record["concept_ids"] = [
|
410
|
+
"Unknown" if x == "0" else x for x in record["concept_ids"]
|
411
|
+
]
|
412
|
+
|
413
|
+
for k, v in record.items():
|
414
|
+
if k not in CEHRGPT_COLUMNS:
|
415
|
+
continue
|
416
|
+
if isinstance(v, (list, np.ndarray)) and len(v) == seq_length:
|
417
|
+
column_names.append(k)
|
418
|
+
valid_concept_ids = self._concept_tokenizer.get_vocab().keys()
|
419
|
+
valid_indices = [
|
420
|
+
idx
|
421
|
+
for idx, concept_id in enumerate(record["concept_ids"])
|
422
|
+
if concept_id in valid_concept_ids
|
423
|
+
]
|
424
|
+
if len(valid_indices) != len(record["concept_ids"]):
|
425
|
+
for column in column_names:
|
426
|
+
values = record[column]
|
427
|
+
record[column] = [values[idx] for idx in valid_indices]
|
428
|
+
return record
|
429
|
+
|
311
430
|
def transform(self, record: Dict[str, Any]) -> Dict[str, Any]:
|
431
|
+
# Remove the tokens from patient sequences that do not exist in the tokenizer
|
432
|
+
record = self.filter_out_invalid_tokens(record)
|
312
433
|
# If any concept has a value associated with it, we normalize the value
|
313
434
|
record["input_ids"] = self._concept_tokenizer.encode(record["concept_ids"])
|
435
|
+
assert len(record["input_ids"]) == len(record["concept_ids"]), (
|
436
|
+
"The number of tokens must equal to the number of concepts\n"
|
437
|
+
f"decoded concept_ids: {self._concept_tokenizer.decode(record['input_ids'], skip_special_tokens=False)}"
|
438
|
+
)
|
314
439
|
record["value_indicators"] = record["concept_value_masks"]
|
315
440
|
if "number_as_values" not in record or "concept_as_values" not in record:
|
316
441
|
record["number_as_values"] = [
|
@@ -391,3 +516,93 @@ class HFFineTuningMapping(HFCehrGptTokenizationMapping):
|
|
391
516
|
columns = super().remove_columns()
|
392
517
|
columns.append("label")
|
393
518
|
return columns
|
519
|
+
|
520
|
+
|
521
|
+
class ExtractTokenizedSequenceDataMapping:
|
522
|
+
def __init__(
|
523
|
+
self,
|
524
|
+
person_index_date_map: Dict[int, List[Dict[str, Any]]],
|
525
|
+
observation_window: int = 0,
|
526
|
+
):
|
527
|
+
self.person_index_date_map = person_index_date_map
|
528
|
+
self.observation_window = observation_window
|
529
|
+
|
530
|
+
def _calculate_prediction_start_time(self, prediction_time: float):
|
531
|
+
if self.observation_window and self.observation_window > 0:
|
532
|
+
return max(prediction_time - self.observation_window * 24 * 3600, 0)
|
533
|
+
return 0
|
534
|
+
|
535
|
+
def transform(self, record: Dict[str, Any]) -> Dict[str, Any]:
|
536
|
+
person_id = record["person_id"]
|
537
|
+
prediction_times = self.person_index_date_map[person_id]
|
538
|
+
prediction_start_end_times = [
|
539
|
+
(
|
540
|
+
self._calculate_prediction_start_time(
|
541
|
+
prediction_time_label_map["index_date"]
|
542
|
+
.replace(tzinfo=datetime.timezone.utc)
|
543
|
+
.timestamp()
|
544
|
+
),
|
545
|
+
prediction_time_label_map["index_date"]
|
546
|
+
.replace(tzinfo=datetime.timezone.utc)
|
547
|
+
.timestamp(),
|
548
|
+
prediction_time_label_map["label"],
|
549
|
+
)
|
550
|
+
for prediction_time_label_map in prediction_times
|
551
|
+
]
|
552
|
+
observation_window_indices = np.zeros(
|
553
|
+
(len(prediction_times), len(record["epoch_times"])), dtype=bool
|
554
|
+
)
|
555
|
+
for i, epoch_time in enumerate(record["epoch_times"]):
|
556
|
+
for sample_n, (
|
557
|
+
feature_extraction_time_start,
|
558
|
+
feature_extraction_end_end,
|
559
|
+
_,
|
560
|
+
) in enumerate(prediction_start_end_times):
|
561
|
+
if (
|
562
|
+
feature_extraction_time_start
|
563
|
+
<= epoch_time
|
564
|
+
<= feature_extraction_end_end
|
565
|
+
):
|
566
|
+
observation_window_indices[sample_n][i] = True
|
567
|
+
|
568
|
+
seq_length = len(record["epoch_times"])
|
569
|
+
time_series_columns = ["concept_ids", "input_ids"]
|
570
|
+
static_inputs = dict()
|
571
|
+
for k, v in record.items():
|
572
|
+
if k in ["concept_ids", "input_ids"]:
|
573
|
+
continue
|
574
|
+
if isinstance(v, (list, np.ndarray)) and len(v) == seq_length:
|
575
|
+
time_series_columns.append(k)
|
576
|
+
else:
|
577
|
+
static_inputs[k] = v
|
578
|
+
|
579
|
+
batched_samples = defaultdict(list)
|
580
|
+
for (_, index_date, label), observation_window_index in zip(
|
581
|
+
prediction_start_end_times, observation_window_indices
|
582
|
+
):
|
583
|
+
for k, v in static_inputs.items():
|
584
|
+
batched_samples[k].append(v)
|
585
|
+
batched_samples["classifier_label"].append(label)
|
586
|
+
batched_samples["index_date"].append(index_date)
|
587
|
+
try:
|
588
|
+
start_age = int(record["concept_ids"][1].split(":")[1])
|
589
|
+
except Exception:
|
590
|
+
start_age = -1
|
591
|
+
batched_samples["age_at_index"].append(start_age)
|
592
|
+
for time_series_column in time_series_columns:
|
593
|
+
batched_samples[time_series_column].append(
|
594
|
+
np.asarray(record[time_series_column])[observation_window_index]
|
595
|
+
)
|
596
|
+
return batched_samples
|
597
|
+
|
598
|
+
def batch_transform(self, record: Dict[str, Any]) -> Dict[str, Any]:
|
599
|
+
all_batched_record = defaultdict(list)
|
600
|
+
all_columns = record.keys()
|
601
|
+
for i in range(len(record["concept_ids"])):
|
602
|
+
one_record = {}
|
603
|
+
for column in all_columns:
|
604
|
+
one_record[column] = record[column][i]
|
605
|
+
new_batched_record = self.transform(one_record)
|
606
|
+
for k, v in new_batched_record.items():
|
607
|
+
all_batched_record[k].extend(v)
|
608
|
+
return all_batched_record
|
@@ -1,5 +1,6 @@
|
|
1
1
|
from typing import Iterator, List, Optional
|
2
2
|
|
3
|
+
import numpy as np
|
3
4
|
import torch
|
4
5
|
import torch.distributed as dist
|
5
6
|
from torch.utils.data import Sampler
|
@@ -33,6 +34,8 @@ class SamplePackingBatchSampler(Sampler[List[int]]):
|
|
33
34
|
rank: Optional[int] = None,
|
34
35
|
seed: int = 0,
|
35
36
|
drop_last: bool = False,
|
37
|
+
negative_sampling_probability: Optional[float] = None,
|
38
|
+
labels: Optional[List[int]] = None,
|
36
39
|
):
|
37
40
|
"""
|
38
41
|
Args:
|
@@ -73,6 +76,11 @@ class SamplePackingBatchSampler(Sampler[List[int]]):
|
|
73
76
|
f"Invalid rank {rank}, rank should be in the interval [0, {num_replicas - 1}]"
|
74
77
|
)
|
75
78
|
|
79
|
+
if negative_sampling_probability is not None and labels is None:
|
80
|
+
raise ValueError(
|
81
|
+
f"When the negative sampling probability is provide, the labels must be provided as well"
|
82
|
+
)
|
83
|
+
|
76
84
|
self.lengths = lengths
|
77
85
|
self.max_tokens_per_batch = max_tokens_per_batch
|
78
86
|
self.max_position_embeddings = max_position_embeddings
|
@@ -80,6 +88,8 @@ class SamplePackingBatchSampler(Sampler[List[int]]):
|
|
80
88
|
self.rank = rank
|
81
89
|
self.seed = seed
|
82
90
|
self.drop_last = drop_last
|
91
|
+
self.negative_sampling_probability = negative_sampling_probability
|
92
|
+
self.labels = labels
|
83
93
|
# Trainer https://github.com/huggingface/transformers/blame/main/src/transformers/trainer.py#L2470
|
84
94
|
# http://github.com/huggingface/accelerate/blob/v0.31.0/src/accelerate/data_loader.py#L482
|
85
95
|
# the huggingface trainer will call the accelerate.data_loader.DataLoaderShard.set_epoch,
|
@@ -100,6 +110,14 @@ class SamplePackingBatchSampler(Sampler[List[int]]):
|
|
100
110
|
current_batch_tokens = 0
|
101
111
|
|
102
112
|
for idx in indices:
|
113
|
+
# There is a chance to skip the negative samples to account for the class imbalance
|
114
|
+
# in the fine-tuning dataset
|
115
|
+
if self.negative_sampling_probability:
|
116
|
+
if (
|
117
|
+
np.random.random() > self.negative_sampling_probability
|
118
|
+
and self.labels[idx] == 0
|
119
|
+
):
|
120
|
+
continue
|
103
121
|
# We take the minimum of the two because each sequence will be truncated to fit
|
104
122
|
# the context window of the model
|
105
123
|
sample_length = min(self.lengths[idx], self.max_position_embeddings)
|
@@ -131,10 +149,22 @@ class SamplePackingBatchSampler(Sampler[List[int]]):
|
|
131
149
|
if len(self.lengths) == 0:
|
132
150
|
return 0
|
133
151
|
|
134
|
-
#
|
135
|
-
|
136
|
-
|
137
|
-
|
152
|
+
# There is a chance to skip the negative samples to account for the class imbalance
|
153
|
+
# in the fine-tuning dataset
|
154
|
+
if self.negative_sampling_probability:
|
155
|
+
truncated_lengths = []
|
156
|
+
for length, label in zip(self.lengths, self.labels):
|
157
|
+
if (
|
158
|
+
np.random.random() > self.negative_sampling_probability
|
159
|
+
and label == 0
|
160
|
+
):
|
161
|
+
continue
|
162
|
+
truncated_lengths.append(length)
|
163
|
+
else:
|
164
|
+
# We need to truncate the lengths due to the context window limit imposed by the model
|
165
|
+
truncated_lengths = [
|
166
|
+
min(self.max_position_embeddings, length + 2) for length in self.lengths
|
167
|
+
]
|
138
168
|
|
139
169
|
# Calculate average sequence length
|
140
170
|
avg_seq_length = sum(truncated_lengths) // len(truncated_lengths)
|
@@ -145,7 +175,7 @@ class SamplePackingBatchSampler(Sampler[List[int]]):
|
|
145
175
|
# Estimate total number of batches
|
146
176
|
if self.drop_last:
|
147
177
|
# If dropping last incomplete batch
|
148
|
-
return len(truncated_lengths) // seqs_per_batch
|
178
|
+
return len(truncated_lengths) // seqs_per_batch
|
149
179
|
else:
|
150
180
|
# If keeping last incomplete batch, ensure at least 1 batch
|
151
|
-
return max(1, len(truncated_lengths) // seqs_per_batch)
|
181
|
+
return max(1, len(truncated_lengths) // seqs_per_batch)
|