autoencoders 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- autoencoders/__init__.py +181 -0
- autoencoders/configuration_utils.py +49 -0
- autoencoders/data/__init__.py +41 -0
- autoencoders/data/base.py +276 -0
- autoencoders/data/embeddings.py +145 -0
- autoencoders/data/fasttext.py +143 -0
- autoencoders/data/glove.py +165 -0
- autoencoders/data/loading.py +21 -0
- autoencoders/data/numberbatch.py +142 -0
- autoencoders/modeling_outputs.py +131 -0
- autoencoders/modeling_utils.py +51 -0
- autoencoders/models/__init__.py +5 -0
- autoencoders/models/aae/__init__.py +6 -0
- autoencoders/models/aae/configuration_aae.py +44 -0
- autoencoders/models/aae/modeling_aae.py +101 -0
- autoencoders/models/ae/__init__.py +13 -0
- autoencoders/models/ae/configuration_ae.py +44 -0
- autoencoders/models/ae/modeling_ae.py +60 -0
- autoencoders/models/base/__init__.py +13 -0
- autoencoders/models/base/configuration_base.py +33 -0
- autoencoders/models/base/configuration_vae.py +52 -0
- autoencoders/models/base/configuration_vq.py +62 -0
- autoencoders/models/base/modeling_base.py +122 -0
- autoencoders/models/base/modeling_vae.py +166 -0
- autoencoders/models/base/modeling_vq.py +187 -0
- autoencoders/models/betavae/__init__.py +13 -0
- autoencoders/models/betavae/configuration_betavae.py +47 -0
- autoencoders/models/betavae/modeling_betavae.py +12 -0
- autoencoders/models/cae/__init__.py +6 -0
- autoencoders/models/cae/configuration_cae.py +38 -0
- autoencoders/models/cae/modeling_cae.py +73 -0
- autoencoders/models/dae/__init__.py +14 -0
- autoencoders/models/dae/configuration_dae.py +49 -0
- autoencoders/models/dae/modeling_dae.py +74 -0
- autoencoders/models/dvae/__init__.py +6 -0
- autoencoders/models/dvae/configuration_dvae.py +54 -0
- autoencoders/models/dvae/modeling_dvae.py +89 -0
- autoencoders/models/fsq/__init__.py +6 -0
- autoencoders/models/fsq/configuration_fsq.py +48 -0
- autoencoders/models/fsq/modeling_fsq.py +90 -0
- autoencoders/models/hvae/__init__.py +6 -0
- autoencoders/models/hvae/configuration_hvae.py +45 -0
- autoencoders/models/hvae/modeling_hvae.py +145 -0
- autoencoders/models/klsae/__init__.py +6 -0
- autoencoders/models/klsae/configuration_klsae.py +42 -0
- autoencoders/models/klsae/modeling_klsae.py +63 -0
- autoencoders/models/loading.py +56 -0
- autoencoders/models/pqvae/__init__.py +6 -0
- autoencoders/models/pqvae/configuration_pqvae.py +52 -0
- autoencoders/models/pqvae/modeling_pqvae.py +133 -0
- autoencoders/models/rqvae/__init__.py +6 -0
- autoencoders/models/rqvae/configuration_rqvae.py +50 -0
- autoencoders/models/rqvae/modeling_rqvae.py +124 -0
- autoencoders/models/sae/__init__.py +13 -0
- autoencoders/models/sae/configuration_sae.py +38 -0
- autoencoders/models/sae/modeling_sae.py +53 -0
- autoencoders/models/topksae/__init__.py +6 -0
- autoencoders/models/topksae/configuration_topksae.py +40 -0
- autoencoders/models/topksae/modeling_topksae.py +55 -0
- autoencoders/models/vae/__init__.py +14 -0
- autoencoders/models/vae/configuration_vae.py +39 -0
- autoencoders/models/vae/modeling_vae.py +66 -0
- autoencoders/models/vqvae/__init__.py +13 -0
- autoencoders/models/vqvae/configuration_vqvae.py +45 -0
- autoencoders/models/vqvae/modeling_vqvae.py +95 -0
- autoencoders/models/wae/__init__.py +6 -0
- autoencoders/models/wae/configuration_wae.py +44 -0
- autoencoders/models/wae/modeling_wae.py +67 -0
- autoencoders/training/__init__.py +26 -0
- autoencoders/training/display.py +373 -0
- autoencoders/training/trainer.py +495 -0
- autoencoders-0.1.0.dist-info/METADATA +284 -0
- autoencoders-0.1.0.dist-info/RECORD +75 -0
- autoencoders-0.1.0.dist-info/WHEEL +5 -0
- autoencoders-0.1.0.dist-info/top_level.txt +1 -0
autoencoders/__init__.py
ADDED
|
@@ -0,0 +1,181 @@
|
|
|
1
|
+
"""Top-level package for the autoencoders library."""
|
|
2
|
+
|
|
3
|
+
from .models.aae.configuration_aae import AdversarialAutoencoderConfig
|
|
4
|
+
from .configuration_utils import PretrainedConfig
|
|
5
|
+
from .modeling_outputs import (
|
|
6
|
+
AdversarialAutoencoderOutput,
|
|
7
|
+
AutoencoderExport,
|
|
8
|
+
AutoencoderOutput,
|
|
9
|
+
BaseAutoencoderOutput,
|
|
10
|
+
ContractiveAutoencoderOutput,
|
|
11
|
+
DenoisingAutoencoderOutput,
|
|
12
|
+
DenoisingVariationalAutoencoderOutput,
|
|
13
|
+
FiniteScalarQuantizedAutoencoderOutput,
|
|
14
|
+
HierarchicalVariationalAutoencoderOutput,
|
|
15
|
+
KLSparseAutoencoderOutput,
|
|
16
|
+
QuantizedAutoencoderOutput,
|
|
17
|
+
SparseAutoencoderOutput,
|
|
18
|
+
TopKSparseAutoencoderOutput,
|
|
19
|
+
VariationalAutoencoderOutput,
|
|
20
|
+
WassersteinAutoencoderOutput,
|
|
21
|
+
)
|
|
22
|
+
from .models.ae.configuration_ae import AutoencoderConfig
|
|
23
|
+
from .models.cae.configuration_cae import ContractiveAutoencoderConfig
|
|
24
|
+
from .models.base.configuration_base import BaseAutoencoderConfig
|
|
25
|
+
from .models.base.configuration_vae import BaseVariationalAutoencoderConfig
|
|
26
|
+
from .models.base.configuration_vq import BaseVectorQuantizedAutoencoderConfig
|
|
27
|
+
from .models.betavae.configuration_betavae import BetaVariationalAutoencoderConfig
|
|
28
|
+
from .models.dae.configuration_dae import DenoisingAutoencoderConfig
|
|
29
|
+
from .models.dvae.configuration_dvae import DenoisingVariationalAutoencoderConfig
|
|
30
|
+
from .models.fsq.configuration_fsq import FiniteScalarQuantizedAutoencoderConfig
|
|
31
|
+
from .models.hvae.configuration_hvae import HierarchicalVariationalAutoencoderConfig
|
|
32
|
+
from .models.klsae.configuration_klsae import KLSparseAutoencoderConfig
|
|
33
|
+
from .models.pqvae.configuration_pqvae import ProductQuantizedAutoencoderConfig
|
|
34
|
+
from .models.rqvae.configuration_rqvae import ResidualQuantizedAutoencoderConfig
|
|
35
|
+
from .models.sae.configuration_sae import SparseAutoencoderConfig
|
|
36
|
+
from .models.topksae.configuration_topksae import TopKSparseAutoencoderConfig
|
|
37
|
+
from .models.vae.configuration_vae import VariationalAutoencoderConfig
|
|
38
|
+
from .models.wae.configuration_wae import WassersteinAutoencoderConfig
|
|
39
|
+
from .models.vqvae.configuration_vqvae import VectorQuantizedAutoencoderConfig
|
|
40
|
+
|
|
41
|
+
__all__ = [
|
|
42
|
+
"AdversarialAutoencoderConfig",
|
|
43
|
+
"AdversarialAutoencoderOutput",
|
|
44
|
+
"AutoencoderConfig",
|
|
45
|
+
"AutoencoderExport",
|
|
46
|
+
"AutoencoderOutput",
|
|
47
|
+
"BaseAutoencoderOutput",
|
|
48
|
+
"BaseAutoencoderConfig",
|
|
49
|
+
"BaseVariationalAutoencoderConfig",
|
|
50
|
+
"BaseVectorQuantizedAutoencoderConfig",
|
|
51
|
+
"BetaVariationalAutoencoderConfig",
|
|
52
|
+
"ContractiveAutoencoderConfig",
|
|
53
|
+
"ContractiveAutoencoderOutput",
|
|
54
|
+
"DenoisingAutoencoderConfig",
|
|
55
|
+
"DenoisingAutoencoderOutput",
|
|
56
|
+
"DenoisingVariationalAutoencoderConfig",
|
|
57
|
+
"DenoisingVariationalAutoencoderOutput",
|
|
58
|
+
"FiniteScalarQuantizedAutoencoderConfig",
|
|
59
|
+
"FiniteScalarQuantizedAutoencoderOutput",
|
|
60
|
+
"HierarchicalVariationalAutoencoderConfig",
|
|
61
|
+
"HierarchicalVariationalAutoencoderOutput",
|
|
62
|
+
"KLSparseAutoencoderConfig",
|
|
63
|
+
"KLSparseAutoencoderOutput",
|
|
64
|
+
"PretrainedConfig",
|
|
65
|
+
"ProductQuantizedAutoencoderConfig",
|
|
66
|
+
"QuantizedAutoencoderOutput",
|
|
67
|
+
"ResidualQuantizedAutoencoderConfig",
|
|
68
|
+
"SparseAutoencoderConfig",
|
|
69
|
+
"SparseAutoencoderOutput",
|
|
70
|
+
"TopKSparseAutoencoderConfig",
|
|
71
|
+
"TopKSparseAutoencoderOutput",
|
|
72
|
+
"VariationalAutoencoderConfig",
|
|
73
|
+
"VariationalAutoencoderOutput",
|
|
74
|
+
"WassersteinAutoencoderConfig",
|
|
75
|
+
"WassersteinAutoencoderOutput",
|
|
76
|
+
"VectorQuantizedAutoencoderConfig",
|
|
77
|
+
]
|
|
78
|
+
|
|
79
|
+
try:
|
|
80
|
+
from .models.aae.modeling_aae import AdversarialAutoencoderModel
|
|
81
|
+
from .modeling_utils import PreTrainedAutoencoderModel
|
|
82
|
+
from .data import (
|
|
83
|
+
AutoencoderDataset,
|
|
84
|
+
CachedDataset,
|
|
85
|
+
ConceptNetNumberbatchDataset,
|
|
86
|
+
DatasetLoaders,
|
|
87
|
+
DatasetSplits,
|
|
88
|
+
EmbeddingMatrix,
|
|
89
|
+
EmbeddingTensorDataset,
|
|
90
|
+
FastTextEnglishDataset,
|
|
91
|
+
GloVeDataset,
|
|
92
|
+
create_dataloaders,
|
|
93
|
+
load_dataset,
|
|
94
|
+
load_embedding_artifact,
|
|
95
|
+
load_text_embedding_matrix,
|
|
96
|
+
split_dataset,
|
|
97
|
+
)
|
|
98
|
+
from .models.ae.modeling_ae import AutoencoderModel
|
|
99
|
+
from .models.base.modeling_base import BaseAutoencoderModel
|
|
100
|
+
from .models.base.modeling_vae import BaseVariationalAutoencoderModel
|
|
101
|
+
from .models.base.modeling_vq import BaseVectorQuantizedAutoencoderModel
|
|
102
|
+
from .models.betavae.modeling_betavae import BetaVariationalAutoencoderModel
|
|
103
|
+
from .models.cae.modeling_cae import ContractiveAutoencoderModel
|
|
104
|
+
from .models.dae.modeling_dae import DenoisingAutoencoderModel
|
|
105
|
+
from .models.dvae.modeling_dvae import DenoisingVariationalAutoencoderModel
|
|
106
|
+
from .models.fsq.modeling_fsq import FiniteScalarQuantizedAutoencoderModel
|
|
107
|
+
from .models.hvae.modeling_hvae import HierarchicalVariationalAutoencoderModel
|
|
108
|
+
from .models.klsae.modeling_klsae import KLSparseAutoencoderModel
|
|
109
|
+
from .models.loading import load_model
|
|
110
|
+
from .models.pqvae.modeling_pqvae import ProductQuantizedAutoencoderModel
|
|
111
|
+
from .models.rqvae.modeling_rqvae import ResidualQuantizedAutoencoderModel
|
|
112
|
+
from .models.sae.modeling_sae import SparseAutoencoderModel
|
|
113
|
+
from .models.topksae.modeling_topksae import TopKSparseAutoencoderModel
|
|
114
|
+
from .models.vae.modeling_vae import VariationalAutoencoderModel
|
|
115
|
+
from .models.wae.modeling_wae import WassersteinAutoencoderModel
|
|
116
|
+
from .models.vqvae.modeling_vqvae import VectorQuantizedAutoencoderModel
|
|
117
|
+
from .training import (
|
|
118
|
+
AETrainer,
|
|
119
|
+
AdversarialAutoencoderTrainer,
|
|
120
|
+
AdversarialAutoencoderTrainingArguments,
|
|
121
|
+
TrainerDisplay,
|
|
122
|
+
TrainerDisplayConfig,
|
|
123
|
+
TrainingArguments,
|
|
124
|
+
VAETrainer,
|
|
125
|
+
VQTrainer,
|
|
126
|
+
resolve_device,
|
|
127
|
+
set_seed,
|
|
128
|
+
)
|
|
129
|
+
except ModuleNotFoundError as exc:
|
|
130
|
+
if exc.name != "torch":
|
|
131
|
+
raise
|
|
132
|
+
else:
|
|
133
|
+
__all__.extend(
|
|
134
|
+
[
|
|
135
|
+
"AETrainer",
|
|
136
|
+
"AdversarialAutoencoderModel",
|
|
137
|
+
"AdversarialAutoencoderTrainer",
|
|
138
|
+
"AdversarialAutoencoderTrainingArguments",
|
|
139
|
+
"AutoencoderModel",
|
|
140
|
+
"AutoencoderDataset",
|
|
141
|
+
"BaseAutoencoderModel",
|
|
142
|
+
"BaseVariationalAutoencoderModel",
|
|
143
|
+
"BaseVectorQuantizedAutoencoderModel",
|
|
144
|
+
"BetaVariationalAutoencoderModel",
|
|
145
|
+
"CachedDataset",
|
|
146
|
+
"ContractiveAutoencoderModel",
|
|
147
|
+
"ConceptNetNumberbatchDataset",
|
|
148
|
+
"DatasetLoaders",
|
|
149
|
+
"DatasetSplits",
|
|
150
|
+
"DenoisingAutoencoderModel",
|
|
151
|
+
"DenoisingVariationalAutoencoderModel",
|
|
152
|
+
"EmbeddingMatrix",
|
|
153
|
+
"EmbeddingTensorDataset",
|
|
154
|
+
"FastTextEnglishDataset",
|
|
155
|
+
"FiniteScalarQuantizedAutoencoderModel",
|
|
156
|
+
"GloVeDataset",
|
|
157
|
+
"HierarchicalVariationalAutoencoderModel",
|
|
158
|
+
"KLSparseAutoencoderModel",
|
|
159
|
+
"PreTrainedAutoencoderModel",
|
|
160
|
+
"ProductQuantizedAutoencoderModel",
|
|
161
|
+
"ResidualQuantizedAutoencoderModel",
|
|
162
|
+
"SparseAutoencoderModel",
|
|
163
|
+
"TopKSparseAutoencoderModel",
|
|
164
|
+
"TrainerDisplay",
|
|
165
|
+
"TrainerDisplayConfig",
|
|
166
|
+
"TrainingArguments",
|
|
167
|
+
"VAETrainer",
|
|
168
|
+
"VariationalAutoencoderModel",
|
|
169
|
+
"VQTrainer",
|
|
170
|
+
"VectorQuantizedAutoencoderModel",
|
|
171
|
+
"WassersteinAutoencoderModel",
|
|
172
|
+
"create_dataloaders",
|
|
173
|
+
"load_dataset",
|
|
174
|
+
"load_model",
|
|
175
|
+
"load_embedding_artifact",
|
|
176
|
+
"load_text_embedding_matrix",
|
|
177
|
+
"resolve_device",
|
|
178
|
+
"set_seed",
|
|
179
|
+
"split_dataset",
|
|
180
|
+
]
|
|
181
|
+
)
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
"""Utilities for model configuration objects."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class PretrainedConfig:
|
|
11
|
+
"""A lightweight configuration base inspired by the transformers API."""
|
|
12
|
+
|
|
13
|
+
model_type = "config"
|
|
14
|
+
|
|
15
|
+
def __init__(self, **kwargs: Any) -> None:
|
|
16
|
+
self.return_dict = kwargs.pop("return_dict", True)
|
|
17
|
+
for key, value in kwargs.items():
|
|
18
|
+
setattr(self, key, value)
|
|
19
|
+
|
|
20
|
+
def to_dict(self) -> dict[str, Any]:
|
|
21
|
+
payload = dict(self.__dict__)
|
|
22
|
+
payload["model_type"] = self.model_type
|
|
23
|
+
return payload
|
|
24
|
+
|
|
25
|
+
def to_json_string(self) -> str:
|
|
26
|
+
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
|
|
27
|
+
|
|
28
|
+
def save_pretrained(self, save_directory: str | Path) -> Path:
|
|
29
|
+
save_path = Path(save_directory)
|
|
30
|
+
save_path.mkdir(parents=True, exist_ok=True)
|
|
31
|
+
config_path = save_path / "config.json"
|
|
32
|
+
config_path.write_text(self.to_json_string(), encoding="utf-8")
|
|
33
|
+
return config_path
|
|
34
|
+
|
|
35
|
+
@classmethod
|
|
36
|
+
def from_dict(cls, config_dict: dict[str, Any], **kwargs: Any) -> "PretrainedConfig":
|
|
37
|
+
payload = dict(config_dict)
|
|
38
|
+
payload.pop("model_type", None)
|
|
39
|
+
payload.update(kwargs)
|
|
40
|
+
return cls(**payload)
|
|
41
|
+
|
|
42
|
+
@classmethod
|
|
43
|
+
def from_pretrained(cls, pretrained_model_name_or_path: str | Path, **kwargs: Any) -> "PretrainedConfig":
|
|
44
|
+
config_path = Path(pretrained_model_name_or_path)
|
|
45
|
+
if config_path.is_dir():
|
|
46
|
+
config_path = config_path / "config.json"
|
|
47
|
+
config_dict = json.loads(config_path.read_text(encoding="utf-8"))
|
|
48
|
+
return cls.from_dict(config_dict, **kwargs)
|
|
49
|
+
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
"""Utilities for working with real embedding matrices."""
|
|
2
|
+
|
|
3
|
+
from .base import (
|
|
4
|
+
AutoencoderDataset,
|
|
5
|
+
CachedDataset,
|
|
6
|
+
DatasetLoaders,
|
|
7
|
+
DatasetSplits,
|
|
8
|
+
create_dataloaders,
|
|
9
|
+
default_cache_dir,
|
|
10
|
+
split_dataset,
|
|
11
|
+
)
|
|
12
|
+
from .embeddings import (
|
|
13
|
+
EmbeddingMatrix,
|
|
14
|
+
EmbeddingTensorDataset,
|
|
15
|
+
load_embedding_artifact,
|
|
16
|
+
load_text_embedding_matrix,
|
|
17
|
+
save_embedding_artifact,
|
|
18
|
+
)
|
|
19
|
+
from .fasttext import FastTextEnglishDataset
|
|
20
|
+
from .glove import GloVeDataset
|
|
21
|
+
from .loading import load_dataset
|
|
22
|
+
from .numberbatch import ConceptNetNumberbatchDataset
|
|
23
|
+
|
|
24
|
+
__all__ = [
|
|
25
|
+
"AutoencoderDataset",
|
|
26
|
+
"CachedDataset",
|
|
27
|
+
"DatasetLoaders",
|
|
28
|
+
"DatasetSplits",
|
|
29
|
+
"EmbeddingMatrix",
|
|
30
|
+
"EmbeddingTensorDataset",
|
|
31
|
+
"FastTextEnglishDataset",
|
|
32
|
+
"GloVeDataset",
|
|
33
|
+
"ConceptNetNumberbatchDataset",
|
|
34
|
+
"create_dataloaders",
|
|
35
|
+
"default_cache_dir",
|
|
36
|
+
"load_dataset",
|
|
37
|
+
"load_embedding_artifact",
|
|
38
|
+
"load_text_embedding_matrix",
|
|
39
|
+
"save_embedding_artifact",
|
|
40
|
+
"split_dataset",
|
|
41
|
+
]
|
|
@@ -0,0 +1,276 @@
|
|
|
1
|
+
"""Base dataset abstractions for autoencoder training and evaluation."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from abc import ABC, abstractmethod
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
import os
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
import sys
|
|
10
|
+
from typing import Callable
|
|
11
|
+
import urllib.request
|
|
12
|
+
|
|
13
|
+
import torch
|
|
14
|
+
from torch.utils.data import DataLoader, Dataset, Subset
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def default_cache_dir() -> Path:
|
|
18
|
+
"""Return the default cache directory for downloadable datasets."""
|
|
19
|
+
|
|
20
|
+
cache_dir = os.environ.get("AUTOENCODERS_CACHE")
|
|
21
|
+
if cache_dir:
|
|
22
|
+
return Path(cache_dir).expanduser()
|
|
23
|
+
return Path.home() / ".cache" / "autoencoders"
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def format_num_bytes(num_bytes: int) -> str:
|
|
27
|
+
"""Format a byte count into a compact human-readable string."""
|
|
28
|
+
|
|
29
|
+
value = float(num_bytes)
|
|
30
|
+
units = ["B", "KB", "MB", "GB", "TB"]
|
|
31
|
+
for unit in units:
|
|
32
|
+
if value < 1024.0 or unit == units[-1]:
|
|
33
|
+
return f"{value:.1f}{unit}"
|
|
34
|
+
value /= 1024.0
|
|
35
|
+
return f"{num_bytes}B"
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class DownloadProgressBar:
|
|
39
|
+
"""A small terminal progress bar for dataset downloads."""
|
|
40
|
+
|
|
41
|
+
def __init__(self, description: str, total_bytes: int | None, stream=None) -> None:
|
|
42
|
+
self.description = description
|
|
43
|
+
self.total_bytes = total_bytes
|
|
44
|
+
self.stream = sys.stderr if stream is None else stream
|
|
45
|
+
self.downloaded_bytes = 0
|
|
46
|
+
self._finished = False
|
|
47
|
+
self._render()
|
|
48
|
+
|
|
49
|
+
def update(self, chunk_size: int) -> None:
|
|
50
|
+
self.downloaded_bytes += chunk_size
|
|
51
|
+
self._render()
|
|
52
|
+
|
|
53
|
+
def close(self) -> None:
|
|
54
|
+
if self._finished:
|
|
55
|
+
return
|
|
56
|
+
self._finished = True
|
|
57
|
+
self._render(final=True)
|
|
58
|
+
|
|
59
|
+
def _render(self, final: bool = False) -> None:
|
|
60
|
+
if self.total_bytes is not None and self.total_bytes > 0:
|
|
61
|
+
ratio = min(self.downloaded_bytes / self.total_bytes, 1.0)
|
|
62
|
+
filled = int(ratio * 20)
|
|
63
|
+
bar = "=" * filled + "." * (20 - filled)
|
|
64
|
+
percent = int(ratio * 100)
|
|
65
|
+
message = (
|
|
66
|
+
f"\r{self.description} [{bar}] {percent:3d}% "
|
|
67
|
+
f"{format_num_bytes(self.downloaded_bytes)}/{format_num_bytes(self.total_bytes)}"
|
|
68
|
+
)
|
|
69
|
+
else:
|
|
70
|
+
message = f"\r{self.description} {format_num_bytes(self.downloaded_bytes)}"
|
|
71
|
+
|
|
72
|
+
if final:
|
|
73
|
+
message += "\n"
|
|
74
|
+
|
|
75
|
+
self.stream.write(message)
|
|
76
|
+
self.stream.flush()
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
class AutoencoderDataset(Dataset[torch.Tensor]):
|
|
80
|
+
"""Simple dataset contract for autoencoder-friendly tensor samples."""
|
|
81
|
+
|
|
82
|
+
split: str
|
|
83
|
+
|
|
84
|
+
def __init__(self, split: str = "train") -> None:
|
|
85
|
+
self.split = split
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
class CachedDataset(ABC):
|
|
89
|
+
"""Base class for datasets that download raw files and cache processed artifacts."""
|
|
90
|
+
|
|
91
|
+
dataset_name = "dataset"
|
|
92
|
+
|
|
93
|
+
def __init__(self, root: str | Path | None = None) -> None:
|
|
94
|
+
self.root = Path(root) if root is not None else default_cache_dir()
|
|
95
|
+
self.dataset_dir = self.root / self.dataset_name
|
|
96
|
+
self.raw_dir = self.dataset_dir / "raw"
|
|
97
|
+
self.external_dir = self.dataset_dir / "external"
|
|
98
|
+
self.processed_dir = self.dataset_dir / "processed"
|
|
99
|
+
|
|
100
|
+
def ensure_prepared(
|
|
101
|
+
self,
|
|
102
|
+
*,
|
|
103
|
+
download: bool = True,
|
|
104
|
+
force_download: bool = False,
|
|
105
|
+
force_prepare: bool = False,
|
|
106
|
+
) -> Path:
|
|
107
|
+
"""Ensure the processed artifact exists and return its directory."""
|
|
108
|
+
|
|
109
|
+
artifact_dir = self.artifact_dir
|
|
110
|
+
if self.is_prepared() and not force_prepare:
|
|
111
|
+
return artifact_dir
|
|
112
|
+
|
|
113
|
+
if download:
|
|
114
|
+
self.download(force=force_download)
|
|
115
|
+
elif not self.has_raw_data():
|
|
116
|
+
raise FileNotFoundError(
|
|
117
|
+
f"Raw data for {self.dataset_name!r} is missing under {self.raw_dir}."
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
self.prepare()
|
|
121
|
+
return artifact_dir
|
|
122
|
+
|
|
123
|
+
def is_prepared(self) -> bool:
|
|
124
|
+
"""Return True when the processed artifact is complete and ready to use."""
|
|
125
|
+
|
|
126
|
+
return self.artifact_dir.exists()
|
|
127
|
+
|
|
128
|
+
def download_to_cache(
|
|
129
|
+
self,
|
|
130
|
+
*,
|
|
131
|
+
url: str,
|
|
132
|
+
destination: Path,
|
|
133
|
+
validator: Callable[[Path], bool] | None = None,
|
|
134
|
+
description: str | None = None,
|
|
135
|
+
force: bool = False,
|
|
136
|
+
chunk_size: int = 1024 * 1024,
|
|
137
|
+
) -> Path:
|
|
138
|
+
"""Download a file atomically with progress reporting and cache validation."""
|
|
139
|
+
|
|
140
|
+
destination.parent.mkdir(parents=True, exist_ok=True)
|
|
141
|
+
temp_path = destination.with_name(f"{destination.name}.tmp")
|
|
142
|
+
self._cleanup_temp_file(temp_path)
|
|
143
|
+
|
|
144
|
+
if destination.exists() and not force:
|
|
145
|
+
if validator is None or validator(destination):
|
|
146
|
+
return destination
|
|
147
|
+
destination.unlink()
|
|
148
|
+
elif destination.exists():
|
|
149
|
+
destination.unlink()
|
|
150
|
+
|
|
151
|
+
try:
|
|
152
|
+
with urllib.request.urlopen(url) as response, temp_path.open("wb") as handle:
|
|
153
|
+
total_bytes = self._response_content_length(response)
|
|
154
|
+
progress = DownloadProgressBar(description or destination.name, total_bytes)
|
|
155
|
+
try:
|
|
156
|
+
while True:
|
|
157
|
+
chunk = response.read(chunk_size)
|
|
158
|
+
if not chunk:
|
|
159
|
+
break
|
|
160
|
+
handle.write(chunk)
|
|
161
|
+
progress.update(len(chunk))
|
|
162
|
+
finally:
|
|
163
|
+
progress.close()
|
|
164
|
+
except Exception:
|
|
165
|
+
self._cleanup_temp_file(temp_path)
|
|
166
|
+
raise
|
|
167
|
+
|
|
168
|
+
if validator is not None and not validator(temp_path):
|
|
169
|
+
self._cleanup_temp_file(temp_path)
|
|
170
|
+
raise ValueError(f"Downloaded file {temp_path} failed validation.")
|
|
171
|
+
|
|
172
|
+
temp_path.replace(destination)
|
|
173
|
+
return destination
|
|
174
|
+
|
|
175
|
+
@staticmethod
|
|
176
|
+
def _cleanup_temp_file(path: Path) -> None:
|
|
177
|
+
if path.exists():
|
|
178
|
+
path.unlink()
|
|
179
|
+
|
|
180
|
+
@staticmethod
|
|
181
|
+
def _response_content_length(response) -> int | None:
|
|
182
|
+
length = response.headers.get("Content-Length")
|
|
183
|
+
if length is None:
|
|
184
|
+
return None
|
|
185
|
+
try:
|
|
186
|
+
return int(length)
|
|
187
|
+
except (TypeError, ValueError):
|
|
188
|
+
return None
|
|
189
|
+
|
|
190
|
+
@property
|
|
191
|
+
@abstractmethod
|
|
192
|
+
def artifact_dir(self) -> Path:
|
|
193
|
+
"""Directory containing the processed dataset artifact."""
|
|
194
|
+
|
|
195
|
+
@abstractmethod
|
|
196
|
+
def has_raw_data(self) -> bool:
|
|
197
|
+
"""Return True when enough raw files are present to prepare the dataset."""
|
|
198
|
+
|
|
199
|
+
@abstractmethod
|
|
200
|
+
def download(self, *, force: bool = False) -> None:
|
|
201
|
+
"""Download the raw dataset files into the cache."""
|
|
202
|
+
|
|
203
|
+
@abstractmethod
|
|
204
|
+
def prepare(self) -> None:
|
|
205
|
+
"""Convert raw files into a processed artifact."""
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
@dataclass
|
|
209
|
+
class DatasetSplits:
|
|
210
|
+
"""A deterministic set of dataset splits."""
|
|
211
|
+
|
|
212
|
+
train: Dataset[torch.Tensor]
|
|
213
|
+
validation: Dataset[torch.Tensor]
|
|
214
|
+
test: Dataset[torch.Tensor]
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
def split_dataset(
|
|
218
|
+
dataset: Dataset[torch.Tensor],
|
|
219
|
+
*,
|
|
220
|
+
validation_ratio: float = 0.1,
|
|
221
|
+
test_ratio: float = 0.1,
|
|
222
|
+
seed: int = 42,
|
|
223
|
+
) -> DatasetSplits:
|
|
224
|
+
"""Split a dataset into train, validation, and test subsets."""
|
|
225
|
+
|
|
226
|
+
if validation_ratio < 0 or test_ratio < 0:
|
|
227
|
+
raise ValueError("validation_ratio and test_ratio must be non-negative.")
|
|
228
|
+
if validation_ratio + test_ratio >= 1.0:
|
|
229
|
+
raise ValueError("validation_ratio + test_ratio must be less than 1.0.")
|
|
230
|
+
|
|
231
|
+
num_examples = len(dataset)
|
|
232
|
+
generator = torch.Generator().manual_seed(seed)
|
|
233
|
+
permutation = torch.randperm(num_examples, generator=generator).tolist()
|
|
234
|
+
|
|
235
|
+
num_validation = int(num_examples * validation_ratio)
|
|
236
|
+
num_test = int(num_examples * test_ratio)
|
|
237
|
+
num_train = num_examples - num_validation - num_test
|
|
238
|
+
|
|
239
|
+
train_indices = permutation[:num_train]
|
|
240
|
+
validation_indices = permutation[num_train : num_train + num_validation]
|
|
241
|
+
test_indices = permutation[num_train + num_validation :]
|
|
242
|
+
|
|
243
|
+
return DatasetSplits(
|
|
244
|
+
train=Subset(dataset, train_indices),
|
|
245
|
+
validation=Subset(dataset, validation_indices),
|
|
246
|
+
test=Subset(dataset, test_indices),
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
@dataclass
|
|
251
|
+
class DatasetLoaders:
|
|
252
|
+
"""Convenience wrapper for train, validation, and test dataloaders."""
|
|
253
|
+
|
|
254
|
+
train: DataLoader[torch.Tensor]
|
|
255
|
+
validation: DataLoader[torch.Tensor]
|
|
256
|
+
test: DataLoader[torch.Tensor]
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
def create_dataloaders(
|
|
260
|
+
splits: DatasetSplits,
|
|
261
|
+
*,
|
|
262
|
+
batch_size: int = 256,
|
|
263
|
+
num_workers: int = 0,
|
|
264
|
+
) -> DatasetLoaders:
|
|
265
|
+
"""Create dataloaders from deterministic dataset splits."""
|
|
266
|
+
|
|
267
|
+
return DatasetLoaders(
|
|
268
|
+
train=DataLoader(splits.train, batch_size=batch_size, shuffle=True, num_workers=num_workers),
|
|
269
|
+
validation=DataLoader(
|
|
270
|
+
splits.validation,
|
|
271
|
+
batch_size=batch_size,
|
|
272
|
+
shuffle=False,
|
|
273
|
+
num_workers=num_workers,
|
|
274
|
+
),
|
|
275
|
+
test=DataLoader(splits.test, batch_size=batch_size, shuffle=False, num_workers=num_workers),
|
|
276
|
+
)
|
|
@@ -0,0 +1,145 @@
|
|
|
1
|
+
"""Helpers for loading and packaging embedding matrices."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
|
|
11
|
+
from .base import AutoencoderDataset
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass
|
|
15
|
+
class EmbeddingMatrix:
|
|
16
|
+
"""A simple in-memory representation of a token embedding matrix."""
|
|
17
|
+
|
|
18
|
+
tokens: list[str]
|
|
19
|
+
matrix: torch.Tensor
|
|
20
|
+
token_to_index: dict[str, int]
|
|
21
|
+
source_path: str | None = None
|
|
22
|
+
name: str | None = None
|
|
23
|
+
|
|
24
|
+
@property
|
|
25
|
+
def num_embeddings(self) -> int:
|
|
26
|
+
return int(self.matrix.shape[0])
|
|
27
|
+
|
|
28
|
+
@property
|
|
29
|
+
def embedding_dim(self) -> int:
|
|
30
|
+
return int(self.matrix.shape[1])
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class EmbeddingTensorDataset(AutoencoderDataset):
|
|
34
|
+
"""Dataset that exposes each embedding vector as one training sample."""
|
|
35
|
+
|
|
36
|
+
def __init__(self, embedding_matrix: EmbeddingMatrix, split: str = "train") -> None:
|
|
37
|
+
super().__init__(split=split)
|
|
38
|
+
self.embedding_matrix = embedding_matrix
|
|
39
|
+
|
|
40
|
+
def __len__(self) -> int:
|
|
41
|
+
return self.embedding_matrix.num_embeddings
|
|
42
|
+
|
|
43
|
+
def __getitem__(self, index: int) -> torch.Tensor:
|
|
44
|
+
return self.embedding_matrix.matrix[index]
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def load_text_embedding_matrix(
|
|
48
|
+
path: str | Path,
|
|
49
|
+
*,
|
|
50
|
+
max_vectors: int | None = None,
|
|
51
|
+
expected_dim: int | None = None,
|
|
52
|
+
skip_first_line: bool = False,
|
|
53
|
+
dtype: torch.dtype = torch.float32,
|
|
54
|
+
) -> EmbeddingMatrix:
|
|
55
|
+
"""Load a whitespace-separated embedding text file such as GloVe."""
|
|
56
|
+
|
|
57
|
+
file_path = Path(path)
|
|
58
|
+
tokens: list[str] = []
|
|
59
|
+
rows: list[list[float]] = []
|
|
60
|
+
|
|
61
|
+
with file_path.open("r", encoding="utf-8") as handle:
|
|
62
|
+
for line_number, line in enumerate(handle, start=1):
|
|
63
|
+
if skip_first_line and line_number == 1:
|
|
64
|
+
continue
|
|
65
|
+
stripped = line.strip()
|
|
66
|
+
if not stripped:
|
|
67
|
+
continue
|
|
68
|
+
|
|
69
|
+
parts = stripped.split()
|
|
70
|
+
if len(parts) < 2:
|
|
71
|
+
raise ValueError(f"Malformed embedding row at line {line_number}: {stripped!r}")
|
|
72
|
+
|
|
73
|
+
token = parts[0]
|
|
74
|
+
values = parts[1:]
|
|
75
|
+
|
|
76
|
+
if expected_dim is None:
|
|
77
|
+
expected_dim = len(values)
|
|
78
|
+
elif len(values) != expected_dim:
|
|
79
|
+
raise ValueError(
|
|
80
|
+
f"Embedding dimension mismatch at line {line_number}: "
|
|
81
|
+
f"expected {expected_dim}, got {len(values)}"
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
tokens.append(token)
|
|
85
|
+
rows.append([float(value) for value in values])
|
|
86
|
+
|
|
87
|
+
if max_vectors is not None and len(tokens) >= max_vectors:
|
|
88
|
+
break
|
|
89
|
+
|
|
90
|
+
if not rows:
|
|
91
|
+
raise ValueError(f"No embeddings were loaded from {file_path}.")
|
|
92
|
+
|
|
93
|
+
matrix = torch.tensor(rows, dtype=dtype)
|
|
94
|
+
token_to_index = {token: index for index, token in enumerate(tokens)}
|
|
95
|
+
return EmbeddingMatrix(
|
|
96
|
+
tokens=tokens,
|
|
97
|
+
matrix=matrix,
|
|
98
|
+
token_to_index=token_to_index,
|
|
99
|
+
source_path=str(file_path),
|
|
100
|
+
name=file_path.stem,
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def save_embedding_artifact(embedding_matrix: EmbeddingMatrix, output_dir: str | Path) -> Path:
|
|
105
|
+
"""Save a processed embedding matrix in a compact torch-friendly format."""
|
|
106
|
+
|
|
107
|
+
save_dir = Path(output_dir)
|
|
108
|
+
save_dir.mkdir(parents=True, exist_ok=True)
|
|
109
|
+
|
|
110
|
+
tensor_path = save_dir / "embeddings.pt"
|
|
111
|
+
tokens_path = save_dir / "tokens.txt"
|
|
112
|
+
metadata_path = save_dir / "metadata.json"
|
|
113
|
+
|
|
114
|
+
torch.save(embedding_matrix.matrix, tensor_path)
|
|
115
|
+
tokens_path.write_text("\n".join(embedding_matrix.tokens) + "\n", encoding="utf-8")
|
|
116
|
+
metadata = {
|
|
117
|
+
"name": embedding_matrix.name,
|
|
118
|
+
"source_path": embedding_matrix.source_path,
|
|
119
|
+
"num_embeddings": embedding_matrix.num_embeddings,
|
|
120
|
+
"embedding_dim": embedding_matrix.embedding_dim,
|
|
121
|
+
}
|
|
122
|
+
metadata_path.write_text(json.dumps(metadata, indent=2, sort_keys=True) + "\n", encoding="utf-8")
|
|
123
|
+
return save_dir
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def load_embedding_artifact(path: str | Path) -> EmbeddingMatrix:
|
|
127
|
+
"""Load a processed embedding artifact created by save_embedding_artifact()."""
|
|
128
|
+
|
|
129
|
+
load_dir = Path(path)
|
|
130
|
+
tensor_path = load_dir / "embeddings.pt"
|
|
131
|
+
tokens_path = load_dir / "tokens.txt"
|
|
132
|
+
metadata_path = load_dir / "metadata.json"
|
|
133
|
+
|
|
134
|
+
matrix = torch.load(tensor_path, map_location="cpu")
|
|
135
|
+
tokens = tokens_path.read_text(encoding="utf-8").splitlines()
|
|
136
|
+
metadata = json.loads(metadata_path.read_text(encoding="utf-8"))
|
|
137
|
+
token_to_index = {token: index for index, token in enumerate(tokens)}
|
|
138
|
+
|
|
139
|
+
return EmbeddingMatrix(
|
|
140
|
+
tokens=tokens,
|
|
141
|
+
matrix=matrix,
|
|
142
|
+
token_to_index=token_to_index,
|
|
143
|
+
source_path=metadata.get("source_path"),
|
|
144
|
+
name=metadata.get("name"),
|
|
145
|
+
)
|