wolof-translate 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,84 +1,153 @@
1
-
2
- import torch
3
1
  import numpy as np
4
- from typing import *
2
+ from typing import Optional, List, Iterator, Union
5
3
  from torch.utils.data import Sampler
6
- from torch.nn.utils.rnn import pad_sequence
7
4
  from math import ceil
8
-
9
- class SequenceLengthBatchSampler(Sampler):
10
- def __init__(self, dataset, boundaries, batch_sizes, input_key = None, label_key = None, drop_unique = True):
5
+ from tqdm import tqdm
6
+ import time
7
+
8
+ class SequenceLengthBatchSampler(Sampler[List[int]]):
9
+ def __init__(
10
+ self,
11
+ dataset,
12
+ boundaries: List[int],
13
+ batch_sizes: List[int],
14
+ input_key: Optional[int] = None,
15
+ label_key: Optional[int] = None,
16
+ drop_unique: bool = True,
17
+ ):
11
18
  self.dataset = dataset
12
19
  self.boundaries = boundaries
13
20
  self.batch_sizes = batch_sizes
14
- self.data_info = {}
15
21
  self.drop_unique = drop_unique
16
22
 
17
- # Initialize dictionary with indices and element lengths
18
- for i in range(len(dataset)):
19
- data = dataset[i]
20
- length = (
21
- max(len(data[0]), len(data[2]))
22
- if (input_key is None and label_key is None)
23
- else max(len(data[input_key]), len(data[label_key]))
24
- )
25
- self.data_info[i] = {"index": i, "length": length}
26
-
27
- self.calculate_length()
28
-
29
- def calculate_length(self):
30
- self.batches = []
23
+ start_time = time.time()
24
+ tqdm.write("Computing sequence lengths...")
31
25
 
32
- # Sort indices based on element length
33
- sorted_indices = sorted(self.data_info.keys(), key=lambda i: self.data_info[i]["length"])
34
-
35
- # Group indices into batches of sequences with the same length
36
- for boundary in self.boundaries:
37
- batch = [i for i in sorted_indices if self.data_info[i]["length"] <= boundary] # Filter indices based on length boundary
38
- self.batches.append(batch)
39
- sorted_indices = [i for i in sorted_indices if i not in batch] # Remove processed indices
26
+ # Compute lengths with tqdm progress bar
27
+ self.lengths = np.array([
28
+ max(len(data[0]), len(data[2])) if input_key is None or label_key is None
29
+ else max(len(data[input_key]), len(data[label_key]))
30
+ for data in tqdm(dataset, desc="Lengths", unit="seq")
31
+ ])
40
32
 
41
- # Add remaining indices to the last batch
42
- self.batches.append(sorted_indices)
33
+ tqdm.write(f"Sequence lengths computed in {time.time() - start_time:.2f} seconds.")
43
34
 
44
- # Calculate the total length of the data loader
45
- self.length = sum(ceil(len(batch) / batch_size) for batch, batch_size in zip(self.batches, self.batch_sizes) if len(batch) % batch_size != 1 or not self.drop_unique)
35
+ start_time = time.time()
36
+ tqdm.write("Assigning buckets...")
46
37
 
47
- def __iter__(self):
48
- # indices = list(self.data_info.keys()) # Get indices from the data_info dictionary
49
- # np.random.shuffle(indices) # Shuffle the indices
38
+ # Assign bucket ids using digitize (vectorized)
39
+ self.bucket_ids = np.digitize(self.lengths, bins=self.boundaries, right=True)
50
40
 
51
- # Yield batches with the corresponding batch sizes
52
- for batch_indices, batch_size in zip(self.batches, self.batch_sizes):
53
- num_batches = len(batch_indices) // batch_size
41
+ # Create buckets of indices
42
+ self.buckets = [np.where(self.bucket_ids == i)[0] for i in range(len(boundaries) + 1)]
54
43
 
55
- for i in range(num_batches):
56
- # Recuperate the current bucket
57
- current_bucket = batch_indices[i * batch_size: (i + 1) * batch_size]
44
+ tqdm.write(f"Buckets assigned in {time.time() - start_time:.2f} seconds.")
58
45
 
59
- # Shuffle the current bucket
60
- np.random.shuffle(current_bucket)
46
+ start_time = time.time()
47
+ tqdm.write("Preparing batches...")
61
48
 
62
- # Yield the current bucket
63
- yield [self.data_info[i]["index"] for i in current_bucket]
49
+ # Prepare batches from buckets
50
+ self.batches = []
51
+ for bucket, batch_size in zip(self.buckets, self.batch_sizes):
52
+ bucket = bucket.copy()
53
+ np.random.shuffle(bucket)
64
54
 
65
- remaining_indices = len(batch_indices) % batch_size
66
-
67
- if remaining_indices > 0 and remaining_indices != 1 or not self.drop_unique:
68
-
69
- # Recuperate the current bucket
70
- current_bucket = batch_indices[-remaining_indices:]
55
+ n_full_batches = len(bucket) // batch_size
56
+ leftover = len(bucket) % batch_size
71
57
 
72
- # Shuffle the current bucket
73
- np.random.shuffle(current_bucket)
58
+ for i in range(n_full_batches):
59
+ batch = bucket[i * batch_size : (i + 1) * batch_size].tolist()
60
+ self.batches.append(batch)
74
61
 
75
- # Yield the current bucket
76
- yield [self.data_info[i]["index"] for i in batch_indices[-remaining_indices:]]
62
+ if leftover > 0 and (leftover != 1 or not self.drop_unique):
63
+ batch = bucket[-leftover:].tolist()
64
+ self.batches.append(batch)
77
65
 
78
- def __len__(self):
66
+ self.length = len(self.batches)
67
+ tqdm.write(f"Batches prepared in {time.time() - start_time:.2f} seconds.")
68
+
69
+ def __iter__(self) -> Iterator[List[int]]:
70
+ # Shuffle all batches globally to add randomness between buckets
71
+ np.random.shuffle(self.batches)
72
+ for batch in self.batches:
73
+ yield batch
74
+
75
+ def __len__(self) -> int:
79
76
  return self.length
80
77
 
81
78
 
79
+
80
+
81
+ # class SequenceLengthBatchSampler(Sampler[List[int]]):
82
+ # def __init__(
83
+ # self,
84
+ # dataset,
85
+ # boundaries: List[int],
86
+ # batch_sizes: List[int],
87
+ # input_key: Optional[int] = None,
88
+ # label_key: Optional[int] = None,
89
+ # drop_unique: bool = True,
90
+ # ):
91
+ # self.dataset = dataset
92
+ # self.boundaries = boundaries
93
+ # self.batch_sizes = batch_sizes
94
+ # self.drop_unique = drop_unique
95
+ # self.data_info = {}
96
+
97
+ # # Extract lengths
98
+ # for i in range(len(dataset)):
99
+ # data = dataset[i]
100
+ # if input_key is None or label_key is None:
101
+ # length = max(len(data[0]), len(data[2]))
102
+ # else:
103
+ # length = max(len(data[input_key]), len(data[label_key]))
104
+ # self.data_info[i] = {"index": i, "length": length}
105
+
106
+ # self.calculate_length()
107
+
108
+ # def calculate_length(self):
109
+ # self.batches = []
110
+ # sorted_indices = sorted(self.data_info.keys(), key=lambda i: self.data_info[i]["length"])
111
+
112
+ # prev_boundary = 0
113
+ # for boundary in self.boundaries:
114
+ # batch = [i for i in sorted_indices if prev_boundary < self.data_info[i]["length"] <= boundary]
115
+ # self.batches.append(batch)
116
+ # sorted_indices = [i for i in sorted_indices if i not in batch]
117
+ # prev_boundary = boundary
118
+
119
+ # # Remaining sequences > last boundary
120
+ # self.batches.append(sorted_indices)
121
+
122
+ # total_batches = 0
123
+ # for batch, batch_size in zip(self.batches, self.batch_sizes):
124
+ # n_full_batches = len(batch) // batch_size
125
+ # leftover = len(batch) % batch_size
126
+ # total_batches += n_full_batches
127
+ # if leftover > 0 and (leftover != 1 or not self.drop_unique):
128
+ # total_batches += 1
129
+ # self.length = total_batches
130
+
131
+ # def __iter__(self) -> Iterator[List[int]]:
132
+ # for batch_indices, batch_size in zip(self.batches, self.batch_sizes):
133
+ # num_batches = len(batch_indices) // batch_size
134
+
135
+ # for i in range(num_batches):
136
+ # current_bucket = batch_indices[i * batch_size: (i + 1) * batch_size]
137
+ # np.random.shuffle(current_bucket)
138
+ # yield [self.data_info[idx]["index"] for idx in current_bucket]
139
+
140
+ # remaining = len(batch_indices) % batch_size
141
+ # if remaining > 0 and (remaining != 1 or not self.drop_unique):
142
+ # current_bucket = batch_indices[-remaining:]
143
+ # np.random.shuffle(current_bucket)
144
+ # yield [self.data_info[idx]["index"] for idx in current_bucket]
145
+
146
+ # def __len__(self) -> int:
147
+ # return self.length
148
+
149
+
150
+
82
151
  class BucketSampler(Sampler):
83
152
  def __init__(self, dataset, batch_size, sort_key=lambda x, index_1, index_2: max(len(x[index_1]), len(x[index_2])), input_key: Union[str, int] = 0, label_key: Union[str, int] = 1):
84
153
  self.dataset = dataset
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: wolof-translate
3
- Version: 0.0.1
3
+ Version: 0.0.3
4
4
  Summary: Contain function and classes to process corpora for making translation between wolof text and other languages.
5
5
  Author: Oumar Kane
6
6
  Author-email: oumar.kane@univ-thies.sn
@@ -22,7 +22,7 @@ wolof_translate/trainers/transformer_trainer_custom.py,sha256=hHUBcU4YK6wuRUMiwX
22
22
  wolof_translate/trainers/transformer_trainer_ml.py,sha256=WgggaugkVHSJlwIAZT-QwI90Fl-_zT8Clhb-7M0m8gM,33561
23
23
  wolof_translate/trainers/transformer_trainer_ml_.py,sha256=QaN9DB5pqhBxV4WlFmJCmUyfwlX-UyAzKRwL6rVEr4Q,38199
24
24
  wolof_translate/utils/__init__.py,sha256=Nl3300H-Xd3uTHDR8y-rYa-UUR9FqbqZPwUKJUpQOb4,64
25
- wolof_translate/utils/bucket_iterator.py,sha256=Hglii1Hj6H_K51JunTjUAxuLd4ehPb6LeeMVhsmhNxQ,6248
25
+ wolof_translate/utils/bucket_iterator.py,sha256=JctNzZI1yU1FjZUKY-zNiRMXS3JpRiXOOSKdObRdPbg,8511
26
26
  wolof_translate/utils/database_manager.py,sha256=7yhgBN1LvVFNEQikxCjSCva82h5nX44Nx2zh8cpFWyA,3543
27
27
  wolof_translate/utils/display_predictions.py,sha256=y5H5lfgIODl6E5Zfb1YIwiAxIlHUxRBoChfQR5kjh24,5145
28
28
  wolof_translate/utils/download_model.py,sha256=x92KpfVPvNK8Suen1qnOcPtZOlB4kXTfqWgoVuuMUEM,1241
@@ -43,7 +43,7 @@ wolof_translate/utils/training.py,sha256=5vPVuqHL6_gqLkh4PTxXqW4UvAJBWNWVDDXC9Fk
43
43
  wolof_translate/utils/trunc_hg_training.py,sha256=mMGrU7Mjr9vYd7eLc8nbFRhRXwSWMKyg35lGf0L6RtQ,6418
44
44
  wolof_translate/utils/improvements/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
45
45
  wolof_translate/utils/improvements/end_marks.py,sha256=scmhMMYguZmrZTPozx1ZovizKrrPfPpMLXbU2-IOdGs,1194
46
- wolof_translate-0.0.1.dist-info/METADATA,sha256=itQMCA-zGM3gSDiKco5dMtY5qfpPHvKNJbw9KtffxzI,818
47
- wolof_translate-0.0.1.dist-info/WHEEL,sha256=G16H4A3IeoQmnOrYV4ueZGKSjhipXx8zc8nu9FGlvMA,92
48
- wolof_translate-0.0.1.dist-info/top_level.txt,sha256=YG-kBnOwUZyQ7SofNvMxNYjzCreH2PVcW2UaEg1-Reg,16
49
- wolof_translate-0.0.1.dist-info/RECORD,,
46
+ wolof_translate-0.0.3.dist-info/METADATA,sha256=IztySbGGsGKbrtFqCcLpppA07bO8kt2_dxdd4hdMVOI,818
47
+ wolof_translate-0.0.3.dist-info/WHEEL,sha256=G16H4A3IeoQmnOrYV4ueZGKSjhipXx8zc8nu9FGlvMA,92
48
+ wolof_translate-0.0.3.dist-info/top_level.txt,sha256=YG-kBnOwUZyQ7SofNvMxNYjzCreH2PVcW2UaEg1-Reg,16
49
+ wolof_translate-0.0.3.dist-info/RECORD,,