@defai.digital/automatosx 5.6.10 → 5.6.12
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +30 -0
- package/README.md +2 -1
- package/dist/index.js +19 -2
- package/examples/AGENTS_INFO.md +2 -2
- package/examples/claude/commands/ax-agent.md +6 -6
- package/examples/claude/commands/ax-clear.md +4 -4
- package/examples/claude/commands/ax-init.md +3 -3
- package/examples/claude/commands/ax-list.md +2 -2
- package/examples/claude/commands/ax-memory.md +7 -7
- package/examples/claude/commands/ax-status.md +3 -3
- package/examples/claude/commands/ax-update.md +4 -4
- package/examples/pytorch_resnet50_training.py +289 -0
- package/package.json +1 -1
package/CHANGELOG.md
CHANGED
|
@@ -2,6 +2,36 @@
|
|
|
2
2
|
|
|
3
3
|
All notable changes to this project will be documented in this file. See [standard-version](https://github.com/conventional-changelog/standard-version) for commit guidelines.
|
|
4
4
|
|
|
5
|
+
## [5.6.12](https://github.com/defai-digital/automatosx/compare/v5.6.11...v5.6.12) (2025-10-21)
|
|
6
|
+
|
|
7
|
+
### Bug Fixes
|
|
8
|
+
|
|
9
|
+
* **docs:** correct slash command format from `/ax:` to `/ax-` across all files ([#3](https://github.com/defai-digital/automatosx/issues/3)) ([420b17f](https://github.com/defai-digital/automatosx/commit/420b17f))
|
|
10
|
+
- **Problem**: Command files use dash format (`ax-agent.md`) but documentation incorrectly used colon format (`/ax:agent`)
|
|
11
|
+
- **Solution**: Corrected all slash commands to use consistent dash format
|
|
12
|
+
- **Files Changed**: 20 files, 40+ occurrences corrected
|
|
13
|
+
- **P0**: 7 command definition files (`.claude/commands/*.md`)
|
|
14
|
+
- **P1**: 1 source code file (`src/cli/commands/init.ts`)
|
|
15
|
+
- **P2**: 12 documentation files (README, BEST-PRACTICES, examples, etc.)
|
|
16
|
+
- **Impact**: Users now see correct slash command format (`/ax-agent`, `/ax-status`, `/ax-memory`, etc.) when running `ax init` and reading documentation
|
|
17
|
+
|
|
18
|
+
### Documentation
|
|
19
|
+
|
|
20
|
+
* **docs:** update CLAUDE.md for v5.6.11 - 24 agents, workspace conventions ([e8911f7](https://github.com/defai-digital/automatosx/commit/e8911f7))
|
|
21
|
+
- Updated version from v5.6.9 to v5.6.11
|
|
22
|
+
- Updated agent count: 19 → 24 agents
|
|
23
|
+
- Added Phase 2 agent expansion details (Fiona, Ivy)
|
|
24
|
+
- Consolidated v5.6.9-5.6.11 release notes
|
|
25
|
+
- Added comprehensive workspace path conventions section
|
|
26
|
+
- Documented `/tmp` vs `/automatosx/tmp` distinction
|
|
27
|
+
- Fixed markdown lint warnings
|
|
28
|
+
|
|
29
|
+
## [5.6.11](https://github.com/defai-digital/automatosx/compare/v5.6.10...v5.6.11) (2025-10-20)
|
|
30
|
+
|
|
31
|
+
### Version Note
|
|
32
|
+
|
|
33
|
+
Version 5.6.11 was a version bump to reflect Phase 2 agent expansion completion. All changes are documented in v5.6.10 release notes below.
|
|
34
|
+
|
|
5
35
|
## [5.6.10](https://github.com/defai-digital/automatosx/compare/v5.6.9...v5.6.10) (2025-10-20)
|
|
6
36
|
|
|
7
37
|
|
package/README.md
CHANGED
|
@@ -171,7 +171,7 @@ Use natural language to collaborate with agents directly within your editor. Cla
|
|
|
171
171
|
"please work with ax agent to refactor this module with best practices"
|
|
172
172
|
```
|
|
173
173
|
|
|
174
|
-
For simple, direct tasks, use slash commands: `/ax
|
|
174
|
+
For simple, direct tasks, use slash commands: `/ax-agent backend, write a function to validate emails`.
|
|
175
175
|
|
|
176
176
|
### 2. Terminal/CLI Mode (For Power Users)
|
|
177
177
|
|
|
@@ -208,6 +208,7 @@ ax run quality "Write tests for the API" # Auto-receives design + implementation
|
|
|
208
208
|
- **[Core Concepts](docs/guide/core-concepts.md)**
|
|
209
209
|
- **[Full CLI Command Reference](docs/reference/cli-commands.md)**
|
|
210
210
|
- **[Agent Directory](examples/AGENTS_INFO.md)**
|
|
211
|
+
- **[Workspace Path Conventions](docs/workspace-conventions.md)**
|
|
211
212
|
- **[Troubleshooting Guide](TROUBLESHOOTING.md)**
|
|
212
213
|
|
|
213
214
|
---
|
package/dist/index.js
CHANGED
|
@@ -5435,8 +5435,8 @@ var initCommand = {
|
|
|
5435
5435
|
console.log(chalk27.gray(" \u2022 cto - Technology strategist"));
|
|
5436
5436
|
console.log(chalk27.gray(" \u2022 researcher - Research analyst\n"));
|
|
5437
5437
|
console.log(chalk27.cyan("Claude Code Integration:"));
|
|
5438
|
-
console.log(chalk27.gray(" \u2022 Use /ax command in Claude Code"));
|
|
5439
|
-
console.log(chalk27.gray(
|
|
5438
|
+
console.log(chalk27.gray(" \u2022 Use /ax-agent command in Claude Code"));
|
|
5439
|
+
console.log(chalk27.gray(" \u2022 Example: /ax-agent backend, create a REST API"));
|
|
5440
5440
|
console.log(chalk27.gray(" \u2022 MCP tools available in .claude/mcp/\n"));
|
|
5441
5441
|
logger.info("AutomatosX initialized", {
|
|
5442
5442
|
projectDir,
|
|
@@ -11630,6 +11630,23 @@ Note: You can delegate to ANY agent by name, not just those listed above.
|
|
|
11630
11630
|
`;
|
|
11631
11631
|
prompt += `- These are recommendations for organization, not access restrictions
|
|
11632
11632
|
|
|
11633
|
+
`;
|
|
11634
|
+
prompt += `**File Writing Guidelines:**
|
|
11635
|
+
`;
|
|
11636
|
+
prompt += `- ALWAYS use WorkspaceManager for workspace file operations:
|
|
11637
|
+
`;
|
|
11638
|
+
prompt += ` \`\`\`typescript
|
|
11639
|
+
`;
|
|
11640
|
+
prompt += ` await workspaceManager.writeTmp('report.md', content);
|
|
11641
|
+
`;
|
|
11642
|
+
prompt += ` \`\`\`
|
|
11643
|
+
`;
|
|
11644
|
+
prompt += `- DO NOT use relative paths like 'tmp/' or './tmp/' for workspace files
|
|
11645
|
+
`;
|
|
11646
|
+
prompt += `- DO NOT use direct fs.writeFile() for workspace files (automatosx/tmp/, automatosx/PRD/)
|
|
11647
|
+
`;
|
|
11648
|
+
prompt += `- For project files outside workspace (src/, tests/, docs/), use normal file I/O
|
|
11649
|
+
|
|
11633
11650
|
`;
|
|
11634
11651
|
prompt += `**When to ask user permission:**
|
|
11635
11652
|
`;
|
package/examples/AGENTS_INFO.md
CHANGED
|
@@ -78,8 +78,8 @@ Instead of directly commanding agents with slash commands, **let Claude Code coo
|
|
|
78
78
|
### Express Method: Slash Commands (for simple tasks only)
|
|
79
79
|
|
|
80
80
|
```
|
|
81
|
-
⚡ EXPRESS: /ax
|
|
82
|
-
⚡ EXPRESS: /ax
|
|
81
|
+
⚡ EXPRESS: /ax-agent backend, write an email validation function
|
|
82
|
+
⚡ EXPRESS: /ax-agent quality, review this code snippet
|
|
83
83
|
```
|
|
84
84
|
|
|
85
85
|
**Use slash commands only when**:
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Execute an AutomatosX agent with a specific task.
|
|
2
2
|
|
|
3
|
-
**IMPORTANT**: When user types `/ax
|
|
3
|
+
**IMPORTANT**: When user types `/ax-agent <agent>, <task>`, you MUST:
|
|
4
4
|
|
|
5
5
|
1. Split the input on the FIRST comma
|
|
6
6
|
2. Extract agent name (text before comma, trimmed)
|
|
@@ -10,7 +10,7 @@ Execute an AutomatosX agent with a specific task.
|
|
|
10
10
|
**Parsing Rules**:
|
|
11
11
|
|
|
12
12
|
```
|
|
13
|
-
Input: /ax
|
|
13
|
+
Input: /ax-agent backend, explain quantum computing
|
|
14
14
|
↓
|
|
15
15
|
Agent: "backend"
|
|
16
16
|
Task: "explain quantum computing"
|
|
@@ -20,16 +20,16 @@ Execute: automatosx run backend "explain quantum computing"
|
|
|
20
20
|
|
|
21
21
|
**Examples**:
|
|
22
22
|
|
|
23
|
-
User input: `/ax
|
|
23
|
+
User input: `/ax-agent bob, i want you help me write a validation function`
|
|
24
24
|
→ Execute: `automatosx run bob "i want you help me write a validation function"`
|
|
25
25
|
|
|
26
|
-
User input: `/ax
|
|
26
|
+
User input: `/ax-agent backend, explain quantum computing to me`
|
|
27
27
|
→ Execute: `automatosx run backend "explain quantum computing to me"`
|
|
28
28
|
|
|
29
|
-
User input: `/ax
|
|
29
|
+
User input: `/ax-agent backend, create a REST API for user management`
|
|
30
30
|
→ Execute: `automatosx run backend "create a REST API for user management"`
|
|
31
31
|
|
|
32
|
-
User input: `/ax
|
|
32
|
+
User input: `/ax-agent quality, review the changes in src/auth.ts and suggest improvements`
|
|
33
33
|
→ Execute: `automatosx run quality "review the changes in src/auth.ts and suggest improvements"`
|
|
34
34
|
|
|
35
35
|
**Available built-in agents**: backend, frontend, devops, data, security, quality, design, writer, product, ceo, cto, researcher
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Clear AutomatosX memory.
|
|
2
2
|
|
|
3
|
-
**IMPORTANT**: When user types `/ax
|
|
3
|
+
**IMPORTANT**: When user types `/ax-clear`, you MUST execute:
|
|
4
4
|
|
|
5
5
|
```bash
|
|
6
6
|
automatosx memory clear --confirm
|
|
@@ -10,13 +10,13 @@ This will delete all stored memories from the AutomatosX memory database.
|
|
|
10
10
|
|
|
11
11
|
**Examples**:
|
|
12
12
|
|
|
13
|
-
User input: `/ax
|
|
13
|
+
User input: `/ax-clear`
|
|
14
14
|
→ Execute: `automatosx memory clear --confirm`
|
|
15
15
|
|
|
16
|
-
User input: `/ax
|
|
16
|
+
User input: `/ax-clear --type task`
|
|
17
17
|
→ Execute: `automatosx memory clear --confirm --type task`
|
|
18
18
|
|
|
19
|
-
User input: `/ax
|
|
19
|
+
User input: `/ax-clear --older-than 30`
|
|
20
20
|
→ Execute: `automatosx memory clear --confirm --older-than 30`
|
|
21
21
|
|
|
22
22
|
⚠️ **Warning**: This action cannot be undone. Consider backing up first.
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Initialize AutomatosX in the current project directory.
|
|
2
2
|
|
|
3
|
-
**IMPORTANT**: When user types `/ax
|
|
3
|
+
**IMPORTANT**: When user types `/ax-init`, you MUST execute:
|
|
4
4
|
|
|
5
5
|
```bash
|
|
6
6
|
automatosx init
|
|
@@ -17,9 +17,9 @@ This will:
|
|
|
17
17
|
|
|
18
18
|
**Examples**:
|
|
19
19
|
|
|
20
|
-
User input: `/ax
|
|
20
|
+
User input: `/ax-init`
|
|
21
21
|
→ Execute: `automatosx init`
|
|
22
22
|
|
|
23
|
-
User input: `/ax
|
|
23
|
+
User input: `/ax-init --force`
|
|
24
24
|
→ Execute: `automatosx init --force`
|
|
25
25
|
(Use `--force` to reinitialize if `.automatosx` already exists)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
List available AutomatosX agents.
|
|
2
2
|
|
|
3
|
-
**IMPORTANT**: When user types `/ax
|
|
3
|
+
**IMPORTANT**: When user types `/ax-list`, you MUST execute:
|
|
4
4
|
|
|
5
5
|
```bash
|
|
6
6
|
automatosx list agents
|
|
@@ -10,7 +10,7 @@ This will display all available agents with their display names, roles, and desc
|
|
|
10
10
|
|
|
11
11
|
**Example**:
|
|
12
12
|
|
|
13
|
-
User input: `/ax
|
|
13
|
+
User input: `/ax-list`
|
|
14
14
|
→ Execute: `automatosx list agents`
|
|
15
15
|
|
|
16
16
|
Expected output shows all agents including:
|
|
@@ -1,25 +1,25 @@
|
|
|
1
1
|
Search AutomatosX memory for relevant information.
|
|
2
2
|
|
|
3
|
-
**IMPORTANT**: When user types `/ax
|
|
3
|
+
**IMPORTANT**: When user types `/ax-memory <query>`, you MUST:
|
|
4
4
|
|
|
5
|
-
1. Take everything after `/ax
|
|
5
|
+
1. Take everything after `/ax-memory` as the search query
|
|
6
6
|
2. Execute: `automatosx memory search "<query>"`
|
|
7
7
|
|
|
8
8
|
**Examples**:
|
|
9
9
|
|
|
10
|
-
User input: `/ax
|
|
10
|
+
User input: `/ax-memory authentication`
|
|
11
11
|
→ Execute: `automatosx memory search "authentication"`
|
|
12
12
|
|
|
13
|
-
User input: `/ax
|
|
13
|
+
User input: `/ax-memory how to setup database`
|
|
14
14
|
→ Execute: `automatosx memory search "how to setup database"`
|
|
15
15
|
|
|
16
|
-
User input: `/ax
|
|
16
|
+
User input: `/ax-memory API errors and solutions`
|
|
17
17
|
→ Execute: `automatosx memory search "API errors and solutions"`
|
|
18
18
|
|
|
19
19
|
**With Options**:
|
|
20
20
|
|
|
21
|
-
User input: `/ax
|
|
21
|
+
User input: `/ax-memory authentication --limit 5`
|
|
22
22
|
→ Execute: `automatosx memory search "authentication" --limit 5`
|
|
23
23
|
|
|
24
|
-
User input: `/ax
|
|
24
|
+
User input: `/ax-memory database --type task`
|
|
25
25
|
→ Execute: `automatosx memory search "database" --type task`
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Display AutomatosX system status and configuration.
|
|
2
2
|
|
|
3
|
-
**IMPORTANT**: When user types `/ax
|
|
3
|
+
**IMPORTANT**: When user types `/ax-status`, you MUST execute:
|
|
4
4
|
|
|
5
5
|
```bash
|
|
6
6
|
automatosx status
|
|
@@ -16,9 +16,9 @@ This will show:
|
|
|
16
16
|
|
|
17
17
|
**Examples**:
|
|
18
18
|
|
|
19
|
-
User input: `/ax
|
|
19
|
+
User input: `/ax-status`
|
|
20
20
|
→ Execute: `automatosx status`
|
|
21
21
|
|
|
22
|
-
User input: `/ax
|
|
22
|
+
User input: `/ax-status --verbose`
|
|
23
23
|
→ Execute: `automatosx status --verbose`
|
|
24
24
|
(Shows detailed configuration and environment information)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Update AutomatosX to the latest version.
|
|
2
2
|
|
|
3
|
-
**IMPORTANT**: When user types `/ax
|
|
3
|
+
**IMPORTANT**: When user types `/ax-update`, you MUST execute:
|
|
4
4
|
|
|
5
5
|
```bash
|
|
6
6
|
automatosx update
|
|
@@ -15,14 +15,14 @@ This will:
|
|
|
15
15
|
|
|
16
16
|
**Examples**:
|
|
17
17
|
|
|
18
|
-
User input: `/ax
|
|
18
|
+
User input: `/ax-update`
|
|
19
19
|
→ Execute: `automatosx update`
|
|
20
20
|
(Interactive update with confirmation)
|
|
21
21
|
|
|
22
|
-
User input: `/ax
|
|
22
|
+
User input: `/ax-update --check`
|
|
23
23
|
→ Execute: `automatosx update --check`
|
|
24
24
|
(Only check for updates without installing)
|
|
25
25
|
|
|
26
|
-
User input: `/ax
|
|
26
|
+
User input: `/ax-update --yes`
|
|
27
27
|
→ Execute: `automatosx update --yes`
|
|
28
28
|
(Update without confirmation prompt)
|
|
@@ -0,0 +1,289 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Mixed precision ResNet-50 training loop for image classification.
|
|
3
|
+
|
|
4
|
+
Author: Mira — "From architecture to inference - I build models that ship."
|
|
5
|
+
|
|
6
|
+
This script focuses on the core PyTorch 2.x training primitives:
|
|
7
|
+
- ImageFolder-based DataLoader with realistic augmentations
|
|
8
|
+
- Transfer learning from torchvision's pretrained ResNet-50
|
|
9
|
+
- Mixed precision training via torch.amp.autocast + GradScaler
|
|
10
|
+
- torch.compile() to squeeze extra throughput from the forward pass
|
|
11
|
+
|
|
12
|
+
Expected directory structure for the dataset:
|
|
13
|
+
data_root/
|
|
14
|
+
train/
|
|
15
|
+
class_a/*.jpg
|
|
16
|
+
class_b/*.jpg
|
|
17
|
+
val/
|
|
18
|
+
class_a/*.jpg
|
|
19
|
+
class_b/*.jpg
|
|
20
|
+
|
|
21
|
+
Run:
|
|
22
|
+
python examples/pytorch_resnet50_training.py --data-root /path/to/data
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
from __future__ import annotations
|
|
26
|
+
|
|
27
|
+
import argparse
|
|
28
|
+
import random
|
|
29
|
+
import time
|
|
30
|
+
from dataclasses import dataclass
|
|
31
|
+
from pathlib import Path
|
|
32
|
+
from typing import Tuple
|
|
33
|
+
|
|
34
|
+
import torch
|
|
35
|
+
from torch import nn
|
|
36
|
+
from torch.amp import GradScaler, autocast
|
|
37
|
+
from torch.optim import Optimizer
|
|
38
|
+
from torch.optim.lr_scheduler import CosineAnnealingLR
|
|
39
|
+
from torch.utils.data import DataLoader
|
|
40
|
+
from torchvision import datasets, transforms
|
|
41
|
+
from torchvision.models import ResNet50_Weights, resnet50
|
|
42
|
+
from tqdm import tqdm
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@dataclass
|
|
46
|
+
class TrainConfig:
|
|
47
|
+
data_root: Path
|
|
48
|
+
batch_size: int = 64
|
|
49
|
+
num_workers: int = 8
|
|
50
|
+
epochs: int = 20
|
|
51
|
+
learning_rate: float = 5e-4
|
|
52
|
+
weight_decay: float = 0.01
|
|
53
|
+
seed: int = 17
|
|
54
|
+
output_dir: Path = Path("artifacts/checkpoints")
|
|
55
|
+
log_every: int = 25
|
|
56
|
+
num_classes: int | None = None
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def set_seed(seed: int) -> None:
|
|
60
|
+
random.seed(seed)
|
|
61
|
+
torch.manual_seed(seed)
|
|
62
|
+
torch.cuda.manual_seed_all(seed)
|
|
63
|
+
torch.backends.cudnn.deterministic = False
|
|
64
|
+
torch.backends.cudnn.benchmark = True
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def build_dataloaders(cfg: TrainConfig) -> Tuple[DataLoader, DataLoader, int]:
|
|
68
|
+
"""Construct train/val DataLoaders and infer class count."""
|
|
69
|
+
train_dir = cfg.data_root / "train"
|
|
70
|
+
val_dir = cfg.data_root / "val"
|
|
71
|
+
|
|
72
|
+
if not train_dir.is_dir() or not val_dir.is_dir():
|
|
73
|
+
raise FileNotFoundError(
|
|
74
|
+
f"Expecting 'train' and 'val' subdirectories under {cfg.data_root}"
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
train_tfms = transforms.Compose(
|
|
78
|
+
[
|
|
79
|
+
transforms.RandomResizedCrop(224, scale=(0.6, 1.0)),
|
|
80
|
+
transforms.RandomHorizontalFlip(),
|
|
81
|
+
transforms.ColorJitter(0.2, 0.2, 0.2, 0.1),
|
|
82
|
+
transforms.RandomApply(
|
|
83
|
+
[transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0))], p=0.3
|
|
84
|
+
),
|
|
85
|
+
transforms.ToTensor(),
|
|
86
|
+
transforms.Normalize(
|
|
87
|
+
mean=ResNet50_Weights.IMAGENET1K_V2.meta["mean"],
|
|
88
|
+
std=ResNet50_Weights.IMAGENET1K_V2.meta["std"],
|
|
89
|
+
),
|
|
90
|
+
]
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
eval_tfms = transforms.Compose(
|
|
94
|
+
[
|
|
95
|
+
transforms.Resize(256),
|
|
96
|
+
transforms.CenterCrop(224),
|
|
97
|
+
transforms.ToTensor(),
|
|
98
|
+
transforms.Normalize(
|
|
99
|
+
mean=ResNet50_Weights.IMAGENET1K_V2.meta["mean"],
|
|
100
|
+
std=ResNet50_Weights.IMAGENET1K_V2.meta["std"],
|
|
101
|
+
),
|
|
102
|
+
]
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
train_dataset = datasets.ImageFolder(train_dir, transform=train_tfms)
|
|
106
|
+
val_dataset = datasets.ImageFolder(val_dir, transform=eval_tfms)
|
|
107
|
+
num_classes = len(train_dataset.classes)
|
|
108
|
+
|
|
109
|
+
loader_kwargs = dict(
|
|
110
|
+
batch_size=cfg.batch_size,
|
|
111
|
+
num_workers=cfg.num_workers,
|
|
112
|
+
pin_memory=True,
|
|
113
|
+
persistent_workers=cfg.num_workers > 0,
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
train_loader = DataLoader(train_dataset, shuffle=True, drop_last=True, **loader_kwargs)
|
|
117
|
+
val_loader = DataLoader(val_dataset, shuffle=False, drop_last=False, **loader_kwargs)
|
|
118
|
+
|
|
119
|
+
return train_loader, val_loader, num_classes
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def build_model(num_classes: int) -> nn.Module:
|
|
123
|
+
"""Load pretrained ResNet-50 and replace the classification head."""
|
|
124
|
+
base_model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
|
|
125
|
+
in_features = base_model.fc.in_features
|
|
126
|
+
base_model.fc = nn.Sequential(
|
|
127
|
+
nn.Dropout(p=0.2),
|
|
128
|
+
nn.Linear(in_features, num_classes),
|
|
129
|
+
)
|
|
130
|
+
return base_model
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def accuracy(output: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
|
134
|
+
preds = output.argmax(dim=1)
|
|
135
|
+
return (preds == target).float().mean()
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def train_one_epoch(
|
|
139
|
+
model: nn.Module,
|
|
140
|
+
loader: DataLoader,
|
|
141
|
+
optimizer: Optimizer,
|
|
142
|
+
scaler: GradScaler,
|
|
143
|
+
device: torch.device,
|
|
144
|
+
epoch: int,
|
|
145
|
+
cfg: TrainConfig,
|
|
146
|
+
) -> Tuple[float, float]:
|
|
147
|
+
model.train()
|
|
148
|
+
running_loss = 0.0
|
|
149
|
+
running_acc = 0.0
|
|
150
|
+
|
|
151
|
+
for step, (images, labels) in enumerate(tqdm(loader, desc=f"Epoch {epoch} [train]")):
|
|
152
|
+
images = images.to(device, non_blocking=True)
|
|
153
|
+
labels = labels.to(device, non_blocking=True)
|
|
154
|
+
|
|
155
|
+
optimizer.zero_grad(set_to_none=True)
|
|
156
|
+
|
|
157
|
+
with autocast(device_type=device.type, dtype=torch.float16 if device.type == "cuda" else torch.bfloat16):
|
|
158
|
+
logits = model(images)
|
|
159
|
+
loss = nn.functional.cross_entropy(logits, labels)
|
|
160
|
+
|
|
161
|
+
scaler.scale(loss).backward()
|
|
162
|
+
scaler.unscale_(optimizer)
|
|
163
|
+
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
|
164
|
+
scaler.step(optimizer)
|
|
165
|
+
scaler.update()
|
|
166
|
+
|
|
167
|
+
with torch.no_grad():
|
|
168
|
+
running_loss += loss.item()
|
|
169
|
+
running_acc += accuracy(logits, labels).item()
|
|
170
|
+
|
|
171
|
+
if step % cfg.log_every == 0:
|
|
172
|
+
current_lr = optimizer.param_groups[0]["lr"]
|
|
173
|
+
print(
|
|
174
|
+
f"Epoch {epoch} | step {step:04d} | lr {current_lr:.2e} | "
|
|
175
|
+
f"loss {loss.item():.4f}"
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
steps = len(loader)
|
|
179
|
+
return running_loss / steps, running_acc / steps
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
@torch.no_grad()
|
|
183
|
+
def evaluate(
|
|
184
|
+
model: nn.Module,
|
|
185
|
+
loader: DataLoader,
|
|
186
|
+
device: torch.device,
|
|
187
|
+
) -> Tuple[float, float]:
|
|
188
|
+
model.eval()
|
|
189
|
+
total_loss = 0.0
|
|
190
|
+
total_acc = 0.0
|
|
191
|
+
|
|
192
|
+
for images, labels in tqdm(loader, desc="Validation"):
|
|
193
|
+
images = images.to(device, non_blocking=True)
|
|
194
|
+
labels = labels.to(device, non_blocking=True)
|
|
195
|
+
|
|
196
|
+
logits = model(images)
|
|
197
|
+
total_loss += nn.functional.cross_entropy(logits, labels).item()
|
|
198
|
+
total_acc += accuracy(logits, labels).item()
|
|
199
|
+
|
|
200
|
+
steps = len(loader)
|
|
201
|
+
return total_loss / steps, total_acc / steps
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def save_checkpoint(
|
|
205
|
+
model: nn.Module,
|
|
206
|
+
optimizer: Optimizer,
|
|
207
|
+
epoch: int,
|
|
208
|
+
cfg: TrainConfig,
|
|
209
|
+
metric: float,
|
|
210
|
+
) -> None:
|
|
211
|
+
cfg.output_dir.mkdir(parents=True, exist_ok=True)
|
|
212
|
+
ckpt = {
|
|
213
|
+
"epoch": epoch,
|
|
214
|
+
"state_dict": model.state_dict(),
|
|
215
|
+
"optimizer_state": optimizer.state_dict(),
|
|
216
|
+
"val_top1": metric,
|
|
217
|
+
}
|
|
218
|
+
torch.save(ckpt, cfg.output_dir / f"resnet50_epoch{epoch:03d}_acc{metric:.3f}.pt")
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
def train_model(cfg: TrainConfig) -> None:
|
|
222
|
+
set_seed(cfg.seed)
|
|
223
|
+
|
|
224
|
+
train_loader, val_loader, inferred_classes = build_dataloaders(cfg)
|
|
225
|
+
num_classes = cfg.num_classes or inferred_classes
|
|
226
|
+
|
|
227
|
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
228
|
+
print(f"Using device: {device}")
|
|
229
|
+
|
|
230
|
+
torch.set_float32_matmul_precision("high")
|
|
231
|
+
model = build_model(num_classes).to(device)
|
|
232
|
+
model = torch.compile(model) # PyTorch 2.x graph capture for extra throughput
|
|
233
|
+
|
|
234
|
+
optimizer = torch.optim.AdamW(
|
|
235
|
+
model.parameters(), lr=cfg.learning_rate, weight_decay=cfg.weight_decay
|
|
236
|
+
)
|
|
237
|
+
scheduler = CosineAnnealingLR(optimizer, T_max=cfg.epochs)
|
|
238
|
+
scaler = GradScaler(device.type if device.type == "cuda" else "cpu")
|
|
239
|
+
|
|
240
|
+
best_acc = 0.0
|
|
241
|
+
|
|
242
|
+
for epoch in range(1, cfg.epochs + 1):
|
|
243
|
+
epoch_start = time.time()
|
|
244
|
+
|
|
245
|
+
train_loss, train_acc = train_one_epoch(
|
|
246
|
+
model, train_loader, optimizer, scaler, device, epoch, cfg
|
|
247
|
+
)
|
|
248
|
+
val_loss, val_acc = evaluate(model, val_loader, device)
|
|
249
|
+
scheduler.step()
|
|
250
|
+
|
|
251
|
+
elapsed = time.time() - epoch_start
|
|
252
|
+
print(
|
|
253
|
+
f"Epoch {epoch:02d} finished in {elapsed:.1f}s | "
|
|
254
|
+
f"train loss {train_loss:.4f}, train acc {train_acc*100:.2f}% | "
|
|
255
|
+
f"val loss {val_loss:.4f}, val acc {val_acc*100:.2f}%"
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
if val_acc > best_acc:
|
|
259
|
+
best_acc = val_acc
|
|
260
|
+
save_checkpoint(model, optimizer, epoch, cfg, metric=val_acc)
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
def parse_args() -> TrainConfig:
|
|
264
|
+
parser = argparse.ArgumentParser(description="ResNet-50 mixed precision training")
|
|
265
|
+
parser.add_argument("--data-root", type=Path, required=True, help="Dataset root path")
|
|
266
|
+
parser.add_argument("--epochs", type=int, default=20)
|
|
267
|
+
parser.add_argument("--batch-size", type=int, default=64)
|
|
268
|
+
parser.add_argument("--num-workers", type=int, default=8)
|
|
269
|
+
parser.add_argument("--lr", type=float, default=5e-4)
|
|
270
|
+
parser.add_argument("--weight-decay", type=float, default=0.01)
|
|
271
|
+
parser.add_argument("--seed", type=int, default=17)
|
|
272
|
+
parser.add_argument("--output-dir", type=Path, default=Path("artifacts/checkpoints"))
|
|
273
|
+
args = parser.parse_args()
|
|
274
|
+
|
|
275
|
+
return TrainConfig(
|
|
276
|
+
data_root=args.data_root,
|
|
277
|
+
epochs=args.epochs,
|
|
278
|
+
batch_size=args.batch_size,
|
|
279
|
+
num_workers=args.num_workers,
|
|
280
|
+
learning_rate=args.lr,
|
|
281
|
+
weight_decay=args.weight_decay,
|
|
282
|
+
seed=args.seed,
|
|
283
|
+
output_dir=args.output_dir,
|
|
284
|
+
)
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
if __name__ == "__main__":
|
|
288
|
+
config = parse_args()
|
|
289
|
+
train_model(config)
|