wolof-translate 0.0.1__py3-none-any.whl → 0.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,84 +1,79 @@
1
-
2
1
  import torch
3
2
  import numpy as np
4
- from typing import *
3
+ from typing import Optional, List, Iterator
5
4
  from torch.utils.data import Sampler
6
- from torch.nn.utils.rnn import pad_sequence
7
5
  from math import ceil
8
6
 
9
- class SequenceLengthBatchSampler(Sampler):
10
- def __init__(self, dataset, boundaries, batch_sizes, input_key = None, label_key = None, drop_unique = True):
7
+ class SequenceLengthBatchSampler(Sampler[List[int]]):
8
+ def __init__(
9
+ self,
10
+ dataset,
11
+ boundaries: List[int],
12
+ batch_sizes: List[int],
13
+ input_key: Optional[int] = None,
14
+ label_key: Optional[int] = None,
15
+ drop_unique: bool = True,
16
+ ):
11
17
  self.dataset = dataset
12
18
  self.boundaries = boundaries
13
19
  self.batch_sizes = batch_sizes
14
- self.data_info = {}
15
20
  self.drop_unique = drop_unique
21
+ self.data_info = {}
16
22
 
17
- # Initialize dictionary with indices and element lengths
23
+ # Extract lengths
18
24
  for i in range(len(dataset)):
19
25
  data = dataset[i]
20
- length = (
21
- max(len(data[0]), len(data[2]))
22
- if (input_key is None and label_key is None)
23
- else max(len(data[input_key]), len(data[label_key]))
24
- )
26
+ if input_key is None or label_key is None:
27
+ length = max(len(data[0]), len(data[2]))
28
+ else:
29
+ length = max(len(data[input_key]), len(data[label_key]))
25
30
  self.data_info[i] = {"index": i, "length": length}
26
-
31
+
27
32
  self.calculate_length()
28
33
 
29
34
  def calculate_length(self):
30
35
  self.batches = []
31
-
32
- # Sort indices based on element length
33
36
  sorted_indices = sorted(self.data_info.keys(), key=lambda i: self.data_info[i]["length"])
34
-
35
- # Group indices into batches of sequences with the same length
37
+
38
+ prev_boundary = 0
36
39
  for boundary in self.boundaries:
37
- batch = [i for i in sorted_indices if self.data_info[i]["length"] <= boundary] # Filter indices based on length boundary
40
+ batch = [i for i in sorted_indices if prev_boundary < self.data_info[i]["length"] <= boundary]
38
41
  self.batches.append(batch)
39
- sorted_indices = [i for i in sorted_indices if i not in batch] # Remove processed indices
42
+ sorted_indices = [i for i in sorted_indices if i not in batch]
43
+ prev_boundary = boundary
40
44
 
41
- # Add remaining indices to the last batch
45
+ # Remaining sequences > last boundary
42
46
  self.batches.append(sorted_indices)
43
47
 
44
- # Calculate the total length of the data loader
45
- self.length = sum(ceil(len(batch) / batch_size) for batch, batch_size in zip(self.batches, self.batch_sizes) if len(batch) % batch_size != 1 or not self.drop_unique)
46
-
47
- def __iter__(self):
48
- # indices = list(self.data_info.keys()) # Get indices from the data_info dictionary
49
- # np.random.shuffle(indices) # Shuffle the indices
48
+ total_batches = 0
49
+ for batch, batch_size in zip(self.batches, self.batch_sizes):
50
+ n_full_batches = len(batch) // batch_size
51
+ leftover = len(batch) % batch_size
52
+ total_batches += n_full_batches
53
+ if leftover > 0 and (leftover != 1 or not self.drop_unique):
54
+ total_batches += 1
55
+ self.length = total_batches
50
56
 
51
- # Yield batches with the corresponding batch sizes
57
+ def __iter__(self) -> Iterator[List[int]]:
52
58
  for batch_indices, batch_size in zip(self.batches, self.batch_sizes):
53
59
  num_batches = len(batch_indices) // batch_size
54
60
 
55
61
  for i in range(num_batches):
56
- # Recuperate the current bucket
57
62
  current_bucket = batch_indices[i * batch_size: (i + 1) * batch_size]
58
-
59
- # Shuffle the current bucket
60
63
  np.random.shuffle(current_bucket)
64
+ yield [self.data_info[idx]["index"] for idx in current_bucket]
61
65
 
62
- # Yield the current bucket
63
- yield [self.data_info[i]["index"] for i in current_bucket]
64
-
65
- remaining_indices = len(batch_indices) % batch_size
66
-
67
- if remaining_indices > 0 and remaining_indices != 1 or not self.drop_unique:
68
-
69
- # Recuperate the current bucket
70
- current_bucket = batch_indices[-remaining_indices:]
71
-
72
- # Shuffle the current bucket
66
+ remaining = len(batch_indices) % batch_size
67
+ if remaining > 0 and (remaining != 1 or not self.drop_unique):
68
+ current_bucket = batch_indices[-remaining:]
73
69
  np.random.shuffle(current_bucket)
70
+ yield [self.data_info[idx]["index"] for idx in current_bucket]
74
71
 
75
- # Yield the current bucket
76
- yield [self.data_info[i]["index"] for i in batch_indices[-remaining_indices:]]
77
-
78
- def __len__(self):
72
+ def __len__(self) -> int:
79
73
  return self.length
80
74
 
81
75
 
76
+
82
77
  class BucketSampler(Sampler):
83
78
  def __init__(self, dataset, batch_size, sort_key=lambda x, index_1, index_2: max(len(x[index_1]), len(x[index_2])), input_key: Union[str, int] = 0, label_key: Union[str, int] = 1):
84
79
  self.dataset = dataset
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: wolof-translate
3
- Version: 0.0.1
3
+ Version: 0.0.2
4
4
  Summary: Contain function and classes to process corpora for making translation between wolof text and other languages.
5
5
  Author: Oumar Kane
6
6
  Author-email: oumar.kane@univ-thies.sn
@@ -22,7 +22,7 @@ wolof_translate/trainers/transformer_trainer_custom.py,sha256=hHUBcU4YK6wuRUMiwX
22
22
  wolof_translate/trainers/transformer_trainer_ml.py,sha256=WgggaugkVHSJlwIAZT-QwI90Fl-_zT8Clhb-7M0m8gM,33561
23
23
  wolof_translate/trainers/transformer_trainer_ml_.py,sha256=QaN9DB5pqhBxV4WlFmJCmUyfwlX-UyAzKRwL6rVEr4Q,38199
24
24
  wolof_translate/utils/__init__.py,sha256=Nl3300H-Xd3uTHDR8y-rYa-UUR9FqbqZPwUKJUpQOb4,64
25
- wolof_translate/utils/bucket_iterator.py,sha256=Hglii1Hj6H_K51JunTjUAxuLd4ehPb6LeeMVhsmhNxQ,6248
25
+ wolof_translate/utils/bucket_iterator.py,sha256=VLqmgZ9Z7PoXFS76R81McV3jKlWQ_A4LYLAZx_a0xfU,5800
26
26
  wolof_translate/utils/database_manager.py,sha256=7yhgBN1LvVFNEQikxCjSCva82h5nX44Nx2zh8cpFWyA,3543
27
27
  wolof_translate/utils/display_predictions.py,sha256=y5H5lfgIODl6E5Zfb1YIwiAxIlHUxRBoChfQR5kjh24,5145
28
28
  wolof_translate/utils/download_model.py,sha256=x92KpfVPvNK8Suen1qnOcPtZOlB4kXTfqWgoVuuMUEM,1241
@@ -43,7 +43,7 @@ wolof_translate/utils/training.py,sha256=5vPVuqHL6_gqLkh4PTxXqW4UvAJBWNWVDDXC9Fk
43
43
  wolof_translate/utils/trunc_hg_training.py,sha256=mMGrU7Mjr9vYd7eLc8nbFRhRXwSWMKyg35lGf0L6RtQ,6418
44
44
  wolof_translate/utils/improvements/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
45
45
  wolof_translate/utils/improvements/end_marks.py,sha256=scmhMMYguZmrZTPozx1ZovizKrrPfPpMLXbU2-IOdGs,1194
46
- wolof_translate-0.0.1.dist-info/METADATA,sha256=itQMCA-zGM3gSDiKco5dMtY5qfpPHvKNJbw9KtffxzI,818
47
- wolof_translate-0.0.1.dist-info/WHEEL,sha256=G16H4A3IeoQmnOrYV4ueZGKSjhipXx8zc8nu9FGlvMA,92
48
- wolof_translate-0.0.1.dist-info/top_level.txt,sha256=YG-kBnOwUZyQ7SofNvMxNYjzCreH2PVcW2UaEg1-Reg,16
49
- wolof_translate-0.0.1.dist-info/RECORD,,
46
+ wolof_translate-0.0.2.dist-info/METADATA,sha256=l8TYjNCIujICJklQFsyVy9XINJNNAU0hPkkNt0kbKFM,818
47
+ wolof_translate-0.0.2.dist-info/WHEEL,sha256=G16H4A3IeoQmnOrYV4ueZGKSjhipXx8zc8nu9FGlvMA,92
48
+ wolof_translate-0.0.2.dist-info/top_level.txt,sha256=YG-kBnOwUZyQ7SofNvMxNYjzCreH2PVcW2UaEg1-Reg,16
49
+ wolof_translate-0.0.2.dist-info/RECORD,,