boltzmann9 0.1.8__py3-none-any.whl → 0.1.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
boltzmann9/project.py CHANGED
@@ -11,7 +11,7 @@ from typing import Any, Dict
11
11
 
12
12
  def _get_templates_dir() -> Path:
13
13
  """Get the path to the templates directory."""
14
- return Path(__file__).parent.parent / "templates"
14
+ return Path(__file__).parent / "templates"
15
15
 
16
16
 
17
17
  def _load_config_template(data_path: str) -> str:
@@ -0,0 +1 @@
1
+ # Template files for project creation
@@ -0,0 +1,65 @@
1
+ """RBM experiment configuration."""
2
+
3
+ config = {
4
+ "device": "auto", # "auto" | "cpu" | "cuda:0" | "mps"
5
+
6
+ "data": {
7
+ "csv_path": "data.csv", # Path to training data CSV
8
+ "drop_cols": ['x'],
9
+ },
10
+
11
+ "model": {
12
+ "bm_type": "rbm", # Currently only RBM is supported
13
+ "visible_blocks": {"v1": 4, "v2": 4},
14
+ "hidden_blocks": {"h1": 5, "h":10, "h2": 5},
15
+ "cross_block_restrictions": [("v1", "h2"), ("v2", "h1")],
16
+ "initialization": "random",
17
+ },
18
+
19
+ "preprocess": {
20
+ "q_low": 0.001,
21
+ "q_high": 0.999,
22
+ "add_missing_bit": True,
23
+ "max_categories": 200,
24
+ "min_category_freq": 2,
25
+ },
26
+
27
+ "dataloader": {
28
+ "batch_size": 256,
29
+ "split": [0.8, 0.1, 0.1],
30
+ "seed": 42,
31
+ "shuffle_train": True,
32
+ "num_workers": 0,
33
+ "drop_last_train": True,
34
+ "pin_memory": "auto", # "auto" | True | False
35
+ },
36
+
37
+ "train": {
38
+ "epochs": 100,
39
+ "lr": 1e-1,
40
+ "k": 10,
41
+ "kind": "mean-field",
42
+ "momentum": 0.9,
43
+ "weight_decay": 1e-4,
44
+ "clip_value": 0.05,
45
+ "clip_norm": 5.0,
46
+ "lr_schedule": {"mode": "cosine", "min_lr": 1e-4},
47
+ "sparse_hidden": True,
48
+ "rho": 0.1,
49
+ "lambda_sparse": 0.01,
50
+ "early_stopping": True,
51
+ "es_patience": 8,
52
+ },
53
+
54
+ "eval": {
55
+ "recon_k": 1,
56
+ },
57
+
58
+ "conditional": {
59
+ "clamp_idx": [0, 1, 2, 3],
60
+ "target_idx": [4, 5, 6, 7],
61
+ "n_samples": 100,
62
+ "burn_in": 500,
63
+ "thin": 10,
64
+ },
65
+ }
@@ -0,0 +1,147 @@
1
+ #!/usr/bin/env python
2
+ """Generate synthetic data for this project.
3
+
4
+ This script is auto-generated from the BoltzmaNN9 template.
5
+ Modify the GeneratorConfig parameters below to customize data generation.
6
+ """
7
+
8
+ import sys
9
+ from pathlib import Path
10
+
11
+ # Add parent directory to path for imports
12
+ sys.path.insert(0, str(Path(__file__).parent.parent.parent))
13
+
14
+ from src.boltzmann9.data_generator import SyntheticDataGenerator, GeneratorConfig
15
+
16
+
17
+ def visualize_data(full_df, generator, save_path, show_from=0, show_to=500):
18
+ """Create and save visualization of the generated data."""
19
+ try:
20
+ import matplotlib.pyplot as plt
21
+ import numpy as np
22
+ except ImportError as e:
23
+ print(f"[visualize_data] matplotlib/numpy import failed: {e}")
24
+ return
25
+
26
+ cfg = generator.config
27
+ K = cfg.k_bins
28
+ n_bits = generator.n_bits
29
+
30
+ # Clamp range to valid indices
31
+ show_from = max(0, show_from)
32
+ show_to = min(len(full_df), show_to)
33
+
34
+ fig = plt.figure(figsize=(15, 10))
35
+
36
+ # Plot 1: Continuous vs Discretized R
37
+ plt.subplot(3, 2, 1)
38
+ plt.plot(full_df["t"][show_from:show_to], full_df["R_continuous"][show_from:show_to],
39
+ alpha=0.5, label="Continuous", linewidth=1)
40
+ plt.plot(full_df["t"][show_from:show_to], full_df["R"][show_from:show_to],
41
+ alpha=0.8, label="Discretized", linewidth=1, drawstyle="steps-post")
42
+ plt.xlabel("Time")
43
+ plt.ylabel("R")
44
+ plt.legend()
45
+ plt.title("Continuous vs Discretized R")
46
+ plt.grid(True, alpha=0.3)
47
+
48
+ # Plot 2: Histogram of discretized R
49
+ plt.subplot(3, 2, 2)
50
+ plt.hist(full_df["R"], bins=K, edgecolor="black", alpha=0.7)
51
+ plt.xlabel("R (discretized)")
52
+ plt.ylabel("Frequency")
53
+ plt.title(f"Distribution across {K} bins")
54
+ plt.grid(True, alpha=0.3)
55
+
56
+ # Plot 3: Bin indices over time
57
+ plt.subplot(3, 2, 3)
58
+ plt.plot(full_df["t"][show_from:show_to], full_df["R_bin_index"][show_from:show_to],
59
+ drawstyle="steps-post")
60
+ plt.xlabel("Time")
61
+ plt.ylabel("Bin Index")
62
+ plt.title("Bin Index over Time")
63
+ plt.yticks(range(K))
64
+ plt.grid(True, alpha=0.3)
65
+
66
+ # Plot 4: Binary bits over time (heatmap style)
67
+ plt.subplot(3, 2, 4)
68
+ binary_matrix = np.array([full_df[f"R_bit_{i}"][show_from:show_to].values for i in range(n_bits)])
69
+ im = plt.imshow(binary_matrix, aspect="auto", cmap="binary", interpolation="nearest")
70
+ plt.xlabel("Time")
71
+ plt.ylabel("Bit Position")
72
+ plt.title(f"Binary Representation over Time ({n_bits} bits)")
73
+ plt.yticks(range(n_bits), [f"Bit {i}" for i in range(n_bits)])
74
+ plt.colorbar(im, label="Bit Value")
75
+
76
+ # Plot 5: Bin index distribution
77
+ plt.subplot(3, 2, 5)
78
+ bin_counts = full_df["R_bin_index"].value_counts().sort_index()
79
+ plt.bar(bin_counts.index, bin_counts.values, alpha=0.7, edgecolor="black")
80
+ plt.xlabel("Bin Index")
81
+ plt.ylabel("Frequency")
82
+ plt.title("Bin Index Distribution")
83
+ plt.xticks(range(K))
84
+ plt.grid(True, alpha=0.3)
85
+
86
+ # Plot 6: Decision variable vs bin index
87
+ plt.subplot(3, 2, 6)
88
+ df_valid = full_df[full_df["x"].notna()]
89
+ decision_by_bin = df_valid.groupby("R_bin_index")["x"].mean()
90
+ plt.bar(decision_by_bin.index, decision_by_bin.values, alpha=0.7, edgecolor="black")
91
+ plt.xlabel("Bin Index")
92
+ plt.ylabel("Mean Decision (x)")
93
+ plt.title("Average Decision by Bin")
94
+ plt.xticks(range(K))
95
+ plt.ylim([0, 1])
96
+ plt.grid(True, alpha=0.3)
97
+
98
+ plt.tight_layout()
99
+
100
+ fig.savefig(save_path, dpi=150, bbox_inches="tight")
101
+ plt.close(fig)
102
+ print(f"Visualization saved to: {save_path}")
103
+
104
+
105
+ def main():
106
+ """Generate synthetic data and save to data.csv."""
107
+ # ============================================================
108
+ # MODIFY THESE PARAMETERS TO CUSTOMIZE DATA GENERATION
109
+ # ============================================================
110
+ gen_config = GeneratorConfig(
111
+ n_samples=5000,
112
+ dt=0.1,
113
+ r_min=-2.0,
114
+ r_max=2.0,
115
+ k_bins=16,
116
+ spring_k=5.0,
117
+ sigma=1.0,
118
+ eq_interval=100,
119
+ m0=0.25,
120
+ sigma_eq=0.0,
121
+ lookahead=10,
122
+ )
123
+ seed = 42
124
+ show_from = 0
125
+ show_to = 100
126
+ # ============================================================
127
+
128
+ generator = SyntheticDataGenerator(gen_config)
129
+ generator.print_info()
130
+
131
+ full_df, simplified_df = generator.generate(seed=seed)
132
+
133
+ # Save to data folder
134
+ output_path = Path(__file__).parent / "data.csv"
135
+ simplified_df.to_csv(output_path, index=False)
136
+
137
+ print(f"\nSaved training data to: {output_path}")
138
+ print(f" Columns: {list(simplified_df.columns)}")
139
+ print(f" Rows: {len(simplified_df)}")
140
+
141
+ # Generate visualization
142
+ plot_path = Path(__file__).parent / "data_visualization.png"
143
+ visualize_data(full_df, generator, plot_path, show_from=show_from, show_to=show_to)
144
+
145
+
146
+ if __name__ == "__main__":
147
+ main()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: boltzmann9
3
- Version: 0.1.8
3
+ Version: 0.1.9
4
4
  Summary: Restricted Boltzmann Machine implementation in PyTorch
5
5
  License: MIT
6
6
  Requires-Python: >=3.10
@@ -7,13 +7,16 @@ boltzmann9/data_generator.py,sha256=aXm8b6O4qoXpDa8TCOdDxF_lDaMtcLNR0P06u2cXb1g,
7
7
  boltzmann9/model.py,sha256=npBDd_pow4f9gcsLIpWMmJLBB6sGljtaOluvXcFlLP0,31538
8
8
  boltzmann9/pipeline.py,sha256=C3L7tHVakrmvdnW9KYMXhXvscwG7TP8xmVHbunx1E8w,5965
9
9
  boltzmann9/preprocessor.py,sha256=USxpMCnC80wp4zO72lPxmFeVdfB4I5NE4P2YlojCVPQ,21758
10
- boltzmann9/project.py,sha256=Iap-nxhhAZDw2Y4yQg53gFQ0CQPCk2zfmZZC0bTPmxQ,5333
10
+ boltzmann9/project.py,sha256=p24xKk6pj1YxgGCfuJy2zEA8dSdG8zAGEGc6_wdxMl0,5326
11
11
  boltzmann9/run_utils.py,sha256=1eRDEkN8j1mG5g9iXArWRA1yQq6pw2agVUGV4Yhy0jk,7823
12
12
  boltzmann9/tester.py,sha256=yMYhYKWWiG-LhJQvKv7QmMCYPQLbmtMxFVrJ9iHZVs8,5659
13
13
  boltzmann9/utils.py,sha256=8ftNcbV4dz52-t0Ewu4E7CmkkVzPDa6kd8ZnxaBvhMM,1337
14
14
  boltzmann9/visualization.py,sha256=iPSytX7GMfmbRbR43xOmybLZBLmy58D4Rp8wasISMcs,3407
15
- boltzmann9-0.1.8.dist-info/METADATA,sha256=-mbyn-rmPZCIVK-k1FvsfZhwyIvZsCCfMwg0gqzXKsk,3097
16
- boltzmann9-0.1.8.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
17
- boltzmann9-0.1.8.dist-info/entry_points.txt,sha256=Y2wY1lqICCjTxuaPgme2QxkoAknrASKmhLljcb5FEyw,51
18
- boltzmann9-0.1.8.dist-info/top_level.txt,sha256=AE0_urGKOFSgKxXf7WQfjGDROWYppyI4tw1VIg-paQI,11
19
- boltzmann9-0.1.8.dist-info/RECORD,,
15
+ boltzmann9/templates/__init__.py,sha256=LEL_kwiunM5L5fdMLLJQPucOTDb4mL6uUR0U75eA7Ps,38
16
+ boltzmann9/templates/config.py,sha256=2_GRiJ_92lxSr3ypayAGrFoB58Ekyu264OUmoxytOlM,1568
17
+ boltzmann9/templates/synthetic_generator.py,sha256=cz5PZndsIbCh7Qfl40mXH5_yIBei71MvLLdV1mmbPFc,4858
18
+ boltzmann9-0.1.9.dist-info/METADATA,sha256=hrK5IhuTInrxJJwO1R5CYHJJcGSi16Lj2GZCr37kgs0,3097
19
+ boltzmann9-0.1.9.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
20
+ boltzmann9-0.1.9.dist-info/entry_points.txt,sha256=Y2wY1lqICCjTxuaPgme2QxkoAknrASKmhLljcb5FEyw,51
21
+ boltzmann9-0.1.9.dist-info/top_level.txt,sha256=AE0_urGKOFSgKxXf7WQfjGDROWYppyI4tw1VIg-paQI,11
22
+ boltzmann9-0.1.9.dist-info/RECORD,,