modelit 0.2.5__tar.gz → 0.2.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. {modelit-0.2.5 → modelit-0.2.6}/CONTRIBUTING.md +3 -0
  2. {modelit-0.2.5 → modelit-0.2.6}/PKG-INFO +11 -1
  3. {modelit-0.2.5 → modelit-0.2.6}/README.md +10 -0
  4. {modelit-0.2.5 → modelit-0.2.6}/modelit/cli.py +1 -1
  5. modelit-0.2.6/modelit/registry.py +92 -0
  6. modelit-0.2.6/modelit/templates/advanced_exceptions/template.java +42 -0
  7. modelit-0.2.6/modelit/templates/assertions/template.java +19 -0
  8. modelit-0.2.6/modelit/templates/diffusion/template.py +231 -0
  9. modelit-0.2.6/modelit/templates/gan/template.py +233 -0
  10. modelit-0.2.6/modelit/templates/java_default_methods/template.java +38 -0
  11. modelit-0.2.6/modelit/templates/jca_digital_signature/template.java +32 -0
  12. modelit-0.2.6/modelit/templates/jca_encryption/template.java +33 -0
  13. modelit-0.2.6/modelit/templates/jca_hashing/template.java +26 -0
  14. modelit-0.2.6/modelit/templates/jca_hmac/template.java +31 -0
  15. modelit-0.2.6/modelit/templates/jca_mac/template.java +30 -0
  16. modelit-0.2.6/modelit/templates/jca_rsa/template.java +31 -0
  17. modelit-0.2.6/modelit/templates/jca_secure_random/template.java +19 -0
  18. modelit-0.2.6/modelit/templates/lstm/template.py +390 -0
  19. modelit-0.2.6/modelit/templates/pca_ae/template.py +226 -0
  20. modelit-0.2.6/modelit/templates/rnn/template.py +392 -0
  21. modelit-0.2.6/modelit/templates/sofm/template.py +90 -0
  22. modelit-0.2.6/modelit/templates/solid_notification/EmailSender.java +6 -0
  23. modelit-0.2.6/modelit/templates/solid_notification/Main.java +24 -0
  24. modelit-0.2.6/modelit/templates/solid_notification/NotificationSender.java +3 -0
  25. modelit-0.2.6/modelit/templates/solid_notification/NotificationService.java +11 -0
  26. modelit-0.2.6/modelit/templates/solid_notification/PushSender.java +6 -0
  27. modelit-0.2.6/modelit/templates/solid_notification/SlackSender.java +6 -0
  28. modelit-0.2.6/modelit/templates/solid_notification/SmsSender.java +6 -0
  29. modelit-0.2.6/modelit/templates/solid_shapes/AreaCalculator.java +14 -0
  30. modelit-0.2.6/modelit/templates/solid_shapes/Circle.java +12 -0
  31. modelit-0.2.6/modelit/templates/solid_shapes/Main.java +20 -0
  32. modelit-0.2.6/modelit/templates/solid_shapes/Pentagon.java +14 -0
  33. modelit-0.2.6/modelit/templates/solid_shapes/Rectangle.java +14 -0
  34. modelit-0.2.6/modelit/templates/solid_shapes/Shape.java +3 -0
  35. modelit-0.2.6/modelit/templates/solid_shapes/Triangle.java +14 -0
  36. modelit-0.2.6/modelit/templates/streams_employee/Employee.java +23 -0
  37. modelit-0.2.6/modelit/templates/streams_employee/Main.java +57 -0
  38. modelit-0.2.6/modelit/templates/streams_word_freq/template.java +59 -0
  39. modelit-0.2.6/modelit/templates/thread_deadlock/template.java +40 -0
  40. modelit-0.2.6/modelit/templates/thread_deadlock_solved/template.java +54 -0
  41. modelit-0.2.6/modelit/templates/vae/template.py +322 -0
  42. {modelit-0.2.5 → modelit-0.2.6}/modelit.egg-info/PKG-INFO +11 -1
  43. modelit-0.2.6/modelit.egg-info/SOURCES.txt +75 -0
  44. modelit-0.2.5/modelit/registry.py +0 -94
  45. modelit-0.2.5/modelit.egg-info/SOURCES.txt +0 -39
  46. {modelit-0.2.5 → modelit-0.2.6}/.github/workflows/publish.yml +0 -0
  47. {modelit-0.2.5 → modelit-0.2.6}/.github/workflows/test.yml +0 -0
  48. {modelit-0.2.5 → modelit-0.2.6}/.gitignore +0 -0
  49. {modelit-0.2.5 → modelit-0.2.6}/LICENSE +0 -0
  50. {modelit-0.2.5 → modelit-0.2.6}/MANIFEST.in +0 -0
  51. {modelit-0.2.5 → modelit-0.2.6}/modelit/__init__.py +0 -0
  52. {modelit-0.2.5 → modelit-0.2.6}/modelit/__main__.py +0 -0
  53. {modelit-0.2.5 → modelit-0.2.6}/modelit/templates/abstract_factory_pattern/template.java +0 -0
  54. {modelit-0.2.5 → modelit-0.2.6}/modelit/templates/adapter_pattern/template.java +0 -0
  55. {modelit-0.2.5 → modelit-0.2.6}/modelit/templates/backpropagation/template.py +0 -0
  56. {modelit-0.2.5 → modelit-0.2.6}/modelit/templates/bridge_pattern/template.java +0 -0
  57. {modelit-0.2.5 → modelit-0.2.6}/modelit/templates/builder_pattern/template.java +0 -0
  58. {modelit-0.2.5 → modelit-0.2.6}/modelit/templates/cnn/template.py +0 -0
  59. {modelit-0.2.5 → modelit-0.2.6}/modelit/templates/command_pattern/template.java +0 -0
  60. {modelit-0.2.5 → modelit-0.2.6}/modelit/templates/composite_pattern/template.java +0 -0
  61. {modelit-0.2.5 → modelit-0.2.6}/modelit/templates/decorator_pattern/template.java +0 -0
  62. {modelit-0.2.5 → modelit-0.2.6}/modelit/templates/facade_pattern/template.java +0 -0
  63. {modelit-0.2.5 → modelit-0.2.6}/modelit/templates/factory_pattern/template.java +0 -0
  64. {modelit-0.2.5 → modelit-0.2.6}/modelit/templates/flyweight_pattern/template.java +0 -0
  65. {modelit-0.2.5 → modelit-0.2.6}/modelit/templates/interpreter_pattern/template.java +0 -0
  66. {modelit-0.2.5 → modelit-0.2.6}/modelit/templates/mediator_pattern/template.java +0 -0
  67. {modelit-0.2.5 → modelit-0.2.6}/modelit/templates/memento_pattern/template.java +0 -0
  68. {modelit-0.2.5 → modelit-0.2.6}/modelit/templates/observer_pattern/template.java +0 -0
  69. {modelit-0.2.5 → modelit-0.2.6}/modelit/templates/perceptron/template.py +0 -0
  70. {modelit-0.2.5 → modelit-0.2.6}/modelit/templates/prototype_pattern/template.java +0 -0
  71. {modelit-0.2.5 → modelit-0.2.6}/modelit/templates/proxy_pattern/template.java +0 -0
  72. {modelit-0.2.5 → modelit-0.2.6}/modelit/templates/state_pattern/template.java +0 -0
  73. {modelit-0.2.5 → modelit-0.2.6}/modelit/templates/template_pattern/template.java +0 -0
  74. {modelit-0.2.5 → modelit-0.2.6}/modelit/templates/visitor_pattern/template.java +0 -0
  75. {modelit-0.2.5 → modelit-0.2.6}/modelit.egg-info/dependency_links.txt +0 -0
  76. {modelit-0.2.5 → modelit-0.2.6}/modelit.egg-info/entry_points.txt +0 -0
  77. {modelit-0.2.5 → modelit-0.2.6}/modelit.egg-info/top_level.txt +0 -0
  78. {modelit-0.2.5 → modelit-0.2.6}/pyproject.toml +0 -0
  79. {modelit-0.2.5 → modelit-0.2.6}/setup.cfg +0 -0
@@ -17,6 +17,8 @@ modelit/templates/<name>/
17
17
  modelit/templates/<name>/template.py
18
18
  ```
19
19
 
20
+ For multi-file templates, add more files in the same folder.
21
+
20
22
  4. Make sure the folder name is the function name.
21
23
  5. Test it locally:
22
24
 
@@ -39,6 +41,7 @@ modelit create <name>
39
41
  - Do not add `metadata.json`.
40
42
  - Keep file names lowercase.
41
43
  - Prefer clean, beginner-friendly code.
44
+ - If a template has multiple files, `--output` should be treated as a directory.
42
45
 
43
46
  ## Pull request flow
44
47
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: modelit
3
- Version: 0.2.5
3
+ Version: 0.2.6
4
4
  Summary: Local-first ML starter templates you can print or save
5
5
  Author: Yashsmith
6
6
  License-Expression: MIT
@@ -28,6 +28,8 @@ ModelIt is a tiny Python package for storing your ML/DL boilerplate templates.
28
28
  - `perceptron()` prints the full code
29
29
  - `perceptron(output="code1.py")` saves it to a file
30
30
  - `modelit create perceptron` works from the terminal
31
+ - single-file templates write to a file
32
+ - multi-file templates write to a folder
31
33
 
32
34
  ## Install
33
35
 
@@ -94,6 +96,8 @@ modelit/templates/mycode/template.java
94
96
 
95
97
  3. Done. The folder name becomes the function name.
96
98
 
99
+ If a template has more than one file, `--output` must be a directory path.
100
+
97
101
  That means this will work automatically:
98
102
 
99
103
  ```python
@@ -109,6 +113,12 @@ modelit create mycode
109
113
  modelit create mycode --output mycode.py
110
114
  ```
111
115
 
116
+ For multi-file templates:
117
+
118
+ ```bash
119
+ modelit create solid_notification --output out/solid_notification
120
+ ```
121
+
112
122
  ## Publish flow
113
123
 
114
124
  1. Add a new template folder.
@@ -8,6 +8,8 @@ ModelIt is a tiny Python package for storing your ML/DL boilerplate templates.
8
8
  - `perceptron()` prints the full code
9
9
  - `perceptron(output="code1.py")` saves it to a file
10
10
  - `modelit create perceptron` works from the terminal
11
+ - single-file templates write to a file
12
+ - multi-file templates write to a folder
11
13
 
12
14
  ## Install
13
15
 
@@ -74,6 +76,8 @@ modelit/templates/mycode/template.java
74
76
 
75
77
  3. Done. The folder name becomes the function name.
76
78
 
79
+ If a template has more than one file, `--output` must be a directory path.
80
+
77
81
  That means this will work automatically:
78
82
 
79
83
  ```python
@@ -89,6 +93,12 @@ modelit create mycode
89
93
  modelit create mycode --output mycode.py
90
94
  ```
91
95
 
96
+ For multi-file templates:
97
+
98
+ ```bash
99
+ modelit create solid_notification --output out/solid_notification
100
+ ```
101
+
92
102
  ## Publish flow
93
103
 
94
104
  1. Add a new template folder.
@@ -17,7 +17,7 @@ def main(argv: list[str] | None = None) -> None:
17
17
 
18
18
  create_parser = subparsers.add_parser("create", help="Print or save a template")
19
19
  create_parser.add_argument("name", choices=available_models(), help="Template name")
20
- create_parser.add_argument("-o", "--output", help="Write the template to a file instead of printing")
20
+ create_parser.add_argument("-o", "--output", help="Write to a file for single-file templates or a directory for multi-file templates")
21
21
  create_parser.set_defaults(func=_create)
22
22
 
23
23
  args = parser.parse_args(argv)
@@ -0,0 +1,92 @@
1
+ """Template discovery and loading."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass
6
+ from importlib.resources import files
7
+ from pathlib import Path
8
+
9
+ PACKAGE_NAME = "modelit"
10
+ TEMPLATES_DIR = "templates"
11
+
12
+ @dataclass(frozen=True)
13
+ class TemplateInfo:
14
+ name: str
15
+
16
+ def _templates_root():
17
+ return files(PACKAGE_NAME).joinpath(TEMPLATES_DIR)
18
+
19
+ def _template_dir(name: str):
20
+ return _templates_root().joinpath(name)
21
+
22
+ def available_models() -> tuple[str, ...]:
23
+ root = _templates_root()
24
+ if not root.is_dir():
25
+ return ()
26
+
27
+ names: list[str] = []
28
+ for child in root.iterdir():
29
+ if child.is_dir() and not child.name.startswith("__"):
30
+ if any(child.iterdir()):
31
+ names.append(child.name)
32
+ return tuple(sorted(names))
33
+
34
+ def load_metadata(name: str) -> TemplateInfo:
35
+ return TemplateInfo(name=name)
36
+
37
+ def load_template_files(name: str) -> dict[str, str]:
38
+ target_dir = _template_dir(name)
39
+ if not target_dir.is_dir():
40
+ raise FileNotFoundError(f"Missing template directory for {name!r}")
41
+
42
+ file_contents = {}
43
+ for child in target_dir.iterdir():
44
+ if child.is_file() and not child.name.startswith("__"):
45
+ file_contents[child.name] = child.read_text(encoding="utf-8")
46
+ return file_contents
47
+
48
+ def build_template_callable(name: str):
49
+ files_dict = load_template_files(name)
50
+ info = load_metadata(name)
51
+
52
+ is_single_file = len(files_dict) == 1
53
+ default_filename = list(files_dict.keys())[0] if is_single_file else name
54
+
55
+ def runner(output: str | None = None) -> None:
56
+ if output:
57
+ out_path = Path(output)
58
+
59
+ if is_single_file:
60
+ if out_path.exists():
61
+ raise FileExistsError(f"Output path already exists: {out_path}")
62
+ out_path.parent.mkdir(parents=True, exist_ok=True)
63
+ out_path.write_text(list(files_dict.values())[0], encoding="utf-8")
64
+ print(f"Generated {out_path}")
65
+ else:
66
+ if out_path.suffix:
67
+ raise ValueError("Multi-file templates require a directory output path")
68
+ if out_path.exists() and not out_path.is_dir():
69
+ raise FileExistsError(f"Output path already exists and is not a directory: {out_path}")
70
+ out_path.mkdir(parents=True, exist_ok=True)
71
+ for fname, content in files_dict.items():
72
+ file_path = out_path / fname
73
+ if file_path.exists():
74
+ print(f"Skipping {file_path} (already exists)")
75
+ continue
76
+ file_path.write_text(content, encoding="utf-8")
77
+ print(f"Generated {file_path}")
78
+ return None
79
+
80
+ # If no output is specified, print to terminal
81
+ for fname, content in files_dict.items():
82
+ if not is_single_file:
83
+ print(f"\n{'='*40}\nFile: {fname}\n{'='*40}")
84
+ print(content, end="\n\n" if not is_single_file else "")
85
+
86
+ runner.__name__ = name
87
+ runner.__qualname__ = name
88
+ runner.__module__ = "modelit"
89
+ runner.__doc__ = f"Print or save the {name} template."
90
+ runner.output_file = default_filename # type: ignore[attr-defined]
91
+ runner.template_info = info # type: ignore[attr-defined]
92
+ return runner
@@ -0,0 +1,42 @@
1
+ import java.io.BufferedReader;
2
+ import java.io.FileReader;
3
+ import java.io.IOException;
4
+
5
+ class InsufficientFundsException extends Exception {
6
+ public InsufficientFundsException(String message) {
7
+ super(message);
8
+ }
9
+ }
10
+
11
+ public class AdvancedExceptionDemo {
12
+
13
+ public static void processTransfer(int accountBalance, int transferAmount) throws InsufficientFundsException {
14
+ if (transferAmount > accountBalance) {
15
+ throw new InsufficientFundsException("Transfer failed: Attempted to send $" + transferAmount + " but balance is $" + accountBalance);
16
+ }
17
+ System.out.println("Transfer of $" + transferAmount + " completed successfully.");
18
+ }
19
+
20
+ public static void main(String[] args) {
21
+ try {
22
+ processTransfer(5000, 7500);
23
+ } catch (InsufficientFundsException e) {
24
+ System.err.println("Business Error: " + e.getMessage());
25
+ }
26
+
27
+ System.out.println("-----------------");
28
+
29
+ String filePath = "non_existent_transactions.csv";
30
+
31
+ try (BufferedReader reader = new BufferedReader(new FileReader(filePath))) {
32
+ System.out.println(reader.readLine());
33
+ int value = Integer.parseInt("Not_A_Number");
34
+
35
+ } catch (IOException | NumberFormatException e) {
36
+ System.err.println("System Failure: Could not process file or format was corrupted.");
37
+ System.err.println("Root Cause: " + e.getClass().getSimpleName() + " -> " + e.getMessage());
38
+ } finally {
39
+ System.out.println("Cleanup executed. System remains stable.");
40
+ }
41
+ }
42
+ }
@@ -0,0 +1,19 @@
1
+ class DiscountCalculator {
2
+ public static double applyDiscount(double originalPrice, double discountPercentage) {
3
+ double discountAmount = originalPrice * (discountPercentage / 100);
4
+ double finalPrice = originalPrice - discountAmount;
5
+
6
+ assert finalPrice >= 0 : "CRITICAL LOGIC ERROR: Final price fell below zero!";
7
+
8
+ return finalPrice;
9
+ }
10
+ }
11
+
12
+ public class AssertionDemo {
13
+ public static void main(String[] args) {
14
+ System.out.println("Final Price: $" + DiscountCalculator.applyDiscount(1000, 20));
15
+
16
+ System.out.println("Testing invalid logic...");
17
+ System.out.println("Final Price: $" + DiscountCalculator.applyDiscount(500, 150));
18
+ }
19
+ }
@@ -0,0 +1,231 @@
1
+ # =========================================================
2
+ # DENOISING DIFFUSION PROBABILISTIC MODEL (DDPM)
3
+ # Learning the y = -x Manifold via Reverse Diffusion
4
+ # Built using PyTorch
5
+ # =========================================================
6
+
7
+
8
+ # =========================================================
9
+ # IMPORTS
10
+ # =========================================================
11
+
12
+ import numpy as np
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.optim as optim
16
+ import matplotlib.pyplot as plt
17
+
18
+
19
+ # =========================================================
20
+ # DATA GENERATION FUNCTION (REAL MANIFOLD)
21
+ # =========================================================
22
+
23
+ def sample_line_data(n_samples=512):
24
+
25
+ x = torch.empty(n_samples, 1).uniform_(-4.5, 4.5)
26
+
27
+ y = -x
28
+
29
+ return torch.cat([x, y], dim=1)
30
+
31
+
32
+ # =========================================================
33
+ # DIFFUSION HYPERPARAMETERS & SCHEDULING
34
+ # =========================================================
35
+
36
+ steps = 120
37
+
38
+ betas = torch.linspace(2e-4, 0.015, steps)
39
+
40
+ alphas = 1.0 - betas
41
+
42
+ alpha_cum = torch.cumprod(alphas, dim=0)
43
+
44
+
45
+ # =========================================================
46
+ # FORWARD DIFFUSION PROCESS (NOISING)
47
+ # =========================================================
48
+
49
+ def diffuse(x0, t_idx):
50
+
51
+ noise = torch.randn_like(x0)
52
+
53
+ a_bar = alpha_cum[t_idx].view(-1, 1)
54
+
55
+ x_t = torch.sqrt(a_bar) * x0 + torch.sqrt(1 - a_bar) * noise
56
+
57
+ return x_t, noise
58
+
59
+
60
+ # =========================================================
61
+ # NOISE ESTIMATOR NETWORK (MLP WITH TIME EMBEDDING)
62
+ # =========================================================
63
+
64
+ class NoiseEstimator(nn.Module):
65
+
66
+ def __init__(self, dim=2, hidden=128):
67
+
68
+ super().__init__()
69
+
70
+ self.mlp = nn.Sequential(
71
+
72
+ nn.Linear(dim + 16, hidden),
73
+
74
+ nn.SiLU(),
75
+
76
+ nn.Linear(hidden, hidden),
77
+
78
+ nn.SiLU(),
79
+
80
+ nn.Linear(hidden, dim)
81
+
82
+ )
83
+
84
+
85
+ def time_embed(self, t):
86
+
87
+ half = 8
88
+
89
+ freqs = torch.exp(torch.linspace(0, 3, half))
90
+
91
+ t = t.float().unsqueeze(1)
92
+
93
+ emb = torch.cat(
94
+ [torch.sin(t / freqs), torch.cos(t / freqs)],
95
+ dim=1
96
+ )
97
+
98
+ return emb
99
+
100
+
101
+ def forward(self, x, t):
102
+
103
+ t_emb = self.time_embed(t)
104
+
105
+ return self.mlp(torch.cat([x, t_emb], dim=1))
106
+
107
+
108
+ # =========================================================
109
+ # INITIALIZATION & CONFIGURATION
110
+ # =========================================================
111
+
112
+ model = NoiseEstimator()
113
+
114
+ opt = optim.Adam(model.parameters(), lr=2e-3)
115
+
116
+ loss_fn = nn.L1Loss()
117
+
118
+ epochs = 1800
119
+
120
+ batch = 128
121
+
122
+
123
+ # =========================================================
124
+ # DIFFUSION MODEL TRAINING LOOP
125
+ # =========================================================
126
+
127
+ print("\n=================================================")
128
+ print("TRAINING DENOISING DIFFUSION MODEL")
129
+ print("=================================================")
130
+
131
+ for ep in range(epochs):
132
+
133
+ x0 = sample_line_data(batch)
134
+
135
+ t = torch.randint(0, steps, (batch,))
136
+
137
+ xt, eps = diffuse(x0, t)
138
+
139
+ eps_hat = model(xt, t)
140
+
141
+ loss = loss_fn(eps_hat, eps)
142
+
143
+ opt.zero_grad()
144
+
145
+ loss.backward()
146
+
147
+ opt.step()
148
+
149
+
150
+ # =========================================================
151
+ # PROGRESS LOGGING
152
+ # =========================================================
153
+
154
+ if ep % 300 == 0:
155
+
156
+ print(f"Epoch {ep:04d} | L1 Loss: {loss.item():.5f}")
157
+
158
+
159
+ print("\nDiffusion Training Complete.")
160
+
161
+
162
+ # =========================================================
163
+ # REVERSE DIFFUSION PROCESS (GENERATION)
164
+ # =========================================================
165
+
166
+ @torch.no_grad()
167
+ def generate(n=800):
168
+
169
+ x = torch.randn(n, 2)
170
+
171
+ for t in reversed(range(steps)):
172
+
173
+ t_batch = torch.full((n,), t)
174
+
175
+ eps = model(x, t_batch)
176
+
177
+ a = alphas[t]
178
+
179
+ a_bar = alpha_cum[t]
180
+
181
+ b = betas[t]
182
+
183
+ coef1 = 1 / torch.sqrt(a)
184
+
185
+ coef2 = (1 - a) / torch.sqrt(1 - a_bar)
186
+
187
+ noise = torch.randn_like(x) if t > 0 else 0
188
+
189
+ x = coef1 * (x - coef2 * eps) + torch.sqrt(b) * noise
190
+
191
+ return x
192
+
193
+
194
+ # =========================================================
195
+ # GENERATING SAMPLES FOR EVALUATION
196
+ # =========================================================
197
+
198
+ real = sample_line_data(1000)
199
+
200
+ fake = generate(1000)
201
+
202
+
203
+ # =========================================================
204
+ # VISUALIZATION
205
+ # =========================================================
206
+
207
+ plt.figure(figsize=(6, 6))
208
+
209
+ plt.scatter(
210
+ real[:, 0],
211
+ real[:, 1],
212
+ s=5,
213
+ alpha=0.4,
214
+ label="Real data"
215
+ )
216
+
217
+ plt.scatter(
218
+ fake[:, 0],
219
+ fake[:, 1],
220
+ s=5,
221
+ alpha=0.4,
222
+ label="Generated"
223
+ )
224
+
225
+ plt.legend()
226
+
227
+ plt.title("Diffusion learns y = -x manifold")
228
+
229
+ plt.axis("equal")
230
+
231
+ plt.show()
@@ -0,0 +1,233 @@
1
+ # =========================================================
2
+ # GENERATIVE ADVERSARIAL NETWORK (GAN) FOR 2D DISTRIBUTION
3
+ # Learning the y = -x Manifold from Scratch
4
+ # Built using PyTorch
5
+ # =========================================================
6
+
7
+
8
+ # =========================================================
9
+ # IMPORTS
10
+ # =========================================================
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.optim as optim
15
+ import matplotlib.pyplot as plt
16
+
17
+
18
+ # =========================================================
19
+ # HYPERPARAMETERS & DEVICE CONFIGURATION
20
+ # =========================================================
21
+
22
+ batch_size = 128
23
+
24
+ noise_dim = 2
25
+
26
+ lr = 0.0002
27
+
28
+ epochs = 2000
29
+
30
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
+
32
+
33
+ # =========================================================
34
+ # DATA GENERATION FUNCTION (REAL DISTRIBUTION)
35
+ # =========================================================
36
+
37
+ def get_real_samples(batch_size):
38
+
39
+ x = torch.randn(batch_size, 1)
40
+
41
+ y = -x
42
+
43
+ data = torch.cat([x, y], dim=1)
44
+
45
+ return data.to(device)
46
+
47
+
48
+ # =========================================================
49
+ # GENERATOR ARCHITECTURE
50
+ # =========================================================
51
+
52
+ class Generator(nn.Module):
53
+
54
+ def __init__(self):
55
+
56
+ super().__init__()
57
+
58
+ self.model = nn.Sequential(
59
+
60
+ nn.Linear(noise_dim, 16),
61
+
62
+ nn.ReLU(),
63
+
64
+ nn.Linear(16, 16),
65
+
66
+ nn.ReLU(),
67
+
68
+ nn.Linear(16, 2)
69
+
70
+ )
71
+
72
+
73
+ def forward(self, z):
74
+
75
+ return self.model(z)
76
+
77
+
78
+ # =========================================================
79
+ # DISCRIMINATOR ARCHITECTURE
80
+ # =========================================================
81
+
82
+ class Discriminator(nn.Module):
83
+
84
+ def __init__(self):
85
+
86
+ super().__init__()
87
+
88
+ self.model = nn.Sequential(
89
+
90
+ nn.Linear(2, 16),
91
+
92
+ nn.ReLU(),
93
+
94
+ nn.Linear(16, 16),
95
+
96
+ nn.ReLU(),
97
+
98
+ nn.Linear(16, 1),
99
+
100
+ nn.Sigmoid()
101
+
102
+ )
103
+
104
+
105
+ def forward(self, x):
106
+
107
+ return self.model(x)
108
+
109
+
110
+ # =========================================================
111
+ # MODEL INITIALIZATION & OPTIMIZERS
112
+ # =========================================================
113
+
114
+ G = Generator().to(device)
115
+
116
+ D = Discriminator().to(device)
117
+
118
+ criterion = nn.BCELoss()
119
+
120
+ optimizer_G = optim.Adam(G.parameters(), lr=lr)
121
+
122
+ optimizer_D = optim.Adam(D.parameters(), lr=lr)
123
+
124
+
125
+ # =========================================================
126
+ # GAN TRAINING LOOP
127
+ # =========================================================
128
+
129
+ print("\n=================================================")
130
+ print("TRAINING GENERATIVE ADVERSARIAL NETWORK")
131
+ print("=================================================")
132
+
133
+ for epoch in range(epochs):
134
+
135
+
136
+ # =========================================================
137
+ # TRAIN DISCRIMINATOR
138
+ # =========================================================
139
+
140
+ real_data = get_real_samples(batch_size)
141
+
142
+ real_labels = torch.ones(batch_size, 1).to(device)
143
+
144
+ noise = torch.randn(batch_size, noise_dim).to(device)
145
+
146
+ fake_data = G(noise)
147
+
148
+ fake_labels = torch.zeros(batch_size, 1).to(device)
149
+
150
+ real_preds = D(real_data)
151
+
152
+ fake_preds = D(fake_data.detach())
153
+
154
+ loss_real = criterion(real_preds, real_labels)
155
+
156
+ loss_fake = criterion(fake_preds, fake_labels)
157
+
158
+ loss_D = loss_real + loss_fake
159
+
160
+ optimizer_D.zero_grad()
161
+
162
+ loss_D.backward()
163
+
164
+ optimizer_D.step()
165
+
166
+
167
+ # =========================================================
168
+ # TRAIN GENERATOR
169
+ # =========================================================
170
+
171
+ noise = torch.randn(batch_size, noise_dim).to(device)
172
+
173
+ fake_data = G(noise)
174
+
175
+ fake_preds = D(fake_data)
176
+
177
+ loss_G = criterion(fake_preds, real_labels)
178
+
179
+ optimizer_G.zero_grad()
180
+
181
+ loss_G.backward()
182
+
183
+ optimizer_G.step()
184
+
185
+
186
+ # =========================================================
187
+ # PROGRESS LOGGING
188
+ # =========================================================
189
+
190
+ if epoch % 200 == 0:
191
+
192
+ print(f"Epoch {epoch:04d} | D Loss: {loss_D.item():.4f} | G Loss: {loss_G.item():.4f}")
193
+
194
+
195
+ print("\nGAN Training Complete.")
196
+
197
+
198
+ # =========================================================
199
+ # EVALUATION & IMAGE GENERATION
200
+ # =========================================================
201
+
202
+ noise = torch.randn(1000, noise_dim).to(device)
203
+
204
+ generated = G(noise).detach().cpu()
205
+
206
+ real = get_real_samples(1000).cpu()
207
+
208
+
209
+ # =========================================================
210
+ # VISUALIZATION
211
+ # =========================================================
212
+
213
+ plt.figure(figsize=(6, 6))
214
+
215
+ plt.scatter(
216
+ real[:, 0],
217
+ real[:, 1],
218
+ label="Real (y=-x)",
219
+ alpha=0.5
220
+ )
221
+
222
+ plt.scatter(
223
+ generated[:, 0],
224
+ generated[:, 1],
225
+ label="Generated",
226
+ alpha=0.5
227
+ )
228
+
229
+ plt.legend()
230
+
231
+ plt.title("GAN Learning y = -x Distribution")
232
+
233
+ plt.show()