manifold-microscope 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. experiment_scripts/__init__.py +0 -0
  2. experiment_scripts/manifold_analysis/__init__.py +0 -0
  3. experiment_scripts/manifold_analysis/analysis.py +107 -0
  4. experiment_scripts/manifold_fitting/__init__.py +0 -0
  5. experiment_scripts/manifold_fitting/analysis.py +106 -0
  6. experiment_scripts/manifold_fitting/inference.py +147 -0
  7. experiment_scripts/manifold_fitting/mmls.py +63 -0
  8. experiment_scripts/manifold_fitting/training.py +220 -0
  9. experiment_scripts/model_configs.py +102 -0
  10. experiment_scripts/toy_manifolds_experiment/__init__.py +0 -0
  11. experiment_scripts/toy_manifolds_experiment/fit_and_get_measures.py +218 -0
  12. experiment_scripts/toy_manifolds_experiment/manifold_fitting_denoising_autoencoder.py +156 -0
  13. experiment_scripts/toy_manifolds_experiment/manifold_fitting_no_noise.py +87 -0
  14. manifold_microscope-0.0.1.dist-info/METADATA +182 -0
  15. manifold_microscope-0.0.1.dist-info/RECORD +60 -0
  16. manifold_microscope-0.0.1.dist-info/WHEEL +5 -0
  17. manifold_microscope-0.0.1.dist-info/licenses/LICENSE +29 -0
  18. manifold_microscope-0.0.1.dist-info/licenses/representation_learning/beta_vae/LICENSE +21 -0
  19. manifold_microscope-0.0.1.dist-info/licenses/representation_learning/beta_vae/NOTICE.md +17 -0
  20. manifold_microscope-0.0.1.dist-info/top_level.txt +3 -0
  21. microscope/__init__.py +0 -0
  22. microscope/computations_grid/__init__.py +0 -0
  23. microscope/computations_grid/basic.py +166 -0
  24. microscope/computations_grid/curvature.py +240 -0
  25. microscope/computations_grid/data_analysis/__init__.py +0 -0
  26. microscope/computations_grid/data_analysis/data_analysis.py +630 -0
  27. microscope/computations_grid/data_analysis/merge_analysis_outputs.py +314 -0
  28. microscope/computations_grid/data_analysis/run_data_analysis.py +229 -0
  29. microscope/computations_grid/reach.py +148 -0
  30. microscope/computations_grid/volume.py +100 -0
  31. microscope/cyclic_dimensions.py +57 -0
  32. microscope/datasets/__init__.py +0 -0
  33. microscope/datasets/coil20.py +171 -0
  34. microscope/datasets/custom_dsprites.py +392 -0
  35. microscope/datasets/dataset_split.py +120 -0
  36. microscope/datasets/generic_dataset_loader.py +476 -0
  37. microscope/datasets/image_transforms.py +156 -0
  38. microscope/datasets/noise_adding.py +103 -0
  39. microscope/datasets/original_dsprites.py +58 -0
  40. microscope/datasets/toy_manifolds.py +686 -0
  41. microscope/manifold_examples/__init__.py +0 -0
  42. microscope/manifold_examples/ellipsoid.py +77 -0
  43. microscope/manifold_examples/hyperboloid.py +47 -0
  44. microscope/manifold_examples/plotting.py +74 -0
  45. microscope/manifold_examples/sampling_grid.py +103 -0
  46. microscope/manifold_examples/sampling_uniform.py +273 -0
  47. microscope/manifold_examples/sphere.py +41 -0
  48. microscope/manifold_examples/symbolic_computations.py +332 -0
  49. microscope/manifold_examples/utils.py +58 -0
  50. microscope/patches.py +120 -0
  51. representation_learning/__init__.py +0 -0
  52. representation_learning/beta_vae/LICENSE +21 -0
  53. representation_learning/beta_vae/NOTICE.md +17 -0
  54. representation_learning/beta_vae/__init__.py +0 -0
  55. representation_learning/beta_vae/dataset.py +106 -0
  56. representation_learning/beta_vae/inference_intermediate_layers.py +185 -0
  57. representation_learning/beta_vae/main.py +79 -0
  58. representation_learning/beta_vae/model.py +172 -0
  59. representation_learning/beta_vae/solver.py +432 -0
  60. representation_learning/beta_vae/utils.py +50 -0
File without changes
File without changes
@@ -0,0 +1,107 @@
1
+ import itertools
2
+ from pathlib import Path
3
+
4
+ import typer
5
+ from tqdm import tqdm
6
+
7
+ from microscope.datasets.generic_dataset_loader import DatasetName
8
+ from microscope.computations_grid.data_analysis.run_data_analysis import main as run_analysis
9
+
10
+ app = typer.Typer(pretty_exceptions_enable=False)
11
+
12
+
13
+ def analysis_on_model(
14
+ inference_path: Path,
15
+ output_path: Path,
16
+ dataset_name: DatasetName,
17
+ model_type: str,
18
+ number_of_dims: int,
19
+ only_evolution: bool,
20
+ normalize_for_volume: bool,
21
+ skip_done: bool,
22
+ n_samples_for_plots: int = 50_000
23
+ ) -> None:
24
+ if output_path.exists() and skip_done:
25
+ print(f"Skipping {output_path.name} as it exists.")
26
+ return None
27
+
28
+ print(f"Analysis on {output_path.name}.")
29
+ run_analysis(
30
+ inference_path=inference_path,
31
+ output_path=output_path,
32
+ dataset=dataset_name,
33
+ model_type=model_type,
34
+ number_of_dims=number_of_dims,
35
+ only_evolution=only_evolution,
36
+ normalize_for_volume=normalize_for_volume,
37
+ n_samples_for_plots=n_samples_for_plots
38
+ )
39
+
40
+
41
+ @app.command()
42
+ def run_analyses(
43
+ inference_path: Path = typer.Option(...),
44
+ output_path: Path = typer.Option(...),
45
+ only_evolution: bool = True,
46
+ normalize_for_volume: bool = True,
47
+ skip_done: bool = True,
48
+ n_samples_for_plots: int = 50_000
49
+ ) -> None:
50
+ dataset_name_list = [
51
+ "custom_dsprites_balanced",
52
+ "extended_coil20"
53
+ ]
54
+ model_type_list = [
55
+ "beta_vae",
56
+ "mae"
57
+ ]
58
+ training_ratio_per_dim_list = [
59
+ 1.0
60
+ ]
61
+ number_of_dims_list = [
62
+ # 1,
63
+ 2,
64
+ 3,
65
+ # 4
66
+ ]
67
+ # No noise.
68
+ noise_sigma_list = [
69
+ 0
70
+ ]
71
+
72
+ hyperparameter_grid = list(itertools.product(
73
+ dataset_name_list,
74
+ model_type_list,
75
+ training_ratio_per_dim_list,
76
+ number_of_dims_list,
77
+ noise_sigma_list
78
+ ))
79
+
80
+ for dataset_name, model_type, training_ratio_per_dim, number_of_dims, noise_sigma in tqdm(hyperparameter_grid):
81
+ # Skip dimension 4 for COIL20.
82
+ if (number_of_dims == 4) and (dataset_name == "extended_coil20"):
83
+ continue
84
+ if (number_of_dims == 4) and (model_type == "mae"):
85
+ continue
86
+ model_dir = "__".join([
87
+ dataset_name,
88
+ model_type,
89
+ str(training_ratio_per_dim),
90
+ str(number_of_dims),
91
+ str(noise_sigma)
92
+ ])
93
+ analysis_on_model(
94
+ output_path=output_path / model_dir,
95
+ inference_path=inference_path / model_dir,
96
+ dataset_name=dataset_name,
97
+ model_type=model_type,
98
+ number_of_dims=number_of_dims,
99
+ only_evolution=only_evolution,
100
+ normalize_for_volume=normalize_for_volume,
101
+ skip_done=skip_done,
102
+ n_samples_for_plots=n_samples_for_plots
103
+ )
104
+
105
+
106
+ if __name__ == "__main__":
107
+ app()
File without changes
@@ -0,0 +1,106 @@
1
+ import itertools
2
+ from pathlib import Path
3
+
4
+ import typer
5
+ from tqdm import tqdm
6
+
7
+ from microscope.datasets.generic_dataset_loader import DatasetName
8
+ from microscope.computations_grid.data_analysis.run_data_analysis import main as run_analysis
9
+
10
+ app = typer.Typer(pretty_exceptions_enable=False)
11
+
12
+
13
+ def analysis_on_model(
14
+ inference_path: Path,
15
+ output_path: Path,
16
+ dataset_name: DatasetName,
17
+ model_type: str,
18
+ number_of_dims: int,
19
+ only_evolution: bool,
20
+ normalize_for_volume: bool,
21
+ skip_done: bool
22
+ ) -> None:
23
+ if output_path.exists() and skip_done:
24
+ print(f"Skipping {output_path.name} as it exists.")
25
+ return None
26
+
27
+ print(f"Analysis on {output_path.name}.")
28
+ run_analysis(
29
+ inference_path=inference_path,
30
+ output_path=output_path,
31
+ dataset=dataset_name,
32
+ model_type=model_type,
33
+ number_of_dims=number_of_dims,
34
+ only_evolution=only_evolution,
35
+ normalize_for_volume=normalize_for_volume
36
+ )
37
+
38
+
39
+ @app.command()
40
+ def run_analyses(
41
+ inference_path: Path = typer.Option(...),
42
+ output_path: Path = typer.Option(...),
43
+ only_evolution: bool = True,
44
+ normalize_for_volume: bool = True,
45
+ skip_done: bool = True
46
+ ) -> None:
47
+ dataset_name_list = [
48
+ "custom_dsprites_balanced",
49
+ "extended_coil20"
50
+ ]
51
+ model_type_list = [
52
+ "beta_vae",
53
+ "mae"
54
+ ]
55
+ training_ratio_per_dim_list = [
56
+ 0.4,
57
+ 0.5,
58
+ 0.6,
59
+ 1.0
60
+ ]
61
+ number_of_dims_list = [
62
+ 1,
63
+ 2,
64
+ 3,
65
+ 4
66
+ ]
67
+ # No noise.
68
+ noise_sigma_list = [
69
+ 0
70
+ ]
71
+
72
+ hyperparameter_grid = list(itertools.product(
73
+ dataset_name_list,
74
+ model_type_list,
75
+ training_ratio_per_dim_list,
76
+ number_of_dims_list,
77
+ noise_sigma_list
78
+ ))
79
+
80
+ for dataset_name, model_type, training_ratio_per_dim, number_of_dims, noise_sigma in tqdm(hyperparameter_grid):
81
+ # Skip dimension 4 for COIL20.
82
+ if (number_of_dims == 4) and (dataset_name == "extended_coil20"):
83
+ continue
84
+ if (number_of_dims == 4) and (model_type == "mae"):
85
+ continue
86
+ model_dir = "__".join([
87
+ dataset_name,
88
+ model_type,
89
+ str(training_ratio_per_dim),
90
+ str(number_of_dims),
91
+ str(noise_sigma)
92
+ ])
93
+ analysis_on_model(
94
+ output_path=output_path / model_dir,
95
+ inference_path=inference_path / model_dir,
96
+ dataset_name=dataset_name,
97
+ model_type=model_type,
98
+ number_of_dims=number_of_dims,
99
+ only_evolution=only_evolution,
100
+ normalize_for_volume=normalize_for_volume,
101
+ skip_done=skip_done
102
+ )
103
+
104
+
105
+ if __name__ == "__main__":
106
+ app()
@@ -0,0 +1,147 @@
1
+ import itertools
2
+ from pathlib import Path
3
+
4
+ import typer
5
+ from tqdm import tqdm
6
+
7
+ from microscope.datasets.generic_dataset_loader import DatasetName
8
+ from representation_learning.mae.inference_intermediate_layers import main as inference_mae
9
+ from representation_learning.beta_vae.inference_intermediate_layers import main as inference_beta_vae
10
+
11
+ app = typer.Typer(pretty_exceptions_enable=False)
12
+
13
+
14
+ def inference_on_model(
15
+ output_path: Path,
16
+ model_path: Path,
17
+ only_final_model: bool,
18
+ only_latent_and_output: bool,
19
+ dataset_name: DatasetName,
20
+ model_type: str,
21
+ number_of_dims: int,
22
+ skip_done: bool
23
+ ) -> None:
24
+ if output_path.exists() and skip_done:
25
+ print(f"Skipping {output_path.name} as it exists.")
26
+ return None
27
+
28
+ print(f"Inference on {output_path.name}.")
29
+ if model_type == "beta_vae":
30
+ inference_fn = inference_beta_vae
31
+ elif model_type == "mae":
32
+ inference_fn = inference_mae
33
+ else:
34
+ raise ValueError(f"Unknown model type: {model_type}.")
35
+
36
+ if only_final_model:
37
+ checkpoints_path = model_path / "checkpoints"
38
+ final_candidates = [
39
+ ckpt for ckpt in checkpoints_path.glob("*")
40
+ if "last" in ckpt.name or "final" in ckpt.name
41
+ ]
42
+ if len(final_candidates) != 1:
43
+ raise ValueError(f"Found the following final model candidates {final_candidates}. Expected as single one.")
44
+ checkpoint_path = final_candidates[0]
45
+ random_model = False
46
+
47
+ inference_fn(
48
+ dataset=dataset_name,
49
+ number_of_dims=number_of_dims,
50
+ only_latent_and_output=only_latent_and_output,
51
+ checkpoint_path=checkpoint_path,
52
+ output_dir=output_path,
53
+ random_model=random_model
54
+ )
55
+ else:
56
+ checkpoints_path = model_path / "checkpoints"
57
+ checkpoint_paths = [p for p in checkpoints_path.glob("*") if "npz" not in p.suffix]
58
+ final_candidates = [
59
+ ckpt for ckpt in checkpoint_paths
60
+ if "last" in ckpt.name or "final" in ckpt.name
61
+ ]
62
+ if len(final_candidates) != 1:
63
+ raise ValueError(f"Found the following final model candidates {final_candidates}. Expected as single one.")
64
+ final_checkpoint = final_candidates[0]
65
+ middle_checkpoint_idx = len(checkpoint_paths) // 2
66
+ middle_checkpoint = checkpoint_paths[middle_checkpoint_idx]
67
+ first_checkpoint = checkpoint_paths[0]
68
+
69
+ for checkpoint_path in [first_checkpoint, middle_checkpoint, final_checkpoint]:
70
+ inference_fn(
71
+ dataset=dataset_name,
72
+ number_of_dims=number_of_dims,
73
+ only_latent_and_output=only_latent_and_output,
74
+ checkpoint_path=checkpoint_path,
75
+ output_dir=output_path,
76
+ random_model=False
77
+ )
78
+
79
+
80
+ @app.command()
81
+ def run_inferences(
82
+ training_path: Path = typer.Option(...),
83
+ output_path: Path = typer.Option(...),
84
+ only_final_model: bool = True,
85
+ only_latent_and_output: bool = True,
86
+ skip_done: bool = True
87
+ ) -> None:
88
+ dataset_name_list = [
89
+ "custom_dsprites_balanced",
90
+ "extended_coil20"
91
+ ]
92
+ model_type_list = [
93
+ "beta_vae",
94
+ "mae"
95
+ ]
96
+ training_ratio_per_dim_list = [
97
+ 0.4,
98
+ 0.5,
99
+ 0.6,
100
+ 1.0
101
+ ]
102
+ number_of_dims_list = [
103
+ 1,
104
+ 2,
105
+ 3,
106
+ 4
107
+ ]
108
+ # No noise.
109
+ noise_sigma_list = [
110
+ 0
111
+ ]
112
+
113
+ hyperparameter_grid = list(itertools.product(
114
+ dataset_name_list,
115
+ model_type_list,
116
+ training_ratio_per_dim_list,
117
+ number_of_dims_list,
118
+ noise_sigma_list
119
+ ))
120
+
121
+ for dataset_name, model_type, training_ratio_per_dim, number_of_dims, noise_sigma in tqdm(hyperparameter_grid):
122
+ # Skip dimension 4 for COIL20.
123
+ if (number_of_dims == 4) and (dataset_name == "extended_coil20"):
124
+ continue
125
+ if (number_of_dims == 4) and (model_type == "mae"):
126
+ continue
127
+ model_dir = "__".join([
128
+ dataset_name,
129
+ model_type,
130
+ str(training_ratio_per_dim),
131
+ str(number_of_dims),
132
+ str(noise_sigma)
133
+ ])
134
+ inference_on_model(
135
+ output_path=output_path / model_dir,
136
+ model_path=training_path / model_dir,
137
+ only_final_model=only_final_model,
138
+ only_latent_and_output=only_latent_and_output,
139
+ dataset_name=dataset_name,
140
+ model_type=model_type,
141
+ number_of_dims=number_of_dims,
142
+ skip_done=skip_done
143
+ )
144
+
145
+
146
+ if __name__ == "__main__":
147
+ app()
@@ -0,0 +1,63 @@
1
+ import pickle
2
+ from pathlib import Path
3
+
4
+ import numpy as np
5
+
6
+ from experiment_scripts.model_configs import MMLSConfig
7
+ from experiment_scripts.toy_manifolds_experiment.manifold_fitting_no_noise import ANNMMLSProjector
8
+ from microscope.datasets.generic_dataset_loader import DatasetName, load_dataset_fixed_test_split
9
+
10
+
11
+ def fit_mmls(model_config: MMLSConfig) -> None:
12
+ output_dir = Path(model_config.output_dir)
13
+ exported_datasets_dir = Path(model_config.exported_datasets_dir)
14
+ dataset = DatasetName[model_config.dataset]
15
+ number_of_dims = model_config.number_of_dims
16
+ training_ratio = model_config.training_ratio
17
+ ratio_per_dim = model_config.ratio_per_dim
18
+ noise_sigma = model_config.noise_sigma
19
+ number_of_neighbors = model_config.number_of_neighbors
20
+ verbose = model_config.verbose
21
+ device = model_config.device
22
+
23
+ config = locals()
24
+
25
+ # Load the exported dataset.
26
+ data_train, data_test, _, _ = load_dataset_fixed_test_split(
27
+ datasets_dir=exported_datasets_dir,
28
+ dataset_name=dataset,
29
+ number_of_dims=number_of_dims,
30
+ ratio_per_dim=ratio_per_dim,
31
+ training_ratio=training_ratio,
32
+ noise_sigma=noise_sigma,
33
+ weight_subsampling_by_manifold_volume=True
34
+ )
35
+
36
+ # Load the dataset.
37
+ data_train = (data_train - data_train.min()) / (data_train.max() - data_train.min())
38
+ data_test = (data_test - data_test.min()) / (data_test.max() - data_test.min())
39
+ data_train = data_train.reshape(data_train.shape[0], np.prod(data_train.shape[1:]))
40
+ data_test = data_test.reshape(data_test.shape[0], np.prod(data_test.shape[1:]))
41
+
42
+ if len(data_train) < number_of_neighbors:
43
+ print(
44
+ f"Skipping the training ratio {training_ratio} as it results to {len(data_train)} points which are less "
45
+ f"than the number of neighbors {number_of_neighbors}."
46
+ )
47
+ return
48
+
49
+ projector = ANNMMLSProjector(data_train, d=number_of_dims, k=number_of_neighbors, verbose=verbose, device=device)
50
+ prediction = projector.project(data_test)
51
+
52
+ distances = np.linalg.norm(data_test - prediction, axis=-1)
53
+ hausdorff_distance = distances.max()
54
+
55
+ results = dict(
56
+ number_of_train_points=len(data_train),
57
+ number_of_test_points=len(data_test),
58
+ pointwise_distances=distances,
59
+ hausdorff_distance=hausdorff_distance
60
+ )
61
+
62
+ with open(output_dir / f"distance_results_{dataset}.pkl", "wb") as f:
63
+ pickle.dump(results, f, -1)
@@ -0,0 +1,220 @@
1
+ import itertools
2
+ from dataclasses import asdict, replace
3
+ from pathlib import Path
4
+
5
+ import numpy as np
6
+ import torch
7
+ import typer
8
+ import yaml
9
+
10
+ from experiment_scripts.manifold_fitting.mmls import fit_mmls
11
+ from experiment_scripts.model_configs import BetaVAEConfig, MMLSConfig
12
+ from microscope.datasets.generic_dataset_loader import DatasetName, export_fixed_grid_test_set_and_rest_for_train
13
+ from representation_learning.beta_vae.solver import Solver
14
+
15
+ app = typer.Typer(pretty_exceptions_enable=False)
16
+
17
+
18
+ def update_max_epochs(
19
+ config: BetaVAEConfig,
20
+ mini_test_run: bool,
21
+ training_ratio: float
22
+ ) -> BetaVAEConfig:
23
+ if mini_test_run:
24
+ config = replace(config, max_epochs=1, plot_interval=2)
25
+ elif training_ratio < 1.0:
26
+ ratio_correction_coeff = 1 / training_ratio
27
+ max_epochs = int(ratio_correction_coeff * config.max_epochs)
28
+ plot_interval = int(ratio_correction_coeff * config.plot_interval)
29
+ config = replace(
30
+ config,
31
+ max_epochs=max_epochs,
32
+ plot_interval=plot_interval
33
+ )
34
+
35
+ return config
36
+
37
+
38
+ def train_model(
39
+ output_path: Path,
40
+ exported_datasets_dir: Path,
41
+ seed: int,
42
+ dataset_name: str,
43
+ model_type: str,
44
+ training_ratio: float,
45
+ ratio_per_dim: bool,
46
+ number_of_dims: int,
47
+ noise_sigma: float,
48
+ skip_done: bool,
49
+ device: str = "cpu"
50
+ ) -> None:
51
+ if output_path.exists() and skip_done:
52
+ print(f"Skipping {output_path.name} as it exists.")
53
+ return None
54
+
55
+ print(f"Training on {output_path.name}.")
56
+ output_path.mkdir(parents=True)
57
+ experiment_config = dict(
58
+ output_path=str(output_path),
59
+ seed=seed,
60
+ dataset_name=dataset_name,
61
+ model_type=model_type,
62
+ training_ratio=training_ratio,
63
+ ratio_per_dim=ratio_per_dim,
64
+ number_of_dims=number_of_dims,
65
+ noise_sigma=noise_sigma,
66
+ device=device,
67
+ )
68
+
69
+ match model_type:
70
+ case "beta_vae":
71
+ model_config = BetaVAEConfig(
72
+ dataset=dataset_name,
73
+ ckpt_dir=output_path / "checkpoints",
74
+ exported_datasets_dir=exported_datasets_dir,
75
+ output_dir=output_path,
76
+ training_ratio=training_ratio,
77
+ ratio_per_dim=ratio_per_dim,
78
+ number_of_dims=number_of_dims,
79
+ noise_sigma=noise_sigma,
80
+ device=device
81
+ )
82
+
83
+ if dataset_name == DatasetName.extended_coil20:
84
+ model_config = replace(
85
+ model_config,
86
+ max_epochs=int(1e5),
87
+ objective="H",
88
+ model="H",
89
+ lr=1e-4,
90
+ loss_threshold=10
91
+ )
92
+
93
+ # Save config.
94
+ config = dict(
95
+ experiment_config=experiment_config,
96
+ model_config=asdict(model_config)
97
+ )
98
+ with open(output_path / "config.yml", "w") as f:
99
+ yaml.dump(config, f)
100
+
101
+ torch.manual_seed(seed)
102
+ if device == "mps" and torch.backends.mps.is_available():
103
+ torch.mps.manual_seed(seed)
104
+ elif device.startswith("cuda") and torch.cuda.is_available():
105
+ torch.cuda.manual_seed(seed)
106
+ np.random.seed(seed)
107
+
108
+ net = Solver(args=model_config)
109
+ net.train()
110
+ case "MMLS":
111
+ model_config = MMLSConfig(
112
+ output_dir=str(output_path),
113
+ exported_datasets_dir=exported_datasets_dir,
114
+ dataset=dataset_name,
115
+ number_of_dims=number_of_dims,
116
+ training_ratio=training_ratio,
117
+ ratio_per_dim=ratio_per_dim,
118
+ noise_sigma=noise_sigma,
119
+ number_of_neighbors=2*2**number_of_dims,
120
+ device=device
121
+ )
122
+ fit_mmls(model_config)
123
+ case _:
124
+ raise ValueError(f"Unknown model type: {model_type}.")
125
+
126
+
127
+ @app.command()
128
+ def run_trainings(
129
+ output_path: Path = typer.Option(...),
130
+ n_experiment_repetitions: int = 1,
131
+ max_test_size: int = 500,
132
+ seed: int = 42,
133
+ skip_done: bool = True,
134
+ device: str = "cuda:0"
135
+ ) -> None:
136
+ exported_datasets_dir = output_path / "datasets"
137
+ exported_datasets_dir.mkdir(exist_ok=True, parents=True)
138
+ dataset_name_list = [
139
+ "custom_dsprites_balanced",
140
+ "extended_coil20"
141
+ ]
142
+ model_type_list = [
143
+ "beta_vae",
144
+ "MMLS"
145
+ ]
146
+ ratio_per_dim = False
147
+ # No noise.
148
+ noise_sigma_list = [
149
+ 0
150
+ ]
151
+
152
+ hyperparameter_grid = list(itertools.product(
153
+ dataset_name_list,
154
+ model_type_list,
155
+ noise_sigma_list
156
+ ))
157
+
158
+ for dataset_name, model_type, noise_sigma in hyperparameter_grid:
159
+ match dataset_name:
160
+ case "extended_coil20":
161
+ training_ratio_list = [0.01, 0.02, 0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
162
+
163
+ number_of_dims_list = [
164
+ 1,
165
+ 2,
166
+ 3
167
+ ]
168
+ case "custom_dsprites_balanced":
169
+ training_ratio_list = [0.01, 0.02, 0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
170
+ number_of_dims_list = [
171
+ 1,
172
+ 2,
173
+ 3,
174
+ 4
175
+ ]
176
+ case _:
177
+ raise ValueError(f"Unknown dataset name: {dataset_name}.")
178
+
179
+ second_level_hyperparameter_grid = list(itertools.product(
180
+ number_of_dims_list,
181
+ training_ratio_list
182
+ ))
183
+ for number_of_dims, training_ratio in second_level_hyperparameter_grid:
184
+ export_fixed_grid_test_set_and_rest_for_train(
185
+ dataset_name=dataset_name,
186
+ number_of_dims=number_of_dims,
187
+ output_dir=exported_datasets_dir,
188
+ max_test_size=max_test_size,
189
+ device=device
190
+ )
191
+
192
+ for repetition_n in range(n_experiment_repetitions):
193
+ if number_of_dims == 1:
194
+ if training_ratio < 0.1:
195
+ continue
196
+ model_dir = "__".join([
197
+ dataset_name,
198
+ model_type,
199
+ str(round(training_ratio, 3)),
200
+ str(number_of_dims),
201
+ str(noise_sigma),
202
+ str(repetition_n)
203
+ ])
204
+ train_model(
205
+ output_path=output_path / model_dir,
206
+ exported_datasets_dir=exported_datasets_dir,
207
+ seed=seed,
208
+ dataset_name=dataset_name,
209
+ model_type=model_type,
210
+ training_ratio=training_ratio,
211
+ ratio_per_dim=ratio_per_dim,
212
+ number_of_dims=number_of_dims,
213
+ noise_sigma=noise_sigma,
214
+ skip_done=skip_done,
215
+ device=device
216
+ )
217
+
218
+
219
+ if __name__ == "__main__":
220
+ app()