gradia 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gradia/__init__.py +1 -0
- gradia/cli/__init__.py +0 -0
- gradia/cli/main.py +91 -0
- gradia/core/config.py +56 -0
- gradia/core/inspector.py +37 -0
- gradia/core/scenario.py +118 -0
- gradia/models/base.py +39 -0
- gradia/models/sklearn_wrappers.py +114 -0
- gradia/trainer/callbacks.py +48 -0
- gradia/trainer/engine.py +203 -0
- gradia/viz/assets/logo.png +0 -0
- gradia/viz/server.py +228 -0
- gradia/viz/static/css/style.css +312 -0
- gradia/viz/static/js/app.js +348 -0
- gradia/viz/templates/configure.html +304 -0
- gradia/viz/templates/index.html +147 -0
- gradia-1.0.0.dist-info/METADATA +143 -0
- gradia-1.0.0.dist-info/RECORD +22 -0
- gradia-1.0.0.dist-info/WHEEL +5 -0
- gradia-1.0.0.dist-info/entry_points.txt +2 -0
- gradia-1.0.0.dist-info/licenses/LICENSE +21 -0
- gradia-1.0.0.dist-info/top_level.txt +1 -0
gradia/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
__version__ = "1.0.0"
|
gradia/cli/__init__.py
ADDED
|
File without changes
|
gradia/cli/main.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
import typer
|
|
2
|
+
import threading
|
|
3
|
+
import time
|
|
4
|
+
import os
|
|
5
|
+
import webbrowser
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from rich.console import Console
|
|
8
|
+
from ..core.inspector import Inspector
|
|
9
|
+
from ..core.scenario import ScenarioInferrer
|
|
10
|
+
from ..core.config import ConfigManager
|
|
11
|
+
from ..trainer.engine import Trainer
|
|
12
|
+
from ..viz import server
|
|
13
|
+
|
|
14
|
+
app = typer.Typer()
|
|
15
|
+
console = Console()
|
|
16
|
+
|
|
17
|
+
@app.callback()
|
|
18
|
+
def callback():
|
|
19
|
+
"""
|
|
20
|
+
gradia: Local-first ML training visualization.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
@app.command()
|
|
24
|
+
def run(
|
|
25
|
+
ctx: typer.Context,
|
|
26
|
+
path: str = typer.Argument(".", help="Path to data directory"),
|
|
27
|
+
target: str = typer.Option(None, help="Manually specify target column"),
|
|
28
|
+
port: int = typer.Option(8000, help="Port for visualization server")
|
|
29
|
+
):
|
|
30
|
+
"""
|
|
31
|
+
Starts the gradia training and visualization session.
|
|
32
|
+
"""
|
|
33
|
+
console.rule("[bold blue]gradia v1.0.0[/bold blue]")
|
|
34
|
+
|
|
35
|
+
# 1. Inspect
|
|
36
|
+
path = Path(path).resolve()
|
|
37
|
+
inspector = Inspector(path)
|
|
38
|
+
datasets = inspector.find_datasets()
|
|
39
|
+
|
|
40
|
+
if not datasets:
|
|
41
|
+
console.print(f"[red]No .csv or .parquet files found in {path}[/red]")
|
|
42
|
+
raise typer.Exit(code=1)
|
|
43
|
+
|
|
44
|
+
# Select first dataset for MVP
|
|
45
|
+
dataset = datasets[0]
|
|
46
|
+
console.print(f"[green]Found dataset:[/green] {dataset.name}")
|
|
47
|
+
|
|
48
|
+
# 2. Config & Scenario Reuse
|
|
49
|
+
run_dir = path / ".gradia_logs"
|
|
50
|
+
config_mgr = ConfigManager(run_dir)
|
|
51
|
+
config = config_mgr.load_or_create()
|
|
52
|
+
|
|
53
|
+
# We infer scenario here to pass to server, but user confirms/configures in UI
|
|
54
|
+
with console.status("Inferring scenario..."):
|
|
55
|
+
inferrer = ScenarioInferrer()
|
|
56
|
+
scenario = inferrer.infer(str(dataset), target_override=target)
|
|
57
|
+
|
|
58
|
+
console.print(f"Target: [bold]{scenario.target_column}[/bold] | Task: [bold]{scenario.task_type}[/bold]")
|
|
59
|
+
# Session Isolation: Create unique run directory
|
|
60
|
+
session_id = int(time.time())
|
|
61
|
+
run_dir = Path(".gradia_logs") / f"run_{session_id}"
|
|
62
|
+
run_dir.mkdir(parents=True, exist_ok=True)
|
|
63
|
+
|
|
64
|
+
config_mgr = ConfigManager(run_dir)
|
|
65
|
+
config = config_mgr.load_or_create()
|
|
66
|
+
|
|
67
|
+
# Apply Smart Recommendation
|
|
68
|
+
config['model']['type'] = scenario.recommended_model
|
|
69
|
+
console.print(f"[cyan]Smart Suggestion:[/cyan] Using [bold]{scenario.recommended_model}[/bold] for this dataset.")
|
|
70
|
+
|
|
71
|
+
console.print(f"[bold green]Configuration moved to Web UI[/bold green]")
|
|
72
|
+
console.print(f"Visualization running at http://localhost:{port}")
|
|
73
|
+
console.print(f"Logs: {run_dir.resolve()}")
|
|
74
|
+
|
|
75
|
+
# 3. Launch Server
|
|
76
|
+
# We inject state into the server module before starting it
|
|
77
|
+
server.SCENARIO = scenario
|
|
78
|
+
server.CONFIG_MGR = config_mgr
|
|
79
|
+
server.RUN_DIR = run_dir
|
|
80
|
+
server.DEFAULT_CONFIG = config
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
# Launch browser
|
|
84
|
+
threading.Timer(1.5, lambda: webbrowser.open(f"http://localhost:{port}/configure")).start()
|
|
85
|
+
|
|
86
|
+
# Start server (blocking main thread is fine now as we don't have a separate training thread YET)
|
|
87
|
+
# The training thread will be spawned by the server upon API request.
|
|
88
|
+
server.start_server(str(run_dir), port)
|
|
89
|
+
|
|
90
|
+
if __name__ == "__main__":
|
|
91
|
+
app()
|
gradia/core/config.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
import yaml
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import Any, Dict
|
|
4
|
+
|
|
5
|
+
class ConfigManager:
|
|
6
|
+
"""Manages gradia configuration."""
|
|
7
|
+
|
|
8
|
+
DEFAULT_CONFIG = {
|
|
9
|
+
'model': {
|
|
10
|
+
'type': 'auto', # auto, linear, random_forest
|
|
11
|
+
'params': {}
|
|
12
|
+
},
|
|
13
|
+
'training': {
|
|
14
|
+
'test_split': 0.2,
|
|
15
|
+
'random_seed': 42,
|
|
16
|
+
'shuffle': True
|
|
17
|
+
},
|
|
18
|
+
'scenario': {
|
|
19
|
+
'target': None, # Auto-detect
|
|
20
|
+
'task': None # Auto-detect
|
|
21
|
+
}
|
|
22
|
+
}
|
|
23
|
+
|
|
24
|
+
def __init__(self, run_dir: str = ".gradia_logs"):
|
|
25
|
+
self.run_dir = Path(run_dir)
|
|
26
|
+
self.config_path = self.run_dir / "config.yaml"
|
|
27
|
+
|
|
28
|
+
def load_or_create(self, user_overrides: Dict[str, Any] = None) -> Dict[str, Any]:
|
|
29
|
+
config = self.DEFAULT_CONFIG.copy()
|
|
30
|
+
|
|
31
|
+
# Load existing if any (feature for restart, maybe not for MVP run-once)
|
|
32
|
+
# For immutable runs, we usually generate NEW config.
|
|
33
|
+
# But if gradia.yaml exists in ROOT, we load it.
|
|
34
|
+
|
|
35
|
+
root_config = Path("gradia.yaml")
|
|
36
|
+
if root_config.exists():
|
|
37
|
+
with open(root_config, 'r') as f:
|
|
38
|
+
user_config = yaml.safe_load(f)
|
|
39
|
+
self._update_recursive(config, user_config)
|
|
40
|
+
|
|
41
|
+
if user_overrides:
|
|
42
|
+
self._update_recursive(config, user_overrides)
|
|
43
|
+
|
|
44
|
+
return config
|
|
45
|
+
|
|
46
|
+
def save(self, config: Dict[str, Any]):
|
|
47
|
+
self.run_dir.mkdir(exist_ok=True)
|
|
48
|
+
with open(self.config_path, 'w') as f:
|
|
49
|
+
yaml.dump(config, f)
|
|
50
|
+
|
|
51
|
+
def _update_recursive(self, base: Dict, update: Dict):
|
|
52
|
+
for k, v in update.items():
|
|
53
|
+
if k in base and isinstance(base[k], dict) and isinstance(v, dict):
|
|
54
|
+
self._update_recursive(base[k], v)
|
|
55
|
+
else:
|
|
56
|
+
base[k] = v
|
gradia/core/inspector.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import List, Optional
|
|
4
|
+
|
|
5
|
+
class Inspector:
|
|
6
|
+
"""Scans the working directory for potential dataset files."""
|
|
7
|
+
|
|
8
|
+
SUPPORTED_EXTENSIONS = {'.csv', '.parquet'}
|
|
9
|
+
|
|
10
|
+
def __init__(self, root_dir: str = "."):
|
|
11
|
+
self.root_dir = Path(root_dir)
|
|
12
|
+
|
|
13
|
+
def find_datasets(self) -> List[Path]:
|
|
14
|
+
"""Finds all supported dataset files in the root directory."""
|
|
15
|
+
datasets = []
|
|
16
|
+
for ext in self.SUPPORTED_EXTENSIONS:
|
|
17
|
+
datasets.extend(self.root_dir.glob(f"*{ext}"))
|
|
18
|
+
return sorted(datasets)
|
|
19
|
+
|
|
20
|
+
def detect_split_layout(self):
|
|
21
|
+
"""
|
|
22
|
+
Detects if proper 'train'/'val'/'test' folders exist.
|
|
23
|
+
Returns a dictionary with paths or None.
|
|
24
|
+
"""
|
|
25
|
+
layout = {}
|
|
26
|
+
for split in ['train', 'val', 'validation', 'test']:
|
|
27
|
+
split_dir = self.root_dir / split
|
|
28
|
+
if split_dir.exists() and split_dir.is_dir():
|
|
29
|
+
# Check for files inside
|
|
30
|
+
files = []
|
|
31
|
+
for ext in self.SUPPORTED_EXTENSIONS:
|
|
32
|
+
files.extend(list(split_dir.glob(f"*{ext}")))
|
|
33
|
+
|
|
34
|
+
if files:
|
|
35
|
+
layout[split] = split_dir
|
|
36
|
+
|
|
37
|
+
return layout if layout else None
|
gradia/core/scenario.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
1
|
+
import pandas as pd
|
|
2
|
+
import numpy as np
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
from typing import Optional, List, Any
|
|
5
|
+
|
|
6
|
+
@dataclass
|
|
7
|
+
class Scenario:
|
|
8
|
+
dataset_path: str
|
|
9
|
+
target_column: str
|
|
10
|
+
task_type: str # 'classification' or 'regression'
|
|
11
|
+
is_multiclass: bool = False
|
|
12
|
+
class_count: int = 0
|
|
13
|
+
features: List[str] = field(default_factory=list)
|
|
14
|
+
recommended_model: str = "random_forest"
|
|
15
|
+
|
|
16
|
+
class ScenarioInferrer:
|
|
17
|
+
"""Infers the ML scenario (Task type, Target) from a dataset."""
|
|
18
|
+
|
|
19
|
+
POSSIBLE_TARGET_NAMES = ['target', 'label', 'y', 'class', 'outcome', 'price', 'score']
|
|
20
|
+
|
|
21
|
+
def infer(self, file_path: str, target_override: Optional[str] = None) -> Scenario:
|
|
22
|
+
# Load a sample to infer types
|
|
23
|
+
df = self._load_sample(file_path)
|
|
24
|
+
|
|
25
|
+
target = target_override
|
|
26
|
+
if not target:
|
|
27
|
+
target = self._guess_target(df)
|
|
28
|
+
|
|
29
|
+
if not target:
|
|
30
|
+
raise ValueError(f"Could not infer target column for {file_path}. Please name one of {self.POSSIBLE_TARGET_NAMES} or provide config.")
|
|
31
|
+
|
|
32
|
+
task_type, is_multiclass, count = self._infer_task_type(df[target])
|
|
33
|
+
features = [c for c in df.columns if c != target]
|
|
34
|
+
|
|
35
|
+
recommended_model = self._infer_model_recommendation(features)
|
|
36
|
+
|
|
37
|
+
return Scenario(
|
|
38
|
+
dataset_path=str(file_path),
|
|
39
|
+
target_column=target,
|
|
40
|
+
task_type=task_type,
|
|
41
|
+
is_multiclass=is_multiclass,
|
|
42
|
+
class_count=count,
|
|
43
|
+
features=features,
|
|
44
|
+
recommended_model=recommended_model
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
def _infer_model_recommendation(self, features: List[str]) -> str:
|
|
48
|
+
# Heuristic 1: Check for pixel data (Fashion MNIST, MNIST, etc.)
|
|
49
|
+
# If > 100 features and names contain 'pixel'
|
|
50
|
+
if len(features) > 100:
|
|
51
|
+
pixel_cols = [f for f in features if 'pixel' in f.lower()]
|
|
52
|
+
if len(pixel_cols) > len(features) * 0.5:
|
|
53
|
+
return "cnn"
|
|
54
|
+
|
|
55
|
+
# Heuristic 2: Tabular default
|
|
56
|
+
return "random_forest"
|
|
57
|
+
|
|
58
|
+
def _load_sample(self, path: str, n_rows: int = 1000) -> pd.DataFrame:
|
|
59
|
+
if path.endswith('.csv'):
|
|
60
|
+
return pd.read_csv(path, nrows=n_rows)
|
|
61
|
+
elif path.endswith('.parquet'):
|
|
62
|
+
# Parquet doesn't support 'nrows' efficiently same as csv sometimes,
|
|
63
|
+
# but pandas read_parquet usually loads full. For large files we might need pyarrow.
|
|
64
|
+
# For MVP assume fits in memory or use logic to limits.
|
|
65
|
+
return pd.read_parquet(path).head(n_rows)
|
|
66
|
+
else:
|
|
67
|
+
raise ValueError("Unsupported format")
|
|
68
|
+
|
|
69
|
+
def _guess_target(self, df: pd.DataFrame) -> Optional[str]:
|
|
70
|
+
# 1. Exact Name match
|
|
71
|
+
for name in self.POSSIBLE_TARGET_NAMES:
|
|
72
|
+
if name in df.columns:
|
|
73
|
+
return name
|
|
74
|
+
if name.upper() in df.columns:
|
|
75
|
+
return name.upper()
|
|
76
|
+
|
|
77
|
+
# 2. Heuristic: Avoid ID/Date columns
|
|
78
|
+
candidates = []
|
|
79
|
+
for col in df.columns:
|
|
80
|
+
lower = col.lower()
|
|
81
|
+
if not any(x in lower for x in ['id', 'date', 'time', 'created_at', 'uuid', 'index']):
|
|
82
|
+
candidates.append(col)
|
|
83
|
+
|
|
84
|
+
if candidates:
|
|
85
|
+
return candidates[-1]
|
|
86
|
+
|
|
87
|
+
# 3. Last column fallback
|
|
88
|
+
return df.columns[-1]
|
|
89
|
+
|
|
90
|
+
def _infer_task_type(self, series: pd.Series):
|
|
91
|
+
"""
|
|
92
|
+
Returns (task_type, is_multiclass, class_count)
|
|
93
|
+
"""
|
|
94
|
+
# Heuristics:
|
|
95
|
+
# If string/object -> Classification
|
|
96
|
+
# If float -> Regression (unless low cardinality?)
|
|
97
|
+
# If int -> Check cardinality. Low (<20) -> Classification. High -> Regression.
|
|
98
|
+
|
|
99
|
+
unique_count = series.nunique()
|
|
100
|
+
dtype = series.dtype
|
|
101
|
+
|
|
102
|
+
if pd.api.types.is_string_dtype(dtype) or pd.api.types.is_object_dtype(dtype):
|
|
103
|
+
return 'classification', unique_count > 2, unique_count
|
|
104
|
+
|
|
105
|
+
if pd.api.types.is_float_dtype(dtype):
|
|
106
|
+
# If floats are actually integers (e.g. 1.0, 0.0), check that
|
|
107
|
+
if series.apply(float.is_integer).all() and unique_count < 20:
|
|
108
|
+
return 'classification', unique_count > 2, unique_count
|
|
109
|
+
return 'regression', False, 0
|
|
110
|
+
|
|
111
|
+
if pd.api.types.is_integer_dtype(dtype):
|
|
112
|
+
if unique_count < 20: # Arbitrary threshold for MVP
|
|
113
|
+
return 'classification', unique_count > 2, unique_count
|
|
114
|
+
else:
|
|
115
|
+
return 'regression', False, 0
|
|
116
|
+
|
|
117
|
+
# Fallback
|
|
118
|
+
return 'regression', False, 0
|
gradia/models/base.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import Any, Dict, Optional
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
class GradiaModel(ABC):
|
|
6
|
+
"""Abstract base class for all Gradia models."""
|
|
7
|
+
|
|
8
|
+
@abstractmethod
|
|
9
|
+
def fit(self, X, y, **kwargs):
|
|
10
|
+
"""Train the model fully."""
|
|
11
|
+
pass
|
|
12
|
+
|
|
13
|
+
def partial_fit(self, X, y, **kwargs):
|
|
14
|
+
"""Train on a batch or single epoch (optional)."""
|
|
15
|
+
raise NotImplementedError("This model does not support iterative training.")
|
|
16
|
+
|
|
17
|
+
@property
|
|
18
|
+
def supports_iterative(self) -> bool:
|
|
19
|
+
return False
|
|
20
|
+
|
|
21
|
+
@abstractmethod
|
|
22
|
+
def predict(self, X) -> np.ndarray:
|
|
23
|
+
"""Make predictions."""
|
|
24
|
+
pass
|
|
25
|
+
|
|
26
|
+
@abstractmethod
|
|
27
|
+
def predict_proba(self, X) -> Optional[np.ndarray]:
|
|
28
|
+
"""Make probability predictions (if applicable)."""
|
|
29
|
+
pass
|
|
30
|
+
|
|
31
|
+
@abstractmethod
|
|
32
|
+
def get_feature_importance(self) -> Optional[Dict[str, float]]:
|
|
33
|
+
"""Return feature importance map if available."""
|
|
34
|
+
pass
|
|
35
|
+
|
|
36
|
+
@abstractmethod
|
|
37
|
+
def get_params(self) -> Dict[str, Any]:
|
|
38
|
+
"""Return model hyperparameters."""
|
|
39
|
+
pass
|
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
from typing import Any, Dict, Optional
|
|
2
|
+
import numpy as np
|
|
3
|
+
from sklearn.linear_model import LogisticRegression, LinearRegression, SGDClassifier, SGDRegressor
|
|
4
|
+
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
|
|
5
|
+
from .base import GradiaModel
|
|
6
|
+
|
|
7
|
+
class SklearnWrapper(GradiaModel):
|
|
8
|
+
def __init__(self, model, feature_names=None):
|
|
9
|
+
self.model = model
|
|
10
|
+
self.feature_names = feature_names
|
|
11
|
+
|
|
12
|
+
def fit(self, X, y, **kwargs):
|
|
13
|
+
self.model.fit(X, y)
|
|
14
|
+
if hasattr(X, "columns"):
|
|
15
|
+
self.feature_names = list(X.columns)
|
|
16
|
+
|
|
17
|
+
def partial_fit(self, X, y, **kwargs):
|
|
18
|
+
# For SGD (true partial_fit)
|
|
19
|
+
if hasattr(self.model, "partial_fit"):
|
|
20
|
+
classes = kwargs.get('classes')
|
|
21
|
+
if classes is not None:
|
|
22
|
+
self.model.partial_fit(X, y, classes=classes)
|
|
23
|
+
else:
|
|
24
|
+
self.model.partial_fit(X, y)
|
|
25
|
+
|
|
26
|
+
# For RandomForest (warm_start simulation)
|
|
27
|
+
elif hasattr(self.model, "warm_start") and self.model.warm_start:
|
|
28
|
+
# Increase estimators by 1 step
|
|
29
|
+
self.model.n_estimators += 1
|
|
30
|
+
self.model.fit(X, y)
|
|
31
|
+
|
|
32
|
+
if hasattr(X, "columns"):
|
|
33
|
+
self.feature_names = list(X.columns)
|
|
34
|
+
|
|
35
|
+
@property
|
|
36
|
+
def supports_iterative(self) -> bool:
|
|
37
|
+
return hasattr(self.model, "partial_fit") or (hasattr(self.model, "warm_start") and self.model.warm_start)
|
|
38
|
+
|
|
39
|
+
def predict(self, X) -> np.ndarray:
|
|
40
|
+
return self.model.predict(X)
|
|
41
|
+
|
|
42
|
+
def predict_proba(self, X) -> Optional[np.ndarray]:
|
|
43
|
+
if hasattr(self.model, "predict_proba"):
|
|
44
|
+
return self.model.predict_proba(X)
|
|
45
|
+
return None
|
|
46
|
+
|
|
47
|
+
def get_feature_importance(self) -> Optional[Dict[str, float]]:
|
|
48
|
+
if not self.feature_names:
|
|
49
|
+
return None
|
|
50
|
+
|
|
51
|
+
importances = None
|
|
52
|
+
if hasattr(self.model, "coef_"):
|
|
53
|
+
# Linear models
|
|
54
|
+
importances = np.abs(self.model.coef_)
|
|
55
|
+
if importances.ndim > 1:
|
|
56
|
+
importances = importances.mean(axis=0) # Multiclass avg
|
|
57
|
+
elif hasattr(self.model, "feature_importances_"):
|
|
58
|
+
# Utilities (Tree based)
|
|
59
|
+
importances = self.model.feature_importances_
|
|
60
|
+
|
|
61
|
+
if importances is not None:
|
|
62
|
+
return dict(zip(self.feature_names, importances))
|
|
63
|
+
return None
|
|
64
|
+
|
|
65
|
+
def get_params(self) -> Dict[str, Any]:
|
|
66
|
+
return self.model.get_params()
|
|
67
|
+
|
|
68
|
+
class ModelFactory:
|
|
69
|
+
@staticmethod
|
|
70
|
+
def create(model_type: str, task_type: str, params: Dict[str, Any] = {}) -> GradiaModel:
|
|
71
|
+
# Standard Linear
|
|
72
|
+
if model_type == 'linear':
|
|
73
|
+
if task_type == 'classification':
|
|
74
|
+
return SklearnWrapper(LogisticRegression(**params))
|
|
75
|
+
else:
|
|
76
|
+
return SklearnWrapper(LinearRegression(**params))
|
|
77
|
+
|
|
78
|
+
# Random Forest
|
|
79
|
+
elif model_type == 'random_forest':
|
|
80
|
+
# Enable warm_start for iterative viz if not specified
|
|
81
|
+
if 'warm_start' not in params:
|
|
82
|
+
params['warm_start'] = True
|
|
83
|
+
if task_type == 'classification':
|
|
84
|
+
return SklearnWrapper(RandomForestClassifier(**params))
|
|
85
|
+
else:
|
|
86
|
+
return SklearnWrapper(RandomForestRegressor(**params))
|
|
87
|
+
|
|
88
|
+
# SGD (Iterative Linear)
|
|
89
|
+
elif model_type == 'sgd':
|
|
90
|
+
# Map optimizer/learning rate params from UI to sklearn args if needed
|
|
91
|
+
# User might pass 'lr', sklearn uses 'eta0' + 'learning_rate'='constant'/'invscaling'
|
|
92
|
+
# simplified normalization handled by CLI or here.
|
|
93
|
+
# For MVP, assume params are already sklearn-compatible or clean them up.
|
|
94
|
+
if task_type == 'classification':
|
|
95
|
+
return SklearnWrapper(SGDClassifier(**params))
|
|
96
|
+
else:
|
|
97
|
+
return SklearnWrapper(SGDRegressor(**params))
|
|
98
|
+
|
|
99
|
+
# MLP / CNN (Basic Neural Net)
|
|
100
|
+
elif model_type in ['mlp', 'cnn']:
|
|
101
|
+
from sklearn.neural_network import MLPClassifier, MLPRegressor
|
|
102
|
+
if task_type == 'classification':
|
|
103
|
+
# hidden_layer_sizes default for simple MNIST-like
|
|
104
|
+
if 'hidden_layer_sizes' not in params:
|
|
105
|
+
params['hidden_layer_sizes'] = (100, 50)
|
|
106
|
+
return SklearnWrapper(MLPClassifier(warm_start=True, **params))
|
|
107
|
+
else:
|
|
108
|
+
return SklearnWrapper(MLPRegressor(warm_start=True, **params))
|
|
109
|
+
|
|
110
|
+
# Default fallback
|
|
111
|
+
if task_type == 'classification':
|
|
112
|
+
return SklearnWrapper(RandomForestClassifier(warm_start=True, **params))
|
|
113
|
+
else:
|
|
114
|
+
return SklearnWrapper(RandomForestRegressor(warm_start=True, **params))
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
from typing import Dict, Any, List
|
|
2
|
+
import json
|
|
3
|
+
import time
|
|
4
|
+
import threading
|
|
5
|
+
import os
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
|
|
8
|
+
# Shared lock for writing to the log file from multiple threads (Trainer vs SystemMonitor)
|
|
9
|
+
log_lock = threading.Lock()
|
|
10
|
+
|
|
11
|
+
class Callback:
|
|
12
|
+
def on_train_begin(self, logs: Dict[str, Any] = {}): pass
|
|
13
|
+
def on_epoch_end(self, epoch: int, logs: Dict[str, Any] = {}): pass
|
|
14
|
+
def on_train_end(self, logs: Dict[str, Any] = {}): pass
|
|
15
|
+
|
|
16
|
+
class EventLogger(Callback):
|
|
17
|
+
"""
|
|
18
|
+
Logs events to a file which can be tailed by the UI server.
|
|
19
|
+
Also keeps an in-memory buffer.
|
|
20
|
+
"""
|
|
21
|
+
def __init__(self, log_dir: str):
|
|
22
|
+
self.log_path = Path(log_dir) / "events.jsonl"
|
|
23
|
+
self.log_path.parent.mkdir(parents=True, exist_ok=True)
|
|
24
|
+
# Clear existing
|
|
25
|
+
if self.log_path.exists():
|
|
26
|
+
with log_lock:
|
|
27
|
+
# Double check to avoid race if multiple loggers init (rare)
|
|
28
|
+
if self.log_path.exists():
|
|
29
|
+
self.log_path.unlink()
|
|
30
|
+
|
|
31
|
+
def _emit(self, event_type: str, data: Dict[str, Any]):
|
|
32
|
+
payload = {
|
|
33
|
+
"timestamp": time.time(),
|
|
34
|
+
"type": event_type,
|
|
35
|
+
"data": data
|
|
36
|
+
}
|
|
37
|
+
with log_lock:
|
|
38
|
+
with open(self.log_path, "a") as f:
|
|
39
|
+
f.write(json.dumps(payload) + "\n")
|
|
40
|
+
|
|
41
|
+
def on_train_begin(self, logs={}):
|
|
42
|
+
self._emit("train_begin", logs)
|
|
43
|
+
|
|
44
|
+
def on_epoch_end(self, epoch: int, logs={}):
|
|
45
|
+
self._emit("epoch_end", {"epoch": epoch, **logs})
|
|
46
|
+
|
|
47
|
+
def on_train_end(self, logs={}):
|
|
48
|
+
self._emit("train_end", logs)
|