modelit 0.2.5__tar.gz → 0.2.6__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {modelit-0.2.5 → modelit-0.2.6}/CONTRIBUTING.md +3 -0
- {modelit-0.2.5 → modelit-0.2.6}/PKG-INFO +11 -1
- {modelit-0.2.5 → modelit-0.2.6}/README.md +10 -0
- {modelit-0.2.5 → modelit-0.2.6}/modelit/cli.py +1 -1
- modelit-0.2.6/modelit/registry.py +92 -0
- modelit-0.2.6/modelit/templates/advanced_exceptions/template.java +42 -0
- modelit-0.2.6/modelit/templates/assertions/template.java +19 -0
- modelit-0.2.6/modelit/templates/diffusion/template.py +231 -0
- modelit-0.2.6/modelit/templates/gan/template.py +233 -0
- modelit-0.2.6/modelit/templates/java_default_methods/template.java +38 -0
- modelit-0.2.6/modelit/templates/jca_digital_signature/template.java +32 -0
- modelit-0.2.6/modelit/templates/jca_encryption/template.java +33 -0
- modelit-0.2.6/modelit/templates/jca_hashing/template.java +26 -0
- modelit-0.2.6/modelit/templates/jca_hmac/template.java +31 -0
- modelit-0.2.6/modelit/templates/jca_mac/template.java +30 -0
- modelit-0.2.6/modelit/templates/jca_rsa/template.java +31 -0
- modelit-0.2.6/modelit/templates/jca_secure_random/template.java +19 -0
- modelit-0.2.6/modelit/templates/lstm/template.py +390 -0
- modelit-0.2.6/modelit/templates/pca_ae/template.py +226 -0
- modelit-0.2.6/modelit/templates/rnn/template.py +392 -0
- modelit-0.2.6/modelit/templates/sofm/template.py +90 -0
- modelit-0.2.6/modelit/templates/solid_notification/EmailSender.java +6 -0
- modelit-0.2.6/modelit/templates/solid_notification/Main.java +24 -0
- modelit-0.2.6/modelit/templates/solid_notification/NotificationSender.java +3 -0
- modelit-0.2.6/modelit/templates/solid_notification/NotificationService.java +11 -0
- modelit-0.2.6/modelit/templates/solid_notification/PushSender.java +6 -0
- modelit-0.2.6/modelit/templates/solid_notification/SlackSender.java +6 -0
- modelit-0.2.6/modelit/templates/solid_notification/SmsSender.java +6 -0
- modelit-0.2.6/modelit/templates/solid_shapes/AreaCalculator.java +14 -0
- modelit-0.2.6/modelit/templates/solid_shapes/Circle.java +12 -0
- modelit-0.2.6/modelit/templates/solid_shapes/Main.java +20 -0
- modelit-0.2.6/modelit/templates/solid_shapes/Pentagon.java +14 -0
- modelit-0.2.6/modelit/templates/solid_shapes/Rectangle.java +14 -0
- modelit-0.2.6/modelit/templates/solid_shapes/Shape.java +3 -0
- modelit-0.2.6/modelit/templates/solid_shapes/Triangle.java +14 -0
- modelit-0.2.6/modelit/templates/streams_employee/Employee.java +23 -0
- modelit-0.2.6/modelit/templates/streams_employee/Main.java +57 -0
- modelit-0.2.6/modelit/templates/streams_word_freq/template.java +59 -0
- modelit-0.2.6/modelit/templates/thread_deadlock/template.java +40 -0
- modelit-0.2.6/modelit/templates/thread_deadlock_solved/template.java +54 -0
- modelit-0.2.6/modelit/templates/vae/template.py +322 -0
- {modelit-0.2.5 → modelit-0.2.6}/modelit.egg-info/PKG-INFO +11 -1
- modelit-0.2.6/modelit.egg-info/SOURCES.txt +75 -0
- modelit-0.2.5/modelit/registry.py +0 -94
- modelit-0.2.5/modelit.egg-info/SOURCES.txt +0 -39
- {modelit-0.2.5 → modelit-0.2.6}/.github/workflows/publish.yml +0 -0
- {modelit-0.2.5 → modelit-0.2.6}/.github/workflows/test.yml +0 -0
- {modelit-0.2.5 → modelit-0.2.6}/.gitignore +0 -0
- {modelit-0.2.5 → modelit-0.2.6}/LICENSE +0 -0
- {modelit-0.2.5 → modelit-0.2.6}/MANIFEST.in +0 -0
- {modelit-0.2.5 → modelit-0.2.6}/modelit/__init__.py +0 -0
- {modelit-0.2.5 → modelit-0.2.6}/modelit/__main__.py +0 -0
- {modelit-0.2.5 → modelit-0.2.6}/modelit/templates/abstract_factory_pattern/template.java +0 -0
- {modelit-0.2.5 → modelit-0.2.6}/modelit/templates/adapter_pattern/template.java +0 -0
- {modelit-0.2.5 → modelit-0.2.6}/modelit/templates/backpropagation/template.py +0 -0
- {modelit-0.2.5 → modelit-0.2.6}/modelit/templates/bridge_pattern/template.java +0 -0
- {modelit-0.2.5 → modelit-0.2.6}/modelit/templates/builder_pattern/template.java +0 -0
- {modelit-0.2.5 → modelit-0.2.6}/modelit/templates/cnn/template.py +0 -0
- {modelit-0.2.5 → modelit-0.2.6}/modelit/templates/command_pattern/template.java +0 -0
- {modelit-0.2.5 → modelit-0.2.6}/modelit/templates/composite_pattern/template.java +0 -0
- {modelit-0.2.5 → modelit-0.2.6}/modelit/templates/decorator_pattern/template.java +0 -0
- {modelit-0.2.5 → modelit-0.2.6}/modelit/templates/facade_pattern/template.java +0 -0
- {modelit-0.2.5 → modelit-0.2.6}/modelit/templates/factory_pattern/template.java +0 -0
- {modelit-0.2.5 → modelit-0.2.6}/modelit/templates/flyweight_pattern/template.java +0 -0
- {modelit-0.2.5 → modelit-0.2.6}/modelit/templates/interpreter_pattern/template.java +0 -0
- {modelit-0.2.5 → modelit-0.2.6}/modelit/templates/mediator_pattern/template.java +0 -0
- {modelit-0.2.5 → modelit-0.2.6}/modelit/templates/memento_pattern/template.java +0 -0
- {modelit-0.2.5 → modelit-0.2.6}/modelit/templates/observer_pattern/template.java +0 -0
- {modelit-0.2.5 → modelit-0.2.6}/modelit/templates/perceptron/template.py +0 -0
- {modelit-0.2.5 → modelit-0.2.6}/modelit/templates/prototype_pattern/template.java +0 -0
- {modelit-0.2.5 → modelit-0.2.6}/modelit/templates/proxy_pattern/template.java +0 -0
- {modelit-0.2.5 → modelit-0.2.6}/modelit/templates/state_pattern/template.java +0 -0
- {modelit-0.2.5 → modelit-0.2.6}/modelit/templates/template_pattern/template.java +0 -0
- {modelit-0.2.5 → modelit-0.2.6}/modelit/templates/visitor_pattern/template.java +0 -0
- {modelit-0.2.5 → modelit-0.2.6}/modelit.egg-info/dependency_links.txt +0 -0
- {modelit-0.2.5 → modelit-0.2.6}/modelit.egg-info/entry_points.txt +0 -0
- {modelit-0.2.5 → modelit-0.2.6}/modelit.egg-info/top_level.txt +0 -0
- {modelit-0.2.5 → modelit-0.2.6}/pyproject.toml +0 -0
- {modelit-0.2.5 → modelit-0.2.6}/setup.cfg +0 -0
|
@@ -17,6 +17,8 @@ modelit/templates/<name>/
|
|
|
17
17
|
modelit/templates/<name>/template.py
|
|
18
18
|
```
|
|
19
19
|
|
|
20
|
+
For multi-file templates, add more files in the same folder.
|
|
21
|
+
|
|
20
22
|
4. Make sure the folder name is the function name.
|
|
21
23
|
5. Test it locally:
|
|
22
24
|
|
|
@@ -39,6 +41,7 @@ modelit create <name>
|
|
|
39
41
|
- Do not add `metadata.json`.
|
|
40
42
|
- Keep file names lowercase.
|
|
41
43
|
- Prefer clean, beginner-friendly code.
|
|
44
|
+
- If a template has multiple files, `--output` should be treated as a directory.
|
|
42
45
|
|
|
43
46
|
## Pull request flow
|
|
44
47
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: modelit
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.6
|
|
4
4
|
Summary: Local-first ML starter templates you can print or save
|
|
5
5
|
Author: Yashsmith
|
|
6
6
|
License-Expression: MIT
|
|
@@ -28,6 +28,8 @@ ModelIt is a tiny Python package for storing your ML/DL boilerplate templates.
|
|
|
28
28
|
- `perceptron()` prints the full code
|
|
29
29
|
- `perceptron(output="code1.py")` saves it to a file
|
|
30
30
|
- `modelit create perceptron` works from the terminal
|
|
31
|
+
- single-file templates write to a file
|
|
32
|
+
- multi-file templates write to a folder
|
|
31
33
|
|
|
32
34
|
## Install
|
|
33
35
|
|
|
@@ -94,6 +96,8 @@ modelit/templates/mycode/template.java
|
|
|
94
96
|
|
|
95
97
|
3. Done. The folder name becomes the function name.
|
|
96
98
|
|
|
99
|
+
If a template has more than one file, `--output` must be a directory path.
|
|
100
|
+
|
|
97
101
|
That means this will work automatically:
|
|
98
102
|
|
|
99
103
|
```python
|
|
@@ -109,6 +113,12 @@ modelit create mycode
|
|
|
109
113
|
modelit create mycode --output mycode.py
|
|
110
114
|
```
|
|
111
115
|
|
|
116
|
+
For multi-file templates:
|
|
117
|
+
|
|
118
|
+
```bash
|
|
119
|
+
modelit create solid_notification --output out/solid_notification
|
|
120
|
+
```
|
|
121
|
+
|
|
112
122
|
## Publish flow
|
|
113
123
|
|
|
114
124
|
1. Add a new template folder.
|
|
@@ -8,6 +8,8 @@ ModelIt is a tiny Python package for storing your ML/DL boilerplate templates.
|
|
|
8
8
|
- `perceptron()` prints the full code
|
|
9
9
|
- `perceptron(output="code1.py")` saves it to a file
|
|
10
10
|
- `modelit create perceptron` works from the terminal
|
|
11
|
+
- single-file templates write to a file
|
|
12
|
+
- multi-file templates write to a folder
|
|
11
13
|
|
|
12
14
|
## Install
|
|
13
15
|
|
|
@@ -74,6 +76,8 @@ modelit/templates/mycode/template.java
|
|
|
74
76
|
|
|
75
77
|
3. Done. The folder name becomes the function name.
|
|
76
78
|
|
|
79
|
+
If a template has more than one file, `--output` must be a directory path.
|
|
80
|
+
|
|
77
81
|
That means this will work automatically:
|
|
78
82
|
|
|
79
83
|
```python
|
|
@@ -89,6 +93,12 @@ modelit create mycode
|
|
|
89
93
|
modelit create mycode --output mycode.py
|
|
90
94
|
```
|
|
91
95
|
|
|
96
|
+
For multi-file templates:
|
|
97
|
+
|
|
98
|
+
```bash
|
|
99
|
+
modelit create solid_notification --output out/solid_notification
|
|
100
|
+
```
|
|
101
|
+
|
|
92
102
|
## Publish flow
|
|
93
103
|
|
|
94
104
|
1. Add a new template folder.
|
|
@@ -17,7 +17,7 @@ def main(argv: list[str] | None = None) -> None:
|
|
|
17
17
|
|
|
18
18
|
create_parser = subparsers.add_parser("create", help="Print or save a template")
|
|
19
19
|
create_parser.add_argument("name", choices=available_models(), help="Template name")
|
|
20
|
-
create_parser.add_argument("-o", "--output", help="Write
|
|
20
|
+
create_parser.add_argument("-o", "--output", help="Write to a file for single-file templates or a directory for multi-file templates")
|
|
21
21
|
create_parser.set_defaults(func=_create)
|
|
22
22
|
|
|
23
23
|
args = parser.parse_args(argv)
|
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
"""Template discovery and loading."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from importlib.resources import files
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
|
|
9
|
+
PACKAGE_NAME = "modelit"
|
|
10
|
+
TEMPLATES_DIR = "templates"
|
|
11
|
+
|
|
12
|
+
@dataclass(frozen=True)
|
|
13
|
+
class TemplateInfo:
|
|
14
|
+
name: str
|
|
15
|
+
|
|
16
|
+
def _templates_root():
|
|
17
|
+
return files(PACKAGE_NAME).joinpath(TEMPLATES_DIR)
|
|
18
|
+
|
|
19
|
+
def _template_dir(name: str):
|
|
20
|
+
return _templates_root().joinpath(name)
|
|
21
|
+
|
|
22
|
+
def available_models() -> tuple[str, ...]:
|
|
23
|
+
root = _templates_root()
|
|
24
|
+
if not root.is_dir():
|
|
25
|
+
return ()
|
|
26
|
+
|
|
27
|
+
names: list[str] = []
|
|
28
|
+
for child in root.iterdir():
|
|
29
|
+
if child.is_dir() and not child.name.startswith("__"):
|
|
30
|
+
if any(child.iterdir()):
|
|
31
|
+
names.append(child.name)
|
|
32
|
+
return tuple(sorted(names))
|
|
33
|
+
|
|
34
|
+
def load_metadata(name: str) -> TemplateInfo:
|
|
35
|
+
return TemplateInfo(name=name)
|
|
36
|
+
|
|
37
|
+
def load_template_files(name: str) -> dict[str, str]:
|
|
38
|
+
target_dir = _template_dir(name)
|
|
39
|
+
if not target_dir.is_dir():
|
|
40
|
+
raise FileNotFoundError(f"Missing template directory for {name!r}")
|
|
41
|
+
|
|
42
|
+
file_contents = {}
|
|
43
|
+
for child in target_dir.iterdir():
|
|
44
|
+
if child.is_file() and not child.name.startswith("__"):
|
|
45
|
+
file_contents[child.name] = child.read_text(encoding="utf-8")
|
|
46
|
+
return file_contents
|
|
47
|
+
|
|
48
|
+
def build_template_callable(name: str):
|
|
49
|
+
files_dict = load_template_files(name)
|
|
50
|
+
info = load_metadata(name)
|
|
51
|
+
|
|
52
|
+
is_single_file = len(files_dict) == 1
|
|
53
|
+
default_filename = list(files_dict.keys())[0] if is_single_file else name
|
|
54
|
+
|
|
55
|
+
def runner(output: str | None = None) -> None:
|
|
56
|
+
if output:
|
|
57
|
+
out_path = Path(output)
|
|
58
|
+
|
|
59
|
+
if is_single_file:
|
|
60
|
+
if out_path.exists():
|
|
61
|
+
raise FileExistsError(f"Output path already exists: {out_path}")
|
|
62
|
+
out_path.parent.mkdir(parents=True, exist_ok=True)
|
|
63
|
+
out_path.write_text(list(files_dict.values())[0], encoding="utf-8")
|
|
64
|
+
print(f"Generated {out_path}")
|
|
65
|
+
else:
|
|
66
|
+
if out_path.suffix:
|
|
67
|
+
raise ValueError("Multi-file templates require a directory output path")
|
|
68
|
+
if out_path.exists() and not out_path.is_dir():
|
|
69
|
+
raise FileExistsError(f"Output path already exists and is not a directory: {out_path}")
|
|
70
|
+
out_path.mkdir(parents=True, exist_ok=True)
|
|
71
|
+
for fname, content in files_dict.items():
|
|
72
|
+
file_path = out_path / fname
|
|
73
|
+
if file_path.exists():
|
|
74
|
+
print(f"Skipping {file_path} (already exists)")
|
|
75
|
+
continue
|
|
76
|
+
file_path.write_text(content, encoding="utf-8")
|
|
77
|
+
print(f"Generated {file_path}")
|
|
78
|
+
return None
|
|
79
|
+
|
|
80
|
+
# If no output is specified, print to terminal
|
|
81
|
+
for fname, content in files_dict.items():
|
|
82
|
+
if not is_single_file:
|
|
83
|
+
print(f"\n{'='*40}\nFile: {fname}\n{'='*40}")
|
|
84
|
+
print(content, end="\n\n" if not is_single_file else "")
|
|
85
|
+
|
|
86
|
+
runner.__name__ = name
|
|
87
|
+
runner.__qualname__ = name
|
|
88
|
+
runner.__module__ = "modelit"
|
|
89
|
+
runner.__doc__ = f"Print or save the {name} template."
|
|
90
|
+
runner.output_file = default_filename # type: ignore[attr-defined]
|
|
91
|
+
runner.template_info = info # type: ignore[attr-defined]
|
|
92
|
+
return runner
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
import java.io.BufferedReader;
|
|
2
|
+
import java.io.FileReader;
|
|
3
|
+
import java.io.IOException;
|
|
4
|
+
|
|
5
|
+
class InsufficientFundsException extends Exception {
|
|
6
|
+
public InsufficientFundsException(String message) {
|
|
7
|
+
super(message);
|
|
8
|
+
}
|
|
9
|
+
}
|
|
10
|
+
|
|
11
|
+
public class AdvancedExceptionDemo {
|
|
12
|
+
|
|
13
|
+
public static void processTransfer(int accountBalance, int transferAmount) throws InsufficientFundsException {
|
|
14
|
+
if (transferAmount > accountBalance) {
|
|
15
|
+
throw new InsufficientFundsException("Transfer failed: Attempted to send $" + transferAmount + " but balance is $" + accountBalance);
|
|
16
|
+
}
|
|
17
|
+
System.out.println("Transfer of $" + transferAmount + " completed successfully.");
|
|
18
|
+
}
|
|
19
|
+
|
|
20
|
+
public static void main(String[] args) {
|
|
21
|
+
try {
|
|
22
|
+
processTransfer(5000, 7500);
|
|
23
|
+
} catch (InsufficientFundsException e) {
|
|
24
|
+
System.err.println("Business Error: " + e.getMessage());
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
System.out.println("-----------------");
|
|
28
|
+
|
|
29
|
+
String filePath = "non_existent_transactions.csv";
|
|
30
|
+
|
|
31
|
+
try (BufferedReader reader = new BufferedReader(new FileReader(filePath))) {
|
|
32
|
+
System.out.println(reader.readLine());
|
|
33
|
+
int value = Integer.parseInt("Not_A_Number");
|
|
34
|
+
|
|
35
|
+
} catch (IOException | NumberFormatException e) {
|
|
36
|
+
System.err.println("System Failure: Could not process file or format was corrupted.");
|
|
37
|
+
System.err.println("Root Cause: " + e.getClass().getSimpleName() + " -> " + e.getMessage());
|
|
38
|
+
} finally {
|
|
39
|
+
System.out.println("Cleanup executed. System remains stable.");
|
|
40
|
+
}
|
|
41
|
+
}
|
|
42
|
+
}
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
class DiscountCalculator {
|
|
2
|
+
public static double applyDiscount(double originalPrice, double discountPercentage) {
|
|
3
|
+
double discountAmount = originalPrice * (discountPercentage / 100);
|
|
4
|
+
double finalPrice = originalPrice - discountAmount;
|
|
5
|
+
|
|
6
|
+
assert finalPrice >= 0 : "CRITICAL LOGIC ERROR: Final price fell below zero!";
|
|
7
|
+
|
|
8
|
+
return finalPrice;
|
|
9
|
+
}
|
|
10
|
+
}
|
|
11
|
+
|
|
12
|
+
public class AssertionDemo {
|
|
13
|
+
public static void main(String[] args) {
|
|
14
|
+
System.out.println("Final Price: $" + DiscountCalculator.applyDiscount(1000, 20));
|
|
15
|
+
|
|
16
|
+
System.out.println("Testing invalid logic...");
|
|
17
|
+
System.out.println("Final Price: $" + DiscountCalculator.applyDiscount(500, 150));
|
|
18
|
+
}
|
|
19
|
+
}
|
|
@@ -0,0 +1,231 @@
|
|
|
1
|
+
# =========================================================
|
|
2
|
+
# DENOISING DIFFUSION PROBABILISTIC MODEL (DDPM)
|
|
3
|
+
# Learning the y = -x Manifold via Reverse Diffusion
|
|
4
|
+
# Built using PyTorch
|
|
5
|
+
# =========================================================
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
# =========================================================
|
|
9
|
+
# IMPORTS
|
|
10
|
+
# =========================================================
|
|
11
|
+
|
|
12
|
+
import numpy as np
|
|
13
|
+
import torch
|
|
14
|
+
import torch.nn as nn
|
|
15
|
+
import torch.optim as optim
|
|
16
|
+
import matplotlib.pyplot as plt
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
# =========================================================
|
|
20
|
+
# DATA GENERATION FUNCTION (REAL MANIFOLD)
|
|
21
|
+
# =========================================================
|
|
22
|
+
|
|
23
|
+
def sample_line_data(n_samples=512):
|
|
24
|
+
|
|
25
|
+
x = torch.empty(n_samples, 1).uniform_(-4.5, 4.5)
|
|
26
|
+
|
|
27
|
+
y = -x
|
|
28
|
+
|
|
29
|
+
return torch.cat([x, y], dim=1)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
# =========================================================
|
|
33
|
+
# DIFFUSION HYPERPARAMETERS & SCHEDULING
|
|
34
|
+
# =========================================================
|
|
35
|
+
|
|
36
|
+
steps = 120
|
|
37
|
+
|
|
38
|
+
betas = torch.linspace(2e-4, 0.015, steps)
|
|
39
|
+
|
|
40
|
+
alphas = 1.0 - betas
|
|
41
|
+
|
|
42
|
+
alpha_cum = torch.cumprod(alphas, dim=0)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
# =========================================================
|
|
46
|
+
# FORWARD DIFFUSION PROCESS (NOISING)
|
|
47
|
+
# =========================================================
|
|
48
|
+
|
|
49
|
+
def diffuse(x0, t_idx):
|
|
50
|
+
|
|
51
|
+
noise = torch.randn_like(x0)
|
|
52
|
+
|
|
53
|
+
a_bar = alpha_cum[t_idx].view(-1, 1)
|
|
54
|
+
|
|
55
|
+
x_t = torch.sqrt(a_bar) * x0 + torch.sqrt(1 - a_bar) * noise
|
|
56
|
+
|
|
57
|
+
return x_t, noise
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
# =========================================================
|
|
61
|
+
# NOISE ESTIMATOR NETWORK (MLP WITH TIME EMBEDDING)
|
|
62
|
+
# =========================================================
|
|
63
|
+
|
|
64
|
+
class NoiseEstimator(nn.Module):
|
|
65
|
+
|
|
66
|
+
def __init__(self, dim=2, hidden=128):
|
|
67
|
+
|
|
68
|
+
super().__init__()
|
|
69
|
+
|
|
70
|
+
self.mlp = nn.Sequential(
|
|
71
|
+
|
|
72
|
+
nn.Linear(dim + 16, hidden),
|
|
73
|
+
|
|
74
|
+
nn.SiLU(),
|
|
75
|
+
|
|
76
|
+
nn.Linear(hidden, hidden),
|
|
77
|
+
|
|
78
|
+
nn.SiLU(),
|
|
79
|
+
|
|
80
|
+
nn.Linear(hidden, dim)
|
|
81
|
+
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def time_embed(self, t):
|
|
86
|
+
|
|
87
|
+
half = 8
|
|
88
|
+
|
|
89
|
+
freqs = torch.exp(torch.linspace(0, 3, half))
|
|
90
|
+
|
|
91
|
+
t = t.float().unsqueeze(1)
|
|
92
|
+
|
|
93
|
+
emb = torch.cat(
|
|
94
|
+
[torch.sin(t / freqs), torch.cos(t / freqs)],
|
|
95
|
+
dim=1
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
return emb
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def forward(self, x, t):
|
|
102
|
+
|
|
103
|
+
t_emb = self.time_embed(t)
|
|
104
|
+
|
|
105
|
+
return self.mlp(torch.cat([x, t_emb], dim=1))
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
# =========================================================
|
|
109
|
+
# INITIALIZATION & CONFIGURATION
|
|
110
|
+
# =========================================================
|
|
111
|
+
|
|
112
|
+
model = NoiseEstimator()
|
|
113
|
+
|
|
114
|
+
opt = optim.Adam(model.parameters(), lr=2e-3)
|
|
115
|
+
|
|
116
|
+
loss_fn = nn.L1Loss()
|
|
117
|
+
|
|
118
|
+
epochs = 1800
|
|
119
|
+
|
|
120
|
+
batch = 128
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
# =========================================================
|
|
124
|
+
# DIFFUSION MODEL TRAINING LOOP
|
|
125
|
+
# =========================================================
|
|
126
|
+
|
|
127
|
+
print("\n=================================================")
|
|
128
|
+
print("TRAINING DENOISING DIFFUSION MODEL")
|
|
129
|
+
print("=================================================")
|
|
130
|
+
|
|
131
|
+
for ep in range(epochs):
|
|
132
|
+
|
|
133
|
+
x0 = sample_line_data(batch)
|
|
134
|
+
|
|
135
|
+
t = torch.randint(0, steps, (batch,))
|
|
136
|
+
|
|
137
|
+
xt, eps = diffuse(x0, t)
|
|
138
|
+
|
|
139
|
+
eps_hat = model(xt, t)
|
|
140
|
+
|
|
141
|
+
loss = loss_fn(eps_hat, eps)
|
|
142
|
+
|
|
143
|
+
opt.zero_grad()
|
|
144
|
+
|
|
145
|
+
loss.backward()
|
|
146
|
+
|
|
147
|
+
opt.step()
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
# =========================================================
|
|
151
|
+
# PROGRESS LOGGING
|
|
152
|
+
# =========================================================
|
|
153
|
+
|
|
154
|
+
if ep % 300 == 0:
|
|
155
|
+
|
|
156
|
+
print(f"Epoch {ep:04d} | L1 Loss: {loss.item():.5f}")
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
print("\nDiffusion Training Complete.")
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
# =========================================================
|
|
163
|
+
# REVERSE DIFFUSION PROCESS (GENERATION)
|
|
164
|
+
# =========================================================
|
|
165
|
+
|
|
166
|
+
@torch.no_grad()
|
|
167
|
+
def generate(n=800):
|
|
168
|
+
|
|
169
|
+
x = torch.randn(n, 2)
|
|
170
|
+
|
|
171
|
+
for t in reversed(range(steps)):
|
|
172
|
+
|
|
173
|
+
t_batch = torch.full((n,), t)
|
|
174
|
+
|
|
175
|
+
eps = model(x, t_batch)
|
|
176
|
+
|
|
177
|
+
a = alphas[t]
|
|
178
|
+
|
|
179
|
+
a_bar = alpha_cum[t]
|
|
180
|
+
|
|
181
|
+
b = betas[t]
|
|
182
|
+
|
|
183
|
+
coef1 = 1 / torch.sqrt(a)
|
|
184
|
+
|
|
185
|
+
coef2 = (1 - a) / torch.sqrt(1 - a_bar)
|
|
186
|
+
|
|
187
|
+
noise = torch.randn_like(x) if t > 0 else 0
|
|
188
|
+
|
|
189
|
+
x = coef1 * (x - coef2 * eps) + torch.sqrt(b) * noise
|
|
190
|
+
|
|
191
|
+
return x
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
# =========================================================
|
|
195
|
+
# GENERATING SAMPLES FOR EVALUATION
|
|
196
|
+
# =========================================================
|
|
197
|
+
|
|
198
|
+
real = sample_line_data(1000)
|
|
199
|
+
|
|
200
|
+
fake = generate(1000)
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
# =========================================================
|
|
204
|
+
# VISUALIZATION
|
|
205
|
+
# =========================================================
|
|
206
|
+
|
|
207
|
+
plt.figure(figsize=(6, 6))
|
|
208
|
+
|
|
209
|
+
plt.scatter(
|
|
210
|
+
real[:, 0],
|
|
211
|
+
real[:, 1],
|
|
212
|
+
s=5,
|
|
213
|
+
alpha=0.4,
|
|
214
|
+
label="Real data"
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
plt.scatter(
|
|
218
|
+
fake[:, 0],
|
|
219
|
+
fake[:, 1],
|
|
220
|
+
s=5,
|
|
221
|
+
alpha=0.4,
|
|
222
|
+
label="Generated"
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
plt.legend()
|
|
226
|
+
|
|
227
|
+
plt.title("Diffusion learns y = -x manifold")
|
|
228
|
+
|
|
229
|
+
plt.axis("equal")
|
|
230
|
+
|
|
231
|
+
plt.show()
|
|
@@ -0,0 +1,233 @@
|
|
|
1
|
+
# =========================================================
|
|
2
|
+
# GENERATIVE ADVERSARIAL NETWORK (GAN) FOR 2D DISTRIBUTION
|
|
3
|
+
# Learning the y = -x Manifold from Scratch
|
|
4
|
+
# Built using PyTorch
|
|
5
|
+
# =========================================================
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
# =========================================================
|
|
9
|
+
# IMPORTS
|
|
10
|
+
# =========================================================
|
|
11
|
+
|
|
12
|
+
import torch
|
|
13
|
+
import torch.nn as nn
|
|
14
|
+
import torch.optim as optim
|
|
15
|
+
import matplotlib.pyplot as plt
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
# =========================================================
|
|
19
|
+
# HYPERPARAMETERS & DEVICE CONFIGURATION
|
|
20
|
+
# =========================================================
|
|
21
|
+
|
|
22
|
+
batch_size = 128
|
|
23
|
+
|
|
24
|
+
noise_dim = 2
|
|
25
|
+
|
|
26
|
+
lr = 0.0002
|
|
27
|
+
|
|
28
|
+
epochs = 2000
|
|
29
|
+
|
|
30
|
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
# =========================================================
|
|
34
|
+
# DATA GENERATION FUNCTION (REAL DISTRIBUTION)
|
|
35
|
+
# =========================================================
|
|
36
|
+
|
|
37
|
+
def get_real_samples(batch_size):
|
|
38
|
+
|
|
39
|
+
x = torch.randn(batch_size, 1)
|
|
40
|
+
|
|
41
|
+
y = -x
|
|
42
|
+
|
|
43
|
+
data = torch.cat([x, y], dim=1)
|
|
44
|
+
|
|
45
|
+
return data.to(device)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
# =========================================================
|
|
49
|
+
# GENERATOR ARCHITECTURE
|
|
50
|
+
# =========================================================
|
|
51
|
+
|
|
52
|
+
class Generator(nn.Module):
|
|
53
|
+
|
|
54
|
+
def __init__(self):
|
|
55
|
+
|
|
56
|
+
super().__init__()
|
|
57
|
+
|
|
58
|
+
self.model = nn.Sequential(
|
|
59
|
+
|
|
60
|
+
nn.Linear(noise_dim, 16),
|
|
61
|
+
|
|
62
|
+
nn.ReLU(),
|
|
63
|
+
|
|
64
|
+
nn.Linear(16, 16),
|
|
65
|
+
|
|
66
|
+
nn.ReLU(),
|
|
67
|
+
|
|
68
|
+
nn.Linear(16, 2)
|
|
69
|
+
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def forward(self, z):
|
|
74
|
+
|
|
75
|
+
return self.model(z)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
# =========================================================
|
|
79
|
+
# DISCRIMINATOR ARCHITECTURE
|
|
80
|
+
# =========================================================
|
|
81
|
+
|
|
82
|
+
class Discriminator(nn.Module):
|
|
83
|
+
|
|
84
|
+
def __init__(self):
|
|
85
|
+
|
|
86
|
+
super().__init__()
|
|
87
|
+
|
|
88
|
+
self.model = nn.Sequential(
|
|
89
|
+
|
|
90
|
+
nn.Linear(2, 16),
|
|
91
|
+
|
|
92
|
+
nn.ReLU(),
|
|
93
|
+
|
|
94
|
+
nn.Linear(16, 16),
|
|
95
|
+
|
|
96
|
+
nn.ReLU(),
|
|
97
|
+
|
|
98
|
+
nn.Linear(16, 1),
|
|
99
|
+
|
|
100
|
+
nn.Sigmoid()
|
|
101
|
+
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def forward(self, x):
|
|
106
|
+
|
|
107
|
+
return self.model(x)
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
# =========================================================
|
|
111
|
+
# MODEL INITIALIZATION & OPTIMIZERS
|
|
112
|
+
# =========================================================
|
|
113
|
+
|
|
114
|
+
G = Generator().to(device)
|
|
115
|
+
|
|
116
|
+
D = Discriminator().to(device)
|
|
117
|
+
|
|
118
|
+
criterion = nn.BCELoss()
|
|
119
|
+
|
|
120
|
+
optimizer_G = optim.Adam(G.parameters(), lr=lr)
|
|
121
|
+
|
|
122
|
+
optimizer_D = optim.Adam(D.parameters(), lr=lr)
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
# =========================================================
|
|
126
|
+
# GAN TRAINING LOOP
|
|
127
|
+
# =========================================================
|
|
128
|
+
|
|
129
|
+
print("\n=================================================")
|
|
130
|
+
print("TRAINING GENERATIVE ADVERSARIAL NETWORK")
|
|
131
|
+
print("=================================================")
|
|
132
|
+
|
|
133
|
+
for epoch in range(epochs):
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
# =========================================================
|
|
137
|
+
# TRAIN DISCRIMINATOR
|
|
138
|
+
# =========================================================
|
|
139
|
+
|
|
140
|
+
real_data = get_real_samples(batch_size)
|
|
141
|
+
|
|
142
|
+
real_labels = torch.ones(batch_size, 1).to(device)
|
|
143
|
+
|
|
144
|
+
noise = torch.randn(batch_size, noise_dim).to(device)
|
|
145
|
+
|
|
146
|
+
fake_data = G(noise)
|
|
147
|
+
|
|
148
|
+
fake_labels = torch.zeros(batch_size, 1).to(device)
|
|
149
|
+
|
|
150
|
+
real_preds = D(real_data)
|
|
151
|
+
|
|
152
|
+
fake_preds = D(fake_data.detach())
|
|
153
|
+
|
|
154
|
+
loss_real = criterion(real_preds, real_labels)
|
|
155
|
+
|
|
156
|
+
loss_fake = criterion(fake_preds, fake_labels)
|
|
157
|
+
|
|
158
|
+
loss_D = loss_real + loss_fake
|
|
159
|
+
|
|
160
|
+
optimizer_D.zero_grad()
|
|
161
|
+
|
|
162
|
+
loss_D.backward()
|
|
163
|
+
|
|
164
|
+
optimizer_D.step()
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
# =========================================================
|
|
168
|
+
# TRAIN GENERATOR
|
|
169
|
+
# =========================================================
|
|
170
|
+
|
|
171
|
+
noise = torch.randn(batch_size, noise_dim).to(device)
|
|
172
|
+
|
|
173
|
+
fake_data = G(noise)
|
|
174
|
+
|
|
175
|
+
fake_preds = D(fake_data)
|
|
176
|
+
|
|
177
|
+
loss_G = criterion(fake_preds, real_labels)
|
|
178
|
+
|
|
179
|
+
optimizer_G.zero_grad()
|
|
180
|
+
|
|
181
|
+
loss_G.backward()
|
|
182
|
+
|
|
183
|
+
optimizer_G.step()
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
# =========================================================
|
|
187
|
+
# PROGRESS LOGGING
|
|
188
|
+
# =========================================================
|
|
189
|
+
|
|
190
|
+
if epoch % 200 == 0:
|
|
191
|
+
|
|
192
|
+
print(f"Epoch {epoch:04d} | D Loss: {loss_D.item():.4f} | G Loss: {loss_G.item():.4f}")
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
print("\nGAN Training Complete.")
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
# =========================================================
|
|
199
|
+
# EVALUATION & IMAGE GENERATION
|
|
200
|
+
# =========================================================
|
|
201
|
+
|
|
202
|
+
noise = torch.randn(1000, noise_dim).to(device)
|
|
203
|
+
|
|
204
|
+
generated = G(noise).detach().cpu()
|
|
205
|
+
|
|
206
|
+
real = get_real_samples(1000).cpu()
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
# =========================================================
|
|
210
|
+
# VISUALIZATION
|
|
211
|
+
# =========================================================
|
|
212
|
+
|
|
213
|
+
plt.figure(figsize=(6, 6))
|
|
214
|
+
|
|
215
|
+
plt.scatter(
|
|
216
|
+
real[:, 0],
|
|
217
|
+
real[:, 1],
|
|
218
|
+
label="Real (y=-x)",
|
|
219
|
+
alpha=0.5
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
plt.scatter(
|
|
223
|
+
generated[:, 0],
|
|
224
|
+
generated[:, 1],
|
|
225
|
+
label="Generated",
|
|
226
|
+
alpha=0.5
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
plt.legend()
|
|
230
|
+
|
|
231
|
+
plt.title("GAN Learning y = -x Distribution")
|
|
232
|
+
|
|
233
|
+
plt.show()
|