boltzmann9 0.1.7__py3-none-any.whl → 0.1.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- boltzmann9/project.py +1 -1
- boltzmann9/templates/__init__.py +1 -0
- boltzmann9/templates/config.py +65 -0
- boltzmann9/templates/synthetic_generator.py +147 -0
- {boltzmann9-0.1.7.dist-info → boltzmann9-0.1.9.dist-info}/METADATA +1 -1
- {boltzmann9-0.1.7.dist-info → boltzmann9-0.1.9.dist-info}/RECORD +9 -6
- {boltzmann9-0.1.7.dist-info → boltzmann9-0.1.9.dist-info}/WHEEL +0 -0
- {boltzmann9-0.1.7.dist-info → boltzmann9-0.1.9.dist-info}/entry_points.txt +0 -0
- {boltzmann9-0.1.7.dist-info → boltzmann9-0.1.9.dist-info}/top_level.txt +0 -0
boltzmann9/project.py
CHANGED
|
@@ -11,7 +11,7 @@ from typing import Any, Dict
|
|
|
11
11
|
|
|
12
12
|
def _get_templates_dir() -> Path:
|
|
13
13
|
"""Get the path to the templates directory."""
|
|
14
|
-
return Path(__file__).parent
|
|
14
|
+
return Path(__file__).parent / "templates"
|
|
15
15
|
|
|
16
16
|
|
|
17
17
|
def _load_config_template(data_path: str) -> str:
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
# Template files for project creation
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
"""RBM experiment configuration."""
|
|
2
|
+
|
|
3
|
+
config = {
|
|
4
|
+
"device": "auto", # "auto" | "cpu" | "cuda:0" | "mps"
|
|
5
|
+
|
|
6
|
+
"data": {
|
|
7
|
+
"csv_path": "data.csv", # Path to training data CSV
|
|
8
|
+
"drop_cols": ['x'],
|
|
9
|
+
},
|
|
10
|
+
|
|
11
|
+
"model": {
|
|
12
|
+
"bm_type": "rbm", # Currently only RBM is supported
|
|
13
|
+
"visible_blocks": {"v1": 4, "v2": 4},
|
|
14
|
+
"hidden_blocks": {"h1": 5, "h":10, "h2": 5},
|
|
15
|
+
"cross_block_restrictions": [("v1", "h2"), ("v2", "h1")],
|
|
16
|
+
"initialization": "random",
|
|
17
|
+
},
|
|
18
|
+
|
|
19
|
+
"preprocess": {
|
|
20
|
+
"q_low": 0.001,
|
|
21
|
+
"q_high": 0.999,
|
|
22
|
+
"add_missing_bit": True,
|
|
23
|
+
"max_categories": 200,
|
|
24
|
+
"min_category_freq": 2,
|
|
25
|
+
},
|
|
26
|
+
|
|
27
|
+
"dataloader": {
|
|
28
|
+
"batch_size": 256,
|
|
29
|
+
"split": [0.8, 0.1, 0.1],
|
|
30
|
+
"seed": 42,
|
|
31
|
+
"shuffle_train": True,
|
|
32
|
+
"num_workers": 0,
|
|
33
|
+
"drop_last_train": True,
|
|
34
|
+
"pin_memory": "auto", # "auto" | True | False
|
|
35
|
+
},
|
|
36
|
+
|
|
37
|
+
"train": {
|
|
38
|
+
"epochs": 100,
|
|
39
|
+
"lr": 1e-1,
|
|
40
|
+
"k": 10,
|
|
41
|
+
"kind": "mean-field",
|
|
42
|
+
"momentum": 0.9,
|
|
43
|
+
"weight_decay": 1e-4,
|
|
44
|
+
"clip_value": 0.05,
|
|
45
|
+
"clip_norm": 5.0,
|
|
46
|
+
"lr_schedule": {"mode": "cosine", "min_lr": 1e-4},
|
|
47
|
+
"sparse_hidden": True,
|
|
48
|
+
"rho": 0.1,
|
|
49
|
+
"lambda_sparse": 0.01,
|
|
50
|
+
"early_stopping": True,
|
|
51
|
+
"es_patience": 8,
|
|
52
|
+
},
|
|
53
|
+
|
|
54
|
+
"eval": {
|
|
55
|
+
"recon_k": 1,
|
|
56
|
+
},
|
|
57
|
+
|
|
58
|
+
"conditional": {
|
|
59
|
+
"clamp_idx": [0, 1, 2, 3],
|
|
60
|
+
"target_idx": [4, 5, 6, 7],
|
|
61
|
+
"n_samples": 100,
|
|
62
|
+
"burn_in": 500,
|
|
63
|
+
"thin": 10,
|
|
64
|
+
},
|
|
65
|
+
}
|
|
@@ -0,0 +1,147 @@
|
|
|
1
|
+
#!/usr/bin/env python
|
|
2
|
+
"""Generate synthetic data for this project.
|
|
3
|
+
|
|
4
|
+
This script is auto-generated from the BoltzmaNN9 template.
|
|
5
|
+
Modify the GeneratorConfig parameters below to customize data generation.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import sys
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
|
|
11
|
+
# Add parent directory to path for imports
|
|
12
|
+
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
|
13
|
+
|
|
14
|
+
from src.boltzmann9.data_generator import SyntheticDataGenerator, GeneratorConfig
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def visualize_data(full_df, generator, save_path, show_from=0, show_to=500):
|
|
18
|
+
"""Create and save visualization of the generated data."""
|
|
19
|
+
try:
|
|
20
|
+
import matplotlib.pyplot as plt
|
|
21
|
+
import numpy as np
|
|
22
|
+
except ImportError as e:
|
|
23
|
+
print(f"[visualize_data] matplotlib/numpy import failed: {e}")
|
|
24
|
+
return
|
|
25
|
+
|
|
26
|
+
cfg = generator.config
|
|
27
|
+
K = cfg.k_bins
|
|
28
|
+
n_bits = generator.n_bits
|
|
29
|
+
|
|
30
|
+
# Clamp range to valid indices
|
|
31
|
+
show_from = max(0, show_from)
|
|
32
|
+
show_to = min(len(full_df), show_to)
|
|
33
|
+
|
|
34
|
+
fig = plt.figure(figsize=(15, 10))
|
|
35
|
+
|
|
36
|
+
# Plot 1: Continuous vs Discretized R
|
|
37
|
+
plt.subplot(3, 2, 1)
|
|
38
|
+
plt.plot(full_df["t"][show_from:show_to], full_df["R_continuous"][show_from:show_to],
|
|
39
|
+
alpha=0.5, label="Continuous", linewidth=1)
|
|
40
|
+
plt.plot(full_df["t"][show_from:show_to], full_df["R"][show_from:show_to],
|
|
41
|
+
alpha=0.8, label="Discretized", linewidth=1, drawstyle="steps-post")
|
|
42
|
+
plt.xlabel("Time")
|
|
43
|
+
plt.ylabel("R")
|
|
44
|
+
plt.legend()
|
|
45
|
+
plt.title("Continuous vs Discretized R")
|
|
46
|
+
plt.grid(True, alpha=0.3)
|
|
47
|
+
|
|
48
|
+
# Plot 2: Histogram of discretized R
|
|
49
|
+
plt.subplot(3, 2, 2)
|
|
50
|
+
plt.hist(full_df["R"], bins=K, edgecolor="black", alpha=0.7)
|
|
51
|
+
plt.xlabel("R (discretized)")
|
|
52
|
+
plt.ylabel("Frequency")
|
|
53
|
+
plt.title(f"Distribution across {K} bins")
|
|
54
|
+
plt.grid(True, alpha=0.3)
|
|
55
|
+
|
|
56
|
+
# Plot 3: Bin indices over time
|
|
57
|
+
plt.subplot(3, 2, 3)
|
|
58
|
+
plt.plot(full_df["t"][show_from:show_to], full_df["R_bin_index"][show_from:show_to],
|
|
59
|
+
drawstyle="steps-post")
|
|
60
|
+
plt.xlabel("Time")
|
|
61
|
+
plt.ylabel("Bin Index")
|
|
62
|
+
plt.title("Bin Index over Time")
|
|
63
|
+
plt.yticks(range(K))
|
|
64
|
+
plt.grid(True, alpha=0.3)
|
|
65
|
+
|
|
66
|
+
# Plot 4: Binary bits over time (heatmap style)
|
|
67
|
+
plt.subplot(3, 2, 4)
|
|
68
|
+
binary_matrix = np.array([full_df[f"R_bit_{i}"][show_from:show_to].values for i in range(n_bits)])
|
|
69
|
+
im = plt.imshow(binary_matrix, aspect="auto", cmap="binary", interpolation="nearest")
|
|
70
|
+
plt.xlabel("Time")
|
|
71
|
+
plt.ylabel("Bit Position")
|
|
72
|
+
plt.title(f"Binary Representation over Time ({n_bits} bits)")
|
|
73
|
+
plt.yticks(range(n_bits), [f"Bit {i}" for i in range(n_bits)])
|
|
74
|
+
plt.colorbar(im, label="Bit Value")
|
|
75
|
+
|
|
76
|
+
# Plot 5: Bin index distribution
|
|
77
|
+
plt.subplot(3, 2, 5)
|
|
78
|
+
bin_counts = full_df["R_bin_index"].value_counts().sort_index()
|
|
79
|
+
plt.bar(bin_counts.index, bin_counts.values, alpha=0.7, edgecolor="black")
|
|
80
|
+
plt.xlabel("Bin Index")
|
|
81
|
+
plt.ylabel("Frequency")
|
|
82
|
+
plt.title("Bin Index Distribution")
|
|
83
|
+
plt.xticks(range(K))
|
|
84
|
+
plt.grid(True, alpha=0.3)
|
|
85
|
+
|
|
86
|
+
# Plot 6: Decision variable vs bin index
|
|
87
|
+
plt.subplot(3, 2, 6)
|
|
88
|
+
df_valid = full_df[full_df["x"].notna()]
|
|
89
|
+
decision_by_bin = df_valid.groupby("R_bin_index")["x"].mean()
|
|
90
|
+
plt.bar(decision_by_bin.index, decision_by_bin.values, alpha=0.7, edgecolor="black")
|
|
91
|
+
plt.xlabel("Bin Index")
|
|
92
|
+
plt.ylabel("Mean Decision (x)")
|
|
93
|
+
plt.title("Average Decision by Bin")
|
|
94
|
+
plt.xticks(range(K))
|
|
95
|
+
plt.ylim([0, 1])
|
|
96
|
+
plt.grid(True, alpha=0.3)
|
|
97
|
+
|
|
98
|
+
plt.tight_layout()
|
|
99
|
+
|
|
100
|
+
fig.savefig(save_path, dpi=150, bbox_inches="tight")
|
|
101
|
+
plt.close(fig)
|
|
102
|
+
print(f"Visualization saved to: {save_path}")
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def main():
|
|
106
|
+
"""Generate synthetic data and save to data.csv."""
|
|
107
|
+
# ============================================================
|
|
108
|
+
# MODIFY THESE PARAMETERS TO CUSTOMIZE DATA GENERATION
|
|
109
|
+
# ============================================================
|
|
110
|
+
gen_config = GeneratorConfig(
|
|
111
|
+
n_samples=5000,
|
|
112
|
+
dt=0.1,
|
|
113
|
+
r_min=-2.0,
|
|
114
|
+
r_max=2.0,
|
|
115
|
+
k_bins=16,
|
|
116
|
+
spring_k=5.0,
|
|
117
|
+
sigma=1.0,
|
|
118
|
+
eq_interval=100,
|
|
119
|
+
m0=0.25,
|
|
120
|
+
sigma_eq=0.0,
|
|
121
|
+
lookahead=10,
|
|
122
|
+
)
|
|
123
|
+
seed = 42
|
|
124
|
+
show_from = 0
|
|
125
|
+
show_to = 100
|
|
126
|
+
# ============================================================
|
|
127
|
+
|
|
128
|
+
generator = SyntheticDataGenerator(gen_config)
|
|
129
|
+
generator.print_info()
|
|
130
|
+
|
|
131
|
+
full_df, simplified_df = generator.generate(seed=seed)
|
|
132
|
+
|
|
133
|
+
# Save to data folder
|
|
134
|
+
output_path = Path(__file__).parent / "data.csv"
|
|
135
|
+
simplified_df.to_csv(output_path, index=False)
|
|
136
|
+
|
|
137
|
+
print(f"\nSaved training data to: {output_path}")
|
|
138
|
+
print(f" Columns: {list(simplified_df.columns)}")
|
|
139
|
+
print(f" Rows: {len(simplified_df)}")
|
|
140
|
+
|
|
141
|
+
# Generate visualization
|
|
142
|
+
plot_path = Path(__file__).parent / "data_visualization.png"
|
|
143
|
+
visualize_data(full_df, generator, plot_path, show_from=show_from, show_to=show_to)
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
if __name__ == "__main__":
|
|
147
|
+
main()
|
|
@@ -7,13 +7,16 @@ boltzmann9/data_generator.py,sha256=aXm8b6O4qoXpDa8TCOdDxF_lDaMtcLNR0P06u2cXb1g,
|
|
|
7
7
|
boltzmann9/model.py,sha256=npBDd_pow4f9gcsLIpWMmJLBB6sGljtaOluvXcFlLP0,31538
|
|
8
8
|
boltzmann9/pipeline.py,sha256=C3L7tHVakrmvdnW9KYMXhXvscwG7TP8xmVHbunx1E8w,5965
|
|
9
9
|
boltzmann9/preprocessor.py,sha256=USxpMCnC80wp4zO72lPxmFeVdfB4I5NE4P2YlojCVPQ,21758
|
|
10
|
-
boltzmann9/project.py,sha256=
|
|
10
|
+
boltzmann9/project.py,sha256=p24xKk6pj1YxgGCfuJy2zEA8dSdG8zAGEGc6_wdxMl0,5326
|
|
11
11
|
boltzmann9/run_utils.py,sha256=1eRDEkN8j1mG5g9iXArWRA1yQq6pw2agVUGV4Yhy0jk,7823
|
|
12
12
|
boltzmann9/tester.py,sha256=yMYhYKWWiG-LhJQvKv7QmMCYPQLbmtMxFVrJ9iHZVs8,5659
|
|
13
13
|
boltzmann9/utils.py,sha256=8ftNcbV4dz52-t0Ewu4E7CmkkVzPDa6kd8ZnxaBvhMM,1337
|
|
14
14
|
boltzmann9/visualization.py,sha256=iPSytX7GMfmbRbR43xOmybLZBLmy58D4Rp8wasISMcs,3407
|
|
15
|
-
boltzmann9
|
|
16
|
-
boltzmann9
|
|
17
|
-
boltzmann9
|
|
18
|
-
boltzmann9-0.1.
|
|
19
|
-
boltzmann9-0.1.
|
|
15
|
+
boltzmann9/templates/__init__.py,sha256=LEL_kwiunM5L5fdMLLJQPucOTDb4mL6uUR0U75eA7Ps,38
|
|
16
|
+
boltzmann9/templates/config.py,sha256=2_GRiJ_92lxSr3ypayAGrFoB58Ekyu264OUmoxytOlM,1568
|
|
17
|
+
boltzmann9/templates/synthetic_generator.py,sha256=cz5PZndsIbCh7Qfl40mXH5_yIBei71MvLLdV1mmbPFc,4858
|
|
18
|
+
boltzmann9-0.1.9.dist-info/METADATA,sha256=hrK5IhuTInrxJJwO1R5CYHJJcGSi16Lj2GZCr37kgs0,3097
|
|
19
|
+
boltzmann9-0.1.9.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
|
|
20
|
+
boltzmann9-0.1.9.dist-info/entry_points.txt,sha256=Y2wY1lqICCjTxuaPgme2QxkoAknrASKmhLljcb5FEyw,51
|
|
21
|
+
boltzmann9-0.1.9.dist-info/top_level.txt,sha256=AE0_urGKOFSgKxXf7WQfjGDROWYppyI4tw1VIg-paQI,11
|
|
22
|
+
boltzmann9-0.1.9.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|