wolof-translate 0.0.2__py3-none-any.whl → 0.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,8 +1,10 @@
1
- import torch
2
1
  import numpy as np
3
- from typing import Optional, List, Iterator
2
+ from typing import Optional, List, Iterator, Union
4
3
  from torch.utils.data import Sampler
5
4
  from math import ceil
5
+ from tqdm import tqdm
6
+ import time
7
+
6
8
 
7
9
  class SequenceLengthBatchSampler(Sampler[List[int]]):
8
10
  def __init__(
@@ -10,70 +12,173 @@ class SequenceLengthBatchSampler(Sampler[List[int]]):
10
12
  dataset,
11
13
  boundaries: List[int],
12
14
  batch_sizes: List[int],
13
- input_key: Optional[int] = None,
14
- label_key: Optional[int] = None,
15
+ input_key: Optional[Union[int, str]] = None,
16
+ label_key: Optional[Union[int, str]] = None,
15
17
  drop_unique: bool = True,
16
18
  ):
19
+ """
20
+ Sampler that batches sequences of similar lengths together to minimize padding.
21
+
22
+ Args:
23
+ dataset: Dataset to sample from.
24
+ boundaries: List of length boundaries to bucket sequences.
25
+ batch_sizes: List of batch sizes per bucket (length must be len(boundaries)+1).
26
+ input_key: Key or index to access input sequence in dataset item.
27
+ label_key: Key or index to access label sequence in dataset item.
28
+ drop_unique: Whether to drop batches with a single leftover element.
29
+ """
17
30
  self.dataset = dataset
18
31
  self.boundaries = boundaries
19
32
  self.batch_sizes = batch_sizes
20
33
  self.drop_unique = drop_unique
21
- self.data_info = {}
22
34
 
23
- # Extract lengths
24
- for i in range(len(dataset)):
25
- data = dataset[i]
26
- if input_key is None or label_key is None:
27
- length = max(len(data[0]), len(data[2]))
28
- else:
29
- length = max(len(data[input_key]), len(data[label_key]))
30
- self.data_info[i] = {"index": i, "length": length}
35
+ assert len(batch_sizes) == len(boundaries) + 1, (
36
+ f"batch_sizes length ({len(batch_sizes)}) must be one more than boundaries length ({len(boundaries)})"
37
+ )
38
+
39
+ start_time = time.time()
40
+ tqdm.write("Computing sequence lengths...")
41
+
42
+ self.lengths = np.array([
43
+ self._get_length(data, input_key, label_key)
44
+ for data in tqdm(self.dataset, desc="Lengths", unit="seq")
45
+ ])
46
+
47
+ tqdm.write(f"Sequence lengths computed in {time.time() - start_time:.2f} seconds.")
31
48
 
32
- self.calculate_length()
49
+ start_time = time.time()
50
+ tqdm.write("Assigning buckets...")
51
+
52
+ # Assign bucket ids (0-based)
53
+ self.bucket_ids = np.digitize(self.lengths, bins=self.boundaries, right=True)
54
+
55
+ # Create buckets of indices
56
+ self.buckets = [np.where(self.bucket_ids == i)[0] for i in range(len(boundaries) + 1)]
57
+
58
+ tqdm.write(f"Buckets assigned in {time.time() - start_time:.2f} seconds.")
59
+
60
+ start_time = time.time()
61
+ tqdm.write("Preparing batches...")
33
62
 
34
- def calculate_length(self):
35
63
  self.batches = []
36
- sorted_indices = sorted(self.data_info.keys(), key=lambda i: self.data_info[i]["length"])
37
-
38
- prev_boundary = 0
39
- for boundary in self.boundaries:
40
- batch = [i for i in sorted_indices if prev_boundary < self.data_info[i]["length"] <= boundary]
41
- self.batches.append(batch)
42
- sorted_indices = [i for i in sorted_indices if i not in batch]
43
- prev_boundary = boundary
44
-
45
- # Remaining sequences > last boundary
46
- self.batches.append(sorted_indices)
47
-
48
- total_batches = 0
49
- for batch, batch_size in zip(self.batches, self.batch_sizes):
50
- n_full_batches = len(batch) // batch_size
51
- leftover = len(batch) % batch_size
52
- total_batches += n_full_batches
64
+ for bucket, batch_size in zip(self.buckets, self.batch_sizes):
65
+ bucket = bucket.copy()
66
+ np.random.shuffle(bucket)
67
+
68
+ n_full_batches = len(bucket) // batch_size
69
+ leftover = len(bucket) % batch_size
70
+
71
+ for i in range(n_full_batches):
72
+ batch = bucket[i * batch_size : (i + 1) * batch_size].tolist()
73
+ self.batches.append(batch)
74
+
53
75
  if leftover > 0 and (leftover != 1 or not self.drop_unique):
54
- total_batches += 1
55
- self.length = total_batches
76
+ batch = bucket[-leftover:].tolist()
77
+ self.batches.append(batch)
56
78
 
57
- def __iter__(self) -> Iterator[List[int]]:
58
- for batch_indices, batch_size in zip(self.batches, self.batch_sizes):
59
- num_batches = len(batch_indices) // batch_size
79
+ self.length = len(self.batches)
80
+ tqdm.write(f"Batches prepared in {time.time() - start_time:.2f} seconds.")
60
81
 
61
- for i in range(num_batches):
62
- current_bucket = batch_indices[i * batch_size: (i + 1) * batch_size]
63
- np.random.shuffle(current_bucket)
64
- yield [self.data_info[idx]["index"] for idx in current_bucket]
82
+ def _get_length(self, data, input_key, label_key) -> int:
83
+ """
84
+ Helper to get the max length of input and label sequences in a dataset item.
65
85
 
66
- remaining = len(batch_indices) % batch_size
67
- if remaining > 0 and (remaining != 1 or not self.drop_unique):
68
- current_bucket = batch_indices[-remaining:]
69
- np.random.shuffle(current_bucket)
70
- yield [self.data_info[idx]["index"] for idx in current_bucket]
86
+ Supports dict-like or tuple/list-like dataset items.
87
+ """
88
+ try:
89
+ if input_key is None or label_key is None:
90
+ # Assume tuple/list with input at 0, label at 2
91
+ input_seq = data[0]
92
+ label_seq = data[2]
93
+ else:
94
+ input_seq = data[input_key]
95
+ label_seq = data[label_key]
96
+ return max(len(input_seq), len(label_seq))
97
+ except Exception as e:
98
+ raise ValueError(f"Error accessing lengths with input_key={input_key}, label_key={label_key}: {e}")
99
+
100
+ def __iter__(self) -> Iterator[List[int]]:
101
+ # Shuffle batches globally for randomness
102
+ np.random.shuffle(self.batches)
103
+ for batch in self.batches:
104
+ yield batch
71
105
 
72
106
  def __len__(self) -> int:
73
107
  return self.length
74
108
 
75
109
 
76
110
 
111
+
112
+ # class SequenceLengthBatchSampler(Sampler[List[int]]):
113
+ # def __init__(
114
+ # self,
115
+ # dataset,
116
+ # boundaries: List[int],
117
+ # batch_sizes: List[int],
118
+ # input_key: Optional[int] = None,
119
+ # label_key: Optional[int] = None,
120
+ # drop_unique: bool = True,
121
+ # ):
122
+ # self.dataset = dataset
123
+ # self.boundaries = boundaries
124
+ # self.batch_sizes = batch_sizes
125
+ # self.drop_unique = drop_unique
126
+ # self.data_info = {}
127
+
128
+ # # Extract lengths
129
+ # for i in range(len(dataset)):
130
+ # data = dataset[i]
131
+ # if input_key is None or label_key is None:
132
+ # length = max(len(data[0]), len(data[2]))
133
+ # else:
134
+ # length = max(len(data[input_key]), len(data[label_key]))
135
+ # self.data_info[i] = {"index": i, "length": length}
136
+
137
+ # self.calculate_length()
138
+
139
+ # def calculate_length(self):
140
+ # self.batches = []
141
+ # sorted_indices = sorted(self.data_info.keys(), key=lambda i: self.data_info[i]["length"])
142
+
143
+ # prev_boundary = 0
144
+ # for boundary in self.boundaries:
145
+ # batch = [i for i in sorted_indices if prev_boundary < self.data_info[i]["length"] <= boundary]
146
+ # self.batches.append(batch)
147
+ # sorted_indices = [i for i in sorted_indices if i not in batch]
148
+ # prev_boundary = boundary
149
+
150
+ # # Remaining sequences > last boundary
151
+ # self.batches.append(sorted_indices)
152
+
153
+ # total_batches = 0
154
+ # for batch, batch_size in zip(self.batches, self.batch_sizes):
155
+ # n_full_batches = len(batch) // batch_size
156
+ # leftover = len(batch) % batch_size
157
+ # total_batches += n_full_batches
158
+ # if leftover > 0 and (leftover != 1 or not self.drop_unique):
159
+ # total_batches += 1
160
+ # self.length = total_batches
161
+
162
+ # def __iter__(self) -> Iterator[List[int]]:
163
+ # for batch_indices, batch_size in zip(self.batches, self.batch_sizes):
164
+ # num_batches = len(batch_indices) // batch_size
165
+
166
+ # for i in range(num_batches):
167
+ # current_bucket = batch_indices[i * batch_size: (i + 1) * batch_size]
168
+ # np.random.shuffle(current_bucket)
169
+ # yield [self.data_info[idx]["index"] for idx in current_bucket]
170
+
171
+ # remaining = len(batch_indices) % batch_size
172
+ # if remaining > 0 and (remaining != 1 or not self.drop_unique):
173
+ # current_bucket = batch_indices[-remaining:]
174
+ # np.random.shuffle(current_bucket)
175
+ # yield [self.data_info[idx]["index"] for idx in current_bucket]
176
+
177
+ # def __len__(self) -> int:
178
+ # return self.length
179
+
180
+
181
+
77
182
  class BucketSampler(Sampler):
78
183
  def __init__(self, dataset, batch_size, sort_key=lambda x, index_1, index_2: max(len(x[index_1]), len(x[index_2])), input_key: Union[str, int] = 0, label_key: Union[str, int] = 1):
79
184
  self.dataset = dataset
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: wolof-translate
3
- Version: 0.0.2
3
+ Version: 0.0.4
4
4
  Summary: Contain function and classes to process corpora for making translation between wolof text and other languages.
5
5
  Author: Oumar Kane
6
6
  Author-email: oumar.kane@univ-thies.sn
@@ -22,7 +22,7 @@ wolof_translate/trainers/transformer_trainer_custom.py,sha256=hHUBcU4YK6wuRUMiwX
22
22
  wolof_translate/trainers/transformer_trainer_ml.py,sha256=WgggaugkVHSJlwIAZT-QwI90Fl-_zT8Clhb-7M0m8gM,33561
23
23
  wolof_translate/trainers/transformer_trainer_ml_.py,sha256=QaN9DB5pqhBxV4WlFmJCmUyfwlX-UyAzKRwL6rVEr4Q,38199
24
24
  wolof_translate/utils/__init__.py,sha256=Nl3300H-Xd3uTHDR8y-rYa-UUR9FqbqZPwUKJUpQOb4,64
25
- wolof_translate/utils/bucket_iterator.py,sha256=VLqmgZ9Z7PoXFS76R81McV3jKlWQ_A4LYLAZx_a0xfU,5800
25
+ wolof_translate/utils/bucket_iterator.py,sha256=sGSBCGPn8NzZ32mfEKh0cfH1Z0WbNJWbDjghR8-u5tU,9847
26
26
  wolof_translate/utils/database_manager.py,sha256=7yhgBN1LvVFNEQikxCjSCva82h5nX44Nx2zh8cpFWyA,3543
27
27
  wolof_translate/utils/display_predictions.py,sha256=y5H5lfgIODl6E5Zfb1YIwiAxIlHUxRBoChfQR5kjh24,5145
28
28
  wolof_translate/utils/download_model.py,sha256=x92KpfVPvNK8Suen1qnOcPtZOlB4kXTfqWgoVuuMUEM,1241
@@ -43,7 +43,7 @@ wolof_translate/utils/training.py,sha256=5vPVuqHL6_gqLkh4PTxXqW4UvAJBWNWVDDXC9Fk
43
43
  wolof_translate/utils/trunc_hg_training.py,sha256=mMGrU7Mjr9vYd7eLc8nbFRhRXwSWMKyg35lGf0L6RtQ,6418
44
44
  wolof_translate/utils/improvements/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
45
45
  wolof_translate/utils/improvements/end_marks.py,sha256=scmhMMYguZmrZTPozx1ZovizKrrPfPpMLXbU2-IOdGs,1194
46
- wolof_translate-0.0.2.dist-info/METADATA,sha256=l8TYjNCIujICJklQFsyVy9XINJNNAU0hPkkNt0kbKFM,818
47
- wolof_translate-0.0.2.dist-info/WHEEL,sha256=G16H4A3IeoQmnOrYV4ueZGKSjhipXx8zc8nu9FGlvMA,92
48
- wolof_translate-0.0.2.dist-info/top_level.txt,sha256=YG-kBnOwUZyQ7SofNvMxNYjzCreH2PVcW2UaEg1-Reg,16
49
- wolof_translate-0.0.2.dist-info/RECORD,,
46
+ wolof_translate-0.0.4.dist-info/METADATA,sha256=Adyexcw1wpc80mDnoeLxFuaWe4Pg0ZAZroqx6eHnSz4,818
47
+ wolof_translate-0.0.4.dist-info/WHEEL,sha256=G16H4A3IeoQmnOrYV4ueZGKSjhipXx8zc8nu9FGlvMA,92
48
+ wolof_translate-0.0.4.dist-info/top_level.txt,sha256=YG-kBnOwUZyQ7SofNvMxNYjzCreH2PVcW2UaEg1-Reg,16
49
+ wolof_translate-0.0.4.dist-info/RECORD,,