gradia 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,203 @@
1
+ from typing import Any, Dict, List
2
+ from tqdm import tqdm
3
+ import pandas as pd
4
+ import numpy as np
5
+ import time
6
+ import json
7
+ import pickle
8
+ from pathlib import Path
9
+ from sklearn.model_selection import train_test_split
10
+ from sklearn.metrics import (
11
+ accuracy_score, mean_squared_error, r2_score,
12
+ precision_score, recall_score, f1_score, mean_absolute_error, confusion_matrix
13
+ )
14
+ from ..models.base import GradiaModel
15
+ from ..models.sklearn_wrappers import ModelFactory
16
+ from ..core.scenario import Scenario
17
+ from .callbacks import Callback, EventLogger
18
+
19
+ class Trainer:
20
+ def __init__(self, scenario: Scenario, config: Dict[str, Any], run_dir: str):
21
+ self.scenario = scenario
22
+ self.config = config
23
+ self.run_dir = run_dir
24
+ print(f"DEBUG: Trainer initialized with RUN_DIR: {self.run_dir}")
25
+ self.model: GradiaModel = ModelFactory.create(
26
+ config['model']['type'],
27
+ scenario.task_type,
28
+ config['model'].get('params', {})
29
+ )
30
+ self.callbacks: List[Callback] = [EventLogger(run_dir)]
31
+
32
+ def run(self):
33
+ print("DEBUG: Trainer.run() started.")
34
+ try:
35
+ # 1. Load Data
36
+ df = self._load_full_data()
37
+
38
+ # 2. Preprocess
39
+ df = df.dropna()
40
+
41
+ # Separate Target and Features
42
+ y = df[self.scenario.target_column]
43
+ X = df[self.scenario.features]
44
+
45
+ # --- Robust Preprocessing ---
46
+ # 1. Identify non-numeric columns
47
+ non_numeric_cols = X.select_dtypes(include=['object', 'category']).columns.tolist()
48
+
49
+ cols_to_drop = []
50
+ cols_to_encode = []
51
+
52
+ for col in non_numeric_cols:
53
+ # high cardinality heuristic (>50 unique and >5% of data) => drop (likely ID or Name)
54
+ unique_count = X[col].nunique()
55
+ if unique_count > 50 and unique_count > len(X) * 0.05:
56
+ cols_to_drop.append(col)
57
+ else:
58
+ cols_to_encode.append(col)
59
+
60
+ if cols_to_drop:
61
+ X = X.drop(columns=cols_to_drop)
62
+ print(f"Dropped high-cardinality/ID columns: {cols_to_drop}")
63
+
64
+ # 2. One-Hot Encode
65
+ if cols_to_encode:
66
+ X = pd.get_dummies(X, columns=cols_to_encode, drop_first=True)
67
+ print(f"Encoded columns: {cols_to_encode}")
68
+
69
+ # Update features list for the UI
70
+ self.scenario.features = X.columns.tolist()
71
+
72
+ # --- End Preprocessing ---
73
+
74
+ # Simple encoding for classification target if string
75
+ if self.scenario.task_type == 'classification' and y.dtype == 'object':
76
+ y = y.astype('category').cat.codes
77
+
78
+ # 3. Split
79
+ test_size = self.config['training'].get('test_split', 0.2)
80
+ X_train, X_test, y_train, y_test = train_test_split(
81
+ X, y, test_size=test_size,
82
+ random_state=self.config['training'].get('random_seed', 42)
83
+ )
84
+
85
+ # Notify Start
86
+ epochs = self.config['training'].get('epochs', 10)
87
+ self._dispatch('on_train_begin', {
88
+ "scenario": str(self.scenario),
89
+ "samples": len(df),
90
+ "features": self.scenario.features,
91
+ "epochs": epochs
92
+ })
93
+
94
+ # 4. Fit Loop
95
+ epochs = self.config['training'].get('epochs', 10)
96
+
97
+ if self.model.supports_iterative:
98
+ # For Random Forest in warm_start, we need to set initial estimators
99
+ if hasattr(self.model.model, "n_estimators"):
100
+ # We will grow it from 0 to 'epochs' (which acts as total estimators here)
101
+ self.model.model.n_estimators = 0
102
+
103
+ classes = np.unique(y) if self.scenario.task_type == 'classification' else None
104
+
105
+ # TQDM Output to Console
106
+ import time
107
+ with tqdm(range(1, epochs + 1), desc="Training", unit="epoch", colour="green") as pbar:
108
+ for epoch in pbar:
109
+ # Small delay to visualize speed if too fast
110
+ time.sleep(0.1)
111
+
112
+ self.model.partial_fit(X_train, y_train, classes=classes)
113
+
114
+ # Evaluate
115
+ metrics = self._evaluate(X_train, y_train, X_test, y_test)
116
+ self._dispatch('on_epoch_end', epoch, metrics)
117
+
118
+ else:
119
+ # Non-Iterative Models (SVM, KNN, DecisionTree, etc.)
120
+ # We fit once, then simulate "epochs" for user visual satisfaction
121
+ print("Training standard model (single batch fits)...")
122
+ self.model.fit(X_train, y_train)
123
+
124
+ # Compute final metrics
125
+ metrics = self._evaluate(X_train, y_train, X_test, y_test)
126
+
127
+ # Simulate progress bar so UI doesn't look broken
128
+ import time
129
+ with tqdm(range(1, epochs + 1), desc="Training", unit="epoch", colour="blue") as pbar:
130
+ for epoch in pbar:
131
+ time.sleep(0.1) # Simulate work
132
+ # We broadcast the SAME metrics for every "epoch" since the model doesn't change
133
+ # But it keeps the UI happy and consistent
134
+ self._dispatch('on_epoch_end', epoch, metrics)
135
+
136
+ # Update Progress Bar
137
+ pf = {}
138
+ if 'train_acc' in metrics:
139
+ pf['acc'] = f"{metrics['train_acc']:.3f}"
140
+ if 'train_mse' in metrics:
141
+ pf['mse'] = f"{metrics['train_mse']:.3f}"
142
+ pbar.set_postfix(pf)
143
+
144
+
145
+
146
+ # 5. Finalize
147
+ fi = self.model.get_feature_importance()
148
+
149
+ # 6. Training Complete
150
+ self._dispatch("on_train_end", {
151
+ "epoch": epochs,
152
+ "feature_importance": self.model.get_feature_importance()
153
+ })
154
+
155
+ # 7. Save Model
156
+ if self.config.get('save_model'):
157
+ ckpt_dir = Path(self.run_dir) / "models" / "best-ckpt"
158
+ ckpt_dir.mkdir(parents=True, exist_ok=True)
159
+ ckpt_path = ckpt_dir / "model.pkl"
160
+ with open(ckpt_path, "wb") as f:
161
+ pickle.dump(self.model, f)
162
+ print(f"Model saved to {ckpt_path}")
163
+
164
+ print(f"DEBUG: Trainer.run() finished successfully.")
165
+
166
+ except Exception as e:
167
+ print(f"CRITICAL ERROR IN TRAINER: {e}")
168
+ import traceback
169
+ traceback.print_exc()
170
+ raise e
171
+
172
+ def _evaluate(self, X_train, y_train, X_test, y_test):
173
+ preds_train = self.model.predict(X_train)
174
+ preds_test = self.model.predict(X_test)
175
+
176
+ metrics = {}
177
+ if self.scenario.task_type == 'classification':
178
+ metrics['train_acc'] = accuracy_score(y_train, preds_train)
179
+ metrics['test_acc'] = accuracy_score(y_test, preds_test)
180
+
181
+ # Weighted average for multiclass support
182
+ metrics['precision'] = precision_score(y_test, preds_test, average='weighted', zero_division=0)
183
+ metrics['recall'] = recall_score(y_test, preds_test, average='weighted', zero_division=0)
184
+ metrics['f1'] = f1_score(y_test, preds_test, average='weighted', zero_division=0)
185
+
186
+ else:
187
+ metrics['train_mse'] = mean_squared_error(y_train, preds_train)
188
+ metrics['test_mse'] = mean_squared_error(y_test, preds_test)
189
+ metrics['mae'] = mean_absolute_error(y_test, preds_test)
190
+ metrics['r2'] = r2_score(y_test, preds_test)
191
+
192
+ return metrics
193
+
194
+ def _load_full_data(self):
195
+ # MVP: Load everything into memory
196
+ path = self.scenario.dataset_path
197
+ if path.endswith('.csv'):
198
+ return pd.read_csv(path)
199
+ return pd.read_parquet(path)
200
+
201
+ def _dispatch(self, method_name, *args, **kwargs):
202
+ for cb in self.callbacks:
203
+ getattr(cb, method_name)(*args, **kwargs)
Binary file
gradia/viz/server.py ADDED
@@ -0,0 +1,228 @@
1
+ from fastapi import FastAPI, Request
2
+ from fastapi.staticfiles import StaticFiles
3
+ from fastapi.templating import Jinja2Templates
4
+ from fastapi.responses import JSONResponse, RedirectResponse
5
+ import uvicorn
6
+ import json
7
+ import threading
8
+ from pathlib import Path
9
+ from typing import Dict, Any
10
+
11
+ from ..trainer.engine import Trainer
12
+
13
+ import psutil
14
+ import time
15
+
16
+ app = FastAPI()
17
+
18
+ # Global State (Injected by CLI)
19
+ SCENARIO = None
20
+ CONFIG_MGR = None
21
+ RUN_DIR = Path(".gradia_logs").resolve()
22
+ DEFAULT_CONFIG = {}
23
+ TRAINER = None
24
+ TRAINING_THREAD = None
25
+ SYSTEM_THREAD = None
26
+
27
+ # Mounts
28
+ BASE_DIR = Path(__file__).resolve().parent
29
+
30
+ app.mount("/static", StaticFiles(directory=BASE_DIR / "static"), name="static")
31
+ # Mount assets if they exist outside static, or ensure user put them in static. Assuming viz/assets
32
+ assets_path = BASE_DIR / "assets"
33
+ if assets_path.exists():
34
+ app.mount("/assets", StaticFiles(directory=assets_path), name="assets")
35
+
36
+ templates = Jinja2Templates(directory=BASE_DIR / "templates")
37
+
38
+ from ..trainer.callbacks import log_lock
39
+
40
+ # ... imports ...
41
+ import os
42
+
43
+ # System Monitor
44
+ def system_monitor_loop():
45
+ log_path = RUN_DIR / "events.jsonl"
46
+ while True:
47
+ cpu = psutil.cpu_percent(interval=1)
48
+ mem = psutil.virtual_memory().percent
49
+ t = time.time()
50
+
51
+ event = {
52
+ "timestamp": t,
53
+ "type": "system_metrics",
54
+ "data": {"cpu": cpu, "ram": mem, "epoch": t}
55
+ }
56
+
57
+ if RUN_DIR.exists():
58
+ with log_lock:
59
+ with open(log_path, "a") as f:
60
+ f.write(json.dumps(event) + "\n")
61
+ f.flush()
62
+ os.fsync(f.fileno())
63
+
64
+ # Start System Monitor on import/startup (or when server starts)
65
+ @app.on_event("startup")
66
+ async def startup_event():
67
+ global SYSTEM_THREAD
68
+ SYSTEM_THREAD = threading.Thread(target=system_monitor_loop, daemon=True)
69
+ SYSTEM_THREAD.start()
70
+
71
+
72
+ @app.get("/")
73
+ async def read_root(request: Request):
74
+ if TRAINER is None:
75
+ return RedirectResponse("/configure")
76
+ return templates.TemplateResponse("index.html", {"request": request, "scenario": SCENARIO})
77
+
78
+ @app.get("/configure")
79
+ async def configure_page(request: Request):
80
+ if SCENARIO is None:
81
+ return "System not initialized correctly from CLI."
82
+
83
+ return templates.TemplateResponse("configure.html", {
84
+ "request": request,
85
+ "scenario": SCENARIO,
86
+ "features": SCENARIO.features,
87
+ "default_config": DEFAULT_CONFIG
88
+ })
89
+
90
+ @app.post("/api/start")
91
+ async def start_training(config_data: Dict[str, Any]):
92
+ global TRAINER, TRAINING_THREAD
93
+
94
+ # Merge received config with defaults
95
+ # Expect config_data = {model: {type:..., params:...}, training: {epochs:...}}
96
+
97
+ # We construct the full config object
98
+ full_config = DEFAULT_CONFIG.copy()
99
+
100
+ # Helper to merge deep dicts if needed, or just overwrite keys
101
+ full_config['model'] = config_data.get('model', full_config['model'])
102
+ full_config['training'].update(config_data.get('training', {}))
103
+
104
+ # New fields
105
+ full_config['project_name'] = config_data.get('project_name', 'experiment')
106
+ full_config['save_model'] = config_data.get('save_model', False)
107
+
108
+ # Save config
109
+ CONFIG_MGR.save(full_config)
110
+
111
+ # Initialize Trainer
112
+ TRAINER = Trainer(SCENARIO, full_config, str(RUN_DIR))
113
+
114
+ # Start Thread
115
+ def train_wrapper():
116
+ import time
117
+ time.sleep(1) # Breathe
118
+ try:
119
+ TRAINER.run()
120
+ except Exception as e:
121
+ print(f"Training Error: {e}")
122
+
123
+ TRAINING_THREAD = threading.Thread(target=train_wrapper, daemon=True)
124
+ TRAINING_THREAD.start()
125
+
126
+ return {"status": "started"}
127
+
128
+ @app.get("/api/events")
129
+ async def get_events():
130
+ event_path = RUN_DIR / "events.jsonl"
131
+ events = []
132
+
133
+ if event_path.exists():
134
+ # No lock needed for reading usually if we tolerate partial lines (which json.loads handles with try/except)
135
+ # But to be safe vs partial writes, we could lock, but that might block writers.
136
+ # Standard polling read is usually fine without lock if we just read lines.
137
+ with open(event_path, "r") as f:
138
+ for line in f:
139
+ if line.strip():
140
+ try:
141
+ events.append(json.loads(line))
142
+ except json.JSONDecodeError:
143
+ pass
144
+
145
+ return JSONResponse(content=events)
146
+
147
+ @app.get("/api/report/json")
148
+ async def download_report_json():
149
+ event_path = RUN_DIR / "events.jsonl"
150
+ if not event_path.exists():
151
+ return JSONResponse({"error": "No logs found"}, status_code=404)
152
+
153
+ events = []
154
+ with open(event_path, "r") as f:
155
+ for line in f:
156
+ if line.strip():
157
+ try: events.append(json.loads(line))
158
+ except: pass
159
+
160
+ return JSONResponse(content={"project": SCENARIO.target_column if SCENARIO else "gradia", "events": events})
161
+
162
+ @app.get("/api/report/pdf")
163
+ async def download_report_pdf():
164
+ # Return a HTML page optimized for print-to-pdf for simplicity without reportlab dep
165
+ event_path = RUN_DIR / "events.jsonl"
166
+ events = []
167
+ if event_path.exists():
168
+ with open(event_path, "r") as f:
169
+ for line in f:
170
+ try: events.append(json.loads(line))
171
+ except: pass
172
+
173
+ html = f"""
174
+ <html>
175
+ <head>
176
+ <title>Training Report</title>
177
+ <style>
178
+ body {{ font-family: sans-serif; padding: 40px; }}
179
+ h1 {{ border-bottom: 2px solid #333; }}
180
+ table {{ border-collapse: collapse; width: 100%; margin-top: 20px; }}
181
+ th, td {{ border: 1px solid #ddd; padding: 8px; text-align: left; }}
182
+ th {{ background-color: #f2f2f2; }}
183
+ .metric {{ color: #0066cc; font-weight: bold; }}
184
+ </style>
185
+ </head>
186
+ <body onload="window.print()">
187
+ <h1>Gradia Training Report</h1>
188
+ <p>Target: {SCENARIO.target_column if SCENARIO else 'N/A'}</p>
189
+ <p>Total Epochs: {len([e for e in events if e['type'] == 'epoch_end'])}</p>
190
+
191
+ <h2>Training History</h2>
192
+ <table>
193
+ <thead><tr><th>Epoch</th><th>Train Acc/MSE</th><th>Test Acc/MSE</th><th>CPU %</th><th>RAM %</th></tr></thead>
194
+ <tbody>
195
+ """
196
+
197
+ # Process events to correlate metrics
198
+ epochs = [e for e in events if e['type'] == 'epoch_end']
199
+ for e in epochs:
200
+ d = e['data']
201
+ # Find close system metric
202
+ html += f"<tr><td>{d['epoch']}</td><td>{d.get('train_acc', d.get('train_mse', 'N/A'))}</td><td>{d.get('test_acc', d.get('test_mse', 'N/A'))}</td><td>-</td><td>-</td></tr>"
203
+
204
+ html += """
205
+ </tbody>
206
+ </table>
207
+ </body>
208
+ </html>
209
+ """
210
+ from fastapi.responses import HTMLResponse
211
+ return HTMLResponse(content=html)
212
+
213
+ @app.post("/api/evaluate")
214
+ async def evaluate_model():
215
+ if TRAINER is None:
216
+ return JSONResponse({"error": "No model trained"}, status_code=400)
217
+
218
+ try:
219
+ results = TRAINER.evaluate_full()
220
+ return JSONResponse(content=results)
221
+ except Exception as e:
222
+ return JSONResponse({"error": str(e)}, status_code=500)
223
+
224
+ def start_server(run_dir: str, port: int = 8000):
225
+ global RUN_DIR
226
+ RUN_DIR = Path(run_dir).resolve()
227
+ print(f"DEBUG: Server using RUN_DIR: {RUN_DIR}")
228
+ uvicorn.run(app, host="127.0.0.1", port=port, log_level="error")