modelit 0.2.5__tar.gz → 0.2.7__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. {modelit-0.2.5 → modelit-0.2.7}/CONTRIBUTING.md +9 -0
  2. {modelit-0.2.5 → modelit-0.2.7}/PKG-INFO +13 -1
  3. {modelit-0.2.5 → modelit-0.2.7}/README.md +12 -0
  4. {modelit-0.2.5 → modelit-0.2.7}/modelit/cli.py +9 -1
  5. modelit-0.2.7/modelit/registry.py +92 -0
  6. modelit-0.2.7/modelit/templates/advanced_exceptions/template.java +42 -0
  7. modelit-0.2.7/modelit/templates/assertions/template.java +19 -0
  8. modelit-0.2.7/modelit/templates/diffusion/template.py +231 -0
  9. modelit-0.2.7/modelit/templates/gan/template.py +233 -0
  10. modelit-0.2.7/modelit/templates/java_default_methods/template.java +38 -0
  11. modelit-0.2.7/modelit/templates/jca_digital_signature/template.java +32 -0
  12. modelit-0.2.7/modelit/templates/jca_encryption/template.java +33 -0
  13. modelit-0.2.7/modelit/templates/jca_hashing/template.java +26 -0
  14. modelit-0.2.7/modelit/templates/jca_hmac/template.java +31 -0
  15. modelit-0.2.7/modelit/templates/jca_mac/template.java +30 -0
  16. modelit-0.2.7/modelit/templates/jca_rsa/template.java +31 -0
  17. modelit-0.2.7/modelit/templates/jca_secure_random/template.java +19 -0
  18. modelit-0.2.7/modelit/templates/lstm/template.py +390 -0
  19. modelit-0.2.7/modelit/templates/pca_ae/template.py +226 -0
  20. modelit-0.2.7/modelit/templates/rnn/template.py +392 -0
  21. modelit-0.2.7/modelit/templates/sofm/template.py +90 -0
  22. modelit-0.2.7/modelit/templates/solid_notification/EmailSender.java +6 -0
  23. modelit-0.2.7/modelit/templates/solid_notification/Main.java +24 -0
  24. modelit-0.2.7/modelit/templates/solid_notification/NotificationSender.java +3 -0
  25. modelit-0.2.7/modelit/templates/solid_notification/NotificationService.java +11 -0
  26. modelit-0.2.7/modelit/templates/solid_notification/PushSender.java +6 -0
  27. modelit-0.2.7/modelit/templates/solid_notification/SlackSender.java +6 -0
  28. modelit-0.2.7/modelit/templates/solid_notification/SmsSender.java +6 -0
  29. modelit-0.2.7/modelit/templates/solid_shapes/AreaCalculator.java +14 -0
  30. modelit-0.2.7/modelit/templates/solid_shapes/Circle.java +12 -0
  31. modelit-0.2.7/modelit/templates/solid_shapes/Main.java +20 -0
  32. modelit-0.2.7/modelit/templates/solid_shapes/Pentagon.java +14 -0
  33. modelit-0.2.7/modelit/templates/solid_shapes/Rectangle.java +14 -0
  34. modelit-0.2.7/modelit/templates/solid_shapes/Shape.java +3 -0
  35. modelit-0.2.7/modelit/templates/solid_shapes/Triangle.java +14 -0
  36. modelit-0.2.7/modelit/templates/streams_employee/Employee.java +23 -0
  37. modelit-0.2.7/modelit/templates/streams_employee/Main.java +57 -0
  38. modelit-0.2.7/modelit/templates/streams_word_freq/template.java +59 -0
  39. modelit-0.2.7/modelit/templates/thread_deadlock/template.java +40 -0
  40. modelit-0.2.7/modelit/templates/thread_deadlock_solved/template.java +54 -0
  41. modelit-0.2.7/modelit/templates/vae/template.py +322 -0
  42. {modelit-0.2.5 → modelit-0.2.7}/modelit.egg-info/PKG-INFO +13 -1
  43. modelit-0.2.7/modelit.egg-info/SOURCES.txt +75 -0
  44. modelit-0.2.5/modelit/registry.py +0 -94
  45. modelit-0.2.5/modelit.egg-info/SOURCES.txt +0 -39
  46. {modelit-0.2.5 → modelit-0.2.7}/.github/workflows/publish.yml +0 -0
  47. {modelit-0.2.5 → modelit-0.2.7}/.github/workflows/test.yml +0 -0
  48. {modelit-0.2.5 → modelit-0.2.7}/.gitignore +0 -0
  49. {modelit-0.2.5 → modelit-0.2.7}/LICENSE +0 -0
  50. {modelit-0.2.5 → modelit-0.2.7}/MANIFEST.in +0 -0
  51. {modelit-0.2.5 → modelit-0.2.7}/modelit/__init__.py +0 -0
  52. {modelit-0.2.5 → modelit-0.2.7}/modelit/__main__.py +0 -0
  53. {modelit-0.2.5 → modelit-0.2.7}/modelit/templates/abstract_factory_pattern/template.java +0 -0
  54. {modelit-0.2.5 → modelit-0.2.7}/modelit/templates/adapter_pattern/template.java +0 -0
  55. {modelit-0.2.5 → modelit-0.2.7}/modelit/templates/backpropagation/template.py +0 -0
  56. {modelit-0.2.5 → modelit-0.2.7}/modelit/templates/bridge_pattern/template.java +0 -0
  57. {modelit-0.2.5 → modelit-0.2.7}/modelit/templates/builder_pattern/template.java +0 -0
  58. {modelit-0.2.5 → modelit-0.2.7}/modelit/templates/cnn/template.py +0 -0
  59. {modelit-0.2.5 → modelit-0.2.7}/modelit/templates/command_pattern/template.java +0 -0
  60. {modelit-0.2.5 → modelit-0.2.7}/modelit/templates/composite_pattern/template.java +0 -0
  61. {modelit-0.2.5 → modelit-0.2.7}/modelit/templates/decorator_pattern/template.java +0 -0
  62. {modelit-0.2.5 → modelit-0.2.7}/modelit/templates/facade_pattern/template.java +0 -0
  63. {modelit-0.2.5 → modelit-0.2.7}/modelit/templates/factory_pattern/template.java +0 -0
  64. {modelit-0.2.5 → modelit-0.2.7}/modelit/templates/flyweight_pattern/template.java +0 -0
  65. {modelit-0.2.5 → modelit-0.2.7}/modelit/templates/interpreter_pattern/template.java +0 -0
  66. {modelit-0.2.5 → modelit-0.2.7}/modelit/templates/mediator_pattern/template.java +0 -0
  67. {modelit-0.2.5 → modelit-0.2.7}/modelit/templates/memento_pattern/template.java +0 -0
  68. {modelit-0.2.5 → modelit-0.2.7}/modelit/templates/observer_pattern/template.java +0 -0
  69. {modelit-0.2.5 → modelit-0.2.7}/modelit/templates/perceptron/template.py +0 -0
  70. {modelit-0.2.5 → modelit-0.2.7}/modelit/templates/prototype_pattern/template.java +0 -0
  71. {modelit-0.2.5 → modelit-0.2.7}/modelit/templates/proxy_pattern/template.java +0 -0
  72. {modelit-0.2.5 → modelit-0.2.7}/modelit/templates/state_pattern/template.java +0 -0
  73. {modelit-0.2.5 → modelit-0.2.7}/modelit/templates/template_pattern/template.java +0 -0
  74. {modelit-0.2.5 → modelit-0.2.7}/modelit/templates/visitor_pattern/template.java +0 -0
  75. {modelit-0.2.5 → modelit-0.2.7}/modelit.egg-info/dependency_links.txt +0 -0
  76. {modelit-0.2.5 → modelit-0.2.7}/modelit.egg-info/entry_points.txt +0 -0
  77. {modelit-0.2.5 → modelit-0.2.7}/modelit.egg-info/top_level.txt +0 -0
  78. {modelit-0.2.5 → modelit-0.2.7}/pyproject.toml +0 -0
  79. {modelit-0.2.5 → modelit-0.2.7}/setup.cfg +0 -0
@@ -17,6 +17,8 @@ modelit/templates/<name>/
17
17
  modelit/templates/<name>/template.py
18
18
  ```
19
19
 
20
+ For multi-file templates, add more files in the same folder.
21
+
20
22
  4. Make sure the folder name is the function name.
21
23
  5. Test it locally:
22
24
 
@@ -32,6 +34,12 @@ or:
32
34
  modelit create <name>
33
35
  ```
34
36
 
37
+ You can list all available templates with:
38
+
39
+ ```bash
40
+ modelit list
41
+ ```
42
+
35
43
  ## Rules
36
44
 
37
45
  - Keep templates simple and runnable.
@@ -39,6 +47,7 @@ modelit create <name>
39
47
  - Do not add `metadata.json`.
40
48
  - Keep file names lowercase.
41
49
  - Prefer clean, beginner-friendly code.
50
+ - If a template has multiple files, `--output` should be treated as a directory.
42
51
 
43
52
  ## Pull request flow
44
53
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: modelit
3
- Version: 0.2.5
3
+ Version: 0.2.7
4
4
  Summary: Local-first ML starter templates you can print or save
5
5
  Author: Yashsmith
6
6
  License-Expression: MIT
@@ -28,6 +28,9 @@ ModelIt is a tiny Python package for storing your ML/DL boilerplate templates.
28
28
  - `perceptron()` prints the full code
29
29
  - `perceptron(output="code1.py")` saves it to a file
30
30
  - `modelit create perceptron` works from the terminal
31
+ - `modelit list` shows everything available
32
+ - single-file templates write to a file
33
+ - multi-file templates write to a folder
31
34
 
32
35
  ## Install
33
36
 
@@ -68,6 +71,7 @@ python3 code1.py
68
71
  ### CLI
69
72
 
70
73
  ```bash
74
+ modelit list
71
75
  modelit create perceptron
72
76
  modelit create perceptron --output code1.py
73
77
  ```
@@ -94,6 +98,8 @@ modelit/templates/mycode/template.java
94
98
 
95
99
  3. Done. The folder name becomes the function name.
96
100
 
101
+ If a template has more than one file, `--output` must be a directory path.
102
+
97
103
  That means this will work automatically:
98
104
 
99
105
  ```python
@@ -109,6 +115,12 @@ modelit create mycode
109
115
  modelit create mycode --output mycode.py
110
116
  ```
111
117
 
118
+ For multi-file templates:
119
+
120
+ ```bash
121
+ modelit create solid_notification --output out/solid_notification
122
+ ```
123
+
112
124
  ## Publish flow
113
125
 
114
126
  1. Add a new template folder.
@@ -8,6 +8,9 @@ ModelIt is a tiny Python package for storing your ML/DL boilerplate templates.
8
8
  - `perceptron()` prints the full code
9
9
  - `perceptron(output="code1.py")` saves it to a file
10
10
  - `modelit create perceptron` works from the terminal
11
+ - `modelit list` shows everything available
12
+ - single-file templates write to a file
13
+ - multi-file templates write to a folder
11
14
 
12
15
  ## Install
13
16
 
@@ -48,6 +51,7 @@ python3 code1.py
48
51
  ### CLI
49
52
 
50
53
  ```bash
54
+ modelit list
51
55
  modelit create perceptron
52
56
  modelit create perceptron --output code1.py
53
57
  ```
@@ -74,6 +78,8 @@ modelit/templates/mycode/template.java
74
78
 
75
79
  3. Done. The folder name becomes the function name.
76
80
 
81
+ If a template has more than one file, `--output` must be a directory path.
82
+
77
83
  That means this will work automatically:
78
84
 
79
85
  ```python
@@ -89,6 +95,12 @@ modelit create mycode
89
95
  modelit create mycode --output mycode.py
90
96
  ```
91
97
 
98
+ For multi-file templates:
99
+
100
+ ```bash
101
+ modelit create solid_notification --output out/solid_notification
102
+ ```
103
+
92
104
  ## Publish flow
93
105
 
94
106
  1. Add a new template folder.
@@ -11,14 +11,22 @@ def _create(args: argparse.Namespace) -> None:
11
11
  build_template_callable(args.name)(output=args.output)
12
12
 
13
13
 
14
+ def _list(args: argparse.Namespace) -> None:
15
+ for name in available_models():
16
+ print(name)
17
+
18
+
14
19
  def main(argv: list[str] | None = None) -> None:
15
20
  parser = argparse.ArgumentParser(prog="modelit", description="Print or save ML template code.")
16
21
  subparsers = parser.add_subparsers(dest="command", required=True)
17
22
 
18
23
  create_parser = subparsers.add_parser("create", help="Print or save a template")
19
24
  create_parser.add_argument("name", choices=available_models(), help="Template name")
20
- create_parser.add_argument("-o", "--output", help="Write the template to a file instead of printing")
25
+ create_parser.add_argument("-o", "--output", help="Write to a file for single-file templates or a directory for multi-file templates")
21
26
  create_parser.set_defaults(func=_create)
22
27
 
28
+ list_parser = subparsers.add_parser("list", help="List available templates")
29
+ list_parser.set_defaults(func=_list)
30
+
23
31
  args = parser.parse_args(argv)
24
32
  args.func(args)
@@ -0,0 +1,92 @@
1
+ """Template discovery and loading."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass
6
+ from importlib.resources import files
7
+ from pathlib import Path
8
+
9
+ PACKAGE_NAME = "modelit"
10
+ TEMPLATES_DIR = "templates"
11
+
12
+ @dataclass(frozen=True)
13
+ class TemplateInfo:
14
+ name: str
15
+
16
+ def _templates_root():
17
+ return files(PACKAGE_NAME).joinpath(TEMPLATES_DIR)
18
+
19
+ def _template_dir(name: str):
20
+ return _templates_root().joinpath(name)
21
+
22
+ def available_models() -> tuple[str, ...]:
23
+ root = _templates_root()
24
+ if not root.is_dir():
25
+ return ()
26
+
27
+ names: list[str] = []
28
+ for child in root.iterdir():
29
+ if child.is_dir() and not child.name.startswith("__"):
30
+ if any(child.iterdir()):
31
+ names.append(child.name)
32
+ return tuple(sorted(names))
33
+
34
+ def load_metadata(name: str) -> TemplateInfo:
35
+ return TemplateInfo(name=name)
36
+
37
+ def load_template_files(name: str) -> dict[str, str]:
38
+ target_dir = _template_dir(name)
39
+ if not target_dir.is_dir():
40
+ raise FileNotFoundError(f"Missing template directory for {name!r}")
41
+
42
+ file_contents = {}
43
+ for child in target_dir.iterdir():
44
+ if child.is_file() and not child.name.startswith("__"):
45
+ file_contents[child.name] = child.read_text(encoding="utf-8")
46
+ return file_contents
47
+
48
+ def build_template_callable(name: str):
49
+ files_dict = load_template_files(name)
50
+ info = load_metadata(name)
51
+
52
+ is_single_file = len(files_dict) == 1
53
+ default_filename = list(files_dict.keys())[0] if is_single_file else name
54
+
55
+ def runner(output: str | None = None) -> None:
56
+ if output:
57
+ out_path = Path(output)
58
+
59
+ if is_single_file:
60
+ if out_path.exists():
61
+ raise FileExistsError(f"Output path already exists: {out_path}")
62
+ out_path.parent.mkdir(parents=True, exist_ok=True)
63
+ out_path.write_text(list(files_dict.values())[0], encoding="utf-8")
64
+ print(f"Generated {out_path}")
65
+ else:
66
+ if out_path.suffix:
67
+ raise ValueError("Multi-file templates require a directory output path")
68
+ if out_path.exists() and not out_path.is_dir():
69
+ raise FileExistsError(f"Output path already exists and is not a directory: {out_path}")
70
+ out_path.mkdir(parents=True, exist_ok=True)
71
+ for fname, content in files_dict.items():
72
+ file_path = out_path / fname
73
+ if file_path.exists():
74
+ print(f"Skipping {file_path} (already exists)")
75
+ continue
76
+ file_path.write_text(content, encoding="utf-8")
77
+ print(f"Generated {file_path}")
78
+ return None
79
+
80
+ # If no output is specified, print to terminal
81
+ for fname, content in files_dict.items():
82
+ if not is_single_file:
83
+ print(f"\n{'='*40}\nFile: {fname}\n{'='*40}")
84
+ print(content, end="\n\n" if not is_single_file else "")
85
+
86
+ runner.__name__ = name
87
+ runner.__qualname__ = name
88
+ runner.__module__ = "modelit"
89
+ runner.__doc__ = f"Print or save the {name} template."
90
+ runner.output_file = default_filename # type: ignore[attr-defined]
91
+ runner.template_info = info # type: ignore[attr-defined]
92
+ return runner
@@ -0,0 +1,42 @@
1
+ import java.io.BufferedReader;
2
+ import java.io.FileReader;
3
+ import java.io.IOException;
4
+
5
+ class InsufficientFundsException extends Exception {
6
+ public InsufficientFundsException(String message) {
7
+ super(message);
8
+ }
9
+ }
10
+
11
+ public class AdvancedExceptionDemo {
12
+
13
+ public static void processTransfer(int accountBalance, int transferAmount) throws InsufficientFundsException {
14
+ if (transferAmount > accountBalance) {
15
+ throw new InsufficientFundsException("Transfer failed: Attempted to send $" + transferAmount + " but balance is $" + accountBalance);
16
+ }
17
+ System.out.println("Transfer of $" + transferAmount + " completed successfully.");
18
+ }
19
+
20
+ public static void main(String[] args) {
21
+ try {
22
+ processTransfer(5000, 7500);
23
+ } catch (InsufficientFundsException e) {
24
+ System.err.println("Business Error: " + e.getMessage());
25
+ }
26
+
27
+ System.out.println("-----------------");
28
+
29
+ String filePath = "non_existent_transactions.csv";
30
+
31
+ try (BufferedReader reader = new BufferedReader(new FileReader(filePath))) {
32
+ System.out.println(reader.readLine());
33
+ int value = Integer.parseInt("Not_A_Number");
34
+
35
+ } catch (IOException | NumberFormatException e) {
36
+ System.err.println("System Failure: Could not process file or format was corrupted.");
37
+ System.err.println("Root Cause: " + e.getClass().getSimpleName() + " -> " + e.getMessage());
38
+ } finally {
39
+ System.out.println("Cleanup executed. System remains stable.");
40
+ }
41
+ }
42
+ }
@@ -0,0 +1,19 @@
1
+ class DiscountCalculator {
2
+ public static double applyDiscount(double originalPrice, double discountPercentage) {
3
+ double discountAmount = originalPrice * (discountPercentage / 100);
4
+ double finalPrice = originalPrice - discountAmount;
5
+
6
+ assert finalPrice >= 0 : "CRITICAL LOGIC ERROR: Final price fell below zero!";
7
+
8
+ return finalPrice;
9
+ }
10
+ }
11
+
12
+ public class AssertionDemo {
13
+ public static void main(String[] args) {
14
+ System.out.println("Final Price: $" + DiscountCalculator.applyDiscount(1000, 20));
15
+
16
+ System.out.println("Testing invalid logic...");
17
+ System.out.println("Final Price: $" + DiscountCalculator.applyDiscount(500, 150));
18
+ }
19
+ }
@@ -0,0 +1,231 @@
1
+ # =========================================================
2
+ # DENOISING DIFFUSION PROBABILISTIC MODEL (DDPM)
3
+ # Learning the y = -x Manifold via Reverse Diffusion
4
+ # Built using PyTorch
5
+ # =========================================================
6
+
7
+
8
+ # =========================================================
9
+ # IMPORTS
10
+ # =========================================================
11
+
12
+ import numpy as np
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.optim as optim
16
+ import matplotlib.pyplot as plt
17
+
18
+
19
+ # =========================================================
20
+ # DATA GENERATION FUNCTION (REAL MANIFOLD)
21
+ # =========================================================
22
+
23
+ def sample_line_data(n_samples=512):
24
+
25
+ x = torch.empty(n_samples, 1).uniform_(-4.5, 4.5)
26
+
27
+ y = -x
28
+
29
+ return torch.cat([x, y], dim=1)
30
+
31
+
32
+ # =========================================================
33
+ # DIFFUSION HYPERPARAMETERS & SCHEDULING
34
+ # =========================================================
35
+
36
+ steps = 120
37
+
38
+ betas = torch.linspace(2e-4, 0.015, steps)
39
+
40
+ alphas = 1.0 - betas
41
+
42
+ alpha_cum = torch.cumprod(alphas, dim=0)
43
+
44
+
45
+ # =========================================================
46
+ # FORWARD DIFFUSION PROCESS (NOISING)
47
+ # =========================================================
48
+
49
+ def diffuse(x0, t_idx):
50
+
51
+ noise = torch.randn_like(x0)
52
+
53
+ a_bar = alpha_cum[t_idx].view(-1, 1)
54
+
55
+ x_t = torch.sqrt(a_bar) * x0 + torch.sqrt(1 - a_bar) * noise
56
+
57
+ return x_t, noise
58
+
59
+
60
+ # =========================================================
61
+ # NOISE ESTIMATOR NETWORK (MLP WITH TIME EMBEDDING)
62
+ # =========================================================
63
+
64
+ class NoiseEstimator(nn.Module):
65
+
66
+ def __init__(self, dim=2, hidden=128):
67
+
68
+ super().__init__()
69
+
70
+ self.mlp = nn.Sequential(
71
+
72
+ nn.Linear(dim + 16, hidden),
73
+
74
+ nn.SiLU(),
75
+
76
+ nn.Linear(hidden, hidden),
77
+
78
+ nn.SiLU(),
79
+
80
+ nn.Linear(hidden, dim)
81
+
82
+ )
83
+
84
+
85
+ def time_embed(self, t):
86
+
87
+ half = 8
88
+
89
+ freqs = torch.exp(torch.linspace(0, 3, half))
90
+
91
+ t = t.float().unsqueeze(1)
92
+
93
+ emb = torch.cat(
94
+ [torch.sin(t / freqs), torch.cos(t / freqs)],
95
+ dim=1
96
+ )
97
+
98
+ return emb
99
+
100
+
101
+ def forward(self, x, t):
102
+
103
+ t_emb = self.time_embed(t)
104
+
105
+ return self.mlp(torch.cat([x, t_emb], dim=1))
106
+
107
+
108
+ # =========================================================
109
+ # INITIALIZATION & CONFIGURATION
110
+ # =========================================================
111
+
112
+ model = NoiseEstimator()
113
+
114
+ opt = optim.Adam(model.parameters(), lr=2e-3)
115
+
116
+ loss_fn = nn.L1Loss()
117
+
118
+ epochs = 1800
119
+
120
+ batch = 128
121
+
122
+
123
+ # =========================================================
124
+ # DIFFUSION MODEL TRAINING LOOP
125
+ # =========================================================
126
+
127
+ print("\n=================================================")
128
+ print("TRAINING DENOISING DIFFUSION MODEL")
129
+ print("=================================================")
130
+
131
+ for ep in range(epochs):
132
+
133
+ x0 = sample_line_data(batch)
134
+
135
+ t = torch.randint(0, steps, (batch,))
136
+
137
+ xt, eps = diffuse(x0, t)
138
+
139
+ eps_hat = model(xt, t)
140
+
141
+ loss = loss_fn(eps_hat, eps)
142
+
143
+ opt.zero_grad()
144
+
145
+ loss.backward()
146
+
147
+ opt.step()
148
+
149
+
150
+ # =========================================================
151
+ # PROGRESS LOGGING
152
+ # =========================================================
153
+
154
+ if ep % 300 == 0:
155
+
156
+ print(f"Epoch {ep:04d} | L1 Loss: {loss.item():.5f}")
157
+
158
+
159
+ print("\nDiffusion Training Complete.")
160
+
161
+
162
+ # =========================================================
163
+ # REVERSE DIFFUSION PROCESS (GENERATION)
164
+ # =========================================================
165
+
166
+ @torch.no_grad()
167
+ def generate(n=800):
168
+
169
+ x = torch.randn(n, 2)
170
+
171
+ for t in reversed(range(steps)):
172
+
173
+ t_batch = torch.full((n,), t)
174
+
175
+ eps = model(x, t_batch)
176
+
177
+ a = alphas[t]
178
+
179
+ a_bar = alpha_cum[t]
180
+
181
+ b = betas[t]
182
+
183
+ coef1 = 1 / torch.sqrt(a)
184
+
185
+ coef2 = (1 - a) / torch.sqrt(1 - a_bar)
186
+
187
+ noise = torch.randn_like(x) if t > 0 else 0
188
+
189
+ x = coef1 * (x - coef2 * eps) + torch.sqrt(b) * noise
190
+
191
+ return x
192
+
193
+
194
+ # =========================================================
195
+ # GENERATING SAMPLES FOR EVALUATION
196
+ # =========================================================
197
+
198
+ real = sample_line_data(1000)
199
+
200
+ fake = generate(1000)
201
+
202
+
203
+ # =========================================================
204
+ # VISUALIZATION
205
+ # =========================================================
206
+
207
+ plt.figure(figsize=(6, 6))
208
+
209
+ plt.scatter(
210
+ real[:, 0],
211
+ real[:, 1],
212
+ s=5,
213
+ alpha=0.4,
214
+ label="Real data"
215
+ )
216
+
217
+ plt.scatter(
218
+ fake[:, 0],
219
+ fake[:, 1],
220
+ s=5,
221
+ alpha=0.4,
222
+ label="Generated"
223
+ )
224
+
225
+ plt.legend()
226
+
227
+ plt.title("Diffusion learns y = -x manifold")
228
+
229
+ plt.axis("equal")
230
+
231
+ plt.show()