smartpool-examples 0.1.3__tar.gz → 0.1.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. {smartpool_examples-0.1.3/smartpool_examples.egg-info → smartpool_examples-0.1.6}/PKG-INFO +24 -4
  2. {smartpool_examples-0.1.3 → smartpool_examples-0.1.6}/README.md +20 -0
  3. {smartpool_examples-0.1.3 → smartpool_examples-0.1.6}/pyproject.toml +9 -6
  4. {smartpool_examples-0.1.3 → smartpool_examples-0.1.6}/smartpool_examples/count_prime/__main__.py +8 -6
  5. {smartpool_examples-0.1.3 → smartpool_examples-0.1.6}/smartpool_examples/count_prime/count_prime.py +3 -5
  6. {smartpool_examples-0.1.3 → smartpool_examples-0.1.6}/smartpool_examples/cross_validation/__main__.py +88 -59
  7. {smartpool_examples-0.1.3 → smartpool_examples-0.1.6}/smartpool_examples/cross_validation/config.py +4 -3
  8. {smartpool_examples-0.1.3 → smartpool_examples-0.1.6}/smartpool_examples/cross_validation/data_utils.py +16 -16
  9. {smartpool_examples-0.1.3 → smartpool_examples-0.1.6}/smartpool_examples/cross_validation/model_utils.py +23 -24
  10. {smartpool_examples-0.1.3 → smartpool_examples-0.1.6}/smartpool_examples/cross_validation/models/LeNet5.py +4 -4
  11. {smartpool_examples-0.1.3 → smartpool_examples-0.1.6}/smartpool_examples/cross_validation/models/MLP.py +3 -3
  12. {smartpool_examples-0.1.3 → smartpool_examples-0.1.6}/smartpool_examples/cross_validation/models/ModernCNN.py +4 -4
  13. {smartpool_examples-0.1.3 → smartpool_examples-0.1.6}/smartpool_examples/cross_validation/models/ResNeXt.py +7 -6
  14. {smartpool_examples-0.1.3 → smartpool_examples-0.1.6}/smartpool_examples/cross_validation/models/ResNeXtV2.py +7 -6
  15. {smartpool_examples-0.1.3 → smartpool_examples-0.1.6}/smartpool_examples/cross_validation/models/ResNet.py +7 -6
  16. {smartpool_examples-0.1.3 → smartpool_examples-0.1.6}/smartpool_examples/cross_validation/models/ResNetV2.py +7 -6
  17. {smartpool_examples-0.1.3 → smartpool_examples-0.1.6}/smartpool_examples/cross_validation/models/__init__.py +1 -1
  18. {smartpool_examples-0.1.3 → smartpool_examples-0.1.6}/smartpool_examples/cross_validation/visualization.py +21 -20
  19. smartpool_examples-0.1.6/smartpool_examples/onnx_infer/__init__.py +0 -0
  20. smartpool_examples-0.1.6/smartpool_examples/onnx_infer/__main__.py +88 -0
  21. smartpool_examples-0.1.6/smartpool_examples/onnx_infer/config.py +32 -0
  22. smartpool_examples-0.1.6/smartpool_examples/onnx_infer/data_utils.py +78 -0
  23. smartpool_examples-0.1.6/smartpool_examples/onnx_infer/inference.py +130 -0
  24. {smartpool_examples-0.1.3 → smartpool_examples-0.1.6/smartpool_examples.egg-info}/PKG-INFO +24 -4
  25. {smartpool_examples-0.1.3 → smartpool_examples-0.1.6}/smartpool_examples.egg-info/SOURCES.txt +6 -1
  26. {smartpool_examples-0.1.3 → smartpool_examples-0.1.6}/smartpool_examples.egg-info/requires.txt +3 -3
  27. {smartpool_examples-0.1.3 → smartpool_examples-0.1.6}/LICENSE +0 -0
  28. {smartpool_examples-0.1.3 → smartpool_examples-0.1.6}/setup.cfg +0 -0
  29. {smartpool_examples-0.1.3 → smartpool_examples-0.1.6}/smartpool_examples/__init__.py +0 -0
  30. {smartpool_examples-0.1.3 → smartpool_examples-0.1.6}/smartpool_examples/count_prime/__init__.py +0 -0
  31. {smartpool_examples-0.1.3 → smartpool_examples-0.1.6}/smartpool_examples/cross_validation/__init__.py +0 -0
  32. {smartpool_examples-0.1.3 → smartpool_examples-0.1.6}/smartpool_examples.egg-info/dependency_links.txt +0 -0
  33. {smartpool_examples-0.1.3 → smartpool_examples-0.1.6}/smartpool_examples.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: smartpool-examples
3
- Version: 0.1.3
3
+ Version: 0.1.6
4
4
  Summary: Examples for smartpool.
5
5
  Author-email: "王炳辉 (Bing-Hui WANG)" <binghui.wang@foxmail.com>
6
6
  License: MIT
@@ -11,14 +11,14 @@ Classifier: License :: OSI Approved :: MIT License
11
11
  Classifier: Operating System :: OS Independent
12
12
  Description-Content-Type: text/markdown
13
13
  License-File: LICENSE
14
- Requires-Dist: pysmartpool
14
+ Requires-Dist: pysmartpool>=0.1.6
15
15
  Requires-Dist: matplotlib
16
16
  Requires-Dist: scikit-learn
17
17
  Requires-Dist: numpy
18
18
  Requires-Dist: rich
19
- Requires-Dist: click
19
+ Requires-Dist: typer
20
20
  Requires-Dist: joblib
21
- Requires-Dist: datawatcher
21
+ Requires-Dist: opencv-python
22
22
  Dynamic: license-file
23
23
 
24
24
  # SmartPool Examples
@@ -79,6 +79,26 @@ python -m smartpool_examples.cross_validation --pool ray
79
79
  - Performance comparison with external frameworks
80
80
 
81
81
 
82
+ ### 3. ONNX Inference (`onnx_infer`)
83
+
84
+ Runs batched ONNX model inference using `InferSessionPool` for concurrent GPU/CPU execution.
85
+ Automatically manages inference sessions across worker threads.
86
+
87
+ #### Running the Example
88
+
89
+ ```bash
90
+ python -m smartpool_examples.onnx_infer --max-workers 4
91
+ ```
92
+
93
+ #### What it Demonstrates
94
+
95
+ - `InferSessionPool` creation and session lifecycle management
96
+ - Multi-threaded inference with automatic device placement
97
+ - COCO-format image preprocessing (resize, normalize, letterbox)
98
+ - Softmax + top-5 postprocessing
99
+ - Progress bars for downloads and inference steps
100
+
101
+
82
102
  ## License
83
103
 
84
104
  MIT License - see main smartpool repository for details
@@ -56,6 +56,26 @@ python -m smartpool_examples.cross_validation --pool ray
56
56
  - Performance comparison with external frameworks
57
57
 
58
58
 
59
+ ### 3. ONNX Inference (`onnx_infer`)
60
+
61
+ Runs batched ONNX model inference using `InferSessionPool` for concurrent GPU/CPU execution.
62
+ Automatically manages inference sessions across worker threads.
63
+
64
+ #### Running the Example
65
+
66
+ ```bash
67
+ python -m smartpool_examples.onnx_infer --max-workers 4
68
+ ```
69
+
70
+ #### What it Demonstrates
71
+
72
+ - `InferSessionPool` creation and session lifecycle management
73
+ - Multi-threaded inference with automatic device placement
74
+ - COCO-format image preprocessing (resize, normalize, letterbox)
75
+ - Softmax + top-5 postprocessing
76
+ - Progress bars for downloads and inference steps
77
+
78
+
59
79
  ## License
60
80
 
61
81
  MIT License - see main smartpool repository for details
@@ -1,10 +1,10 @@
1
1
  [build-system]
2
- requires = ["setuptools>=45", "wheel"]
2
+ requires = ["setuptools", "wheel"]
3
3
  build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "smartpool-examples"
7
- version = "0.1.3"
7
+ version = "0.1.6"
8
8
  description = "Examples for smartpool."
9
9
  readme = "README.md"
10
10
  authors = [
@@ -17,16 +17,19 @@ classifiers = [
17
17
  "Operating System :: OS Independent",
18
18
  ]
19
19
  dependencies = [
20
- "pysmartpool",
20
+ "pysmartpool>=0.1.6",
21
21
  "matplotlib",
22
22
  "scikit-learn",
23
23
  "numpy",
24
24
  "rich",
25
- "click",
25
+ "typer",
26
26
  "joblib",
27
- "datawatcher"
27
+ "opencv-python",
28
28
  ]
29
29
 
30
30
  [project.urls]
31
31
  Homepage = "https://github.com/Time-Coder/smartpool"
32
- Repository = "https://github.com/Time-Coder/smartpool.git"
32
+ Repository = "https://github.com/Time-Coder/smartpool.git"
33
+
34
+ [tool.setuptools.packages.find]
35
+ include = ["smartpool_examples", "smartpool_examples.*"]
@@ -4,28 +4,30 @@ import sys
4
4
  self_folder = os.path.dirname(os.path.abspath(__file__)).replace("\\", "/")
5
5
  sys.path.append(self_folder)
6
6
 
7
- from smartpool import ProcessPool
7
+ if __name__ == "__main__":
8
+ target_folder = os.path.abspath(self_folder + "/../../../smartpool").replace("\\", "/")
9
+ sys.path.append(target_folder)
8
10
 
9
11
  from count_prime import count_prime
10
12
 
13
+ from smartpool import ProcessPool
11
14
 
12
15
  if __name__ == "__main__":
13
16
  print("Use ProcessPool to count prime numbers lower than 10000.")
14
- print(f"See source code at folder {os.path.dirname(os.path.abspath(__file__))}")
15
-
17
+ print(f"See source code at folder {self_folder}")
18
+
16
19
  tasks = []
17
20
  start = 0
18
21
  while start < 10000:
19
22
  stop = start + 1000
20
23
  tasks.append((start, stop))
21
24
  start = stop
22
-
25
+
23
26
  with ProcessPool() as pool:
24
27
  futures = []
25
28
  for task in tasks:
26
29
  future = pool.submit(count_prime, args=task)
27
30
  futures.append(future)
28
-
31
+
29
32
  total_primes_count = sum(future.result() for future in futures)
30
33
  print(total_primes_count)
31
-
@@ -4,10 +4,8 @@ import math
4
4
  def is_prime(num:int):
5
5
  if num < 2:
6
6
  return False
7
- for i in range(2, int(math.sqrt(num)) + 1):
8
- if num % i == 0:
9
- return False
10
- return True
7
+
8
+ return all(num % i != 0 for i in range(2, int(math.sqrt(num)) + 1))
11
9
 
12
10
 
13
11
  def count_prime(start:int, stop:int):
@@ -15,4 +13,4 @@ def count_prime(start:int, stop:int):
15
13
  for i in range(start, stop):
16
14
  if is_prime(i):
17
15
  count += 1
18
- return count
16
+ return count
@@ -1,36 +1,61 @@
1
- from smartpool import ProcessPool, ThreadPool, DataSize, limit_num_single_thread
2
- limit_num_single_thread()
1
+ if __name__ == "__main__":
2
+ import os
3
+ import sys
4
+ self_folder = os.path.dirname(os.path.abspath(__file__)).replace("\\", "/")
5
+ target_folder = os.path.abspath(self_folder + "/../../../smartpool").replace("\\", "/")
6
+ sys.path.append(target_folder)
3
7
 
4
- import click
5
-
6
-
7
- @click.command(help=f"Use smartpool to do 5-fold cross validatation for 7 deep learning models for handwritten digit recognition task.")
8
- @click.option(
9
- '--pool', default='smartpool.ProcessPool', type=click.Choice([
10
- 'smartpool.ProcessPool',
11
- 'smartpool.ThreadPool',
12
- 'multiprocessing.Pool',
13
- 'concurrent.futures.ProcessPoolExecutor',
14
- 'concurrent.futures.ThreadPoolExecutor',
15
- "joblib.Parallel(backend='loky')",
16
- "joblib.Parallel(backend='threading')",
17
- "ray"
18
- ]),
19
- help="choose process pool implementations"
20
- )
21
- @click.option(
22
- '--max_workers', default=0, type=int,
23
- help='max number of workers to use, 0 to use all available cores'
8
+
9
+ from smartpool import (
10
+ DataSize,
11
+ ProcessPool,
12
+ Resource,
13
+ ThreadPool,
14
+ limit_num_single_thread,
24
15
  )
25
- def main(pool:str="smart", max_workers:int=0):
16
+
17
+ limit_num_single_thread()
18
+
19
+ from typing import Literal, TypeAlias
20
+
21
+ import typer
22
+
23
+ app = typer.Typer(help="Use smartpool to do 5-fold cross validatation for 7 deep learning models for handwritten digit recognition task.")
24
+
25
+ PoolChoice: TypeAlias = Literal[
26
+ "smartpool.ProcessPool",
27
+ "smartpool.ThreadPool",
28
+ "multiprocessing.Pool",
29
+ "concurrent.futures.ProcessPoolExecutor",
30
+ "concurrent.futures.ThreadPoolExecutor",
31
+ "joblib.Parallel(backend='loky')",
32
+ "joblib.Parallel(backend='threading')",
33
+ "ray"
34
+ ]
35
+
36
+
37
+ @app.command()
38
+ def main(
39
+ pool: PoolChoice = typer.Option(
40
+ "smartpool.ProcessPool",
41
+ "--pool",
42
+ help="choose process pool implementations"
43
+ ),
44
+ max_workers: int = typer.Option(
45
+ 0,
46
+ "--max_workers",
47
+ help="max number of workers to use, 0 to use all available cores"
48
+ )
49
+ ):
26
50
  import os
51
+ import importlib
27
52
  os.environ["RAY_ACCEL_ENV_VAR_OVERRIDE_ON_ZERO"] = "0"
28
-
53
+
29
54
  print(f"Use {pool} to do 5-fold cross validatation for 7 deep learning models for handwritten digit recognition task.")
30
55
  print("Use `python -m smartpool_examples.cross_validation --help` to see all options.")
31
56
  print(f"See source code at folder {os.path.dirname(os.path.abspath(__file__))}")
32
57
  print("\npreparing data...")
33
-
58
+
34
59
  try:
35
60
  import torch
36
61
  import torch.nn as nn
@@ -38,33 +63,31 @@ def main(pool:str="smart", max_workers:int=0):
38
63
  print("PyTorch is not installed. Follow https://pytorch.org/ instructions to install PyTorch.")
39
64
  exit(1)
40
65
 
41
- try:
42
- import torchvision
43
- except ImportError:
66
+ if importlib.util.find_spec("torchvision") is None:
44
67
  print("torchvision is not installed. Use `pip install torchvision` to install torchvision.")
45
68
  exit(1)
46
69
 
47
- import time
48
- from sklearn.model_selection import KFold
49
- import numpy as np
50
- from rich.progress import Progress, BarColumn, TextColumn, TimeRemainingColumn
51
70
  import multiprocessing as mp
71
+ import os
52
72
  import queue
73
+ import sys
74
+ import time
53
75
  from collections import defaultdict
54
76
  from concurrent.futures import Future
55
77
  from typing import Dict, Union
56
78
 
57
- import os
58
- import sys
79
+ import numpy as np
80
+ from rich.progress import BarColumn, Progress, TextColumn, TimeRemainingColumn
81
+ from sklearn.model_selection import KFold
59
82
 
60
83
  self_folder = os.path.dirname(os.path.abspath(__file__)).replace("\\", "/")
61
84
  sys.path.append(self_folder)
62
85
 
63
86
  import models
87
+ from config import EPOCHS
64
88
  from data_utils import prepare_data
65
- from model_utils import train_single_fold, ErrorInfo, ProgressInfo, TrainingResult
89
+ from model_utils import ErrorInfo, ProgressInfo, TrainingResult, train_single_fold
66
90
  from visualization import plot_results, print_results_table
67
- from config import EPOCHS
68
91
 
69
92
  if max_workers == 0:
70
93
  max_workers = os.cpu_count()
@@ -73,10 +96,10 @@ def main(pool:str="smart", max_workers:int=0):
73
96
  cls for cls in models.__dict__.values()
74
97
  if isinstance(cls, type) and issubclass(cls, nn.Module) and cls != nn.Module
75
98
  ]
76
-
99
+
77
100
  dataset = prepare_data()
78
101
  kfold = KFold(n_splits=5, shuffle=True, random_state=42)
79
-
102
+
80
103
  manager = mp.Manager()
81
104
 
82
105
  if pool != "ray":
@@ -88,17 +111,17 @@ def main(pool:str="smart", max_workers:int=0):
88
111
  except ImportError:
89
112
  print("Ray is not installed. Use `pip install ray` to install Ray.")
90
113
  exit(1)
91
-
114
+
92
115
  progress_queue:queue.Queue[Union[ProgressInfo, ErrorInfo]] = ray.util.queue.Queue()
93
116
 
94
117
  tasks = []
95
118
  for fold_idx, (train_indices, val_indices) in enumerate(kfold.split(dataset)):
96
119
  for model_class in model_classes:
97
120
  tasks.append((fold_idx, model_class, train_indices.copy(), val_indices.copy(), dataset, progress_queue))
98
-
121
+
99
122
  task_progress_bars = {}
100
123
  best_device = 'cuda' if torch.cuda.is_available() else 'cpu'
101
-
124
+
102
125
  start_time = time.perf_counter()
103
126
  with Progress(
104
127
  TextColumn("[progress.description]{task.description}"),
@@ -106,7 +129,7 @@ def main(pool:str="smart", max_workers:int=0):
106
129
  TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
107
130
  TimeRemainingColumn()
108
131
  ) as progress:
109
-
132
+
110
133
  active_tasks = {}
111
134
 
112
135
  if pool == "smartpool.ProcessPool":
@@ -128,7 +151,7 @@ def main(pool:str="smart", max_workers:int=0):
128
151
  elif pool == "joblib.Parallel(backend='threading')":
129
152
  from joblib import Parallel, delayed
130
153
  process_pool = Parallel(n_jobs=max_workers, backend='threading', return_as="generator")
131
-
154
+
132
155
  print("submitting training tasks...")
133
156
  futures_map:Dict[str, Future] = {}
134
157
  futures = []
@@ -137,10 +160,16 @@ def main(pool:str="smart", max_workers:int=0):
137
160
  future = process_pool.submit(
138
161
  train_single_fold,
139
162
  args=task_args,
140
- need_cpu_cores=1,
141
- need_cpu_mem=1.1*DataSize.GB,
142
- need_gpu_cores=1000,
143
- need_gpu_mem=0.2*DataSize.GB
163
+ cpu_mode_res=Resource(
164
+ cpu_cores=1,
165
+ cpu_mem=1.1*DataSize.GB
166
+ ),
167
+ gpu_mode_res=Resource(
168
+ cpu_cores=1,
169
+ cpu_mem=1.1*DataSize.GB,
170
+ gpu_cores=1000,
171
+ gpu_mem=0.2*DataSize.GB
172
+ )
144
173
  )
145
174
  elif pool.startswith("concurrent.futures."):
146
175
  future = process_pool.submit(train_single_fold, *task_args, best_device if i % max_workers < 5 else 'cpu')
@@ -150,14 +179,14 @@ def main(pool:str="smart", max_workers:int=0):
150
179
  future = delayed(train_single_fold)(*task_args, best_device if i % max_workers < 5 else 'cpu')
151
180
  elif pool == "ray":
152
181
  future = ray.remote(num_cpus=1, num_gpus=(0.2 if i % max_workers < 5 else 0), memory=1.1*DataSize.GB)(train_single_fold).remote(*task_args, best_device if i % max_workers < 5 else 'cpu')
153
-
182
+
154
183
  fold_idx = task_args[0]
155
184
  model_class = task_args[1]
156
185
  model_name = model_class.__name__
157
186
  task_key = f"{model_name}_fold_{fold_idx}"
158
187
  futures_map[task_key] = future
159
188
  futures.append(future)
160
-
189
+
161
190
  print(f"training all models in {pool} ...")
162
191
  if pool.startswith("joblib"):
163
192
  joblib_results = process_pool(futures)
@@ -170,31 +199,31 @@ def main(pool:str="smart", max_workers:int=0):
170
199
  break
171
200
 
172
201
  task_key = f"{progress_info.model_name}_fold_{progress_info.fold_idx}"
173
-
202
+
174
203
  if task_key not in task_progress_bars:
175
204
  initial_desc = f"train {progress_info.model_name} on {progress_info.device} "
176
205
  initial_desc += f"for fold {progress_info.fold_idx+1}/5"
177
206
  task_progress_bars[task_key] = progress.add_task(initial_desc, total=100)
178
207
  active_tasks[task_key] = True
179
-
208
+
180
209
  if task_key in task_progress_bars:
181
210
  epoch_progress = (progress_info.epoch - 1) / 5
182
211
  batch_progress = progress_info.batch / progress_info.total_batches
183
212
  total_progress = (epoch_progress + batch_progress / 5) * 100
184
-
213
+
185
214
  if progress_info.epoch == 5 and progress_info.batch == progress_info.total_batches:
186
215
  total_progress = 100.0
187
216
  finished_tasks.add(task_key)
188
-
217
+
189
218
  new_desc = f"train {progress_info.model_name} on {progress_info.device} "
190
219
  new_desc += f"for fold {progress_info.fold_idx+1}/5 - Epoch {progress_info.epoch}/{EPOCHS} "
191
220
  new_desc += f"Loss: {progress_info.avg_loss:.4f} "
192
221
  new_desc += f"Val Acc: {progress_info.val_accuracy*100:.2f}%"
193
222
  if progress_info.device.startswith("cuda"):
194
223
  new_desc = "[bright_cyan]" + new_desc
195
-
224
+
196
225
  progress.update(
197
- task_progress_bars[task_key],
226
+ task_progress_bars[task_key],
198
227
  completed=total_progress,
199
228
  description=new_desc
200
229
  )
@@ -206,7 +235,7 @@ def main(pool:str="smart", max_workers:int=0):
206
235
 
207
236
  model_results = defaultdict(list)
208
237
  if pool in ["smartpool.ProcessPool", "smartpool.ThreadPool", "concurrent.futures.ProcessPoolExecutor", "concurrent.futures.ThreadPoolExecutor", "multiprocessing.Pool"]:
209
- for task_key, future in futures_map.items():
238
+ for future in futures_map.values():
210
239
  if pool == "multiprocessing.Pool":
211
240
  result:TrainingResult = future.get()
212
241
  else:
@@ -220,7 +249,7 @@ def main(pool:str="smart", max_workers:int=0):
220
249
  ray_results = ray.get(futures)
221
250
  for result in ray_results:
222
251
  model_results[result.model_name].append(result.val_accuracy)
223
-
252
+
224
253
  stop_time = time.perf_counter()
225
254
  print(f"train completed in {stop_time - start_time:.2f} seconds")
226
255
 
@@ -235,10 +264,10 @@ def main(pool:str="smart", max_workers:int=0):
235
264
  'max': np.max(accuracies),
236
265
  'accuracies': accuracies
237
266
  }
238
-
267
+
239
268
  print_results_table(stats)
240
269
  plot_results(model_results, stats)
241
270
 
242
271
 
243
272
  if __name__ == "__main__":
244
- main()
273
+ app()
@@ -1,11 +1,12 @@
1
+ import os
2
+
3
+ self_folder = os.path.dirname(os.path.abspath(__file__))
4
+
1
5
  # Training parameters
2
6
  BATCH_SIZE = 128
3
7
  EPOCHS = 5
4
8
  LEARNING_RATE = 0.001
5
9
 
6
10
  # Data settings
7
- import os
8
- self_folder = os.path.dirname(os.path.abspath(__file__))
9
-
10
11
  DATA_ROOT = f'{self_folder}/data'
11
12
  DATASET_NAME = 'MNIST'
@@ -1,25 +1,25 @@
1
1
  import os
2
+
3
+ from config import BATCH_SIZE, DATA_ROOT
2
4
  from torch.utils.data import DataLoader, Subset
3
5
  from torchvision import datasets, transforms
4
6
 
5
- from config import DATA_ROOT, BATCH_SIZE
6
-
7
7
 
8
8
  def prepare_data():
9
9
  transform = transforms.Compose([
10
10
  transforms.ToTensor(),
11
11
  transforms.Normalize((0.1307,), (0.3081,))
12
12
  ])
13
-
13
+
14
14
  mnist_exists = (
15
15
  os.path.exists(os.path.join(DATA_ROOT, 'MNIST', 'raw')) and
16
16
  os.path.exists(os.path.join(DATA_ROOT, 'MNIST', 'processed'))
17
17
  )
18
-
18
+
19
19
  dataset = datasets.MNIST(
20
- root=DATA_ROOT,
21
- train=True,
22
- download=not mnist_exists,
20
+ root=DATA_ROOT,
21
+ train=True,
22
+ download=not mnist_exists,
23
23
  transform=transform
24
24
  )
25
25
  dataset.data.share_memory_()
@@ -30,18 +30,18 @@ def prepare_data():
30
30
  def create_data_loaders(dataset, train_indices, val_indices):
31
31
  train_subset = Subset(dataset, train_indices)
32
32
  val_subset = Subset(dataset, val_indices)
33
-
33
+
34
34
  train_loader = DataLoader(
35
- train_subset,
36
- batch_size=BATCH_SIZE,
37
- shuffle=True,
35
+ train_subset,
36
+ batch_size=BATCH_SIZE,
37
+ shuffle=True,
38
38
  pin_memory=True
39
39
  )
40
40
  val_loader = DataLoader(
41
- val_subset,
42
- batch_size=BATCH_SIZE,
43
- shuffle=False,
41
+ val_subset,
42
+ batch_size=BATCH_SIZE,
43
+ shuffle=False,
44
44
  pin_memory=True
45
45
  )
46
-
47
- return train_loader, val_loader
46
+
47
+ return train_loader, val_loader
@@ -1,14 +1,13 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.optim as optim
4
-
5
1
  import traceback
6
2
  from dataclasses import dataclass
7
3
 
8
- from config import LEARNING_RATE, EPOCHS
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.optim as optim
7
+ from config import EPOCHS, LEARNING_RATE
9
8
  from data_utils import create_data_loaders
10
9
 
11
- from smartpool import move_optimizer_to, best_device
10
+ from smartpool import best_device, move_optimizer_to
12
11
 
13
12
 
14
13
  @dataclass
@@ -47,17 +46,17 @@ def _train_single_fold(fold_idx, model_class, train_indices, val_indices, datase
47
46
  train_loader, val_loader = create_data_loaders(dataset, train_indices, val_indices)
48
47
  num_batches = len(train_loader)
49
48
  model = model_class()
50
-
49
+
51
50
  device = user_device
52
51
  if user_device is None:
53
52
  device = best_device()
54
53
 
55
54
  old_device = device
56
55
  model.to(device, non_blocking=True)
57
-
56
+
58
57
  criterion = nn.CrossEntropyLoss()
59
58
  optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
60
-
59
+
61
60
  initial_progress = ProgressInfo(
62
61
  model_name=model_class.__name__,
63
62
  fold_idx=fold_idx,
@@ -71,7 +70,7 @@ def _train_single_fold(fold_idx, model_class, train_indices, val_indices, datase
71
70
  progress_queue.put(initial_progress)
72
71
 
73
72
  last_val_accuracy = 0.0
74
-
73
+
75
74
  for epoch in range(EPOCHS):
76
75
  epoch_loss = 0.0
77
76
  model.train()
@@ -81,20 +80,20 @@ def _train_single_fold(fold_idx, model_class, train_indices, val_indices, datase
81
80
 
82
81
  data = data.to(device, non_blocking=True)
83
82
  target = target.to(device, non_blocking=True)
84
-
83
+
85
84
  if device != old_device:
86
85
  model.to(device, non_blocking=True)
87
86
  move_optimizer_to(optimizer, device)
88
87
  old_device = device
89
-
88
+
90
89
  optimizer.zero_grad()
91
90
  output = model(data)
92
91
  loss = criterion(output, target)
93
92
  loss.backward()
94
93
  optimizer.step()
95
-
94
+
96
95
  epoch_loss += loss.item()
97
-
96
+
98
97
  progress_info = ProgressInfo(
99
98
  model_name=model_class.__name__,
100
99
  fold_idx=fold_idx,
@@ -106,12 +105,12 @@ def _train_single_fold(fold_idx, model_class, train_indices, val_indices, datase
106
105
  val_accuracy=last_val_accuracy
107
106
  )
108
107
  progress_queue.put(progress_info)
109
-
108
+
110
109
  model.eval()
111
110
  correct = 0
112
111
  total = 0
113
112
  val_accuracy = 0.0
114
-
113
+
115
114
  with torch.no_grad():
116
115
  for data, target in val_loader:
117
116
  if user_device is None:
@@ -119,21 +118,21 @@ def _train_single_fold(fold_idx, model_class, train_indices, val_indices, datase
119
118
 
120
119
  data = data.to(device, non_blocking=True)
121
120
  target = target.to(device, non_blocking=True)
122
-
121
+
123
122
  if device != old_device:
124
123
  model.to(device, non_blocking=True)
125
124
  move_optimizer_to(optimizer, device)
126
125
  old_device = device
127
-
126
+
128
127
  output = model(data)
129
128
  pred = output.argmax(dim=1, keepdim=True)
130
129
  correct += pred.eq(target.view_as(pred)).sum().item()
131
130
  total += target.size(0)
132
-
131
+
133
132
  val_accuracy = correct / total
134
133
  last_val_accuracy = val_accuracy
135
134
  model.train()
136
-
135
+
137
136
  final_progress = ProgressInfo(
138
137
  model_name=model_class.__name__,
139
138
  fold_idx=fold_idx,
@@ -145,11 +144,11 @@ def _train_single_fold(fold_idx, model_class, train_indices, val_indices, datase
145
144
  val_accuracy=val_accuracy
146
145
  )
147
146
  progress_queue.put(final_progress)
148
-
147
+
149
148
  model.eval()
150
149
  correct = 0
151
150
  total = 0
152
-
151
+
153
152
  with torch.no_grad():
154
153
  for data, target in val_loader:
155
154
  if user_device is None:
@@ -167,7 +166,7 @@ def _train_single_fold(fold_idx, model_class, train_indices, val_indices, datase
167
166
  pred = output.argmax(dim=1, keepdim=True)
168
167
  correct += pred.eq(target.view_as(pred)).sum().item()
169
168
  total += target.size(0)
170
-
169
+
171
170
  val_accuracy = correct / total
172
-
171
+
173
172
  return TrainingResult(fold_idx, model_class.__name__, val_accuracy)