cehrgpt 0.1.0__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. cehrgpt/analysis/irregularity.py +36 -0
  2. cehrgpt/data/hf_cehrgpt_dataset.py +1 -0
  3. cehrgpt/data/hf_cehrgpt_dataset_collator.py +454 -68
  4. cehrgpt/data/hf_cehrgpt_dataset_mapping.py +232 -17
  5. cehrgpt/data/sample_packing_sampler.py +36 -6
  6. cehrgpt/generation/cehrgpt_conditional_generation.py +314 -0
  7. cehrgpt/generation/generate_batch_hf_gpt_sequence.py +15 -3
  8. cehrgpt/generation/omop_converter_batch.py +32 -2
  9. cehrgpt/gpt_utils.py +20 -2
  10. cehrgpt/models/config.py +25 -0
  11. cehrgpt/models/hf_cehrgpt.py +244 -39
  12. cehrgpt/models/hf_modeling_outputs.py +1 -0
  13. cehrgpt/models/special_tokens.py +1 -0
  14. cehrgpt/models/tokenization_hf_cehrgpt.py +354 -71
  15. cehrgpt/runners/data_utils.py +131 -5
  16. cehrgpt/runners/hf_cehrgpt_finetune_runner.py +84 -51
  17. cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +59 -7
  18. cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +60 -0
  19. cehrgpt/runners/hyperparameter_search_util.py +6 -7
  20. cehrgpt/runners/sample_packing_trainer.py +17 -0
  21. cehrgpt/time_to_event/config/1_year_cabg.yaml +23 -0
  22. cehrgpt/time_to_event/time_to_event_model.py +2 -13
  23. cehrgpt/time_to_event/time_to_event_prediction.py +27 -13
  24. cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +80 -62
  25. {cehrgpt-0.1.0.dist-info → cehrgpt-0.1.2.dist-info}/METADATA +102 -7
  26. {cehrgpt-0.1.0.dist-info → cehrgpt-0.1.2.dist-info}/RECORD +29 -26
  27. {cehrgpt-0.1.0.dist-info → cehrgpt-0.1.2.dist-info}/WHEEL +1 -1
  28. {cehrgpt-0.1.0.dist-info → cehrgpt-0.1.2.dist-info}/licenses/LICENSE +0 -0
  29. {cehrgpt-0.1.0.dist-info → cehrgpt-0.1.2.dist-info}/top_level.txt +0 -0
@@ -1,8 +1,10 @@
1
1
  import datetime
2
- from typing import Any, Dict, Generator, Optional
2
+ from collections import defaultdict
3
+ from typing import Any, Dict, Generator, List, Optional, Union
3
4
 
4
5
  import numpy as np
5
6
  import pandas as pd
7
+ from cehrbert.data_generators.hf_data_generator import UNKNOWN_VALUE
6
8
  from cehrbert.data_generators.hf_data_generator.hf_dataset_mapping import (
7
9
  ED_VISIT_TYPE_CODES,
8
10
  INPATIENT_VISIT_TYPE_CODES,
@@ -15,9 +17,16 @@ from cehrbert.data_generators.hf_data_generator.hf_dataset_mapping import (
15
17
  )
16
18
  from cehrbert.med_extension.schema_extension import Event
17
19
  from cehrbert.runners.hf_runner_argument_dataclass import DataTrainingArguments
20
+ from cehrbert_data.const.artificial_tokens import (
21
+ DISCHARGE_UNKNOWN_TOKEN,
22
+ GENDER_UNKNOWN_TOKEN,
23
+ RACE_UNKNOWN_TOKEN,
24
+ )
18
25
  from cehrbert_data.const.common import NA
19
26
  from cehrbert_data.decorators.patient_event_decorator_base import get_att_function
27
+ from datasets.formatting.formatting import LazyBatch
20
28
  from dateutil.relativedelta import relativedelta
29
+ from pandas import Series
21
30
 
22
31
  from cehrgpt.models.tokenization_hf_cehrgpt import (
23
32
  NONE_BIN,
@@ -25,14 +34,60 @@ from cehrgpt.models.tokenization_hf_cehrgpt import (
25
34
  CehrGptTokenizer,
26
35
  )
27
36
 
37
+ CEHRGPT_COLUMNS = [
38
+ "concept_ids",
39
+ "concept_value_masks",
40
+ "number_as_values",
41
+ "concept_as_values",
42
+ "is_numeric_types",
43
+ "concept_values",
44
+ "units",
45
+ "epoch_times",
46
+ ]
47
+
48
+
49
+ def convert_date_to_posix_time(index_date: Union[datetime.date, int, float]) -> float:
50
+ if isinstance(index_date, datetime.date):
51
+ return (
52
+ datetime.datetime.combine(index_date, datetime.datetime.min.time())
53
+ .replace(tzinfo=datetime.timezone.utc)
54
+ .timestamp()
55
+ )
56
+ elif isinstance(index_date, datetime.datetime):
57
+ return index_date.replace(tzinfo=datetime.timezone.utc).timestamp()
58
+ return index_date
59
+
60
+
61
+ class DatasetMappingDecorator(DatasetMapping):
62
+
63
+ def batch_transform(
64
+ self, records: Union[LazyBatch, Dict[str, Any]]
65
+ ) -> List[Dict[str, Any]]:
66
+ """
67
+ Drop index_date if it contains None.
68
+
69
+ :param records:
70
+ :return:
71
+ """
72
+ if isinstance(records, LazyBatch):
73
+ table = records.pa_table
74
+
75
+ if "index_date" in table.column_names:
76
+ index_col = table.column("index_date")
77
+ if index_col.null_count > 0:
78
+ table = table.drop(["index_date"])
79
+ records = LazyBatch(pa_table=table, formatter=records.formatter)
80
+ else:
81
+ if "index_date" in records:
82
+ if pd.isna(records["index_date"][0]):
83
+ del records["index_date"]
84
+ return super().batch_transform(records=records)
28
85
 
29
- def convert_date_to_posix_time(index_date: datetime.date) -> float:
30
- return datetime.datetime.combine(
31
- index_date, datetime.datetime.min.time()
32
- ).timestamp()
86
+ def transform(self, record: Dict[str, Any]) -> Union[Dict[str, Any], Series]:
87
+ raise NotImplemented("Must be implemented")
33
88
 
34
89
 
35
- class MedToCehrGPTDatasetMapping(DatasetMapping):
90
+ class MedToCehrGPTDatasetMapping(DatasetMappingDecorator):
36
91
  def __init__(
37
92
  self,
38
93
  data_args: DataTrainingArguments,
@@ -65,6 +120,7 @@ class MedToCehrGPTDatasetMapping(DatasetMapping):
65
120
  def _update_cehrgpt_record(
66
121
  cehrgpt_record: Dict[str, Any],
67
122
  code: str,
123
+ time: datetime.datetime,
68
124
  concept_value_mask: int = 0,
69
125
  number_as_value: float = 0.0,
70
126
  concept_as_value: str = "0",
@@ -77,6 +133,9 @@ class MedToCehrGPTDatasetMapping(DatasetMapping):
77
133
  cehrgpt_record["concept_as_values"].append(concept_as_value)
78
134
  cehrgpt_record["units"].append(unit)
79
135
  cehrgpt_record["is_numeric_types"].append(is_numeric_type)
136
+ cehrgpt_record["epoch_times"].append(
137
+ time.replace(tzinfo=datetime.timezone.utc).timestamp()
138
+ )
80
139
 
81
140
  def transform(self, record: Dict[str, Any]) -> Dict[str, Any]:
82
141
  cehrgpt_record = {
@@ -87,13 +146,16 @@ class MedToCehrGPTDatasetMapping(DatasetMapping):
87
146
  "concept_as_values": [],
88
147
  "units": [],
89
148
  "is_numeric_types": [],
149
+ "epoch_times": [],
90
150
  }
91
151
  # Extract the demographic information
92
152
  birth_datetime = record["birth_datetime"]
93
153
  if isinstance(birth_datetime, pd.Timestamp):
94
154
  birth_datetime = birth_datetime.to_pydatetime()
95
155
  gender = record["gender"]
156
+ gender = GENDER_UNKNOWN_TOKEN if gender == UNKNOWN_VALUE else gender
96
157
  race = record["race"]
158
+ race = RACE_UNKNOWN_TOKEN if race == UNKNOWN_VALUE else race
97
159
  visits = record["visits"]
98
160
  # This indicates this is columnar format
99
161
  if isinstance(visits, dict):
@@ -108,10 +170,12 @@ class MedToCehrGPTDatasetMapping(DatasetMapping):
108
170
  )
109
171
  year_str = f"year:{str(first_visit_start_datetime.year)}"
110
172
  age_str = f"age:{str(relativedelta(first_visit_start_datetime, birth_datetime).years)}"
111
- self._update_cehrgpt_record(cehrgpt_record, year_str)
112
- self._update_cehrgpt_record(cehrgpt_record, age_str)
113
- self._update_cehrgpt_record(cehrgpt_record, gender)
114
- self._update_cehrgpt_record(cehrgpt_record, race)
173
+ self._update_cehrgpt_record(
174
+ cehrgpt_record, year_str, first_visit_start_datetime
175
+ )
176
+ self._update_cehrgpt_record(cehrgpt_record, age_str, first_visit_start_datetime)
177
+ self._update_cehrgpt_record(cehrgpt_record, gender, first_visit_start_datetime)
178
+ self._update_cehrgpt_record(cehrgpt_record, race, first_visit_start_datetime)
115
179
 
116
180
  # Use a data cursor to keep track of time
117
181
  datetime_cursor: Optional[datetime.datetime] = None
@@ -146,6 +210,7 @@ class MedToCehrGPTDatasetMapping(DatasetMapping):
146
210
  self._update_cehrgpt_record(
147
211
  cehrgpt_record,
148
212
  code=self._time_token_function(time_delta),
213
+ time=visit_start_datetime,
149
214
  )
150
215
 
151
216
  datetime_cursor = visit_start_datetime
@@ -153,11 +218,13 @@ class MedToCehrGPTDatasetMapping(DatasetMapping):
153
218
  self._update_cehrgpt_record(
154
219
  cehrgpt_record,
155
220
  code="[VS]",
221
+ time=datetime_cursor,
156
222
  )
157
223
  # Add a visit type token
158
224
  self._update_cehrgpt_record(
159
225
  cehrgpt_record,
160
226
  code=visit_type,
227
+ time=datetime_cursor,
161
228
  )
162
229
  # We need to insert an inpatient hour token right after the visit type, we calculate the hour interval
163
230
  # with respect to the midnight of the day
@@ -167,6 +234,7 @@ class MedToCehrGPTDatasetMapping(DatasetMapping):
167
234
  self._update_cehrgpt_record(
168
235
  cehrgpt_record,
169
236
  code=f"i-H{datetime_cursor.hour}",
237
+ time=datetime_cursor,
170
238
  )
171
239
 
172
240
  # Keep track of the existing outpatient events, we don't want to add them again
@@ -185,6 +253,10 @@ class MedToCehrGPTDatasetMapping(DatasetMapping):
185
253
  concept_value_mask = int(
186
254
  numeric_value is not None or text_value is not None
187
255
  )
256
+ if numeric_value is None and text_value is not None:
257
+ if text_value.isnumeric():
258
+ numeric_value = float(text_value)
259
+
188
260
  is_numeric_type = int(numeric_value is not None)
189
261
  code = replace_escape_chars(e["code"])
190
262
 
@@ -208,6 +280,7 @@ class MedToCehrGPTDatasetMapping(DatasetMapping):
208
280
  self._update_cehrgpt_record(
209
281
  cehrgpt_record,
210
282
  code=f"i-{self._inpatient_time_token_function(time_diff_days)}",
283
+ time=event_time,
211
284
  )
212
285
 
213
286
  if self._include_inpatient_hour_token:
@@ -226,6 +299,7 @@ class MedToCehrGPTDatasetMapping(DatasetMapping):
226
299
  self._update_cehrgpt_record(
227
300
  cehrgpt_record,
228
301
  code=f"i-H{time_diff_hours}",
302
+ time=event_time,
229
303
  )
230
304
 
231
305
  if event_identity in existing_duplicate_events:
@@ -234,6 +308,7 @@ class MedToCehrGPTDatasetMapping(DatasetMapping):
234
308
  self._update_cehrgpt_record(
235
309
  cehrgpt_record,
236
310
  code=code,
311
+ time=event_time,
237
312
  concept_value_mask=concept_value_mask,
238
313
  unit=unit,
239
314
  number_as_value=numeric_value if numeric_value else 0.0,
@@ -262,17 +337,24 @@ class MedToCehrGPTDatasetMapping(DatasetMapping):
262
337
  # facility event
263
338
  discharge_facility = get_value(visit, "discharge_facility")
264
339
  if not discharge_facility:
265
- discharge_facility = "0"
266
-
340
+ discharge_facility = DISCHARGE_UNKNOWN_TOKEN
341
+ else:
342
+ discharge_facility = (
343
+ DISCHARGE_UNKNOWN_TOKEN
344
+ if discharge_facility == UNKNOWN_VALUE
345
+ else discharge_facility
346
+ )
267
347
  self._update_cehrgpt_record(
268
348
  cehrgpt_record,
269
349
  code=discharge_facility,
350
+ time=datetime_cursor,
270
351
  )
271
352
 
272
353
  # Reuse the age and date calculated for the last event in the patient timeline
273
354
  self._update_cehrgpt_record(
274
355
  cehrgpt_record,
275
356
  code="[VE]",
357
+ time=datetime_cursor,
276
358
  )
277
359
 
278
360
  # Generate the orders of the concepts that the cehrbert dataset mapping function expects
@@ -284,17 +366,23 @@ class MedToCehrGPTDatasetMapping(DatasetMapping):
284
366
  cehrgpt_record["num_of_concepts"] = len(cehrgpt_record["concept_ids"])
285
367
  cehrgpt_record["num_of_visits"] = len(visits)
286
368
 
287
- if record.get("index_date", None):
288
- cehrgpt_record["index_date"] = record["index_date"]
289
- if record.get("label", None):
369
+ if record.get("index_date", None) is not None:
370
+ cehrgpt_record["index_date"] = (
371
+ record["index_date"].replace(tzinfo=datetime.timezone.utc).timestamp()
372
+ )
373
+ if record.get("label", None) is not None:
290
374
  cehrgpt_record["label"] = record["label"]
291
- if record.get("age_at_index", None):
375
+ if record.get("age_at_index", None) is not None:
292
376
  cehrgpt_record["age_at_index"] = record["age_at_index"]
293
377
 
378
+ assert len(cehrgpt_record["epoch_times"]) == len(
379
+ cehrgpt_record["concept_ids"]
380
+ ), "The number of time stamps must match with the number of concepts in the sequence"
381
+
294
382
  return cehrgpt_record
295
383
 
296
384
 
297
- class HFCehrGptTokenizationMapping(DatasetMapping):
385
+ class HFCehrGptTokenizationMapping(DatasetMappingDecorator):
298
386
  def __init__(
299
387
  self,
300
388
  concept_tokenizer: CehrGptTokenizer,
@@ -308,9 +396,46 @@ class HFCehrGptTokenizationMapping(DatasetMapping):
308
396
  "is_numeric_types",
309
397
  ]
310
398
 
399
+ def filter_out_invalid_tokens(self, record: Dict[str, Any]) -> Dict[str, Any]:
400
+ column_names = []
401
+ seq_length = len(record["concept_ids"])
402
+
403
+ # We can't have "0" as a token in the tokenizer because it would break tokenization for "Race/0", "Visit/0"
404
+ # This is a pre-caution
405
+ if "0" in record["concept_ids"]:
406
+ if isinstance(record["concept_ids"], np.ndarray):
407
+ record["concept_ids"][record["concept_ids"] == "0"] = "Unknown"
408
+ else:
409
+ record["concept_ids"] = [
410
+ "Unknown" if x == "0" else x for x in record["concept_ids"]
411
+ ]
412
+
413
+ for k, v in record.items():
414
+ if k not in CEHRGPT_COLUMNS:
415
+ continue
416
+ if isinstance(v, (list, np.ndarray)) and len(v) == seq_length:
417
+ column_names.append(k)
418
+ valid_concept_ids = self._concept_tokenizer.get_vocab().keys()
419
+ valid_indices = [
420
+ idx
421
+ for idx, concept_id in enumerate(record["concept_ids"])
422
+ if concept_id in valid_concept_ids
423
+ ]
424
+ if len(valid_indices) != len(record["concept_ids"]):
425
+ for column in column_names:
426
+ values = record[column]
427
+ record[column] = [values[idx] for idx in valid_indices]
428
+ return record
429
+
311
430
  def transform(self, record: Dict[str, Any]) -> Dict[str, Any]:
431
+ # Remove the tokens from patient sequences that do not exist in the tokenizer
432
+ record = self.filter_out_invalid_tokens(record)
312
433
  # If any concept has a value associated with it, we normalize the value
313
434
  record["input_ids"] = self._concept_tokenizer.encode(record["concept_ids"])
435
+ assert len(record["input_ids"]) == len(record["concept_ids"]), (
436
+ "The number of tokens must equal to the number of concepts\n"
437
+ f"decoded concept_ids: {self._concept_tokenizer.decode(record['input_ids'], skip_special_tokens=False)}"
438
+ )
314
439
  record["value_indicators"] = record["concept_value_masks"]
315
440
  if "number_as_values" not in record or "concept_as_values" not in record:
316
441
  record["number_as_values"] = [
@@ -391,3 +516,93 @@ class HFFineTuningMapping(HFCehrGptTokenizationMapping):
391
516
  columns = super().remove_columns()
392
517
  columns.append("label")
393
518
  return columns
519
+
520
+
521
+ class ExtractTokenizedSequenceDataMapping:
522
+ def __init__(
523
+ self,
524
+ person_index_date_map: Dict[int, List[Dict[str, Any]]],
525
+ observation_window: int = 0,
526
+ ):
527
+ self.person_index_date_map = person_index_date_map
528
+ self.observation_window = observation_window
529
+
530
+ def _calculate_prediction_start_time(self, prediction_time: float):
531
+ if self.observation_window and self.observation_window > 0:
532
+ return max(prediction_time - self.observation_window * 24 * 3600, 0)
533
+ return 0
534
+
535
+ def transform(self, record: Dict[str, Any]) -> Dict[str, Any]:
536
+ person_id = record["person_id"]
537
+ prediction_times = self.person_index_date_map[person_id]
538
+ prediction_start_end_times = [
539
+ (
540
+ self._calculate_prediction_start_time(
541
+ prediction_time_label_map["index_date"]
542
+ .replace(tzinfo=datetime.timezone.utc)
543
+ .timestamp()
544
+ ),
545
+ prediction_time_label_map["index_date"]
546
+ .replace(tzinfo=datetime.timezone.utc)
547
+ .timestamp(),
548
+ prediction_time_label_map["label"],
549
+ )
550
+ for prediction_time_label_map in prediction_times
551
+ ]
552
+ observation_window_indices = np.zeros(
553
+ (len(prediction_times), len(record["epoch_times"])), dtype=bool
554
+ )
555
+ for i, epoch_time in enumerate(record["epoch_times"]):
556
+ for sample_n, (
557
+ feature_extraction_time_start,
558
+ feature_extraction_end_end,
559
+ _,
560
+ ) in enumerate(prediction_start_end_times):
561
+ if (
562
+ feature_extraction_time_start
563
+ <= epoch_time
564
+ <= feature_extraction_end_end
565
+ ):
566
+ observation_window_indices[sample_n][i] = True
567
+
568
+ seq_length = len(record["epoch_times"])
569
+ time_series_columns = ["concept_ids", "input_ids"]
570
+ static_inputs = dict()
571
+ for k, v in record.items():
572
+ if k in ["concept_ids", "input_ids"]:
573
+ continue
574
+ if isinstance(v, (list, np.ndarray)) and len(v) == seq_length:
575
+ time_series_columns.append(k)
576
+ else:
577
+ static_inputs[k] = v
578
+
579
+ batched_samples = defaultdict(list)
580
+ for (_, index_date, label), observation_window_index in zip(
581
+ prediction_start_end_times, observation_window_indices
582
+ ):
583
+ for k, v in static_inputs.items():
584
+ batched_samples[k].append(v)
585
+ batched_samples["classifier_label"].append(label)
586
+ batched_samples["index_date"].append(index_date)
587
+ try:
588
+ start_age = int(record["concept_ids"][1].split(":")[1])
589
+ except Exception:
590
+ start_age = -1
591
+ batched_samples["age_at_index"].append(start_age)
592
+ for time_series_column in time_series_columns:
593
+ batched_samples[time_series_column].append(
594
+ np.asarray(record[time_series_column])[observation_window_index]
595
+ )
596
+ return batched_samples
597
+
598
+ def batch_transform(self, record: Dict[str, Any]) -> Dict[str, Any]:
599
+ all_batched_record = defaultdict(list)
600
+ all_columns = record.keys()
601
+ for i in range(len(record["concept_ids"])):
602
+ one_record = {}
603
+ for column in all_columns:
604
+ one_record[column] = record[column][i]
605
+ new_batched_record = self.transform(one_record)
606
+ for k, v in new_batched_record.items():
607
+ all_batched_record[k].extend(v)
608
+ return all_batched_record
@@ -1,5 +1,6 @@
1
1
  from typing import Iterator, List, Optional
2
2
 
3
+ import numpy as np
3
4
  import torch
4
5
  import torch.distributed as dist
5
6
  from torch.utils.data import Sampler
@@ -33,6 +34,8 @@ class SamplePackingBatchSampler(Sampler[List[int]]):
33
34
  rank: Optional[int] = None,
34
35
  seed: int = 0,
35
36
  drop_last: bool = False,
37
+ negative_sampling_probability: Optional[float] = None,
38
+ labels: Optional[List[int]] = None,
36
39
  ):
37
40
  """
38
41
  Args:
@@ -73,6 +76,11 @@ class SamplePackingBatchSampler(Sampler[List[int]]):
73
76
  f"Invalid rank {rank}, rank should be in the interval [0, {num_replicas - 1}]"
74
77
  )
75
78
 
79
+ if negative_sampling_probability is not None and labels is None:
80
+ raise ValueError(
81
+ f"When the negative sampling probability is provide, the labels must be provided as well"
82
+ )
83
+
76
84
  self.lengths = lengths
77
85
  self.max_tokens_per_batch = max_tokens_per_batch
78
86
  self.max_position_embeddings = max_position_embeddings
@@ -80,6 +88,8 @@ class SamplePackingBatchSampler(Sampler[List[int]]):
80
88
  self.rank = rank
81
89
  self.seed = seed
82
90
  self.drop_last = drop_last
91
+ self.negative_sampling_probability = negative_sampling_probability
92
+ self.labels = labels
83
93
  # Trainer https://github.com/huggingface/transformers/blame/main/src/transformers/trainer.py#L2470
84
94
  # http://github.com/huggingface/accelerate/blob/v0.31.0/src/accelerate/data_loader.py#L482
85
95
  # the huggingface trainer will call the accelerate.data_loader.DataLoaderShard.set_epoch,
@@ -100,6 +110,14 @@ class SamplePackingBatchSampler(Sampler[List[int]]):
100
110
  current_batch_tokens = 0
101
111
 
102
112
  for idx in indices:
113
+ # There is a chance to skip the negative samples to account for the class imbalance
114
+ # in the fine-tuning dataset
115
+ if self.negative_sampling_probability:
116
+ if (
117
+ np.random.random() > self.negative_sampling_probability
118
+ and self.labels[idx] == 0
119
+ ):
120
+ continue
103
121
  # We take the minimum of the two because each sequence will be truncated to fit
104
122
  # the context window of the model
105
123
  sample_length = min(self.lengths[idx], self.max_position_embeddings)
@@ -131,10 +149,22 @@ class SamplePackingBatchSampler(Sampler[List[int]]):
131
149
  if len(self.lengths) == 0:
132
150
  return 0
133
151
 
134
- # We need to truncate the lengths due to the context window limit imposed by the model
135
- truncated_lengths = [
136
- min(self.max_position_embeddings, length + 2) for length in self.lengths
137
- ]
152
+ # There is a chance to skip the negative samples to account for the class imbalance
153
+ # in the fine-tuning dataset
154
+ if self.negative_sampling_probability:
155
+ truncated_lengths = []
156
+ for length, label in zip(self.lengths, self.labels):
157
+ if (
158
+ np.random.random() > self.negative_sampling_probability
159
+ and label == 0
160
+ ):
161
+ continue
162
+ truncated_lengths.append(length)
163
+ else:
164
+ # We need to truncate the lengths due to the context window limit imposed by the model
165
+ truncated_lengths = [
166
+ min(self.max_position_embeddings, length + 2) for length in self.lengths
167
+ ]
138
168
 
139
169
  # Calculate average sequence length
140
170
  avg_seq_length = sum(truncated_lengths) // len(truncated_lengths)
@@ -145,7 +175,7 @@ class SamplePackingBatchSampler(Sampler[List[int]]):
145
175
  # Estimate total number of batches
146
176
  if self.drop_last:
147
177
  # If dropping last incomplete batch
148
- return len(truncated_lengths) // seqs_per_batch * self.num_replicas
178
+ return len(truncated_lengths) // seqs_per_batch
149
179
  else:
150
180
  # If keeping last incomplete batch, ensure at least 1 batch
151
- return max(1, len(truncated_lengths) // seqs_per_batch) * self.num_replicas
181
+ return max(1, len(truncated_lengths) // seqs_per_batch)