boltzmann9 0.1.4__py3-none-any.whl → 0.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- boltzmann9/__init__.py +38 -0
- boltzmann9/__main__.py +4 -0
- boltzmann9/cli.py +389 -0
- boltzmann9/config.py +58 -0
- boltzmann9/data.py +145 -0
- boltzmann9/data_generator.py +234 -0
- boltzmann9/model.py +867 -0
- boltzmann9/pipeline.py +216 -0
- boltzmann9/preprocessor.py +627 -0
- boltzmann9/project.py +195 -0
- boltzmann9/run_utils.py +262 -0
- boltzmann9/tester.py +167 -0
- boltzmann9/utils.py +42 -0
- boltzmann9/visualization.py +115 -0
- {boltzmann9-0.1.4.dist-info → boltzmann9-0.1.7.dist-info}/METADATA +1 -1
- boltzmann9-0.1.7.dist-info/RECORD +19 -0
- boltzmann9-0.1.7.dist-info/top_level.txt +1 -0
- boltzmann9-0.1.4.dist-info/RECORD +0 -5
- boltzmann9-0.1.4.dist-info/top_level.txt +0 -1
- {boltzmann9-0.1.4.dist-info → boltzmann9-0.1.7.dist-info}/WHEEL +0 -0
- {boltzmann9-0.1.4.dist-info → boltzmann9-0.1.7.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,627 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import os
|
|
5
|
+
from dataclasses import dataclass, asdict
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
import pandas as pd
|
|
11
|
+
|
|
12
|
+
try:
|
|
13
|
+
import torch
|
|
14
|
+
except Exception:
|
|
15
|
+
torch = None
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
# -----------------------------
|
|
19
|
+
# Helpers: device
|
|
20
|
+
# -----------------------------
|
|
21
|
+
def resolve_device(device_cfg: str) -> str:
|
|
22
|
+
"""
|
|
23
|
+
Resolve device string deterministically. Returns a torch-style device string.
|
|
24
|
+
"""
|
|
25
|
+
if device_cfg is None:
|
|
26
|
+
return "cpu"
|
|
27
|
+
device_cfg = str(device_cfg).lower().strip()
|
|
28
|
+
if device_cfg in ("cpu", "cuda", "cuda:0", "cuda:1", "mps"):
|
|
29
|
+
return device_cfg
|
|
30
|
+
if device_cfg != "auto":
|
|
31
|
+
# accept arbitrary explicit torch device strings like "cuda:2"
|
|
32
|
+
return str(device_cfg)
|
|
33
|
+
|
|
34
|
+
# auto
|
|
35
|
+
if torch is None:
|
|
36
|
+
return "cpu"
|
|
37
|
+
if torch.cuda.is_available():
|
|
38
|
+
return "cuda:0"
|
|
39
|
+
# mps availability check
|
|
40
|
+
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
|
41
|
+
return "mps"
|
|
42
|
+
return "cpu"
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
# -----------------------------
|
|
46
|
+
# Helpers: parsing "path/to.csv/col"
|
|
47
|
+
# -----------------------------
|
|
48
|
+
def parse_source_key(key: str) -> Tuple[str, str]:
|
|
49
|
+
"""
|
|
50
|
+
Parse "path/to/file.csv/column_name" -> (csv_path, column_name).
|
|
51
|
+
|
|
52
|
+
Notes:
|
|
53
|
+
- This assumes the column name does not contain "/".
|
|
54
|
+
- Windows paths are supported if they use "/" in config. If you need
|
|
55
|
+
backslashes, use raw strings and still include a final "/" before column.
|
|
56
|
+
"""
|
|
57
|
+
key = str(key).strip()
|
|
58
|
+
if "/" not in key:
|
|
59
|
+
raise ValueError(
|
|
60
|
+
f"Invalid visible_blocks key '{key}'. Expected 'path/to.csv/column_name'."
|
|
61
|
+
)
|
|
62
|
+
csv_path, col = key.rsplit("/", 1)
|
|
63
|
+
csv_path = csv_path.strip()
|
|
64
|
+
col = col.strip()
|
|
65
|
+
if not csv_path or not col:
|
|
66
|
+
raise ValueError(
|
|
67
|
+
f"Invalid visible_blocks key '{key}'. Parsed csv_path='{csv_path}', col='{col}'."
|
|
68
|
+
)
|
|
69
|
+
return csv_path, col
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def safe_feature_name(source_key: str) -> str:
|
|
73
|
+
"""
|
|
74
|
+
Create a stable, filesystem-safe, column-safe feature name from a source key.
|
|
75
|
+
Example: "data/raw_1.csv/price" -> "data__raw_1.csv__price"
|
|
76
|
+
"""
|
|
77
|
+
# Keep it readable and deterministic
|
|
78
|
+
return (
|
|
79
|
+
source_key.replace("\\", "/")
|
|
80
|
+
.replace("/", "__")
|
|
81
|
+
.replace(" ", "_")
|
|
82
|
+
.replace(":", "_")
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
# -----------------------------
|
|
87
|
+
# Bit packing utilities
|
|
88
|
+
# -----------------------------
|
|
89
|
+
def int_to_bits_msb_first(x: np.ndarray, nbits: int) -> np.ndarray:
|
|
90
|
+
"""
|
|
91
|
+
x: (N,) uint64
|
|
92
|
+
returns: (N, nbits) uint8 bits, MSB first
|
|
93
|
+
"""
|
|
94
|
+
x = x.astype(np.uint64, copy=False)
|
|
95
|
+
shifts = np.arange(nbits - 1, -1, -1, dtype=np.uint64)
|
|
96
|
+
return ((x[:, None] >> shifts) & 1).astype(np.uint8)
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def bits_to_int_msb_first(bits: np.ndarray) -> np.ndarray:
|
|
100
|
+
"""
|
|
101
|
+
bits: (N, nbits) uint8, MSB first
|
|
102
|
+
returns: (N,) uint64
|
|
103
|
+
"""
|
|
104
|
+
bits = bits.astype(np.uint64, copy=False)
|
|
105
|
+
nbits = bits.shape[1]
|
|
106
|
+
shifts = np.arange(nbits - 1, -1, -1, dtype=np.uint64)
|
|
107
|
+
return (bits << shifts).sum(axis=1)
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def binary_to_gray(q: np.ndarray) -> np.ndarray:
|
|
111
|
+
q = q.astype(np.uint64, copy=False)
|
|
112
|
+
return (q ^ (q >> 1)).astype(np.uint64)
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def gray_to_binary(g: np.ndarray) -> np.ndarray:
|
|
116
|
+
g = g.astype(np.uint64, copy=False)
|
|
117
|
+
q = g.copy()
|
|
118
|
+
shift = 1
|
|
119
|
+
while True:
|
|
120
|
+
shifted = q >> shift
|
|
121
|
+
if np.all(shifted == 0):
|
|
122
|
+
break
|
|
123
|
+
q ^= shifted
|
|
124
|
+
shift <<= 1
|
|
125
|
+
return q
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
# -----------------------------
|
|
129
|
+
# Specs / metadata
|
|
130
|
+
# -----------------------------
|
|
131
|
+
@dataclass
|
|
132
|
+
class FloatGraySpec:
|
|
133
|
+
source_key: str
|
|
134
|
+
feature_name: str
|
|
135
|
+
nbits: int
|
|
136
|
+
low: float
|
|
137
|
+
high: float
|
|
138
|
+
q_low: float
|
|
139
|
+
q_high: float
|
|
140
|
+
add_missing_bit: bool = True
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
@dataclass
|
|
144
|
+
class BinarySpec:
|
|
145
|
+
source_key: str
|
|
146
|
+
feature_name: str
|
|
147
|
+
add_missing_bit: bool = True
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
@dataclass
|
|
151
|
+
class CategoricalSpec:
|
|
152
|
+
source_key: str
|
|
153
|
+
feature_name: str
|
|
154
|
+
categories: List[str]
|
|
155
|
+
add_unk: bool = True
|
|
156
|
+
add_missing: bool = True
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
# -----------------------------
|
|
160
|
+
# Main preprocessor
|
|
161
|
+
# -----------------------------
|
|
162
|
+
class DataPreprocessor:
|
|
163
|
+
"""
|
|
164
|
+
Industry-grade RBM preprocessor:
|
|
165
|
+
- Reads multiple raw CSV files based on config["model"]["visible_blocks"]
|
|
166
|
+
- Auto-detects types: float -> Gray K-bit, categorical -> one-hot, binary -> passthrough
|
|
167
|
+
- Adds missing indicator bits (default)
|
|
168
|
+
- Writes processed dataset to config["data"]["csv_path"]
|
|
169
|
+
- Writes metadata JSON next to output CSV
|
|
170
|
+
- Resolves relative paths against the provided config_dir
|
|
171
|
+
"""
|
|
172
|
+
|
|
173
|
+
def __init__(
|
|
174
|
+
self,
|
|
175
|
+
config: Dict[str, Any],
|
|
176
|
+
*,
|
|
177
|
+
config_dir: Optional[Union[str, os.PathLike]] = None,
|
|
178
|
+
) -> None:
|
|
179
|
+
self.config = config
|
|
180
|
+
self.config_dir = Path(config_dir).resolve() if config_dir is not None else None
|
|
181
|
+
|
|
182
|
+
self.device_str: str = resolve_device(config.get("device", "auto"))
|
|
183
|
+
|
|
184
|
+
data_cfg = config.get("data", {})
|
|
185
|
+
model_cfg = config.get("model", {})
|
|
186
|
+
|
|
187
|
+
raw_output_path = data_cfg.get("csv_path") or "data/processed.csv"
|
|
188
|
+
self.output_csv_path = self._resolve_path(raw_output_path)
|
|
189
|
+
self.drop_cols: List[str] = list(data_cfg.get("drop_cols", []))
|
|
190
|
+
|
|
191
|
+
self.bm_type = str(model_cfg.get("bm_type", "rbm")).lower().strip()
|
|
192
|
+
if self.bm_type != "rbm":
|
|
193
|
+
raise ValueError(f"Only bm_type='rbm' is supported, got: {self.bm_type}")
|
|
194
|
+
|
|
195
|
+
# visible_blocks: { "path/to/raw.csv/col": K_i, ... }
|
|
196
|
+
raw_visible = model_cfg.get("visible_blocks", {})
|
|
197
|
+
if not isinstance(raw_visible, dict) or len(raw_visible) == 0:
|
|
198
|
+
raise ValueError("model.visible_blocks must be a non-empty dict.")
|
|
199
|
+
|
|
200
|
+
# Preprocessing params (optional; safe defaults)
|
|
201
|
+
prep_cfg = config.get("preprocess", {}) # optional block (not required)
|
|
202
|
+
self.q_low: float = float(prep_cfg.get("q_low", 0.001))
|
|
203
|
+
self.q_high: float = float(prep_cfg.get("q_high", 0.999))
|
|
204
|
+
self.add_missing_bit: bool = bool(prep_cfg.get("add_missing_bit", True))
|
|
205
|
+
self.max_categories: int = int(prep_cfg.get("max_categories", 200))
|
|
206
|
+
self.min_category_freq: int = int(prep_cfg.get("min_category_freq", 1))
|
|
207
|
+
self.force_float: bool = bool(prep_cfg.get("force_float", False))
|
|
208
|
+
# if force_float=True, numeric columns are always treated as float quantized (unless binary)
|
|
209
|
+
|
|
210
|
+
# Store parsed requests
|
|
211
|
+
self.requested: List[Tuple[str, str, int]] = [] # (source_key, col, K)
|
|
212
|
+
for source_key, K in raw_visible.items():
|
|
213
|
+
csv_path, col = parse_source_key(source_key)
|
|
214
|
+
try:
|
|
215
|
+
K_i = int(K)
|
|
216
|
+
except Exception as e:
|
|
217
|
+
raise ValueError(f"K_i for '{source_key}' must be int-like, got {K}") from e
|
|
218
|
+
if K_i <= 0:
|
|
219
|
+
raise ValueError(f"K_i for '{source_key}' must be >=1, got {K_i}")
|
|
220
|
+
self.requested.append((source_key, col, K_i))
|
|
221
|
+
|
|
222
|
+
# Will be filled after fit_transform
|
|
223
|
+
self.float_specs: List[FloatGraySpec] = []
|
|
224
|
+
self.bin_specs: List[BinarySpec] = []
|
|
225
|
+
self.cat_specs: List[CategoricalSpec] = []
|
|
226
|
+
|
|
227
|
+
self.visible_blocks_out: Dict[str, List[str]] = {} # feature -> produced bit columns
|
|
228
|
+
self.processed_columns: List[str] = []
|
|
229
|
+
|
|
230
|
+
def _resolve_path(self, path_like: Union[str, os.PathLike]) -> Path:
|
|
231
|
+
"""
|
|
232
|
+
Resolve a path relative to the config directory if provided.
|
|
233
|
+
Absolute paths pass through unchanged.
|
|
234
|
+
"""
|
|
235
|
+
p = Path(path_like)
|
|
236
|
+
if p.is_absolute() or self.config_dir is None:
|
|
237
|
+
return p
|
|
238
|
+
return (self.config_dir / p).resolve()
|
|
239
|
+
|
|
240
|
+
# -----------------------------
|
|
241
|
+
# I/O: load raw columns
|
|
242
|
+
# -----------------------------
|
|
243
|
+
def _load_raw_dataframe(self) -> pd.DataFrame:
|
|
244
|
+
"""
|
|
245
|
+
Load all requested columns from their CSV files, merge into one dataframe.
|
|
246
|
+
Validates consistent row counts across files.
|
|
247
|
+
"""
|
|
248
|
+
# Group by csv path to avoid re-reading files
|
|
249
|
+
by_file: Dict[Path, List[Tuple[str, str, int]]] = {}
|
|
250
|
+
for source_key, col, k in self.requested:
|
|
251
|
+
csv_path, _ = parse_source_key(source_key)
|
|
252
|
+
resolved_csv = self._resolve_path(csv_path)
|
|
253
|
+
by_file.setdefault(resolved_csv, []).append((source_key, col, k))
|
|
254
|
+
|
|
255
|
+
merged_parts: List[pd.DataFrame] = []
|
|
256
|
+
expected_len: Optional[int] = None
|
|
257
|
+
|
|
258
|
+
for csv_path, items in by_file.items():
|
|
259
|
+
p = csv_path
|
|
260
|
+
if not p.exists():
|
|
261
|
+
raise FileNotFoundError(f"Raw CSV not found: {p}")
|
|
262
|
+
|
|
263
|
+
# Read only required columns if possible
|
|
264
|
+
usecols = list({col for (_, col, _) in items})
|
|
265
|
+
df = pd.read_csv(p, usecols=usecols)
|
|
266
|
+
|
|
267
|
+
if expected_len is None:
|
|
268
|
+
expected_len = len(df)
|
|
269
|
+
else:
|
|
270
|
+
if len(df) != expected_len:
|
|
271
|
+
raise ValueError(
|
|
272
|
+
f"Row count mismatch: file '{p}' has {len(df)} rows, expected {expected_len}. "
|
|
273
|
+
f"Align datasets before preprocessing (same order / same rows)."
|
|
274
|
+
)
|
|
275
|
+
|
|
276
|
+
# Rename raw columns to stable feature names (based on full source_key)
|
|
277
|
+
renamed = {}
|
|
278
|
+
for source_key, col, _k in items:
|
|
279
|
+
renamed[col] = safe_feature_name(source_key)
|
|
280
|
+
df = df.rename(columns=renamed)
|
|
281
|
+
|
|
282
|
+
merged_parts.append(df)
|
|
283
|
+
|
|
284
|
+
merged = pd.concat(merged_parts, axis=1)
|
|
285
|
+
|
|
286
|
+
# drop requested columns if user asked (in output space)
|
|
287
|
+
for c in self.drop_cols:
|
|
288
|
+
if c in merged.columns:
|
|
289
|
+
merged = merged.drop(columns=[c])
|
|
290
|
+
|
|
291
|
+
return merged
|
|
292
|
+
|
|
293
|
+
# -----------------------------
|
|
294
|
+
# Type detection
|
|
295
|
+
# -----------------------------
|
|
296
|
+
@staticmethod
|
|
297
|
+
def _is_binary_series(s: pd.Series) -> bool:
|
|
298
|
+
"""
|
|
299
|
+
True if (ignoring NaNs) values are subset of {0,1} or {False,True}.
|
|
300
|
+
"""
|
|
301
|
+
x = s.dropna()
|
|
302
|
+
if x.empty:
|
|
303
|
+
return False
|
|
304
|
+
# Try numeric interpretation
|
|
305
|
+
if pd.api.types.is_bool_dtype(x):
|
|
306
|
+
return True
|
|
307
|
+
if pd.api.types.is_numeric_dtype(x):
|
|
308
|
+
vals = set(pd.unique(x.astype(float)))
|
|
309
|
+
return vals.issubset({0.0, 1.0})
|
|
310
|
+
return False
|
|
311
|
+
|
|
312
|
+
# -----------------------------
|
|
313
|
+
# Float preprocessing: quantize + Gray bits
|
|
314
|
+
# -----------------------------
|
|
315
|
+
def _fit_float_range(self, x: np.ndarray) -> Tuple[float, float]:
|
|
316
|
+
x = x[np.isfinite(x)]
|
|
317
|
+
if x.size == 0:
|
|
318
|
+
raise ValueError("Cannot fit float range: all values are NaN/inf.")
|
|
319
|
+
lo = float(np.quantile(x, self.q_low))
|
|
320
|
+
hi = float(np.quantile(x, self.q_high))
|
|
321
|
+
if not np.isfinite(lo) or not np.isfinite(hi) or hi <= lo:
|
|
322
|
+
# fallback: min/max if quantiles collapse
|
|
323
|
+
lo = float(np.min(x))
|
|
324
|
+
hi = float(np.max(x))
|
|
325
|
+
if hi <= lo:
|
|
326
|
+
raise ValueError(f"Degenerate float range after fallback: lo={lo}, hi={hi}")
|
|
327
|
+
return lo, hi
|
|
328
|
+
|
|
329
|
+
def _encode_float_gray(
|
|
330
|
+
self,
|
|
331
|
+
s: pd.Series,
|
|
332
|
+
nbits: int,
|
|
333
|
+
low: float,
|
|
334
|
+
high: float,
|
|
335
|
+
feature_name: str,
|
|
336
|
+
add_missing_bit: bool,
|
|
337
|
+
) -> pd.DataFrame:
|
|
338
|
+
x = s.to_numpy(dtype=np.float64, copy=False)
|
|
339
|
+
|
|
340
|
+
miss = ~np.isfinite(x)
|
|
341
|
+
x_filled = x.copy()
|
|
342
|
+
# midpoint fill (but missing indicator makes it safe)
|
|
343
|
+
x_filled[miss] = 0.5 * (low + high)
|
|
344
|
+
|
|
345
|
+
# clip
|
|
346
|
+
x_filled = np.clip(x_filled, low, high)
|
|
347
|
+
|
|
348
|
+
# normalize -> quantize
|
|
349
|
+
denom = (high - low)
|
|
350
|
+
u = (x_filled - low) / denom # in [0,1]
|
|
351
|
+
qmax = (1 << nbits) - 1
|
|
352
|
+
q = np.rint(u * qmax).astype(np.uint64)
|
|
353
|
+
q = np.clip(q, 0, qmax).astype(np.uint64)
|
|
354
|
+
|
|
355
|
+
g = binary_to_gray(q)
|
|
356
|
+
bits = int_to_bits_msb_first(g, nbits) # (N, nbits)
|
|
357
|
+
|
|
358
|
+
col_bits = [f"{feature_name}__g{j:02d}" for j in range(nbits)]
|
|
359
|
+
out = pd.DataFrame(bits, columns=col_bits, index=s.index, dtype=np.uint8)
|
|
360
|
+
|
|
361
|
+
if add_missing_bit:
|
|
362
|
+
out[f"{feature_name}__missing"] = miss.astype(np.uint8)
|
|
363
|
+
|
|
364
|
+
return out
|
|
365
|
+
|
|
366
|
+
# -----------------------------
|
|
367
|
+
# Binary preprocessing
|
|
368
|
+
# -----------------------------
|
|
369
|
+
def _encode_binary(
|
|
370
|
+
self,
|
|
371
|
+
s: pd.Series,
|
|
372
|
+
feature_name: str,
|
|
373
|
+
add_missing_bit: bool,
|
|
374
|
+
) -> pd.DataFrame:
|
|
375
|
+
miss = s.isna().to_numpy(dtype=bool)
|
|
376
|
+
|
|
377
|
+
# strict 0/1
|
|
378
|
+
if pd.api.types.is_bool_dtype(s):
|
|
379
|
+
v = s.fillna(False).astype(bool).to_numpy(dtype=np.uint8)
|
|
380
|
+
else:
|
|
381
|
+
v = s.fillna(0).astype(float).to_numpy()
|
|
382
|
+
v = np.clip(v, 0, 1)
|
|
383
|
+
# values other than 0/1 are suspicious; round and validate
|
|
384
|
+
v_round = np.rint(v).astype(np.uint8)
|
|
385
|
+
# validate (after rounding)
|
|
386
|
+
bad = ~np.isin(v_round, [0, 1])
|
|
387
|
+
if bad.any():
|
|
388
|
+
idx = np.where(bad)[0][:5]
|
|
389
|
+
raise ValueError(
|
|
390
|
+
f"Binary column '{feature_name}' has non-binary values at rows {idx.tolist()}."
|
|
391
|
+
)
|
|
392
|
+
v = v_round
|
|
393
|
+
|
|
394
|
+
out = pd.DataFrame(index=s.index)
|
|
395
|
+
out[f"{feature_name}__bin"] = v.astype(np.uint8)
|
|
396
|
+
|
|
397
|
+
if add_missing_bit:
|
|
398
|
+
out[f"{feature_name}__missing"] = miss.astype(np.uint8)
|
|
399
|
+
|
|
400
|
+
return out
|
|
401
|
+
|
|
402
|
+
# -----------------------------
|
|
403
|
+
# Categorical preprocessing: one-hot (+UNK +MISSING)
|
|
404
|
+
# -----------------------------
|
|
405
|
+
def _fit_categories(self, s: pd.Series) -> List[str]:
|
|
406
|
+
s_obj = s.astype("object")
|
|
407
|
+
miss = s_obj.isna()
|
|
408
|
+
vc = s_obj[~miss].value_counts()
|
|
409
|
+
|
|
410
|
+
# filter by frequency
|
|
411
|
+
kept = vc[vc >= self.min_category_freq].index.astype(str).tolist()
|
|
412
|
+
|
|
413
|
+
# cap categories
|
|
414
|
+
if len(kept) > self.max_categories:
|
|
415
|
+
kept = kept[: self.max_categories]
|
|
416
|
+
|
|
417
|
+
return kept
|
|
418
|
+
|
|
419
|
+
def _encode_categorical(
|
|
420
|
+
self,
|
|
421
|
+
s: pd.Series,
|
|
422
|
+
categories: List[str],
|
|
423
|
+
feature_name: str,
|
|
424
|
+
add_unk: bool = True,
|
|
425
|
+
add_missing: bool = True,
|
|
426
|
+
) -> pd.DataFrame:
|
|
427
|
+
s_obj = s.astype("object")
|
|
428
|
+
miss = s_obj.isna()
|
|
429
|
+
|
|
430
|
+
cols = [f"{feature_name}__{c}" for c in categories]
|
|
431
|
+
if add_unk:
|
|
432
|
+
cols.append(f"{feature_name}__UNK")
|
|
433
|
+
if add_missing:
|
|
434
|
+
cols.append(f"{feature_name}__MISSING")
|
|
435
|
+
|
|
436
|
+
out = pd.DataFrame(0, index=s.index, columns=cols, dtype=np.uint8)
|
|
437
|
+
cat_set = set(categories)
|
|
438
|
+
|
|
439
|
+
# set one-hot
|
|
440
|
+
for idx, val in s_obj[~miss].items():
|
|
441
|
+
v = str(val)
|
|
442
|
+
if v in cat_set:
|
|
443
|
+
out.at[idx, f"{feature_name}__{v}"] = 1
|
|
444
|
+
else:
|
|
445
|
+
if add_unk:
|
|
446
|
+
out.at[idx, f"{feature_name}__UNK"] = 1
|
|
447
|
+
|
|
448
|
+
if add_missing:
|
|
449
|
+
out.loc[miss, f"{feature_name}__MISSING"] = 1
|
|
450
|
+
|
|
451
|
+
return out
|
|
452
|
+
|
|
453
|
+
# -----------------------------
|
|
454
|
+
# Main pipeline
|
|
455
|
+
# -----------------------------
|
|
456
|
+
def fit_transform(self) -> pd.DataFrame:
|
|
457
|
+
"""
|
|
458
|
+
Fits necessary per-feature parameters (ranges/categories) on the loaded raw data,
|
|
459
|
+
then transforms into RBM-ready bits, writes CSV + metadata, and returns the processed df.
|
|
460
|
+
"""
|
|
461
|
+
raw = self._load_raw_dataframe()
|
|
462
|
+
|
|
463
|
+
self.float_specs.clear()
|
|
464
|
+
self.bin_specs.clear()
|
|
465
|
+
self.cat_specs.clear()
|
|
466
|
+
self.visible_blocks_out.clear()
|
|
467
|
+
|
|
468
|
+
blocks: List[pd.DataFrame] = []
|
|
469
|
+
|
|
470
|
+
# For each requested feature (in config order), process
|
|
471
|
+
# We reconstruct "source_key -> feature_name" mapping
|
|
472
|
+
source_to_feature: Dict[str, str] = {}
|
|
473
|
+
source_to_bits: Dict[str, int] = {}
|
|
474
|
+
for source_key, _col, k in self.requested:
|
|
475
|
+
fn = safe_feature_name(source_key)
|
|
476
|
+
source_to_feature[source_key] = fn
|
|
477
|
+
source_to_bits[source_key] = int(k)
|
|
478
|
+
|
|
479
|
+
# Process in stable order: as provided in visible_blocks
|
|
480
|
+
for source_key, _col, k in self.requested:
|
|
481
|
+
feature_name = source_to_feature[source_key]
|
|
482
|
+
if feature_name not in raw.columns:
|
|
483
|
+
raise KeyError(
|
|
484
|
+
f"Expected column '{feature_name}' from source '{source_key}' not found in merged dataframe."
|
|
485
|
+
)
|
|
486
|
+
|
|
487
|
+
s = raw[feature_name]
|
|
488
|
+
|
|
489
|
+
# Decide type
|
|
490
|
+
if self._is_binary_series(s):
|
|
491
|
+
# binary
|
|
492
|
+
encoded = self._encode_binary(s, feature_name, add_missing_bit=self.add_missing_bit)
|
|
493
|
+
self.bin_specs.append(BinarySpec(source_key=source_key, feature_name=feature_name, add_missing_bit=self.add_missing_bit))
|
|
494
|
+
self.visible_blocks_out[feature_name] = list(encoded.columns)
|
|
495
|
+
blocks.append(encoded)
|
|
496
|
+
continue
|
|
497
|
+
|
|
498
|
+
if pd.api.types.is_object_dtype(s) or pd.api.types.is_categorical_dtype(s):
|
|
499
|
+
# categorical
|
|
500
|
+
cats = self._fit_categories(s)
|
|
501
|
+
encoded = self._encode_categorical(
|
|
502
|
+
s,
|
|
503
|
+
categories=cats,
|
|
504
|
+
feature_name=feature_name,
|
|
505
|
+
add_unk=True,
|
|
506
|
+
add_missing=True,
|
|
507
|
+
)
|
|
508
|
+
self.cat_specs.append(
|
|
509
|
+
CategoricalSpec(
|
|
510
|
+
source_key=source_key,
|
|
511
|
+
feature_name=feature_name,
|
|
512
|
+
categories=cats,
|
|
513
|
+
add_unk=True,
|
|
514
|
+
add_missing=True,
|
|
515
|
+
)
|
|
516
|
+
)
|
|
517
|
+
self.visible_blocks_out[feature_name] = list(encoded.columns)
|
|
518
|
+
blocks.append(encoded)
|
|
519
|
+
continue
|
|
520
|
+
|
|
521
|
+
# numeric non-binary -> float quantize (Gray) using K_i = k
|
|
522
|
+
if not pd.api.types.is_numeric_dtype(s):
|
|
523
|
+
# last resort: treat as categorical
|
|
524
|
+
cats = self._fit_categories(s.astype("object"))
|
|
525
|
+
encoded = self._encode_categorical(
|
|
526
|
+
s.astype("object"),
|
|
527
|
+
categories=cats,
|
|
528
|
+
feature_name=feature_name,
|
|
529
|
+
add_unk=True,
|
|
530
|
+
add_missing=True,
|
|
531
|
+
)
|
|
532
|
+
self.cat_specs.append(
|
|
533
|
+
CategoricalSpec(
|
|
534
|
+
source_key=source_key,
|
|
535
|
+
feature_name=feature_name,
|
|
536
|
+
categories=cats,
|
|
537
|
+
add_unk=True,
|
|
538
|
+
add_missing=True,
|
|
539
|
+
)
|
|
540
|
+
)
|
|
541
|
+
self.visible_blocks_out[feature_name] = list(encoded.columns)
|
|
542
|
+
blocks.append(encoded)
|
|
543
|
+
continue
|
|
544
|
+
|
|
545
|
+
# float encoding
|
|
546
|
+
x = s.to_numpy(dtype=np.float64, copy=False)
|
|
547
|
+
low, high = self._fit_float_range(x)
|
|
548
|
+
|
|
549
|
+
encoded = self._encode_float_gray(
|
|
550
|
+
s=s,
|
|
551
|
+
nbits=int(k),
|
|
552
|
+
low=low,
|
|
553
|
+
high=high,
|
|
554
|
+
feature_name=feature_name,
|
|
555
|
+
add_missing_bit=self.add_missing_bit,
|
|
556
|
+
)
|
|
557
|
+
|
|
558
|
+
self.float_specs.append(
|
|
559
|
+
FloatGraySpec(
|
|
560
|
+
source_key=source_key,
|
|
561
|
+
feature_name=feature_name,
|
|
562
|
+
nbits=int(k),
|
|
563
|
+
low=float(low),
|
|
564
|
+
high=float(high),
|
|
565
|
+
q_low=float(self.q_low),
|
|
566
|
+
q_high=float(self.q_high),
|
|
567
|
+
add_missing_bit=self.add_missing_bit,
|
|
568
|
+
)
|
|
569
|
+
)
|
|
570
|
+
self.visible_blocks_out[feature_name] = list(encoded.columns)
|
|
571
|
+
blocks.append(encoded)
|
|
572
|
+
|
|
573
|
+
X = pd.concat(blocks, axis=1)
|
|
574
|
+
# enforce strict {0,1} uint8
|
|
575
|
+
X = X.astype(np.uint8)
|
|
576
|
+
|
|
577
|
+
# Validate
|
|
578
|
+
if not ((X.values == 0) | (X.values == 1)).all():
|
|
579
|
+
raise ValueError("Processed dataset contains values outside {0,1}.")
|
|
580
|
+
|
|
581
|
+
self.processed_columns = list(X.columns)
|
|
582
|
+
|
|
583
|
+
# Write outputs
|
|
584
|
+
self._export(X)
|
|
585
|
+
|
|
586
|
+
return X
|
|
587
|
+
|
|
588
|
+
def _export(self, X: pd.DataFrame) -> None:
|
|
589
|
+
self.output_csv_path.parent.mkdir(parents=True, exist_ok=True)
|
|
590
|
+
X.to_csv(self.output_csv_path, index=False)
|
|
591
|
+
|
|
592
|
+
meta = self.export_metadata()
|
|
593
|
+
meta_path = self.output_csv_path.with_suffix(self.output_csv_path.suffix + ".meta.json")
|
|
594
|
+
with open(meta_path, "w", encoding="utf-8") as f:
|
|
595
|
+
json.dump(meta, f, indent=2, ensure_ascii=False)
|
|
596
|
+
|
|
597
|
+
def export_metadata(self) -> Dict[str, Any]:
|
|
598
|
+
"""
|
|
599
|
+
Metadata includes enough to:
|
|
600
|
+
- audit preprocessing
|
|
601
|
+
- rebuild visible_blocks sizes
|
|
602
|
+
- decode floats approximately if desired (range + bits)
|
|
603
|
+
"""
|
|
604
|
+
return {
|
|
605
|
+
"type": "DataPreprocessor",
|
|
606
|
+
"version": 1,
|
|
607
|
+
"device": self.device_str,
|
|
608
|
+
"bm_type": self.bm_type,
|
|
609
|
+
"output_csv_path": str(self.output_csv_path),
|
|
610
|
+
"q_low": self.q_low,
|
|
611
|
+
"q_high": self.q_high,
|
|
612
|
+
"add_missing_bit": self.add_missing_bit,
|
|
613
|
+
"max_categories": self.max_categories,
|
|
614
|
+
"min_category_freq": self.min_category_freq,
|
|
615
|
+
"float_specs": [asdict(s) for s in self.float_specs],
|
|
616
|
+
"binary_specs": [asdict(s) for s in self.bin_specs],
|
|
617
|
+
"categorical_specs": [asdict(s) for s in self.cat_specs],
|
|
618
|
+
"visible_blocks_out": self.visible_blocks_out, # feature_name -> [bit cols]
|
|
619
|
+
"visible_blocks_sizes": {k: len(v) for k, v in self.visible_blocks_out.items()},
|
|
620
|
+
"processed_columns": self.processed_columns,
|
|
621
|
+
}
|
|
622
|
+
|
|
623
|
+
# Convenience: build RBM-ready visible_blocks dict
|
|
624
|
+
def get_visible_blocks_sizes(self) -> Dict[str, int]:
|
|
625
|
+
if not self.visible_blocks_out:
|
|
626
|
+
raise RuntimeError("Run fit_transform() first.")
|
|
627
|
+
return {k: len(v) for k, v in self.visible_blocks_out.items()}
|