returnn 1.20250901.123052__py3-none-any.whl → 1.20260105.192646__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. returnn/PKG-INFO +2 -2
  2. returnn/_setup_info_generated.py +2 -2
  3. returnn/config.py +1 -1
  4. returnn/datasets/basic.py +29 -13
  5. returnn/datasets/distrib_files.py +61 -3
  6. returnn/datasets/generating.py +12 -21
  7. returnn/datasets/huggingface.py +434 -0
  8. returnn/datasets/lm.py +20 -0
  9. returnn/datasets/meta.py +179 -60
  10. returnn/datasets/multi_proc.py +1 -1
  11. returnn/datasets/postprocessing.py +597 -108
  12. returnn/datasets/text_dict.py +1 -1
  13. returnn/datasets/util/vocabulary.py +90 -0
  14. returnn/frontend/_backend.py +7 -0
  15. returnn/frontend/array_.py +54 -1
  16. returnn/frontend/attention.py +54 -20
  17. returnn/frontend/conv.py +273 -54
  18. returnn/frontend/decoder/transformer.py +36 -17
  19. returnn/frontend/encoder/conformer.py +1 -0
  20. returnn/frontend/encoder/transformer.py +2 -0
  21. returnn/frontend/loss.py +40 -1
  22. returnn/frontend/module.py +8 -1
  23. returnn/frontend/nested.py +9 -0
  24. returnn/native_op.cpp +80 -0
  25. returnn/sprint/cache.py +12 -13
  26. returnn/tensor/_dim_extra.py +51 -29
  27. returnn/tensor/_tensor_extra.py +6 -1
  28. returnn/tensor/utils.py +7 -4
  29. returnn/tf/frontend_layers/_backend.py +11 -2
  30. returnn/tf/frontend_low_level/_backend.py +15 -0
  31. returnn/tf/layers/basic.py +16 -38
  32. returnn/tf/native_op.py +11 -58
  33. returnn/tf/network.py +1 -1
  34. returnn/tf/util/basic.py +19 -0
  35. returnn/torch/data/returnn_dataset_wrapper.py +9 -3
  36. returnn/torch/engine.py +67 -2
  37. returnn/torch/frontend/_backend.py +119 -7
  38. returnn/torch/util/diagnose_gpu.py +65 -31
  39. returnn/torch/util/exception_helper.py +7 -1
  40. returnn/util/basic.py +6 -7
  41. returnn/util/better_exchook.py +4 -0
  42. returnn/util/collect_outputs_dict.py +79 -0
  43. returnn/util/debug.py +11 -2
  44. returnn/util/file_cache.py +42 -4
  45. returnn/util/task_system.py +1 -1
  46. {returnn-1.20250901.123052.dist-info → returnn-1.20260105.192646.dist-info}/METADATA +2 -2
  47. {returnn-1.20250901.123052.dist-info → returnn-1.20260105.192646.dist-info}/RECORD +50 -48
  48. {returnn-1.20250901.123052.dist-info → returnn-1.20260105.192646.dist-info}/LICENSE +0 -0
  49. {returnn-1.20250901.123052.dist-info → returnn-1.20260105.192646.dist-info}/WHEEL +0 -0
  50. {returnn-1.20250901.123052.dist-info → returnn-1.20260105.192646.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,434 @@
1
+ """
2
+ HuggingFace dataset wrapper
3
+
4
+ See https://github.com/rwth-i6/returnn/issues/1257 for some initial discussion.
5
+ """
6
+
7
+ from __future__ import annotations
8
+ from typing import TYPE_CHECKING, Optional, Union, Any, Callable, Sequence, Dict, List
9
+ import os
10
+ import re
11
+ import numpy
12
+ from returnn.tensor import Tensor
13
+ from returnn.util import file_cache
14
+ from .basic import DatasetSeq
15
+ from .cached2 import CachedDataset2
16
+ from .util.vocabulary import Vocabulary
17
+ from .util.strings import str_to_numpy_array
18
+
19
+ if TYPE_CHECKING:
20
+ # noinspection PyUnresolvedReferences,PyPackageRequirements
21
+ import datasets
22
+
23
+
24
+ class HuggingFaceDataset(CachedDataset2):
25
+ """
26
+ HuggingFace dataset wrapper.
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ dataset_opts: Union[
32
+ Dict[str, Any],
33
+ str,
34
+ os.PathLike,
35
+ Sequence[Union[str, os.PathLike]],
36
+ Callable[[], Union[Dict[str, Any], str, os.PathLike, Sequence[Union[str, os.PathLike]], datasets.Dataset]],
37
+ ],
38
+ *,
39
+ use_file_cache: bool = False,
40
+ map_func: Optional[Callable[[datasets.Dataset], datasets.Dataset]] = None,
41
+ rename_columns: Optional[Dict[str, str]] = None,
42
+ cast_columns: Optional[Dict[str, Dict[str, Any]]] = None,
43
+ data_format: Dict[str, Dict[str, Any]],
44
+ seq_tag_column: Optional[str] = "id",
45
+ sorting_seq_len_column_data: Optional[str] = None,
46
+ sorting_seq_len_column: Optional[str] = None,
47
+ **kwargs,
48
+ ):
49
+ """
50
+ :param dataset_opts: either a dict of options for :func:`datasets.load_dataset`
51
+ or a path to a local dataset for :func:`datasets.load_from_disk`,
52
+ or a list of Arrow filenames to load with :func:`datasets.Dataset.from_file` and concatenate.
53
+ It can also be a callable returning one of the above,
54
+ or returning a :class:`datasets.Dataset` directly.
55
+ :param use_file_cache: if True, will cache the dataset files on local disk using :mod:`file_cache`.
56
+ This only works for dataset_opts which is a str or list of str (or callable returning that).
57
+ :param map_func: optional function to apply to the dataset after loading
58
+ :param rename_columns: if given, will rename these columns
59
+ :param cast_columns: if given, will cast these columns to the specified types.
60
+ This is useful if the dataset has not the expected types.
61
+ See :func:`datasets.Dataset.cast` for details.
62
+ You can also e.g. enforce some sample_rate for audio, etc.
63
+ :param data_format:
64
+ For each column name (data key), specify the format,
65
+ as a dict with entries for "dim", "ndim", "shape", and/or "dtype",
66
+ compatible to :class:`Tensor`.
67
+ It can be a subset of the available columns.
68
+ If "vocab" is specified, and the underlying HF datasets column is of dtype "string",
69
+ it will automatically tokenize the string using the vocab.
70
+ :param seq_tag_column: key (column name) in the dataset to use as sequence tag.
71
+ If None, will use the sequence index as tag.
72
+ :param sorting_seq_len_column_data: key (column name) in the dataset to use for sorting by sequence length.
73
+ It will take len(dataset[sorting_seq_len_column_data]) as sequence length (only for sorting/shuffling).
74
+ :param sorting_seq_len_column: key (column name) in the dataset to use for sorting by sequence length.
75
+ It will take the value of dataset[sorting_seq_len_column] as sequence length (only for sorting/shuffling).
76
+ E.g. some datasets provide "duration", "duration_ms", "wav_filesize" or similar such information
77
+ which can be used.
78
+ """
79
+ super().__init__(**kwargs)
80
+
81
+ self.dataset_opts = dataset_opts
82
+ self.use_file_cache = use_file_cache
83
+ self.map_func = map_func
84
+ self.rename_columns = rename_columns
85
+ self.cast_columns = cast_columns
86
+
87
+ self.data_format: Dict[str, Tensor] = {k: _make_tensor_template(v, k) for k, v in data_format.items()}
88
+ self.seq_tag_column: Optional[str] = seq_tag_column
89
+ self.sorting_seq_len_column_data = sorting_seq_len_column_data
90
+ self.sorting_seq_len_column = sorting_seq_len_column
91
+
92
+ self.labels = {k: data.vocab.labels for k, data in self.data_format.items() if data.vocab}
93
+ self.num_outputs = {k: (data.dim or 1, data.ndim) for k, data in self.data_format.items()}
94
+
95
+ self.hf_dataset: Optional[datasets.Dataset] = None # lazily loaded, _lazy_init
96
+ self._seq_order: Optional[Sequence[int]] = None # init_seq_order
97
+ self._seq_tags: Optional[List[str]] = None # get_all_tags cache
98
+
99
+ def _lazy_init(self):
100
+ if self.hf_dataset is not None:
101
+ return
102
+
103
+ # Load the dataset
104
+ # noinspection PyUnresolvedReferences,PyPackageRequirements
105
+ import datasets
106
+
107
+ dataset_opts = self.dataset_opts
108
+ if callable(dataset_opts):
109
+ dataset_opts = dataset_opts()
110
+ if self.use_file_cache:
111
+ assert isinstance(dataset_opts, (str, os.PathLike, list, tuple)), (
112
+ f"{self}: with use_file_cache, dataset_opts must be str or list of str, got {type(dataset_opts)}"
113
+ )
114
+ if isinstance(dataset_opts, (str, os.PathLike)):
115
+ dataset_opts = get_arrow_shard_files_from_hf_dataset_dir(dataset_opts)
116
+ assert isinstance(dataset_opts, (list, tuple))
117
+ cache = file_cache.get_instance()
118
+ dataset_opts = [cache.get_file(os.fspath(fn)) for fn in dataset_opts]
119
+ self.set_file_cache(cache)
120
+ if isinstance(dataset_opts, dict):
121
+ self.hf_dataset = datasets.load_dataset(**dataset_opts)
122
+ elif isinstance(dataset_opts, str):
123
+ self.hf_dataset = datasets.load_from_disk(dataset_opts)
124
+ elif isinstance(dataset_opts, (list, tuple)):
125
+ self.hf_dataset = datasets.concatenate_datasets([datasets.Dataset.from_file(fn) for fn in dataset_opts])
126
+ elif isinstance(dataset_opts, datasets.Dataset):
127
+ self.hf_dataset = dataset_opts
128
+ else:
129
+ raise TypeError(f"{self}: invalid dataset_opts type {type(dataset_opts)}")
130
+ assert isinstance(self.hf_dataset, datasets.Dataset), (
131
+ f"{self}: Expected single dataset, got {type(self.hf_dataset)} {self.hf_dataset}. Specify split if needed."
132
+ )
133
+
134
+ if self.map_func is not None:
135
+ self.hf_dataset = self.map_func(self.hf_dataset)
136
+
137
+ if self.rename_columns:
138
+ self.hf_dataset = self.hf_dataset.rename_columns(self.rename_columns)
139
+
140
+ if self.cast_columns:
141
+ # Note: prefer cast_column, as this can avoid using `map`, i.e. be faster.
142
+ for key, column_format in self.cast_columns.items():
143
+ assert key in self.hf_dataset.features, (
144
+ f"{self}: cast_column {key} not in dataset features {self.hf_dataset.features}"
145
+ )
146
+ feat = datasets.features.features.generate_from_dict(column_format)
147
+ self.hf_dataset = self.hf_dataset.cast_column(key, feat)
148
+
149
+ if self.seq_tag_column:
150
+ assert self.seq_tag_column in self.hf_dataset.features, (
151
+ f"{self}: seq_tag_column {self.seq_tag_column} not in dataset features {self.hf_dataset.features}"
152
+ )
153
+ assert self.hf_dataset.features[self.seq_tag_column].dtype in ("string", "int64"), (
154
+ f"{self}: seq_tag_column {self.seq_tag_column} must be of dtype string or int64,"
155
+ f" got {self.hf_dataset.features[self.seq_tag_column].dtype}"
156
+ )
157
+
158
+ selected_columns = list(self.data_format.keys())
159
+ if self.seq_tag_column and self.seq_tag_column not in selected_columns:
160
+ selected_columns.append(self.seq_tag_column)
161
+ if self.sorting_seq_len_column and self.sorting_seq_len_column not in selected_columns:
162
+ selected_columns.append(self.sorting_seq_len_column)
163
+ if self.sorting_seq_len_column_data and self.sorting_seq_len_column_data not in selected_columns:
164
+ selected_columns.append(self.sorting_seq_len_column_data)
165
+ self.hf_dataset = self.hf_dataset.select_columns(selected_columns)
166
+
167
+ self.hf_dataset.set_format("numpy")
168
+
169
+ for key, user_format in self.data_format.items():
170
+ feature = self.hf_dataset.features[key]
171
+ inferred_format = _infer_data_format_for_feature(feature, f"{self}: column {key}: ")
172
+ if user_format.vocab and inferred_format["dtype"] == "string":
173
+ pass # allow to auto-tokenize strings when vocab is specified
174
+ else:
175
+ for key_ in ["dtype", "ndim", "dim"]:
176
+ assert getattr(user_format, key_) == inferred_format[key_], (
177
+ f"{self}: column {key}, user-specified {user_format}, {key_}:"
178
+ f" user-specified {getattr(user_format, key_)} does not match inferred {inferred_format[key_]}"
179
+ )
180
+ if "vocab" in inferred_format and not user_format.vocab:
181
+ assert user_format.sparse, f"{self}: column {key}: user_format expected to be sparse, got {user_format}"
182
+ user_format.sparse_dim.vocab = Vocabulary.create_vocab(**inferred_format["vocab"])
183
+ self.labels[key] = user_format.vocab.labels
184
+
185
+ def get_data_keys(self) -> List[str]:
186
+ """:return: list of data keys"""
187
+ return list(self.data_format.keys())
188
+
189
+ def get_target_list(self) -> List[str]:
190
+ """:return: list of target keys"""
191
+ return self.get_data_keys() # it's somewhat arbitrary...
192
+
193
+ def get_data_shape(self, key: str) -> List[int]:
194
+ """:return: data shape for the given key"""
195
+ return list(self.data_format[key].shape)
196
+
197
+ def get_data_dim(self, key: str) -> int:
198
+ """:return: data dimension for the given key"""
199
+ return self.data_format[key].dim
200
+
201
+ def is_data_sparse(self, key: str) -> bool:
202
+ """:return: whether the data is sparse for the given key"""
203
+ return self.data_format[key].sparse
204
+
205
+ def get_data_dtype(self, key: str) -> str:
206
+ """:return: dtype"""
207
+ return self.data_format[key].dtype
208
+
209
+ def _get_seq_len(self, seq_idx: int) -> Union[int, float]:
210
+ if self._seq_order_seq_lens_by_idx is not None:
211
+ self._get_seq_len = self._seq_order_seq_lens_by_idx.__getitem__ # faster
212
+ return self._seq_order_seq_lens_by_idx[seq_idx]
213
+ assert not self._seq_order_seq_lens_file # not expected to call this
214
+ if self.sorting_seq_len_column:
215
+ self._seq_order_seq_lens_by_idx = numpy.array(self.hf_dataset[self.sorting_seq_len_column])
216
+ self._get_seq_len = self._seq_order_seq_lens_by_idx.__getitem__ # faster
217
+ v = self._seq_order_seq_lens_by_idx[seq_idx]
218
+ return int(v) # noqa
219
+ if self.sorting_seq_len_column_data:
220
+ v = self.hf_dataset[seq_idx][self.sorting_seq_len_column_data]
221
+ return len(v) # noqa
222
+ raise ValueError(
223
+ f"{self}: sorting/shuffling by seq len not configured,"
224
+ f" need sorting_seq_len_column or sorting_seq_len_column_data"
225
+ )
226
+
227
+ @property
228
+ def num_seqs(self) -> int:
229
+ """:return: number of sequences"""
230
+ assert self._seq_order is not None, "num_seqs is only known after calling init_seq_order()"
231
+ return len(self._seq_order)
232
+
233
+ def get_tag(self, sorted_seq_idx: int) -> str:
234
+ """:return: tag of the sequence"""
235
+ corpus_seq_idx = self.get_corpus_seq_idx(sorted_seq_idx)
236
+ self._lazy_init()
237
+ dataset_item = self.hf_dataset[corpus_seq_idx]
238
+ return self._get_seq_tag(corpus_seq_idx, dataset_item)
239
+
240
+ def get_all_tags(self) -> List[str]:
241
+ """:return: all tags"""
242
+ if self._seq_tags is not None:
243
+ return self._seq_tags
244
+ self._lazy_init()
245
+ if self.seq_tag_column:
246
+ res = list(map(str, self.hf_dataset[self.seq_tag_column]))
247
+ else:
248
+ res = [f"seq-{i}" for i in range(self.hf_dataset.num_rows)]
249
+ self._seq_tags = res
250
+ return res
251
+
252
+ def get_total_num_seqs(self, *, fast: bool = False) -> int:
253
+ """:return: total number of sequences in the dataset"""
254
+ if fast:
255
+ return super().get_total_num_seqs(fast=True)
256
+ self._lazy_init()
257
+ return self.hf_dataset.num_rows
258
+
259
+ def init_seq_order(
260
+ self,
261
+ epoch: Optional[int] = None,
262
+ seq_list: Optional[Sequence[str]] = None,
263
+ seq_order: Optional[Sequence[int]] = None,
264
+ ) -> bool:
265
+ """
266
+ :param epoch:
267
+ :param seq_list: List of sequence tags, to set a predefined order.
268
+ :param seq_order: List of corpus sequence indices, to set a predefined order.
269
+ :returns whether the order changed (True is always safe to return)
270
+ """
271
+ super().init_seq_order(epoch=epoch, seq_list=seq_list, seq_order=seq_order)
272
+
273
+ if seq_order is not None:
274
+ self._seq_order = seq_order
275
+ elif seq_list is not None:
276
+ all_tags = self.get_all_tags()
277
+ self._seq_order = [all_tags.index(tag) for tag in seq_list]
278
+ elif epoch is None:
279
+ self._seq_order = ()
280
+ else:
281
+ self._lazy_init()
282
+ self._seq_order = self.get_seq_order_for_epoch(
283
+ epoch=epoch, num_seqs=self.hf_dataset.num_rows, get_seq_len=self._get_seq_len
284
+ )
285
+ return True
286
+
287
+ def _collect_single_seq(self, seq_idx: int) -> DatasetSeq:
288
+ # noinspection PyUnresolvedReferences,PyPackageRequirements
289
+ import datasets
290
+
291
+ corpus_seq_idx = self.get_corpus_seq_idx(seq_idx)
292
+
293
+ def _ensure_numpy(k, x):
294
+ if isinstance(x, numpy.ndarray): # fast path
295
+ return x
296
+ if isinstance(x, str):
297
+ if self.data_format[k].dtype == "string":
298
+ return str_to_numpy_array(x)
299
+ if self.data_format[k].vocab:
300
+ return numpy.array(self.data_format[k].vocab.get_seq(x), dtype=self.data_format[k].dtype)
301
+ raise ValueError(f"{self}: column {k}: cannot convert string {x!r} to numpy array")
302
+ feat = self.hf_dataset.features[k]
303
+ if isinstance(feat, datasets.features.Audio):
304
+ # In HF datasets 3, this is just a dict.
305
+ # In HF datasets 4, this can also be a datasets.features._torchcodec.AudioDecoder.
306
+ assert isinstance(x, dict) or x.__class__.__name__ == "AudioDecoder"
307
+ if feat.decode:
308
+ x = x["array"]
309
+ else:
310
+ x = x["bytes"]
311
+ if isinstance(x, numpy.ndarray): # fast path
312
+ return x
313
+ if isinstance(x, (bytes, bytearray)):
314
+ return numpy.frombuffer(x, dtype=self.data_format[k].dtype)
315
+ return numpy.array(x)
316
+
317
+ self._lazy_init()
318
+ dataset_item = self.hf_dataset[corpus_seq_idx]
319
+ seq_tag = self._get_seq_tag(corpus_seq_idx, dataset_item)
320
+ features = {k: _ensure_numpy(k, dataset_item[k]) for k in self.data_format}
321
+ return DatasetSeq(seq_idx, features=features, seq_tag=seq_tag)
322
+
323
+ def _get_seq_tag(self, corpus_seq_idx: int, dataset_item: Dict[str, Any]) -> str:
324
+ if self.seq_tag_column:
325
+ seq_tag = dataset_item[self.seq_tag_column]
326
+ assert isinstance(seq_tag, (str, int, numpy.int64)), f"got {type(seq_tag)} {seq_tag!r}"
327
+ seq_tag = str(seq_tag)
328
+ else:
329
+ seq_tag = f"seq-{corpus_seq_idx}"
330
+ return seq_tag
331
+
332
+ def get_current_seq_order(self) -> Sequence[int]:
333
+ """:return: list of corpus seq idx"""
334
+ assert self._seq_order is not None
335
+ return self._seq_order
336
+
337
+ def get_corpus_seq_idx(self, sorted_seq_idx: int) -> int:
338
+ """:return: corpus seq idx"""
339
+ return int(self._seq_order[sorted_seq_idx])
340
+
341
+
342
+ def get_arrow_shard_files_from_hf_dataset_dir(hf_data_dir: Union[str, os.PathLike]) -> List[str]:
343
+ """
344
+ Given some HF datasets directory (created via :func:`datasets.save_to_disk`),
345
+ return the list of Arrow shard files (``data-*-of-*.arrow``).
346
+ This also verifies that the directory looks like a valid HF datasets directory.
347
+ The order of the returned list is by shard index.
348
+ Note that this does not load the dataset, just inspects the directory structure.
349
+
350
+ :param hf_data_dir: directory
351
+ :return: list of Arrow shard files
352
+ """
353
+ hf_data_dir = os.fspath(hf_data_dir)
354
+ content = os.listdir(hf_data_dir)
355
+ assert "state.json" in content, f"not a valid HF datasets dir: {hf_data_dir!r}"
356
+ assert "dataset_info.json" in content, f"not a valid HF datasets dir: {hf_data_dir!r}"
357
+ pat = re.compile("^(.*)-([0-9]+)-of-([0-9]+).arrow$")
358
+ content = [pat.match(fn) for fn in content]
359
+ content = [m for m in content if m]
360
+ assert content, f"no matching .arrow files in {hf_data_dir!r} found, expected *-*-of-*.arrow"
361
+ prefix = content[0].group(1)
362
+ assert all(m.group(1) == prefix for m in content), (
363
+ f"mismatching prefix in {hf_data_dir!r}, expected {prefix}, got {[m.group(1) for m in content]}"
364
+ )
365
+ num_shards = int(content[0].group(3))
366
+ assert all(int(m.group(3)) == num_shards for m in content), (
367
+ f"mismatching number of shards in {hf_data_dir!r}, expected {num_shards}, got {[m.group(3) for m in content]}"
368
+ )
369
+ assert len(content) == num_shards, f"expected {num_shards} shard files in {hf_data_dir!r}, got {content}"
370
+ content_by_idx = {int(m.group(2)): m for m in content}
371
+ assert set(content_by_idx.keys()) == set(range(num_shards)), (
372
+ f"expected shard indices 0..{num_shards - 1} in {hf_data_dir!r}, got {sorted(content_by_idx.keys())}"
373
+ )
374
+ return [hf_data_dir + "/" + content_by_idx[i].group(0) for i in range(num_shards)]
375
+
376
+
377
+ def _infer_data_format_for_feature(
378
+ feature: Union[
379
+ datasets.features.Sequence,
380
+ datasets.features.ClassLabel,
381
+ datasets.features.Value,
382
+ datasets.features.Array2D,
383
+ datasets.features.Array3D,
384
+ datasets.features.Array4D,
385
+ datasets.features.Audio,
386
+ ],
387
+ exc_prefix: str = "",
388
+ ) -> Dict[str, Any]:
389
+ # noinspection PyUnresolvedReferences,PyPackageRequirements
390
+ import datasets
391
+
392
+ labels = None
393
+ num_classes = None
394
+ num_dims = 0
395
+ while isinstance(feature, datasets.features.Sequence):
396
+ feature: datasets.features.List # typing for HF datasets 4
397
+ num_dims += 1
398
+ if feature.length != -1:
399
+ num_classes = feature.length
400
+ feature = feature.feature
401
+ if isinstance(feature, datasets.features.ClassLabel):
402
+ labels = feature.names
403
+ dtype = feature.dtype
404
+ num_classes = feature.num_classes # noqa
405
+ elif isinstance(feature, datasets.features.Value):
406
+ dtype = feature.dtype
407
+ elif isinstance(feature, (datasets.features.Array2D, datasets.features.Array3D, datasets.features.Array4D)):
408
+ dtype = feature.dtype
409
+ num_classes = feature.shape[-1]
410
+ num_dims += len(feature.shape)
411
+ elif isinstance(feature, datasets.features.Audio):
412
+ if feature.decode:
413
+ dtype = "float32" # samples
414
+ else:
415
+ dtype = "uint8" # bytes
416
+ num_dims += 1 # time axis
417
+ else:
418
+ assert False, f"{exc_prefix}unsupported feature type {type(feature)} {feature}"
419
+
420
+ d = {"dim": num_classes, "ndim": num_dims, "dtype": dtype}
421
+ if labels:
422
+ d["sparse"] = True
423
+ d["vocab"] = {"vocab_file": None, "labels": labels, "unknown_label": None}
424
+ return d
425
+
426
+
427
+ def _make_tensor_template(data: Union[Dict[str, Any], Tensor], name: str) -> Tensor:
428
+ if isinstance(data, Tensor):
429
+ data = data.copy(name)
430
+ else:
431
+ assert isinstance(data, dict)
432
+ data = Tensor(name, batch_dim_axis=None, **data)
433
+ assert data.batch_dim_axis is None
434
+ return data
returnn/datasets/lm.py CHANGED
@@ -694,6 +694,26 @@ class LmDataset(CachedDataset2):
694
694
  self.next_seq_idx = seq_idx + 1
695
695
  return DatasetSeq(seq_idx=seq_idx, features=data, targets=targets, seq_tag=seq_tag)
696
696
 
697
+ def finish_epoch(self, *, free_resources: bool = False):
698
+ """finish epoch"""
699
+ super().finish_epoch(free_resources=free_resources)
700
+
701
+ if free_resources:
702
+ self._orths_offsets_and_lens = None
703
+ if self._orth_mmaps is not None:
704
+ for m in self._orth_mmaps:
705
+ if m is not None:
706
+ m.close()
707
+ self._orth_mmaps = None
708
+ if self._orth_files is not None:
709
+ for f in self._orth_files:
710
+ if f is not None:
711
+ f.close()
712
+ self._orth_files = None
713
+
714
+ self._seq_list = None
715
+ self._seq_index_by_tag = None
716
+
697
717
 
698
718
  def _is_bliss(filename):
699
719
  """