smartpool-examples 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. smartpool_examples-0.1.0/LICENSE +21 -0
  2. smartpool_examples-0.1.0/PKG-INFO +84 -0
  3. smartpool_examples-0.1.0/README.md +61 -0
  4. smartpool_examples-0.1.0/pyproject.toml +32 -0
  5. smartpool_examples-0.1.0/setup.cfg +4 -0
  6. smartpool_examples-0.1.0/smartpool_examples/__init__.py +0 -0
  7. smartpool_examples-0.1.0/smartpool_examples/count_prime/__init__.py +0 -0
  8. smartpool_examples-0.1.0/smartpool_examples/count_prime/__main__.py +31 -0
  9. smartpool_examples-0.1.0/smartpool_examples/count_prime/count_prime.py +18 -0
  10. smartpool_examples-0.1.0/smartpool_examples/cross_validation/__init__.py +0 -0
  11. smartpool_examples-0.1.0/smartpool_examples/cross_validation/__main__.py +244 -0
  12. smartpool_examples-0.1.0/smartpool_examples/cross_validation/config.py +11 -0
  13. smartpool_examples-0.1.0/smartpool_examples/cross_validation/data_utils.py +47 -0
  14. smartpool_examples-0.1.0/smartpool_examples/cross_validation/model_utils.py +173 -0
  15. smartpool_examples-0.1.0/smartpool_examples/cross_validation/models/LeNet5.py +27 -0
  16. smartpool_examples-0.1.0/smartpool_examples/cross_validation/models/MLP.py +21 -0
  17. smartpool_examples-0.1.0/smartpool_examples/cross_validation/models/ModernCNN.py +33 -0
  18. smartpool_examples-0.1.0/smartpool_examples/cross_validation/models/ResNeXt.py +57 -0
  19. smartpool_examples-0.1.0/smartpool_examples/cross_validation/models/ResNeXtV2.py +59 -0
  20. smartpool_examples-0.1.0/smartpool_examples/cross_validation/models/ResNet.py +54 -0
  21. smartpool_examples-0.1.0/smartpool_examples/cross_validation/models/ResNetV2.py +55 -0
  22. smartpool_examples-0.1.0/smartpool_examples/cross_validation/models/__init__.py +9 -0
  23. smartpool_examples-0.1.0/smartpool_examples/cross_validation/visualization.py +86 -0
  24. smartpool_examples-0.1.0/smartpool_examples.egg-info/PKG-INFO +84 -0
  25. smartpool_examples-0.1.0/smartpool_examples.egg-info/SOURCES.txt +26 -0
  26. smartpool_examples-0.1.0/smartpool_examples.egg-info/dependency_links.txt +1 -0
  27. smartpool_examples-0.1.0/smartpool_examples.egg-info/requires.txt +8 -0
  28. smartpool_examples-0.1.0/smartpool_examples.egg-info/top_level.txt +1 -0
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2026 王炳辉
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,84 @@
1
+ Metadata-Version: 2.4
2
+ Name: smartpool-examples
3
+ Version: 0.1.0
4
+ Summary: Examples for smartpool.
5
+ Author-email: "王炳辉 (Bing-Hui WANG)" <binghui.wang@foxmail.com>
6
+ License: MIT
7
+ Project-URL: Homepage, https://github.com/Time-Coder/smartpool
8
+ Project-URL: Repository, https://github.com/Time-Coder/smartpool.git
9
+ Classifier: Programming Language :: Python :: 3
10
+ Classifier: License :: OSI Approved :: MIT License
11
+ Classifier: Operating System :: OS Independent
12
+ Description-Content-Type: text/markdown
13
+ License-File: LICENSE
14
+ Requires-Dist: pysmartpool
15
+ Requires-Dist: matplotlib
16
+ Requires-Dist: scikit-learn
17
+ Requires-Dist: numpy
18
+ Requires-Dist: rich
19
+ Requires-Dist: click
20
+ Requires-Dist: joblib
21
+ Requires-Dist: datawatcher
22
+ Dynamic: license-file
23
+
24
+ # SmartPool Examples
25
+
26
+ This package contains practical examples demonstrating the capabilities of SmartPool for various computational tasks.
27
+
28
+ ## Examples Overview
29
+
30
+ ### 1. Prime Number Counting (`count_prime`)
31
+
32
+ Count the number of prime numbers below 10000 using smartpool.ProcessPool.
33
+ Demonstrates basic usage of smartpool.ProcessPool.
34
+
35
+ #### Running the Example
36
+
37
+ ```bash
38
+ python -m smartpool_examples.count_prime
39
+ ```
40
+
41
+ ### 2. Cross-Validation for Deep Learning models (`cross_validation`)
42
+
43
+ Demonstrates SmartPool's capabilities for machine learning workloads with GPU resource management.
44
+
45
+ #### Running the Example
46
+
47
+ ```bash
48
+ # Using ProcessPool
49
+ python -m smartpool_examples.cross_validation --pool smartpool.ProcessPool
50
+
51
+ # Using ThreadPool
52
+ python -m smartpool_examples.cross_validation --pool smartpool.ThreadPool
53
+
54
+ # Using multiprocessing.Pool
55
+ python -m smartpool_examples.cross_validation --pool multiprocessing.Pool
56
+
57
+ # Using concurrent.futures.ProcessPoolExecutor
58
+ python -m smartpool_examples.cross_validation --pool concurrent.futures.ProcessPoolExecutor
59
+
60
+ # Using concurrent.futures.ThreadPoolExecutor
61
+ python -m smartpool_examples.cross_validation --pool concurrent.futures.ThreadPoolExecutor
62
+
63
+ # Using joblib.Parallel(backend='loky')
64
+ python -m smartpool_examples.cross_validation --pool joblib.Parallel(backend='loky')
65
+
66
+ # Using joblib.Parallel(backend='threading')
67
+ python -m smartpool_examples.cross_validation --pool joblib.Parallel(backend='threading')
68
+
69
+ # Using Ray
70
+ python -m smartpool_examples.cross_validation --pool ray
71
+ ```
72
+
73
+ #### What it Demonstrates
74
+
75
+ - GPU memory management and core allocation
76
+ - Automatic device selection (CPU vs GPU)
77
+ - Cross-validation pipeline parallelization
78
+ - Resource monitoring during training
79
+ - Performance comparison with external frameworks
80
+
81
+
82
+ ## License
83
+
84
+ MIT License - see main smartpool repository for details
@@ -0,0 +1,61 @@
1
+ # SmartPool Examples
2
+
3
+ This package contains practical examples demonstrating the capabilities of SmartPool for various computational tasks.
4
+
5
+ ## Examples Overview
6
+
7
+ ### 1. Prime Number Counting (`count_prime`)
8
+
9
+ Count the number of prime numbers below 10000 using smartpool.ProcessPool.
10
+ Demonstrates basic usage of smartpool.ProcessPool.
11
+
12
+ #### Running the Example
13
+
14
+ ```bash
15
+ python -m smartpool_examples.count_prime
16
+ ```
17
+
18
+ ### 2. Cross-Validation for Deep Learning models (`cross_validation`)
19
+
20
+ Demonstrates SmartPool's capabilities for machine learning workloads with GPU resource management.
21
+
22
+ #### Running the Example
23
+
24
+ ```bash
25
+ # Using ProcessPool
26
+ python -m smartpool_examples.cross_validation --pool smartpool.ProcessPool
27
+
28
+ # Using ThreadPool
29
+ python -m smartpool_examples.cross_validation --pool smartpool.ThreadPool
30
+
31
+ # Using multiprocessing.Pool
32
+ python -m smartpool_examples.cross_validation --pool multiprocessing.Pool
33
+
34
+ # Using concurrent.futures.ProcessPoolExecutor
35
+ python -m smartpool_examples.cross_validation --pool concurrent.futures.ProcessPoolExecutor
36
+
37
+ # Using concurrent.futures.ThreadPoolExecutor
38
+ python -m smartpool_examples.cross_validation --pool concurrent.futures.ThreadPoolExecutor
39
+
40
+ # Using joblib.Parallel(backend='loky')
41
+ python -m smartpool_examples.cross_validation --pool joblib.Parallel(backend='loky')
42
+
43
+ # Using joblib.Parallel(backend='threading')
44
+ python -m smartpool_examples.cross_validation --pool joblib.Parallel(backend='threading')
45
+
46
+ # Using Ray
47
+ python -m smartpool_examples.cross_validation --pool ray
48
+ ```
49
+
50
+ #### What it Demonstrates
51
+
52
+ - GPU memory management and core allocation
53
+ - Automatic device selection (CPU vs GPU)
54
+ - Cross-validation pipeline parallelization
55
+ - Resource monitoring during training
56
+ - Performance comparison with external frameworks
57
+
58
+
59
+ ## License
60
+
61
+ MIT License - see main smartpool repository for details
@@ -0,0 +1,32 @@
1
+ [build-system]
2
+ requires = ["setuptools>=45", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "smartpool-examples"
7
+ version = "0.1.0"
8
+ description = "Examples for smartpool."
9
+ readme = "README.md"
10
+ authors = [
11
+ { name = "王炳辉 (Bing-Hui WANG)", email = "binghui.wang@foxmail.com" }
12
+ ]
13
+ license = { text = "MIT" }
14
+ classifiers = [
15
+ "Programming Language :: Python :: 3",
16
+ "License :: OSI Approved :: MIT License",
17
+ "Operating System :: OS Independent",
18
+ ]
19
+ dependencies = [
20
+ "pysmartpool",
21
+ "matplotlib",
22
+ "scikit-learn",
23
+ "numpy",
24
+ "rich",
25
+ "click",
26
+ "joblib",
27
+ "datawatcher"
28
+ ]
29
+
30
+ [project.urls]
31
+ Homepage = "https://github.com/Time-Coder/smartpool"
32
+ Repository = "https://github.com/Time-Coder/smartpool.git"
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1,31 @@
1
+ import os
2
+ import sys
3
+
4
+ self_folder = os.path.dirname(os.path.abspath(__file__)).replace("\\", "/")
5
+ sys.path.append(self_folder)
6
+
7
+ from smartpool import ProcessPool
8
+
9
+ from count_prime import count_prime
10
+
11
+
12
+ if __name__ == "__main__":
13
+ print("Use ProcessPool to count prime numbers lower than 10000.")
14
+ print(f"See source code at folder {os.path.dirname(os.path.abspath(__file__))}")
15
+
16
+ tasks = []
17
+ start = 0
18
+ while start < 10000:
19
+ stop = start + 1000
20
+ tasks.append((start, stop))
21
+ start = stop
22
+
23
+ with ProcessPool() as pool:
24
+ futures = []
25
+ for task in tasks:
26
+ future = pool.submit(count_prime, args=task)
27
+ futures.append(future)
28
+
29
+ total_primes_count = sum(future.result() for future in futures)
30
+ print(total_primes_count)
31
+
@@ -0,0 +1,18 @@
1
+ import math
2
+
3
+
4
+ def is_prime(num:int):
5
+ if num < 2:
6
+ return False
7
+ for i in range(2, int(math.sqrt(num)) + 1):
8
+ if num % i == 0:
9
+ return False
10
+ return True
11
+
12
+
13
+ def count_prime(start:int, stop:int):
14
+ count = 0
15
+ for i in range(start, stop):
16
+ if is_prime(i):
17
+ count += 1
18
+ return count
@@ -0,0 +1,244 @@
1
+ from smartpool import ProcessPool, ThreadPool, DataSize, limit_num_single_thread
2
+ limit_num_single_thread()
3
+
4
+ import click
5
+
6
+
7
+ @click.command(help=f"Use smartpool to do 5-fold cross validatation for 7 deep learning models for handwritten digit recognition task.")
8
+ @click.option(
9
+ '--pool', default='smartpool.ProcessPool', type=click.Choice([
10
+ 'smartpool.ProcessPool',
11
+ 'smartpool.ThreadPool',
12
+ 'multiprocessing.Pool',
13
+ 'concurrent.futures.ProcessPoolExecutor',
14
+ 'concurrent.futures.ThreadPoolExecutor',
15
+ "joblib.Parallel(backend='loky')",
16
+ "joblib.Parallel(backend='threading')",
17
+ "ray"
18
+ ]),
19
+ help="choose process pool implementations"
20
+ )
21
+ @click.option(
22
+ '--max_workers', default=0, type=int,
23
+ help='max number of workers to use, 0 to use all available cores'
24
+ )
25
+ def main(pool:str="smart", max_workers:int=0):
26
+ import os
27
+ os.environ["RAY_ACCEL_ENV_VAR_OVERRIDE_ON_ZERO"] = "0"
28
+
29
+ print(f"Use {pool} to do 5-fold cross validatation for 7 deep learning models for handwritten digit recognition task.")
30
+ print("Use `python -m smartpool_examples.cross_validation --help` to see all options.")
31
+ print(f"See source code at folder {os.path.dirname(os.path.abspath(__file__))}")
32
+ print("\npreparing data...")
33
+
34
+ try:
35
+ import torch
36
+ import torch.nn as nn
37
+ except ImportError:
38
+ print("PyTorch is not installed. Follow https://pytorch.org/ instructions to install PyTorch.")
39
+ exit(1)
40
+
41
+ try:
42
+ import torchvision
43
+ except ImportError:
44
+ print("torchvision is not installed. Use `pip install torchvision` to install torchvision.")
45
+ exit(1)
46
+
47
+ import time
48
+ from sklearn.model_selection import KFold
49
+ import numpy as np
50
+ from rich.progress import Progress, BarColumn, TextColumn, TimeRemainingColumn
51
+ import multiprocessing as mp
52
+ import queue
53
+ from collections import defaultdict
54
+ from concurrent.futures import Future
55
+ from typing import Dict, Union
56
+
57
+ import os
58
+ import sys
59
+
60
+ self_folder = os.path.dirname(os.path.abspath(__file__)).replace("\\", "/")
61
+ sys.path.append(self_folder)
62
+
63
+ import models
64
+ from data_utils import prepare_data
65
+ from model_utils import train_single_fold, ErrorInfo, ProgressInfo, TrainingResult
66
+ from visualization import plot_results, print_results_table
67
+ from config import EPOCHS
68
+
69
+ if max_workers == 0:
70
+ max_workers = os.cpu_count()
71
+
72
+ model_classes = [
73
+ cls for cls in models.__dict__.values()
74
+ if isinstance(cls, type) and issubclass(cls, nn.Module) and cls != nn.Module
75
+ ]
76
+
77
+ dataset = prepare_data()
78
+ kfold = KFold(n_splits=5, shuffle=True, random_state=42)
79
+
80
+ manager = mp.Manager()
81
+
82
+ if pool != "ray":
83
+ progress_queue:queue.Queue[Union[ProgressInfo, ErrorInfo]] = manager.Queue()
84
+ else:
85
+ try:
86
+ import ray
87
+ import ray.util.queue
88
+ except ImportError:
89
+ print("Ray is not installed. Use `pip install ray` to install Ray.")
90
+ exit(1)
91
+
92
+ progress_queue:queue.Queue[Union[ProgressInfo, ErrorInfo]] = ray.util.queue.Queue()
93
+
94
+ tasks = []
95
+ for fold_idx, (train_indices, val_indices) in enumerate(kfold.split(dataset)):
96
+ for model_class in model_classes:
97
+ tasks.append((fold_idx, model_class, train_indices.copy(), val_indices.copy(), dataset, progress_queue))
98
+
99
+ task_progress_bars = {}
100
+ best_device = 'cuda' if torch.cuda.is_available() else 'cpu'
101
+
102
+ start_time = time.perf_counter()
103
+ with Progress(
104
+ TextColumn("[progress.description]{task.description}"),
105
+ BarColumn(),
106
+ TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
107
+ TimeRemainingColumn()
108
+ ) as progress:
109
+
110
+ active_tasks = {}
111
+
112
+ if pool == "smartpool.ProcessPool":
113
+ process_pool = ProcessPool(max_workers=max_workers, use_torch=True)
114
+ elif pool == "smartpool.ThreadPool":
115
+ process_pool = ThreadPool(max_workers=max_workers, use_torch=True)
116
+ elif pool == "concurrent.futures.ProcessPoolExecutor":
117
+ from concurrent.futures import ProcessPoolExecutor
118
+ process_pool = ProcessPoolExecutor(max_workers=max_workers)
119
+ elif pool == "concurrent.futures.ThreadPoolExecutor":
120
+ from concurrent.futures import ThreadPoolExecutor
121
+ process_pool = ThreadPoolExecutor(max_workers=max_workers)
122
+ elif pool == "multiprocessing.Pool":
123
+ import multiprocessing
124
+ process_pool = multiprocessing.Pool(processes=max_workers)
125
+ elif pool == "joblib.Parallel(backend='loky')":
126
+ from joblib import Parallel, delayed
127
+ process_pool = Parallel(n_jobs=max_workers, backend='loky', return_as="generator")
128
+ elif pool == "joblib.Parallel(backend='threading')":
129
+ from joblib import Parallel, delayed
130
+ process_pool = Parallel(n_jobs=max_workers, backend='threading', return_as="generator")
131
+
132
+ print("submitting training tasks...")
133
+ futures_map:Dict[str, Future] = {}
134
+ futures = []
135
+ for i, task_args in enumerate(tasks):
136
+ if pool.startswith("smartpool."):
137
+ future = process_pool.submit(
138
+ train_single_fold,
139
+ args=task_args,
140
+ need_cpu_cores=1,
141
+ need_cpu_mem=1.1*DataSize.GB,
142
+ need_gpu_cores=1000,
143
+ need_gpu_mem=0.2*DataSize.GB
144
+ )
145
+ elif pool.startswith("concurrent.futures."):
146
+ future = process_pool.submit(train_single_fold, *task_args, best_device if i % max_workers < 5 else 'cpu')
147
+ elif pool == "multiprocessing.Pool":
148
+ future = process_pool.apply_async(train_single_fold, args=(*task_args, best_device if i % max_workers < 5 else 'cpu'))
149
+ elif pool.startswith("joblib"):
150
+ future = delayed(train_single_fold)(*task_args, best_device if i % max_workers < 5 else 'cpu')
151
+ elif pool == "ray":
152
+ future = ray.remote(num_cpus=1, num_gpus=(0.2 if i % max_workers < 5 else 0), memory=1.1*DataSize.GB)(train_single_fold).remote(*task_args, best_device if i % max_workers < 5 else 'cpu')
153
+
154
+ fold_idx = task_args[0]
155
+ model_class = task_args[1]
156
+ model_name = model_class.__name__
157
+ task_key = f"{model_name}_fold_{fold_idx}"
158
+ futures_map[task_key] = future
159
+ futures.append(future)
160
+
161
+ print(f"training all models in {pool} ...")
162
+ if pool.startswith("joblib"):
163
+ joblib_results = process_pool(futures)
164
+
165
+ finished_tasks = set()
166
+ while True:
167
+ progress_info:Union[ProgressInfo, ErrorInfo] = progress_queue.get()
168
+ if isinstance(progress_info, ErrorInfo):
169
+ print(progress_info.traceback)
170
+ break
171
+
172
+ task_key = f"{progress_info.model_name}_fold_{progress_info.fold_idx}"
173
+
174
+ if task_key not in task_progress_bars:
175
+ initial_desc = f"train {progress_info.model_name} on {progress_info.device} "
176
+ initial_desc += f"for fold {progress_info.fold_idx+1}/5"
177
+ task_progress_bars[task_key] = progress.add_task(initial_desc, total=100)
178
+ active_tasks[task_key] = True
179
+
180
+ if task_key in task_progress_bars:
181
+ epoch_progress = (progress_info.epoch - 1) / 5
182
+ batch_progress = progress_info.batch / progress_info.total_batches
183
+ total_progress = (epoch_progress + batch_progress / 5) * 100
184
+
185
+ if progress_info.epoch == 5 and progress_info.batch == progress_info.total_batches:
186
+ total_progress = 100.0
187
+ finished_tasks.add(task_key)
188
+
189
+ new_desc = f"train {progress_info.model_name} on {progress_info.device} "
190
+ new_desc += f"for fold {progress_info.fold_idx+1}/5 - Epoch {progress_info.epoch}/{EPOCHS} "
191
+ new_desc += f"Loss: {progress_info.avg_loss:.4f} "
192
+ new_desc += f"Val Acc: {progress_info.val_accuracy*100:.2f}%"
193
+ if progress_info.device.startswith("cuda"):
194
+ new_desc = "[bright_cyan]" + new_desc
195
+
196
+ progress.update(
197
+ task_progress_bars[task_key],
198
+ completed=total_progress,
199
+ description=new_desc
200
+ )
201
+ if total_progress >= 100.0:
202
+ progress.update(task_progress_bars[task_key], visible=False)
203
+
204
+ if len(finished_tasks) == len(futures_map):
205
+ break
206
+
207
+ model_results = defaultdict(list)
208
+ if pool in ["smartpool.ProcessPool", "smartpool.ThreadPool", "concurrent.futures.ProcessPoolExecutor", "concurrent.futures.ThreadPoolExecutor", "multiprocessing.Pool"]:
209
+ for task_key, future in futures_map.items():
210
+ if pool == "multiprocessing.Pool":
211
+ result:TrainingResult = future.get()
212
+ else:
213
+ result:TrainingResult = future.result()
214
+
215
+ model_results[result.model_name].append(result.val_accuracy)
216
+ elif pool.startswith("joblib"):
217
+ for result in joblib_results:
218
+ model_results[result.model_name].append(result.val_accuracy)
219
+ elif pool == "ray":
220
+ ray_results = ray.get(futures)
221
+ for result in ray_results:
222
+ model_results[result.model_name].append(result.val_accuracy)
223
+
224
+ stop_time = time.perf_counter()
225
+ print(f"train completed in {stop_time - start_time:.2f} seconds")
226
+
227
+ print("analysing results...")
228
+
229
+ stats = {}
230
+ for model_name, accuracies in model_results.items():
231
+ stats[model_name] = {
232
+ 'mean': np.mean(accuracies),
233
+ 'std': np.std(accuracies),
234
+ 'min': np.min(accuracies),
235
+ 'max': np.max(accuracies),
236
+ 'accuracies': accuracies
237
+ }
238
+
239
+ print_results_table(stats)
240
+ plot_results(model_results, stats)
241
+
242
+
243
+ if __name__ == "__main__":
244
+ main()
@@ -0,0 +1,11 @@
1
+ # Training parameters
2
+ BATCH_SIZE = 128
3
+ EPOCHS = 5
4
+ LEARNING_RATE = 0.001
5
+
6
+ # Data settings
7
+ import os
8
+ self_folder = os.path.dirname(os.path.abspath(__file__))
9
+
10
+ DATA_ROOT = f'{self_folder}/data'
11
+ DATASET_NAME = 'MNIST'
@@ -0,0 +1,47 @@
1
+ import os
2
+ from torch.utils.data import DataLoader, Subset
3
+ from torchvision import datasets, transforms
4
+
5
+ from config import DATA_ROOT, BATCH_SIZE
6
+
7
+
8
+ def prepare_data():
9
+ transform = transforms.Compose([
10
+ transforms.ToTensor(),
11
+ transforms.Normalize((0.1307,), (0.3081,))
12
+ ])
13
+
14
+ mnist_exists = (
15
+ os.path.exists(os.path.join(DATA_ROOT, 'MNIST', 'raw')) and
16
+ os.path.exists(os.path.join(DATA_ROOT, 'MNIST', 'processed'))
17
+ )
18
+
19
+ dataset = datasets.MNIST(
20
+ root=DATA_ROOT,
21
+ train=True,
22
+ download=not mnist_exists,
23
+ transform=transform
24
+ )
25
+ dataset.data.share_memory_()
26
+ dataset.targets.share_memory_()
27
+ return dataset
28
+
29
+
30
+ def create_data_loaders(dataset, train_indices, val_indices):
31
+ train_subset = Subset(dataset, train_indices)
32
+ val_subset = Subset(dataset, val_indices)
33
+
34
+ train_loader = DataLoader(
35
+ train_subset,
36
+ batch_size=BATCH_SIZE,
37
+ shuffle=True,
38
+ pin_memory=True
39
+ )
40
+ val_loader = DataLoader(
41
+ val_subset,
42
+ batch_size=BATCH_SIZE,
43
+ shuffle=False,
44
+ pin_memory=True
45
+ )
46
+
47
+ return train_loader, val_loader
@@ -0,0 +1,173 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+
5
+ import traceback
6
+ from dataclasses import dataclass
7
+
8
+ from config import LEARNING_RATE, EPOCHS
9
+ from data_utils import create_data_loaders
10
+
11
+ from smartpool import move_optimizer_to, best_device
12
+
13
+
14
+ @dataclass
15
+ class TrainingResult:
16
+ fold_idx:int
17
+ model_name:str
18
+ val_accuracy:float
19
+
20
+
21
+ @dataclass
22
+ class ProgressInfo:
23
+ model_name:str
24
+ fold_idx:int
25
+ epoch:int
26
+ batch:int
27
+ total_batches:int
28
+ device:str
29
+ avg_loss:float
30
+ val_accuracy:float
31
+
32
+
33
+ @dataclass
34
+ class ErrorInfo:
35
+ exception:BaseException
36
+ traceback:str
37
+
38
+ def train_single_fold(fold_idx, model_class, train_indices, val_indices, dataset, progress_queue, device=None):
39
+ try:
40
+ return _train_single_fold(fold_idx, model_class, train_indices, val_indices, dataset, progress_queue, device)
41
+ except BaseException as e:
42
+ error_info = ErrorInfo(e, traceback.format_exc())
43
+ progress_queue.put(error_info)
44
+ raise e
45
+
46
+ def _train_single_fold(fold_idx, model_class, train_indices, val_indices, dataset, progress_queue, user_device):
47
+ train_loader, val_loader = create_data_loaders(dataset, train_indices, val_indices)
48
+ num_batches = len(train_loader)
49
+ model = model_class()
50
+
51
+ device = user_device
52
+ if user_device is None:
53
+ device = best_device()
54
+
55
+ old_device = device
56
+ model.to(device, non_blocking=True)
57
+
58
+ criterion = nn.CrossEntropyLoss()
59
+ optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
60
+
61
+ initial_progress = ProgressInfo(
62
+ model_name=model_class.__name__,
63
+ fold_idx=fold_idx,
64
+ epoch=1,
65
+ batch=0,
66
+ total_batches=num_batches,
67
+ device=device,
68
+ avg_loss=0.0,
69
+ val_accuracy=0.0
70
+ )
71
+ progress_queue.put(initial_progress)
72
+
73
+ last_val_accuracy = 0.0
74
+
75
+ for epoch in range(EPOCHS):
76
+ epoch_loss = 0.0
77
+ model.train()
78
+ for batch_idx, (data, target) in enumerate(train_loader):
79
+ if user_device is None:
80
+ device = best_device()
81
+
82
+ data = data.to(device, non_blocking=True)
83
+ target = target.to(device, non_blocking=True)
84
+
85
+ if device != old_device:
86
+ model.to(device, non_blocking=True)
87
+ move_optimizer_to(optimizer, device)
88
+ old_device = device
89
+
90
+ optimizer.zero_grad()
91
+ output = model(data)
92
+ loss = criterion(output, target)
93
+ loss.backward()
94
+ optimizer.step()
95
+
96
+ epoch_loss += loss.item()
97
+
98
+ progress_info = ProgressInfo(
99
+ model_name=model_class.__name__,
100
+ fold_idx=fold_idx,
101
+ epoch=epoch + 1,
102
+ batch=batch_idx + 1,
103
+ total_batches=num_batches,
104
+ device=device,
105
+ avg_loss=epoch_loss / (batch_idx + 1),
106
+ val_accuracy=last_val_accuracy
107
+ )
108
+ progress_queue.put(progress_info)
109
+
110
+ model.eval()
111
+ correct = 0
112
+ total = 0
113
+ val_accuracy = 0.0
114
+
115
+ with torch.no_grad():
116
+ for data, target in val_loader:
117
+ if user_device is None:
118
+ device = best_device()
119
+
120
+ data = data.to(device, non_blocking=True)
121
+ target = target.to(device, non_blocking=True)
122
+
123
+ if device != old_device:
124
+ model.to(device, non_blocking=True)
125
+ move_optimizer_to(optimizer, device)
126
+ old_device = device
127
+
128
+ output = model(data)
129
+ pred = output.argmax(dim=1, keepdim=True)
130
+ correct += pred.eq(target.view_as(pred)).sum().item()
131
+ total += target.size(0)
132
+
133
+ val_accuracy = correct / total
134
+ last_val_accuracy = val_accuracy
135
+ model.train()
136
+
137
+ final_progress = ProgressInfo(
138
+ model_name=model_class.__name__,
139
+ fold_idx=fold_idx,
140
+ epoch=epoch + 1,
141
+ batch=num_batches,
142
+ total_batches=num_batches,
143
+ device=device,
144
+ avg_loss=epoch_loss / num_batches,
145
+ val_accuracy=val_accuracy
146
+ )
147
+ progress_queue.put(final_progress)
148
+
149
+ model.eval()
150
+ correct = 0
151
+ total = 0
152
+
153
+ with torch.no_grad():
154
+ for data, target in val_loader:
155
+ if user_device is None:
156
+ device = best_device()
157
+
158
+ data = data.to(device, non_blocking=True)
159
+ target = target.to(device, non_blocking=True)
160
+
161
+ if device != old_device:
162
+ model.to(device, non_blocking=True)
163
+ move_optimizer_to(optimizer, device)
164
+ old_device = device
165
+
166
+ output = model(data)
167
+ pred = output.argmax(dim=1, keepdim=True)
168
+ correct += pred.eq(target.view_as(pred)).sum().item()
169
+ total += target.size(0)
170
+
171
+ val_accuracy = correct / total
172
+
173
+ return TrainingResult(fold_idx, model_class.__name__, val_accuracy)
@@ -0,0 +1,27 @@
1
+ import torch.nn as nn
2
+
3
+
4
+ class LeNet5(nn.Module):
5
+
6
+ def __init__(self):
7
+ super(LeNet5, self).__init__()
8
+ self.conv = nn.Sequential(
9
+ nn.Conv2d(1, 6, 5), # 28x28 -> 24x24
10
+ nn.ReLU(),
11
+ nn.MaxPool2d(2), # 24x24 -> 12x12
12
+ nn.Conv2d(6, 16, 5), # 12x12 -> 8x8
13
+ nn.ReLU(),
14
+ nn.MaxPool2d(2) # 8x8 -> 4x4
15
+ )
16
+ self.fc = nn.Sequential(
17
+ nn.Linear(16*4*4, 120),
18
+ nn.ReLU(),
19
+ nn.Linear(120, 84),
20
+ nn.ReLU(),
21
+ nn.Linear(84, 10)
22
+ )
23
+
24
+ def forward(self, x):
25
+ x = self.conv(x)
26
+ x = x.view(x.size(0), -1)
27
+ return self.fc(x)
@@ -0,0 +1,21 @@
1
+ import torch.nn as nn
2
+
3
+
4
+ class MLP(nn.Module):
5
+
6
+ def __init__(self):
7
+ super(MLP, self).__init__()
8
+ self.flatten = nn.Flatten()
9
+ self.fc = nn.Sequential(
10
+ nn.Linear(28*28, 512),
11
+ nn.ReLU(),
12
+ nn.Dropout(0.2),
13
+ nn.Linear(512, 256),
14
+ nn.ReLU(),
15
+ nn.Dropout(0.2),
16
+ nn.Linear(256, 10)
17
+ )
18
+
19
+ def forward(self, x):
20
+ x = self.flatten(x)
21
+ return self.fc(x)
@@ -0,0 +1,33 @@
1
+ import torch.nn as nn
2
+
3
+
4
+ class ModernCNN(nn.Module):
5
+
6
+ def __init__(self):
7
+ super(ModernCNN, self).__init__()
8
+ self.conv = nn.Sequential(
9
+ nn.Conv2d(1, 32, 3, padding=1),
10
+ nn.BatchNorm2d(32),
11
+ nn.ReLU(),
12
+ nn.Conv2d(32, 64, 3, padding=1),
13
+ nn.BatchNorm2d(64),
14
+ nn.ReLU(),
15
+ nn.MaxPool2d(2), # 14x14
16
+ nn.Dropout(0.25),
17
+ nn.Conv2d(64, 128, 3, padding=1),
18
+ nn.BatchNorm2d(128),
19
+ nn.ReLU(),
20
+ nn.MaxPool2d(2), # 7x7
21
+ nn.Dropout(0.25)
22
+ )
23
+ self.fc = nn.Sequential(
24
+ nn.Linear(128*7*7, 256),
25
+ nn.ReLU(),
26
+ nn.Dropout(0.5),
27
+ nn.Linear(256, 10)
28
+ )
29
+
30
+ def forward(self, x):
31
+ x = self.conv(x)
32
+ x = x.view(x.size(0), -1)
33
+ return self.fc(x)
@@ -0,0 +1,57 @@
1
+ import torch.nn as nn
2
+
3
+ class ResNeXtBlock(nn.Module):
4
+ def __init__(self, in_channels, out_channels, cardinality=32, stride=1):
5
+ super(ResNeXtBlock, self).__init__()
6
+ self.conv1 = nn.Conv2d(in_channels, out_channels, 1, bias=False)
7
+ self.bn1 = nn.BatchNorm2d(out_channels)
8
+ self.conv2 = nn.Conv2d(out_channels, out_channels, 3, stride=stride, padding=1, groups=cardinality, bias=False)
9
+ self.bn2 = nn.BatchNorm2d(out_channels)
10
+ self.conv3 = nn.Conv2d(out_channels, out_channels * 2, 1, bias=False)
11
+ self.bn3 = nn.BatchNorm2d(out_channels * 2)
12
+ self.relu = nn.ReLU(inplace=True)
13
+
14
+ self.shortcut = nn.Sequential()
15
+ if stride != 1 or in_channels != out_channels * 2:
16
+ self.shortcut = nn.Sequential(
17
+ nn.Conv2d(in_channels, out_channels * 2, 1, stride=stride, bias=False),
18
+ nn.BatchNorm2d(out_channels * 2)
19
+ )
20
+
21
+ def forward(self, x):
22
+ identity = x
23
+ out = self.relu(self.bn1(self.conv1(x)))
24
+ out = self.relu(self.bn2(self.conv2(out)))
25
+ out = self.bn3(self.conv3(out))
26
+ out += self.shortcut(identity)
27
+ out = self.relu(out)
28
+ return out
29
+
30
+ class ResNeXt(nn.Module):
31
+ def __init__(self):
32
+ super(ResNeXt, self).__init__()
33
+ self.conv1 = nn.Conv2d(1, 32, 3, padding=1)
34
+ self.bn1 = nn.BatchNorm2d(32)
35
+ self.relu = nn.ReLU(inplace=True)
36
+
37
+ self.layer1 = self._make_layer(32, 64, 2, stride=2)
38
+ self.layer2 = self._make_layer(128, 128, 2, stride=2)
39
+
40
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
41
+ self.fc = nn.Linear(256, 10)
42
+
43
+ def _make_layer(self, in_channels, out_channels, blocks, stride=1):
44
+ layers = []
45
+ layers.append(ResNeXtBlock(in_channels, out_channels, stride=stride))
46
+ for _ in range(1, blocks):
47
+ layers.append(ResNeXtBlock(out_channels * 2, out_channels))
48
+ return nn.Sequential(*layers)
49
+
50
+ def forward(self, x):
51
+ x = self.relu(self.bn1(self.conv1(x)))
52
+ x = self.layer1(x)
53
+ x = self.layer2(x)
54
+ x = self.avgpool(x)
55
+ x = x.view(x.size(0), -1)
56
+ x = self.fc(x)
57
+ return x
@@ -0,0 +1,59 @@
1
+ import torch.nn as nn
2
+
3
+ class PreActResNeXtBlock(nn.Module):
4
+ def __init__(self, in_channels, out_channels, cardinality=32, stride=1):
5
+ super(PreActResNeXtBlock, self).__init__()
6
+ self.bn1 = nn.BatchNorm2d(in_channels)
7
+ self.relu = nn.ReLU(inplace=True)
8
+ self.conv1 = nn.Conv2d(in_channels, out_channels, 1, bias=False)
9
+ self.bn2 = nn.BatchNorm2d(out_channels)
10
+ self.conv2 = nn.Conv2d(out_channels, out_channels, 3, stride=stride, padding=1, groups=cardinality, bias=False)
11
+ self.bn3 = nn.BatchNorm2d(out_channels)
12
+ self.conv3 = nn.Conv2d(out_channels, out_channels * 2, 1, bias=False)
13
+
14
+ self.shortcut = nn.Sequential()
15
+ if stride != 1 or in_channels != out_channels * 2:
16
+ self.shortcut = nn.Sequential(
17
+ nn.Conv2d(in_channels, out_channels * 2, 1, stride=stride, bias=False)
18
+ )
19
+
20
+ def forward(self, x):
21
+ identity = x
22
+ out = self.relu(self.bn1(x))
23
+ out = self.conv1(out)
24
+ out = self.relu(self.bn2(out))
25
+ out = self.conv2(out)
26
+ out = self.relu(self.bn3(out))
27
+ out = self.conv3(out)
28
+ out += self.shortcut(identity)
29
+ return out
30
+
31
+ class ResNeXtV2(nn.Module):
32
+ def __init__(self):
33
+ super(ResNeXtV2, self).__init__()
34
+ self.conv1 = nn.Conv2d(1, 32, 3, padding=1)
35
+
36
+ self.layer1 = self._make_layer(32, 64, 2, stride=2)
37
+ self.layer2 = self._make_layer(128, 128, 2, stride=2)
38
+
39
+ self.bn = nn.BatchNorm2d(256)
40
+ self.relu = nn.ReLU(inplace=True)
41
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
42
+ self.fc = nn.Linear(256, 10)
43
+
44
+ def _make_layer(self, in_channels, out_channels, blocks, stride=1):
45
+ layers = []
46
+ layers.append(PreActResNeXtBlock(in_channels, out_channels, stride=stride))
47
+ for _ in range(1, blocks):
48
+ layers.append(PreActResNeXtBlock(out_channels * 2, out_channels))
49
+ return nn.Sequential(*layers)
50
+
51
+ def forward(self, x):
52
+ x = self.conv1(x)
53
+ x = self.layer1(x)
54
+ x = self.layer2(x)
55
+ x = self.relu(self.bn(x))
56
+ x = self.avgpool(x)
57
+ x = x.view(x.size(0), -1)
58
+ x = self.fc(x)
59
+ return x
@@ -0,0 +1,54 @@
1
+ import torch.nn as nn
2
+
3
+ class ResidualBlock(nn.Module):
4
+ def __init__(self, in_channels, out_channels, stride=1):
5
+ super(ResidualBlock, self).__init__()
6
+ self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride=stride, padding=1, bias=False)
7
+ self.bn1 = nn.BatchNorm2d(out_channels)
8
+ self.relu = nn.ReLU(inplace=True)
9
+ self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False)
10
+ self.bn2 = nn.BatchNorm2d(out_channels)
11
+
12
+ self.shortcut = nn.Sequential()
13
+ if stride != 1 or in_channels != out_channels:
14
+ self.shortcut = nn.Sequential(
15
+ nn.Conv2d(in_channels, out_channels, 1, stride=stride, bias=False),
16
+ nn.BatchNorm2d(out_channels)
17
+ )
18
+
19
+ def forward(self, x):
20
+ identity = x
21
+ out = self.relu(self.bn1(self.conv1(x)))
22
+ out = self.bn2(self.conv2(out))
23
+ out += self.shortcut(identity)
24
+ out = self.relu(out)
25
+ return out
26
+
27
+ class ResNet(nn.Module):
28
+ def __init__(self):
29
+ super(ResNet, self).__init__()
30
+ self.conv1 = nn.Conv2d(1, 32, 3, padding=1)
31
+ self.bn1 = nn.BatchNorm2d(32)
32
+ self.relu = nn.ReLU(inplace=True)
33
+
34
+ self.layer1 = self._make_layer(32, 64, 2, stride=2)
35
+ self.layer2 = self._make_layer(64, 128, 2, stride=2)
36
+
37
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
38
+ self.fc = nn.Linear(128, 10)
39
+
40
+ def _make_layer(self, in_channels, out_channels, blocks, stride=1):
41
+ layers = []
42
+ layers.append(ResidualBlock(in_channels, out_channels, stride))
43
+ for _ in range(1, blocks):
44
+ layers.append(ResidualBlock(out_channels, out_channels))
45
+ return nn.Sequential(*layers)
46
+
47
+ def forward(self, x):
48
+ x = self.relu(self.bn1(self.conv1(x)))
49
+ x = self.layer1(x)
50
+ x = self.layer2(x)
51
+ x = self.avgpool(x)
52
+ x = x.view(x.size(0), -1)
53
+ x = self.fc(x)
54
+ return x
@@ -0,0 +1,55 @@
1
+ import torch.nn as nn
2
+
3
+ class PreActBlock(nn.Module):
4
+ def __init__(self, in_channels, out_channels, stride=1):
5
+ super(PreActBlock, self).__init__()
6
+ self.bn1 = nn.BatchNorm2d(in_channels)
7
+ self.relu = nn.ReLU(inplace=True)
8
+ self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride=stride, padding=1, bias=False)
9
+ self.bn2 = nn.BatchNorm2d(out_channels)
10
+ self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False)
11
+
12
+ self.shortcut = nn.Sequential()
13
+ if stride != 1 or in_channels != out_channels:
14
+ self.shortcut = nn.Sequential(
15
+ nn.Conv2d(in_channels, out_channels, 1, stride=stride, bias=False)
16
+ )
17
+
18
+ def forward(self, x):
19
+ identity = x
20
+ out = self.relu(self.bn1(x))
21
+ out = self.conv1(out)
22
+ out = self.relu(self.bn2(out))
23
+ out = self.conv2(out)
24
+ out += self.shortcut(identity)
25
+ return out
26
+
27
+ class ResNetV2(nn.Module):
28
+ def __init__(self):
29
+ super(ResNetV2, self).__init__()
30
+ self.conv1 = nn.Conv2d(1, 32, 3, padding=1)
31
+
32
+ self.layer1 = self._make_layer(32, 64, 2, stride=2)
33
+ self.layer2 = self._make_layer(64, 128, 2, stride=2)
34
+
35
+ self.bn = nn.BatchNorm2d(128)
36
+ self.relu = nn.ReLU(inplace=True)
37
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
38
+ self.fc = nn.Linear(128, 10)
39
+
40
+ def _make_layer(self, in_channels, out_channels, blocks, stride=1):
41
+ layers = []
42
+ layers.append(PreActBlock(in_channels, out_channels, stride))
43
+ for _ in range(1, blocks):
44
+ layers.append(PreActBlock(out_channels, out_channels))
45
+ return nn.Sequential(*layers)
46
+
47
+ def forward(self, x):
48
+ x = self.conv1(x)
49
+ x = self.layer1(x)
50
+ x = self.layer2(x)
51
+ x = self.relu(self.bn(x))
52
+ x = self.avgpool(x)
53
+ x = x.view(x.size(0), -1)
54
+ x = self.fc(x)
55
+ return x
@@ -0,0 +1,9 @@
1
+ from .LeNet5 import LeNet5
2
+ from .MLP import MLP
3
+ from .ModernCNN import ModernCNN
4
+ from .ResNet import ResNet
5
+ from .ResNetV2 import ResNetV2
6
+ from .ResNeXt import ResNeXt
7
+ from .ResNeXtV2 import ResNeXtV2
8
+
9
+ __all__ = ['LeNet5', 'MLP', 'ModernCNN', 'ResNet', 'ResNetV2', 'ResNeXt', 'ResNeXtV2']
@@ -0,0 +1,86 @@
1
+ import matplotlib.pyplot as plt
2
+ import numpy as np
3
+ import rich
4
+ from rich.table import Table
5
+ from typing import Dict, Any
6
+
7
+
8
+ def plot_results(model_results:Dict[str, Any], stats:Dict[str, Any]):
9
+ fig, ax = plt.subplots(figsize=(12, 8))
10
+
11
+ plt.rcParams["font.family"] = ["Microsoft YaHei", "Arial", "sans-serif"]
12
+ plt.rcParams["axes.unicode_minus"] = False
13
+
14
+ n_models = len(model_results)
15
+ n_folds = 5
16
+ x_pos = np.arange(n_models)
17
+
18
+ # 调整柱状图宽度和间距,使柱子整体居中对齐
19
+ total_width = 0.8 # 总宽度
20
+ bar_width = total_width / n_folds # 每个柱子的宽度
21
+
22
+ # 定义不同fold的颜色
23
+ colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd']
24
+
25
+ for fold_idx in range(n_folds):
26
+ fold_accuracies = []
27
+ for module_result in model_results.values():
28
+ if fold_idx < len(module_result):
29
+ fold_accuracies.append(module_result[fold_idx])
30
+ else:
31
+ fold_accuracies.append(0)
32
+
33
+ # 计算每个柱子的位置,使其整体居中
34
+ offset = (fold_idx - n_folds/2 + 0.5) * bar_width
35
+ ax.bar(x_pos + offset, fold_accuracies,
36
+ bar_width, alpha=0.7,
37
+ color=colors[fold_idx],
38
+ label=f'Fold {fold_idx+1}')
39
+
40
+ means = [stat['mean'] for stat in stats.values()]
41
+ stds = [stat['std'] for stat in stats.values()]
42
+
43
+ ax.errorbar(x_pos, means, yerr=stds, fmt='o', color='black',
44
+ capsize=5, capthick=2, elinewidth=2, markersize=8,
45
+ label='Mean ± Std Dev')
46
+
47
+ for i, (mean, std) in enumerate(zip(means, stds)):
48
+ ax.annotate(f'{mean:.4f}\n±{std:.4f}',
49
+ xy=(x_pos[i], mean),
50
+ xytext=(0, 10),
51
+ textcoords='offset points',
52
+ ha='center', va='bottom',
53
+ fontsize=9)
54
+
55
+ ax.set_xlabel('Model Type', fontsize=12)
56
+ ax.set_ylabel('Validation Accuracy', fontsize=12)
57
+ ax.set_title('Handwritten Digit Recognition Models: 5-Fold Cross-Validation Comparison', fontsize=14, pad=20)
58
+ ax.set_xticks(x_pos)
59
+ ax.set_xticklabels(model_results.keys())
60
+ ax.legend(loc='lower right')
61
+ ax.grid(True, alpha=0.3)
62
+ ax.set_ylim(0.8, 1.0)
63
+
64
+ plt.tight_layout()
65
+ plt.savefig('cross_validation_results.png', dpi=300, bbox_inches='tight')
66
+ plt.show()
67
+
68
+ def print_results_table(stats:Dict[str, Any]):
69
+ table = Table(title="5-Fold Cross-Validation Results Summary")
70
+ table.add_column("Model", style="cyan")
71
+ table.add_column("Mean Accuracy", style="magenta")
72
+ table.add_column("Std Deviation", style="green")
73
+ table.add_column("Min", style="yellow")
74
+ table.add_column("Max", style="blue")
75
+
76
+ for model_name, stat in stats.items():
77
+ stat = stats[model_name]
78
+ table.add_row(
79
+ model_name,
80
+ f"{stat['mean']:.4f}",
81
+ f"{stat['std']:.4f}",
82
+ f"{stat['min']:.4f}",
83
+ f"{stat['max']:.4f}"
84
+ )
85
+
86
+ rich.print(table)
@@ -0,0 +1,84 @@
1
+ Metadata-Version: 2.4
2
+ Name: smartpool-examples
3
+ Version: 0.1.0
4
+ Summary: Examples for smartpool.
5
+ Author-email: "王炳辉 (Bing-Hui WANG)" <binghui.wang@foxmail.com>
6
+ License: MIT
7
+ Project-URL: Homepage, https://github.com/Time-Coder/smartpool
8
+ Project-URL: Repository, https://github.com/Time-Coder/smartpool.git
9
+ Classifier: Programming Language :: Python :: 3
10
+ Classifier: License :: OSI Approved :: MIT License
11
+ Classifier: Operating System :: OS Independent
12
+ Description-Content-Type: text/markdown
13
+ License-File: LICENSE
14
+ Requires-Dist: pysmartpool
15
+ Requires-Dist: matplotlib
16
+ Requires-Dist: scikit-learn
17
+ Requires-Dist: numpy
18
+ Requires-Dist: rich
19
+ Requires-Dist: click
20
+ Requires-Dist: joblib
21
+ Requires-Dist: datawatcher
22
+ Dynamic: license-file
23
+
24
+ # SmartPool Examples
25
+
26
+ This package contains practical examples demonstrating the capabilities of SmartPool for various computational tasks.
27
+
28
+ ## Examples Overview
29
+
30
+ ### 1. Prime Number Counting (`count_prime`)
31
+
32
+ Count the number of prime numbers below 10000 using smartpool.ProcessPool.
33
+ Demonstrates basic usage of smartpool.ProcessPool.
34
+
35
+ #### Running the Example
36
+
37
+ ```bash
38
+ python -m smartpool_examples.count_prime
39
+ ```
40
+
41
+ ### 2. Cross-Validation for Deep Learning models (`cross_validation`)
42
+
43
+ Demonstrates SmartPool's capabilities for machine learning workloads with GPU resource management.
44
+
45
+ #### Running the Example
46
+
47
+ ```bash
48
+ # Using ProcessPool
49
+ python -m smartpool_examples.cross_validation --pool smartpool.ProcessPool
50
+
51
+ # Using ThreadPool
52
+ python -m smartpool_examples.cross_validation --pool smartpool.ThreadPool
53
+
54
+ # Using multiprocessing.Pool
55
+ python -m smartpool_examples.cross_validation --pool multiprocessing.Pool
56
+
57
+ # Using concurrent.futures.ProcessPoolExecutor
58
+ python -m smartpool_examples.cross_validation --pool concurrent.futures.ProcessPoolExecutor
59
+
60
+ # Using concurrent.futures.ThreadPoolExecutor
61
+ python -m smartpool_examples.cross_validation --pool concurrent.futures.ThreadPoolExecutor
62
+
63
+ # Using joblib.Parallel(backend='loky')
64
+ python -m smartpool_examples.cross_validation --pool joblib.Parallel(backend='loky')
65
+
66
+ # Using joblib.Parallel(backend='threading')
67
+ python -m smartpool_examples.cross_validation --pool joblib.Parallel(backend='threading')
68
+
69
+ # Using Ray
70
+ python -m smartpool_examples.cross_validation --pool ray
71
+ ```
72
+
73
+ #### What it Demonstrates
74
+
75
+ - GPU memory management and core allocation
76
+ - Automatic device selection (CPU vs GPU)
77
+ - Cross-validation pipeline parallelization
78
+ - Resource monitoring during training
79
+ - Performance comparison with external frameworks
80
+
81
+
82
+ ## License
83
+
84
+ MIT License - see main smartpool repository for details
@@ -0,0 +1,26 @@
1
+ LICENSE
2
+ README.md
3
+ pyproject.toml
4
+ smartpool_examples/__init__.py
5
+ smartpool_examples.egg-info/PKG-INFO
6
+ smartpool_examples.egg-info/SOURCES.txt
7
+ smartpool_examples.egg-info/dependency_links.txt
8
+ smartpool_examples.egg-info/requires.txt
9
+ smartpool_examples.egg-info/top_level.txt
10
+ smartpool_examples/count_prime/__init__.py
11
+ smartpool_examples/count_prime/__main__.py
12
+ smartpool_examples/count_prime/count_prime.py
13
+ smartpool_examples/cross_validation/__init__.py
14
+ smartpool_examples/cross_validation/__main__.py
15
+ smartpool_examples/cross_validation/config.py
16
+ smartpool_examples/cross_validation/data_utils.py
17
+ smartpool_examples/cross_validation/model_utils.py
18
+ smartpool_examples/cross_validation/visualization.py
19
+ smartpool_examples/cross_validation/models/LeNet5.py
20
+ smartpool_examples/cross_validation/models/MLP.py
21
+ smartpool_examples/cross_validation/models/ModernCNN.py
22
+ smartpool_examples/cross_validation/models/ResNeXt.py
23
+ smartpool_examples/cross_validation/models/ResNeXtV2.py
24
+ smartpool_examples/cross_validation/models/ResNet.py
25
+ smartpool_examples/cross_validation/models/ResNetV2.py
26
+ smartpool_examples/cross_validation/models/__init__.py
@@ -0,0 +1,8 @@
1
+ pysmartpool
2
+ matplotlib
3
+ scikit-learn
4
+ numpy
5
+ rich
6
+ click
7
+ joblib
8
+ datawatcher
@@ -0,0 +1 @@
1
+ smartpool_examples