wolof-translate 0.0.3__py3-none-any.whl → 0.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -5,29 +5,43 @@ from math import ceil
5
5
  from tqdm import tqdm
6
6
  import time
7
7
 
8
+
8
9
  class SequenceLengthBatchSampler(Sampler[List[int]]):
9
10
  def __init__(
10
11
  self,
11
12
  dataset,
12
13
  boundaries: List[int],
13
14
  batch_sizes: List[int],
14
- input_key: Optional[int] = None,
15
- label_key: Optional[int] = None,
15
+ input_key: Optional[Union[int, str]] = None,
16
+ label_key: Optional[Union[int, str]] = None,
16
17
  drop_unique: bool = True,
17
18
  ):
19
+ """
20
+ Sampler that batches sequences of similar lengths together to minimize padding.
21
+
22
+ Args:
23
+ dataset: Dataset to sample from.
24
+ boundaries: List of length boundaries to bucket sequences.
25
+ batch_sizes: List of batch sizes per bucket (length must be len(boundaries)+1).
26
+ input_key: Key or index to access input sequence in dataset item.
27
+ label_key: Key or index to access label sequence in dataset item.
28
+ drop_unique: Whether to drop batches with a single leftover element.
29
+ """
18
30
  self.dataset = dataset
19
31
  self.boundaries = boundaries
20
32
  self.batch_sizes = batch_sizes
21
33
  self.drop_unique = drop_unique
22
34
 
35
+ assert len(batch_sizes) == len(boundaries) + 1, (
36
+ f"batch_sizes length ({len(batch_sizes)}) must be one more than boundaries length ({len(boundaries)})"
37
+ )
38
+
23
39
  start_time = time.time()
24
40
  tqdm.write("Computing sequence lengths...")
25
41
 
26
- # Compute lengths with tqdm progress bar
27
42
  self.lengths = np.array([
28
- max(len(data[0]), len(data[2])) if input_key is None or label_key is None
29
- else max(len(data[input_key]), len(data[label_key]))
30
- for data in tqdm(dataset, desc="Lengths", unit="seq")
43
+ self._get_length(data, input_key, label_key)
44
+ for data in tqdm(self.dataset, desc="Lengths", unit="seq")
31
45
  ])
32
46
 
33
47
  tqdm.write(f"Sequence lengths computed in {time.time() - start_time:.2f} seconds.")
@@ -35,7 +49,7 @@ class SequenceLengthBatchSampler(Sampler[List[int]]):
35
49
  start_time = time.time()
36
50
  tqdm.write("Assigning buckets...")
37
51
 
38
- # Assign bucket ids using digitize (vectorized)
52
+ # Assign bucket ids (0-based)
39
53
  self.bucket_ids = np.digitize(self.lengths, bins=self.boundaries, right=True)
40
54
 
41
55
  # Create buckets of indices
@@ -46,7 +60,6 @@ class SequenceLengthBatchSampler(Sampler[List[int]]):
46
60
  start_time = time.time()
47
61
  tqdm.write("Preparing batches...")
48
62
 
49
- # Prepare batches from buckets
50
63
  self.batches = []
51
64
  for bucket, batch_size in zip(self.buckets, self.batch_sizes):
52
65
  bucket = bucket.copy()
@@ -66,8 +79,26 @@ class SequenceLengthBatchSampler(Sampler[List[int]]):
66
79
  self.length = len(self.batches)
67
80
  tqdm.write(f"Batches prepared in {time.time() - start_time:.2f} seconds.")
68
81
 
82
+ def _get_length(self, data, input_key, label_key) -> int:
83
+ """
84
+ Helper to get the max length of input and label sequences in a dataset item.
85
+
86
+ Supports dict-like or tuple/list-like dataset items.
87
+ """
88
+ try:
89
+ if input_key is None or label_key is None:
90
+ # Assume tuple/list with input at 0, label at 2
91
+ input_seq = data[0]
92
+ label_seq = data[2]
93
+ else:
94
+ input_seq = data[input_key]
95
+ label_seq = data[label_key]
96
+ return max(len(input_seq), len(label_seq))
97
+ except Exception as e:
98
+ raise ValueError(f"Error accessing lengths with input_key={input_key}, label_key={label_key}: {e}")
99
+
69
100
  def __iter__(self) -> Iterator[List[int]]:
70
- # Shuffle all batches globally to add randomness between buckets
101
+ # Shuffle batches globally for randomness
71
102
  np.random.shuffle(self.batches)
72
103
  for batch in self.batches:
73
104
  yield batch
@@ -13,6 +13,8 @@ def add_guillemet_space(sentences: Union[list, str]):
13
13
 
14
14
  if type(sentences) is str:
15
15
  sentences = [sentences]
16
+ elif type(sentences) is float:
17
+ sentences = [str(sentences)]
16
18
 
17
19
  for s in range(len(sentences)):
18
20
 
@@ -51,6 +53,8 @@ def delete_guillemet_space(sentences: Union[list, str]):
51
53
 
52
54
  if type(sentences) is str:
53
55
  sentences = [sentences]
56
+ elif type(sentences) is float:
57
+ sentences = [str(sentences)]
54
58
 
55
59
  for s in range(len(sentences)):
56
60
 
@@ -91,6 +95,8 @@ def add_mark_space(
91
95
 
92
96
  if type(sentences) is str:
93
97
  sentences = [sentences]
98
+ elif type(sentences) is float:
99
+ sentences = [str(sentences)]
94
100
 
95
101
  for s in range(len(sentences)):
96
102
 
@@ -123,6 +129,8 @@ def remove_mark_space(sentences: Union[list, str], marks: list = ["'", "-"]):
123
129
 
124
130
  if type(sentences) is str:
125
131
  sentences = [sentences]
132
+ elif type(sentences) is float:
133
+ sentences = [str(sentences)]
126
134
 
127
135
  for s in range(len(sentences)):
128
136
 
@@ -161,6 +169,8 @@ def delete_much_space(sentences: Union[list, str]):
161
169
 
162
170
  if type(sentences) is str:
163
171
  sentences = [sentences]
172
+ elif type(sentences) is float:
173
+ sentences = [str(sentences)]
164
174
 
165
175
  for i in range(len(sentences)):
166
176
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: wolof-translate
3
- Version: 0.0.3
3
+ Version: 0.0.5
4
4
  Summary: Contain function and classes to process corpora for making translation between wolof text and other languages.
5
5
  Author: Oumar Kane
6
6
  Author-email: oumar.kane@univ-thies.sn
@@ -22,7 +22,7 @@ wolof_translate/trainers/transformer_trainer_custom.py,sha256=hHUBcU4YK6wuRUMiwX
22
22
  wolof_translate/trainers/transformer_trainer_ml.py,sha256=WgggaugkVHSJlwIAZT-QwI90Fl-_zT8Clhb-7M0m8gM,33561
23
23
  wolof_translate/trainers/transformer_trainer_ml_.py,sha256=QaN9DB5pqhBxV4WlFmJCmUyfwlX-UyAzKRwL6rVEr4Q,38199
24
24
  wolof_translate/utils/__init__.py,sha256=Nl3300H-Xd3uTHDR8y-rYa-UUR9FqbqZPwUKJUpQOb4,64
25
- wolof_translate/utils/bucket_iterator.py,sha256=JctNzZI1yU1FjZUKY-zNiRMXS3JpRiXOOSKdObRdPbg,8511
25
+ wolof_translate/utils/bucket_iterator.py,sha256=sGSBCGPn8NzZ32mfEKh0cfH1Z0WbNJWbDjghR8-u5tU,9847
26
26
  wolof_translate/utils/database_manager.py,sha256=7yhgBN1LvVFNEQikxCjSCva82h5nX44Nx2zh8cpFWyA,3543
27
27
  wolof_translate/utils/display_predictions.py,sha256=y5H5lfgIODl6E5Zfb1YIwiAxIlHUxRBoChfQR5kjh24,5145
28
28
  wolof_translate/utils/download_model.py,sha256=x92KpfVPvNK8Suen1qnOcPtZOlB4kXTfqWgoVuuMUEM,1241
@@ -34,7 +34,7 @@ wolof_translate/utils/extract_sentences.py,sha256=-PDBmceKUqiTdV9ieezSIITfADAnv_
34
34
  wolof_translate/utils/recuperate_datasets.py,sha256=4yTNXPOIfTokon0Bke50SdB8MT_Ojmu1aTmYv_K_w64,2644
35
35
  wolof_translate/utils/recuperate_datasets_trunc.py,sha256=82T7mHbxruYJUw0L0ZUUoPHxO2Yr65rApakmIhe034M,2500
36
36
  wolof_translate/utils/send_model.py,sha256=v_dQJDDpk3ak_DutbhwSqKF8-Q_-Gx9zezZsTot6Onk,797
37
- wolof_translate/utils/sent_corrections.py,sha256=5iqdS4j78ayag0GxnCEl_dBUs4zbBAWAOac2h0ECv4c,3534
37
+ wolof_translate/utils/sent_corrections.py,sha256=BC2AJ1Qq8LTNPIDNJitkaygo4g6a0c_tlikLV-QpVBc,3904
38
38
  wolof_translate/utils/sent_transformers.py,sha256=kbbc5H-zPkxSM1uOghGeZa9fCAcm2GwTSuiRHM0asgI,574
39
39
  wolof_translate/utils/sent_unification.py,sha256=UD9uZ--NREj5Z462n5hs-UjMPNhUN8Nr_6ZmR2w-B6Y,2104
40
40
  wolof_translate/utils/split_with_valid.py,sha256=7-e6EfvPbLpTYrZOXJVYYqm_nV7n6yUYOaWkn8hsJJw,2424
@@ -43,7 +43,7 @@ wolof_translate/utils/training.py,sha256=5vPVuqHL6_gqLkh4PTxXqW4UvAJBWNWVDDXC9Fk
43
43
  wolof_translate/utils/trunc_hg_training.py,sha256=mMGrU7Mjr9vYd7eLc8nbFRhRXwSWMKyg35lGf0L6RtQ,6418
44
44
  wolof_translate/utils/improvements/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
45
45
  wolof_translate/utils/improvements/end_marks.py,sha256=scmhMMYguZmrZTPozx1ZovizKrrPfPpMLXbU2-IOdGs,1194
46
- wolof_translate-0.0.3.dist-info/METADATA,sha256=IztySbGGsGKbrtFqCcLpppA07bO8kt2_dxdd4hdMVOI,818
47
- wolof_translate-0.0.3.dist-info/WHEEL,sha256=G16H4A3IeoQmnOrYV4ueZGKSjhipXx8zc8nu9FGlvMA,92
48
- wolof_translate-0.0.3.dist-info/top_level.txt,sha256=YG-kBnOwUZyQ7SofNvMxNYjzCreH2PVcW2UaEg1-Reg,16
49
- wolof_translate-0.0.3.dist-info/RECORD,,
46
+ wolof_translate-0.0.5.dist-info/METADATA,sha256=kQLynATL140xAtQaN8Q9nziqKEx9TjT7mx_oD_9XrN8,818
47
+ wolof_translate-0.0.5.dist-info/WHEEL,sha256=G16H4A3IeoQmnOrYV4ueZGKSjhipXx8zc8nu9FGlvMA,92
48
+ wolof_translate-0.0.5.dist-info/top_level.txt,sha256=YG-kBnOwUZyQ7SofNvMxNYjzCreH2PVcW2UaEg1-Reg,16
49
+ wolof_translate-0.0.5.dist-info/RECORD,,