returnn 1.20251006.114241__py3-none-any.whl → 1.20251007.223754__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of returnn might be problematic. Click here for more details.
- returnn/PKG-INFO +1 -1
- returnn/_setup_info_generated.py +2 -2
- returnn/datasets/basic.py +29 -13
- returnn/datasets/distrib_files.py +7 -1
- returnn/datasets/huggingface.py +434 -0
- {returnn-1.20251006.114241.dist-info → returnn-1.20251007.223754.dist-info}/METADATA +1 -1
- {returnn-1.20251006.114241.dist-info → returnn-1.20251007.223754.dist-info}/RECORD +10 -9
- {returnn-1.20251006.114241.dist-info → returnn-1.20251007.223754.dist-info}/LICENSE +0 -0
- {returnn-1.20251006.114241.dist-info → returnn-1.20251007.223754.dist-info}/WHEEL +0 -0
- {returnn-1.20251006.114241.dist-info → returnn-1.20251007.223754.dist-info}/top_level.txt +0 -0
returnn/PKG-INFO
CHANGED
returnn/_setup_info_generated.py
CHANGED
|
@@ -1,2 +1,2 @@
|
|
|
1
|
-
version = '1.
|
|
2
|
-
long_version = '1.
|
|
1
|
+
version = '1.20251007.223754'
|
|
2
|
+
long_version = '1.20251007.223754+git.eb1103a'
|
returnn/datasets/basic.py
CHANGED
|
@@ -19,6 +19,7 @@ import os
|
|
|
19
19
|
import math
|
|
20
20
|
import numpy
|
|
21
21
|
import functools
|
|
22
|
+
import types
|
|
22
23
|
from typing import TYPE_CHECKING, Optional, Any, Set, Tuple, Union, Type, Dict, Sequence, List, Callable
|
|
23
24
|
|
|
24
25
|
from returnn.log import log
|
|
@@ -154,7 +155,7 @@ class Dataset:
|
|
|
154
155
|
self.seq_tags_filter = set(self._load_seq_list_file(seq_list_filter_file)) if seq_list_filter_file else None
|
|
155
156
|
self.unique_seq_tags = unique_seq_tags
|
|
156
157
|
self._seq_order_seq_lens_file = seq_order_seq_lens_file
|
|
157
|
-
self._seq_order_seq_lens_by_idx = None
|
|
158
|
+
self._seq_order_seq_lens_by_idx: Optional[Sequence[Union[int, float]]] = None
|
|
158
159
|
# There is probably no use case for combining the two, so avoid potential misconfiguration.
|
|
159
160
|
assert self.partition_epoch == 1 or self.repeat_epoch == 1, (
|
|
160
161
|
"Combining partition_epoch and repeat_epoch is prohibited."
|
|
@@ -486,12 +487,8 @@ class Dataset:
|
|
|
486
487
|
"""
|
|
487
488
|
raise NotImplementedError
|
|
488
489
|
|
|
489
|
-
def _get_seq_order_seq_lens_by_idx(self, seq_idx):
|
|
490
|
-
|
|
491
|
-
:param int seq_idx:
|
|
492
|
-
:rtype: int
|
|
493
|
-
"""
|
|
494
|
-
if not self._seq_order_seq_lens_by_idx:
|
|
490
|
+
def _get_seq_order_seq_lens_by_idx(self, seq_idx: int) -> Union[int, float]:
|
|
491
|
+
if self._seq_order_seq_lens_by_idx is None:
|
|
495
492
|
assert self._seq_order_seq_lens_file
|
|
496
493
|
if self._seq_order_seq_lens_file.endswith(".gz"):
|
|
497
494
|
import gzip
|
|
@@ -502,11 +499,12 @@ class Dataset:
|
|
|
502
499
|
seq_lens = eval(raw)
|
|
503
500
|
assert isinstance(seq_lens, dict)
|
|
504
501
|
all_tags = self.get_all_tags()
|
|
505
|
-
self._seq_order_seq_lens_by_idx = [seq_lens[tag] for tag in all_tags]
|
|
502
|
+
self._seq_order_seq_lens_by_idx = numpy.array([seq_lens[tag] for tag in all_tags])
|
|
503
|
+
self._get_seq_order_seq_lens_by_idx = self._seq_order_seq_lens_by_idx.__getitem__ # faster
|
|
506
504
|
return self._seq_order_seq_lens_by_idx[seq_idx]
|
|
507
505
|
|
|
508
506
|
def get_seq_order_for_epoch(
|
|
509
|
-
self, epoch: Optional[int], num_seqs: int, get_seq_len: Optional[Callable[[int], int]] = None
|
|
507
|
+
self, epoch: Optional[int], num_seqs: int, get_seq_len: Optional[Callable[[int], Union[int, float]]] = None
|
|
510
508
|
) -> Sequence[int]:
|
|
511
509
|
"""
|
|
512
510
|
Returns the order of the given epoch.
|
|
@@ -515,7 +513,7 @@ class Dataset:
|
|
|
515
513
|
|
|
516
514
|
:param epoch: for 'random', this determines the random seed
|
|
517
515
|
:param num_seqs:
|
|
518
|
-
:param get_seq_len: function (originalSeqIdx: int) -> int
|
|
516
|
+
:param get_seq_len: function (originalSeqIdx: int) -> int|float
|
|
519
517
|
:return: the order for the given epoch. such that seq_idx -> underlying idx
|
|
520
518
|
"""
|
|
521
519
|
if epoch is None:
|
|
@@ -561,8 +559,9 @@ class Dataset:
|
|
|
561
559
|
seq_index = range(num_seqs - 1, -1, -1) # type: Union[range, Sequence[int]]
|
|
562
560
|
elif seq_ordering_method in ["sorted", "sorted_reverse"]:
|
|
563
561
|
assert get_seq_len
|
|
564
|
-
|
|
565
|
-
|
|
562
|
+
seq_lens = _get_seq_len_as_array(get_seq_len, num_seqs)
|
|
563
|
+
if seq_ordering_method == "sorted_reverse":
|
|
564
|
+
seq_lens = -seq_lens
|
|
566
565
|
seq_index = numpy.argsort(seq_lens, kind="stable")
|
|
567
566
|
elif seq_ordering_method == "random" or seq_ordering_method.startswith("random:"):
|
|
568
567
|
tmp = seq_ordering_method.split(":", 1)
|
|
@@ -628,7 +627,7 @@ class Dataset:
|
|
|
628
627
|
nth = 1
|
|
629
628
|
else:
|
|
630
629
|
nth = int(tmp[1])
|
|
631
|
-
seq_lens =
|
|
630
|
+
seq_lens = _get_seq_len_as_array(get_seq_len, num_seqs)
|
|
632
631
|
rnd_seed = self._get_random_seed_for_epoch(epoch=epoch, num_epochs_fixed=nth)
|
|
633
632
|
random_generator = numpy.random.RandomState(rnd_seed)
|
|
634
633
|
seq_index = random_generator.permutation(num_seqs) # type: Union[numpy.ndarray, List[int]]
|
|
@@ -1501,6 +1500,7 @@ def get_dataset_class(name: Union[str, Type[Dataset]]) -> Optional[Type[Dataset]
|
|
|
1501
1500
|
"distrib_files",
|
|
1502
1501
|
"postprocessing",
|
|
1503
1502
|
"text_dict",
|
|
1503
|
+
"huggingface",
|
|
1504
1504
|
]
|
|
1505
1505
|
for mod_name in mod_names:
|
|
1506
1506
|
mod = import_module("returnn.datasets.%s" % mod_name)
|
|
@@ -1757,3 +1757,19 @@ def set_config_extern_data_from_dataset(config, dataset):
|
|
|
1757
1757
|
"extern_data",
|
|
1758
1758
|
{key: _data_kwargs_from_dataset_key(dataset=dataset, key=key) for key in dataset.get_data_keys()},
|
|
1759
1759
|
)
|
|
1760
|
+
|
|
1761
|
+
|
|
1762
|
+
def _get_seq_len_as_array(get_seq_len: Callable[[int], Union[int, float]], num_seqs: int) -> numpy.ndarray:
|
|
1763
|
+
if num_seqs == 0:
|
|
1764
|
+
return numpy.zeros((0,), dtype=numpy.int32)
|
|
1765
|
+
if isinstance(get_seq_len, (types.BuiltinMethodType, types.MethodWrapperType, types.MethodType)):
|
|
1766
|
+
# Call it once. This might trigger some caching.
|
|
1767
|
+
get_seq_len(0)
|
|
1768
|
+
# Get it again. This might now get us a different (cached) function, e.g. array.__getitem__.
|
|
1769
|
+
get_seq_len = getattr(get_seq_len.__self__, get_seq_len.__name__)
|
|
1770
|
+
assert isinstance(get_seq_len, (types.BuiltinMethodType, types.MethodWrapperType, types.MethodType))
|
|
1771
|
+
obj = get_seq_len.__self__
|
|
1772
|
+
if isinstance(obj, numpy.ndarray) and get_seq_len.__name__ == "__getitem__":
|
|
1773
|
+
assert obj.shape == (num_seqs,)
|
|
1774
|
+
return obj
|
|
1775
|
+
return numpy.array([get_seq_len(i) for i in range(num_seqs)])
|
|
@@ -135,7 +135,7 @@ class DistributeFilesDataset(CachedDataset2):
|
|
|
135
135
|
def __init__(
|
|
136
136
|
self,
|
|
137
137
|
*,
|
|
138
|
-
files: Union[List[FileTree], os.PathLike],
|
|
138
|
+
files: Union[List[FileTree], os.PathLike, Callable[[], List[FileTree]]],
|
|
139
139
|
get_sub_epoch_dataset: Callable[[List[FileTree]], Dict[str, Any]],
|
|
140
140
|
preload_next_n_sub_epochs: int = 1,
|
|
141
141
|
buffer_size: int = 1,
|
|
@@ -151,6 +151,7 @@ class DistributeFilesDataset(CachedDataset2):
|
|
|
151
151
|
can also be specified as a path to a .txt file containing one file per line,
|
|
152
152
|
or a python file containing the repr of a list of arbitrarily nested python objects,
|
|
153
153
|
or a JSON file containing a list of arbitarily nested (JSON) objects.
|
|
154
|
+
It can also be a callable which returns such a list.
|
|
154
155
|
:param get_sub_epoch_dataset: callable which returns a dataset dict for a given subset of files
|
|
155
156
|
:param preload_next_n_sub_epochs: how many sub epoch datasets to preload
|
|
156
157
|
:param buffer_size: buffer size for each worker, number of seqs to prefetch
|
|
@@ -244,6 +245,11 @@ class DistributeFilesDataset(CachedDataset2):
|
|
|
244
245
|
return
|
|
245
246
|
if isinstance(self.files, list):
|
|
246
247
|
self._files = self.files
|
|
248
|
+
elif callable(self.files):
|
|
249
|
+
self._files = self.files()
|
|
250
|
+
assert isinstance(self._files, list), (
|
|
251
|
+
f"{self}: callable files {self.files} must return a list, got {type(self._files)}"
|
|
252
|
+
)
|
|
247
253
|
elif isinstance(self.files, (str, os.PathLike)):
|
|
248
254
|
_, ext = os.path.splitext(self.files)
|
|
249
255
|
assert ext, f"{self}: no file extension on file list file {self.files}"
|
|
@@ -0,0 +1,434 @@
|
|
|
1
|
+
"""
|
|
2
|
+
HuggingFace dataset wrapper
|
|
3
|
+
|
|
4
|
+
See https://github.com/rwth-i6/returnn/issues/1257 for some initial discussion.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
from typing import TYPE_CHECKING, Optional, Union, Any, Callable, Sequence, Dict, List
|
|
9
|
+
import os
|
|
10
|
+
import re
|
|
11
|
+
import numpy
|
|
12
|
+
from returnn.tensor import Tensor
|
|
13
|
+
from returnn.util import file_cache
|
|
14
|
+
from .basic import DatasetSeq
|
|
15
|
+
from .cached2 import CachedDataset2
|
|
16
|
+
from .util.vocabulary import Vocabulary
|
|
17
|
+
from .util.strings import str_to_numpy_array
|
|
18
|
+
|
|
19
|
+
if TYPE_CHECKING:
|
|
20
|
+
# noinspection PyUnresolvedReferences,PyPackageRequirements
|
|
21
|
+
import datasets
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class HuggingFaceDataset(CachedDataset2):
|
|
25
|
+
"""
|
|
26
|
+
HuggingFace dataset wrapper.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
def __init__(
|
|
30
|
+
self,
|
|
31
|
+
dataset_opts: Union[
|
|
32
|
+
Dict[str, Any],
|
|
33
|
+
str,
|
|
34
|
+
os.PathLike,
|
|
35
|
+
Sequence[Union[str, os.PathLike]],
|
|
36
|
+
Callable[[], Union[Dict[str, Any], str, os.PathLike, Sequence[Union[str, os.PathLike]], datasets.Dataset]],
|
|
37
|
+
],
|
|
38
|
+
*,
|
|
39
|
+
use_file_cache: bool = False,
|
|
40
|
+
map_func: Optional[Callable[[datasets.Dataset], datasets.Dataset]] = None,
|
|
41
|
+
rename_columns: Optional[Dict[str, str]] = None,
|
|
42
|
+
cast_columns: Optional[Dict[str, Dict[str, Any]]] = None,
|
|
43
|
+
data_format: Dict[str, Dict[str, Any]],
|
|
44
|
+
seq_tag_column: Optional[str] = "id",
|
|
45
|
+
sorting_seq_len_column_data: Optional[str] = None,
|
|
46
|
+
sorting_seq_len_column: Optional[str] = None,
|
|
47
|
+
**kwargs,
|
|
48
|
+
):
|
|
49
|
+
"""
|
|
50
|
+
:param dataset_opts: either a dict of options for :func:`datasets.load_dataset`
|
|
51
|
+
or a path to a local dataset for :func:`datasets.load_from_disk`,
|
|
52
|
+
or a list of Arrow filenames to load with :func:`datasets.Dataset.from_file` and concatenate.
|
|
53
|
+
It can also be a callable returning one of the above,
|
|
54
|
+
or returning a :class:`datasets.Dataset` directly.
|
|
55
|
+
:param use_file_cache: if True, will cache the dataset files on local disk using :mod:`file_cache`.
|
|
56
|
+
This only works for dataset_opts which is a str or list of str (or callable returning that).
|
|
57
|
+
:param map_func: optional function to apply to the dataset after loading
|
|
58
|
+
:param rename_columns: if given, will rename these columns
|
|
59
|
+
:param cast_columns: if given, will cast these columns to the specified types.
|
|
60
|
+
This is useful if the dataset has not the expected types.
|
|
61
|
+
See :func:`datasets.Dataset.cast` for details.
|
|
62
|
+
You can also e.g. enforce some sample_rate for audio, etc.
|
|
63
|
+
:param data_format:
|
|
64
|
+
For each column name (data key), specify the format,
|
|
65
|
+
as a dict with entries for "dim", "ndim", "shape", and/or "dtype",
|
|
66
|
+
compatible to :class:`Tensor`.
|
|
67
|
+
It can be a subset of the available columns.
|
|
68
|
+
If "vocab" is specified, and the underlying HF datasets column is of dtype "string",
|
|
69
|
+
it will automatically tokenize the string using the vocab.
|
|
70
|
+
:param seq_tag_column: key (column name) in the dataset to use as sequence tag.
|
|
71
|
+
If None, will use the sequence index as tag.
|
|
72
|
+
:param sorting_seq_len_column_data: key (column name) in the dataset to use for sorting by sequence length.
|
|
73
|
+
It will take len(dataset[sorting_seq_len_column_data]) as sequence length (only for sorting/shuffling).
|
|
74
|
+
:param sorting_seq_len_column: key (column name) in the dataset to use for sorting by sequence length.
|
|
75
|
+
It will take the value of dataset[sorting_seq_len_column] as sequence length (only for sorting/shuffling).
|
|
76
|
+
E.g. some datasets provide "duration", "duration_ms", "wav_filesize" or similar such information
|
|
77
|
+
which can be used.
|
|
78
|
+
"""
|
|
79
|
+
super().__init__(**kwargs)
|
|
80
|
+
|
|
81
|
+
self.dataset_opts = dataset_opts
|
|
82
|
+
self.use_file_cache = use_file_cache
|
|
83
|
+
self.map_func = map_func
|
|
84
|
+
self.rename_columns = rename_columns
|
|
85
|
+
self.cast_columns = cast_columns
|
|
86
|
+
|
|
87
|
+
self.data_format: Dict[str, Tensor] = {k: _make_tensor_template(v, k) for k, v in data_format.items()}
|
|
88
|
+
self.seq_tag_column: Optional[str] = seq_tag_column
|
|
89
|
+
self.sorting_seq_len_column_data = sorting_seq_len_column_data
|
|
90
|
+
self.sorting_seq_len_column = sorting_seq_len_column
|
|
91
|
+
|
|
92
|
+
self.labels = {k: data.vocab.labels for k, data in self.data_format.items() if data.vocab}
|
|
93
|
+
self.num_outputs = {k: (data.dim, data.ndim) for k, data in self.data_format.items()}
|
|
94
|
+
|
|
95
|
+
self.hf_dataset: Optional[datasets.Dataset] = None # lazily loaded, _lazy_init
|
|
96
|
+
self._seq_order: Optional[Sequence[int]] = None # init_seq_order
|
|
97
|
+
self._seq_tags: Optional[List[str]] = None # get_all_tags cache
|
|
98
|
+
|
|
99
|
+
def _lazy_init(self):
|
|
100
|
+
if self.hf_dataset is not None:
|
|
101
|
+
return
|
|
102
|
+
|
|
103
|
+
# Load the dataset
|
|
104
|
+
# noinspection PyUnresolvedReferences,PyPackageRequirements
|
|
105
|
+
import datasets
|
|
106
|
+
|
|
107
|
+
dataset_opts = self.dataset_opts
|
|
108
|
+
if callable(dataset_opts):
|
|
109
|
+
dataset_opts = dataset_opts()
|
|
110
|
+
if self.use_file_cache:
|
|
111
|
+
assert isinstance(dataset_opts, (str, os.PathLike, list, tuple)), (
|
|
112
|
+
f"{self}: with use_file_cache, dataset_opts must be str or list of str, got {type(dataset_opts)}"
|
|
113
|
+
)
|
|
114
|
+
if isinstance(dataset_opts, (str, os.PathLike)):
|
|
115
|
+
dataset_opts = get_arrow_shard_files_from_hf_dataset_dir(dataset_opts)
|
|
116
|
+
assert isinstance(dataset_opts, (list, tuple))
|
|
117
|
+
cache = file_cache.get_instance()
|
|
118
|
+
dataset_opts = [cache.get_file(os.fspath(fn)) for fn in dataset_opts]
|
|
119
|
+
self.set_file_cache(cache)
|
|
120
|
+
if isinstance(dataset_opts, dict):
|
|
121
|
+
self.hf_dataset = datasets.load_dataset(**dataset_opts)
|
|
122
|
+
elif isinstance(dataset_opts, str):
|
|
123
|
+
self.hf_dataset = datasets.load_from_disk(dataset_opts)
|
|
124
|
+
elif isinstance(dataset_opts, (list, tuple)):
|
|
125
|
+
self.hf_dataset = datasets.concatenate_datasets([datasets.Dataset.from_file(fn) for fn in dataset_opts])
|
|
126
|
+
elif isinstance(dataset_opts, datasets.Dataset):
|
|
127
|
+
self.hf_dataset = dataset_opts
|
|
128
|
+
else:
|
|
129
|
+
raise TypeError(f"{self}: invalid dataset_opts type {type(dataset_opts)}")
|
|
130
|
+
assert isinstance(self.hf_dataset, datasets.Dataset), (
|
|
131
|
+
f"{self}: Expected single dataset, got {type(self.hf_dataset)} {self.hf_dataset}. Specify split if needed."
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
if self.map_func is not None:
|
|
135
|
+
self.hf_dataset = self.map_func(self.hf_dataset)
|
|
136
|
+
|
|
137
|
+
if self.rename_columns:
|
|
138
|
+
self.hf_dataset = self.hf_dataset.rename_columns(self.rename_columns)
|
|
139
|
+
|
|
140
|
+
if self.cast_columns:
|
|
141
|
+
# Note: prefer cast_column, as this can avoid using `map`, i.e. be faster.
|
|
142
|
+
for key, column_format in self.cast_columns.items():
|
|
143
|
+
assert key in self.hf_dataset.features, (
|
|
144
|
+
f"{self}: cast_column {key} not in dataset features {self.hf_dataset.features}"
|
|
145
|
+
)
|
|
146
|
+
feat = datasets.features.features.generate_from_dict(column_format)
|
|
147
|
+
self.hf_dataset = self.hf_dataset.cast_column(key, feat)
|
|
148
|
+
|
|
149
|
+
if self.seq_tag_column:
|
|
150
|
+
assert self.seq_tag_column in self.hf_dataset.features, (
|
|
151
|
+
f"{self}: seq_tag_column {self.seq_tag_column} not in dataset features {self.hf_dataset.features}"
|
|
152
|
+
)
|
|
153
|
+
assert self.hf_dataset.features[self.seq_tag_column].dtype in ("string", "int64"), (
|
|
154
|
+
f"{self}: seq_tag_column {self.seq_tag_column} must be of dtype string or int64,"
|
|
155
|
+
f" got {self.hf_dataset.features[self.seq_tag_column].dtype}"
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
selected_columns = list(self.data_format.keys())
|
|
159
|
+
if self.seq_tag_column and self.seq_tag_column not in selected_columns:
|
|
160
|
+
selected_columns.append(self.seq_tag_column)
|
|
161
|
+
if self.sorting_seq_len_column and self.sorting_seq_len_column not in selected_columns:
|
|
162
|
+
selected_columns.append(self.sorting_seq_len_column)
|
|
163
|
+
if self.sorting_seq_len_column_data and self.sorting_seq_len_column_data not in selected_columns:
|
|
164
|
+
selected_columns.append(self.sorting_seq_len_column_data)
|
|
165
|
+
self.hf_dataset = self.hf_dataset.select_columns(selected_columns)
|
|
166
|
+
|
|
167
|
+
self.hf_dataset.set_format("numpy")
|
|
168
|
+
|
|
169
|
+
for key, user_format in self.data_format.items():
|
|
170
|
+
feature = self.hf_dataset.features[key]
|
|
171
|
+
inferred_format = _infer_data_format_for_feature(feature, f"{self}: column {key}: ")
|
|
172
|
+
if user_format.vocab and inferred_format["dtype"] == "string":
|
|
173
|
+
pass # allow to auto-tokenize strings when vocab is specified
|
|
174
|
+
else:
|
|
175
|
+
for key_ in ["dtype", "ndim", "dim"]:
|
|
176
|
+
assert getattr(user_format, key_) == inferred_format[key_], (
|
|
177
|
+
f"{self}: column {key}, user-specified {user_format}, {key_}:"
|
|
178
|
+
f" user-specified {getattr(user_format, key_)} does not match inferred {inferred_format[key_]}"
|
|
179
|
+
)
|
|
180
|
+
if "vocab" in inferred_format and not user_format.vocab:
|
|
181
|
+
assert user_format.sparse, f"{self}: column {key}: user_format expected to be sparse, got {user_format}"
|
|
182
|
+
user_format.sparse_dim.vocab = Vocabulary.create_vocab(**inferred_format["vocab"])
|
|
183
|
+
self.labels[key] = user_format.vocab.labels
|
|
184
|
+
|
|
185
|
+
def get_data_keys(self) -> List[str]:
|
|
186
|
+
""":return: list of data keys"""
|
|
187
|
+
return list(self.data_format.keys())
|
|
188
|
+
|
|
189
|
+
def get_target_list(self) -> List[str]:
|
|
190
|
+
""":return: list of target keys"""
|
|
191
|
+
return self.get_data_keys() # it's somewhat arbitrary...
|
|
192
|
+
|
|
193
|
+
def get_data_shape(self, key: str) -> List[int]:
|
|
194
|
+
""":return: data shape for the given key"""
|
|
195
|
+
return list(self.data_format[key].shape)
|
|
196
|
+
|
|
197
|
+
def get_data_dim(self, key: str) -> int:
|
|
198
|
+
""":return: data dimension for the given key"""
|
|
199
|
+
return self.data_format[key].dim
|
|
200
|
+
|
|
201
|
+
def is_data_sparse(self, key: str) -> bool:
|
|
202
|
+
""":return: whether the data is sparse for the given key"""
|
|
203
|
+
return self.data_format[key].sparse
|
|
204
|
+
|
|
205
|
+
def get_data_dtype(self, key: str) -> str:
|
|
206
|
+
""":return: dtype"""
|
|
207
|
+
return self.data_format[key].dtype
|
|
208
|
+
|
|
209
|
+
def _get_seq_len(self, seq_idx: int) -> Union[int, float]:
|
|
210
|
+
if self._seq_order_seq_lens_by_idx is not None:
|
|
211
|
+
self._get_seq_len = self._seq_order_seq_lens_by_idx.__getitem__ # faster
|
|
212
|
+
return self._seq_order_seq_lens_by_idx[seq_idx]
|
|
213
|
+
assert not self._seq_order_seq_lens_file # not expected to call this
|
|
214
|
+
if self.sorting_seq_len_column:
|
|
215
|
+
self._seq_order_seq_lens_by_idx = numpy.array(self.hf_dataset[self.sorting_seq_len_column])
|
|
216
|
+
self._get_seq_len = self._seq_order_seq_lens_by_idx.__getitem__ # faster
|
|
217
|
+
v = self._seq_order_seq_lens_by_idx[seq_idx]
|
|
218
|
+
return int(v) # noqa
|
|
219
|
+
if self.sorting_seq_len_column_data:
|
|
220
|
+
v = self.hf_dataset[seq_idx][self.sorting_seq_len_column_data]
|
|
221
|
+
return len(v) # noqa
|
|
222
|
+
raise ValueError(
|
|
223
|
+
f"{self}: sorting/shuffling by seq len not configured,"
|
|
224
|
+
f" need sorting_seq_len_column or sorting_seq_len_column_data"
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
@property
|
|
228
|
+
def num_seqs(self) -> int:
|
|
229
|
+
""":return: number of sequences"""
|
|
230
|
+
assert self._seq_order is not None, "num_seqs is only known after calling init_seq_order()"
|
|
231
|
+
return len(self._seq_order)
|
|
232
|
+
|
|
233
|
+
def get_tag(self, sorted_seq_idx: int) -> str:
|
|
234
|
+
""":return: tag of the sequence"""
|
|
235
|
+
corpus_seq_idx = self.get_corpus_seq_idx(sorted_seq_idx)
|
|
236
|
+
self._lazy_init()
|
|
237
|
+
dataset_item = self.hf_dataset[corpus_seq_idx]
|
|
238
|
+
return self._get_seq_tag(corpus_seq_idx, dataset_item)
|
|
239
|
+
|
|
240
|
+
def get_all_tags(self) -> List[str]:
|
|
241
|
+
""":return: all tags"""
|
|
242
|
+
if self._seq_tags is not None:
|
|
243
|
+
return self._seq_tags
|
|
244
|
+
self._lazy_init()
|
|
245
|
+
if self.seq_tag_column:
|
|
246
|
+
res = list(map(str, self.hf_dataset[self.seq_tag_column]))
|
|
247
|
+
else:
|
|
248
|
+
res = [f"seq-{i}" for i in range(self.hf_dataset.num_rows)]
|
|
249
|
+
self._seq_tags = res
|
|
250
|
+
return res
|
|
251
|
+
|
|
252
|
+
def get_total_num_seqs(self, *, fast: bool = False) -> int:
|
|
253
|
+
""":return: total number of sequences in the dataset"""
|
|
254
|
+
if fast:
|
|
255
|
+
return super().get_total_num_seqs(fast=True)
|
|
256
|
+
self._lazy_init()
|
|
257
|
+
return self.hf_dataset.num_rows
|
|
258
|
+
|
|
259
|
+
def init_seq_order(
|
|
260
|
+
self,
|
|
261
|
+
epoch: Optional[int] = None,
|
|
262
|
+
seq_list: Optional[Sequence[str]] = None,
|
|
263
|
+
seq_order: Optional[Sequence[int]] = None,
|
|
264
|
+
) -> bool:
|
|
265
|
+
"""
|
|
266
|
+
:param epoch:
|
|
267
|
+
:param seq_list: List of sequence tags, to set a predefined order.
|
|
268
|
+
:param seq_order: List of corpus sequence indices, to set a predefined order.
|
|
269
|
+
:returns whether the order changed (True is always safe to return)
|
|
270
|
+
"""
|
|
271
|
+
super().init_seq_order(epoch=epoch, seq_list=seq_list, seq_order=seq_order)
|
|
272
|
+
|
|
273
|
+
if seq_order is not None:
|
|
274
|
+
self._seq_order = seq_order
|
|
275
|
+
elif seq_list is not None:
|
|
276
|
+
all_tags = self.get_all_tags()
|
|
277
|
+
self._seq_order = [all_tags.index(tag) for tag in seq_list]
|
|
278
|
+
elif epoch is None:
|
|
279
|
+
self._seq_order = ()
|
|
280
|
+
else:
|
|
281
|
+
self._lazy_init()
|
|
282
|
+
self._seq_order = self.get_seq_order_for_epoch(
|
|
283
|
+
epoch=epoch, num_seqs=self.hf_dataset.num_rows, get_seq_len=self._get_seq_len
|
|
284
|
+
)
|
|
285
|
+
return True
|
|
286
|
+
|
|
287
|
+
def _collect_single_seq(self, seq_idx: int) -> DatasetSeq:
|
|
288
|
+
# noinspection PyUnresolvedReferences,PyPackageRequirements
|
|
289
|
+
import datasets
|
|
290
|
+
|
|
291
|
+
corpus_seq_idx = self.get_corpus_seq_idx(seq_idx)
|
|
292
|
+
|
|
293
|
+
def _ensure_numpy(k, x):
|
|
294
|
+
if isinstance(x, numpy.ndarray): # fast path
|
|
295
|
+
return x
|
|
296
|
+
if isinstance(x, str):
|
|
297
|
+
if self.data_format[k].dtype == "string":
|
|
298
|
+
return str_to_numpy_array(x)
|
|
299
|
+
if self.data_format[k].vocab:
|
|
300
|
+
return numpy.array(self.data_format[k].vocab.get_seq(x), dtype=self.data_format[k].dtype)
|
|
301
|
+
raise ValueError(f"{self}: column {k}: cannot convert string {x!r} to numpy array")
|
|
302
|
+
feat = self.hf_dataset.features[k]
|
|
303
|
+
if isinstance(feat, datasets.features.Audio):
|
|
304
|
+
# In HF datasets 3, this is just a dict.
|
|
305
|
+
# In HF datasets 4, this can also be a datasets.features._torchcodec.AudioDecoder.
|
|
306
|
+
assert isinstance(x, dict) or x.__class__.__name__ == "AudioDecoder"
|
|
307
|
+
if feat.decode:
|
|
308
|
+
x = x["array"]
|
|
309
|
+
else:
|
|
310
|
+
x = x["bytes"]
|
|
311
|
+
if isinstance(x, numpy.ndarray): # fast path
|
|
312
|
+
return x
|
|
313
|
+
if isinstance(x, (bytes, bytearray)):
|
|
314
|
+
return numpy.frombuffer(x, dtype=self.data_format[k].dtype)
|
|
315
|
+
return numpy.array(x)
|
|
316
|
+
|
|
317
|
+
self._lazy_init()
|
|
318
|
+
dataset_item = self.hf_dataset[corpus_seq_idx]
|
|
319
|
+
seq_tag = self._get_seq_tag(corpus_seq_idx, dataset_item)
|
|
320
|
+
features = {k: _ensure_numpy(k, dataset_item[k]) for k in self.data_format}
|
|
321
|
+
return DatasetSeq(seq_idx, features=features, seq_tag=seq_tag)
|
|
322
|
+
|
|
323
|
+
def _get_seq_tag(self, corpus_seq_idx: int, dataset_item: Dict[str, Any]) -> str:
|
|
324
|
+
if self.seq_tag_column:
|
|
325
|
+
seq_tag = dataset_item[self.seq_tag_column]
|
|
326
|
+
assert isinstance(seq_tag, (str, int, numpy.int64)), f"got {type(seq_tag)} {seq_tag!r}"
|
|
327
|
+
seq_tag = str(seq_tag)
|
|
328
|
+
else:
|
|
329
|
+
seq_tag = f"seq-{corpus_seq_idx}"
|
|
330
|
+
return seq_tag
|
|
331
|
+
|
|
332
|
+
def get_current_seq_order(self) -> Sequence[int]:
|
|
333
|
+
""":return: list of corpus seq idx"""
|
|
334
|
+
assert self._seq_order is not None
|
|
335
|
+
return self._seq_order
|
|
336
|
+
|
|
337
|
+
def get_corpus_seq_idx(self, sorted_seq_idx: int) -> int:
|
|
338
|
+
""":return: corpus seq idx"""
|
|
339
|
+
return int(self._seq_order[sorted_seq_idx])
|
|
340
|
+
|
|
341
|
+
|
|
342
|
+
def get_arrow_shard_files_from_hf_dataset_dir(hf_data_dir: Union[str, os.PathLike]) -> List[str]:
|
|
343
|
+
"""
|
|
344
|
+
Given some HF datasets directory (created via :func:`datasets.save_to_disk`),
|
|
345
|
+
return the list of Arrow shard files (``data-*-of-*.arrow``).
|
|
346
|
+
This also verifies that the directory looks like a valid HF datasets directory.
|
|
347
|
+
The order of the returned list is by shard index.
|
|
348
|
+
Note that this does not load the dataset, just inspects the directory structure.
|
|
349
|
+
|
|
350
|
+
:param hf_data_dir: directory
|
|
351
|
+
:return: list of Arrow shard files
|
|
352
|
+
"""
|
|
353
|
+
hf_data_dir = os.fspath(hf_data_dir)
|
|
354
|
+
content = os.listdir(hf_data_dir)
|
|
355
|
+
assert "state.json" in content, f"not a valid HF datasets dir: {hf_data_dir!r}"
|
|
356
|
+
assert "dataset_info.json" in content, f"not a valid HF datasets dir: {hf_data_dir!r}"
|
|
357
|
+
pat = re.compile("^(.*)-([0-9]+)-of-([0-9]+).arrow$")
|
|
358
|
+
content = [pat.match(fn) for fn in content]
|
|
359
|
+
content = [m for m in content if m]
|
|
360
|
+
assert content, f"no matching .arrow files in {hf_data_dir!r} found, expected *-*-of-*.arrow"
|
|
361
|
+
prefix = content[0].group(1)
|
|
362
|
+
assert all(m.group(1) == prefix for m in content), (
|
|
363
|
+
f"mismatching prefix in {hf_data_dir!r}, expected {prefix}, got {[m.group(1) for m in content]}"
|
|
364
|
+
)
|
|
365
|
+
num_shards = int(content[0].group(3))
|
|
366
|
+
assert all(int(m.group(3)) == num_shards for m in content), (
|
|
367
|
+
f"mismatching number of shards in {hf_data_dir!r}, expected {num_shards}, got {[m.group(3) for m in content]}"
|
|
368
|
+
)
|
|
369
|
+
assert len(content) == num_shards, f"expected {num_shards} shard files in {hf_data_dir!r}, got {content}"
|
|
370
|
+
content_by_idx = {int(m.group(2)): m for m in content}
|
|
371
|
+
assert set(content_by_idx.keys()) == set(range(num_shards)), (
|
|
372
|
+
f"expected shard indices 0..{num_shards - 1} in {hf_data_dir!r}, got {sorted(content_by_idx.keys())}"
|
|
373
|
+
)
|
|
374
|
+
return [hf_data_dir + "/" + content_by_idx[i].group(0) for i in range(num_shards)]
|
|
375
|
+
|
|
376
|
+
|
|
377
|
+
def _infer_data_format_for_feature(
|
|
378
|
+
feature: Union[
|
|
379
|
+
datasets.features.Sequence,
|
|
380
|
+
datasets.features.ClassLabel,
|
|
381
|
+
datasets.features.Value,
|
|
382
|
+
datasets.features.Array2D,
|
|
383
|
+
datasets.features.Array3D,
|
|
384
|
+
datasets.features.Array4D,
|
|
385
|
+
datasets.features.Audio,
|
|
386
|
+
],
|
|
387
|
+
exc_prefix: str = "",
|
|
388
|
+
) -> Dict[str, Any]:
|
|
389
|
+
# noinspection PyUnresolvedReferences,PyPackageRequirements
|
|
390
|
+
import datasets
|
|
391
|
+
|
|
392
|
+
labels = None
|
|
393
|
+
num_classes = None
|
|
394
|
+
num_dims = 0
|
|
395
|
+
while isinstance(feature, datasets.features.Sequence):
|
|
396
|
+
feature: datasets.features.List # typing for HF datasets 4
|
|
397
|
+
num_dims += 1
|
|
398
|
+
if feature.length != -1:
|
|
399
|
+
num_classes = feature.length
|
|
400
|
+
feature = feature.feature
|
|
401
|
+
if isinstance(feature, datasets.features.ClassLabel):
|
|
402
|
+
labels = feature.names
|
|
403
|
+
dtype = feature.dtype
|
|
404
|
+
num_classes = feature.num_classes # noqa
|
|
405
|
+
elif isinstance(feature, datasets.features.Value):
|
|
406
|
+
dtype = feature.dtype
|
|
407
|
+
elif isinstance(feature, (datasets.features.Array2D, datasets.features.Array3D, datasets.features.Array4D)):
|
|
408
|
+
dtype = feature.dtype
|
|
409
|
+
num_classes = feature.shape[-1]
|
|
410
|
+
num_dims += len(feature.shape)
|
|
411
|
+
elif isinstance(feature, datasets.features.Audio):
|
|
412
|
+
if feature.decode:
|
|
413
|
+
dtype = "float32" # samples
|
|
414
|
+
else:
|
|
415
|
+
dtype = "uint8" # bytes
|
|
416
|
+
num_dims += 1 # time axis
|
|
417
|
+
else:
|
|
418
|
+
assert False, f"{exc_prefix}unsupported feature type {type(feature)} {feature}"
|
|
419
|
+
|
|
420
|
+
d = {"dim": num_classes, "ndim": num_dims, "dtype": dtype}
|
|
421
|
+
if labels:
|
|
422
|
+
d["sparse"] = True
|
|
423
|
+
d["vocab"] = {"vocab_file": None, "labels": labels, "unknown_label": None}
|
|
424
|
+
return d
|
|
425
|
+
|
|
426
|
+
|
|
427
|
+
def _make_tensor_template(data: Union[Dict[str, Any], Tensor], name: str) -> Tensor:
|
|
428
|
+
if isinstance(data, Tensor):
|
|
429
|
+
data = data.copy(name)
|
|
430
|
+
else:
|
|
431
|
+
assert isinstance(data, dict)
|
|
432
|
+
data = Tensor(name, batch_dim_axis=None, **data)
|
|
433
|
+
assert data.batch_dim_axis is None
|
|
434
|
+
return data
|
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
returnn/PKG-INFO,sha256=
|
|
1
|
+
returnn/PKG-INFO,sha256=lgxotZdSfk01D3LjKSGuuHRSwL3ETFQiSn5GCw80DsE,5215
|
|
2
2
|
returnn/__init__.py,sha256=biBtRsM0WZ406vShaeH-9WFoqJ8XwTbn6g0EeFJ7l8E,1012
|
|
3
3
|
returnn/__main__.py,sha256=lHyZcu_0yc9f7Vf_Kfdy9PmeU0T76XVXnpalHi5WKro,31740
|
|
4
4
|
returnn/__old_mod_loader__.py,sha256=nvsNY-xELdS_IPNkv66Q9Rmvg4dbGW0-EBRDcCmctos,7654
|
|
5
5
|
returnn/__setup__.py,sha256=22kQn2fh11iPM0hLb2Fy5sLmoU1JGvmDxXRYuRgQkwU,4659
|
|
6
|
-
returnn/_setup_info_generated.py,sha256=
|
|
6
|
+
returnn/_setup_info_generated.py,sha256=kdGINGjXKqjDphnF8IBHiGlKw9_1pozhbIvFUOdv_vU,77
|
|
7
7
|
returnn/config.py,sha256=3tmKhB6FnQZaNdtcYsiB61JnEY--iZ2qmJ4yq0b6tE0,29140
|
|
8
8
|
returnn/forward_iface.py,sha256=A_OJiaXsX4MlXQRzST86ylyxSUZbC402PQL1REcqHjM,911
|
|
9
9
|
returnn/learning_rate_control.py,sha256=ZvWryAn_tv9DhV8sh1LV3eE34Yltl3On3mYZAG4hR9s,34684
|
|
@@ -13,13 +13,14 @@ returnn/native_op.py,sha256=4_NnvfNxsM8GE_FsD6yOg6PZegqIdtJ3Sl1GdBWmFvg,244424
|
|
|
13
13
|
returnn/pretrain.py,sha256=MHiXJZqkQFmDVyaYsGpd_Acv20wxl7Pr6s6qJzAT2FI,22648
|
|
14
14
|
returnn/datasets/__init__.py,sha256=PvDlfDOaaopIeUIt0OSvHD2eHZkdkyE-sjMXf35EH5U,390
|
|
15
15
|
returnn/datasets/audio.py,sha256=Gmj7a08dnvYh7Z-G1TNapz42L50AIcDE9JeIZaO1s1M,23334
|
|
16
|
-
returnn/datasets/basic.py,sha256=
|
|
16
|
+
returnn/datasets/basic.py,sha256=s0Vjag5lJ5wGXKENN4KHwGtx7ZDiLdWAFIjFbiqAQsE,74159
|
|
17
17
|
returnn/datasets/bundle_file.py,sha256=KQNrS1MSf-4_idlK0c0KFwON-f5sEK0sWU15WpoMYpE,2380
|
|
18
18
|
returnn/datasets/cached.py,sha256=RyefRjSDdp-HveK-2vLy2C6BIHcpqQ_lNvUKlIa4QAI,25412
|
|
19
19
|
returnn/datasets/cached2.py,sha256=oJOq2lWRQpxm6kyUKW1w5qZBd4kdKEpwM7KY_QnXbq4,11922
|
|
20
|
-
returnn/datasets/distrib_files.py,sha256
|
|
20
|
+
returnn/datasets/distrib_files.py,sha256=srTieLP02kCepAwZ6Y9p20cqB8nAlVJWbSAoOPna9ik,30567
|
|
21
21
|
returnn/datasets/generating.py,sha256=Qb7V94N_GfL2pZPxWS5PmzszoVXXKzuUmsHuW3dmVbc,99556
|
|
22
22
|
returnn/datasets/hdf.py,sha256=v5sjBenURR9Z-g7AQ9tsL84yDSye5RtbLpym3M6HSDE,67833
|
|
23
|
+
returnn/datasets/huggingface.py,sha256=Bh-1hGYERigvuxjQF8kGwd2gm_BFCPVTtedzk1gz9y0,20042
|
|
23
24
|
returnn/datasets/lm.py,sha256=rQ3jV43lSnlGkKu7m5jTTH7aK0BOMXQocsHfJ8OGec8,99950
|
|
24
25
|
returnn/datasets/map.py,sha256=kOBJVZmwDhLsOplzDNByIfa0NRSUaMo2Lsy36lBvxrM,10907
|
|
25
26
|
returnn/datasets/meta.py,sha256=6XPPxhiNSxWw9Hu5Z6wG8dD9Zk82FqiI-k9HGQSTKgw,95658
|
|
@@ -253,8 +254,8 @@ returnn/util/sig_proc.py,sha256=Tjz0VOAVyqu2qDCF5HZ1JjALjcFsHcNkcd96WgZeKfE,7265
|
|
|
253
254
|
returnn/util/task_system.py,sha256=y4sMVXQ25Qd2z0rx03uOlXlkE-jbCYC1Sjfn-XlraVU,26003
|
|
254
255
|
returnn/util/train_proc_manager.py,sha256=Pjht28k6uz6BNQ47uW6Gf880iyq5q4wx7P_K2tmoAM8,3266
|
|
255
256
|
returnn/util/watch_memory.py,sha256=BR5P2kvBN6UI81cE0_1WAA6Hd1SByLbBaiDxvLhPOew,4213
|
|
256
|
-
returnn-1.
|
|
257
|
-
returnn-1.
|
|
258
|
-
returnn-1.
|
|
259
|
-
returnn-1.
|
|
260
|
-
returnn-1.
|
|
257
|
+
returnn-1.20251007.223754.dist-info/LICENSE,sha256=ywBD_U2aD4vpuoIgNAsjIGBYydl0tVKll3De0Z8s77c,11041
|
|
258
|
+
returnn-1.20251007.223754.dist-info/METADATA,sha256=lgxotZdSfk01D3LjKSGuuHRSwL3ETFQiSn5GCw80DsE,5215
|
|
259
|
+
returnn-1.20251007.223754.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
|
|
260
|
+
returnn-1.20251007.223754.dist-info/top_level.txt,sha256=Lsn4WZc5Pbfk0-xDQOgnFCxOoqxL4CyeM3N1TFbJncw,8
|
|
261
|
+
returnn-1.20251007.223754.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|