autoencoders 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (75) hide show
  1. autoencoders/__init__.py +181 -0
  2. autoencoders/configuration_utils.py +49 -0
  3. autoencoders/data/__init__.py +41 -0
  4. autoencoders/data/base.py +276 -0
  5. autoencoders/data/embeddings.py +145 -0
  6. autoencoders/data/fasttext.py +143 -0
  7. autoencoders/data/glove.py +165 -0
  8. autoencoders/data/loading.py +21 -0
  9. autoencoders/data/numberbatch.py +142 -0
  10. autoencoders/modeling_outputs.py +131 -0
  11. autoencoders/modeling_utils.py +51 -0
  12. autoencoders/models/__init__.py +5 -0
  13. autoencoders/models/aae/__init__.py +6 -0
  14. autoencoders/models/aae/configuration_aae.py +44 -0
  15. autoencoders/models/aae/modeling_aae.py +101 -0
  16. autoencoders/models/ae/__init__.py +13 -0
  17. autoencoders/models/ae/configuration_ae.py +44 -0
  18. autoencoders/models/ae/modeling_ae.py +60 -0
  19. autoencoders/models/base/__init__.py +13 -0
  20. autoencoders/models/base/configuration_base.py +33 -0
  21. autoencoders/models/base/configuration_vae.py +52 -0
  22. autoencoders/models/base/configuration_vq.py +62 -0
  23. autoencoders/models/base/modeling_base.py +122 -0
  24. autoencoders/models/base/modeling_vae.py +166 -0
  25. autoencoders/models/base/modeling_vq.py +187 -0
  26. autoencoders/models/betavae/__init__.py +13 -0
  27. autoencoders/models/betavae/configuration_betavae.py +47 -0
  28. autoencoders/models/betavae/modeling_betavae.py +12 -0
  29. autoencoders/models/cae/__init__.py +6 -0
  30. autoencoders/models/cae/configuration_cae.py +38 -0
  31. autoencoders/models/cae/modeling_cae.py +73 -0
  32. autoencoders/models/dae/__init__.py +14 -0
  33. autoencoders/models/dae/configuration_dae.py +49 -0
  34. autoencoders/models/dae/modeling_dae.py +74 -0
  35. autoencoders/models/dvae/__init__.py +6 -0
  36. autoencoders/models/dvae/configuration_dvae.py +54 -0
  37. autoencoders/models/dvae/modeling_dvae.py +89 -0
  38. autoencoders/models/fsq/__init__.py +6 -0
  39. autoencoders/models/fsq/configuration_fsq.py +48 -0
  40. autoencoders/models/fsq/modeling_fsq.py +90 -0
  41. autoencoders/models/hvae/__init__.py +6 -0
  42. autoencoders/models/hvae/configuration_hvae.py +45 -0
  43. autoencoders/models/hvae/modeling_hvae.py +145 -0
  44. autoencoders/models/klsae/__init__.py +6 -0
  45. autoencoders/models/klsae/configuration_klsae.py +42 -0
  46. autoencoders/models/klsae/modeling_klsae.py +63 -0
  47. autoencoders/models/loading.py +56 -0
  48. autoencoders/models/pqvae/__init__.py +6 -0
  49. autoencoders/models/pqvae/configuration_pqvae.py +52 -0
  50. autoencoders/models/pqvae/modeling_pqvae.py +133 -0
  51. autoencoders/models/rqvae/__init__.py +6 -0
  52. autoencoders/models/rqvae/configuration_rqvae.py +50 -0
  53. autoencoders/models/rqvae/modeling_rqvae.py +124 -0
  54. autoencoders/models/sae/__init__.py +13 -0
  55. autoencoders/models/sae/configuration_sae.py +38 -0
  56. autoencoders/models/sae/modeling_sae.py +53 -0
  57. autoencoders/models/topksae/__init__.py +6 -0
  58. autoencoders/models/topksae/configuration_topksae.py +40 -0
  59. autoencoders/models/topksae/modeling_topksae.py +55 -0
  60. autoencoders/models/vae/__init__.py +14 -0
  61. autoencoders/models/vae/configuration_vae.py +39 -0
  62. autoencoders/models/vae/modeling_vae.py +66 -0
  63. autoencoders/models/vqvae/__init__.py +13 -0
  64. autoencoders/models/vqvae/configuration_vqvae.py +45 -0
  65. autoencoders/models/vqvae/modeling_vqvae.py +95 -0
  66. autoencoders/models/wae/__init__.py +6 -0
  67. autoencoders/models/wae/configuration_wae.py +44 -0
  68. autoencoders/models/wae/modeling_wae.py +67 -0
  69. autoencoders/training/__init__.py +26 -0
  70. autoencoders/training/display.py +373 -0
  71. autoencoders/training/trainer.py +495 -0
  72. autoencoders-0.1.0.dist-info/METADATA +284 -0
  73. autoencoders-0.1.0.dist-info/RECORD +75 -0
  74. autoencoders-0.1.0.dist-info/WHEEL +5 -0
  75. autoencoders-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,181 @@
1
+ """Top-level package for the autoencoders library."""
2
+
3
+ from .models.aae.configuration_aae import AdversarialAutoencoderConfig
4
+ from .configuration_utils import PretrainedConfig
5
+ from .modeling_outputs import (
6
+ AdversarialAutoencoderOutput,
7
+ AutoencoderExport,
8
+ AutoencoderOutput,
9
+ BaseAutoencoderOutput,
10
+ ContractiveAutoencoderOutput,
11
+ DenoisingAutoencoderOutput,
12
+ DenoisingVariationalAutoencoderOutput,
13
+ FiniteScalarQuantizedAutoencoderOutput,
14
+ HierarchicalVariationalAutoencoderOutput,
15
+ KLSparseAutoencoderOutput,
16
+ QuantizedAutoencoderOutput,
17
+ SparseAutoencoderOutput,
18
+ TopKSparseAutoencoderOutput,
19
+ VariationalAutoencoderOutput,
20
+ WassersteinAutoencoderOutput,
21
+ )
22
+ from .models.ae.configuration_ae import AutoencoderConfig
23
+ from .models.cae.configuration_cae import ContractiveAutoencoderConfig
24
+ from .models.base.configuration_base import BaseAutoencoderConfig
25
+ from .models.base.configuration_vae import BaseVariationalAutoencoderConfig
26
+ from .models.base.configuration_vq import BaseVectorQuantizedAutoencoderConfig
27
+ from .models.betavae.configuration_betavae import BetaVariationalAutoencoderConfig
28
+ from .models.dae.configuration_dae import DenoisingAutoencoderConfig
29
+ from .models.dvae.configuration_dvae import DenoisingVariationalAutoencoderConfig
30
+ from .models.fsq.configuration_fsq import FiniteScalarQuantizedAutoencoderConfig
31
+ from .models.hvae.configuration_hvae import HierarchicalVariationalAutoencoderConfig
32
+ from .models.klsae.configuration_klsae import KLSparseAutoencoderConfig
33
+ from .models.pqvae.configuration_pqvae import ProductQuantizedAutoencoderConfig
34
+ from .models.rqvae.configuration_rqvae import ResidualQuantizedAutoencoderConfig
35
+ from .models.sae.configuration_sae import SparseAutoencoderConfig
36
+ from .models.topksae.configuration_topksae import TopKSparseAutoencoderConfig
37
+ from .models.vae.configuration_vae import VariationalAutoencoderConfig
38
+ from .models.wae.configuration_wae import WassersteinAutoencoderConfig
39
+ from .models.vqvae.configuration_vqvae import VectorQuantizedAutoencoderConfig
40
+
41
+ __all__ = [
42
+ "AdversarialAutoencoderConfig",
43
+ "AdversarialAutoencoderOutput",
44
+ "AutoencoderConfig",
45
+ "AutoencoderExport",
46
+ "AutoencoderOutput",
47
+ "BaseAutoencoderOutput",
48
+ "BaseAutoencoderConfig",
49
+ "BaseVariationalAutoencoderConfig",
50
+ "BaseVectorQuantizedAutoencoderConfig",
51
+ "BetaVariationalAutoencoderConfig",
52
+ "ContractiveAutoencoderConfig",
53
+ "ContractiveAutoencoderOutput",
54
+ "DenoisingAutoencoderConfig",
55
+ "DenoisingAutoencoderOutput",
56
+ "DenoisingVariationalAutoencoderConfig",
57
+ "DenoisingVariationalAutoencoderOutput",
58
+ "FiniteScalarQuantizedAutoencoderConfig",
59
+ "FiniteScalarQuantizedAutoencoderOutput",
60
+ "HierarchicalVariationalAutoencoderConfig",
61
+ "HierarchicalVariationalAutoencoderOutput",
62
+ "KLSparseAutoencoderConfig",
63
+ "KLSparseAutoencoderOutput",
64
+ "PretrainedConfig",
65
+ "ProductQuantizedAutoencoderConfig",
66
+ "QuantizedAutoencoderOutput",
67
+ "ResidualQuantizedAutoencoderConfig",
68
+ "SparseAutoencoderConfig",
69
+ "SparseAutoencoderOutput",
70
+ "TopKSparseAutoencoderConfig",
71
+ "TopKSparseAutoencoderOutput",
72
+ "VariationalAutoencoderConfig",
73
+ "VariationalAutoencoderOutput",
74
+ "WassersteinAutoencoderConfig",
75
+ "WassersteinAutoencoderOutput",
76
+ "VectorQuantizedAutoencoderConfig",
77
+ ]
78
+
79
+ try:
80
+ from .models.aae.modeling_aae import AdversarialAutoencoderModel
81
+ from .modeling_utils import PreTrainedAutoencoderModel
82
+ from .data import (
83
+ AutoencoderDataset,
84
+ CachedDataset,
85
+ ConceptNetNumberbatchDataset,
86
+ DatasetLoaders,
87
+ DatasetSplits,
88
+ EmbeddingMatrix,
89
+ EmbeddingTensorDataset,
90
+ FastTextEnglishDataset,
91
+ GloVeDataset,
92
+ create_dataloaders,
93
+ load_dataset,
94
+ load_embedding_artifact,
95
+ load_text_embedding_matrix,
96
+ split_dataset,
97
+ )
98
+ from .models.ae.modeling_ae import AutoencoderModel
99
+ from .models.base.modeling_base import BaseAutoencoderModel
100
+ from .models.base.modeling_vae import BaseVariationalAutoencoderModel
101
+ from .models.base.modeling_vq import BaseVectorQuantizedAutoencoderModel
102
+ from .models.betavae.modeling_betavae import BetaVariationalAutoencoderModel
103
+ from .models.cae.modeling_cae import ContractiveAutoencoderModel
104
+ from .models.dae.modeling_dae import DenoisingAutoencoderModel
105
+ from .models.dvae.modeling_dvae import DenoisingVariationalAutoencoderModel
106
+ from .models.fsq.modeling_fsq import FiniteScalarQuantizedAutoencoderModel
107
+ from .models.hvae.modeling_hvae import HierarchicalVariationalAutoencoderModel
108
+ from .models.klsae.modeling_klsae import KLSparseAutoencoderModel
109
+ from .models.loading import load_model
110
+ from .models.pqvae.modeling_pqvae import ProductQuantizedAutoencoderModel
111
+ from .models.rqvae.modeling_rqvae import ResidualQuantizedAutoencoderModel
112
+ from .models.sae.modeling_sae import SparseAutoencoderModel
113
+ from .models.topksae.modeling_topksae import TopKSparseAutoencoderModel
114
+ from .models.vae.modeling_vae import VariationalAutoencoderModel
115
+ from .models.wae.modeling_wae import WassersteinAutoencoderModel
116
+ from .models.vqvae.modeling_vqvae import VectorQuantizedAutoencoderModel
117
+ from .training import (
118
+ AETrainer,
119
+ AdversarialAutoencoderTrainer,
120
+ AdversarialAutoencoderTrainingArguments,
121
+ TrainerDisplay,
122
+ TrainerDisplayConfig,
123
+ TrainingArguments,
124
+ VAETrainer,
125
+ VQTrainer,
126
+ resolve_device,
127
+ set_seed,
128
+ )
129
+ except ModuleNotFoundError as exc:
130
+ if exc.name != "torch":
131
+ raise
132
+ else:
133
+ __all__.extend(
134
+ [
135
+ "AETrainer",
136
+ "AdversarialAutoencoderModel",
137
+ "AdversarialAutoencoderTrainer",
138
+ "AdversarialAutoencoderTrainingArguments",
139
+ "AutoencoderModel",
140
+ "AutoencoderDataset",
141
+ "BaseAutoencoderModel",
142
+ "BaseVariationalAutoencoderModel",
143
+ "BaseVectorQuantizedAutoencoderModel",
144
+ "BetaVariationalAutoencoderModel",
145
+ "CachedDataset",
146
+ "ContractiveAutoencoderModel",
147
+ "ConceptNetNumberbatchDataset",
148
+ "DatasetLoaders",
149
+ "DatasetSplits",
150
+ "DenoisingAutoencoderModel",
151
+ "DenoisingVariationalAutoencoderModel",
152
+ "EmbeddingMatrix",
153
+ "EmbeddingTensorDataset",
154
+ "FastTextEnglishDataset",
155
+ "FiniteScalarQuantizedAutoencoderModel",
156
+ "GloVeDataset",
157
+ "HierarchicalVariationalAutoencoderModel",
158
+ "KLSparseAutoencoderModel",
159
+ "PreTrainedAutoencoderModel",
160
+ "ProductQuantizedAutoencoderModel",
161
+ "ResidualQuantizedAutoencoderModel",
162
+ "SparseAutoencoderModel",
163
+ "TopKSparseAutoencoderModel",
164
+ "TrainerDisplay",
165
+ "TrainerDisplayConfig",
166
+ "TrainingArguments",
167
+ "VAETrainer",
168
+ "VariationalAutoencoderModel",
169
+ "VQTrainer",
170
+ "VectorQuantizedAutoencoderModel",
171
+ "WassersteinAutoencoderModel",
172
+ "create_dataloaders",
173
+ "load_dataset",
174
+ "load_model",
175
+ "load_embedding_artifact",
176
+ "load_text_embedding_matrix",
177
+ "resolve_device",
178
+ "set_seed",
179
+ "split_dataset",
180
+ ]
181
+ )
@@ -0,0 +1,49 @@
1
+ """Utilities for model configuration objects."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ from pathlib import Path
7
+ from typing import Any
8
+
9
+
10
+ class PretrainedConfig:
11
+ """A lightweight configuration base inspired by the transformers API."""
12
+
13
+ model_type = "config"
14
+
15
+ def __init__(self, **kwargs: Any) -> None:
16
+ self.return_dict = kwargs.pop("return_dict", True)
17
+ for key, value in kwargs.items():
18
+ setattr(self, key, value)
19
+
20
+ def to_dict(self) -> dict[str, Any]:
21
+ payload = dict(self.__dict__)
22
+ payload["model_type"] = self.model_type
23
+ return payload
24
+
25
+ def to_json_string(self) -> str:
26
+ return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
27
+
28
+ def save_pretrained(self, save_directory: str | Path) -> Path:
29
+ save_path = Path(save_directory)
30
+ save_path.mkdir(parents=True, exist_ok=True)
31
+ config_path = save_path / "config.json"
32
+ config_path.write_text(self.to_json_string(), encoding="utf-8")
33
+ return config_path
34
+
35
+ @classmethod
36
+ def from_dict(cls, config_dict: dict[str, Any], **kwargs: Any) -> "PretrainedConfig":
37
+ payload = dict(config_dict)
38
+ payload.pop("model_type", None)
39
+ payload.update(kwargs)
40
+ return cls(**payload)
41
+
42
+ @classmethod
43
+ def from_pretrained(cls, pretrained_model_name_or_path: str | Path, **kwargs: Any) -> "PretrainedConfig":
44
+ config_path = Path(pretrained_model_name_or_path)
45
+ if config_path.is_dir():
46
+ config_path = config_path / "config.json"
47
+ config_dict = json.loads(config_path.read_text(encoding="utf-8"))
48
+ return cls.from_dict(config_dict, **kwargs)
49
+
@@ -0,0 +1,41 @@
1
+ """Utilities for working with real embedding matrices."""
2
+
3
+ from .base import (
4
+ AutoencoderDataset,
5
+ CachedDataset,
6
+ DatasetLoaders,
7
+ DatasetSplits,
8
+ create_dataloaders,
9
+ default_cache_dir,
10
+ split_dataset,
11
+ )
12
+ from .embeddings import (
13
+ EmbeddingMatrix,
14
+ EmbeddingTensorDataset,
15
+ load_embedding_artifact,
16
+ load_text_embedding_matrix,
17
+ save_embedding_artifact,
18
+ )
19
+ from .fasttext import FastTextEnglishDataset
20
+ from .glove import GloVeDataset
21
+ from .loading import load_dataset
22
+ from .numberbatch import ConceptNetNumberbatchDataset
23
+
24
+ __all__ = [
25
+ "AutoencoderDataset",
26
+ "CachedDataset",
27
+ "DatasetLoaders",
28
+ "DatasetSplits",
29
+ "EmbeddingMatrix",
30
+ "EmbeddingTensorDataset",
31
+ "FastTextEnglishDataset",
32
+ "GloVeDataset",
33
+ "ConceptNetNumberbatchDataset",
34
+ "create_dataloaders",
35
+ "default_cache_dir",
36
+ "load_dataset",
37
+ "load_embedding_artifact",
38
+ "load_text_embedding_matrix",
39
+ "save_embedding_artifact",
40
+ "split_dataset",
41
+ ]
@@ -0,0 +1,276 @@
1
+ """Base dataset abstractions for autoencoder training and evaluation."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from abc import ABC, abstractmethod
6
+ from dataclasses import dataclass
7
+ import os
8
+ from pathlib import Path
9
+ import sys
10
+ from typing import Callable
11
+ import urllib.request
12
+
13
+ import torch
14
+ from torch.utils.data import DataLoader, Dataset, Subset
15
+
16
+
17
+ def default_cache_dir() -> Path:
18
+ """Return the default cache directory for downloadable datasets."""
19
+
20
+ cache_dir = os.environ.get("AUTOENCODERS_CACHE")
21
+ if cache_dir:
22
+ return Path(cache_dir).expanduser()
23
+ return Path.home() / ".cache" / "autoencoders"
24
+
25
+
26
+ def format_num_bytes(num_bytes: int) -> str:
27
+ """Format a byte count into a compact human-readable string."""
28
+
29
+ value = float(num_bytes)
30
+ units = ["B", "KB", "MB", "GB", "TB"]
31
+ for unit in units:
32
+ if value < 1024.0 or unit == units[-1]:
33
+ return f"{value:.1f}{unit}"
34
+ value /= 1024.0
35
+ return f"{num_bytes}B"
36
+
37
+
38
+ class DownloadProgressBar:
39
+ """A small terminal progress bar for dataset downloads."""
40
+
41
+ def __init__(self, description: str, total_bytes: int | None, stream=None) -> None:
42
+ self.description = description
43
+ self.total_bytes = total_bytes
44
+ self.stream = sys.stderr if stream is None else stream
45
+ self.downloaded_bytes = 0
46
+ self._finished = False
47
+ self._render()
48
+
49
+ def update(self, chunk_size: int) -> None:
50
+ self.downloaded_bytes += chunk_size
51
+ self._render()
52
+
53
+ def close(self) -> None:
54
+ if self._finished:
55
+ return
56
+ self._finished = True
57
+ self._render(final=True)
58
+
59
+ def _render(self, final: bool = False) -> None:
60
+ if self.total_bytes is not None and self.total_bytes > 0:
61
+ ratio = min(self.downloaded_bytes / self.total_bytes, 1.0)
62
+ filled = int(ratio * 20)
63
+ bar = "=" * filled + "." * (20 - filled)
64
+ percent = int(ratio * 100)
65
+ message = (
66
+ f"\r{self.description} [{bar}] {percent:3d}% "
67
+ f"{format_num_bytes(self.downloaded_bytes)}/{format_num_bytes(self.total_bytes)}"
68
+ )
69
+ else:
70
+ message = f"\r{self.description} {format_num_bytes(self.downloaded_bytes)}"
71
+
72
+ if final:
73
+ message += "\n"
74
+
75
+ self.stream.write(message)
76
+ self.stream.flush()
77
+
78
+
79
+ class AutoencoderDataset(Dataset[torch.Tensor]):
80
+ """Simple dataset contract for autoencoder-friendly tensor samples."""
81
+
82
+ split: str
83
+
84
+ def __init__(self, split: str = "train") -> None:
85
+ self.split = split
86
+
87
+
88
+ class CachedDataset(ABC):
89
+ """Base class for datasets that download raw files and cache processed artifacts."""
90
+
91
+ dataset_name = "dataset"
92
+
93
+ def __init__(self, root: str | Path | None = None) -> None:
94
+ self.root = Path(root) if root is not None else default_cache_dir()
95
+ self.dataset_dir = self.root / self.dataset_name
96
+ self.raw_dir = self.dataset_dir / "raw"
97
+ self.external_dir = self.dataset_dir / "external"
98
+ self.processed_dir = self.dataset_dir / "processed"
99
+
100
+ def ensure_prepared(
101
+ self,
102
+ *,
103
+ download: bool = True,
104
+ force_download: bool = False,
105
+ force_prepare: bool = False,
106
+ ) -> Path:
107
+ """Ensure the processed artifact exists and return its directory."""
108
+
109
+ artifact_dir = self.artifact_dir
110
+ if self.is_prepared() and not force_prepare:
111
+ return artifact_dir
112
+
113
+ if download:
114
+ self.download(force=force_download)
115
+ elif not self.has_raw_data():
116
+ raise FileNotFoundError(
117
+ f"Raw data for {self.dataset_name!r} is missing under {self.raw_dir}."
118
+ )
119
+
120
+ self.prepare()
121
+ return artifact_dir
122
+
123
+ def is_prepared(self) -> bool:
124
+ """Return True when the processed artifact is complete and ready to use."""
125
+
126
+ return self.artifact_dir.exists()
127
+
128
+ def download_to_cache(
129
+ self,
130
+ *,
131
+ url: str,
132
+ destination: Path,
133
+ validator: Callable[[Path], bool] | None = None,
134
+ description: str | None = None,
135
+ force: bool = False,
136
+ chunk_size: int = 1024 * 1024,
137
+ ) -> Path:
138
+ """Download a file atomically with progress reporting and cache validation."""
139
+
140
+ destination.parent.mkdir(parents=True, exist_ok=True)
141
+ temp_path = destination.with_name(f"{destination.name}.tmp")
142
+ self._cleanup_temp_file(temp_path)
143
+
144
+ if destination.exists() and not force:
145
+ if validator is None or validator(destination):
146
+ return destination
147
+ destination.unlink()
148
+ elif destination.exists():
149
+ destination.unlink()
150
+
151
+ try:
152
+ with urllib.request.urlopen(url) as response, temp_path.open("wb") as handle:
153
+ total_bytes = self._response_content_length(response)
154
+ progress = DownloadProgressBar(description or destination.name, total_bytes)
155
+ try:
156
+ while True:
157
+ chunk = response.read(chunk_size)
158
+ if not chunk:
159
+ break
160
+ handle.write(chunk)
161
+ progress.update(len(chunk))
162
+ finally:
163
+ progress.close()
164
+ except Exception:
165
+ self._cleanup_temp_file(temp_path)
166
+ raise
167
+
168
+ if validator is not None and not validator(temp_path):
169
+ self._cleanup_temp_file(temp_path)
170
+ raise ValueError(f"Downloaded file {temp_path} failed validation.")
171
+
172
+ temp_path.replace(destination)
173
+ return destination
174
+
175
+ @staticmethod
176
+ def _cleanup_temp_file(path: Path) -> None:
177
+ if path.exists():
178
+ path.unlink()
179
+
180
+ @staticmethod
181
+ def _response_content_length(response) -> int | None:
182
+ length = response.headers.get("Content-Length")
183
+ if length is None:
184
+ return None
185
+ try:
186
+ return int(length)
187
+ except (TypeError, ValueError):
188
+ return None
189
+
190
+ @property
191
+ @abstractmethod
192
+ def artifact_dir(self) -> Path:
193
+ """Directory containing the processed dataset artifact."""
194
+
195
+ @abstractmethod
196
+ def has_raw_data(self) -> bool:
197
+ """Return True when enough raw files are present to prepare the dataset."""
198
+
199
+ @abstractmethod
200
+ def download(self, *, force: bool = False) -> None:
201
+ """Download the raw dataset files into the cache."""
202
+
203
+ @abstractmethod
204
+ def prepare(self) -> None:
205
+ """Convert raw files into a processed artifact."""
206
+
207
+
208
+ @dataclass
209
+ class DatasetSplits:
210
+ """A deterministic set of dataset splits."""
211
+
212
+ train: Dataset[torch.Tensor]
213
+ validation: Dataset[torch.Tensor]
214
+ test: Dataset[torch.Tensor]
215
+
216
+
217
+ def split_dataset(
218
+ dataset: Dataset[torch.Tensor],
219
+ *,
220
+ validation_ratio: float = 0.1,
221
+ test_ratio: float = 0.1,
222
+ seed: int = 42,
223
+ ) -> DatasetSplits:
224
+ """Split a dataset into train, validation, and test subsets."""
225
+
226
+ if validation_ratio < 0 or test_ratio < 0:
227
+ raise ValueError("validation_ratio and test_ratio must be non-negative.")
228
+ if validation_ratio + test_ratio >= 1.0:
229
+ raise ValueError("validation_ratio + test_ratio must be less than 1.0.")
230
+
231
+ num_examples = len(dataset)
232
+ generator = torch.Generator().manual_seed(seed)
233
+ permutation = torch.randperm(num_examples, generator=generator).tolist()
234
+
235
+ num_validation = int(num_examples * validation_ratio)
236
+ num_test = int(num_examples * test_ratio)
237
+ num_train = num_examples - num_validation - num_test
238
+
239
+ train_indices = permutation[:num_train]
240
+ validation_indices = permutation[num_train : num_train + num_validation]
241
+ test_indices = permutation[num_train + num_validation :]
242
+
243
+ return DatasetSplits(
244
+ train=Subset(dataset, train_indices),
245
+ validation=Subset(dataset, validation_indices),
246
+ test=Subset(dataset, test_indices),
247
+ )
248
+
249
+
250
+ @dataclass
251
+ class DatasetLoaders:
252
+ """Convenience wrapper for train, validation, and test dataloaders."""
253
+
254
+ train: DataLoader[torch.Tensor]
255
+ validation: DataLoader[torch.Tensor]
256
+ test: DataLoader[torch.Tensor]
257
+
258
+
259
+ def create_dataloaders(
260
+ splits: DatasetSplits,
261
+ *,
262
+ batch_size: int = 256,
263
+ num_workers: int = 0,
264
+ ) -> DatasetLoaders:
265
+ """Create dataloaders from deterministic dataset splits."""
266
+
267
+ return DatasetLoaders(
268
+ train=DataLoader(splits.train, batch_size=batch_size, shuffle=True, num_workers=num_workers),
269
+ validation=DataLoader(
270
+ splits.validation,
271
+ batch_size=batch_size,
272
+ shuffle=False,
273
+ num_workers=num_workers,
274
+ ),
275
+ test=DataLoader(splits.test, batch_size=batch_size, shuffle=False, num_workers=num_workers),
276
+ )
@@ -0,0 +1,145 @@
1
+ """Helpers for loading and packaging embedding matrices."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ from dataclasses import dataclass
7
+ from pathlib import Path
8
+
9
+ import torch
10
+
11
+ from .base import AutoencoderDataset
12
+
13
+
14
+ @dataclass
15
+ class EmbeddingMatrix:
16
+ """A simple in-memory representation of a token embedding matrix."""
17
+
18
+ tokens: list[str]
19
+ matrix: torch.Tensor
20
+ token_to_index: dict[str, int]
21
+ source_path: str | None = None
22
+ name: str | None = None
23
+
24
+ @property
25
+ def num_embeddings(self) -> int:
26
+ return int(self.matrix.shape[0])
27
+
28
+ @property
29
+ def embedding_dim(self) -> int:
30
+ return int(self.matrix.shape[1])
31
+
32
+
33
+ class EmbeddingTensorDataset(AutoencoderDataset):
34
+ """Dataset that exposes each embedding vector as one training sample."""
35
+
36
+ def __init__(self, embedding_matrix: EmbeddingMatrix, split: str = "train") -> None:
37
+ super().__init__(split=split)
38
+ self.embedding_matrix = embedding_matrix
39
+
40
+ def __len__(self) -> int:
41
+ return self.embedding_matrix.num_embeddings
42
+
43
+ def __getitem__(self, index: int) -> torch.Tensor:
44
+ return self.embedding_matrix.matrix[index]
45
+
46
+
47
+ def load_text_embedding_matrix(
48
+ path: str | Path,
49
+ *,
50
+ max_vectors: int | None = None,
51
+ expected_dim: int | None = None,
52
+ skip_first_line: bool = False,
53
+ dtype: torch.dtype = torch.float32,
54
+ ) -> EmbeddingMatrix:
55
+ """Load a whitespace-separated embedding text file such as GloVe."""
56
+
57
+ file_path = Path(path)
58
+ tokens: list[str] = []
59
+ rows: list[list[float]] = []
60
+
61
+ with file_path.open("r", encoding="utf-8") as handle:
62
+ for line_number, line in enumerate(handle, start=1):
63
+ if skip_first_line and line_number == 1:
64
+ continue
65
+ stripped = line.strip()
66
+ if not stripped:
67
+ continue
68
+
69
+ parts = stripped.split()
70
+ if len(parts) < 2:
71
+ raise ValueError(f"Malformed embedding row at line {line_number}: {stripped!r}")
72
+
73
+ token = parts[0]
74
+ values = parts[1:]
75
+
76
+ if expected_dim is None:
77
+ expected_dim = len(values)
78
+ elif len(values) != expected_dim:
79
+ raise ValueError(
80
+ f"Embedding dimension mismatch at line {line_number}: "
81
+ f"expected {expected_dim}, got {len(values)}"
82
+ )
83
+
84
+ tokens.append(token)
85
+ rows.append([float(value) for value in values])
86
+
87
+ if max_vectors is not None and len(tokens) >= max_vectors:
88
+ break
89
+
90
+ if not rows:
91
+ raise ValueError(f"No embeddings were loaded from {file_path}.")
92
+
93
+ matrix = torch.tensor(rows, dtype=dtype)
94
+ token_to_index = {token: index for index, token in enumerate(tokens)}
95
+ return EmbeddingMatrix(
96
+ tokens=tokens,
97
+ matrix=matrix,
98
+ token_to_index=token_to_index,
99
+ source_path=str(file_path),
100
+ name=file_path.stem,
101
+ )
102
+
103
+
104
+ def save_embedding_artifact(embedding_matrix: EmbeddingMatrix, output_dir: str | Path) -> Path:
105
+ """Save a processed embedding matrix in a compact torch-friendly format."""
106
+
107
+ save_dir = Path(output_dir)
108
+ save_dir.mkdir(parents=True, exist_ok=True)
109
+
110
+ tensor_path = save_dir / "embeddings.pt"
111
+ tokens_path = save_dir / "tokens.txt"
112
+ metadata_path = save_dir / "metadata.json"
113
+
114
+ torch.save(embedding_matrix.matrix, tensor_path)
115
+ tokens_path.write_text("\n".join(embedding_matrix.tokens) + "\n", encoding="utf-8")
116
+ metadata = {
117
+ "name": embedding_matrix.name,
118
+ "source_path": embedding_matrix.source_path,
119
+ "num_embeddings": embedding_matrix.num_embeddings,
120
+ "embedding_dim": embedding_matrix.embedding_dim,
121
+ }
122
+ metadata_path.write_text(json.dumps(metadata, indent=2, sort_keys=True) + "\n", encoding="utf-8")
123
+ return save_dir
124
+
125
+
126
+ def load_embedding_artifact(path: str | Path) -> EmbeddingMatrix:
127
+ """Load a processed embedding artifact created by save_embedding_artifact()."""
128
+
129
+ load_dir = Path(path)
130
+ tensor_path = load_dir / "embeddings.pt"
131
+ tokens_path = load_dir / "tokens.txt"
132
+ metadata_path = load_dir / "metadata.json"
133
+
134
+ matrix = torch.load(tensor_path, map_location="cpu")
135
+ tokens = tokens_path.read_text(encoding="utf-8").splitlines()
136
+ metadata = json.loads(metadata_path.read_text(encoding="utf-8"))
137
+ token_to_index = {token: index for index, token in enumerate(tokens)}
138
+
139
+ return EmbeddingMatrix(
140
+ tokens=tokens,
141
+ matrix=matrix,
142
+ token_to_index=token_to_index,
143
+ source_path=metadata.get("source_path"),
144
+ name=metadata.get("name"),
145
+ )