gradia 1.0.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gradia-1.0.0/LICENSE +21 -0
- gradia-1.0.0/MANIFEST.in +10 -0
- gradia-1.0.0/PKG-INFO +143 -0
- gradia-1.0.0/README.md +92 -0
- gradia-1.0.0/gradia/__init__.py +1 -0
- gradia-1.0.0/gradia/cli/__init__.py +0 -0
- gradia-1.0.0/gradia/cli/main.py +91 -0
- gradia-1.0.0/gradia/core/config.py +56 -0
- gradia-1.0.0/gradia/core/inspector.py +37 -0
- gradia-1.0.0/gradia/core/scenario.py +118 -0
- gradia-1.0.0/gradia/models/base.py +39 -0
- gradia-1.0.0/gradia/models/sklearn_wrappers.py +114 -0
- gradia-1.0.0/gradia/trainer/callbacks.py +48 -0
- gradia-1.0.0/gradia/trainer/engine.py +203 -0
- gradia-1.0.0/gradia/viz/assets/logo.png +0 -0
- gradia-1.0.0/gradia/viz/server.py +228 -0
- gradia-1.0.0/gradia/viz/static/css/style.css +312 -0
- gradia-1.0.0/gradia/viz/static/js/app.js +348 -0
- gradia-1.0.0/gradia/viz/templates/configure.html +304 -0
- gradia-1.0.0/gradia/viz/templates/index.html +147 -0
- gradia-1.0.0/gradia.egg-info/PKG-INFO +143 -0
- gradia-1.0.0/gradia.egg-info/SOURCES.txt +27 -0
- gradia-1.0.0/gradia.egg-info/dependency_links.txt +1 -0
- gradia-1.0.0/gradia.egg-info/entry_points.txt +2 -0
- gradia-1.0.0/gradia.egg-info/requires.txt +11 -0
- gradia-1.0.0/gradia.egg-info/top_level.txt +1 -0
- gradia-1.0.0/pyproject.toml +45 -0
- gradia-1.0.0/setup.cfg +4 -0
- gradia-1.0.0/tests/test_gradia.py +134 -0
gradia-1.0.0/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2025 STiFLeR
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
gradia-1.0.0/MANIFEST.in
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
include README.md
|
|
2
|
+
include LICENSE
|
|
3
|
+
recursive-include gradia/viz/templates *
|
|
4
|
+
recursive-include gradia/viz/static *
|
|
5
|
+
recursive-include gradia/viz/assets *
|
|
6
|
+
global-exclude .gradia
|
|
7
|
+
global-exclude .gradia_logs
|
|
8
|
+
global-exclude *.pyc
|
|
9
|
+
global-exclude __pycache__
|
|
10
|
+
global-exclude *.DS_Store
|
gradia-1.0.0/PKG-INFO
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: gradia
|
|
3
|
+
Version: 1.0.0
|
|
4
|
+
Summary: Local-first ML training visualization and tracking dashboard.
|
|
5
|
+
Author-email: STiFLeR <stifler@example.com>
|
|
6
|
+
License: MIT License
|
|
7
|
+
|
|
8
|
+
Copyright (c) 2025 STiFLeR
|
|
9
|
+
|
|
10
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
11
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
12
|
+
in the Software without restriction, including without limitation the rights
|
|
13
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
14
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
15
|
+
furnished to do so, subject to the following conditions:
|
|
16
|
+
|
|
17
|
+
The above copyright notice and this permission notice shall be included in all
|
|
18
|
+
copies or substantial portions of the Software.
|
|
19
|
+
|
|
20
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
21
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
22
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
23
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
24
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
25
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
26
|
+
SOFTWARE.
|
|
27
|
+
|
|
28
|
+
Project-URL: Homepage, https://github.com/STiFLeR7/gradia
|
|
29
|
+
Project-URL: Bug Tracker, https://github.com/STiFLeR7/gradia/issues
|
|
30
|
+
Keywords: machine-learning,dashboard,visualization,tracking,mlops
|
|
31
|
+
Classifier: Programming Language :: Python :: 3
|
|
32
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
33
|
+
Classifier: Operating System :: OS Independent
|
|
34
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
35
|
+
Classifier: Topic :: System :: Monitoring
|
|
36
|
+
Requires-Python: >=3.9
|
|
37
|
+
Description-Content-Type: text/markdown
|
|
38
|
+
License-File: LICENSE
|
|
39
|
+
Requires-Dist: scikit-learn
|
|
40
|
+
Requires-Dist: pandas
|
|
41
|
+
Requires-Dist: numpy
|
|
42
|
+
Requires-Dist: fastapi
|
|
43
|
+
Requires-Dist: uvicorn
|
|
44
|
+
Requires-Dist: typer
|
|
45
|
+
Requires-Dist: jinja2
|
|
46
|
+
Requires-Dist: watchdog
|
|
47
|
+
Requires-Dist: rich
|
|
48
|
+
Requires-Dist: psutil
|
|
49
|
+
Requires-Dist: tqdm
|
|
50
|
+
Dynamic: license-file
|
|
51
|
+
|
|
52
|
+
<div align="center">
|
|
53
|
+
|
|
54
|
+
# G R A D I A
|
|
55
|
+
|
|
56
|
+
**Next-Generation Local-First MLOps Platform**
|
|
57
|
+
|
|
58
|
+
[](https://pypi.org/project/gradia/)
|
|
59
|
+
[](https://python.org)
|
|
60
|
+
[](LICENSE)
|
|
61
|
+
[](https://github.com/STiFLeR7/gradia/actions)
|
|
62
|
+
|
|
63
|
+
<p align="center">
|
|
64
|
+
<img src="https://github.com/STiFLeR7/gradia/blob/master/docs/dashboard.png" alt="Gradia Dashboard" width="100%" />
|
|
65
|
+
</p>
|
|
66
|
+
|
|
67
|
+
</div>
|
|
68
|
+
|
|
69
|
+
---
|
|
70
|
+
|
|
71
|
+
## 🚀 Overview
|
|
72
|
+
|
|
73
|
+
**Gradia** is a high-performance, asynchronous monitoring solution designed for local machine learning workflows. Unlike cloud-native behemoths, Gradia focuses on **zero-latency**, **privacy-first** tracking that runs directly alongside your training loop.
|
|
74
|
+
|
|
75
|
+
Built on **FastAPI**, **WebSockets** (simulated via high-frequency polling), and **Reactive UI**, Gradia provides granular visibility into your model's training dynamics, system resource consumption, and feature importance without the overhead of external servers.
|
|
76
|
+
|
|
77
|
+
## ⚡ Key Capabilities
|
|
78
|
+
|
|
79
|
+
| Feature | Description |
|
|
80
|
+
| :--- | :--- |
|
|
81
|
+
| **Real-Time Telemetry** | Nanosecond-precision tracking of Loss, Accuracy, and custom metrics via async event dispatching. |
|
|
82
|
+
| **Intelligent Auto-Discovery** | Heuristic analysis of tabular datasets to automatically infer task types (Regression vs Classification) and suggest optimal architectures (CNNs, RFCs). |
|
|
83
|
+
| **System Profiling** | Integrated `psutil` hooks for monitoring CPU/GPU* and RAM saturation during training epochs. |
|
|
84
|
+
| **Artifact Management** | Automated checkpointing (`best-ckpt`) and structured logging (`events.jsonl`) with thread-safe IO. |
|
|
85
|
+
| **Comprehensive Reporting** | One-click generation of audit-ready PDF/JSON reports containing full training history and confusion matrices. |
|
|
86
|
+
| **Interactive Evaluation** | Post-training validation suite featuring dynamic Heatmap visualization for Confusion Matrices. |
|
|
87
|
+
|
|
88
|
+
## 🛠️ Architecture
|
|
89
|
+
|
|
90
|
+
Gradia employs a **Producer-Consumer** architecture:
|
|
91
|
+
|
|
92
|
+
1. **Trainer Thread (Producer)**: Executes the Scikit-Learn training loop, emitting atomic events to a thread-locked `EventLogger`.
|
|
93
|
+
2. **System Thread**: Asynchronously samples hardware metrics.
|
|
94
|
+
3. **Visualization Server (Consumer)**: A lightweight FastAPI instance that aggregates logs and serves a reactive Single Page Application (SPA).
|
|
95
|
+
|
|
96
|
+
This decoupling ensures that monitoring never bottlenecks your training throughput.
|
|
97
|
+
|
|
98
|
+
## 📦 Installation
|
|
99
|
+
|
|
100
|
+
```bash
|
|
101
|
+
pip install gradia --upgrade
|
|
102
|
+
```
|
|
103
|
+
|
|
104
|
+
## 💻 Usage
|
|
105
|
+
|
|
106
|
+
### Quick Start
|
|
107
|
+
Initialize the environment and start the observer in one command. Gradia will auto-detect any CSV files in the directory.
|
|
108
|
+
|
|
109
|
+
```bash
|
|
110
|
+
gradia run .
|
|
111
|
+
```
|
|
112
|
+
|
|
113
|
+
### Advanced CLI
|
|
114
|
+
Override default heuristics and bind to specific interfaces.
|
|
115
|
+
|
|
116
|
+
```bash
|
|
117
|
+
gradia run . \
|
|
118
|
+
--target "churn_label" \
|
|
119
|
+
--port 8080 \
|
|
120
|
+
--workers 4
|
|
121
|
+
```
|
|
122
|
+
|
|
123
|
+
## 📊 Dashboard
|
|
124
|
+
|
|
125
|
+
Access the dashboard at `http://localhost:8000`.
|
|
126
|
+
|
|
127
|
+
- **Configure**: Select your model (Random Forest, MLP, etc.) and hyperparameters.
|
|
128
|
+
- **Observe**: Watch metrics stream in real-time.
|
|
129
|
+
- **Analyze**: Use the built-in Feature Importance charts to debug model bias.
|
|
130
|
+
|
|
131
|
+
## 🤝 Contributing
|
|
132
|
+
|
|
133
|
+
We welcome contributions! Please see `CONTRIBUTING.md` for details on submitting logical PRs.
|
|
134
|
+
|
|
135
|
+
## 📄 License
|
|
136
|
+
|
|
137
|
+
Distributed under the MIT License. See `LICENSE` for more information.
|
|
138
|
+
|
|
139
|
+
---
|
|
140
|
+
|
|
141
|
+
<div align="center">
|
|
142
|
+
<sub>Built with ❤️ by STiFLeR for the ML Community.</sub>
|
|
143
|
+
</div>
|
gradia-1.0.0/README.md
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
<div align="center">
|
|
2
|
+
|
|
3
|
+
# G R A D I A
|
|
4
|
+
|
|
5
|
+
**Next-Generation Local-First MLOps Platform**
|
|
6
|
+
|
|
7
|
+
[](https://pypi.org/project/gradia/)
|
|
8
|
+
[](https://python.org)
|
|
9
|
+
[](LICENSE)
|
|
10
|
+
[](https://github.com/STiFLeR7/gradia/actions)
|
|
11
|
+
|
|
12
|
+
<p align="center">
|
|
13
|
+
<img src="https://github.com/STiFLeR7/gradia/blob/master/docs/dashboard.png" alt="Gradia Dashboard" width="100%" />
|
|
14
|
+
</p>
|
|
15
|
+
|
|
16
|
+
</div>
|
|
17
|
+
|
|
18
|
+
---
|
|
19
|
+
|
|
20
|
+
## 🚀 Overview
|
|
21
|
+
|
|
22
|
+
**Gradia** is a high-performance, asynchronous monitoring solution designed for local machine learning workflows. Unlike cloud-native behemoths, Gradia focuses on **zero-latency**, **privacy-first** tracking that runs directly alongside your training loop.
|
|
23
|
+
|
|
24
|
+
Built on **FastAPI**, **WebSockets** (simulated via high-frequency polling), and **Reactive UI**, Gradia provides granular visibility into your model's training dynamics, system resource consumption, and feature importance without the overhead of external servers.
|
|
25
|
+
|
|
26
|
+
## ⚡ Key Capabilities
|
|
27
|
+
|
|
28
|
+
| Feature | Description |
|
|
29
|
+
| :--- | :--- |
|
|
30
|
+
| **Real-Time Telemetry** | Nanosecond-precision tracking of Loss, Accuracy, and custom metrics via async event dispatching. |
|
|
31
|
+
| **Intelligent Auto-Discovery** | Heuristic analysis of tabular datasets to automatically infer task types (Regression vs Classification) and suggest optimal architectures (CNNs, RFCs). |
|
|
32
|
+
| **System Profiling** | Integrated `psutil` hooks for monitoring CPU/GPU* and RAM saturation during training epochs. |
|
|
33
|
+
| **Artifact Management** | Automated checkpointing (`best-ckpt`) and structured logging (`events.jsonl`) with thread-safe IO. |
|
|
34
|
+
| **Comprehensive Reporting** | One-click generation of audit-ready PDF/JSON reports containing full training history and confusion matrices. |
|
|
35
|
+
| **Interactive Evaluation** | Post-training validation suite featuring dynamic Heatmap visualization for Confusion Matrices. |
|
|
36
|
+
|
|
37
|
+
## 🛠️ Architecture
|
|
38
|
+
|
|
39
|
+
Gradia employs a **Producer-Consumer** architecture:
|
|
40
|
+
|
|
41
|
+
1. **Trainer Thread (Producer)**: Executes the Scikit-Learn training loop, emitting atomic events to a thread-locked `EventLogger`.
|
|
42
|
+
2. **System Thread**: Asynchronously samples hardware metrics.
|
|
43
|
+
3. **Visualization Server (Consumer)**: A lightweight FastAPI instance that aggregates logs and serves a reactive Single Page Application (SPA).
|
|
44
|
+
|
|
45
|
+
This decoupling ensures that monitoring never bottlenecks your training throughput.
|
|
46
|
+
|
|
47
|
+
## 📦 Installation
|
|
48
|
+
|
|
49
|
+
```bash
|
|
50
|
+
pip install gradia --upgrade
|
|
51
|
+
```
|
|
52
|
+
|
|
53
|
+
## 💻 Usage
|
|
54
|
+
|
|
55
|
+
### Quick Start
|
|
56
|
+
Initialize the environment and start the observer in one command. Gradia will auto-detect any CSV files in the directory.
|
|
57
|
+
|
|
58
|
+
```bash
|
|
59
|
+
gradia run .
|
|
60
|
+
```
|
|
61
|
+
|
|
62
|
+
### Advanced CLI
|
|
63
|
+
Override default heuristics and bind to specific interfaces.
|
|
64
|
+
|
|
65
|
+
```bash
|
|
66
|
+
gradia run . \
|
|
67
|
+
--target "churn_label" \
|
|
68
|
+
--port 8080 \
|
|
69
|
+
--workers 4
|
|
70
|
+
```
|
|
71
|
+
|
|
72
|
+
## 📊 Dashboard
|
|
73
|
+
|
|
74
|
+
Access the dashboard at `http://localhost:8000`.
|
|
75
|
+
|
|
76
|
+
- **Configure**: Select your model (Random Forest, MLP, etc.) and hyperparameters.
|
|
77
|
+
- **Observe**: Watch metrics stream in real-time.
|
|
78
|
+
- **Analyze**: Use the built-in Feature Importance charts to debug model bias.
|
|
79
|
+
|
|
80
|
+
## 🤝 Contributing
|
|
81
|
+
|
|
82
|
+
We welcome contributions! Please see `CONTRIBUTING.md` for details on submitting logical PRs.
|
|
83
|
+
|
|
84
|
+
## 📄 License
|
|
85
|
+
|
|
86
|
+
Distributed under the MIT License. See `LICENSE` for more information.
|
|
87
|
+
|
|
88
|
+
---
|
|
89
|
+
|
|
90
|
+
<div align="center">
|
|
91
|
+
<sub>Built with ❤️ by STiFLeR for the ML Community.</sub>
|
|
92
|
+
</div>
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
__version__ = "1.0.0"
|
|
File without changes
|
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
import typer
|
|
2
|
+
import threading
|
|
3
|
+
import time
|
|
4
|
+
import os
|
|
5
|
+
import webbrowser
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from rich.console import Console
|
|
8
|
+
from ..core.inspector import Inspector
|
|
9
|
+
from ..core.scenario import ScenarioInferrer
|
|
10
|
+
from ..core.config import ConfigManager
|
|
11
|
+
from ..trainer.engine import Trainer
|
|
12
|
+
from ..viz import server
|
|
13
|
+
|
|
14
|
+
app = typer.Typer()
|
|
15
|
+
console = Console()
|
|
16
|
+
|
|
17
|
+
@app.callback()
|
|
18
|
+
def callback():
|
|
19
|
+
"""
|
|
20
|
+
gradia: Local-first ML training visualization.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
@app.command()
|
|
24
|
+
def run(
|
|
25
|
+
ctx: typer.Context,
|
|
26
|
+
path: str = typer.Argument(".", help="Path to data directory"),
|
|
27
|
+
target: str = typer.Option(None, help="Manually specify target column"),
|
|
28
|
+
port: int = typer.Option(8000, help="Port for visualization server")
|
|
29
|
+
):
|
|
30
|
+
"""
|
|
31
|
+
Starts the gradia training and visualization session.
|
|
32
|
+
"""
|
|
33
|
+
console.rule("[bold blue]gradia v1.0.0[/bold blue]")
|
|
34
|
+
|
|
35
|
+
# 1. Inspect
|
|
36
|
+
path = Path(path).resolve()
|
|
37
|
+
inspector = Inspector(path)
|
|
38
|
+
datasets = inspector.find_datasets()
|
|
39
|
+
|
|
40
|
+
if not datasets:
|
|
41
|
+
console.print(f"[red]No .csv or .parquet files found in {path}[/red]")
|
|
42
|
+
raise typer.Exit(code=1)
|
|
43
|
+
|
|
44
|
+
# Select first dataset for MVP
|
|
45
|
+
dataset = datasets[0]
|
|
46
|
+
console.print(f"[green]Found dataset:[/green] {dataset.name}")
|
|
47
|
+
|
|
48
|
+
# 2. Config & Scenario Reuse
|
|
49
|
+
run_dir = path / ".gradia_logs"
|
|
50
|
+
config_mgr = ConfigManager(run_dir)
|
|
51
|
+
config = config_mgr.load_or_create()
|
|
52
|
+
|
|
53
|
+
# We infer scenario here to pass to server, but user confirms/configures in UI
|
|
54
|
+
with console.status("Inferring scenario..."):
|
|
55
|
+
inferrer = ScenarioInferrer()
|
|
56
|
+
scenario = inferrer.infer(str(dataset), target_override=target)
|
|
57
|
+
|
|
58
|
+
console.print(f"Target: [bold]{scenario.target_column}[/bold] | Task: [bold]{scenario.task_type}[/bold]")
|
|
59
|
+
# Session Isolation: Create unique run directory
|
|
60
|
+
session_id = int(time.time())
|
|
61
|
+
run_dir = Path(".gradia_logs") / f"run_{session_id}"
|
|
62
|
+
run_dir.mkdir(parents=True, exist_ok=True)
|
|
63
|
+
|
|
64
|
+
config_mgr = ConfigManager(run_dir)
|
|
65
|
+
config = config_mgr.load_or_create()
|
|
66
|
+
|
|
67
|
+
# Apply Smart Recommendation
|
|
68
|
+
config['model']['type'] = scenario.recommended_model
|
|
69
|
+
console.print(f"[cyan]Smart Suggestion:[/cyan] Using [bold]{scenario.recommended_model}[/bold] for this dataset.")
|
|
70
|
+
|
|
71
|
+
console.print(f"[bold green]Configuration moved to Web UI[/bold green]")
|
|
72
|
+
console.print(f"Visualization running at http://localhost:{port}")
|
|
73
|
+
console.print(f"Logs: {run_dir.resolve()}")
|
|
74
|
+
|
|
75
|
+
# 3. Launch Server
|
|
76
|
+
# We inject state into the server module before starting it
|
|
77
|
+
server.SCENARIO = scenario
|
|
78
|
+
server.CONFIG_MGR = config_mgr
|
|
79
|
+
server.RUN_DIR = run_dir
|
|
80
|
+
server.DEFAULT_CONFIG = config
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
# Launch browser
|
|
84
|
+
threading.Timer(1.5, lambda: webbrowser.open(f"http://localhost:{port}/configure")).start()
|
|
85
|
+
|
|
86
|
+
# Start server (blocking main thread is fine now as we don't have a separate training thread YET)
|
|
87
|
+
# The training thread will be spawned by the server upon API request.
|
|
88
|
+
server.start_server(str(run_dir), port)
|
|
89
|
+
|
|
90
|
+
if __name__ == "__main__":
|
|
91
|
+
app()
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
import yaml
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import Any, Dict
|
|
4
|
+
|
|
5
|
+
class ConfigManager:
|
|
6
|
+
"""Manages gradia configuration."""
|
|
7
|
+
|
|
8
|
+
DEFAULT_CONFIG = {
|
|
9
|
+
'model': {
|
|
10
|
+
'type': 'auto', # auto, linear, random_forest
|
|
11
|
+
'params': {}
|
|
12
|
+
},
|
|
13
|
+
'training': {
|
|
14
|
+
'test_split': 0.2,
|
|
15
|
+
'random_seed': 42,
|
|
16
|
+
'shuffle': True
|
|
17
|
+
},
|
|
18
|
+
'scenario': {
|
|
19
|
+
'target': None, # Auto-detect
|
|
20
|
+
'task': None # Auto-detect
|
|
21
|
+
}
|
|
22
|
+
}
|
|
23
|
+
|
|
24
|
+
def __init__(self, run_dir: str = ".gradia_logs"):
|
|
25
|
+
self.run_dir = Path(run_dir)
|
|
26
|
+
self.config_path = self.run_dir / "config.yaml"
|
|
27
|
+
|
|
28
|
+
def load_or_create(self, user_overrides: Dict[str, Any] = None) -> Dict[str, Any]:
|
|
29
|
+
config = self.DEFAULT_CONFIG.copy()
|
|
30
|
+
|
|
31
|
+
# Load existing if any (feature for restart, maybe not for MVP run-once)
|
|
32
|
+
# For immutable runs, we usually generate NEW config.
|
|
33
|
+
# But if gradia.yaml exists in ROOT, we load it.
|
|
34
|
+
|
|
35
|
+
root_config = Path("gradia.yaml")
|
|
36
|
+
if root_config.exists():
|
|
37
|
+
with open(root_config, 'r') as f:
|
|
38
|
+
user_config = yaml.safe_load(f)
|
|
39
|
+
self._update_recursive(config, user_config)
|
|
40
|
+
|
|
41
|
+
if user_overrides:
|
|
42
|
+
self._update_recursive(config, user_overrides)
|
|
43
|
+
|
|
44
|
+
return config
|
|
45
|
+
|
|
46
|
+
def save(self, config: Dict[str, Any]):
|
|
47
|
+
self.run_dir.mkdir(exist_ok=True)
|
|
48
|
+
with open(self.config_path, 'w') as f:
|
|
49
|
+
yaml.dump(config, f)
|
|
50
|
+
|
|
51
|
+
def _update_recursive(self, base: Dict, update: Dict):
|
|
52
|
+
for k, v in update.items():
|
|
53
|
+
if k in base and isinstance(base[k], dict) and isinstance(v, dict):
|
|
54
|
+
self._update_recursive(base[k], v)
|
|
55
|
+
else:
|
|
56
|
+
base[k] = v
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import List, Optional
|
|
4
|
+
|
|
5
|
+
class Inspector:
|
|
6
|
+
"""Scans the working directory for potential dataset files."""
|
|
7
|
+
|
|
8
|
+
SUPPORTED_EXTENSIONS = {'.csv', '.parquet'}
|
|
9
|
+
|
|
10
|
+
def __init__(self, root_dir: str = "."):
|
|
11
|
+
self.root_dir = Path(root_dir)
|
|
12
|
+
|
|
13
|
+
def find_datasets(self) -> List[Path]:
|
|
14
|
+
"""Finds all supported dataset files in the root directory."""
|
|
15
|
+
datasets = []
|
|
16
|
+
for ext in self.SUPPORTED_EXTENSIONS:
|
|
17
|
+
datasets.extend(self.root_dir.glob(f"*{ext}"))
|
|
18
|
+
return sorted(datasets)
|
|
19
|
+
|
|
20
|
+
def detect_split_layout(self):
|
|
21
|
+
"""
|
|
22
|
+
Detects if proper 'train'/'val'/'test' folders exist.
|
|
23
|
+
Returns a dictionary with paths or None.
|
|
24
|
+
"""
|
|
25
|
+
layout = {}
|
|
26
|
+
for split in ['train', 'val', 'validation', 'test']:
|
|
27
|
+
split_dir = self.root_dir / split
|
|
28
|
+
if split_dir.exists() and split_dir.is_dir():
|
|
29
|
+
# Check for files inside
|
|
30
|
+
files = []
|
|
31
|
+
for ext in self.SUPPORTED_EXTENSIONS:
|
|
32
|
+
files.extend(list(split_dir.glob(f"*{ext}")))
|
|
33
|
+
|
|
34
|
+
if files:
|
|
35
|
+
layout[split] = split_dir
|
|
36
|
+
|
|
37
|
+
return layout if layout else None
|
|
@@ -0,0 +1,118 @@
|
|
|
1
|
+
import pandas as pd
|
|
2
|
+
import numpy as np
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
from typing import Optional, List, Any
|
|
5
|
+
|
|
6
|
+
@dataclass
|
|
7
|
+
class Scenario:
|
|
8
|
+
dataset_path: str
|
|
9
|
+
target_column: str
|
|
10
|
+
task_type: str # 'classification' or 'regression'
|
|
11
|
+
is_multiclass: bool = False
|
|
12
|
+
class_count: int = 0
|
|
13
|
+
features: List[str] = field(default_factory=list)
|
|
14
|
+
recommended_model: str = "random_forest"
|
|
15
|
+
|
|
16
|
+
class ScenarioInferrer:
|
|
17
|
+
"""Infers the ML scenario (Task type, Target) from a dataset."""
|
|
18
|
+
|
|
19
|
+
POSSIBLE_TARGET_NAMES = ['target', 'label', 'y', 'class', 'outcome', 'price', 'score']
|
|
20
|
+
|
|
21
|
+
def infer(self, file_path: str, target_override: Optional[str] = None) -> Scenario:
|
|
22
|
+
# Load a sample to infer types
|
|
23
|
+
df = self._load_sample(file_path)
|
|
24
|
+
|
|
25
|
+
target = target_override
|
|
26
|
+
if not target:
|
|
27
|
+
target = self._guess_target(df)
|
|
28
|
+
|
|
29
|
+
if not target:
|
|
30
|
+
raise ValueError(f"Could not infer target column for {file_path}. Please name one of {self.POSSIBLE_TARGET_NAMES} or provide config.")
|
|
31
|
+
|
|
32
|
+
task_type, is_multiclass, count = self._infer_task_type(df[target])
|
|
33
|
+
features = [c for c in df.columns if c != target]
|
|
34
|
+
|
|
35
|
+
recommended_model = self._infer_model_recommendation(features)
|
|
36
|
+
|
|
37
|
+
return Scenario(
|
|
38
|
+
dataset_path=str(file_path),
|
|
39
|
+
target_column=target,
|
|
40
|
+
task_type=task_type,
|
|
41
|
+
is_multiclass=is_multiclass,
|
|
42
|
+
class_count=count,
|
|
43
|
+
features=features,
|
|
44
|
+
recommended_model=recommended_model
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
def _infer_model_recommendation(self, features: List[str]) -> str:
|
|
48
|
+
# Heuristic 1: Check for pixel data (Fashion MNIST, MNIST, etc.)
|
|
49
|
+
# If > 100 features and names contain 'pixel'
|
|
50
|
+
if len(features) > 100:
|
|
51
|
+
pixel_cols = [f for f in features if 'pixel' in f.lower()]
|
|
52
|
+
if len(pixel_cols) > len(features) * 0.5:
|
|
53
|
+
return "cnn"
|
|
54
|
+
|
|
55
|
+
# Heuristic 2: Tabular default
|
|
56
|
+
return "random_forest"
|
|
57
|
+
|
|
58
|
+
def _load_sample(self, path: str, n_rows: int = 1000) -> pd.DataFrame:
|
|
59
|
+
if path.endswith('.csv'):
|
|
60
|
+
return pd.read_csv(path, nrows=n_rows)
|
|
61
|
+
elif path.endswith('.parquet'):
|
|
62
|
+
# Parquet doesn't support 'nrows' efficiently same as csv sometimes,
|
|
63
|
+
# but pandas read_parquet usually loads full. For large files we might need pyarrow.
|
|
64
|
+
# For MVP assume fits in memory or use logic to limits.
|
|
65
|
+
return pd.read_parquet(path).head(n_rows)
|
|
66
|
+
else:
|
|
67
|
+
raise ValueError("Unsupported format")
|
|
68
|
+
|
|
69
|
+
def _guess_target(self, df: pd.DataFrame) -> Optional[str]:
|
|
70
|
+
# 1. Exact Name match
|
|
71
|
+
for name in self.POSSIBLE_TARGET_NAMES:
|
|
72
|
+
if name in df.columns:
|
|
73
|
+
return name
|
|
74
|
+
if name.upper() in df.columns:
|
|
75
|
+
return name.upper()
|
|
76
|
+
|
|
77
|
+
# 2. Heuristic: Avoid ID/Date columns
|
|
78
|
+
candidates = []
|
|
79
|
+
for col in df.columns:
|
|
80
|
+
lower = col.lower()
|
|
81
|
+
if not any(x in lower for x in ['id', 'date', 'time', 'created_at', 'uuid', 'index']):
|
|
82
|
+
candidates.append(col)
|
|
83
|
+
|
|
84
|
+
if candidates:
|
|
85
|
+
return candidates[-1]
|
|
86
|
+
|
|
87
|
+
# 3. Last column fallback
|
|
88
|
+
return df.columns[-1]
|
|
89
|
+
|
|
90
|
+
def _infer_task_type(self, series: pd.Series):
|
|
91
|
+
"""
|
|
92
|
+
Returns (task_type, is_multiclass, class_count)
|
|
93
|
+
"""
|
|
94
|
+
# Heuristics:
|
|
95
|
+
# If string/object -> Classification
|
|
96
|
+
# If float -> Regression (unless low cardinality?)
|
|
97
|
+
# If int -> Check cardinality. Low (<20) -> Classification. High -> Regression.
|
|
98
|
+
|
|
99
|
+
unique_count = series.nunique()
|
|
100
|
+
dtype = series.dtype
|
|
101
|
+
|
|
102
|
+
if pd.api.types.is_string_dtype(dtype) or pd.api.types.is_object_dtype(dtype):
|
|
103
|
+
return 'classification', unique_count > 2, unique_count
|
|
104
|
+
|
|
105
|
+
if pd.api.types.is_float_dtype(dtype):
|
|
106
|
+
# If floats are actually integers (e.g. 1.0, 0.0), check that
|
|
107
|
+
if series.apply(float.is_integer).all() and unique_count < 20:
|
|
108
|
+
return 'classification', unique_count > 2, unique_count
|
|
109
|
+
return 'regression', False, 0
|
|
110
|
+
|
|
111
|
+
if pd.api.types.is_integer_dtype(dtype):
|
|
112
|
+
if unique_count < 20: # Arbitrary threshold for MVP
|
|
113
|
+
return 'classification', unique_count > 2, unique_count
|
|
114
|
+
else:
|
|
115
|
+
return 'regression', False, 0
|
|
116
|
+
|
|
117
|
+
# Fallback
|
|
118
|
+
return 'regression', False, 0
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import Any, Dict, Optional
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
class GradiaModel(ABC):
|
|
6
|
+
"""Abstract base class for all Gradia models."""
|
|
7
|
+
|
|
8
|
+
@abstractmethod
|
|
9
|
+
def fit(self, X, y, **kwargs):
|
|
10
|
+
"""Train the model fully."""
|
|
11
|
+
pass
|
|
12
|
+
|
|
13
|
+
def partial_fit(self, X, y, **kwargs):
|
|
14
|
+
"""Train on a batch or single epoch (optional)."""
|
|
15
|
+
raise NotImplementedError("This model does not support iterative training.")
|
|
16
|
+
|
|
17
|
+
@property
|
|
18
|
+
def supports_iterative(self) -> bool:
|
|
19
|
+
return False
|
|
20
|
+
|
|
21
|
+
@abstractmethod
|
|
22
|
+
def predict(self, X) -> np.ndarray:
|
|
23
|
+
"""Make predictions."""
|
|
24
|
+
pass
|
|
25
|
+
|
|
26
|
+
@abstractmethod
|
|
27
|
+
def predict_proba(self, X) -> Optional[np.ndarray]:
|
|
28
|
+
"""Make probability predictions (if applicable)."""
|
|
29
|
+
pass
|
|
30
|
+
|
|
31
|
+
@abstractmethod
|
|
32
|
+
def get_feature_importance(self) -> Optional[Dict[str, float]]:
|
|
33
|
+
"""Return feature importance map if available."""
|
|
34
|
+
pass
|
|
35
|
+
|
|
36
|
+
@abstractmethod
|
|
37
|
+
def get_params(self) -> Dict[str, Any]:
|
|
38
|
+
"""Return model hyperparameters."""
|
|
39
|
+
pass
|